示例代码: ''' import torch import torch.nn as nn import torch.optim as optim
class VAE(nn.Module): def init(self): super(VAE, self).init() # encoder部分 self.encoder_fc1 = nn.Linear(784, 400) self.encoder_fc2_mu = nn.Linear(400, 20) self.encoder_fc2_logvar = nn.Linear(400, 20) # decoder部分 self.decoder_fc1 = nn.Linear(20, 400) self.decoder_fc2 = nn.Linear(400, 784)
def encode(self, x):
x = torch.relu(self.encoder_fc1(x))
return self.encoder_fc2_mu(x), self.encoder_fc2_logvar(x)
def reparametrize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
z = torch.relu(self.decoder_fc1(z))
return torch.sigmoid(self.decoder_fc2(z))
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparametrize(mu, logvar)
return self.decode(z), mu, logvar
def negative_log_likelihood(x_pred, x_actual): return -torch.mean(torch.sum(x_actual * torch.log(x_pred) + (1 - x_actual) * torch.log(1 - x_pred), dim=1))
def kl_divergence(mu, logvar): return -0.5 * torch.mean(torch.sum(1 + logvar - mu ** 2 - logvar.exp(), dim=1))
vae = VAE() optimizer = optim.Adam(vae.parameters(), lr=0.001) criterion = nn.MSELoss()
num_epochs = 10 for epoch in range(num_epochs): for data in train_loader: img, _ =
上一篇:变分自编码器损失值未正确显示?
下一篇:变分自编码器推断问题。