Skip to content

Commit 0c0ae2d

Browse files
authored
expand: improve error handling and error messages. (#9645)
This PR refactors the `expand` operation implementation by improving its error message, and returning a status type value. **Key Changes:** - Make `tensor_methods::expand` return `StatusOr<XLATensorPtr>` - Improve error messages and error handling - Create new `CheckExpandValidRank` for checking the input and the given sizes' rank - Modified `GetExpandDimensions`, calling the check function above, and also checking whether input and sizes corresponding dimensions are valid
1 parent 6d755ee commit 0c0ae2d

File tree

7 files changed

+132
-48
lines changed

7 files changed

+132
-48
lines changed

test/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ function run_xla_op_tests1 {
151151
run_eager_debug "$_TEST_DIR/test_operations.py" "$@" --verbosity=$VERBOSITY
152152
run_test "$_TEST_DIR/test_operations.py" "$@" --verbosity=$VERBOSITY
153153
run_test "$_TEST_DIR/test_ops_error_message.py"
154+
run_test "$_TEST_DIR/test_ops_error_message_functionalization_disabled.py"
154155
run_test "$_TEST_DIR/test_xla_graph_execution.py" "$@" --verbosity=$VERBOSITY
155156
run_pt_xla_debug_level2 "$_TEST_DIR/test_xla_graph_execution.py" "$@" --verbosity=$VERBOSITY
156157
run_test_without_functionalization "$_TEST_DIR/test_operations.py" "$@" --verbosity=$VERBOSITY
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import os
2+
3+
os.environ["XLA_DISABLE_FUNCTIONALIZATION"] = "1"
4+
5+
import expecttest
6+
import torch
7+
import torch_xla
8+
import unittest
9+
10+
11+
class TestOpsErrorMessageFunctionalizationDisabled(expecttest.TestCase):
12+
13+
def test_expand_raises_error_on_higher_rank_tensor(self):
14+
device = torch_xla.device()
15+
a = torch.rand(1, 1, 2, 3, device=device)
16+
sizes = [-1, 3]
17+
18+
def test():
19+
return a.expand(sizes)
20+
21+
self.assertExpectedRaisesInline(
22+
exc_type=RuntimeError,
23+
callable=test,
24+
expect="""expand(): expected the `input` tensor f32[1,1,2,3] (rank: 4) to have a rank smaller or equal to the given `sizes` [-1, 3] (rank: 2)."""
25+
)
26+
27+
def test_expand_raises_error_on_size_mismatch(self):
28+
device = torch_xla.device()
29+
a = torch.rand(1, 1, 2, 3, device=device)
30+
sizes = [1, 1, 1, 3]
31+
32+
def test():
33+
return a.expand(sizes)
34+
35+
self.assertExpectedRaisesInline(
36+
exc_type=RuntimeError,
37+
callable=test,
38+
expect="""expand(): expected dimension 2 of the given `sizes` [1, 1, 1, 3] (1) to be -1, or equal to the size of the `input` tensor f32[1,1,2,3] at dimension 2 (2)."""
39+
)
40+
41+
42+
if __name__ == "__main__":
43+
unittest.main()

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1786,19 +1786,21 @@ at::Tensor XLANativeFunctions::empty_strided_symint(
17861786
}
17871787

17881788
at::Tensor XLANativeFunctions::expand_copy_symint(const at::Tensor& self,
1789-
at::SymIntArrayRef sym_size,
1789+
at::SymIntArrayRef sym_sizes,
17901790
bool implicit) {
17911791
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
1792-
std::optional<at::IntArrayRef> size = c10::asIntArrayRefSlowOpt(sym_size);
1793-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
1794-
if (size.has_value()) {
1795-
return bridge::AtenFromXlaTensor(tensor_methods::expand(
1796-
xla_self, torch::lazy::ToVector<int64_t>(*size)));
1792+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
1793+
bridge::GetXlaTensor(self));
1794+
std::optional<at::IntArrayRef> sizes = c10::asIntArrayRefSlowOpt(sym_sizes);
1795+
if (sizes.has_value()) {
1796+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output,
1797+
tensor_methods::expand(xla_self, *sizes));
1798+
return bridge::AtenFromXlaTensor(std::move(output));
17971799
} else {
17981800
// at least one of the dimension is symbolic, use the sym_int version of the
17991801
// node
18001802
return bridge::AtenFromXlaTensor(
1801-
tensor_methods::expand_symint(xla_self, sym_size));
1803+
tensor_methods::expand_symint(xla_self, sym_sizes));
18021804
}
18031805
}
18041806

@@ -4563,19 +4565,21 @@ at::Tensor XLANativeFunctions::diagonal(const at::Tensor& self, int64_t offset,
45634565
}
45644566

45654567
at::Tensor XLANativeFunctions::expand_symint(const at::Tensor& self,
4566-
at::SymIntArrayRef sym_size,
4568+
at::SymIntArrayRef sym_sizes,
45674569
bool implicit) {
45684570
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
4569-
std::optional<at::IntArrayRef> size = c10::asIntArrayRefSlowOpt(sym_size);
4570-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
4571-
if (size.has_value()) {
4572-
return bridge::AtenFromXlaTensor(tensor_methods::expand(
4573-
xla_self, torch::lazy::ToVector<int64_t>(*size)));
4571+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
4572+
bridge::GetXlaTensor(self));
4573+
std::optional<at::IntArrayRef> sizes = c10::asIntArrayRefSlowOpt(sym_sizes);
4574+
if (sizes.has_value()) {
4575+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output,
4576+
tensor_methods::expand(xla_self, *sizes));
4577+
return bridge::AtenFromXlaTensor(std::move(output));
45744578
} else {
45754579
// at least one of the dimension is symbolic, use the sym_int version of the
45764580
// node
45774581
return bridge::AtenFromXlaTensor(
4578-
tensor_methods::expand_symint(xla_self, sym_size));
4582+
tensor_methods::expand_symint(xla_self, sym_sizes));
45794583
}
45804584
}
45814585

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <unordered_map>
2525
#include <vector>
2626

27+
#include "absl/base/nullability.h"
2728
#include "absl/container/flat_hash_map.h"
2829
#include "absl/log/absl_check.h"
2930
#include "absl/strings/str_cat.h"
@@ -451,15 +452,18 @@ at::Tensor AllReduce(const std::string& reduce_type, const at::Tensor& input,
451452
}
452453

453454
at::Tensor DynamicExpand(const at::Tensor& input,
454-
const std::vector<int64_t>& size,
455+
const std::vector<int64_t>& sizes,
455456
const at::Tensor& src_tensor, int src_dim,
456457
int target_dim) {
457-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_input, bridge::GetXlaTensor(input));
458-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_src_tensor,
458+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_input,
459+
bridge::GetXlaTensor(input));
460+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_src_tensor,
459461
bridge::GetXlaTensor(src_tensor));
460-
XLATensorPtr result = tensor_methods::dynamic_expand(
461-
xla_input, size, xla_src_tensor, src_dim, target_dim);
462-
return bridge::AtenFromXlaTensor(std::move(result));
462+
XLA_ASSIGN_OR_THROW(
463+
absl_nonnull XLATensorPtr output,
464+
tensor_methods::dynamic_expand(xla_input, sizes, xla_src_tensor, src_dim,
465+
target_dim));
466+
return bridge::AtenFromXlaTensor(std::move(output));
463467
}
464468

465469
at::Tensor DynamicView(const at::Tensor& input,

torch_xla/csrc/tensor_methods.cpp

Lines changed: 52 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
#include <functional>
1111
#include <iterator>
1212

13+
#include "absl/base/nullability.h"
1314
#include "absl/log/absl_check.h"
1415
#include "absl/status/status.h"
1516
#include "absl/strings/str_cat.h"
1617
#include "absl/strings/str_join.h"
1718
#include "absl/strings/str_split.h"
19+
#include "absl/types/span.h"
1820
#include "torch_xla/csrc/LazyIr.h"
1921
#include "torch_xla/csrc/aten_xla_bridge.h"
2022
#include "torch_xla/csrc/data_ops.h"
@@ -234,16 +236,45 @@ void CheckBmmDimension(const std::string& tag, const XLATensorPtr& batch1,
234236
"batch2", 2);
235237
}
236238

237-
std::vector<int64_t> GetExpandDimensions(const xla::Shape& shape,
238-
std::vector<int64_t> dimensions) {
239-
XLA_CHECK_GE(dimensions.size(), shape.dimensions_size()) << shape;
240-
int64_t base = dimensions.size() - shape.dimensions_size();
241-
for (size_t i = 0; i < shape.dimensions_size(); ++i) {
242-
if (dimensions[base + i] == -1) {
243-
dimensions[base + i] = shape.dimensions(i);
239+
absl::Status CheckExpandValidRank(const XLATensorPtr& input,
240+
const absl::Span<const int64_t> sizes) {
241+
xla::Shape shape = input->shape();
242+
int64_t rank = shape.dimensions().size();
243+
if (rank > sizes.size()) {
244+
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat(
245+
"expand(): expected the `input` tensor ", shape.ToString(), " (rank: ",
246+
rank, ") to have a rank smaller or equal to the given `sizes` [",
247+
absl::StrJoin(sizes, /* sep= */ ", "), "] (rank: ", sizes.size(),
248+
").")));
249+
}
250+
return absl::OkStatus();
251+
}
252+
253+
absl::StatusOr<std::vector<int64_t>> GetExpandDimensions(
254+
const XLATensorPtr& input, const absl::Span<const int64_t> sizes) {
255+
XLA_RETURN_IF_ERROR(CheckExpandValidRank(input, sizes));
256+
257+
xla::Shape shape = input->shape();
258+
const int64_t rank = shape.dimensions().size();
259+
const int64_t base = sizes.size() - rank;
260+
261+
std::vector<int64_t> expanded_dimensions(sizes.begin(), sizes.end());
262+
for (size_t i = 0; i < shape.dimensions().size(); ++i) {
263+
const int64_t dim = base + i;
264+
const int64_t size = sizes[dim];
265+
if (size == -1) {
266+
expanded_dimensions[dim] = shape.dimensions(i);
267+
} else if (shape.dimensions(i) != 1 && size != shape.dimensions(i)) {
268+
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat(
269+
"expand(): expected dimension ", dim, " of the given `sizes` [",
270+
absl::StrJoin(sizes, /* sep= */ ", "), "] (", size,
271+
") to be -1, or equal to the size of the `input` tensor ",
272+
shape.ToString(), " at dimension ", i, " (", shape.dimensions(i),
273+
").")));
244274
}
245275
}
246-
return dimensions;
276+
277+
return expanded_dimensions;
247278
}
248279

249280
// Resizes and / or checks whether a list is of the given size. The list is only
@@ -1791,11 +1822,12 @@ XLATensorPtr exp(const XLATensorPtr& input) {
17911822
return input->CreateFrom(Exp(input->GetIrValue()));
17921823
}
17931824

1794-
XLATensorPtr expand(const XLATensorPtr& input, std::vector<int64_t> size) {
1795-
auto input_shape = input->shape();
1796-
auto output = input->CreateFrom(torch_xla::MakeNode<Expand>(
1797-
input->GetIrValue(),
1798-
GetExpandDimensions(input_shape.get(), std::move(size))));
1825+
absl::StatusOr<absl_nonnull XLATensorPtr> expand(
1826+
const XLATensorPtr& input, const absl::Span<const int64_t> sizes) {
1827+
XLA_ASSIGN_OR_RETURN(std::vector<int64_t> expanded_dimensions,
1828+
GetExpandDimensions(input, sizes));
1829+
auto output = input->CreateFrom(
1830+
torch_xla::MakeNode<Expand>(input->GetIrValue(), expanded_dimensions));
17991831
output->SetStorage(input->Storage());
18001832
return output;
18011833
}
@@ -2927,15 +2959,14 @@ XLATensorPtr cast_int4(const XLATensorPtr& weight,
29272959
// Dynamic Reshape ops here.
29282960
//////////////////////////////////////////////////////////////////////////////
29292961

2930-
XLATensorPtr dynamic_expand(const XLATensorPtr& input,
2931-
const std::vector<int64_t>& size,
2932-
const XLATensorPtr& src_tensor, int src_dim,
2933-
int target_dim) {
2934-
std::vector<int64_t> expanded_size =
2935-
GetExpandDimensions(input->shape().get(), size);
2962+
absl::StatusOr<absl_nonnull XLATensorPtr> dynamic_expand(
2963+
const XLATensorPtr& input, const absl::Span<const int64_t> sizes,
2964+
const XLATensorPtr& src_tensor, int src_dim, int target_dim) {
2965+
XLA_ASSIGN_OR_RETURN(std::vector<int64_t> expanded_dimensions,
2966+
GetExpandDimensions(input, sizes));
29362967
torch::lazy::NodePtr node = torch_xla::MakeNode<DynamicExpand>(
2937-
input->GetIrValue(), expanded_size, src_tensor->GetIrValue(), src_dim,
2938-
target_dim);
2968+
input->GetIrValue(), expanded_dimensions, src_tensor->GetIrValue(),
2969+
src_dim, target_dim);
29392970
return input->CreateFrom(torch::lazy::Value(node));
29402971
}
29412972

torch_xla/csrc/tensor_methods.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define XLA_TORCH_XLA_CSRC_TENSOR_METHODS_H_
33

44
#include "absl/base/nullability.h"
5+
#include "absl/types/span.h"
56
#include "torch_xla/csrc/cross_replica_reduces.h"
67
#include "torch_xla/csrc/ops/custom_sharding.h"
78
#include "torch_xla/csrc/runtime/computation_client.h"
@@ -158,10 +159,9 @@ XLATensorPtr cast_int4(const XLATensorPtr& weight,
158159
// Dynamic Reshape ops here.
159160
//////////////////////////////////////////////////////////////////////////////
160161

161-
XLATensorPtr dynamic_expand(const XLATensorPtr& input,
162-
const std::vector<int64_t>& size,
163-
const XLATensorPtr& src_tensor, int src_dim,
164-
int target_dim);
162+
absl::StatusOr<absl_nonnull XLATensorPtr> dynamic_expand(
163+
const XLATensorPtr& input, const absl::Span<const int64_t> sizes,
164+
const XLATensorPtr& src_tensor, int src_dim, int target_dim);
165165

166166
XLATensorPtr dynamic_view(const XLATensorPtr& input,
167167
const std::vector<int64_t>& size,
@@ -427,7 +427,8 @@ XLATensorPtr eq(const XLATensorPtr& input, const XLATensorPtr& other);
427427

428428
XLATensorPtr exp(const XLATensorPtr& input);
429429

430-
XLATensorPtr expand(const XLATensorPtr& input, std::vector<int64_t> size);
430+
absl::StatusOr<absl_nonnull XLATensorPtr> expand(
431+
const XLATensorPtr& input, const absl::Span<const int64_t> sizes);
431432

432433
XLATensorPtr expand_symint(const XLATensorPtr& input,
433434
c10::SymIntArrayRef sym_size);

torch_xla/csrc/tensor_ops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,9 +240,9 @@ XLATensorPtr EmbeddingDenseBackward(const XLATensorPtr& grad_output,
240240
// padding_idx.
241241
XLATensorPtr skip_padding = tensor_methods::unsqueeze(
242242
tensor_methods::ne(indices_rank1, padding_idx), 1);
243-
skip_padding = tensor_methods::expand(
243+
XLA_ASSIGN_OR_THROW(
244244
skip_padding,
245-
torch::lazy::ToVector<int64_t>(grad->shape().get().dimensions()));
245+
tensor_methods::expand(skip_padding, grad->shape().get().dimensions()));
246246
XLATensorPtr zero_grad =
247247
tensor_methods::full_like(grad, 0, grad->GetDevice(), grad->dtype());
248248
return tensor_methods::index_put(

0 commit comments

Comments
 (0)