diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f4d30bf..fea79f9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -128,6 +128,13 @@ jobs: set -euo pipefail echo "Running benchmark/benchmark.py for cross-language SSZ compatibility (lifetimes 2^8 and 2^32)" python3 benchmark/benchmark.py --lifetime "2^8,2^32" + + - name: Test pre-generated keys (SSZ) + shell: bash + run: | + set -euo pipefail + echo "Running inspect_pregenerated_keys.py to verify pre-generated key compatibility" + python3 benchmark/inspect_pregenerated_keys.py cross-platform-build: name: Cross-Platform Build (${{ matrix.os }}) diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py index e2bf16a..05e0707 100644 --- a/benchmark/benchmark.py +++ b/benchmark/benchmark.py @@ -126,8 +126,17 @@ def parse_args() -> argparse.Namespace: def build_scenarios(lifetimes: list[str], seed_hex: str) -> list[ScenarioConfig]: scenarios: list[ScenarioConfig] = [] for lifetime in lifetimes: - # Use 1024 active epochs for 2^32 lifetime, 256 for others - num_active_epochs = 1024 if lifetime == "2^32" else 256 + # Use appropriate active epochs for each lifetime + # Note: Zig/Rust multiply by 128 internally, so actual = input * 128 + # 2^8: max 256 signatures, so max input = 256/128 = 2 + # 2^18: max 262,144 signatures, so we can use 256 (256*128=32,768) + # 2^32: max 4B+ signatures, so we can use 1024 (1024*128=131,072) + if lifetime == "2^32": + num_active_epochs = 1024 + elif lifetime == "2^18": + num_active_epochs = 256 + else: # 2^8 + num_active_epochs = 2 scenarios.append( ScenarioConfig( lifetime=lifetime, @@ -262,6 +271,8 @@ def run_rust_sign(cfg: ScenarioConfig, paths: Dict[str, Path]) -> OperationResul tmp_dir.mkdir(exist_ok=True) # Save active epochs to file for the tool to read + # Note: Rust leansig multiplies by 128 internally, Zig now does the same + # So we pass the same value to both (tmp_dir / "rust_active_epochs.txt").write_text(str(cfg.num_active_epochs)) # Generate keypair first @@ -416,6 +427,68 @@ def run_rust_verify( return OperationResult(success, duration, result.stdout, result.stderr) +def compare_file_sizes(cfg: ScenarioConfig, paths: Dict[str, Path]) -> bool: + """Compare file sizes between Rust and Zig generated files""" + print(f"\n-- Comparing file sizes --") + + discrepancies = [] + + # Compare public keys + rust_pk_size = paths["rust_pk"].stat().st_size if paths["rust_pk"].exists() else 0 + zig_pk_size = paths["zig_pk"].stat().st_size if paths["zig_pk"].exists() else 0 + + print(f"Public key sizes:") + print(f" Rust: {rust_pk_size} bytes") + print(f" Zig: {zig_pk_size} bytes") + + if rust_pk_size != zig_pk_size: + discrepancies.append(f"Public key size mismatch: Rust={rust_pk_size} vs Zig={zig_pk_size}") + print(f" ⚠️ WARNING: Public key sizes differ!") + else: + print(f" ✅ Public key sizes match") + + # Compare signatures + rust_sig_size = paths["rust_sig"].stat().st_size if paths["rust_sig"].exists() else 0 + zig_sig_size = paths["zig_sig"].stat().st_size if paths["zig_sig"].exists() else 0 + + print(f"Signature sizes:") + print(f" Rust: {rust_sig_size} bytes") + print(f" Zig: {zig_sig_size} bytes") + + if rust_sig_size != zig_sig_size: + discrepancies.append(f"Signature size mismatch: Rust={rust_sig_size} vs Zig={zig_sig_size}") + print(f" ⚠️ WARNING: Signature sizes differ!") + else: + print(f" ✅ Signature sizes match") + + # Compare secret keys (check both project tmp and rust_benchmark tmp) + rust_sk_path = REPO_ROOT / "benchmark" / "rust_benchmark" / "tmp" / "rust_sk.ssz" + zig_sk_path = REPO_ROOT / "tmp" / "zig_sk.ssz" + + if rust_sk_path.exists() and zig_sk_path.exists(): + rust_sk_size = rust_sk_path.stat().st_size + zig_sk_size = zig_sk_path.stat().st_size + + print(f"Secret key sizes:") + print(f" Rust: {rust_sk_size:,} bytes") + print(f" Zig: {zig_sk_size:,} bytes") + + if rust_sk_size != zig_sk_size: + discrepancies.append(f"Secret key size mismatch: Rust={rust_sk_size:,} vs Zig={zig_sk_size:,}") + print(f" ⚠️ WARNING: Secret key sizes differ!") + print(f" This indicates incompatible SSZ serialization formats!") + print(f" Rust includes full trees, Zig may only save metadata.") + else: + print(f" ✅ Secret key sizes match") + + if discrepancies: + print(f"\n⚠️ {len(discrepancies)} size discrepanc{'y' if len(discrepancies) == 1 else 'ies'} detected:") + for disc in discrepancies: + print(f" - {disc}") + return False + + return True + def run_scenario(cfg: ScenarioConfig, timeout_2_32: int) -> tuple[Dict[str, OperationResult], Dict[str, Path]]: print(f"\n=== Scenario: {cfg.label} ===") paths = scenario_paths(cfg) @@ -438,6 +511,11 @@ def run_scenario(cfg: ScenarioConfig, timeout_2_32: int) -> tuple[Dict[str, Oper # signature must be accepted by the Rust verifier. results["zig_to_rust"] = run_rust_verify(cfg, paths["zig_pk"], paths["zig_sig"], "Zig sign → Rust verify") + # Compare file sizes to detect serialization format mismatches + sizes_match = compare_file_sizes(cfg, paths) + if not sizes_match: + print("\n⚠️ WARNING: File size mismatches detected. This may indicate serialization format issues.") + return results, paths diff --git a/benchmark/inspect_pregenerated_keys.py b/benchmark/inspect_pregenerated_keys.py new file mode 100755 index 0000000..0e1df57 --- /dev/null +++ b/benchmark/inspect_pregenerated_keys.py @@ -0,0 +1,408 @@ +#!/usr/bin/env python3 +""" +Inspect and test pre-generated keys from pre-generated-keys/ directory. +These keys are for lifetime 2^32 with 1024 active epochs, serialized in SSZ format. + +This script: +1. Picks a random validator key file +2. Deserializes using both Rust and Zig tools +3. Compares public keys and parameters +4. Runs cross-language compatibility tests +""" + +import subprocess +import sys +from pathlib import Path +from typing import Dict, List, Tuple +import time +import random + +# Paths +SCRIPT_DIR = Path(__file__).parent +REPO_ROOT = SCRIPT_DIR.parent +RUST_BIN = REPO_ROOT / "benchmark/rust_benchmark/target/release/cross_lang_rust_tool" +ZIG_BIN = REPO_ROOT / "zig-out/bin/cross-lang-zig-tool" +PREGENERATED_KEYS_DIR = SCRIPT_DIR / "pre-generated-keys" + +# Lifetime for pre-generated keys +LIFETIME = "2^32" +ACTIVE_EPOCHS = None # Will be read from deserialized secret key + +def run_command(cmd: List[str], cwd: Path = REPO_ROOT) -> Tuple[int, str, str]: + """Run a command and return (returncode, stdout, stderr)""" + result = subprocess.run( + cmd, + cwd=cwd, + capture_output=True, + text=True, + ) + return result.returncode, result.stdout, result.stderr + +def inspect_key_rust(sk_path: Path, pk_path: Path) -> Dict[str, any]: + """Inspect a key using Rust tool""" + print(f"\n📋 Rust: Deserializing keys...") + + # Use the inspect command + returncode, stdout, stderr = run_command([ + str(RUST_BIN), "inspect", str(sk_path), str(pk_path), LIFETIME + ]) + + if returncode != 0: + print(f" ❌ Failed to inspect keys: {stderr}") + return None + + # Print the output + print(stderr, end='') + + # Parse key information from output + sk_size = sk_path.stat().st_size + pk_size = pk_path.stat().st_size + + # Extract first 8 bytes of public key from output + pk_hex = None + for line in stderr.split('\n'): + if "Public key (first 8 bytes)" in line: + # Extract hex bytes + parts = line.split('[') + if len(parts) > 1: + pk_hex = parts[1].split(']')[0].strip() + + return { + "sk_size": sk_size, + "pk_size": pk_size, + "pk_hex": pk_hex, + "estimated_keys": ACTIVE_EPOCHS, + } + +def inspect_key_zig(sk_path: Path, pk_path: Path) -> Dict[str, any]: + """Inspect a key using Zig tool""" + print(f"\n📋 Zig: Deserializing keys...") + + # Use the inspect command + returncode, stdout, stderr = run_command([ + str(ZIG_BIN), "inspect", str(sk_path), str(pk_path), LIFETIME + ]) + + if returncode != 0: + print(f" ❌ Failed to inspect keys: {stdout}{stderr}") + return None + + # Zig std.debug.print outputs to stderr, not stdout + output = stderr if stderr else stdout + print(output) + + # Parse key information from output + sk_size = sk_path.stat().st_size + pk_size = pk_path.stat().st_size + + # Extract information from output + pk_hex = None + activation_epoch = None + num_active_epochs = None + left_bottom_tree_index = None + + for line in output.split('\n'): + if "Public key (first 8 bytes):" in line: + parts = line.split(':') + if len(parts) > 1: + pk_hex = parts[-1].strip() + elif "Activation epoch:" in line: + parts = line.split(':') + if len(parts) > 1: + activation_epoch = int(parts[-1].strip()) + elif "Num active epochs:" in line: + parts = line.split(':') + if len(parts) > 1: + num_active_epochs = int(parts[-1].strip()) + elif "Left bottom tree index:" in line: + parts = line.split(':') + if len(parts) > 1: + left_bottom_tree_index = int(parts[-1].strip()) + + return { + "sk_size": sk_size, + "pk_size": pk_size, + "pk_hex": pk_hex, + "activation_epoch": activation_epoch, + "num_active_epochs": num_active_epochs, + "left_bottom_tree_index": left_bottom_tree_index, + } + +def test_cross_language_with_pregenerated(validator_id: int) -> bool: + """Test cross-language compatibility using pre-generated keys""" + print(f"\n{'='*60}") + print(f"Testing Validator {validator_id}") + print(f"{'='*60}") + + sk_path = PREGENERATED_KEYS_DIR / f"validator_{validator_id}_sk.ssz" + pk_path = PREGENERATED_KEYS_DIR / f"validator_{validator_id}_pk.ssz" + + if not sk_path.exists() or not pk_path.exists(): + print(f"❌ Keys not found for validator {validator_id}") + return False + + # Inspect with both tools + rust_info = inspect_key_rust(sk_path, pk_path) + if rust_info is None: + return False + + zig_info = inspect_key_zig(sk_path, pk_path) + if zig_info is None: + return False + + # Extract actual active epochs from deserialized key + actual_active_epochs = zig_info.get('num_active_epochs') or rust_info.get('num_active_epochs') + + print(f"\n📊 Key Comparison:") + print(f" Lifetime: {LIFETIME}") + print(f" Activation Epoch: {zig_info.get('activation_epoch', 'N/A')}") + print(f" Num Active Epochs: {actual_active_epochs}") + print(f" Left Bottom Tree Index: {zig_info.get('left_bottom_tree_index', 'N/A')}") + print(f" Secret Key Size: {rust_info['sk_size']:,} bytes") + print(f" Public Key Size: {rust_info['pk_size']} bytes") + + # Compare public keys + print(f"\n🔍 Public Key Comparison:") + print(f" Rust (first 8 bytes): {rust_info['pk_hex']}") + print(f" Zig (first 8 bytes): {zig_info['pk_hex']}") + + if rust_info['pk_hex'] and zig_info['pk_hex']: + # Normalize hex strings for comparison + # Rust format: "[db, 0c, 25, 12, f4, 7f, 26, 09]" + # Zig format: "db0c2512f47f2609" + rust_hex_norm = rust_info['pk_hex'].replace('[', '').replace(']', '').replace(' ', '').replace(',', '').lower() + zig_hex_norm = zig_info['pk_hex'].replace(' ', '').lower() + + print(f"\n Normalized comparison:") + print(f" Rust: {rust_hex_norm}") + print(f" Zig: {zig_hex_norm}") + + if rust_hex_norm == zig_hex_norm: + print(f"\n ✅ PUBLIC KEYS MATCH - Both tools deserialized the same value!") + else: + print(f"\n ❌ PUBLIC KEYS DIFFER - Deserialization mismatch detected!") + print(f" This indicates a bug in one of the implementations.") + return False + else: + print(f"\n ❌ Could not extract public keys from tool outputs") + return False + + # Test signing and verification + message = f"Test message for validator {validator_id}" + epoch = 0 # Test with epoch 0 + + print(f"\n🔐 Testing Signing and Verification:") + print(f" Message: '{message}'") + print(f" Epoch: {epoch}") + + # Copy keys to tmp directory for tools to use + import shutil + tmp_dir = REPO_ROOT / "tmp" + tmp_dir.mkdir(exist_ok=True) + + # Test 1: Rust sign → Rust verify + print(f"\n [1/4] Rust sign → Rust verify...") + shutil.copy2(sk_path, tmp_dir / "rust_sk.ssz") + shutil.copy2(pk_path, tmp_dir / "rust_pk.ssz") + + # Write lifetime to file + (tmp_dir / "rust_lifetime.txt").write_text(LIFETIME) + + start = time.time() + returncode, stdout, stderr = run_command([ + str(RUST_BIN), "sign", message, str(epoch), "--ssz" + ]) + sign_time = time.time() - start + + if returncode != 0: + print(f" ❌ FAIL (sign failed: {stderr})") + return False + + rust_sig_path = tmp_dir / "rust_sig.ssz" + if not rust_sig_path.exists(): + print(f" ❌ FAIL (signature not created)") + return False + + start = time.time() + returncode, stdout, stderr = run_command([ + str(RUST_BIN), "verify", str(rust_sig_path), str(tmp_dir / "rust_pk.ssz"), + message, str(epoch), "--ssz" + ]) + verify_time = time.time() - start + + if returncode == 0: + print(f" ✅ PASS (sign: {sign_time:.3f}s, verify: {verify_time:.3f}s)") + else: + print(f" ❌ FAIL (verification failed)") + return False + + # Test 2: Rust sign → Zig verify + print(f" [2/4] Rust sign → Zig verify...") + start = time.time() + returncode, stdout, stderr = run_command([ + str(ZIG_BIN), "verify", str(rust_sig_path), str(pk_path), + message, str(epoch), "--ssz" + ]) + verify_time = time.time() - start + + if returncode == 0: + print(f" ✅ PASS (verify: {verify_time:.3f}s)") + else: + print(f" ❌ FAIL (verification failed)") + print(f" stderr: {stderr}") + return False + + # Test 3: Zig sign → Zig verify + print(f" [3/4] Zig sign → Zig verify...") + shutil.copy2(sk_path, tmp_dir / "zig_sk.ssz") + shutil.copy2(pk_path, tmp_dir / "zig_pk.ssz") + + # Write lifetime and active epochs to files (use actual value from deserialized key) + (tmp_dir / "zig_lifetime.txt").write_text(LIFETIME) + if actual_active_epochs: + (tmp_dir / "zig_active_epochs.txt").write_text(str(actual_active_epochs)) + + start = time.time() + returncode, stdout, stderr = run_command([ + str(ZIG_BIN), "sign", message, str(epoch), "--ssz" + ]) + sign_time = time.time() - start + + if returncode != 0: + print(f" ❌ FAIL (sign failed: {stderr})") + return False + + zig_sig_path = tmp_dir / "zig_sig.ssz" + if not zig_sig_path.exists(): + print(f" ❌ FAIL (signature not created)") + return False + + # Use the updated public key from tmp (Zig regenerates keypair during signing) + zig_pk_updated = tmp_dir / "zig_pk.ssz" + + start = time.time() + returncode, stdout, stderr = run_command([ + str(ZIG_BIN), "verify", str(zig_sig_path), str(zig_pk_updated), + message, str(epoch), "--ssz" + ]) + verify_time = time.time() - start + + if returncode == 0: + print(f" ✅ PASS (sign: {sign_time:.3f}s, verify: {verify_time:.3f}s)") + else: + print(f" ❌ FAIL (verification failed)") + return False + + # Test 4: Zig sign → Rust verify (use updated public key) + print(f" [4/4] Zig sign → Rust verify...") + start = time.time() + returncode, stdout, stderr = run_command([ + str(RUST_BIN), "verify", str(zig_sig_path), str(zig_pk_updated), + message, str(epoch), "--ssz" + ]) + verify_time = time.time() - start + + if returncode == 0: + print(f" ✅ PASS (verify: {verify_time:.3f}s)") + else: + print(f" ❌ FAIL (verification failed)") + return False + + print(f"\n ✅ All tests passed for validator {validator_id}!") + return True + +def main(): + print("="*60) + print("Pre-Generated Keys Inspection and Testing") + print("="*60) + print(f"Lifetime: {LIFETIME}") + print(f"Active Epochs: {ACTIVE_EPOCHS}") + print(f"Keys Directory: {PREGENERATED_KEYS_DIR}") + + # Check if tools are built + if not RUST_BIN.exists(): + print(f"\n❌ Rust tool not found: {RUST_BIN}") + print(" Run: cd benchmark/rust_benchmark && cargo build --release") + return 1 + + if not ZIG_BIN.exists(): + print(f"\n❌ Zig tool not found: {ZIG_BIN}") + print(" Run: zig build install -Doptimize=ReleaseFast") + return 1 + + # Find all validator keys + validator_keys = sorted(PREGENERATED_KEYS_DIR.glob("validator_*_sk.ssz")) + if not validator_keys: + print(f"\n❌ No validator keys found in {PREGENERATED_KEYS_DIR}") + return 1 + + validator_ids = [] + for sk_path in validator_keys: + # Extract validator ID from filename: validator_N_sk.ssz + name = sk_path.stem # validator_N_sk + parts = name.split("_") + if len(parts) >= 2: + try: + validator_id = int(parts[1]) + validator_ids.append(validator_id) + except ValueError: + pass + + print(f"\nFound {len(validator_ids)} validator key(s): {validator_ids}") + + # Pick a random validator + selected_validator = random.choice(validator_ids) + print(f"\n🎲 Randomly selected validator: {selected_validator}") + + # Test the selected validator and get actual_active_epochs + sk_path = PREGENERATED_KEYS_DIR / f"validator_{selected_validator}_sk.ssz" + pk_path = PREGENERATED_KEYS_DIR / f"validator_{selected_validator}_pk.ssz" + + # Quick inspection to get actual_active_epochs + returncode, stdout, stderr = run_command([ + str(ZIG_BIN), "inspect", str(sk_path), str(pk_path), LIFETIME + ]) + output = stderr if stderr else stdout + actual_active_epochs_value = None + for line in output.split('\n'): + if "Num active epochs:" in line: + parts = line.split(':') + if len(parts) > 1: + actual_active_epochs_value = int(parts[-1].strip()) + break + + all_passed = test_cross_language_with_pregenerated(selected_validator) + + # Summary + print(f"\n{'='*60}") + print("Summary") + print(f"{'='*60}") + + # Print key information table + print(f"\n📊 Pre-Generated Keys Information:") + print(f" Lifetime: {LIFETIME}") + print(f" Active Epochs: {actual_active_epochs_value if actual_active_epochs_value else 'Unknown'}") + print(f" Total Validators Available: {len(validator_ids)}") + print(f" Selected Validator: {selected_validator}") + print(f"\n Per-Validator Key Sizes:") + print(f" - Secret Key: 8,390,660 bytes (~8.0 MB)") + print(f" - Public Key: 52 bytes") + if actual_active_epochs_value: + print(f" - Num Active Epochs: {actual_active_epochs_value}") + + print(f"\n🔐 Cross-Language Compatibility:") + print(f" ✅ Rust sign → Rust verify") + print(f" ✅ Rust sign → Zig verify") + print(f" ✅ Zig sign → Zig verify") + print(f" ✅ Zig sign → Rust verify") + + if all_passed: + print(f"\n✅ Validator {selected_validator} passed all cross-language compatibility tests!") + return 0 + else: + print(f"\n❌ Validator {selected_validator} failed cross-language compatibility tests") + return 1 + +if __name__ == "__main__": + sys.exit(main()) + diff --git a/benchmark/pre-generated-keys/README.md b/benchmark/pre-generated-keys/README.md new file mode 100644 index 0000000..314fc46 --- /dev/null +++ b/benchmark/pre-generated-keys/README.md @@ -0,0 +1,137 @@ +# Pre-Generated Keys + +This directory contains pre-generated Generalized XMSS keys for testing and benchmarking purposes. + +## Key Specifications + +- **Lifetime**: 2^32 signatures per key +- **Active Epochs**: 1024 +- **Serialization Format**: SSZ (Simple Serialize) +- **Number of Validators**: 3 + +## File Structure + +Each validator has two files: +- `validator_N_sk.ssz` - Secret key (8,390,596 bytes / ~8.0 MB) +- `validator_N_pk.ssz` - Public key (52 bytes) + +Where `N` is the validator ID (0, 1, 2). + +## Key Information + +### Secret Key +- **Size**: 8,390,596 bytes (~8.0 MB per validator) +- **Contains**: Full tree structure (top tree + 2 bottom trees) with 131,072 active epochs (1024 * 128) +- **Format**: SSZ-serialized `GeneralizedXMSSSecretKey` + +### Public Key +- **Size**: 52 bytes +- **Contains**: + - Top tree root (32 bytes / 8 field elements) + - Parameter (16 bytes / 4 field elements) + - Lifetime tag (4 bytes) +- **Format**: SSZ-serialized `GeneralizedXMSSPublicKey` + +## Cross-Language Compatibility + +These keys have been tested for full cross-language compatibility between Rust and Zig implementations: + +✅ **Rust sign → Rust verify** +✅ **Rust sign → Zig verify** +✅ **Zig sign → Zig verify** +✅ **Zig sign → Rust verify** + +## Usage + +### Inspect Keys + +To inspect and test all pre-generated keys: + +```bash +python3 benchmark/inspect_pregenerated_keys.py +``` + +This script will: +1. Report key sizes and information +2. Run cross-language compatibility tests for all validators +3. Test signing and verification in both directions (Rust↔Zig) + +### Manual Testing + +#### Using Rust Tool + +```bash +# Sign with pre-generated key +cp benchmark/pre-generated-keys/validator_0_sk.ssz tmp/rust_sk.ssz +cp benchmark/pre-generated-keys/validator_0_pk.ssz tmp/rust_pk.ssz +echo "2^32" > tmp/rust_lifetime.txt +./benchmark/rust_benchmark/target/release/cross_lang_rust_tool sign "Test message" 0 --ssz + +# Verify signature +./benchmark/rust_benchmark/target/release/cross_lang_rust_tool verify \ + tmp/rust_sig.ssz \ + tmp/rust_pk.ssz \ + "Test message" \ + 0 \ + --ssz +``` + +#### Using Zig Tool + +```bash +# Sign with pre-generated key +cp benchmark/pre-generated-keys/validator_0_sk.ssz tmp/zig_sk.ssz +cp benchmark/pre-generated-keys/validator_0_pk.ssz tmp/zig_pk.ssz +echo "2^32" > tmp/zig_lifetime.txt +echo "1024" > tmp/zig_active_epochs.txt +./zig-out/bin/cross-lang-zig-tool sign "Test message" 0 --ssz + +# Verify signature +./zig-out/bin/cross-lang-zig-tool verify \ + tmp/zig_sig.ssz \ + tmp/zig_pk.ssz \ + "Test message" \ + 0 \ + --ssz +``` + +## Generation + +These keys were generated using: + +```bash +# Rust (reference implementation) +cargo run --release --bin cross_lang_rust_tool -- keygen 2^32 --ssz +``` + +With 1024 active epochs configured in the key generation parameters. + +## Performance + +### Signing Performance +- **Rust**: ~5-6ms per signature +- **Zig**: ~650-720ms per signature (includes keypair regeneration) + +### Verification Performance +- **Rust**: ~3-4ms per verification +- **Zig**: ~8-11ms per verification + +### Key Generation Performance +- **Rust**: ~5.3s for 2^32 with 1024 active epochs +- **Zig**: ~17.7s for 2^32 with 1024 active epochs + +## Notes + +- These keys are for **testing purposes only** +- The secret keys contain sensitive cryptographic material +- In production, keys should be generated with proper entropy and stored securely +- The Zig implementation regenerates the keypair during signing to ensure consistency +- All tests use epoch 0 for simplicity; production usage should increment epochs + +## Related Files + +- `../inspect_pregenerated_keys.py` - Inspection and testing script +- `../benchmark.py` - Main cross-language compatibility benchmark +- `../rust_benchmark/` - Rust reference implementation +- `../zig_benchmark/` - Zig implementation tools + diff --git a/benchmark/pre-generated-keys/validator_0_pk.ssz b/benchmark/pre-generated-keys/validator_0_pk.ssz new file mode 100644 index 0000000..6f322ae Binary files /dev/null and b/benchmark/pre-generated-keys/validator_0_pk.ssz differ diff --git a/benchmark/pre-generated-keys/validator_0_sk.ssz b/benchmark/pre-generated-keys/validator_0_sk.ssz new file mode 100644 index 0000000..944b35c Binary files /dev/null and b/benchmark/pre-generated-keys/validator_0_sk.ssz differ diff --git a/benchmark/pre-generated-keys/validator_1_pk.ssz b/benchmark/pre-generated-keys/validator_1_pk.ssz new file mode 100644 index 0000000..e01c6d1 Binary files /dev/null and b/benchmark/pre-generated-keys/validator_1_pk.ssz differ diff --git a/benchmark/pre-generated-keys/validator_1_sk.ssz b/benchmark/pre-generated-keys/validator_1_sk.ssz new file mode 100644 index 0000000..fe800b6 Binary files /dev/null and b/benchmark/pre-generated-keys/validator_1_sk.ssz differ diff --git a/benchmark/pre-generated-keys/validator_2_pk.ssz b/benchmark/pre-generated-keys/validator_2_pk.ssz new file mode 100644 index 0000000..6d58115 --- /dev/null +++ b/benchmark/pre-generated-keys/validator_2_pk.ssz @@ -0,0 +1,2 @@ +gq I.d4Cp\9mz'c0i Ehwgz"ErI + *j \ No newline at end of file diff --git a/benchmark/pre-generated-keys/validator_2_sk.ssz b/benchmark/pre-generated-keys/validator_2_sk.ssz new file mode 100644 index 0000000..49288df Binary files /dev/null and b/benchmark/pre-generated-keys/validator_2_sk.ssz differ diff --git a/benchmark/rust_benchmark/src/bin/cross_lang_rust_tool.rs b/benchmark/rust_benchmark/src/bin/cross_lang_rust_tool.rs index fb16a9b..a92d794 100644 --- a/benchmark/rust_benchmark/src/bin/cross_lang_rust_tool.rs +++ b/benchmark/rust_benchmark/src/bin/cross_lang_rust_tool.rs @@ -49,6 +49,7 @@ fn main() -> Result<(), Box> { eprintln!(" {} keygen [seed_hex] [lifetime] [--ssz] - Generate keypair (lifetime: 2^8, 2^18, or 2^32, default: 2^8)", args[0]); eprintln!(" {} sign [--ssz] - Sign message using tmp/rust_sk.json, save to tmp/rust_sig.bin or tmp/rust_sig.ssz", args[0]); eprintln!(" {} verify [--ssz] - Verify Zig signature", args[0]); + eprintln!(" {} inspect - Inspect SSZ keys and report public key", args[0]); eprintln!("\n --ssz: Use SSZ serialization instead of JSON/bincode"); std::process::exit(1); } @@ -57,6 +58,22 @@ fn main() -> Result<(), Box> { let use_ssz = args.iter().any(|arg| arg == "--ssz"); match args[1].as_str() { + "inspect" => { + if args.len() < 5 { + eprintln!("Usage: {} inspect ", args[0]); + eprintln!("Example: {} inspect validator_0_sk.ssz validator_0_pk.ssz 2^32", args[0]); + std::process::exit(1); + } + let sk_path = &args[2]; + let pk_path = &args[3]; + let lifetime_str = &args[4]; + let lifetime = LifetimeTag::parse(Some(&lifetime_str.to_string()))?; + + if let Err(e) = inspect_command(sk_path, pk_path, lifetime) { + eprintln!("❌ Error: {}", e); + std::process::exit(1); + } + } "keygen" => { let seed_hex = args.get(2); let lifetime_str = args.get(3); @@ -114,6 +131,8 @@ fn keygen_command(seed_hex: Option<&String>, lifetime: LifetimeTag, use_ssz: boo .and_then(|s| s.trim().parse().ok()) .unwrap_or(256); + eprintln!("Using num_active_epochs: {}", num_active_epochs); + let seed = if let Some(hex) = seed_hex { let bytes = hex::decode(hex)?; if bytes.len() != 32 { @@ -333,6 +352,72 @@ fn sign_command(message: &str, epoch: u32, lifetime: LifetimeTag, use_ssz: bool) Ok(()) } +fn inspect_command(sk_path: &str, pk_path: &str, lifetime: LifetimeTag) -> Result<(), Box> { + eprintln!("🔍 Rust: Inspecting keys..."); + eprintln!(" Secret key: {}", sk_path); + eprintln!(" Public key: {}", pk_path); + + match lifetime { + LifetimeTag::Pow8 => { + type SkType = ::SecretKey; + type PkType = ::PublicKey; + + let sk_bytes = fs::read(sk_path)?; + let _secret_key: SkType = Decode::from_ssz_bytes(&sk_bytes) + .map_err(|e: DecodeError| format!("Failed to decode secret key: {:?}", e))?; + + let pk_bytes = fs::read(pk_path)?; + let public_key: PkType = Decode::from_ssz_bytes(&pk_bytes) + .map_err(|e: DecodeError| format!("Failed to decode public key: {:?}", e))?; + + eprintln!("✅ Successfully deserialized keys for lifetime 2^8"); + eprintln!(" Public key size: {} bytes", pk_bytes.len()); + eprintln!(" Secret key size: {} bytes", sk_bytes.len()); + eprintln!(" Public key (first 8 bytes): {:02x?}", &pk_bytes[..8.min(pk_bytes.len())]); + + Ok(()) + } + LifetimeTag::Pow18 => { + type SkType = ::SecretKey; + type PkType = ::PublicKey; + + let sk_bytes = fs::read(sk_path)?; + let _secret_key: SkType = Decode::from_ssz_bytes(&sk_bytes) + .map_err(|e: DecodeError| format!("Failed to decode secret key: {:?}", e))?; + + let pk_bytes = fs::read(pk_path)?; + let public_key: PkType = Decode::from_ssz_bytes(&pk_bytes) + .map_err(|e: DecodeError| format!("Failed to decode public key: {:?}", e))?; + + eprintln!("✅ Successfully deserialized keys for lifetime 2^18"); + eprintln!(" Public key size: {} bytes", pk_bytes.len()); + eprintln!(" Secret key size: {} bytes", sk_bytes.len()); + eprintln!(" Public key (first 8 bytes): {:02x?}", &pk_bytes[..8.min(pk_bytes.len())]); + + Ok(()) + } + LifetimeTag::Pow32 => { + type SkType = ::SecretKey; + type PkType = ::PublicKey; + + let sk_bytes = fs::read(sk_path)?; + let _secret_key: SkType = Decode::from_ssz_bytes(&sk_bytes) + .map_err(|e: DecodeError| format!("Failed to decode secret key: {:?}", e))?; + + let pk_bytes = fs::read(pk_path)?; + let public_key: PkType = Decode::from_ssz_bytes(&pk_bytes) + .map_err(|e: DecodeError| format!("Failed to decode public key: {:?}", e))?; + + eprintln!("✅ Successfully deserialized keys for lifetime 2^32"); + eprintln!(" Public key size: {} bytes", pk_bytes.len()); + eprintln!(" Secret key size: {} bytes", sk_bytes.len()); + eprintln!(" Public key (first 8 bytes): {:02x?}", &pk_bytes[..8.min(pk_bytes.len())]); + + Ok(()) + } + } +} + fn verify_command(sig_path: &str, pk_path: &str, message: &str, epoch: u32, lifetime: LifetimeTag, use_ssz: bool) -> Result<(), Box> { eprintln!("Verifying signature from Zig..."); eprintln!(" Signature: {}", sig_path); diff --git a/benchmark/zig_benchmark/src/cross_lang_zig_tool.zig b/benchmark/zig_benchmark/src/cross_lang_zig_tool.zig index 0e40aa1..093bcae 100644 --- a/benchmark/zig_benchmark/src/cross_lang_zig_tool.zig +++ b/benchmark/zig_benchmark/src/cross_lang_zig_tool.zig @@ -1,5 +1,5 @@ //! Zig tool for cross-language compatibility testing -//! +//! //! This tool provides: //! - Key generation (supports lifetime 2^8, 2^18, 2^32) //! - Serialization of secret/public keys to bincode JSON @@ -33,13 +33,13 @@ fn readLifetimeFromFile(allocator: Allocator) !KeyLifetime { return err; }; defer allocator.free(lifetime_json); - + // Remove trailing newline if present var lifetime_str = lifetime_json; if (lifetime_str.len > 0 and lifetime_str[lifetime_str.len - 1] == '\n') { - lifetime_str = lifetime_str[0..lifetime_str.len - 1]; + lifetime_str = lifetime_str[0 .. lifetime_str.len - 1]; } - + return parseLifetime(lifetime_str); } @@ -56,10 +56,11 @@ pub fn main() !void { std.debug.print(" {s} keygen [seed_hex] [lifetime] [--ssz] - Generate keypair (lifetime: 2^8, 2^18, or 2^32, default: 2^8)\n", .{args[0]}); std.debug.print(" {s} sign [--ssz] - Sign message using tmp/zig_sk.json, save to tmp/zig_sig.bin or tmp/zig_sig.ssz\n", .{args[0]}); std.debug.print(" {s} verify [--ssz] - Verify Rust signature\n", .{args[0]}); + std.debug.print(" {s} inspect - Inspect SSZ keys and report public key\n", .{args[0]}); std.debug.print("\n --ssz: Use SSZ serialization instead of JSON/bincode\n", .{}); std.process.exit(1); } - + // Check for --ssz flag var use_ssz = false; for (args) |arg| { @@ -69,7 +70,21 @@ pub fn main() !void { } } - if (std.mem.eql(u8, args[1], "keygen")) { + if (std.mem.eql(u8, args[1], "inspect")) { + if (args.len < 5) { + std.debug.print("Usage: {s} inspect \n", .{args[0]}); + std.debug.print("Example: {s} inspect validator_0_sk.ssz validator_0_pk.ssz 2^32\n", .{args[0]}); + std.process.exit(1); + } + const sk_path = args[2]; + const pk_path = args[3]; + const lifetime_str = args[4]; + const lifetime = parseLifetime(lifetime_str) catch { + std.debug.print("Error: Invalid lifetime '{s}'. Must be one of: 2^8, 2^18, 2^32\n", .{lifetime_str}); + std.process.exit(1); + }; + try inspectCommand(allocator, sk_path, pk_path, lifetime); + } else if (std.mem.eql(u8, args[1], "keygen")) { const seed_hex = if (args.len > 2) args[2] else null; const lifetime_str = if (args.len > 3) args[3] else "2^8"; const lifetime = parseLifetime(lifetime_str) catch { @@ -170,40 +185,17 @@ fn keygenCommand(allocator: Allocator, seed_hex: ?[]const u8, lifetime: KeyLifet // Remove trailing newline if present var active_epochs_str = active_epochs_file; if (active_epochs_str.len > 0 and active_epochs_str[active_epochs_str.len - 1] == '\n') { - active_epochs_str = active_epochs_str[0..active_epochs_str.len - 1]; + active_epochs_str = active_epochs_str[0 .. active_epochs_str.len - 1]; } break :blk try std.fmt.parseUnsigned(u32, active_epochs_str, 10); }; // Generate keypair - // Debug: Log RNG state before keyGen - const rng_state_before = scheme.getRngState(); - log.print("ZIG_KEYGEN_DEBUG: RNG state before keyGen: ", .{}); - for (rng_state_before) |val| { - log.print("0x{x:0>8} ", .{val}); - } - log.print("\n", .{}); - var keypair = scheme.keyGen(0, num_active_epochs) catch |err| { - log.print("ZIG_KEYGEN_ERROR: keyGen failed with error {s}\n", .{@errorName(err)}); + std.debug.print("Error: keyGen failed with {s}\n", .{@errorName(err)}); return err; }; defer keypair.secret_key.deinit(); - - // Debug: Log RNG state after keyGen - const rng_state_after_keygen = scheme.getRngState(); - log.print("ZIG_KEYGEN_DEBUG: RNG state after keyGen: ", .{}); - for (rng_state_after_keygen) |val| { - log.print("0x{x:0>8} ", .{val}); - } - log.print("\n", .{}); - - // Debug: Log the generated public key root - log.print("ZIG_KEYGEN_DEBUG: Generated public key root (canonical): ", .{}); - for (keypair.public_key.root) |fe| { - log.print("0x{x:0>8} ", .{fe.toCanonical()}); - } - log.print("\n", .{}); if (use_ssz) { // Serialize secret key to SSZ @@ -233,18 +225,7 @@ fn keygenCommand(allocator: Allocator, seed_hex: ?[]const u8, lifetime: KeyLifet // Serialize public key to JSON const pk_json = try hash_zig.serialization.serializePublicKey(allocator, &keypair.public_key); defer allocator.free(pk_json); - - // Debug: print parameter that will be written to public key file - log.print("ZIG_KEYGEN_DEBUG: parameter to be written to public key file (canonical): ", .{}); - for (0..5) |i| { - log.print("0x{x:0>8} ", .{keypair.public_key.parameter[i].toCanonical()}); - } - log.print("(Montgomery: ", .{}); - for (0..5) |i| { - log.print("0x{x:0>8} ", .{keypair.public_key.parameter[i].toMontgomery()}); - } - log.print(")\n", .{}); - + var pk_file = try std.fs.cwd().createFile("tmp/zig_pk.json", .{}); defer pk_file.close(); try pk_file.writeAll(pk_json); @@ -257,12 +238,88 @@ fn keygenCommand(allocator: Allocator, seed_hex: ?[]const u8, lifetime: KeyLifet fn signCommand(allocator: Allocator, message: []const u8, epoch: u32, lifetime: KeyLifetime, use_ssz: bool) !void { std.debug.print("Signing message: '{s}' (epoch: {})\n", .{ message, epoch }); - // Prefer deserialization path (PRF key + parameter) for reliable reconstruction. - // This ensures we use the exact parameter and PRF key from the original keygen, - // which should produce identical trees. The seed-based path can have RNG state - // synchronization issues, especially for 2^32 lifetime. var scheme: *hash_zig.GeneralizedXMSSSignatureScheme = undefined; const keypair: hash_zig.GeneralizedXMSSSignatureScheme.KeyGenResult = blk: { + // Try to load SSZ secret key first if use_ssz is true and file exists + if (use_ssz) { + if (std.fs.cwd().readFileAlloc(allocator, "tmp/zig_sk.ssz", std.math.maxInt(usize))) |sk_ssz| { + defer allocator.free(sk_ssz); + + // Check if this is a full secret key (with trees) or minimal (just metadata) + // Minimal SSZ: exactly 68 bytes (prf_key:32 + parameter:20 + activation_epoch:8 + num_active_epochs:8) + // Full SSZ: 88+ bytes header + tree data (at least several KB even for smallest lifetime) + // A full key for 2^8 with 2 active epochs is ~3KB, so 500 bytes is a safe threshold + const is_full_secret_key = sk_ssz.len >= 500; // Threshold: >= 500 bytes means full key + + if (!is_full_secret_key) { + std.debug.print("⚠️ SSZ secret key is minimal ({} bytes), falling back to regeneration\n", .{sk_ssz.len}); + // Fall through to JSON/regeneration path + } else { + std.debug.print("✅ Loaded pre-generated full SSZ secret key ({} bytes)\n", .{sk_ssz.len}); + + // Extract lifetime from tree depth field in SSZ + // SSZ format: [header:88][top_tree_data...] + // Tree structure: [depth:8][lowest_layer:8]... + if (sk_ssz.len < 96) return error.InvalidLength; // Need at least header + tree depth + const top_tree_offset = std.mem.readInt(u32, sk_ssz[68..72], .little); + + // Validate top_tree_offset to prevent overflow and out-of-bounds access + if (top_tree_offset < 88 or top_tree_offset >= sk_ssz.len) return error.InvalidOffset; + if (sk_ssz.len - top_tree_offset < 8) return error.InvalidLength; + + const tree_depth = std.mem.readInt(u64, sk_ssz[top_tree_offset .. top_tree_offset + 8][0..8], .little); + + const actual_lifetime: KeyLifetime = switch (tree_depth) { + 8 => .lifetime_2_8, + 18 => .lifetime_2_18, + 32 => .lifetime_2_32, + else => return error.InvalidLifetime, + }; + + if (actual_lifetime != lifetime) { + std.debug.print("⚠️ Using lifetime from SSZ file (log={d}) instead of provided lifetime\n", .{tree_depth}); + } + + // Allocate secret key on heap + const secret_key = try allocator.create(hash_zig.GeneralizedXMSSSecretKey); + errdefer allocator.destroy(secret_key); + + // Deserialize the full secret key (including trees) from SSZ + try hash_zig.GeneralizedXMSSSecretKey.sszDecode(sk_ssz, secret_key, allocator); + + // Derive public key from secret key's top tree root (not from file!) + // The public key is: root = top_tree.root(), parameter = secret_key.parameter + const top_tree_root = secret_key.top_tree.root(); + const hash_len_fe: usize = switch (actual_lifetime) { + .lifetime_2_8 => 8, + .lifetime_2_18 => 7, + .lifetime_2_32 => 8, + }; + const public_key = hash_zig.GeneralizedXMSSPublicKey.init(top_tree_root, secret_key.parameter, hash_len_fe); + + std.debug.print("✅ Loaded pre-generated key (lifetime 2^{}, {} active epochs)\n", .{tree_depth, secret_key.num_active_epochs}); + + // Initialize scheme with just the lifetime - we don't need to pass PRF key as seed! + // The secret key already contains the PRF key, parameter, and all trees. + // We just need a minimal scheme with the right lifetime_params and poseidon2 for hashing. + scheme = try hash_zig.GeneralizedXMSSSignatureScheme.init(allocator, actual_lifetime); + + // Return the loaded keypair + break :blk hash_zig.GeneralizedXMSSSignatureScheme.KeyGenResult{ + .secret_key = secret_key, + .public_key = public_key, + }; + } + } else |err| { + // If SSZ file not found or error reading, fall through to JSON/regeneration path + if (err != error.FileNotFound) { + std.debug.print("⚠️ Error reading SSZ secret key: {}, falling back to JSON/regeneration\n", .{err}); + } + // Fall through to JSON path below + } + } + + // JSON/regeneration path const sk_json = std.fs.cwd().readFileAlloc(allocator, "tmp/zig_sk.json", std.math.maxInt(usize)) catch |err| { // Fallback to seed-based path if secret key file is missing const seed_file = std.fs.cwd().openFile("tmp/zig_seed.hex", .{}) catch { @@ -277,7 +334,6 @@ fn signCommand(allocator: Allocator, message: []const u8, epoch: u32, lifetime: var seed: [32]u8 = undefined; if (hex_slice.len != 64) { - log.print("ZIG_SIGN_DEBUG: Invalid seed length in tmp/zig_seed.hex (got {}, expected 64)\n", .{hex_slice.len}); return error.InvalidSeed; } _ = try std.fmt.hexToBytes(&seed, hex_slice); @@ -297,19 +353,18 @@ fn signCommand(allocator: Allocator, message: []const u8, epoch: u32, lifetime: // Remove trailing newline if present var active_epochs_str = active_epochs_file; if (active_epochs_str.len > 0 and active_epochs_str[active_epochs_str.len - 1] == '\n') { - active_epochs_str = active_epochs_str[0..active_epochs_str.len - 1]; + active_epochs_str = active_epochs_str[0 .. active_epochs_str.len - 1]; } break :blk2 try std.fmt.parseUnsigned(u32, active_epochs_str, 10); }; const kp = try scheme.keyGen(0, num_active_epochs); - log.print("ZIG_SIGN_DEBUG: Reconstructed keypair from seed (fallback path)\n", .{}); break :blk kp; }; defer allocator.free(sk_json); const sk_data = try hash_zig.serialization.deserializeSecretKeyData(allocator, sk_json); - + // Use the original seed (not PRF key) to ensure RNG state matches original keygen // The PRF key was generated from the seed, so we need to start from the seed // and consume RNG state to match where we were after generating parameter and PRF key @@ -317,7 +372,6 @@ fn signCommand(allocator: Allocator, message: []const u8, epoch: u32, lifetime: // If seed file is missing, fall back to using PRF key as seed (may not match exactly) scheme = try hash_zig.GeneralizedXMSSSignatureScheme.initWithSeed(allocator, lifetime, sk_data.prf_key); const kp = try scheme.keyGenWithParameter(sk_data.activation_epoch, sk_data.num_active_epochs, sk_data.parameter, sk_data.prf_key, false); - log.print("ZIG_SIGN_DEBUG: Reconstructed keypair from PRF key + parameter (no seed file)\n", .{}); break :blk kp; }; defer seed_file.close(); @@ -329,14 +383,13 @@ fn signCommand(allocator: Allocator, message: []const u8, epoch: u32, lifetime: var seed: [32]u8 = undefined; if (seed_hex_slice.len != 64) { - log.print("ZIG_SIGN_DEBUG: Invalid seed length in tmp/zig_seed.hex (got {}, expected 64)\n", .{seed_hex_slice.len}); return error.InvalidSeed; } _ = try std.fmt.hexToBytes(&seed, seed_hex_slice); // Initialize with original seed to match RNG state from keygen scheme = try hash_zig.GeneralizedXMSSSignatureScheme.initWithSeed(allocator, lifetime, seed); - + // CRITICAL: We need to match the RNG state exactly as it was when keyGenWithParameter // was called from keyGen(). In keyGen(), the flow is: // 1. generateRandomParameter() - peeks 20 bytes (doesn't consume) @@ -376,61 +429,14 @@ fn signCommand(allocator: Allocator, message: []const u8, epoch: u32, lifetime: // We've already consumed 32 bytes to match PRF key generation, so pass true const kp = try scheme.keyGenWithParameter(sk_data.activation_epoch, sk_data.num_active_epochs, sk_data.parameter, sk_data.prf_key, true); - - log.print("ZIG_SIGN_DEBUG: Reconstructed keypair from PRF key + parameter with original seed (preferred path)\n", .{}); break :blk kp; }; - + // Keep scheme alive for signing - it's needed for the sign() call defer scheme.deinit(); defer keypair.secret_key.deinit(); - + const secret_key = keypair.secret_key; - - // CRITICAL DEBUG: Verify the secret key's top tree root matches the public key root - const top_tree_root = secret_key.top_tree.root(); - log.print("ZIG_SIGN_DEBUG: Top tree root from secret key (canonical): ", .{}); - for (top_tree_root) |fe| { - log.print("0x{x:0>8} ", .{fe.toCanonical()}); - } - log.print("\n", .{}); - log.print("ZIG_SIGN_DEBUG: Public key root (canonical): ", .{}); - for (keypair.public_key.root) |fe| { - log.print("0x{x:0>8} ", .{fe.toCanonical()}); - } - log.print("\n", .{}); - - var root_match = true; - for (0..8) |i| { - if (!top_tree_root[i].eql(keypair.public_key.root[i])) { - log.debugPrint("ZIG_SIGN_ERROR: Top tree root[{}] mismatch: computed=0x{x:0>8} (canonical) / 0x{x:0>8} (monty) expected=0x{x:0>8} (canonical) / 0x{x:0>8} (monty)\n", .{ - i, - top_tree_root[i].toCanonical(), - top_tree_root[i].toMontgomery(), - keypair.public_key.root[i].toCanonical(), - keypair.public_key.root[i].toMontgomery(), - }); - root_match = false; - } - } - if (!root_match) { - log.debugPrint("ZIG_SIGN_ERROR: Top tree root does not match public key root! This indicates the regenerated keypair is inconsistent.\n", .{}); - log.debugPrint("ZIG_SIGN_ERROR: This will cause verification to fail. The signature will be generated with trees that don't match the public key.\n", .{}); - // Continue anyway to see the full error - } else { - log.debugPrint("ZIG_SIGN_DEBUG: Top tree root matches public key root ✓\n", .{}); - } - - // CRITICAL DEBUG: Verify the secret key has the correct parameter - log.print("ZIG_SIGN_DEBUG_STEP4: Secret key parameter after keyGenWithParameter (canonical): ", .{}); - for (0..5) |i| { - log.print("0x{x:0>8} ", .{secret_key.getParameter()[i].toCanonical()}); - } - log.print("(Montgomery: ", .{}); - for (0..5) |i| { - log.print("0x{x:0>8} ", .{secret_key.getParameter()[i].toMontgomery()}); - } - log.print(")\n", .{}); // Convert message to 32 bytes var msg_bytes: [32]u8 = undefined; @@ -438,56 +444,21 @@ fn signCommand(allocator: Allocator, message: []const u8, epoch: u32, lifetime: @memset(msg_bytes[0..], 0); @memcpy(msg_bytes[0..len], message[0..len]); - // CRITICAL: Verify parameter match before signing - std.debug.print("ZIG_SIGN_DEBUG: Checking parameter match before signing:\n", .{}); - std.debug.print("ZIG_SIGN_DEBUG: secret_key.parameter (canonical): ", .{}); - for (0..5) |i| { - std.debug.print("0x{x:0>8} ", .{secret_key.getParameter()[i].toCanonical()}); - } - std.debug.print("\nZIG_SIGN_DEBUG: public_key.parameter (canonical): ", .{}); - for (0..5) |i| { - std.debug.print("0x{x:0>8} ", .{keypair.public_key.parameter[i].toCanonical()}); - } - std.debug.print("\n", .{}); - // Verify parameters match - var param_match = true; for (0..5) |i| { if (!secret_key.getParameter()[i].eql(keypair.public_key.parameter[i])) { - log.debugPrint("ZIG_SIGN_ERROR: Parameter mismatch at index {}!\n", .{i}); - param_match = false; + return error.ParameterMismatch; } } - if (!param_match) { - log.debugPrint("ZIG_SIGN_ERROR: secret_key.parameter does not match public_key.parameter!\n", .{}); - return error.ParameterMismatch; - } - std.debug.print("ZIG_SIGN_DEBUG: Parameters match ✓\n", .{}); // Sign the message var signature = try scheme.sign(secret_key, epoch, msg_bytes); defer signature.deinit(); - // In-memory self-check: verify immediately using the same keypair and message. - // Debug: Print stored hash from signature before verification - const stored_hashes = signature.getHashes(); - if (stored_hashes.len > 0) { - log.debugPrint("ZIG_SIGN_DEBUG: Stored hash[0] before verification (Montgomery): ", .{}); - for (0..@min(8, stored_hashes[0].len)) |h| { - std.debug.print("0x{x:0>8} ", .{stored_hashes[0][h].value}); - } - std.debug.print("\n", .{}); - log.debugPrint("ZIG_SIGN_DEBUG: Stored hash[0] before verification (Canonical): ", .{}); - for (0..@min(8, stored_hashes[0].len)) |h| { - std.debug.print("0x{x:0>8} ", .{stored_hashes[0][h].toCanonical()}); - } - std.debug.print("\n", .{}); - } + // In-memory self-check: verify immediately using the same keypair and message const in_memory_valid = try scheme.verify(&keypair.public_key, epoch, msg_bytes, signature); - if (in_memory_valid) { - log.debugPrint("ZIG_SIGN_DEBUG: In-memory sign→verify PASSED for epoch {}\n", .{epoch}); - } else { - log.debugPrint("ZIG_SIGN_DEBUG: In-memory sign→verify FAILED for epoch {}\n", .{epoch}); + if (!in_memory_valid) { + std.debug.print("Warning: In-memory verification failed\n", .{}); } if (use_ssz) { @@ -497,7 +468,7 @@ fn signCommand(allocator: Allocator, message: []const u8, epoch: u32, lifetime: var pk_file = try std.fs.cwd().createFile("tmp/zig_pk.ssz", .{}); defer pk_file.close(); try pk_file.writeAll(pk_bytes); - std.debug.print("✅ Public key updated to tmp/zig_pk.ssz ({} bytes, from regenerated keypair)\n", .{pk_bytes.len}); + std.debug.print("✅ Public key saved to tmp/zig_pk.ssz ({} bytes)\n", .{pk_bytes.len}); // Serialize signature to SSZ const sig_bytes = try signature.toBytes(allocator); @@ -529,7 +500,7 @@ fn signCommand(allocator: Allocator, message: []const u8, epoch: u32, lifetime: const rand_len = scheme.lifetime_params.rand_len_fe; const hash_len = scheme.lifetime_params.hash_len_fe; try remote_hash_tool.writeSignatureBincode("tmp/zig_sig.bin", signature, rand_len, hash_len); - + // Pad to exactly 3116 bytes as per leanSignature spec (Bytes3116 container) const SIG_LEN: usize = 3116; var sig_file = try std.fs.cwd().openFile("tmp/zig_sig.bin", .{ .mode = .read_write }); @@ -556,6 +527,96 @@ fn signCommand(allocator: Allocator, message: []const u8, epoch: u32, lifetime: std.debug.print("Message signed successfully!\n", .{}); } +fn inspectCommand(allocator: Allocator, sk_path: []const u8, pk_path: []const u8, lifetime: KeyLifetime) !void { + std.debug.print("🔍 Zig: Inspecting keys...\n", .{}); + std.debug.print(" Secret key: {s}\n", .{sk_path}); + std.debug.print(" Public key: {s}\n", .{pk_path}); + + // Read secret key + const sk_bytes = try std.fs.cwd().readFileAlloc(allocator, sk_path, std.math.maxInt(usize)); + defer allocator.free(sk_bytes); + + // Read public key + const pk_bytes = try std.fs.cwd().readFileAlloc(allocator, pk_path, std.math.maxInt(usize)); + defer allocator.free(pk_bytes); + + // Parse metadata from SSZ header without fully deserializing trees + // SSZ format: [prf_key:32][parameter:20][activation_epoch:8][num_active_epochs:8][top_tree_offset:4][left_bottom_tree_index:8]... + if (sk_bytes.len < 88) return error.InvalidLength; + + var offset: usize = 0; + + // Skip prf_key (32 bytes) + offset += 32; + + // Skip parameter (20 bytes for 5 u32s) + offset += 20; + + // Read activation_epoch (u64) + const activation_epoch = std.mem.readInt(u64, sk_bytes[offset .. offset + 8][0..8], .little); + offset += 8; + + // Read num_active_epochs (u64) + const num_active_epochs = std.mem.readInt(u64, sk_bytes[offset .. offset + 8][0..8], .little); + offset += 8; + + // Read top_tree_offset (u32) + const top_tree_offset = std.mem.readInt(u32, sk_bytes[offset .. offset + 4][0..4], .little); + offset += 4; + + // Validate top_tree_offset to prevent overflow and out-of-bounds access + if (top_tree_offset < 88 or top_tree_offset >= sk_bytes.len) return error.InvalidOffset; + if (sk_bytes.len - top_tree_offset < 8) return error.InvalidLength; + + // Read left_bottom_tree_index (u64) + const left_bottom_tree_index = std.mem.readInt(u64, sk_bytes[offset .. offset + 8][0..8], .little); + + // Extract lifetime from tree depth field + // Tree structure: [depth:8][lowest_layer:8][layers_offset:4]... + const tree_depth = std.mem.readInt(u64, sk_bytes[top_tree_offset .. top_tree_offset + 8][0..8], .little); + + // Determine actual lifetime from tree depth (log_lifetime) + const actual_lifetime: KeyLifetime = switch (tree_depth) { + 8 => .lifetime_2_8, + 18 => .lifetime_2_18, + 32 => .lifetime_2_32, + else => return error.InvalidLifetime, + }; + + // Verify it matches the provided lifetime parameter + if (actual_lifetime != lifetime) { + std.debug.print("⚠️ WARNING: Provided lifetime {s} doesn't match SSZ file lifetime (log={d})\n", .{ + switch (lifetime) { + .lifetime_2_8 => "2^8", + .lifetime_2_18 => "2^18", + .lifetime_2_32 => "2^32", + }, + tree_depth, + }); + } + + // Deserialize public key to verify it's valid + _ = try hash_zig.GeneralizedXMSSPublicKey.fromBytes(pk_bytes, null); + + const lifetime_str = switch (actual_lifetime) { + .lifetime_2_8 => "2^8", + .lifetime_2_18 => "2^18", + .lifetime_2_32 => "2^32", + }; + + std.debug.print("✅ Successfully deserialized keys for lifetime {s}\n", .{lifetime_str}); + std.debug.print(" Public key size: {} bytes\n", .{pk_bytes.len}); + std.debug.print(" Secret key size: {} bytes\n", .{sk_bytes.len}); + std.debug.print(" Activation epoch: {}\n", .{activation_epoch}); + std.debug.print(" Num active epochs: {}\n", .{num_active_epochs}); + std.debug.print(" Left bottom tree index: {}\n", .{left_bottom_tree_index}); + std.debug.print(" Public key (first 8 bytes): ", .{}); + for (pk_bytes[0..@min(8, pk_bytes.len)]) |byte| { + std.debug.print("{x:0>2}", .{byte}); + } + std.debug.print("\n", .{}); +} + fn verifyCommand(allocator: Allocator, sig_path: []const u8, pk_path: []const u8, message: []const u8, epoch: u32, lifetime: KeyLifetime, use_ssz: bool) !void { std.debug.print("Verifying signature from Rust...\n", .{}); std.debug.print(" Signature: {s}\n", .{sig_path}); @@ -568,11 +629,11 @@ fn verifyCommand(allocator: Allocator, sig_path: []const u8, pk_path: []const u8 var scheme = try hash_zig.GeneralizedXMSSSignatureScheme.init(allocator, lifetime); defer scheme.deinit(); - + var signature: *hash_zig.GeneralizedXMSSSignature = undefined; var public_key: hash_zig.GeneralizedXMSSPublicKey = undefined; var sig_bytes_opt: ?[]u8 = null; // Keep sig_bytes alive if using SSZ - + if (use_ssz) { // Load signature from SSZ format const sig_bytes = try std.fs.cwd().readFileAlloc(allocator, sig_path, std.math.maxInt(usize)); @@ -580,14 +641,14 @@ fn verifyCommand(allocator: Allocator, sig_path: []const u8, pk_path: []const u8 std.debug.print("DEBUG: sig_bytes allocated at 0x{x}, len={}\n", .{ @intFromPtr(sig_bytes.ptr), sig_bytes.len }); signature = try hash_zig.GeneralizedXMSSSignature.fromBytes(sig_bytes, allocator); std.debug.print("DEBUG: signature allocated at 0x{x}\n", .{@intFromPtr(signature)}); - + // Validate signature struct is accessible before verify std.debug.print("DEBUG: Signature struct at 0x{x}, path=0x{x}, rho[0]=0x{x}\n", .{ @intFromPtr(signature), @intFromPtr(signature.path), signature.rho[0].toCanonical(), }); - + // Load public key from SSZ format const pk_bytes = try std.fs.cwd().readFileAlloc(allocator, pk_path, std.math.maxInt(usize)); defer allocator.free(pk_bytes); @@ -600,39 +661,15 @@ fn verifyCommand(allocator: Allocator, sig_path: []const u8, pk_path: []const u8 const max_path_len: usize = scheme.lifetime_params.final_layer; const hash_len = scheme.lifetime_params.hash_len_fe; const max_hashes: usize = scheme.lifetime_params.dimension; - + // The readSignatureBincode function reads from file path directly signature = try remote_hash_tool.readSignatureBincode(sig_path, allocator, rand_len, max_path_len, hash_len, max_hashes); - - // Debug: print rho from signature right after reading (before verify) - const rho_after_read = signature.getRho(); - log.print("ZIG_VERIFY_DEBUG: rho from signature.getRho() RIGHT AFTER READ (Montgomery): ", .{}); - for (0..rand_len) |i| { - log.print("0x{x:0>8} ", .{rho_after_read[i].toMontgomery()}); - } - log.print("\n", .{}); - - // Debug: print which public key file we're reading from - log.print("ZIG_VERIFY_DEBUG: Reading public key from file: {s}\n", .{pk_path}); // Load public key from Rust const pk_json = try std.fs.cwd().readFileAlloc(allocator, pk_path, std.math.maxInt(usize)); defer allocator.free(pk_json); public_key = try hash_zig.serialization.deserializePublicKey(pk_json); } - - // Debug: print parameter from public key right after reading - log.print("ZIG_VERIFY_DEBUG: parameter from public key file (canonical): ", .{}); - for (0..5) |i| { - log.print("0x{x:0>8} ", .{public_key.parameter[i].toCanonical()}); - } - log.print("(Montgomery: ", .{}); - for (0..5) |i| { - log.print("0x{x:0>8} ", .{public_key.parameter[i].toMontgomery()}); - } - log.print(")\n", .{}); - - // Scheme already initialized above // Convert message to 32 bytes var msg_bytes: [32]u8 = undefined; @@ -640,10 +677,6 @@ fn verifyCommand(allocator: Allocator, sig_path: []const u8, pk_path: []const u8 @memset(msg_bytes[0..], 0); @memcpy(msg_bytes[0..len], message[0..len]); - // Debug: verify signature struct is still valid before calling verify - log.print("ZIG_VERIFY_DEBUG: Before verify call - signature=0x{x}\n", .{@intFromPtr(signature)}); - // Don't access struct fields here - let verify() handle it - // Verify the signature const is_valid = try scheme.verify(&public_key, epoch, msg_bytes, signature); @@ -660,4 +693,3 @@ fn verifyCommand(allocator: Allocator, sig_path: []const u8, pk_path: []const u8 std.process.exit(1); } } - diff --git a/src/prf/shake_prf_to_field.zig b/src/prf/shake_prf_to_field.zig index 82ff88e..a897bc5 100644 --- a/src/prf/shake_prf_to_field.zig +++ b/src/prf/shake_prf_to_field.zig @@ -6,7 +6,8 @@ const crypto = std.crypto; const plonky3_field = @import("../poseidon2/plonky3_field.zig"); // Constants matching Rust implementation -const PRF_BYTES_PER_FE: usize = 8; +// CRITICAL: Rust uses 16 bytes per FE (reads as u128), not 8! +const PRF_BYTES_PER_FE: usize = 16; const KEY_LENGTH: usize = 32; // 32 bytes const MESSAGE_LENGTH: usize = 32; // From Rust hash-sig @@ -71,9 +72,9 @@ pub fn ShakePRFtoF(comptime DOMAIN_LENGTH_FE: usize, comptime RAND_LENGTH_FE: us const chunk_start = i * PRF_BYTES_PER_FE; const chunk_end = chunk_start + PRF_BYTES_PER_FE; - // Convert big-endian bytes to u64 + // Convert big-endian bytes to u128 (matching Rust's from_u128) const bytes_array: [PRF_BYTES_PER_FE]u8 = prf_output[chunk_start..chunk_end][0..PRF_BYTES_PER_FE].*; - const integer_value = std.mem.readInt(u64, &bytes_array, .big); + const integer_value = std.mem.readInt(u128, &bytes_array, .big); // Reduce modulo KoalaBear field order and map into Montgomery form const reduced: u32 = @intCast(integer_value % KOALA_BEAR_MODULUS); @@ -119,9 +120,9 @@ pub fn ShakePRFtoF(comptime DOMAIN_LENGTH_FE: usize, comptime RAND_LENGTH_FE: us const chunk_start = i * PRF_BYTES_PER_FE; const chunk_end = chunk_start + PRF_BYTES_PER_FE; - // Convert big-endian bytes to u64 + // Convert big-endian bytes to u128 (matching Rust's from_u128) const bytes_array: [PRF_BYTES_PER_FE]u8 = prf_output[chunk_start..chunk_end][0..PRF_BYTES_PER_FE].*; - const integer_value = std.mem.readInt(u64, &bytes_array, .big); + const integer_value = std.mem.readInt(u128, &bytes_array, .big); // Reduce modulo KoalaBear field order and map into Montgomery form const reduced: u32 = @intCast(integer_value % KOALA_BEAR_MODULUS); diff --git a/src/signature/native/scheme.zig b/src/signature/native/scheme.zig index 0bee474..abfed2f 100644 --- a/src/signature/native/scheme.zig +++ b/src/signature/native/scheme.zig @@ -167,6 +167,7 @@ pub const HashSubTree = struct { root_value: [8]FieldElement, layers: ?[]PaddedLayer, allocator: std.mem.Allocator, + depth: usize, // The tree depth (log_lifetime for full trees, log_lifetime/2 for bottom trees) pub fn init(allocator: std.mem.Allocator, root_value: [8]FieldElement) !*HashSubTree { const self = try allocator.create(HashSubTree); @@ -174,6 +175,7 @@ pub const HashSubTree = struct { .root_value = root_value, .layers = null, .allocator = allocator, + .depth = 0, // Default depth for trees without layers }; return self; } @@ -182,12 +184,14 @@ pub const HashSubTree = struct { allocator: std.mem.Allocator, root_value: [8]FieldElement, layers: []PaddedLayer, + depth: usize, ) !*HashSubTree { const self = try allocator.create(HashSubTree); self.* = HashSubTree{ .root_value = root_value, .layers = layers, .allocator = allocator, + .depth = depth, }; return self; } @@ -215,6 +219,92 @@ pub const HashSubTree = struct { }; /// Helper function to deserialize HashSubTree from leansig SSZ format +/// Serialize HashSubTree to SSZ format (matching Rust leansig) +fn serializeHashSubTree(tree: *const HashSubTree, l: *std.ArrayList(u8)) !void { + // Format: [depth:8][lowest_layer:8][layers_offset:4][layers_data] + const layers = tree.getLayers() orelse return error.NoLayers; + + // Write depth (u64) - Rust stores the tree depth (log_lifetime), not layers.len + try ssz.serialize(u64, @as(u64, @intCast(tree.depth)), l); + + // Write lowest_layer (u64) - always 0 for our trees + try ssz.serialize(u64, @as(u64, 0), l); + + // Write layers_offset (u32) - points to start of layers array + const layers_offset: u32 = 20; // 8 + 8 + 4 = 20 + try ssz.serialize(u32, layers_offset, l); + + // Now serialize layers array + // First, serialize all layers to get their sizes + var layer_bytes_list = try tree.allocator.alloc(std.ArrayList(u8), layers.len); + defer { + for (layer_bytes_list) |*lb| { + lb.deinit(); + } + tree.allocator.free(layer_bytes_list); + } + + for (layers, 0..) |layer, i| { + layer_bytes_list[i] = std.ArrayList(u8).init(tree.allocator); + // Check if this is a padded single-node root layer: + // - It's the last layer + // - It has exactly 2 nodes + // - The layer 2 levels down (i-2) has MORE than 2 nodes (indicating natural convergence to 1 root) + // For top tree: layer i-2 has 2 nodes → layer i has 2 real nodes (no padding skip) + // For bottom tree: layer i-2 has 4+ nodes → layer i has 1 real node + padding (skip padding) + const is_padded_single_root = if (i == layers.len - 1 and layer.nodes.len == 2 and i >= 2) blk: { + const two_below = layers[i - 2]; + break :blk two_below.nodes.len > 2; // More than 2 nodes → converges to 1 → padded + } else false; + try serializePaddedLayer(&layer, &layer_bytes_list[i], is_padded_single_root); + } + + // Write layer offsets (relative to start of layers array) + var current_offset: u32 = @as(u32, @intCast(layers.len * 4)); // Space for offsets + for (layer_bytes_list) |*lb| { + try ssz.serialize(u32, current_offset, l); + current_offset += @as(u32, @intCast(lb.items.len)); + } + + // Write layer data + for (layer_bytes_list) |*lb| { + try l.appendSlice(lb.items); + } +} + +/// Serialize PaddedLayer to SSZ format (matching Rust leansig) +fn serializePaddedLayer(layer: *const PaddedLayer, l: *std.ArrayList(u8), skip_padding: bool) !void { + // Format: [start_index:8][nodes_offset:4][nodes_data] + + // Write start_index (u64) + try ssz.serialize(u64, @as(u64, @intCast(layer.start_index)), l); + + // Write nodes_offset (u32) - points to start of nodes array + const nodes_offset: u32 = 12; // 8 + 4 = 12 + try ssz.serialize(u32, nodes_offset, l); + + // CRITICAL: For padded single-node roots, Rust only serializes the real node (not padding) + // - If start_index is even (0), padding is at the back: serialize nodes[0] + // - If start_index is odd (1), padding is at the front: serialize nodes[1] + const nodes_to_serialize = if (skip_padding) blk: { + if ((layer.start_index & 1) == 0) { + // Even start_index: real node is first, padding is last + break :blk layer.nodes[0..1]; + } else { + // Odd start_index: padding is first, real node is last + break :blk layer.nodes[1..2]; + } + } else layer.nodes; + + // Write nodes as raw field element arrays (no length prefix, Vec<[FE; 8]> in Rust) + for (nodes_to_serialize) |node| { + // Each node is [8]FieldElement, serialize as 8 u32s in canonical form + for (node) |fe| { + try ssz.serialize(u32, fe.toCanonical(), l); + } + } +} + fn deserializeHashSubTree(allocator: std.mem.Allocator, serialized: []const u8) !*HashSubTree { if (serialized.len < 20) return error.InvalidLength; @@ -228,7 +318,6 @@ fn deserializeHashSubTree(allocator: std.mem.Allocator, serialized: []const u8) const lowest_layer = std.mem.readInt(u64, serialized[offset .. offset + 8][0..8], .little); offset += 8; _ = lowest_layer; - _ = depth; // Decode layers_offset (u32) const layers_offset = std.mem.readInt(u32, serialized[offset .. offset + 4][0..4], .little); @@ -268,13 +357,14 @@ fn deserializeHashSubTree(allocator: std.mem.Allocator, serialized: []const u8) layers[i] = try deserializePaddedLayer(allocator, layer_bytes); } - // Extract root from the last layer's last node - const root_value = if (layers.len > 0 and layers[layers.len - 1].nodes.len > 0) - layers[layers.len - 1].nodes[layers[layers.len - 1].nodes.len - 1] - else - [_]FieldElement{FieldElement{ .value = 0 }} ** 8; + // Extract root from the last layer's FIRST node (not last!) + // In leansig's tree structure, the root is stored as the first node of the last layer + const root_value = if (layers.len > 0 and layers[layers.len - 1].nodes.len > 0) blk: { + const root_node = layers[layers.len - 1].nodes[0]; // FIRST node, not last! + break :blk root_node; + } else [_]FieldElement{FieldElement{ .value = 0 }} ** 8; - return try HashSubTree.initWithLayers(allocator, root_value, layers); + return try HashSubTree.initWithLayers(allocator, root_value, layers, @intCast(depth)); } /// Helper function to deserialize PaddedLayer from leansig SSZ format @@ -296,22 +386,23 @@ fn deserializePaddedLayer(allocator: std.mem.Allocator, serialized: []const u8) errdefer allocator.free(nodes); // Deserialize nodes - // CRITICAL: Leansig stores field elements in SSZ as Montgomery form, NOT canonical! - // This is different from how we encode (we use canonical), but we must match leansig's format + // Leansig stores field elements in SSZ as CANONICAL form (confirmed by inspecting bytes) for (0..num_nodes) |i| { for (0..8) |j| { const val = std.mem.readInt(u32, nodes_data[i * 32 + j * 4 .. i * 32 + j * 4 + 4][0..4], .little); - // Leansig stores as Montgomery values directly, so use fromMontgomery - nodes[i][j] = FieldElement.fromMontgomery(val); - - if (i == num_nodes - 1 and j == 0) { - // Debug last node first element (the root) - std.debug.print("TREE_SSZ_DECODE: Last layer last node first element: raw_u32=0x{x:0>8}, as_montgomery=0x{x:0>8}, as_canonical=0x{x:0>8}\n", .{ - val, - nodes[i][j].value, - nodes[i][j].toCanonical(), - }); - } + nodes[i][j] = FieldElement.fromCanonical(val); + } + } + + // Verify first node canonical matches raw value + if (num_nodes > 0) { + const first_node_raw = std.mem.readInt(u32, nodes_data[0..4], .little); + const first_node_canonical = nodes[0][0].toCanonical(); + if (first_node_raw != first_node_canonical) { + std.debug.print("TREE_DECODE_ERROR: First node mismatch! raw=0x{x:0>8}, canonical=0x{x:0>8}\n", .{ + first_node_raw, + first_node_canonical, + }); } } @@ -509,7 +600,8 @@ const BottomTreeCache = struct { }; } - return try HashSubTree.initWithLayers(allocator, root_value, layers); + // Cache doesn't store depth, so we use 0 as a placeholder + return try HashSubTree.initWithLayers(allocator, root_value, layers, 0); } pub fn store( @@ -1177,11 +1269,6 @@ pub const GeneralizedXMSSPublicKey = struct { // Direct little-endian read instead of ssz.deserialize which may have issues const bytes = serialized[root_offset .. root_offset + 4]; const val = std.mem.readInt(u32, bytes[0..4], .little); - if (i == 0) { - std.debug.print("PK_SSZ_DECODE: First 4 bytes: {x:0>2}{x:0>2}{x:0>2}{x:0>2} -> u32=0x{x:0>8}\n", .{ - bytes[0], bytes[1], bytes[2], bytes[3], val, - }); - } root_canonical[i] = val; root_offset += 4; } @@ -1328,10 +1415,12 @@ pub const GeneralizedXMSSSecretKey = struct { // SSZ serialization methods pub fn sszEncode(self: *const GeneralizedXMSSSecretKey, l: *std.ArrayList(u8)) !void { - // Note: This is a simplified encoding that doesn't include trees (for compatibility with existing code) - // For full leansig-compatible encoding with trees, use a separate method + // Full leansig-compatible encoding with trees + // Format: [prf_key:32][parameter:20][activation_epoch:8][num_active_epochs:8] + // [top_tree_offset:4][left_bottom_tree_index:8][left_bottom_tree_offset:4][right_bottom_tree_offset:4] + // [top_tree_data][left_bottom_tree_data][right_bottom_tree_data] - // Encode prf_key (32 bytes, fixed-size) + // Encode fixed-size fields try ssz.serialize([32]u8, self.prf_key, l); // Convert parameter to canonical u32 array and serialize (20 bytes for 5 u32s) @@ -1346,6 +1435,40 @@ pub const GeneralizedXMSSSecretKey = struct { // Encode num_active_epochs as u64 (8 bytes) try ssz.serialize(u64, @as(u64, @intCast(self.num_active_epochs)), l); + + // Now we're at offset 68 (32+20+8+8) + // Encode offsets for variable-size fields + const fixed_part_end: u32 = 88; // 68 + 4 + 8 + 4 + 4 = 88 + + // Serialize top_tree to get its size + var top_tree_bytes = std.ArrayList(u8).init(self.allocator); + defer top_tree_bytes.deinit(); + try serializeHashSubTree(self.top_tree, &top_tree_bytes); + + // Serialize left_bottom_tree to get its size + var left_bottom_tree_bytes = std.ArrayList(u8).init(self.allocator); + defer left_bottom_tree_bytes.deinit(); + try serializeHashSubTree(self.left_bottom_tree, &left_bottom_tree_bytes); + + // Serialize right_bottom_tree to get its size + var right_bottom_tree_bytes = std.ArrayList(u8).init(self.allocator); + defer right_bottom_tree_bytes.deinit(); + try serializeHashSubTree(self.right_bottom_tree, &right_bottom_tree_bytes); + + // Write offsets + const top_tree_offset = fixed_part_end; + const left_bottom_tree_offset = top_tree_offset + @as(u32, @intCast(top_tree_bytes.items.len)); + const right_bottom_tree_offset = left_bottom_tree_offset + @as(u32, @intCast(left_bottom_tree_bytes.items.len)); + + try ssz.serialize(u32, top_tree_offset, l); + try ssz.serialize(u64, @as(u64, @intCast(self.left_bottom_tree_index)), l); + try ssz.serialize(u32, left_bottom_tree_offset, l); + try ssz.serialize(u32, right_bottom_tree_offset, l); + + // Write tree data + try l.appendSlice(top_tree_bytes.items); + try l.appendSlice(left_bottom_tree_bytes.items); + try l.appendSlice(right_bottom_tree_bytes.items); } pub fn sszDecode(serialized: []const u8, out: *GeneralizedXMSSSecretKey, allocator: ?std.mem.Allocator) !void { @@ -1397,12 +1520,21 @@ pub const GeneralizedXMSSSecretKey = struct { // Deserialize top_tree const top_tree = try deserializeHashSubTree(alloc, serialized[top_tree_offset..left_bottom_tree_offset]); + errdefer top_tree.deinit(); // Deserialize left_bottom_tree - const left_bottom_tree = try deserializeHashSubTree(alloc, serialized[left_bottom_tree_offset..right_bottom_tree_offset]); + const left_bottom_tree = deserializeHashSubTree(alloc, serialized[left_bottom_tree_offset..right_bottom_tree_offset]) catch |err| { + top_tree.deinit(); + return err; + }; + errdefer left_bottom_tree.deinit(); // Deserialize right_bottom_tree - const right_bottom_tree = try deserializeHashSubTree(alloc, serialized[right_bottom_tree_offset..]); + const right_bottom_tree = deserializeHashSubTree(alloc, serialized[right_bottom_tree_offset..]) catch |err| { + left_bottom_tree.deinit(); + top_tree.deinit(); + return err; + }; // Initialize the secret key out.* = GeneralizedXMSSSecretKey{ @@ -2261,7 +2393,9 @@ pub const GeneralizedXMSSSignatureScheme = struct { } // Store layers in HashSubTree so they can be reused during signing (major optimization!) - return try HashSubTree.initWithLayers(self.allocator, bottom_root, bottom_layers); + // Bottom tree depth is log_lifetime (32 for 2^32), matching Rust's encoding + const tree_depth = self.lifetime_params.log_lifetime; + return try HashSubTree.initWithLayers(self.allocator, bottom_root, bottom_layers, tree_depth); } /// Compute hash chain (matching Rust chain function) @@ -4666,11 +4800,16 @@ pub const GeneralizedXMSSSignatureScheme = struct { activation_epoch: usize, num_active_epochs: usize, ) !KeyGenResult { + // CRITICAL: Rust leansig library multiplies num_active_epochs by 128 internally + // To match Rust's behavior exactly, we multiply by 128 here + // Example: Input 1024 -> Rust stores 131072 (1024 * 128) in SSZ + const rust_compatible_num_active_epochs = num_active_epochs * 128; + // Generate random parameter and PRF key (matching Rust order exactly) const parameter = try self.generateRandomParameter(); const prf_key = try self.generateRandomPRFKey(); // RNG has already been consumed by generateRandomPRFKey() (32 bytes) - return self.keyGenWithParameter(activation_epoch, num_active_epochs, parameter, prf_key, true); + return self.keyGenWithParameter(activation_epoch, rust_compatible_num_active_epochs, parameter, prf_key, true); } /// Key generation with provided parameter and PRF key (for reconstructing keys from serialized data) @@ -5029,7 +5168,9 @@ pub const GeneralizedXMSSSignatureScheme = struct { } // Create a top tree for the secret key, preserving the layered structure for future path computation - const top_tree = try HashSubTree.initWithLayers(self.allocator, root_array, top_layers); + // Top tree depth is log_lifetime (32 for 2^32), matching Rust's encoding + const tree_depth = self.lifetime_params.log_lifetime; + const top_tree = try HashSubTree.initWithLayers(self.allocator, root_array, top_layers, tree_depth); top_layers = top_layers[0..0]; // Create public and secret keys (store root in Montgomery form to match Rust)