Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions src/lean_spec/subspecs/xmss/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,32 @@ def verify(
public_key: PublicKey,
epoch: "Uint64",
message: bytes,
scheme: GeneralizedXmssScheme,
scheme: "GeneralizedXmssScheme",
) -> bool:
"""Verify the signature using XMSS verification algorithm."""
"""
Verify the signature using XMSS verification algorithm.

This is a convenience method that delegates to `scheme.verify()`.

Invalid or malformed signatures return `False`.

Expected exceptions:
- `ValueError` for invalid epochs,
- `IndexError` for malformed signatures
are caught and converted to `False`.

Args:
public_key: The public key to verify against.
epoch: The epoch the signature corresponds to.
message: The message that was supposedly signed.
scheme: The XMSS scheme instance to use for verification.

Returns:
`True` if the signature is valid, `False` otherwise.
"""
try:
return scheme.verify(public_key, epoch, message, self)
except Exception:
except (ValueError, IndexError):
return False


Expand Down
57 changes: 27 additions & 30 deletions src/lean_spec/subspecs/xmss/hypercube.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,30 +33,29 @@

import bisect
import math
from dataclasses import dataclass
from functools import lru_cache
from itertools import accumulate
from typing import List, Tuple

from lean_spec.types import StrictBaseModel

MAX_DIMENSION = 100
"""The maximum dimension `v` for which layer sizes will be precomputed."""


class LayerInfo(StrictBaseModel):
@dataclass(frozen=True, slots=True)
class LayerInfo:
"""
A data structure to store precomputed sizes and cumulative sums for the
layers of a single hypercube configuration (fixed `w` and `v`).
Precomputed sizes and cumulative sums for a hypercube configuration.

This object makes subsequent calculations, like finding the total size of a
range of layers, highly efficient.
This immutable data structure enables O(1) lookups for layer sizes and
range sums, which is critical for efficient hypercube mapping.
"""

sizes: List[int]
"""A list where `sizes[d]` is the number of vertices in layer `d`."""
prefix_sums: List[int]
sizes: tuple[int, ...]
"""Tuple where `sizes[d]` is the number of vertices in layer `d`."""

prefix_sums: tuple[int, ...]
"""
A list where `prefix_sums[d]` is the cumulative number of vertices from
Tuple where `prefix_sums[d]` is the cumulative number of vertices from
layer 0 up to and including layer `d`.

Mathematically: `prefix_sums[d] = sizes[0] + ... + sizes[d]`.
Expand Down Expand Up @@ -125,7 +124,7 @@ def _calculate_layer_size(w: int, v: int, d: int) -> int:


@lru_cache(maxsize=None)
def prepare_layer_info(w: int) -> List[LayerInfo]:
def prepare_layer_info(w: int) -> tuple[LayerInfo, ...]:
"""
Precomputes and caches layer information using a direct combinatorial formula.

Expand All @@ -138,24 +137,25 @@ def prepare_layer_info(w: int) -> List[LayerInfo]:
w: The base of the hypercube.

Returns:
A list where `list[v]` is a `LayerInfo` object for a `v`-dim hypercube.
A tuple where `tuple[v]` is a `LayerInfo` object for a `v`-dim hypercube.
"""
all_info = [LayerInfo(sizes=[], prefix_sums=[])] * (MAX_DIMENSION + 1)
# Initialize with empty placeholder for index 0
all_info: list[LayerInfo] = [LayerInfo(sizes=(), prefix_sums=())] * (MAX_DIMENSION + 1)

for v in range(1, MAX_DIMENSION + 1):
# The maximum possible distance `d` in a v-dimensional hypercube.
max_d = (w - 1) * v

# Directly compute the size of each layer using the helper function.
sizes = [_calculate_layer_size(w, v, d) for d in range(max_d + 1)]
sizes = tuple(_calculate_layer_size(w, v, d) for d in range(max_d + 1))

# Compute the cumulative sums from the list of sizes.
prefix_sums = list(accumulate(sizes))
# Compute the cumulative sums from the tuple of sizes.
prefix_sums = tuple(accumulate(sizes))

# Store the complete layer info for the current dimension `v`.
all_info[v] = LayerInfo(sizes=sizes, prefix_sums=prefix_sums)

return all_info
return tuple(all_info)


def get_layer_size(w: int, v: int, d: int) -> int:
Expand All @@ -168,7 +168,7 @@ def hypercube_part_size(w: int, v: int, d: int) -> int:
return prepare_layer_info(w)[v].prefix_sums[d]


def hypercube_find_layer(w: int, v: int, x: int) -> Tuple[int, int]:
def hypercube_find_layer(w: int, v: int, x: int) -> tuple[int, int]:
"""
Given a global index `x`, finds its layer `d` and local offset `remainder`.

Expand Down Expand Up @@ -203,7 +203,7 @@ def hypercube_find_layer(w: int, v: int, x: int) -> Tuple[int, int]:
return d, remainder


def map_to_vertex(w: int, v: int, d: int, x: int) -> List[int]:
def map_to_vertex(w: int, v: int, d: int, x: int) -> list[int]:
"""
Maps an integer index `x` to a unique vertex in a specific hypercube layer.

Expand All @@ -228,7 +228,7 @@ def map_to_vertex(w: int, v: int, d: int, x: int) -> List[int]:
if x >= layer_size:
raise ValueError("Index x is out of bounds for the given layer.")

vertex: List[int] = []
vertex: list[int] = []
# Track remaining distance and index.
d_curr, x_curr = d, x

Expand All @@ -239,20 +239,17 @@ def map_to_vertex(w: int, v: int, d: int, x: int) -> List[int]:

# This loop finds which block of sub-hypercubes the index `x_curr` falls into.
#
# It skips over full blocks by subtracting their size
# from `x_curr` until the correct one is found.
ji = -1 # Sentinel value
# It skips over full blocks by subtracting their size from `x_curr` until found.
ji = None
range_start = max(0, d_curr - (w - 1) * dim_remaining)
for j in range(range_start, min(w, d_curr + 1)):
count = prev_dim_layer_info.sizes[d_curr - j]
if x_curr >= count:
x_curr -= count
else:
# Found the correct block.
if x_curr < count:
ji = j
break
x_curr -= count

if ji == -1:
if ji is None:
raise RuntimeError("Internal logic error: failed to find coordinate")

# Convert the block's distance contribution `ji` to a coordinate `ai`.
Expand Down
Loading
Loading