Skip to content

Commit 33ec615

Browse files
authored
make triton kernel usage user controlable
Differential Revision: D88096054 Pull Request resolved: #16030
1 parent 7fa93a7 commit 33ec615

File tree

4 files changed

+92
-13
lines changed

4 files changed

+92
-13
lines changed

backends/aoti/aoti_backend.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,15 @@ def get_aoti_compile_options(
7070

7171
@classmethod
7272
@abstractmethod
73-
def get_custom_passes(cls) -> List[typing.Any]:
73+
def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any]:
7474
"""Return the list of custom passes to apply after ReplaceViewCopyWithViewPass and before decomposition."""
7575
pass
7676

77+
@classmethod
78+
def get_extra_aoti_compile_context_manager(cls):
79+
"""Return extra context manager to apply during aoti_compile stage. By default returns an empty context manager."""
80+
return contextlib.nullcontext()
81+
7782
@classmethod
7883
@contextlib.contextmanager
7984
def collect_unsupported_fallback_kernels(cls, missing_fallback_kernels: Set[str]):
@@ -149,7 +154,7 @@ def preprocess(
149154
ReplaceViewCopyWithViewPass()(device_edge_program.graph_module)
150155

151156
# Apply custom backend-specific passes
152-
custom_passes = cls.get_custom_passes()
157+
custom_passes = cls.get_custom_passes(compile_specs)
153158
for custom_pass in custom_passes:
154159
custom_pass(device_edge_program.graph_module)
155160

@@ -174,7 +179,7 @@ def preprocess(
174179
# Compile with fallback kernel collection
175180
with cls.collect_unsupported_fallback_kernels(
176181
missing_fallback_kernels
177-
), torch.no_grad():
182+
), torch.no_grad(), cls.get_extra_aoti_compile_context_manager():
178183
paths = torch._inductor.aot_compile(
179184
edge_program_module, tuple(user_input_placeholders), options=options
180185
)

backends/apple/metal/metal_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def get_decomposition_table(cls) -> Dict[Any, Any]:
4242
return {}
4343

4444
@classmethod
45-
def get_custom_passes(cls) -> List[typing.Any]:
45+
def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any]:
4646
"""Return Metal-specific passes (currently none)"""
4747
return []
4848

backends/cuda/cuda_backend.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from executorch.exir.backend.backend_details import BackendDetails
1818
from executorch.exir.backend.compile_spec_schema import CompileSpec
1919
from torch._inductor.decomposition import conv1d_to_conv2d
20+
from torch.nn.attention import SDPBackend
2021

2122

2223
@final
@@ -47,9 +48,27 @@ def get_decomposition_table(cls) -> Dict[Any, Any]:
4748
}
4849

4950
@classmethod
50-
def get_custom_passes(cls) -> List[typing.Any]:
51-
"""Return CUDA-specific passes: ReplaceEdgeOpWithTritonOpPass"""
52-
return [ReplaceEdgeOpWithTritonOpPass()]
51+
def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any]:
52+
"""
53+
Return CUDA-specific passes: ReplaceEdgeOpWithTritonOpPass.
54+
55+
The Triton kernel replacement behavior can be controlled via compile_specs:
56+
- triton_kernel_mode="ON": Always use Triton kernels
57+
- triton_kernel_mode="OFF": Never use Triton kernels and fallback to other implementations like cuda or decomposed operator.
58+
"""
59+
# Parse compile_specs for triton_kernel_mode
60+
triton_kernel_mode = "ON" # Default mode
61+
for spec in compile_specs:
62+
if spec.key == "triton_kernel_mode":
63+
mode = spec.value.decode("utf-8").upper()
64+
if mode not in ["ON", "OFF"]:
65+
raise ValueError(
66+
f"Invalid triton_kernel_mode: {mode}. "
67+
f"Expected 'ON' or 'OFF'."
68+
)
69+
triton_kernel_mode = mode
70+
71+
return [ReplaceEdgeOpWithTritonOpPass()] if triton_kernel_mode == "ON" else []
5372

5473
@classmethod
5574
def get_aoti_compile_options(
@@ -114,3 +133,21 @@ def get_aoti_compile_options(
114133
), "shim_library_path should not be set for Linux"
115134

116135
return options
136+
137+
@classmethod
138+
def get_extra_aoti_compile_context_manager(cls):
139+
"""
140+
Return SDPA MATH backend context manager for CUDA compilation.
141+
142+
This context manager plays as a fallback solution for any remaining PyTorch SDPA
143+
operations to use the MATH backend (decomposed SDPA) during AOTInductor compilation.
144+
145+
Note:
146+
- If SDPA ops are replaced with Triton kernels by ReplaceEdgeOpWithTritonOpPass,
147+
this context manager will have no effect on those ops (they are no longer
148+
PyTorch SDPA ops).
149+
- If SDPA ops are NOT replaced (e.g., when triton_kernel_mode="OFF"), this
150+
context manager will force them to use the MATH backend, causing them to
151+
be automatically decomposed during compilation.
152+
"""
153+
return torch.nn.attention.sdpa_kernel([SDPBackend.MATH])

backends/cuda/tests/test_cuda_export.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from executorch.backends.cuda.cuda_partitioner import CudaPartitioner
1313
from executorch.examples.models.toy_model import SdpaModule
1414
from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower
15+
from executorch.exir.backend.compile_spec_schema import CompileSpec
1516
from torch.export import export
1617

1718

@@ -25,16 +26,27 @@ def setUp(self):
2526
self.skipTest("CUDA is not available")
2627

2728
def _export_to_cuda_with_lower(
28-
self, module: torch.nn.Module, inputs: Tuple[torch.Tensor, ...]
29+
self,
30+
module: torch.nn.Module,
31+
inputs: Tuple[torch.Tensor, ...],
32+
compile_specs: list[CompileSpec] | None = None,
2933
) -> None:
30-
"""Helper method to export a module to CUDA backend using to_edge_transform_and_lower."""
34+
"""Helper method to export a module to CUDA backend using to_edge_transform_and_lower.
35+
36+
Args:
37+
module: The torch.nn.Module to export
38+
inputs: The example inputs for the module
39+
compile_specs: Optional list of compile specs. If not provided, defaults to
40+
only the method name compile spec for "forward"
41+
"""
3142
# Export the model
3243
exported_program = export(module, inputs, strict=True)
3344

34-
# Create partitioner and compile specs
35-
partitioner = CudaPartitioner(
36-
[CudaBackend.generate_method_name_compile_spec("forward")]
37-
)
45+
# Create partitioner with compile specs
46+
if compile_specs is None:
47+
compile_specs = [CudaBackend.generate_method_name_compile_spec("forward")]
48+
49+
partitioner = CudaPartitioner(compile_specs)
3850

3951
# Use to_edge_transform_and_lower for complete pipeline
4052
edge_program_manager = to_edge_transform_and_lower(
@@ -288,3 +300,28 @@ def test_sdpa_single_kernel(self):
288300
edge_program_manager,
289301
"SDPA single kernel operation export failed",
290302
)
303+
304+
def test_triton_kernel_mode_off(self):
305+
"""
306+
Test CUDA export with triton_kernel_mode set to OFF for SDPA kernel.
307+
This validates that the backend correctly processes the triton_kernel_mode
308+
compile spec and can export SDPA operations without Triton kernel replacements.
309+
When triton_kernel_mode is OFF, SDPA should be decomposed using the MATH backend.
310+
"""
311+
312+
sdpa = SdpaModule()
313+
314+
# Create compile specs with triton_kernel_mode set to OFF
315+
compile_specs = [
316+
CudaBackend.generate_method_name_compile_spec("forward"),
317+
CompileSpec(key="triton_kernel_mode", value=b"OFF"),
318+
]
319+
320+
# Test export with triton_kernel_mode=OFF
321+
edge_program_manager = self._export_to_cuda_with_lower(
322+
sdpa.get_eager_model(), sdpa.get_example_inputs(), compile_specs
323+
)
324+
self.assertIsNotNone(
325+
edge_program_manager,
326+
"SDPA kernel export with triton_kernel_mode=OFF failed",
327+
)

0 commit comments

Comments
 (0)