-
Notifications
You must be signed in to change notification settings - Fork 175
tools/mllm-llm-benchmark: add llama benchmark template #617
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
huangzhenhua111
wants to merge
2
commits into
UbiquitousLearning:main
from
huangzhenhua111:fix/llm-benchmark-mv
Closed
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,134 @@ | ||
| // Copyright (c) MLLM Team. | ||
| // Licensed under the MIT License. | ||
| #pragma once | ||
|
|
||
| #include <memory> | ||
| #include <chrono> | ||
| #include <string> | ||
|
|
||
| #include "BenchmarkTemplate.hpp" | ||
|
|
||
| #include <mllm/mllm.hpp> | ||
| #include <mllm/models/llama/modeling_llama.hpp> | ||
| #include <mllm/models/llama/configuration_llama.hpp> | ||
|
|
||
| class Llama_Benchmark final : public BenchmarkTemplate { | ||
| public: | ||
| void init(const std::string& cfg_path, const std::string& model_path, int32_t cache_length) override { | ||
| cfg_ = std::make_unique<mllm::models::llama::LLaMAConfig>(cfg_path); | ||
|
|
||
| // LLaMA config uses max_position_embeddings as KV-cache upper bound | ||
| if (cache_length > 0) { | ||
| cfg_->max_position_embeddings = cache_length; | ||
| } | ||
|
|
||
| model_ = std::make_unique<mllm::models::llama::LlamaForCausalLM>("", *cfg_); | ||
|
|
||
| // NOTE: | ||
| // tinyllama-fp32.mllm used in examples is a V1 parameter file. | ||
| // Loading it as V2 will assert on magic number mismatch. | ||
| // We keep V1-only here to make the benchmark runnable; V2 support can be added later | ||
| // once we have either: | ||
| // (1) a reliable file-version probe, or | ||
| // (2) a CLI flag to select model file version. | ||
| auto param = mllm::load(model_path, mllm::ModelFileVersion::kV1); | ||
| model_->load(param); | ||
|
|
||
| mllm::print("Model initialized successfully"); | ||
| } | ||
|
|
||
| void printModelInfo() override { | ||
| if (!cfg_) return; | ||
| mllm::print("========== Model Information =========="); | ||
| mllm::print("Model Type : LLaMA / TinyLlama"); | ||
| mllm::print("Hidden Size :", cfg_->hidden_size); | ||
| mllm::print("Num Layers :", cfg_->num_hidden_layers); | ||
| mllm::print("Num Heads :", cfg_->num_attention_heads); | ||
| mllm::print("Num KV Heads :", cfg_->num_key_value_heads); | ||
| // NOTE: Defensive guard (shouldn't happen with valid configs, but keeps benchmark robust). | ||
| int32_t head_dim = (cfg_->num_attention_heads > 0) ? (cfg_->hidden_size / cfg_->num_attention_heads) : 0; | ||
| mllm::print("Head Dim :", head_dim); | ||
| mllm::print("Intermediate Size :", cfg_->intermediate_size); | ||
| mllm::print("Vocab Size :", cfg_->vocab_size); | ||
| mllm::print("Max Pos Embeddings :", cfg_->max_position_embeddings); | ||
| mllm::print("======================================="); | ||
| } | ||
|
|
||
| void warmup() override { | ||
| if (!model_) return; | ||
|
|
||
| const int32_t warmup_length = 8; | ||
| const int32_t warmup_gen = 4; | ||
|
|
||
| auto input_ids = mllm::Tensor::empty({1, warmup_length}, mllm::kInt64, mllm::kCPU) | ||
| .setMemType(mllm::kNormal) | ||
| .alloc(); | ||
| auto ptr = input_ids.ptr<mllm::mllm_int64_t>(); | ||
| for (int i = 0; i < warmup_length; ++i) ptr[i] = 1; | ||
|
|
||
| mllm::models::ARGenerationOutputPast inputs; | ||
| inputs["sequence"] = input_ids; | ||
|
|
||
| mllm::models::ARGenerationArgs args; | ||
| args["max_length"] = mllm::AnyValue((int)warmup_gen); | ||
| args["do_sample"] = mllm::AnyValue(false); | ||
|
|
||
| model_->generate(inputs, args); | ||
| mllm::print("Warmup completed"); | ||
| } | ||
|
|
||
| void clear() override { | ||
| // TODO: expose a public KV-cache reset API for LlamaForCausalLM (if needed). | ||
| // For now, keep it as no-op to minimize API changes in PR1. | ||
| } | ||
|
|
||
| BenchmarkTemplateResult run(int32_t pp, int32_t tg) override { | ||
| if (!model_) return {0.f, 0.f, 0.f}; | ||
|
|
||
| auto input_ids = mllm::Tensor::empty({1, pp}, mllm::kInt64, mllm::kCPU) | ||
| .setMemType(mllm::kNormal) | ||
| .alloc(); | ||
| auto ptr = input_ids.ptr<mllm::mllm_int64_t>(); | ||
| for (int i = 0; i < pp; ++i) ptr[i] = 1 + (i % 100); | ||
|
|
||
| mllm::models::ARGenerationOutputPast inputs; | ||
| inputs["sequence"] = input_ids; | ||
|
|
||
| mllm::models::ARGenerationArgs args; | ||
| args["max_length"] = mllm::AnyValue((int)tg); | ||
| args["do_sample"] = mllm::AnyValue(false); | ||
|
|
||
| auto prefill_start = std::chrono::high_resolution_clock::now(); | ||
| auto decode_start = prefill_start; | ||
| auto decode_end = prefill_start; | ||
|
|
||
| bool first_token = true; | ||
| int token_count = 0; | ||
|
|
||
| model_->streamGenerate(inputs, args, [&](int64_t /*token_id*/) { | ||
| if (first_token) { | ||
| decode_start = std::chrono::high_resolution_clock::now(); | ||
| first_token = false; | ||
| } | ||
| token_count++; | ||
| decode_end = std::chrono::high_resolution_clock::now(); | ||
| }); | ||
|
|
||
| auto prefill_us = std::chrono::duration_cast<std::chrono::microseconds>(decode_start - prefill_start).count(); | ||
| auto decode_us = std::chrono::duration_cast<std::chrono::microseconds>(decode_end - decode_start).count(); | ||
|
|
||
| BenchmarkTemplateResult r; | ||
| r.ttft = prefill_us / 1000.0f; | ||
| r.prefill_speed = (prefill_us > 0) ? (static_cast<float>(pp) / prefill_us) * 1e6f : 0.f; | ||
| // NOTE: decode_us is measured from first token timestamp; exclude that first token from decode throughput. | ||
| int decode_tokens = (token_count > 0) ? (token_count - 1) : 0; | ||
| r.decode_speed = (decode_us > 0 && decode_tokens > 0) | ||
| ? (static_cast<float>(decode_tokens) / decode_us) * 1e6f | ||
| : 0.f; | ||
| return r; | ||
| } | ||
|
|
||
| private: | ||
| std::unique_ptr<mllm::models::llama::LLaMAConfig> cfg_; | ||
| std::unique_ptr<mllm::models::llama::LlamaForCausalLM> model_; | ||
| }; |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.