Skip to content

Commit 0ff1e4a

Browse files
Gasoonjiafacebook-github-bot
authored andcommitted
introduce cuda sdpa (#15996)
Summary: Pull Request resolved: #15996 Test Plan: buck2 test mode/dev-nosan fbcode//executorch/backends/cuda/runtime/shims/tests:test_aoti_torch_cuda_scaled_dot_product_attention Differential Revision: D87950475 Pulled By: Gasoonjia
1 parent a93f59e commit 0ff1e4a

File tree

7 files changed

+2794
-1
lines changed

7 files changed

+2794
-1
lines changed

backends/cuda/cuda_backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def get_device_name(cls) -> str:
3838
def get_supported_fallback_kernels(cls) -> Dict[str, Any]:
3939
return {
4040
"at::_ops::_weight_int4pack_mm::call": None,
41+
"at::_ops::_scaled_dot_product_flash_attention::call": None,
4142
}
4243

4344
@classmethod
@@ -49,7 +50,8 @@ def get_decomposition_table(cls) -> Dict[Any, Any]:
4950
@classmethod
5051
def get_custom_passes(cls) -> List[typing.Any]:
5152
"""Return CUDA-specific passes: ReplaceEdgeOpWithTritonOpPass"""
52-
return [ReplaceEdgeOpWithTritonOpPass()]
53+
return []
54+
# return [ReplaceEdgeOpWithTritonOpPass()]
5355

5456
@classmethod
5557
def get_aoti_compile_options(

backends/cuda/runtime/TARGETS

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ runtime.cxx_library(
5353
"shims/cuda_guard.cpp",
5454
"shims/int4mm.cu",
5555
"shims/memory.cpp",
56+
"shims/sdpa.cu",
5657
"shims/tensor_attribute.cpp",
5758
],
5859
headers = [
@@ -61,6 +62,8 @@ runtime.cxx_library(
6162
"shims/int4mm.cuh",
6263
"shims/int4mm.h",
6364
"shims/memory.h",
65+
"shims/sdpa.cuh",
66+
"shims/sdpa.h",
6467
"shims/tensor_attribute.h",
6568
"utils.h",
6669
],
@@ -84,6 +87,7 @@ runtime.cxx_library(
8487
],
8588
external_deps = [
8689
("cuda", None, "cuda-lazy"),
90+
("cuda", None, "cublas-lazy"),
8791
],
8892
)
8993

0 commit comments

Comments
 (0)