Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 90 additions & 66 deletions examples/host_build_graph/paged_attention/golden.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,76 +120,101 @@ def paged_attention(
"""
Compute paged attention using online softmax with head tiling and GQA.

Vectorized across the batch dimension for performance.
Supports different context_lens per batch via masking.

Args:
query: (batch, num_heads, head_dim) float16
key_cache: (total_blocks, block_size, num_kv_heads, head_dim) float16
value_cache: (total_blocks, block_size, num_kv_heads, head_dim) float16
query: (batch, num_heads, head_dim) bfloat16
key_cache: (total_blocks, block_size, num_kv_heads, head_dim) bfloat16
value_cache: (total_blocks, block_size, num_kv_heads, head_dim) bfloat16
num_kv_heads: int
num_heads: int
scale_value: float
block_table: (batch, block_num) int32
context_lens: (batch,) int32

Returns:
out: (batch, num_heads, head_dim) float32
out: (batch * num_heads, head_dim) float32
"""
assert num_kv_heads == 1
batch, num_heads, head_dim = query.shape
batch, num_heads_dim, head_dim = query.shape
_, block_size, _, _ = key_cache.shape
_, block_num = block_table.shape

query = query.reshape(-1, head_dim)
key_cache = key_cache.reshape(-1, block_size, head_dim)
value_cache = value_cache.reshape(-1, block_size, head_dim)

out = torch.zeros((batch * num_heads, head_dim), dtype=torch.float32)

for b_idx in range(batch):
cur_seq = int(context_lens[b_idx])
bn_this_batch = (cur_seq + block_size - 1) // block_size
assert bn_this_batch <= block_num

q_tile = min(num_heads, 128)
for cur_offset in range(0, num_heads, q_tile):
q_tile_size = min(q_tile, num_heads - cur_offset)
base_idx = b_idx * num_heads + cur_offset
qi = query[base_idx : base_idx + q_tile_size].to(torch.float32)

oi = None
li = None
mi = None

for bn in range(bn_this_batch):
cur_block_idx = block_table[b_idx, bn]
valid_len = min(block_size, cur_seq - bn * block_size)
kj = key_cache[cur_block_idx, :valid_len, :].to(torch.float32)
vj = value_cache[cur_block_idx, :valid_len, :].to(torch.float32)

sij = (qi @ kj.T) * scale_value
mij = sij.max(dim=-1, keepdim=True)[0]
pij = torch.exp(sij - mij).to(torch.float16).to(torch.float32)
lij = torch.sum(pij, dim=1, keepdim=True)

if bn == 0:
oi = pij @ vj
li = lij
mi = mij
else:
mi_new = torch.maximum(mi, mij)
alpha = torch.exp(mi - mi_new)
beta = torch.exp(mij - mi_new)
li_new = alpha * li + beta * lij
oi_new = pij @ vj
oi = alpha * oi + beta * oi_new
li = li_new
mi = mi_new

if bn == bn_this_batch - 1:
oi = oi / li

out[base_idx : base_idx + q_tile_size] = oi

return out

# Reshape for batched computation
key_cache_flat = key_cache.reshape(-1, block_size, head_dim)
value_cache_flat = value_cache.reshape(-1, block_size, head_dim)

out = torch.zeros((batch, num_heads_dim, head_dim), dtype=torch.float32)

q_tile = min(num_heads_dim, 128)

# Max blocks across all batches (each batch may have different context_len)
max_bn = int(((context_lens.max().item()) + block_size - 1) // block_size)

for q_offset in range(0, num_heads_dim, q_tile):
q_tile_size = min(q_tile, num_heads_dim - q_offset)
# qi: (batch, q_tile_size, head_dim)
qi = query[:, q_offset:q_offset + q_tile_size, :].to(torch.float32)

oi = None # (batch, q_tile_size, head_dim)
li = None # (batch, q_tile_size, 1)
mi = None # (batch, q_tile_size, 1)

for bn in range(max_bn):
# valid_len per batch for this block position
valid_lens = torch.clamp(context_lens - bn * block_size, min=0, max=block_size)
active_mask = valid_lens > 0 # (batch,)

if not active_mask.any():
break

# Gather block indices for all batches
block_indices = block_table[:, bn] # (batch,)

# Gather K and V: (batch, block_size, head_dim)
kj_all = key_cache_flat[block_indices].to(torch.float32)
vj_all = value_cache_flat[block_indices].to(torch.float32)

# QK matmul: (batch, q_tile_size, block_size)
sij = torch.bmm(qi, kj_all.transpose(1, 2)) * scale_value

# Mask out invalid positions (beyond valid_len per batch)
pos = torch.arange(block_size, device=sij.device).unsqueeze(0) # (1, block_size)
valid_mask = pos < valid_lens.unsqueeze(1) # (batch, block_size)
valid_mask = valid_mask.unsqueeze(1) # (batch, 1, block_size)
sij = sij.masked_fill(~valid_mask, float('-inf'))

# Also mask inactive batches (no blocks at this position)
batch_mask = active_mask.view(-1, 1, 1) # (batch, 1, 1)
sij = sij.masked_fill(~batch_mask, float('-inf'))

mij = sij.max(dim=-1, keepdim=True)[0] # (batch, q_tile_size, 1)
mij = mij.clamp(min=-1e30)
pij = torch.exp(sij - mij)
pij = pij.masked_fill(~valid_mask, 0.0)
pij = pij.masked_fill(~batch_mask, 0.0)
pij = pij.to(torch.bfloat16).to(torch.float32)
lij = pij.sum(dim=-1, keepdim=True) # (batch, q_tile_size, 1)

# PV matmul: (batch, q_tile_size, head_dim)
oi_new = torch.bmm(pij, vj_all)

if bn == 0:
oi = oi_new
li = lij
mi = mij
else:
mi_new = torch.maximum(mi, mij)
alpha = torch.exp(mi - mi_new)
beta = torch.exp(mij - mi_new)
li = alpha * li + beta * lij
oi = alpha * oi + beta * oi_new
mi = mi_new

# Final normalization
out[:, q_offset:q_offset + q_tile_size, :] = oi / li

return out.reshape(-1, head_dim)


def compute_golden(tensors: dict, params: dict) -> None:
Expand All @@ -203,13 +228,12 @@ def compute_golden(tensors: dict, params: dict) -> None:

max_num_blocks_per_req = max_model_len // block_size

# Reconstruct shaped arrays from flat float16 tensors
# Convert to torch tensors (handles both array types)
query = torch.as_tensor(tensors["query"]).reshape(batch, num_heads, head_dim)
key_cache = torch.as_tensor(tensors["key_cache"]).reshape(-1, block_size, kv_head_num, head_dim)
value_cache = torch.as_tensor(tensors["value_cache"]).reshape(-1, block_size, kv_head_num, head_dim)
block_table = torch.as_tensor(tensors["block_table"]).reshape(batch, max_num_blocks_per_req)
context_lens = torch.as_tensor(tensors["context_lens"])
# Reconstruct shaped tensors from flat tensors
query = tensors["query"].reshape(batch, num_heads, head_dim)
key_cache = tensors["key_cache"].reshape(-1, block_size, kv_head_num, head_dim)
value_cache = tensors["value_cache"].reshape(-1, block_size, kv_head_num, head_dim)
block_table = tensors["block_table"].reshape(batch, max_num_blocks_per_req)
context_lens = tensors["context_lens"]

out = paged_attention(
query=query,
Expand Down
10 changes: 10 additions & 0 deletions examples/scripts/code_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def compute_golden(tensors: dict, params: dict) -> None:
import logging
import os
import sys
import time
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
Expand Down Expand Up @@ -629,6 +630,7 @@ def run(self) -> None:
runtime.enable_profiling(True)
logger.info("Profiling enabled")

_t_init_start = time.perf_counter()
with _temporary_env(run_env):
runtime.initialize(
orch_so_binary,
Expand All @@ -638,12 +640,20 @@ def run(self) -> None:
arg_sizes=arg_sizes,
kernel_binaries=kernel_binaries,
)
_t_init_end = time.perf_counter()
logger.info(f">>> runtime.initialize() took {_t_init_end - _t_init_start:.3f}s")

# Save expected values BEFORE hardware execution (outputs will be overwritten)
golden = {k: v.clone() for k, v in outputs.items()}
# Convert to dict for compute_golden (may expect numpy-like interface)
golden_with_inputs = {**inputs, **golden}
_t_golden_start = time.perf_counter()
self._golden_module.compute_golden(golden_with_inputs, params)
_t_golden_end = time.perf_counter()
logger.info(f">>> compute_golden() took {_t_golden_end - _t_golden_start:.3f}s")
logger.info(f">>> Total init-to-launch: {_t_golden_end - _t_init_start:.3f}s "
f"(initialize={_t_init_end - _t_init_start:.3f}s, "
f"golden={_t_golden_end - _t_golden_start:.3f}s)")

# Launch runtime
logger.info("=== Launching Runtime ===")
Expand Down
143 changes: 84 additions & 59 deletions examples/tensormap_and_ringbuffer/paged_attention/golden.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,76 +117,101 @@ def paged_attention(
"""
Compute paged attention using online softmax with head tiling and GQA.

Vectorized across the batch dimension for performance.
Supports different context_lens per batch via masking.

Args:
query: (batch, num_heads, head_dim) float16
key_cache: (total_blocks, block_size, num_kv_heads, head_dim) float16
value_cache: (total_blocks, block_size, num_kv_heads, head_dim) float16
query: (batch, num_heads, head_dim) bfloat16
key_cache: (total_blocks, block_size, num_kv_heads, head_dim) bfloat16
value_cache: (total_blocks, block_size, num_kv_heads, head_dim) bfloat16
num_kv_heads: int
num_heads: int
scale_value: float
block_table: (batch, block_num) int32
context_lens: (batch,) int32

Returns:
out: (batch, num_heads, head_dim) float32
out: (batch * num_heads, head_dim) float32
"""
assert num_kv_heads == 1
batch, num_heads, head_dim = query.shape
batch, num_heads_dim, head_dim = query.shape
_, block_size, _, _ = key_cache.shape
_, block_num = block_table.shape

query = query.reshape(-1, head_dim)
key_cache = key_cache.reshape(-1, block_size, head_dim)
value_cache = value_cache.reshape(-1, block_size, head_dim)

out = torch.zeros((batch * num_heads, head_dim), dtype=torch.float32)

for b_idx in range(batch):
cur_seq = int(context_lens[b_idx])
bn_this_batch = (cur_seq + block_size - 1) // block_size
assert bn_this_batch <= block_num

q_tile = min(num_heads, 128)
for cur_offset in range(0, num_heads, q_tile):
q_tile_size = min(q_tile, num_heads - cur_offset)
base_idx = b_idx * num_heads + cur_offset
qi = query[base_idx : base_idx + q_tile_size].to(torch.float32)

oi = None
li = None
mi = None

for bn in range(bn_this_batch):
cur_block_idx = block_table[b_idx, bn]
valid_len = min(block_size, cur_seq - bn * block_size)
kj = key_cache[cur_block_idx, :valid_len, :].to(torch.float32)
vj = value_cache[cur_block_idx, :valid_len, :].to(torch.float32)

sij = (qi @ kj.T) * scale_value
mij = sij.max(dim=-1, keepdim=True)[0]
pij = torch.exp(sij - mij).to(torch.float16).to(torch.float32)
lij = pij.sum(dim=1, keepdim=True)

if bn == 0:
oi = pij @ vj
li = lij
mi = mij
else:
mi_new = torch.maximum(mi, mij)
alpha = torch.exp(mi - mi_new)
beta = torch.exp(mij - mi_new)
li_new = alpha * li + beta * lij
oi_new = pij @ vj
oi = alpha * oi + beta * oi_new
li = li_new
mi = mi_new

if bn == bn_this_batch - 1:
oi = oi / li

out[base_idx : base_idx + q_tile_size] = oi

return out

# Reshape for batched computation
key_cache_flat = key_cache.reshape(-1, block_size, head_dim)
value_cache_flat = value_cache.reshape(-1, block_size, head_dim)

out = torch.zeros((batch, num_heads_dim, head_dim), dtype=torch.float32)

q_tile = min(num_heads_dim, 128)

# Max blocks across all batches (each batch may have different context_len)
max_bn = int(((context_lens.max().item()) + block_size - 1) // block_size)

for q_offset in range(0, num_heads_dim, q_tile):
q_tile_size = min(q_tile, num_heads_dim - q_offset)
# qi: (batch, q_tile_size, head_dim)
qi = query[:, q_offset:q_offset + q_tile_size, :].to(torch.float32)

oi = None # (batch, q_tile_size, head_dim)
li = None # (batch, q_tile_size, 1)
mi = None # (batch, q_tile_size, 1)

for bn in range(max_bn):
# valid_len per batch for this block position
valid_lens = torch.clamp(context_lens - bn * block_size, min=0, max=block_size)
active_mask = valid_lens > 0 # (batch,)

if not active_mask.any():
break

# Gather block indices for all batches
block_indices = block_table[:, bn] # (batch,)

# Gather K and V: (batch, block_size, head_dim)
kj_all = key_cache_flat[block_indices].to(torch.float32)
vj_all = value_cache_flat[block_indices].to(torch.float32)

# QK matmul: (batch, q_tile_size, block_size)
sij = torch.bmm(qi, kj_all.transpose(1, 2)) * scale_value

# Mask out invalid positions (beyond valid_len per batch)
pos = torch.arange(block_size, device=sij.device).unsqueeze(0) # (1, block_size)
valid_mask = pos < valid_lens.unsqueeze(1) # (batch, block_size)
valid_mask = valid_mask.unsqueeze(1) # (batch, 1, block_size)
sij = sij.masked_fill(~valid_mask, float('-inf'))

# Also mask inactive batches (no blocks at this position)
batch_mask = active_mask.view(-1, 1, 1) # (batch, 1, 1)
sij = sij.masked_fill(~batch_mask, float('-inf'))

mij = sij.max(dim=-1, keepdim=True)[0] # (batch, q_tile_size, 1)
mij = mij.clamp(min=-1e30)
pij = torch.exp(sij - mij)
pij = pij.masked_fill(~valid_mask, 0.0)
pij = pij.masked_fill(~batch_mask, 0.0)
pij = pij.to(torch.bfloat16).to(torch.float32)
lij = pij.sum(dim=-1, keepdim=True) # (batch, q_tile_size, 1)

# PV matmul: (batch, q_tile_size, head_dim)
oi_new = torch.bmm(pij, vj_all)

if bn == 0:
oi = oi_new
li = lij
mi = mij
else:
mi_new = torch.maximum(mi, mij)
alpha = torch.exp(mi - mi_new)
beta = torch.exp(mij - mi_new)
li = alpha * li + beta * lij
oi = alpha * oi + beta * oi_new
mi = mi_new

# Final normalization
out[:, q_offset:q_offset + q_tile_size, :] = oi / li

return out.reshape(-1, head_dim)


def compute_golden(tensors: dict, params: dict) -> None:
Expand Down
9 changes: 9 additions & 0 deletions src/platform/a2a3/aicpu/device_log.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,12 @@ void dev_log_error(const char* func, const char* fmt, ...) {
va_end(args);
dlog_error(AICPU, "%lu %s\n\"%s\"", GET_TID(), func, buffer);
}

void dev_log_always(const char* func, const char* fmt, ...) {
va_list args;
va_start(args, fmt);
char buffer[2048];
vsnprintf(buffer, sizeof(buffer), fmt, args);
va_end(args);
dlog_error(AICPU, "%lu %s\n\"%s\"", GET_TID(), func, buffer);
}
Loading
Loading