Skip to content

Commit 827fad6

Browse files
leisuzzJ石页a-r-r-o-w
authored
Improve performance of NPU FA (#12260)
Co-authored-by: J石页 <jiangshuo9@h-partners.com> Co-authored-by: Aryan <aryan@huggingface.co>
1 parent 9b721db commit 827fad6

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -955,12 +955,13 @@ def _native_npu_attention(
955955
dropout_p: float = 0.0,
956956
scale: Optional[float] = None,
957957
) -> torch.Tensor:
958-
return npu_fusion_attention(
958+
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
959+
out = npu_fusion_attention(
959960
query,
960961
key,
961962
value,
962-
query.size(2), # num_heads
963-
input_layout="BSND",
963+
query.size(1), # num_heads
964+
input_layout="BNSD",
964965
pse=None,
965966
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
966967
pre_tockens=65536,
@@ -969,6 +970,8 @@ def _native_npu_attention(
969970
sync=False,
970971
inner_precise=0,
971972
)[0]
973+
out = out.transpose(1, 2).contiguous()
974+
return out
972975

973976

974977
# Reference: https://github.com/pytorch/xla/blob/06c5533de6588f6b90aa1655d9850bcf733b90b4/torch_xla/experimental/custom_kernel.py#L853

0 commit comments

Comments
 (0)