当前位置:   article > 正文

Transformers 源码解析(二百八十九)

Transformers 源码解析(二百八十九)

.\pipelines\zero_shot_image_classification.py

# 导入必要的模块和函数
from collections import UserDict  # 导入UserDict用于创建自定义字典
from typing import List, Union  # 导入List和Union用于类型提示

# 从上级目录的utils模块导入各种函数和类
from ..utils import (
    add_end_docstrings,  # 导入函数add_end_docstrings,用于添加文档字符串
    is_tf_available,  # 导入函数is_tf_available,检查是否可以使用TensorFlow
    is_torch_available,  # 导入函数is_torch_available,检查是否可以使用PyTorch
    is_vision_available,  # 导入函数is_vision_available,检查是否可以使用视觉处理功能
    logging,  # 导入logging模块,用于日志记录
    requires_backends,  # 导入requires_backends函数,用于检查后端依赖
)

# 从当前目录的base模块导入Pipeline类和build_pipeline_init_args函数
from .base import Pipeline, build_pipeline_init_args

# 如果可以使用视觉处理功能
if is_vision_available():
    # 从PIL库中导入Image模块,用于处理图像
    from PIL import Image
    # 从image_utils模块导入load_image函数,用于加载图像数据

# 如果可以使用PyTorch
if is_torch_available():
    # 导入torch库,用于深度学习任务
    import torch
    # 从models.auto模块导入模型映射名称字典

# 如果可以使用TensorFlow
if is_tf_available():
    # 从models.auto模块导入TensorFlow相关的模型映射名称字典
    from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
    # 从tf_utils模块导入稳定的softmax函数,用于概率计算

# 获取当前模块的日志记录器对象
logger = logging.get_logger(__name__)

# 使用装饰器add_end_docstrings为ZeroShotImageClassificationPipeline类添加文档字符串
@add_end_docstrings(build_pipeline_init_args(has_image_processor=True))
class ZeroShotImageClassificationPipeline(Pipeline):
    """
    Zero shot image classification pipeline using `CLIPModel`. This pipeline predicts the class of an image when you
    provide an image and a set of `candidate_labels`.

    Example:

    ```
    >>> from transformers import pipeline

    >>> classifier = pipeline(model="google/siglip-so400m-patch14-384")
    >>> classifier(
    ...     "https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png",
    ...     candidate_labels=["animals", "humans", "landscape"],
    ... )
    [{'score': 0.965, 'label': 'animals'}, {'score': 0.03, 'label': 'humans'}, {'score': 0.005, 'label': 'landscape'}]

    >>> classifier(
    ...     "https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png",
    ...     candidate_labels=["black and white", "photorealist", "painting"],
    ... )
    [{'score': 0.996, 'label': 'black and white'}, {'score': 0.003, 'label': 'photorealist'}, {'score': 0.0, 'label': 'painting'}]
    ```

    Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)

    This image classification pipeline can currently be loaded from [`pipeline`] using the following task identifier:
    `"zero-shot-image-classification"`.

    See the list of available models on
    [huggingface.co/models](https://huggingface.co/models?filter=zero-shot-image-classification).
    """

    # 初始化函数,继承自Pipeline类
    def __init__(self, **kwargs):
        # 调用父类的初始化方法
        super().__init__(**kwargs)

        # 检查当前实例是否满足视觉后端的依赖
        requires_backends(self, "vision")
        
        # 根据当前框架选择适当的模型映射名称字典,用于后续任务
        self.check_model_type(
            TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
            if self.framework == "tf"
            else MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
        )
    def __call__(self, images: Union[str, List[str], "Image", List["Image"]], **kwargs):
        """
        将标签分配给作为输入传递的图像。

        Args:
            images (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
                处理三种类型的图像:

                - 包含指向图像的 http 链接的字符串
                - 包含指向本地图像路径的字符串
                - 直接加载到 PIL 中的图像

            candidate_labels (`List[str]`):
                此图像的候选标签列表

            hypothesis_template (`str`, *可选*, 默认为 `"This is a photo of {}"`):
                与 *candidate_labels* 结合使用的句子,通过将占位符替换为 candidate_labels 尝试图像分类。
                然后使用 logits_per_image 估算可能性。

            timeout (`float`, *可选*, 默认为 None):
                从网络获取图像的最长等待时间(以秒为单位)。如果为 None,则不设置超时,调用可能会永远阻塞。

        Return:
            包含结果的字典列表,每个提议的标签一个字典。字典包含以下键:

            - **label** (`str`) -- 模型识别的标签之一。它是建议的 `candidate_label` 之一。
            - **score** (`float`) -- 模型为该标签分配的分数(介于0和1之间)。
        """
        return super().__call__(images, **kwargs)

    def _sanitize_parameters(self, **kwargs):
        preprocess_params = {}
        if "candidate_labels" in kwargs:
            preprocess_params["candidate_labels"] = kwargs["candidate_labels"]
        if "timeout" in kwargs:
            preprocess_params["timeout"] = kwargs["timeout"]
        if "hypothesis_template" in kwargs:
            preprocess_params["hypothesis_template"] = kwargs["hypothesis_template"]

        return preprocess_params, {}, {}

    def preprocess(self, image, candidate_labels=None, hypothesis_template="This is a photo of {}.", timeout=None):
        """
        预处理图像及其相关参数。

        Args:
            image: 图像数据
            candidate_labels (`List[str]`, optional): 图像的候选标签
            hypothesis_template (`str`, optional, defaults to `"This is a photo of {}."`):
                用于替换占位符生成假设句子的模板
            timeout (`float`, optional): 从网络获取图像的最长等待时间(以秒为单位)

        Returns:
            inputs: 包含预处理后数据的字典
        """
        image = load_image(image, timeout=timeout)  # 加载图像数据
        inputs = self.image_processor(images=[image], return_tensors=self.framework)  # 处理图像数据
        inputs["candidate_labels"] = candidate_labels  # 设置候选标签
        sequences = [hypothesis_template.format(x) for x in candidate_labels]  # 根据模板生成假设句子序列
        padding = "max_length" if self.model.config.model_type == "siglip" else True  # 根据模型类型设置填充方式
        text_inputs = self.tokenizer(sequences, return_tensors=self.framework, padding=padding)  # 对假设句子序列进行tokenize
        inputs["text_inputs"] = [text_inputs]  # 设置文本输入
        return inputs
    # 定义一个方法用于模型推断,接收模型输入
    def _forward(self, model_inputs):
        # 弹出输入中的候选标签
        candidate_labels = model_inputs.pop("candidate_labels")
        # 弹出输入中的文本数据
        text_inputs = model_inputs.pop("text_inputs")
        
        # 如果文本输入的第一个元素是 UserDict 类型的对象
        if isinstance(text_inputs[0], UserDict):
            # 将文本输入重新赋值为第一个元素(UserDict对象)
            text_inputs = text_inputs[0]
        else:
            # 如果不是 UserDict 对象,则为批处理情况,取第一个元素的第一个元素
            # (这里假设 text_inputs 是一个二重嵌套列表,第一个元素是批处理的列表)
            text_inputs = text_inputs[0][0]

        # 使用模型进行推断,传入文本输入和模型输入
        outputs = self.model(**text_inputs, **model_inputs)

        # 构建模型输出字典,包括候选标签和模型的 logits
        model_outputs = {
            "candidate_labels": candidate_labels,
            "logits": outputs.logits_per_image,
        }
        return model_outputs

    # 定义一个方法用于后处理模型输出
    def postprocess(self, model_outputs):
        # 弹出模型输出中的候选标签
        candidate_labels = model_outputs.pop("candidate_labels")
        # 取出 logits,并在第一个维度上进行压缩,即去除维度为1的维度
        logits = model_outputs["logits"][0]

        # 根据不同的框架和模型类型进行处理概率
        if self.framework == "pt" and self.model.config.model_type == "siglip":
            # 对 logits 应用 sigmoid 函数,并在最后一个维度上进行压缩
            probs = torch.sigmoid(logits).squeeze(-1)
            # 将概率转换为列表
            scores = probs.tolist()
            # 如果 scores 不是列表,则转换为列表
            if not isinstance(scores, list):
                scores = [scores]
        elif self.framework == "pt":
            # 对 logits 应用 softmax 函数,并在最后一个维度上进行压缩
            probs = logits.softmax(dim=-1).squeeze(-1)
            # 将概率转换为列表
            scores = probs.tolist()
            # 如果 scores 不是列表,则转换为列表
            if not isinstance(scores, list):
                scores = [scores]
        elif self.framework == "tf":
            # 对 logits 应用稳定的 softmax 函数,并在最后一个维度上进行处理
            probs = stable_softmax(logits, axis=-1)
            # 将概率转换为 numpy 数组,再转换为列表
            scores = probs.numpy().tolist()
        else:
            # 如果框架不支持,则引发异常
            raise ValueError(f"Unsupported framework: {self.framework}")

        # 将概率分数与候选标签组成字典列表,并按分数降序排列
        result = [
            {"score": score, "label": candidate_label}
            for score, candidate_label in sorted(zip(scores, candidate_labels), key=lambda x: -x[0])
        ]
        return result
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215

.\pipelines\zero_shot_object_detection.py

from typing import Any, Dict, List, Union  # 导入需要的类型提示模块

from ..utils import add_end_docstrings, is_torch_available, is_vision_available, logging, requires_backends  # 导入自定义工具函数和模块
from .base import ChunkPipeline, build_pipeline_init_args  # 导入基础类和初始化函数构建器


if is_vision_available():  # 如果视觉处理模块可用
    from PIL import Image  # 导入PIL图像处理库中的Image模块
    from ..image_utils import load_image  # 从自定义图像处理工具中导入加载图像的函数

if is_torch_available():  # 如果PyTorch可用
    import torch  # 导入PyTorch模块
    from transformers.modeling_outputs import BaseModelOutput  # 导入模型输出基类
    from ..models.auto.modeling_auto import MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES  # 导入零样本对象检测模型映射名称

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器


@add_end_docstrings(build_pipeline_init_args(has_image_processor=True))  # 添加文档字符串的装饰器,指定初始化参数为具有图像处理器
class ZeroShotObjectDetectionPipeline(ChunkPipeline):  # 定义零样本对象检测流水线,继承自ChunkPipeline基类
    """
    Zero shot object detection pipeline using `OwlViTForObjectDetection`. This pipeline predicts bounding boxes of
    objects when you provide an image and a set of `candidate_labels`.

    Example:

    ```
    >>> from transformers import pipeline

    >>> detector = pipeline(model="google/owlvit-base-patch32", task="zero-shot-object-detection")
    >>> detector(
    ...     "http://images.cocodataset.org/val2017/000000039769.jpg",
    ...     candidate_labels=["cat", "couch"],
    ... )
    [{'score': 0.287, 'label': 'cat', 'box': {'xmin': 324, 'ymin': 20, 'xmax': 640, 'ymax': 373}}, {'score': 0.254, 'label': 'cat', 'box': {'xmin': 1, 'ymin': 55, 'xmax': 315, 'ymax': 472}}, {'score': 0.121, 'label': 'couch', 'box': {'xmin': 4, 'ymin': 0, 'xmax': 642, 'ymax': 476}}]

    >>> detector(
    ...     "https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png",
    ...     candidate_labels=["head", "bird"],
    ... )
    [{'score': 0.119, 'label': 'bird', 'box': {'xmin': 71, 'ymin': 170, 'xmax': 410, 'ymax': 508}}]
    ```

    Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)

    This object detection pipeline can currently be loaded from [`pipeline`] using the following task identifier:
    `"zero-shot-object-detection"`.

    See the list of available models on
    [huggingface.co/models](https://huggingface.co/models?filter=zero-shot-object-detection).
    """

    def __init__(self, **kwargs):  # 定义初始化方法,接受任意关键字参数
        super().__init__(**kwargs)  # 调用父类的初始化方法,传递所有接收到的关键字参数

        if self.framework == "tf":  # 如果当前框架是TensorFlow
            raise ValueError(f"The {self.__class__} is only available in PyTorch.")  # 抛出错误,表明该类只在PyTorch中可用

        requires_backends(self, "vision")  # 确保必要的后端模块可用,这里要求视觉处理模块可用
        self.check_model_type(MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES)  # 检查当前模型类型是否符合零样本对象检测模型的映射名称

    def __call__(  # 定义对象实例可调用的方法
        self,
        image: Union[str, "Image.Image", List[Dict[str, Any]]],  # 图像参数可以是字符串、PIL图像对象或包含字典的列表
        candidate_labels: Union[str, List[str]] = None,  # 候选标签可以是字符串或字符串列表,默认为None
        **kwargs,  # 允许接收额外的关键字参数
   `
# 定义一个方法用于清理参数
def _sanitize_parameters(self, **kwargs):
    # 创建一个空的预处理参数字典
    preprocess_params = {}
    # 如果参数中包含超时(timeout),将其加入预处理参数中
    if "timeout" in kwargs:
        preprocess_params["timeout"] = kwargs["timeout"]
    
    # 创建一个空的后处理参数字典
    postprocess_params = {}
    # 如果参数中包含阈值(threshold),将其加入后处理参数中
    if "threshold" in kwargs:
        postprocess_params["threshold"] = kwargs["threshold"]
    # 如果参数中包含前 k 个(top_k),将其加入后处理参数中
    if "top_k" in kwargs:
        postprocess_params["top_k"] = kwargs["top_k"]
    
    # 返回预处理参数字典、空字典和后处理参数字典
    return preprocess_params, {}, postprocess_params

# 定义一个预处理方法
def preprocess(self, inputs, timeout=None):
    # 加载图像,并设定超时时间
    image = load_image(inputs["image"], timeout=timeout)
    # 获取候选标签
    candidate_labels = inputs["candidate_labels"]
    # 如果候选标签是字符串,则按逗号分隔
    if isinstance(candidate_labels, str):
        candidate_labels = candidate_labels.split(",")

    # 创建目标尺寸张量
    target_size = torch.tensor([[image.height, image.width]], dtype=torch.int32)
    
    # 遍历候选标签
    for i, candidate_label in enumerate(candidate_labels):
        # 使用分词器处理候选标签,返回张量
        text_inputs = self.tokenizer(candidate_label, return_tensors=self.framework)
        # 使用图像处理器处理图像,返回张量
        image_features = self.image_processor(image, return_tensors=self.framework)
        
        # 生成字典,包括是否最后一个、目标尺寸、候选标签及其它特征
        yield {
            "is_last": i == len(candidate_labels) - 1,
            "target_size": target_size,
            "candidate_label": candidate_label,
            **text_inputs,
            **image_features,
        }

# 定义一个前向方法
def _forward(self, model_inputs):
    # 弹出目标尺寸、候选标签和是否最后一个标志
    target_size = model_inputs.pop("target_size")
    candidate_label = model_inputs.pop("candidate_label")
    is_last = model_inputs.pop("is_last")

    # 使用模型处理输入,返回输出
    outputs = self.model(**model_inputs)

    # 创建模型输出字典,包括目标尺寸、候选标签、是否最后一个及其它输出
    model_outputs = {"target_size": target_size, "candidate_label": candidate_label, "is_last": is_last, **outputs}
    return model_outputs

# 定义一个后处理方法
def postprocess(self, model_outputs, threshold=0.1, top_k=None):
    # 存储结果列表
    results = []
    
    # 遍历模型输出
    for model_output in model_outputs:
        # 获取候选标签
        label = model_output["candidate_label"]
        # 将模型输出封装成基本模型输出对象
        model_output = BaseModelOutput(model_output)
        
        # 使用图像处理器后处理目标检测结果,返回输出
        outputs = self.image_processor.post_process_object_detection(
            outputs=model_output, threshold=threshold, target_sizes=model_output["target_size"]
        )[0]

        # 遍历输出的分eshold, target_sizes=model_output["target_size"]
            )[0]

            # 遍历输出结果中的得分,生成包含得分、标签和边界框的结果字典,并添加到结果列表中
            for index in outputs["scores"].nonzero():
                score = outputs["scores"][index].item()
                box = self._get_bounding_box(outputs["boxes"][index][0])

                result = {"score": score, "label": label, "box": box}
                results.append(result)

        # 按得分倒序排列结果列表
        results = sorted(results, key=lambda x: x["score"], reverse=True)
        # 如果指定了 top_k 参数,则返回前 top_k 个结果
        if top_k:
            results = results[:top_k]

        return results
    # 定义一个方法 `_get_bounding_box`,用于将列表 [xmin, xmax, ymin, ymax] 转换为包含这些坐标的字典
    def _get_bounding_box(self, box: "torch.Tensor") -> Dict[str, int]:
        """
        Turns list [xmin, xmax, ymin, ymax] into dict { "xmin": xmin, ... }

        Args:
            box (`torch.Tensor`): Tensor containing the coordinates in corners format.

        Returns:
            bbox (`Dict[str, int]`): Dict containing the coordinates in corners format.
        """
        # 检查当前所用的深度学习框架是否为 PyTorch,若不是则抛出 ValueError 异常
        if self.framework != "pt":
            raise ValueError("The ZeroShotObjectDetectionPipeline is only available in PyTorch.")
        # 将输入的 box 张量转换为整数列表,并将其转换为 Python 中的标准列表形式
        xmin, ymin, xmax, ymax = box.int().tolist()
        # 创建包含坐标的字典 bbox,键为坐标名,值为对应的坐标值
        bbox = {
            "xmin": xmin,
            "ymin": ymin,
            "xmax": xmax,
            "ymax": ymax,
        }
        # 返回坐标字典 bbox
        return bbox
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190

.\pipelines\__init__.py

# 导入所需的模块和函数

import json  # 导入处理 JSON 数据的模块
import os  # 导入操作系统相关的功能模块
import warnings  # 导入警告处理模块
from pathlib import Path  # 导入处理路径的模块 Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union  # 导入类型提示相关的功能

from huggingface_hub import model_info  # 从 huggingface_hub 模块导入 model_info

# 从不同模块中导入所需的类和函数
from ..configuration_utils import PretrainedConfig  # 导入预训练配置类
from ..dynamic_module_utils import get_class_from_dynamic_module  # 导入从动态模块获取类的函数
from ..feature_extraction_utils import PreTrainedFeatureExtractor  # 导入预训练特征提取器类
from ..image_processing_utils import BaseImageProcessor  # 导入基础图像处理器类
from ..models.auto.configuration_auto import AutoConfig  # 导入自动配置类
from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor  # 导入自动特征提取映射和自动特征提取器类
from ..models.auto.image_processing_auto import IMAGE_PROCESSOR_MAPPING, AutoImageProcessor  # 导入自动图像处理映射和自动图像处理器类
from ..models.auto.modeling_auto import AutoModelForDepthEstimation, AutoModelForImageToImage  # 导入自动深度估计模型和自动图像转换模型
from ..models.auto.tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer  # 导入自动分词映射和自动分词器类
from ..tokenization_utils import PreTrainedTokenizer  # 导入预训练分词器类
from ..utils import (
    CONFIG_NAME,  # 导入配置文件名常量
    HUGGINGFACE_CO_RESOLVE_ENDPOINT,  # 导入 Hugging Face 协作解决端点常量
    cached_file,  # 导入缓存文件函数
    extract_commit_hash,  # 导入提取提交哈希函数
    find_adapter_config_file,  # 导入查找适配器配置文件函数
    is_kenlm_available,  # 导入检查 kenlm 是否可用函数
    is_offline_mode,  # 导入检查是否离线模式函数
    is_peft_available,  # 导入检查 peft 是否可用函数
    is_pyctcdecode_available,  # 导入检查 pyctcdecode 是否可用函数
    is_tf_available,  # 导入检查是否有 TensorFlow 函数
    is_torch_available,  # 导入检查是否有 PyTorch 函数
    logging,  # 导入日志记录模块
)

# 从不同子模块导入具体的任务流水线类
from .audio_classification import AudioClassificationPipeline  # 导入音频分类任务流水线类
from .automatic_speech_recognition import AutomaticSpeechRecognitionPipeline  # 导入自动语音识别任务流水线类
from .base import (  # 从基础模块导入多个类和函数
    ArgumentHandler,  # 导入参数处理器类
    CsvPipelineDataFormat,  # 导入 CSV 数据格式流水线类
    JsonPipelineDataFormat,  # 导入 JSON 数据格式流水线类
    PipedPipelineDataFormat,  # 导入管道数据格式流水线类
    Pipeline,  # 导入任务流水线基类
    PipelineDataFormat,  # 导入任务流水线数据格式基类
    PipelineException,  # 导入任务流水线异常类
    PipelineRegistry,  # 导入任务流水线注册表类
    get_default_model_and_revision,  # 导入获取默认模型和版本函数
    infer_framework_load_model,  # 导入推断框架加载模型函数
)

# 从不同子模块导入特定任务流水线类
from .conversational import Conversation, ConversationalPipeline  # 导入对话任务流水线类
from .depth_estimation import DepthEstimationPipeline  # 导入深度估计任务流水线类
from .document_question_answering import DocumentQuestionAnsweringPipeline  # 导入文档问答任务流水线类
from .feature_extraction import FeatureExtractionPipeline  # 导入特征提取任务流水线类
from .fill_mask import FillMaskPipeline  # 导入填充掩码任务流水线类
from .image_classification import ImageClassificationPipeline  # 导入图像分类任务流水线类
from .image_feature_extraction import ImageFeatureExtractionPipeline  # 导入图像特征提取任务流水线类
from .image_segmentation import ImageSegmentationPipeline  # 导入图像分割任务流水线类
from .image_to_image import ImageToImagePipeline  # 导入图像到图像任务流水线类
from .image_to_text import ImageToTextPipeline  # 导入图像到文本任务流水线类
from .mask_generation import MaskGenerationPipeline  # 导入生成掩码任务流水线类
from .object_detection import ObjectDetectionPipeline  # 导入对象检测任务流水线类
from .question_answering import QuestionAnsweringArgumentHandler, QuestionAnsweringPipeline  # 导入问答任务流水线相关类和函数
# 导入表格问答模块中的参数处理器和管道
from .table_question_answering import TableQuestionAnsweringArgumentHandler, TableQuestionAnsweringPipeline
# 导入文本到文本生成模块中的摘要生成管道、文本到文本生成管道和翻译管道
from .text2text_generation import SummarizationPipeline, Text2TextGenerationPipeline, TranslationPipeline
# 导入文本分类模块中的文本分类管道
from .text_classification import TextClassificationPipeline
# 导入文本生成模块中的文本生成管道
from .text_generation import TextGenerationPipeline
# 导入文本到音频模块中的文本到音频管道
from .text_to_audio import TextToAudioPipeline
# 导入标记分类模块中的聚合策略、命名实体识别管道、标记分类参数处理器和标记分类管道
from .token_classification import (
    AggregationStrategy,
    NerPipeline,
    TokenClassificationArgumentHandler,
    TokenClassificationPipeline,
)
# 导入视频分类模块中的视频分类管道
from .video_classification import VideoClassificationPipeline
# 导入视觉问答模块中的视觉问答管道
from .visual_question_answering import VisualQuestionAnsweringPipeline
# 导入零样本音频分类模块中的零样本音频分类管道
from .zero_shot_audio_classification import ZeroShotAudioClassificationPipeline
# 导入零样本分类模块中的零样本分类参数处理器和零样本分类管道
from .zero_shot_classification import ZeroShotClassificationArgumentHandler, ZeroShotClassificationPipeline
# 导入零样本图像分类模块中的零样本图像分类管道
from .zero_shot_image_classification import ZeroShotImageClassificationPipeline
# 导入零样本目标检测模块中的零样本目标检测管道
from .zero_shot_object_detection import ZeroShotObjectDetectionPipeline

# 如果 TensorFlow 可用,则导入相关模块和类
if is_tf_available():
    import tensorflow as tf

    from ..models.auto.modeling_tf_auto import (
        TFAutoModel,
        TFAutoModelForCausalLM,
        TFAutoModelForImageClassification,
        TFAutoModelForMaskedLM,
        TFAutoModelForQuestionAnswering,
        TFAutoModelForSeq2SeqLM,
        TFAutoModelForSequenceClassification,
        TFAutoModelForTableQuestionAnswering,
        TFAutoModelForTokenClassification,
        TFAutoModelForVision2Seq,
        TFAutoModelForZeroShotImageClassification,
    )

# 如果 PyTorch 可用,则导入相关模块和类
if is_torch_available():
    import torch

    from ..models.auto.modeling_auto import (
        AutoModel,
        AutoModelForAudioClassification,
        AutoModelForCausalLM,
        AutoModelForCTC,
        AutoModelForDocumentQuestionAnswering,
        AutoModelForImageClassification,
        AutoModelForImageSegmentation,
        AutoModelForMaskedLM,
        AutoModelForMaskGeneration,
        AutoModelForObjectDetection,
        AutoModelForQuestionAnswering,
        AutoModelForSemanticSegmentation,
        AutoModelForSeq2SeqLM,
        AutoModelForSequenceClassification,
        AutoModelForSpeechSeq2Seq,
        AutoModelForTableQuestionAnswering,
        AutoModelForTextToSpectrogram,
        AutoModelForTextToWaveform,
        AutoModelForTokenClassification,
        AutoModelForVideoClassification,
        AutoModelForVision2Seq,
        AutoModelForVisualQuestionAnswering,
        AutoModelForZeroShotImageClassification,
        AutoModelForZeroShotObjectDetection,
    )

# 如果支持类型检查,则导入必要的模块
if TYPE_CHECKING:
    from ..modeling_tf_utils import TFPreTrainedModel
    from ..modeling_utils import PreTrainedModel
    from ..tokenization_utils_fast import PreTrainedTokenizerFast

# 获取日志记录器并命名空间化
logger = logging.get_logger(__name__)

# 注册所有支持的任务别名
TASK_ALIASES = {
    "sentiment-analysis": "text-classification",  # 情感分析任务的别名为文本分类
    "ner": "token-classification",  # 命名实体识别任务的别名为标记分类
    "vqa": "visual-question-answering",  # 视觉问答任务的别名为视觉问答
    "text-to-speech": "text-to-audio",  # 文本转语音任务的别名为文本到音频
}
# 支持的任务及其配置信息字典,每个任务对应一个字典条目
SUPPORTED_TASKS = {
    # 音频分类任务
    "audio-classification": {
        # 实现类为 AudioClassificationPipeline
        "impl": AudioClassificationPipeline,
        # TensorFlow 空元组,无特定的 TensorFlow 模型
        "tf": (),
        # 如果 Torch 可用,包含 AutoModelForAudioClassification 类
        "pt": (AutoModelForAudioClassification,) if is_torch_available() else (),
        # 默认模型为 wav2vec2-base-superb-ks,版本为 "372e048"
        "default": {"model": {"pt": ("superb/wav2vec2-base-superb-ks", "372e048")}},
        # 类型为音频
        "type": "audio",
    },
    # 自动语音识别任务
    "automatic-speech-recognition": {
        # 实现类为 AutomaticSpeechRecognitionPipeline
        "impl": AutomaticSpeechRecognitionPipeline,
        # TensorFlow 空元组,无特定的 TensorFlow 模型
        "tf": (),
        # 如果 Torch 可用,包含 AutoModelForCTC 和 AutoModelForSpeechSeq2Seq 类
        "pt": (AutoModelForCTC, AutoModelForSpeechSeq2Seq) if is_torch_available() else (),
        # 默认模型为 wav2vec2-base-960h,版本为 "55bb623"
        "default": {"model": {"pt": ("facebook/wav2vec2-base-960h", "55bb623")}},
        # 类型为多模态
        "type": "multimodal",
    },
    # 文本转音频任务
    "text-to-audio": {
        # 实现类为 TextToAudioPipeline
        "impl": TextToAudioPipeline,
        # TensorFlow 空元组,无特定的 TensorFlow 模型
        "tf": (),
        # 如果 Torch 可用,包含 AutoModelForTextToWaveform 和 AutoModelForTextToSpectrogram 类
        "pt": (AutoModelForTextToWaveform, AutoModelForTextToSpectrogram) if is_torch_available() else (),
        # 默认模型为 bark-small,版本为 "645cfba"
        "default": {"model": {"pt": ("suno/bark-small", "645cfba")}},
        # 类型为文本
        "type": "text",
    },
    # 特征提取任务
    "feature-extraction": {
        # 实现类为 FeatureExtractionPipeline
        "impl": FeatureExtractionPipeline,
        # 如果 TensorFlow 可用,包含 TFAutoModel 类
        "tf": (TFAutoModel,) if is_tf_available() else (),
        # 如果 Torch 可用,包含 AutoModel 类
        "pt": (AutoModel,) if is_torch_available() else (),
        # 默认模型为 distilbert-base-cased,版本为 "935ac13",同时支持 TensorFlow 和 Torch
        "default": {
            "model": {
                "pt": ("distilbert/distilbert-base-cased", "935ac13"),
                "tf": ("distilbert/distilbert-base-cased", "935ac13"),
            }
        },
        # 类型为多模态
        "type": "multimodal",
    },
    # 文本分类任务
    "text-classification": {
        # 实现类为 TextClassificationPipeline
        "impl": TextClassificationPipeline,
        # 如果 TensorFlow 可用,包含 TFAutoModelForSequenceClassification 类
        "tf": (TFAutoModelForSequenceClassification,) if is_tf_available() else (),
        # 如果 Torch 可用,包含 AutoModelForSequenceClassification 类
        "pt": (AutoModelForSequenceClassification,) if is_torch_available() else (),
        # 默认模型为 distilbert-base-uncased-finetuned-sst-2-english,版本为 "af0f99b",同时支持 TensorFlow 和 Torch
        "default": {
            "model": {
                "pt": ("distilbert/distilbert-base-uncased-finetuned-sst-2-english", "af0f99b"),
                "tf": ("distilbert/distilbert-base-uncased-finetuned-sst-2-english", "af0f99b"),
            },
        },
        # 类型为文本
        "type": "text",
    },
    # 标记分类任务
    "token-classification": {
        # 实现类为 TokenClassificationPipeline
        "impl": TokenClassificationPipeline,
        # 如果 TensorFlow 可用,包含 TFAutoModelForTokenClassification 类
        "tf": (TFAutoModelForTokenClassification,) if is_tf_available() else (),
        # 如果 Torch 可用,包含 AutoModelForTokenClassification 类
        "pt": (AutoModelForTokenClassification,) if is_torch_available() else (),
        # 默认模型为 bert-large-cased-finetuned-conll03-english,版本为 "f2482bf",同时支持 TensorFlow 和 Torch
        "default": {
            "model": {
                "pt": ("dbmdz/bert-large-cased-finetuned-conll03-english", "f2482bf"),
                "tf": ("dbmdz/bert-large-cased-finetuned-conll03-english", "f2482bf"),
            },
        },
        # 类型为文本
        "type": "text",
    },
    # 问答任务
    "question-answering": {
        # 实现类为 QuestionAnsweringPipeline
        "impl": QuestionAnsweringPipeline,
        # 如果 TensorFlow 可用,包含 TFAutoModelForQuestionAnswering 类
        "tf": (TFAutoModelForQuestionAnswering,) if is_tf_available() else (),
        # 如果 Torch 可用,包含 AutoModelForQuestionAnswering 类
        "pt": (AutoModelForQuestionAnswering,) if is_torch_available() else (),
        # 默认模型为 distilbert-base-cased-distilled-squad,版本为 "626af31",同时支持 TensorFlow 和 Torch
        "default": {
            "model": {
                "pt": ("distilbert/distilbert-base-cased-distilled-squad", "626af31"),
                "tf": ("distilbert/distilbert-base-cased-distilled-squad", "626af31"),
            },
        },
        # 类型为文本
        "type": "text",
    },
    # 定义 table-question-answering 任务配置项
    "table-question-answering": {
        # 使用 TableQuestionAnsweringPipeline 处理该任务
        "impl": TableQuestionAnsweringPipeline,
        # 如果有 Torch 可用,则提供 Torch 模型
        "pt": (AutoModelForTableQuestionAnswering,) if is_torch_available() else (),
        # 如果有 TensorFlow 可用,则提供 TensorFlow 模型
        "tf": (TFAutoModelForTableQuestionAnswering,) if is_tf_available() else (),
        # 默认模型设定
        "default": {
            "model": {
                # Torch 模型及其版本
                "pt": ("google/tapas-base-finetuned-wtq", "69ceee2"),
                # TensorFlow 模型及其版本
                "tf": ("google/tapas-base-finetuned-wtq", "69ceee2"),
            },
        },
        # 任务类型为文本处理
        "type": "text",
    },
    
    # 定义 visual-question-answering 任务配置项
    "visual-question-answering": {
        # 使用 VisualQuestionAnsweringPipeline 处理该任务
        "impl": VisualQuestionAnsweringPipeline,
        # 如果有 Torch 可用,则提供 Torch 模型
        "pt": (AutoModelForVisualQuestionAnswering,) if is_torch_available() else (),
        # TensorFlow 模型部分为空,表示无 TensorFlow 模型
        "tf": (),
        # 默认模型设定
        "default": {
            "model": {
                # Torch 模型及其版本
                "pt": ("dandelin/vilt-b32-finetuned-vqa", "4355f59"),
            },
        },
        # 任务类型为多模态处理
        "type": "multimodal",
    },
    
    # 定义 document-question-answering 任务配置项
    "document-question-answering": {
        # 使用 DocumentQuestionAnsweringPipeline 处理该任务
        "impl": DocumentQuestionAnsweringPipeline,
        # 如果有 Torch 可用,则提供 Torch 模型
        "pt": (AutoModelForDocumentQuestionAnswering,) if is_torch_available() else (),
        # TensorFlow 模型部分为空,表示无 TensorFlow 模型
        "tf": (),
        # 默认模型设定
        "default": {
            "model": {
                # Torch 模型及其版本
                "pt": ("impira/layoutlm-document-qa", "52e01b3"),
            },
        },
        # 任务类型为多模态处理
        "type": "multimodal",
    },
    
    # 定义 fill-mask 任务配置项
    "fill-mask": {
        # 使用 FillMaskPipeline 处理该任务
        "impl": FillMaskPipeline,
        # 如果有 TensorFlow 可用,则提供 TensorFlow 模型
        "tf": (TFAutoModelForMaskedLM,) if is_tf_available() else (),
        # 如果有 Torch 可用,则提供 Torch 模型
        "pt": (AutoModelForMaskedLM,) if is_torch_available() else (),
        # 默认模型设定
        "default": {
            "model": {
                # Torch 模型及其版本
                "pt": ("distilbert/distilroberta-base", "ec58a5b"),
                # TensorFlow 模型及其版本
                "tf": ("distilbert/distilroberta-base", "ec58a5b"),
            }
        },
        # 任务类型为文本处理
        "type": "text",
    },
    
    # 定义 summarization 任务配置项
    "summarization": {
        # 使用 SummarizationPipeline 处理该任务
        "impl": SummarizationPipeline,
        # 如果有 TensorFlow 可用,则提供 TensorFlow 模型
        "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
        # 如果有 Torch 可用,则提供 Torch 模型
        "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
        # 默认模型设定
        "default": {
            "model": {
                # Torch 模型及其版本
                "pt": ("sshleifer/distilbart-cnn-12-6", "a4f8f3e"),
                # TensorFlow 模型及其版本
                "tf": ("google-t5/t5-small", "d769bba")
            }
        },
        # 任务类型为文本处理
        "type": "text",
    },
    
    # translation 任务是特殊情况,参数化为 SRC 和 TGT 语言
    "translation": {
        # 使用 TranslationPipeline 处理该任务
        "impl": TranslationPipeline,
        # 如果有 TensorFlow 可用,则提供 TensorFlow 模型
        "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
        # 如果有 Torch 可用,则提供 Torch 模型
        "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
        # 默认模型设定
        "default": {
            # 设定不同的 SRC 和 TGT 语言对应的模型
            ("en", "fr"): {"model": {"pt": ("google-t5/t5-base", "686f1db"), "tf": ("google-t5/t5-base", "686f1db")}},
            ("en", "de"): {"model": {"pt": ("google-t5/t5-base", "686f1db"), "tf": ("google-t5/t5-base", "686f1db")}},
            ("en", "ro"): {"model": {"pt": ("google-t5/t5-base", "686f1db"), "tf": ("google-t5/t5-base", "686f1db")}},
        },
        # 任务类型为文本处理
        "type": "text",
    },
    "text2text-generation": {  # 文本到文本生成任务配置
        "impl": Text2TextGenerationPipeline,  # 使用 Text2TextGenerationPipeline 类实现
        "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),  # 如果 TensorFlow 可用,使用 TFAutoModelForSeq2SeqLM 模型
        "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),  # 如果 PyTorch 可用,使用 AutoModelForSeq2SeqLM 模型
        "default": {"model": {"pt": ("google-t5/t5-base", "686f1db"), "tf": ("google-t5/t5-base", "686f1db")}},  # 默认模型配置
        "type": "text",  # 任务类型为文本生成
    },
    "text-generation": {  # 文本生成任务配置
        "impl": TextGenerationPipeline,  # 使用 TextGenerationPipeline 类实现
        "tf": (TFAutoModelForCausalLM,) if is_tf_available() else (),  # 如果 TensorFlow 可用,使用 TFAutoModelForCausalLM 模型
        "pt": (AutoModelForCausalLM,) if is_torch_available() else (),  # 如果 PyTorch 可用,使用 AutoModelForCausalLM 模型
        "default": {"model": {"pt": ("openai-community/gpt2", "6c0e608"), "tf": ("openai-community/gpt2", "6c0e608")}},  # 默认模型配置
        "type": "text",  # 任务类型为文本生成
    },
    "zero-shot-classification": {  # 零样本分类任务配置
        "impl": ZeroShotClassificationPipeline,  # 使用 ZeroShotClassificationPipeline 类实现
        "tf": (TFAutoModelForSequenceClassification,) if is_tf_available() else (),  # 如果 TensorFlow 可用,使用 TFAutoModelForSequenceClassification 模型
        "pt": (AutoModelForSequenceClassification,) if is_torch_available() else (),  # 如果 PyTorch 可用,使用 AutoModelForSequenceClassification 模型
        "default": {  # 默认配置
            "model": {  # 模型配置
                "pt": ("facebook/bart-large-mnli", "c626438"),  # PyTorch 使用 Facebook BART 大型 MNLI 模型
                "tf": ("FacebookAI/roberta-large-mnli", "130fb28"),  # TensorFlow 使用 Facebook RoBERTa 大型 MNLI 模型
            },
            "config": {  # 额外配置
                "pt": ("facebook/bart-large-mnli", "c626438"),  # PyTorch 使用相同的 BART 大型 MNLI 模型
                "tf": ("FacebookAI/roberta-large-mnli", "130fb28"),  # TensorFlow 使用相同的 RoBERTa 大型 MNLI 模型
            },
        },
        "type": "text",  # 任务类型为文本分类
    },
    "zero-shot-image-classification": {  # 零样本图像分类任务配置
        "impl": ZeroShotImageClassificationPipeline,  # 使用 ZeroShotImageClassificationPipeline 类实现
        "tf": (TFAutoModelForZeroShotImageClassification,) if is_tf_available() else (),  # 如果 TensorFlow 可用,使用 TFAutoModelForZeroShotImageClassification 模型
        "pt": (AutoModelForZeroShotImageClassification,) if is_torch_available() else (),  # 如果 PyTorch 可用,使用 AutoModelForZeroShotImageClassification 模型
        "default": {  # 默认配置
            "model": {  # 模型配置
                "pt": ("openai/clip-vit-base-patch32", "f4881ba"),  # PyTorch 使用 OpenAI CLIP-ViT Base 模型
                "tf": ("openai/clip-vit-base-patch32", "f4881ba"),  # TensorFlow 使用相同的 CLIP-ViT Base 模型
            }
        },
        "type": "multimodal",  # 任务类型为多模态
    },
    "zero-shot-audio-classification": {  # 零样本音频分类任务配置
        "impl": ZeroShotAudioClassificationPipeline,  # 使用 ZeroShotAudioClassificationPipeline 类实现
        "tf": (),  # TensorFlow 不适用于此任务,设为空元组
        "pt": (AutoModel,) if is_torch_available() else (),  # 如果 PyTorch 可用,使用 AutoModel 模型
        "default": {  # 默认配置
            "model": {  # 模型配置
                "pt": ("laion/clap-htsat-fused", "973b6e5"),  # PyTorch 使用 Laion CLAP-HTSAT-Fused 模型
            }
        },
        "type": "multimodal",  # 任务类型为多模态
    },
    "conversational": {  # 对话生成任务配置
        "impl": ConversationalPipeline,  # 使用 ConversationalPipeline 类实现
        "tf": (TFAutoModelForSeq2SeqLM, TFAutoModelForCausalLM) if is_tf_available() else (),  # 如果 TensorFlow 可用,使用 TFAutoModelForSeq2SeqLM 和 TFAutoModelForCausalLM 模型
        "pt": (AutoModelForSeq2SeqLM, AutoModelForCausalLM) if is_torch_available() else (),  # 如果 PyTorch 可用,使用 AutoModelForSeq2SeqLM 和 AutoModelForCausalLM 模型
        "default": {  # 默认配置
            "model": {"pt": ("microsoft/DialoGPT-medium", "8bada3b"), "tf": ("microsoft/DialoGPT-medium", "8bada3b")}  # 使用 Microsoft DialoGPT 中等模型
        },
        "type": "text",  # 任务类型为文本生成
    },
    {
        # 图像分类任务的配置
        "image-classification": {
            # 实现图像分类任务的流水线
            "impl": ImageClassificationPipeline,
            # TensorFlow 可用时的模型配置,包含自动图像分类模型
            "tf": (TFAutoModelForImageClassification,) if is_tf_available() else (),
            # PyTorch 可用时的模型配置,包含自动图像分类模型
            "pt": (AutoModelForImageClassification,) if is_torch_available() else (),
            # 默认模型配置
            "default": {
                "model": {
                    # PyTorch 的默认模型为 VIT-base-patch16-224,版本为 5dca96d
                    "pt": ("google/vit-base-patch16-224", "5dca96d"),
                    # TensorFlow 的默认模型为 VIT-base-patch16-224,版本为 5dca96d
                    "tf": ("google/vit-base-patch16-224", "5dca96d"),
                }
            },
            # 任务类型为图像处理
            "type": "image",
        },
        # 图像特征提取任务的配置
        "image-feature-extraction": {
            # 实现图像特征提取任务的流水线
            "impl": ImageFeatureExtractionPipeline,
            # TensorFlow 可用时的模型配置,包含自动模型
            "tf": (TFAutoModel,) if is_tf_available() else (),
            # PyTorch 可用时的模型配置,包含自动模型
            "pt": (AutoModel,) if is_torch_available() else (),
            # 默认模型配置
            "default": {
                "model": {
                    # PyTorch 的默认模型为 VIT-base-patch16-224,版本为 29e7a1e183
                    "pt": ("google/vit-base-patch16-224", "29e7a1e183"),
                    # TensorFlow 的默认模型为 VIT-base-patch16-224,版本为 29e7a1e183
                    "tf": ("google/vit-base-patch16-224", "29e7a1e183"),
                }
            },
            # 任务类型为图像处理
            "type": "image",
        },
        # 图像分割任务的配置
        "image-segmentation": {
            # 实现图像分割任务的流水线
            "impl": ImageSegmentationPipeline,
            # TensorFlow 可用时的模型配置为空元组,表示不可用
            "tf": (),
            # PyTorch 可用时的模型配置,包含自动目标分割和语义分割模型
            "pt": (AutoModelForImageSegmentation, AutoModelForSemanticSegmentation) if is_torch_available() else (),
            # 默认模型配置,PyTorch 的默认模型为 DETR-resnet-50-panoptic,版本为 fc15262
            "default": {"model": {"pt": ("facebook/detr-resnet-50-panoptic", "fc15262")}},
            # 任务类型为多模态处理
            "type": "multimodal",
        },
        # 图像到文本任务的配置
        "image-to-text": {
            # 实现图像到文本任务的流水线
            "impl": ImageToTextPipeline,
            # TensorFlow 可用时的模型配置,包含自动视觉到序列模型
            "tf": (TFAutoModelForVision2Seq,) if is_tf_available() else (),
            # PyTorch 可用时的模型配置,包含自动视觉到序列模型
            "pt": (AutoModelForVision2Seq,) if is_torch_available() else (),
            # 默认模型配置,PyTorch 的默认模型为 VIT-GPT2-COCO-en,版本为 65636df
            "default": {
                "model": {
                    "pt": ("ydshieh/vit-gpt2-coco-en", "65636df"),
                    "tf": ("ydshieh/vit-gpt2-coco-en", "65636df"),
                }
            },
            # 任务类型为多模态处理
            "type": "multimodal",
        },
        # 目标检测任务的配置
        "object-detection": {
            # 实现目标检测任务的流水线
            "impl": ObjectDetectionPipeline,
            # TensorFlow 可用时的模型配置为空元组,表示不可用
            "tf": (),
            # PyTorch 可用时的模型配置,包含自动目标检测模型
            "pt": (AutoModelForObjectDetection,) if is_torch_available() else (),
            # 默认模型配置,PyTorch 的默认模型为 DETR-resnet-50,版本为 2729413
            "default": {"model": {"pt": ("facebook/detr-resnet-50", "2729413")}},
            # 任务类型为多模态处理
            "type": "multimodal",
        },
        # 零样本目标检测任务的配置
        "zero-shot-object-detection": {
            # 实现零样本目标检测任务的流水线
            "impl": ZeroShotObjectDetectionPipeline,
            # TensorFlow 可用时的模型配置为空元组,表示不可用
            "tf": (),
            # PyTorch 可用时的模型配置,包含自动零样本目标检测模型
            "pt": (AutoModelForZeroShotObjectDetection,) if is_torch_available() else (),
            # 默认模型配置,PyTorch 的默认模型为 OWL-ViT-base-patch32,版本为 17740e1
            "default": {"model": {"pt": ("google/owlvit-base-patch32", "17740e1")}},
            # 任务类型为多模态处理
            "type": "multimodal",
        },
        # 深度估计任务的配置
        "depth-estimation": {
            # 实现深度估计任务的流水线
            "impl": DepthEstimationPipeline,
            # TensorFlow 可用时的模型配置为空元组,表示不可用
            "tf": (),
            # PyTorch 可用时的模型配置,包含自动深度估计模型
            "pt": (AutoModelForDepthEstimation,) if is_torch_available() else (),
            # 默认模型配置,PyTorch 的默认模型为 DPT-large,版本为 e93beec
            "default": {"model": {"pt": ("Intel/dpt-large", "e93beec")}},
            # 任务类型为图像处理
            "type": "image",
        },
        # 视频分类任务的配置
        "video-classification": {
            # 实现视频分类任务的流水线
            "impl": VideoClassificationPipeline,
            # TensorFlow 可用时的模型配置为空元组,表示不可用
            "tf": (),
            # PyTorch 可用时的模型配置,包含自动视频分类模型
            "pt": (AutoModelForVideoClassification,) if is_torch_available() else (),
            # 默认模型配置,PyTorch 的默认模型为 VideoMae-base-finetuned-kinetics,版本为 4800870
            "default": {"model": {"pt": ("MCG-NJU/videomae-base-finetuned-kinetics", "4800870")}},
            # 任务类型为视频处理
            "type": "video",
        },
    }
    # "mask-generation"任务配置
    "mask-generation": {
        # 使用MaskGenerationPipeline作为实现
        "impl": MaskGenerationPipeline,
        # TensorFlow环境下不需要额外模型
        "tf": (),
        # 如果有PyTorch环境,使用AutoModelForMaskGeneration作为模型
        "pt": (AutoModelForMaskGeneration,) if is_torch_available() else (),
        # 默认模型配置,使用Facebook的"facebook/sam-vit-huge"模型
        "default": {"model": {"pt": ("facebook/sam-vit-huge", "997b15")}},
        # 任务类型为多模态处理
        "type": "multimodal",
    },
    
    # "image-to-image"任务配置
    "image-to-image": {
        # 使用ImageToImagePipeline作为实现
        "impl": ImageToImagePipeline,
        # TensorFlow环境下不需要额外模型
        "tf": (),
        # 如果有PyTorch环境,使用AutoModelForImageToImage作为模型
        "pt": (AutoModelForImageToImage,) if is_torch_available() else (),
        # 默认模型配置,使用"caidas/swin2SR-classical-sr-x2-64"模型
        "default": {"model": {"pt": ("caidas/swin2SR-classical-sr-x2-64", "4aaedcb")}},
        # 任务类型为图像处理
        "type": "image",
    },
}

# 初始化空集合,用于存放没有特征提取器的任务
NO_FEATURE_EXTRACTOR_TASKS = set()
# 初始化空集合,用于存放没有图像处理器的任务
NO_IMAGE_PROCESSOR_TASKS = set()
# 初始化空集合,用于存放没有分词器的任务
NO_TOKENIZER_TASKS = set()

# 下面这些模型配置是特殊的,它们是通用的,适用于多种任务,意味着任何分词器/特征提取器都可能用于给定的模型,
# 因此我们无法使用静态定义的 TOKENIZER_MAPPING 和 FEATURE_EXTRACTOR_MAPPING 来查看模型是否定义了这些对象。
MULTI_MODEL_AUDIO_CONFIGS = {"SpeechEncoderDecoderConfig"}
MULTI_MODEL_VISION_CONFIGS = {"VisionEncoderDecoderConfig", "VisionTextDualEncoderConfig"}

# 遍历 SUPPORTED_TASKS 中的任务及其值
for task, values in SUPPORTED_TASKS.items():
    if values["type"] == "text":
        # 如果任务类型为文本,将其添加到没有特征提取器的任务集合中
        NO_FEATURE_EXTRACTOR_TASKS.add(task)
        # 如果任务类型为文本,将其添加到没有图像处理器的任务集合中
        NO_IMAGE_PROCESSOR_TASKS.add(task)
    elif values["type"] in {"image", "video"}:
        # 如果任务类型为图像或视频,将其添加到没有分词器的任务集合中
        NO_TOKENIZER_TASKS.add(task)
    elif values["type"] in {"audio"}:
        # 如果任务类型为音频,将其添加到没有分词器的任务集合中
        NO_TOKENIZER_TASKS.add(task)
        # 如果任务类型为音频,将其添加到没有图像处理器的任务集合中
        NO_IMAGE_PROCESSOR_TASKS.add(task)
    elif values["type"] != "multimodal":
        # 如果任务类型不是多模态,抛出异常,说明不支持的任务类型
        raise ValueError(f"SUPPORTED_TASK {task} contains invalid type {values['type']}")

# 创建管道注册对象,使用支持的任务和任务别名作为参数
PIPELINE_REGISTRY = PipelineRegistry(supported_tasks=SUPPORTED_TASKS, task_aliases=TASK_ALIASES)


def get_supported_tasks() -> List[str]:
    """
    返回支持的任务列表。
    """
    return PIPELINE_REGISTRY.get_supported_tasks()


def get_task(model: str, token: Optional[str] = None, **deprecated_kwargs) -> str:
    """
    根据模型和令牌返回任务字符串,支持废弃的参数。
    """
    # 弹出废弃的参数 use_auth_token,并赋值给 use_auth_token
    use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
    
    # 如果 use_auth_token 不为 None,发出废弃警告信息
    if use_auth_token is not None:
        warnings.warn(
            "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
            FutureWarning,
        )
        # 如果 token 不为 None,引发值错误,说明同时指定了 token 和 use_auth_token 参数
        if token is not None:
            raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
        # 将 use_auth_token 赋值给 token
        token = use_auth_token

    # 如果处于离线模式,引发运行时错误,说明不能在离线模式下自动推断任务
    if is_offline_mode():
        raise RuntimeError("You cannot infer task automatically within `pipeline` when using offline mode")
    
    # 尝试获取模型信息,如果出现异常,引发运行时错误
    try:
        info = model_info(model, token=token)
    except Exception as e:
        raise RuntimeError(f"Instantiating a pipeline without a task set raised an error: {e}")
    
    # 如果信息中没有 pipeline_tag 属性,引发运行时错误,说明模型没有正确设置 pipeline_tag 来自动推断任务
    if not info.pipeline_tag:
        raise RuntimeError(
            f"The model {model} does not seem to have a correct `pipeline_tag` set to infer the task automatically"
        )
    
    # 如果 info 的 library_name 属性不是 "transformers",引发运行时错误,说明该模型应该使用其他库而不是 transformers
    if getattr(info, "library_name", "transformers") != "transformers":
        raise RuntimeError(f"This model is meant to be used with {info.library_name} not with transformers")
    
    # 返回从 info 中推断的 pipeline_tag 作为任务
    task = info.pipeline_tag
    return task


def check_task(task: str) -> Tuple[str, Dict, Any]:
    """
    检查传入的任务字符串,验证其正确性,并返回默认的管道和模型类,以及默认模型(如果存在)。
    """
    Args:
        task (`str`):
            指定要返回的流水线的任务。目前接受的任务包括:

            - `"audio-classification"`
            - `"automatic-speech-recognition"`
            - `"conversational"`
            - `"depth-estimation"`
            - `"document-question-answering"`
            - `"feature-extraction"`
            - `"fill-mask"`
            - `"image-classification"`
            - `"image-feature-extraction"`
            - `"image-segmentation"`
            - `"image-to-text"`
            - `"image-to-image"`
            - `"object-detection"`
            - `"question-answering"`
            - `"summarization"`
            - `"table-question-answering"`
            - `"text2text-generation"`
            - `"text-classification"`(别名为 `"sentiment-analysis"` 可用)
            - `"text-generation"`
            - `"text-to-audio"`(别名为 `"text-to-speech"` 可用)
            - `"token-classification"`(别名为 `"ner"` 可用)
            - `"translation"`
            - `"translation_xx_to_yy"`
            - `"video-classification"`
            - `"visual-question-answering"`(别名为 `"vqa"` 可用)
            - `"zero-shot-classification"`
            - `"zero-shot-image-classification"`
            - `"zero-shot-object-detection"`

    Returns:
        返回一个元组,包含标准化后的任务名称 `normalized_task`(去除了别名和选项)、任务默认设置字典 `task_defaults`,以及一些额外的任务选项 `task_options`(对于像 "translation_XX_to_YY" 这样带参数的任务)。

    """
    return PIPELINE_REGISTRY.check_task(task)
def clean_custom_task(task_info):
    import transformers  # 导入transformers库

    # 检查任务信息中是否包含实现信息,如果没有则抛出运行时错误
    if "impl" not in task_info:
        raise RuntimeError("This model introduces a custom pipeline without specifying its implementation.")
    
    pt_class_names = task_info.get("pt", ())  # 获取pt_class_names,如果不存在则默认为空元组
    if isinstance(pt_class_names, str):
        pt_class_names = [pt_class_names]  # 如果pt_class_names是字符串,转换为列表
    # 将pt_class_names中每个类名对应的类对象存入task_info["pt"]中
    task_info["pt"] = tuple(getattr(transformers, c) for c in pt_class_names)
    
    tf_class_names = task_info.get("tf", ())  # 获取tf_class_names,如果不存在则默认为空元组
    if isinstance(tf_class_names, str):
        tf_class_names = [tf_class_names]  # 如果tf_class_names是字符串,转换为列表
    # 将tf_class_names中每个类名对应的类对象存入task_info["tf"]中
    task_info["tf"] = tuple(getattr(transformers, c) for c in tf_class_names)
    
    return task_info, None  # 返回更新后的task_info和None作为第二个返回值


def pipeline(
    task: str = None,
    model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None,
    config: Optional[Union[str, PretrainedConfig]] = None,
    tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None,
    feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None,
    image_processor: Optional[Union[str, BaseImageProcessor]] = None,
    framework: Optional[str] = None,
    revision: Optional[str] = None,
    use_fast: bool = True,
    token: Optional[Union[str, bool]] = None,
    device: Optional[Union[int, str, "torch.device"]] = None,
    device_map=None,
    torch_dtype=None,
    trust_remote_code: Optional[bool] = None,
    model_kwargs: Dict[str, Any] = None,
    pipeline_class: Optional[Any] = None,
    **kwargs,
) -> Pipeline:
    """
    Utility factory method to build a [`Pipeline`].

    Pipelines are made of:

        - A [tokenizer](tokenizer) in charge of mapping raw textual input to token.
        - A [model](model) to make predictions from the inputs.
        - Some (optional) post processing for enhancing model's output.

    Returns:
        [`Pipeline`]: A suitable pipeline for the task.

    Examples:

    ```
    >>> from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer

    >>> # Sentiment analysis pipeline
    >>> analyzer = pipeline("sentiment-analysis")

    >>> # Question answering pipeline, specifying the checkpoint identifier
    >>> oracle = pipeline(
    ...     "question-answering", model="distilbert/distilbert-base-cased-distilled-squad", tokenizer="google-bert/bert-base-cased"
    ... )

    >>> # Named entity recognition pipeline, passing in a specific model and tokenizer
    >>> model = AutoModelForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")
    >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased")
    >>> recognizer = pipeline("ner", model=model, tokenizer=tokenizer)
    ```"""
    if model_kwargs is None:
        model_kwargs = {}
    
    # 确保只将use_auth_token作为一个关键字参数传递(以前可以将其传递给model_kwargs,为了保持向后兼容性)
    use_auth_token = model_kwargs.pop("use_auth_token", None)
    # 如果 use_auth_token 参数不为 None,则发出警告,提醒该参数在 Transformers v5 版本中将被移除,建议使用 `token` 参数代替
    if use_auth_token is not None:
        warnings.warn(
            "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
            FutureWarning,
        )
        # 如果 token 参数也不为 None,则抛出 ValueError,说明同时指定了 `token` 和 `use_auth_token` 参数,应只设置 `token` 参数
        if token is not None:
            raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
        # 将 use_auth_token 的值赋给 token 参数
        token = use_auth_token

    # 从 kwargs 字典中弹出 code_revision 和 _commit_hash 参数的值
    code_revision = kwargs.pop("code_revision", None)
    commit_hash = kwargs.pop("_commit_hash", None)

    # 创建 hub_kwargs 字典,用于存储 revision、token、trust_remote_code 和 _commit_hash 参数的值
    hub_kwargs = {
        "revision": revision,
        "token": token,
        "trust_remote_code": trust_remote_code,
        "_commit_hash": commit_hash,
    }

    # 如果既未指定 task 参数也未指定 model 参数,则抛出 RuntimeError,说明无法实例化 Pipeline
    if task is None and model is None:
        raise RuntimeError(
            "Impossible to instantiate a pipeline without either a task or a model "
            "being specified. "
            "Please provide a task class or a model"
        )

    # 如果未指定 model 参数但指定了 tokenizer 参数,则抛出 RuntimeError,说明无法实例化 Pipeline
    if model is None and tokenizer is not None:
        raise RuntimeError(
            "Impossible to instantiate a pipeline with tokenizer specified but not the model as the provided tokenizer"
            " may not be compatible with the default model. Please provide a PreTrainedModel class or a"
            " path/identifier to a pretrained model when providing tokenizer."
        )

    # 如果未指定 model 参数但指定了 feature_extractor 参数,则抛出 RuntimeError,说明无法实例化 Pipeline
    if model is None and feature_extractor is not None:
        raise RuntimeError(
            "Impossible to instantiate a pipeline with feature_extractor specified but not the model as the provided"
            " feature_extractor may not be compatible with the default model. Please provide a PreTrainedModel class"
            " or a path/identifier to a pretrained model when providing feature_extractor."
        )

    # 如果 model 参数的类型是 Path 对象,则将其转换为字符串类型
    if isinstance(model, Path):
        model = str(model)

    # 如果 commit_hash 参数为 None
    if commit_hash is None:
        # 预先训练的模型名或路径名为 None
        pretrained_model_name_or_path = None
        # 如果 config 参数是字符串类型,则将其赋值给 pretrained_model_name_or_path
        if isinstance(config, str):
            pretrained_model_name_or_path = config
        # 如果 config 参数为 None 且 model 参数为字符串类型,则将 model 参数赋值给 pretrained_model_name_or_path
        elif config is None and isinstance(model, str):
            pretrained_model_name_or_path = model

        # 如果 config 参数不是 PretrainedConfig 类型且 pretrained_model_name_or_path 不为 None
        if not isinstance(config, PretrainedConfig) and pretrained_model_name_or_path is not None:
            # 首先调用配置文件 (可能不存在) 获取 commit hash
            resolved_config_file = cached_file(
                pretrained_model_name_or_path,
                CONFIG_NAME,
                _raise_exceptions_for_gated_repo=False,
                _raise_exceptions_for_missing_entries=False,
                _raise_exceptions_for_connection_errors=False,
                **hub_kwargs,
            )
            # 从配置文件中提取 commit hash,更新 hub_kwargs 中的 _commit_hash 参数
            hub_kwargs["_commit_hash"] = extract_commit_hash(resolved_config_file, commit_hash)
        else:
            # 否则,从 config 对象中获取 _commit_hash 属性的值,更新 hub_kwargs 中的 _commit_hash 参数
            hub_kwargs["_commit_hash"] = getattr(config, "_commit_hash", None)

    # 配置是最原始的信息项。
    # 如有需要则实例化配置
    # 如果配置是字符串,则根据预训练模型配置自动生成配置对象
    if isinstance(config, str):
        config = AutoConfig.from_pretrained(
            config, _from_pipeline=task, code_revision=code_revision, **hub_kwargs, **model_kwargs
        )
        # 更新 hub_kwargs 中的 _commit_hash
        hub_kwargs["_commit_hash"] = config._commit_hash
    # 如果配置为 None 且模型路径是字符串
    elif config is None and isinstance(model, str):
        # 如果 PEFT 可用,检查模型路径中是否存在适配器文件
        if is_peft_available():
            # 在模型路径中查找适配器配置文件,不包括 `trust_remote_code` 参数
            _hub_kwargs = {k: v for k, v in hub_kwargs.items() if k != "trust_remote_code"}
            maybe_adapter_path = find_adapter_config_file(
                model,
                token=hub_kwargs["token"],
                revision=hub_kwargs["revision"],
                _commit_hash=hub_kwargs["_commit_hash"],
            )

            # 如果找到适配器路径,则加载适配器配置文件中的基础模型名称或路径
            if maybe_adapter_path is not None:
                with open(maybe_adapter_path, "r", encoding="utf-8") as f:
                    adapter_config = json.load(f)
                    model = adapter_config["base_model_name_or_path"]

        # 根据模型路径加载自动配置对象
        config = AutoConfig.from_pretrained(
            model, _from_pipeline=task, code_revision=code_revision, **hub_kwargs, **model_kwargs
        )
        # 更新 hub_kwargs 中的 _commit_hash
        hub_kwargs["_commit_hash"] = config._commit_hash

    # 自定义任务字典初始化为空
    custom_tasks = {}
    # 如果配置对象不为空且存在自定义流水线,则获取自定义流水线任务
    if config is not None and len(getattr(config, "custom_pipelines", {})) > 0:
        custom_tasks = config.custom_pipelines
        # 如果任务为 None 且不禁止远程代码,则尝试自动推断任务
        if task is None and trust_remote_code is not False:
            # 如果只有一个自定义任务,则自动选择该任务
            if len(custom_tasks) == 1:
                task = list(custom_tasks.keys())[0]
            else:
                # 如果存在多个自定义任务,则抛出运行时错误,要求手动选择任务
                raise RuntimeError(
                    "We can't infer the task automatically for this model as there are multiple tasks available. Pick "
                    f"one in {', '.join(custom_tasks.keys())}"
                )

    # 如果任务仍为 None 且模型不为空,则尝试获取任务
    if task is None and model is not None:
        # 如果模型不是字符串,则抛出运行时错误
        if not isinstance(model, str):
            raise RuntimeError(
                "Inferring the task automatically requires to check the hub with a model_id defined as a `str`. "
                f"{model} is not a valid model_id."
            )
        # 根据模型 ID 和 token 获取任务
        task = get_task(model, token)

    # 获取任务后的处理流程
    if task in custom_tasks:
        # 标准化任务名称
        normalized_task = task
        # 清理自定义任务,获取目标任务和任务选项
        targeted_task, task_options = clean_custom_task(custom_tasks[task])
        # 如果未指定流水线类,则根据情况抛出 ValueError
        if pipeline_class is None:
            # 如果不信任远程代码,则要求设置 `trust_remote_code=True` 以移除错误
            if not trust_remote_code:
                raise ValueError(
                    "Loading this pipeline requires you to execute the code in the pipeline file in that"
                    " repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
                    " set the option `trust_remote_code=True` to remove this error."
                )
            # 从动态模块中获取类引用
            class_ref = targeted_task["impl"]
            pipeline_class = get_class_from_dynamic_module(
                class_ref,
                model,
                code_revision=code_revision,
                **hub_kwargs,
            )
    else:
        # 检查任务并返回标准化的任务、目标任务和任务选项
        normalized_task, targeted_task, task_options = check_task(task)
        # 如果未指定流水线类,则使用目标任务的实现类作为默认流水线类
        if pipeline_class is None:
            pipeline_class = targeted_task["impl"]

    # 如果未提供模型,则使用任务的默认模型、配置和分词器
    if model is None:
        # 获取任务的默认模型及其修订版本
        model, default_revision = get_default_model_and_revision(targeted_task, framework, task_options)
        # 如果未指定修订版本,则使用默认修订版本
        revision = revision if revision is not None else default_revision
        # 记录警告信息,指出未提供模型,使用默认模型和修订版本
        logger.warning(
            f"No model was supplied, defaulted to {model} and revision"
            f" {revision} ({HUGGINGFACE_CO_RESOLVE_ENDPOINT}/{model}).\n"
            "Using a pipeline without specifying a model name and revision in production is not recommended."
        )
        # 如果未提供配置且模型名称为字符串,则从预训练模型中创建配置对象
        if config is None and isinstance(model, str):
            config = AutoConfig.from_pretrained(model, _from_pipeline=task, **hub_kwargs, **model_kwargs)
            # 将配置的提交哈希记录到 hub_kwargs 中
            hub_kwargs["_commit_hash"] = config._commit_hash

    # 如果设备映射不为空,则处理相关参数
    if device_map is not None:
        # 如果模型参数中已包含 device_map,抛出错误
        if "device_map" in model_kwargs:
            raise ValueError(
                'You cannot use both `pipeline(... device_map=..., model_kwargs={"device_map":...})` as those'
                " arguments might conflict, use only one.)"
            )
        # 如果同时指定了 device 和 device_map,则发出警告
        if device is not None:
            logger.warning(
                "Both `device` and `device_map` are specified. `device` will override `device_map`. You"
                " will most likely encounter unexpected behavior. Please remove `device` and keep `device_map`."
            )
        # 将 device_map 添加到模型参数中
        model_kwargs["device_map"] = device_map

    # 如果 torch 数据类型不为空,则处理相关参数
    if torch_dtype is not None:
        # 如果模型参数中已包含 torch_dtype,抛出错误
        if "torch_dtype" in model_kwargs:
            raise ValueError(
                'You cannot use both `pipeline(... torch_dtype=..., model_kwargs={"torch_dtype":...})` as those'
                " arguments might conflict, use only one.)"
            )
        # 如果 torch_dtype 是字符串且存在于 torch 模块中,则转换成相应的 torch 数据类型
        if isinstance(torch_dtype, str) and hasattr(torch, torch_dtype):
            torch_dtype = getattr(torch, torch_dtype)
        # 将 torch_dtype 添加到模型参数中
        model_kwargs["torch_dtype"] = torch_dtype

    # 如果模型名称是字符串,则推断框架并加载模型
    if isinstance(model, str) or framework is None:
        # 定义模型类别(TensorFlow 或 PyTorch)并根据模型加载相应的框架和模型
        model_classes = {"tf": targeted_task["tf"], "pt": targeted_task["pt"]}
        framework, model = infer_framework_load_model(
            model,
            model_classes=model_classes,
            config=config,
            framework=framework,
            task=task,
            **hub_kwargs,
            **model_kwargs,
        )

    # 获取模型的配置信息
    model_config = model.config
    # 将配置的提交哈希记录到 hub_kwargs 中
    hub_kwargs["_commit_hash"] = model.config._commit_hash
    # 判断是否需要加载分词器
    load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None
    # 判断是否需要加载特征提取器
    load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None
    # 检查是否需要加载图像处理器,条件为模型配置在图像处理器映射中或者图像处理器不为空
    load_image_processor = type(model_config) in IMAGE_PROCESSOR_MAPPING or image_processor is not None

    # 如果传入的`model`(`PretrainedModel`的实例而不是字符串),并且`image_processor`或`feature_extractor`为空,
    # 则加载将失败。这在某些视觉任务中特别发生,当使用`pipeline()`函数时传入`model`和其中一个`image_processor`或`feature_extractor`时。
    # TODO: 我们需要使`NO_IMAGE_PROCESSOR_TASKS`和`NO_FEATURE_EXTRACTOR_TASKS`更加健壮,以避免这种问题。
    # 这段代码仅用于临时使CI通过。
    if load_image_processor and load_feature_extractor:
        load_feature_extractor = False

    # 如果`tokenizer`为空,并且不需要加载`tokenizer`,并且`normalized_task`不在`NO_TOKENIZER_TASKS`中,
    # 并且`model_config`的类名在`MULTI_MODEL_AUDIO_CONFIGS`或`MULTI_MODEL_VISION_CONFIGS`中,
    # 则尝试强制加载`tokenizer`。
    if (
        tokenizer is None
        and not load_tokenizer
        and normalized_task not in NO_TOKENIZER_TASKS
        # 使用类名来避免导入真实类。
        and (
            model_config.__class__.__name__ in MULTI_MODEL_AUDIO_CONFIGS
            or model_config.__class__.__name__ in MULTI_MODEL_VISION_CONFIGS
        )
    ):
        load_tokenizer = True

    # 如果`image_processor`为空,并且不需要加载`image_processor`,并且`normalized_task`不在`NO_IMAGE_PROCESSOR_TASKS`中,
    # 并且`model_config`的类名在`MULTI_MODEL_VISION_CONFIGS`中,
    # 则尝试强制加载`image_processor`。
    if (
        image_processor is None
        and not load_image_processor
        and normalized_task not in NO_IMAGE_PROCESSOR_TASKS
        # 使用类名来避免导入真实类。
        and model_config.__class__.__name__ in MULTI_MODEL_VISION_CONFIGS
    ):
        load_image_processor = True

    # 如果`feature_extractor`为空,并且不需要加载`feature_extractor`,并且`normalized_task`不在`NO_FEATURE_EXTRACTOR_TASKS`中,
    # 并且`model_config`的类名在`MULTI_MODEL_AUDIO_CONFIGS`中,
    # 则尝试强制加载`feature_extractor`。
    if (
        feature_extractor is None
        and not load_feature_extractor
        and normalized_task not in NO_FEATURE_EXTRACTOR_TASKS
        # 使用类名来避免导入真实类。
        and model_config.__class__.__name__ in MULTI_MODEL_AUDIO_CONFIGS
    ):
        load_feature_extractor = True

    # 如果任务在`NO_TOKENIZER_TASKS`中,则不需要加载`tokenizer`。
    if task in NO_TOKENIZER_TASKS:
        load_tokenizer = False

    # 如果任务在`NO_FEATURE_EXTRACTOR_TASKS`中,则不需要加载`feature_extractor`。
    if task in NO_FEATURE_EXTRACTOR_TASKS:
        load_feature_extractor = False

    # 如果任务在`NO_IMAGE_PROCESSOR_TASKS`中,则不需要加载`image_processor`。
    if task in NO_IMAGE_PROCESSOR_TASKS:
        load_image_processor = False
    # 如果需要加载分词器
    if load_tokenizer:
        # 尝试根据模型名称或配置名称推断分词器(如果提供的话)
        if tokenizer is None:
            # 如果 model_name 是字符串,则尝试使用其作为分词器
            if isinstance(model_name, str):
                tokenizer = model_name
            # 如果 config 是字符串,则尝试使用其作为分词器
            elif isinstance(config, str):
                tokenizer = config
            else:
                # 在这里无法猜测应该使用哪个分词器
                raise Exception(
                    "Impossible to guess which tokenizer to use. "
                    "Please provide a PreTrainedTokenizer class or a path/identifier to a pretrained tokenizer."
                )

        # 如果需要,实例化分词器
        if isinstance(tokenizer, (str, tuple)):
            if isinstance(tokenizer, tuple):
                # 对于元组,格式为(分词器名称,{kwargs})
                use_fast = tokenizer[1].pop("use_fast", use_fast)
                tokenizer_identifier = tokenizer[0]
                tokenizer_kwargs = tokenizer[1]
            else:
                tokenizer_identifier = tokenizer
                tokenizer_kwargs = model_kwargs.copy()
                tokenizer_kwargs.pop("torch_dtype", None)

            # 根据给定的参数创建 AutoTokenizer 实例
            tokenizer = AutoTokenizer.from_pretrained(
                tokenizer_identifier, use_fast=use_fast, _from_pipeline=task, **hub_kwargs, **tokenizer_kwargs
            )

    # 如果需要加载图像处理器
    if load_image_processor:
        # 尝试根据模型名称或配置名称推断图像处理器(如果提供的话)
        if image_processor is None:
            # 如果 model_name 是字符串,则尝试使用其作为图像处理器
            if isinstance(model_name, str):
                image_processor = model_name
            # 如果 config 是字符串,则尝试使用其作为图像处理器
            elif isinstance(config, str):
                image_processor = config
            # 为了向后兼容,如果 feature_extractor 是 BaseImageProcessor 的实例,则使用其作为图像处理器
            elif feature_extractor is not None and isinstance(feature_extractor, BaseImageProcessor):
                image_processor = feature_extractor
            else:
                # 在这里无法猜测应该使用哪个图像处理器
                raise Exception(
                    "Impossible to guess which image processor to use. "
                    "Please provide a PreTrainedImageProcessor class or a path/identifier "
                    "to a pretrained image processor."
                )

        # 如果需要,实例化图像处理器
        if isinstance(image_processor, (str, tuple)):
            # 根据给定的参数创建 AutoImageProcessor 实例
            image_processor = AutoImageProcessor.from_pretrained(
                image_processor, _from_pipeline=task, **hub_kwargs, **model_kwargs
            )
    # 如果需要加载特征提取器
    if load_feature_extractor:
        # 尝试从模型名称或配置名称(如果是字符串)推断特征提取器
        if feature_extractor is None:
            # 如果模型名称是字符串,则将其作为特征提取器
            if isinstance(model_name, str):
                feature_extractor = model_name
            # 如果配置是字符串,则将其作为特征提取器
            elif isinstance(config, str):
                feature_extractor = config
            else:
                # 在此无法猜测正确的特征提取器
                raise Exception(
                    "Impossible to guess which feature extractor to use. "
                    "Please provide a PreTrainedFeatureExtractor class or a path/identifier "
                    "to a pretrained feature extractor."
                )

        # 如果特征提取器是字符串或元组,则实例化特征提取器
        if isinstance(feature_extractor, (str, tuple)):
            feature_extractor = AutoFeatureExtractor.from_pretrained(
                feature_extractor, _from_pipeline=task, **hub_kwargs, **model_kwargs
            )

            # 如果特征提取器包含语言模型且模型名称是字符串
            if (
                feature_extractor._processor_class
                and feature_extractor._processor_class.endswith("WithLM")
                and isinstance(model_name, str)
            ):
                try:
                    import kenlm  # 触发 `ImportError` 如果未安装
                    from pyctcdecode import BeamSearchDecoderCTC

                    # 如果模型名称是目录或文件
                    if os.path.isdir(model_name) or os.path.isfile(model_name):
                        decoder = BeamSearchDecoderCTC.load_from_dir(model_name)
                    else:
                        # 语言模型的全局路径及字母表文件名
                        language_model_glob = os.path.join(
                            BeamSearchDecoderCTC._LANGUAGE_MODEL_SERIALIZED_DIRECTORY, "*"
                        )
                        alphabet_filename = BeamSearchDecoderCTC._ALPHABET_SERIALIZED_FILENAME
                        allow_patterns = [language_model_glob, alphabet_filename]
                        # 从 HF Hub 加载模型名称对应的解码器
                        decoder = BeamSearchDecoderCTC.load_from_hf_hub(model_name, allow_patterns=allow_patterns)

                    # 将解码器加入参数中
                    kwargs["decoder"] = decoder
                except ImportError as e:
                    # 如果无法加载 `decoder`,则记录警告信息,并默认使用原始 CTC
                    logger.warning(f"Could not load the `decoder` for {model_name}. Defaulting to raw CTC. Error: {e}")
                    # 如果未安装 kenlm
                    if not is_kenlm_available():
                        logger.warning("Try to install `kenlm`: `pip install kenlm")

                    # 如果未安装 pyctcdecode
                    if not is_pyctcdecode_available():
                        logger.warning("Try to install `pyctcdecode`: `pip install pyctcdecode")

    # 如果任务是翻译且模型配置具有特定任务参数
    if task == "translation" and model.config.task_specific_params:
        # 遍历模型配置的特定任务参数
        for key in model.config.task_specific_params:
            # 如果参数以 "translation" 开头
            if key.startswith("translation"):
                # 将任务设为该参数值,并发出警告
                task = key
                warnings.warn(
                    f'"translation" task was used, instead of "translation_XX_to_YY", defaulting to "{task}"',
                    UserWarning,
                )
                break

    # 如果存在分词器,则将其加入参数中
    if tokenizer is not None:
        kwargs["tokenizer"] = tokenizer
    # 如果提供了特征提取器,则将其添加到 kwargs 字典中
    if feature_extractor is not None:
        kwargs["feature_extractor"] = feature_extractor

    # 如果提供了 torch 的数据类型,则将其添加到 kwargs 字典中
    if torch_dtype is not None:
        kwargs["torch_dtype"] = torch_dtype

    # 如果提供了图像处理器,则将其添加到 kwargs 字典中
    if image_processor is not None:
        kwargs["image_processor"] = image_processor

    # 如果提供了设备信息,则将其添加到 kwargs 字典中
    if device is not None:
        kwargs["device"] = device

    # 使用给定的参数和 kwargs 字典创建一个新的 pipeline_class 对象并返回
    return pipeline_class(model=model, framework=framework, task=task, **kwargs)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293
  • 294
  • 295
  • 296
  • 297
  • 298
  • 299
  • 300
  • 301
  • 302
  • 303
  • 304
  • 305
  • 306
  • 307
  • 308
  • 309
  • 310
  • 311
  • 312
  • 313
  • 314
  • 315
  • 316
  • 317
  • 318
  • 319
  • 320
  • 321
  • 322
  • 323
  • 324
  • 325
  • 326
  • 327
  • 328
  • 329
  • 330
  • 331
  • 332
  • 333
  • 334
  • 335
  • 336
  • 337
  • 338
  • 339
  • 340
  • 341
  • 342
  • 343
  • 344
  • 345
  • 346
  • 347
  • 348
  • 349
  • 350
  • 351
  • 352
  • 353
  • 354
  • 355
  • 356
  • 357
  • 358
  • 359
  • 360
  • 361
  • 362
  • 363
  • 364
  • 365
  • 366
  • 367
  • 368
  • 369
  • 370
  • 371
  • 372
  • 373
  • 374
  • 375
  • 376
  • 377
  • 378
  • 379
  • 380
  • 381
  • 382
  • 383
  • 384
  • 385
  • 386
  • 387
  • 388
  • 389
  • 390
  • 391
  • 392
  • 393
  • 394
  • 395
  • 396
  • 397
  • 398
  • 399
  • 400
  • 401
  • 402
  • 403
  • 404
  • 405
  • 406
  • 407
  • 408
  • 409
  • 410
  • 411
  • 412
  • 413
  • 414
  • 415
  • 416
  • 417
  • 418
  • 419
  • 420
  • 421
  • 422
  • 423
  • 424
  • 425
  • 426
  • 427
  • 428
  • 429
  • 430
  • 431
  • 432
  • 433
  • 434
  • 435
  • 436
  • 437
  • 438
  • 439
  • 440
  • 441
  • 442
  • 443
  • 444
  • 445
  • 446
  • 447
  • 448
  • 449
  • 450
  • 451
  • 452
  • 453
  • 454
  • 455
  • 456
  • 457
  • 458
  • 459
  • 460
  • 461
  • 462
  • 463
  • 464
  • 465
  • 466
  • 467
  • 468
  • 469
  • 470
  • 471
  • 472
  • 473
  • 474
  • 475
  • 476
  • 477
  • 478
  • 479
  • 480
  • 481
  • 482
  • 483
  • 484
  • 485
  • 486
  • 487
  • 488
  • 489
  • 490
  • 491
  • 492
  • 493
  • 494
  • 495
  • 496
  • 497
  • 498
  • 499
  • 500
  • 501
  • 502
  • 503
  • 504
  • 505
  • 506
  • 507
  • 508
  • 509
  • 510
  • 511
  • 512
  • 513
  • 514
  • 515
  • 516
  • 517
  • 518
  • 519
  • 520
  • 521
  • 522
  • 523
  • 524
  • 525
  • 526
  • 527
  • 528
  • 529
  • 530
  • 531
  • 532
  • 533
  • 534
  • 535
  • 536
  • 537
  • 538
  • 539
  • 540
  • 541
  • 542
  • 543
  • 544
  • 545
  • 546
  • 547
  • 548
  • 549
  • 550
  • 551
  • 552
  • 553
  • 554
  • 555
  • 556
  • 557
  • 558
  • 559
  • 560
  • 561
  • 562
  • 563
  • 564
  • 565
  • 566
  • 567
  • 568
  • 569
  • 570
  • 571
  • 572
  • 573
  • 574
  • 575
  • 576
  • 577
  • 578
  • 579
  • 580
  • 581
  • 582
  • 583
  • 584
  • 585
  • 586
  • 587
  • 588
  • 589
  • 590
  • 591
  • 592
  • 593
  • 594
  • 595
  • 596
  • 597
  • 598
  • 599
  • 600
  • 601
  • 602
  • 603
  • 604
  • 605
  • 606
  • 607
  • 608
  • 609
  • 610
  • 611
  • 612
  • 613
  • 614
  • 615
  • 616
  • 617
  • 618
  • 619
  • 620
  • 621
  • 622
  • 623
  • 624
  • 625
  • 626
  • 627
  • 628
  • 629
  • 630
  • 631
  • 632
  • 633
  • 634
  • 635
  • 636
  • 637
  • 638
  • 639
  • 640
  • 641
  • 642
  • 643
  • 644
  • 645
  • 646
  • 647
  • 648
  • 649
  • 650
  • 651
  • 652
  • 653
  • 654
  • 655
  • 656
  • 657
  • 658
  • 659
  • 660
  • 661
  • 662
  • 663
  • 664
  • 665
  • 666
  • 667
  • 668
  • 669
  • 670
  • 671
  • 672
  • 673
  • 674
  • 675
  • 676
  • 677
  • 678
  • 679
  • 680
  • 681
  • 682
  • 683
  • 684
  • 685
  • 686
  • 687
  • 688
  • 689
  • 690
  • 691
  • 692
  • 693
  • 694
  • 695
  • 696
  • 697
  • 698
  • 699
  • 700
  • 701
  • 702
  • 703
  • 704
  • 705
  • 706
  • 707
  • 708
  • 709
  • 710
  • 711
  • 712
  • 713
  • 714
  • 715
  • 716
  • 717
  • 718
  • 719
  • 720
  • 721
  • 722
  • 723
  • 724
  • 725
  • 726
  • 727
  • 728
  • 729
  • 730
  • 731
  • 732
  • 733
  • 734
  • 735
  • 736
  • 737
  • 738
  • 739
  • 740
  • 741
  • 742
  • 743
  • 744
  • 745
  • 746
  • 747
  • 748
  • 749
  • 750
  • 751
  • 752
  • 753
  • 754
  • 755
  • 756
  • 757
  • 758
  • 759
  • 760
  • 761
  • 762
  • 763
  • 764
  • 765
  • 766
  • 767
  • 768
  • 769
  • 770
  • 771
  • 772
  • 773
  • 774
  • 775
  • 776
  • 777
  • 778
  • 779
  • 780
  • 781
  • 782
  • 783
  • 784
  • 785
  • 786
  • 787
  • 788
  • 789
  • 790
  • 791
  • 792
  • 793
  • 794
  • 795
  • 796
  • 797
  • 798
  • 799
  • 800
  • 801
  • 802
  • 803
  • 804
  • 805
  • 806
  • 807
  • 808
  • 809
  • 810
  • 811
  • 812
  • 813
  • 814
  • 815
  • 816
  • 817
  • 818
  • 819
  • 820
  • 821
  • 822
  • 823
  • 824
  • 825
  • 826
  • 827
  • 828
  • 829
  • 830
  • 831
  • 832
  • 833
  • 834
  • 835
  • 836
  • 837
  • 838
  • 839
  • 840
  • 841
  • 842
  • 843
  • 844
  • 845
  • 846
  • 847
  • 848
  • 849
  • 850
  • 851
  • 852
  • 853
  • 854
  • 855
  • 856
  • 857
  • 858
  • 859
  • 860
  • 861
  • 862
  • 863
  • 864
  • 865
  • 866
  • 867
  • 868
  • 869
  • 870
  • 871
  • 872
  • 873
  • 874
  • 875
  • 876
  • 877
  • 878
  • 879
  • 880
  • 881
  • 882
  • 883
  • 884
  • 885
  • 886
  • 887
  • 888
  • 889
  • 890
  • 891
  • 892
  • 893
  • 894
  • 895
  • 896
  • 897
  • 898
  • 899
  • 900
  • 901
  • 902
  • 903
  • 904
  • 905
  • 906
  • 907
  • 908
  • 909
  • 910
  • 911
  • 912
  • 913
  • 914
  • 915
  • 916
  • 917
  • 918
  • 919
  • 920
  • 921
  • 922
  • 923
  • 924
  • 925
  • 926
  • 927
  • 928
  • 929
  • 930
  • 931
  • 932
  • 933
  • 934
  • 935
  • 936
  • 937
  • 938
  • 939
  • 940
  • 941
  • 942
  • 943
  • 944
  • 945
  • 946
  • 947
  • 948
  • 949
  • 950
  • 951
  • 952
  • 953
  • 954
  • 955
  • 956
  • 957
  • 958
  • 959
  • 960
  • 961
  • 962
  • 963
  • 964
  • 965
  • 966
  • 967
  • 968
  • 969
  • 970
  • 971
  • 972
  • 973
  • 974
  • 975
  • 976
  • 977
  • 978
  • 979
  • 980
  • 981
  • 982
  • 983
  • 984
  • 985
  • 986
  • 987
  • 988
  • 989
  • 990
  • 991
  • 992
  • 993
  • 994
  • 995
  • 996
  • 997
  • 998
  • 999
  • 1000
  • 1001
  • 1002
  • 1003
  • 1004
  • 1005
  • 1006
  • 1007
  • 1008
  • 1009
  • 1010
  • 1011
  • 1012
  • 1013
  • 1014
  • 1015
  • 1016
  • 1017
  • 1018
  • 1019
  • 1020
  • 1021
  • 1022
  • 1023
  • 1024
  • 1025
  • 1026
  • 1027
  • 1028
  • 1029
  • 1030
  • 1031
  • 1032
  • 1033
  • 1034
  • 1035
  • 1036
  • 1037
  • 1038
  • 1039
  • 1040
  • 1041
  • 1042
  • 1043
  • 1044
  • 1045
  • 1046
  • 1047
  • 1048
  • 1049
  • 1050
  • 1051
  • 1052
  • 1053
  • 1054
  • 1055
  • 1056
  • 1057
  • 1058
  • 1059
  • 1060
  • 1061
  • 1062
  • 1063
  • 1064
  • 1065
  • 1066
  • 1067
  • 1068
  • 1069
  • 1070
  • 1071
  • 1072
  • 1073
  • 1074
  • 1075
  • 1076
  • 1077
  • 1078
  • 1079
  • 1080
  • 1081
  • 1082
  • 1083
  • 1084
  • 1085
  • 1086
  • 1087
  • 1088
  • 1089
  • 1090
  • 1091
  • 1092
  • 1093
  • 1094
  • 1095
  • 1096
  • 1097
  • 1098
  • 1099
  • 1100
  • 1101
  • 1102
  • 1103
  • 1104
  • 1105
  • 1106
  • 1107
  • 1108
  • 1109
  • 1110
  • 1111
  • 1112
  • 1113
  • 1114
  • 1115
  • 1116
  • 1117
  • 1118
  • 1119
  • 1120
  • 1121
  • 1122
  • 1123
  • 1124
  • 1125
  • 1126
  • 1127
  • 1128
  • 1129
  • 1130
  • 1131
  • 1132
  • 1133
  • 1134
  • 1135
  • 1136
  • 1137
  • 1138
  • 1139
  • 1140
  • 1141
  • 1142
  • 1143
  • 1144
  • 1145
  • 1146
  • 1147
  • 1148
  • 1149
  • 1150
  • 1151
  • 1152
  • 1153
  • 1154
  • 1155
  • 1156
  • 1157
  • 1158
  • 1159
  • 1160
  • 1161
  • 1162
  • 1163
  • 1164
  • 1165
  • 1166
  • 1167
  • 1168
  • 1169
  • 1170
  • 1171
  • 1172
  • 1173
  • 1174
  • 1175
  • 1176
  • 1177
  • 1178
  • 1179
  • 1180
  • 1181
  • 1182
  • 1183
  • 1184
  • 1185
  • 1186
  • 1187
  • 1188
  • 1189
  • 1190
  • 1191
  • 1192
  • 1193
  • 1194
  • 1195
  • 1196
  • 1197
  • 1198
  • 1199
  • 1200
  • 1201
  • 1202
  • 1203
  • 1204
  • 1205
  • 1206
  • 1207
  • 1208
  • 1209
  • 1210
  • 1211
  • 1212
  • 1213
  • 1214
  • 1215
  • 1216
  • 1217
  • 1218
  • 1219
  • 1220
  • 1221
  • 1222
  • 1223
  • 1224
  • 1225
  • 1226
  • 1227
  • 1228
  • 1229
  • 1230
  • 1231
  • 1232
  • 1233
  • 1234
  • 1235
  • 1236
  • 1237
  • 1238
  • 1239
  • 1240
  • 1241
  • 1242
  • 1243
  • 1244
  • 1245
  • 1246
  • 1247
  • 1248
  • 1249
  • 1250
  • 1251
  • 1252
  • 1253
  • 1254
  • 1255
  • 1256
  • 1257
  • 1258
  • 1259
  • 1260
  • 1261
  • 1262
  • 1263
  • 1264
  • 1265
  • 1266
  • 1267
  • 1268
  • 1269
  • 1270
  • 1271
  • 1272
  • 1273

.\processing_utils.py

# 设置文件编码为 UTF-8
# 版权声明,声明代码的版权归 The HuggingFace Inc. 团队所有
#
# 根据 Apache 许可证版本 2.0 使用此文件,除非遵守许可证,否则不得使用此文件
# 可以在以下网址获取许可证副本:http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则本软件根据"原样"提供,不附带任何形式的明示或暗示的担保或条件
# 有关详细信息,请参阅许可证
"""
通用处理器的保存/加载类。
"""

import copy  # 导入复制模块
import inspect  # 导入检查模块
import json  # 导入 JSON 模块
import os  # 导入操作系统模块
import warnings  # 导入警告模块
from pathlib import Path  # 导入 Path 类
from typing import Any, Dict, Optional, Tuple, Union  # 导入类型提示

from .dynamic_module_utils import custom_object_save  # 从动态模块工具导入自定义对象保存函数
from .tokenization_utils_base import PreTrainedTokenizerBase  # 从基础标记化工具导入预训练分词器基类
from .utils import (
    PROCESSOR_NAME,  # 从工具模块导入处理器名称常量
    PushToHubMixin,  # 从工具模块导入推送至 Hub 的 Mixin 类
    add_model_info_to_auto_map,  # 从工具模块导入将模型信息添加到自动映射的函数
    cached_file,  # 从工具模块导入缓存文件函数
    copy_func,  # 从工具模块导入复制函数函数
    direct_transformers_import,  # 从工具模块导入直接导入 Transformers 模块的函数
    download_url,  # 从工具模块导入下载 URL 函数
    is_offline_mode,  # 从工具模块导入检查是否为离线模式的函数
    is_remote_url,  # 从工具模块导入检查是否为远程 URL 的函数
    logging,  # 从工具模块导入日志记录对象
)

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器对象

# 动态导入 Transformers 模块,以获取处理器类的属性类
transformers_module = direct_transformers_import(Path(__file__).parent)

# 自动映射到基类的映射表,用于自动模型加载时的类关联
AUTO_TO_BASE_CLASS_MAPPING = {
    "AutoTokenizer": "PreTrainedTokenizerBase",  # 自动分词器映射到基础分词器基类
    "AutoFeatureExtractor": "FeatureExtractionMixin",  # 自动特征提取器映射到特征提取混合类
    "AutoImageProcessor": "ImageProcessingMixin",  # 自动图像处理器映射到图像处理混合类
}


class ProcessorMixin(PushToHubMixin):
    """
    这是一个 Mixin 类,用于为所有处理器类提供保存/加载功能。
    """

    attributes = ["feature_extractor", "tokenizer"]  # 处理器类中需要保存的属性列表
    # 对应属性列表中的类属性定义
    feature_extractor_class = None  # 特征提取器类属性初始化为空
    tokenizer_class = None  # 分词器类属性初始化为空
    _auto_class = None  # 自动加载的类属性初始化为空

    # args have to match the attributes class attribute
    def __init__(self, *args, **kwargs):
        # 对传入的参数和关键字参数进行清理和验证
        for key in kwargs:
            # 检查关键字参数是否在对象的属性列表中,否则引发异常
            if key not in self.attributes:
                raise TypeError(f"Unexpected keyword argument {key}.")
        
        for arg, attribute_name in zip(args, self.attributes):
            # 检查位置参数是否与属性名匹配的关键字参数冲突,如果有冲突则引发异常
            if attribute_name in kwargs:
                raise TypeError(f"Got multiple values for argument {attribute_name}.")
            else:
                kwargs[attribute_name] = arg

        if len(kwargs) != len(self.attributes):
            # 检查最终的关键字参数数量是否与对象属性数量匹配,不匹配则引发数值错误异常
            raise ValueError(
                f"This processor requires {len(self.attributes)} arguments: {', '.join(self.attributes)}. Got "
                f"{len(args)} arguments instead."
            )

        # 检查每个参数是否属于其对应的预期类别,这也会捕获用户错误顺序初始化的情况
        for attribute_name, arg in kwargs.items():
            class_name = getattr(self, f"{attribute_name}_class")
            # 如果类名为"AutoXxx",则检查其对应的基类
            class_name = AUTO_TO_BASE_CLASS_MAPPING.get(class_name, class_name)
            if isinstance(class_name, tuple):
                # 如果类名是元组,则获取模块中对应的类列表
                proper_class = tuple(getattr(transformers_module, n) for n in class_name if n is not None)
            else:
                # 否则直接获取模块中的类
                proper_class = getattr(transformers_module, class_name)

            # 检查参数是否属于预期的类别,不属于则引发数值错误异常
            if not isinstance(arg, proper_class):
                raise ValueError(
                    f"Received a {type(arg).__name__} for argument {attribute_name}, but a {class_name} was expected."
                )

            # 将参数设置为对象的属性
            setattr(self, attribute_name, arg)
    def to_dict(self) -> Dict[str, Any]:
        """
        Serializes this instance to a Python dictionary.

        Returns:
            `Dict[str, Any]`: Dictionary of all the attributes that make up this processor instance.
        """
        # Create a deep copy of the instance's __dict__ to prevent unintended modifications
        output = copy.deepcopy(self.__dict__)

        # Retrieve the signature of the __init__ method to get its parameters
        sig = inspect.signature(self.__init__)
        
        # Filter out attributes that are not listed in the __init__ parameters
        attrs_to_save = sig.parameters
        attrs_to_save = [x for x in attrs_to_save if x not in self.__class__.attributes]
        
        # Add "auto_map" to the list of attributes to be saved
        attrs_to_save += ["auto_map"]

        # Filter the output dictionary to include only the attributes to be saved
        output = {k: v for k, v in output.items() if k in attrs_to_save}

        # Add the class name of the processor instance to the output dictionary
        output["processor_class"] = self.__class__.__name__

        # Remove specific attributes that should not be included in the output
        if "tokenizer" in output:
            del output["tokenizer"]
        if "image_processor" in output:
            del output["image_processor"]
        if "feature_extractor" in output:
            del output["feature_extractor"]

        # Filter out attributes with names indicating objects not suitable for serialization
        output = {
            k: v
            for k, v in output.items()
            if not (isinstance(v, PushToHubMixin) or v.__class__.__name__ == "BeamSearchDecoderCTC")
        }

        return output

    def to_json_string(self) -> str:
        """
        Serializes this instance to a JSON string.

        Returns:
            `str`: String containing all the attributes that make up this feature_extractor instance in JSON format.
        """
        # Convert the instance to a dictionary
        dictionary = self.to_dict()

        # Serialize the dictionary to a JSON string with formatting
        return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"

    def to_json_file(self, json_file_path: Union[str, os.PathLike]):
        """
        Save this instance to a JSON file.

        Args:
            json_file_path (`str` or `os.PathLike`):
                Path to the JSON file in which this processor instance's parameters will be saved.
        """
        # Open the JSON file for writing
        with open(json_file_path, "w", encoding="utf-8") as writer:
            # Write the instance's JSON representation to the file
            writer.write(self.to_json_string())

    def __repr__(self):
        """
        Returns a string representation of the processor instance.

        Returns:
            `str`: String representation of the processor instance, including key attributes and JSON serialization.
        """
        # Generate representations of all attributes specified in self.attributes
        attributes_repr = [f"- {name}: {repr(getattr(self, name))}" for name in self.attributes]
        
        # Concatenate attribute representations into a single string
        attributes_repr = "\n".join(attributes_repr)
        
        # Return a formatted string including class name, attributes, and JSON serialization
        return f"{self.__class__.__name__}:\n{attributes_repr}\n\n{self.to_json_string()}"

    @classmethod
    def get_processor_dict(
        cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
    ):
        """
        Placeholder method for defining how to get processor dictionary.

        This method is not implemented in the provided code snippet.
        """
        pass
    def from_args_and_dict(cls, args, processor_dict: Dict[str, Any], **kwargs):
        """
        从参数字典和额外关键字参数实例化一个 [`~processing_utils.ProcessingMixin`] 类型的对象。

        Args:
            processor_dict (`Dict[str, Any]`):
                用于实例化处理器对象的参数字典。可以利用预训练检查点的
                [`~processing_utils.ProcessingMixin.to_dict`] 方法来获取这样一个字典。
            kwargs (`Dict[str, Any]`):
                初始化处理器对象的额外参数。

        Returns:
            [`~processing_utils.ProcessingMixin`]: 从这些参数实例化的处理器对象。
        """
        processor_dict = processor_dict.copy()
        return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)

        # 不像图像处理器或特征提取器那样,处理器的 `__init__` 方法不接受 `kwargs`。
        # 我们必须弹出一些未使用的(但是特定的)参数才能使其正常工作。
        if "processor_class" in processor_dict:
            del processor_dict["processor_class"]

        if "auto_map" in processor_dict:
            del processor_dict["auto_map"]

        # 使用给定的 `args` 和 `processor_dict` 实例化处理器对象
        processor = cls(*args, **processor_dict)

        # 如果需要,使用 `kwargs` 更新处理器对象
        for key in set(kwargs.keys()):
            if hasattr(processor, key):
                setattr(processor, key, kwargs.pop(key))

        # 记录处理器对象的信息
        logger.info(f"Processor {processor}")
        if return_unused_kwargs:
            return processor, kwargs
        else:
            return processor

    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: Union[str, os.PathLike],
        cache_dir: Optional[Union[str, os.PathLike]] = None,
        force_download: bool = False,
        local_files_only: bool = False,
        token: Optional[Union[str, bool]] = None,
        revision: str = "main",
        **kwargs,
        ):
        r"""
        Instantiate a processor associated with a pretrained model.

        <Tip>

        This class method is simply calling the feature extractor
        [`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`], image processor
        [`~image_processing_utils.ImageProcessingMixin`] and the tokenizer
        [`~tokenization_utils_base.PreTrainedTokenizer.from_pretrained`] methods. Please refer to the docstrings of the
        methods above for more information.

        </Tip>

        Args:
            pretrained_model_name_or_path (`str` or `os.PathLike`):
                This can be either:

                - a string, the *model id* of a pretrained feature_extractor hosted inside a model repo on
                  huggingface.co.
                - a path to a *directory* containing a feature extractor file saved using the
                  [`~SequenceFeatureExtractor.save_pretrained`] method, e.g., `./my_model_directory/`.
                - a path or url to a saved feature extractor JSON *file*, e.g.,
                  `./my_model_directory/preprocessor_config.json`.
            **kwargs
                Additional keyword arguments passed along to both
                [`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`] and
                [`~tokenization_utils_base.PreTrainedTokenizer.from_pretrained`].
        """
        kwargs["cache_dir"] = cache_dir
        kwargs["force_download"] = force_download
        kwargs["local_files_only"] = local_files_only
        kwargs["revision"] = revision

        # Check and handle deprecated use_auth_token argument
        use_auth_token = kwargs.pop("use_auth_token", None)
        if use_auth_token is not None:
            warnings.warn(
                "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
                FutureWarning,
            )
            if token is not None:
                raise ValueError(
                    "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
                )
            token = use_auth_token

        # If token is provided, set it in kwargs
        if token is not None:
            kwargs["token"] = token

        # Get arguments from pretrained model and process kwargs
        args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs)
        # Obtain processor dictionary and update kwargs
        processor_dict, kwargs = cls.get_processor_dict(pretrained_model_name_or_path, **kwargs)

        # Instantiate the class using obtained arguments and processor dictionary
        return cls.from_args_and_dict(args, processor_dict, **kwargs)

    @classmethod
    # 注册一个自动类别名,用于自定义特征提取器,这应仅用于自定义的特征提取器,因为库中的提取器已经与 `AutoProcessor` 映射好了。
    def register_for_auto_class(cls, auto_class="AutoProcessor"):
        """
        Register this class with a given auto class. This should only be used for custom feature extractors as the ones
        in the library are already mapped with `AutoProcessor`.

        <Tip warning={true}>

        This API is experimental and may have some slight breaking changes in the next releases.

        </Tip>

        Args:
            auto_class (`str` or `type`, *optional*, defaults to `"AutoProcessor"`):
                The auto class to register this new feature extractor with.
        """
        if not isinstance(auto_class, str):
            auto_class = auto_class.__name__

        # 导入 transformers.models.auto 模块,用于检查 auto_class 是否存在
        import transformers.models.auto as auto_module

        # 如果 auto_module 中没有找到指定的 auto_class,则抛出 ValueError
        if not hasattr(auto_module, auto_class):
            raise ValueError(f"{auto_class} is not a valid auto class.")

        # 将 auto_class 赋值给当前类的 _auto_class 属性
        cls._auto_class = auto_class

    @classmethod
    def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
        # 初始化一个空列表,用于存储从预训练模型中获取的参数
        args = []
        # 遍历类的 attributes 列表
        for attribute_name in cls.attributes:
            # 获取当前属性对应的类名
            class_name = getattr(cls, f"{attribute_name}_class")

            # 如果 class_name 是一个元组
            if isinstance(class_name, tuple):
                # 从 transformers_module 中获取类,如果为 None 则跳过
                classes = tuple(getattr(transformers_module, n) if n is not None else None for n in class_name)
                # 获取 kwargs 中的 use_fast 参数,默认为 True
                use_fast = kwargs.get("use_fast", True)
                # 如果 use_fast 为 True 并且 classes[1] 不为 None,则使用 classes[1],否则使用 classes[0]
                if use_fast and classes[1] is not None:
                    attribute_class = classes[1]
                else:
                    attribute_class = classes[0]
            else:
                # 如果 class_name 不是元组,则直接从 transformers_module 中获取对应的类
                attribute_class = getattr(transformers_module, class_name)

            # 使用 from_pretrained 方法从预训练模型加载参数,并添加到 args 列表中
            args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))
        return args

    @property
    def model_input_names(self):
        # 获取当前对象的第一个属性,并尝试获取其 model_input_names 属性,如果不存在则返回 None
        first_attribute = getattr(self, self.attributes[0])
        return getattr(first_attribute, "model_input_names", None)
# 将 ProcessorMixin 类的 push_to_hub 方法复制一份,赋值给原方法
ProcessorMixin.push_to_hub = copy_func(ProcessorMixin.push_to_hub)
# 检查 push_to_hub 方法的文档字符串是否不为空
if ProcessorMixin.push_to_hub.__doc__ is not None:
    # 如果文档字符串不为空,使用格式化字符串将文档字符串中的占位符替换为指定的内容
    ProcessorMixin.push_to_hub.__doc__ = ProcessorMixin.push_to_hub.__doc__.format(
        object="processor", object_class="AutoProcessor", object_files="processor files"
    )
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293
  • 294
  • 295
  • 296
  • 297
  • 298
  • 299
  • 300
  • 301
  • 302
  • 303
  • 304
  • 305
  • 306
  • 307
  • 308
  • 309
  • 310
  • 311
  • 312
  • 313
  • 314
  • 315
  • 316
  • 317
  • 318
  • 319
  • 320
  • 321
  • 322
  • 323
  • 324
  • 325
  • 326
  • 327
  • 328
  • 329
  • 330
  • 331
  • 332
  • 333
  • 334
  • 335
  • 336
  • 337
  • 338
  • 339
  • 340
  • 341
  • 342
  • 343
  • 344
  • 345
  • 346
  • 347
  • 348
  • 349
  • 350
  • 351
  • 352
  • 353
  • 354
  • 355
  • 356
  • 357
  • 358
  • 359
  • 360
  • 361
  • 362
  • 363
  • 364
  • 365
  • 366
  • 367
  • 368
  • 369
  • 370
  • 371
  • 372
  • 373
  • 374
  • 375
  • 376
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/932531
推荐阅读
相关标签
  

闽ICP备14008679号