Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions tensorrt_llm/_torch/memory_buffer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

import torch

from tensorrt_llm.logger import logger

from .utils import get_shared_pool


@dataclass
class BufferBlock:
Expand Down Expand Up @@ -80,9 +84,22 @@ def get_buffer(self, tensor_shape: list[int], dtype: torch.dtype,

# No suitable buffer was found, so allocate a new one.
# The new buffer is created with uint8 to represent raw bytes.
new_buffer_tensor = torch.zeros((required_memory_size, ),
device='cuda',
dtype=torch.uint8)
new_buffer_tensor = None
try:
with torch.cuda.memory.use_mem_pool(get_shared_pool()):
new_buffer_tensor = torch.zeros((required_memory_size, ),
device='cuda',
dtype=torch.uint8)
except Exception as ex:
# Need to check if this is an OOM exception
logger.debug(
f"Exception happened to create tensor from given memory pool: {str{ex}}"
)
# if exception happens during allocating memory from
new_buffer_tensor = torch.zeros((required_memory_size, ),
device='cuda',
dtype=torch.uint8)

new_block = BufferBlock(buffer=new_buffer_tensor,
is_reserved=reserve_buffer)

Expand Down
10 changes: 10 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,15 @@ def needs_capture(self, key: Tuple[int, int, int]):

return key not in self.graph_outputs

def get_graph_pool(self):
"""Returns the CUDA memory pool used by this graph runner.

Returns:
The CUDA memory pool associated with captured graphs, or None if
no graphs have been captured yet.
"""
return self.memory_pool

def capture(self,
key: Tuple[int, int, int],
forward_fn: Callable,
Expand Down Expand Up @@ -255,6 +264,7 @@ def _setup_spec_decoding_and_forward(key: Tuple[int, int, int],
capture_inputs)
if postprocess_fn is not None:
postprocess_fn(capture_inputs)

with torch.cuda.graph(graph, pool=self.memory_pool):
output = _setup_spec_decoding_and_forward(
key, forward_fn, capture_inputs)
Expand Down
51 changes: 26 additions & 25 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@
from ..speculative.mtp import SampleStateTensorsMTP
from ..utils import (get_model_extra_attrs,
set_per_request_piecewise_cuda_graph_flag,
set_torch_compiling, with_model_extra_attrs)
set_shared_mem_pool, set_torch_compiling,
with_model_extra_attrs)
from .config import PyTorchConfig
from .config_utils import is_mla
from .cuda_graph_runner import CUDAGraphRunner
Expand Down Expand Up @@ -2186,35 +2187,35 @@ def forward(
new_tensors_device, cache_indirection_buffer)

self.iter_counter += 1
with set_shared_mem_pool(self.cuda_graph_runner.get_graph_pool()):
if not maybe_graph:
# Fallback to eager execution if graph was not used
with MoeLoadBalancerIterContext(moe_load_balancer):
outputs = self._forward_step(inputs, gather_ids,
gather_context_logits)
else:
if self.cuda_graph_runner.needs_capture(key):

if not maybe_graph:
# Fallback to eager execution if graph was not used
with MoeLoadBalancerIterContext(moe_load_balancer):
outputs = self._forward_step(inputs, gather_ids,
gather_context_logits)
else:
if self.cuda_graph_runner.needs_capture(key):

def capture_forward_fn(inputs: Dict[str, Any]):
with MoeLoadBalancerIterContext(moe_load_balancer):
return self._forward_step(
inputs,
gather_ids=gather_ids,
gather_context_logits=gather_context_logits)
def capture_forward_fn(inputs: Dict[str, Any]):
with MoeLoadBalancerIterContext(moe_load_balancer):
return self._forward_step(
inputs,
gather_ids=gather_ids,
gather_context_logits=gather_context_logits)

def capture_postprocess_fn(inputs: Dict[str, Any]):
self._postprocess_inputs(inputs)
def capture_postprocess_fn(inputs: Dict[str, Any]):
self._postprocess_inputs(inputs)

self.cuda_graph_runner.capture(key, capture_forward_fn,
inputs,
capture_postprocess_fn)
self.cuda_graph_runner.capture(key, capture_forward_fn,
inputs,
capture_postprocess_fn)

# here we don't need to use context since cuda graph capture didn't run kernel.
# maybe we need a cleaner way to do this.
outputs = self.cuda_graph_runner.replay(key, inputs)
else:
with MoeLoadBalancerIterContext(moe_load_balancer):
# here we don't need to use context since cuda graph capture didn't run kernel.
# maybe we need a cleaner way to do this.
outputs = self.cuda_graph_runner.replay(key, inputs)
else:
with MoeLoadBalancerIterContext(moe_load_balancer):
outputs = self.cuda_graph_runner.replay(key, inputs)

self._execute_logit_post_processors(scheduled_requests, outputs)

Expand Down
50 changes: 50 additions & 0 deletions tensorrt_llm/_torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,3 +312,53 @@ def create_lm_head_tp_mapping(mapping: Mapping, token_count: int) -> Mapping:
# It's here so that unit tests can mock it and turn it off.
def _get_allow_chain_drafter() -> bool:
return True


_buffer_pool = None


def set_shared_pool(buffer_pool):
"""Sets the global memory pool for buffer allocation.

Args:
buffer_pool: A CUDA memory pool object to use for allocations.
"""
global _buffer_pool
_buffer_pool = buffer_pool


def get_shared_pool():
"""Retrieves the current global memory pool.

Returns:
The current memory pool, or None if not set.
"""
global _buffer_pool
return _buffer_pool


@contextlib.contextmanager
def set_shared_mem_pool(mem_pool) -> contextlib.AbstractContextManager:
"""Temporarily sets a preferred memory pool and restores the previous one on exit.

This context manager allows temporarily switching to a different memory pool
for CUDA graph operations, ensuring the original pool is restored even if
an exception occurs.

Args:
mem_pool: The memory pool to use within the context.

Yields:
None

Example:
>>> with set_shared_mem_pool(graph_pool):
... # Allocations within this block use graph_pool
... tensor = allocate_buffer(...)
"""
old_buffer_pool = get_shared_pool()
set_shared_pool(mem_pool)
try:
yield
finally:
set_mem_pool(old_buffer_pool)
Loading