Fine-tuning a Hugging Face LLM with Ray Train is surprisingly like teaching a very smart, very expensive parrot to speak a new dialect, except the parrot is actually a massive neural network and the dialect is your specific data.
Let’s see this in action. Imagine we have a base LLM, say meta-llama/Llama-2-7b-hf, and we want to fine-tune it on a dataset of customer support tickets to make it better at answering common questions.
First, we need to set up our environment. This involves installing transformers, datasets, and ray[train].
pip install transformers datasets ray[train] accelerate bitsandbytes
Now, let’s get our data ready. We’ll use the datasets library to load a dummy dataset. In a real scenario, you’d load your own CSV, JSON, or even a Hugging Face dataset.
from datasets import load_dataset
# Load a dummy dataset (replace with your actual data loading)
dataset = load_dataset("imdb", split="train[:1%]") # Using a small slice for demonstration
# Preprocessing function to tokenize text
from transformers import AutoTokenizer
model_name = "meta-llama/Llama-2-7b-hf" # Or any other compatible model
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Add padding token if it doesn't exist
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
def preprocess_function(examples):
# Assuming your dataset has a 'text' column
return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512)
tokenized_dataset = dataset.map(preprocess_function, batched=True, remove_columns=["text"])
# Rename columns to match what the trainer expects (e.g., 'label' and 'input_ids')
# For LLM fine-tuning, often you just need input_ids and attention_mask.
# If your task involves labels, ensure they are correctly mapped.
tokenized_dataset = tokenized_dataset.rename_column("label", "labels")
tokenized_dataset = tokenized_dataset.rename_column("input_ids", "input_ids")
tokenized_dataset = tokenized_dataset.rename_column("attention_mask", "attention_mask")
# Filter out any potential empty sequences after tokenization (optional but good practice)
tokenized_dataset = tokenized_dataset.filter(lambda x: len(x['input_ids']) > 0)
Next, we define our Ray Train training configuration. This is where Ray’s distributed capabilities come into play.
import ray
from ray.train.torch import TorchTrainer
from ray.train import RunConfig, ScalingConfig
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer
import torch
# Initialize Ray (if not already initialized)
if ray.is_initialized():
ray.shutdown()
ray.init()
# Load the model
model = AutoModelForCausalLM.from_pretrained(model_name)
# Configure training arguments
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=1,
per_device_train_batch_size=2, # Small batch size for demonstration
gradient_accumulation_steps=4, # Accumulate gradients to simulate larger batch size
learning_rate=2e-5,
logging_dir="./logs",
logging_steps=10,
save_strategy="epoch",
report_to="none", # Disable reporting to external services for this example
fp16=True, # Use mixed precision for faster training
)
# Create a Hugging Face Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
# eval_dataset=tokenized_eval_dataset # Include if you have an evaluation set
)
# Define the training function for Ray
def train_fn(config):
# The trainer object is passed implicitly or can be accessed via config
# In this simple setup, we're just calling trainer.train()
trainer.train()
# In a more complex scenario, you might load model/data here based on config
# Configure Ray Train
scaling_config = ScalingConfig(num_workers=2, use_gpu=True) # Use 2 workers, each with a GPU
run_config = RunConfig(storage_path="./ray_results")
# Create the Ray TorchTrainer
# We need to pass the Hugging Face Trainer to Ray.
# Ray Train's TorchTrainer works by running a `train_fn` on each worker.
# Inside `train_fn`, you typically load your model, data, and then use a framework-specific
# trainer (like HF Trainer or PyTorch Lightning).
# For Hugging Face Trainer, the typical pattern is to initialize the Trainer *outside*
# the `train_fn` and then call `trainer.train()` inside it. This works because Ray
# will serialize and distribute the `trainer` object (or parts of it) to the workers.
# However, a more robust way for distributed training with HF Trainer and Ray is to
# initialize the Trainer *inside* the `train_fn` on each worker, ensuring each worker
# gets its own instance. This requires passing model and dataset to the `train_fn`.
# Let's refine `train_fn` to be more explicit for Ray:
def train_fn_ray(config):
# Load model and tokenizer on each worker
model_name = "meta-llama/Llama-2-7b-hf"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
# Load dataset on each worker (Ray handles data sharding if needed)
# For simplicity, we'll assume the dataset is already preprocessed and available.
# In a real Ray setup, you'd often use Ray Datasets or pass data loading logic.
# Here, we'll pass the tokenized_dataset object directly (Ray will serialize it).
# A more production-ready approach would involve Ray Datasets for efficient sharding.
from datasets import Dataset
train_data = tokenized_dataset # Assuming tokenized_dataset is accessible in this scope
# Re-initialize TrainingArguments on each worker if needed, or pass relevant parts.
# For simplicity, we'll reuse the existing training_args object.
# Important: Ensure `output_dir` is accessible by all workers or handled appropriately.
# In a real scenario, you might use a shared filesystem or Ray's persistent storage.
training_args_worker = TrainingArguments(
output_dir="./results_worker", # Each worker might have its own temp dir
num_train_epochs=1,
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
learning_rate=2e-5,
logging_dir="./logs_worker",
logging_steps=10,
save_strategy="epoch",
report_to="none",
fp16=True,
local_rank=config.get("local_rank", 0), # Ray provides local_rank
# DDP settings are often handled by Ray's underlying distributed backend
)
trainer_worker = Trainer(
model=model,
args=training_args_worker,
train_dataset=train_data,
tokenizer=tokenizer, # Pass tokenizer for convenience
)
trainer_worker.train()
# Re-initialize Ray if needed
if ray.is_initialized():
ray.shutdown()
ray.init(runtime_env={"pip": ["transformers", "datasets", "torch", "accelerate", "bitsandbytes"]}) # Ensure dependencies are available
scaling_config = ScalingConfig(num_workers=2, use_gpu=True)
run_config = RunConfig(storage_path="./ray_results")
# Create the Ray TorchTrainer
# The `train_fn_ray` function will be executed on each worker.
# `TorchTrainer` handles distributing the data and synchronizing gradients.
# The `scaling_config` and `run_config` define how the training is distributed and managed.
trainer_ray = TorchTrainer(
train_fn=train_fn_ray,
scaling_config=scaling_config,
run_config=run_config,
# `train_loop_config` can pass arguments to `train_fn`.
# For this example, we'll assume `tokenized_dataset` is globally accessible or
# passed via `datasets_to_load` if using Ray Datasets.
)
# Start the training
results = trainer_ray.fit()
print("Training finished!")
# The `results` object contains information about the training run,
# including metrics and paths to saved checkpoints.
The core idea is that TorchTrainer takes your train_fn (which contains your model loading, data loading, and training loop logic) and executes it across multiple workers. Ray handles the communication, gradient synchronization (via PyTorch’s DistributedDataParallel or similar), and fault tolerance. You configure the number of workers, whether to use GPUs, and where to store results.
The mental model here is that Ray is the conductor, and your train_fn is the orchestra’s score. Ray ensures each musician (worker) gets the right notes (data, model weights) and that they play in sync (gradients are averaged correctly). The TorchTrainer specifically orchestrates PyTorch-based training, integrating seamlessly with libraries like Hugging Face Transformers.
One aspect that often trips people up is how data is handled. While you can pass a datasets.Dataset object directly to TorchTrainer (Ray serializes it), for very large datasets, this isn’t efficient. The idiomatic Ray way is to use ray.data.Dataset. You would then load your data using ray.data.read_* and pass the ray.data.Dataset to TorchTrainer’s datasets_to_load argument. Ray’s data API is designed for distributed data loading, sharding, and preprocessing, ensuring each worker efficiently receives its portion of the data without overwhelming serialization.
After this, you’ll likely want to evaluate your fine-tuned model and potentially deploy it.