当前位置:   article > 正文

【机器学习实战】-k-近邻算法之手写数字识别_机器学习数字识别

机器学习数字识别

机器学习实战】-k-近邻算法之手写数字识别


《机器学习实战》中k-近邻算法中的第二个例子,使用k-近邻算法来识别数字,跟海伦算法区别在于,数字是图像的形式,宽高是32像素×32像素的黑白图像,需要转换成文本格式

1.准备数据

使用的是书中提供的数据集,trainingDigits中包含了大约2000个例子,每个数字大约有200个样本,testDigits中包含了大约900个测试数据。
编写函数img2vector将图像转换成向量,是32×32的规格,有1024个像素,使用函数创建1×1024的numpy数组,循环读出32行,存储下来。
代码:

# 1.数据处理模块
def img2vector(filename):
    # 初始化1×1024的向量存储结果
    returnVect = np.zeros((1, 1024))
    # 打开文件
    fr = open(filename)
    # 逐行读取数据
    for i in range(32):
        lineStr = fr.readline()
        # 逐列读取数据
        for j in range(32):
            returnVect[0, 32*i+j] = int(lineStr[j])
    # 返回处理好的向量
    return returnVect
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

输出第一个数字的前32个像素,
在这里插入图片描述

2.训练算法
一般这一步应该使用训练数据集来对模型进行训练,但是k-近邻算法不需要进行训练,这一步编写函数来计算欧氏距离,最后排序判断未知量的标签,详细解释看k-近邻算法的简易实现
代码:

# 2.分类器使用欧几里得距离
def classify0(inx, dataset, labels, k):
    # 距离计算
    datasetSetSize = dataset.shape[0]
    diffMat = tile(inx, (datasetSetSize, 1)) - dataset
    sqDiffmat = diffMat ** 2
    sqDistances = sqDiffmat.sum(axis=1)
    distance = sqDistances ** 0.5
    # 距离排序
    sortedDistIndicies = distance.argsort()

    classCount = {}
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
    sorteClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sorteClassCount[0][0]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

3.测试算法,使用k-近邻算法识别数字

使用traningDigits中的数据来进行k-近邻判断,然后使用testDigits中的数据来检测正确率,使用os库中的listdir函数读取文件的文件名,然后从文件名中找出图片的类别,然后跟判断类别进行对比。
代码:

# 3.测试算法:使用k-近邻算法识别手写数字
def handwritingClassTest():
    hwLabels = []
    trainingFileList = listdir("trainingDigits")
    m = len(trainingFileList)
    trainingMat = zeros((m, 1024))
    for i in range(m):
        fileNameStr = trainingFileList[i]
        fileStr = fileNameStr.split('.')[0]
        # 从文件名解析分类数字
        classNumStr = int(fileStr.split('_')[0])
        hwLabels.append(classNumStr)
        trainingMat[i, :] = img2vector('trainingDigits/%s' % fileNameStr)
    testFileList = listdir('testDigits')
    errorCount = 0.0
    mTest = len(testFileList)
    for i in range(mTest):
        fileNameStr = testFileList[i]
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])
        vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
        # 使用训练集来进行k近邻算法
        classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
        print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr))
        if(classifierResult != classNumStr):
            errorCount += 1.0
    print("\n the total number of errors is: %d" % errorCount)
    print("\n the total error rate is: %f" % (errorCount/float(mTest)))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28

结果效果图:
在这里插入图片描述
可以看出这个算法的执行效率不高,每个测试向量要跟2000多个数据进行1024维度的欧几里得距离距离计算,然后一共执行900次。还要2MB的存储空间来存储k-近邻的计算数据,需要大的开销,k决策树是k-近邻算法的优化版,可以节省大量的计算开销。还有k-近邻算法的一个缺点是,无法给出任何数据的基础结构信息,无法知晓平均实例样本和典型实例样本有什么特征。如有错误,欢迎批评指正。

坚持,时间不等人了。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/空白诗007/article/detail/1008042
推荐阅读
相关标签
  

闽ICP备14008679号