Skip to content

Commit 3bb031b

Browse files
sxuagrima1304
authored andcommitted
Call .detach() in static attention cache update helper
Differential Revision: D80853817 Pull Request resolved: pytorch#13618
1 parent 9c0280c commit 3bb031b

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

examples/models/llama/static_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,7 @@ def _update_states(self, attn_updates, update_pos, update_len):
549549
style=self.style,
550550
update_pos=update_pos,
551551
update_len=update_len,
552-
)
552+
).detach()
553553
for cache_id, update in v_cache_updates.items():
554554
self.v_caches[cache_id] = StaticKVCache.apply_update(
555555
self.v_caches[cache_id],
@@ -558,7 +558,7 @@ def _update_states(self, attn_updates, update_pos, update_len):
558558
style=self.style,
559559
update_pos=update_pos,
560560
update_len=update_len,
561-
)
561+
).detach()
562562
self.pos += update_len
563563

564564
def _get_lookahead_decoding_mask(

0 commit comments

Comments
 (0)