-
Notifications
You must be signed in to change notification settings - Fork 41
Description
Hi @jianc99 , great work!!
Currently, DFlash extracts hidden states from 5 target layers, concatenates them (5 × d_target), projects down to d_target via a learned projection, and injects the same fused vector into every draft layer's KV cache.
Have you considered a simpler alternative? Instead of fusing, just map a subset (the same 5 layers) of target layer's hidden state directly to the corresponding draft layer? Eg. Let's say target has 40 layers and draft has 5 layers, then
Draft layer 1 ← Target layer 1 hidden state
Draft layer 2 ← Target layer 8 hidden state
...and so on.
Since the draft model shares d_target with the target model, no projection is needed at all...the hidden states can go directly through each draft layer's W_K, W_V. This avoids the lossy 5 × d_target → d_target compression, removes the extra projection layer entirely, and still doesn't restrict information access across target layers since draft layers are stacked and information from earlier injections propagates through the residual stream.
Curious if this was ablated and if there were specific reasons it underperformed...