[Fix] Enable AOTAutograd caching for graphs containing autograd_function_apply#16
Merged
jiahy0825 merged 3 commits intoSandAI-org:mainfrom Apr 3, 2026
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
🗂️ PR Category
📝 Description
Problem
When standalone_compile() is called from within the MagiBackend (i.e. from
inside a torch.compile backend), saving the compiled artifact always fails
with:
causing a fallback to "no cache handle" and forcing recompilation on every
subsequent call.
Root cause
standalone_compile() internally calls compile_fx(), which runs AOTAutograd.
AOTAutograd attempts to compute a cache key by calling check_cacheable(gm),
which visits every node in the graph. The graphs produced by MagiBackend's
split_module pass contain nodes that invoke torch.ops.higher_order.
autograd_function_apply (the internal representation of torch.autograd.
Function subclasses). By default, HigherOrderOperator.cacheable() returns
False, causing check_cacheable() to raise BypassAOTAutogradCache.
This exception is silently caught upstream:
As a result, CacheArtifactManager.record_artifact() is never called for the
aot_autograd type, leaving cache_info.aot_autograd_artifacts empty.
CompiledArtifact.save() then asserts:
which raises, is caught by the outer try/except in PiecewiseCompiler.compile(),
logs a warning, and returns cache_handle=None — so the artifact is never
persisted and is recompiled every time.
Why the previous workaround was wrong
A prior attempt injected a fake key directly into cache_info:
This passed the assertion in save(), but CompiledArtifact.save() then calls
load_cache_artifacts(artifact_bytes) to unpack the serialized bytes into the
cache directory. Because there was never any real AOT data captured in
artifact_bytes, no aotautograd/ subdirectory was created on disk. On the
next run, CompiledArtifact.load() (format="unpacked") asserts:
which crashes. The workaround fixed the save-side assertion only by lying
about the artifact contents, breaking the load path instead.
Fix
Set torch._functorch.config.autograd_cache_allow_custom_autograd_functions=True
via a scoped config.patch() around the standalone_compile() call.
HigherOrderOperator.cacheable() is defined as:
raises, AOTAutograd computes a valid cache key, records the artifact in
artifact_bytes, and save()/load() both work correctly end-to-end.
Why this is safe in this context
The flag defaults to False because, in general use, torch.autograd.Function
subclasses may capture Python closures whose state is invisible to the cache
key computation, potentially making two semantically different compilations
collide on the same key.
In MagiBackend, all graphs arrive after Dynamo tracing. Dynamo explicitly
lifts every closed-over variable into a graph input (placeholder node), so
there is no hidden state: the graph is fully self-contained and deterministic.
The same graph structure with the same inputs will always produce the same
output, satisfying the correctness requirement for caching.
The patch() scope ensures the flag is active only during standalone_compile()
and does not affect any other code paths.