diff --git a/src/lean_spec/subspecs/xmss/containers.py b/src/lean_spec/subspecs/xmss/containers.py index 4ad096ff..662494fc 100644 --- a/src/lean_spec/subspecs/xmss/containers.py +++ b/src/lean_spec/subspecs/xmss/containers.py @@ -72,12 +72,32 @@ def verify( public_key: PublicKey, epoch: "Uint64", message: bytes, - scheme: GeneralizedXmssScheme, + scheme: "GeneralizedXmssScheme", ) -> bool: - """Verify the signature using XMSS verification algorithm.""" + """ + Verify the signature using XMSS verification algorithm. + + This is a convenience method that delegates to `scheme.verify()`. + + Invalid or malformed signatures return `False`. + + Expected exceptions: + - `ValueError` for invalid epochs, + - `IndexError` for malformed signatures + are caught and converted to `False`. + + Args: + public_key: The public key to verify against. + epoch: The epoch the signature corresponds to. + message: The message that was supposedly signed. + scheme: The XMSS scheme instance to use for verification. + + Returns: + `True` if the signature is valid, `False` otherwise. + """ try: return scheme.verify(public_key, epoch, message, self) - except Exception: + except (ValueError, IndexError): return False diff --git a/src/lean_spec/subspecs/xmss/hypercube.py b/src/lean_spec/subspecs/xmss/hypercube.py index 420d2cea..f1bbf86d 100644 --- a/src/lean_spec/subspecs/xmss/hypercube.py +++ b/src/lean_spec/subspecs/xmss/hypercube.py @@ -33,30 +33,29 @@ import bisect import math +from dataclasses import dataclass from functools import lru_cache from itertools import accumulate -from typing import List, Tuple - -from lean_spec.types import StrictBaseModel MAX_DIMENSION = 100 """The maximum dimension `v` for which layer sizes will be precomputed.""" -class LayerInfo(StrictBaseModel): +@dataclass(frozen=True, slots=True) +class LayerInfo: """ - A data structure to store precomputed sizes and cumulative sums for the - layers of a single hypercube configuration (fixed `w` and `v`). + Precomputed sizes and cumulative sums for a hypercube configuration. - This object makes subsequent calculations, like finding the total size of a - range of layers, highly efficient. + This immutable data structure enables O(1) lookups for layer sizes and + range sums, which is critical for efficient hypercube mapping. """ - sizes: List[int] - """A list where `sizes[d]` is the number of vertices in layer `d`.""" - prefix_sums: List[int] + sizes: tuple[int, ...] + """Tuple where `sizes[d]` is the number of vertices in layer `d`.""" + + prefix_sums: tuple[int, ...] """ - A list where `prefix_sums[d]` is the cumulative number of vertices from + Tuple where `prefix_sums[d]` is the cumulative number of vertices from layer 0 up to and including layer `d`. Mathematically: `prefix_sums[d] = sizes[0] + ... + sizes[d]`. @@ -125,7 +124,7 @@ def _calculate_layer_size(w: int, v: int, d: int) -> int: @lru_cache(maxsize=None) -def prepare_layer_info(w: int) -> List[LayerInfo]: +def prepare_layer_info(w: int) -> tuple[LayerInfo, ...]: """ Precomputes and caches layer information using a direct combinatorial formula. @@ -138,24 +137,25 @@ def prepare_layer_info(w: int) -> List[LayerInfo]: w: The base of the hypercube. Returns: - A list where `list[v]` is a `LayerInfo` object for a `v`-dim hypercube. + A tuple where `tuple[v]` is a `LayerInfo` object for a `v`-dim hypercube. """ - all_info = [LayerInfo(sizes=[], prefix_sums=[])] * (MAX_DIMENSION + 1) + # Initialize with empty placeholder for index 0 + all_info: list[LayerInfo] = [LayerInfo(sizes=(), prefix_sums=())] * (MAX_DIMENSION + 1) for v in range(1, MAX_DIMENSION + 1): # The maximum possible distance `d` in a v-dimensional hypercube. max_d = (w - 1) * v # Directly compute the size of each layer using the helper function. - sizes = [_calculate_layer_size(w, v, d) for d in range(max_d + 1)] + sizes = tuple(_calculate_layer_size(w, v, d) for d in range(max_d + 1)) - # Compute the cumulative sums from the list of sizes. - prefix_sums = list(accumulate(sizes)) + # Compute the cumulative sums from the tuple of sizes. + prefix_sums = tuple(accumulate(sizes)) # Store the complete layer info for the current dimension `v`. all_info[v] = LayerInfo(sizes=sizes, prefix_sums=prefix_sums) - return all_info + return tuple(all_info) def get_layer_size(w: int, v: int, d: int) -> int: @@ -168,7 +168,7 @@ def hypercube_part_size(w: int, v: int, d: int) -> int: return prepare_layer_info(w)[v].prefix_sums[d] -def hypercube_find_layer(w: int, v: int, x: int) -> Tuple[int, int]: +def hypercube_find_layer(w: int, v: int, x: int) -> tuple[int, int]: """ Given a global index `x`, finds its layer `d` and local offset `remainder`. @@ -203,7 +203,7 @@ def hypercube_find_layer(w: int, v: int, x: int) -> Tuple[int, int]: return d, remainder -def map_to_vertex(w: int, v: int, d: int, x: int) -> List[int]: +def map_to_vertex(w: int, v: int, d: int, x: int) -> list[int]: """ Maps an integer index `x` to a unique vertex in a specific hypercube layer. @@ -228,7 +228,7 @@ def map_to_vertex(w: int, v: int, d: int, x: int) -> List[int]: if x >= layer_size: raise ValueError("Index x is out of bounds for the given layer.") - vertex: List[int] = [] + vertex: list[int] = [] # Track remaining distance and index. d_curr, x_curr = d, x @@ -239,20 +239,17 @@ def map_to_vertex(w: int, v: int, d: int, x: int) -> List[int]: # This loop finds which block of sub-hypercubes the index `x_curr` falls into. # - # It skips over full blocks by subtracting their size - # from `x_curr` until the correct one is found. - ji = -1 # Sentinel value + # It skips over full blocks by subtracting their size from `x_curr` until found. + ji = None range_start = max(0, d_curr - (w - 1) * dim_remaining) for j in range(range_start, min(w, d_curr + 1)): count = prev_dim_layer_info.sizes[d_curr - j] - if x_curr >= count: - x_curr -= count - else: - # Found the correct block. + if x_curr < count: ji = j break + x_curr -= count - if ji == -1: + if ji is None: raise RuntimeError("Internal logic error: failed to find coordinate") # Convert the block's distance contribution `ji` to a coordinate `ai`. diff --git a/src/lean_spec/subspecs/xmss/interface.py b/src/lean_spec/subspecs/xmss/interface.py index 26679d20..7a7916ad 100644 --- a/src/lean_spec/subspecs/xmss/interface.py +++ b/src/lean_spec/subspecs/xmss/interface.py @@ -8,8 +8,6 @@ from __future__ import annotations -from typing import List, Tuple - from pydantic import model_validator from lean_spec.config import LEAN_ENV @@ -28,13 +26,13 @@ from .containers import PublicKey, SecretKey, Signature from .prf import PROD_PRF, TEST_PRF, Prf from .rand import PROD_RAND, TEST_RAND, Rand -from .subtree import HashSubTree +from .subtree import HashSubTree, combined_path, verify_path from .tweak_hash import ( PROD_TWEAK_HASHER, TEST_TWEAK_HASHER, TweakHasher, ) -from .types import HashDigestVector +from .types import HashDigestList, HashDigestVector from .utils import expand_activation_time @@ -42,9 +40,8 @@ class GeneralizedXmssScheme(StrictBaseModel): """ Instance of the Generalized XMSS signature scheme for a given config. - This class enforces strict type checking to ensure only approved component - implementations are used. Subclasses of the base component types (such as - SeededPrf or SeededRand) are explicitly rejected. + This class holds the configuration and component instances needed to + perform key generation, signing, and verification operations. """ config: XmssConfig @@ -64,25 +61,22 @@ class GeneralizedXmssScheme(StrictBaseModel): @model_validator(mode="after") def enforce_strict_types(self) -> "GeneralizedXmssScheme": - """Validates that only exact approved types are used (rejects subclasses).""" - checks = { - "config": XmssConfig, - "prf": Prf, - "hasher": TweakHasher, - "encoder": TargetSumEncoder, - "rand": Rand, - } - for field, expected in checks.items(): - if type(getattr(self, field)) is not expected: - raise TypeError( - f"{field} must be exactly {expected.__name__}, " - f"got {type(getattr(self, field)).__name__}" - ) + """Reject subclasses to prevent type confusion attacks.""" + if type(self.config) is not XmssConfig: + raise TypeError("config must be exactly XmssConfig, not a subclass") + if type(self.prf) is not Prf: + raise TypeError("prf must be exactly Prf, not a subclass") + if type(self.hasher) is not TweakHasher: + raise TypeError("hasher must be exactly TweakHasher, not a subclass") + if type(self.encoder) is not TargetSumEncoder: + raise TypeError("encoder must be exactly TargetSumEncoder, not a subclass") + if type(self.rand) is not Rand: + raise TypeError("rand must be exactly Rand, not a subclass") return self def key_gen( self, activation_epoch: Uint64, num_active_epochs: Uint64 - ) -> Tuple[PublicKey, SecretKey]: + ) -> tuple[PublicKey, SecretKey]: """ Generates a new cryptographic key pair for a specified range of epochs. @@ -184,7 +178,7 @@ def key_gen( ) # Collect roots for building the top tree. - bottom_tree_roots: List[HashDigestVector] = [ + bottom_tree_roots: list[HashDigestVector] = [ left_bottom_tree.root(), right_bottom_tree.root(), ] @@ -277,11 +271,9 @@ def sign(self, sk: SecretKey, epoch: Uint64, message: bytes) -> Signature: config = self.config # Verify that the secret key is currently active for the requested signing epoch. - # Note: range() requires int, so we convert only for the range check + epoch_int = int(epoch) activation_int = int(sk.activation_epoch) - num_epochs_int = int(sk.num_active_epochs) - active_range = range(activation_int, activation_int + num_epochs_int) - if int(epoch) not in active_range: + if not (activation_int <= epoch_int < activation_int + int(sk.num_active_epochs)): raise ValueError("Key is not active for the specified epoch.") # Verify that the epoch is within the prepared interval (covered by loaded bottom trees). @@ -290,11 +282,13 @@ def sign(self, sk: SecretKey, epoch: Uint64, message: bytes) -> Signature: # signed without computing additional bottom trees. # # If the epoch is outside this range, we need to slide the window forward. - prepared_interval = self.get_prepared_interval(sk) - if int(epoch) not in prepared_interval: + leafs_per_bottom_tree = 1 << (config.LOG_LIFETIME // 2) + prepared_start = int(sk.left_bottom_tree_index) * leafs_per_bottom_tree + prepared_end = prepared_start + 2 * leafs_per_bottom_tree + if not (prepared_start <= epoch_int < prepared_end): raise ValueError( f"Epoch {epoch} is outside the prepared interval " - f"[{prepared_interval.start}, {prepared_interval.stop}). " + f"[{prepared_start}, {prepared_end}). " f"Call advance_preparation() to slide the window forward." ) @@ -328,7 +322,7 @@ def sign(self, sk: SecretKey, epoch: Uint64, message: bytes) -> Signature: raise RuntimeError("Encoding is broken: returned too many or too few chunks.") # Compute the one-time signature hashes based on the codeword. - ots_hashes: List[HashDigestVector] = [] + ots_hashes: list[HashDigestVector] = [] for chain_index, steps in enumerate(codeword): # Derive the secret start of the current chain using the master PRF key. start_digest = self.prf.apply(sk.prf_key, epoch, Uint64(chain_index)) @@ -350,16 +344,9 @@ def sign(self, sk: SecretKey, epoch: Uint64, message: bytes) -> Signature: # With top-bottom tree traversal, we use combined_path to merge paths from # the bottom tree and top tree. - # Determine which bottom tree contains this epoch. - leafs_per_bottom_tree = 1 << (config.LOG_LIFETIME // 2) - boundary = (sk.left_bottom_tree_index + Uint64(1)) * Uint64(leafs_per_bottom_tree) - - if epoch < boundary: - # Use left bottom tree - bottom_tree = sk.left_bottom_tree - else: - # Use right bottom tree - bottom_tree = sk.right_bottom_tree + # Determine which bottom tree contains this epoch (reuse leafs_per_bottom_tree from above). + boundary = (int(sk.left_bottom_tree_index) + 1) * leafs_per_bottom_tree + bottom_tree = sk.left_bottom_tree if epoch_int < boundary else sk.right_bottom_tree # Ensure bottom tree exists if bottom_tree is None: @@ -370,16 +357,12 @@ def sign(self, sk: SecretKey, epoch: Uint64, message: bytes) -> Signature: ) # Generate the combined authentication path - from .subtree import combined_path - path = combined_path(sk.top_tree, bottom_tree, epoch) # Assemble and return the final signature, which contains: # - The OTS, # - The Merkle path, # - The randomness `rho` needed for verification. - from .types import HashDigestList - return Signature(path=path, rho=rho, hashes=HashDigestList(data=ots_hashes)) def verify(self, pk: PublicKey, epoch: Uint64, message: bytes, sig: Signature) -> bool: @@ -437,7 +420,7 @@ def verify(self, pk: PublicKey, epoch: Uint64, message: bytes, sig: Signature) - return False # Reconstruct the one-time public key (the list of chain endpoints). - chain_ends: List[HashDigestVector] = [] + chain_ends: list[HashDigestVector] = [] for chain_index, xi in enumerate(codeword): # The signature provides `start_digest`, which is the hash value after `xi` steps. start_digest = sig.hashes[chain_index] @@ -460,8 +443,6 @@ def verify(self, pk: PublicKey, epoch: Uint64, message: bytes, sig: Signature) - # - Hashes the `chain_ends` to get the leaf node for the epoch, # - Uses the `opening` path from the signature to compute a candidate root. # - It returns true if and only if this candidate root matches the public key's root. - from .subtree import verify_path - return verify_path( hasher=self.hasher, parameter=pk.parameter, @@ -510,9 +491,8 @@ def get_prepared_interval(self, sk: SecretKey) -> range: ValueError: If the secret key is missing top-bottom tree structures. """ leafs_per_bottom_tree = 1 << (self.config.LOG_LIFETIME // 2) - start = int(sk.left_bottom_tree_index * Uint64(leafs_per_bottom_tree)) - end = start + (2 * leafs_per_bottom_tree) - return range(start, end) + start = int(sk.left_bottom_tree_index) * leafs_per_bottom_tree + return range(start, start + 2 * leafs_per_bottom_tree) def advance_preparation(self, sk: SecretKey) -> SecretKey: """ @@ -540,26 +520,23 @@ def advance_preparation(self, sk: SecretKey) -> SecretKey: ValueError: If advancing would exceed the activation interval. """ leafs_per_bottom_tree = 1 << (self.config.LOG_LIFETIME // 2) + left_index = int(sk.left_bottom_tree_index) # Check if advancing would exceed the activation interval - next_prepared_end_epoch = int( - sk.left_bottom_tree_index * Uint64(leafs_per_bottom_tree) - + Uint64(3 * leafs_per_bottom_tree) - ) - activation_interval = self.get_activation_interval(sk) - if next_prepared_end_epoch > activation_interval.stop: + next_prepared_end_epoch = (left_index + 3) * leafs_per_bottom_tree + activation_end = int(sk.activation_epoch) + int(sk.num_active_epochs) + if next_prepared_end_epoch > activation_end: # Nothing to do - we're already at the end of the activation interval return sk # Compute the next bottom tree (the one after the current right tree) - new_right_tree_index = sk.left_bottom_tree_index + Uint64(2) new_right_bottom_tree = HashSubTree.from_prf_key( prf=self.prf, hasher=self.hasher, rand=self.rand, config=self.config, prf_key=sk.prf_key, - bottom_tree_index=new_right_tree_index, + bottom_tree_index=Uint64(left_index + 2), parameter=sk.parameter, ) @@ -568,7 +545,7 @@ def advance_preparation(self, sk: SecretKey) -> SecretKey: update={ "left_bottom_tree": sk.right_bottom_tree, "right_bottom_tree": new_right_bottom_tree, - "left_bottom_tree_index": sk.left_bottom_tree_index + Uint64(1), + "left_bottom_tree_index": Uint64(left_index + 1), } ) diff --git a/src/lean_spec/subspecs/xmss/message_hash.py b/src/lean_spec/subspecs/xmss/message_hash.py index 2a26ae2e..23d1a72b 100644 --- a/src/lean_spec/subspecs/xmss/message_hash.py +++ b/src/lean_spec/subspecs/xmss/message_hash.py @@ -29,8 +29,6 @@ from __future__ import annotations -from typing import List - from pydantic import model_validator from lean_spec.subspecs.xmss.poseidon import ( @@ -67,17 +65,14 @@ class MessageHasher(StrictBaseModel): @model_validator(mode="after") def enforce_strict_types(self) -> "MessageHasher": - """Validates that only exact approved types are used (rejects subclasses).""" - checks = {"config": XmssConfig, "poseidon": PoseidonXmss} - for field, expected in checks.items(): - if type(getattr(self, field)) is not expected: - raise TypeError( - f"{field} must be exactly {expected.__name__}, " - f"got {type(getattr(self, field)).__name__}" - ) + """Reject subclasses to prevent type confusion attacks.""" + if type(self.config) is not XmssConfig: + raise TypeError("config must be exactly XmssConfig, not a subclass") + if type(self.poseidon) is not PoseidonXmss: + raise TypeError("poseidon must be exactly PoseidonXmss, not a subclass") return self - def encode_message(self, message: bytes) -> List[Fp]: + def encode_message(self, message: bytes) -> list[Fp]: """ Encodes a 32-byte message into a list of field elements. @@ -92,7 +87,7 @@ def encode_message(self, message: bytes) -> List[Fp]: # Decompose the integer into a list of field elements (base-P). return int_to_base_p(acc, self.config.MSG_LEN_FE) - def encode_epoch(self, epoch: Uint64) -> List[Fp]: + def encode_epoch(self, epoch: Uint64) -> list[Fp]: """ Encodes the epoch and a domain separator prefix into field elements. @@ -106,7 +101,7 @@ def encode_epoch(self, epoch: Uint64) -> List[Fp]: # Decompose the integer into its base-P representation. return int_to_base_p(acc, self.config.TWEAK_LEN_FE) - def _map_into_hypercube_part(self, field_elements: List[Fp]) -> List[int]: + def _map_into_hypercube_part(self, field_elements: list[Fp]) -> list[int]: """ Maps a long, pseudorandom digest to a unique vertex within the top layers of the signature hypercube. @@ -153,7 +148,7 @@ def apply( epoch: Uint64, rho: Randomness, message: bytes, - ) -> List[int]: + ) -> list[int]: """ Applies the full "Top Level" message hash and mapping procedure. @@ -184,21 +179,16 @@ def apply( epoch_fe = self.encode_epoch(epoch) # Iteratively call Poseidon2 to generate a long hash output. - poseidon_outputs: List[Fp] = [] + # + # The base input (rho || P || epoch || message) is reused each iteration. + base_input = list(rho.data) + list(parameter.data) + epoch_fe + message_fe + poseidon_outputs: list[Fp] = [] + output_len = self.config.POS_OUTPUT_LEN_PER_INV_FE for i in range(self.config.POS_INVOCATIONS): - # Use the iteration number as a domain separator for each hash call. - iteration_separator = [Fp(value=i)] - - # The input is: rho || P || epoch || message || iteration. - combined_input = ( - list(rho.data) + list(parameter.data) + epoch_fe + message_fe + iteration_separator - ) - - # Hash the combined input using Poseidon2 compression mode. - iteration_output = self.poseidon.compress( - combined_input, 24, self.config.POS_OUTPUT_LEN_PER_INV_FE + # Append iteration number as domain separator and hash. + poseidon_outputs.extend( + self.poseidon.compress(base_input + [Fp(value=i)], 24, output_len) ) - poseidon_outputs.extend(iteration_output) # Map the final aggregated list of field elements into a hypercube vertex. return self._map_into_hypercube_part(poseidon_outputs) diff --git a/src/lean_spec/subspecs/xmss/poseidon.py b/src/lean_spec/subspecs/xmss/poseidon.py index 3cbff51c..e57a2431 100644 --- a/src/lean_spec/subspecs/xmss/poseidon.py +++ b/src/lean_spec/subspecs/xmss/poseidon.py @@ -22,8 +22,6 @@ from __future__ import annotations -from typing import List - from pydantic import model_validator from lean_spec.types import StrictBaseModel @@ -49,17 +47,14 @@ class PoseidonXmss(StrictBaseModel): @model_validator(mode="after") def enforce_strict_types(self) -> "PoseidonXmss": - """Validates that only exact approved types are used (rejects subclasses).""" - checks = {"params16": Poseidon2Params, "params24": Poseidon2Params} - for field, expected in checks.items(): - if type(getattr(self, field)) is not expected: - raise TypeError( - f"{field} must be exactly {expected.__name__}, " - f"got {type(getattr(self, field)).__name__}" - ) + """Reject subclasses to prevent type confusion attacks.""" + if type(self.params16) is not Poseidon2Params: + raise TypeError("params16 must be exactly Poseidon2Params, not a subclass") + if type(self.params24) is not Poseidon2Params: + raise TypeError("params24 must be exactly Poseidon2Params, not a subclass") return self - def compress(self, input_vec: List[Fp], width: int, output_len: int) -> List[Fp]: + def compress(self, input_vec: list[Fp], width: int, output_len: int) -> list[Fp]: """ Implements the Poseidon2 hash in **compression mode**. @@ -93,9 +88,8 @@ def compress(self, input_vec: List[Fp], width: int, output_len: int) -> List[Fp] assert width in (16, 24), "Width must be 16 or 24" params = self.params16 if width == 16 else self.params24 - # Create a fixed-width buffer and copy the input, padding with zeros. - padded_input = [Fp(value=0)] * width - padded_input[: len(input_vec)] = input_vec + # Create a padded input by extending with zeros to match the state width. + padded_input = list(input_vec) + [Fp(value=0)] * (width - len(input_vec)) # Apply the Poseidon2 permutation. permuted_state = permute(padded_input, params) @@ -106,7 +100,7 @@ def compress(self, input_vec: List[Fp], width: int, output_len: int) -> List[Fp] # Truncate the state to the desired output length and return. return final_state[:output_len] - def safe_domain_separator(self, lengths: List[int], capacity_len: int) -> List[Fp]: + def safe_domain_separator(self, lengths: list[int], capacity_len: int) -> list[Fp]: """ Computes a unique domain separator for the sponge construction (SAFE API). @@ -139,11 +133,11 @@ def safe_domain_separator(self, lengths: List[int], capacity_len: int) -> List[F def sponge( self, - input_vec: List[Fp], - capacity_value: List[Fp], + input_vec: list[Fp], + capacity_value: list[Fp], output_len: int, width: int, - ) -> List[Fp]: + ) -> list[Fp]: """ Implements the Poseidon2 hash using the **sponge construction**. @@ -203,7 +197,7 @@ def sponge( state = permute(state, params) # Squeeze the output until enough elements have been generated. - output: List[Fp] = [] + output: list[Fp] = [] while len(output) < output_len: # Extract the rate part of the state as output. output.extend(state[:rate]) diff --git a/src/lean_spec/subspecs/xmss/prf.py b/src/lean_spec/subspecs/xmss/prf.py index 7a8c3fd5..babb7463 100644 --- a/src/lean_spec/subspecs/xmss/prf.py +++ b/src/lean_spec/subspecs/xmss/prf.py @@ -80,6 +80,28 @@ """ +def _bytes_to_field_elements(data: bytes, count: int) -> list[Fp]: + """ + Convert PRF output bytes into a list of field elements. + + Each field element is derived from `PRF_BYTES_PER_FE` bytes, + interpreted as a big-endian integer and reduced modulo the field prime. + + The extra bits provide statistical uniformity. + + Args: + data: Raw bytes from SHAKE128 output. Must be exactly `count * PRF_BYTES_PER_FE` bytes. + count: Number of field elements to extract. + + Returns: + List of `count` field elements. + """ + return [ + Fp(value=int.from_bytes(data[i : i + PRF_BYTES_PER_FE], "big")) + for i in range(0, count * PRF_BYTES_PER_FE, PRF_BYTES_PER_FE) + ] + + class Prf(StrictBaseModel): """An instance of the SHAKE128-based PRF for a given config.""" @@ -88,9 +110,9 @@ class Prf(StrictBaseModel): @model_validator(mode="after") def enforce_strict_types(self) -> "Prf": - """Validates that only exact approved types are used (rejects subclasses).""" + """Reject subclasses to prevent type confusion attacks.""" if type(self.config) is not XmssConfig: - raise TypeError(f"config must be exactly XmssConfig, got {type(self.config).__name__}") + raise TypeError("config must be exactly XmssConfig, not a subclass") return self def key_gen(self) -> PRFKey: @@ -153,22 +175,7 @@ def apply(self, key: PRFKey, epoch: Uint64, chain_index: Uint64) -> HashDigestVe prf_output_bytes = hashlib.shake_128(input_data).digest(num_bytes_to_read) # Convert the raw byte output into a list of field elements. - # - # For each required field element, this performs the following steps: - # - Slice an 8-byte (64-bit) chunk from the `prf_output_bytes`. - # - Convert that chunk from a big-endian byte representation to an integer. - # - Create a field element from the integer (the Fp constructor handles the modulo). - return HashDigestVector( - data=[ - Fp( - value=int.from_bytes( - prf_output_bytes[i * PRF_BYTES_PER_FE : (i + 1) * PRF_BYTES_PER_FE], - "big", - ) - ) - for i in range(config.HASH_LEN_FE) - ] - ) + return HashDigestVector(data=_bytes_to_field_elements(prf_output_bytes, config.HASH_LEN_FE)) def get_randomness( self, key: PRFKey, epoch: Uint64, message: bytes, counter: Uint64 @@ -225,17 +232,7 @@ def get_randomness( prf_output_bytes = hashlib.shake_128(input_data).digest(num_bytes_to_read) # Convert to field elements and wrap in Randomness - return Randomness( - data=[ - Fp( - value=int.from_bytes( - prf_output_bytes[i * PRF_BYTES_PER_FE : (i + 1) * PRF_BYTES_PER_FE], - "big", - ) - ) - for i in range(config.RAND_LEN_FE) - ] - ) + return Randomness(data=_bytes_to_field_elements(prf_output_bytes, config.RAND_LEN_FE)) PROD_PRF = Prf(config=PROD_CONFIG) diff --git a/src/lean_spec/subspecs/xmss/rand.py b/src/lean_spec/subspecs/xmss/rand.py index bf0e0a4d..488b47a5 100644 --- a/src/lean_spec/subspecs/xmss/rand.py +++ b/src/lean_spec/subspecs/xmss/rand.py @@ -1,7 +1,6 @@ """Random data generator for the XMSS signature scheme.""" import secrets -from typing import List from pydantic import model_validator @@ -20,12 +19,12 @@ class Rand(StrictBaseModel): @model_validator(mode="after") def enforce_strict_types(self) -> "Rand": - """Validates that only exact approved types are used (rejects subclasses).""" + """Reject subclasses to prevent type confusion attacks.""" if type(self.config) is not XmssConfig: - raise TypeError(f"config must be exactly XmssConfig, got {type(self.config).__name__}") + raise TypeError("config must be exactly XmssConfig, not a subclass") return self - def field_elements(self, length: int) -> List[Fp]: + def field_elements(self, length: int) -> list[Fp]: """Generates a random list of field elements.""" # For each element, generate a secure random integer in the range [0, P-1]. return [Fp(value=secrets.randbelow(P)) for _ in range(length)] diff --git a/src/lean_spec/subspecs/xmss/subtree.py b/src/lean_spec/subspecs/xmss/subtree.py index 34c3650a..393e59ca 100644 --- a/src/lean_spec/subspecs/xmss/subtree.py +++ b/src/lean_spec/subspecs/xmss/subtree.py @@ -7,7 +7,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING from lean_spec.types import Uint64 from lean_spec.types.container import Container @@ -107,7 +107,7 @@ def new( depth: int, start_index: Uint64, parameter: Parameter, - lowest_layer_nodes: List[HashDigestVector], + lowest_layer_nodes: list[HashDigestVector], ) -> HashSubTree: """ Builds a new sparse Merkle subtree starting from a specified layer. @@ -154,7 +154,7 @@ def new( ) # Initialize with padded input layer. - layers: List[HashTreeLayer] = [] + layers: list[HashTreeLayer] = [] current = get_padded_layer(rand, lowest_layer_nodes, start_index) layers.append(current) @@ -162,14 +162,16 @@ def new( for level in range(lowest_layer, depth): parent_start = current.start_index // Uint64(2) - # Hash each pair of siblings into their parent. + # Hash each pair of siblings into their parent using zip for cleaner indexing. + parent_start_int = int(parent_start) + node_pairs = zip(current.nodes[::2], current.nodes[1::2], strict=True) parents = [ hasher.apply( parameter, - TreeTweak(level=level + 1, index=Uint64(int(parent_start) + i)), - [current.nodes[2 * i], current.nodes[2 * i + 1]], + TreeTweak(level=level + 1, index=Uint64(parent_start_int + i)), + [left, right], ) - for i in range(len(current.nodes) // 2) + for i, (left, right) in enumerate(node_pairs) ] # Pad and store the new layer. @@ -190,7 +192,7 @@ def new_top_tree( depth: int, start_bottom_tree_index: Uint64, parameter: Parameter, - bottom_tree_roots: List[HashDigestVector], + bottom_tree_roots: list[HashDigestVector], ) -> HashSubTree: """ Constructs a top tree from the roots of bottom trees. @@ -246,7 +248,7 @@ def new_bottom_tree( depth: int, bottom_tree_index: Uint64, parameter: Parameter, - leaves: List[HashDigestVector], + leaves: list[HashDigestVector], ) -> HashSubTree: """ Constructs a single bottom tree from leaf hashes. @@ -372,11 +374,11 @@ def from_prf_key( end_epoch = start_epoch + Uint64(leafs_per_bottom_tree) # Generate leaf hashes for all epochs in this bottom tree. - leaf_hashes: List[HashDigestVector] = [] + leaf_hashes: list[HashDigestVector] = [] for epoch in range(int(start_epoch), int(end_epoch)): # For each epoch, compute the one-time public key (chain endpoints). - chain_ends: List[HashDigestVector] = [] + chain_ends: list[HashDigestVector] = [] for chain_index in range(config.DIMENSION): # Derive the secret start of the chain from the PRF key. @@ -456,7 +458,7 @@ def path(self, position: Uint64) -> HashTreeOpening: raise ValueError(f"Position {position} out of bounds.") # Collect sibling at each layer (except root). - siblings: List[HashDigestVector] = [] + siblings: list[HashDigestVector] = [] pos = position # Iterate over all layers except the last (root). @@ -547,7 +549,7 @@ def verify_path( parameter: Parameter, root: HashDigestVector, position: Uint64, - leaf_parts: List[HashDigestVector], + leaf_parts: list[HashDigestVector], opening: HashTreeOpening, ) -> bool: """ @@ -608,11 +610,7 @@ def verify_path( # Walk up: hash current with each sibling. for level, sibling in enumerate(opening.siblings): # Left child has even position, right child has odd. - if pos % 2 == 0: - left, right = current, sibling - else: - left, right = sibling, current - + left, right = (current, sibling) if pos % 2 == 0 else (sibling, current) pos //= 2 # Parent position. current = hasher.apply( parameter, diff --git a/src/lean_spec/subspecs/xmss/target_sum.py b/src/lean_spec/subspecs/xmss/target_sum.py index 955edfe4..4ab9f0ff 100644 --- a/src/lean_spec/subspecs/xmss/target_sum.py +++ b/src/lean_spec/subspecs/xmss/target_sum.py @@ -6,8 +6,6 @@ top of the message hash output. """ -from typing import List, Optional - from pydantic import model_validator from lean_spec.types import StrictBaseModel, Uint64 @@ -37,19 +35,16 @@ class TargetSumEncoder(StrictBaseModel): @model_validator(mode="after") def enforce_strict_types(self) -> "TargetSumEncoder": - """Validates that only exact approved types are used (rejects subclasses).""" - checks = {"config": XmssConfig, "message_hasher": MessageHasher} - for field, expected in checks.items(): - if type(getattr(self, field)) is not expected: - raise TypeError( - f"{field} must be exactly {expected.__name__}, " - f"got {type(getattr(self, field)).__name__}" - ) + """Reject subclasses to prevent type confusion attacks.""" + if type(self.config) is not XmssConfig: + raise TypeError("config must be exactly XmssConfig, not a subclass") + if type(self.message_hasher) is not MessageHasher: + raise TypeError("message_hasher must be exactly MessageHasher, not a subclass") return self def encode( self, parameter: Parameter, message: bytes, rho: Randomness, epoch: Uint64 - ) -> Optional[List[int]]: + ) -> list[int] | None: """ Encodes a message into a codeword if it meets the target sum criteria. diff --git a/src/lean_spec/subspecs/xmss/tweak_hash.py b/src/lean_spec/subspecs/xmss/tweak_hash.py index b3b7ca8a..54a0181f 100644 --- a/src/lean_spec/subspecs/xmss/tweak_hash.py +++ b/src/lean_spec/subspecs/xmss/tweak_hash.py @@ -26,8 +26,6 @@ from __future__ import annotations -from typing import List, Union - from pydantic import Field, model_validator from lean_spec.types import StrictBaseModel, Uint64 @@ -89,19 +87,14 @@ class TweakHasher(StrictBaseModel): @model_validator(mode="after") def enforce_strict_types(self) -> "TweakHasher": - """Validates that only exact approved types are used (rejects subclasses).""" - from .poseidon import PoseidonXmss - - checks = {"config": XmssConfig, "poseidon": PoseidonXmss} - for field, expected in checks.items(): - if type(getattr(self, field)) is not expected: - raise TypeError( - f"{field} must be exactly {expected.__name__}, " - f"got {type(getattr(self, field)).__name__}" - ) + """Reject subclasses to prevent type confusion attacks.""" + if type(self.config) is not XmssConfig: + raise TypeError("config must be exactly XmssConfig, not a subclass") + if type(self.poseidon) is not PoseidonXmss: + raise TypeError("poseidon must be exactly PoseidonXmss, not a subclass") return self - def _encode_tweak(self, tweak: Union[TreeTweak, ChainTweak], length: int) -> List[Fp]: + def _encode_tweak(self, tweak: TreeTweak | ChainTweak, length: int) -> list[Fp]: """ Encodes a structured tweak object into a list of field elements. @@ -148,8 +141,8 @@ def _encode_tweak(self, tweak: Union[TreeTweak, ChainTweak], length: int) -> Lis def apply( self, parameter: Parameter, - tweak: Union[TreeTweak, ChainTweak], - message_parts: List[HashDigestVector], + tweak: TreeTweak | ChainTweak, + message_parts: list[HashDigestVector], ) -> HashDigestVector: """ Applies the tweakable Poseidon2 hash function to a message. @@ -208,9 +201,7 @@ def apply( # # We use the robust sponge mode. # First, flatten the list of message parts into a single vector. - flattened_message: List[Fp] = [] - for part in message_parts: - flattened_message.extend(part.elements) + flattened_message = [elem for part in message_parts for elem in part.elements] input_vec = parameter.elements + encoded_tweak + flattened_message # Create a domain separator for the sponge mode based on the input dimensions. diff --git a/src/lean_spec/subspecs/xmss/utils.py b/src/lean_spec/subspecs/xmss/utils.py index 0a1d205a..35bdd745 100644 --- a/src/lean_spec/subspecs/xmss/utils.py +++ b/src/lean_spec/subspecs/xmss/utils.py @@ -1,7 +1,5 @@ """Utility functions for the XMSS signature scheme.""" -from typing import List - from ...types.uint import Uint64 from ..koalabear import Fp, P from .rand import Rand @@ -9,7 +7,7 @@ def get_padded_layer( - rand: Rand, nodes: List[HashDigestVector], start_index: Uint64 + rand: Rand, nodes: list[HashDigestVector], start_index: Uint64 ) -> HashTreeLayer: """ Pads a layer of nodes with random hashes to simplify tree construction. @@ -28,7 +26,7 @@ def get_padded_layer( Returns: A new `HashTreeLayer` with the necessary padding applied. """ - nodes_with_padding: List[HashDigestVector] = [] + nodes_with_padding: list[HashDigestVector] = [] end_index = start_index + Uint64(len(nodes)) - Uint64(1) # Prepend random padding if the layer starts at an odd index. @@ -51,7 +49,7 @@ def get_padded_layer( ) -def int_to_base_p(value: int, num_limbs: int) -> List[Fp]: +def int_to_base_p(value: int, num_limbs: int) -> list[Fp]: """ Decomposes a large integer into a list of base-P field elements. @@ -65,7 +63,7 @@ def int_to_base_p(value: int, num_limbs: int) -> List[Fp]: Returns: A list of `num_limbs` field elements representing the integer. """ - limbs: List[Fp] = [] + limbs: list[Fp] = [] acc = value for _ in range(num_limbs): limbs.append(Fp(value=acc)) diff --git a/tests/lean_spec/subspecs/xmss/test_hypercube.py b/tests/lean_spec/subspecs/xmss/test_hypercube.py index cfc6ffd4..a5ead775 100644 --- a/tests/lean_spec/subspecs/xmss/test_hypercube.py +++ b/tests/lean_spec/subspecs/xmss/test_hypercube.py @@ -115,7 +115,7 @@ def test_prepare_layer_sizes_against_reference() -> None: for v in range(1, MAX_DIMENSION + 1): # Note: The reference implementation returns reversed layer sizes. # Layer `d` in our spec corresponds to sum `k = v*(w-1) - d`. - expected_sizes_reordered = list(reversed(expected_sizes_by_v[v])) + expected_sizes_reordered = tuple(reversed(expected_sizes_by_v[v])) actual_sizes = actual_info_by_v[v].sizes assert expected_sizes_reordered == actual_sizes