Skip to content

Commit f6071d1

Browse files
committed
Small optimisations of rng usage
1 parent deb965b commit f6071d1

File tree

28 files changed

+230
-218
lines changed

28 files changed

+230
-218
lines changed

necsim/core/src/cogs/coalescence_sampler.rs

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
1-
use core::cmp::{Ord, Ordering};
2-
3-
use necsim_core_bond::ClosedOpenUnitF64;
1+
use core::{
2+
cmp::{Ord, Ordering},
3+
num::NonZeroU32,
4+
};
45

56
use serde::{Deserialize, Serialize};
67

78
use crate::{
8-
cogs::{
9-
distribution::UniformClosedOpenUnit, Backup, Distribution, Habitat, LineageStore,
10-
MathsCore, Rng, Samples,
11-
},
9+
cogs::{Backup, Habitat, LineageStore, MathsCore, Rng, RngCore},
1210
landscape::{IndexedLocation, Location},
1311
lineage::LineageInteraction,
1412
};
@@ -32,7 +30,7 @@ pub trait CoalescenceSampler<M: MathsCore, H: Habitat<M>, S: LineageStore<M, H>>
3230
#[allow(clippy::unsafe_derive_deserialize)]
3331
#[derive(Debug, PartialEq, Serialize, Deserialize, TypeLayout)]
3432
#[repr(transparent)]
35-
pub struct CoalescenceRngSample(ClosedOpenUnitF64);
33+
pub struct CoalescenceRngSample(u64);
3634

3735
#[contract_trait]
3836
impl Backup for CoalescenceRngSample {
@@ -58,22 +56,20 @@ impl Eq for CoalescenceRngSample {}
5856
impl CoalescenceRngSample {
5957
#[must_use]
6058
#[inline]
61-
pub fn new<M: MathsCore, G: Rng<M> + Samples<M, UniformClosedOpenUnit>>(rng: &mut G) -> Self {
62-
Self(UniformClosedOpenUnit::sample(rng))
59+
pub fn new<M: MathsCore, G: Rng<M>>(rng: &mut G) -> Self {
60+
Self(rng.generator().sample_u64())
6361
}
6462

6563
#[must_use]
6664
#[inline]
67-
#[debug_ensures(ret < length, "samples U(0, length - 1)")]
68-
pub fn sample_coalescence_index<M: MathsCore>(self, length: u32) -> u32 {
69-
// attributes on expressions are experimental
70-
// see https://github.com/rust-lang/rust/issues/15701
71-
#[allow(
72-
clippy::cast_precision_loss,
73-
clippy::cast_possible_truncation,
74-
clippy::cast_sign_loss
75-
)]
76-
let index = M::floor(self.0.get() * f64::from(length)) as u32;
77-
index
65+
#[debug_ensures(ret < length.get(), "samples U(0, length - 1)")]
66+
pub fn sample_coalescence_index(self, length: NonZeroU32) -> u32 {
67+
// Sample U(0, length - 1) using a widening multiplication
68+
// Note: Some slight bias is traded for only needing one u64 sample
69+
// Note: Should optimise to a single 64 bit (high-only) multiplication
70+
#[allow(clippy::cast_possible_truncation)]
71+
{
72+
(((u128::from(self.0) * u128::from(length.get())) >> 64) & u128::from(!0_u32)) as u32
73+
}
7874
}
7975
}

necsim/impls/no-std/src/cogs/active_lineage_sampler/alias/sampler/indexed/tests.rs

Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,19 +1145,25 @@ impl DistributionSampler<IntrinsicsMathsCore, DummyRng, DummyDistributionSampler
11451145
&self,
11461146
rng: &mut DummyRng,
11471147
_samplers: &DummyDistributionSamplers,
1148-
params: Length<NonZeroUsize>,
1148+
Length(length): Length<NonZeroUsize>,
11491149
) -> usize {
1150-
let length = params.0;
1151-
1152-
#[allow(
1153-
clippy::cast_precision_loss,
1154-
clippy::cast_possible_truncation,
1155-
clippy::cast_sign_loss
1156-
)]
1157-
let index = IntrinsicsMathsCore::floor(rng.sample_f64() * (length.get() as f64)) as usize;
1158-
1159-
// Safety in case of f64 rounding errors
1160-
index.min(length.get() - 1)
1150+
let u01 = rng.sample_f64();
1151+
1152+
// Safety: U[0, 1) * length in [0, 2^[32/64]) is a valid [u32/u64]
1153+
// since (1 - 2^-53) * 2^[32/64] <= (2^[32/64] - 1)
1154+
#[allow(clippy::cast_precision_loss)]
1155+
let index = unsafe {
1156+
IntrinsicsMathsCore::floor(u01 * (length.get() as f64)).to_int_unchecked::<usize>()
1157+
};
1158+
1159+
if cfg!(target_pointer_width = "32") {
1160+
// Note: [0, 2^32) is losslessly represented in f64
1161+
index
1162+
} else {
1163+
// Note: Ensure index < length despite
1164+
// usize->f64->usize precision loss
1165+
index.min(length.get() - 1)
1166+
}
11611167
}
11621168
}
11631169

@@ -1174,18 +1180,18 @@ impl DistributionSampler<IntrinsicsMathsCore, DummyRng, DummyDistributionSampler
11741180
&self,
11751181
rng: &mut DummyRng,
11761182
_samplers: &DummyDistributionSamplers,
1177-
params: Length<NonZeroU64>,
1183+
Length(length): Length<NonZeroU64>,
11781184
) -> u64 {
1179-
let length = params.0;
1185+
let u01 = rng.sample_f64();
11801186

1181-
#[allow(
1182-
clippy::cast_precision_loss,
1183-
clippy::cast_possible_truncation,
1184-
clippy::cast_sign_loss
1185-
)]
1186-
let index = IntrinsicsMathsCore::floor(rng.sample_f64() * (length.get() as f64)) as u64;
1187+
// Safety: U[0, 1) * length in [0, 2^64) is a valid u64
1188+
// since (1 - 2^-53) * 2^64 <= (2^64 - 1)
1189+
#[allow(clippy::cast_precision_loss)]
1190+
let index = unsafe {
1191+
IntrinsicsMathsCore::floor(u01 * (length.get() as f64)).to_int_unchecked::<u64>()
1192+
};
11871193

1188-
// Safety in case of f64 rounding errors
1194+
// Note: Ensure index < length despite u64->f64->u64 precision loss
11891195
index.min(length.get() - 1)
11901196
}
11911197
}
@@ -1203,18 +1209,18 @@ impl DistributionSampler<IntrinsicsMathsCore, DummyRng, DummyDistributionSampler
12031209
&self,
12041210
rng: &mut DummyRng,
12051211
_samplers: &DummyDistributionSamplers,
1206-
params: Length<NonZeroU128>,
1212+
Length(length): Length<NonZeroU128>,
12071213
) -> u128 {
1208-
let length = params.0;
1214+
let u01 = rng.sample_f64();
12091215

1210-
#[allow(
1211-
clippy::cast_precision_loss,
1212-
clippy::cast_possible_truncation,
1213-
clippy::cast_sign_loss
1214-
)]
1215-
let index = IntrinsicsMathsCore::floor(rng.sample_f64() * (length.get() as f64)) as u128;
1216+
// Safety: U[0, 1) * length in [0, 2^128) is a valid u128
1217+
// since (1 - 2^-53) * 2^128 <= (2^128 - 1)
1218+
#[allow(clippy::cast_precision_loss)]
1219+
let index = unsafe {
1220+
IntrinsicsMathsCore::floor(u01 * (length.get() as f64)).to_int_unchecked::<u128>()
1221+
};
12161222

1217-
// Safety in case of f64 rounding errors
1223+
// Note: Ensure index < length despite u128->f64->u128 precision loss
12181224
index.min(length.get() - 1)
12191225
}
12201226
}

necsim/impls/no-std/src/cogs/active_lineage_sampler/alias/sampler/stack/tests.rs

Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -644,19 +644,25 @@ impl DistributionSampler<IntrinsicsMathsCore, DummyRng, DummyDistributionSampler
644644
&self,
645645
rng: &mut DummyRng,
646646
_samplers: &DummyDistributionSamplers,
647-
params: Length<NonZeroUsize>,
647+
Length(length): Length<NonZeroUsize>,
648648
) -> usize {
649-
let length = params.0;
650-
651-
#[allow(
652-
clippy::cast_precision_loss,
653-
clippy::cast_possible_truncation,
654-
clippy::cast_sign_loss
655-
)]
656-
let index = IntrinsicsMathsCore::floor(rng.sample_f64() * (length.get() as f64)) as usize;
657-
658-
// Safety in case of f64 rounding errors
659-
index.min(length.get() - 1)
649+
let u01 = rng.sample_f64();
650+
651+
// Safety: U[0, 1) * length in [0, 2^[32/64]) is a valid [u32/u64]
652+
// since (1 - 2^-53) * 2^[32/64] <= (2^[32/64] - 1)
653+
#[allow(clippy::cast_precision_loss)]
654+
let index = unsafe {
655+
IntrinsicsMathsCore::floor(u01 * (length.get() as f64)).to_int_unchecked::<usize>()
656+
};
657+
658+
if cfg!(target_pointer_width = "32") {
659+
// Note: [0, 2^32) is losslessly represented in f64
660+
index
661+
} else {
662+
// Note: Ensure index < length despite
663+
// usize->f64->usize precision loss
664+
index.min(length.get() - 1)
665+
}
660666
}
661667
}
662668

@@ -673,18 +679,18 @@ impl DistributionSampler<IntrinsicsMathsCore, DummyRng, DummyDistributionSampler
673679
&self,
674680
rng: &mut DummyRng,
675681
_samplers: &DummyDistributionSamplers,
676-
params: Length<NonZeroU64>,
682+
Length(length): Length<NonZeroU64>,
677683
) -> u64 {
678-
let length = params.0;
684+
let u01 = rng.sample_f64();
679685

680-
#[allow(
681-
clippy::cast_precision_loss,
682-
clippy::cast_possible_truncation,
683-
clippy::cast_sign_loss
684-
)]
685-
let index = IntrinsicsMathsCore::floor(rng.sample_f64() * (length.get() as f64)) as u64;
686+
// Safety: U[0, 1) * length in [0, 2^64) is a valid u64
687+
// since (1 - 2^-53) * 2^64 <= (2^64 - 1)
688+
#[allow(clippy::cast_precision_loss)]
689+
let index = unsafe {
690+
IntrinsicsMathsCore::floor(u01 * (length.get() as f64)).to_int_unchecked::<u64>()
691+
};
686692

687-
// Safety in case of f64 rounding errors
693+
// Note: Ensure index < length despite u64->f64->u64 precision loss
688694
index.min(length.get() - 1)
689695
}
690696
}
@@ -702,18 +708,18 @@ impl DistributionSampler<IntrinsicsMathsCore, DummyRng, DummyDistributionSampler
702708
&self,
703709
rng: &mut DummyRng,
704710
_samplers: &DummyDistributionSamplers,
705-
params: Length<NonZeroU128>,
711+
Length(length): Length<NonZeroU128>,
706712
) -> u128 {
707-
let length = params.0;
713+
let u01 = rng.sample_f64();
708714

709-
#[allow(
710-
clippy::cast_precision_loss,
711-
clippy::cast_possible_truncation,
712-
clippy::cast_sign_loss
713-
)]
714-
let index = IntrinsicsMathsCore::floor(rng.sample_f64() * (length.get() as f64)) as u128;
715+
// Safety: U[0, 1) * length in [0, 2^128) is a valid u128
716+
// since (1 - 2^-53) * 2^128 <= (2^128 - 1)
717+
#[allow(clippy::cast_precision_loss)]
718+
let index = unsafe {
719+
IntrinsicsMathsCore::floor(u01 * (length.get() as f64)).to_int_unchecked::<u128>()
720+
};
715721

716-
// Safety in case of f64 rounding errors
722+
// Note: Ensure index < length despite u128->f64->u128 precision loss
717723
index.min(length.get() - 1)
718724
}
719725
}

necsim/impls/no-std/src/cogs/active_lineage_sampler/classical/mod.rs

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use alloc::vec::Vec;
22
use core::marker::PhantomData;
33

44
use necsim_core::cogs::{
5-
distribution::{Bernoulli, Exponential, IndexUsize, UniformClosedOpenUnit},
5+
distribution::{Bernoulli, Exponential, IndexUsize},
66
Backup, DispersalSampler, EmigrationExit, Habitat, ImmigrationEntry,
77
LocallyCoherentLineageStore, MathsCore, Rng, Samples, SpeciationProbability,
88
};
@@ -20,11 +20,7 @@ mod sampler;
2020
pub struct ClassicalActiveLineageSampler<
2121
M: MathsCore,
2222
H: Habitat<M>,
23-
G: Rng<M>
24-
+ Samples<M, Exponential>
25-
+ Samples<M, IndexUsize>
26-
+ Samples<M, Bernoulli>
27-
+ Samples<M, UniformClosedOpenUnit>,
23+
G: Rng<M> + Samples<M, Exponential> + Samples<M, IndexUsize> + Samples<M, Bernoulli>,
2824
S: LocallyCoherentLineageStore<M, H>,
2925
X: EmigrationExit<M, H, G, S>,
3026
D: DispersalSampler<M, H, G>,
@@ -40,11 +36,7 @@ pub struct ClassicalActiveLineageSampler<
4036
impl<
4137
M: MathsCore,
4238
H: Habitat<M>,
43-
G: Rng<M>
44-
+ Samples<M, Exponential>
45-
+ Samples<M, IndexUsize>
46-
+ Samples<M, Bernoulli>
47-
+ Samples<M, UniformClosedOpenUnit>,
39+
G: Rng<M> + Samples<M, Exponential> + Samples<M, IndexUsize> + Samples<M, Bernoulli>,
4840
S: LocallyCoherentLineageStore<M, H>,
4941
X: EmigrationExit<M, H, G, S>,
5042
D: DispersalSampler<M, H, G>,
@@ -135,11 +127,7 @@ impl<
135127
impl<
136128
M: MathsCore,
137129
H: Habitat<M>,
138-
G: Rng<M>
139-
+ Samples<M, Exponential>
140-
+ Samples<M, IndexUsize>
141-
+ Samples<M, Bernoulli>
142-
+ Samples<M, UniformClosedOpenUnit>,
130+
G: Rng<M> + Samples<M, Exponential> + Samples<M, IndexUsize> + Samples<M, Bernoulli>,
143131
S: LocallyCoherentLineageStore<M, H>,
144132
X: EmigrationExit<M, H, G, S>,
145133
D: DispersalSampler<M, H, G>,

necsim/impls/no-std/src/cogs/active_lineage_sampler/classical/sampler.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use core::{
55

66
use necsim_core::{
77
cogs::{
8-
distribution::{Bernoulli, Exponential, IndexUsize, Lambda, Length, UniformClosedOpenUnit},
8+
distribution::{Bernoulli, Exponential, IndexUsize, Lambda, Length},
99
ActiveLineageSampler, DispersalSampler, Distribution, EmigrationExit, Habitat,
1010
ImmigrationEntry, LocallyCoherentLineageStore, MathsCore, Rng, Samples,
1111
SpeciationProbability,
@@ -27,11 +27,7 @@ use super::ClassicalActiveLineageSampler;
2727
impl<
2828
M: MathsCore,
2929
H: Habitat<M>,
30-
G: Rng<M>
31-
+ Samples<M, Exponential>
32-
+ Samples<M, IndexUsize>
33-
+ Samples<M, Bernoulli>
34-
+ Samples<M, UniformClosedOpenUnit>,
30+
G: Rng<M> + Samples<M, Exponential> + Samples<M, IndexUsize> + Samples<M, Bernoulli>,
3531
S: LocallyCoherentLineageStore<M, H>,
3632
X: EmigrationExit<M, H, G, S>,
3733
D: DispersalSampler<M, H, G>,

necsim/impls/no-std/src/cogs/active_lineage_sampler/independent/event_time_sampler/exp.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,13 @@ impl<
5656
)
5757
};
5858

59+
// Note: rust clamps f64 as u64 to [0, 2^64 - 1]
5960
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
6061
let mut time_step = M::floor(time.get() / self.delta_t.get()) as u64;
6162

6263
let mut event_time = NonNegativeF64::from(time_step) * self.delta_t;
6364
let mut time_slice_end = NonNegativeF64::from(time_step + 1) * self.delta_t;
6465

65-
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
6666
rng.generator()
6767
.prime_with_habitat(habitat, indexed_location, time_step);
6868

necsim/impls/no-std/src/cogs/active_lineage_sampler/independent/event_time_sampler/fixed.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ impl<M: MathsCore, H: Habitat<M>, G: Rng<M, Generator: PrimeableRng>, T: Turnove
2727
let lambda =
2828
turnover_rate.get_turnover_rate_at_location(indexed_location.location(), habitat);
2929

30-
#[allow(clippy::cast_possible_truncation)]
31-
#[allow(clippy::cast_sign_loss)]
30+
// Note: rust clamps f64 as u64 to [0, 2^64 - 1]
31+
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
3232
let time_step = M::floor(time.get() * lambda.get()) as u64 + 1;
3333

3434
rng.generator()

necsim/impls/no-std/src/cogs/active_lineage_sampler/independent/event_time_sampler/geometric.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ impl<
4646
.neg_exp::<M>()
4747
.one_minus();
4848

49-
#[allow(clippy::cast_possible_truncation)]
50-
#[allow(clippy::cast_sign_loss)]
49+
// Note: rust clamps f64 as u64 to [0, 2^64 - 1]
50+
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
5151
let mut time_step = M::floor(time.get() / self.delta_t.get()) as u64 + 1;
5252

5353
loop {

necsim/impls/no-std/src/cogs/active_lineage_sampler/independent/event_time_sampler/poisson.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ impl<
5050
// location
5151
let lambda_per_step = unsafe { PositiveF64::new_unchecked(lambda.get()) } * self.delta_t;
5252

53-
#[allow(clippy::cast_possible_truncation)]
54-
#[allow(clippy::cast_sign_loss)]
53+
// Note: rust clamps f64 as u64 to [0, 2^64 - 1]
54+
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
5555
let mut time_step = M::floor(time.get() / self.delta_t.get()) as u64;
5656

5757
let (event_time, event_index) = loop {

necsim/impls/no-std/src/cogs/coalescence_sampler/conditional.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use core::marker::PhantomData;
1+
use core::{marker::PhantomData, num::NonZeroU32};
22

33
use necsim_core::{
44
cogs::{
@@ -73,11 +73,11 @@ impl<M: MathsCore, H: Habitat<M>, S: GloballyCoherentLineageStore<M, H>>
7373
let lineages_at_location =
7474
lineage_store.get_local_lineage_references_at_location_unordered(&location, habitat);
7575

76+
// Safety: individuals can only occupy habitable locations
7677
#[allow(clippy::cast_possible_truncation)]
77-
let population = lineages_at_location.len() as u32;
78+
let population = unsafe { NonZeroU32::new_unchecked(lineages_at_location.len() as u32) };
7879

79-
let chosen_coalescence_index =
80-
coalescence_rng_sample.sample_coalescence_index::<M>(population);
80+
let chosen_coalescence_index = coalescence_rng_sample.sample_coalescence_index(population);
8181
let chosen_coalescence = &lineages_at_location[chosen_coalescence_index as usize];
8282

8383
let lineage = &lineage_store[chosen_coalescence];

0 commit comments

Comments
 (0)