Skip to content
Closed
152 changes: 113 additions & 39 deletions test/cpp/test_xla_sharding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,34 @@ class XLAShardingTest : public AtenXlaTensorTestBase {
}
};

TEST_F(XLAShardingTest, NormalizeTileAssignment) {
// Test with an empty tile assignment
std::vector<int64_t> empty_tile_assignment = {};
auto normalized =
ShardingUtil::NormalizeTileAssignment(empty_tile_assignment);
EXPECT_TRUE(normalized.empty());

// Test with positive values
std::vector<int64_t> positive_tile_assignment = {3, 1, 4, 2};
normalized = ShardingUtil::NormalizeTileAssignment(positive_tile_assignment);
EXPECT_EQ(normalized, std::vector<int64_t>({2, 0, 3, 1}));

// Test with all identical values
std::vector<int64_t> identical_tile_assignment = {5, 5, 5, 5};
normalized = ShardingUtil::NormalizeTileAssignment(identical_tile_assignment);
EXPECT_EQ(normalized, std::vector<int64_t>({0, 0, 0, 0}));

// Test with negative values
std::vector<int64_t> negative_tile_assignment = {-3, -1, -4, -2};
EXPECT_THROW(ShardingUtil::NormalizeTileAssignment(negative_tile_assignment),
std::runtime_error);

// Test with mixed positive and negative values
std::vector<int64_t> mixed_tile_assignment = {3, -1, 4, 2};
EXPECT_THROW(ShardingUtil::NormalizeTileAssignment(mixed_tile_assignment),
std::runtime_error);
}

TEST_F(XLAShardingTest, GetShardShape) {
auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat));
xla::Shape tensor_shape =
Expand All @@ -50,15 +78,19 @@ TEST_F(XLAShardingTest, GetShardShape) {
{0, 1},
{2, 3},
});
auto sharding = xla::HloSharding::Tile(mesh).ToProto();
auto xla_sharding = xla::HloSharding::Tile(mesh).ToProto();
std::vector<int64_t> denormalized_tile_assignment = {0, 1, 2, 3};
torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment);
auto sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);

auto shard_shape = ShardingUtil::GetShardShape(sharding_spec);
// For tiled sharding, each dimension should be halved
EXPECT_EQ(shard_shape, std::vector<int64_t>({4, 4}));

sharding_spec->sharding = xla::HloSharding::Replicate().ToProto();
xla_sharding = xla::HloSharding::Replicate().ToProto();
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
sharding_spec->sharding = sharding;
shard_shape = ShardingUtil::GetShardShape(sharding_spec);
// For replicated sharding, each dimension should be preserved
EXPECT_EQ(shard_shape, std::vector<int64_t>({8, 7}));
Expand All @@ -74,7 +106,9 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
{0, 1},
{2, 3},
});
auto sharding = xla::HloSharding::Tile(mesh).ToProto();
auto xla_sharding = xla::HloSharding::Tile(mesh).ToProto();
std::vector<int64_t> denormalized_tile_assignment = {0, 1, 2, 3};
torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment);
auto sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
auto shard_shape = ShardingUtil::GetShardShape(sharding_spec);
Expand Down Expand Up @@ -103,7 +137,8 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
EXPECT_EQ(slice.step(), 1);
}
}
sharding = xla::HloSharding::Replicate().ToProto();
xla_sharding = xla::HloSharding::Replicate().ToProto();
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
sharding_spec->sharding = sharding;
shard_shape = ShardingUtil::GetShardShape(sharding_spec);
replica_and_indices = ShardingUtil::GetShardReplicaAndIndicesForDevices(
Expand All @@ -121,16 +156,18 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
TEST_F(XLAShardingTest, ShardTensor) {
std::vector<std::string> devices = {"TPU:0", "TPU:1", "TPU:2", "TPU:3",
"TPU:4", "TPU:5", "TPU:6", "TPU:7"};
std::vector<int64_t> denormalized_tile_assignment = {0, 1, 2, 3, 4, 5, 6, 7};

// 1D tiled
at::Tensor tensor = at::ones({8}, at::TensorOptions(at::kFloat));
xla::Shape tensor_shape =
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
xla::OpSharding sharding =
xla::OpSharding xla_sharding =
xla::HloSharding::Tile1D(
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()),
devices.size())
.ToProto();
torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment);
auto sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
auto shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
Expand All @@ -148,7 +185,8 @@ TEST_F(XLAShardingTest, ShardTensor) {
{0, 1, 2, 3},
{4, 5, 6, 7},
});
sharding = xla::HloSharding::Tile(mesh).ToProto();
xla_sharding = xla::HloSharding::Tile(mesh).ToProto();
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
Expand All @@ -160,15 +198,19 @@ TEST_F(XLAShardingTest, ShardTensor) {
// 3D tiled, the first dim is replicated and the last halved. The last shard
// size should be smaller in dim=1 because it's not evenly divisible.
xla::Array3D<int64_t> cube({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}});
sharding_spec->sharding = xla::HloSharding::Tile(cube).ToProto();
xla_sharding = xla::HloSharding::Tile(cube).ToProto();
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
sharding_spec->sharding = sharding;
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
/*padded=*/false);
EXPECT_EQ(shards.size(), 8);
EXPECT_EQ(shards[0].sizes(), c10::ArrayRef<long>({8, 2, 2}));
EXPECT_EQ(shards[7].sizes(), c10::ArrayRef<long>({8, 1, 2}));

// Replicated, all shards should be identical.
sharding_spec->sharding = xla::HloSharding::Replicate().ToProto();
xla_sharding = xla::HloSharding::Replicate().ToProto();
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
sharding_spec->sharding = sharding;
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
/*padded=*/false);
EXPECT_EQ(shards.size(), 8);
Expand All @@ -182,7 +224,8 @@ TEST_F(XLAShardingTest, ShardTensor) {
tensor_shape =
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
xla::Array4D<int64_t> tesseract({{{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}});
sharding = xla::HloSharding::Tile(tesseract).ToProto();
xla_sharding = xla::HloSharding::Tile(tesseract).ToProto();
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
Expand All @@ -206,7 +249,8 @@ TEST_F(XLAShardingTest, ShardTensor) {
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
xla::Array<int64_t> hypercube(std::vector<int64_t>{1, 1, 2, 2, 2});
hypercube.FillIota(0);
sharding = xla::HloSharding::Tile(hypercube).ToProto();
xla_sharding = xla::HloSharding::Tile(hypercube).ToProto();
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
Expand Down Expand Up @@ -234,7 +278,9 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) {
{4, 5, 0, 1},
{6, 7, 2, 3},
});
auto sharding = xla::HloSharding::Tile(mesh).ToProto();
auto xla_sharding = xla::HloSharding::Tile(mesh).ToProto();
std::vector<int64_t> denormalized_tile_assignment = {4, 5, 0, 1, 6, 7, 2, 3};
torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment);
auto sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
// For devices at the start of the mesh, all shards should have the same
Expand All @@ -251,7 +297,10 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) {
{0, 1, 4, 5},
{2, 3, 6, 7},
});
sharding_spec->sharding = xla::HloSharding::Tile(mesh).ToProto();
xla_sharding = xla::HloSharding::Tile(mesh).ToProto();
denormalized_tile_assignment = {0, 1, 4, 5, 2, 3, 6, 7};
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
sharding_spec->sharding = sharding;
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
/*padded=*/false);
EXPECT_EQ(shards.size(), 4);
Expand All @@ -278,7 +327,9 @@ TEST_F(XLAShardingTest, ShardTensorMiniBatch) {
{{7}},
});

auto sharding = xla::HloSharding::Tile(mesh).ToProto();
auto xla_sharding = xla::HloSharding::Tile(mesh).ToProto();
std::vector<int64_t> denormalized_tile_assignment = {0, 1, 2, 3, 4, 5, 6, 7};
torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment);
auto sharding_spec = std::make_shared<XLATensor::ShardingSpec>(
sharding, global_shape, /*minibatch=*/true);
auto shards = ShardingUtil::ShardTensor(minibatch_tensor, sharding_spec,
Expand All @@ -292,17 +343,21 @@ TEST_F(XLAShardingTest, EqualShardingSpecs) {
auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat));
xla::Shape tensor_shape =
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
XLATensor::ShardingSpec tiled_2d(xla::HloSharding::Tile({
{0, 1, 2, 3},
{4, 5, 6, 7},
})
.ToProto(),
tensor_shape);
XLATensor::ShardingSpec tiled_3d(
xla::HloSharding::Tile({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}).ToProto(),
tensor_shape);
XLATensor::ShardingSpec replicated(xla::HloSharding::Replicate().ToProto(),
tensor_shape);
auto xla_sharding = xla::HloSharding::Tile({
{0, 1, 2, 3},
{4, 5, 6, 7},
})
.ToProto();
std::vector<int64_t> denormalized_tile_assignment = {0, 1, 2, 3, 4, 5, 6, 7};
torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment);
XLATensor::ShardingSpec tiled_2d(sharding, tensor_shape);
xla_sharding =
xla::HloSharding::Tile({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}).ToProto();
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
XLATensor::ShardingSpec tiled_3d(sharding, tensor_shape);
xla_sharding = xla::HloSharding::Replicate().ToProto();
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
XLATensor::ShardingSpec replicated(sharding, tensor_shape);
EXPECT_TRUE(ShardingUtil::EqualShardingSpecs(tiled_2d, tiled_2d));
EXPECT_FALSE(ShardingUtil::EqualShardingSpecs(tiled_2d, tiled_3d));
EXPECT_TRUE(ShardingUtil::EqualShardingSpecs(replicated, replicated));
Expand All @@ -323,12 +378,17 @@ TEST_F(XLAShardingTest, CreateTensorsData) {
std::vector<std::string> devices(3);
std::fill_n(devices.begin(), devices.size(),
bridge::GetDefaultDevice()->toString());
auto replicate_xla_sharding = xla::HloSharding::Replicate().ToProto();
auto unknown_xla_sharding = xla::HloSharding::Unknown().ToProto();
torch_xla::OpSharding replicate_sharding(replicate_xla_sharding,
std::nullopt);
torch_xla::OpSharding unknown_sharding(unknown_xla_sharding, std::nullopt);
std::vector<XLATensor::ShardingSpecPtr> shardings = {
nullptr,
std::make_shared<XLATensor::ShardingSpec>(
xla::HloSharding::Replicate().ToProto(), tensor_shape),
std::make_shared<XLATensor::ShardingSpec>(
xla::HloSharding::Unknown().ToProto(), tensor_shape)};
std::make_shared<XLATensor::ShardingSpec>(replicate_sharding,
tensor_shape),
std::make_shared<XLATensor::ShardingSpec>(unknown_sharding,
tensor_shape)};
std::vector<torch::lazy::BackendDataPtr> tensors_data =
CreateTensorsData(tensors, shardings, devices);

Expand Down Expand Up @@ -387,13 +447,29 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
auto y = xla::Add(x, xla::ConstantR0<float>(&b, 3));
xla::XlaComputation xla_computation =
GetValueOrThrow(b.Build(/*remove_dynamic_dimensions=*/false));

std::vector<XLATensorPtr> tensors{XLATensor::Create(
torch_xla::runtime::GetComputationClientOrDie()->CreateDataPlaceholder(
bridge::GetDefaultDevice()->toString(), std::move(shape)))};
std::vector<std::vector<int64_t>> denormalized_tile_assignments;
for (auto tensor : tensors) {
auto sharding_spec = tensor->sharding_spec();
if (sharding_spec) {
denormalized_tile_assignments.push_back(
sharding_spec->sharding.GetDenormalizedTileAssignment());
}
}

std::vector<torch_xla::runtime::ComputationClient::CompileInstance> instances;
instances.push_back({std::move(xla_computation),
bridge::GetDefaultDevice()->toString(),
{bridge::GetDefaultDevice()->toString()},
&shape,
/*should_wrap_parameter=*/false,
/*is_sharded=*/true});
instances.push_back(
{std::move(xla_computation),
bridge::GetDefaultDevice()->toString(),
{bridge::GetDefaultDevice()->toString()},
&shape,
/*should_wrap_parameter=*/false,
/*is_sharded=*/true,
/*allow_spmd_sharding_propagation_to_output=*/true,
/*denormalized_tile_assignments=*/denormalized_tile_assignments});

std::vector<
std::shared_ptr<torch_xla::runtime::ComputationClient::Computation>>
Expand All @@ -404,9 +480,6 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
"add", std::move(computations[0]->move_computation()));

// Prepare output sharding propagation, expect a sharded output placeholder.
std::vector<XLATensorPtr> tensors{XLATensor::Create(
torch_xla::runtime::GetComputationClientOrDie()->CreateDataPlaceholder(
bridge::GetDefaultDevice()->toString(), std::move(shape)))};
std::vector<torch::lazy::BackendDataPtr> data_placeholders;
std::vector<XLATensor::ShardingSpecPtr> sharding_specs;
ShardingUtil::PrepareOutputShardingPropagation(
Expand All @@ -417,11 +490,12 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
if (n_devices > 1) {
// Tiled sharding requires multiple devices.
EXPECT_TRUE(xla::protobuf_util::HaveSameSerialization(
tiled, sharding_specs[0]->sharding));
tiled, sharding_specs[0]->sharding.GetXlaOpSharding()));
} else {
// Sincle device execution defaults to replication sharding.
EXPECT_TRUE(xla::protobuf_util::HaveSameSerialization(
xla::HloSharding::Replicate().ToProto(), sharding_specs[0]->sharding));
xla::HloSharding::Replicate().ToProto(),
sharding_specs[0]->sharding.GetXlaOpSharding()));
}

// Check if the placeholder is on a SPMD device (sharded) with no real values.
Expand Down
2 changes: 2 additions & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,8 @@ function run_xla_op_tests3 {
run_test "$_TEST_DIR/spmd/test_train_spmd_linear_model.py" "$@" --skip-gradient-checkpointing
run_test "$_TEST_DIR/test_gradient_accumulation.py"
run_save_tensor_hlo run_test "$_TEST_DIR/spmd/test_spmd_lowering_context.py"
run_test_multi_devices "$_TEST_DIR/spmd/test_submesh_zero_indexed.py"
run_test_multi_devices "$_TEST_DIR/spmd/test_submesh_non_zero_indexed.py"
run_test "$_TEST_DIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY
run_test "$_TEST_DIR/test_input_output_aliases.py"
run_test_without_functionalization "$_TEST_DIR/test_input_output_aliases.py"
Expand Down
Loading