Skip to content

Commit 1cfe785

Browse files
kvshbg-awskvshbg-aws
authored andcommitted
fix for failing ci/cd tests
1 parent 50faf0d commit 1cfe785

File tree

3 files changed

+6
-4
lines changed

3 files changed

+6
-4
lines changed

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2594,12 +2594,12 @@ void InitXlaModuleBindings(py::module m) {
25942594
return GetXLAShardingSpec(xtensor);
25952595
})
25962596
.def("_get_xla_op_sharding",
2597-
[](const at::Tensor& input) -> std::optional<xla::OpSharding> {
2597+
[](const at::Tensor& input) -> std::optional<torch_xla::OpSharding> {
25982598
XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(input));
25992599
XLATensor::ShardingSpecPtr sharding_spec =
26002600
xtensor ? xtensor->sharding_spec() : nullptr;
26012601
if (sharding_spec != nullptr) {
2602-
return sharding_spec->sharding.GetXlaOpSharding();
2602+
return sharding_spec->sharding;
26032603
}
26042604
return std::nullopt;
26052605
})

torch_xla/csrc/runtime/ifrt_computation_client.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,8 @@ class IfrtComputationClient : public ComputationClient {
285285
denormalized_tile_assignment.value_or(std::vector<int64_t>{}))) {
286286
xla_output_shardings_ = this->executable->GetOutputShardings();
287287
if (xla_output_shardings_.has_value()) {
288-
output_shardings_->reserve(xla_output_shardings_->size());
288+
output_shardings_ = std::vector<torch_xla::OpSharding>{};
289+
output_shardings_->reserve(xla_output_shardings_.value().size());
289290
for (const auto& sharding : xla_output_shardings_.value()) {
290291
// convert each into torch_xla::OpSharding object
291292
torch_xla::OpSharding torch_xla_op_sharding(

torch_xla/csrc/runtime/pjrt_computation_client.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,8 @@ class PjRtComputationClient : public ComputationClient {
345345
denormalized_tile_assignment.value_or(std::vector<int64_t>{}))) {
346346
xla_output_shardings_ = this->executable->GetOutputShardings();
347347
if (xla_output_shardings_.has_value()) {
348-
output_shardings_->reserve(xla_output_shardings_->size());
348+
output_shardings_ = std::vector<torch_xla::OpSharding>{};
349+
output_shardings_->reserve(xla_output_shardings_.value().size());
349350
for (const auto& sharding : xla_output_shardings_.value()) {
350351
// convert each into torch_xla::OpSharding object
351352
torch_xla::OpSharding torch_xla_op_sharding(

0 commit comments

Comments
 (0)