A VAE doesn’t actually reconstruct its input; it reconstructs a version of its input that has been compressed into a probabilistic latent space.
Let’s build a VAE for MNIST. This will let us see how it learns to compress images into a distribution and then sample from that distribution to generate new, similar images.
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
# Hyperparameters
input_dim = 784 # MNIST images are 28x28 = 784
hidden_dim = 400
latent_dim = 20
epochs = 15
batch_size = 128
learning_rate = 1e-3
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 1. DataLoader for MNIST
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5,), std=(0.5,)) # Normalize to [-1, 1]
])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# 2. VAE Model Definition
class VAE(nn.Module):
def __init__(self, input_dim, hidden_dim, latent_dim):
super(VAE, self).__init__()
# Encoder
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
# Decoder
self.fc_decoder_input = nn.Linear(latent_dim, hidden_dim)
self.fc_decoder_output = nn.Linear(hidden_dim, input_dim)
self.sigmoid = nn.Sigmoid() # Output pixel values between 0 and 1
def encode(self, x):
h = self.relu(self.fc1(x))
mu = self.fc_mu(h)
logvar = self.fc_logvar(h)
return mu, logvar
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
h = self.relu(self.fc_decoder_input(z))
# We use sigmoid in the output layer to ensure pixel values are in [0, 1]
# If normalization was to [-1, 1], we'd use Tanh.
# Since we normalized to [-1, 1] and then scaled to [0, 1] for sigmoid,
# it's a bit of a mix. For pure reconstruction with [-1, 1], Tanh is better.
# For generation with [0, 1], Sigmoid is fine. Let's stick to Sigmoid.
return self.sigmoid(self.fc_decoder_output(h))
def forward(self, x):
mu, logvar = self.encode(x.view(-1, input_dim))
z = self.reparameterize(mu, logvar)
x_recon = self.decode(z)
return x_recon, mu, logvar
# 3. Loss Function (Binary Cross-Entropy + KL Divergence)
def vae_loss(recon_x, x, mu, logvar):
# Reconstruction loss (BCE)
BCE = nn.functional.binary_cross_entropy(recon_x, x.view(-1, input_dim), reduction='sum')
# KL divergence loss
# D_KL(N(mu, sigma^2) || N(0, 1)) = -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
# where sigma^2 = exp(logvar)
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
# 4. Model, Optimizer, and Training Loop
model = VAE(input_dim, hidden_dim, latent_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
for epoch in range(epochs):
model.train()
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
recon_batch, mu, logvar = model(data)
loss = vae_loss(recon_batch, data, mu, logvar)
loss.backward()
train_loss += loss.item()
optimizer.step()
avg_loss = train_loss / len(train_loader.dataset)
print(f'Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}')
# 5. Visualize Results (Reconstructions and Generations)
model.eval()
with torch.no_grad():
# Reconstructions
data, _ = next(iter(test_loader))
data = data.to(device)
recon_batch, mu, logvar = model(data)
# Plot original and reconstructed images
n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
# Original
ax = plt.subplot(2, n, i + 1)
plt.imshow(data[i].cpu().numpy().squeeze(), cmap='gray')
plt.title("Original")
plt.axis('off')
# Reconstruction
ax = plt.subplot(2, n, i + 1 + n)
plt.imshow(recon_batch[i].cpu().numpy().squeeze(), cmap='gray')
plt.title("Reconstructed")
plt.axis('off')
plt.suptitle("Original vs. Reconstructed Images")
plt.show()
# Generations
# Sample from the prior distribution (N(0, 1))
sample_z = torch.randn(64, latent_dim).to(device)
generated_images = model.decode(sample_z).cpu()
plt.figure(figsize=(8, 8))
for i in range(64):
ax = plt.subplot(8, 8, i + 1)
plt.imshow(generated_images[i].numpy().squeeze(), cmap='gray')
plt.axis('off')
plt.suptitle("Generated Images from Latent Space")
plt.show()
The core idea of a VAE is to learn a compressed representation of data, but instead of a single point in latent space, it learns a probability distribution for each input. This distribution is typically parameterized by a mean ($\mu$) and a variance ($\sigma^2$), which the encoder outputs. The "variational" part comes from the fact that we approximate the true posterior distribution of the latent variables with a simpler, tractable distribution (often a Gaussian) to make inference feasible.
The loss function is crucial here. It has two components:
- Reconstruction Loss: This is usually a measure of how well the decoder can reconstruct the original input from the sampled latent representation. For image data with pixel values between 0 and 1, Binary Cross-Entropy (BCE) is common. It penalizes the model for predicting probabilities that are far from the true pixel values (0 or 1).
- KL Divergence Loss: This term acts as a regularizer. It encourages the learned latent distributions ($\mathcal{N}(\mu, \sigma^2)$) to be close to a standard normal distribution ($\mathcal{N}(0, 1)$). This is important because we want the latent space to be somewhat continuous and well-behaved, so we can easily sample from it later to generate new data. The KL divergence between two Gaussians is a well-known analytical formula.
The reparameterize trick is how we backpropagate gradients through the sampling process. If we just sampled directly from $\mathcal{N}(\mu, \sigma^2)$, the sampling operation would be a black box, and gradients wouldn’t flow. By introducing a random variable $\epsilon \sim \mathcal{N}(0, 1)$ and re-expressing the sample as $\mu + \sigma \cdot \epsilon$, we’ve moved the randomness outside the deterministic part of the computation. Now, the gradient can flow back through $\mu$ and $\sigma$.
When you train this, you’ll see the loss decrease. The reconstruction loss will drop as the VAE gets better at turning an image into a latent code and back. The KL divergence will also decrease, meaning the latent distributions are becoming more like standard normal distributions.
The magic happens when you sample from $\mathcal{N}(0, 1)$ (the prior distribution we regularized towards) and pass those samples through the decoder. Because the latent space is regularized to be continuous and centered around zero, sampling points from it and decoding them should produce new data that resembles the training data.
The next hurdle you’ll likely face is understanding how to interpret the learned latent space. What do the different dimensions of z actually represent, and how can you manipulate them to control generated outputs?