diff --git a/test/test_ops_error_message.py b/test/test_ops_error_message.py index bb23810f234..42858fc84f6 100644 --- a/test/test_ops_error_message.py +++ b/test/test_ops_error_message.py @@ -222,6 +222,19 @@ def test(): expect="""roll(): expected `dims` [0] (size=1) to match the size of `shifts` [2, 2] (size=2).""" ) + def test_trace_raises_error_on_non_matrix_input(self): + device = torch_xla.device() + a = torch.rand(2, 2, 2, device=device) + + def test(): + torch.trace(a) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""trace(): expected the input tensor f32[2,2,2] to be a matrix (i.e. a 2D tensor).""" + ) + if __name__ == "__main__": unittest.main() diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index c2514b41097..970e6608520 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -3873,8 +3873,11 @@ std::tuple XLANativeFunctions::topk( at::Tensor XLANativeFunctions::trace(const at::Tensor& self) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); - return bridge::AtenFromXlaTensor(tensor_methods::trace(xla_self)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self, + bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output, + tensor_methods::trace(xla_self)); + return bridge::AtenFromXlaTensor(std::move(output)); } at::Tensor XLANativeFunctions::transpose_copy(const at::Tensor& self, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 11545f72ff4..7740b0e06ff 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -10,6 +10,7 @@ #include #include +#include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" @@ -480,13 +481,16 @@ absl::Status CheckGatherSizesAreCompatible(const XLATensorPtr& input, return absl::OkStatus(); } -absl::Status CheckMMInputIsMatrix(const XLATensorPtr& mat, - const std::string_view arg) { - xla::Shape shape = mat->shape(); +absl::Status CheckInputIsMatrix(const XLATensorPtr& tensor, + const std::string_view op, + const std::string_view arg = "") { + xla::Shape shape = tensor->shape(); if (shape.dimensions().size() != 2) { - return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError( - absl::StrCat("mm(): expected the ", arg, " input tensor ", - shape.ToString(), " to be a matrix (i.e. a 2D tensor)."))); + const std::string arg_with_trailing_space = + arg.empty() ? std::string("") : absl::StrCat(arg, " "); + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( + op, "(): expected the ", arg_with_trailing_space, "input tensor ", + shape.ToString(), " to be a matrix (i.e. a 2D tensor)."))); } return absl::OkStatus(); } @@ -2420,8 +2424,8 @@ XLATensorPtr mish(const XLATensorPtr& input) { absl::StatusOr mm(const XLATensorPtr& input, const XLATensorPtr& weight) { - XLA_RETURN_IF_ERROR(CheckMMInputIsMatrix(input, "first")); - XLA_RETURN_IF_ERROR(CheckMMInputIsMatrix(weight, "second")); + XLA_RETURN_IF_ERROR(CheckInputIsMatrix(input, "mm", "first")); + XLA_RETURN_IF_ERROR(CheckInputIsMatrix(weight, "mm", "second")); XLA_RETURN_IF_ERROR(CheckMMMatrixSizesAreCompatible(input, weight)); return input->CreateFrom(Dot(input->GetIrValue(), weight->GetIrValue())); } @@ -3617,13 +3621,11 @@ std::tuple topk(const XLATensorPtr& input, return std::make_tuple(t1, t2); } -XLATensorPtr trace(const XLATensorPtr& input) { - auto input_shape_ref = input->shape(); - XLA_CHECK_EQ((*input_shape_ref).dimensions_size(), 2) - << "invalid argument for trace: expected a matrix"; - torch::lazy::NodePtr eye = Identity((*input_shape_ref).dimensions(0), - (*input_shape_ref).dimensions(1), - (*input_shape_ref).element_type()); +absl::StatusOr trace(const XLATensorPtr& input) { + XLA_RETURN_IF_ERROR(CheckInputIsMatrix(input, "trace")); + xla::Shape shape = input->shape(); + torch::lazy::NodePtr eye = + Identity(shape.dimensions(0), shape.dimensions(1), shape.element_type()); return sum(input->CreateFrom(eye * input->GetIrValue()), {0, 1}, false, input->dtype()); } diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 71a522d4d06..2ca64c50291 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -972,7 +972,7 @@ std::tuple topk(const XLATensorPtr& input, bool stable); // Returns the sum of the elements of the diagonal of the input 2-D matrix. -XLATensorPtr trace(const XLATensorPtr& input); +absl::StatusOr trace(const XLATensorPtr& input); // Swap given dimensions of the input. XLATensorPtr transpose(const XLATensorPtr& input, int64_t dim0, int64_t dim1);