|
8 | 8 | import warnings |
9 | 9 | from contextlib import ExitStack, contextmanager |
10 | 10 | from types import MethodType |
11 | | -from typing import Optional, Union |
| 11 | +from typing import Any, Optional, Tuple, Union |
12 | 12 |
|
13 | 13 | import torch |
| 14 | +from torch._dynamo.functional_export import _dynamo_graph_capture_for_export |
14 | 15 | from torch._functorch.aot_autograd import ( |
15 | 16 | aot_compile_joint_with_descriptors, |
16 | 17 | aot_export_joint_with_descriptors, |
|
22 | 23 | from torch._subclasses import FakeTensorMode |
23 | 24 | from torch.distributed.fsdp import MixedPrecisionPolicy |
24 | 25 | from torch.distributed.tensor import DeviceMesh |
| 26 | +from torch.export._trace import _restore_state_dict |
25 | 27 | from torch.export._unlift import _assign_attr |
26 | 28 | from torch.export.unflatten import _AttrKind |
27 | 29 |
|
@@ -163,6 +165,21 @@ def enable_local_map_wrapping(): |
163 | 165 | yield |
164 | 166 |
|
165 | 167 |
|
| 168 | +def _export(model: torch.nn.Module, inputs: Tuple[Any]) -> torch.nn.Module: |
| 169 | + """ |
| 170 | + Thin wrapper around graph capture output that restores the |
| 171 | + original calling convention and attribute fqn. TODO: |
| 172 | + 1) Use bytecode for calling convention instead of pytree for more |
| 173 | + seamless UX. |
| 174 | + 2) Attach guards |
| 175 | + 3) Be more careful about tensor constants names. |
| 176 | + """ |
| 177 | + with torch._dynamo.config.patch(install_free_tensors=True): |
| 178 | + gm = _dynamo_graph_capture_for_export(model)(*inputs) |
| 179 | + _restore_state_dict(model, gm) |
| 180 | + return gm |
| 181 | + |
| 182 | + |
166 | 183 | class AutoParallel: |
167 | 184 | """ |
168 | 185 | Args: |
@@ -279,13 +296,10 @@ def build_model_graph(self): |
279 | 296 | with set_dtype_cast( |
280 | 297 | True |
281 | 298 | ), enable_local_map_wrapping(), torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(): |
282 | | - with torch._dynamo.config.patch( |
283 | | - install_free_tensors=True |
284 | | - ), monkey_patch_export_verifier(): |
285 | | - ep = torch.export.export(self.model, inputs, strict=True) |
| 299 | + torch_ir_with_fqn = _export(self.model, inputs) |
286 | 300 | self.joint_with_descriptors = aot_export_joint_with_descriptors( |
287 | 301 | self.stack, |
288 | | - ep.module(), |
| 302 | + torch_ir_with_fqn, |
289 | 303 | inputs, |
290 | 304 | decompositions=decomp_table, |
291 | 305 | ) |
|
0 commit comments