From 1cb743941e5fcd8c49a4544b02a3d9114dbcbd75 Mon Sep 17 00:00:00 2001 From: huangzhenhua111 Date: Fri, 30 Jan 2026 17:44:08 +0800 Subject: [PATCH 1/3] tools/mllm-llm-benchmark: add llama benchmark template --- tools/mllm-llm-benchmark/models/All.hpp | 25 +++- tools/mllm-llm-benchmark/models/Llama.hpp | 134 ++++++++++++++++++++++ 2 files changed, 154 insertions(+), 5 deletions(-) create mode 100644 tools/mllm-llm-benchmark/models/Llama.hpp diff --git a/tools/mllm-llm-benchmark/models/All.hpp b/tools/mllm-llm-benchmark/models/All.hpp index 340fe6bf8..728363355 100644 --- a/tools/mllm-llm-benchmark/models/All.hpp +++ b/tools/mllm-llm-benchmark/models/All.hpp @@ -4,20 +4,35 @@ #include #include +#include +#include // for std::tolower -#include "Qwen3_W4A32_KAI.hpp" #include "BenchmarkTemplate.hpp" +#include "Qwen3_W4A32_KAI.hpp" +#include "Llama.hpp" -std::shared_ptr createBenchmark(const std::string& model_name) { +inline std::shared_ptr createBenchmark(const std::string& model_name) { auto tolower = [](const std::string& str) { std::string result = str; - std::transform(result.begin(), result.end(), result.begin(), ::tolower); + // NOTE: std::tolower expects unsigned char cast to avoid UB for negative char values. + std::transform(result.begin(), result.end(), result.begin(), + [](unsigned char c) { return static_cast(std::tolower(c)); }); return result; }; + auto normalized_model_name = tolower(model_name); - if (normalized_model_name.find("qwen3") != std::string::npos && normalized_model_name.find("w4a32") != std::string::npos - && normalized_model_name.find("kai") != std::string::npos) { + + if (normalized_model_name.find("qwen3") != std::string::npos && + normalized_model_name.find("w4a32") != std::string::npos && + normalized_model_name.find("kai") != std::string::npos) { return std::make_shared(); } + + if (normalized_model_name.find("llama") != std::string::npos || + normalized_model_name.find("tinyllama") != std::string::npos || + normalized_model_name.find("tiny_llama") != std::string::npos) { + return std::make_shared(); + } + return nullptr; } diff --git a/tools/mllm-llm-benchmark/models/Llama.hpp b/tools/mllm-llm-benchmark/models/Llama.hpp new file mode 100644 index 000000000..f313922ef --- /dev/null +++ b/tools/mllm-llm-benchmark/models/Llama.hpp @@ -0,0 +1,134 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include + +#include "BenchmarkTemplate.hpp" + +#include +#include +#include + +class Llama_Benchmark final : public BenchmarkTemplate { + public: + void init(const std::string& cfg_path, const std::string& model_path, int32_t cache_length) override { + cfg_ = std::make_unique(cfg_path); + + // LLaMA config uses max_position_embeddings as KV-cache upper bound + if (cache_length > 0) { + cfg_->max_position_embeddings = cache_length; + } + + model_ = std::make_unique("", *cfg_); + + // NOTE: + // tinyllama-fp32.mllm used in examples is a V1 parameter file. + // Loading it as V2 will assert on magic number mismatch. + // We keep V1-only here to make the benchmark runnable; V2 support can be added later + // once we have either: + // (1) a reliable file-version probe, or + // (2) a CLI flag to select model file version. + auto param = mllm::load(model_path, mllm::ModelFileVersion::kV1); + model_->load(param); + + mllm::print("Model initialized successfully"); + } + + void printModelInfo() override { + if (!cfg_) return; + mllm::print("========== Model Information =========="); + mllm::print("Model Type : LLaMA / TinyLlama"); + mllm::print("Hidden Size :", cfg_->hidden_size); + mllm::print("Num Layers :", cfg_->num_hidden_layers); + mllm::print("Num Heads :", cfg_->num_attention_heads); + mllm::print("Num KV Heads :", cfg_->num_key_value_heads); + // NOTE: Defensive guard (shouldn't happen with valid configs, but keeps benchmark robust). + int32_t head_dim = (cfg_->num_attention_heads > 0) ? (cfg_->hidden_size / cfg_->num_attention_heads) : 0; + mllm::print("Head Dim :", head_dim); + mllm::print("Intermediate Size :", cfg_->intermediate_size); + mllm::print("Vocab Size :", cfg_->vocab_size); + mllm::print("Max Pos Embeddings :", cfg_->max_position_embeddings); + mllm::print("======================================="); + } + + void warmup() override { + if (!model_) return; + + const int32_t warmup_length = 8; + const int32_t warmup_gen = 4; + + auto input_ids = mllm::Tensor::empty({1, warmup_length}, mllm::kInt64, mllm::kCPU) + .setMemType(mllm::kNormal) + .alloc(); + auto ptr = input_ids.ptr(); + for (int i = 0; i < warmup_length; ++i) ptr[i] = 1; + + mllm::models::ARGenerationOutputPast inputs; + inputs["sequence"] = input_ids; + + mllm::models::ARGenerationArgs args; + args["max_length"] = mllm::AnyValue((int)warmup_gen); + args["do_sample"] = mllm::AnyValue(false); + + model_->generate(inputs, args); + mllm::print("Warmup completed"); + } + + void clear() override { + // TODO: expose a public KV-cache reset API for LlamaForCausalLM (if needed). + // For now, keep it as no-op to minimize API changes in PR1. + } + + BenchmarkTemplateResult run(int32_t pp, int32_t tg) override { + if (!model_) return {0.f, 0.f, 0.f}; + + auto input_ids = mllm::Tensor::empty({1, pp}, mllm::kInt64, mllm::kCPU) + .setMemType(mllm::kNormal) + .alloc(); + auto ptr = input_ids.ptr(); + for (int i = 0; i < pp; ++i) ptr[i] = 1 + (i % 100); + + mllm::models::ARGenerationOutputPast inputs; + inputs["sequence"] = input_ids; + + mllm::models::ARGenerationArgs args; + args["max_length"] = mllm::AnyValue((int)tg); + args["do_sample"] = mllm::AnyValue(false); + + auto prefill_start = std::chrono::high_resolution_clock::now(); + auto decode_start = prefill_start; + auto decode_end = prefill_start; + + bool first_token = true; + int token_count = 0; + + model_->streamGenerate(inputs, args, [&](int64_t /*token_id*/) { + if (first_token) { + decode_start = std::chrono::high_resolution_clock::now(); + first_token = false; + } + token_count++; + decode_end = std::chrono::high_resolution_clock::now(); + }); + + auto prefill_us = std::chrono::duration_cast(decode_start - prefill_start).count(); + auto decode_us = std::chrono::duration_cast(decode_end - decode_start).count(); + + BenchmarkTemplateResult r; + r.ttft = prefill_us / 1000.0f; + r.prefill_speed = (prefill_us > 0) ? (static_cast(pp) / prefill_us) * 1e6f : 0.f; + // NOTE: decode_us is measured from first token timestamp; exclude that first token from decode throughput. + int decode_tokens = (token_count > 0) ? (token_count - 1) : 0; + r.decode_speed = (decode_us > 0 && decode_tokens > 0) + ? (static_cast(decode_tokens) / decode_us) * 1e6f + : 0.f; + return r; + } + + private: + std::unique_ptr cfg_; + std::unique_ptr model_; +}; From 5444b15ea56787fd46fcc929d4b398ef6c688b8d Mon Sep 17 00:00:00 2001 From: huangzhenhua111 Date: Sat, 31 Jan 2026 11:52:57 +0800 Subject: [PATCH 2/3] tools/mllm-llm-benchmark: add csv output and configurable runs --- tools/mllm-llm-benchmark/main.cpp | 112 +++++++++++++++++----- tools/mllm-llm-benchmark/models/Llama.hpp | 4 + 2 files changed, 94 insertions(+), 22 deletions(-) diff --git a/tools/mllm-llm-benchmark/main.cpp b/tools/mllm-llm-benchmark/main.cpp index af275a2e6..382355a73 100644 --- a/tools/mllm-llm-benchmark/main.cpp +++ b/tools/mllm-llm-benchmark/main.cpp @@ -1,10 +1,13 @@ // Copyright (c) MLLM Team. // Licensed under the MIT License. +#include +#include #include #include #include #include +#include // For std::transform #include #include @@ -16,6 +19,14 @@ #include "models/All.hpp" +#ifndef MLLM_GIT_COMMIT_HASH +#define MLLM_GIT_COMMIT_HASH unknown +#endif + +#define STR_HELPER(x) #x +#define STR(x) STR_HELPER(x) + + MLLM_MAIN({ auto& help = mllm::Argparse::add("-h|--help").help("Show help message"); auto& model_name = mllm::Argparse::add("-n|--model_name").help("Model name"); @@ -25,8 +36,19 @@ MLLM_MAIN({ auto& pp = mllm::Argparse::add("-pp|--prompt_length").help("Prompt length"); auto& tg = mllm::Argparse::add("-tg|--test_generation_length").help("Test Generation length"); auto& cache_length = mllm::Argparse::add("-cl|--cache_length").help("Cache length"); + + // New CLI Arguments + auto& runs = mllm::Argparse::add("-r|--runs").help("Number of benchmark runs").def(3); + auto& cooldown_s = mllm::Argparse::add("-cs|--cooldown_s").help("Cooldown time between runs in seconds").def(5); + auto& output_csv = mllm::Argparse::add("-oc|--output_csv").help("Output results to a CSV file").def(""); + auto& schema_version = mllm::Argparse::add("-sv|--schema_version").help("Schema version for output format").def(1); + auto& kv_dtype_bytes = mllm::Argparse::add("-kv|--kv_dtype_bytes").help("KV cache data type bytes (1: int8, 2: fp16, 4: fp32)").def(4); + mllm::Argparse::parse(argc, argv); + mllm::Context::instance().setCpuOpThreads(num_threads.get()); + mllm::setMaximumNumThreads((uint32_t)num_threads.get()); + // Print Build Version mllm::print("MLLM Build Version :", STRINGIFY(MLLM_GIT_COMMIT_HASH)); @@ -58,6 +80,25 @@ MLLM_MAIN({ auto benchmark = createBenchmark(model_name.get()); MLLM_RT_ASSERT(benchmark != nullptr); + + // Validate runs early to avoid huge reserve() when negative values cast to size_t. + int R = runs.get(); + if (R <= 0) { + mllm::print("[ERROR] --runs must be > 0, got:", R); + return 1; + } + + // Open file stream + std::ofstream csv_file; + if (!output_csv.get().empty()) { + csv_file.open(output_csv.get()); + if (!csv_file.is_open()) { + mllm::print("[ERROR] Failed to open --output_csv:", output_csv.get()); + return 1; + } + csv_file << "schema_version,git_commit,arch,model_name,pp,tg,ttft_ms,prefill_speed,decode_speed,prefill_ms,decode_ms_per_tok,kv_est_bytes_pp,kv_est_bytes_final\n"; + } + // Print Model Info mllm::print("Model Info"); benchmark->init(config_path.get(), model_path.get(), cache_length.get()); @@ -92,7 +133,7 @@ MLLM_MAIN({ for (size_t i = 0; i < pp_values.size(); ++i) { pp_tg_pairs.emplace_back(pp_values[i], tg_values[i]); } } - // Actual run for 3 turns and gives avg results. Each turn will sleep for 5 seconds to let the SoC or GPU/NPU cool down. + // Actual run for configurable number of turns mllm::print("\n========================================"); mllm::print("Starting Benchmark Tests"); mllm::print("========================================\n"); @@ -106,15 +147,12 @@ MLLM_MAIN({ // Storage for results std::vector results; - results.reserve(3); + results.reserve(static_cast(R)); - for (int i = 0; i < 3; ++i) { - mllm::print(" Run", i + 1, "of 3..."); + for (int i = 0; i < R; ++i) { + mllm::print(" Run", i + 1, "of", R, "..."); - // Clear cache before each run benchmark->clear(); - - // Run benchmark auto result = benchmark->run(pp, tg); results.push_back(result); @@ -122,14 +160,20 @@ MLLM_MAIN({ mllm::print(" Prefill Speed:", result.prefill_speed, "tokens/s"); mllm::print(" Decode Speed :", result.decode_speed, "tokens/s"); - // Sleep for 5 seconds between runs to cool down - if (i < 2) { - mllm::print(" Cooling down for 5 seconds..."); - std::this_thread::sleep_for(std::chrono::seconds(5)); + float prefill_ms = (result.prefill_speed > 0.0f) ? (pp / result.prefill_speed) * 1000.0f : 0.0f; + float decode_ms_per_tok = (result.decode_speed > 0.0f) ? (1.0f / result.decode_speed) * 1000.0f : 0.0f; + mllm::print(" Prefill Latency :", prefill_ms, "ms"); + mllm::print(" Decode Latency :", decode_ms_per_tok, "ms"); + + int cool = cooldown_s.get(); + if (i + 1 < R && cool > 0) { + mllm::print(" Cooling down for", cool, "seconds..."); + std::this_thread::sleep_for(std::chrono::seconds(cool)); } } // Calculate average results + float denom = (R > 0) ? static_cast(R) : 1.0f; float avg_ttft = 0.0f; float avg_prefill_speed = 0.0f; float avg_decode_speed = 0.0f; @@ -140,20 +184,44 @@ MLLM_MAIN({ avg_decode_speed += result.decode_speed; } - avg_ttft /= 3.0f; - avg_prefill_speed /= 3.0f; - avg_decode_speed /= 3.0f; - - // Print average results - mllm::print("\n========== Average Results =========="); - mllm::print("Configuration: PP=", pp, " TG=", tg); - mllm::print("Average TTFT :", avg_ttft, "ms"); - mllm::print("Average Prefill Speed:", avg_prefill_speed, "tokens/s"); - mllm::print("Average Decode Speed :", avg_decode_speed, "tokens/s"); - mllm::print("=====================================\n"); + avg_ttft /= denom; + avg_prefill_speed /= denom; + avg_decode_speed /= denom; + + float avg_prefill_ms = (avg_prefill_speed > 0.0f) ? (pp / avg_prefill_speed) * 1000.0f : 0.0f; + float avg_decode_ms_per_tok = (avg_decode_speed > 0.0f) ? (1.0f / avg_decode_speed) * 1000.0f : 0.0f; + + // Rough KV cache estimate (bytes) + double kv_est_bytes_pp = 0.0; + double kv_est_bytes_final = 0.0; + + // Prepare one line output (avg) + std::stringstream ss; + ss << schema_version.get() << "," + << STRINGIFY(MLLM_GIT_COMMIT_HASH) << "," + << mllm::cpu::CURRENT_ARCH_STRING << "," + << model_name.get() << "," + << pp << "," + << tg << "," + << avg_ttft << "," + << avg_prefill_speed << "," + << avg_decode_speed << "," + << avg_prefill_ms << "," + << avg_decode_ms_per_tok << "," + << kv_est_bytes_pp << "," + << kv_est_bytes_final; + + if (csv_file.is_open()) { + csv_file << ss.str() << std::endl; + } } mllm::print("\n========================================"); mllm::print("Benchmark Tests Completed"); mllm::print("========================================"); + + //close file stream + if (csv_file.is_open()) { + csv_file.close(); + } }) diff --git a/tools/mllm-llm-benchmark/models/Llama.hpp b/tools/mllm-llm-benchmark/models/Llama.hpp index f313922ef..74128cb18 100644 --- a/tools/mllm-llm-benchmark/models/Llama.hpp +++ b/tools/mllm-llm-benchmark/models/Llama.hpp @@ -83,6 +83,10 @@ class Llama_Benchmark final : public BenchmarkTemplate { } BenchmarkTemplateResult run(int32_t pp, int32_t tg) override { + if (pp <= 0 || tg < 0) { + mllm::print("[ERROR] invalid pp/tg:", pp, tg); + return {0.f, 0.f, 0.f}; + } if (!model_) return {0.f, 0.f, 0.f}; auto input_ids = mllm::Tensor::empty({1, pp}, mllm::kInt64, mllm::kCPU) From b17c8f55eea881bc0560e1b3de53c832938c1d13 Mon Sep 17 00:00:00 2001 From: huangzhenhua111 Date: Sat, 31 Jan 2026 14:23:53 +0800 Subject: [PATCH 3/3] tools/mllm-llm-benchmark: estimate KV cache bytes from model config --- tools/mllm-llm-benchmark/main.cpp | 6 ++++++ .../mllm-llm-benchmark/models/BenchmarkTemplate.hpp | 12 ++++++++++++ tools/mllm-llm-benchmark/models/Llama.hpp | 9 +++++++++ 3 files changed, 27 insertions(+) diff --git a/tools/mllm-llm-benchmark/main.cpp b/tools/mllm-llm-benchmark/main.cpp index 382355a73..d832d0500 100644 --- a/tools/mllm-llm-benchmark/main.cpp +++ b/tools/mllm-llm-benchmark/main.cpp @@ -194,6 +194,12 @@ MLLM_MAIN({ // Rough KV cache estimate (bytes) double kv_est_bytes_pp = 0.0; double kv_est_bytes_final = 0.0; + if (auto info = benchmark->kvEstimateInfo(); info.has_value()) { + const int32_t bytes_per = kv_dtype_bytes.get(); // 1/2/4 + // LLaMA-like KV: 2 * n_layers * n_kv_heads * head_dim * seq_len * bytes + kv_est_bytes_pp = 2.0 * info->num_layers * info->num_kv_heads * info->head_dim * (double)pp * bytes_per; + kv_est_bytes_final = 2.0 * info->num_layers * info->num_kv_heads * info->head_dim * (double)(pp + tg) * bytes_per; + } // Prepare one line output (avg) std::stringstream ss; diff --git a/tools/mllm-llm-benchmark/models/BenchmarkTemplate.hpp b/tools/mllm-llm-benchmark/models/BenchmarkTemplate.hpp index 4724a8ca8..aded345e8 100644 --- a/tools/mllm-llm-benchmark/models/BenchmarkTemplate.hpp +++ b/tools/mllm-llm-benchmark/models/BenchmarkTemplate.hpp @@ -3,6 +3,8 @@ #pragma once #include +#include +#include /** * @brief Benchmark result structure @@ -13,6 +15,12 @@ struct BenchmarkTemplateResult { float decode_speed; ///< Decode phase speed in tokens/s }; +struct KVCacheEstimateInfo { + int32_t num_layers = 0; + int32_t num_kv_heads = 0; + int32_t head_dim = 0; // hidden_size / num_attention_heads +}; + /** * @brief Base class for benchmark templates * @@ -58,4 +66,8 @@ class BenchmarkTemplate { * @return Test results */ virtual BenchmarkTemplateResult run(int32_t pp, int32_t tg) = 0; + + // Optional: provide info for KV cache size estimation. + // If a model does not support it, return std::nullopt. + virtual std::optional kvEstimateInfo() const { return std::nullopt; } }; diff --git a/tools/mllm-llm-benchmark/models/Llama.hpp b/tools/mllm-llm-benchmark/models/Llama.hpp index 74128cb18..07629fd72 100644 --- a/tools/mllm-llm-benchmark/models/Llama.hpp +++ b/tools/mllm-llm-benchmark/models/Llama.hpp @@ -14,6 +14,15 @@ class Llama_Benchmark final : public BenchmarkTemplate { public: + std::optional kvEstimateInfo() const override { + if (!cfg_) return std::nullopt; + KVCacheEstimateInfo info; + info.num_layers = cfg_->num_hidden_layers; + info.num_kv_heads = cfg_->num_key_value_heads; + info.head_dim = cfg_->hidden_size / cfg_->num_attention_heads; + return info; + } + void init(const std::string& cfg_path, const std::string& model_path, int32_t cache_length) override { cfg_ = std::make_unique(cfg_path);