"""
ch15_pq_settlement_ledger.py
Vol II · Chapter 15 — Post-Quantum Settlement Ledger
=====================================================
Assembles the chapter primitives into a complete example:

  1. QuantumSafeLedger  — append-only ML-DSA-65 signed chain
  2. Performance        — ECDSA-256 vs ML-DSA-65 benchmark
  3. SLH-DSA archival   — hash-based signature for regulatory records
  4. Cryptographic agility — configurable signing backend pattern

Requirements:
    pip install liboqs-python cryptography
"""

import oqs
import hashlib
import json
import time
from dataclasses import dataclass, field
from typing import List, Optional
from abc import ABC, abstractmethod
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives import hashes


# =============================================================================
# SECTION 1 — Minimal Post-Quantum Settlement Ledger
# =============================================================================

@dataclass
class LedgerEntry:
    """A single immutable entry in the settlement ledger."""
    entry_id:      int
    timestamp:     float
    instruction:   dict
    signer_pubkey: bytes          # ML-DSA-65 public key (1,952 B)
    signature:     bytes          # ML-DSA-65 signature  (3,309 B)
    prev_hash:     str            # SHA-256 of previous entry
    entry_hash:    str = field(init=False)

    def __post_init__(self):
        content = json.dumps({
            "entry_id":    self.entry_id,
            "timestamp":   self.timestamp,
            "instruction": self.instruction,
            "signature":   self.signature.hex(),
            "prev_hash":   self.prev_hash,
        }, sort_keys=True).encode("utf-8")
        self.entry_hash = hashlib.sha256(content).hexdigest()


class QuantumSafeLedger:
    """An append-only ledger of ML-DSA-65 signed settlement instructions."""

    GENESIS_HASH = "0" * 64

    def __init__(self):
        self.entries: List[LedgerEntry] = []
        self._key_registry: dict = {}

    def register_participant(self, participant_id: str, public_key: bytes) -> None:
        """Register a participant's ML-DSA-65 public key."""
        self._key_registry[public_key.hex()] = participant_id
        print(f"Registered: {participant_id} "
              f"(pubkey {public_key.hex()[:16]}…)")

    def submit(self, instruction: dict, signer_pubkey: bytes,
               signature: bytes) -> Optional[LedgerEntry]:
        """
        Validate and append a signed settlement instruction.
        Returns the new LedgerEntry on success, None on failure.
        """
        if signer_pubkey.hex() not in self._key_registry:
            print("REJECTED: unknown signer public key")
            return None

        message = json.dumps(instruction, sort_keys=True).encode("utf-8")
        with oqs.Signature("ML-DSA-65") as verifier:
            if not verifier.verify(message, signature, signer_pubkey):
                print("REJECTED: invalid ML-DSA-65 signature")
                return None

        prev_hash = (self.entries[-1].entry_hash
                     if self.entries
                     else self.GENESIS_HASH)

        entry = LedgerEntry(
            entry_id      = len(self.entries),
            timestamp     = time.time(),
            instruction   = instruction,
            signer_pubkey = signer_pubkey,
            signature     = signature,
            prev_hash     = prev_hash,
        )
        self.entries.append(entry)

        participant = self._key_registry[signer_pubkey.hex()]
        print(f"SETTLED  entry #{entry.entry_id} | "
              f"{instruction.get('isin','?')} "
              f"{instruction.get('quantity','?'):,} @ "
              f"{instruction.get('amount','?'):,.2f} EUR | "
              f"signer: {participant}")
        return entry

    def verify_chain(self) -> bool:
        """
        Verify the integrity of the entire ledger:
          - Each entry's hash chain is unbroken
          - Each ML-DSA-65 signature is valid
        """
        prev_hash = self.GENESIS_HASH
        for entry in self.entries:
            if entry.prev_hash != prev_hash:
                print(f"CHAIN BROKEN at entry #{entry.entry_id}")
                return False

            message = json.dumps(entry.instruction,
                                 sort_keys=True).encode("utf-8")
            with oqs.Signature("ML-DSA-65") as v:
                if not v.verify(message, entry.signature, entry.signer_pubkey):
                    print(f"INVALID SIG at entry #{entry.entry_id}")
                    return False

            prev_hash = entry.entry_hash

        print(f"Chain intact: {len(self.entries)} entries verified")
        return True


# ── Demo ──────────────────────────────────────────────────────────────────────

print("=" * 60)
print("SECTION 1: Post-Quantum Settlement Ledger")
print("=" * 60)

with oqs.Signature("ML-DSA-65") as s:
    bnp_pubkey  = s.generate_keypair()
    bnp_privkey = s.export_secret_key()

with oqs.Signature("ML-DSA-65") as s:
    soc_pubkey  = s.generate_keypair()
    soc_privkey = s.export_secret_key()

ledger = QuantumSafeLedger()
ledger.register_participant("BNP_PARIBAS_SECURITIES",   bnp_pubkey)
ledger.register_participant("SOCIETE_GENERALE_GESTION", soc_pubkey)
print()

instr_1 = {
    "isin":             "FR0010242511",
    "quantity":         10_000,
    "currency":         "EUR",
    "amount":           652_400.00,
    "settlement_date":  "2026-01-17",
    "delivering_party": "BNP_PARIBAS_SECURITIES",
    "receiving_party":  "SOCIETE_GENERALE_GESTION",
}
msg_1 = json.dumps(instr_1, sort_keys=True).encode("utf-8")
with oqs.Signature("ML-DSA-65") as s:
    s.import_secret_key(bnp_privkey)
    sig_1 = s.sign(msg_1)
ledger.submit(instr_1, bnp_pubkey, sig_1)

instr_2 = {
    "isin":             "DE0005140008",
    "quantity":         5_000,
    "currency":         "EUR",
    "amount":           318_750.00,
    "settlement_date":  "2026-01-17",
    "delivering_party": "SOCIETE_GENERALE_GESTION",
    "receiving_party":  "BNP_PARIBAS_SECURITIES",
}
msg_2 = json.dumps(instr_2, sort_keys=True).encode("utf-8")
with oqs.Signature("ML-DSA-65") as s:
    s.import_secret_key(soc_privkey)
    sig_2 = s.sign(msg_2)
ledger.submit(instr_2, soc_pubkey, sig_2)

print()
ledger.verify_chain()


# =============================================================================
# SECTION 2 — Performance: ECDSA-256 vs ML-DSA-65
# =============================================================================

print()
print("=" * 60)
print("SECTION 2: Performance Benchmark")
print("=" * 60)

N       = 1_000
message = b"settlement instruction payload " * 8  # 256 bytes

ecdsa_key = ec.generate_private_key(ec.SECP256R1())
t0 = time.perf_counter()
for _ in range(N):
    ecdsa_key.sign(message, ec.ECDSA(hashes.SHA256()))
ecdsa_time = time.perf_counter() - t0
print(f"ECDSA-256:   {N/ecdsa_time:>8,.0f} sign/s  ({ecdsa_time/N*1000:.3f} ms/op)")

with oqs.Signature("ML-DSA-65") as signer:
    pub  = signer.generate_keypair()
    priv = signer.export_secret_key()

t0 = time.perf_counter()
with oqs.Signature("ML-DSA-65") as signer:
    signer.import_secret_key(priv)
    for _ in range(N):
        signer.sign(message)
mldsa_time = time.perf_counter() - t0
print(f"ML-DSA-65:   {N/mldsa_time:>8,.0f} sign/s  ({mldsa_time/N*1000:.3f} ms/op)")
print(f"Slowdown:    ×{mldsa_time/ecdsa_time:.2f}")
print()
print("A settlement system at 100,000 instructions/hour (~28/s)")
print(f"requires {N/mldsa_time:,.0f}× more throughput than needed.")


# =============================================================================
# SECTION 3 — SLH-DSA for Archival Records
# =============================================================================

print()
print("=" * 60)
print("SECTION 3: SLH-DSA-128f for Archival Records  (FIPS 205)")
print("=" * 60)

SLH_ALG = "SPHINCS+-SHA2-128f-simple"   # liboqs name for SLH-DSA-128f

with oqs.Signature(SLH_ALG) as s:
    pub  = s.generate_keypair()
    priv = s.export_secret_key()

print(f"SLH-DSA-128f public key:  {len(pub)} bytes")    # 32 bytes
print(f"SLH-DSA-128f private key: {len(priv)} bytes")   # 64 bytes

record = b"END_OF_DAY_RECORD|DATE=20260115|TOTAL_SETTLED=EUR 12.4B|..."

t0 = time.perf_counter()
with oqs.Signature(SLH_ALG) as s:
    s.import_secret_key(priv)
    sig = s.sign(record)
elapsed = (time.perf_counter() - t0) * 1000

print(f"Signature size:   {len(sig):,} bytes")    # 17,088 bytes
print(f"Signing time:     {elapsed:.1f} ms")      # ~14 ms for 128f

with oqs.Signature(SLH_ALG) as v:
    valid = v.verify(record, sig, pub)
print(f"Signature valid:  {valid}")

print()
print("SLH-DSA security relies only on SHA-256 — no lattice assumptions.")
print("Ideal for regulatory archives with 10–20 year retention horizons.")


# =============================================================================
# SECTION 4 — Cryptographic Agility
# =============================================================================

print()
print("=" * 60)
print("SECTION 4: Cryptographic Agility Pattern")
print("=" * 60)


class SigningBackend(ABC):
    """Abstract interface: algorithm-agnostic signing."""
    @abstractmethod
    def generate_keypair(self) -> tuple:
        """Returns (public_key_bytes, private_key_bytes)."""

    @abstractmethod
    def sign(self, message: bytes, private_key: bytes) -> bytes:
        """Returns signature bytes."""

    @abstractmethod
    def verify(self, message: bytes, signature: bytes, public_key: bytes) -> bool:
        """Returns True if signature is valid."""

    @property
    @abstractmethod
    def algorithm_id(self) -> str:
        """Human-readable algorithm identifier."""


class MLDSABackend(SigningBackend):
    """Post-quantum signing using ML-DSA (liboqs / FIPS 204)."""
    def __init__(self, param_set: str = "ML-DSA-65"):
        self._alg = param_set

    def generate_keypair(self):
        with oqs.Signature(self._alg) as s:
            pub = s.generate_keypair()
            priv = s.export_secret_key()
        return pub, priv

    def sign(self, message: bytes, private_key: bytes) -> bytes:
        with oqs.Signature(self._alg) as s:
            s.import_secret_key(private_key)
            return s.sign(message)

    def verify(self, message: bytes, signature: bytes, public_key: bytes) -> bool:
        with oqs.Signature(self._alg) as v:
            return v.verify(message, signature, public_key)

    @property
    def algorithm_id(self) -> str:
        return self._alg


# Algorithm registry — read from configuration at startup.
# Changing algorithm requires ONE line change here, not in application code.
ALGORITHM_REGISTRY = {
    "transaction_signing": MLDSABackend("ML-DSA-65"),
    "certificate_signing": MLDSABackend("ML-DSA-87"),
    "archival_sealing":    MLDSABackend("ML-DSA-65"),
    # Future: uncomment to migrate archival sealing to SLH-DSA
    # "archival_sealing":  SLHDSABackend("SPHINCS+-SHA2-256s-simple"),
}


def sign_settlement_instruction(instruction_bytes: bytes,
                                private_key: bytes) -> dict:
    backend = ALGORITHM_REGISTRY["transaction_signing"]
    sig = backend.sign(instruction_bytes, private_key)
    return {"algorithm": backend.algorithm_id, "signature": sig.hex()}


# Demo
pub, priv = ALGORITHM_REGISTRY["transaction_signing"].generate_keypair()
test_msg  = b'{"isin":"FR0010242511","qty":10000}'
result    = sign_settlement_instruction(test_msg, priv)

valid = ALGORITHM_REGISTRY["transaction_signing"].verify(
    test_msg, bytes.fromhex(result["signature"]), pub
)
print(f"Algorithm in use:  {result['algorithm']}")
print(f"Signature bytes:   {len(result['signature'])//2:,}")
print(f"Verification:      {valid}")
print()
print("Upgrading from ML-DSA-65 to ML-DSA-87 requires one registry change.")
print("No application code is modified. This is cryptographic agility.")
