赞
踩
目录
简单介绍:
与VAE类似,只不过模型的输入需要考虑图片和条件(condition)的融合,融合结果通过一个 encoder 映射到标准分布(均值和方差),从映射的标准分布中随机采样一个样本,样本也需要和条件进行融合,最后通过 decoder 重构图片;
由于模型的输入是图片和条件的融合,因此模型学习了基于条件的图片生成;
计算源图片和重构图片之间的损失,具体损失函数的推导可以参考:变分自编码器(VAE)
下面的 CVAE 中,用了最简单的融合方式(concat)将条件 Y 与输入 X 融合形成X_given_Y,同理条件 Y 与 X_given_Y 融合形成 z_given_Y;
- import torch
- import torch.nn as nn
-
- class VAE(nn.Module):
- def __init__(self, in_features, latent_size, y_size=0):
- super(VAE, self).__init__()
-
- self.latent_size = latent_size
-
- self.encoder_forward = nn.Sequential( # encoder
- nn.Linear(in_features + y_size, in_features),
- nn.LeakyReLU(),
- nn.Linear(in_features, in_features),
- nn.LeakyReLU(),
- nn.Linear(in_features, self.latent_size * 2)
- )
-
- self.decoder_forward = nn.Sequential( # decoder
- nn.Linear(self.latent_size + y_size, in_features),
- nn.LeakyReLU(),
- nn.Linear(in_features, in_features),
- nn.LeakyReLU(),
- nn.Linear(in_features, in_features),
- nn.Sigmoid()
- )
-
- def encoder(self, X): # encode
- out = self.encoder_forward(X) # 这里通过一个encoder生成均值和标准差
- mu = out[:, :self.latent_size] # 输出的前半部分作为均值
- log_var = out[:, self.latent_size:] # 后半部分作为标准差
- return mu, log_var
-
- def decoder(self, z): # decode
- mu_prime = self.decoder_forward(z)
- return mu_prime
-
- def reparameterization(self, mu, log_var): # reparameterization
- epsilon = torch.randn_like(log_var)
- z = mu + epsilon * torch.sqrt(log_var.exp())
- return z
-
- def loss(self, X, mu_prime, mu, log_var): # cal loss
- reconstruction_loss = torch.mean(torch.square(X - mu_prime).sum(dim=1))
- latent_loss = torch.mean(0.5 * (log_var.exp() + torch.square(mu) - log_var).sum(dim=1))
- return reconstruction_loss + latent_loss
-
- def forward(self, X, *args, **kwargs):
- mu, log_var = self.encoder(X) # encode
- z = self.reparameterization(mu, log_var) # generate z by reparameterization
- mu_prime = self.decoder(z) # decode
- return mu_prime, mu, log_var
-
- class CVAE(VAE):
- def __init__(self, in_features, latent_size, y_size):
- super(CVAE, self).__init__(in_features, latent_size, y_size)
-
- def forward(self, X, y = None, *args, **kwargs):
- y = y.to(next(self.parameters()).device)
- X_given_Y = torch.cat((X, y.unsqueeze(1)), dim = 1)
-
- mu, log_var = self.encoder(X_given_Y)
- z = self.reparameterization(mu, log_var)
- z_given_Y = torch.cat((z, y.unsqueeze(1)), dim = 1)
-
- mu_prime_given_Y = self.decoder(z_given_Y)
- return mu_prime_given_Y, mu, log_var

简单的损失计算代码:
- def loss(self, X, mu_prime, mu, log_var): # cal loss
- reconstruction_loss = torch.mean(torch.square(X - mu_prime).sum(dim=1))
- latent_loss = torch.mean(0.5 * (log_var.exp() + torch.square(mu) - log_var).sum(dim=1))
- return reconstruction_loss + latent_loss
完整代码参考:liujf69/VAE
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。