Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 150 additions & 0 deletions common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,153 @@ llama_tokens common_speculative_gen_draft(

return result;
}

llama_tokens common_speculative_gen_draft_eagle(
struct common_speculative * spec,
struct common_speculative_params params,
const llama_tokens & prompt_tgt,
llama_token id_last,
std::vector<uint8_t> & data) {
auto & batch = spec->batch;
auto & ctx = spec->ctx;
auto & smpl = spec->smpl;
auto & prompt = spec->prompt;

auto * mem = llama_get_memory(ctx);

int reuse_i = 0;
int reuse_n = 0;

const int n_ctx = llama_n_ctx(ctx) - params.n_draft;

const int i_start = std::max<int>(1, (int) prompt_tgt.size() - n_ctx);

int n_accepted_draft_tokens = data.size() / sizeof(float) / llama_model_n_embd(llama_get_model(ctx)) - 1;

// reuse as much as possible from the old draft context
// ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
for (int i = 0; i < (int) prompt.size(); ++i) {
int cur = 0;
while (i_start + cur < (int) prompt_tgt.size() &&
i + cur < (int) prompt.size() &&
prompt_tgt[i_start + cur] == prompt[i + cur]) {
cur++;
}

cur = (cur - n_accepted_draft_tokens) > 0 ? (cur - n_accepted_draft_tokens) : cur;

if ((cur >= params.n_reuse || n_ctx >= (int) prompt_tgt.size()) && cur > reuse_n) {
reuse_i = i;
reuse_n = cur;
}
}

LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt.size());

llama_tokens result;
result.reserve(params.n_draft);

if (reuse_n == 0) {
llama_memory_clear(mem, false);

prompt.clear();
} else {
// this happens when a previous draft has been discarded (for example, due to being too small), but the
// target model agreed with it. in this case, we simply pass back the previous results to save compute
if (reuse_i + reuse_n < (int) prompt.size() && prompt[reuse_i + reuse_n] == id_last) {
for (int i = reuse_i + reuse_n + 1; i < (int) prompt.size(); ++i) {
result.push_back(prompt[i]);

if (params.n_draft <= (int) result.size()) {
break;
}
}

return result;
}

if (reuse_i > 0) {
llama_memory_seq_rm (mem, 0, 0, reuse_i);
llama_memory_seq_add(mem, 0, reuse_i, -1, -reuse_i);

prompt.erase(prompt.begin(), prompt.begin() + reuse_i);
}

if (reuse_n < (int) prompt.size()) {
llama_memory_seq_rm (mem, 0, reuse_n, -1);

prompt.erase(prompt.begin() + reuse_n, prompt.end());
}
}

// prepare a batch to evaluate any new tokens in the prompt
common_batch_clear(batch);

for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) {
//LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, (i < prompt_tgt.size() - 1) ? false : true);

prompt.push_back(prompt_tgt[i]);
}

// we should rarely end-up here during normal decoding
if (batch.n_tokens > 0) {
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());

llama_decode_eagle(ctx, batch, data.data());
}

const llama_pos n_past = prompt.size();

LOG_DBG("%s: n_past = %d\n", __func__, n_past);

common_batch_clear(batch);
common_batch_add (batch, id_last, n_past, { 0 }, true);

prompt.push_back(id_last);

//LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str());

llama_decode_eagle(ctx, batch, data.data());

common_sampler_reset(smpl);

// sample n_draft tokens from the draft model
for (int i = 0; i < params.n_draft; ++i) {
common_batch_clear(batch);

common_sampler_sample(smpl, ctx, -1, true);

const auto * cur_p = common_sampler_get_candidates(smpl);

for (int k = 0; k < std::min(1, (int) cur_p->size); ++k) {
LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str());
}

// add drafted token for each sequence
const llama_token id = cur_p->data[0].id;

common_sampler_accept(smpl, id, true);

result.push_back(id);

if (params.n_draft <= (int) result.size()) {
break;
}

// only collect very high-confidence draft tokens
if (cur_p->data[0].p < params.p_min) {
break;
}

common_batch_add(batch, id, n_past + i + 1, { 0 }, true);

// evaluate the drafted tokens on the draft model
llama_decode_eagle(ctx, batch, data.data());

prompt.push_back(id);
}

return result;
}
7 changes: 7 additions & 0 deletions common/speculative.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,10 @@ llama_tokens common_speculative_gen_draft(
struct common_speculative_params params,
const llama_tokens & prompt,
llama_token id_last);

llama_tokens common_speculative_gen_draft_eagle(
struct common_speculative * spec,
struct common_speculative_params params,
const llama_tokens & prompt,
llama_token id_last,
std::vector<uint8_t> & data);
1 change: 1 addition & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ else()
add_subdirectory(simple-chat)
add_subdirectory(speculative)
add_subdirectory(speculative-simple)
add_subdirectory(speculative-simple-eagle)
add_subdirectory(gen-docs)
add_subdirectory(training)
if (NOT GGML_BACKEND_DL)
Expand Down
5 changes: 5 additions & 0 deletions examples/speculative-simple-eagle/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
set(TARGET llama-speculative-simple-eagle)
add_executable(${TARGET} speculative-simple-eagle.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17)
12 changes: 12 additions & 0 deletions examples/speculative-simple-eagle/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# llama.cpp/examples/speculative-simple-eagle

Demonstration of basic greedy speculative decoding for EAGLE

```bash
./bin/llama-speculative-simple-eagle \
-m ../models/qwen2.5-32b-coder-instruct/ggml-model-q8_0.gguf \
-md ../models/qwen2.5-1.5b-coder-instruct/ggml-model-q4_0.gguf \
-f test.txt -c 0 -ngl 99 --color \
--sampling-seq k --top-k 1 -fa --temp 0.0 \
-ngld 99 --draft-max 16 --draft-min 5 --draft-p-min 0.9
```
Loading
Loading