Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion chacha20/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ pub use legacy::{ChaCha20Legacy, LegacyNonce};
#[cfg(feature = "rng")]
pub use rand_core;
#[cfg(feature = "rng")]
pub use rng::{ChaCha8Rng, ChaCha12Rng, ChaCha20Rng, Seed};
pub use rng::{ChaCha8Rng, ChaCha12Rng, ChaCha20Rng, Seed, SerializedRngState};
#[cfg(feature = "xchacha")]
pub use xchacha::{XChaCha8, XChaCha12, XChaCha20, XNonce, hchacha};

Expand Down
60 changes: 54 additions & 6 deletions chacha20/src/rng.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ use cfg_if::cfg_if;
/// Seed value used to initialize ChaCha-based RNGs.
pub type Seed = [u8; 32];

/// Serialized RNG state.
pub type SerializedRngState = [u8; 49];

/// Number of 32-bit words per ChaCha block (fixed by algorithm definition).
pub(crate) const BLOCK_WORDS: u8 = 16;

Expand Down Expand Up @@ -277,16 +280,61 @@ macro_rules! impl_chacha_rng {
#[inline]
#[must_use]
pub fn get_seed(&self) -> [u8; 32] {
let seed = &self.core.core.state[4..12];
let mut result = [0u8; 32];
for (i, &big) in self.core.core.state[4..12].iter().enumerate() {
let index = i * 4;
result[index + 0] = big as u8;
result[index + 1] = (big >> 8) as u8;
result[index + 2] = (big >> 16) as u8;
result[index + 3] = (big >> 24) as u8;
for (src, dst) in seed.iter().zip(result.chunks_exact_mut(4)) {
dst.copy_from_slice(&src.to_le_bytes())
}
result
}

/// Serialize RNG state.
///
/// # Warning
/// Leaking serialized RNG state to an attacker defeats security properties
/// provided by the RNG.
#[inline]
pub fn serialize_state(&self) -> SerializedRngState {
let seed = self.get_seed();
let stream = self.get_stream().to_le_bytes();
let word_pos = self.get_word_pos().to_le_bytes();

let mut res = [0u8; 49];
let (seed_dst, res_rem) = res.split_at_mut(32);
let (stream_dst, word_pos_dst) = res_rem.split_at_mut(8);

seed_dst.copy_from_slice(&seed);
stream_dst.copy_from_slice(&stream);
word_pos_dst.copy_from_slice(&word_pos[..9]);

debug_assert_eq!(&word_pos[9..], &[0u8; 7]);

res
}

/// Deserialize RNG state.
#[inline]
pub fn deserialize_state(state: &SerializedRngState) -> Self {
let (seed, state_rem) = state.split_at(32);
let (stream, word_pos_raw) = state_rem.split_at(8);

let seed: &[u8; 32] = seed.try_into().expect("seed.len() is equal to 32");
let stream: &[u8; 8] = stream.try_into().expect("stream.len() is equal to 8");

// Note that we use only 68 bits from `word_pos_raw`, i.e. 4 remaining bits
// get ignored and should be equal to zero in practice.
let mut word_pos_buf = [0u8; 16];
word_pos_buf[..9].copy_from_slice(word_pos_raw);
let word_pos = u128::from_le_bytes(word_pos_buf);

let core = ChaChaCore::new_internal(seed, stream);
let mut res = Self {
core: BlockRng::new(core),
};

res.set_word_pos(word_pos);
res
}
}
};
}
Expand Down
84 changes: 83 additions & 1 deletion chacha20/tests/rng.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
#![cfg(feature = "rng")]

use chacha20::{
ChaCha20Rng,
ChaCha8Rng, ChaCha12Rng, ChaCha20Rng, SerializedRngState,
rand_core::{Rng, SeedableRng},
};
use hex_literal::hex;

const KEY: [u8; 32] = hex!("0102030405060708090A0B0C0D0E0F101112131415161718191A1B1C1D1E1F20");
const STREAM: u64 = 0xF0F1F2F3_F4F5F6F7;
const BLOCK_WORDS: u8 = 16;

#[test]
Expand Down Expand Up @@ -467,3 +468,84 @@ fn counter_not_wrapping_at_32_bits() {
);
assert_ne!(&first_blocks[0..64 * 4], &result[64..]);
}

#[test]
fn test_chacha8rng_serde_roundtrip() {
for skip_words in 0..100 {
let mut rng = ChaCha8Rng::from_seed(KEY);
rng.set_stream(STREAM);
for _ in 0..skip_words {
let _ = rng.next_u32();
}
let state = rng.serialize_state();
let mut rng2 = ChaCha8Rng::deserialize_state(&state);
for _ in 0..100 {
assert_eq!(rng.next_u32(), rng2.next_u32());
}
}
}

#[test]
fn test_chacha12rng_serde_roundtrip() {
for skip_words in 0..100 {
let mut rng = ChaCha12Rng::from_seed(KEY);
rng.set_stream(STREAM);
for _ in 0..skip_words {
let _ = rng.next_u32();
}
let state = rng.serialize_state();
let mut rng2 = ChaCha12Rng::deserialize_state(&state);
for _ in 0..100 {
assert_eq!(rng.next_u32(), rng2.next_u32());
}
}
}

#[test]
fn test_chacha20rng_serde_roundtrip() {
for skip_words in 0..100 {
let mut rng = ChaCha20Rng::from_seed(KEY);
rng.set_stream(STREAM);
for _ in 0..skip_words {
let _ = rng.next_u32();
}
let state = rng.serialize_state();
let mut rng2 = ChaCha20Rng::deserialize_state(&state);
for _ in 0..100 {
assert_eq!(rng.next_u32(), rng2.next_u32());
}
}
}

#[test]
fn test_rng_serialized_state_stability() {
const EXPECTED: SerializedRngState = hex!(
"0102030405060708090A0B0C0D0E0F10"
"1112131415161718191A1B1C1D1E1F20"
"F7F6F5F4F3F2F1F06400000000000000"
"00"
);
let mut rng = ChaCha8Rng::from_seed(KEY);
rng.set_stream(STREAM);
for _ in 0..100 {
let _ = rng.next_u32();
}
let state = rng.serialize_state();
assert_eq!(state, EXPECTED);

let mut rng = ChaCha12Rng::from_seed(KEY);
rng.set_stream(STREAM);
for _ in 0..100 {
let _ = rng.next_u32();
}
let state = rng.serialize_state();
assert_eq!(state, EXPECTED);

let mut rng = ChaCha20Rng::from_seed(KEY);
rng.set_stream(STREAM);
for _ in 0..100 {
let _ = rng.next_u32();
}
let state = rng.serialize_state();
assert_eq!(state, EXPECTED);
}