From 9e7fafd5fc15d62b7e1b9c4bc9867a83dfb5cb2b Mon Sep 17 00:00:00 2001 From: pftq Date: Sat, 27 Sep 2025 21:36:40 -0700 Subject: [PATCH] Update attention.py Fixed TypeError error "only integer tensors of a single element can be converted to an index" when run on flash_attn_2 with version 2.8.3 and B200 GPU. --- wanvideo/modules/attention.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/wanvideo/modules/attention.py b/wanvideo/modules/attention.py index b923419a..4942fabc 100644 --- a/wanvideo/modules/attention.py +++ b/wanvideo/modules/attention.py @@ -101,8 +101,13 @@ def half(x): [lk] * b, dtype=torch.int32).to( device=k.device, non_blocking=True) else: - k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)])) - v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)])) + # 20250927 pftq: Convert k_lens to a single tensor to avoid 'list' object has no attribute 'new_zeros' and optimize performance + k_lens_tensor = torch.tensor(k_lens, dtype=torch.int32, device=k.device) if not isinstance(k_lens, torch.Tensor) else k_lens.to(dtype=torch.int32, device=k.device) + k = half(torch.cat([u[:v] for u, v in zip(k, k_lens_tensor)])) # Use tensor directly + #k = half(torch.cat([u[:v.item()] for u, v in zip(k, k_lens)])) # original line + v = half(torch.cat([u[:v] for u, v in zip(v, k_lens_tensor)])) # Use tensor directly + #v = half(torch.cat([u[:v.item()] for u, v in zip(v, k_lens)])) # original line + k_lens = k_lens_tensor # 20250927 pftq: Reuse k_lens_tensor for cu_seqlens_k q = q.to(v.dtype) k = k.to(v.dtype)