Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 161 additions & 44 deletions packages/testing/src/consensus_testing/keys.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""XMSS key management utilities for testing."""

from typing import NamedTuple, Optional
from __future__ import annotations

from typing import Any, NamedTuple, Optional

from lean_spec.subspecs.containers import Attestation, Signature
from lean_spec.subspecs.containers.slot import Slot
Expand All @@ -10,7 +12,7 @@
TEST_SIGNATURE_SCHEME,
GeneralizedXmssScheme,
)
from lean_spec.types import ValidatorIndex
from lean_spec.types import Uint64, ValidatorIndex


class KeyPair(NamedTuple):
Expand All @@ -23,11 +25,11 @@ class KeyPair(NamedTuple):
"""The validator's secret key (used for signing)."""


_KEY_CACHE: dict[tuple[int, int], KeyPair] = {}
_KEY_CACHE: dict[tuple[int, int, int, int], KeyPair] = {}
"""
Cache keys across tests to avoid regenerating them for the same validator/lifetime combo.

Key: (validator_index, num_active_epochs) -> KeyPair
Key: (validator_index, activation_epoch, num_active_epochs, seed) -> KeyPair
"""


Expand All @@ -36,9 +38,17 @@ class XmssKeyManager:

DEFAULT_MAX_SLOT = Slot(100)
"""Default maximum slot horizon if not specified."""
DEFAULT_ACTIVATION_EPOCH = Uint64(0)
"""Default activation epoch when none is provided."""
DEFAULT_SEED = 0
"""Default deterministic seed when none is provided."""

def __init__(
self,
activation_epoch: Optional[Uint64 | Slot | int] = None,
*,
default_activation_epoch: Optional[Uint64 | Slot | int] = None,
default_seed: Optional[int] = None,
max_slot: Optional[Slot] = None,
scheme: GeneralizedXmssScheme = TEST_SIGNATURE_SCHEME,
) -> None:
Expand All @@ -47,6 +57,12 @@ def __init__(

Parameters
----------
activation_epoch : Uint64 | Slot | int, optional
Deprecated alias for `default_activation_epoch`.
default_activation_epoch : Uint64 | Slot | int, optional
Activation epoch used when none is provided for key generation.
default_seed : int, optional
Seed value used when none is provided for key generation.
max_slot : Slot, optional
Highest slot number for which keys must remain valid.
Defaults to `Slot(100)`.
Expand All @@ -58,14 +74,118 @@ def __init__(
-----
Internally, keys are stored in a single dictionary:
`{ValidatorIndex → KeyPair}`.

This class manages stateful XMSS keys for testing, handling the complexity of
epoch updates and key evolution that stateless helpers cannot provide.
"""
self.max_slot = max_slot if max_slot is not None else self.DEFAULT_MAX_SLOT
self.scheme = scheme
if activation_epoch is not None and default_activation_epoch is not None:
raise ValueError("Use either activation_epoch or default_activation_epoch, not both.")
effective_activation = (
default_activation_epoch if default_activation_epoch is not None else activation_epoch
)
activation_value = (
self.DEFAULT_ACTIVATION_EPOCH
if effective_activation is None
else self._coerce_uint64(effective_activation)
)
self._default_activation_epoch = activation_value
self._default_seed = int(default_seed) if default_seed is not None else self.DEFAULT_SEED
self._key_pairs: dict[ValidatorIndex, KeyPair] = {}
self._key_metadata: dict[ValidatorIndex, dict[str, Any]] = {}

@staticmethod
def _coerce_uint64(value: Uint64 | Slot | int) -> Uint64:
"""Convert supported numeric inputs to Uint64."""
if isinstance(value, Uint64):
return Uint64(int(value))
if isinstance(value, Slot):
return Uint64(value.as_int())
return Uint64(int(value))

@property
def default_max_epoch(self) -> int:
"""Default lifetime derived from the manager's configured max_slot."""
return self.default_num_active_epochs

@property
def default_num_active_epochs(self) -> int:
"""Number of epochs keys stay active when not overridden."""
return self.max_slot.as_int() + 1

@property
def default_activation_epoch(self) -> int:
"""Default activation epoch as an int."""
return int(self._default_activation_epoch)

@property
def default_seed(self) -> int:
"""Default seed used when none is provided."""
return self._default_seed

def create_and_store_key_pair(
self,
validator_index: ValidatorIndex,
*,
activation_epoch: Optional[Uint64 | Slot | int] = None,
num_active_epochs: Optional[Uint64 | Slot | int] = None,
seed: Optional[int] = None,
) -> KeyPair:
"""
Generate and store a key pair with explicit control over key generation.

Parameters
----------
validator_index : ValidatorIndex
The validator for whom a key pair should be generated.
activation_epoch : Uint64 | Slot | int, optional
First epoch for which the key is valid. Defaults to the manager's
configured `default_activation_epoch`.
num_active_epochs : Uint64 | Slot | int, optional
Number of consecutive epochs the key should remain active.
Defaults to `default_num_active_epochs` (derived from `max_slot` to include genesis).
seed : int, optional
Deterministic seed for caching/reuse. Defaults to manager's `default_seed`.
"""
activation_epoch_val = (
self._coerce_uint64(activation_epoch)
if activation_epoch is not None
else self._default_activation_epoch
)
num_active_epochs_val = (
self._coerce_uint64(num_active_epochs)
if num_active_epochs is not None
else self._coerce_uint64(self.default_num_active_epochs)
)
seed_val = int(seed) if seed is not None else self.default_seed

cache_key = (
int(validator_index),
int(activation_epoch_val),
int(num_active_epochs_val),
seed_val,
)

if cache_key in _KEY_CACHE:
key_pair = _KEY_CACHE[cache_key]
else:
pk, sk = self.scheme.key_gen(activation_epoch_val, num_active_epochs_val)
key_pair = KeyPair(public=pk, secret=sk)
_KEY_CACHE[cache_key] = key_pair

self._key_pairs[validator_index] = key_pair
self._key_metadata[validator_index] = {
"activation_epoch": int(activation_epoch_val),
"num_active_epochs": int(num_active_epochs_val),
"seed": seed_val,
}
# TODO: support multiple keys per validator keyed by activation_epoch.
return key_pair

def __getitem__(self, validator_index: ValidatorIndex) -> KeyPair:
"""
Retrieve or lazily generate a validators key pair.
Retrieve or lazily generate a validator's key pair.

Parameters
----------
Expand All @@ -75,45 +195,18 @@ def __getitem__(self, validator_index: ValidatorIndex) -> KeyPair:
Returns:
-------
KeyPair
The validator’s XMSS key pair.
XMSS key pair associated with the validator.

Notes:
-----
- Generates a new key if none exists.
- Keys are deterministic for testing (`seed=0`).
- Lifetime = `max_slot + 1` to include the genesis slot.
- Lifetime defaults to `default_num_active_epochs` to include the genesis slot.
"""
# Return cached keys if they exist.
if validator_index in self._key_pairs:
return self._key_pairs[validator_index]

# Generate New Key Pair
#
# XMSS requires knowing the total number of signatures in advance.
# We use max_slot + 1 as the lifetime since:
# - Validators may sign once per slot (attestations)
# - We include slot 0 (genesis) in the count
num_active_epochs = self.max_slot.as_int() + 1

# Check global cache first (keys are reused across tests)
cache_key = (int(validator_index), num_active_epochs)
if cache_key in _KEY_CACHE:
key_pair = _KEY_CACHE[cache_key]
self._key_pairs[validator_index] = key_pair
return key_pair

# Generate the key pair using the default XMSS scheme.
#
# The seed is set to 0 for deterministic test keys.
from lean_spec.types import Uint64

pk, sk = self.scheme.key_gen(Uint64(0), Uint64(num_active_epochs))

# Store as a cohesive unit and return.
key_pair = KeyPair(public=pk, secret=sk)
_KEY_CACHE[cache_key] = key_pair # Cache globally for reuse across tests
self._key_pairs[validator_index] = key_pair
return key_pair
return self.create_and_store_key_pair(validator_index)

def sign_attestation(self, attestation: Attestation) -> Signature:
"""
Expand Down Expand Up @@ -177,16 +270,40 @@ def sign_attestation(self, attestation: Attestation) -> Signature:
# Generate the XMSS signature using the validator's (now prepared) secret key.
xmss_sig = self.scheme.sign(sk, epoch, message)

# Convert the signature to the wire format (byte array).
signature_bytes = xmss_sig.to_bytes(self.scheme.config)
# Convert to the consensus Signature container (handles padding internally).
return Signature.from_xmss(xmss_sig, self.scheme)

# Ensure the signature meets the consensus spec length (3116 bytes).
#
# This is necessary when using TEST_CONFIG (796 bytes) vs PROD_CONFIG.
# Padding with zeros on the right maintains compatibility.
padded_bytes = signature_bytes.ljust(Signature.LENGTH, b"\x00")
def export_test_vectors(self, include_private_keys: bool = False) -> list[dict[str, Any]]:
"""
Export generated keys as dictionaries suitable for JSON test vectors.

Parameters
----------
include_private_keys : bool, optional
When True, include SecretKey contents for debugging fixtures.

return Signature(padded_bytes)
Returns:
-------
list[dict[str, Any]]
A list of entries keyed by validator_index with metadata and hex keys.
"""
vectors: list[dict[str, Any]] = []
for validator_index in sorted(self._key_pairs.keys(), key=int):
key_pair = self._key_pairs[validator_index]
metadata = self._key_metadata.get(validator_index, {})
entry: dict[str, Any] = {
"validator_index": int(validator_index),
"activation_epoch": metadata.get("activation_epoch", self.default_activation_epoch),
"num_active_epochs": metadata.get(
"num_active_epochs", self.default_num_active_epochs
),
"seed": metadata.get("seed", self.default_seed),
"public_key": key_pair.public.to_bytes(self.scheme.config).hex(),
}
if include_private_keys:
entry["secret_key"] = key_pair.secret.model_dump()
vectors.append(entry)
return vectors

def get_public_key(self, validator_index: ValidatorIndex) -> PublicKey:
"""
Expand All @@ -199,7 +316,7 @@ def get_public_key(self, validator_index: ValidatorIndex) -> PublicKey:
Returns:
-------
PublicKey
The validator’s public key.
Public key for the validator.
"""
return self[validator_index].public

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,20 @@ class ForkChoiceTest(BaseConsensusFixture):
valid up to the highest slot used in any block or attestation.
"""

key_manager_seed: int | None = None
"""
Optional deterministic seed to pass to the XMSS key manager.

When set, validators' keys and signatures become reproducible across runs.
"""

key_manager_activation_epoch: Slot | None = None
"""
Optional activation epoch to use when generating keys.

Defaults to the key manager's own default (0) when unset.
"""

@model_validator(mode="after")
def set_anchor_block_default(self) -> ForkChoiceTest:
"""
Expand Down Expand Up @@ -183,10 +197,24 @@ def make_fixture(self) -> ForkChoiceTest:
# Use shared key manager if it has sufficient capacity, otherwise create a new one
# This optimizes performance by reusing keys across tests when possible
shared_key_manager = _get_shared_key_manager()
use_shared = (
self.key_manager_seed is None
and self.key_manager_activation_epoch is None
and self.max_slot <= shared_key_manager.max_slot
)
key_manager = (
shared_key_manager
if self.max_slot <= shared_key_manager.max_slot
else XmssKeyManager(max_slot=self.max_slot, scheme=TEST_SIGNATURE_SCHEME)
if use_shared
else XmssKeyManager(
max_slot=self.max_slot,
scheme=TEST_SIGNATURE_SCHEME,
default_seed=self.key_manager_seed,
default_activation_epoch=(
self.key_manager_activation_epoch
if self.key_manager_activation_epoch is not None
else XmssKeyManager.DEFAULT_ACTIVATION_EPOCH
),
)
)

# Update validator pubkeys to match key_manager's generated keys
Expand Down
19 changes: 19 additions & 0 deletions src/lean_spec/subspecs/containers/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,22 @@ def verify(
return scheme.verify(public_key, epoch, message, signature)
except Exception:
return False

@classmethod
def from_xmss(
cls, xmss_signature: XmssSignature, scheme: GeneralizedXmssScheme = TEST_SIGNATURE_SCHEME
) -> Signature:
"""
Create a consensus `Signature` container from an XMSS signature object.

Applies the consensus-layer fixed-length padding, delegating all encoding
details to the XMSS container itself.
"""
raw = xmss_signature.to_bytes(scheme.config)
if len(raw) > cls.LENGTH:
raise ValueError(
f"XMSS signature length {len(raw)} exceeds container size {cls.LENGTH}"
)

# Pad on the right to the fixed-length container expected by consensus.
return cls(raw.ljust(cls.LENGTH, b"\x00"))
29 changes: 29 additions & 0 deletions tests/lean_spec/subspecs/containers/test_signature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Tests for consensus Signature container."""

from lean_spec.subspecs.containers import Signature
from lean_spec.subspecs.xmss.interface import TEST_SIGNATURE_SCHEME
from lean_spec.types import Uint64


class TestSignatureFromXmss:
"""Tests for Signature.from_xmss conversion method."""

def test_from_xmss_roundtrip_with_verify(self) -> None:
"""Test that a signature created via from_xmss can be verified."""

# Generate a test key pair
pk, sk = TEST_SIGNATURE_SCHEME.key_gen(Uint64(0), Uint64(10))

# Create a test message (must be exactly 32 bytes)
message = b"test message for signing123456\x00\x00" # 32 bytes
assert len(message) == 32
epoch = Uint64(0)

# Sign the message
xmss_sig = TEST_SIGNATURE_SCHEME.sign(sk, epoch, message)

# Convert to consensus signature
consensus_sig = Signature.from_xmss(xmss_sig, TEST_SIGNATURE_SCHEME)

# Verify using the consensus signature's verify method
assert consensus_sig.verify(pk, epoch, message, TEST_SIGNATURE_SCHEME)
Loading
Loading