当前位置:   article > 正文

BERT 模型到 TensorRT 的转换_bert-base-uncased onnx模型转trt代码

bert-base-uncased onnx模型转trt代码

总体流程

1. 准备环境:安装TensorRT和相关依赖库,确保环境可用。

2. 获取预训练BERT模型:可以直接下载官方提供的PyTorch模型,或训练自定义BERT模型。

3. 模型验证:在PyTorch环境中加载模型,给定样本输入,验证模型输出正确性。

4. 导出ONNX模型:使用torch.onnx.export()将PyTorch模型导出为ONNX格式。要注意输入输出格式。

5. 简化ONNX模型:使用onnx-simplifier等工具对ONNX模型进行优化和简化。这一步可以省略。

6. 导入TensorRT:使用TensorRT的API(Python或C++)加载ONNX模型。

7. 配置TensorRT引擎:设置最大batchsize、workspace大小等参数。开启FP16精度模式。

8. 生成序列化引擎:使用trt.Builder的build_serialized_network()方法生成序列化后的TensorRT引擎。

9. 反序列化和验证:在另一个Python脚本中加载序列化后的引擎,验证模型输出正确性。

10. 制作数据输入:根据模型输入格式,准备calib数据集或随机生成输入tensor。

11. TRT模型校准:运行trt.Calibrator对引擎进行校准,生成量化后的TensorRT引擎。

12. TRT模型评估:用calib数据集进行多轮测试,记录推理时间,计算加速比。与PyTorch原模型对比。

13. 模型部署:将量化后的TensorRT引擎保存为文件,以用于实际部署。可以编译为C++引擎。

14. 结果分析:分析日志,整理转换前后模型精度和速度指标,总结优化经验。

详解:

1. 准备工作- 安装TensorRT库及依赖
- 确保CUDA可用
- 准备PyTorch环境,安装onnx等模块
- 获取预训练BERT模型或训练自定义BERT
- 设定评测设备(GPU或CPU)

2. PyTorch模型检查- 在PyTorch环境中加载BERT模型
- 给定样本输入,获取模型输出
- 检查输出结果正确性
- 分析模型参数大小、浮点运算量等

3. 导出ONNX模型- 使用torch.onnx.export()导出模型
- 输入需是torch.Tensor,输出也为tensor
- 检查ONNX模型输入输出和PyTorch一致
- 可视化ONNX模型结构,分析层信息

  1. import torch
  2. from torch.nn import functional as F
  3. import numpy as np
  4. import os
  5. from transformers import BertTokenizer, BertForMaskedLM
  6. import logging
  7. # import onnxruntime as ort
  8. import transformers
  9. import time
  10. current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
  11. print("pytorch:", torch.__version__)
  12. # print("onnxruntime version:", ort.__version__)
  13. # print("onnxruntime device:", ort.get_device())
  14. print("transformers:", transformers.__version__)
  15. BERT_PATH = 'bert-base-uncased'
  16. logging.basicConfig(filename='test.log', level=logging.DEBUG)
  17. '''
  18. logging.basicConfig(level=logging.DEBUG)这行代码是用于设置logging模块的日志级别和基本配置的。
  19. 具体来说:
  20. - logging是Python的日志模块。使用该模块可以方便地进行日志记录和调试。
  21. - basicConfig是logging模块中的一个函数,用于进行日志的基本配置。
  22. - level=logging.DEBUG 这部分是设置日志级别为DEBUG。
  23. logging模块定义了多个日志级别,按严重程度从低到高为:
  24. - DEBUG:调试信息,日志量大,用于追踪程序的所有流程。
  25. - INFO:信息提示,确认程序按预期运行。
  26. - WARNING:警告信息,可能会有潜在问题,但程序还能执行。
  27. - ERROR:错误信息,程序无法按预期执行。
  28. - CRITICAL:严重错误,程序可能无法继续运行。
  29. 通过设置level参数,可以控制打印哪些级别的日志信息。
  30. setLevel为DEBUG时,会打印所有级别的日志,这对调试程序非常有用。
  31. 总结一下:
  32. 1. logging.basicConfig配置日志模块
  33. 2. 设置level=logging.DEBUG将日志级别设置为最低级别DEBUG
  34. 3. 这样可以打印调试程序过程中的所有日志信息,方便调试和追踪程序执行流程
  35. '''
  36. def model_test(model, tokenizer, text):
  37. #print("==============model test===================")
  38. logging.info("==============model test===================")
  39. encoded_input = tokenizer.encode_plus(text, return_tensors = "pt")
  40. '''
  41. 1. tokenizer.encode_plus()会返回一个字典,包含:
  42. - input_ids: 输入序列的数字化表示
  43. - token_type_ids: 对于每个输入元组的第一个和第二个序列的区分(对BERT等双序列模型需要)
  44. - attention_mask: 指定对哪些词元进行self-attention操作的掩码(可选择性返回)
  45. 2. tokenizer.encode()只返回输入的数字化表示input_ids,不返回其他信息。
  46. '''
  47. # encoded_input = {k: v.unsqueeze(0) for k,v in encoded_input.items()}
  48. '''
  49. 这行使用tokenizer对文本进行编码,返回的encoded_input是一个字典,包含 input_ids, token_type_ids等字段。
  50. 然后这行:
  51. encoded_input = {k: v.unsqueeze(0) for k,v in encoded_input.items()}
  52. 是将encoded_input中的每个tensor进行unsqueeze(0)操作,即在tensor的第一个维度上增加一个大小为1的维度。
  53. 因为在PyTorch中,batch size对应的是tensor的第一个维度。通过unsqueeze(0),可以将batch size从原来的不定,变为1。
  54. 举个例子,输入的input_ids大小可能原来是(5, 20)(5是batch size,20是序列长度)。
  55. 经过unsqueeze(0)后,大小变为(1, 5, 20),其中第一个维度被增加为1,即batch size变为1。
  56. 这样做的目的是为了固定batch size,之后测试模型的时候输入是一个固定batch size的tensor,方便调试。
  57. unsqueeze(0)后返回的仍是一个字典,字典的key保持不变,value是进行过unsqueeze的tensor。
  58. 所以这行代码的作用就是:
  59. 1. 对编码后的输入进行遍历,遍历字典中的每一个tensor
  60. 2. 对每个tensor进行unsqueeze(0),以固定batch size为1
  61. 3. 将处理后的tensor重新组成字典,赋值给encoded_input
  62. '''
  63. mask_index = torch.where(encoded_input["input_ids"][0] == tokenizer.mask_token_id)
  64. #mask_index = torch.where(encoded_input["input_ids"] == tokenizer.mask_token_id)
  65. '''
  66. 之所以这样修改,是因为修改后的代码设置了batch size为1,所以encoded_input["input_ids"]中只有一个示例,不需要再取[0]。
  67. 取[0]主要是为了兼容batch size大于1的情况,但现在batch size固定为1,所以可以简化代码,直接判断整个tensor。
  68. 另外判断一个tensor中哪些位置等于某值,torch.where()是比较方便的操作。
  69. 总结一下:
  70. 1. 原代码中需要取[0],是为了兼容batch size大于1的情况
  71. 2. 修改后代码固定了batch size为1,所以可以简化判断,不再需要取[0]
  72. 3. torch.where可以直接判断tensor中哪些位置符合条件,很方便实现mask token的索引查找
  73. 4. 这种修改使代码更简洁,同时也更符合batch size为1的情况
  74. '''
  75. output = model(**encoded_input)
  76. #print(output[0].shape)
  77. logging.debug(f"Output shape: {output[0].shape}")
  78. logits = output.logits
  79. softmax = F.softmax(logits, dim = -1)
  80. mask_word = softmax[0, mask_index, :]
  81. top_10 = torch.topk(mask_word, 10, dim = 1)[1][0]
  82. # print("model test topk10 output:")
  83. logging.info("model test topk10 output:")
  84. for token in top_10:
  85. word = tokenizer.decode([token])
  86. new_sentence = text.replace(tokenizer.mask_token, word)
  87. print(new_sentence)
  88. # save inputs and output
  89. # print("Saving inputs and output to case_data.npz ...")
  90. logging.info("Saving inputs and output to case_data.npz ...")
  91. position_ids = torch.arange(0, encoded_input['input_ids'].shape[1]).int().view(1, -1)
  92. #position_ids = torch.arange(0, encoded_input['input_ids'].shape[1]).unsqueeze(0).int()
  93. '''
  94. 主要区别在于使用view和unsqueeze的不同。
  95. view是对tensor进行reshape,可能会造成原tensor被覆盖。
  96. 而unsqueeze是返回一个新的tensor,对原tensor无影响。
  97. 之所以修改成unsqueeze,主要有以下几点考虑:
  98. 1. unsqueeze更安全,不会改变原tensor,生成的是新的tensor。
  99. 2. 这里的position_ids只需要增加一个维度,使其size为(1, seq_len),unsqueeze可以很方便地做到。
  100. 3. view需要指定-1这个特殊大小,unsqueeze无需指定size。
  101. 4. unsqueeze更符合“增加维度”这个操作的语义。
  102. 所以修改成unsqueeze可以:
  103. 1. 生成一个新的安全的tensor,不影响原tensor
  104. 2. 简化代码,不需要指定-1这个特殊大小
  105. 3. 更符合逻辑地增加维度
  106. '''
  107. # print(position_ids)
  108. logging.debug(f"Position ids shape: {position_ids.shape}")
  109. input_ids=encoded_input['input_ids'].int().detach().numpy()
  110. '''
  111. 这行代码是将PyTorch的tensor转换为numpy数组的过程:
  112. 1. encoded_input['input_ids'] 取出编码后的input_ids的tensor
  113. 2. .int() 将tensor转换为整数类型(int32)
  114. 3. .detach() 将这个tensor与计算图分离,即得到一个纯tensor,不再与计算图相关
  115. 4. .numpy() 将这个detach的整数tensor转换为numpy数组
  116. 需要注意的是,由于默认PyTorch tensor是int32,转换的numpy数组也是int32。
  117. 如果模型需要int64类型,这里就需要额外处理,比如:
  118. input_ids = encoded_input['input_ids'].int().to(torch.int64).detach().numpy().astype(np.int64)
  119. 先转换为int64 tensor,再转为int64 numpy数组。
  120. 在后面ONNXRuntime推理时就出现了这个问题,从npz加载的数据为int32,模型的输入类型定义为int64
  121. '''
  122. logging.debug(f"Input_ids: {input_ids}")
  123. token_type_ids=encoded_input['token_type_ids'].int().detach().numpy()
  124. logging.debug(f"Token_type_ids: {token_type_ids}")
  125. # print(input_ids.shape)
  126. logging.debug(f"Input ids shape: {input_ids.shape}")
  127. # save data
  128. npz_file = BERT_PATH + '/case_data.npz'
  129. np.savez(npz_file,
  130. input_ids=input_ids,
  131. token_type_ids=token_type_ids,
  132. position_ids=position_ids,
  133. logits=output[0].detach().numpy())
  134. data = np.load(npz_file)
  135. # print(data['input_ids'])
  136. logging.debug(f"Saved input ids: {data['input_ids']}")
  137. def model2onnx(model, tokenizer, text):
  138. # print("===================model2onnx=======================")
  139. logging.info("===================model2onnx=======================")
  140. encoded_input = tokenizer.encode_plus(text, return_tensors = "pt")
  141. # encoded_input = {k: v.unsqueeze(0) for k,v in encoded_input.items()}
  142. # print(encoded_input)
  143. logging.debug(f"Encoded input: \n {encoded_input}")
  144. logging.debug(f"tuple(encoded_input.values()): \n {tuple(encoded_input.values())}")
  145. # convert model to onnx
  146. model.eval()
  147. export_model_path = BERT_PATH + "/model.onnx"
  148. opset_version = 12
  149. #batch_size = 1
  150. symbolic_names = {0: 'batch_size', 1: 'max_seq_len'}
  151. '''
  152. dynamic_axes的格式是:
  153. {输出张量名: {轴编号: 轴名称}}
  154. dynamic_axes中的数字表示固定大小,字符串表示动态轴。
  155. '''
  156. # input_names = ['input_ids', 'attention_mask', 'token_type_ids']
  157. # input_shapes = [encoded_input[name].shape for name in input_names]
  158. # shape_tuples = [tuple(shape) for shape in input_shapes]
  159. torch.onnx.export(model, # model being run
  160. args = tuple(encoded_input.values()), # model input (or a tuple for multiple inputs)
  161. f=export_model_path, # where to save the model (can be a file or file-like object)
  162. opset_version=opset_version, # the ONNX version to export the model to
  163. do_constant_folding=False, # whether to execute constant folding for optimization
  164. input_names=['input_ids', # the model's input names
  165. 'attention_mask',
  166. 'token_type_ids'],
  167. output_names=['logits'], # the model's output names
  168. dynamic_axes={'input_ids': symbolic_names, # variable length axes
  169. 'attention_mask' : symbolic_names,
  170. 'token_type_ids' : symbolic_names,
  171. 'logits' : symbolic_names})
  172. #print("Model exported at ", export_model_path)
  173. logging.info("Model exported at " + export_model_path)
  174. if __name__ == '__main__':
  175. if not os.path.exists(BERT_PATH):
  176. print(f"Download {BERT_PATH} model first!")
  177. assert(0)
  178. logging.info("---------------------------------------------------------------------------------------------")
  179. logging.info(current_time)
  180. tokenizer = BertTokenizer.from_pretrained(BERT_PATH)
  181. model = BertForMaskedLM.from_pretrained(BERT_PATH, return_dict = True)
  182. text = "The capital of France, " + tokenizer.mask_token + ", contains the Eiffel Tower."
  183. # bias = model.cls.predictions.bias.data
  184. # logging.info(f"bias: {bias.shape}")
  185. # bias = torch.unsqueeze(bias, dim=0)
  186. # bias = torch.nn.Parameter(bias)
  187. # logging.info(f"n_bias: {bias.shape}")
  188. # model.cls.predictions.bias = bias
  189. model_test(model, tokenizer, text)
  190. model2onnx(model, tokenizer, text)

4. 优化ONNX模型- 使用onnx-simplifier等工具简化模型
- 删除冗余/无用节点,减小模型大小
- 可选:使用ONNXRuntime测试简化后模型

  1. import onnxsim
  2. import onnx
  3. import onnxruntime as ort
  4. import os
  5. import logging
  6. import time
  7. import numpy as np
  8. import torch
  9. from transformers import BertTokenizer
  10. def model_simplify(model_path, model_simp_path):
  11. model = onnx.load(model_path)
  12. model_simp, check = onnxsim.simplify(model)
  13. onnx.save(model_simp, model_simp_path)
  14. logging.info("Model exported at " + model_simp_path)
  15. # 把输入都打印出来
  16. for i in range(len(model.graph.input)):
  17. logging.debug(f"input[{i}]: {model.graph.input[i].type.tensor_type.shape}")
  18. for i in range(len(model_simp.graph.input)):
  19. logging.debug(f"input[{i}]: {model_simp.graph.input[i].type.tensor_type.shape}")
  20. # 把输出都打印出来
  21. logging.debug(model.graph.output[0].type.tensor_type.shape)
  22. logging.debug(model_simp.graph.output[0].type.tensor_type.shape)
  23. # logging.debug(f"bias: {model.cls.predictions.bias.data.shape}")
  24. # logging.debug(f"bias_sim: {model.cls.predictions.bias.data.shape}")
  25. logging.info("尺寸对比...")
  26. model_size = os.path.getsize(model_path)
  27. model_simp_size = os.path.getsize(model_simp_path)
  28. logging.debug(f"Simplified ONNX model from {model_size} to {model_simp_size} bytes")
  29. class ortInfer:
  30. def __init__(self, model_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']):
  31. self.model_path = model_path
  32. self.providers = providers
  33. self.session = None
  34. def load_model(self):
  35. try:
  36. self.session = ort.InferenceSession(self.model_path, providers=self.providers)
  37. except Exception as e:
  38. logging.debug(f"Failed to load model: {e}")
  39. self.session = None
  40. def check_input(self, input_dataA, input_dataB, input_dataC, dtype):
  41. is_match = (input_dataA.dtype == dtype and
  42. input_dataB.dtype == dtype and
  43. input_dataC.dtype == dtype )
  44. return is_match
  45. def release(self):
  46. if self.session:
  47. self.session.release()
  48. def onnxRuntime_gpu(self):
  49. data = np.load("bert-base-uncased/case_data.npz", allow_pickle=True)
  50. # np.load()注意检查加载数据类型,可能出现类型转换
  51. '''
  52. encoded_input = tokenizer.encode_plus(text, return_tensors="pt")
  53. input_ids = encoded_input["input_ids"]
  54. attention_mask = encoded_input["attention_mask"]
  55. token_type_ids = encoded_input["token_type_ids"]
  56. '''
  57. input_ids = data['input_ids']
  58. # 构造attention_mask
  59. atten_mask = np.where(input_ids == 0, 0, 1)
  60. # 保持shape一致
  61. attention_mask = atten_mask.reshape(input_ids.shape)
  62. token_type_ids = data['token_type_ids']
  63. # 这里检查出数据类型不对
  64. logging.debug(f"input_ids.dtype: {input_ids.dtype}")
  65. logging.debug(f"attention_mask.dtype: {attention_mask.dtype}")
  66. logging.debug(f"token_type_ids.dtype: {token_type_ids.dtype}")
  67. # 类型转换
  68. input_ids = input_ids.astype(np.int64)
  69. # attention_mask = attention_mask.astype(np.int64)
  70. token_type_ids = token_type_ids.astype(np.int64)
  71. if not self.check_input(input_ids, attention_mask, token_type_ids, np.int64):
  72. logging.debug(f"input_ids.dtype: {input_ids.dtype}")
  73. logging.debug(f"attention_mask.dtype: {attention_mask.dtype}")
  74. logging.debug(f"token_type_ids.dtype: {token_type_ids.dtype}")
  75. raise ValueError("Input dada type mismatch!")
  76. # 对于bert模型,主要的输入是token_ids和attention_mask,输出是推理结果logits
  77. #ort.set_default_logger_severity(3)
  78. self.load_model()
  79. start = time.time()
  80. outputs = self.session.run(None, {"input_ids": input_ids,
  81. "attention_mask": attention_mask,
  82. "token_type_ids": token_type_ids
  83. })
  84. end = time.time()
  85. for output in outputs:
  86. logging.debug(f"output.shape: {output.shape} output.ndim: {output.ndim}")
  87. # logits = outputs[0]
  88. # softmax = torch.softmax(torch.tensor(logits), dim=-1)
  89. # values, indices = torch.topk(softmax, 5)
  90. # for i in indices[0]:
  91. # logging.info(tokenizer.decode([i]))
  92. name = (((self.model_path.split('/'))[1]).split('.'))[0]
  93. Tort = end - start
  94. logging.info(self.session.get_providers())
  95. # self.session.get_default_device()
  96. # logging.info(name + "by ONNX Runtime GPU inference time:" + Tort)
  97. logging.info("{} by ONNX Runtime GPU inference time: {}s".format(name, Tort))
  98. '''
  99. case_data.npz:
  100. - 这是在PyTorch模型测试时保存输入输出的文件,使用numpy的npz格式保存。
  101. - 里面包含了input_ids、token_type_ids、position_ids和模型输出logits。
  102. - 这些数据可以作为ONNX Runtime推理的输入。
  103. data:
  104. - 通过numpy.load()加载case_data.npz文件,得到一个字典。
  105. - 里面包含输入和输出的numpy数组。
  106. input_ids:
  107. - BERT模型输入中的一个,对应输入文本的id化结果。
  108. - shape类似(1,序列长度),1是batch_size。
  109. token_type_ids:
  110. - BERT模型输入中的另一个,标识token属于句子1还是2。
  111. - shape同input_ids。
  112. outputs:
  113. - ONNX Runtime推理后的模型输出,是一个list。
  114. - 里面按序包含了模型所有的输出tensor。
  115. logits:
  116. - outputs[0],也就是模型输出的第一个tensor。
  117. - 对应BERT模型推理后的logits结果,后续可以通过该数组得到预测结果。
  118. - shape类似(1,序列长度,词表大小)。
  119. 举个输入数据的例子:
  120. input_ids = [[101, 2054, 2003, ...]]
  121. # 输入序列token id
  122. token_type_ids = [[0, 0, 0, ...]]
  123. # 全部为句子1
  124. position_ids = [[0, 1, 2, ...]]
  125. # 对应位置
  126. outputs = [logits_array, ...]
  127. # 模型输出
  128. logits = logits_array
  129. # shape (1, 序列长度, 词表大小)
  130. '''
  131. '''
  132. set_default_logger_severity方法可以设置ONNX Runtime的默认日志级别。
  133. 日志级别包括:
  134. 0 - ORT_LOGGING_LEVEL_VERBOSE
  135. 1 - ORT_LOGGING_LEVEL_INFO
  136. 2 - ORT_LOGGING_LEVEL_WARNING
  137. 3 - ORT_LOGGING_LEVEL_ERROR
  138. 4 - ORT_LOGGING_LEVEL_FATAL
  139. 数字越大,日志级别越高。
  140. ort.set_default_logger_severity(3)表示设置ONNX Runtime的日志级别为ORT_LOGGING_LEVEL_ERROR。
  141. 也就是说,只会打印错误和致命级别的日志信息,屏蔽信息、警告和详细日志。
  142. 这通常在生产环境中使用,可以避免输出过多无用的运行日志,只打印错误信息。
  143. 我们也可以根据需要设置其他级别,如:
  144. python
  145. # 打印所有详细日志
  146. ort.set_default_logger_severity(0)
  147. # 只打印错误
  148. ort.set_default_logger_severity(3)
  149. # 完全不打印日志
  150. ort.set_default_logger_severity(4)
  151. '''
  152. if __name__ == '__main__':
  153. current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
  154. logging.basicConfig(filename='simp.log', level=logging.DEBUG)
  155. model_path = 'bert-base-uncased/model.onnx'
  156. model_simp_path = 'bert-base-uncased/model_simp.onnx'
  157. logging.info("---------------------------------------------------------------------------------------------")
  158. logging.info(current_time)
  159. model_simplify(model_path, model_simp_path)
  160. model_simp = ortInfer(model_simp_path)
  161. model_simp.onnxRuntime_gpu()
  162. model = ortInfer(model_path)
  163. model.onnxRuntime_gpu()

5. 引入TensorRT- 加载TensorRT Python API
- 从ONNX模型导入TensorRT网络
- 打印网络层信息,与ONNX对比
- 为Builder配置工作参数(batchsize等)

6. 构建TensorRT引擎- 使用Builder构建引擎
- 生成序列化模型,并保存为文件
- 在新的脚本中反序列化加载
- 重新检查模型输出正确性

  1. #include <fstream>
  2. #include <iostream>
  3. #include <memory>
  4. #include <NvInfer.h>
  5. #include <NvOnnxParser.h>
  6. using namespace nvinfer1;
  7. using namespace nvonnxparser;
  8. class Logger : public nvinfer1::ILogger{
  9. public:
  10. void log(Severity severity, const char* msg) noexcept override {
  11. // suppress info-level messages
  12. if (severity != Severity::kINFO)
  13. std::cout << msg << std::endl;
  14. }
  15. } gLogger;
  16. int main(int argc, char** argv)
  17. {
  18. // Create builder
  19. // std::unique_ptr<ICudaEngine> engine{nullptr};
  20. IBuilder* builder = nvinfer1::createInferBuilder(gLogger);
  21. const auto explicitBatch = 1U << static_cast<uint32_t>(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
  22. IBuilderConfig* config = builder->createBuilderConfig();
  23. // Create model to populate the network
  24. INetworkDefinition* network = builder->createNetworkV2(explicitBatch);
  25. // Parse ONNX file
  26. IParser* parser = nvonnxparser::createParser(*network, gLogger);
  27. bool parser_status = parser->parseFromFile("../bert-base-uncased/model_simp.onnx", static_cast<int>(ILogger::Severity::kWARNING));
  28. // Get the name of network input
  29. Dims dim = network->getInput(0)->getDimensions();
  30. if (dim.d[0] == -1 && dim.nbDims > 0) // -1 means it is a dynamic model
  31. {
  32. const char* name0 = network->getInput(0)->getName();
  33. IOptimizationProfile* profile = builder->createOptimizationProfile();
  34. profile->setDimensions(name0, OptProfileSelector::kMIN, Dims2(1, 6));
  35. profile->setDimensions(name0, OptProfileSelector::kOPT, Dims2(1, 64));
  36. profile->setDimensions(name0, OptProfileSelector::kMAX, Dims2(1, 256));
  37. const char* name1 = network->getInput(1)->getName();
  38. profile->setDimensions(name1, OptProfileSelector::kMIN, Dims2(1, 6));
  39. profile->setDimensions(name1, OptProfileSelector::kOPT, Dims2(1, 64));
  40. profile->setDimensions(name1, OptProfileSelector::kMAX, Dims2(1, 256));
  41. const char* name2 = network->getInput(2)->getName();
  42. profile->setDimensions(name2, OptProfileSelector::kMIN, Dims2(1, 6));
  43. profile->setDimensions(name2, OptProfileSelector::kOPT, Dims2(1, 64));
  44. profile->setDimensions(name2, OptProfileSelector::kMAX, Dims2(1, 256));
  45. config->addOptimizationProfile(profile);
  46. }
  47. // Build engine
  48. config->setMaxWorkspaceSize(10000 << 20);
  49. // ICudaEngine* engine = builder->buildEngineWithConfig(*network, *config);
  50. // engine.reset(builder->buildEngineWithConfig(*network, *config));
  51. // Serialize the model to engine file
  52. IHostMemory* serialized_engine{ nullptr };
  53. // assert(engine != nullptr);
  54. serialized_engine = builder->buildSerializedNetwork(*network, *config);;
  55. // config->setEngineCapability(nvinfer1::EngineCapability::kSAFETY);
  56. // config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, 1<<40); // 10GB workspace
  57. // nvinfer1::IHostMemory* serialized_engine = builder->buildSerializedNetwork(*network, *config);
  58. std::ofstream output("model.plan", std::ios::binary);
  59. if (!output) {
  60. std::cerr << "Failed to open plan file: " << "model.plan" << std::endl;
  61. return -1;
  62. }
  63. // 通过reinterpret_cast将数据指针转换为const char*
  64. // 将序列化后的引擎数据写入文件,写入的数据来自serialized_engine->data(),大小为serialized_engine->size()
  65. output.write(reinterpret_cast<const char*>(serialized_engine->data()), serialized_engine->size());
  66. std::cout << "generate file success!" << std::endl;
  67. output.close();
  68. // Release resources
  69. delete serialized_engine;
  70. // delete engine;
  71. delete config;
  72. delete parser;
  73. delete network;
  74. delete builder;
  75. // modelStream->destroy();
  76. // engine->destroy();
  77. // config->destroy();
  78. // parser->destroy();
  79. // network->destroy();
  80. // builder->destroy();
  81. return 0;
  82. }

7. 模型校准- 准备校准数据集,或生成随机输入
- 使用EntropyCalibrator进行INT8校准
- 校准后生成量化的TensorRT引擎

8. 模型评估- 用校准数据集进行多轮测试
- 统计平均推理时间
- 计算与PyTorch模型的加速比
- 保存量化后的TensorRT引擎

9. 模型部署- 将TensorRT引擎封装为Python接口
- 可选:用C++ API加载引擎,编译为.so库
- 在实际环境中部署TensorRT模型

10. 结果分析- 整理日志,生成统计报告 
- 分析每个转换步骤的优化效果
- 总结经验教训,改进流程

注意事项:

1. 模型转换要注意输入输出格式的对应,确保ONNX模型和原始PyTorch模型的输入输出一致。

2. 在转为TensorRT模型时,可以适当调整最大batchsize、workspace大小等参数,以取得最佳的性能。

3. 可以在转TRT模型时开启FP16精度,可能会进一步提升性能。

4. TRT模型测速时,要运行足够多次取平均,以消除误差。同时要注意需要进行模型和数据的预热,避免第一次运行有额外的开销。

5. 最终模型测速要对比ONNX模型和原始PyTorch模型,看TensorRT优化的效果如何。同时也要和规定的baseline作比较。

6. 如果条件允许,可以尝试更大的序列长度,看对转换后的性能的影响。

7. 可以考虑使用Python API,因为它更方便调试和修改代码。C++代码可以在最终确定模型后使用。

8. 推荐使用onnx-simplifier等工具简化模型,可以减少不必要的运算,提高转换性能。

9. 模型转换和测速都要记录详细日志,包括每一步的时间、精度数字等,方便后期分析和改进。

10. TensorRT 在创建引擎时,默认情况下会确定网络中的所有 tensor(包括输入和输出)的形状和大小信息。但是对于输入 tensors,TensorRT 允许用户在执行时 Override 它们的形状,这就是 setBindingDimensions 的作用。而对于输出 tensors,其形状和大小是已固定的,不能再次 Override。这么做的设计考虑是:1. 输入形状可能需要根据不同的 Batch Size 进行调整,所以提供了Override的能力。2. 但输出形状应该是固定的,防止用户误操作导致内存出错。3. 引擎内部会根据输入形状自动计算输出形状和大小。这样既提供了输入形状的灵活性,也保证了输出的安全性。

PyTorch到ONNX的转换代码

1. 模型测试部分可以保留,用来验证ONNX模型转换后的正确性。

2. encode_plus()编码后的输入不需要保存在NPZ文件中,可以直接传给export接口。

3. export时输入名应该对应encode_plus的输出字典key,比如input_ids, attention_mask等。

4. 输出名logits可以更详细一些,如‘predictions’或‘output_logits’。

5. dynamic_axes参数可以设置为None,直接使用批量大小和序列长度作为常量。

6. 可以设置do_constant_folding=True开启常量折叠优化。

7. 导出后可以用onnxruntime等工具推理,和PyTorch输出对比验证正确性。

8. 可以使用onnx-simplifier等工具进一步优化ONNX模型。

9. 为了部署,可以通过--opset指定较老的opset版本,如11。

10. 需要注意ONNX模型文件大小,避免过大导致部署问题。

11. 可以参考官方示例,导出时设置训练是否为True,以及排除不需要的层。

12. 需要记录日志,保存模型转换的参数设置,方便排查问题。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/348189
推荐阅读
相关标签
  

闽ICP备14008679号