From a05ff11f16d0a24c6ef919a81fbf71305936ad83 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 26 Nov 2025 15:02:04 -0300 Subject: [PATCH 1/5] Improve error handling of LoweringContext::GetOutputOp --- torch_xla/csrc/lowering_context.cpp | 53 +++++++++++++++++++++-------- torch_xla/csrc/lowering_context.h | 46 ++++++++++++++++++------- torch_xla/csrc/ops/custom_call.cpp | 1 + torch_xla/csrc/ops/dot_general.cpp | 1 + torch_xla/csrc/ops/symeig.cpp | 1 + 5 files changed, 75 insertions(+), 27 deletions(-) diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp index 3b557dffa823..f2ea4d6ea458 100644 --- a/torch_xla/csrc/lowering_context.cpp +++ b/torch_xla/csrc/lowering_context.cpp @@ -1,24 +1,36 @@ #include "torch_xla/csrc/lowering_context.h" +#include + +#include +#include #include #include -#include -#include +#include #include #include +#include +#include +#include +#include +#include +#include #include +#include -#include "absl/log/absl_check.h" -#include "absl/log/absl_log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/str_replace.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/shape.h" +#include "xla/xla_data.pb.h" #include "torch_xla/csrc/ir.h" #include "torch_xla/csrc/runtime/computation_client.h" -#include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/shape_helper.h" #include "torch_xla/csrc/stack_frame_index_builder.h" @@ -242,21 +254,23 @@ void LoweringContext::AssignOutputOp(const torch::lazy::Output& output, } xla::XlaOp LoweringContext::GetOutputOp(const torch::lazy::Output& output) { - auto it = emitted_outputs_.find(output); + XLA_ASSIGN_OR_THROW(xla::XlaOp op, SafeGetOutputOp(output)); + return op; +} - if (it == emitted_outputs_.end()) { - const auto post_order = +absl::StatusOr LoweringContext::SafeGetOutputOp( + const torch::lazy::Output& output) { + if (CheckOutputIsEmitted(output).ok()) { + const std::vector post_order = torch::lazy::Util::ComputePostOrder(output.node, &emit_status_); - for (const auto* const node : post_order) { - XLA_THROW_IF_ERROR(LowerNode(*node)); + for (const torch::lazy::Node* const node : post_order) { + XLA_RETURN_IF_ERROR(LowerNode(*node)); } // At this point the output better be present, otherwise there is an issue // with the lowering code. - it = emitted_outputs_.find(output); - ABSL_CHECK(it != emitted_outputs_.end()) - << "No XLA operation emitted for output: " << output; + XLA_CHECK_OK(CheckOutputIsEmitted(output)); } - return it->second; + return emitted_outputs_.at(output); } absl::StatusOr LoweringContext::LowerNode( @@ -329,4 +343,15 @@ torch::lazy::ComputationPtr LoweringContext::Build() { builder_.name(), std::move(xla_computation), device_); } +absl::Status LoweringContext::CheckOutputIsEmitted( + const torch::lazy::Output& output) const { + torch::lazy::OutputMap::const_iterator it = + emitted_outputs_.find(output); + if (it == emitted_outputs_.end()) { + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError( + absl::StrCat("could not find output: ", output.ToString()))); + } + return absl::OkStatus(); +} + } // namespace torch_xla diff --git a/torch_xla/csrc/lowering_context.h b/torch_xla/csrc/lowering_context.h index 75a68aac6f51..699ec41bfc5e 100644 --- a/torch_xla/csrc/lowering_context.h +++ b/torch_xla/csrc/lowering_context.h @@ -1,27 +1,30 @@ #ifndef XLA_TORCH_XLA_CSRC_LOWERING_CONTEXT_H_ #define XLA_TORCH_XLA_CSRC_LOWERING_CONTEXT_H_ +#include + +#include +#include #include #include #include -#include #include -#include +#include #include #include +#include #include +#include +#include #include #include "absl/status/status.h" -#include "absl/types/span.h" -#include "tsl/platform/macros.h" +#include "absl/status/statusor.h" #include "xla/hlo/builder/xla_builder.h" -#include "xla/types.h" +#include "xla/hlo/builder/xla_computation.h" -#include "torch_xla/csrc/device.h" #include "torch_xla/csrc/ir.h" -#include "torch_xla/csrc/runtime/computation_client.h" namespace torch_xla { @@ -74,10 +77,23 @@ class LoweringContext : public torch::lazy::LoweringContext { // operands among the emitted outputs. void AssignOutputOp(const torch::lazy::Output& output, xla::XlaOp op); - // Retrieves the lowered operation for a output. If the requested output is - // not available yet, the graph behind the output's XlaNode is lowered, and - // the corresponding XLA operation returned. - xla::XlaOp GetOutputOp(const torch::lazy::Output& output); + // Retrieves the lowered operation for a output. + // + // If the requested output is not available yet, the graph behind the output's + // XlaNode is lowered, and the corresponding XLA operation returned. + [[deprecated("Use SafeGetOutputOp for better error handling.")]] xla::XlaOp + GetOutputOp(const torch::lazy::Output& output); + // Retrieves the lowered operation for a output. + // + // If the requested output is not available yet, the graph behind the output's + // XlaNode is lowered, and the corresponding XLA operation returned. + // + // This function shall return an error status if the lowering the underlying + // `output`, or any other dependent nodes, returns an error status. + // Additionally, it might abort if after the lowering of `output` and its + // dependent nodes, the lowered node for `output` is not available, i.e. not + // in `emitted_outputs_`. + absl::StatusOr SafeGetOutputOp(const torch::lazy::Output& output); // Build the XLA computation capturing all the operations created with the // embedded XLA builder (returned by the builder() API). @@ -110,7 +126,7 @@ class LoweringContext : public torch::lazy::LoweringContext { torch::lazy::ComputationPtr Build() override; - const OutputMap GetEmittedOutputs() const { + const torch::lazy::OutputMap GetEmittedOutputs() const { return emitted_outputs_; } @@ -124,11 +140,15 @@ class LoweringContext : public torch::lazy::LoweringContext { size_t index = 0; }; + // Checks whether the given output is already emitted. In other words, whether + // we can find it inside `emitted_outputs_`. + absl::Status CheckOutputIsEmitted(const torch::lazy::Output& output) const; + xla::XlaBuilder builder_; std::unordered_map parameters_map_; std::vector root_tuple_; - OutputMap emitted_outputs_; + torch::lazy::OutputMap emitted_outputs_; std::string name_; std::shared_ptr stack_frame_index_builder_; diff --git a/torch_xla/csrc/ops/custom_call.cpp b/torch_xla/csrc/ops/custom_call.cpp index b7980f75f573..53032fd9cee4 100644 --- a/torch_xla/csrc/ops/custom_call.cpp +++ b/torch_xla/csrc/ops/custom_call.cpp @@ -4,6 +4,7 @@ #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/xla_ops.h" +#include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/shape_helper.h" namespace torch_xla { diff --git a/torch_xla/csrc/ops/dot_general.cpp b/torch_xla/csrc/ops/dot_general.cpp index 345bfc3a08ff..a0abd21199b3 100644 --- a/torch_xla/csrc/ops/dot_general.cpp +++ b/torch_xla/csrc/ops/dot_general.cpp @@ -7,6 +7,7 @@ #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/ops/xla_ops.h" +#include "torch_xla/csrc/runtime/debug_macros.h" namespace torch_xla { diff --git a/torch_xla/csrc/ops/symeig.cpp b/torch_xla/csrc/ops/symeig.cpp index 996f91ba5cc2..7242d8a25e79 100644 --- a/torch_xla/csrc/ops/symeig.cpp +++ b/torch_xla/csrc/ops/symeig.cpp @@ -5,6 +5,7 @@ #include "xla/hlo/builder/lib/self_adjoint_eig.h" #include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/shape_helper.h" namespace torch_xla { From 51c7bb43eda554a0aad8e1d55e40ed509e595170 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 26 Nov 2025 15:20:10 -0300 Subject: [PATCH 2/5] Remove OuputMap declaration. --- torch_xla/csrc/ir.h | 4 ---- 1 file changed, 4 deletions(-) diff --git a/torch_xla/csrc/ir.h b/torch_xla/csrc/ir.h index eea82a8fe2bd..36345e72b88b 100644 --- a/torch_xla/csrc/ir.h +++ b/torch_xla/csrc/ir.h @@ -33,10 +33,6 @@ class LoweringContext; using XlaOpVector = absl::InlinedVector; -template -using OutputMap = - std::unordered_map; - void DetectDynamicShape(torch::lazy::NodePtr node); template From f5f186e2975a49cc08491f762c85cb56ac4cb046 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 3 Dec 2025 09:54:18 -0300 Subject: [PATCH 3/5] Rebase and fix lint issues. --- torch_xla/csrc/lowering_context.cpp | 3 +-- torch_xla/csrc/lowering_context.h | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp index f2ea4d6ea458..b00afb03f392 100644 --- a/torch_xla/csrc/lowering_context.cpp +++ b/torch_xla/csrc/lowering_context.cpp @@ -1,7 +1,5 @@ #include "torch_xla/csrc/lowering_context.h" -#include - #include #include #include @@ -11,6 +9,7 @@ #include #include +#include #include #include #include diff --git a/torch_xla/csrc/lowering_context.h b/torch_xla/csrc/lowering_context.h index 699ec41bfc5e..bcb0a35e6d31 100644 --- a/torch_xla/csrc/lowering_context.h +++ b/torch_xla/csrc/lowering_context.h @@ -1,8 +1,6 @@ #ifndef XLA_TORCH_XLA_CSRC_LOWERING_CONTEXT_H_ #define XLA_TORCH_XLA_CSRC_LOWERING_CONTEXT_H_ -#include - #include #include #include @@ -12,6 +10,7 @@ #include #include +#include #include #include #include From c4a25468eedb2718c36693680e24357c477feb8a Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 3 Dec 2025 13:31:01 -0300 Subject: [PATCH 4/5] Add include. --- torch_xla/csrc/ir.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_xla/csrc/ir.cpp b/torch_xla/csrc/ir.cpp index c63b772ac45d..37dc9bee2f05 100644 --- a/torch_xla/csrc/ir.cpp +++ b/torch_xla/csrc/ir.cpp @@ -17,6 +17,7 @@ #include "torch_xla/csrc/runtime/cache.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/sys_util.h" +#include "torch_xla/csrc/status.h" namespace torch_xla { namespace { From cdaf1c8f1ed7d3348cc1b8778bc48eea5e4afca3 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 3 Dec 2025 14:32:55 -0300 Subject: [PATCH 5/5] Fix condition. --- torch_xla/csrc/lowering_context.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp index b00afb03f392..14b812d8d3d3 100644 --- a/torch_xla/csrc/lowering_context.cpp +++ b/torch_xla/csrc/lowering_context.cpp @@ -259,7 +259,7 @@ xla::XlaOp LoweringContext::GetOutputOp(const torch::lazy::Output& output) { absl::StatusOr LoweringContext::SafeGetOutputOp( const torch::lazy::Output& output) { - if (CheckOutputIsEmitted(output).ok()) { + if (!CheckOutputIsEmitted(output).ok()) { const std::vector post_order = torch::lazy::Util::ComputePostOrder(output.node, &emit_status_); for (const torch::lazy::Node* const node : post_order) {