Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions flash-attn2/flash_attn_xpu/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,16 @@ mha_fwd(
out_padded = torch::zeros_like(q_padded);
}

bool is_local = (window_size_left != -1) | (window_size_right != -1);

q_padded = ensure_contiguous(q_padded);
k_padded = ensure_contiguous(k_padded);
v_padded = ensure_contiguous(v_padded);
cutlass::flash_attention::fixed::cutlass_fixed_impl(q_padded, k_padded, v_padded, out_padded, softmax_scale, is_causal);
cutlass::flash_attention::fixed::cutlass_fixed_impl(
q_padded, k_padded, v_padded, out_padded,
softmax_scale,
window_size_left, window_size_right,
is_causal, is_local);

// Remove padding from output
at::Tensor out = out_padded;
Expand Down Expand Up @@ -170,14 +176,19 @@ mha_varlen_fwd(
} else {
out_padded = torch::zeros_like(q_padded);
}


bool is_local = (window_size_left != -1) | (window_size_right != -1);

q_padded = ensure_contiguous(q_padded);
k_padded = ensure_contiguous(k_padded);
v_padded = ensure_contiguous(v_padded);
cutlass::flash_attention::varlen::cutlass_varlen_impl(q_padded, k_padded, v_padded, out_padded, block_table_,
cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k,
softmax_scale, is_causal);
cutlass::flash_attention::varlen::cutlass_varlen_impl(
q_padded, k_padded, v_padded, out_padded, block_table_,
cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k,
softmax_scale,
window_size_left, window_size_right,
is_causal, is_local);

// Remove padding from output
at::Tensor out = out_padded;
Expand Down
15 changes: 10 additions & 5 deletions flash-attn2/flash_attn_xpu/src/collective/fixed_mma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ CUTLASS_DEVICE auto convert_type(Tensor<Engine, Layout> const &tensor) {

template <class DispatchPolicy, class ProblemShapeType_, class ElementQ_, class StrideQ_, class ElementK_, class StrideK_,
class ElementV_, class StrideV_, class MMAOperation_, class TileShapeQK_, class TileShapePV_, class SubgroupLayout_, class GmemTiledCopyQ_, class GmemTiledCopyK_,
class GmemTiledCopyV_, bool CausalMask_>
class GmemTiledCopyV_, bool CausalMask_, bool LocalMask_>
struct FlashPrefillMma {
static_assert(cutlass::detail::dependent_false<ElementQ_>, "Could not find a mainloop specialization.");
};
Expand All @@ -71,9 +71,9 @@ struct FlashPrefillMma {

template <int Stages, class ProblemShapeType_, class ElementQ_, class StrideQ_, class ElementK_, class StrideK_,
class ElementV_, class StrideV_, class MMAOperation_, class TileShapeQK_, class TileShapePV_, class SubgroupLayout_, class GmemTiledCopyQ_, class GmemTiledCopyK_,
class GmemTiledCopyV_, bool CausalMask_>
class GmemTiledCopyV_, bool CausalMask_, bool LocalMask_>
struct FlashPrefillMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, ElementQ_, StrideQ_, ElementK_, StrideK_, ElementV_,
StrideV_, MMAOperation_, TileShapeQK_, TileShapePV_, SubgroupLayout_, GmemTiledCopyQ_, GmemTiledCopyK_, GmemTiledCopyV_, CausalMask_> {
StrideV_, MMAOperation_, TileShapeQK_, TileShapePV_, SubgroupLayout_, GmemTiledCopyQ_, GmemTiledCopyK_, GmemTiledCopyV_, CausalMask_, LocalMask_> {
//
// Type Aliases
//
Expand All @@ -97,6 +97,7 @@ struct FlashPrefillMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, El
using TiledMmaPV = typename TiledMMAHelper<MmaAtom, Layout<TileShapePV>, SubgroupLayout>::TiledMMA;
using ElementAccumulator = typename TiledMmaQK::ValTypeC;
static constexpr bool CausalMask = CausalMask_;
static constexpr bool LocalMask = LocalMask_;
static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize;

using MmaAtomShape = typename MmaAtom::Shape_MNK;
Expand Down Expand Up @@ -158,12 +159,16 @@ struct FlashPrefillMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, El
StrideK dK;
ElementV const *ptr_V;
StrideV dV;
int window_left;
int window_right;
};

struct Params {
XE_Copy_Q gmem_tiled_copy_q;
XE_Copy_K gmem_tiled_copy_k;
XE_Copy_V gmem_tiled_copy_v;
int window_left;
int window_right;
};

//
Expand All @@ -185,7 +190,7 @@ struct FlashPrefillMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, El
XE_Copy_K copyK{XE_Copy_K{}.with(tensorK)};
XE_Copy_V copyV{XE_Copy_V{}.with(tensorV)};

return Params{copyQ, copyK, copyV};
return Params{copyQ, copyK, copyV, args.window_left, args.window_right};
}

template <class FragQccum, class TensorQ, class TensorK, class FragSrc>
Expand Down Expand Up @@ -376,7 +381,7 @@ struct FlashPrefillMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, El
XE_Copy_K copyK{XE_Copy_K{}.with(tensorK)};
XE_Copy_V copyV{XE_Copy_V{}.with(tensorV)};

return Params{copyQ, copyK, copyV};
return Params{copyQ, copyK, copyV, params.window_left, params.window_right};
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ namespace collective {

/////////////////////////////////////////////////////////////////////////////////////////////////

template <bool CausalMask_, class DispatchPolicy, class... Args> class FlashPrefillSoftmaxEpilogue {
template <bool CausalMask_, bool LocalMask_, class DispatchPolicy, class... Args> class FlashPrefillSoftmaxEpilogue {
static_assert(cutlass::detail::dependent_false<DispatchPolicy>, "Could not find an epilogue specialization.");
};


template <bool CausalMask_, class Element_>
class FlashPrefillSoftmaxEpilogue<CausalMask_, epilogue::IntelXeXMX16, Element_> {
template <bool CausalMask_, bool LocalMask_, class Element_>
class FlashPrefillSoftmaxEpilogue<CausalMask_, LocalMask_, epilogue::IntelXeXMX16, Element_> {
public:

//
Expand All @@ -66,6 +66,7 @@ class FlashPrefillSoftmaxEpilogue<CausalMask_, epilogue::IntelXeXMX16, Element_>
using Element = Element_;

static constexpr bool CausalMask = CausalMask_;
static constexpr bool LocalMask = LocalMask_;

using GmemTiledCopyOut = void;

Expand Down Expand Up @@ -115,8 +116,18 @@ class FlashPrefillSoftmaxEpilogue<CausalMask_, epilogue::IntelXeXMX16, Element_>
CUTLASS_PRAGMA_UNROLL
for (int z = 0; z < FragsN; z++) {
auto base_indx = indx + (z * Vec * FragsM);
Element eq = frag_s(base_indx) - max_scale_bcast;
frag_s(base_indx) = sycl::native::exp2(eq);
if constexpr (LocalMask) {
if ((std::isinf(max_scale_bcast) && max_scale_bcast < 0) ||
(std::isinf(frag_s(base_indx)) && frag_s(base_indx) < 0)) {
frag_s(base_indx) = 0.f;
} else {
Element eq = frag_s(base_indx) - max_scale_bcast;
frag_s(base_indx) = sycl::native::exp2(eq);
}
} else {
Element eq = frag_s(base_indx) - max_scale_bcast;
frag_s(base_indx) = sycl::native::exp2(eq);
}
sum(indx) += frag_s(base_indx);
}
}
Expand Down Expand Up @@ -155,7 +166,18 @@ class FlashPrefillSoftmaxEpilogue<CausalMask_, epilogue::IntelXeXMX16, Element_>
if (!is_first) {
auto g = COMPAT::get_nd_item<1>().get_sub_group();
Element max_scale{max * params.scale};
Element exp_scale{sycl::native::exp2(max_prev * params.scale - max_scale)};
Element exp_scale;
if constexpr (LocalMask) {
if ((std::isinf(max_scale) && max_scale < 0) ||
(std::isinf(max_prev) && max_prev < 0)) {
exp_scale = 0.f;
} else {
exp_scale = sycl::native::exp2(max_prev * params.scale - max_scale);
}
} else {
exp_scale = sycl::native::exp2(max_prev * params.scale - max_scale);
}

CUTLASS_PRAGMA_UNROLL
for (int indx = 0; indx < Vec * FragsM; indx++) {
auto max_scale_bcast = group_broadcast(g, max_scale, indx);
Expand All @@ -164,7 +186,17 @@ class FlashPrefillSoftmaxEpilogue<CausalMask_, epilogue::IntelXeXMX16, Element_>
CUTLASS_PRAGMA_UNROLL
for (int z = 0; z < FragsNAcc; z++) {
auto base_indx = indx + (z * Vec * FragsM);
frag_s(base_indx) = sycl::native::exp2((frag_s(base_indx) - max_scale_bcast));
if constexpr (LocalMask) {
if ((std::isinf(max_scale_bcast) && max_scale_bcast < 0) ||
(std::isinf(frag_s(base_indx)) && frag_s(base_indx) < 0)) {
frag_s(base_indx) = 0.f;
} else {
Element eq = frag_s(base_indx) - max_scale_bcast;
frag_s(base_indx) = sycl::native::exp2(eq);
}
} else {
frag_s(base_indx) = sycl::native::exp2((frag_s(base_indx) - max_scale_bcast));
}
sum(indx) += frag_s(base_indx);
}
CUTLASS_PRAGMA_UNROLL
Expand Down
33 changes: 25 additions & 8 deletions flash-attn2/flash_attn_xpu/src/fixed.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ struct prefill_args_fixed_t {
int batch_size;
int seq_len_q;
int seq_len_k;
int window_size_left;
int window_size_right;
bool is_local;
};

template <class FMHAPrefillKernel>
Expand Down Expand Up @@ -83,7 +86,7 @@ struct KernelLauncher {
{reinterpret_cast<ElementQ*>(args.query), stride_Q,
reinterpret_cast<ElementK*>(args.key), stride_K,
reinterpret_cast<ElementV*>(args.value), stride_V,
}, // window_left, window_right for local mask (not supported currently)
args.window_size_left, args.window_size_right,},
{args.softmax_scale},
{reinterpret_cast<ElementOutput*>(args.out), stride_O},
hw_info};
Expand Down Expand Up @@ -165,7 +168,7 @@ template <typename TileShapeQK, typename TileShapePV, typename TileShapeOutput,
typename ElementAccumulator = float,
typename ElementComputeEpilogue = float>
struct FMHAKernel {
template <bool Causal, class Scheduler>
template <bool Causal, bool Local, class Scheduler>
static void run_impl(const prefill_args_fixed_t& args) {
cutlass::KernelHardwareInfo hw_info;

Expand All @@ -186,7 +189,7 @@ struct FMHAKernel {

using CollectiveSoftmaxEpilogue =
cutlass::flash_attention::collective::FlashPrefillSoftmaxEpilogue<
Causal, EpilogueDispatchPolicy, ElementAccumulator>;
Causal, Local, EpilogueDispatchPolicy, ElementAccumulator>;

using ProblemShape = cute::tuple<int, int, int, int, int, int, int>;

Expand All @@ -197,7 +200,7 @@ struct FMHAKernel {
cutlass::gemm::TagToStrideB_t<LayoutK>, ElementInputKV,
cutlass::gemm::TagToStrideB_t<LayoutV>, MMAOperation, TileShapeQK,
TileShapePV, SubgroupLayout, GmemTiledCopyQ, GmemTiledCopyK,
GmemTiledCopyV, Causal>;
GmemTiledCopyV, Causal, Local>;

using FMHAPrefillKernel = cutlass::flash_attention::kernel::FMHAPrefill<
ProblemShape, CollectiveMainloop, CollectiveSoftmaxEpilogue,
Expand All @@ -209,9 +212,17 @@ struct FMHAKernel {

static void dispatch(const prefill_args_fixed_t& args) {
if (args.is_causal) {
run_impl<true, cutlass::flash_attention::kernel::fixed::IndividualScheduler>(args);
if (args.is_local) {
run_impl<true, true, cutlass::flash_attention::kernel::fixed::IndividualScheduler>(args);
} else {
run_impl<true, false, cutlass::flash_attention::kernel::fixed::IndividualScheduler>(args);
}
} else {
run_impl<false, cutlass::flash_attention::kernel::fixed::IndividualScheduler>(args);
if (args.is_local) {
run_impl<false, true, cutlass::flash_attention::kernel::fixed::IndividualScheduler>(args);
} else {
run_impl<false, false, cutlass::flash_attention::kernel::fixed::IndividualScheduler>(args);
}
}
}
};
Expand Down Expand Up @@ -271,7 +282,8 @@ void cutlass_fixed_impl(
const at::Tensor& key, // [batch, seq_k, heads, head_size]
const at::Tensor& value, // [batch, seq_k, heads, head_size]
at::Tensor& out, // [batch, seq_q, heads, head_size]
double softmax_scale, bool is_causal) {
double softmax_scale, int window_size_left = -1, int window_size_right = -1,
bool is_causal = false, bool is_local = false) {

int batch_size = query.size(0);
int seq_len_q = query.size(1);
Expand All @@ -286,12 +298,17 @@ void cutlass_fixed_impl(
auto v_reshaped = value.transpose(1, 2).contiguous();
auto out_temp = torch::zeros_like(q_reshaped);

if (is_local) {
window_size_left = window_size_left == -1 ? seq_len_q : window_size_left;
window_size_right = window_size_right == -1 ? seq_len_k : window_size_right;
}

// Prepare arguments
prefill_args_fixed_t args{
q_reshaped.data_ptr(), k_reshaped.data_ptr(), v_reshaped.data_ptr(),
out_temp.data_ptr(), static_cast<float>(softmax_scale),
num_heads_q, num_heads_kv, head_size, is_causal, batch_size,
seq_len_q, seq_len_k
seq_len_q, seq_len_k, window_size_left, window_size_right, is_local
};

dispatch_by_head_size(aten_to_Cutlass_dtype(query), args);
Expand Down
50 changes: 50 additions & 0 deletions flash-attn2/flash_attn_xpu/src/kernel/fixed_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class FMHAPrefill {
static constexpr int SharedStorageSize = 0;

static constexpr bool CausalMask = CollectiveMainloop::CausalMask;
static constexpr bool LocalMask = CollectiveMainloop::LocalMask;
static constexpr int SubgroupSize = CollectiveMainloop::SubgroupSize; // sub_group size
static constexpr uint32_t MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock;
using MmaAtomShape = typename CollectiveMainloop::MmaAtomShape; // 8,16,16
Expand Down Expand Up @@ -341,6 +342,30 @@ class FMHAPrefill {
}
}

if constexpr (LocalMask) {
const int item_id = thread_idx % SubgroupSize;
int col_idx = item_id + nblock * QK_BLK_N;

CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < FragsN; n++, col_idx += get<1>(MmaAtomShape())) {
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < FragsM; m++) {
int row_idx = m * Vec + seq_coord;
int col_ref = seq_len_kv - seq_len_qo;
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < Vec; row++) {
bool left_mask = col_idx < cute::max(0,
row + row_idx + col_ref - mainloop_params.window_left);
bool right_mask = col_idx > cute::min(seq_len_kv,
row + row_idx + col_ref + mainloop_params.window_right);
if (left_mask || right_mask) {
tSr(row, m, n) = ElementAccumulator{-INFINITY};
}
}
}
}
}

// we only need one block ahead, there is enough gap to prefetch it while doing softmax. because the gap between the two MMA is big,
// prefetching it the same way as cutlass K matrix does not make sense
for(int i=0; i < size<1>(pVgV); i++) {
Expand Down Expand Up @@ -389,6 +414,31 @@ class FMHAPrefill {
}
}
}

if constexpr (LocalMask) {
const int item_id = thread_idx % SubgroupSize;
int col_idx = item_id + (nblock_limit - 1) * QK_BLK_N;

CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < FragsN; n++, col_idx += get<1>(MmaAtomShape())) {
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < FragsM; m++) {
int row_idx = m * Vec + seq_coord;
int col_ref = seq_len_kv - seq_len_qo;
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < Vec; row++) {
bool left_mask = col_idx < cute::max(0,
row + row_idx + col_ref - mainloop_params.window_left);
bool right_mask = col_idx > cute::min(seq_len_kv,
row + row_idx + col_ref + mainloop_params.window_right);
if (left_mask || right_mask) {
tSr(row, m, n) = ElementAccumulator{-INFINITY};
}
}
}
}
}

// Add masking for partial tiles at the end block
if (seq_len % QK_BLK_N != 0) {
const int remainder = seq_len % QK_BLK_N;
Expand Down
10 changes: 5 additions & 5 deletions flash-attn2/flash_attn_xpu/src/kernel/varlen_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class FMHAPrefillChunk {
static constexpr bool CausalMask = CollectiveMainloop::CausalMask;
static constexpr bool LocalMask = CollectiveMainloop::LocalMask;

static_assert(!(CausalMask && LocalMask), "Cannot be both causal and local");
// static_assert(!(CausalMask && LocalMask), "Cannot be both causal and local");
static constexpr bool PagedKV = CollectiveMainloop::PagedKV;

static constexpr int SubgroupSize =
Expand Down Expand Up @@ -464,15 +464,15 @@ class FMHAPrefillChunk {
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < FragsM; m++) { // 2
int row_idx = m * Vec + seq_coord;
int col_ref = seq_len_kv_cache - seq_len_qo;
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < Vec; row++) { // 8
bool left_mask =
col_idx < cute::max(0, row + row_idx + seq_len_kv_cache -
mainloop_params.window_left);
col_idx < cute::max(0,
row + row_idx + col_ref - mainloop_params.window_left);
bool right_mask =
col_idx > cute::min(seq_len_kv_cache,
row + row_idx + seq_len_kv_cache +
mainloop_params.window_right);
row + row_idx + col_ref + mainloop_params.window_right);
if (left_mask || right_mask) {
tSr(row, m, n) = ElementAccumulator{-INFINITY};
}
Expand Down
Loading