Skip to content

Commit b3e8b3e

Browse files
committed
feat: add support for custom compile options in torch_xla.compile and PJRT backend
This change introduces the ability to pass custom compile options from Python down to the PJRT backend, allowing users to fine-tune XLA compilation behavior without modifying core code. Key changes: * Python API * Added custom_compile_options parameter to torch_xla.compile for passing compile-time options as a dict (supports bool, float, int, and str values). * Added torch_xla.set_custom_compile_options() utility for setting compile options globally. * Added internal binding _XLAC._set_custom_compile_options(). * C++ Runtime * Added SetCustomCompileOptions() virtual method to ComputationClient and implemented it in PjRtComputationClient. * PjRtComputationClient now stores custom_compile_options_ and injects them into xla::CompileOptions.env_option_overrides during compilation. * Options are stringified before being passed to XLA for compatibility. Motivation:
This enables advanced users to pass through backend-specific tuning flags (e.g., enabling experimental optimizations, toggling partitioning strategies) without hardcoding them, improving flexibility for research and debugging workflows.
1 parent 0f56dec commit b3e8b3e

File tree

6 files changed

+55
-1
lines changed

6 files changed

+55
-1
lines changed

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3244,6 +3244,16 @@ void InitXlaModuleBindings(py::module m) {
32443244
XLA_ERROR() << "Could not get the buffer pointer for XLATensor "
32453245
"without a data handle or an IR.";
32463246
})
3247+
.def("_set_custom_compile_options",
3248+
[](const py::dict& compile_options) {
3249+
std::unordered_map<std::string, std::string> options;
3250+
for (const auto& item : compile_options) {
3251+
std::string key = item.first.cast<std::string>();
3252+
options[key] = py::str(item.second).cast<std::string>();
3253+
}
3254+
runtime::GetComputationClientOrDie()->SetCustomCompileOptions(
3255+
options);
3256+
})
32473257
.def(
32483258
// from an XLA tensor to a PyCapsule.
32493259
// When consuming the PyCapsule, we should synchronize

torch_xla/csrc/runtime/computation_client.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,9 @@ class ComputationClient {
447447
// after the last ':' character of the device string.
448448
static int64_t GetDeviceOrdinal(const std::string& device);
449449

450+
virtual void SetCustomCompileOptions(
451+
const std::unordered_map<std::string, std::string>& options) = 0;
452+
450453
protected:
451454
static constexpr auto spmd_device_str = "SPMD:0";
452455

torch_xla/csrc/runtime/ifrt_computation_client.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,11 @@ class IfrtComputationClient : public ComputationClient {
176176
XLA_ERROR() << __FUNCTION__ << " not implemented";
177177
}
178178

179+
void SetCustomCompileOptions(
180+
const std::unordered_map<std::string, std::string>& options) override {
181+
XLA_ERROR() << __FUNCTION__ << " not implemented";
182+
}
183+
179184
// Creates a new instance of IfrtComputationClient and initializes it.
180185
static absl::StatusOr<absl_nonnull std::unique_ptr<IfrtComputationClient>>
181186
Create();

torch_xla/csrc/runtime/pjrt_computation_client.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,13 +554,18 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
554554

555555
for (auto& instance : instances) {
556556
xla::CompileOptions compile_options;
557+
for (auto& option : custom_compile_options_) {
558+
compile_options.env_option_overrides.push_back(
559+
{option.first, option.second});
560+
}
557561
if (enable_cm_in_mp) {
558562
compile_options.executable_build_options.set_use_spmd_partitioning(true);
559563
compile_options.env_option_overrides.push_back(
560564
{"xla_tpu_decompose_all_gather_einsum", true});
561565
compile_options.env_option_overrides.push_back(
562566
{"xla_tpu_decompose_einsum_reduce_scatter", true});
563567
}
568+
564569
if (instance.is_sharded) {
565570
// TODO(yeounoh) multi-host, multi-slice configurations
566571
compile_options.executable_build_options.set_use_spmd_partitioning(true);
@@ -1088,5 +1093,14 @@ void PjRtComputationClient::OnReadyCallback(
10881093
[callback](absl::Status unused) { callback(); });
10891094
}
10901095

1096+
void PjRtComputationClient::SetCustomCompileOptions(
1097+
const std::unordered_map<std::string, std::string>& options) {
1098+
// Stringfy values
1099+
custom_compile_options_.clear();
1100+
for (const auto& [key, value] : options) {
1101+
custom_compile_options_[key] = value;
1102+
}
1103+
}
1104+
10911105
} // namespace runtime
10921106
} // namespace torch_xla

torch_xla/csrc/runtime/pjrt_computation_client.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@ class PjRtComputationClient : public ComputationClient {
174174
void OnReadyCallback(DataPtr data,
175175
const std::function<void()>& callback) override;
176176

177+
void SetCustomCompileOptions(
178+
const std::unordered_map<std::string, std::string>& options) override;
179+
177180
// Creates a new instance of PjRtComputationClient and initializes it.
178181
static absl::StatusOr<absl_nonnull std::unique_ptr<PjRtComputationClient>>
179182
Create();
@@ -206,6 +209,7 @@ class PjRtComputationClient : public ComputationClient {
206209
// If not nullptr, invoke this instead of the actual XLA compilation. Used
207210
// only for testing.
208211
std::function<absl::Status()> fake_xla_compile_ = nullptr;
212+
std::unordered_map<std::string, std::string> custom_compile_options_;
209213

210214
xla::PjRtDevice* StringToPjRtDevice(const std::string& device);
211215

torch_xla/torch_xla.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def compile(
116116
full_graph: Optional[bool] = False,
117117
name: Optional[str] = None,
118118
max_different_graphs: Optional[int] = None,
119+
custom_compile_options: Optional[dict] = None,
119120
):
120121
"""
121122
Optimizes given model/function using torch_xla's LazyTensor tracing mode.
@@ -136,6 +137,8 @@ def compile(
136137
max_different_graphs (Optional[int]): number of different traced graphs of the given
137138
model/function that we are allowed to have. An error will be raised in case this limit
138139
is exceeded.
140+
custom_compile_options (Optional[dict]): A dictionary of custom compile options to be set.
141+
The keys are strings and the values can be of type bool, float, int, or str.
139142
140143
Example::
141144
@@ -214,7 +217,8 @@ def _compile():
214217
sync()
215218
torch_xla._XLAC._set_use_eager_mode(saved_eager_mode_status)
216219
torch_xla._XLAC._set_current_graph_name(saved_current_graph_name)
217-
220+
if custom_compile_options is not None and len(custom_compile_options) > 0:
221+
torch_xla._XLAC._set_custom_compile_options(custom_compile_options)
218222
return _compile() if f is None else _compile()(f)
219223

220224

@@ -264,3 +268,17 @@ def launch(
264268
fn(xu.getenv_as(xenv.LOCAL_RANK, int), *args)
265269
else:
266270
xmp.spawn(fn, args=args, nprocs=nprocs, start_method=start_method)
271+
272+
def set_custom_compile_options(
273+
options: Optional[dict] = None,
274+
):
275+
"""Sets custom compile options for the XLA compilation.
276+
277+
Args:
278+
options: A dictionary of custom compile options to be set.
279+
The keys are strings and the values can be of type bool, float, int, or str.
280+
"""
281+
if options is None:
282+
options = {}
283+
torch_xla._XLAC._set_custom_compile_options(options)
284+

0 commit comments

Comments
 (0)