变分自编码器(VAE)

变分自编码器(Variational Autoencoder, VAE)是一种生成模型,结合了概率图模型和深度学习的优点。VAE 通过学习数据的潜在表示来生成新的数据样本。与传统的自编码器不同,VAE 在编码器部分引入了概率分布,使得生成的潜在表示具有更好的连续性和可操作性。

VAE 的模型结构

VAE 的模型结构主要包括两个部分:编码器(Encoder)和解码器(Decoder)。

  1. 编码器(Encoder):将输入数据映射到潜在空间的概率分布,通常是高斯分布。编码器输出潜在变量的均值和方差。
  2. 解码器(Decoder):从潜在空间的样本生成数据。解码器将潜在变量映射回原始数据空间。

VAE 的用途

VAE 可以用于以下任务:

  • 数据生成:生成与训练数据相似的新数据样本。
  • 数据降维:将高维数据映射到低维潜在空间。
  • 图像生成和重建:生成和重建图像数据。
  • 异常检测:通过检测潜在空间中的异常点来识别异常数据。

使用 PyTorch 实现 VAE

以下是一个使用 PyTorch 实现的简单变分自编码器(VAE)示例。示例使用 MNIST 数据集来训练 VAE 模型。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 定义编码器
class Encoder(nn.Module):
def __init__(self, input_dim, hidden_dim, latent_dim):
super(Encoder, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2_mean = nn.Linear(hidden_dim, latent_dim)
self.fc2_log_var = nn.Linear(hidden_dim, latent_dim)

def forward(self, x):
h = torch.relu(self.fc1(x))
mean = self.fc2_mean(h)
log_var = self.fc2_log_var(h)
return mean, log_var

# 定义解码器
class Decoder(nn.Module):
def __init__(self, latent_dim, hidden_dim, output_dim):
super(Decoder, self).__init__()
self.fc1 = nn.Linear(latent_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)

def forward(self, z):
h = torch.relu(self.fc1(z))
x_reconstructed = torch.sigmoid(self.fc2(h))
return x_reconstructed

# 定义 VAE 模型
class VAE(nn.Module):
def __init__(self, input_dim, hidden_dim, latent_dim):
super(VAE, self).__init__()
self.encoder = Encoder(input_dim, hidden_dim, latent_dim)
self.decoder = Decoder(latent_dim, hidden_dim, input_dim)

def forward(self, x):
mean, log_var = self.encoder(x)
std = torch.exp(0.5 * log_var)
epsilon = torch.randn_like(std)
z = mean + std * epsilon
x_reconstructed = self.decoder(z)
return x_reconstructed, mean, log_var

# 定义损失函数
def vae_loss(x, x_reconstructed, mean, log_var):
reconstruction_loss = nn.functional.binary_cross_entropy(x_reconstructed, x, reduction='sum')
kl_divergence = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
return reconstruction_loss + kl_divergence

# 加载数据
transform = transforms.ToTensor()
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# 创建模型
input_dim = 28 * 28
hidden_dim = 256
latent_dim = 20
vae = VAE(input_dim, hidden_dim, latent_dim)

# 定义优化器
optimizer = optim.Adam(vae.parameters(), lr=1e-3)

# 训练模型
num_epochs = 20
vae.train()
for epoch in range(num_epochs):
total_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.view(-1, input_dim)
optimizer.zero_grad()
x_reconstructed, mean, log_var = vae(data)
loss = vae_loss(data, x_reconstructed, mean, log_var)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f'Epoch {epoch + 1}, Loss: {total_loss / len(train_loader.dataset)}')

# 保存模型
torch.save(vae.state_dict(), 'vae.pth')

以上示例展示了如何使用 PyTorch 实现一个简单的 VAE 模型。编码器将输入数据映射到潜在空间的均值和方差,解码器从潜在空间的样本生成数据。损失函数包括重建损失和 KL 散度,用于平衡生成数据的质量和潜在空间的连续性。


变分自编码器(VAE)
https://wenzhaoabc.github.io/llm/VAE/
作者
wenzhaoabc
发布于
2023年12月3日
许可协议