当前位置:   article > 正文

TensorRT学习笔记--官方示例sampleMNIST.cpp的理解与运行_sample_mnist

sample_mnist

目录

1--前言

2--编写CMakeLists.txt

3--代码注释

4--运行结果


1--前言

        基于 TensorRT 8.2.5.1 和 Cuda11.3,在 Ubuntu 20.04 编写 CMakeLists.txt 进行编译,生成可执行文件;

2--编写CMakeLists.txt

        提供一个 CMakeLists.txt 编写样例,需根据个人实际修改具体路径:

  1. cmake_minimum_required(VERSION 3.13)
  2. project(TensorRT_test)
  3. set(CMAKE_CXX_STANDARD 11)
  4. set(SAMPLES_COMMON_SOURCES "/home/liujinfu/Downloads/TensorRT-8.2.5.1/samples/common/logger.cpp")
  5. add_executable(TensorRT_test sampleMNIST.cpp ${SAMPLES_COMMON_SOURCES})
  6. # add TensorRT8
  7. include_directories(/home/liujinfu/Downloads/TensorRT-8.2.5.1/include)
  8. include_directories(/home/liujinfu/Downloads/TensorRT-8.2.5.1/samples/common)
  9. set(TENSORRT_LIB_PATH "/home/liujinfu/Downloads/TensorRT-8.2.5.1/lib")
  10. file(GLOB LIBS "${TENSORRT_LIB_PATH}/*.so")
  11. # add CUDA
  12. find_package(CUDA 11.3 REQUIRED)
  13. message("CUDA_LIBRARIES:${CUDA_LIBRARIES}")
  14. message("CUDA_INCLUDE_DIRS:${CUDA_INCLUDE_DIRS}")
  15. include_directories(${CUDA_INCLUDE_DIRS})
  16. # link
  17. target_link_libraries(TensorRT_test ${LIBS} ${CUDA_LIBRARIES})

3--代码注释

  1. // 包含头文件
  2. #include "argsParser.h"
  3. #include "buffers.h"
  4. #include "common.h"
  5. #include "logger.h"
  6. #include "NvCaffeParser.h"
  7. #include "NvInfer.h"
  8. #include <algorithm>
  9. #include <cmath>
  10. #include <cuda_runtime_api.h>
  11. #include <fstream>
  12. #include <iostream>
  13. #include <sstream>
  14. using samplesCommon::SampleUniquePtr;
  15. const std::string gSampleName = "TensorRT.sample_mnist";
  16. // 定义 SampleMNIST 类
  17. class SampleMNIST
  18. {
  19. public:
  20. // 构造函数
  21. SampleMNIST(const samplesCommon::CaffeSampleParams& params)
  22. : mParams(params)
  23. {
  24. }
  25. // 用于 build engine 的成员函数
  26. bool build();
  27. // 用于 inference 的成员函数
  28. bool infer();
  29. // 用于 清理状态 的成员函数
  30. bool teardown();
  31. private:
  32. // 使用 caffe 的 parser 创建网络并标记输出层
  33. bool constructNetwork(
  34. SampleUniquePtr<nvcaffeparser1::ICaffeParser>& parser, SampleUniquePtr<nvinfer1::INetworkDefinition>& network);
  35. // 对输入数据进行前处理
  36. bool processInput(
  37. const samplesCommon::BufferManager& buffers, const std::string& inputTensorName, int inputFileIdx) const;
  38. // 验证输出结果是否正确并打印
  39. bool verifyOutput(
  40. const samplesCommon::BufferManager& buffers, const std::string& outputTensorName, int groundTruthDigit) const;
  41. std::shared_ptr<nvinfer1::ICudaEngine> mEngine{nullptr}; //!< The TensorRT engine used to run the network
  42. samplesCommon::CaffeSampleParams mParams; //!< The parameters for the sample.
  43. nvinfer1::Dims mInputDims; // 输入数据的维度
  44. SampleUniquePtr<nvcaffeparser1::IBinaryProtoBlob>
  45. mMeanBlob; //! the mean blob, which we need to keep around until build is done
  46. };
  47. // build engine 的成员函数实现
  48. bool SampleMNIST::build()
  49. {
  50. // 创建 builder
  51. auto builder = SampleUniquePtr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(sample::gLogger.getTRTLogger()));
  52. if (!builder)
  53. {
  54. return false;
  55. }
  56. // 使用 builder 创建 network
  57. auto network = SampleUniquePtr<nvinfer1::INetworkDefinition>(builder->createNetworkV2(0));
  58. if (!network)
  59. {
  60. return false;
  61. }
  62. // 创建config,用于配置模型
  63. auto config = SampleUniquePtr<nvinfer1::IBuilderConfig>(builder->createBuilderConfig());
  64. if (!config)
  65. {
  66. return false;
  67. }
  68. // 创建parser,用于解析模型
  69. auto parser = SampleUniquePtr<nvcaffeparser1::ICaffeParser>(nvcaffeparser1::createCaffeParser());
  70. if (!parser)
  71. {
  72. return false;
  73. }
  74. // 执行 constructNetwork() 成员函数,使用 parser 创建 MNIST network 并标记输出层
  75. if (!constructNetwork(parser, network))
  76. {
  77. return false;
  78. }
  79. // 设置推理网络的相关参数
  80. builder->setMaxBatchSize(mParams.batchSize); // 设定 batchsize
  81. config->setMaxWorkspaceSize(16_MiB); // 设定 workspace
  82. config->setFlag(BuilderFlag::kGPU_FALLBACK); // 启用GPU回退模式
  83. // 设置精度计算
  84. if (mParams.fp16) // 使用fp16精度
  85. {
  86. config->setFlag(BuilderFlag::kFP16);
  87. }
  88. if (mParams.int8) // 使用int8精度
  89. {
  90. config->setFlag(BuilderFlag::kINT8);
  91. }
  92. // 启用 DLA
  93. samplesCommon::enableDLA(builder.get(), config.get(), mParams.dlaCore);
  94. // 创建 Cuda stream
  95. auto profileStream = samplesCommon::makeCudaStream();
  96. if (!profileStream)
  97. {
  98. return false;
  99. }
  100. config->setProfileStream(*profileStream);
  101. // builder->buildSerializedNetwork 创建推理引擎
  102. SampleUniquePtr<IHostMemory> plan{builder->buildSerializedNetwork(*network, *config)};
  103. if (!plan)
  104. {
  105. return false;
  106. }
  107. // 序列化模型后,创建 Runtime 接口
  108. SampleUniquePtr<IRuntime> runtime{createInferRuntime(sample::gLogger.getTRTLogger())};
  109. if (!runtime)
  110. {
  111. return false;
  112. }
  113. // 反序列化
  114. mEngine = std::shared_ptr<nvinfer1::ICudaEngine>(
  115. runtime->deserializeCudaEngine(plan->data(), plan->size()), samplesCommon::InferDeleter());
  116. if (!mEngine)
  117. {
  118. return false;
  119. }
  120. // 断定输入的 batchsize 为1,即单次只能处理一个样本
  121. ASSERT(network->getNbInputs() == 1);
  122. // 获取输入的维度,断定输入数据的维度为3
  123. mInputDims = network->getInput(0)->getDimensions();
  124. ASSERT(mInputDims.nbDims == 3);
  125. // build engine 成功,返回True
  126. return true;
  127. }
  128. // 输入数据前处理成员函数的实现
  129. bool SampleMNIST::processInput(
  130. const samplesCommon::BufferManager& buffers, const std::string& inputTensorName, int inputFileIdx) const
  131. {
  132. // 获取输入图像的高和宽,mInputDims 已在私有成员变量中定义
  133. const int inputH = mInputDims.d[1];
  134. const int inputW = mInputDims.d[2];
  135. // 随机读取一个数字文件
  136. srand(unsigned(time(nullptr)));
  137. std::vector<uint8_t> fileData(inputH * inputW);
  138. readPGMFile(locateFile(std::to_string(inputFileIdx) + ".pgm", mParams.dataDirs), fileData.data(), inputH, inputW);
  139. // 打印数字的 ASCII 表示,在控制台中用 ASCII 码显示图片
  140. sample::gLogInfo << "Input:\n";
  141. for (int i = 0; i < inputH * inputW; i++)
  142. {
  143. sample::gLogInfo << (" .:-=+*#%@"[fileData[i] / 26]) << (((i + 1) % inputW) ? "" : "\n");
  144. }
  145. sample::gLogInfo << std::endl;
  146. // 将输入数据存储到主机(host)的内存中
  147. float* hostInputBuffer = static_cast<float*>(buffers.getHostBuffer(inputTensorName));
  148. for (int i = 0; i < inputH * inputW; i++)
  149. {
  150. hostInputBuffer[i] = float(fileData[i]);
  151. }
  152. return true;
  153. }
  154. // 成员函数实现,验证输出结果是否正确并打印
  155. bool SampleMNIST::verifyOutput(
  156. const samplesCommon::BufferManager& buffers, const std::string& outputTensorName, int groundTruthDigit) const
  157. {
  158. // 从 host 的 output buffer 中读取 输出结果
  159. const float* prob = static_cast<const float*>(buffers.getHostBuffer(outputTensorName));
  160. // 打印输出分布的直方图
  161. sample::gLogInfo << "Output:\n";
  162. float val{0.0f};
  163. int idx{0};
  164. const int kDIGITS = 10;
  165. for (int i = 0; i < kDIGITS; i++)
  166. {
  167. if (val < prob[i])
  168. {
  169. val = prob[i];
  170. idx = i;
  171. }
  172. sample::gLogInfo << i << ": " << std::string(int(std::floor(prob[i] * 10 + 0.5f)), '*') << "\n";
  173. }
  174. sample::gLogInfo << std::endl;
  175. return (idx == groundTruthDigit && val > 0.9f);
  176. }
  177. // constructNetwork()成员函数的实现
  178. bool SampleMNIST::constructNetwork(
  179. SampleUniquePtr<nvcaffeparser1::ICaffeParser>& parser, SampleUniquePtr<nvinfer1::INetworkDefinition>& network)
  180. {
  181. const nvcaffeparser1::IBlobNameToTensor* blobNameToTensor = parser->parse(
  182. mParams.prototxtFileName.c_str(), mParams.weightsFileName.c_str(), *network, nvinfer1::DataType::kFLOAT);
  183. // 输出 Tensor 标记
  184. for (auto& s : mParams.outputTensorNames)
  185. {
  186. network->markOutput(*blobNameToTensor->find(s.c_str()));
  187. }
  188. // 在网络开头添加减均值的操作(针对本示例而言)
  189. nvinfer1::Dims inputDims = network->getInput(0)->getDimensions();
  190. mMeanBlob
  191. = SampleUniquePtr<nvcaffeparser1::IBinaryProtoBlob>(parser->parseBinaryProto(mParams.meanFileName.c_str()));
  192. nvinfer1::Weights meanWeights{nvinfer1::DataType::kFLOAT, mMeanBlob->getData(), inputDims.d[1] * inputDims.d[2]};
  193. float maxMean
  194. = samplesCommon::getMaxValue(static_cast<const float*>(meanWeights.values), samplesCommon::volume(inputDims));
  195. auto mean = network->addConstant(nvinfer1::Dims3(1, inputDims.d[1], inputDims.d[2]), meanWeights);
  196. if (!mean->getOutput(0)->setDynamicRange(-maxMean, maxMean))
  197. {
  198. return false;
  199. }
  200. if (!network->getInput(0)->setDynamicRange(-maxMean, maxMean))
  201. {
  202. return false;
  203. }
  204. auto meanSub = network->addElementWise(*network->getInput(0), *mean->getOutput(0), ElementWiseOperation::kSUB);
  205. if (!meanSub->getOutput(0)->setDynamicRange(-maxMean, maxMean))
  206. {
  207. return false;
  208. }
  209. network->getLayer(0)->setInput(0, *meanSub->getOutput(0));
  210. samplesCommon::setAllDynamicRanges(network.get(), 127.0f, 127.0f);
  211. return true;
  212. }
  213. // 推理过程的成员函数实现
  214. bool SampleMNIST::infer()
  215. {
  216. // 创建 buffer 管理对象
  217. samplesCommon::BufferManager buffers(mEngine, mParams.batchSize);
  218. // 创建上下文 Context
  219. auto context = SampleUniquePtr<nvinfer1::IExecutionContext>(mEngine->createExecutionContext());
  220. if (!context)
  221. {
  222. return false;
  223. }
  224. // 随机选择一个数字进行推理
  225. srand(time(NULL));
  226. const int digit = rand() % 10;
  227. // 断定数字的个数为1
  228. ASSERT(mParams.inputTensorNames.size() == 1);
  229. // 对输入数字进行前处理
  230. if (!processInput(buffers, mParams.inputTensorNames[0], digit))
  231. {
  232. return false;
  233. }
  234. // 创建 Cuda stream
  235. cudaStream_t stream;
  236. CHECK(cudaStreamCreate(&stream));
  237. // 异步地将 input_data 从 host 的 input buffer 拷贝到 device 的 input buffer (host -> device)
  238. buffers.copyInputToDeviceAsync(stream);
  239. // 异步地执行推理
  240. if (!context->enqueue(mParams.batchSize, buffers.getDeviceBindings().data(), stream, nullptr))
  241. {
  242. return false;
  243. }
  244. // 推理结果从 device 转移到 host
  245. buffers.copyOutputToHostAsync(stream);
  246. // 等待 stream 中的 work 完成
  247. cudaStreamSynchronize(stream);
  248. // 释放stream
  249. cudaStreamDestroy(stream);
  250. // 断定输出结果只有一个Tensor
  251. ASSERT(mParams.outputTensorNames.size() == 1);
  252. // 执行函数 verifyOutput() 验证结果的正确性,并打印推理结果,这个函数可以理解成后处理函数
  253. bool outputCorrect = verifyOutput(buffers, mParams.outputTensorNames[0], digit);
  254. return outputCorrect;
  255. }
  256. // 成员函数实现,清理对象的所有状态
  257. bool SampleMNIST::teardown()
  258. {
  259. nvcaffeparser1::shutdownProtobufLibrary();
  260. return true;
  261. }
  262. // 使用命令行参数初始化结构体参数成员
  263. samplesCommon::CaffeSampleParams initializeSampleParams(const samplesCommon::Args& args)
  264. {
  265. samplesCommon::CaffeSampleParams params;
  266. if (args.dataDirs.empty()) // 输入参数的路径为空时,使用默认的路径
  267. {
  268. params.dataDirs.push_back("data/mnist/");
  269. params.dataDirs.push_back("data/samples/mnist/");
  270. }
  271. else // 输入参数的路径不为空时,使用输入的路径
  272. {
  273. params.dataDirs = args.dataDirs;
  274. }
  275. // 从路径中读取相应的文件
  276. params.prototxtFileName = locateFile("mnist.prototxt", params.dataDirs);
  277. params.weightsFileName = locateFile("mnist.caffemodel", params.dataDirs);
  278. params.meanFileName = locateFile("mnist_mean.binaryproto", params.dataDirs);
  279. params.inputTensorNames.push_back("data");
  280. params.batchSize = 1;
  281. params.outputTensorNames.push_back("prob");
  282. params.dlaCore = args.useDLACore;
  283. params.int8 = args.runInInt8;
  284. params.fp16 = args.runInFp16;
  285. return params;
  286. }
  287. // 定义打印参数含义的函数
  288. void printHelpInfo()
  289. {
  290. std::cout
  291. << "Usage: ./sample_mnist [-h or --help] [-d or --datadir=<path to data directory>] [--useDLACore=<int>]\n";
  292. std::cout << "--help Display help information\n";
  293. std::cout << "--datadir Specify path to a data directory, overriding the default. This option can be used "
  294. "multiple times to add multiple directories. If no data directories are given, the default is to use "
  295. "(data/samples/mnist/, data/mnist/)"
  296. << std::endl;
  297. std::cout << "--useDLACore=N Specify a DLA engine for layers that support DLA. Value can range from 0 to n-1, "
  298. "where n is the number of DLA engines on the platform."
  299. << std::endl;
  300. std::cout << "--int8 Run in Int8 mode.\n";
  301. std::cout << "--fp16 Run in FP16 mode.\n";
  302. }
  303. int main(int argc, char** argv)
  304. {
  305. // 解析参数的含义,例如--help
  306. samplesCommon::Args args;
  307. bool argsOK = samplesCommon::parseArgs(args, argc, argv);
  308. // 参数不匹配
  309. if (!argsOK)
  310. {
  311. sample::gLogError << "Invalid arguments" << std::endl;
  312. printHelpInfo();
  313. return EXIT_FAILURE;
  314. }
  315. // 输出参数帮助信息
  316. if (args.help)
  317. {
  318. printHelpInfo();
  319. return EXIT_SUCCESS;
  320. }
  321. // 定义用于 test 的logger
  322. auto sampleTest = sample::gLogger.defineTest(gSampleName, argc, argv);
  323. // 输出: 表明 test 开始
  324. sample::gLogger.reportTestStart(sampleTest);
  325. // 执行定义的 initializeSampleParams() 函数,使用命令行参数初始化结构体参数成员
  326. samplesCommon::CaffeSampleParams params = initializeSampleParams(args);
  327. // 使用 SampleMNIST 类根据params定义一个sample对象
  328. SampleMNIST sample(params);
  329. sample::gLogInfo << "Building and running a GPU inference engine for MNIST" << std::endl; // 打印信息
  330. // 执行 sample.build() 编译推理引擎
  331. if (!sample.build())
  332. {
  333. return sample::gLogger.reportFail(sampleTest); // 编译失败,使用gLogger报告信息
  334. }
  335. // 执行 sample.infer() 进行 inference 推理
  336. if (!sample.infer())
  337. {
  338. return sample::gLogger.reportFail(sampleTest); // 推理失败,使用gLogger报告信息
  339. }
  340. // 清除 sample 的所有状态,释放内存
  341. if (!sample.teardown())
  342. {
  343. return sample::gLogger.reportFail(sampleTest);
  344. }
  345. return sample::gLogger.reportPass(sampleTest);
  346. }

4--运行结果

        执行以下命令进行编译,生成可执行文件:

  1. mkdir build
  2. cd build
  3. cmake ..
  4. make

        执行以下命令运行可执行文件,并指定模型地址:

./TensorRT_test -d /home/liujinfu/Downloads/TensorRT-8.2.5.1/data/mnist

        运行结果截图:

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/article/detail/57543
推荐阅读
相关标签
  

闽ICP备14008679号