Skip to content

Input masks for generation - Potential small bug. #20

@nilsec

Description

@nilsec

Looks like there may be a small bug in the generation:

eos_idx = torch.where(mod_dict[domain]['tensor'] == eos_id)[1][0].item()

The input masks for text are determined by the position of the first batch eos only but subsequently applied to all batches. Is this intentional? Looks like it's commonly used with single batch generation (in the examples) so this may have fallen through the cracks? If not I'd be curious about the intention here, otherwise happy to make a PR.

Great stuff btw, thanks for open sourcing this!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions