-
Notifications
You must be signed in to change notification settings - Fork 45
L519 : Repeating single chunk embedding T times correct ? #5
Copy link
Copy link
Open
Description
https://github.com/simulanics/REFRAG/blob/main/refrag.py#L519
- In the above reconstruction loss computation to match input and output sequence length you've copied the (1, D) chunk embedding T times to get (T, D). And also using a for-loop passed through K chunks one-at-a-time.
Is this the correct approach ? Because then we are feeding the same embedding during decoding and expecting decoder to learn to decode new tokens each time ? Maybe the authors intended some prefix-tuning based method ? Something like this:
(say chunk size = 2)
Input Ids -> [chunk_1_embed] [chunk_2_embed] [s1_embed] [s2_embed] [s3_embed] [s4_embed]
Labels -> [-100] [-100] [s1_embed] [s2_embed] [s3_embed] [s4_embed]
- Also the authors in paper mention that they start with 1-2 chunks initially and then pass L chunks at the same time to decode the sequence back (see Curriculum learning paragraph in page 4). Because in phase-2 CPT the model will receive L chunks sequentially. Reconstruction task should account for this instead of feeding one chunk at a time no ?
Were you able to reproduce paper results using this approach ?
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels