Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions test/cpp/cpp_test_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ bool EqualValuesNoElementTypeCheck(at::Tensor tensor1, at::Tensor tensor2) {
void ForEachDevice(
absl::Span<const DeviceType> device_types,
const std::function<void(const torch::lazy::BackendDevice&)>& devfn) {
const torch::lazy::BackendDevice* default_device = bridge::GetDefaultDevice();
XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device,
bridge::GetDefaultDevice());
if (device_types.empty() ||
std::find_if(device_types.begin(), device_types.end(),
[&](const DeviceType device_type) {
Expand All @@ -169,7 +170,8 @@ void ForEachDevice(

void ForEachDevice(absl::Span<const DeviceType> device_types,
const std::function<void(const torch::Device&)>& devfn) {
const torch::lazy::BackendDevice* default_device = bridge::GetDefaultDevice();
XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device,
bridge::GetDefaultDevice());
if (device_types.empty() ||
std::find_if(device_types.begin(), device_types.end(),
[&](const DeviceType device_type) {
Expand Down
10 changes: 6 additions & 4 deletions test/cpp/test_aten_xla_tensor_5.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1422,8 +1422,9 @@ TEST_F(AtenXlaTensorTest, TestAvgPool3DNoBatch) {
}

TEST_F(AtenXlaTensorTest, TestAdaptiveMaxPool2D) {
XlaDeviceType hw_type =
static_cast<XlaDeviceType>(bridge::GetDefaultDevice()->type());
XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device,
bridge::GetDefaultDevice());
XlaDeviceType hw_type = static_cast<XlaDeviceType>(default_device->type());
// skip this test until the tile mismatch bug is fixed.
if (hw_type == XlaDeviceType::TPU) {
return;
Expand Down Expand Up @@ -1455,8 +1456,9 @@ TEST_F(AtenXlaTensorTest, TestAdaptiveMaxPool2D) {
TEST_F(AtenXlaTensorTest, TestAdaptiveMaxPool2DBackward) {
GTEST_SKIP() << "failing due to PyTorch upstream changes. "
<< "See: https://github.com/pytorch/xla/issues/9651.";
XlaDeviceType hw_type =
static_cast<XlaDeviceType>(bridge::GetDefaultDevice()->type());
XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device,
bridge::GetDefaultDevice());
XlaDeviceType hw_type = static_cast<XlaDeviceType>(default_device->type());
// skip this test until the tile mismatch bug is fixed.
if (hw_type == XlaDeviceType::TPU) {
return;
Expand Down
5 changes: 3 additions & 2 deletions test/cpp/test_aten_xla_tensor_6.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -943,8 +943,9 @@ TEST_F(AtenXlaTensorTest, TestEmbeddingBackward) {
}

TEST_F(AtenXlaTensorTest, TestAmpUpdateScale) {
XlaDeviceType hw_type =
static_cast<XlaDeviceType>(bridge::GetDefaultDevice()->type());
XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device,
bridge::GetDefaultDevice());
XlaDeviceType hw_type = static_cast<XlaDeviceType>(default_device->type());
if (hw_type != XlaDeviceType::CPU) {
return;
}
Expand Down
56 changes: 34 additions & 22 deletions test/cpp/test_xla_sharding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@ class XLAShardingTest : public AtenXlaTensorTestBase {

TEST_F(XLAShardingTest, GetShardShape) {
auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat));
XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device,
bridge::GetDefaultDevice());
xla::Shape tensor_shape =
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
CreateComputationShapeFromTensor(tensor, default_device);
xla::Array2D<int64_t> mesh({
{0, 1},
{2, 3},
Expand All @@ -70,8 +72,10 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
std::vector<std::string> devices = {"TPU:0", "TPU:1", "TPU:2", "TPU:3"};

auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat));
XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device,
bridge::GetDefaultDevice());
xla::Shape tensor_shape =
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
CreateComputationShapeFromTensor(tensor, default_device);
xla::Array2D<int64_t> mesh({
{0, 1},
{2, 3},
Expand Down Expand Up @@ -126,11 +130,13 @@ TEST_F(XLAShardingTest, ShardTensor) {

// 1D tiled
at::Tensor tensor = at::ones({8}, at::TensorOptions(at::kFloat));
XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device,
bridge::GetDefaultDevice());
xla::Shape tensor_shape =
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
CreateComputationShapeFromTensor(tensor, default_device);
xla::OpSharding sharding =
xla::HloSharding::Tile1D(
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()),
CreateComputationShapeFromTensor(tensor, default_device),
devices.size())
.ToProto();
auto sharding_spec =
Expand All @@ -144,8 +150,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
// 2D tiled, The first dim is halved and the last replicated. The last shard
// size should be smaller in dim=1 because it's not evenly divisible.
tensor = at::ones({8, 7, 4}, at::TensorOptions(at::kFloat));
tensor_shape =
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
tensor_shape = CreateComputationShapeFromTensor(tensor, default_device);
xla::Array2D<int64_t> mesh({
{0, 1, 2, 3},
{4, 5, 6, 7},
Expand Down Expand Up @@ -181,8 +186,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
// last shard size should be smaller in dim=2 because it's not evenly
// divisible.
tensor = at::ones({1, 8, 7, 4}, at::TensorOptions(at::kFloat));
tensor_shape =
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
tensor_shape = CreateComputationShapeFromTensor(tensor, default_device);
xla::Array4D<int64_t> tesseract({{{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}});
sharding = xla::HloSharding::Tile(tesseract).ToProto();
sharding_spec =
Expand All @@ -204,8 +208,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
// last shard size should be smaller in dim=2 because it's not evenly
// divisible.
tensor = at::ones({10, 1, 8, 7, 4}, at::TensorOptions(at::kFloat));
tensor_shape =
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
tensor_shape = CreateComputationShapeFromTensor(tensor, default_device);
xla::Array<int64_t> hypercube(std::vector<int64_t>{1, 1, 2, 2, 2});
hypercube.FillIota(0);
sharding = xla::HloSharding::Tile(hypercube).ToProto();
Expand All @@ -230,8 +233,10 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) {

// 2D tiled, The first dim is halved and the last replicated.
at::Tensor tensor = at::ones({8, 7, 4}, at::TensorOptions(at::kFloat));
XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device,
bridge::GetDefaultDevice());
xla::Shape tensor_shape =
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
CreateComputationShapeFromTensor(tensor, default_device);
xla::Array2D<int64_t> mesh({
{4, 5, 0, 1},
{6, 7, 2, 3},
Expand Down Expand Up @@ -265,8 +270,10 @@ TEST_F(XLAShardingTest, ShardTensorMiniBatch) {
std::vector<std::string> devices = {"TPU:4", "TPU:5", "TPU:6", "TPU:7"};
at::Tensor minibatch_tensor =
at::ones({8, 7, 4}, at::TensorOptions(at::kFloat));
xla::Shape global_shape = CreateComputationShapeFromTensor(
minibatch_tensor, bridge::GetDefaultDevice());
XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device,
bridge::GetDefaultDevice());
xla::Shape global_shape =
CreateComputationShapeFromTensor(minibatch_tensor, default_device);
global_shape.set_dimensions(
0, minibatch_tensor.sizes()[0] * 2); // Assuming 2 hosts
xla::Array3D<int64_t> mesh({
Expand All @@ -292,8 +299,10 @@ TEST_F(XLAShardingTest, ShardTensorMiniBatch) {

TEST_F(XLAShardingTest, EqualShardingSpecs) {
auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat));
XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device,
bridge::GetDefaultDevice());
xla::Shape tensor_shape =
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
CreateComputationShapeFromTensor(tensor, default_device);
XLATensor::ShardingSpec tiled_2d(xla::HloSharding::Tile({
{0, 1, 2, 3},
{4, 5, 6, 7},
Expand All @@ -319,12 +328,13 @@ TEST_F(XLAShardingTest, CreateTensorsData) {

std::vector<at::Tensor> tensors(3);
auto tensor = at::ones({8, 8}, at::TensorOptions(at::kFloat));
XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device,
bridge::GetDefaultDevice());
xla::Shape tensor_shape =
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
CreateComputationShapeFromTensor(tensor, default_device);
std::fill_n(tensors.begin(), tensors.size(), tensor);
std::vector<std::string> devices(3);
std::fill_n(devices.begin(), devices.size(),
bridge::GetDefaultDevice()->toString());
std::fill_n(devices.begin(), devices.size(), default_device->toString());
std::vector<XLATensor::ShardingSpecPtr> shardings = {
nullptr,
std::make_shared<XLATensor::ShardingSpec>(
Expand Down Expand Up @@ -388,10 +398,13 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
auto y = xla::Add(x, xla::ConstantR0<float>(&b, 3));
XLA_ASSIGN_OR_THROW(xla::XlaComputation xla_computation,
b.Build(/*remove_dynamic_dimensions=*/false));
XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device,
bridge::GetDefaultDevice());
std::string default_device_str = default_device->toString();
std::vector<torch_xla::runtime::ComputationClient::CompileInstance> instances;
instances.push_back({std::move(xla_computation),
bridge::GetDefaultDevice()->toString(),
{bridge::GetDefaultDevice()->toString()},
default_device_str,
{default_device_str},
&shape,
/*should_wrap_parameter=*/false,
/*is_sharded=*/true});
Expand All @@ -404,9 +417,8 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
"add", std::move(computations[0]->move_computation()));

// Prepare output sharding propagation, expect a sharded output placeholder.
std::vector<XLATensorPtr> tensors{
XLATensor::Create(client->CreateDataPlaceholder(
bridge::GetDefaultDevice()->toString(), std::move(shape)))};
std::vector<XLATensorPtr> tensors{XLATensor::Create(
client->CreateDataPlaceholder(default_device_str, std::move(shape)))};
std::vector<torch::lazy::BackendDataPtr> data_placeholders;
std::vector<XLATensor::ShardingSpecPtr> sharding_specs;
ShardingUtil::PrepareOutputShardingPropagation(
Expand Down
16 changes: 7 additions & 9 deletions torch_xla/csrc/aten_xla_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -419,14 +419,8 @@ InitializeDefaultBackendDevice() {
return new torch::lazy::BackendDevice(device);
}

const torch::lazy::BackendDevice* absl_nonnull GetDefaultDevice() {
XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* absl_nonnull device,
SafeGetDefaultDevice());
return device;
}

const absl::StatusOr<torch::lazy::BackendDevice * absl_nonnull>&
SafeGetDefaultDevice() {
GetDefaultDevice() {
static absl::StatusOr<torch::lazy::BackendDevice* absl_nonnull>&
default_backend_device =
*new absl::StatusOr<torch::lazy::BackendDevice * absl_nonnull>(
Expand All @@ -435,12 +429,16 @@ SafeGetDefaultDevice() {
}

c10::Device AtenDefaultDevice() {
return XlaDeviceToAtenDevice(*GetDefaultDevice());
XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device,
bridge::GetDefaultDevice());
return XlaDeviceToAtenDevice(*default_device);
}

torch::lazy::BackendDevice GetCurrentDevice() {
if (!g_current_device) {
g_current_device = *GetDefaultDevice();
XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device,
bridge::GetDefaultDevice());
g_current_device = *default_device;
}
return *g_current_device;
}
Expand Down
14 changes: 5 additions & 9 deletions torch_xla/csrc/aten_xla_bridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@ namespace bridge {
// A StatusOr type fulfills only (2), so we can't use it there. In order
// to do so, we have to change upstream accordingly.
//
ABSL_DEPRECATED(
[[deprecated(
"Use GetXlaTensor(), instead. "
"This function returns an null-initialized `XLATensorPtr`, instead of "
"propagating errors with StatusOr values.")
XLATensorPtr TryGetXlaTensor(const at::Tensor& tensor);
"propagating errors with StatusOr values.")]] //
XLATensorPtr
TryGetXlaTensor(const at::Tensor& tensor);

// Retrieves the underlying `XLATensorPtr` from `tensor`.
//
Expand Down Expand Up @@ -144,16 +145,11 @@ c10::Device XlaDeviceToAtenDevice(const torch::lazy::BackendDevice& device);

std::string ToXlaString(const c10::Device& device);

[[deprecated(
"Use SafeGetDefaultDevice for better error handling.")]] const torch::lazy::
BackendDevice* absl_nonnull
GetDefaultDevice();

// Returns the default `BackendDevice`.
// This function returns an error if the `ComputationClient` wasn't correctly
// initialized.
const absl::StatusOr<torch::lazy::BackendDevice * absl_nonnull>&
SafeGetDefaultDevice();
GetDefaultDevice();

c10::Device AtenDefaultDevice();

Expand Down
4 changes: 3 additions & 1 deletion torch_xla/csrc/ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ struct XLAIrBuilder : torch::lazy::IrBuilder {

torch::lazy::NodePtr MakeScalar(const at::Scalar& value,
const at::ScalarType& type) const override {
XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device,
bridge::GetDefaultDevice());
return torch_xla::MakeNode<Scalar>(
value, MakeXlaPrimitiveType(type, bridge::GetDefaultDevice()));
value, MakeXlaPrimitiveType(type, default_device));
}
torch::lazy::NodePtr MakeExpand(const torch::lazy::Value& input0,
const std::vector<int64_t>& size,
Expand Down
6 changes: 4 additions & 2 deletions torch_xla/csrc/ops/dynamic_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ int64_t SizeNode::getDynamicValue() const {
// Wrap the IR of SizeNode into a dummy tensor and execute/fetch the value
// of this tensor. GetTensors will return a cpu at::Tensor so we can just
// extract the value of it.
std::vector<XLATensorPtr> dummy_size_tensors = {XLATensor::Create(
cloned, *bridge::GetDefaultDevice(), at::ScalarType::Long)};
XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device,
bridge::GetDefaultDevice());
std::vector<XLATensorPtr> dummy_size_tensors = {
XLATensor::Create(cloned, *default_device, at::ScalarType::Long)};
std::vector<at::Tensor> res =
XLAGraphExecutor::Get()->GetTensors(&dummy_size_tensors);
runtime_size_ = res[0].item().toInt();
Expand Down
5 changes: 3 additions & 2 deletions torch_xla/csrc/xla_backend_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface {
if (!default_device_ordinal_inited_) {
// bridge::GetDefaultDevice will trigger the runtime device init, should
// not do it during class init time.
torch::lazy::BackendDevice default_device = *bridge::GetDefaultDevice();
default_device_ordinal_ = default_device.ordinal();
XLA_ASSIGN_OR_THROW(const torch::lazy::BackendDevice* default_device,
bridge::GetDefaultDevice());
default_device_ordinal_ = default_device->ordinal();
default_device_ordinal_inited_ = true;
}
return true;
Expand Down