File tree Expand file tree Collapse file tree 3 files changed +6
-4
lines changed Expand file tree Collapse file tree 3 files changed +6
-4
lines changed Original file line number Diff line number Diff line change @@ -2594,12 +2594,12 @@ void InitXlaModuleBindings(py::module m) {
2594
2594
return GetXLAShardingSpec (xtensor);
2595
2595
})
2596
2596
.def (" _get_xla_op_sharding" ,
2597
- [](const at::Tensor& input) -> std::optional<xla ::OpSharding> {
2597
+ [](const at::Tensor& input) -> std::optional<torch_xla ::OpSharding> {
2598
2598
XLATensorPtr xtensor = GetValueOrThrow (bridge::GetXlaTensor (input));
2599
2599
XLATensor::ShardingSpecPtr sharding_spec =
2600
2600
xtensor ? xtensor->sharding_spec () : nullptr ;
2601
2601
if (sharding_spec != nullptr ) {
2602
- return sharding_spec->sharding . GetXlaOpSharding () ;
2602
+ return sharding_spec->sharding ;
2603
2603
}
2604
2604
return std::nullopt ;
2605
2605
})
Original file line number Diff line number Diff line change @@ -285,7 +285,8 @@ class IfrtComputationClient : public ComputationClient {
285
285
denormalized_tile_assignment.value_or(std::vector<int64_t >{}))) {
286
286
xla_output_shardings_ = this ->executable ->GetOutputShardings ();
287
287
if (xla_output_shardings_.has_value ()) {
288
- output_shardings_->reserve (xla_output_shardings_->size ());
288
+ output_shardings_ = std::vector<torch_xla::OpSharding>{};
289
+ output_shardings_->reserve (xla_output_shardings_.value ().size ());
289
290
for (const auto & sharding : xla_output_shardings_.value ()) {
290
291
// convert each into torch_xla::OpSharding object
291
292
torch_xla::OpSharding torch_xla_op_sharding (
Original file line number Diff line number Diff line change @@ -345,7 +345,8 @@ class PjRtComputationClient : public ComputationClient {
345
345
denormalized_tile_assignment.value_or(std::vector<int64_t >{}))) {
346
346
xla_output_shardings_ = this ->executable ->GetOutputShardings ();
347
347
if (xla_output_shardings_.has_value ()) {
348
- output_shardings_->reserve (xla_output_shardings_->size ());
348
+ output_shardings_ = std::vector<torch_xla::OpSharding>{};
349
+ output_shardings_->reserve (xla_output_shardings_.value ().size ());
349
350
for (const auto & sharding : xla_output_shardings_.value ()) {
350
351
// convert each into torch_xla::OpSharding object
351
352
torch_xla::OpSharding torch_xla_op_sharding (
You can’t perform that action at this time.
0 commit comments