From 5c5c0b0eb4cbd3d8f5148cf548243087f66785ee Mon Sep 17 00:00:00 2001 From: HayzelHan Date: Wed, 7 Jan 2026 06:00:22 +0000 Subject: [PATCH 1/5] feat(qwen3-moe): add support for Qwen3 MoE --- examples/CMakeLists.txt | 1 + examples/qwen3_moe/CMakeLists.txt | 3 + examples/qwen3_moe/config_30B_A3B_gguf.json | 37 ++ examples/qwen3_moe/main.cpp | 74 +++ examples/qwen3_moe/quant_cfg_30B_q4_k.json | 79 +++ .../qwen3_moe/configuration_qwen3_moe.hpp | 75 +++ .../qwen3_moe/modeling_qwen3_moe_fa2.hpp | 492 ++++++++++++++++++ .../qwen3_moe/tokenization_qwen3_moe.hpp | 269 ++++++++++ 8 files changed, 1030 insertions(+) create mode 100644 examples/qwen3_moe/CMakeLists.txt create mode 100644 examples/qwen3_moe/config_30B_A3B_gguf.json create mode 100644 examples/qwen3_moe/main.cpp create mode 100644 examples/qwen3_moe/quant_cfg_30B_q4_k.json create mode 100644 mllm/models/qwen3_moe/configuration_qwen3_moe.hpp create mode 100644 mllm/models/qwen3_moe/modeling_qwen3_moe_fa2.hpp create mode 100644 mllm/models/qwen3_moe/tokenization_qwen3_moe.hpp diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 31bd8e1b..3df37bdd 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -7,6 +7,7 @@ add_subdirectory(minicpm_o) add_subdirectory(minicpm4) add_subdirectory(qwen3) add_subdirectory(qwen3_service) +add_subdirectory(qwen3_moe) add_subdirectory(deepseek_ocr) if(MLLM_BUILD_QNN_BACKEND) diff --git a/examples/qwen3_moe/CMakeLists.txt b/examples/qwen3_moe/CMakeLists.txt new file mode 100644 index 00000000..d20fa815 --- /dev/null +++ b/examples/qwen3_moe/CMakeLists.txt @@ -0,0 +1,3 @@ +add_executable(mllm-qwen3-moe-runner main.cpp) +target_link_libraries(mllm-qwen3-moe-runner PRIVATE MllmRT MllmCPUBackend) +target_include_directories(mllm-qwen3-moe-runner PRIVATE ${MLLM_INCLUDE_DIR}) diff --git a/examples/qwen3_moe/config_30B_A3B_gguf.json b/examples/qwen3_moe/config_30B_A3B_gguf.json new file mode 100644 index 00000000..0ae3fd17 --- /dev/null +++ b/examples/qwen3_moe/config_30B_A3B_gguf.json @@ -0,0 +1,37 @@ +{ + "architectures": [ + "Qwen3MoeForCausalLM" + ], + "attention_bias": false, + "bos_token_id": 151643, + "decoder_sparse_step": 1, + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 6144, + "max_position_embeddings": 262144, + "max_window_layers": 48, + "mlp_only_layers": [], + "model_type": "qwen3_moe", + "moe_intermediate_size": 768, + "norm_topk_prob": true, + "num_attention_heads": 32, + "num_experts": 128, + "num_experts_per_tok": 8, + "num_hidden_layers": 48, + "num_key_value_heads": 4, + "output_router_logits": false, + "rms_norm_eps": 1e-06, + "rope_scaling": 1.0, + "rope_theta": 10000000, + "router_aux_loss_coef": 0.001, + "tie_word_embeddings": true, + "transformers_version": "4.51.0", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 151936, + "max_cache_length": 16384, + "linear_impl_type": "Default" +} diff --git a/examples/qwen3_moe/main.cpp b/examples/qwen3_moe/main.cpp new file mode 100644 index 00000000..1fb01fd4 --- /dev/null +++ b/examples/qwen3_moe/main.cpp @@ -0,0 +1,74 @@ +#include +#include +#include +#include +#include +#include + +using mllm::Argparse; + +MLLM_MAIN({ + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model_path").help("Model path").required(true); + auto& model_version = Argparse::add("-mv|--model_version").help("Model version").required(true); + auto& tokenizer_path = Argparse::add("-t|--tokenizer_path").help("Tokenizer directory").required(true); + auto& config_path = Argparse::add("-c|--config_path").help("Config path").required(true); + + Argparse::parse(argc, argv); + +#ifdef MLLM_PERFETTO_ENABLE + mllm::perf::start(); +#endif + + mllm::ModelFileVersion file_version = mllm::ModelFileVersion::kV1; + if (model_version.get() == "v1") { + file_version = mllm::ModelFileVersion::kV1; + } else if (model_version.get() == "v2") { + file_version = mllm::ModelFileVersion::kV2; + } + + if (help.isSet()) { + Argparse::printHelp(); + mllm::shutdownContext(); + return 0; + } + + { + auto qwen3_moe_cfg = mllm::models::qwen3_moe::Qwen3MoeConfig(config_path.get()); + auto qwen3_moe_tokenizer = mllm::models::qwen3_moe::Qwen3Tokenizer(tokenizer_path.get()); + auto qwen3_moe = mllm::models::qwen3_moe::Qwen3MoeForCausalLM(qwen3_moe_cfg); + + auto param = mllm::load(model_path.get(), file_version); + qwen3_moe.load(param); + + fmt::print("\n{:*^60}\n", " Qwen3 MoE Interactive CLI "); + fmt::print("Enter 'exit' or 'quit' to end the session\n\n"); + + std::string prompt_text; + + fmt::print("šŸ’¬ Prompt text (or 'exit/quit'): "); + std::getline(std::cin, prompt_text); + + try { + fmt::print("šŸ”„ Processing...\n"); + auto inputs = qwen3_moe_tokenizer.convertMessage({.prompt = prompt_text}); + + fmt::print("\nšŸ¤– Response: "); + + // Use for loop + for (auto& step : qwen3_moe.chat(inputs)) { std::wcout << qwen3_moe_tokenizer.detokenize(step.cur_token_id) << std::flush; } + + fmt::print("\n{}\n", std::string(60, '-')); + } catch (const std::exception& e) { fmt::print("\nāŒ Error: {}\n{}\n", e.what(), std::string(60, '-')); } + + qwen3_moe.perfSummary(); + } + +#ifdef MLLM_PERFETTO_ENABLE + mllm::perf::stop(); + mllm::perf::saveReport("qwen3_moe.perf"); +#endif + + mllm::print("\n"); + mllm::memoryReport(); +}) diff --git a/examples/qwen3_moe/quant_cfg_30B_q4_k.json b/examples/qwen3_moe/quant_cfg_30B_q4_k.json new file mode 100644 index 00000000..f93829ab --- /dev/null +++ b/examples/qwen3_moe/quant_cfg_30B_q4_k.json @@ -0,0 +1,79 @@ +{ + "^model\\.layers\\.\\d+\\.self_attn\\.q_proj.(bias|weight)": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_K", + "shape": [ + 4096, + 2048 + ], + "replace": true + } + }, + "^model\\.layers\\.\\d+\\.self_attn\\.k_proj.(bias|weight)": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_K", + "shape": [ + 512, + 2048 + ], + "replace": true + } + }, + "^model\\.layers\\.\\d+\\.self_attn\\.v_proj.(bias|weight)": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q6_K", + "shape": [ + 512, + 2048 + ], + "replace": true + } + }, + "^model\\.layers\\.\\d+\\.self_attn\\.o_proj.(bias|weight)": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_K", + "shape": [ + 2048, + 4096 + ], + "replace": true + } + }, + "^model\\.layers\\.\\d+\\.mlp\\.experts\\.\\d+\\.up_proj.(bias|weight)": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_K", + "shape": [ + 768, + 2048 + ], + "replace": true + } + }, + "^model\\.layers\\.\\d+\\.mlp\\.experts\\.\\d+\\.down_proj.(bias|weight)": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q6_K", + "shape": [ + 2048, + 768 + ], + "replace": true + } + }, + "^lm_head.weight": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_K", + "shape": [ + 151936, + 2048 + ], + "replace": true + } + } +} diff --git a/mllm/models/qwen3_moe/configuration_qwen3_moe.hpp b/mllm/models/qwen3_moe/configuration_qwen3_moe.hpp new file mode 100644 index 00000000..7dd92c49 --- /dev/null +++ b/mllm/models/qwen3_moe/configuration_qwen3_moe.hpp @@ -0,0 +1,75 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include "mllm/core/aops/LinearOp.hpp" +#include "mllm/engine/ConfigFile.hpp" + +namespace mllm::models::qwen3_moe { + +struct Qwen3MoeConfig : protected ConfigFile { + Qwen3MoeConfig() = default; + + explicit Qwen3MoeConfig(const std::string& file_path) : ConfigFile(file_path) { + // Init all + attention_bias = data()["attention_bias"]; + hidden_size = data()["hidden_size"]; + intermediate_size = data()["intermediate_size"]; + num_attention_heads = data()["num_attention_heads"]; + num_key_value_heads = data()["num_key_value_heads"]; + num_hidden_layers = data()["num_hidden_layers"]; + max_position_embeddings = data()["max_position_embeddings"]; + rms_norm_eps = data()["rms_norm_eps"]; + vocab_size = data()["vocab_size"]; + head_dim = data()["head_dim"]; + + bos_token_id = data()["bos_token_id"]; + eos_token_id = data()["eos_token_id"]; + rope_theta = data()["rope_theta"]; + + tie_word_embeddings = data()["tie_word_embeddings"]; + max_cache_length = data()["max_cache_length"]; + + // MoE config + num_experts = data()["num_experts"]; + num_experts_per_tok = data()["num_experts_per_tok"]; + moe_intermediate_size = data()["moe_intermediate_size"]; + norm_topk_prob = data()["norm_topk_prob"]; + decoder_sparse_step = data()["decoder_sparse_step"]; + mlp_only_layers = data()["mlp_only_layers"].get>(); + + linear_impl_type = aops::str2LinearImplTypes(data()["linear_impl_type"]); + } + + bool attention_bias = false; + int32_t hidden_size = 2048; + int32_t head_dim = 128; + int32_t intermediate_size = 6144; + int32_t num_attention_heads = 32; + int32_t num_key_value_heads = 4; + int32_t num_hidden_layers = 48; + int32_t max_position_embeddings = 262144; + float rms_norm_eps = 1e-06; + int32_t vocab_size = 151936; + + int64_t bos_token_id = 151643; + int64_t eos_token_id = 151645; + float rope_theta = 1000000.0; + + bool tie_word_embeddings = false; + int32_t max_cache_length = 4096; + int32_t end_of_text_token_id = 151645; + int32_t thinking_start_token_id = 151667; + int32_t thinking_end_token_id = 151668; + + int32_t num_experts = 128; + int32_t num_experts_per_tok = 8; + int32_t moe_intermediate_size = 768; + bool norm_topk_prob = true; + int32_t decoder_sparse_step = 1; + std::vector mlp_only_layers; + + aops::LinearImplTypes linear_impl_type = aops::LinearImplTypes::kDefault; +}; + +} // namespace mllm::models::qwen3_moe diff --git a/mllm/models/qwen3_moe/modeling_qwen3_moe_fa2.hpp b/mllm/models/qwen3_moe/modeling_qwen3_moe_fa2.hpp new file mode 100644 index 00000000..83a569df --- /dev/null +++ b/mllm/models/qwen3_moe/modeling_qwen3_moe_fa2.hpp @@ -0,0 +1,492 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/mllm.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/nn/lmcache/StaticCache.hpp" +#include "mllm/models/qwen3_moe/configuration_qwen3_moe.hpp" +#include "mllm/utils/Enumerate.hpp" +#include "mllm/models/ARGeneration.hpp" + +namespace mllm::models::qwen3_moe { + +inline auto makeRoPEInvFreq(int output_dim, float rope_theta) -> Tensor { + auto inv_freq = Tensor::empty({output_dim / 2}, kFloat32, kCPU).alloc(); + auto inv_freq_ptr = inv_freq.ptr(); + for (int i = 0; i < output_dim / 2; i++) { inv_freq_ptr[i] = 1.0 / std::pow(rope_theta, 2.0 * i / output_dim); } + return inv_freq; +} + +inline auto makeRotaryPosEmbedding(Tensor& position_ids, const Tensor& inv_freq, + float attention_scaling = 1.0f) -> std::pair { + auto batch_size = position_ids.shape()[0]; + auto seq_len = position_ids.shape()[1]; + auto inv_freq_len = inv_freq.shape()[0]; + auto dim = inv_freq_len * 2; + + // Create freqs tensor: position_ids @ inv_freq + auto freqs = Tensor::empty({batch_size, seq_len, inv_freq_len}, kFloat32, kCPU).alloc(); + auto freqs_ptr = freqs.ptr(); + auto position_ids_ptr = position_ids.ptr(); + auto inv_freq_ptr = inv_freq.ptr(); + + // Compute freqs = position_ids[:, :, None] @ inv_freq[None, :] + for (int b = 0; b < batch_size; ++b) { + for (int s = 0; s < seq_len; ++s) { + auto pos = position_ids_ptr[b * seq_len + s]; + for (int d = 0; d < inv_freq_len; ++d) { + freqs_ptr[b * seq_len * inv_freq_len + s * inv_freq_len + d] = static_cast(pos) * inv_freq_ptr[d]; + } + } + } + + // Create sin and cos tensors with shape [batch_size, seq_len, dim] + auto sin_emb = Tensor::empty({batch_size, seq_len, dim}, kFloat32, kCPU).alloc(); + auto cos_emb = Tensor::empty({batch_size, seq_len, dim}, kFloat32, kCPU).alloc(); + auto sin_ptr = sin_emb.ptr(); + auto cos_ptr = cos_emb.ptr(); + + // Compute sin and cos embeddings: emb = [freqs, freqs] + for (int b = 0; b < batch_size; ++b) { + for (int s = 0; s < seq_len; ++s) { + for (int d = 0; d < inv_freq_len; ++d) { + auto freq = freqs_ptr[b * seq_len * inv_freq_len + s * inv_freq_len + d]; + auto sin_val = std::sin(freq) * attention_scaling; + auto cos_val = std::cos(freq) * attention_scaling; + + // Store the same values in both halves: [freqs, freqs] + sin_ptr[b * seq_len * dim + s * dim + d] = sin_val; + sin_ptr[b * seq_len * dim + s * dim + d + inv_freq_len] = sin_val; + cos_ptr[b * seq_len * dim + s * dim + d] = cos_val; + cos_ptr[b * seq_len * dim + s * dim + d + inv_freq_len] = cos_val; + } + } + } + + return {sin_emb, cos_emb}; +} + +class Qwen3MoeMLP final : public nn::Module { + nn::Linear gate_proj_; + nn::Linear up_proj_; + nn::Linear down_proj_; + nn::SiLU act_; + + int hidden_size_; + int intermediate_size_; + + public: + Qwen3MoeMLP() = default; + + explicit Qwen3MoeMLP(const std::string& name, const Qwen3MoeConfig& config, + const std::optional& hidden_size = std::nullopt, + const std::optional& intermediate_size = std::nullopt) + : nn::Module(name) { + hidden_size_ = hidden_size.value_or(config.hidden_size); + intermediate_size_ = intermediate_size.value_or(config.intermediate_size); + + // clang-format off + gate_proj_ = reg("gate_proj", hidden_size_, intermediate_size_, false, config.linear_impl_type); + up_proj_ = reg("up_proj", hidden_size_, intermediate_size_, false, config.linear_impl_type); + down_proj_ = reg("down_proj", intermediate_size_, hidden_size_, false, config.linear_impl_type); + act_ = reg("act"); + // clang-format on + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + return {down_proj_(act_(gate_proj_(inputs[0])) * up_proj_(inputs[0]))}; + } +}; + +class MoEGate final : public nn::Module { + int top_k_; + int num_experts_; + bool norm_topk_prob_; + + nn::Param weight_; + + public: + MoEGate() = default; + + MoEGate(const std::string& name, const Qwen3MoeConfig& config) : nn::Module(name) { + top_k_ = config.num_experts_per_tok; + num_experts_ = config.num_experts; + norm_topk_prob_ = config.norm_topk_prob; + + weight_ = reg("weight", getModuleName() + ".weight"); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto bsz = hidden_states.size(0); + auto seq_len = hidden_states.size(1); + auto h = hidden_states.size(2); + + // Compute gating score + hidden_states = hidden_states.view({-1, h}); + // hidden_states and weight must in fp32 to keep precision !!! + auto logits = nn::functional::matmul(hidden_states, weight_.weight(), false, true); + auto scores = nn::functional::softmax(logits, -1); + auto [topk_weight, topk_idx] = nn::functional::topk(scores, top_k_, -1, true, false); + + if(norm_topk_prob_){ + topk_weight = topk_weight / topk_weight.sum(-1, true); + } + + return {topk_idx, topk_weight}; + } +}; + +class Qwen3MoE final : public nn::Module { + int num_experts_per_tok_; + nn::ModuleList experts_; + MoEGate gate_; + + public: + Qwen3MoE() = default; + + Qwen3MoE(const std::string& name, const Qwen3MoeConfig& config) : nn::Module(name) { + num_experts_per_tok_ = config.num_experts_per_tok; + // Init experts + experts_ = reg>("experts", config.num_experts, config, std::nullopt, + config.moe_intermediate_size); + gate_ = reg("gate", config); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto identity = hidden_states; + auto orig_shape = hidden_states.shape(); + auto topk_idx = Tensor::nil(); + auto topk_weight = Tensor::nil(); + auto gated_ret = gate_(hidden_states); + topk_idx = gated_ret[0]; + topk_weight = gated_ret[1]; + hidden_states = hidden_states.view({-1, hidden_states.size(-1)}); + + auto y = moeInfer(hidden_states, topk_idx, topk_weight).view(orig_shape); + + return {y}; + } + + private: + Tensor moeInfer(const Tensor& x, Tensor& topk_ids, Tensor& topk_weights) { + // x shape is [batch_size * seq, hidden_dim] + + auto cnts = Tensor::zeros({topk_ids.size(0), (int32_t)experts_.list().size()}); + // Do scatter_ operation + { + const int32_t* idx_ptr = topk_ids.ptr(); + float* cnt_ptr = cnts.ptr(); + const int batch = topk_ids.size(0); + const int k = topk_ids.size(1); + const int n_exp = cnts.size(1); + for (int b = 0; b < batch; ++b) { + for (int j = 0; j < k; ++j) { + int32_t e = idx_ptr[b * k + j]; + MLLM_RT_ASSERT(e >= 0 && e < n_exp); + cnt_ptr[b * n_exp + e] += 1.f; // +1 + } + } + } + auto tokens_per_expert = cnts.sum(0); + auto idxs = topk_ids.view({-1}).argsort(); + + // TODO this line maybe error + auto sorted_tokens = x[{idxs / topk_ids.size(1), {kAll}}]; + + std::vector outputs; + int start_idx = 0; + + // tokens_per_expert shape is [num_experts] + // Loop through each expert + for (int i = 0; i < experts_.list().size(); ++i) { + auto num_tokens = tokens_per_expert.ptr()[i]; + auto end_idx = start_idx + (int32_t)num_tokens; + if (num_tokens == 0) { continue; } + auto& expert = experts_.list()[i]; + auto tokens_for_this_expert = sorted_tokens[{{start_idx, end_idx}, kAll}]; + auto expert_out = expert(tokens_for_this_expert)[0]; + outputs.push_back(expert_out); + start_idx = end_idx; + } + + auto outs = nn::functional::concat(outputs, 0); + auto new_x = Tensor::emptyLike(outs).alloc(); + + // indexed_write + // python logic: new_x[idxs] = outs + { + const int32_t* idx_ptr = idxs.ptr(); + float* outs_ptr = outs.ptr(); + float* new_x_ptr = new_x.ptr(); + MLLM_RT_ASSERT_EQ(new_x.rank(), 2); + MLLM_RT_ASSERT_EQ(new_x.size(0), idxs.size(0)); + auto dim = new_x.size(1); + for (int i = 0; i < idxs.size(0); ++i) { + int32_t idx = idx_ptr[i]; + std::memcpy(new_x_ptr + idx * dim, outs_ptr + i * dim, dim * sizeof(float)); + } + } + + auto final_out_shape = topk_ids.shape(); + final_out_shape.emplace_back(-1); + auto final_out = + new_x.view(final_out_shape).to(topk_weights.dtype()).mul_(topk_weights.unsqueeze(-1)).sum(1).to(new_x.dtype()); + return final_out; + } +}; + +class Qwen3MoeAttention final : public nn::Module { + nn::Linear q_proj_; + nn::Linear k_proj_; + nn::Linear v_proj_; + nn::Linear o_proj_; + nn::RMSNorm rms_norm_q_; + nn::RMSNorm rms_norm_k_; + nn::RoPE q_rope_; + nn::RoPE k_rope_; + + int hidden_size_; + int head_dim_; + int num_attention_heads_; + int num_key_value_heads_; + int num_key_value_groups_; + + public: + Qwen3MoeAttention() = default; + + Qwen3MoeAttention(const std::string& name, const Qwen3MoeConfig& cfg) : nn::Module(name) { + hidden_size_ = cfg.hidden_size; + num_attention_heads_ = cfg.num_attention_heads; + num_key_value_heads_ = cfg.num_key_value_heads; + head_dim_ = cfg.head_dim; + num_key_value_groups_ = num_attention_heads_ / num_key_value_heads_; + + // clang-format off + q_proj_ = reg("q_proj", hidden_size_, head_dim_ * num_attention_heads_, cfg.attention_bias, cfg.linear_impl_type); + k_proj_ = reg("k_proj", hidden_size_, head_dim_ * num_key_value_heads_, cfg.attention_bias, cfg.linear_impl_type).redirect(); + v_proj_ = reg("v_proj", hidden_size_, head_dim_ * num_key_value_heads_, cfg.attention_bias, cfg.linear_impl_type).redirect(); + o_proj_ = reg("o_proj", head_dim_ * num_attention_heads_, hidden_size_, cfg.attention_bias, cfg.linear_impl_type); + // clang-format on + + rms_norm_q_ = reg("q_norm", cfg.rms_norm_eps).inplace(); + rms_norm_k_ = reg("k_norm", cfg.rms_norm_eps).inplace(); + + // clang-format off + q_rope_ = reg("q_rope", cfg.rope_theta, cfg.max_position_embeddings, aops::RoPEOpOptionsInputType::kBSHD).inplace(); + k_rope_ = reg("k_rope", cfg.rope_theta, cfg.max_position_embeddings, aops::RoPEOpOptionsInputType::kBSHD).inplace(); + // clang-format on + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = inputs[0]; + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto past_kv_cache = args[0].get(); + + int B = inputs[0].shape()[0]; + int S = inputs[0].shape()[1]; + + // Get KV cache for Key and Value first. + // [B, S, H * D] + auto [key_states_redirect, value_states_redirect] = past_kv_cache->preGetKVWriteLocation(layer_idx_, S); + + // [B, S, H * D] + auto query_states = q_proj_(x); + auto key_states = k_proj_(x, key_states_redirect); + auto value_states = v_proj_(x, value_states_redirect); + + // [B, S, H, D] + query_states = query_states.view({B, S, num_attention_heads_, head_dim_}); + key_states = key_states.view({B, S, num_key_value_heads_, head_dim_}); + + // [B, S, H, D] + query_states = rms_norm_q_(query_states); + key_states = rms_norm_k_(key_states); + + // [B, S, H, D] + query_states = q_rope_(query_states, llm_embedding_sin, llm_embedding_cos); + key_states = k_rope_(key_states, llm_embedding_sin, llm_embedding_cos); + + // Get KV + auto [K, V] = past_kv_cache->getKVCache(layer_idx_); + + // [B, S, H, D] FA2 + auto output = o_proj_(nn::functional::flashAttention2(query_states, K, V).view({B, S, num_attention_heads_ * head_dim_})); + + return {output}; + } + + int layer_idx_; +}; + +class Qwen3MoeDecoder final : public nn::Module { + Qwen3MoeAttention self_attn_; + nn::RMSNorm input_layer_norm_; + nn::RMSNorm post_attention_layer_norm_; + + std::optional mlp_opt0_ = std::nullopt; + std::optional mlp_opt1_ = std::nullopt; + + public: + int layer_idx_; + + Qwen3MoeDecoder() = default; + + Qwen3MoeDecoder(const std::string& name, const Qwen3MoeConfig& cfg, int layer_idx) : nn::Module(name) { + layer_idx_ = layer_idx; + + self_attn_ = reg("self_attn", cfg); + self_attn_.layer_idx_ = layer_idx; + + bool is_mlp_only = std::find(cfg.mlp_only_layers.begin(), cfg.mlp_only_layers.end(), layer_idx) != cfg.mlp_only_layers.end(); + if ((!is_mlp_only) && (cfg.num_experts > 0 && (layer_idx_+1) % cfg.decoder_sparse_step == 0)) { + mlp_opt0_ = reg("mlp", cfg); + } else { + mlp_opt1_ = reg("mlp", cfg); + } + + input_layer_norm_ = reg("input_layernorm", cfg.rms_norm_eps); + post_attention_layer_norm_ = reg("post_attention_layernorm", cfg.rms_norm_eps); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto& kv_cache = args[0]; + + auto x = input_layer_norm_(inputs[0]); + x = self_attn_(x, llm_embedding_sin, llm_embedding_cos, kv_cache)[0]; + auto tmp = x + inputs[0]; + x = post_attention_layer_norm_(tmp); + if(mlp_opt0_){ + x = mlp_opt0_.value()(x)[0]; + } else { + x = mlp_opt1_.value()(x)[0]; + } + x = x + tmp; + return {x}; + } +}; + +class Qwen3MoeText final : public nn::Module { + nn::Embedding embedding_; + nn::ModuleListWithIdx decode_blocks_; + nn::RMSNorm norm_; + + public: + Qwen3MoeText() = default; + + explicit Qwen3MoeText(const std::string& name, const Qwen3MoeConfig& cfg) : nn::Module(name) { + embedding_ = reg("embed_tokens", cfg.vocab_size, cfg.hidden_size); + decode_blocks_ = reg>("layers", cfg.num_hidden_layers, cfg); + norm_ = reg("norm", cfg.rms_norm_eps); + + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto& blocks = decode_blocks_.list(); + + // X is already embedded + auto x = embedding_(inputs[0]); + + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto& kv_cache = args[0]; + + for (auto& block : blocks) { x = block(x, llm_embedding_sin, llm_embedding_cos, kv_cache)[0]; } + + x = norm_(x); + + return {x}; + } +}; + +class Qwen3MoeForCausalLM : public ARGeneration, public nn::Module { + public: + explicit Qwen3MoeForCausalLM(const Qwen3MoeConfig& cfg) : cfg(cfg) { + kv_cache_ = nn::StaticCache(cfg.max_cache_length, cfg.num_hidden_layers, + cfg.num_attention_heads, // q_heads + cfg.num_key_value_heads, // kv_heads + cfg.head_dim, // kv_dim + kFloat32, // k_dtype + kFloat32, // v_dtype + kCPU, // device_type + true // use_fa2 + ); + eos_token_id_ = cfg.end_of_text_token_id; + max_length_ = cfg.max_cache_length; + tie_word_embeddings_ = cfg.tie_word_embeddings; + + llm = reg("model", cfg); + + if (cfg.tie_word_embeddings) { + // NOTE: + // model.lm_head.weight is quantization weights of model.embed_tokens.weight + lm_head_ = reg("lm_head", cfg.hidden_size, cfg.vocab_size, false, cfg.linear_impl_type); + } + + // Init inv freq + auto inv = makeRoPEInvFreq(cfg.head_dim, cfg.rope_theta); + registerBuffer("inv_freq", inv); + } + + ARGenerationOutputPast forward(const ARGenerationOutputPast& input, const ARGenerationArgs& args) override { + auto sequence = input.at("sequence"); + + // Generate position_ids for the current sequence + auto batch_size = sequence.shape()[0]; + auto seq_len = sequence.shape()[1]; + + Tensor position_ids = Tensor::nil(); + if (input.count("position_ids")) { + // Use existing position_ids for decode phase + position_ids = input.at("position_ids"); + + // For decode phase, increment the last position + if (seq_len == 1) { + auto last_pos = *position_ids.offsettedPtr({0, position_ids.shape()[1] - 1}); + position_ids = Tensor::empty({batch_size, 1}, kInt64, kCPU).alloc(); + *position_ids.offsettedPtr({0, 0}) = last_pos + 1; + } + } else { + // Generate position_ids for prefill phase + position_ids = Tensor::empty({batch_size, seq_len}, kInt64, kCPU).alloc(); + auto position_ids_ptr = position_ids.ptr(); + for (int b = 0; b < batch_size; ++b) { + for (int s = 0; s < seq_len; ++s) { position_ids_ptr[b * seq_len + s] = s; } + } + } + + // Generate RoPE embeddings using the inv_freq buffer + auto [llm_embedding_sin, llm_embedding_cos] = makeRotaryPosEmbedding(position_ids, getBuffer("inv_freq"), 1.0f); + + sequence = llm(sequence, llm_embedding_sin, llm_embedding_cos, AnyValue(&kv_cache_))[0]; + + // clip x to one seq length + { + auto S = sequence.shape()[1]; + sequence = sequence[{kAll, {S - 1}, kAll}]; + } + if (tie_word_embeddings_) { sequence = lm_head_(sequence); } + + return { + {"sequence", sequence}, + {"position_ids", position_ids}, + }; + } + + inline nn::StaticCache& kvCache() { return kv_cache_; } + + private: + const Qwen3MoeConfig& cfg; + Qwen3MoeText llm; + nn::Linear lm_head_; + bool tie_word_embeddings_; + nn::StaticCache kv_cache_; +}; + +} // namespace mllm::models::qwen3_moe diff --git a/mllm/models/qwen3_moe/tokenization_qwen3_moe.hpp b/mllm/models/qwen3_moe/tokenization_qwen3_moe.hpp new file mode 100644 index 00000000..181d576f --- /dev/null +++ b/mllm/models/qwen3_moe/tokenization_qwen3_moe.hpp @@ -0,0 +1,269 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include + +#include "mllm/preprocessor/tokenizers/BPE.hpp" +#include "mllm/models/ARGeneration.hpp" +#include "mllm/preprocessor/tokenizers/Unicode.hpp" +#include "mllm/preprocessor/tokenizers/AutoTokenizer.hpp" + +namespace mllm::models::qwen3_moe { + +// we need to handle this: +// +// (?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| +// ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+ +inline bool qwen3TokenizerMatchPattern(const std::wstring& str, size_t& pos, std::wstring& matched) { + if (pos >= str.size()) return false; + + // 1. Match contractions: "'s|'t|'re|'ve|'m|'ll|'d" + static const std::wstring contractions[] = {L"'s", L"'t", L"'re", L"'ve", L"'m", L"'ll", L"'d"}; + for (const auto& contraction : contractions) { + if (pos + contraction.size() <= str.size() && str.compare(pos, contraction.size(), contraction) == 0) { + matched = contraction; + pos += contraction.size(); + return true; + } + } + + // 2. Match [^\r\n\p{L}\p{N}]?\p{L}+ (non-letter/digit followed by letters) + { + size_t original_pos = pos; + bool has_prefix = false; + matched.clear(); + + // Check optional non-letter/digit prefix (excluding \r\n) + if (!preprocessor::isLetter(str[pos]) && !preprocessor::isDigit(str[pos]) && str[pos] != L'\r' && str[pos] != L'\n') { + matched += str[pos]; + ++pos; + has_prefix = true; + } + + // Require at least one letter + if (pos < str.size() && preprocessor::isLetter(str[pos])) { + do { + matched += str[pos]; + ++pos; + } while (pos < str.size() && preprocessor::isLetter(str[pos])); + return true; + } else { + // Rollback if no letters after prefix + if (has_prefix) { + pos = original_pos; + matched.clear(); + } + } + } + + // 3. Match \p{N} (digits) + if (preprocessor::isDigit(str[pos])) { + matched = str.substr(pos, 1); + ++pos; + return true; + } + + // 4. Match ?[^\s\p{L}\p{N}]+[\r\n]* (punctuation/symbols with optional space prefix) + { + size_t original_pos = pos; + matched.clear(); + size_t start = pos; + + // Optional space + if (str[pos] == L' ') { ++pos; } + + // Require at least one non-letter/digit/whitespace + if (pos < str.size() && !std::iswspace(str[pos]) && !preprocessor::isLetter(str[pos]) && !preprocessor::isDigit(str[pos])) { + do { + ++pos; + } while (pos < str.size() && !std::iswspace(str[pos]) && !preprocessor::isLetter(str[pos]) + && !preprocessor::isDigit(str[pos])); + + // Capture from start (after optional space) to current pos + matched = str.substr(start, pos - start); + + // Capture trailing newlines + while (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) { + matched += str[pos]; + ++pos; + } + return true; + } else { + // Rollback if no symbols found + pos = original_pos; + } + } + + // 5. Match \s*[\r\n]+ (newlines with leading whitespace) + { + size_t start = pos; + while (pos < str.size() && std::iswspace(str[pos])) ++pos; + if (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) { + while (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) ++pos; + matched = str.substr(start, pos - start); + return true; + } else { + pos = start; + } + } + + // 6. Match \s+(?!\S) (whitespace not followed by non-space) + if (std::iswspace(str[pos])) { + size_t start = pos; + while (pos < str.size() && std::iswspace(str[pos])) ++pos; + // Check if at end or followed by whitespace + if (pos >= str.size() || std::iswspace(str[pos])) { + matched = str.substr(start, pos - start); + return true; + } else { + pos = start; + } + } + + // 7. Match remaining whitespace + if (std::iswspace(str[pos])) { + size_t start = pos; + while (pos < str.size() && std::iswspace(str[pos])) ++pos; + matched = str.substr(start, pos - start); + return true; + } + + return false; +} + +inline bool qwen3Regex(const std::string& str, std::vector& splitted) { + auto w_string = preprocessor::utf8string2WideString(str); + size_t pos = 0; + while (pos < w_string.size()) { + std::wstring matched; + if (qwen3TokenizerMatchPattern(w_string, pos, matched)) { + splitted.push_back(matched); + } else { + ++pos; + } + } + return true; +} + +struct Qwen3Message { + std::string prompt; + static inline std::string message_template = + "<|im_start|>user\n{{{prompt}}}<|im_end|>\n<|im_start|>assistant\n"; +}; + +class Qwen3Tokenizer final : public mllm::preprocessor::AutoTokenizer { + public: + explicit Qwen3Tokenizer(const std::string& file_path) { + preprocessor::initLocal(); + preprocessor::makeBytes2UnicodeMap(bytes_2_unicode_dict_); + for (auto& kv : bytes_2_unicode_dict_) { bytes_2_unicode_dict_inverse_.insert({kv.second, kv.first}); } + bpe_.initFromSentencePieceJson(file_path); + special_tokens_trie_.add(L"<|endoftext|>"); + special_tokens_trie_.add(L"<|im_start|>"); + special_tokens_trie_.add(L"<|im_end|>"); + special_tokens_trie_.add(L"<|object_ref_start|>"); + special_tokens_trie_.add(L"<|object_ref_end|>"); + special_tokens_trie_.add(L"<|box_start|>"); + special_tokens_trie_.add(L"<|box_end|>"); + special_tokens_trie_.add(L"<|quad_start|>"); + special_tokens_trie_.add(L"<|quad_end|>"); + special_tokens_trie_.add(L"<|vision_start|>"); + special_tokens_trie_.add(L"<|vision_end|>"); + special_tokens_trie_.add(L"<|vision_pad|>"); + special_tokens_trie_.add(L"<|image_pad|>"); + special_tokens_trie_.add(L"<|video_pad|>"); + special_tokens_trie_.add(L""); + special_tokens_trie_.add(L""); + } + + std::vector _tokenize(const std::string& str) override { + std::vector ret; + std::vector splitted; + ::mllm::models::qwen3_moe::qwen3Regex(str, splitted); + for (const auto& s : splitted) { + auto utf_8_str = preprocessor::wideString2Utf8String(s); + std::wstring mapped_str; + for (unsigned char c : utf_8_str) { mapped_str.push_back(bytes_2_unicode_dict_[c]); } + + auto bpe_ts = bpe_._bpe(mapped_str); + + for (const auto& bpe_t : bpe_ts) { ret.push_back(bpe_t); } + } + + return ret; + } + + std::vector tokenize(const std::string& str) override { + auto tokens = special_tokens_trie_.split(preprocessor::utf8string2WideString(str)); + std::vector all_tokens; + for (const auto& token : tokens) { + if (special_tokens_trie_.isSpecialToken(token)) { + all_tokens.emplace_back(token); + continue; + } + auto tmp_tokens = _tokenize(preprocessor::wideString2Utf8String(token)); + all_tokens.insert(all_tokens.end(), tmp_tokens.begin(), tmp_tokens.end()); + } + return all_tokens; + } + + std::wstring _detokenize(int64_t pos_idx) override { return bpe_._lookup_inverse_vocab(pos_idx); } + + std::wstring detokenize(int64_t pos_idx) override { + auto str = _detokenize(pos_idx); + std::string utf_8_str; + for (wchar_t c : str) { utf_8_str.push_back((unsigned char)(bytes_2_unicode_dict_inverse_[c])); } + return {mllm::preprocessor::utf8string2WideString(utf_8_str)}; + } + + Tensor convert2Ids(const std::vector& strs) override { + std::vector ids; + ids.reserve(strs.size()); + for (const auto& str : strs) { ids.emplace_back(bpe_._lookup_vocab(str)); } + Tensor ret = Tensor::empty({/*batch*/ 1, /*seq*/ (int32_t)ids.size()}, kInt64, kCPU) + .setMemType(kExtraInput) + .setName("qwen2-tokenizer-i0") + .alloc(); + + auto ptr = ret.ptr(); + for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; } + + return ret; + } + + ARGenerationOutputPast convertMessage(const Qwen3Message& message) { + // process prompt + auto applied_string = Qwen3Message::message_template; + size_t pos = applied_string.find("{{{prompt}}}"); + applied_string.replace(pos, 12, message.prompt); + + // process sequence + auto sequence_str = tokenize(applied_string); + std::vector ids; + ids.reserve(sequence_str.size()); + for (const auto& str : sequence_str) { ids.emplace_back(bpe_._lookup_vocab(str)); } + + // Get sequence Tensor + Tensor sequence = Tensor::empty({/*batch*/ 1, /*seq*/ (int32_t)ids.size()}, kInt64, kCPU) + .setMemType(kNormal) + .setName("qwen2-tokenizer-i0") + .alloc(); + + auto ptr = sequence.ptr(); + for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; } + + return { + {"sequence", sequence}, + }; + } + + private: + // For text + preprocessor::BPE bpe_; + std::unordered_map bytes_2_unicode_dict_; + std::unordered_map bytes_2_unicode_dict_inverse_; +}; + +} // namespace mllm::models::qwen3_moe From ad3cb3cf95619ffd29bd7a47f84f4682aa76aa81 Mon Sep 17 00:00:00 2001 From: HayzelHan Date: Fri, 30 Jan 2026 09:09:02 +0000 Subject: [PATCH 2/5] feat(x86): add ops (elewise/reduce) for x86 cpu - add elewise/reduce kernels implementation - remove matmul restriction for vecdot --- .../cpu/kernels/common/elewise-inl.hpp | 225 +++++++++++++----- .../kernels/common/ggml/quantize/quantize.hpp | 2 +- .../cpu/kernels/common/kernel_dispatch.cpp | 182 +++++++++++++- .../cpu/kernels/common/kernel_dispatch.hpp | 55 ++++- .../cpu/kernels/common/reduce-inl.hpp | 123 ++++++++++ mllm/backends/cpu/ops/ElewiseOps.cpp | 206 +++++++++++----- mllm/backends/cpu/ops/MatMulOp.cpp | 14 +- mllm/backends/cpu/ops/ReduceOps.cpp | 5 + 8 files changed, 676 insertions(+), 136 deletions(-) diff --git a/mllm/backends/cpu/kernels/common/elewise-inl.hpp b/mllm/backends/cpu/kernels/common/elewise-inl.hpp index a2f2ee42..b839e32d 100644 --- a/mllm/backends/cpu/kernels/common/elewise-inl.hpp +++ b/mllm/backends/cpu/kernels/common/elewise-inl.hpp @@ -8,31 +8,6 @@ HWY_BEFORE_NAMESPACE(); namespace mllm::cpu::common { // NOLINT namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; - -//===----------------------------------------------------------------------===// -// Elementwise + - * / By Matrix -//===----------------------------------------------------------------------===// -template -HWY_INLINE void elementwise_impl(const T* HWY_RESTRICT x, const T* HWY_RESTRICT y, T* HWY_RESTRICT out, size_t count, Op&& op) { - const hn::ScalableTag d; - const size_t N = hn::Lanes(d); - size_t idx = 0; - - for (; idx + N <= count; idx += N) { - const hn::Vec vx = hn::LoadU(d, x + idx); - const hn::Vec vy = hn::LoadU(d, y + idx); - const hn::Vec result = op(d, vx, vy); - hn::StoreU(result, d, out + idx); - } - - if (idx < count) { - const hn::Vec vx = hn::LoadN(d, x + idx, count - idx); - const hn::Vec vy = hn::LoadN(d, y + idx, count - idx); - const hn::Vec result = op(d, vx, vy); - hn::StoreN(result, d, out + idx, count - idx); - } -} - struct AddOp { template HWY_INLINE V operator()(D d, V a, V b) const { @@ -61,6 +36,30 @@ struct DivOp { } }; +//===----------------------------------------------------------------------===// +// Elementwise + - * / By Matrix +//===----------------------------------------------------------------------===// +template +HWY_INLINE void elementwise_impl(const T* HWY_RESTRICT x, const T* HWY_RESTRICT y, T* HWY_RESTRICT out, size_t count, Op&& op) { + const hn::ScalableTag d; + const size_t N = hn::Lanes(d); + size_t idx = 0; + + for (; idx + N <= count; idx += N) { + const hn::Vec vx = hn::LoadU(d, x + idx); + const hn::Vec vy = hn::LoadU(d, y + idx); + const hn::Vec result = op(d, vx, vy); + hn::StoreU(result, d, out + idx); + } + + if (idx < count) { + const hn::Vec vx = hn::LoadN(d, x + idx, count - idx); + const hn::Vec vy = hn::LoadN(d, y + idx, count - idx); + const hn::Vec result = op(d, vx, vy); + hn::StoreN(result, d, out + idx, count - idx); + } +} + HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_add_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, const mllm_fp32_t* y, size_t n) { elementwise_impl(x, y, out, n, AddOp{}); } @@ -77,12 +76,81 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_div_fp32(mllm_fp32_t* out, const mllm elementwise_impl(x, y, out, n, DivOp{}); } + +// HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_add_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, const mllm_fp16_t* y, size_t n) { +// elementwise_impl(x, y, out, n, AddOp{}); +// } + +// HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_sub_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, const mllm_fp16_t* y, size_t n) { +// elementwise_impl(x, y, out, n, SubOp{}); +// } + +// HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_mul_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, const mllm_fp16_t* y, size_t n) { +// elementwise_impl(x, y, out, n, MulOp{}); +// } + +// HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_div_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, const mllm_fp16_t* y, size_t n) { +// elementwise_impl(x, y, out, n, DivOp{}); +// } + + +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_add_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t* y, size_t n) { + elementwise_impl(x, y, out, n, AddOp{}); +} + +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_sub_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t* y, size_t n) { + elementwise_impl(x, y, out, n, SubOp{}); +} + +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_mul_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t* y, size_t n) { + elementwise_impl(x, y, out, n, MulOp{}); +} + +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_div_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t* y, size_t n) { + elementwise_impl(x, y, out, n, DivOp{}); +} + + +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_add_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t* y, size_t n) { + elementwise_impl(x, y, out, n, AddOp{}); +} + +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_sub_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t* y, size_t n) { + elementwise_impl(x, y, out, n, SubOp{}); +} + +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_mul_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t* y, size_t n) { + elementwise_impl(x, y, out, n, MulOp{}); +} + +// HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_div_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t* y, size_t n) { +// elementwise_impl(x, y, out, n, DivOp{}); +// } + + +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_add_int8(mllm_int8_t* out, const mllm_int8_t* x, const mllm_int8_t* y, size_t n) { + elementwise_impl(x, y, out, n, AddOp{}); +} + +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_sub_int8(mllm_int8_t* out, const mllm_int8_t* x, const mllm_int8_t* y, size_t n) { + elementwise_impl(x, y, out, n, SubOp{}); +} + +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_mul_int8(mllm_int8_t* out, const mllm_int8_t* x, const mllm_int8_t* y, size_t n) { + elementwise_impl(x, y, out, n, MulOp{}); +} + +// HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_div_int8(mllm_int8_t* out, const mllm_int8_t* x, const mllm_int8_t* y, size_t n) { +// elementwise_impl(x, y, out, n, DivOp{}); +// } + + //===----------------------------------------------------------------------===// // Elementwise + - * / By Const //===----------------------------------------------------------------------===// template -HWY_INLINE void elementwise_scalar_impl(T* HWY_RESTRICT out, const T* HWY_RESTRICT x, const T y, size_t count, Op&& op) { +HWY_INLINE void elementwise_scl_impl(T* HWY_RESTRICT out, const T* HWY_RESTRICT x, const T y, size_t count, Op&& op) { const hn::ScalableTag d; const size_t N = hn::Lanes(d); size_t idx = 0; @@ -103,50 +171,91 @@ HWY_INLINE void elementwise_scalar_impl(T* HWY_RESTRICT out, const T* HWY_RESTRI } } -struct AddScalarOp { - template - HWY_INLINE V operator()(D d, V a, V b) const { - return hn::Add(a, b); - } -}; +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_add_scl_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, const mllm_fp32_t y, size_t n) { + elementwise_scl_impl(out, x, y, n, AddOp{}); +} -struct SubScalarOp { - template - HWY_INLINE V operator()(D d, V a, V b) const { - return hn::Sub(a, b); - } -}; +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_sub_scl_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, const mllm_fp32_t y, size_t n) { + elementwise_scl_impl(out, x, y, n, SubOp{}); +} -struct MulScalarOp { - template - HWY_INLINE V operator()(D d, V a, V b) const { - return hn::Mul(a, b); - } -}; +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_mul_scl_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, const mllm_fp32_t y, size_t n) { + elementwise_scl_impl(out, x, y, n, MulOp{}); +} -struct DivScalarOp { - template - HWY_INLINE V operator()(D d, V a, V b) const { - return hn::Div(a, b); - } -}; +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_div_scl_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, const mllm_fp32_t y, size_t n) { + elementwise_scl_impl(out, x, y, n, DivOp{}); +} + + +// HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_add_scl_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, const mllm_fp16_t y, size_t n) { +// elementwise_scl_impl(out, x, y, n, AddOp{}); +// } + +// HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_sub_scl_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, const mllm_fp16_t y, size_t n) { +// elementwise_scl_impl(out, x, y, n, SubOp{}); +// } + +// HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_mul_scl_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, const mllm_fp16_t y, size_t n) { +// elementwise_scl_impl(out, x, y, n, MulOp{}); +// } + +// HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_div_scl_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, const mllm_fp16_t y, size_t n) { +// elementwise_scl_impl(out, x, y, n, DivOp{}); +// } + + +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_add_scl_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t y, size_t n) { + elementwise_scl_impl(out, x, y, n, AddOp{}); +} -HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_add_scalar_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, const mllm_fp32_t y, size_t n) { - elementwise_scalar_impl(out, x, y, n, AddScalarOp{}); +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_sub_scl_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t y, size_t n) { + elementwise_scl_impl(out, x, y, n, SubOp{}); } -HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_sub_scalar_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, const mllm_fp32_t y, size_t n) { - elementwise_scalar_impl(out, x, y, n, SubScalarOp{}); +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_mul_scl_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t y, size_t n) { + elementwise_scl_impl(out, x, y, n, MulOp{}); } -HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_mul_scalar_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, const mllm_fp32_t y, size_t n) { - elementwise_scalar_impl(out, x, y, n, MulScalarOp{}); +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_div_scl_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t y, size_t n) { + elementwise_scl_impl(out, x, y, n, DivOp{}); } -HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_div_scalar_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, const mllm_fp32_t y, size_t n) { - elementwise_scalar_impl(out, x, y, n, DivScalarOp{}); + +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_add_scl_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t y, size_t n) { + elementwise_scl_impl(out, x, y, n, AddOp{}); +} + +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_sub_scl_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t y, size_t n) { + elementwise_scl_impl(out, x, y, n, SubOp{}); +} + +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_mul_scl_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t y, size_t n) { + elementwise_scl_impl(out, x, y, n, MulOp{}); } +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_div_scl_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t y, size_t n) { + elementwise_scl_impl(out, x, y, n, DivOp{}); +} + + +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_add_scl_int8(mllm_int8_t* out, const mllm_int8_t* x, const mllm_int8_t y, size_t n) { + elementwise_scl_impl(out, x, y, n, AddOp{}); +} + +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_sub_scl_int8(mllm_int8_t* out, const mllm_int8_t* x, const mllm_int8_t y, size_t n) { + elementwise_scl_impl(out, x, y, n, SubOp{}); +} + +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_mul_scl_int8(mllm_int8_t* out, const mllm_int8_t* x, const mllm_int8_t y, size_t n) { + elementwise_scl_impl(out, x, y, n, MulOp{}); +} + +HWY_NOINLINE HWY_MAYBE_UNUSED void elewise_div_scl_int8(mllm_int8_t* out, const mllm_int8_t* x, const mllm_int8_t y, size_t n) { + elementwise_scl_impl(out, x, y, n, DivOp{}); +} + + //===----------------------------------------------------------------------===// // Inplace Elementwise + - * / // diff --git a/mllm/backends/cpu/kernels/common/ggml/quantize/quantize.hpp b/mllm/backends/cpu/kernels/common/ggml/quantize/quantize.hpp index edb98305..e318451a 100644 --- a/mllm/backends/cpu/kernels/common/ggml/quantize/quantize.hpp +++ b/mllm/backends/cpu/kernels/common/ggml/quantize/quantize.hpp @@ -111,7 +111,7 @@ namespace mllm::cpu { static float table_f32_f16[1 << 16]; static bool table_f32_f16_init = false; -inline static float lookup_fp16_to_fp32(mllm_fp16_t f) { +inline static float lookup_fp16_to_fp32(uint16_t f) { if (!table_f32_f16_init) { uint16_t ii; for (int i = 0; i < (1 << 16); ++i) { diff --git a/mllm/backends/cpu/kernels/common/kernel_dispatch.cpp b/mllm/backends/cpu/kernels/common/kernel_dispatch.cpp index 7e81adfd..052bbf93 100644 --- a/mllm/backends/cpu/kernels/common/kernel_dispatch.cpp +++ b/mllm/backends/cpu/kernels/common/kernel_dispatch.cpp @@ -18,6 +18,7 @@ // Include all inline implementations here #include "mllm/backends/cpu/kernels/common/elewise-inl.hpp" #include "mllm/backends/cpu/kernels/common/fill-inl.hpp" +#include "mllm/backends/cpu/kernels/common/reduce-inl.hpp" #if HWY_ONCE namespace mllm::cpu::common { @@ -29,10 +30,52 @@ HWY_EXPORT(elewise_add_fp32); HWY_EXPORT(elewise_sub_fp32); HWY_EXPORT(elewise_mul_fp32); HWY_EXPORT(elewise_div_fp32); -HWY_EXPORT(elewise_add_scalar_fp32); -HWY_EXPORT(elewise_sub_scalar_fp32); -HWY_EXPORT(elewise_mul_scalar_fp32); -HWY_EXPORT(elewise_div_scalar_fp32); + +// HWY_EXPORT(elewise_add_fp16); +// HWY_EXPORT(elewise_sub_fp16); +// HWY_EXPORT(elewise_mul_fp16); +// HWY_EXPORT(elewise_div_fp16); + +HWY_EXPORT(elewise_add_int32); +HWY_EXPORT(elewise_sub_int32); +HWY_EXPORT(elewise_mul_int32); +HWY_EXPORT(elewise_div_int32); + +HWY_EXPORT(elewise_add_int16); +HWY_EXPORT(elewise_sub_int16); +HWY_EXPORT(elewise_mul_int16); +// HWY_EXPORT(elewise_div_int16); + +HWY_EXPORT(elewise_add_int8); +HWY_EXPORT(elewise_sub_int8); +HWY_EXPORT(elewise_mul_int8); +// HWY_EXPORT(elewise_div_int8); + +HWY_EXPORT(elewise_add_scl_fp32); +HWY_EXPORT(elewise_sub_scl_fp32); +HWY_EXPORT(elewise_mul_scl_fp32); +HWY_EXPORT(elewise_div_scl_fp32); + +// HWY_EXPORT(elewise_add_scl_fp16); +// HWY_EXPORT(elewise_sub_scl_fp16); +// HWY_EXPORT(elewise_mul_scl_fp16); +// HWY_EXPORT(elewise_div_scl_fp16); + +HWY_EXPORT(elewise_add_scl_int32); +HWY_EXPORT(elewise_sub_scl_int32); +HWY_EXPORT(elewise_mul_scl_int32); +HWY_EXPORT(elewise_div_scl_int32); + +HWY_EXPORT(elewise_add_scl_int16); +HWY_EXPORT(elewise_sub_scl_int16); +HWY_EXPORT(elewise_mul_scl_int16); +HWY_EXPORT(elewise_div_scl_int16); + +HWY_EXPORT(elewise_add_scl_int8); +HWY_EXPORT(elewise_sub_scl_int8); +HWY_EXPORT(elewise_mul_scl_int8); +HWY_EXPORT(elewise_div_scl_int8); + HWY_DLLEXPORT void call_elewise_add_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, const mllm_fp32_t* y, size_t n) { HWY_DYNAMIC_DISPATCH(elewise_add_fp32)(out, x, y, n); @@ -50,22 +93,128 @@ HWY_DLLEXPORT void call_elewise_div_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, HWY_DYNAMIC_DISPATCH(elewise_div_fp32)(out, x, y, n); } -HWY_DLLEXPORT void call_elewise_add_scalar_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n) { - HWY_DYNAMIC_DISPATCH(elewise_add_scalar_fp32)(out, x, y, n); + +HWY_DLLEXPORT void call_elewise_add_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t* y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_add_int32)(out, x, y, n); +} + +HWY_DLLEXPORT void call_elewise_sub_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t* y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_sub_int32)(out, x, y, n); +} + +HWY_DLLEXPORT void call_elewise_mul_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t* y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_mul_int32)(out, x, y, n); } -HWY_DLLEXPORT void call_elewise_sub_scalar_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n) { - HWY_DYNAMIC_DISPATCH(elewise_sub_scalar_fp32)(out, x, y, n); +HWY_DLLEXPORT void call_elewise_div_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t* y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_div_int32)(out, x, y, n); +} + + +HWY_DLLEXPORT void call_elewise_add_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t* y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_add_int16)(out, x, y, n); +} + +HWY_DLLEXPORT void call_elewise_sub_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t* y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_sub_int16)(out, x, y, n); +} + +HWY_DLLEXPORT void call_elewise_mul_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t* y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_mul_int16)(out, x, y, n); +} + +// HWY_DLLEXPORT void call_elewise_div_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t* y, size_t n) { +// HWY_DYNAMIC_DISPATCH(elewise_div_int16)(out, x, y, n); +// } + + +HWY_DLLEXPORT void call_elewise_add_int8(mllm_int8_t* out, const mllm_int8_t* x, const mllm_int8_t* y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_add_int8)(out, x, y, n); +} + +HWY_DLLEXPORT void call_elewise_sub_int8(mllm_int8_t* out, const mllm_int8_t* x, const mllm_int8_t* y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_sub_int8)(out, x, y, n); +} + +HWY_DLLEXPORT void call_elewise_mul_int8(mllm_int8_t* out, const mllm_int8_t* x, const mllm_int8_t* y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_mul_int8)(out, x, y, n); +} + +// HWY_DLLEXPORT void call_elewise_div_int8(mllm_int8_t* out, const mllm_int8_t* x, const mllm_int8_t* y, size_t n) { +// HWY_DYNAMIC_DISPATCH(elewise_div_int8)(out, x, y, n); +// } + + + +HWY_DLLEXPORT void call_elewise_add_scl_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_add_scl_fp32)(out, x, y, n); } -HWY_DLLEXPORT void call_elewise_mul_scalar_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n) { - HWY_DYNAMIC_DISPATCH(elewise_mul_scalar_fp32)(out, x, y, n); +HWY_DLLEXPORT void call_elewise_sub_scl_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_sub_scl_fp32)(out, x, y, n); } -HWY_DLLEXPORT void call_elewise_div_scalar_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n) { - HWY_DYNAMIC_DISPATCH(elewise_div_scalar_fp32)(out, x, y, n); +HWY_DLLEXPORT void call_elewise_mul_scl_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_mul_scl_fp32)(out, x, y, n); } +HWY_DLLEXPORT void call_elewise_div_scl_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_div_scl_fp32)(out, x, y, n); +} + + +HWY_DLLEXPORT void call_elewise_add_scl_int32(mllm_int32_t* out, const mllm_int32_t* x, mllm_int32_t y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_add_scl_int32)(out, x, y, n); +} + +HWY_DLLEXPORT void call_elewise_sub_scl_int32(mllm_int32_t* out, const mllm_int32_t* x, mllm_int32_t y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_sub_scl_int32)(out, x, y, n); +} + +HWY_DLLEXPORT void call_elewise_mul_scl_int32(mllm_int32_t* out, const mllm_int32_t* x, mllm_int32_t y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_mul_scl_int32)(out, x, y, n); +} + +HWY_DLLEXPORT void call_elewise_div_scl_int32(mllm_int32_t* out, const mllm_int32_t* x, mllm_int32_t y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_div_scl_int32)(out, x, y, n); +} + + +HWY_DLLEXPORT void call_elewise_add_scl_int16(mllm_int16_t* out, const mllm_int16_t* x, mllm_int16_t y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_add_scl_int16)(out, x, y, n); +} + +HWY_DLLEXPORT void call_elewise_sub_scl_int16(mllm_int16_t* out, const mllm_int16_t* x, mllm_int16_t y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_sub_scl_int16)(out, x, y, n); +} + +HWY_DLLEXPORT void call_elewise_mul_scl_int16(mllm_int16_t* out, const mllm_int16_t* x, mllm_int16_t y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_mul_scl_int16)(out, x, y, n); +} + +HWY_DLLEXPORT void call_elewise_div_scl_int16(mllm_int16_t* out, const mllm_int16_t* x, mllm_int16_t y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_div_scl_int16)(out, x, y, n); +} + + +HWY_DLLEXPORT void call_elewise_add_scl_int8(mllm_int8_t* out, const mllm_int8_t* x, mllm_int8_t y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_add_scl_int8)(out, x, y, n); +} + +HWY_DLLEXPORT void call_elewise_sub_scl_int8(mllm_int8_t* out, const mllm_int8_t* x, mllm_int8_t y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_sub_scl_int8)(out, x, y, n); +} + +HWY_DLLEXPORT void call_elewise_mul_scl_int8(mllm_int8_t* out, const mllm_int8_t* x, mllm_int8_t y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_mul_scl_int8)(out, x, y, n); +} + +HWY_DLLEXPORT void call_elewise_div_scl_int8(mllm_int8_t* out, const mllm_int8_t* x, mllm_int8_t y, size_t n) { + HWY_DYNAMIC_DISPATCH(elewise_div_scl_int8)(out, x, y, n); +} + + + //===----------------------------------------------------------------------===// // GELU //===----------------------------------------------------------------------===// @@ -252,6 +401,15 @@ HWY_DLLEXPORT void call_fill_random_u8(mllm_uint8_t* dst, size_t n, mllm_fp32_t HWY_DYNAMIC_DISPATCH(fill_random_u8)(dst, n, start, end, seed); } +//===----------------------------------------------------------------------===// +// Reduce +//===----------------------------------------------------------------------===// +HWY_EXPORT(reduce_sum_fp32); + +HWY_DLLEXPORT void call_reduce_sum_fp32(mllm_fp32_t* dst, const mllm_fp32_t* src, size_t src_stride, size_t size, int32_t thread_count) { + HWY_DYNAMIC_DISPATCH(reduce_sum_fp32)(dst, src, src_stride, size, thread_count); +} + } // namespace mllm::cpu::common #endif // HWY_ONCE diff --git a/mllm/backends/cpu/kernels/common/kernel_dispatch.hpp b/mllm/backends/cpu/kernels/common/kernel_dispatch.hpp index 4df34db0..6a6e4098 100644 --- a/mllm/backends/cpu/kernels/common/kernel_dispatch.hpp +++ b/mllm/backends/cpu/kernels/common/kernel_dispatch.hpp @@ -23,13 +23,55 @@ HWY_DLLEXPORT void call_elewise_sub_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, HWY_DLLEXPORT void call_elewise_mul_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, const mllm_fp32_t* y, size_t n); HWY_DLLEXPORT void call_elewise_div_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, const mllm_fp32_t* y, size_t n); +//TODO: fp16 support not implemented yet +// HWY_DLLEXPORT void call_elewise_add_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, const mllm_fp16_t* y, size_t n); +// HWY_DLLEXPORT void call_elewise_sub_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, const mllm_fp16_t* y, size_t n); +// HWY_DLLEXPORT void call_elewise_mul_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, const mllm_fp16_t* y, size_t n); +// HWY_DLLEXPORT void call_elewise_div_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, const mllm_fp16_t* y, size_t n); + +HWY_DLLEXPORT void call_elewise_add_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t* y, size_t n); +HWY_DLLEXPORT void call_elewise_sub_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t* y, size_t n); +HWY_DLLEXPORT void call_elewise_mul_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t* y, size_t n); +HWY_DLLEXPORT void call_elewise_div_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t* y, size_t n); + +HWY_DLLEXPORT void call_elewise_add_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t* y, size_t n); +HWY_DLLEXPORT void call_elewise_sub_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t* y, size_t n); +HWY_DLLEXPORT void call_elewise_mul_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t* y, size_t n); +// HWY_DLLEXPORT void call_elewise_div_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t* y, size_t n); + +HWY_DLLEXPORT void call_elewise_add_int8(mllm_int8_t* out, const mllm_int8_t* x, const mllm_int8_t* y, size_t n); +HWY_DLLEXPORT void call_elewise_sub_int8(mllm_int8_t* out, const mllm_int8_t* x, const mllm_int8_t* y, size_t n); +HWY_DLLEXPORT void call_elewise_mul_int8(mllm_int8_t* out, const mllm_int8_t* x, const mllm_int8_t* y, size_t n); +// HWY_DLLEXPORT void call_elewise_div_int8(mllm_int8_t* out, const mllm_int8_t* x, const mllm_int8_t* y, size_t n); + //===----------------------------------------------------------------------===// // Elementwise + - * / By Const //===----------------------------------------------------------------------===// -HWY_DLLEXPORT void call_elewise_add_scalar_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n); -HWY_DLLEXPORT void call_elewise_sub_scalar_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n); -HWY_DLLEXPORT void call_elewise_mul_scalar_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n); -HWY_DLLEXPORT void call_elewise_div_scalar_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n); +HWY_DLLEXPORT void call_elewise_add_scl_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n); +HWY_DLLEXPORT void call_elewise_sub_scl_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n); +HWY_DLLEXPORT void call_elewise_mul_scl_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n); +HWY_DLLEXPORT void call_elewise_div_scl_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n); + +//TODO: fp16 support not implemented yet +// HWY_DLLEXPORT void call_elewise_add_scl_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, mllm_fp16_t y, size_t n); +// HWY_DLLEXPORT void call_elewise_sub_scl_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, mllm_fp16_t y, size_t n); +// HWY_DLLEXPORT void call_elewise_mul_scl_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, mllm_fp16_t y, size_t n); +// HWY_DLLEXPORT void call_elewise_div_scl_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, mllm_fp16_t y, size_t n); + +HWY_DLLEXPORT void call_elewise_add_scl_int32(mllm_int32_t* out, const mllm_int32_t* x, mllm_int32_t y, size_t n); +HWY_DLLEXPORT void call_elewise_sub_scl_int32(mllm_int32_t* out, const mllm_int32_t* x, mllm_int32_t y, size_t n); +HWY_DLLEXPORT void call_elewise_mul_scl_int32(mllm_int32_t* out, const mllm_int32_t* x, mllm_int32_t y, size_t n); +HWY_DLLEXPORT void call_elewise_div_scl_int32(mllm_int32_t* out, const mllm_int32_t* x, mllm_int32_t y, size_t n); + +HWY_DLLEXPORT void call_elewise_add_scl_int16(mllm_int16_t* out, const mllm_int16_t* x, mllm_int16_t y, size_t n); +HWY_DLLEXPORT void call_elewise_sub_scl_int16(mllm_int16_t* out, const mllm_int16_t* x, mllm_int16_t y, size_t n); +HWY_DLLEXPORT void call_elewise_mul_scl_int16(mllm_int16_t* out, const mllm_int16_t* x, mllm_int16_t y, size_t n); +HWY_DLLEXPORT void call_elewise_div_scl_int16(mllm_int16_t* out, const mllm_int16_t* x, mllm_int16_t y, size_t n); + +HWY_DLLEXPORT void call_elewise_add_scl_int8(mllm_int8_t* out, const mllm_int8_t* x, mllm_int8_t y, size_t n); +HWY_DLLEXPORT void call_elewise_sub_scl_int8(mllm_int8_t* out, const mllm_int8_t* x, mllm_int8_t y, size_t n); +HWY_DLLEXPORT void call_elewise_mul_scl_int8(mllm_int8_t* out, const mllm_int8_t* x, mllm_int8_t y, size_t n); +HWY_DLLEXPORT void call_elewise_div_scl_int8(mllm_int8_t* out, const mllm_int8_t* x, mllm_int8_t y, size_t n); //===----------------------------------------------------------------------===// // Fill Zeros @@ -247,6 +289,11 @@ inline void fill_random_anytype(T* dst, size_t n, mllm_fp32_t start, mllm_fp32_t } } +//===----------------------------------------------------------------------===// +// Reduce +//===----------------------------------------------------------------------===// +HWY_DLLEXPORT void call_reduce_sum_fp32(mllm_fp32_t* dst, const mllm_fp32_t* src, size_t src_stride, size_t size, int32_t thread_count); + } // namespace mllm::cpu::common #endif diff --git a/mllm/backends/cpu/kernels/common/reduce-inl.hpp b/mllm/backends/cpu/kernels/common/reduce-inl.hpp index e69de29b..2c381ae0 100644 --- a/mllm/backends/cpu/kernels/common/reduce-inl.hpp +++ b/mllm/backends/cpu/kernels/common/reduce-inl.hpp @@ -0,0 +1,123 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include +#include "mllm/core/DataTypes.hpp" + +HWY_BEFORE_NAMESPACE(); +namespace mllm::cpu::common { // NOLINT +namespace HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; + + +struct ScalarAddOp { template HWY_INLINE T operator()(T a, T b) const { return a + b; } }; + +struct ScalarSubOp { template HWY_INLINE T operator()(T a, T b) const { return a - b; } }; + +struct ScalarMulOp { template HWY_INLINE T operator()(T a, T b) const { return a * b; } }; + +struct ScalarDivOp { template HWY_INLINE T operator()(T a, T b) const { return a / b; } }; + +struct ScalarMaxOp { template HWY_INLINE T operator()(T a, T b) const { return a > b ? a : b; } }; + +struct ScalarMinOp { template HWY_INLINE T operator()(T a, T b) const { return a < b ? a : b; } }; + +struct VecAddOp { + template + HWY_INLINE V operator()(D d, V a, V b) const { return hn::Add(a, b); } +}; + +struct VecSubOp { + template + HWY_INLINE V operator()(D d, V a, V b) const { return hn::Sub(a, b); } +}; + +struct VecMulOp { + template + HWY_INLINE V operator()(D d, V a, V b) const { return hn::Mul(a, b); } +}; + +struct VecDivOp { + template + HWY_INLINE V operator()(D d, V a, V b) const { return hn::Div(a, b); } +}; + +struct VecMaxOp { + template + HWY_INLINE V operator()(D d, V a, V b) const { return hn::Max(a, b); } +}; + +struct VecMinOp { + template + HWY_INLINE V operator()(D d, V a, V b) const { return hn::Min(a, b); } +}; + +struct VecSumReduce { + template + HWY_INLINE hn::TFromD operator()(D d, V v) const { return hn::ReduceSum(d, v); } +}; + + +template +HWY_INLINE T reduce_impl(const T* HWY_RESTRICT src, size_t src_stride, size_t size, + ScalarOp&& scalar_op, VectorOp&& vec_op, VectorReduceOp&& vec_reduce_op) { + if (size == 0) return T(0); + + const hn::ScalableTag d; + const size_t N = hn::Lanes(d); + + // SIMD fast path + if (src_stride == 1 && size >= N) { + using V = hn::Vec; + + // Init with first vector + V vec_result = hn::LoadU(d, src); + size_t i = N; + + // 4x unroll + for (; i + 4 * N <= size; i += 4 * N) { + const V v0 = hn::LoadU(d, src + i); + const V v1 = hn::LoadU(d, src + i + N); + const V v2 = hn::LoadU(d, src + i + 2 * N); + const V v3 = hn::LoadU(d, src + i + 3 * N); + + vec_result = vec_op(d, vec_result, v0); + vec_result = vec_op(d, vec_result, v1); + vec_result = vec_op(d, vec_result, v2); + vec_result = vec_op(d, vec_result, v3); + } + + for (; i + N <= size; i += N) { + const V v = hn::LoadU(d, src + i); + vec_result = vec_op(d, vec_result, v); + } + + if (i < size) { + const V vt = hn::LoadN(d, src + i, size - i); + vec_result = vec_op(d, vec_result, vt); + } + + return vec_reduce_op(d, vec_result); + } + + // Scalar path (stride != 1 or too small) + T scalar_result = src[0]; + for (size_t i = 1; i < size; ++i) { + scalar_result = scalar_op(scalar_result, src[i * src_stride]); + } + return scalar_result; + +} + + +HWY_NOINLINE HWY_MAYBE_UNUSED void reduce_sum_fp32(mllm_fp32_t* dst,const mllm_fp32_t* src, +size_t src_stride, size_t size, int32_t thread_count) { + const mllm_fp32_t v = reduce_impl(src, src_stride, size, + ScalarAddOp{}, VecAddOp{}, VecSumReduce{}); + *dst = v; +} + + +} // namespace HWY_NAMESPACE +} // namespace mllm::cpu::common +HWY_AFTER_NAMESPACE(); diff --git a/mllm/backends/cpu/ops/ElewiseOps.cpp b/mllm/backends/cpu/ops/ElewiseOps.cpp index a3e1f7dd..e751b197 100644 --- a/mllm/backends/cpu/ops/ElewiseOps.cpp +++ b/mllm/backends/cpu/ops/ElewiseOps.cpp @@ -140,22 +140,22 @@ void CPUAddOp::forward(const std::vector& inputs, std::vector& o switch (dtype) { case kFloat32: { if (input0.numel() == input1.numel()) { -#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_add_fp32(output.ptr(), input0.ptr(), input1.ptr(), - output.numel()); -#elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) +#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_add_fp32(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::call_elewise_add_fp32(output.ptr(), input0.ptr(), input1.ptr(), + output.numel()); #else NYI("AddOp not supported on this architecture."); #endif } else if (input1.numel() == 1) { -#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_add_scalar_fp32(output.ptr(), input0.ptr(), *input1.ptr(), - output.numel()); -#elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) +#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_add_fp32_scalar(output.ptr(), input0.ptr(), *input1.ptr(), - output.numel(), options_.getThreads()); + output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::call_elewise_add_scl_fp32(output.ptr(), input0.ptr(), *input1.ptr(), + output.numel()); #else NYI("AddOp not supported on this architecture."); #endif @@ -163,7 +163,7 @@ void CPUAddOp::forward(const std::vector& inputs, std::vector& o const float* a = input0.ptr(); const float* b = input1.ptr(); float* out = output.ptr(); -#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) +#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) // Process each batch separately for (int batch = 0; batch < batch_dims; ++batch) { // Each batch processes broadcast_naive_loops iterations of vector_size elements @@ -172,11 +172,10 @@ void CPUAddOp::forward(const std::vector& inputs, std::vector& o size_t b_offset = batch * vector_size; // b doesn't broadcast over loops dimension size_t out_offset = batch * broadcast_naive_loops * vector_size + l * vector_size; - cpu::common::call_elewise_add_fp32(out + out_offset, a + a_offset, b + b_offset, vector_size); + cpu::arm::ew_add_fp32(out + out_offset, a + a_offset, b + b_offset, vector_size, options_.getThreads()); } } - -#elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) // Process each batch separately for (int batch = 0; batch < batch_dims; ++batch) { // Each batch processes broadcast_naive_loops iterations of vector_size elements @@ -185,7 +184,7 @@ void CPUAddOp::forward(const std::vector& inputs, std::vector& o size_t b_offset = batch * vector_size; // b doesn't broadcast over loops dimension size_t out_offset = batch * broadcast_naive_loops * vector_size + l * vector_size; - cpu::arm::ew_add_fp32(out + out_offset, a + a_offset, b + b_offset, vector_size, options_.getThreads()); + cpu::common::call_elewise_add_fp32(out + out_offset, a + a_offset, b + b_offset, vector_size); } } #else @@ -202,11 +201,15 @@ void CPUAddOp::forward(const std::vector& inputs, std::vector& o #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) cpu::arm::ew_add_fp16(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + NYI("AddOp fp16 not supported on x86 architecture yet."); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) cpu::arm::ew_add_fp16_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + NYI("AddOp fp16 not supported on x86 architecture yet."); #endif } else { NYI("AddOp broadcast not supported."); @@ -219,11 +222,17 @@ void CPUAddOp::forward(const std::vector& inputs, std::vector& o #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_add_int32(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::call_elewise_add_int32(output.ptr(), input0.ptr(), input1.ptr(), + output.numel()); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_add_int32_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::call_elewise_add_scl_int32(output.ptr(), input0.ptr(), *input1.ptr(), + output.numel()); #endif } else { NYI("AddOp broadcast not supported."); @@ -236,11 +245,17 @@ void CPUAddOp::forward(const std::vector& inputs, std::vector& o #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_add_int16(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::call_elewise_add_int16(output.ptr(), input0.ptr(), input1.ptr(), + output.numel()); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_add_int16_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::call_elewise_add_scl_int16(output.ptr(), input0.ptr(), *input1.ptr(), + output.numel()); #endif } else { NYI("AddOp broadcast not supported."); @@ -253,11 +268,16 @@ void CPUAddOp::forward(const std::vector& inputs, std::vector& o #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_add_int8(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::call_elewise_add_int8(output.ptr(), input0.ptr(), input1.ptr(), output.numel()); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_add_int8_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::call_elewise_add_scl_int8(output.ptr(), input0.ptr(), *input1.ptr(), + output.numel()); #endif } else { NYI("AddOp broadcast not supported."); @@ -323,22 +343,22 @@ void CPUSubOp::forward(const std::vector& inputs, std::vector& o switch (dtype) { case kFloat32: { if (input0.numel() == input1.numel()) { -#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) +#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) + cpu::arm::ew_sub_fp32(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), + options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) cpu::common::call_elewise_sub_fp32(output.ptr(), input0.ptr(), input1.ptr(), output.numel()); -#elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) - cpu::arm::ew_sub_fp32(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), - options_.getThreads()); #else NYI("SubOp not supported on this architecture."); #endif } else if (input1.numel() == 1) { -#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_sub_scalar_fp32(output.ptr(), input0.ptr(), *input1.ptr(), - output.numel()); -#elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) +#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_sub_fp32_scalar(output.ptr(), input0.ptr(), *input1.ptr(), - output.numel(), options_.getThreads()); + output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::call_elewise_sub_scl_fp32(output.ptr(), input0.ptr(), *input1.ptr(), + output.numel()); #else NYI("SubOp not supported on this architecture."); #endif @@ -346,7 +366,7 @@ void CPUSubOp::forward(const std::vector& inputs, std::vector& o const float* a = input0.ptr(); const float* b = input1.ptr(); float* out = output.ptr(); -#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) +#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) // Process each batch separately for (int batch = 0; batch < batch_dims; ++batch) { // Each batch processes broadcast_naive_loops iterations of vector_size elements @@ -355,10 +375,10 @@ void CPUSubOp::forward(const std::vector& inputs, std::vector& o size_t b_offset = batch * vector_size; // b doesn't broadcast over loops dimension size_t out_offset = batch * broadcast_naive_loops * vector_size + l * vector_size; - cpu::common::call_elewise_sub_fp32(out + out_offset, a + a_offset, b + b_offset, vector_size); + cpu::arm::ew_sub_fp32(out + out_offset, a + a_offset, b + b_offset, vector_size, options_.getThreads()); } - } -#elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) + } +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) // Process each batch separately for (int batch = 0; batch < batch_dims; ++batch) { // Each batch processes broadcast_naive_loops iterations of vector_size elements @@ -367,7 +387,7 @@ void CPUSubOp::forward(const std::vector& inputs, std::vector& o size_t b_offset = batch * vector_size; // b doesn't broadcast over loops dimension size_t out_offset = batch * broadcast_naive_loops * vector_size + l * vector_size; - cpu::arm::ew_sub_fp32(out + out_offset, a + a_offset, b + b_offset, vector_size, options_.getThreads()); + cpu::common::call_elewise_sub_fp32(out + out_offset, a + a_offset, b + b_offset, vector_size); } } #else @@ -384,11 +404,15 @@ void CPUSubOp::forward(const std::vector& inputs, std::vector& o #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) cpu::arm::ew_sub_fp16(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + NYI("SubOp fp16 not supported on x86 architecture yet."); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) cpu::arm::ew_sub_fp16_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + NYI("SubOp fp16 not supported on x86 architecture yet."); #endif } else { NYI("SubOp broadcast not supported."); @@ -401,11 +425,17 @@ void CPUSubOp::forward(const std::vector& inputs, std::vector& o #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_sub_int32(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::call_elewise_sub_int32(output.ptr(), input0.ptr(), input1.ptr(), + output.numel()); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_sub_int32_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::call_elewise_sub_scl_int32(output.ptr(), input0.ptr(), *input1.ptr(), + output.numel()); #endif } else { NYI("SubOp broadcast not supported."); @@ -418,11 +448,17 @@ void CPUSubOp::forward(const std::vector& inputs, std::vector& o #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_sub_int16(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::call_elewise_sub_int16(output.ptr(), input0.ptr(), input1.ptr(), + output.numel()); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_sub_int16_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::call_elewise_sub_scl_int16(output.ptr(), input0.ptr(), *input1.ptr(), + output.numel()); #endif } else { NYI("SubOp broadcast not supported."); @@ -435,11 +471,16 @@ void CPUSubOp::forward(const std::vector& inputs, std::vector& o #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_sub_int8(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::call_elewise_sub_int8(output.ptr(), input0.ptr(), input1.ptr(), output.numel()); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_sub_int8_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::call_elewise_sub_scl_int8(output.ptr(), input0.ptr(), *input1.ptr(), + output.numel()); #endif } else { NYI("SubOp broadcast not supported."); @@ -505,22 +546,22 @@ void CPUMulOp::forward(const std::vector& inputs, std::vector& o switch (dtype) { case kFloat32: { if (input0.numel() == input1.numel()) { -#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) +#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) + cpu::arm::ew_mul_fp32(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), + options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) cpu::common::call_elewise_mul_fp32(output.ptr(), input0.ptr(), input1.ptr(), output.numel()); -#elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) - cpu::arm::ew_mul_fp32(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), - options_.getThreads()); #else NYI("MulOp not supported on this architecture."); #endif } else if (input1.numel() == 1) { -#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_mul_scalar_fp32(output.ptr(), input0.ptr(), *input1.ptr(), - output.numel()); -#elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) +#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_mul_fp32_scalar(output.ptr(), input0.ptr(), *input1.ptr(), - output.numel(), options_.getThreads()); + output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::call_elewise_mul_scl_fp32(output.ptr(), input0.ptr(), *input1.ptr(), + output.numel()); #else NYI("MulOp not supported on this architecture."); #endif @@ -528,28 +569,28 @@ void CPUMulOp::forward(const std::vector& inputs, std::vector& o const float* a = input0.ptr(); const float* b = input1.ptr(); float* out = output.ptr(); -#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) +#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) // Process each batch separately for (int batch = 0; batch < batch_dims; ++batch) { - // Each batch processes broadcast_naive_loops iterations of vector_size elements + // Each batch processes broadcast_naive_loops iterations of vector_size elements for (int l = 0; l < broadcast_naive_loops; ++l) { size_t a_offset = batch * broadcast_naive_loops * vector_size + l * vector_size; size_t b_offset = batch * vector_size; // b doesn't broadcast over loops dimension size_t out_offset = batch * broadcast_naive_loops * vector_size + l * vector_size; - cpu::common::call_elewise_mul_fp32(out + out_offset, a + a_offset, b + b_offset, vector_size); + cpu::arm::ew_mul_fp32(out + out_offset, a + a_offset, b + b_offset, vector_size, options_.getThreads()); } - } -#elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) + } +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) // Process each batch separately for (int batch = 0; batch < batch_dims; ++batch) { - // Each batch processes broadcast_naive_loops iterations of vector_size elements + // Each batch processes broadcast_naive_loops iterations of vector_size elements for (int l = 0; l < broadcast_naive_loops; ++l) { size_t a_offset = batch * broadcast_naive_loops * vector_size + l * vector_size; size_t b_offset = batch * vector_size; // b doesn't broadcast over loops dimension size_t out_offset = batch * broadcast_naive_loops * vector_size + l * vector_size; - cpu::arm::ew_mul_fp32(out + out_offset, a + a_offset, b + b_offset, vector_size, options_.getThreads()); + cpu::common::call_elewise_mul_fp32(out + out_offset, a + a_offset, b + b_offset, vector_size); } } #else @@ -566,11 +607,15 @@ void CPUMulOp::forward(const std::vector& inputs, std::vector& o #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) cpu::arm::ew_mul_fp16(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + NYI("MulOp fp16 not supported on x86 architecture yet."); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) cpu::arm::ew_mul_fp16_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + NYI("MulOp fp16 not supported on x86 architecture yet."); #endif } else { NYI("MulOp broadcast not supported."); @@ -583,11 +628,17 @@ void CPUMulOp::forward(const std::vector& inputs, std::vector& o #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_mul_int32(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::call_elewise_mul_int32(output.ptr(), input0.ptr(), input1.ptr(), + output.numel()); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_mul_int32_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::call_elewise_mul_scl_int32(output.ptr(), input0.ptr(), *input1.ptr(), + output.numel()); #endif } else { NYI("MulOp broadcast not supported."); @@ -600,11 +651,17 @@ void CPUMulOp::forward(const std::vector& inputs, std::vector& o #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_mul_int16(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::call_elewise_mul_int16(output.ptr(), input0.ptr(), input1.ptr(), + output.numel()); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_mul_int16_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::call_elewise_mul_scl_int16(output.ptr(), input0.ptr(), *input1.ptr(), + output.numel()); #endif } else { NYI("MulOp broadcast not supported."); @@ -617,11 +674,16 @@ void CPUMulOp::forward(const std::vector& inputs, std::vector& o #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_mul_int8(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::call_elewise_mul_int8(output.ptr(), input0.ptr(), input1.ptr(), output.numel()); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_mul_int8_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::call_elewise_mul_scl_int8(output.ptr(), input0.ptr(), *input1.ptr(), + output.numel()); #endif } else { NYI("MulOp broadcast not supported."); @@ -687,22 +749,22 @@ void CPUDivOp::forward(const std::vector& inputs, std::vector& o switch (dtype) { case kFloat32: { if (input0.numel() == input1.numel()) { -#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) +#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) + cpu::arm::ew_div_fp32(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), + options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) cpu::common::call_elewise_div_fp32(output.ptr(), input0.ptr(), input1.ptr(), output.numel()); -#elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) - cpu::arm::ew_div_fp32(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), - options_.getThreads()); #else NYI("DivOp not supported on this architecture."); #endif } else if (input1.numel() == 1) { -#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_div_scalar_fp32(output.ptr(), input0.ptr(), *input1.ptr(), - output.numel()); -#elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) +#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_div_fp32_scalar(output.ptr(), input0.ptr(), *input1.ptr(), - output.numel(), options_.getThreads()); + output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::call_elewise_div_scl_fp32(output.ptr(), input0.ptr(), *input1.ptr(), + output.numel()); #else NYI("DivOp not supported on this architecture."); #endif @@ -710,28 +772,28 @@ void CPUDivOp::forward(const std::vector& inputs, std::vector& o const float* a = input0.ptr(); const float* b = input1.ptr(); float* out = output.ptr(); -#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) +#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) // Process each batch separately for (int batch = 0; batch < batch_dims; ++batch) { - // Each batch processes broadcast_naive_loops iterations of vector_size elements + // Each batch processes broadcast_naive_loops iterations of vector_size elements for (int l = 0; l < broadcast_naive_loops; ++l) { size_t a_offset = batch * broadcast_naive_loops * vector_size + l * vector_size; size_t b_offset = batch * vector_size; // b doesn't broadcast over loops dimension size_t out_offset = batch * broadcast_naive_loops * vector_size + l * vector_size; - cpu::common::call_elewise_div_fp32(out + out_offset, a + a_offset, b + b_offset, vector_size); + cpu::arm::ew_div_fp32(out + out_offset, a + a_offset, b + b_offset, vector_size, options_.getThreads()); } - } -#elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) + } +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) // Process each batch separately for (int batch = 0; batch < batch_dims; ++batch) { - // Each batch processes broadcast_naive_loops iterations of vector_size elements + // Each batch processes broadcast_naive_loops iterations of vector_size elements for (int l = 0; l < broadcast_naive_loops; ++l) { size_t a_offset = batch * broadcast_naive_loops * vector_size + l * vector_size; size_t b_offset = batch * vector_size; // b doesn't broadcast over loops dimension size_t out_offset = batch * broadcast_naive_loops * vector_size + l * vector_size; - cpu::arm::ew_div_fp32(out + out_offset, a + a_offset, b + b_offset, vector_size, options_.getThreads()); + cpu::common::call_elewise_div_fp32(out + out_offset, a + a_offset, b + b_offset, vector_size); } } #else @@ -748,11 +810,15 @@ void CPUDivOp::forward(const std::vector& inputs, std::vector& o #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) cpu::arm::ew_div_fp16(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + NYI("DivOp fp16 not supported on x86 architecture yet."); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) cpu::arm::ew_div_fp16_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + NYI("DivOp fp16 not supported on x86 architecture yet."); #endif } else { NYI("DivOp broadcast not supported."); @@ -765,11 +831,17 @@ void CPUDivOp::forward(const std::vector& inputs, std::vector& o #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_div_int32(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::call_elewise_div_int32(output.ptr(), input0.ptr(), input1.ptr(), + output.numel()); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_div_int32_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + cpu::common::call_elewise_div_scl_int32(output.ptr(), input0.ptr(), *input1.ptr(), + output.numel()); #endif } else { NYI("DivOp broadcast not supported."); @@ -782,11 +854,15 @@ void CPUDivOp::forward(const std::vector& inputs, std::vector& o #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_div_int16(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + NYI("DivOp int16 not supported on x86 architecture yet."); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_div_int16_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + NYI("DivOp int16 not supported on x86 architecture yet."); #endif } else { NYI("DivOp broadcast not supported."); @@ -799,11 +875,15 @@ void CPUDivOp::forward(const std::vector& inputs, std::vector& o #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_div_int8(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + NYI("DivOp int8 not supported on x86 architecture yet."); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_div_int8_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + NYI("DivOp int8 not supported on x86 architecture yet."); #endif } else { NYI("DivOp broadcast not supported."); @@ -817,11 +897,15 @@ void CPUDivOp::forward(const std::vector& inputs, std::vector& o #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_div_fp32_complex(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + NYI("DivOp complex fp32 not supported on x86 architecture yet."); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_div_fp32_complex_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + NYI("DivOp complex fp32 not supported on x86 architecture yet."); #endif } else if (can_be_broadcast_naive) { const float* a = input0.ptr(); @@ -840,6 +924,8 @@ void CPUDivOp::forward(const std::vector& inputs, std::vector& o cpu::arm::ew_div_fp32_complex(out + out_offset, a + a_offset, b + b_offset, vector_size, options_.getThreads()); } } +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + NYI("DivOp complex fp32 not supported on x86 architecture yet."); #endif } else { NYI("DivOp broadcast for complex output not supported."); diff --git a/mllm/backends/cpu/ops/MatMulOp.cpp b/mllm/backends/cpu/ops/MatMulOp.cpp index cc7dddde..b22549fb 100644 --- a/mllm/backends/cpu/ops/MatMulOp.cpp +++ b/mllm/backends/cpu/ops/MatMulOp.cpp @@ -49,7 +49,7 @@ void CPUMatMulOp::forward(const std::vector& inputs, std::vector #if defined(MLLM_USE_BLAS) mt = aops::MatMulOpType::kBLAS; #else - if (!transpose_a && transpose_b && M >= 4) { + if (!transpose_a && transpose_b) { // TODO kGGUF still buggy !!! mt = aops::MatMulOpType::kGGUF; } else @@ -110,6 +110,18 @@ void CPUMatMulOp::forward(const std::vector& inputs, std::vector transpose_a, transpose_b, thread_count); } } +// #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) +// if (lhs.dtype() == kFloat32 && rhs.dtype() == kFloat32 && o.dtype() == kFloat32) { +// if (batch_count == 1) { +// x86::mllm_blas_matmul_fp32(M, K, N, o.ptr(), lhs.ptr(), rhs.ptr(), nullptr, +// transpose_a, transpose_b); +// } else { +// x86::mllm_blas_batch_matmul_fp32(batch_count, M, K, N, o.stride()[o.shape().size() - 3], +// lhs.stride()[lhs_shape.size() - 3], rhs.stride()[rhs_shape.size() - 3], 0, +// o.ptr(), lhs.ptr(), rhs.ptr(), nullptr, +// transpose_a, transpose_b); +// } +// } #else NYI("MllmBlas only support MLLM_HOST_ARCH_ARM64 or MLLM_HOST_ARCH_ARM right now.") #endif diff --git a/mllm/backends/cpu/ops/ReduceOps.cpp b/mllm/backends/cpu/ops/ReduceOps.cpp index a60ae67e..e9eb325d 100644 --- a/mllm/backends/cpu/ops/ReduceOps.cpp +++ b/mllm/backends/cpu/ops/ReduceOps.cpp @@ -294,6 +294,8 @@ void CPUReduceSumOp::forward(const std::vector& inputs, std::vector(), input.ptr(), 1, input.numel(), options_.getThreads()); +#elif defined(MLLM_HOST_ARCH_X86) || defined(MLLM_HOST_ARCH_X86_64) + NYI("ReduceSumOp not implemented for x86/x86_64 yet."); #endif break; } @@ -344,6 +346,9 @@ void CPUReduceSumOp::forward(const std::vector& inputs, std::vector Date: Sun, 15 Feb 2026 15:17:19 +0000 Subject: [PATCH 3/5] fix: add error handling and comments for qwen3-moe files --- examples/qwen3_moe/main.cpp | 8 +++++++- mllm/models/qwen3_moe/configuration_qwen3_moe.hpp | 15 ++++++++------- mllm/models/qwen3_moe/modeling_qwen3_moe_fa2.hpp | 8 ++++---- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/examples/qwen3_moe/main.cpp b/examples/qwen3_moe/main.cpp index 1fb01fd4..367bbae2 100644 --- a/examples/qwen3_moe/main.cpp +++ b/examples/qwen3_moe/main.cpp @@ -25,7 +25,11 @@ MLLM_MAIN({ file_version = mllm::ModelFileVersion::kV1; } else if (model_version.get() == "v2") { file_version = mllm::ModelFileVersion::kV2; - } + } else { + fmt::print("āŒ Unsupported model_version: {} (expected v1 or v2)\n", model_version.get()); + mllm::shutdownContext(); + return 1; + } if (help.isSet()) { Argparse::printHelp(); @@ -49,6 +53,8 @@ MLLM_MAIN({ fmt::print("šŸ’¬ Prompt text (or 'exit/quit'): "); std::getline(std::cin, prompt_text); + if(prompt_text == "exit" || prompt_text == "quit") { return 0; } + try { fmt::print("šŸ”„ Processing...\n"); auto inputs = qwen3_moe_tokenizer.convertMessage({.prompt = prompt_text}); diff --git a/mllm/models/qwen3_moe/configuration_qwen3_moe.hpp b/mllm/models/qwen3_moe/configuration_qwen3_moe.hpp index 7dd92c49..f3cf5d32 100644 --- a/mllm/models/qwen3_moe/configuration_qwen3_moe.hpp +++ b/mllm/models/qwen3_moe/configuration_qwen3_moe.hpp @@ -6,7 +6,7 @@ #include "mllm/engine/ConfigFile.hpp" namespace mllm::models::qwen3_moe { - +// Configuration for Qwen3 Mixture-of-Experts model struct Qwen3MoeConfig : protected ConfigFile { Qwen3MoeConfig() = default; @@ -22,22 +22,25 @@ struct Qwen3MoeConfig : protected ConfigFile { rms_norm_eps = data()["rms_norm_eps"]; vocab_size = data()["vocab_size"]; head_dim = data()["head_dim"]; + rope_theta = data()["rope_theta"]; + // Special tokens bos_token_id = data()["bos_token_id"]; eos_token_id = data()["eos_token_id"]; - rope_theta = data()["rope_theta"]; + // Generation config tie_word_embeddings = data()["tie_word_embeddings"]; max_cache_length = data()["max_cache_length"]; // MoE config num_experts = data()["num_experts"]; num_experts_per_tok = data()["num_experts_per_tok"]; - moe_intermediate_size = data()["moe_intermediate_size"]; + moe_intermediate_size = data()["moe_intermediate_size"]; norm_topk_prob = data()["norm_topk_prob"]; - decoder_sparse_step = data()["decoder_sparse_step"]; + decoder_sparse_step = data()["decoder_sparse_step"]; mlp_only_layers = data()["mlp_only_layers"].get>(); + // Linear implementation type linear_impl_type = aops::str2LinearImplTypes(data()["linear_impl_type"]); } @@ -58,9 +61,7 @@ struct Qwen3MoeConfig : protected ConfigFile { bool tie_word_embeddings = false; int32_t max_cache_length = 4096; - int32_t end_of_text_token_id = 151645; - int32_t thinking_start_token_id = 151667; - int32_t thinking_end_token_id = 151668; + int32_t end_of_text_token_id = 151645; // fixed default int32_t num_experts = 128; int32_t num_experts_per_tok = 8; diff --git a/mllm/models/qwen3_moe/modeling_qwen3_moe_fa2.hpp b/mllm/models/qwen3_moe/modeling_qwen3_moe_fa2.hpp index 83a569df..379db0c7 100644 --- a/mllm/models/qwen3_moe/modeling_qwen3_moe_fa2.hpp +++ b/mllm/models/qwen3_moe/modeling_qwen3_moe_fa2.hpp @@ -130,7 +130,7 @@ class MoEGate final : public nn::Module { auto logits = nn::functional::matmul(hidden_states, weight_.weight(), false, true); auto scores = nn::functional::softmax(logits, -1); auto [topk_weight, topk_idx] = nn::functional::topk(scores, top_k_, -1, true, false); - + if(norm_topk_prob_){ topk_weight = topk_weight / topk_weight.sum(-1, true); } @@ -174,7 +174,7 @@ class Qwen3MoE final : public nn::Module { private: Tensor moeInfer(const Tensor& x, Tensor& topk_ids, Tensor& topk_weights) { // x shape is [batch_size * seq, hidden_dim] - + auto cnts = Tensor::zeros({topk_ids.size(0), (int32_t)experts_.list().size()}); // Do scatter_ operation { @@ -194,9 +194,8 @@ class Qwen3MoE final : public nn::Module { auto tokens_per_expert = cnts.sum(0); auto idxs = topk_ids.view({-1}).argsort(); - // TODO this line maybe error auto sorted_tokens = x[{idxs / topk_ids.size(1), {kAll}}]; - + std::vector outputs; int start_idx = 0; @@ -342,6 +341,7 @@ class Qwen3MoeDecoder final : public nn::Module { self_attn_ = reg("self_attn", cfg); self_attn_.layer_idx_ = layer_idx; + MLLM_RT_ASSERT(cfg.decoder_sparse_step > 0); bool is_mlp_only = std::find(cfg.mlp_only_layers.begin(), cfg.mlp_only_layers.end(), layer_idx) != cfg.mlp_only_layers.end(); if ((!is_mlp_only) && (cfg.num_experts > 0 && (layer_idx_+1) % cfg.decoder_sparse_step == 0)) { mlp_opt0_ = reg("mlp", cfg); From 8adc4cb339d017fb603ea38ef918ffa44bb5e492 Mon Sep 17 00:00:00 2001 From: HayzelHan Date: Sun, 15 Feb 2026 15:24:45 +0000 Subject: [PATCH 4/5] refactor: add generic template wrapper for elementwise operators --- .../cpu/kernels/common/kernel_dispatch.cpp | 93 +++------ .../cpu/kernels/common/kernel_dispatch.hpp | 153 ++++++++++++++- mllm/backends/cpu/ops/ElewiseOps.cpp | 177 +++++++++--------- 3 files changed, 257 insertions(+), 166 deletions(-) diff --git a/mllm/backends/cpu/kernels/common/kernel_dispatch.cpp b/mllm/backends/cpu/kernels/common/kernel_dispatch.cpp index 052bbf93..324039c8 100644 --- a/mllm/backends/cpu/kernels/common/kernel_dispatch.cpp +++ b/mllm/backends/cpu/kernels/common/kernel_dispatch.cpp @@ -24,197 +24,152 @@ namespace mllm::cpu::common { //===----------------------------------------------------------------------===// -// Element-wise +// Elementwise + - * / By Matrix //===----------------------------------------------------------------------===// HWY_EXPORT(elewise_add_fp32); HWY_EXPORT(elewise_sub_fp32); HWY_EXPORT(elewise_mul_fp32); HWY_EXPORT(elewise_div_fp32); - // HWY_EXPORT(elewise_add_fp16); // HWY_EXPORT(elewise_sub_fp16); // HWY_EXPORT(elewise_mul_fp16); // HWY_EXPORT(elewise_div_fp16); - HWY_EXPORT(elewise_add_int32); HWY_EXPORT(elewise_sub_int32); HWY_EXPORT(elewise_mul_int32); HWY_EXPORT(elewise_div_int32); - HWY_EXPORT(elewise_add_int16); HWY_EXPORT(elewise_sub_int16); HWY_EXPORT(elewise_mul_int16); // HWY_EXPORT(elewise_div_int16); - HWY_EXPORT(elewise_add_int8); HWY_EXPORT(elewise_sub_int8); HWY_EXPORT(elewise_mul_int8); // HWY_EXPORT(elewise_div_int8); -HWY_EXPORT(elewise_add_scl_fp32); -HWY_EXPORT(elewise_sub_scl_fp32); -HWY_EXPORT(elewise_mul_scl_fp32); -HWY_EXPORT(elewise_div_scl_fp32); - -// HWY_EXPORT(elewise_add_scl_fp16); -// HWY_EXPORT(elewise_sub_scl_fp16); -// HWY_EXPORT(elewise_mul_scl_fp16); -// HWY_EXPORT(elewise_div_scl_fp16); - -HWY_EXPORT(elewise_add_scl_int32); -HWY_EXPORT(elewise_sub_scl_int32); -HWY_EXPORT(elewise_mul_scl_int32); -HWY_EXPORT(elewise_div_scl_int32); - -HWY_EXPORT(elewise_add_scl_int16); -HWY_EXPORT(elewise_sub_scl_int16); -HWY_EXPORT(elewise_mul_scl_int16); -HWY_EXPORT(elewise_div_scl_int16); - -HWY_EXPORT(elewise_add_scl_int8); -HWY_EXPORT(elewise_sub_scl_int8); -HWY_EXPORT(elewise_mul_scl_int8); -HWY_EXPORT(elewise_div_scl_int8); - - HWY_DLLEXPORT void call_elewise_add_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, const mllm_fp32_t* y, size_t n) { HWY_DYNAMIC_DISPATCH(elewise_add_fp32)(out, x, y, n); } - HWY_DLLEXPORT void call_elewise_sub_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, const mllm_fp32_t* y, size_t n) { HWY_DYNAMIC_DISPATCH(elewise_sub_fp32)(out, x, y, n); } - HWY_DLLEXPORT void call_elewise_mul_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, const mllm_fp32_t* y, size_t n) { HWY_DYNAMIC_DISPATCH(elewise_mul_fp32)(out, x, y, n); } - HWY_DLLEXPORT void call_elewise_div_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, const mllm_fp32_t* y, size_t n) { HWY_DYNAMIC_DISPATCH(elewise_div_fp32)(out, x, y, n); } - - HWY_DLLEXPORT void call_elewise_add_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t* y, size_t n) { HWY_DYNAMIC_DISPATCH(elewise_add_int32)(out, x, y, n); } - HWY_DLLEXPORT void call_elewise_sub_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t* y, size_t n) { HWY_DYNAMIC_DISPATCH(elewise_sub_int32)(out, x, y, n); } - HWY_DLLEXPORT void call_elewise_mul_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t* y, size_t n) { HWY_DYNAMIC_DISPATCH(elewise_mul_int32)(out, x, y, n); } - HWY_DLLEXPORT void call_elewise_div_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t* y, size_t n) { HWY_DYNAMIC_DISPATCH(elewise_div_int32)(out, x, y, n); } - - HWY_DLLEXPORT void call_elewise_add_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t* y, size_t n) { HWY_DYNAMIC_DISPATCH(elewise_add_int16)(out, x, y, n); } - HWY_DLLEXPORT void call_elewise_sub_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t* y, size_t n) { HWY_DYNAMIC_DISPATCH(elewise_sub_int16)(out, x, y, n); } - HWY_DLLEXPORT void call_elewise_mul_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t* y, size_t n) { HWY_DYNAMIC_DISPATCH(elewise_mul_int16)(out, x, y, n); } - // HWY_DLLEXPORT void call_elewise_div_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t* y, size_t n) { // HWY_DYNAMIC_DISPATCH(elewise_div_int16)(out, x, y, n); // } - - HWY_DLLEXPORT void call_elewise_add_int8(mllm_int8_t* out, const mllm_int8_t* x, const mllm_int8_t* y, size_t n) { HWY_DYNAMIC_DISPATCH(elewise_add_int8)(out, x, y, n); } - HWY_DLLEXPORT void call_elewise_sub_int8(mllm_int8_t* out, const mllm_int8_t* x, const mllm_int8_t* y, size_t n) { HWY_DYNAMIC_DISPATCH(elewise_sub_int8)(out, x, y, n); } - HWY_DLLEXPORT void call_elewise_mul_int8(mllm_int8_t* out, const mllm_int8_t* x, const mllm_int8_t* y, size_t n) { HWY_DYNAMIC_DISPATCH(elewise_mul_int8)(out, x, y, n); } - // HWY_DLLEXPORT void call_elewise_div_int8(mllm_int8_t* out, const mllm_int8_t* x, const mllm_int8_t* y, size_t n) { // HWY_DYNAMIC_DISPATCH(elewise_div_int8)(out, x, y, n); // } - +//===----------------------------------------------------------------------===// +// Elementwise + - * / By Const +//===----------------------------------------------------------------------===// +HWY_EXPORT(elewise_add_scl_fp32); +HWY_EXPORT(elewise_sub_scl_fp32); +HWY_EXPORT(elewise_mul_scl_fp32); +HWY_EXPORT(elewise_div_scl_fp32); +// HWY_EXPORT(elewise_add_scl_fp16); +// HWY_EXPORT(elewise_sub_scl_fp16); +// HWY_EXPORT(elewise_mul_scl_fp16); +// HWY_EXPORT(elewise_div_scl_fp16); +HWY_EXPORT(elewise_add_scl_int32); +HWY_EXPORT(elewise_sub_scl_int32); +HWY_EXPORT(elewise_mul_scl_int32); +HWY_EXPORT(elewise_div_scl_int32); +HWY_EXPORT(elewise_add_scl_int16); +HWY_EXPORT(elewise_sub_scl_int16); +HWY_EXPORT(elewise_mul_scl_int16); +HWY_EXPORT(elewise_div_scl_int16); +HWY_EXPORT(elewise_add_scl_int8); +HWY_EXPORT(elewise_sub_scl_int8); +HWY_EXPORT(elewise_mul_scl_int8); +HWY_EXPORT(elewise_div_scl_int8); HWY_DLLEXPORT void call_elewise_add_scl_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n) { HWY_DYNAMIC_DISPATCH(elewise_add_scl_fp32)(out, x, y, n); } - HWY_DLLEXPORT void call_elewise_sub_scl_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n) { HWY_DYNAMIC_DISPATCH(elewise_sub_scl_fp32)(out, x, y, n); } - HWY_DLLEXPORT void call_elewise_mul_scl_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n) { HWY_DYNAMIC_DISPATCH(elewise_mul_scl_fp32)(out, x, y, n); } - HWY_DLLEXPORT void call_elewise_div_scl_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n) { HWY_DYNAMIC_DISPATCH(elewise_div_scl_fp32)(out, x, y, n); } - - HWY_DLLEXPORT void call_elewise_add_scl_int32(mllm_int32_t* out, const mllm_int32_t* x, mllm_int32_t y, size_t n) { HWY_DYNAMIC_DISPATCH(elewise_add_scl_int32)(out, x, y, n); } - HWY_DLLEXPORT void call_elewise_sub_scl_int32(mllm_int32_t* out, const mllm_int32_t* x, mllm_int32_t y, size_t n) { HWY_DYNAMIC_DISPATCH(elewise_sub_scl_int32)(out, x, y, n); } - HWY_DLLEXPORT void call_elewise_mul_scl_int32(mllm_int32_t* out, const mllm_int32_t* x, mllm_int32_t y, size_t n) { HWY_DYNAMIC_DISPATCH(elewise_mul_scl_int32)(out, x, y, n); } - HWY_DLLEXPORT void call_elewise_div_scl_int32(mllm_int32_t* out, const mllm_int32_t* x, mllm_int32_t y, size_t n) { HWY_DYNAMIC_DISPATCH(elewise_div_scl_int32)(out, x, y, n); } - - HWY_DLLEXPORT void call_elewise_add_scl_int16(mllm_int16_t* out, const mllm_int16_t* x, mllm_int16_t y, size_t n) { HWY_DYNAMIC_DISPATCH(elewise_add_scl_int16)(out, x, y, n); } - HWY_DLLEXPORT void call_elewise_sub_scl_int16(mllm_int16_t* out, const mllm_int16_t* x, mllm_int16_t y, size_t n) { HWY_DYNAMIC_DISPATCH(elewise_sub_scl_int16)(out, x, y, n); } - HWY_DLLEXPORT void call_elewise_mul_scl_int16(mllm_int16_t* out, const mllm_int16_t* x, mllm_int16_t y, size_t n) { HWY_DYNAMIC_DISPATCH(elewise_mul_scl_int16)(out, x, y, n); } - HWY_DLLEXPORT void call_elewise_div_scl_int16(mllm_int16_t* out, const mllm_int16_t* x, mllm_int16_t y, size_t n) { HWY_DYNAMIC_DISPATCH(elewise_div_scl_int16)(out, x, y, n); } - - HWY_DLLEXPORT void call_elewise_add_scl_int8(mllm_int8_t* out, const mllm_int8_t* x, mllm_int8_t y, size_t n) { HWY_DYNAMIC_DISPATCH(elewise_add_scl_int8)(out, x, y, n); } - HWY_DLLEXPORT void call_elewise_sub_scl_int8(mllm_int8_t* out, const mllm_int8_t* x, mllm_int8_t y, size_t n) { HWY_DYNAMIC_DISPATCH(elewise_sub_scl_int8)(out, x, y, n); } - HWY_DLLEXPORT void call_elewise_mul_scl_int8(mllm_int8_t* out, const mllm_int8_t* x, mllm_int8_t y, size_t n) { HWY_DYNAMIC_DISPATCH(elewise_mul_scl_int8)(out, x, y, n); } - HWY_DLLEXPORT void call_elewise_div_scl_int8(mllm_int8_t* out, const mllm_int8_t* x, mllm_int8_t y, size_t n) { HWY_DYNAMIC_DISPATCH(elewise_div_scl_int8)(out, x, y, n); } - //===----------------------------------------------------------------------===// // GELU //===----------------------------------------------------------------------===// diff --git a/mllm/backends/cpu/kernels/common/kernel_dispatch.hpp b/mllm/backends/cpu/kernels/common/kernel_dispatch.hpp index 6a6e4098..170a7406 100644 --- a/mllm/backends/cpu/kernels/common/kernel_dispatch.hpp +++ b/mllm/backends/cpu/kernels/common/kernel_dispatch.hpp @@ -18,27 +18,29 @@ namespace mllm::cpu::common { //===----------------------------------------------------------------------===// // Elementwise + - * / By Matrix //===----------------------------------------------------------------------===// +/// @brief Elementwise operations on contiguous buffers: out[i] = x[i] (op) y[i]. +/// @param out Output buffer of length n. +/// @param x Input buffer of length n. +/// @param y Input buffer of length n. +/// @param n Number of elements. +/// @note For integer division, behavior is undefined when a divisor is zero. HWY_DLLEXPORT void call_elewise_add_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, const mllm_fp32_t* y, size_t n); HWY_DLLEXPORT void call_elewise_sub_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, const mllm_fp32_t* y, size_t n); HWY_DLLEXPORT void call_elewise_mul_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, const mllm_fp32_t* y, size_t n); HWY_DLLEXPORT void call_elewise_div_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, const mllm_fp32_t* y, size_t n); - //TODO: fp16 support not implemented yet // HWY_DLLEXPORT void call_elewise_add_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, const mllm_fp16_t* y, size_t n); // HWY_DLLEXPORT void call_elewise_sub_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, const mllm_fp16_t* y, size_t n); // HWY_DLLEXPORT void call_elewise_mul_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, const mllm_fp16_t* y, size_t n); // HWY_DLLEXPORT void call_elewise_div_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, const mllm_fp16_t* y, size_t n); - HWY_DLLEXPORT void call_elewise_add_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t* y, size_t n); HWY_DLLEXPORT void call_elewise_sub_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t* y, size_t n); HWY_DLLEXPORT void call_elewise_mul_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t* y, size_t n); HWY_DLLEXPORT void call_elewise_div_int32(mllm_int32_t* out, const mllm_int32_t* x, const mllm_int32_t* y, size_t n); - HWY_DLLEXPORT void call_elewise_add_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t* y, size_t n); HWY_DLLEXPORT void call_elewise_sub_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t* y, size_t n); HWY_DLLEXPORT void call_elewise_mul_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t* y, size_t n); // HWY_DLLEXPORT void call_elewise_div_int16(mllm_int16_t* out, const mllm_int16_t* x, const mllm_int16_t* y, size_t n); - HWY_DLLEXPORT void call_elewise_add_int8(mllm_int8_t* out, const mllm_int8_t* x, const mllm_int8_t* y, size_t n); HWY_DLLEXPORT void call_elewise_sub_int8(mllm_int8_t* out, const mllm_int8_t* x, const mllm_int8_t* y, size_t n); HWY_DLLEXPORT void call_elewise_mul_int8(mllm_int8_t* out, const mllm_int8_t* x, const mllm_int8_t* y, size_t n); @@ -47,32 +49,161 @@ HWY_DLLEXPORT void call_elewise_mul_int8(mllm_int8_t* out, const mllm_int8_t* x, //===----------------------------------------------------------------------===// // Elementwise + - * / By Const //===----------------------------------------------------------------------===// +/// @brief Elementwise operations with a scalar constant: out[i] = x[i] (op) y. +/// @param out Output buffer of length n. +/// @param x Input buffer of length n. +/// @param y Scalar constant. +/// @param n Number of elements. +/// @note For integer division, behavior is undefined when y == 0. HWY_DLLEXPORT void call_elewise_add_scl_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n); HWY_DLLEXPORT void call_elewise_sub_scl_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n); HWY_DLLEXPORT void call_elewise_mul_scl_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n); HWY_DLLEXPORT void call_elewise_div_scl_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n); - //TODO: fp16 support not implemented yet // HWY_DLLEXPORT void call_elewise_add_scl_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, mllm_fp16_t y, size_t n); // HWY_DLLEXPORT void call_elewise_sub_scl_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, mllm_fp16_t y, size_t n); // HWY_DLLEXPORT void call_elewise_mul_scl_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, mllm_fp16_t y, size_t n); // HWY_DLLEXPORT void call_elewise_div_scl_fp16(mllm_fp16_t* out, const mllm_fp16_t* x, mllm_fp16_t y, size_t n); - HWY_DLLEXPORT void call_elewise_add_scl_int32(mllm_int32_t* out, const mllm_int32_t* x, mllm_int32_t y, size_t n); HWY_DLLEXPORT void call_elewise_sub_scl_int32(mllm_int32_t* out, const mllm_int32_t* x, mllm_int32_t y, size_t n); HWY_DLLEXPORT void call_elewise_mul_scl_int32(mllm_int32_t* out, const mllm_int32_t* x, mllm_int32_t y, size_t n); HWY_DLLEXPORT void call_elewise_div_scl_int32(mllm_int32_t* out, const mllm_int32_t* x, mllm_int32_t y, size_t n); - HWY_DLLEXPORT void call_elewise_add_scl_int16(mllm_int16_t* out, const mllm_int16_t* x, mllm_int16_t y, size_t n); HWY_DLLEXPORT void call_elewise_sub_scl_int16(mllm_int16_t* out, const mllm_int16_t* x, mllm_int16_t y, size_t n); HWY_DLLEXPORT void call_elewise_mul_scl_int16(mllm_int16_t* out, const mllm_int16_t* x, mllm_int16_t y, size_t n); HWY_DLLEXPORT void call_elewise_div_scl_int16(mllm_int16_t* out, const mllm_int16_t* x, mllm_int16_t y, size_t n); - HWY_DLLEXPORT void call_elewise_add_scl_int8(mllm_int8_t* out, const mllm_int8_t* x, mllm_int8_t y, size_t n); HWY_DLLEXPORT void call_elewise_sub_scl_int8(mllm_int8_t* out, const mllm_int8_t* x, mllm_int8_t y, size_t n); HWY_DLLEXPORT void call_elewise_mul_scl_int8(mllm_int8_t* out, const mllm_int8_t* x, mllm_int8_t y, size_t n); HWY_DLLEXPORT void call_elewise_div_scl_int8(mllm_int8_t* out, const mllm_int8_t* x, mllm_int8_t y, size_t n); +//===----------------------------------------------------------------------===// +// Template wrapper for generic elewise operations +//===----------------------------------------------------------------------===// +template +inline void elewise_add_anytype(T* out, const T* x, const T* y, size_t n) { + if constexpr (std::is_same_v) { + call_elewise_add_fp32(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_add_int32(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_add_int16(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_add_int8(out, x, y, n); + } else { + // Fallback + for (size_t i = 0; i < n; ++i) { out[i] = x[i] + y[i]; } + } +} + +template +inline void elewise_sub_anytype(T* out, const T* x, const T* y, size_t n) { + if constexpr (std::is_same_v) { + call_elewise_sub_fp32(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_sub_int32(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_sub_int16(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_sub_int8(out, x, y, n); + } else { + // Fallback + for (size_t i = 0; i < n; ++i) { out[i] = x[i] - y[i]; } + } +} + +template +inline void elewise_mul_anytype(T* out, const T* x, const T* y, size_t n) { + if constexpr (std::is_same_v) { + call_elewise_mul_fp32(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_mul_int32(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_mul_int16(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_mul_int8(out, x, y, n); + } else { + // Fallback + for (size_t i = 0; i < n; ++i) { out[i] = x[i] * y[i]; } + } +} + +template +inline void elewise_div_anytype(T* out, const T* x, const T* y, size_t n) { + if constexpr (std::is_same_v) { + call_elewise_div_fp32(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_div_int32(out, x, y, n); + } else { + // Fallback (note: division by zero is undefined) + for (size_t i = 0; i < n; ++i) { out[i] = x[i] / y[i]; } + } +} + +template +inline void elewise_add_scl_anytype(T* out, const T* x, T y, size_t n) { + if constexpr (std::is_same_v) { + call_elewise_add_scl_fp32(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_add_scl_int32(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_add_scl_int16(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_add_scl_int8(out, x, y, n); + } else { + // Fallback + for (size_t i = 0; i < n; ++i) { out[i] = x[i] + y; } + } +} + +template +inline void elewise_sub_scl_anytype(T* out, const T* x, T y, size_t n) { + if constexpr (std::is_same_v) { + call_elewise_sub_scl_fp32(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_sub_scl_int32(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_sub_scl_int16(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_sub_scl_int8(out, x, y, n); + } else { + // Fallback + for (size_t i = 0; i < n; ++i) { out[i] = x[i] - y; } + } +} + +template +inline void elewise_mul_scl_anytype(T* out, const T* x, T y, size_t n) { + if constexpr (std::is_same_v) { + call_elewise_mul_scl_fp32(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_mul_scl_int32(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_mul_scl_int16(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_mul_scl_int8(out, x, y, n); + } else { + // Fallback + for (size_t i = 0; i < n; ++i) { out[i] = x[i] * y; } + } +} + +template +inline void elewise_div_scl_anytype(T* out, const T* x, T y, size_t n) { + if constexpr (std::is_same_v) { + call_elewise_div_scl_fp32(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_div_scl_int32(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_div_scl_int16(out, x, y, n); + } else if constexpr (std::is_same_v) { + call_elewise_div_scl_int8(out, x, y, n); + } else { + // Fallback (note: division by zero is undefined) + for (size_t i = 0; i < n; ++i) { out[i] = x[i] / y; } + } +} + //===----------------------------------------------------------------------===// // Fill Zeros //===----------------------------------------------------------------------===// @@ -292,6 +423,12 @@ inline void fill_random_anytype(T* dst, size_t n, mllm_fp32_t start, mllm_fp32_t //===----------------------------------------------------------------------===// // Reduce //===----------------------------------------------------------------------===// +/// Sum-reduction over a strided FP32 buffer. +/// @param dst Output buffer receiving the reduction result(s). +/// @param src Input buffer. +/// @param src_stride Stride between consecutive source elements. +/// @param size Number of elements to reduce. +/// @param thread_count Requested number of threads (implementation may clamp). HWY_DLLEXPORT void call_reduce_sum_fp32(mllm_fp32_t* dst, const mllm_fp32_t* src, size_t src_stride, size_t size, int32_t thread_count); } // namespace mllm::cpu::common diff --git a/mllm/backends/cpu/ops/ElewiseOps.cpp b/mllm/backends/cpu/ops/ElewiseOps.cpp index e751b197..fd743094 100644 --- a/mllm/backends/cpu/ops/ElewiseOps.cpp +++ b/mllm/backends/cpu/ops/ElewiseOps.cpp @@ -144,20 +144,20 @@ void CPUAddOp::forward(const std::vector& inputs, std::vector& o cpu::arm::ew_add_fp32(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_add_fp32(output.ptr(), input0.ptr(), input1.ptr(), + cpu::common::elewise_add_anytype(output.ptr(), input0.ptr(), input1.ptr(), output.numel()); #else - NYI("AddOp not supported on this architecture."); + NYI("AddOp not supported on this architecture."); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_add_fp32_scalar(output.ptr(), input0.ptr(), *input1.ptr(), - output.numel(), options_.getThreads()); + output.numel(), options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_add_scl_fp32(output.ptr(), input0.ptr(), *input1.ptr(), + cpu::common::elewise_add_scl_anytype(output.ptr(), input0.ptr(), *input1.ptr(), output.numel()); #else - NYI("AddOp not supported on this architecture."); + NYI("AddOp not supported on this architecture."); #endif } else if (can_be_broadcast_naive) { const float* a = input0.ptr(); @@ -184,11 +184,11 @@ void CPUAddOp::forward(const std::vector& inputs, std::vector& o size_t b_offset = batch * vector_size; // b doesn't broadcast over loops dimension size_t out_offset = batch * broadcast_naive_loops * vector_size + l * vector_size; - cpu::common::call_elewise_add_fp32(out + out_offset, a + a_offset, b + b_offset, vector_size); + cpu::common::elewise_add_anytype(out + out_offset, a + a_offset, b + b_offset, vector_size); } } #else - NYI("AddOp not supported on this architecture."); + NYI("AddOp not supported on this architecture."); #endif } else { NYI("AddOp broadcast not supported."); @@ -202,14 +202,14 @@ void CPUAddOp::forward(const std::vector& inputs, std::vector& o cpu::arm::ew_add_fp16(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - NYI("AddOp fp16 not supported on x86 architecture yet."); + NYI("AddOp fp16 not supported on x86 architecture yet."); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) cpu::arm::ew_add_fp16_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - NYI("AddOp fp16 not supported on x86 architecture yet."); + NYI("AddOp fp16 not supported on x86 architecture yet."); #endif } else { NYI("AddOp broadcast not supported."); @@ -223,16 +223,16 @@ void CPUAddOp::forward(const std::vector& inputs, std::vector& o cpu::arm::ew_add_int32(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_add_int32(output.ptr(), input0.ptr(), input1.ptr(), - output.numel()); + cpu::common::elewise_add_anytype(output.ptr(), input0.ptr(), input1.ptr(), + output.numel()); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_add_int32_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_add_scl_int32(output.ptr(), input0.ptr(), *input1.ptr(), - output.numel()); + cpu::common::elewise_add_scl_anytype(output.ptr(), input0.ptr(), *input1.ptr(), + output.numel()); #endif } else { NYI("AddOp broadcast not supported."); @@ -246,16 +246,16 @@ void CPUAddOp::forward(const std::vector& inputs, std::vector& o cpu::arm::ew_add_int16(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_add_int16(output.ptr(), input0.ptr(), input1.ptr(), - output.numel()); + cpu::common::elewise_add_anytype(output.ptr(), input0.ptr(), input1.ptr(), + output.numel()); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_add_int16_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_add_scl_int16(output.ptr(), input0.ptr(), *input1.ptr(), - output.numel()); + cpu::common::elewise_add_scl_anytype(output.ptr(), input0.ptr(), *input1.ptr(), + output.numel()); #endif } else { NYI("AddOp broadcast not supported."); @@ -269,15 +269,15 @@ void CPUAddOp::forward(const std::vector& inputs, std::vector& o cpu::arm::ew_add_int8(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_add_int8(output.ptr(), input0.ptr(), input1.ptr(), output.numel()); + cpu::common::elewise_add_anytype(output.ptr(), input0.ptr(), input1.ptr(), output.numel()); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_add_int8_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_add_scl_int8(output.ptr(), input0.ptr(), *input1.ptr(), - output.numel()); + cpu::common::elewise_add_scl_anytype(output.ptr(), input0.ptr(), *input1.ptr(), + output.numel()); #endif } else { NYI("AddOp broadcast not supported."); @@ -301,7 +301,6 @@ void CPUAddOp::forward(const std::vector& inputs, std::vector& o const float* a = input0.ptr(); const mllm_complex_fp32_t* b = input1.ptr(); mllm_complex_fp32_t* out = output.ptr(); - #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) // Process each batch separately for (int batch = 0; batch < batch_dims; ++batch) { @@ -345,22 +344,22 @@ void CPUSubOp::forward(const std::vector& inputs, std::vector& o if (input0.numel() == input1.numel()) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_sub_fp32(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), - options_.getThreads()); + options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_sub_fp32(output.ptr(), input0.ptr(), input1.ptr(), + cpu::common::elewise_sub_anytype(output.ptr(), input0.ptr(), input1.ptr(), output.numel()); #else - NYI("SubOp not supported on this architecture."); + NYI("SubOp not supported on this architecture."); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_sub_fp32_scalar(output.ptr(), input0.ptr(), *input1.ptr(), - output.numel(), options_.getThreads()); + output.numel(), options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_sub_scl_fp32(output.ptr(), input0.ptr(), *input1.ptr(), + cpu::common::elewise_sub_scl_anytype(output.ptr(), input0.ptr(), *input1.ptr(), output.numel()); #else - NYI("SubOp not supported on this architecture."); + NYI("SubOp not supported on this architecture."); #endif } else if (can_be_broadcast_naive) { const float* a = input0.ptr(); @@ -377,7 +376,7 @@ void CPUSubOp::forward(const std::vector& inputs, std::vector& o cpu::arm::ew_sub_fp32(out + out_offset, a + a_offset, b + b_offset, vector_size, options_.getThreads()); } - } + } #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) // Process each batch separately for (int batch = 0; batch < batch_dims; ++batch) { @@ -387,11 +386,11 @@ void CPUSubOp::forward(const std::vector& inputs, std::vector& o size_t b_offset = batch * vector_size; // b doesn't broadcast over loops dimension size_t out_offset = batch * broadcast_naive_loops * vector_size + l * vector_size; - cpu::common::call_elewise_sub_fp32(out + out_offset, a + a_offset, b + b_offset, vector_size); + cpu::common::elewise_sub_anytype(out + out_offset, a + a_offset, b + b_offset, vector_size); } } #else - NYI("SubOp not supported on this architecture."); + NYI("SubOp not supported on this architecture."); #endif } else { NYI("SubOp broadcast not supported."); @@ -405,14 +404,14 @@ void CPUSubOp::forward(const std::vector& inputs, std::vector& o cpu::arm::ew_sub_fp16(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - NYI("SubOp fp16 not supported on x86 architecture yet."); + NYI("SubOp fp16 not supported on x86 architecture yet."); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) cpu::arm::ew_sub_fp16_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - NYI("SubOp fp16 not supported on x86 architecture yet."); + NYI("SubOp fp16 not supported on x86 architecture yet."); #endif } else { NYI("SubOp broadcast not supported."); @@ -426,16 +425,16 @@ void CPUSubOp::forward(const std::vector& inputs, std::vector& o cpu::arm::ew_sub_int32(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_sub_int32(output.ptr(), input0.ptr(), input1.ptr(), - output.numel()); + cpu::common::elewise_sub_anytype(output.ptr(), input0.ptr(), input1.ptr(), + output.numel()); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_sub_int32_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_sub_scl_int32(output.ptr(), input0.ptr(), *input1.ptr(), - output.numel()); + cpu::common::elewise_sub_scl_anytype(output.ptr(), input0.ptr(), *input1.ptr(), + output.numel()); #endif } else { NYI("SubOp broadcast not supported."); @@ -449,16 +448,16 @@ void CPUSubOp::forward(const std::vector& inputs, std::vector& o cpu::arm::ew_sub_int16(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_sub_int16(output.ptr(), input0.ptr(), input1.ptr(), - output.numel()); + cpu::common::elewise_sub_anytype(output.ptr(), input0.ptr(), input1.ptr(), + output.numel()); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_sub_int16_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_sub_scl_int16(output.ptr(), input0.ptr(), *input1.ptr(), - output.numel()); + cpu::common::elewise_sub_scl_anytype(output.ptr(), input0.ptr(), *input1.ptr(), + output.numel()); #endif } else { NYI("SubOp broadcast not supported."); @@ -472,15 +471,15 @@ void CPUSubOp::forward(const std::vector& inputs, std::vector& o cpu::arm::ew_sub_int8(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_sub_int8(output.ptr(), input0.ptr(), input1.ptr(), output.numel()); + cpu::common::elewise_sub_anytype(output.ptr(), input0.ptr(), input1.ptr(), output.numel()); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_sub_int8_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_sub_scl_int8(output.ptr(), input0.ptr(), *input1.ptr(), - output.numel()); + cpu::common::elewise_sub_scl_anytype(output.ptr(), input0.ptr(), *input1.ptr(), + output.numel()); #endif } else { NYI("SubOp broadcast not supported."); @@ -548,22 +547,22 @@ void CPUMulOp::forward(const std::vector& inputs, std::vector& o if (input0.numel() == input1.numel()) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_mul_fp32(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), - options_.getThreads()); + options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_mul_fp32(output.ptr(), input0.ptr(), input1.ptr(), + cpu::common::elewise_mul_anytype(output.ptr(), input0.ptr(), input1.ptr(), output.numel()); #else - NYI("MulOp not supported on this architecture."); + NYI("MulOp not supported on this architecture."); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_mul_fp32_scalar(output.ptr(), input0.ptr(), *input1.ptr(), - output.numel(), options_.getThreads()); + output.numel(), options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_mul_scl_fp32(output.ptr(), input0.ptr(), *input1.ptr(), + cpu::common::elewise_mul_scl_anytype(output.ptr(), input0.ptr(), *input1.ptr(), output.numel()); #else - NYI("MulOp not supported on this architecture."); + NYI("MulOp not supported on this architecture."); #endif } else if (can_be_broadcast_naive) { const float* a = input0.ptr(); @@ -580,21 +579,21 @@ void CPUMulOp::forward(const std::vector& inputs, std::vector& o cpu::arm::ew_mul_fp32(out + out_offset, a + a_offset, b + b_offset, vector_size, options_.getThreads()); } - } + } #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) // Process each batch separately for (int batch = 0; batch < batch_dims; ++batch) { - // Each batch processes broadcast_naive_loops iterations of vector_size elements + // Each batch processes broadcast_naive_loops iterations of vector_size elements for (int l = 0; l < broadcast_naive_loops; ++l) { size_t a_offset = batch * broadcast_naive_loops * vector_size + l * vector_size; size_t b_offset = batch * vector_size; // b doesn't broadcast over loops dimension size_t out_offset = batch * broadcast_naive_loops * vector_size + l * vector_size; - cpu::common::call_elewise_mul_fp32(out + out_offset, a + a_offset, b + b_offset, vector_size); + cpu::common::elewise_mul_anytype(out + out_offset, a + a_offset, b + b_offset, vector_size); } } #else - NYI("MulOp not supported on this architecture."); + NYI("MulOp not supported on this architecture."); #endif } else { NYI("MulOp broadcast not supported."); @@ -608,14 +607,14 @@ void CPUMulOp::forward(const std::vector& inputs, std::vector& o cpu::arm::ew_mul_fp16(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - NYI("MulOp fp16 not supported on x86 architecture yet."); + NYI("MulOp fp16 not supported on x86 architecture yet."); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) cpu::arm::ew_mul_fp16_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - NYI("MulOp fp16 not supported on x86 architecture yet."); + NYI("MulOp fp16 not supported on x86 architecture yet."); #endif } else { NYI("MulOp broadcast not supported."); @@ -629,16 +628,16 @@ void CPUMulOp::forward(const std::vector& inputs, std::vector& o cpu::arm::ew_mul_int32(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_mul_int32(output.ptr(), input0.ptr(), input1.ptr(), - output.numel()); + cpu::common::elewise_mul_anytype(output.ptr(), input0.ptr(), input1.ptr(), + output.numel()); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_mul_int32_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_mul_scl_int32(output.ptr(), input0.ptr(), *input1.ptr(), - output.numel()); + cpu::common::elewise_mul_scl_anytype(output.ptr(), input0.ptr(), *input1.ptr(), + output.numel()); #endif } else { NYI("MulOp broadcast not supported."); @@ -652,16 +651,16 @@ void CPUMulOp::forward(const std::vector& inputs, std::vector& o cpu::arm::ew_mul_int16(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_mul_int16(output.ptr(), input0.ptr(), input1.ptr(), - output.numel()); + cpu::common::elewise_mul_anytype(output.ptr(), input0.ptr(), input1.ptr(), + output.numel()); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_mul_int16_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_mul_scl_int16(output.ptr(), input0.ptr(), *input1.ptr(), - output.numel()); + cpu::common::elewise_mul_scl_anytype(output.ptr(), input0.ptr(), *input1.ptr(), + output.numel()); #endif } else { NYI("MulOp broadcast not supported."); @@ -675,15 +674,15 @@ void CPUMulOp::forward(const std::vector& inputs, std::vector& o cpu::arm::ew_mul_int8(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_mul_int8(output.ptr(), input0.ptr(), input1.ptr(), output.numel()); + cpu::common::elewise_mul_anytype(output.ptr(), input0.ptr(), input1.ptr(), output.numel()); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_mul_int8_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_mul_scl_int8(output.ptr(), input0.ptr(), *input1.ptr(), - output.numel()); + cpu::common::elewise_mul_scl_anytype(output.ptr(), input0.ptr(), *input1.ptr(), + output.numel()); #endif } else { NYI("MulOp broadcast not supported."); @@ -751,22 +750,22 @@ void CPUDivOp::forward(const std::vector& inputs, std::vector& o if (input0.numel() == input1.numel()) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_div_fp32(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), - options_.getThreads()); + options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_div_fp32(output.ptr(), input0.ptr(), input1.ptr(), + cpu::common::elewise_div_anytype(output.ptr(), input0.ptr(), input1.ptr(), output.numel()); #else - NYI("DivOp not supported on this architecture."); + NYI("DivOp not supported on this architecture."); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_div_fp32_scalar(output.ptr(), input0.ptr(), *input1.ptr(), - output.numel(), options_.getThreads()); + output.numel(), options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_div_scl_fp32(output.ptr(), input0.ptr(), *input1.ptr(), + cpu::common::elewise_div_scl_anytype(output.ptr(), input0.ptr(), *input1.ptr(), output.numel()); #else - NYI("DivOp not supported on this architecture."); + NYI("DivOp not supported on this architecture."); #endif } else if (can_be_broadcast_naive) { const float* a = input0.ptr(); @@ -783,21 +782,21 @@ void CPUDivOp::forward(const std::vector& inputs, std::vector& o cpu::arm::ew_div_fp32(out + out_offset, a + a_offset, b + b_offset, vector_size, options_.getThreads()); } - } + } #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) // Process each batch separately for (int batch = 0; batch < batch_dims; ++batch) { - // Each batch processes broadcast_naive_loops iterations of vector_size elements + // Each batch processes broadcast_naive_loops iterations of vector_size elements for (int l = 0; l < broadcast_naive_loops; ++l) { size_t a_offset = batch * broadcast_naive_loops * vector_size + l * vector_size; size_t b_offset = batch * vector_size; // b doesn't broadcast over loops dimension size_t out_offset = batch * broadcast_naive_loops * vector_size + l * vector_size; - cpu::common::call_elewise_div_fp32(out + out_offset, a + a_offset, b + b_offset, vector_size); + cpu::common::elewise_div_anytype(out + out_offset, a + a_offset, b + b_offset, vector_size); } } #else - NYI("DivOp not supported on this architecture."); + NYI("DivOp not supported on this architecture."); #endif } else { NYI("DivOp broadcast not supported."); @@ -811,14 +810,14 @@ void CPUDivOp::forward(const std::vector& inputs, std::vector& o cpu::arm::ew_div_fp16(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - NYI("DivOp fp16 not supported on x86 architecture yet."); + NYI("DivOp fp16 not supported on x86 architecture yet."); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) cpu::arm::ew_div_fp16_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - NYI("DivOp fp16 not supported on x86 architecture yet."); + NYI("DivOp fp16 not supported on x86 architecture yet."); #endif } else { NYI("DivOp broadcast not supported."); @@ -832,16 +831,16 @@ void CPUDivOp::forward(const std::vector& inputs, std::vector& o cpu::arm::ew_div_int32(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_div_int32(output.ptr(), input0.ptr(), input1.ptr(), - output.numel()); + cpu::common::elewise_div_anytype(output.ptr(), input0.ptr(), input1.ptr(), + output.numel()); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_div_int32_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - cpu::common::call_elewise_div_scl_int32(output.ptr(), input0.ptr(), *input1.ptr(), - output.numel()); + cpu::common::elewise_div_scl_anytype(output.ptr(), input0.ptr(), *input1.ptr(), + output.numel()); #endif } else { NYI("DivOp broadcast not supported."); @@ -855,14 +854,14 @@ void CPUDivOp::forward(const std::vector& inputs, std::vector& o cpu::arm::ew_div_int16(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - NYI("DivOp int16 not supported on x86 architecture yet."); + NYI("DivOp int16 not supported on x86 architecture yet."); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_div_int16_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - NYI("DivOp int16 not supported on x86 architecture yet."); + NYI("DivOp int16 not supported on x86 architecture yet."); #endif } else { NYI("DivOp broadcast not supported."); @@ -876,14 +875,14 @@ void CPUDivOp::forward(const std::vector& inputs, std::vector& o cpu::arm::ew_div_int8(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - NYI("DivOp int8 not supported on x86 architecture yet."); + NYI("DivOp int8 not supported on x86 architecture yet."); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_div_int8_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - NYI("DivOp int8 not supported on x86 architecture yet."); + NYI("DivOp int8 not supported on x86 architecture yet."); #endif } else { NYI("DivOp broadcast not supported."); @@ -898,14 +897,14 @@ void CPUDivOp::forward(const std::vector& inputs, std::vector& o cpu::arm::ew_div_fp32_complex(output.ptr(), input0.ptr(), input1.ptr(), output.numel(), options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - NYI("DivOp complex fp32 not supported on x86 architecture yet."); + NYI("DivOp complex fp32 not supported on x86 architecture yet."); #endif } else if (input1.numel() == 1) { #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) cpu::arm::ew_div_fp32_complex_scalar(output.ptr(), input0.ptr(), *input1.ptr(), output.numel(), options_.getThreads()); #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - NYI("DivOp complex fp32 not supported on x86 architecture yet."); + NYI("DivOp complex fp32 not supported on x86 architecture yet."); #endif } else if (can_be_broadcast_naive) { const float* a = input0.ptr(); @@ -925,7 +924,7 @@ void CPUDivOp::forward(const std::vector& inputs, std::vector& o } } #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - NYI("DivOp complex fp32 not supported on x86 architecture yet."); + NYI("DivOp complex fp32 not supported on x86 architecture yet."); #endif } else { NYI("DivOp broadcast for complex output not supported."); From 3b5c08813824d25bf909f2dd324a6019b0b6cedc Mon Sep 17 00:00:00 2001 From: HayzelHan Date: Sun, 15 Feb 2026 15:32:37 +0000 Subject: [PATCH 5/5] feat: add reduce over all dimensions for reduce operator --- mllm/backends/cpu/kernels/common/reduce-inl.hpp | 6 +++--- mllm/backends/cpu/ops/MatMulOp.cpp | 4 ++-- mllm/backends/cpu/ops/ReduceOps.cpp | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/mllm/backends/cpu/kernels/common/reduce-inl.hpp b/mllm/backends/cpu/kernels/common/reduce-inl.hpp index 2c381ae0..357c0b0c 100644 --- a/mllm/backends/cpu/kernels/common/reduce-inl.hpp +++ b/mllm/backends/cpu/kernels/common/reduce-inl.hpp @@ -7,7 +7,7 @@ HWY_BEFORE_NAMESPACE(); namespace mllm::cpu::common { // NOLINT namespace HWY_NAMESPACE { -namespace hn = hwy::HWY_NAMESPACE; +namespace hn = hwy::HWY_NAMESPACE; struct ScalarAddOp { template HWY_INLINE T operator()(T a, T b) const { return a + b; } }; @@ -59,7 +59,7 @@ struct VecSumReduce { template -HWY_INLINE T reduce_impl(const T* HWY_RESTRICT src, size_t src_stride, size_t size, +HWY_INLINE T reduce_impl(const T* HWY_RESTRICT src, size_t src_stride, size_t size, ScalarOp&& scalar_op, VectorOp&& vec_op, VectorReduceOp&& vec_reduce_op) { if (size == 0) return T(0); @@ -99,7 +99,7 @@ HWY_INLINE T reduce_impl(const T* HWY_RESTRICT src, size_t src_stride, size_t si return vec_reduce_op(d, vec_result); } - + // Scalar path (stride != 1 or too small) T scalar_result = src[0]; for (size_t i = 1; i < size; ++i) { diff --git a/mllm/backends/cpu/ops/MatMulOp.cpp b/mllm/backends/cpu/ops/MatMulOp.cpp index b22549fb..4f4cc0ef 100644 --- a/mllm/backends/cpu/ops/MatMulOp.cpp +++ b/mllm/backends/cpu/ops/MatMulOp.cpp @@ -50,7 +50,7 @@ void CPUMatMulOp::forward(const std::vector& inputs, std::vector mt = aops::MatMulOpType::kBLAS; #else if (!transpose_a && transpose_b) { - // TODO kGGUF still buggy !!! + // TODO: kGGUF still buggy !!! mt = aops::MatMulOpType::kGGUF; } else // All fallback to mllm blas @@ -121,7 +121,7 @@ void CPUMatMulOp::forward(const std::vector& inputs, std::vector // o.ptr(), lhs.ptr(), rhs.ptr(), nullptr, // transpose_a, transpose_b); // } -// } +// } #else NYI("MllmBlas only support MLLM_HOST_ARCH_ARM64 or MLLM_HOST_ARCH_ARM right now.") #endif diff --git a/mllm/backends/cpu/ops/ReduceOps.cpp b/mllm/backends/cpu/ops/ReduceOps.cpp index e9eb325d..c42f4070 100644 --- a/mllm/backends/cpu/ops/ReduceOps.cpp +++ b/mllm/backends/cpu/ops/ReduceOps.cpp @@ -294,8 +294,8 @@ void CPUReduceSumOp::forward(const std::vector& inputs, std::vector(), input.ptr(), 1, input.numel(), options_.getThreads()); -#elif defined(MLLM_HOST_ARCH_X86) || defined(MLLM_HOST_ARCH_X86_64) - NYI("ReduceSumOp not implemented for x86/x86_64 yet."); +#elif defined(MLLM_HOST_ARCH_X86) || defined(MLLM_HOST_ARCH_X86_64) + common::call_reduce_sum_fp32(output.ptr(), input.ptr(), 1, input.numel(), options_.getThreads()); #endif break; } @@ -348,7 +348,7 @@ void CPUReduceSumOp::forward(const std::vector& inputs, std::vector