diff --git a/stdlib/Random/src/MersenneTwister.jl b/stdlib/Random/src/MersenneTwister.jl index 7caa75ddcd0a7..8cd8a03f1c4c8 100644 --- a/stdlib/Random/src/MersenneTwister.jl +++ b/stdlib/Random/src/MersenneTwister.jl @@ -9,9 +9,9 @@ const MT_CACHE_I = 501 << 4 # number of bytes in the UInt128 cache mutable struct MersenneTwister <: AbstractRNG seed::Any - state::DSFMT_state - vals::Vector{Float64} - ints::Vector{UInt128} + const state::DSFMT_state + const vals::Memory{Float64} + const ints::Vector{UInt128} # it's temporarily resized internally idxF::Int idxI::Int @@ -21,25 +21,13 @@ mutable struct MersenneTwister <: AbstractRNG adv_vals::Int64 # state of advance when vals is filled-up adv_ints::Int64 # state of advance when ints is filled-up - function MersenneTwister(seed, state, vals, ints, idxF, idxI, - adv, adv_jump, adv_vals, adv_ints) - length(vals) == MT_CACHE_F && 0 <= idxF <= MT_CACHE_F || - throw(DomainError((length(vals), idxF), - "`length(vals)` and `idxF` must be consistent with $MT_CACHE_F")) - length(ints) == MT_CACHE_I >> 4 && 0 <= idxI <= MT_CACHE_I || - throw(DomainError((length(ints), idxI), - "`length(ints)` and `idxI` must be consistent with $MT_CACHE_I")) - new(seed, state, vals, ints, idxF, idxI, - adv, adv_jump, adv_vals, adv_ints) - end + global _MersenneTwister(::UndefInitializer) = + new(nothing, DSFMT_state(), + Memory{Float64}(undef, MT_CACHE_F), + Vector{UInt128}(undef, MT_CACHE_I >> 4), + MT_CACHE_F, 0, 0, Base.GMP.ZERO, -1, -1) end -MersenneTwister(seed, state::DSFMT_state) = - MersenneTwister(seed, state, - Vector{Float64}(undef, MT_CACHE_F), - Vector{UInt128}(undef, MT_CACHE_I >> 4), - MT_CACHE_F, 0, 0, 0, -1, -1) - """ MersenneTwister(seed) MersenneTwister() @@ -72,8 +60,7 @@ julia> x1 == x2 true ``` """ -MersenneTwister(seed=nothing) = - seed!(MersenneTwister(Vector{UInt32}(), DSFMT_state()), seed) +MersenneTwister(seed=nothing) = seed!(_MersenneTwister(undef), seed) function copy!(dst::MersenneTwister, src::MersenneTwister) @@ -90,10 +77,7 @@ function copy!(dst::MersenneTwister, src::MersenneTwister) dst end -copy(src::MersenneTwister) = - MersenneTwister(src.seed, copy(src.state), copy(src.vals), copy(src.ints), - src.idxF, src.idxI, src.adv, src.adv_jump, src.adv_vals, src.adv_ints) - +copy(src::MersenneTwister) = copy!(_MersenneTwister(undef), src) ==(r1::MersenneTwister, r2::MersenneTwister) = r1.seed == r2.seed && r1.state == r2.state && @@ -250,7 +234,7 @@ function initstate!(r::MersenneTwister, data::StridedVector, seed) dsfmt_init_by_array(r.state, reinterpret(UInt32, data)) reset_caches!(r) r.adv = 0 - r.adv_jump = 0 + r.adv_jump = Base.GMP.ZERO return r end @@ -561,7 +545,9 @@ end function _randjump(r::MersenneTwister, jumppoly::DSFMT.GF2X) adv = r.adv adv_jump = r.adv_jump - s = MersenneTwister(r.seed, DSFMT.dsfmt_jump(r.state, jumppoly)) + s = _MersenneTwister(undef) + s.seed = r.seed + copy!(s.state, DSFMT.dsfmt_jump(r.state, jumppoly)) reset_caches!(s) s.adv = adv s.adv_jump = adv_jump diff --git a/stdlib/Random/test/runtests.jl b/stdlib/Random/test/runtests.jl index 55cbf02f6ad9d..dad8ecbd1acda 100644 --- a/stdlib/Random/test/runtests.jl +++ b/stdlib/Random/test/runtests.jl @@ -646,18 +646,6 @@ end # MersenneTwister initialization with invalid values @test_throws DomainError DSFMT.DSFMT_state(zeros(Int32, rand(0:DSFMT.JN32-1))) -@test_throws DomainError MersenneTwister(zeros(UInt32, 1), DSFMT.DSFMT_state(), - zeros(Float64, 10), zeros(UInt128, MT_CACHE_I>>4), 0, 0, 0, 0, -1, -1) - -@test_throws DomainError MersenneTwister(zeros(UInt32, 1), DSFMT.DSFMT_state(), - zeros(Float64, MT_CACHE_F), zeros(UInt128, MT_CACHE_I>>4), -1, 0, 0, 0, -1, -1) - -@test_throws DomainError MersenneTwister(zeros(UInt32, 1), DSFMT.DSFMT_state(), - zeros(Float64, MT_CACHE_F), zeros(UInt128, MT_CACHE_I>>3), 0, 0, 0, 0, -1, -1) - -@test_throws DomainError MersenneTwister(zeros(UInt32, 1), DSFMT.DSFMT_state(), - zeros(Float64, MT_CACHE_F), zeros(UInt128, MT_CACHE_I>>4), 0, -1, 0, 0, -1, -1) - # seed is private to MersenneTwister let seed = rand(UInt32, 10) r = MersenneTwister(seed)