"""
ch15_pq_hybrid_tls.py
Vol II · Chapter 15 — Hybrid X25519 + ML-KEM-768 Key Exchange
==============================================================
Simulates the X25519MLKEM768 hybrid key exchange used in TLS 1.3
to protect the settlement message bus during the PQC transition.

The hybrid construction provides:
  - Classical security (X25519) against today's adversaries
  - Post-quantum security (ML-KEM-768) against a CRQC
  Both are required for the combined key to be compromised.

The shared secret is derived by concatenating both secrets and
passing them through HKDF — matching the approach specified in
RFC 9180 and the IETF draft-ietf-tls-hybrid-design.

Requirements:
    pip install liboqs-python cryptography
"""

import oqs
import hashlib
import hmac
import struct
import time
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.ciphers.aead import AESGCM

KEM_ALG = "ML-KEM-768"


# ── HKDF helper ───────────────────────────────────────────────────────────────

def hkdf_extract_expand(
    ikm: bytes,
    info: bytes = b"tls13 hybrid settlement",
    length: int = 32
) -> bytes:
    """HKDF-SHA256: extract then expand."""
    return HKDF(
        algorithm=hashes.SHA256(),
        length=length,
        salt=None,
        info=info,
    ).derive(ikm)


# ── Hybrid key exchange ───────────────────────────────────────────────────────

class HybridKeyExchange:
    """
    Simulates one side of a hybrid X25519 + ML-KEM-768 key exchange.
    Naming follows TLS 1.3 terminology: 'server' (CSD) and 'client' (participant).
    """

    def __init__(self, role: str):
        assert role in ("server", "client"), "role must be 'server' or 'client'"
        self.role = role

    # ── Server: generate key share ────────────────────────────────────────────

    def server_generate_key_share(self) -> dict:
        """
        Server (CSD) generates:
          - X25519 key pair
          - ML-KEM-768 key pair
        Returns the public key share to send in ServerHello.
        """
        # X25519
        self._x25519_priv = X25519PrivateKey.generate()
        x25519_pub = self._x25519_priv.public_key().public_bytes_raw()

        # ML-KEM-768
        self._kem = oqs.KeyEncapsulation(KEM_ALG)
        mlkem_pub = self._kem.generate_keypair()
        self._mlkem_priv = self._kem.export_secret_key()

        self._server_key_share = {
            "x25519_public":  x25519_pub,
            "mlkem_public":   mlkem_pub,
        }
        return self._server_key_share

    # ── Client: process server share and generate client share ────────────────

    def client_process_server_share(self, server_share: dict) -> dict:
        """
        Client (participant) receives server's key share and:
          - Performs X25519 key agreement
          - Encapsulates ML-KEM-768 shared secret
        Returns the client key share to send in ClientKeyShare.
        """
        # X25519: client generates ephemeral key, computes shared secret
        self._x25519_priv = X25519PrivateKey.generate()
        x25519_pub = self._x25519_priv.public_key().public_bytes_raw()

        from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PublicKey
        server_x25519_pub = X25519PublicKey.from_public_bytes(server_share["x25519_public"])
        x25519_secret = self._x25519_priv.exchange(server_x25519_pub)

        # ML-KEM-768: encapsulate
        with oqs.KeyEncapsulation(KEM_ALG) as k:
            mlkem_ct, mlkem_secret = k.encap_secret(server_share["mlkem_public"])

        # Combined key: HKDF over concatenation of both secrets
        combined_ikm   = x25519_secret + mlkem_secret
        self.shared_key = hkdf_extract_expand(combined_ikm)

        return {
            "x25519_public": x25519_pub,
            "mlkem_ct":      mlkem_ct,
        }

    # ── Server: process client share ─────────────────────────────────────────

    def server_process_client_share(self, client_share: dict) -> None:
        """
        Server (CSD) receives client's key share:
          - Completes X25519 key agreement
          - Decapsulates ML-KEM-768 shared secret
        Derives the same combined key.
        """
        from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PublicKey
        client_x25519_pub = X25519PublicKey.from_public_bytes(client_share["x25519_public"])
        x25519_secret = self._x25519_priv.exchange(client_x25519_pub)

        # ML-KEM-768: decapsulate
        with oqs.KeyEncapsulation(KEM_ALG) as k:
            k.import_secret_key(self._mlkem_priv)
            mlkem_secret = k.decap_secret(client_share["mlkem_ct"])

        combined_ikm   = x25519_secret + mlkem_secret
        self.shared_key = hkdf_extract_expand(combined_ikm)


# ── Demo: full handshake ─────────────────────────────────────────────────────

print("=" * 60)
print("Hybrid X25519 + ML-KEM-768 Key Exchange")
print("(simulating TLS 1.3 ClientKeyShare / ServerHello)")
print("=" * 60)
print()

server = HybridKeyExchange("server")
client = HybridKeyExchange("client")

# 1. Server generates key share (would appear in ServerHello)
server_share = server.server_generate_key_share()
print("Server key share (ServerHello):")
print(f"  X25519 public key  :     {len(server_share['x25519_public'])} bytes")
print(f"  ML-KEM-768 public  : {len(server_share['mlkem_public']):,} bytes")
total_server = len(server_share['x25519_public']) + len(server_share['mlkem_public'])
print(f"  Total              : {total_server:,} bytes")
print()

# 2. Client processes server share and returns client share (ClientKeyShare)
client_share = client.client_process_server_share(server_share)
print("Client key share (ClientKeyShare extension):")
print(f"  X25519 public key  :     {len(client_share['x25519_public'])} bytes")
print(f"  ML-KEM-768 CT      : {len(client_share['mlkem_ct']):,} bytes")
total_client = len(client_share['x25519_public']) + len(client_share['mlkem_ct'])
print(f"  Total              : {total_client:,} bytes")
print()

# 3. Server processes client share
server.server_process_client_share(client_share)

# 4. Verify both sides hold the same key
assert server.shared_key == client.shared_key, "Key mismatch — handshake failed"
print(f"Shared key (32 bytes): {server.shared_key.hex()}")
print(f"Keys match:            True  — secure channel established")
print()


# ── Wire-size impact ──────────────────────────────────────────────────────────

print("=" * 60)
print("TLS 1.3 Handshake Size Impact")
print("=" * 60)
print()
print("Classical TLS 1.3 (X25519 only):")
x25519_hs = 32 + 32   # server pubkey + client pubkey
print(f"  Key exchange data: {x25519_hs} bytes")
print()
print("Hybrid TLS 1.3 (X25519MLKEM768):")
hybrid_hs = total_server + total_client
print(f"  Key exchange data: {hybrid_hs:,} bytes")
print(f"  Overhead:         +{hybrid_hs - x25519_hs:,} bytes  "
      f"(+{(hybrid_hs/x25519_hs - 1)*100:.0f}%)")
print()
print("At 100,000 connections/day on the settlement bus:")
daily_overhead = (hybrid_hs - x25519_hs) * 100_000
print(f"  Additional data: {daily_overhead/1e6:.1f} MB/day — negligible.")
print()


# ── Encrypt a settlement message with the derived key ────────────────────────

print("=" * 60)
print("Settlement Message Encryption (AES-256-GCM)")
print("=" * 60)
print()

nonce   = bytes(12)   # In production: random per message
aes_key = server.shared_key   # 32 bytes = AES-256
aesgcm  = AESGCM(aes_key)

plaintext = b'{"isin":"FR0010242511","qty":10000,"amount":652400.0,"ccy":"EUR"}'
ciphertext_msg = aesgcm.encrypt(nonce, plaintext, b"settlement-bus-v1")

print(f"Plaintext ({len(plaintext)} B):   {plaintext.decode()}")
print(f"Ciphertext ({len(ciphertext_msg)} B):  {ciphertext_msg.hex()[:48]}…")

# Decrypt
recovered = aesgcm.decrypt(nonce, ciphertext_msg, b"settlement-bus-v1")
print(f"Decrypted:              {recovered.decode()}")
print(f"Round-trip OK:          {recovered == plaintext}")
