From f07606e30ee0b828623aad38ad4f4be9f2a81a42 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Wed, 3 Dec 2025 19:41:04 -0800 Subject: [PATCH] introduce cuda sdpa (#15996) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/15996 Differential Revision: D87950475 Pulled By: Gasoonjia --- .../ci_commit_pins/optimum-executorch.txt | 2 +- .ci/scripts/test_model_e2e.sh | 2 +- backends/aoti/aoti_backend.py | 8 +- backends/cuda/CMakeLists.txt | 7 +- backends/cuda/cuda_backend.py | 39 +- backends/cuda/runtime/TARGETS | 4 + backends/cuda/runtime/shims/sdpa.cu | 1938 +++++++++++++++++ backends/cuda/runtime/shims/sdpa.cuh | 395 ++++ backends/cuda/runtime/shims/sdpa.h | 113 + backends/cuda/runtime/shims/tests/targets.bzl | 1 + ...orch_cuda_scaled_dot_product_attention.cpp | 1815 +++++++++++++++ 11 files changed, 4298 insertions(+), 26 deletions(-) create mode 100644 backends/cuda/runtime/shims/sdpa.cu create mode 100644 backends/cuda/runtime/shims/sdpa.cuh create mode 100644 backends/cuda/runtime/shims/sdpa.h create mode 100644 backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_scaled_dot_product_attention.cpp diff --git a/.ci/docker/ci_commit_pins/optimum-executorch.txt b/.ci/docker/ci_commit_pins/optimum-executorch.txt index df87f35a69d..edd4419d470 100644 --- a/.ci/docker/ci_commit_pins/optimum-executorch.txt +++ b/.ci/docker/ci_commit_pins/optimum-executorch.txt @@ -1 +1 @@ -d03e90c2cd9048e6d9a75285c0355f033cd016fc +8967fe914c252bf242b7d0ad4f5e098a007a6993 diff --git a/.ci/scripts/test_model_e2e.sh b/.ci/scripts/test_model_e2e.sh index e26a843733f..16097ad55cb 100755 --- a/.ci/scripts/test_model_e2e.sh +++ b/.ci/scripts/test_model_e2e.sh @@ -126,7 +126,7 @@ case "$HF_MODEL" in esac echo "::group::Setup ExecuTorch Requirements" -./install_requirements.sh +# ./install_requirements.sh pip list echo "::endgroup::" diff --git a/backends/aoti/aoti_backend.py b/backends/aoti/aoti_backend.py index c2c587da9fe..1fcca781b29 100644 --- a/backends/aoti/aoti_backend.py +++ b/backends/aoti/aoti_backend.py @@ -214,6 +214,8 @@ def preprocess( with open(so_path, "rb") as f: so_data = f.read() + print("so_path: ", so_path) + # Read weights blob with open(blob_path, "rb") as f: blob_data = f.read() @@ -229,9 +231,9 @@ def preprocess( method_name + "_weights_blob", blob_data, 1, weights_blob_data_type ) - # Clean up the generated files - os.remove(so_path) - os.remove(blob_path) + # # Clean up the generated files + # os.remove(so_path) + # os.remove(blob_path) return PreprocessResult( processed_bytes=b"", diff --git a/backends/cuda/CMakeLists.txt b/backends/cuda/CMakeLists.txt index ac97b9809bf..914421ef654 100644 --- a/backends/cuda/CMakeLists.txt +++ b/backends/cuda/CMakeLists.txt @@ -98,6 +98,7 @@ install( set(_aoti_cuda_shim_sources runtime/shims/memory.cpp runtime/shims/tensor_attribute.cpp runtime/guard.cpp runtime/shims/cuda_guard.cpp runtime/shims/int4mm.cu + runtime/shims/sdpa.cu ${EXECUTORCH_ROOT}/backends/aoti/common_shims.cpp ) @@ -130,12 +131,12 @@ target_link_options( aoti_cuda_shims PUBLIC $<$>:-Wl,--export-dynamic> ) -# Link against CUDA::cudart, common AOTI library, cuda_tensor_maker, and +# Link against CUDA::cudart, CUDA::cublas, common AOTI library, cuda_tensor_maker, and # platform utilities target_link_libraries( aoti_cuda_shims - PRIVATE cuda_platform - PUBLIC extension_tensor cuda_tensor_maker CUDA::cudart ${CMAKE_DL_LIBS} + PRIVATE cuda_platform executorch_core + PUBLIC extension_tensor cuda_tensor_maker CUDA::cudart CUDA::cublas ${CMAKE_DL_LIBS} ) if(NOT MSVC) diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index f0d3a000ec0..8cd7c2df119 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -39,6 +39,8 @@ def get_device_name(cls) -> str: def get_supported_fallback_kernels(cls) -> Dict[str, Any]: return { "at::_ops::_weight_int4pack_mm::call": None, + "at::_ops::_scaled_dot_product_flash_attention::call": None, + "at::_ops::_scaled_dot_product_efficient_attention::call": None, } @classmethod @@ -68,7 +70,8 @@ def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any] ) triton_kernel_mode = mode - return [ReplaceEdgeOpWithTritonOpPass()] if triton_kernel_mode == "ON" else [] + return [] + # return [ReplaceEdgeOpWithTritonOpPass()] if triton_kernel_mode == "ON" else [] @classmethod def get_aoti_compile_options( @@ -134,20 +137,20 @@ def get_aoti_compile_options( return options - @classmethod - def get_extra_aoti_compile_context_manager(cls): - """ - Return SDPA MATH backend context manager for CUDA compilation. - - This context manager plays as a fallback solution for any remaining PyTorch SDPA - operations to use the MATH backend (decomposed SDPA) during AOTInductor compilation. - - Note: - - If SDPA ops are replaced with Triton kernels by ReplaceEdgeOpWithTritonOpPass, - this context manager will have no effect on those ops (they are no longer - PyTorch SDPA ops). - - If SDPA ops are NOT replaced (e.g., when triton_kernel_mode="OFF"), this - context manager will force them to use the MATH backend, causing them to - be automatically decomposed during compilation. - """ - return torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) + # @classmethod + # def get_extra_aoti_compile_context_manager(cls): + # """ + # Return SDPA MATH backend context manager for CUDA compilation. + + # This context manager plays as a fallback solution for any remaining PyTorch SDPA + # operations to use the MATH backend (decomposed SDPA) during AOTInductor compilation. + + # Note: + # - If SDPA ops are replaced with Triton kernels by ReplaceEdgeOpWithTritonOpPass, + # this context manager will have no effect on those ops (they are no longer + # PyTorch SDPA ops). + # - If SDPA ops are NOT replaced (e.g., when triton_kernel_mode="OFF"), this + # context manager will force them to use the MATH backend, causing them to + # be automatically decomposed during compilation. + # """ + # return torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) diff --git a/backends/cuda/runtime/TARGETS b/backends/cuda/runtime/TARGETS index a85f3a7e6a3..01dabee9086 100644 --- a/backends/cuda/runtime/TARGETS +++ b/backends/cuda/runtime/TARGETS @@ -53,6 +53,7 @@ runtime.cxx_library( "shims/cuda_guard.cpp", "shims/int4mm.cu", "shims/memory.cpp", + "shims/sdpa.cu", "shims/tensor_attribute.cpp", ], headers = [ @@ -61,6 +62,8 @@ runtime.cxx_library( "shims/int4mm.cuh", "shims/int4mm.h", "shims/memory.h", + "shims/sdpa.cuh", + "shims/sdpa.h", "shims/tensor_attribute.h", "utils.h", ], @@ -84,6 +87,7 @@ runtime.cxx_library( ], external_deps = [ ("cuda", None, "cuda-lazy"), + ("cuda", None, "cublas-lazy"), ], ) diff --git a/backends/cuda/runtime/shims/sdpa.cu b/backends/cuda/runtime/shims/sdpa.cu new file mode 100644 index 00000000000..56033144708 --- /dev/null +++ b/backends/cuda/runtime/shims/sdpa.cu @@ -0,0 +1,1938 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace executorch::backends::cuda { + +using executorch::backends::aoti::AOTITorchError; +using executorch::backends::aoti::Tensor; +using executorch::runtime::Error; + +// ============================================================================ +// CUDA Kernels for Softmax and Masking +// ============================================================================ + +// Helper function for max with different types +__device__ __forceinline__ float device_max(float a, float b) { + return fmaxf(a, b); +} + +__device__ __forceinline__ __half device_max(__half a, __half b) { + return __hgt(a, b) ? a : b; +} + +__device__ __forceinline__ __nv_bfloat16 +device_max(__nv_bfloat16 a, __nv_bfloat16 b) { +#if __CUDA_ARCH__ >= 800 + return __hgt(a, b) ? a : b; +#else + return __float2bfloat16(fmaxf(__bfloat162float(a), __bfloat162float(b))); +#endif +} + +/** + * Softmax kernel with optional causal masking and attention bias + * + * Computes softmax along the last dimension (seq_len_k) of a 4D tensor. + * Supports: + * - Causal masking where positions j > i are masked out + * - Explicit attention bias (additive mask, 0.0 = allow, -inf = mask) + * + * Input: [batch, num_heads, seq_len_q, seq_len_k] + * Output: [batch, num_heads, seq_len_q, seq_len_k] + * Bias (optional): [batch, num_heads, seq_len_q, seq_len_k] or broadcastable + * + * Each thread processes one row (seq_len_q position). + * + * Note: Supports in-place operation (input == output). + */ +template +__global__ void softmax_with_mask_kernel( + const scalar_t* input, + scalar_t* output, + const scalar_t* attn_bias, // Optional attention bias (can be nullptr) + const int64_t batch, + const int64_t num_heads, + const int64_t seq_len_q, + const int64_t seq_len_k, + const bool is_causal, + const float scale) { + // Each block processes one row of the attention matrix + const int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + const int64_t total_rows = batch * num_heads * seq_len_q; + + if (idx >= total_rows) { + return; + } + + // Decode position - we only need i for causal masking + const int64_t i = idx % seq_len_q; + + // Pointer to the start of this row + const int64_t row_offset = idx * seq_len_k; + const scalar_t* input_row = input + row_offset; + scalar_t* output_row = output + row_offset; + + // Attention bias for this row (if provided) + const scalar_t* bias_row = attn_bias ? (attn_bias + row_offset) : nullptr; + + // Find max for numerical stability (two-pass algorithm) + float max_val = -FLT_MAX; + for (int64_t j = 0; j < seq_len_k; ++j) { + // Apply scaling and bias + float val = static_cast(input_row[j]) * scale; + if (bias_row) { + float bias = static_cast(bias_row[j]); + val += bias; // Additive bias (0.0 = allow, -inf = mask) + } + + // Apply causal mask if needed + if (!is_causal || j <= i) { + max_val = fmaxf(max_val, val); + } + } + + // Compute exp and sum + float sum_exp = 0.0f; + for (int64_t j = 0; j < seq_len_k; ++j) { + float val = static_cast(input_row[j]) * scale; + if (bias_row) { + float bias = static_cast(bias_row[j]); + val += bias; + } + + float exp_val; + if (!is_causal || j <= i) { + exp_val = expf(val - max_val); + } else { + exp_val = 0.0f; + } + output_row[j] = static_cast(exp_val); + sum_exp += exp_val; + } + + // Normalize + const float inv_sum = 1.0f / (sum_exp + 1e-12f); // Add epsilon for stability + for (int64_t j = 0; j < seq_len_k; ++j) { + output_row[j] = + static_cast(static_cast(output_row[j]) * inv_sum); + } +} + +/** + * Scale kernel - multiply all elements by a scalar + */ +template +__global__ void scale_kernel( + scalar_t* __restrict__ data, + const int64_t size, + const float scale) { + const int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < size) { + data[idx] = static_cast(static_cast(data[idx]) * scale); + } +} + +// ============================================================================ +// cuBLAS Helper Functions +// ============================================================================ + +/** + * Get or create a cuBLAS handle for the current stream + * + * Note: In production, this should use a handle pool or be managed + * by the backend infrastructure. This is a simplified version. + */ +cublasHandle_t get_cublas_handle(cudaStream_t stream) { + static cublasHandle_t handle = nullptr; + static bool init_attempted = false; + static bool init_success = false; + + if (!init_attempted) { + init_attempted = true; + cublasStatus_t status = cublasCreate_v2(&handle); + if (status != CUBLAS_STATUS_SUCCESS) { + ET_LOG( + Error, + "Failed to create cuBLAS handle: %d", + static_cast(status)); + handle = nullptr; + init_success = false; + } else { + ET_LOG(Info, "cuBLAS handle created successfully"); + init_success = true; + } + } + + if (!init_success || handle == nullptr) { + return nullptr; + } + + cublasStatus_t status = cublasSetStream_v2(handle, stream); + if (status != CUBLAS_STATUS_SUCCESS) { + ET_LOG(Error, "Failed to set cuBLAS stream: %d", static_cast(status)); + return nullptr; + } + + return handle; +} + +/** + * Batched matrix multiplication wrapper for cuBLAS + * + * Computes: C = alpha * op(A) @ op(B) + beta * C + * for a batch of matrices + */ +template +cublasStatus_t batched_gemm( + cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const scalar_t* alpha, + const scalar_t* A, + int lda, + int64_t strideA, + const scalar_t* B, + int ldb, + int64_t strideB, + const scalar_t* beta, + scalar_t* C, + int ldc, + int64_t strideC, + int batchCount); + +// Specializations for different types +template <> +cublasStatus_t batched_gemm( + cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float* alpha, + const float* A, + int lda, + int64_t strideA, + const float* B, + int ldb, + int64_t strideB, + const float* beta, + float* C, + int ldc, + int64_t strideC, + int batchCount) { + return cublasSgemmStridedBatched( + handle, + transa, + transb, + m, + n, + k, + alpha, + A, + lda, + strideA, + B, + ldb, + strideB, + beta, + C, + ldc, + strideC, + batchCount); +} + +template <> +cublasStatus_t batched_gemm<__half>( + cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const __half* alpha, + const __half* A, + int lda, + int64_t strideA, + const __half* B, + int ldb, + int64_t strideB, + const __half* beta, + __half* C, + int ldc, + int64_t strideC, + int batchCount) { + return cublasHgemmStridedBatched( + handle, + transa, + transb, + m, + n, + k, + alpha, + A, + lda, + strideA, + B, + ldb, + strideB, + beta, + C, + ldc, + strideC, + batchCount); +} + +// Note: BFloat16 uses compute type float internally +template <> +cublasStatus_t batched_gemm<__nv_bfloat16>( + cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const __nv_bfloat16* alpha, + const __nv_bfloat16* A, + int lda, + int64_t strideA, + const __nv_bfloat16* B, + int ldb, + int64_t strideB, + const __nv_bfloat16* beta, + __nv_bfloat16* C, + int ldc, + int64_t strideC, + int batchCount) { +// cuBLAS BFloat16 GEMM - introduced in CUDA 11+ +#if CUDA_VERSION >= 11000 + // For BFloat16, we need to use cublasGemmStridedBatchedEx + // with compute type CUBLAS_COMPUTE_32F + float alpha_f = 1.0f; + float beta_f = 0.0f; + + return cublasGemmStridedBatchedEx( + handle, + transa, + transb, + m, + n, + k, + &alpha_f, + A, + CUDA_R_16BF, + lda, + strideA, + B, + CUDA_R_16BF, + ldb, + strideB, + &beta_f, + C, + CUDA_R_16BF, + ldc, + strideC, + batchCount, + CUBLAS_COMPUTE_32F, + CUBLAS_GEMM_DEFAULT); +#else + ET_LOG(Error, "BFloat16 GEMM requires CUDA 11.0 or later"); + return CUBLAS_STATUS_NOT_SUPPORTED; +#endif +} + +// ============================================================================ +// Flash Attention Implementation +// ============================================================================ + +/** + * Flash Attention kernel - memory-efficient attention computation + * + * This kernel implements the Flash Attention algorithm which computes + * attention in blocks to reduce memory usage from O(N^2) to O(N). + */ +template +__global__ void flash_attention_kernel( + const scalar_t* __restrict__ query, + const scalar_t* __restrict__ key, + const scalar_t* __restrict__ value, + scalar_t* __restrict__ output, + const int64_t seq_len_q, + const int64_t seq_len_k, + const int64_t head_dim, + const int64_t head_dim_v, + const float scale, + const bool is_causal) { + const int64_t batch_head_idx = blockIdx.z; + const int64_t q_block_idx = blockIdx.y; + const int64_t q_start = q_block_idx * BLOCK_SIZE; + const int64_t q_end = min(q_start + BLOCK_SIZE, seq_len_q); + const int64_t q_block_size = q_end - q_start; + const int tid = threadIdx.x; + + extern __shared__ char shared_mem[]; + scalar_t* shared_q = reinterpret_cast(shared_mem); + scalar_t* shared_k = shared_q + BLOCK_SIZE * head_dim; + + const scalar_t* q_base = query + batch_head_idx * seq_len_q * head_dim; + const scalar_t* k_base = key + batch_head_idx * seq_len_k * head_dim; + const scalar_t* v_base = value + batch_head_idx * seq_len_k * head_dim_v; + scalar_t* out_base = output + batch_head_idx * seq_len_q * head_dim_v; + + for (int i = tid; i < q_block_size * head_dim; i += blockDim.x) { + int q_idx = i / head_dim; + int d_idx = i % head_dim; + if (q_start + q_idx < seq_len_q) { + shared_q[i] = q_base[(q_start + q_idx) * head_dim + d_idx]; + } + } + __syncthreads(); + + for (int64_t q_local = 0; q_local < q_block_size; ++q_local) { + if (tid != 0) + continue; + + const int64_t q_idx = q_start + q_local; + float max_score = -FLT_MAX; + float sum_exp = 0.0f; + float output_acc[64]; + for (int d = 0; d < head_dim_v; ++d) { + output_acc[d] = 0.0f; + } + + const int64_t k_blocks = (seq_len_k + BLOCK_SIZE - 1) / BLOCK_SIZE; + for (int64_t k_block_idx = 0; k_block_idx < k_blocks; ++k_block_idx) { + const int64_t k_start = k_block_idx * BLOCK_SIZE; + const int64_t k_end = min(k_start + BLOCK_SIZE, seq_len_k); + const int64_t k_block_size = k_end - k_start; + + float block_scores[64]; + float block_max = -FLT_MAX; + + for (int64_t k_local = 0; k_local < k_block_size; ++k_local) { + const int64_t k_idx = k_start + k_local; + if (is_causal && k_idx > q_idx) { + block_scores[k_local] = -FLT_MAX; + continue; + } + + float score = 0.0f; + for (int64_t d = 0; d < head_dim; ++d) { + float q_val = static_cast(shared_q[q_local * head_dim + d]); + float k_val = static_cast(k_base[k_idx * head_dim + d]); + score += q_val * k_val; + } + score *= scale; + block_scores[k_local] = score; + block_max = fmaxf(block_max, score); + } + + float new_max = fmaxf(max_score, block_max); + float exp_correction = expf(max_score - new_max); + + for (int d = 0; d < head_dim_v; ++d) { + output_acc[d] *= exp_correction; + } + sum_exp *= exp_correction; + + for (int64_t k_local = 0; k_local < k_block_size; ++k_local) { + const int64_t k_idx = k_start + k_local; + if (is_causal && k_idx > q_idx) + continue; + + float exp_score = expf(block_scores[k_local] - new_max); + sum_exp += exp_score; + + for (int64_t d = 0; d < head_dim_v; ++d) { + float v_val = static_cast(v_base[k_idx * head_dim_v + d]); + output_acc[d] += exp_score * v_val; + } + } + max_score = new_max; + } + + float inv_sum = 1.0f / sum_exp; + for (int64_t d = 0; d < head_dim_v; ++d) { + out_base[q_idx * head_dim_v + d] = + static_cast(output_acc[d] * inv_sum); + } + } +} + +/** + * Flash Attention implementation dispatcher + */ +template +Tensor* sdpa_flash_attention_impl( + const Tensor* query, + const Tensor* key, + const Tensor* value, + const Tensor* attn_mask, + bool is_causal, + float scale_factor, + cudaStream_t stream) { + const int64_t batch = query->size(0); + const int64_t num_heads = query->size(1); + const int64_t seq_len_q = query->size(2); + const int64_t head_dim = query->size(3); + const int64_t seq_len_k = key->size(2); + const int64_t head_dim_v = value->size(3); + + Tensor* output = nullptr; + std::array output_shape = { + batch, num_heads, seq_len_q, head_dim_v}; + std::array output_stride = { + num_heads * seq_len_q * head_dim_v, + seq_len_q * head_dim_v, + head_dim_v, + 1}; + + auto dtype_int = static_cast(query->dtype()); + aoti_torch_empty_strided( + 4, + output_shape.data(), + output_stride.data(), + dtype_int, + static_cast(SupportedDevices::CUDA), + 0, + &output); + + if (output == nullptr) { + ET_LOG(Error, "sdpa_flash_attention: Failed to allocate output tensor"); + return nullptr; + } + + constexpr int BLOCK_SIZE = 64; + const int64_t q_blocks = (seq_len_q + BLOCK_SIZE - 1) / BLOCK_SIZE; + const int64_t batch_head_count = batch * num_heads; + const size_t shared_mem_size = BLOCK_SIZE * head_dim * sizeof(scalar_t) * 2 + + BLOCK_SIZE * BLOCK_SIZE * sizeof(float); + + dim3 grid(1, q_blocks, batch_head_count); + dim3 block(256); + + const scalar_t* q_ptr = reinterpret_cast(query->data_ptr()); + const scalar_t* k_ptr = reinterpret_cast(key->data_ptr()); + const scalar_t* v_ptr = reinterpret_cast(value->data_ptr()); + scalar_t* out_ptr = reinterpret_cast(output->data_ptr()); + + flash_attention_kernel + <<>>( + q_ptr, + k_ptr, + v_ptr, + out_ptr, + seq_len_q, + seq_len_k, + head_dim, + head_dim_v, + scale_factor, + is_causal); + + cudaError_t cuda_err = cudaGetLastError(); + if (cuda_err != cudaSuccess) { + ET_LOG( + Error, + "sdpa_flash_attention: Kernel launch failed: %s", + cudaGetErrorString(cuda_err)); + aoti_torch_delete_tensor_object(output); + return nullptr; + } + return output; +} + +/** + * Flash Attention entry point with dtype dispatch + */ +Tensor* sdpa_flash_attention( + const Tensor* query, + const Tensor* key, + const Tensor* value, + const Tensor* attn_mask, + bool is_causal, + double scale_factor, + cudaStream_t stream) { + auto dtype = query->dtype(); + if (dtype == executorch::aten::ScalarType::Float) { + return sdpa_flash_attention_impl( + query, + key, + value, + attn_mask, + is_causal, + static_cast(scale_factor), + stream); + } else if (dtype == executorch::aten::ScalarType::Half) { + return sdpa_flash_attention_impl<__half>( + query, + key, + value, + attn_mask, + is_causal, + static_cast(scale_factor), + stream); + } else if (dtype == executorch::aten::ScalarType::BFloat16) { + return sdpa_flash_attention_impl<__nv_bfloat16>( + query, + key, + value, + attn_mask, + is_causal, + static_cast(scale_factor), + stream); + } else { + ET_LOG(Error, "sdpa_flash_attention: Unsupported dtype"); + return nullptr; + } +} + +// ============================================================================ +// Memory-Efficient Attention Implementation (with attn_bias support) +// ============================================================================ + +/** + * Memory-Efficient Attention kernel with attention bias support + * + * This kernel computes scaled dot-product attention with full support for + * attention bias (additive mask). It uses online softmax for numerical + * stability. + * + * Each thread processes one query position independently. + * + * Input shapes: + * - query: [batch, num_heads, seq_len_q, head_dim] + * - key: [batch, num_heads, seq_len_k, head_dim] + * - value: [batch, num_heads, seq_len_k, head_dim_v] + * - attn_bias: [batch, num_heads, seq_len_q, seq_len_k] (optional) + * - output: [batch, num_heads, seq_len_q, head_dim_v] + */ +template +__global__ void efficient_attention_kernel( + const scalar_t* __restrict__ query, + const scalar_t* __restrict__ key, + const scalar_t* __restrict__ value, + const scalar_t* __restrict__ attn_bias, + scalar_t* __restrict__ output, + const int64_t num_heads, + const int64_t seq_len_q, + const int64_t seq_len_k, + const int64_t head_dim, + const int64_t head_dim_v, + const float scale, + const bool is_causal, + // Query strides [batch, head, seq, dim] + const int64_t q_stride_batch, + const int64_t q_stride_head, + const int64_t q_stride_seq, + const int64_t q_stride_dim, + // Key strides [batch, head, seq, dim] + const int64_t k_stride_batch, + const int64_t k_stride_head, + const int64_t k_stride_seq, + const int64_t k_stride_dim, + // Value strides [batch, head, seq, dim] + const int64_t v_stride_batch, + const int64_t v_stride_head, + const int64_t v_stride_seq, + const int64_t v_stride_dim, + // Output strides [batch, head, seq, dim] - always contiguous + const int64_t o_stride_batch, + const int64_t o_stride_head, + const int64_t o_stride_seq, + const int64_t o_stride_dim, + // Bias strides [batch, head, seq_q, seq_k] + const int64_t bias_stride_batch, + const int64_t bias_stride_head, + const int64_t bias_stride_q, + const int64_t bias_stride_k) { + const int64_t batch_head_idx = blockIdx.x; + const int64_t q_idx = blockIdx.y * blockDim.x + threadIdx.x; + + if (q_idx >= seq_len_q) { + return; + } + + // Decompose batch_head_idx into batch and head indices + const int64_t batch_idx = batch_head_idx / num_heads; + const int64_t head_idx = batch_head_idx % num_heads; + + // Compute base pointers using proper 4D strides + const scalar_t* q_base = + query + batch_idx * q_stride_batch + head_idx * q_stride_head; + const scalar_t* k_base = + key + batch_idx * k_stride_batch + head_idx * k_stride_head; + const scalar_t* v_base = + value + batch_idx * v_stride_batch + head_idx * v_stride_head; + scalar_t* out_base = + output + batch_idx * o_stride_batch + head_idx * o_stride_head; + + // Compute bias base pointer using proper 4D indexing with broadcasting + // support + const scalar_t* bias_base = nullptr; + if (attn_bias != nullptr) { + // Only add stride contribution if the dimension size > 1 (not broadcasting) + int64_t bias_offset = 0; + + // Note: bias_stride will be 0 for broadcasting dimensions (size=1) + // This is correct - we want all positions to point to the same element + bias_offset += batch_idx * bias_stride_batch; + bias_offset += head_idx * bias_stride_head; + bias_offset += q_idx * bias_stride_q; + + bias_base = attn_bias + bias_offset; + } + + float output_acc[MAX_HEAD_DIM_V]; + for (int64_t d = 0; d < head_dim_v && d < MAX_HEAD_DIM_V; ++d) { + output_acc[d] = 0.0f; + } + + float max_score = -FLT_MAX; + float sum_exp = 0.0f; + + for (int64_t k_idx = 0; k_idx < seq_len_k; ++k_idx) { + if (is_causal && k_idx > q_idx) { + continue; + } + + float score = 0.0f; + for (int64_t d = 0; d < head_dim; ++d) { + float q_val = + static_cast(q_base[q_idx * q_stride_seq + d * q_stride_dim]); + float k_val = + static_cast(k_base[k_idx * k_stride_seq + d * k_stride_dim]); + score += q_val * k_val; + } + score *= scale; + + // Add bias if provided + // Note: bias_stride_k should be 1, and we're indexing along the last + // dimension + if (bias_base != nullptr) { + float bias_val = static_cast(bias_base[k_idx * bias_stride_k]); + score += bias_val; + } + + float new_max = fmaxf(max_score, score); + float exp_correction = expf(max_score - new_max); + + for (int64_t d = 0; d < head_dim_v && d < MAX_HEAD_DIM_V; ++d) { + output_acc[d] *= exp_correction; + } + sum_exp = sum_exp * exp_correction + expf(score - new_max); + + float exp_score = expf(score - new_max); + for (int64_t d = 0; d < head_dim_v && d < MAX_HEAD_DIM_V; ++d) { + float v_val = + static_cast(v_base[k_idx * v_stride_seq + d * v_stride_dim]); + output_acc[d] += exp_score * v_val; + } + + max_score = new_max; + } + + float inv_sum = (sum_exp > 0.0f) ? (1.0f / sum_exp) : 0.0f; + for (int64_t d = 0; d < head_dim_v && d < MAX_HEAD_DIM_V; ++d) { + out_base[q_idx * o_stride_seq + d * o_stride_dim] = + static_cast(output_acc[d] * inv_sum); + } +} + +/** + * Memory-Efficient Attention implementation dispatcher + */ +template +Tensor* sdpa_efficient_attention_impl( + const Tensor* query, + const Tensor* key, + const Tensor* value, + const Tensor* attn_bias, + bool is_causal, + float scale_factor, + cudaStream_t stream) { + const int64_t batch = query->size(0); + const int64_t num_heads = query->size(1); + const int64_t seq_len_q = query->size(2); + const int64_t head_dim = query->size(3); + const int64_t seq_len_k = key->size(2); + const int64_t head_dim_v = value->size(3); + + Tensor* output = nullptr; + std::array output_shape = { + batch, num_heads, seq_len_q, head_dim_v}; + std::array output_stride = { + num_heads * seq_len_q * head_dim_v, + seq_len_q * head_dim_v, + head_dim_v, + 1}; + + auto dtype_int = static_cast(query->dtype()); + aoti_torch_empty_strided( + 4, + output_shape.data(), + output_stride.data(), + dtype_int, + static_cast(SupportedDevices::CUDA), + 0, + &output); + + if (output == nullptr) { + ET_LOG(Error, "sdpa_efficient_attention: Failed to allocate output tensor"); + return nullptr; + } + + const int64_t batch_head_count = batch * num_heads; + const int threads_per_block = 128; + const int64_t q_blocks = + (seq_len_q + threads_per_block - 1) / threads_per_block; + + dim3 grid(batch_head_count, q_blocks); + dim3 block(threads_per_block); + + const scalar_t* q_ptr = reinterpret_cast(query->data_ptr()); + const scalar_t* k_ptr = reinterpret_cast(key->data_ptr()); + const scalar_t* v_ptr = reinterpret_cast(value->data_ptr()); + scalar_t* out_ptr = reinterpret_cast(output->data_ptr()); + + const scalar_t* bias_ptr = nullptr; + int64_t bias_stride_batch = 0; + int64_t bias_stride_head = 0; + int64_t bias_stride_q = 0; + int64_t bias_stride_k = 1; + + if (attn_bias != nullptr) { + bias_ptr = reinterpret_cast(attn_bias->data_ptr()); + + int64_t bias_dim = attn_bias->dim(); + printf(" attn_bias: ptr=%p, dim=%ld\n", (void*)bias_ptr, bias_dim); + fflush(stdout); + + if (bias_dim == 4) { + // Handle attention bias with shape [batch, num_heads, seq_len_q, + // seq_len_k] or broadcastable variants + auto bias_strides = attn_bias->strides(); + auto bias_sizes = attn_bias->sizes(); + + // Extract sizes safely + int64_t bias_size_0 = bias_sizes[0]; + int64_t bias_size_1 = bias_sizes[1]; + int64_t bias_size_2 = bias_sizes[2]; + int64_t bias_size_3 = bias_sizes[3]; + + // Extract strides safely + int64_t bias_stride_0 = bias_strides[0]; + int64_t bias_stride_1 = bias_strides[1]; + int64_t bias_stride_2 = bias_strides[2]; + int64_t bias_stride_3 = bias_strides[3]; + + printf( + " bias dim=4, sizes=[%ld, %ld, %ld, %ld], strides=[%ld, %ld, %ld, %ld]\n", + bias_size_0, + bias_size_1, + bias_size_2, + bias_size_3, + bias_stride_0, + bias_stride_1, + bias_stride_2, + bias_stride_3); + fflush(stdout); + + bias_stride_batch = bias_stride_0; + bias_stride_head = bias_stride_1; + bias_stride_q = bias_stride_2; + bias_stride_k = bias_stride_3; + } else { + printf( + " WARNING: attn_bias has unexpected dim=%ld (expected 4) tried to print its top 4 size and stride to see the value\n", + bias_dim); + fflush(stdout); + + bias_dim = 4; + + // Try to handle 1D or other dimensions + auto bias_strides = attn_bias->strides(); + auto bias_sizes = attn_bias->sizes(); + printf(" bias_sizes: "); + for (int64_t i = 0; i < bias_dim; ++i) { + printf("%d ", bias_sizes[i]); + } + printf("\n bias_strides: "); + for (int64_t i = 0; i < bias_dim; ++i) { + printf("%d ", bias_strides[i]); + } + printf("\n"); + fflush(stdout); + exit(1); + } + } + + // Debug: Print query tensor info + auto query_strides = query->strides(); + auto key_strides = key->strides(); + auto value_strides = value->strides(); + + printf("Launching efficient_attention_kernel:\n"); + printf( + " batch=%ld, num_heads=%ld, seq_len_q=%ld, seq_len_k=%ld, head_dim=%ld, head_dim_v=%ld\n", + batch, + num_heads, + seq_len_q, + seq_len_k, + head_dim, + head_dim_v); + // printf( + // " query_strides=[%ld, %ld, %ld, %ld]\n", + // query_strides[0], + // query_strides[1], + // query_strides[2], + // query_strides[3]); + // printf( + // " key_strides=[%ld, %ld, %ld, %ld]\n", + // key_strides[0], + // key_strides[1], + // key_strides[2], + // key_strides[3]); + // printf( + // " value_strides=[%ld, %ld, %ld, %ld]\n", + // value_strides[0], + // value_strides[1], + // Debug: Print query tensor info + // auto query_strides = query->strides(); + // auto key_strides = key->strides(); + // auto value_strides = value->strides(); + + printf("\n=== Efficient Attention Kernel Launch Details ===\n"); + printf("Tensor Dimensions:\n"); + printf( + " batch=%ld, num_heads=%ld, seq_len_q=%ld, seq_len_k=%ld\n", + batch, + num_heads, + seq_len_q, + seq_len_k); + printf(" head_dim=%ld, head_dim_v=%ld\n", head_dim, head_dim_v); + + printf( + "\nQuery Tensor (shape: [%ld, %ld, %ld, %ld]):\n", + batch, + num_heads, + seq_len_q, + head_dim); + printf( + " strides=[%d, %d, %d, %d]\n", + query_strides[0], + query_strides[1], + query_strides[2], + query_strides[3]); + + printf( + "\nKey Tensor (shape: [%ld, %ld, %ld, %ld]):\n", + batch, + num_heads, + seq_len_k, + head_dim); + printf( + " strides=[%d, %d, %d, %d]\n", + key_strides[0], + key_strides[1], + key_strides[2], + key_strides[3]); + + printf( + "\nValue Tensor (shape: [%ld, %ld, %ld, %ld]):\n", + batch, + num_heads, + seq_len_k, + head_dim_v); + printf( + " strides=[%d, %d, %d, %d]\n", + value_strides[0], + value_strides[1], + value_strides[2], + value_strides[3]); + + if (attn_bias != nullptr) { + auto bias_sizes = attn_bias->sizes(); + printf("\nAttention Bias Tensor (shape: ["); + for (int64_t i = 0; i < attn_bias->dim(); ++i) { + printf("%d", bias_sizes[i]); + if (i < attn_bias->dim() - 1) + printf(", "); + } + printf("]):\n"); + printf(" bias_ptr=%p\n", (void*)bias_ptr); + printf( + " strides=[batch:%ld, head:%ld, q:%ld, k:%ld]\n", + bias_stride_batch, + bias_stride_head, + bias_stride_q, + bias_stride_k); + } else { + printf("\nAttention Bias: None (nullptr)\n"); + } + + printf("\nKernel Configuration:\n"); + printf(" scale_factor=%.6f\n", scale_factor); + printf(" is_causal=%d\n", is_causal); + printf( + " grid=(%ld, %ld, 1) [batch_head_count=%ld, q_blocks=%ld]\n", + batch_head_count, + q_blocks, + batch_head_count, + q_blocks); + printf( + " block=(%d, 1, 1) [threads_per_block=%d]\n", + threads_per_block, + threads_per_block); + + // Verify that Q/K/V are contiguous in the batch_head*seq*dim layout + bool q_is_contiguous = + (query_strides[0] == num_heads * seq_len_q * head_dim) && + (query_strides[1] == seq_len_q * head_dim) && + (query_strides[2] == head_dim) && (query_strides[3] == 1); + bool k_is_contiguous = (key_strides[0] == num_heads * seq_len_k * head_dim) && + (key_strides[1] == seq_len_k * head_dim) && + (key_strides[2] == head_dim) && (key_strides[3] == 1); + bool v_is_contiguous = + (value_strides[0] == num_heads * seq_len_k * head_dim_v) && + (value_strides[1] == seq_len_k * head_dim_v) && + (value_strides[2] == head_dim_v) && (value_strides[3] == 1); + + printf("\nMemory Layout Check:\n"); + printf(" Q is contiguous: %s\n", q_is_contiguous ? "YES" : "NO"); + printf(" K is contiguous: %s\n", k_is_contiguous ? "YES" : "NO"); + printf(" V is contiguous: %s\n", v_is_contiguous ? "YES" : "NO"); + + if (!q_is_contiguous || !k_is_contiguous || !v_is_contiguous) { + printf( + " WARNING: Non-contiguous tensor detected! Kernel will use stride-based indexing.\n"); + } + printf("==============================================\n\n"); + fflush(stdout); + + // Output strides (always contiguous) + int64_t o_stride_batch = num_heads * seq_len_q * head_dim_v; + int64_t o_stride_head = seq_len_q * head_dim_v; + int64_t o_stride_seq = head_dim_v; + int64_t o_stride_dim = 1; + + if (head_dim_v <= 64) { + efficient_attention_kernel<<>>( + q_ptr, + k_ptr, + v_ptr, + bias_ptr, + out_ptr, + num_heads, + seq_len_q, + seq_len_k, + head_dim, + head_dim_v, + scale_factor, + is_causal, + // Query strides + query_strides[0], + query_strides[1], + query_strides[2], + query_strides[3], + // Key strides + key_strides[0], + key_strides[1], + key_strides[2], + key_strides[3], + // Value strides + value_strides[0], + value_strides[1], + value_strides[2], + value_strides[3], + // Output strides + o_stride_batch, + o_stride_head, + o_stride_seq, + o_stride_dim, + // Bias strides + bias_stride_batch, + bias_stride_head, + bias_stride_q, + bias_stride_k); + } else { + efficient_attention_kernel<<>>( + q_ptr, + k_ptr, + v_ptr, + bias_ptr, + out_ptr, + num_heads, + seq_len_q, + seq_len_k, + head_dim, + head_dim_v, + scale_factor, + is_causal, + // Query strides + query_strides[0], + query_strides[1], + query_strides[2], + query_strides[3], + // Key strides + key_strides[0], + key_strides[1], + key_strides[2], + key_strides[3], + // Value strides + value_strides[0], + value_strides[1], + value_strides[2], + value_strides[3], + // Output strides + o_stride_batch, + o_stride_head, + o_stride_seq, + o_stride_dim, + // Bias strides + bias_stride_batch, + bias_stride_head, + bias_stride_q, + bias_stride_k); + } + + cudaError_t cuda_err = cudaGetLastError(); + if (cuda_err != cudaSuccess) { + ET_LOG( + Error, + "sdpa_efficient_attention: Kernel launch failed: %s", + cudaGetErrorString(cuda_err)); + aoti_torch_delete_tensor_object(output); + return nullptr; + } + + // Synchronize to check for kernel execution errors + cuda_err = cudaStreamSynchronize(stream); + if (cuda_err != cudaSuccess) { + ET_LOG( + Error, + "sdpa_efficient_attention: Kernel execution failed: %s", + cudaGetErrorString(cuda_err)); + aoti_torch_delete_tensor_object(output); + return nullptr; + } + + printf("efficient_attention_kernel completed successfully\n"); + fflush(stdout); + + return output; +} + +/** + * Memory-Efficient Attention entry point with dtype dispatch + */ +Tensor* sdpa_efficient_attention( + const Tensor* query, + const Tensor* key, + const Tensor* value, + const Tensor* attn_bias, + bool is_causal, + double scale_factor, + cudaStream_t stream) { + auto dtype = query->dtype(); + + if (dtype == executorch::aten::ScalarType::Float) { + return sdpa_efficient_attention_impl( + query, + key, + value, + attn_bias, + is_causal, + static_cast(scale_factor), + stream); + } else if (dtype == executorch::aten::ScalarType::Half) { + return sdpa_efficient_attention_impl<__half>( + query, + key, + value, + attn_bias, + is_causal, + static_cast(scale_factor), + stream); + } else if (dtype == executorch::aten::ScalarType::BFloat16) { + return sdpa_efficient_attention_impl<__nv_bfloat16>( + query, + key, + value, + attn_bias, + is_causal, + static_cast(scale_factor), + stream); + } else { + ET_LOG(Error, "sdpa_efficient_attention: Unsupported dtype"); + return nullptr; + } +} + +// ============================================================================ +// Math Fallback Implementation +// ============================================================================ + +/** + * Math fallback implementation for SDPA + * + * This implementation uses cuBLAS for matrix multiplications and custom + * kernels for softmax. It provides maximum compatibility across all CUDA + * devices but may not be as optimized as Flash Attention or Memory Efficient + * Attention. + * + * Algorithm: + * 1. Compute attention scores: S = (Q @ K^T) + * 2. Apply scaling and compute softmax with optional causal mask + * 3. Compute output: O = attention_weights @ V + */ +template +Tensor* sdpa_math_fallback_impl( + const Tensor* query, + const Tensor* key, + const Tensor* value, + const Tensor* attn_mask, + bool is_causal, + float scale_factor, + cudaStream_t stream) { + printf("Inside sdpa_math_fallback_impl\n"); + fflush(stdout); + + // Get tensor dimensions + printf("Getting tensor dimensions\n"); + fflush(stdout); + const int64_t batch = query->size(0); + const int64_t num_heads = query->size(1); + const int64_t seq_len_q = query->size(2); + const int64_t head_dim = query->size(3); + const int64_t seq_len_k = key->size(2); + const int64_t head_dim_v = value->size(3); + + printf( + "Dimensions: batch=%ld, num_heads=%ld, seq_len_q=%ld, head_dim=%ld\n", + batch, + num_heads, + seq_len_q, + head_dim); + fflush(stdout); + + // Get cuBLAS handle + printf("About to get cuBLAS handle\n"); + fflush(stdout); + cublasHandle_t handle = get_cublas_handle(stream); + printf("Got cuBLAS handle: %p\n", (void*)handle); + fflush(stdout); + + // Step 1: Allocate temporary buffer for attention scores + // Shape: [batch, num_heads, seq_len_q, seq_len_k] + const int64_t scores_size = batch * num_heads * seq_len_q * seq_len_k; + scalar_t* scores_ptr = nullptr; + cudaMalloc(&scores_ptr, scores_size * sizeof(scalar_t)); + if (scores_ptr == nullptr) { + ET_LOG(Error, "sdpa_math_fallback: Failed to allocate scores buffer"); + return nullptr; + } + + // Step 2: Compute Q @ K^T using cuBLAS + // Q: [batch * num_heads, seq_len_q, head_dim] + // K^T: [batch * num_heads, head_dim, seq_len_k] + // Output: [batch * num_heads, seq_len_q, seq_len_k] + + const int m = seq_len_q; + const int n = seq_len_k; + const int k = head_dim; + const int batch_count = batch * num_heads; + + const scalar_t alpha = static_cast(1.0f); + const scalar_t beta = static_cast(0.0f); + + const scalar_t* q_ptr = reinterpret_cast(query->data_ptr()); + const scalar_t* k_ptr = reinterpret_cast(key->data_ptr()); + + // Strides for batched GEMM + const int64_t stride_q = seq_len_q * head_dim; + const int64_t stride_k = seq_len_k * head_dim; + const int64_t stride_scores = seq_len_q * seq_len_k; + + // Q @ K^T + cublasStatus_t status = batched_gemm( + handle, + CUBLAS_OP_T, // Transpose K + CUBLAS_OP_N, // No transpose Q + n, // seq_len_k + m, // seq_len_q + k, // head_dim + &alpha, + k_ptr, + k, // K matrix + stride_k, + q_ptr, + k, // Q matrix + stride_q, + &beta, + scores_ptr, + n, // Output scores + stride_scores, + batch_count); + + if (status != CUBLAS_STATUS_SUCCESS) { + ET_LOG(Error, "sdpa_math_fallback: cuBLAS GEMM failed for Q @ K^T"); + cudaFree(scores_ptr); + return nullptr; + } + + // Step 3: Apply softmax with scaling, optional causal mask, and attention + // bias + const int threads_per_block = 256; + const int64_t total_rows = batch * num_heads * seq_len_q; + const int num_blocks = + (total_rows + threads_per_block - 1) / threads_per_block; + + // Get attn_bias pointer if provided + const scalar_t* bias_ptr = attn_mask + ? reinterpret_cast(attn_mask->data_ptr()) + : nullptr; + + printf("About to launch softmax kernel with attn_bias=%p\n", (void*)bias_ptr); + fflush(stdout); + + softmax_with_mask_kernel + <<>>( + scores_ptr, + scores_ptr, // in-place + bias_ptr, // attention bias (additive) + batch, + num_heads, + seq_len_q, + seq_len_k, + is_causal, + scale_factor); + + cudaError_t cuda_err = cudaGetLastError(); + if (cuda_err != cudaSuccess) { + printf( + "sdpa_math_fallback: Softmax kernel launch failed: %s", + cudaGetErrorString(cuda_err)); + fflush(stdout); + ET_LOG( + Error, + "sdpa_math_fallback: Softmax kernel launch failed: %s", + cudaGetErrorString(cuda_err)); + cudaFree(scores_ptr); + return nullptr; + } + + // Step 4: Allocate output tensor [batch, num_heads, seq_len_q, head_dim_v] + Tensor* output = nullptr; + std::array output_shape = { + batch, num_heads, seq_len_q, head_dim_v}; + std::array output_stride = { + num_heads * seq_len_q * head_dim_v, + seq_len_q * head_dim_v, + head_dim_v, + 1}; + + auto dtype_int = static_cast(query->dtype()); + aoti_torch_empty_strided( + 4, + output_shape.data(), + output_stride.data(), + dtype_int, + static_cast(SupportedDevices::CUDA), + 0, + &output); + + if (output == nullptr) { + printf("sdpa_math_fallback: Failed to allocate output tensor"); + fflush(stdout); + ET_LOG(Error, "sdpa_math_fallback: Failed to allocate output tensor"); + cudaFree(scores_ptr); + return nullptr; + } + + // Step 5: Compute attention_weights @ V + // attention_weights: [batch * num_heads, seq_len_q, seq_len_k] + // V: [batch * num_heads, seq_len_k, head_dim_v] + // Output: [batch * num_heads, seq_len_q, head_dim_v] + + const int m_v = seq_len_q; + const int n_v = head_dim_v; + const int k_v = seq_len_k; + + const scalar_t* v_ptr = reinterpret_cast(value->data_ptr()); + scalar_t* out_ptr = reinterpret_cast(output->data_ptr()); + + const int64_t stride_v = seq_len_k * head_dim_v; + const int64_t stride_out = seq_len_q * head_dim_v; + + status = batched_gemm( + handle, + CUBLAS_OP_N, // No transpose V + CUBLAS_OP_N, // No transpose attention_weights + n_v, // head_dim_v + m_v, // seq_len_q + k_v, // seq_len_k + &alpha, + v_ptr, + n_v, // V matrix + stride_v, + scores_ptr, + k_v, // attention_weights + stride_scores, + &beta, + out_ptr, + n_v, // Output + stride_out, + batch_count); + + // Cleanup temporary buffers + cudaFree(scores_ptr); + + if (status != CUBLAS_STATUS_SUCCESS) { + printf("sdpa_math_fallback: cuBLAS GEMM failed for attention_weights @ V"); + fflush(stdout); + ET_LOG( + Error, + "sdpa_math_fallback: cuBLAS GEMM failed for attention_weights @ V"); + aoti_torch_delete_tensor_object(output); + return nullptr; + } + + return output; +} + +Tensor* sdpa_math_fallback( + const Tensor* query, + const Tensor* key, + const Tensor* value, + const Tensor* attn_mask, + bool is_causal, + double scale_factor, + cudaStream_t stream) { + printf("Inside sdpa_math_fallback\n"); + fflush(stdout); + + // Dispatch based on dtype + auto dtype = query->dtype(); + printf("Query dtype: %d\n", static_cast(dtype)); + fflush(stdout); + + if (dtype == executorch::aten::ScalarType::Float) { + printf("Calling sdpa_math_fallback_impl\n"); + fflush(stdout); + return sdpa_math_fallback_impl( + query, + key, + value, + attn_mask, + is_causal, + static_cast(scale_factor), + stream); + } else if (dtype == executorch::aten::ScalarType::Half) { + printf("Calling sdpa_math_fallback_impl\n"); + fflush(stdout); + return sdpa_math_fallback_impl<__half>( + query, + key, + value, + attn_mask, + is_causal, + static_cast(scale_factor), + stream); + } else if (dtype == executorch::aten::ScalarType::BFloat16) { + printf("Calling sdpa_math_fallback_impl\n"); + fflush(stdout); + return sdpa_math_fallback_impl<__nv_bfloat16>( + query, + key, + value, + attn_mask, + is_causal, + static_cast(scale_factor), + stream); + } else { + ET_LOG(Error, "sdpa_math_fallback: Unsupported dtype"); + return nullptr; + } +} + +/** + * Main entry point for SDPA computation + */ +Tensor* scaled_dot_product_attention_cuda( + const Tensor* query, + const Tensor* key, + const Tensor* value, + const Tensor* attn_mask, + double dropout_p, + bool is_causal, + const double* scale, + bool enable_gqa, + cudaStream_t stream) { + // Select backend + SDPBackend backend = + select_sdp_backend(query, key, value, attn_mask, dropout_p, is_causal); + + if (backend == SDPBackend::Error) { + ET_LOG( + Error, "scaled_dot_product_attention_cuda: No valid backend selected"); + return nullptr; + } + + printf("selected backend: %d\n", static_cast(backend)); + fflush(stdout); + + // Calculate scale factor + printf("About to calculate scale factor\n"); + fflush(stdout); + double scale_factor = calculate_scale(query, scale); + printf("Calculated scale factor: %f\n", scale_factor); + fflush(stdout); + + printf("enable_gqa: %d\n", enable_gqa); + fflush(stdout); + + // Handle GQA if needed + printf("About to check GQA configuration\n"); + fflush(stdout); + + if (enable_gqa && is_gqa_configuration(query, key, value)) { + printf("GQA configuration detected\n"); + fflush(stdout); + + if (!validate_gqa(query, key, value)) { + ET_LOG( + Error, + "scaled_dot_product_attention_cuda: Invalid GQA configuration"); + return nullptr; + } + ET_LOG( + Error, + "scaled_dot_product_attention_cuda: GQA support not yet implemented. " + "Need to repeat K/V heads to match Q heads."); + return nullptr; + } + + printf("Passed GQA check\n"); + fflush(stdout); + + printf( + "About to enter switch statement, backend = %d\n", + static_cast(backend)); + fflush(stdout); + + // Dispatch to appropriate backend + switch (backend) { + case SDPBackend::Math: + printf("In Math case, about to call sdpa_math_fallback\n"); + fflush(stdout); + return sdpa_math_fallback( + query, key, value, attn_mask, is_causal, scale_factor, stream); + + case SDPBackend::FlashAttention: + printf("In FlashAttention case\n"); + fflush(stdout); + return sdpa_flash_attention( + query, key, value, attn_mask, is_causal, scale_factor, stream); + + case SDPBackend::MemoryEfficientAttention: + printf("Memory Efficient Attention backend\n"); + fflush(stdout); + return sdpa_efficient_attention( + query, key, value, attn_mask, is_causal, scale_factor, stream); + + case SDPBackend::CuDNN: + printf("cuDNN backend not yet implemented\n"); + fflush(stdout); + return nullptr; + + default: + printf("Unknown SDPA backend\n"); + fflush(stdout); + return nullptr; + } +} + +// ============================================================================ +// C API Implementation +// ============================================================================ + +#ifdef __cplusplus +extern "C" { +#endif + +AOTITorchError aoti_torch_cuda__scaled_dot_product_flash_attention( + Tensor* query, + Tensor* key, + Tensor* value, + double dropout_p, + int32_t is_causal, + int32_t return_debug_mask, + double* scale, + Tensor** ret0, // output + Tensor** ret1, // logsumexp (nullptr for inference) + Tensor** ret2, // cum_seq_q (nullptr for inference) + Tensor** ret3, // cum_seq_k (nullptr for inference) + int64_t* max_seqlen_q, + int64_t* max_seqlen_k, + Tensor** ret4, // philox_seed (nullptr for inference) + Tensor** ret5, // philox_offset (nullptr for inference) + Tensor** ret6) { // debug_attn_mask (nullptr for inference) + // Input validation + if (!query || !key || !value || !ret0) { + ET_LOG( + Error, + "aoti_torch_cuda__scaled_dot_product_flash_attention: Null pointer input"); + return Error::InvalidArgument; + } + + // Currently only support dropout_p = 0.0 for inference + if (dropout_p != 0.0) { + ET_LOG( + Error, + "aoti_torch_cuda__scaled_dot_product_flash_attention: dropout_p != 0.0 is not supported"); + return Error::InvalidArgument; + } + + // Check tensor dimensions + if (query->dim() != 4 || key->dim() != 4 || value->dim() != 4) { + ET_LOG( + Error, + "aoti_torch_cuda__scaled_dot_product_flash_attention: Query, Key, Value must be 4D tensors"); + return Error::InvalidArgument; + } + + // Check that Q, K, V have the same dtype + if (query->dtype() != key->dtype() || query->dtype() != value->dtype()) { + ET_LOG( + Error, + "aoti_torch_cuda__scaled_dot_product_flash_attention: Query, Key, Value must have the same dtype"); + return Error::InvalidArgument; + } + + // Check dtype support + if (!is_supported_dtype(query) || !is_supported_dtype(key) || + !is_supported_dtype(value)) { + ET_LOG( + Error, + "aoti_torch_cuda__scaled_dot_product_flash_attention: Unsupported dtype, only Float32/Float16/BFloat16 supported"); + return Error::InvalidArgument; + } + + // Check tensor shapes + const int64_t batch = query->size(0); + const int64_t num_heads = query->size(1); + const int64_t seq_len_q = query->size(2); + const int64_t head_dim_q = query->size(3); + + const int64_t num_heads_kv = key->size(1); + const int64_t seq_len_k = key->size(2); + const int64_t head_dim_k = key->size(3); + + const int64_t seq_len_v = value->size(2); + const int64_t head_dim_v = value->size(3); + + // Validate shapes + if (key->size(0) != batch || value->size(0) != batch) { + ET_LOG( + Error, + "aoti_torch_cuda__scaled_dot_product_flash_attention: Batch size mismatch"); + return Error::InvalidArgument; + } + + if (seq_len_k != seq_len_v) { + ET_LOG( + Error, + "aoti_torch_cuda__scaled_dot_product_flash_attention: Key and Value sequence length mismatch"); + return Error::InvalidArgument; + } + + if (head_dim_q != head_dim_k) { + ET_LOG( + Error, + "aoti_torch_cuda__scaled_dot_product_flash_attention: Query and Key head dimension mismatch"); + return Error::InvalidArgument; + } + + if (value->size(1) != num_heads_kv) { + ET_LOG( + Error, + "aoti_torch_cuda__scaled_dot_product_flash_attention: Key and Value num_heads mismatch"); + return Error::InvalidArgument; + } + + // Determine if GQA is being used + bool enable_gqa = (num_heads != num_heads_kv); + + // GQA validation and check + if (enable_gqa) { + if (num_heads % num_heads_kv != 0) { + ET_LOG( + Error, + "aoti_torch_cuda__scaled_dot_product_flash_attention: For GQA, num_heads must be divisible by num_heads_kv"); + return Error::InvalidArgument; + } + ET_LOG( + Error, + "aoti_torch_cuda__scaled_dot_product_flash_attention: GQA support not yet implemented"); + return Error::InvalidArgument; + } + + // Check if flash attention can be used + if (!supports_flash_attention()) { + ET_LOG( + Error, + "aoti_torch_cuda__scaled_dot_product_flash_attention: Flash Attention not supported on this GPU"); + return Error::InvalidArgument; + } + + if (!can_use_flash_attention(query, key, value, nullptr, is_causal != 0)) { + ET_LOG( + Error, + "aoti_torch_cuda__scaled_dot_product_flash_attention: Input conditions not suitable for Flash Attention"); + return Error::InvalidArgument; + } + + // Calculate scale factor + double scale_factor = calculate_scale(query, scale); + + // Get CUDA stream + auto stream_result = getCurrentCUDAStream(0); + if (!stream_result.ok()) { + ET_LOG( + Error, + "aoti_torch_cuda__scaled_dot_product_flash_attention: Failed to get CUDA stream"); + return Error::Internal; + } + cudaStream_t stream = stream_result.get(); + + // Call flash attention directly + Tensor* output = sdpa_flash_attention( + query, + key, + value, + nullptr, // attn_mask - Flash Attention doesn't support it + is_causal != 0, + scale_factor, + stream); + + if (output == nullptr) { + ET_LOG( + Error, + "aoti_torch_cuda__scaled_dot_product_flash_attention: Flash Attention computation failed"); + return Error::Internal; + } + + // Set the main output + *ret0 = output; + + // Set all training-related outputs to nullptr (for inference) + if (ret1) + *ret1 = nullptr; // logsumexp + if (ret2) + *ret2 = nullptr; // cum_seq_q + if (ret3) + *ret3 = nullptr; // cum_seq_k + if (ret4) + *ret4 = nullptr; // philox_seed + if (ret5) + *ret5 = nullptr; // philox_offset + if (ret6) + *ret6 = nullptr; // debug_attn_mask + + return Error::Ok; +} + +AOTITorchError aoti_torch_cuda__scaled_dot_product_efficient_attention( + Tensor* query, + Tensor* key, + Tensor* value, + Tensor** attn_bias, // Optional attention bias (can be nullptr) + int32_t compute_log_sumexp, + double dropout_p, + int32_t is_causal, + double* scale, + Tensor** ret0, // output + Tensor** ret1, // logsumexp (nullptr for inference) + Tensor** ret2, // philox_seed (nullptr for inference) + Tensor** ret3) { // philox_offset (nullptr for inference) + + // Input validation + if (!query || !key || !value || !ret0) { + ET_LOG( + Error, + "aoti_torch_cuda__scaled_dot_product_efficient_attention: Null pointer input"); + return Error::InvalidArgument; + } + + // Currently only support dropout_p = 0.0 for inference + if (dropout_p != 0.0) { + ET_LOG( + Error, + "aoti_torch_cuda__scaled_dot_product_efficient_attention: dropout_p != 0.0 is not supported"); + return Error::InvalidArgument; + } + + // Check tensor dimensions + if (query->dim() != 4 || key->dim() != 4 || value->dim() != 4) { + ET_LOG( + Error, + "aoti_torch_cuda__scaled_dot_product_efficient_attention: Query, Key, Value must be 4D tensors"); + return Error::InvalidArgument; + } + + // Check that Q, K, V have the same dtype + if (query->dtype() != key->dtype() || query->dtype() != value->dtype()) { + ET_LOG( + Error, + "aoti_torch_cuda__scaled_dot_product_efficient_attention: Query, Key, Value must have the same dtype"); + return Error::InvalidArgument; + } + + // Check dtype support + if (!is_supported_dtype(query) || !is_supported_dtype(key) || + !is_supported_dtype(value)) { + ET_LOG( + Error, + "aoti_torch_cuda__scaled_dot_product_efficient_attention: Unsupported dtype, only Float32/Float16/BFloat16 supported"); + return Error::InvalidArgument; + } + + // Check tensor shapes + const int64_t batch = query->size(0); + const int64_t num_heads = query->size(1); + const int64_t seq_len_q = query->size(2); + const int64_t head_dim_q = query->size(3); + + const int64_t num_heads_kv = key->size(1); + const int64_t seq_len_k = key->size(2); + const int64_t head_dim_k = key->size(3); + + const int64_t seq_len_v = value->size(2); + const int64_t head_dim_v = value->size(3); + + // Validate shapes + if (key->size(0) != batch || value->size(0) != batch) { + ET_LOG( + Error, + "aoti_torch_cuda__scaled_dot_product_efficient_attention: Batch size mismatch"); + return Error::InvalidArgument; + } + + if (seq_len_k != seq_len_v) { + ET_LOG( + Error, + "aoti_torch_cuda__scaled_dot_product_efficient_attention: Key and Value sequence length mismatch"); + return Error::InvalidArgument; + } + + if (head_dim_q != head_dim_k) { + ET_LOG( + Error, + "aoti_torch_cuda__scaled_dot_product_efficient_attention: Query and Key head dimension mismatch"); + return Error::InvalidArgument; + } + + if (value->size(1) != num_heads_kv) { + ET_LOG( + Error, + "aoti_torch_cuda__scaled_dot_product_efficient_attention: Key and Value num_heads mismatch"); + return Error::InvalidArgument; + } + + // Determine if GQA is being used + bool enable_gqa = (num_heads != num_heads_kv); + + // GQA validation and check + if (enable_gqa) { + if (num_heads % num_heads_kv != 0) { + ET_LOG( + Error, + "aoti_torch_cuda__scaled_dot_product_efficient_attention: For GQA, num_heads must be divisible by num_heads_kv"); + return Error::InvalidArgument; + } + ET_LOG( + Error, + "aoti_torch_cuda__scaled_dot_product_efficient_attention: GQA support not yet implemented"); + return Error::InvalidArgument; + } + + // Extract attn_bias tensor if provided + Tensor* attn_bias_tensor = (attn_bias && *attn_bias) ? *attn_bias : nullptr; + + // Check if efficient attention can be used + if (!can_use_efficient_attention(query, key, value, attn_bias_tensor, is_causal != 0)) { + ET_LOG( + Error, + "aoti_torch_cuda__scaled_dot_product_efficient_attention: Input conditions not suitable for Efficient Attention"); + return Error::InvalidArgument; + } + + // Calculate scale factor + double scale_factor = calculate_scale(query, scale); + + // Get CUDA stream + auto stream_result = getCurrentCUDAStream(0); + if (!stream_result.ok()) { + ET_LOG( + Error, + "aoti_torch_cuda__scaled_dot_product_efficient_attention: Failed to get CUDA stream"); + return Error::Internal; + } + cudaStream_t stream = stream_result.get(); + + // Call efficient attention directly + Tensor* output = sdpa_efficient_attention( + query, + key, + value, + attn_bias_tensor, // Pass attn_bias (can be nullptr) + is_causal != 0, + scale_factor, + stream); + + if (output == nullptr) { + ET_LOG( + Error, + "aoti_torch_cuda__scaled_dot_product_efficient_attention: Efficient Attention computation failed"); + return Error::Internal; + } + + // Set the main output + *ret0 = output; + + // Set all training-related outputs to nullptr (for inference) + if (ret1) + *ret1 = nullptr; // logsumexp + if (ret2) + *ret2 = nullptr; // philox_seed + if (ret3) + *ret3 = nullptr; // philox_offset + + return Error::Ok; +} + +#ifdef __cplusplus +} +#endif + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/sdpa.cuh b/backends/cuda/runtime/shims/sdpa.cuh new file mode 100644 index 00000000000..7a0a2cf340e --- /dev/null +++ b/backends/cuda/runtime/shims/sdpa.cuh @@ -0,0 +1,395 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// This file implements scaled_dot_product_attention for ExecuTorch. +// +// IMPLEMENTATION NOTES: +// --------------------- +// This is NOT a direct port from PyTorch. Instead, we implemented +// a custom Math Fallback using cuBLAS and custom CUDA kernels. +// +// PyTorch reference implementations (for architecture reference only): +// - CPU/General: aten/src/ATen/native/transformers/attention.cpp +// - CUDA: aten/src/ATen/native/transformers/cuda/attention.cu +// +// Key differences from PyTorch: +// - PyTorch uses high-level ATen ops (at::matmul, at::_safe_softmax) +// - We use direct cuBLAS calls and custom softmax kernels +// - Optimized for inference (no dropout, no backward pass) +// - Simplified memory management +// - No ATen/c10 dependencies +// +// PORTING NOTES: +// -------------- +// 1. KERNEL CODE: Adapted from PyTorch attention kernels +// - Math fallback implementation for maximum compatibility +// - Supports Float32, Float16, and BFloat16 dtypes +// - Standard attention computation: softmax(Q @ K^T / scale) @ V +// +// 2. API ADAPTATIONS: +// - Replaced at::Tensor with executorch::backends::aoti::Tensor +// - Output returned via pointer-to-pointer instead of by-value +// - Simplified interface for inference (dropout=0.0 only) +// +// 3. REMOVED FEATURES: +// - Flash Attention backend (requires external library) +// - Memory Efficient Attention backend (requires external library) +// - cuDNN backend (requires cuDNN library) +// - Dropout support (training-only feature) +// - Nested tensor support (complex layout) +// - Backward pass (training-only feature) +// +// 4. INFRASTRUCTURE CHANGES: +// - Removed c10::cuda::CUDAGuard: Device management handled by AOTI backend +// - Removed at::cuda::getCurrentCUDAStream(): Stream passed explicitly +// - Simplified error handling using ExecutorTorch Error codes + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace executorch::backends::cuda { + +using executorch::backends::aoti::Tensor; +using executorch::runtime::Error; + +// ============================================================================ +// Utility Functions for SDPA +// ============================================================================ + +// Calculate the scaling factor for attention scores +inline double calculate_scale(const Tensor* query, const double* scale) { + if (scale != nullptr) { + return *scale; + } + // Default: 1 / sqrt(head_dim) + // Query shape: [batch, num_heads, seq_len_q, head_dim] + // head_dim is at index 3 (0-indexed) + const int64_t head_dim = query->size(3); + return 1.0 / std::sqrt(static_cast(head_dim)); +} + +// Check if tensor dtype is supported for SDPA +inline bool is_supported_dtype(const Tensor* tensor) { + auto dtype = tensor->dtype(); + return dtype == executorch::aten::ScalarType::Float || + dtype == executorch::aten::ScalarType::Half || + dtype == executorch::aten::ScalarType::BFloat16; +} + +// ============================================================================ +// Math Fallback Implementation +// ============================================================================ + +// This is the basic, portable implementation that works on all CUDA devices. +// It computes attention using explicit matrix multiplications and softmax: +// 1. Compute scores: S = Q @ K^T * scale +// 2. Apply mask if provided +// 3. Compute attention weights: A = softmax(S) +// 4. Compute output: O = A @ V + +/** + * Math fallback kernel for scaled dot product attention + * + * This is a basic implementation that performs: + * output = softmax(query @ key^T / scale) @ value + * + * Supports: + * - Batch processing + * - Multiple attention heads + * - Optional causal masking + * - Optional explicit attention mask + * - Float32, Float16, BFloat16 dtypes + * + * Note: This implementation is for reference and maximum compatibility. + * For production use, consider using Flash Attention or other optimized + * backends. + */ +Tensor* sdpa_math_fallback( + const Tensor* query, // [batch, num_heads, seq_len_q, head_dim] + const Tensor* key, // [batch, num_heads_kv, seq_len_k, head_dim] + const Tensor* value, // [batch, num_heads_kv, seq_len_k, head_dim_v] + const Tensor* attn_mask, // Optional: [batch, num_heads, seq_len_q, + // seq_len_k] or broadcastable + bool is_causal, // Apply causal masking + double scale_factor, // Scaling factor for attention scores + cudaStream_t stream); // CUDA stream for execution + +// ============================================================================ +// Memory-Efficient Attention Implementation +// ============================================================================ + +/** + * Memory-Efficient Attention kernel for scaled dot product attention + * + * This implementation uses online softmax to compute attention efficiently + * without materializing the full attention matrix. Supports attention bias. + * + * Supports: + * - Batch processing + * - Multiple attention heads + * - Optional causal masking + * - Optional attention bias (additive) + * - Float32, Float16, BFloat16 dtypes + */ +Tensor* sdpa_efficient_attention( + const Tensor* query, // [batch, num_heads, seq_len_q, head_dim] + const Tensor* key, // [batch, num_heads_kv, seq_len_k, head_dim] + const Tensor* value, // [batch, num_heads_kv, seq_len_k, head_dim_v] + const Tensor* attn_bias, // Optional: [batch, num_heads, seq_len_q, + // seq_len_k] or broadcastable + bool is_causal, // Apply causal masking + double scale_factor, // Scaling factor for attention scores + cudaStream_t stream); // CUDA stream for execution + +// ============================================================================ +// Backend Selection +// ============================================================================ + +enum class SDPBackend { + Error = -1, + Math = 0, + FlashAttention = 1, + MemoryEfficientAttention = 2, + CuDNN = 3 +}; + +/** + * Check if Flash Attention is supported on the current GPU + */ +inline bool supports_flash_attention() { + int device; + cudaGetDevice(&device); + + cudaDeviceProp props; + cudaGetDeviceProperties(&props, device); + + // Flash Attention typically requires compute capability >= 7.0 (Volta+) + // For better performance, prefer Ampere+ (8.0+) + // We'll use a conservative threshold of 7.0 for compatibility + return props.major >= 7; +} + +/** + * Check if inputs are suitable for Flash Attention + */ +inline bool can_use_flash_attention( + const Tensor* query, + const Tensor* key, + const Tensor* value, + const Tensor* attn_mask, + bool is_causal) { + // Flash Attention doesn't support explicit attention masks (yet) + // It only supports causal masking + if (attn_mask != nullptr) { + return false; + } + + // Check head dimensions are reasonable for Flash Attention + const int64_t head_dim = query->size(3); + const int64_t head_dim_v = value->size(3); + + // Flash Attention works best with head_dim <= 128 + // Our simple implementation requires head_dim_v <= 64 (output accumulator + // size) which can be improved in the future + if (head_dim > 128 || head_dim_v > 64) { + return false; + } + + // Flash Attention works for all sequence lengths + // While it's most beneficial for longer sequences, it still works correctly + // for short sequences used in testing + return true; +} + +/** + * Check if inputs are suitable for Memory Efficient Attention + */ +inline bool can_use_efficient_attention( + const Tensor* query, + const Tensor* key, + const Tensor* value, + const Tensor* attn_mask, + bool is_causal) { + // Check head dimensions + const int64_t head_dim_v = value->size(3); + + // Our efficient attention implementation supports head_dim_v <= 128 + if (head_dim_v > 128) { + return false; + } + + // Efficient attention supports attention bias/mask + return true; +} + +/** + * Select the best available backend for SDPA based on input parameters + * + * Selection priority: + * 1. Memory Efficient Attention - For inputs with attention bias/mask + * 2. Flash Attention - For supported hardware and suitable inputs (no mask) + * 3. Math fallback - For maximum compatibility + */ +inline SDPBackend select_sdp_backend( + const Tensor* query, + const Tensor* key, + const Tensor* value, + const Tensor* attn_mask, + double dropout_p, + bool is_causal) { + // Check for unsupported features + if (dropout_p > 0.0) { + ET_LOG(Error, "SDPA: Dropout not supported in inference mode"); + return SDPBackend::Error; + } + + // Check tensor dimensions + if (query->dim() != 4 || key->dim() != 4 || value->dim() != 4) { + ET_LOG(Error, "SDPA: All inputs must be 4D tensors"); + return SDPBackend::Error; + } + + // Check dtype support + if (!is_supported_dtype(query) || !is_supported_dtype(key) || + !is_supported_dtype(value)) { + ET_LOG( + Error, + "SDPA: Unsupported dtype, only Float32/Float16/BFloat16 supported"); + return SDPBackend::Error; + } + + // Check dtype consistency + if (query->dtype() != key->dtype() || query->dtype() != value->dtype()) { + ET_LOG(Error, "SDPA: Query, Key, Value must have the same dtype"); + return SDPBackend::Error; + } + + // If attention mask/bias is provided, use Memory Efficient Attention + if (attn_mask != nullptr && + can_use_efficient_attention(query, key, value, attn_mask, is_causal)) { + printf("Selected backend: MemoryEfficientAttention (has attn_mask)\n"); + return SDPBackend::MemoryEfficientAttention; + } + + // Try Flash Attention if hardware supports it and inputs are suitable + if (supports_flash_attention() && + can_use_flash_attention(query, key, value, attn_mask, is_causal)) { + printf("Selected backend: FlashAttention\n"); + return SDPBackend::FlashAttention; + } + + // Fall back to math implementation + printf("Selected backend: Math (fallback)\n"); + return SDPBackend::Math; +} + +// ============================================================================ +// Helper Functions for Causal Mask +// ============================================================================ + +/** + * Check if we need to apply causal masking + */ +inline bool needs_causal_mask(bool is_causal, const Tensor* attn_mask) { + if (!is_causal) { + return false; + } + if (attn_mask != nullptr) { + ET_LOG( + Error, "SDPA: Cannot use both is_causal=true and explicit attn_mask"); + return false; + } + return true; +} + +// ============================================================================ +// Grouped Query Attention (GQA) Support +// ============================================================================ + +/** + * Check if inputs require GQA handling + * + * GQA allows num_heads_q != num_heads_kv, where num_heads_q must be + * divisible by num_heads_kv. Key and Value heads are repeated to match + * Query heads. + */ +inline bool is_gqa_configuration( + const Tensor* query, + const Tensor* key, + const Tensor* value) { + const int64_t num_heads_q = query->size(1); + const int64_t num_heads_kv = key->size(1); + + return num_heads_q != num_heads_kv; +} + +/** + * Validate GQA configuration + */ +inline bool +validate_gqa(const Tensor* query, const Tensor* key, const Tensor* value) { + const int64_t num_heads_q = query->size(1); + const int64_t num_heads_kv = key->size(1); + const int64_t num_heads_v = value->size(1); + + // Key and Value must have same num_heads + if (num_heads_kv != num_heads_v) { + ET_LOG(Error, "SDPA GQA: Key and Value must have same num_heads"); + return false; + } + + // Query heads must be divisible by Key/Value heads + if (num_heads_q % num_heads_kv != 0) { + ET_LOG( + Error, + "SDPA GQA: Query num_heads must be divisible by Key/Value num_heads"); + return false; + } + + return true; +} + +// ============================================================================ +// Main SDPA Entry Point +// ============================================================================ + +/** + * Compute scaled dot product attention + * + * This is the main entry point that selects the appropriate backend + * and dispatches to the corresponding implementation. + * + * Currently only Math fallback is implemented. Future versions may add: + * - Flash Attention + * - Memory Efficient Attention + * - cuDNN backend + */ +Tensor* scaled_dot_product_attention_cuda( + const Tensor* query, + const Tensor* key, + const Tensor* value, + const Tensor* attn_mask, + double dropout_p, + bool is_causal, + const double* scale, + bool enable_gqa, + cudaStream_t stream); + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/sdpa.h b/backends/cuda/runtime/shims/sdpa.h new file mode 100644 index 00000000000..1c037a1ad94 --- /dev/null +++ b/backends/cuda/runtime/shims/sdpa.h @@ -0,0 +1,113 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace executorch::backends::cuda { + +using executorch::backends::aoti::AOTITorchError; +using executorch::backends::aoti::Tensor; + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * Performs scaled dot-product attention on CUDA. + * + * This matches PyTorch's AOTI signature for: + * torch.ops.aten._scaled_dot_product_flash_attention + * + * @param query Query tensor [batch, num_heads, seq_len_q, head_dim] + * @param key Key tensor [batch, num_heads, seq_len_k, head_dim] + * @param value Value tensor [batch, num_heads, seq_len_k, head_dim] + * @param dropout_p Dropout probability (must be 0.0 for inference) + * @param is_causal Whether to apply causal masking + * @param return_debug_mask Whether to return debug attention mask (ignored for + * inference) + * @param scale Optional scaling factor for attention scores + * @param ret0 Output: attention result [batch, num_heads, seq_len_q, head_dim] + * @param ret1 Output: logsumexp (set to nullptr for inference) + * @param ret2 Output: cumulative sequence length Q (set to nullptr for + * inference) + * @param ret3 Output: cumulative sequence length K (set to nullptr for + * inference) + * @param max_seqlen_q Maximum sequence length in Q (set to seq_len_q) + * @param max_seqlen_k Maximum sequence length in K (set to seq_len_k) + * @param ret4 Output: philox seed (set to nullptr for inference) + * @param ret5 Output: philox offset (set to nullptr for inference) + * @param ret6 Output: debug attention mask (set to nullptr for inference) + * + * @return AOTITorchError error code + */ +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_cuda__scaled_dot_product_flash_attention( + Tensor* query, + Tensor* key, + Tensor* value, + double dropout_p, + int32_t is_causal, + int32_t return_debug_mask, + double* scale, + Tensor** ret0, // output + Tensor** ret1, // logsumexp (nullptr for inference) + Tensor** ret2, // cum_seq_q (nullptr for inference) + Tensor** ret3, // cum_seq_k (nullptr for inference) + int64_t* max_seqlen_q, + int64_t* max_seqlen_k, + Tensor** ret4, // philox_seed (nullptr for inference) + Tensor** ret5, // philox_offset (nullptr for inference) + Tensor** ret6); // debug_attn_mask (nullptr for inference) + +/** + * Performs scaled dot-product efficient attention on CUDA. + * + * This matches PyTorch's AOTI signature for: + * torch.ops.aten._scaled_dot_product_efficient_attention + * + * @param query Query tensor [batch, num_heads, seq_len_q, head_dim] + * @param key Key tensor [batch, num_heads, seq_len_k, head_dim] + * @param value Value tensor [batch, num_heads, seq_len_k, head_dim] + * @param attn_bias Optional attention bias (additive mask) + * @param compute_log_sumexp Whether to compute logsumexp (ignored for + * inference) + * @param dropout_p Dropout probability (must be 0.0 for inference) + * @param is_causal Whether to apply causal masking + * @param scale Optional scaling factor for attention scores + * @param ret0 Output: attention result [batch, num_heads, seq_len_q, head_dim] + * @param ret1 Output: logsumexp (set to nullptr for inference) + * @param ret2 Output: philox seed (set to nullptr for inference) + * @param ret3 Output: philox offset (set to nullptr for inference) + * + * @return AOTITorchError error code + * + */ +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_cuda__scaled_dot_product_efficient_attention( + Tensor* query, + Tensor* key, + Tensor* value, + Tensor** attn_bias, // Optional attention bias (can be nullptr) + int32_t compute_log_sumexp, + double dropout_p, + int32_t is_causal, + double* scale, + Tensor** ret0, // output + Tensor** ret1, // logsumexp (nullptr for inference) + Tensor** ret2, // philox_seed (nullptr for inference) + Tensor** ret3); // philox_offset (nullptr for inference) + +#ifdef __cplusplus +} +#endif + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/tests/targets.bzl b/backends/cuda/runtime/shims/tests/targets.bzl index b274ecf3675..0896b3b6a3b 100644 --- a/backends/cuda/runtime/shims/tests/targets.bzl +++ b/backends/cuda/runtime/shims/tests/targets.bzl @@ -34,4 +34,5 @@ def define_common_targets(): cuda_shim_cpp_unittest("aoti_torch_copy_") cuda_shim_cpp_unittest("aoti_torch_cuda_guard") cuda_shim_cpp_unittest("aoti_torch_cuda__weight_int4pack_mm") + cuda_shim_cpp_unittest("aoti_torch_cuda_scaled_dot_product_attention") cuda_shim_cpp_unittest("aoti_torch_new_tensor_handle") diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_scaled_dot_product_attention.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_scaled_dot_product_attention.cpp new file mode 100644 index 00000000000..25286d7b4f8 --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_scaled_dot_product_attention.cpp @@ -0,0 +1,1815 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include +#include // For can_use_flash_attention +#include +#include +#include +#include + +#include +#include + +using namespace executorch::backends::cuda; +using namespace executorch::backends::aoti; +using namespace executorch::runtime; + +// Test fixture for SDPA tests +class AOTITorchSDPATest : public ::testing::Test { + protected: + void SetUp() override { + // Initialize ExecuTorch Platform Abstraction Layer + et_pal_init(); + + // Check if CUDA is available + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available, skipping CUDA tests"; + } + + // Clean up any existing cached metadata before each test + cleanup_tensor_metadata(); + } + + void TearDown() override { + // Clean up after each test + cleanup_tensor_metadata(); + } + + // Helper function to create a Float32 tensor filled with a specific value + Tensor* create_float_tensor( + std::vector shape, + float fill_value = 1.0f) { + Tensor* tensor = nullptr; + + // Calculate size + int64_t total_size = 1; + for (auto dim : shape) { + total_size *= dim; + } + + // Calculate strides (row-major) + std::vector strides(shape.size()); + int64_t stride = 1; + for (int i = shape.size() - 1; i >= 0; --i) { + strides[i] = stride; + stride *= shape[i]; + } + + // Create tensor + Error error = aoti_torch_empty_strided( + shape.size(), + shape.data(), + strides.data(), + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, + &tensor); + + if (error != Error::Ok || tensor == nullptr) { + return nullptr; + } + + // Fill with value + std::vector host_data(total_size, fill_value); + cudaMemcpy( + tensor->data_ptr(), + host_data.data(), + total_size * sizeof(float), + cudaMemcpyHostToDevice); + + return tensor; + } + + // Helper function to create a BFloat16 tensor + Tensor* create_bfloat16_tensor( + std::vector shape, + float fill_value = 1.0f) { + Tensor* tensor = nullptr; + + // Calculate size + int64_t total_size = 1; + for (auto dim : shape) { + total_size *= dim; + } + + // Calculate strides (row-major) + std::vector strides(shape.size()); + int64_t stride = 1; + for (int i = shape.size() - 1; i >= 0; --i) { + strides[i] = stride; + stride *= shape[i]; + } + + // Create tensor + Error error = aoti_torch_empty_strided( + shape.size(), + shape.data(), + strides.data(), + static_cast(SupportedDTypes::BFLOAT16), + static_cast(SupportedDevices::CUDA), + 0, + &tensor); + + if (error != Error::Ok || tensor == nullptr) { + return nullptr; + } + + // Fill with value + // Note: For simplicity, we'll fill with float and let the runtime handle conversion + // In production, you'd want to properly convert to bfloat16 + std::vector host_data(total_size, fill_value); + cudaMemcpy( + tensor->data_ptr(), + host_data.data(), + total_size * sizeof(float), + cudaMemcpyHostToDevice); + + return tensor; + } + + // Helper to check if output tensor has expected shape + bool check_output_shape( + Tensor* output, + const std::vector& expected_shape) { + if (output == nullptr) { + return false; + } + if (output->dim() != expected_shape.size()) { + return false; + } + for (size_t i = 0; i < expected_shape.size(); ++i) { + if (output->size(i) != expected_shape[i]) { + return false; + } + } + return true; + } + + // Helper to copy tensor data from GPU to CPU for verification + std::vector copy_tensor_to_host(Tensor* tensor) { + int64_t total_size = 1; + for (int i = 0; i < tensor->dim(); ++i) { + total_size *= tensor->size(i); + } + + std::vector host_data(total_size); + cudaMemcpy( + host_data.data(), + tensor->data_ptr(), + total_size * sizeof(float), + cudaMemcpyDeviceToHost); + + return host_data; + } + + // Helper to check if a value is approximately equal (for floating point comparison) + bool approx_equal(float a, float b, float epsilon = 1e-5f) { + return std::abs(a - b) < epsilon; + } + + // ======================================================================== + // Wrapper Functions for Simplified Testing + // ======================================================================== + + /** + * Simplified wrapper for Flash Attention testing + * Only requires the essential parameters, sets others to nullptr/defaults + */ + AOTITorchError call_flash_attention( + Tensor* query, + Tensor* key, + Tensor* value, + double dropout_p, + int32_t is_causal, + double* scale, + Tensor** output) { + // Initialize all optional outputs to nullptr + Tensor* logsumexp = nullptr; + Tensor* cum_seq_q = nullptr; + Tensor* cum_seq_k = nullptr; + int64_t max_seqlen_q = query->size(2); + int64_t max_seqlen_k = key->size(2); + Tensor* philox_seed = nullptr; + Tensor* philox_offset = nullptr; + Tensor* debug_mask = nullptr; + + return aoti_torch_cuda__scaled_dot_product_flash_attention( + query, + key, + value, + dropout_p, + is_causal, + 0, // return_debug_mask = 0 + scale, + output, + &logsumexp, + &cum_seq_q, + &cum_seq_k, + &max_seqlen_q, + &max_seqlen_k, + &philox_seed, + &philox_offset, + &debug_mask); + } + + /** + * Simplified wrapper for Efficient Attention testing + * Only requires the essential parameters, sets others to nullptr/defaults + */ + AOTITorchError call_efficient_attention( + Tensor* query, + Tensor* key, + Tensor* value, + Tensor* attn_bias, + int32_t is_causal, + double* scale, + Tensor** output) { + // Initialize all optional outputs to nullptr + Tensor* logsumexp = nullptr; + Tensor* philox_seed = nullptr; + Tensor* philox_offset = nullptr; + + return aoti_torch_cuda__scaled_dot_product_efficient_attention( + query, + key, + value, + attn_bias ? &attn_bias : nullptr, + 0, // compute_log_sumexp = 0 + 0.0, // dropout_p = 0.0 + is_causal, + scale, + output, + &logsumexp, + &philox_seed, + &philox_offset); + } +}; + +// ============================================================================ +// Basic Functionality Tests +// ============================================================================ + +// Test basic SDPA with Float32, no causal mask +TEST_F(AOTITorchSDPATest, BasicFunctionalityFloat32) { + // Create tensors: [batch=1, num_heads=2, seq_len=4, head_dim=8] + const int64_t batch = 1; + const int64_t num_heads = 2; + const int64_t seq_len = 4; + const int64_t head_dim = 8; + + Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); + Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.3f); + + // Create V with different values at each position so attention weight changes matter + Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.0f); + std::vector value_host(batch * num_heads * seq_len * head_dim); + for (int64_t pos = 0; pos < seq_len; ++pos) { + for (int64_t d = 0; d < head_dim; ++d) { + // V values: pos 0=1.0, pos 1=2.0, pos 2=3.0, etc. + value_host[pos * head_dim + d] = static_cast(pos + 1); + } + } + cudaMemcpy( + value->data_ptr(), + value_host.data(), + value_host.size() * sizeof(float), + cudaMemcpyHostToDevice); + + ASSERT_NE(query, nullptr) << "Failed to create query tensor"; + ASSERT_NE(key, nullptr) << "Failed to create key tensor"; + ASSERT_NE(value, nullptr) << "Failed to create value tensor"; + + printf("Testing SDPA Float32: [%ldx%ldx%ldx%ld]\n", batch, num_heads, seq_len, head_dim); + + // Call SDPA - Flash Attention + Tensor* output = nullptr; + Tensor* logsumexp = nullptr; + Tensor* cum_seq_q = nullptr; + Tensor* cum_seq_k = nullptr; + int64_t max_seqlen_q = seq_len; + int64_t max_seqlen_k = seq_len; + Tensor* philox_seed = nullptr; + Tensor* philox_offset = nullptr; + Tensor* debug_mask = nullptr; + + AOTITorchError error = aoti_torch_cuda__scaled_dot_product_flash_attention( + query, + key, + value, + 0.0, // no dropout + 0, // not causal + 0, // no debug mask + nullptr, // default scale + &output, + &logsumexp, + &cum_seq_q, + &cum_seq_k, + &max_seqlen_q, + &max_seqlen_k, + &philox_seed, + &philox_offset, + &debug_mask); + + // Check result + EXPECT_EQ(error, Error::Ok) << "SDPA should succeed"; + ASSERT_NE(output, nullptr) << "Output should not be null"; + + // Verify output shape: [batch, num_heads, seq_len, head_dim] + EXPECT_TRUE(check_output_shape(output, {batch, num_heads, seq_len, head_dim})) + << "Output shape mismatch"; + + printf("SDPA Float32 test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + aoti_torch_delete_tensor_object(output); +} + +// Test SDPA with causal masking +TEST_F(AOTITorchSDPATest, CausalMasking) { + const int64_t batch = 1; + const int64_t num_heads = 1; + const int64_t seq_len = 8; + const int64_t head_dim = 16; + + Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); + Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); + Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); + + ASSERT_NE(query, nullptr); + ASSERT_NE(key, nullptr); + ASSERT_NE(value, nullptr); + + printf("Testing SDPA with causal masking: [%ldx%ldx%ldx%ld]\n", + batch, num_heads, seq_len, head_dim); + + // Call SDPA with causal mask using wrapper + Tensor* output = nullptr; + AOTITorchError error = call_flash_attention( + query, key, value, + 0.0, // no dropout + 1, // causal mask enabled + nullptr, // default scale + &output); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(output, nullptr); + EXPECT_TRUE(check_output_shape(output, {batch, num_heads, seq_len, head_dim})); + + printf("Causal masking test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + aoti_torch_delete_tensor_object(output); +} + +// Test SDPA with BFloat16 +TEST_F(AOTITorchSDPATest, BFloat16Precision) { + const int64_t batch = 2; + const int64_t num_heads = 4; + const int64_t seq_len = 16; + const int64_t head_dim = 32; + + Tensor* query = create_bfloat16_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); + Tensor* key = create_bfloat16_tensor({batch, num_heads, seq_len, head_dim}, 0.3f); + Tensor* value = create_bfloat16_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); + + ASSERT_NE(query, nullptr) << "Failed to create BFloat16 query tensor"; + ASSERT_NE(key, nullptr) << "Failed to create BFloat16 key tensor"; + ASSERT_NE(value, nullptr) << "Failed to create BFloat16 value tensor"; + + printf("Testing SDPA BFloat16: [%ldx%ldx%ldx%ld]\n", + batch, num_heads, seq_len, head_dim); + + Tensor* output = nullptr; + AOTITorchError error = call_flash_attention( + query, key, value, + 0.0, // no dropout + 0, // not causal + nullptr, // default scale + &output); + + EXPECT_EQ(error, Error::Ok) << "SDPA BFloat16 should succeed"; + ASSERT_NE(output, nullptr); + EXPECT_TRUE(check_output_shape(output, {batch, num_heads, seq_len, head_dim})); + + printf("BFloat16 precision test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + aoti_torch_delete_tensor_object(output); +} + +// Test SDPA with custom scale factor +TEST_F(AOTITorchSDPATest, CustomScale) { + const int64_t batch = 1; + const int64_t num_heads = 2; + const int64_t seq_len = 4; + const int64_t head_dim = 8; + + Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); + Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); + Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); + + ASSERT_NE(query, nullptr); + ASSERT_NE(key, nullptr); + ASSERT_NE(value, nullptr); + + printf("Testing SDPA with custom scale\n"); + + // Use custom scale instead of default 1/sqrt(head_dim) + double custom_scale = 0.25; + Tensor* output = nullptr; + AOTITorchError error = call_flash_attention( + query, key, value, + 0.0, // no dropout + 0, // not causal + &custom_scale, // custom scale + &output); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(output, nullptr); + EXPECT_TRUE(check_output_shape(output, {batch, num_heads, seq_len, head_dim})); + + printf("Custom scale test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + aoti_torch_delete_tensor_object(output); +} + +// Test with larger tensors (closer to real-world usage) +TEST_F(AOTITorchSDPATest, LargerTensors) { + const int64_t batch = 4; + const int64_t num_heads = 8; + const int64_t seq_len = 128; + const int64_t head_dim = 64; + + Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.1f); + Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.1f); + Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); + + ASSERT_NE(query, nullptr); + ASSERT_NE(key, nullptr); + ASSERT_NE(value, nullptr); + + printf("Testing SDPA with larger tensors: [%ldx%ldx%ldx%ld]\n", + batch, num_heads, seq_len, head_dim); + + Tensor* output = nullptr; + AOTITorchError error = call_flash_attention( + query, key, value, + 0.0, // no dropout + 1, // causal + nullptr, // default scale + &output); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(output, nullptr); + EXPECT_TRUE(check_output_shape(output, {batch, num_heads, seq_len, head_dim})); + + printf("Larger tensors test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + aoti_torch_delete_tensor_object(output); +} + +// ============================================================================ +// Error Handling Tests +// ============================================================================ + +// Test dimension mismatch +TEST_F(AOTITorchSDPATest, DimensionMismatch) { + Tensor* query = create_float_tensor({1, 2, 4, 8}, 0.5f); + Tensor* key = create_float_tensor({1, 2, 6, 8}, 0.5f); // Different seq_len + Tensor* value = create_float_tensor({1, 2, 6, 8}, 1.0f); + Tensor* output = nullptr; + + // This should succeed (Q and K can have different seq_len) + AOTITorchError error = call_flash_attention( + query, key, value, 0.0, 0, nullptr, &output); + + EXPECT_EQ(error, Error::Ok) << "Different Q and K seq_len should be allowed"; + + if (output != nullptr) { + // Output should have Q's seq_len + EXPECT_EQ(output->size(2), 4) << "Output seq_len should match Query"; + aoti_torch_delete_tensor_object(output); + } + + printf("Dimension handling test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); +} + +// Test dropout error (should fail since we don't support dropout) +TEST_F(AOTITorchSDPATest, DropoutNotSupported) { + Tensor* query = create_float_tensor({1, 1, 4, 8}, 0.5f); + Tensor* key = create_float_tensor({1, 1, 4, 8}, 0.5f); + Tensor* value = create_float_tensor({1, 1, 4, 8}, 1.0f); + Tensor* output = nullptr; + + AOTITorchError error = call_flash_attention( + query, key, value, 0.5, 0, nullptr, &output); // dropout=0.5 + + EXPECT_NE(error, Error::Ok) << "Should fail with non-zero dropout"; + + printf("Dropout rejection test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); +} + +// ============================================================================ +// Numerical Correctness Tests +// ============================================================================ + +// Test that output values are in reasonable range +TEST_F(AOTITorchSDPATest, OutputValueRangeCheck) { + const int64_t batch = 1; + const int64_t num_heads = 1; + const int64_t seq_len = 4; + const int64_t head_dim = 8; + + // Use small values to avoid numerical overflow + Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.1f); + Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.1f); + Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); + + ASSERT_NE(query, nullptr); + ASSERT_NE(key, nullptr); + ASSERT_NE(value, nullptr); + + printf("Testing SDPA output value range\n"); + + Tensor* output = nullptr; + AOTITorchError error = call_flash_attention( + query, key, value, 0.0, 0, nullptr, &output); + + ASSERT_EQ(error, Error::Ok); + ASSERT_NE(output, nullptr); + + // Copy output back to CPU for verification + std::vector output_data = copy_tensor_to_host(output); + + // Since V is all 1.0, and softmax produces weights that sum to 1, + // output should be close to 1.0 (weighted average of 1.0) + bool all_in_range = true; + for (size_t i = 0; i < output_data.size(); ++i) { + // Output should be around 1.0 with some tolerance + if (output_data[i] < 0.5f || output_data[i] > 1.5f) { + printf("Output[%zu] = %f is out of expected range [0.5, 1.5]\n", + i, output_data[i]); + all_in_range = false; + } + } + + EXPECT_TRUE(all_in_range) << "Some output values are out of reasonable range"; + + printf("Output value range check passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + aoti_torch_delete_tensor_object(output); +} + +// Test with identity Q=K, verify attention weights sum to 1 +TEST_F(AOTITorchSDPATest, IdentityQKTest) { + const int64_t batch = 1; + const int64_t num_heads = 1; + const int64_t seq_len = 4; + const int64_t head_dim = 8; + + // When Q=K, attention scores will be uniform (since all positions are equally similar) + Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); + Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); + Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim}, 2.0f); + + ASSERT_NE(query, nullptr); + ASSERT_NE(key, nullptr); + ASSERT_NE(value, nullptr); + + printf("Testing SDPA with Q=K (identity attention)\n"); + + Tensor* output = nullptr; + AOTITorchError error = call_flash_attention( + query, key, value, 0.0, 0, nullptr, &output); + + ASSERT_EQ(error, Error::Ok); + ASSERT_NE(output, nullptr); + + // Copy output back to CPU + std::vector output_data = copy_tensor_to_host(output); + + // When Q=K and V is uniform, output should be close to V + // (since attention weights are uniform due to identical scores) + bool values_correct = true; + for (size_t i = 0; i < output_data.size(); ++i) { + // Output should be close to 2.0 (the value of V) + if (!approx_equal(output_data[i], 2.0f, 0.1f)) { + printf("Output[%zu] = %f, expected ~2.0\n", i, output_data[i]); + values_correct = false; + } + } + + EXPECT_TRUE(values_correct) << "Output values don't match expected for identity Q=K"; + + printf("Identity Q=K test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + aoti_torch_delete_tensor_object(output); +} + +// Test that different scales produce different outputs +TEST_F(AOTITorchSDPATest, ScaleEffectTest) { + const int64_t batch = 1; + const int64_t num_heads = 1; + const int64_t seq_len = 4; + const int64_t head_dim = 8; + + Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); + Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.3f); + Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); + + ASSERT_NE(query, nullptr); + ASSERT_NE(key, nullptr); + ASSERT_NE(value, nullptr); + + // Make K different at different positions so attention scores vary + std::vector key_host(batch * num_heads * seq_len * head_dim); + for (int64_t pos = 0; pos < seq_len; ++pos) { + for (int64_t d = 0; d < head_dim; ++d) { + // Different values per position: pos 0=0.1, pos 1=0.3, pos 2=0.5, pos 3=0.7 + key_host[pos * head_dim + d] = 0.1f + 0.2f * pos; + } + } + cudaMemcpy( + key->data_ptr(), + key_host.data(), + key_host.size() * sizeof(float), + cudaMemcpyHostToDevice); + + // Make V also different at different positions to amplify differences + std::vector value_host(batch * num_heads * seq_len * head_dim); + for (int64_t pos = 0; pos < seq_len; ++pos) { + for (int64_t d = 0; d < head_dim; ++d) { + // V values: pos 0=1.0, pos 1=2.0, pos 2=3.0, pos 3=4.0 + value_host[pos * head_dim + d] = static_cast(pos + 1); + } + } + cudaMemcpy( + value->data_ptr(), + value_host.data(), + value_host.size() * sizeof(float), + cudaMemcpyHostToDevice); + + printf("Testing SDPA scale effect\n"); + + // Test with default scale + Tensor* output1 = nullptr; + AOTITorchError error1 = call_flash_attention( + query, key, value, 0.0, 0, nullptr, &output1); + ASSERT_EQ(error1, Error::Ok); + ASSERT_NE(output1, nullptr); + + // Test with custom scale (much smaller, should make attention more uniform) + double small_scale = 0.01; + Tensor* output2 = nullptr; + AOTITorchError error2 = call_flash_attention( + query, key, value, 0.0, 0, &small_scale, &output2); + ASSERT_EQ(error2, Error::Ok); + ASSERT_NE(output2, nullptr); + + // Copy outputs back to CPU + std::vector output1_data = copy_tensor_to_host(output1); + std::vector output2_data = copy_tensor_to_host(output2); + + // Outputs should be different (scale affects softmax sharpness) + // With varied V values, even small changes in attention weights will produce + // noticeably different outputs + bool outputs_differ = false; + float max_diff = 0.0f; + for (size_t i = 0; i < output1_data.size(); ++i) { + float diff = std::abs(output1_data[i] - output2_data[i]); + max_diff = std::max(max_diff, diff); + if (diff > 0.05f) { // More lenient threshold due to varied V values + outputs_differ = true; + break; + } + } + + printf("Max difference between outputs: %f\n", max_diff); + EXPECT_TRUE(outputs_differ) << "Different scales should produce different outputs (max_diff=" << max_diff << ")"; + + printf("Scale effect test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + aoti_torch_delete_tensor_object(output1); + aoti_torch_delete_tensor_object(output2); +} + +// Test causal masking correctness +TEST_F(AOTITorchSDPATest, CausalMaskingCorrectness) { + const int64_t batch = 1; + const int64_t num_heads = 1; + const int64_t seq_len = 4; + const int64_t head_dim = 8; + + // Create distinct values at different positions in V + // This allows us to verify that causal masking works correctly + Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); + Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); + Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.0f); + + ASSERT_NE(query, nullptr); + ASSERT_NE(key, nullptr); + ASSERT_NE(value, nullptr); + + // Manually set different values for each position in V + // V[position i] = i+1 (so we can track which positions contribute) + std::vector value_host(batch * num_heads * seq_len * head_dim); + for (int64_t pos = 0; pos < seq_len; ++pos) { + for (int64_t d = 0; d < head_dim; ++d) { + value_host[pos * head_dim + d] = static_cast(pos + 1); + } + } + cudaMemcpy( + value->data_ptr(), + value_host.data(), + value_host.size() * sizeof(float), + cudaMemcpyHostToDevice); + + printf("Testing SDPA causal masking correctness\n"); + + // Run with causal masking + Tensor* output_causal = nullptr; + AOTITorchError error = call_flash_attention( + query, key, value, 0.0, 1, nullptr, &output_causal); + ASSERT_EQ(error, Error::Ok); + ASSERT_NE(output_causal, nullptr); + + // Copy output back to CPU + std::vector output_data = copy_tensor_to_host(output_causal); + + // With causal masking: + // - Position 0 can only see position 0, so output[0] should be ~1.0 + // - Position 1 can see positions 0,1, so output[1] should be ~1.5 (average of 1 and 2) + // - Position 2 can see positions 0,1,2, so output[2] should be ~2.0 (average of 1,2,3) + // - Position 3 can see all, so output[3] should be ~2.5 (average of 1,2,3,4) + + std::vector expected_values = {1.0f, 1.5f, 2.0f, 2.5f}; + + bool causal_correct = true; + for (int64_t pos = 0; pos < seq_len; ++pos) { + float avg_output = 0.0f; + for (int64_t d = 0; d < head_dim; ++d) { + avg_output += output_data[pos * head_dim + d]; + } + avg_output /= head_dim; + + printf("Position %ld: output avg = %f, expected ~%f\n", + pos, avg_output, expected_values[pos]); + + if (!approx_equal(avg_output, expected_values[pos], 0.2f)) { + causal_correct = false; + } + } + + EXPECT_TRUE(causal_correct) << "Causal masking did not produce expected values"; + + printf("Causal masking correctness test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + aoti_torch_delete_tensor_object(output_causal); +} + +// ============================================================================ +// Flash Attention Specific Tests +// ============================================================================ + + + + +// Test that Flash Attention is selected for longer sequences +TEST_F(AOTITorchSDPATest, FlashAttentionLongSequence) { + const int64_t batch = 2; + const int64_t num_heads = 4; + const int64_t seq_len = 256; // Long sequence to trigger Flash Attention + const int64_t head_dim = 64; + + Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.1f); + Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.1f); + Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); + + ASSERT_NE(query, nullptr); + ASSERT_NE(key, nullptr); + ASSERT_NE(value, nullptr); + + printf("Testing Flash Attention with long sequence: [%ldx%ldx%ldx%ld]\n", + batch, num_heads, seq_len, head_dim); + + // Verify that can_use_flash_attention returns true for this configuration + bool can_use_fa = can_use_flash_attention(query, key, value, nullptr, false); + EXPECT_TRUE(can_use_fa) << "Should be able to use Flash Attention for long sequences"; + printf(" can_use_flash_attention: %s\n", can_use_fa ? "true" : "false"); + + Tensor* output = nullptr; + AOTITorchError error = call_flash_attention( + query, key, value, 0.0, 0, nullptr, &output); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(output, nullptr); + EXPECT_TRUE(check_output_shape(output, {batch, num_heads, seq_len, head_dim})); + + printf("Flash Attention long sequence test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + aoti_torch_delete_tensor_object(output); +} + +// Test Flash Attention with causal masking +TEST_F(AOTITorchSDPATest, FlashAttentionCausalMasking) { + const int64_t batch = 1; + const int64_t num_heads = 2; + const int64_t seq_len = 128; // Long enough for Flash Attention + const int64_t head_dim = 64; + + Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); + Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); + Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); + + ASSERT_NE(query, nullptr); + ASSERT_NE(key, nullptr); + ASSERT_NE(value, nullptr); + + printf("Testing Flash Attention with causal masking\n"); + + // Verify that can_use_flash_attention returns true for causal masking + bool can_use_fa = can_use_flash_attention(query, key, value, nullptr, true); + EXPECT_TRUE(can_use_fa) << "Should be able to use Flash Attention with causal masking"; + printf(" can_use_flash_attention (causal): %s\n", can_use_fa ? "true" : "false"); + + Tensor* output = nullptr; + AOTITorchError error = call_flash_attention( + query, key, value, 0.0, 1, nullptr, &output); // causal=1 + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(output, nullptr); + EXPECT_TRUE(check_output_shape(output, {batch, num_heads, seq_len, head_dim})); + + printf("Flash Attention causal masking test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + aoti_torch_delete_tensor_object(output); +} + +// Test Flash Attention with BFloat16 +TEST_F(AOTITorchSDPATest, FlashAttentionBFloat16) { + const int64_t batch = 1; + const int64_t num_heads = 4; + const int64_t seq_len = 64; + const int64_t head_dim = 64; + + Tensor* query = create_bfloat16_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); + Tensor* key = create_bfloat16_tensor({batch, num_heads, seq_len, head_dim}, 0.3f); + Tensor* value = create_bfloat16_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); + + ASSERT_NE(query, nullptr); + ASSERT_NE(key, nullptr); + ASSERT_NE(value, nullptr); + + printf("Testing Flash Attention with BFloat16\n"); + + // Verify that can_use_flash_attention returns true for BFloat16 + bool can_use_fa = can_use_flash_attention(query, key, value, nullptr, false); + EXPECT_TRUE(can_use_fa) << "Should be able to use Flash Attention with BFloat16"; + printf(" can_use_flash_attention (BFloat16): %s\n", can_use_fa ? "true" : "false"); + + Tensor* output = nullptr; + AOTITorchError error = call_flash_attention( + query, key, value, 0.0, 0, nullptr, &output); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(output, nullptr); + EXPECT_TRUE(check_output_shape(output, {batch, num_heads, seq_len, head_dim})); + + printf("Flash Attention BFloat16 test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + aoti_torch_delete_tensor_object(output); +} + +// Test correctness: Compare Flash Attention output with Math fallback +TEST_F(AOTITorchSDPATest, FlashAttentionCorrectness) { + const int64_t batch = 1; + const int64_t num_heads = 1; + const int64_t seq_len = 32; // Short enough to potentially use Math fallback + const int64_t head_dim = 32; + + // Create inputs with known values + Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); + Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.3f); + Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); + + ASSERT_NE(query, nullptr); + ASSERT_NE(key, nullptr); + ASSERT_NE(value, nullptr); + + printf("Testing Flash Attention numerical correctness\n"); + + // Verify that can_use_flash_attention returns true for this configuration + bool can_use_fa = can_use_flash_attention(query, key, value, nullptr, false); + EXPECT_TRUE(can_use_fa) << "Should be able to use Flash Attention for this configuration"; + printf(" can_use_flash_attention: %s\n", can_use_fa ? "true" : "false"); + + // Run SDPA (will auto-select backend) + Tensor* output = nullptr; + AOTITorchError error = call_flash_attention( + query, key, value, 0.0, 0, nullptr, &output); + + ASSERT_EQ(error, Error::Ok); + ASSERT_NE(output, nullptr); + + // Copy output back to CPU for validation + std::vector output_data = copy_tensor_to_host(output); + + // Since V is all 1.0 and softmax weights sum to 1, output should be close to 1.0 + bool all_correct = true; + for (size_t i = 0; i < output_data.size(); ++i) { + if (!approx_equal(output_data[i], 1.0f, 0.1f)) { + printf("Output[%zu] = %f, expected ~1.0\n", i, output_data[i]); + all_correct = false; + } + } + + EXPECT_TRUE(all_correct) << "Flash Attention output doesn't match expected values"; + + printf("Flash Attention correctness test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + aoti_torch_delete_tensor_object(output); +} + +// Test Flash Attention with different query and key sequence lengths +TEST_F(AOTITorchSDPATest, FlashAttentionDifferentSeqLengths) { + const int64_t batch = 1; + const int64_t num_heads = 2; + const int64_t seq_len_q = 64; + const int64_t seq_len_k = 128; // K/V longer than Q + const int64_t head_dim = 32; + + Tensor* query = create_float_tensor({batch, num_heads, seq_len_q, head_dim}, 0.5f); + Tensor* key = create_float_tensor({batch, num_heads, seq_len_k, head_dim}, 0.5f); + Tensor* value = create_float_tensor({batch, num_heads, seq_len_k, head_dim}, 1.0f); + + ASSERT_NE(query, nullptr); + ASSERT_NE(key, nullptr); + ASSERT_NE(value, nullptr); + + printf("Testing Flash Attention with different Q and K/V sequence lengths\n"); + + // Verify that can_use_flash_attention returns true for different seq lengths + bool can_use_fa = can_use_flash_attention(query, key, value, nullptr, false); + EXPECT_TRUE(can_use_fa) << "Should be able to use Flash Attention with different seq lengths"; + printf(" can_use_flash_attention (different seq lengths): %s\n", can_use_fa ? "true" : "false"); + + Tensor* output = nullptr; + AOTITorchError error = call_flash_attention( + query, key, value, 0.0, 0, nullptr, &output); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(output, nullptr); + EXPECT_TRUE(check_output_shape(output, {batch, num_heads, seq_len_q, head_dim})); + + printf("Flash Attention different seq lengths test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + aoti_torch_delete_tensor_object(output); +} + +// Test that explicit attention mask forces Math fallback instead of Flash Attention +TEST_F(AOTITorchSDPATest, ExplicitMaskFallsBackToMath) { + const int64_t batch = 1; + const int64_t num_heads = 1; + const int64_t seq_len = 128; // Long enough to prefer Flash Attention + const int64_t head_dim = 64; + + Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); + Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); + Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); + + // Create explicit attention mask (will force Math fallback) + Tensor* attn_mask = create_float_tensor({batch, num_heads, seq_len, seq_len}, 0.0f); + + ASSERT_NE(query, nullptr); + ASSERT_NE(key, nullptr); + ASSERT_NE(value, nullptr); + ASSERT_NE(attn_mask, nullptr); + + printf("Testing that explicit mask forces Math fallback\n"); + + // Verify that can_use_flash_attention returns false when explicit mask is provided + bool can_use_fa = can_use_flash_attention(query, key, value, attn_mask, false); + EXPECT_FALSE(can_use_fa) << "Should NOT be able to use Flash Attention with explicit mask"; + printf(" can_use_flash_attention (with explicit mask): %s\n", can_use_fa ? "true" : "false"); + + Tensor* output = nullptr; + AOTITorchError error = call_efficient_attention( + query, key, value, attn_mask, 0, nullptr, &output); + + // Should succeed but use Math fallback instead of Flash Attention + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(output, nullptr); + EXPECT_TRUE(check_output_shape(output, {batch, num_heads, seq_len, head_dim})); + + printf("Explicit mask fallback test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + aoti_torch_delete_tensor_object(attn_mask); + aoti_torch_delete_tensor_object(output); +} + +// ============================================================================ +// can_use_flash_attention Function Tests +// ============================================================================ + +// Test can_use_flash_attention returns true for valid configuration +TEST_F(AOTITorchSDPATest, CanUseFlashAttention_ValidConfig) { + // Create tensors with valid flash attention configuration: + // - head_dim <= 128 + // - head_dim_v <= 64 + // - At least one sequence length >= 32 + // - No explicit attention mask + const int64_t batch = 2; + const int64_t num_heads = 8; + const int64_t seq_len = 64; // >= 32 + const int64_t head_dim = 64; // <= 128 + + Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); + Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); + Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); + + ASSERT_NE(query, nullptr); + ASSERT_NE(key, nullptr); + ASSERT_NE(value, nullptr); + + printf("Testing can_use_flash_attention with valid config: " + "seq_len=%ld, head_dim=%ld\n", seq_len, head_dim); + + // Test without attention mask, should return true + bool can_use = can_use_flash_attention( + query, key, value, + nullptr, // no attention mask + false); // not causal + + EXPECT_TRUE(can_use) << "Should allow flash attention for valid configuration"; + + // Also test with causal masking, should still return true + bool can_use_causal = can_use_flash_attention( + query, key, value, + nullptr, // no attention mask + true); // causal + + EXPECT_TRUE(can_use_causal) << "Should allow flash attention with causal masking"; + + printf("Flash attention valid config test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); +} + +// Test can_use_flash_attention returns false for various invalid configurations +TEST_F(AOTITorchSDPATest, CanUseFlashAttention_InvalidConfigs) { + const int64_t batch = 2; + const int64_t num_heads = 8; + + printf("Testing can_use_flash_attention with invalid configurations\n"); + + // Test 1: Explicit attention mask provided + { + Tensor* query = create_float_tensor({batch, num_heads, 64, 64}, 0.5f); + Tensor* key = create_float_tensor({batch, num_heads, 64, 64}, 0.5f); + Tensor* value = create_float_tensor({batch, num_heads, 64, 64}, 1.0f); + Tensor* attn_mask = create_float_tensor({batch, num_heads, 64, 64}, 0.0f); + + ASSERT_NE(query, nullptr); + ASSERT_NE(key, nullptr); + ASSERT_NE(value, nullptr); + ASSERT_NE(attn_mask, nullptr); + + bool can_use = can_use_flash_attention( + query, key, value, attn_mask, false); + + EXPECT_FALSE(can_use) << "Should reject flash attention with explicit attention mask"; + printf(" - Explicit attention mask: correctly rejected\n"); + + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + aoti_torch_delete_tensor_object(attn_mask); + } + + // Test 2: head_dim > 128 + { + const int64_t large_head_dim = 256; // > 128 + Tensor* query = create_float_tensor({batch, num_heads, 64, large_head_dim}, 0.5f); + Tensor* key = create_float_tensor({batch, num_heads, 64, large_head_dim}, 0.5f); + Tensor* value = create_float_tensor({batch, num_heads, 64, large_head_dim}, 1.0f); + + ASSERT_NE(query, nullptr); + ASSERT_NE(key, nullptr); + ASSERT_NE(value, nullptr); + + bool can_use = can_use_flash_attention( + query, key, value, nullptr, false); + + EXPECT_FALSE(can_use) << "Should reject flash attention with head_dim > 128"; + printf(" - head_dim > 128: correctly rejected\n"); + + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + } + + // Test 3: head_dim_v > 64 + { + const int64_t large_head_dim_v = 128; // > 64 + Tensor* query = create_float_tensor({batch, num_heads, 64, 64}, 0.5f); + Tensor* key = create_float_tensor({batch, num_heads, 64, 64}, 0.5f); + Tensor* value = create_float_tensor({batch, num_heads, 64, large_head_dim_v}, 1.0f); + + ASSERT_NE(query, nullptr); + ASSERT_NE(key, nullptr); + ASSERT_NE(value, nullptr); + + bool can_use = can_use_flash_attention( + query, key, value, nullptr, false); + + EXPECT_FALSE(can_use) << "Should reject flash attention with head_dim_v > 64"; + printf(" - head_dim_v > 64: correctly rejected\n"); + + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + } + + printf("Flash attention invalid configs test passed!\n"); +} + +// ============================================================================ +// Efficient Attention Specific Tests +// ============================================================================ + +// Test basic efficient attention without attention bias +TEST_F(AOTITorchSDPATest, EfficientAttention_BasicNoBias) { + const int64_t batch = 2; + const int64_t num_heads = 4; + const int64_t seq_len = 32; + const int64_t head_dim = 64; + + Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); + Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.3f); + Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); + + ASSERT_NE(query, nullptr); + ASSERT_NE(key, nullptr); + ASSERT_NE(value, nullptr); + + printf("Testing Efficient Attention without bias: [%ldx%ldx%ldx%ld]\n", + batch, num_heads, seq_len, head_dim); + + Tensor* output = nullptr; + AOTITorchError error = call_efficient_attention( + query, key, value, nullptr, 0, nullptr, &output); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(output, nullptr); + EXPECT_TRUE(check_output_shape(output, {batch, num_heads, seq_len, head_dim})); + + printf("Efficient attention basic test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + aoti_torch_delete_tensor_object(output); +} + +// Test efficient attention with full attention bias (no broadcasting) +TEST_F(AOTITorchSDPATest, EfficientAttention_WithBias) { + const int64_t batch = 2; + const int64_t num_heads = 4; + const int64_t seq_len = 16; + const int64_t head_dim = 32; + + Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); + Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.3f); + Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); + Tensor* attn_bias = create_float_tensor({batch, num_heads, seq_len, seq_len}, 0.1f); + + ASSERT_NE(query, nullptr); + ASSERT_NE(key, nullptr); + ASSERT_NE(value, nullptr); + ASSERT_NE(attn_bias, nullptr); + + printf("Testing Efficient Attention with full bias: [%ldx%ldx%ldx%ld]\n", + batch, num_heads, seq_len, seq_len); + + Tensor* output = nullptr; + AOTITorchError error = call_efficient_attention( + query, key, value, attn_bias, 0, nullptr, &output); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(output, nullptr); + EXPECT_TRUE(check_output_shape(output, {batch, num_heads, seq_len, head_dim})); + + printf("Backend selection test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + aoti_torch_delete_tensor_object(attn_bias); + aoti_torch_delete_tensor_object(output); +} + +// Test efficient attention with broadcasted bias (batch broadcast) +TEST_F(AOTITorchSDPATest, EfficientAttention_BroadcastBatchDim) { + const int64_t batch = 4; + const int64_t num_heads = 2; + const int64_t seq_len = 8; + const int64_t head_dim = 16; + + Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); + Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.3f); + Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); + // Broadcast across batch dimension: [1, num_heads, seq_len, seq_len] + Tensor* attn_bias = create_float_tensor({1, num_heads, seq_len, seq_len}, 0.2f); + + ASSERT_NE(query, nullptr); + ASSERT_NE(key, nullptr); + ASSERT_NE(value, nullptr); + ASSERT_NE(attn_bias, nullptr); + + printf("Testing Efficient Attention with batch-broadcasted bias\n"); + + Tensor* output = nullptr; + AOTITorchError error = call_efficient_attention( + query, key, value, attn_bias, 0, nullptr, &output); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(output, nullptr); + EXPECT_TRUE(check_output_shape(output, {batch, num_heads, seq_len, head_dim})); + + printf("Efficient attention with batch broadcast test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + aoti_torch_delete_tensor_object(attn_bias); + aoti_torch_delete_tensor_object(output); +} + +// Test efficient attention with broadcasted bias (head broadcast) +TEST_F(AOTITorchSDPATest, EfficientAttention_BroadcastHeadDim) { + const int64_t batch = 2; + const int64_t num_heads = 8; + const int64_t seq_len = 16; + const int64_t head_dim = 32; + + Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); + Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.3f); + Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); + // Broadcast across head dimension: [batch, 1, seq_len, seq_len] + Tensor* attn_bias = create_float_tensor({batch, 1, seq_len, seq_len}, -0.1f); + + ASSERT_NE(query, nullptr); + ASSERT_NE(key, nullptr); + ASSERT_NE(value, nullptr); + ASSERT_NE(attn_bias, nullptr); + + printf("Testing Efficient Attention with head-broadcasted bias\n"); + + Tensor* output = nullptr; + AOTITorchError error = call_efficient_attention( + query, key, value, attn_bias, 0, nullptr, &output); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(output, nullptr); + EXPECT_TRUE(check_output_shape(output, {batch, num_heads, seq_len, head_dim})); + + printf("Efficient attention with head broadcast test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + aoti_torch_delete_tensor_object(attn_bias); + aoti_torch_delete_tensor_object(output); +} + +// Test efficient attention with causal masking and bias +TEST_F(AOTITorchSDPATest, EfficientAttention_CausalWithBias) { + const int64_t batch = 1; + const int64_t num_heads = 2; + const int64_t seq_len = 32; + const int64_t head_dim = 64; + + Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); + Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); + Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); + Tensor* attn_bias = create_float_tensor({batch, num_heads, seq_len, seq_len}, 0.05f); + + ASSERT_NE(query, nullptr); + ASSERT_NE(key, nullptr); + ASSERT_NE(value, nullptr); + ASSERT_NE(attn_bias, nullptr); + + printf("Testing Efficient Attention with causal masking and bias\n"); + + Tensor* output = nullptr; + AOTITorchError error = call_efficient_attention( + query, key, value, attn_bias, 1, nullptr, &output); // causal=1 + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(output, nullptr); + EXPECT_TRUE(check_output_shape(output, {batch, num_heads, seq_len, head_dim})); + + printf("Efficient attention causal with bias test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + aoti_torch_delete_tensor_object(attn_bias); + aoti_torch_delete_tensor_object(output); +} + +// Test efficient attention with BFloat16 +TEST_F(AOTITorchSDPATest, EfficientAttention_BFloat16) { + const int64_t batch = 2; + const int64_t num_heads = 4; + const int64_t seq_len = 64; + const int64_t head_dim = 64; + + Tensor* query = create_bfloat16_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); + Tensor* key = create_bfloat16_tensor({batch, num_heads, seq_len, head_dim}, 0.3f); + Tensor* value = create_bfloat16_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); + + ASSERT_NE(query, nullptr); + ASSERT_NE(key, nullptr); + ASSERT_NE(value, nullptr); + + printf("Testing Efficient Attention with BFloat16\n"); + + Tensor* output = nullptr; + AOTITorchError error = call_efficient_attention( + query, key, value, nullptr, 0, nullptr, &output); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(output, nullptr); + EXPECT_TRUE(check_output_shape(output, {batch, num_heads, seq_len, head_dim})); + + printf("Efficient attention BFloat16 test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + aoti_torch_delete_tensor_object(output); +} + +// Test efficient attention with different Q and K/V sequence lengths +TEST_F(AOTITorchSDPATest, EfficientAttention_DifferentSeqLengths) { + const int64_t batch = 2; + const int64_t num_heads = 4; + const int64_t seq_len_q = 16; + const int64_t seq_len_kv = 32; // K/V longer than Q + const int64_t head_dim = 32; + + Tensor* query = create_float_tensor({batch, num_heads, seq_len_q, head_dim}, 0.5f); + Tensor* key = create_float_tensor({batch, num_heads, seq_len_kv, head_dim}, 0.3f); + Tensor* value = create_float_tensor({batch, num_heads, seq_len_kv, head_dim}, 1.0f); + Tensor* attn_bias = create_float_tensor({batch, num_heads, seq_len_q, seq_len_kv}, 0.1f); + + ASSERT_NE(query, nullptr); + ASSERT_NE(key, nullptr); + ASSERT_NE(value, nullptr); + ASSERT_NE(attn_bias, nullptr); + + printf("Testing Efficient Attention with different Q (%ld) and K/V (%ld) lengths\n", + seq_len_q, seq_len_kv); + + Tensor* output = nullptr; + AOTITorchError error = call_efficient_attention( + query, key, value, attn_bias, 0, nullptr, &output); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(output, nullptr); + EXPECT_TRUE(check_output_shape(output, {batch, num_heads, seq_len_q, head_dim})); + + printf("Efficient attention different seq lengths test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + aoti_torch_delete_tensor_object(attn_bias); + aoti_torch_delete_tensor_object(output); +} + +// Test efficient attention with very large head_dim_v (up to 128) +TEST_F(AOTITorchSDPATest, EfficientAttention_LargeHeadDimV) { + const int64_t batch = 1; + const int64_t num_heads = 2; + const int64_t seq_len = 16; + const int64_t head_dim = 64; + const int64_t head_dim_v = 128; // Maximum supported by efficient attention + + Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); + Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.3f); + Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim_v}, 1.0f); + + ASSERT_NE(query, nullptr); + ASSERT_NE(key, nullptr); + ASSERT_NE(value, nullptr); + + printf("Testing Efficient Attention with large head_dim_v=%ld\n", head_dim_v); + + // Verify that can_use_efficient_attention returns true + bool can_use = can_use_efficient_attention(query, key, value, nullptr, false); + EXPECT_TRUE(can_use) << "Should support head_dim_v=128"; + + Tensor* output = nullptr; + AOTITorchError error = call_efficient_attention( + query, key, value, nullptr, 0, nullptr, &output); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(output, nullptr); + EXPECT_TRUE(check_output_shape(output, {batch, num_heads, seq_len, head_dim_v})); + + printf("Efficient attention large head_dim_v test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + aoti_torch_delete_tensor_object(output); +} + +// Test numerical correctness: bias affects output +TEST_F(AOTITorchSDPATest, EfficientAttention_BiasAffectsOutput) { + const int64_t batch = 1; + const int64_t num_heads = 1; + const int64_t seq_len = 8; + const int64_t head_dim = 16; + + Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); + Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.3f); + + // Create V with different values at each position so attention weight changes matter + Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.0f); + std::vector value_host(batch * num_heads * seq_len * head_dim); + for (int64_t pos = 0; pos < seq_len; ++pos) { + for (int64_t d = 0; d < head_dim; ++d) { + // V values: pos 0=1.0, pos 1=2.0, pos 2=3.0, etc. + value_host[pos * head_dim + d] = static_cast(pos + 1); + } + } + cudaMemcpy( + value->data_ptr(), + value_host.data(), + value_host.size() * sizeof(float), + cudaMemcpyHostToDevice); + + ASSERT_NE(query, nullptr); + ASSERT_NE(key, nullptr); + ASSERT_NE(value, nullptr); + + printf("Testing that attention bias affects output\n"); + + // Run without bias + Tensor* output_no_bias = nullptr; + AOTITorchError error1 = call_efficient_attention( + query, key, value, nullptr, 0, nullptr, &output_no_bias); + ASSERT_EQ(error1, Error::Ok); + ASSERT_NE(output_no_bias, nullptr); + + // Create bias tensor with varied values (not uniform) + // A uniform bias doesn't change softmax output, so we need variation + Tensor* attn_bias = create_float_tensor({batch, num_heads, seq_len, seq_len}, 0.0f); + ASSERT_NE(attn_bias, nullptr); + + // Fill bias with varied values to actually affect attention distribution + // Use a strong pattern that will significantly change attention weights + std::vector bias_host(batch * num_heads * seq_len * seq_len); + for (int64_t i = 0; i < seq_len; ++i) { + for (int64_t j = 0; j < seq_len; ++j) { + // Create a diagonal pattern with large values to strongly affect attention + if (i == j) { + bias_host[i * seq_len + j] = 10.0f; // Large positive bias for diagonal + } else { + bias_host[i * seq_len + j] = -5.0f; // Negative bias for off-diagonal + } + } + } + cudaMemcpy( + attn_bias->data_ptr(), + bias_host.data(), + bias_host.size() * sizeof(float), + cudaMemcpyHostToDevice); + + // Run with bias + Tensor* output_with_bias = nullptr; + AOTITorchError error2 = call_efficient_attention( + query, key, value, attn_bias, 0, nullptr, &output_with_bias); + ASSERT_EQ(error2, Error::Ok); + ASSERT_NE(output_with_bias, nullptr); + + // Verify both computations succeeded + EXPECT_TRUE(check_output_shape(output_no_bias, {batch, num_heads, seq_len, head_dim})); + EXPECT_TRUE(check_output_shape(output_with_bias, {batch, num_heads, seq_len, head_dim})); + + // Copy outputs to host and verify they differ + std::vector output_no_bias_data = copy_tensor_to_host(output_no_bias); + std::vector output_with_bias_data = copy_tensor_to_host(output_with_bias); + + // Outputs should be different when bias is non-uniform + bool outputs_differ = false; + float max_diff = 0.0f; + for (size_t i = 0; i < output_no_bias_data.size(); ++i) { + float diff = std::abs(output_no_bias_data[i] - output_with_bias_data[i]); + max_diff = std::max(max_diff, diff); + if (diff > 0.001f) { // Small threshold for float comparison + outputs_differ = true; + break; + } + } + + printf("Max difference between outputs: %f\n", max_diff); + EXPECT_TRUE(outputs_differ) << "Attention bias should affect the output (max_diff=" << max_diff << ")"; + + printf("Bias affects output test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + aoti_torch_delete_tensor_object(attn_bias); + aoti_torch_delete_tensor_object(output_no_bias); + aoti_torch_delete_tensor_object(output_with_bias); +} + +// Test numerical correctness: negative bias masks out positions +TEST_F(AOTITorchSDPATest, EfficientAttention_NegativeBiasMasking) { + const int64_t batch = 1; + const int64_t num_heads = 1; + const int64_t seq_len = 4; + const int64_t head_dim = 8; + + Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); + Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); + + // Create distinct values at each position in V + Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.0f); + std::vector value_host(batch * num_heads * seq_len * head_dim); + for (int64_t pos = 0; pos < seq_len; ++pos) { + for (int64_t d = 0; d < head_dim; ++d) { + value_host[pos * head_dim + d] = static_cast(pos + 1); + } + } + cudaMemcpy( + value->data_ptr(), + value_host.data(), + value_host.size() * sizeof(float), + cudaMemcpyHostToDevice); + + // Create bias that masks out the last position (apply large negative bias) + Tensor* attn_bias = create_float_tensor({batch, num_heads, seq_len, seq_len}, 0.0f); + std::vector bias_host(batch * num_heads * seq_len * seq_len, 0.0f); + // Mask out last position for all queries + for (int64_t q = 0; q < seq_len; ++q) { + bias_host[q * seq_len + (seq_len - 1)] = -10000.0f; // Large negative value + } + cudaMemcpy( + attn_bias->data_ptr(), + bias_host.data(), + bias_host.size() * sizeof(float), + cudaMemcpyHostToDevice); + + ASSERT_NE(query, nullptr); + ASSERT_NE(key, nullptr); + ASSERT_NE(value, nullptr); + ASSERT_NE(attn_bias, nullptr); + + printf("Testing negative bias masking\n"); + + Tensor* output = nullptr; + AOTITorchError error = call_efficient_attention( + query, key, value, attn_bias, 0, nullptr, &output); + ASSERT_EQ(error, Error::Ok); + ASSERT_NE(output, nullptr); + + // Copy output to host + std::vector output_data = copy_tensor_to_host(output); + + // With Q=K, without masking, output would be average of all positions: (1+2+3+4)/4 = 2.5 + // With last position masked, output should be: (1+2+3)/3 = 2.0 + float avg_output = 0.0f; + for (size_t i = 0; i < output_data.size(); ++i) { + avg_output += output_data[i]; + } + avg_output /= output_data.size(); + + printf(" Average output: %f (expected ~2.0)\n", avg_output); + EXPECT_TRUE(approx_equal(avg_output, 2.0f, 0.2f)) + << "Masked position should not contribute to output"; + + printf("Negative bias masking test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + aoti_torch_delete_tensor_object(attn_bias); + aoti_torch_delete_tensor_object(output); +} + +// Test with custom scale and bias +TEST_F(AOTITorchSDPATest, EfficientAttention_CustomScaleWithBias) { + const int64_t batch = 1; + const int64_t num_heads = 2; + const int64_t seq_len = 16; + const int64_t head_dim = 32; + + Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.5f); + Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.3f); + Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); + Tensor* attn_bias = create_float_tensor({batch, num_heads, seq_len, seq_len}, 0.1f); + + ASSERT_NE(query, nullptr); + ASSERT_NE(key, nullptr); + ASSERT_NE(value, nullptr); + ASSERT_NE(attn_bias, nullptr); + + printf("Testing Efficient Attention with custom scale and bias\n"); + + double custom_scale = 0.1; + Tensor* output = nullptr; + AOTITorchError error = call_efficient_attention( + query, key, value, attn_bias, 0, &custom_scale, &output); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(output, nullptr); + EXPECT_TRUE(check_output_shape(output, {batch, num_heads, seq_len, head_dim})); + + printf("Custom scale with bias test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + aoti_torch_delete_tensor_object(attn_bias); + aoti_torch_delete_tensor_object(output); +} + +// Test with edge case: all zeros in query +TEST_F(AOTITorchSDPATest, EfficientAttention_ZeroQuery) { + const int64_t batch = 1; + const int64_t num_heads = 1; + const int64_t seq_len = 8; + const int64_t head_dim = 16; + + Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.0f); + Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); + Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim}, 2.0f); + + ASSERT_NE(query, nullptr); + ASSERT_NE(key, nullptr); + ASSERT_NE(value, nullptr); + + printf("Testing Efficient Attention with zero query\n"); + + Tensor* output = nullptr; + AOTITorchError error = call_efficient_attention( + query, key, value, nullptr, 0, nullptr, &output); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(output, nullptr); + EXPECT_TRUE(check_output_shape(output, {batch, num_heads, seq_len, head_dim})); + + // With zero query, attention scores are uniform, output should be average of V + std::vector output_data = copy_tensor_to_host(output); + bool all_close_to_2 = true; + for (size_t i = 0; i < output_data.size(); ++i) { + if (!approx_equal(output_data[i], 2.0f, 0.1f)) { + all_close_to_2 = false; + break; + } + } + EXPECT_TRUE(all_close_to_2) << "With uniform attention, output should equal V"; + + printf("Zero query test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + aoti_torch_delete_tensor_object(output); +} + +// Test long sequences (stress test) +TEST_F(AOTITorchSDPATest, EfficientAttention_LongSequence) { + const int64_t batch = 1; + const int64_t num_heads = 4; + const int64_t seq_len = 512; // Long sequence + const int64_t head_dim = 64; + + Tensor* query = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.1f); + Tensor* key = create_float_tensor({batch, num_heads, seq_len, head_dim}, 0.1f); + Tensor* value = create_float_tensor({batch, num_heads, seq_len, head_dim}, 1.0f); + + ASSERT_NE(query, nullptr); + ASSERT_NE(key, nullptr); + ASSERT_NE(value, nullptr); + + printf("Testing Efficient Attention with long sequence: %ld\n", seq_len); + + Tensor* output = nullptr; + AOTITorchError error = call_efficient_attention( + query, key, value, nullptr, 0, nullptr, &output); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(output, nullptr); + EXPECT_TRUE(check_output_shape(output, {batch, num_heads, seq_len, head_dim})); + + printf("Long sequence test passed!\n"); + + // Cleanup + aoti_torch_delete_tensor_object(query); + aoti_torch_delete_tensor_object(key); + aoti_torch_delete_tensor_object(value); + aoti_torch_delete_tensor_object(output); +}