当前位置:   article > 正文

pytorch实现 --- 手写数字识别_人工智能手写数字识别是怎么实现的

人工智能手写数字识别是怎么实现的

        本篇文章是博主在人工智能等领域学习时,用于个人学习、研究或者欣赏使用,并基于博主对人工智能等领域的一些理解而记录的学习摘录和笔记,若有不当和侵权之处,指出后将会立即改正,还望谅解。文章分类在Pytorch

       Pytorch(1)---pytorch实现 --- 手写数字识别》

pytorch实现 --- 手写数字识别

目录

1.项目介绍

2.实现方法

3.程序代码

4.运行结果


1.项目介绍

        使用pytorch实现手写数字识别,十分简单的小项目,环境搭建好,一跑就通。


2.实现方法

2.1方式1        

 安装库:

pip install numpy torch torchvision matplotlib

 运行:

python test.py

首次运行会下载MNIST数据集,请保持网络畅通

2.2方式2

        如果使用pycharm,已经安装好了pytorch环境,那么直接在pytorch环境中运行下面这份代码就好。


3.程序代码

  1. """手写数字识别项目
  2. 时间:2023.11.6
  3. 环境:pytorch
  4. 作者:Rainbook
  5. """
  6. import torch
  7. from torch.utils.data import DataLoader
  8. from torchvision import transforms
  9. from torchvision.datasets import MNIST
  10. import matplotlib.pyplot as plt
  11. class Net(torch.nn.Module): # 定义一个Net类,神经网络的主体
  12. def __init__(self): # 全连接层,四个
  13. super().__init__()
  14. self.fc1 = torch.nn.Linear(28*28, 64) # 输入层输入28*28,输出64
  15. self.fc2 = torch.nn.Linear(64, 64) # 中间层,输入64,输出64
  16. self.fc3 = torch.nn.Linear(64, 64)
  17. self.fc4 = torch.nn.Linear(64, 10) # 中间层(隐藏层)的最后一层,输出10个特征值
  18. def forward(self, x): # 前向传播过程
  19. # self.fc1(x)全连接线性计算,再套上一个激活函数torch.nn.functional.relu()
  20. x = torch.nn.functional.relu(self.fc1(x))
  21. x = torch.nn.functional.relu(self.fc2(x))
  22. x = torch.nn.functional.relu(self.fc3(x))
  23. # 最后一层进行softmax归一化,log_softmax是为了提高计算稳定性,在softmax后面套上了一个对数运算
  24. x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)
  25. return x
  26. def get_data_loader(is_train):
  27. to_tensor = transforms.Compose([transforms.ToTensor()]) # 定义数据转换类型tensor,多维数组(张量)
  28. """下载MNIST数据集,
  29. "":当前位置
  30. is_train:判断是训练集还是测试集;
  31. batch_size:一个批次包含15张图片;
  32. shuffle:数据随机打乱的
  33. """
  34. data_set = MNIST("", is_train, transform=to_tensor, download=True)
  35. return DataLoader(data_set, batch_size=15, shuffle=True) # 数据加载器
  36. def evaluate(test_data, net): # 用来评估神经网络
  37. n_correct = 0
  38. n_total = 0
  39. with torch.no_grad():
  40. for (x, y) in test_data:
  41. outputs = net.forward(x.view(-1, 28*28)) # 计算神经网络的预测值
  42. for i, output in enumerate(outputs): # 对每个批次的预测值进行比较,累加正确预测的数量
  43. if torch.argmax(output) == y[i]:
  44. n_correct += 1
  45. n_total += 1
  46. return n_correct / n_total # 返回正确率
  47. def main():
  48. # 导入训练集和测试集
  49. train_data = get_data_loader(is_train=True)
  50. test_data = get_data_loader(is_train=False)
  51. net = Net() # 初始化神经网络
  52. # 打印初始网络的正确率,应当是10%附近。手写数字有十种结果,随机猜的正确率就是1/10
  53. print("initial accuracy:", evaluate(test_data, net))
  54. """训练神经网络
  55. pytorch的固定写法
  56. """
  57. optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
  58. for epoch in range(5): # 需要在一个数据集上反复训练神经网络,epoch网络轮次,提高数据集的利用率
  59. for (x, y) in train_data:
  60. net.zero_grad() # 初始化
  61. output = net.forward(x.view(-1, 28*28)) # 正向传播
  62. # 计算差值,nll_loss对数损失函数,为了匹配log_softmax的log运算
  63. loss = torch.nn.functional.nll_loss(output, y)
  64. loss.backward() # 反向误差传播
  65. optimizer.step() # 优化网络参数
  66. print("epoch", epoch, "accuracy:", evaluate(test_data, net)) # 打印当前网络的正确率
  67. """测试神经网络
  68. 训练完成后,随机抽取3张图片进行测试
  69. """
  70. for (n, (x, _)) in enumerate(test_data):
  71. if n > 3:
  72. break
  73. predict = torch.argmax(net.forward(x[0].view(-1, 28*28))) # 测试结果
  74. plt.figure(n) # 画出图像
  75. plt.imshow(x[0].view(28, 28)) # 像素大小28*28
  76. plt.title("prediction: " + str(int(predict))) # figure的标题
  77. plt.show()
  78. if __name__ == "__main__":
  79. main()

4.运行结果

4.1正确率

4.2测试结果

        参考资料来源:B站

        文章若有不当和不正确之处,还望理解与指出。由于部分文字、图片等来源于互联网,无法核实真实出处,如涉及相关争议,请联系博主删除。如有错误、疑问和侵权,欢迎评论留言联系作者,或者关注VX公众号:Rain21321,联系作者。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/751715
推荐阅读
相关标签
  

闽ICP备14008679号