当前位置:   article > 正文

Pytorch:搭建CNN模型的基本架构_torch cnn

torch cnn
  1. import torch.nn as nn
  2. # 搭建CNN模型
  3. class CNN(nn.Module):
  4. def __init__(self):
  5. super(CNN, self).__init__()
  6. # 1号网络
  7. self.conv1 = nn.Sequential(
  8. # 卷积参数设置
  9. nn.Conv2d(
  10. in_channels=1, # 输入数据的通道为1,即卷积核通道数为1
  11. out_channels=16, # 输出通道数为16,即卷积核个数为16
  12. kernel_size=5, # 卷积核的尺寸为 5*5
  13. stride=1, # 卷积核的滑动步长为1
  14. padding=2, # 边缘零填充为2
  15. ),
  16. nn.ReLU(), # 激活函数为Relu
  17. nn.MaxPool2d(kernel_size=2), # 最大池化 2*2
  18. )
  19. self.conv2 = nn.Sequential(
  20. nn.Conv2d(16, 32, 5, 1, 2),
  21. nn.ReLU(),
  22. nn.MaxPool2d(2),
  23. )
  24. # 全连接层
  25. self.fc1 = nn.Sequential(
  26. nn.Linear(32 * 7 * 7, 120),
  27. nn.ReLU(),
  28. )
  29. self.fc2 = nn.Sequential(
  30. nn.Linear(120, 84),
  31. nn.ReLU(),
  32. )
  33. self.fc3 = nn.Linear(84, 10) # 最后输出结果个数为10,因为数字的类别为10
  34. # 前向传播
  35. def forward(self, x):
  36. x = self.conv1(x)
  37. x = self.conv2(x)
  38. # nn.Linear的输入输出都是一维的,所以对参数实现扁平化成一维,便于全连接层接入
  39. x = x.view(x.size(0), -1)
  40. x = self.fc1(x)
  41. x = self.fc2(x)
  42. x = self.fc3(x)
  43. return x
  44. # 将CNN网络类实例化
  45. model = CNN()

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/717042
推荐阅读
相关标签
  

闽ICP备14008679号