当前位置:   article > 正文

opencv C++ SVM模型训练与分类实现

opencv c++ svm

最近想学习一下分类算法的内容,恰好opencv有SVM的函数,故先从这个下手。找了许多资料,发现要么是opencv2、3的,要么就没有具体实现代码,学习还是把代码与原理一起结合来看比较好。

其中,我主要参考的是这一篇文章:

学习SVM(一) SVM模型训练与分类的OpenCV实现icon-default.png?t=M4ADhttps://blog.csdn.net/chaipp0607/article/details/68067098写得非常好!但是是2017年发布的文章,其中许多内容都做了更新,我用的是opencv 4.5.1版本,win10系统,vs2019作开发工具。具体opencv配置不说了,我对上面那篇文章的代码进行了更新。

步骤一样.

一、数据准备

首先找到opencv库自带的digits图片,我的电脑上路径在:D:\app\opencv4.5.1\opencv\opencv\sources\samples\data\digits.png

然后在D盘建立如下文件夹:

只需新建命名就好了,不用往里面放东西。接下来建立vs2019项目工程,新建源文件

复制如下代码:

  1. #include <windows.h>
  2. #include <iostream>
  3. #include <opencv2/core/core.hpp>
  4. #include <opencv2/highgui/highgui.hpp>
  5. #include <opencv2/imgproc/imgproc.hpp>
  6. #include <opencv2/core/utils/logger.hpp>
  7. #include <thread>
  8. #include <time.h>
  9. //#include <stdio.h>
  10. #include <string.h>
  11. using namespace std;
  12. using namespace cv;
  13. int main()
  14. {
  15. char ad[128] = { 0 };
  16. int filename = 0, filenum = 0;
  17. Mat img = imread("digits.png");
  18. Mat gray;
  19. cvtColor(img, gray, COLOR_BGR2GRAY);
  20. int b = 20;
  21. int m = gray.rows / b; //原图为1000*2000
  22. int n = gray.cols / b; //裁剪为5000个20*20的小图块
  23. for (int i = 0; i < m; i++)
  24. {
  25. int offsetRow = i * b; //行上的偏移量
  26. if (i % 5 == 0 && i != 0)
  27. {
  28. filename++;
  29. filenum = 0;
  30. }
  31. for (int j = 0; j < n; j++)
  32. {
  33. int offsetCol = j * b; //列上的偏移量
  34. sprintf_s(ad, "D:\\data\\%d\\%d.jpg", filename, filenum++);
  35. //截取20*20的小块
  36. Mat tmp;
  37. gray(Range(offsetRow, offsetRow + b), Range(offsetCol, offsetCol + b)).copyTo(tmp);
  38. imwrite(ad, tmp);
  39. }
  40. }
  41. return 0;
  42. }

 运行结束后,在刚刚新建的文件夹中,0、1文件夹内各有500张分割好的图片。

最后在test_image、train_image分别新建0、1文件夹。

把data\0中的0-399复制到data\test_image\0,399-499复制到data\train_image\0;

把data\1中的0-399复制到data\test_image\1,399-499复制到data\train_image\1。第一步完成。

  1. --D:
  2. --data
  3. --0
  4. --1
  5. --train_image
  6. --0400张)
  7. --1400张)
  8. --test_image
  9. --0100张)
  10. --1100张)

 二、模型训练

 再新建一个源文件:SVM模型训练.cpp,将第一步的SVM数据准备文件从项目中移除。

复制上如下代码,其中最主要的就是opencv4中的SVM函数改变很大,配置参数上与原文完全不同

  1. #include <stdio.h>
  2. #include <time.h>
  3. #include <opencv2/opencv.hpp>
  4. #include <iostream>
  5. #include <opencv2/core/core.hpp>
  6. #include <opencv2/highgui/highgui.hpp>
  7. #include <opencv2/imgproc/imgproc.hpp>
  8. #include "opencv2/imgcodecs.hpp"
  9. #include <opencv2/core/utils/logger.hpp>
  10. #include <opencv2/ml/ml.hpp>
  11. #include <io.h>
  12. using namespace std;
  13. using namespace cv;
  14. using namespace cv::ml;
  15. void getFiles(string path, vector<string>& files);
  16. void get_1(Mat& trainingImages, vector<int>& trainingLabels);
  17. void get_0(Mat& trainingImages, vector<int>& trainingLabels);
  18. int main()
  19. {
  20. //获取训练数据
  21. Mat classes;
  22. Mat trainingData;
  23. Mat trainingImages;
  24. vector<int> trainingLabels;
  25. get_1(trainingImages, trainingLabels);
  26. //waitKey(2000);
  27. get_0(trainingImages, trainingLabels);
  28. Mat(trainingImages).copyTo(trainingData);
  29. trainingData.convertTo(trainingData, CV_32FC1);
  30. Mat(trainingLabels).copyTo(classes);
  31. //配置SVM训练器参数
  32. Ptr<SVM> svm = SVM::create();
  33. svm->setType(SVM::C_SVC);
  34. svm->setKernel(SVM::LINEAR);
  35. svm->setDegree(0);
  36. svm->setGamma(1);
  37. svm->setCoef0(0);
  38. svm->setC(1);
  39. svm->setNu(0);
  40. svm->setP(0);
  41. svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 1000, 0.01));
  42. //训练
  43. svm->train(trainingData, ROW_SAMPLE, classes );
  44. //保存模型
  45. svm->save("svm.xml");
  46. cout << "训练好了!!!" << endl;
  47. getchar();
  48. return 0;
  49. }
  50. void getFiles(string path, vector<string>& files)
  51. {
  52. long long hFile = 0;
  53. struct _finddata_t fileinfo;
  54. string p;
  55. if ((hFile = _findfirst(p.assign(path).append("\\*").c_str(), &fileinfo)) != -1)
  56. {
  57. do
  58. {
  59. if ((fileinfo.attrib & _A_SUBDIR))
  60. {
  61. if (strcmp(fileinfo.name, ".") != 0 && strcmp(fileinfo.name, "..") != 0)
  62. getFiles(p.assign(path).append("\\").append(fileinfo.name), files);
  63. }
  64. else
  65. {
  66. files.push_back(p.assign(path).append("\\").append(fileinfo.name));
  67. }
  68. } while (_findnext(hFile, &fileinfo) == 0);
  69. _findclose(hFile);
  70. }
  71. }
  72. void get_1(Mat& trainingImages, vector<int>& trainingLabels)
  73. {
  74. string filePath = "D:\\data\\train_image\\1";
  75. cout << "获取D:\\data\\1" << endl;
  76. vector<string> files;
  77. getFiles(filePath, files);
  78. int number = files.size();
  79. for (int i = 0; i < number; i++)
  80. {
  81. Mat SrcImage = imread(files[i].c_str());
  82. SrcImage = SrcImage.reshape(1, 1);
  83. trainingImages.push_back(SrcImage);
  84. trainingLabels.push_back(1);
  85. }
  86. }
  87. void get_0(Mat& trainingImages, vector<int>& trainingLabels)
  88. {
  89. string filePath = "D:\\data\\train_image\\0";
  90. cout << "获取D:\\data\\0" << endl;
  91. vector<string> files;
  92. getFiles(filePath, files);
  93. int number = files.size();
  94. for (int i = 0; i < number; i++)
  95. {
  96. Mat SrcImage = imread(files[i].c_str());
  97. SrcImage = SrcImage.reshape(1, 1);
  98. trainingImages.push_back(SrcImage);
  99. trainingLabels.push_back(0);
  100. }
  101. }

 训练完毕后,在这个解决方案文件夹下就生成了一个.xml文件,即是我们训练出来的模型。

训练时还可以选择自动训练,会自己寻找最优参数,效果也很好。

  1. //训练
  2. svm->trainAuto(trainingData, ROW_SAMPLE, classes );

三、加载模型实现分类

同样的,新建一个源文件:

复制如下代码:

  1. #include <stdio.h>
  2. #include <time.h>
  3. #include <opencv2/opencv.hpp>
  4. #include <iostream>
  5. #include <opencv2/core/core.hpp>
  6. #include <opencv2/highgui/highgui.hpp>
  7. #include <opencv2/imgproc/imgproc.hpp>
  8. #include "opencv2/imgcodecs.hpp"
  9. #include <opencv2/core/utils/logger.hpp>
  10. #include <opencv2/ml/ml.hpp>
  11. #include <io.h>
  12. using namespace std;
  13. using namespace cv;
  14. using namespace cv::ml;
  15. void getFiles(string path, vector<string>& files);
  16. int main()
  17. {
  18. int result = 0;
  19. string filePath = "D:\\data\\test_image\\1";
  20. vector<string> files;
  21. getFiles(filePath, files);
  22. int number = files.size();
  23. cout << number << endl;
  24. string modelpath = "svm.xml";
  25. cv::Ptr<cv::ml::SVM> svm;
  26. svm = cv::Algorithm::load<cv::ml::SVM>(modelpath);
  27. /*CvSVM svm;
  28. svm.clear();
  29. string modelpath = "svm.xml";
  30. FileStorage svm_fs(modelpath, FileStorage::READ);
  31. if (svm_fs.isOpened())
  32. {
  33. svm.load(modelpath.c_str());
  34. }*/
  35. for (int i = 0; i < number; i++)
  36. {
  37. Mat inMat = imread(files[i].c_str());
  38. Mat p = inMat.reshape(1, 1);
  39. p.convertTo(p, CV_32FC1);
  40. int response = (int)svm->predict(p);
  41. if (response == 1)//要预测1,如果用0来做测试集就改成response == 0
  42. {
  43. result++;
  44. }
  45. else
  46. {
  47. cout << files[i] << endl;
  48. }
  49. }
  50. cout << result << endl;
  51. getchar();
  52. return 0;
  53. }
  54. void getFiles(string path, vector<string>& files)
  55. {
  56. long long hFile = 0;
  57. struct _finddata_t fileinfo;
  58. string p;
  59. if ((hFile = _findfirst(p.assign(path).append("\\*").c_str(), &fileinfo)) != -1)
  60. {
  61. do
  62. {
  63. if ((fileinfo.attrib & _A_SUBDIR))
  64. {
  65. if (strcmp(fileinfo.name, ".") != 0 && strcmp(fileinfo.name, "..") != 0)
  66. getFiles(p.assign(path).append("\\").append(fileinfo.name), files);
  67. }
  68. else
  69. {
  70. files.push_back(p.assign(path).append("\\").append(fileinfo.name));
  71. }
  72. } while (_findnext(hFile, &fileinfo) == 0);
  73. _findclose(hFile);
  74. }
  75. }

 如果想要检测0的分类准确率就让第46行的response == 0。

 可以看到,100张1有99张被识别出来,有一张452没有识别成功。100张0都识别出来了。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/263913
推荐阅读
相关标签
  

闽ICP备14008679号