当前位置:   article > 正文

TensorRT&Sample&Python[end_to_end_tensorflow_mnist]

tensorrt samples python

本文是基于TensorRT 5.0.2基础上,关于其内部的end_to_end_tensorflow_mnist例子的分析和介绍。

1 引言

假设当前路径为:

TensorRT-5.0.2.6/samples

其对应当前例子文件目录树为:

  1. # tree python
  2. python
  3. ├── common.py
  4. ├── end_to_end_tensorflow_mnist
  5. │   ├── model.py
  6. │   ├── README.md
  7. │   ├── requirements.txt
  8. │   └── sample.py

2 基于tensorflow生成模型

其中只有2个文件:

  • model:该文件包含简单的训练模型代码
  • sample:该文件使用UFF mnist模型去创建一个TensorRT inference engine

首先介绍下model.py

  1. # 该脚本包含一个简单的模型训练过程
  2. import tensorflow as tf
  3. import numpy as np
  4. '''main中第一步:获取数据集 '''
  5. def process_dataset():
  6. # 导入mnist数据集
  7. # 手动下载aria2c -x 16 https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
  8. # 将mnist.npz移动到~/.keras/datasets/
  9. # tf.keras.datasets.mnist.load_data会去读取~/.keras/datasets/mnist.npz,而不从网络下载
  10. (x_train, y_train),(x_test, y_test) = tf.keras.datasets.mnist.load_data()
  11. x_train, x_test = x_train / 255.0, x_test / 255.0
  12. # Reshape
  13. NUM_TRAIN = 60000
  14. NUM_TEST = 10000
  15. x_train = np.reshape(x_train, (NUM_TRAIN, 28, 28, 1))
  16. x_test = np.reshape(x_test, (NUM_TEST, 28, 28, 1))
  17. return x_train, y_train, x_test, y_test
  18. '''main中第二步:构建模型 '''
  19. def create_model():
  20. model = tf.keras.models.Sequential()
  21. model.add(tf.keras.layers.InputLayer(input_shape=[28,28, 1]))
  22. model.add(tf.keras.layers.Flatten())
  23. model.add(tf.keras.layers.Dense(512, activation=tf.nn.relu))
  24. model.add(tf.keras.layers.Dense(10, activation=tf.nn.softmax))
  25. model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
  26. return model
  27. '''main中第五步:模型存储 '''
  28. def save(model, filename):
  29. output_names = model.output.op.name
  30. sess = tf.keras.backend.get_session()
  31. # freeze graph
  32. frozen_graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), [output_names])
  33. # 移除训练的节点
  34. frozen_graph = tf.graph_util.remove_training_nodes(frozen_graph)
  35. # 保存模型
  36. with open(filename, "wb") as ofile:
  37. ofile.write(frozen_graph.SerializeToString())
  38. def main():
  39. ''' 1 - 获取数据'''
  40. x_train, y_train, x_test, y_test = process_dataset()
  41. ''' 2 - 构建模型'''
  42. model = create_model()
  43. ''' 3 - 模型训练'''
  44. model.fit(x_train, y_train, epochs = 5, verbose = 1)
  45. ''' 4 - 模型评估'''
  46. model.evaluate(x_test, y_test)
  47. ''' 5 - 模型存储'''
  48. save(model, filename="models/lenet5.pb")
  49. if __name__ == '__main__':
  50. main()

在获得

models/lenet5.pb

之后,执行下述命令,将其转换成uff文件,输出结果如

  1. '''该converter会显示关于input/output nodes的信息,这样你就可以用来在解析的时候进行注册;
  2. 本例子中,我们基于tensorflow.keras的命名规则,事先已知input/output nodes名称了 '''
  3. [root@30d4bceec4c4 end_to_end_tensorflow_mnist]# convert-to-uff models/lenet5.pb
  4. Loading models/lenet5.pb

441382-20190313200015204-370012470.png

3 基于tensorflow的pb文件生成UFF并处理

  1. # 该例子使用UFF MNIST 模型去创建一个TensorRT Inference Engine
  2. from random import randint
  3. from PIL import Image
  4. import numpy as np
  5. import pycuda.driver as cuda
  6. import pycuda.autoinit # 该import会让pycuda自动管理CUDA上下文的创建和清理工作
  7. import tensorrt as trt
  8. import sys, os
  9. # sys.path.insert(1, os.path.join(sys.path[0], ".."))
  10. # import common
  11. # 这里将common中的GiB和find_sample_dataallocate_buffers,do_inference等函数移动到该py文件中,保证自包含。
  12. def GiB(val):
  13. '''以GB为单位,计算所需要的存储值,向左位移10bit表示KB,20bit表示MB '''
  14. return val * 1 << 30
  15. def find_sample_data(description="Runs a TensorRT Python sample", subfolder="", find_files=[]):
  16. '''该函数就是一个参数解析函数。
  17. Parses sample arguments.
  18. Args:
  19. description (str): Description of the sample.
  20. subfolder (str): The subfolder containing data relevant to this sample
  21. find_files (str): A list of filenames to find. Each filename will be replaced with an absolute path.
  22. Returns:
  23. str: Path of data directory.
  24. Raises:
  25. FileNotFoundError
  26. '''
  27. # 为了简洁,这里直接将路径硬编码到代码中。
  28. data_root = kDEFAULT_DATA_ROOT = os.path.abspath("/TensorRT-5.0.2.6/python/data/")
  29. subfolder_path = os.path.join(data_root, subfolder)
  30. if not os.path.exists(subfolder_path):
  31. print("WARNING: " + subfolder_path + " does not exist. Using " + data_root + " instead.")
  32. data_path = subfolder_path if os.path.exists(subfolder_path) else data_root
  33. if not (os.path.exists(data_path)):
  34. raise FileNotFoundError(data_path + " does not exist.")
  35. for index, f in enumerate(find_files):
  36. find_files[index] = os.path.abspath(os.path.join(data_path, f))
  37. if not os.path.exists(find_files[index]):
  38. raise FileNotFoundError(find_files[index] + " does not exist. ")
  39. if find_files:
  40. return data_path, find_files
  41. else:
  42. return data_path
  43. #-----------------
  44. TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
  45. class ModelData(object):
  46. MODEL_FILE = os.path.join(os.path.dirname(__file__), "models/lenet5.uff")
  47. INPUT_NAME ="input_1"
  48. INPUT_SHAPE = (1, 28, 28)
  49. OUTPUT_NAME = "dense_1/Softmax"
  50. '''main中第二步:构建engine'''
  51. def build_engine(model_file):
  52. with trt.Builder(TRT_LOGGER) as builder, \
  53. builder.create_network() as network, \
  54. trt.UffParser() as parser:
  55. builder.max_workspace_size = GiB(1)
  56. # 解析 Uff 网络
  57. parser.register_input(ModelData.INPUT_NAME, ModelData.INPUT_SHAPE)
  58. parser.register_output(ModelData.OUTPUT_NAME)
  59. parser.parse(model_file, network)
  60. # 构建并返回一个engine
  61. return builder.build_cuda_engine(network)
  62. '''main中第三步 '''
  63. def allocate_buffers(engine):
  64. inputs = []
  65. outputs = []
  66. bindings = []
  67. stream = cuda.Stream()
  68. for binding in engine:
  69. size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
  70. dtype = trt.nptype(engine.get_binding_dtype(binding))
  71. # 分配host和device端的buffer
  72. host_mem = cuda.pagelocked_empty(size, dtype)
  73. device_mem = cuda.mem_alloc(host_mem.nbytes)
  74. # 将device端的buffer追加到device的bindings.
  75. bindings.append(int(device_mem))
  76. # Append to the appropriate list.
  77. if engine.binding_is_input(binding):
  78. inputs.append(HostDeviceMem(host_mem, device_mem))
  79. else:
  80. outputs.append(HostDeviceMem(host_mem, device_mem))
  81. return inputs, outputs, bindings, stream
  82. '''main中第四步 '''
  83. # 从pagelocked_buffer.中读取测试样本
  84. def load_normalized_test_case(data_path, pagelocked_buffer, case_num=randint(0, 9)):
  85. test_case_path = os.path.join(data_path, str(case_num) + ".pgm")
  86. # Flatten该图像成为一个1维数组,然后归一化,并copy到host端的 pagelocked内存中.
  87. img = np.array(Image.open(test_case_path)).ravel()
  88. np.copyto(pagelocked_buffer, 1.0 - img / 255.0)
  89. return case_num
  90. '''main中第五步:执行inference '''
  91. # 该函数可以适应多个输入/输出;输入和输出格式为HostDeviceMem对象组成的列表
  92. def do_inference(context, bindings, inputs, outputs, stream, batch_size=1):
  93. # 将数据移动到GPU
  94. [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
  95. # 执行inference.
  96. context.execute_async(batch_size=batch_size, bindings=bindings, stream_handle=stream.handle)
  97. # 将结果从 GPU写回到host端
  98. [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
  99. # 同步stream
  100. stream.synchronize()
  101. # 返回host端的输出结果
  102. return [out.host for out in outputs]
  103. def main():
  104. ''' 1 - 寻找模型文件'''
  105. data_path = find_sample_data(
  106. description="Runs an MNIST network using a UFF model file",
  107. subfolder="mnist")
  108. model_file = ModelData.MODEL_FILE
  109. ''' 2 - 基于build_engine函数构建engine'''
  110. with build_engine(model_file) as engine:
  111. ''' 3 - 分配buffer并创建一个流'''
  112. inputs, outputs, bindings, stream = allocate_buffers(engine)
  113. with engine.create_execution_context() as context:
  114. ''' 4 - 读取测试样本,并归一化'''
  115. case_num = load_normalized_test_case(data_path, pagelocked_buffer=inputs[0].host)
  116. ''' 5 - 执行inference,do_inference函数会返回一个list类型,此处只有一个元素'''
  117. [output] = do_inference(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)
  118. pred = np.argmax(output)
  119. print("Test Case: " + str(case_num))
  120. print("Prediction: " + str(pred))
  121. if __name__ == '__main__':
  122. main()

结果如:

441382-20190313200623396-1861776747.png

转载于:https://www.cnblogs.com/shouhuxianjian/p/10525000.html

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/article/detail/57569
推荐阅读
相关标签
  

闽ICP备14008679号