当前位置:   article > 正文

神经网络VGGNet训练CIFAR10数据集_vgg训练cifar10

vgg训练cifar10

VGG网络结构

       

基本的单元(vgg_block)是几个卷积再加上一个池化,这个单元结构反复出现,用一个函数封装(vgg_stack).

  1. import numpy as np
  2. import torch
  3. from torch import nn
  4. from torch.autograd import Variable
  5. from torchvision.datasets import CIFAR10
  6. from torchvision import transforms
  7. import torch.nn.functional as F
  8. from utils import train
  9. def vgg_block(num_convs, in_channels, out_channels):
  10. net = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
  11. nn.ReLU(True)] # 定义第一层卷积层
  12. for i in range(num_convs-1): # 定义后面的num_convs-1层卷积层
  13. net.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))
  14. net.append(nn.ReLU(True))
  15. net.append(nn.MaxPool2d(2, 2)) # 定义池化层
  16. return nn.Sequential(*net)
  17. # 每个vgg_block的卷积层层数是num_convs,输入、输出通道数保存在channels中
  18. def vgg_stack(num_convs, channels):
  19. net = []
  20. for n, c in zip(num_convs, channels):
  21. in_c = c[0]
  22. out_c = c[1]
  23. net.append(vgg_block(n, in_c, out_c))
  24. return nn.Sequential(*net)
  25. # 定义一个vgg结构,其中有五个卷积层,每个vgg_block的卷积层分别是1,1,2,2,2
  26. # 每个模块的输入、输出通道数如下
  27. vgg_net = vgg_stack((1, 1, 2, 2, 2), ((3, 64), (64, 128), (128, 256), (256, 512), (512, 512)))
  28. # print(vgg_net)
  29. class vgg(nn.Module):
  30. def __init__(self):
  31. super(vgg, self).__init__()
  32. self.feature = vgg_net
  33. self.fc = nn.Sequential(
  34. nn.Linear(512, 100),
  35. nn.ReLU(True),
  36. nn.Linear(100, 10)
  37. )
  38. def forward(self, x):
  39. x = self.feature(x)
  40. x = x.view(x.shape[0], -1)
  41. x = self.fc(x)
  42. return x
  43. def data_tf(x):
  44. x = np.array(x, dtype='float32') / 255
  45. x = (x-0.5)/0.5 # 标准化
  46. x = x.transpose((2, 0, 1)) # 将channel放到第一维,只是pytorch要求的输入方式
  47. x = torch.from_numpy(x)
  48. return x
  49. train_set = CIFAR10('./data', train=True, transform=data_tf, download=False)
  50. train_data = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True)
  51. test_set = CIFAR10('./data', train=False, transform=data_tf, download=False)
  52. test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)
  53. net = vgg().cuda()
  54. optimizer = torch.optim.SGD(net.parameters(), lr=1e-2)
  55. criterion = nn.CrossEntropyLoss()
  56. train(net, train_data, test_data, 20, optimizer, criterion)

utils.py

  1. from datetime import datetime
  2. import torch
  3. import torch.nn.functional as F
  4. from torch import nn
  5. from torch.autograd import Variable
  6. def get_acc(output, label):
  7. total = output.shape[0]
  8. _, pred_label = output.max(1)
  9. num_correct = (pred_label == label).sum().data[0]
  10. return num_correct / total
  11. def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
  12. if torch.cuda.is_available():
  13. net = net.cuda()
  14. prev_time = datetime.now()
  15. for epoch in range(num_epochs):
  16. train_loss = 0
  17. train_acc = 0
  18. net = net.train()
  19. for im, label in train_data:
  20. if torch.cuda.is_available():
  21. im = Variable(im.cuda()) # (bs, 3, h, w)
  22. label = Variable(label.cuda()) # (bs, h, w)
  23. else:
  24. im = Variable(im)
  25. label = Variable(label)
  26. # forward
  27. output = net(im)
  28. loss = criterion(output, label)
  29. # backward
  30. optimizer.zero_grad()
  31. loss.backward()
  32. optimizer.step()
  33. train_loss += loss.data[0]
  34. train_acc += get_acc(output, label)
  35. cur_time = datetime.now()
  36. h, remainder = divmod((cur_time - prev_time).seconds, 3600)
  37. m, s = divmod(remainder, 60)
  38. time_str = "Time %02d:%02d:%02d" % (h, m, s)
  39. if valid_data is not None:
  40. valid_loss = 0
  41. valid_acc = 0
  42. net = net.eval()
  43. for im, label in valid_data:
  44. if torch.cuda.is_available():
  45. im = Variable(im.cuda(), volatile=True)
  46. label = Variable(label.cuda(), volatile=True)
  47. else:
  48. im = Variable(im, volatile=True)
  49. label = Variable(label, volatile=True)
  50. output = net(im)
  51. loss = criterion(output, label)
  52. valid_loss += loss.data[0]
  53. valid_acc += get_acc(output, label)
  54. epoch_str = (
  55. "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
  56. % (epoch, train_loss / len(train_data),
  57. train_acc / len(train_data), valid_loss / len(valid_data),
  58. valid_acc / len(valid_data)))
  59. else:
  60. epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
  61. (epoch, train_loss / len(train_data),
  62. train_acc / len(train_data)))
  63. prev_time = cur_time
  64. print(epoch_str + time_str)

 

 上述代码用的batch_size是128,训练一次数据集大概需要三分钟多,如果改成64,大概需要五分钟多。

训练结果:

训练到第八次的时候,训练集的准确率就达到了99.1%,测试集达到了98.9%

 

 

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/957480
推荐阅读
相关标签
  

闽ICP备14008679号