You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
// Port from pytorch, original license: https://github.com/pytorch/pytorch/blob/d01a7b0241ed1c4cded7e7ca097249feb343f072/LICENSE
11
+
// Ref: https://github.com/pytorch/pytorch/blob/d01a7b0241ed1c4cded7e7ca097249feb343f072/aten/src/ATen/core/TransformationHelper.h, for uniform_real
12
+
// Ref: https://github.com/pytorch/pytorch/blob/d01a7b0241ed1c4cded7e7ca097249feb343f072/aten/src/ATen/native/cpu/DistributionTemplates.h, for normal_kernel/normal_fill/normal_fill_16
13
+
// Ref: https://github.com/pytorch/pytorch/blob/d01a7b0241ed1c4cded7e7ca097249feb343f072/aten/src/ATen/core/MT19937RNGEngine.h, for mt19937_engine
14
+
// Ref: https://github.com/pytorch/pytorch/blob/d01a7b0241ed1c4cded7e7ca097249feb343f072/aten/src/ATen/core/DistributionsHelper.h, for uniform_real_distribution/normal_distribution
15
+
classMT19937RNG : publicRNG {
16
+
staticconstint N = 624;
17
+
staticconstint M = 397;
18
+
staticconstuint32_t MATRIX_A = 0x9908b0dfU;
19
+
staticconstuint32_t UMASK = 0x80000000U;
20
+
staticconstuint32_t LMASK = 0x7fffffffU;
21
+
22
+
structState {
23
+
uint64_t seed_;
24
+
int left_;
25
+
bool seeded_;
26
+
uint32_t next_;
27
+
std::array<uint32_t, N> state_;
28
+
bool has_next_gauss = false;
29
+
double next_gauss = 0.0f;
30
+
};
31
+
32
+
State s;
33
+
34
+
uint32_tmix_bits(uint32_t u, uint32_t v) { return (u & UMASK) | (v & LMASK); }
0 commit comments