Ray Serve’s dynamic batching is a surprisingly effective way to boost throughput for your inference workloads by grouping independent requests together.
Let’s see it in action. Imagine we have a simple Keras model that predicts house prices. Without batching, each request would be processed individually, leading to underutilization of our GPU.
from ray import serve
from starlette.requests import Request
import numpy as np
import tensorflow as tf
@serve.deployment(num_replicas=1)
class HousePricePredictor:
def __init__(self):
# Load a pre-trained Keras model
self.model = tf.keras.models.load_model("house_price_model.h5")
async def __call__(self, request: Request):
data = await request.json()
# Convert input data to a NumPy array and then to a TensorFlow tensor
input_data = tf.constant(np.array(data['features']), dtype=tf.float32)
predictions = self.model.predict(input_data)
return {"predictions": predictions.tolist()}
app = serve.run(HousePricePredictor.bind())
Now, let’s enable dynamic batching. We’ll add batching_config to our deployment:
from ray import serve
from starlette.requests import Request
import numpy as np
import tensorflow as tf
@serve.deployment(
num_replicas=1,
# Configure dynamic batching
batching_config={
"max_batch_size": 32, # Maximum number of requests to group
"timeout_s": 0.1, # Max time to wait for a batch to fill
"stride_s": 0, # Process batches as soon as they are ready
"allow_small_batch": True # Allow batches smaller than max_batch_size
}
)
class HousePricePredictor:
def __init__(self):
self.model = tf.keras.models.load_model("house_price_model.h5")
async def __call__(self, request: Request):
data = await request.json()
input_data = tf.constant(np.array(data['features']), dtype=tf.float32)
# The predict method now receives a batch of inputs
predictions = self.model.predict(input_data)
return {"predictions": predictions.tolist()}
app = serve.run(HousePricePredictor.bind())
When requests arrive, Ray Serve intelligently queues them. If a request comes in and there are already N requests waiting, and N + 1 <= max_batch_size, the new request will join the queue. If the queue reaches max_batch_size or the timeout_s elapses, the entire batch is sent to the __call__ method of the deployment. Notice how self.model.predict now operates on a batch of inputs, which is significantly more efficient for hardware like GPUs that excel at parallel computation.
The core problem batching solves is the overhead associated with processing many small, independent requests. Each request incurs some fixed cost (e.g., network deserialization, model loading/initialization, Python function call overhead). When requests are batched, this fixed cost is amortized across multiple requests. For deep learning models, especially those running on GPUs, the computation itself can be highly parallelized. A single forward pass on a batch of 32 inputs can be many times faster than 32 separate forward passes, often not even linearly scaling with batch size.
The key levers you control are within batching_config:
max_batch_size: This is the upper limit on the number of requests that will be grouped into a single batch. A larger value can increase throughput but also increases latency if requests have to wait for the batch to fill.timeout_s: This is the maximum time Serve will wait for a batch to fill up tomax_batch_size. If this timeout is reached, the current batch will be processed even if it’s smaller thanmax_batch_size. This is crucial for balancing throughput and latency. A shorter timeout means lower latency but potentially smaller batches and lower throughput.stride_s: This setting controls how often Serve checks for new batches. Astride_sof0means Serve checks continuously, effectively processing batches as soon as they are ready (either by reachingmax_batch_sizeor hittingtimeout_s). A non-zero stride can be used to reduce CPU overhead if you have a very high number of deployments and requests.allow_small_batch: WhenTrue, Serve will emit batches that are smaller thanmax_batch_sizeif thetimeout_sis reached or if the upstream source stops sending requests. This is generally what you want for production systems to avoid unbounded latency.
A common misconception is that you must always fill the max_batch_size for optimal performance. In reality, the optimal max_batch_size and timeout_s depend heavily on your specific model, hardware, and the arrival rate of your requests. For instance, if your requests arrive at a rate that naturally leads to batches of 10-15 requests, setting max_batch_size to 32 with a timeout_s of 0.1 might be perfectly fine, and you’ll rarely wait for a full batch. If, however, you have bursty traffic, a larger max_batch_size and a slightly longer timeout_s might be beneficial to capture those bursts, but at the cost of increased latency for individual requests within those larger batches. The stride_s parameter, when set to 0, means that Serve will aggressively try to form batches as soon as possible, either by filling up to max_batch_size or by hitting the timeout_s. This aggressive polling can consume more CPU resources if you have a vast number of replicas or deployments, but for typical use cases, it ensures that batches are processed with minimal delay.
Once you have dynamic batching configured, the next hurdle is understanding how request ordering and dependencies interact with batch processing.