Skip to content

printing the device before/after forward pass #15

@ankur6ue

Description

@ankur6ue

Thanks for the awesome work. I see you print the device for the input ids before and after completion generation. I'm assuming that's because Dataparallel places the model on separate GPUs and automatically moves the input tensors to the correct GPU if there is a mismatch between the GPUs where the model and input tensors are located?

Also, would be good to see a DistributedDataParallel implementation of this code, as that could scale beyond single node..

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions