Skip to content

[Fix] Enable AOTAutograd caching for graphs containing autograd_function_apply#16

Merged
jiahy0825 merged 3 commits intoSandAI-org:mainfrom
wtr0504:fix/auto_grad_cache
Apr 3, 2026
Merged

[Fix] Enable AOTAutograd caching for graphs containing autograd_function_apply#16
jiahy0825 merged 3 commits intoSandAI-org:mainfrom
wtr0504:fix/auto_grad_cache

Conversation

@wtr0504
Copy link
Copy Markdown
Collaborator

@wtr0504 wtr0504 commented Apr 2, 2026

🗂️ PR Category

  • ✨ New Feature
  • 🚀 Optimization (performance, memory, etc.)
  • 💥 Breaking Change
  • 🐛 Bug Fix
  • 🛠️ Development / Refactoring
  • 📚 Documentation
  • 🧹 Chore (Dependencies, CI/CD, Configuration, etc.)
  • 🧪 Testing

📝 Description

Problem

When standalone_compile() is called from within the MagiBackend (i.e. from
inside a torch.compile backend), saving the compiled artifact always fails
with:

AssertionError: (raised inside CompiledArtifact.save())

causing a fallback to "no cache handle" and forcing recompilation on every
subsequent call.

Root cause

standalone_compile() internally calls compile_fx(), which runs AOTAutograd.
AOTAutograd attempts to compute a cache key by calling check_cacheable(gm),
which visits every node in the graph. The graphs produced by MagiBackend's
split_module pass contain nodes that invoke torch.ops.higher_order.
autograd_function_apply (the internal representation of torch.autograd.
Function subclasses). By default, HigherOrderOperator.cacheable() returns
False, causing check_cacheable() to raise BypassAOTAutogradCache.

This exception is silently caught upstream:

except Exception as e:
    cache_key = None          # no key computed
    cache_state = "bypass"    # artifact never recorded

As a result, CacheArtifactManager.record_artifact() is never called for the
aot_autograd type, leaving cache_info.aot_autograd_artifacts empty.

CompiledArtifact.save() then asserts:

assert len(cache_info.aot_autograd_artifacts) == 1

which raises, is caught by the outer try/except in PiecewiseCompiler.compile(),
logs a warning, and returns cache_handle=None — so the artifact is never
persisted and is recompiled every time.

Why the previous workaround was wrong

A prior attempt injected a fake key directly into cache_info:

cache_info.artifacts["aot_autograd"] = [key]   # key = magi/inductor key

This passed the assertion in save(), but CompiledArtifact.save() then calls
load_cache_artifacts(artifact_bytes) to unpack the serialized bytes into the
cache directory. Because there was never any real AOT data captured in
artifact_bytes, no aotautograd/ subdirectory was created on disk. On the
next run, CompiledArtifact.load() (format="unpacked") asserts:

assert os.path.isdir(os.path.join(path, "aotautograd"))

which crashes. The workaround fixed the save-side assertion only by lying
about the artifact contents, breaking the load path instead.

Fix

Set torch._functorch.config.autograd_cache_allow_custom_autograd_functions=True
via a scoped config.patch() around the standalone_compile() call.

HigherOrderOperator.cacheable() is defined as:

def cacheable(self) -> bool:
    from torch._functorch.autograd_function import AutogradFunctionApply

    return (
        self._cacheable
        or f"{self.__module__}.{self.__name__}"
        in torch._inductor.config.unsafe_marked_cacheable_functions
        or (
            isinstance(self, AutogradFunctionApply)
            and torch._functorch.config.autograd_cache_allow_custom_autograd_functions
        )
    )  With the flag set, cacheable() returns True, check_cacheable() no longer

raises, AOTAutograd computes a valid cache key, records the artifact in
artifact_bytes, and save()/load() both work correctly end-to-end.

Why this is safe in this context

The flag defaults to False because, in general use, torch.autograd.Function
subclasses may capture Python closures whose state is invisible to the cache
key computation, potentially making two semantically different compilations
collide on the same key.

In MagiBackend, all graphs arrive after Dynamo tracing. Dynamo explicitly
lifts every closed-over variable into a graph input (placeholder node), so
there is no hidden state: the graph is fully self-contained and deterministic.
The same graph structure with the same inputs will always produce the same
output, satisfying the correctness requirement for caching.

The patch() scope ensures the flag is active only during standalone_compile()
and does not affect any other code paths.

@wtr0504 wtr0504 changed the title [Fix] Enable AOTAutograd caching for graphs containing autograd_functi… [Fix] Enable AOTAutograd caching for graphs containing autograd_function_apply Apr 2, 2026
jiahy0825
jiahy0825 previously approved these changes Apr 2, 2026
Copy link
Copy Markdown
Collaborator

@jiahy0825 jiahy0825 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Copy Markdown
Collaborator

@jiahy0825 jiahy0825 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jiahy0825 jiahy0825 merged commit 7cc38bd into SandAI-org:main Apr 3, 2026
2 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants