diff --git a/test/neuron/run_tests.sh b/test/neuron/run_tests.sh index a68e0671a3b..b052b754fb1 100755 --- a/test/neuron/run_tests.sh +++ b/test/neuron/run_tests.sh @@ -257,6 +257,7 @@ function run_xla_op_tests3 { run_test_multi_device "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py" run_test_multi_device "$_TEST_DIR/spmd/test_xla_dtensor_spec_conv.py" run_test_multi_device "$_TEST_DIR/spmd/test_dtensor_redistribute.py" + run_test_multi_device "$_TEST_DIR/spmd/test_xla_dtensor_to_local.py" run_test_multi_device "$_TEST_DIR/spmd/test_xla_sharded_tensor.py" run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py" #run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py" diff --git a/test/run_tests.sh b/test/run_tests.sh index bb03d7abe16..10715855aa6 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -238,6 +238,7 @@ function run_xla_op_tests3 { run_test_multi_devices "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py" run_test_multi_devices "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py" run_test_multi_devices "$_TEST_DIR/spmd/test_dtensor_redistribute.py" + run_test_multi_devices "$_TEST_DIR/spmd/test_xla_dtensor_to_local.py" run_test_multi_devices "$_TEST_DIR/spmd/test_xla_sharded_tensor.py" run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py" run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py" diff --git a/test/spmd/test_xla_dtensor_to_local.py b/test/spmd/test_xla_dtensor_to_local.py new file mode 100644 index 00000000000..2335720a027 --- /dev/null +++ b/test/spmd/test_xla_dtensor_to_local.py @@ -0,0 +1,116 @@ +import sys +import unittest +import torch +import numpy as np + +from torch.distributed.tensor import DeviceMesh +from torch.distributed._tensor import DTensor +from torch.distributed.tensor.placement_types import Replicate, Shard +import torch_xla +import torch_xla.runtime as xr +import torch_xla.core.xla_model as xm +from torch_xla.distributed.spmd.xla_sharded_tensor import XLAShardedTensor +import test_xla_sharding_base + + +class DTensorXLAFromLocalConversionTest(test_xla_sharding_base.XlaShardingTest): + """ + Test suite for the automatic conversion of regular tensors to XLAShardedTensor + in DTensor.from_local() when using XLA device mesh. + """ + + def test_to_local(self): + from torch.distributed.tensor import distribute_tensor + world_size = xr.global_runtime_device_count() + mesh = DeviceMesh("xla", list(range(world_size))) + + big_tensor = torch.randn(100000, 88) + sharded_tensor = XLAShardedTensor(big_tensor, mesh, [Shard(0)]) + + local_tensor = sharded_tensor.to_local() + + # Verify the shapes are the same + self.assertEqual(local_tensor.shape, big_tensor.shape) + + # Check the value of the tensor + torch.testing.assert_close(local_tensor, big_tensor, check_device=False) + + def test_to_local_requires_grad(self): + """Test that gradients flow correctly through to_local().""" + # Create a tensor with requires_grad=True + world_size = xr.global_runtime_device_count() + mesh = DeviceMesh("xla", list(range(world_size))) + + tensor = torch.randn(100_000, 88, requires_grad=True) + + # Create XLAShardedTensor + sharded_tensor = XLAShardedTensor( + tensor, mesh, [Shard(0)], requires_grad=tensor.requires_grad) + + # Verify requires_grad is set + self.assertTrue(sharded_tensor.requires_grad) + + res = sharded_tensor.sum() + res.backward() + + # Verify grad are calculated + self.assertTrue(sharded_tensor.grad is not None) + + # Call to local function + local_tensor = sharded_tensor.to_local() + + # Verify requires_grad is preserved + self.assertTrue(local_tensor.requires_grad) + + # All gradients should be 1.0 since we did a sum() + self.assertTrue(torch.allclose(local_tensor.grad, torch.ones_like(tensor))) + + def test_to_local_grad_independence(self): + """Test that gradients are independent between original and local tensor.""" + world_size = xr.global_runtime_device_count() + mesh = DeviceMesh("xla", list(range(world_size))) + + tensor = torch.randn(100_000, 88, requires_grad=True) + sharded_tensor = XLAShardedTensor( + tensor, mesh, [Shard(0)], requires_grad=tensor.requires_grad) + + # Create gradients + res = sharded_tensor.sum() + res.backward() + + # Get local tensor + local_tensor = sharded_tensor.to_local() + + # Verify gradients are initially the same + self.assertTrue(torch.allclose(local_tensor.grad, sharded_tensor.grad)) + + # Modify local tensor's gradient + local_tensor.grad[0, 0] = 999.0 + + # Verify gradients are now independent (not the same object) + self.assertFalse(local_tensor.grad is sharded_tensor.grad) + self.assertFalse(torch.allclose(local_tensor.grad, sharded_tensor.grad)) + + def test_to_local_grad_none_handling(self): + """Test that to_local() handles None gradients correctly.""" + world_size = xr.global_runtime_device_count() + mesh = DeviceMesh("xla", list(range(world_size))) + + tensor = torch.randn(100_000, 88, requires_grad=True) + sharded_tensor = XLAShardedTensor( + tensor, mesh, [Shard(0)], requires_grad=tensor.requires_grad) + + # Don't do backward pass, so grad remains None + self.assertIsNone(sharded_tensor.grad) + + # Get local tensor + local_tensor = sharded_tensor.to_local() + + # Verify local tensor has correct properties + self.assertTrue(local_tensor.requires_grad) + self.assertIsNone(local_tensor.grad) + + +if __name__ == "__main__": + result = unittest.main(exit=False) + sys.exit(0 if result.result.wasSuccessful() else 1) diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index e1ad7c0023a..0712e131a81 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -63,6 +63,7 @@ run_test "$_TEST_DIR/spmd/test_fsdp_v2.py" run_test "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py" run_test "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py" run_test "$_TEST_DIR/spmd/test_dtensor_redistribute.py" +run_test "$_TEST_DIR/spmd/test_xla_dtensor_to_local.py" run_test "$_TEST_DIR/spmd/test_xla_sharded_tensor.py" run_test "$_TEST_DIR/test_gradient_accumulation.py" XLA_EXPERIMENTAL=nonzero:masked_select:nms run_test "$_TEST_DIR/ds/test_dynamic_shape_models.py" -v diff --git a/torch_xla/distributed/spmd/xla_sharded_tensor.py b/torch_xla/distributed/spmd/xla_sharded_tensor.py index 652a2011cbd..1355111eeb6 100644 --- a/torch_xla/distributed/spmd/xla_sharded_tensor.py +++ b/torch_xla/distributed/spmd/xla_sharded_tensor.py @@ -151,6 +151,34 @@ def load_local_shards_(self, shards: List[XLAShard]): # Invalidate cached spec since the global_tensor data has changed self.invalidate_spec_cache() + def to_local(self): + """ + Returns the local representation of the XLAShardedTensor. + + This method returns the global tensor representation, which contains + the combined data across all devices. The returned tensor is on the + same device as the original XLAShardedTensor. The returned tensor + will have the same requires_grad value as the XLAShardedTensor. + If the original tensor has gradients, those will be preserved. + + Returns: + torch.Tensor: The global tensor representation with appropriate requires_grad setting. + """ + + if not self.requires_grad: + # When requires_grad is False, global_tensor is the original tensor + return self.global_tensor + else: + # When requires_grad is True, global_tensor is detached + # Create a new tensor with the same values of global_tensor + result = self.global_tensor.clone() + # Since global tensor is detached, add requires_grad and grad values back to the local tensor + if self.requires_grad: + result.requires_grad_(self.requires_grad) + result.grad = self.grad.clone() if self.grad is not None else None + + return result + @property def sharding_spec(self): return torch_xla._XLAC._get_xla_sharding_spec(self.global_tensor)