From db60623e7926fb151b3cc63f029929122cac342a Mon Sep 17 00:00:00 2001 From: Aaron Lee Date: Sun, 10 Aug 2025 23:52:54 -0400 Subject: [PATCH 1/8] added getter for nextn layer count and server slot has_mtp property --- include/llama.h | 2 ++ src/llama-model.cpp | 4 ++++ tools/server/server.cpp | 12 +++++++++++- 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/include/llama.h b/include/llama.h index 545e957e5f52b..3bade3ae71cce 100644 --- a/include/llama.h +++ b/include/llama.h @@ -495,6 +495,8 @@ extern "C" { LLAMA_API int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab); + LLAMA_API int32_t llama_model_n_nextn_layer(const struct llama_model * model); + // Functions to access the model's GGUF metadata scalar values // - The functions return the length of the string on success, or -1 on failure // - The output string is always null-terminated and cleared on failure diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 58ca7df707ef3..2351478c2f056 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -18587,6 +18587,10 @@ const char * llama_model_cls_label(const struct llama_model * model, uint32_t i) return nullptr; } +int32_t llama_model_n_nextn_layer(const llama_model * model) { + return model->hparams.nextn_predict_layers; +} + // deprecated int32_t llama_n_ctx_train(const llama_model * model) { return llama_model_n_ctx_train(model); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index a255d481a4d1c..7a931cc6b0740 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1294,7 +1294,8 @@ struct server_slot { mtmd_context * mctx = nullptr; common_speculative * spec = nullptr; - + bool has_mtp = false; + std::vector lora; // the index relative to completion multi-task request @@ -2121,6 +2122,15 @@ struct server_context { common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str()); } } + else if (llama_model_n_nextn_layer(model) > 0) { + SRV_INF("model has nextn layers = %d\n", llama_model_n_nextn_layer(model)); + slot.has_mtp = true; + + // assume one speculative token (true of all well-known MTP models so far) + slot.batch_spec = llama_batch_init(2, 0, 1); + params_base.speculative.n_min = 0; + params_base.speculative.n_max = 1; + } SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx); From e434f87cc739a1901931d88e33f777170a4e18e7 Mon Sep 17 00:00:00 2001 From: Aaron Lee Date: Mon, 11 Aug 2025 01:21:47 -0400 Subject: [PATCH 2/8] some work towards building mtp layer graph --- src/llama-model.cpp | 139 ++++++++++++++++++++++++++++++++++++++++ tools/server/server.cpp | 18 +++--- 2 files changed, 149 insertions(+), 8 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 2351478c2f056..9e09e7e0a4f97 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4507,6 +4507,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // but only PROCESS up to last layer (skipping final NextN layer) in forward pass for (int i = 0; i < n_layer; ++i) { int flags = 0; + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { // skip all tensors in the NextN layers flags |= TENSOR_SKIP; @@ -13919,6 +13920,144 @@ struct llm_build_glm4_moe : public llm_graph_context { } }; +struct llm_build_glm4_moe_mtp : public llm_graph_context { + llm_build_glm4_moe_mtp(const llama_model & model, const llm_graph_params & params, + // For v0, let's rebuild the computational graph for every step + this mimics the vLLM impl parameterization + ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past + ) : llm_graph_context(params) { + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + // Assuming a single MTP layer at the end + const int il = hparams.n_layer - 1; + const auto & mtp_layer = model.layers[il]; + + ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1); + ggml_set_i32(inp_pos, n_past); + llm_graph_input_attn_no_cache * inp_attn = nullptr; + + ggml_tensor * cur; + + // get MTP embedding for last (conventionally sampled) token + ggml_tensor * inp_token_id = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1); + ggml_set_i32(inp_token_id, last_token_id); + ggml_tensor * token_emb = ggml_get_rows(ctx0, mtp_layer.nextn.embed_tokens, inp_token_id); + ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il); + + // vLLM l99 previous_hidden_states = self.hnorm(previous_hidden_states) + ggml_tensor * hidden_state_norm = build_norm(hidden_state_inp, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il); + + ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0); // torch.cat + cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined); // eh_proj + + + // now proceed through last layer (skipped in main model) + ggml_tensor * inpSA = cur; + + // Pre-attention norm for the MTP block + ggml_tensor* attn_inp = build_norm(cur, mtp_layer.attn_norm, NULL, LLM_NORM_RMS, il); + + // self-attention + { + ggml_tensor * Qcur = build_lora_mm(mtp_layer.wq, cur); + if (mtp_layer.bq) { + Qcur = ggml_add(ctx0, Qcur, mtp_layer.bq); + } + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(mtp_layer.wk, cur); + if (mtp_layer.bk) { + Kcur = ggml_add(ctx0, Kcur, mtp_layer.bk); + } + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(mtp_layer.wv, cur); + if (mtp_layer.bv) { + Vcur = ggml_add(ctx0, Vcur, mtp_layer.bv); + } + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // Apply Q/K norm if available (GLM-4.5 355B variant) + if (mtp_layer.attn_q_norm) { + Qcur = build_norm(Qcur, mtp_layer.attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + } + if (mtp_layer.attn_k_norm) { + Kcur = build_norm(Kcur, mtp_layer.attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + } + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + mtp_layer.wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + + cur = build_norm(ffn_inp, mtp_layer.attn_post_norm, NULL, LLM_NORM_RMS, il); + + // moe ffn for nextn block + { + // Process routed experts using existing MoE infrastructure + ggml_tensor * routed_out = build_moe_ffn(cur, + mtp_layer.ffn_gate_inp, + mtp_layer.ffn_up_exps, + mtp_layer.ffn_gate_exps, + mtp_layer.ffn_down_exps, + mtp_layer.ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + true, hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(routed_out, "ffn_moe_out", il); + + // Process shared expert on original input + ggml_tensor * shared_out = build_ffn(cur, + mtp_layer.ffn_up_shexp, NULL, NULL, + mtp_layer.ffn_gate_shexp, NULL, NULL, + mtp_layer.ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(shared_out, "ffn_shexp_out", il); + + // Final output: routed_output + shared_output + cur = ggml_add(ctx0, routed_out, shared_out); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_norm(cur, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, il); + cur = build_lora_mm(mtp_layer.nextn.shared_head_head, cur); + + res->t_logits = cur; + + ggml_build_forward_expand(gf, res->t_logits); + } +}; + struct llm_build_nemotron : public llm_graph_context { llm_build_nemotron(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 7a931cc6b0740..08ffb25d2417e 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1432,7 +1432,7 @@ struct server_slot { } bool can_speculate() const { - return ctx_dft && params.speculative.n_max > 0 && params.cache_prompt; + return (ctx_dft || has_mtp) && params.speculative.n_max > 0 && params.cache_prompt; } void add_token(const completion_token_output & token) { @@ -2122,14 +2122,16 @@ struct server_context { common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str()); } } + + // if model has MTP and no draft model is specified... else if (llama_model_n_nextn_layer(model) > 0) { - SRV_INF("model has nextn layers = %d\n", llama_model_n_nextn_layer(model)); - slot.has_mtp = true; - - // assume one speculative token (true of all well-known MTP models so far) - slot.batch_spec = llama_batch_init(2, 0, 1); - params_base.speculative.n_min = 0; - params_base.speculative.n_max = 1; + SRV_INF("model has nextn layers = %d\n", llama_model_n_nextn_layer(model)); + slot.has_mtp = true; + + // assume one speculative token (true of all well-known MTP models so far) + slot.batch_spec = llama_batch_init(2, 0, 1); + params_base.speculative.n_min = 0; + params_base.speculative.n_max = 1; } SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx); From 1f477b375504aa557ed21066aa6783b11781a179 Mon Sep 17 00:00:00 2001 From: Aaron Lee Date: Mon, 11 Aug 2025 20:54:45 -0400 Subject: [PATCH 3/8] make nextn weights loadable without a crash --- src/llama-arch.cpp | 13 +++++++------ src/llama-model.cpp | 27 ++++++++++++++++++++++++++- tools/server/server.cpp | 3 ++- 3 files changed, 35 insertions(+), 8 deletions(-) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 18dcc6ddfe567..4b6fa3e6059a2 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -2240,12 +2240,13 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_SHORTCONV_OUTPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, // NextN/MTP tensors are currently ignored (reserved for future MTP support) // These tensors only exist in the last layer(s) and are treated as output tensors - {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, - {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + // Changed to LLM_TENSOR_LAYER_REPEATING because we saved these under a blk with a non-negative id + {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, }; LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 9e09e7e0a4f97..a9310a6090562 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4510,7 +4510,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { // skip all tensors in the NextN layers - flags |= TENSOR_SKIP; + // flags |= TENSOR_SKIP; } auto & layer = layers[i]; @@ -4574,12 +4574,37 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + + // our input/output layer sanity check prevents us from loading the eh_proj layer! + // this is because eh_proj is labelled with a layer number in existing GGUFs, + // so we need to set bid == to successfully load the tensors, but our io layer sanity check requires bid == -1. + // this function is a hack that creates the nextn layers as LLM_TENSOR_LAYER_REPEATING instead. + /* auto create_tensor_override_io_sanity_check = + [&](llm_tensor type_enum, const char * suffix, int bid, const std::initializer_list& ne, int flags) -> ggml_tensor * { + + auto tn_orig = tn(type_enum, suffix, bid); + llm_tensor_info info_override = *tn_orig.info; + info_override.layer = LLM_TENSOR_LAYER_REPEATING; + + auto tn_override = tn_orig; + tn_override.info = &info_override; + + return create_tensor(tn_override, ne, flags); + };*/ + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags); layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags); layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags); + + // layer.nextn.eh_proj = create_tensor_override_io_sanity_check(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i, { 2 * n_embd, n_embd }, flags); + // layer.nextn.embed_tokens = create_tensor_override_io_sanity_check(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i, { n_embd, n_vocab }, flags); + // layer.nextn.enorm = create_tensor_override_io_sanity_check(LLM_TENSOR_NEXTN_ENORM, "weight", i, { n_embd }, flags); + // layer.nextn.hnorm = create_tensor_override_io_sanity_check(LLM_TENSOR_NEXTN_HNORM, "weight", i, { n_embd }, flags); + // layer.nextn.shared_head_head = create_tensor_override_io_sanity_check(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i, { n_embd, n_vocab }, flags); + // layer.nextn.shared_head_norm = create_tensor_override_io_sanity_check(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i, { n_embd }, flags); } } } diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 08ffb25d2417e..a9ad900ce39ee 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1432,7 +1432,8 @@ struct server_slot { } bool can_speculate() const { - return (ctx_dft || has_mtp) && params.speculative.n_max > 0 && params.cache_prompt; + // return (ctx_dft || has_mtp) && params.speculative.n_max > 0 && params.cache_prompt; + return (ctx_dft) && params.speculative.n_max > 0 && params.cache_prompt; } void add_token(const completion_token_output & token) { From 03231da69eec20677e25e2307d4fe31ac2ede034 Mon Sep 17 00:00:00 2001 From: Aaron Lee Date: Tue, 12 Aug 2025 01:03:59 -0400 Subject: [PATCH 4/8] add model member function to build mtp graph, to be called from speculative.cpp --- src/llama-model.cpp | 16 ++++++++++++++++ src/llama-model.h | 2 ++ 2 files changed, 18 insertions(+) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index a9310a6090562..667d9e442b398 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -18673,6 +18673,22 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { return llm->res->get_gf(); } +ggml_cgraph* llama_model::build_mtp_graph(const llm_graph_params& params, + ggml_tensor* hidden_state_inp, llama_token last_token_id, int n_past) const { + std::unique_ptr llm; + + switch (arch) { + case LLM_ARCH_GLM4_MOE: + { + llm = std::make_unique(*this, params, hidden_state_inp, last_token_id, n_past); + } break; + default: + GGML_ABORT("fatal error"); + } + + return llm->res->get_gf(); +} + // // interface implementation // diff --git a/src/llama-model.h b/src/llama-model.h index 6fcd74d57fdca..77a18aca7164f 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -475,6 +475,8 @@ struct llama_model { // TODO: move this to new llm_arch_model_i interface ggml_cgraph * build_graph(const llm_graph_params & params) const; + ggml_cgraph * build_mtp_graph(const llm_graph_params & params, + ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past) const; private: struct impl; From cf0f7c0448c2c1736588673114558e5829db7879 Mon Sep 17 00:00:00 2001 From: Aaron Lee Date: Wed, 13 Aug 2025 02:21:17 -0400 Subject: [PATCH 5/8] broad thrust of the mtp implementation --- common/speculative.cpp | 126 ++++++++++++++++++++++++++++++++++++++++ common/speculative.h | 9 +++ include/llama.h | 17 ++++++ src/llama-context.cpp | 59 +++++++++++++++++++ src/llama-context.h | 7 +++ src/llama-graph.cpp | 4 ++ src/llama-graph.h | 1 + src/llama-model.cpp | 12 +++- tools/server/server.cpp | 36 ++++++++---- 9 files changed, 260 insertions(+), 11 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 262b2c23e720f..e46a0968bdec2 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -5,6 +5,7 @@ #include "log.h" #include "common.h" #include "sampling.h" +#include "../src/llama-graph.h" #include #include @@ -359,3 +360,128 @@ llama_tokens common_speculative_gen_draft( } return result; } + + +llama_tokens mtp_speculative_gen_draft( + struct common_sampler * smpl, + struct llama_context * ctx, + llama_token id_last, + int32_t n_past, + int32_t last_tok_idx) { + + llama_tokens result; + + LOG_INF("step: '%d'\n", 1); + + // sample one token from the draft model -- this does NOT generalize to >1 MTP head + result.reserve(1); + + // need to determine which architecture we're using so we call the correct MTP model + const auto * model = llama_get_model(ctx); + + LOG_INF("step: '%d'\n", 2); + + //LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); + //auto * gf = model.build_graph(gparams); + + LOG_INF("step: '%d'\n", 3); + + /*if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); + ret = GGML_STATUS_ALLOC_FAILED; + return nullptr; + }*/ + + //llm_graph_result res_mtp(ctx->graph_max_nodes()); + llm_graph_result * res_mtp; + llama_ubatch ubatch_mtp; + ubatch_mtp.n_tokens = 1; + ubatch_mtp.pos = &n_past; // Critical for positional encoding + + // We also need a minimal ubatch to provide positional context (RoPE) + // ubatch_mtp.tokens = &last_token_id; + // ubatch_mtp.seq_id = llama_get_main_seq_id(ctx); // Assuming a helper + // ubatch_mtp.logits = nullptr; + // ubatch_mtp.all_pos_0 = -1; + // ubatch_mtp.all_pos_1 = -1; + // ubatch_mtp.all_seq_id = -1; + + // Manually construct the graph parameters + //const llm_graph_params params_mtp = { + // /*.arch =*/ model->arch, + // /*.hparams =*/ model->hparams, + // /*.cparams =*/ ctx->cparams, + // /*.ubatch =*/ ubatch_mtp, + // /*.gtype =*/ LLM_GRAPH_TYPE_DECODER, + // /*.sched =*/ ctx->sched.get(), + // /*.backend_cpu =*/ ctx->backend_cpu, + // /*.cvec =*/ &ctx->cvec, + // /*.loras =*/ &ctx->loras, + // /*.mctx =*/ llama_get_memory(ctx), // Use the KV cache's memory context + // /*.cross =*/ &ctx->cross, + // /*.n_outputs =*/ 1, + // /*.cb =*/ ctx->graph_get_cb(), + // /*.res =*/ &res_mtp, // Point to our temporary result object + //}; + llm_graph_params params_mtp = llama_mtp_graph_params(ctx, res_mtp, ubatch_mtp); + + LOG_INF("step: '%d'\n", 4); + + // ggml_cgraph* build_mtp_graph(const llm_graph_params & params, + // ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past) const; + auto * last_embd = llama_get_embeddings_tensor(ctx); + + LOG_INF("step: '%d'\n", 5); + + GGML_ASSERT(model != nullptr); + GGML_ASSERT(last_embd != nullptr); + + auto * gf = llama_build_mtp_graph(model, params_mtp, last_embd, id_last, n_past); + + if (!gf) { + LOG_INF("%s: failed to initialize graph\n", __func__); + //ret = GGML_STATUS_FAILED; + return result; + } + + LOG_INF("step: '%d'\n", 6); + + const auto status = llama_graph_compute(ctx, gf, false); + + LOG_INF("step: '%d'\n", 7); + + struct ggml_tensor * logits_mtp = llama_graph_result_get_logits(res_mtp); + float * ctx_logit_pointer = llama_get_logits(ctx); + + LOG_INF("step: '%d'\n", 8); + + if (logits_mtp) { + llama_set_logits(ctx, logits_mtp); + } + + LOG_INF("step: '%d'\n", 9); + + { + common_sampler_sample(smpl, ctx, last_tok_idx, true); + + LOG_INF("step: '%d'\n", 10); + + const auto * cur_p = common_sampler_get_candidates(smpl); + + for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { + LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", + k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str()); + } + + // add drafted token for each sequence + const llama_token id = cur_p->data[0].id; + + // skip accepting draft token -- since we're only drafting one token this can't affect future outputs + // smpl will accept the token if it doesn't get rejected by main model later + // common_sampler_accept(smpl, id, true); + + result.push_back(id); + } + + return result; +} diff --git a/common/speculative.h b/common/speculative.h index e69d7aaa1eb00..3b04890073867 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -27,6 +27,15 @@ void common_speculative_add_replacement_tgt_dft( struct common_speculative * spec, const char *source, const char *dest); + +// sample up to n_draft tokens and add them to the batch using the draft model +llama_tokens mtp_speculative_gen_draft( + struct common_sampler* smpl, + struct llama_context* ctx, + llama_token id_last, + int32_t n_past, + int32_t last_tok_idx); + // sample up to n_draft tokens and add them to the batch using the draft model llama_tokens common_speculative_gen_draft( struct common_speculative * spec, diff --git a/include/llama.h b/include/llama.h index 3bade3ae71cce..2134f62d52714 100644 --- a/include/llama.h +++ b/include/llama.h @@ -544,12 +544,17 @@ extern "C" { // Returns true if the model is diffusion-based (like LLaDA, Dream, etc.) LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model); + LLAMA_API ggml_cgraph * llama_build_mtp_graph(const struct llama_model * model, const struct llm_graph_params & params, + struct ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past); + // Returns 0 on success LLAMA_API uint32_t llama_model_quantize( const char * fname_inp, const char * fname_out, const llama_model_quantize_params * params); + + // // Adapters // @@ -972,6 +977,8 @@ extern "C" { // returns NULL for invalid ids. LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i); + LLAMA_API void llama_set_logits(struct llama_context* ctx, struct ggml_tensor* logit_override); + // Get all output token embeddings. // when pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model, // the embeddings for which llama_batch.logits[i] != 0 are stored contiguously @@ -994,6 +1001,8 @@ extern "C" { // otherwise: float[n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id); + LLAMA_API ggml_tensor * llama_get_embeddings_tensor(struct llama_context * ctx); + // // Vocab // @@ -1452,6 +1461,14 @@ extern "C" { ggml_opt_epoch_callback callback_train, ggml_opt_epoch_callback callback_eval); + LLAMA_API llm_graph_params llama_mtp_graph_params(struct llama_context* ctx, class llm_graph_result * res, const struct llama_ubatch& ubatch); + + LLAMA_API ggml_status llama_graph_compute(struct llama_context * ctx, struct ggml_cgraph * gf, bool batched); + + LLAMA_API ggml_tensor * llama_graph_result_get_logits(class llm_graph_result * res); + + + #ifdef __cplusplus } #endif diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 26a5cf9c3f8db..26c3e639d8a95 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -6,6 +6,7 @@ #include "llama-memory.h" #include "llama-mmap.h" #include "llama-model.h" +#include "llama-graph.h" #include #include @@ -522,6 +523,14 @@ float * llama_context::get_logits() { return logits; } +void llama_context::set_logits(struct ggml_tensor * logit_override) { + ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), logit_override); + GGML_ASSERT(backend_res != nullptr); + GGML_ASSERT(logits != nullptr); + + ggml_backend_tensor_get_async(backend_res, logit_override, logits, 0, model.vocab.n_tokens() * sizeof(float)); +} + float * llama_context::get_logits_ith(int32_t i) { int64_t j = -1; @@ -617,6 +626,10 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { return it->second.data(); } +ggml_tensor * llama_context::get_embeddings_tensor() { + return embd_tensor; +} + void llama_context::attach_threadpool( ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch) { @@ -1113,6 +1126,7 @@ int llama_context::decode(const llama_batch & batch_inp) { auto * t_logits = res->get_logits(); auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; + embd_tensor = res->get_embd(); if (t_embd && res->get_embd_pooled()) { t_embd = res->get_embd_pooled(); @@ -1429,6 +1443,27 @@ llm_graph_params llama_context::graph_params( }; } +llm_graph_params llama_context::mtp_graph_params( + llm_graph_result* res, + const llama_ubatch& ubatch) const { + return { + /*.arch =*/ model.arch, + /*.hparams =*/ model.hparams, + /*.cparams =*/ cparams, + /*.ubatch =*/ ubatch, + /*.gtype =*/ LLM_GRAPH_TYPE_DECODER, + /*.sched =*/ sched.get(), + /*.backend_cpu =*/ backend_cpu, + /*.cvec =*/ &cvec, + /*.loras =*/ &loras, + /*.mctx =*/ memory->init_batch(*balloc, 1, false).get(), + /*.cross =*/ &cross, + /*.n_outputs =*/ 1, + /*.cb =*/ graph_get_cb(), + /*.res =*/ res, + }; +} + ggml_status llama_context::graph_compute( ggml_cgraph * gf, bool batched) { @@ -2233,6 +2268,7 @@ void llama_context::opt_epoch( llama_batch_free(batch); } + // // interface implementation // @@ -2274,6 +2310,8 @@ llama_context_params llama_context_default_params() { return result; } + + llama_context * llama_init_from_model( llama_model * model, llama_context_params params) { @@ -2412,6 +2450,11 @@ float * llama_get_logits_ith(llama_context * ctx, int32_t i) { return ctx->get_logits_ith(i); } +void llama_set_logits(llama_context* ctx, struct ggml_tensor* logit_override) { + ctx->set_logits(logit_override); +} + + float * llama_get_embeddings(llama_context * ctx) { ctx->synchronize(); @@ -2430,6 +2473,13 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) { return ctx->get_embeddings_seq(seq_id); } +ggml_tensor * llama_get_embeddings_tensor(llama_context * ctx) { + ctx->synchronize(); + + return ctx->get_embeddings_tensor(); +} + + // llama adapter API int32_t llama_set_adapter_lora( @@ -2926,3 +2976,12 @@ void llama_opt_epoch( callback_train, callback_eval); } + +llm_graph_params llama_mtp_graph_params(llama_context* ctx, llm_graph_result* res, const llama_ubatch& ubatch) { + return ctx->mtp_graph_params(res, ubatch); +} + + +ggml_status llama_graph_compute(llama_context* ctx, ggml_cgraph* gf, bool batched) { + return ctx->graph_compute(gf, batched); +} diff --git a/src/llama-context.h b/src/llama-context.h index 25c143d56dfb2..44bcdf6d95260 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -59,6 +59,7 @@ struct llama_context { float * get_embeddings(); float * get_embeddings_ith(int32_t i); float * get_embeddings_seq(llama_seq_id seq_id); + ggml_tensor * get_embeddings_tensor(); void attach_threadpool( ggml_threadpool_t threadpool, @@ -199,6 +200,10 @@ struct llama_context { // reserve a graph with a dummy ubatch of the specified size ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx); + llm_graph_params mtp_graph_params(llm_graph_result * res, const llama_ubatch & ubatch) const; + + void set_logits(struct ggml_tensor* logit_override); + private: llm_graph_params graph_params( llm_graph_result * res, @@ -240,6 +245,7 @@ struct llama_context { // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE size_t embd_size = 0; // capacity (of floats) for embeddings float * embd = nullptr; + ggml_tensor * embd_tensor = nullptr; // sequence embeddings output (map of [n_embd] vectors) // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE @@ -308,3 +314,4 @@ struct llama_context { mutable int32_t n_reused = 0; // number of times the previous graph was reused }; + diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 053c72d6dc8d1..b5184e4559d4e 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1911,3 +1911,7 @@ int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buck return relative_bucket; } + +ggml_tensor * llama_graph_result_get_logits(llm_graph_result * res) { + return res->get_logits(); +} diff --git a/src/llama-graph.h b/src/llama-graph.h index 6ff49de3a1ce8..10702ed219c01 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -818,3 +818,4 @@ struct llm_graph_context { // TODO: better name int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional); + diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 667d9e442b398..8a9ba8480327e 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -18673,19 +18673,21 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { return llm->res->get_gf(); } -ggml_cgraph* llama_model::build_mtp_graph(const llm_graph_params& params, +ggml_cgraph * llama_model::build_mtp_graph(const llm_graph_params& params, ggml_tensor* hidden_state_inp, llama_token last_token_id, int n_past) const { std::unique_ptr llm; switch (arch) { case LLM_ARCH_GLM4_MOE: { + printf("step: '%d'\n", 56); llm = std::make_unique(*this, params, hidden_state_inp, last_token_id, n_past); } break; default: GGML_ABORT("fatal error"); } + printf("step: '%d'\n", 57); return llm->res->get_gf(); } @@ -19004,3 +19006,11 @@ bool llama_model_is_diffusion(const llama_model * model) { const std::vector> & llama_internal_get_tensor_map(const llama_model * model) { return model->tensors_by_name; } + +ggml_cgraph * llama_build_mtp_graph(const llama_model * model, const llm_graph_params & params, + ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past) { + printf("step: '%d'\n", 55); + + return model->build_mtp_graph(params, hidden_state_inp, last_token_id, n_past); +} + diff --git a/tools/server/server.cpp b/tools/server/server.cpp index a9ad900ce39ee..29d551ea5131b 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1294,7 +1294,8 @@ struct server_slot { mtmd_context * mctx = nullptr; common_speculative * spec = nullptr; - bool has_mtp = false; + bool has_mtp = false; + int32_t last_tok_idx = -1; std::vector lora; @@ -1432,8 +1433,8 @@ struct server_slot { } bool can_speculate() const { - // return (ctx_dft || has_mtp) && params.speculative.n_max > 0 && params.cache_prompt; - return (ctx_dft) && params.speculative.n_max > 0 && params.cache_prompt; + return (ctx_dft || has_mtp) && params.speculative.n_max > 0 && params.cache_prompt; + // return (ctx_dft) && params.speculative.n_max > 0 && params.cache_prompt; } void add_token(const completion_token_output & token) { @@ -1993,7 +1994,7 @@ struct server_context { SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str()); return false; } - + vocab = llama_model_get_vocab(model); n_ctx = llama_n_ctx(ctx); @@ -3531,6 +3532,7 @@ struct server_context { const int tok_idx = slot.i_batch - i; llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); + slot.last_tok_idx = tok_idx; slot.i_batch = -1; @@ -3567,6 +3569,8 @@ struct server_context { } } + SRV_DBG("starting speculative decoding: %d\n", 1); + // do speculative decoding for (auto & slot : slots) { if (!slot.is_processing() || !slot.can_speculate()) { @@ -3583,7 +3587,9 @@ struct server_context { } // determine the max draft that fits the current slot state + SLT_DBG(slot, "starting mtp draft: %d\n", 2); int n_draft_max = slot.params.speculative.n_max; + SLT_DBG(slot, "starting mtp draft: %d\n", 3); // note: n_past is not yet increased for the `id` token sampled above // also, need to leave space for 1 extra token to allow context shifts @@ -3601,15 +3607,25 @@ struct server_context { continue; } + SLT_DBG(slot, "slot has mtp: %d\n", slot.has_mtp); + llama_token id = slot.sampled; - struct common_speculative_params params_spec; - params_spec.n_draft = n_draft_max; - params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max; - params_spec.p_min = slot.params.speculative.p_min; + llama_tokens draft; + if (slot.has_mtp) { + SLT_DBG(slot, "starting mtp draft: %d\n", 1); + llama_tokens draft = mtp_speculative_gen_draft(slot.smpl, ctx, id, slot.n_past, slot.last_tok_idx); + } + else { + struct common_speculative_params params_spec; + params_spec.n_draft = n_draft_max; + params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max; + params_spec.p_min = slot.params.speculative.p_min; + + const llama_tokens& cached_text_tokens = slot.cache_tokens.get_text_tokens(); - const llama_tokens & cached_text_tokens = slot.cache_tokens.get_text_tokens(); - llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); + llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); + } // ignore small drafts if (slot.params.speculative.n_min > (int) draft.size()) { From 6e9bafc7a738b4c99f9440c0ec461e08cf6ce702 Mon Sep 17 00:00:00 2001 From: Aaron Lee Date: Fri, 15 Aug 2025 23:13:56 -0400 Subject: [PATCH 6/8] failed attempt to implement MTP; outputs tokens but KV cache management is unreasonable --- common/sampling.cpp | 5 ++ common/speculative.cpp | 135 ++++++++-------------------------------- common/speculative.h | 2 +- include/llama.h | 5 +- src/llama-context.cpp | 70 ++++++++++++++++----- src/llama-context.h | 8 ++- src/llama-model.cpp | 37 ++++++++--- tools/server/server.cpp | 26 +++++--- 8 files changed, 141 insertions(+), 147 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 9c04d35fd00a2..a5824ebeedbaa 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -348,6 +348,11 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co llama_sampler_apply(chain, &cur_p); + /*for (int k = 0; k < (int)cur_p.size; ++k) { + LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f)\n", + k, 0, cur_p.data[k].id, cur_p.data[k].p); + }*/ + GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration"); const llama_token id = cur_p.data[cur_p.selected].id; diff --git a/common/speculative.cpp b/common/speculative.cpp index e46a0968bdec2..fa784f62f69db 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -6,6 +6,7 @@ #include "common.h" #include "sampling.h" #include "../src/llama-graph.h" +#include "../src/llama-context.h" #include #include @@ -362,126 +363,40 @@ llama_tokens common_speculative_gen_draft( } -llama_tokens mtp_speculative_gen_draft( - struct common_sampler * smpl, - struct llama_context * ctx, - llama_token id_last, - int32_t n_past, - int32_t last_tok_idx) { +llama_token mtp_speculative_gen_draft( + struct common_sampler* smpl, + struct llama_context* ctx, + llama_token id_last, + int32_t n_past, + int32_t last_tok_idx) { - llama_tokens result; - - LOG_INF("step: '%d'\n", 1); - - // sample one token from the draft model -- this does NOT generalize to >1 MTP head - result.reserve(1); - - // need to determine which architecture we're using so we call the correct MTP model const auto * model = llama_get_model(ctx); - - LOG_INF("step: '%d'\n", 2); - - //LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); - //auto * gf = model.build_graph(gparams); - - LOG_INF("step: '%d'\n", 3); - - /*if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { - LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); - ret = GGML_STATUS_ALLOC_FAILED; - return nullptr; - }*/ - - //llm_graph_result res_mtp(ctx->graph_max_nodes()); - llm_graph_result * res_mtp; - llama_ubatch ubatch_mtp; - ubatch_mtp.n_tokens = 1; - ubatch_mtp.pos = &n_past; // Critical for positional encoding - - // We also need a minimal ubatch to provide positional context (RoPE) - // ubatch_mtp.tokens = &last_token_id; - // ubatch_mtp.seq_id = llama_get_main_seq_id(ctx); // Assuming a helper - // ubatch_mtp.logits = nullptr; - // ubatch_mtp.all_pos_0 = -1; - // ubatch_mtp.all_pos_1 = -1; - // ubatch_mtp.all_seq_id = -1; - - // Manually construct the graph parameters - //const llm_graph_params params_mtp = { - // /*.arch =*/ model->arch, - // /*.hparams =*/ model->hparams, - // /*.cparams =*/ ctx->cparams, - // /*.ubatch =*/ ubatch_mtp, - // /*.gtype =*/ LLM_GRAPH_TYPE_DECODER, - // /*.sched =*/ ctx->sched.get(), - // /*.backend_cpu =*/ ctx->backend_cpu, - // /*.cvec =*/ &ctx->cvec, - // /*.loras =*/ &ctx->loras, - // /*.mctx =*/ llama_get_memory(ctx), // Use the KV cache's memory context - // /*.cross =*/ &ctx->cross, - // /*.n_outputs =*/ 1, - // /*.cb =*/ ctx->graph_get_cb(), - // /*.res =*/ &res_mtp, // Point to our temporary result object - //}; - llm_graph_params params_mtp = llama_mtp_graph_params(ctx, res_mtp, ubatch_mtp); - - LOG_INF("step: '%d'\n", 4); - - // ggml_cgraph* build_mtp_graph(const llm_graph_params & params, - // ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past) const; auto * last_embd = llama_get_embeddings_tensor(ctx); - LOG_INF("step: '%d'\n", 5); - GGML_ASSERT(model != nullptr); GGML_ASSERT(last_embd != nullptr); + llama_build_and_execute_mtp_graph(ctx, last_embd, id_last, n_past, last_tok_idx); - auto * gf = llama_build_mtp_graph(model, params_mtp, last_embd, id_last, n_past); - - if (!gf) { - LOG_INF("%s: failed to initialize graph\n", __func__); - //ret = GGML_STATUS_FAILED; - return result; - } - - LOG_INF("step: '%d'\n", 6); - - const auto status = llama_graph_compute(ctx, gf, false); - - LOG_INF("step: '%d'\n", 7); - - struct ggml_tensor * logits_mtp = llama_graph_result_get_logits(res_mtp); - float * ctx_logit_pointer = llama_get_logits(ctx); + common_sampler_sample(smpl, ctx, last_tok_idx, true); - LOG_INF("step: '%d'\n", 8); + const auto* cur_p = common_sampler_get_candidates(smpl); + /*LOG_INF("cur_p->size: %d\n", cur_p->size); - if (logits_mtp) { - llama_set_logits(ctx, logits_mtp); - } - - LOG_INF("step: '%d'\n", 9); - - { - common_sampler_sample(smpl, ctx, last_tok_idx, true); - - LOG_INF("step: '%d'\n", 10); - - const auto * cur_p = common_sampler_get_candidates(smpl); - - for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { - LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", - k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str()); - } - - // add drafted token for each sequence - const llama_token id = cur_p->data[0].id; + for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { + LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", + k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str()); + }*/ - // skip accepting draft token -- since we're only drafting one token this can't affect future outputs - // smpl will accept the token if it doesn't get rejected by main model later - // common_sampler_accept(smpl, id, true); + // add drafted token for each sequence + const llama_token id = cur_p->data[0].id; - result.push_back(id); - } + // skip accepting draft token -- since we're only drafting one token this can't affect future outputs + // smpl will accept the token if it doesn't get rejected by main model later + // common_sampler_accept(smpl, id, true); - return result; + //llama_tokens result; + //result.reserve(1); + //result.push_back(id); + //return result; + return id; } diff --git a/common/speculative.h b/common/speculative.h index 3b04890073867..6ff9e822f8d37 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -29,7 +29,7 @@ void common_speculative_add_replacement_tgt_dft( // sample up to n_draft tokens and add them to the batch using the draft model -llama_tokens mtp_speculative_gen_draft( +llama_token mtp_speculative_gen_draft( struct common_sampler* smpl, struct llama_context* ctx, llama_token id_last, diff --git a/include/llama.h b/include/llama.h index 2134f62d52714..16dc10d4032b7 100644 --- a/include/llama.h +++ b/include/llama.h @@ -977,8 +977,6 @@ extern "C" { // returns NULL for invalid ids. LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i); - LLAMA_API void llama_set_logits(struct llama_context* ctx, struct ggml_tensor* logit_override); - // Get all output token embeddings. // when pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model, // the embeddings for which llama_batch.logits[i] != 0 are stored contiguously @@ -1465,6 +1463,9 @@ extern "C" { LLAMA_API ggml_status llama_graph_compute(struct llama_context * ctx, struct ggml_cgraph * gf, bool batched); + LLAMA_API void llama_build_and_execute_mtp_graph(struct llama_context * ctx, + ggml_tensor* hidden_state_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx); + LLAMA_API ggml_tensor * llama_graph_result_get_logits(class llm_graph_result * res); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 26c3e639d8a95..ca713fa389097 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -523,12 +523,16 @@ float * llama_context::get_logits() { return logits; } -void llama_context::set_logits(struct ggml_tensor * logit_override) { - ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), logit_override); +void llama_context::set_logits_ith(struct ggml_tensor * logit_override, ggml_backend_sched_t sched_override, int32_t i) { + output_reorder(); + + ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched_override, logit_override); GGML_ASSERT(backend_res != nullptr); GGML_ASSERT(logits != nullptr); - ggml_backend_tensor_get_async(backend_res, logit_override, logits, 0, model.vocab.n_tokens() * sizeof(float)); + int64_t j = output_ids[i]; + + ggml_backend_tensor_get_async(backend_res, logit_override, logits + j*model.vocab.n_tokens(), 0, model.vocab.n_tokens() * sizeof(float)); } float * llama_context::get_logits_ith(int32_t i) { @@ -1445,21 +1449,23 @@ llm_graph_params llama_context::graph_params( llm_graph_params llama_context::mtp_graph_params( llm_graph_result* res, - const llama_ubatch& ubatch) const { + const llama_ubatch& ubatch) { + size_t n_nodes = std::max(1024u, 8u * 8u * (((model.hparams.nextn_predict_layers + 1) * model.n_tensors()) / model.hparams.n_layer)); + ggml_backend_sched_t temp_sched = create_temp_scheduler(n_nodes); return { /*.arch =*/ model.arch, /*.hparams =*/ model.hparams, /*.cparams =*/ cparams, /*.ubatch =*/ ubatch, /*.gtype =*/ LLM_GRAPH_TYPE_DECODER, - /*.sched =*/ sched.get(), + /*.sched =*/ temp_sched, /*.backend_cpu =*/ backend_cpu, /*.cvec =*/ &cvec, /*.loras =*/ &loras, /*.mctx =*/ memory->init_batch(*balloc, 1, false).get(), /*.cross =*/ &cross, /*.n_outputs =*/ 1, - /*.cb =*/ graph_get_cb(), + /*.cb =*/ graph_get_cb(temp_sched), /*.res =*/ res, }; } @@ -1491,8 +1497,10 @@ ggml_status llama_context::graph_compute( return status; } -llm_graph_cb llama_context::graph_get_cb() const { - return [&](const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il) { +llm_graph_cb llama_context::graph_get_cb(ggml_backend_sched * sched_override) const { + ggml_backend_sched * cb_sched = sched_override ? sched_override : sched.get(); + + return [=](const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il) { if (il >= 0) { ggml_format_name(cur, "%s-%d", name, il); } else { @@ -1502,7 +1510,7 @@ llm_graph_cb llama_context::graph_get_cb() const { if (!cparams.offload_kqv) { if (strcmp(name, "kqv_merged_cont") == 0) { // all nodes between the KV store and the attention output are run on the CPU - ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend_cpu); + ggml_backend_sched_set_tensor_backend(cb_sched, cur, backend_cpu); } } @@ -1515,7 +1523,7 @@ llm_graph_cb llama_context::graph_get_cb() const { for (const auto & backend : backends) { if (ggml_backend_get_device(backend.get()) == dev_layer) { if (ggml_backend_supports_op(backend.get(), cur)) { - ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend.get()); + ggml_backend_sched_set_tensor_backend(cb_sched, cur, backend.get()); } } } @@ -1524,6 +1532,10 @@ llm_graph_cb llama_context::graph_get_cb() const { }; } +ggml_backend_sched_t llama_context::create_temp_scheduler(size_t n_nodes) { + return ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), n_nodes, false, cparams.op_offload); +} + // // state save/load // @@ -2450,10 +2462,6 @@ float * llama_get_logits_ith(llama_context * ctx, int32_t i) { return ctx->get_logits_ith(i); } -void llama_set_logits(llama_context* ctx, struct ggml_tensor* logit_override) { - ctx->set_logits(logit_override); -} - float * llama_get_embeddings(llama_context * ctx) { ctx->synchronize(); @@ -2985,3 +2993,37 @@ llm_graph_params llama_mtp_graph_params(llama_context* ctx, llm_graph_result* re ggml_status llama_graph_compute(llama_context* ctx, ggml_cgraph* gf, bool batched) { return ctx->graph_compute(gf, batched); } + +void llama_build_and_execute_mtp_graph(struct llama_context * ctx, + ggml_tensor * hidden_state_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) { + + const auto * model = llama_get_model(ctx); + + auto res_mtp = std::make_unique(ctx->graph_max_nodes()); + + llama_ubatch ubatch_mtp; + ubatch_mtp.n_tokens = 1; + ubatch_mtp.pos = &n_past; + + auto params_mtp = std::make_unique(ctx->mtp_graph_params(res_mtp.get(), ubatch_mtp)); + + auto* gf = model->build_mtp_graph(*params_mtp, hidden_state_inp, last_token_id, n_past); + + ggml_backend_sched_t sched = params_mtp->sched; + + ggml_backend_sched_reset(sched); // clear the allocation of the previous graph + ggml_backend_sched_alloc_graph(sched, gf); // explicitly allocate the new graph but do not execute it + + ggml_tensor * mtp_token_id_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_token_id_input"); + + ggml_backend_tensor_set(mtp_token_id_input, &last_token_id, 0, sizeof(last_token_id)); // copy data to the newly allocated graph tensors + ggml_backend_sched_graph_compute(sched, gf); // execute the graph + + struct ggml_tensor * logits_mtp = res_mtp->get_logits();; + LLAMA_LOG_INFO("logits_mtp pointer address: %p\n", (void*)logits_mtp); + + if (logits_mtp) { + ctx->set_logits_ith(logits_mtp, sched, last_tok_idx); + } +} + diff --git a/src/llama-context.h b/src/llama-context.h index 44bcdf6d95260..20314304c070e 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -200,9 +200,11 @@ struct llama_context { // reserve a graph with a dummy ubatch of the specified size ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx); - llm_graph_params mtp_graph_params(llm_graph_result * res, const llama_ubatch & ubatch) const; + llm_graph_params mtp_graph_params(llm_graph_result * res, const llama_ubatch & ubatch); - void set_logits(struct ggml_tensor* logit_override); + void set_logits_ith(struct ggml_tensor * logit_override, ggml_backend_sched_t sched_override, int32_t i); + + ggml_backend_sched_t create_temp_scheduler(size_t n_nodes); private: llm_graph_params graph_params( @@ -211,7 +213,7 @@ struct llama_context { const llama_memory_context_i * mctx, llm_graph_type gtype) const; - llm_graph_cb graph_get_cb() const; + llm_graph_cb graph_get_cb(ggml_backend_sched * sched_override = nullptr) const; // TODO: read/write lora adapters and cvec size_t state_write_data(llama_io_write_i & io); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 8a9ba8480327e..b0c096dec6521 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13950,7 +13950,6 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { // For v0, let's rebuild the computational graph for every step + this mimics the vLLM impl parameterization ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past ) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -13958,22 +13957,43 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { const int il = hparams.n_layer - 1; const auto & mtp_layer = model.layers[il]; - ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1); - ggml_set_i32(inp_pos, n_past); - llm_graph_input_attn_no_cache * inp_attn = nullptr; + // ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1); + // ggml_set_i32(inp_pos, n_past); + ggml_tensor * inp_pos = build_inp_pos(); + + llm_graph_input_attn_no_cache * inp_attn = build_attn_inp_no_cache();//nullptr; ggml_tensor * cur; // get MTP embedding for last (conventionally sampled) token + // ggml_tensor * inp_token_id = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1); + // LLAMA_LOG_INFO("step: '%d'\n", 5641); + // ggml_set_i32(inp_token_id, last_token_id); + //ggml_set_no_alloc(ctx0, false); + //LLAMA_LOG_INFO("last token id: '%d'\n", last_token_id); + ggml_tensor * inp_token_id = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1); - ggml_set_i32(inp_token_id, last_token_id); + ggml_set_name(inp_token_id, "mtp_token_id_input"); + ggml_set_input(inp_token_id); + + //ggml_tensor * inp_token_id = ggml_new_i32(ctx0, last_token_id); + //ggml_set_no_alloc(ctx0, true); + ggml_tensor * token_emb = ggml_get_rows(ctx0, mtp_layer.nextn.embed_tokens, inp_token_id); ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il); + ggml_tensor * prev_embedding_leaf = ggml_dup_tensor(ctx0, hidden_state_inp); + ggml_set_name(prev_embedding_leaf, "mtp_prev_embedding_leaf"); + ggml_cpy(ctx0, hidden_state_inp, prev_embedding_leaf); + // vLLM l99 previous_hidden_states = self.hnorm(previous_hidden_states) - ggml_tensor * hidden_state_norm = build_norm(hidden_state_inp, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il); + ggml_tensor * hidden_state_norm = build_norm(prev_embedding_leaf, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il); + //token_emb_norm = ggml_cont(ctx0, token_emb_norm); + //hidden_state_norm = ggml_cont(ctx0, hidden_state_norm); ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0); // torch.cat + + cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined); // eh_proj @@ -14071,7 +14091,6 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { cur = ggml_add(ctx0, routed_out, shared_out); cb(cur, "ffn_out", il); } - cur = ggml_add(ctx0, cur, ffn_inp); cur = build_norm(cur, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, il); @@ -18680,14 +18699,12 @@ ggml_cgraph * llama_model::build_mtp_graph(const llm_graph_params& params, switch (arch) { case LLM_ARCH_GLM4_MOE: { - printf("step: '%d'\n", 56); llm = std::make_unique(*this, params, hidden_state_inp, last_token_id, n_past); } break; default: GGML_ABORT("fatal error"); } - printf("step: '%d'\n", 57); return llm->res->get_gf(); } @@ -19009,8 +19026,8 @@ const std::vector> & llama_internal_get_te ggml_cgraph * llama_build_mtp_graph(const llama_model * model, const llm_graph_params & params, ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past) { - printf("step: '%d'\n", 55); return model->build_mtp_graph(params, hidden_state_inp, last_token_id, n_past); } + diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 29d551ea5131b..e5039fe86ae66 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2132,6 +2132,8 @@ struct server_context { // assume one speculative token (true of all well-known MTP models so far) slot.batch_spec = llama_batch_init(2, 0, 1); + SLT_DBG(slot, "batch_spec contains %d tokens\n", slot.batch_spec.n_tokens); + params_base.speculative.n_min = 0; params_base.speculative.n_max = 1; } @@ -3587,9 +3589,7 @@ struct server_context { } // determine the max draft that fits the current slot state - SLT_DBG(slot, "starting mtp draft: %d\n", 2); int n_draft_max = slot.params.speculative.n_max; - SLT_DBG(slot, "starting mtp draft: %d\n", 3); // note: n_past is not yet increased for the `id` token sampled above // also, need to leave space for 1 extra token to allow context shifts @@ -3607,14 +3607,13 @@ struct server_context { continue; } - SLT_DBG(slot, "slot has mtp: %d\n", slot.has_mtp); - llama_token id = slot.sampled; llama_tokens draft; if (slot.has_mtp) { - SLT_DBG(slot, "starting mtp draft: %d\n", 1); - llama_tokens draft = mtp_speculative_gen_draft(slot.smpl, ctx, id, slot.n_past, slot.last_tok_idx); + llama_token draft_id = mtp_speculative_gen_draft(slot.smpl, ctx, id, slot.n_past, slot.last_tok_idx); + draft.reserve(1); + draft.push_back(draft_id); } else { struct common_speculative_params params_spec; @@ -3624,7 +3623,16 @@ struct server_context { const llama_tokens& cached_text_tokens = slot.cache_tokens.get_text_tokens(); - llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); + draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); + } + + //llama_token draft_id = mtp_speculative_gen_draft(slot.smpl, ctx, id, slot.n_past, slot.last_tok_idx); + //llama_tokens draft; + //draft.reserve(1); + //draft.push_back(draft_id); + + for (const auto& str : draft) { + SLT_DBG(slot, "%s\n", str); } // ignore small drafts @@ -3636,6 +3644,7 @@ struct server_context { // keep track of total number of drafted tokens tested slot.n_draft_total += draft.size(); + SLT_DBG(slot, "draft size = %d\n", draft.size()); // construct the speculation batch common_batch_clear(slot.batch_spec); @@ -3652,6 +3661,9 @@ struct server_context { // the accepted tokens from the speculation const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); + // if slot has mtp + // call + slot.n_past += ids.size(); slot.n_decoded += ids.size(); From 6870f9790c1bb1d0254241267b1a6c8a7fc82830 Mon Sep 17 00:00:00 2001 From: Aaron Lee Date: Sun, 17 Aug 2025 04:59:36 -0400 Subject: [PATCH 7/8] added proper KV cache management for MTP layers and slightly refactored --- common/speculative.cpp | 58 ++++++++++++++++++++++-------- common/speculative.h | 8 +++++ include/llama.h | 15 +------- src/llama-batch.cpp | 6 ++-- src/llama-context.cpp | 66 +++++++++++++++++++--------------- src/llama-context.h | 4 ++- src/llama-graph.cpp | 4 --- src/llama-kv-cache-unified.cpp | 2 +- src/llama-model.cpp | 22 +++++------- src/llama-model.h | 2 +- tools/server/server.cpp | 43 ++++++++++++++-------- 11 files changed, 136 insertions(+), 94 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index fa784f62f69db..9f8384abb1325 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -370,25 +370,45 @@ llama_token mtp_speculative_gen_draft( int32_t n_past, int32_t last_tok_idx) { - const auto * model = llama_get_model(ctx); - auto * last_embd = llama_get_embeddings_tensor(ctx); + llama_token token_data[] = { id_last }; + llama_pos pos_data[] = { n_past }; + int32_t n_seq_id_data[] = { 1 }; + llama_seq_id seq_id_data_internal[] = { 0 }; + llama_seq_id* seq_id_data[] = {seq_id_data_internal}; + int8_t logits_data[] = { (int8_t) (smpl != nullptr) }; + + llama_batch batch = { + /*.n_tokens = */ 1, + /*.token = */ token_data, + /*.embd = */ nullptr, + /*.pos = */ pos_data, + /*.n_seq_id = */ n_seq_id_data, + /*.seq_id = */ seq_id_data, + /*.logits = */ logits_data + }; + + llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx); + //LOG_INF("updating kv cache for n_past: %d\n", n_past); - GGML_ASSERT(model != nullptr); - GGML_ASSERT(last_embd != nullptr); - llama_build_and_execute_mtp_graph(ctx, last_embd, id_last, n_past, last_tok_idx); + if (!smpl) { + return -1; + } + else { + common_sampler_sample(smpl, ctx, last_tok_idx, true); + const auto* cur_p = common_sampler_get_candidates(smpl); - common_sampler_sample(smpl, ctx, last_tok_idx, true); + //for (int k = 0; k < std::min(3, (int)cur_p->size); ++k) { + // LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", + // k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str()); + //} - const auto* cur_p = common_sampler_get_candidates(smpl); - /*LOG_INF("cur_p->size: %d\n", cur_p->size); + const llama_token id = cur_p->data[0].id; + return id; + } + // LOG_INF("cur_p->size: %d\n", cur_p->size); - for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { - LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", - k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str()); - }*/ // add drafted token for each sequence - const llama_token id = cur_p->data[0].id; // skip accepting draft token -- since we're only drafting one token this can't affect future outputs // smpl will accept the token if it doesn't get rejected by main model later @@ -398,5 +418,15 @@ llama_token mtp_speculative_gen_draft( //result.reserve(1); //result.push_back(id); //return result; - return id; +} + + +void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens) { + mtp_kv_update_data token; + for (int i = 0; i < tokens.size(); ++i) { + token = tokens[i]; + mtp_speculative_gen_draft(nullptr, ctx, token.id, token.n_past, token.tok_idx); + } + + tokens.clear(); } diff --git a/common/speculative.h b/common/speculative.h index 6ff9e822f8d37..786f3ad1e8d6d 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -12,6 +12,12 @@ struct common_speculative_params { float p_min = 0.75f; // min probability required to accept a token in the draft }; +struct mtp_kv_update_data { + llama_token id; + int32_t n_past; + int32_t tok_idx; +}; + struct common_speculative * common_speculative_init( struct llama_context * ctx_tgt, struct llama_context * ctx_dft @@ -42,3 +48,5 @@ llama_tokens common_speculative_gen_draft( struct common_speculative_params params, const llama_tokens & prompt, llama_token id_last); + +void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens); diff --git a/include/llama.h b/include/llama.h index 16dc10d4032b7..1de8a963cc034 100644 --- a/include/llama.h +++ b/include/llama.h @@ -544,9 +544,6 @@ extern "C" { // Returns true if the model is diffusion-based (like LLaDA, Dream, etc.) LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model); - LLAMA_API ggml_cgraph * llama_build_mtp_graph(const struct llama_model * model, const struct llm_graph_params & params, - struct ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past); - // Returns 0 on success LLAMA_API uint32_t llama_model_quantize( const char * fname_inp, @@ -999,8 +996,6 @@ extern "C" { // otherwise: float[n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id); - LLAMA_API ggml_tensor * llama_get_embeddings_tensor(struct llama_context * ctx); - // // Vocab // @@ -1459,16 +1454,8 @@ extern "C" { ggml_opt_epoch_callback callback_train, ggml_opt_epoch_callback callback_eval); - LLAMA_API llm_graph_params llama_mtp_graph_params(struct llama_context* ctx, class llm_graph_result * res, const struct llama_ubatch& ubatch); - - LLAMA_API ggml_status llama_graph_compute(struct llama_context * ctx, struct ggml_cgraph * gf, bool batched); - LLAMA_API void llama_build_and_execute_mtp_graph(struct llama_context * ctx, - ggml_tensor* hidden_state_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx); - - LLAMA_API ggml_tensor * llama_graph_result_get_logits(class llm_graph_result * res); - - + const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx); #ifdef __cplusplus } diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 8698d89acecb2..ff73429301d68 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -275,7 +275,9 @@ bool llama_batch_allocr::init( } } - if (!ok) { + // TEMPORARILY DISABLING THIS SANITY CHECK + // TODO: UNDO THIS IF IT WORKS + /*if (!ok) { LLAMA_LOG_ERROR( "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n" " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n" @@ -284,7 +286,7 @@ bool llama_batch_allocr::init( __func__, s, s, p0, s, seq_pos_min(s)); return false; - } + }*/ } if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index ca713fa389097..34d514387b9dd 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1448,8 +1448,9 @@ llm_graph_params llama_context::graph_params( } llm_graph_params llama_context::mtp_graph_params( - llm_graph_result* res, - const llama_ubatch& ubatch) { + llm_graph_result * res, + const llama_ubatch& ubatch, + const llama_memory_context_i * mctx) { size_t n_nodes = std::max(1024u, 8u * 8u * (((model.hparams.nextn_predict_layers + 1) * model.n_tensors()) / model.hparams.n_layer)); ggml_backend_sched_t temp_sched = create_temp_scheduler(n_nodes); return { @@ -1462,7 +1463,7 @@ llm_graph_params llama_context::mtp_graph_params( /*.backend_cpu =*/ backend_cpu, /*.cvec =*/ &cvec, /*.loras =*/ &loras, - /*.mctx =*/ memory->init_batch(*balloc, 1, false).get(), + /*.mctx =*/ mctx, /*.cross =*/ &cross, /*.n_outputs =*/ 1, /*.cb =*/ graph_get_cb(temp_sched), @@ -1470,6 +1471,21 @@ llm_graph_params llama_context::mtp_graph_params( }; } +std::unique_ptr llama_context::mtp_memory_batch(const llama_batch& batch_inp) { + const auto& vocab = model.vocab; + const auto& hparams = model.hparams; + + const int64_t n_vocab = vocab.n_tokens(); + const int64_t n_embd = hparams.n_embd; + + if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, false)) { + LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); + return nullptr; + } + + return memory->init_batch(*balloc, 1, false); +} + ggml_status llama_context::graph_compute( ggml_cgraph * gf, bool batched) { @@ -2481,13 +2497,6 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) { return ctx->get_embeddings_seq(seq_id); } -ggml_tensor * llama_get_embeddings_tensor(llama_context * ctx) { - ctx->synchronize(); - - return ctx->get_embeddings_tensor(); -} - - // llama adapter API int32_t llama_set_adapter_lora( @@ -2985,42 +2994,43 @@ void llama_opt_epoch( callback_eval); } -llm_graph_params llama_mtp_graph_params(llama_context* ctx, llm_graph_result* res, const llama_ubatch& ubatch) { - return ctx->mtp_graph_params(res, ubatch); -} - - -ggml_status llama_graph_compute(llama_context* ctx, ggml_cgraph* gf, bool batched) { - return ctx->graph_compute(gf, batched); -} - void llama_build_and_execute_mtp_graph(struct llama_context * ctx, - ggml_tensor * hidden_state_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) { + const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) { const auto * model = llama_get_model(ctx); auto res_mtp = std::make_unique(ctx->graph_max_nodes()); + llama_memory_context_ptr mctx = ctx->mtp_memory_batch(batch_inp); + const auto& ubatch_mtp = mctx->get_ubatch(); - llama_ubatch ubatch_mtp; - ubatch_mtp.n_tokens = 1; - ubatch_mtp.pos = &n_past; + //llama_ubatch ubatch_mtp; + //ubatch_mtp.n_tokens = 1; + //ubatch_mtp.pos = &n_past; - auto params_mtp = std::make_unique(ctx->mtp_graph_params(res_mtp.get(), ubatch_mtp)); + auto params_mtp = std::make_unique(ctx->mtp_graph_params(res_mtp.get(), ubatch_mtp, mctx.get())); + ggml_backend_sched_t sched = params_mtp->sched; - auto* gf = model->build_mtp_graph(*params_mtp, hidden_state_inp, last_token_id, n_past); + auto * last_embd = ctx->get_embeddings_ith(last_tok_idx); - ggml_backend_sched_t sched = params_mtp->sched; + if (mctx && !mctx->apply()) { + LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__); + } + + auto * gf = model->build_mtp_graph(*params_mtp, last_token_id, n_past); ggml_backend_sched_reset(sched); // clear the allocation of the previous graph ggml_backend_sched_alloc_graph(sched, gf); // explicitly allocate the new graph but do not execute it ggml_tensor * mtp_token_id_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_token_id_input"); - ggml_backend_tensor_set(mtp_token_id_input, &last_token_id, 0, sizeof(last_token_id)); // copy data to the newly allocated graph tensors + + ggml_tensor * mtp_prev_embedding_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_prev_embedding_input"); + ggml_backend_tensor_set(mtp_prev_embedding_input, last_embd, 0, ggml_nbytes(mtp_prev_embedding_input)); // copy data to the newly allocated graph tensors + ggml_backend_sched_graph_compute(sched, gf); // execute the graph struct ggml_tensor * logits_mtp = res_mtp->get_logits();; - LLAMA_LOG_INFO("logits_mtp pointer address: %p\n", (void*)logits_mtp); + //LLAMA_LOG_INFO("logits_mtp pointer address: %p\n", (void*)logits_mtp); if (logits_mtp) { ctx->set_logits_ith(logits_mtp, sched, last_tok_idx); diff --git a/src/llama-context.h b/src/llama-context.h index 20314304c070e..e8ea3a4c9be39 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -200,12 +200,14 @@ struct llama_context { // reserve a graph with a dummy ubatch of the specified size ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx); - llm_graph_params mtp_graph_params(llm_graph_result * res, const llama_ubatch & ubatch); + llm_graph_params mtp_graph_params(llm_graph_result * res, const llama_ubatch & ubatch, const llama_memory_context_i * mctx); void set_logits_ith(struct ggml_tensor * logit_override, ggml_backend_sched_t sched_override, int32_t i); ggml_backend_sched_t create_temp_scheduler(size_t n_nodes); + std::unique_ptr mtp_memory_batch(const llama_batch& batch_inp); + private: llm_graph_params graph_params( llm_graph_result * res, diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index b5184e4559d4e..053c72d6dc8d1 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1911,7 +1911,3 @@ int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buck return relative_bucket; } - -ggml_tensor * llama_graph_result_get_logits(llm_graph_result * res) { - return res->get_logits(); -} diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index e539142e6b8cd..ed6cf969d4a9a 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -41,7 +41,7 @@ llama_kv_cache_unified::llama_kv_cache_unified( } if (model.arch == LLM_ARCH_GLM4_MOE) { // GLM-4.5: Only process up to last layer, skip final NextN layer - n_layer_cache = hparams.n_layer - hparams.nextn_predict_layers; + n_layer_cache = hparams.n_layer;// - hparams.nextn_predict_layers; } // create a context for each buffer type diff --git a/src/llama-model.cpp b/src/llama-model.cpp index b0c096dec6521..04743e01f37a2 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13948,7 +13948,7 @@ struct llm_build_glm4_moe : public llm_graph_context { struct llm_build_glm4_moe_mtp : public llm_graph_context { llm_build_glm4_moe_mtp(const llama_model & model, const llm_graph_params & params, // For v0, let's rebuild the computational graph for every step + this mimics the vLLM impl parameterization - ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past + llama_token last_token_id, int n_past ) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -13961,7 +13961,8 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { // ggml_set_i32(inp_pos, n_past); ggml_tensor * inp_pos = build_inp_pos(); - llm_graph_input_attn_no_cache * inp_attn = build_attn_inp_no_cache();//nullptr; + //llm_graph_input_attn_no_cache * inp_attn = build_attn_inp_no_cache();//nullptr; + auto * inp_attn = build_attn_inp_kv_unified(); ggml_tensor * cur; @@ -13982,9 +13983,9 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { ggml_tensor * token_emb = ggml_get_rows(ctx0, mtp_layer.nextn.embed_tokens, inp_token_id); ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il); - ggml_tensor * prev_embedding_leaf = ggml_dup_tensor(ctx0, hidden_state_inp); - ggml_set_name(prev_embedding_leaf, "mtp_prev_embedding_leaf"); - ggml_cpy(ctx0, hidden_state_inp, prev_embedding_leaf); + ggml_tensor* prev_embedding_leaf = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, model.hparams.n_embd); + ggml_set_name(prev_embedding_leaf, "mtp_prev_embedding_input"); + ggml_set_input(prev_embedding_leaf); // vLLM l99 previous_hidden_states = self.hnorm(previous_hidden_states) ggml_tensor * hidden_state_norm = build_norm(prev_embedding_leaf, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il); @@ -18693,13 +18694,13 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { } ggml_cgraph * llama_model::build_mtp_graph(const llm_graph_params& params, - ggml_tensor* hidden_state_inp, llama_token last_token_id, int n_past) const { + llama_token last_token_id, int n_past) const { std::unique_ptr llm; switch (arch) { case LLM_ARCH_GLM4_MOE: { - llm = std::make_unique(*this, params, hidden_state_inp, last_token_id, n_past); + llm = std::make_unique(*this, params, last_token_id, n_past); } break; default: GGML_ABORT("fatal error"); @@ -19024,10 +19025,3 @@ const std::vector> & llama_internal_get_te return model->tensors_by_name; } -ggml_cgraph * llama_build_mtp_graph(const llama_model * model, const llm_graph_params & params, - ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past) { - - return model->build_mtp_graph(params, hidden_state_inp, last_token_id, n_past); -} - - diff --git a/src/llama-model.h b/src/llama-model.h index 77a18aca7164f..b28a37488f78a 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -476,7 +476,7 @@ struct llama_model { // TODO: move this to new llm_arch_model_i interface ggml_cgraph * build_graph(const llm_graph_params & params) const; ggml_cgraph * build_mtp_graph(const llm_graph_params & params, - ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past) const; + llama_token last_token_id, int n_past) const; private: struct impl; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index e5039fe86ae66..b85fa4e7691c9 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1278,6 +1278,7 @@ struct server_task_result_apply_lora : server_task_result { } }; + struct server_slot { int id; int id_task = -1; @@ -1295,8 +1296,9 @@ struct server_slot { common_speculative * spec = nullptr; bool has_mtp = false; + std::vector mtp_kv_update_batch; int32_t last_tok_idx = -1; - + std::vector lora; // the index relative to completion multi-task request @@ -1393,7 +1395,7 @@ struct server_slot { } bool need_embd() const { - return server_task_type_need_embd(task_type); + return server_task_type_need_embd(task_type) || has_mtp; } bool need_logits() const { @@ -1569,6 +1571,7 @@ struct server_slot { } }; + struct server_metrics { int64_t t_start = 0; @@ -1994,7 +1997,7 @@ struct server_context { SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str()); return false; } - + vocab = llama_model_get_vocab(model); n_ctx = llama_n_ctx(ctx); @@ -2124,18 +2127,21 @@ struct server_context { common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str()); } } - + // if model has MTP and no draft model is specified... else if (llama_model_n_nextn_layer(model) > 0) { SRV_INF("model has nextn layers = %d\n", llama_model_n_nextn_layer(model)); slot.has_mtp = true; - + // assume one speculative token (true of all well-known MTP models so far) slot.batch_spec = llama_batch_init(2, 0, 1); SLT_DBG(slot, "batch_spec contains %d tokens\n", slot.batch_spec.n_tokens); params_base.speculative.n_min = 0; params_base.speculative.n_max = 1; + + SRV_INF("%s\n", "MTP needs embeddings on decode, enabling"); + llama_set_embeddings(ctx, true); } SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx); @@ -3383,7 +3389,11 @@ struct server_context { // embedding requires all tokens in the batch to be output const bool need_embd = server_task_type_need_embd(slot.task_type); + if (slot.has_mtp) { + slot.mtp_kv_update_batch.push_back({ cur_tok, slot.n_past, batch.n_tokens }); + } common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd); + slot.cache_tokens.push_back(cur_tok); slot.n_prompt_tokens_processed++; @@ -3533,6 +3543,11 @@ struct server_context { const int tok_idx = slot.i_batch - i; + // This should only trigger on a non-empty update batch once, after prompt processing but not during token generation + if (slot.has_mtp) { + mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch); + } + llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); slot.last_tok_idx = tok_idx; @@ -3571,8 +3586,6 @@ struct server_context { } } - SRV_DBG("starting speculative decoding: %d\n", 1); - // do speculative decoding for (auto & slot : slots) { if (!slot.is_processing() || !slot.can_speculate()) { @@ -3631,13 +3644,9 @@ struct server_context { //draft.reserve(1); //draft.push_back(draft_id); - for (const auto& str : draft) { - SLT_DBG(slot, "%s\n", str); - } - // ignore small drafts - if (slot.params.speculative.n_min > (int) draft.size()) { - SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min); + if (slot.params.speculative.n_min > (int)draft.size()) { + SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int)draft.size(), slot.params.speculative.n_min); continue; } @@ -3661,8 +3670,12 @@ struct server_context { // the accepted tokens from the speculation const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); - // if slot has mtp - // call + if (slot.has_mtp) { + for (int32_t i = 0; i < ids.size(); ++i) { + slot.mtp_kv_update_batch.push_back({ ids[i], slot.n_past + 1 + i, i }); + } + mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch); + } slot.n_past += ids.size(); slot.n_decoded += ids.size(); From 382135aa3619294ab8bf87b0de4b1255ab7942f0 Mon Sep 17 00:00:00 2001 From: Aaron Lee Date: Sun, 17 Aug 2025 21:54:45 -0400 Subject: [PATCH 8/8] fixed mtp kv cache update sequencing after prompt processing --- tools/server/server.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index b85fa4e7691c9..e323f7b521083 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -3543,18 +3543,19 @@ struct server_context { const int tok_idx = slot.i_batch - i; - // This should only trigger on a non-empty update batch once, after prompt processing but not during token generation - if (slot.has_mtp) { - mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch); - } - llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); slot.last_tok_idx = tok_idx; + SRV_INF("main loop sampled token: '%s'\n", common_token_to_piece(ctx, id, true).c_str()); slot.i_batch = -1; common_sampler_accept(slot.smpl, id, true); + // This should only trigger on a non-empty update batch once, after prompt processing but not during token generation + if (slot.has_mtp) { + mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch); + } + slot.n_decoded += 1; const int64_t t_current = ggml_time_us();