赞
踩
网络模型为3层(含输入层):
输入(28×28)784个特征单元(神经元);
隐藏层:256个单元;
输出层:10 (比如sofamax的10分类)。
import torch from torch import nn from torch.nn import init import torchvision import torchvision.transforms as transforms import sys import time class FlattenLayer(torch.nn.Module): def __init__(self): super(FlattenLayer, self).__init__() def forward(self, x): # x shape: (batch, *, *, ...) return x.view(x.shape[0], -1) def load_data_fashion_mnist(batch_size, root='Datasets/FashionMNIST'): mnist_train = torchvision.datasets.FashionMNIST(root=root, train=True, download=False, transform=transforms.ToTensor()) mnist_test = torchvision.datasets.FashionMNIST(root=root, train=False, download=False, transform=transforms.ToTensor()) if sys.platform.startswith('win'): num_workers = 0 # 0表示不用额外的进程来加速读取数据 else: num_workers = 4 train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers) test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers) return train_iter, test_iter # 评估 def evaluate_accuracy(data_iter, net): acc_sum, n = 0.0, 0 for X, y in data_iter: acc_sum += (net(X).argmax(dim=1) == y).float().sum().item() n += y.shape[0] return acc_sum / n # 训练,随机梯度下降 def sgd(params, lr, batch_size): for param in params: param.data -= lr * param.grad / batch_size # 注意这里更改param时用的param.data # 训练 def train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, params=None, lr=None, optimizer=None): for epoch in range(num_epochs): train_l_sum, train_acc_sum, n = 0.0, 0.0, 0 for X, y in train_iter: y_hat = net(X) l = loss(y_hat, y).sum() # 梯度清零 optimizer.zero_grad() l.backward() optimizer.step() train_l_sum += l.item() train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item() n += y.shape[0] test_acc = evaluate_accuracy(test_iter, net) print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f' % (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc)) if __name__ == '__main__': num_inputs, num_outputs, num_hiddens = 784, 10, 256 # 定义网络: 输入层(784) --> 隐藏层(256) --> 输出层(10) net = nn.Sequential( FlattenLayer(), nn.Linear(num_inputs, num_hiddens), nn.ReLU(), nn.Linear(num_hiddens, num_outputs), ) # print('len: ', len(list(net.parameters()))) # print('param: ', net.parameters()) # print('param list: ', list(net.parameters())) # 初始化网络参数 for params in net.parameters(): # net.parameters() 为各层的网络参数,可迭代 init.normal_(params, mean=0, std=0.01) print('len: ', len(list(net.parameters()))) print('param init: ', net.parameters()) # <generator object Module.parameters at 0x000001A0B1356620> print('param list init: ', list(net.parameters())) # 加载数据 batch_size = 256 train_iter, test_iter = load_data_fashion_mnist(batch_size) loss = torch.nn.CrossEntropyLoss() optimizer = torch.optim.SGD(net.parameters(), lr=0.5) # 训练 num_epochs = 5 train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, None, None, optimizer)
① net.parameters()
net.parameters() 为:网络定义后,网络模型中的各层的参数 ( <generator object Module.parameters at 0x000001A0B1356620> )
可以转化为一个list查看其参数:list(net.parameters())
②torch.nn.init.normal_(params, mean=0, std=0.01)
对参数params进行均值为0方差为1的初始化
len: 4 param init: <generator object Module.parameters at 0x000001A0B1356620> param list init: [Parameter containing: tensor([[ 0.0096, 0.0091, 0.0036, ..., -0.0016, -0.0025, 0.0082], [ 0.0014, -0.0060, -0.0170, ..., -0.0087, -0.0124, 0.0029], [-0.0002, -0.0053, -0.0205, ..., 0.0140, 0.0015, -0.0047], ..., [-0.0067, -0.0192, -0.0026, ..., 0.0223, 0.0114, 0.0003], [ 0.0033, 0.0095, 0.0097, ..., 0.0015, 0.0030, -0.0138], [ 0.0139, -0.0073, 0.0012, ..., -0.0143, 0.0085, 0.0056]], requires_grad=True), Parameter containing: tensor([-3.3042e-03, 4.7892e-04, -5.6590e-03, 3.2377e-03, 3.5846e-03, 6.6989e-03, 2.3601e-03, -9.1927e-04, 2.4281e-02, -2.2155e-02, 5.3163e-03, 3.7543e-03, 4.3089e-03, 1.1811e-02, -4.1673e-03, -1.9667e-02, 2.6118e-03, -3.2978e-03, 1.3942e-02, -1.6289e-02, 9.8179e-03, -2.2531e-02, -1.4156e-02, -1.4382e-03, 1.7384e-02, -1.2549e-02, 9.4562e-03, -9.0459e-03, 8.4983e-03, -1.5124e-03, -1.4963e-02, -7.0390e-03, 1.0951e-02, -1.6487e-02, -5.2332e-03, -5.2680e-03, -1.7785e-03, -1.3423e-03, 1.9302e-03, -4.9111e-03, 1.7328e-03, -8.0625e-03, -3.9449e-03, 3.6381e-03, 1.1906e-02, 5.0710e-03, 5.1031e-03, 9.2445e-04, 2.6244e-02, -2.9451e-03, 9.6235e-03, -2.1532e-03, -1.3756e-02, -2.1489e-03, -1.3318e-02, 4.8365e-03, -1.0427e-02, 5.2636e-03, 8.1710e-03, -2.8734e-03, -4.0999e-03, -3.3395e-03, 9.2141e-03, 1.8420e-02, 2.3903e-03, 6.3389e-03, -7.1875e-03, 9.3982e-03, -1.6983e-02, -1.9021e-03, -6.3871e-03, 6.7952e-03, -1.2235e-02, -1.6785e-02, -6.6447e-03, 1.2196e-02, 7.3601e-03, -1.5027e-02, -2.6593e-03, -9.6182e-03, -8.4485e-03, 2.2411e-02, -7.5373e-03, 3.6415e-02, 2.6785e-03, 1.9647e-02, -1.4472e-03, -2.1426e-03, -1.0003e-02, -6.0945e-03, 6.1464e-04, 6.1757e-03, 1.2456e-02, 1.0664e-02, 8.7811e-03, -1.9107e-02, -8.5125e-03, -3.2865e-04, 1.0192e-02, -2.4412e-02, -2.1226e-02, 1.0242e-02, 4.0445e-03, -3.3238e-03, 4.4551e-04, 1.7880e-02, 1.4732e-02, 7.4244e-04, 1.5565e-02, 6.3838e-03, 4.2519e-03, 3.7454e-04, 6.0372e-03, 1.0598e-02, 6.6352e-03, 9.3732e-03, 7.1993e-03, -8.0230e-03, -2.0376e-02, 1.7323e-03, 1.5667e-02, -1.0637e-02, -1.9101e-02, -8.6477e-03, 4.6590e-03, -4.7290e-03, 1.2458e-02, 1.0215e-02, 1.4719e-02, -3.4490e-03, -4.6496e-03, 6.5331e-03, -3.9560e-03, -1.1488e-02, -8.5887e-03, 1.5083e-02, 1.0957e-02, 1.9015e-02, -2.1299e-03, -8.0287e-03, -1.4993e-02, -1.1674e-02, 7.0364e-03, -2.5001e-03, -1.0356e-03, 5.7498e-03, 5.7233e-04, 7.9161e-04, -6.0469e-03, -2.6913e-03, 6.7641e-03, 1.8129e-03, 1.5494e-03, -9.7351e-03, 6.8967e-05, 2.2971e-03, -9.1847e-03, -2.3717e-03, -6.4801e-03, 2.9549e-03, -7.2387e-03, -1.6071e-02, -1.1841e-02, -4.3262e-03, -7.4287e-04, -1.0381e-02, -1.9941e-02, 1.2515e-02, 1.1387e-02, -3.3133e-03, 1.3639e-02, -1.9078e-03, -1.5026e-02, 3.7264e-03, 1.2014e-02, -8.0367e-03, -3.5969e-02, 6.3780e-03, 3.4895e-03, 1.5735e-02, -5.6254e-04, -5.5807e-03, 5.4600e-04, -8.7495e-04, 7.8439e-03, -1.2823e-02, -1.4356e-02, 7.8702e-03, 4.3848e-04, 5.3145e-03, -6.1489e-03, 8.7027e-04, -1.0802e-03, 7.2241e-03, 5.0439e-03, 1.3031e-02, 7.4891e-03, -7.3666e-03, -6.0929e-03, -6.1948e-03, 8.1562e-03, -6.0273e-03, -1.0222e-02, -1.7376e-03, -1.2922e-02, 1.1247e-02, -1.0559e-02, -1.5887e-02, 1.0038e-02, -1.4515e-02, -9.5886e-03, 1.2830e-02, 8.8126e-03, -9.1111e-03, 6.2043e-03, 1.9829e-02, 1.5241e-02, 2.2486e-03, 9.0140e-03, 1.7259e-02, -5.6758e-03, 4.1752e-03, 4.8623e-04, 1.9457e-02, 8.3239e-03, -1.1590e-02, -5.5052e-03, -2.0561e-02, 2.8499e-03, 1.1046e-02, -7.4051e-03, 1.1231e-02, 1.4840e-02, 4.9973e-03, 1.3801e-02, -1.4826e-02, -7.4246e-03, -1.5146e-02, 1.2617e-02, 7.5188e-03, 1.9418e-02, -1.0118e-03, -8.8281e-03, -5.6416e-03, 1.8890e-04, -3.9850e-03, -4.7776e-03, 9.0903e-03, -3.2510e-02, 3.5589e-03, 3.9693e-03, 1.9995e-02, 2.7695e-03, 9.5730e-03, -9.2412e-03, 1.0012e-02], requires_grad=True), Parameter containing: tensor([[-0.0074, -0.0100, 0.0084, ..., 0.0004, -0.0123, -0.0015], [-0.0017, 0.0017, 0.0104, ..., -0.0067, -0.0016, -0.0096], [-0.0132, -0.0034, 0.0193, ..., 0.0191, 0.0004, 0.0105], ..., [ 0.0167, -0.0144, 0.0048, ..., 0.0061, -0.0083, 0.0072], [ 0.0018, 0.0048, 0.0050, ..., 0.0015, -0.0165, 0.0046], [-0.0039, 0.0027, -0.0014, ..., 0.0078, -0.0054, -0.0089]], requires_grad=True), Parameter containing: tensor([-0.0058, -0.0037, 0.0099, -0.0099, 0.0018, -0.0193, -0.0041, 0.0043, -0.0114, -0.0049], requires_grad=True)] epoch 1, loss 0.0031, train acc 0.699, test acc 0.763 epoch 2, loss 0.0019, train acc 0.817, test acc 0.785 epoch 3, loss 0.0017, train acc 0.844, test acc 0.844 epoch 4, loss 0.0015, train acc 0.856, test acc 0.798 epoch 5, loss 0.0014, train acc 0.865, test acc 0.845 Process finished with exit code 0
参考:
https://pytorch.org/docs/stable/nn.html#
https://pytorch.org/docs/stable/nn.init.html (torch.nn.init)
https://tangshusen.me/Dive-into-DL-PyTorch/#/chapter03_DL-basics/3.10_mlp-pytorch
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。