Skip to content

Commit 81e9f20

Browse files
committed
Added RawDistribution for better DistributionSampler impl self-documentation
1 parent ca0188d commit 81e9f20

File tree

5 files changed

+67
-44
lines changed

5 files changed

+67
-44
lines changed

necsim/core/src/cogs/distribution.rs

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,22 +36,43 @@ impl<D: DistributionCore> Distribution for D {
3636
}
3737

3838
#[allow(clippy::module_name_repetitions)]
39-
pub trait DistributionSampler<M: MathsCore, R: RngCore, S, D: DistributionCore> {
39+
pub trait RawDistribution: DistributionCore {
40+
fn sample_raw_with<M: MathsCore, R: RngCore, S: DistributionSampler<M, R, S, Self>>(
41+
rng: &mut R,
42+
samplers: &S,
43+
params: Self::Parameters,
44+
) -> Self::Sample;
45+
46+
fn sample_raw<M: MathsCore, R: RngCore, S: DistributionSampler<M, R, S, Self>>(
47+
rng: &mut R,
48+
samplers: &S,
49+
) -> Self::Sample
50+
where
51+
Self: DistributionCore<Parameters = ()>,
52+
{
53+
Self::sample_raw_with(rng, samplers, ())
54+
}
55+
}
56+
57+
impl<D: DistributionCore> RawDistribution for D {
58+
fn sample_raw_with<M: MathsCore, R: RngCore, S: DistributionSampler<M, R, S, Self>>(
59+
rng: &mut R,
60+
samplers: &S,
61+
params: Self::Parameters,
62+
) -> Self::Sample {
63+
samplers.sample_distribution(rng, samplers, params)
64+
}
65+
}
66+
67+
#[allow(clippy::module_name_repetitions)]
68+
pub trait DistributionSampler<M: MathsCore, R: RngCore, S, D: DistributionCore + ?Sized> {
4069
type ConcreteSampler: DistributionSampler<M, R, S, D>;
4170

4271
#[must_use]
4372
fn concrete(&self) -> &Self::ConcreteSampler;
4473

4574
#[must_use]
46-
fn sample_with(&self, rng: &mut R, samplers: &S, params: D::Parameters) -> D::Sample;
47-
48-
#[must_use]
49-
fn sample(&self, rng: &mut R, samplers: &S) -> D::Sample
50-
where
51-
D: DistributionCore<Parameters = ()>,
52-
{
53-
self.sample_with(rng, samplers, ())
54-
}
75+
fn sample_distribution(&self, rng: &mut R, samplers: &S, params: D::Parameters) -> D::Sample;
5576
}
5677

5778
pub enum UniformClosedOpenUnit {}

necsim/core/src/cogs/rng.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,6 @@ where
121121
{
122122
#[must_use]
123123
fn sample_with(&mut self, params: D::Parameters) -> D::Sample {
124-
self.with(|rng, samplers| samplers.sample_with(rng, samplers, params))
124+
self.with(|rng, samplers| samplers.sample_distribution(rng, samplers, params))
125125
}
126126
}

necsim/impls/no-std/src/cogs/active_lineage_sampler/alias/sampler/indexed/tests.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1141,7 +1141,7 @@ impl DistributionSampler<IntrinsicsMathsCore, DummyRng, DummyDistributionSampler
11411141
self
11421142
}
11431143

1144-
fn sample_with(
1144+
fn sample_distribution(
11451145
&self,
11461146
rng: &mut DummyRng,
11471147
_samplers: &DummyDistributionSamplers,
@@ -1170,7 +1170,7 @@ impl DistributionSampler<IntrinsicsMathsCore, DummyRng, DummyDistributionSampler
11701170
self
11711171
}
11721172

1173-
fn sample_with(
1173+
fn sample_distribution(
11741174
&self,
11751175
rng: &mut DummyRng,
11761176
_samplers: &DummyDistributionSamplers,
@@ -1199,7 +1199,7 @@ impl DistributionSampler<IntrinsicsMathsCore, DummyRng, DummyDistributionSampler
11991199
self
12001200
}
12011201

1202-
fn sample_with(
1202+
fn sample_distribution(
12031203
&self,
12041204
rng: &mut DummyRng,
12051205
_samplers: &DummyDistributionSamplers,

necsim/impls/no-std/src/cogs/active_lineage_sampler/alias/sampler/stack/tests.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,7 @@ impl DistributionSampler<IntrinsicsMathsCore, DummyRng, DummyDistributionSampler
640640
self
641641
}
642642

643-
fn sample_with(
643+
fn sample_distribution(
644644
&self,
645645
rng: &mut DummyRng,
646646
_samplers: &DummyDistributionSamplers,
@@ -669,7 +669,7 @@ impl DistributionSampler<IntrinsicsMathsCore, DummyRng, DummyDistributionSampler
669669
self
670670
}
671671

672-
fn sample_with(
672+
fn sample_distribution(
673673
&self,
674674
rng: &mut DummyRng,
675675
_samplers: &DummyDistributionSamplers,
@@ -698,7 +698,7 @@ impl DistributionSampler<IntrinsicsMathsCore, DummyRng, DummyDistributionSampler
698698
self
699699
}
700700

701-
fn sample_with(
701+
fn sample_distribution(
702702
&self,
703703
rng: &mut DummyRng,
704704
_samplers: &DummyDistributionSamplers,

necsim/impls/no-std/src/cogs/rng/simple.rs

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ use core::{
66
use necsim_core::cogs::{
77
distribution::{
88
Bernoulli, Exponential, IndexU128, IndexU32, IndexU64, IndexUsize, Lambda, Length, Normal,
9-
Normal2D, Poisson, StandardNormal2D, UniformClosedOpenUnit, UniformOpenClosedUnit,
9+
Normal2D, Poisson, RawDistribution, StandardNormal2D, UniformClosedOpenUnit,
10+
UniformOpenClosedUnit,
1011
},
1112
Backup, DistributionSampler, MathsCore, Rng, RngCore,
1213
};
@@ -86,7 +87,7 @@ impl<M: MathsCore, R: RngCore, S> DistributionSampler<M, R, S, UniformClosedOpen
8687
self
8788
}
8889

89-
fn sample_with(&self, rng: &mut R, _samplers: &S, _params: ()) -> ClosedOpenUnitF64 {
90+
fn sample_distribution(&self, rng: &mut R, _samplers: &S, _params: ()) -> ClosedOpenUnitF64 {
9091
// http://prng.di.unimi.it -> Generating uniform doubles in the unit interval
9192
#[allow(clippy::cast_precision_loss)]
9293
let u01 = ((rng.sample_u64() >> 11) as f64) * f64::from_bits(0x3CA0_0000_0000_0000_u64); // 0x1.0p-53
@@ -104,7 +105,7 @@ impl<M: MathsCore, R: RngCore, S> DistributionSampler<M, R, S, UniformOpenClosed
104105
self
105106
}
106107

107-
fn sample_with(&self, rng: &mut R, _samplers: &S, _params: ()) -> OpenClosedUnitF64 {
108+
fn sample_distribution(&self, rng: &mut R, _samplers: &S, _params: ()) -> OpenClosedUnitF64 {
108109
// http://prng.di.unimi.it -> Generating uniform doubles in the unit interval
109110
#[allow(clippy::cast_precision_loss)]
110111
let u01 =
@@ -123,10 +124,15 @@ impl<M: MathsCore, R: RngCore, S: DistributionSampler<M, R, S, UniformClosedOpen
123124
self
124125
}
125126

126-
fn sample_with(&self, rng: &mut R, samplers: &S, params: Length<NonZeroUsize>) -> usize {
127+
fn sample_distribution(
128+
&self,
129+
rng: &mut R,
130+
samplers: &S,
131+
params: Length<NonZeroUsize>,
132+
) -> usize {
127133
let length = params.0;
128134

129-
let u01: ClosedOpenUnitF64 = samplers.sample(rng, samplers);
135+
let u01 = UniformClosedOpenUnit::sample_raw(rng, samplers);
130136

131137
#[allow(
132138
clippy::cast_precision_loss,
@@ -149,10 +155,10 @@ impl<M: MathsCore, R: RngCore, S: DistributionSampler<M, R, S, UniformClosedOpen
149155
self
150156
}
151157

152-
fn sample_with(&self, rng: &mut R, samplers: &S, params: Length<NonZeroU32>) -> u32 {
158+
fn sample_distribution(&self, rng: &mut R, samplers: &S, params: Length<NonZeroU32>) -> u32 {
153159
let length = params.0;
154160

155-
let u01: ClosedOpenUnitF64 = samplers.sample(rng, samplers);
161+
let u01 = UniformClosedOpenUnit::sample_raw(rng, samplers);
156162

157163
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
158164
let index = M::floor(u01.get() * f64::from(length.get())) as u32;
@@ -171,10 +177,10 @@ impl<M: MathsCore, R: RngCore, S: DistributionSampler<M, R, S, UniformClosedOpen
171177
self
172178
}
173179

174-
fn sample_with(&self, rng: &mut R, samplers: &S, params: Length<NonZeroU64>) -> u64 {
180+
fn sample_distribution(&self, rng: &mut R, samplers: &S, params: Length<NonZeroU64>) -> u64 {
175181
let length = params.0;
176182

177-
let u01: ClosedOpenUnitF64 = samplers.sample(rng, samplers);
183+
let u01 = UniformClosedOpenUnit::sample_raw(rng, samplers);
178184

179185
#[allow(
180186
clippy::cast_precision_loss,
@@ -197,10 +203,10 @@ impl<M: MathsCore, R: RngCore, S: DistributionSampler<M, R, S, UniformClosedOpen
197203
self
198204
}
199205

200-
fn sample_with(&self, rng: &mut R, samplers: &S, params: Length<NonZeroU128>) -> u128 {
206+
fn sample_distribution(&self, rng: &mut R, samplers: &S, params: Length<NonZeroU128>) -> u128 {
201207
let length = params.0;
202208

203-
let u01: ClosedOpenUnitF64 = samplers.sample(rng, samplers);
209+
let u01 = UniformClosedOpenUnit::sample_raw(rng, samplers);
204210

205211
#[allow(
206212
clippy::cast_precision_loss,
@@ -223,10 +229,10 @@ impl<M: MathsCore, R: RngCore, S: DistributionSampler<M, R, S, UniformOpenClosed
223229
self
224230
}
225231

226-
fn sample_with(&self, rng: &mut R, samplers: &S, params: Lambda) -> NonNegativeF64 {
232+
fn sample_distribution(&self, rng: &mut R, samplers: &S, params: Lambda) -> NonNegativeF64 {
227233
let lambda = params.0;
228234

229-
let u01: OpenClosedUnitF64 = samplers.sample(rng, samplers);
235+
let u01 = UniformOpenClosedUnit::sample_raw(rng, samplers);
230236

231237
// Inverse transform sample: X = -ln(U(0,1]) / lambda
232238
-u01.ln::<M>() / lambda
@@ -246,15 +252,14 @@ impl<
246252
self
247253
}
248254

249-
fn sample_with(&self, rng: &mut R, samplers: &S, params: Lambda) -> u64 {
255+
fn sample_distribution(&self, rng: &mut R, samplers: &S, params: Lambda) -> u64 {
250256
let lambda = params.0;
251257
let no_event_probability = M::exp(-lambda.get());
252258

253259
if no_event_probability <= 0.0_f64 {
254260
// Fallback in case no_event_probability_per_step underflows
255261
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
256-
let normal_as_poisson = DistributionSampler::<M, R, S, Normal2D>::sample_with(
257-
samplers,
262+
let normal_as_poisson = Normal2D::sample_raw_with(
258263
rng,
259264
samplers,
260265
Normal {
@@ -273,8 +278,7 @@ impl<
273278
let mut prod = no_event_probability;
274279
let mut acc = no_event_probability;
275280

276-
let u =
277-
DistributionSampler::<M, R, S, UniformClosedOpenUnit>::sample(samplers, rng, samplers);
281+
let u = UniformClosedOpenUnit::sample_raw(rng, samplers);
278282

279283
#[allow(clippy::cast_precision_loss)]
280284
while u > acc && prod > 0.0_f64 {
@@ -296,10 +300,10 @@ impl<M: MathsCore, R: RngCore, S: DistributionSampler<M, R, S, UniformClosedOpen
296300
self
297301
}
298302

299-
fn sample_with(&self, rng: &mut R, samplers: &S, params: ClosedUnitF64) -> bool {
303+
fn sample_distribution(&self, rng: &mut R, samplers: &S, params: ClosedUnitF64) -> bool {
300304
let probability = params;
301305

302-
let u01: ClosedOpenUnitF64 = samplers.sample(rng, samplers);
306+
let u01 = UniformClosedOpenUnit::sample_raw(rng, samplers);
303307

304308
// if probability == 1, then U[0, 1) always < 1.0
305309
// if probability == 0, then U[0, 1) never < 0.0
@@ -320,12 +324,10 @@ impl<
320324
self
321325
}
322326

323-
fn sample_with(&self, rng: &mut R, samplers: &S, _params: ()) -> (f64, f64) {
327+
fn sample_distribution(&self, rng: &mut R, samplers: &S, _params: ()) -> (f64, f64) {
324328
// Basic Box-Muller transform
325-
let u0 =
326-
DistributionSampler::<M, R, S, UniformOpenClosedUnit>::sample(samplers, rng, samplers);
327-
let u1 =
328-
DistributionSampler::<M, R, S, UniformClosedOpenUnit>::sample(samplers, rng, samplers);
329+
let u0 = UniformOpenClosedUnit::sample_raw(rng, samplers);
330+
let u1 = UniformClosedOpenUnit::sample_raw(rng, samplers);
329331

330332
let r = M::sqrt(-2.0_f64 * M::ln(u0.get()));
331333
let theta = -core::f64::consts::TAU * u1.get();
@@ -343,8 +345,8 @@ impl<M: MathsCore, R: RngCore, S: DistributionSampler<M, R, S, StandardNormal2D>
343345
self
344346
}
345347

346-
fn sample_with(&self, rng: &mut R, samplers: &S, params: Normal) -> (f64, f64) {
347-
let (z0, z1) = samplers.sample(rng, samplers);
348+
fn sample_distribution(&self, rng: &mut R, samplers: &S, params: Normal) -> (f64, f64) {
349+
let (z0, z1) = StandardNormal2D::sample_raw(rng, samplers);
348350

349351
(
350352
z0 * params.sigma.get() + params.mu,

0 commit comments

Comments
 (0)