diff --git a/crates/rec_aggregation/src/lib.rs b/crates/rec_aggregation/src/lib.rs index 26bff46b..ac37b2f9 100644 --- a/crates/rec_aggregation/src/lib.rs +++ b/crates/rec_aggregation/src/lib.rs @@ -7,8 +7,8 @@ use lean_vm::*; use tracing::instrument; use utils::{build_prover_state, poseidon_compress_slice, poseidon16_compress_pair}; use xmss::{ - LOG_LIFETIME, MESSAGE_LEN_FE, Poseidon16History, SIG_SIZE_FE, XmssPublicKey, XmssSignature, slot_to_field_elements, - xmss_verify_with_poseidon_trace, + LOG_LIFETIME, MESSAGE_LEN_FE, Poseidon16History, RANDOMNESS_LEN_FE, SIG_SIZE_FE, XmssPublicKey, XmssSignature, + slot_to_field_elements, xmss_verify_with_poseidon_trace, }; use serde::{Deserialize, Serialize}; @@ -36,18 +36,19 @@ pub(crate) fn count_signers(topology: &AggregationTopology, overlap: usize) -> u } pub fn hash_pubkeys(pub_keys: &[XmssPublicKey]) -> [F; DIGEST_LEN] { - let iv = [F::ZERO; DIGEST_LEN]; - let flat: Vec = iv - .iter() - .copied() - .chain(pub_keys.iter().flat_map(|pk| pk.merkle_root.iter().copied())) - .collect(); + let total = DIGEST_LEN + pub_keys.len() * DIGEST_LEN; + let mut flat = vec![F::ZERO; total]; + // First DIGEST_LEN elements are already zero. Write pub keys directly at offsets. + let mut off = DIGEST_LEN; + for pk in pub_keys { + flat[off..off + DIGEST_LEN].copy_from_slice(&pk.merkle_root); + off += DIGEST_LEN; + } poseidon_compress_slice(&flat) } -fn compute_merkle_chunks_for_slot(slot: u32) -> Vec { - let mut chunks = Vec::with_capacity(N_MERKLE_CHUNKS_FOR_SLOT); - for chunk_idx in 0..N_MERKLE_CHUNKS_FOR_SLOT { +fn compute_merkle_chunks_for_slot(slot: u32) -> [F; N_MERKLE_CHUNKS_FOR_SLOT] { + std::array::from_fn(|chunk_idx| { let mut nibble_val: usize = 0; for bit in 0..4 { let level = chunk_idx * 4 + bit; @@ -56,9 +57,8 @@ fn compute_merkle_chunks_for_slot(slot: u32) -> Vec { nibble_val |= 1 << bit; } } - chunks.push(F::from_usize(nibble_val)); - } - chunks + F::from_usize(nibble_val) + }) } fn build_non_reserved_public_input( @@ -68,26 +68,45 @@ fn build_non_reserved_public_input( slot: u32, bytecode_claim_output: &[F], ) -> Vec { - let mut pi = vec![]; - pi.push(F::from_usize(n_sigs)); - pi.extend_from_slice(slice_hash); - pi.extend_from_slice(message); + let total = 1 + DIGEST_LEN + MESSAGE_LEN_FE + 2 + N_MERKLE_CHUNKS_FOR_SLOT + bytecode_claim_output.len(); + let mut pi = Vec::with_capacity(total); + // SAFETY: capacity is `total`, all elements written below + unsafe { pi.set_len(total) }; + let mut off = 0; + pi[off] = F::from_usize(n_sigs); + off += 1; + pi[off..off + DIGEST_LEN].copy_from_slice(slice_hash); + off += DIGEST_LEN; + pi[off..off + MESSAGE_LEN_FE].copy_from_slice(message); + off += MESSAGE_LEN_FE; let [slot_lo, slot_hi] = slot_to_field_elements(slot); - pi.push(slot_lo); - pi.push(slot_hi); - pi.extend(compute_merkle_chunks_for_slot(slot)); - pi.extend_from_slice(bytecode_claim_output); + pi[off] = slot_lo; + pi[off + 1] = slot_hi; + off += 2; + let chunks = compute_merkle_chunks_for_slot(slot); + pi[off..off + N_MERKLE_CHUNKS_FOR_SLOT].copy_from_slice(&chunks); + off += N_MERKLE_CHUNKS_FOR_SLOT; + pi[off..off + bytecode_claim_output.len()].copy_from_slice(bytecode_claim_output); + debug_assert_eq!(off + bytecode_claim_output.len(), total); pi } fn encode_xmss_signature(sig: &XmssSignature) -> Vec { - let mut data = vec![]; - data.extend(sig.wots_signature.randomness.to_vec()); - data.extend(sig.wots_signature.chain_tips.iter().flat_map(|digest| digest.to_vec())); + let mut data = Vec::with_capacity(SIG_SIZE_FE); + // SAFETY: capacity is SIG_SIZE_FE, all bytes written below via copy_from_slice + unsafe { data.set_len(SIG_SIZE_FE) }; + let mut off = 0; + data[off..off + RANDOMNESS_LEN_FE].copy_from_slice(&sig.wots_signature.randomness); + off += RANDOMNESS_LEN_FE; + for digest in &sig.wots_signature.chain_tips { + data[off..off + DIGEST_LEN].copy_from_slice(digest); + off += DIGEST_LEN; + } for neighbor in &sig.merkle_proof { - data.extend(neighbor.to_vec()); + data[off..off + DIGEST_LEN].copy_from_slice(neighbor); + off += DIGEST_LEN; } - assert_eq!(data.len(), SIG_SIZE_FE); + debug_assert_eq!(off, SIG_SIZE_FE); data } @@ -342,26 +361,47 @@ pub fn xmss_aggregate( } let bytecode_sumcheck_proof_ptr = offset; - let mut private_input = vec![]; - private_input.push(F::from_usize(n_recursions)); - private_input.push(F::from_usize(n_dup)); - private_input.push(F::from_usize(pubkeys_start)); + let source_blocks_total: usize = source_blocks.iter().map(|b| b.len()).sum(); + let private_total = header_size + pubkeys_block_size + source_blocks_total + final_sumcheck_transcript.len(); + let mut private_input = Vec::with_capacity(private_total); + // SAFETY: capacity is private_total, all elements written below via indexed assignment + copy_from_slice + unsafe { private_input.set_len(private_total) }; + let mut off = 0; + + // Header + private_input[off] = F::from_usize(n_recursions); + private_input[off + 1] = F::from_usize(n_dup); + private_input[off + 2] = F::from_usize(pubkeys_start); + off += 3; for &ptr in &source_ptrs { - private_input.push(F::from_usize(ptr)); + private_input[off] = F::from_usize(ptr); + off += 1; } - private_input.push(F::from_usize(bytecode_sumcheck_proof_ptr)); - assert_eq!(private_input.len(), header_size); + private_input[off] = F::from_usize(bytecode_sumcheck_proof_ptr); + off += 1; + debug_assert_eq!(off, header_size); + // Pub keys for pk in &global_pub_keys { - private_input.extend_from_slice(&pk.merkle_root); + private_input[off..off + DIGEST_LEN].copy_from_slice(&pk.merkle_root); + off += DIGEST_LEN; } for pk in &dup_pub_keys { - private_input.extend_from_slice(&pk.merkle_root); + private_input[off..off + DIGEST_LEN].copy_from_slice(&pk.merkle_root); + off += DIGEST_LEN; } + + // Source blocks for block in &source_blocks { - private_input.extend_from_slice(block); + let blen = block.len(); + private_input[off..off + blen].copy_from_slice(block); + off += blen; } - private_input.extend_from_slice(&final_sumcheck_transcript); + + // Bytecode sumcheck transcript + let tlen = final_sumcheck_transcript.len(); + private_input[off..off + tlen].copy_from_slice(&final_sumcheck_transcript); + debug_assert_eq!(off + tlen, private_total); // TODO precompute all the other poseidons let xmss_poseidons_16_precomputed = precompute_poseidons(&raw_xmss, message); @@ -372,9 +412,16 @@ pub fn xmss_aggregate( .iter() .flat_map(|p| p.merkle_openings.iter()) .flat_map(|o| { - let leaf = o.leaf_data.clone(); - let path: Vec = o.path.iter().flat_map(|d| d.iter().copied()).collect(); - [leaf, path] + let path_len = o.path.len() * DIGEST_LEN; + let mut path = Vec::with_capacity(path_len); + // SAFETY: capacity is path_len, all elements written below + unsafe { path.set_len(path_len) }; + let mut off = 0; + for d in &o.path { + path[off..off + DIGEST_LEN].copy_from_slice(d); + off += DIGEST_LEN; + } + [o.leaf_data.clone(), path] }) .collect(); diff --git a/crates/xmss/src/lib.rs b/crates/xmss/src/lib.rs index 7c03f6e4..c8a288e3 100644 --- a/crates/xmss/src/lib.rs +++ b/crates/xmss/src/lib.rs @@ -29,7 +29,9 @@ pub const SIG_SIZE_FE: usize = RANDOMNESS_LEN_FE + (V + LOG_LIFETIME) * DIGEST_S pub type Poseidon16History = Vec<([F; 16], [F; 8])>; fn poseidon16_compress_with_trace(a: &Digest, b: &Digest, poseidon_16_trace: &mut Vec<([F; 16], [F; 8])>) -> Digest { - let input: [F; 16] = [*a, *b].concat().try_into().unwrap(); + let mut input = [F::default(); 16]; + input[..8].copy_from_slice(a); + input[8..].copy_from_slice(b); let output = poseidon16_compress(input); poseidon_16_trace.push((input, output)); output diff --git a/crates/xmss/src/wots.rs b/crates/xmss/src/wots.rs index 26df9c63..3c27aa9e 100644 --- a/crates/xmss/src/wots.rs +++ b/crates/xmss/src/wots.rs @@ -174,19 +174,22 @@ pub fn wots_encode_with_poseidon_trace( // ensures uniformity of encoding return None; } - let all_indices: Vec<_> = compressed - .iter() - .flat_map(|kb| to_little_endian_bits(kb.to_usize(), 24)) - .collect::>() - .chunks_exact(W) - .take(V + V_GRINDING) - .map(|chunk| { - chunk - .iter() - .enumerate() - .fold(0u8, |acc, (i, &bit)| acc | (u8::from(bit) << i)) - }) - .collect(); + // 8 field elements * 24 bits each = 192 bits, packed into W=3-bit chunks. + let mut bits = [false; DIGEST_SIZE * 24]; + let mut bit_idx = 0; + for kb in &compressed { + for b in to_little_endian_bits(kb.to_usize(), 24) { + bits[bit_idx] = b; + bit_idx += 1; + } + } + let mut all_indices = [0u8; V + V_GRINDING]; + for (i, chunk) in bits.chunks_exact(W).take(V + V_GRINDING).enumerate() { + all_indices[i] = chunk + .iter() + .enumerate() + .fold(0u8, |acc, (j, &bit)| acc | (u8::from(bit) << j)); + } is_valid_encoding(&all_indices).then(|| all_indices[..V].try_into().unwrap()) } diff --git a/crates/xmss/src/xmss.rs b/crates/xmss/src/xmss.rs index 364cf2a3..15ad9526 100644 --- a/crates/xmss/src/xmss.rs +++ b/crates/xmss/src/xmss.rs @@ -86,13 +86,13 @@ pub fn xmss_key_gen( let left = if left_idx >= prev_base && left_idx <= prev_top { prev[(left_idx - prev_base) as usize] } else { - assert!(left_idx < 1u64 << 32); + debug_assert!(left_idx < 1u64 << 32); gen_random_node(&seed, level - 1, left_idx as u32) }; let right = if right_idx >= prev_base && right_idx <= prev_top { prev[(right_idx - prev_base) as usize] } else { - assert!(right_idx < 1u64 << 32); + debug_assert!(right_idx < 1u64 << 32); gen_random_node(&seed, level - 1, right_idx as u32) }; compress(&perm, [left, right]) @@ -189,7 +189,7 @@ pub fn xmss_verify_with_poseidon_trace( message: &[F; MESSAGE_LEN_FE], signature: &XmssSignature, ) -> Result { - let mut poseidon_16_trace = Vec::new(); + let mut poseidon_16_trace = Vec::with_capacity(NUM_CHAIN_HASHES + V + LOG_LIFETIME); let truncated_merkle_root = pub_key.merkle_root[0..TRUNCATED_MERKLE_ROOT_LEN_FE].try_into().unwrap(); let wots_public_key = signature .wots_signature