From 190c7e78149fd037afba533e96126d8bb16d77f3 Mon Sep 17 00:00:00 2001 From: Sp0tless <104404122+Sp0tless@users.noreply.github.com> Date: Wed, 12 Nov 2025 16:14:08 +0800 Subject: [PATCH] feat: implement internlm2.5-1.8B-chat --- examples/CMakeLists.txt | 1 + examples/internlm2_5/CMakeLists.txt | 3 + examples/internlm2_5/config_1.8B.json | 29 ++ examples/internlm2_5/main.cpp | 67 ++++ .../internlm2/configuration_internlm2.hpp | 82 +++++ mllm/models/internlm2/modeling_internlm2.hpp | 324 ++++++++++++++++++ .../internlm2/tokenization_internlm2.hpp | 157 +++++++++ 7 files changed, 663 insertions(+) create mode 100644 examples/internlm2_5/CMakeLists.txt create mode 100644 examples/internlm2_5/config_1.8B.json create mode 100644 examples/internlm2_5/main.cpp create mode 100644 mllm/models/internlm2/configuration_internlm2.hpp create mode 100644 mllm/models/internlm2/modeling_internlm2.hpp create mode 100644 mllm/models/internlm2/tokenization_internlm2.hpp diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index abfb7375c..effc52c1a 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -8,6 +8,7 @@ add_subdirectory(qwen3) add_subdirectory(qwen3_service) add_subdirectory(deepseek_ocr) add_subdirectory(smollm3_3B) +add_subdirectory(internlm2_5) if(MLLM_BUILD_QNN_BACKEND) add_subdirectory(qwen_npu) diff --git a/examples/internlm2_5/CMakeLists.txt b/examples/internlm2_5/CMakeLists.txt new file mode 100644 index 000000000..383097ffa --- /dev/null +++ b/examples/internlm2_5/CMakeLists.txt @@ -0,0 +1,3 @@ +add_executable(mllm-internlm2_5-chat-runner main.cpp) +target_link_libraries(mllm-internlm2_5-chat-runner PRIVATE MllmRT MllmCPUBackend) +target_include_directories(mllm-internlm2_5-chat-runner PRIVATE ${MLLM_INCLUDE_DIR}) \ No newline at end of file diff --git a/examples/internlm2_5/config_1.8B.json b/examples/internlm2_5/config_1.8B.json new file mode 100644 index 000000000..400d17e54 --- /dev/null +++ b/examples/internlm2_5/config_1.8B.json @@ -0,0 +1,29 @@ +{ + "architectures": [ + "InternLM2ForCausalLM" + ], + "bias": false, + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 8192, + "max_position_embeddings": 32768, + "num_attention_heads": 16, + "num_hidden_layers": 24, + "num_key_value_heads": 8, + "pad_token_id": 2, + "rms_norm_eps": 1e-05, + "rope_scaling": { + "factor": 2.0, + "type": "dynamic" + }, + "rope_theta": 1000000, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.34.0", + "use_cache": true, + "vocab_size": 92544, + "linear_impl_type": "Default" +} \ No newline at end of file diff --git a/examples/internlm2_5/main.cpp b/examples/internlm2_5/main.cpp new file mode 100644 index 000000000..159ca7a17 --- /dev/null +++ b/examples/internlm2_5/main.cpp @@ -0,0 +1,67 @@ +#include +#include +#include +#include +#include + +using mllm::Argparse; + +MLLM_MAIN({ + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model_path").help("Model path").required(true); + auto& model_version = Argparse::add("-mv|--model_version").help("Model version").required(true); + auto& tokenizer_path = Argparse::add("-t|--tokenizer_path").help("Tokenizer JSON path").required(true); + auto& config_path = Argparse::add("-c|--config_path").help("Config path").required(true); + + Argparse::parse(argc, argv); + +#ifdef MLLM_PERFETTO_ENABLE + mllm::perf::start(); +#endif + + mllm::ModelFileVersion file_version = mllm::ModelFileVersion::kV1; + if (model_version.get() == "v2") { file_version = mllm::ModelFileVersion::kV2; } + + if (help.isSet()) { + Argparse::printHelp(); + mllm::shutdownContext(); + return 0; + } + + auto cfg = mllm::models::internlm2::InternLM2Config(config_path.get()); + auto tokenizer = mllm::models::internlm2::InternLM2Tokenizer(tokenizer_path.get()); + auto model = mllm::models::internlm2::InternLM2ForCausalLM(cfg); + + auto params = mllm::load(model_path.get(), file_version); + model.load(params); + + fmt::print("\n{:*^60}\n", " InternLM2.5 1.5B CLI "); + fmt::print("Enter 'exit' or 'quit' to end the session\n\n"); + + std::string prompt_text; + fmt::print("šŸ’¬ Prompt text (or 'exit/quit'): "); + std::getline(std::cin, prompt_text); + if (!(prompt_text == "exit" || prompt_text == "quit")) { + try { + fmt::print("šŸ”„ Processing...\n"); + mllm::models::internlm2::InternLM2Message prompt{prompt_text}; + auto inputs = tokenizer.convertMessage(prompt); + + fmt::print("\nšŸ¤– Response: "); + for (auto& step : model.chat(inputs)) { + auto token = tokenizer.detokenize(step.cur_token_id); + std::wcout << token << std::flush; + } + fmt::print("\n{}\n", std::string(60, '-')); + } catch (const std::exception& e) { fmt::print("\nāŒ Error: {}\n{}\n", e.what(), std::string(60, '-')); } + model.perfSummary(); + } + +#ifdef MLLM_PERFETTO_ENABLE + mllm::perf::stop(); + mllm::perf::saveReport("internlm2_5.perf"); +#endif + + mllm::print("\n"); + mllm::memoryReport(); +}) diff --git a/mllm/models/internlm2/configuration_internlm2.hpp b/mllm/models/internlm2/configuration_internlm2.hpp new file mode 100644 index 000000000..179c0f7ec --- /dev/null +++ b/mllm/models/internlm2/configuration_internlm2.hpp @@ -0,0 +1,82 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "mllm/core/aops/LinearOp.hpp" +#include "mllm/engine/ConfigFile.hpp" + +namespace mllm::models::internlm2 { + +struct InternLM2Config : protected ConfigFile { + InternLM2Config() = default; + + explicit InternLM2Config(const std::string& file_path) : ConfigFile(file_path) { + auto& json = data(); + + if (json.contains("bias")) { bias = json["bias"].get(); } + if (json.contains("hidden_size")) { hidden_size = json["hidden_size"].get(); } + if (json.contains("intermediate_size")) { intermediate_size = json["intermediate_size"].get(); } + if (json.contains("num_hidden_layers")) { num_hidden_layers = json["num_hidden_layers"].get(); } + if (json.contains("num_attention_heads")) { num_attention_heads = json["num_attention_heads"].get(); } + if (json.contains("num_key_value_heads")) { + num_key_value_heads = json["num_key_value_heads"].get(); + } else { + num_key_value_heads = num_attention_heads; + } + if (json.contains("max_position_embeddings")) { max_position_embeddings = json["max_position_embeddings"].get(); } + if (json.contains("rms_norm_eps")) { rms_norm_eps = json["rms_norm_eps"].get(); } + if (json.contains("vocab_size")) { vocab_size = json["vocab_size"].get(); } + if (json.contains("rope_theta")) { rope_theta = json["rope_theta"].get(); } + if (json.contains("tie_word_embeddings")) { tie_word_embeddings = json["tie_word_embeddings"].get(); } + if (json.contains("use_cache")) { use_cache = json["use_cache"].get(); } + if (json.contains("pad_token_id")) { pad_token_id = json["pad_token_id"].get(); } + if (json.contains("bos_token_id")) { bos_token_id = json["bos_token_id"].get(); } + if (json.contains("eos_token_id")) { eos_token_id = json["eos_token_id"].get(); } + if (json.contains("initializer_range")) { initializer_range = json["initializer_range"].get(); } + + if (json.contains("rope_scaling")) { + const auto& scaling = json["rope_scaling"]; + if (scaling.contains("type")) { rope_scaling_type = scaling["type"].get(); } + if (scaling.contains("factor")) { rope_scaling_factor = scaling["factor"].get(); } + } + + if (json.contains("linear_impl_type")) { + linear_impl_type = aops::str2LinearImplTypes(json["linear_impl_type"].get()); + } + + head_dim = hidden_size / num_attention_heads; + max_cache_length = max_position_embeddings; + end_of_text_token_id = static_cast(eos_token_id); + } + + bool bias = false; + int32_t hidden_size = 4096; + int32_t intermediate_size = 11008; + int32_t num_hidden_layers = 32; + int32_t num_attention_heads = 32; + int32_t num_key_value_heads = 32; + int32_t max_position_embeddings = 2048; + int32_t max_cache_length = 2048; + int32_t head_dim = 128; + int32_t vocab_size = 32000; + float rms_norm_eps = 1e-6f; + float rope_theta = 10000.0f; + float rope_scaling_factor = 1.0f; + std::string rope_scaling_type; + + float initializer_range = 0.02f; + bool tie_word_embeddings = false; + bool use_cache = true; + + int32_t pad_token_id = 0; + int32_t bos_token_id = 1; + int32_t eos_token_id = 2; + int32_t end_of_text_token_id = 2; + + aops::LinearImplTypes linear_impl_type = aops::LinearImplTypes::kDefault; +}; + +} // namespace mllm::models::internlm2 diff --git a/mllm/models/internlm2/modeling_internlm2.hpp b/mllm/models/internlm2/modeling_internlm2.hpp new file mode 100644 index 000000000..c52110da0 --- /dev/null +++ b/mllm/models/internlm2/modeling_internlm2.hpp @@ -0,0 +1,324 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include + +#include "fmt/base.h" +#include "mllm/mllm.hpp" +#include "mllm/models/ARGeneration.hpp" +#include "mllm/models/internlm2/configuration_internlm2.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/nn/lmcache/StaticCache.hpp" +#include "mllm/utils/Enumerate.hpp" + +namespace mllm::models::internlm2 { + +inline Tensor makeRoPEInvFreq(int output_dim, float rope_theta, float linear_scale = 1.0f) { + auto inv_freq = Tensor::empty({output_dim / 2}, kFloat32, kCPU).alloc(); + auto inv_freq_ptr = inv_freq.ptr(); + for (int i = 0; i < output_dim / 2; ++i) { + auto base = 1.0f / std::pow(rope_theta, 2.0f * static_cast(i) / static_cast(output_dim)); + inv_freq_ptr[i] = base / linear_scale; + } + return inv_freq; +} + +inline std::pair makeRotaryPosEmbedding(Tensor& position_ids, const Tensor& inv_freq) { + auto batch_size = position_ids.shape()[0]; + auto seq_len = position_ids.shape()[1]; + auto inv_freq_len = inv_freq.shape()[0]; + auto dim = inv_freq_len * 2; + + auto freqs = Tensor::empty({batch_size, seq_len, inv_freq_len}, kFloat32, kCPU).alloc(); + auto freqs_ptr = freqs.ptr(); + auto position_ids_ptr = position_ids.ptr(); + auto inv_freq_ptr = inv_freq.ptr(); + + for (int b = 0; b < batch_size; ++b) { + for (int s = 0; s < seq_len; ++s) { + auto pos = position_ids_ptr[b * seq_len + s]; + for (int d = 0; d < inv_freq_len; ++d) { + freqs_ptr[b * seq_len * inv_freq_len + s * inv_freq_len + d] = static_cast(pos) * inv_freq_ptr[d]; + } + } + } + + auto sin_emb = Tensor::empty({batch_size, seq_len, dim}, kFloat32, kCPU).alloc(); + auto cos_emb = Tensor::empty({batch_size, seq_len, dim}, kFloat32, kCPU).alloc(); + auto sin_ptr = sin_emb.ptr(); + auto cos_ptr = cos_emb.ptr(); + + for (int b = 0; b < batch_size; ++b) { + for (int s = 0; s < seq_len; ++s) { + for (int d = 0; d < inv_freq_len; ++d) { + auto freq = freqs_ptr[b * seq_len * inv_freq_len + s * inv_freq_len + d]; + auto sin_val = std::sin(freq); + auto cos_val = std::cos(freq); + + sin_ptr[b * seq_len * dim + s * dim + d] = sin_val; + sin_ptr[b * seq_len * dim + s * dim + d + inv_freq_len] = sin_val; + cos_ptr[b * seq_len * dim + s * dim + d] = cos_val; + cos_ptr[b * seq_len * dim + s * dim + d + inv_freq_len] = cos_val; + } + } + } + + return {sin_emb, cos_emb}; +} + +class InternLM2MLP final : public nn::Module { + nn::Linear w1_; + nn::Linear w3_; + nn::Linear w2_; + nn::SiLU silu_; + + public: + InternLM2MLP() = default; + + InternLM2MLP(const std::string& name, const InternLM2Config& cfg) : nn::Module(name) { + w1_ = reg("w1", cfg.hidden_size, cfg.intermediate_size, cfg.bias, cfg.linear_impl_type); + silu_ = reg("act"); + w3_ = reg("w3", cfg.hidden_size, cfg.intermediate_size, cfg.bias, cfg.linear_impl_type); + w2_ = reg("w2", cfg.intermediate_size, cfg.hidden_size, cfg.bias, cfg.linear_impl_type); + } + + std::vector forward(const std::vector& inputs, const std::vector& /*args*/) override { + auto x = w1_(inputs[0]); + x = silu_(x); + auto y = w3_(inputs[0]); + x = x * y; + x = w2_(x); + return {x}; + } +}; + +class InternLM2Attention final : public nn::Module { + nn::Linear wqkv_; + nn::Linear o_proj_; + nn::RoPE q_rope_; + nn::RoPE k_rope_; + nn::CausalMask mask_; + nn::Softmax softmax_; + + int hidden_size_ = 0; + int head_dim_ = 0; + int num_attention_heads_ = 0; + int num_key_value_heads_ = 0; + int num_key_value_groups_ = 0; + + public: + InternLM2Attention() = default; + + InternLM2Attention(const std::string& name, const InternLM2Config& cfg) : nn::Module(name) { + hidden_size_ = cfg.hidden_size; + num_attention_heads_ = cfg.num_attention_heads; + num_key_value_heads_ = cfg.num_key_value_heads; + head_dim_ = hidden_size_ / num_attention_heads_; + num_key_value_groups_ = num_attention_heads_ / num_key_value_heads_; + + auto out_features = (num_attention_heads_ + 2 * num_key_value_heads_) * head_dim_; + wqkv_ = reg("wqkv", hidden_size_, out_features, cfg.bias, cfg.linear_impl_type); + o_proj_ = reg("wo", num_attention_heads_ * head_dim_, hidden_size_, cfg.bias, cfg.linear_impl_type); + + q_rope_ = reg("q_rope", cfg.rope_theta, cfg.max_position_embeddings); + k_rope_ = reg("k_rope", cfg.rope_theta, cfg.max_position_embeddings); + mask_ = reg("mask"); + softmax_ = reg("softmax", -1); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = inputs[0]; + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto past_kv_cache = args[0].get(); + + auto qkv = wqkv_(x); + + int B = x.shape()[0]; + int S = x.shape()[1]; + + qkv = qkv.view({B, S, num_key_value_heads_, num_key_value_groups_ + 2, head_dim_}); + + // now we have to contiguous before reshape, this is why the model is not efficient + auto query_blocks = qkv[{kAll, kAll, kAll, {0, num_key_value_groups_}, kAll}].contiguous(); + query_blocks = query_blocks.view({B, S, num_key_value_heads_ * num_key_value_groups_, head_dim_}); + auto query_states = query_blocks.permute({0, 2, 1, 3}); // [B, num_heads, S, D] + + auto key_blocks = qkv[{kAll, kAll, kAll, {num_key_value_groups_, num_key_value_groups_ + 1}, kAll}].contiguous().squeeze(3); + auto key_states = key_blocks.permute({0, 2, 1, 3}); // [B, num_kv_heads, S, D] + + auto value_blocks = + qkv[{kAll, kAll, kAll, {num_key_value_groups_ + 1, num_key_value_groups_ + 2}, kAll}].contiguous().squeeze(3); + auto value_states = value_blocks.permute({0, 2, 1, 3}); + + query_states = q_rope_(query_states, llm_embedding_sin, llm_embedding_cos); + key_states = k_rope_(key_states, llm_embedding_sin, llm_embedding_cos); + + auto [cached_k, cached_v] = past_kv_cache->updateKVCache(layer_idx_, key_states, value_states); + key_states = cached_k; + value_states = cached_v; + + auto scale = 1.f / std::sqrt(static_cast(head_dim_)); + auto attn = nn::functional::matmul(query_states, key_states, false, true) * scale; + attn = mask_(attn); + attn = softmax_(attn); + + auto output = nn::functional::matmul(attn, value_states); + output = output.transpose(1, 2).view({B, S, num_attention_heads_ * head_dim_}); + output = o_proj_(output); + + return {output}; + } + + int layer_idx_ = 0; +}; + +class InternLM2Decoder final : public nn::Module { + public: + InternLM2Attention self_attn_; + InternLM2MLP mlp_; + nn::RMSNorm input_layer_norm_; + nn::RMSNorm post_attention_layer_norm_; + + public: + InternLM2Decoder() = default; + + InternLM2Decoder(const std::string& name, const InternLM2Config& cfg) : nn::Module(name) { + self_attn_ = reg("attention", cfg); + mlp_ = reg("feed_forward", cfg); + input_layer_norm_ = reg("attention_norm", cfg.rms_norm_eps); + post_attention_layer_norm_ = reg("ffn_norm", cfg.rms_norm_eps); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto& kv_cache = args[0]; + + auto x = input_layer_norm_(inputs[0]); + x = self_attn_(x, llm_embedding_sin, llm_embedding_cos, kv_cache)[0]; + auto tmp = x + inputs[0]; + x = post_attention_layer_norm_(tmp); + x = mlp_(x)[0]; + x = x + tmp; + return {x}; + } +}; + +class InternLM2Model final : public nn::Module { + nn::Embedding embedding_; + nn::ModuleList layers_; + nn::RMSNorm norm_; + + public: + InternLM2Model() = default; + + InternLM2Model(const std::string& name, const InternLM2Config& cfg) : nn::Module(name) { + embedding_ = reg("tok_embeddings", cfg.vocab_size, cfg.hidden_size); + layers_ = reg>("layers", cfg.num_hidden_layers, cfg); + for (auto [idx, block] : enumerate(layers_.list())) { block.self_attn_.layer_idx_ = idx; } + norm_ = reg("norm", cfg.rms_norm_eps); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = embedding_(inputs[0]); + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto& kv_cache = args[0]; + + for (auto& block : layers_.list()) { x = block(x, llm_embedding_sin, llm_embedding_cos, kv_cache)[0]; } + + x = norm_(x); + return {x}; + } +}; + +class InternLM2ForCausalLM : public ARGeneration, public nn::Module { + public: + explicit InternLM2ForCausalLM(const InternLM2Config& cfg) + : cfg_(cfg), rope_linear_scale_(cfg.rope_scaling_type == "linear" ? cfg.rope_scaling_factor : 1.0f) { + kv_cache_ = nn::StaticCache(cfg.max_cache_length, cfg.num_hidden_layers, cfg.num_attention_heads, cfg.num_key_value_heads, + cfg.head_dim, kFloat32, kFloat32, kCPU, false); + + eos_token_id_ = cfg.eos_token_id; + max_length_ = cfg.max_cache_length; + + decoder_ = reg("model", cfg); + lm_head_ = reg("output", cfg.hidden_size, cfg.vocab_size, cfg.bias, cfg.linear_impl_type); + + tie_word_embeddings_ = cfg.tie_word_embeddings; + rope_scaling_type_ = cfg.rope_scaling_type; + rope_scaling_factor_ = cfg.rope_scaling_factor; + + auto inv_freq = makeRoPEInvFreq(cfg.head_dim, cfg.rope_theta, rope_linear_scale_); + registerBuffer("inv_freq_base", inv_freq); + } + + ARGenerationOutputPast forward(const ARGenerationOutputPast& input, const ARGenerationArgs& /*args*/) override { + auto sequence = input.at("sequence"); + + auto batch_size = sequence.shape()[0]; + auto seq_len = sequence.shape()[1]; + + Tensor position_ids; + if (input.count("position_ids")) { + position_ids = input.at("position_ids"); + if (seq_len == 1) { + auto last_pos = *position_ids.offsettedPtr({0, position_ids.shape()[1] - 1}); + position_ids = Tensor::empty({batch_size, 1}, kInt64, kCPU).alloc(); + *position_ids.offsettedPtr({0, 0}) = last_pos + 1; + } + } else { + position_ids = Tensor::empty({batch_size, seq_len}, kInt64, kCPU).alloc(); + auto position_ids_ptr = position_ids.ptr(); + for (int b = 0; b < batch_size; ++b) { + for (int s = 0; s < seq_len; ++s) { position_ids_ptr[b * seq_len + s] = s; } + } + } + + auto max_position = int64_t{0}; + auto pos_ptr = position_ids.ptr(); + for (size_t i = 0; i < position_ids.numel(); ++i) { max_position = std::max(max_position, pos_ptr[i]); } + max_position += 1; + + Tensor inv_freq = getBuffer("inv_freq_base"); + if (rope_scaling_type_ == "dynamic" && rope_scaling_factor_ > 1.0f && max_position > cfg_.max_position_embeddings) { + auto factor = rope_scaling_factor_; + auto base = cfg_.rope_theta + * std::pow((factor * static_cast(max_position) / static_cast(cfg_.max_position_embeddings)) + - (factor - 1.0f), + static_cast(cfg_.head_dim) / static_cast(cfg_.head_dim - 2)); + inv_freq = makeRoPEInvFreq(cfg_.head_dim, base, rope_linear_scale_); + } + + auto [llm_embedding_sin, llm_embedding_cos] = makeRotaryPosEmbedding(position_ids, inv_freq); + + auto hidden_states = decoder_(sequence, llm_embedding_sin, llm_embedding_cos, AnyValue(&kv_cache_))[0]; + + auto S = hidden_states.shape()[1]; + hidden_states = hidden_states[{kAll, {S - 1}, kAll}]; + + auto logits = lm_head_(hidden_states); + + return { + {"sequence", logits}, + {"position_ids", position_ids}, + }; + } + + nn::StaticCache& kvCache() { return kv_cache_; } + + private: + InternLM2Config cfg_; + InternLM2Model decoder_; + nn::Linear lm_head_; + nn::StaticCache kv_cache_; + bool tie_word_embeddings_ = false; + std::string rope_scaling_type_; + float rope_scaling_factor_ = 1.0f; + float rope_linear_scale_ = 1.0f; +}; + +} // namespace mllm::models::internlm2 \ No newline at end of file diff --git a/mllm/models/internlm2/tokenization_internlm2.hpp b/mllm/models/internlm2/tokenization_internlm2.hpp new file mode 100644 index 000000000..37c148886 --- /dev/null +++ b/mllm/models/internlm2/tokenization_internlm2.hpp @@ -0,0 +1,157 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include + +#include "mllm/models/ARGeneration.hpp" +#include "mllm/preprocessor/tokenizers/AutoTokenizer.hpp" +#include "mllm/preprocessor/tokenizers/BPE.hpp" +#include "mllm/preprocessor/tokenizers/Unicode.hpp" + +namespace mllm::models::internlm2 { + +struct InternLM2Message { + std::string prompt; + static inline std::string message_template = "<|im_start|>user\n{{{prompt}}}<|im_end|>\n<|im_start|>assistant\n"; + bool add_bos = true; + bool add_eos = false; +}; + +class InternLM2Tokenizer final : public preprocessor::AutoTokenizer { + public: + explicit InternLM2Tokenizer(const std::string& file_path, bool add_bos = true, bool add_eos = false) + : add_bos_(add_bos), add_eos_(add_eos) { + preprocessor::initLocal(); + preprocessor::makeBytes2UnicodeMap(bytes_2_unicode_dict_); + for (auto& kv : bytes_2_unicode_dict_) { bytes_2_unicode_dict_inverse_.insert({kv.second, kv.first}); } + bpe_.initFromSentencePieceJson(file_path); + + special_tokens_trie_.add(L"<|plugin|>"); + special_tokens_trie_.add(L"<|interpreter|>"); + special_tokens_trie_.add(L"<|action_end|>"); + special_tokens_trie_.add(L"<|action_start|>"); + special_tokens_trie_.add(L"<|im_end|>"); + special_tokens_trie_.add(L"<|im_start|>"); + } + + std::vector _tokenize(const std::string& str) override { + std::vector tokens; + auto w_string = preprocessor::utf8string2WideString(str); + auto normalized = normalize(w_string); + + auto bpe_tokens = bpe_._bpe(normalized); + tokens.reserve(bpe_tokens.size()); + for (const auto& token : bpe_tokens) { tokens.push_back(token); } + return tokens; + } + + std::vector tokenize(const std::string& str) override { + auto tokens = special_tokens_trie_.split(preprocessor::utf8string2WideString(str)); + std::vector all_tokens; + + if (add_bos_) { all_tokens.emplace_back(L""); } + for (const auto& token : tokens) { + if (special_tokens_trie_.isSpecialToken(token)) { + all_tokens.emplace_back(token); + continue; + } + auto tmp_tokens = _tokenize(preprocessor::wideString2Utf8String(token)); + all_tokens.insert(all_tokens.end(), tmp_tokens.begin(), tmp_tokens.end()); + } + if (add_eos_) { all_tokens.emplace_back(L""); } + return all_tokens; + } + + std::wstring _detokenize(int64_t pos_idx) override { + auto str = bpe_._lookup_inverse_vocab(pos_idx); + return postprocess(str); + } + + std::wstring detokenize(int64_t pos_idx) override { return _detokenize(pos_idx); } + + Tensor convert2Ids(const std::vector& strs) override { + std::vector ids; + ids.reserve(strs.size()); + for (const auto& str : strs) { ids.emplace_back(bpe_._lookup_vocab(str)); } + + Tensor ret = Tensor::empty({1, static_cast(ids.size())}, kInt64, kCPU) + .setMemType(kExtraInput) + .setName("internlm2-tokenizer-i0") + .alloc(); + + auto ptr = ret.ptr(); + for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; } + + return ret; + } + + ARGenerationOutputPast convertMessage(const InternLM2Message& prompt) { + auto applied_string = InternLM2Message::message_template; + size_t pos = applied_string.find("{{{prompt}}}"); + applied_string.replace(pos, 12, prompt.prompt); + auto tokens = tokenize(prompt.prompt); + + if (!prompt.add_bos && !tokens.empty() && tokens.front() == L"") { tokens.erase(tokens.begin()); } + if (prompt.add_eos && (tokens.empty() || tokens.back() != L"")) { tokens.emplace_back(L""); } + + std::vector ids; + ids.reserve(tokens.size()); + for (const auto& token : tokens) { ids.emplace_back(bpe_._lookup_vocab(token)); } + + Tensor sequence = Tensor::empty({1, static_cast(ids.size())}, kInt64, kCPU) + .setMemType(kNormal) + .setName("internlm2-seq-i0") + .alloc(); + + auto ptr = sequence.ptr(); + for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; } + + return { + {"sequence", sequence}, + }; + } + + private: + static std::wstring normalize(const std::wstring& text) { + if (text.empty()) { return text; } + + std::wstring normalized = text; + std::wstring space_char = L" "; + std::wstring underline_char = L"▁"; + + size_t pos = 0; + while ((pos = normalized.find(space_char, pos)) != std::wstring::npos) { + normalized.replace(pos, space_char.length(), underline_char); + pos += underline_char.length(); + } + + // if (normalized[0] != L'▁') { normalized = underline_char + normalized; } + + return normalized; + } + + static std::wstring postprocess(const std::wstring& text) { + if (text == L"" || text == L"" || text == L"") { return L""; } + + std::wstring processed = text; + std::wregex underline_regex(L"▁"); + processed = std::regex_replace(processed, underline_regex, L" "); + + if (processed == L"<0x0A>") { return L"\n"; } + return processed; + } + + private: + preprocessor::BPE bpe_; + std::unordered_map bytes_2_unicode_dict_; + std::unordered_map bytes_2_unicode_dict_inverse_; + bool add_bos_ = true; + bool add_eos_ = false; +}; + +} // namespace mllm::models::internlm2 \ No newline at end of file