1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
| import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import DataLoader
class Encoder(nn.Module): def __init__(self, input_dim, hidden_dim, latent_dim): super(Encoder, self).__init__() self.fc1 = nn.Linear(input_dim, hidden_dim) self.fc2_mean = nn.Linear(hidden_dim, latent_dim) self.fc2_log_var = nn.Linear(hidden_dim, latent_dim)
def forward(self, x): h = torch.relu(self.fc1(x)) mean = self.fc2_mean(h) log_var = self.fc2_log_var(h) return mean, log_var
class Decoder(nn.Module): def __init__(self, latent_dim, hidden_dim, output_dim): super(Decoder, self).__init__() self.fc1 = nn.Linear(latent_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, z): h = torch.relu(self.fc1(z)) x_reconstructed = torch.sigmoid(self.fc2(h)) return x_reconstructed
class VAE(nn.Module): def __init__(self, input_dim, hidden_dim, latent_dim): super(VAE, self).__init__() self.encoder = Encoder(input_dim, hidden_dim, latent_dim) self.decoder = Decoder(latent_dim, hidden_dim, input_dim)
def forward(self, x): mean, log_var = self.encoder(x) std = torch.exp(0.5 * log_var) epsilon = torch.randn_like(std) z = mean + std * epsilon x_reconstructed = self.decoder(z) return x_reconstructed, mean, log_var
def vae_loss(x, x_reconstructed, mean, log_var): reconstruction_loss = nn.functional.binary_cross_entropy(x_reconstructed, x, reduction='sum') kl_divergence = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp()) return reconstruction_loss + kl_divergence
transform = transforms.ToTensor() train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True) train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
input_dim = 28 * 28 hidden_dim = 256 latent_dim = 20 vae = VAE(input_dim, hidden_dim, latent_dim)
optimizer = optim.Adam(vae.parameters(), lr=1e-3)
num_epochs = 20 vae.train() for epoch in range(num_epochs): total_loss = 0 for batch_idx, (data, _) in enumerate(train_loader): data = data.view(-1, input_dim) optimizer.zero_grad() x_reconstructed, mean, log_var = vae(data) loss = vae_loss(data, x_reconstructed, mean, log_var) loss.backward() optimizer.step() total_loss += loss.item() print(f'Epoch {epoch + 1}, Loss: {total_loss / len(train_loader.dataset)}')
torch.save(vae.state_dict(), 'vae.pth')
|