Ray Data’s distributed preprocessing pipeline can feel like a black box, but it’s actually a surprisingly straightforward series of steps that process your data in parallel, transforming it from raw files into a format ready for machine learning.
Let’s see it in action. Imagine you have a directory of CSV files, each containing customer transaction data, and you want to aggregate these transactions by customer ID to get the total amount spent.
import ray
import ray.data
# Initialize Ray (if not already running)
if ray.is_initialized():
ray.shutdown()
ray.init()
# Assume you have a directory named 'transactions' with CSV files like:
# transactions/part_0.csv
# transactions/part_1.csv
# ...
# Read the CSV files into a Ray Dataset
ds = ray.data.read_csv("transactions/*.csv")
# Define a function to process each row
def aggregate_transactions(row):
# This is a simplified example; in reality, this would be more complex
# For demonstration, let's assume we're just extracting customer_id and amount
customer_id = row["customer_id"]
amount = row["amount"]
return {"customer_id": customer_id, "total_spent": amount} # Placeholder for aggregation
# Apply the transformation.
# `map` applies a function to each row.
# `num_blocks` controls parallelism.
processed_ds = ds.map(aggregate_transactions, num_blocks=4)
# Now, let's group by customer_id and sum the amounts.
# `group_by` is a distributed aggregation operation.
aggregated_ds = processed_ds.group_by("customer_id", reduce_fn=sum)
# You can then view the results or write them to storage
# print(aggregated_ds.take(5)) # This would show 5 sample aggregated records
# aggregated_ds.write_csv("output/aggregated_transactions")
ray.shutdown()
This code snippet demonstrates a common pattern: reading data, mapping a function over it, and then performing a group-by aggregation. Ray Data handles the distribution of these operations across your Ray cluster.
The core problem Ray Data solves is the bottleneck of preprocessing large datasets on a single machine. Traditional Python scripts often struggle with memory limits and slow processing times when dealing with gigabytes or terabytes of data. Ray Data breaks down your dataset into smaller "blocks" and processes these blocks in parallel across multiple worker nodes.
Internally, Ray Data uses a directed acyclic graph (DAG) to represent your pipeline. Each operation (like read_csv, map, filter, group_by) is a node in this DAG. Ray’s scheduler then executes these nodes in parallel, optimizing data shuffling and execution to minimize latency. When you call .map(), Ray doesn’t immediately run the function on all your data. Instead, it adds a Map task to the DAG. Similarly, .group_by() adds a GroupBy task. The actual computation only happens when you trigger an action, like .take(), .count(), or .write_csv().
The num_blocks argument in .map() is a crucial lever. It dictates how many parallel tasks Ray should try to run for that specific operation. A higher num_blocks means more parallelism, but also more overhead. You generally want num_blocks to be at least as large as the number of CPU cores available across your cluster, and often more, to keep all cores busy. For .group_by(), Ray often handles block distribution automatically, but you can influence it with arguments like equal_size or by pre-shuffling if you have a skewed distribution.
The real magic of Ray Data is its ability to chain these operations efficiently. When you write ds.map(...).filter(...), Ray doesn’t necessarily materialize the intermediate results of the .map() operation. It can fuse these operations, pushing the filter predicate down to the mapper, reducing the amount of data that needs to be processed and shuffled. This "pushdown" optimization is key to its performance.
One aspect that often trips people up is how group_by handles the reduce_fn. While sum is common, you can pass arbitrary Python functions. However, for true distributed performance, especially with large datasets, you should aim for reduction functions that can be applied associatively and commutatively. For example, if you wanted to count occurrences, reduce_fn=lambda x, y: x + y works, but if you’re doing something more complex like calculating a running average, you’d need to structure your map function to emit intermediate states (like count and sum) that can then be combined. Ray Data’s group_by is designed for commutative and associative operations, so if your reduce_fn isn’t, you might need to rethink your aggregation strategy.
The next concept you’ll likely encounter is handling more complex data formats and distributed training integrations.