Skip to content

Commit a99b5c3

Browse files
committed
Improve error handling of LoweringContext::GetOutputOp
1 parent e8d46ef commit a99b5c3

File tree

5 files changed

+75
-27
lines changed

5 files changed

+75
-27
lines changed

torch_xla/csrc/lowering_context.cpp

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,36 @@
11
#include "torch_xla/csrc/lowering_context.h"
22

3+
#include <c10/util/ArrayRef.h>
4+
5+
#include <cstddef>
6+
#include <cstdint>
37
#include <memory>
48
#include <optional>
5-
#include <sstream>
6-
#include <stdexcept>
9+
#include <string>
710
#include <unordered_set>
811
#include <utility>
12+
#include <vector>
913

14+
#include <torch/csrc/lazy/backend/backend_data.h>
15+
#include <torch/csrc/lazy/backend/backend_device.h>
16+
#include <torch/csrc/lazy/backend/lowering_context.h>
17+
#include <torch/csrc/lazy/core/config.h>
18+
#include <torch/csrc/lazy/core/ir.h>
1019
#include <torch/csrc/lazy/core/ir_metadata.h>
20+
#include <torch/csrc/lazy/core/ir_util.h>
1121

12-
#include "absl/log/absl_check.h"
13-
#include "absl/log/absl_log.h"
1422
#include "absl/status/status.h"
23+
#include "absl/status/statusor.h"
1524
#include "absl/strings/str_cat.h"
1625
#include "absl/strings/str_join.h"
1726
#include "absl/strings/str_replace.h"
27+
#include "xla/hlo/builder/xla_builder.h"
28+
#include "xla/hlo/builder/xla_computation.h"
29+
#include "xla/shape.h"
30+
#include "xla/xla_data.pb.h"
1831

1932
#include "torch_xla/csrc/ir.h"
2033
#include "torch_xla/csrc/runtime/computation_client.h"
21-
#include "torch_xla/csrc/runtime/debug_macros.h"
2234
#include "torch_xla/csrc/runtime/sys_util.h"
2335
#include "torch_xla/csrc/shape_helper.h"
2436
#include "torch_xla/csrc/stack_frame_index_builder.h"
@@ -242,21 +254,23 @@ void LoweringContext::AssignOutputOp(const torch::lazy::Output& output,
242254
}
243255

244256
xla::XlaOp LoweringContext::GetOutputOp(const torch::lazy::Output& output) {
245-
auto it = emitted_outputs_.find(output);
257+
XLA_ASSIGN_OR_THROW(xla::XlaOp op, SafeGetOutputOp(output));
258+
return op;
259+
}
246260

247-
if (it == emitted_outputs_.end()) {
248-
const auto post_order =
261+
absl::StatusOr<xla::XlaOp> LoweringContext::SafeGetOutputOp(
262+
const torch::lazy::Output& output) {
263+
if (CheckOutputIsEmitted(output).ok()) {
264+
const std::vector<const torch::lazy::Node*> post_order =
249265
torch::lazy::Util::ComputePostOrder(output.node, &emit_status_);
250-
for (const auto* const node : post_order) {
251-
XLA_THROW_IF_ERROR(LowerNode(*node));
266+
for (const torch::lazy::Node* const node : post_order) {
267+
XLA_RETURN_IF_ERROR(LowerNode(*node));
252268
}
253269
// At this point the output better be present, otherwise there is an issue
254270
// with the lowering code.
255-
it = emitted_outputs_.find(output);
256-
ABSL_CHECK(it != emitted_outputs_.end())
257-
<< "No XLA operation emitted for output: " << output;
271+
XLA_CHECK_OK(CheckOutputIsEmitted(output));
258272
}
259-
return it->second;
273+
return emitted_outputs_.at(output);
260274
}
261275

262276
absl::StatusOr<XlaOpVector> LoweringContext::LowerNode(
@@ -329,4 +343,15 @@ torch::lazy::ComputationPtr LoweringContext::Build() {
329343
builder_.name(), std::move(xla_computation), device_);
330344
}
331345

346+
absl::Status LoweringContext::CheckOutputIsEmitted(
347+
const torch::lazy::Output& output) const {
348+
torch::lazy::OutputMap<xla::XlaOp>::const_iterator it =
349+
emitted_outputs_.find(output);
350+
if (it == emitted_outputs_.end()) {
351+
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(
352+
absl::StrCat("could not find output: ", output.ToString())));
353+
}
354+
return absl::OkStatus();
355+
}
356+
332357
} // namespace torch_xla

torch_xla/csrc/lowering_context.h

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,30 @@
11
#ifndef XLA_TORCH_XLA_CSRC_LOWERING_CONTEXT_H_
22
#define XLA_TORCH_XLA_CSRC_LOWERING_CONTEXT_H_
33

4+
#include <c10/util/ArrayRef.h>
5+
6+
#include <cstddef>
7+
#include <cstdint>
48
#include <memory>
59
#include <optional>
610
#include <string>
7-
#include <string_view>
811
#include <unordered_map>
9-
#include <utility>
12+
#include <unordered_set>
1013
#include <vector>
1114

1215
#include <torch/csrc/lazy/backend/backend_data.h>
16+
#include <torch/csrc/lazy/backend/backend_device.h>
1317
#include <torch/csrc/lazy/backend/lowering_context.h>
18+
#include <torch/csrc/lazy/core/ir.h>
19+
#include <torch/csrc/lazy/core/ir_metadata.h>
1420
#include <torch/csrc/lazy/core/ir_util.h>
1521

1622
#include "absl/status/status.h"
17-
#include "absl/types/span.h"
18-
#include "tsl/platform/macros.h"
23+
#include "absl/status/statusor.h"
1924
#include "xla/hlo/builder/xla_builder.h"
20-
#include "xla/types.h"
25+
#include "xla/hlo/builder/xla_computation.h"
2126

22-
#include "torch_xla/csrc/device.h"
2327
#include "torch_xla/csrc/ir.h"
24-
#include "torch_xla/csrc/runtime/computation_client.h"
2528

2629
namespace torch_xla {
2730

@@ -74,10 +77,23 @@ class LoweringContext : public torch::lazy::LoweringContext {
7477
// operands among the emitted outputs.
7578
void AssignOutputOp(const torch::lazy::Output& output, xla::XlaOp op);
7679

77-
// Retrieves the lowered operation for a output. If the requested output is
78-
// not available yet, the graph behind the output's XlaNode is lowered, and
79-
// the corresponding XLA operation returned.
80-
xla::XlaOp GetOutputOp(const torch::lazy::Output& output);
80+
// Retrieves the lowered operation for a output.
81+
//
82+
// If the requested output is not available yet, the graph behind the output's
83+
// XlaNode is lowered, and the corresponding XLA operation returned.
84+
[[deprecated("Use SafeGetOutputOp for better error handling.")]] xla::XlaOp
85+
GetOutputOp(const torch::lazy::Output& output);
86+
// Retrieves the lowered operation for a output.
87+
//
88+
// If the requested output is not available yet, the graph behind the output's
89+
// XlaNode is lowered, and the corresponding XLA operation returned.
90+
//
91+
// This function shall return an error status if the lowering the underlying
92+
// `output`, or any other dependent nodes, returns an error status.
93+
// Additionally, it might abort if after the lowering of `output` and its
94+
// dependent nodes, the lowered node for `output` is not available, i.e. not
95+
// in `emitted_outputs_`.
96+
absl::StatusOr<xla::XlaOp> SafeGetOutputOp(const torch::lazy::Output& output);
8197

8298
// Build the XLA computation capturing all the operations created with the
8399
// embedded XLA builder (returned by the builder() API).
@@ -110,7 +126,7 @@ class LoweringContext : public torch::lazy::LoweringContext {
110126

111127
torch::lazy::ComputationPtr Build() override;
112128

113-
const OutputMap<xla::XlaOp> GetEmittedOutputs() const {
129+
const torch::lazy::OutputMap<xla::XlaOp> GetEmittedOutputs() const {
114130
return emitted_outputs_;
115131
}
116132

@@ -124,11 +140,15 @@ class LoweringContext : public torch::lazy::LoweringContext {
124140
size_t index = 0;
125141
};
126142

143+
// Checks whether the given output is already emitted. In other words, whether
144+
// we can find it inside `emitted_outputs_`.
145+
absl::Status CheckOutputIsEmitted(const torch::lazy::Output& output) const;
146+
127147
xla::XlaBuilder builder_;
128148
std::unordered_map<torch::lazy::BackendData::Handle, Parameter>
129149
parameters_map_;
130150
std::vector<xla::XlaOp> root_tuple_;
131-
OutputMap<xla::XlaOp> emitted_outputs_;
151+
torch::lazy::OutputMap<xla::XlaOp> emitted_outputs_;
132152
std::string name_;
133153

134154
std::shared_ptr<StackFrameIndexBuilder> stack_frame_index_builder_;

torch_xla/csrc/ops/custom_call.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include "torch_xla/csrc/lowering_context.h"
66
#include "torch_xla/csrc/ops/xla_ops.h"
7+
#include "torch_xla/csrc/runtime/debug_macros.h"
78
#include "torch_xla/csrc/shape_helper.h"
89

910
namespace torch_xla {

torch_xla/csrc/ops/dot_general.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "torch_xla/csrc/lowering_context.h"
99
#include "torch_xla/csrc/ops/infer_output_shape.h"
1010
#include "torch_xla/csrc/ops/xla_ops.h"
11+
#include "torch_xla/csrc/runtime/debug_macros.h"
1112

1213
namespace torch_xla {
1314

torch_xla/csrc/ops/symeig.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "xla/hlo/builder/lib/self_adjoint_eig.h"
66

77
#include "torch_xla/csrc/lowering_context.h"
8+
#include "torch_xla/csrc/runtime/debug_macros.h"
89
#include "torch_xla/csrc/shape_helper.h"
910

1011
namespace torch_xla {

0 commit comments

Comments
 (0)