当前位置:   article > 正文

C/C++开发,opencv-ml库学习,K近邻(KNN)应用_opencv knearest c++

opencv knearest c++

目录

一、k近邻算法

1.1 算法简介

1.2 opencv-k近邻算法

二、cv::ml::KNearest应用

2.1 数据集样本准备

2.2 KNearest应用

2.3 程序编译

2.4 main.cpp全代码


一、k近邻算法

1.1 算法简介

        K近邻算法(K-Nearest Neighbor,KNN)基本原理是:

        在特征空间中,如果一个样本附近的K个最近(即特征空间中最邻近)样本的大多数属于某一个类别,则该样本也属于这个类别。具体来说,给定一个训练数据集,对于新的输入实例,KNN算法会在训练数据集中找到与该实例最邻近的K个实例(即K个邻居)。然后,根据这K个邻居的类别进行投票,将票数最多的类别作为新输入实例的预测类别。

        KNN算法的优点包括:

  1. 简单易理解:KNN算法非常直观和简单,易于理解和实现。
  2. 适用于多分类问题:KNN算法可以很容易地应用于多分类问题。
  3. 适用于非线性数据:KNN算法对于非线性数据具有良好的适应性。
  4. 无需训练:KNN算法属于懒惰学习(lazy learning),不需要训练过程,节省了模型训练时间。

        KNN算法也存在一些缺点:

  1. 需要大量内存:KNN算法需要存储所有训练数据,因此在处理大规模数据集时需要大量内存。
  2. 计算复杂度高:当训练集很大时,KNN算法需要计算测试样本与所有训练样本之间的距离,计算复杂度较高。
  3. 预测时间长:由于需要计算测试样本与所有训练样本的距离,KNN算法的预测时间较长。
  4. 敏感度高:KNN算法对于异常值和噪声数据非常敏感,容易受到局部特征的影响。

        KNN算法的应用实例包括但不限于手写数字识别、电影推荐系统、人脸识别和疾病诊断等。在这些应用中,KNN算法可以通过计算测试样本与训练样本之间的距离,找到最相似的邻居,并根据邻居的类别进行预测或分类。常用的距离度量方法包括欧式距离、曼哈顿距离、闵可夫斯基距离等。

1.2 opencv-k近邻算法

        OpenCV中的K近邻算法(K-Nearest Neighbors, KNN)是一种常用的监督学习算法,主要用于分类和回归问题。该算法基于样本之间的距离来进行分类或回归。

        1)对于分类问题,KNN算法将未知样本与训练集中的样本逐个比较距离,并选择距离最近的K个邻居样本。然后,根据这K个邻居样本的标签进行投票,将未知样本归类为票数最多的标签。

        2)对于回归问题,KNN算法同样将未知样本与训练集中的样本逐个比较距离,并选择距离最近的K个邻居样本。但是,在回归问题中,KNN算法会取这K个邻居样本的平均值作为未知样本的预测值。

        OpenCV中的KNN算法实现包含在ml模块中,函数为cv.ml.KNearest_create()。通过这个函数,你可以创建一个KNN分类器或回归器,并使用训练数据对其进行训练。训练完成后,你可以使用训练好的模型对新的数据进行预测。

  1. k近邻算法是使用的是KNearest类 继承了StatModel类(base类)
  2. class CV_EXPORTS_W KNearest : public StatModel
  3. {
  4. public:
  5. //类代码
  6. };
  7. StatModel类 方法:
  8. 训练函数
  9. ret = cv.ml_StatModel.train(samples,layout,responses)
  10. samples: 训练的样本矩阵
  11. layout: 排列方式
  12. responses: 标签矩阵
  13. 返还一个bool类型变量来作为是否完成了模型训练
  14. samples 必须为float32类型
  15. layout cv.ml.ROW_SAMPLE 样本按行排列
  16. cv.ml.COL_SAMPLE 按列排列
  17. responses: 单行或者单列的矩阵 类型为int 或者 float
  18. 检测函数
  19. retval, res = cv.ml_StatModel.predict(samples)
  20. flags模型标志
  21. res 结果矩阵
  22. samples 输入矩阵
  23. retval: 第一个值得标签
  24. 除了使用StatModel提供的通用的预测方法
  25. KNearest类也提供了预测方法
  26. retval,results,neighborResponses,dist = cv.ml_KNearest.findNearest(
  27. sample
  28. k
  29. )
  30. sample: 待预测数据
  31. k: 近邻数
  32. results: 预测结果
  33. neighborResponses: 可以选择输出的每个数据的k个最近邻
  34. dist: 输出k个最近邻的距离

        KNN算法的优点包括在线技术(新数据可以直接加入数据集而不必进行重新训练)、理论简单、容易实现、准确性高和对异常值和噪声有较高的容忍度。然而,KNN算法也存在一些缺点,如对于样本容量大的数据集计算量比较大、容易导致维度灾难、样本不平衡时预测偏差比较大,以及k值大小的选择需要依靠经验或交叉验证等。

        在OpenCV中,KNN算法可以用于各种图像处理任务,如图像分类、目标检测和模式识别等。通过调整k值和使用不同的距离度量方法,你可以优化KNN算法的性能以适应你的具体任务。

二、cv::ml::KNearest应用

2.1 数据集样本准备

   本文为了快速验证使用,采用mnist数据集,参考本专栏博文《C/C++开发,opencv-ml库学习,支持向量机(SVM)应用-CSDN博客》下载MNIST 数据集(手写数字识别),并解压。

        同时参考该博文“2.4 SVM(支持向量机)实时识别应用”的章节资料,利用python代码解压t10k-images.idx3-ubyte出图片数据文件。

2.2 KNearest应用

         类似ml模块的其他算法一样,创建了一个 cv::ml::KNearest对象,并设置了训练数据和终止条件。接着,我们调用 train 方法来训练决策树模型。最后,我们使用训练好的模型来预测一个新样本的类别。

  1. // 4. 设置并训练KNN模型
  2. // 创建KNN模型
  3. cv::Ptr<cv::ml::KNearest> knn = cv::ml::KNearest::create();
  4. // 设置KNN参数
  5. knn->setAlgorithmType(cv::ml::KNearest::BRUTE_FORCE); // 使用暴力搜索
  6. knn->setIsClassifier(true); // 设置为分类器
  7. knn->setDefaultK(3); // 设置K值
  8. // 训练KNN模型
  9. knn->train(trainingData, cv::ml::ROW_SAMPLE, labelsMat);
  10. //同样预测函数调用
  11. cv::Mat testResp;
  12. float response = knn->predict(testData,testResp);
  13. //存储模型,文件名借用了博文的命名,不必在意
  14. knn->save("mnist_svm.xml");

        训练及测试过的算法模型,保存输出(.xml),然后调用。PS,训练图片解压读取请参见C/C++开发,opencv-ml库学习,支持向量机(SVM)应用-CSDN博客的“2.4 SVM(支持向量机)实时识别应用”章节。

  1. cv::Ptr<cv::ml::KNearest> kkn = cv::ml::StatModel::load<cv::ml::KNearest>("mnist_svm.xml");
  2. //read img 28*28 size
  3. cv::Mat image = cv::imread(fileName, cv::IMREAD_GRAYSCALE);
  4. //uchar->float32
  5. image.convertTo(image, CV_32F);
  6. //image data normalization
  7. image = image / 255.0;
  8. //28*28 -> 1*784
  9. image = image.reshape(1, 1);
  10. //预测图片
  11. float ret = knn->predict(image);
  12. std::cout << "predict val = "<< ret << std::endl;
2.3 程序编译

        和讲述支持向量机(SVM)应用的博文编译类似,采用opencv+mingw+makefile方式编译:

  1. #/bin/sh
  2. #win32
  3. CX= g++ -DWIN32
  4. #linux
  5. #CX= g++ -Dlinux
  6. BIN := ./
  7. TARGET := opencv_ml04.exe
  8. FLAGS := -std=c++11 -static
  9. SRCDIR := ./
  10. #INCLUDES
  11. INCLUDEDIR := -I"../../opencv_MinGW/include" -I"./"
  12. #-I"$(SRCDIR)"
  13. staticDir := ../../opencv_MinGW/x64/mingw/staticlib/
  14. #LIBDIR := $(staticDir)/libopencv_world460.a\
  15. # $(staticDir)/libade.a \
  16. # $(staticDir)/libIlmImf.a \
  17. # $(staticDir)/libquirc.a \
  18. # $(staticDir)/libzlib.a \
  19. # $(wildcard $(staticDir)/liblib*.a) \
  20. # -lgdi32 -lComDlg32 -lOleAut32 -lOle32 -luuid
  21. #opencv_world放弃前,然后是opencv依赖的第三方库,后面的库是MinGW编译工具的库
  22. LIBDIR := -L $(staticDir) -lopencv_world460 -lade -lIlmImf -lquirc -lzlib \
  23. -llibjpeg-turbo -llibopenjp2 -llibpng -llibprotobuf -llibtiff -llibwebp \
  24. -lgdi32 -lComDlg32 -lOleAut32 -lOle32 -luuid
  25. source := $(wildcard $(SRCDIR)/*.cpp)
  26. $(TARGET) :
  27. $(CX) $(FLAGS) $(INCLUDEDIR) $(source) -o $(BIN)/$(TARGET) $(LIBDIR)
  28. clean:
  29. rm $(BIN)/$(TARGET)

        make编译,make clean 清除可重新编译。

        运行效果,同样数据样本,相比前面博文所述算法训练结果,其准确率有了较大改善,大家可以尝试调整参数验证: 

2.4 main.cpp全代码

        main.cpp源代码,由于是基于前三篇博文支持向量机(SVM)应用、决策树(DTrees)应用、随机森林(RTrees)应用基础上,快速移用实现的,有很多支持向量机(SVM)应用或决策树(DTrees)的痕迹,采用的数据样本也非较合适的,仅仅是为了阐述c++ opencv K近邻算法(KNearest)应用说明。

  1. #include <opencv2/opencv.hpp>
  2. #include <opencv2/ml/ml.hpp>
  3. #include <opencv2/imgcodecs.hpp>
  4. #include <iostream>
  5. #include <vector>
  6. #include <iostream>
  7. #include <fstream>
  8. int intReverse(int num)
  9. {
  10. return (num>>24|((num&0xFF0000)>>8)|((num&0xFF00)<<8)|((num&0xFF)<<24));
  11. }
  12. std::string intToString(int num)
  13. {
  14. char buf[32]={0};
  15. itoa(num,buf,10);
  16. return std::string(buf);
  17. }
  18. cv::Mat read_mnist_image(const std::string fileName) {
  19. int magic_number = 0;
  20. int number_of_images = 0;
  21. int img_rows = 0;
  22. int img_cols = 0;
  23. cv::Mat DataMat;
  24. std::ifstream file(fileName, std::ios::binary);
  25. if (file.is_open())
  26. {
  27. std::cout << "open images file: "<< fileName << std::endl;
  28. file.read((char*)&magic_number, sizeof(magic_number));//format
  29. file.read((char*)&number_of_images, sizeof(number_of_images));//images number
  30. file.read((char*)&img_rows, sizeof(img_rows));//img rows
  31. file.read((char*)&img_cols, sizeof(img_cols));//img cols
  32. magic_number = intReverse(magic_number);
  33. number_of_images = intReverse(number_of_images);
  34. img_rows = intReverse(img_rows);
  35. img_cols = intReverse(img_cols);
  36. std::cout << "format:" << magic_number
  37. << " img num:" << number_of_images
  38. << " img row:" << img_rows
  39. << " img col:" << img_cols << std::endl;
  40. std::cout << "read img data" << std::endl;
  41. DataMat = cv::Mat::zeros(number_of_images, img_rows * img_cols, CV_32FC1);
  42. unsigned char temp = 0;
  43. for (int i = 0; i < number_of_images; i++) {
  44. for (int j = 0; j < img_rows * img_cols; j++) {
  45. file.read((char*)&temp, sizeof(temp));
  46. //svm data is CV_32FC1
  47. float pixel_value = float(temp);
  48. DataMat.at<float>(i, j) = pixel_value;
  49. }
  50. }
  51. std::cout << "read img data finish!" << std::endl;
  52. }
  53. file.close();
  54. return DataMat;
  55. }
  56. cv::Mat read_mnist_label(const std::string fileName) {
  57. int magic_number;
  58. int number_of_items;
  59. cv::Mat LabelMat;
  60. std::ifstream file(fileName, std::ios::binary);
  61. if (file.is_open())
  62. {
  63. std::cout << "open label file: "<< fileName << std::endl;
  64. file.read((char*)&magic_number, sizeof(magic_number));
  65. file.read((char*)&number_of_items, sizeof(number_of_items));
  66. magic_number = intReverse(magic_number);
  67. number_of_items = intReverse(number_of_items);
  68. std::cout << "format:" << magic_number << " ;label_num:" << number_of_items << std::endl;
  69. std::cout << "read Label data" << std::endl;
  70. //data type:CV_32SC1,channel:1
  71. LabelMat = cv::Mat::zeros(number_of_items, 1, CV_32SC1);
  72. for (int i = 0; i < number_of_items; i++) {
  73. unsigned char temp = 0;
  74. file.read((char*)&temp, sizeof(temp));
  75. LabelMat.at<unsigned int>(i, 0) = (unsigned int)temp;
  76. }
  77. std::cout << "read label data finish!" << std::endl;
  78. }
  79. file.close();
  80. return LabelMat;
  81. }
  82. //change path for real paths
  83. std::string trainImgFile = "D:\\workForMy\\OpenCVLib\\opencv_demo\\opencv_ml01\\train-images.idx3-ubyte";
  84. std::string trainLabeFile = "D:\\workForMy\\OpenCVLib\\opencv_demo\\opencv_ml01\\train-labels.idx1-ubyte";
  85. std::string testImgFile = "D:\\workForMy\\OpenCVLib\\opencv_demo\\opencv_ml01\\t10k-images.idx3-ubyte";
  86. std::string testLabeFile = "D:\\workForMy\\OpenCVLib\\opencv_demo\\opencv_ml01\\t10k-labels.idx1-ubyte";
  87. void train_SVM()
  88. {
  89. //read train images, data type CV_32FC1
  90. cv::Mat trainingData = read_mnist_image(trainImgFile);
  91. //images data normalization
  92. trainingData = trainingData/255.0;
  93. std::cout << "trainingData.size() = " << trainingData.size() << std::endl;
  94. std::cout << "trainingData.type() = " << trainingData.type() << std::endl;
  95. std::cout << "trainingData.rows = " << trainingData.rows << std::endl;
  96. std::cout << "trainingData.cols = " << trainingData.cols << std::endl;
  97. //read train label, data type CV_32SC1
  98. cv::Mat labelsMat = read_mnist_label(trainLabeFile);
  99. std::cout << "labelsMat.size() = " << labelsMat.size() << std::endl;
  100. std::cout << "labelsMat.type() = " << labelsMat.type() << std::endl;
  101. std::cout << "labelsMat.rows = " << labelsMat.rows << std::endl;
  102. std::cout << "labelsMat.cols = " << labelsMat.cols << std::endl;
  103. std::cout << "trainingData & labelsMat finish!" << std::endl;
  104. // //create SVM model
  105. // cv::Ptr<cv::ml::SVM> svm = cv::ml::SVM::create();
  106. // //set svm args,type and KernelTypes
  107. // svm->setType(cv::ml::SVM::C_SVC);
  108. // svm->setKernel(cv::ml::SVM::POLY);
  109. // //KernelTypes POLY is need set gamma and degree
  110. // svm->setGamma(3.0);
  111. // svm->setDegree(2.0);
  112. // //Set iteration termination conditions, maxCount is importance
  113. // svm->setTermCriteria(cv::TermCriteria(cv::TermCriteria::EPS | cv::TermCriteria::COUNT, 1000, 1e-8));
  114. // std::cout << "create SVM object finish!" << std::endl;
  115. // std::cout << "trainingData.rows = " << trainingData.rows << std::endl;
  116. // std::cout << "trainingData.cols = " << trainingData.cols << std::endl;
  117. // std::cout << "trainingData.type() = " << trainingData.type() << std::endl;
  118. // // svm model train
  119. // svm->train(trainingData, cv::ml::ROW_SAMPLE, labelsMat);
  120. // std::cout << "SVM training finish!" << std::endl;
  121. // // 创建决策树对象
  122. // cv::Ptr<cv::ml::DTrees> dtree = cv::ml::DTrees::create();
  123. // dtree->setMaxDepth(30); // 设置树的最大深度
  124. // dtree->setCVFolds(0);
  125. // dtree->setMinSampleCount(1); // 设置分裂内部节点所需的最小样本数
  126. // std::cout << "create dtree object finish!" << std::endl;
  127. // // 训练决策树--trainingData训练数据,labelsMat训练标签
  128. // cv::Ptr<cv::ml::TrainData> td = cv::ml::TrainData::create(trainingData, cv::ml::ROW_SAMPLE, labelsMat);
  129. // std::cout << "create TrainData object finish!" << std::endl;
  130. // if(dtree->train(td))
  131. // {
  132. // std::cout << "dtree training finish!" << std::endl;
  133. // }else{
  134. // std::cout << "dtree training fail!" << std::endl;
  135. // }
  136. // // 3. 设置并训练随机森林模型
  137. // cv::Ptr<cv::ml::RTrees> rf = cv::ml::RTrees::create();
  138. // rf->setMaxDepth(30); // 设置决策树的最大深度
  139. // rf->setMinSampleCount(2); // 设置叶子节点上的最小样本数
  140. // rf->setTermCriteria(cv::TermCriteria(cv::TermCriteria::EPS + cv::TermCriteria::COUNT, 10, 0.1)); // 设置终止条件
  141. // rf->train(trainingData, cv::ml::ROW_SAMPLE, labelsMat);
  142. // 4. 设置并训练KNN模型
  143. // 创建KNN模型
  144. cv::Ptr<cv::ml::KNearest> knn = cv::ml::KNearest::create();
  145. // 设置KNN参数
  146. knn->setAlgorithmType(cv::ml::KNearest::BRUTE_FORCE); // 使用暴力搜索
  147. knn->setIsClassifier(true); // 设置为分类器
  148. knn->setDefaultK(3); // 设置K值
  149. // 训练KNN模型
  150. knn->train(trainingData, cv::ml::ROW_SAMPLE, labelsMat);
  151. // svm model test
  152. cv::Mat testData = read_mnist_image(testImgFile);
  153. //images data normalization
  154. testData = testData/255.0;
  155. std::cout << "testData.rows = " << testData.rows << std::endl;
  156. std::cout << "testData.cols = " << testData.cols << std::endl;
  157. std::cout << "testData.type() = " << testData.type() << std::endl;
  158. //read test label, data type CV_32SC1
  159. cv::Mat testlabel = read_mnist_label(testLabeFile);
  160. cv::Mat testResp;
  161. // float response = svm->predict(testData,testResp);
  162. // float response = dtree->predict(testData,testResp);
  163. // float response = rf->predict(testData,testResp);
  164. float response = knn->predict(testData,testResp);
  165. // std::cout << "response = " << response << std::endl;
  166. testResp.convertTo(testResp,CV_32SC1);
  167. int map_num = 0;
  168. for (int i = 0; i <testResp.rows&&testResp.rows==testlabel.rows; i++)
  169. {
  170. if (testResp.at<int>(i, 0) == testlabel.at<int>(i, 0))
  171. {
  172. map_num++;
  173. }
  174. // else{
  175. // std::cout << "testResp.at<int>(i, 0) " << testResp.at<int>(i, 0) << std::endl;
  176. // std::cout << "testlabel.at<int>(i, 0) " << testlabel.at<int>(i, 0) << std::endl;
  177. // }
  178. }
  179. float proportion = float(map_num) / float(testResp.rows);
  180. std::cout << "map rate: " << proportion * 100 << "%" << std::endl;
  181. std::cout << "SVM testing finish!" << std::endl;
  182. //save svm model
  183. // svm->save("mnist_svm.xml");
  184. // dtree->save("mnist_svm.xml");
  185. // rf->save("mnist_svm.xml");
  186. knn->save("mnist_svm.xml");
  187. }
  188. void prediction(const std::string fileName,cv::Ptr<cv::ml::KNearest> knn)
  189. // void prediction(const std::string fileName,cv::Ptr<cv::ml::DTrees> dtree)
  190. // void prediction(const std::string fileName,cv::Ptr<cv::ml::SVM> svm)
  191. {
  192. //read img 28*28 size
  193. cv::Mat image = cv::imread(fileName, cv::IMREAD_GRAYSCALE);
  194. //uchar->float32
  195. image.convertTo(image, CV_32F);
  196. //image data normalization
  197. image = image / 255.0;
  198. //28*28 -> 1*784
  199. image = image.reshape(1, 1);
  200. //预测图片
  201. // float ret = dtree->predict(image);
  202. float ret = knn->predict(image);
  203. std::cout << "predict val = "<< ret << std::endl;
  204. }
  205. std::string imgDir = "D:\\workForMy\\OpenCVLib\\opencv_demo\\opencv_ml01\\t10k-images\\";
  206. std::string ImgFiles[5] = {"image_0.png","image_10.png","image_20.png","image_30.png","image_40.png",};
  207. void predictimgs()
  208. {
  209. //load svm model
  210. // cv::Ptr<cv::ml::SVM> svm = cv::ml::StatModel::load<cv::ml::SVM>("mnist_svm.xml");
  211. //load DTrees model
  212. // cv::Ptr<cv::ml::DTrees> dtree = cv::ml::StatModel::load<cv::ml::DTrees>("mnist_svm.xml");
  213. // cv::Ptr<cv::ml::RTrees> rf = cv::ml::StatModel::load<cv::ml::RTrees>("mnist_svm.xml");
  214. cv::Ptr<cv::ml::KNearest> kkn = cv::ml::StatModel::load<cv::ml::KNearest>("mnist_svm.xml");
  215. for (size_t i = 0; i < 5; i++)
  216. {
  217. prediction(imgDir+ImgFiles[i],kkn);
  218. }
  219. }
  220. int main()
  221. {
  222. train_SVM();
  223. predictimgs();
  224. return 0;
  225. }
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/801234
推荐阅读
相关标签
  

闽ICP备14008679号