当前位置:   article > 正文

KNN的C++实现_knn c++库

knn c++库
  1. #include"stdafx.h"
  2. #include<iostream>
  3. #include<map>
  4. #include<vector>
  5. #include<stdio.h>
  6. #include<cmath>
  7. #include<cstdlib>
  8. #include<algorithm>
  9. #include<fstream>
  10. using namespace std;
  11. typedef char tLabel;
  12. typedef double tData;
  13. typedef pair<int, double> PAIR;
  14. const int colLen = 2;
  15. const int rowLen = 12;
  16. ifstream fin;
  17. ofstream fout;
  18. class KNN
  19. {
  20. private:
  21. tData dataSet[rowLen][colLen];
  22. tLabel labels[rowLen];
  23. tData testData[colLen];
  24. int k;
  25. map<int, double> map_index_dis;
  26. map<tLabel, int> map_label_freq;
  27. double get_distance(tData *d1, tData *d2);
  28. public:
  29. KNN(int k);
  30. void get_all_distance();
  31. void get_max_freq_label();
  32. struct CmpByValue
  33. {
  34. bool operator() (const PAIR& lhs, const PAIR& rhs)
  35. {
  36. return lhs.second < rhs.second;
  37. }
  38. };
  39. };
  40. KNN::KNN(int k)
  41. {
  42. this->k = k;
  43. fin.open("data.txt");
  44. if (!fin)
  45. {
  46. cout << "can not open the file data.txt" << endl;
  47. exit(1);
  48. }
  49. /* input the dataSet */
  50. for (int i = 0; i<rowLen; i++)
  51. {
  52. for (int j = 0; j<colLen; j++)
  53. {
  54. fin >> dataSet[i][j];
  55. }
  56. fin >> labels[i];
  57. }
  58. cout << "please input the test data :" << endl;
  59. /* inuput the test data */
  60. for (int i = 0; i<colLen; i++)
  61. cin >> testData[i];
  62. }
  63. /*
  64. * calculate the distance between test data and dataSet[i]
  65. */
  66. double KNN::get_distance(tData *d1, tData *d2)
  67. {
  68. double sum = 0;
  69. for (int i = 0; i<colLen; i++)
  70. {
  71. sum += pow((d1[i] - d2[i]), 2);
  72. }
  73. // cout<<"the sum is = "<<sum<<endl;
  74. return sqrt(sum);
  75. }
  76. /*
  77. * calculate all the distance between test data and each training data
  78. */
  79. void KNN::get_all_distance()
  80. {
  81. double distance;
  82. int i;
  83. for (i = 0; i<rowLen; i++)
  84. {
  85. distance = get_distance(dataSet[i], testData);
  86. //<key,value> => <i,distance>
  87. map_index_dis[i] = distance;
  88. }
  89. //traverse the map to print the index and distance
  90. map<int, double>::const_iterator it = map_index_dis.begin();
  91. while (it != map_index_dis.end())
  92. {
  93. cout << "index = " << it->first << " distance = " << it->second << endl;
  94. it++;
  95. }
  96. }
  97. /*
  98. * check which label the test data belongs to to classify the test data
  99. */
  100. void KNN::get_max_freq_label()
  101. {
  102. //transform the map_index_dis to vec_index_dis
  103. vector<PAIR> vec_index_dis(map_index_dis.begin(), map_index_dis.end());
  104. //sort the vec_index_dis by distance from low to high to get the nearest data
  105. sort(vec_index_dis.begin(), vec_index_dis.end(), CmpByValue());
  106. for (int i = 0; i<k; i++)
  107. {
  108. cout << "the index = " << vec_index_dis[i].first << " the distance = " << vec_index_dis[i].second
  109. << " the label = " << labels[vec_index_dis[i].first]
  110. << " the coordinate ( " << dataSet[vec_index_dis[i].first][0] << "," << dataSet[vec_index_dis[i].first][1] << " )" << endl;
  111. //calculate the count of each label
  112. map_label_freq[labels[vec_index_dis[i].first]]++;
  113. }
  114. map<tLabel, int>::const_iterator map_it = map_label_freq.begin();
  115. tLabel label;
  116. int max_freq = 0;
  117. //find the most frequent label
  118. while (map_it != map_label_freq.end())
  119. {
  120. if (map_it->second > max_freq)
  121. {
  122. max_freq = map_it->second;
  123. label = map_it->first;
  124. }
  125. map_it++;
  126. }
  127. cout << "The test data belongs to the " << label << " label" << endl;
  128. }
  129. int main()
  130. {
  131. int k;
  132. cout << "please input the k value : " << endl;
  133. cin >> k;
  134. KNN knn(k);
  135. knn.get_all_distance();
  136. knn.get_max_freq_label();
  137. system("pause");
  138. return 0;
  139. }

data.txt数据如下:

0.0 1.1 A  
1.0 1.0 A  
2.0 1.0 B  
0.5 0.5 A  
2.5 0.5 B  
0.0 0.0 A  
1.0 0.0 A   
2.0 0.0 B  
3.0 0.0 B  
0.0 -1.0 A  
1.0 -1.0 A  
2.0 -1.0 B

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小惠珠哦/article/detail/801255
推荐阅读
相关标签
  

闽ICP备14008679号