Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 34 additions & 6 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ TRTEngine::TRTEngine(
const Platform& target_platform,
bool hardware_compatible,
bool requires_output_allocator,
const std::string& serialized_metadata)
const std::string& serialized_metadata,
const ResourceAllocationStrategy resource_allocation_strategy)
: TRTEngine(
"deserialized_trt",
serialized_engine,
Expand All @@ -71,7 +72,8 @@ TRTEngine::TRTEngine(
target_platform,
hardware_compatible,
requires_output_allocator,
serialized_metadata) {}
serialized_metadata,
resource_allocation_strategy) {}

TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
: TRTEngine(
Expand All @@ -83,7 +85,8 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
Platform(serialized_info[TARGET_PLATFORM_IDX]),
static_cast<bool>(std::stoi(serialized_info[HW_COMPATIBLE_IDX])),
static_cast<bool>(std::stoi(serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX])),
serialized_info[SERIALIZED_METADATA_IDX]) {}
serialized_info[SERIALIZED_METADATA_IDX],
(static_cast<bool>(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) ? ResourceAllocationStrategy::kDynamic : ResourceAllocationStrategy::kStatic)) {}

TRTEngine::TRTEngine(
const std::string& mod_name,
Expand All @@ -94,7 +97,8 @@ TRTEngine::TRTEngine(
const Platform& target_platform,
bool hardware_compatible,
bool requires_output_allocator,
const std::string& serialized_metadata) {
const std::string& serialized_metadata,
const ResourceAllocationStrategy resource_allocation_strategy) {
TORCHTRT_CHECK(
is_supported_on_current_platform(target_platform),
"This engine was not built to run on this platform (built for: " << target_platform << ", current platform: "
Expand Down Expand Up @@ -124,7 +128,14 @@ TRTEngine::TRTEngine(
cuda_engine->setWeightStreamingBudgetV2(budget_bytes);
}

exec_ctx = make_trt(cuda_engine->createExecutionContext());
this->resource_allocation_strategy = resource_allocation_strategy;
LOG_DEBUG("Resource allocation strategy: " << (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static"));
if (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic) {
this->exec_ctx =
make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED));
} else {
this->exec_ctx = make_trt(cuda_engine->createExecutionContext());
}
TORCHTRT_CHECK((exec_ctx.get() != nullptr), "Unable to create TensorRT execution context");

runtime_states.old_cudagraphs = CUDAGRAPHS_MODE;
Expand Down Expand Up @@ -393,6 +404,7 @@ std::string TRTEngine::to_str() const {
ss << " Device: " << device_info << std::endl;
ss << " Hardware Compatibility: " << (hardware_compatible ? "Enabled" : "Disabled") << std::endl;
ss << " Target Platform: " << target_platform << std::endl;
ss << " Resource Allocation Strategy: " << (resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static") << std::endl;
// clang-format on
return ss.str();
}
Expand Down Expand Up @@ -436,7 +448,8 @@ FlattenedState TRTEngine::__obj_flatten__() {
std::tuple("hardware_compatible", serialized_info[HW_COMPATIBLE_IDX]),
std::tuple("serialized_metadata", serialized_info[SERIALIZED_METADATA_IDX]),
std::tuple("requires_output_allocator", serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX]),
std::tuple("target_platform", serialized_info[TARGET_PLATFORM_IDX]));
std::tuple("target_platform", serialized_info[TARGET_PLATFORM_IDX]),
std::tuple("resource_allocation_strategy", serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX]));
}

std::vector<std::string> TRTEngine::serialize() {
Expand All @@ -459,6 +472,7 @@ std::vector<std::string> TRTEngine::serialize() {
serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX] = this->requires_output_allocator ? "1" : "0";
serialized_info[SERIALIZED_METADATA_IDX] = this->serialized_metadata;
serialized_info[TARGET_PLATFORM_IDX] = this->target_platform.serialize();
serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX] = this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "1" : "0";

return serialized_info;
}
Expand All @@ -467,6 +481,20 @@ void TRTEngine::reset_captured_graph() {
cudagraph.reset();
}

void TRTEngine::set_resource_allocation_strategy(TRTEngine::ResourceAllocationStrategy new_strategy) {
if (new_strategy != this->resource_allocation_strategy) {
this->resource_allocation_strategy = new_strategy;
if (this->resource_allocation_strategy == TRTEngine::ResourceAllocationStrategy::kDynamic) {
LOG_DEBUG("Setting resource allocation strategy to dynamic");
this->exec_ctx = make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED));
} else {
LOG_DEBUG("Setting resource allocation strategy to static");
this->exec_ctx = make_trt(
cuda_engine->createExecutionContext());
}
}
}

} // namespace runtime
} // namespace core
} // namespace torch_tensorrt
16 changes: 13 additions & 3 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ using FlattenedState = std::tuple<
std::tuple<std::string, std::string>, // HW compatibility
std::tuple<std::string, std::string>, // requires_output_allocator
std::tuple<std::string, std::string>, // serialized metadata
std::tuple<std::string, std::string>>; // Platform
std::tuple<std::string, std::string>, // Platform
std::tuple<std::string, std::string>>; // Resource Allocation Strategy

struct TorchTRTRuntimeStates {
// Indicates whether CUDAGraphs were enabled in the previous execute_engine
Expand Down Expand Up @@ -98,6 +99,8 @@ class DynamicOutputAllocator : public nvinfer1::IOutputAllocator {
};

struct TRTEngine : torch::CustomClassHolder {
// Resource Allocation Strategy
typedef enum { kStatic = 0, kDynamic } ResourceAllocationStrategy;
// Each engine needs it's own runtime object
std::shared_ptr<nvinfer1::IRuntime> rt;
std::shared_ptr<nvinfer1::ICudaEngine> cuda_engine;
Expand Down Expand Up @@ -128,7 +131,9 @@ struct TRTEngine : torch::CustomClassHolder {
const Platform& target_platform = get_current_platform(),
bool hardware_compatible = false,
bool requires_output_allocator = false,
const std::string& serialized_metadata = "");
const std::string& serialized_metadata = "",
const TRTEngine::ResourceAllocationStrategy resource_allocation_strategy =
TRTEngine::ResourceAllocationStrategy::kStatic);

TRTEngine(std::vector<std::string> serialized_info);

Expand All @@ -141,7 +146,9 @@ struct TRTEngine : torch::CustomClassHolder {
const Platform& target_platform = get_current_platform(),
bool hardware_compatible = false,
bool requires_output_allocator = false,
const std::string& serialized_metadata = "");
const std::string& serialized_metadata = "",
const TRTEngine::ResourceAllocationStrategy resource_allocation_strategy =
TRTEngine::ResourceAllocationStrategy::kStatic);

TRTEngine& operator=(const TRTEngine& other);
std::string to_str() const;
Expand Down Expand Up @@ -200,6 +207,9 @@ struct TRTEngine : torch::CustomClassHolder {
std::string cuda_graph_debug_path;
std::mutex mu;
std::unique_ptr<TRTEngineProfiler> trt_engine_profiler;
ResourceAllocationStrategy resource_allocation_strategy = kStatic;
void set_resource_allocation_strategy(ResourceAllocationStrategy new_strategy);
ResourceAllocationStrategy get_resource_allocation_strategy();
};

} // namespace runtime
Expand Down
6 changes: 6 additions & 0 deletions core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,12 @@ void create_output_allocator(c10::intrusive_ptr<TRTEngine> compiled_engine) {
}

std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine) {
torch::Tensor dynamic_workspace;
if (compiled_engine->resource_allocation_strategy == TRTEngine::ResourceAllocationStrategy::kDynamic) {
dynamic_workspace = torch::empty(compiled_engine->cuda_engine->getDeviceMemorySizeV2(), {torch::kCUDA});
compiled_engine->exec_ctx->setDeviceMemory(dynamic_workspace.data_ptr());
}

auto run_standard_execution = [&]() {
bool cudagraphs_enabled = (CUDAGRAPHS_MODE == SUBGRAPH_CUDAGRAPHS);
bool shape_changed = _validate_shapes(inputs, compiled_engine);
Expand Down
9 changes: 9 additions & 0 deletions core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,13 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
.def("get_engine_layer_info", &TRTEngine::get_engine_layer_info)
.def("infer_outputs", &TRTEngine::infer_outputs)
.def("reset_captured_graph", &TRTEngine::reset_captured_graph)
.def(
"use_dynamically_allocated_resources",
[](const c10::intrusive_ptr<TRTEngine>& self, bool dynamic) -> void {
self->set_resource_allocation_strategy(
dynamic ? TRTEngine::ResourceAllocationStrategy::kDynamic
: TRTEngine::ResourceAllocationStrategy::kStatic);
})
.def_readwrite("use_pre_allocated_outputs", &TRTEngine::use_pre_allocated_outputs)
.def_readwrite("use_output_allocator_outputs", &TRTEngine::use_output_allocator_outputs)
.def_property(
Expand All @@ -102,6 +109,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
[](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> { return self->serialize(); },
[](std::vector<std::string> serialized_info) -> c10::intrusive_ptr<TRTEngine> {
serialized_info[ENGINE_IDX] = base64_decode(serialized_info[ENGINE_IDX]);
LOG_DEBUG("Deserialized resource allocation strategy: " << (static_cast<bool>(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) ? "Dynamic" : "Static"));
TRTEngine::verify_serialization_fmt(serialized_info);
return c10::make_intrusive<TRTEngine>(serialized_info);
});
Expand Down Expand Up @@ -135,6 +143,7 @@ TORCH_LIBRARY(tensorrt, m) {
m.def("TARGET_PLATFORM_IDX", []() -> int64_t { return TARGET_PLATFORM_IDX; });
m.def("REQUIRES_OUTPUT_ALLOCATOR_IDX", []() -> int64_t { return REQUIRES_OUTPUT_ALLOCATOR_IDX; });
m.def("SERIALIZATION_LEN", []() -> int64_t { return SERIALIZATION_LEN; });
m.def("RESOURCE_ALLOCATION_STRATEGY_IDX", []() -> int64_t { return RESOURCE_ALLOCATION_STRATEGY_IDX; });
m.def("_platform_linux_x86_64", []() -> std::string {
auto it = get_platform_name_map().find(Platform::PlatformEnum::kLINUX_X86_64);
return it->second;
Expand Down
3 changes: 2 additions & 1 deletion core/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace core {
namespace runtime {

using EngineID = int64_t;
const std::string ABI_VERSION = "7";
const std::string ABI_VERSION = "8";
extern bool MULTI_DEVICE_SAFE_MODE;

typedef enum {
Expand All @@ -38,6 +38,7 @@ typedef enum {
SERIALIZED_METADATA_IDX,
TARGET_PLATFORM_IDX,
REQUIRES_OUTPUT_ALLOCATOR_IDX,
RESOURCE_ALLOCATION_STRATEGY_IDX,
SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO
} SerializedInfoIndex;

Expand Down
45 changes: 45 additions & 0 deletions examples/dynamo/dynamic_memory_allocation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# %%
import gc
import time

import numpy as np
import torch
import torch_tensorrt as torch_trt
import torchvision.models as models

np.random.seed(5)
torch.manual_seed(5)
inputs = [torch.rand((100, 3, 224, 224)).to("cuda")]

settings = {
"ir": "dynamo",
"use_python_runtime": False,
"enabled_precisions": {torch.float32},
"immutable_weights": False,
"lazy_engine_init": True,
"dynamically_allocate_resources": True,
}

model = models.resnet152(pretrained=True).eval().to("cuda")
compiled_module = torch_trt.compile(model, inputs=inputs, **settings)
print((torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) / 1024**3)
compiled_module(*inputs)


time.sleep(30)
with torch_trt.dynamo.runtime.ResourceAllocationStrategy(
compiled_module, dynamically_allocate_resources=False
):
print(
"Memory used (GB):",
(torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) / 1024**3,
)
compiled_module(*inputs)
gc.collect()
torch.cuda.empty_cache()
time.sleep(30)
print(
"Memory used (GB):",
(torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) / 1024**3,
)
compiled_module(*inputs)
6 changes: 6 additions & 0 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def cross_compile_for_windows(
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
dynamically_allocate_resources: bool = _defaults.DYNAMICALLY_ALLOCATE_RESOURCES,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows
Expand Down Expand Up @@ -177,6 +178,7 @@ def cross_compile_for_windows(
enable_weight_streaming (bool): Enable weight streaming.
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
dynamically_allocate_resources (bool): Dynamically allocate resources during engine execution.
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Expand Down Expand Up @@ -340,6 +342,7 @@ def cross_compile_for_windows(
"enable_weight_streaming": enable_weight_streaming,
"tiling_optimization_level": tiling_optimization_level,
"l2_limit_for_tiling": l2_limit_for_tiling,
"dynamically_allocate_resources": dynamically_allocate_resources,
}

# disable the following settings is not supported for cross compilation for windows feature
Expand Down Expand Up @@ -440,6 +443,7 @@ def compile(
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
dynamically_allocate_resources: bool = _defaults.DYNAMICALLY_ALLOCATE_RESOURCES,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
Expand Down Expand Up @@ -517,6 +521,7 @@ def compile(
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage.
dynamically_allocate_resources (bool): Dynamically allocate resources during engine execution.
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Expand Down Expand Up @@ -690,6 +695,7 @@ def compile(
"tiling_optimization_level": tiling_optimization_level,
"l2_limit_for_tiling": l2_limit_for_tiling,
"offload_module_to_cpu": offload_module_to_cpu,
"dynamically_allocate_resources": dynamically_allocate_resources,
}

settings = CompilationSettings(**compilation_options)
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
L2_LIMIT_FOR_TILING = -1
USE_DISTRIBUTED_MODE_TRACE = False
OFFLOAD_MODULE_TO_CPU = False
DYNAMICALLY_ALLOCATE_RESOURCES = False

if platform.system() == "Linux":
import pwd
Expand Down
4 changes: 4 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
DLA_GLOBAL_DRAM_SIZE,
DLA_LOCAL_DRAM_SIZE,
DLA_SRAM_SIZE,
DYNAMICALLY_ALLOCATE_RESOURCES,
DRYRUN,
ENABLE_CROSS_COMPILE_FOR_WINDOWS,
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
Expand Down Expand Up @@ -97,6 +98,8 @@ class CompilationSettings:
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
offload_module_to_cpu (bool): Offload the model to CPU to reduce memory footprint during compilation
dynamically_allocate_resources (bool): Dynamically allocate resources for TensorRT engines
"""

enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
Expand Down Expand Up @@ -140,6 +143,7 @@ class CompilationSettings:
l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE
offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU
dynamically_allocate_resources: bool = DYNAMICALLY_ALLOCATE_RESOURCES

def __getstate__(self) -> dict[str, Any]:
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
Expand Down
32 changes: 32 additions & 0 deletions py/torch_tensorrt/dynamo/runtime/_ResourceAllocator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import Any

import torch


class ResourceAllocationStrategy(torch.nn.Module): # type: ignore[misc]
"""
ResourceAllocationStrategy is a context manager module that temporarily enables dynamic resource allocation
for all TRT submodules of the given compiled_module. When entering the context,
it sets these submodules to use dynamically allocated resources. Upon exiting, it restores them to their
original (static) resource allocation mode.
"""

def __init__(
self,
compiled_module: torch.nn.Module,
dynamically_allocate_resources: bool = True
) -> None:
super(ResourceAllocationStrategy, self).__init__()
self.compiled_module = compiled_module
self.dynamically_allocate_resources = dynamically_allocate_resources

def __enter__(self) -> None:
print("Entering resource allocator context")
for name, submodule in self.compiled_module.named_modules():
if "_run_on_acc" in name:
submodule.use_dynamically_allocated_resources(dynamically_allocate_resources=self.dynamically_allocate_resources)

def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
for name, submodule in self.compiled_module.named_modules():
if "_run_on_acc" in name:
submodule.use_dynamically_allocated_resources(dynamically_allocate_resources=self.dynamically_allocate_resources)
Loading