赞
踩
这次我们使用朴素贝叶斯分类最常用的3个衍生模型:多项式模型、伯努利模型、高斯模型对手写数字图像进行训练分类。
我们直接采用sklearn框架,这个框架基本包含了所有机器学习统计学习模型,导入使用很方便简单,也省得我们自己手写模型框架了。
关于朴素贝叶斯分类看我之前的博文,统计学习(二)朴素贝叶斯分类,当然博文有什么问题也欢迎及时指出交流,谢谢~
完整工程文件点击这里
如果对你的学习有所帮助欢迎点star,感谢~
MNIST 数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据。
数据内容为60000张训练集,包含了手写数字0-9,10000张测试集,每个样本为28 x 28大小的矩阵数据,每个元素数值在0-255之间。
数据的获取方式:
可从官网直接下载: http://yann.lecun.com/exdb/mnist/ ,下载得到的内容为:
Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 包含 60,000 个样本)
Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解压后 60 KB, 包含 60,000 个标签为0-9)
Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解压后 7.8 MB, 包含 10,000 个样本)
Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解压后 10 KB, 包含 10,000 个标签为0-9)
这样解压还挺麻烦,可以这样下载。
tensorflow 里面包含了很多常见的机器学习数据集,可以直接下载下来,代码如下:
import tensorflow as tf
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif']=['SimHei'] #显示中文标签
plt.rcParams['axes.unicode_minus']=False #这两行需要手动设置
mnist = tf.keras.datasets.mnist
# 下载
(train_x, train_y), (test_x, test_y) = mnist.load_data()
print('train_x shape=',train_x.shape)
print('test_x shape=',test_x.shape)
plt.imshow(train_x[0])#显示训练集中的一张图片
plt.title('标签='+str(train_y[0]))#把样本的标签作为图像标题
plt.show()
运行内容:
或者也可以直接到我的github仓库里下载,我把数据集上传到了github仓库。https://github.com/fmc123653/Statistical-learning/tree/main/naive_bayes%E6%9C%B4%E7%B4%A0%E8%B4%9D%E5%8F%B6%E6%96%AF%E5%88%86%E7%B1%BB
这里我按我上传到github的数据集mnist.npz为例,加载数据集。
import numpy as np
data=np.load('mnist.npz')
for title in data:#查看所有标签
print(title)
里面成字典关系,我们打印看看。
import numpy as np
data=np.load('mnist.npz')
print('train_x shape=',data['x_train'].shape)
print('test_x shape=',data['x_test'].shape)
数据格式是符合的,训练集60000个样本,每个样本大小为28 x 28。
如果是卷积神经网络是直接对二维图像卷积处理,但是朴素贝叶斯分类,可以理解为28 x 28个元素点,每个像素点单独看成一个特征,于是每个手写数字图像有28 x 28 = 784 个特征信息,因为原始数据是二维矩阵,我们需要转换成一维的1x784格式的数组,这样才方便朴素贝叶斯模型的加载处理。
import numpy as np import matplotlib.pyplot as plt from tqdm import tqdm def rebulid_picture(data):#将数据均值化处理成0或1 for i in range(28): for j in range(28): if data[i][j]>=50: data[i][j]=1 else: data[i][j]=0 return data.reshape([1,28*28])[0]#将数据格式转换为1x784 train_data=[] test_data=[] dic_num={} data = np.load('mnist.npz')#加载数据 id=0 for val in tqdm(data['x_train']): lab=data['y_train'][id] res=list(rebulid_picture(val)) res.append(lab)#把标签作为最后一列 train_data.append(res) id+=1 id=0 for val in tqdm(data['x_test']): lab=data['y_test'][id] res=list(rebulid_picture(val)) res.append(lab)#把标签作为最后一列 test_data.append(res) id+=1 train_data=np.array(train_data)#转换为数组格式 test_data=np.array(test_data) print('train_data.shape=',train_data.shape) print('test_data.shape=',test_data.shape) np.savez('train_data.npz',train_data)#保存为train_data.npz文件 np.savez('test_data.npz',test_data)#保存为test_data.npz文件
这里分别保存为train_data.npz文件和test_data.npz文件,方便模型训练时直接加载,不用每次训练都要重新处理数据。
然后开始训练,多项式模型、伯努利模型和高斯模型之间的联系和不同可以看这篇博客,总结的很不错。朴素贝叶斯的三个常用模型:高斯、多项式、伯努利
我们将处理好的数据分别导入到多项式模型、伯努利模型和高斯模型中,分别得到的结果如下:
import numpy as np import matplotlib.pyplot as plt from sklearn.metrics import accuracy_score,classification_report from sklearn.naive_bayes import MultinomialNB,BernoulliNB,GaussianNB train_data=np.load('train_data.npz',allow_pickle=True)#加载数据 test_data=np.load('test_data.npz',allow_pickle=True) train_x=train_data['arr_0'][:,:-1]#训练集 train_y=train_data['arr_0'][:,-1]#训练集标签 test_x=test_data['arr_0'][:,:-1] test_y=test_data['arr_0'][:,-1] Mnb = MultinomialNB()#加载多项式朴素贝叶斯模型 Bnm = BernoulliNB()#加载伯努利朴素贝叶斯模型 Gnb = GaussianNB()#加载高斯朴素贝叶斯模型 Mnb.fit(train_x,train_y) Mpredict = Mnb.predict(test_x) Bnm.fit(train_x,train_y) Bpredict = Bnm.predict(test_x) Gnb.fit(train_x,train_y) Gpredict = Gnb.predict(test_x) print("多项式模型 accuracy_score: %.4lf" % accuracy_score(Mpredict,test_y)) print("伯努利模型 accuracy_score: %.4lf" % accuracy_score(Bpredict,test_y)) print("高斯模型 accuracy_score: %.4lf" % accuracy_score(Gpredict,test_y))
其中多项式模型和伯努利模型都是针对0或者1这种比较单一的数据处理,效果较好,高斯模型是针对离散的数据,比如某个特征的取值为:0.1,0.3,0.76,0.877这样。所以在这里表现很差。
如果后面遇到离散化的数据可以考虑用高斯朴素贝叶斯模型处理看看。
希望我的分享对你的学习有所帮助,如果有问题请及时指出,谢谢~
赞
踩
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。