diff --git a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py index c8abf965486..f532c116187 100644 --- a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py +++ b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py @@ -1,6 +1,6 @@ import bisect import contextlib -import weakref +from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple import torch @@ -16,12 +16,35 @@ from .scheduler import ScheduledRequests if TYPE_CHECKING: - from .model_engine import PyTorchModelEngine + from ..distributed import MPIDist + from ..mapping import Mapping + from ..speculative import DecodingBaseConfig # A large prime number used for dummy request IDs to avoid collisions CUDA_GRAPH_DUMMY_REQUEST_ID = (1 << 64) - 1 +@dataclass +class CUDAGraphRunnerConfig: + """Configuration for the CUDAGraphRunner, passed from the ModelEngine.""" + use_cuda_graph: bool + cuda_graph_padding_enabled: bool + cuda_graph_batch_sizes: list[int] + max_cuda_graph_batch_size: int + max_beam_width: int + max_num_tokens: int + spec_config: Optional["DecodingBaseConfig"] + cuda_graph_mem_pool: Any + use_mrope: bool + original_max_draft_len: int + is_draft_model: bool + enable_attention_dp: bool + batch_size: int + mapping: Optional["Mapping"] + dist: Optional["MPIDist"] + kv_cache_manager_key: Any + + class CUDAGraphRunner: """ Manages the lifecycle and execution of CUDA graphs for the model engine. @@ -32,23 +55,22 @@ class CUDAGraphRunner: """ WARMUP_STEPS = 2 - def __init__(self, engine: "PyTorchModelEngine"): - self.engine_ref = weakref.ref(engine) + def __init__(self, config: CUDAGraphRunnerConfig): + self.config = config - # High-level configuration - config = engine.pytorch_backend_config + # High-level configuration from the config object self.enabled = config.use_cuda_graph self.padding_enabled = config.cuda_graph_padding_enabled - self.supported_batch_sizes = engine._cuda_graph_batch_sizes - self.max_supported_batch_size = engine._max_cuda_graph_batch_size - self.max_beam_width = engine.max_beam_width - self.spec_config = engine.spec_config + self.supported_batch_sizes = config.cuda_graph_batch_sizes + self.max_supported_batch_size = config.max_cuda_graph_batch_size + self.max_beam_width = config.max_beam_width + self.spec_config = config.spec_config self.graphs: Dict[Tuple[int, int, int], torch.cuda.CUDAGraph] = {} self.graph_outputs: Dict[Tuple[int, int, int], Callable[[], Optional[torch.Tensor]]] = {} self.graph_metadata: Dict[Tuple[int, int, int], Dict[str, Any]] = {} - self.memory_pool = engine._cuda_graph_mem_pool + self.memory_pool = config.cuda_graph_mem_pool self.padding_dummy_request: Optional["Request"] = None self.shared_static_tensors: Dict[str, torch.Tensor] = {} @@ -58,12 +80,11 @@ def __init__(self, engine: "PyTorchModelEngine"): def _create_shared_static_tensors(self): """Allocates static tensors sized for the largest possible batch.""" - engine = self._get_engine() - - token_per_request = self.max_possible_draft_len + 1 + max_draft_len = self.config.original_max_draft_len if self.config.is_spec_decode else 0 + token_per_request = max_draft_len + 1 max_total_tokens = (self.max_supported_batch_size * self.max_beam_width * token_per_request) - max_total_tokens = min(max_total_tokens, engine.max_num_tokens) + max_total_tokens = min(max_total_tokens, self.config.max_num_tokens) self.shared_static_tensors = { "input_ids": @@ -72,7 +93,7 @@ def _create_shared_static_tensors(self): torch.zeros((1, max_total_tokens), device="cuda", dtype=torch.int32), } - if engine.use_mrope: + if self.config.use_mrope: self.shared_static_tensors["position_ids"] = torch.zeros( (3, 1, max_total_tokens), device="cuda", dtype=torch.int32) self.shared_static_tensors["multimodal_params"] = [ @@ -86,55 +107,31 @@ def _create_shared_static_tensors(self): }) for _ in range(max_total_tokens) ] - @property - def enable_spec_decode(self): - return self._get_engine().enable_spec_decode - - @property - def max_possible_draft_len(self): - engine = self._get_engine() - return (engine.original_max_draft_len if self.enable_spec_decode else 0) - def get_graph_key( self, batch_size, + enable_spec_decode: bool, spec_resource_manager: Optional[BaseResourceManager] = None): - engine = self._get_engine() - if engine.is_draft_model and spec_resource_manager is not None and isinstance( + if self.config.is_draft_model and spec_resource_manager is not None and isinstance( spec_resource_manager, Eagle3ResourceManager): - draft_len = engine.original_max_draft_len if spec_resource_manager.is_first_draft else 0 + draft_len = self.config.original_max_draft_len if spec_resource_manager.is_first_draft else 0 key = (batch_size, draft_len, spec_resource_manager.is_first_draft) else: - draft_len = self.spec_config.max_draft_len if self.enable_spec_decode else 0 + draft_len = self.spec_config.max_draft_len if enable_spec_decode else 0 key = (batch_size, draft_len, False) return key - @property - def spec_metadata(self): - return self._get_engine().spec_metadata - - @property - def draft_tokens_cuda(self): - return self._get_engine().draft_tokens_cuda - - @property - def attn_metadata(self): - return self._get_engine().attn_metadata - def __del__(self): self.clear() - def _get_engine(self) -> "PyTorchModelEngine": - """Safely dereferences the weak reference to the engine.""" - engine = self.engine_ref() - if engine is None: - raise RuntimeError( - "The parent PyTorchModelEngine has been garbage collected.") - return engine - def maybe_get_cuda_graph( self, batch: ScheduledRequests, + iter_counter: int, + enable_spec_decode: bool, + attn_metadata: Any, + spec_metadata: Optional[Any], + draft_tokens_cuda: torch.Tensor, spec_resource_manager: Optional[BaseResourceManager] = None): """ Determines if the current batch can be run with a CUDA graph. @@ -145,17 +142,14 @@ def maybe_get_cuda_graph( - The spec_metadata for the graph, if applicable. - The key for the graph. """ - engine = self._get_engine() - # disable when doing statistic - if hasattr(engine, 'iter_counter') and ExpertStatistic.set_iter( - engine.iter_counter): + if ExpertStatistic.set_iter(iter_counter): return False, None, None, None can_run_cuda_graph = batch.can_run_cuda_graph batch_size = batch.batch_size - if self.enabled and engine.enable_attention_dp and engine.mapping.tp_size > 1: - all_can_graph_batch = engine.dist.tp_allgather( + if self.enabled and self.config.enable_attention_dp and self.config.mapping.tp_size > 1: + all_can_graph_batch = self.config.dist.tp_allgather( [can_run_cuda_graph, batch_size]) is_all_gen_only = all(all_can_graph[0] for all_can_graph in all_can_graph_batch) @@ -168,7 +162,8 @@ def maybe_get_cuda_graph( if not self.enabled or not can_run_cuda_graph: return False, None, None, None - key = self.get_graph_key(batch_size, spec_resource_manager) + key = self.get_graph_key(batch_size, enable_spec_decode, + spec_resource_manager) if key in self.graphs: return True, self.graph_metadata[key][ @@ -178,29 +173,28 @@ def maybe_get_cuda_graph( return False, None, None, None num_sequences_in_batch = batch_size * self.max_beam_width - attn_metadata = self.attn_metadata.create_cuda_graph_metadata( + graph_attn_metadata = attn_metadata.create_cuda_graph_metadata( num_sequences_in_batch, False, key[1], self.cuda_graph_meta_buffers) - assert attn_metadata.is_cuda_graph + assert graph_attn_metadata.is_cuda_graph - if self.enable_spec_decode: - spec_metadata = self.spec_metadata.create_cuda_graph_metadata( + if enable_spec_decode: + graph_spec_metadata = spec_metadata.create_cuda_graph_metadata( num_sequences_in_batch) - spec_metadata.draft_tokens = self.draft_tokens_cuda + graph_spec_metadata.draft_tokens = draft_tokens_cuda else: - spec_metadata = None - return True, attn_metadata, spec_metadata, key + graph_spec_metadata = None + return True, graph_attn_metadata, graph_spec_metadata, key def needs_capture(self, key: Tuple[int, int, int]): - return key not in self.graph_outputs def capture(self, key: Tuple[int, int, int], forward_fn: Callable, initial_inputs: Dict[str, Any], + enable_spec_decode: bool, postprocess_fn: Optional[Callable] = None): """Captures the forward pass for a given batch size.""" - engine = self._get_engine() batch_size = key[0] # [CUDA graph spec decode padding] # We pad input IDs/position IDs to the maximum draft length (token per request). @@ -217,7 +211,7 @@ def capture(self, self.shared_static_tensors["position_ids"] [:, :num_tokens_for_capture], } - if engine.use_mrope: + if self.config.use_mrope: sliced_static_tensors["position_ids"] = self.shared_static_tensors[ "position_ids"][:, :, :num_tokens_for_capture], sliced_static_tensors[ @@ -235,12 +229,10 @@ def capture(self, def _setup_spec_decoding_and_forward(key: Tuple[int, int, int], forward_fn: Callable, capture_inputs: Dict[str, Any]): - engine = self._get_engine() - # for the first inference of draft model, we need to set the use_spec_decoding to True when capture the graph for multiple runs. is_first_draft = key[2] - needs_kv_cache_recompute = True if engine.enable_spec_decode and engine.spec_config.spec_dec_mode.needs_kv_cache_recompute( + needs_kv_cache_recompute = True if enable_spec_decode and self.config.spec_config.spec_dec_mode.needs_kv_cache_recompute( ) else False - if is_first_draft and engine.is_draft_model and needs_kv_cache_recompute: + if is_first_draft and self.config.is_draft_model and needs_kv_cache_recompute: capture_inputs['attn_metadata'].use_spec_decoding = True return forward_fn(capture_inputs) @@ -268,7 +260,6 @@ def _setup_spec_decoding_and_forward(key: Tuple[int, int, int], def replay(self, key: Tuple[int, int, int], current_inputs: Dict[str, Any]) -> Optional[torch.Tensor]: """Replays a previously captured graph.""" - engine = self._get_engine() stored_meta = self.graph_metadata[key] assert current_inputs["attn_metadata"] is stored_meta["attn_metadata"] if stored_meta["spec_metadata"] is not None: @@ -282,7 +273,7 @@ def replay(self, key: Tuple[int, int, int], static_tensors["input_ids"][:seqlen].copy_(input_ids) position_ids = current_inputs["position_ids"] - if engine.use_mrope and current_inputs.get( + if self.config.use_mrope and current_inputs.get( 'multimodal_params') is not None: static_tensors["position_ids"][:, :, :seqlen].copy_(position_ids) for i, multimodal_param in enumerate( @@ -302,16 +293,16 @@ def replay(self, key: Tuple[int, int, int], return output_ref def _get_padded_batch(self, batch: ScheduledRequests, - resource_manager: ResourceManager) -> int: - engine = self._get_engine() + resource_manager: ResourceManager, + runtime_draft_len: int) -> int: kv_cache_manager = resource_manager.get_resource_manager( - engine.kv_cache_manager_key) + self.config.kv_cache_manager_key) can_run_cuda_graph = batch.can_run_cuda_graph batch_size = batch.batch_size new_batch_size = batch_size - if self.enabled and engine.enable_attention_dp and engine.mapping.tp_size > 1: - graph_batch_size = engine.dist.tp_allgather( + if self.enabled and self.config.enable_attention_dp and self.config.mapping.tp_size > 1: + graph_batch_size = self.config.dist.tp_allgather( [can_run_cuda_graph, batch_size]) all_can_graph = all(graph_batch[0] for graph_batch in graph_batch_size) @@ -329,7 +320,7 @@ def _get_padded_batch(self, batch: ScheduledRequests, return 0 padding_size = padded_batch_size - batch_size - if padding_size + batch.batch_size > engine.batch_size: + if padding_size + batch.batch_size > self.config.batch_size: return 0 # No padding if it would create too many concurrent requests. @@ -344,9 +335,9 @@ def _get_padded_batch(self, batch: ScheduledRequests, self.padding_dummy_request = kv_cache_manager.add_dummy_requests( [CUDA_GRAPH_DUMMY_REQUEST_ID], is_gen=True, - max_num_draft_tokens=engine.runtime_draft_len, - use_mrope=engine.use_mrope, - max_beam_width=engine.max_beam_width)[0] + max_num_draft_tokens=runtime_draft_len, + use_mrope=self.config.use_mrope, + max_beam_width=self.config.max_beam_width)[0] self.padding_dummy_request.is_cuda_graph_dummy = True spec_res_mgr = resource_manager.get_resource_manager( ResourceManagerType.SPEC_RESOURCE_MANAGER) @@ -368,11 +359,11 @@ def _round_up_batch_size(self, batch_size: int) -> int: @contextlib.contextmanager def pad_batch(self, scheduled_requests: ScheduledRequests, - resource_manager: ResourceManager): + resource_manager: ResourceManager, runtime_draft_len: int): """Context manager to pad a batch to a graph-compatible size.""" - padding_size = self._get_padded_batch(scheduled_requests, - resource_manager) + resource_manager, + runtime_draft_len) try: yield scheduled_requests finally: diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index db38219fe03..76b6acb5448 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -51,7 +51,7 @@ set_torch_compiling, with_model_extra_attrs) from .config import PyTorchConfig from .config_utils import is_mla -from .cuda_graph_runner import CUDAGraphRunner +from .cuda_graph_runner import CUDAGraphRunner, CUDAGraphRunnerConfig from .guided_decoder import CapturableGuidedDecoder from .layerwise_nvtx_marker import LayerwiseNvtxMarker from .llm_request import get_draft_token_length @@ -336,7 +336,28 @@ def __init__( # with different KV cache managers. self.kv_cache_manager_key = ResourceManagerType.KV_CACHE_MANAGER self.lora_model_config: Optional[LoraModelConfig] = None - self.cuda_graph_runner = CUDAGraphRunner(self) + + # Create config and runner + cuda_graph_runner_config = CUDAGraphRunnerConfig( + use_cuda_graph=pytorch_backend_config.use_cuda_graph, + cuda_graph_padding_enabled=pytorch_backend_config. + cuda_graph_padding_enabled, + cuda_graph_batch_sizes=self._cuda_graph_batch_sizes, + max_cuda_graph_batch_size=self._max_cuda_graph_batch_size, + max_beam_width=self.max_beam_width, + spec_config=self.spec_config, + cuda_graph_mem_pool=self._cuda_graph_mem_pool, + max_num_tokens=self.max_num_tokens, + use_mrope=self.use_mrope, + original_max_draft_len=self.original_max_draft_len, + is_draft_model=self.is_draft_model, + enable_attention_dp=self.enable_attention_dp, + batch_size=self.batch_size, + mapping=self.mapping, + dist=self.dist, + kv_cache_manager_key=self.kv_cache_manager_key, + ) + self.cuda_graph_runner = CUDAGraphRunner(cuda_graph_runner_config) # Setup the local cache indirection buffer only once and reuse it. # This way it can also be used for CUDA graphs. @@ -1443,16 +1464,16 @@ def previous_seq_slots_device(): # The order of requests in a batch: [context requests, generation requests] # generation requests: ['requests that do not have previous batch', 'requests that already have previous batch', 'dummy requests'] - # 1) 'requests that do not have previous batch': disable overlap scheduler or the first step in the generation server of disaggregated serving. - # 2) 'requests that already have previous batch': previous iteration's requests. - # 3) 'dummy requests': pad dummy requests for CUDA graph or attention dp. + # 1) 'requests that do not have previous batch': disable overlap scheduler or the first step in the generation server of disaggregated serving. + # 2) 'requests that already have previous batch': previous iteration's requests. + # 3) 'dummy requests': pad dummy requests for CUDA graph or attention dp. # Therefore, both of self.previous_pos_id_offsets_cuda and self.previous_kv_lens_offsets_cuda are also 3 segments. - # For 1) 'requests that do not have previous batch': disable overlap scheduler or the first step in the generation server of disaggregated serving. + # For 1) 'requests that do not have previous batch': disable overlap scheduler or the first step in the generation server of disaggregated serving. # Set these requests' previous_pos_id_offsets and previous_kv_lens_offsets to '0' to skip the value changes in _preprocess_inputs. # Already set to '0' during initialization. - # For 2) 'requests that already have previous batch': enable overlap scheduler. + # For 2) 'requests that already have previous batch': enable overlap scheduler. # Set their previous_pos_id_offsets and previous_kv_lens_offsets according to new_tokens_lens_device and kv_len_offsets_device. - # For 3) 'dummy requests': pad dummy requests for CUDA graph or attention dp. + # For 3) 'dummy requests': pad dummy requests for CUDA graph or attention dp. # Already set to '0' during initialization. num_extend_reqeust_wo_dummy = len(extend_requests) - len( @@ -2177,10 +2198,19 @@ def forward( return self._forward_step(inputs, gather_ids, gather_context_logits) with self.cuda_graph_runner.pad_batch( - scheduled_requests, resource_manager) as padded_requests: + scheduled_requests, resource_manager, + self.runtime_draft_len) as padded_requests: maybe_graph, maybe_attn_metadata, maybe_spec_metadata, key = self.cuda_graph_runner.maybe_get_cuda_graph( - padded_requests, spec_resource_manager) + padded_requests, + iter_counter=self.iter_counter, + enable_spec_decode=self.enable_spec_decode, + attn_metadata=attn_metadata, + spec_metadata=spec_metadata, + draft_tokens_cuda=self.draft_tokens_cuda + if self.is_spec_decode else None, + spec_resource_manager=spec_resource_manager, + ) if maybe_graph: attn_metadata = maybe_attn_metadata spec_metadata = maybe_spec_metadata @@ -2215,9 +2245,12 @@ def capture_forward_fn(inputs: Dict[str, Any]): def capture_postprocess_fn(inputs: Dict[str, Any]): self._postprocess_inputs(inputs) - self.cuda_graph_runner.capture(key, capture_forward_fn, - inputs, - capture_postprocess_fn) + self.cuda_graph_runner.capture( + key, + capture_forward_fn, + inputs, + enable_spec_decode=self.enable_spec_decode, + postprocess_fn=capture_postprocess_fn) # here we don't need to use context since cuda graph capture didn't run kernel. # maybe we need a cleaner way to do this. diff --git a/tests/unittest/_torch/helpers.py b/tests/unittest/_torch/helpers.py index 8200caa3584..33f29c9c8f2 100644 --- a/tests/unittest/_torch/helpers.py +++ b/tests/unittest/_torch/helpers.py @@ -3,6 +3,11 @@ import torch import torch.nn.functional as F +from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import ( + CUDAGraphRunner, CUDAGraphRunnerConfig) +from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType +from tensorrt_llm.mapping import Mapping + def ceil_div(x: int, y: int) -> int: return (x + y - 1) // y @@ -164,42 +169,21 @@ def block_scale_gemm(mat_a: torch.Tensor, mat_scale_a: torch.Tensor, return results.view_as(x) -class MockPytorchBackendConfig: - - def __init__(self, use_cuda_graph, cuda_graph_padding_enabled): - self.use_cuda_graph = use_cuda_graph - self.cuda_graph_padding_enabled = cuda_graph_padding_enabled - - -class MockEngine: - """A replacement for SimpleNamespace that supports weak references.""" - - def __init__(self, **kwargs): - self.__dict__.update(kwargs) - - -def create_mock_engine(batch_size: int): - - class MockSpecConfig: - - class SpecDecMode: - - def needs_kv_cache_recompute(self): - return False - - spec_dec_mode = SpecDecMode() - - return MockEngine( - pytorch_backend_config=MockPytorchBackendConfig( - use_cuda_graph=True, cuda_graph_padding_enabled=False), - _cuda_graph_batch_sizes=[batch_size], - _max_cuda_graph_batch_size=batch_size, +def create_mock_cuda_graph_runner(batch_size: int, use_mrope: bool = False): + config = CUDAGraphRunnerConfig( + use_cuda_graph=True, + cuda_graph_padding_enabled=False, + supported_batch_sizes=[batch_size], + max_supported_batch_size=batch_size, + max_batch_size=batch_size, max_beam_width=1, - max_num_tokens=8192, - is_spec_decode=False, - enable_spec_decode=False, - spec_config=MockSpecConfig(), - is_draft_model=False, - _cuda_graph_mem_pool=None, - use_mrope=False, - ) + max_draft_len=0, + max_num_tokens=1, + use_mrope=use_mrope, + spec_config=None, + cuda_graph_mem_pool=None, + enable_attention_dp=False, + mapping=Mapping(), + dist=None, + kv_cache_manager_key=ResourceManagerType.KV_CACHE_MANAGER) + return CUDAGraphRunner(config) diff --git a/tests/unittest/_torch/modeling/test_modeling_exaone4.py b/tests/unittest/_torch/modeling/test_modeling_exaone4.py index 28a35323b6e..a224cea1186 100644 --- a/tests/unittest/_torch/modeling/test_modeling_exaone4.py +++ b/tests/unittest/_torch/modeling/test_modeling_exaone4.py @@ -22,7 +22,7 @@ class Exaone4Config(PretrainedConfig): # TODO: Remove this once we have a proper config for Exaone4 SKIP_EXAONE4_HF_ACCURACY_TEST = True -from _torch.helpers import create_mock_engine +from _torch.helpers import create_mock_cuda_graph_runner from transformers.cache_utils import HybridCache from utils.util import getSMVersion @@ -31,7 +31,6 @@ class Exaone4Config(PretrainedConfig): from tensorrt_llm._torch.metadata import KVCacheParams from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._torch.models.modeling_exaone4 import Exaone4ForCausalLM -from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager from tensorrt_llm.bindings.executor import KvCacheConfig from tensorrt_llm.mapping import Mapping @@ -338,10 +337,8 @@ def test_exaone4_allclose_to_hf(self, scenario: Scenario) -> None: ] gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda() - graph_runner = None - if scenario.use_cuda_graph: - mock_engine = create_mock_engine(1) - graph_runner = CUDAGraphRunner(mock_engine) + graph_runner = create_mock_cuda_graph_runner( + 1) if scenario.use_cuda_graph else None def run_forward(input_ids, position_ids, attn_metadata): attn_metadata.prepare() diff --git a/tests/unittest/_torch/modeling/test_modeling_llama.py b/tests/unittest/_torch/modeling/test_modeling_llama.py index 8de665741d8..c4c8ecaa283 100644 --- a/tests/unittest/_torch/modeling/test_modeling_llama.py +++ b/tests/unittest/_torch/modeling/test_modeling_llama.py @@ -4,7 +4,7 @@ from typing import Any import torch -from _torch.helpers import create_mock_engine +from _torch.helpers import create_mock_cuda_graph_runner from parameterized import parameterized from transformers import LlamaConfig from transformers import LlamaForCausalLM as HFLlamaForCausalLM @@ -15,7 +15,6 @@ from tensorrt_llm._torch.metadata import KVCacheParams from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._torch.models.modeling_llama import LlamaForCausalLM -from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager from tensorrt_llm.bindings.executor import KvCacheConfig from tensorrt_llm.mapping import Mapping @@ -326,10 +325,8 @@ def test_llama_allclose_to_hf(self, scenario: Scenario) -> None: ] gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda() - graph_runner = None - if scenario.use_cuda_graph: - mock_engine = create_mock_engine(1) - graph_runner = CUDAGraphRunner(mock_engine) + graph_runner = create_mock_cuda_graph_runner( + 1) if scenario.use_cuda_graph else None def run_forward(input_ids, position_ids, attn_metadata): attn_metadata.prepare() diff --git a/tests/unittest/_torch/modeling/test_modeling_llama_min_latency.py b/tests/unittest/_torch/modeling/test_modeling_llama_min_latency.py index 7b3e74a1bb2..879abd73ab5 100644 --- a/tests/unittest/_torch/modeling/test_modeling_llama_min_latency.py +++ b/tests/unittest/_torch/modeling/test_modeling_llama_min_latency.py @@ -4,7 +4,7 @@ import torch import transformers -from _torch.helpers import create_mock_engine +from _torch.helpers import create_mock_cuda_graph_runner from parameterized import parameterized from transformers import Llama4Config from transformers import \ @@ -21,7 +21,6 @@ from tensorrt_llm._torch.models.modeling_llama import \ Llama4ForConditionalGeneration from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig -from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager from tensorrt_llm.bindings.executor import KvCacheConfig from tensorrt_llm.mapping import Mapping @@ -403,10 +402,8 @@ def test_llama_allclose_to_hf(self, scenario: AllCloseScenario) -> None: input_ids.size(-1) + gen_input_ids.size(-1)) ] gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda() - graph_runner = None - if scenario.use_cuda_graph: - mock_engine = create_mock_engine(1) - graph_runner = CUDAGraphRunner(mock_engine) + graph_runner = create_mock_cuda_graph_runner( + 1) if scenario.use_cuda_graph else None def run_forward(input_ids, position_ids, attn_metadata): attn_metadata.prepare() diff --git a/tests/unittest/_torch/modeling/test_modeling_mistral.py b/tests/unittest/_torch/modeling/test_modeling_mistral.py index d3ea00c32ad..e21e1ae1a35 100644 --- a/tests/unittest/_torch/modeling/test_modeling_mistral.py +++ b/tests/unittest/_torch/modeling/test_modeling_mistral.py @@ -8,7 +8,7 @@ import torch import transformers import transformers.models.mistral3 -from _torch.helpers import create_mock_engine +from _torch.helpers import create_mock_cuda_graph_runner from PIL import Image from utils.util import getSMVersion @@ -19,7 +19,6 @@ from tensorrt_llm._torch.attention_backend import utils as attention_utils from tensorrt_llm._torch.models import modeling_mistral from tensorrt_llm._torch.pyexecutor import resource_manager -from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner from tensorrt_llm.bindings import executor as executor_lib from tensorrt_llm.models import modeling_utils @@ -404,10 +403,7 @@ def test_mistral_3_vlm_allclose_to_hf(mistral_small_3_1_24b_config, backend, use ] gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda() - graph_runner = None - if use_cuda_graph: - mock_engine = create_mock_engine(1) - graph_runner = CUDAGraphRunner(mock_engine) + graph_runner = create_mock_cuda_graph_runner(1) if use_cuda_graph else None def run_forward(input_ids, position_ids, attn_metadata): attn_metadata.prepare() diff --git a/tests/unittest/_torch/modeling/test_modeling_mixtral.py b/tests/unittest/_torch/modeling/test_modeling_mixtral.py index b8beecaa772..7071a440ff5 100644 --- a/tests/unittest/_torch/modeling/test_modeling_mixtral.py +++ b/tests/unittest/_torch/modeling/test_modeling_mixtral.py @@ -3,7 +3,7 @@ from dataclasses import dataclass import torch -from _torch.helpers import create_mock_engine +from _torch.helpers import create_mock_cuda_graph_runner from parameterized import parameterized from transformers import MixtralConfig from transformers import MixtralForCausalLM as HFMixtralForCausalLM @@ -16,7 +16,6 @@ from tensorrt_llm._torch.models.checkpoints.hf.mixtral_weight_mapper import \ MixtralHfWeightMapper from tensorrt_llm._torch.models.modeling_mixtral import MixtralForCausalLM -from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager from tensorrt_llm.bindings.executor import KvCacheConfig from tensorrt_llm.mapping import Mapping @@ -310,10 +309,8 @@ def test_mixtral_allclose_to_hf(self, scenario: Scenario): ] gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda() - graph_runner = None - if scenario.use_cuda_graph: - mock_engine = create_mock_engine(1) - graph_runner = CUDAGraphRunner(mock_engine) + graph_runner = create_mock_cuda_graph_runner( + 1) if scenario.use_cuda_graph else None def run_forward(input_ids, position_ids, attn_metadata): attn_metadata.prepare() diff --git a/tests/unittest/_torch/modeling/test_modeling_mllama.py b/tests/unittest/_torch/modeling/test_modeling_mllama.py index 597c084b41d..a9423b86d35 100644 --- a/tests/unittest/_torch/modeling/test_modeling_mllama.py +++ b/tests/unittest/_torch/modeling/test_modeling_mllama.py @@ -4,7 +4,7 @@ import pytest import torch -from _torch.helpers import create_mock_engine +from _torch.helpers import create_mock_cuda_graph_runner from parameterized import parameterized from test_modeling_llama import Scenario, reduce_llama_config from transformers import MllamaConfig @@ -17,7 +17,6 @@ from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._torch.models.modeling_mllama import \ MllamaForConditionalGeneration -from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager from tensorrt_llm.bindings.executor import KvCacheConfig from tensorrt_llm.mapping import Mapping @@ -420,10 +419,8 @@ def test_mllama_allclose_to_hf_text_only(self, scenario: Scenario) -> None: ] gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda() - graph_runner = None - if scenario.use_cuda_graph: - mock_engine = create_mock_engine(1) - graph_runner = CUDAGraphRunner(mock_engine) + graph_runner = create_mock_cuda_graph_runner( + 1) if scenario.use_cuda_graph else None def run_forward(input_ids, position_ids, attn_metadata): attn_metadata.prepare() diff --git a/tests/unittest/_torch/modeling/test_modeling_nemotron.py b/tests/unittest/_torch/modeling/test_modeling_nemotron.py index d06a6bc6b81..2dcac56ea55 100644 --- a/tests/unittest/_torch/modeling/test_modeling_nemotron.py +++ b/tests/unittest/_torch/modeling/test_modeling_nemotron.py @@ -4,7 +4,7 @@ from typing import Any import torch -from _torch.helpers import create_mock_engine +from _torch.helpers import create_mock_cuda_graph_runner from parameterized import parameterized from transformers import NemotronConfig from transformers import NemotronForCausalLM as HFNemotronForCausalLM @@ -15,7 +15,6 @@ from tensorrt_llm._torch.metadata import KVCacheParams from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._torch.models.modeling_nemotron import NemotronForCausalLM -from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager from tensorrt_llm.bindings.executor import KvCacheConfig from tensorrt_llm.mapping import Mapping @@ -318,10 +317,8 @@ def test_nemotron_allclose_to_hf(self, scenario: Scenario) -> None: ] gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda() - graph_runner = None - if scenario.use_cuda_graph: - mock_engine = create_mock_engine(1) - graph_runner = CUDAGraphRunner(mock_engine) + graph_runner = create_mock_cuda_graph_runner( + 1) if scenario.use_cuda_graph else None def run_forward(input_ids, position_ids, attn_metadata): attn_metadata.prepare() diff --git a/tests/unittest/_torch/modeling/test_modeling_phi3.py b/tests/unittest/_torch/modeling/test_modeling_phi3.py index 1a50b874ae5..1f7f0316611 100644 --- a/tests/unittest/_torch/modeling/test_modeling_phi3.py +++ b/tests/unittest/_torch/modeling/test_modeling_phi3.py @@ -4,7 +4,7 @@ from typing import Any import torch -from _torch.helpers import create_mock_engine +from _torch.helpers import create_mock_cuda_graph_runner from transformers import Phi3Config from transformers import Phi3ForCausalLM as HFPhi3ForCausalLM from utils.util import default_dtype @@ -14,7 +14,6 @@ from tensorrt_llm._torch.metadata import KVCacheParams from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._torch.models.modeling_phi3 import Phi3ForCausalLM -from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager from tensorrt_llm.bindings.executor import KvCacheConfig from tensorrt_llm.mapping import Mapping @@ -310,10 +309,8 @@ def test_phi3_allclose_to_hf(self) -> None: ] gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda() - graph_runner = None - if scenario.use_cuda_graph: - mock_engine = create_mock_engine(1) - graph_runner = CUDAGraphRunner(mock_engine) + graph_runner = create_mock_cuda_graph_runner( + 1) if scenario.use_cuda_graph else None def run_forward(input_ids, position_ids, attn_metadata): attn_metadata.prepare() diff --git a/tests/unittest/_torch/modeling/test_modeling_qwen.py b/tests/unittest/_torch/modeling/test_modeling_qwen.py index a35dc9131f6..d2f9cdaac73 100644 --- a/tests/unittest/_torch/modeling/test_modeling_qwen.py +++ b/tests/unittest/_torch/modeling/test_modeling_qwen.py @@ -17,12 +17,11 @@ from tensorrt_llm._torch.models.modeling_qwen import ( Qwen2ForCausalLM, Qwen2ForProcessRewardModel) # yapf: enable -from _torch.helpers import create_mock_engine +from _torch.helpers import create_mock_cuda_graph_runner from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager from tensorrt_llm.bindings.executor import KvCacheConfig from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import QuantConfig -from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner from utils.llm_data import llm_models_root from utils.util import getSMVersion @@ -265,10 +264,8 @@ def test_qwen_allclose_to_hf(self, scenario: Scenario) -> None: ] gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda() - graph_runner = None - if scenario.use_cuda_graph: - mock_engine = create_mock_engine(1) - graph_runner = CUDAGraphRunner(mock_engine) + graph_runner = create_mock_cuda_graph_runner( + 1) if scenario.use_cuda_graph else None def run_forward(input_ids, position_ids, attn_metadata): attn_metadata.prepare() diff --git a/tests/unittest/_torch/modeling/test_modeling_qwen2_5vl.py b/tests/unittest/_torch/modeling/test_modeling_qwen2_5vl.py index 8c04f744f60..4a597b2dbd3 100644 --- a/tests/unittest/_torch/modeling/test_modeling_qwen2_5vl.py +++ b/tests/unittest/_torch/modeling/test_modeling_qwen2_5vl.py @@ -6,7 +6,7 @@ import pytest import torch -from _torch.helpers import create_mock_engine +from _torch.helpers import create_mock_cuda_graph_runner from parameterized import parameterized from transformers import AutoProcessor, AutoTokenizer, Qwen2_5_VLConfig from transformers import \ @@ -19,7 +19,6 @@ from tensorrt_llm._torch.models.checkpoints.hf.qwen2vl_weight_mapper import \ Qwen2VLHfWeightMapper from tensorrt_llm._torch.models.modeling_qwen2vl import Qwen2_5_VLModel -from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager from tensorrt_llm.bindings.executor import KvCacheConfig from tensorrt_llm.inputs import (create_input_processor, @@ -480,11 +479,8 @@ def test_qwen2_5_vl_allclose_to_hf(self, scenario: Scenario) -> None: target_keywords=["mrope_config.mrope_position_deltas"]) gen_multimodal_params_list.append(multimodal_param) - graph_runner = None - if scenario.use_cuda_graph: - mock_engine = create_mock_engine(1) - mock_engine.use_mrope = True - graph_runner = CUDAGraphRunner(mock_engine) + graph_runner = create_mock_cuda_graph_runner( + 1, True) if scenario.use_cuda_graph else None def run_forward(input_ids, position_ids, attn_metadata, multimodal_params): diff --git a/tests/unittest/_torch/modeling/test_modeling_qwen_moe.py b/tests/unittest/_torch/modeling/test_modeling_qwen_moe.py index b8db3be83d6..39cbf33b823 100644 --- a/tests/unittest/_torch/modeling/test_modeling_qwen_moe.py +++ b/tests/unittest/_torch/modeling/test_modeling_qwen_moe.py @@ -3,7 +3,7 @@ from dataclasses import dataclass import torch -from _torch.helpers import create_mock_engine +from _torch.helpers import create_mock_cuda_graph_runner from parameterized import parameterized from transformers import Qwen2MoeConfig from transformers import Qwen2MoeForCausalLM as HFQwen2MoeForCausalLM @@ -16,7 +16,6 @@ from tensorrt_llm._torch.models.checkpoints.hf.qwen2_moe_weight_mapper import \ Qwen2MoeHfWeightMapper from tensorrt_llm._torch.models.modeling_qwen_moe import Qwen2MoeForCausalLM -from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager from tensorrt_llm.bindings.executor import KvCacheConfig from tensorrt_llm.mapping import Mapping @@ -315,10 +314,8 @@ def test_qwen_moe_allclose_to_hf(self, scenario: Scenario): ] gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda() - graph_runner = None - if scenario.use_cuda_graph: - mock_engine = create_mock_engine(1) - graph_runner = CUDAGraphRunner(mock_engine) + graph_runner = create_mock_cuda_graph_runner( + 1) if scenario.use_cuda_graph else None def run_forward(input_ids, position_ids, attn_metadata): attn_metadata.prepare()