Mixed precision training in PyTorch, often referred to as Automatic Mixed Precision (AMP), is a technique that leverages both 16-bit (half-precision) and 32-bit (single-precision) floating-point types during neural network training. The primary goal is to significantly accelerate training speed and reduce memory consumption without a substantial loss in model accuracy.
Let’s see it in action. Imagine you’re training a ResNet-50 on ImageNet. Without AMP, a typical training loop might look like this:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet50
from torchvision import transforms
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import CIFAR10 # Using CIFAR10 for a smaller, quicker example
# Model and optimizer setup
model = resnet50(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, 10) # Adjust for CIFAR10 classes
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# Dummy data loader
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
# Using a subset for faster demonstration
subset_indices = list(range(1000))
train_dataset = Subset(dataset, subset_indices)
dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
# Standard FP32 training loop
num_epochs = 5
for epoch in range(num_epochs):
running_loss = 0.0
for i, data in enumerate(dataloader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f'Epoch {epoch + 1}, Loss: {running_loss / len(dataloader):.3f}')
print("Finished FP32 Training")
Now, let’s introduce AMP. The core idea is to use torch.cuda.amp.autocast to automatically cast operations to the appropriate precision and torch.cuda.amp.GradScaler to handle gradient scaling.
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet50
from torchvision import transforms
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import CIFAR10
from torch.cuda.amp import autocast, GradScaler # Import AMP components
# Model and optimizer setup (same as before)
model = resnet50(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, 10)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# Dummy data loader (same as before)
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
subset_indices = list(range(1000))
train_dataset = Subset(dataset, subset_indices)
dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
# AMP training loop
scaler = GradScaler() # Initialize GradScaler
num_epochs = 5
for epoch in range(num_epochs):
running_loss = 0.0
for i, data in enumerate(dataloader, 0):
inputs, labels = data
optimizer.zero_grad()
# Wrap forward pass with autocast
with autocast():
outputs = model(inputs)
loss = criterion(outputs, labels)
# Scale loss.backward() and call optimizer.step()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update() # Update the scale for next iteration
running_loss += loss.item()
print(f'Epoch {epoch + 1}, Loss: {running_loss / len(dataloader):.3f}')
print("Finished AMP Training")
The magic happens within the with autocast(): block. PyTorch automatically selects the appropriate precision for operations based on hardware support and the operation type. For example, matrix multiplications and convolutions, which are computationally intensive and benefit most from reduced precision, will often be performed in FP16. Operations that require higher precision, like loss computations or certain normalization layers, will remain in FP32.
The GradScaler is crucial because FP16 has a much smaller dynamic range than FP32. Gradients can become very small (underflow) and get rounded to zero when converted to FP16, leading to training instability or divergence. GradScaler multiplies the loss by a scaling factor before backward(). This scales up the gradients, preventing them from underflowing. Before optimizer.step(), scaler.step(optimizer) un-scales the gradients and, if they are not inf or NaN, calls optimizer.step(). scaler.update() then adjusts the scaling factor for the next iteration, increasing it if no overflow occurred and decreasing it if overflow did occur.
This combination of automatic precision casting and gradient scaling allows PyTorch AMP to achieve significant speedups (often 1.5x to 3x depending on hardware and model) and reduce GPU memory usage by approximately 20-40%, all while maintaining comparable accuracy to full FP32 training.
The primary problem AMP solves is the trade-off between training speed/memory and accuracy. Historically, you had to choose: slower, more memory-intensive FP32 training for maximum accuracy, or faster, memory-efficient FP16 training that often struggled with stability and accuracy. AMP provides a way to get the best of both worlds by selectively using FP16 where it’s beneficial and FP32 where it’s necessary.
Internally, autocast works by inspecting the operations being performed. When an operation is encountered within the autocast context, autocast checks if the target device (e.g., CUDA) supports FP16 computation for that specific operation. If it does, and if the input tensors are already in FP16 or can be cast to FP16 without significant loss, the operation is executed in FP16. Otherwise, it defaults to FP32. This is done on a per-operator basis, making it "automatic."
The GradScaler is the key to numerical stability. It maintains a scale value. When scaler.scale(loss).backward() is called, PyTorch computes loss * scale. The subsequent backward() call computes gradients with respect to this scaled loss. When scaler.step(optimizer) is called, it first checks for inf or NaN in the gradients. If none are found, it divides the gradients by the current scale (effectively un-scaling them) and then calls optimizer.step(). If inf or NaN are found, it skips the optimizer.step() and reduces the scale for the next iteration. scaler.update() then adjusts the scale based on whether an overflow occurred. If no overflow happened, it increases the scale (up to a max_scale) to try and benefit more from FP16 precision. If an overflow did happen, it decreases the scale. This adaptive scaling ensures that the gradients remain within the representable range of FP16 as much as possible.
One thing many people don’t realize is that autocast doesn’t just cast everything to FP16. It has a registry of operations that are known to be safe and beneficial to run in FP16. For example, torch.nn.Linear and torch.nn.Conv2d are prime candidates. However, operations like torch.nn.BatchNorm2d or torch.nn.LayerNorm are often kept in FP32 because their statistics calculations can be sensitive to precision loss, and the speedup from running them in FP16 is typically less pronounced compared to their computational cost. autocast intelligently selects which operations to cast based on the operation type, the device, and the input tensor dtypes.
The next logical step after mastering mixed precision is exploring distributed training strategies, where you’ll encounter challenges related to communication overhead and gradient synchronization across multiple GPUs or machines.