@@ -4,12 +4,9 @@ DEF RS_RNG_JUMPABLE = 1
4
4
cdef extern from " distributions.h" :
5
5
6
6
cdef struct s_mrg32k3a_state:
7
- int64_t s10
8
- int64_t s11
9
- int64_t s12
10
- int64_t s20
11
- int64_t s21
12
- int64_t s22
7
+ int64_t s1[3 ]
8
+ int64_t s2[3 ]
9
+ int loc
13
10
14
11
ctypedef s_mrg32k3a_state mrg32k3a_state
15
12
@@ -31,16 +28,18 @@ ctypedef mrg32k3a_state rng_t
31
28
ctypedef uint64_t rng_state_t
32
29
33
30
cdef object _get_state(aug_state state):
34
- return (state.rng.s10, state.rng.s11, state.rng.s12,
35
- state.rng.s20, state.rng.s21, state.rng.s22)
31
+ return (state.rng.s1[0 ], state.rng.s1[1 ], state.rng.s1[2 ],
32
+ state.rng.s2[0 ], state.rng.s2[1 ], state.rng.s2[2 ],
33
+ state.rng.loc)
36
34
37
35
cdef object _set_state(aug_state * state, object state_info):
38
- state.rng.s10 = state_info[0 ]
39
- state.rng.s11 = state_info[1 ]
40
- state.rng.s12 = state_info[2 ]
41
- state.rng.s20 = state_info[3 ]
42
- state.rng.s21 = state_info[4 ]
43
- state.rng.s22 = state_info[5 ]
36
+ state.rng.s1[0 ] = state_info[0 ]
37
+ state.rng.s1[1 ] = state_info[1 ]
38
+ state.rng.s1[2 ] = state_info[2 ]
39
+ state.rng.s2[0 ] = state_info[3 ]
40
+ state.rng.s2[1 ] = state_info[4 ]
41
+ state.rng.s2[2 ] = state_info[5 ]
42
+ state.rng.loc = state_info[6 ]
44
43
45
44
cdef object matrix_power_127(x, m):
46
45
n = x.shape[0 ]
@@ -68,21 +67,39 @@ A2_127 = matrix_power_127(A2p, m2)
68
67
69
68
cdef void jump_state(aug_state* state):
70
69
# vectors s1 and s2
71
- s1 = np.array([state.rng.s10,state.rng.s11,state.rng.s12], dtype = np.uint64)
72
- s2 = np.array([state.rng.s20,state.rng.s21,state.rng.s22], dtype = np.uint64)
70
+ loc = state.rng.loc
71
+
72
+ if loc == 0 :
73
+ loc_m1 = 2
74
+ loc_m2 = 1
75
+ elif loc == 1 :
76
+ loc_m1 = 0
77
+ loc_m2 = 2
78
+ else :
79
+ loc_m1 = 1
80
+ loc_m2 = 0
81
+
82
+ s1 = np.array([state.rng.s1[loc_m2],
83
+ state.rng.s1[loc_m1],
84
+ state.rng.s1[loc]], dtype = np.uint64)
85
+ s2 = np.array([state.rng.s2[loc_m2],
86
+ state.rng.s2[loc_m1],
87
+ state.rng.s2[loc]], dtype = np.uint64)
73
88
74
89
# Advance the state
75
90
s1 = np.mod(A1_127.dot(s1), m1)
76
91
s2 = np.mod(A1_127.dot(s2), m2)
77
92
78
93
# Restore state
79
- state.rng.s10 = s1[0 ]
80
- state.rng.s11 = s1[1 ]
81
- state.rng.s12 = s1[2 ]
94
+ state.rng.s1[ 0 ] = s1[0 ]
95
+ state.rng.s1[ 1 ] = s1[1 ]
96
+ state.rng.s1[ 2 ] = s1[2 ]
82
97
83
- state.rng.s20 = s2[0 ]
84
- state.rng.s21 = s2[1 ]
85
- state.rng.s22 = s2[2 ]
98
+ state.rng.s2[0 ] = s2[0 ]
99
+ state.rng.s2[1 ] = s2[1 ]
100
+ state.rng.s2[2 ] = s2[2 ]
101
+
102
+ state.rng.loc = 2
86
103
87
104
DEF CLASS_DOCSTRING = """
88
105
RandomState(seed=None)
0 commit comments