当前位置:   article > 正文

TensorRT学习笔记--官方示例sampleOnnxMNIST.cpp的理解与运行_params.inputtensornames.push_back

params.inputtensornames.push_back

前言

        重点关注官方示例如何导入Onnx模型,相关API的用法;

1--代码

  1. #include "argsParser.h"
  2. #include "buffers.h"
  3. #include "common.h"
  4. #include "logger.h"
  5. #include "parserOnnxConfig.h"
  6. #include "NvInfer.h"
  7. #include <cuda_runtime_api.h>
  8. #include <cstdlib>
  9. #include <fstream>
  10. #include <iostream>
  11. #include <sstream>
  12. using samplesCommon::SampleUniquePtr;
  13. const std::string gSampleName = "TensorRT.sample_onnx_mnist";
  14. class SampleOnnxMNIST
  15. {
  16. public:
  17. SampleOnnxMNIST(const samplesCommon::OnnxSampleParams& params): mParams(params), mEngine(nullptr){}
  18. bool build();
  19. bool infer();
  20. private:
  21. samplesCommon::OnnxSampleParams mParams;
  22. nvinfer1::Dims mInputDims; // 网络的输入维度
  23. nvinfer1::Dims mOutputDims; // 网络的输出维度
  24. int mNumber{0}; // 类别数
  25. std::shared_ptr<nvinfer1::ICudaEngine> mEngine; // 初始化engine
  26. bool constructNetwork(SampleUniquePtr<nvinfer1::IBuilder>& builder,
  27. SampleUniquePtr<nvinfer1::INetworkDefinition>& network, SampleUniquePtr<nvinfer1::IBuilderConfig>& config,
  28. SampleUniquePtr<nvonnxparser::IParser>& parser);
  29. // 前处理
  30. bool processInput(const samplesCommon::BufferManager& buffers);
  31. // 验证结果
  32. bool verifyOutput(const samplesCommon::BufferManager& buffers);
  33. };
  34. bool SampleOnnxMNIST::build()
  35. {
  36. // 创建 builder
  37. auto builder = SampleUniquePtr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(sample::gLogger.getTRTLogger()));
  38. if (!builder)
  39. {
  40. return false;
  41. }
  42. // 显式 batch
  43. const auto explicitBatch = 1U << static_cast<uint32_t>(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
  44. // 创建network
  45. auto network = SampleUniquePtr<nvinfer1::INetworkDefinition>(builder->createNetworkV2(explicitBatch));
  46. if (!network)
  47. {
  48. return false;
  49. }
  50. // 创建config
  51. auto config = SampleUniquePtr<nvinfer1::IBuilderConfig>(builder->createBuilderConfig());
  52. if (!config)
  53. {
  54. return false;
  55. }
  56. // 创建parser
  57. auto parser = SampleUniquePtr<nvonnxparser::IParser>(nvonnxparser::createParser(*network, sample::gLogger.getTRTLogger()));
  58. if (!parser)
  59. {
  60. return false;
  61. }
  62. // 调用成员函数生成网络
  63. auto constructed = constructNetwork(builder, network, config, parser);
  64. if (!constructed)
  65. {
  66. return false;
  67. }
  68. // CUDA stream used for profiling by the builder.
  69. auto profileStream = samplesCommon::makeCudaStream(); // 创建Cuda stream
  70. if (!profileStream)
  71. {
  72. return false;
  73. }
  74. config->setProfileStream(*profileStream);
  75. SampleUniquePtr<IHostMemory> plan{builder->buildSerializedNetwork(*network, *config)}; // 创建推理引擎
  76. if (!plan)
  77. {
  78. return false;
  79. }
  80. SampleUniquePtr<IRuntime> runtime{createInferRuntime(sample::gLogger.getTRTLogger())}; // 创建Runtime接口
  81. if (!runtime)
  82. {
  83. return false;
  84. }
  85. mEngine = std::shared_ptr<nvinfer1::ICudaEngine>(
  86. runtime->deserializeCudaEngine(plan->data(), plan->size()), samplesCommon::InferDeleter());
  87. if (!mEngine)
  88. {
  89. return false;
  90. }
  91. ASSERT(network->getNbInputs() == 1); // 输入 batch 为1
  92. mInputDims = network->getInput(0)->getDimensions();
  93. ASSERT(mInputDims.nbDims == 4); // 输入维度为 4(包含 batch 维度)
  94. ASSERT(network->getNbOutputs() == 1); // 输出 batch 为1
  95. mOutputDims = network->getOutput(0)->getDimensions();
  96. ASSERT(mOutputDims.nbDims == 2); // 输出维度为 2
  97. return true;
  98. }
  99. // constructNetwork 成员函数实现
  100. bool SampleOnnxMNIST::constructNetwork(SampleUniquePtr<nvinfer1::IBuilder>& builder,
  101. SampleUniquePtr<nvinfer1::INetworkDefinition>& network, SampleUniquePtr<nvinfer1::IBuilderConfig>& config,
  102. SampleUniquePtr<nvonnxparser::IParser>& parser)
  103. {
  104. auto parsed = parser->parseFromFile(locateFile(mParams.onnxFileName, mParams.dataDirs).c_str(),
  105. static_cast<int>(sample::gLogger.getReportableSeverity())); // 解析onnx模型
  106. if (!parsed)
  107. {
  108. return false;
  109. }
  110. config->setMaxWorkspaceSize(16_MiB); // 设置最大工作空间
  111. if (mParams.fp16) // 设置精度
  112. {
  113. config->setFlag(BuilderFlag::kFP16);
  114. }
  115. if (mParams.int8)
  116. {
  117. config->setFlag(BuilderFlag::kINT8);
  118. samplesCommon::setAllDynamicRanges(network.get(), 127.0f, 127.0f);
  119. }
  120. samplesCommon::enableDLA(builder.get(), config.get(), mParams.dlaCore);
  121. return true;
  122. }
  123. bool SampleOnnxMNIST::infer()
  124. {
  125. // 创建内存管理
  126. samplesCommon::BufferManager buffers(mEngine);
  127. // 创建context
  128. auto context = SampleUniquePtr<nvinfer1::IExecutionContext>(mEngine->createExecutionContext());
  129. if (!context)
  130. {
  131. return false;
  132. }
  133. // 从 buffer 中读取数据,并进行前处理
  134. ASSERT(mParams.inputTensorNames.size() == 1);
  135. if (!processInput(buffers))
  136. {
  137. return false;
  138. }
  139. // 将数据复制到 GPU 中
  140. buffers.copyInputToDevice();
  141. bool status = context->executeV2(buffers.getDeviceBindings().data()); // 执行推理
  142. if (!status)
  143. {
  144. return false;
  145. }
  146. // 将结果复制到 CPU 中
  147. buffers.copyOutputToHost();
  148. // 验证推理结果
  149. if (!verifyOutput(buffers))
  150. {
  151. return false;
  152. }
  153. return true;
  154. }
  155. // 前处理成员函数实现
  156. bool SampleOnnxMNIST::processInput(const samplesCommon::BufferManager& buffers) // 传入 buffers 的引用
  157. {
  158. const int inputH = mInputDims.d[2];
  159. const int inputW = mInputDims.d[3];
  160. // 随机选择一个数据进行读取
  161. srand(unsigned(time(nullptr)));
  162. std::vector<uint8_t> fileData(inputH * inputW);
  163. mNumber = rand() % 10;
  164. readPGMFile(locateFile(std::to_string(mNumber) + ".pgm", mParams.dataDirs), fileData.data(), inputH, inputW);
  165. // Print an ascii representation
  166. sample::gLogInfo << "Input:" << std::endl;
  167. for (int i = 0; i < inputH * inputW; i++)
  168. {
  169. sample::gLogInfo << (" .:-=+*#%@"[fileData[i] / 26]) << (((i + 1) % inputW) ? "" : "\n");
  170. }
  171. sample::gLogInfo << std::endl;
  172. // buffers.getHostBuffer() 返回 Name 对应的 buffer 地址
  173. float* hostDataBuffer = static_cast<float*>(buffers.getHostBuffer(mParams.inputTensorNames[0]));
  174. for (int i = 0; i < inputH * inputW; i++) // 对这段 buffer 地址的内容进行赋值操作,由于 processInput() 函数传入的是buffer引用,所以能改变buffer的值
  175. {
  176. hostDataBuffer[i] = 1.0 - float(fileData[i] / 255.0);
  177. }
  178. return true;
  179. }
  180. // 验证推理结果的成员函数实现
  181. bool SampleOnnxMNIST::verifyOutput(const samplesCommon::BufferManager& buffers)
  182. {
  183. const int outputSize = mOutputDims.d[1];
  184. float* output = static_cast<float*>(buffers.getHostBuffer(mParams.outputTensorNames[0]));
  185. float val{0.0f};
  186. int idx{0};
  187. // Calculate Softmax
  188. float sum{0.0f};
  189. for (int i = 0; i < outputSize; i++)
  190. {
  191. output[i] = exp(output[i]);
  192. sum += output[i];
  193. }
  194. sample::gLogInfo << "Output:" << std::endl;
  195. for (int i = 0; i < outputSize; i++)
  196. {
  197. output[i] /= sum;
  198. val = std::max(val, output[i]);
  199. if (val == output[i])
  200. {
  201. idx = i;
  202. }
  203. sample::gLogInfo << " Prob " << i << " " << std::fixed << std::setw(5) << std::setprecision(4) << output[i]
  204. << " "
  205. << "Class " << i << ": " << std::string(int(std::floor(output[i] * 10 + 0.5f)), '*')
  206. << std::endl;
  207. }
  208. sample::gLogInfo << std::endl;
  209. return idx == mNumber && val > 0.9f;
  210. }
  211. // 初始化参数
  212. samplesCommon::OnnxSampleParams initializeSampleParams(const samplesCommon::Args& args)
  213. {
  214. samplesCommon::OnnxSampleParams params; // 创建参数对象
  215. if (args.dataDirs.empty()) //!< Use default directories if user hasn't provided directory paths
  216. {
  217. params.dataDirs.push_back("data/mnist/");
  218. params.dataDirs.push_back("data/samples/mnist/");
  219. }
  220. else //!< Use the data directory provided by the user
  221. {
  222. params.dataDirs = args.dataDirs;
  223. }
  224. // 设置参数对象的默认属性
  225. params.onnxFileName = "mnist.onnx";
  226. params.inputTensorNames.push_back("Input3");
  227. params.outputTensorNames.push_back("Plus214_Output_0");
  228. params.dlaCore = args.useDLACore;
  229. params.int8 = args.runInInt8;
  230. params.fp16 = args.runInFp16;
  231. return params;
  232. }
  233. // 打印参数信息
  234. void printHelpInfo()
  235. {
  236. std::cout
  237. << "Usage: ./sample_onnx_mnist [-h or --help] [-d or --datadir=<path to data directory>] [--useDLACore=<int>]"
  238. << std::endl;
  239. std::cout << "--help Display help information" << std::endl;
  240. std::cout << "--datadir Specify path to a data directory, overriding the default. This option can be used "
  241. "multiple times to add multiple directories. If no data directories are given, the default is to use "
  242. "(data/samples/mnist/, data/mnist/)"
  243. << std::endl;
  244. std::cout << "--useDLACore=N Specify a DLA engine for layers that support DLA. Value can range from 0 to n-1, "
  245. "where n is the number of DLA engines on the platform."
  246. << std::endl;
  247. std::cout << "--int8 Run in Int8 mode." << std::endl;
  248. std::cout << "--fp16 Run in FP16 mode." << std::endl;
  249. }
  250. int main(int argc, char** argv)
  251. {
  252. samplesCommon::Args args;
  253. bool argsOK = samplesCommon::parseArgs(args, argc, argv);
  254. if (!argsOK)
  255. {
  256. sample::gLogError << "Invalid arguments" << std::endl;
  257. printHelpInfo();
  258. return EXIT_FAILURE;
  259. }
  260. if (args.help)
  261. {
  262. printHelpInfo();
  263. return EXIT_SUCCESS;
  264. }
  265. auto sampleTest = sample::gLogger.defineTest(gSampleName, argc, argv);
  266. sample::gLogger.reportTestStart(sampleTest);
  267. SampleOnnxMNIST sample(initializeSampleParams(args)); // 创建对象
  268. sample::gLogInfo << "Building and running a GPU inference engine for Onnx MNIST" << std::endl;
  269. if (!sample.build())
  270. {
  271. return sample::gLogger.reportFail(sampleTest);
  272. }
  273. if (!sample.infer())
  274. {
  275. return sample::gLogger.reportFail(sampleTest);
  276. }
  277. return sample::gLogger.reportPass(sampleTest);
  278. }

2--编译

① CMakeLists.txt:

  1. cmake_minimum_required(VERSION 3.13)
  2. project(TensorRT_test)
  3. set(CMAKE_CXX_STANDARD 11)
  4. set(SAMPLES_COMMON_SOURCES "/home/liujinfu/Downloads/TensorRT-8.2.5.1/samples/common/logger.cpp")
  5. add_executable(TensorRT_test_OnnxMNIST sampleOnnxMNIST.cpp ${SAMPLES_COMMON_SOURCES})
  6. # add TensorRT8
  7. include_directories(/home/liujinfu/Downloads/TensorRT-8.2.5.1/include)
  8. include_directories(/home/liujinfu/Downloads/TensorRT-8.2.5.1/samples/common)
  9. set(TENSORRT_LIB_PATH "/home/liujinfu/Downloads/TensorRT-8.2.5.1/lib")
  10. file(GLOB LIBS "${TENSORRT_LIB_PATH}/*.so")
  11. # add CUDA
  12. find_package(CUDA 11.3 REQUIRED)
  13. message("CUDA_LIBRARIES:${CUDA_LIBRARIES}")
  14. message("CUDA_INCLUDE_DIRS:${CUDA_INCLUDE_DIRS}")
  15. include_directories(${CUDA_INCLUDE_DIRS})
  16. # link
  17. target_link_libraries(TensorRT_test_OnnxMNIST ${LIBS} ${CUDA_LIBRARIES})

② 编译

  1. mkdir build && cd build
  2. cmake ..
  3. make

3--运行结果

./TensorRT_test_OnnxMNIST -d /home/liujinfu/Downloads/TensorRT-8.2.5.1/data/mnist

 

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop】
推荐阅读
相关标签
  

闽ICP备14008679号