diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index ddef4be14..27e5550c0 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -2367,6 +2367,11 @@ def __init__( # Set EmbeddingLocation.HOST to make embedding op in FBGEMM choose CPU path. # But the tensor will still be created on MTIA with device type "mtia". managed.append(EmbeddingLocation.HOST) + elif device is not None and device.type == torch._C._get_privateuse1_backend_name(): + compute_devices.append(ComputeDevice.PRIVATEUSE1) + managed.append( + compute_kernel_to_embedding_location(table.compute_kernel) + ) else: compute_devices.append(ComputeDevice.CPU) managed.append(EmbeddingLocation.HOST) @@ -2480,6 +2485,7 @@ def __init__( device is None or device.type == "cpu" or (not (torch.cuda.is_available() or torch.mtia.is_available())) + or (not (torch.get_device_module(device) and torch.get_device_module(device).is_available())) ) self._emb_module: DenseTableBatchedEmbeddingBagsCodegen = ( DenseTableBatchedEmbeddingBagsCodegen( @@ -3259,6 +3265,11 @@ def __init__( # Set EmbeddingLocation.HOST to make embedding op in FBGEMM choose CPU path. # But the tensor will still be created on MTIA with device type "mtia". managed.append(EmbeddingLocation.HOST) + elif device is not None and device.type == torch._C._get_privateuse1_backend_name(): + compute_devices.append(ComputeDevice.PRIVATEUSE1) + managed.append( + compute_kernel_to_embedding_location(table.compute_kernel) + ) else: compute_devices.append(ComputeDevice.CPU) managed.append(EmbeddingLocation.HOST) @@ -3374,6 +3385,7 @@ def __init__( device is None or device.type == "cpu" or (not (torch.cuda.is_available() or torch.mtia.is_available())) + or (not (torch.get_device_module(device) and torch.get_device_module(device).is_available())) ) self._emb_module: DenseTableBatchedEmbeddingBagsCodegen = ( DenseTableBatchedEmbeddingBagsCodegen( diff --git a/torchrec/distributed/comm_ops.py b/torchrec/distributed/comm_ops.py index f9fcbaebd..e1ed6cee9 100644 --- a/torchrec/distributed/comm_ops.py +++ b/torchrec/distributed/comm_ops.py @@ -157,7 +157,8 @@ def _wait_impl(self) -> W: """ ret = self.wait_function.apply(self.pg, self, self.dummy_tensor) - if isinstance(ret, torch.Tensor) and ret.device.type == "cuda": + if isinstance(ret, torch.Tensor) and (ret.device.type == "cuda" or + ret.device.type == torch._C._get_privateuse1_backend_name()): ret.record_stream(torch.get_device_module(ret.device).current_stream()) self.req = None self.tensor = None diff --git a/torchrec/distributed/dist_data.py b/torchrec/distributed/dist_data.py index 99ed9e049..d129dcae0 100644 --- a/torchrec/distributed/dist_data.py +++ b/torchrec/distributed/dist_data.py @@ -794,7 +794,7 @@ def __init__( # BUG: device will default to cuda if cpu specified self._device_type: str = ( device.type - if device is not None and device.type in {"meta", "cuda", "mtia"} + if device is not None and device.type in {"meta", "cuda", "mtia", torch._C._get_privateuse1_backend_name()} else "cuda" ) assert self._world_size == len(splits)