Skip to content

server: implement GLM-style MTP #15225

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft

Conversation

F1LM1
Copy link

@F1LM1 F1LM1 commented Aug 11, 2025

This is very much a draft/proof of concept I'm playing with, just one idea for an MTP implementation. Planning to test on GLM-4.5 because it's the only model out there that we've preserved NextN tensors for.

From what I can tell

  • the three models with MTP implemented in vLLM right now are all "DeepseekV3-style,"
  • they only have one MTP head, which predicts token at position n+2,
  • the MTP layers take as input the output embedding from the last conventional layer and their own input embedding.

So implementation-wise it seems like

  • we should try to reuse the existing speculative decode functionality (including nice stuff like main model KV cache management, various samplers, etc.),
  • but a lot of the full draft model functionality is redundant/harmful, like context/cache management for the draft model, vocab matching,
  • it probably makes sense to write a new function like mtp_speculative_gen_draft in speculative.cpp that is vastly simplified and branch into it in server.cpp when a slot has MTP (versus common_speculative_gen_draft).
  • AFAICT it looks like the server.cpp loop currently alternates between conventional forward pass and draft, which in the MTP case will probably sabotage performance gains (since our max throughput is only 1.5 tok/pass assuming zero rejections, instead of 2 tok/pass). Let me know if this isn't the case!—but if it is, should probably avoid doing non-speculative decodes after the first response token.
  • It doesn't make sense to have to manage a distinct ctx_dft in this case as well. It's a bit hacky but I was thinking we could just have ctx_dft = ctx and then have both normal and MTP passes write over the shared ctx logits. I think this minimizes required code changes elsewhere

This is my first time (1) working with ML stuff outside of python (2) attempting to contribute, so patience is appreciated :)

@ggerganov ggerganov added the hot Something that is hot label Aug 12, 2025
@ggerganov
Copy link
Member

AFAICT it looks like the server.cpp loop currently alternates between conventional forward pass and draft, which in the MTP case will probably sabotage performance gains (since our max throughput is only 1.5 tok/pass assuming zero rejections, instead of 2 tok/pass). Let me know if this isn't the case!—but if it is, should probably avoid doing non-speculative decodes after the first response token.

This is correct - we always alternate between conventional and speculative passes. It's definitely not optimal, but improves flexibility for regular sampling. It allows to change the speculative parameters and even disable it per request, while the logic is quite simple.

It should be possible to improve this by keeping track which slots are speculating on each iteration and skip adding tokens to the conventional batch for them. It might be a good idea to implement this separately to avoid huge changes in the logic in a single PR.

@ggerganov
Copy link
Member

Generally we should try to minimize the changes to llama.h, since changing/extending the public API requires a lot of effort.

On first look, I think the path that involves minimal changes is:

  • Add int n_mtp flag to llama_context_params (default = 1 - MTP is disabled, 2 - predict logits for one additional token, 3 - predict logits for 2 additional tokens, etc.)
  • Use this flag during graph build to determine if the MTP heads should be appended to the graph
  • Keep the conventional logits in the t_logits tensor in llm_graph_result
  • Add new tensor t_logits_mtp (or whatever is more appropriate) in llm_graph_result and use it to store the MTP results in it
  • In llama_decode() extract the t_logits_mtp data when available, following the same logic as for t_logits

Extracting the MTP logits during llama_decode() can be done in 2 ways:

  • Create separate buffer in the llama_context to store them and add a new llama_get_logits_mtp_ith() API that works with that new buffer in a similar way as the existing llama_get_logits_ith()
  • Reuse the existing logits buffer by expanding it to from [n_outputs][n_vocab] to [n_outputs][n_mtp*n_vocab]. This would avoid the need to add llama_get_logits_mtp_ith() and we can generalize the existing llama_get_logits_ith() by taking into account the value of n_mtp.

Currently, I am not sure which way is better. The first requires a new API call, while the second might break some existing assumptions (not sure if that's the case yet).

In any case, you can avoid this until you get the implementation working with a reasonable speedup. After that, we can discuss further how to best refactor the implementation.

@slaren
Copy link
Member

slaren commented Aug 13, 2025

Currently, I am not sure which way is better. The first requires a new API call, while the second might break some existing assumptions (not sure if that's the case yet).

I don't see an issue with adding a new API for this, and it would be easier to use.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
examples hot Something that is hot server
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants