当前位置:   article > 正文

Pytorch通过保存为ONNX模型转TensorRT5_saveonnx

saveonnx

1 Pytorch以ONNX方式保存模型

    def saveONNX(model, filepath):
        '''
        保存ONNX模型
        :param model: 神经网络模型
        :param filepath: 文件保存路径
        '''
        
        # 神经网络输入数据类型
        dummy_input = torch.randn(self.config.BATCH_SIZE, 1, 28, 28, device='cuda')
        torch.onnx.export(model, dummy_input, filepath, verbose=True)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

2 利用TensorRT5中ONNX解析器构建Engine

    def ONNX_build_engine(onnx_file_path):
        '''
        通过加载onnx文件,构建engine
        :param onnx_file_path: onnx文件路径
        :return: engine
        '''
        # 打印日志
        G_LOGGER = trt.Logger(trt.Logger.WARNING)

        with trt.Builder(G_LOGGER) as builder, builder.create_network() as network, trt.OnnxParser(network, G_LOGGER) as parser:
            builder.max_batch_size = 100
            builder.max_workspace_size = 1 << 20

            print('Loading ONNX file from path {}...'.format(onnx_file_path))
            with open(onnx_file_path, 'rb') as model:
                print('Beginning ONNX file parsing')
                parser.parse(model.read())
            print('Completed parsing of ONNX file')

            print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))
            engine = builder.build_cuda_engine(network)
            print("Completed creating Engine")

            # 保存计划文件
            # with open(engine_file_path, "wb") as f:
            #     f.write(engine.serialize())
            return engine
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27

3 构建TensorRT运行引擎进行预测

    def loadONNX2TensorRT(filepath):
        '''
        通过onnx文件,构建TensorRT运行引擎
        :param filepath: onnx文件路径
        '''
        # 计算开始时间
        Start = time()

        engine = self.ONNX_build_engine(filepath)

        # 读取测试集
        datas = DataLoaders()
        test_loader = datas.testDataLoader()
        img, target = next(iter(test_loader))
        img = img.numpy()
        target = target.numpy()

        img = img.ravel()

        context = engine.create_execution_context()
        output = np.empty((100, 10), dtype=np.float32)

        # 分配内存
        d_input = cuda.mem_alloc(1 * img.size * img.dtype.itemsize)
        d_output = cuda.mem_alloc(1 * output.size * output.dtype.itemsize)
        bindings = [int(d_input), int(d_output)]

        # pycuda操作缓冲区
        stream = cuda.Stream()
        # 将输入数据放入device
        cuda.memcpy_htod_async(d_input, img, stream)
        # 执行模型
        context.execute_async(100, bindings, stream.handle, None)
        # 将预测结果从从缓冲区取出
        cuda.memcpy_dtoh_async(output, d_output, stream)
        # 线程同步
        stream.synchronize()

        print("Test Case: " + str(target))
        print("Prediction: " + str(np.argmax(output, axis=1)))
        print("tensorrt time:", time() - Start)

        del context
        del engine
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/article/detail/57630
推荐阅读
相关标签
  

闽ICP备14008679号