Skip to content

Commit f1887eb

Browse files
Move autoparallel to use leaner export API (#181)
1 parent f4ef815 commit f1887eb

File tree

1 file changed

+20
-6
lines changed

1 file changed

+20
-6
lines changed

autoparallel/api.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
import warnings
99
from contextlib import ExitStack, contextmanager
1010
from types import MethodType
11-
from typing import Optional, Union
11+
from typing import Any, Optional, Tuple, Union
1212

1313
import torch
14+
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
1415
from torch._functorch.aot_autograd import (
1516
aot_compile_joint_with_descriptors,
1617
aot_export_joint_with_descriptors,
@@ -22,6 +23,7 @@
2223
from torch._subclasses import FakeTensorMode
2324
from torch.distributed.fsdp import MixedPrecisionPolicy
2425
from torch.distributed.tensor import DeviceMesh
26+
from torch.export._trace import _restore_state_dict
2527
from torch.export._unlift import _assign_attr
2628
from torch.export.unflatten import _AttrKind
2729

@@ -163,6 +165,21 @@ def enable_local_map_wrapping():
163165
yield
164166

165167

168+
def _export(model: torch.nn.Module, inputs: Tuple[Any]) -> torch.nn.Module:
169+
"""
170+
Thin wrapper around graph capture output that restores the
171+
original calling convention and attribute fqn. TODO:
172+
1) Use bytecode for calling convention instead of pytree for more
173+
seamless UX.
174+
2) Attach guards
175+
3) Be more careful about tensor constants names.
176+
"""
177+
with torch._dynamo.config.patch(install_free_tensors=True):
178+
gm = _dynamo_graph_capture_for_export(model)(*inputs)
179+
_restore_state_dict(model, gm)
180+
return gm
181+
182+
166183
class AutoParallel:
167184
"""
168185
Args:
@@ -279,13 +296,10 @@ def build_model_graph(self):
279296
with set_dtype_cast(
280297
True
281298
), enable_local_map_wrapping(), torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing():
282-
with torch._dynamo.config.patch(
283-
install_free_tensors=True
284-
), monkey_patch_export_verifier():
285-
ep = torch.export.export(self.model, inputs, strict=True)
299+
torch_ir_with_fqn = _export(self.model, inputs)
286300
self.joint_with_descriptors = aot_export_joint_with_descriptors(
287301
self.stack,
288-
ep.module(),
302+
torch_ir_with_fqn,
289303
inputs,
290304
decompositions=decomp_table,
291305
)

0 commit comments

Comments
 (0)