Skip to content

Why fused KV injection instead of independent per-layer injection? #26

@shreyas269

Description

@shreyas269

Hi @jianc99 , great work!!

Currently, DFlash extracts hidden states from 5 target layers, concatenates them (5 × d_target), projects down to d_target via a learned projection, and injects the same fused vector into every draft layer's KV cache.

Have you considered a simpler alternative? Instead of fusing, just map a subset (the same 5 layers) of target layer's hidden state directly to the corresponding draft layer? Eg. Let's say target has 40 layers and draft has 5 layers, then

Draft layer 1 ← Target layer 1 hidden state
Draft layer 2 ← Target layer 8 hidden state
...and so on.

Since the draft model shares d_target with the target model, no projection is needed at all...the hidden states can go directly through each draft layer's W_K, W_V. This avoids the lossy 5 × d_target → d_target compression, removes the extra projection layer entirely, and still doesn't restrict information access across target layers since draft layers are stacked and information from earlier injections propagates through the residual stream.

Curious if this was ablated and if there were specific reasons it underperformed...

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions