Ray Batch Inference at Scale: Process Millions of Rows
The most surprising thing about processing millions of rows with Ray Batch Inference is how little it actually requires you to change your existing inference code.
Let’s see it in action. Imagine you have a PyTorch model and a function to run inference on a single batch.
import torch
import ray
from ray.train.batch_predictor import BatchPredictor
from ray.train.torch import TorchCheckpoint
# Assume model_path points to your saved PyTorch model
model_path = "path/to/your/pytorch_model.pt"
model = torch.load(model_path)
model.eval()
def predict_batch(batch):
# Convert pandas DataFrame batch to tensors
# Adjust these lines based on your model's input requirements
inputs = torch.tensor(batch.values, dtype=torch.float32)
with torch.no_grad():
predictions = model(inputs)
# Convert predictions back to a format that can be returned
return predictions.numpy()
# Initialize Ray (if not already initialized)
if not ray.is_initialized():
ray.init()
# Create a BatchPredictor instance
# The predict_batch function will be called on each batch
batch_predictor = BatchPredictor.from_pandas_udf(
predict_batch,
batch_size_bytes=1024 * 1024 * 10, # 10MB per batch
# If your model requires specific device placement (e.g., GPU)
# you might need to configure resources here.
# For simplicity, we'll assume CPU for now.
)
# Assume you have a large dataset loaded into a pandas DataFrame
# For demonstration, let's create a dummy DataFrame
import pandas as pd
import numpy as np
num_rows = 1_000_000
num_features = 10
data = np.random.rand(num_rows, num_features)
df = pd.DataFrame(data, columns=[f"feature_{i}" for i in range(num_features)])
# Run batch inference
# The output will be a pandas DataFrame with predictions
# The column name will be 'predictions' by default.
# You can customize this by passing `predict_batch_kwargs={'output_column_name': 'my_preds'}`
# to BatchPredictor.from_pandas_udf
predictions_df = batch_predictor.predict(df)
print(f"Processed {len(predictions_df)} rows. First 5 predictions:")
print(predictions_df.head())
ray.shutdown()
What’s happening under the hood is that Ray is taking your predict_batch function and distributing it across potentially many workers. It automatically handles splitting your large DataFrame into smaller chunks (batches), sending those chunks to the workers, executing your predict_batch function on each chunk, and then collecting the results. The batch_size_bytes parameter is a crucial knob for performance; it tells Ray how large each batch should ideally be in memory, allowing you to fine-tune the trade-off between parallelism and overhead.
The core problem this solves is the out-of-memory error and the sheer slowness of trying to run inference on millions of rows all at once on a single machine. Batch inference, especially distributed batch inference with Ray, breaks the problem down. Instead of one massive inference call, you have many smaller, manageable inference calls happening in parallel. Ray’s BatchPredictor abstracts away the complexities of data sharding, task scheduling, and result aggregation. You provide the model loading and the per-batch prediction logic, and Ray handles the rest.
The from_pandas_udf method is a convenient way to wrap your existing inference logic. Ray will serialize your predict_batch function and send it to worker processes. When a worker receives a batch of data, it deserializes the function and executes it. The batch_size_bytes parameter is key to controlling resource utilization. A larger batch_size_bytes means larger batches, which can be more efficient if your model can handle them without excessive memory usage, leading to fewer, but larger, inference calls. A smaller batch_size_bytes means smaller batches, which can be better for memory-constrained environments or when you want to maximize parallelism by having more, smaller tasks.
A detail that often trips people up is how Ray handles model loading. If your model is large, loading it on every worker for every batch would be incredibly inefficient. Ray’s TorchCheckpoint (and similar constructs for other frameworks) allows you to load your model once per worker. You would typically save your model using TorchCheckpoint.from_model() and then pass this checkpoint to the BatchPredictor. When Ray starts workers, it can pre-load the model onto each worker, ensuring that your predict_batch function has immediate access to the model without incurring the loading cost repeatedly. This is essential for performance at scale.
The next step after successfully processing millions of rows is often optimizing the throughput and latency of your batch inference jobs.