当前位置:   article > 正文

k-近邻算法_k近邻算法

k近邻算法

k-近邻算法


k-近邻算法概述

k-近邻算法(k-NearestNeighor Algorithm)是采用测量不同特征值之间的距离方法进行分类,简称kNN。

这里用到的距离计算是欧几里德距离。

工作原理:存在一个样本数据集合(x_{1}^{i},x_{2}^{i},...,xki,yi)i(0,n),也称作训练样本集,并且样本集中每条数据都存在标签yi,即我们知道样本集中每一数据与所属分类的对应关系。输入没有标签的新数据后(x1,x2,...,xk),将新数据的每个特征与样本集中数据对应的特征进行比较,然后算法提取样本集中特征最相似数据(最近邻)的分类标签。一般来说,我们只选择样本数据集中前k个最相似的数据,这就是k-近邻算法中k的出处,通常k是不大于20的整数。

距离计算如下:

  

然后把d按照从小到大排序,选择k个最相似数据中出现次数中最多的分类,作为新数据的分类。

kNN除了做分类任务,还可以做回归任务,在后面的章节中会讲到。


k-近邻算法的一般流程

(1)收集数据:方法不限

(2)准备数据:距离计算所需要的数值,最好是结构化的数据格式

(3)分析数据:方法不限

(4)训练算法:此步骤不适用于k-近邻算法(不需要训练,直接计算距离,找出距离最近的k个

(5)测试算法:计算错误率

(6)使用算法:首先需要输入样本数据和结构化的输出结果,然后运行k-近邻算法判定输入数据分别属于哪个分类,最后应用对计算出的分类执行后续的处理


代码实现

所有代码实现都是在python集成的IDLE上完成,创建名为kNN.py的python模块,本节使用的代码都在这个文件中。

1.导包

  1. '''
  2. KNN是一种最简单最有效的算法,但是KNN必须保留所有的数据集,
  3. 如果训练数据集很大,必须使用大量的存储空间,
  4. 此外,需要对每一个数据计算距离,非常耗时
  5. 另外,它无法给出任何数据的基础结构信息(无法给出一个模型)
  6. '''
  7. from numpy import *
  8. import operator
  9. import matplotlib
  10. import matplotlib.pyplot as plt
  11. from os import listdir
  12. import numpy as np
  13. import matplotlib as mpl
  14. import matplotlib.lines as mlines

2.导入数据

  1. #使用python导入数据,创建数据集和标签
  2. def createDataSet():
  3. group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
  4. labels = ['A','A','B','B']
  5. return group,labels
  6. #每次使用数据,调用该函数即可

3.实施kNN分类算法

  1. '''
  2. 伪代码
  3. 对未知类别属性的数据集中的每个点依此执行以下操作
  4. 1、计算已知类别数据集中的点与当前点之间的距离;
  5. 2、按照距离递增次序排序;
  6. 3、选取与当前点距离最小的k个点;
  7. 4、确定前k个点所在的类别的出现频率;
  8. 5、返回前k个点出现频率最高的类别作为当前点的预测分类。
  9. '''
  10. #实施kNN分类算法
  11. def classify0(inX,dataSet,labels,k):
  12. dataSetSize = dataSet.shape[0]#查看矩阵的维度
  13. diffMat = tile(inX,(dataSetSize,1)) - dataSet
  14. #tile(数组,(在行上重复次数,在列上重复次数))
  15. sqDiffMat = diffMat**2
  16. sqDistances = sqDiffMat.sum(axis=1)
  17. #sum默认axis=0,是普通的相加,axis=1是将一个矩阵的每一行向量相加
  18. distances = sqDistances**0.5
  19. sortedDistIndicies = distances.argsort()
  20. #sort函数按照数组值从小到大排序
  21. #argsort函数返回的是数组值从小到大的索引值
  22. classCount={}
  23. for i in range(k):
  24. voteIlabel = labels[sortedDistIndicies[i]]
  25. classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
  26. #get(key,k),当字典dic中不存在key时,返回默认值k;存在时返回key对应的值
  27. sortedClassCount = sorted(classCount.items(),
  28. key=operator.itemgetter(1),reverse=True)
  29. #python2中用iteritems,python3中用items代替;operator.itemgetter(k),返回第k个域的值
  30. return sortedClassCount[0][0]

4、测试

  1. #测试KNN
  2. #>>> import KNN
  3. #>>> group,labels = KNN.createDataSet()
  4. #>>> KNN.classify0([0,0],group,labels,3) # B
  5. #>>> KNN.classify0([1.2,1.5],group,labels,3) # A

自行测试~


总结

本节实现了比较简单的k-近邻算法,相信看完会加深对kNN的理解!

欢迎交流~

下一节将会结合一个实例-手写数字识别系统来加深对kNN的理解!

声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号