The most surprising truth about GAN training is that the discriminator often learns too well, too quickly, and that’s precisely what breaks the whole process.
Let’s watch a GAN train. Imagine we have a generator G trying to create images of digits, and a discriminator D trying to tell if an image is real (from a dataset of MNIST digits) or fake (from G).
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# Simplified Generator and Discriminator
class Generator(nn.Module):
def __init__(self, latent_dim, img_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(latent_dim, 128),
nn.ReLU(),
nn.Linear(128, img_dim),
nn.Sigmoid() # Outputting pixel values between 0 and 1
)
def forward(self, x):
return self.net(x)
class Discriminator(nn.Module):
def __init__(self, img_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(img_dim, 128),
nn.ReLU(),
nn.Linear(128, 1),
nn.Sigmoid() # Outputting probability between 0 and 1
)
def forward(self, x):
return self.net(x)
# Hyperparameters
latent_dim = 100
img_dim = 784 # 28*28 for MNIST
lr = 0.0002
batch_size = 64
num_epochs = 100
# Initialize models, loss, and optimizers
generator = Generator(latent_dim, img_dim)
discriminator = Discriminator(img_dim)
criterion = nn.BCELoss() # Binary Cross Entropy Loss
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
# Load MNIST dataset
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)) # Normalize to [-1, 1]
])
mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
# Training loop (simplified for illustration)
for epoch in range(num_epochs):
for i, (real_images, _) in enumerate(dataloader):
real_images = real_images.view(batch_size, img_dim)
real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)
# Train Discriminator
discriminator.zero_grad()
# Real images
outputs = discriminator(real_images)
loss_D_real = criterion(outputs, real_labels)
# Fake images
noise = torch.randn(batch_size, latent_dim)
fake_images = generator(noise)
outputs = discriminator(fake_images.detach()) # Detach to avoid training generator here
loss_D_fake = criterion(outputs, fake_labels)
loss_D = loss_D_real + loss_D_fake
loss_D.backward()
optimizer_D.step()
# Train Generator
generator.zero_grad()
noise = torch.randn(batch_size, latent_dim)
fake_images = generator(noise)
outputs = discriminator(fake_images)
loss_G = criterion(outputs, real_labels) # Generator wants discriminator to think fake is real
loss_G.backward()
optimizer_G.step()
if i % 100 == 0:
print(f"Epoch [{epoch}/{num_epochs}], Step [{i}/{len(dataloader)}], D Loss: {loss_D.item():.4f}, G Loss: {loss_G.item():.4f}")
The core problem is that G and D are locked in a minimax game. G wants to minimize the probability that D correctly identifies its outputs as fake. D wants to maximize the probability that it correctly identifies real and fake inputs. When D becomes too powerful, it can easily distinguish real from fake. This means G receives a gradient signal that’s either zero (if D is 100% confident) or very weak, providing no useful information on how to improve. G essentially gets stuck.
Here’s how to keep them balanced and prevent the discriminator from running away with the game:
1. Label Smoothing for the Discriminator: Instead of using hard labels (0 for fake, 1 for real), use slightly softened labels. For instance, train D to output 0.9 for real images and 0.1 for fake images.
- Diagnosis: Monitor
D’s accuracy. If it consistently hovers near 100% on both real and fake batches, it’s too strong. - Fix: In the training loop, modify the label creation:
real_labels = torch.ones(batch_size, 1) * 0.9 # Softened real label fake_labels = torch.zeros(batch_size, 1) * 0.1 # Softened fake label - Why it works: This prevents
Dfrom becoming overly confident. By aiming for 0.9 instead of 1.0, it’s less likely to saturate its output, providing a more informative gradient toGeven whenDis generally good.
2. Use a Different Loss Function (Wasserstein GAN - WGAN): The standard BCE loss can lead to vanishing gradients. WGANs use the Earth Mover’s Distance, which provides a more stable gradient. This requires a few changes:
- Diagnosis: If your BCE loss is fluctuating wildly or
Gloss is consistently very high and not decreasing, consider WGAN. - Fix:
- Remove
nn.Sigmoid()from the output of the Discriminator (it becomes a "Critic"). - Change
nn.BCELoss()totorch.nn.Identity()for the loss calculation. - Modify the loss calculation:
# Train Discriminator discriminator.zero_grad() real_validity = discriminator(real_images) # No sigmoid real_loss = -torch.mean(real_validity) # Maximize validity for real noise = torch.randn(batch_size, latent_dim) fake_images = generator(noise) fake_validity = discriminator(fake_images.detach()) # No sigmoid fake_loss = torch.mean(fake_validity) # Minimize validity for fake loss_D = real_loss + fake_loss loss_D.backward() optimizer_D.step() # Train Generator generator.zero_grad() noise = torch.randn(batch_size, latent_dim) fake_images = generator(noise) validity = discriminator(fake_images) # No sigmoid loss_G = -torch.mean(validity) # Generator wants to maximize validity loss_G.backward() optimizer_G.step() - WGAN-GP (Gradient Penalty): Add a gradient penalty term to
loss_Dto enforce the Lipschitz constraint, which is crucial for WGAN stability. This typically involves calculating the norm of the gradients ofD’s output with respect to its input and penalizing deviations from 1.
- Remove
- Why it works: The Wasserstein distance is a better-behaved metric for comparing probability distributions, especially when they have limited overlap, providing smoother and more reliable gradients for
G.
3. Lower the Learning Rate: A high learning rate can cause G and D to overshoot optimal points, leading to instability.
- Diagnosis: If you see erratic loss curves or generated samples that are nonsensical and don’t improve over many epochs.
- Fix: Reduce
lr. A common starting point islr = 0.0001or even0.00005. - Why it works: Smaller steps allow the optimizers to converge more gracefully towards a stable equilibrium rather than jumping around erratically.
4. Adjust Optimizer Betas: The betas parameter in Adam controls the decay rates of the first and second moment estimates. For GANs, (0.5, 0.999) is often recommended over the default (0.9, 0.999).
- Diagnosis: If you’re using Adam and experiencing oscillations or instability, especially in conjunction with other issues.
- Fix: Set
betas=(0.5, 0.999)for bothoptimizer_Gandoptimizer_D.optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999)) optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999)) - Why it works: A higher beta1 (closer to 1) can make the optimizer more sensitive to recent gradients, which can be destabilizing in GANs. A lower beta1 (like 0.5) averages gradients over a longer period, leading to more stable updates.
5. Discriminator Architecture Tweaks:
* Fewer Layers/Neurons: If D is too powerful, simplifying its architecture can help.
* Batch Normalization: While often helpful, Batch Norm in D can sometimes leak information between samples in a batch, which can be problematic for GANs. Try removing it from D or using alternative normalization layers like LayerNorm or InstanceNorm.
- Diagnosis: If other methods fail and you suspect
Dis just too complex and learning too fast. - Fix: Remove
nn.BatchNorm1dlayers if present, or reduce the number of layers/neurons inD. - Why it works: A less powerful discriminator is less likely to overfit and memorize the training data, thus providing more meaningful gradients to the generator.
6. Train Discriminator More Than Generator (or vice-versa): Sometimes, one network needs more updates than the other to maintain balance.
- Diagnosis: If
Gloss is consistently low andDloss is consistently high (or vice versa), indicating one is significantly outperforming the other. - Fix: For example, to train
Dmore, run its update step multiple times for eachGupdate.# Inside the training loop for _ in range(k): # k is typically 1 to 5 # Train Discriminator steps... loss_D.backward() optimizer_D.step() # Train Generator step... loss_G.backward() optimizer_G.step() - Why it works: This allows one network to "catch up" if it’s falling behind, helping to maintain the adversarial equilibrium.
7. Spectral Normalization: Apply spectral normalization to the weights of the discriminator. This is another technique to enforce the Lipschitz constraint, similar in spirit to WGAN-GP but can be applied to standard GANs.
- Diagnosis: Similar to WGAN-GP, if
Dis too powerful and gradients become unstable. - Fix: Use a spectral normalization layer (e.g., from
torch.nn.utils.spectral_norm).from torch.nn.utils import spectral_norm class Discriminator(nn.Module): def __init__(self, img_dim): super().__init__() self.net = nn.Sequential( spectral_norm(nn.Linear(img_dim, 128)), # Apply spectral norm nn.ReLU(), nn.Linear(128, 1), nn.Sigmoid() ) def forward(self, x): return self.net(x) - Why it works: Spectral normalization constrains the Lipschitz constant of each layer, which helps stabilize training by limiting how much the discriminator’s output can change with respect to its input, thus preventing exploding gradients.
After stabilizing your GAN, the next hurdle you’ll likely face is mode collapse, where the generator produces only a limited variety of outputs.