From d913d49b48edc7e2386646e358342e368208bfa8 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Tue, 9 Sep 2025 17:04:32 -0300 Subject: [PATCH 1/4] Improve error handling and error message for `trace`. --- test/test_operations.py | 12 +++++++++++ torch_xla/csrc/aten_xla_type.cpp | 7 +++++-- torch_xla/csrc/tensor_methods.cpp | 33 +++++++++++++++++-------------- torch_xla/csrc/tensor_methods.h | 2 +- 4 files changed, 36 insertions(+), 18 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index a02bb746666..3c61291f125 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2367,6 +2367,18 @@ def test_isneginf_no_fallback(self): t = t.to(torch.float16) self._test_no_fallback(torch.isneginf, (t,)) + def test_trace_raises_error_on_non_matrix_input(self): + device = torch_xla.device() + a = torch.rand(2, 2, 2, device=device) + + try: + torch.trace(a) + except RuntimeError as e: + expected_error = ( + "trace(): expected the input tensor f32[2,2,2] to be a " + "matrix (i.e. a 2D tensor).") + self.assertEqual(str(e), expected_error) + class MNISTComparator(nn.Module): diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index c2514b41097..970e6608520 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -3873,8 +3873,11 @@ std::tuple XLANativeFunctions::topk( at::Tensor XLANativeFunctions::trace(const at::Tensor& self) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); - return bridge::AtenFromXlaTensor(tensor_methods::trace(xla_self)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self, + bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output, + tensor_methods::trace(xla_self)); + return bridge::AtenFromXlaTensor(std::move(output)); } at::Tensor XLANativeFunctions::transpose_copy(const at::Tensor& self, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 11545f72ff4..8eb99aaa19a 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -10,11 +10,13 @@ #include #include +#include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" +#include "status.h" #include "torch_xla/csrc/LazyIr.h" #include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/data_ops.h" @@ -480,13 +482,16 @@ absl::Status CheckGatherSizesAreCompatible(const XLATensorPtr& input, return absl::OkStatus(); } -absl::Status CheckMMInputIsMatrix(const XLATensorPtr& mat, - const std::string_view arg) { - xla::Shape shape = mat->shape(); +absl::Status CheckInputIsMatrix(const XLATensorPtr& tensor, + const std::string_view op, + const std::string_view arg = "") { + xla::Shape shape = tensor->shape(); if (shape.dimensions().size() != 2) { - return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError( - absl::StrCat("mm(): expected the ", arg, " input tensor ", - shape.ToString(), " to be a matrix (i.e. a 2D tensor)."))); + const std::string arg_with_trailing_space = + arg.empty() ? "" : std::string(arg) + " "; + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( + op, "(): expected the ", arg_with_trailing_space, "input tensor ", + shape.ToString(), " to be a matrix (i.e. a 2D tensor)."))); } return absl::OkStatus(); } @@ -2420,8 +2425,8 @@ XLATensorPtr mish(const XLATensorPtr& input) { absl::StatusOr mm(const XLATensorPtr& input, const XLATensorPtr& weight) { - XLA_RETURN_IF_ERROR(CheckMMInputIsMatrix(input, "first")); - XLA_RETURN_IF_ERROR(CheckMMInputIsMatrix(weight, "second")); + XLA_RETURN_IF_ERROR(CheckInputIsMatrix(input, "mm", "first")); + XLA_RETURN_IF_ERROR(CheckInputIsMatrix(weight, "mm", "second")); XLA_RETURN_IF_ERROR(CheckMMMatrixSizesAreCompatible(input, weight)); return input->CreateFrom(Dot(input->GetIrValue(), weight->GetIrValue())); } @@ -3617,13 +3622,11 @@ std::tuple topk(const XLATensorPtr& input, return std::make_tuple(t1, t2); } -XLATensorPtr trace(const XLATensorPtr& input) { - auto input_shape_ref = input->shape(); - XLA_CHECK_EQ((*input_shape_ref).dimensions_size(), 2) - << "invalid argument for trace: expected a matrix"; - torch::lazy::NodePtr eye = Identity((*input_shape_ref).dimensions(0), - (*input_shape_ref).dimensions(1), - (*input_shape_ref).element_type()); +absl::StatusOr trace(const XLATensorPtr& input) { + XLA_RETURN_IF_ERROR(CheckInputIsMatrix(input, "trace")); + xla::Shape shape = input->shape(); + torch::lazy::NodePtr eye = + Identity(shape.dimensions(0), shape.dimensions(1), shape.element_type()); return sum(input->CreateFrom(eye * input->GetIrValue()), {0, 1}, false, input->dtype()); } diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 71a522d4d06..2ca64c50291 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -972,7 +972,7 @@ std::tuple topk(const XLATensorPtr& input, bool stable); // Returns the sum of the elements of the diagonal of the input 2-D matrix. -XLATensorPtr trace(const XLATensorPtr& input); +absl::StatusOr trace(const XLATensorPtr& input); // Swap given dimensions of the input. XLATensorPtr transpose(const XLATensorPtr& input, int64_t dim0, int64_t dim1); From 0e4c662cff284bfee3543dc205a66f07a26ae7db Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 10 Sep 2025 11:05:56 -0300 Subject: [PATCH 2/4] Fix lint. --- test/test_operations.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index 3c61291f125..971220e2d3d 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2374,9 +2374,8 @@ def test_trace_raises_error_on_non_matrix_input(self): try: torch.trace(a) except RuntimeError as e: - expected_error = ( - "trace(): expected the input tensor f32[2,2,2] to be a " - "matrix (i.e. a 2D tensor).") + expected_error = ("trace(): expected the input tensor f32[2,2,2] to be a " + "matrix (i.e. a 2D tensor).") self.assertEqual(str(e), expected_error) From e013a2524e36f177cde67d1f05b4d847474550aa Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 12 Sep 2025 16:19:22 -0300 Subject: [PATCH 3/4] Move test. --- test/test_operations.py | 11 ----------- test/test_ops_error_message.py | 13 +++++++++++++ 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index 971220e2d3d..a02bb746666 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2367,17 +2367,6 @@ def test_isneginf_no_fallback(self): t = t.to(torch.float16) self._test_no_fallback(torch.isneginf, (t,)) - def test_trace_raises_error_on_non_matrix_input(self): - device = torch_xla.device() - a = torch.rand(2, 2, 2, device=device) - - try: - torch.trace(a) - except RuntimeError as e: - expected_error = ("trace(): expected the input tensor f32[2,2,2] to be a " - "matrix (i.e. a 2D tensor).") - self.assertEqual(str(e), expected_error) - class MNISTComparator(nn.Module): diff --git a/test/test_ops_error_message.py b/test/test_ops_error_message.py index bb23810f234..42858fc84f6 100644 --- a/test/test_ops_error_message.py +++ b/test/test_ops_error_message.py @@ -222,6 +222,19 @@ def test(): expect="""roll(): expected `dims` [0] (size=1) to match the size of `shifts` [2, 2] (size=2).""" ) + def test_trace_raises_error_on_non_matrix_input(self): + device = torch_xla.device() + a = torch.rand(2, 2, 2, device=device) + + def test(): + torch.trace(a) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""trace(): expected the input tensor f32[2,2,2] to be a matrix (i.e. a 2D tensor).""" + ) + if __name__ == "__main__": unittest.main() From e36e378dd5b6784ecdd9d984351b4f393e4aefe1 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 17 Sep 2025 16:38:12 -0300 Subject: [PATCH 4/4] Update. --- torch_xla/csrc/tensor_methods.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 8eb99aaa19a..7740b0e06ff 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -16,7 +16,6 @@ #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" -#include "status.h" #include "torch_xla/csrc/LazyIr.h" #include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/data_ops.h" @@ -488,7 +487,7 @@ absl::Status CheckInputIsMatrix(const XLATensorPtr& tensor, xla::Shape shape = tensor->shape(); if (shape.dimensions().size() != 2) { const std::string arg_with_trailing_space = - arg.empty() ? "" : std::string(arg) + " "; + arg.empty() ? std::string("") : absl::StrCat(arg, " "); return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( op, "(): expected the ", arg_with_trailing_space, "input tensor ", shape.ToString(), " to be a matrix (i.e. a 2D tensor).")));