@@ -46,8 +46,10 @@ class XLAShardingTest : public AtenXlaTensorTestBase {
4646
4747TEST_F (XLAShardingTest, GetShardShape) {
4848 auto tensor = at::ones ({8 , 7 }, at::TensorOptions (at::kFloat ));
49+ XLA_ASSIGN_OR_THROW (const torch::lazy::BackendDevice* default_device,
50+ bridge::GetDefaultDevice ());
4951 xla::Shape tensor_shape =
50- CreateComputationShapeFromTensor (tensor, bridge::GetDefaultDevice () );
52+ CreateComputationShapeFromTensor (tensor, default_device );
5153 xla::Array2D<int64_t > mesh ({
5254 {0 , 1 },
5355 {2 , 3 },
@@ -70,8 +72,10 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
7072 std::vector<std::string> devices = {" TPU:0" , " TPU:1" , " TPU:2" , " TPU:3" };
7173
7274 auto tensor = at::ones ({8 , 7 }, at::TensorOptions (at::kFloat ));
75+ XLA_ASSIGN_OR_THROW (const torch::lazy::BackendDevice* default_device,
76+ bridge::GetDefaultDevice ());
7377 xla::Shape tensor_shape =
74- CreateComputationShapeFromTensor (tensor, bridge::GetDefaultDevice () );
78+ CreateComputationShapeFromTensor (tensor, default_device );
7579 xla::Array2D<int64_t > mesh ({
7680 {0 , 1 },
7781 {2 , 3 },
@@ -126,11 +130,13 @@ TEST_F(XLAShardingTest, ShardTensor) {
126130
127131 // 1D tiled
128132 at::Tensor tensor = at::ones ({8 }, at::TensorOptions (at::kFloat ));
133+ XLA_ASSIGN_OR_THROW (const torch::lazy::BackendDevice* default_device,
134+ bridge::GetDefaultDevice ());
129135 xla::Shape tensor_shape =
130- CreateComputationShapeFromTensor (tensor, bridge::GetDefaultDevice () );
136+ CreateComputationShapeFromTensor (tensor, default_device );
131137 xla::OpSharding sharding =
132138 xla::HloSharding::Tile1D (
133- CreateComputationShapeFromTensor (tensor, bridge::GetDefaultDevice () ),
139+ CreateComputationShapeFromTensor (tensor, default_device ),
134140 devices.size ())
135141 .ToProto ();
136142 auto sharding_spec =
@@ -144,8 +150,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
144150 // 2D tiled, The first dim is halved and the last replicated. The last shard
145151 // size should be smaller in dim=1 because it's not evenly divisible.
146152 tensor = at::ones ({8 , 7 , 4 }, at::TensorOptions (at::kFloat ));
147- tensor_shape =
148- CreateComputationShapeFromTensor (tensor, bridge::GetDefaultDevice ());
153+ tensor_shape = CreateComputationShapeFromTensor (tensor, default_device);
149154 xla::Array2D<int64_t > mesh ({
150155 {0 , 1 , 2 , 3 },
151156 {4 , 5 , 6 , 7 },
@@ -181,8 +186,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
181186 // last shard size should be smaller in dim=2 because it's not evenly
182187 // divisible.
183188 tensor = at::ones ({1 , 8 , 7 , 4 }, at::TensorOptions (at::kFloat ));
184- tensor_shape =
185- CreateComputationShapeFromTensor (tensor, bridge::GetDefaultDevice ());
189+ tensor_shape = CreateComputationShapeFromTensor (tensor, default_device);
186190 xla::Array4D<int64_t > tesseract ({{{{0 , 1 }, {2 , 3 }, {4 , 5 }, {6 , 7 }}}});
187191 sharding = xla::HloSharding::Tile (tesseract).ToProto ();
188192 sharding_spec =
@@ -204,8 +208,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
204208 // last shard size should be smaller in dim=2 because it's not evenly
205209 // divisible.
206210 tensor = at::ones ({10 , 1 , 8 , 7 , 4 }, at::TensorOptions (at::kFloat ));
207- tensor_shape =
208- CreateComputationShapeFromTensor (tensor, bridge::GetDefaultDevice ());
211+ tensor_shape = CreateComputationShapeFromTensor (tensor, default_device);
209212 xla::Array<int64_t > hypercube (std::vector<int64_t >{1 , 1 , 2 , 2 , 2 });
210213 hypercube.FillIota (0 );
211214 sharding = xla::HloSharding::Tile (hypercube).ToProto ();
@@ -230,8 +233,10 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) {
230233
231234 // 2D tiled, The first dim is halved and the last replicated.
232235 at::Tensor tensor = at::ones ({8 , 7 , 4 }, at::TensorOptions (at::kFloat ));
236+ XLA_ASSIGN_OR_THROW (const torch::lazy::BackendDevice* default_device,
237+ bridge::GetDefaultDevice ());
233238 xla::Shape tensor_shape =
234- CreateComputationShapeFromTensor (tensor, bridge::GetDefaultDevice () );
239+ CreateComputationShapeFromTensor (tensor, default_device );
235240 xla::Array2D<int64_t > mesh ({
236241 {4 , 5 , 0 , 1 },
237242 {6 , 7 , 2 , 3 },
@@ -265,8 +270,10 @@ TEST_F(XLAShardingTest, ShardTensorMiniBatch) {
265270 std::vector<std::string> devices = {" TPU:4" , " TPU:5" , " TPU:6" , " TPU:7" };
266271 at::Tensor minibatch_tensor =
267272 at::ones ({8 , 7 , 4 }, at::TensorOptions (at::kFloat ));
268- xla::Shape global_shape = CreateComputationShapeFromTensor (
269- minibatch_tensor, bridge::GetDefaultDevice ());
273+ XLA_ASSIGN_OR_THROW (const torch::lazy::BackendDevice* default_device,
274+ bridge::GetDefaultDevice ());
275+ xla::Shape global_shape =
276+ CreateComputationShapeFromTensor (minibatch_tensor, default_device);
270277 global_shape.set_dimensions (
271278 0 , minibatch_tensor.sizes ()[0 ] * 2 ); // Assuming 2 hosts
272279 xla::Array3D<int64_t > mesh ({
@@ -292,8 +299,10 @@ TEST_F(XLAShardingTest, ShardTensorMiniBatch) {
292299
293300TEST_F (XLAShardingTest, EqualShardingSpecs) {
294301 auto tensor = at::ones ({8 , 7 }, at::TensorOptions (at::kFloat ));
302+ XLA_ASSIGN_OR_THROW (const torch::lazy::BackendDevice* default_device,
303+ bridge::GetDefaultDevice ());
295304 xla::Shape tensor_shape =
296- CreateComputationShapeFromTensor (tensor, bridge::GetDefaultDevice () );
305+ CreateComputationShapeFromTensor (tensor, default_device );
297306 XLATensor::ShardingSpec tiled_2d (xla::HloSharding::Tile ({
298307 {0 , 1 , 2 , 3 },
299308 {4 , 5 , 6 , 7 },
@@ -319,12 +328,13 @@ TEST_F(XLAShardingTest, CreateTensorsData) {
319328
320329 std::vector<at::Tensor> tensors (3 );
321330 auto tensor = at::ones ({8 , 8 }, at::TensorOptions (at::kFloat ));
331+ XLA_ASSIGN_OR_THROW (const torch::lazy::BackendDevice* default_device,
332+ bridge::GetDefaultDevice ());
322333 xla::Shape tensor_shape =
323- CreateComputationShapeFromTensor (tensor, bridge::GetDefaultDevice () );
334+ CreateComputationShapeFromTensor (tensor, default_device );
324335 std::fill_n (tensors.begin (), tensors.size (), tensor);
325336 std::vector<std::string> devices (3 );
326- std::fill_n (devices.begin (), devices.size (),
327- bridge::GetDefaultDevice ()->toString ());
337+ std::fill_n (devices.begin (), devices.size (), default_device->toString ());
328338 std::vector<XLATensor::ShardingSpecPtr> shardings = {
329339 nullptr ,
330340 std::make_shared<XLATensor::ShardingSpec>(
@@ -388,10 +398,13 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
388398 auto y = xla::Add (x, xla::ConstantR0<float >(&b, 3 ));
389399 XLA_ASSIGN_OR_THROW (xla::XlaComputation xla_computation,
390400 b.Build (/* remove_dynamic_dimensions=*/ false ));
401+ XLA_ASSIGN_OR_THROW (const torch::lazy::BackendDevice* default_device,
402+ bridge::GetDefaultDevice ());
403+ std::string default_device_str = default_device->toString ();
391404 std::vector<torch_xla::runtime::ComputationClient::CompileInstance> instances;
392405 instances.push_back ({std::move (xla_computation),
393- bridge::GetDefaultDevice ()-> toString () ,
394- {bridge::GetDefaultDevice ()-> toString () },
406+ default_device_str ,
407+ {default_device_str },
395408 &shape,
396409 /* should_wrap_parameter=*/ false ,
397410 /* is_sharded=*/ true });
@@ -404,9 +417,8 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
404417 " add" , std::move (computations[0 ]->move_computation ()));
405418
406419 // Prepare output sharding propagation, expect a sharded output placeholder.
407- std::vector<XLATensorPtr> tensors{
408- XLATensor::Create (client->CreateDataPlaceholder (
409- bridge::GetDefaultDevice ()->toString (), std::move (shape)))};
420+ std::vector<XLATensorPtr> tensors{XLATensor::Create (
421+ client->CreateDataPlaceholder (default_device_str, std::move (shape)))};
410422 std::vector<torch::lazy::BackendDataPtr> data_placeholders;
411423 std::vector<XLATensor::ShardingSpecPtr> sharding_specs;
412424 ShardingUtil::PrepareOutputShardingPropagation (
0 commit comments