当前位置:   article > 正文

机器学习实战之knn(C++复现)_knnh ,@,,,,,?@、

knnh ,@,,,,,?@、
//knn.h
#ifndef KNN_H_ 
#define KNN_H_ 
#include<opencv2/opencv.hpp>
#include<string>
using cv::Mat;
using std::string;
class knn
{
public:
	knn(string path = " ");
	int classify0(Mat inX, Mat dataSet, Mat labels, int labelNum, int k);//分类
	bool file2matrix(Mat& dataSet, Mat& label);//读取txt文件并转化为Mat
	void autoNorm(Mat& dataSet);//数据归一化到0-1
	~knn(){};

private:
	string dataPath;//数据读取路径
};
#endif
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
//knn.cpp
#include"knn.h"
#include<iostream>
#include<vector>
#include<algorithm>
#include<fstream>
#include<sstream>
#include<string>
using std::string;
using std::vector;
using std::cout;
using std::endl;
using namespace cv;
knn::knn(string path):dataPath(path){}
int knn::classify0(Mat inX, Mat dataSet, Mat labels,int labelNum, int k)
{
	/*
	inX:输入向量,行向量
	dataSet:训练数据集
	labels:训练数据集对应的标签
	labelNum:标签数量
	k:选择最近邻的数目
	函数功能:判断样本的类别
	*/
	Mat temp(dataSet.size(), dataSet.type());
	//将temp的每行赋值为inX
	for (int i = 0; i < dataSet.rows; i++)
	{
		inX.row(0).copyTo(temp.row(i));

	}
	//cout << "temp = " << temp << endl;
	//计算距离
	Mat diffMat = temp - dataSet;
	//cout << "diffMat = " << diffMat << endl;
	Mat sqDiffMat = diffMat.mul(diffMat);
	//cout << "sqDiffMat = " << sqDiffMat << endl;
	//每行求和
	Mat sqDistances(sqDiffMat.rows, 1, CV_32F);
	reduce(sqDiffMat, sqDistances, 1, CV_REDUCE_SUM);//行或列求和,求平均,求最大最小函数
	//cout << "sqDistances = " << sqDistances << endl;
	Mat distances;
	sqrt(sqDistances, distances);//平方根
	//cout << "distances = " << distances << endl;
	//升序排列,返回对应的序号
	Mat index;
	sortIdx(distances, index, CV_SORT_EVERY_COLUMN + CV_SORT_ASCENDING);//行向量或者列向量排列,返回对应的序号
	//cout << "index = " << index;
	//计算得到投票数最多的标签,即判断的标签
	vector<int>classCount(labelNum,0);
	for (int i = 0; i < k; i++)
	{
		int minIndex = index.at<int>(i, 0);
		int label = labels.at<float>(minIndex, 0);
		classCount[label]++;
	}
	int maxVal = classCount[0];
	int resultLabel = 0;
	for (int i = 1; i < labelNum; i++)
	{
		if (maxVal < classCount[i])
		{
			maxVal = classCount[i];
			resultLabel = i;
		}
	}
	return resultLabel;
}
bool knn::file2matrix(Mat& dataSet, Mat& labels)
{
	//读取数据
	std::ifstream file;
	file.open(dataPath);
	if (!file.is_open())
	{
		cout << "read fail..." << endl;
		return false;
	}
	string s;
	vector<Mat>vecData;
	vector<Mat>vecLabels;
	while (std::getline(file, s))
	{
		std::istringstream data(s);//将读入的一行中的分割开,包含头文件sstream
		float s1;
		Mat temp1(1, 3, CV_32F);//存放数据
		Mat temp2(1, 1, CV_32F);//存放标签
		int count = 0;
		while (data >> s1)//data >> s1 string强制转换成类型
		{
			if (count != 3)
				temp1.at<float>(0, count) = s1;
			else
				temp2.at<float>(0, 0) = s1;
			count++;

		}
		vecData.push_back(temp1);
		vecLabels.push_back(temp2);
	}
	dataSet.create(vecData.size(), 3, CV_32F);
	labels.create(vecLabels.size(), 1, CV_32F);
	for (int i = 0; i < vecData.size(); i++)
	{
		vecData[i].copyTo(dataSet.row(i));
		vecLabels[i].copyTo(labels.row(i));
	}
	return true;
}
void knn::autoNorm(Mat& dataSet)
{
	/*
	函数功能:数据归一化到0-1
	dataSet:数据集
	*/
	for (int i = 0; i < dataSet.cols; i++)
	{
		normalize(dataSet.col(i), dataSet.col(i), 0, 1, NORM_MINMAX);//逐行归一化
	}
}
float datingClassTest(string path)
{
	/*
	函数功能:测试函数
	*/
	knn a(path);
	Mat dataSet, labels;
	a.file2matrix(dataSet, labels);
	a.autoNorm(dataSet);

	float hoRatio = 0.1;//10%作为测试集
	int num = dataSet.rows * hoRatio;//测试集数量
	int errorCount = 0;
	Mat testDataSet = dataSet(Range(0, num), Range(0, 3));
	Mat testLabels = labels(Range(0, num), Range(0, 1));
	Mat trainDataSet = dataSet(Range(num,dataSet.rows), Range(0, 3));
	Mat trainLabels = labels(Range(num, dataSet.rows), Range(0, 1));
	for (int i = 0; i < num; i++)
	{
		float predictLabel = a.classify0(testDataSet.row(i), trainDataSet, trainLabels, 4,4);
		//cout << testLabels.at<float>(i, 0) << endl;
		if (predictLabel != testLabels.at<float>(i, 0))
			errorCount++;
	}
	float error = float(errorCount) / num;
	cout << "the total error is :" << error;
	return error;


}
int main()
{
	string path = "E:/c++工程/knn/Ch02/datingTestSet2.txt";
	datingClassTest(path);

	system("pause");
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号