diff --git a/Cargo.toml b/Cargo.toml index 49cf64ed..7c79a111 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,3 +11,6 @@ exclude = [ "benches", ] resolver = "2" + +[patch.crates-io] +rand_core = { git = "https://github.com/rust-random/rand_core.git", branch = "remaining_results" } diff --git a/rand_isaac/Cargo.toml b/rand_isaac/Cargo.toml index cf30165d..ee77f415 100644 --- a/rand_isaac/Cargo.toml +++ b/rand_isaac/Cargo.toml @@ -19,7 +19,7 @@ rust-version = "1.85" all-features = true [features] -serde = ["dep:serde", "rand_core/serde"] +serde = ["dep:serde"] [dependencies] rand_core = "0.10.0-rc-2" diff --git a/rand_isaac/src/isaac.rs b/rand_isaac/src/isaac.rs index 88771beb..94aa10ee 100644 --- a/rand_isaac/src/isaac.rs +++ b/rand_isaac/src/isaac.rs @@ -90,7 +90,6 @@ const RAND_SIZE: usize = 1 << RAND_SIZE_LEN; /// /// [`rand_hc`]: https://docs.rs/rand_hc #[derive(Debug, Clone)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct IsaacRng(BlockRng); impl RngCore for IsaacRng { @@ -143,6 +142,124 @@ impl SeedableRng for IsaacRng { } } +#[cfg(feature = "serde")] +mod serde_impls { + use super::{IsaacArray, IsaacRng, IsaacCore}; + use rand_core::block::BlockRng; + use serde::ser::{Serialize, Serializer, SerializeStruct}; + use serde::de::{Deserialize, Deserializer, Visitor, SeqAccess, MapAccess, Error}; + use core::fmt; + + impl Serialize for IsaacRng { + fn serialize(&self, serializer: S) -> Result { + let mut state = serializer.serialize_struct("IsaacRng", 2)?; + state.serialize_field("core", &self.0.core)?; + state.serialize_field("results", self.0.remaining_results())?; + state.end() + } + } + + struct Results { + results: IsaacArray, + len: usize, + } + impl Results { + fn to_rng(&self, core: IsaacCore) -> IsaacRng { + let results = &self.results[..self.len]; + IsaacRng(BlockRng::from_core_and_remaining_results(core, results).unwrap()) + } + } + struct ResultsVisitor; + impl<'de> Visitor<'de> for ResultsVisitor { + type Value = Results; + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(formatter, "") // TODO + } + + fn visit_seq>(self, mut seq: A) -> Result { + let mut results = IsaacArray::::default(); + let mut len = 0; + while let Some(value) = seq.next_element()? { + if len >= results.len() { + return Err(Error::invalid_length(len + 1, &("up to 256 elements" as &str))); + } + + results[len] = value; + len += 1; + } + + Ok(Results { results, len }) + } + } + + impl<'de> Deserialize<'de> for Results { + fn deserialize>(deserializer: D) -> Result { + deserializer.deserialize_seq(ResultsVisitor) + } + } + + #[derive(serde::Deserialize)] + #[serde(field_identifier, rename_all = "lowercase")] + enum Field { Core, Results } + + struct IsaacVisitor; + impl<'de> Visitor<'de> for IsaacVisitor { + type Value = IsaacRng; + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(formatter, "") // TODO + } + + fn visit_seq(self, mut seq: V) -> Result + where + V: SeqAccess<'de>, + { + let core = seq.next_element()? + .ok_or_else(|| Error::invalid_length(0, &self))?; + let results: Results = seq.next_element()? + .ok_or_else(|| Error::invalid_length(1, &self))?; + + Ok(results.to_rng(core)) + } + + fn visit_map(self, mut map: V) -> Result + where + V: MapAccess<'de>, + { + let mut core = None; + let mut results: Option = None; + while let Some(key) = map.next_key()? { + match key { + Field::Core => { + if core.is_some() { + return Err(Error::duplicate_field("core")); + } + core = Some(map.next_value()?); + } + Field::Results => { + if results.is_some() { + return Err(Error::duplicate_field("results")); + } + results = Some(map.next_value()?); + } + } + } + let core = core.ok_or_else(|| Error::missing_field("core"))?; + let results = results.ok_or_else(|| Error::missing_field("results"))?; + + Ok(results.to_rng(core)) + } + } + + impl<'de> Deserialize<'de> for IsaacRng { + fn deserialize>(deserializer: D) -> Result { + const FIELDS: &[&str] = &["core", "results"]; + deserializer.deserialize_struct("IsaacRng", FIELDS, IsaacVisitor) + } + } +} + /// The core of [`IsaacRng`], used with [`BlockRng`]. #[derive(Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -554,6 +671,10 @@ mod test { ]; let mut rng = IsaacRng::from_seed(seed); + // discard some results + let _ = rng.next_u64(); + let _ = rng.next_u32(); + let buf: Vec = Vec::new(); let mut buf = BufWriter::new(buf); bincode::serialize_into(&mut buf, &rng).expect("Could not serialize"); diff --git a/rand_isaac/src/isaac64.rs b/rand_isaac/src/isaac64.rs index 713cf0df..016fd945 100644 --- a/rand_isaac/src/isaac64.rs +++ b/rand_isaac/src/isaac64.rs @@ -14,8 +14,6 @@ use core::num::Wrapping as w; use core::{fmt, slice}; use rand_core::block::{BlockRng64, BlockRngCore}; use rand_core::{RngCore, SeedableRng, TryRngCore, le}; -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; #[allow(non_camel_case_types)] type w64 = w; @@ -81,7 +79,6 @@ const RAND_SIZE: usize = 1 << RAND_SIZE_LEN; /// [`rand_hc`]: https://docs.rs/rand_hc /// [`BlockRng64`]: rand_core::block::BlockRng64 #[derive(Debug, Clone)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Isaac64Rng(BlockRng64); impl RngCore for Isaac64Rng { @@ -136,12 +133,7 @@ impl SeedableRng for Isaac64Rng { /// The core of `Isaac64Rng`, used with `BlockRng`. #[derive(Clone)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Isaac64Core { - #[cfg_attr( - feature = "serde", - serde(with = "super::isaac_array::isaac_array_serde") - )] mem: [w64; RAND_SIZE], a: w64, b: w64, @@ -546,31 +538,4 @@ mod test { assert_eq!(rng1.next_u64(), rng2.next_u64()); } } - - #[test] - #[cfg(feature = "serde")] - fn test_isaac64_serde() { - use bincode; - use std::io::{BufReader, BufWriter}; - - let seed = [ - 1, 0, 0, 0, 23, 0, 0, 0, 200, 1, 0, 0, 210, 30, 0, 0, 57, 48, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, - ]; - let mut rng = Isaac64Rng::from_seed(seed); - - let buf: Vec = Vec::new(); - let mut buf = BufWriter::new(buf); - bincode::serialize_into(&mut buf, &rng).expect("Could not serialize"); - - let buf = buf.into_inner().unwrap(); - let mut read = BufReader::new(&buf[..]); - let mut deserialized: Isaac64Rng = - bincode::deserialize_from(&mut read).expect("Could not deserialize"); - - // more than the 256 buffered results - for _ in 0..300 { - assert_eq!(rng.next_u64(), deserialized.next_u64()); - } - } }