diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index ce6012c1073..c163ac07491 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1618,6 +1618,14 @@ void InitXlaModuleBindings(py::module m) { []() { return XLAGraphExecutor::Get()->IsComputationCacheInitialized(); }) + .def("_xla_computation_cache_clear", + []() { + WaitDeviceOps();// wait for any inflight computations which may hold references to cached computations + XLAGraphExecutor::ComputationCache* cache = XLAGraphExecutor::Get()->GetComputationCache(); + if (cache != nullptr) { + cache->Clear(); + } + }) .def("_get_git_revs", // &GetRevisions) .def("_get_xla_tensor_dimension_size", diff --git a/torch_xla/runtime.py b/torch_xla/runtime.py index 4e6352bc552..25c2e707af9 100644 --- a/torch_xla/runtime.py +++ b/torch_xla/runtime.py @@ -274,3 +274,15 @@ def get_num_cached_compilation_graph(): the compilation graph will be fetched into the in-memory cache. """ return torch_xla._XLAC._xla_get_num_cached_compilation_graph() + + +def clear_computation_cache(): + """Clears the XLA computation cache contents, if computation cache is initialized. + Returns: + bool: whether the cache was cleared successfully. + """ + if not torch_xla._XLAC._xla_computation_cache_is_initialized(): + warnings.warn("Computation cache must be initialized to clear it.") + return False + torch_xla._XLAC._xla_computation_cache_clear() + return True