diff --git a/test/cpp/cpp_test_util.cpp b/test/cpp/cpp_test_util.cpp index d79bcba70fa..01dbca6ecb0 100644 --- a/test/cpp/cpp_test_util.cpp +++ b/test/cpp/cpp_test_util.cpp @@ -154,7 +154,8 @@ bool EqualValuesNoElementTypeCheck(at::Tensor tensor1, at::Tensor tensor2) { void ForEachDevice( absl::Span device_types, const std::function& devfn) { - const torch::lazy::BackendDevice* default_device = bridge::GetDefaultDevice(); + XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device, + bridge::GetDefaultDevice()); if (device_types.empty() || std::find_if(device_types.begin(), device_types.end(), [&](const DeviceType device_type) { @@ -169,7 +170,8 @@ void ForEachDevice( void ForEachDevice(absl::Span device_types, const std::function& devfn) { - const torch::lazy::BackendDevice* default_device = bridge::GetDefaultDevice(); + XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device, + bridge::GetDefaultDevice()); if (device_types.empty() || std::find_if(device_types.begin(), device_types.end(), [&](const DeviceType device_type) { diff --git a/test/cpp/test_aten_xla_tensor_5.cpp b/test/cpp/test_aten_xla_tensor_5.cpp index 7a9383b4d77..788b78bf978 100644 --- a/test/cpp/test_aten_xla_tensor_5.cpp +++ b/test/cpp/test_aten_xla_tensor_5.cpp @@ -1422,8 +1422,9 @@ TEST_F(AtenXlaTensorTest, TestAvgPool3DNoBatch) { } TEST_F(AtenXlaTensorTest, TestAdaptiveMaxPool2D) { - XlaDeviceType hw_type = - static_cast(bridge::GetDefaultDevice()->type()); + XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device, + bridge::GetDefaultDevice()); + XlaDeviceType hw_type = static_cast(default_device->type()); // skip this test until the tile mismatch bug is fixed. if (hw_type == XlaDeviceType::TPU) { return; @@ -1455,8 +1456,9 @@ TEST_F(AtenXlaTensorTest, TestAdaptiveMaxPool2D) { TEST_F(AtenXlaTensorTest, TestAdaptiveMaxPool2DBackward) { GTEST_SKIP() << "failing due to PyTorch upstream changes. " << "See: https://github.com/pytorch/xla/issues/9651."; - XlaDeviceType hw_type = - static_cast(bridge::GetDefaultDevice()->type()); + XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device, + bridge::GetDefaultDevice()); + XlaDeviceType hw_type = static_cast(default_device->type()); // skip this test until the tile mismatch bug is fixed. if (hw_type == XlaDeviceType::TPU) { return; diff --git a/test/cpp/test_aten_xla_tensor_6.cpp b/test/cpp/test_aten_xla_tensor_6.cpp index bb1a00a203e..5b0e31a0f55 100644 --- a/test/cpp/test_aten_xla_tensor_6.cpp +++ b/test/cpp/test_aten_xla_tensor_6.cpp @@ -943,8 +943,9 @@ TEST_F(AtenXlaTensorTest, TestEmbeddingBackward) { } TEST_F(AtenXlaTensorTest, TestAmpUpdateScale) { - XlaDeviceType hw_type = - static_cast(bridge::GetDefaultDevice()->type()); + XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device, + bridge::GetDefaultDevice()); + XlaDeviceType hw_type = static_cast(default_device->type()); if (hw_type != XlaDeviceType::CPU) { return; } diff --git a/test/cpp/test_xla_sharding.cpp b/test/cpp/test_xla_sharding.cpp index 57ab2de44e2..e8a5169b720 100644 --- a/test/cpp/test_xla_sharding.cpp +++ b/test/cpp/test_xla_sharding.cpp @@ -46,8 +46,10 @@ class XLAShardingTest : public AtenXlaTensorTestBase { TEST_F(XLAShardingTest, GetShardShape) { auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat)); + XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device, + bridge::GetDefaultDevice()); xla::Shape tensor_shape = - CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); + CreateComputationShapeFromTensor(tensor, default_device); xla::Array2D mesh({ {0, 1}, {2, 3}, @@ -70,8 +72,10 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) { std::vector devices = {"TPU:0", "TPU:1", "TPU:2", "TPU:3"}; auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat)); + XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device, + bridge::GetDefaultDevice()); xla::Shape tensor_shape = - CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); + CreateComputationShapeFromTensor(tensor, default_device); xla::Array2D mesh({ {0, 1}, {2, 3}, @@ -126,11 +130,13 @@ TEST_F(XLAShardingTest, ShardTensor) { // 1D tiled at::Tensor tensor = at::ones({8}, at::TensorOptions(at::kFloat)); + XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device, + bridge::GetDefaultDevice()); xla::Shape tensor_shape = - CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); + CreateComputationShapeFromTensor(tensor, default_device); xla::OpSharding sharding = xla::HloSharding::Tile1D( - CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()), + CreateComputationShapeFromTensor(tensor, default_device), devices.size()) .ToProto(); auto sharding_spec = @@ -144,8 +150,7 @@ TEST_F(XLAShardingTest, ShardTensor) { // 2D tiled, The first dim is halved and the last replicated. The last shard // size should be smaller in dim=1 because it's not evenly divisible. tensor = at::ones({8, 7, 4}, at::TensorOptions(at::kFloat)); - tensor_shape = - CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); + tensor_shape = CreateComputationShapeFromTensor(tensor, default_device); xla::Array2D mesh({ {0, 1, 2, 3}, {4, 5, 6, 7}, @@ -181,8 +186,7 @@ TEST_F(XLAShardingTest, ShardTensor) { // last shard size should be smaller in dim=2 because it's not evenly // divisible. tensor = at::ones({1, 8, 7, 4}, at::TensorOptions(at::kFloat)); - tensor_shape = - CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); + tensor_shape = CreateComputationShapeFromTensor(tensor, default_device); xla::Array4D tesseract({{{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}}); sharding = xla::HloSharding::Tile(tesseract).ToProto(); sharding_spec = @@ -204,8 +208,7 @@ TEST_F(XLAShardingTest, ShardTensor) { // last shard size should be smaller in dim=2 because it's not evenly // divisible. tensor = at::ones({10, 1, 8, 7, 4}, at::TensorOptions(at::kFloat)); - tensor_shape = - CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); + tensor_shape = CreateComputationShapeFromTensor(tensor, default_device); xla::Array hypercube(std::vector{1, 1, 2, 2, 2}); hypercube.FillIota(0); sharding = xla::HloSharding::Tile(hypercube).ToProto(); @@ -230,8 +233,10 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) { // 2D tiled, The first dim is halved and the last replicated. at::Tensor tensor = at::ones({8, 7, 4}, at::TensorOptions(at::kFloat)); + XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device, + bridge::GetDefaultDevice()); xla::Shape tensor_shape = - CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); + CreateComputationShapeFromTensor(tensor, default_device); xla::Array2D mesh({ {4, 5, 0, 1}, {6, 7, 2, 3}, @@ -265,8 +270,10 @@ TEST_F(XLAShardingTest, ShardTensorMiniBatch) { std::vector devices = {"TPU:4", "TPU:5", "TPU:6", "TPU:7"}; at::Tensor minibatch_tensor = at::ones({8, 7, 4}, at::TensorOptions(at::kFloat)); - xla::Shape global_shape = CreateComputationShapeFromTensor( - minibatch_tensor, bridge::GetDefaultDevice()); + XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device, + bridge::GetDefaultDevice()); + xla::Shape global_shape = + CreateComputationShapeFromTensor(minibatch_tensor, default_device); global_shape.set_dimensions( 0, minibatch_tensor.sizes()[0] * 2); // Assuming 2 hosts xla::Array3D mesh({ @@ -292,8 +299,10 @@ TEST_F(XLAShardingTest, ShardTensorMiniBatch) { TEST_F(XLAShardingTest, EqualShardingSpecs) { auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat)); + XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device, + bridge::GetDefaultDevice()); xla::Shape tensor_shape = - CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); + CreateComputationShapeFromTensor(tensor, default_device); XLATensor::ShardingSpec tiled_2d(xla::HloSharding::Tile({ {0, 1, 2, 3}, {4, 5, 6, 7}, @@ -319,12 +328,13 @@ TEST_F(XLAShardingTest, CreateTensorsData) { std::vector tensors(3); auto tensor = at::ones({8, 8}, at::TensorOptions(at::kFloat)); + XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device, + bridge::GetDefaultDevice()); xla::Shape tensor_shape = - CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); + CreateComputationShapeFromTensor(tensor, default_device); std::fill_n(tensors.begin(), tensors.size(), tensor); std::vector devices(3); - std::fill_n(devices.begin(), devices.size(), - bridge::GetDefaultDevice()->toString()); + std::fill_n(devices.begin(), devices.size(), default_device->toString()); std::vector shardings = { nullptr, std::make_shared( @@ -388,10 +398,13 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) { auto y = xla::Add(x, xla::ConstantR0(&b, 3)); XLA_ASSIGN_OR_THROW(xla::XlaComputation xla_computation, b.Build(/*remove_dynamic_dimensions=*/false)); + XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device, + bridge::GetDefaultDevice()); + std::string default_device_str = default_device->toString(); std::vector instances; instances.push_back({std::move(xla_computation), - bridge::GetDefaultDevice()->toString(), - {bridge::GetDefaultDevice()->toString()}, + default_device_str, + {default_device_str}, &shape, /*should_wrap_parameter=*/false, /*is_sharded=*/true}); @@ -404,9 +417,8 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) { "add", std::move(computations[0]->move_computation())); // Prepare output sharding propagation, expect a sharded output placeholder. - std::vector tensors{ - XLATensor::Create(client->CreateDataPlaceholder( - bridge::GetDefaultDevice()->toString(), std::move(shape)))}; + std::vector tensors{XLATensor::Create( + client->CreateDataPlaceholder(default_device_str, std::move(shape)))}; std::vector data_placeholders; std::vector sharding_specs; ShardingUtil::PrepareOutputShardingPropagation( diff --git a/torch_xla/csrc/aten_xla_bridge.cpp b/torch_xla/csrc/aten_xla_bridge.cpp index 9fd9634fd76..b7b038d0bb7 100644 --- a/torch_xla/csrc/aten_xla_bridge.cpp +++ b/torch_xla/csrc/aten_xla_bridge.cpp @@ -419,14 +419,8 @@ InitializeDefaultBackendDevice() { return new torch::lazy::BackendDevice(device); } -const torch::lazy::BackendDevice* absl_nonnull GetDefaultDevice() { - XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* absl_nonnull device, - SafeGetDefaultDevice()); - return device; -} - const absl::StatusOr& -SafeGetDefaultDevice() { +GetDefaultDevice() { static absl::StatusOr& default_backend_device = *new absl::StatusOr( @@ -435,12 +429,16 @@ SafeGetDefaultDevice() { } c10::Device AtenDefaultDevice() { - return XlaDeviceToAtenDevice(*GetDefaultDevice()); + XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device, + bridge::GetDefaultDevice()); + return XlaDeviceToAtenDevice(*default_device); } torch::lazy::BackendDevice GetCurrentDevice() { if (!g_current_device) { - g_current_device = *GetDefaultDevice(); + XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device, + bridge::GetDefaultDevice()); + g_current_device = *default_device; } return *g_current_device; } diff --git a/torch_xla/csrc/aten_xla_bridge.h b/torch_xla/csrc/aten_xla_bridge.h index f22a68ca49b..7ac95fafaf6 100644 --- a/torch_xla/csrc/aten_xla_bridge.h +++ b/torch_xla/csrc/aten_xla_bridge.h @@ -30,11 +30,12 @@ namespace bridge { // A StatusOr type fulfills only (2), so we can't use it there. In order // to do so, we have to change upstream accordingly. // -ABSL_DEPRECATED( +[[deprecated( "Use GetXlaTensor(), instead. " "This function returns an null-initialized `XLATensorPtr`, instead of " - "propagating errors with StatusOr values.") -XLATensorPtr TryGetXlaTensor(const at::Tensor& tensor); + "propagating errors with StatusOr values.")]] // +XLATensorPtr +TryGetXlaTensor(const at::Tensor& tensor); // Retrieves the underlying `XLATensorPtr` from `tensor`. // @@ -144,16 +145,11 @@ c10::Device XlaDeviceToAtenDevice(const torch::lazy::BackendDevice& device); std::string ToXlaString(const c10::Device& device); -[[deprecated( - "Use SafeGetDefaultDevice for better error handling.")]] const torch::lazy:: - BackendDevice* absl_nonnull - GetDefaultDevice(); - // Returns the default `BackendDevice`. // This function returns an error if the `ComputationClient` wasn't correctly // initialized. const absl::StatusOr& -SafeGetDefaultDevice(); +GetDefaultDevice(); c10::Device AtenDefaultDevice(); diff --git a/torch_xla/csrc/ir_builder.h b/torch_xla/csrc/ir_builder.h index cb338ea2888..bfea30f5678 100644 --- a/torch_xla/csrc/ir_builder.h +++ b/torch_xla/csrc/ir_builder.h @@ -27,8 +27,10 @@ struct XLAIrBuilder : torch::lazy::IrBuilder { torch::lazy::NodePtr MakeScalar(const at::Scalar& value, const at::ScalarType& type) const override { + XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device, + bridge::GetDefaultDevice()); return torch_xla::MakeNode( - value, MakeXlaPrimitiveType(type, bridge::GetDefaultDevice())); + value, MakeXlaPrimitiveType(type, default_device)); } torch::lazy::NodePtr MakeExpand(const torch::lazy::Value& input0, const std::vector& size, diff --git a/torch_xla/csrc/ops/dynamic_ir.cpp b/torch_xla/csrc/ops/dynamic_ir.cpp index 25f332d7a18..389f052a87e 100644 --- a/torch_xla/csrc/ops/dynamic_ir.cpp +++ b/torch_xla/csrc/ops/dynamic_ir.cpp @@ -50,8 +50,10 @@ int64_t SizeNode::getDynamicValue() const { // Wrap the IR of SizeNode into a dummy tensor and execute/fetch the value // of this tensor. GetTensors will return a cpu at::Tensor so we can just // extract the value of it. - std::vector dummy_size_tensors = {XLATensor::Create( - cloned, *bridge::GetDefaultDevice(), at::ScalarType::Long)}; + XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device, + bridge::GetDefaultDevice()); + std::vector dummy_size_tensors = { + XLATensor::Create(cloned, *default_device, at::ScalarType::Long)}; std::vector res = XLAGraphExecutor::Get()->GetTensors(&dummy_size_tensors); runtime_size_ = res[0].item().toInt(); diff --git a/torch_xla/csrc/xla_backend_impl.cpp b/torch_xla/csrc/xla_backend_impl.cpp index 7d7acb735ef..4a94a7091e6 100644 --- a/torch_xla/csrc/xla_backend_impl.cpp +++ b/torch_xla/csrc/xla_backend_impl.cpp @@ -42,8 +42,9 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface { if (!default_device_ordinal_inited_) { // bridge::GetDefaultDevice will trigger the runtime device init, should // not do it during class init time. - torch::lazy::BackendDevice default_device = *bridge::GetDefaultDevice(); - default_device_ordinal_ = default_device.ordinal(); + XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device, + bridge::GetDefaultDevice()); + default_device_ordinal_ = default_device->ordinal(); default_device_ordinal_inited_ = true; } return true;