Skip to content

Commit e7cad5b

Browse files
committed
Added decorator and tests
1 parent b9fe0c1 commit e7cad5b

File tree

4 files changed

+120
-13
lines changed

4 files changed

+120
-13
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,10 @@ def compile(
618618
"'arg_inputs' and 'inputs' should not be used at the same time."
619619
)
620620

621+
assert (
622+
cpu_memory_budget >= 2 * 1024 * 1024 * 1024
623+
), "CPU memory budget must be greater than 10GB"
624+
621625
arg_inputs = inputs or arg_inputs
622626

623627
if kwarg_inputs is None:

py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def calculate_size_budget(
211211
int: Budget in bytes for a single accelerated subgraph.
212212
"""
213213

214-
used_rss: int = psutil.virtual_memory().used
214+
used_rss: int = psutil.Process().memory_info().rss
215215
available_rss = self.cpu_memory_budget - used_rss
216216
return available_rss // engine_compilation_memory_usage_multiplier
217217

py/torch_tensorrt/dynamo/partitioning/fusion_patterns.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,25 @@
11
from functools import lru_cache
2-
from typing import Dict, List, Set
2+
from typing import Callable, Dict, List, Set
33

44
import torch
55
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
66
from torch.ops import aten
77

8+
ATOMIC_SUBGRAPHS = []
89

10+
11+
def register_atomic_subgraph(
12+
is_aten: bool = False,
13+
) -> Callable[[torch.nn.Module], torch.nn.Module]:
14+
15+
def decorator(subgraph: torch.nn.Module) -> torch.nn.Module:
16+
ATOMIC_SUBGRAPHS.append((subgraph, is_aten))
17+
return subgraph
18+
19+
return decorator
20+
21+
22+
@register_atomic_subgraph(is_aten=True)
923
class ConvBNReLU(torch.nn.Module): # type: ignore[misc]
1024
def __init__(self) -> None:
1125
super().__init__()
@@ -46,6 +60,7 @@ def forward(
4660
return x
4761

4862

63+
@register_atomic_subgraph(is_aten=True)
4964
class ConvReLU(torch.nn.Module): # type: ignore[misc]
5065
def __init__(self) -> None:
5166
super().__init__()
@@ -77,6 +92,7 @@ def forward(
7792
return x
7893

7994

95+
@register_atomic_subgraph(is_aten=True)
8096
class ConvGelu(torch.nn.Module): # type: ignore[misc]
8197
def __init__(self) -> None:
8298
super().__init__()
@@ -108,6 +124,7 @@ def forward(
108124
return x
109125

110126

127+
@register_atomic_subgraph(is_aten=True)
111128
class ConvSilu(torch.nn.Module): # type: ignore[misc]
112129
def __init__(self) -> None:
113130
super().__init__()
@@ -122,6 +139,7 @@ def forward(
122139
return x
123140

124141

142+
@register_atomic_subgraph(is_aten=True)
125143
class MulAdd(torch.nn.Module): # type: ignore[misc]
126144
def __init__(self) -> None:
127145
super().__init__()
@@ -134,6 +152,7 @@ def forward(
134152
return x
135153

136154

155+
@register_atomic_subgraph(is_aten=True)
137156
class MulMul(torch.nn.Module): # type: ignore[misc]
138157
def __init__(self) -> None:
139158
super().__init__()
@@ -146,16 +165,6 @@ def forward(
146165
return x
147166

148167

149-
All_FUSION_PATTERNS = [
150-
ConvBNReLU,
151-
ConvReLU,
152-
ConvGelu,
153-
ConvSilu,
154-
MulAdd,
155-
MulMul,
156-
]
157-
158-
159168
@lru_cache(maxsize=None)
160169
def get_node_in_fusion_pattern(
161170
graph: torch.fx.Graph,
@@ -166,8 +175,9 @@ def get_node_in_fusion_pattern(
166175
Value: the list of nodes that should be fused together
167176
"""
168177
fusion_nodes = {}
169-
for pattern in All_FUSION_PATTERNS:
178+
for pattern, is_aten in ATOMIC_SUBGRAPHS:
170179
pattern_graph = torch.fx.symbolic_trace(pattern())
180+
# TODO: Add decomposition and lowering if is_aten is False
171181
subgraph_matcher = SubgraphMatcher(pattern_graph.graph)
172182
match_result = subgraph_matcher.match(graph)
173183
for match in match_result:
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from typing import Any
2+
3+
import torch
4+
import torch.nn as nn
5+
import torch.nn.functional as F
6+
import torch_tensorrt as torchtrt
7+
from torch.testing._internal.common_utils import TestCase, run_tests
8+
from torch_tensorrt.dynamo import partitioning
9+
from torch_tensorrt.dynamo.conversion import CompilationSettings
10+
from torch_tensorrt.dynamo.lowering import (
11+
get_decompositions,
12+
post_lowering,
13+
pre_export_lowering,
14+
)
15+
from torch_tensorrt.dynamo.lowering.passes import post_lowering, pre_export_lowering
16+
from torch_tensorrt.dynamo.partitioning._resource_partitioner import resource_partition
17+
18+
19+
class TestResourcePartitioning(TestCase):
20+
def test_resource_partitioning(self):
21+
class net(nn.Module):
22+
def __init__(self):
23+
super().__init__()
24+
self.conv1 = nn.Conv2d(1024, 4096, 3, padding=1)
25+
self.bn1 = nn.BatchNorm2d(4096)
26+
self.conv2 = nn.Conv2d(4096, 1024, 3, padding=1)
27+
self.bn2 = nn.BatchNorm2d(1024)
28+
self.fc1 = nn.Linear(1024 * 56 * 56, 10)
29+
30+
def forward(self, x):
31+
x = self.conv1(x)
32+
x = self.bn1(x)
33+
x = F.relu(x)
34+
x = F.max_pool2d(x, (2, 2))
35+
x = self.conv2(x)
36+
x = self.bn2(x)
37+
x = F.relu(x)
38+
x = F.max_pool2d(x, (2, 2))
39+
x = torch.flatten(x, 1)
40+
return self.fc1(x)
41+
42+
model = net().eval()
43+
model.to("cuda")
44+
inputs = [torch.randn((1, 1024, 224, 224)).to("cuda")]
45+
46+
enabled_precisions = {torch.float}
47+
use_python_runtime = False
48+
49+
exp_program = torch.export.export(model, tuple(inputs))
50+
51+
compilation_options = {
52+
"use_python_runtime": use_python_runtime,
53+
"enabled_precisions": enabled_precisions,
54+
"min_block_size": 1,
55+
"immutable_weights": True,
56+
"reuse_cached_engines": False,
57+
}
58+
settings = CompilationSettings(**compilation_options)
59+
with torchtrt.dynamo.Debugger(
60+
log_level="debug",
61+
logging_dir="/home/profile/logging/moe",
62+
engine_builder_monitor=False,
63+
):
64+
65+
exported_program = pre_export_lowering(exp_program, settings)
66+
exported_program = exported_program.run_decompositions(
67+
get_decompositions(False)
68+
)
69+
70+
gm = exported_program.module()
71+
gm = post_lowering(gm, settings)
72+
73+
partitioned_module, supported_ops = partitioning.fast_partition(
74+
gm,
75+
min_block_size=settings.min_block_size,
76+
torch_executed_ops=settings.torch_executed_ops,
77+
require_full_compilation=settings.require_full_compilation,
78+
skip_fusion=True,
79+
)
80+
81+
partitioned_module = resource_partition(
82+
gm, partitioned_module, cpu_memory_budget=2 * 1024 * 1024 * 1024 # 2GB,
83+
)
84+
85+
self.assertEqual(
86+
len(list[Any](partitioned_module.named_children())),
87+
2,
88+
"The graph should have 2 subgraphs",
89+
)
90+
91+
92+
if __name__ == "__main__":
93+
run_tests()

0 commit comments

Comments
 (0)