当前位置:   article > 正文

PyTorch笔记 - Diffusion Model 源码开发 (2)_diffusion model 数据集制作

diffusion model 数据集制作

Diffusion Model的效果如下:
Diffusion Model
源码如下:

  1. 选择一个数据集
  2. 确定超参数的值
  3. 确定扩散过程任意时刻的采样值
  4. 演示原始数据分布加噪100步后的效果
  5. 编写拟合逆扩散过程高斯分布的模型
  6. 编写训练的误差函数
  7. 编写逆扩散采样函数(inference过程)
  8. 开始训练模型,并打印loss及中间的重构效果
  9. 动画演示扩散过程和逆扩散过程
# 1、选择一个数据集
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_s_curve
import torch

s_curve, _ = make_s_curve(10**4, noise=0.1)  # 10000个点
s_curve = s_curve[:, [0,2]]/10.0  # 每个点只取第0维和第2维
print(f'[Info] s_curve[0]: {
     s_curve[0]}')

print(f"[Info] shape of moons: {
     np.shape(s_curve)}")

data = s_curve.T
fig, ax = plt.subplots()
ax.scatter(*data, color="red", edgecolor="white")
ax.axis("off")
dataset = torch.Tensor(s_curve).float()



# 2、确定超参数的值

num_steps = 100

# 制定每一步的beta
betas = torch.linspace(-6, 6, num_steps)
print(f'[Info] betas: {
     betas.shape}, betas[0]: {
     betas[1]}')
betas = torch.sigmoid(betas) * (0.5e-2 - 1e-5) + 1e-5
# 近似:最小值0.00001,最大值0.005
print(f'[Info] betas[0]: {
     betas[0]}, betas[-1]: {
     betas[-1]}')

# 计算alpha
alphas = 1 - betas
alphas_prod = torch.cumprod(alphas, 0)  # prod是product乘法,连乘
alphas_prod_p = torch.cat([torch.tensor([1]).float(), alphas_prod[:-1]], 0)
alphas_bar_sqrt = torch.sqrt(alphas_prod)
one_minus_alphas_bar_log = torch.log(1  - alphas_prod)
one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_prod)

assert alphas.shape == alphas_prod.shape == alphas_prod_p.shape == alphas_bar_sqrt.shape \
    == one_minus_alphas_bar_log.shape == one_minus_alphas_bar_sqrt.shape

print("[Info] all the same shape:", betas.shape)


# 测试linspace、cumprod
alist = torch.linspace(1, 4, 4)   # [1, 2, 3, 4]</
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号