|
10 | 10 | #include <functional>
|
11 | 11 | #include <iterator>
|
12 | 12 |
|
| 13 | +#include "absl/base/nullability.h" |
13 | 14 | #include "absl/log/absl_check.h"
|
14 | 15 | #include "absl/status/status.h"
|
15 | 16 | #include "absl/strings/str_cat.h"
|
16 | 17 | #include "absl/strings/str_join.h"
|
17 | 18 | #include "absl/strings/str_split.h"
|
| 19 | +#include "status.h" |
18 | 20 | #include "torch_xla/csrc/LazyIr.h"
|
19 | 21 | #include "torch_xla/csrc/aten_xla_bridge.h"
|
20 | 22 | #include "torch_xla/csrc/data_ops.h"
|
@@ -480,13 +482,16 @@ absl::Status CheckGatherSizesAreCompatible(const XLATensorPtr& input,
|
480 | 482 | return absl::OkStatus();
|
481 | 483 | }
|
482 | 484 |
|
483 |
| -absl::Status CheckMMInputIsMatrix(const XLATensorPtr& mat, |
484 |
| - const std::string_view arg) { |
485 |
| - xla::Shape shape = mat->shape(); |
| 485 | +absl::Status CheckInputIsMatrix(const XLATensorPtr& tensor, |
| 486 | + const std::string_view op, |
| 487 | + const std::string_view arg = "") { |
| 488 | + xla::Shape shape = tensor->shape(); |
486 | 489 | if (shape.dimensions().size() != 2) {
|
487 |
| - return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError( |
488 |
| - absl::StrCat("mm(): expected the ", arg, " input tensor ", |
489 |
| - shape.ToString(), " to be a matrix (i.e. a 2D tensor)."))); |
| 490 | + const std::string arg_with_trailing_space = |
| 491 | + arg.empty() ? "" : std::string(arg) + " "; |
| 492 | + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( |
| 493 | + op, "(): expected the ", arg_with_trailing_space, "input tensor ", |
| 494 | + shape.ToString(), " to be a matrix (i.e. a 2D tensor)."))); |
490 | 495 | }
|
491 | 496 | return absl::OkStatus();
|
492 | 497 | }
|
@@ -2420,8 +2425,8 @@ XLATensorPtr mish(const XLATensorPtr& input) {
|
2420 | 2425 |
|
2421 | 2426 | absl::StatusOr<XLATensorPtr> mm(const XLATensorPtr& input,
|
2422 | 2427 | const XLATensorPtr& weight) {
|
2423 |
| - XLA_RETURN_IF_ERROR(CheckMMInputIsMatrix(input, "first")); |
2424 |
| - XLA_RETURN_IF_ERROR(CheckMMInputIsMatrix(weight, "second")); |
| 2428 | + XLA_RETURN_IF_ERROR(CheckInputIsMatrix(input, "mm", "first")); |
| 2429 | + XLA_RETURN_IF_ERROR(CheckInputIsMatrix(weight, "mm", "second")); |
2425 | 2430 | XLA_RETURN_IF_ERROR(CheckMMMatrixSizesAreCompatible(input, weight));
|
2426 | 2431 | return input->CreateFrom(Dot(input->GetIrValue(), weight->GetIrValue()));
|
2427 | 2432 | }
|
@@ -3617,13 +3622,11 @@ std::tuple<XLATensorPtr, XLATensorPtr> topk(const XLATensorPtr& input,
|
3617 | 3622 | return std::make_tuple(t1, t2);
|
3618 | 3623 | }
|
3619 | 3624 |
|
3620 |
| -XLATensorPtr trace(const XLATensorPtr& input) { |
3621 |
| - auto input_shape_ref = input->shape(); |
3622 |
| - XLA_CHECK_EQ((*input_shape_ref).dimensions_size(), 2) |
3623 |
| - << "invalid argument for trace: expected a matrix"; |
3624 |
| - torch::lazy::NodePtr eye = Identity((*input_shape_ref).dimensions(0), |
3625 |
| - (*input_shape_ref).dimensions(1), |
3626 |
| - (*input_shape_ref).element_type()); |
| 3625 | +absl::StatusOr<absl_nonnull XLATensorPtr> trace(const XLATensorPtr& input) { |
| 3626 | + XLA_RETURN_IF_ERROR(CheckInputIsMatrix(input, "trace")); |
| 3627 | + xla::Shape shape = input->shape(); |
| 3628 | + torch::lazy::NodePtr eye = |
| 3629 | + Identity(shape.dimensions(0), shape.dimensions(1), shape.element_type()); |
3627 | 3630 | return sum(input->CreateFrom(eye * input->GetIrValue()), {0, 1}, false,
|
3628 | 3631 | input->dtype());
|
3629 | 3632 | }
|
|
0 commit comments