赞
踩
本文提到的sampleMNISTAPI与之前0. 前言
一点点疑问:TensorRT要使用ONNX模型应该有两种方式,一种是像本例一样,直接在程序中转换ONNX模型形式,另外还有一种是通过官方工具先将ONNX模型转换为engine文件,不知道这两种方式有什么区别。
SampleOnnxMNIST::build() 函数。bool SampleOnnxMNIST::build() { // 构建模型builder auto builder = SampleUniquePtr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(sample::gLogger.getTRTLogger())); if (!builder) { return false; } // 构建空白network对象 const auto explicitBatch = 1U << static_cast<uint32_t>(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); auto network = SampleUniquePtr<nvinfer1::INetworkDefinition>(builder->createNetworkV2(explicitBatch)); if (!network) { return false; } // 创建BuildConfig,我也不知道是干啥用的 auto config = SampleUniquePtr<nvinfer1::IBuilderConfig>(builder->createBuilderConfig()); if (!config) { return false; } // 构建Onnx模型解析器 auto parser = SampleUniquePtr<nvonnxparser::IParser>(nvonnxparser::createParser(*network, sample::gLogger.getTRTLogger())); if (!parser) { return false; } // 构建模型,通过parser解析,并将解析结果导入network中 auto constructed = constructNetwork(builder, network, config, parser); if (!constructed) { return false; } mEngine = std::shared_ptr<nvinfer1::ICudaEngine>( builder->buildEngineWithConfig(*network, *config), samplesCommon::InferDeleter()); if (!mEngine) { return false; } // 验证结果 assert(network->getNbInputs() == 1); mInputDims = network->getInput(0)->getDimensions(); assert(mInputDims.nbDims == 4); assert(network->getNbOutputs() == 1); mOutputDims = network->getOutput(0)->getDimensions(); assert(mOutputDims.nbDims == 2); return true; }
constrctNetwork,即通过parser解析模型并保存到network中//! //! \brief Uses a ONNX parser to create the Onnx MNIST Network and marks the //! output layers //! //! \param network Pointer to the network that will be populated with the Onnx MNIST network //! //! \param builder Pointer to the engine builder //! bool SampleOnnxMNIST::constructNetwork(SampleUniquePtr<nvinfer1::IBuilder>& builder, SampleUniquePtr<nvinfer1::INetworkDefinition>& network, SampleUniquePtr<nvinfer1::IBuilderConfig>& config, SampleUniquePtr<nvonnxparser::IParser>& parser) { // 注意,构建解析器的时候就已经把network对象作为参数传入了 auto parsed = parser->parseFromFile(locateFile(mParams.onnxFileName, mParams.dataDirs).c_str(), static_cast<int>(sample::gLogger.getReportableSeverity())); if (!parsed) { return false; } // 模型量化,不知道跟onnx_tensorrt工具有啥区别 config->setMaxWorkspaceSize(16_MiB); if (mParams.fp16) { config->setFlag(BuilderFlag::kFP16); } if (mParams.int8) { config->setFlag(BuilderFlag::kINT8); samplesCommon::setAllTensorScales(network.get(), 127.0f, 127.0f); } // 这里的 DLA 就是 Deep Learning Accelerator // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#dla_layers samplesCommon::enableDLA(builder.get(), config.get(), mParams.dlaCore); return true; }
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。