Files
Kecalek_python/zaloha/crypto_utils.py
2026-03-11 16:54:14 +01:00

813 lines
31 KiB
Python

"""Cryptographic utilities: Ed25519, X25519, AES-256-GCM, Double Ratchet, Sender Keys.
RSA functions retained for login challenge-response only.
"""
import hashlib
import hmac
import json
import os
import struct
import uuid
from dataclasses import dataclass, field
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import padding, rsa
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey, Ed25519PublicKey
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
# ---------------------------------------------------------------------------
# Password-based key encryption (M3: PBKDF2 600k iterations + AES-256-GCM)
# ---------------------------------------------------------------------------
PBKDF2_ITERATIONS = 600_000
_ECP1_MAGIC = b"ECP1" # Encrypted Chat PBKDF v1 format marker
def _encrypt_private_key(raw_bytes: bytes, password: bytes) -> bytes:
"""Encrypt raw key bytes with PBKDF2-HMAC-SHA256 (600k iterations) + AES-256-GCM.
Output format: MAGIC(4) + salt(16) + nonce(12) + ciphertext_with_tag(N+16)
"""
salt = os.urandom(16)
kdf = PBKDF2HMAC(algorithm=hashes.SHA256(), length=32,
salt=salt, iterations=PBKDF2_ITERATIONS)
derived = kdf.derive(password)
nonce = os.urandom(12)
aesgcm = AESGCM(derived)
ct = aesgcm.encrypt(nonce, raw_bytes, _ECP1_MAGIC) # AAD = magic bytes
return _ECP1_MAGIC + salt + nonce + ct
def _decrypt_private_key(data: bytes, password: bytes) -> bytes:
"""Decrypt key bytes encrypted with _encrypt_private_key."""
if not data.startswith(_ECP1_MAGIC):
raise ValueError("Not ECP1 format")
salt = data[4:20]
nonce = data[20:32]
ct = data[32:]
kdf = PBKDF2HMAC(algorithm=hashes.SHA256(), length=32,
salt=salt, iterations=PBKDF2_ITERATIONS)
derived = kdf.derive(password)
aesgcm = AESGCM(derived)
return aesgcm.decrypt(nonce, ct, _ECP1_MAGIC)
# ---------------------------------------------------------------------------
# RSA (login challenge-response ONLY)
# ---------------------------------------------------------------------------
def generate_rsa_keypair(key_size: int = 4096) -> tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey]:
private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size)
return private_key, private_key.public_key()
def serialize_private_key(key: rsa.RSAPrivateKey, password: bytes | None = None) -> bytes:
if password:
raw = key.private_bytes(serialization.Encoding.DER, serialization.PrivateFormat.PKCS8,
serialization.NoEncryption())
return _encrypt_private_key(raw, password)
return key.private_bytes(serialization.Encoding.PEM, serialization.PrivateFormat.PKCS8,
serialization.NoEncryption())
def serialize_public_key(key: rsa.RSAPublicKey) -> bytes:
return key.public_bytes(serialization.Encoding.PEM, serialization.PublicFormat.SubjectPublicKeyInfo)
def load_private_key(data: bytes, password: bytes | None = None) -> rsa.RSAPrivateKey:
if data.startswith(_ECP1_MAGIC):
raw = _decrypt_private_key(data, password)
return serialization.load_der_private_key(raw, password=None)
# Legacy PEM format (old BestAvailableEncryption or unencrypted)
return serialization.load_pem_private_key(data, password=password)
def load_public_key(pem: bytes) -> rsa.RSAPublicKey:
return serialization.load_pem_public_key(pem)
def rsa_sign(private_key: rsa.RSAPrivateKey, data: bytes) -> bytes:
return private_key.sign(
data,
padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH),
hashes.SHA256(),
)
def rsa_verify(public_key: rsa.RSAPublicKey, signature: bytes, data: bytes) -> bool:
try:
public_key.verify(
signature, data,
padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH),
hashes.SHA256(),
)
return True
except Exception:
return False
# ---------------------------------------------------------------------------
# AES-256-GCM (symmetric encryption — used by ratchet message keys & images)
# ---------------------------------------------------------------------------
def aes_encrypt(plaintext: bytes, key: bytes | None = None) -> tuple[bytes, bytes, bytes, bytes]:
"""Encrypt with AES-256-GCM. Returns (key, nonce, ciphertext, tag)."""
if key is None:
key = AESGCM.generate_key(bit_length=256)
nonce = os.urandom(12)
aesgcm = AESGCM(key)
ct_with_tag = aesgcm.encrypt(nonce, plaintext, None)
ciphertext = ct_with_tag[:-16]
tag = ct_with_tag[-16:]
return key, nonce, ciphertext, tag
def aes_decrypt(key: bytes, nonce: bytes, ciphertext: bytes, tag: bytes) -> bytes:
"""Decrypt with AES-256-GCM."""
aesgcm = AESGCM(key)
return aesgcm.decrypt(nonce, ciphertext + tag, None)
# ---------------------------------------------------------------------------
# Ed25519 Identity Keys
# ---------------------------------------------------------------------------
def generate_identity_keypair() -> tuple[Ed25519PrivateKey, Ed25519PublicKey]:
priv = Ed25519PrivateKey.generate()
return priv, priv.public_key()
def serialize_ed25519_private(key: Ed25519PrivateKey, password: bytes | None = None) -> bytes:
if password:
raw = serialize_ed25519_private_raw(key) # 32 bytes
return _encrypt_private_key(raw, password)
return serialize_ed25519_private_raw(key) # 32 bytes, no password
def serialize_ed25519_private_raw(key: Ed25519PrivateKey) -> bytes:
"""Serialize Ed25519 private key to 32 raw bytes (unencrypted)."""
return key.private_bytes(serialization.Encoding.Raw, serialization.PrivateFormat.Raw, serialization.NoEncryption())
def serialize_ed25519_public(key: Ed25519PublicKey) -> bytes:
"""Serialize Ed25519 public key to 32 raw bytes."""
return key.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw)
def load_ed25519_private(data: bytes, password: bytes | None = None) -> Ed25519PrivateKey:
if data.startswith(_ECP1_MAGIC):
raw = _decrypt_private_key(data, password)
return Ed25519PrivateKey.from_private_bytes(raw)
# Legacy formats: PEM (old BestAvailableEncryption) or 32-byte raw
if password:
return serialization.load_pem_private_key(data, password=password)
if len(data) == 32:
return Ed25519PrivateKey.from_private_bytes(data)
return serialization.load_pem_private_key(data, password=None)
def load_ed25519_public(data: bytes) -> Ed25519PublicKey:
if len(data) == 32:
return Ed25519PublicKey.from_public_bytes(data)
return serialization.load_pem_public_key(data)
def ed25519_sign(private_key: Ed25519PrivateKey, data: bytes) -> bytes:
"""Sign data with Ed25519. Returns 64-byte signature."""
return private_key.sign(data)
def ed25519_verify(public_key: Ed25519PublicKey, signature: bytes, data: bytes) -> bool:
"""Verify Ed25519 signature."""
try:
public_key.verify(signature, data)
return True
except Exception:
return False
# ---------------------------------------------------------------------------
# X25519 Key Exchange
# ---------------------------------------------------------------------------
def generate_x25519_keypair() -> tuple[X25519PrivateKey, X25519PublicKey]:
priv = X25519PrivateKey.generate()
return priv, priv.public_key()
def serialize_x25519_private(key: X25519PrivateKey) -> bytes:
"""Serialize X25519 private key to 32 raw bytes."""
return key.private_bytes(serialization.Encoding.Raw, serialization.PrivateFormat.Raw, serialization.NoEncryption())
def serialize_x25519_public(key: X25519PublicKey) -> bytes:
"""Serialize X25519 public key to 32 raw bytes."""
return key.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw)
def load_x25519_private(data: bytes) -> X25519PrivateKey:
return X25519PrivateKey.from_private_bytes(data)
def load_x25519_public(data: bytes) -> X25519PublicKey:
return X25519PublicKey.from_public_bytes(data)
def x25519_dh(private_key: X25519PrivateKey, public_key: X25519PublicKey) -> bytes:
"""Perform X25519 Diffie-Hellman. Returns 32-byte shared secret."""
return private_key.exchange(public_key)
# ---------------------------------------------------------------------------
# Ed25519 <-> X25519 conversion (for Identity Key dual use)
# ---------------------------------------------------------------------------
def ed25519_private_to_x25519(ed_private: Ed25519PrivateKey) -> X25519PrivateKey:
"""Derive X25519 private key from Ed25519 private key via RFC 7748 clamping."""
raw = ed_private.private_bytes(
serialization.Encoding.Raw, serialization.PrivateFormat.Raw, serialization.NoEncryption()
)
# SHA-512 hash of the seed, take first 32 bytes, clamp per RFC 7748
h = hashlib.sha512(raw).digest()[:32]
clamped = bytearray(h)
clamped[0] &= 248
clamped[31] &= 127
clamped[31] |= 64
return X25519PrivateKey.from_private_bytes(bytes(clamped))
def ed25519_public_to_x25519(ed_public: Ed25519PublicKey) -> X25519PublicKey:
"""Derive X25519 public key from Ed25519 public key.
Uses the cryptography library's internal conversion. For production use,
we compute the X25519 public key from the converted private key when possible.
For remote keys (where we don't have the private key), we use a pure-Python
implementation of the Ed25519->X25519 point conversion.
"""
# Montgomery u = (1 + y) / (1 - y) mod p, where p = 2^255 - 19
raw = ed_public.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw)
y = int.from_bytes(raw, "little")
# Clear the sign bit
y &= (1 << 255) - 1
p = (1 << 255) - 19
# u = (1 + y) * inverse(1 - y) mod p
one_plus_y = (1 + y) % p
one_minus_y = (1 - y) % p
inv = pow(one_minus_y, p - 2, p)
u = (one_plus_y * inv) % p
x25519_bytes = u.to_bytes(32, "little")
return X25519PublicKey.from_public_bytes(x25519_bytes)
# ---------------------------------------------------------------------------
# HKDF
# ---------------------------------------------------------------------------
_HKDF_INFO_SELF = b"EncryptedChat_SelfKey"
_HKDF_INFO_RK = b"EncryptedChat_RootKey"
def derive_self_encryption_key(identity_private: Ed25519PrivateKey) -> bytes:
"""Derive a static AES-256 key from identity key for encrypting own sent messages.
This is NOT a ratchet — it's a static key. Safe because only the owner
has the identity private key, and self-copies don't need forward secrecy.
"""
raw = identity_private.private_bytes(
serialization.Encoding.Raw, serialization.PrivateFormat.Raw, serialization.NoEncryption()
)
return hkdf_derive(raw, salt=b"self_encryption", info=_HKDF_INFO_SELF, length=32)
_HKDF_INFO_LOCAL = b"EncryptedChat_LocalStorage"
def derive_local_storage_key(identity_private: Ed25519PrivateKey) -> bytes:
"""Derive AES-256 key for encrypting local session/sender key files."""
raw = identity_private.private_bytes(
serialization.Encoding.Raw, serialization.PrivateFormat.Raw, serialization.NoEncryption()
)
return hkdf_derive(raw, salt=b"local_storage", info=_HKDF_INFO_LOCAL, length=32)
_HKDF_INFO_CK_MSG = b"\x01" # chain key -> message key
_HKDF_INFO_CK_NEXT = b"\x02" # chain key -> next chain key
def hkdf_derive(input_key: bytes, salt: bytes, info: bytes, length: int = 32) -> bytes:
return HKDF(algorithm=hashes.SHA256(), length=length, salt=salt, info=info).derive(input_key)
def kdf_rk(root_key: bytes, dh_output: bytes) -> tuple[bytes, bytes]:
"""Root key KDF. Returns (new_root_key, chain_key).
Uses HKDF with the root key as salt and DH output as input key material.
Derives 64 bytes: first 32 = new root key, last 32 = chain key.
"""
derived = hkdf_derive(dh_output, salt=root_key, info=_HKDF_INFO_RK, length=64)
return derived[:32], derived[32:]
def kdf_ck(chain_key: bytes) -> tuple[bytes, bytes]:
"""Chain key KDF. Returns (new_chain_key, message_key).
Uses HMAC-SHA256:
message_key = HMAC(chain_key, 0x01)
new_chain_key = HMAC(chain_key, 0x02)
"""
message_key = hmac.new(chain_key, _HKDF_INFO_CK_MSG, hashlib.sha256).digest()
new_chain_key = hmac.new(chain_key, _HKDF_INFO_CK_NEXT, hashlib.sha256).digest()
return new_chain_key, message_key
# ---------------------------------------------------------------------------
# X3DH
# ---------------------------------------------------------------------------
_X3DH_INFO = b"EncryptedChat_X3DH"
def generate_signed_prekey(identity_private: Ed25519PrivateKey) -> dict:
"""Generate a signed pre-key (SPK).
Returns {private: X25519PrivateKey, public: X25519PublicKey, signature: bytes, id: str}.
"""
spk_priv, spk_pub = generate_x25519_keypair()
spk_pub_bytes = serialize_x25519_public(spk_pub)
signature = ed25519_sign(identity_private, spk_pub_bytes)
return {
"private": spk_priv,
"public": spk_pub,
"signature": signature,
"id": str(uuid.uuid4()),
}
def generate_one_time_prekeys(count: int = 50) -> list[dict]:
"""Generate a batch of one-time pre-keys.
Returns [{private: X25519PrivateKey, public: X25519PublicKey, id: str}, ...].
"""
result = []
for _ in range(count):
priv, pub = generate_x25519_keypair()
result.append({"private": priv, "public": pub, "id": str(uuid.uuid4())})
return result
def x3dh_initiate(
ik_private_ed: Ed25519PrivateKey,
ik_public_remote_ed: Ed25519PublicKey,
spk_remote: X25519PublicKey,
spk_signature: bytes,
opk_remote: X25519PublicKey | None = None,
) -> tuple[bytes, X25519PrivateKey, X25519PublicKey]:
"""Initiator side of X3DH.
Args:
ik_private_ed: Our Ed25519 identity private key
ik_public_remote_ed: Remote Ed25519 identity public key
spk_remote: Remote signed pre-key (X25519 public)
spk_signature: Ed25519 signature of spk_remote by ik_public_remote_ed
opk_remote: Optional one-time pre-key (X25519 public)
Returns:
(shared_secret, ephemeral_private, ephemeral_public)
"""
# Verify SPK signature
spk_remote_bytes = serialize_x25519_public(spk_remote)
if not ed25519_verify(ik_public_remote_ed, spk_signature, spk_remote_bytes):
raise ValueError("Invalid SPK signature")
# Convert identity keys to X25519
ik_x25519_private = ed25519_private_to_x25519(ik_private_ed)
ik_x25519_remote = ed25519_public_to_x25519(ik_public_remote_ed)
# Generate ephemeral keypair
ek_priv, ek_pub = generate_x25519_keypair()
# DH computations
dh1 = x25519_dh(ik_x25519_private, spk_remote) # IK_A, SPK_B
dh2 = x25519_dh(ek_priv, ik_x25519_remote) # EK_A, IK_B
dh3 = x25519_dh(ek_priv, spk_remote) # EK_A, SPK_B
dh_concat = dh1 + dh2 + dh3
if opk_remote is not None:
dh4 = x25519_dh(ek_priv, opk_remote) # EK_A, OPK_B
dh_concat += dh4
# Derive shared secret
shared_secret = hkdf_derive(dh_concat, salt=b"\x00" * 32, info=_X3DH_INFO, length=32)
return shared_secret, ek_priv, ek_pub
def x3dh_respond(
ik_private_ed: Ed25519PrivateKey,
spk_private: X25519PrivateKey,
ik_remote_ed: Ed25519PublicKey,
ek_remote: X25519PublicKey,
opk_private: X25519PrivateKey | None = None,
) -> bytes:
"""Responder side of X3DH.
Args:
ik_private_ed: Our Ed25519 identity private key
spk_private: Our signed pre-key private (X25519)
ik_remote_ed: Remote Ed25519 identity public key
ek_remote: Remote ephemeral key (X25519 public)
opk_private: Our one-time pre-key private (X25519), if used
Returns:
shared_secret (32 bytes)
"""
ik_x25519_private = ed25519_private_to_x25519(ik_private_ed)
ik_x25519_remote = ed25519_public_to_x25519(ik_remote_ed)
dh1 = x25519_dh(spk_private, ik_x25519_remote) # SPK_B, IK_A
dh2 = x25519_dh(ik_x25519_private, ek_remote) # IK_B, EK_A
dh3 = x25519_dh(spk_private, ek_remote) # SPK_B, EK_A
dh_concat = dh1 + dh2 + dh3
if opk_private is not None:
dh4 = x25519_dh(opk_private, ek_remote) # OPK_B, EK_A
dh_concat += dh4
shared_secret = hkdf_derive(dh_concat, salt=b"\x00" * 32, info=_X3DH_INFO, length=32)
return shared_secret
# ---------------------------------------------------------------------------
# Double Ratchet
# ---------------------------------------------------------------------------
MAX_SKIP = 256 # max messages to skip in a single chain (out-of-order tolerance)
@dataclass
class RatchetHeader:
"""Header sent with each ratchet message."""
dh_pub: bytes # sender's current ratchet public key (32 bytes)
n: int # message number in current sending chain
pn: int # number of messages in previous sending chain
def serialize(self) -> bytes:
return json.dumps({
"dh_pub": serialize_x25519_public(load_x25519_public(self.dh_pub)).hex()
if isinstance(self.dh_pub, bytes) else serialize_x25519_public(self.dh_pub).hex(),
"n": self.n,
"pn": self.pn,
}).encode()
def to_dict(self) -> dict:
pub_hex = self.dh_pub.hex() if isinstance(self.dh_pub, bytes) else \
serialize_x25519_public(self.dh_pub).hex()
return {"dh_pub": pub_hex, "n": self.n, "pn": self.pn}
@classmethod
def from_dict(cls, d: dict) -> "RatchetHeader":
return cls(dh_pub=bytes.fromhex(d["dh_pub"]), n=d["n"], pn=d["pn"])
class DoubleRatchet:
"""Signal Double Ratchet implementation."""
def __init__(self):
self.dh_pair: tuple[X25519PrivateKey, X25519PublicKey] | None = None
self.dh_remote: X25519PublicKey | None = None
self.root_key: bytes = b""
self.send_chain_key: bytes | None = None
self.recv_chain_key: bytes | None = None
self.send_n: int = 0
self.recv_n: int = 0
self.prev_send_n: int = 0
# (dh_pub_hex, n) -> message_key for out-of-order messages
self.skipped: dict[tuple[str, int], bytes] = {}
@classmethod
def init_alice(cls, shared_secret: bytes, bob_spk_pub: X25519PublicKey) -> "DoubleRatchet":
"""Initialize as initiator (Alice) after X3DH.
Alice performs the first DH ratchet step immediately.
"""
ratchet = cls()
ratchet.dh_pair = generate_x25519_keypair()
ratchet.dh_remote = bob_spk_pub
# Perform DH ratchet to derive send chain
dh_output = x25519_dh(ratchet.dh_pair[0], ratchet.dh_remote)
ratchet.root_key, ratchet.send_chain_key = kdf_rk(shared_secret, dh_output)
ratchet.recv_chain_key = None
ratchet.send_n = 0
ratchet.recv_n = 0
ratchet.prev_send_n = 0
return ratchet
@classmethod
def init_bob(cls, shared_secret: bytes, spk_pair: tuple[X25519PrivateKey, X25519PublicKey]) -> "DoubleRatchet":
"""Initialize as responder (Bob) after X3DH.
Bob uses his SPK as the initial ratchet key pair.
"""
ratchet = cls()
ratchet.dh_pair = spk_pair
ratchet.root_key = shared_secret
ratchet.send_chain_key = None
ratchet.recv_chain_key = None
ratchet.send_n = 0
ratchet.recv_n = 0
ratchet.prev_send_n = 0
return ratchet
def encrypt(self, plaintext: bytes) -> dict:
"""Encrypt a message.
Returns {header: {dh_pub, n, pn}, ciphertext: bytes, nonce: bytes}.
"""
if self.send_chain_key is None:
raise RuntimeError("Send chain not initialized")
self.send_chain_key, message_key = kdf_ck(self.send_chain_key)
header = RatchetHeader(
dh_pub=serialize_x25519_public(self.dh_pair[1]),
n=self.send_n,
pn=self.prev_send_n,
)
# Encrypt with AES-256-GCM using the message key
nonce = os.urandom(12)
aesgcm = AESGCM(message_key)
# Include header as AAD to bind ciphertext to header
aad = header.serialize()
ct_with_tag = aesgcm.encrypt(nonce, plaintext, aad)
self.send_n += 1
return {
"header": header.to_dict(),
"ciphertext": ct_with_tag, # includes 16-byte tag
"nonce": nonce,
}
def decrypt(self, header_dict: dict, ciphertext: bytes, nonce: bytes) -> bytes:
"""Decrypt a message. Handles DH ratchet step if new dh_pub.
State is snapshotted before modification and restored on failure (M9 fix).
"""
header = RatchetHeader.from_dict(header_dict)
remote_dh_pub_bytes = header.dh_pub
# Check if this is from a skipped message (no state modification needed)
skip_key = (remote_dh_pub_bytes.hex(), header.n)
if skip_key in self.skipped:
mk = self.skipped.pop(skip_key)
aad = header.serialize()
aesgcm = AESGCM(mk)
try:
return aesgcm.decrypt(nonce, ciphertext, aad)
except Exception:
self.skipped[skip_key] = mk # restore skipped key
raise
# Snapshot state before modifications
snap = self._snapshot()
try:
remote_dh_pub = load_x25519_public(remote_dh_pub_bytes)
current_remote_bytes = serialize_x25519_public(self.dh_remote) if self.dh_remote else None
if current_remote_bytes is None or remote_dh_pub_bytes != current_remote_bytes:
# New DH ratchet step
self._skip_messages(header.pn)
self._dh_ratchet(remote_dh_pub)
self._skip_messages(header.n)
# Derive message key from receive chain
self.recv_chain_key, mk = kdf_ck(self.recv_chain_key)
self.recv_n += 1
aad = header.serialize()
aesgcm = AESGCM(mk)
return aesgcm.decrypt(nonce, ciphertext, aad)
except Exception:
self._restore(snap)
raise
def _snapshot(self) -> dict:
"""Capture mutable state for rollback on decrypt failure."""
return {
"dh_pair": self.dh_pair,
"dh_remote": self.dh_remote,
"root_key": self.root_key,
"send_chain_key": self.send_chain_key,
"recv_chain_key": self.recv_chain_key,
"send_n": self.send_n,
"recv_n": self.recv_n,
"prev_send_n": self.prev_send_n,
"skipped": dict(self.skipped),
}
def _restore(self, snap: dict):
"""Restore state from snapshot."""
self.dh_pair = snap["dh_pair"]
self.dh_remote = snap["dh_remote"]
self.root_key = snap["root_key"]
self.send_chain_key = snap["send_chain_key"]
self.recv_chain_key = snap["recv_chain_key"]
self.send_n = snap["send_n"]
self.recv_n = snap["recv_n"]
self.prev_send_n = snap["prev_send_n"]
self.skipped = snap["skipped"]
def _skip_messages(self, until: int):
"""Skip ahead in the receive chain, storing message keys for out-of-order delivery."""
if self.recv_chain_key is None:
return
if until - self.recv_n > MAX_SKIP:
raise RuntimeError(f"Too many skipped messages ({until - self.recv_n} > {MAX_SKIP})")
while self.recv_n < until:
self.recv_chain_key, mk = kdf_ck(self.recv_chain_key)
remote_hex = serialize_x25519_public(self.dh_remote).hex() if self.dh_remote else ""
self.skipped[(remote_hex, self.recv_n)] = mk
self.recv_n += 1
def _dh_ratchet(self, remote_dh_pub: X25519PublicKey):
"""Perform a DH ratchet step: update receive chain, generate new DH pair, update send chain."""
self.prev_send_n = self.send_n
self.send_n = 0
self.recv_n = 0
self.dh_remote = remote_dh_pub
# Derive new receive chain key
dh_output = x25519_dh(self.dh_pair[0], self.dh_remote)
self.root_key, self.recv_chain_key = kdf_rk(self.root_key, dh_output)
# Generate new DH pair and derive new send chain key
self.dh_pair = generate_x25519_keypair()
dh_output = x25519_dh(self.dh_pair[0], self.dh_remote)
self.root_key, self.send_chain_key = kdf_rk(self.root_key, dh_output)
def export_state(self) -> bytes:
"""Serialize full ratchet state for persistent storage."""
state = {
"dh_priv": serialize_x25519_private(self.dh_pair[0]).hex() if self.dh_pair else None,
"dh_pub": serialize_x25519_public(self.dh_pair[1]).hex() if self.dh_pair else None,
"dh_remote": serialize_x25519_public(self.dh_remote).hex() if self.dh_remote else None,
"root_key": self.root_key.hex(),
"send_ck": self.send_chain_key.hex() if self.send_chain_key else None,
"recv_ck": self.recv_chain_key.hex() if self.recv_chain_key else None,
"send_n": self.send_n,
"recv_n": self.recv_n,
"prev_send_n": self.prev_send_n,
"skipped": {f"{k[0]}:{k[1]}": v.hex() for k, v in self.skipped.items()},
}
return json.dumps(state).encode()
@classmethod
def import_state(cls, data: bytes) -> "DoubleRatchet":
"""Deserialize ratchet state."""
state = json.loads(data)
r = cls()
if state["dh_priv"] and state["dh_pub"]:
priv = load_x25519_private(bytes.fromhex(state["dh_priv"]))
pub = load_x25519_public(bytes.fromhex(state["dh_pub"]))
r.dh_pair = (priv, pub)
if state["dh_remote"]:
r.dh_remote = load_x25519_public(bytes.fromhex(state["dh_remote"]))
r.root_key = bytes.fromhex(state["root_key"])
r.send_chain_key = bytes.fromhex(state["send_ck"]) if state["send_ck"] else None
r.recv_chain_key = bytes.fromhex(state["recv_ck"]) if state["recv_ck"] else None
r.send_n = state["send_n"]
r.recv_n = state["recv_n"]
r.prev_send_n = state["prev_send_n"]
r.skipped = {}
for k_str, v_hex in state.get("skipped", {}).items():
parts = k_str.rsplit(":", 1)
dh_hex = parts[0]
n = int(parts[1])
r.skipped[(dh_hex, n)] = bytes.fromhex(v_hex)
return r
# ---------------------------------------------------------------------------
# Sender Keys (group messaging)
# ---------------------------------------------------------------------------
class SenderKeyState:
"""Sender key chain for group messaging.
Each sender in a group has their own sender key chain.
Other group members receive the initial sender_key via pairwise Double Ratchet.
"""
def __init__(self, sender_key: bytes | None = None):
if sender_key is None:
sender_key = os.urandom(32)
self.sender_key = sender_key
self.chain_id = hashlib.sha256(sender_key).digest()
self.chain_key = hkdf_derive(sender_key, salt=b"\x00" * 32, info=b"SenderKeyChain", length=32)
self.n = 0
# For receivers: track chain state to allow fast-forward
self._known_keys: dict[int, bytes] = {}
def encrypt(self, plaintext: bytes) -> dict:
"""Encrypt with current chain key.
Returns {chain_id: hex, n: int, ciphertext: bytes, nonce: bytes}.
"""
self.chain_key, message_key = kdf_ck(self.chain_key)
nonce = os.urandom(12)
aesgcm = AESGCM(message_key)
# AAD includes chain_id and message number
aad = self.chain_id + struct.pack(">I", self.n)
ct_with_tag = aesgcm.encrypt(nonce, plaintext, aad)
result = {
"chain_id": self.chain_id.hex(),
"n": self.n,
"ciphertext": ct_with_tag,
"nonce": nonce,
}
self.n += 1
return result
MAX_SENDER_KEY_SKIP = 256
def decrypt(self, chain_id_hex: str, n: int, ciphertext: bytes, nonce: bytes) -> bytes:
"""Decrypt a group message. Fast-forwards the chain if needed.
State is snapshotted before modification and restored on failure (M9 fix).
"""
chain_id = bytes.fromhex(chain_id_hex)
if chain_id != self.chain_id:
raise ValueError("Chain ID mismatch")
if n - self.n > self.MAX_SENDER_KEY_SKIP:
raise ValueError(f"Sender key skip too large ({n - self.n} > {self.MAX_SENDER_KEY_SKIP})")
# Snapshot before fast-forward
snap_chain_key = self.chain_key
snap_n = self.n
snap_known = dict(self._known_keys)
try:
# Fast-forward the chain to reach message n
while self.n <= n:
self.chain_key, mk = kdf_ck(self.chain_key)
self._known_keys[self.n] = mk
self.n += 1
mk = self._known_keys.pop(n, None)
if mk is None:
raise ValueError(f"Message key for n={n} not available (already consumed)")
aad = chain_id + struct.pack(">I", n)
aesgcm = AESGCM(mk)
return aesgcm.decrypt(nonce, ciphertext, aad)
except Exception:
self.chain_key = snap_chain_key
self.n = snap_n
self._known_keys = snap_known
raise
def export_key(self) -> bytes:
"""Export sender key for distribution to group members.
Contains everything needed to initialize a receiving SenderKeyState.
"""
return json.dumps({
"sender_key": self.sender_key.hex(),
}).encode()
def export_state(self) -> bytes:
"""Serialize full state for persistent storage."""
return json.dumps({
"sender_key": self.sender_key.hex(),
"chain_id": self.chain_id.hex(),
"chain_key": self.chain_key.hex(),
"n": self.n,
"known_keys": {str(k): v.hex() for k, v in self._known_keys.items()},
}).encode()
@classmethod
def import_state(cls, data: bytes) -> "SenderKeyState":
state = json.loads(data)
obj = cls.__new__(cls)
obj.sender_key = bytes.fromhex(state["sender_key"])
obj.chain_id = bytes.fromhex(state["chain_id"])
obj.chain_key = bytes.fromhex(state["chain_key"])
obj.n = state["n"]
obj._known_keys = {int(k): bytes.fromhex(v) for k, v in state.get("known_keys", {}).items()}
return obj
@classmethod
def from_key(cls, exported_key: bytes) -> "SenderKeyState":
"""Initialize a receiving SenderKeyState from an exported key."""
data = json.loads(exported_key)
return cls(sender_key=bytes.fromhex(data["sender_key"]))