Skip to content

Commit 1871a66

Browse files
authored
xmss: make message Bytes32 (#335)
1 parent 15081d3 commit 1871a66

11 files changed

Lines changed: 37 additions & 39 deletions

File tree

src/lean_spec/subspecs/validator/service.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -332,11 +332,10 @@ def _sign_block(
332332
if entry is None:
333333
raise ValueError(f"No secret key for validator {validator_index}")
334334

335-
message_bytes = proposer_attestation_data.data_root_bytes()
336335
proposer_signature = TARGET_SIGNATURE_SCHEME.sign(
337336
entry.secret_key,
338337
block.slot,
339-
bytes(message_bytes),
338+
proposer_attestation_data.data_root_bytes(),
340339
)
341340

342341
# Create the message wrapper.
@@ -385,11 +384,10 @@ def _sign_attestation(
385384
# Sign the attestation data root.
386385
#
387386
# Uses XMSS one-time signature for the current epoch (slot).
388-
message_bytes = attestation_data.data_root_bytes()
389387
signature = TARGET_SIGNATURE_SCHEME.sign(
390388
entry.secret_key,
391389
attestation_data.slot,
392-
bytes(message_bytes),
390+
attestation_data.data_root_bytes(),
393391
)
394392

395393
return SignedAttestation(

src/lean_spec/subspecs/xmss/aggregation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def aggregate(
6565
participants: AggregationBits,
6666
public_keys: Sequence[PublicKey],
6767
signatures: Sequence[Signature],
68-
message: bytes,
68+
message: Bytes32,
6969
epoch: Uint64,
7070
mode: str | None = None,
7171
) -> Self:
@@ -106,7 +106,7 @@ def aggregate(
106106
def verify(
107107
self,
108108
public_keys: Sequence[PublicKey],
109-
message: bytes,
109+
message: Bytes32,
110110
epoch: Uint64,
111111
mode: str | None = None,
112112
) -> None:

src/lean_spec/subspecs/xmss/containers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from typing import TYPE_CHECKING, Mapping, NamedTuple
1111

12-
from ...types import Uint64
12+
from ...types import Bytes32, Uint64
1313
from ...types.container import Container
1414
from .subtree import HashSubTree
1515
from .types import (
@@ -71,7 +71,7 @@ def verify(
7171
self,
7272
public_key: PublicKey,
7373
epoch: "Uint64",
74-
message: bytes,
74+
message: "Bytes32",
7575
scheme: "GeneralizedXmssScheme",
7676
) -> bool:
7777
"""

src/lean_spec/subspecs/xmss/interface.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
TEST_TARGET_SUM_ENCODER,
1717
TargetSumEncoder,
1818
)
19-
from lean_spec.types import StrictBaseModel, Uint64
19+
from lean_spec.types import Bytes32, StrictBaseModel, Uint64
2020

2121
from ._validation import enforce_strict_types
2222
from .constants import (
@@ -220,7 +220,7 @@ def key_gen(self, activation_epoch: Uint64, num_active_epochs: Uint64) -> KeyPai
220220
)
221221
return KeyPair(public=pk, secret=sk)
222222

223-
def sign(self, sk: SecretKey, epoch: Uint64, message: bytes) -> Signature:
223+
def sign(self, sk: SecretKey, epoch: Uint64, message: Bytes32) -> Signature:
224224
"""
225225
Produces a digital signature for a given message at a specific epoch.
226226
@@ -362,7 +362,7 @@ def sign(self, sk: SecretKey, epoch: Uint64, message: bytes) -> Signature:
362362
# - The randomness `rho` needed for verification.
363363
return Signature(path=path, rho=rho, hashes=HashDigestList(data=ots_hashes))
364364

365-
def verify(self, pk: PublicKey, epoch: Uint64, message: bytes, sig: Signature) -> bool:
365+
def verify(self, pk: PublicKey, epoch: Uint64, message: Bytes32, sig: Signature) -> bool:
366366
r"""
367367
Verifies a digital signature against a public key, message, and epoch.
368368

src/lean_spec/subspecs/xmss/message_hash.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
TEST_POSEIDON,
3737
PoseidonXmss,
3838
)
39-
from lean_spec.types import StrictBaseModel, Uint64
39+
from lean_spec.types import Bytes32, StrictBaseModel, Uint64
4040

4141
from ..koalabear import Fp, P
4242
from ._validation import enforce_strict_types
@@ -70,7 +70,7 @@ def _validate_strict_types(self) -> "MessageHasher":
7070
enforce_strict_types(self, config=XmssConfig, poseidon=PoseidonXmss)
7171
return self
7272

73-
def encode_message(self, message: bytes) -> list[Fp]:
73+
def encode_message(self, message: Bytes32) -> list[Fp]:
7474
"""
7575
Encodes a 32-byte message into a list of field elements.
7676
@@ -145,7 +145,7 @@ def apply(
145145
parameter: Parameter,
146146
epoch: Uint64,
147147
rho: Randomness,
148-
message: bytes,
148+
message: Bytes32,
149149
) -> list[int]:
150150
"""
151151
Applies the full "Top Level" message hash and mapping procedure.

src/lean_spec/subspecs/xmss/prf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from pydantic import model_validator
1616

1717
from lean_spec.subspecs.koalabear import Fp
18-
from lean_spec.types import StrictBaseModel, Uint64
18+
from lean_spec.types import Bytes32, StrictBaseModel, Uint64
1919

2020
from ._validation import enforce_strict_types
2121
from .constants import (
@@ -178,7 +178,7 @@ def apply(self, key: PRFKey, epoch: Uint64, chain_index: Uint64) -> HashDigestVe
178178
return HashDigestVector(data=_bytes_to_field_elements(prf_output_bytes, config.HASH_LEN_FE))
179179

180180
def get_randomness(
181-
self, key: PRFKey, epoch: Uint64, message: bytes, counter: Uint64
181+
self, key: PRFKey, epoch: Uint64, message: Bytes32, counter: Uint64
182182
) -> Randomness:
183183
"""
184184
Derives pseudorandom field elements for use in deterministic signing.

src/lean_spec/subspecs/xmss/target_sum.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from pydantic import model_validator
1010

11-
from lean_spec.types import StrictBaseModel, Uint64
11+
from lean_spec.types import Bytes32, StrictBaseModel, Uint64
1212

1313
from ._validation import enforce_strict_types
1414
from .constants import PROD_CONFIG, TEST_CONFIG, XmssConfig
@@ -41,7 +41,7 @@ def _validate_strict_types(self) -> "TargetSumEncoder":
4141
return self
4242

4343
def encode(
44-
self, parameter: Parameter, message: bytes, rho: Randomness, epoch: Uint64
44+
self, parameter: Parameter, message: Bytes32, rho: Randomness, epoch: Uint64
4545
) -> list[int] | None:
4646
"""
4747
Encodes a message into a codeword if it meets the target sum criteria.

tests/lean_spec/subspecs/validator/test_service.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,7 @@ async def produce_block() -> None:
601601
is_valid = TARGET_SIGNATURE_SCHEME.verify(
602602
pk=proposer_public_key,
603603
epoch=signed_block.message.block.slot,
604-
message=bytes(message_bytes),
604+
message=message_bytes,
605605
sig=signed_block.signature.proposer_signature,
606606
)
607607
assert is_valid, "Proposer signature failed verification"
@@ -648,7 +648,7 @@ async def produce_attestations() -> None:
648648
is_valid = TARGET_SIGNATURE_SCHEME.verify(
649649
pk=public_key,
650650
epoch=signed_att.message.slot,
651-
message=bytes(message_bytes),
651+
message=message_bytes,
652652
sig=signed_att.signature,
653653
)
654654
assert is_valid, f"Attestation signature for validator {validator_id} failed"
@@ -750,7 +750,7 @@ async def produce_block() -> None:
750750
is_valid = TARGET_SIGNATURE_SCHEME.verify(
751751
pk=public_key,
752752
epoch=signed_block.message.block.slot,
753-
message=bytes(message_bytes),
753+
message=message_bytes,
754754
sig=signed_block.signature.proposer_signature,
755755
)
756756
assert is_valid
@@ -1005,7 +1005,7 @@ async def produce_attestations() -> None:
10051005
is_valid = TARGET_SIGNATURE_SCHEME.verify(
10061006
pk=public_key,
10071007
epoch=test_slot, # Must match the signing slot
1008-
message=bytes(message_bytes),
1008+
message=message_bytes,
10091009
sig=signed_att.signature,
10101010
)
10111011
assert is_valid, f"Signature for validator {validator_id} at slot {test_slot} failed"
@@ -1015,7 +1015,7 @@ async def produce_attestations() -> None:
10151015
is_invalid = TARGET_SIGNATURE_SCHEME.verify(
10161016
pk=public_key,
10171017
epoch=wrong_epoch,
1018-
message=bytes(message_bytes),
1018+
message=message_bytes,
10191019
sig=signed_att.signature,
10201020
)
10211021
assert not is_invalid, "Signature should fail with wrong epoch"

tests/lean_spec/subspecs/xmss/test_interface.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
TEST_SIGNATURE_SCHEME,
99
GeneralizedXmssScheme,
1010
)
11-
from lean_spec.types import Uint64
11+
from lean_spec.types import Bytes32, Uint64
1212

1313

1414
def _test_correctness_roundtrip(
@@ -32,7 +32,7 @@ def _test_correctness_roundtrip(
3232
#
3333
# Pick a sample epoch within the active range to test signing.
3434
test_epoch = Uint64(activation_epoch + num_active_epochs // 2)
35-
message = b"\x42" * scheme.config.MESSAGE_LENGTH
35+
message = Bytes32(b"\x42" * 32)
3636

3737
# Sign the message at the chosen epoch.
3838
#
@@ -46,7 +46,7 @@ def _test_correctness_roundtrip(
4646
# TEST INVALID CASES
4747
#
4848
# Verification must fail if the message is tampered with.
49-
tampered_message = b"\x43" * scheme.config.MESSAGE_LENGTH
49+
tampered_message = Bytes32(b"\x43" * 32)
5050

5151
# With small test parameters (test configuration), there's a small chance that
5252
# the tampered message produces the same codeword as the original due to
@@ -176,7 +176,7 @@ def test_sign_requires_prepared_interval() -> None:
176176
assert int(outside_epoch) not in prepared_interval
177177

178178
# Signing should fail
179-
message = b"\x42" * scheme.config.MESSAGE_LENGTH
179+
message = Bytes32(b"\x42" * 32)
180180
with pytest.raises(ValueError, match="outside the prepared interval"):
181181
scheme.sign(sk, outside_epoch, message)
182182

@@ -189,7 +189,7 @@ def test_deterministic_signing() -> None:
189189

190190
# Use epoch within prepared interval
191191
epoch = Uint64(4)
192-
message = b"\x42" * scheme.config.MESSAGE_LENGTH
192+
message = Bytes32(b"\x42" * 32)
193193

194194
# Sign twice
195195
sig1 = scheme.sign(sk, epoch, message)
@@ -218,7 +218,7 @@ def test_rejects_epoch_beyond_lifetime(self) -> None:
218218

219219
# Sign a valid message at a valid epoch.
220220
valid_epoch = Uint64(4)
221-
message = b"\x42" * scheme.config.MESSAGE_LENGTH
221+
message = Bytes32(b"\x42" * 32)
222222
signature = scheme.sign(sk, valid_epoch, message)
223223

224224
# Verify with an epoch beyond LIFETIME.
@@ -234,7 +234,7 @@ def test_rejects_very_large_epoch(self) -> None:
234234
pk, sk = scheme.key_gen(Uint64(0), Uint64(scheme.config.LIFETIME))
235235

236236
valid_epoch = Uint64(4)
237-
message = b"\x42" * scheme.config.MESSAGE_LENGTH
237+
message = Bytes32(b"\x42" * 32)
238238
signature = scheme.sign(sk, valid_epoch, message)
239239

240240
# Try to verify with a huge epoch.

tests/lean_spec/subspecs/xmss/test_message_hash.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
)
1313
from lean_spec.subspecs.xmss.rand import TEST_RAND
1414
from lean_spec.subspecs.xmss.utils import int_to_base_p
15-
from lean_spec.types import Uint64
15+
from lean_spec.types import Bytes32, Uint64
1616

1717

1818
def test_encode_message() -> None:
@@ -21,13 +21,13 @@ def test_encode_message() -> None:
2121
hasher = TEST_MESSAGE_HASHER
2222

2323
# All-zero message
24-
msg_zeros = b"\x00" * config.MESSAGE_LENGTH
24+
msg_zeros = Bytes32(b"\x00" * 32)
2525
encoded_zeros = hasher.encode_message(msg_zeros)
2626
assert len(encoded_zeros) == config.MSG_LEN_FE
2727
assert all(fe.value == 0 for fe in encoded_zeros)
2828

2929
# All-max message (0xff)
30-
msg_max = b"\xff" * config.MESSAGE_LENGTH
30+
msg_max = Bytes32(b"\xff" * 32)
3131
acc = int.from_bytes(msg_max, "little")
3232
expected_max = int_to_base_p(acc, config.MSG_LEN_FE)
3333
assert hasher.encode_message(msg_max) == expected_max
@@ -70,7 +70,7 @@ def test_apply_output_is_in_correct_hypercube_part() -> None:
7070
parameter = rand.parameter()
7171
epoch = Uint64(313)
7272
randomness = rand.rho()
73-
message = b"\xaa" * config.MESSAGE_LENGTH
73+
message = Bytes32(b"\xaa" * 32)
7474

7575
# Call the message hash function.
7676
vertex = hasher.apply(parameter, epoch, randomness, message)

0 commit comments

Comments
 (0)