赞
踩
- import torch.nn as nn
-
-
- # 搭建CNN模型
- class CNN(nn.Module):
- def __init__(self):
- super(CNN, self).__init__()
- # 1号网络
- self.conv1 = nn.Sequential(
- # 卷积参数设置
- nn.Conv2d(
- in_channels=1, # 输入数据的通道为1,即卷积核通道数为1
- out_channels=16, # 输出通道数为16,即卷积核个数为16
- kernel_size=5, # 卷积核的尺寸为 5*5
- stride=1, # 卷积核的滑动步长为1
- padding=2, # 边缘零填充为2
- ),
- nn.ReLU(), # 激活函数为Relu
- nn.MaxPool2d(kernel_size=2), # 最大池化 2*2
- )
- self.conv2 = nn.Sequential(
- nn.Conv2d(16, 32, 5, 1, 2),
- nn.ReLU(),
- nn.MaxPool2d(2),
- )
- # 全连接层
- self.fc1 = nn.Sequential(
- nn.Linear(32 * 7 * 7, 120),
- nn.ReLU(),
- )
- self.fc2 = nn.Sequential(
- nn.Linear(120, 84),
- nn.ReLU(),
- )
- self.fc3 = nn.Linear(84, 10) # 最后输出结果个数为10,因为数字的类别为10
-
- # 前向传播
- def forward(self, x):
- x = self.conv1(x)
- x = self.conv2(x)
- # nn.Linear的输入输出都是一维的,所以对参数实现扁平化成一维,便于全连接层接入
- x = x.view(x.size(0), -1)
- x = self.fc1(x)
- x = self.fc2(x)
- x = self.fc3(x)
- return x
-
-
- # 将CNN网络类实例化
- model = CNN()

Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。