当前位置:   article > 正文

OpenCV中使用vulkan 进行dnn推理。_dnn_backend_vkcom

dnn_backend_vkcom

核心代码仅3行 

  1. m_model = readNetFromDarknet(m_modelConfig, m_modelWeights);
  2. m_model.setPreferableBackend(DNN_BACKEND_VKCOM);
  3. m_model.setPreferableTarget(DNN_TARGET_VULKAN);

以下代码为核心代码,用于opencv加载darknet模型:

  1. #pragma once
  2. #ifndef __DETECTION_H__
  3. #define __DETECTION_H__
  4. #include <opencv2/opencv.hpp>
  5. #include <opencv2/dnn.hpp>
  6. #include <opencv2/imgproc.hpp>
  7. #include <opencv2/highgui.hpp>
  8. #include <string.h>
  9. #include <vector>
  10. #include <fstream>
  11. using namespace std;
  12. using namespace cv;
  13. using namespace dnn;
  14. class Detection
  15. {
  16. public:
  17. //构造、析构函数
  18. Detection();
  19. Detection( const char* szModelConfig, const char* szModelWeights, const char* szClassFile=NULL);
  20. bool InitData(const char* szModelConfig, const char* szModelWeights, const char* szClassFile);
  21. ~Detection();
  22. //初始化函数
  23. void Initialize(int width, int height);
  24. //读取网络模型
  25. void ReadModel();
  26. //行人与车辆检测
  27. bool Detecting(Mat frame);
  28. //获取网络输出层名称
  29. vector<String> GetOutputsNames();
  30. //对输出进行处理,使用NMS选出最合适的框
  31. void PostProcess();
  32. //画检测结果
  33. void Drawer();
  34. //画出检测框和相关信息
  35. void DrawBoxes(int classId, float conf, int left, int top, int right, int bottom);
  36. //获取Mat对象
  37. Mat GetFrame();
  38. //获取图像宽度
  39. int GetResWidth();
  40. //获取图像高度
  41. int GetResHeight();
  42. void SetGPU(bool bUseGPU) { m_bUseGPU = bUseGPU; }
  43. private:
  44. //图像属性
  45. int m_width; //图像宽度
  46. int m_height; //图像高度
  47. //网络处理相关
  48. Net m_model; //网络模型
  49. Mat m_frame; //每一帧
  50. Mat m_blob; //从每一帧创建一个4D的blob用于网络输入
  51. vector<Mat> m_outs; //网络输出
  52. vector<float> m_confs; //置信度
  53. vector<Rect> m_boxes; //检测框左上角坐标、宽、高
  54. vector<int> m_classIds; //类别id
  55. vector<int> m_perfIndx; //非极大阈值处理后边界框的下标
  56. //检测超参数
  57. int m_inpWidth; //网络输入图像宽度
  58. int m_inpHeight; //网络输入图像高度
  59. float m_confThro; //置信度阈值
  60. float m_NMSThro; //NMS非极大抑制阈值
  61. vector<string> m_classes; //类别名称
  62. bool m_bUseGPU = false;
  63. private:
  64. //内存释放
  65. void Dump();
  66. private:
  67. string m_classesFile = "../data/objdetect.names";
  68. String m_modelConfig = "../data/objdetect.cfg";
  69. String m_modelWeights = "../data/objdetect.pattern";
  70. bool m_bReady = false;
  71. };
  72. #endif
  1. #include "Detection.h"
  2. using namespace cv;
  3. using namespace dnn;
  4. //构造函数,成员变量初始化
  5. Detection::Detection()
  6. {
  7. //图像属性
  8. m_width = 0;
  9. m_height = 0;
  10. m_inpWidth = 416;
  11. m_inpHeight = 416;
  12. //其他成员变量
  13. m_confThro = 0.25;
  14. m_NMSThro = 0.4;
  15. 网络模型加载
  16. //ReadModel();
  17. }
  18. //析构函数
  19. Detection::~Detection()
  20. {
  21. Dump();
  22. }
  23. //内存释放
  24. void Detection::Dump()
  25. {
  26. //网络输出相关清零
  27. m_outs.clear();
  28. m_boxes.clear();
  29. m_confs.clear();
  30. m_classIds.clear();
  31. m_perfIndx.clear();
  32. }
  33. //初始化函数
  34. void Detection::Initialize(int width, int height)
  35. {
  36. //图像属性
  37. m_width = width;
  38. m_height = height;
  39. }
  40. Detection::Detection( const char* szModelConfig, const char* szModelWeights, const char* szClassFile)
  41. {
  42. m_bReady=InitData(szModelConfig, szModelWeights, szClassFile);
  43. }
  44. bool Detection::InitData(const char* szModelConfig, const char* szModelWeights, const char* szClassFile)
  45. {
  46. m_classesFile = szClassFile;
  47. m_modelConfig = szModelConfig;
  48. m_modelWeights = szModelWeights;
  49. ReadModel();
  50. return true;
  51. }
  52. //读取网络模型和类别
  53. void Detection::ReadModel()
  54. {
  55. //加载类别名
  56. if (!m_classesFile.empty())
  57. {
  58. ifstream ifs(m_classesFile.c_str());
  59. string line;
  60. while (getline(ifs, line)) m_classes.push_back(line);
  61. }
  62. //加载网络模型
  63. m_model = readNetFromDarknet(m_modelConfig, m_modelWeights);
  64. if (m_bUseGPU)
  65. {
  66. m_model.setPreferableBackend(DNN_BACKEND_VKCOM);
  67. m_model.setPreferableTarget(DNN_TARGET_VULKAN);
  68. }
  69. else
  70. {
  71. m_model.setPreferableBackend(DNN_BACKEND_OPENCV);
  72. //m_model.setPreferableTarget(DNN_TARGET_CPU);
  73. m_model.setPreferableTarget(DNN_TARGET_OPENCL); // opencl
  74. }
  75. }
  76. //行人与车辆检测
  77. bool Detection::Detecting(Mat frame)
  78. {
  79. m_frame = frame.clone();
  80. //创建4D的blob用于网络输入
  81. blobFromImage(m_frame, m_blob, 1 / 255.0,Size(m_inpWidth, m_inpHeight), Scalar(0, 0, 0), true, false);
  82. //设置网络输入
  83. m_model.setInput(m_blob);
  84. //前向预测得到网络输出,forward需要知道输出层,这里用了一个函数找到输出层
  85. m_model.forward(m_outs, GetOutputsNames());
  86. //使用非极大抑制NMS删除置信度较低的边界框
  87. PostProcess();
  88. //画检测框
  89. //Drawer();
  90. return true;
  91. }
  92. //获取网络输出层名称
  93. vector<String> Detection::GetOutputsNames()
  94. {
  95. static vector<String> names;
  96. if (names.empty())
  97. {
  98. //得到输出层索引号
  99. vector<int> outLayers = m_model.getUnconnectedOutLayers();
  100. //得到网络中所有层名称
  101. vector<String> layersNames = m_model.getLayerNames();
  102. //获取输出层名称
  103. names.resize(outLayers.size());
  104. for (int i = 0; i < outLayers.size(); ++i)
  105. names[i] = layersNames[outLayers[i] - 1];
  106. }
  107. return names;
  108. }
  109. //使用非极大抑制NMS去除置信度较低的边界框
  110. void Detection::PostProcess()
  111. {
  112. for (int num = 0; num < m_outs.size(); num++)
  113. {
  114. Point Position;
  115. double confidence;
  116. //得到每个输出的数据
  117. float* data = (float*)m_outs[num].data;
  118. for (int j = 0; j < m_outs[num].rows; j++, data += m_outs[num].cols)
  119. {
  120. //得到该输出的所有类别的
  121. Mat scores = m_outs[num].row(j).colRange(5, m_outs[num].cols);
  122. //获取最大置信度对应的值和位置
  123. minMaxLoc(scores, 0, &confidence, 0, &Position);
  124. //对置信度大于阈值的边界框进行相关计算和保存
  125. if (confidence > m_confThro)
  126. {
  127. //data[0],data[1],data[2],data[3]都是相对于原图像的比例
  128. int centerX = (int)(data[0] * m_width);
  129. int centerY = (int)(data[1] * m_height);
  130. int width = (int)(data[2] * m_width);
  131. int height = (int)(data[3] * m_height);
  132. int left = centerX - width / 2;
  133. int top = centerY - height / 2;
  134. //保存信息
  135. m_classIds.push_back(Position.x);
  136. m_confs.push_back((float)confidence);
  137. m_boxes.push_back(Rect(left, top, width, height));
  138. }
  139. }
  140. }
  141. //非极大值抑制,以消除具有较低置信度的冗余重叠框
  142. NMSBoxes(m_boxes, m_confs, m_confThro, m_NMSThro, m_perfIndx);
  143. }
  144. //画出检测结果
  145. void Detection::Drawer()
  146. {
  147. //获取所有最佳检测框信息
  148. for (int i = 0; i < m_perfIndx.size(); i++)
  149. {
  150. int idx = m_perfIndx[i];
  151. Rect box = m_boxes[idx];
  152. DrawBoxes(m_classIds[idx], m_confs[idx], box.x, box.y,
  153. box.x + box.width, box.y + box.height);
  154. }
  155. }
  156. //画出检测框和相关信息
  157. void Detection::DrawBoxes(int classId, float conf, int left, int top, int right, int bottom)
  158. {
  159. //画检测框
  160. rectangle(m_frame, Point(left, top), Point(right, bottom), Scalar(255, 178, 50), 3);
  161. //该检测框对应的类别和置信度
  162. string label = format("%.2f", conf);
  163. if (!m_classes.empty())
  164. {
  165. CV_Assert(classId < (int)m_classes.size());
  166. label = m_classes[classId] + ":" + label;
  167. }
  168. //将标签显示在检测框顶部
  169. int baseLine;
  170. Size labelSize = getTextSize(label, FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);
  171. top = max(top, labelSize.height);
  172. rectangle(m_frame, Point(left, top - round(1.5*labelSize.height)), Point(left + round(1.5*labelSize.width), top + baseLine), Scalar(255, 255, 255), FILLED);
  173. putText(m_frame, label, Point(left, top), FONT_HERSHEY_SIMPLEX, 0.75, Scalar(0, 0, 0), 1);
  174. }
  175. //获取Mat对象
  176. Mat Detection::GetFrame()
  177. {
  178. return m_frame;
  179. }
  180. //获取结果图像宽度
  181. int Detection::GetResWidth()
  182. {
  183. return m_width;
  184. }
  185. //获取结果图像高度
  186. int Detection::GetResHeight()
  187. {
  188. return m_height;
  189. }

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/371518
推荐阅读
相关标签
  

闽ICP备14008679号