赞
踩
背景:项目需求,python框架只适合实现快速验证,但是算法真正部署项目中是不行的,需要将相关算法通过c++翻译并训练得到相应模型文件,并封装dll文件,本博客只实现训练和预测,dll文件详见参考文章。
前言:为保护客户数据,暂时使用鸢尾花数据集做测试;
数据集下载链接:
文件格式 iris_training.csv,iris_test.csv
链接:https://pan.baidu.com/s/1KzUwJwTgOYiy_tNUPZrCUQ
提取码:tojv
直接上代码
- #include <iostream>
- #include <fstream>
- #include "opencv2/core/core.hpp"
- #include "opencv2/ml/ml.hpp"
-
-
- cv::Ptr <cv::ml::RTrees> model_load;
-
-
- // 读取CSV文件并返回数据
- float** readCSV(const char* filePath, int& rows, int& cols) {
- std::ifstream file(filePath);
- // 检查文件是否成功打开
- if (!file.is_open()) {
- std::cerr << "无法打开文件\n";
- }
-
- std::string line;
- // 跳过第一行
- getline(file, line);
-
- // 统计行和列数
- rows = 0;
- cols = 0;
- while (getline(file, line)) {
- ++rows;
- std::istringstream iss(line);
- std::string value;
- while (getline(iss, value, ',')) {
- ++cols;
- }
- }
- cols /= rows;
-
- // 重新定位文件指针到文件开头
- file.clear();
- file.seekg(0, std::ios::beg);
-
- // 跳过第一行
- getline(file, line);
-
- // 分配内存
- float** data = new float* [rows];
- for (int i = 0; i < rows; ++i) {
- data[i] = new float[cols];
- }
-
- // 读取数据
- for (int i = 0; i < rows; ++i) {
- getline(file, line);
- std::istringstream iss(line);
- std::string value;
- for (int j = 0; j < cols; ++j) {
- getline(iss, value, ',');
- data[i][j] = stof(value);
- }
- }
-
-
- return data;
- }
-
-
- int train(float** data, int rows, int cols) {
-
- float* data_arr = new float[rows * cols];
-
- for (int i = 0; i < rows * cols; i++) {
- data_arr[i] = data[i / cols][i % cols];
- }
-
- cv::Mat data_mat = cv::Mat(rows, cols, CV_32FC1, data_arr);
-
- //获得标签
- cv::Mat label = data_mat.col(cols - 1).clone();
-
- //获得训练特征数据
- data_mat = data_mat.colRange(0, cols - 1);
-
- //std::cout << data_mat << "\n";
- //std::cout << label << "\n";
-
- //std::cout << data_mat.size() << "\n";
- //std::cout << label.size() << "\n";
-
- cv::Ptr<cv::ml::TrainData> train_data = cv::ml::TrainData::create(data_mat, cv::ml::ROW_SAMPLE, label, cv::noArray(), cv::noArray(), cv::noArray(), cv::noArray());
- cv::Ptr<cv::ml::RTrees> model = cv::ml::RTrees::create();
-
- //树的最大可能深度
- //model->setMaxDepth(100);
- //节点最小样本数量
- //model->setMinSampleCount(5);
- //回归树的终止标准
- //model->setRegressionAccuracy(0.01f);
- //是否建立替代分裂点
- //model->setUseSurrogates(false);
- //最大聚类簇数
- //model->setMaxCategories(15);
- //先验类概率数组
- //model->setPriors(cv::Mat());
- //计算的变量重要性
- //model->setCalculateVarImportance(true);
- //树节点随机选择的特征子集的大小
- //model->setActiveVarCount(1);
-
- //训练模型
- model->train(train_data);
-
- //保存模型
- model->save("test_model.xml");
- printf("model saved success!\n");
-
- delete[] data_arr;
-
-
- return 0;
- }
-
-
- int init_model(const char* modelPath) {
- model_load = cv::Algorithm::load<cv::ml::RTrees>(modelPath);
- if (model_load.empty()) {
- printf("load model failed!\n");
- return -1;
- }
-
- return 0;
- }
-
-
- int predict(float** data, int rows, int cols) {
-
- float* data_arr = new float[rows * cols];
-
- for (int i = 0; i < rows * cols; i++) {
- data_arr[i] = data[i / cols][i % cols];
- }
-
- cv::Mat data_mat = cv::Mat(rows, cols, CV_32FC1, data_arr);
-
- //获得标签
- cv::Mat label = data_mat.col(cols - 1).clone();
-
- //获得训练特征数据
- data_mat = data_mat.colRange(0, cols - 1);
-
- //std::cout << data_mat << "\n";
- //std::cout << label << "\n";
-
- //std::cout << data_mat.size() << "\n";
- //std::cout << label.size() << "\n";
-
- cv::Ptr<cv::ml::TrainData> test_data = cv::ml::TrainData::create(data_mat, cv::ml::ROW_SAMPLE, label, cv::noArray(), cv::noArray(), cv::noArray(), cv::noArray());
-
-
- for (int i = 0; i < rows; i++) {
- cv::Mat test_data = data_mat.row(i);
- float out = model_load->predict(test_data);
- std::cout << out << "\n";
- //res[i] = out;
- }
-
-
- return 0;
- }
-
-
- int main()
- {
- const char* trainData = "iris_training.csv";
- const char* testPath = "iris_test.csv";
-
- // 读取csv文件
- int rows, cols;
- float** data = readCSV(trainData, rows, cols);
-
- // 01 训练模型
- train(data, rows, cols);
-
- // 02 初始化
- const char* modelPath = "test_model.xml";
- init_model(modelPath);
-
- // 04 加载测试集
- float** testData = readCSV(testPath, rows, cols);
-
- // 05 预测
- predict(testData, rows, cols);
-
-
- // 释放每行的内存
- for (int i = 0; i < rows; ++i) {
- delete[] data[i];
- }
- // 释放指向每行的指针的内存
- delete[] data;
-
-
- for (int i = 0; i < rows; ++i) {
- delete[] testData[i];
- }
- delete[] testData;
-
-
- return 0;
- }

关于封装dll文件,参考
vs2022环境下,使用c#调用c++生成的dll动态链接库,实现ocr和条形码的识别_vs2022 c# c++-CSDN博客
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。