Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 88 additions & 41 deletions crates/rec_aggregation/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use lean_vm::*;
use tracing::instrument;
use utils::{build_prover_state, poseidon_compress_slice, poseidon16_compress_pair};
use xmss::{
LOG_LIFETIME, MESSAGE_LEN_FE, Poseidon16History, SIG_SIZE_FE, XmssPublicKey, XmssSignature, slot_to_field_elements,
xmss_verify_with_poseidon_trace,
LOG_LIFETIME, MESSAGE_LEN_FE, Poseidon16History, RANDOMNESS_LEN_FE, SIG_SIZE_FE, XmssPublicKey, XmssSignature,
slot_to_field_elements, xmss_verify_with_poseidon_trace,
};

use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -36,18 +36,19 @@ pub(crate) fn count_signers(topology: &AggregationTopology, overlap: usize) -> u
}

pub fn hash_pubkeys(pub_keys: &[XmssPublicKey]) -> [F; DIGEST_LEN] {
let iv = [F::ZERO; DIGEST_LEN];
let flat: Vec<F> = iv
.iter()
.copied()
.chain(pub_keys.iter().flat_map(|pk| pk.merkle_root.iter().copied()))
.collect();
let total = DIGEST_LEN + pub_keys.len() * DIGEST_LEN;
let mut flat = vec![F::ZERO; total];
// First DIGEST_LEN elements are already zero. Write pub keys directly at offsets.
let mut off = DIGEST_LEN;
for pk in pub_keys {
flat[off..off + DIGEST_LEN].copy_from_slice(&pk.merkle_root);
off += DIGEST_LEN;
}
poseidon_compress_slice(&flat)
}

fn compute_merkle_chunks_for_slot(slot: u32) -> Vec<F> {
let mut chunks = Vec::with_capacity(N_MERKLE_CHUNKS_FOR_SLOT);
for chunk_idx in 0..N_MERKLE_CHUNKS_FOR_SLOT {
fn compute_merkle_chunks_for_slot(slot: u32) -> [F; N_MERKLE_CHUNKS_FOR_SLOT] {
std::array::from_fn(|chunk_idx| {
let mut nibble_val: usize = 0;
for bit in 0..4 {
let level = chunk_idx * 4 + bit;
Expand All @@ -56,9 +57,8 @@ fn compute_merkle_chunks_for_slot(slot: u32) -> Vec<F> {
nibble_val |= 1 << bit;
}
}
chunks.push(F::from_usize(nibble_val));
}
chunks
F::from_usize(nibble_val)
})
}

fn build_non_reserved_public_input(
Expand All @@ -68,26 +68,45 @@ fn build_non_reserved_public_input(
slot: u32,
bytecode_claim_output: &[F],
) -> Vec<F> {
let mut pi = vec![];
pi.push(F::from_usize(n_sigs));
pi.extend_from_slice(slice_hash);
pi.extend_from_slice(message);
let total = 1 + DIGEST_LEN + MESSAGE_LEN_FE + 2 + N_MERKLE_CHUNKS_FOR_SLOT + bytecode_claim_output.len();
let mut pi = Vec::with_capacity(total);
// SAFETY: capacity is `total`, all elements written below
unsafe { pi.set_len(total) };
let mut off = 0;
pi[off] = F::from_usize(n_sigs);
off += 1;
pi[off..off + DIGEST_LEN].copy_from_slice(slice_hash);
off += DIGEST_LEN;
pi[off..off + MESSAGE_LEN_FE].copy_from_slice(message);
off += MESSAGE_LEN_FE;
let [slot_lo, slot_hi] = slot_to_field_elements(slot);
pi.push(slot_lo);
pi.push(slot_hi);
pi.extend(compute_merkle_chunks_for_slot(slot));
pi.extend_from_slice(bytecode_claim_output);
pi[off] = slot_lo;
pi[off + 1] = slot_hi;
off += 2;
let chunks = compute_merkle_chunks_for_slot(slot);
pi[off..off + N_MERKLE_CHUNKS_FOR_SLOT].copy_from_slice(&chunks);
off += N_MERKLE_CHUNKS_FOR_SLOT;
pi[off..off + bytecode_claim_output.len()].copy_from_slice(bytecode_claim_output);
debug_assert_eq!(off + bytecode_claim_output.len(), total);
pi
}

fn encode_xmss_signature(sig: &XmssSignature) -> Vec<F> {
let mut data = vec![];
data.extend(sig.wots_signature.randomness.to_vec());
data.extend(sig.wots_signature.chain_tips.iter().flat_map(|digest| digest.to_vec()));
let mut data = Vec::with_capacity(SIG_SIZE_FE);
// SAFETY: capacity is SIG_SIZE_FE, all bytes written below via copy_from_slice
unsafe { data.set_len(SIG_SIZE_FE) };
let mut off = 0;
data[off..off + RANDOMNESS_LEN_FE].copy_from_slice(&sig.wots_signature.randomness);
off += RANDOMNESS_LEN_FE;
for digest in &sig.wots_signature.chain_tips {
data[off..off + DIGEST_LEN].copy_from_slice(digest);
off += DIGEST_LEN;
}
for neighbor in &sig.merkle_proof {
data.extend(neighbor.to_vec());
data[off..off + DIGEST_LEN].copy_from_slice(neighbor);
off += DIGEST_LEN;
}
assert_eq!(data.len(), SIG_SIZE_FE);
debug_assert_eq!(off, SIG_SIZE_FE);
data
}

Expand Down Expand Up @@ -342,26 +361,47 @@ pub fn xmss_aggregate(
}
let bytecode_sumcheck_proof_ptr = offset;

let mut private_input = vec![];
private_input.push(F::from_usize(n_recursions));
private_input.push(F::from_usize(n_dup));
private_input.push(F::from_usize(pubkeys_start));
let source_blocks_total: usize = source_blocks.iter().map(|b| b.len()).sum();
let private_total = header_size + pubkeys_block_size + source_blocks_total + final_sumcheck_transcript.len();
let mut private_input = Vec::with_capacity(private_total);
// SAFETY: capacity is private_total, all elements written below via indexed assignment + copy_from_slice
unsafe { private_input.set_len(private_total) };
let mut off = 0;

// Header
private_input[off] = F::from_usize(n_recursions);
private_input[off + 1] = F::from_usize(n_dup);
private_input[off + 2] = F::from_usize(pubkeys_start);
off += 3;
for &ptr in &source_ptrs {
private_input.push(F::from_usize(ptr));
private_input[off] = F::from_usize(ptr);
off += 1;
}
private_input.push(F::from_usize(bytecode_sumcheck_proof_ptr));
assert_eq!(private_input.len(), header_size);
private_input[off] = F::from_usize(bytecode_sumcheck_proof_ptr);
off += 1;
debug_assert_eq!(off, header_size);

// Pub keys
for pk in &global_pub_keys {
private_input.extend_from_slice(&pk.merkle_root);
private_input[off..off + DIGEST_LEN].copy_from_slice(&pk.merkle_root);
off += DIGEST_LEN;
}
for pk in &dup_pub_keys {
private_input.extend_from_slice(&pk.merkle_root);
private_input[off..off + DIGEST_LEN].copy_from_slice(&pk.merkle_root);
off += DIGEST_LEN;
}

// Source blocks
for block in &source_blocks {
private_input.extend_from_slice(block);
let blen = block.len();
private_input[off..off + blen].copy_from_slice(block);
off += blen;
}
private_input.extend_from_slice(&final_sumcheck_transcript);

// Bytecode sumcheck transcript
let tlen = final_sumcheck_transcript.len();
private_input[off..off + tlen].copy_from_slice(&final_sumcheck_transcript);
debug_assert_eq!(off + tlen, private_total);

// TODO precompute all the other poseidons
let xmss_poseidons_16_precomputed = precompute_poseidons(&raw_xmss, message);
Expand All @@ -372,9 +412,16 @@ pub fn xmss_aggregate(
.iter()
.flat_map(|p| p.merkle_openings.iter())
.flat_map(|o| {
let leaf = o.leaf_data.clone();
let path: Vec<F> = o.path.iter().flat_map(|d| d.iter().copied()).collect();
[leaf, path]
let path_len = o.path.len() * DIGEST_LEN;
let mut path = Vec::with_capacity(path_len);
// SAFETY: capacity is path_len, all elements written below
unsafe { path.set_len(path_len) };
let mut off = 0;
for d in &o.path {
path[off..off + DIGEST_LEN].copy_from_slice(d);
off += DIGEST_LEN;
}
[o.leaf_data.clone(), path]
})
.collect();

Expand Down
4 changes: 3 additions & 1 deletion crates/xmss/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ pub const SIG_SIZE_FE: usize = RANDOMNESS_LEN_FE + (V + LOG_LIFETIME) * DIGEST_S
pub type Poseidon16History = Vec<([F; 16], [F; 8])>;

fn poseidon16_compress_with_trace(a: &Digest, b: &Digest, poseidon_16_trace: &mut Vec<([F; 16], [F; 8])>) -> Digest {
let input: [F; 16] = [*a, *b].concat().try_into().unwrap();
let mut input = [F::default(); 16];
input[..8].copy_from_slice(a);
input[8..].copy_from_slice(b);
let output = poseidon16_compress(input);
poseidon_16_trace.push((input, output));
output
Expand Down
29 changes: 16 additions & 13 deletions crates/xmss/src/wots.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,19 +174,22 @@ pub fn wots_encode_with_poseidon_trace(
// ensures uniformity of encoding
return None;
}
let all_indices: Vec<_> = compressed
.iter()
.flat_map(|kb| to_little_endian_bits(kb.to_usize(), 24))
.collect::<Vec<_>>()
.chunks_exact(W)
.take(V + V_GRINDING)
.map(|chunk| {
chunk
.iter()
.enumerate()
.fold(0u8, |acc, (i, &bit)| acc | (u8::from(bit) << i))
})
.collect();
// 8 field elements * 24 bits each = 192 bits, packed into W=3-bit chunks.
let mut bits = [false; DIGEST_SIZE * 24];
let mut bit_idx = 0;
for kb in &compressed {
for b in to_little_endian_bits(kb.to_usize(), 24) {
bits[bit_idx] = b;
bit_idx += 1;
}
}
let mut all_indices = [0u8; V + V_GRINDING];
for (i, chunk) in bits.chunks_exact(W).take(V + V_GRINDING).enumerate() {
all_indices[i] = chunk
.iter()
.enumerate()
.fold(0u8, |acc, (j, &bit)| acc | (u8::from(bit) << j));
}
is_valid_encoding(&all_indices).then(|| all_indices[..V].try_into().unwrap())
}

Expand Down
6 changes: 3 additions & 3 deletions crates/xmss/src/xmss.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,13 @@ pub fn xmss_key_gen(
let left = if left_idx >= prev_base && left_idx <= prev_top {
prev[(left_idx - prev_base) as usize]
} else {
assert!(left_idx < 1u64 << 32);
debug_assert!(left_idx < 1u64 << 32);
gen_random_node(&seed, level - 1, left_idx as u32)
};
let right = if right_idx >= prev_base && right_idx <= prev_top {
prev[(right_idx - prev_base) as usize]
} else {
assert!(right_idx < 1u64 << 32);
debug_assert!(right_idx < 1u64 << 32);
gen_random_node(&seed, level - 1, right_idx as u32)
};
compress(&perm, [left, right])
Expand Down Expand Up @@ -189,7 +189,7 @@ pub fn xmss_verify_with_poseidon_trace(
message: &[F; MESSAGE_LEN_FE],
signature: &XmssSignature,
) -> Result<Poseidon16History, XmssVerifyError> {
let mut poseidon_16_trace = Vec::new();
let mut poseidon_16_trace = Vec::with_capacity(NUM_CHAIN_HASHES + V + LOG_LIFETIME);
let truncated_merkle_root = pub_key.merkle_root[0..TRUNCATED_MERKLE_ROOT_LEN_FE].try_into().unwrap();
let wots_public_key = signature
.wots_signature
Expand Down