Thanks for the awesome work. I see you print the device for the input ids before and after completion generation. I'm assuming that's because Dataparallel places the model on separate GPUs and automatically moves the input tensors to the correct GPU if there is a mismatch between the GPUs where the model and input tensors are located?
Also, would be good to see a DistributedDataParallel implementation of this code, as that could scale beyond single node..