diff --git a/tools/mllm-llm-benchmark/models/All.hpp b/tools/mllm-llm-benchmark/models/All.hpp index 340fe6bf8..728363355 100644 --- a/tools/mllm-llm-benchmark/models/All.hpp +++ b/tools/mllm-llm-benchmark/models/All.hpp @@ -4,20 +4,35 @@ #include #include +#include +#include // for std::tolower -#include "Qwen3_W4A32_KAI.hpp" #include "BenchmarkTemplate.hpp" +#include "Qwen3_W4A32_KAI.hpp" +#include "Llama.hpp" -std::shared_ptr createBenchmark(const std::string& model_name) { +inline std::shared_ptr createBenchmark(const std::string& model_name) { auto tolower = [](const std::string& str) { std::string result = str; - std::transform(result.begin(), result.end(), result.begin(), ::tolower); + // NOTE: std::tolower expects unsigned char cast to avoid UB for negative char values. + std::transform(result.begin(), result.end(), result.begin(), + [](unsigned char c) { return static_cast(std::tolower(c)); }); return result; }; + auto normalized_model_name = tolower(model_name); - if (normalized_model_name.find("qwen3") != std::string::npos && normalized_model_name.find("w4a32") != std::string::npos - && normalized_model_name.find("kai") != std::string::npos) { + + if (normalized_model_name.find("qwen3") != std::string::npos && + normalized_model_name.find("w4a32") != std::string::npos && + normalized_model_name.find("kai") != std::string::npos) { return std::make_shared(); } + + if (normalized_model_name.find("llama") != std::string::npos || + normalized_model_name.find("tinyllama") != std::string::npos || + normalized_model_name.find("tiny_llama") != std::string::npos) { + return std::make_shared(); + } + return nullptr; } diff --git a/tools/mllm-llm-benchmark/models/Llama.hpp b/tools/mllm-llm-benchmark/models/Llama.hpp new file mode 100644 index 000000000..f313922ef --- /dev/null +++ b/tools/mllm-llm-benchmark/models/Llama.hpp @@ -0,0 +1,134 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include + +#include "BenchmarkTemplate.hpp" + +#include +#include +#include + +class Llama_Benchmark final : public BenchmarkTemplate { + public: + void init(const std::string& cfg_path, const std::string& model_path, int32_t cache_length) override { + cfg_ = std::make_unique(cfg_path); + + // LLaMA config uses max_position_embeddings as KV-cache upper bound + if (cache_length > 0) { + cfg_->max_position_embeddings = cache_length; + } + + model_ = std::make_unique("", *cfg_); + + // NOTE: + // tinyllama-fp32.mllm used in examples is a V1 parameter file. + // Loading it as V2 will assert on magic number mismatch. + // We keep V1-only here to make the benchmark runnable; V2 support can be added later + // once we have either: + // (1) a reliable file-version probe, or + // (2) a CLI flag to select model file version. + auto param = mllm::load(model_path, mllm::ModelFileVersion::kV1); + model_->load(param); + + mllm::print("Model initialized successfully"); + } + + void printModelInfo() override { + if (!cfg_) return; + mllm::print("========== Model Information =========="); + mllm::print("Model Type : LLaMA / TinyLlama"); + mllm::print("Hidden Size :", cfg_->hidden_size); + mllm::print("Num Layers :", cfg_->num_hidden_layers); + mllm::print("Num Heads :", cfg_->num_attention_heads); + mllm::print("Num KV Heads :", cfg_->num_key_value_heads); + // NOTE: Defensive guard (shouldn't happen with valid configs, but keeps benchmark robust). + int32_t head_dim = (cfg_->num_attention_heads > 0) ? (cfg_->hidden_size / cfg_->num_attention_heads) : 0; + mllm::print("Head Dim :", head_dim); + mllm::print("Intermediate Size :", cfg_->intermediate_size); + mllm::print("Vocab Size :", cfg_->vocab_size); + mllm::print("Max Pos Embeddings :", cfg_->max_position_embeddings); + mllm::print("======================================="); + } + + void warmup() override { + if (!model_) return; + + const int32_t warmup_length = 8; + const int32_t warmup_gen = 4; + + auto input_ids = mllm::Tensor::empty({1, warmup_length}, mllm::kInt64, mllm::kCPU) + .setMemType(mllm::kNormal) + .alloc(); + auto ptr = input_ids.ptr(); + for (int i = 0; i < warmup_length; ++i) ptr[i] = 1; + + mllm::models::ARGenerationOutputPast inputs; + inputs["sequence"] = input_ids; + + mllm::models::ARGenerationArgs args; + args["max_length"] = mllm::AnyValue((int)warmup_gen); + args["do_sample"] = mllm::AnyValue(false); + + model_->generate(inputs, args); + mllm::print("Warmup completed"); + } + + void clear() override { + // TODO: expose a public KV-cache reset API for LlamaForCausalLM (if needed). + // For now, keep it as no-op to minimize API changes in PR1. + } + + BenchmarkTemplateResult run(int32_t pp, int32_t tg) override { + if (!model_) return {0.f, 0.f, 0.f}; + + auto input_ids = mllm::Tensor::empty({1, pp}, mllm::kInt64, mllm::kCPU) + .setMemType(mllm::kNormal) + .alloc(); + auto ptr = input_ids.ptr(); + for (int i = 0; i < pp; ++i) ptr[i] = 1 + (i % 100); + + mllm::models::ARGenerationOutputPast inputs; + inputs["sequence"] = input_ids; + + mllm::models::ARGenerationArgs args; + args["max_length"] = mllm::AnyValue((int)tg); + args["do_sample"] = mllm::AnyValue(false); + + auto prefill_start = std::chrono::high_resolution_clock::now(); + auto decode_start = prefill_start; + auto decode_end = prefill_start; + + bool first_token = true; + int token_count = 0; + + model_->streamGenerate(inputs, args, [&](int64_t /*token_id*/) { + if (first_token) { + decode_start = std::chrono::high_resolution_clock::now(); + first_token = false; + } + token_count++; + decode_end = std::chrono::high_resolution_clock::now(); + }); + + auto prefill_us = std::chrono::duration_cast(decode_start - prefill_start).count(); + auto decode_us = std::chrono::duration_cast(decode_end - decode_start).count(); + + BenchmarkTemplateResult r; + r.ttft = prefill_us / 1000.0f; + r.prefill_speed = (prefill_us > 0) ? (static_cast(pp) / prefill_us) * 1e6f : 0.f; + // NOTE: decode_us is measured from first token timestamp; exclude that first token from decode throughput. + int decode_tokens = (token_count > 0) ? (token_count - 1) : 0; + r.decode_speed = (decode_us > 0 && decode_tokens > 0) + ? (static_cast(decode_tokens) / decode_us) * 1e6f + : 0.f; + return r; + } + + private: + std::unique_ptr cfg_; + std::unique_ptr model_; +};