Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions mlx/random.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,20 @@ class MLX_API KeySequence {
void seed(uint64_t seed);
array next();

// static default
// Each thread has its own random key to avoid race condition.
static KeySequence& default_() {
static KeySequence ks(get_current_time_seed());
static auto time_seed = []() {
auto now = std::chrono::system_clock::now();
return std::chrono::duration_cast<std::chrono::milliseconds>(
now.time_since_epoch())
.count();
}();
static thread_local KeySequence ks(time_seed);
return ks;
}

private:
array key_;
static uint64_t get_current_time_seed() {
auto now = std::chrono::system_clock::now();
return std::chrono::duration_cast<std::chrono::milliseconds>(
now.time_since_epoch())
.count();
}
};

/** Get a PRNG key from a seed. */
Expand Down
56 changes: 32 additions & 24 deletions python/src/random.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,41 +18,43 @@ using namespace nb::literals;

class PyKeySequence {
public:
explicit PyKeySequence(uint64_t seed) {
state_.append(mx::random::key(seed));
PyKeySequence() {
// Destroy state before the python interpreter exits.
auto atexit = nb::module_::import_("atexit");
atexit.attr("register")(nb::cpp_function([this]() { state_.reset(); }));
}

void seed(uint64_t seed) {
state_[0] = mx::random::key(seed);
state()[0] = mx::random::key(seed);
}

mx::array next() {
auto out = mx::random::split(nb::cast<mx::array>(state_[0]));
state_[0] = out.first;
auto out = mx::random::split(nb::cast<mx::array>(state()[0]));
state()[0] = out.first;
return out.second;
}

nb::list state() {
return state_;
}

void release() {
nb::gil_scoped_acquire gil;
state_.release().dec_ref();
nb::list& state() {
if (!state_) {
static auto time_seed = []() {
auto now = std::chrono::system_clock::now();
return std::chrono::duration_cast<std::chrono::milliseconds>(
now.time_since_epoch())
.count();
}();
state_ = nb::list();
state_->append(mx::random::key(time_seed));
}
return *state_;
}

private:
nb::list state_;
std::optional<nb::list> state_;
};

PyKeySequence& default_key() {
auto get_current_time_seed = []() {
auto now = std::chrono::system_clock::now();
return std::chrono::duration_cast<std::chrono::milliseconds>(
now.time_since_epoch())
.count();
};
static PyKeySequence ks(get_current_time_seed());
// Each thread has its own random key to avoid race condition.
static thread_local PyKeySequence ks;
return ks;
}

Expand All @@ -61,7 +63,16 @@ void init_random(nb::module_& parent_module) {
"random",
"mlx.core.random: functionality related to random number generation");

m.attr("state") = default_key().state();
m.def("__getattr__", [&](nb::handle key) -> nb::object {
// Create random.state lazily to avoid initializing device during import.
if (nb::isinstance<nb::str>(key) && nb::cast<std::string>(key) == "state") {
return default_key().state();
}
return nb::steal(PyErr_Format(
PyExc_AttributeError,
"Module 'random' has no attribute %R",
key.ptr()));
});
m.def(
"seed",
[](uint64_t seed) { default_key().seed(seed); },
Expand Down Expand Up @@ -510,7 +521,4 @@ void init_random(nb::module_& parent_module) {
array:
The generated random permutation or randomly permuted input array.
)pbdoc");
// Register static Python object cleanup before the interpreter exits
auto atexit = nb::module_::import_("atexit");
atexit.attr("register")(nb::cpp_function([]() { default_key().release(); }));
}
Loading