PyTorch’s GPU memory management is a subtle dance, and most people don’t realize how much of the "out of memory" (OOM) errors are actually caused by PyTorch not releasing memory it could have.

Let’s see what happens when we train a simple model.

import torch
import torch.nn as nn
import torch.optim as optim

# Dummy data and model
input_size = 1024
hidden_size = 2048
output_size = 10

model = nn.Sequential(
    nn.Linear(input_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, output_size)
)

# Move model to GPU if available
if torch.cuda.is_available():
    model.cuda()

# Dummy input and target
batch_size = 32
inputs = torch.randn(batch_size, input_size)
targets = torch.randint(0, output_size, (batch_size,))

if torch.cuda.is_available():
    inputs = inputs.cuda()
    targets = targets.cuda()

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Training step
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()

print(f"Loss: {loss.item()}")

This looks straightforward, but under the hood, PyTorch is managing memory for tensors, gradients, and intermediate activations. When an OOM error hits, it’s usually because the peak memory usage exceeds what your GPU has.

Here are the common culprits and how to fix them:

1. Not Clearing the Gradient Cache: PyTorch’s autograd engine caches intermediate results to compute gradients. If you don’t zero out the gradients after each optimizer.step(), they accumulate, bloating memory.

  • Diagnosis: Observe GPU memory usage (nvidia-smi) before and after loss.backward(). You’ll see a significant jump.
  • Fix: Add optimizer.zero_grad() before loss.backward().
    optimizer.zero_grad() # Zero gradients
    loss.backward()
    optimizer.step()
    
  • Why it works: This explicitly tells PyTorch to discard the previously computed gradients, freeing up that memory for the next iteration.

2. Keeping Unnecessary Tensors: Holding onto tensors that are no longer needed, especially large ones or those on the GPU, will consume memory. This often happens with intermediate results or data you’ve moved to GPU and forgotten about.

  • Diagnosis: Use torch.cuda.memory_allocated() and torch.cuda.memory_reserved() to track memory usage. If these numbers keep climbing without bound, you’re likely holding onto tensors.
  • Fix: Explicitly delete tensors using del tensor_name and call torch.cuda.empty_cache() to release the memory back to the OS.
    del large_tensor
    torch.cuda.empty_cache()
    
  • Why it works: del removes the Python reference, allowing the garbage collector to reclaim the memory. empty_cache() releases fragmented memory that PyTorch’s caching allocator might be holding onto but isn’t actively using.

3. Large Batch Sizes: The most direct way to increase GPU memory usage is by increasing the batch size. Larger batches mean more data, more intermediate activations, and larger gradients to store.

  • Diagnosis: Monitor nvidia-smi as you increase batch_size. The memory usage will scale roughly linearly.
  • Fix: Reduce batch_size. If you need to process more data, consider gradient accumulation.
    # Instead of large batch_size, use smaller batches and accumulate gradients
    accumulation_steps = 4
    if (i + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()
    else:
        loss.backward() # Only call backward if not stepping
    
  • Why it works: Gradient accumulation simulates a larger batch size by averaging gradients over several smaller batches before performing an optimizer step. This keeps the peak memory usage lower.

4. Inefficient Data Loading: If your DataLoader is too slow or not using multiple workers effectively, your GPU might sit idle while waiting for data. Conversely, if it’s loading too much data into memory at once, it can cause OOMs.

  • Diagnosis: Profile your data loading pipeline. If nvidia-smi shows low GPU utilization during training, your data loader is likely the bottleneck.
  • Fix: Optimize num_workers in DataLoader. Ensure your collate_fn is efficient. If you’re loading large datasets, consider using pin_memory=True for faster CPU-to-GPU transfers.
    from torch.utils.data import DataLoader, TensorDataset
    
    dataset = TensorDataset(inputs, targets)
    # Adjust num_workers based on your CPU cores and workload
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
    
  • Why it works: num_workers uses separate processes to load data in parallel, preventing the main training loop from blocking. pin_memory=True allows tensors to be allocated in page-locked memory, speeding up the transfer to the GPU.

5. Saving Intermediate Tensors During Training: Sometimes, for debugging or analysis, you might save intermediate tensors. If these are large and kept in memory for too long, they contribute to the memory footprint.

  • Diagnosis: Look for any code that explicitly stores tensors in lists or dictionaries that grow over the training loop.
  • Fix: Ensure these stored tensors are moved to CPU (.cpu()) if they don’t need to be on the GPU, or delete them once they’re no longer needed.
    # Instead of:
    # gpu_intermediates.append(intermediate_tensor)
    # Do:
    cpu_intermediates.append(intermediate_tensor.cpu())
    del intermediate_tensor # Or just let it go out of scope
    
  • Why it works: Moving tensors to the CPU frees up precious GPU VRAM. Deleting them promptly ensures the memory is reclaimed immediately.

6. PyTorch Caching Allocator Fragmentation: PyTorch uses a caching allocator to speed up memory allocation. Over time, this can lead to fragmentation, where there’s enough total free memory but not a single contiguous block large enough for a new allocation.

  • Diagnosis: High values for torch.cuda.memory_reserved() compared to torch.cuda.memory_allocated() often indicate fragmentation.
  • Fix: Periodically call torch.cuda.empty_cache(). While this can slow down training slightly due to re-allocation overhead, it’s often necessary.
    # Call this perhaps every N steps or when memory usage plateaus unexpectedly
    if i % 100 == 0:
        torch.cuda.empty_cache()
    
  • Why it works: This forces PyTorch to release all currently unused cached memory back to the CUDA driver, allowing for a more optimal re-allocation on subsequent calls.

The next error you’ll likely encounter after fixing these is RuntimeError: CUDA out of memory. Tried to allocate X GiB but only Y GiB available.

Want structured learning?

Take the full Pytorch course →