-
Notifications
You must be signed in to change notification settings - Fork 110
Open
Description
Looks like there may be a small bug in the generation:
ml-4m/fourm/models/generate.py
Line 138 in 2db0125
| eos_idx = torch.where(mod_dict[domain]['tensor'] == eos_id)[1][0].item() |
The input masks for text are determined by the position of the first batch eos only but subsequently applied to all batches. Is this intentional? Looks like it's commonly used with single batch generation (in the examples) so this may have fallen through the cracks? If not I'd be curious about the intention here, otherwise happy to make a PR.
Great stuff btw, thanks for open sourcing this!
Metadata
Metadata
Assignees
Labels
No labels