Saving and restoring your Ray training job mid-execution is surprisingly complex because it involves coordinating state across potentially thousands of distributed workers, not just saving a single file.
Let’s watch a simple example. We’ll train a small PyTorch model for a few epochs, save a checkpoint, and then resume training from that checkpoint.
import ray
import torch
import torch.nn as nn
import torch.optim as optim
from ray.train.torch import TorchTrainer
from ray.train import Checkpoint
from ray.air.config import ScalingConfig, RunConfig
# Initialize Ray
ray.init(ignore_reinit_error=True)
# Define a simple model
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 2)
def forward(self, x):
return self.linear(x)
# Dummy data
data = torch.randn(128, 10)
labels = torch.randint(0, 2, (128,))
# Define training function that saves checkpoints
def train_fn(config):
model = SimpleModel()
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
# Resume from checkpoint if available
start_epoch = 0
if config.get("checkpoint"):
checkpoint = Checkpoint.from_dict(config["checkpoint"])
state = checkpoint.to_dict()
model.load_state_dict(state["model_state_dict"])
optimizer.load_state_dict(state["optimizer_state_dict"])
start_epoch = state["epoch"] + 1
print(f"Resuming training from epoch {start_epoch}")
for epoch in range(start_epoch, start_epoch + 5):
optimizer.zero_grad()
outputs = model(data)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print(f"Epoch {epoch}, Loss: {loss.item()}")
# Save checkpoint every epoch
if epoch % 1 == 0:
checkpoint_data = {
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
}
# This is the key: Ray Train handles distributing and storing this
# checkpoint data.
yield Checkpoint.from_dict(checkpoint_data)
# Configure the trainer
trainer = TorchTrainer(
train_loop_per_worker=train_fn,
scaling_config=ScalingConfig(num_workers=2),
run_config=RunConfig(storage_path="/tmp/ray_checkpoints"),
)
# First run: train for a bit and save
print("--- Starting first training run ---")
result1 = trainer.fit()
latest_checkpoint = result1.checkpoint
print("\n--- Starting second training run (resuming) ---")
# Second run: resume from the latest checkpoint
# Pass the checkpoint data to the train_fn via config
trainer.fit(
config={"checkpoint": latest_checkpoint.to_dict()}
)
ray.shutdown()
In this example, TorchTrainer orchestrates the training. The train_fn is executed by each worker. Inside train_fn, we check config.get("checkpoint"). If a checkpoint is provided, we load the model and optimizer states from it. Crucially, yield Checkpoint.from_dict(checkpoint_data) is how the trainer knows to save the current state. Ray Train then takes care of serializing this checkpoint_data (which can be arbitrarily complex, containing model weights, optimizer states, epoch numbers, etc.) and storing it in a location specified by storage_path. When we call trainer.fit() again, we pass the latest_checkpoint from the previous run’s result into the config dictionary, allowing train_fn to find and load it.
The core problem Ray Train solves here is distributed state management. It’s not just about saving a file; it’s about ensuring that every worker has a consistent view of the training progress. When you yield Checkpoint, Ray Train serializes the provided dictionary. For PyTorch, this typically includes the model’s state_dict() and the optimizer’s state_dict(). If you were using other frameworks or custom state, you’d ensure those objects are serializable and included in the dictionary. Ray Train then takes this serialized data and stores it. The storage_path is essential; it’s where the checkpoint files (often tarballs containing serialized Python objects) are written. By default, this might be a local directory, but it can be configured to point to cloud storage like S3 or GCS.
The surprising part for many is how the state is passed back into the training loop. It’s not an automatic injection. You explicitly check the config dictionary passed to your train_loop_per_worker function for a "checkpoint" key. If it exists, you load it. This design gives you fine-grained control. You can decide when to load a checkpoint and how to integrate its state into your existing training variables. This is also how you handle custom state beyond just model and optimizer parameters – just add them to the dictionary you yield and then load them back from the dictionary you retrieve from the Checkpoint object.
If you’re using ray.train.DataConfig and ray.train.DatasetConfig, checkpointing also needs to account for the distributed dataset. Ray Train has mechanisms to checkpoint the state of the distributed data loading process, ensuring that when you resume, you don’t re-process data you’ve already seen or skip data. This is typically handled by checkpointing the internal state of the TorchDataLoader or equivalent for other frameworks. The Checkpoint object itself can contain references or serialized state that allows these data loaders to be reconstructed correctly.
After successfully resuming and completing your training, the next hurdle is often managing and cleaning up old checkpoints, especially in long-running or frequently interrupted training jobs.