Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion models/common/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
output_mem_config=None,
ccl_topology=ttnn.Topology.Ring,
tt_ccl=None,
simplified_rms=False,
):
super().__init__()
self.device = device
Expand Down Expand Up @@ -114,13 +115,21 @@ def __init__(
fp32_dest_acc_en=True,
packer_l1_acc=True,
)
self.simplified_rms = simplified_rms

def forward(self, x: ttnn.Tensor, mode, in_sharded=False, out_sharded=False) -> ttnn.Tensor:
# If input is sharded do sharded RMSNorm and optionally return sharded output
program_config = self.sharded_program_config if in_sharded else None
memory_config = self.sharded_output_config if out_sharded else None
distributed = self.is_distributed and self.is_distributed(mode)
norm = self._distributed_rmsnorm if distributed else ttnn.rms_norm
norm = (
self._simplified_rmsnorm
if self.simplified_rms
else self._distributed_rmsnorm
if distributed
else ttnn.rms_norm
)

weight = self.weight_distributed if distributed else self.weight

if in_sharded:
Expand All @@ -142,6 +151,25 @@ def forward(self, x: ttnn.Tensor, mode, in_sharded=False, out_sharded=False) ->
else:
return x

def _simplified_rmsnorm(
self, inp, epsilon=None, weight=None, program_config=None, memory_config=None, compute_kernel_config=None
):
inp = ttnn.sharded_to_interleaved(inp, ttnn.DRAM_MEMORY_CONFIG)
xnorm = ttnn.pow(inp, 2)
xnorm = ttnn.mean(xnorm, dim=-1, keepdim=True)
xnorm = ttnn.rsqrt(xnorm + epsilon)
xnorm = ttnn.multiply(inp, xnorm)
weight = ttnn.reshape(weight, [1, 1, -1])
output = ttnn.multiply(xnorm, (weight), use_legacy=False)

if memory_config is not None:
output = ttnn.to_memory_config(output, memory_config)

ttnn.deallocate(xnorm)
ttnn.deallocate(weight)

return output

def _distributed_rmsnorm(
self, inp, epsilon=None, weight=None, program_config=None, memory_config=None, compute_kernel_config=None
):
Expand Down
128 changes: 128 additions & 0 deletions models/experimental/qwen25_vl/tests/test_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""Test for Qwen 2.5 VL Vision Attention"""

import os

import pytest
import torch
from loguru import logger

import ttnn
from models.tt_transformers.tt.model_config import ModelArgs

from models.experimental.qwen25_vl.tt.attention import TtQwen2_5_VLVisionSdpaAttention
from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull


@skip_for_grayskull("Requires wormhole_b0 to run")
@pytest.mark.parametrize(
"batch, num_chunks",
((1, 4),),
)
@pytest.mark.parametrize(
"mesh_device",
[
{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(
os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids())
)
],
indirect=True,
)
def test_attention_inference(batch, num_chunks, mesh_device, reset_seeds):
dtype = ttnn.bfloat16
pcc_required = 0.99

model_args = ModelArgs(mesh_device)
state_dict = model_args.load_state_dict()

# Ref model needs partial state dict, but our models use full state dict keys as cached weight names
first_layer_prefix = "visual.blocks.0.attn."
partial_state_dict = {
k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix))
}

dim = model_args.vision_dim

reference_model = model_args.reference_vision_attention()
reference_model.load_state_dict(partial_state_dict)
reference_model.eval()

hidden_size = model_args.vision_dim
n_heads = model_args.vision_attn_n_heads
head_dim = hidden_size // n_heads
seq_len = model_args.vision_chunk_ntok

tt_model = TtQwen2_5_VLVisionSdpaAttention(
mesh_device,
state_dict,
state_dict_prefix=first_layer_prefix,
# weight_cache_path=model_args.weight_cache_path(dtype),
dtype=dtype,
configuration=model_args,
)

seq_len = 4096
hidden_dim = 1280
num_heads = 16
head_dim = hidden_dim // num_heads # 80
rotary_dim = head_dim // 2 # 40

# Step 1: PyTorch input
pt_attention_input = torch.randn(seq_len, hidden_dim) # no batch dim
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32)

# Step 2: precompute cos/sin
cos, sin = precompute_rope_cos_sin(seq_len, head_dim)

# Step 3: run PyTorch reference
reference_output = reference_model(
pt_attention_input, cu_seqlens, rotary_pos_emb=None, position_embeddings=(cos, sin)
)

# Step 4: TT input
tt_attention_input = model_args.prepare_residual_tensor_prefill(
pt_attention_input.unsqueeze(0), force_replicated=True
)

cos_tensor = ttnn.from_torch(cos, device=mesh_device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT)
sin_tensor = ttnn.from_torch(sin, device=mesh_device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT)

# Step 6: run TT
tt_out = tt_model(tt_attention_input, cu_seqlens, position_embeddings=(cos_tensor, sin_tensor))

# Doing contract in tt is correct!!
tt_output_torch = ttnn.to_torch(tt_out, device=mesh_device).squeeze(0)

passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required)

logger.info(comp_allclose(reference_output, tt_output_torch))
logger.info(f"PCC: {pcc_message}")

assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!"


def precompute_rope_cos_sin(seq_len: int, dim: int, theta: float = 10000.0):
"""
Precompute RoPE cos/sin tensors.
Args:
seq_len: sequence length (number of tokens)
dim: hidden size (usually head_dim, not full hidden_size)
theta: RoPE theta parameter (default 10000)
Returns:
cos, sin: [seq_len, dim] each
"""
# Build the rope frequencies
half_dim = dim // 2
freq_seq = torch.arange(half_dim, dtype=torch.float32)
inv_freq = 1.0 / (theta ** (freq_seq / half_dim))

# positions: [seq_len]
positions = torch.arange(seq_len, dtype=torch.float32)

# Outer product: [seq_len, half_dim]
sinusoid_inp = torch.outer(positions, inv_freq)

# Concatenate for complex dim
sin = torch.sin(torch.cat([sinusoid_inp, sinusoid_inp], dim=-1))
cos = torch.cos(torch.cat([sinusoid_inp, sinusoid_inp], dim=-1))

return cos, sin
Loading