Hi, Line 320 of /runtime/communication.py should be set as [target_tensor_name, i, self.training_tensor_dtypes[target_tensor_name], , otherwise if one of the target to calculate the loss is not type int64, there will be an error.
The original code snippet:
self.start_helper_thread( self.recv_helper_thread_args, recv_helper_thread, [target_tensor_name, i, torch.int64, False], num_iterations_for_backward_threads)
, my suggestion is:
self.start_helper_thread( self.recv_helper_thread_args, recv_helper_thread, [target_tensor_name, i, self.training_tensor_dtypes[target_tensor_name], False], num_iterations_for_backward_threads).
In this case, the code is consistent with the data type specified in the main.py file.