赞
踩
- import torch
- import torch.nn as nn
- from torch.autograd import Variable
- import torch.utils.data as Data
- import torchvision
- import matplotlib.pyplot as plt
-
- EPOCH = 1 # 训练整批数据多少次
- BATCH_SIZE = 50 # 每批多少
- LR = 0.001 # 学习率
- DOWNLOAD_MNIST = False # 是否下载mnist数据,第一次True,下载好后改为False
-
- # Mnist 手写数字
- # 训练集
- train_data = torchvision.datasets.MNIST(
- root='./mnist', # 保存或者提取位置
- train=True, # this is training data
- transform=torchvision.transforms.ToTensor(),
- download=DOWNLOAD_MNIST
- )
- train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
- # # 展示数据集
- # print(train_data.train_data.size())
- # print(train_data.train_labels.size())
- # plt.imshow(train_data.train_data[0].numpy(), cmap='gray')
- # plt.title('%i' % train_data.train_labels[0])
- # plt.show()
-
-
- # 测试集
- test_data = torchvision.datasets.MNIST(
- root='./mnist', # 保存或者提取位置
- train=False, # this is training data
- )
- test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[
- :2000] / 255. # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)
- test_y = test_data.test_labels[:2000]
-
-
- # 建立CNN网络
- class CNN(nn.Module):
- def __init__(self):
- super(CNN, self).__init__()
- self.conv1 = nn.Sequential( # 卷积层 (1, 28, 28)
- nn.Conv2d( # 卷积层:过滤器
- in_channels=1, # rgb:3, 灰度:1
- out_channels=16, # 通道数
- kernel_size=5, # 卷积核大小
- stride=1, # 步长
- padding=2, # 零填充 28=(28+2p-5)/1 + 1
- ), # ->(16, 28, 28)
- nn.ReLU(), # 激活函数
- nn.MaxPool2d(kernel_size=2), # 池化层 ->(16, 14, 14)
- )
-
- self.conv2 = nn.Sequential( # 卷积层 (16, 14, 14)
- nn.Conv2d( # 卷积层:过滤器
- in_channels=16, # rgb:3, 灰度:1
- out_channels=32, # 通道数
- kernel_size=5, # 卷积核大小
- stride=1, # 步长
- padding=2, # 零填充 28=(28+2p-5)/1 + 1
- ), # ->(32, 14, 14)
- nn.ReLU(), # 激活函数
- nn.MaxPool2d(kernel_size=2), # 池化层 ->(32, 7, 7)
- )
-
- self.out = nn.Linear(32*7*7, 10) # 0-9
-
- def forward(self, x):
- x = self.conv1(x)
- x = self.conv2(x) # (batch_size, 32, 7, 7)
- x = x.view(x.size(0), -1) # 展平多维的卷积图成 (batch_size, 32 * 7 * 7)
- output = self.out(x)
- return output
- return x
-
-
- cnn = CNN()
- # print(cnn)
-
- # 训练
- optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
- loss_func = nn.CrossEntropyLoss()
-
- for epoch in range(EPOCH):
- for step, (b_x, b_y) in enumerate(train_loader):
- output = cnn(b_x)
- loss = loss_func(output, b_y)
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
-
- if step % 50 == 0:
- test_output = cnn(test_x)
- pred_y = torch.max(test_output, 1)[1].data.numpy()
- accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))
- print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % accuracy)
-
- test_output = cnn(test_x[:10])
- pred_y = torch.max(test_output, 1)[1].data.numpy()
- print(pred_y, 'prediction number')
- print(test_y[:10].numpy(), 'real number')
运行结果;
- Epoch: 0 | train loss: 2.2940 | test accuracy: 0.10
- Epoch: 0 | train loss: 0.4439 | test accuracy: 0.82
- Epoch: 0 | train loss: 0.3219 | test accuracy: 0.90
- Epoch: 0 | train loss: 0.3983 | test accuracy: 0.92
- Epoch: 0 | train loss: 0.3778 | test accuracy: 0.93
- Epoch: 0 | train loss: 0.0972 | test accuracy: 0.94
- Epoch: 0 | train loss: 0.1544 | test accuracy: 0.95
- Epoch: 0 | train loss: 0.0502 | test accuracy: 0.95
- Epoch: 0 | train loss: 0.1359 | test accuracy: 0.96
- Epoch: 0 | train loss: 0.0192 | test accuracy: 0.97
- Epoch: 0 | train loss: 0.2811 | test accuracy: 0.97
- Epoch: 0 | train loss: 0.1555 | test accuracy: 0.96
- Epoch: 0 | train loss: 0.0601 | test accuracy: 0.96
- Epoch: 0 | train loss: 0.0407 | test accuracy: 0.97
- Epoch: 0 | train loss: 0.2625 | test accuracy: 0.97
- Epoch: 0 | train loss: 0.0915 | test accuracy: 0.98
- Epoch: 0 | train loss: 0.1415 | test accuracy: 0.98
- Epoch: 0 | train loss: 0.1100 | test accuracy: 0.98
- Epoch: 0 | train loss: 0.0870 | test accuracy: 0.97
- Epoch: 0 | train loss: 0.0553 | test accuracy: 0.98
- Epoch: 0 | train loss: 0.0267 | test accuracy: 0.97
- Epoch: 0 | train loss: 0.0187 | test accuracy: 0.97
- Epoch: 0 | train loss: 0.1733 | test accuracy: 0.97
- Epoch: 0 | train loss: 0.0438 | test accuracy: 0.98
- [7 2 1 0 4 1 4 9 5 9] prediction number
- [7 2 1 0 4 1 4 9 5 9] real number
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。