diff --git a/mlx/random.h b/mlx/random.h index a23c25572a..0c11de6372 100644 --- a/mlx/random.h +++ b/mlx/random.h @@ -19,20 +19,20 @@ class MLX_API KeySequence { void seed(uint64_t seed); array next(); - // static default + // Each thread has its own random key to avoid race condition. static KeySequence& default_() { - static KeySequence ks(get_current_time_seed()); + static auto time_seed = []() { + auto now = std::chrono::system_clock::now(); + return std::chrono::duration_cast( + now.time_since_epoch()) + .count(); + }(); + static thread_local KeySequence ks(time_seed); return ks; } private: array key_; - static uint64_t get_current_time_seed() { - auto now = std::chrono::system_clock::now(); - return std::chrono::duration_cast( - now.time_since_epoch()) - .count(); - } }; /** Get a PRNG key from a seed. */ diff --git a/python/src/random.cpp b/python/src/random.cpp index c832c5a9ed..d944504a30 100644 --- a/python/src/random.cpp +++ b/python/src/random.cpp @@ -18,41 +18,43 @@ using namespace nb::literals; class PyKeySequence { public: - explicit PyKeySequence(uint64_t seed) { - state_.append(mx::random::key(seed)); + PyKeySequence() { + // Destroy state before the python interpreter exits. + auto atexit = nb::module_::import_("atexit"); + atexit.attr("register")(nb::cpp_function([this]() { state_.reset(); })); } void seed(uint64_t seed) { - state_[0] = mx::random::key(seed); + state()[0] = mx::random::key(seed); } mx::array next() { - auto out = mx::random::split(nb::cast(state_[0])); - state_[0] = out.first; + auto out = mx::random::split(nb::cast(state()[0])); + state()[0] = out.first; return out.second; } - nb::list state() { - return state_; - } - - void release() { - nb::gil_scoped_acquire gil; - state_.release().dec_ref(); + nb::list& state() { + if (!state_) { + static auto time_seed = []() { + auto now = std::chrono::system_clock::now(); + return std::chrono::duration_cast( + now.time_since_epoch()) + .count(); + }(); + state_ = nb::list(); + state_->append(mx::random::key(time_seed)); + } + return *state_; } private: - nb::list state_; + std::optional state_; }; PyKeySequence& default_key() { - auto get_current_time_seed = []() { - auto now = std::chrono::system_clock::now(); - return std::chrono::duration_cast( - now.time_since_epoch()) - .count(); - }; - static PyKeySequence ks(get_current_time_seed()); + // Each thread has its own random key to avoid race condition. + static thread_local PyKeySequence ks; return ks; } @@ -61,7 +63,16 @@ void init_random(nb::module_& parent_module) { "random", "mlx.core.random: functionality related to random number generation"); - m.attr("state") = default_key().state(); + m.def("__getattr__", [&](nb::handle key) -> nb::object { + // Create random.state lazily to avoid initializing device during import. + if (nb::isinstance(key) && nb::cast(key) == "state") { + return default_key().state(); + } + return nb::steal(PyErr_Format( + PyExc_AttributeError, + "Module 'random' has no attribute %R", + key.ptr())); + }); m.def( "seed", [](uint64_t seed) { default_key().seed(seed); }, @@ -510,7 +521,4 @@ void init_random(nb::module_& parent_module) { array: The generated random permutation or randomly permuted input array. )pbdoc"); - // Register static Python object cleanup before the interpreter exits - auto atexit = nb::module_::import_("atexit"); - atexit.attr("register")(nb::cpp_function([]() { default_key().release(); })); }