diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 4af68d0cfc..f3b522332a 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -1,4 +1,6 @@ // Copyright © 2023-2024 Apple Inc. + +#include #include #include #include @@ -211,7 +213,7 @@ std::vector Compiled::output_shapes(const std::vector& inputs) { namespace detail { -CompileMode& compile_mode() { +std::atomic& compile_mode() { auto get_val = []() { if (std::getenv("MLX_DISABLE_COMPILE")) { return CompileMode::disabled; @@ -219,7 +221,7 @@ CompileMode& compile_mode() { return CompileMode::enabled; } }; - static CompileMode compile_mode_ = get_val(); + static std::atomic compile_mode_ = get_val(); return compile_mode_; } @@ -384,7 +386,7 @@ class CompilerCache { }; CompilerCache& compiler_cache() { - static CompilerCache compiler_cache_; + static thread_local CompilerCache compiler_cache_; return compiler_cache_; } @@ -1133,14 +1135,15 @@ ArrayFnWithExtra compile( compile_dfs(entry.inputs, entry.outputs, inputs); // Simplify the tape - if (compile_mode() != CompileMode::no_simplify) { + auto mode = compile_mode().load(); + if (mode != CompileMode::no_simplify) { compile_simplify( entry.tape, parents_map, entry.outputs, /* passes */ 3); } // Kernel fusion to generate Compiled primitives. The tape and // new outputs must be updated accordingly - if (compile_mode() != CompileMode::no_fuse) { + if (mode != CompileMode::no_fuse) { compile_fuse(entry.tape, parents_map, entry.inputs, entry.outputs); } } diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index c615b12ec5..6c82d95458 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -1461,6 +1461,14 @@ void init_transforms(nb::module_& m) { const nb::object& inputs, const nb::object& outputs, bool shapeless) { + // Make sure each thread using mx.compile would clear its compile cache + // before python interpreter exits. + static thread_local auto clear_cache = []() { + auto atexit = nb::module_::import_("atexit"); + atexit.attr("register")( + nb::cpp_function(&mx::detail::compile_clear_cache)); + return true; + }; return mlx_func( nb::cpp_function(PyCompiledFun{fun, inputs, outputs, shapeless}), fun, @@ -1534,9 +1542,4 @@ void init_transforms(nb::module_& m) { A callable that recomputes intermediate states during gradient computation. )pbdoc"); - - // Register static Python object cleanup before the interpreter exits - auto atexit = nb::module_::import_("atexit"); - atexit.attr("register")( - nb::cpp_function([]() { mx::detail::compile_clear_cache(); })); }