当前位置:   article > 正文

基于pytorch实现手写数字识别(附python代码)_python utils下没有plot_image

python utils下没有plot_image

/1加载图片:加载数据集,没有的话会自动下载,数据分布在0附近,并打散。

训练集:测试集=6k:1k。

utils.py文件:plot_image()绘制loss下降曲线; plot_curve()显示图片通过plot_image()可视化结果。minst_train.py文件:读取Minst数据集

/2 加载模型:三层线性模型,前两层用ReLU函数,batch_size=512,一张图片28*28,Normalize将数据均匀分布。

/3 训练:学习率0.01,momentum = 0.9,loss定义,梯度清零、计算、更新,每10次显示loss,可以看到loss下降:

/4 测试

计算正确率并显示梯度下降:

遇到的问题:pytorch中优化器获得的是空参数表

ValueError:optimizer got an empty parameter list

解决:初始函数定义未正确,两个下划线

def __init__(self):

        super(Net, self).__init__()

win10+anaconda3+python3.7,安装tensorflow、pytorch、opencv、CUDA10.2

mnist_train.py

  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Tue Jan 14 15:10:20 2020
  4. @author: ZM
  5. """
  6. import torch
  7. from torch import nn
  8. from torch.nn import functional as F
  9. from torch import optim
  10. import torchvision
  11. from matplotlib import pyplot as plt
  12. from utils import plot_image, plot_curve, one_hot
  13. batch_size=512
  14. #step1:load dataset
  15. #加载数据集,没有的话会自动下载,数据分布在0附近,并打散
  16. train_loader=torch.utils.data.DataLoader(
  17. torchvision.datasets.MNIST('mnist_data',train=True,download=True,
  18. transform=torchvision.transforms.Compose([
  19. torchvision.transforms.ToTensor(),
  20. torchvision.transforms.Normalize(
  21. (0.1307,),(0.3081,))
  22. ])),
  23. batch_size=batch_size,shuffle=True)
  24. test_loader=torch.utils.data.DataLoader(
  25. torchvision.datasets.MNIST('mnist_data/',train=False,download=True,
  26. transform=torchvision.transforms.Compose([
  27. torchvision.transforms.ToTensor(),
  28. torchvision.transforms.Normalize(
  29. (0.1307,),(0.3081,))
  30. ])),
  31. batch_size=batch_size,shuffle=False)
  32. #显示:batch_size=512,一张图片28*28,Normalize将数据均匀
  33. x, y = next(iter(train_loader))
  34. print(x.shape,y.shape,x.min(),x.max())
  35. plot_image(x, y, 'image sample')
  36. #建立模型
  37. class Net(nn.Module):
  38. def __init__(self):
  39. super(Net, self).__init__()
  40. #wx+b
  41. self.fc1 = nn.Linear(28*28, 256)
  42. self.fc2 = nn.Linear(256, 64)
  43. self.fc3 = nn.Linear(64,10)
  44. def forward(self, x):
  45. #x:[b,1,28,28]
  46. #h1=relu(w1x+b1)
  47. x = F.relu(self.fc1(x))
  48. #h2=relu(h1w2+b2)
  49. x = F.relu(self.fc2(x))
  50. #h3=h2w3+b3
  51. x = self.fc3(x)
  52. return x
  53. # return F.log_softmax(x, dim=1)
  54. #训练
  55. net = Net()#初始化
  56. #返回[w1,b1,w2,b2,w3,b3]
  57. optimizer = optim.SGD(net.parameters(), lr=0.01, momentum = 0.9)
  58. train_loss = []
  59. for epoch in range(3):
  60. for batch_idx, (x,y) in enumerate(train_loader):
  61. # x[b,1,28,28] y:[512]
  62. # print(x.shape,y.shape)
  63. # break
  64. # x, y = Variable(x), Variable(y)
  65. #[b,1,28,28]=>[b,784]实际图片4维打平为二维
  66. x = x.view(x.size(0), 28*28)
  67. #[b,10]
  68. out = net(x)
  69. #[b,10]
  70. y_onehot = one_hot(y)
  71. #loss=mse(out,y_onehot)
  72. loss = F.mse_loss(out, y_onehot)
  73. optimizer.zero_grad()
  74. loss.backward()
  75. #w'=w-li*grad
  76. optimizer.step()
  77. #测试
  78. train_loss.append(loss.item())
  79. if batch_idx % 10==0:
  80. print(epoch, batch_idx, loss.item())
  81. plot_curve(train_loss)
  82. #达到较好的[w1,b1,w2,b2,w3,b3]
  83. total_correct=0
  84. for x,y in test_loader:
  85. x = x.view(x.size(0),28*28)
  86. #out:[b,10] => pred:[b]
  87. out = net(x)
  88. pred = out.argmax(dim = 1)
  89. correct = pred.eq(y).sum().float().item()
  90. total_correct += correct
  91. total_num = len(test_loader.dataset)
  92. acc = total_correct / total_num
  93. print('test acc:', acc)
  94. x,y = next(iter(test_loader))
  95. out = net(x.view(x.size(0),28*28))
  96. pred = out.argmax(dim = 1)
  97. plot_image(x, pred, 'test')

utils.py 

  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Tue Jan 14 16:37:46 2020
  4. @author: ZM
  5. """
  6. import torch
  7. from matplotlib import pyplot as plt
  8. def plot_curve(data):
  9. fig = plt.figure()
  10. plt.plot(range(len(data)), data, color='blue')
  11. plt.legend(['value'], loc='upper right')
  12. plt.xlabel('step')
  13. plt.ylabel('value')
  14. plt.show()
  15. def plot_image(img, label, name):
  16. fig = plt.figure()
  17. for i in range(6):
  18. plt.subplot(2, 3, i+1)
  19. plt.tight_layout()
  20. plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
  21. plt.title("{}: {}".format(name, label[i].item()))
  22. plt.xticks([])
  23. plt.yticks([])
  24. plt.show()
  25. def one_hot(label, depth=10):
  26. out = torch.zeros(label.size(0), depth)
  27. idx = torch.LongTensor(label).view(-1,1)
  28. out.scatter_(dim=1, index=idx, value=1)
  29. return out

 

声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号