Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 45 additions & 4 deletions doc/source/serve/advanced-guides/dyn-req-batch.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@ emphasize-lines: 11-12
---
```

You can supply 3 optional parameters to the decorators.
- `batch_wait_timeout_s` controls how long Serve should wait for a batch once the first request arrives. The default value is 0.01 (10 milliseconds).
You can supply 4 optional parameters to the decorators:
- `max_batch_size` controls the size of the batch. The default value is 10.
- `batch_wait_timeout_s` controls how long Serve should wait for a batch once the first request arrives. The default value is 0.01 (10 milliseconds).
- `max_concurrent_batches` maximum number of batches that can run concurrently. The default value is 1.
- `batch_size_fn` optional function to compute the effective batch size. If provided, this function takes a list of items and returns an integer representing the batch size. This is useful for batching based on custom metrics such as total nodes in graphs or total tokens in sequences. If `None` (the default), the batch size is computed as `len(batch)`.

Once the first request arrives, the batching decorator waits for a full batch (up to `max_batch_size`) until `batch_wait_timeout_s` is reached. If the timeout is reached, the Serve sends the batch to the model regardless the batch size.
Once the first request arrives, the batching decorator waits for a full batch (up to `max_batch_size`) until `batch_wait_timeout_s` is reached. If the timeout is reached, Serve sends the batch to the model regardless of the batch size.

:::{tip}
You can reconfigure your `batch_wait_timeout_s` and `max_batch_size` parameters using the `set_batch_wait_timeout_s` and `set_max_batch_size` methods:
Expand All @@ -50,6 +51,44 @@ end-before: __batch_params_update_end__
Use these methods in the constructor or the `reconfigure` [method](serve-user-config) to control the `@serve.batch` parameters through your Serve configuration file.
:::

## Custom batch size functions

By default, Ray Serve measures batch size as the number of items in the batch (`len(batch)`). However, in many workloads, the computational cost depends on properties of the items themselves rather than just the count. For example:

- **Graph Neural Networks (GNNs)**: The cost depends on the total number of nodes across all graphs, not the number of graphs
- **Natural Language Processing (NLP)**: Transformer models batch by total token count, not the number of sequences
- **Variable-resolution images**: Memory usage depends on total pixels, not the number of images

Use the `batch_size_fn` parameter to define a custom metric for batch size:

### Graph Neural Network example

The following example shows how to batch graph data by total node count:

```{literalinclude} ../doc_code/batching_guide.py
---
start-after: __batch_size_fn_begin__
end-before: __batch_size_fn_end__
emphasize-lines: 14-17
---
```

In this example, `batch_size_fn=lambda graphs: sum(g.num_nodes for g in graphs)` ensures that the batch contains at most 10,000 total nodes, preventing GPU memory overflow regardless of how many individual graphs are in the batch.

### NLP token batching example

The following example shows how to batch text sequences by total token count:

```{literalinclude} ../doc_code/batching_guide.py
---
start-after: __batch_size_fn_nlp_begin__
end-before: __batch_size_fn_nlp_end__
emphasize-lines: 7-10
---
```

This pattern ensures that the total number of tokens doesn't exceed the model's context window or memory limits.

(serve-streaming-batched-requests-guide)=

## Streaming batched requests
Expand Down Expand Up @@ -80,7 +119,9 @@ Some inputs within a batch may generate fewer outputs than others. When a partic

`max_batch_size` ideally should be a power of 2 (2, 4, 8, 16, ...) because CPUs and GPUs are both optimized for data of these shapes. Large batch sizes incur a high memory cost as well as latency penalty for the first few requests.

Set `batch_wait_timeout_s` considering the end to end latency SLO (Service Level Objective). For example, if your latency target is 150ms, and the model takes 100ms to evaluate the batch, set the `batch_wait_timeout_s` to a value much lower than 150ms - 100ms = 50ms.
When using `batch_size_fn`, set `max_batch_size` based on your custom metric rather than item count. For example, if batching by total nodes in graphs, set `max_batch_size` to your GPU's maximum node capacity (such as 10,000 nodes) rather than a count of graphs.

Set `batch_wait_timeout_s` considering the end-to-end latency SLO (Service Level Objective). For example, if your latency target is 150ms, and the model takes 100ms to evaluate the batch, set the `batch_wait_timeout_s` to a value much lower than 150ms - 100ms = 50ms.

When using batching in a Serve Deployment Graph, the relationship between an upstream node and a downstream node might affect the performance as well. Consider a chain of two models where first model sets `max_batch_size=8` and second model sets `max_batch_size=6`. In this scenario, when the first model finishes a full batch of 8, the second model finishes one batch of 6 and then to fill the next batch, which Serve initially only partially fills with 8 - 6 = 2 requests, leads to incurring latency costs. The batch size of downstream models should ideally be multiples or divisors of the upstream models to ensure the batches work optimally together.

91 changes: 91 additions & 0 deletions doc/source/serve/doc_code/batching_guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,94 @@ def issue_request(max) -> List[str]:
chunks_list = [fut.result() for fut in futs]
for max, chunks in zip(requested_maxes, chunks_list):
assert chunks == [str(i) for i in range(max)]

# __batch_size_fn_begin__
from typing import List

from ray import serve
from ray.serve.handle import DeploymentHandle


class Graph:
"""Simple graph data structure for GNN workloads."""

def __init__(self, num_nodes: int, node_features: list):
self.num_nodes = num_nodes
self.node_features = node_features


@serve.deployment
class GraphNeuralNetwork:
@serve.batch(
max_batch_size=10000, # Maximum total nodes per batch
batch_wait_timeout_s=0.1,
batch_size_fn=lambda graphs: sum(g.num_nodes for g in graphs),
)
async def predict(self, graphs: List[Graph]) -> List[float]:
"""Process a batch of graphs, batching by total node count."""
# The batch_size_fn ensures that the total number of nodes
# across all graphs in the batch doesn't exceed max_batch_size.
# This prevents GPU memory overflow.
results = []
for graph in graphs:
# Your GNN model inference logic here
# For this example, just return a simple score
score = float(graph.num_nodes * 0.1)
results.append(score)
return results

async def __call__(self, graph: Graph) -> float:
return await self.predict(graph)


handle: DeploymentHandle = serve.run(GraphNeuralNetwork.bind())

# Create test graphs with varying node counts
graphs = [
Graph(num_nodes=100, node_features=[1.0] * 100),
Graph(num_nodes=5000, node_features=[2.0] * 5000),
Graph(num_nodes=3000, node_features=[3.0] * 3000),
]

# Send requests - they'll be batched by total node count
results = [handle.remote(g).result() for g in graphs]
print(f"Results: {results}")
# __batch_size_fn_end__

# __batch_size_fn_nlp_begin__
from typing import List

from ray import serve
from ray.serve.handle import DeploymentHandle


@serve.deployment
class TokenBatcher:
@serve.batch(
max_batch_size=512, # Maximum total tokens per batch
batch_wait_timeout_s=0.1,
batch_size_fn=lambda sequences: sum(len(s.split()) for s in sequences),
)
async def process(self, sequences: List[str]) -> List[int]:
"""Process text sequences, batching by total token count."""
# The batch_size_fn ensures total tokens don't exceed max_batch_size.
# This is useful for transformer models with fixed context windows.
return [len(seq.split()) for seq in sequences]

async def __call__(self, sequence: str) -> int:
return await self.process(sequence)


handle: DeploymentHandle = serve.run(TokenBatcher.bind())

# Create sequences with different lengths
sequences = [
"This is a short sentence",
"This is a much longer sentence with many more words to process",
"Short",
]

# Send requests - they'll be batched by total token count
results = [handle.remote(seq).result() for seq in sequences]
print(f"Token counts: {results}")
# __batch_size_fn_nlp_end__
75 changes: 72 additions & 3 deletions python/ray/serve/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(
batch_wait_timeout_s: float,
max_concurrent_batches: int,
handle_batch_func: Optional[Callable] = None,
batch_size_fn: Optional[Callable[[List], int]] = None,
) -> None:
"""Async queue that accepts individual items and returns batches.

Expand All @@ -128,11 +129,18 @@ def __init__(
max_concurrent_batches: max number of batches to run concurrently.
handle_batch_func(Optional[Callable]): callback to run in the
background to handle batches if provided.
batch_size_fn(Optional[Callable[[List], int]]): optional function to
compute the effective batch size. If None, uses len(batch).
The function takes a list of requests and returns an integer
representing the batch size. This is useful for batching based
on custom metrics such as total nodes in graphs, total tokens
in sequences, etc.
"""
self.queue: asyncio.Queue[_SingleRequest] = asyncio.Queue()
self.max_batch_size = max_batch_size
self.batch_wait_timeout_s = batch_wait_timeout_s
self.max_concurrent_batches = max_concurrent_batches
self.batch_size_fn = batch_size_fn
self.semaphore = asyncio.Semaphore(max_concurrent_batches)
self.requests_available_event = asyncio.Event()
self.tasks: Set[asyncio.Task] = set()
Expand Down Expand Up @@ -174,6 +182,23 @@ def put(self, request: Tuple[_SingleRequest, asyncio.Future]) -> None:
self.queue.put_nowait(request)
self.requests_available_event.set()

def _compute_batch_size(self, batch: List[_SingleRequest]) -> int:
"""Compute the effective batch size using batch_size_fn or len()."""
if self.batch_size_fn is None:
return len(batch)

# Extract the actual data items from requests to pass to batch_size_fn.
# We need to reconstruct the original arguments from flattened_args.
items = []
for request in batch:
# Recover the original arguments from flattened format
args, kwargs = recover_args(request.flattened_args)
# The batch function expects a single positional argument (the item)
# after 'self' has been extracted (if it was a method)
items.append(args[0])

return self.batch_size_fn(items)

async def wait_for_batch(self) -> List[_SingleRequest]:
"""Wait for batch respecting self.max_batch_size and self.timeout_s.

Expand Down Expand Up @@ -207,8 +232,31 @@ async def wait_for_batch(self) -> List[_SingleRequest]:
pass

# Add all new arrivals to the batch.
while len(batch) < max_batch_size and not self.queue.empty():
batch.append(self.queue.get_nowait())
# Track items we need to put back if they don't fit
deferred_item = None

# Custom batch size function logic
if self.batch_size_fn is not None:
while not self.queue.empty():
next_item = self.queue.get_nowait()
# Temporarily add to check size
batch.append(next_item)
new_size = self._compute_batch_size(batch)

if new_size > max_batch_size:
# Would exceed limit, remove it and save for later
batch.pop()
deferred_item = next_item
break
# Size is OK, keep it in the batch (already added above)
else:
# Default behavior: use original len() check logic
while len(batch) < max_batch_size and not self.queue.empty():
batch.append(self.queue.get_nowait())

# Put deferred item back in queue for next batch
if deferred_item is not None:
self.queue.put_nowait(deferred_item)

# Only clear the put event if the queue is empty. If it's not empty
# we can start constructing a new batch immediately in the next loop.
Expand All @@ -219,9 +267,10 @@ async def wait_for_batch(self) -> List[_SingleRequest]:
if self.queue.empty():
self.requests_available_event.clear()

current_batch_size = self._compute_batch_size(batch)
if (
time.time() - batch_start_time >= batch_wait_timeout_s
or len(batch) >= max_batch_size
or current_batch_size >= max_batch_size
):
break

Expand Down Expand Up @@ -409,12 +458,14 @@ def __init__(
batch_wait_timeout_s: float = 0.0,
max_concurrent_batches: int = 1,
handle_batch_func: Optional[Callable] = None,
batch_size_fn: Optional[Callable[[List], int]] = None,
):
self._queue: Optional[_BatchQueue] = None
self.max_batch_size = max_batch_size
self.batch_wait_timeout_s = batch_wait_timeout_s
self.max_concurrent_batches = max_concurrent_batches
self.handle_batch_func = handle_batch_func
self.batch_size_fn = batch_size_fn

@property
def queue(self) -> _BatchQueue:
Expand All @@ -428,6 +479,7 @@ def queue(self) -> _BatchQueue:
self.batch_wait_timeout_s,
self.max_concurrent_batches,
self.handle_batch_func,
self.batch_size_fn,
)
return self._queue

Expand Down Expand Up @@ -516,6 +568,13 @@ def _validate_max_concurrent_batches(max_concurrent_batches: int) -> None:
)


def _validate_batch_size_fn(batch_size_fn: Optional[Callable[[List], int]]) -> None:
if batch_size_fn is not None and not callable(batch_size_fn):
raise TypeError(
f"batch_size_fn must be a callable or None, got {type(batch_size_fn)}"
)


SelfType = TypeVar("SelfType", contravariant=True)
T = TypeVar("T")
R = TypeVar("R")
Expand Down Expand Up @@ -564,6 +623,7 @@ def batch(
max_batch_size: int = 10,
batch_wait_timeout_s: float = 0.01,
max_concurrent_batches: int = 1,
batch_size_fn: Optional[Callable[[List], int]] = None,
) -> "_BatchDecorator":
...

Expand Down Expand Up @@ -601,6 +661,7 @@ def batch(
max_batch_size: int = 10,
batch_wait_timeout_s: float = 0.01,
max_concurrent_batches: int = 1,
batch_size_fn: Optional[Callable[[List], int]] = None,
) -> Callable:
"""Converts a function to asynchronously handle batches.

Expand Down Expand Up @@ -652,6 +713,12 @@ async def __call__(self, request: Request):
executed concurrently. If the number of concurrent batches exceeds
this limit, the batch handler will wait for a batch to complete
before sending the next batch to the underlying function.
batch_size_fn: optional function to compute the effective batch size.
If provided, this function takes a list of items and returns an
integer representing the batch size. This is useful for batching
based on custom metrics such as total nodes in graphs, total tokens
in sequences, or other domain-specific measures. If None, the batch
size is computed as len(batch).
"""
# `_func` will be None in the case when the decorator is parametrized.
# See the comment at the end of this function for a detailed explanation.
Expand All @@ -667,13 +734,15 @@ async def __call__(self, request: Request):
_validate_max_batch_size(max_batch_size)
_validate_batch_wait_timeout_s(batch_wait_timeout_s)
_validate_max_concurrent_batches(max_concurrent_batches)
_validate_batch_size_fn(batch_size_fn)

def _batch_decorator(_func):
lazy_batch_queue_wrapper = _LazyBatchQueueWrapper(
max_batch_size,
batch_wait_timeout_s,
max_concurrent_batches,
_func,
batch_size_fn,
)

async def batch_handler_generator(
Expand Down
Loading