Skip to content

Commit 4e1e4cb

Browse files
Gasoonjiafacebook-github-bot
authored andcommitted
introduce cuda sdpa (#15996)
Summary: Pull Request resolved: #15996 Test Plan: buck2 test mode/dev-nosan fbcode//executorch/backends/cuda/runtime/shims/tests:test_aoti_torch_cuda_scaled_dot_product_attention Differential Revision: D87950475 Pulled By: Gasoonjia
1 parent 33ec615 commit 4e1e4cb

File tree

8 files changed

+2810
-2
lines changed

8 files changed

+2810
-2
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
d03e90c2cd9048e6d9a75285c0355f033cd016fc
1+
8967fe914c252bf242b7d0ad4f5e098a007a6993

backends/cuda/cuda_backend.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ def get_device_name(cls) -> str:
3939
def get_supported_fallback_kernels(cls) -> Dict[str, Any]:
4040
return {
4141
"at::_ops::_weight_int4pack_mm::call": None,
42+
"at::_ops::_scaled_dot_product_flash_attention::call": None,
43+
"at::_ops::_scaled_dot_product_efficient_attention::call": None,
4244
}
4345

4446
@classmethod
@@ -68,7 +70,8 @@ def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any]
6870
)
6971
triton_kernel_mode = mode
7072

71-
return [ReplaceEdgeOpWithTritonOpPass()] if triton_kernel_mode == "ON" else []
73+
return []
74+
# return [ReplaceEdgeOpWithTritonOpPass()] if triton_kernel_mode == "ON" else []
7275

7376
@classmethod
7477
def get_aoti_compile_options(

backends/cuda/runtime/TARGETS

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ runtime.cxx_library(
5353
"shims/cuda_guard.cpp",
5454
"shims/int4mm.cu",
5555
"shims/memory.cpp",
56+
"shims/sdpa.cu",
5657
"shims/tensor_attribute.cpp",
5758
],
5859
headers = [
@@ -61,6 +62,8 @@ runtime.cxx_library(
6162
"shims/int4mm.cuh",
6263
"shims/int4mm.h",
6364
"shims/memory.h",
65+
"shims/sdpa.cuh",
66+
"shims/sdpa.h",
6467
"shims/tensor_attribute.h",
6568
"utils.h",
6669
],
@@ -84,6 +87,7 @@ runtime.cxx_library(
8487
],
8588
external_deps = [
8689
("cuda", None, "cuda-lazy"),
90+
("cuda", None, "cublas-lazy"),
8791
],
8892
)
8993

0 commit comments

Comments
 (0)