当前位置:   article > 正文

Unet代码实现(PyTorch)_unet代码pytorch

unet代码pytorch
  1. import torch
  2. import torch.nn as nn
  3. from torch.nn import functional as F
  4. # 基本卷积块, 长宽不变,in_channels -> out_channels
  5. class Conv(nn.Module):
  6. def __init__(self, C_in, C_out):
  7. super(Conv, self).__init__()
  8. self.layer = nn.Sequential(
  9. nn.Conv2d(C_in, C_out, 3, 1, 1),
  10. nn.BatchNorm2d(C_out),
  11. # 防止过拟合
  12. nn.Dropout(0.3),
  13. nn.LeakyReLU(),
  14. nn.Conv2d(C_out, C_out, 3, 1, 1),
  15. nn.BatchNorm2d(C_out),
  16. # 防止过拟合
  17. nn.Dropout(0.4),
  18. nn.LeakyReLU(),
  19. )
  20. def forward(self, x):
  21. return self.layer(x)
  22. # 下采样模块,长宽下采样2倍,通道数不变
  23. class DownSampling(nn.Module):
  24. def __init__(self, C):
  25. super(DownSampling, self).__init__()
  26. self.Down = nn.Sequential(
  27. # 使用卷积进行2倍的下采样,通道数不变
  28. nn.Conv2d(C, C, 3, 2, 1),
  29. nn.LeakyReLU()
  30. )
  31. def forward(self, x):
  32. return self.Down(x)
  33. # 上采样模块,长宽扩大2倍,通道数减少2
  34. class UpSampling(nn.Module):
  35. def __init__(self, C):
  36. super(UpSampling, self).__init__()
  37. self.Up = nn.Conv2d(C, C // 2, 1, 1)
  38. def forward(self, x, r):
  39. # 双线性插值进行上采样
  40. up = F.interpolate(x, scale_factor=2, mode="bilinear")
  41. x = self.Up(up)
  42. # 拼接,当前上采样的,和之前下采样过程中的
  43. return torch.cat((x, r), 1)
  44. # 主干网络
  45. class UNet(nn.Module):
  46. def __init__(self):
  47. super(UNet, self).__init__()
  48. # 4次下采样
  49. self.C1 = Conv(3, 64)
  50. self.D1 = DownSampling(64)
  51. self.C2 = Conv(64, 128)
  52. self.D2 = DownSampling(128)
  53. self.C3 = Conv(128, 256)
  54. self.D3 = DownSampling(256)
  55. self.C4 = Conv(256, 512)
  56. self.D4 = DownSampling(512)
  57. self.C5 = Conv(512, 1024)
  58. # 4次上采样
  59. self.U1 = UpSampling(1024)
  60. self.C6 = Conv(1024, 512)
  61. self.U2 = UpSampling(512)
  62. self.C7 = Conv(512, 256)
  63. self.U3 = UpSampling(256)
  64. self.C8 = Conv(256, 128)
  65. self.U4 = UpSampling(128)
  66. self.C9 = Conv(128, 64)
  67. self.Th = torch.nn.Sigmoid()
  68. self.pred = torch.nn.Conv2d(64, 3, 3, 1, 1)
  69. def forward(self, x):
  70. # 下采样部分
  71. R1 = self.C1(x)
  72. R2 = self.C2(self.D1(R1))
  73. R3 = self.C3(self.D2(R2))
  74. R4 = self.C4(self.D3(R3))
  75. Y1 = self.C5(self.D4(R4))
  76. # 上采样部分
  77. # 上采样的时候需要拼接起来
  78. O1 = self.C6(self.U1(Y1, R4))
  79. O2 = self.C7(self.U2(O1, R3))
  80. O3 = self.C8(self.U3(O2, R2))
  81. O4 = self.C9(self.U4(O3, R1))
  82. # 输出预测,这里大小跟输入是一致的
  83. # 可以把下采样时的中间抠出来再进行拼接,这样修改后输出就会更小
  84. return self.Th(self.pred(O4))
  85. if __name__ == '__main__':
  86. a = torch.randn(2, 3, 256, 256)
  87. net = UNet()
  88. print(net(a).shape)

参考UNet详解(附图文和代码实现)_liiiiiiiiiiiiike的博客-CSDN博客

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/351107
推荐阅读
相关标签
  

闽ICP备14008679号