diff --git a/CMakeLists.txt b/CMakeLists.txt index ae2d9dc5..e2db297b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -24,4 +24,7 @@ elseif(NOT DEVICE IN_LIST SUPPORTED_DEVICE) message(FATAL_ERROR "Device ${DEVICE} is not supported! Supported devices: ${SUPPORTED_DEVICE}") endif() +add_subdirectory(dlinfer/vendor/${DEVICE}) add_subdirectory(dlinfer/graph/dicp/vendor) + +install(CODE "message(STATUS \"Install completed for device: ${DEVICE}\")") diff --git a/dlinfer/framework/lmdeploy_ext/cudagraph/maca_cudagraph.py b/dlinfer/framework/lmdeploy_ext/cudagraph/maca_cudagraph.py index 09ec0d0c..ef3c20b9 100644 --- a/dlinfer/framework/lmdeploy_ext/cudagraph/maca_cudagraph.py +++ b/dlinfer/framework/lmdeploy_ext/cudagraph/maca_cudagraph.py @@ -19,11 +19,11 @@ def MacaCudaGraphMixin_make_buffers_cudagraph( num_blocks = graph_meta.num_blocks device = graph_meta.device input_buffers: BuffType = dict() - input_buffers["input_ids"] = torch.empty( + input_buffers["input_ids"] = torch.zeros( 1, max_tokens, dtype=torch.int32, device=device ) - input_buffers["position_ids"] = torch.empty( + input_buffers["position_ids"] = torch.zeros( (1, max_tokens), dtype=torch.int32, device=device ) diff --git a/dlinfer/vendor/maca/CMakeLists.txt b/dlinfer/vendor/maca/CMakeLists.txt index 26d58f54..de03e3ed 100644 --- a/dlinfer/vendor/maca/CMakeLists.txt +++ b/dlinfer/vendor/maca/CMakeLists.txt @@ -35,6 +35,7 @@ ExternalProject_Add(${MACA_SUB_MODULE} BUILD_ALWAYS ON USES_TERMINAL_BUILD ON USES_TERMINAL_INSTALL ON + INSTALL_COMMAND "" CMAKE_ARGS "-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}" "-DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH}" diff --git a/dlinfer/vendor/maca/fused_moe.py b/dlinfer/vendor/maca/fused_moe.py index 44fae5d2..f7b58aad 100644 --- a/dlinfer/vendor/maca/fused_moe.py +++ b/dlinfer/vendor/maca/fused_moe.py @@ -4,13 +4,10 @@ import json import os from typing import Any, Callable, Dict, List, Optional, Tuple - import torch import triton import triton.language as tl -from .maca_extension import ops as maca_ext_ops - import logging logger = logging.getLogger(__name__) @@ -251,9 +248,13 @@ def moe_align_block_size( (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device ) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) - - maca_ext_ops.moe_align_block_size( - topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad + torch.ops._moe_C.moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, ) return sorted_ids, expert_ids, num_tokens_post_pad @@ -460,8 +461,7 @@ def fused_topk( token_expert_indicies = torch.empty( M, topk, dtype=torch.int32, device=hidden_states.device ) - - maca_ext_ops.topk_softmax( + torch.ops._moe_C.topk_softmax( topk_weights, topk_ids, token_expert_indicies, @@ -796,8 +796,7 @@ def fused_experts_impl( use_int8_w8a16=use_int8_w8a16, block_shape=block_shape, ) - - maca_ext_ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) invoke_fused_moe_kernel( intermediate_cache2, @@ -818,8 +817,7 @@ def fused_experts_impl( use_int8_w8a16=use_int8_w8a16, block_shape=block_shape, ) - - maca_ext_ops.moe_sum( + torch.ops._moe_C.moe_sum( intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states[begin_chunk_idx:end_chunk_idx], ) diff --git a/dlinfer/vendor/maca/maca_ops.py b/dlinfer/vendor/maca/maca_ops.py index 15bfa7bb..e3941896 100644 --- a/dlinfer/vendor/maca/maca_ops.py +++ b/dlinfer/vendor/maca/maca_ops.py @@ -11,7 +11,10 @@ from dlinfer.utils.type_annotation import Tensor, Optional, Sequence, Tuple from .fused_moe import fused_experts -from .maca_extension import ops as maca_ext_ops +from mcoplib import lmdeploy as ops +from mcoplib import op as op_origin +import mcoplib._C +import mcoplib._moe_C __all__ = [ "add_rms_norm", @@ -58,7 +61,7 @@ def add_rms_norm( weight: Tensor, epsilon: float, ) -> Tuple[Tensor, Tensor]: - maca_ext_ops.fused_add_rms_norm(hidden_states, residual, weight, epsilon) + torch.ops._C.fused_add_rms_norm(hidden_states, residual, weight, epsilon) return hidden_states, residual @@ -76,8 +79,7 @@ def apply_rotary_pos_emb( query = query.flatten(-2, -1) key = key.flatten(-2, -1) rot_dim = cos.size(-1) - - maca_ext_ops.rotary_embedding( + ops.lmdeploy_rotary_embedding( position_ids_1d, query, key, @@ -86,6 +88,7 @@ def apply_rotary_pos_emb( sin.view(-1, rot_dim), True, ) + return query, key @@ -161,6 +164,13 @@ def prefill_attention( ) softmax_scale = float(1 / math.sqrt(head_dim)) + # for qwen vl part. + if q_start_loc.shape[0] == q_seq_len.shape[0]: + causal = False + q_start_loc = torch.cat( + [q_start_loc, q_seq_len.sum().to(torch.int32).unsqueeze(0)] + ) + output = flash_attn_varlen_func( query, key, @@ -173,6 +183,7 @@ def prefill_attention( causal=causal, window_size=(-1, -1), ) + attn_output.copy_(output) return output @@ -200,15 +211,11 @@ def fill_kv_cache( quant_bits: int, ) -> Tuple[Tensor, Tensor]: kv_indices = kv_indices.squeeze(-1) - maca_ext_ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - kv_indices, - "auto", - torch.tensor(1.0), - torch.tensor(1.0), + k_scale = torch.tensor(1.0) + v_scale = torch.tensor(1.0) + + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, value, key_cache, value_cache, kv_indices, "auto", k_scale, v_scale ) return key_cache, value_cache @@ -238,8 +245,6 @@ def paged_decode_attention( num_kv_heads = value_cache.size(1) block_size = value_cache.size(-2) - output = torch.empty_like(query) - is_mla = query.size(-1) == 576 if is_mla: @@ -319,7 +324,8 @@ def paged_prefill_attention( ) return output[..., :512] - value_cache = value_cache.permute(0, 1, 3, 2) + value_cache = value_cache.permute(0, 2, 3, 1) + key_cache = key_cache.permute(0, 2, 3, 1) context_attention_fwd( query, key, @@ -347,8 +353,7 @@ def rms_norm( hidden_states = hidden_states.to(torch.float32) weight = weight.to(torch.float32) output = torch.empty_like(hidden_states) - maca_ext_ops.rms_norm(output, hidden_states, weight, epsilon) - + op_origin.rms_norm(output, hidden_states, weight, epsilon, None, None, False) return output.to(input_dtype) @@ -366,13 +371,9 @@ def moe_gating_topk_softmax( token_expert_indicies = torch.empty_like(topk_ids) - maca_ext_ops.topk_softmax( - topk_weights, - topk_ids, - token_expert_indicies, - router_logits.float(), + torch.ops._moe_C.topk_softmax( + topk_weights, topk_ids, token_expert_indicies, router_logits.float() ) - del token_expert_indicies # Not used. Will be used in the future. if renormalize: @@ -388,7 +389,7 @@ def silu_and_mul(x: Tensor, dim: int = -1) -> Tensor: d = x.shape[-1] // 2 output_shape = x.shape[:-1] + (d,) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) - maca_ext_ops.silu_and_mul(out, x) + torch.ops._C.silu_and_mul(out, x) return out