11#ifndef XLA_TORCH_XLA_CSRC_LOWERING_CONTEXT_H_
22#define XLA_TORCH_XLA_CSRC_LOWERING_CONTEXT_H_
33
4+ #include < c10/util/ArrayRef.h>
5+
6+ #include < cstddef>
7+ #include < cstdint>
48#include < memory>
59#include < optional>
610#include < string>
7- #include < string_view>
811#include < unordered_map>
9- #include < utility >
12+ #include < unordered_set >
1013#include < vector>
1114
1215#include < torch/csrc/lazy/backend/backend_data.h>
16+ #include < torch/csrc/lazy/backend/backend_device.h>
1317#include < torch/csrc/lazy/backend/lowering_context.h>
18+ #include < torch/csrc/lazy/core/ir.h>
19+ #include < torch/csrc/lazy/core/ir_metadata.h>
1420#include < torch/csrc/lazy/core/ir_util.h>
1521
1622#include " absl/status/status.h"
17- #include " absl/types/span.h"
18- #include " tsl/platform/macros.h"
23+ #include " absl/status/statusor.h"
1924#include " xla/hlo/builder/xla_builder.h"
20- #include " xla/types .h"
25+ #include " xla/hlo/builder/xla_computation .h"
2126
22- #include " torch_xla/csrc/device.h"
2327#include " torch_xla/csrc/ir.h"
24- #include " torch_xla/csrc/runtime/computation_client.h"
2528
2629namespace torch_xla {
2730
@@ -74,10 +77,23 @@ class LoweringContext : public torch::lazy::LoweringContext {
7477 // operands among the emitted outputs.
7578 void AssignOutputOp (const torch::lazy::Output& output, xla::XlaOp op);
7679
77- // Retrieves the lowered operation for a output. If the requested output is
78- // not available yet, the graph behind the output's XlaNode is lowered, and
79- // the corresponding XLA operation returned.
80- xla::XlaOp GetOutputOp (const torch::lazy::Output& output);
80+ // Retrieves the lowered operation for a output.
81+ //
82+ // If the requested output is not available yet, the graph behind the output's
83+ // XlaNode is lowered, and the corresponding XLA operation returned.
84+ [[deprecated(" Use SafeGetOutputOp for better error handling." )]] xla::XlaOp
85+ GetOutputOp (const torch::lazy::Output& output);
86+ // Retrieves the lowered operation for a output.
87+ //
88+ // If the requested output is not available yet, the graph behind the output's
89+ // XlaNode is lowered, and the corresponding XLA operation returned.
90+ //
91+ // This function shall return an error status if the lowering the underlying
92+ // `output`, or any other dependent nodes, returns an error status.
93+ // Additionally, it might abort if after the lowering of `output` and its
94+ // dependent nodes, the lowered node for `output` is not available, i.e. not
95+ // in `emitted_outputs_`.
96+ absl::StatusOr<xla::XlaOp> SafeGetOutputOp (const torch::lazy::Output& output);
8197
8298 // Build the XLA computation capturing all the operations created with the
8399 // embedded XLA builder (returned by the builder() API).
@@ -110,7 +126,7 @@ class LoweringContext : public torch::lazy::LoweringContext {
110126
111127 torch::lazy::ComputationPtr Build () override ;
112128
113- const OutputMap<xla::XlaOp> GetEmittedOutputs () const {
129+ const torch::lazy:: OutputMap<xla::XlaOp> GetEmittedOutputs () const {
114130 return emitted_outputs_;
115131 }
116132
@@ -124,11 +140,15 @@ class LoweringContext : public torch::lazy::LoweringContext {
124140 size_t index = 0 ;
125141 };
126142
143+ // Checks whether the given output is already emitted. In other words, whether
144+ // we can find it inside `emitted_outputs_`.
145+ absl::Status CheckOutputIsEmitted (const torch::lazy::Output& output) const ;
146+
127147 xla::XlaBuilder builder_;
128148 std::unordered_map<torch::lazy::BackendData::Handle, Parameter>
129149 parameters_map_;
130150 std::vector<xla::XlaOp> root_tuple_;
131- OutputMap<xla::XlaOp> emitted_outputs_;
151+ torch::lazy:: OutputMap<xla::XlaOp> emitted_outputs_;
132152 std::string name_;
133153
134154 std::shared_ptr<StackFrameIndexBuilder> stack_frame_index_builder_;
0 commit comments