Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,7 @@ elseif(NOT DEVICE IN_LIST SUPPORTED_DEVICE)
message(FATAL_ERROR "Device ${DEVICE} is not supported! Supported devices: ${SUPPORTED_DEVICE}")
endif()

add_subdirectory(dlinfer/vendor/${DEVICE})
add_subdirectory(dlinfer/graph/dicp/vendor)

install(CODE "message(STATUS \"Install completed for device: ${DEVICE}\")")
4 changes: 2 additions & 2 deletions dlinfer/framework/lmdeploy_ext/cudagraph/maca_cudagraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ def MacaCudaGraphMixin_make_buffers_cudagraph(
num_blocks = graph_meta.num_blocks
device = graph_meta.device
input_buffers: BuffType = dict()
input_buffers["input_ids"] = torch.empty(
input_buffers["input_ids"] = torch.zeros(
1, max_tokens, dtype=torch.int32, device=device
)

input_buffers["position_ids"] = torch.empty(
input_buffers["position_ids"] = torch.zeros(
(1, max_tokens), dtype=torch.int32, device=device
)

Expand Down
1 change: 1 addition & 0 deletions dlinfer/vendor/maca/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ ExternalProject_Add(${MACA_SUB_MODULE}
BUILD_ALWAYS ON
USES_TERMINAL_BUILD ON
USES_TERMINAL_INSTALL ON
INSTALL_COMMAND ""
CMAKE_ARGS
"-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}"
"-DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH}"
Expand Down
22 changes: 10 additions & 12 deletions dlinfer/vendor/maca/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@
import json
import os
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
import triton
import triton.language as tl

from .maca_extension import ops as maca_ext_ops

import logging

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -251,9 +248,13 @@ def moe_align_block_size(
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
)
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)

maca_ext_ops.moe_align_block_size(
topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
torch.ops._moe_C.moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
)

return sorted_ids, expert_ids, num_tokens_post_pad
Expand Down Expand Up @@ -460,8 +461,7 @@ def fused_topk(
token_expert_indicies = torch.empty(
M, topk, dtype=torch.int32, device=hidden_states.device
)

maca_ext_ops.topk_softmax(
torch.ops._moe_C.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
Expand Down Expand Up @@ -796,8 +796,7 @@ def fused_experts_impl(
use_int8_w8a16=use_int8_w8a16,
block_shape=block_shape,
)

maca_ext_ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))

invoke_fused_moe_kernel(
intermediate_cache2,
Expand All @@ -818,8 +817,7 @@ def fused_experts_impl(
use_int8_w8a16=use_int8_w8a16,
block_shape=block_shape,
)

maca_ext_ops.moe_sum(
torch.ops._moe_C.moe_sum(
intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx],
)
Expand Down
51 changes: 26 additions & 25 deletions dlinfer/vendor/maca/maca_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
from dlinfer.utils.type_annotation import Tensor, Optional, Sequence, Tuple

from .fused_moe import fused_experts
from .maca_extension import ops as maca_ext_ops
from mcoplib import lmdeploy as ops
from mcoplib import op as op_origin
import mcoplib._C
import mcoplib._moe_C

__all__ = [
"add_rms_norm",
Expand Down Expand Up @@ -58,7 +61,7 @@ def add_rms_norm(
weight: Tensor,
epsilon: float,
) -> Tuple[Tensor, Tensor]:
maca_ext_ops.fused_add_rms_norm(hidden_states, residual, weight, epsilon)
torch.ops._C.fused_add_rms_norm(hidden_states, residual, weight, epsilon)
return hidden_states, residual


Expand All @@ -76,8 +79,7 @@ def apply_rotary_pos_emb(
query = query.flatten(-2, -1)
key = key.flatten(-2, -1)
rot_dim = cos.size(-1)

maca_ext_ops.rotary_embedding(
ops.lmdeploy_rotary_embedding(
position_ids_1d,
query,
key,
Expand All @@ -86,6 +88,7 @@ def apply_rotary_pos_emb(
sin.view(-1, rot_dim),
True,
)

return query, key


Expand Down Expand Up @@ -161,6 +164,13 @@ def prefill_attention(
)
softmax_scale = float(1 / math.sqrt(head_dim))

# for qwen vl part.
if q_start_loc.shape[0] == q_seq_len.shape[0]:
causal = False
q_start_loc = torch.cat(
[q_start_loc, q_seq_len.sum().to(torch.int32).unsqueeze(0)]
)

output = flash_attn_varlen_func(
query,
key,
Expand All @@ -173,6 +183,7 @@ def prefill_attention(
causal=causal,
window_size=(-1, -1),
)
attn_output.copy_(output)
return output


Expand Down Expand Up @@ -200,15 +211,11 @@ def fill_kv_cache(
quant_bits: int,
) -> Tuple[Tensor, Tensor]:
kv_indices = kv_indices.squeeze(-1)
maca_ext_ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
kv_indices,
"auto",
torch.tensor(1.0),
torch.tensor(1.0),
k_scale = torch.tensor(1.0)
v_scale = torch.tensor(1.0)

torch.ops._C_cache_ops.reshape_and_cache_flash(
key, value, key_cache, value_cache, kv_indices, "auto", k_scale, v_scale
)
return key_cache, value_cache

Expand Down Expand Up @@ -238,8 +245,6 @@ def paged_decode_attention(

num_kv_heads = value_cache.size(1)
block_size = value_cache.size(-2)
output = torch.empty_like(query)

is_mla = query.size(-1) == 576

if is_mla:
Expand Down Expand Up @@ -319,7 +324,8 @@ def paged_prefill_attention(
)
return output[..., :512]

value_cache = value_cache.permute(0, 1, 3, 2)
value_cache = value_cache.permute(0, 2, 3, 1)
key_cache = key_cache.permute(0, 2, 3, 1)
context_attention_fwd(
query,
key,
Expand Down Expand Up @@ -347,8 +353,7 @@ def rms_norm(
hidden_states = hidden_states.to(torch.float32)
weight = weight.to(torch.float32)
output = torch.empty_like(hidden_states)
maca_ext_ops.rms_norm(output, hidden_states, weight, epsilon)

op_origin.rms_norm(output, hidden_states, weight, epsilon, None, None, False)
return output.to(input_dtype)


Expand All @@ -366,13 +371,9 @@ def moe_gating_topk_softmax(

token_expert_indicies = torch.empty_like(topk_ids)

maca_ext_ops.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
router_logits.float(),
torch.ops._moe_C.topk_softmax(
topk_weights, topk_ids, token_expert_indicies, router_logits.float()
)

del token_expert_indicies # Not used. Will be used in the future.

if renormalize:
Expand All @@ -388,7 +389,7 @@ def silu_and_mul(x: Tensor, dim: int = -1) -> Tensor:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
maca_ext_ops.silu_and_mul(out, x)
torch.ops._C.silu_and_mul(out, x)
return out


Expand Down
Loading