diff --git a/.gitignore b/.gitignore index 0c9ef52c..916fce8b 100644 --- a/.gitignore +++ b/.gitignore @@ -17,6 +17,9 @@ __pycache__/ # Cache cache/ +# Cache +.cache/clangd/ + # JSON *.json diff --git a/include/infinicore_infer.h b/include/infinicore_infer.h index 0bed7bc7..a936359e 100644 --- a/include/infinicore_infer.h +++ b/include/infinicore_infer.h @@ -6,5 +6,6 @@ #include "infinicore_infer/models/deepseek.h" #include "infinicore_infer/models/jiuge.h" +#include "infinicore_infer/models/qwen3moe.h" #endif /* INFINICORE_INFER_H */ diff --git a/include/infinicore_infer/models/qwen3moe.h b/include/infinicore_infer/models/qwen3moe.h new file mode 100644 index 00000000..92a8c918 --- /dev/null +++ b/include/infinicore_infer/models/qwen3moe.h @@ -0,0 +1,120 @@ +#ifndef _QWEN3MOE_H_ +#define _QWEN3MOE_H_ + +#include +#include +#include +#include +#include +namespace Qwen3MoE { +struct Weights; +struct Model; + +struct Meta { + infiniDtype_t dt_logits; + size_t nlayer, d, nh, nkvh, dh, di, dctx, dvoc; + float epsilon, theta; + uint32_t end_token; + // + size_t _moe_intermediate_size; + size_t _shared_expert_intermediate_size; + size_t _num_experts; + size_t _num_experts_per_tok; + bool _norm_topk_prob; + +public: + void print_info() const { + printf("\n"); + printf(" dt_logits : %d\n", dt_logits); + printf(" nlayer : %ld\n", nlayer); + printf(" d : %ld\n", d); + printf(" nh : %ld\n", nh); + printf(" nkvh : %ld\n", nkvh); + printf(" dh : %ld\n", dh); + printf(" di : %ld\n", di); + printf(" dvoc : %ld\n", dvoc); + printf(" nkvh : %ld\n", nkvh); + + printf(" epsilon : %f\n", epsilon); + printf(" theta : %f\n", theta); + + printf(" end_token : %d\n", end_token); + + printf(" _moe_intermediate_size : %ld\n", _moe_intermediate_size); + printf(" _shared_expert_intermediate_size : %ld\n", _shared_expert_intermediate_size); + printf(" _num_experts : %ld\n", _num_experts); + printf(" _num_experts_per_tok : %ld\n", _num_experts_per_tok); + printf(" _norm_topk_prob : %d\n", _norm_topk_prob); + } +}; + +}; // namespace Qwen3MoE + +////////////////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////// Qwen3 APIs ///////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +/// @brief 创建模型 +/// @param device 协处理器种类 +/// @param ndev 协处理器数量 +/// @param dev_ids 协处理器编号,长度为 ndev +__C __export struct Qwen3MoE::Model * +Qwen3MoEcreateModel(const Qwen3MoE::Meta *, + const Qwen3MoE::Weights *, + infiniDevice_t device, + int ndev, + const int *dev_ids); + +/// @brief 销毁模型 +__C __export void +Qwen3MoEdestroyModel(struct Qwen3MoE::Model *); + +/// @brief 创建 KV Cache +__C __export struct KVCache * +Qwen3MoEcreateKVCache(const struct Qwen3MoE::Model *); + +/// @brief 复制 KV Cache +__C __export struct KVCache * +Qwen3MoEduplicateKVCache(const struct Qwen3MoE::Model *, + const struct KVCache *, uint32_t seq_len); + +/// @brief 销毁 KV Cache +__C __export void +Qwen3MoEdropKVCache(const struct Qwen3MoE::Model *, + struct KVCache *); + +/// @brief 批次推理一轮,并采样出新的 token +/// @param tokens 输入 token 地址 +/// @param ntok 输入 token 数量 +/// @param nreq 请求数量 +/// @param req_lens 每个请求的 token 数量 +/// @param req_pos 每个请求的起始位置 +/// @param kv_caches 每个请求的 KV Cache +/// @param temperature 采样温度(0. 表示贪心采样) +/// @param topk 采样 topk(1 表示贪心采样) +/// @param topp 采样 topp +/// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq +__C __export void +Qwen3MoEinferBatch(struct Qwen3MoE::Model *, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, + struct KVCache **kv_caches, + const float *temperature, const uint32_t *topk, const float *topp, + uint32_t *output); + +/// @brief 批次推理一轮,输出 output embedding 后的 logits +/// @param tokens 输入 token 地址 +/// @param ntok 输入 token 数量 +/// @param nreq 请求数量 +/// @param req_lens 每个请求的 token 数量 +/// @param req_pos 每个请求的起始位置 +/// @param kv_caches 每个请求的 KV Cache +/// @param logits 输出 token 数组,每个请求一个输出,长度至少为nreq +__C __export void +Qwen3MoEforwardBatch(struct Qwen3MoE::Model *, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, + struct KVCache **kv_caches, + void *logits); + +#endif diff --git a/scripts/libinfinicore_infer/qwen3_moe.py b/scripts/libinfinicore_infer/qwen3_moe.py new file mode 100644 index 00000000..fd2f08ad --- /dev/null +++ b/scripts/libinfinicore_infer/qwen3_moe.py @@ -0,0 +1,406 @@ +import ctypes +from ctypes import c_size_t, c_uint, c_int, c_float, c_void_p, c_bool, POINTER +import torch +from .base import DataType, DeviceType +import ctypes +import os +from typing import List + + + +def find_name_in_state_dict(name_list: List[str], state_dict: dict): + retname = None + for name in name_list: + if name in state_dict: + retname = name + break + return retname + +class MLPCStruct(ctypes.Structure): + _fields_ = [ + ("_gate_up_proj_weight", c_void_p), + ("_down_proj_weight", c_void_p), + ] + + def __init__(self, ilayer: int, di, ndev, d, + torch_dt_mat, transpose_weight, + state_dict: dict, + gate_proj=None, up_proj=None, down_proj=None): + # transpose_weight 默认为True + + ### gate_up + self.gate_up_tensor = torch.concat(self.gate_up_slices(ilayer, di, ndev, state_dict, gate_proj=gate_proj, up_proj=up_proj)).to(torch_dt_mat) + + if not transpose_weight: + self.gate_up_tensor = self.gate_up_tensors.reshape(ndev, 2 * di // ndev, d).transpose(1, 2).contiguous() + setattr(self, "_gate_up_proj_weight", self.gate_up_tensor.data_ptr()) + + ### down + if down_proj is None: + down_proj = find_name_in_state_dict([f"model.layers.{ilayer}.mlp.down_proj.weight", f"layers.{ilayer}.mlp.down_proj.weight"], state_dict) + if transpose_weight: + self.ffn_down_tensor = state_dict[down_proj].to(torch_dt_mat).reshape([d, ndev, di // ndev]).transpose(0, 1).contiguous() + else: + self.ffn_down_tensor = state_dict[down_proj].transpose(0, 1).to(torch_dt_mat).contiguous() + + setattr(self, "_down_proj_weight", self.ffn_down_tensor.data_ptr()) + + def gate_up_slices(self, ilayer: int, di, ndev, state_dict: dict, + gate_proj=None, up_proj=None): + if gate_proj is None: + gate_proj = find_name_in_state_dict([f"model.layers.{ilayer}.mlp.gate_proj.weight", f"layers.{ilayer}.mlp.gate_proj.weight"], state_dict) + if up_proj is None: + up_proj = find_name_in_state_dict([f"model.layers.{ilayer}.mlp.up_proj.weight", f"layers.{ilayer}.mlp.up_proj.weight"], state_dict) + + _result = [] + _di = di // ndev + for _idev in range(ndev): + _start = _idev * _di + _end = (_idev + 1) * _di + _result.append(state_dict[gate_proj][_start:_end, :]) + _result.append(state_dict[up_proj][_start:_end, :]) + + return _result + + +class AttentionCStruct(ctypes.Structure): + _fields_ = [ + ("_qkv_proj_weight", c_void_p), + ("_qkv_proj_bias", c_void_p), + ("_qk_norm_weight", c_void_p), + ("_o_proj_weight", c_void_p), + ] + + def __init__(self, ilayer: int, nh, nkvh, d, dh, ndev, + torch_dt_mat, torch_dt_logits, torch_dt_norm, + transpose_weight, + state_dict: dict): + ### + self.qkv_tensor = torch.concat(self.qkv_slices(ilayer, nh, nkvh, d, dh, ndev, state_dict)).to(torch_dt_mat) + if not transpose_weight: + self.qkv_tensor = self.qkv_tensor.reshape(ndev, (nh + 2 * nkvh) // ndev * dh, d).transpose(1, 2).contiguous() + setattr(self, "_qkv_proj_weight", self.qkv_tensor.data_ptr()) + + ### + self.qkv_b_tensor = None + attn_q_b = f"model.layers.{ilayer}.self_attn.q_proj.bias" + if attn_q_b in state_dict: + self.qkv_b_tensor = torch.concat(self.qkv_b_slices(ilayer, nh, nkvh, d, dh, ndev, state_dict)).to(torch_dt_logits) + setattr(self, "_qkv_proj_bias", self.qkv_b_tensor.data_ptr()) + + ### + self.qk_norm_tensor = None + attn_q_norm = find_name_in_state_dict([f"model.layers.{ilayer}.self_attn.q_norm.weight", f"layers.{ilayer}.self_attn.q_norm.weight"], state_dict) + if attn_q_norm in state_dict: + attn_q_norm = find_name_in_state_dict([f"model.layers.{ilayer}.self_attn.q_norm.weight", f"layers.{ilayer}.self_attn.q_norm.weight"], state_dict) + attn_k_norm = find_name_in_state_dict([f"model.layers.{ilayer}.self_attn.k_norm.weight", f"layers.{ilayer}.self_attn.k_norm.weight"], state_dict) + + q_norm = state_dict[attn_q_norm].reshape([2, dh // 2]).transpose(1, 0) + k_norm = state_dict[attn_k_norm].reshape([2, dh // 2]).transpose(1, 0) + self.qk_norm_tensor = torch.concat([q_norm, k_norm]).to(torch_dt_norm) + setattr(self, "_qk_norm_weight", self.qk_norm_tensor.data_ptr()) + + ### + attn_o = find_name_in_state_dict([f"model.layers.{ilayer}.self_attn.o_proj.weight", f"layers.{ilayer}.self_attn.o_proj.weight"], state_dict) + if transpose_weight: + self.attn_o_tensor = state_dict[attn_o].to(torch_dt_mat).reshape([d, ndev, nh // ndev * dh]).transpose(0, 1).contiguous() + else: + self.attn_o_tensor = state_dict[attn_o].transpose(0, 1).to(torch_dt_mat).contiguous() + setattr(self, "_o_proj_weight", self.attn_o_tensor.data_ptr()) + + def qkv_b_slices(self, ilayer, nh, nkvh, d, dh, ndev, state_dict): + attn_q_b = f"model.layers.{ilayer}.self_attn.q_proj.bias" + attn_k_b = f"model.layers.{ilayer}.self_attn.k_proj.bias" + attn_v_b = f"model.layers.{ilayer}.self_attn.v_proj.bias" + + _QB = state_dict[attn_q_b].reshape([nh, 2, dh // 2]).transpose(1, 2) + _KB = state_dict[attn_k_b].reshape([nkvh, 2, dh // 2]).transpose(1, 2) + _VB = state_dict[attn_v_b].reshape([nkvh, dh // 2, 2]) + + _result = [] + _nh = nh // ndev + _nkvh = nkvh // ndev + for _idev in range(ndev): + _result.append(_QB[_idev * _nh: (_idev + 1) * _nh, :, :].flatten()) + _result.append(_KB[_idev * _nkvh: (_idev + 1) * _nkvh, :, :].flatten()) + _result.append(_VB[_idev * _nkvh: (_idev + 1) * _nkvh, :, :].flatten()) + return _result + + def qkv_slices(self, ilayer: int, nh, nkvh, d, dh, ndev, state_dict): + attn_q = find_name_in_state_dict([f"model.layers.{ilayer}.self_attn.q_proj.weight", f"layers.{ilayer}.self_attn.q_proj.weight"], state_dict) + attn_k = find_name_in_state_dict([f"model.layers.{ilayer}.self_attn.k_proj.weight", f"layers.{ilayer}.self_attn.k_proj.weight"], state_dict) + attn_v = find_name_in_state_dict([f"model.layers.{ilayer}.self_attn.v_proj.weight", f"layers.{ilayer}.self_attn.v_proj.weight"], state_dict) + + _Q = state_dict[attn_q].reshape([nh, 2, dh // 2, d]).transpose(1, 2) + _K = state_dict[attn_k].reshape([nkvh, 2, dh // 2, d]).transpose(1, 2) + _V = state_dict[attn_v].reshape([nkvh, dh // 2, 2, d]) + + _result = [] + _nh = nh // ndev + _nkvh = nkvh // ndev + for _idev in range(ndev): + _result.append(_Q[_idev * _nh: (_idev + 1) * _nh, :, :, :]) + _result.append(_K[_idev * _nkvh: (_idev + 1) * _nkvh, :, :, :]) + _result.append(_V[_idev * _nkvh: (_idev + 1) * _nkvh, :, :]) + return _result + + + +class MoEMetaCStruct(ctypes.Structure): + _fields_ = [ + ("dt_logits", DataType), + ("nlayer", c_size_t), + ("d", c_size_t), + ("nh", c_size_t), + ("nkvh", c_size_t), + ("dh", c_size_t), + ("di", c_size_t), + ("dctx", c_size_t), + ("dvoc", c_size_t), + ("epsilon", c_float), + ("theta", c_float), + ("end_token", c_uint), + # + ("_moe_intermediate_size", c_size_t), + ("_shared_expert_intermediate_size", c_size_t), + ("_num_experts", c_size_t), + ("_num_experts_per_tok", c_size_t), + ("_norm_topk_prob", c_bool), + ] + + +class SparseMLPCStruct(ctypes.Structure): + _fields_ = [ + ("_shared_expert_num", c_size_t), + ("_num_experts", c_size_t), + ("_shared_expert_gate_weight", c_void_p), + ("_gate_weight", c_void_p), + ("_shared_expert", MLPCStruct), + ("_experts", POINTER(MLPCStruct)), + ] + + def __init__(self, ilayer: int, num_experts, ndev, d, + torch_dt_mat, transpose_weight, + _moe_intermediate_size, _shared_expert_intermediate_size, _num_experts_per_tok, _norm_topk_prob, + state_dict: dict): + + setattr(self, "_num_experts", num_experts) + + # shared_expert + shared_expert_gate = f"model.layers.{ilayer}.mlp.shared_expert_gate.weight" + if shared_expert_gate in state_dict: + + self.shared_expert_gate_tensor = state_dict[shared_expert_gate].to(torch_dt_mat) + if transpose_weight: + self.shared_expert_gate_tensor = self.shared_expert_gate_tensor.transpose(0, 1).contiguous() + setattr(self, "_shared_expert_gate_weight", self.shared_expert_gate_tensor.data_ptr()) + setattr(self, "_shared_expert_num", 1) + + ## shared_expert + gate_proj = f"model.layers.{ilayer}.mlp.shared_expert.gate_proj.weight" + up_proj = f"model.layers.{ilayer}.mlp.shared_expert.up_proj.weight" + down_proj = f"model.layers.{ilayer}.mlp.shared_expert.down_proj.weight" + self.shared_expert_mlp = MLPCStruct(ilayer, _shared_expert_intermediate_size, ndev, d, + torch_dt_mat, transpose_weight, + state_dict, + gate_proj=gate_proj, up_proj=up_proj, down_proj=down_proj) + setattr(self, "_shared_expert", self.shared_expert_mlp) + else: + setattr(self, "_shared_expert_num", 0) + + ## experts + experts_gate = f"model.layers.{ilayer}.mlp.gate.weight" + self.experts_gate_tensor = state_dict[experts_gate].to(torch_dt_mat) + if transpose_weight: + self.experts_gate_tensor = self.experts_gate_tensor.transpose(0, 1).contiguous() + setattr(self, "_gate_weight", self.experts_gate_tensor.data_ptr()) + + self.experts_mlp = [] + for i in range(num_experts): + gate_proj = f"model.layers.{ilayer}.mlp.experts.{i}.gate_proj.weight" + up_proj = f"model.layers.{ilayer}.mlp.experts.{i}.up_proj.weight" + down_proj = f"model.layers.{ilayer}.mlp.experts.{i}.down_proj.weight" + self.experts_mlp.append( + MLPCStruct(ilayer, _moe_intermediate_size, ndev, d, + torch_dt_mat, transpose_weight, + state_dict, + gate_proj=gate_proj, up_proj=up_proj, down_proj=down_proj) + ) + + self.experts_mlp_array = (MLPCStruct * num_experts)(*self.experts_mlp) + setattr(self, "_experts", self.experts_mlp_array) + + +# Define the Decoder Layer struct +class DecoderLayerCStruct(ctypes.Structure): + _fields_ = [ + ("_ilayer", c_int), + ("_post_attention_layernorm_weight", c_void_p), + ("_input_layernorm_weight", c_void_p), + ("_self_attn", AttentionCStruct), + ("_mlp", SparseMLPCStruct), + ] + + def __init__(self, ilayer: int, num_experts, nh, nkvh, d, di, dh, ndev, + torch_dt_mat, torch_dt_logits, torch_dt_norm, + transpose_weight, + _moe_intermediate_size, _shared_expert_intermediate_size, _num_experts_per_tok, _norm_topk_prob, + state_dict: dict): + setattr(self, "_ilayer", ilayer) + + attn_norm = f"model.layers.{ilayer}.input_layernorm.weight" + self.attn_norm_tensor = state_dict[attn_norm].to(torch_dt_norm) + setattr(self, "_input_layernorm_weight", self.attn_norm_tensor.data_ptr()) + + ffn_norm = f"model.layers.{ilayer}.post_attention_layernorm.weight" + self.mlp_norm_tensor = state_dict[ffn_norm].to(torch_dt_norm) + setattr(self, "_post_attention_layernorm_weight", self.mlp_norm_tensor.data_ptr()) + + self.self_attn = AttentionCStruct(ilayer, nh, nkvh, d, dh, ndev, torch_dt_mat, torch_dt_logits, torch_dt_norm, transpose_weight, state_dict) + setattr(self, "_self_attn", self.self_attn) + + self.mlp = SparseMLPCStruct(ilayer, num_experts, ndev, d, torch_dt_mat, transpose_weight, + _moe_intermediate_size, _shared_expert_intermediate_size, _num_experts_per_tok, _norm_topk_prob, + state_dict) + setattr(self, "_mlp", self.mlp) + + +# Define the QwenWeights struct +class WeightsCStruct(ctypes.Structure): + _fields_ = [ + ("_nlayer", c_size_t), + ("_dt_norm", DataType), + ("_dt_mat", DataType), + ("_transpose_linear_weights", c_int), + ### + ("_embed_tokens_weight", c_void_p), + ("_norm_weight", c_void_p), + ("_lm_head_weight", c_void_p), + ### + ("_layers", POINTER(DecoderLayerCStruct)), + ] + + def __init__(self, nlayer, num_experts, nh, nkvh, d, di, dh, ndev, + torch_dt_mat, torch_dt_logits, torch_dt_norm, + transpose_weight, + _moe_intermediate_size, _shared_expert_intermediate_size, _num_experts_per_tok, _norm_topk_prob, + state_dict: dict): + ### + setattr(self, "_nlayer", nlayer) + setattr(self, "_transpose_linear_weights", 1 if transpose_weight else 0) + + if torch_dt_mat == torch.float16: + setattr(self, "_dt_mat", DataType.INFINI_DTYPE_F16) + elif torch_dt_mat == torch.float32: + setattr(self, "_dt_mat", DataType.INFINI_DTYPE_F32) + elif torch_dt_mat == torch.bfloat16: + setattr(self, "_dt_mat", DataType.INFINI_DTYPE_BF16) + else: + raise ValueError("Unsupported proj weight data type") + + if torch_dt_norm == torch.float16: + setattr(self, "_dt_norm", DataType.INFINI_DTYPE_F16) + elif torch_dt_norm == torch.float32: + setattr(self, "_dt_norm", DataType.INFINI_DTYPE_F32) + elif torch_dt_norm == torch.bfloat16: + setattr(self, "_dt_norm", DataType.INFINI_DTYPE_BF16) + else: + raise ValueError("Unsupported norm weight data type") + + ### + input_embd = "model.embed_tokens.weight" + output_norm = "model.norm.weight" + output_embd = "lm_head.weight" + + input_embd_naming = input_embd if input_embd in state_dict else output_embd + self.input_embd_tensor = state_dict[input_embd_naming].to(torch_dt_logits) + + setattr(self, "_embed_tokens_weight", self.input_embd_tensor.data_ptr()) + + output_embd_naming = output_embd if output_embd in state_dict else input_embd + self.output_embd_tensor = state_dict[output_embd_naming].to(torch_dt_mat) # 这里把输入数据强制类型转换了 ?? 使用的不是 bf16 了 + if not transpose_weight: + self.output_embd_tensor = self.output_embd_tensor.transpose(0, 1).contiguous() + setattr(self, "_lm_head_weight", self.output_embd_tensor.data_ptr()) + + self.output_norm_tensor = state_dict[output_norm].to(torch_dt_norm) + setattr(self, "_norm_weight", self.output_norm_tensor.data_ptr()) + + ### + self.layers = [] + for ilayer in range(nlayer): + self.layers.append( + DecoderLayerCStruct(ilayer, num_experts, nh, nkvh, d, di, dh, ndev, + torch_dt_mat, torch_dt_logits, torch_dt_norm, + transpose_weight, + _moe_intermediate_size, _shared_expert_intermediate_size, _num_experts_per_tok, _norm_topk_prob, + state_dict) + ) + + self.layers_array = (DecoderLayerCStruct * nlayer)(*self.layers) + setattr(self, "_layers", self.layers_array) + + +class ModelCSruct(ctypes.Structure): + pass + + +class KVCacheCStruct(ctypes.Structure): + pass + + +def __open_library__(): + lib_path = os.path.join( + os.environ.get("INFINI_ROOT"), "lib", "libinfinicore_infer.so" + ) + + lib = ctypes.CDLL(lib_path) + lib.Qwen3MoEcreateModel.restype = POINTER(ModelCSruct) + lib.Qwen3MoEcreateModel.argtypes = [ + POINTER(MoEMetaCStruct), # JiugeMeta const * + POINTER(WeightsCStruct), # JiugeWeights const * + DeviceType, # DeviceType + c_int, # int ndev + POINTER(c_int), # int const *dev_ids + ] + lib.Qwen3MoEdestroyModel.argtypes = [POINTER(ModelCSruct)] + lib.Qwen3MoEcreateKVCache.argtypes = [POINTER(ModelCSruct)] + lib.Qwen3MoEcreateKVCache.restype = POINTER(KVCacheCStruct) + lib.Qwen3MoEdropKVCache.argtypes = [POINTER(ModelCSruct), POINTER(KVCacheCStruct)] + lib.Qwen3MoEinferBatch.restype = None + lib.Qwen3MoEinferBatch.argtypes = [ + POINTER(ModelCSruct), # struct JiugeModel const * + POINTER(c_uint), # unsigned int const *tokens + c_uint, # unsigned int ntok + POINTER(c_uint), # unsigned int const *req_lens + c_uint, # unsigned int nreq + POINTER(c_uint), # unsigned int const *req_pos + POINTER(POINTER(KVCacheCStruct)), # struct KVCache **kv_caches + POINTER(c_float), # float temperature + POINTER(c_uint), # unsigned int topk + POINTER(c_float), # float topp + POINTER(c_uint), # unsigned int *output + ] + lib.Qwen3MoEforwardBatch.restype = None + lib.Qwen3MoEforwardBatch.argtypes = [ + POINTER(ModelCSruct), # struct JiugeModel const * + POINTER(c_uint), # unsigned int const *tokens + c_uint, # unsigned int ntok + POINTER(c_uint), # unsigned int const *req_lens + c_uint, # unsigned int nreq + POINTER(c_uint), # unsigned int const *req_pos + POINTER(POINTER(KVCacheCStruct)), # struct KVCache **kv_caches + c_void_p, # void *logits + ] + + return lib + + +LIB = __open_library__() + +create_model = LIB.Qwen3MoEcreateModel +destroy_model = LIB.Qwen3MoEdestroyModel +create_kv_cache = LIB.Qwen3MoEcreateKVCache +drop_kv_cache = LIB.Qwen3MoEdropKVCache +infer_batch = LIB.Qwen3MoEinferBatch +forward_batch = LIB.Qwen3MoEforwardBatch diff --git a/scripts/qwen3_moe.py b/scripts/qwen3_moe.py new file mode 100644 index 00000000..7fcea8e7 --- /dev/null +++ b/scripts/qwen3_moe.py @@ -0,0 +1,355 @@ +from typing import List, Sequence + +from sympy import true +from libinfinicore_infer.qwen3_moe import ( + MoEMetaCStruct, + WeightsCStruct, + DataType, + DeviceType, + KVCacheCStruct, + create_model, + destroy_model, + create_kv_cache, + drop_kv_cache, + infer_batch, + forward_batch, +) +from infer_task import InferTask, KVCache + +from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref +import os +from pathlib import Path +import safetensors +import sys +import time +import json +import math +import torch +import transformers + +torch.set_default_device("cpu") + + +class Qwen3MoEMeta(MoEMetaCStruct): + def __init__(self, config: dict, dtype=torch.float16, max_tokens=None): + if dtype == torch.float16: + dt_ = DataType.INFINI_DTYPE_F16 + elif dtype == torch.float32: + dt_ = DataType.INFINI_DTYPE_F32 + elif dtype == torch.bfloat16: + dt_ = DataType.INFINI_DTYPE_BF16 + else: + dt_ = DataType.INFINI_DTYPE_F16 + + super().__init__( + dt_logits=dt_, + nlayer=config["num_hidden_layers"], + d=config["hidden_size"], + nh=config["num_attention_heads"], + nkvh=config["num_key_value_heads"] if "num_key_value_heads" in config else config["num_attention_heads"], + dh=config["head_dim"] if "head_dim" in config else (config["hidden_size"] // config["num_attention_heads"]), + di=config["intermediate_size"], + dctx=config["max_position_embeddings"] if max_tokens is None else max_tokens, + dvoc=config["vocab_size"], + epsilon=config["rms_norm_eps"], + theta=config["rope_theta"] if "rope_theta" in config else 100000.0, + end_token=config["eos_token_id"], + # + _moe_intermediate_size=config["moe_intermediate_size"], + _shared_expert_intermediate_size=config["shared_expert_intermediate_size"] if "shared_expert_intermediate_size" in config else 0, + _num_experts=config["num_experts"], + _num_experts_per_tok=config["num_experts_per_tok"], + _norm_topk_prob=config["norm_topk_prob"], + ) + self.torch_dtype_logits = dtype + + +class Qwen3MoEWeights(WeightsCStruct): + def __init__(self, + meta: Qwen3MoEMeta, + state_dict: dict, + torch_dt_mat=torch.float16, + torch_dt_norm=torch.float32, + ndev=1, + transpose_weight: bool = True, + ): + nlayer = meta.nlayer + nh = meta.nh + nkvh = meta.nkvh + dh = meta.dh + d = meta.d + di = meta.di + num_experts = meta._num_experts + + assert nh % nkvh == 0 + assert nh % ndev == 0 + assert nkvh % ndev == 0 + assert di % ndev == 0 + + torch_dt_logits = meta.torch_dtype_logits + + _moe_intermediate_size = meta._moe_intermediate_size + _shared_expert_intermediate_size = meta._shared_expert_intermediate_size + + _num_experts_per_tok = meta._num_experts_per_tok + _norm_topk_prob = meta._norm_topk_prob + + super().__init__(nlayer, num_experts, nh, nkvh, d, di, dh, ndev, + torch_dt_mat, torch_dt_logits, torch_dt_norm, + transpose_weight, + _moe_intermediate_size, _shared_expert_intermediate_size, _num_experts_per_tok, _norm_topk_prob, + state_dict) + + +class BatchedTask: + def __init__(self, tasks: List[InferTask]): + self.tasks = tasks + self.nreq = len(tasks) + + # Precompute fields + token_lists = [t.tokens for t in tasks] + self.req_lens_list = [len(toks) for toks in token_lists] + self.req_pos_list = [t.pos for t in tasks] + self.kv_cache_ptrs = [t.kvcache().data() for t in tasks] + self.temperaturas_list = [t.temperature for t in tasks] + self.topks_list = [t.topk for t in tasks] + self.topps_list = [t.topp for t in tasks] + + # Flatten token lists + flat_tokens = [tok for toks in token_lists for tok in toks] + self.ntok = len(flat_tokens) + + # Convert to ctypes arrays in one pass + self.tokens = (c_uint * self.ntok)(*flat_tokens) + self.req_lens = (c_uint * self.nreq)(*self.req_lens_list) + self.req_pos = (c_uint * self.nreq)(*self.req_pos_list) + self.kv_caches = (POINTER(KVCacheCStruct) * self.nreq)(*self.kv_cache_ptrs) + self.temperaturas = (c_float * self.nreq)(*self.temperaturas_list) + self.topks = (c_uint * self.nreq)(*self.topks_list) + self.topps = (c_float * self.nreq)(*self.topps_list) + + def input_args(self): + return ( + self.tokens, + self.ntok, + self.req_lens, + self.nreq, + self.req_pos, + self.kv_caches, + self.temperaturas, + self.topks, + self.topps, + ) + + +# --------------------------- +# --------------------------- +# --------------------------- +def load_all_safetensors_from_dir(dir_path_: str): + tensors_ = {} + dir_path_ = Path(dir_path_) + for file in sorted(dir_path_.glob("*.safetensors")): + data_ = safetensors.safe_open(file, "pt") + for name_ in data_.keys(): + tensors_[name_] = data_.get_tensor(name_) + return tensors_ + + +def load_config_json(dir_path_: str): + with open(os.path.join(dir_path_, "config.json"), "r") as f: + config = json.load(f) + return config + + +class Qwen3MoEForCauslLM: + def __init__( + self, model_dir_path, device=DeviceType.DEVICE_TYPE_CPU, ndev=1, max_tokens=None + ): + print("Loading model weights to host...") + load_start_time = time.time() + + self.config = load_config_json(model_dir_path) + eos_token_id = self.config["eos_token_id"] + self.eos_token_id = [eos_token_id] if type(eos_token_id) == int else eos_token_id + + transpose_weight = ( + device != DeviceType.DEVICE_TYPE_ASCEND + ) # y = xW is faster than y=xW^T on Ascend + + if "qwen3_moe" == self.config["model_type"]: + state_dict = load_all_safetensors_from_dir(model_dir_path) + + self.meta = Qwen3MoEMeta(self.config, max_tokens=max_tokens) + self.weights = Qwen3MoEWeights( + self.meta, + state_dict, + ndev=ndev, + transpose_weight=transpose_weight, + ) + self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path) + else: + raise ValueError("Unsupported model architecture") + + load_end_time = time.time() + print(f"Qwen3MoEWeights, Time used: {load_end_time - load_start_time:.3f}s") + + print(f"Creating model on {ndev} devices...") + load_start_time = time.time() + dev_ids = (c_int * ndev)(*[i for i in range(ndev)]) + self.model_instance = create_model( + byref(self.meta), + byref(self.weights), + device, + ndev, + dev_ids, + ) + load_end_time = time.time() + print(f"create_model Time used: {load_end_time - load_start_time:.3f}s") + + def max_context_len(self): + return self.meta.dctx + + def create_kv_cache(self): + return create_kv_cache(self.model_instance) + + def drop_kv_cache(self, kv_cache): + drop_kv_cache(self.model_instance, kv_cache) + + def batch_infer_one_round(self, tasks: List[InferTask]): + output = (c_uint * len(tasks))() + batch_inputs = BatchedTask(tasks) + infer_batch( + self.model_instance, + *(batch_inputs.input_args()), + output, + ) + return list(output) + + def generate(self, input_content, max_steps, topp_=1.0, topk_=1, temperature_=1.0): + input_content = self.tokenizer.apply_chat_template( + conversation=[{"role": "user", "content": input_content}], + add_generation_prompt=True, + tokenize=False, + ) + print(input_content, end="", flush=True) + tokens = self.tokenizer.encode(input_content) + # print("tokens: ", tokens) + + infer_task = InferTask( + 0, + tokens, + self.max_context_len(), + temperature_, + topk_, + topp_, + self.eos_token_id, + ) + infer_task.bind_kvcache(KVCache(self)) + + steps = 0 + total_time = 0 + output_content = "" + for step_i in range(max_steps): + start_time = time.time() + output_tokens = self.batch_infer_one_round([infer_task]) + + end_time = time.time() + steps += 1 + output_str = self.tokenizer.decode(output_tokens[0]) + output_content += output_str + print(output_str, end="", flush=True) + if output_tokens[0] in self.eos_token_id: + break + infer_task.next(output_tokens[0]) + + if step_i > 0: + total_time += end_time - start_time + + print("\n") + avg_time = total_time * 1000 / (steps - 1 + 1e-9) + print(f"Time per step: {avg_time:.3f}ms") + + + infer_task._kv_cache.drop(self) + return output_content, avg_time + + def perplexity(self, test_sequences: List[Sequence[int]], batch_size=10): + tasks = [ + InferTask(i, [], self.max_context_len(), 1.0, 1, 1.0, self.eos_token_id) + for i in range(batch_size) + ] + kv_caches = [KVCache(self) for _ in range(batch_size)] + + nll = 0.0 + total_len = 0 + + for i in range(0, len(test_sequences), batch_size): + batch_id = 0 + true_tokens = [] + while batch_id < batch_size and batch_id + i < len(test_sequences): + input_tokens = test_sequences[i + batch_id][:-1] + true_tokens.extend(test_sequences[i + batch_id][1:]) + tasks[batch_id].tokens = input_tokens + tasks[batch_id].bind_kvcache(kv_caches[batch_id]) + batch_id += 1 + + batch_inputs = BatchedTask(tasks[:batch_id]) + logits = torch.zeros( + (batch_inputs.ntok, self.meta.dvoc), dtype=self.meta.torch_dtype_logits + ) + forward_batch( + self.model_instance, + batch_inputs.tokens, + batch_inputs.ntok, + batch_inputs.req_lens, + batch_inputs.nreq, + batch_inputs.req_pos, + batch_inputs.kv_caches, + logits.data_ptr(), + ) + + logits = logits.float() + token_ids = torch.tensor(true_tokens, dtype=torch.int64) # [ntok,] + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) # (ntok, vocab) + token_logprobs = log_probs[ + torch.arange(batch_inputs.ntok), token_ids + ] # (ntok,) + + start = 0 + for l in batch_inputs.req_lens_list: + nll += -token_logprobs[start: start + l].sum().item() + start += l + total_len += token_logprobs.numel() + + for task in tasks: + task.release_kvcache() + + return math.exp(nll / total_len) + + def destroy_model_instance(self): + destroy_model(self.model_instance) + print("Model destroyed") + + +def test(): + if len(sys.argv) < 3: + print("Usage: python qwen3_moe.py --nvidia [n_device]") + sys.exit(1) + model_path = sys.argv[2] + device_type = DeviceType.DEVICE_TYPE_NVIDIA + if sys.argv[1] == "--nvidia": + device_type = DeviceType.DEVICE_TYPE_NVIDIA + elif sys.argv[1] == "--metax": + device_type = DeviceType.DEVICE_TYPE_METAX + else: + print("Usage: python qwen3_moe.py --nvidia [n_device]") + sys.exit(1) + + ndev = int(sys.argv[3]) if len(sys.argv) > 3 else 1 + model = Qwen3MoEForCauslLM(model_path, device_type, ndev) + model.generate("山东最高的山是?", 500) + model.destroy_model_instance() + + +if __name__ == "__main__": + test() diff --git a/src/cache_manager/opcache_manager.hpp b/src/cache_manager/opcache_manager.hpp index 69a20b47..bfedd483 100644 --- a/src/cache_manager/opcache_manager.hpp +++ b/src/cache_manager/opcache_manager.hpp @@ -160,6 +160,7 @@ class CacheManager { DECLARE_OP_CACHE(CausalSoftmax) DECLARE_OP_CACHE(LogSoftmax) DECLARE_OP_CACHE(Topkrouter) + DECLARE_OP_CACHE(Topksoftmax) DECLARE_OP_CACHE(SwiGLU) DECLARE_OP_CACHE(RandomSample) DECLARE_OP_CACHE(DequantizeAWQ) @@ -173,6 +174,7 @@ class CacheManager { CausalSoftmax_cache(capacity, DESTROY_FUNC(CausalSoftmax)), LogSoftmax_cache(capacity, DESTROY_FUNC(LogSoftmax)), Topkrouter_cache(capacity, DESTROY_FUNC(Topkrouter)), + Topksoftmax_cache(capacity, DESTROY_FUNC(Topksoftmax)), SwiGLU_cache(capacity, DESTROY_FUNC(SwiGLU)), RandomSample_cache(capacity, DESTROY_FUNC(RandomSample)), DequantizeAWQ_cache(capacity, DESTROY_FUNC(DequantizeAWQ)) {} diff --git a/src/models/inference_context.cpp b/src/models/inference_context.cpp index f604b2d9..feb9b67f 100644 --- a/src/models/inference_context.cpp +++ b/src/models/inference_context.cpp @@ -188,6 +188,47 @@ void InferenceContext::topkrouter(std::shared_ptr values, // F32 routed_scaling_factor, topk, stream)); } +/** + * @brief Performs Top-K Softmax operation + * + * This function performs Top-K Softmax operation on the input tensor x: + * 1. Finds the top-k largest values and their indices in the input tensor + * 2. Applies softmax normalization to these top-k values + * 3. Writes the normalized probability values to values and corresponding indices to indices + * + * This operation is commonly used in sparse attention mechanisms and Mixture of Experts (MoE) + * models to select the most important top-k elements and perform probability normalization. + * + * @param values Output tensor storing the normalized probability values of top-k elements (F32 type) + * @param indices Output tensor storing the index positions corresponding to the top-k values (I32 type) + * @param x Input tensor containing the data to perform top-k softmax operation on + * @param topk The top-k value to select, i.e., selecting the top k largest elements + * @param norm_topk_prob Whether to normalize the top-k probabilities (non-zero value indicates normalization) + */ +void InferenceContext::topksoftmax(std::shared_ptr values, // F32 + std::shared_ptr indices, // I32 + std::shared_ptr x, + size_t topk, + int norm_topk_prob) { + + size_t key = CacheManager::createDescriptorKey(values, indices, x); + + infiniopTopksoftmaxDescriptor_t desc; + if (!cache_manager->getTopksoftmaxDescriptor(key, desc)) { + RUN_INFINI(infiniopCreateTopksoftmaxDescriptor(op_handle, &desc, x->desc())); + cache_manager->putTopksoftmaxDescriptor(key, desc); + } + + size_t workspace_size = 0; + RUN_INFINI(infiniopGetTopksoftmaxWorkspaceSize(desc, &workspace_size)); + ensure_workspace(workspace_size); + void *workspace = workspace_storage->memory(); + + RUN_INFINI(infiniopTopksoftmax( + desc, workspace, workspace_size, + values->data(), indices->data(), x->data(), topk, norm_topk_prob, stream)); +} + void InferenceContext::swiglu(std::shared_ptr out, std::shared_ptr up, std::shared_ptr gate) { diff --git a/src/models/inference_context.hpp b/src/models/inference_context.hpp index d8597b5c..129ded4e 100644 --- a/src/models/inference_context.hpp +++ b/src/models/inference_context.hpp @@ -47,6 +47,12 @@ struct InferenceContext { float routed_scaling_factor, size_t topk); + void topksoftmax(std::shared_ptr values, // F32 + std::shared_ptr indices, // I32 + std::shared_ptr x, + size_t topk, + int norm_topk_prob); + void swiglu(std::shared_ptr out, std::shared_ptr up, std::shared_ptr gate); @@ -132,6 +138,19 @@ inline void topkrouter(std::shared_ptr values, // F32 topk); } +inline void topksoftmax(std::shared_ptr values, // F32 + std::shared_ptr indices, // I32 + std::shared_ptr x, + size_t topk, + int norm_topk_prob) { + + getInferenceContext().topksoftmax(values, + indices, + x, + topk, + norm_topk_prob); +} + inline void swiglu(std::shared_ptr out, std::shared_ptr up, std::shared_ptr gate) { getInferenceContext().swiglu(out, up, gate); diff --git a/src/models/jiuge/jiuge_impl.hpp b/src/models/jiuge/jiuge_impl.hpp index 64ba72dd..47b0e76c 100644 --- a/src/models/jiuge/jiuge_impl.hpp +++ b/src/models/jiuge/jiuge_impl.hpp @@ -20,7 +20,7 @@ struct JiugeDeviceResource { // Weights std::shared_ptr w_in_embd, w_out_norm, w_out_embd, sin_table, cos_table; - std::vector> w_attn_norm, w_attn_qkv, b_attn_qkv, w_attn_q_norm, w_attn_k_norm,w_attn_out, + std::vector> w_attn_norm, w_attn_qkv, b_attn_qkv, w_attn_q_norm, w_attn_k_norm, w_attn_out, w_ffn_norm, w_ffn_gate_up, w_ffn_down; // Streams infinirtStream_t stream; diff --git a/src/models/qwen/qwen3moe/qwen3moe.cpp b/src/models/qwen/qwen3moe/qwen3moe.cpp new file mode 100644 index 00000000..e520f60a --- /dev/null +++ b/src/models/qwen/qwen3moe/qwen3moe.cpp @@ -0,0 +1,75 @@ +#include "../../../tensor.hpp" +#include "../../../utils.hpp" +#include "../../inference_context.hpp" +#include "../qwen_device_resource.hpp" +#include "../qwen_kv_cache.hpp" +#include "../qwen_model.hpp" +#include "../qwen_weight.hpp" +#include "infinicore_infer.h" +#include +#include +#include + +////////////////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////// Model API //////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +__C Qwen3MoE::Model *Qwen3MoEcreateModel(const Qwen3MoE::Meta *meta, + const Qwen3MoE::Weights *weight, + infiniDevice_t device, + int ndev, + const int *dev_ids) { + return createModel(meta, weight, device, ndev, dev_ids); +} + +/// @brief 销毁模型 +__C void Qwen3MoEdestroyModel(struct Qwen3MoE::Model *model) { + destroyModel(model); +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////// KVCache API //////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////////////////////////////////////// +/// @brief 创建 KV Cache +__C KVCache *Qwen3MoEcreateKVCache(const Qwen3MoE::Model *model) { + return createKVCache(model); +} + +/// @brief 复制 KV Cache +__C KVCache * +Qwen3MoEduplicateKVCache(const Qwen3MoE::Model *model, + const KVCache *kv_cache, uint32_t seq_len) { + return duplicateKVCache(model, kv_cache, seq_len); +} + +/// @brief 销毁 KV Cache +__C void Qwen3MoEdropKVCache(const Qwen3MoE::Model *model, KVCache *kv_cache) { + dropKVCache(model, kv_cache); +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////// infer API ////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////////////////////////////////////// +__C void Qwen3MoEinferBatch(struct Qwen3MoE::Model *model, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, + KVCache **kv_caches, + const float *temperature, const uint32_t *topk, const float *topp, + uint32_t *output) { + inferBatch(model, tokens, ntok, + req_lens, nreq, req_pos, + kv_caches, temperature, topk, topp, output); +} + +__C void Qwen3MoEforwardBatch(Qwen3MoE::Model *model, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, + KVCache **kv_caches, + void *logits) { + + forwardBatch(model, + tokens, ntok, + req_lens, nreq, req_pos, + kv_caches, + logits); +} \ No newline at end of file diff --git a/src/models/qwen/qwen3moe/qwen3moe_infer.cpp b/src/models/qwen/qwen3moe/qwen3moe_infer.cpp new file mode 100644 index 00000000..494e2f26 --- /dev/null +++ b/src/models/qwen/qwen3moe/qwen3moe_infer.cpp @@ -0,0 +1,349 @@ +#include "../../../tensor.hpp" +#include "../../../utils.hpp" +#include "../../inference_context.hpp" +#include "../qwen_device_resource.hpp" +#include "../qwen_model.hpp" +#include "infinicore_infer.h" +#include +#include +#include + +void Qwen3MoEinferDeviceBatch(const Qwen3MoE::Meta *meta, DeviceResource &rsrc, + uint32_t idev, uint32_t ndev, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, + struct KVCache **kv_caches, + const float *temperature, const uint32_t *topk, const float *topp, + uint32_t *output, void *last_logits) { + + // ======================================================================== + // 1. 提取模型配置参数 + // ======================================================================== + auto nlayer = meta->nlayer; // Transformer层数 + auto nkvh = meta->nkvh / ndev; // 每个设备的KV头数(分布式时分割) + auto nh = meta->nh / ndev; // 每个设备的注意力头数(分布式时分割) + auto ngroup = nh / nkvh; // GQA分组数(Grouped Query Attention) + auto dh = meta->dh; // 每个注意力头的维度 + auto d = meta->d; // 模型隐藏层维度 + auto dt_logits = meta->dt_logits; // logits的数据类型 + auto dvoc = meta->dvoc; // 词汇表大小 + auto stream = rsrc.stream; // CUDA流,用于异步操作 + + // ======================================================================== + // 2. 分配主要计算缓冲区 + // ======================================================================== + // 主计算缓冲区:用于存储每层的输入输出 + auto logits_in = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool); // 层输入 [ntok, d] + auto logits_out = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool); // 层输出 [ntok, d] + + // QKV缓冲区:存储query、key、value投影结果 + // 形状: [ntok, (nh + nkvh * 2) * dh] + // nh个query头 + nkvh个key头 + nkvh个value头 + auto qkv_buf = Tensor::buffer(dt_logits, {ntok, (nh + nkvh * 2) * dh}, rsrc.memory_pool); + auto qkv_rope = qkv_buf->view({ntok, nh + nkvh * 2, dh}); // 用于RoPE的视图 + + // 注意力输出缓冲区 + auto o_buf = Tensor::buffer(dt_logits, {ntok, nh * dh}, rsrc.memory_pool); // 注意力输出 [ntok, nh*dh] + + // 采样相关缓冲区 + auto prob_buf = Tensor::buffer(dt_logits, {nreq, dvoc}, rsrc.memory_pool); // 输出概率分布 [nreq, dvoc] + auto result_buf = Tensor::buffer(INFINI_DTYPE_I64, {nreq}, rsrc.memory_pool); // 采样结果 [nreq] + auto result_cpu = std::vector(nreq); // CPU端的结果缓冲区 + + // Prepare inputs + auto batch_pos_ids = std::vector(ntok); + size_t req_start = 0; + for (uint32_t req = 0; req < nreq; req++) { + for (uint32_t i = 0; i < req_lens[req]; i++) { + batch_pos_ids[req_start + i] = req_pos[req] + i; + } + req_start += req_lens[req]; + } + + // 获取权重张量指针 + const Qwen3MoE::WeightsTensor *g_WeightsTensor = rsrc.weights_tensor_ptr.get(); + if (!g_WeightsTensor) { + return; // 权重未加载,直接返回 + } + + std::shared_ptr pos_ids_buf; + if (rsrc.device == INFINI_DEVICE_CPU) { + pos_ids_buf = Tensor::weight(batch_pos_ids.data(), INFINI_DTYPE_U32, {ntok}); + } else { + pos_ids_buf = Tensor::buffer(INFINI_DTYPE_U32, {ntok}, rsrc.memory_pool); + RUN_INFINI(infinirtMemcpyAsync(pos_ids_buf->data(), batch_pos_ids.data(), sizeof(uint32_t) * ntok, + INFINIRT_MEMCPY_H2D, stream)); + } + + // 将输入token嵌入到隐藏空间:从词汇表嵌入矩阵中查找每个token的嵌入向量 + for (uint32_t i = 0; i < ntok; i++) { + RUN_INFINI(infinirtMemcpyAsync(logits_in->data(i * d), g_WeightsTensor->w_in_embd->data(tokens[i] * d), + dsize(dt_logits) * d, INFINIRT_MEMCPY_D2D, stream)); + } + + // Attention + // attention inner + size_t max_qk_size = 0; + size_t max_seq_len = 0; + + for (uint32_t req = 0; req < nreq; req++) { + auto past_len = req_pos[req]; // 历史长度(已缓存的token数) + auto seq_len = req_lens[req]; // 当前请求的新token数 + auto total_len = past_len + seq_len; // 总长度(历史 + 当前) + + max_qk_size = std::max(max_qk_size, size_t(seq_len * total_len)); + max_seq_len = std::max(max_seq_len, size_t(seq_len)); + } + + auto qk_buf = Tensor::buffer(dt_logits, {nh, max_qk_size}, rsrc.memory_pool); + auto rearrange_q_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool); + auto q_rearrange = rearrange_q_buf->view({nkvh, ngroup, max_seq_len, dh}); + auto attn_val_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool); + auto attn_val_gemm = attn_val_buf->view({nkvh, ngroup, max_seq_len, dh}); + + // Compute + for (uint32_t ilayer = 0; ilayer < nlayer; ilayer++) { + auto layer_tensor = g_WeightsTensor->layers[ilayer]; + + // 1. Attention + // rms norm + rmsnorm(logits_out, logits_in, layer_tensor->w_attn_norm, meta->epsilon); + // qkv_proj + linear(qkv_buf, logits_out, layer_tensor->self_attn->w_attn_qkv, 1.0, 0.0, nullptr, layer_tensor->self_attn->b_attn_qkv ? layer_tensor->self_attn->b_attn_qkv : nullptr); + + if (layer_tensor->self_attn->w_attn_qk_norm) { + auto qkv_buf_view = qkv_buf->view({ntok, nh + nkvh * 2, dh}); + auto q_buf = qkv_buf_view->slice(1, 0, nh); + auto k_buf = qkv_buf_view->slice(1, nh, nkvh); + rmsnorm(q_buf, q_buf, layer_tensor->self_attn->w_attn_qk_norm->slice(0, 0, dh), meta->epsilon); + rmsnorm(k_buf, k_buf, layer_tensor->self_attn->w_attn_qk_norm->slice(0, dh, dh), meta->epsilon); + } + + // rope + rope(qkv_rope->slice(1, 0, nh), qkv_rope->slice(1, 0, nh), pos_ids_buf, g_WeightsTensor->sin_table, g_WeightsTensor->cos_table); + rope(qkv_rope->slice(1, nh, nkvh), qkv_rope->slice(1, nh, nkvh), pos_ids_buf, g_WeightsTensor->sin_table, g_WeightsTensor->cos_table); + + size_t token_offset = 0; + for (uint32_t req = 0; req < nreq; req++) { + auto past_len = req_pos[req]; // 该请求的历史长度 + auto seq_len = req_lens[req]; // 该请求的新token数 + auto total_len = past_len + seq_len; // 总长度 + + // 提取当前请求的Q、K、V + // Q: [seq_len, nh, dh] -> 重排为 [nkvh, ngroup, seq_len, dh] 用于GQA + auto o = o_buf->slice({{0, token_offset, seq_len}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3}); + auto q = qkv_rope->slice({{0, token_offset, seq_len}, {1, 0, nh}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3}); + auto k = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh, nkvh}}); + auto v = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}}); + + // self attention + // concat + rearrange(kv_caches[req]->k[idev][ilayer]->slice(0, past_len, seq_len), k); + rearrange(kv_caches[req]->v[idev][ilayer]->slice(0, past_len, seq_len), v); + // qk + rearrange(q_rearrange->slice(2, 0, seq_len), q); + auto qk_gemm = qk_buf->slice(1, 0, seq_len * total_len)->view({nkvh, ngroup * seq_len, total_len}); + auto k_gemm = kv_caches[req]->k[idev][ilayer]->slice(0, 0, total_len)->permute({1, 2, 0}); + linear(qk_gemm, rearrange_q_buf->slice(1, 0, ngroup * seq_len), k_gemm, 1.f / float(sqrt(dh)), 0.f, nullptr, nullptr); + // softmax + auto qk_softmax = qk_buf->slice(1, 0, seq_len * total_len)->view({nh, seq_len, total_len}); + causalSoftmax(qk_softmax, qk_softmax); + auto v_gemm = kv_caches[req]->v[idev][ilayer]->slice(0, 0, total_len)->permute({1, 0, 2}); + linear(attn_val_buf->slice(1, 0, ngroup * seq_len), qk_gemm, v_gemm, 1.f, 0.f, nullptr, nullptr); + // rearrange attn val + rearrange(o, attn_val_gemm->slice(2, 0, seq_len)); + + token_offset += seq_len; + } + + // o_proj + linear(logits_in, o_buf, layer_tensor->self_attn->w_attn_out, 1.0, 0.0, idev == 0 ? logits_in : nullptr, nullptr); // only rank 0 adds residual + + // All_reduce if distributed + if (rsrc.comm != nullptr) { + RUN_INFINI(infinicclAllReduce( + logits_in->data(), logits_in->data(), ntok * d, dt_logits, + INFINICCL_SUM, rsrc.comm, stream)); + RUN_INFINI(infinirtStreamSynchronize(stream)); + } + + // 2. FFN + rmsnorm(logits_out, logits_in, layer_tensor->w_ffn_norm, meta->epsilon); + + // ------------------------------------------------------------------------ + // SparseMLP: 稀疏混合专家网络 + // 每个token根据路由权重选择top-k个专家进行计算 + // ------------------------------------------------------------------------ + { + std::shared_ptr hidden_states = logits_out; // MoE的输入 + + // MoE配置参数 + size_t moe_intermediate_size = meta->_moe_intermediate_size / ndev; // 每个设备的专家中间层大小 + + // 分配MoE计算缓冲区 + // gate_up_buf: 存储gate和up投影的拼接结果 [1, 2 * moe_intermediate_size] + auto router_gate_up_buf = Tensor::buffer(dt_logits, {1, 2 * moe_intermediate_size}, rsrc.memory_pool); + auto router_gate_buf = router_gate_up_buf->slice(1, 0, moe_intermediate_size); // gate部分 + auto router_up_buf = router_gate_up_buf->slice(1, moe_intermediate_size, moe_intermediate_size); // up部分 + + // 输出缓冲区:存储所有专家输出的加权和 + std::shared_ptr router_states_sum = Tensor::buffer(hidden_states->dtype(), hidden_states->shape(), rsrc.memory_pool); + + // 路由logits:每个token对所有专家的路由分数 [ntok, num_experts] + std::shared_ptr router_logits = Tensor::buffer(dt_logits, {ntok, meta->_num_experts}, rsrc.memory_pool); + + // TopK路由参数 + size_t topk = meta->_num_experts_per_tok; // 每个token选择的专家数量(通常为4或8) + bool norm_topk_prob = meta->_norm_topk_prob; // 是否对topk概率进行归一化 + + // TopK路由结果缓冲区 + std::shared_ptr values_gpu = Tensor::buffer(infiniDtype_t::INFINI_DTYPE_F32, {ntok * topk}, rsrc.memory_pool); // 专家权重 [ntok * topk] + std::shared_ptr indices_gpu = Tensor::buffer(infiniDtype_t::INFINI_DTYPE_I32, {ntok * topk}, rsrc.memory_pool); // 专家索引 [ntok * topk] + std::vector values_cpu(ntok * topk, 0.f); // CPU端权重(用于后续计算) + std::vector indices_cpu(ntok * topk, 0); // CPU端索引(用于后续计算) + + // ------------------------------------------------------------------------ + // 开始MoE计算 + // ------------------------------------------------------------------------ + auto ffn = layer_tensor->ffn; + + // Step 1: 计算路由logits并执行TopK选择 + // 将hidden_states通过路由门控网络,得到每个token对所有专家的路由分数 + linear(router_logits, hidden_states, ffn->_gate_weight, 1.0, 0.0, nullptr, nullptr); + { + topksoftmax(values_gpu, indices_gpu, router_logits, topk, norm_topk_prob); + RUN_INFINI(infinirtMemcpy((void *)values_cpu.data(), values_gpu->data(), values_cpu.size() * sizeof(float), INFINIRT_MEMCPY_D2H)); + RUN_INFINI(infinirtMemcpy((void *)indices_cpu.data(), indices_gpu->data(), indices_cpu.size() * sizeof(int), INFINIRT_MEMCPY_D2H)); + RUN_INFINI(infinirtStreamSynchronize(rsrc.stream)); + } + + // Step 2: 对每个token执行MoE计算 + // 每个token根据路由结果,依次经过topk个专家,并将结果加权求和 + { + for (size_t itok = 0; itok < ntok; ++itok) { + // 提取当前token的输入和输出缓冲区 + std::shared_ptr hidden_states_i = hidden_states->slice(0, itok, 1); // [1, d] + std::shared_ptr router_states_sum_i = router_states_sum->slice(0, itok, 1); // [1, d] + + // 第一个专家:初始化输出(alpha * Expert(hidden_states_i)) + { + int index = indices_cpu[itok * topk + 0]; + float alpha = values_cpu[itok * topk + 0]; + linear(router_gate_up_buf, hidden_states_i, layer_tensor->ffn->_experts[index]->w_ffn_gate_up, 1.0, 0.0, nullptr, nullptr); + swiglu(router_gate_buf, router_up_buf, router_gate_buf); + linear(router_states_sum_i, router_gate_buf, layer_tensor->ffn->_experts[index]->w_ffn_down, alpha, 0.0, nullptr, nullptr); + } + + // 后续专家:累加到已有输出(alpha * Expert(hidden_states_i) + router_states_sum_i) + for (size_t k = 1; k < topk; ++k) { + int index = indices_cpu[itok * topk + k]; + float alpha = values_cpu[itok * topk + k]; + linear(router_gate_up_buf, hidden_states_i, layer_tensor->ffn->_experts[index]->w_ffn_gate_up, 1.0, 0.0, nullptr, nullptr); + swiglu(router_gate_buf, router_up_buf, router_gate_buf); + // 加权累加(注意这里使用router_states_sum_i作为残差,实现累加) + linear(router_states_sum_i, router_gate_buf, + layer_tensor->ffn->_experts[index]->w_ffn_down, alpha, 0.0, router_states_sum_i, nullptr); + } + } + + // 分布式AllReduce:聚合所有设备的MoE输出 + if (rsrc.comm != nullptr) { + RUN_INFINI(infinicclAllReduce( + router_states_sum->data(), router_states_sum->data(), ntok * d, dt_logits, + INFINICCL_SUM, rsrc.comm, stream)); + RUN_INFINI(infinirtStreamSynchronize(stream)); + } + } + + // Step 3: 残差连接 + // 将MoE输出与注意力输出相加,完成Transformer块的计算 + add(logits_in, router_states_sum, logits_in); + } + + // All_reduce if distributed + // if (rsrc.comm != nullptr) { + // RUN_INFINI(infinicclAllReduce( + // logits_in->data(), logits_in->data(), ntok * d, dt_logits, + // INFINICCL_SUM, rsrc.comm, stream)); + // RUN_INFINI(infinirtStreamSynchronize(stream)); + // } + } + + // Sample and Output + if (idev == 0) { + if (last_logits != nullptr) { + rmsnorm(logits_out, logits_in, g_WeightsTensor->w_out_norm, meta->epsilon); + auto last_logits_buf = Tensor::buffer(dt_logits, {ntok, dvoc}, rsrc.memory_pool); + linear(last_logits_buf, logits_out, g_WeightsTensor->w_out_embd, 1.0, 0.0, nullptr, nullptr); + RUN_INFINI(infinirtStreamSynchronize(stream)); + RUN_INFINI(infinirtMemcpy(last_logits, last_logits_buf->data(), dsize(dt_logits) * ntok * dvoc, INFINIRT_MEMCPY_D2H)); + } + if (output != nullptr) { + size_t token_offset = 0; + for (uint32_t req = 0; req < nreq; req++) { + auto seq_len = req_lens[req]; + token_offset += seq_len; + rmsnorm(logits_out->slice(0, req, 1), + logits_in->slice(0, token_offset - 1, 1), + g_WeightsTensor->w_out_norm, + meta->epsilon); + } + linear(prob_buf, logits_out->slice(0, 0, nreq), g_WeightsTensor->w_out_embd, 1.0, 0.0, nullptr, nullptr); + std::random_device _rd; + std::mt19937 gen(_rd()); + token_offset = 0; + for (uint32_t req = 0; req < nreq; req++) { + auto seq_len = req_lens[req]; + float random_val = std::uniform_real_distribution(0, 1)(gen); + randomSample(result_buf->slice(0, req, 1)->view_as({}, {}), + prob_buf->slice(0, req, 1)->view_as({dvoc}, {1}), + random_val, topp[req], topk[req], temperature[req]); + token_offset += seq_len; + } + RUN_INFINI(infinirtStreamSynchronize(stream)); + RUN_INFINI(infinirtMemcpy(result_cpu.data(), result_buf->data(), + sizeof(int64_t) * nreq, INFINIRT_MEMCPY_D2H)); + for (uint32_t req = 0; req < nreq; req++) { + output[req] = uint32_t(result_cpu[req]); + } + } + } +} + +namespace Qwen3MoE { +Model::Model(const Meta *_meta, const Weights *weights, infiniDevice_t device_, std::vector device_ids) : meta(*_meta) { + + // 初始化设备相关参数 + int ndev = int(device_ids.size()); // 设备数量 + device = device_; + dev_ids = device_ids; + dev_resources = std::vector>(ndev); // 每个设备的资源 + states = std::vector(ndev); // 每个设备的推理状态 + threads.resize(ndev); // 每个设备的推理线程 + + // 初始化InfiniRT运行时 + RUN_INFINI(infinirtInit()); + + // 初始化通信器(用于多设备间的AllReduce) + auto comms = std::vector(ndev, nullptr); + if (ndev > 1) { + RUN_INFINI(infinicclCommInitAll(device, comms.data(), ndev, dev_ids.data())); + } + + // 为每个设备启动推理线程 + for (int i = 0; i < ndev; i++) { + threads[i] = std::thread(launchDevice, + std::cref(meta), weights, &dev_resources[i], + std::ref(states[i]), std::ref(req), device, i, ndev, + dev_ids[i], comms[i], Qwen3MoEinferDeviceBatch); + } + + // 等待所有设备完成权重加载 + for (int i = 0; i < ndev; i++) { + std::unique_lock lock(states[i].mtx); + states[i].cv_load.wait(lock, [&] { return states[i].loaded; }); + lock.unlock(); + } +} + +}; // namespace Qwen3MoE diff --git a/src/models/qwen/qwen3moe/qwen3moe_model.hpp b/src/models/qwen/qwen3moe/qwen3moe_model.hpp new file mode 100644 index 00000000..b09b66a4 --- /dev/null +++ b/src/models/qwen/qwen3moe/qwen3moe_model.hpp @@ -0,0 +1,20 @@ +#ifndef _QWEN3MOE_MODEL_HPP_ +#define _QWEN3MOE_MODEL_HPP_ +#include "../qwen_device_resource.hpp" +#include "qwen3moe_weight.hpp" + +namespace Qwen3MoE { + +struct Model { + Meta meta; + infiniDevice_t device; + std::vector dev_ids; + std::vector> dev_resources; + std::vector states; + std::vector threads; + InferRequest req; + + Model(const Meta *, const Weights *, infiniDevice_t device, std::vector device_ids); +}; +}; // namespace Qwen3MoE +#endif diff --git a/src/models/qwen/qwen3moe/qwen3moe_weight.hpp b/src/models/qwen/qwen3moe/qwen3moe_weight.hpp new file mode 100644 index 00000000..0fe8281b --- /dev/null +++ b/src/models/qwen/qwen3moe/qwen3moe_weight.hpp @@ -0,0 +1,113 @@ +#ifndef _QWEN3MOE_WEIGHT_HPP_ +#define _QWEN3MOE_WEIGHT_HPP_ +#include "../qwen_weight.hpp" +#include "infinicore_infer/models/qwen3moe.h" +#include + +// +// cpu 地址 +// +namespace Qwen3MoE { + +using AttentionCStruct = Qwen::AttentionCStruct; + +using SparseMLPCStruct = Qwen::SparseMLPCStruct; + +using DecoderLayerCStruct = Qwen::DecoderLayerCStruct; + +struct Weights { + size_t _nlayer{0}; // ("_nlayer", c_size_t) + infiniDtype_t _dt_norm; // ("_dt_norm", DataType) + infiniDtype_t _dt_mat; // ("_dt_mat", DataType), + int _transpose_linear_weights{false}; // ("_transpose_linear_weights", c_int), + void *_embed_tokens_weight{nullptr}; // ("_embed_tokens_weight", c_void_p), + void *_norm_weight{nullptr}; // ("_norm_weight", c_void_p), + void *_lm_head_weight{nullptr}; // ("_lm_head_weight", c_void_p) + DecoderLayerCStruct *_layers{nullptr}; // ("_layers", POINTER(DecoderLayerCStruct)) + + void print_info() const { + printf("Qwen3MoE Weights:\n"); + printf("\tnlayer : %ld\n", _nlayer); + printf("\ttranspose_linear_weights : %d\n", _transpose_linear_weights); + printf("\tembed_tokens_weight : %p\n", _embed_tokens_weight); + printf("\tnorm_weight : %p\n", _norm_weight); + printf("\tlm_head_weight : %p\n", _lm_head_weight); + if (_layers) { + _layers[0].print_info(); + } + } +}; +}; // namespace Qwen3MoE + +// +// gpu 地址 +// +namespace Qwen3MoE { + +using SharedMLPTensor = Qwen::SharedMLPTensor; + +using RouterMLPTensor = Qwen::RouterMLPTensor; + +using SparseMLPTensor = Qwen::SparseMLPTensor; + +using AttentionTensor = Qwen::AttentionTensor; + +using DecoderLayerTensor = Qwen::DecoderLayerTensor; + +struct WeightsTensor { + size_t nlayer{0}; + std::shared_ptr sin_table; + std::shared_ptr cos_table; + std::shared_ptr w_in_embd; + std::shared_ptr w_out_norm; + std::shared_ptr w_out_embd; + std::vector> layers; + +public: + WeightsTensor(Meta const *meta, Weights const *w, size_t idev, size_t ndev) { + this->nlayer = meta->nlayer; + + size_t d = meta->d; + size_t dh = meta->dh; + float theta = meta->theta; + size_t dvoc = meta->dvoc; + size_t dctx = meta->dctx; + + infiniDtype_t dt_logits = meta->dt_logits; + infiniDtype_t dt_norm = w->_dt_norm; + + int transpose_linear_weights = w->_transpose_linear_weights; + + void *embed_tokens_weight_ptr = w->_embed_tokens_weight; + void *lm_head_weight_ptr = w->_lm_head_weight; + void *norm_weight_ptr = w->_norm_weight; + + this->sin_table = Qwen::getSinTable(dh, theta, dctx, dt_logits); + this->cos_table = Qwen::getCosTable(dh, theta, dctx, dt_logits); + this->w_in_embd = Qwen::getInEmbd(d, dvoc, dt_logits, embed_tokens_weight_ptr); + this->w_out_embd = Qwen::getOutEmbd(d, dvoc, dt_logits, transpose_linear_weights, lm_head_weight_ptr); + this->w_out_norm = Qwen::getNorm(d, dt_norm, norm_weight_ptr); + + this->layers.reserve(this->nlayer); + for (size_t ilayer = 0; ilayer < this->nlayer; ++ilayer) { + this->layers.push_back(std::make_shared(meta, w, ilayer, idev, ndev)); + } + } + + void print_info() const { + printf(" \n "); + printf("Qwen3MoE::WeightsTensor nlayer: %ld \n ", nlayer); + printf("\t\t sin_table :: %p\t%s \n", sin_table.get(), sin_table->info().c_str()); + printf("\t\t cos_table :: %p\t%s \n ", cos_table.get(), cos_table->info().c_str()); + printf("\t\t w_in_embd :: %p\t%s \n ", w_in_embd.get(), w_in_embd->info().c_str()); + printf("\t\t w_out_norm :: %p\t%s \n ", w_out_norm.get(), w_out_norm->info().c_str()); + printf("\t\t w_out_embd :: %p\t%s \n", w_out_embd.get(), w_out_embd->info().c_str()); + for (auto &layer : layers) { + layer->print_info(); + break; + } + } +}; + +}; // namespace Qwen3MoE +#endif diff --git a/src/models/qwen/qwen_device_resource.hpp b/src/models/qwen/qwen_device_resource.hpp new file mode 100644 index 00000000..0864854a --- /dev/null +++ b/src/models/qwen/qwen_device_resource.hpp @@ -0,0 +1,268 @@ +#ifndef _QWEN_DEVICE_RESOURCE_ +#define _QWEN_DEVICE_RESOURCE_ + +#include "../inference_context.hpp" +#include "../jiuge/jiuge_impl.hpp" +#include +#include +#include +#include +#include + +template +struct DeviceResource { + // Device + infiniDevice_t device; + int device_id; + infiniopHandle_t handle; + // Streams + infinirtStream_t stream; + // Communicator + infinicclComm_t comm; + std::shared_ptr memory_pool; + + // Pointer to the GPU parameters of the model + std::unique_ptr weights_tensor_ptr{nullptr}; +}; + +/** + * @brief Create and initialize device resource for model inference + * @tparam WeightsTensor Type of weights tensor + * @tparam Meta Type of model metadata + * @tparam Weights Type of model weights + * @param rsrc Pointer to DeviceResource to initialize (must not be nullptr) + * @param meta Pointer to model metadata (must not be nullptr) + * @param weights Pointer to model weights (must not be nullptr) + * @param device Device type + * @param idev Device index + * @param ndev Total number of devices + * @param dev_id Physical device ID + * @param comm Communication handle for multi-device + * @throws std::invalid_argument if any required pointer is nullptr + * @throws std::runtime_error if resource creation fails + */ +template +void createDeviceResource(DeviceResource *rsrc, const Meta *meta, + const Weights *weights, + infiniDevice_t device, int idev, + int ndev, int dev_id, + infinicclComm_t comm) { + // Input validation + if (rsrc == nullptr) { + throw std::invalid_argument("createDeviceResource: rsrc cannot be nullptr"); + } + if (meta == nullptr) { + throw std::invalid_argument("createDeviceResource: meta cannot be nullptr"); + } + if (weights == nullptr) { + throw std::invalid_argument("createDeviceResource: weights cannot be nullptr"); + } + if (ndev <= 0) { + throw std::invalid_argument("createDeviceResource: ndev must be positive"); + } + if (idev < 0 || idev >= ndev) { + throw std::invalid_argument("createDeviceResource: idev out of range"); + } + + RUN_INFINI(infinirtSetDevice(device, dev_id)); + infiniopHandle_t handle; + infiniopCreateHandle(&handle); + infinirtStream_t stream; + infinirtStreamCreate(&stream); + + auto memory_pool = std::make_shared(128 * 1024 * 1024); + if (!memory_pool) { + throw std::runtime_error("createDeviceResource: memory pool allocation failed"); + } + + // Use member-wise assignment instead of aggregate initialization to avoid stack smashing + rsrc->device = device; + rsrc->device_id = dev_id; + rsrc->handle = handle; + rsrc->stream = stream; + rsrc->comm = comm; + rsrc->memory_pool = memory_pool; + rsrc->weights_tensor_ptr = std::make_unique(meta, weights, idev, ndev); + + RUN_INFINI(infinirtDeviceSynchronize()); +} + +/** + * @brief Release device resource and clean up allocated resources + * @tparam WeightsTensor Type of weights tensor + * @param res DeviceResource reference to release + * @note This function is safe to call multiple times or with partially initialized resources + */ +template +void releaseDeviceResource(DeviceResource &res) { + infinirtDeviceSynchronize(); + + // Release weights tensor (smart pointer will automatically free memory) + res.weights_tensor_ptr.reset(); + + // Release device handles + if (res.handle != nullptr) { + infiniopDestroyHandle(res.handle); + res.handle = nullptr; + } + if (res.stream != nullptr) { + infinirtStreamDestroy(res.stream); + res.stream = nullptr; + } + if (res.comm != nullptr) { + infinicclCommDestroy(res.comm); + res.comm = nullptr; + } +} + +/** + * @brief Launch device thread for model inference + * @tparam WeightsTensor Type of weights tensor + * @tparam Meta Type of model metadata + * @tparam Weights Type of model weights + * @param meta Model metadata reference, it is config of the model + * @param weights Pointer to model weights, it is cpu pointer, it will be copied to gpu memory + * @param rsrc Pointer to DeviceResource to initialize (must not be nullptr) + * @param state Inference state for synchronization + * @param req Inference request structure + * @param device Device type + * @param idev Device index + * @param ndev Total number of devices + * @param dev_id Physical device ID + * @param comm Communication handle for multi-device + * @param inferDeviceBatch Function pointer to device batch inference function (must not be nullptr) + * @throws std::invalid_argument if any required pointer is nullptr + */ +template +void launchDevice(const Meta &meta, const Weights *weights, DeviceResource *rsrc, InferState &state, InferRequest &req, + infiniDevice_t device, int idev, int ndev, int dev_id, infinicclComm_t comm, + void (*inferDeviceBatch)(const Meta *, DeviceResource &, uint32_t, uint32_t, const uint32_t *, uint32_t, const uint32_t *, uint32_t, const uint32_t *, struct KVCache **kv_caches, const float *, const uint32_t *, const float *, uint32_t *, void *)) { + // Input validation + if (rsrc == nullptr) { + throw std::invalid_argument("launchDevice: rsrc cannot be nullptr"); + } + if (weights == nullptr) { + throw std::invalid_argument("launchDevice: weights cannot be nullptr"); + } + if (inferDeviceBatch == nullptr) { + throw std::invalid_argument("launchDevice: inferDeviceBatch cannot be nullptr"); + } + + // Create Device Resource + createDeviceResource(rsrc, &meta, weights, device, idev, ndev, dev_id, comm); + + CacheManager cache_manager(100); + InferenceContext ctx(rsrc->handle, rsrc->memory_pool, &cache_manager, rsrc->stream); + + // Set the inference context for this thread + setInferenceContext(&ctx); + { + std::unique_lock lock(state.mtx); + state.loaded = true; + lock.unlock(); + state.cv_load.notify_one(); + } + + // Infer Loop + while (true) { + std::unique_lock lock(state.mtx); + state.cv_start.wait(lock, [&] { return state.proceed || state.exit_flag; }); + // quit if exit_flag is set + if (state.exit_flag) { + break; + } + + inferDeviceBatch(&meta, *rsrc, idev, ndev, req.tokens, req.ntok, + req.req_lens, req.nreq, req.req_pos, req.kv_caches, + req.temperature, req.topk, req.topp, req.output, req.logits); + + state.proceed = false; + lock.unlock(); + state.cv_done.notify_one(); + } + + // Clean-Up + releaseDeviceResource(*rsrc); + setInferenceContext(nullptr); // Clear the context when done +} + +/** + * @brief Perform batch inference on the model + */ +template +void inferBatch(Model *model, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, + KVCache **kv_caches, + const float *temperature, const uint32_t *topk, const float *topp, + uint32_t *output) { + if (model == nullptr) { + throw std::invalid_argument("inferBatch: model cannot be nullptr"); + } + + model->req.tokens = tokens; + model->req.ntok = ntok; + model->req.req_lens = req_lens; + model->req.nreq = nreq; + model->req.req_pos = req_pos; + model->req.kv_caches = kv_caches; + model->req.output = output; + model->req.logits = nullptr; + model->req.temperature = temperature; + model->req.topk = topk; + model->req.topp = topp; + + for (size_t idev = 0; idev < model->dev_ids.size(); idev++) { + std::unique_lock lock(model->states[idev].mtx); + model->states[idev].proceed = true; + lock.unlock(); + model->states[idev].cv_start.notify_one(); + } + for (size_t i = model->dev_ids.size(); i > 0; i--) { + auto idev = i - 1; + std::unique_lock lock(model->states[idev].mtx); + model->states[idev].cv_done.wait(lock, [&] { return !(model->states[idev].proceed); }); + lock.unlock(); + } +} + +/** + * @brief Perform forward pass (compute logits) for batch inference + */ +template +void forwardBatch(Model *model, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, + KVCache **kv_caches, + void *logits) { + if (model == nullptr) { + throw std::invalid_argument("forwardBatch: model cannot be nullptr"); + } + + model->req.tokens = tokens; + model->req.ntok = ntok; + model->req.req_lens = req_lens; + model->req.nreq = nreq; + model->req.req_pos = req_pos; + model->req.kv_caches = kv_caches; + model->req.output = nullptr; + model->req.logits = logits; + model->req.temperature = nullptr; + model->req.topk = nullptr; + model->req.topp = nullptr; + + for (size_t idev = 0; idev < model->dev_ids.size(); idev++) { + std::unique_lock lock(model->states[idev].mtx); + model->states[idev].proceed = true; + lock.unlock(); + model->states[idev].cv_start.notify_one(); + } + for (size_t i = model->dev_ids.size(); i > 0; i--) { + auto idev = i - 1; + std::unique_lock lock(model->states[idev].mtx); + model->states[idev].cv_done.wait(lock, [&] { return !(model->states[idev].proceed); }); + lock.unlock(); + } +} + +#endif diff --git a/src/models/qwen/qwen_kv_cache.hpp b/src/models/qwen/qwen_kv_cache.hpp new file mode 100644 index 00000000..b7576746 --- /dev/null +++ b/src/models/qwen/qwen_kv_cache.hpp @@ -0,0 +1,66 @@ +#ifndef _QWEN_KV_CACHE_H_ +#define _QWEN_KV_CACHE_H_ + +#include "../../cache.hpp" + +template +struct KVCache *createKVCache(const Model *model) { + KVCache *cache = new KVCache(); + auto ndev = model->dev_resources.size(); + auto nkvh = model->meta.nkvh / ndev; + auto max_len = model->meta.dctx; + auto dh = model->meta.dh; + auto shape = std::vector{max_len, nkvh, dh}; + for (unsigned int idev = 0; idev < ndev; idev++) { + RUN_INFINI(infinirtSetDevice(model->device, model->dev_ids[idev])); + auto kcache = std::vector>(); + auto vcache = std::vector>(); + for (unsigned int layer = 0; layer < model->meta.nlayer; layer++) { + kcache.push_back(std::move(Tensor::buffer(model->meta.dt_logits, shape))); + vcache.push_back(std::move(Tensor::buffer(model->meta.dt_logits, shape))); + } + cache->k.push_back(kcache); + cache->v.push_back(vcache); + } + + return cache; +} + +template +struct KVCache *duplicateKVCache(const Model *model, + const KVCache *kv_cache, + unsigned int seq_len) { + auto new_kv_cache = createKVCache(model); + auto ndev = model->dev_resources.size(); + auto nkvh = model->meta.nkvh / ndev; + auto dh = model->meta.dh; + auto dt_size = dsize(model->meta.dt_logits); + for (unsigned int idev = 0; idev < ndev; idev++) { + RUN_INFINI(infinirtSetDevice(model->device, model->dev_ids[idev])); + for (unsigned int layer = 0; layer < model->meta.nlayer; layer++) { + RUN_INFINI(infinirtMemcpy(new_kv_cache->k[idev][layer]->data(), + kv_cache->k[idev][layer]->data(), + seq_len * nkvh * dh * dt_size, + INFINIRT_MEMCPY_D2D)); + RUN_INFINI(infinirtMemcpy(new_kv_cache->v[idev][layer]->data(), + kv_cache->v[idev][layer]->data(), + seq_len * nkvh * dh * dt_size, + INFINIRT_MEMCPY_D2D)); + } + } + return new_kv_cache; +} + +template +void dropKVCache(Model const *model, KVCache *kv_cache) { + auto ndev = model->dev_resources.size(); + for (unsigned int idev = 0; idev < ndev; idev++) { + RUN_INFINI(infinirtSetDevice(model->device, model->dev_ids[idev])); + for (unsigned int layer = 0; layer < model->meta.nlayer; layer++) { + kv_cache->k[idev][layer].reset(); + kv_cache->v[idev][layer].reset(); + } + } + delete kv_cache; +} +#endif diff --git a/src/models/qwen/qwen_model.hpp b/src/models/qwen/qwen_model.hpp new file mode 100644 index 00000000..4005accc --- /dev/null +++ b/src/models/qwen/qwen_model.hpp @@ -0,0 +1,89 @@ +#ifndef _QWEN_MODEL_H_ +#define _QWEN_MODEL_H_ + +// #include "infinicore_infer/models/qwen3.h" + + +#include "qwen3moe/qwen3moe_model.hpp" + +#include +#include + +/** + * @brief Create a model instance + * @tparam Model Model type + * @tparam Meta Metadata type + * @tparam Weights Weights type + * @param meta Pointer to model config metadata (must not be nullptr) + * @param weights Pointer to model weights, it is a cpu pointer, it will be copied to gpu memory + * @param device Device type + * @param ndev Number of devices (must be positive) + * @param dev_ids Array of device IDs (must not be nullptr if ndev > 0) + * @return Pointer to the created model instance + * @throws std::invalid_argument if any input parameter is invalid + * @throws std::bad_alloc if memory allocation fails + */ +template +Model *createModel(const Meta *meta, + const Weights *weights, + infiniDevice_t device, + int ndev, + const int *dev_ids) { + // Input validation + if (meta == nullptr) { + throw std::invalid_argument("createModel: meta cannot be nullptr"); + } + if (weights == nullptr) { + throw std::invalid_argument("createModel: weights cannot be nullptr"); + } + if (ndev <= 0) { + throw std::invalid_argument("createModel: ndev must be positive"); + } + if (dev_ids == nullptr) { + throw std::invalid_argument("createModel: dev_ids cannot be nullptr"); + } + + // Copy device IDs + std::vector device_ids(dev_ids, dev_ids + ndev); + + // Create model instance + Model *model = new Model(meta, weights, device, device_ids); + if (model == nullptr) { + throw std::bad_alloc(); + } + + return model; +} + +/** + * @brief Destroy a model instance and clean up resources + * @tparam Model Model type + * @param model Pointer to the model instance to destroy (must not be nullptr) + * @throws std::invalid_argument if model is nullptr + */ +template +void destroyModel(Model *model) { + if (model == nullptr) { + throw std::invalid_argument("destroyModel: model cannot be nullptr"); + } + + auto ndev = model->dev_resources.size(); + + // Signal all device threads to exit + for (size_t idev = 0; idev < ndev; idev++) { + std::unique_lock lock(model->states[idev].mtx); + model->states[idev].exit_flag = true; + lock.unlock(); + model->states[idev].cv_start.notify_one(); + } + + // Wait for all device threads to finish + for (size_t idev = 0; idev < ndev; idev++) { + model->threads[idev].join(); + } + + // Delete the model instance + delete model; +} + +#endif diff --git a/src/models/qwen/qwen_weight.hpp b/src/models/qwen/qwen_weight.hpp new file mode 100644 index 00000000..0bc20958 --- /dev/null +++ b/src/models/qwen/qwen_weight.hpp @@ -0,0 +1,556 @@ +#ifndef _QWEN_WEIGHT_HPP_ +#define _QWEN_WEIGHT_HPP_ + +#include "../../tensor.hpp" +#include "../../utils.hpp" +#include +#include +#include +#include +#include +#include + +namespace Qwen { +// +// CPU的权重指针 +// +struct MLPCStruct { + void *_gate_up_proj_weight{nullptr}; // ("_gate_up_proj_weight", c_void_p), + void *_down_proj_weight{nullptr}; // ("_down_proj_weight", c_void_p), + + void print_info() const { + printf("\n"); + printf("\t\t\tMLPCStruct:\n"); + printf("\t\t\t\tgate_up_proj_weight : %p\n", _gate_up_proj_weight); + printf("\t\t\t\tdown_proj_weight : %p\n", _down_proj_weight); + } +}; + +struct SparseMLPCStruct { + size_t _shared_expert_num{0}; // ("_shared_expert_num", c_size_t) + size_t _num_experts{0}; // ("_num_experts", c_size_t) + void *_shared_expert_gate_weight{nullptr}; // ("_shared_expert_gate_weight", c_void_p) + void *_gate_weight{nullptr}; // ("_gate_weight", c_void_p) + MLPCStruct _shared_expert; // ("_shared_expert", MLPCStruct) + MLPCStruct *_experts{nullptr}; // ("_experts", POINTER(MLPCStruct)) + + void print_info() const { + printf("\n"); + printf("\t\tSparseMLPCStruct:\n"); + printf("\t\t\tshared_expert_gate_weight : %p\n", _shared_expert_gate_weight); + printf("\t\t\tgate_weight : %p\n", _gate_weight); + + printf("\t\t\t shared_expert : \n"); + _shared_expert.print_info(); + + printf("\t\t\t experts : \n"); + if (_experts) { + _experts[0].print_info(); + } + } +}; + +struct AttentionCStruct { + void *_qkv_proj_weight{nullptr}; // ("_qkv_proj_weight", c_void_p) + void *_qkv_proj_bias{nullptr}; // ("_qkv_proj_bias", c_void_p) + void *_qk_norm_weight{nullptr}; // ("_qk_norm_weight", c_void_p) + void *_o_proj_weight{nullptr}; // ("_o_proj_weight", c_void_p) + + void print_info() const { + printf("\t\tAttentionCStruct:\n"); + printf("\t\t\tqkv_proj_weight : %p\n", _qkv_proj_weight); + printf("\t\t\tqkv_proj_bias : %p\n", _qkv_proj_bias); + printf("\t\t\tqk_norm_weight : %p\n", _qk_norm_weight); + printf("\t\t\to_proj_weight : %p\n", _o_proj_weight); + } +}; + +template +struct DecoderLayerCStruct { + int _ilayer{0}; + void *_post_attention_layernorm_weight{nullptr}; // ("_post_attention_layernorm_weight", c_void_p), + void *_input_layernorm_weight{nullptr}; // ("_input_layernorm_weight", c_void_p), + AttentionCStruct _self_attn; + FFNCStruct _mlp; + + void print_info() const { + printf("\tDecoderLayerCStruct:\n"); + printf("\t\tilayer : %d\n", _ilayer); + printf("\t\tpost_attention_layernorm_weight : %p\n", _post_attention_layernorm_weight); + printf("\t\tinput_layernorm_weight : %p\n", _input_layernorm_weight); + _self_attn.print_info(); + _mlp.print_info(); + } +}; + +}; // namespace Qwen + +///////////////////////////////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace Qwen { +// getAttnNorm getFFNNorm getOutNorm getAttnQKNorm +/** + * @brief Create a normalization weight tensor + * @param d Dimension size + * @param dt_norm Data type for normalization weights + * @param norm_weight_ptr is cpu pointer, it will be copied to gpu memory + * @return Shared pointer to the created Tensor + * @throws std::invalid_argument if norm_weight_ptr is nullptr + */ +inline std::shared_ptr getNorm(size_t d, infiniDtype_t dt_norm, void *norm_weight_ptr) { + if (norm_weight_ptr == nullptr) { + throw std::invalid_argument("getNorm: norm_weight_ptr cannot be nullptr"); + } + auto shape = std::vector({d}); + return Tensor::weight(static_cast(norm_weight_ptr), dt_norm, shape); +} + +/** + * @brief Create a sine table for RoPE (Rotary Position Embedding) + * @param dh Head dimension + * @param theta Base frequency parameter + * @param dctx Maximum context length + * @param dt_logits Data type for the table + * @return Shared pointer to the created Tensor + * @throws std::runtime_error if memory allocation fails + */ +inline std::shared_ptr getSinTable(size_t dh, float theta, size_t dctx, infiniDtype_t dt_logits) { + + if (theta <= 0.0f) { + throw std::invalid_argument("getSinTable: theta must be positive"); + } + + auto half_dh = dh / 2; + auto unit = dsize(dt_logits); + void *table = std::malloc(dctx * half_dh * unit); + + for (size_t i = 0; i < dctx; i++) { + for (size_t j = 0; j < half_dh; j++) { + float _sin = std::sin( + static_cast(i) / std::pow(theta, static_cast(j) / half_dh)); + if (dt_logits == INFINI_DTYPE_F16) { + ((uint16_t *)table)[i * half_dh + j] = f32_to_f16(_sin); + } else if (dt_logits == INFINI_DTYPE_BF16) { + ((uint16_t *)table)[i * half_dh + j] = f32_to_bf16(_sin); + } else if (dt_logits == INFINI_DTYPE_F32) { + ((float *)table)[i * half_dh + j] = _sin; + } else { + throw std::invalid_argument("getSinTable: unsupported data type"); + } + } + } + auto shape = std::vector({dctx, half_dh}); + auto tensor = Tensor::weight(table, dt_logits, shape); + std::free(table); + return tensor; +} + +/** + * @brief Create a cosine table for RoPE (Rotary Position Embedding) + * @param dh Head dimension + * @param theta Base frequency parameter + * @param dctx Maximum context length + * @param dt_logits Data type for the table + * @return Shared pointer to the created Tensor + * @throws std::runtime_error if memory allocation fails + */ +inline std::shared_ptr getCosTable(size_t dh, float theta, size_t dctx, infiniDtype_t dt_logits) { + auto half_dh = dh / 2; + auto unit = dsize(dt_logits); + void *table = std::malloc(dctx * half_dh * unit); + + for (size_t i = 0; i < dctx; i++) { + for (size_t j = 0; j < half_dh; j++) { + float _cos = std::cos( + static_cast(i) / std::pow(theta, static_cast(j) / half_dh)); + if (dt_logits == INFINI_DTYPE_F16) { + ((uint16_t *)table)[i * half_dh + j] = f32_to_f16(_cos); + } else if (dt_logits == INFINI_DTYPE_BF16) { + ((uint16_t *)table)[i * half_dh + j] = f32_to_bf16(_cos); + } else if (dt_logits == INFINI_DTYPE_F32) { + ((float *)table)[i * half_dh + j] = _cos; + } else { + throw std::invalid_argument("getCosTable: unsupported data type"); + } + } + } + auto shape = std::vector({dctx, half_dh}); + auto tensor = Tensor::weight(table, dt_logits, shape); + std::free(table); + return tensor; +} + +/** + * @brief Create input embedding tensor + * @param d Hidden dimension + * @param dvoc Vocabulary size + * @param dt_logits Data type + * @param embed_tokens_weight_ptr is cpu pointer, it will be copied to gpu memory + * @return Shared pointer to the created Tensor + * @throws std::invalid_argument if embed_tokens_weight_ptr is nullptr + */ +inline std::shared_ptr getInEmbd(size_t d, size_t dvoc, infiniDtype_t dt_logits, void *embed_tokens_weight_ptr) { + if (embed_tokens_weight_ptr == nullptr) { + throw std::invalid_argument("getInEmbd: embed_tokens_weight_ptr cannot be nullptr"); + } + auto shape = std::vector({dvoc, d}); + return Tensor::weight(static_cast(embed_tokens_weight_ptr), dt_logits, shape); +} + +/** + * @brief Create output embedding (LM head) tensor + * @param d Hidden dimension + * @param dvoc Vocabulary size + * @param dt_logits Data type + * @param transpose_linear_weights Whether to transpose weights + * @param lm_head_weight_ptr is cpu pointer, it will be copied to gpu memory + * @return Shared pointer to the created Tensor + * @throws std::invalid_argument if lm_head_weight_ptr is nullptr + */ +inline std::shared_ptr getOutEmbd(size_t d, size_t dvoc, infiniDtype_t dt_logits, int transpose_linear_weights, void *lm_head_weight_ptr) { + if (lm_head_weight_ptr == nullptr) { + throw std::invalid_argument("getOutEmbd: lm_head_weight_ptr cannot be nullptr"); + } + if (transpose_linear_weights != 0) { + auto shape = std::vector({dvoc, d}); + return Tensor::weight(static_cast(lm_head_weight_ptr), dt_logits, shape)->permute({1, 0}); + } else { + auto shape = std::vector({d, dvoc}); + return Tensor::weight(static_cast(lm_head_weight_ptr), dt_logits, shape); + } +} + +}; // namespace Qwen + +namespace Qwen { +class BaseMLPTensor { +public: + std::shared_ptr w_ffn_gate_up; + std::shared_ptr w_ffn_down; + +public: + BaseMLPTensor() = default; + + void Init(size_t di, size_t d, infiniDtype_t dt_mat, int transpose_linear_weights, size_t idev, size_t ndev, void *gate_up_proj_weight_ptr, void *down_proj_weight_ptr) { + this->w_ffn_gate_up = this->getFFNGateUp(di, d, dt_mat, transpose_linear_weights, idev, ndev, gate_up_proj_weight_ptr); + this->w_ffn_down = this->getFFNDown(di, d, dt_mat, transpose_linear_weights, idev, ndev, down_proj_weight_ptr); + } + +private: + inline std::shared_ptr getFFNGateUp(size_t di, size_t d, infiniDtype_t dt_mat, int transpose_linear_weights, size_t idev, size_t ndev, void *gate_up_proj_weight_ptr) { + + size_t offset = idev * (2 * di / ndev) * d * dsize(dt_mat); + if (transpose_linear_weights != 0) { + auto shape = std::vector({2 * di / ndev, d}); + return Tensor::weight((char *)(gate_up_proj_weight_ptr) + offset, dt_mat, shape)->permute({1, 0}); + } else { + auto shape = std::vector({d, 2 * di / ndev}); + return Tensor::weight((char *)(gate_up_proj_weight_ptr) + offset, dt_mat, shape); + } + } + + inline std::shared_ptr getFFNDown(size_t di, size_t d, infiniDtype_t dt_mat, int transpose_linear_weights, size_t idev, size_t ndev, void *down_proj_weight_ptr) { + size_t offset = idev * d * (di / ndev) * dsize(dt_mat); + if (transpose_linear_weights != 0) { + auto shape = std::vector({d, di / ndev}); + return Tensor::weight((char *)(down_proj_weight_ptr) + offset, dt_mat, shape)->permute({1, 0}); + } else { + auto shape = std::vector({di / ndev, d}); + return Tensor::weight((char *)(down_proj_weight_ptr) + offset, dt_mat, shape); + } + } +}; + +class BaseAttentionTensor { +public: + std::shared_ptr w_attn_qkv; + std::shared_ptr b_attn_qkv; + std::shared_ptr w_attn_qk_norm; + std::shared_ptr w_attn_out; + +public: + BaseAttentionTensor() = default; + void Init(size_t nkvh, + size_t nh, + size_t dh, + size_t d, + infiniDtype_t dt_mat, + infiniDtype_t dt_norm, + int transpose_linear_weights, + size_t idev, size_t ndev, + void *qkv_proj_weight_ptr, + void *qkv_proj_bias_ptr, + void *qk_norm_weight_ptr, + void *o_proj_weight_ptr) { + + this->w_attn_qkv = this->getAttnQKV(nkvh, nh, dh, d, dt_mat, dt_norm, transpose_linear_weights, idev, ndev, qkv_proj_weight_ptr); + + if (qkv_proj_bias_ptr != nullptr) { + this->b_attn_qkv = this->getAttnQKVBias(nkvh, nh, dh, d, dt_mat, dt_norm, transpose_linear_weights, idev, ndev, qkv_proj_bias_ptr); + } + + if (qk_norm_weight_ptr != nullptr) { + this->w_attn_qk_norm = getNorm(dh * 2, dt_norm, qk_norm_weight_ptr); + } + + this->w_attn_out = this->getAttnO(nkvh, nh, dh, d, dt_mat, dt_norm, transpose_linear_weights, idev, ndev, o_proj_weight_ptr); + } + +private: + inline std::shared_ptr getAttnQKV(size_t nkvh, + size_t nh, + size_t dh, + size_t d, + infiniDtype_t dt_mat, + infiniDtype_t dt_norm, + int transpose_linear_weights, + size_t idev, + size_t ndev, void *qkv_proj_weight_ptr) { + + size_t offset = idev * ((nkvh * 2 + nh) / ndev * dh) * d * dsize(dt_mat); + if (transpose_linear_weights != 0) { + auto shape = std::vector({(nh + 2 * nkvh) / ndev * dh, d}); + return Tensor::weight((char *)(qkv_proj_weight_ptr) + offset, dt_mat, shape)->permute({1, 0}); + } else { + auto shape = std::vector({d, (nh + 2 * nkvh) / ndev * dh}); + return Tensor::weight((char *)(qkv_proj_weight_ptr) + offset, dt_mat, shape); + } + } + + inline std::shared_ptr getAttnQKVBias(size_t nkvh, + size_t nh, + size_t dh, + size_t d, + infiniDtype_t dt_mat, + infiniDtype_t dt_norm, + int transpose_linear_weights, + size_t idev, + size_t ndev, void *qkv_proj_bias_ptr) { + + size_t offset = idev * ((nkvh * 2 + nh) / ndev * dh) * dsize(dt_mat); + auto shape = std::vector({(nh + 2 * nkvh) / ndev * dh}); + + return Tensor::weight((char *)(qkv_proj_bias_ptr) + offset, dt_mat, shape); + } + + inline std::shared_ptr getAttnO(size_t nkvh, + size_t nh, + size_t dh, + size_t d, + infiniDtype_t dt_mat, + infiniDtype_t dt_norm, + int transpose_linear_weights, + size_t idev, + size_t ndev, void *o_proj_weight_ptr) { + + size_t offset = idev * d * (nh / ndev * dh) * dsize(dt_mat); + if (transpose_linear_weights != 0) { + auto shape = std::vector({d, nh / ndev * dh}); + return Tensor::weight((char *)(o_proj_weight_ptr) + offset, dt_mat, shape)->permute({1, 0}); + } else { + auto shape = std::vector({nh / ndev * dh, d}); + return Tensor::weight((char *)(o_proj_weight_ptr) + offset, dt_mat, shape); + } + } +}; + +}; // namespace Qwen + +// +// 存储 gpu 地址 +// +namespace Qwen { + +template +class MLPTensor : public Qwen::BaseMLPTensor { +public: + MLPTensor(Meta const *meta, Weights const *w, int ilayer, size_t idev, size_t ndev) { + size_t di = meta->di; + size_t d = meta->d; + infiniDtype_t dt_mat = w->_dt_mat; + int transpose_linear_weights = w->_transpose_linear_weights; + void *gate_up_proj_weight_ptr = w->_layers[ilayer]._mlp._gate_up_proj_weight; + void *down_proj_weight_ptr = w->_layers[ilayer]._mlp._down_proj_weight; + this->Init(di, d, dt_mat, transpose_linear_weights, idev, ndev, gate_up_proj_weight_ptr, down_proj_weight_ptr); + } + +public: + void print_info() const { + printf("\t\t\t Qwen3::MLPTensor \n"); + printf("\t\t\t\t w_ffn_gate_up :: %p\t%s \n", w_ffn_gate_up.get(), w_ffn_gate_up->info().c_str()); + printf("\t\t\t\t w_ffn_down :: %p\t%s \n", w_ffn_down.get(), w_ffn_down->info().c_str()); + } +}; + +template +class SharedMLPTensor : public Qwen::BaseMLPTensor { +public: + SharedMLPTensor(Meta const *meta, Weights const *w, int ilayer, size_t idev, size_t ndev) { + size_t di = meta->_shared_expert_intermediate_size; + size_t d = meta->d; + infiniDtype_t dt_mat = w->_dt_mat; + int transpose_linear_weights = w->_transpose_linear_weights; + void *gate_up_proj_weight_ptr = w->_layers[ilayer]._mlp._shared_expert._gate_up_proj_weight; + void *down_proj_weight_ptr = w->_layers[ilayer]._mlp._shared_expert._down_proj_weight; + this->Init(di, d, dt_mat, transpose_linear_weights, idev, ndev, gate_up_proj_weight_ptr, down_proj_weight_ptr); + } + +public: + void print_info() const { + printf("\t\t\t\t SharedMLPTensor \n"); + printf("\t\t\t\t\t w_ffn_gate_up :: %p\t%s \n", w_ffn_gate_up.get(), w_ffn_gate_up->info().c_str()); + printf("\t\t\t\t\t w_ffn_down :: %p\t%s \n", w_ffn_down.get(), w_ffn_down->info().c_str()); + } +}; + +template +class RouterMLPTensor : public Qwen::BaseMLPTensor { +public: + RouterMLPTensor(Meta const *meta, Weights const *w, int ilayer, int iexpert, size_t idev, size_t ndev) { + size_t di = meta->_moe_intermediate_size; + size_t d = meta->d; + infiniDtype_t dt_mat = w->_dt_mat; + int transpose_linear_weights = w->_transpose_linear_weights; + void *gate_up_proj_weight_ptr = w->_layers[ilayer]._mlp._experts[iexpert]._gate_up_proj_weight; + void *down_proj_weight_ptr = w->_layers[ilayer]._mlp._experts[iexpert]._down_proj_weight; + this->Init(di, d, dt_mat, transpose_linear_weights, idev, ndev, gate_up_proj_weight_ptr, down_proj_weight_ptr); + } + +public: + void print_info() const { + printf("\t\t\t\t RouterMLPTensor \n"); + printf("\t\t\t\t\t w_ffn_gate_up :: %p\t%s \n", w_ffn_gate_up.get(), w_ffn_gate_up->info().c_str()); + printf("\t\t\t\t\t w_ffn_down :: %p\t%s \n", w_ffn_down.get(), w_ffn_down->info().c_str()); + } +}; + +template +class SparseMLPTensor { +public: + size_t _shared_expert_num; + size_t _num_experts; + std::shared_ptr _shared_expert_gate_weight; + std::shared_ptr _gate_weight; + std::shared_ptr _shared_expert; + std::vector> _experts; + +public: + SparseMLPTensor(Meta const *meta, Weights const *w, int ilayer, size_t idev, size_t ndev) { + this->_shared_expert_num = 1; + this->_num_experts = meta->_num_experts; + + if (w->_layers[ilayer]._mlp._shared_expert_gate_weight) { + // gate + void *shared_expert_gate = w->_layers[ilayer]._mlp._shared_expert_gate_weight; + auto shape = std::vector({meta->d, 1}); + this->_shared_expert_gate_weight = Tensor::weight((char *)(shared_expert_gate), w->_dt_mat, shape); + + // 权重 + this->_shared_expert = std::make_shared(meta, w, ilayer, idev, ndev); + } + + // + void *experts_gate = w->_layers[ilayer]._mlp._gate_weight; + auto shape = std::vector({meta->d, meta->_num_experts}); + this->_gate_weight = Tensor::weight((char *)(experts_gate), w->_dt_mat, shape); + + // experts + this->_experts.reserve(meta->_num_experts); + for (size_t iexpert = 0; iexpert < meta->_num_experts; ++iexpert) { + this->_experts.push_back( + std::make_shared(meta, w, ilayer, iexpert, idev, ndev)); + } + } + +public: + void print_info() const { + printf("\t\t\t SparseMLPTensor \n"); + printf("\t\t\t\t shared_expert_num %ld \n", _shared_expert_num); + printf("\t\t\t\t shared_expert_gate_weight %p %s \n", _shared_expert_gate_weight.get(), _shared_expert_gate_weight.get() ? _shared_expert_gate_weight->info().c_str() : ""); + printf("\t\t\t\t gate_weight %p %s \n", _gate_weight.get(), _gate_weight.get() ? _gate_weight->info().c_str() : ""); + printf("\n"); + printf("\t\t\t\t _shared_expert %p %s \n", _shared_expert.get(), _shared_expert.get() ? "_shared_expert" : ""); + if (_shared_expert) { + _shared_expert->print_info(); + } + printf("\n"); + printf("\t\t\t\t _experts size %ld \n", _experts.size()); + for (auto expert : _experts) { + expert->print_info(); + break; + } + } +}; + +template +class AttentionTensor : public Qwen::BaseAttentionTensor { +public: + AttentionTensor(Meta const *meta, Weights const *w, size_t ilayer, size_t idev, size_t ndev) { + size_t nkvh = meta->nkvh; + size_t nh = meta->nh; + size_t dh = meta->dh; + size_t d = meta->d; + infiniDtype_t dt_mat = w->_dt_mat; + infiniDtype_t dt_norm = w->_dt_norm; + int transpose_linear_weights = w->_transpose_linear_weights; + + void *qkv_proj_weight_ptr = w->_layers[ilayer]._self_attn._qkv_proj_weight; + void *qkv_proj_bias_ptr = w->_layers[ilayer]._self_attn._qkv_proj_bias; + void *qk_norm_weight_ptr = w->_layers[ilayer]._self_attn._qk_norm_weight; + void *o_proj_weight_ptr = w->_layers[ilayer]._self_attn._o_proj_weight; + + this->Init(nkvh, nh, dh, d, dt_mat, dt_norm, transpose_linear_weights, idev, ndev, qkv_proj_weight_ptr, qkv_proj_bias_ptr, qk_norm_weight_ptr, o_proj_weight_ptr); + } + + void print_info() const { + printf("\t\t\t AttentionTensor \n"); + printf("\t\t\t\t w_attn_qkv :: %p\t%s \n", w_attn_qkv.get(), w_attn_qkv->info().c_str()); + printf("\t\t\t\t b_attn_qkv :: %p\t%s \n", b_attn_qkv.get(), b_attn_qkv.get() ? b_attn_qkv->info().c_str() : ""); + printf("\t\t\t\t w_attn_qk_norm :: %p\t%s \n", w_attn_qk_norm.get(), w_attn_qk_norm.get() ? w_attn_qk_norm->info().c_str() : ""); + printf("\t\t\t\t w_attn_out :: %p\t%s \n", w_attn_out.get(), w_attn_out->info().c_str()); + } +}; + +template +class DecoderLayerTensor { +public: + int ilayer; + std::shared_ptr w_attn_norm; + std::shared_ptr w_ffn_norm; + std::shared_ptr self_attn; + std::shared_ptr ffn; + +public: + DecoderLayerTensor(Meta const *meta, Weights const *w, size_t ilayer, size_t idev, size_t ndev) { + this->ilayer = ilayer; + + size_t d = meta->d; + infiniDtype_t dt_norm = w->_dt_norm; + void *att_norm_weight_ptr = w->_layers[ilayer]._input_layernorm_weight; + void *ffn_norm_weight_ptr = w->_layers[ilayer]._post_attention_layernorm_weight; + + this->w_attn_norm = Qwen::getNorm(d, dt_norm, att_norm_weight_ptr); + this->w_ffn_norm = Qwen::getNorm(d, dt_norm, ffn_norm_weight_ptr); + + this->self_attn = std::make_shared(meta, w, ilayer, idev, ndev); + this->ffn = std::make_shared(meta, w, ilayer, idev, ndev); + } + void print_info() const { + printf("\n "); + printf("\t\t DecoderLayerTensor %d \n ", ilayer); + printf("\t\t\t w_attn_norm :: %p\t%s \n", w_attn_norm.get(), w_attn_norm->info().c_str()); + printf("\t\t\t w_ffn_norm :: %p\t%s \n", w_ffn_norm.get(), w_ffn_norm->info().c_str()); + printf("\n"); + printf("\t\t\t self_attn :: %p \n", self_attn.get()); + self_attn->print_info(); + printf("\n"); + printf("\t\t\t ffn :: %p \n", ffn.get()); + ffn->print_info(); + } +}; + +}; // namespace Qwen + +#endif diff --git a/third_party/spdlog b/third_party/spdlog new file mode 160000 index 00000000..88a0e07a --- /dev/null +++ b/third_party/spdlog @@ -0,0 +1 @@ +Subproject commit 88a0e07ad5bb3e2651cd5613530b3f06a15fc400 diff --git a/xmake.lua b/xmake.lua index 598ac534..172da092 100644 --- a/xmake.lua +++ b/xmake.lua @@ -14,6 +14,7 @@ target("infinicore_infer") add_files("src/models/*.cpp") add_files("src/models/*/*.cpp") + add_files("src/models/qwen/*/*.cpp") add_files("src/tensor/*.cpp") add_files("src/allocator/*.cpp") add_files("src/dataloader/*.cpp")