@@ -43,16 +43,24 @@ The current defaults are:
4343 - Julia version is < 1.7: `Random.GLOBAL_RNG`
4444 - Julia version is >= 1.7: `Random.default_rng()`
4545"""
46- _rng_from_array (:: AbstractArray ) = _rng_from_array ()
47- _rng_from_array (:: CuArray ) = CUDA. default_rng ()
46+ rng_from_array (:: AbstractArray ) = default_rng_value ()
47+ rng_from_array (:: CuArray ) = CUDA. default_rng ()
48+
4849if VERSION >= v " 1.7"
49- _rng_from_array () = Random. default_rng ()
50+ @doc """
51+ default_rng_value()
52+
53+ Create an instance of the default RNG depending on Julia's version.
54+ - Julia version is < 1.7: `Random.GLOBAL_RNG`
55+ - Julia version is >= 1.7: `Random.default_rng()`
56+ """
57+ default_rng_value () = Random. default_rng ()
5058else
51- _rng_from_array () = Random. GLOBAL_RNG
59+ default_rng_value () = Random. GLOBAL_RNG
5260end
5361
5462"""
55- glorot_uniform([rng=GLOBAL_RNG ], size...; gain = 1) -> Array
63+ glorot_uniform([rng = default_rng_value() ], size...; gain = 1) -> Array
5664 glorot_uniform([rng]; kw...) -> Function
5765
5866Return an `Array{Float32}` of the given `size` containing random numbers drawn from a uniform
@@ -91,13 +99,13 @@ function glorot_uniform(rng::AbstractRNG, dims::Integer...; gain::Real=1)
9199 scale = Float32 (gain) * sqrt (24.0f0 / sum (nfan (dims... )))
92100 (rand (rng, Float32, dims... ) .- 0.5f0 ) .* scale
93101end
94- glorot_uniform (dims:: Integer... ; kw... ) = glorot_uniform (_rng_from_array (), dims... ; kw... )
95- glorot_uniform (rng:: AbstractRNG = _rng_from_array (); init_kwargs... ) = (dims... ; kwargs... ) -> glorot_uniform (rng, dims... ; init_kwargs... , kwargs... )
102+ glorot_uniform (dims:: Integer... ; kw... ) = glorot_uniform (default_rng_value (), dims... ; kw... )
103+ glorot_uniform (rng:: AbstractRNG = default_rng_value (); init_kwargs... ) = (dims... ; kwargs... ) -> glorot_uniform (rng, dims... ; init_kwargs... , kwargs... )
96104
97105ChainRulesCore. @non_differentiable glorot_uniform (:: Any... )
98106
99107"""
100- glorot_normal([rng=GLOBAL_RNG] , size...; gain = 1) -> Array
108+ glorot_normal([rng = default_rng_value() , size...; gain = 1) -> Array
101109 glorot_normal([rng]; kw...) -> Function
102110
103111Return an `Array{Float32}` of the given `size` containing random numbers drawn from a normal
@@ -134,13 +142,13 @@ function glorot_normal(rng::AbstractRNG, dims::Integer...; gain::Real=1)
134142 std = Float32 (gain) * sqrt (2.0f0 / sum (nfan (dims... )))
135143 randn (rng, Float32, dims... ) .* std
136144end
137- glorot_normal (dims:: Integer... ; kwargs... ) = glorot_normal (_rng_from_array (), dims... ; kwargs... )
138- glorot_normal (rng:: AbstractRNG = _rng_from_array (); init_kwargs... ) = (dims... ; kwargs... ) -> glorot_normal (rng, dims... ; init_kwargs... , kwargs... )
145+ glorot_normal (dims:: Integer... ; kwargs... ) = glorot_normal (default_rng_value (), dims... ; kwargs... )
146+ glorot_normal (rng:: AbstractRNG = default_rng_value (); init_kwargs... ) = (dims... ; kwargs... ) -> glorot_normal (rng, dims... ; init_kwargs... , kwargs... )
139147
140148ChainRulesCore. @non_differentiable glorot_normal (:: Any... )
141149
142150"""
143- kaiming_uniform([rng=GLOBAL_RNG ], size...; gain = √2) -> Array
151+ kaiming_uniform([rng = default_rng_value() ], size...; gain = √2) -> Array
144152 kaiming_uniform([rng]; kw...) -> Function
145153
146154Return an `Array{Float32}` of the given `size` containing random numbers drawn from a uniform distribution
@@ -169,13 +177,13 @@ function kaiming_uniform(rng::AbstractRNG, dims::Integer...; gain::Real = √2)
169177 return (rand (rng, Float32, dims... ) .- 0.5f0 ) .* 2 bound
170178end
171179
172- kaiming_uniform (dims:: Integer... ; kwargs... ) = kaiming_uniform (_rng_from_array (), dims... ; kwargs... )
173- kaiming_uniform (rng:: AbstractRNG = _rng_from_array (); init_kwargs... ) = (dims... ; kwargs... ) -> kaiming_uniform (rng, dims... ; init_kwargs... , kwargs... )
180+ kaiming_uniform (dims:: Integer... ; kwargs... ) = kaiming_uniform (default_rng_value (), dims... ; kwargs... )
181+ kaiming_uniform (rng:: AbstractRNG = default_rng_value (); init_kwargs... ) = (dims... ; kwargs... ) -> kaiming_uniform (rng, dims... ; init_kwargs... , kwargs... )
174182
175183ChainRulesCore. @non_differentiable kaiming_uniform (:: Any... )
176184
177185"""
178- kaiming_normal([rng=GLOBAL_RNG ], size...; gain = √2) -> Array
186+ kaiming_normal([rng = default_rng_value() ], size...; gain = √2) -> Array
179187 kaiming_normal([rng]; kw...) -> Function
180188
181189Return an `Array{Float32}` of the given `size` containing random numbers taken from a normal
@@ -206,13 +214,13 @@ function kaiming_normal(rng::AbstractRNG, dims::Integer...; gain::Real = √2f0)
206214 return randn (rng, Float32, dims... ) .* std
207215end
208216
209- kaiming_normal (dims:: Integer... ; kwargs... ) = kaiming_normal (_rng_from_array (), dims... ; kwargs... )
217+ kaiming_normal (dims:: Integer... ; kwargs... ) = kaiming_normal (default_rng_value (), dims... ; kwargs... )
210218kaiming_normal (rng:: AbstractRNG ; init_kwargs... ) = (dims... ; kwargs... ) -> kaiming_normal (rng, dims... ; init_kwargs... , kwargs... )
211219
212220ChainRulesCore. @non_differentiable kaiming_normal (:: Any... )
213221
214222"""
215- truncated_normal([rng=GLOBAL_RNG ], size...; mean = 0, std = 1, lo = -2, hi = 2) -> Array
223+ truncated_normal([rng = default_rng_value() ], size...; mean = 0, std = 1, lo = -2, hi = 2) -> Array
216224 truncated_normal([rng]; kw...) -> Function
217225
218226Return an `Array{Float32}` of the given `size` where each element is drawn from a truncated normal distribution.
@@ -252,13 +260,13 @@ function truncated_normal(rng::AbstractRNG, dims::Integer...; mean = 0, std = 1,
252260 return xs
253261end
254262
255- truncated_normal (dims:: Integer... ; kwargs... ) = truncated_normal (_rng_from_array (), dims... ; kwargs... )
256- truncated_normal (rng:: AbstractRNG = _rng_from_array (); init_kwargs... ) = (dims... ; kwargs... ) -> truncated_normal (rng, dims... ; init_kwargs... , kwargs... )
263+ truncated_normal (dims:: Integer... ; kwargs... ) = truncated_normal (default_rng_value (), dims... ; kwargs... )
264+ truncated_normal (rng:: AbstractRNG = default_rng_value (); init_kwargs... ) = (dims... ; kwargs... ) -> truncated_normal (rng, dims... ; init_kwargs... , kwargs... )
257265
258266ChainRulesCore. @non_differentiable truncated_normal (:: Any... )
259267
260268"""
261- orthogonal([rng=GLOBAL_RNG ], size...; gain = 1) -> Array
269+ orthogonal([rng = default_rng_value() ], size...; gain = 1) -> Array
262270 orthogonal([rng]; kw...) -> Function
263271
264272Return an `Array{Float32}` of the given `size` which is a (semi) orthogonal matrix, as described in [1].
@@ -313,13 +321,13 @@ function orthogonal(rng::AbstractRNG, d1::Integer, ds::Integer...; kwargs...)
313321 return reshape (orthogonal (rng, rows, cols; kwargs... ), dims)
314322end
315323
316- orthogonal (dims:: Integer... ; kwargs... ) = orthogonal (_rng_from_array (), dims... ; kwargs... )
317- orthogonal (rng:: AbstractRNG = _rng_from_array (); init_kwargs... ) = (dims:: Integer... ; kwargs... ) -> orthogonal (rng, dims... ; init_kwargs... , kwargs... )
324+ orthogonal (dims:: Integer... ; kwargs... ) = orthogonal (default_rng_value (), dims... ; kwargs... )
325+ orthogonal (rng:: AbstractRNG = default_rng_value (); init_kwargs... ) = (dims:: Integer... ; kwargs... ) -> orthogonal (rng, dims... ; init_kwargs... , kwargs... )
318326
319327ChainRulesCore. @non_differentiable orthogonal (:: Any... )
320328
321329"""
322- sparse_init([rng=GLOBAL_RNG ], rows, cols; sparsity, std = 0.01) -> Array
330+ sparse_init([rng = default_rng_value() ], rows, cols; sparsity, std = 0.01) -> Array
323331 sparse_init([rng]; kw...) -> Function
324332
325333Return a `Matrix{Float32}` of size `rows, cols` where each column contains a fixed fraction of
@@ -361,8 +369,8 @@ function sparse_init(rng::AbstractRNG, dims::Integer...; sparsity, std = 0.01)
361369 return mapslices (shuffle, sparse_array, dims= 1 )
362370end
363371
364- sparse_init (dims:: Integer... ; kwargs... ) = sparse_init (_rng_from_array (), dims... ; kwargs... )
365- sparse_init (rng:: AbstractRNG = _rng_from_array (); init_kwargs... ) = (dims... ; kwargs... ) -> sparse_init (rng, dims... ; init_kwargs... , kwargs... )
372+ sparse_init (dims:: Integer... ; kwargs... ) = sparse_init (default_rng_value (), dims... ; kwargs... )
373+ sparse_init (rng:: AbstractRNG = default_rng_value (); init_kwargs... ) = (dims... ; kwargs... ) -> sparse_init (rng, dims... ; init_kwargs... , kwargs... )
366374
367375ChainRulesCore. @non_differentiable sparse_init (:: Any... )
368376
452460
453461# For consistency, it accepts an RNG, but ignores it:
454462identity_init (:: AbstractRNG , dims:: Integer... ; kwargs... ) = identity_init (dims... ; kwargs... )
455- identity_init (rng:: AbstractRNG = _rng_from_array (); init_kwargs... ) = (args... ;kwargs... ) -> identity_init (rng, args... ; init_kwargs... , kwargs... )
463+ identity_init (rng:: AbstractRNG = default_rng_value (); init_kwargs... ) = (args... ;kwargs... ) -> identity_init (rng, args... ; init_kwargs... , kwargs... )
456464
457465ChainRulesCore. @non_differentiable identity_init (:: Any... )
458466
0 commit comments