Skip to content

Commit 8767c73

Browse files
committed
Replace 'GetDefaultDevice()' by its safer version, and remove 'Safe' from its name.
1 parent 1815bb9 commit 8767c73

File tree

9 files changed

+69
-53
lines changed

9 files changed

+69
-53
lines changed

test/cpp/cpp_test_util.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,8 @@ bool EqualValuesNoElementTypeCheck(at::Tensor tensor1, at::Tensor tensor2) {
154154
void ForEachDevice(
155155
absl::Span<const DeviceType> device_types,
156156
const std::function<void(const torch::lazy::BackendDevice&)>& devfn) {
157-
const torch::lazy::BackendDevice* default_device = bridge::GetDefaultDevice();
157+
XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device,
158+
bridge::GetDefaultDevice());
158159
if (device_types.empty() ||
159160
std::find_if(device_types.begin(), device_types.end(),
160161
[&](const DeviceType device_type) {
@@ -169,7 +170,8 @@ void ForEachDevice(
169170

170171
void ForEachDevice(absl::Span<const DeviceType> device_types,
171172
const std::function<void(const torch::Device&)>& devfn) {
172-
const torch::lazy::BackendDevice* default_device = bridge::GetDefaultDevice();
173+
XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device,
174+
bridge::GetDefaultDevice());
173175
if (device_types.empty() ||
174176
std::find_if(device_types.begin(), device_types.end(),
175177
[&](const DeviceType device_type) {

test/cpp/test_aten_xla_tensor_5.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1422,8 +1422,9 @@ TEST_F(AtenXlaTensorTest, TestAvgPool3DNoBatch) {
14221422
}
14231423

14241424
TEST_F(AtenXlaTensorTest, TestAdaptiveMaxPool2D) {
1425-
XlaDeviceType hw_type =
1426-
static_cast<XlaDeviceType>(bridge::GetDefaultDevice()->type());
1425+
XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device,
1426+
bridge::GetDefaultDevice());
1427+
XlaDeviceType hw_type = static_cast<XlaDeviceType>(default_device->type());
14271428
// skip this test until the tile mismatch bug is fixed.
14281429
if (hw_type == XlaDeviceType::TPU) {
14291430
return;
@@ -1455,8 +1456,9 @@ TEST_F(AtenXlaTensorTest, TestAdaptiveMaxPool2D) {
14551456
TEST_F(AtenXlaTensorTest, TestAdaptiveMaxPool2DBackward) {
14561457
GTEST_SKIP() << "failing due to PyTorch upstream changes. "
14571458
<< "See: https://github.com/pytorch/xla/issues/9651.";
1458-
XlaDeviceType hw_type =
1459-
static_cast<XlaDeviceType>(bridge::GetDefaultDevice()->type());
1459+
XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device,
1460+
bridge::GetDefaultDevice());
1461+
XlaDeviceType hw_type = static_cast<XlaDeviceType>(default_device->type());
14601462
// skip this test until the tile mismatch bug is fixed.
14611463
if (hw_type == XlaDeviceType::TPU) {
14621464
return;

test/cpp/test_aten_xla_tensor_6.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -943,8 +943,9 @@ TEST_F(AtenXlaTensorTest, TestEmbeddingBackward) {
943943
}
944944

945945
TEST_F(AtenXlaTensorTest, TestAmpUpdateScale) {
946-
XlaDeviceType hw_type =
947-
static_cast<XlaDeviceType>(bridge::GetDefaultDevice()->type());
946+
XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device,
947+
bridge::GetDefaultDevice());
948+
XlaDeviceType hw_type = static_cast<XlaDeviceType>(default_device->type());
948949
if (hw_type != XlaDeviceType::CPU) {
949950
return;
950951
}

test/cpp/test_xla_sharding.cpp

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,10 @@ class XLAShardingTest : public AtenXlaTensorTestBase {
4646

4747
TEST_F(XLAShardingTest, GetShardShape) {
4848
auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat));
49+
XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device,
50+
bridge::GetDefaultDevice());
4951
xla::Shape tensor_shape =
50-
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
52+
CreateComputationShapeFromTensor(tensor, default_device);
5153
xla::Array2D<int64_t> mesh({
5254
{0, 1},
5355
{2, 3},
@@ -70,8 +72,10 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
7072
std::vector<std::string> devices = {"TPU:0", "TPU:1", "TPU:2", "TPU:3"};
7173

7274
auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat));
75+
XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device,
76+
bridge::GetDefaultDevice());
7377
xla::Shape tensor_shape =
74-
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
78+
CreateComputationShapeFromTensor(tensor, default_device);
7579
xla::Array2D<int64_t> mesh({
7680
{0, 1},
7781
{2, 3},
@@ -126,11 +130,13 @@ TEST_F(XLAShardingTest, ShardTensor) {
126130

127131
// 1D tiled
128132
at::Tensor tensor = at::ones({8}, at::TensorOptions(at::kFloat));
133+
XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device,
134+
bridge::GetDefaultDevice());
129135
xla::Shape tensor_shape =
130-
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
136+
CreateComputationShapeFromTensor(tensor, default_device);
131137
xla::OpSharding sharding =
132138
xla::HloSharding::Tile1D(
133-
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()),
139+
CreateComputationShapeFromTensor(tensor, default_device),
134140
devices.size())
135141
.ToProto();
136142
auto sharding_spec =
@@ -144,8 +150,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
144150
// 2D tiled, The first dim is halved and the last replicated. The last shard
145151
// size should be smaller in dim=1 because it's not evenly divisible.
146152
tensor = at::ones({8, 7, 4}, at::TensorOptions(at::kFloat));
147-
tensor_shape =
148-
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
153+
tensor_shape = CreateComputationShapeFromTensor(tensor, default_device);
149154
xla::Array2D<int64_t> mesh({
150155
{0, 1, 2, 3},
151156
{4, 5, 6, 7},
@@ -181,8 +186,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
181186
// last shard size should be smaller in dim=2 because it's not evenly
182187
// divisible.
183188
tensor = at::ones({1, 8, 7, 4}, at::TensorOptions(at::kFloat));
184-
tensor_shape =
185-
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
189+
tensor_shape = CreateComputationShapeFromTensor(tensor, default_device);
186190
xla::Array4D<int64_t> tesseract({{{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}});
187191
sharding = xla::HloSharding::Tile(tesseract).ToProto();
188192
sharding_spec =
@@ -204,8 +208,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
204208
// last shard size should be smaller in dim=2 because it's not evenly
205209
// divisible.
206210
tensor = at::ones({10, 1, 8, 7, 4}, at::TensorOptions(at::kFloat));
207-
tensor_shape =
208-
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
211+
tensor_shape = CreateComputationShapeFromTensor(tensor, default_device);
209212
xla::Array<int64_t> hypercube(std::vector<int64_t>{1, 1, 2, 2, 2});
210213
hypercube.FillIota(0);
211214
sharding = xla::HloSharding::Tile(hypercube).ToProto();
@@ -230,8 +233,10 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) {
230233

231234
// 2D tiled, The first dim is halved and the last replicated.
232235
at::Tensor tensor = at::ones({8, 7, 4}, at::TensorOptions(at::kFloat));
236+
XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device,
237+
bridge::GetDefaultDevice());
233238
xla::Shape tensor_shape =
234-
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
239+
CreateComputationShapeFromTensor(tensor, default_device);
235240
xla::Array2D<int64_t> mesh({
236241
{4, 5, 0, 1},
237242
{6, 7, 2, 3},
@@ -265,8 +270,10 @@ TEST_F(XLAShardingTest, ShardTensorMiniBatch) {
265270
std::vector<std::string> devices = {"TPU:4", "TPU:5", "TPU:6", "TPU:7"};
266271
at::Tensor minibatch_tensor =
267272
at::ones({8, 7, 4}, at::TensorOptions(at::kFloat));
268-
xla::Shape global_shape = CreateComputationShapeFromTensor(
269-
minibatch_tensor, bridge::GetDefaultDevice());
273+
XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device,
274+
bridge::GetDefaultDevice());
275+
xla::Shape global_shape =
276+
CreateComputationShapeFromTensor(minibatch_tensor, default_device);
270277
global_shape.set_dimensions(
271278
0, minibatch_tensor.sizes()[0] * 2); // Assuming 2 hosts
272279
xla::Array3D<int64_t> mesh({
@@ -292,8 +299,10 @@ TEST_F(XLAShardingTest, ShardTensorMiniBatch) {
292299

293300
TEST_F(XLAShardingTest, EqualShardingSpecs) {
294301
auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat));
302+
XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device,
303+
bridge::GetDefaultDevice());
295304
xla::Shape tensor_shape =
296-
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
305+
CreateComputationShapeFromTensor(tensor, default_device);
297306
XLATensor::ShardingSpec tiled_2d(xla::HloSharding::Tile({
298307
{0, 1, 2, 3},
299308
{4, 5, 6, 7},
@@ -319,12 +328,13 @@ TEST_F(XLAShardingTest, CreateTensorsData) {
319328

320329
std::vector<at::Tensor> tensors(3);
321330
auto tensor = at::ones({8, 8}, at::TensorOptions(at::kFloat));
331+
XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device,
332+
bridge::GetDefaultDevice());
322333
xla::Shape tensor_shape =
323-
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
334+
CreateComputationShapeFromTensor(tensor, default_device);
324335
std::fill_n(tensors.begin(), tensors.size(), tensor);
325336
std::vector<std::string> devices(3);
326-
std::fill_n(devices.begin(), devices.size(),
327-
bridge::GetDefaultDevice()->toString());
337+
std::fill_n(devices.begin(), devices.size(), default_device->toString());
328338
std::vector<XLATensor::ShardingSpecPtr> shardings = {
329339
nullptr,
330340
std::make_shared<XLATensor::ShardingSpec>(
@@ -388,10 +398,13 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
388398
auto y = xla::Add(x, xla::ConstantR0<float>(&b, 3));
389399
XLA_ASSIGN_OR_THROW(xla::XlaComputation xla_computation,
390400
b.Build(/*remove_dynamic_dimensions=*/false));
401+
XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device,
402+
bridge::GetDefaultDevice());
403+
std::string default_device_str = default_device->toString();
391404
std::vector<torch_xla::runtime::ComputationClient::CompileInstance> instances;
392405
instances.push_back({std::move(xla_computation),
393-
bridge::GetDefaultDevice()->toString(),
394-
{bridge::GetDefaultDevice()->toString()},
406+
default_device_str,
407+
{default_device_str},
395408
&shape,
396409
/*should_wrap_parameter=*/false,
397410
/*is_sharded=*/true});
@@ -404,9 +417,8 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
404417
"add", std::move(computations[0]->move_computation()));
405418

406419
// Prepare output sharding propagation, expect a sharded output placeholder.
407-
std::vector<XLATensorPtr> tensors{
408-
XLATensor::Create(client->CreateDataPlaceholder(
409-
bridge::GetDefaultDevice()->toString(), std::move(shape)))};
420+
std::vector<XLATensorPtr> tensors{XLATensor::Create(
421+
client->CreateDataPlaceholder(default_device_str, std::move(shape)))};
410422
std::vector<torch::lazy::BackendDataPtr> data_placeholders;
411423
std::vector<XLATensor::ShardingSpecPtr> sharding_specs;
412424
ShardingUtil::PrepareOutputShardingPropagation(

torch_xla/csrc/aten_xla_bridge.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -419,14 +419,8 @@ InitializeDefaultBackendDevice() {
419419
return new torch::lazy::BackendDevice(device);
420420
}
421421

422-
const torch::lazy::BackendDevice* absl_nonnull GetDefaultDevice() {
423-
XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* absl_nonnull device,
424-
SafeGetDefaultDevice());
425-
return device;
426-
}
427-
428422
const absl::StatusOr<torch::lazy::BackendDevice * absl_nonnull>&
429-
SafeGetDefaultDevice() {
423+
GetDefaultDevice() {
430424
static absl::StatusOr<torch::lazy::BackendDevice* absl_nonnull>&
431425
default_backend_device =
432426
*new absl::StatusOr<torch::lazy::BackendDevice * absl_nonnull>(
@@ -435,12 +429,16 @@ SafeGetDefaultDevice() {
435429
}
436430

437431
c10::Device AtenDefaultDevice() {
438-
return XlaDeviceToAtenDevice(*GetDefaultDevice());
432+
XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device,
433+
bridge::GetDefaultDevice());
434+
return XlaDeviceToAtenDevice(*default_device);
439435
}
440436

441437
torch::lazy::BackendDevice GetCurrentDevice() {
442438
if (!g_current_device) {
443-
g_current_device = *GetDefaultDevice();
439+
XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device,
440+
bridge::GetDefaultDevice());
441+
g_current_device = *default_device;
444442
}
445443
return *g_current_device;
446444
}

torch_xla/csrc/aten_xla_bridge.h

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,12 @@ namespace bridge {
3030
// A StatusOr type fulfills only (2), so we can't use it there. In order
3131
// to do so, we have to change upstream accordingly.
3232
//
33-
ABSL_DEPRECATED(
33+
[[deprecated(
3434
"Use GetXlaTensor(), instead. "
3535
"This function returns an null-initialized `XLATensorPtr`, instead of "
36-
"propagating errors with StatusOr values.")
37-
XLATensorPtr TryGetXlaTensor(const at::Tensor& tensor);
36+
"propagating errors with StatusOr values.")]] //
37+
XLATensorPtr
38+
TryGetXlaTensor(const at::Tensor& tensor);
3839

3940
// Retrieves the underlying `XLATensorPtr` from `tensor`.
4041
//
@@ -144,16 +145,11 @@ c10::Device XlaDeviceToAtenDevice(const torch::lazy::BackendDevice& device);
144145

145146
std::string ToXlaString(const c10::Device& device);
146147

147-
[[deprecated(
148-
"Use SafeGetDefaultDevice for better error handling.")]] const torch::lazy::
149-
BackendDevice* absl_nonnull
150-
GetDefaultDevice();
151-
152148
// Returns the default `BackendDevice`.
153149
// This function returns an error if the `ComputationClient` wasn't correctly
154150
// initialized.
155151
const absl::StatusOr<torch::lazy::BackendDevice * absl_nonnull>&
156-
SafeGetDefaultDevice();
152+
GetDefaultDevice();
157153

158154
c10::Device AtenDefaultDevice();
159155

torch_xla/csrc/ir_builder.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@ struct XLAIrBuilder : torch::lazy::IrBuilder {
2727

2828
torch::lazy::NodePtr MakeScalar(const at::Scalar& value,
2929
const at::ScalarType& type) const override {
30+
XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device,
31+
bridge::GetDefaultDevice());
3032
return torch_xla::MakeNode<Scalar>(
31-
value, MakeXlaPrimitiveType(type, bridge::GetDefaultDevice()));
33+
value, MakeXlaPrimitiveType(type, default_device));
3234
}
3335
torch::lazy::NodePtr MakeExpand(const torch::lazy::Value& input0,
3436
const std::vector<int64_t>& size,

torch_xla/csrc/ops/dynamic_ir.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,10 @@ int64_t SizeNode::getDynamicValue() const {
5050
// Wrap the IR of SizeNode into a dummy tensor and execute/fetch the value
5151
// of this tensor. GetTensors will return a cpu at::Tensor so we can just
5252
// extract the value of it.
53-
std::vector<XLATensorPtr> dummy_size_tensors = {XLATensor::Create(
54-
cloned, *bridge::GetDefaultDevice(), at::ScalarType::Long)};
53+
XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device,
54+
bridge::GetDefaultDevice());
55+
std::vector<XLATensorPtr> dummy_size_tensors = {
56+
XLATensor::Create(cloned, *default_device, at::ScalarType::Long)};
5557
std::vector<at::Tensor> res =
5658
XLAGraphExecutor::Get()->GetTensors(&dummy_size_tensors);
5759
runtime_size_ = res[0].item().toInt();

torch_xla/csrc/xla_backend_impl.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,9 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface {
4242
if (!default_device_ordinal_inited_) {
4343
// bridge::GetDefaultDevice will trigger the runtime device init, should
4444
// not do it during class init time.
45-
torch::lazy::BackendDevice default_device = *bridge::GetDefaultDevice();
46-
default_device_ordinal_ = default_device.ordinal();
45+
XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device,
46+
bridge::GetDefaultDevice());
47+
default_device_ordinal_ = default_device->ordinal();
4748
default_device_ordinal_inited_ = true;
4849
}
4950
return true;

0 commit comments

Comments
 (0)