Skip to content

Commit d913d49

Browse files
committed
Improve error handling and error message for trace.
1 parent 6d755ee commit d913d49

File tree

4 files changed

+36
-18
lines changed

4 files changed

+36
-18
lines changed

test/test_operations.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2367,6 +2367,18 @@ def test_isneginf_no_fallback(self):
23672367
t = t.to(torch.float16)
23682368
self._test_no_fallback(torch.isneginf, (t,))
23692369

2370+
def test_trace_raises_error_on_non_matrix_input(self):
2371+
device = torch_xla.device()
2372+
a = torch.rand(2, 2, 2, device=device)
2373+
2374+
try:
2375+
torch.trace(a)
2376+
except RuntimeError as e:
2377+
expected_error = (
2378+
"trace(): expected the input tensor f32[2,2,2] to be a "
2379+
"matrix (i.e. a 2D tensor).")
2380+
self.assertEqual(str(e), expected_error)
2381+
23702382

23712383
class MNISTComparator(nn.Module):
23722384

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3873,8 +3873,11 @@ std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::topk(
38733873

38743874
at::Tensor XLANativeFunctions::trace(const at::Tensor& self) {
38753875
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
3876-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
3877-
return bridge::AtenFromXlaTensor(tensor_methods::trace(xla_self));
3876+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
3877+
bridge::GetXlaTensor(self));
3878+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output,
3879+
tensor_methods::trace(xla_self));
3880+
return bridge::AtenFromXlaTensor(std::move(output));
38783881
}
38793882

38803883
at::Tensor XLANativeFunctions::transpose_copy(const at::Tensor& self,

torch_xla/csrc/tensor_methods.cpp

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
#include <functional>
1111
#include <iterator>
1212

13+
#include "absl/base/nullability.h"
1314
#include "absl/log/absl_check.h"
1415
#include "absl/status/status.h"
1516
#include "absl/strings/str_cat.h"
1617
#include "absl/strings/str_join.h"
1718
#include "absl/strings/str_split.h"
19+
#include "status.h"
1820
#include "torch_xla/csrc/LazyIr.h"
1921
#include "torch_xla/csrc/aten_xla_bridge.h"
2022
#include "torch_xla/csrc/data_ops.h"
@@ -480,13 +482,16 @@ absl::Status CheckGatherSizesAreCompatible(const XLATensorPtr& input,
480482
return absl::OkStatus();
481483
}
482484

483-
absl::Status CheckMMInputIsMatrix(const XLATensorPtr& mat,
484-
const std::string_view arg) {
485-
xla::Shape shape = mat->shape();
485+
absl::Status CheckInputIsMatrix(const XLATensorPtr& tensor,
486+
const std::string_view op,
487+
const std::string_view arg = "") {
488+
xla::Shape shape = tensor->shape();
486489
if (shape.dimensions().size() != 2) {
487-
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(
488-
absl::StrCat("mm(): expected the ", arg, " input tensor ",
489-
shape.ToString(), " to be a matrix (i.e. a 2D tensor).")));
490+
const std::string arg_with_trailing_space =
491+
arg.empty() ? "" : std::string(arg) + " ";
492+
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat(
493+
op, "(): expected the ", arg_with_trailing_space, "input tensor ",
494+
shape.ToString(), " to be a matrix (i.e. a 2D tensor).")));
490495
}
491496
return absl::OkStatus();
492497
}
@@ -2420,8 +2425,8 @@ XLATensorPtr mish(const XLATensorPtr& input) {
24202425

24212426
absl::StatusOr<XLATensorPtr> mm(const XLATensorPtr& input,
24222427
const XLATensorPtr& weight) {
2423-
XLA_RETURN_IF_ERROR(CheckMMInputIsMatrix(input, "first"));
2424-
XLA_RETURN_IF_ERROR(CheckMMInputIsMatrix(weight, "second"));
2428+
XLA_RETURN_IF_ERROR(CheckInputIsMatrix(input, "mm", "first"));
2429+
XLA_RETURN_IF_ERROR(CheckInputIsMatrix(weight, "mm", "second"));
24252430
XLA_RETURN_IF_ERROR(CheckMMMatrixSizesAreCompatible(input, weight));
24262431
return input->CreateFrom(Dot(input->GetIrValue(), weight->GetIrValue()));
24272432
}
@@ -3617,13 +3622,11 @@ std::tuple<XLATensorPtr, XLATensorPtr> topk(const XLATensorPtr& input,
36173622
return std::make_tuple(t1, t2);
36183623
}
36193624

3620-
XLATensorPtr trace(const XLATensorPtr& input) {
3621-
auto input_shape_ref = input->shape();
3622-
XLA_CHECK_EQ((*input_shape_ref).dimensions_size(), 2)
3623-
<< "invalid argument for trace: expected a matrix";
3624-
torch::lazy::NodePtr eye = Identity((*input_shape_ref).dimensions(0),
3625-
(*input_shape_ref).dimensions(1),
3626-
(*input_shape_ref).element_type());
3625+
absl::StatusOr<absl_nonnull XLATensorPtr> trace(const XLATensorPtr& input) {
3626+
XLA_RETURN_IF_ERROR(CheckInputIsMatrix(input, "trace"));
3627+
xla::Shape shape = input->shape();
3628+
torch::lazy::NodePtr eye =
3629+
Identity(shape.dimensions(0), shape.dimensions(1), shape.element_type());
36273630
return sum(input->CreateFrom(eye * input->GetIrValue()), {0, 1}, false,
36283631
input->dtype());
36293632
}

torch_xla/csrc/tensor_methods.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -972,7 +972,7 @@ std::tuple<XLATensorPtr, XLATensorPtr> topk(const XLATensorPtr& input,
972972
bool stable);
973973

974974
// Returns the sum of the elements of the diagonal of the input 2-D matrix.
975-
XLATensorPtr trace(const XLATensorPtr& input);
975+
absl::StatusOr<absl_nonnull XLATensorPtr> trace(const XLATensorPtr& input);
976976

977977
// Swap given dimensions of the input.
978978
XLATensorPtr transpose(const XLATensorPtr& input, int64_t dim0, int64_t dim1);

0 commit comments

Comments
 (0)