Skip to content

feat(x86,cpu) add support for Qwen3 MoE model on x86#619

Merged
oreomaker merged 5 commits intoUbiquitousLearning:mainfrom
HayzelHan:feat/qwen3-30b-a3b
Feb 17, 2026
Merged

feat(x86,cpu) add support for Qwen3 MoE model on x86#619
oreomaker merged 5 commits intoUbiquitousLearning:mainfrom
HayzelHan:feat/qwen3-30b-a3b

Conversation

@HayzelHan
Copy link
Contributor

@HayzelHan HayzelHan commented Jan 30, 2026

  • Add support for Qwen3 MoE model(tested with Qwen3-30b-a3b) on x86
  • Add elewise/reduce kernels implementation for common use
  • Remove matmul restriction for vecdot

Summary by CodeRabbit

  • New Features

    • Added Qwen3 MoE model support with tokenizer, configuration, and an interactive example CLI
    • Added quantization config for Qwen3 MoE and packaging support
  • Improvements

    • Added integer elementwise arithmetic (8/16/32-bit) and scalar-by-constant variants
    • Introduced SIMD-accelerated reduction API and broader CPU kernel dispatch coverage

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 30, 2026

📝 Walkthrough

Walkthrough

Adds a Qwen3 Mixture-of-Experts model (tokenizer, config, model, example CLI) and expands CPU backend kernels with integer elementwise operations, scalar-by-constant variants, and a SIMD-accelerated reduction API.

Changes

Cohort / File(s) Summary
Example build & runner
examples/CMakeLists.txt, examples/qwen3_moe/CMakeLists.txt
Register new qwen3_moe example and add mllm-qwen3-moe-runner executable with proper links and include dirs.
Example assets & CLI
examples/qwen3_moe/config_30B_A3B_gguf.json, examples/qwen3_moe/quant_cfg_30B_q4_k.json, examples/qwen3_moe/main.cpp
Add model config and quantization hints; new interactive CLI that loads tokenizer/config/model, streams generation, and optionally emits perf traces.
Model config & public API
mllm/models/qwen3_moe/configuration_qwen3_moe.hpp
Introduce Qwen3MoeConfig struct with MoE-specific parameters and mapping for linear impl type.
Model implementation
mllm/models/qwen3_moe/modeling_qwen3_moe_fa2.hpp
Add Qwen3 MoE model: RoPE utilities, attention, MoE gate/routing, expert ensemble, decoder stack, and causal LM wrapper.
Tokenizer
mllm/models/qwen3_moe/tokenization_qwen3_moe.hpp
Add Qwen3Tokenizer with pattern matching, BPE, Unicode mapping, message conversion, and detokenization utilities.
Elementwise kernels (impl)
mllm/backends/cpu/kernels/common/elewise-inl.hpp
Reintroduced generic elementwise_impl and expanded public elementwise API with int8/16/32 ops and scalar-by-constant (scl_) variants; removed legacy scalar-Op structs; FP16 variants left commented.
Kernel dispatch & exports
mllm/backends/cpu/kernels/common/kernel_dispatch.hpp, .../kernel_dispatch.cpp
Add many new call/export symbols for int32/int16/int8 elewise and scl variants; remove older scalar_fp32 exports; expose reduce_sum_fp32 and corresponding call wrappers.
Reduction (SIMD)
mllm/backends/cpu/kernels/common/reduce-inl.hpp
Add Highway-based SIMD reduction implementation with 4x unrolled vector path and reduce_sum_fp32 entrypoint.
Quantize helper
mllm/backends/cpu/kernels/common/ggml/quantize/quantize.hpp
Change lookup_fp16_to_fp32 parameter type from mllm_fp16_t to uint16_t.
Ops dispatch & usage
mllm/backends/cpu/ops/ElewiseOps.cpp, mllm/backends/cpu/ops/ReduceOps.cpp
Reorder architecture-specific dispatches (ARM vs x86) for elementwise ops; route reduce calls on x86/x64 to call_reduce_sum_fp32.
MatMul fallback tweak
mllm/backends/cpu/ops/MatMulOp.cpp
Remove M >= 4 requirement from GGUF fallback condition; add commented x86 BLAS alternative.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant CLI as examples/qwen3_moe/main.cpp
    participant Tokenizer as Qwen3Tokenizer
    participant Model as Qwen3MoeForCausalLM
    participant Router as Qwen3MoE
    participant Output as Console

    User->>CLI: enter prompt
    CLI->>Tokenizer: tokenize(prompt)
    Tokenizer-->>CLI: token_ids
    CLI->>Model: forward(token_ids)
    Model->>Model: embed + RoPE + cache
    Model->>Router: route per-token (MoEGate)
    Router->>Router: dispatch to experts, compute weighted outputs
    Router-->>Model: aggregated moe_output
    Model-->>CLI: logits / token stream
    CLI->>Tokenizer: detokenize(token_id)
    Tokenizer-->>Output: token_text
    Output->>User: display streamed response
Loading
sequenceDiagram
    participant App as Caller
    participant Dispatch as kernel_dispatch
    participant Arch as ArchSelector
    participant ARM as ARM Path (NEON)
    participant X86 as Common Impl (Highway)
    participant Impl as elewise-inl.hpp

    App->>Dispatch: call_elewise_add_int32(...)
    Dispatch->>Arch: detect architecture
    alt ARM
        Arch->>ARM: call ARM-specific ew_add_int32
        ARM->>Impl: NEON vector ops
        Impl-->>ARM: result
        ARM-->>Dispatch: done
    else x86/x64
        Arch->>X86: call elementwise_impl via Highway
        X86->>Impl: vectorized elementwise_impl
        Impl-->>X86: result
        X86-->>Dispatch: done
    end
    Dispatch-->>App: output buffer filled
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • yirongjie
  • chenghuaWang
  • oreomaker

Poem

🐰 A rabbit's hop for code and cheer
Experts gather, swift and near,
Integers hum, reductions sing,
Token streams take gentle wing,
Hooray—our MoE springs to gear! 🥕

🚥 Pre-merge checks | ✅ 2 | ❌ 2
❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 4.65% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ❓ Inconclusive The PR description provides a bulleted list of key changes but lacks the structure and detail recommended by the contribution template, which asks for clear sections and guidelines reference. Expand the description to follow the template structure more closely, including context, testing details, and any breaking changes or migration notes.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly summarizes the main changes: adding Qwen3 MoE model support on x86 CPU, which is the primary objective reflected in the changeset.
Merge Conflict Detection ✅ Passed ✅ No merge conflicts detected when merging into main

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 15

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
mllm/backends/cpu/ops/MatMulOp.cpp (1)

52-59: ⚠️ Potential issue | 🟠 Major

Add contiguity guard and fix TODO format and trailing whitespace.

The default fallback routes all (!transpose_a && transpose_b) cases to GGUF (line 53), but the GGUF path throws NYI for non‑contiguous tensors (line 77). This will crash at runtime instead of safely falling back to MllmBlas. Add a contiguity check to the fallback condition and fix the TODO format and trailing whitespace:

  • Line 53: TODO comment missing colon—change // TODO to // TODO:
  • Line 124: Remove trailing whitespace
  • Lines 52–59: Guard GGUF selection with contiguity check:
    -    if (!transpose_a && transpose_b) {
    -      // TODO kGGUF still buggy !!!
    +    if (!transpose_a && transpose_b && lhs.isContiguousN(0)) {
    +      // TODO: kGGUF still buggy !!!
           mt = aops::MatMulOpType::kGGUF;
mllm/backends/cpu/ops/ElewiseOps.cpp (1)

142-281: ⚠️ Potential issue | 🔴 Critical

Remove trailing whitespace in the CPUAddOp dispatch block.
Lines 150 and 155 (and nearby lines in this block) end with trailing spaces; please trim trailing whitespace here.
As per coding guidelines, no line may end with trailing whitespace.

🤖 Fix all issues with AI agents
In `@examples/qwen3_moe/main.cpp`:
- Around line 44-63: The loop currently always processes input even when the
user types the advertised "exit" or "quit"; after reading prompt_text (variable
prompt_text) trim/normalize it and check for "exit" or "quit" (case-insensitive)
and if matched, break out of the interactive loop or return from main instead of
calling qwen3_moe_tokenizer.convertMessage and qwen3_moe.chat; place this check
immediately after std::getline and before the try block so convertMessage and
qwen3_moe.chat are not invoked for exit/quit.
- Around line 23-28: The code silently defaults file_version to
mllm::ModelFileVersion::kV1 when model_version contains an unknown string;
change the logic in the block that reads model_version/get() so it validates the
value and handles unknown values explicitly (e.g., log an error and abort or
throw an exception) instead of falling through to kV1. Specifically, inspect the
model_version.get() branches that set file_version and add an else branch that
reports the invalid model_version value (using process logging or throwing) and
prevents loading the wrong format; ensure references to model_version,
file_version and mllm::ModelFileVersion are used so the error message identifies
the offending string.

In `@mllm/backends/cpu/kernels/common/kernel_dispatch.hpp`:
- Around line 292-295: The public function call_reduce_sum_fp32 lacks a
docstring explaining parameter contracts; add a concise comment above its
declaration that documents: what dst and src point to, that src_stride is the
distance in elements (mllm_fp32_t units) between consecutive rows/segments in
src, that size is the number of elements per segment to reduce, and the exact
semantics and valid range for thread_count (e.g., 1 = single-threaded, 0 or -1 =
auto-select max hardware threads if supported, upper bound is the number of
logical workers), plus any alignment, concurrency or error expectations (e.g.,
caller must ensure dst has space for results and thread_count must be >0 or
documented special value). Reference the symbol call_reduce_sum_fp32 when adding
the doc so callers know the contract.
- Around line 26-74: Public HWY_DLLEXPORT elementwise entry points (e.g.,
call_elewise_add_int32 / call_elewise_sub_int8 and the scalar variants like
call_elewise_add_scl_fp32 / call_elewise_div_scl_int8) lack documentation; add
short Doxygen-style blocks above the grouped declarations for both "Elementwise
+ - * / By Vector" and "Elementwise + - * / By Const" sections that state the
purpose, describe each parameter (out, x, y, n) and expected buffer lengths (n
elements), explain scalar semantics for _scl_ variants (scalar applied to each
element of x), define behavior on division (division-by-zero handling or
undefined behavior), and note any data-type-specific caveats (e.g., fp16 TODOs
or integer overflow behavior) so callers understand inputs, outputs, and error
expectations.

In `@mllm/backends/cpu/kernels/common/reduce-inl.hpp`:
- Around line 10-107: Trim trailing whitespace on the two affected lines in this
header: remove the extra space at the end of the line containing "namespace hn =
hwy::HWY_NAMESPACE;" and the trailing space at the end of the line containing
"return vec_reduce_op(d, vec_result);". Update the file
mllm/backends/cpu/kernels/common/reduce-inl.hpp so no line ends with trailing
spaces (verify with your editor or run a whitespace linter), then save and
re-run formatting checks.

In `@mllm/backends/cpu/ops/MatMulOp.cpp`:
- Line 124: Remove the trailing whitespace at the end of the commented line
containing "//       }" in MatMulOp.cpp (the stray spaces after the closing
brace comment); edit the commented line so it ends immediately after the comment
text with no spaces or tabs following it to comply with the "no trailing
whitespace" rule.
- Around line 53-54: The inline comment before assigning mt =
aops::MatMulOpType::kGGUF is using "TODO" without the required colon; update the
comment to use the mandated format by changing "// TODO kGGUF still buggy !!!"
to start with "TODO:" (e.g. "TODO: kGGUF still buggy") near the assignment to mt
in MatMulOp.cpp so it conforms to the "TODO:" convention while leaving the rest
of the message intact.

In `@mllm/backends/cpu/ops/ReduceOps.cpp`:
- Around line 349-351: The lines in the x86/x86_64 branch calling
cpu::common::call_reduce_sum_fp32 in ReduceOps.cpp contain trailing whitespace
at the end of the line(s); remove those trailing spaces so no line ends with
extraneous whitespace (trim the end of the line(s) around the call to
call_reduce_sum_fp32 and the surrounding preprocessor block).
- Around line 294-299: The x86/x86_64 branch for kFloat32 currently calls NYI
but should perform a full-tensor FP32 sum like the ARM path; replace the NYI
with a call to the existing call_reduce_sum_fp32 helper using the same arguments
as the ARM implementation (output.ptr<mllm_fp32_t>(), input.ptr<mllm_fp32_t>(),
1, input.numel(), options_.getThreads()), ensuring the ReduceSum codepath for
FP32 uses call_reduce_sum_fp32 rather than NYI.

In `@mllm/models/qwen3_moe/configuration_qwen3_moe.hpp`:
- Around line 10-73: Add Doxygen-style docstrings to the public API: document
the Qwen3MoeConfig struct itself, the explicit constructor Qwen3MoeConfig(const
std::string& file_path), and each publicly visible config field (e.g.,
attention_bias, hidden_size, head_dim, intermediate_size, num_attention_heads,
num_key_value_heads, num_hidden_layers, max_position_embeddings, rms_norm_eps,
vocab_size, bos_token_id, eos_token_id, rope_theta, tie_word_embeddings,
max_cache_length, end_of_text_token_id, thinking_start_token_id,
thinking_end_token_id, num_experts, num_experts_per_tok, moe_intermediate_size,
norm_topk_prob, decoder_sparse_step, mlp_only_layers, linear_impl_type). For
each give a one-line summary of purpose, any important value ranges or units,
and note that the constructor reads values from the ConfigFile; mention expected
behavior on missing/invalid entries (e.g., defaults are used). Ensure comments
use Doxygen tags (`@brief`, `@param`, `@note` or `@throws` where applicable) and are
placed immediately above the struct, constructor, and each field declaration.
- Around line 26-41: The constructor currently reads bos_token_id and
eos_token_id but does not load end_of_text_token_id or any thinking_* fields
from the JSON, so JSON overrides are ignored; modify the same initialization
block (near bos_token_id, eos_token_id, rope_theta, tie_word_embeddings,
max_cache_length, num_experts, num_experts_per_tok, moe_intermediate_size,
norm_topk_prob, decoder_sparse_step, mlp_only_layers, linear_impl_type) to also
assign end_of_text_token_id and all thinking_* configuration keys from data()
(using the same access pattern and appropriate get<...>() types as used for
mlp_only_layers) so those tokens honor JSON overrides.

In `@mllm/models/qwen3_moe/modeling_qwen3_moe_fa2.hpp`:
- Around line 15-491: Add concise docstrings/comments above each public function
and class declaration to describe purpose, parameters, return values, and any
error/edge-case behaviors; specifically add comments for makeRoPEInvFreq,
makeRotaryPosEmbedding, Qwen3MoeMLP (and its constructor/forward), MoEGate
(constructor/forward), Qwen3MoE (constructor/forward/moeInfer),
Qwen3MoeAttention (constructor/forward), Qwen3MoeDecoder (constructor/forward),
Qwen3MoeText (constructor/forward), and Qwen3MoeForCausalLM
(constructor/forward/kvCache); keep each docstring short (one- to three-sentence
description plus brief param/return lines), place them immediately above the
corresponding function/class, and mention any important preconditions (e.g.,
expected tensor shapes, dtype requirements, or when exceptions/assertions are
raised).
- Around line 197-199: The line assigning sorted_tokens using auto sorted_tokens
= x[{idxs / topk_ids.size(1), {kAll}}]; is flagged as uncertain and the inline
TODO must be converted to the required "TODO:" format and clarified; verify that
the batch index computation uses the correct dimension (ensure idxs is divided
by the correct topk_ids dimension and that the division yields integer batch
indices) and confirm that x[{...}] returns tokens in the intended score-sorted
order (adjust the index expression if the batch/beam mapping is reversed), then
replace the comment with a clear "TODO: confirm token ordering and correct batch
index computation for sorted_tokens (uses idxs, topk_ids.size(1), x and kAll)"
describing what to validate or change.
- Around line 345-347: The code uses modulo with cfg.decoder_sparse_step in the
condition around mlp_opt0_ initialization which can divide by zero; update the
relevant initialization/constructor logic (the block that checks
cfg.num_experts, layer_idx_ and uses (layer_idx_+1) % cfg.decoder_sparse_step)
to first validate cfg.decoder_sparse_step is > 0 and either throw a clear
runtime_error (fail fast) or set a safe fallback value (e.g., treat zero as 1)
before performing the modulo; reference cfg.decoder_sparse_step,
cfg.num_experts, layer_idx_ and the mlp_opt0_ assignment so you modify the exact
conditional guarding mlp_opt0_ initialization.

In `@mllm/models/qwen3_moe/tokenization_qwen3_moe.hpp`:
- Around line 19-267: Add brief doc comments to all public-facing symbols: the
free functions qwen3TokenizerMatchPattern and qwen3Regex, the Qwen3Message
struct, and the Qwen3Tokenizer public API (constructor Qwen3Tokenizer,
_tokenize, tokenize, _detokenize, detokenize, convert2Ids, convertMessage). For
each, state one-line purpose, parameters (e.g., str, pos, matched, message),
return value, and any error/edge-case behavior (e.g., returns false or empty
vector on no match/input). Place comments above the declarations so they
document intent, inputs, outputs and error conditions per coding guidelines.
🧹 Nitpick comments (3)
mllm/backends/cpu/ops/MatMulOp.cpp (1)

113-124: Consider removing or compile‑guarding the commented‑out x86 path.

The large commented block adds dead code and makes the dispatch harder to reason about. If this is intended for later work, consider a proper #if guard or a follow‑up PR instead of commented‑out code.

🧹 Example cleanup (remove commented block)
-// `#elif` defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86)
-//       if (lhs.dtype() == kFloat32 && rhs.dtype() == kFloat32 && o.dtype() == kFloat32) {
-//         if (batch_count == 1) {
-//           x86::mllm_blas_matmul_fp32(M, K, N, o.ptr<mllm_fp32_t>(), lhs.ptr<mllm_fp32_t>(), rhs.ptr<mllm_fp32_t>(), nullptr,
-//                                        transpose_a, transpose_b);
-//         } else {
-//           x86::mllm_blas_batch_matmul_fp32(batch_count, M, K, N, o.stride()[o.shape().size() - 3],
-//                                               lhs.stride()[lhs_shape.size() - 3], rhs.stride()[rhs_shape.size() - 3], 0,
-//                                               o.ptr<mllm_fp32_t>(), lhs.ptr<mllm_fp32_t>(), rhs.ptr<mllm_fp32_t>(), nullptr,
-//                                               transpose_a, transpose_b);
-//         }
-//       }  
mllm/models/qwen3_moe/tokenization_qwen3_moe.hpp (1)

163-179: Consider registering <unk> as a special token for parity.

mllm/models/qwen_npu/tokenization_qwen.hpp:178-183 adds <unk> to the special-token trie; adding it here keeps special-token handling consistent when upstream emits <unk>.

♻️ Suggested change
     special_tokens_trie_.add(L"<|endoftext|>");
     special_tokens_trie_.add(L"<|im_start|>");
     special_tokens_trie_.add(L"<|im_end|>");
+    special_tokens_trie_.add(L"<unk>");
mllm/backends/cpu/kernels/common/reduce-inl.hpp (1)

113-118: Silence the unused thread_count parameter.
thread_count is currently unused and may trigger warnings; consider casting to void until parallel reduction is implemented.

🛠️ Proposed fix
 HWY_NOINLINE HWY_MAYBE_UNUSED void reduce_sum_fp32(mllm_fp32_t* dst,const mllm_fp32_t* src,
 size_t src_stride, size_t size, int32_t thread_count) {
+  (void)thread_count;
   const mllm_fp32_t v = reduce_impl<mllm_fp32_t>(src, src_stride, size,
       ScalarAddOp{}, VecAddOp{}, VecSumReduce{});
   *dst = v;
 }

Comment on lines +44 to +63
fmt::print("\n{:*^60}\n", " Qwen3 MoE Interactive CLI ");
fmt::print("Enter 'exit' or 'quit' to end the session\n\n");

std::string prompt_text;

fmt::print("💬 Prompt text (or 'exit/quit'): ");
std::getline(std::cin, prompt_text);

try {
fmt::print("🔄 Processing...\n");
auto inputs = qwen3_moe_tokenizer.convertMessage({.prompt = prompt_text});

fmt::print("\n🤖 Response: ");

// Use for loop
for (auto& step : qwen3_moe.chat(inputs)) { std::wcout << qwen3_moe_tokenizer.detokenize(step.cur_token_id) << std::flush; }

fmt::print("\n{}\n", std::string(60, '-'));
} catch (const std::exception& e) { fmt::print("\n❌ Error: {}\n{}\n", e.what(), std::string(60, '-')); }

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Honor the advertised exit/quit commands.
The prompt tells users these commands end the session, but the input is always processed.

🛠️ Proposed fix
     fmt::print("💬 Prompt text (or 'exit/quit'): ");
     std::getline(std::cin, prompt_text);
 
-    try {
-      fmt::print("🔄 Processing...\n");
-      auto inputs = qwen3_moe_tokenizer.convertMessage({.prompt = prompt_text});
-
-      fmt::print("\n🤖 Response: ");
-
-      // Use for loop
-      for (auto& step : qwen3_moe.chat(inputs)) { std::wcout << qwen3_moe_tokenizer.detokenize(step.cur_token_id) << std::flush; }
-
-      fmt::print("\n{}\n", std::string(60, '-'));
-    } catch (const std::exception& e) { fmt::print("\n❌ Error: {}\n{}\n", e.what(), std::string(60, '-')); }
+    if (prompt_text == "exit" || prompt_text == "quit") {
+      fmt::print("👋 Bye!\n");
+    } else {
+      try {
+        fmt::print("🔄 Processing...\n");
+        auto inputs = qwen3_moe_tokenizer.convertMessage({.prompt = prompt_text});
+
+        fmt::print("\n🤖 Response: ");
+
+        // Use for loop
+        for (auto& step : qwen3_moe.chat(inputs)) { std::wcout << qwen3_moe_tokenizer.detokenize(step.cur_token_id) << std::flush; }
+
+        fmt::print("\n{}\n", std::string(60, '-'));
+      } catch (const std::exception& e) { fmt::print("\n❌ Error: {}\n{}\n", e.what(), std::string(60, '-')); }
+    }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
fmt::print("\n{:*^60}\n", " Qwen3 MoE Interactive CLI ");
fmt::print("Enter 'exit' or 'quit' to end the session\n\n");
std::string prompt_text;
fmt::print("💬 Prompt text (or 'exit/quit'): ");
std::getline(std::cin, prompt_text);
try {
fmt::print("🔄 Processing...\n");
auto inputs = qwen3_moe_tokenizer.convertMessage({.prompt = prompt_text});
fmt::print("\n🤖 Response: ");
// Use for loop
for (auto& step : qwen3_moe.chat(inputs)) { std::wcout << qwen3_moe_tokenizer.detokenize(step.cur_token_id) << std::flush; }
fmt::print("\n{}\n", std::string(60, '-'));
} catch (const std::exception& e) { fmt::print("\n❌ Error: {}\n{}\n", e.what(), std::string(60, '-')); }
fmt::print("\n{:*^60}\n", " Qwen3 MoE Interactive CLI ");
fmt::print("Enter 'exit' or 'quit' to end the session\n\n");
std::string prompt_text;
fmt::print("💬 Prompt text (or 'exit/quit'): ");
std::getline(std::cin, prompt_text);
if (prompt_text == "exit" || prompt_text == "quit") {
fmt::print("👋 Bye!\n");
} else {
try {
fmt::print("🔄 Processing...\n");
auto inputs = qwen3_moe_tokenizer.convertMessage({.prompt = prompt_text});
fmt::print("\n🤖 Response: ");
// Use for loop
for (auto& step : qwen3_moe.chat(inputs)) { std::wcout << qwen3_moe_tokenizer.detokenize(step.cur_token_id) << std::flush; }
fmt::print("\n{}\n", std::string(60, '-'));
} catch (const std::exception& e) { fmt::print("\n❌ Error: {}\n{}\n", e.what(), std::string(60, '-')); }
}
🤖 Prompt for AI Agents
In `@examples/qwen3_moe/main.cpp` around lines 44 - 63, The loop currently always
processes input even when the user types the advertised "exit" or "quit"; after
reading prompt_text (variable prompt_text) trim/normalize it and check for
"exit" or "quit" (case-insensitive) and if matched, break out of the interactive
loop or return from main instead of calling qwen3_moe_tokenizer.convertMessage
and qwen3_moe.chat; place this check immediately after std::getline and before
the try block so convertMessage and qwen3_moe.chat are not invoked for
exit/quit.

Comment on lines +10 to +107
namespace hn = hwy::HWY_NAMESPACE;


struct ScalarAddOp { template<typename T> HWY_INLINE T operator()(T a, T b) const { return a + b; } };

struct ScalarSubOp { template<typename T> HWY_INLINE T operator()(T a, T b) const { return a - b; } };

struct ScalarMulOp { template<typename T> HWY_INLINE T operator()(T a, T b) const { return a * b; } };

struct ScalarDivOp { template<typename T> HWY_INLINE T operator()(T a, T b) const { return a / b; } };

struct ScalarMaxOp { template<typename T> HWY_INLINE T operator()(T a, T b) const { return a > b ? a : b; } };

struct ScalarMinOp { template<typename T> HWY_INLINE T operator()(T a, T b) const { return a < b ? a : b; } };

struct VecAddOp {
template<class D, class V>
HWY_INLINE V operator()(D d, V a, V b) const { return hn::Add(a, b); }
};

struct VecSubOp {
template<class D, class V>
HWY_INLINE V operator()(D d, V a, V b) const { return hn::Sub(a, b); }
};

struct VecMulOp {
template<class D, class V>
HWY_INLINE V operator()(D d, V a, V b) const { return hn::Mul(a, b); }
};

struct VecDivOp {
template<class D, class V>
HWY_INLINE V operator()(D d, V a, V b) const { return hn::Div(a, b); }
};

struct VecMaxOp {
template<class D, class V>
HWY_INLINE V operator()(D d, V a, V b) const { return hn::Max(a, b); }
};

struct VecMinOp {
template<class D, class V>
HWY_INLINE V operator()(D d, V a, V b) const { return hn::Min(a, b); }
};

struct VecSumReduce {
template <class D, class V>
HWY_INLINE hn::TFromD<D> operator()(D d, V v) const { return hn::ReduceSum(d, v); }
};


template<typename T, typename ScalarOp, typename VectorOp, typename VectorReduceOp>
HWY_INLINE T reduce_impl(const T* HWY_RESTRICT src, size_t src_stride, size_t size,
ScalarOp&& scalar_op, VectorOp&& vec_op, VectorReduceOp&& vec_reduce_op) {
if (size == 0) return T(0);

const hn::ScalableTag<T> d;
const size_t N = hn::Lanes(d);

// SIMD fast path
if (src_stride == 1 && size >= N) {
using V = hn::Vec<decltype(d)>;

// Init with first vector
V vec_result = hn::LoadU(d, src);
size_t i = N;

// 4x unroll
for (; i + 4 * N <= size; i += 4 * N) {
const V v0 = hn::LoadU(d, src + i);
const V v1 = hn::LoadU(d, src + i + N);
const V v2 = hn::LoadU(d, src + i + 2 * N);
const V v3 = hn::LoadU(d, src + i + 3 * N);

vec_result = vec_op(d, vec_result, v0);
vec_result = vec_op(d, vec_result, v1);
vec_result = vec_op(d, vec_result, v2);
vec_result = vec_op(d, vec_result, v3);
}

for (; i + N <= size; i += N) {
const V v = hn::LoadU(d, src + i);
vec_result = vec_op(d, vec_result, v);
}

if (i < size) {
const V vt = hn::LoadN(d, src + i, size - i);
vec_result = vec_op(d, vec_result, vt);
}

return vec_reduce_op(d, vec_result);
}

// Scalar path (stride != 1 or too small)
T scalar_result = src[0];
for (size_t i = 1; i < size; ++i) {
scalar_result = scalar_op(scalar_result, src[i * src_stride]);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Remove trailing whitespace in the new reduction code.
Line 10 and Line 62 appear to end with trailing spaces; please trim whitespace in this block.
As per coding guidelines, no line may end with trailing whitespace.

🤖 Prompt for AI Agents
In `@mllm/backends/cpu/kernels/common/reduce-inl.hpp` around lines 10 - 107, Trim
trailing whitespace on the two affected lines in this header: remove the extra
space at the end of the line containing "namespace hn = hwy::HWY_NAMESPACE;" and
the trailing space at the end of the line containing "return vec_reduce_op(d,
vec_result);". Update the file mllm/backends/cpu/kernels/common/reduce-inl.hpp
so no line ends with trailing spaces (verify with your editor or run a
whitespace linter), then save and re-run formatting checks.

Comment on lines +26 to +41
bos_token_id = data()["bos_token_id"];
eos_token_id = data()["eos_token_id"];
rope_theta = data()["rope_theta"];

tie_word_embeddings = data()["tie_word_embeddings"];
max_cache_length = data()["max_cache_length"];

// MoE config
num_experts = data()["num_experts"];
num_experts_per_tok = data()["num_experts_per_tok"];
moe_intermediate_size = data()["moe_intermediate_size"];
norm_topk_prob = data()["norm_topk_prob"];
decoder_sparse_step = data()["decoder_sparse_step"];
mlp_only_layers = data()["mlp_only_layers"].get<std::vector<int>>();

linear_impl_type = aops::str2LinearImplTypes(data()["linear_impl_type"]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Config constructor ignores end/thinking token IDs.

Line 26-39 loads bos/eos but skips end_of_text_token_id and thinking_*; any JSON overrides won’t apply even though these tokens are used later for generation. Consider loading them (or documenting they are fixed defaults).

💡 Suggested fix
   bos_token_id = data()["bos_token_id"];
   eos_token_id = data()["eos_token_id"];
+  end_of_text_token_id = data()["end_of_text_token_id"];
+  thinking_start_token_id = data()["thinking_start_token_id"];
+  thinking_end_token_id = data()["thinking_end_token_id"];
   rope_theta = data()["rope_theta"];
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
bos_token_id = data()["bos_token_id"];
eos_token_id = data()["eos_token_id"];
rope_theta = data()["rope_theta"];
tie_word_embeddings = data()["tie_word_embeddings"];
max_cache_length = data()["max_cache_length"];
// MoE config
num_experts = data()["num_experts"];
num_experts_per_tok = data()["num_experts_per_tok"];
moe_intermediate_size = data()["moe_intermediate_size"];
norm_topk_prob = data()["norm_topk_prob"];
decoder_sparse_step = data()["decoder_sparse_step"];
mlp_only_layers = data()["mlp_only_layers"].get<std::vector<int>>();
linear_impl_type = aops::str2LinearImplTypes(data()["linear_impl_type"]);
bos_token_id = data()["bos_token_id"];
eos_token_id = data()["eos_token_id"];
end_of_text_token_id = data()["end_of_text_token_id"];
thinking_start_token_id = data()["thinking_start_token_id"];
thinking_end_token_id = data()["thinking_end_token_id"];
rope_theta = data()["rope_theta"];
tie_word_embeddings = data()["tie_word_embeddings"];
max_cache_length = data()["max_cache_length"];
// MoE config
num_experts = data()["num_experts"];
num_experts_per_tok = data()["num_experts_per_tok"];
moe_intermediate_size = data()["moe_intermediate_size"];
norm_topk_prob = data()["norm_topk_prob"];
decoder_sparse_step = data()["decoder_sparse_step"];
mlp_only_layers = data()["mlp_only_layers"].get<std::vector<int>>();
linear_impl_type = aops::str2LinearImplTypes(data()["linear_impl_type"]);
🤖 Prompt for AI Agents
In `@mllm/models/qwen3_moe/configuration_qwen3_moe.hpp` around lines 26 - 41, The
constructor currently reads bos_token_id and eos_token_id but does not load
end_of_text_token_id or any thinking_* fields from the JSON, so JSON overrides
are ignored; modify the same initialization block (near bos_token_id,
eos_token_id, rope_theta, tie_word_embeddings, max_cache_length, num_experts,
num_experts_per_tok, moe_intermediate_size, norm_topk_prob, decoder_sparse_step,
mlp_only_layers, linear_impl_type) to also assign end_of_text_token_id and all
thinking_* configuration keys from data() (using the same access pattern and
appropriate get<...>() types as used for mlp_only_layers) so those tokens honor
JSON overrides.

Comment on lines +15 to +491
inline auto makeRoPEInvFreq(int output_dim, float rope_theta) -> Tensor {
auto inv_freq = Tensor::empty({output_dim / 2}, kFloat32, kCPU).alloc();
auto inv_freq_ptr = inv_freq.ptr<float>();
for (int i = 0; i < output_dim / 2; i++) { inv_freq_ptr[i] = 1.0 / std::pow(rope_theta, 2.0 * i / output_dim); }
return inv_freq;
}

inline auto makeRotaryPosEmbedding(Tensor& position_ids, const Tensor& inv_freq,
float attention_scaling = 1.0f) -> std::pair<Tensor, Tensor> {
auto batch_size = position_ids.shape()[0];
auto seq_len = position_ids.shape()[1];
auto inv_freq_len = inv_freq.shape()[0];
auto dim = inv_freq_len * 2;

// Create freqs tensor: position_ids @ inv_freq
auto freqs = Tensor::empty({batch_size, seq_len, inv_freq_len}, kFloat32, kCPU).alloc();
auto freqs_ptr = freqs.ptr<float>();
auto position_ids_ptr = position_ids.ptr<int64_t>();
auto inv_freq_ptr = inv_freq.ptr<float>();

// Compute freqs = position_ids[:, :, None] @ inv_freq[None, :]
for (int b = 0; b < batch_size; ++b) {
for (int s = 0; s < seq_len; ++s) {
auto pos = position_ids_ptr[b * seq_len + s];
for (int d = 0; d < inv_freq_len; ++d) {
freqs_ptr[b * seq_len * inv_freq_len + s * inv_freq_len + d] = static_cast<float>(pos) * inv_freq_ptr[d];
}
}
}

// Create sin and cos tensors with shape [batch_size, seq_len, dim]
auto sin_emb = Tensor::empty({batch_size, seq_len, dim}, kFloat32, kCPU).alloc();
auto cos_emb = Tensor::empty({batch_size, seq_len, dim}, kFloat32, kCPU).alloc();
auto sin_ptr = sin_emb.ptr<float>();
auto cos_ptr = cos_emb.ptr<float>();

// Compute sin and cos embeddings: emb = [freqs, freqs]
for (int b = 0; b < batch_size; ++b) {
for (int s = 0; s < seq_len; ++s) {
for (int d = 0; d < inv_freq_len; ++d) {
auto freq = freqs_ptr[b * seq_len * inv_freq_len + s * inv_freq_len + d];
auto sin_val = std::sin(freq) * attention_scaling;
auto cos_val = std::cos(freq) * attention_scaling;

// Store the same values in both halves: [freqs, freqs]
sin_ptr[b * seq_len * dim + s * dim + d] = sin_val;
sin_ptr[b * seq_len * dim + s * dim + d + inv_freq_len] = sin_val;
cos_ptr[b * seq_len * dim + s * dim + d] = cos_val;
cos_ptr[b * seq_len * dim + s * dim + d + inv_freq_len] = cos_val;
}
}
}

return {sin_emb, cos_emb};
}

class Qwen3MoeMLP final : public nn::Module {
nn::Linear gate_proj_;
nn::Linear up_proj_;
nn::Linear down_proj_;
nn::SiLU act_;

int hidden_size_;
int intermediate_size_;

public:
Qwen3MoeMLP() = default;

explicit Qwen3MoeMLP(const std::string& name, const Qwen3MoeConfig& config,
const std::optional<int>& hidden_size = std::nullopt,
const std::optional<int>& intermediate_size = std::nullopt)
: nn::Module(name) {
hidden_size_ = hidden_size.value_or(config.hidden_size);
intermediate_size_ = intermediate_size.value_or(config.intermediate_size);

// clang-format off
gate_proj_ = reg<nn::Linear>("gate_proj", hidden_size_, intermediate_size_, false, config.linear_impl_type);
up_proj_ = reg<nn::Linear>("up_proj", hidden_size_, intermediate_size_, false, config.linear_impl_type);
down_proj_ = reg<nn::Linear>("down_proj", intermediate_size_, hidden_size_, false, config.linear_impl_type);
act_ = reg<nn::SiLU>("act");
// clang-format on
}

std::vector<Tensor> forward(const std::vector<Tensor>& inputs, const std::vector<AnyValue>& args) override {
return {down_proj_(act_(gate_proj_(inputs[0])) * up_proj_(inputs[0]))};
}
};

class MoEGate final : public nn::Module {
int top_k_;
int num_experts_;
bool norm_topk_prob_;

nn::Param weight_;

public:
MoEGate() = default;

MoEGate(const std::string& name, const Qwen3MoeConfig& config) : nn::Module(name) {
top_k_ = config.num_experts_per_tok;
num_experts_ = config.num_experts;
norm_topk_prob_ = config.norm_topk_prob;

weight_ = reg<nn::Param>("weight", getModuleName() + ".weight");
}

std::vector<Tensor> forward(const std::vector<Tensor>& inputs, const std::vector<AnyValue>& args) override {
auto hidden_states = inputs[0];
auto bsz = hidden_states.size(0);
auto seq_len = hidden_states.size(1);
auto h = hidden_states.size(2);

// Compute gating score
hidden_states = hidden_states.view({-1, h});
// hidden_states and weight must in fp32 to keep precision !!!
auto logits = nn::functional::matmul(hidden_states, weight_.weight(), false, true);
auto scores = nn::functional::softmax(logits, -1);
auto [topk_weight, topk_idx] = nn::functional::topk(scores, top_k_, -1, true, false);

if(norm_topk_prob_){
topk_weight = topk_weight / topk_weight.sum(-1, true);
}

return {topk_idx, topk_weight};
}
};

class Qwen3MoE final : public nn::Module {
int num_experts_per_tok_;
nn::ModuleList<Qwen3MoeMLP> experts_;
MoEGate gate_;

public:
Qwen3MoE() = default;

Qwen3MoE(const std::string& name, const Qwen3MoeConfig& config) : nn::Module(name) {
num_experts_per_tok_ = config.num_experts_per_tok;
// Init experts
experts_ = reg<nn::ModuleList<Qwen3MoeMLP>>("experts", config.num_experts, config, std::nullopt,
config.moe_intermediate_size);
gate_ = reg<MoEGate>("gate", config);
}

std::vector<Tensor> forward(const std::vector<Tensor>& inputs, const std::vector<AnyValue>& args) override {
auto hidden_states = inputs[0];
auto identity = hidden_states;
auto orig_shape = hidden_states.shape();
auto topk_idx = Tensor::nil();
auto topk_weight = Tensor::nil();
auto gated_ret = gate_(hidden_states);
topk_idx = gated_ret[0];
topk_weight = gated_ret[1];
hidden_states = hidden_states.view({-1, hidden_states.size(-1)});

auto y = moeInfer(hidden_states, topk_idx, topk_weight).view(orig_shape);

return {y};
}

private:
Tensor moeInfer(const Tensor& x, Tensor& topk_ids, Tensor& topk_weights) {
// x shape is [batch_size * seq, hidden_dim]

auto cnts = Tensor::zeros({topk_ids.size(0), (int32_t)experts_.list().size()});
// Do scatter_ operation
{
const int32_t* idx_ptr = topk_ids.ptr<mllm_int32_t>();
float* cnt_ptr = cnts.ptr<mllm_fp32_t>();
const int batch = topk_ids.size(0);
const int k = topk_ids.size(1);
const int n_exp = cnts.size(1);
for (int b = 0; b < batch; ++b) {
for (int j = 0; j < k; ++j) {
int32_t e = idx_ptr[b * k + j];
MLLM_RT_ASSERT(e >= 0 && e < n_exp);
cnt_ptr[b * n_exp + e] += 1.f; // +1
}
}
}
auto tokens_per_expert = cnts.sum(0);
auto idxs = topk_ids.view({-1}).argsort();

// TODO this line maybe error
auto sorted_tokens = x[{idxs / topk_ids.size(1), {kAll}}];

std::vector<Tensor> outputs;
int start_idx = 0;

// tokens_per_expert shape is [num_experts]
// Loop through each expert
for (int i = 0; i < experts_.list().size(); ++i) {
auto num_tokens = tokens_per_expert.ptr<mllm_fp32_t>()[i];
auto end_idx = start_idx + (int32_t)num_tokens;
if (num_tokens == 0) { continue; }
auto& expert = experts_.list()[i];
auto tokens_for_this_expert = sorted_tokens[{{start_idx, end_idx}, kAll}];
auto expert_out = expert(tokens_for_this_expert)[0];
outputs.push_back(expert_out);
start_idx = end_idx;
}

auto outs = nn::functional::concat(outputs, 0);
auto new_x = Tensor::emptyLike(outs).alloc();

// indexed_write
// python logic: new_x[idxs] = outs
{
const int32_t* idx_ptr = idxs.ptr<mllm_int32_t>();
float* outs_ptr = outs.ptr<mllm_fp32_t>();
float* new_x_ptr = new_x.ptr<mllm_fp32_t>();
MLLM_RT_ASSERT_EQ(new_x.rank(), 2);
MLLM_RT_ASSERT_EQ(new_x.size(0), idxs.size(0));
auto dim = new_x.size(1);
for (int i = 0; i < idxs.size(0); ++i) {
int32_t idx = idx_ptr[i];
std::memcpy(new_x_ptr + idx * dim, outs_ptr + i * dim, dim * sizeof(float));
}
}

auto final_out_shape = topk_ids.shape();
final_out_shape.emplace_back(-1);
auto final_out =
new_x.view(final_out_shape).to(topk_weights.dtype()).mul_(topk_weights.unsqueeze(-1)).sum(1).to(new_x.dtype());
return final_out;
}
};

class Qwen3MoeAttention final : public nn::Module {
nn::Linear q_proj_;
nn::Linear k_proj_;
nn::Linear v_proj_;
nn::Linear o_proj_;
nn::RMSNorm rms_norm_q_;
nn::RMSNorm rms_norm_k_;
nn::RoPE q_rope_;
nn::RoPE k_rope_;

int hidden_size_;
int head_dim_;
int num_attention_heads_;
int num_key_value_heads_;
int num_key_value_groups_;

public:
Qwen3MoeAttention() = default;

Qwen3MoeAttention(const std::string& name, const Qwen3MoeConfig& cfg) : nn::Module(name) {
hidden_size_ = cfg.hidden_size;
num_attention_heads_ = cfg.num_attention_heads;
num_key_value_heads_ = cfg.num_key_value_heads;
head_dim_ = cfg.head_dim;
num_key_value_groups_ = num_attention_heads_ / num_key_value_heads_;

// clang-format off
q_proj_ = reg<nn::Linear>("q_proj", hidden_size_, head_dim_ * num_attention_heads_, cfg.attention_bias, cfg.linear_impl_type);
k_proj_ = reg<nn::Linear>("k_proj", hidden_size_, head_dim_ * num_key_value_heads_, cfg.attention_bias, cfg.linear_impl_type).redirect();
v_proj_ = reg<nn::Linear>("v_proj", hidden_size_, head_dim_ * num_key_value_heads_, cfg.attention_bias, cfg.linear_impl_type).redirect();
o_proj_ = reg<nn::Linear>("o_proj", head_dim_ * num_attention_heads_, hidden_size_, cfg.attention_bias, cfg.linear_impl_type);
// clang-format on

rms_norm_q_ = reg<nn::RMSNorm>("q_norm", cfg.rms_norm_eps).inplace();
rms_norm_k_ = reg<nn::RMSNorm>("k_norm", cfg.rms_norm_eps).inplace();

// clang-format off
q_rope_ = reg<nn::RoPE>("q_rope", cfg.rope_theta, cfg.max_position_embeddings, aops::RoPEOpOptionsInputType::kBSHD).inplace();
k_rope_ = reg<nn::RoPE>("k_rope", cfg.rope_theta, cfg.max_position_embeddings, aops::RoPEOpOptionsInputType::kBSHD).inplace();
// clang-format on
}

std::vector<Tensor> forward(const std::vector<Tensor>& inputs, const std::vector<AnyValue>& args) override {
auto x = inputs[0];
auto llm_embedding_sin = inputs[1];
auto llm_embedding_cos = inputs[2];
auto past_kv_cache = args[0].get<nn::StaticCache*>();

int B = inputs[0].shape()[0];
int S = inputs[0].shape()[1];

// Get KV cache for Key and Value first.
// [B, S, H * D]
auto [key_states_redirect, value_states_redirect] = past_kv_cache->preGetKVWriteLocation(layer_idx_, S);

// [B, S, H * D]
auto query_states = q_proj_(x);
auto key_states = k_proj_(x, key_states_redirect);
auto value_states = v_proj_(x, value_states_redirect);

// [B, S, H, D]
query_states = query_states.view({B, S, num_attention_heads_, head_dim_});
key_states = key_states.view({B, S, num_key_value_heads_, head_dim_});

// [B, S, H, D]
query_states = rms_norm_q_(query_states);
key_states = rms_norm_k_(key_states);

// [B, S, H, D]
query_states = q_rope_(query_states, llm_embedding_sin, llm_embedding_cos);
key_states = k_rope_(key_states, llm_embedding_sin, llm_embedding_cos);

// Get KV
auto [K, V] = past_kv_cache->getKVCache(layer_idx_);

// [B, S, H, D] FA2
auto output = o_proj_(nn::functional::flashAttention2(query_states, K, V).view({B, S, num_attention_heads_ * head_dim_}));

return {output};
}

int layer_idx_;
};

class Qwen3MoeDecoder final : public nn::Module {
Qwen3MoeAttention self_attn_;
nn::RMSNorm input_layer_norm_;
nn::RMSNorm post_attention_layer_norm_;

std::optional<Qwen3MoE> mlp_opt0_ = std::nullopt;
std::optional<Qwen3MoeMLP> mlp_opt1_ = std::nullopt;

public:
int layer_idx_;

Qwen3MoeDecoder() = default;

Qwen3MoeDecoder(const std::string& name, const Qwen3MoeConfig& cfg, int layer_idx) : nn::Module(name) {
layer_idx_ = layer_idx;

self_attn_ = reg<Qwen3MoeAttention>("self_attn", cfg);
self_attn_.layer_idx_ = layer_idx;

bool is_mlp_only = std::find(cfg.mlp_only_layers.begin(), cfg.mlp_only_layers.end(), layer_idx) != cfg.mlp_only_layers.end();
if ((!is_mlp_only) && (cfg.num_experts > 0 && (layer_idx_+1) % cfg.decoder_sparse_step == 0)) {
mlp_opt0_ = reg<Qwen3MoE>("mlp", cfg);
} else {
mlp_opt1_ = reg<Qwen3MoeMLP>("mlp", cfg);
}

input_layer_norm_ = reg<nn::RMSNorm>("input_layernorm", cfg.rms_norm_eps);
post_attention_layer_norm_ = reg<nn::RMSNorm>("post_attention_layernorm", cfg.rms_norm_eps);
}

std::vector<Tensor> forward(const std::vector<Tensor>& inputs, const std::vector<AnyValue>& args) override {
auto llm_embedding_sin = inputs[1];
auto llm_embedding_cos = inputs[2];
auto& kv_cache = args[0];

auto x = input_layer_norm_(inputs[0]);
x = self_attn_(x, llm_embedding_sin, llm_embedding_cos, kv_cache)[0];
auto tmp = x + inputs[0];
x = post_attention_layer_norm_(tmp);
if(mlp_opt0_){
x = mlp_opt0_.value()(x)[0];
} else {
x = mlp_opt1_.value()(x)[0];
}
x = x + tmp;
return {x};
}
};

class Qwen3MoeText final : public nn::Module {
nn::Embedding embedding_;
nn::ModuleListWithIdx<Qwen3MoeDecoder> decode_blocks_;
nn::RMSNorm norm_;

public:
Qwen3MoeText() = default;

explicit Qwen3MoeText(const std::string& name, const Qwen3MoeConfig& cfg) : nn::Module(name) {
embedding_ = reg<nn::Embedding>("embed_tokens", cfg.vocab_size, cfg.hidden_size);
decode_blocks_ = reg<nn::ModuleListWithIdx<Qwen3MoeDecoder>>("layers", cfg.num_hidden_layers, cfg);
norm_ = reg<nn::RMSNorm>("norm", cfg.rms_norm_eps);

}

std::vector<Tensor> forward(const std::vector<Tensor>& inputs, const std::vector<AnyValue>& args) override {
auto& blocks = decode_blocks_.list();

// X is already embedded
auto x = embedding_(inputs[0]);

auto llm_embedding_sin = inputs[1];
auto llm_embedding_cos = inputs[2];
auto& kv_cache = args[0];

for (auto& block : blocks) { x = block(x, llm_embedding_sin, llm_embedding_cos, kv_cache)[0]; }

x = norm_(x);

return {x};
}
};

class Qwen3MoeForCausalLM : public ARGeneration, public nn::Module {
public:
explicit Qwen3MoeForCausalLM(const Qwen3MoeConfig& cfg) : cfg(cfg) {
kv_cache_ = nn::StaticCache(cfg.max_cache_length, cfg.num_hidden_layers,
cfg.num_attention_heads, // q_heads
cfg.num_key_value_heads, // kv_heads
cfg.head_dim, // kv_dim
kFloat32, // k_dtype
kFloat32, // v_dtype
kCPU, // device_type
true // use_fa2
);
eos_token_id_ = cfg.end_of_text_token_id;
max_length_ = cfg.max_cache_length;
tie_word_embeddings_ = cfg.tie_word_embeddings;

llm = reg<Qwen3MoeText>("model", cfg);

if (cfg.tie_word_embeddings) {
// NOTE:
// model.lm_head.weight is quantization weights of model.embed_tokens.weight
lm_head_ = reg<nn::Linear>("lm_head", cfg.hidden_size, cfg.vocab_size, false, cfg.linear_impl_type);
}

// Init inv freq
auto inv = makeRoPEInvFreq(cfg.head_dim, cfg.rope_theta);
registerBuffer("inv_freq", inv);
}

ARGenerationOutputPast forward(const ARGenerationOutputPast& input, const ARGenerationArgs& args) override {
auto sequence = input.at("sequence");

// Generate position_ids for the current sequence
auto batch_size = sequence.shape()[0];
auto seq_len = sequence.shape()[1];

Tensor position_ids = Tensor::nil();
if (input.count("position_ids")) {
// Use existing position_ids for decode phase
position_ids = input.at("position_ids");

// For decode phase, increment the last position
if (seq_len == 1) {
auto last_pos = *position_ids.offsettedPtr<int64_t>({0, position_ids.shape()[1] - 1});
position_ids = Tensor::empty({batch_size, 1}, kInt64, kCPU).alloc();
*position_ids.offsettedPtr<int64_t>({0, 0}) = last_pos + 1;
}
} else {
// Generate position_ids for prefill phase
position_ids = Tensor::empty({batch_size, seq_len}, kInt64, kCPU).alloc();
auto position_ids_ptr = position_ids.ptr<int64_t>();
for (int b = 0; b < batch_size; ++b) {
for (int s = 0; s < seq_len; ++s) { position_ids_ptr[b * seq_len + s] = s; }
}
}

// Generate RoPE embeddings using the inv_freq buffer
auto [llm_embedding_sin, llm_embedding_cos] = makeRotaryPosEmbedding(position_ids, getBuffer("inv_freq"), 1.0f);

sequence = llm(sequence, llm_embedding_sin, llm_embedding_cos, AnyValue(&kv_cache_))[0];

// clip x to one seq length
{
auto S = sequence.shape()[1];
sequence = sequence[{kAll, {S - 1}, kAll}];
}
if (tie_word_embeddings_) { sequence = lm_head_(sequence); }

return {
{"sequence", sequence},
{"position_ids", position_ids},
};
}

inline nn::StaticCache& kvCache() { return kv_cache_; }

private:
const Qwen3MoeConfig& cfg;
Qwen3MoeText llm;
nn::Linear lm_head_;
bool tie_word_embeddings_;
nn::StaticCache kv_cache_;
};

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add docstrings for Qwen3 MoE public APIs.

Public functions/classes (e.g., makeRoPEInvFreq, Qwen3Moe* modules, Qwen3MoeForCausalLM) lack docstrings explaining purpose, parameters, returns, and errors. Please add brief API comments.

As per coding guidelines: Ensure public APIs, classes, and functions have clear docstrings or comments explaining purpose, parameters, returns, and errors.

🤖 Prompt for AI Agents
In `@mllm/models/qwen3_moe/modeling_qwen3_moe_fa2.hpp` around lines 15 - 491, Add
concise docstrings/comments above each public function and class declaration to
describe purpose, parameters, return values, and any error/edge-case behaviors;
specifically add comments for makeRoPEInvFreq, makeRotaryPosEmbedding,
Qwen3MoeMLP (and its constructor/forward), MoEGate (constructor/forward),
Qwen3MoE (constructor/forward/moeInfer), Qwen3MoeAttention
(constructor/forward), Qwen3MoeDecoder (constructor/forward), Qwen3MoeText
(constructor/forward), and Qwen3MoeForCausalLM (constructor/forward/kvCache);
keep each docstring short (one- to three-sentence description plus brief
param/return lines), place them immediately above the corresponding
function/class, and mention any important preconditions (e.g., expected tensor
shapes, dtype requirements, or when exceptions/assertions are raised).

Comment on lines +197 to +199
// TODO this line maybe error
auto sorted_tokens = x[{idxs / topk_ids.size(1), {kAll}}];

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Resolve TODO and confirm token ordering.

Line 197 contains a TODO without the required TODO: prefix and flags a potential correctness uncertainty in the routing index; please verify the logic or fix it, and update the TODO format.

🔧 Suggested comment fix
-    // TODO this line maybe error
+    // TODO: Confirm idxs/topk mapping is correct; add a targeted test or rationale.

As per coding guidelines: TODO and FIXME comments must be written as 'TODO:' or 'FIXME:' followed by UTF-8 text that adheres to character set rules.

🤖 Prompt for AI Agents
In `@mllm/models/qwen3_moe/modeling_qwen3_moe_fa2.hpp` around lines 197 - 199, The
line assigning sorted_tokens using auto sorted_tokens = x[{idxs /
topk_ids.size(1), {kAll}}]; is flagged as uncertain and the inline TODO must be
converted to the required "TODO:" format and clarified; verify that the batch
index computation uses the correct dimension (ensure idxs is divided by the
correct topk_ids dimension and that the division yields integer batch indices)
and confirm that x[{...}] returns tokens in the intended score-sorted order
(adjust the index expression if the batch/beam mapping is reversed), then
replace the comment with a clear "TODO: confirm token ordering and correct batch
index computation for sorted_tokens (uses idxs, topk_ids.size(1), x and kAll)"
describing what to validate or change.

Comment on lines +19 to +267
inline bool qwen3TokenizerMatchPattern(const std::wstring& str, size_t& pos, std::wstring& matched) {
if (pos >= str.size()) return false;

// 1. Match contractions: "'s|'t|'re|'ve|'m|'ll|'d"
static const std::wstring contractions[] = {L"'s", L"'t", L"'re", L"'ve", L"'m", L"'ll", L"'d"};
for (const auto& contraction : contractions) {
if (pos + contraction.size() <= str.size() && str.compare(pos, contraction.size(), contraction) == 0) {
matched = contraction;
pos += contraction.size();
return true;
}
}

// 2. Match [^\r\n\p{L}\p{N}]?\p{L}+ (non-letter/digit followed by letters)
{
size_t original_pos = pos;
bool has_prefix = false;
matched.clear();

// Check optional non-letter/digit prefix (excluding \r\n)
if (!preprocessor::isLetter(str[pos]) && !preprocessor::isDigit(str[pos]) && str[pos] != L'\r' && str[pos] != L'\n') {
matched += str[pos];
++pos;
has_prefix = true;
}

// Require at least one letter
if (pos < str.size() && preprocessor::isLetter(str[pos])) {
do {
matched += str[pos];
++pos;
} while (pos < str.size() && preprocessor::isLetter(str[pos]));
return true;
} else {
// Rollback if no letters after prefix
if (has_prefix) {
pos = original_pos;
matched.clear();
}
}
}

// 3. Match \p{N} (digits)
if (preprocessor::isDigit(str[pos])) {
matched = str.substr(pos, 1);
++pos;
return true;
}

// 4. Match ?[^\s\p{L}\p{N}]+[\r\n]* (punctuation/symbols with optional space prefix)
{
size_t original_pos = pos;
matched.clear();
size_t start = pos;

// Optional space
if (str[pos] == L' ') { ++pos; }

// Require at least one non-letter/digit/whitespace
if (pos < str.size() && !std::iswspace(str[pos]) && !preprocessor::isLetter(str[pos]) && !preprocessor::isDigit(str[pos])) {
do {
++pos;
} while (pos < str.size() && !std::iswspace(str[pos]) && !preprocessor::isLetter(str[pos])
&& !preprocessor::isDigit(str[pos]));

// Capture from start (after optional space) to current pos
matched = str.substr(start, pos - start);

// Capture trailing newlines
while (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) {
matched += str[pos];
++pos;
}
return true;
} else {
// Rollback if no symbols found
pos = original_pos;
}
}

// 5. Match \s*[\r\n]+ (newlines with leading whitespace)
{
size_t start = pos;
while (pos < str.size() && std::iswspace(str[pos])) ++pos;
if (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) {
while (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) ++pos;
matched = str.substr(start, pos - start);
return true;
} else {
pos = start;
}
}

// 6. Match \s+(?!\S) (whitespace not followed by non-space)
if (std::iswspace(str[pos])) {
size_t start = pos;
while (pos < str.size() && std::iswspace(str[pos])) ++pos;
// Check if at end or followed by whitespace
if (pos >= str.size() || std::iswspace(str[pos])) {
matched = str.substr(start, pos - start);
return true;
} else {
pos = start;
}
}

// 7. Match remaining whitespace
if (std::iswspace(str[pos])) {
size_t start = pos;
while (pos < str.size() && std::iswspace(str[pos])) ++pos;
matched = str.substr(start, pos - start);
return true;
}

return false;
}

inline bool qwen3Regex(const std::string& str, std::vector<std::wstring>& splitted) {
auto w_string = preprocessor::utf8string2WideString(str);
size_t pos = 0;
while (pos < w_string.size()) {
std::wstring matched;
if (qwen3TokenizerMatchPattern(w_string, pos, matched)) {
splitted.push_back(matched);
} else {
++pos;
}
}
return true;
}

struct Qwen3Message {
std::string prompt;
static inline std::string message_template =
"<|im_start|>user\n{{{prompt}}}<|im_end|>\n<|im_start|>assistant\n";
};

class Qwen3Tokenizer final : public mllm::preprocessor::AutoTokenizer {
public:
explicit Qwen3Tokenizer(const std::string& file_path) {
preprocessor::initLocal();
preprocessor::makeBytes2UnicodeMap(bytes_2_unicode_dict_);
for (auto& kv : bytes_2_unicode_dict_) { bytes_2_unicode_dict_inverse_.insert({kv.second, kv.first}); }
bpe_.initFromSentencePieceJson(file_path);
special_tokens_trie_.add(L"<|endoftext|>");
special_tokens_trie_.add(L"<|im_start|>");
special_tokens_trie_.add(L"<|im_end|>");
special_tokens_trie_.add(L"<|object_ref_start|>");
special_tokens_trie_.add(L"<|object_ref_end|>");
special_tokens_trie_.add(L"<|box_start|>");
special_tokens_trie_.add(L"<|box_end|>");
special_tokens_trie_.add(L"<|quad_start|>");
special_tokens_trie_.add(L"<|quad_end|>");
special_tokens_trie_.add(L"<|vision_start|>");
special_tokens_trie_.add(L"<|vision_end|>");
special_tokens_trie_.add(L"<|vision_pad|>");
special_tokens_trie_.add(L"<|image_pad|>");
special_tokens_trie_.add(L"<|video_pad|>");
special_tokens_trie_.add(L"<think>");
special_tokens_trie_.add(L"</think>");
}

std::vector<std::wstring> _tokenize(const std::string& str) override {
std::vector<std::wstring> ret;
std::vector<std::wstring> splitted;
::mllm::models::qwen3_moe::qwen3Regex(str, splitted);
for (const auto& s : splitted) {
auto utf_8_str = preprocessor::wideString2Utf8String(s);
std::wstring mapped_str;
for (unsigned char c : utf_8_str) { mapped_str.push_back(bytes_2_unicode_dict_[c]); }

auto bpe_ts = bpe_._bpe(mapped_str);

for (const auto& bpe_t : bpe_ts) { ret.push_back(bpe_t); }
}

return ret;
}

std::vector<std::wstring> tokenize(const std::string& str) override {
auto tokens = special_tokens_trie_.split(preprocessor::utf8string2WideString(str));
std::vector<std::wstring> all_tokens;
for (const auto& token : tokens) {
if (special_tokens_trie_.isSpecialToken(token)) {
all_tokens.emplace_back(token);
continue;
}
auto tmp_tokens = _tokenize(preprocessor::wideString2Utf8String(token));
all_tokens.insert(all_tokens.end(), tmp_tokens.begin(), tmp_tokens.end());
}
return all_tokens;
}

std::wstring _detokenize(int64_t pos_idx) override { return bpe_._lookup_inverse_vocab(pos_idx); }

std::wstring detokenize(int64_t pos_idx) override {
auto str = _detokenize(pos_idx);
std::string utf_8_str;
for (wchar_t c : str) { utf_8_str.push_back((unsigned char)(bytes_2_unicode_dict_inverse_[c])); }
return {mllm::preprocessor::utf8string2WideString(utf_8_str)};
}

Tensor convert2Ids(const std::vector<std::wstring>& strs) override {
std::vector<int64_t> ids;
ids.reserve(strs.size());
for (const auto& str : strs) { ids.emplace_back(bpe_._lookup_vocab(str)); }
Tensor ret = Tensor::empty({/*batch*/ 1, /*seq*/ (int32_t)ids.size()}, kInt64, kCPU)
.setMemType(kExtraInput)
.setName("qwen2-tokenizer-i0")
.alloc();

auto ptr = ret.ptr<int64_t>();
for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; }

return ret;
}

ARGenerationOutputPast convertMessage(const Qwen3Message& message) {
// process prompt
auto applied_string = Qwen3Message::message_template;
size_t pos = applied_string.find("{{{prompt}}}");
applied_string.replace(pos, 12, message.prompt);

// process sequence
auto sequence_str = tokenize(applied_string);
std::vector<int64_t> ids;
ids.reserve(sequence_str.size());
for (const auto& str : sequence_str) { ids.emplace_back(bpe_._lookup_vocab(str)); }

// Get sequence Tensor
Tensor sequence = Tensor::empty({/*batch*/ 1, /*seq*/ (int32_t)ids.size()}, kInt64, kCPU)
.setMemType(kNormal)
.setName("qwen2-tokenizer-i0")
.alloc();

auto ptr = sequence.ptr<int64_t>();
for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; }

return {
{"sequence", sequence},
};
}

private:
// For text
preprocessor::BPE bpe_;
std::unordered_map<std::wint_t, wchar_t> bytes_2_unicode_dict_;
std::unordered_map<wchar_t, std::wint_t> bytes_2_unicode_dict_inverse_;
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add docstrings for Qwen3 tokenizer public APIs.

The free functions, Qwen3Message, and Qwen3Tokenizer public methods lack docstrings for purpose/params/returns/errors. Please add brief API comments.

As per coding guidelines: Ensure public APIs, classes, and functions have clear docstrings or comments explaining purpose, parameters, returns, and errors.

🤖 Prompt for AI Agents
In `@mllm/models/qwen3_moe/tokenization_qwen3_moe.hpp` around lines 19 - 267, Add
brief doc comments to all public-facing symbols: the free functions
qwen3TokenizerMatchPattern and qwen3Regex, the Qwen3Message struct, and the
Qwen3Tokenizer public API (constructor Qwen3Tokenizer, _tokenize, tokenize,
_detokenize, detokenize, convert2Ids, convertMessage). For each, state one-line
purpose, parameters (e.g., str, pos, matched, message), return value, and any
error/edge-case behavior (e.g., returns false or empty vector on no
match/input). Place comments above the declarations so they document intent,
inputs, outputs and error conditions per coding guidelines.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 18

🤖 Fix all issues with AI agents
In `@examples/qwen3_moe/config_30B_A3B_gguf.json`:
- Line 28: The default rope_theta in configuration_qwen3_moe.hpp is mismatched
with the model JSON; update the default value used by the class (the
member/initializer named rope_theta in the configuration_qwen3_moe.hpp default
constructor or in-class initializer) from 1e6 (1000000.0) to 1e7 (10000000.0) so
it matches the examples/qwen3_moe/config_30B_A3B_gguf.json value.

In `@examples/qwen3_moe/main.cpp`:
- Around line 48-56: The UI claims an interactive CLI but main flow only reads
one prompt into prompt_text then exits; wrap the prompt/read/response/exit-check
sequence (the block using prompt_text and the prompt/ getline logic) inside a
loop (e.g., while(true)) so the program repeatedly prompts until prompt_text ==
"exit" || "quit", or alternatively update the banner/strings to remove
"interactive" and the "Enter 'exit' or 'quit' to end the session" line; modify
the code paths referencing prompt_text to support repeated iterations and clean
termination.
- Around line 23-38: Move the help handling so the program prints help without
requiring other args: call Argparse::parse() only after checking if
help.isSet(), i.e., check help.isSet() immediately after constructing/parsing
raw flags (or the minimal pre-parse if needed) and call Argparse::printHelp()
then mllm::shutdownContext() and return 0; ensure this check runs before
invoking Argparse::parse() (which performs required-argument validation and may
exit), and keep references to Argparse::printHelp(), help.isSet(),
Argparse::parse(), and mllm::shutdownContext() while relocating the check.

In `@examples/qwen3_moe/quant_cfg_30B_q4_k.json`:
- Line 68: The regex entry for "lm_head.weight" uses an unescaped dot which
matches any character; update the pattern to escape the dot (change the pattern
for lm_head.weight to use lm_head\\.weight) to match the literal field name
consistently with other keys (e.g., self_attn\\.q_proj) so the config only
targets the intended parameter.
- Around line 46-67: The MoE expert config is missing the gate_proj pattern used
in the forward pass; add a new JSON entry matching
"^model\\.layers\\.\\d+\\.mlp\\.experts\\.\\d+\\.gate_proj.(bias|weight)" with
the same quant hints style as up_proj/down_proj (quant_method "gguf", gguf_type
"Q4_K", shape [2048,2048], replace true) so gate_proj weights are quantized, or
if omission was intentional, add an inline comment next to the up_proj/down_proj
blocks explaining why gate_proj is excluded; reference the existing up_proj,
down_proj and the model.layers.*.mlp.experts.*.gate_proj symbol to locate where
to add the change.

In `@mllm/backends/cpu/kernels/common/elewise-inl.hpp`:
- Around line 109-111: Add an inline TODO/FIXME explaining why the matrix
variants elewise_div_int16 and elewise_div_int8 are commented out while their
scalar-by-const counterparts remain enabled: locate the commented-out matrix
functions elewise_div_int16 and elewise_div_int8 and either (a) re-enable them
if they were wrongly disabled (uncomment and run CI) or (b) leave them commented
and add a clear TODO/FIXME above each commented block stating the specific
reason (e.g., known HWY/overflow/ABI bug, performance regression, or test
failure) and reference the working scalar functions
(elewise_div_int16_scalar/elewise_div_int8_scalar or similar) so future
maintainers know why only scalar variants are used and what must be resolved to
re-enable matrix paths.

In `@mllm/backends/cpu/kernels/common/kernel_dispatch.cpp`:
- Around line 99-170: The two exports HWY_EXPORT(elewise_div_scl_int16) and
HWY_EXPORT(elewise_div_scl_int8) and their wrappers call_elewise_div_scl_int16 /
call_elewise_div_scl_int8 rely on hn::Div for integer lanes (see
elewise-inl.hpp) which may not be supported by Highway; either remove/disable
these exports and wrappers or conditionally compile them behind a feature macro
(e.g., HWY_HAVE_INT_DIV) and/or implement a safe fallback (cast to a supported
type or use a software integer division path) so the file compiles; update the
entries for elewise-div scalars accordingly to reference the guarded symbols
HWY_EXPORT(elewise_div_scl_int16), HWY_EXPORT(elewise_div_scl_int8),
call_elewise_div_scl_int16, and call_elewise_div_scl_int8.

In `@mllm/backends/cpu/kernels/common/kernel_dispatch.hpp`:
- Around line 131-141: elewise_div_anytype currently only handles mllm_fp32_t
and mllm_int32_t and falls back to scalar for int16/int8, while
elewise_div_scl_anytype routes int16/int8 to SIMD helpers; update
elewise_div_anytype to mirror elewise_div_scl_anytype by adding branches for
mllm_int16_t and mllm_int8_t that call call_elewise_div_int16 and
call_elewise_div_int8 (or, if the Highway integer Div paths in elewise-inl.hpp
are not available/compilable, ensure those branches instead fall through to the
scalar loop), and keep the existing fp32/int32 branches intact so behavior is
consistent across both functions.
- Around line 83-97: Add the missing header include for std::is_same_v by adding
`#include` <type_traits> at the top of the file that defines elewise_add_anytype;
this ensures the template in elewise_add_anytype (which calls std::is_same_v<T,
...>) compiles correctly and resolves the type-trait dependency used alongside
call_elewise_add_fp32 / call_elewise_add_int32 / call_elewise_add_int16 /
call_elewise_add_int8.

In `@mllm/backends/cpu/kernels/common/reduce-inl.hpp`:
- Around line 95-100: The tail handling using hn::LoadN in reduce_impl is only
correct for additive reductions because LoadN zero-fills lanes; update
reduce_impl (and its callers like reduce_sum_fp32) to accept an identity-element
parameter (or a lane-mask-aware load) so the tail lanes are filled with the
reduction identity rather than zeros; specifically, change the tail path where
LoadN is used (referencing LoadN, reduce_impl, vec_op, and vec_reduce_op) to use
the provided identity element for the remaining lanes (or perform a masked load
that preserves a neutral value) and ensure reduce_sum_fp32 passes the
appropriate identity (0 for sum) while other reductions (Min/Max/Mul) would pass
their own identity if/when exposed.

In `@mllm/backends/cpu/ops/MatMulOp.cpp`:
- Around line 52-54: The change routes all cases matching the condition
(!transpose_a && transpose_b) to a known-buggy path MatMulOpType::kGGUF and
removed the previous M >= 4 safeguard that used to fall back to
MatMulOpType::kMllmBlas for small M; if the GGUF issues are not fixed, restore a
size-based guard (e.g., reintroduce the M >= 4 check) inside the same
conditional in MatMulOp.cpp so only sufficiently large M uses kGGUF and small-M
uses kMllmBlas, or alternatively update the conditional to explicitly document
and gate GGUF behind a resolved-bug flag and add a TODO referencing
MatMulOpType::kGGUF for follow-up.

In `@mllm/models/qwen3_moe/configuration_qwen3_moe.hpp`:
- Line 60: The default value for rope_theta is incorrect; update the declaration
of rope_theta in the configuration (float rope_theta) to match the Qwen3-30B
model JSON (set to 10000000.0) so the default constructor uses the correct model
value; ensure any documentation/comments that mention rope_theta reflect the new
1e7 default.

In `@mllm/models/qwen3_moe/modeling_qwen3_moe_fa2.hpp`:
- Around line 199-216: The code may call nn::functional::concat on an empty
outputs vector if every expert got zero tokens; add a guard after the expert
loop: if outputs.empty() create an appropriate empty tensor (e.g. using
Tensor::emptyLike(sorted_tokens).alloc() or Tensor::zeros with the expected
concat-dimension shape) and assign that to outs/new_x, otherwise call
nn::functional::concat(outputs, 0) as before; reference the outputs vector,
nn::functional::concat, sorted_tokens, tokens_per_expert, and new_x (and keep
the original concat path for the non-empty case).
- Around line 484-486: The class Qwen3MoeForCausalLM currently holds cfg as a
const Qwen3MoeConfig& which risks dangling references; change the member
declaration from a reference to a value (store Qwen3MoeConfig cfg;) and update
the constructor to take either const Qwen3MoeConfig& or Qwen3MoeConfig and
copy/move it into the new cfg member; ensure any uses in the class (and
initialization of Qwen3MoeText llm if it requires config) are adjusted to read
from the stored value instead of a reference.

In `@mllm/models/qwen3_moe/tokenization_qwen3_moe.hpp`:
- Around line 225-228: The tensor name string "qwen2-tokenizer-i0" is hardcoded
in convert2Ids and convertMessage; update both places to use a
Qwen3-MoE-specific name (e.g., "qwen3-moe-tokenizer-i0") so logs/debugging
reflect the correct model; locate the Tensor construction in convert2Ids and the
equivalent Tensor in convertMessage and replace the name passed to setName(...)
with the new identifier, ensuring consistency between both functions.
- Around line 236-241: In convertMessage, the result of
applied_string.find("{{{prompt}}}") (stored in pos) isn't checked before calling
applied_string.replace, which can pass npos and cause undefined behavior; update
convertMessage to verify pos != std::string::npos (or otherwise detect a
malformed Qwen3Message::message_template) before calling replace and handle the
error path (e.g., log/throw or fall back to appending the prompt) so replace is
only called with a valid position.
- Around line 214-219: In detokenize (function detokenize in
tokenization_qwen3_moe.hpp) the loop uses operator[] on
bytes_2_unicode_dict_inverse_ which will insert default entries for missing
keys; change the lookup to use .at() or .find() and handle the missing-key case
(throw a meaningful exception or log and skip/return an error) so the map is not
mutated silently and invalid byte values are not produced; ensure the error path
includes the offending wchar_t value for debugging.
- Around line 99-110: In the pattern-5 block inside tokenization_qwen3_moe.hpp
(the code that implements the `\s*[\r\n]+` match), change the leading-whitespace
loop so it does not consume CR/LF: instead of while (pos < str.size() &&
std::iswspace(str[pos])) ++pos; only advance while std::iswspace(str[pos]) AND
str[pos] is not L'\r' and not L'\n' (e.g., while (pos < str.size() &&
std::iswspace(str[pos]) && str[pos] != L'\r' && str[pos] != L'\n') ++pos;), then
keep the subsequent check for CR/LF and the rest of the logic in the same block
(use the same start/pos/matched variables and return true on success).
🧹 Nitpick comments (3)
mllm/backends/cpu/ops/ElewiseOps.cpp (1)

220-240: Missing #else fallback for integer/fp16/complex types on non-ARM/non-x86 architectures.

The kFloat32 branches consistently include a #else NYI(...) fallback (e.g., lines 149–151), but the kInt32, kInt16, kInt8, kFloat16, and kComplexFloat32 branches silently compile to no-ops on unsupported architectures. This is pre-existing for the ARM-only paths, but now that x86 is also covered, it would be good to add #else NYI(...) for consistency and to catch misconfigured builds.

Example for kInt32 Add (apply similar pattern to all integer/fp16/complex cases)
 `#elif` defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86)
         cpu::common::elewise_add_anytype(output.ptr<mllm_int32_t>(), input0.ptr<mllm_int32_t>(), input1.ptr<mllm_int32_t>(),
                                output.numel());
+#else
+        NYI("AddOp int32 not supported on this architecture.");
 `#endif`
mllm/backends/cpu/kernels/common/reduce-inl.hpp (2)

113-118: thread_count parameter is accepted but unused.

reduce_sum_fp32 takes int32_t thread_count but never uses it. If this is a placeholder for future multi-threaded reduction, add a comment (e.g., // TODO: multi-threaded reduction). Otherwise, consider removing the parameter to avoid confusion.


13-53: Duplicate scalar/vector op structs with elewise-inl.hpp.

ScalarAddOp/VecAddOp etc. duplicate the functionality of AddOp/SubOp etc. in elewise-inl.hpp. The existing elewise-inl.hpp ops already have the same (D d, V a, V b) signature. Consider reusing those ops and extracting the scalar ops into a shared location to avoid maintaining two parallel sets.

"output_router_logits": false,
"rms_norm_eps": 1e-06,
"rope_scaling": 1.0,
"rope_theta": 10000000,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

rope_theta value here (10000000) differs from the default in configuration_qwen3_moe.hpp (1000000.0).

The JSON provides 10000000 (10⁷) while the C++ default at configuration_qwen3_moe.hpp:60 is 1000000.0 (10⁶). When loading from JSON this is fine, but anyone using the default constructor gets the wrong value. Please align the C++ default with the model's actual rope_theta.

🤖 Prompt for AI Agents
In `@examples/qwen3_moe/config_30B_A3B_gguf.json` at line 28, The default
rope_theta in configuration_qwen3_moe.hpp is mismatched with the model JSON;
update the default value used by the class (the member/initializer named
rope_theta in the configuration_qwen3_moe.hpp default constructor or in-class
initializer) from 1e6 (1000000.0) to 1e7 (10000000.0) so it matches the
examples/qwen3_moe/config_30B_A3B_gguf.json value.

Comment on lines +23 to +38
mllm::ModelFileVersion file_version = mllm::ModelFileVersion::kV1;
if (model_version.get() == "v1") {
file_version = mllm::ModelFileVersion::kV1;
} else if (model_version.get() == "v2") {
file_version = mllm::ModelFileVersion::kV2;
} else {
fmt::print("❌ Unsupported model_version: {} (expected v1 or v2)\n", model_version.get());
mllm::shutdownContext();
return 1;
}

if (help.isSet()) {
Argparse::printHelp();
mllm::shutdownContext();
return 0;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, locate Argparse.hpp
find . -name "Argparse.hpp" -type f

Repository: UbiquitousLearning/mllm

Length of output: 92


🏁 Script executed:

# Once located, check the implementation of required arguments and get() method
rg -n "required" mllm/utils/Argparse.hpp -A 3 -B 1

Repository: UbiquitousLearning/mllm

Length of output: 605


🏁 Script executed:

# Also check the get() method implementation
rg -n "\.get\(\)" mllm/utils/Argparse.hpp -A 2 -B 1

Repository: UbiquitousLearning/mllm

Length of output: 484


🏁 Script executed:

# Look at the full context of main.cpp around the lines in question
head -50 examples/qwen3_moe/main.cpp | tail -40

Repository: UbiquitousLearning/mllm

Length of output: 1651


🏁 Script executed:

# Get the full parse() method implementation
rg -n "void parse" mllm/utils/Argparse.hpp -A 30

Repository: UbiquitousLearning/mllm

Length of output: 3232


🏁 Script executed:

# Also check if help flag has any special behavior
rg -n "help.*isSet\|--help" mllm/utils/Argparse.hpp -B 2 -A 2

Repository: UbiquitousLearning/mllm

Length of output: 49


🏁 Script executed:

# Get more of the parse method to see required arg validation
rg -n "static void parse" mllm/utils/Argparse.hpp -A 50 | head -80

Repository: UbiquitousLearning/mllm

Length of output: 1821


🏁 Script executed:

# Check if there's special handling for help before validation
rg -n "MLLM_ERROR_EXIT.*required" mllm/utils/Argparse.hpp -B 5 -A 2

Repository: UbiquitousLearning/mllm

Length of output: 344


🏁 Script executed:

# Get the complete parse method to see if there's early exit handling for help
rg -n "static void parse" mllm/utils/Argparse.hpp -A 70

Repository: UbiquitousLearning/mllm

Length of output: 2631


--help check happens after parse(), which exits if required arguments are missing—help will be unreachable.

Argparse::parse() validates required arguments at the end (lines 200-205 in Argparse.hpp) and calls MLLM_ERROR_EXIT if any are missing. The --help flag is parsed successfully within parse(), but the subsequent required argument validation still exits before the help check at line 34 executes. Move the help check before calling parse() to make help accessible without required arguments.

Proposed fix
   Argparse::parse(argc, argv);

+  if (help.isSet()) {
+    Argparse::printHelp();
+    mllm::shutdownContext();
+    return 0;
+  }
+
   mllm::ModelFileVersion file_version = mllm::ModelFileVersion::kV1;
   if (model_version.get() == "v1") {
     file_version = mllm::ModelFileVersion::kV1;
   } else if (model_version.get() == "v2") {
     file_version = mllm::ModelFileVersion::kV2;
   } else {
     fmt::print("❌ Unsupported model_version: {} (expected v1 or v2)\n", model_version.get());
     mllm::shutdownContext();
     return 1;
    }

-  if (help.isSet()) {
-    Argparse::printHelp();
-    mllm::shutdownContext();
-    return 0;
-  }
🤖 Prompt for AI Agents
In `@examples/qwen3_moe/main.cpp` around lines 23 - 38, Move the help handling so
the program prints help without requiring other args: call Argparse::parse()
only after checking if help.isSet(), i.e., check help.isSet() immediately after
constructing/parsing raw flags (or the minimal pre-parse if needed) and call
Argparse::printHelp() then mllm::shutdownContext() and return 0; ensure this
check runs before invoking Argparse::parse() (which performs required-argument
validation and may exit), and keep references to Argparse::printHelp(),
help.isSet(), Argparse::parse(), and mllm::shutdownContext() while relocating
the check.

Comment on lines +48 to +56
fmt::print("\n{:*^60}\n", " Qwen3 MoE Interactive CLI ");
fmt::print("Enter 'exit' or 'quit' to end the session\n\n");

std::string prompt_text;

fmt::print("💬 Prompt text (or 'exit/quit'): ");
std::getline(std::cin, prompt_text);

if(prompt_text == "exit" || prompt_text == "quit") { return 0; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

UI says "interactive" but implementation is single-prompt only.

Lines 48-49 suggest an interactive session ("Enter 'exit' or 'quit' to end the session"), but there is no loop — only one prompt is accepted. Either wrap lines 53-68 in a while(true) loop or update the banner to reflect single-prompt behavior.

🤖 Prompt for AI Agents
In `@examples/qwen3_moe/main.cpp` around lines 48 - 56, The UI claims an
interactive CLI but main flow only reads one prompt into prompt_text then exits;
wrap the prompt/read/response/exit-check sequence (the block using prompt_text
and the prompt/ getline logic) inside a loop (e.g., while(true)) so the program
repeatedly prompts until prompt_text == "exit" || "quit", or alternatively
update the banner/strings to remove "interactive" and the "Enter 'exit' or
'quit' to end the session" line; modify the code paths referencing prompt_text
to support repeated iterations and clean termination.

Comment on lines +46 to +67
"^model\\.layers\\.\\d+\\.mlp\\.experts\\.\\d+\\.up_proj.(bias|weight)": {
"hints": {
"quant_method": "gguf",
"gguf_type": "Q4_K",
"shape": [
768,
2048
],
"replace": true
}
},
"^model\\.layers\\.\\d+\\.mlp\\.experts\\.\\d+\\.down_proj.(bias|weight)": {
"hints": {
"quant_method": "gguf",
"gguf_type": "Q6_K",
"shape": [
2048,
768
],
"replace": true
}
},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# Check if gate_proj exists in quantization configs and model definitions
rg -n "gate_proj" --type=json

Repository: UbiquitousLearning/mllm

Length of output: 1218


🏁 Script executed:

# Also search in Python files for model definitions
rg -n "gate_proj" --type=py

Repository: UbiquitousLearning/mllm

Length of output: 1146


🏁 Script executed:

# List all quantization config files for comparison
fd -e json -p "quant_cfg"

Repository: UbiquitousLearning/mllm

Length of output: 477


🏁 Script executed:

# Search for qwen3_moe model definitions
fd -p "qwen3_moe" --type=py | head -20

Repository: UbiquitousLearning/mllm

Length of output: 236


🏁 Script executed:

# Check the qwen3_moe model file for MoE expert structure and gate_proj
fd -p "modeling_qwen3" --type=py | grep -i moe

Repository: UbiquitousLearning/mllm

Length of output: 236


🏁 Script executed:

# Search for Qwen3MoeMLP or similar MoE expert definitions
rg -n "Qwen3.*Moe|MoeMLP" --type=py -A 10

Repository: UbiquitousLearning/mllm

Length of output: 49


🏁 Script executed:

# Search for Qwen3MoeMLP or MoE expert definitions
rg -n "Qwen3.*Moe|MoeMLP" --type-add "py:*.py"

Repository: UbiquitousLearning/mllm

Length of output: 3702


🏁 Script executed:

# Search for qwen3_moe references in Python files
rg -n "qwen3_moe"

Repository: UbiquitousLearning/mllm

Length of output: 1964


🏁 Script executed:

# Search for expert projection definitions with gate_proj
rg -n "experts.*gate_proj|gate_proj.*experts"

Repository: UbiquitousLearning/mllm

Length of output: 50379


🏁 Script executed:

cat mllm/models/qwen3_moe/modeling_qwen3_moe_fa2.hpp | sed -n '71,112p'

Repository: UbiquitousLearning/mllm

Length of output: 1468


🏁 Script executed:

# Check if gate_proj is mentioned anywhere in the qwen3_moe model file
rg -n "gate_proj" mllm/models/qwen3_moe/

Repository: UbiquitousLearning/mllm

Length of output: 436


🏁 Script executed:

# Also check the current quant_cfg_30B_q4_k.json file directly for context
cat examples/qwen3_moe/quant_cfg_30B_q4_k.json

Repository: UbiquitousLearning/mllm

Length of output: 1612


Add gate_proj to MoE expert quantization config.

The model defines three projections per MoE expert (gate_proj, up_proj, down_proj), but only up_proj and down_proj are quantized. The gate_proj is actively used in the forward pass yet missing from the config. If intentional, add a comment; otherwise, add the following pattern:

"^model\\.layers\\.\\d+\\.mlp\\.experts\\.\\d+\\.gate_proj.(bias|weight)": {
  "hints": {
    "quant_method": "gguf",
    "gguf_type": "Q4_K",
    "shape": [2048, 2048],
    "replace": true
  }
}
🤖 Prompt for AI Agents
In `@examples/qwen3_moe/quant_cfg_30B_q4_k.json` around lines 46 - 67, The MoE
expert config is missing the gate_proj pattern used in the forward pass; add a
new JSON entry matching
"^model\\.layers\\.\\d+\\.mlp\\.experts\\.\\d+\\.gate_proj.(bias|weight)" with
the same quant hints style as up_proj/down_proj (quant_method "gguf", gguf_type
"Q4_K", shape [2048,2048], replace true) so gate_proj weights are quantized, or
if omission was intentional, add an inline comment next to the up_proj/down_proj
blocks explaining why gate_proj is excluded; reference the existing up_proj,
down_proj and the model.layers.*.mlp.experts.*.gate_proj symbol to locate where
to add the change.

"replace": true
}
},
"^lm_head.weight": {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Unescaped dot in lm_head.weight regex pattern.

All other patterns escape dots as \\. (e.g., self_attn\\.q_proj), but lm_head.weight uses an unescaped . which matches any character in regex. Should be ^lm_head\\.weight for consistency and correctness.

Proposed fix
-  "^lm_head.weight": {
+  "^lm_head\\.weight": {
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"^lm_head.weight": {
"^lm_head\\.weight": {
🤖 Prompt for AI Agents
In `@examples/qwen3_moe/quant_cfg_30B_q4_k.json` at line 68, The regex entry for
"lm_head.weight" uses an unescaped dot which matches any character; update the
pattern to escape the dot (change the pattern for lm_head.weight to use
lm_head\\.weight) to match the literal field name consistently with other keys
(e.g., self_attn\\.q_proj) so the config only targets the intended parameter.

Comment on lines +484 to +486
private:
const Qwen3MoeConfig& cfg;
Qwen3MoeText llm;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

find . -type f -name "*.hpp" -o -name "*.cpp" | head -20

Repository: UbiquitousLearning/mllm

Length of output: 1044


🏁 Script executed:

cat -n mllm/models/qwen3_moe/modeling_qwen3_moe_fa2.hpp | sed -n '470,500p'

Repository: UbiquitousLearning/mllm

Length of output: 731


🏁 Script executed:

rg -n "Qwen3MoeForCausalLM" -A 5 -B 2 --type=cpp --type=hpp

Repository: UbiquitousLearning/mllm

Length of output: 94


🏁 Script executed:

rg "class Qwen3MoeForCausalLM" -A 20

Repository: UbiquitousLearning/mllm

Length of output: 2131


🏁 Script executed:

rg "Qwen3MoeForCausalLM" -A 5 -B 2

Repository: UbiquitousLearning/mllm

Length of output: 2218


🏁 Script executed:

cat -n mllm/models/qwen3_moe/modeling_qwen3_moe_fa2.hpp | sed -n '1,100p'

Repository: UbiquitousLearning/mllm

Length of output: 4826


🏁 Script executed:

rg "cfg\." mllm/models/qwen3_moe/modeling_qwen3_moe_fa2.hpp | head -30

Repository: UbiquitousLearning/mllm

Length of output: 2519


🏁 Script executed:

cat -n mllm/models/qwen3_moe/modeling_qwen3_moe_fa2.hpp | sed -n '400,490p'

Repository: UbiquitousLearning/mllm

Length of output: 4035


🏁 Script executed:

rg "this->cfg" mllm/models/qwen3_moe/modeling_qwen3_moe_fa2.hpp

Repository: UbiquitousLearning/mllm

Length of output: 49


🏁 Script executed:

rg "const.*Config&.*cfg" mllm/models/ -A 2 | head -50

Repository: UbiquitousLearning/mllm

Length of output: 4501


🏁 Script executed:

rg "private:" mllm/models/qwen3_moe/modeling_qwen3_moe_fa2.hpp -A 5

Repository: UbiquitousLearning/mllm

Length of output: 459


Dangling reference risk: cfg stored as const Qwen3MoeConfig&.

Qwen3MoeForCausalLM stores a const reference to the config. Since the cfg member is only used during construction and never accessed afterward, storing a reference creates unnecessary fragility. If the config object is destroyed or moved before the model, this becomes a dangling reference. Store by value instead.

Proposed fix
  private:
-  const Qwen3MoeConfig& cfg;
+  const Qwen3MoeConfig cfg;
🤖 Prompt for AI Agents
In `@mllm/models/qwen3_moe/modeling_qwen3_moe_fa2.hpp` around lines 484 - 486, The
class Qwen3MoeForCausalLM currently holds cfg as a const Qwen3MoeConfig& which
risks dangling references; change the member declaration from a reference to a
value (store Qwen3MoeConfig cfg;) and update the constructor to take either
const Qwen3MoeConfig& or Qwen3MoeConfig and copy/move it into the new cfg
member; ensure any uses in the class (and initialization of Qwen3MoeText llm if
it requires config) are adjusted to read from the stored value instead of a
reference.

Comment on lines +99 to +110
// 5. Match \s*[\r\n]+ (newlines with leading whitespace)
{
size_t start = pos;
while (pos < str.size() && std::iswspace(str[pos])) ++pos;
if (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) {
while (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) ++pos;
matched = str.substr(start, pos - start);
return true;
} else {
pos = start;
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Pattern 5 (\s*[\r\n]+) can never match — greedy \s* consumes newlines first.

The while (std::iswspace(...)) loop on Line 102 treats \r and \n as whitespace and consumes them greedily. When the subsequent check on Line 103 looks for \r/\n, they've already been consumed and pos has moved past them. The pattern always falls through to the rollback at Line 108.

To fix, the leading whitespace consumption should skip \r/\n:

Proposed fix
   // 5. Match \s*[\r\n]+ (newlines with leading whitespace)
   {
     size_t start = pos;
-    while (pos < str.size() && std::iswspace(str[pos])) ++pos;
+    while (pos < str.size() && std::iswspace(str[pos]) && str[pos] != L'\r' && str[pos] != L'\n') ++pos;
     if (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) {
       while (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) ++pos;
       matched = str.substr(start, pos - start);
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// 5. Match \s*[\r\n]+ (newlines with leading whitespace)
{
size_t start = pos;
while (pos < str.size() && std::iswspace(str[pos])) ++pos;
if (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) {
while (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) ++pos;
matched = str.substr(start, pos - start);
return true;
} else {
pos = start;
}
}
// 5. Match \s*[\r\n]+ (newlines with leading whitespace)
{
size_t start = pos;
while (pos < str.size() && std::iswspace(str[pos]) && str[pos] != L'\r' && str[pos] != L'\n') ++pos;
if (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) {
while (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) ++pos;
matched = str.substr(start, pos - start);
return true;
} else {
pos = start;
}
}
🤖 Prompt for AI Agents
In `@mllm/models/qwen3_moe/tokenization_qwen3_moe.hpp` around lines 99 - 110, In
the pattern-5 block inside tokenization_qwen3_moe.hpp (the code that implements
the `\s*[\r\n]+` match), change the leading-whitespace loop so it does not
consume CR/LF: instead of while (pos < str.size() && std::iswspace(str[pos]))
++pos; only advance while std::iswspace(str[pos]) AND str[pos] is not L'\r' and
not L'\n' (e.g., while (pos < str.size() && std::iswspace(str[pos]) && str[pos]
!= L'\r' && str[pos] != L'\n') ++pos;), then keep the subsequent check for CR/LF
and the rest of the logic in the same block (use the same start/pos/matched
variables and return true on success).

Comment on lines +214 to +219
std::wstring detokenize(int64_t pos_idx) override {
auto str = _detokenize(pos_idx);
std::string utf_8_str;
for (wchar_t c : str) { utf_8_str.push_back((unsigned char)(bytes_2_unicode_dict_inverse_[c])); }
return {mllm::preprocessor::utf8string2WideString(utf_8_str)};
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

bytes_2_unicode_dict_inverse_ lookup may silently insert default entries.

Using operator[] on bytes_2_unicode_dict_inverse_ (Line 217) for a key not in the map will insert a zero entry, corrupting the map and producing incorrect byte values. Use .at() to get an exception on missing keys, or .find() with an error path.

Proposed fix
   std::wstring detokenize(int64_t pos_idx) override {
     auto str = _detokenize(pos_idx);
     std::string utf_8_str;
-    for (wchar_t c : str) { utf_8_str.push_back((unsigned char)(bytes_2_unicode_dict_inverse_[c])); }
+    for (wchar_t c : str) { utf_8_str.push_back((unsigned char)(bytes_2_unicode_dict_inverse_.at(c))); }
     return {mllm::preprocessor::utf8string2WideString(utf_8_str)};
   }
🤖 Prompt for AI Agents
In `@mllm/models/qwen3_moe/tokenization_qwen3_moe.hpp` around lines 214 - 219, In
detokenize (function detokenize in tokenization_qwen3_moe.hpp) the loop uses
operator[] on bytes_2_unicode_dict_inverse_ which will insert default entries
for missing keys; change the lookup to use .at() or .find() and handle the
missing-key case (throw a meaningful exception or log and skip/return an error)
so the map is not mutated silently and invalid byte values are not produced;
ensure the error path includes the offending wchar_t value for debugging.

Comment on lines +225 to +228
Tensor ret = Tensor::empty({/*batch*/ 1, /*seq*/ (int32_t)ids.size()}, kInt64, kCPU)
.setMemType(kExtraInput)
.setName("qwen2-tokenizer-i0")
.alloc();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Tensor names hardcoded as "qwen2-tokenizer-i0" — likely a copy-paste from Qwen2.

Both convert2Ids (Line 227) and convertMessage (Line 251) name tensors "qwen2-tokenizer-i0". For a Qwen3 MoE model, this should be updated to avoid confusion during debugging and logging.

Proposed fix
-                     .setName("qwen2-tokenizer-i0")
+                     .setName("qwen3-moe-tokenizer-i0")

Also applies to: 249-252

🤖 Prompt for AI Agents
In `@mllm/models/qwen3_moe/tokenization_qwen3_moe.hpp` around lines 225 - 228, The
tensor name string "qwen2-tokenizer-i0" is hardcoded in convert2Ids and
convertMessage; update both places to use a Qwen3-MoE-specific name (e.g.,
"qwen3-moe-tokenizer-i0") so logs/debugging reflect the correct model; locate
the Tensor construction in convert2Ids and the equivalent Tensor in
convertMessage and replace the name passed to setName(...) with the new
identifier, ensuring consistency between both functions.

Comment on lines +236 to +241
ARGenerationOutputPast convertMessage(const Qwen3Message& message) {
// process prompt
auto applied_string = Qwen3Message::message_template;
size_t pos = applied_string.find("{{{prompt}}}");
applied_string.replace(pos, 12, message.prompt);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

std::string::find result is not checked before replace, risking UB if template is malformed.

If applied_string.find("{{{prompt}}}") returns std::string::npos, the replace call on Line 240 passes npos as the position, which is undefined behavior.

Proposed fix
     auto applied_string = Qwen3Message::message_template;
     size_t pos = applied_string.find("{{{prompt}}}");
+    MLLM_RT_ASSERT(pos != std::string::npos);
     applied_string.replace(pos, 12, message.prompt);
🤖 Prompt for AI Agents
In `@mllm/models/qwen3_moe/tokenization_qwen3_moe.hpp` around lines 236 - 241, In
convertMessage, the result of applied_string.find("{{{prompt}}}") (stored in
pos) isn't checked before calling applied_string.replace, which can pass npos
and cause undefined behavior; update convertMessage to verify pos !=
std::string::npos (or otherwise detect a malformed
Qwen3Message::message_template) before calling replace and handle the error path
(e.g., log/throw or fall back to appending the prompt) so replace is only called
with a valid position.

Copy link
Collaborator

@oreomaker oreomaker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It this has been a runnable demo? The float16 support seems unsolved. Can you explain it more?

@HayzelHan
Copy link
Contributor Author

It this has been a runnable demo? The float16 support seems unsolved. Can you explain it more?

Yes, it is a runnable demo. I tested successfully with Qwen3-30B-A3B using Q4_K/Q6_K quantization (the quantization file is examples/qwen3_moe/quant_cfg_30B_q4_k.json).

Regarding FP16: sorry for the incomplete elementwise/reduce operator implementation. I prioritized FP32 and int types since:

  • Highway library doesn't directly support FP16 operations
  • The current Q4_K/Q6_K models don't require FP16

Happy to complete the missing types in a follow-up PR if you'd like!

@oreomaker oreomaker merged commit 4fd29d0 into UbiquitousLearning:main Feb 17, 2026
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants