Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 115 additions & 49 deletions dlinfer/vendor/maca/maca_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,32 @@
import torch
import torch.distributed as dist

from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from flash_attn import flash_attn_varlen_func
from .context_flashattention import context_attention_fwd

from dlinfer.vendor import vendor_ops_registry
from dlinfer.utils.registry import register_ops
from dlinfer.utils.type_annotation import Tensor, Optional, Sequence, Tuple

from .fused_moe import fused_experts
from .maca_extension import ops as maca_ext_ops
from mcoplib import lmdeploy as mcoplib_ops
from mcoplib import op as op_origin
import mcoplib._C

env_value = os.getenv("MACA_LMDEPLOY_MCOPLIB_OPS", "false")
USE_MCOPLIB_OPS = env_value.lower() in ("true", "1", "yes", "on")

# Select the ops library based on environment variable
if USE_MCOPLIB_OPS:
ops = mcoplib_ops
ops_name = "mcoplib_ops"
else:
ops = maca_ext_ops
ops_name = "maca_ext_ops"

# Print environment variable value and selected ops library
print(f"[DLInfer] MACA_LMDEPLOY_MCOPLIB_OPS environment variable: {env_value} USE_MCOPLIB_OPS:{USE_MCOPLIB_OPS}")
print(f"[DLInfer] Using ops library: {ops_name}")

__all__ = [
"add_rms_norm",
Expand Down Expand Up @@ -58,7 +75,8 @@ def add_rms_norm(
weight: Tensor,
epsilon: float,
) -> Tuple[Tensor, Tensor]:
maca_ext_ops.fused_add_rms_norm(hidden_states, residual, weight, epsilon)
torch.ops._C.fused_add_rms_norm(hidden_states, residual, weight, epsilon)
#ops.fused_add_rms_norm(hidden_states, residual, weight, epsilon)
return hidden_states, residual


Expand All @@ -76,16 +94,30 @@ def apply_rotary_pos_emb(
query = query.flatten(-2, -1)
key = key.flatten(-2, -1)
rot_dim = cos.size(-1)
#cos_sin_cache = torch.cat([cos, sin], dim=-1).contiguous()
#position_ids_1d = position_ids_1d.contiguous().unsqueeze(0)
#print(f"=======>position_ids_1d size:{position_ids_1d.size(0)} query:{query.size(0)} key:{key.size(0)}")
if USE_MCOPLIB_OPS:
ops.lmdeploy_rotary_embedding(
position_ids_1d,
query,
key,
head_size,
cos.view(-1, rot_dim),
sin.view(-1, rot_dim),
True,
)
else:
ops.rotary_embedding(
position_ids_1d,
query,
key,
head_size,
cos.view(-1, rot_dim),
sin.view(-1, rot_dim),
True,
)

maca_ext_ops.rotary_embedding(
position_ids_1d,
query,
key,
head_size,
cos.view(-1, rot_dim),
sin.view(-1, rot_dim),
True,
)
return query, key


Expand Down Expand Up @@ -200,15 +232,13 @@ def fill_kv_cache(
quant_bits: int,
) -> Tuple[Tensor, Tensor]:
kv_indices = kv_indices.squeeze(-1)
maca_ext_ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
kv_indices,
"auto",
torch.tensor(1.0),
torch.tensor(1.0),
k_scale = torch.tensor(1.0)
v_scale = torch.tensor(1.0)
# torch.ops._C_cache_ops.reshape_and_cache(
# key, value, key_cache, value_cache, kv_indices, "auto", k_scale, v_scale
# )
ops.reshape_and_cache_new(
key, value, key_cache, value_cache, kv_indices, "auto", 1.0, 1.0
)
return key_cache, value_cache

Expand Down Expand Up @@ -247,17 +277,27 @@ def paged_decode_attention(
-1, num_kv_heads, 576, block_size
)

output = flash_attn_with_kvcache(
q=query.unsqueeze(1),
k_cache=key_cache, # [num_blocks, block_size, num_heads, head_size]
v_cache=value_cache, # [num_blocks, block_size, num_heads, head_size]
block_table=block_table,
cache_seqlens=kv_seq_len,
softmax_scale=softmax_scale,
causal=True,
softcap=0,
).squeeze(1)

ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
softmax_scale,
block_table,
kv_seq_len,
block_size,
max_kv_seq_len,
None, # alibi_slopes
"auto", # kv_cache_dtype
1.0, # k_scale
1.0, # v_scale
torch.cuda.current_device(), # tp_rank
0, # blocksparse_local_blocks
1, # blocksparse_vert_stride
1, # blocksparse_block_size
1, # blocksparse_head_sliding_step
)
if is_mla:
return output[..., :512]
else:
Expand Down Expand Up @@ -325,6 +365,7 @@ def paged_prefill_attention(
key,
value,
output,
"auto",
key_cache,
value_cache,
b_loc=block_table,
Expand All @@ -347,8 +388,11 @@ def rms_norm(
hidden_states = hidden_states.to(torch.float32)
weight = weight.to(torch.float32)
output = torch.empty_like(hidden_states)
maca_ext_ops.rms_norm(output, hidden_states, weight, epsilon)

#op_rms_norm.rms_norm_mc(output, hidden_states, weight, epsilon, rms_div = False)
if USE_MCOPLIB_OPS:
op_origin.rms_norm(output, hidden_states, weight, epsilon, None, None,False)
else:
ops.rms_norm(output, hidden_states, weight, epsilon)
return output.to(input_dtype)


Expand All @@ -366,12 +410,19 @@ def moe_gating_topk_softmax(

token_expert_indicies = torch.empty_like(topk_ids)

maca_ext_ops.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
router_logits.float(),
)
if USE_MCOPLIB_OPS:
op_origin.moe_softmax_topk(
topk_weights,
topk_ids,
router_logits.float(),
)
else:
ops.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
router_logits.float(),
)

del token_expert_indicies # Not used. Will be used in the future.

Expand All @@ -388,7 +439,8 @@ def silu_and_mul(x: Tensor, dim: int = -1) -> Tensor:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
maca_ext_ops.silu_and_mul(out, x)
torch.ops._C.silu_and_mul(out, x)
#ops.silu_and_mul(out, x)
return out


Expand All @@ -402,15 +454,27 @@ def fused_moe(
top_k: int,
renormalize: bool,
) -> Tensor:
N = hidden_states.size(0)
topk_weights = topk_weights.reshape(N, top_k)
topk_ids = topk_ids.reshape(N, top_k)
N, D = hidden_states.shape
hidden_states = hidden_states.view(N, -1, D).repeat(1, top_k, 1).reshape(-1, D)
out = torch.zeros(
N * top_k,
down_weights.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device,
)

if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)

return fused_experts(
hidden_states, gate_up_weights, down_weights, topk_weights, topk_ids
)
for i in range(gate_up_weights.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = silu_and_mul(
hidden_states[mask] @ gate_up_weights[i].transpose(0, 1)
) @ down_weights[i].transpose(0, 1)
return (
out.view(N, -1, down_weights.shape[1])
* topk_weights.view(N, -1, 1).to(out.dtype)
).sum(dim=1)


@register_ops(vendor_ops_registry)
Expand All @@ -421,7 +485,7 @@ def linear(
all_reduce: Optional[bool],
group: Optional[str],
) -> Tensor:
if os.getenv("DLINFER_LINEAR_USE_NN_LAYOUT", "0") == "1":
if os.getenv("DLINER_LINEAR_USE_NN_LAYOUT", "0") == "1":
out = torch.matmul(x, weight)
if bias is not None:
out += bias
Expand All @@ -430,3 +494,5 @@ def linear(
if all_reduce:
dist.all_reduce(out)
return out