Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions test/test_ops_error_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,19 @@ def test():
expect="""roll(): expected `dims` [0] (size=1) to match the size of `shifts` [2, 2] (size=2)."""
)

def test_trace_raises_error_on_non_matrix_input(self):
device = torch_xla.device()
a = torch.rand(2, 2, 2, device=device)

def test():
torch.trace(a)

self.assertExpectedRaisesInline(
exc_type=RuntimeError,
callable=test,
expect="""trace(): expected the input tensor f32[2,2,2] to be a matrix (i.e. a 2D tensor)."""
)


if __name__ == "__main__":
unittest.main()
7 changes: 5 additions & 2 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3873,8 +3873,11 @@ std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::topk(

at::Tensor XLANativeFunctions::trace(const at::Tensor& self) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
return bridge::AtenFromXlaTensor(tensor_methods::trace(xla_self));
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
bridge::GetXlaTensor(self));
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output,
tensor_methods::trace(xla_self));
return bridge::AtenFromXlaTensor(std::move(output));
}

at::Tensor XLANativeFunctions::transpose_copy(const at::Tensor& self,
Expand Down
32 changes: 17 additions & 15 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <functional>
#include <iterator>

#include "absl/base/nullability.h"
#include "absl/log/absl_check.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
Expand Down Expand Up @@ -480,13 +481,16 @@ absl::Status CheckGatherSizesAreCompatible(const XLATensorPtr& input,
return absl::OkStatus();
}

absl::Status CheckMMInputIsMatrix(const XLATensorPtr& mat,
const std::string_view arg) {
xla::Shape shape = mat->shape();
absl::Status CheckInputIsMatrix(const XLATensorPtr& tensor,
const std::string_view op,
const std::string_view arg = "") {
xla::Shape shape = tensor->shape();
if (shape.dimensions().size() != 2) {
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(
absl::StrCat("mm(): expected the ", arg, " input tensor ",
shape.ToString(), " to be a matrix (i.e. a 2D tensor).")));
const std::string arg_with_trailing_space =
arg.empty() ? std::string("") : absl::StrCat(arg, " ");
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat(
op, "(): expected the ", arg_with_trailing_space, "input tensor ",
shape.ToString(), " to be a matrix (i.e. a 2D tensor).")));
}
return absl::OkStatus();
}
Expand Down Expand Up @@ -2420,8 +2424,8 @@ XLATensorPtr mish(const XLATensorPtr& input) {

absl::StatusOr<XLATensorPtr> mm(const XLATensorPtr& input,
const XLATensorPtr& weight) {
XLA_RETURN_IF_ERROR(CheckMMInputIsMatrix(input, "first"));
XLA_RETURN_IF_ERROR(CheckMMInputIsMatrix(weight, "second"));
XLA_RETURN_IF_ERROR(CheckInputIsMatrix(input, "mm", "first"));
XLA_RETURN_IF_ERROR(CheckInputIsMatrix(weight, "mm", "second"));
XLA_RETURN_IF_ERROR(CheckMMMatrixSizesAreCompatible(input, weight));
return input->CreateFrom(Dot(input->GetIrValue(), weight->GetIrValue()));
}
Expand Down Expand Up @@ -3617,13 +3621,11 @@ std::tuple<XLATensorPtr, XLATensorPtr> topk(const XLATensorPtr& input,
return std::make_tuple(t1, t2);
}

XLATensorPtr trace(const XLATensorPtr& input) {
auto input_shape_ref = input->shape();
XLA_CHECK_EQ((*input_shape_ref).dimensions_size(), 2)
<< "invalid argument for trace: expected a matrix";
torch::lazy::NodePtr eye = Identity((*input_shape_ref).dimensions(0),
(*input_shape_ref).dimensions(1),
(*input_shape_ref).element_type());
absl::StatusOr<absl_nonnull XLATensorPtr> trace(const XLATensorPtr& input) {
XLA_RETURN_IF_ERROR(CheckInputIsMatrix(input, "trace"));
xla::Shape shape = input->shape();
torch::lazy::NodePtr eye =
Identity(shape.dimensions(0), shape.dimensions(1), shape.element_type());
return sum(input->CreateFrom(eye * input->GetIrValue()), {0, 1}, false,
input->dtype());
}
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -972,7 +972,7 @@ std::tuple<XLATensorPtr, XLATensorPtr> topk(const XLATensorPtr& input,
bool stable);

// Returns the sum of the elements of the diagonal of the input 2-D matrix.
XLATensorPtr trace(const XLATensorPtr& input);
absl::StatusOr<absl_nonnull XLATensorPtr> trace(const XLATensorPtr& input);

// Swap given dimensions of the input.
XLATensorPtr transpose(const XLATensorPtr& input, int64_t dim0, int64_t dim1);
Expand Down
Loading