当前位置:   article > 正文

深入浅出 diffusion(3):pytorch 实现 diffusion 中的 U-Net

深入浅出 diffusion(3):pytorch 实现 diffusion 中的 U-Net

导入python包

  1. import math
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F

 silu激活函数

  1. class SiLU(nn.Module):
  2. # SiLU激活函数
  3. @staticmethod
  4. def forward(x):
  5. return x * torch.sigmoid(x)

归一化设置

  1. def get_norm(norm, num_channels, num_groups):
  2. if norm == "in":
  3. return nn.InstanceNorm2d(num_channels, affine=True)
  4. elif norm == "bn":
  5. return nn.BatchNorm2d(num_channels)
  6. elif norm == "gn":
  7. return nn.GroupNorm(num_groups, num_channels)
  8. elif norm is None:
  9. return nn.Identity()
  10. else:
  11. raise ValueError("unknown normalization type")

 计算时间步长的位置嵌入,一半为sin,一半为cos

  1. class PositionalEmbedding(nn.Module):
  2. def __init__(self, dim, scale=1.0):
  3. super().__init__()
  4. assert dim % 2 == 0
  5. self.dim = dim
  6. self.scale = scale
  7. def forward(self, x):
  8. device = x.device
  9. half_dim = self.dim // 2
  10. emb = math.log(10000) / half_dim
  11. emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
  12. # x * self.scale和emb外积
  13. emb = torch.outer(x * self.scale, emb)
  14. emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
  15. return emb

 上下采样层设置

  1. class Downsample(nn.Module):
  2. def __init__(self, in_channels):
  3. super().__init__()
  4. self.downsample = nn.Conv2d(in_channels, in_channels, 3, stride=2, padding=1)
  5. def forward(self, x, time_emb, y):
  6. if x.shape[2] % 2 == 1:
  7. raise ValueError("downsampling tensor height should be even")
  8. if x.shape[3] % 2 == 1:
  9. raise ValueError("downsampling tensor width should be even")
  10. return self.downsample(x)
  11. class Upsample(nn.Module):
  12. def __init__(self, in_channels):
  13. super().__init__()
  14. self.upsample = nn.Sequential(
  15. nn.Upsample(scale_factor=2, mode="nearest"),
  16. nn.Conv2d(in_channels, in_channels, 3, padding=1),
  17. )
  18. def forward(self, x, time_emb, y):
  19. return self.upsample(x)

 使用Self-Attention注意力机制,做一个全局的Self-Attention

  1. class AttentionBlock(nn.Module):
  2. def __init__(self, in_channels, norm="gn", num_groups=32):
  3. super().__init__()
  4. self.in_channels = in_channels
  5. self.norm = get_norm(norm, in_channels, num_groups)
  6. self.to_qkv = nn.Conv2d(in_channels, in_channels * 3, 1)
  7. self.to_out = nn.Conv2d(in_channels, in_channels, 1)
  8. def forward(self, x):
  9. b, c, h, w = x.shape
  10. q, k, v = torch.split(self.to_qkv(self.norm(x)), self.in_channels, dim=1)
  11. q = q.permute(0, 2, 3, 1).view(b, h * w, c)
  12. k = k.view(b, c, h * w)
  13. v = v.permute(0, 2, 3, 1).view(b, h * w, c)
  14. dot_products = torch.bmm(q, k) * (c ** (-0.5))
  15. assert dot_products.shape == (b, h * w, h * w)
  16. attention = torch.softmax(dot_products, dim=-1)
  17. out = torch.bmm(attention, v)
  18. assert out.shape == (b, h * w, c)
  19. out = out.view(b, h, w, c).permute(0, 3, 1, 2)
  20. return self.to_out(out) + x

 用于特征提取的残差结构

  1. class ResidualBlock(nn.Module):
  2. def __init__(
  3. self, in_channels, out_channels, dropout, time_emb_dim=None, num_classes=None, activation=F.relu,
  4. norm="gn", num_groups=32, use_attention=False,
  5. ):
  6. super().__init__()
  7. self.activation = activation
  8. self.norm_1 = get_norm(norm, in_channels, num_groups)
  9. self.conv_1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
  10. self.norm_2 = get_norm(norm, out_channels, num_groups)
  11. self.conv_2 = nn.Sequential(
  12. nn.Dropout(p=dropout),
  13. nn.Conv2d(out_channels, out_channels, 3, padding=1),
  14. )
  15. self.time_bias = nn.Linear(time_emb_dim, out_channels) if time_emb_dim is not None else None
  16. self.class_bias = nn.Embedding(num_classes, out_channels) if num_classes is not None else None
  17. self.residual_connection = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
  18. self.attention = nn.Identity() if not use_attention else AttentionBlock(out_channels, norm, num_groups)
  19. def forward(self, x, time_emb=None, y=None):
  20. out = self.activation(self.norm_1(x))
  21. # 第一个卷积
  22. out = self.conv_1(out)
  23. # 对时间time_emb做一个全连接,施加在通道上
  24. if self.time_bias is not None:
  25. if time_emb is None:
  26. raise ValueError("time conditioning was specified but time_emb is not passed")
  27. out += self.time_bias(self.activation(time_emb))[:, :, None, None]
  28. # 对种类y_emb做一个全连接,施加在通道上
  29. if self.class_bias is not None:
  30. if y is None:
  31. raise ValueError("class conditioning was specified but y is not passed")
  32. out += self.class_bias(y)[:, :, None, None]
  33. out = self.activation(self.norm_2(out))
  34. # 第二个卷积+残差边
  35. out = self.conv_2(out) + self.residual_connection(x)
  36. # 最后做个Attention
  37. out = self.attention(out)
  38. return out

 U-Net模型设计

  1. class UNet(nn.Module):
  2. def __init__(
  3. self, img_channels, base_channels=128, channel_mults=(1, 2, 2, 2),
  4. num_res_blocks=2, time_emb_dim=128 * 4, time_emb_scale=1.0, num_classes=None, activation=F.silu,
  5. dropout=0.1, attention_resolutions=(1,), norm="gn", num_groups=32, initial_pad=0,
  6. ):
  7. super().__init__()
  8. # 使用到的激活函数,一般为SILU
  9. self.activation = activation
  10. # 是否对输入进行padding
  11. self.initial_pad = initial_pad
  12. # 需要去区分的类别数
  13. self.num_classes = num_classes
  14. # 对时间轴输入的全连接层
  15. self.time_mlp = nn.Sequential(
  16. PositionalEmbedding(base_channels, time_emb_scale),
  17. nn.Linear(base_channels, time_emb_dim),
  18. nn.SiLU(),
  19. nn.Linear(time_emb_dim, time_emb_dim),
  20. ) if time_emb_dim is not None else None
  21. # 对输入图片的第一个卷积
  22. self.init_conv = nn.Conv2d(img_channels, base_channels, 3, padding=1)
  23. # self.downs用于存储下采样用到的层,首先利用ResidualBlock提取特征
  24. # 然后利用Downsample降低特征图的高宽
  25. self.downs = nn.ModuleList()
  26. self.ups = nn.ModuleList()
  27. # channels指的是每一个模块处理后的通道数
  28. # now_channels是一个中间变量,代表中间的通道数
  29. channels = [base_channels]
  30. now_channels = base_channels
  31. for i, mult in enumerate(channel_mults):
  32. out_channels = base_channels * mult
  33. for _ in range(num_res_blocks):
  34. self.downs.append(
  35. ResidualBlock(
  36. now_channels, out_channels, dropout,
  37. time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,
  38. norm=norm, num_groups=num_groups, use_attention=i in attention_resolutions,
  39. )
  40. )
  41. now_channels = out_channels
  42. channels.append(now_channels)
  43. if i != len(channel_mults) - 1:
  44. self.downs.append(Downsample(now_channels))
  45. channels.append(now_channels)
  46. # 可以看作是特征整合,中间的一个特征提取模块
  47. self.mid = nn.ModuleList(
  48. [
  49. ResidualBlock(
  50. now_channels, now_channels, dropout,
  51. time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,
  52. norm=norm, num_groups=num_groups, use_attention=True,
  53. ),
  54. ResidualBlock(
  55. now_channels, now_channels, dropout,
  56. time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,
  57. norm=norm, num_groups=num_groups, use_attention=False,
  58. ),
  59. ]
  60. )
  61. # 进行上采样,进行特征融合
  62. for i, mult in reversed(list(enumerate(channel_mults))):
  63. out_channels = base_channels * mult
  64. for _ in range(num_res_blocks + 1):
  65. self.ups.append(ResidualBlock(
  66. channels.pop() + now_channels, out_channels, dropout,
  67. time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,
  68. norm=norm, num_groups=num_groups, use_attention=i in attention_resolutions,
  69. ))
  70. now_channels = out_channels
  71. if i != 0:
  72. self.ups.append(Upsample(now_channels))
  73. assert len(channels) == 0
  74. self.out_norm = get_norm(norm, base_channels, num_groups)
  75. self.out_conv = nn.Conv2d(base_channels, img_channels, 3, padding=1)
  76. def forward(self, x, time=None, y=None):
  77. # 是否对输入进行padding
  78. ip = self.initial_pad
  79. if ip != 0:
  80. x = F.pad(x, (ip,) * 4)
  81. # 对时间轴输入的全连接层
  82. if self.time_mlp is not None:
  83. if time is None:
  84. raise ValueError("time conditioning was specified but tim is not passed")
  85. time_emb = self.time_mlp(time)
  86. else:
  87. time_emb = None
  88. if self.num_classes is not None and y is None:
  89. raise ValueError("class conditioning was specified but y is not passed")
  90. # 对输入图片的第一个卷积
  91. x = self.init_conv(x)
  92. # skips用于存放下采样的中间层
  93. skips = [x]
  94. for layer in self.downs:
  95. x = layer(x, time_emb, y)
  96. skips.append(x)
  97. # 特征整合与提取
  98. for layer in self.mid:
  99. x = layer(x, time_emb, y)
  100. # 上采样并进行特征融合
  101. for layer in self.ups:
  102. if isinstance(layer, ResidualBlock):
  103. x = torch.cat([x, skips.pop()], dim=1)
  104. x = layer(x, time_emb, y)
  105. # 上采样并进行特征融合
  106. x = self.activation(self.out_norm(x))
  107. x = self.out_conv(x)
  108. if self.initial_pad != 0:
  109. return x[:, :, ip:-ip, ip:-ip]
  110. else:
  111. return x

参考链接:GitCode - 开发者的代码家园icon-default.png?t=N7T8https://gitcode.com/bubbliiiing/ddpm-pytorch/tree/master?utm_source=csdn_github_accelerator&isLogin=1

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/article/detail/47402?site
推荐阅读
相关标签
  

闽ICP备14008679号