diff --git a/torch_xla/csrc/ir.cpp b/torch_xla/csrc/ir.cpp index c63b772ac45..37dc9bee2f0 100644 --- a/torch_xla/csrc/ir.cpp +++ b/torch_xla/csrc/ir.cpp @@ -17,6 +17,7 @@ #include "torch_xla/csrc/runtime/cache.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/sys_util.h" +#include "torch_xla/csrc/status.h" namespace torch_xla { namespace { diff --git a/torch_xla/csrc/ir.h b/torch_xla/csrc/ir.h index eea82a8fe2b..36345e72b88 100644 --- a/torch_xla/csrc/ir.h +++ b/torch_xla/csrc/ir.h @@ -33,10 +33,6 @@ class LoweringContext; using XlaOpVector = absl::InlinedVector; -template -using OutputMap = - std::unordered_map; - void DetectDynamicShape(torch::lazy::NodePtr node); template diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp index 3b557dffa82..14b812d8d3d 100644 --- a/torch_xla/csrc/lowering_context.cpp +++ b/torch_xla/csrc/lowering_context.cpp @@ -1,24 +1,35 @@ #include "torch_xla/csrc/lowering_context.h" +#include +#include #include #include -#include -#include +#include #include #include - +#include + +#include +#include +#include +#include +#include +#include #include +#include -#include "absl/log/absl_check.h" -#include "absl/log/absl_log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/str_replace.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/shape.h" +#include "xla/xla_data.pb.h" #include "torch_xla/csrc/ir.h" #include "torch_xla/csrc/runtime/computation_client.h" -#include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/shape_helper.h" #include "torch_xla/csrc/stack_frame_index_builder.h" @@ -242,21 +253,23 @@ void LoweringContext::AssignOutputOp(const torch::lazy::Output& output, } xla::XlaOp LoweringContext::GetOutputOp(const torch::lazy::Output& output) { - auto it = emitted_outputs_.find(output); + XLA_ASSIGN_OR_THROW(xla::XlaOp op, SafeGetOutputOp(output)); + return op; +} - if (it == emitted_outputs_.end()) { - const auto post_order = +absl::StatusOr LoweringContext::SafeGetOutputOp( + const torch::lazy::Output& output) { + if (!CheckOutputIsEmitted(output).ok()) { + const std::vector post_order = torch::lazy::Util::ComputePostOrder(output.node, &emit_status_); - for (const auto* const node : post_order) { - XLA_THROW_IF_ERROR(LowerNode(*node)); + for (const torch::lazy::Node* const node : post_order) { + XLA_RETURN_IF_ERROR(LowerNode(*node)); } // At this point the output better be present, otherwise there is an issue // with the lowering code. - it = emitted_outputs_.find(output); - ABSL_CHECK(it != emitted_outputs_.end()) - << "No XLA operation emitted for output: " << output; + XLA_CHECK_OK(CheckOutputIsEmitted(output)); } - return it->second; + return emitted_outputs_.at(output); } absl::StatusOr LoweringContext::LowerNode( @@ -329,4 +342,15 @@ torch::lazy::ComputationPtr LoweringContext::Build() { builder_.name(), std::move(xla_computation), device_); } +absl::Status LoweringContext::CheckOutputIsEmitted( + const torch::lazy::Output& output) const { + torch::lazy::OutputMap::const_iterator it = + emitted_outputs_.find(output); + if (it == emitted_outputs_.end()) { + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError( + absl::StrCat("could not find output: ", output.ToString()))); + } + return absl::OkStatus(); +} + } // namespace torch_xla diff --git a/torch_xla/csrc/lowering_context.h b/torch_xla/csrc/lowering_context.h index 75a68aac6f5..bcb0a35e6d3 100644 --- a/torch_xla/csrc/lowering_context.h +++ b/torch_xla/csrc/lowering_context.h @@ -1,27 +1,29 @@ #ifndef XLA_TORCH_XLA_CSRC_LOWERING_CONTEXT_H_ #define XLA_TORCH_XLA_CSRC_LOWERING_CONTEXT_H_ +#include +#include #include #include #include -#include #include -#include +#include #include +#include #include +#include #include +#include +#include #include #include "absl/status/status.h" -#include "absl/types/span.h" -#include "tsl/platform/macros.h" +#include "absl/status/statusor.h" #include "xla/hlo/builder/xla_builder.h" -#include "xla/types.h" +#include "xla/hlo/builder/xla_computation.h" -#include "torch_xla/csrc/device.h" #include "torch_xla/csrc/ir.h" -#include "torch_xla/csrc/runtime/computation_client.h" namespace torch_xla { @@ -74,10 +76,23 @@ class LoweringContext : public torch::lazy::LoweringContext { // operands among the emitted outputs. void AssignOutputOp(const torch::lazy::Output& output, xla::XlaOp op); - // Retrieves the lowered operation for a output. If the requested output is - // not available yet, the graph behind the output's XlaNode is lowered, and - // the corresponding XLA operation returned. - xla::XlaOp GetOutputOp(const torch::lazy::Output& output); + // Retrieves the lowered operation for a output. + // + // If the requested output is not available yet, the graph behind the output's + // XlaNode is lowered, and the corresponding XLA operation returned. + [[deprecated("Use SafeGetOutputOp for better error handling.")]] xla::XlaOp + GetOutputOp(const torch::lazy::Output& output); + // Retrieves the lowered operation for a output. + // + // If the requested output is not available yet, the graph behind the output's + // XlaNode is lowered, and the corresponding XLA operation returned. + // + // This function shall return an error status if the lowering the underlying + // `output`, or any other dependent nodes, returns an error status. + // Additionally, it might abort if after the lowering of `output` and its + // dependent nodes, the lowered node for `output` is not available, i.e. not + // in `emitted_outputs_`. + absl::StatusOr SafeGetOutputOp(const torch::lazy::Output& output); // Build the XLA computation capturing all the operations created with the // embedded XLA builder (returned by the builder() API). @@ -110,7 +125,7 @@ class LoweringContext : public torch::lazy::LoweringContext { torch::lazy::ComputationPtr Build() override; - const OutputMap GetEmittedOutputs() const { + const torch::lazy::OutputMap GetEmittedOutputs() const { return emitted_outputs_; } @@ -124,11 +139,15 @@ class LoweringContext : public torch::lazy::LoweringContext { size_t index = 0; }; + // Checks whether the given output is already emitted. In other words, whether + // we can find it inside `emitted_outputs_`. + absl::Status CheckOutputIsEmitted(const torch::lazy::Output& output) const; + xla::XlaBuilder builder_; std::unordered_map parameters_map_; std::vector root_tuple_; - OutputMap emitted_outputs_; + torch::lazy::OutputMap emitted_outputs_; std::string name_; std::shared_ptr stack_frame_index_builder_; diff --git a/torch_xla/csrc/ops/custom_call.cpp b/torch_xla/csrc/ops/custom_call.cpp index b7980f75f57..53032fd9cee 100644 --- a/torch_xla/csrc/ops/custom_call.cpp +++ b/torch_xla/csrc/ops/custom_call.cpp @@ -4,6 +4,7 @@ #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/xla_ops.h" +#include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/shape_helper.h" namespace torch_xla { diff --git a/torch_xla/csrc/ops/dot_general.cpp b/torch_xla/csrc/ops/dot_general.cpp index 345bfc3a08f..a0abd21199b 100644 --- a/torch_xla/csrc/ops/dot_general.cpp +++ b/torch_xla/csrc/ops/dot_general.cpp @@ -7,6 +7,7 @@ #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/ops/xla_ops.h" +#include "torch_xla/csrc/runtime/debug_macros.h" namespace torch_xla { diff --git a/torch_xla/csrc/ops/symeig.cpp b/torch_xla/csrc/ops/symeig.cpp index 996f91ba5cc..7242d8a25e7 100644 --- a/torch_xla/csrc/ops/symeig.cpp +++ b/torch_xla/csrc/ops/symeig.cpp @@ -5,6 +5,7 @@ #include "xla/hlo/builder/lib/self_adjoint_eig.h" #include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/shape_helper.h" namespace torch_xla {