"""Cryptographic utilities: Ed25519, X25519, AES-256-GCM, Double Ratchet, Sender Keys. RSA functions retained for login challenge-response only. """ import hashlib import hmac import json import os import struct import uuid from dataclasses import dataclass, field from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import padding, rsa from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey, Ed25519PublicKey from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey from cryptography.hazmat.primitives.ciphers.aead import AESGCM from cryptography.hazmat.primitives.kdf.hkdf import HKDF from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC # --------------------------------------------------------------------------- # Password-based key encryption (M3: PBKDF2 600k iterations + AES-256-GCM) # --------------------------------------------------------------------------- PBKDF2_ITERATIONS = 600_000 _ECP1_MAGIC = b"ECP1" # Encrypted Chat PBKDF v1 format marker def _encrypt_private_key(raw_bytes: bytes, password: bytes) -> bytes: """Encrypt raw key bytes with PBKDF2-HMAC-SHA256 (600k iterations) + AES-256-GCM. Output format: MAGIC(4) + salt(16) + nonce(12) + ciphertext_with_tag(N+16) """ salt = os.urandom(16) kdf = PBKDF2HMAC(algorithm=hashes.SHA256(), length=32, salt=salt, iterations=PBKDF2_ITERATIONS) derived = kdf.derive(password) nonce = os.urandom(12) aesgcm = AESGCM(derived) ct = aesgcm.encrypt(nonce, raw_bytes, _ECP1_MAGIC) # AAD = magic bytes return _ECP1_MAGIC + salt + nonce + ct def _decrypt_private_key(data: bytes, password: bytes) -> bytes: """Decrypt key bytes encrypted with _encrypt_private_key.""" if not data.startswith(_ECP1_MAGIC): raise ValueError("Not ECP1 format") salt = data[4:20] nonce = data[20:32] ct = data[32:] kdf = PBKDF2HMAC(algorithm=hashes.SHA256(), length=32, salt=salt, iterations=PBKDF2_ITERATIONS) derived = kdf.derive(password) aesgcm = AESGCM(derived) return aesgcm.decrypt(nonce, ct, _ECP1_MAGIC) # --------------------------------------------------------------------------- # RSA (login challenge-response ONLY) # --------------------------------------------------------------------------- def generate_rsa_keypair(key_size: int = 4096) -> tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey]: private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size) return private_key, private_key.public_key() def serialize_private_key(key: rsa.RSAPrivateKey, password: bytes | None = None) -> bytes: if password: raw = key.private_bytes(serialization.Encoding.DER, serialization.PrivateFormat.PKCS8, serialization.NoEncryption()) return _encrypt_private_key(raw, password) return key.private_bytes(serialization.Encoding.PEM, serialization.PrivateFormat.PKCS8, serialization.NoEncryption()) def serialize_public_key(key: rsa.RSAPublicKey) -> bytes: return key.public_bytes(serialization.Encoding.PEM, serialization.PublicFormat.SubjectPublicKeyInfo) def load_private_key(data: bytes, password: bytes | None = None) -> rsa.RSAPrivateKey: if data.startswith(_ECP1_MAGIC): raw = _decrypt_private_key(data, password) return serialization.load_der_private_key(raw, password=None) # Legacy PEM format (old BestAvailableEncryption or unencrypted) return serialization.load_pem_private_key(data, password=password) def load_public_key(pem: bytes) -> rsa.RSAPublicKey: return serialization.load_pem_public_key(pem) def compute_pairing_fingerprint(public_key_data: bytes | str) -> str: """Format a temporary pairing public key as a human-verifiable fingerprint.""" if isinstance(public_key_data, str): key_bytes = public_key_data.encode("utf-8") else: key_bytes = public_key_data canonical = key_bytes.replace(b"\r\n", b"\n").strip() if b"-----BEGIN" in key_bytes else key_bytes digest = hashlib.sha256(b"EncryptedChat_PairingKey_v1\x00" + canonical).digest() return format_fingerprint(digest) def normalize_pairing_fingerprint(value: str) -> str: """Normalize user-entered pairing fingerprints for comparison.""" return "".join(ch for ch in value if ch.isdigit()) def encode_pairing_qr(code: str, fingerprint: str) -> bytes: """Encode pairing code + fingerprint for QR transport. Format: magic(5='PAIR1') + code(8 ASCII digits) + fingerprint(30 ASCII digits) """ code_digits = "".join(ch for ch in code if ch.isdigit()) fp_digits = normalize_pairing_fingerprint(fingerprint) if len(code_digits) != 8: raise ValueError("Pairing code must contain 8 digits") if len(fp_digits) != 30: raise ValueError("Pairing fingerprint must contain 30 digits") return b"PAIR1" + code_digits.encode("ascii") + fp_digits.encode("ascii") def decode_pairing_qr(data: bytes) -> tuple[str, str]: """Decode pairing QR payload. Returns (code, formatted_fingerprint).""" if len(data) != 43 or not data.startswith(b"PAIR1"): raise ValueError("Invalid pairing QR payload") code = data[5:13].decode("ascii") fp_digits = data[13:43].decode("ascii") if not code.isdigit() or not fp_digits.isdigit(): raise ValueError("Invalid pairing QR payload") groups = [fp_digits[i:i + 5] for i in range(0, 30, 5)] return code, " ".join(groups[:3]) + "\n" + " ".join(groups[3:]) def rsa_sign(private_key: rsa.RSAPrivateKey, data: bytes) -> bytes: return private_key.sign( data, padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH), hashes.SHA256(), ) def rsa_verify(public_key: rsa.RSAPublicKey, signature: bytes, data: bytes) -> bool: try: public_key.verify( signature, data, padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.AUTO), hashes.SHA256(), ) return True except Exception: return False # --------------------------------------------------------------------------- # AES-256-GCM (symmetric encryption — used by ratchet message keys & images) # --------------------------------------------------------------------------- def aes_encrypt(plaintext: bytes, key: bytes | None = None) -> tuple[bytes, bytes, bytes, bytes]: """Encrypt with AES-256-GCM. Returns (key, nonce, ciphertext, tag).""" if key is None: key = AESGCM.generate_key(bit_length=256) nonce = os.urandom(12) aesgcm = AESGCM(key) ct_with_tag = aesgcm.encrypt(nonce, plaintext, None) ciphertext = ct_with_tag[:-16] tag = ct_with_tag[-16:] return key, nonce, ciphertext, tag def aes_decrypt(key: bytes, nonce: bytes, ciphertext: bytes, tag: bytes) -> bytes: """Decrypt with AES-256-GCM.""" aesgcm = AESGCM(key) return aesgcm.decrypt(nonce, ciphertext + tag, None) # --------------------------------------------------------------------------- # Ed25519 Identity Keys # --------------------------------------------------------------------------- def generate_identity_keypair() -> tuple[Ed25519PrivateKey, Ed25519PublicKey]: priv = Ed25519PrivateKey.generate() return priv, priv.public_key() def serialize_ed25519_private(key: Ed25519PrivateKey, password: bytes | None = None) -> bytes: if password: raw = serialize_ed25519_private_raw(key) # 32 bytes return _encrypt_private_key(raw, password) return serialize_ed25519_private_raw(key) # 32 bytes, no password def serialize_ed25519_private_raw(key: Ed25519PrivateKey) -> bytes: """Serialize Ed25519 private key to 32 raw bytes (unencrypted).""" return key.private_bytes(serialization.Encoding.Raw, serialization.PrivateFormat.Raw, serialization.NoEncryption()) def serialize_ed25519_public(key: Ed25519PublicKey) -> bytes: """Serialize Ed25519 public key to 32 raw bytes.""" return key.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw) def load_ed25519_private(data: bytes, password: bytes | None = None) -> Ed25519PrivateKey: if data.startswith(_ECP1_MAGIC): raw = _decrypt_private_key(data, password) return Ed25519PrivateKey.from_private_bytes(raw) # Legacy formats: PEM (old BestAvailableEncryption) or 32-byte raw if password: return serialization.load_pem_private_key(data, password=password) if len(data) == 32: return Ed25519PrivateKey.from_private_bytes(data) return serialization.load_pem_private_key(data, password=None) def load_ed25519_public(data: bytes) -> Ed25519PublicKey: if len(data) == 32: return Ed25519PublicKey.from_public_bytes(data) return serialization.load_pem_public_key(data) def ed25519_sign(private_key: Ed25519PrivateKey, data: bytes) -> bytes: """Sign data with Ed25519. Returns 64-byte signature.""" return private_key.sign(data) def ed25519_verify(public_key: Ed25519PublicKey, signature: bytes, data: bytes) -> bool: """Verify Ed25519 signature.""" try: public_key.verify(signature, data) return True except Exception: return False # --------------------------------------------------------------------------- # X25519 Key Exchange # --------------------------------------------------------------------------- def generate_x25519_keypair() -> tuple[X25519PrivateKey, X25519PublicKey]: priv = X25519PrivateKey.generate() return priv, priv.public_key() def serialize_x25519_private(key: X25519PrivateKey) -> bytes: """Serialize X25519 private key to 32 raw bytes.""" return key.private_bytes(serialization.Encoding.Raw, serialization.PrivateFormat.Raw, serialization.NoEncryption()) def serialize_x25519_public(key: X25519PublicKey) -> bytes: """Serialize X25519 public key to 32 raw bytes.""" return key.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw) def load_x25519_private(data: bytes) -> X25519PrivateKey: return X25519PrivateKey.from_private_bytes(data) def load_x25519_public(data: bytes) -> X25519PublicKey: return X25519PublicKey.from_public_bytes(data) def x25519_dh(private_key: X25519PrivateKey, public_key: X25519PublicKey) -> bytes: """Perform X25519 Diffie-Hellman. Returns 32-byte shared secret.""" return private_key.exchange(public_key) def derive_pairing_shared_key(shared_secret: bytes, public_key_a: bytes, public_key_b: bytes) -> bytes: """Derive a symmetric bootstrap key for device pairing. The key derivation is direction-agnostic: both peers sort the two public keys lexicographically before binding them into HKDF salt. """ pub1, pub2 = sorted((public_key_a, public_key_b)) salt = hashlib.sha256(b"EncryptedChat_PairingSalt_v1\x00" + pub1 + pub2).digest() return hkdf_derive(shared_secret, salt=salt, info=b"EncryptedChat_PairingBootstrap", length=32) # --------------------------------------------------------------------------- # Ed25519 <-> X25519 conversion (for Identity Key dual use) # --------------------------------------------------------------------------- def ed25519_private_to_x25519(ed_private: Ed25519PrivateKey) -> X25519PrivateKey: """Derive X25519 private key from Ed25519 private key via RFC 7748 clamping.""" raw = ed_private.private_bytes( serialization.Encoding.Raw, serialization.PrivateFormat.Raw, serialization.NoEncryption() ) # SHA-512 hash of the seed, take first 32 bytes, clamp per RFC 7748 h = hashlib.sha512(raw).digest()[:32] clamped = bytearray(h) clamped[0] &= 248 clamped[31] &= 127 clamped[31] |= 64 return X25519PrivateKey.from_private_bytes(bytes(clamped)) def ed25519_public_to_x25519(ed_public: Ed25519PublicKey) -> X25519PublicKey: """Derive X25519 public key from Ed25519 public key. Uses the cryptography library's internal conversion. For production use, we compute the X25519 public key from the converted private key when possible. For remote keys (where we don't have the private key), we use a pure-Python implementation of the Ed25519->X25519 point conversion. """ # Montgomery u = (1 + y) / (1 - y) mod p, where p = 2^255 - 19 raw = ed_public.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw) y = int.from_bytes(raw, "little") # Clear the sign bit y &= (1 << 255) - 1 p = (1 << 255) - 19 # u = (1 + y) * inverse(1 - y) mod p one_plus_y = (1 + y) % p one_minus_y = (1 - y) % p inv = pow(one_minus_y, p - 2, p) u = (one_plus_y * inv) % p x25519_bytes = u.to_bytes(32, "little") return X25519PublicKey.from_public_bytes(x25519_bytes) # --------------------------------------------------------------------------- # HKDF # --------------------------------------------------------------------------- _HKDF_INFO_SELF = b"EncryptedChat_SelfKey" _HKDF_INFO_RK = b"EncryptedChat_RootKey" def derive_self_encryption_key(identity_private: Ed25519PrivateKey) -> bytes: """Derive a static AES-256 key from identity key for encrypting own sent messages. This is NOT a ratchet — it's a static key. Safe because only the owner has the identity private key, and self-copies don't need forward secrecy. """ raw = identity_private.private_bytes( serialization.Encoding.Raw, serialization.PrivateFormat.Raw, serialization.NoEncryption() ) return hkdf_derive(raw, salt=b"self_encryption", info=_HKDF_INFO_SELF, length=32) _HKDF_INFO_LOCAL = b"EncryptedChat_LocalStorage" def derive_local_storage_key(identity_private: Ed25519PrivateKey) -> bytes: """Derive AES-256 key for encrypting local session/sender key files.""" raw = identity_private.private_bytes( serialization.Encoding.Raw, serialization.PrivateFormat.Raw, serialization.NoEncryption() ) return hkdf_derive(raw, salt=b"local_storage", info=_HKDF_INFO_LOCAL, length=32) _HKDF_INFO_CK_MSG = b"\x01" # chain key -> message key _HKDF_INFO_CK_NEXT = b"\x02" # chain key -> next chain key def hkdf_derive(input_key: bytes, salt: bytes, info: bytes, length: int = 32) -> bytes: return HKDF(algorithm=hashes.SHA256(), length=length, salt=salt, info=info).derive(input_key) def kdf_rk(root_key: bytes, dh_output: bytes) -> tuple[bytes, bytes]: """Root key KDF. Returns (new_root_key, chain_key). Uses HKDF with the root key as salt and DH output as input key material. Derives 64 bytes: first 32 = new root key, last 32 = chain key. """ derived = hkdf_derive(dh_output, salt=root_key, info=_HKDF_INFO_RK, length=64) return derived[:32], derived[32:] def kdf_ck(chain_key: bytes) -> tuple[bytes, bytes]: """Chain key KDF. Returns (new_chain_key, message_key). Uses HMAC-SHA256: message_key = HMAC(chain_key, 0x01) new_chain_key = HMAC(chain_key, 0x02) """ message_key = hmac.new(chain_key, _HKDF_INFO_CK_MSG, hashlib.sha256).digest() new_chain_key = hmac.new(chain_key, _HKDF_INFO_CK_NEXT, hashlib.sha256).digest() return new_chain_key, message_key # --------------------------------------------------------------------------- # X3DH # --------------------------------------------------------------------------- _X3DH_INFO = b"EncryptedChat_X3DH" def generate_signed_prekey(identity_private: Ed25519PrivateKey) -> dict: """Generate a signed pre-key (SPK). Returns {private: X25519PrivateKey, public: X25519PublicKey, signature: bytes, id: str}. """ spk_priv, spk_pub = generate_x25519_keypair() spk_pub_bytes = serialize_x25519_public(spk_pub) signature = ed25519_sign(identity_private, spk_pub_bytes) return { "private": spk_priv, "public": spk_pub, "signature": signature, "id": str(uuid.uuid4()), } def generate_one_time_prekeys(count: int = 50) -> list[dict]: """Generate a batch of one-time pre-keys. Returns [{private: X25519PrivateKey, public: X25519PublicKey, id: str}, ...]. """ result = [] for _ in range(count): priv, pub = generate_x25519_keypair() result.append({"private": priv, "public": pub, "id": str(uuid.uuid4())}) return result def x3dh_initiate( ik_private_ed: Ed25519PrivateKey, ik_public_remote_ed: Ed25519PublicKey, spk_remote: X25519PublicKey, spk_signature: bytes, opk_remote: X25519PublicKey | None = None, ) -> tuple[bytes, X25519PrivateKey, X25519PublicKey]: """Initiator side of X3DH. Args: ik_private_ed: Our Ed25519 identity private key ik_public_remote_ed: Remote Ed25519 identity public key spk_remote: Remote signed pre-key (X25519 public) spk_signature: Ed25519 signature of spk_remote by ik_public_remote_ed opk_remote: Optional one-time pre-key (X25519 public) Returns: (shared_secret, ephemeral_private, ephemeral_public) """ # Verify SPK signature spk_remote_bytes = serialize_x25519_public(spk_remote) if not ed25519_verify(ik_public_remote_ed, spk_signature, spk_remote_bytes): raise ValueError("Invalid SPK signature") # Convert identity keys to X25519 ik_x25519_private = ed25519_private_to_x25519(ik_private_ed) ik_x25519_remote = ed25519_public_to_x25519(ik_public_remote_ed) # Generate ephemeral keypair ek_priv, ek_pub = generate_x25519_keypair() # DH computations dh1 = x25519_dh(ik_x25519_private, spk_remote) # IK_A, SPK_B dh2 = x25519_dh(ek_priv, ik_x25519_remote) # EK_A, IK_B dh3 = x25519_dh(ek_priv, spk_remote) # EK_A, SPK_B dh_concat = dh1 + dh2 + dh3 if opk_remote is not None: dh4 = x25519_dh(ek_priv, opk_remote) # EK_A, OPK_B dh_concat += dh4 # Derive shared secret shared_secret = hkdf_derive(dh_concat, salt=b"\x00" * 32, info=_X3DH_INFO, length=32) return shared_secret, ek_priv, ek_pub def x3dh_respond( ik_private_ed: Ed25519PrivateKey, spk_private: X25519PrivateKey, ik_remote_ed: Ed25519PublicKey, ek_remote: X25519PublicKey, opk_private: X25519PrivateKey | None = None, ) -> bytes: """Responder side of X3DH. Args: ik_private_ed: Our Ed25519 identity private key spk_private: Our signed pre-key private (X25519) ik_remote_ed: Remote Ed25519 identity public key ek_remote: Remote ephemeral key (X25519 public) opk_private: Our one-time pre-key private (X25519), if used Returns: shared_secret (32 bytes) """ ik_x25519_private = ed25519_private_to_x25519(ik_private_ed) ik_x25519_remote = ed25519_public_to_x25519(ik_remote_ed) dh1 = x25519_dh(spk_private, ik_x25519_remote) # SPK_B, IK_A dh2 = x25519_dh(ik_x25519_private, ek_remote) # IK_B, EK_A dh3 = x25519_dh(spk_private, ek_remote) # SPK_B, EK_A dh_concat = dh1 + dh2 + dh3 if opk_private is not None: dh4 = x25519_dh(opk_private, ek_remote) # OPK_B, EK_A dh_concat += dh4 shared_secret = hkdf_derive(dh_concat, salt=b"\x00" * 32, info=_X3DH_INFO, length=32) return shared_secret # --------------------------------------------------------------------------- # Double Ratchet # --------------------------------------------------------------------------- MAX_SKIP = 256 # max messages to skip in a single chain (out-of-order tolerance) @dataclass class RatchetHeader: """Header sent with each ratchet message.""" dh_pub: bytes # sender's current ratchet public key (32 bytes) n: int # message number in current sending chain pn: int # number of messages in previous sending chain def serialize(self) -> bytes: return json.dumps({ "dh_pub": serialize_x25519_public(load_x25519_public(self.dh_pub)).hex() if isinstance(self.dh_pub, bytes) else serialize_x25519_public(self.dh_pub).hex(), "n": self.n, "pn": self.pn, }).encode() def to_dict(self) -> dict: pub_hex = self.dh_pub.hex() if isinstance(self.dh_pub, bytes) else \ serialize_x25519_public(self.dh_pub).hex() return {"dh_pub": pub_hex, "n": self.n, "pn": self.pn} @classmethod def from_dict(cls, d: dict) -> "RatchetHeader": return cls(dh_pub=bytes.fromhex(d["dh_pub"]), n=d["n"], pn=d["pn"]) class DoubleRatchet: """Signal Double Ratchet implementation.""" def __init__(self): self.dh_pair: tuple[X25519PrivateKey, X25519PublicKey] | None = None self.dh_remote: X25519PublicKey | None = None self.root_key: bytes = b"" self.send_chain_key: bytes | None = None self.recv_chain_key: bytes | None = None self.send_n: int = 0 self.recv_n: int = 0 self.prev_send_n: int = 0 # (dh_pub_hex, n) -> message_key for out-of-order messages self.skipped: dict[tuple[str, int], bytes] = {} @classmethod def init_alice(cls, shared_secret: bytes, bob_spk_pub: X25519PublicKey) -> "DoubleRatchet": """Initialize as initiator (Alice) after X3DH. Alice performs the first DH ratchet step immediately. """ ratchet = cls() ratchet.dh_pair = generate_x25519_keypair() ratchet.dh_remote = bob_spk_pub # Perform DH ratchet to derive send chain dh_output = x25519_dh(ratchet.dh_pair[0], ratchet.dh_remote) ratchet.root_key, ratchet.send_chain_key = kdf_rk(shared_secret, dh_output) ratchet.recv_chain_key = None ratchet.send_n = 0 ratchet.recv_n = 0 ratchet.prev_send_n = 0 return ratchet @classmethod def init_bob(cls, shared_secret: bytes, spk_pair: tuple[X25519PrivateKey, X25519PublicKey]) -> "DoubleRatchet": """Initialize as responder (Bob) after X3DH. Bob uses his SPK as the initial ratchet key pair. """ ratchet = cls() ratchet.dh_pair = spk_pair ratchet.root_key = shared_secret ratchet.send_chain_key = None ratchet.recv_chain_key = None ratchet.send_n = 0 ratchet.recv_n = 0 ratchet.prev_send_n = 0 return ratchet def encrypt(self, plaintext: bytes) -> dict: """Encrypt a message. Returns {header: {dh_pub, n, pn}, ciphertext: bytes, nonce: bytes}. """ if self.send_chain_key is None: raise RuntimeError("Send chain not initialized") self.send_chain_key, message_key = kdf_ck(self.send_chain_key) header = RatchetHeader( dh_pub=serialize_x25519_public(self.dh_pair[1]), n=self.send_n, pn=self.prev_send_n, ) # Encrypt with AES-256-GCM using the message key nonce = os.urandom(12) aesgcm = AESGCM(message_key) # Include header as AAD to bind ciphertext to header aad = header.serialize() ct_with_tag = aesgcm.encrypt(nonce, plaintext, aad) self.send_n += 1 return { "header": header.to_dict(), "ciphertext": ct_with_tag, # includes 16-byte tag "nonce": nonce, } def decrypt(self, header_dict: dict, ciphertext: bytes, nonce: bytes) -> bytes: """Decrypt a message. Handles DH ratchet step if new dh_pub. State is snapshotted before modification and restored on failure (M9 fix). """ header = RatchetHeader.from_dict(header_dict) remote_dh_pub_bytes = header.dh_pub # Check if this is from a skipped message (no state modification needed) skip_key = (remote_dh_pub_bytes.hex(), header.n) if skip_key in self.skipped: mk = self.skipped.pop(skip_key) aad = header.serialize() aesgcm = AESGCM(mk) try: return aesgcm.decrypt(nonce, ciphertext, aad) except Exception: self.skipped[skip_key] = mk # restore skipped key raise # Snapshot state before modifications snap = self._snapshot() try: remote_dh_pub = load_x25519_public(remote_dh_pub_bytes) current_remote_bytes = serialize_x25519_public(self.dh_remote) if self.dh_remote else None if current_remote_bytes is None or remote_dh_pub_bytes != current_remote_bytes: # New DH ratchet step self._skip_messages(header.pn) self._dh_ratchet(remote_dh_pub) self._skip_messages(header.n) # Derive message key from receive chain self.recv_chain_key, mk = kdf_ck(self.recv_chain_key) self.recv_n += 1 aad = header.serialize() aesgcm = AESGCM(mk) return aesgcm.decrypt(nonce, ciphertext, aad) except Exception: self._restore(snap) raise def _snapshot(self) -> dict: """Capture mutable state for rollback on decrypt failure.""" return { "dh_pair": self.dh_pair, "dh_remote": self.dh_remote, "root_key": self.root_key, "send_chain_key": self.send_chain_key, "recv_chain_key": self.recv_chain_key, "send_n": self.send_n, "recv_n": self.recv_n, "prev_send_n": self.prev_send_n, "skipped": dict(self.skipped), } def _restore(self, snap: dict): """Restore state from snapshot.""" self.dh_pair = snap["dh_pair"] self.dh_remote = snap["dh_remote"] self.root_key = snap["root_key"] self.send_chain_key = snap["send_chain_key"] self.recv_chain_key = snap["recv_chain_key"] self.send_n = snap["send_n"] self.recv_n = snap["recv_n"] self.prev_send_n = snap["prev_send_n"] self.skipped = snap["skipped"] def _skip_messages(self, until: int): """Skip ahead in the receive chain, storing message keys for out-of-order delivery.""" if self.recv_chain_key is None: return if until - self.recv_n > MAX_SKIP: raise RuntimeError(f"Too many skipped messages ({until - self.recv_n} > {MAX_SKIP})") while self.recv_n < until: self.recv_chain_key, mk = kdf_ck(self.recv_chain_key) remote_hex = serialize_x25519_public(self.dh_remote).hex() if self.dh_remote else "" self.skipped[(remote_hex, self.recv_n)] = mk self.recv_n += 1 def _dh_ratchet(self, remote_dh_pub: X25519PublicKey): """Perform a DH ratchet step: update receive chain, generate new DH pair, update send chain.""" self.prev_send_n = self.send_n self.send_n = 0 self.recv_n = 0 self.dh_remote = remote_dh_pub # Derive new receive chain key dh_output = x25519_dh(self.dh_pair[0], self.dh_remote) self.root_key, self.recv_chain_key = kdf_rk(self.root_key, dh_output) # Generate new DH pair and derive new send chain key self.dh_pair = generate_x25519_keypair() dh_output = x25519_dh(self.dh_pair[0], self.dh_remote) self.root_key, self.send_chain_key = kdf_rk(self.root_key, dh_output) def export_state(self) -> bytes: """Serialize full ratchet state for persistent storage.""" state = { "dh_priv": serialize_x25519_private(self.dh_pair[0]).hex() if self.dh_pair else None, "dh_pub": serialize_x25519_public(self.dh_pair[1]).hex() if self.dh_pair else None, "dh_remote": serialize_x25519_public(self.dh_remote).hex() if self.dh_remote else None, "root_key": self.root_key.hex(), "send_ck": self.send_chain_key.hex() if self.send_chain_key else None, "recv_ck": self.recv_chain_key.hex() if self.recv_chain_key else None, "send_n": self.send_n, "recv_n": self.recv_n, "prev_send_n": self.prev_send_n, "skipped": {f"{k[0]}:{k[1]}": v.hex() for k, v in self.skipped.items()}, } return json.dumps(state).encode() @classmethod def import_state(cls, data: bytes) -> "DoubleRatchet": """Deserialize ratchet state.""" state = json.loads(data) r = cls() if state["dh_priv"] and state["dh_pub"]: priv = load_x25519_private(bytes.fromhex(state["dh_priv"])) pub = load_x25519_public(bytes.fromhex(state["dh_pub"])) r.dh_pair = (priv, pub) if state["dh_remote"]: r.dh_remote = load_x25519_public(bytes.fromhex(state["dh_remote"])) r.root_key = bytes.fromhex(state["root_key"]) r.send_chain_key = bytes.fromhex(state["send_ck"]) if state["send_ck"] else None r.recv_chain_key = bytes.fromhex(state["recv_ck"]) if state["recv_ck"] else None r.send_n = state["send_n"] r.recv_n = state["recv_n"] r.prev_send_n = state["prev_send_n"] r.skipped = {} for k_str, v_hex in state.get("skipped", {}).items(): parts = k_str.rsplit(":", 1) dh_hex = parts[0] n = int(parts[1]) r.skipped[(dh_hex, n)] = bytes.fromhex(v_hex) return r # --------------------------------------------------------------------------- # Sender Keys (group messaging) # --------------------------------------------------------------------------- class SenderKeyState: """Sender key chain for group messaging. Each sender in a group has their own sender key chain. Other group members receive the initial sender_key via pairwise Double Ratchet. """ def __init__(self, sender_key: bytes | None = None): if sender_key is None: sender_key = os.urandom(32) self.sender_key = sender_key self.chain_id = hashlib.sha256(sender_key).digest() self.chain_key = hkdf_derive(sender_key, salt=b"\x00" * 32, info=b"SenderKeyChain", length=32) self.n = 0 # For receivers: track chain state to allow fast-forward self._known_keys: dict[int, bytes] = {} def encrypt(self, plaintext: bytes) -> dict: """Encrypt with current chain key. Returns {chain_id: hex, n: int, ciphertext: bytes, nonce: bytes}. """ self.chain_key, message_key = kdf_ck(self.chain_key) nonce = os.urandom(12) aesgcm = AESGCM(message_key) # AAD includes chain_id and message number aad = self.chain_id + struct.pack(">I", self.n) ct_with_tag = aesgcm.encrypt(nonce, plaintext, aad) result = { "chain_id": self.chain_id.hex(), "n": self.n, "ciphertext": ct_with_tag, "nonce": nonce, } self.n += 1 return result MAX_SENDER_KEY_SKIP = 256 def decrypt(self, chain_id_hex: str, n: int, ciphertext: bytes, nonce: bytes) -> bytes: """Decrypt a group message. Fast-forwards the chain if needed. State is snapshotted before modification and restored on failure (M9 fix). """ chain_id = bytes.fromhex(chain_id_hex) if chain_id != self.chain_id: raise ValueError("Chain ID mismatch") if n - self.n > self.MAX_SENDER_KEY_SKIP: raise ValueError(f"Sender key skip too large ({n - self.n} > {self.MAX_SENDER_KEY_SKIP})") # Snapshot before fast-forward snap_chain_key = self.chain_key snap_n = self.n snap_known = dict(self._known_keys) try: # Fast-forward the chain to reach message n while self.n <= n: self.chain_key, mk = kdf_ck(self.chain_key) self._known_keys[self.n] = mk self.n += 1 mk = self._known_keys.pop(n, None) if mk is None: raise ValueError(f"Message key for n={n} not available (already consumed)") aad = chain_id + struct.pack(">I", n) aesgcm = AESGCM(mk) return aesgcm.decrypt(nonce, ciphertext, aad) except Exception: self.chain_key = snap_chain_key self.n = snap_n self._known_keys = snap_known raise def export_key(self) -> bytes: """Export sender key for distribution to group members. Contains everything needed to initialize a receiving SenderKeyState. """ return json.dumps({ "sender_key": self.sender_key.hex(), }).encode() def export_state(self) -> bytes: """Serialize full state for persistent storage.""" return json.dumps({ "sender_key": self.sender_key.hex(), "chain_id": self.chain_id.hex(), "chain_key": self.chain_key.hex(), "n": self.n, "known_keys": {str(k): v.hex() for k, v in self._known_keys.items()}, }).encode() @classmethod def import_state(cls, data: bytes) -> "SenderKeyState": state = json.loads(data) obj = cls.__new__(cls) obj.sender_key = bytes.fromhex(state["sender_key"]) obj.chain_id = bytes.fromhex(state["chain_id"]) obj.chain_key = bytes.fromhex(state["chain_key"]) obj.n = state["n"] obj._known_keys = {int(k): bytes.fromhex(v) for k, v in state.get("known_keys", {}).items()} return obj @classmethod def from_key(cls, exported_key: bytes) -> "SenderKeyState": """Initialize a receiving SenderKeyState from an exported key.""" data = json.loads(exported_key) return cls(sender_key=bytes.fromhex(data["sender_key"])) # --------------------------------------------------------------------------- # Contact Key Verification (Safety Numbers / Fingerprints / QR Codes) # --------------------------------------------------------------------------- FINGERPRINT_VERSION = 0 def compute_fingerprint(user_id: str, identity_key_bytes: bytes, iterations: int = 5200) -> bytes: """Compute a 32-byte fingerprint for a user's identity key. Uses iterated SHA-512 (Signal's NumericFingerprint algorithm). Seed: version(2B) + identity_key(32B) + user_id(UTF-8). Each iteration: SHA-512(previous_hash + identity_key). Output: first 32 bytes of final hash. """ version_bytes = FINGERPRINT_VERSION.to_bytes(2, "big") data = version_bytes + identity_key_bytes + user_id.encode("utf-8") for _ in range(iterations): data = hashlib.sha512(data + identity_key_bytes).digest() return data[:32] def format_fingerprint(fp_bytes: bytes) -> str: """Format 32-byte fingerprint as 6 groups of 5 zero-padded digits (30 digits). Each group: int(bytes[i*5:(i+1)*5], big-endian) % 100000. Output: two lines of 3 groups each, space-separated. """ groups = [] for i in range(6): num = int.from_bytes(fp_bytes[i * 5:(i + 1) * 5], "big") % 100000 groups.append(f"{num:05d}") return " ".join(groups[:3]) + "\n" + " ".join(groups[3:]) def compute_safety_number(my_uid: str, my_ik_bytes: bytes, their_uid: str, their_ik_bytes: bytes) -> str: """Compute a 60-digit safety number for a pair of users. Both users see the same number regardless of who computes it. Lower user_id's fingerprint comes first (deterministic ordering). Output: 12 groups of 5 digits, formatted as 3 lines of 4 groups. """ fp_mine = compute_fingerprint(my_uid, my_ik_bytes) fp_theirs = compute_fingerprint(their_uid, their_ik_bytes) if my_uid < their_uid: combined = fp_mine + fp_theirs else: combined = fp_theirs + fp_mine # 64 bytes -> 12 groups of 5 digits groups = [] for i in range(12): num = int.from_bytes(combined[i * 5:(i + 1) * 5], "big") % 100000 groups.append(f"{num:05d}") lines = [ " ".join(groups[0:4]), " ".join(groups[4:8]), " ".join(groups[8:12]), ] return "\n".join(lines) def encode_verification_qr(user_id: str, identity_key_bytes: bytes) -> bytes: """Encode user identity for QR code verification. Format: version(1B=0x01) + uid_len(1B) + uid(UTF-8) + identity_key(32B). """ uid_bytes = user_id.encode("utf-8") return b"\x01" + len(uid_bytes).to_bytes(1, "big") + uid_bytes + identity_key_bytes def decode_verification_qr(data: bytes) -> tuple[str, bytes]: """Decode QR code verification payload. Returns (user_id, identity_key_bytes). Raises ValueError on invalid format. """ if len(data) < 3: raise ValueError("QR data too short") if data[0] != 0x01: raise ValueError(f"Unknown QR version: {data[0]}") uid_len = data[1] if len(data) < 2 + uid_len + 32: raise ValueError("QR data truncated") user_id = data[2:2 + uid_len].decode("utf-8") identity_key = data[2 + uid_len:2 + uid_len + 32] return user_id, identity_key # --------------------------------------------------------------------------- # Message Padding (metadata privacy — hide plaintext length) # --------------------------------------------------------------------------- _PAD_MAGIC = b"\x01" _PAD_BUCKETS = [64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536] def pad_plaintext(plaintext: bytes) -> bytes: """Pad plaintext to nearest bucket size to hide message length. Format: 0x01 + plaintext + random_padding + pad_length(4B big-endian) Prefix 0x01 distinguishes padded messages from legacy unpadded (which start with '{'). """ content = _PAD_MAGIC + plaintext # +4 for the length suffix min_size = len(content) + 4 target = next((b for b in _PAD_BUCKETS if b >= min_size), min_size) pad_len = target - len(content) return content + os.urandom(pad_len - 4) + struct.pack(">I", pad_len) def unpad_plaintext(data: bytes) -> bytes: """Remove padding. Returns raw plaintext for both padded and legacy unpadded messages.""" if not data or data[0:1] != _PAD_MAGIC: return data # legacy unpadded message (starts with '{' for JSON) if len(data) < 5: return data # too short to be validly padded pad_len = struct.unpack(">I", data[-4:])[0] if pad_len < 4 or pad_len > len(data) - 1: return data # invalid padding metadata, treat as legacy return data[1:len(data) - pad_len]