Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 8 additions & 14 deletions magi_compiler/magi_backend/piecewise_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,12 +190,14 @@ def compile(
dynamic_shapes = "from_tracing_context"

# Step2: Compile the graph
import torch._functorch.config as functorch_config
from torch._inductor import standalone_compile

try:
compiled_graph = standalone_compile(
graph, example_inputs, dynamic_shapes=dynamic_shapes, options={"config_patches": current_config}
)
with functorch_config.patch(autograd_cache_allow_custom_autograd_functions=True):
compiled_graph = standalone_compile(
graph, example_inputs, dynamic_shapes=dynamic_shapes, options={"config_patches": current_config}
)
except torch._dynamo.exc.RestartAnalysis as e:
if key is not None:
self._restart_analysis_counts[key] = self._restart_analysis_counts.get(key, 0) + 1
Expand All @@ -212,21 +214,13 @@ def compile(
raise

# Step3: Save the compiled artifact
# When standalone_compile is invoked from within a torch.compile backend,
# AOTAutograd's cache key computation may be silently bypassed, leaving
# aot_autograd artifacts empty. In that case save() will raise an
# AssertionError, so we fall back to running without a cache handle.
# TODO: Support caching the compiled artifact.
# autograd_cache_allow_custom_autograd_functions=True is required above so that
# autograd_function_apply (a HigherOrderOperator) does not bypass AOTAutograd cache
# key computation, which would leave aot_autograd_artifacts empty and cause save() to fail.
assert key is not None
restart_analysis_count = self._restart_analysis_counts.get(key, 0)
if hasattr(self, "cache_dir") and self.cache_dir is not None:
try:
# Workaround for empty aot_autograd artifacts
if getattr(compiled_graph, "_artifacts", None) is not None:
_, cache_info = compiled_graph._artifacts
if not cache_info.artifacts.get("aot_autograd"):
cache_info.artifacts["aot_autograd"] = [key]

path: Path = self.cache_dir / key
compiled_graph.save(path=path.as_posix(), format="unpacked")
compilation_counter.num_compiled_artifacts_saved += 1
Expand Down
139 changes: 139 additions & 0 deletions tests/feature_tests/cache_reuse_helper/autograd_cache_flag_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright (c) 2026 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Helper script for test_autograd_function_cache_flag.py.

Each subprocess invocation runs one full training step (forward + backward +
optimizer) on a model that contains a torch.autograd.Function subclass, causing
Dynamo to emit an autograd_function_apply HigherOrderOperator node.

A spy is installed on torch._inductor.standalone_compile to record the value of
autograd_cache_allow_custom_autograd_functions at the moment standalone_compile
is called. Because piecewise_compiler.py imports standalone_compile inside its
compile() method body (``from torch._inductor import standalone_compile``), the
import resolves to the patched spy while the patch is active.

The spy delegates to the real standalone_compile so that compilation and artifact
saving proceed normally.

Output JSON payload
-------------------
- flag_during_compile: list of bool, one entry per standalone_compile call
- all_flags_true: True iff every entry in flag_during_compile is True
- num_standalone_compile_calls: len(flag_during_compile)
- num_compiled_artifacts_saved: from compilation_counter
- num_inductor_compiles: from compilation_counter
- loss: scalar training loss value
"""

from __future__ import annotations

import argparse
import json
from unittest.mock import patch

import torch
import torch._functorch.config as functorch_config
import torch._inductor as _inductor_mod
import torch.nn as nn

from magi_compiler import magi_compile
from magi_compiler.config import CompileMode, get_compile_config
from magi_compiler.utils import compilation_counter

DEVICE = "cuda"
DTYPE = torch.bfloat16
HIDDEN = 16


class _ScaledSigmoid(torch.autograd.Function):
"""A custom autograd function.

When Dynamo traces ``_ScaledSigmoid.apply(x)`` it emits an
``autograd_function_apply`` HigherOrderOperator node β€” the node that
previously caused AOTAutograd caching to be bypassed.
"""

@staticmethod
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
ctx.save_for_backward(x)
return torch.sigmoid(x) * 2.0

@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
(x,) = ctx.saved_tensors
sig = torch.sigmoid(x)
return grad_output * sig * (1.0 - sig) * 2.0


@magi_compile(dynamic_arg_dims={"x": 0})
class TrainingModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(HIDDEN, HIDDEN, dtype=DTYPE, device=DEVICE)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return _ScaledSigmoid.apply(self.linear(x))


def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--cache-root", required=True)
parser.add_argument("--output", required=True)
args = parser.parse_args()

config = get_compile_config()
config.compile_mode = CompileMode.MAGI_COMPILE
config.aot = False
config.cache_root_dir = args.cache_root

torch._dynamo.reset()
torch.manual_seed(2026)
torch.cuda.manual_seed_all(2026)

# Install a spy on standalone_compile to record the flag state at call time.
# The real implementation is captured before patching so compilation proceeds
# normally and artifacts are saved as usual.
_real_standalone_compile = _inductor_mod.standalone_compile
flag_during_compile: list[bool] = []

def _spy_standalone_compile(graph, example_inputs, **kwargs):
flag_during_compile.append(functorch_config.autograd_cache_allow_custom_autograd_functions)
return _real_standalone_compile(graph, example_inputs, **kwargs)

with patch("torch._inductor.standalone_compile", side_effect=_spy_standalone_compile):
model = TrainingModel()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
x = torch.randn(4, HIDDEN, device=DEVICE, dtype=DTYPE)

optimizer.zero_grad()
y = model(x)
loss = y.sum()
loss.backward()
optimizer.step()

payload = {
"flag_during_compile": flag_during_compile,
"all_flags_true": all(flag_during_compile),
"num_standalone_compile_calls": len(flag_during_compile),
"num_compiled_artifacts_saved": compilation_counter.num_compiled_artifacts_saved,
"num_inductor_compiles": compilation_counter.num_inductor_compiles,
"loss": float(loss.float().item()),
}
with open(args.output, "w") as f:
json.dump(payload, f)


if __name__ == "__main__":
main()
109 changes: 109 additions & 0 deletions tests/feature_tests/test_autograd_function_cache_flag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright (c) 2026 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests verifying that InductorStandaloneAdaptor.compile() correctly patches
autograd_cache_allow_custom_autograd_functions=True around standalone_compile().

Two-process integration test (mirrors test_transformer_cache_reuse.py):

run 1 (warm) – compiles a training model containing autograd_function_apply,
saves the artifact to a shared cache directory, and verifies
that autograd_cache_allow_custom_autograd_functions was True
inside every standalone_compile() call.
run 2 (cache) – starts fresh, loads the artifact from disk, and verifies
that no recompilation occurred.

Assertions:
- "Failed to save compiled artifact" must NOT appear in run 1 stderr.
- flag_during_compile entries are all True on run 1.
- num_compiled_artifacts_saved > 0 on run 1.
- num_inductor_compiles == 0 on run 2.
- loss values are numerically consistent between runs.
"""

from __future__ import annotations

import json
import os
import subprocess
import sys
from pathlib import Path

import pytest
import torch


@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
def test_autograd_cache_flag_and_cache_reuse(tmp_path: Path):
"""Training model with autograd.Function: artifact saved on run 1, loaded
on run 2, and autograd_cache_allow_custom_autograd_functions is True during
every standalone_compile() call."""

helper_path = Path(__file__).parent / "cache_reuse_helper" / "autograd_cache_flag_helper.py"
cache_root = tmp_path / "cache"
out1 = tmp_path / "run1.json"
out2 = tmp_path / "run2.json"

env = os.environ.copy()
env["MAGI_LOGGING_LEVEL"] = "info"

def _run(output: Path) -> subprocess.CompletedProcess:
return subprocess.run(
[sys.executable, str(helper_path), "--cache-root", str(cache_root), "--output", str(output)],
env=env,
capture_output=True,
text=True,
)

# ── Run 1: warm cache ────────────────────────────────────────────────────
p1 = _run(out1)
assert p1.returncode == 0, f"run 1 failed\nstdout:\n{p1.stdout}\nstderr:\n{p1.stderr}"

# The fix must prevent "Failed to save compiled artifact" from appearing.
assert "Failed to save compiled artifact" not in p1.stderr, (
"CompiledArtifact.save() still failing β€” autograd_function_apply bypass not fixed.\n" f"stderr:\n{p1.stderr}"
)

payload1 = json.loads(out1.read_text())

# Flag check: every standalone_compile() call must see the flag as True.
assert payload1["num_standalone_compile_calls"] > 0, "Spy was never called β€” standalone_compile was not intercepted."
assert payload1["all_flags_true"], (
"autograd_cache_allow_custom_autograd_functions was NOT True during "
f"standalone_compile(); observed per-call values: {payload1['flag_during_compile']}"
)

# Artifact must have been saved.
assert payload1["num_compiled_artifacts_saved"] > 0, (
"Expected at least one artifact to be saved on the warm run, "
f"got num_compiled_artifacts_saved={payload1['num_compiled_artifacts_saved']}"
)

# ── Run 2: cache hit ─────────────────────────────────────────────────────
p2 = _run(out2)
assert p2.returncode == 0, f"run 2 failed\nstdout:\n{p2.stdout}\nstderr:\n{p2.stderr}"

payload2 = json.loads(out2.read_text())

# No recompilation: PiecewiseCompiler.load() returns early before compile().
assert payload2["num_inductor_compiles"] == 0, (
"Expected 0 inductor compiles on the cache-hit run β€” artifact was not loaded.\n"
f"num_inductor_compiles={payload2['num_inductor_compiles']}\n"
f"stderr:\n{p2.stderr}"
)

# Numerical consistency.
assert (
abs(payload1["loss"] - payload2["loss"]) < 1e-2
), f"Loss mismatch between runs: run1={payload1['loss']}, run2={payload2['loss']}"
Loading