Ray Train lets you scale your PyTorch and TensorFlow training jobs across multiple machines, but it’s not just about throwing more GPUs at the problem.

Let’s see it in action. Imagine we have a simple PyTorch model and we want to train it on two workers.

import torch
import torch.nn as nn
import ray
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig

# Define a simple PyTorch model
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 2)

    def forward(self, x):
        return self.linear(x)

# Define the training function
def train_fn(config):
    model = SimpleModel()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()

    # In a real scenario, you'd load data here.
    # For demonstration, we'll use dummy data.
    data_loader = [(torch.randn(16, 10), torch.randint(0, 2, (16,))) for _ in range(10)]

    for epoch in range(config.get("num_epochs", 1)):
        for inputs, labels in data_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch + 1}, Loss: {loss.item()}")

# Initialize Ray
ray.init()

# Configure scaling
# This tells Ray Train to use 2 workers, each with 1 GPU.
# If you don't have GPUs, you can set `num_gpus_per_worker=0`.
scaling_config = ScalingConfig(num_workers=2, use_gpu=True)

# Create a TorchTrainer
trainer = TorchTrainer(
    train_loop_per_worker=train_fn,
    train_loop_config={"num_epochs": 3},
    scaling_config=scaling_config,
)

# Run the training
results = trainer.fit()

print("Training finished!")
ray.shutdown()

This code sets up a TorchTrainer. The train_loop_per_worker is the function that each individual worker will execute. scaling_config dictates how many workers to use and their resource allocation (CPUs, GPUs). trainer.fit() kicks off the distributed training. Ray handles the complexities of distributing the code, data, and model across the specified workers, synchronizing gradients, and aggregating results.

The core problem Ray Train solves is the difficulty of scaling deep learning training beyond a single machine. Traditionally, this involves manual setup of distributed environments, complex communication protocols (like NCCL or Gloo), and careful data sharding. Ray Train abstracts this away. It leverages Ray’s distributed execution framework to manage worker processes, handle inter-worker communication, and provide a unified API for both PyTorch and TensorFlow. You define your training logic as you would for a single machine, and Ray Train takes care of the rest, allowing you to scale out by simply adjusting the ScalingConfig. It supports various parallelism strategies, including data parallelism (where each worker processes a different batch of data) and model parallelism (where different parts of the model are on different workers), although data parallelism is the most common starting point.

One of the most powerful, yet often overlooked, aspects of Ray Train is its integration with Ray Tune for hyperparameter optimization. You can seamlessly pass a tune.Tuner object to trainer.fit(), and Ray Train will automatically orchestrate distributed training runs for each hyperparameter configuration. This means you’re not just scaling your training; you’re also scaling your hyperparameter search, allowing you to explore much larger search spaces efficiently. The results object returned by trainer.fit() contains rich information, including metrics logged during training, and if integrated with Tune, it will contain details about the best performing trials.

The next step after distributed training is often integrating sophisticated checkpointing and fault tolerance mechanisms.

Want structured learning?

Take the full Ray course →