当前位置:   article > 正文

pytorch:使用CNN识别数字(分类)_torch.unsqueeze(test_data.test_data, dim=1).type(t

torch.unsqueeze(test_data.test_data, dim=1).type(torch.floattensor)[:2000]/2
  1. import torch
  2. import torch.nn as nn
  3. from torch.autograd import Variable
  4. import torch.utils.data as Data
  5. import torchvision
  6. import matplotlib.pyplot as plt
  7. EPOCH = 1 # 训练整批数据多少次
  8. BATCH_SIZE = 50 # 每批多少
  9. LR = 0.001 # 学习率
  10. DOWNLOAD_MNIST = False # 是否下载mnist数据,第一次True,下载好后改为False
  11. # Mnist 手写数字
  12. # 训练集
  13. train_data = torchvision.datasets.MNIST(
  14. root='./mnist', # 保存或者提取位置
  15. train=True, # this is training data
  16. transform=torchvision.transforms.ToTensor(),
  17. download=DOWNLOAD_MNIST
  18. )
  19. train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
  20. # # 展示数据集
  21. # print(train_data.train_data.size())
  22. # print(train_data.train_labels.size())
  23. # plt.imshow(train_data.train_data[0].numpy(), cmap='gray')
  24. # plt.title('%i' % train_data.train_labels[0])
  25. # plt.show()
  26. # 测试集
  27. test_data = torchvision.datasets.MNIST(
  28. root='./mnist', # 保存或者提取位置
  29. train=False, # this is training data
  30. )
  31. test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[
  32. :2000] / 255. # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)
  33. test_y = test_data.test_labels[:2000]
  34. # 建立CNN网络
  35. class CNN(nn.Module):
  36. def __init__(self):
  37. super(CNN, self).__init__()
  38. self.conv1 = nn.Sequential( # 卷积层 (1, 28, 28)
  39. nn.Conv2d( # 卷积层:过滤器
  40. in_channels=1, # rgb:3, 灰度:1
  41. out_channels=16, # 通道数
  42. kernel_size=5, # 卷积核大小
  43. stride=1, # 步长
  44. padding=2, # 零填充 28=28+2p-5/1 + 1
  45. ), # ->(16, 28, 28)
  46. nn.ReLU(), # 激活函数
  47. nn.MaxPool2d(kernel_size=2), # 池化层 ->(16, 14, 14)
  48. )
  49. self.conv2 = nn.Sequential( # 卷积层 (16, 14, 14)
  50. nn.Conv2d( # 卷积层:过滤器
  51. in_channels=16, # rgb:3, 灰度:1
  52. out_channels=32, # 通道数
  53. kernel_size=5, # 卷积核大小
  54. stride=1, # 步长
  55. padding=2, # 零填充 28=28+2p-5/1 + 1
  56. ), # ->(32, 14, 14)
  57. nn.ReLU(), # 激活函数
  58. nn.MaxPool2d(kernel_size=2), # 池化层 ->(32, 7, 7)
  59. )
  60. self.out = nn.Linear(32*7*7, 10) # 0-9
  61. def forward(self, x):
  62. x = self.conv1(x)
  63. x = self.conv2(x) # (batch_size, 32, 7, 7)
  64. x = x.view(x.size(0), -1) # 展平多维的卷积图成 (batch_size, 32 * 7 * 7)
  65. output = self.out(x)
  66. return output
  67. return x
  68. cnn = CNN()
  69. # print(cnn)
  70. # 训练
  71. optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
  72. loss_func = nn.CrossEntropyLoss()
  73. for epoch in range(EPOCH):
  74. for step, (b_x, b_y) in enumerate(train_loader):
  75. output = cnn(b_x)
  76. loss = loss_func(output, b_y)
  77. optimizer.zero_grad()
  78. loss.backward()
  79. optimizer.step()
  80. if step % 50 == 0:
  81. test_output = cnn(test_x)
  82. pred_y = torch.max(test_output, 1)[1].data.numpy()
  83. accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))
  84. print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % accuracy)
  85. test_output = cnn(test_x[:10])
  86. pred_y = torch.max(test_output, 1)[1].data.numpy()
  87. print(pred_y, 'prediction number')
  88. print(test_y[:10].numpy(), 'real number')

运行结果;
 

  1. Epoch: 0 | train loss: 2.2940 | test accuracy: 0.10
  2. Epoch: 0 | train loss: 0.4439 | test accuracy: 0.82
  3. Epoch: 0 | train loss: 0.3219 | test accuracy: 0.90
  4. Epoch: 0 | train loss: 0.3983 | test accuracy: 0.92
  5. Epoch: 0 | train loss: 0.3778 | test accuracy: 0.93
  6. Epoch: 0 | train loss: 0.0972 | test accuracy: 0.94
  7. Epoch: 0 | train loss: 0.1544 | test accuracy: 0.95
  8. Epoch: 0 | train loss: 0.0502 | test accuracy: 0.95
  9. Epoch: 0 | train loss: 0.1359 | test accuracy: 0.96
  10. Epoch: 0 | train loss: 0.0192 | test accuracy: 0.97
  11. Epoch: 0 | train loss: 0.2811 | test accuracy: 0.97
  12. Epoch: 0 | train loss: 0.1555 | test accuracy: 0.96
  13. Epoch: 0 | train loss: 0.0601 | test accuracy: 0.96
  14. Epoch: 0 | train loss: 0.0407 | test accuracy: 0.97
  15. Epoch: 0 | train loss: 0.2625 | test accuracy: 0.97
  16. Epoch: 0 | train loss: 0.0915 | test accuracy: 0.98
  17. Epoch: 0 | train loss: 0.1415 | test accuracy: 0.98
  18. Epoch: 0 | train loss: 0.1100 | test accuracy: 0.98
  19. Epoch: 0 | train loss: 0.0870 | test accuracy: 0.97
  20. Epoch: 0 | train loss: 0.0553 | test accuracy: 0.98
  21. Epoch: 0 | train loss: 0.0267 | test accuracy: 0.97
  22. Epoch: 0 | train loss: 0.0187 | test accuracy: 0.97
  23. Epoch: 0 | train loss: 0.1733 | test accuracy: 0.97
  24. Epoch: 0 | train loss: 0.0438 | test accuracy: 0.98
  25. [7 2 1 0 4 1 4 9 5 9] prediction number
  26. [7 2 1 0 4 1 4 9 5 9] real number

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/235370
推荐阅读
相关标签
  

闽ICP备14008679号