Ray Train’s FSDP implementation is surprisingly not about sharding your model parameters across nodes, but rather about sharding your optimizer states and gradients to reduce memory pressure on individual GPUs.
Let’s see FSDP in action. Imagine a simple PyTorch model, a few layers deep, and we want to train it with FSDP.
import torch
import torch.nn as nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
import torch.distributed as dist
# Assume distributed environment is already set up (e.g., using torch.distributed.launch or Ray's RayTrain)
# dist.init_Process_group("nccl") # This would be handled by the training framework
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(1024, 1024)
self.relu = nn.ReLU()
self.layer2 = nn.Linear(1024, 10)
def forward(self, x):
x = self.layer1(x)
x = self.relu(x)
x = self.layer2(x)
return x
model = SimpleModel()
# FSDP wrapper
fsdp_model = FSDP(model)
# Dummy data and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(fsdp_model.parameters(), lr=0.001)
data = torch.randn(32, 1024)
labels = torch.randint(0, 10, (32,))
# Training step
optimizer.zero_grad()
outputs = fsdp_model(data)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print(f"Loss: {loss.item()}")
This code snippet shows the core of FSDP integration. You wrap your existing PyTorch nn.Module with FullyShardedDataParallel. The FSDP wrapper intercepts forward and backward passes, managing the sharding and gathering of parameters, gradients, and optimizer states automatically.
The problem FSDP solves is the ever-increasing memory demands of large deep learning models. As models grow, the memory required to store model parameters, gradients, and optimizer states on each GPU in a Data Parallel setup becomes prohibitive. FSDP tackles this by sharding these components across all participating GPUs. Instead of each GPU holding a full copy of the model, its gradients, and its optimizer state, each GPU only holds a shard of these.
Internally, FSDP works by dividing the model into segments (often based on nn.Module boundaries or automatically by size). During the forward pass, when a specific layer’s parameters are needed, FSDP gathers them onto the GPU that will perform the computation for that layer. Once the computation is done, these parameters are discarded, and only the gradients are kept. During the backward pass, gradients are computed and then reduced (all-reduced) to their respective shards. The optimizer then updates only the shards of parameters residing on each GPU. This "all-gather, reduce-scatter" pattern is key.
You control FSDP’s behavior primarily through its constructor arguments. The auto_wrap_policy is crucial for defining how the model is segmented. size_based_auto_wrap_policy is a common choice, where FSDP automatically wraps submodules that exceed a certain parameter count threshold (e.g., min_num_params=100_000_000). You can also use transformer_auto_wrap_policy for common transformer blocks. Other important parameters include cpu_offload (to offload optimizer states to CPU RAM), mixed_precision (to use FP16 for parameters and gradients), and backward_prefetch (to overlap communication and computation).
The auto_wrap_policy can be a bit of a black box. While size_based_auto_wrap_policy sounds simple, FSDP’s internal logic for determining when to wrap a module considers not just the parameter count but also the communication overhead. It tries to find a balance where the cost of gathering parameters for a segment is less than the memory saved by sharding. This means a submodule with slightly fewer parameters might be wrapped if it’s large enough to cause significant memory pressure during its computation, while a slightly larger one might not be wrapped if its parameters are frequently reused and gathering them is cheap relative to the sharding benefit.
The next challenge you’ll face is managing checkpointing with FSDP, which requires careful coordination to save and load sharded states correctly.