Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
logger = logging.getLogger(__name__)


@needs_cross_compile
@needs_cross_compile # type: ignore
def cross_compile_for_windows(
exported_program: ExportedProgram,
inputs: Optional[Sequence[Sequence[Any]]] = None,
Expand Down Expand Up @@ -141,7 +141,7 @@ def cross_compile_for_windows(
assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False
sparse_weights (bool): Enable sparsity for convolution and fully connected layers.
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels
engine_capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels
num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels
workspace_size (int): Maximum size of workspace given to TensorRT
dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer.
Expand Down Expand Up @@ -479,7 +479,7 @@ def compile(
assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False
sparse_weights (bool): Enable sparsity for convolution and fully connected layers.
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels
engine_capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels
num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels
workspace_size (int): Maximum size of workspace given to TensorRT
dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer.
Expand Down Expand Up @@ -723,7 +723,7 @@ def compile(
return trt_gm


@fn_supports_debugger
@fn_supports_debugger # type: ignore
def compile_module(
gm: torch.fx.GraphModule,
sample_arg_inputs: Sequence[Input],
Expand Down Expand Up @@ -1289,7 +1289,7 @@ def convert_exported_program_to_serialized_trt_engine(
return serialized_engine


@needs_cross_compile
@needs_cross_compile # type: ignore
def save_cross_compiled_exported_program(
gm: torch.fx.GraphModule,
file_path: str,
Expand Down
11 changes: 6 additions & 5 deletions py/torch_tensorrt/dynamo/_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
logger = logging.getLogger(__name__)


@needs_refit
@needs_refit # type: ignore[misc]
def construct_refit_mapping(
module: torch.fx.GraphModule,
inputs: Sequence[Input],
Expand Down Expand Up @@ -85,7 +85,7 @@ def construct_refit_mapping(
return weight_refit_map


@needs_refit
@needs_refit # type: ignore[misc]
def construct_refit_mapping_from_weight_name_map(
weight_name_map: dict[Any, Any],
state_dict: dict[Any, Any],
Expand Down Expand Up @@ -128,7 +128,7 @@ def construct_refit_mapping_from_weight_name_map(
return engine_weight_map


@needs_refit
@needs_refit # type: ignore[misc]
def _refit_single_trt_engine_with_gm(
new_gm: torch.fx.GraphModule,
old_engine: trt.ICudaEngine,
Expand Down Expand Up @@ -211,7 +211,7 @@ def _refit_single_trt_engine_with_gm(
raise AssertionError("Refitting failed.")


@needs_refit
@needs_refit # type: ignore[misc]
def refit_module_weights(
compiled_module: torch.fx.GraphModule | ExportedProgram,
new_weight_module: ExportedProgram,
Expand Down Expand Up @@ -484,9 +484,10 @@ def refit_module_weights(
weight_name_map=None,
)

# clear EXCLUDE_WEIGHTS flag
# clear EXCLUDE_WEIGHTS flag and set INCLUDE_REFIT flag to make the engine refittable
serialization_config = engine.create_serialization_config()
serialization_config.clear_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
serialization_config.set_flag(trt.SerializationFlag.INCLUDE_REFIT)
serialized_engine = engine.serialize_with_config(serialization_config)

if isinstance(compiled_submodule, PythonTorchTensorRTModule):
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def __setstate__(self, state: dict[str, Any]) -> None:
self.__dict__.update(state)


# TODO: @Evan If changing the setting would affect the behavior of engine compilation, then should be added to this list
_SETTINGS_TO_BE_ENGINE_INVARIANT = (
"enabled_precisions",
"max_aux_streams",
Expand All @@ -167,7 +168,6 @@ def __setstate__(self, state: dict[str, Any]) -> None:
"engine_capability",
"hardware_compatible",
"refit_identical_engine_weights",
"strip_engine_weights", # TODO: @Evan to remove this after implementing caching weight-stripped engines as default?
"immutable_weights",
"enable_weight_streaming",
"tiling_optimization_level",
Expand Down
4 changes: 0 additions & 4 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,6 @@ def _pretraced_backend(
logger.warning(
"require_full_compilation arg is not applicable for torch.compile with backend='torch_tensorrt"
)
if settings.strip_engine_weights:
logger.error(
"strip_engine_weights arg is not supported for torch.compile()"
)
trt_compiled = compile_module(
gm,
torchtrt_inputs,
Expand Down
63 changes: 32 additions & 31 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def check_weight_equal(
except Exception:
return torch.all(sd_weight == network_weight)

@needs_refit
@needs_refit # type: ignore[misc]
def _save_weight_mapping(self) -> None:
"""
Construct the weight name mapping from engine weight name to state_dict weight name.
Expand Down Expand Up @@ -577,21 +577,19 @@ def _save_weight_mapping(self) -> None:
gc.collect()
torch.cuda.empty_cache()

@needs_refit
@needs_refit # type: ignore[misc]
def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> None:
# TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine
# if not self.compilation_settings.strip_engine_weights:
# # set EXCLUDE_WEIGHTS flag to strip weights
# runtime = trt.Runtime(TRT_LOGGER)
# engine = runtime.deserialize_cuda_engine(serialized_engine)

# serialization_config = engine.create_serialization_config()
# serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
# serialized_engine = engine.serialize_with_config(
# serialization_config
# )

# Cache weighted engine for now
# Cache the weight-stripped engine regardless of the `strip_engine_weights` setting
if not self.compilation_settings.strip_engine_weights:
# set EXCLUDE_WEIGHTS flag to strip weights
runtime = trt.Runtime(TRT_LOGGER)
engine = runtime.deserialize_cuda_engine(serialized_engine)

serialization_config = engine.create_serialization_config()
serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
serialized_engine = engine.serialize_with_config(serialization_config)

# Insert weight-stripped engine to cache
self.engine_cache.insert( # type: ignore[union-attr]
hash_val,
(
Expand All @@ -605,13 +603,13 @@ def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> No
),
)

@needs_refit
@needs_refit # type: ignore[misc]
def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]:
# query the cached TRT engine
cached_data = self.engine_cache.check(hash_val) # type: ignore[union-attr]
if cached_data is not None: # hit the cache
(
serialized_engine,
serialized_engine, # weight-stripped engine
self._input_names,
self._output_names,
cached_engine_input_specs,
Expand Down Expand Up @@ -644,31 +642,34 @@ def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]:
# refit the cached engine with the new graph module
if not self.compilation_settings.strip_engine_weights:
runtime = trt.Runtime(TRT_LOGGER)
engine = runtime.deserialize_cuda_engine(serialized_engine)
weight_stripped_engine = runtime.deserialize_cuda_engine(
serialized_engine
)

from torch_tensorrt.dynamo._refit import (
_refit_single_trt_engine_with_gm,
)

# weight_stripped_engine --in place--> weight_included_engine
_refit_single_trt_engine_with_gm(
new_gm=self.module,
old_engine=engine,
old_engine=weight_stripped_engine,
input_list=self.input_specs,
settings=self.compilation_settings,
weight_name_map=self.weight_name_map,
)
serialized_engine = engine.serialize()

# TODO: @Evan is waiting for TRT's feature to load the weight-stripped engine
# # EXCLUDE_WEIGHTS flag must be cleared
# serialization_config = engine.create_serialization_config()
# serialization_config.clear_flag(
# trt.SerializationFlag.EXCLUDE_WEIGHTS
# )
# serialized_engine = engine.serialize_with_config(
# serialization_config
# )
# # As of now, the engine becomes non-refittable because when EXCLUDE_WEIGHTS flag is cleared, the REFIT flag is also cleared by TRT to make the plan file smaller

# Load the cached weight-stripped engine
# EXCLUDE_WEIGHTS flag must be cleared and INCLUDE_REFIT flag must be set
serialization_config = (
weight_stripped_engine.create_serialization_config()
)
serialization_config.clear_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
serialization_config.set_flag(trt.SerializationFlag.INCLUDE_REFIT)
serialized_engine = weight_stripped_engine.serialize_with_config(
serialization_config
)
# Start from here, the serialized_engine is weight-included and refittable

with io.BytesIO() as engine_bytes:
engine_bytes.write(serialized_engine)
Expand Down
92 changes: 86 additions & 6 deletions tests/py/dynamo/models/test_weight_stripped_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,11 +449,11 @@ def remove_timing_cache(path=TIMING_CACHE_PATH):
not torch_trt.ENABLED_FEATURES.refit,
"Engine caching requires refit feature that is not supported in Python 3.13 or higher",
)
def test_different_args_dont_share_cached_engine(self):
def test_different_args_share_cached_engine(self):
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 4, 3, stride=1, bias=True)
self.conv = torch.nn.Conv2d(512, 64, 32, stride=1, bias=True)
self.relu = torch.nn.ReLU()

def forward(self, x):
Expand All @@ -463,11 +463,11 @@ def forward(self, x):

pyt_model = MyModel().eval().to("cuda")

engine_cache_dir = "/tmp/test_different_args_dont_share_cached_engine"
engine_cache_dir = "/tmp/test_different_args_share_cached_engine"
if os.path.exists(engine_cache_dir):
shutil.rmtree(engine_cache_dir)

inputs = [torch.rand((4, 3, 32, 32)).to("cuda")]
inputs = [torch.rand((64, 512, 32, 32)).to("cuda")]

for i in range(2):
if i == 0:
Expand All @@ -493,8 +493,8 @@ def forward(self, x):

assertions.assertEqual(
len(os.listdir(engine_cache_dir)),
2,
msg=f"It has {len(os.listdir(engine_cache_dir))} cached engine(s) but should have 2 engines",
1,
msg=f"It has {len(os.listdir(engine_cache_dir))} cached engine(s) but should have 1 engine",
)

@unittest.skipIf(
Expand Down Expand Up @@ -631,3 +631,83 @@ def test_refit_identical_engine_weights(self):
)
except Exception as e:
pass

@unittest.skipIf(
not torch_trt.ENABLED_FEATURES.refit,
"Engine caching requires refit feature that is not supported in Python 3.13 or higher",
)
@unittest.skipIf(
not importlib.util.find_spec("torchvision"),
"torchvision is not installed",
)
def test_refit_weight_stripped_engine_multiple_times(self):
pyt_model = models.resnet18(pretrained=True).eval().to("cuda")
example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),)
# Mark the dim0 of inputs as dynamic
batch = torch.export.Dim("batch", min=1, max=200)
exp_program = torch.export.export(
pyt_model, args=example_inputs, dynamic_shapes={"x": {0: batch}}
)

inputs = (torch.rand((128, 3, 224, 224)).to("cuda"),)

trt_gm = torch_trt.dynamo.compile(
exp_program,
inputs,
use_python_runtime=True,
enabled_precisions={torch.float},
min_block_size=1,
immutable_weights=False,
cache_built_engines=False,
reuse_cached_engines=False,
strip_engine_weights=True,
refit_identical_engine_weights=False,
)
output = trt_gm(*inputs)
assertions.assertEqual(
output.sum(), 0, msg="weight-stripped engine results should be all zeros"
)

# Refit the weight-stripped engine with the same weights
refitted_trt_gm = refit_module_weights(trt_gm, exp_program)
refitted_output = refitted_trt_gm(*inputs)
assertions.assertNotEqual(
refitted_output.sum(),
0,
msg="refitted engine results should not be all zeros",
)

inputs2 = (torch.rand((64, 3, 224, 224)).to("cuda"),)
exp_program2 = torch.export.export(
pyt_model, args=inputs2, dynamic_shapes={"x": {0: batch}}
)

# Refit with different weights
refitted_trt_gm = refit_module_weights(refitted_trt_gm, exp_program2)
refitted_output = refitted_trt_gm(*inputs2)
assertions.assertNotEqual(
refitted_output.sum(),
0,
msg="refitted engine results should not be all zeros",
)

compiled_model = torch.compile(
pyt_model,
backend="tensorrt",
options={
"use_python_runtime": False,
"enabled_precisions": {torch.float},
"min_block_size": 1,
"immutable_weights": False,
"cache_built_engines": False,
"reuse_cached_engines": False,
"refit_identical_engine_weights": False,
"strip_engine_weights": False,
},
)
compiled_model_output = compiled_model(*inputs2)
cos_sim = cosine_similarity(refitted_output, compiled_model_output)
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"refitted_output doesn't match with compiled_model_output. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)
Loading