Skip to content

server: implement GLM-style MTP #15225

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,11 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co

llama_sampler_apply(chain, &cur_p);

/*for (int k = 0; k < (int)cur_p.size; ++k) {
LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f)\n",
k, 0, cur_p.data[k].id, cur_p.data[k].p);
}*/

GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");

const llama_token id = cur_p.data[cur_p.selected].id;
Expand Down
41 changes: 41 additions & 0 deletions common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include "log.h"
#include "common.h"
#include "sampling.h"
#include "../src/llama-graph.h"
#include "../src/llama-context.h"

#include <cstring>
#include <algorithm>
Expand Down Expand Up @@ -359,3 +361,42 @@ llama_tokens common_speculative_gen_draft(
}
return result;
}


llama_token mtp_speculative_gen_draft(
struct common_sampler* smpl,
struct llama_context* ctx,
llama_token id_last,
int32_t n_past,
int32_t last_tok_idx) {

const auto * model = llama_get_model(ctx);
auto * last_embd = llama_get_embeddings_tensor(ctx);

GGML_ASSERT(model != nullptr);
GGML_ASSERT(last_embd != nullptr);
llama_build_and_execute_mtp_graph(ctx, last_embd, id_last, n_past, last_tok_idx);

common_sampler_sample(smpl, ctx, last_tok_idx, true);

const auto* cur_p = common_sampler_get_candidates(smpl);
/*LOG_INF("cur_p->size: %d\n", cur_p->size);
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str());
}*/

// add drafted token for each sequence
const llama_token id = cur_p->data[0].id;

// skip accepting draft token -- since we're only drafting one token this can't affect future outputs
// smpl will accept the token if it doesn't get rejected by main model later
// common_sampler_accept(smpl, id, true);

//llama_tokens result;
//result.reserve(1);
//result.push_back(id);
//return result;
return id;
}
9 changes: 9 additions & 0 deletions common/speculative.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@ void common_speculative_add_replacement_tgt_dft(
struct common_speculative * spec,
const char *source, const char *dest);


// sample up to n_draft tokens and add them to the batch using the draft model
llama_token mtp_speculative_gen_draft(
struct common_sampler* smpl,
struct llama_context* ctx,
llama_token id_last,
int32_t n_past,
int32_t last_tok_idx);

// sample up to n_draft tokens and add them to the batch using the draft model
llama_tokens common_speculative_gen_draft(
struct common_speculative * spec,
Expand Down
20 changes: 20 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,8 @@ extern "C" {

LLAMA_API int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab);

LLAMA_API int32_t llama_model_n_nextn_layer(const struct llama_model * model);

// Functions to access the model's GGUF metadata scalar values
// - The functions return the length of the string on success, or -1 on failure
// - The output string is always null-terminated and cleared on failure
Expand Down Expand Up @@ -542,12 +544,17 @@ extern "C" {
// Returns true if the model is diffusion-based (like LLaDA, Dream, etc.)
LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model);

LLAMA_API ggml_cgraph * llama_build_mtp_graph(const struct llama_model * model, const struct llm_graph_params & params,
struct ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past);

// Returns 0 on success
LLAMA_API uint32_t llama_model_quantize(
const char * fname_inp,
const char * fname_out,
const llama_model_quantize_params * params);



//
// Adapters
//
Expand Down Expand Up @@ -992,6 +999,8 @@ extern "C" {
// otherwise: float[n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);

LLAMA_API ggml_tensor * llama_get_embeddings_tensor(struct llama_context * ctx);

//
// Vocab
//
Expand Down Expand Up @@ -1450,6 +1459,17 @@ extern "C" {
ggml_opt_epoch_callback callback_train,
ggml_opt_epoch_callback callback_eval);

LLAMA_API llm_graph_params llama_mtp_graph_params(struct llama_context* ctx, class llm_graph_result * res, const struct llama_ubatch& ubatch);

LLAMA_API ggml_status llama_graph_compute(struct llama_context * ctx, struct ggml_cgraph * gf, bool batched);

LLAMA_API void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
ggml_tensor* hidden_state_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx);

LLAMA_API ggml_tensor * llama_graph_result_get_logits(class llm_graph_result * res);



#ifdef __cplusplus
}
#endif
Expand Down
13 changes: 7 additions & 6 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2240,12 +2240,13 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_SHORTCONV_OUTPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
// NextN/MTP tensors are currently ignored (reserved for future MTP support)
// These tensors only exist in the last layer(s) and are treated as output tensors
{LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
// Changed to LLM_TENSOR_LAYER_REPEATING because we saved these under a blk with a non-negative id
{LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}},
{LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}},
{LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
};

LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
Expand Down
109 changes: 105 additions & 4 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "llama-memory.h"
#include "llama-mmap.h"
#include "llama-model.h"
#include "llama-graph.h"

#include <cinttypes>
#include <cstring>
Expand Down Expand Up @@ -522,6 +523,18 @@ float * llama_context::get_logits() {
return logits;
}

void llama_context::set_logits_ith(struct ggml_tensor * logit_override, ggml_backend_sched_t sched_override, int32_t i) {
output_reorder();

ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched_override, logit_override);
GGML_ASSERT(backend_res != nullptr);
GGML_ASSERT(logits != nullptr);

int64_t j = output_ids[i];

ggml_backend_tensor_get_async(backend_res, logit_override, logits + j*model.vocab.n_tokens(), 0, model.vocab.n_tokens() * sizeof(float));
}

float * llama_context::get_logits_ith(int32_t i) {
int64_t j = -1;

Expand Down Expand Up @@ -617,6 +630,10 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
return it->second.data();
}

ggml_tensor * llama_context::get_embeddings_tensor() {
return embd_tensor;
}

void llama_context::attach_threadpool(
ggml_threadpool_t threadpool,
ggml_threadpool_t threadpool_batch) {
Expand Down Expand Up @@ -1113,6 +1130,7 @@ int llama_context::decode(const llama_batch & batch_inp) {

auto * t_logits = res->get_logits();
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
embd_tensor = res->get_embd();

if (t_embd && res->get_embd_pooled()) {
t_embd = res->get_embd_pooled();
Expand Down Expand Up @@ -1429,6 +1447,29 @@ llm_graph_params llama_context::graph_params(
};
}

llm_graph_params llama_context::mtp_graph_params(
llm_graph_result* res,
const llama_ubatch& ubatch) {
size_t n_nodes = std::max<uint32_t>(1024u, 8u * 8u * (((model.hparams.nextn_predict_layers + 1) * model.n_tensors()) / model.hparams.n_layer));
ggml_backend_sched_t temp_sched = create_temp_scheduler(n_nodes);
return {
/*.arch =*/ model.arch,
/*.hparams =*/ model.hparams,
/*.cparams =*/ cparams,
/*.ubatch =*/ ubatch,
/*.gtype =*/ LLM_GRAPH_TYPE_DECODER,
/*.sched =*/ temp_sched,
/*.backend_cpu =*/ backend_cpu,
/*.cvec =*/ &cvec,
/*.loras =*/ &loras,
/*.mctx =*/ memory->init_batch(*balloc, 1, false).get(),
/*.cross =*/ &cross,
/*.n_outputs =*/ 1,
/*.cb =*/ graph_get_cb(temp_sched),
/*.res =*/ res,
};
}

ggml_status llama_context::graph_compute(
ggml_cgraph * gf,
bool batched) {
Expand Down Expand Up @@ -1456,8 +1497,10 @@ ggml_status llama_context::graph_compute(
return status;
}

llm_graph_cb llama_context::graph_get_cb() const {
return [&](const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il) {
llm_graph_cb llama_context::graph_get_cb(ggml_backend_sched * sched_override) const {
ggml_backend_sched * cb_sched = sched_override ? sched_override : sched.get();

return [=](const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il) {
if (il >= 0) {
ggml_format_name(cur, "%s-%d", name, il);
} else {
Expand All @@ -1467,7 +1510,7 @@ llm_graph_cb llama_context::graph_get_cb() const {
if (!cparams.offload_kqv) {
if (strcmp(name, "kqv_merged_cont") == 0) {
// all nodes between the KV store and the attention output are run on the CPU
ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend_cpu);
ggml_backend_sched_set_tensor_backend(cb_sched, cur, backend_cpu);
}
}

Expand All @@ -1480,7 +1523,7 @@ llm_graph_cb llama_context::graph_get_cb() const {
for (const auto & backend : backends) {
if (ggml_backend_get_device(backend.get()) == dev_layer) {
if (ggml_backend_supports_op(backend.get(), cur)) {
ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend.get());
ggml_backend_sched_set_tensor_backend(cb_sched, cur, backend.get());
}
}
}
Expand All @@ -1489,6 +1532,10 @@ llm_graph_cb llama_context::graph_get_cb() const {
};
}

ggml_backend_sched_t llama_context::create_temp_scheduler(size_t n_nodes) {
return ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), n_nodes, false, cparams.op_offload);
}

//
// state save/load
//
Expand Down Expand Up @@ -2233,6 +2280,7 @@ void llama_context::opt_epoch(
llama_batch_free(batch);
}


//
// interface implementation
//
Expand Down Expand Up @@ -2274,6 +2322,8 @@ llama_context_params llama_context_default_params() {
return result;
}



llama_context * llama_init_from_model(
llama_model * model,
llama_context_params params) {
Expand Down Expand Up @@ -2412,6 +2462,7 @@ float * llama_get_logits_ith(llama_context * ctx, int32_t i) {
return ctx->get_logits_ith(i);
}


float * llama_get_embeddings(llama_context * ctx) {
ctx->synchronize();

Expand All @@ -2430,6 +2481,13 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
return ctx->get_embeddings_seq(seq_id);
}

ggml_tensor * llama_get_embeddings_tensor(llama_context * ctx) {
ctx->synchronize();

return ctx->get_embeddings_tensor();
}


// llama adapter API

int32_t llama_set_adapter_lora(
Expand Down Expand Up @@ -2926,3 +2984,46 @@ void llama_opt_epoch(
callback_train,
callback_eval);
}

llm_graph_params llama_mtp_graph_params(llama_context* ctx, llm_graph_result* res, const llama_ubatch& ubatch) {
return ctx->mtp_graph_params(res, ubatch);
}


ggml_status llama_graph_compute(llama_context* ctx, ggml_cgraph* gf, bool batched) {
return ctx->graph_compute(gf, batched);
}

void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
ggml_tensor * hidden_state_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) {

const auto * model = llama_get_model(ctx);

auto res_mtp = std::make_unique<llm_graph_result>(ctx->graph_max_nodes());

llama_ubatch ubatch_mtp;
ubatch_mtp.n_tokens = 1;
ubatch_mtp.pos = &n_past;

auto params_mtp = std::make_unique<llm_graph_params>(ctx->mtp_graph_params(res_mtp.get(), ubatch_mtp));

auto* gf = model->build_mtp_graph(*params_mtp, hidden_state_inp, last_token_id, n_past);

ggml_backend_sched_t sched = params_mtp->sched;

ggml_backend_sched_reset(sched); // clear the allocation of the previous graph
ggml_backend_sched_alloc_graph(sched, gf); // explicitly allocate the new graph but do not execute it

ggml_tensor * mtp_token_id_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_token_id_input");

ggml_backend_tensor_set(mtp_token_id_input, &last_token_id, 0, sizeof(last_token_id)); // copy data to the newly allocated graph tensors
ggml_backend_sched_graph_compute(sched, gf); // execute the graph

struct ggml_tensor * logits_mtp = res_mtp->get_logits();;
LLAMA_LOG_INFO("logits_mtp pointer address: %p\n", (void*)logits_mtp);

if (logits_mtp) {
ctx->set_logits_ith(logits_mtp, sched, last_tok_idx);
}
}

Loading