diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index bf7a4003acc..536d3b5b6a4 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1147,8 +1147,8 @@ at::Tensor XLANativeFunctions::as_strided_scatter( XLA_ASSIGN_OR_THROW(XLATensorPtr xla_base, bridge::GetXlaTensor(base)); auto xsize = XlaHelpers::I64List(size); auto xstride = XlaHelpers::I64List(stride); - if (!AsStrided::StrideIsSupported(xla_base->shape(), xsize, xstride, - storage_offset.value_or(0))) { + if (!IsAsStridedSupported(xla_base->shape(), xsize, xstride, + storage_offset.value_or(0))) { return at::native::call_fallback_fn< &xla_fallback, ATEN_OP(as_strided_scatter)>::call(base, mutated_view, size, stride, @@ -4567,8 +4567,8 @@ at::Tensor XLANativeFunctions::as_strided( XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); auto xsize = XlaHelpers::I64List(size); auto xstride = XlaHelpers::I64List(stride); - if (!AsStrided::StrideIsSupported(xla_self->shape(), xsize, xstride, - storage_offset.value_or(0))) { + if (!IsAsStridedSupported(xla_self->shape(), xsize, xstride, + storage_offset.value_or(0))) { return at::native::call_fallback_fn< &xla_fallback, ATEN_OP(as_strided)>::call(self, size, stride, storage_offset); @@ -4585,8 +4585,8 @@ const at::Tensor& XLANativeFunctions::as_strided_( XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); auto xsize = XlaHelpers::I64List(size); auto xstride = XlaHelpers::I64List(stride); - if (!AsStrided::StrideIsSupported(xla_self->shape(), xsize, xstride, - storage_offset.value_or(0))) { + if (!IsAsStridedSupported(xla_self->shape(), xsize, xstride, + storage_offset.value_or(0))) { return at::native::call_fallback_fn< &xla_fallback, ATEN_OP(as_strided_)>::call(self, size, stride, storage_offset); diff --git a/torch_xla/csrc/ops/as_strided.cpp b/torch_xla/csrc/ops/as_strided.cpp index fcccd95b5a1..78708bc0497 100644 --- a/torch_xla/csrc/ops/as_strided.cpp +++ b/torch_xla/csrc/ops/as_strided.cpp @@ -1,68 +1,62 @@ #include "torch_xla/csrc/ops/as_strided.h" -#include - -#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/permutation_util.h" +#include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/util.h" -#include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/helpers.h" +#include "torch_xla/csrc/ir.h" #include "torch_xla/csrc/lowering_context.h" -#include "torch_xla/csrc/runtime/util.h" #include "torch_xla/csrc/shape_helper.h" -#include "torch_xla/csrc/tensor_util.h" -#include "torch_xla/csrc/torch_util.h" +#include "torch_xla/csrc/status.h" namespace torch_xla { namespace { -xla::XlaOp LowerAsStrided(xla::XlaOp input, absl::Span size, - absl::Span stride, - int64_t storage_offset) { - const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input); - int64_t input_element_count = xla::ShapeUtil::ElementsIn(input_shape); - int64_t slice_size = torch_xla::runtime::util::Multiply(size); - XLA_CHECK_LE(storage_offset + slice_size, input_element_count); - - xla::XlaOp off_input = input; - if (storage_offset > 0 || slice_size < input_element_count) { - xla::XlaOp r1_input = XlaHelpers::Flatten(input); - off_input = xla::SliceInDim(r1_input, storage_offset, - storage_offset + slice_size, 1, 0); - } - - std::vector permutation = xla::InversePermutation( - AsStrided::GetArrayStridePermutation(stride, size)); - std::vector new_sizes = xla::PermuteInverse(size, permutation); - xla::XlaOp reshaped_input = XlaHelpers::DynamicReshape(off_input, new_sizes); - return xla::IsIdentityPermutation(permutation) - ? reshaped_input - : xla::Transpose(reshaped_input, permutation); +xla::Shape AsStridedOutputShape(const torch::lazy::Value& input, + absl::Span size) { + return xla::ShapeUtil::MakeShape(GetXlaShape(input).element_type(), size); } } // namespace -AsStrided::AsStrided(const torch::lazy::Value& input, std::vector size, - std::vector stride, int64_t storage_offset) +AsStrided::AsStrided(const torch::lazy::Value& input, + const std::vector& size, + const std::vector& stride, int64_t storage_offset) : XlaNode( torch::lazy::OpKind(at::aten::as_strided), {input}, - [&]() { - return xla::ShapeUtil::MakeShape(GetXlaShape(input).element_type(), - size); - }, - /*num_outputs=*/1, torch::lazy::MHash(size, stride, storage_offset)), - size_(std::move(size)), - stride_(std::move(stride)), - storage_offset_(storage_offset) {} + AsStridedOutputShape(input, size), + /* num_outputs= */ 1, + /* hash_seed= */ torch::lazy::MHash(size, stride, storage_offset)), + size_(size), + stride_(stride), + storage_offset_(storage_offset) { + // Make sure `input` has enough elements to fit the given spec. + XLA_CHECK_OK(CheckSpecFitsInput(input)); +} std::string AsStrided::ToString() const { - std::stringstream ss; - ss << XlaNode::ToString() << ", size=(" << absl::StrJoin(size_, ", ") - << "), stride=(" << absl::StrJoin(stride_, ", ") - << "), storage_offset=" << storage_offset_; - return ss.str(); + return absl::StrCat(XlaNode::ToString(), ", size=(", + absl::StrJoin(size_, ", "), "), stride=(", + absl::StrJoin(stride_, ", "), + "), storage_offset=", storage_offset_); } torch::lazy::NodePtr AsStrided::Clone(torch::lazy::OpList operands) const { @@ -70,26 +64,105 @@ torch::lazy::NodePtr AsStrided::Clone(torch::lazy::OpList operands) const { storage_offset_); } -XlaOpVector AsStrided::Lower(LoweringContext* loctx) const { - xla::XlaOp input = loctx->GetOutputOp(operand(0)); - return ReturnOp(LowerAsStrided(input, size_, stride_, storage_offset_), - loctx); +absl::StatusOr AsStrided::SafeLower(LoweringContext* loctx) const { + XLA_ASSIGN_OR_RETURN(xla::XlaOp input, loctx->SafeGetOutputOp(operand(0))); + + XLA_ASSIGN_OR_RETURN(const xla::Shape* absl_nonnull input_shape_ptr, + GetShape(input)); + + XLA_ASSIGN_OR_RETURN(int64_t input_element_count, + xla::ShapeUtil::ElementsIn(*input_shape_ptr)); + + int64_t spec_element_count = GetSpecElementCount(); + + // Preprocess `input` so that it: + // 1. Starts from `storage_offset_` + // 2. Has the same element count as the spec + // + // This preprocessing should only be done if: + // 1. There's actually a `storage_offset_` to start from; or + // 2. The element count of the input is different from the spec + if (storage_offset_ > 0 || input_element_count != spec_element_count) { + XLA_ASSIGN_OR_RETURN(xla::XlaOp flattened, XlaHelpers::SafeFlatten(input)); + input = xla::SliceInDim(flattened, storage_offset_, + storage_offset_ + spec_element_count, 1, 0); + } + + // Since PyTorch/XLA has no concept of strides in a tensor (i.e. all tensors + // are contiguous), we need a way to compute a contiguous tensor that accesses + // the same elements that a similarly spec'd strided tensor would. In order to + // do that, we need to: + // + // 1. Reshape the `input`, so that the dimensions with larger strides come + // first. This should yield the correct contiguous tensor, but with + // permuted dimensions. + std::vector permutation = + xla::InversePermutation(GetDescendingOrderPermutation(stride_)); + std::vector permuted_sizes = xla::PermuteInverse(size_, permutation); + XLA_ASSIGN_OR_RETURN(xla::XlaOp input_reshaped_with_permuted_sizes, + XlaHelpers::SafeDynamicReshape(input, permuted_sizes)); + + // 2. Reverse the dimension permutation we did on `input` in the previous + // step. + xla::XlaOp output = + xla::IsIdentityPermutation(permutation) + ? input_reshaped_with_permuted_sizes + : xla::Transpose(input_reshaped_with_permuted_sizes, permutation); + + return ReturnOp(output, loctx); +} + +int64_t AsStrided::GetSpecElementCount() const { + return runtime::util::Multiply(size_); +} + +absl::Status AsStrided::CheckSpecFitsInput(xla::XlaOp input) const { + XLA_ASSIGN_OR_RETURN(const xla::Shape* absl_nonnull shape_ptr, + GetShape(input)); + XLA_RETURN_IF_ERROR( + CheckSpecFitsInputImpl(*shape_ptr, xla::ShapeUtil::ElementsIn(*shape_ptr), + GetSpecElementCount())); + return absl::OkStatus(); +} + +absl::Status AsStrided::CheckSpecFitsInput( + const torch::lazy::Value& input) const { + const xla::Shape& shape = GetXlaShape(input); + XLA_RETURN_IF_ERROR(CheckSpecFitsInputImpl( + shape, xla::ShapeUtil::ElementsIn(shape), GetSpecElementCount())); + return absl::OkStatus(); +} + +absl::Status AsStrided::CheckSpecFitsInputImpl( + const xla::Shape& input_shape, int64_t input_element_count, + int64_t spec_element_count) const { + if (input_element_count < storage_offset_ + spec_element_count) { + return XLA_ERROR_WITH_LOCATION(absl::InternalError(absl::StrCat( + "as_strided(): expected input ", input_shape.ToString(), + " (elements=", input_element_count, + ") to have enough elements to fit the given spec of size=[", + absl::StrJoin(size_, /* separator= */ ", "), "], stride=[", + absl::StrJoin(stride_, /* separator= */ ", "), "], and storage_offset=", + storage_offset_, " (elements=", spec_element_count, ")"))); + } + return absl::OkStatus(); } -bool AsStrided::StrideIsSupported(const xla::Shape& input_shape, - absl::Span size, - absl::Span stride, - int64_t storage_offset) { +bool AsStridedIsSupported(const xla::Shape& input_shape, + absl::Span size, + absl::Span stride, + int64_t storage_offset) { std::vector sorted_stride(stride.begin(), stride.end()); std::sort(sorted_stride.begin(), sorted_stride.end()); return stride.empty() || sorted_stride.front() == 1; } -std::vector AsStrided::GetArrayStridePermutation( - absl::Span stride, absl::Span size) { - std::vector permutation = torch::lazy::Iota(stride.size()); +std::vector GetDescendingOrderPermutation( + absl::Span v) { + std::vector permutation(v.size()); + std::iota(permutation.begin(), permutation.end(), 0); std::sort(permutation.begin(), permutation.end(), - [&](int64_t a, int64_t b) { return stride[a] > stride[b]; }); + [&](int64_t a, int64_t b) { return v[a] > v[b]; }); return permutation; } diff --git a/torch_xla/csrc/ops/as_strided.h b/torch_xla/csrc/ops/as_strided.h index 0b7556352da..785900048c5 100644 --- a/torch_xla/csrc/ops/as_strided.h +++ b/torch_xla/csrc/ops/as_strided.h @@ -1,45 +1,60 @@ #ifndef XLA_TORCH_XLA_CSRC_OPS_AS_STRIDED_H_ #define XLA_TORCH_XLA_CSRC_OPS_AS_STRIDED_H_ +#include +#include #include -#include "xla/types.h" +#include + +#include "absl/status/statusor.h" +#include "absl/types/span.h" #include "torch_xla/csrc/ir.h" +#include "torch_xla/csrc/lowering_context.h" namespace torch_xla { class AsStrided : public XlaNode { public: - AsStrided(const torch::lazy::Value& input, std::vector size, - std::vector stride, int64_t storage_offset); + AsStrided(const torch::lazy::Value& input, const std::vector& size, + const std::vector& stride, int64_t storage_offset); std::string ToString() const override; torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; - XlaOpVector Lower(LoweringContext* loctx) const override; - - const std::vector& size() const { return size_; } - - const std::vector& stride() const { return stride_; } - - int64_t storage_offset() const { return storage_offset_; } - - static bool StrideIsSupported(const xla::Shape& input_shape, - absl::Span size, - absl::Span stride, - int64_t storage_offset); - - static std::vector GetArrayStridePermutation( - absl::Span stride, absl::Span size); + absl::StatusOr SafeLower(LoweringContext* loctx) const override; private: std::vector size_; std::vector stride_; int64_t storage_offset_; + + // Convenient function for retrieving the number of elements given by `size_`. + int64_t GetSpecElementCount() const; + // Check that the given as_strided arguments (i.e. spec) are actually within + // the input tensor bounds. i.e. whether it's actually possible to retrieve a + // non-overlapping tensor given by spec from the input tensor. + absl::Status CheckSpecFitsInputImpl(const xla::Shape& input_shape, + int64_t input_element_count, + int64_t spec_element_count) const; + // Lowering check that calls `CheckSpecFitsInputImpl()`. + absl::Status CheckSpecFitsInput(xla::XlaOp input) const; + // Tracing check that calls `CheckSpecFitsInputImpl()`. + absl::Status CheckSpecFitsInput(const torch::lazy::Value& input) const; }; +// Legacy function that checks whether the given specs are supported by this +// lowering of the `as_strided` operation. +bool IsAsStridedSupported(const xla::Shape& input_shape, + absl::Span size, + absl::Span stride, + int64_t storage_offset); + +// Retrieves the permutation for sorting `v` in descending order. +std::vector GetDescendingOrderPermutation(absl::Span v); + } // namespace torch_xla -#endif // XLA_TORCH_XLA_CSRC_OPS_AS_STRIDED_H_ \ No newline at end of file +#endif // XLA_TORCH_XLA_CSRC_OPS_AS_STRIDED_H_ diff --git a/torch_xla/csrc/ops/as_strided_view_update.cpp b/torch_xla/csrc/ops/as_strided_view_update.cpp index 973d891af06..bfaddcc4d1d 100644 --- a/torch_xla/csrc/ops/as_strided_view_update.cpp +++ b/torch_xla/csrc/ops/as_strided_view_update.cpp @@ -24,8 +24,7 @@ xla::XlaOp LowerAsStridedViewUpdate(xla::XlaOp target, xla::XlaOp input, int64_t slice_size = torch_xla::runtime::util::Multiply(size); XLA_CHECK_LE(storage_offset + input_element_count, slice_size); - std::vector permutation = - AsStrided::GetArrayStridePermutation(stride, input_shape.dimensions()); + std::vector permutation = GetDescendingOrderPermutation(stride); xla::XlaOp transposed_input = xla::IsIdentityPermutation(permutation) ? input : xla::Transpose(input, permutation);