当前位置:   article > 正文

Pytorch学习基础——RNN从训练到测试_rnn训练完之后如何测试

rnn训练完之后如何测试

在上一篇Pytorch学习基础——LeNet从训练到测试讲述了简单神经网络LeNet识别MNIST数据集的实例,作为对比,本次将结合LSTM实现对MNIST数据集的识别。

实现过程:

  • 导入必要的包并设置超参数:
  1. import torch
  2. import torchvision
  3. from torch import nn
  4. from torch.autograd import Variable
  5. import torchvision.datasets as dsets
  6. import torchvision.transforms as transforms
  7. import matplotlib.pyplot as plt
  8. #define hyperparameter
  9. EPOCH = 1
  10. BATCH_SIZE = 64
  11. TIME_STEP = 28 #time_step / image_height
  12. INPUT_SIZE = 28 #input_step / image_width
  13. LR = 0.01
  14. DOWNLOAD = True
  • 下载并加载MNIST数据集(如果已经下载MNIST数据集,设置DOWMLOAD=False即可)
  1. #get the mnist dataset
  2. train_data = dsets.MNIST(root='./', train=True, transform=torchvision.transforms.ToTensor(), download=False)
  3. test_data = dsets.MNIST(root='./', train=False, transform=torchvision.transforms.ToTensor())
  4. test_x = test_data.test_data.type(torch.FloatTensor)[:2000]/255
  5. test_y = test_data.test_labels.numpy()[:2000]
  6. #use dataloader to batch input dateset
  7. train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)

注意与CNN数据加载时的区别,LSTM将图片按行排列作为序列,实现循环神经网络的训练和测试

  • 定义并实例化LSTM神经网络:
  1. #define the RNN class
  2. class RNN(nn.Module):
  3. #overload __init__() method
  4. def __init__(self):
  5. super(RNN, self).__init__()
  6. self.rnn = nn.LSTM(
  7. input_size=28,
  8. hidden_size=64,
  9. num_layers=1,
  10. batch_first=True,
  11. )
  12. self.out = nn.Linear(64,10)
  13. #overload forward() method
  14. def forward(self, x):
  15. r_out, (h_n, h_c) = self.rnn(x, None)
  16. out = self.out(r_out[: ,-1, :])
  17. return out
  18. rnn = RNN()
  19. print(rnn)
  • 定义优化器和损失函数
  1. #define optimizer with Adam optim
  2. optimizer = torch.optim.Adam(rnn.parameters(), lr=LR)
  3. #define cross entropy loss function
  4. loss_func = nn.CrossEntropyLoss()
  • 训练模型
  1. #training and testing
  2. for epoch in range(EPOCH):
  3. for step, (b_x, b_y) in enumerate(train_loader):
  4. #recover x as (batch, time_step, input_size)
  5. b_x = b_x.view(-1, 28, 28)
  6. output = rnn(b_x)
  7. loss = loss_func(output, b_y)
  8. optimizer.zero_grad()
  9. loss.backward()
  10. optimizer.step()
  11. if step % 50 == 0:
  12. #train with rnn
  13. test_output = rnn(test_x)
  14. #loss function
  15. pred_y = torch.max(test_output, 1)[1].data.numpy()
  16. #accuracy calculate
  17. acc = float((pred_y == test_y).astype(int).sum()) / float(test_y.size)
  18. print('Epoch: ', (epoch), 'train loss: %.3f'%loss.data.numpy(), 'test acc: %.3f'%(acc))
  • 测试模型
  1. # print 100 predictions from test data
  2. numTest = 100
  3. test_output = rnn(test_x[:numTest].view(-1, 28, 28))
  4. pred_y = torch.max(test_output, 1)[1].data.numpy()
  5. print(pred_y, 'prediction number')
  6. print(test_y[:numTest], 'real number')
  7. ErrorCount = 0.0
  8. for i in pred_y:
  9. if pred_y[i] != test_y[i]:
  10. ErrorCount += 1
  11. print('ErrorRate : %.3f'%(ErrorCount / numTest))

实验结果:

可以看到,LSTM网络既可以用于语音处理,同时可以进行图像分类,此时的“图像”被抽象为按行排列的序列,对于MNIST数据集 的测试表明,LSTM可以在较短时间内实现对数字手势的识别。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/720152
推荐阅读
相关标签
  

闽ICP备14008679号