当前位置:   article > 正文

【机器学习】02. vs平台c++随机森林实现回归训练和预测并保存xml模型

【机器学习】02. vs平台c++随机森林实现回归训练和预测并保存xml模型

背景:项目需求,python框架只适合实现快速验证,但是算法真正部署项目中是不行的,需要将相关算法通过c++翻译并训练得到相应模型文件,并封装dll文件,本博客只实现训练和预测,dll文件详见参考文章。

前言:为保护客户数据,暂时使用鸢尾花数据集做测试;

数据集下载链接:

文件格式 iris_training.csv,iris_test.csv

链接:https://pan.baidu.com/s/1KzUwJwTgOYiy_tNUPZrCUQ 
提取码:tojv

 直接上代码

  1. #include <iostream>
  2. #include <fstream>
  3. #include "opencv2/core/core.hpp"
  4. #include "opencv2/ml/ml.hpp"
  5. cv::Ptr <cv::ml::RTrees> model_load;
  6. // 读取CSV文件并返回数据
  7. float** readCSV(const char* filePath, int& rows, int& cols) {
  8. std::ifstream file(filePath);
  9. // 检查文件是否成功打开
  10. if (!file.is_open()) {
  11. std::cerr << "无法打开文件\n";
  12. }
  13. std::string line;
  14. // 跳过第一行
  15. getline(file, line);
  16. // 统计行和列数
  17. rows = 0;
  18. cols = 0;
  19. while (getline(file, line)) {
  20. ++rows;
  21. std::istringstream iss(line);
  22. std::string value;
  23. while (getline(iss, value, ',')) {
  24. ++cols;
  25. }
  26. }
  27. cols /= rows;
  28. // 重新定位文件指针到文件开头
  29. file.clear();
  30. file.seekg(0, std::ios::beg);
  31. // 跳过第一行
  32. getline(file, line);
  33. // 分配内存
  34. float** data = new float* [rows];
  35. for (int i = 0; i < rows; ++i) {
  36. data[i] = new float[cols];
  37. }
  38. // 读取数据
  39. for (int i = 0; i < rows; ++i) {
  40. getline(file, line);
  41. std::istringstream iss(line);
  42. std::string value;
  43. for (int j = 0; j < cols; ++j) {
  44. getline(iss, value, ',');
  45. data[i][j] = stof(value);
  46. }
  47. }
  48. return data;
  49. }
  50. int train(float** data, int rows, int cols) {
  51. float* data_arr = new float[rows * cols];
  52. for (int i = 0; i < rows * cols; i++) {
  53. data_arr[i] = data[i / cols][i % cols];
  54. }
  55. cv::Mat data_mat = cv::Mat(rows, cols, CV_32FC1, data_arr);
  56. //获得标签
  57. cv::Mat label = data_mat.col(cols - 1).clone();
  58. //获得训练特征数据
  59. data_mat = data_mat.colRange(0, cols - 1);
  60. //std::cout << data_mat << "\n";
  61. //std::cout << label << "\n";
  62. //std::cout << data_mat.size() << "\n";
  63. //std::cout << label.size() << "\n";
  64. cv::Ptr<cv::ml::TrainData> train_data = cv::ml::TrainData::create(data_mat, cv::ml::ROW_SAMPLE, label, cv::noArray(), cv::noArray(), cv::noArray(), cv::noArray());
  65. cv::Ptr<cv::ml::RTrees> model = cv::ml::RTrees::create();
  66. //树的最大可能深度
  67. //model->setMaxDepth(100);
  68. //节点最小样本数量
  69. //model->setMinSampleCount(5);
  70. //回归树的终止标准
  71. //model->setRegressionAccuracy(0.01f);
  72. //是否建立替代分裂点
  73. //model->setUseSurrogates(false);
  74. //最大聚类簇数
  75. //model->setMaxCategories(15);
  76. //先验类概率数组
  77. //model->setPriors(cv::Mat());
  78. //计算的变量重要性
  79. //model->setCalculateVarImportance(true);
  80. //树节点随机选择的特征子集的大小
  81. //model->setActiveVarCount(1);
  82. //训练模型
  83. model->train(train_data);
  84. //保存模型
  85. model->save("test_model.xml");
  86. printf("model saved success!\n");
  87. delete[] data_arr;
  88. return 0;
  89. }
  90. int init_model(const char* modelPath) {
  91. model_load = cv::Algorithm::load<cv::ml::RTrees>(modelPath);
  92. if (model_load.empty()) {
  93. printf("load model failed!\n");
  94. return -1;
  95. }
  96. return 0;
  97. }
  98. int predict(float** data, int rows, int cols) {
  99. float* data_arr = new float[rows * cols];
  100. for (int i = 0; i < rows * cols; i++) {
  101. data_arr[i] = data[i / cols][i % cols];
  102. }
  103. cv::Mat data_mat = cv::Mat(rows, cols, CV_32FC1, data_arr);
  104. //获得标签
  105. cv::Mat label = data_mat.col(cols - 1).clone();
  106. //获得训练特征数据
  107. data_mat = data_mat.colRange(0, cols - 1);
  108. //std::cout << data_mat << "\n";
  109. //std::cout << label << "\n";
  110. //std::cout << data_mat.size() << "\n";
  111. //std::cout << label.size() << "\n";
  112. cv::Ptr<cv::ml::TrainData> test_data = cv::ml::TrainData::create(data_mat, cv::ml::ROW_SAMPLE, label, cv::noArray(), cv::noArray(), cv::noArray(), cv::noArray());
  113. for (int i = 0; i < rows; i++) {
  114. cv::Mat test_data = data_mat.row(i);
  115. float out = model_load->predict(test_data);
  116. std::cout << out << "\n";
  117. //res[i] = out;
  118. }
  119. return 0;
  120. }
  121. int main()
  122. {
  123. const char* trainData = "iris_training.csv";
  124. const char* testPath = "iris_test.csv";
  125. // 读取csv文件
  126. int rows, cols;
  127. float** data = readCSV(trainData, rows, cols);
  128. // 01 训练模型
  129. train(data, rows, cols);
  130. // 02 初始化
  131. const char* modelPath = "test_model.xml";
  132. init_model(modelPath);
  133. // 04 加载测试集
  134. float** testData = readCSV(testPath, rows, cols);
  135. // 05 预测
  136. predict(testData, rows, cols);
  137. // 释放每行的内存
  138. for (int i = 0; i < rows; ++i) {
  139. delete[] data[i];
  140. }
  141. // 释放指向每行的指针的内存
  142. delete[] data;
  143. for (int i = 0; i < rows; ++i) {
  144. delete[] testData[i];
  145. }
  146. delete[] testData;
  147. return 0;
  148. }

关于封装dll文件,参考

vs2022环境下,使用c#调用c++生成的dll动态链接库,实现ocr和条形码的识别_vs2022 c# c++-CSDN博客

本文内容由网友自发贡献,转载请注明出处:https://www.wpsshop.cn/article/detail/52895
推荐阅读
相关标签
  

闽ICP备14008679号