diff --git a/common/speculative.cpp b/common/speculative.cpp index 843bd1ddb..5a67b46e1 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -278,3 +278,153 @@ llama_tokens common_speculative_gen_draft( return result; } + +llama_tokens common_speculative_gen_draft_eagle( + struct common_speculative * spec, + struct common_speculative_params params, + const llama_tokens & prompt_tgt, + llama_token id_last, + std::vector & data) { + auto & batch = spec->batch; + auto & ctx = spec->ctx; + auto & smpl = spec->smpl; + auto & prompt = spec->prompt; + + auto * mem = llama_get_memory(ctx); + + int reuse_i = 0; + int reuse_n = 0; + + const int n_ctx = llama_n_ctx(ctx) - params.n_draft; + + const int i_start = std::max(1, (int) prompt_tgt.size() - n_ctx); + + int n_accepted_draft_tokens = data.size() / sizeof(float) / llama_model_n_embd(llama_get_model(ctx)) - 1; + + // reuse as much as possible from the old draft context + // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt + for (int i = 0; i < (int) prompt.size(); ++i) { + int cur = 0; + while (i_start + cur < (int) prompt_tgt.size() && + i + cur < (int) prompt.size() && + prompt_tgt[i_start + cur] == prompt[i + cur]) { + cur++; + } + + cur = (cur - n_accepted_draft_tokens) > 0 ? (cur - n_accepted_draft_tokens) : cur; + + if ((cur >= params.n_reuse || n_ctx >= (int) prompt_tgt.size()) && cur > reuse_n) { + reuse_i = i; + reuse_n = cur; + } + } + + LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt.size()); + + llama_tokens result; + result.reserve(params.n_draft); + + if (reuse_n == 0) { + llama_memory_clear(mem, false); + + prompt.clear(); + } else { + // this happens when a previous draft has been discarded (for example, due to being too small), but the + // target model agreed with it. in this case, we simply pass back the previous results to save compute + if (reuse_i + reuse_n < (int) prompt.size() && prompt[reuse_i + reuse_n] == id_last) { + for (int i = reuse_i + reuse_n + 1; i < (int) prompt.size(); ++i) { + result.push_back(prompt[i]); + + if (params.n_draft <= (int) result.size()) { + break; + } + } + + return result; + } + + if (reuse_i > 0) { + llama_memory_seq_rm (mem, 0, 0, reuse_i); + llama_memory_seq_add(mem, 0, reuse_i, -1, -reuse_i); + + prompt.erase(prompt.begin(), prompt.begin() + reuse_i); + } + + if (reuse_n < (int) prompt.size()) { + llama_memory_seq_rm (mem, 0, reuse_n, -1); + + prompt.erase(prompt.begin() + reuse_n, prompt.end()); + } + } + + // prepare a batch to evaluate any new tokens in the prompt + common_batch_clear(batch); + + for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) { + //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]); + common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, (i < prompt_tgt.size() - 1) ? false : true); + + prompt.push_back(prompt_tgt[i]); + } + + // we should rarely end-up here during normal decoding + if (batch.n_tokens > 0) { + //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); + + llama_decode_eagle(ctx, batch, data.data()); + } + + const llama_pos n_past = prompt.size(); + + LOG_DBG("%s: n_past = %d\n", __func__, n_past); + + common_batch_clear(batch); + common_batch_add (batch, id_last, n_past, { 0 }, true); + + prompt.push_back(id_last); + + //LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str()); + + llama_decode_eagle(ctx, batch, data.data()); + + common_sampler_reset(smpl); + + // sample n_draft tokens from the draft model + for (int i = 0; i < params.n_draft; ++i) { + common_batch_clear(batch); + + common_sampler_sample(smpl, ctx, -1, true); + + const auto * cur_p = common_sampler_get_candidates(smpl); + + for (int k = 0; k < std::min(1, (int) cur_p->size); ++k) { + LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", + k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str()); + } + + // add drafted token for each sequence + const llama_token id = cur_p->data[0].id; + + common_sampler_accept(smpl, id, true); + + result.push_back(id); + + if (params.n_draft <= (int) result.size()) { + break; + } + + // only collect very high-confidence draft tokens + if (cur_p->data[0].p < params.p_min) { + break; + } + + common_batch_add(batch, id, n_past + i + 1, { 0 }, true); + + // evaluate the drafted tokens on the draft model + llama_decode_eagle(ctx, batch, data.data()); + + prompt.push_back(id); + } + + return result; +} diff --git a/common/speculative.h b/common/speculative.h index 2b51a70ca..4f1725f1e 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -26,3 +26,10 @@ llama_tokens common_speculative_gen_draft( struct common_speculative_params params, const llama_tokens & prompt, llama_token id_last); + +llama_tokens common_speculative_gen_draft_eagle( + struct common_speculative * spec, + struct common_speculative_params params, + const llama_tokens & prompt, + llama_token id_last, + std::vector & data); diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 49e4d2cf8..0f313bff1 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -31,6 +31,7 @@ else() add_subdirectory(simple-chat) add_subdirectory(speculative) add_subdirectory(speculative-simple) + add_subdirectory(speculative-simple-eagle) add_subdirectory(gen-docs) add_subdirectory(training) if (NOT GGML_BACKEND_DL) diff --git a/examples/speculative-simple-eagle/CMakeLists.txt b/examples/speculative-simple-eagle/CMakeLists.txt new file mode 100644 index 000000000..1013ca88e --- /dev/null +++ b/examples/speculative-simple-eagle/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-speculative-simple-eagle) +add_executable(${TARGET} speculative-simple-eagle.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/speculative-simple-eagle/README.md b/examples/speculative-simple-eagle/README.md new file mode 100644 index 000000000..43cf9d81d --- /dev/null +++ b/examples/speculative-simple-eagle/README.md @@ -0,0 +1,12 @@ +# llama.cpp/examples/speculative-simple-eagle + +Demonstration of basic greedy speculative decoding for EAGLE + +```bash +./bin/llama-speculative-simple-eagle \ + -m ../models/qwen2.5-32b-coder-instruct/ggml-model-q8_0.gguf \ + -md ../models/qwen2.5-1.5b-coder-instruct/ggml-model-q4_0.gguf \ + -f test.txt -c 0 -ngl 99 --color \ + --sampling-seq k --top-k 1 -fa --temp 0.0 \ + -ngld 99 --draft-max 16 --draft-min 5 --draft-p-min 0.9 +``` diff --git a/examples/speculative-simple-eagle/speculative-simple-eagle.cpp b/examples/speculative-simple-eagle/speculative-simple-eagle.cpp new file mode 100644 index 000000000..ceecb4fb5 --- /dev/null +++ b/examples/speculative-simple-eagle/speculative-simple-eagle.cpp @@ -0,0 +1,301 @@ +#include "arg.h" +#include "common.h" +#include "sampling.h" +#include "speculative.h" +#include "log.h" +#include "llama.h" +#include "../src/llama-model.h" + +#include +#include +#include +#include + +struct callback_data { + std::vector data; +}; + +static bool cb_get_hidden(struct ggml_tensor * tensor, bool ask, void * user_data) { + if (ask) { + static const char * result_norm_name = "result_norm"; + const bool is_result_norm = strcmp(tensor->name, result_norm_name) == 0; + return is_result_norm; + } + + auto * cb_data = (struct callback_data *) user_data; + auto n_bytes = ggml_nbytes(tensor); + cb_data->data.resize(n_bytes); + ggml_backend_tensor_get(tensor, cb_data->data.data(), 0, n_bytes); + + return true; +} + +int main(int argc, char ** argv) { + common_params params; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) { + return 1; + } + + if (params.n_predict < -1) { + LOG_ERR("%s: --n-predict must be >= -1\n", __func__); + return 1; + } + + common_init(); + + if (params.speculative.model.path.empty()) { + LOG_ERR("%s: --model-draft is required\n", __func__); + return 1; + } + + // init llama.cpp + llama_backend_init(); + llama_numa_init(params.numa); + + callback_data cb_data; + params.cb_eval = cb_get_hidden; + params.cb_eval_user_data = &cb_data; + + llama_model * model_tgt = NULL; + //llama_model * model_dft = NULL; + + llama_context * ctx_tgt = NULL; + llama_context * ctx_dft = NULL; + + // load the target model + common_init_result llama_init_tgt = common_init_from_params(params); + + model_tgt = llama_init_tgt.model.get(); + ctx_tgt = llama_init_tgt.context.get(); + + const llama_vocab * vocab = llama_model_get_vocab(model_tgt); + + // load the draft model + params.devices = params.speculative.devices; + params.model = params.speculative.model; + params.n_ctx = params.speculative.n_ctx; + params.n_batch = params.speculative.n_ctx > 0 ? params.speculative.n_ctx : params.n_batch; + params.n_gpu_layers = params.speculative.n_gpu_layers; + + if (params.speculative.cpuparams.n_threads > 0) { + params.cpuparams.n_threads = params.speculative.cpuparams.n_threads; + } + + params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads; + common_init_result llama_init_dft = common_init_from_params(params); + + //model_dft = llama_init_dft.model.get(); + ctx_dft = llama_init_dft.context.get(); + + // Trick: if the output buffer is in host memory, we need to allocate a new buffer for the draft model + if (ggml_backend_buffer_is_host(llama_get_model(ctx_dft)->output->buffer)) { + void * data = malloc(ggml_nbytes(llama_get_model(ctx_tgt)->output)); + llama_get_model(ctx_dft)->output->data = data; + } + // copy output parameters from target to draft + ggml_backend_tensor_copy(llama_get_model(ctx_tgt)->output, llama_get_model(ctx_dft)->output); + + if (!common_speculative_are_compatible(ctx_tgt, ctx_dft)) { + return 1; + } + + // Tokenize the prompt + std::vector inp; + inp = common_tokenize(ctx_tgt, params.prompt, true, true); + + if (llama_n_ctx(ctx_tgt) < (uint32_t) inp.size()) { + LOG_ERR("%s: the prompt exceeds the context size (%d tokens, ctx %d)\n", __func__, (int) inp.size(), llama_n_ctx(ctx_tgt)); + + return 1; + } + + if (llama_n_batch(ctx_tgt) < (uint32_t) inp.size()) { + LOG_ERR("%s: the prompt exceeds the batch size (%d tokens, batch %d)\n", __func__, (int) inp.size(), llama_n_batch(ctx_tgt)); + + return 1; + } + + LOG("\n\n"); + + for (auto id : inp) { + LOG("%s", common_token_to_piece(ctx_tgt, id).c_str()); + } + + // how many tokens to draft each time + int n_draft = params.speculative.n_max; + int n_draft_min = params.speculative.n_min; + + float p_min = params.speculative.p_min; + + int n_predict = 0; + int n_drafted = 0; + int n_accept = 0; + + // used to determine end of generation + bool has_eos = false; + + // ================================================ + // everything until here is standard initialization + // the relevant stuff for speculative decoding starts here + + const auto t_enc_start = ggml_time_us(); + + // target model sampling context + struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling); + + llama_batch temp_batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1); + int temp_n_past = 0; + for (size_t i = 0; i < inp.size(); i++) { + common_batch_add(temp_batch_tgt, inp[i], temp_n_past++, { 0 }, true); + } + + // eval the prompt + llama_decode(ctx_tgt, temp_batch_tgt); + + // note: keep the last token separate! + llama_token id_last = inp.back(); + + // all tokens currently in the target context + llama_tokens prompt_tgt(inp.begin(), inp.end()); + prompt_tgt.reserve(llama_n_ctx(ctx_tgt)); + + int n_past = inp.size() + 1; + + // init the speculator + struct common_speculative_params params_spec; + params_spec.n_draft = n_draft; + params_spec.n_reuse = llama_n_ctx(ctx_dft) - n_draft; + params_spec.p_min = p_min; + + struct common_speculative * spec = common_speculative_init(ctx_dft); + + llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1); + + const auto t_enc_end = ggml_time_us(); + + const auto t_dec_start = ggml_time_us(); + + while (true) { + // optionally, generate draft tokens that can be appended to the target batch + // + // this is the most important part of the speculation. the more probable tokens that are provided here + // the better the performance will be. in theory, this computation can be performed asynchronously and even + // offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens + // from a cache or lookup tables. + // + llama_tokens draft = common_speculative_gen_draft_eagle(spec, params_spec, prompt_tgt, id_last, cb_data.data); + + //LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str()); + + // always have a token to evaluate from before - id_last + common_batch_clear(batch_tgt); + common_batch_add (batch_tgt, id_last, n_past++, { 0 }, true); + + // evaluate the target model on [id_last, draft0, draft1, ..., draftN-1] + { + // do not waste time on small drafts + if (draft.size() < (size_t) n_draft_min) { + draft.clear(); + } + + for (size_t i = 0; i < draft.size(); ++i) { + common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true); + } + + //LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str()); + + llama_decode(ctx_tgt, batch_tgt); + } + + // sample from the full target batch and return the accepted tokens based on the target sampler + // + // for each token to be accepted, the sampler would have to sample that same token + // in such cases, instead of decoding the sampled token as we normally do, we simply continue with the + // available logits from the batch and sample the next token until we run out of logits or the sampler + // disagrees with the draft + // + const auto ids = common_sampler_sample_and_accept_n(smpl, ctx_tgt, draft); + + //LOG_DBG("ids: %s\n", string_from(ctx_tgt, ids).c_str()); + + GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token + + n_past += ids.size() - 1; + n_drafted += draft.size(); // note: we ignore the discarded small drafts + n_accept += ids.size() - 1; + n_predict += ids.size(); + + // process the accepted tokens and update contexts + // + // this is the standard token post-processing that we normally do + // in this case, we do it for a group of accepted tokens at once + // + for (size_t i = 0; i < ids.size(); ++i) { + prompt_tgt.push_back(id_last); + + id_last = ids[i]; + + if (llama_vocab_is_eog(vocab, id_last)) { + has_eos = true; + break; + } + + const std::string token_str = common_token_to_piece(ctx_tgt, id_last); + + if (params.use_color && i + 1 < ids.size()) { + LOG("\u001b[%dm%s\u001b[37m", (36 - 0 % 6), token_str.c_str()); + } else { + LOG("%s", token_str.c_str()); + } + } + + cb_data.data.resize(ids.size() * sizeof(float) * llama_model_n_embd(model_tgt)); + + LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d)\n", (int) ids.size() - 1, (int) draft.size(), id_last); + + { + LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past); + + llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, n_past, -1); + } + + if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) { + break; + } + } + + auto t_dec_end = ggml_time_us(); + + const int n_input = inp.size(); + + LOG("\n\n"); + + LOG_INF("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f)); + LOG_INF("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f)); + + LOG_INF("\n"); + LOG_INF("n_draft = %d\n", n_draft); + LOG_INF("n_predict = %d\n", n_predict); + LOG_INF("n_drafted = %d\n", n_drafted); + LOG_INF("n_accept = %d\n", n_accept); + LOG_INF("accept = %.3f%%\n", 100.0f * n_accept / n_drafted); + + LOG_INF("\n"); + LOG_INF("draft:\n\n"); + + llama_perf_context_print(ctx_dft); + + LOG_INF("\n"); + LOG_INF("target:\n\n"); + common_perf_print(ctx_tgt, smpl); + + common_sampler_free(smpl); + common_speculative_free(spec); + + llama_backend_free(); + + LOG("\n\n"); + + return 0; +} diff --git a/include/llama.h b/include/llama.h index 635508b10..4842d90d3 100644 --- a/include/llama.h +++ b/include/llama.h @@ -953,6 +953,11 @@ extern "C" { struct llama_context * ctx, struct llama_batch batch); + LLAMA_API int32_t llama_decode_eagle( + struct llama_context * ctx, + struct llama_batch batch, + void * data); + // Set the number of threads used for decoding // n_threads is the number of threads used for generation (single token) // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 0bc60565d..ac1a7d298 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -75,6 +75,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_BAILINGMOE, "bailingmoe" }, { LLM_ARCH_DOTS1, "dots1" }, { LLM_ARCH_ARCEE, "arcee" }, + { LLM_ARCH_EAGLE, "eagle" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -1620,6 +1621,23 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, } }, + { + LLM_ARCH_EAGLE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_EMBD_FC, "fc" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -1633,6 +1651,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_POS_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, {LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, {LLM_TENSOR_TOKEN_TYPES, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_EMBD_FC, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, diff --git a/src/llama-arch.h b/src/llama-arch.h index 51b242c66..ee65100a9 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -79,6 +79,7 @@ enum llm_arch { LLM_ARCH_BAILINGMOE, LLM_ARCH_DOTS1, LLM_ARCH_ARCEE, + LLM_ARCH_EAGLE, LLM_ARCH_UNKNOWN, }; @@ -229,6 +230,7 @@ enum llm_tensor { LLM_TENSOR_TOKEN_EMBD_NORM, LLM_TENSOR_TOKEN_TYPES, LLM_TENSOR_POS_EMBD, + LLM_TENSOR_EMBD_FC, LLM_TENSOR_OUTPUT, LLM_TENSOR_OUTPUT_NORM, LLM_TENSOR_ROPE_FREQS, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index f56a58e9b..60f4441c7 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -721,6 +721,50 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, return res; } +llm_graph_result_ptr llama_context::process_ubatch_eagle(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_state_i * mstate, ggml_status & ret, void * data) { + if (mstate && !mstate->apply()) { + LLAMA_LOG_ERROR("%s: failed to apply memory state\n", __func__); + ret = GGML_STATUS_FAILED; + return nullptr; + } + + auto * gf = graph_init(); + if (!gf) { + LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__); + ret = GGML_STATUS_FAILED; + return nullptr; + } + + auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mstate); + if (!res) { + LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__); + ret = GGML_STATUS_FAILED; + return nullptr; + } + + // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); + + if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); + ret = GGML_STATUS_ALLOC_FAILED; + return nullptr; + } + + res->set_inputs(&ubatch); + ggml_backend_tensor_set(res->get_hidden_states(), data, 0, ggml_element_size(res->get_hidden_states()) * model.hparams.n_embd * ubatch.n_tokens); + + const auto status = graph_compute(gf, ubatch.n_tokens > 1); + if (status != GGML_STATUS_SUCCESS) { + LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status); + ret = status; + return nullptr; + } + + ret = GGML_STATUS_SUCCESS; + + return res; +} + int llama_context::encode(const llama_batch & batch_inp) { if (batch_inp.n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); @@ -1208,6 +1252,331 @@ int llama_context::decode(const llama_batch & batch_inp) { return 0; } +int llama_context::decode_eagle(const llama_batch & batch_inp, void * data) { + if (!memory) { + LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__); + return encode(batch_inp); + } + + if (batch_inp.n_tokens == 0) { + LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); + return -1; + } + + // when computing embeddings, all tokens are output + const bool embd_all = cparams.embeddings; + + if (!batch_allocr->init(batch_inp, model.vocab, memory.get(), embd_all)) { + LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); + return -1; + } + + const llama_batch & batch = batch_allocr->get_batch(); + + const auto & vocab = model.vocab; + const auto & hparams = model.hparams; + + const int32_t n_vocab = vocab.n_tokens(); + const int64_t n_embd = hparams.n_embd; + + const uint32_t n_tokens_all = batch.n_tokens; + + GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT + + const uint32_t n_outputs_all = batch_allocr->get_n_outputs(); + + if (embd_all) { + // require that all tokens are output + if (n_outputs_all != n_tokens_all) { + LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n", + __func__, n_outputs_all, n_tokens_all); + return -1; + } + } + + GGML_ASSERT(n_tokens_all <= cparams.n_batch); + + GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens"); + + if (t_compute_start_us == 0) { + t_compute_start_us = ggml_time_us(); + } + n_queued_tokens += n_tokens_all; + + // TODO: this clear of the buffer can easily be forgotten - need something better + embd_seq.clear(); + + bool did_optimize = false; + + // handle any pending defrags/shifts + kv_self_update(false); + + llama_memory_state_ptr mstate; + + while (true) { + mstate = memory->init_batch(batch, cparams.n_ubatch, embd_all); + if (!mstate) { + return -2; + } + + switch (mstate->get_status()) { + case LLAMA_MEMORY_STATUS_SUCCESS: + { + } break; + case LLAMA_MEMORY_STATUS_NO_UPDATE: + { + LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, mstate->get_status()); + + return -2; + } + case LLAMA_MEMORY_STATUS_FAILED_PREPARE: + { + if (!did_optimize) { + did_optimize = true; + + if (kv_self_update(true)) { + LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, batch.n_tokens); + + continue; + } + } + + LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, batch.n_tokens); + + return 1; + } + case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: + { + LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, batch.n_tokens); + + return -2; + } + } + + break; + } + + // reserve output buffer + if (output_reserve(n_outputs_all) < n_outputs_all) { + LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all); + return -2; + }; + + int64_t n_outputs_prev = 0; + + do { + const auto & ubatch = mstate->get_ubatch(); + + // count the outputs in this ubatch + { + int32_t n_outputs_new = 0; + + if (n_outputs_all == n_tokens_all) { + n_outputs_new = ubatch.n_tokens; + } else { + GGML_ASSERT(ubatch.output); + for (uint32_t i = 0; i < ubatch.n_tokens; i++) { + n_outputs_new += (int32_t) (ubatch.output[i] != 0); + } + } + + // needs to happen before the graph is built + n_outputs = n_outputs_new; + } + + ggml_backend_sched_reset(sched.get()); + ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); + + ggml_status status; + const auto res = process_ubatch_eagle(ubatch, LLM_GRAPH_TYPE_DECODER, mstate.get(), status, data); + + if (!res) { + // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache + llama_pos pos_min[LLAMA_MAX_SEQ]; + for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { + pos_min[s] = std::numeric_limits::max(); + } + + // TODO: fix sequence indexing + for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { + const auto & seq_id = ubatch.seq_id[i][0]; + + pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]); + } + + for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { + if (pos_min[s] == std::numeric_limits::max()) { + continue; + } + + LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]); + + memory->seq_rm(s, pos_min[s], -1); + } + + switch (status) { + case GGML_STATUS_ABORTED: return 2; + case GGML_STATUS_ALLOC_FAILED: return -2; + case GGML_STATUS_FAILED: return -3; + case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen"); + } + } + + // plot the computation graph in dot format (for debugging purposes) + //if (n_past%100 == 0) { + // ggml_graph_dump_dot(gf, NULL, "llama.dot"); + //} + + auto * t_logits = res->get_logits(); + auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; + + if (t_embd && res->get_embd_pooled()) { + t_embd = res->get_embd_pooled(); + } + + // extract logits + if (t_logits && n_outputs > 0) { + ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); + GGML_ASSERT(backend_res != nullptr); + GGML_ASSERT(logits != nullptr); + + float * logits_out = logits + n_outputs_prev*n_vocab; + + if (n_outputs) { + GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size); + ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float)); + } + } + + // extract embeddings + if (t_embd && n_outputs > 0) { + ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); + GGML_ASSERT(backend_embd != nullptr); + + switch (cparams.pooling_type) { + case LLAMA_POOLING_TYPE_NONE: + { + // extract token embeddings + GGML_ASSERT(embd != nullptr); + float * embd_out = embd + n_outputs_prev*n_embd; + + if (n_outputs) { + GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_MEAN: + case LLAMA_POOLING_TYPE_CLS: + case LLAMA_POOLING_TYPE_LAST: + { + // extract sequence embeddings (cleared before processing each batch) + auto & embd_seq_out = embd_seq; + + for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { + continue; + } + embd_seq_out[seq_id].resize(n_embd); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_RANK: + { + // extract the rerank score - a single float per sequence + auto & embd_seq_out = embd_seq; + + for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { + continue; + } + embd_seq_out[seq_id].resize(1); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_UNSPECIFIED: + { + GGML_ABORT("unknown pooling type"); + } + } + } + + n_outputs_prev += n_outputs; + } while (mstate->next()); + + // set to total number of outputs in the batch, for use in llama_get_logits_ith + n_outputs = n_outputs_all; + + // set output mappings + if (n_outputs > 0) { + bool sorted_output = true; + + auto & out_ids = mstate->out_ids(); + + GGML_ASSERT(out_ids.size() == (size_t) n_outputs); + + for (int64_t i = 0; i < n_outputs; ++i) { + int64_t out_id = out_ids[i]; + output_ids[out_id] = i; + if (out_id != i) { + sorted_output = false; + } + } + + // make the outputs have the same order they had in the user-provided batch + // note: this is mostly relevant for recurrent models atm + if (!sorted_output) { + const uint32_t n_vocab = model.vocab.n_tokens(); + const uint64_t n_embd = model.hparams.n_embd; + + GGML_ASSERT((size_t) n_outputs == out_ids.size()); + + // TODO: is there something more efficient which also minimizes swaps? + // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort) + for (uint32_t i = 0; i < n_outputs - 1; ++i) { + uint32_t j_min = i; + for (uint32_t j = i + 1; j < n_outputs; ++j) { + if (out_ids[j] < out_ids[j_min]) { + j_min = j; + } + } + if (j_min == i) { + continue; + } + std::swap(out_ids[i], out_ids[j_min]); + if (logits_size > 0) { + for (uint32_t k = 0; k < n_vocab; k++) { + std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]); + } + } + if (embd_size > 0) { + for (uint32_t k = 0; k < n_embd; k++) { + std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]); + } + } + } + + std::fill(output_ids.begin(), output_ids.end(), -1); + + for (uint32_t i = 0; i < n_outputs; ++i) { + output_ids[out_ids[i]] = i; + } + } + } + + // wait for the computation to finish (automatically done when obtaining the model output) + //synchronize(); + + // Reset state for the next token before backend sync, to allow the CPU activities in the reset to + // overlap with device computation. + ggml_backend_sched_reset(sched.get()); + + return 0; +} + // // output // @@ -2786,6 +3155,18 @@ int32_t llama_decode( return ret; } +int32_t llama_decode_eagle( + llama_context * ctx, + llama_batch batch, + void * data) { + const int ret = ctx->decode_eagle(batch, data); + if (ret != 0 && ret != 1) { + LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); + } + + return ret; +} + // // perf // diff --git a/src/llama-context.h b/src/llama-context.h index 040f03ae4..367a61cac 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -102,9 +102,18 @@ struct llama_context { llama_memory_state_i * mstate, ggml_status & ret); + llm_graph_result_ptr process_ubatch_eagle( + const llama_ubatch & ubatch, + llm_graph_type gtype, + llama_memory_state_i * mstate, + ggml_status & ret, + void * data); + int encode(const llama_batch & batch_inp); int decode(const llama_batch & batch_inp); + int decode_eagle(const llama_batch & batch_inp, void * data); + // // state save/load // diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 65d98cbbb..79bac1b91 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -914,6 +914,26 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { return cur; } +ggml_tensor * llm_graph_context::build_inp_embd_fc(ggml_tensor * embd, ggml_tensor * fc, ggml_tensor * fc_b) const { + ggml_tensor * cur = nullptr; + + ggml_tensor * hidden_states = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, ubatch.n_tokens); + ggml_set_input(hidden_states); + + cur = ggml_concat(ctx0, embd, hidden_states, 0); + cur = ggml_mul_mat(ctx0, fc, cur); + if (fc_b) { + cur = ggml_add(ctx0, cur, fc_b); + } + + cb(cur, "inp_embd_fc_out", -1); + + res->set_hidden_states(hidden_states); + + return cur; +} + + ggml_tensor * llm_graph_context::build_inp_pos() const { auto inp = std::make_unique(n_pos_per_embd()); diff --git a/src/llama-graph.h b/src/llama-graph.h index 58845e284..c8182843e 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -347,6 +347,9 @@ class llm_graph_result_i { virtual ggml_tensor * get_embd() = 0; virtual ggml_tensor * get_embd_pooled() = 0; + virtual void set_hidden_states(ggml_tensor * hidden_states) = 0; + virtual ggml_tensor * get_hidden_states() = 0; + virtual void set_inputs(const llama_ubatch * ubatch) = 0; }; @@ -362,6 +365,13 @@ class llm_graph_result : public llm_graph_result_i { ggml_tensor * get_embd() override { return t_embd; } ggml_tensor * get_embd_pooled() override { return t_embd_pooled; } + void set_hidden_states(ggml_tensor * hidden_states) override { + t_hidden_states = hidden_states; + } + ggml_tensor * get_hidden_states() override { + return t_hidden_states; + } + void set_inputs(const llama_ubatch * ubatch) override { for (auto & input : inputs) { input->set_input(ubatch); @@ -379,6 +389,8 @@ class llm_graph_result : public llm_graph_result_i { ggml_tensor * t_embd = nullptr; ggml_tensor * t_embd_pooled = nullptr; + ggml_tensor * t_hidden_states = nullptr; + std::vector inputs; }; @@ -531,6 +543,7 @@ struct llm_graph_context { // ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const; + ggml_tensor * build_inp_embd_fc(ggml_tensor * embd, ggml_tensor * fc, ggml_tensor * fc_b) const; ggml_tensor * build_inp_pos() const; ggml_tensor * build_inp_attn_scale() const; ggml_tensor * build_inp_out_ids() const; diff --git a/src/llama-model-loader.h b/src/llama-model-loader.h index 0f52b011b..5e6f5f65e 100644 --- a/src/llama-model-loader.h +++ b/src/llama-model-loader.h @@ -60,6 +60,7 @@ struct llama_model_loader { static const int TENSOR_NOT_REQUIRED = 1; static const int TENSOR_DUPLICATED = 2; + static const int TENSOR_DUPLICATED_EAGLE = 4; int n_kv = 0; int n_tensors = 0; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index a5853f8b1..1c52dc9fa 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1484,6 +1484,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_EAGLE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -1613,6 +1617,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const auto TENSOR_DUPLICATED = llama_model_loader::TENSOR_DUPLICATED; const auto TENSOR_NOT_REQUIRED = llama_model_loader::TENSOR_NOT_REQUIRED; + const auto TENSOR_DUPLICATED_EAGLE = llama_model_loader::TENSOR_DUPLICATED_EAGLE; // create tensors for the weights { @@ -1660,6 +1665,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { tn_tensor = LLM_TENSOR_OUTPUT; } + if (tn.tensor == LLM_TENSOR_TOKEN_EMBD && flags & TENSOR_DUPLICATED_EAGLE) { + tn_tensor = LLM_TENSOR_OUTPUT; + } + llm_tensor_info info; try { info = llm_tensor_info_for(tn_tensor); @@ -4268,6 +4277,38 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_EAGLE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + embd_fc = create_tensor(tn(LLM_TENSOR_EMBD_FC, "weight"), {n_embd * 2, n_embd}, 0); + embd_fc_b = create_tensor(tn(LLM_TENSOR_EMBD_FC, "bias"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED_EAGLE); + ml.n_created--; + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); } @@ -13731,6 +13772,161 @@ struct llm_build_arcee : public llm_graph_context { } }; +struct llm_build_eagle : public llm_graph_context { + llm_build_eagle(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // For EAGLE architecture + inpL = build_inp_embd_fc(inpL, model.embd_fc, model.embd_fc_b); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + cur = inpL; + + // self-attention + { + // rope freq factors for llama3; may return nullptr for llama2 and other models + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network (non-MoE) + if (model.layers[il].ffn_gate_inp == nullptr) { + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(cur, "ffn_moe_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const { llama_memory_i * res; @@ -14102,6 +14298,10 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_EAGLE: + { + llm = std::make_unique(*this, params, gf); + } break; default: GGML_ABORT("fatal error"); } @@ -14227,6 +14427,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_RWKV7: case LLM_ARCH_ARWKV7: case LLM_ARCH_WAVTOKENIZER_DEC: + case LLM_ARCH_EAGLE: return LLAMA_ROPE_TYPE_NONE; // use what we call a normal RoPE, operating on pairs of consecutive head values diff --git a/src/llama-model.h b/src/llama-model.h index 06e6c6879..12469254e 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -338,6 +338,8 @@ struct llama_model { struct ggml_tensor * pos_embd = nullptr; struct ggml_tensor * tok_norm = nullptr; struct ggml_tensor * tok_norm_b = nullptr; + struct ggml_tensor * embd_fc = nullptr; + struct ggml_tensor * embd_fc_b = nullptr; struct ggml_tensor * output_norm = nullptr; struct ggml_tensor * output_norm_b = nullptr;