From d3cc2b932b274df3d7b9eb9c78aeb8fddb8492e5 Mon Sep 17 00:00:00 2001 From: poursoul Date: Fri, 13 Feb 2026 17:20:33 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E4=BF=AE=E5=A4=8D64batch=E5=8D=A1=E6=AD=BB?= =?UTF-8?q?=E7=9A=84=E9=97=AE=E9=A2=98=EF=BC=8C=E6=9A=82=E6=97=B6=E5=B0=86?= =?UTF-8?q?window=E5=A4=A7=E5=B0=8F=E6=94=B9=E6=88=9065536=EF=BC=8C?= =?UTF-8?q?=E5=90=8E=E9=9D=A2=E5=86=8D=E6=94=AF=E6=8C=81task=20ring?= =?UTF-8?q?=E7=9A=84=E7=8E=AF=E5=BD=A2=E5=88=86=E9=85=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../aicpu/aicpu_executor.cpp | 3 ++- .../orchestration/tensor_orch.cpp | 24 +++++++++---------- .../paged_attention/golden.py | 2 +- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/src/runtime/tensormap_and_ringbuffer/aicpu/aicpu_executor.cpp b/src/runtime/tensormap_and_ringbuffer/aicpu/aicpu_executor.cpp index 66a7463b..4ce464e4 100644 --- a/src/runtime/tensormap_and_ringbuffer/aicpu/aicpu_executor.cpp +++ b/src/runtime/tensormap_and_ringbuffer/aicpu/aicpu_executor.cpp @@ -22,6 +22,7 @@ // Runtime headers (full struct definition for create/destroy + PTO2_SCOPE) #include "pto_runtime2.h" #include "pto_shared_memory.h" +#include "pto_runtime2_types.h" // Performance profiling headers #include "common/perf_profiling.h" @@ -136,7 +137,7 @@ struct AicpuExecutor { static AicpuExecutor g_aicpu_executor; // PTO2 device-mode state (shared memory view + per-task fanin refcount) -static constexpr int PTO2_MAX_SLOTS = 16384; +static constexpr int PTO2_MAX_SLOTS = PTO2_TASK_WINDOW_SIZE; static int s_pto2_fanin_refcount[PTO2_MAX_SLOTS]; static volatile int32_t s_pto2_task_completed[PTO2_MAX_SLOTS]; static PTO2DispatchPayload s_pto2_payload_per_core[RUNTIME_MAX_WORKER]; diff --git a/src/runtime/tensormap_and_ringbuffer/orchestration/tensor_orch.cpp b/src/runtime/tensormap_and_ringbuffer/orchestration/tensor_orch.cpp index 31c55a08..dfe84df0 100644 --- a/src/runtime/tensormap_and_ringbuffer/orchestration/tensor_orch.cpp +++ b/src/runtime/tensormap_and_ringbuffer/orchestration/tensor_orch.cpp @@ -134,20 +134,20 @@ void Tensor::resort_strides() { } Tensor& Tensor::optimize() { -#ifndef NDEBUG - uint64_t original_strides[RUNTIME_MAX_TENSOR_DIMS]; - uint64_t original_repeats[RUNTIME_MAX_TENSOR_DIMS]; - int32_t original_ndims = ndims; - for (uint64_t i = 0; i < ndims; i++) { - original_strides[i] = this->strides[i]; - original_repeats[i] = this->repeats[i]; - } -#endif +// #ifndef NDEBUG +// uint64_t original_strides[RUNTIME_MAX_TENSOR_DIMS]; +// uint64_t original_repeats[RUNTIME_MAX_TENSOR_DIMS]; +// int32_t original_ndims = ndims; +// for (uint64_t i = 0; i < ndims; i++) { +// original_strides[i] = this->strides[i]; +// original_repeats[i] = this->repeats[i]; +// } +// #endif resort_strides(); -#ifndef NDEBUG - debug_assert(validate_memory_access_preserved(original_strides, original_repeats, original_ndims)); -#endif +// #ifndef NDEBUG +// debug_assert(validate_memory_access_preserved(original_strides, original_repeats, original_ndims)); +// #endif return *this; } diff --git a/tests/device_tests/tensormap_and_ringbuffer/paged_attention/golden.py b/tests/device_tests/tensormap_and_ringbuffer/paged_attention/golden.py index 2ae2bf0a..61ab6a1d 100644 --- a/tests/device_tests/tensormap_and_ringbuffer/paged_attention/golden.py +++ b/tests/device_tests/tensormap_and_ringbuffer/paged_attention/golden.py @@ -27,7 +27,7 @@ # All test cases - production scale ALL_CASES = { "Case1": { - "batch": 16, + "batch": 64, "num_heads": 16, "kv_head_num": 1, "head_dim": 128, From d372c32272e49cd3e385aa0570aaf24e2f88cc7f Mon Sep 17 00:00:00 2001 From: poursoul Date: Sat, 14 Feb 2026 14:40:46 +0800 Subject: [PATCH 2/2] Update: optimize PTO2 scheduler and orchestrator performance - Replace std::mutex with spinlocks and bitmask circular queues in AICPU scheduler ready queues; remove redundant atomic counters - Add orchestrator ready queue (SPSC ring) to push early-return ready tasks directly, replacing O(N) readiness scan in scheduler - Stack-allocate tensormap lookup results (PTO2LookupResult) to avoid per-lookup heap allocation from std::vector - Inline Tensor copy ctor/operator= into header; skip memset on task ring slot allocation; bulk memcpy params in submit_task - Add PTO2_SPIN_PAUSE_LIGHT() (yield without sched_yield) for tight spinloops on aarch64/x86_64 - Add DEV_ALWAYS log level for diagnostics; change diagnostic reports from DEV_ERROR to DEV_ALWAYS, per-task logs from DEV_INFO to DEV_DEBUG - Add profiling instrumentation to scheduler, orchestrator, and host runtime init - Vectorize paged attention golden computation across batch dimension --- .../paged_attention/golden.py | 156 ++++++---- examples/scripts/code_runner.py | 10 + .../paged_attention/golden.py | 143 +++++---- src/platform/a2a3/aicpu/device_log.cpp | 9 + src/platform/a2a3sim/aicpu/device_log.cpp | 9 + src/platform/include/aicpu/device_log.h | 2 + .../aicpu/aicpu_executor.cpp | 30 +- .../aicpu/aicpu_executor.cpp | 288 +++++++++++------- .../host/runtime_maker.cpp | 26 ++ .../orchestration/tensor_orch.cpp | 27 +- .../runtime/pto_orchestrator.cpp | 121 +++++++- .../runtime/pto_orchestrator.h | 41 +++ .../runtime/pto_ring_buffer.cpp | 4 +- .../runtime/pto_runtime2_types.h | 4 +- .../runtime/pto_tensormap.cpp | 10 +- .../runtime/pto_tensormap.h | 23 +- .../tensormap_and_ringbuffer/runtime/tensor.h | 28 +- .../paged_attention/golden.py | 139 +++++---- .../paged_attention/golden.py | 137 +++++---- .../paged_attention/golden.py | 137 +++++---- 20 files changed, 872 insertions(+), 472 deletions(-) diff --git a/examples/host_build_graph/paged_attention/golden.py b/examples/host_build_graph/paged_attention/golden.py index 08519066..7ea48a54 100644 --- a/examples/host_build_graph/paged_attention/golden.py +++ b/examples/host_build_graph/paged_attention/golden.py @@ -120,10 +120,13 @@ def paged_attention( """ Compute paged attention using online softmax with head tiling and GQA. + Vectorized across the batch dimension for performance. + Supports different context_lens per batch via masking. + Args: - query: (batch, num_heads, head_dim) float16 - key_cache: (total_blocks, block_size, num_kv_heads, head_dim) float16 - value_cache: (total_blocks, block_size, num_kv_heads, head_dim) float16 + query: (batch, num_heads, head_dim) bfloat16 + key_cache: (total_blocks, block_size, num_kv_heads, head_dim) bfloat16 + value_cache: (total_blocks, block_size, num_kv_heads, head_dim) bfloat16 num_kv_heads: int num_heads: int scale_value: float @@ -131,65 +134,87 @@ def paged_attention( context_lens: (batch,) int32 Returns: - out: (batch, num_heads, head_dim) float32 + out: (batch * num_heads, head_dim) float32 """ assert num_kv_heads == 1 - batch, num_heads, head_dim = query.shape + batch, num_heads_dim, head_dim = query.shape _, block_size, _, _ = key_cache.shape - _, block_num = block_table.shape - - query = query.reshape(-1, head_dim) - key_cache = key_cache.reshape(-1, block_size, head_dim) - value_cache = value_cache.reshape(-1, block_size, head_dim) - - out = torch.zeros((batch * num_heads, head_dim), dtype=torch.float32) - - for b_idx in range(batch): - cur_seq = int(context_lens[b_idx]) - bn_this_batch = (cur_seq + block_size - 1) // block_size - assert bn_this_batch <= block_num - - q_tile = min(num_heads, 128) - for cur_offset in range(0, num_heads, q_tile): - q_tile_size = min(q_tile, num_heads - cur_offset) - base_idx = b_idx * num_heads + cur_offset - qi = query[base_idx : base_idx + q_tile_size].to(torch.float32) - - oi = None - li = None - mi = None - - for bn in range(bn_this_batch): - cur_block_idx = block_table[b_idx, bn] - valid_len = min(block_size, cur_seq - bn * block_size) - kj = key_cache[cur_block_idx, :valid_len, :].to(torch.float32) - vj = value_cache[cur_block_idx, :valid_len, :].to(torch.float32) - - sij = (qi @ kj.T) * scale_value - mij = sij.max(dim=-1, keepdim=True)[0] - pij = torch.exp(sij - mij).to(torch.float16).to(torch.float32) - lij = torch.sum(pij, dim=1, keepdim=True) - - if bn == 0: - oi = pij @ vj - li = lij - mi = mij - else: - mi_new = torch.maximum(mi, mij) - alpha = torch.exp(mi - mi_new) - beta = torch.exp(mij - mi_new) - li_new = alpha * li + beta * lij - oi_new = pij @ vj - oi = alpha * oi + beta * oi_new - li = li_new - mi = mi_new - - if bn == bn_this_batch - 1: - oi = oi / li - - out[base_idx : base_idx + q_tile_size] = oi - - return out + + # Reshape for batched computation + key_cache_flat = key_cache.reshape(-1, block_size, head_dim) + value_cache_flat = value_cache.reshape(-1, block_size, head_dim) + + out = torch.zeros((batch, num_heads_dim, head_dim), dtype=torch.float32) + + q_tile = min(num_heads_dim, 128) + + # Max blocks across all batches (each batch may have different context_len) + max_bn = int(((context_lens.max().item()) + block_size - 1) // block_size) + + for q_offset in range(0, num_heads_dim, q_tile): + q_tile_size = min(q_tile, num_heads_dim - q_offset) + # qi: (batch, q_tile_size, head_dim) + qi = query[:, q_offset:q_offset + q_tile_size, :].to(torch.float32) + + oi = None # (batch, q_tile_size, head_dim) + li = None # (batch, q_tile_size, 1) + mi = None # (batch, q_tile_size, 1) + + for bn in range(max_bn): + # valid_len per batch for this block position + valid_lens = torch.clamp(context_lens - bn * block_size, min=0, max=block_size) + active_mask = valid_lens > 0 # (batch,) + + if not active_mask.any(): + break + + # Gather block indices for all batches + block_indices = block_table[:, bn] # (batch,) + + # Gather K and V: (batch, block_size, head_dim) + kj_all = key_cache_flat[block_indices].to(torch.float32) + vj_all = value_cache_flat[block_indices].to(torch.float32) + + # QK matmul: (batch, q_tile_size, block_size) + sij = torch.bmm(qi, kj_all.transpose(1, 2)) * scale_value + + # Mask out invalid positions (beyond valid_len per batch) + pos = torch.arange(block_size, device=sij.device).unsqueeze(0) # (1, block_size) + valid_mask = pos < valid_lens.unsqueeze(1) # (batch, block_size) + valid_mask = valid_mask.unsqueeze(1) # (batch, 1, block_size) + sij = sij.masked_fill(~valid_mask, float('-inf')) + + # Also mask inactive batches (no blocks at this position) + batch_mask = active_mask.view(-1, 1, 1) # (batch, 1, 1) + sij = sij.masked_fill(~batch_mask, float('-inf')) + + mij = sij.max(dim=-1, keepdim=True)[0] # (batch, q_tile_size, 1) + mij = mij.clamp(min=-1e30) + pij = torch.exp(sij - mij) + pij = pij.masked_fill(~valid_mask, 0.0) + pij = pij.masked_fill(~batch_mask, 0.0) + pij = pij.to(torch.bfloat16).to(torch.float32) + lij = pij.sum(dim=-1, keepdim=True) # (batch, q_tile_size, 1) + + # PV matmul: (batch, q_tile_size, head_dim) + oi_new = torch.bmm(pij, vj_all) + + if bn == 0: + oi = oi_new + li = lij + mi = mij + else: + mi_new = torch.maximum(mi, mij) + alpha = torch.exp(mi - mi_new) + beta = torch.exp(mij - mi_new) + li = alpha * li + beta * lij + oi = alpha * oi + beta * oi_new + mi = mi_new + + # Final normalization + out[:, q_offset:q_offset + q_tile_size, :] = oi / li + + return out.reshape(-1, head_dim) def compute_golden(tensors: dict, params: dict) -> None: @@ -203,13 +228,12 @@ def compute_golden(tensors: dict, params: dict) -> None: max_num_blocks_per_req = max_model_len // block_size - # Reconstruct shaped arrays from flat float16 tensors - # Convert to torch tensors (handles both array types) - query = torch.as_tensor(tensors["query"]).reshape(batch, num_heads, head_dim) - key_cache = torch.as_tensor(tensors["key_cache"]).reshape(-1, block_size, kv_head_num, head_dim) - value_cache = torch.as_tensor(tensors["value_cache"]).reshape(-1, block_size, kv_head_num, head_dim) - block_table = torch.as_tensor(tensors["block_table"]).reshape(batch, max_num_blocks_per_req) - context_lens = torch.as_tensor(tensors["context_lens"]) + # Reconstruct shaped tensors from flat tensors + query = tensors["query"].reshape(batch, num_heads, head_dim) + key_cache = tensors["key_cache"].reshape(-1, block_size, kv_head_num, head_dim) + value_cache = tensors["value_cache"].reshape(-1, block_size, kv_head_num, head_dim) + block_table = tensors["block_table"].reshape(batch, max_num_blocks_per_req) + context_lens = tensors["context_lens"] out = paged_attention( query=query, diff --git a/examples/scripts/code_runner.py b/examples/scripts/code_runner.py index 06530be0..3870b0fb 100644 --- a/examples/scripts/code_runner.py +++ b/examples/scripts/code_runner.py @@ -36,6 +36,7 @@ def compute_golden(tensors: dict, params: dict) -> None: import logging import os import sys +import time from contextlib import contextmanager from pathlib import Path from typing import Any, Dict, List, Optional, Tuple @@ -629,6 +630,7 @@ def run(self) -> None: runtime.enable_profiling(True) logger.info("Profiling enabled") + _t_init_start = time.perf_counter() with _temporary_env(run_env): runtime.initialize( orch_so_binary, @@ -638,12 +640,20 @@ def run(self) -> None: arg_sizes=arg_sizes, kernel_binaries=kernel_binaries, ) + _t_init_end = time.perf_counter() + logger.info(f">>> runtime.initialize() took {_t_init_end - _t_init_start:.3f}s") # Save expected values BEFORE hardware execution (outputs will be overwritten) golden = {k: v.clone() for k, v in outputs.items()} # Convert to dict for compute_golden (may expect numpy-like interface) golden_with_inputs = {**inputs, **golden} + _t_golden_start = time.perf_counter() self._golden_module.compute_golden(golden_with_inputs, params) + _t_golden_end = time.perf_counter() + logger.info(f">>> compute_golden() took {_t_golden_end - _t_golden_start:.3f}s") + logger.info(f">>> Total init-to-launch: {_t_golden_end - _t_init_start:.3f}s " + f"(initialize={_t_init_end - _t_init_start:.3f}s, " + f"golden={_t_golden_end - _t_golden_start:.3f}s)") # Launch runtime logger.info("=== Launching Runtime ===") diff --git a/examples/tensormap_and_ringbuffer/paged_attention/golden.py b/examples/tensormap_and_ringbuffer/paged_attention/golden.py index 96293385..cb02beb1 100644 --- a/examples/tensormap_and_ringbuffer/paged_attention/golden.py +++ b/examples/tensormap_and_ringbuffer/paged_attention/golden.py @@ -117,10 +117,13 @@ def paged_attention( """ Compute paged attention using online softmax with head tiling and GQA. + Vectorized across the batch dimension for performance. + Supports different context_lens per batch via masking. + Args: - query: (batch, num_heads, head_dim) float16 - key_cache: (total_blocks, block_size, num_kv_heads, head_dim) float16 - value_cache: (total_blocks, block_size, num_kv_heads, head_dim) float16 + query: (batch, num_heads, head_dim) bfloat16 + key_cache: (total_blocks, block_size, num_kv_heads, head_dim) bfloat16 + value_cache: (total_blocks, block_size, num_kv_heads, head_dim) bfloat16 num_kv_heads: int num_heads: int scale_value: float @@ -128,65 +131,87 @@ def paged_attention( context_lens: (batch,) int32 Returns: - out: (batch, num_heads, head_dim) float32 + out: (batch * num_heads, head_dim) float32 """ assert num_kv_heads == 1 - batch, num_heads, head_dim = query.shape + batch, num_heads_dim, head_dim = query.shape _, block_size, _, _ = key_cache.shape - _, block_num = block_table.shape - - query = query.reshape(-1, head_dim) - key_cache = key_cache.reshape(-1, block_size, head_dim) - value_cache = value_cache.reshape(-1, block_size, head_dim) - - out = torch.zeros((batch * num_heads, head_dim), dtype=torch.float32) - - for b_idx in range(batch): - cur_seq = int(context_lens[b_idx]) - bn_this_batch = (cur_seq + block_size - 1) // block_size - assert bn_this_batch <= block_num - - q_tile = min(num_heads, 128) - for cur_offset in range(0, num_heads, q_tile): - q_tile_size = min(q_tile, num_heads - cur_offset) - base_idx = b_idx * num_heads + cur_offset - qi = query[base_idx : base_idx + q_tile_size].to(torch.float32) - - oi = None - li = None - mi = None - - for bn in range(bn_this_batch): - cur_block_idx = block_table[b_idx, bn] - valid_len = min(block_size, cur_seq - bn * block_size) - kj = key_cache[cur_block_idx, :valid_len, :].to(torch.float32) - vj = value_cache[cur_block_idx, :valid_len, :].to(torch.float32) - - sij = (qi @ kj.T) * scale_value - mij = sij.max(dim=-1, keepdim=True)[0] - pij = torch.exp(sij - mij).to(torch.float16).to(torch.float32) - lij = pij.sum(dim=1, keepdim=True) - - if bn == 0: - oi = pij @ vj - li = lij - mi = mij - else: - mi_new = torch.maximum(mi, mij) - alpha = torch.exp(mi - mi_new) - beta = torch.exp(mij - mi_new) - li_new = alpha * li + beta * lij - oi_new = pij @ vj - oi = alpha * oi + beta * oi_new - li = li_new - mi = mi_new - - if bn == bn_this_batch - 1: - oi = oi / li - - out[base_idx : base_idx + q_tile_size] = oi - - return out + + # Reshape for batched computation + key_cache_flat = key_cache.reshape(-1, block_size, head_dim) + value_cache_flat = value_cache.reshape(-1, block_size, head_dim) + + out = torch.zeros((batch, num_heads_dim, head_dim), dtype=torch.float32) + + q_tile = min(num_heads_dim, 128) + + # Max blocks across all batches (each batch may have different context_len) + max_bn = int(((context_lens.max().item()) + block_size - 1) // block_size) + + for q_offset in range(0, num_heads_dim, q_tile): + q_tile_size = min(q_tile, num_heads_dim - q_offset) + # qi: (batch, q_tile_size, head_dim) + qi = query[:, q_offset:q_offset + q_tile_size, :].to(torch.float32) + + oi = None # (batch, q_tile_size, head_dim) + li = None # (batch, q_tile_size, 1) + mi = None # (batch, q_tile_size, 1) + + for bn in range(max_bn): + # valid_len per batch for this block position + valid_lens = torch.clamp(context_lens - bn * block_size, min=0, max=block_size) + active_mask = valid_lens > 0 # (batch,) + + if not active_mask.any(): + break + + # Gather block indices for all batches + block_indices = block_table[:, bn] # (batch,) + + # Gather K and V: (batch, block_size, head_dim) + kj_all = key_cache_flat[block_indices].to(torch.float32) + vj_all = value_cache_flat[block_indices].to(torch.float32) + + # QK matmul: (batch, q_tile_size, block_size) + sij = torch.bmm(qi, kj_all.transpose(1, 2)) * scale_value + + # Mask out invalid positions (beyond valid_len per batch) + pos = torch.arange(block_size, device=sij.device).unsqueeze(0) # (1, block_size) + valid_mask = pos < valid_lens.unsqueeze(1) # (batch, block_size) + valid_mask = valid_mask.unsqueeze(1) # (batch, 1, block_size) + sij = sij.masked_fill(~valid_mask, float('-inf')) + + # Also mask inactive batches (no blocks at this position) + batch_mask = active_mask.view(-1, 1, 1) # (batch, 1, 1) + sij = sij.masked_fill(~batch_mask, float('-inf')) + + mij = sij.max(dim=-1, keepdim=True)[0] # (batch, q_tile_size, 1) + mij = mij.clamp(min=-1e30) + pij = torch.exp(sij - mij) + pij = pij.masked_fill(~valid_mask, 0.0) + pij = pij.masked_fill(~batch_mask, 0.0) + pij = pij.to(torch.bfloat16).to(torch.float32) + lij = pij.sum(dim=-1, keepdim=True) # (batch, q_tile_size, 1) + + # PV matmul: (batch, q_tile_size, head_dim) + oi_new = torch.bmm(pij, vj_all) + + if bn == 0: + oi = oi_new + li = lij + mi = mij + else: + mi_new = torch.maximum(mi, mij) + alpha = torch.exp(mi - mi_new) + beta = torch.exp(mij - mi_new) + li = alpha * li + beta * lij + oi = alpha * oi + beta * oi_new + mi = mi_new + + # Final normalization + out[:, q_offset:q_offset + q_tile_size, :] = oi / li + + return out.reshape(-1, head_dim) def compute_golden(tensors: dict, params: dict) -> None: diff --git a/src/platform/a2a3/aicpu/device_log.cpp b/src/platform/a2a3/aicpu/device_log.cpp index db92a7a2..b3c347a4 100644 --- a/src/platform/a2a3/aicpu/device_log.cpp +++ b/src/platform/a2a3/aicpu/device_log.cpp @@ -61,3 +61,12 @@ void dev_log_error(const char* func, const char* fmt, ...) { va_end(args); dlog_error(AICPU, "%lu %s\n\"%s\"", GET_TID(), func, buffer); } + +void dev_log_always(const char* func, const char* fmt, ...) { + va_list args; + va_start(args, fmt); + char buffer[2048]; + vsnprintf(buffer, sizeof(buffer), fmt, args); + va_end(args); + dlog_error(AICPU, "%lu %s\n\"%s\"", GET_TID(), func, buffer); +} diff --git a/src/platform/a2a3sim/aicpu/device_log.cpp b/src/platform/a2a3sim/aicpu/device_log.cpp index b7993152..b785849a 100644 --- a/src/platform/a2a3sim/aicpu/device_log.cpp +++ b/src/platform/a2a3sim/aicpu/device_log.cpp @@ -120,3 +120,12 @@ void dev_log_error(const char* func, const char* fmt, ...) { printf("\n"); va_end(args); } + +void dev_log_always(const char* func, const char* fmt, ...) { + va_list args; + va_start(args, fmt); + printf("[ALWAYS] %s: ", func); + vprintf(fmt, args); + printf("\n"); + va_end(args); +} diff --git a/src/platform/include/aicpu/device_log.h b/src/platform/include/aicpu/device_log.h index c84cd061..fd21c685 100644 --- a/src/platform/include/aicpu/device_log.h +++ b/src/platform/include/aicpu/device_log.h @@ -51,6 +51,7 @@ void dev_log_debug(const char* func, const char* fmt, ...); void dev_log_info(const char* func, const char* fmt, ...); void dev_log_warn(const char* func, const char* fmt, ...); void dev_log_error(const char* func, const char* fmt, ...); +void dev_log_always(const char* func, const char* fmt, ...); // ============================================================================= // High-Level Logging Macros (Platform-Independent Layer) @@ -92,6 +93,7 @@ void dev_log_error(const char* func, const char* fmt, ...); #define DEV_INFO(fmt, args...) D_DEV_LOGI(TILE_FWK_DEVICE_MACHINE, fmt, ##args) #define DEV_WARN(fmt, args...) D_DEV_LOGW(TILE_FWK_DEVICE_MACHINE, fmt, ##args) #define DEV_ERROR(fmt, args...) D_DEV_LOGE(TILE_FWK_DEVICE_MACHINE, fmt, ##args) +#define DEV_ALWAYS(fmt, args...) dev_log_always(__FUNCTION__, fmt, ##args) // ============================================================================= // Platform-Specific Assertion diff --git a/src/runtime/aicpu_build_graph/aicpu/aicpu_executor.cpp b/src/runtime/aicpu_build_graph/aicpu/aicpu_executor.cpp index 9f02abd1..0fef4011 100644 --- a/src/runtime/aicpu_build_graph/aicpu/aicpu_executor.cpp +++ b/src/runtime/aicpu_build_graph/aicpu/aicpu_executor.cpp @@ -825,11 +825,11 @@ void AicpuExecutor::deinit() { void AicpuExecutor::diagnose_stuck_state( Runtime& runtime, int thread_idx, const int* cur_thread_cores, int core_num, Handshake* hank) { - DEV_ERROR("========== DIAGNOSTIC REPORT: Thread %d ==========", thread_idx); + DEV_ALWAYS("========== DIAGNOSTIC REPORT: Thread %d ==========", thread_idx); int completed = completed_tasks_.load(std::memory_order_acquire); int published = published_tasks_.load(std::memory_order_acquire); - DEV_ERROR("Progress: completed=%d published=%d build_done=%d build_failed=%d", + DEV_ALWAYS("Progress: completed=%d published=%d build_done=%d build_failed=%d", completed, published, build_done_.load(std::memory_order_acquire) ? 1 : 0, @@ -837,13 +837,13 @@ void AicpuExecutor::diagnose_stuck_state( int aic_ready = ready_count_aic_.load(std::memory_order_acquire); int aiv_ready = ready_count_aiv_.load(std::memory_order_acquire); - DEV_ERROR("Ready Queues: AIC=%d, AIV=%d", aic_ready, aiv_ready); + DEV_ALWAYS("Ready Queues: AIC=%d, AIV=%d", aic_ready, aiv_ready); int busy_cores = 0; int idle_cores = 0; int anomaly_cores = 0; - DEV_ERROR("Core Status:"); + DEV_ALWAYS("Core Status:"); for (int i = 0; i < core_num; i++) { int core_id = cur_thread_cores[i]; Handshake* h = &hank[core_id]; @@ -854,7 +854,7 @@ void AicpuExecutor::diagnose_stuck_state( Task* task = reinterpret_cast(h->task); busy_cores++; - DEV_ERROR(" Core %d [%s, BUSY]: task_id=%d, func_id=%d, fanin=%d, fanout=%d", + DEV_ALWAYS(" Core %d [%s, BUSY]: task_id=%d, func_id=%d, fanin=%d, fanout=%d", core_id, core_type_str, task->task_id, @@ -863,39 +863,39 @@ void AicpuExecutor::diagnose_stuck_state( task->fanout_count); } else if (h->task_status != 0) { anomaly_cores++; - DEV_ERROR(" Core %d [%s, ANOMALY]: status=BUSY but task=NULL", core_id, core_type_str); + DEV_ALWAYS(" Core %d [%s, ANOMALY]: status=BUSY but task=NULL", core_id, core_type_str); } else { idle_cores++; } } - DEV_ERROR("Summary: %d busy, %d idle, %d anomaly", busy_cores, idle_cores, anomaly_cores); + DEV_ALWAYS("Summary: %d busy, %d idle, %d anomaly", busy_cores, idle_cores, anomaly_cores); // Diagnose deadlock vs livelock if (busy_cores == 0 && aic_ready == 0 && aiv_ready == 0 && completed < published) { - DEV_ERROR("*** DEADLOCK DETECTED ***"); - DEV_ERROR("All cores idle, no ready tasks, but %d tasks incomplete", published - completed); + DEV_ALWAYS("*** DEADLOCK DETECTED ***"); + DEV_ALWAYS("All cores idle, no ready tasks, but %d tasks incomplete", published - completed); - DEV_ERROR("Tasks with fanin > 0:"); + DEV_ALWAYS("Tasks with fanin > 0:"); int stuck_count = 0; int task_count = runtime.get_task_count(); for (int tid = 0; tid < task_count && stuck_count < 10; tid++) { Task* t = runtime.get_task(tid); int fanin = t->fanin.load(std::memory_order_acquire); if (fanin > 0) { - DEV_ERROR(" Task %d: fanin=%d (waiting for dependencies)", tid, fanin); + DEV_ALWAYS(" Task %d: fanin=%d (waiting for dependencies)", tid, fanin); stuck_count++; } } if (stuck_count == 0) { - DEV_ERROR(" No tasks waiting! Possible counter corruption."); + DEV_ALWAYS(" No tasks waiting! Possible counter corruption."); } } else if (busy_cores > 0) { - DEV_ERROR("*** LIVELOCK / HUNG TASK ***"); - DEV_ERROR("%d cores executing but no progress", busy_cores); + DEV_ALWAYS("*** LIVELOCK / HUNG TASK ***"); + DEV_ALWAYS("%d cores executing but no progress", busy_cores); } - DEV_ERROR("========== END DIAGNOSTIC =========="); + DEV_ALWAYS("========== END DIAGNOSTIC =========="); } // ===== Public Entry Point ===== diff --git a/src/runtime/tensormap_and_ringbuffer/aicpu/aicpu_executor.cpp b/src/runtime/tensormap_and_ringbuffer/aicpu/aicpu_executor.cpp index 4ce464e4..e8f5699c 100644 --- a/src/runtime/tensormap_and_ringbuffer/aicpu/aicpu_executor.cpp +++ b/src/runtime/tensormap_and_ringbuffer/aicpu/aicpu_executor.cpp @@ -28,6 +28,14 @@ #include "common/perf_profiling.h" #include "common/memory_barrier.h" #include "common/unified_log.h" + +// Scheduler profiling helper +#include +static inline uint64_t _orch_now_ns() { + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + return (uint64_t)ts.tv_sec * 1000000000ULL + ts.tv_nsec; +} #include "inner_aicpu.h" // Device orchestration function signature (loaded via dlopen). @@ -45,6 +53,14 @@ constexpr int MAX_CORES_PER_THREAD = MAX_AIC_PER_THREAD + MAX_AIV_PER_THREAD; // Maximum tasks for ready queue (PTO2 mode uses shared memory task count) constexpr int AICPU_MAX_READY_TASKS = 16384; +constexpr int AICPU_READY_MASK = AICPU_MAX_READY_TASKS - 1; + +// Lightweight spinlock (avoids futex syscall overhead of std::mutex) +struct SpinLock { + std::atomic flag{0}; + void lock() { while (flag.exchange(1, std::memory_order_acquire) != 0) { PTO2_SPIN_PAUSE_LIGHT(); } } + void unlock() { flag.store(0, std::memory_order_release); } +}; // Core information for discovery (aligned with host_build_graph) struct CoreInfo { @@ -73,17 +89,16 @@ struct AicpuExecutor { int aiv_count_{0}; // ===== Task queue state (FIFO circular queue, aligned with host_build_graph) ===== - std::mutex ready_queue_aic_mutex_; + // ===== Spinlock-based MPMC ready queues (lighter than std::mutex) ===== + SpinLock ready_queue_aic_lock_; int ready_queue_aic_[AICPU_MAX_READY_TASKS]; - std::atomic ready_count_aic_{0}; - int ready_queue_aic_head_{0}; // Circular queue: read position (front) - int ready_queue_aic_tail_{0}; // Circular queue: write position (back) + int ready_queue_aic_head_{0}; + int ready_queue_aic_tail_{0}; - std::mutex ready_queue_aiv_mutex_; + SpinLock ready_queue_aiv_lock_; int ready_queue_aiv_[AICPU_MAX_READY_TASKS]; - std::atomic ready_count_aiv_{0}; - int ready_queue_aiv_head_{0}; // Circular queue: read position (front) - int ready_queue_aiv_tail_{0}; // Circular queue: write position (back) + int ready_queue_aiv_head_{0}; + int ready_queue_aiv_tail_{0}; // Task execution tracking std::atomic completed_tasks_{0}; @@ -97,6 +112,12 @@ struct AicpuExecutor { std::atomic perf_init_done_{false}; std::atomic sm_header_ready_{false}; // Thread 3 sets after SM header init + // Orchestrator ready queue pointers (set by Thread 3, read by scheduler threads) + volatile int32_t* orch_ready_queue_{nullptr}; + volatile int32_t* orch_ready_tail_{nullptr}; + volatile int32_t* orch_ready_head_{nullptr}; + int32_t orch_ready_capacity_{0}; + // Orchestration SO handle - defer dlclose until all tasks complete void* orch_so_handle_{nullptr}; char orch_so_path_[256]{}; // Path to orchestration SO file for cleanup @@ -292,8 +313,6 @@ int AicpuExecutor::init(Runtime* runtime) { orchestrator_done_.store(orch_on_host, std::memory_order_release); // Initial ready tasks will be populated from PTO2 shared memory in resolve_and_dispatch_pto2 - ready_count_aic_.store(0, std::memory_order_release); - ready_count_aiv_.store(0, std::memory_order_release); ready_queue_aic_head_ = 0; ready_queue_aic_tail_ = 0; ready_queue_aiv_head_ = 0; @@ -357,14 +376,6 @@ static void build_pto2_payload(PTO2DispatchPayload* out, Runtime* runtime, } out->num_args = n; - DEV_INFO("build_pto2_payload ok"); - for (int i = 0; i < task->param_count; i++) { - if (task->params[i].type == PTOParamType::SCALAR) { - DEV_INFO("build_pto2_payload param %d scalar: %d", i, out->args[i]); - } else { - DEV_INFO("build_pto2_payload param %d addr: %x", i, out->args[i]); - } - } } int AicpuExecutor::resolve_and_dispatch_pto2(Runtime* runtime, int thread_idx, @@ -433,7 +444,17 @@ int AicpuExecutor::resolve_and_dispatch_pto2(Runtime* runtime, int thread_idx, const int WARN_INTERVAL = 1000000; bool profiling_enabled = runtime->enable_profiling; + // Scheduler profiling counters + uint64_t sched_scan_ns = 0, sched_orch_drain_ns = 0; + uint64_t sched_complete_ns = 0, sched_dispatch_ns = 0, sched_yield_ns = 0; + uint64_t sched_loop_count = 0, sched_yield_count = 0; + // Fanout traversal statistics + uint64_t total_fanout_traversed = 0; + int32_t max_fanout_len = 0; + while (true) { + sched_loop_count++; + uint64_t _phase_t0 = _orch_now_ns(), _phase_t1; // Dynamic task_count (Thread 3 sets total_tasks_ when orchestration completes) int32_t task_count = total_tasks_.load(std::memory_order_acquire); bool orch_done = orchestrator_done_.load(std::memory_order_acquire); @@ -481,57 +502,57 @@ int AicpuExecutor::resolve_and_dispatch_pto2(Runtime* runtime, int thread_idx, __atomic_store_n(&s_pto2_task_completed[slot], 1, __ATOMIC_RELEASE); int32_t wt = t->worker_type; if (wt == PTO2_WORKER_CUBE) { - std::lock_guard lock(ready_queue_aic_mutex_); - ready_queue_aic_[ready_queue_aic_tail_] = idx; - ready_queue_aic_tail_ = (ready_queue_aic_tail_ + 1) % AICPU_MAX_READY_TASKS; - ready_count_aic_.fetch_add(1, std::memory_order_release); + ready_queue_aic_lock_.lock(); + ready_queue_aic_[ready_queue_aic_tail_++ & AICPU_READY_MASK] = idx; + ready_queue_aic_lock_.unlock(); } else { - std::lock_guard lock(ready_queue_aiv_mutex_); - ready_queue_aiv_[ready_queue_aiv_tail_] = idx; - ready_queue_aiv_tail_ = (ready_queue_aiv_tail_ + 1) % AICPU_MAX_READY_TASKS; - ready_count_aiv_.fetch_add(1, std::memory_order_release); + ready_queue_aiv_lock_.lock(); + ready_queue_aiv_[ready_queue_aiv_tail_++ & AICPU_READY_MASK] = idx; + ready_queue_aiv_lock_.unlock(); } made_progress = true; } } } + _phase_t1 = _orch_now_ns(); sched_scan_ns += (_phase_t1 - _phase_t0); _phase_t0 = _phase_t1; - // Readiness scan: enqueue tasks made ready by orchestrator's early-return path - // (producer already completed → refcount incremented directly, but not enqueued) - { - int32_t scanned = next_scan_index_.load(std::memory_order_acquire); - for (int32_t idx = 0; idx < scanned; idx++) { - int32_t slot = idx & window_mask; - int32_t state = __atomic_load_n(&s_pto2_task_completed[slot], __ATOMIC_ACQUIRE); - if (state != 0) continue; // already enqueued (1) or completed (2) - PTO2TaskDescriptor* t = &task_descriptors[slot]; - int32_t fanin_count = __atomic_load_n(&t->fanin_count, __ATOMIC_ACQUIRE); - if (fanin_count <= 0) continue; // root tasks handled by incremental scan + // Drain orchestrator ready queue: tasks made ready by orchestrator's early-return path + // (producer already completed → refcount incremented directly, consumer pushed to queue) + if (orch_ready_queue_ != nullptr) { + while (true) { + int32_t head = __atomic_load_n(orch_ready_head_, __ATOMIC_ACQUIRE); + int32_t tail = __atomic_load_n(orch_ready_tail_, __ATOMIC_ACQUIRE); + if (head == tail) break; // queue empty - int32_t refcount = __atomic_load_n(&s_pto2_fanin_refcount[slot], __ATOMIC_ACQUIRE); - if (refcount < fanin_count) continue; + // CAS to claim this slot (multiple scheduler threads compete) + if (!__atomic_compare_exchange_n(orch_ready_head_, &head, head + 1, + false, __ATOMIC_ACQ_REL, __ATOMIC_ACQUIRE)) continue; - // CAS from 0 → 1 to claim enqueue rights + int32_t task_id = orch_ready_queue_[head & (orch_ready_capacity_ - 1)]; + int32_t slot = task_id & window_mask; + + // CAS from 0 → 1 to claim enqueue rights (may already be enqueued by fanout path) int32_t expected = 0; if (!__atomic_compare_exchange_n(&s_pto2_task_completed[slot], &expected, 1, false, __ATOMIC_ACQ_REL, __ATOMIC_ACQUIRE)) continue; + PTO2TaskDescriptor* t = &task_descriptors[slot]; int32_t wt = t->worker_type; if (wt == PTO2_WORKER_CUBE) { - std::lock_guard lock(ready_queue_aic_mutex_); - ready_queue_aic_[ready_queue_aic_tail_] = idx; - ready_queue_aic_tail_ = (ready_queue_aic_tail_ + 1) % AICPU_MAX_READY_TASKS; - ready_count_aic_.fetch_add(1, std::memory_order_release); + ready_queue_aic_lock_.lock(); + ready_queue_aic_[ready_queue_aic_tail_++ & AICPU_READY_MASK] = task_id; + ready_queue_aic_lock_.unlock(); } else { - std::lock_guard lock(ready_queue_aiv_mutex_); - ready_queue_aiv_[ready_queue_aiv_tail_] = idx; - ready_queue_aiv_tail_ = (ready_queue_aiv_tail_ + 1) % AICPU_MAX_READY_TASKS; - ready_count_aiv_.fetch_add(1, std::memory_order_release); + ready_queue_aiv_lock_.lock(); + ready_queue_aiv_[ready_queue_aiv_tail_++ & AICPU_READY_MASK] = task_id; + ready_queue_aiv_lock_.unlock(); } made_progress = true; } } + _phase_t1 = _orch_now_ns(); sched_orch_drain_ns += (_phase_t1 - _phase_t0); _phase_t0 = _phase_t1; + // Phase 1: Process completed tasks (Handshake.task = PTO2DispatchPayload*) for (int i = 0; i < core_num; i++) { @@ -566,17 +587,19 @@ int AicpuExecutor::resolve_and_dispatch_pto2(Runtime* runtime, int thread_idx, int32_t task_id = payload->task_id; PTO2TaskDescriptor* pto2_task = &task_descriptors[task_id & window_mask]; - DEV_INFO("Thread %d: Core %d completed PTO2 task %d", thread_idx, core_id, task_id); + DEV_DEBUG("Thread %d: Core %d completed PTO2 task %d", thread_idx, core_id, task_id); // Acquire fanout_lock, mark completed (state=2), snapshot fanout_head - while (PTO2_EXCHANGE(&pto2_task->fanout_lock, 1) != 0) { PTO2_SPIN_PAUSE(); } + while (PTO2_EXCHANGE(&pto2_task->fanout_lock, 1) != 0) { PTO2_SPIN_PAUSE_LIGHT(); } __atomic_store_n(&s_pto2_task_completed[task_id & window_mask], 2, __ATOMIC_RELEASE); int32_t fanout_head = pto2_task->fanout_head; PTO2_STORE_RELEASE(&pto2_task->fanout_lock, 0); // Traverse fanout outside lock + int32_t fanout_len = 0; int32_t current = fanout_head; while (current > 0) { + fanout_len++; PTO2DepListEntry* entry = &dep_list_pool[current]; int32_t consumer_id = entry->task_id; int32_t consumer_slot = consumer_id & window_mask; @@ -587,21 +610,19 @@ int AicpuExecutor::resolve_and_dispatch_pto2(Runtime* runtime, int thread_idx, __atomic_store_n(&s_pto2_task_completed[consumer_slot], 1, __ATOMIC_RELEASE); int32_t wt = consumer_desc->worker_type; if (wt == PTO2_WORKER_CUBE) { - std::lock_guard lock(ready_queue_aic_mutex_); - // FIFO: enqueue to tail - ready_queue_aic_[ready_queue_aic_tail_] = consumer_id; - ready_queue_aic_tail_ = (ready_queue_aic_tail_ + 1) % AICPU_MAX_READY_TASKS; - ready_count_aic_.fetch_add(1, std::memory_order_release); + ready_queue_aic_lock_.lock(); + ready_queue_aic_[ready_queue_aic_tail_++ & AICPU_READY_MASK] = consumer_id; + ready_queue_aic_lock_.unlock(); } else { - std::lock_guard lock(ready_queue_aiv_mutex_); - // FIFO: enqueue to tail - ready_queue_aiv_[ready_queue_aiv_tail_] = consumer_id; - ready_queue_aiv_tail_ = (ready_queue_aiv_tail_ + 1) % AICPU_MAX_READY_TASKS; - ready_count_aiv_.fetch_add(1, std::memory_order_release); + ready_queue_aiv_lock_.lock(); + ready_queue_aiv_[ready_queue_aiv_tail_++ & AICPU_READY_MASK] = consumer_id; + ready_queue_aiv_lock_.unlock(); } } current = entry->next_offset; } + total_fanout_traversed += fanout_len; + if (fanout_len > max_fanout_len) max_fanout_len = fanout_len; cur_thread_tasks_in_flight--; cur_thread_completed++; @@ -609,6 +630,7 @@ int AicpuExecutor::resolve_and_dispatch_pto2(Runtime* runtime, int thread_idx, completed_tasks_.fetch_add(1, std::memory_order_release); } } + _phase_t1 = _orch_now_ns(); sched_complete_ns += (_phase_t1 - _phase_t0); _phase_t0 = _phase_t1; // Phase 2: Dispatch ready tasks to idle cores (build PTO2DispatchPayload) if (cur_thread_tasks_in_flight < core_num) { @@ -616,53 +638,37 @@ int AicpuExecutor::resolve_and_dispatch_pto2(Runtime* runtime, int thread_idx, int core_id = cur_thread_cores[i]; Handshake* h = &hank[core_id]; if (h->task_status == 0 && h->task == 0) { - bool dispatched = false; - if (h->core_type == CoreType::AIC && ready_count_aic_.load(std::memory_order_acquire) > 0) { - std::lock_guard lock(ready_queue_aic_mutex_); - int count = ready_count_aic_.load(std::memory_order_relaxed); - if (count > 0) { - // FIFO: dequeue from head - int32_t task_id = ready_queue_aic_[ready_queue_aic_head_]; - ready_queue_aic_head_ = (ready_queue_aic_head_ + 1) % AICPU_MAX_READY_TASKS; - ready_count_aic_.fetch_sub(1, std::memory_order_release); - PTO2TaskDescriptor* task = &task_descriptors[task_id & window_mask]; - PTO2DispatchPayload* payload = &s_pto2_payload_per_core[core_id]; - build_pto2_payload(payload, runtime, task, task_descriptors, dep_list_pool, window_size); - h->task = reinterpret_cast(payload); - if (runtime->enable_profiling) { - dispatch_timestamps_[core_id] = get_sys_cnt_aicpu(); - } - h->task_status = 1; - cur_thread_tasks_in_flight++; - made_progress = true; - dispatched = true; - DEV_INFO("Thread %d: Dispatching PTO2 AIC task %d to core %d", thread_idx, task_id, core_id); + int32_t task_id = -1; + if (h->core_type == CoreType::AIC) { + ready_queue_aic_lock_.lock(); + if (ready_queue_aic_head_ < ready_queue_aic_tail_) { + task_id = ready_queue_aic_[ready_queue_aic_head_++ & AICPU_READY_MASK]; } + ready_queue_aic_lock_.unlock(); + } else { + ready_queue_aiv_lock_.lock(); + if (ready_queue_aiv_head_ < ready_queue_aiv_tail_) { + task_id = ready_queue_aiv_[ready_queue_aiv_head_++ & AICPU_READY_MASK]; + } + ready_queue_aiv_lock_.unlock(); } - if (!dispatched && h->core_type == CoreType::AIV && ready_count_aiv_.load(std::memory_order_acquire) > 0) { - std::lock_guard lock(ready_queue_aiv_mutex_); - int count = ready_count_aiv_.load(std::memory_order_relaxed); - if (count > 0) { - // FIFO: dequeue from head - int32_t task_id = ready_queue_aiv_[ready_queue_aiv_head_]; - ready_queue_aiv_head_ = (ready_queue_aiv_head_ + 1) % AICPU_MAX_READY_TASKS; - ready_count_aiv_.fetch_sub(1, std::memory_order_release); - PTO2TaskDescriptor* task = &task_descriptors[task_id & window_mask]; - PTO2DispatchPayload* payload = &s_pto2_payload_per_core[core_id]; - build_pto2_payload(payload, runtime, task, task_descriptors, dep_list_pool, window_size); - h->task = reinterpret_cast(payload); - if (runtime->enable_profiling) { - dispatch_timestamps_[core_id] = get_sys_cnt_aicpu(); - } - h->task_status = 1; - cur_thread_tasks_in_flight++; - made_progress = true; - DEV_INFO("Thread %d: Dispatching PTO2 AIV task %d to core %d", thread_idx, task_id, core_id); + if (task_id >= 0) { + PTO2TaskDescriptor* task = &task_descriptors[task_id & window_mask]; + PTO2DispatchPayload* payload = &s_pto2_payload_per_core[core_id]; + build_pto2_payload(payload, runtime, task, task_descriptors, dep_list_pool, window_size); + h->task = reinterpret_cast(payload); + if (runtime->enable_profiling) { + dispatch_timestamps_[core_id] = get_sys_cnt_aicpu(); } + h->task_status = 1; + cur_thread_tasks_in_flight++; + made_progress = true; + DEV_DEBUG("Thread %d: Dispatching PTO2 task %d to core %d", thread_idx, task_id, core_id); } } } } + _phase_t1 = _orch_now_ns(); sched_dispatch_ns += (_phase_t1 - _phase_t0); _phase_t0 = _phase_t1; if (!made_progress) { idle_iterations++; @@ -675,12 +681,32 @@ int AicpuExecutor::resolve_and_dispatch_pto2(Runtime* runtime, int thread_idx, return -1; } std::this_thread::yield(); + sched_yield_count++; + _phase_t1 = _orch_now_ns(); sched_yield_ns += (_phase_t1 - _phase_t0); } else { idle_iterations = 0; } } - DEV_INFO("Thread %d: PTO2 execution complete, completed %d tasks", thread_idx, cur_thread_completed); + uint64_t sched_total = sched_scan_ns + sched_orch_drain_ns + sched_complete_ns + sched_dispatch_ns + sched_yield_ns; + if (sched_total == 0) sched_total = 1; // avoid div-by-zero + DEV_ALWAYS("Thread %d: PTO2 scheduler stats: loops=%llu, completed=%d, total=%.3fms", + thread_idx, (unsigned long long)sched_loop_count, cur_thread_completed, sched_total/1e6); + DEV_ALWAYS("Thread %d: scan=%.3fms (%.1f%%), orch_drain=%.3fms (%.1f%%), complete=%.3fms (%.1f%%), dispatch=%.3fms (%.1f%%)", + thread_idx, + sched_scan_ns/1e6, sched_scan_ns*100.0/sched_total, + sched_orch_drain_ns/1e6, sched_orch_drain_ns*100.0/sched_total, + sched_complete_ns/1e6, sched_complete_ns*100.0/sched_total, + sched_dispatch_ns/1e6, sched_dispatch_ns*100.0/sched_total); + DEV_ALWAYS("Thread %d: yield=%.3fms (%.1f%%, %llu calls, avg=%.1fus)", + thread_idx, sched_yield_ns/1e6, sched_yield_ns*100.0/sched_total, + (unsigned long long)sched_yield_count, + sched_yield_count > 0 ? sched_yield_ns/1e3/sched_yield_count : 0.0); + DEV_ALWAYS("Thread %d: fanout: total_traversed=%llu, max_len=%d, avg=%.1f", + thread_idx, (unsigned long long)total_fanout_traversed, max_fanout_len, + cur_thread_completed > 0 ? (double)total_fanout_traversed / cur_thread_completed : 0.0); + + DEV_ALWAYS("Thread %d: PTO2 execution complete, completed %d tasks", thread_idx, cur_thread_completed); // Flush performance buffers for cores managed by this thread if (profiling_enabled) { @@ -860,6 +886,12 @@ int AicpuExecutor::run(Runtime* runtime) { rt->orchestrator.aicpu_task_completed = s_pto2_task_completed; rt->orchestrator.aicpu_window_mask = ws - 1; + // Expose orchestrator ready queue to scheduler threads + orch_ready_queue_ = rt->orchestrator.orch_ready_queue; + orch_ready_tail_ = &rt->orchestrator.orch_ready_tail; + orch_ready_head_ = &rt->orchestrator.orch_ready_head; + orch_ready_capacity_ = PTO2OrchestratorState::ORCH_READY_QUEUE_SIZE; + // Call orchestration wrapped in outer scope (matches old PTO2_ORCHESTRATION behavior) DEV_INFO("Thread 3: Calling aicpu_orchestration_entry from SO"); PTO2_SCOPE(rt) { @@ -867,6 +899,28 @@ int AicpuExecutor::run(Runtime* runtime) { } DEV_INFO("Thread 3: aicpu_orchestration_entry returned"); + // Print orchestrator profiling data +#if PTO2_ORCH_PROFILING + { + PTO2OrchProfilingData p = pto2_orchestrator_get_profiling(); + uint64_t total = p.sync_ns + p.alloc_ns + p.params_ns + + p.lookup_ns + p.heap_ns + p.insert_ns + + p.fanin_ns + p.finalize_ns; + DEV_INFO("=== Orchestrator Profiling: %lld tasks, total=%.3fms ===", + (long long)p.submit_count, total / 1e6); + DEV_INFO(" sync_tensormap : %.3fms (%.1f%%)", p.sync_ns / 1e6, p.sync_ns * 100.0 / total); + DEV_INFO(" task_ring_alloc: %.3fms (%.1f%%)", p.alloc_ns / 1e6, p.alloc_ns * 100.0 / total); + DEV_INFO(" param_copy : %.3fms (%.1f%%)", p.params_ns / 1e6, p.params_ns * 100.0 / total); + DEV_INFO(" lookup+dep : %.3fms (%.1f%%)", p.lookup_ns / 1e6, p.lookup_ns * 100.0 / total); + DEV_INFO(" heap_alloc : %.3fms (%.1f%%)", p.heap_ns / 1e6, p.heap_ns * 100.0 / total); + DEV_INFO(" tensormap_ins : %.3fms (%.1f%%)", p.insert_ns / 1e6, p.insert_ns * 100.0 / total); + DEV_INFO(" fanin+ready : %.3fms (%.1f%%)", p.fanin_ns / 1e6, p.fanin_ns * 100.0 / total); + DEV_INFO(" finalize+SM : %.3fms (%.1f%%)", p.finalize_ns / 1e6, p.finalize_ns * 100.0 / total); + DEV_INFO(" scope_end : %.3fms", p.scope_end_ns / 1e6); + DEV_INFO(" avg/task : %.3fus", total / 1e3 / p.submit_count); + } +#endif + // Teardown runtime pto2_rt_orchestration_done(rt); pto2_runtime_destroy(rt); @@ -913,8 +967,6 @@ int AicpuExecutor::run(Runtime* runtime) { void AicpuExecutor::deinit() { // Cleanup runtime execution state - ready_count_aic_.store(0, std::memory_order_release); - ready_count_aiv_.store(0, std::memory_order_release); ready_queue_aic_head_ = 0; ready_queue_aic_tail_ = 0; ready_queue_aiv_head_ = 0; @@ -954,22 +1006,22 @@ void AicpuExecutor::diagnose_stuck_state(Runtime* runtime, int thread_idx, const int* cur_thread_cores, int core_num, Handshake* hank) { (void)runtime; // Reserved for future use - DEV_ERROR("========== DIAGNOSTIC REPORT: Thread %d ==========", thread_idx); + DEV_ALWAYS("========== DIAGNOSTIC REPORT: Thread %d ==========", thread_idx); int completed = completed_tasks_.load(std::memory_order_acquire); int total = total_tasks_.load(std::memory_order_acquire); - DEV_ERROR("Progress: %d/%d tasks (%.1f%%)", + DEV_ALWAYS("Progress: %d/%d tasks (%.1f%%)", completed, total, total > 0 ? completed * 100.0 / total : 0.0); - int aic_ready = ready_count_aic_.load(std::memory_order_acquire); - int aiv_ready = ready_count_aiv_.load(std::memory_order_acquire); - DEV_ERROR("Ready Queues: AIC=%d, AIV=%d", aic_ready, aiv_ready); + int aic_ready = ready_queue_aic_tail_ - ready_queue_aic_head_; + int aiv_ready = ready_queue_aiv_tail_ - ready_queue_aiv_head_; + DEV_ALWAYS("Ready Queues: AIC=%d, AIV=%d", aic_ready, aiv_ready); int busy_cores = 0; int idle_cores = 0; int anomaly_cores = 0; - DEV_ERROR("Core Status:"); + DEV_ALWAYS("Core Status:"); for (int i = 0; i < core_num; i++) { int core_id = cur_thread_cores[i]; Handshake* h = &hank[core_id]; @@ -980,30 +1032,30 @@ void AicpuExecutor::diagnose_stuck_state(Runtime* runtime, int thread_idx, PTO2DispatchPayload* payload = reinterpret_cast(h->task); busy_cores++; - DEV_ERROR(" Core %d [%s, BUSY]: task_id=%d, kernel_id=%d", + DEV_ALWAYS(" Core %d [%s, BUSY]: task_id=%d, kernel_id=%d", core_id, core_type_str, payload->task_id, payload->kernel_id); } else if (h->task_status != 0) { anomaly_cores++; - DEV_ERROR(" Core %d [%s, ANOMALY]: status=BUSY but task=NULL", core_id, core_type_str); + DEV_ALWAYS(" Core %d [%s, ANOMALY]: status=BUSY but task=NULL", core_id, core_type_str); } else { idle_cores++; } } - DEV_ERROR("Summary: %d busy, %d idle, %d anomaly", busy_cores, idle_cores, anomaly_cores); + DEV_ALWAYS("Summary: %d busy, %d idle, %d anomaly", busy_cores, idle_cores, anomaly_cores); // Diagnose deadlock vs livelock if (busy_cores == 0 && aic_ready == 0 && aiv_ready == 0 && completed < total) { - DEV_ERROR("*** DEADLOCK DETECTED ***"); - DEV_ERROR("All cores idle, no ready tasks, but %d tasks incomplete", total - completed); - DEV_ERROR("Check PTO2 shared memory for task dependency state"); + DEV_ALWAYS("*** DEADLOCK DETECTED ***"); + DEV_ALWAYS("All cores idle, no ready tasks, but %d tasks incomplete", total - completed); + DEV_ALWAYS("Check PTO2 shared memory for task dependency state"); } else if (busy_cores > 0) { - DEV_ERROR("*** LIVELOCK / HUNG TASK ***"); - DEV_ERROR("%d cores executing but no progress", busy_cores); + DEV_ALWAYS("*** LIVELOCK / HUNG TASK ***"); + DEV_ALWAYS("%d cores executing but no progress", busy_cores); } - DEV_ERROR("========== END DIAGNOSTIC =========="); + DEV_ALWAYS("========== END DIAGNOSTIC =========="); } // ============================================================================= diff --git a/src/runtime/tensormap_and_ringbuffer/host/runtime_maker.cpp b/src/runtime/tensormap_and_ringbuffer/host/runtime_maker.cpp index 409be4f2..3a5493fe 100644 --- a/src/runtime/tensormap_and_ringbuffer/host/runtime_maker.cpp +++ b/src/runtime/tensormap_and_ringbuffer/host/runtime_maker.cpp @@ -24,6 +24,14 @@ #include #include #include +#include + +// Helper: return current time in milliseconds +static long long _now_ms() { + struct timeval tv; + gettimeofday(&tv, nullptr); + return (long long)tv.tv_sec * 1000 + tv.tv_usec / 1000; +} // Max args for device orchestration #define RT2_MAX_DEVICE_ARGS 32 @@ -103,9 +111,12 @@ extern "C" int init_runtime_impl(Runtime *runtime, std::cout << "RT2 init: " << func_args_count << " arguments, device orchestration mode\n"; + long long t_total_start = _now_ms(); + // Convert host pointers to device pointers based on arg_types uint64_t device_args[RT2_MAX_DEVICE_ARGS]; + long long t_args_start = _now_ms(); for (int i = 0; i < func_args_count; i++) { switch (arg_types[i]) { case ARG_SCALAR: @@ -186,8 +197,10 @@ extern "C" int init_runtime_impl(Runtime *runtime, return -1; } } + long long t_args_end = _now_ms(); // Copy orchestration SO to device memory (AICPU cannot access host memory) + long long t_so_start = _now_ms(); void* dev_so = runtime->host_api.device_malloc(orch_so_size); if (dev_so == nullptr) { std::cerr << "Error: Failed to allocate device memory for orchestration SO\n"; @@ -205,9 +218,12 @@ extern "C" int init_runtime_impl(Runtime *runtime, runtime->set_device_orch_so(orch_so_binary, orch_so_size); runtime->record_tensor_pair(nullptr, dev_so, orch_so_size); std::cout << "Orchestration SO: " << orch_so_size << " bytes copied to device\n"; + long long t_so_end = _now_ms(); // Allocate GM heap for orchestrator output buffers + long long t_heap_start = _now_ms(); void* gm_heap = runtime->host_api.device_malloc(PTO2_HEAP_SIZE); + long long t_heap_end = _now_ms(); if (gm_heap == nullptr) { std::cerr << "Error: Failed to allocate GM heap\n"; return -1; @@ -216,8 +232,10 @@ extern "C" int init_runtime_impl(Runtime *runtime, runtime->set_pto2_gm_heap(gm_heap); // Allocate PTO2 shared memory + long long t_sm_start = _now_ms(); int32_t sm_size = pto2_sm_calculate_size(PTO2_TASK_WINDOW_SIZE, PTO2_DEP_LIST_POOL_SIZE); void* sm_ptr = runtime->host_api.device_malloc(static_cast(sm_size)); + long long t_sm_end = _now_ms(); if (sm_ptr == nullptr) { std::cerr << "Error: Failed to allocate PTO2 shared memory\n"; return -1; @@ -230,6 +248,14 @@ extern "C" int init_runtime_impl(Runtime *runtime, runtime->set_orch_args(device_args, func_args_count); std::cout << "Device orchestration ready: " << func_args_count << " args\n"; + + long long t_total_end = _now_ms(); + printf("TIMING: args_malloc_copy = %lldms\n", t_args_end - t_args_start); + printf("TIMING: orch_so_copy = %lldms\n", t_so_end - t_so_start); + printf("TIMING: gm_heap_alloc(1GB) = %lldms\n", t_heap_end - t_heap_start); + printf("TIMING: shared_mem_alloc = %lldms\n", t_sm_end - t_sm_start); + printf("TIMING: total_init_runtime_impl = %lldms\n", t_total_end - t_total_start); + return 0; } diff --git a/src/runtime/tensormap_and_ringbuffer/orchestration/tensor_orch.cpp b/src/runtime/tensormap_and_ringbuffer/orchestration/tensor_orch.cpp index dfe84df0..4b1105d9 100644 --- a/src/runtime/tensormap_and_ringbuffer/orchestration/tensor_orch.cpp +++ b/src/runtime/tensormap_and_ringbuffer/orchestration/tensor_orch.cpp @@ -54,32 +54,7 @@ Tensor::Tensor(Tensor&& other) } } -Tensor::Tensor(const Tensor& other) - : buffer(other.buffer), - start_offset(other.start_offset), - ndims(other.ndims), - dtype(other.dtype), - version(other.version), - overlap_type(other.overlap_type) { - for (uint64_t i = 0; i < ndims; i++) { - strides[i] = other.strides[i]; - repeats[i] = other.repeats[i]; - } -} - -Tensor& Tensor::operator=(const Tensor& other) { - buffer = other.buffer; - start_offset = other.start_offset; - ndims = other.ndims; - dtype = other.dtype; - version = other.version; - overlap_type = other.overlap_type; - for (uint64_t i = 0; i < ndims; i++) { - strides[i] = other.strides[i]; - repeats[i] = other.repeats[i]; - } - return *this; -} +// Copy constructor and operator= are now inline in tensor.h // ============================================================================= // Validation and optimization (called by constructor's debug_assert) diff --git a/src/runtime/tensormap_and_ringbuffer/runtime/pto_orchestrator.cpp b/src/runtime/tensormap_and_ringbuffer/runtime/pto_orchestrator.cpp index 0d687b1c..b4bb1fed 100644 --- a/src/runtime/tensormap_and_ringbuffer/runtime/pto_orchestrator.cpp +++ b/src/runtime/tensormap_and_ringbuffer/runtime/pto_orchestrator.cpp @@ -16,6 +16,39 @@ #include "pto_tensormap.h" #include "tensor.h" +// ============================================================================= +// Orchestrator Profiling (compile-time toggle) +// ============================================================================= +#ifndef PTO2_ORCH_PROFILING +#define PTO2_ORCH_PROFILING 1 +#endif + +#if PTO2_ORCH_PROFILING +#include +static inline uint64_t _orch_now_ns() { + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + return (uint64_t)ts.tv_sec * 1000000000ULL + ts.tv_nsec; +} + +// Accumulated nanoseconds per sub-step +static uint64_t g_orch_sync_ns = 0; // tensormap sync +static uint64_t g_orch_alloc_ns = 0; // task ring alloc +static uint64_t g_orch_params_ns = 0; // param copy +static uint64_t g_orch_lookup_ns = 0; // tensormap lookup + dep building +static uint64_t g_orch_heap_ns = 0; // heap alloc + output assign +static uint64_t g_orch_insert_ns = 0; // tensormap insert +static uint64_t g_orch_fanin_ns = 0; // fanin list + early-return check +static uint64_t g_orch_finalize_ns = 0; // scheduler init + SM update +static uint64_t g_orch_scope_end_ns = 0; // scope_end overhead +static int64_t g_orch_submit_count = 0; +#define ORCH_PROF_START() uint64_t _t0 = _orch_now_ns(), _t1 +#define ORCH_PROF_LAP(acc) do { _t1 = _orch_now_ns(); acc += (_t1 - _t0); _t0 = _t1; } while(0) +#else +#define ORCH_PROF_START() +#define ORCH_PROF_LAP(acc) +#endif + // ============================================================================= // Per-Task Spinlock Implementation // ============================================================================= @@ -25,7 +58,7 @@ */ static inline void task_fanout_lock(PTO2TaskDescriptor* task) { while (PTO2_EXCHANGE(&task->fanout_lock, 1) != 0) { - PTO2_SPIN_PAUSE(); + PTO2_SPIN_PAUSE_LIGHT(); } } @@ -149,6 +182,10 @@ void pto2_scope_begin(PTO2OrchestratorState* orch) { void pto2_scope_end(PTO2OrchestratorState* orch) { assert(orch->scope_stack_top >= 0 && "Scope stack underflow"); +#if PTO2_ORCH_PROFILING + uint64_t _se0 = _orch_now_ns(); +#endif + int32_t begin = orch->scope_begins[orch->scope_stack_top--]; int32_t count = orch->scope_tasks_size - begin; @@ -158,6 +195,10 @@ void pto2_scope_end(PTO2OrchestratorState* orch) { // Rewind the task buffer — these entries are no longer needed orch->scope_tasks_size = begin; + +#if PTO2_ORCH_PROFILING + g_orch_scope_end_ns += (_orch_now_ns() - _se0); +#endif } // ============================================================================= @@ -224,15 +265,21 @@ void pto2_submit_task(PTO2OrchestratorState* orch, const char* func_name, PTOParam* params, int32_t num_params) { + ORCH_PROF_START(); + // === STEP 0: Sync TensorMap validity and optional cleanup === pto2_orchestrator_sync_tensormap(&orch->tensor_map); + ORCH_PROF_LAP(g_orch_sync_ns); + // Submission without an open scope is illegal assert(orch->scope_stack_top >= 0 && "Cannot submit task outside a scope"); // === STEP 1: Allocate task slot from Task Ring (blocks until available) === int32_t task_id = pto2_task_ring_alloc(&orch->task_ring); + ORCH_PROF_LAP(g_orch_alloc_ns); + PTO2TaskDescriptor* task = pto2_task_ring_get(&orch->task_ring, task_id); // Initialize task descriptor @@ -262,15 +309,18 @@ void pto2_submit_task(PTO2OrchestratorState* orch, int32_t fanin_count = 0; task->param_count = num_params; - for (int i = 0; i < task->param_count; i++) { - task->params[i] = params[i]; - // Copy tensor data into task-owned storage; redirect task's pointer to it + // Bulk copy all params at once + memcpy(task->params, params, num_params * sizeof(PTOParam)); + // Copy tensor data into task-owned storage; redirect pointers + for (int i = 0; i < num_params; i++) { if (params[i].tensor) { task->tensor_copies[i] = *params[i].tensor; task->params[i].tensor = &task->tensor_copies[i]; } } + ORCH_PROF_LAP(g_orch_params_ns); + // === STEP 2: First pass - collect output sizes and process inputs === for (int i = 0; i < num_params; i++) { @@ -280,9 +330,12 @@ void pto2_submit_task(PTO2OrchestratorState* orch, case PTOParamType::INOUT: case PTOParamType::INPUT: { // Look up producer via TensorMap - auto dependency_task_ids = pto2_tensormap_lookup(&orch->tensor_map, params[i].tensor); + PTO2LookupResult lookup_result; + pto2_tensormap_lookup(&orch->tensor_map, params[i].tensor, &lookup_result); - for (auto [entry, overlap_status] : dependency_task_ids) { + for (int r = 0; r < lookup_result.count; r++) { + auto* entry = lookup_result.entries[r].entry; + auto overlap_status = lookup_result.entries[r].overlap_status; // Check if this producer is already in fanin list (avoid duplicates) int producer_task_id = entry->producer_task_id; bool already_added = false; @@ -328,6 +381,8 @@ void pto2_submit_task(PTO2OrchestratorState* orch, } } + ORCH_PROF_LAP(g_orch_lookup_ns); + // === STEP 3: Allocate packed buffer from Heap Ring (may stall) === // Each output slot is aligned to PTO2_PACKED_OUTPUT_ALIGN (1024B); gap after data is padding. if (total_output_size > 0) { @@ -350,6 +405,8 @@ void pto2_submit_task(PTO2OrchestratorState* orch, } } + ORCH_PROF_LAP(g_orch_heap_ns); + // === STEP 4: Second pass - register outputs in TensorMap === int32_t output_idx = 0; for (int i = 0; i < num_params; i++) { @@ -363,6 +420,8 @@ void pto2_submit_task(PTO2OrchestratorState* orch, } } + ORCH_PROF_LAP(g_orch_insert_ns); + // === STEP 5: Finalize fanin list === // First build the fanin list for (int i = 0; i < fanin_count; i++) { @@ -371,6 +430,28 @@ void pto2_submit_task(PTO2OrchestratorState* orch, // Use release semantics to ensure fanin list is visible before fanin_count __atomic_store_n(&task->fanin_count, fanin_count, __ATOMIC_RELEASE); + ORCH_PROF_LAP(g_orch_fanin_ns); + + // === STEP 5b: Check if task is already ready (all producers completed via early-return) === + // In AICPU parallel mode, early-return in pto2_add_consumer_to_producer may have + // already incremented aicpu_fanin_refcount for this task. Now that fanin_count is + // finalized, check if the task is already satisfied and push it to the orchestrator + // ready queue so scheduler threads can pick it up without an O(N) scan. + if (orch->aicpu_fanin_refcount && fanin_count > 0) { + int32_t slot = task_id & orch->aicpu_window_mask; + int32_t refcount = __atomic_load_n(&orch->aicpu_fanin_refcount[slot], __ATOMIC_ACQUIRE); + if (refcount >= fanin_count) { + // All producers already completed — push to orch ready queue + int32_t tail = orch->orch_ready_tail; + int32_t capacity = PTO2OrchestratorState::ORCH_READY_QUEUE_SIZE; + int32_t head = __atomic_load_n(&orch->orch_ready_head, __ATOMIC_ACQUIRE); + if (((tail + 1) & (capacity - 1)) != (head & (capacity - 1))) { + orch->orch_ready_queue[tail & (capacity - 1)] = task_id; + __atomic_store_n(&orch->orch_ready_tail, tail + 1, __ATOMIC_RELEASE); + } + } + } + // === STEP 6: Initialize task in scheduler === // In multi-threaded mode, scheduler thread handles task initialization via polling if (orch->scheduler && orch->init_task_on_submit) { @@ -380,7 +461,12 @@ void pto2_submit_task(PTO2OrchestratorState* orch, // === STEP 7: Update shared memory with current task index === PTO2_STORE_RELEASE(&orch->sm_handle->header->current_task_index, orch->task_ring.current_index); + ORCH_PROF_LAP(g_orch_finalize_ns); + orch->tasks_submitted++; +#if PTO2_ORCH_PROFILING + g_orch_submit_count++; +#endif } // ============================================================================= @@ -436,3 +522,26 @@ void pto2_orchestrator_print_scope_stack(PTO2OrchestratorState* orch) { printf("==================\n"); } + +#if PTO2_ORCH_PROFILING +PTO2OrchProfilingData pto2_orchestrator_get_profiling() { + PTO2OrchProfilingData d; + d.sync_ns = g_orch_sync_ns; + d.alloc_ns = g_orch_alloc_ns; + d.params_ns = g_orch_params_ns; + d.lookup_ns = g_orch_lookup_ns; + d.heap_ns = g_orch_heap_ns; + d.insert_ns = g_orch_insert_ns; + d.fanin_ns = g_orch_fanin_ns; + d.finalize_ns = g_orch_finalize_ns; + d.scope_end_ns = g_orch_scope_end_ns; + d.submit_count = g_orch_submit_count; + + // Reset + g_orch_sync_ns = g_orch_alloc_ns = g_orch_params_ns = 0; + g_orch_lookup_ns = g_orch_heap_ns = g_orch_insert_ns = 0; + g_orch_fanin_ns = g_orch_finalize_ns = g_orch_scope_end_ns = 0; + g_orch_submit_count = 0; + return d; +} +#endif diff --git a/src/runtime/tensormap_and_ringbuffer/runtime/pto_orchestrator.h b/src/runtime/tensormap_and_ringbuffer/runtime/pto_orchestrator.h index 227acc78..fbeb3ff0 100644 --- a/src/runtime/tensormap_and_ringbuffer/runtime/pto_orchestrator.h +++ b/src/runtime/tensormap_and_ringbuffer/runtime/pto_orchestrator.h @@ -77,6 +77,18 @@ struct PTO2OrchestratorState { int32_t* aicpu_fanin_refcount; volatile int32_t* aicpu_task_completed; int32_t aicpu_window_mask; + + // === ORCHESTRATOR READY QUEUE (early-return path → scheduler) === + // When the orchestrator discovers a producer already completed, it + // increments the consumer's refcount directly. If that makes the + // consumer ready, the consumer_id is pushed here so scheduler threads + // can pick it up without an O(N) scan. + // SPSC-ish ring: orchestrator writes (single producer), scheduler + // threads read via CAS on orch_ready_head (multiple consumers). + static constexpr int32_t ORCH_READY_QUEUE_SIZE = 4096; + volatile int32_t orch_ready_queue[4096]; + volatile int32_t orch_ready_tail; // written by orchestrator only + volatile int32_t orch_ready_head; // advanced by scheduler via CAS }; // ============================================================================= @@ -226,4 +238,33 @@ void pto2_orchestrator_print_stats(PTO2OrchestratorState* orch); */ void pto2_orchestrator_print_scope_stack(PTO2OrchestratorState* orch); +// ============================================================================= +// Orchestrator Profiling Data +// ============================================================================= + +#ifndef PTO2_ORCH_PROFILING +#define PTO2_ORCH_PROFILING 1 +#endif + +#if PTO2_ORCH_PROFILING +struct PTO2OrchProfilingData { + uint64_t sync_ns; + uint64_t alloc_ns; + uint64_t params_ns; + uint64_t lookup_ns; + uint64_t heap_ns; + uint64_t insert_ns; + uint64_t fanin_ns; + uint64_t finalize_ns; + uint64_t scope_end_ns; + int64_t submit_count; +}; + +/** + * Get and reset orchestrator profiling data. + * Returns accumulated profiling data and resets counters. + */ +PTO2OrchProfilingData pto2_orchestrator_get_profiling(); +#endif + #endif // PTO_ORCHESTRATOR_H diff --git a/src/runtime/tensormap_and_ringbuffer/runtime/pto_ring_buffer.cpp b/src/runtime/tensormap_and_ringbuffer/runtime/pto_ring_buffer.cpp index 6aa6a1a3..da9bfcc7 100644 --- a/src/runtime/tensormap_and_ringbuffer/runtime/pto_ring_buffer.cpp +++ b/src/runtime/tensormap_and_ringbuffer/runtime/pto_ring_buffer.cpp @@ -272,9 +272,9 @@ int32_t pto2_task_ring_try_alloc(PTO2TaskRing* ring) { int32_t task_id = current; int32_t slot = task_id & (ring->window_size - 1); - // Initialize task descriptor + // Mark slot as occupied (skip full memset — pto2_submit_task + // explicitly initializes all fields it needs) PTO2TaskDescriptor* task = &ring->descriptors[slot]; - memset(task, 0, sizeof(PTO2TaskDescriptor)); task->task_id = task_id; task->is_active = true; diff --git a/src/runtime/tensormap_and_ringbuffer/runtime/pto_runtime2_types.h b/src/runtime/tensormap_and_ringbuffer/runtime/pto_runtime2_types.h index 31d51783..2dd7023c 100644 --- a/src/runtime/tensormap_and_ringbuffer/runtime/pto_runtime2_types.h +++ b/src/runtime/tensormap_and_ringbuffer/runtime/pto_runtime2_types.h @@ -279,7 +279,6 @@ typedef struct { PTOParam params[16]; Tensor tensor_copies[16]; // Owned tensor data (params[i].tensor points here) int param_count{0}; - } PTO2TaskDescriptor; // ============================================================================= @@ -378,10 +377,13 @@ typedef void (*PTO2InCoreFunc)(void** args, int32_t num_args); #include #if defined(__aarch64__) #define PTO2_SPIN_PAUSE() do { __asm__ __volatile__("yield" ::: "memory"); sched_yield(); } while(0) + #define PTO2_SPIN_PAUSE_LIGHT() __asm__ __volatile__("yield" ::: "memory") #elif defined(__x86_64__) #define PTO2_SPIN_PAUSE() do { __builtin_ia32_pause(); sched_yield(); } while(0) + #define PTO2_SPIN_PAUSE_LIGHT() __builtin_ia32_pause() #else #define PTO2_SPIN_PAUSE() sched_yield() + #define PTO2_SPIN_PAUSE_LIGHT() ((void)0) #endif /** diff --git a/src/runtime/tensormap_and_ringbuffer/runtime/pto_tensormap.cpp b/src/runtime/tensormap_and_ringbuffer/runtime/pto_tensormap.cpp index 2b63959a..aebace93 100644 --- a/src/runtime/tensormap_and_ringbuffer/runtime/pto_tensormap.cpp +++ b/src/runtime/tensormap_and_ringbuffer/runtime/pto_tensormap.cpp @@ -248,12 +248,12 @@ void pto2_tensormap_cleanup_retired(PTO2TensorMap* tm, int32_t old_last_task_ali // Lookup with Chain Truncation // ============================================================================= -std::vector> pto2_tensormap_lookup(PTO2TensorMap* tm, Tensor* tensor) { +void pto2_tensormap_lookup(PTO2TensorMap* tm, Tensor* tensor, PTO2LookupResult* result) { uint32_t bucket = pto2_tensormap_hash(tm, tensor); int32_t* prev_ptr = &tm->buckets[bucket]; // For truncation int32_t offset = *prev_ptr; - std::vector> task_ids; + result->count = 0; while (offset >= 0) { PTO2TensorMapEntry* entry = &tm->entry_pool[offset]; @@ -275,7 +275,7 @@ std::vector> pto2_tensormap_lookup offset = next; } - return task_ids; + return; } // Entry is valid - check if regions OVERLAP (not just exact match) @@ -283,15 +283,13 @@ std::vector> pto2_tensormap_lookup // potential to overlap. We must check actual byte-range overlap. auto overlap_status = tensor->is_overlap(entry->tensor); if (overlap_status != OverlapStatus::NO_OVERLAP) { - task_ids.emplace_back(entry, overlap_status); + result->push(entry, overlap_status); } // Move to next entry prev_ptr = &entry->next_in_bucket; offset = *prev_ptr; } - - return task_ids; } // ============================================================================= diff --git a/src/runtime/tensormap_and_ringbuffer/runtime/pto_tensormap.h b/src/runtime/tensormap_and_ringbuffer/runtime/pto_tensormap.h index caadafc8..6e974cde 100644 --- a/src/runtime/tensormap_and_ringbuffer/runtime/pto_tensormap.h +++ b/src/runtime/tensormap_and_ringbuffer/runtime/pto_tensormap.h @@ -130,6 +130,25 @@ void pto2_tensormap_reset(PTO2TensorMap* tm); */ void pto2_tensormap_sync_validity(PTO2TensorMap* tm, int32_t last_task_alive); +/** + * Stack-allocated lookup result (avoids heap allocation per lookup) + */ +#define PTO2_LOOKUP_MAX_RESULTS 16 +struct PTO2LookupResult { + struct Entry { + PTO2TensorMapEntry* entry; + OverlapStatus overlap_status; + }; + Entry entries[PTO2_LOOKUP_MAX_RESULTS]; + int32_t count{0}; + + void push(PTO2TensorMapEntry* e, OverlapStatus s) { + if (count < PTO2_LOOKUP_MAX_RESULTS) { + entries[count++] = {e, s}; + } + } +}; + /** * Lookup producer for a tensor region * @@ -141,9 +160,9 @@ void pto2_tensormap_sync_validity(PTO2TensorMap* tm, int32_t last_task_alive); * * @param tm TensorMap * @param tensor Tensor to look up - * @return Producer entry, and overlap status + * @param result Output: stack-allocated result buffer */ -std::vector> pto2_tensormap_lookup(PTO2TensorMap* tm, Tensor* tensor); +void pto2_tensormap_lookup(PTO2TensorMap* tm, Tensor* tensor, PTO2LookupResult* result); /** * Insert a new entry (called when task produces output) diff --git a/src/runtime/tensormap_and_ringbuffer/runtime/tensor.h b/src/runtime/tensormap_and_ringbuffer/runtime/tensor.h index 388d05f0..7d061b69 100644 --- a/src/runtime/tensormap_and_ringbuffer/runtime/tensor.h +++ b/src/runtime/tensormap_and_ringbuffer/runtime/tensor.h @@ -132,9 +132,33 @@ struct Tensor { OverlapType overlap_type = OverlapType::Accurate); Tensor(Tensor&& other); - Tensor(const Tensor& other); - Tensor& operator=(const Tensor& other); + Tensor(const Tensor& other) + : buffer(other.buffer), + start_offset(other.start_offset), + ndims(other.ndims), + dtype(other.dtype), + version(other.version), + overlap_type(other.overlap_type) { + for (uint64_t i = 0; i < ndims; i++) { + strides[i] = other.strides[i]; + repeats[i] = other.repeats[i]; + } + } + + Tensor& operator=(const Tensor& other) { + buffer = other.buffer; + start_offset = other.start_offset; + ndims = other.ndims; + dtype = other.dtype; + version = other.version; + overlap_type = other.overlap_type; + for (uint64_t i = 0; i < ndims; i++) { + strides[i] = other.strides[i]; + repeats[i] = other.repeats[i]; + } + return *this; + } std::string dump() const; diff --git a/tests/device_tests/aicpu_build_graph/paged_attention/golden.py b/tests/device_tests/aicpu_build_graph/paged_attention/golden.py index 396789bc..72566fe5 100644 --- a/tests/device_tests/aicpu_build_graph/paged_attention/golden.py +++ b/tests/device_tests/aicpu_build_graph/paged_attention/golden.py @@ -118,6 +118,9 @@ def paged_attention( """ Compute paged attention using online softmax with head tiling and GQA. + Vectorized across the batch dimension for performance. + Supports different context_lens per batch via masking. + Args: query: (batch, num_heads, head_dim) bfloat16 key_cache: (total_blocks, block_size, num_kv_heads, head_dim) bfloat16 @@ -129,65 +132,87 @@ def paged_attention( context_lens: (batch,) int32 Returns: - out: (batch, num_heads, head_dim) float32 + out: (batch * num_heads, head_dim) float32 """ assert num_kv_heads == 1 - batch, num_heads, head_dim = query.shape + batch, num_heads_dim, head_dim = query.shape _, block_size, _, _ = key_cache.shape - _, block_num = block_table.shape - - query = query.reshape(-1, head_dim) - key_cache = key_cache.reshape(-1, block_size, head_dim) - value_cache = value_cache.reshape(-1, block_size, head_dim) - - out = torch.zeros((batch * num_heads, head_dim), dtype=torch.float32) - - for b_idx in range(batch): - cur_seq = int(context_lens[b_idx]) - bn_this_batch = (cur_seq + block_size - 1) // block_size - assert bn_this_batch <= block_num - - q_tile = min(num_heads, 128) - for cur_offset in range(0, num_heads, q_tile): - q_tile_size = min(q_tile, num_heads - cur_offset) - base_idx = b_idx * num_heads + cur_offset - qi = query[base_idx : base_idx + q_tile_size].to(torch.float32) - - oi = None - li = None - mi = None - - for bn in range(bn_this_batch): - cur_block_idx = block_table[b_idx, bn] - valid_len = min(block_size, cur_seq - bn * block_size) - kj = key_cache[cur_block_idx, :valid_len, :].to(torch.float32) - vj = value_cache[cur_block_idx, :valid_len, :].to(torch.float32) - - sij = (qi @ kj.T) * scale_value - mij = sij.max(dim=-1, keepdim=True)[0] - pij = torch.exp(sij - mij).to(torch.bfloat16).to(torch.float32) - lij = pij.sum(dim=1, keepdim=True) - - if bn == 0: - oi = pij @ vj - li = lij - mi = mij - else: - mi_new = torch.maximum(mi, mij) - alpha = torch.exp(mi - mi_new) - beta = torch.exp(mij - mi_new) - li_new = alpha * li + beta * lij - oi_new = pij @ vj - oi = alpha * oi + beta * oi_new - li = li_new - mi = mi_new - - if bn == bn_this_batch - 1: - oi = oi / li - - out[base_idx : base_idx + q_tile_size] = oi - - return out + + # Reshape for batched computation + key_cache_flat = key_cache.reshape(-1, block_size, head_dim) + value_cache_flat = value_cache.reshape(-1, block_size, head_dim) + + out = torch.zeros((batch, num_heads_dim, head_dim), dtype=torch.float32) + + q_tile = min(num_heads_dim, 128) + + # Max blocks across all batches (each batch may have different context_len) + max_bn = int(((context_lens.max().item()) + block_size - 1) // block_size) + + for q_offset in range(0, num_heads_dim, q_tile): + q_tile_size = min(q_tile, num_heads_dim - q_offset) + # qi: (batch, q_tile_size, head_dim) + qi = query[:, q_offset:q_offset + q_tile_size, :].to(torch.float32) + + oi = None # (batch, q_tile_size, head_dim) + li = None # (batch, q_tile_size, 1) + mi = None # (batch, q_tile_size, 1) + + for bn in range(max_bn): + # valid_len per batch for this block position + valid_lens = torch.clamp(context_lens - bn * block_size, min=0, max=block_size) + active_mask = valid_lens > 0 # (batch,) + + if not active_mask.any(): + break + + # Gather block indices for all batches + block_indices = block_table[:, bn] # (batch,) + + # Gather K and V: (batch, block_size, head_dim) + kj_all = key_cache_flat[block_indices].to(torch.float32) + vj_all = value_cache_flat[block_indices].to(torch.float32) + + # QK matmul: (batch, q_tile_size, block_size) + sij = torch.bmm(qi, kj_all.transpose(1, 2)) * scale_value + + # Mask out invalid positions (beyond valid_len per batch) + pos = torch.arange(block_size, device=sij.device).unsqueeze(0) # (1, block_size) + valid_mask = pos < valid_lens.unsqueeze(1) # (batch, block_size) + valid_mask = valid_mask.unsqueeze(1) # (batch, 1, block_size) + sij = sij.masked_fill(~valid_mask, float('-inf')) + + # Also mask inactive batches (no blocks at this position) + batch_mask = active_mask.view(-1, 1, 1) # (batch, 1, 1) + sij = sij.masked_fill(~batch_mask, float('-inf')) + + mij = sij.max(dim=-1, keepdim=True)[0] # (batch, q_tile_size, 1) + mij = mij.clamp(min=-1e30) + pij = torch.exp(sij - mij) + pij = pij.masked_fill(~valid_mask, 0.0) + pij = pij.masked_fill(~batch_mask, 0.0) + pij = pij.to(torch.bfloat16).to(torch.float32) + lij = pij.sum(dim=-1, keepdim=True) # (batch, q_tile_size, 1) + + # PV matmul: (batch, q_tile_size, head_dim) + oi_new = torch.bmm(pij, vj_all) + + if bn == 0: + oi = oi_new + li = lij + mi = mij + else: + mi_new = torch.maximum(mi, mij) + alpha = torch.exp(mi - mi_new) + beta = torch.exp(mij - mi_new) + li = alpha * li + beta * lij + oi = alpha * oi + beta * oi_new + mi = mi_new + + # Final normalization + out[:, q_offset:q_offset + q_tile_size, :] = oi / li + + return out.reshape(-1, head_dim) def compute_golden(tensors: dict, params: dict) -> None: @@ -201,7 +226,7 @@ def compute_golden(tensors: dict, params: dict) -> None: max_num_blocks_per_req = max_model_len // block_size - # Reconstruct shaped arrays from flat bfloat16 tensors + # Reconstruct shaped tensors from flat tensors query = tensors["query"].reshape(batch, num_heads, head_dim) key_cache = tensors["key_cache"].reshape(-1, block_size, kv_head_num, head_dim) value_cache = tensors["value_cache"].reshape(-1, block_size, kv_head_num, head_dim) diff --git a/tests/device_tests/host_build_graph/paged_attention/golden.py b/tests/device_tests/host_build_graph/paged_attention/golden.py index 8133ca03..a74bd24f 100644 --- a/tests/device_tests/host_build_graph/paged_attention/golden.py +++ b/tests/device_tests/host_build_graph/paged_attention/golden.py @@ -118,6 +118,9 @@ def paged_attention( """ Compute paged attention using online softmax with head tiling and GQA. + Vectorized across the batch dimension for performance. + Supports different context_lens per batch via masking. + Args: query: (batch, num_heads, head_dim) bfloat16 key_cache: (total_blocks, block_size, num_kv_heads, head_dim) bfloat16 @@ -129,65 +132,87 @@ def paged_attention( context_lens: (batch,) int32 Returns: - out: (batch, num_heads, head_dim) float32 + out: (batch * num_heads, head_dim) float32 """ assert num_kv_heads == 1 - batch, num_heads, head_dim = query.shape + batch, num_heads_dim, head_dim = query.shape _, block_size, _, _ = key_cache.shape - _, block_num = block_table.shape - - query = query.reshape(-1, head_dim) - key_cache = key_cache.reshape(-1, block_size, head_dim) - value_cache = value_cache.reshape(-1, block_size, head_dim) - - out = torch.zeros((batch * num_heads, head_dim), dtype=torch.float32) - - for b_idx in range(batch): - cur_seq = int(context_lens[b_idx]) - bn_this_batch = (cur_seq + block_size - 1) // block_size - assert bn_this_batch <= block_num - - q_tile = min(num_heads, 128) - for cur_offset in range(0, num_heads, q_tile): - q_tile_size = min(q_tile, num_heads - cur_offset) - base_idx = b_idx * num_heads + cur_offset - qi = query[base_idx : base_idx + q_tile_size].to(torch.float32) - - oi = None - li = None - mi = None - - for bn in range(bn_this_batch): - cur_block_idx = block_table[b_idx, bn] - valid_len = min(block_size, cur_seq - bn * block_size) - kj = key_cache[cur_block_idx, :valid_len, :].to(torch.float32) - vj = value_cache[cur_block_idx, :valid_len, :].to(torch.float32) - - sij = (qi @ kj.T) * scale_value - mij = sij.max(dim=-1, keepdim=True)[0] - pij = torch.exp(sij - mij).to(torch.bfloat16).to(torch.float32) - lij = pij.sum(dim=1, keepdim=True) - - if bn == 0: - oi = pij @ vj - li = lij - mi = mij - else: - mi_new = torch.maximum(mi, mij) - alpha = torch.exp(mi - mi_new) - beta = torch.exp(mij - mi_new) - li_new = alpha * li + beta * lij - oi_new = pij @ vj - oi = alpha * oi + beta * oi_new - li = li_new - mi = mi_new - - if bn == bn_this_batch - 1: - oi = oi / li - - out[base_idx : base_idx + q_tile_size] = oi - - return out + + # Reshape for batched computation + key_cache_flat = key_cache.reshape(-1, block_size, head_dim) + value_cache_flat = value_cache.reshape(-1, block_size, head_dim) + + out = torch.zeros((batch, num_heads_dim, head_dim), dtype=torch.float32) + + q_tile = min(num_heads_dim, 128) + + # Max blocks across all batches (each batch may have different context_len) + max_bn = int(((context_lens.max().item()) + block_size - 1) // block_size) + + for q_offset in range(0, num_heads_dim, q_tile): + q_tile_size = min(q_tile, num_heads_dim - q_offset) + # qi: (batch, q_tile_size, head_dim) + qi = query[:, q_offset:q_offset + q_tile_size, :].to(torch.float32) + + oi = None # (batch, q_tile_size, head_dim) + li = None # (batch, q_tile_size, 1) + mi = None # (batch, q_tile_size, 1) + + for bn in range(max_bn): + # valid_len per batch for this block position + valid_lens = torch.clamp(context_lens - bn * block_size, min=0, max=block_size) + active_mask = valid_lens > 0 # (batch,) + + if not active_mask.any(): + break + + # Gather block indices for all batches + block_indices = block_table[:, bn] # (batch,) + + # Gather K and V: (batch, block_size, head_dim) + kj_all = key_cache_flat[block_indices].to(torch.float32) + vj_all = value_cache_flat[block_indices].to(torch.float32) + + # QK matmul: (batch, q_tile_size, block_size) + sij = torch.bmm(qi, kj_all.transpose(1, 2)) * scale_value + + # Mask out invalid positions (beyond valid_len per batch) + pos = torch.arange(block_size, device=sij.device).unsqueeze(0) # (1, block_size) + valid_mask = pos < valid_lens.unsqueeze(1) # (batch, block_size) + valid_mask = valid_mask.unsqueeze(1) # (batch, 1, block_size) + sij = sij.masked_fill(~valid_mask, float('-inf')) + + # Also mask inactive batches (no blocks at this position) + batch_mask = active_mask.view(-1, 1, 1) # (batch, 1, 1) + sij = sij.masked_fill(~batch_mask, float('-inf')) + + mij = sij.max(dim=-1, keepdim=True)[0] # (batch, q_tile_size, 1) + mij = mij.clamp(min=-1e30) + pij = torch.exp(sij - mij) + pij = pij.masked_fill(~valid_mask, 0.0) + pij = pij.masked_fill(~batch_mask, 0.0) + pij = pij.to(torch.bfloat16).to(torch.float32) + lij = pij.sum(dim=-1, keepdim=True) # (batch, q_tile_size, 1) + + # PV matmul: (batch, q_tile_size, head_dim) + oi_new = torch.bmm(pij, vj_all) + + if bn == 0: + oi = oi_new + li = lij + mi = mij + else: + mi_new = torch.maximum(mi, mij) + alpha = torch.exp(mi - mi_new) + beta = torch.exp(mij - mi_new) + li = alpha * li + beta * lij + oi = alpha * oi + beta * oi_new + mi = mi_new + + # Final normalization + out[:, q_offset:q_offset + q_tile_size, :] = oi / li + + return out.reshape(-1, head_dim) def compute_golden(tensors: dict, params: dict) -> None: diff --git a/tests/device_tests/tensormap_and_ringbuffer/paged_attention/golden.py b/tests/device_tests/tensormap_and_ringbuffer/paged_attention/golden.py index 61ab6a1d..99c98072 100644 --- a/tests/device_tests/tensormap_and_ringbuffer/paged_attention/golden.py +++ b/tests/device_tests/tensormap_and_ringbuffer/paged_attention/golden.py @@ -118,6 +118,9 @@ def paged_attention( """ Compute paged attention using online softmax with head tiling and GQA. + Vectorized across the batch dimension for performance. + Supports different context_lens per batch via masking. + Args: query: (batch, num_heads, head_dim) bfloat16 key_cache: (total_blocks, block_size, num_kv_heads, head_dim) bfloat16 @@ -129,65 +132,87 @@ def paged_attention( context_lens: (batch,) int32 Returns: - out: (batch, num_heads, head_dim) float32 + out: (batch * num_heads, head_dim) float32 """ assert num_kv_heads == 1 - batch, num_heads, head_dim = query.shape + batch, num_heads_dim, head_dim = query.shape _, block_size, _, _ = key_cache.shape - _, block_num = block_table.shape - - query = query.reshape(-1, head_dim) - key_cache = key_cache.reshape(-1, block_size, head_dim) - value_cache = value_cache.reshape(-1, block_size, head_dim) - - out = torch.zeros((batch * num_heads, head_dim), dtype=torch.float32) - - for b_idx in range(batch): - cur_seq = int(context_lens[b_idx]) - bn_this_batch = (cur_seq + block_size - 1) // block_size - assert bn_this_batch <= block_num - - q_tile = min(num_heads, 128) - for cur_offset in range(0, num_heads, q_tile): - q_tile_size = min(q_tile, num_heads - cur_offset) - base_idx = b_idx * num_heads + cur_offset - qi = query[base_idx : base_idx + q_tile_size].to(torch.float32) - - oi = None - li = None - mi = None - - for bn in range(bn_this_batch): - cur_block_idx = block_table[b_idx, bn] - valid_len = min(block_size, cur_seq - bn * block_size) - kj = key_cache[cur_block_idx, :valid_len, :].to(torch.float32) - vj = value_cache[cur_block_idx, :valid_len, :].to(torch.float32) - - sij = (qi @ kj.T) * scale_value - mij = sij.max(dim=-1, keepdim=True)[0] - pij = torch.exp(sij - mij).to(torch.bfloat16).to(torch.float32) - lij = pij.sum(dim=1, keepdim=True) - - if bn == 0: - oi = pij @ vj - li = lij - mi = mij - else: - mi_new = torch.maximum(mi, mij) - alpha = torch.exp(mi - mi_new) - beta = torch.exp(mij - mi_new) - li_new = alpha * li + beta * lij - oi_new = pij @ vj - oi = alpha * oi + beta * oi_new - li = li_new - mi = mi_new - - if bn == bn_this_batch - 1: - oi = oi / li - - out[base_idx : base_idx + q_tile_size] = oi - - return out + + # Reshape for batched computation + key_cache_flat = key_cache.reshape(-1, block_size, head_dim) + value_cache_flat = value_cache.reshape(-1, block_size, head_dim) + + out = torch.zeros((batch, num_heads_dim, head_dim), dtype=torch.float32) + + q_tile = min(num_heads_dim, 128) + + # Max blocks across all batches (each batch may have different context_len) + max_bn = int(((context_lens.max().item()) + block_size - 1) // block_size) + + for q_offset in range(0, num_heads_dim, q_tile): + q_tile_size = min(q_tile, num_heads_dim - q_offset) + # qi: (batch, q_tile_size, head_dim) + qi = query[:, q_offset:q_offset + q_tile_size, :].to(torch.float32) + + oi = None # (batch, q_tile_size, head_dim) + li = None # (batch, q_tile_size, 1) + mi = None # (batch, q_tile_size, 1) + + for bn in range(max_bn): + # valid_len per batch for this block position + valid_lens = torch.clamp(context_lens - bn * block_size, min=0, max=block_size) + active_mask = valid_lens > 0 # (batch,) + + if not active_mask.any(): + break + + # Gather block indices for all batches + block_indices = block_table[:, bn] # (batch,) + + # Gather K and V: (batch, block_size, head_dim) + kj_all = key_cache_flat[block_indices].to(torch.float32) + vj_all = value_cache_flat[block_indices].to(torch.float32) + + # QK matmul: (batch, q_tile_size, block_size) + sij = torch.bmm(qi, kj_all.transpose(1, 2)) * scale_value + + # Mask out invalid positions (beyond valid_len per batch) + pos = torch.arange(block_size, device=sij.device).unsqueeze(0) # (1, block_size) + valid_mask = pos < valid_lens.unsqueeze(1) # (batch, block_size) + valid_mask = valid_mask.unsqueeze(1) # (batch, 1, block_size) + sij = sij.masked_fill(~valid_mask, float('-inf')) + + # Also mask inactive batches (no blocks at this position) + batch_mask = active_mask.view(-1, 1, 1) # (batch, 1, 1) + sij = sij.masked_fill(~batch_mask, float('-inf')) + + mij = sij.max(dim=-1, keepdim=True)[0] # (batch, q_tile_size, 1) + mij = mij.clamp(min=-1e30) + pij = torch.exp(sij - mij) + pij = pij.masked_fill(~valid_mask, 0.0) + pij = pij.masked_fill(~batch_mask, 0.0) + pij = pij.to(torch.bfloat16).to(torch.float32) + lij = pij.sum(dim=-1, keepdim=True) # (batch, q_tile_size, 1) + + # PV matmul: (batch, q_tile_size, head_dim) + oi_new = torch.bmm(pij, vj_all) + + if bn == 0: + oi = oi_new + li = lij + mi = mij + else: + mi_new = torch.maximum(mi, mij) + alpha = torch.exp(mi - mi_new) + beta = torch.exp(mij - mi_new) + li = alpha * li + beta * lij + oi = alpha * oi + beta * oi_new + mi = mi_new + + # Final normalization + out[:, q_offset:q_offset + q_tile_size, :] = oi / li + + return out.reshape(-1, head_dim) def compute_golden(tensors: dict, params: dict) -> None: