Skip to content

Commit 3b02ef9

Browse files
committed
Added example and fixed lru problem
1 parent 99581e2 commit 3b02ef9

File tree

2 files changed

+104
-5
lines changed

2 files changed

+104
-5
lines changed
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
"""
2+
3+
.. _low_cpu_memory_compilation:
4+
5+
Low CPU Memory Compilation Example
6+
==================================
7+
8+
This example demonstrates compiling a model with a bounded CPU (host) memory
9+
budget using Torch-TensorRT Dynamo. Limiting host RAM use is helpful on
10+
memory-constrained machines or when compiling very large models.
11+
12+
Key notes:
13+
- The toy model below has roughly 430 MB of parameters. We set the CPU
14+
memory budget to 2 GiB. At compile time, only about 900 MB of host RAM
15+
may remain available. We expect at most 403 * 4 = 1612 MB of memory to be used by the model.
16+
So the model is partitioned into two subgraphs to fit the memory budget.
17+
18+
- Performance impact varies by model. When the number of TensorRT engines
19+
created is small, the impact is typically minimal.
20+
21+
"""
22+
23+
import torch
24+
import torch.nn as nn
25+
import torch.nn.functional as F
26+
import torch_tensorrt as torchtrt
27+
from torch_tensorrt.dynamo.conversion import CompilationSettings
28+
29+
30+
class net(nn.Module):
31+
def __init__(self):
32+
super().__init__()
33+
# Intentionally large layers to stress host memory during compilation.
34+
self.conv1 = nn.Conv2d(1024, 4096, 3, padding=1)
35+
self.bn1 = nn.BatchNorm2d(4096)
36+
self.conv2 = nn.Conv2d(4096, 1024, 3, padding=1)
37+
self.bn2 = nn.BatchNorm2d(1024)
38+
self.fc1 = nn.Linear(1024 * 56 * 56, 10)
39+
40+
def forward(self, x):
41+
x = self.conv1(x)
42+
x = self.bn1(x)
43+
x = F.relu(x)
44+
x = F.max_pool2d(x, (2, 2))
45+
x = self.conv2(x)
46+
x = self.bn2(x)
47+
x = F.relu(x)
48+
x = F.max_pool2d(x, (2, 2))
49+
x = torch.flatten(x, 1)
50+
return self.fc1(x)
51+
52+
53+
model = net().eval()
54+
model.to("cuda")
55+
inputs = [torch.randn((1, 1024, 224, 224)).to("cuda")]
56+
57+
enabled_precisions = {torch.float}
58+
use_python_runtime = False
59+
60+
compilation_options = {
61+
"use_python_runtime": use_python_runtime,
62+
"enabled_precisions": enabled_precisions,
63+
"min_block_size": 1,
64+
"immutable_weights": True,
65+
"reuse_cached_engines": False,
66+
"cpu_memory_budget": 2 * 1024 * 1024 * 1024, # 2 GiB in bytes
67+
}
68+
69+
settings = CompilationSettings(**compilation_options)
70+
with torchtrt.dynamo.Debugger(
71+
log_level="debug",
72+
logging_dir="/home/profile/logging/moe",
73+
engine_builder_monitor=False,
74+
):
75+
76+
exp_program = torch.export.export(model, tuple(inputs))
77+
trt_gm = torchtrt.dynamo.compile(
78+
exp_program,
79+
inputs=inputs,
80+
**compilation_options,
81+
)
82+
83+
# Expect two back-to-back TensorRT engines due to partitioning under the memory budget.
84+
print(trt_gm)

py/torch_tensorrt/dynamo/partitioning/_atomic_subgraphs.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,6 @@ def forward(
165165
return x
166166

167167

168-
@lru_cache(maxsize=None)
169168
def get_node_in_fusion_pattern(
170169
graph: torch.fx.Graph,
171170
) -> Dict[torch.fx.Node, Set[torch.fx.Node]]:
@@ -175,10 +174,8 @@ def get_node_in_fusion_pattern(
175174
Value: the list of nodes that should be fused together
176175
"""
177176
fusion_nodes = {}
178-
for pattern, is_aten in ATOMIC_SUBGRAPHS:
179-
pattern_graph = torch.fx.symbolic_trace(pattern())
180-
# TODO: Add decomposition and lowering if is_aten is False
181-
subgraph_matcher = SubgraphMatcher(pattern_graph.graph)
177+
for compiled_pattern_graph in get_compiled_atomic_subgraphs():
178+
subgraph_matcher = SubgraphMatcher(compiled_pattern_graph.graph)
182179
match_result = subgraph_matcher.match(graph)
183180
for match in match_result:
184181
fusion_group = {
@@ -193,3 +190,21 @@ def get_node_in_fusion_pattern(
193190
fusion_nodes[node] = fusion_group
194191

195192
return fusion_nodes
193+
194+
195+
@lru_cache(maxsize=None)
196+
def get_compiled_atomic_subgraphs() -> List[torch.fx.GraphModule]:
197+
"""
198+
This function gets the compiled atomic subgraphs from the graph.
199+
LRU cache the result to avoid recompiling the same pattern multiple times.
200+
"""
201+
compiled_atomic_subgraphs = []
202+
for pattern, is_aten in ATOMIC_SUBGRAPHS:
203+
pattern_graph = torch.fx.symbolic_trace(pattern())
204+
if not is_aten:
205+
# TODO: Add decomposition and lowering if is_aten is False
206+
raise NotImplementedError(
207+
"Atomic subgraphs are not supported for non-aten subgraphs yet."
208+
)
209+
compiled_atomic_subgraphs.append(pattern_graph)
210+
return compiled_atomic_subgraphs

0 commit comments

Comments
 (0)