From 81da1cead37f501091b5c2458182c190bbedb3f4 Mon Sep 17 00:00:00 2001 From: abrar Date: Fri, 28 Nov 2025 07:06:31 +0000 Subject: [PATCH] [Serve] add support for custom batch size function Signed-off-by: abrar --- .../serve/advanced-guides/dyn-req-batch.md | 49 ++++- doc/source/serve/doc_code/batching_guide.py | 91 +++++++++ python/ray/serve/batching.py | 75 +++++++- python/ray/serve/tests/test_batching.py | 176 ++++++++++++++++++ 4 files changed, 384 insertions(+), 7 deletions(-) diff --git a/doc/source/serve/advanced-guides/dyn-req-batch.md b/doc/source/serve/advanced-guides/dyn-req-batch.md index 0ff2b9a4cd84..033e6d0d9e98 100644 --- a/doc/source/serve/advanced-guides/dyn-req-batch.md +++ b/doc/source/serve/advanced-guides/dyn-req-batch.md @@ -30,12 +30,13 @@ emphasize-lines: 11-12 --- ``` -You can supply 3 optional parameters to the decorators. -- `batch_wait_timeout_s` controls how long Serve should wait for a batch once the first request arrives. The default value is 0.01 (10 milliseconds). +You can supply 4 optional parameters to the decorators: - `max_batch_size` controls the size of the batch. The default value is 10. +- `batch_wait_timeout_s` controls how long Serve should wait for a batch once the first request arrives. The default value is 0.01 (10 milliseconds). - `max_concurrent_batches` maximum number of batches that can run concurrently. The default value is 1. +- `batch_size_fn` optional function to compute the effective batch size. If provided, this function takes a list of items and returns an integer representing the batch size. This is useful for batching based on custom metrics such as total nodes in graphs or total tokens in sequences. If `None` (the default), the batch size is computed as `len(batch)`. -Once the first request arrives, the batching decorator waits for a full batch (up to `max_batch_size`) until `batch_wait_timeout_s` is reached. If the timeout is reached, the Serve sends the batch to the model regardless the batch size. +Once the first request arrives, the batching decorator waits for a full batch (up to `max_batch_size`) until `batch_wait_timeout_s` is reached. If the timeout is reached, Serve sends the batch to the model regardless of the batch size. :::{tip} You can reconfigure your `batch_wait_timeout_s` and `max_batch_size` parameters using the `set_batch_wait_timeout_s` and `set_max_batch_size` methods: @@ -50,6 +51,44 @@ end-before: __batch_params_update_end__ Use these methods in the constructor or the `reconfigure` [method](serve-user-config) to control the `@serve.batch` parameters through your Serve configuration file. ::: +## Custom batch size functions + +By default, Ray Serve measures batch size as the number of items in the batch (`len(batch)`). However, in many workloads, the computational cost depends on properties of the items themselves rather than just the count. For example: + +- **Graph Neural Networks (GNNs)**: The cost depends on the total number of nodes across all graphs, not the number of graphs +- **Natural Language Processing (NLP)**: Transformer models batch by total token count, not the number of sequences +- **Variable-resolution images**: Memory usage depends on total pixels, not the number of images + +Use the `batch_size_fn` parameter to define a custom metric for batch size: + +### Graph Neural Network example + +The following example shows how to batch graph data by total node count: + +```{literalinclude} ../doc_code/batching_guide.py +--- +start-after: __batch_size_fn_begin__ +end-before: __batch_size_fn_end__ +emphasize-lines: 14-17 +--- +``` + +In this example, `batch_size_fn=lambda graphs: sum(g.num_nodes for g in graphs)` ensures that the batch contains at most 10,000 total nodes, preventing GPU memory overflow regardless of how many individual graphs are in the batch. + +### NLP token batching example + +The following example shows how to batch text sequences by total token count: + +```{literalinclude} ../doc_code/batching_guide.py +--- +start-after: __batch_size_fn_nlp_begin__ +end-before: __batch_size_fn_nlp_end__ +emphasize-lines: 7-10 +--- +``` + +This pattern ensures that the total number of tokens doesn't exceed the model's context window or memory limits. + (serve-streaming-batched-requests-guide)= ## Streaming batched requests @@ -80,7 +119,9 @@ Some inputs within a batch may generate fewer outputs than others. When a partic `max_batch_size` ideally should be a power of 2 (2, 4, 8, 16, ...) because CPUs and GPUs are both optimized for data of these shapes. Large batch sizes incur a high memory cost as well as latency penalty for the first few requests. -Set `batch_wait_timeout_s` considering the end to end latency SLO (Service Level Objective). For example, if your latency target is 150ms, and the model takes 100ms to evaluate the batch, set the `batch_wait_timeout_s` to a value much lower than 150ms - 100ms = 50ms. +When using `batch_size_fn`, set `max_batch_size` based on your custom metric rather than item count. For example, if batching by total nodes in graphs, set `max_batch_size` to your GPU's maximum node capacity (such as 10,000 nodes) rather than a count of graphs. + +Set `batch_wait_timeout_s` considering the end-to-end latency SLO (Service Level Objective). For example, if your latency target is 150ms, and the model takes 100ms to evaluate the batch, set the `batch_wait_timeout_s` to a value much lower than 150ms - 100ms = 50ms. When using batching in a Serve Deployment Graph, the relationship between an upstream node and a downstream node might affect the performance as well. Consider a chain of two models where first model sets `max_batch_size=8` and second model sets `max_batch_size=6`. In this scenario, when the first model finishes a full batch of 8, the second model finishes one batch of 6 and then to fill the next batch, which Serve initially only partially fills with 8 - 6 = 2 requests, leads to incurring latency costs. The batch size of downstream models should ideally be multiples or divisors of the upstream models to ensure the batches work optimally together. diff --git a/doc/source/serve/doc_code/batching_guide.py b/doc/source/serve/doc_code/batching_guide.py index 5ae7852a5723..f78a929dfdb9 100644 --- a/doc/source/serve/doc_code/batching_guide.py +++ b/doc/source/serve/doc_code/batching_guide.py @@ -152,3 +152,94 @@ def issue_request(max) -> List[str]: chunks_list = [fut.result() for fut in futs] for max, chunks in zip(requested_maxes, chunks_list): assert chunks == [str(i) for i in range(max)] + +# __batch_size_fn_begin__ +from typing import List + +from ray import serve +from ray.serve.handle import DeploymentHandle + + +class Graph: + """Simple graph data structure for GNN workloads.""" + + def __init__(self, num_nodes: int, node_features: list): + self.num_nodes = num_nodes + self.node_features = node_features + + +@serve.deployment +class GraphNeuralNetwork: + @serve.batch( + max_batch_size=10000, # Maximum total nodes per batch + batch_wait_timeout_s=0.1, + batch_size_fn=lambda graphs: sum(g.num_nodes for g in graphs), + ) + async def predict(self, graphs: List[Graph]) -> List[float]: + """Process a batch of graphs, batching by total node count.""" + # The batch_size_fn ensures that the total number of nodes + # across all graphs in the batch doesn't exceed max_batch_size. + # This prevents GPU memory overflow. + results = [] + for graph in graphs: + # Your GNN model inference logic here + # For this example, just return a simple score + score = float(graph.num_nodes * 0.1) + results.append(score) + return results + + async def __call__(self, graph: Graph) -> float: + return await self.predict(graph) + + +handle: DeploymentHandle = serve.run(GraphNeuralNetwork.bind()) + +# Create test graphs with varying node counts +graphs = [ + Graph(num_nodes=100, node_features=[1.0] * 100), + Graph(num_nodes=5000, node_features=[2.0] * 5000), + Graph(num_nodes=3000, node_features=[3.0] * 3000), +] + +# Send requests - they'll be batched by total node count +results = [handle.remote(g).result() for g in graphs] +print(f"Results: {results}") +# __batch_size_fn_end__ + +# __batch_size_fn_nlp_begin__ +from typing import List + +from ray import serve +from ray.serve.handle import DeploymentHandle + + +@serve.deployment +class TokenBatcher: + @serve.batch( + max_batch_size=512, # Maximum total tokens per batch + batch_wait_timeout_s=0.1, + batch_size_fn=lambda sequences: sum(len(s.split()) for s in sequences), + ) + async def process(self, sequences: List[str]) -> List[int]: + """Process text sequences, batching by total token count.""" + # The batch_size_fn ensures total tokens don't exceed max_batch_size. + # This is useful for transformer models with fixed context windows. + return [len(seq.split()) for seq in sequences] + + async def __call__(self, sequence: str) -> int: + return await self.process(sequence) + + +handle: DeploymentHandle = serve.run(TokenBatcher.bind()) + +# Create sequences with different lengths +sequences = [ + "This is a short sentence", + "This is a much longer sentence with many more words to process", + "Short", +] + +# Send requests - they'll be batched by total token count +results = [handle.remote(seq).result() for seq in sequences] +print(f"Token counts: {results}") +# __batch_size_fn_nlp_end__ diff --git a/python/ray/serve/batching.py b/python/ray/serve/batching.py index ab16fe47e962..e7c312116d1f 100644 --- a/python/ray/serve/batching.py +++ b/python/ray/serve/batching.py @@ -109,6 +109,7 @@ def __init__( batch_wait_timeout_s: float, max_concurrent_batches: int, handle_batch_func: Optional[Callable] = None, + batch_size_fn: Optional[Callable[[List], int]] = None, ) -> None: """Async queue that accepts individual items and returns batches. @@ -128,11 +129,18 @@ def __init__( max_concurrent_batches: max number of batches to run concurrently. handle_batch_func(Optional[Callable]): callback to run in the background to handle batches if provided. + batch_size_fn(Optional[Callable[[List], int]]): optional function to + compute the effective batch size. If None, uses len(batch). + The function takes a list of requests and returns an integer + representing the batch size. This is useful for batching based + on custom metrics such as total nodes in graphs, total tokens + in sequences, etc. """ self.queue: asyncio.Queue[_SingleRequest] = asyncio.Queue() self.max_batch_size = max_batch_size self.batch_wait_timeout_s = batch_wait_timeout_s self.max_concurrent_batches = max_concurrent_batches + self.batch_size_fn = batch_size_fn self.semaphore = asyncio.Semaphore(max_concurrent_batches) self.requests_available_event = asyncio.Event() self.tasks: Set[asyncio.Task] = set() @@ -174,6 +182,23 @@ def put(self, request: Tuple[_SingleRequest, asyncio.Future]) -> None: self.queue.put_nowait(request) self.requests_available_event.set() + def _compute_batch_size(self, batch: List[_SingleRequest]) -> int: + """Compute the effective batch size using batch_size_fn or len().""" + if self.batch_size_fn is None: + return len(batch) + + # Extract the actual data items from requests to pass to batch_size_fn. + # We need to reconstruct the original arguments from flattened_args. + items = [] + for request in batch: + # Recover the original arguments from flattened format + args, kwargs = recover_args(request.flattened_args) + # The batch function expects a single positional argument (the item) + # after 'self' has been extracted (if it was a method) + items.append(args[0]) + + return self.batch_size_fn(items) + async def wait_for_batch(self) -> List[_SingleRequest]: """Wait for batch respecting self.max_batch_size and self.timeout_s. @@ -207,8 +232,31 @@ async def wait_for_batch(self) -> List[_SingleRequest]: pass # Add all new arrivals to the batch. - while len(batch) < max_batch_size and not self.queue.empty(): - batch.append(self.queue.get_nowait()) + # Track items we need to put back if they don't fit + deferred_item = None + + # Custom batch size function logic + if self.batch_size_fn is not None: + while not self.queue.empty(): + next_item = self.queue.get_nowait() + # Temporarily add to check size + batch.append(next_item) + new_size = self._compute_batch_size(batch) + + if new_size > max_batch_size: + # Would exceed limit, remove it and save for later + batch.pop() + deferred_item = next_item + break + # Size is OK, keep it in the batch (already added above) + else: + # Default behavior: use original len() check logic + while len(batch) < max_batch_size and not self.queue.empty(): + batch.append(self.queue.get_nowait()) + + # Put deferred item back in queue for next batch + if deferred_item is not None: + self.queue.put_nowait(deferred_item) # Only clear the put event if the queue is empty. If it's not empty # we can start constructing a new batch immediately in the next loop. @@ -219,9 +267,10 @@ async def wait_for_batch(self) -> List[_SingleRequest]: if self.queue.empty(): self.requests_available_event.clear() + current_batch_size = self._compute_batch_size(batch) if ( time.time() - batch_start_time >= batch_wait_timeout_s - or len(batch) >= max_batch_size + or current_batch_size >= max_batch_size ): break @@ -409,12 +458,14 @@ def __init__( batch_wait_timeout_s: float = 0.0, max_concurrent_batches: int = 1, handle_batch_func: Optional[Callable] = None, + batch_size_fn: Optional[Callable[[List], int]] = None, ): self._queue: Optional[_BatchQueue] = None self.max_batch_size = max_batch_size self.batch_wait_timeout_s = batch_wait_timeout_s self.max_concurrent_batches = max_concurrent_batches self.handle_batch_func = handle_batch_func + self.batch_size_fn = batch_size_fn @property def queue(self) -> _BatchQueue: @@ -428,6 +479,7 @@ def queue(self) -> _BatchQueue: self.batch_wait_timeout_s, self.max_concurrent_batches, self.handle_batch_func, + self.batch_size_fn, ) return self._queue @@ -516,6 +568,13 @@ def _validate_max_concurrent_batches(max_concurrent_batches: int) -> None: ) +def _validate_batch_size_fn(batch_size_fn: Optional[Callable[[List], int]]) -> None: + if batch_size_fn is not None and not callable(batch_size_fn): + raise TypeError( + f"batch_size_fn must be a callable or None, got {type(batch_size_fn)}" + ) + + SelfType = TypeVar("SelfType", contravariant=True) T = TypeVar("T") R = TypeVar("R") @@ -564,6 +623,7 @@ def batch( max_batch_size: int = 10, batch_wait_timeout_s: float = 0.01, max_concurrent_batches: int = 1, + batch_size_fn: Optional[Callable[[List], int]] = None, ) -> "_BatchDecorator": ... @@ -601,6 +661,7 @@ def batch( max_batch_size: int = 10, batch_wait_timeout_s: float = 0.01, max_concurrent_batches: int = 1, + batch_size_fn: Optional[Callable[[List], int]] = None, ) -> Callable: """Converts a function to asynchronously handle batches. @@ -652,6 +713,12 @@ async def __call__(self, request: Request): executed concurrently. If the number of concurrent batches exceeds this limit, the batch handler will wait for a batch to complete before sending the next batch to the underlying function. + batch_size_fn: optional function to compute the effective batch size. + If provided, this function takes a list of items and returns an + integer representing the batch size. This is useful for batching + based on custom metrics such as total nodes in graphs, total tokens + in sequences, or other domain-specific measures. If None, the batch + size is computed as len(batch). """ # `_func` will be None in the case when the decorator is parametrized. # See the comment at the end of this function for a detailed explanation. @@ -667,6 +734,7 @@ async def __call__(self, request: Request): _validate_max_batch_size(max_batch_size) _validate_batch_wait_timeout_s(batch_wait_timeout_s) _validate_max_concurrent_batches(max_concurrent_batches) + _validate_batch_size_fn(batch_size_fn) def _batch_decorator(_func): lazy_batch_queue_wrapper = _LazyBatchQueueWrapper( @@ -674,6 +742,7 @@ def _batch_decorator(_func): batch_wait_timeout_s, max_concurrent_batches, _func, + batch_size_fn, ) async def batch_handler_generator( diff --git a/python/ray/serve/tests/test_batching.py b/python/ray/serve/tests/test_batching.py index ac149e79ad13..21c2000efcd2 100644 --- a/python/ray/serve/tests/test_batching.py +++ b/python/ray/serve/tests/test_batching.py @@ -402,6 +402,182 @@ def do_request(): ), f"Expected 6 unique request IDs, got {len(request_ids_in_batch_context)}" +def test_batch_size_fn_simple(serve_instance): + """Test batch_size_fn with a simple custom batch size metric.""" + + @serve.deployment + class BatchSizeFnExample: + def __init__(self): + self.batches_received = [] + + @serve.batch( + max_batch_size=100, # Set based on total size, not count + batch_wait_timeout_s=0.5, + batch_size_fn=lambda items: sum(item["size"] for item in items), + ) + async def handle_batch(self, requests: List): + # Record the batch for verification + self.batches_received.append(requests) + # Return results + return [req["value"] * 2 for req in requests] + + async def __call__(self, request): + return await self.handle_batch(request) + + def get_batches(self): + return self.batches_received + + handle = serve.run(BatchSizeFnExample.bind()) + + # Send requests with different sizes + # Request 1: size=30, value=1 + # Request 2: size=40, value=2 + # Request 3: size=20, value=3 + # Request 4: size=25, value=4 + # Total of first 3 = 90 (< 100), but adding 4th would be 115 (> 100) + requests = [ + {"size": 30, "value": 1}, + {"size": 40, "value": 2}, + {"size": 20, "value": 3}, + {"size": 25, "value": 4}, + ] + + result_futures = [handle.remote(req) for req in requests] + results = [future.result() for future in result_futures] + + # Verify results are correct + assert results == [2, 4, 6, 8] + + # Verify batching behavior + batches = handle.get_batches.remote().result() + # Should have created at least one batch + assert len(batches) > 0 + + +def test_batch_size_fn_graph_nodes(serve_instance): + """Test batch_size_fn with a GNN-style use case (batching by total nodes).""" + + class Graph: + def __init__(self, num_nodes: int, graph_id: int): + self.num_nodes = num_nodes + self.graph_id = graph_id + + @serve.deployment + class GraphBatcher: + def __init__(self): + self.batch_sizes = [] + + @serve.batch( + max_batch_size=100, # Max 100 nodes per batch + batch_wait_timeout_s=0.5, + batch_size_fn=lambda graphs: sum(g.num_nodes for g in graphs), + ) + async def process_graphs(self, graphs: List[Graph]): + # Record batch size (total nodes) + total_nodes = sum(g.num_nodes for g in graphs) + self.batch_sizes.append(total_nodes) + # Return graph_id * num_nodes as result + return [g.graph_id * g.num_nodes for g in graphs] + + async def __call__(self, graph): + return await self.process_graphs(graph) + + def get_batch_sizes(self): + return self.batch_sizes + + handle = serve.run(GraphBatcher.bind()) + + # Create graphs with different node counts + # Graph 1: 30 nodes, Graph 2: 40 nodes, Graph 3: 35 nodes, Graph 4: 50 nodes + # First 3 total = 105 nodes (> 100), so should be 2 batches + graphs = [ + Graph(num_nodes=30, graph_id=1), + Graph(num_nodes=40, graph_id=2), + Graph(num_nodes=35, graph_id=3), + Graph(num_nodes=50, graph_id=4), + ] + + result_futures = [handle.remote(g) for g in graphs] + results = [future.result() for future in result_futures] + + # Verify results + assert results == [30, 80, 105, 200] + + # Verify batch sizes respect the limit + batch_sizes = handle.get_batch_sizes.remote().result() + for batch_size in batch_sizes: + # Each batch should have <= 100 nodes + assert batch_size <= 100, f"Batch size {batch_size} exceeds limit of 100" + + +def test_batch_size_fn_token_count(serve_instance): + """Test batch_size_fn with an NLP-style use case (batching by total tokens).""" + + @serve.deployment + class TokenBatcher: + @serve.batch( + max_batch_size=1000, # Max 1000 tokens per batch + batch_wait_timeout_s=0.5, + batch_size_fn=lambda sequences: sum(len(s.split()) for s in sequences), + ) + async def process_sequences(self, sequences: List[str]): + # Return word count for each sequence + return [len(s.split()) for s in sequences] + + async def __call__(self, sequence): + return await self.process_sequences(sequence) + + handle = serve.run(TokenBatcher.bind()) + + # Create sequences with different token counts + sequences = [ + "This is a short sequence", # 5 tokens + "This is a much longer sequence with many more words in it", # 12 tokens + "Short", # 1 token + "A B C D E F G H I J", # 10 tokens + ] + + result_futures = [handle.remote(s) for s in sequences] + results = [future.result() for future in result_futures] + + # Verify results are correct + assert results == [5, 12, 1, 10] + + +def test_batch_size_fn_validation(): + """Test that batch_size_fn validation works correctly.""" + from ray.serve.batching import batch + + # Test with non-callable batch_size_fn + with pytest.raises(TypeError, match="batch_size_fn must be a callable or None"): + + @batch(batch_size_fn="not_a_function") + async def my_batch_handler(items): + return items + + +def test_batch_size_fn_default_behavior(serve_instance): + """Test that default behavior (batch_size_fn=None) still works as expected.""" + + @serve.deployment + class DefaultBatcher: + @serve.batch(max_batch_size=5, batch_wait_timeout_s=0.5) + async def handle_batch(self, requests): + return [r * 2 for r in requests] + + async def __call__(self, request): + return await self.handle_batch(request) + + handle = serve.run(DefaultBatcher.bind()) + + # Send 10 requests + result_futures = [handle.remote(i) for i in range(10)] + results = [future.result() for future in result_futures] + + # Verify all results are correct + assert results == [i * 2 for i in range(10)] + + if __name__ == "__main__": import sys