Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions torch_xla/csrc/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,6 @@ class LoweringContext;

using XlaOpVector = absl::InlinedVector<xla::XlaOp, 1>;

template <typename T>
using OutputMap =
std::unordered_map<torch::lazy::Output, T, torch::lazy::Output::Hasher>;

void DetectDynamicShape(torch::lazy::NodePtr node);

template <typename T, typename... Args>
Expand Down
53 changes: 39 additions & 14 deletions torch_xla/csrc/lowering_context.cpp
Original file line number Diff line number Diff line change
@@ -1,24 +1,36 @@
#include "torch_xla/csrc/lowering_context.h"

#include <c10/util/ArrayRef.h>

#include <cstddef>
#include <cstdint>
#include <memory>
#include <optional>
#include <sstream>
#include <stdexcept>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>

#include <torch/csrc/lazy/backend/backend_data.h>
#include <torch/csrc/lazy/backend/backend_device.h>
#include <torch/csrc/lazy/backend/lowering_context.h>
#include <torch/csrc/lazy/core/config.h>
#include <torch/csrc/lazy/core/ir.h>
#include <torch/csrc/lazy/core/ir_metadata.h>
#include <torch/csrc/lazy/core/ir_util.h>

#include "absl/log/absl_check.h"
#include "absl/log/absl_log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_replace.h"
#include "xla/hlo/builder/xla_builder.h"
#include "xla/hlo/builder/xla_computation.h"
#include "xla/shape.h"
#include "xla/xla_data.pb.h"

#include "torch_xla/csrc/ir.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/shape_helper.h"
#include "torch_xla/csrc/stack_frame_index_builder.h"
Expand Down Expand Up @@ -242,21 +254,23 @@ void LoweringContext::AssignOutputOp(const torch::lazy::Output& output,
}

xla::XlaOp LoweringContext::GetOutputOp(const torch::lazy::Output& output) {
auto it = emitted_outputs_.find(output);
XLA_ASSIGN_OR_THROW(xla::XlaOp op, SafeGetOutputOp(output));
return op;
}

if (it == emitted_outputs_.end()) {
const auto post_order =
absl::StatusOr<xla::XlaOp> LoweringContext::SafeGetOutputOp(
const torch::lazy::Output& output) {
if (CheckOutputIsEmitted(output).ok()) {
const std::vector<const torch::lazy::Node*> post_order =
torch::lazy::Util::ComputePostOrder(output.node, &emit_status_);
for (const auto* const node : post_order) {
XLA_THROW_IF_ERROR(LowerNode(*node));
for (const torch::lazy::Node* const node : post_order) {
XLA_RETURN_IF_ERROR(LowerNode(*node));
}
// At this point the output better be present, otherwise there is an issue
// with the lowering code.
it = emitted_outputs_.find(output);
ABSL_CHECK(it != emitted_outputs_.end())
<< "No XLA operation emitted for output: " << output;
XLA_CHECK_OK(CheckOutputIsEmitted(output));
}
return it->second;
return emitted_outputs_.at(output);
}

absl::StatusOr<XlaOpVector> LoweringContext::LowerNode(
Expand Down Expand Up @@ -329,4 +343,15 @@ torch::lazy::ComputationPtr LoweringContext::Build() {
builder_.name(), std::move(xla_computation), device_);
}

absl::Status LoweringContext::CheckOutputIsEmitted(
const torch::lazy::Output& output) const {
torch::lazy::OutputMap<xla::XlaOp>::const_iterator it =
emitted_outputs_.find(output);
if (it == emitted_outputs_.end()) {
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(
absl::StrCat("could not find output: ", output.ToString())));
}
return absl::OkStatus();
}

} // namespace torch_xla
46 changes: 33 additions & 13 deletions torch_xla/csrc/lowering_context.h
Original file line number Diff line number Diff line change
@@ -1,27 +1,30 @@
#ifndef XLA_TORCH_XLA_CSRC_LOWERING_CONTEXT_H_
#define XLA_TORCH_XLA_CSRC_LOWERING_CONTEXT_H_

#include <c10/util/ArrayRef.h>

#include <cstddef>
#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <unordered_map>
#include <utility>
#include <unordered_set>
#include <vector>

#include <torch/csrc/lazy/backend/backend_data.h>
#include <torch/csrc/lazy/backend/backend_device.h>
#include <torch/csrc/lazy/backend/lowering_context.h>
#include <torch/csrc/lazy/core/ir.h>
#include <torch/csrc/lazy/core/ir_metadata.h>
#include <torch/csrc/lazy/core/ir_util.h>

#include "absl/status/status.h"
#include "absl/types/span.h"
#include "tsl/platform/macros.h"
#include "absl/status/statusor.h"
#include "xla/hlo/builder/xla_builder.h"
#include "xla/types.h"
#include "xla/hlo/builder/xla_computation.h"

#include "torch_xla/csrc/device.h"
#include "torch_xla/csrc/ir.h"
#include "torch_xla/csrc/runtime/computation_client.h"

namespace torch_xla {

Expand Down Expand Up @@ -74,10 +77,23 @@ class LoweringContext : public torch::lazy::LoweringContext {
// operands among the emitted outputs.
void AssignOutputOp(const torch::lazy::Output& output, xla::XlaOp op);

// Retrieves the lowered operation for a output. If the requested output is
// not available yet, the graph behind the output's XlaNode is lowered, and
// the corresponding XLA operation returned.
xla::XlaOp GetOutputOp(const torch::lazy::Output& output);
// Retrieves the lowered operation for a output.
//
// If the requested output is not available yet, the graph behind the output's
// XlaNode is lowered, and the corresponding XLA operation returned.
[[deprecated("Use SafeGetOutputOp for better error handling.")]] xla::XlaOp
GetOutputOp(const torch::lazy::Output& output);
// Retrieves the lowered operation for a output.
//
// If the requested output is not available yet, the graph behind the output's
// XlaNode is lowered, and the corresponding XLA operation returned.
//
// This function shall return an error status if the lowering the underlying
// `output`, or any other dependent nodes, returns an error status.
// Additionally, it might abort if after the lowering of `output` and its
// dependent nodes, the lowered node for `output` is not available, i.e. not
// in `emitted_outputs_`.
absl::StatusOr<xla::XlaOp> SafeGetOutputOp(const torch::lazy::Output& output);

// Build the XLA computation capturing all the operations created with the
// embedded XLA builder (returned by the builder() API).
Expand Down Expand Up @@ -110,7 +126,7 @@ class LoweringContext : public torch::lazy::LoweringContext {

torch::lazy::ComputationPtr Build() override;

const OutputMap<xla::XlaOp> GetEmittedOutputs() const {
const torch::lazy::OutputMap<xla::XlaOp> GetEmittedOutputs() const {
return emitted_outputs_;
}

Expand All @@ -124,11 +140,15 @@ class LoweringContext : public torch::lazy::LoweringContext {
size_t index = 0;
};

// Checks whether the given output is already emitted. In other words, whether
// we can find it inside `emitted_outputs_`.
absl::Status CheckOutputIsEmitted(const torch::lazy::Output& output) const;

xla::XlaBuilder builder_;
std::unordered_map<torch::lazy::BackendData::Handle, Parameter>
parameters_map_;
std::vector<xla::XlaOp> root_tuple_;
OutputMap<xla::XlaOp> emitted_outputs_;
torch::lazy::OutputMap<xla::XlaOp> emitted_outputs_;
std::string name_;

std::shared_ptr<StackFrameIndexBuilder> stack_frame_index_builder_;
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/ops/custom_call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/ops/xla_ops.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/shape_helper.h"

namespace torch_xla {
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/ops/dot_general.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/ops/infer_output_shape.h"
#include "torch_xla/csrc/ops/xla_ops.h"
#include "torch_xla/csrc/runtime/debug_macros.h"

namespace torch_xla {

Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/ops/symeig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "xla/hlo/builder/lib/self_adjoint_eig.h"

#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/shape_helper.h"

namespace torch_xla {
Expand Down