diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 31bd8e1b1..3df37bddc 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -7,6 +7,7 @@ add_subdirectory(minicpm_o) add_subdirectory(minicpm4) add_subdirectory(qwen3) add_subdirectory(qwen3_service) +add_subdirectory(qwen3_moe) add_subdirectory(deepseek_ocr) if(MLLM_BUILD_QNN_BACKEND) diff --git a/examples/qwen3_moe/CMakeLists.txt b/examples/qwen3_moe/CMakeLists.txt new file mode 100644 index 000000000..d20fa8158 --- /dev/null +++ b/examples/qwen3_moe/CMakeLists.txt @@ -0,0 +1,3 @@ +add_executable(mllm-qwen3-moe-runner main.cpp) +target_link_libraries(mllm-qwen3-moe-runner PRIVATE MllmRT MllmCPUBackend) +target_include_directories(mllm-qwen3-moe-runner PRIVATE ${MLLM_INCLUDE_DIR}) diff --git a/examples/qwen3_moe/config_30B_A3B_gguf.json b/examples/qwen3_moe/config_30B_A3B_gguf.json new file mode 100644 index 000000000..0ae3fd17d --- /dev/null +++ b/examples/qwen3_moe/config_30B_A3B_gguf.json @@ -0,0 +1,37 @@ +{ + "architectures": [ + "Qwen3MoeForCausalLM" + ], + "attention_bias": false, + "bos_token_id": 151643, + "decoder_sparse_step": 1, + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 6144, + "max_position_embeddings": 262144, + "max_window_layers": 48, + "mlp_only_layers": [], + "model_type": "qwen3_moe", + "moe_intermediate_size": 768, + "norm_topk_prob": true, + "num_attention_heads": 32, + "num_experts": 128, + "num_experts_per_tok": 8, + "num_hidden_layers": 48, + "num_key_value_heads": 4, + "output_router_logits": false, + "rms_norm_eps": 1e-06, + "rope_scaling": 1.0, + "rope_theta": 10000000, + "router_aux_loss_coef": 0.001, + "tie_word_embeddings": true, + "transformers_version": "4.51.0", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 151936, + "max_cache_length": 16384, + "linear_impl_type": "Default" +} diff --git a/examples/qwen3_moe/main.cpp b/examples/qwen3_moe/main.cpp new file mode 100644 index 000000000..367bbae26 --- /dev/null +++ b/examples/qwen3_moe/main.cpp @@ -0,0 +1,80 @@ +#include +#include +#include +#include +#include +#include + +using mllm::Argparse; + +MLLM_MAIN({ + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model_path").help("Model path").required(true); + auto& model_version = Argparse::add("-mv|--model_version").help("Model version").required(true); + auto& tokenizer_path = Argparse::add("-t|--tokenizer_path").help("Tokenizer directory").required(true); + auto& config_path = Argparse::add("-c|--config_path").help("Config path").required(true); + + Argparse::parse(argc, argv); + +#ifdef MLLM_PERFETTO_ENABLE + mllm::perf::start(); +#endif + + mllm::ModelFileVersion file_version = mllm::ModelFileVersion::kV1; + if (model_version.get() == "v1") { + file_version = mllm::ModelFileVersion::kV1; + } else if (model_version.get() == "v2") { + file_version = mllm::ModelFileVersion::kV2; + } else { + fmt::print("āŒ Unsupported model_version: {} (expected v1 or v2)\n", model_version.get()); + mllm::shutdownContext(); + return 1; + } + + if (help.isSet()) { + Argparse::printHelp(); + mllm::shutdownContext(); + return 0; + } + + { + auto qwen3_moe_cfg = mllm::models::qwen3_moe::Qwen3MoeConfig(config_path.get()); + auto qwen3_moe_tokenizer = mllm::models::qwen3_moe::Qwen3Tokenizer(tokenizer_path.get()); + auto qwen3_moe = mllm::models::qwen3_moe::Qwen3MoeForCausalLM(qwen3_moe_cfg); + + auto param = mllm::load(model_path.get(), file_version); + qwen3_moe.load(param); + + fmt::print("\n{:*^60}\n", " Qwen3 MoE Interactive CLI "); + fmt::print("Enter 'exit' or 'quit' to end the session\n\n"); + + std::string prompt_text; + + fmt::print("šŸ’¬ Prompt text (or 'exit/quit'): "); + std::getline(std::cin, prompt_text); + + if(prompt_text == "exit" || prompt_text == "quit") { return 0; } + + try { + fmt::print("šŸ”„ Processing...\n"); + auto inputs = qwen3_moe_tokenizer.convertMessage({.prompt = prompt_text}); + + fmt::print("\nšŸ¤– Response: "); + + // Use for loop + for (auto& step : qwen3_moe.chat(inputs)) { std::wcout << qwen3_moe_tokenizer.detokenize(step.cur_token_id) << std::flush; } + + fmt::print("\n{}\n", std::string(60, '-')); + } catch (const std::exception& e) { fmt::print("\nāŒ Error: {}\n{}\n", e.what(), std::string(60, '-')); } + + qwen3_moe.perfSummary(); + } + +#ifdef MLLM_PERFETTO_ENABLE + mllm::perf::stop(); + mllm::perf::saveReport("qwen3_moe.perf"); +#endif + + mllm::print("\n"); + mllm::memoryReport(); +}) diff --git a/examples/qwen3_moe/quant_cfg_30B_q4_k.json b/examples/qwen3_moe/quant_cfg_30B_q4_k.json new file mode 100644 index 000000000..f93829ab0 --- /dev/null +++ b/examples/qwen3_moe/quant_cfg_30B_q4_k.json @@ -0,0 +1,79 @@ +{ + "^model\\.layers\\.\\d+\\.self_attn\\.q_proj.(bias|weight)": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_K", + "shape": [ + 4096, + 2048 + ], + "replace": true + } + }, + "^model\\.layers\\.\\d+\\.self_attn\\.k_proj.(bias|weight)": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_K", + "shape": [ + 512, + 2048 + ], + "replace": true + } + }, + "^model\\.layers\\.\\d+\\.self_attn\\.v_proj.(bias|weight)": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q6_K", + "shape": [ + 512, + 2048 + ], + "replace": true + } + }, + "^model\\.layers\\.\\d+\\.self_attn\\.o_proj.(bias|weight)": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_K", + "shape": [ + 2048, + 4096 + ], + "replace": true + } + }, + "^model\\.layers\\.\\d+\\.mlp\\.experts\\.\\d+\\.up_proj.(bias|weight)": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_K", + "shape": [ + 768, + 2048 + ], + "replace": true + } + }, + "^model\\.layers\\.\\d+\\.mlp\\.experts\\.\\d+\\.down_proj.(bias|weight)": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q6_K", + "shape": [ + 2048, + 768 + ], + "replace": true + } + }, + "^lm_head.weight": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_K", + "shape": [ + 151936, + 2048 + ], + "replace": true + } + } +} diff --git a/mllm/backends/cpu/kernels/common/elewise-inl.hpp b/mllm/backends/cpu/kernels/common/elewise-inl.hpp index a2f2ee429..b839e32dc 100644 --- a/mllm/backends/cpu/kernels/common/elewise-inl.hpp +++ b/mllm/backends/cpu/kernels/common/elewise-inl.hpp @@ -8,31 +8,6 @@ HWY_BEFORE_NAMESPACE(); namespace mllm::cpu::common { // NOLINT namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; - -//===----------------------------------------------------------------------===// -// Elementwise + - * / By Matrix -//===----------------------------------------------------------------------===// -template -HWY_INLINE void elementwise_impl(const T* HWY_RESTRICT x, const T* HWY_RESTRICT y, T* HWY_RESTRICT out, size_t count, Op&& op) { - const hn::ScalableTag d; - const size_t N = hn::Lanes(d); - size_t idx = 0; - - for (; idx + N <= count; idx += N) { - const hn::Vec vx = hn::LoadU(d, x + idx); - const hn::Vec vy = hn::LoadU(d, y + idx); - const hn::Vec result = op(d, vx, vy); - hn::StoreU(result, d, out + idx); - } - - if (idx < count) { - const hn::Vec vx = hn::LoadN(d, x + idx, count - idx); - const hn::Vec vy = hn::LoadN(d, y + idx, count - idx); - const hn::Vec result = op(d, vx, vy); - hn::StoreN(result, d, out + idx, count - idx); - } -} - struct AddOp { template HWY_INLINE V operator()(D d, V a, V b) const { @@ -61,6 +36,30 @@ struct DivOp { } }; +//===----------------------------------------------------------------------===// +// Elementwise + - * / By Matrix +//===----------------------------------------------------------------------===// +template +HWY_INLINE void elementwise_impl(const T* HWY_RESTRICT x, const T* HWY_RESTRICT y, T* HWY_RESTRICT out, size_t count, Op&& op) { + const hn::ScalableTag d; + const size_t N = hn::Lanes(d); + size_t idx = 0; + + for (; idx + N <= count; idx += N) { + const hn::Vec vx = hn::LoadU(d, x + idx); + const hn::Vec vy = hn::LoadU(d, y + idx); + const hn::Vec result = op(d, vx, vy); + hn::StoreU(result, d, out + idx); + } + + if (idx < count) { + const hn::Vec vx = hn::LoadN(d, x + idx, count - idx); + const hn::Vec vy = hn::LoadN(d, y + idx, count - idx); + const hn::Vec result = op(d, vx, vy); + hn::StoreN(result, d, out + idx, count - idx); + } +} + HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_add_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, const mllm_fp32_t* y, size_t n) { elementwise_impl(x, y, out, n, AddOp{}); } @@ -77,12 +76,81 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_div_fp32(mllm_fp32_t* out, const mllm elementwise_impl(x, y, out, n, DivOp{}); } + +// HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_add_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, const mllm_fp16_t* y, size_t n) { +// elementwise_impl(x, y, out, n, AddOp{}); +// } + +// HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_sub_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, const mllm_fp16_t* y, size_t n) { +// elementwise_impl(x, y, out, n, SubOp{}); +// } + +// HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_mul_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, const mllm_fp16_t* y, size_t n) { +// elementwise_impl(x, y, out, n, MulOp{}); +// } + +// HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_div_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, const mllm_fp16_t* y, size_t n) { +// elementwise_impl(x, y, out, n, DivOp{}); +// } + + +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_add_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t* y, size_t n) { + elementwise_impl(x, y, out, n, AddOp{}); +} + +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_sub_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t* y, size_t n) { + elementwise_impl(x, y, out, n, SubOp{}); +} + +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_mul_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t* y, size_t n) { + elementwise_impl(x, y, out, n, MulOp{}); +} + +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_div_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t* y, size_t n) { + elementwise_impl(x, y, out, n, DivOp{}); +} + + +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_add_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t* y, size_t n) { + elementwise_impl(x, y, out, n, AddOp{}); +} + +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_sub_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t* y, size_t n) { + elementwise_impl(x, y, out, n, SubOp{}); +} + +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_mul_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t* y, size_t n) { + elementwise_impl(x, y, out, n, MulOp{}); +} + +// HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_div_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t* y, size_t n) { +// elementwise_impl(x, y, out, n, DivOp{}); +// } + + +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_add_int8(mllm_int8_t* out, const mllm_int8_t* x, const mllm_int8_t* y, size_t n) { + elementwise_impl(x, y, out, n, AddOp{}); +} + +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_sub_int8(mllm_int8_t* out, const mllm_int8_t* x, const mllm_int8_t* y, size_t n) { + elementwise_impl(x, y, out, n, SubOp{}); +} + +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_mul_int8(mllm_int8_t* out, const mllm_int8_t* x, const mllm_int8_t* y, size_t n) { + elementwise_impl(x, y, out, n, MulOp{}); +} + +// HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_div_int8(mllm_int8_t* out, const mllm_int8_t* x, const mllm_int8_t* y, size_t n) { +// elementwise_impl(x, y, out, n, DivOp{}); +// } + + //===----------------------------------------------------------------------===// // Elementwise + - * / By Const //===----------------------------------------------------------------------===// template -HWY_INLINE void elementwise_scalar_impl(T* HWY_RESTRICT out, const T* HWY_RESTRICT x, const T y, size_t count, Op&& op) { +HWY_INLINE void elementwise_scl_impl(T* HWY_RESTRICT out, const T* HWY_RESTRICT x, const T y, size_t count, Op&& op) { const hn::ScalableTag d; const size_t N = hn::Lanes(d); size_t idx = 0; @@ -103,50 +171,91 @@ HWY_INLINE void elementwise_scalar_impl(T* HWY_RESTRICT out, const T* HWY_RESTRI } } -struct AddScalarOp { - template - HWY_INLINE V operator()(D d, V a, V b) const { - return hn::Add(a, b); - } -}; +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_add_scl_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, const mllm_fp32_t y, size_t n) { + elementwise_scl_impl(out, x, y, n, AddOp{}); +} -struct SubScalarOp { - template - HWY_INLINE V operator()(D d, V a, V b) const { - return hn::Sub(a, b); - } -}; +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_sub_scl_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, const mllm_fp32_t y, size_t n) { + elementwise_scl_impl(out, x, y, n, SubOp{}); +} -struct MulScalarOp { - template - HWY_INLINE V operator()(D d, V a, V b) const { - return hn::Mul(a, b); - } -}; +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_mul_scl_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, const mllm_fp32_t y, size_t n) { + elementwise_scl_impl(out, x, y, n, MulOp{}); +} -struct DivScalarOp { - template - HWY_INLINE V operator()(D d, V a, V b) const { - return hn::Div(a, b); - } -}; +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_div_scl_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, const mllm_fp32_t y, size_t n) { + elementwise_scl_impl(out, x, y, n, DivOp{}); +} + + +// HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_add_scl_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, const mllm_fp16_t y, size_t n) { +// elementwise_scl_impl(out, x, y, n, AddOp{}); +// } + +// HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_sub_scl_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, const mllm_fp16_t y, size_t n) { +// elementwise_scl_impl(out, x, y, n, SubOp{}); +// } + +// HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_mul_scl_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, const mllm_fp16_t y, size_t n) { +// elementwise_scl_impl(out, x, y, n, MulOp{}); +// } + +// HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_div_scl_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, const mllm_fp16_t y, size_t n) { +// elementwise_scl_impl(out, x, y, n, DivOp{}); +// } + + +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_add_scl_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t y, size_t n) { + elementwise_scl_impl(out, x, y, n, AddOp{}); +} -HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_add_scalar_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, const mllm_fp32_t y, size_t n) { - elementwise_scalar_impl(out, x, y, n, AddScalarOp{}); +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_sub_scl_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t y, size_t n) { + elementwise_scl_impl(out, x, y, n, SubOp{}); } -HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_sub_scalar_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, const mllm_fp32_t y, size_t n) { - elementwise_scalar_impl(out, x, y, n, SubScalarOp{}); +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_mul_scl_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t y, size_t n) { + elementwise_scl_impl(out, x, y, n, MulOp{}); } -HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_mul_scalar_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, const mllm_fp32_t y, size_t n) { - elementwise_scalar_impl(out, x, y, n, MulScalarOp{}); +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_div_scl_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t y, size_t n) { + elementwise_scl_impl(out, x, y, n, DivOp{}); } -HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_div_scalar_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, const mllm_fp32_t y, size_t n) { - elementwise_scalar_impl(out, x, y, n, DivScalarOp{}); + +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_add_scl_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t y, size_t n) { + elementwise_scl_impl(out, x, y, n, AddOp{}); +} + +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_sub_scl_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t y, size_t n) { + elementwise_scl_impl(out, x, y, n, SubOp{}); +} + +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_mul_scl_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t y, size_t n) { + elementwise_scl_impl(out, x, y, n, MulOp{}); } +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_div_scl_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t y, size_t n) { + elementwise_scl_impl(out, x, y, n, DivOp{}); +} + + +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_add_scl_int8(mllm_int8_t* out, const mllm_int8_t* x, const mllm_int8_t y, size_t n) { + elementwise_scl_impl(out, x, y, n, AddOp{}); +} + +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_sub_scl_int8(mllm_int8_t* out, const mllm_int8_t* x, const mllm_int8_t y, size_t n) { + elementwise_scl_impl(out, x, y, n, SubOp{}); +} + +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_mul_scl_int8(mllm_int8_t* out, const mllm_int8_t* x, const mllm_int8_t y, size_t n) { + elementwise_scl_impl(out, x, y, n, MulOp{}); +} + +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_div_scl_int8(mllm_int8_t* out, const mllm_int8_t* x, const mllm_int8_t y, size_t n) { + elementwise_scl_impl(out, x, y, n, DivOp{}); +} + + //===----------------------------------------------------------------------===// // Inplace Elementwise + - * / // diff --git a/mllm/backends/cpu/kernels/common/ggml/quantize/quantize.hpp b/mllm/backends/cpu/kernels/common/ggml/quantize/quantize.hpp index edb983058..e318451a1 100644 --- a/mllm/backends/cpu/kernels/common/ggml/quantize/quantize.hpp +++ b/mllm/backends/cpu/kernels/common/ggml/quantize/quantize.hpp @@ -111,7 +111,7 @@ namespace mllm::cpu { static float table_f32_f16[1 << 16]; static bool table_f32_f16_init = false; -inline static float lookup_fp16_to_fp32(mllm_fp16_t f) { +inline static float lookup_fp16_to_fp32(uint16_t f) { if (!table_f32_f16_init) { uint16_t ii; for (int i = 0; i < (1 << 16); ++i) { diff --git a/mllm/backends/cpu/kernels/common/kernel_dispatch.cpp b/mllm/backends/cpu/kernels/common/kernel_dispatch.cpp index 7e81adfdf..324039c8f 100644 --- a/mllm/backends/cpu/kernels/common/kernel_dispatch.cpp +++ b/mllm/backends/cpu/kernels/common/kernel_dispatch.cpp @@ -18,54 +18,158 @@ // Include all inline implementations here #include "mllm/backends/cpu/kernels/common/elewise-inl.hpp" #include "mllm/backends/cpu/kernels/common/fill-inl.hpp" +#include "mllm/backends/cpu/kernels/common/reduce-inl.hpp" #if HWY_ONCE namespace mllm::cpu::common { //===----------------------------------------------------------------------===// -// Element-wise +// Elementwise + - * / By Matrix //===----------------------------------------------------------------------===// HWY_EXPORT(elewise_add_fp32); HWY_EXPORT(elewise_sub_fp32); HWY_EXPORT(elewise_mul_fp32); HWY_EXPORT(elewise_div_fp32); -HWY_EXPORT(elewise_add_scalar_fp32); -HWY_EXPORT(elewise_sub_scalar_fp32); -HWY_EXPORT(elewise_mul_scalar_fp32); -HWY_EXPORT(elewise_div_scalar_fp32); +// HWY_EXPORT(elewise_add_fp16); +// HWY_EXPORT(elewise_sub_fp16); +// HWY_EXPORT(elewise_mul_fp16); +// HWY_EXPORT(elewise_div_fp16); +HWY_EXPORT(elewise_add_int32); +HWY_EXPORT(elewise_sub_int32); +HWY_EXPORT(elewise_mul_int32); +HWY_EXPORT(elewise_div_int32); +HWY_EXPORT(elewise_add_int16); +HWY_EXPORT(elewise_sub_int16); +HWY_EXPORT(elewise_mul_int16); +// HWY_EXPORT(elewise_div_int16); +HWY_EXPORT(elewise_add_int8); +HWY_EXPORT(elewise_sub_int8); +HWY_EXPORT(elewise_mul_int8); +// HWY_EXPORT(elewise_div_int8); HWY_DLLEXPORT void call_elewise_add_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, const mllm_fp32_t* y, size_t n) { HWY_DYNAMIC_DISPATCH(elewise_add_fp32)(out, x, y, n); } - HWY_DLLEXPORT void call_elewise_sub_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, const mllm_fp32_t* y, size_t n) { HWY_DYNAMIC_DISPATCH(elewise_sub_fp32)(out, x, y, n); } - HWY_DLLEXPORT void call_elewise_mul_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, const mllm_fp32_t* y, size_t n) { HWY_DYNAMIC_DISPATCH(elewise_mul_fp32)(out, x, y, n); } - HWY_DLLEXPORT void call_elewise_div_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, const mllm_fp32_t* y, size_t n) { HWY_DYNAMIC_DISPATCH(elewise_div_fp32)(out, x, y, n); } - -HWY_DLLEXPORT void call_elewise_add_scalar_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n) { - HWY_DYNAMIC_DISPATCH(elewise_add_scalar_fp32)(out, x, y, n); +HWY_DLLEXPORT void call_elewise_add_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t* y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_add_int32)(out, x, y, n); } - -HWY_DLLEXPORT void call_elewise_sub_scalar_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n) { - HWY_DYNAMIC_DISPATCH(elewise_sub_scalar_fp32)(out, x, y, n); +HWY_DLLEXPORT void call_elewise_sub_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t* y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_sub_int32)(out, x, y, n); } - -HWY_DLLEXPORT void call_elewise_mul_scalar_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n) { - HWY_DYNAMIC_DISPATCH(elewise_mul_scalar_fp32)(out, x, y, n); +HWY_DLLEXPORT void call_elewise_mul_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t* y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_mul_int32)(out, x, y, n); +} +HWY_DLLEXPORT void call_elewise_div_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t* y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_div_int32)(out, x, y, n); +} +HWY_DLLEXPORT void call_elewise_add_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t* y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_add_int16)(out, x, y, n); +} +HWY_DLLEXPORT void call_elewise_sub_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t* y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_sub_int16)(out, x, y, n); +} +HWY_DLLEXPORT void call_elewise_mul_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t* y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_mul_int16)(out, x, y, n); } +// HWY_DLLEXPORT void call_elewise_div_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t* y, size_t n) { +// HWY_DYNAMIC_DISPATCH(elewise_div_int16)(out, x, y, n); +// } +HWY_DLLEXPORT void call_elewise_add_int8(mllm_int8_t* out, const mllm_int8_t* x, const mllm_int8_t* y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_add_int8)(out, x, y, n); +} +HWY_DLLEXPORT void call_elewise_sub_int8(mllm_int8_t* out, const mllm_int8_t* x, const mllm_int8_t* y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_sub_int8)(out, x, y, n); +} +HWY_DLLEXPORT void call_elewise_mul_int8(mllm_int8_t* out, const mllm_int8_t* x, const mllm_int8_t* y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_mul_int8)(out, x, y, n); +} +// HWY_DLLEXPORT void call_elewise_div_int8(mllm_int8_t* out, const mllm_int8_t* x, const mllm_int8_t* y, size_t n) { +// HWY_DYNAMIC_DISPATCH(elewise_div_int8)(out, x, y, n); +// } -HWY_DLLEXPORT void call_elewise_div_scalar_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n) { - HWY_DYNAMIC_DISPATCH(elewise_div_scalar_fp32)(out, x, y, n); +//===----------------------------------------------------------------------===// +// Elementwise + - * / By Const +//===----------------------------------------------------------------------===// +HWY_EXPORT(elewise_add_scl_fp32); +HWY_EXPORT(elewise_sub_scl_fp32); +HWY_EXPORT(elewise_mul_scl_fp32); +HWY_EXPORT(elewise_div_scl_fp32); +// HWY_EXPORT(elewise_add_scl_fp16); +// HWY_EXPORT(elewise_sub_scl_fp16); +// HWY_EXPORT(elewise_mul_scl_fp16); +// HWY_EXPORT(elewise_div_scl_fp16); +HWY_EXPORT(elewise_add_scl_int32); +HWY_EXPORT(elewise_sub_scl_int32); +HWY_EXPORT(elewise_mul_scl_int32); +HWY_EXPORT(elewise_div_scl_int32); +HWY_EXPORT(elewise_add_scl_int16); +HWY_EXPORT(elewise_sub_scl_int16); +HWY_EXPORT(elewise_mul_scl_int16); +HWY_EXPORT(elewise_div_scl_int16); +HWY_EXPORT(elewise_add_scl_int8); +HWY_EXPORT(elewise_sub_scl_int8); +HWY_EXPORT(elewise_mul_scl_int8); +HWY_EXPORT(elewise_div_scl_int8); + +HWY_DLLEXPORT void call_elewise_add_scl_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_add_scl_fp32)(out, x, y, n); +} +HWY_DLLEXPORT void call_elewise_sub_scl_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_sub_scl_fp32)(out, x, y, n); +} +HWY_DLLEXPORT void call_elewise_mul_scl_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_mul_scl_fp32)(out, x, y, n); +} +HWY_DLLEXPORT void call_elewise_div_scl_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_div_scl_fp32)(out, x, y, n); +} +HWY_DLLEXPORT void call_elewise_add_scl_int32(mllm_int32_t* out, const mllm_int32_t* x, mllm_int32_t y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_add_scl_int32)(out, x, y, n); +} +HWY_DLLEXPORT void call_elewise_sub_scl_int32(mllm_int32_t* out, const mllm_int32_t* x, mllm_int32_t y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_sub_scl_int32)(out, x, y, n); +} +HWY_DLLEXPORT void call_elewise_mul_scl_int32(mllm_int32_t* out, const mllm_int32_t* x, mllm_int32_t y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_mul_scl_int32)(out, x, y, n); +} +HWY_DLLEXPORT void call_elewise_div_scl_int32(mllm_int32_t* out, const mllm_int32_t* x, mllm_int32_t y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_div_scl_int32)(out, x, y, n); +} +HWY_DLLEXPORT void call_elewise_add_scl_int16(mllm_int16_t* out, const mllm_int16_t* x, mllm_int16_t y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_add_scl_int16)(out, x, y, n); +} +HWY_DLLEXPORT void call_elewise_sub_scl_int16(mllm_int16_t* out, const mllm_int16_t* x, mllm_int16_t y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_sub_scl_int16)(out, x, y, n); +} +HWY_DLLEXPORT void call_elewise_mul_scl_int16(mllm_int16_t* out, const mllm_int16_t* x, mllm_int16_t y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_mul_scl_int16)(out, x, y, n); +} +HWY_DLLEXPORT void call_elewise_div_scl_int16(mllm_int16_t* out, const mllm_int16_t* x, mllm_int16_t y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_div_scl_int16)(out, x, y, n); +} +HWY_DLLEXPORT void call_elewise_add_scl_int8(mllm_int8_t* out, const mllm_int8_t* x, mllm_int8_t y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_add_scl_int8)(out, x, y, n); +} +HWY_DLLEXPORT void call_elewise_sub_scl_int8(mllm_int8_t* out, const mllm_int8_t* x, mllm_int8_t y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_sub_scl_int8)(out, x, y, n); +} +HWY_DLLEXPORT void call_elewise_mul_scl_int8(mllm_int8_t* out, const mllm_int8_t* x, mllm_int8_t y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_mul_scl_int8)(out, x, y, n); +} +HWY_DLLEXPORT void call_elewise_div_scl_int8(mllm_int8_t* out, const mllm_int8_t* x, mllm_int8_t y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_div_scl_int8)(out, x, y, n); } + //===----------------------------------------------------------------------===// // GELU //===----------------------------------------------------------------------===// @@ -252,6 +356,15 @@ HWY_DLLEXPORT void call_fill_random_u8(mllm_uint8_t* dst, size_t n, mllm_fp32_t HWY_DYNAMIC_DISPATCH(fill_random_u8)(dst, n, start, end, seed); } +//===----------------------------------------------------------------------===// +// Reduce +//===----------------------------------------------------------------------===// +HWY_EXPORT(reduce_sum_fp32); + +HWY_DLLEXPORT void call_reduce_sum_fp32(mllm_fp32_t* dst, const mllm_fp32_t* src, size_t src_stride, size_t size, int32_t thread_count) { + HWY_DYNAMIC_DISPATCH(reduce_sum_fp32)(dst, src, src_stride, size, thread_count); +} + } // namespace mllm::cpu::common #endif // HWY_ONCE diff --git a/mllm/backends/cpu/kernels/common/kernel_dispatch.hpp b/mllm/backends/cpu/kernels/common/kernel_dispatch.hpp index 4df34db0e..170a74069 100644 --- a/mllm/backends/cpu/kernels/common/kernel_dispatch.hpp +++ b/mllm/backends/cpu/kernels/common/kernel_dispatch.hpp @@ -18,18 +18,191 @@ namespace mllm::cpu::common { //===----------------------------------------------------------------------===// // Elementwise + - * / By Matrix //===----------------------------------------------------------------------===// +/// @brief Elementwise operations on contiguous buffers: out[i] = x[i] (op) y[i]. +/// @param out Output buffer of length n. +/// @param x Input buffer of length n. +/// @param y Input buffer of length n. +/// @param n Number of elements. +/// @note For integer division, behavior is undefined when a divisor is zero. HWY_DLLEXPORT void call_elewise_add_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, const mllm_fp32_t* y, size_t n); HWY_DLLEXPORT void call_elewise_sub_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, const mllm_fp32_t* y, size_t n); HWY_DLLEXPORT void call_elewise_mul_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, const mllm_fp32_t* y, size_t n); HWY_DLLEXPORT void call_elewise_div_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, const mllm_fp32_t* y, size_t n); +//TODO: fp16 support not implemented yet +// HWY_DLLEXPORT void call_elewise_add_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, const mllm_fp16_t* y, size_t n); +// HWY_DLLEXPORT void call_elewise_sub_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, const mllm_fp16_t* y, size_t n); +// HWY_DLLEXPORT void call_elewise_mul_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, const mllm_fp16_t* y, size_t n); +// HWY_DLLEXPORT void call_elewise_div_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, const mllm_fp16_t* y, size_t n); +HWY_DLLEXPORT void call_elewise_add_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t* y, size_t n); +HWY_DLLEXPORT void call_elewise_sub_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t* y, size_t n); +HWY_DLLEXPORT void call_elewise_mul_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t* y, size_t n); +HWY_DLLEXPORT void call_elewise_div_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t* y, size_t n); +HWY_DLLEXPORT void call_elewise_add_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t* y, size_t n); +HWY_DLLEXPORT void call_elewise_sub_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t* y, size_t n); +HWY_DLLEXPORT void call_elewise_mul_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t* y, size_t n); +// HWY_DLLEXPORT void call_elewise_div_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t* y, size_t n); +HWY_DLLEXPORT void call_elewise_add_int8(mllm_int8_t* out, const mllm_int8_t* x, const mllm_int8_t* y, size_t n); +HWY_DLLEXPORT void call_elewise_sub_int8(mllm_int8_t* out, const mllm_int8_t* x, const mllm_int8_t* y, size_t n); +HWY_DLLEXPORT void call_elewise_mul_int8(mllm_int8_t* out, const mllm_int8_t* x, const mllm_int8_t* y, size_t n); +// HWY_DLLEXPORT void call_elewise_div_int8(mllm_int8_t* out, const mllm_int8_t* x, const mllm_int8_t* y, size_t n); //===----------------------------------------------------------------------===// // Elementwise + - * / By Const //===----------------------------------------------------------------------===// -HWY_DLLEXPORT void call_elewise_add_scalar_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n); -HWY_DLLEXPORT void call_elewise_sub_scalar_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n); -HWY_DLLEXPORT void call_elewise_mul_scalar_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n); -HWY_DLLEXPORT void call_elewise_div_scalar_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n); +/// @brief Elementwise operations with a scalar constant: out[i] = x[i] (op) y. +/// @param out Output buffer of length n. +/// @param x Input buffer of length n. +/// @param y Scalar constant. +/// @param n Number of elements. +/// @note For integer division, behavior is undefined when y == 0. +HWY_DLLEXPORT void call_elewise_add_scl_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n); +HWY_DLLEXPORT void call_elewise_sub_scl_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n); +HWY_DLLEXPORT void call_elewise_mul_scl_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n); +HWY_DLLEXPORT void call_elewise_div_scl_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n); +//TODO: fp16 support not implemented yet +// HWY_DLLEXPORT void call_elewise_add_scl_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, mllm_fp16_t y, size_t n); +// HWY_DLLEXPORT void call_elewise_sub_scl_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, mllm_fp16_t y, size_t n); +// HWY_DLLEXPORT void call_elewise_mul_scl_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, mllm_fp16_t y, size_t n); +// HWY_DLLEXPORT void call_elewise_div_scl_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, mllm_fp16_t y, size_t n); +HWY_DLLEXPORT void call_elewise_add_scl_int32(mllm_int32_t* out, const mllm_int32_t* x, mllm_int32_t y, size_t n); +HWY_DLLEXPORT void call_elewise_sub_scl_int32(mllm_int32_t* out, const mllm_int32_t* x, mllm_int32_t y, size_t n); +HWY_DLLEXPORT void call_elewise_mul_scl_int32(mllm_int32_t* out, const mllm_int32_t* x, mllm_int32_t y, size_t n); +HWY_DLLEXPORT void call_elewise_div_scl_int32(mllm_int32_t* out, const mllm_int32_t* x, mllm_int32_t y, size_t n); +HWY_DLLEXPORT void call_elewise_add_scl_int16(mllm_int16_t* out, const mllm_int16_t* x, mllm_int16_t y, size_t n); +HWY_DLLEXPORT void call_elewise_sub_scl_int16(mllm_int16_t* out, const mllm_int16_t* x, mllm_int16_t y, size_t n); +HWY_DLLEXPORT void call_elewise_mul_scl_int16(mllm_int16_t* out, const mllm_int16_t* x, mllm_int16_t y, size_t n); +HWY_DLLEXPORT void call_elewise_div_scl_int16(mllm_int16_t* out, const mllm_int16_t* x, mllm_int16_t y, size_t n); +HWY_DLLEXPORT void call_elewise_add_scl_int8(mllm_int8_t* out, const mllm_int8_t* x, mllm_int8_t y, size_t n); +HWY_DLLEXPORT void call_elewise_sub_scl_int8(mllm_int8_t* out, const mllm_int8_t* x, mllm_int8_t y, size_t n); +HWY_DLLEXPORT void call_elewise_mul_scl_int8(mllm_int8_t* out, const mllm_int8_t* x, mllm_int8_t y, size_t n); +HWY_DLLEXPORT void call_elewise_div_scl_int8(mllm_int8_t* out, const mllm_int8_t* x, mllm_int8_t y, size_t n); + +//===----------------------------------------------------------------------===// +// Template wrapper for generic elewise operations +//===----------------------------------------------------------------------===// +template +inline void elewise_add_anytype(T* out, const T* x, const T* y, size_t n) { + if constexpr (std::is_same_v) { + call_elewise_add_fp32(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_add_int32(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_add_int16(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_add_int8(out, x, y, n); + } else { + // Fallback + for (size_t i = 0; i < n; ++i) { out[i] = x[i] + y[i]; } + } +} + +template +inline void elewise_sub_anytype(T* out, const T* x, const T* y, size_t n) { + if constexpr (std::is_same_v) { + call_elewise_sub_fp32(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_sub_int32(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_sub_int16(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_sub_int8(out, x, y, n); + } else { + // Fallback + for (size_t i = 0; i < n; ++i) { out[i] = x[i] - y[i]; } + } +} + +template +inline void elewise_mul_anytype(T* out, const T* x, const T* y, size_t n) { + if constexpr (std::is_same_v) { + call_elewise_mul_fp32(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_mul_int32(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_mul_int16(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_mul_int8(out, x, y, n); + } else { + // Fallback + for (size_t i = 0; i < n; ++i) { out[i] = x[i] * y[i]; } + } +} + +template +inline void elewise_div_anytype(T* out, const T* x, const T* y, size_t n) { + if constexpr (std::is_same_v) { + call_elewise_div_fp32(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_div_int32(out, x, y, n); + } else { + // Fallback (note: division by zero is undefined) + for (size_t i = 0; i < n; ++i) { out[i] = x[i] / y[i]; } + } +} + +template +inline void elewise_add_scl_anytype(T* out, const T* x, T y, size_t n) { + if constexpr (std::is_same_v) { + call_elewise_add_scl_fp32(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_add_scl_int32(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_add_scl_int16(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_add_scl_int8(out, x, y, n); + } else { + // Fallback + for (size_t i = 0; i < n; ++i) { out[i] = x[i] + y; } + } +} + +template +inline void elewise_sub_scl_anytype(T* out, const T* x, T y, size_t n) { + if constexpr (std::is_same_v) { + call_elewise_sub_scl_fp32(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_sub_scl_int32(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_sub_scl_int16(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_sub_scl_int8(out, x, y, n); + } else { + // Fallback + for (size_t i = 0; i < n; ++i) { out[i] = x[i] - y; } + } +} + +template +inline void elewise_mul_scl_anytype(T* out, const T* x, T y, size_t n) { + if constexpr (std::is_same_v) { + call_elewise_mul_scl_fp32(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_mul_scl_int32(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_mul_scl_int16(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_mul_scl_int8(out, x, y, n); + } else { + // Fallback + for (size_t i = 0; i < n; ++i) { out[i] = x[i] * y; } + } +} + +template +inline void elewise_div_scl_anytype(T* out, const T* x, T y, size_t n) { + if constexpr (std::is_same_v) { + call_elewise_div_scl_fp32(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_div_scl_int32(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_div_scl_int16(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_div_scl_int8(out, x, y, n); + } else { + // Fallback (note: division by zero is undefined) + for (size_t i = 0; i < n; ++i) { out[i] = x[i] / y; } + } +} //===----------------------------------------------------------------------===// // Fill Zeros @@ -247,6 +420,17 @@ inline void fill_random_anytype(T* dst, size_t n, mllm_fp32_t start, mllm_fp32_t } } +//===----------------------------------------------------------------------===// +// Reduce +//===----------------------------------------------------------------------===// +/// Sum-reduction over a strided FP32 buffer. +/// @param dst Output buffer receiving the reduction result(s). +/// @param src Input buffer. +/// @param src_stride Stride between consecutive source elements. +/// @param size Number of elements to reduce. +/// @param thread_count Requested number of threads (implementation may clamp). +HWY_DLLEXPORT void call_reduce_sum_fp32(mllm_fp32_t* dst, const mllm_fp32_t* src, size_t src_stride, size_t size, int32_t thread_count); + } // namespace mllm::cpu::common #endif diff --git a/mllm/backends/cpu/kernels/common/reduce-inl.hpp b/mllm/backends/cpu/kernels/common/reduce-inl.hpp index e69de29bb..357c0b0cb 100644 --- a/mllm/backends/cpu/kernels/common/reduce-inl.hpp +++ b/mllm/backends/cpu/kernels/common/reduce-inl.hpp @@ -0,0 +1,123 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include +#include "mllm/core/DataTypes.hpp" + +HWY_BEFORE_NAMESPACE(); +namespace mllm::cpu::common { // NOLINT +namespace HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; + + +struct ScalarAddOp { template HWY_INLINE T operator()(T a, T b) const { return a + b; } }; + +struct ScalarSubOp { template HWY_INLINE T operator()(T a, T b) const { return a - b; } }; + +struct ScalarMulOp { template HWY_INLINE T operator()(T a, T b) const { return a * b; } }; + +struct ScalarDivOp { template HWY_INLINE T operator()(T a, T b) const { return a / b; } }; + +struct ScalarMaxOp { template HWY_INLINE T operator()(T a, T b) const { return a > b ? a : b; } }; + +struct ScalarMinOp { template HWY_INLINE T operator()(T a, T b) const { return a < b ? a : b; } }; + +struct VecAddOp { + template + HWY_INLINE V operator()(D d, V a, V b) const { return hn::Add(a, b); } +}; + +struct VecSubOp { + template + HWY_INLINE V operator()(D d, V a, V b) const { return hn::Sub(a, b); } +}; + +struct VecMulOp { + template + HWY_INLINE V operator()(D d, V a, V b) const { return hn::Mul(a, b); } +}; + +struct VecDivOp { + template + HWY_INLINE V operator()(D d, V a, V b) const { return hn::Div(a, b); } +}; + +struct VecMaxOp { + template + HWY_INLINE V operator()(D d, V a, V b) const { return hn::Max(a, b); } +}; + +struct VecMinOp { + template + HWY_INLINE V operator()(D d, V a, V b) const { return hn::Min(a, b); } +}; + +struct VecSumReduce { + template + HWY_INLINE hn::TFromD operator()(D d, V v) const { return hn::ReduceSum(d, v); } +}; + + +template +HWY_INLINE T reduce_impl(const T* HWY_RESTRICT src, size_t src_stride, size_t size, + ScalarOp&& scalar_op, VectorOp&& vec_op, VectorReduceOp&& vec_reduce_op) { + if (size == 0) return T(0); + + const hn::ScalableTag d; + const size_t N = hn::Lanes(d); + + // SIMD fast path + if (src_stride == 1 && size >= N) { + using V = hn::Vec; + + // Init with first vector + V vec_result = hn::LoadU(d, src); + size_t i = N; + + // 4x unroll + for (; i + 4 * N <= size; i += 4 * N) { + const V v0 = hn::LoadU(d, src + i); + const V v1 = hn::LoadU(d, src + i + N); + const V v2 = hn::LoadU(d, src + i + 2 * N); + const V v3 = hn::LoadU(d, src + i + 3 * N); + + vec_result = vec_op(d, vec_result, v0); + vec_result = vec_op(d, vec_result, v1); + vec_result = vec_op(d, vec_result, v2); + vec_result = vec_op(d, vec_result, v3); + } + + for (; i + N <= size; i += N) { + const V v = hn::LoadU(d, src + i); + vec_result = vec_op(d, vec_result, v); + } + + if (i < size) { + const V vt = hn::LoadN(d, src + i, size - i); + vec_result = vec_op(d, vec_result, vt); + } + + return vec_reduce_op(d, vec_result); + } + + // Scalar path (stride != 1 or too small) + T scalar_result = src[0]; + for (size_t i = 1; i < size; ++i) { + scalar_result = scalar_op(scalar_result, src[i * src_stride]); + } + return scalar_result; + +} + + +HWY_NOINLINE HWY_MAYBE_UNUSED void reduce_sum_fp32(mllm_fp32_t* dst,const mllm_fp32_t* src, +size_t src_stride, size_t size, int32_t thread_count) { + const mllm_fp32_t v = reduce_impl(src, src_stride, size, + ScalarAddOp{}, VecAddOp{}, VecSumReduce{}); + *dst = v; +} + + +} // namespace HWY_NAMESPACE +} // namespace mllm::cpu::common +HWY_AFTER_NAMESPACE(); diff --git a/mllm/backends/cpu/ops/ElewiseOps.cpp b/mllm/backends/cpu/ops/ElewiseOps.cpp index a3e1f7dd1..fd7430948 100644 --- a/mllm/backends/cpu/ops/ElewiseOps.cpp +++ b/mllm/backends/cpu/ops/ElewiseOps.cpp @@ -140,30 +140,30 @@ void CPUAddOp::forward(const std::vector& inputs, std::vector& o switch (dtype) { case kFloat32: { if (input0.numel() == input1.numel()) { -#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_add_fp32(output.ptr(), input0.ptr(), input1.ptr(), - output.numel()); -#elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) +#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_add_fp32(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::elewise_add_anytype(output.ptr(), input0.ptr(), input1.ptr(), + output.numel()); #else - NYI("AddOp not supported on this architecture."); + NYI("AddOp not supported on this architecture."); #endif } else if (input1.numel() == 1) { -#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_add_scalar_fp32(output.ptr(), input0.ptr(), *input1.ptr(), - output.numel()); -#elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) +#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_add_fp32_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::elewise_add_scl_anytype(output.ptr(), input0.ptr(), *input1.ptr(), + output.numel()); #else - NYI("AddOp not supported on this architecture."); + NYI("AddOp not supported on this architecture."); #endif } else if (can_be_broadcast_naive) { const float* a = input0.ptr(); const float* b = input1.ptr(); float* out = output.ptr(); -#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) +#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) // Process each batch separately for (int batch = 0; batch < batch_dims; ++batch) { // Each batch processes broadcast_naive_loops iterations of vector_size elements @@ -172,11 +172,10 @@ void CPUAddOp::forward(const std::vector& inputs, std::vector& o size_t b_offset = batch * vector_size; // b doesn't broadcast over loops dimension size_t out_offset = batch * broadcast_naive_loops * vector_size + l * vector_size; - cpu::common::call_elewise_add_fp32(out + out_offset, a + a_offset, b + b_offset, vector_size); + cpu::arm::ew_add_fp32(out + out_offset, a + a_offset, b + b_offset, vector_size, options_.getThreads()); } } - -#elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) // Process each batch separately for (int batch = 0; batch < batch_dims; ++batch) { // Each batch processes broadcast_naive_loops iterations of vector_size elements @@ -185,11 +184,11 @@ void CPUAddOp::forward(const std::vector& inputs, std::vector& o size_t b_offset = batch * vector_size; // b doesn't broadcast over loops dimension size_t out_offset = batch * broadcast_naive_loops * vector_size + l * vector_size; - cpu::arm::ew_add_fp32(out + out_offset, a + a_offset, b + b_offset, vector_size, options_.getThreads()); + cpu::common::elewise_add_anytype(out + out_offset, a + a_offset, b + b_offset, vector_size); } } #else - NYI("AddOp not supported on this architecture."); + NYI("AddOp not supported on this architecture."); #endif } else { NYI("AddOp broadcast not supported."); @@ -202,11 +201,15 @@ void CPUAddOp::forward(const std::vector& inputs, std::vector& o #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) cpu::arm::ew_add_fp16(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + NYI("AddOp fp16 not supported on x86 architecture yet."); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) cpu::arm::ew_add_fp16_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + NYI("AddOp fp16 not supported on x86 architecture yet."); #endif } else { NYI("AddOp broadcast not supported."); @@ -219,11 +222,17 @@ void CPUAddOp::forward(const std::vector& inputs, std::vector& o #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_add_int32(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::elewise_add_anytype(output.ptr(), input0.ptr(), input1.ptr(), + output.numel()); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_add_int32_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::elewise_add_scl_anytype(output.ptr(), input0.ptr(), *input1.ptr(), + output.numel()); #endif } else { NYI("AddOp broadcast not supported."); @@ -236,11 +245,17 @@ void CPUAddOp::forward(const std::vector& inputs, std::vector& o #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_add_int16(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::elewise_add_anytype(output.ptr(), input0.ptr(), input1.ptr(), + output.numel()); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_add_int16_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::elewise_add_scl_anytype(output.ptr(), input0.ptr(), *input1.ptr(), + output.numel()); #endif } else { NYI("AddOp broadcast not supported."); @@ -253,11 +268,16 @@ void CPUAddOp::forward(const std::vector& inputs, std::vector& o #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_add_int8(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::elewise_add_anytype(output.ptr(), input0.ptr(), input1.ptr(), output.numel()); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_add_int8_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::elewise_add_scl_anytype(output.ptr(), input0.ptr(), *input1.ptr(), + output.numel()); #endif } else { NYI("AddOp broadcast not supported."); @@ -281,7 +301,6 @@ void CPUAddOp::forward(const std::vector& inputs, std::vector& o const float* a = input0.ptr(); const mllm_complex_fp32_t* b = input1.ptr(); mllm_complex_fp32_t* out = output.ptr(); - #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) // Process each batch separately for (int batch = 0; batch < batch_dims; ++batch) { @@ -323,30 +342,30 @@ void CPUSubOp::forward(const std::vector& inputs, std::vector& o switch (dtype) { case kFloat32: { if (input0.numel() == input1.numel()) { -#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_sub_fp32(output.ptr(), input0.ptr(), input1.ptr(), - output.numel()); -#elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) +#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_sub_fp32(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::elewise_sub_anytype(output.ptr(), input0.ptr(), input1.ptr(), + output.numel()); #else - NYI("SubOp not supported on this architecture."); + NYI("SubOp not supported on this architecture."); #endif } else if (input1.numel() == 1) { -#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_sub_scalar_fp32(output.ptr(), input0.ptr(), *input1.ptr(), - output.numel()); -#elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) +#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_sub_fp32_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::elewise_sub_scl_anytype(output.ptr(), input0.ptr(), *input1.ptr(), + output.numel()); #else - NYI("SubOp not supported on this architecture."); + NYI("SubOp not supported on this architecture."); #endif } else if (can_be_broadcast_naive) { const float* a = input0.ptr(); const float* b = input1.ptr(); float* out = output.ptr(); -#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) +#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) // Process each batch separately for (int batch = 0; batch < batch_dims; ++batch) { // Each batch processes broadcast_naive_loops iterations of vector_size elements @@ -355,10 +374,10 @@ void CPUSubOp::forward(const std::vector& inputs, std::vector& o size_t b_offset = batch * vector_size; // b doesn't broadcast over loops dimension size_t out_offset = batch * broadcast_naive_loops * vector_size + l * vector_size; - cpu::common::call_elewise_sub_fp32(out + out_offset, a + a_offset, b + b_offset, vector_size); + cpu::arm::ew_sub_fp32(out + out_offset, a + a_offset, b + b_offset, vector_size, options_.getThreads()); } } -#elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) // Process each batch separately for (int batch = 0; batch < batch_dims; ++batch) { // Each batch processes broadcast_naive_loops iterations of vector_size elements @@ -367,11 +386,11 @@ void CPUSubOp::forward(const std::vector& inputs, std::vector& o size_t b_offset = batch * vector_size; // b doesn't broadcast over loops dimension size_t out_offset = batch * broadcast_naive_loops * vector_size + l * vector_size; - cpu::arm::ew_sub_fp32(out + out_offset, a + a_offset, b + b_offset, vector_size, options_.getThreads()); + cpu::common::elewise_sub_anytype(out + out_offset, a + a_offset, b + b_offset, vector_size); } } #else - NYI("SubOp not supported on this architecture."); + NYI("SubOp not supported on this architecture."); #endif } else { NYI("SubOp broadcast not supported."); @@ -384,11 +403,15 @@ void CPUSubOp::forward(const std::vector& inputs, std::vector& o #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) cpu::arm::ew_sub_fp16(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + NYI("SubOp fp16 not supported on x86 architecture yet."); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) cpu::arm::ew_sub_fp16_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + NYI("SubOp fp16 not supported on x86 architecture yet."); #endif } else { NYI("SubOp broadcast not supported."); @@ -401,11 +424,17 @@ void CPUSubOp::forward(const std::vector& inputs, std::vector& o #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_sub_int32(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::elewise_sub_anytype(output.ptr(), input0.ptr(), input1.ptr(), + output.numel()); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_sub_int32_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::elewise_sub_scl_anytype(output.ptr(), input0.ptr(), *input1.ptr(), + output.numel()); #endif } else { NYI("SubOp broadcast not supported."); @@ -418,11 +447,17 @@ void CPUSubOp::forward(const std::vector& inputs, std::vector& o #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_sub_int16(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::elewise_sub_anytype(output.ptr(), input0.ptr(), input1.ptr(), + output.numel()); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_sub_int16_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::elewise_sub_scl_anytype(output.ptr(), input0.ptr(), *input1.ptr(), + output.numel()); #endif } else { NYI("SubOp broadcast not supported."); @@ -435,11 +470,16 @@ void CPUSubOp::forward(const std::vector& inputs, std::vector& o #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_sub_int8(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::elewise_sub_anytype(output.ptr(), input0.ptr(), input1.ptr(), output.numel()); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_sub_int8_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::elewise_sub_scl_anytype(output.ptr(), input0.ptr(), *input1.ptr(), + output.numel()); #endif } else { NYI("SubOp broadcast not supported."); @@ -505,42 +545,42 @@ void CPUMulOp::forward(const std::vector& inputs, std::vector& o switch (dtype) { case kFloat32: { if (input0.numel() == input1.numel()) { -#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_mul_fp32(output.ptr(), input0.ptr(), input1.ptr(), - output.numel()); -#elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) +#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_mul_fp32(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::elewise_mul_anytype(output.ptr(), input0.ptr(), input1.ptr(), + output.numel()); #else - NYI("MulOp not supported on this architecture."); + NYI("MulOp not supported on this architecture."); #endif } else if (input1.numel() == 1) { -#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_mul_scalar_fp32(output.ptr(), input0.ptr(), *input1.ptr(), - output.numel()); -#elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) +#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_mul_fp32_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::elewise_mul_scl_anytype(output.ptr(), input0.ptr(), *input1.ptr(), + output.numel()); #else - NYI("MulOp not supported on this architecture."); + NYI("MulOp not supported on this architecture."); #endif } else if (can_be_broadcast_naive) { const float* a = input0.ptr(); const float* b = input1.ptr(); float* out = output.ptr(); -#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) +#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) // Process each batch separately for (int batch = 0; batch < batch_dims; ++batch) { - // Each batch processes broadcast_naive_loops iterations of vector_size elements + // Each batch processes broadcast_naive_loops iterations of vector_size elements for (int l = 0; l < broadcast_naive_loops; ++l) { size_t a_offset = batch * broadcast_naive_loops * vector_size + l * vector_size; size_t b_offset = batch * vector_size; // b doesn't broadcast over loops dimension size_t out_offset = batch * broadcast_naive_loops * vector_size + l * vector_size; - cpu::common::call_elewise_mul_fp32(out + out_offset, a + a_offset, b + b_offset, vector_size); + cpu::arm::ew_mul_fp32(out + out_offset, a + a_offset, b + b_offset, vector_size, options_.getThreads()); } } -#elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) // Process each batch separately for (int batch = 0; batch < batch_dims; ++batch) { // Each batch processes broadcast_naive_loops iterations of vector_size elements @@ -549,11 +589,11 @@ void CPUMulOp::forward(const std::vector& inputs, std::vector& o size_t b_offset = batch * vector_size; // b doesn't broadcast over loops dimension size_t out_offset = batch * broadcast_naive_loops * vector_size + l * vector_size; - cpu::arm::ew_mul_fp32(out + out_offset, a + a_offset, b + b_offset, vector_size, options_.getThreads()); + cpu::common::elewise_mul_anytype(out + out_offset, a + a_offset, b + b_offset, vector_size); } } #else - NYI("MulOp not supported on this architecture."); + NYI("MulOp not supported on this architecture."); #endif } else { NYI("MulOp broadcast not supported."); @@ -566,11 +606,15 @@ void CPUMulOp::forward(const std::vector& inputs, std::vector& o #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) cpu::arm::ew_mul_fp16(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + NYI("MulOp fp16 not supported on x86 architecture yet."); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) cpu::arm::ew_mul_fp16_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + NYI("MulOp fp16 not supported on x86 architecture yet."); #endif } else { NYI("MulOp broadcast not supported."); @@ -583,11 +627,17 @@ void CPUMulOp::forward(const std::vector& inputs, std::vector& o #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_mul_int32(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::elewise_mul_anytype(output.ptr(), input0.ptr(), input1.ptr(), + output.numel()); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_mul_int32_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::elewise_mul_scl_anytype(output.ptr(), input0.ptr(), *input1.ptr(), + output.numel()); #endif } else { NYI("MulOp broadcast not supported."); @@ -600,11 +650,17 @@ void CPUMulOp::forward(const std::vector& inputs, std::vector& o #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_mul_int16(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::elewise_mul_anytype(output.ptr(), input0.ptr(), input1.ptr(), + output.numel()); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_mul_int16_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::elewise_mul_scl_anytype(output.ptr(), input0.ptr(), *input1.ptr(), + output.numel()); #endif } else { NYI("MulOp broadcast not supported."); @@ -617,11 +673,16 @@ void CPUMulOp::forward(const std::vector& inputs, std::vector& o #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_mul_int8(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::elewise_mul_anytype(output.ptr(), input0.ptr(), input1.ptr(), output.numel()); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_mul_int8_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::elewise_mul_scl_anytype(output.ptr(), input0.ptr(), *input1.ptr(), + output.numel()); #endif } else { NYI("MulOp broadcast not supported."); @@ -687,42 +748,42 @@ void CPUDivOp::forward(const std::vector& inputs, std::vector& o switch (dtype) { case kFloat32: { if (input0.numel() == input1.numel()) { -#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_div_fp32(output.ptr(), input0.ptr(), input1.ptr(), - output.numel()); -#elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) +#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_div_fp32(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::elewise_div_anytype(output.ptr(), input0.ptr(), input1.ptr(), + output.numel()); #else - NYI("DivOp not supported on this architecture."); + NYI("DivOp not supported on this architecture."); #endif } else if (input1.numel() == 1) { -#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_div_scalar_fp32(output.ptr(), input0.ptr(), *input1.ptr(), - output.numel()); -#elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) +#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_div_fp32_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::elewise_div_scl_anytype(output.ptr(), input0.ptr(), *input1.ptr(), + output.numel()); #else - NYI("DivOp not supported on this architecture."); + NYI("DivOp not supported on this architecture."); #endif } else if (can_be_broadcast_naive) { const float* a = input0.ptr(); const float* b = input1.ptr(); float* out = output.ptr(); -#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) +#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) // Process each batch separately for (int batch = 0; batch < batch_dims; ++batch) { - // Each batch processes broadcast_naive_loops iterations of vector_size elements + // Each batch processes broadcast_naive_loops iterations of vector_size elements for (int l = 0; l < broadcast_naive_loops; ++l) { size_t a_offset = batch * broadcast_naive_loops * vector_size + l * vector_size; size_t b_offset = batch * vector_size; // b doesn't broadcast over loops dimension size_t out_offset = batch * broadcast_naive_loops * vector_size + l * vector_size; - cpu::common::call_elewise_div_fp32(out + out_offset, a + a_offset, b + b_offset, vector_size); + cpu::arm::ew_div_fp32(out + out_offset, a + a_offset, b + b_offset, vector_size, options_.getThreads()); } } -#elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) // Process each batch separately for (int batch = 0; batch < batch_dims; ++batch) { // Each batch processes broadcast_naive_loops iterations of vector_size elements @@ -731,11 +792,11 @@ void CPUDivOp::forward(const std::vector& inputs, std::vector& o size_t b_offset = batch * vector_size; // b doesn't broadcast over loops dimension size_t out_offset = batch * broadcast_naive_loops * vector_size + l * vector_size; - cpu::arm::ew_div_fp32(out + out_offset, a + a_offset, b + b_offset, vector_size, options_.getThreads()); + cpu::common::elewise_div_anytype(out + out_offset, a + a_offset, b + b_offset, vector_size); } } #else - NYI("DivOp not supported on this architecture."); + NYI("DivOp not supported on this architecture."); #endif } else { NYI("DivOp broadcast not supported."); @@ -748,11 +809,15 @@ void CPUDivOp::forward(const std::vector& inputs, std::vector& o #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) cpu::arm::ew_div_fp16(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + NYI("DivOp fp16 not supported on x86 architecture yet."); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) cpu::arm::ew_div_fp16_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + NYI("DivOp fp16 not supported on x86 architecture yet."); #endif } else { NYI("DivOp broadcast not supported."); @@ -765,11 +830,17 @@ void CPUDivOp::forward(const std::vector& inputs, std::vector& o #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_div_int32(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::elewise_div_anytype(output.ptr(), input0.ptr(), input1.ptr(), + output.numel()); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_div_int32_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::elewise_div_scl_anytype(output.ptr(), input0.ptr(), *input1.ptr(), + output.numel()); #endif } else { NYI("DivOp broadcast not supported."); @@ -782,11 +853,15 @@ void CPUDivOp::forward(const std::vector& inputs, std::vector& o #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_div_int16(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + NYI("DivOp int16 not supported on x86 architecture yet."); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_div_int16_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + NYI("DivOp int16 not supported on x86 architecture yet."); #endif } else { NYI("DivOp broadcast not supported."); @@ -799,11 +874,15 @@ void CPUDivOp::forward(const std::vector& inputs, std::vector& o #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_div_int8(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + NYI("DivOp int8 not supported on x86 architecture yet."); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_div_int8_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + NYI("DivOp int8 not supported on x86 architecture yet."); #endif } else { NYI("DivOp broadcast not supported."); @@ -817,11 +896,15 @@ void CPUDivOp::forward(const std::vector& inputs, std::vector& o #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_div_fp32_complex(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + NYI("DivOp complex fp32 not supported on x86 architecture yet."); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_div_fp32_complex_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + NYI("DivOp complex fp32 not supported on x86 architecture yet."); #endif } else if (can_be_broadcast_naive) { const float* a = input0.ptr(); @@ -840,6 +923,8 @@ void CPUDivOp::forward(const std::vector& inputs, std::vector& o cpu::arm::ew_div_fp32_complex(out + out_offset, a + a_offset, b + b_offset, vector_size, options_.getThreads()); } } +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + NYI("DivOp complex fp32 not supported on x86 architecture yet."); #endif } else { NYI("DivOp broadcast for complex output not supported."); diff --git a/mllm/backends/cpu/ops/MatMulOp.cpp b/mllm/backends/cpu/ops/MatMulOp.cpp index cc7dddde7..4f4cc0efa 100644 --- a/mllm/backends/cpu/ops/MatMulOp.cpp +++ b/mllm/backends/cpu/ops/MatMulOp.cpp @@ -49,8 +49,8 @@ void CPUMatMulOp::forward(const std::vector& inputs, std::vector #if defined(MLLM_USE_BLAS) mt = aops::MatMulOpType::kBLAS; #else - if (!transpose_a && transpose_b && M >= 4) { - // TODO kGGUF still buggy !!! + if (!transpose_a && transpose_b) { + // TODO: kGGUF still buggy !!! mt = aops::MatMulOpType::kGGUF; } else // All fallback to mllm blas @@ -110,6 +110,18 @@ void CPUMatMulOp::forward(const std::vector& inputs, std::vector transpose_a, transpose_b, thread_count); } } +// #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) +// if (lhs.dtype() == kFloat32 && rhs.dtype() == kFloat32 && o.dtype() == kFloat32) { +// if (batch_count == 1) { +// x86::mllm_blas_matmul_fp32(M, K, N, o.ptr(), lhs.ptr(), rhs.ptr(), nullptr, +// transpose_a, transpose_b); +// } else { +// x86::mllm_blas_batch_matmul_fp32(batch_count, M, K, N, o.stride()[o.shape().size() - 3], +// lhs.stride()[lhs_shape.size() - 3], rhs.stride()[rhs_shape.size() - 3], 0, +// o.ptr(), lhs.ptr(), rhs.ptr(), nullptr, +// transpose_a, transpose_b); +// } +// } #else NYI("MllmBlas only support MLLM_HOST_ARCH_ARM64 or MLLM_HOST_ARCH_ARM right now.") #endif diff --git a/mllm/backends/cpu/ops/ReduceOps.cpp b/mllm/backends/cpu/ops/ReduceOps.cpp index a60ae67e2..c42f40709 100644 --- a/mllm/backends/cpu/ops/ReduceOps.cpp +++ b/mllm/backends/cpu/ops/ReduceOps.cpp @@ -294,6 +294,8 @@ void CPUReduceSumOp::forward(const std::vector& inputs, std::vector(), input.ptr(), 1, input.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86) || defined(MLLM_HOST_ARCH_X86_64) + common::call_reduce_sum_fp32(output.ptr(), input.ptr(), 1, input.numel(), options_.getThreads()); #endif break; } @@ -344,6 +346,9 @@ void CPUReduceSumOp::forward(const std::vector& inputs, std::vector>(); + + // Linear implementation type + linear_impl_type = aops::str2LinearImplTypes(data()["linear_impl_type"]); + } + + bool attention_bias = false; + int32_t hidden_size = 2048; + int32_t head_dim = 128; + int32_t intermediate_size = 6144; + int32_t num_attention_heads = 32; + int32_t num_key_value_heads = 4; + int32_t num_hidden_layers = 48; + int32_t max_position_embeddings = 262144; + float rms_norm_eps = 1e-06; + int32_t vocab_size = 151936; + + int64_t bos_token_id = 151643; + int64_t eos_token_id = 151645; + float rope_theta = 1000000.0; + + bool tie_word_embeddings = false; + int32_t max_cache_length = 4096; + int32_t end_of_text_token_id = 151645; // fixed default + + int32_t num_experts = 128; + int32_t num_experts_per_tok = 8; + int32_t moe_intermediate_size = 768; + bool norm_topk_prob = true; + int32_t decoder_sparse_step = 1; + std::vector mlp_only_layers; + + aops::LinearImplTypes linear_impl_type = aops::LinearImplTypes::kDefault; +}; + +} // namespace mllm::models::qwen3_moe diff --git a/mllm/models/qwen3_moe/modeling_qwen3_moe_fa2.hpp b/mllm/models/qwen3_moe/modeling_qwen3_moe_fa2.hpp new file mode 100644 index 000000000..379db0c79 --- /dev/null +++ b/mllm/models/qwen3_moe/modeling_qwen3_moe_fa2.hpp @@ -0,0 +1,492 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/mllm.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/nn/lmcache/StaticCache.hpp" +#include "mllm/models/qwen3_moe/configuration_qwen3_moe.hpp" +#include "mllm/utils/Enumerate.hpp" +#include "mllm/models/ARGeneration.hpp" + +namespace mllm::models::qwen3_moe { + +inline auto makeRoPEInvFreq(int output_dim, float rope_theta) -> Tensor { + auto inv_freq = Tensor::empty({output_dim / 2}, kFloat32, kCPU).alloc(); + auto inv_freq_ptr = inv_freq.ptr(); + for (int i = 0; i < output_dim / 2; i++) { inv_freq_ptr[i] = 1.0 / std::pow(rope_theta, 2.0 * i / output_dim); } + return inv_freq; +} + +inline auto makeRotaryPosEmbedding(Tensor& position_ids, const Tensor& inv_freq, + float attention_scaling = 1.0f) -> std::pair { + auto batch_size = position_ids.shape()[0]; + auto seq_len = position_ids.shape()[1]; + auto inv_freq_len = inv_freq.shape()[0]; + auto dim = inv_freq_len * 2; + + // Create freqs tensor: position_ids @ inv_freq + auto freqs = Tensor::empty({batch_size, seq_len, inv_freq_len}, kFloat32, kCPU).alloc(); + auto freqs_ptr = freqs.ptr(); + auto position_ids_ptr = position_ids.ptr(); + auto inv_freq_ptr = inv_freq.ptr(); + + // Compute freqs = position_ids[:, :, None] @ inv_freq[None, :] + for (int b = 0; b < batch_size; ++b) { + for (int s = 0; s < seq_len; ++s) { + auto pos = position_ids_ptr[b * seq_len + s]; + for (int d = 0; d < inv_freq_len; ++d) { + freqs_ptr[b * seq_len * inv_freq_len + s * inv_freq_len + d] = static_cast(pos) * inv_freq_ptr[d]; + } + } + } + + // Create sin and cos tensors with shape [batch_size, seq_len, dim] + auto sin_emb = Tensor::empty({batch_size, seq_len, dim}, kFloat32, kCPU).alloc(); + auto cos_emb = Tensor::empty({batch_size, seq_len, dim}, kFloat32, kCPU).alloc(); + auto sin_ptr = sin_emb.ptr(); + auto cos_ptr = cos_emb.ptr(); + + // Compute sin and cos embeddings: emb = [freqs, freqs] + for (int b = 0; b < batch_size; ++b) { + for (int s = 0; s < seq_len; ++s) { + for (int d = 0; d < inv_freq_len; ++d) { + auto freq = freqs_ptr[b * seq_len * inv_freq_len + s * inv_freq_len + d]; + auto sin_val = std::sin(freq) * attention_scaling; + auto cos_val = std::cos(freq) * attention_scaling; + + // Store the same values in both halves: [freqs, freqs] + sin_ptr[b * seq_len * dim + s * dim + d] = sin_val; + sin_ptr[b * seq_len * dim + s * dim + d + inv_freq_len] = sin_val; + cos_ptr[b * seq_len * dim + s * dim + d] = cos_val; + cos_ptr[b * seq_len * dim + s * dim + d + inv_freq_len] = cos_val; + } + } + } + + return {sin_emb, cos_emb}; +} + +class Qwen3MoeMLP final : public nn::Module { + nn::Linear gate_proj_; + nn::Linear up_proj_; + nn::Linear down_proj_; + nn::SiLU act_; + + int hidden_size_; + int intermediate_size_; + + public: + Qwen3MoeMLP() = default; + + explicit Qwen3MoeMLP(const std::string& name, const Qwen3MoeConfig& config, + const std::optional& hidden_size = std::nullopt, + const std::optional& intermediate_size = std::nullopt) + : nn::Module(name) { + hidden_size_ = hidden_size.value_or(config.hidden_size); + intermediate_size_ = intermediate_size.value_or(config.intermediate_size); + + // clang-format off + gate_proj_ = reg("gate_proj", hidden_size_, intermediate_size_, false, config.linear_impl_type); + up_proj_ = reg("up_proj", hidden_size_, intermediate_size_, false, config.linear_impl_type); + down_proj_ = reg("down_proj", intermediate_size_, hidden_size_, false, config.linear_impl_type); + act_ = reg("act"); + // clang-format on + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + return {down_proj_(act_(gate_proj_(inputs[0])) * up_proj_(inputs[0]))}; + } +}; + +class MoEGate final : public nn::Module { + int top_k_; + int num_experts_; + bool norm_topk_prob_; + + nn::Param weight_; + + public: + MoEGate() = default; + + MoEGate(const std::string& name, const Qwen3MoeConfig& config) : nn::Module(name) { + top_k_ = config.num_experts_per_tok; + num_experts_ = config.num_experts; + norm_topk_prob_ = config.norm_topk_prob; + + weight_ = reg("weight", getModuleName() + ".weight"); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto bsz = hidden_states.size(0); + auto seq_len = hidden_states.size(1); + auto h = hidden_states.size(2); + + // Compute gating score + hidden_states = hidden_states.view({-1, h}); + // hidden_states and weight must in fp32 to keep precision !!! + auto logits = nn::functional::matmul(hidden_states, weight_.weight(), false, true); + auto scores = nn::functional::softmax(logits, -1); + auto [topk_weight, topk_idx] = nn::functional::topk(scores, top_k_, -1, true, false); + + if(norm_topk_prob_){ + topk_weight = topk_weight / topk_weight.sum(-1, true); + } + + return {topk_idx, topk_weight}; + } +}; + +class Qwen3MoE final : public nn::Module { + int num_experts_per_tok_; + nn::ModuleList experts_; + MoEGate gate_; + + public: + Qwen3MoE() = default; + + Qwen3MoE(const std::string& name, const Qwen3MoeConfig& config) : nn::Module(name) { + num_experts_per_tok_ = config.num_experts_per_tok; + // Init experts + experts_ = reg>("experts", config.num_experts, config, std::nullopt, + config.moe_intermediate_size); + gate_ = reg("gate", config); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto identity = hidden_states; + auto orig_shape = hidden_states.shape(); + auto topk_idx = Tensor::nil(); + auto topk_weight = Tensor::nil(); + auto gated_ret = gate_(hidden_states); + topk_idx = gated_ret[0]; + topk_weight = gated_ret[1]; + hidden_states = hidden_states.view({-1, hidden_states.size(-1)}); + + auto y = moeInfer(hidden_states, topk_idx, topk_weight).view(orig_shape); + + return {y}; + } + + private: + Tensor moeInfer(const Tensor& x, Tensor& topk_ids, Tensor& topk_weights) { + // x shape is [batch_size * seq, hidden_dim] + + auto cnts = Tensor::zeros({topk_ids.size(0), (int32_t)experts_.list().size()}); + // Do scatter_ operation + { + const int32_t* idx_ptr = topk_ids.ptr(); + float* cnt_ptr = cnts.ptr(); + const int batch = topk_ids.size(0); + const int k = topk_ids.size(1); + const int n_exp = cnts.size(1); + for (int b = 0; b < batch; ++b) { + for (int j = 0; j < k; ++j) { + int32_t e = idx_ptr[b * k + j]; + MLLM_RT_ASSERT(e >= 0 && e < n_exp); + cnt_ptr[b * n_exp + e] += 1.f; // +1 + } + } + } + auto tokens_per_expert = cnts.sum(0); + auto idxs = topk_ids.view({-1}).argsort(); + + auto sorted_tokens = x[{idxs / topk_ids.size(1), {kAll}}]; + + std::vector outputs; + int start_idx = 0; + + // tokens_per_expert shape is [num_experts] + // Loop through each expert + for (int i = 0; i < experts_.list().size(); ++i) { + auto num_tokens = tokens_per_expert.ptr()[i]; + auto end_idx = start_idx + (int32_t)num_tokens; + if (num_tokens == 0) { continue; } + auto& expert = experts_.list()[i]; + auto tokens_for_this_expert = sorted_tokens[{{start_idx, end_idx}, kAll}]; + auto expert_out = expert(tokens_for_this_expert)[0]; + outputs.push_back(expert_out); + start_idx = end_idx; + } + + auto outs = nn::functional::concat(outputs, 0); + auto new_x = Tensor::emptyLike(outs).alloc(); + + // indexed_write + // python logic: new_x[idxs] = outs + { + const int32_t* idx_ptr = idxs.ptr(); + float* outs_ptr = outs.ptr(); + float* new_x_ptr = new_x.ptr(); + MLLM_RT_ASSERT_EQ(new_x.rank(), 2); + MLLM_RT_ASSERT_EQ(new_x.size(0), idxs.size(0)); + auto dim = new_x.size(1); + for (int i = 0; i < idxs.size(0); ++i) { + int32_t idx = idx_ptr[i]; + std::memcpy(new_x_ptr + idx * dim, outs_ptr + i * dim, dim * sizeof(float)); + } + } + + auto final_out_shape = topk_ids.shape(); + final_out_shape.emplace_back(-1); + auto final_out = + new_x.view(final_out_shape).to(topk_weights.dtype()).mul_(topk_weights.unsqueeze(-1)).sum(1).to(new_x.dtype()); + return final_out; + } +}; + +class Qwen3MoeAttention final : public nn::Module { + nn::Linear q_proj_; + nn::Linear k_proj_; + nn::Linear v_proj_; + nn::Linear o_proj_; + nn::RMSNorm rms_norm_q_; + nn::RMSNorm rms_norm_k_; + nn::RoPE q_rope_; + nn::RoPE k_rope_; + + int hidden_size_; + int head_dim_; + int num_attention_heads_; + int num_key_value_heads_; + int num_key_value_groups_; + + public: + Qwen3MoeAttention() = default; + + Qwen3MoeAttention(const std::string& name, const Qwen3MoeConfig& cfg) : nn::Module(name) { + hidden_size_ = cfg.hidden_size; + num_attention_heads_ = cfg.num_attention_heads; + num_key_value_heads_ = cfg.num_key_value_heads; + head_dim_ = cfg.head_dim; + num_key_value_groups_ = num_attention_heads_ / num_key_value_heads_; + + // clang-format off + q_proj_ = reg("q_proj", hidden_size_, head_dim_ * num_attention_heads_, cfg.attention_bias, cfg.linear_impl_type); + k_proj_ = reg("k_proj", hidden_size_, head_dim_ * num_key_value_heads_, cfg.attention_bias, cfg.linear_impl_type).redirect(); + v_proj_ = reg("v_proj", hidden_size_, head_dim_ * num_key_value_heads_, cfg.attention_bias, cfg.linear_impl_type).redirect(); + o_proj_ = reg("o_proj", head_dim_ * num_attention_heads_, hidden_size_, cfg.attention_bias, cfg.linear_impl_type); + // clang-format on + + rms_norm_q_ = reg("q_norm", cfg.rms_norm_eps).inplace(); + rms_norm_k_ = reg("k_norm", cfg.rms_norm_eps).inplace(); + + // clang-format off + q_rope_ = reg("q_rope", cfg.rope_theta, cfg.max_position_embeddings, aops::RoPEOpOptionsInputType::kBSHD).inplace(); + k_rope_ = reg("k_rope", cfg.rope_theta, cfg.max_position_embeddings, aops::RoPEOpOptionsInputType::kBSHD).inplace(); + // clang-format on + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = inputs[0]; + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto past_kv_cache = args[0].get(); + + int B = inputs[0].shape()[0]; + int S = inputs[0].shape()[1]; + + // Get KV cache for Key and Value first. + // [B, S, H * D] + auto [key_states_redirect, value_states_redirect] = past_kv_cache->preGetKVWriteLocation(layer_idx_, S); + + // [B, S, H * D] + auto query_states = q_proj_(x); + auto key_states = k_proj_(x, key_states_redirect); + auto value_states = v_proj_(x, value_states_redirect); + + // [B, S, H, D] + query_states = query_states.view({B, S, num_attention_heads_, head_dim_}); + key_states = key_states.view({B, S, num_key_value_heads_, head_dim_}); + + // [B, S, H, D] + query_states = rms_norm_q_(query_states); + key_states = rms_norm_k_(key_states); + + // [B, S, H, D] + query_states = q_rope_(query_states, llm_embedding_sin, llm_embedding_cos); + key_states = k_rope_(key_states, llm_embedding_sin, llm_embedding_cos); + + // Get KV + auto [K, V] = past_kv_cache->getKVCache(layer_idx_); + + // [B, S, H, D] FA2 + auto output = o_proj_(nn::functional::flashAttention2(query_states, K, V).view({B, S, num_attention_heads_ * head_dim_})); + + return {output}; + } + + int layer_idx_; +}; + +class Qwen3MoeDecoder final : public nn::Module { + Qwen3MoeAttention self_attn_; + nn::RMSNorm input_layer_norm_; + nn::RMSNorm post_attention_layer_norm_; + + std::optional mlp_opt0_ = std::nullopt; + std::optional mlp_opt1_ = std::nullopt; + + public: + int layer_idx_; + + Qwen3MoeDecoder() = default; + + Qwen3MoeDecoder(const std::string& name, const Qwen3MoeConfig& cfg, int layer_idx) : nn::Module(name) { + layer_idx_ = layer_idx; + + self_attn_ = reg("self_attn", cfg); + self_attn_.layer_idx_ = layer_idx; + + MLLM_RT_ASSERT(cfg.decoder_sparse_step > 0); + bool is_mlp_only = std::find(cfg.mlp_only_layers.begin(), cfg.mlp_only_layers.end(), layer_idx) != cfg.mlp_only_layers.end(); + if ((!is_mlp_only) && (cfg.num_experts > 0 && (layer_idx_+1) % cfg.decoder_sparse_step == 0)) { + mlp_opt0_ = reg("mlp", cfg); + } else { + mlp_opt1_ = reg("mlp", cfg); + } + + input_layer_norm_ = reg("input_layernorm", cfg.rms_norm_eps); + post_attention_layer_norm_ = reg("post_attention_layernorm", cfg.rms_norm_eps); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto& kv_cache = args[0]; + + auto x = input_layer_norm_(inputs[0]); + x = self_attn_(x, llm_embedding_sin, llm_embedding_cos, kv_cache)[0]; + auto tmp = x + inputs[0]; + x = post_attention_layer_norm_(tmp); + if(mlp_opt0_){ + x = mlp_opt0_.value()(x)[0]; + } else { + x = mlp_opt1_.value()(x)[0]; + } + x = x + tmp; + return {x}; + } +}; + +class Qwen3MoeText final : public nn::Module { + nn::Embedding embedding_; + nn::ModuleListWithIdx decode_blocks_; + nn::RMSNorm norm_; + + public: + Qwen3MoeText() = default; + + explicit Qwen3MoeText(const std::string& name, const Qwen3MoeConfig& cfg) : nn::Module(name) { + embedding_ = reg("embed_tokens", cfg.vocab_size, cfg.hidden_size); + decode_blocks_ = reg>("layers", cfg.num_hidden_layers, cfg); + norm_ = reg("norm", cfg.rms_norm_eps); + + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto& blocks = decode_blocks_.list(); + + // X is already embedded + auto x = embedding_(inputs[0]); + + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto& kv_cache = args[0]; + + for (auto& block : blocks) { x = block(x, llm_embedding_sin, llm_embedding_cos, kv_cache)[0]; } + + x = norm_(x); + + return {x}; + } +}; + +class Qwen3MoeForCausalLM : public ARGeneration, public nn::Module { + public: + explicit Qwen3MoeForCausalLM(const Qwen3MoeConfig& cfg) : cfg(cfg) { + kv_cache_ = nn::StaticCache(cfg.max_cache_length, cfg.num_hidden_layers, + cfg.num_attention_heads, // q_heads + cfg.num_key_value_heads, // kv_heads + cfg.head_dim, // kv_dim + kFloat32, // k_dtype + kFloat32, // v_dtype + kCPU, // device_type + true // use_fa2 + ); + eos_token_id_ = cfg.end_of_text_token_id; + max_length_ = cfg.max_cache_length; + tie_word_embeddings_ = cfg.tie_word_embeddings; + + llm = reg("model", cfg); + + if (cfg.tie_word_embeddings) { + // NOTE: + // model.lm_head.weight is quantization weights of model.embed_tokens.weight + lm_head_ = reg("lm_head", cfg.hidden_size, cfg.vocab_size, false, cfg.linear_impl_type); + } + + // Init inv freq + auto inv = makeRoPEInvFreq(cfg.head_dim, cfg.rope_theta); + registerBuffer("inv_freq", inv); + } + + ARGenerationOutputPast forward(const ARGenerationOutputPast& input, const ARGenerationArgs& args) override { + auto sequence = input.at("sequence"); + + // Generate position_ids for the current sequence + auto batch_size = sequence.shape()[0]; + auto seq_len = sequence.shape()[1]; + + Tensor position_ids = Tensor::nil(); + if (input.count("position_ids")) { + // Use existing position_ids for decode phase + position_ids = input.at("position_ids"); + + // For decode phase, increment the last position + if (seq_len == 1) { + auto last_pos = *position_ids.offsettedPtr({0, position_ids.shape()[1] - 1}); + position_ids = Tensor::empty({batch_size, 1}, kInt64, kCPU).alloc(); + *position_ids.offsettedPtr({0, 0}) = last_pos + 1; + } + } else { + // Generate position_ids for prefill phase + position_ids = Tensor::empty({batch_size, seq_len}, kInt64, kCPU).alloc(); + auto position_ids_ptr = position_ids.ptr(); + for (int b = 0; b < batch_size; ++b) { + for (int s = 0; s < seq_len; ++s) { position_ids_ptr[b * seq_len + s] = s; } + } + } + + // Generate RoPE embeddings using the inv_freq buffer + auto [llm_embedding_sin, llm_embedding_cos] = makeRotaryPosEmbedding(position_ids, getBuffer("inv_freq"), 1.0f); + + sequence = llm(sequence, llm_embedding_sin, llm_embedding_cos, AnyValue(&kv_cache_))[0]; + + // clip x to one seq length + { + auto S = sequence.shape()[1]; + sequence = sequence[{kAll, {S - 1}, kAll}]; + } + if (tie_word_embeddings_) { sequence = lm_head_(sequence); } + + return { + {"sequence", sequence}, + {"position_ids", position_ids}, + }; + } + + inline nn::StaticCache& kvCache() { return kv_cache_; } + + private: + const Qwen3MoeConfig& cfg; + Qwen3MoeText llm; + nn::Linear lm_head_; + bool tie_word_embeddings_; + nn::StaticCache kv_cache_; +}; + +} // namespace mllm::models::qwen3_moe diff --git a/mllm/models/qwen3_moe/tokenization_qwen3_moe.hpp b/mllm/models/qwen3_moe/tokenization_qwen3_moe.hpp new file mode 100644 index 000000000..181d576f7 --- /dev/null +++ b/mllm/models/qwen3_moe/tokenization_qwen3_moe.hpp @@ -0,0 +1,269 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include + +#include "mllm/preprocessor/tokenizers/BPE.hpp" +#include "mllm/models/ARGeneration.hpp" +#include "mllm/preprocessor/tokenizers/Unicode.hpp" +#include "mllm/preprocessor/tokenizers/AutoTokenizer.hpp" + +namespace mllm::models::qwen3_moe { + +// we need to handle this: +// +// (?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| +// ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+ +inline bool qwen3TokenizerMatchPattern(const std::wstring& str, size_t& pos, std::wstring& matched) { + if (pos >= str.size()) return false; + + // 1. Match contractions: "'s|'t|'re|'ve|'m|'ll|'d" + static const std::wstring contractions[] = {L"'s", L"'t", L"'re", L"'ve", L"'m", L"'ll", L"'d"}; + for (const auto& contraction : contractions) { + if (pos + contraction.size() <= str.size() && str.compare(pos, contraction.size(), contraction) == 0) { + matched = contraction; + pos += contraction.size(); + return true; + } + } + + // 2. Match [^\r\n\p{L}\p{N}]?\p{L}+ (non-letter/digit followed by letters) + { + size_t original_pos = pos; + bool has_prefix = false; + matched.clear(); + + // Check optional non-letter/digit prefix (excluding \r\n) + if (!preprocessor::isLetter(str[pos]) && !preprocessor::isDigit(str[pos]) && str[pos] != L'\r' && str[pos] != L'\n') { + matched += str[pos]; + ++pos; + has_prefix = true; + } + + // Require at least one letter + if (pos < str.size() && preprocessor::isLetter(str[pos])) { + do { + matched += str[pos]; + ++pos; + } while (pos < str.size() && preprocessor::isLetter(str[pos])); + return true; + } else { + // Rollback if no letters after prefix + if (has_prefix) { + pos = original_pos; + matched.clear(); + } + } + } + + // 3. Match \p{N} (digits) + if (preprocessor::isDigit(str[pos])) { + matched = str.substr(pos, 1); + ++pos; + return true; + } + + // 4. Match ?[^\s\p{L}\p{N}]+[\r\n]* (punctuation/symbols with optional space prefix) + { + size_t original_pos = pos; + matched.clear(); + size_t start = pos; + + // Optional space + if (str[pos] == L' ') { ++pos; } + + // Require at least one non-letter/digit/whitespace + if (pos < str.size() && !std::iswspace(str[pos]) && !preprocessor::isLetter(str[pos]) && !preprocessor::isDigit(str[pos])) { + do { + ++pos; + } while (pos < str.size() && !std::iswspace(str[pos]) && !preprocessor::isLetter(str[pos]) + && !preprocessor::isDigit(str[pos])); + + // Capture from start (after optional space) to current pos + matched = str.substr(start, pos - start); + + // Capture trailing newlines + while (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) { + matched += str[pos]; + ++pos; + } + return true; + } else { + // Rollback if no symbols found + pos = original_pos; + } + } + + // 5. Match \s*[\r\n]+ (newlines with leading whitespace) + { + size_t start = pos; + while (pos < str.size() && std::iswspace(str[pos])) ++pos; + if (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) { + while (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) ++pos; + matched = str.substr(start, pos - start); + return true; + } else { + pos = start; + } + } + + // 6. Match \s+(?!\S) (whitespace not followed by non-space) + if (std::iswspace(str[pos])) { + size_t start = pos; + while (pos < str.size() && std::iswspace(str[pos])) ++pos; + // Check if at end or followed by whitespace + if (pos >= str.size() || std::iswspace(str[pos])) { + matched = str.substr(start, pos - start); + return true; + } else { + pos = start; + } + } + + // 7. Match remaining whitespace + if (std::iswspace(str[pos])) { + size_t start = pos; + while (pos < str.size() && std::iswspace(str[pos])) ++pos; + matched = str.substr(start, pos - start); + return true; + } + + return false; +} + +inline bool qwen3Regex(const std::string& str, std::vector& splitted) { + auto w_string = preprocessor::utf8string2WideString(str); + size_t pos = 0; + while (pos < w_string.size()) { + std::wstring matched; + if (qwen3TokenizerMatchPattern(w_string, pos, matched)) { + splitted.push_back(matched); + } else { + ++pos; + } + } + return true; +} + +struct Qwen3Message { + std::string prompt; + static inline std::string message_template = + "<|im_start|>user\n{{{prompt}}}<|im_end|>\n<|im_start|>assistant\n"; +}; + +class Qwen3Tokenizer final : public mllm::preprocessor::AutoTokenizer { + public: + explicit Qwen3Tokenizer(const std::string& file_path) { + preprocessor::initLocal(); + preprocessor::makeBytes2UnicodeMap(bytes_2_unicode_dict_); + for (auto& kv : bytes_2_unicode_dict_) { bytes_2_unicode_dict_inverse_.insert({kv.second, kv.first}); } + bpe_.initFromSentencePieceJson(file_path); + special_tokens_trie_.add(L"<|endoftext|>"); + special_tokens_trie_.add(L"<|im_start|>"); + special_tokens_trie_.add(L"<|im_end|>"); + special_tokens_trie_.add(L"<|object_ref_start|>"); + special_tokens_trie_.add(L"<|object_ref_end|>"); + special_tokens_trie_.add(L"<|box_start|>"); + special_tokens_trie_.add(L"<|box_end|>"); + special_tokens_trie_.add(L"<|quad_start|>"); + special_tokens_trie_.add(L"<|quad_end|>"); + special_tokens_trie_.add(L"<|vision_start|>"); + special_tokens_trie_.add(L"<|vision_end|>"); + special_tokens_trie_.add(L"<|vision_pad|>"); + special_tokens_trie_.add(L"<|image_pad|>"); + special_tokens_trie_.add(L"<|video_pad|>"); + special_tokens_trie_.add(L""); + special_tokens_trie_.add(L""); + } + + std::vector _tokenize(const std::string& str) override { + std::vector ret; + std::vector splitted; + ::mllm::models::qwen3_moe::qwen3Regex(str, splitted); + for (const auto& s : splitted) { + auto utf_8_str = preprocessor::wideString2Utf8String(s); + std::wstring mapped_str; + for (unsigned char c : utf_8_str) { mapped_str.push_back(bytes_2_unicode_dict_[c]); } + + auto bpe_ts = bpe_._bpe(mapped_str); + + for (const auto& bpe_t : bpe_ts) { ret.push_back(bpe_t); } + } + + return ret; + } + + std::vector tokenize(const std::string& str) override { + auto tokens = special_tokens_trie_.split(preprocessor::utf8string2WideString(str)); + std::vector all_tokens; + for (const auto& token : tokens) { + if (special_tokens_trie_.isSpecialToken(token)) { + all_tokens.emplace_back(token); + continue; + } + auto tmp_tokens = _tokenize(preprocessor::wideString2Utf8String(token)); + all_tokens.insert(all_tokens.end(), tmp_tokens.begin(), tmp_tokens.end()); + } + return all_tokens; + } + + std::wstring _detokenize(int64_t pos_idx) override { return bpe_._lookup_inverse_vocab(pos_idx); } + + std::wstring detokenize(int64_t pos_idx) override { + auto str = _detokenize(pos_idx); + std::string utf_8_str; + for (wchar_t c : str) { utf_8_str.push_back((unsigned char)(bytes_2_unicode_dict_inverse_[c])); } + return {mllm::preprocessor::utf8string2WideString(utf_8_str)}; + } + + Tensor convert2Ids(const std::vector& strs) override { + std::vector ids; + ids.reserve(strs.size()); + for (const auto& str : strs) { ids.emplace_back(bpe_._lookup_vocab(str)); } + Tensor ret = Tensor::empty({/*batch*/ 1, /*seq*/ (int32_t)ids.size()}, kInt64, kCPU) + .setMemType(kExtraInput) + .setName("qwen2-tokenizer-i0") + .alloc(); + + auto ptr = ret.ptr(); + for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; } + + return ret; + } + + ARGenerationOutputPast convertMessage(const Qwen3Message& message) { + // process prompt + auto applied_string = Qwen3Message::message_template; + size_t pos = applied_string.find("{{{prompt}}}"); + applied_string.replace(pos, 12, message.prompt); + + // process sequence + auto sequence_str = tokenize(applied_string); + std::vector ids; + ids.reserve(sequence_str.size()); + for (const auto& str : sequence_str) { ids.emplace_back(bpe_._lookup_vocab(str)); } + + // Get sequence Tensor + Tensor sequence = Tensor::empty({/*batch*/ 1, /*seq*/ (int32_t)ids.size()}, kInt64, kCPU) + .setMemType(kNormal) + .setName("qwen2-tokenizer-i0") + .alloc(); + + auto ptr = sequence.ptr(); + for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; } + + return { + {"sequence", sequence}, + }; + } + + private: + // For text + preprocessor::BPE bpe_; + std::unordered_map bytes_2_unicode_dict_; + std::unordered_map bytes_2_unicode_dict_inverse_; +}; + +} // namespace mllm::models::qwen3_moe