Skip to content

Commit 2a25807

Browse files
authored
Revert "Skip size calculation during async copy wait=True" (#1148)
Signed-off-by: Kyuyeun Kim <kyuyeunk@google.com>
1 parent 6f311b9 commit 2a25807

File tree

2 files changed

+108
-154
lines changed

2 files changed

+108
-154
lines changed

tpu_inference/kernels/ragged_paged_attention/v3/kernel.py

Lines changed: 54 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -440,54 +440,42 @@ def _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx, *, wait=False):
440440
debug_print("[RPA debug] bkv_sz_frm_new={}", bkv_sz_frm_new)
441441
debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
442442

443-
if not wait:
444-
# Fetch effective kv from kv cache.
445-
def loop_body(i, offset):
446-
sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
447-
_async_copy(
448-
cache_hbm_ref.at[pl.ds(
449-
page_indices_ref[page_indices_offset + i] * page_size,
450-
sz)],
451-
vmem_ref.at[pl.ds(i * page_size, sz)],
452-
sem,
453-
wait=False,
454-
)
455-
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
456-
return offset + sz
457-
458-
offset = lax.fori_loop(
459-
0,
460-
bkv_p_frm_cache,
461-
loop_body,
462-
0, # offset
463-
unroll=False,
443+
# Fetch effective kv from kv cache.
444+
def loop_body(i, offset):
445+
sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
446+
_async_copy(
447+
cache_hbm_ref.at[pl.ds(
448+
page_indices_ref[page_indices_offset + i] * page_size,
449+
sz)],
450+
vmem_ref.at[pl.ds(i * page_size, sz)],
451+
sem,
452+
wait,
464453
)
454+
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
455+
return offset + sz
456+
457+
offset = lax.fori_loop(
458+
0,
459+
bkv_p_frm_cache,
460+
loop_body,
461+
0, # offset
462+
unroll=False,
463+
)
465464

466-
# Fetch kv directly from new kv.
467-
@pl.when(bkv_sz_frm_new > 0)
468-
def _fetch_bkv_from_new_kv():
469-
new_kv_len_start = q_end - kv_left_frm_new
470-
debug_print("[RPA debug] new_kv_len_start={}",
471-
new_kv_len_start)
472-
debug_print("[RPA debug] offset_in_bkv={}", offset)
473-
_async_copy(
474-
kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
475-
vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
476-
sem,
477-
wait,
478-
)
479-
480-
return kv_len_start + offset, bkv_sz_frm_new
481-
else:
482-
offset = jnp.minimum(kv_left_frm_cache, page_size * bkv_p)
483-
dst = vmem_ref.at[pl.ds(0, offset + bkv_sz_frm_new)]
465+
# Fetch kv directly from new kv.
466+
@pl.when(bkv_sz_frm_new > 0)
467+
def _fetch_bkv_from_new_kv():
468+
new_kv_len_start = q_end - kv_left_frm_new
469+
debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
470+
debug_print("[RPA debug] offset_in_bkv={}", offset)
484471
_async_copy(
485-
src=dst,
486-
dst=dst,
487-
sem=sem,
488-
wait=True,
472+
kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
473+
vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
474+
sem,
475+
wait,
489476
)
490-
return kv_len_start + offset, bkv_sz_frm_new
477+
478+
return kv_len_start + offset, bkv_sz_frm_new
491479

492480
def _update_kv_cache(seq_idx,
493481
bkv_sem_idx,
@@ -523,41 +511,30 @@ def _update_kv_cache(seq_idx,
523511
debug_print("[RPA debug] p_ignore={}", p_ignore)
524512
debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
525513

526-
if not wait:
527-
528-
def loop_body(i, states):
529-
update_sz, ignore = states
530-
sz = jnp.minimum(page_size - ignore, update_sz)
531-
532-
_async_copy(
533-
vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore,
534-
sz)],
535-
cache_hbm_ref.at[pl.ds(
536-
page_indices_ref[page_indices_offset + i] * page_size +
537-
ignore,
538-
sz,
539-
)],
540-
sem,
541-
wait=False,
542-
)
543-
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
544-
return update_sz - sz, 0
545-
546-
lax.fori_loop(
547-
0,
548-
kv_p_end - kv_p_start,
549-
loop_body,
550-
(update_sz, ignore), # total transfer size
551-
unroll=False,
552-
)
553-
else:
554-
dst = cache_hbm_ref.at[pl.ds(0, update_sz)],
514+
def loop_body(i, states):
515+
update_sz, ignore = states
516+
sz = jnp.minimum(page_size - ignore, update_sz)
517+
555518
_async_copy(
556-
src=dst,
557-
dst=dst,
558-
sem=sem,
559-
wait=True,
519+
vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore, sz)],
520+
cache_hbm_ref.at[pl.ds(
521+
page_indices_ref[page_indices_offset + i] * page_size +
522+
ignore,
523+
sz,
524+
)],
525+
sem,
526+
wait,
560527
)
528+
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
529+
return update_sz - sz, 0
530+
531+
lax.fori_loop(
532+
0,
533+
kv_p_end - kv_p_start,
534+
loop_body,
535+
(update_sz, ignore), # total transfer size
536+
unroll=False,
537+
)
561538

562539
def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
563540
sem = sems.at[1, bq_sem_idx]

tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py

Lines changed: 54 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -475,54 +475,42 @@ def _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx, *, wait=False):
475475
debug_print("[RPA debug] bkv_sz_frm_new={}", bkv_sz_frm_new)
476476
debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
477477

478-
if not wait:
479-
# Fetch effective kv from kv cache.
480-
def loop_body(i, offset):
481-
sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
482-
_async_copy(
483-
cache_hbm_ref.at[pl.ds(
484-
page_indices_ref[page_indices_offset + i] * page_size,
485-
sz)],
486-
vmem_ref.at[pl.ds(i * page_size, sz)],
487-
sem,
488-
wait=False,
489-
)
490-
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
491-
return offset + sz
492-
493-
offset = lax.fori_loop(
494-
0,
495-
bkv_p_frm_cache,
496-
loop_body,
497-
0, # offset
498-
unroll=False,
478+
# Fetch effective kv from kv cache.
479+
def loop_body(i, offset):
480+
sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
481+
_async_copy(
482+
cache_hbm_ref.at[pl.ds(
483+
page_indices_ref[page_indices_offset + i] * page_size,
484+
sz)],
485+
vmem_ref.at[pl.ds(i * page_size, sz)],
486+
sem,
487+
wait,
499488
)
489+
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
490+
return offset + sz
491+
492+
offset = lax.fori_loop(
493+
0,
494+
bkv_p_frm_cache,
495+
loop_body,
496+
0, # offset
497+
unroll=False,
498+
)
500499

501-
# Fetch kv directly from new kv.
502-
@pl.when(bkv_sz_frm_new > 0)
503-
def _fetch_bkv_from_new_kv():
504-
new_kv_len_start = q_end - kv_left_frm_new
505-
debug_print("[RPA debug] new_kv_len_start={}",
506-
new_kv_len_start)
507-
debug_print("[RPA debug] offset_in_bkv={}", offset)
508-
_async_copy(
509-
kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
510-
vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
511-
sem,
512-
wait,
513-
)
514-
515-
return kv_len_start + offset, bkv_sz_frm_new
516-
else:
517-
offset = jnp.minimum(kv_left_frm_cache, page_size * bkv_p)
518-
dst = vmem_ref.at[pl.ds(0, offset + bkv_sz_frm_new)]
500+
# Fetch kv directly from new kv.
501+
@pl.when(bkv_sz_frm_new > 0)
502+
def _fetch_bkv_from_new_kv():
503+
new_kv_len_start = q_end - kv_left_frm_new
504+
debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
505+
debug_print("[RPA debug] offset_in_bkv={}", offset)
519506
_async_copy(
520-
src=dst,
521-
dst=dst,
522-
sem=sem,
523-
wait=True,
507+
kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
508+
vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
509+
sem,
510+
wait,
524511
)
525-
return kv_len_start + offset, bkv_sz_frm_new
512+
513+
return kv_len_start + offset, bkv_sz_frm_new
526514

527515
def _update_kv_cache(seq_idx,
528516
bkv_sem_idx,
@@ -558,41 +546,30 @@ def _update_kv_cache(seq_idx,
558546
debug_print("[RPA debug] p_ignore={}", p_ignore)
559547
debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
560548

561-
if not wait:
562-
563-
def loop_body(i, states):
564-
update_sz, ignore = states
565-
sz = jnp.minimum(page_size - ignore, update_sz)
566-
567-
_async_copy(
568-
vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore,
569-
sz)],
570-
cache_hbm_ref.at[pl.ds(
571-
page_indices_ref[page_indices_offset + i] * page_size +
572-
ignore,
573-
sz,
574-
)],
575-
sem,
576-
wait=False,
577-
)
578-
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
579-
return update_sz - sz, 0
580-
581-
lax.fori_loop(
582-
0,
583-
kv_p_end - kv_p_start,
584-
loop_body,
585-
(update_sz, ignore), # total transfer size
586-
unroll=False,
587-
)
588-
else:
589-
dst = cache_hbm_ref.at[pl.ds(0, update_sz)]
549+
def loop_body(i, states):
550+
update_sz, ignore = states
551+
sz = jnp.minimum(page_size - ignore, update_sz)
552+
590553
_async_copy(
591-
src=dst,
592-
dst=dst,
593-
sem=sem,
594-
wait=True,
554+
vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore, sz)],
555+
cache_hbm_ref.at[pl.ds(
556+
page_indices_ref[page_indices_offset + i] * page_size +
557+
ignore,
558+
sz,
559+
)],
560+
sem,
561+
wait,
595562
)
563+
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
564+
return update_sz - sz, 0
565+
566+
lax.fori_loop(
567+
0,
568+
kv_p_end - kv_p_start,
569+
loop_body,
570+
(update_sz, ignore), # total transfer size
571+
unroll=False,
572+
)
596573

597574
def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
598575
sem = sems.at[1, bq_sem_idx]

0 commit comments

Comments
 (0)