diff --git a/src/uint/boxed/rand.rs b/src/uint/boxed/rand.rs index a2fd4ffd..51252b99 100644 --- a/src/uint/boxed/rand.rs +++ b/src/uint/boxed/rand.rs @@ -28,7 +28,7 @@ impl RandomBits for BoxedUint { } let mut ret = BoxedUint::zero_with_precision(bits_precision); - random_bits_core(rng, &mut ret.limbs, bit_length).map_err(RandomBitsError::RandCore)?; + random_bits_core(rng, &mut ret.limbs, bit_length)?; Ok(ret) } } diff --git a/src/uint/rand.rs b/src/uint/rand.rs index ae060c70..d36dc26c 100644 --- a/src/uint/rand.rs +++ b/src/uint/rand.rs @@ -1,7 +1,7 @@ //! Random number generator support use super::{Uint, Word}; -use crate::{Limb, NonZero, Random, RandomBits, RandomBitsError, RandomMod, Zero}; +use crate::{Encoding, Limb, NonZero, Random, RandomBits, RandomBitsError, RandomMod, Zero}; use rand_core::{RngCore, TryRngCore}; use subtle::ConstantTimeLess; @@ -30,7 +30,7 @@ pub(crate) fn random_bits_core( rng: &mut R, zeroed_limbs: &mut [Limb], bit_length: u32, -) -> Result<(), R::Error> { +) -> Result<(), RandomBitsError> { if bit_length == 0 { return Ok(()); } @@ -43,7 +43,8 @@ pub(crate) fn random_bits_core( let mask = Word::MAX >> ((Word::BITS - partial_limb) % Word::BITS); for i in 0..nonzero_limbs - 1 { - rng.try_fill_bytes(&mut buffer)?; + rng.try_fill_bytes(&mut buffer) + .map_err(RandomBitsError::RandCore)?; zeroed_limbs[i] = Limb(Word::from_le_bytes(buffer)); } @@ -61,7 +62,8 @@ pub(crate) fn random_bits_core( buffer.as_mut_slice() }; - rng.try_fill_bytes(slice)?; + rng.try_fill_bytes(slice) + .map_err(RandomBitsError::RandCore)?; zeroed_limbs[nonzero_limbs - 1] = Limb(Word::from_le_bytes(buffer) & mask); Ok(()) @@ -93,7 +95,7 @@ impl RandomBits for Uint { }); } let mut limbs = [Limb::ZERO; LIMBS]; - random_bits_core(rng, &mut limbs, bit_length).map_err(RandomBitsError::RandCore)?; + random_bits_core(rng, &mut limbs, bit_length)?; Ok(Self::from(limbs)) } } @@ -126,19 +128,43 @@ pub(super) fn random_mod_core( where T: AsMut<[Limb]> + AsRef<[Limb]> + ConstantTimeLess + Zero, { - loop { - random_bits_core(rng, n.as_mut(), n_bits)?; + #[cfg(target_pointer_width = "64")] + let mut next_word = || rng.try_next_u64(); + #[cfg(target_pointer_width = "32")] + let mut next_word = || rng.try_next_u32(); + + let n_limbs = n_bits.div_ceil(Limb::BITS) as usize; + + let hi_word_modulus = modulus.as_ref().as_ref()[n_limbs - 1].0; + let mask = !0 >> hi_word_modulus.leading_zeros(); + let mut hi_word = next_word()? & mask; + loop { + while hi_word > hi_word_modulus { + hi_word = next_word()? & mask; + } + // Set high limb + n.as_mut()[n_limbs - 1] = Limb::from_le_bytes(hi_word.to_le_bytes()); + // Set low limbs + for i in 0..n_limbs - 1 { + // Need to deserialize from little-endian to make sure that two 32-bit limbs + // deserialized sequentially are equal to one 64-bit limb produced from the same + // byte stream. + n.as_mut()[i] = Limb::from_le_bytes(next_word()?.to_le_bytes()); + } + // If the high limb is equal to the modulus' high limb, it's still possible + // that the full uint is too big so we check and repeat if it is. if n.ct_lt(modulus).into() { break; } + hi_word = next_word()? & mask; } Ok(()) } #[cfg(test)] mod tests { - use crate::uint::rand::{random_bits_core, random_mod_core}; + use crate::uint::rand::random_bits_core; use crate::{Limb, NonZero, Random, RandomBits, RandomMod, U256, U1024, Uint}; use chacha20::ChaCha8Rng; use rand_core::{RngCore, SeedableRng}; @@ -262,32 +288,6 @@ mod tests { ); } - /// Make sure random_mod output is consistent across platforms - #[test] - fn random_mod_platform_independence() { - let mut rng = get_four_sequential_rng(); - - let modulus = NonZero::new(U256::from_u32(8192)).unwrap(); - let mut vals = [U256::ZERO, U256::ZERO, U256::ZERO, U256::ZERO, U256::ZERO]; - for val in &mut vals { - random_mod_core(&mut rng, val, &modulus, modulus.bits_vartime()).unwrap(); - } - let expected = [55, 3378, 2172, 1657, 5323]; - for (want, got) in expected.into_iter().zip(vals.into_iter()) { - assert_eq!(got, U256::from_u32(want)); - } - - let mut state = [0u8; 16]; - rng.fill_bytes(&mut state); - - assert_eq!( - state, - [ - 60, 146, 46, 106, 157, 83, 56, 212, 186, 104, 211, 210, 125, 28, 120, 239 - ], - ); - } - /// Test that random bytes are sampled consecutively. #[test] fn random_bits_4_bytes_sequential() {