当前位置:   article > 正文

PyTorch2ONNX-分类模型:速度比较(固定维度、动态维度)、精度比较

PyTorch2ONNX-分类模型:速度比较(固定维度、动态维度)、精度比较
图像分类模型部署: PyTorch -> ONNX

1. 模型部署介绍

1.1 人工智能开发部署全流程

step1
数据
数据采集
定义类别
标注
数据集
step2
模型
训练模型
测试集评估
调参优化
可解释分析
step3
部署
手机/平板
服务器
PC/浏览器
嵌入式开发板

1.2 模型部署平台和芯片介绍

  • 设备:PC、浏览器、APP、小程序、服务器、嵌入式开发板、无人车、无人机、Jetson Nano、树莓派、机械臂、物联网设备
  • 厂商
    • 英特尔(Intel):主要生产 CPU(中央处理器)和一些 FPGA(现场可编程门阵列)芯片。代表作品包括 Intel Core 系列 CPU 和 Xeon 系列服务器 CPU,以及 FPGA 产品如 Intel Stratix 系列。
    • 英伟达(NVIDIA):以 GPU(图形处理器)为主打产品,广泛应用于图形渲染、深度学习等领域。代表作品包括 NVIDIA GeForce 系列用于游戏图形处理,NVIDIA Tesla 和 NVIDIA A100 用于深度学习加速。
    • AMD:主要生产 CPU 和 GPU。代表作品包括 AMD Ryzen 系列 CPU 和 AMD EPYC 系列服务器 CPU,以及 AMD Radeon 系列 GPU 用于游戏和专业图形处理。
    • 苹果(Apple):生产自家设计的芯片,主要包括苹果 M 系列芯片。代表作品有 M1 芯片,广泛应用于苹果的 Mac 电脑、iPad 和一些其他设备。
    • 高通(Qualcomm):主要生产移动平台芯片,包括移动处理器和调制解调器。代表作品包括 Snapdragon 系列芯片,用于智能手机和移动设备。
    • 昇腾(Ascend):由华为生产,主要生产 NPU(神经网络处理器),用于深度学习任务。代表作品包括昇腾 910 和昇腾 310。
    • 麒麟(Kirin):同样由华为生产,主要生产手机芯片,包括 CPU 和 GPU。代表作品包括麒麟 9000 系列,用于华为旗舰手机。
    • 瑞芯微(Rockchip):主要生产 VPU(视觉处理器)和一些移动平台芯片。代表作品包括 RK3288 和 RK3399,广泛应用于智能显示、机器人等领域。
芯片名英文名中文名厂商主要任务是否训练是否推理算力速度
CPUCentral Processing Unit(CPU)中央处理器各大厂商通用计算中等
GPUGraphics Processing Unit(GPU)图形处理器NVIDIA、AMD等图形渲染、深度学习加速
TPUTensor Processing Unit(TPU)张量处理器谷歌机器学习中的张量运算
NPUNeural Processing Unit(NPU)神经网络处理器华为、联发科等深度学习模型的性能提升中等
VPUVision Processing Unit(VPU)视觉处理器英特尔、博通等图像和视频处理中等中等
DSPDigital Signal Processor(DSP)数字信号处理器德州仪器、高通等数字信号处理、音频信号处理中等中等
FPGAField-Programmable Gate Array(FPGA)现场可编程门阵列英特尔、赛灵思等可编程硬件加速器中等

1.3 模型部署的通用流程

转换
运行
PyTorch
TensorFlow
Caffe
PaddlePaddle
训练框架
ONNX/中间表示
推理框架/引擎/后端
TensorRT
ONNXRuntime
OpenVINO
NCNN/TNN
PPL

2. 使用 ONNX 的意义

从这两张图可以很明显的看到,当有了中间表示 ONNX 后,从原来的 M × N M \times N M×N 变为了 M + N M + N M+N,让模型部署的流程变得简单。

3. ONNX 的介绍

开源机器学习通用中间格式,由微软、Facebook(Meta)、亚马逊、IBM 共同发起的。它可以兼容各种深度学习框架,也可以兼容各种推理引擎和终端硬件、操作系统

4. ONNX 环境安装

pip install onnx -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install onnxruntime -i https://pypi.tuna.tsinghua.edu.cn/simple
  • 1
  • 2

5. PyTorch → ONNX

5.1 将一个分类模型转换为 ONNX

import torch
from torchvision import models


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"正在使用的设备: {device}")

# 创建一个训练好的模型
model = models.resnet18(pretrained=True)  # ImageNet 预训练权重
model = model.eval().to(device)

# 构建一个输入
dummy_input = torch.randn(size=[1, 3, 256, 256]).to(device)  # [N, B, H, W]

# 让模型推理
output = model(dummy_input)
print(f"output.shape: {output.shape}")

# 使用 PyTorch 自带的函数将模型转换为 ONNX 格式
onnx_save_path = 'ONNX/saves/resnet18_imagenet.onnx'  # 导出的ONNX模型路径 
with torch.no_grad():
    torch.onnx.export(
        model=model,                            # 要转换的模型
        args=dummy_input,                       # 模型的输入
        f=onnx_save_path,                       # 导出的ONNX模型路径 
        input_names=['input'],                  # ONNX模型输入的名字(自定义)
        output_names=['output'],                # ONNX模型输出的名字(自定义)
        opset_version=11,                       # Opset算子集合的版本(默认为17)
    )
    
print(f"ONNX 模型导出成功,路径为:{onnx_save_path}")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
正在使用的设备: cpu
/home/leovin/anaconda3/envs/wsl/lib/python3.8/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/home/leovin/anaconda3/envs/wsl/lib/python3.8/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /home/leovin/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 44.7M/44.7M [00:03<00:00, 13.9MB/s]
output.shape: torch.Size([1, 1000])
ONNX 模型导出成功,路径为:ONNX/saves/resnet18_imagenet.onnx
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/article/detail/55534

推荐阅读
相关标签