当前位置:   article > 正文

使用onnxruntime加载YOLOv8生成的onnx文件进行实例分割_如何使用onnx文件处理图片

如何使用onnx文件处理图片

      在网上下载了60多幅包含西瓜和冬瓜的图像组成melon数据集,使用 EISeg 工具进行标注,然后使用 eiseg2yolov8 脚本将.json文件转换成YOLOv8支持的.txt文件,并自动生成YOLOv8支持的目录结构,包括melon.yaml文件,其内容如下:

  1. path: ../datasets/melon_seg # dataset root dir
  2. train: images/train # train images (relative to 'path')
  3. val: images/val # val images (relative to 'path')
  4. test: # test images (optional)
  5. # Classes
  6. names:
  7. 0: watermelon
  8. 1: wintermelon

      对melon数据集进行训练的Python实现如下:最终生成的模型文件有best.pt、best.onnx、best.torchscript

  1. import argparse
  2. import colorama
  3. from ultralytics import YOLO
  4. def parse_args():
  5. parser = argparse.ArgumentParser(description="YOLOv8 train")
  6. parser.add_argument("--yaml", required=True, type=str, help="yaml file")
  7. parser.add_argument("--epochs", required=True, type=int, help="number of training")
  8. parser.add_argument("--task", required=True, type=str, choices=["detect", "segment"], help="specify what kind of task")
  9. args = parser.parse_args()
  10. return args
  11. def train(task, yaml, epochs):
  12. if task == "detect":
  13. model = YOLO("yolov8n.pt") # load a pretrained model
  14. elif task == "segment":
  15. model = YOLO("yolov8n-seg.pt") # load a pretrained model
  16. else:
  17. print(colorama.Fore.RED + "Error: unsupported task:", task)
  18. raise
  19. results = model.train(data=yaml, epochs=epochs, imgsz=640) # train the model
  20. metrics = model.val() # It'll automatically evaluate the data you trained, no arguments needed, dataset and settings remembered
  21. model.export(format="onnx") #, dynamic=True) # export the model, cannot specify dynamic=True, opencv does not support
  22. # model.export(format="onnx", opset=12, simplify=True, dynamic=False, imgsz=640)
  23. model.export(format="torchscript") # libtorch
  24. if __name__ == "__main__":
  25. colorama.init()
  26. args = parse_args()
  27. train(args.task, args.yaml, args.epochs)
  28. print(colorama.Fore.GREEN + "====== execution completed ======")

      以下是使用onnxruntime接口加载onnx文件进行实例分割的C++实现代码:

  1. namespace {
  2. constexpr bool cuda_enabled{ false };
  3. constexpr int input_size[2]{ 640, 640 }; // {height,width}, input shape (1, 3, 640, 640) BCHW and output shape(s): detect:(1,6,8400); segment:(1,38,8400),(1,32,160,160)
  4. constexpr float confidence_threshold{ 0.45 }; // confidence threshold
  5. constexpr float iou_threshold{ 0.50 }; // iou threshold
  6. constexpr float mask_threshold{ 0.50 }; // segment mask threshold
  7. #ifdef _MSC_VER
  8. constexpr char* onnx_file{ "../../../data/best.onnx" };
  9. constexpr char* torchscript_file{ "../../../data/best.torchscript" };
  10. constexpr char* images_dir{ "../../../data/images/predict" };
  11. constexpr char* result_dir{ "../../../data/result" };
  12. constexpr char* classes_file{ "../../../data/images/labels.txt" };
  13. #else
  14. constexpr char* onnx_file{ "data/best.onnx" };
  15. constexpr char* torchscript_file{ "data/best.torchscript" };
  16. constexpr char* images_dir{ "data/images/predict" };
  17. constexpr char* result_dir{ "data/result" };
  18. constexpr char* classes_file{ "data/images/labels.txt" };
  19. #endif
  20. std::vector<std::string> parse_classes_file(const char* name)
  21. {
  22. std::vector<std::string> classes;
  23. std::ifstream file(name);
  24. if (!file.is_open()) {
  25. std::cerr << "Error: fail to open classes file: " << name << std::endl;
  26. return classes;
  27. }
  28. std::string line;
  29. while (std::getline(file, line)) {
  30. auto pos = line.find_first_of(" ");
  31. classes.emplace_back(line.substr(0, pos));
  32. }
  33. file.close();
  34. return classes;
  35. }
  36. auto get_dir_images(const char* name)
  37. {
  38. std::map<std::string, std::string> images; // image name, image path + image name
  39. for (auto const& dir_entry : std::filesystem::directory_iterator(name)) {
  40. if (dir_entry.is_regular_file())
  41. images[dir_entry.path().filename().string()] = dir_entry.path().string();
  42. }
  43. return images;
  44. }
  45. std::wstring ctow(const char* str)
  46. {
  47. //std::wstring_convert<std::codecvt_utf8<wchar_t>>().from_bytes(std::string); // std::string -> std::wstring
  48. constexpr size_t len{ 128 };
  49. wchar_t wch[len];
  50. swprintf(wch, len, L"%hs", str);
  51. return std::wstring(wch);
  52. }
  53. float image_preprocess(const cv::Mat& src, cv::Mat& dst)
  54. {
  55. cv::cvtColor(src, dst, cv::COLOR_BGR2RGB);
  56. float scalex = src.cols * 1.f / input_size[1];
  57. float scaley = src.rows * 1.f / input_size[0];
  58. if (scalex > scaley)
  59. cv::resize(dst, dst, cv::Size(input_size[1], static_cast<int>(src.rows / scalex)));
  60. else
  61. cv::resize(dst, dst, cv::Size(static_cast<int>(src.cols / scaley), input_size[0]));
  62. cv::Mat tmp = cv::Mat::zeros(input_size[0], input_size[1], CV_8UC3);
  63. dst.copyTo(tmp(cv::Rect(0, 0, dst.cols, dst.rows)));
  64. dst = tmp;
  65. return (scalex > scaley) ? scalex : scaley;
  66. }
  67. template<typename T>
  68. void image_to_blob(const cv::Mat& src, T* blob)
  69. {
  70. for (auto c = 0; c < 3; ++c) {
  71. for (auto h = 0; h < src.rows; ++h) {
  72. for (auto w = 0; w < src.cols; ++w) {
  73. blob[c * src.rows * src.cols + h * src.cols + w] = (src.at<cv::Vec3b>(h, w)[c]) / 255.f;
  74. }
  75. }
  76. }
  77. }
  78. void get_masks(const cv::Mat& features, const cv::Mat& proto, const std::vector<int>& output1_sizes, const cv::Mat& frame, const cv::Rect box, cv::Mat& mk)
  79. {
  80. const cv::Size shape_src(frame.cols, frame.rows), shape_input(input_size[1], input_size[0]), shape_mask(output1_sizes[3], output1_sizes[2]);
  81. cv::Mat res = (features * proto).t();
  82. res = res.reshape(1, { shape_mask.height, shape_mask.width });
  83. // apply sigmoid to the mask
  84. cv::exp(-res, res);
  85. res = 1.0 / (1.0 + res);
  86. cv::resize(res, res, shape_input);
  87. float scalex = shape_src.width * 1.0 / shape_input.width;
  88. float scaley = shape_src.height * 1.0 / shape_input.height;
  89. cv::Mat tmp;
  90. if (scalex > scaley)
  91. cv::resize(res, tmp, cv::Size(shape_src.width, static_cast<int>(shape_input.height * scalex)));
  92. else
  93. cv::resize(res, tmp, cv::Size(static_cast<int>(shape_input.width * scaley), shape_src.height));
  94. cv::Mat dst = tmp(cv::Rect(0, 0, shape_src.width, shape_src.height));
  95. mk = dst(box) > mask_threshold;
  96. }
  97. void draw_boxes_mask(const std::vector<std::string>& classes, const std::vector<int>& ids, const std::vector<float>& confidences,
  98. const std::vector<cv::Rect>& boxes, const std::vector<cv::Mat>& masks, const std::string& name, cv::Mat& frame)
  99. {
  100. std::cout << "image name: " << name << ", number of detections: " << ids.size() << std::endl;
  101. std::random_device rd;
  102. std::mt19937 gen(rd());
  103. std::uniform_int_distribution<int> dis(100, 255);
  104. cv::Mat mk = frame.clone();
  105. std::vector<cv::Scalar> colors;
  106. for (auto i = 0; i < classes.size(); ++i)
  107. colors.emplace_back(cv::Scalar(dis(gen), dis(gen), dis(gen)));
  108. for (auto i = 0; i < ids.size(); ++i) {
  109. cv::rectangle(frame, boxes[i], colors[ids[i]], 2);
  110. std::string class_string = classes[ids[i]] + ' ' + std::to_string(confidences[i]).substr(0, 4);
  111. cv::Size text_size = cv::getTextSize(class_string, cv::FONT_HERSHEY_DUPLEX, 1, 2, 0);
  112. cv::Rect text_box(boxes[i].x, boxes[i].y - 40, text_size.width + 10, text_size.height + 20);
  113. cv::rectangle(frame, text_box, colors[ids[i]], cv::FILLED);
  114. cv::putText(frame, class_string, cv::Point(boxes[i].x + 5, boxes[i].y - 10), cv::FONT_HERSHEY_DUPLEX, 1, cv::Scalar(0, 0, 0), 2, 0);
  115. mk(boxes[i]).setTo(colors[ids[i]], masks[i]);
  116. }
  117. cv::addWeighted(frame, 0.5, mk, 0.5, 0, frame);
  118. //cv::imshow("Inference", frame);
  119. //cv::waitKey(-1);
  120. std::string path(result_dir);
  121. cv::imwrite(path + "/" + name, frame);
  122. }
  123. void post_process_mask(const cv::Mat& output0, const cv::Mat& output1, const std::vector<int>& output1_sizes, const std::vector<std::string>& classes, const std::string& name, cv::Mat& frame)
  124. {
  125. std::vector<int> class_ids;
  126. std::vector<float> confidences;
  127. std::vector<cv::Rect> boxes;
  128. std::vector<std::vector<float>> masks;
  129. float scalex = frame.cols * 1.f / input_size[1]; // note: image_preprocess function
  130. float scaley = frame.rows * 1.f / input_size[0];
  131. auto scale = (scalex > scaley) ? scalex : scaley;
  132. const float* data = (float*)output0.data;
  133. for (auto i = 0; i < output0.rows; ++i) {
  134. cv::Mat scores(1, classes.size(), CV_32FC1, (float*)data + 4);
  135. cv::Point class_id;
  136. double max_class_score;
  137. cv::minMaxLoc(scores, 0, &max_class_score, 0, &class_id);
  138. if (max_class_score > confidence_threshold) {
  139. confidences.emplace_back(max_class_score);
  140. class_ids.emplace_back(class_id.x);
  141. masks.emplace_back(std::vector<float>(data + 4 + classes.size(), data + output0.cols)); // 32
  142. float x = data[0];
  143. float y = data[1];
  144. float w = data[2];
  145. float h = data[3];
  146. int left = std::max(0, std::min(int((x - 0.5 * w) * scale), frame.cols));
  147. int top = std::max(0, std::min(int((y - 0.5 * h) * scale), frame.rows));
  148. int width = std::max(0, std::min(int(w * scale), frame.cols - left));
  149. int height = std::max(0, std::min(int(h * scale), frame.rows - top));
  150. boxes.emplace_back(cv::Rect(left, top, width, height));
  151. }
  152. data += output0.cols;
  153. }
  154. std::vector<int> nms_result;
  155. cv::dnn::NMSBoxes(boxes, confidences, confidence_threshold, iou_threshold, nms_result);
  156. cv::Mat proto = output1.reshape(0, { output1_sizes[1], output1_sizes[2] * output1_sizes[3] });
  157. std::vector<int> ids;
  158. std::vector<float> confs;
  159. std::vector<cv::Rect> rects;
  160. std::vector<cv::Mat> mks;
  161. for (size_t i = 0; i < nms_result.size(); ++i) {
  162. auto index = nms_result[i];
  163. ids.emplace_back(class_ids[index]);
  164. confs.emplace_back(confidences[index]);
  165. boxes[index] = boxes[index] & cv::Rect(0, 0, frame.cols, frame.rows);
  166. cv::Mat mk;
  167. get_masks(cv::Mat(masks[index]).t(), proto, output1_sizes, frame, boxes[index], mk);
  168. mks.emplace_back(mk);
  169. rects.emplace_back(boxes[index]);
  170. }
  171. draw_boxes_mask(classes, ids, confs, rects, mks, name, frame);
  172. }
  173. } // namespace
  174. int test_yolov8_segment_onnxruntime()
  175. {
  176. try {
  177. Ort::Env env = Ort::Env(ORT_LOGGING_LEVEL_WARNING, "Yolo");
  178. Ort::SessionOptions session_option;
  179. if (cuda_enabled) {
  180. OrtCUDAProviderOptions cuda_option;
  181. cuda_option.device_id = 0;
  182. session_option.AppendExecutionProvider_CUDA(cuda_option);
  183. }
  184. session_option.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
  185. session_option.SetIntraOpNumThreads(1);
  186. session_option.SetLogSeverityLevel(3);
  187. Ort::Session session(env, ctow(onnx_file).c_str(), session_option);
  188. Ort::AllocatorWithDefaultOptions allocator;
  189. std::vector<const char*> input_node_names, output_node_names;
  190. std::vector<std::string> input_node_names_, output_node_names_;
  191. for (auto i = 0; i < session.GetInputCount(); ++i) {
  192. Ort::AllocatedStringPtr input_node_name = session.GetInputNameAllocated(i, allocator);
  193. input_node_names_.emplace_back(input_node_name.get());
  194. }
  195. for (auto i = 0; i < session.GetOutputCount(); ++i) {
  196. Ort::AllocatedStringPtr output_node_name = session.GetOutputNameAllocated(i, allocator);
  197. output_node_names_.emplace_back(output_node_name.get());
  198. }
  199. for (auto i = 0; i < input_node_names_.size(); ++i)
  200. input_node_names.emplace_back(input_node_names_[i].c_str());
  201. for (auto i = 0; i < output_node_names_.size(); ++i)
  202. output_node_names.emplace_back(output_node_names_[i].c_str());
  203. std::unique_ptr<float[]> blob(new float[input_size[0] * input_size[1] * 3]);
  204. std::vector<int64_t> input_node_dims{ 1, 3, input_size[1], input_size[0] };
  205. auto classes = parse_classes_file(classes_file);
  206. if (classes.size() == 0) {
  207. std::cerr << "Error: fail to parse classes file: " << classes_file << std::endl;
  208. return -1;
  209. }
  210. if (!std::filesystem::exists(result_dir)) {
  211. std::filesystem::create_directories(result_dir);
  212. }
  213. for (const auto& [key, val] : get_dir_images(images_dir)) {
  214. cv::Mat frame = cv::imread(val, cv::IMREAD_COLOR);
  215. if (frame.empty()) {
  216. std::cerr << "Warning: unable to load image: " << val << std::endl;
  217. continue;
  218. }
  219. auto tstart = std::chrono::high_resolution_clock::now();
  220. cv::Mat rgb;
  221. image_preprocess(frame, rgb);
  222. image_to_blob(rgb, blob.get());
  223. Ort::Value input_tensor = Ort::Value::CreateTensor<float>(
  224. Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU), blob.get(), 3 * input_size[1] * input_size[0], input_node_dims.data(), input_node_dims.size());
  225. auto output_tensors = session.Run(Ort::RunOptions{nullptr}, input_node_names.data(), &input_tensor, input_node_names.size(), output_node_names.data(), output_node_names.size());
  226. if (output_tensors.size() != 2) {
  227. std::cerr << "Error: output must have 2 layers: " << output_tensors.size() << std::endl;
  228. return -1;
  229. }
  230. // output0
  231. std::vector<int64_t> output0_node_dims = output_tensors[0].GetTypeInfo().GetTensorTypeAndShapeInfo().GetShape();
  232. auto output0 = output_tensors[0].GetTensorMutableData<float>();
  233. cv::Mat data0 = cv::Mat(output0_node_dims[1], output0_node_dims[2], CV_32F, output0);
  234. data0 = data0.t();
  235. // output1
  236. std::vector<int64_t> output1_node_dims = output_tensors[1].GetTypeInfo().GetTensorTypeAndShapeInfo().GetShape();
  237. auto output1 = output_tensors[1].GetTensorMutableData<float>();
  238. std::vector<int> sizes;
  239. for (auto val : output1_node_dims)
  240. sizes.emplace_back(val);
  241. cv::Mat data1 = cv::Mat(sizes, CV_32F, output1);
  242. auto tend = std::chrono::high_resolution_clock::now();
  243. std::cout << "elapsed millisenconds: " << std::chrono::duration_cast<std::chrono::milliseconds>(tend - tstart).count() << " ms" << std::endl;
  244. post_process_mask(data0, data1, sizes, classes, key, frame);
  245. }
  246. }
  247. catch (const std::exception& e) {
  248. std::cerr << "Error: " << e.what() << std::endl;
  249. return -1;
  250. }
  251. return 0;
  252. }

      labels.txt文件内容如下:仅2类

  1. watermelon 0
  2. wintermelon 1

      说明:

      1.这里使用的onnxruntime版本为1.18.0;

      2.windows下,onnxruntime库在debug和release为同一套库,在debug和release下均可执行;

      3.通过指定变量cuda_enabled判断走cpu还是gpu流程 ;

      4.windows下,onnxruntime中有些接口参数为wchar_t*,而linux下为char*,因此在windows下需要单独做转换,这里通过ctow函数实现从char*到wchar_t的转换。

      执行结果如下图所示:同样的预测图像集,与opencv dnn结果相似,它们具有相同的后处理流程;下面显示的耗时是在cpu下,gpu下仅20毫秒左右

      其中一幅图像的分割结果如下图所示:

      GitHubhttps://github.com/fengbingchun/NN_Test

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/725081
推荐阅读
相关标签
  

闽ICP备14008679号