|
10 | 10 | #include <functional>
|
11 | 11 | #include <iterator>
|
12 | 12 |
|
| 13 | +#include "absl/base/nullability.h" |
13 | 14 | #include "absl/log/absl_check.h"
|
14 | 15 | #include "absl/status/status.h"
|
15 | 16 | #include "absl/strings/str_cat.h"
|
16 | 17 | #include "absl/strings/str_join.h"
|
17 | 18 | #include "absl/strings/str_split.h"
|
| 19 | +#include "absl/types/span.h" |
18 | 20 | #include "torch_xla/csrc/LazyIr.h"
|
19 | 21 | #include "torch_xla/csrc/aten_xla_bridge.h"
|
20 | 22 | #include "torch_xla/csrc/data_ops.h"
|
@@ -234,16 +236,45 @@ void CheckBmmDimension(const std::string& tag, const XLATensorPtr& batch1,
|
234 | 236 | "batch2", 2);
|
235 | 237 | }
|
236 | 238 |
|
237 |
| -std::vector<int64_t> GetExpandDimensions(const xla::Shape& shape, |
238 |
| - std::vector<int64_t> dimensions) { |
239 |
| - XLA_CHECK_GE(dimensions.size(), shape.dimensions_size()) << shape; |
240 |
| - int64_t base = dimensions.size() - shape.dimensions_size(); |
241 |
| - for (size_t i = 0; i < shape.dimensions_size(); ++i) { |
242 |
| - if (dimensions[base + i] == -1) { |
243 |
| - dimensions[base + i] = shape.dimensions(i); |
| 239 | +absl::Status CheckExpandValidRank(const XLATensorPtr& input, |
| 240 | + const absl::Span<const int64_t> sizes) { |
| 241 | + xla::Shape shape = input->shape(); |
| 242 | + int64_t rank = shape.dimensions().size(); |
| 243 | + if (rank > sizes.size()) { |
| 244 | + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( |
| 245 | + "expand(): expected the `input` tensor ", shape.ToString(), " (rank: ", |
| 246 | + rank, ") to have a rank smaller or equal to the given `sizes` [", |
| 247 | + absl::StrJoin(sizes, /* sep= */ ", "), "] (rank: ", sizes.size(), |
| 248 | + ")."))); |
| 249 | + } |
| 250 | + return absl::OkStatus(); |
| 251 | +} |
| 252 | + |
| 253 | +absl::StatusOr<std::vector<int64_t>> GetExpandDimensions( |
| 254 | + const XLATensorPtr& input, const absl::Span<const int64_t> sizes) { |
| 255 | + XLA_RETURN_IF_ERROR(CheckExpandValidRank(input, sizes)); |
| 256 | + |
| 257 | + xla::Shape shape = input->shape(); |
| 258 | + const int64_t rank = shape.dimensions().size(); |
| 259 | + const int64_t base = sizes.size() - rank; |
| 260 | + |
| 261 | + std::vector<int64_t> expanded_dimensions(sizes.begin(), sizes.end()); |
| 262 | + for (size_t i = 0; i < shape.dimensions().size(); ++i) { |
| 263 | + const int64_t dim = base + i; |
| 264 | + const int64_t size = sizes[dim]; |
| 265 | + if (size == -1) { |
| 266 | + expanded_dimensions[dim] = shape.dimensions(i); |
| 267 | + } else if (shape.dimensions(i) != 1 && size != shape.dimensions(i)) { |
| 268 | + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( |
| 269 | + "expand(): expected dimension ", dim, " of the given `sizes` [", |
| 270 | + absl::StrJoin(sizes, /* sep= */ ", "), "] (", size, |
| 271 | + ") to be -1, or equal to the size of the `input` tensor ", |
| 272 | + shape.ToString(), " at dimension ", i, " (", shape.dimensions(i), |
| 273 | + ")."))); |
244 | 274 | }
|
245 | 275 | }
|
246 |
| - return dimensions; |
| 276 | + |
| 277 | + return expanded_dimensions; |
247 | 278 | }
|
248 | 279 |
|
249 | 280 | // Resizes and / or checks whether a list is of the given size. The list is only
|
@@ -1791,11 +1822,12 @@ XLATensorPtr exp(const XLATensorPtr& input) {
|
1791 | 1822 | return input->CreateFrom(Exp(input->GetIrValue()));
|
1792 | 1823 | }
|
1793 | 1824 |
|
1794 |
| -XLATensorPtr expand(const XLATensorPtr& input, std::vector<int64_t> size) { |
1795 |
| - auto input_shape = input->shape(); |
1796 |
| - auto output = input->CreateFrom(torch_xla::MakeNode<Expand>( |
1797 |
| - input->GetIrValue(), |
1798 |
| - GetExpandDimensions(input_shape.get(), std::move(size)))); |
| 1825 | +absl::StatusOr<absl_nonnull XLATensorPtr> expand( |
| 1826 | + const XLATensorPtr& input, const absl::Span<const int64_t> sizes) { |
| 1827 | + XLA_ASSIGN_OR_RETURN(std::vector<int64_t> expanded_dimensions, |
| 1828 | + GetExpandDimensions(input, sizes)); |
| 1829 | + auto output = input->CreateFrom( |
| 1830 | + torch_xla::MakeNode<Expand>(input->GetIrValue(), expanded_dimensions)); |
1799 | 1831 | output->SetStorage(input->Storage());
|
1800 | 1832 | return output;
|
1801 | 1833 | }
|
@@ -2927,15 +2959,14 @@ XLATensorPtr cast_int4(const XLATensorPtr& weight,
|
2927 | 2959 | // Dynamic Reshape ops here.
|
2928 | 2960 | //////////////////////////////////////////////////////////////////////////////
|
2929 | 2961 |
|
2930 |
| -XLATensorPtr dynamic_expand(const XLATensorPtr& input, |
2931 |
| - const std::vector<int64_t>& size, |
2932 |
| - const XLATensorPtr& src_tensor, int src_dim, |
2933 |
| - int target_dim) { |
2934 |
| - std::vector<int64_t> expanded_size = |
2935 |
| - GetExpandDimensions(input->shape().get(), size); |
| 2962 | +absl::StatusOr<absl_nonnull XLATensorPtr> dynamic_expand( |
| 2963 | + const XLATensorPtr& input, const absl::Span<const int64_t> sizes, |
| 2964 | + const XLATensorPtr& src_tensor, int src_dim, int target_dim) { |
| 2965 | + XLA_ASSIGN_OR_RETURN(std::vector<int64_t> expanded_dimensions, |
| 2966 | + GetExpandDimensions(input, sizes)); |
2936 | 2967 | torch::lazy::NodePtr node = torch_xla::MakeNode<DynamicExpand>(
|
2937 |
| - input->GetIrValue(), expanded_size, src_tensor->GetIrValue(), src_dim, |
2938 |
| - target_dim); |
| 2968 | + input->GetIrValue(), expanded_dimensions, src_tensor->GetIrValue(), |
| 2969 | + src_dim, target_dim); |
2939 | 2970 | return input->CreateFrom(torch::lazy::Value(node));
|
2940 | 2971 | }
|
2941 | 2972 |
|
|
0 commit comments