Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1147,8 +1147,8 @@ at::Tensor XLANativeFunctions::as_strided_scatter(
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_base, bridge::GetXlaTensor(base));
auto xsize = XlaHelpers::I64List(size);
auto xstride = XlaHelpers::I64List(stride);
if (!AsStrided::StrideIsSupported(xla_base->shape(), xsize, xstride,
storage_offset.value_or(0))) {
if (!IsAsStridedSupported(xla_base->shape(), xsize, xstride,
storage_offset.value_or(0))) {
return at::native::call_fallback_fn<
&xla_fallback, ATEN_OP(as_strided_scatter)>::call(base, mutated_view,
size, stride,
Expand Down Expand Up @@ -4567,8 +4567,8 @@ at::Tensor XLANativeFunctions::as_strided(
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
auto xsize = XlaHelpers::I64List(size);
auto xstride = XlaHelpers::I64List(stride);
if (!AsStrided::StrideIsSupported(xla_self->shape(), xsize, xstride,
storage_offset.value_or(0))) {
if (!IsAsStridedSupported(xla_self->shape(), xsize, xstride,
storage_offset.value_or(0))) {
return at::native::call_fallback_fn<
&xla_fallback, ATEN_OP(as_strided)>::call(self, size, stride,
storage_offset);
Expand All @@ -4585,8 +4585,8 @@ const at::Tensor& XLANativeFunctions::as_strided_(
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
auto xsize = XlaHelpers::I64List(size);
auto xstride = XlaHelpers::I64List(stride);
if (!AsStrided::StrideIsSupported(xla_self->shape(), xsize, xstride,
storage_offset.value_or(0))) {
if (!IsAsStridedSupported(xla_self->shape(), xsize, xstride,
storage_offset.value_or(0))) {
return at::native::call_fallback_fn<
&xla_fallback, ATEN_OP(as_strided_)>::call(self, size, stride,
storage_offset);
Expand Down
187 changes: 130 additions & 57 deletions torch_xla/csrc/ops/as_strided.cpp
Original file line number Diff line number Diff line change
@@ -1,95 +1,168 @@
#include "torch_xla/csrc/ops/as_strided.h"

#include <algorithm>

#include <torch/csrc/lazy/core/util.h>
#include <ATen/core/aten_interned_strings.h>

#include <algorithm>
#include <cstdint>
#include <numeric>
#include <string>
#include <vector>

#include <torch/csrc/lazy/core/hash.h>
#include <torch/csrc/lazy/core/ir.h>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/types/span.h"
#include "xla/hlo/builder/xla_builder.h"
#include "xla/permutation_util.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/util.h"

#include "torch_xla/csrc/data_ops.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/ir.h"
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/runtime/util.h"
#include "torch_xla/csrc/shape_helper.h"
#include "torch_xla/csrc/tensor_util.h"
#include "torch_xla/csrc/torch_util.h"
#include "torch_xla/csrc/status.h"

namespace torch_xla {
namespace {

xla::XlaOp LowerAsStrided(xla::XlaOp input, absl::Span<const int64_t> size,
absl::Span<const int64_t> stride,
int64_t storage_offset) {
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input);
int64_t input_element_count = xla::ShapeUtil::ElementsIn(input_shape);
int64_t slice_size = torch_xla::runtime::util::Multiply<int64_t>(size);
XLA_CHECK_LE(storage_offset + slice_size, input_element_count);

xla::XlaOp off_input = input;
if (storage_offset > 0 || slice_size < input_element_count) {
xla::XlaOp r1_input = XlaHelpers::Flatten(input);
off_input = xla::SliceInDim(r1_input, storage_offset,
storage_offset + slice_size, 1, 0);
}

std::vector<int64_t> permutation = xla::InversePermutation(
AsStrided::GetArrayStridePermutation(stride, size));
std::vector<int64_t> new_sizes = xla::PermuteInverse(size, permutation);
xla::XlaOp reshaped_input = XlaHelpers::DynamicReshape(off_input, new_sizes);
return xla::IsIdentityPermutation(permutation)
? reshaped_input
: xla::Transpose(reshaped_input, permutation);
xla::Shape AsStridedOutputShape(const torch::lazy::Value& input,
absl::Span<const int64_t> size) {
return xla::ShapeUtil::MakeShape(GetXlaShape(input).element_type(), size);
}

} // namespace

AsStrided::AsStrided(const torch::lazy::Value& input, std::vector<int64_t> size,
std::vector<int64_t> stride, int64_t storage_offset)
AsStrided::AsStrided(const torch::lazy::Value& input,
const std::vector<int64_t>& size,
const std::vector<int64_t>& stride, int64_t storage_offset)
: XlaNode(
torch::lazy::OpKind(at::aten::as_strided), {input},
[&]() {
return xla::ShapeUtil::MakeShape(GetXlaShape(input).element_type(),
size);
},
/*num_outputs=*/1, torch::lazy::MHash(size, stride, storage_offset)),
size_(std::move(size)),
stride_(std::move(stride)),
storage_offset_(storage_offset) {}
AsStridedOutputShape(input, size),
/* num_outputs= */ 1,
/* hash_seed= */ torch::lazy::MHash(size, stride, storage_offset)),
size_(size),
stride_(stride),
storage_offset_(storage_offset) {
// Make sure `input` has enough elements to fit the given spec.
XLA_CHECK_OK(CheckSpecFitsInput(input));
}

std::string AsStrided::ToString() const {
std::stringstream ss;
ss << XlaNode::ToString() << ", size=(" << absl::StrJoin(size_, ", ")
<< "), stride=(" << absl::StrJoin(stride_, ", ")
<< "), storage_offset=" << storage_offset_;
return ss.str();
return absl::StrCat(XlaNode::ToString(), ", size=(",
absl::StrJoin(size_, ", "), "), stride=(",
absl::StrJoin(stride_, ", "),
"), storage_offset=", storage_offset_);
}

torch::lazy::NodePtr AsStrided::Clone(torch::lazy::OpList operands) const {
return torch_xla::MakeNode<AsStrided>(operands.at(0), size_, stride_,
storage_offset_);
}

XlaOpVector AsStrided::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
return ReturnOp(LowerAsStrided(input, size_, stride_, storage_offset_),
loctx);
absl::StatusOr<XlaOpVector> AsStrided::SafeLower(LoweringContext* loctx) const {
XLA_ASSIGN_OR_RETURN(xla::XlaOp input, loctx->SafeGetOutputOp(operand(0)));

XLA_ASSIGN_OR_RETURN(const xla::Shape* absl_nonnull input_shape_ptr,
GetShape(input));

XLA_ASSIGN_OR_RETURN(int64_t input_element_count,
xla::ShapeUtil::ElementsIn(*input_shape_ptr));

int64_t spec_element_count = GetSpecElementCount();

// Preprocess `input` so that it:
// 1. Starts from `storage_offset_`
// 2. Has the same element count as the spec
//
// This preprocessing should only be done if:
// 1. There's actually a `storage_offset_` to start from; or
// 2. The element count of the input is different from the spec
if (storage_offset_ > 0 || input_element_count != spec_element_count) {
XLA_ASSIGN_OR_RETURN(xla::XlaOp flattened, XlaHelpers::SafeFlatten(input));
input = xla::SliceInDim(flattened, storage_offset_,
storage_offset_ + spec_element_count, 1, 0);
}

// Since PyTorch/XLA has no concept of strides in a tensor (i.e. all tensors
// are contiguous), we need a way to compute a contiguous tensor that accesses
// the same elements that a similarly spec'd strided tensor would. In order to
// do that, we need to:
//
// 1. Reshape the `input`, so that the dimensions with larger strides come
// first. This should yield the correct contiguous tensor, but with
// permuted dimensions.
std::vector<int64_t> permutation =
xla::InversePermutation(GetDescendingOrderPermutation(stride_));
std::vector<int64_t> permuted_sizes = xla::PermuteInverse(size_, permutation);
XLA_ASSIGN_OR_RETURN(xla::XlaOp input_reshaped_with_permuted_sizes,
XlaHelpers::SafeDynamicReshape(input, permuted_sizes));

// 2. Reverse the dimension permutation we did on `input` in the previous
// step.
xla::XlaOp output =
xla::IsIdentityPermutation(permutation)
? input_reshaped_with_permuted_sizes
: xla::Transpose(input_reshaped_with_permuted_sizes, permutation);

return ReturnOp(output, loctx);
}

int64_t AsStrided::GetSpecElementCount() const {
return runtime::util::Multiply<int64_t>(size_);
}

absl::Status AsStrided::CheckSpecFitsInput(xla::XlaOp input) const {
XLA_ASSIGN_OR_RETURN(const xla::Shape* absl_nonnull shape_ptr,
GetShape(input));
XLA_RETURN_IF_ERROR(
CheckSpecFitsInputImpl(*shape_ptr, xla::ShapeUtil::ElementsIn(*shape_ptr),
GetSpecElementCount()));
return absl::OkStatus();
}

absl::Status AsStrided::CheckSpecFitsInput(
const torch::lazy::Value& input) const {
const xla::Shape& shape = GetXlaShape(input);
XLA_RETURN_IF_ERROR(CheckSpecFitsInputImpl(
shape, xla::ShapeUtil::ElementsIn(shape), GetSpecElementCount()));
return absl::OkStatus();
}

absl::Status AsStrided::CheckSpecFitsInputImpl(
const xla::Shape& input_shape, int64_t input_element_count,
int64_t spec_element_count) const {
if (input_element_count < storage_offset_ + spec_element_count) {
return XLA_ERROR_WITH_LOCATION(absl::InternalError(absl::StrCat(
"as_strided(): expected input ", input_shape.ToString(),
" (elements=", input_element_count,
") to have enough elements to fit the given spec of size=[",
absl::StrJoin(size_, /* separator= */ ", "), "], stride=[",
absl::StrJoin(stride_, /* separator= */ ", "), "], and storage_offset=",
storage_offset_, " (elements=", spec_element_count, ")")));
}
return absl::OkStatus();
}

bool AsStrided::StrideIsSupported(const xla::Shape& input_shape,
absl::Span<const int64_t> size,
absl::Span<const int64_t> stride,
int64_t storage_offset) {
bool AsStridedIsSupported(const xla::Shape& input_shape,
absl::Span<const int64_t> size,
absl::Span<const int64_t> stride,
int64_t storage_offset) {
std::vector<int64_t> sorted_stride(stride.begin(), stride.end());
std::sort(sorted_stride.begin(), sorted_stride.end());
return stride.empty() || sorted_stride.front() == 1;
}

std::vector<int64_t> AsStrided::GetArrayStridePermutation(
absl::Span<const int64_t> stride, absl::Span<const int64_t> size) {
std::vector<int64_t> permutation = torch::lazy::Iota<int64_t>(stride.size());
std::vector<int64_t> GetDescendingOrderPermutation(
absl::Span<const int64_t> v) {
std::vector<int64_t> permutation(v.size());
std::iota(permutation.begin(), permutation.end(), 0);
std::sort(permutation.begin(), permutation.end(),
[&](int64_t a, int64_t b) { return stride[a] > stride[b]; });
[&](int64_t a, int64_t b) { return v[a] > v[b]; });
return permutation;
}

Expand Down
53 changes: 34 additions & 19 deletions torch_xla/csrc/ops/as_strided.h
Original file line number Diff line number Diff line change
@@ -1,45 +1,60 @@
#ifndef XLA_TORCH_XLA_CSRC_OPS_AS_STRIDED_H_
#define XLA_TORCH_XLA_CSRC_OPS_AS_STRIDED_H_

#include <cstdint>
#include <string>
#include <vector>

#include "xla/types.h"
#include <torch/csrc/lazy/core/ir.h>

#include "absl/status/statusor.h"
#include "absl/types/span.h"

#include "torch_xla/csrc/ir.h"
#include "torch_xla/csrc/lowering_context.h"

namespace torch_xla {

class AsStrided : public XlaNode {
public:
AsStrided(const torch::lazy::Value& input, std::vector<int64_t> size,
std::vector<int64_t> stride, int64_t storage_offset);
AsStrided(const torch::lazy::Value& input, const std::vector<int64_t>& size,
const std::vector<int64_t>& stride, int64_t storage_offset);

std::string ToString() const override;

torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override;

XlaOpVector Lower(LoweringContext* loctx) const override;

const std::vector<int64_t>& size() const { return size_; }

const std::vector<int64_t>& stride() const { return stride_; }

int64_t storage_offset() const { return storage_offset_; }

static bool StrideIsSupported(const xla::Shape& input_shape,
absl::Span<const int64_t> size,
absl::Span<const int64_t> stride,
int64_t storage_offset);

static std::vector<int64_t> GetArrayStridePermutation(
absl::Span<const int64_t> stride, absl::Span<const int64_t> size);
absl::StatusOr<XlaOpVector> SafeLower(LoweringContext* loctx) const override;

private:
std::vector<int64_t> size_;
std::vector<int64_t> stride_;
int64_t storage_offset_;

// Convenient function for retrieving the number of elements given by `size_`.
int64_t GetSpecElementCount() const;
// Check that the given as_strided arguments (i.e. spec) are actually within
// the input tensor bounds. i.e. whether it's actually possible to retrieve a
// non-overlapping tensor given by spec from the input tensor.
absl::Status CheckSpecFitsInputImpl(const xla::Shape& input_shape,
int64_t input_element_count,
int64_t spec_element_count) const;
// Lowering check that calls `CheckSpecFitsInputImpl()`.
absl::Status CheckSpecFitsInput(xla::XlaOp input) const;
// Tracing check that calls `CheckSpecFitsInputImpl()`.
absl::Status CheckSpecFitsInput(const torch::lazy::Value& input) const;
};

// Legacy function that checks whether the given specs are supported by this
// lowering of the `as_strided` operation.
bool IsAsStridedSupported(const xla::Shape& input_shape,
absl::Span<const int64_t> size,
absl::Span<const int64_t> stride,
int64_t storage_offset);

// Retrieves the permutation for sorting `v` in descending order.
std::vector<int64_t> GetDescendingOrderPermutation(absl::Span<const int64_t> v);

} // namespace torch_xla

#endif // XLA_TORCH_XLA_CSRC_OPS_AS_STRIDED_H_
#endif // XLA_TORCH_XLA_CSRC_OPS_AS_STRIDED_H_
3 changes: 1 addition & 2 deletions torch_xla/csrc/ops/as_strided_view_update.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ xla::XlaOp LowerAsStridedViewUpdate(xla::XlaOp target, xla::XlaOp input,
int64_t slice_size = torch_xla::runtime::util::Multiply<int64_t>(size);
XLA_CHECK_LE(storage_offset + input_element_count, slice_size);

std::vector<int64_t> permutation =
AsStrided::GetArrayStridePermutation(stride, input_shape.dimensions());
std::vector<int64_t> permutation = GetDescendingOrderPermutation(stride);
xla::XlaOp transposed_input = xla::IsIdentityPermutation(permutation)
? input
: xla::Transpose(input, permutation);
Expand Down
Loading