Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@

from flash_sparse_attn.ops.triton import (
assert_inputs,
flash_backward_postprocess,
flash_backward_preprocess,
launch_template,
launch_grid,
seqlen_info,
block_info,
mask,
flash_bwd_preprocess,
flash_bwd_postprocess,
)


Expand Down Expand Up @@ -712,7 +712,7 @@ def _flash_dense_attn_base_backward(
device=query.device,
)

flash_backward_preprocess._flash_attn_bwd_preprocess(
flash_bwd_preprocess._flash_attn_bwd_preprocess(
out=out,
dout=dout,
dpsum=dpsum,
Expand Down Expand Up @@ -793,7 +793,7 @@ def _flash_dense_attn_base_backward(
num_ctas=num_ctas,
)

flash_backward_postprocess._flash_attn_bwd_postprocess(
flash_bwd_postprocess._flash_attn_bwd_postprocess(
dq_accum=dq_accum,
dq=dq,
scale=softmax_scale,
Expand Down Expand Up @@ -901,7 +901,7 @@ def _flash_dense_attn_varlen_base_backward(
device=query.device,
)

flash_backward_preprocess._flash_attn_bwd_preprocess(
flash_bwd_preprocess._flash_attn_bwd_preprocess(
out=out,
dout=dout,
dpsum=dpsum,
Expand Down Expand Up @@ -985,7 +985,7 @@ def _flash_dense_attn_varlen_base_backward(
num_ctas=num_ctas,
)

flash_backward_postprocess._flash_attn_bwd_postprocess(
flash_bwd_postprocess._flash_attn_bwd_postprocess(
dq_accum=dq_accum,
dq=dq,
scale=softmax_scale,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@

from flash_sparse_attn.ops.triton import (
assert_inputs,
flash_decode_combine,
utils,
launch_template,
launch_grid,
seqlen_info,
block_info,
activations,
mask,
flash_fwd_combine,
)


Expand Down Expand Up @@ -743,7 +743,7 @@ def _flash_dense_attn_base_forward(
)

if is_split_kv:
flash_decode_combine._flash_attn_fwd_combine(
flash_fwd_combine._flash_attn_fwd_combine(
out_partial,
lse_partial,
out,
Expand Down Expand Up @@ -894,7 +894,7 @@ def _flash_dense_attn_varlen_base_forward(
)

if is_split_kv:
flash_decode_combine._flash_attn_fwd_combine(
flash_fwd_combine._flash_attn_fwd_combine(
out_partial,
lse_partial,
out,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@

from flash_sparse_attn.ops.triton import (
assert_inputs,
flash_backward_postprocess,
flash_backward_preprocess,
launch_template,
launch_grid,
seqlen_info,
block_info,
activations,
mask,
flash_bwd_preprocess,
flash_bwd_postprocess,
)


Expand Down Expand Up @@ -1101,7 +1101,7 @@ def _flash_gated_attn_base_backward(
device=query.device,
)

flash_backward_preprocess._flash_attn_bwd_preprocess(
flash_bwd_preprocess._flash_attn_bwd_preprocess(
out=out,
dout=dout,
dpsum=dpsum,
Expand Down Expand Up @@ -1198,7 +1198,7 @@ def _flash_gated_attn_base_backward(
num_ctas=num_ctas,
)

flash_backward_postprocess._flash_attn_bwd_postprocess(
flash_bwd_postprocess._flash_attn_bwd_postprocess(
dq_accum=dq_accum,
dq=dq,
da_accum=da_accum,
Expand Down Expand Up @@ -1324,7 +1324,7 @@ def _flash_gated_attn_varlen_base_backward(
device=query.device,
)

flash_backward_preprocess._flash_attn_bwd_preprocess(
flash_bwd_preprocess._flash_attn_bwd_preprocess(
out=out,
dout=dout,
dpsum=dpsum,
Expand Down Expand Up @@ -1424,7 +1424,7 @@ def _flash_gated_attn_varlen_base_backward(
num_ctas=num_ctas,
)

flash_backward_postprocess._flash_attn_bwd_postprocess(
flash_bwd_postprocess._flash_attn_bwd_postprocess(
dq_accum=dq_accum,
dq=dq,
da_accum=da_accum,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@

from flash_sparse_attn.ops.triton import (
assert_inputs,
flash_decode_combine,
utils,
launch_template,
launch_grid,
seqlen_info,
block_info,
activations,
mask,
flash_fwd_combine,
)


Expand Down Expand Up @@ -1095,7 +1095,7 @@ def _flash_gated_attn_base_forward(
)

if is_split_kv:
flash_decode_combine._flash_attn_fwd_combine(
flash_fwd_combine._flash_attn_fwd_combine(
out_partial,
lse_partial,
out,
Expand Down Expand Up @@ -1266,7 +1266,7 @@ def _flash_gated_attn_varlen_base_forward(
)

if is_split_kv:
flash_decode_combine._flash_attn_fwd_combine(
flash_fwd_combine._flash_attn_fwd_combine(
out_partial,
lse_partial,
out,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@

from flash_sparse_attn.ops.triton import (
assert_inputs,
flash_backward_postprocess,
flash_backward_preprocess,
launch_template,
launch_grid,
seqlen_info,
block_info,
mask,
flash_bwd_preprocess,
flash_bwd_postprocess,
)


Expand Down Expand Up @@ -773,7 +773,7 @@ def _flash_sparse_attn_base_backward(
device=query.device,
)

flash_backward_preprocess._flash_attn_bwd_preprocess(
flash_bwd_preprocess._flash_attn_bwd_preprocess(
out=out,
dout=dout,
dpsum=dpsum,
Expand Down Expand Up @@ -855,7 +855,7 @@ def _flash_sparse_attn_base_backward(
num_ctas=num_ctas,
)

flash_backward_postprocess._flash_attn_bwd_postprocess(
flash_bwd_postprocess._flash_attn_bwd_postprocess(
dq_accum=dq_accum,
dq=dq,
scale=softmax_scale,
Expand Down Expand Up @@ -965,7 +965,7 @@ def _flash_sparse_attn_varlen_base_backward(
device=query.device,
)

flash_backward_preprocess._flash_attn_bwd_preprocess(
flash_bwd_preprocess._flash_attn_bwd_preprocess(
out=out,
dout=dout,
dpsum=dpsum,
Expand Down Expand Up @@ -1050,7 +1050,7 @@ def _flash_sparse_attn_varlen_base_backward(
num_ctas=num_ctas,
)

flash_backward_postprocess._flash_attn_bwd_postprocess(
flash_bwd_postprocess._flash_attn_bwd_postprocess(
dq_accum=dq_accum,
dq=dq,
scale=softmax_scale,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@

from flash_sparse_attn.ops.triton import (
assert_inputs,
flash_decode_combine,
utils,
launch_template,
launch_grid,
seqlen_info,
block_info,
activations,
mask,
flash_fwd_combine,
)


Expand Down Expand Up @@ -766,7 +766,7 @@ def _flash_sparse_attn_base_forward(
)

if is_split_kv:
flash_decode_combine._flash_attn_fwd_combine(
flash_fwd_combine._flash_attn_fwd_combine(
out_partial,
lse_partial,
out,
Expand Down Expand Up @@ -920,7 +920,7 @@ def _flash_sparse_attn_varlen_base_forward(
)

if is_split_kv:
flash_decode_combine._flash_attn_fwd_combine(
flash_fwd_combine._flash_attn_fwd_combine(
out_partial,
lse_partial,
out,
Expand Down
12 changes: 6 additions & 6 deletions flash_sparse_attn/ops/triton/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,27 @@

import torch

from flash_sparse_attn.ops.triton.flash_dense_forward import (
from flash_sparse_attn.ops.triton.flash_dense_fwd import (
_flash_dense_attn_base_forward,
_flash_dense_attn_varlen_base_forward,
)
from flash_sparse_attn.ops.triton.flash_dense_backward import (
from flash_sparse_attn.ops.triton.flash_dense_bwd import (
_flash_dense_attn_base_backward,
_flash_dense_attn_varlen_base_backward,
)
from flash_sparse_attn.ops.triton.flash_sparse_forward import (
from flash_sparse_attn.ops.triton.flash_sparse_fwd import (
_flash_sparse_attn_base_forward,
_flash_sparse_attn_varlen_base_forward,
)
from flash_sparse_attn.ops.triton.flash_sparse_backward import (
from flash_sparse_attn.ops.triton.flash_sparse_bwd import (
_flash_sparse_attn_base_backward,
_flash_sparse_attn_varlen_base_backward,
)
from flash_sparse_attn.ops.triton.flash_gated_forward import (
from flash_sparse_attn.ops.triton.flash_gated_fwd import (
_flash_gated_attn_base_forward,
_flash_gated_attn_varlen_base_forward,
)
from flash_sparse_attn.ops.triton.flash_gated_backward import (
from flash_sparse_attn.ops.triton.flash_gated_bwd import (
_flash_gated_attn_base_backward,
_flash_gated_attn_varlen_base_backward,
)
Expand Down
12 changes: 6 additions & 6 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,27 @@

import torch

from flash_sparse_attn.ops.triton.flash_dense_backward import (
from flash_sparse_attn.ops.triton.flash_dense_bwd import (
_flash_dense_attn_base_backward,
_flash_dense_attn_varlen_base_backward,
)
from flash_sparse_attn.ops.triton.flash_dense_forward import (
from flash_sparse_attn.ops.triton.flash_dense_fwd import (
_flash_dense_attn_base_forward,
_flash_dense_attn_varlen_base_forward,
)
from flash_sparse_attn.ops.triton.flash_gated_backward import (
from flash_sparse_attn.ops.triton.flash_gated_bwd import (
_flash_gated_attn_base_backward,
_flash_gated_attn_varlen_base_backward,
)
from flash_sparse_attn.ops.triton.flash_gated_forward import (
from flash_sparse_attn.ops.triton.flash_gated_fwd import (
_flash_gated_attn_base_forward,
_flash_gated_attn_varlen_base_forward,
)
from flash_sparse_attn.ops.triton.flash_sparse_backward import (
from flash_sparse_attn.ops.triton.flash_sparse_bwd import (
_flash_sparse_attn_base_backward,
_flash_sparse_attn_varlen_base_backward,
)
from flash_sparse_attn.ops.triton.flash_sparse_forward import (
from flash_sparse_attn.ops.triton.flash_sparse_fwd import (
_flash_sparse_attn_base_forward,
_flash_sparse_attn_varlen_base_forward,
)
Expand Down
Loading