PyTorch’s GPU memory management is a subtle dance, and most people don’t realize how much of the "out of memory" (OOM) errors are actually caused by PyTorch not releasing memory it could have.
Let’s see what happens when we train a simple model.
import torch
import torch.nn as nn
import torch.optim as optim
# Dummy data and model
input_size = 1024
hidden_size = 2048
output_size = 10
model = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, output_size)
)
# Move model to GPU if available
if torch.cuda.is_available():
model.cuda()
# Dummy input and target
batch_size = 32
inputs = torch.randn(batch_size, input_size)
targets = torch.randint(0, output_size, (batch_size,))
if torch.cuda.is_available():
inputs = inputs.cuda()
targets = targets.cuda()
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# Training step
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
print(f"Loss: {loss.item()}")
This looks straightforward, but under the hood, PyTorch is managing memory for tensors, gradients, and intermediate activations. When an OOM error hits, it’s usually because the peak memory usage exceeds what your GPU has.
Here are the common culprits and how to fix them:
1. Not Clearing the Gradient Cache: PyTorch’s autograd engine caches intermediate results to compute gradients. If you don’t zero out the gradients after each optimizer.step(), they accumulate, bloating memory.
- Diagnosis: Observe GPU memory usage (
nvidia-smi) before and afterloss.backward(). You’ll see a significant jump. - Fix: Add
optimizer.zero_grad()beforeloss.backward().optimizer.zero_grad() # Zero gradients loss.backward() optimizer.step() - Why it works: This explicitly tells PyTorch to discard the previously computed gradients, freeing up that memory for the next iteration.
2. Keeping Unnecessary Tensors: Holding onto tensors that are no longer needed, especially large ones or those on the GPU, will consume memory. This often happens with intermediate results or data you’ve moved to GPU and forgotten about.
- Diagnosis: Use
torch.cuda.memory_allocated()andtorch.cuda.memory_reserved()to track memory usage. If these numbers keep climbing without bound, you’re likely holding onto tensors. - Fix: Explicitly delete tensors using
del tensor_nameand calltorch.cuda.empty_cache()to release the memory back to the OS.del large_tensor torch.cuda.empty_cache() - Why it works:
delremoves the Python reference, allowing the garbage collector to reclaim the memory.empty_cache()releases fragmented memory that PyTorch’s caching allocator might be holding onto but isn’t actively using.
3. Large Batch Sizes: The most direct way to increase GPU memory usage is by increasing the batch size. Larger batches mean more data, more intermediate activations, and larger gradients to store.
- Diagnosis: Monitor
nvidia-smias you increasebatch_size. The memory usage will scale roughly linearly. - Fix: Reduce
batch_size. If you need to process more data, consider gradient accumulation.# Instead of large batch_size, use smaller batches and accumulate gradients accumulation_steps = 4 if (i + 1) % accumulation_steps == 0: optimizer.step() optimizer.zero_grad() else: loss.backward() # Only call backward if not stepping - Why it works: Gradient accumulation simulates a larger batch size by averaging gradients over several smaller batches before performing an optimizer step. This keeps the peak memory usage lower.
4. Inefficient Data Loading: If your DataLoader is too slow or not using multiple workers effectively, your GPU might sit idle while waiting for data. Conversely, if it’s loading too much data into memory at once, it can cause OOMs.
- Diagnosis: Profile your data loading pipeline. If
nvidia-smishows low GPU utilization during training, your data loader is likely the bottleneck. - Fix: Optimize
num_workersinDataLoader. Ensure yourcollate_fnis efficient. If you’re loading large datasets, consider usingpin_memory=Truefor faster CPU-to-GPU transfers.from torch.utils.data import DataLoader, TensorDataset dataset = TensorDataset(inputs, targets) # Adjust num_workers based on your CPU cores and workload dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True) - Why it works:
num_workersuses separate processes to load data in parallel, preventing the main training loop from blocking.pin_memory=Trueallows tensors to be allocated in page-locked memory, speeding up the transfer to the GPU.
5. Saving Intermediate Tensors During Training: Sometimes, for debugging or analysis, you might save intermediate tensors. If these are large and kept in memory for too long, they contribute to the memory footprint.
- Diagnosis: Look for any code that explicitly stores tensors in lists or dictionaries that grow over the training loop.
- Fix: Ensure these stored tensors are moved to CPU (
.cpu()) if they don’t need to be on the GPU, or delete them once they’re no longer needed.# Instead of: # gpu_intermediates.append(intermediate_tensor) # Do: cpu_intermediates.append(intermediate_tensor.cpu()) del intermediate_tensor # Or just let it go out of scope - Why it works: Moving tensors to the CPU frees up precious GPU VRAM. Deleting them promptly ensures the memory is reclaimed immediately.
6. PyTorch Caching Allocator Fragmentation: PyTorch uses a caching allocator to speed up memory allocation. Over time, this can lead to fragmentation, where there’s enough total free memory but not a single contiguous block large enough for a new allocation.
- Diagnosis: High values for
torch.cuda.memory_reserved()compared totorch.cuda.memory_allocated()often indicate fragmentation. - Fix: Periodically call
torch.cuda.empty_cache(). While this can slow down training slightly due to re-allocation overhead, it’s often necessary.# Call this perhaps every N steps or when memory usage plateaus unexpectedly if i % 100 == 0: torch.cuda.empty_cache() - Why it works: This forces PyTorch to release all currently unused cached memory back to the CUDA driver, allowing for a more optimal re-allocation on subsequent calls.
The next error you’ll likely encounter after fixing these is RuntimeError: CUDA out of memory. Tried to allocate X GiB but only Y GiB available.