From 3bc5570d894836a1dab1109eba8599ecb20daa77 Mon Sep 17 00:00:00 2001 From: James Xu Date: Wed, 12 Nov 2025 10:24:10 -0500 Subject: [PATCH 1/2] Add torch_xla.runtime.clear_computation_cache() binding (#16) * Add torch_xla.runtime.clear_computation_cache() binding * Run linters * Add optional workflow dispatch to build wheel from branch instead of torch-xla main --- torch_xla/csrc/init_python_bindings.cpp | 8 ++++++++ torch_xla/runtime.py | 7 +++++++ 2 files changed, 15 insertions(+) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index ce6012c1073d..c163ac074917 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1618,6 +1618,14 @@ void InitXlaModuleBindings(py::module m) { []() { return XLAGraphExecutor::Get()->IsComputationCacheInitialized(); }) + .def("_xla_computation_cache_clear", + []() { + WaitDeviceOps();// wait for any inflight computations which may hold references to cached computations + XLAGraphExecutor::ComputationCache* cache = XLAGraphExecutor::Get()->GetComputationCache(); + if (cache != nullptr) { + cache->Clear(); + } + }) .def("_get_git_revs", // &GetRevisions) .def("_get_xla_tensor_dimension_size", diff --git a/torch_xla/runtime.py b/torch_xla/runtime.py index 4e6352bc5527..ae4800a79dcf 100644 --- a/torch_xla/runtime.py +++ b/torch_xla/runtime.py @@ -274,3 +274,10 @@ def get_num_cached_compilation_graph(): the compilation graph will be fetched into the in-memory cache. """ return torch_xla._XLAC._xla_get_num_cached_compilation_graph() + + +def clear_computation_cache(): + """Clears the XLA computation cache contents.""" + assert torch_xla._XLAC._xla_computation_cache_is_initialized( + ), "Computation cache must be initialized to clear it." + torch_xla._XLAC._xla_computation_cache_clear() From fb11fedf82f6963b63c3c9c82b4a0d92404da137 Mon Sep 17 00:00:00 2001 From: James Xu Date: Mon, 17 Nov 2025 16:50:27 -0500 Subject: [PATCH 2/2] Warn instead of asserting when clearing computation cache (#17) * Change assertion to warning * Return bool so caller can optionally check if cache was cleared --- torch_xla/runtime.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/torch_xla/runtime.py b/torch_xla/runtime.py index ae4800a79dcf..25c2e707af97 100644 --- a/torch_xla/runtime.py +++ b/torch_xla/runtime.py @@ -277,7 +277,12 @@ def get_num_cached_compilation_graph(): def clear_computation_cache(): - """Clears the XLA computation cache contents.""" - assert torch_xla._XLAC._xla_computation_cache_is_initialized( - ), "Computation cache must be initialized to clear it." + """Clears the XLA computation cache contents, if computation cache is initialized. + Returns: + bool: whether the cache was cleared successfully. + """ + if not torch_xla._XLAC._xla_computation_cache_is_initialized(): + warnings.warn("Computation cache must be initialized to clear it.") + return False torch_xla._XLAC._xla_computation_cache_clear() + return True