FSDP isn’t just about fitting big models onto fewer GPUs; it’s primarily about speeding up training by distributing computation and communication more efficiently.

Let’s see FSDP in action. Imagine you’ve got a massive model, say 70 billion parameters, and you want to train it on 8 A100s. Without FSDP, you’d likely run out of memory. With FSDP, you can distribute the model’s parameters, gradients, and optimizer states across those GPUs.

Here’s a minimal FSDP setup for a simple model:

import torch
import torch.nn as nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
import functools

# 1. Initialize distributed environment (essential for FSDP)
# This would typically be done via torchrun or torch.distributed.launch
# For demonstration, let's assume it's already set up.
# Example: torchrun --nproc_per_node=8 your_script.py

# 2. Define your model
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(1024, 2048)
        self.relu = nn.ReLU()
        self.layer2 = nn.Linear(2048, 1024)

    def forward(self, x):
        return self.layer2(self.relu(self.layer1(x)))

# 3. Wrap the model with FSDP
model = SimpleModel()

# Define a wrapping policy - for large models, auto-wrapping is key
# This example uses size_based_auto_wrap_policy, which automatically
# wraps modules larger than a certain size. You'd tune `min_num_params`
# based on your model and GPU memory.
# For a 70B model, you'd likely want to wrap individual layers or blocks.
wrap_policy = functools.partial(
    size_based_auto_wrap_policy, min_num_params=1e9 # Tune this value
)

# Instantiate FSDP
# common_kwargs for FSDP:
# - cpu_offload (bool): Offload optimizer states/params to CPU if GPU memory is insufficient.
# - auto_wrap_policy (Callable): Determines which modules to wrap.
# - sharding_strategy (FSDP.ShardingStrategy): Controls how parameters, gradients,
#   and optimizer states are sharded. Common options:
#   - FULL_SHARD: Each GPU holds only a shard of everything.
#   - 1 or 2 DPs: Distribute across data parallel ranks.
#   - HYBRID: Combination of DP and sharding.
# - backward_prefetch (FSDP.BackwardPrefetch): Prefetch parameters for the backward pass.
# - use_orig_params (bool): Whether to use original parameters or FSDP's wrapped ones.

fsdp_model = FSDP(
    model,
    auto_wrap_policy=wrap_policy,
    sharding_strategy=FSDP.ShardingStrategy.FULL_SHARD,
    cpu_offload=None, # Or specify FSDP.CPUOffload(offload_optimizer=True)
    backward_prefetch=FSDP.BackwardPrefetch.BACKWARD_PRE,
    use_orig_params=True,
)

# 4. Prepare dummy data and optimizer
batch_size = 32
input_data = torch.randn(batch_size, 1024)
labels = torch.randn(batch_size, 1024)
optimizer = torch.optim.AdamW(fsdp_model.parameters(), lr=1e-5)
criterion = nn.MSELoss()

# 5. Training loop (simplified)
# In a real scenario, this would be inside a distributed launcher
# and involve more complex data loading.
if torch.distributed.is_initialized():
    for epoch in range(1):
        optimizer.zero_grad()
        outputs = fsdp_model(input_data)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        print(f"Rank {torch.distributed.get_rank()}: Loss = {loss.item()}")
else:
    print("Distributed environment not initialized. Skipping training loop.")

FSDP tackles the memory bottleneck by sharding everything – parameters, gradients, and optimizer states – across all participating GPUs. When a specific GPU needs a piece of data (e.g., for a forward or backward pass), it requests that shard from the GPU that currently owns it. The key is that no single GPU ever holds the full model or its associated optimizer states.

The core problem FSDP solves is fitting models that exceed the memory of a single GPU, which is increasingly common with state-of-the-art LLMs. It does this by implementing a sophisticated sharding strategy that distributes the model’s components across available devices.

Here’s how it works internally:

  1. Parameter Sharding: Instead of each GPU having a full copy of the model’s parameters, FSDP divides them. Each GPU is responsible for a shard of the parameters. During the forward pass, when a layer needs its parameters, FSDP gathers the necessary shards from other GPUs onto the current GPU, performs the computation, and then discards the gathered parameters, keeping only its local shard.
  2. Gradient Sharding: Similarly, gradients computed during the backward pass are also sharded. Each GPU only stores and reduces the gradients for the parameter shards it owns.
  3. Optimizer State Sharding: The optimizer states (like momentum buffers in Adam) are the largest memory consumers. FSDP shards these states along with the parameters, meaning each GPU only holds the optimizer states for the parameter shards it’s responsible for.

This fine-grained distribution means that the memory footprint per GPU is drastically reduced, allowing much larger models to be trained. The communication overhead is managed by FSDP, which orchestrates the gathering and scattering of parameter shards, often overlapping communication with computation to minimize latency.

The sharding_strategy parameter is crucial. FSDP.ShardingStrategy.FULL_SHARD is the most aggressive, sharding parameters, gradients, and optimizer states. Other strategies exist, like FULL_SHARD_DP, which combines FSDP sharding with data parallelism. For massive models, FULL_SHARD is typically the go-to.

The auto_wrap_policy is how FSDP decides which modules to shard independently. For very large models, you don’t want to shard individual nn.Linear layers if they are small. Instead, you want to shard larger blocks like transformer layers or attention mechanisms. size_based_auto_wrap_policy is a common choice, automatically wrapping modules that exceed a certain parameter count threshold, ensuring that FSDP shards meaningful computational units. Tuning min_num_params here is critical to balance memory savings and communication overhead.

One aspect that often trips people up is how FSDP handles the optimizer. When using FULL_SHARD, the optimizer doesn’t see the full parameters. It only sees the local shard of parameters on each GPU. FSDP internally manages the state sharding, so when optimizer.step() is called, it operates on the locally available parameter shard and its corresponding optimizer state shard. This is why use_orig_params=True can be important; it ensures that the optimizer is working with the underlying parameter tensors that FSDP manages, rather than potentially seeing FSDP’s internal wrappers directly.

If you’re running out of memory even with FSDP, the next step is often to explore cpu_offload=FSDP.CPUOffload(offload_optimizer=True). This moves optimizer states to CPU RAM, which is much larger but slower, freeing up precious GPU VRAM.

The next hurdle you’ll likely face is optimizing communication. While FSDP reduces memory, it increases communication. Understanding and tuning communication primitives, like all_gather and reduce_scatter, or exploring mixed-precision training with FSDP, becomes paramount.

Want structured learning?

Take the full Pytorch course →