From 8867b74aeab3d73f2193f3e50553c7b61e9fc30d Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 19 Aug 2025 19:37:05 -0700 Subject: [PATCH 1/6] feature support --- py/torch_tensorrt/dynamo/_refit.py | 11 ++-- py/torch_tensorrt/dynamo/_settings.py | 1 - .../dynamo/conversion/_TRTInterpreter.py | 63 ++++++++++--------- 3 files changed, 38 insertions(+), 37 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 9aae901f87..467daab529 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -52,7 +52,7 @@ logger = logging.getLogger(__name__) -@needs_refit +@needs_refit # type: ignore[misc] def construct_refit_mapping( module: torch.fx.GraphModule, inputs: Sequence[Input], @@ -85,7 +85,7 @@ def construct_refit_mapping( return weight_refit_map -@needs_refit +@needs_refit # type: ignore[misc] def construct_refit_mapping_from_weight_name_map( weight_name_map: dict[Any, Any], state_dict: dict[Any, Any], @@ -128,7 +128,7 @@ def construct_refit_mapping_from_weight_name_map( return engine_weight_map -@needs_refit +@needs_refit # type: ignore[misc] def _refit_single_trt_engine_with_gm( new_gm: torch.fx.GraphModule, old_engine: trt.ICudaEngine, @@ -211,7 +211,7 @@ def _refit_single_trt_engine_with_gm( raise AssertionError("Refitting failed.") -@needs_refit +@needs_refit # type: ignore[misc] def refit_module_weights( compiled_module: torch.fx.GraphModule | ExportedProgram, new_weight_module: ExportedProgram, @@ -484,9 +484,10 @@ def refit_module_weights( weight_name_map=None, ) - # clear EXCLUDE_WEIGHTS flag + # clear EXCLUDE_WEIGHTS flag and set INCLUDE_REFIT flag to make the engine refittable serialization_config = engine.create_serialization_config() serialization_config.clear_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) + serialization_config.set_flag(trt.SerializationFlag.INCLUDE_REFIT) serialized_engine = engine.serialize_with_config(serialization_config) if isinstance(compiled_submodule, PythonTorchTensorRTModule): diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index d8f6809eae..a64c5b3800 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -167,7 +167,6 @@ def __setstate__(self, state: dict[str, Any]) -> None: "engine_capability", "hardware_compatible", "refit_identical_engine_weights", - "strip_engine_weights", # TODO: @Evan to remove this after implementing caching weight-stripped engines as default? "immutable_weights", "enable_weight_streaming", "tiling_optimization_level", diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index b8d4994fca..0915f8d61c 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -440,7 +440,7 @@ def check_weight_equal( except Exception: return torch.all(sd_weight == network_weight) - @needs_refit + @needs_refit # type: ignore[misc] def _save_weight_mapping(self) -> None: """ Construct the weight name mapping from engine weight name to state_dict weight name. @@ -577,21 +577,19 @@ def _save_weight_mapping(self) -> None: gc.collect() torch.cuda.empty_cache() - @needs_refit + @needs_refit # type: ignore[misc] def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> None: - # TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine - # if not self.compilation_settings.strip_engine_weights: - # # set EXCLUDE_WEIGHTS flag to strip weights - # runtime = trt.Runtime(TRT_LOGGER) - # engine = runtime.deserialize_cuda_engine(serialized_engine) - - # serialization_config = engine.create_serialization_config() - # serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) - # serialized_engine = engine.serialize_with_config( - # serialization_config - # ) - - # Cache weighted engine for now + # Cache the weight-stripped engine regardless of the `strip_engine_weights` setting + if not self.compilation_settings.strip_engine_weights: + # set EXCLUDE_WEIGHTS flag to strip weights + runtime = trt.Runtime(TRT_LOGGER) + engine = runtime.deserialize_cuda_engine(serialized_engine) + + serialization_config = engine.create_serialization_config() + serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) + serialized_engine = engine.serialize_with_config(serialization_config) + + # Insert weight-stripped engine to cache self.engine_cache.insert( # type: ignore[union-attr] hash_val, ( @@ -605,13 +603,13 @@ def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> No ), ) - @needs_refit + @needs_refit # type: ignore[misc] def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]: # query the cached TRT engine cached_data = self.engine_cache.check(hash_val) # type: ignore[union-attr] if cached_data is not None: # hit the cache ( - serialized_engine, + weight_stripped_serialized_engine, self._input_names, self._output_names, cached_engine_input_specs, @@ -644,31 +642,34 @@ def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]: # refit the cached engine with the new graph module if not self.compilation_settings.strip_engine_weights: runtime = trt.Runtime(TRT_LOGGER) - engine = runtime.deserialize_cuda_engine(serialized_engine) + weight_stripped_engine = runtime.deserialize_cuda_engine( + weight_stripped_serialized_engine + ) from torch_tensorrt.dynamo._refit import ( _refit_single_trt_engine_with_gm, ) + # weight_stripped_engine --in place--> weight_included_engine _refit_single_trt_engine_with_gm( new_gm=self.module, - old_engine=engine, + old_engine=weight_stripped_engine, input_list=self.input_specs, settings=self.compilation_settings, weight_name_map=self.weight_name_map, ) - serialized_engine = engine.serialize() - - # TODO: @Evan is waiting for TRT's feature to load the weight-stripped engine - # # EXCLUDE_WEIGHTS flag must be cleared - # serialization_config = engine.create_serialization_config() - # serialization_config.clear_flag( - # trt.SerializationFlag.EXCLUDE_WEIGHTS - # ) - # serialized_engine = engine.serialize_with_config( - # serialization_config - # ) - # # As of now, the engine becomes non-refittable because when EXCLUDE_WEIGHTS flag is cleared, the REFIT flag is also cleared by TRT to make the plan file smaller + + # Load the cached weight-stripped engine + # EXCLUDE_WEIGHTS flag must be cleared and INCLUDE_REFIT flag must be set + serialization_config = ( + weight_stripped_engine.create_serialization_config() + ) + serialization_config.clear_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) + serialization_config.set_flag(trt.SerializationFlag.INCLUDE_REFIT) + serialized_engine = weight_stripped_engine.serialize_with_config( + serialization_config + ) + # Start from there, the serialized_engine is weight-stripped and refittable with io.BytesIO() as engine_bytes: engine_bytes.write(serialized_engine) From 88ffb24097a57c6b48ba16d2258abde530e091ee Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 19 Aug 2025 19:53:38 -0700 Subject: [PATCH 2/6] update weight stripped engine tests --- .../models/test_weight_stripped_engine.py | 92 +++++++++++++++++-- 1 file changed, 86 insertions(+), 6 deletions(-) diff --git a/tests/py/dynamo/models/test_weight_stripped_engine.py b/tests/py/dynamo/models/test_weight_stripped_engine.py index d2079d11bf..d2093fc04f 100644 --- a/tests/py/dynamo/models/test_weight_stripped_engine.py +++ b/tests/py/dynamo/models/test_weight_stripped_engine.py @@ -449,11 +449,11 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): not torch_trt.ENABLED_FEATURES.refit, "Engine caching requires refit feature that is not supported in Python 3.13 or higher", ) - def test_different_args_dont_share_cached_engine(self): + def test_different_args_share_cached_engine(self): class MyModel(torch.nn.Module): def __init__(self): super().__init__() - self.conv = torch.nn.Conv2d(3, 4, 3, stride=1, bias=True) + self.conv = torch.nn.Conv2d(512, 64, 32, stride=1, bias=True) self.relu = torch.nn.ReLU() def forward(self, x): @@ -463,11 +463,11 @@ def forward(self, x): pyt_model = MyModel().eval().to("cuda") - engine_cache_dir = "/tmp/test_different_args_dont_share_cached_engine" + engine_cache_dir = "/tmp/test_different_args_share_cached_engine" if os.path.exists(engine_cache_dir): shutil.rmtree(engine_cache_dir) - inputs = [torch.rand((4, 3, 32, 32)).to("cuda")] + inputs = [torch.rand((64, 512, 32, 32)).to("cuda")] for i in range(2): if i == 0: @@ -493,8 +493,8 @@ def forward(self, x): assertions.assertEqual( len(os.listdir(engine_cache_dir)), - 2, - msg=f"It has {len(os.listdir(engine_cache_dir))} cached engine(s) but should have 2 engines", + 1, + msg=f"It has {len(os.listdir(engine_cache_dir))} cached engine(s) but should have 1 engine", ) @unittest.skipIf( @@ -631,3 +631,83 @@ def test_refit_identical_engine_weights(self): ) except Exception as e: pass + + @unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Engine caching requires refit feature that is not supported in Python 3.13 or higher", + ) + @unittest.skipIf( + not importlib.util.find_spec("torchvision"), + "torchvision is not installed", + ) + def test_refit_weight_stripped_engine_multiple_times(self): + pyt_model = models.resnet18(pretrained=True).eval().to("cuda") + example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) + # Mark the dim0 of inputs as dynamic + batch = torch.export.Dim("batch", min=1, max=200) + exp_program = torch.export.export( + pyt_model, args=example_inputs, dynamic_shapes={"x": {0: batch}} + ) + + inputs = (torch.rand((128, 3, 224, 224)).to("cuda"),) + + trt_gm = torch_trt.dynamo.compile( + exp_program, + inputs, + use_python_runtime=True, + enabled_precisions={torch.float}, + min_block_size=1, + immutable_weights=False, + cache_built_engines=False, + reuse_cached_engines=False, + strip_engine_weights=True, + refit_identical_engine_weights=False, + ) + output = trt_gm(*inputs) + assertions.assertEqual( + output.sum(), 0, msg="weight-stripped engine results should be all zeros" + ) + + # Refit the weight-stripped engine with the same weights + refitted_trt_gm = refit_module_weights(trt_gm, exp_program) + refitted_output = refitted_trt_gm(*inputs) + assertions.assertNotEqual( + refitted_output.sum(), + 0, + msg="refitted engine results should not be all zeros", + ) + + inputs2 = (torch.rand((64, 3, 224, 224)).to("cuda"),) + exp_program2 = torch.export.export( + pyt_model, args=inputs2, dynamic_shapes={"x": {0: batch}} + ) + + # Refit with different weights + refitted_trt_gm = refit_module_weights(refitted_trt_gm, exp_program2) + refitted_output = refitted_trt_gm(*inputs2) + assertions.assertNotEqual( + refitted_output.sum(), + 0, + msg="refitted engine results should not be all zeros", + ) + + compiled_model = torch.compile( + pyt_model, + backend="tensorrt", + options={ + "use_python_runtime": False, + "enabled_precisions": {torch.float}, + "min_block_size": 1, + "immutable_weights": False, + "cache_built_engines": False, + "reuse_cached_engines": False, + "refit_identical_engine_weights": False, + "strip_engine_weights": False, + }, + ) + compiled_model_output = compiled_model(*inputs2) + cos_sim = cosine_similarity(refitted_output, compiled_model_output) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"refitted_output doesn't match with compiled_model_output. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) From 9ff8c169918d252a1e7146eec5b73bc316efd9ae Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 19 Aug 2025 20:42:23 -0700 Subject: [PATCH 3/6] fix bug --- py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 0915f8d61c..f030a4e0e3 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -609,7 +609,7 @@ def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]: cached_data = self.engine_cache.check(hash_val) # type: ignore[union-attr] if cached_data is not None: # hit the cache ( - weight_stripped_serialized_engine, + serialized_engine, # weight-stripped engine self._input_names, self._output_names, cached_engine_input_specs, @@ -643,7 +643,7 @@ def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]: if not self.compilation_settings.strip_engine_weights: runtime = trt.Runtime(TRT_LOGGER) weight_stripped_engine = runtime.deserialize_cuda_engine( - weight_stripped_serialized_engine + serialized_engine ) from torch_tensorrt.dynamo._refit import ( From 08012a7bf7cb048b7af3cdc67b58aa7cc502eeb0 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 19 Aug 2025 20:44:32 -0700 Subject: [PATCH 4/6] remove restriction --- py/torch_tensorrt/dynamo/backend/backends.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index c39fe57197..28b99e9400 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -157,10 +157,6 @@ def _pretraced_backend( logger.warning( "require_full_compilation arg is not applicable for torch.compile with backend='torch_tensorrt" ) - if settings.strip_engine_weights: - logger.error( - "strip_engine_weights arg is not supported for torch.compile()" - ) trt_compiled = compile_module( gm, torchtrt_inputs, From f0c9c7e8ebe5f4dd3563c817b4157bf98ab41c61 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Wed, 20 Aug 2025 13:02:35 -0700 Subject: [PATCH 5/6] fix typo --- py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index f030a4e0e3..9cfaebf06a 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -669,7 +669,7 @@ def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]: serialized_engine = weight_stripped_engine.serialize_with_config( serialization_config ) - # Start from there, the serialized_engine is weight-stripped and refittable + # Start from here, the serialized_engine is weight-included and refittable with io.BytesIO() as engine_bytes: engine_bytes.write(serialized_engine) From 65b76cf6839a13c9e261e6361f5f84bae9d3deaf Mon Sep 17 00:00:00 2001 From: Evan Li Date: Wed, 20 Aug 2025 13:41:46 -0700 Subject: [PATCH 6/6] add todo --- py/torch_tensorrt/dynamo/_compiler.py | 10 +++++----- py/torch_tensorrt/dynamo/_settings.py | 1 + 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 608c8e84c9..fb07a71a5a 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -53,7 +53,7 @@ logger = logging.getLogger(__name__) -@needs_cross_compile +@needs_cross_compile # type: ignore def cross_compile_for_windows( exported_program: ExportedProgram, inputs: Optional[Sequence[Sequence[Any]]] = None, @@ -141,7 +141,7 @@ def cross_compile_for_windows( assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False sparse_weights (bool): Enable sparsity for convolution and fully connected layers. enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels - capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels + engine_capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels workspace_size (int): Maximum size of workspace given to TensorRT dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer. @@ -479,7 +479,7 @@ def compile( assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False sparse_weights (bool): Enable sparsity for convolution and fully connected layers. enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels - capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels + engine_capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels workspace_size (int): Maximum size of workspace given to TensorRT dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer. @@ -723,7 +723,7 @@ def compile( return trt_gm -@fn_supports_debugger +@fn_supports_debugger # type: ignore def compile_module( gm: torch.fx.GraphModule, sample_arg_inputs: Sequence[Input], @@ -1289,7 +1289,7 @@ def convert_exported_program_to_serialized_trt_engine( return serialized_engine -@needs_cross_compile +@needs_cross_compile # type: ignore def save_cross_compiled_exported_program( gm: torch.fx.GraphModule, file_path: str, diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index a64c5b3800..782b1e68c7 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -157,6 +157,7 @@ def __setstate__(self, state: dict[str, Any]) -> None: self.__dict__.update(state) +# TODO: @Evan If changing the setting would affect the behavior of engine compilation, then should be added to this list _SETTINGS_TO_BE_ENGINE_INVARIANT = ( "enabled_precisions", "max_aux_streams",