赞
踩
VGG网络结构
基本的单元(vgg_block)是几个卷积再加上一个池化,这个单元结构反复出现,用一个函数封装(vgg_stack).
- import numpy as np
- import torch
- from torch import nn
- from torch.autograd import Variable
- from torchvision.datasets import CIFAR10
- from torchvision import transforms
- import torch.nn.functional as F
- from utils import train
-
-
- def vgg_block(num_convs, in_channels, out_channels):
- net = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
- nn.ReLU(True)] # 定义第一层卷积层
-
- for i in range(num_convs-1): # 定义后面的num_convs-1层卷积层
- net.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))
- net.append(nn.ReLU(True))
-
- net.append(nn.MaxPool2d(2, 2)) # 定义池化层
- return nn.Sequential(*net)
-
-
- # 每个vgg_block的卷积层层数是num_convs,输入、输出通道数保存在channels中
- def vgg_stack(num_convs, channels):
- net = []
- for n, c in zip(num_convs, channels):
- in_c = c[0]
- out_c = c[1]
- net.append(vgg_block(n, in_c, out_c))
- return nn.Sequential(*net)
- # 定义一个vgg结构,其中有五个卷积层,每个vgg_block的卷积层分别是1,1,2,2,2
- # 每个模块的输入、输出通道数如下
- vgg_net = vgg_stack((1, 1, 2, 2, 2), ((3, 64), (64, 128), (128, 256), (256, 512), (512, 512)))
- # print(vgg_net)
-
-
- class vgg(nn.Module):
- def __init__(self):
- super(vgg, self).__init__()
- self.feature = vgg_net
- self.fc = nn.Sequential(
- nn.Linear(512, 100),
- nn.ReLU(True),
- nn.Linear(100, 10)
- )
-
- def forward(self, x):
- x = self.feature(x)
- x = x.view(x.shape[0], -1)
- x = self.fc(x)
- return x
-
-
- def data_tf(x):
- x = np.array(x, dtype='float32') / 255
- x = (x-0.5)/0.5 # 标准化
- x = x.transpose((2, 0, 1)) # 将channel放到第一维,只是pytorch要求的输入方式
- x = torch.from_numpy(x)
- return x
-
-
- train_set = CIFAR10('./data', train=True, transform=data_tf, download=False)
- train_data = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True)
- test_set = CIFAR10('./data', train=False, transform=data_tf, download=False)
- test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)
-
- net = vgg().cuda()
- optimizer = torch.optim.SGD(net.parameters(), lr=1e-2)
- criterion = nn.CrossEntropyLoss()
-
-
- train(net, train_data, test_data, 20, optimizer, criterion)
-
-

utils.py
- from datetime import datetime
-
- import torch
- import torch.nn.functional as F
- from torch import nn
- from torch.autograd import Variable
-
-
- def get_acc(output, label):
- total = output.shape[0]
- _, pred_label = output.max(1)
- num_correct = (pred_label == label).sum().data[0]
- return num_correct / total
-
-
- def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
- if torch.cuda.is_available():
- net = net.cuda()
- prev_time = datetime.now()
- for epoch in range(num_epochs):
- train_loss = 0
- train_acc = 0
- net = net.train()
- for im, label in train_data:
- if torch.cuda.is_available():
- im = Variable(im.cuda()) # (bs, 3, h, w)
- label = Variable(label.cuda()) # (bs, h, w)
- else:
- im = Variable(im)
- label = Variable(label)
- # forward
- output = net(im)
- loss = criterion(output, label)
- # backward
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
-
- train_loss += loss.data[0]
- train_acc += get_acc(output, label)
-
- cur_time = datetime.now()
- h, remainder = divmod((cur_time - prev_time).seconds, 3600)
- m, s = divmod(remainder, 60)
- time_str = "Time %02d:%02d:%02d" % (h, m, s)
- if valid_data is not None:
- valid_loss = 0
- valid_acc = 0
- net = net.eval()
- for im, label in valid_data:
- if torch.cuda.is_available():
- im = Variable(im.cuda(), volatile=True)
- label = Variable(label.cuda(), volatile=True)
- else:
- im = Variable(im, volatile=True)
- label = Variable(label, volatile=True)
- output = net(im)
- loss = criterion(output, label)
- valid_loss += loss.data[0]
- valid_acc += get_acc(output, label)
- epoch_str = (
- "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
- % (epoch, train_loss / len(train_data),
- train_acc / len(train_data), valid_loss / len(valid_data),
- valid_acc / len(valid_data)))
- else:
- epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
- (epoch, train_loss / len(train_data),
- train_acc / len(train_data)))
- prev_time = cur_time
- print(epoch_str + time_str)
-
-
-

上述代码用的batch_size是128,训练一次数据集大概需要三分钟多,如果改成64,大概需要五分钟多。
训练结果:
训练到第八次的时候,训练集的准确率就达到了99.1%,测试集达到了98.9%
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。