E2E encrypted chat (X3DH + Double Ratchet, Signal Protocol). Server: asyncio TCP + TLS, MySQL. Clients: PyQt6 GUI + CLI. Secrets (.env, TLS keys, Cloudflare token), runtime data and mobile clients (separate repos) are gitignored. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
989 lines
38 KiB
Python
989 lines
38 KiB
Python
"""Cryptographic utilities: Ed25519, X25519, AES-256-GCM, Double Ratchet, Sender Keys.
|
|
|
|
RSA functions retained for login challenge-response only.
|
|
"""
|
|
|
|
import hashlib
|
|
import hmac
|
|
import json
|
|
import os
|
|
import struct
|
|
import uuid
|
|
from dataclasses import dataclass, field
|
|
|
|
from cryptography.hazmat.primitives import hashes, serialization
|
|
from cryptography.hazmat.primitives.asymmetric import padding, rsa
|
|
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey, Ed25519PublicKey
|
|
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey
|
|
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
|
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
|
|
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Password-based key encryption (M3: PBKDF2 600k iterations + AES-256-GCM)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
PBKDF2_ITERATIONS = 600_000
|
|
_ECP1_MAGIC = b"ECP1" # Encrypted Chat PBKDF v1 format marker
|
|
|
|
|
|
def _encrypt_private_key(raw_bytes: bytes, password: bytes) -> bytes:
|
|
"""Encrypt raw key bytes with PBKDF2-HMAC-SHA256 (600k iterations) + AES-256-GCM.
|
|
|
|
Output format: MAGIC(4) + salt(16) + nonce(12) + ciphertext_with_tag(N+16)
|
|
"""
|
|
salt = os.urandom(16)
|
|
kdf = PBKDF2HMAC(algorithm=hashes.SHA256(), length=32,
|
|
salt=salt, iterations=PBKDF2_ITERATIONS)
|
|
derived = kdf.derive(password)
|
|
nonce = os.urandom(12)
|
|
aesgcm = AESGCM(derived)
|
|
ct = aesgcm.encrypt(nonce, raw_bytes, _ECP1_MAGIC) # AAD = magic bytes
|
|
return _ECP1_MAGIC + salt + nonce + ct
|
|
|
|
|
|
def _decrypt_private_key(data: bytes, password: bytes) -> bytes:
|
|
"""Decrypt key bytes encrypted with _encrypt_private_key."""
|
|
if not data.startswith(_ECP1_MAGIC):
|
|
raise ValueError("Not ECP1 format")
|
|
salt = data[4:20]
|
|
nonce = data[20:32]
|
|
ct = data[32:]
|
|
kdf = PBKDF2HMAC(algorithm=hashes.SHA256(), length=32,
|
|
salt=salt, iterations=PBKDF2_ITERATIONS)
|
|
derived = kdf.derive(password)
|
|
aesgcm = AESGCM(derived)
|
|
return aesgcm.decrypt(nonce, ct, _ECP1_MAGIC)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# RSA (login challenge-response ONLY)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def generate_rsa_keypair(key_size: int = 4096) -> tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey]:
|
|
private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size)
|
|
return private_key, private_key.public_key()
|
|
|
|
|
|
def serialize_private_key(key: rsa.RSAPrivateKey, password: bytes | None = None) -> bytes:
|
|
if password:
|
|
raw = key.private_bytes(serialization.Encoding.DER, serialization.PrivateFormat.PKCS8,
|
|
serialization.NoEncryption())
|
|
return _encrypt_private_key(raw, password)
|
|
return key.private_bytes(serialization.Encoding.PEM, serialization.PrivateFormat.PKCS8,
|
|
serialization.NoEncryption())
|
|
|
|
|
|
def serialize_public_key(key: rsa.RSAPublicKey) -> bytes:
|
|
return key.public_bytes(serialization.Encoding.PEM, serialization.PublicFormat.SubjectPublicKeyInfo)
|
|
|
|
|
|
def load_private_key(data: bytes, password: bytes | None = None) -> rsa.RSAPrivateKey:
|
|
if data.startswith(_ECP1_MAGIC):
|
|
raw = _decrypt_private_key(data, password)
|
|
return serialization.load_der_private_key(raw, password=None)
|
|
# Legacy PEM format (old BestAvailableEncryption or unencrypted)
|
|
return serialization.load_pem_private_key(data, password=password)
|
|
|
|
|
|
def load_public_key(pem: bytes) -> rsa.RSAPublicKey:
|
|
return serialization.load_pem_public_key(pem)
|
|
|
|
|
|
def compute_pairing_fingerprint(public_key_data: bytes | str) -> str:
|
|
"""Format a temporary pairing public key as a human-verifiable fingerprint."""
|
|
if isinstance(public_key_data, str):
|
|
key_bytes = public_key_data.encode("utf-8")
|
|
else:
|
|
key_bytes = public_key_data
|
|
canonical = key_bytes.replace(b"\r\n", b"\n").strip() if b"-----BEGIN" in key_bytes else key_bytes
|
|
digest = hashlib.sha256(b"EncryptedChat_PairingKey_v1\x00" + canonical).digest()
|
|
return format_fingerprint(digest)
|
|
|
|
|
|
def normalize_pairing_fingerprint(value: str) -> str:
|
|
"""Normalize user-entered pairing fingerprints for comparison."""
|
|
return "".join(ch for ch in value if ch.isdigit())
|
|
|
|
|
|
def encode_pairing_qr(code: str, fingerprint: str) -> bytes:
|
|
"""Encode pairing code + fingerprint for QR transport.
|
|
|
|
Format: magic(5='PAIR1') + code(8 ASCII digits) + fingerprint(30 ASCII digits)
|
|
"""
|
|
code_digits = "".join(ch for ch in code if ch.isdigit())
|
|
fp_digits = normalize_pairing_fingerprint(fingerprint)
|
|
if len(code_digits) != 8:
|
|
raise ValueError("Pairing code must contain 8 digits")
|
|
if len(fp_digits) != 30:
|
|
raise ValueError("Pairing fingerprint must contain 30 digits")
|
|
return b"PAIR1" + code_digits.encode("ascii") + fp_digits.encode("ascii")
|
|
|
|
|
|
def decode_pairing_qr(data: bytes) -> tuple[str, str]:
|
|
"""Decode pairing QR payload. Returns (code, formatted_fingerprint)."""
|
|
if len(data) != 43 or not data.startswith(b"PAIR1"):
|
|
raise ValueError("Invalid pairing QR payload")
|
|
code = data[5:13].decode("ascii")
|
|
fp_digits = data[13:43].decode("ascii")
|
|
if not code.isdigit() or not fp_digits.isdigit():
|
|
raise ValueError("Invalid pairing QR payload")
|
|
groups = [fp_digits[i:i + 5] for i in range(0, 30, 5)]
|
|
return code, " ".join(groups[:3]) + "\n" + " ".join(groups[3:])
|
|
|
|
|
|
def rsa_sign(private_key: rsa.RSAPrivateKey, data: bytes) -> bytes:
|
|
return private_key.sign(
|
|
data,
|
|
padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH),
|
|
hashes.SHA256(),
|
|
)
|
|
|
|
|
|
def rsa_verify(public_key: rsa.RSAPublicKey, signature: bytes, data: bytes) -> bool:
|
|
try:
|
|
public_key.verify(
|
|
signature, data,
|
|
padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.AUTO),
|
|
hashes.SHA256(),
|
|
)
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# AES-256-GCM (symmetric encryption — used by ratchet message keys & images)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def aes_encrypt(plaintext: bytes, key: bytes | None = None) -> tuple[bytes, bytes, bytes, bytes]:
|
|
"""Encrypt with AES-256-GCM. Returns (key, nonce, ciphertext, tag)."""
|
|
if key is None:
|
|
key = AESGCM.generate_key(bit_length=256)
|
|
nonce = os.urandom(12)
|
|
aesgcm = AESGCM(key)
|
|
ct_with_tag = aesgcm.encrypt(nonce, plaintext, None)
|
|
ciphertext = ct_with_tag[:-16]
|
|
tag = ct_with_tag[-16:]
|
|
return key, nonce, ciphertext, tag
|
|
|
|
|
|
def aes_decrypt(key: bytes, nonce: bytes, ciphertext: bytes, tag: bytes) -> bytes:
|
|
"""Decrypt with AES-256-GCM."""
|
|
aesgcm = AESGCM(key)
|
|
return aesgcm.decrypt(nonce, ciphertext + tag, None)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Ed25519 Identity Keys
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def generate_identity_keypair() -> tuple[Ed25519PrivateKey, Ed25519PublicKey]:
|
|
priv = Ed25519PrivateKey.generate()
|
|
return priv, priv.public_key()
|
|
|
|
|
|
def serialize_ed25519_private(key: Ed25519PrivateKey, password: bytes | None = None) -> bytes:
|
|
if password:
|
|
raw = serialize_ed25519_private_raw(key) # 32 bytes
|
|
return _encrypt_private_key(raw, password)
|
|
return serialize_ed25519_private_raw(key) # 32 bytes, no password
|
|
|
|
|
|
def serialize_ed25519_private_raw(key: Ed25519PrivateKey) -> bytes:
|
|
"""Serialize Ed25519 private key to 32 raw bytes (unencrypted)."""
|
|
return key.private_bytes(serialization.Encoding.Raw, serialization.PrivateFormat.Raw, serialization.NoEncryption())
|
|
|
|
|
|
def serialize_ed25519_public(key: Ed25519PublicKey) -> bytes:
|
|
"""Serialize Ed25519 public key to 32 raw bytes."""
|
|
return key.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw)
|
|
|
|
|
|
def load_ed25519_private(data: bytes, password: bytes | None = None) -> Ed25519PrivateKey:
|
|
if data.startswith(_ECP1_MAGIC):
|
|
raw = _decrypt_private_key(data, password)
|
|
return Ed25519PrivateKey.from_private_bytes(raw)
|
|
# Legacy formats: PEM (old BestAvailableEncryption) or 32-byte raw
|
|
if password:
|
|
return serialization.load_pem_private_key(data, password=password)
|
|
if len(data) == 32:
|
|
return Ed25519PrivateKey.from_private_bytes(data)
|
|
return serialization.load_pem_private_key(data, password=None)
|
|
|
|
|
|
def load_ed25519_public(data: bytes) -> Ed25519PublicKey:
|
|
if len(data) == 32:
|
|
return Ed25519PublicKey.from_public_bytes(data)
|
|
return serialization.load_pem_public_key(data)
|
|
|
|
|
|
def ed25519_sign(private_key: Ed25519PrivateKey, data: bytes) -> bytes:
|
|
"""Sign data with Ed25519. Returns 64-byte signature."""
|
|
return private_key.sign(data)
|
|
|
|
|
|
def ed25519_verify(public_key: Ed25519PublicKey, signature: bytes, data: bytes) -> bool:
|
|
"""Verify Ed25519 signature."""
|
|
try:
|
|
public_key.verify(signature, data)
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# X25519 Key Exchange
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def generate_x25519_keypair() -> tuple[X25519PrivateKey, X25519PublicKey]:
|
|
priv = X25519PrivateKey.generate()
|
|
return priv, priv.public_key()
|
|
|
|
|
|
def serialize_x25519_private(key: X25519PrivateKey) -> bytes:
|
|
"""Serialize X25519 private key to 32 raw bytes."""
|
|
return key.private_bytes(serialization.Encoding.Raw, serialization.PrivateFormat.Raw, serialization.NoEncryption())
|
|
|
|
|
|
def serialize_x25519_public(key: X25519PublicKey) -> bytes:
|
|
"""Serialize X25519 public key to 32 raw bytes."""
|
|
return key.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw)
|
|
|
|
|
|
def load_x25519_private(data: bytes) -> X25519PrivateKey:
|
|
return X25519PrivateKey.from_private_bytes(data)
|
|
|
|
|
|
def load_x25519_public(data: bytes) -> X25519PublicKey:
|
|
return X25519PublicKey.from_public_bytes(data)
|
|
|
|
|
|
def x25519_dh(private_key: X25519PrivateKey, public_key: X25519PublicKey) -> bytes:
|
|
"""Perform X25519 Diffie-Hellman. Returns 32-byte shared secret."""
|
|
return private_key.exchange(public_key)
|
|
|
|
|
|
def derive_pairing_shared_key(shared_secret: bytes, public_key_a: bytes, public_key_b: bytes) -> bytes:
|
|
"""Derive a symmetric bootstrap key for device pairing.
|
|
|
|
The key derivation is direction-agnostic: both peers sort the two public
|
|
keys lexicographically before binding them into HKDF salt.
|
|
"""
|
|
pub1, pub2 = sorted((public_key_a, public_key_b))
|
|
salt = hashlib.sha256(b"EncryptedChat_PairingSalt_v1\x00" + pub1 + pub2).digest()
|
|
return hkdf_derive(shared_secret, salt=salt, info=b"EncryptedChat_PairingBootstrap", length=32)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Ed25519 <-> X25519 conversion (for Identity Key dual use)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def ed25519_private_to_x25519(ed_private: Ed25519PrivateKey) -> X25519PrivateKey:
|
|
"""Derive X25519 private key from Ed25519 private key via RFC 7748 clamping."""
|
|
raw = ed_private.private_bytes(
|
|
serialization.Encoding.Raw, serialization.PrivateFormat.Raw, serialization.NoEncryption()
|
|
)
|
|
# SHA-512 hash of the seed, take first 32 bytes, clamp per RFC 7748
|
|
h = hashlib.sha512(raw).digest()[:32]
|
|
clamped = bytearray(h)
|
|
clamped[0] &= 248
|
|
clamped[31] &= 127
|
|
clamped[31] |= 64
|
|
return X25519PrivateKey.from_private_bytes(bytes(clamped))
|
|
|
|
|
|
def ed25519_public_to_x25519(ed_public: Ed25519PublicKey) -> X25519PublicKey:
|
|
"""Derive X25519 public key from Ed25519 public key.
|
|
|
|
Uses the cryptography library's internal conversion. For production use,
|
|
we compute the X25519 public key from the converted private key when possible.
|
|
For remote keys (where we don't have the private key), we use a pure-Python
|
|
implementation of the Ed25519->X25519 point conversion.
|
|
"""
|
|
# Montgomery u = (1 + y) / (1 - y) mod p, where p = 2^255 - 19
|
|
raw = ed_public.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw)
|
|
y = int.from_bytes(raw, "little")
|
|
# Clear the sign bit
|
|
y &= (1 << 255) - 1
|
|
p = (1 << 255) - 19
|
|
# u = (1 + y) * inverse(1 - y) mod p
|
|
one_plus_y = (1 + y) % p
|
|
one_minus_y = (1 - y) % p
|
|
inv = pow(one_minus_y, p - 2, p)
|
|
u = (one_plus_y * inv) % p
|
|
x25519_bytes = u.to_bytes(32, "little")
|
|
return X25519PublicKey.from_public_bytes(x25519_bytes)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# HKDF
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_HKDF_INFO_SELF = b"EncryptedChat_SelfKey"
|
|
_HKDF_INFO_RK = b"EncryptedChat_RootKey"
|
|
|
|
|
|
def derive_self_encryption_key(identity_private: Ed25519PrivateKey) -> bytes:
|
|
"""Derive a static AES-256 key from identity key for encrypting own sent messages.
|
|
|
|
This is NOT a ratchet — it's a static key. Safe because only the owner
|
|
has the identity private key, and self-copies don't need forward secrecy.
|
|
"""
|
|
raw = identity_private.private_bytes(
|
|
serialization.Encoding.Raw, serialization.PrivateFormat.Raw, serialization.NoEncryption()
|
|
)
|
|
return hkdf_derive(raw, salt=b"self_encryption", info=_HKDF_INFO_SELF, length=32)
|
|
|
|
|
|
_HKDF_INFO_LOCAL = b"EncryptedChat_LocalStorage"
|
|
|
|
|
|
def derive_local_storage_key(identity_private: Ed25519PrivateKey) -> bytes:
|
|
"""Derive AES-256 key for encrypting local session/sender key files."""
|
|
raw = identity_private.private_bytes(
|
|
serialization.Encoding.Raw, serialization.PrivateFormat.Raw, serialization.NoEncryption()
|
|
)
|
|
return hkdf_derive(raw, salt=b"local_storage", info=_HKDF_INFO_LOCAL, length=32)
|
|
|
|
|
|
_HKDF_INFO_CK_MSG = b"\x01" # chain key -> message key
|
|
_HKDF_INFO_CK_NEXT = b"\x02" # chain key -> next chain key
|
|
|
|
|
|
def hkdf_derive(input_key: bytes, salt: bytes, info: bytes, length: int = 32) -> bytes:
|
|
return HKDF(algorithm=hashes.SHA256(), length=length, salt=salt, info=info).derive(input_key)
|
|
|
|
|
|
def kdf_rk(root_key: bytes, dh_output: bytes) -> tuple[bytes, bytes]:
|
|
"""Root key KDF. Returns (new_root_key, chain_key).
|
|
|
|
Uses HKDF with the root key as salt and DH output as input key material.
|
|
Derives 64 bytes: first 32 = new root key, last 32 = chain key.
|
|
"""
|
|
derived = hkdf_derive(dh_output, salt=root_key, info=_HKDF_INFO_RK, length=64)
|
|
return derived[:32], derived[32:]
|
|
|
|
|
|
def kdf_ck(chain_key: bytes) -> tuple[bytes, bytes]:
|
|
"""Chain key KDF. Returns (new_chain_key, message_key).
|
|
|
|
Uses HMAC-SHA256:
|
|
message_key = HMAC(chain_key, 0x01)
|
|
new_chain_key = HMAC(chain_key, 0x02)
|
|
"""
|
|
message_key = hmac.new(chain_key, _HKDF_INFO_CK_MSG, hashlib.sha256).digest()
|
|
new_chain_key = hmac.new(chain_key, _HKDF_INFO_CK_NEXT, hashlib.sha256).digest()
|
|
return new_chain_key, message_key
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# X3DH
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_X3DH_INFO = b"EncryptedChat_X3DH"
|
|
|
|
|
|
def generate_signed_prekey(identity_private: Ed25519PrivateKey) -> dict:
|
|
"""Generate a signed pre-key (SPK).
|
|
|
|
Returns {private: X25519PrivateKey, public: X25519PublicKey, signature: bytes, id: str}.
|
|
"""
|
|
spk_priv, spk_pub = generate_x25519_keypair()
|
|
spk_pub_bytes = serialize_x25519_public(spk_pub)
|
|
signature = ed25519_sign(identity_private, spk_pub_bytes)
|
|
return {
|
|
"private": spk_priv,
|
|
"public": spk_pub,
|
|
"signature": signature,
|
|
"id": str(uuid.uuid4()),
|
|
}
|
|
|
|
|
|
def generate_one_time_prekeys(count: int = 50) -> list[dict]:
|
|
"""Generate a batch of one-time pre-keys.
|
|
|
|
Returns [{private: X25519PrivateKey, public: X25519PublicKey, id: str}, ...].
|
|
"""
|
|
result = []
|
|
for _ in range(count):
|
|
priv, pub = generate_x25519_keypair()
|
|
result.append({"private": priv, "public": pub, "id": str(uuid.uuid4())})
|
|
return result
|
|
|
|
|
|
def x3dh_initiate(
|
|
ik_private_ed: Ed25519PrivateKey,
|
|
ik_public_remote_ed: Ed25519PublicKey,
|
|
spk_remote: X25519PublicKey,
|
|
spk_signature: bytes,
|
|
opk_remote: X25519PublicKey | None = None,
|
|
) -> tuple[bytes, X25519PrivateKey, X25519PublicKey]:
|
|
"""Initiator side of X3DH.
|
|
|
|
Args:
|
|
ik_private_ed: Our Ed25519 identity private key
|
|
ik_public_remote_ed: Remote Ed25519 identity public key
|
|
spk_remote: Remote signed pre-key (X25519 public)
|
|
spk_signature: Ed25519 signature of spk_remote by ik_public_remote_ed
|
|
opk_remote: Optional one-time pre-key (X25519 public)
|
|
|
|
Returns:
|
|
(shared_secret, ephemeral_private, ephemeral_public)
|
|
"""
|
|
# Verify SPK signature
|
|
spk_remote_bytes = serialize_x25519_public(spk_remote)
|
|
if not ed25519_verify(ik_public_remote_ed, spk_signature, spk_remote_bytes):
|
|
raise ValueError("Invalid SPK signature")
|
|
|
|
# Convert identity keys to X25519
|
|
ik_x25519_private = ed25519_private_to_x25519(ik_private_ed)
|
|
ik_x25519_remote = ed25519_public_to_x25519(ik_public_remote_ed)
|
|
|
|
# Generate ephemeral keypair
|
|
ek_priv, ek_pub = generate_x25519_keypair()
|
|
|
|
# DH computations
|
|
dh1 = x25519_dh(ik_x25519_private, spk_remote) # IK_A, SPK_B
|
|
dh2 = x25519_dh(ek_priv, ik_x25519_remote) # EK_A, IK_B
|
|
dh3 = x25519_dh(ek_priv, spk_remote) # EK_A, SPK_B
|
|
|
|
dh_concat = dh1 + dh2 + dh3
|
|
if opk_remote is not None:
|
|
dh4 = x25519_dh(ek_priv, opk_remote) # EK_A, OPK_B
|
|
dh_concat += dh4
|
|
|
|
# Derive shared secret
|
|
shared_secret = hkdf_derive(dh_concat, salt=b"\x00" * 32, info=_X3DH_INFO, length=32)
|
|
return shared_secret, ek_priv, ek_pub
|
|
|
|
|
|
def x3dh_respond(
|
|
ik_private_ed: Ed25519PrivateKey,
|
|
spk_private: X25519PrivateKey,
|
|
ik_remote_ed: Ed25519PublicKey,
|
|
ek_remote: X25519PublicKey,
|
|
opk_private: X25519PrivateKey | None = None,
|
|
) -> bytes:
|
|
"""Responder side of X3DH.
|
|
|
|
Args:
|
|
ik_private_ed: Our Ed25519 identity private key
|
|
spk_private: Our signed pre-key private (X25519)
|
|
ik_remote_ed: Remote Ed25519 identity public key
|
|
ek_remote: Remote ephemeral key (X25519 public)
|
|
opk_private: Our one-time pre-key private (X25519), if used
|
|
|
|
Returns:
|
|
shared_secret (32 bytes)
|
|
"""
|
|
ik_x25519_private = ed25519_private_to_x25519(ik_private_ed)
|
|
ik_x25519_remote = ed25519_public_to_x25519(ik_remote_ed)
|
|
|
|
dh1 = x25519_dh(spk_private, ik_x25519_remote) # SPK_B, IK_A
|
|
dh2 = x25519_dh(ik_x25519_private, ek_remote) # IK_B, EK_A
|
|
dh3 = x25519_dh(spk_private, ek_remote) # SPK_B, EK_A
|
|
|
|
dh_concat = dh1 + dh2 + dh3
|
|
if opk_private is not None:
|
|
dh4 = x25519_dh(opk_private, ek_remote) # OPK_B, EK_A
|
|
dh_concat += dh4
|
|
|
|
shared_secret = hkdf_derive(dh_concat, salt=b"\x00" * 32, info=_X3DH_INFO, length=32)
|
|
return shared_secret
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Double Ratchet
|
|
# ---------------------------------------------------------------------------
|
|
|
|
MAX_SKIP = 256 # max messages to skip in a single chain (out-of-order tolerance)
|
|
|
|
|
|
@dataclass
|
|
class RatchetHeader:
|
|
"""Header sent with each ratchet message."""
|
|
dh_pub: bytes # sender's current ratchet public key (32 bytes)
|
|
n: int # message number in current sending chain
|
|
pn: int # number of messages in previous sending chain
|
|
|
|
def serialize(self) -> bytes:
|
|
return json.dumps({
|
|
"dh_pub": serialize_x25519_public(load_x25519_public(self.dh_pub)).hex()
|
|
if isinstance(self.dh_pub, bytes) else serialize_x25519_public(self.dh_pub).hex(),
|
|
"n": self.n,
|
|
"pn": self.pn,
|
|
}).encode()
|
|
|
|
def to_dict(self) -> dict:
|
|
pub_hex = self.dh_pub.hex() if isinstance(self.dh_pub, bytes) else \
|
|
serialize_x25519_public(self.dh_pub).hex()
|
|
return {"dh_pub": pub_hex, "n": self.n, "pn": self.pn}
|
|
|
|
@classmethod
|
|
def from_dict(cls, d: dict) -> "RatchetHeader":
|
|
return cls(dh_pub=bytes.fromhex(d["dh_pub"]), n=d["n"], pn=d["pn"])
|
|
|
|
|
|
class DoubleRatchet:
|
|
"""Signal Double Ratchet implementation."""
|
|
|
|
def __init__(self):
|
|
self.dh_pair: tuple[X25519PrivateKey, X25519PublicKey] | None = None
|
|
self.dh_remote: X25519PublicKey | None = None
|
|
self.root_key: bytes = b""
|
|
self.send_chain_key: bytes | None = None
|
|
self.recv_chain_key: bytes | None = None
|
|
self.send_n: int = 0
|
|
self.recv_n: int = 0
|
|
self.prev_send_n: int = 0
|
|
# (dh_pub_hex, n) -> message_key for out-of-order messages
|
|
self.skipped: dict[tuple[str, int], bytes] = {}
|
|
|
|
@classmethod
|
|
def init_alice(cls, shared_secret: bytes, bob_spk_pub: X25519PublicKey) -> "DoubleRatchet":
|
|
"""Initialize as initiator (Alice) after X3DH.
|
|
|
|
Alice performs the first DH ratchet step immediately.
|
|
"""
|
|
ratchet = cls()
|
|
ratchet.dh_pair = generate_x25519_keypair()
|
|
ratchet.dh_remote = bob_spk_pub
|
|
|
|
# Perform DH ratchet to derive send chain
|
|
dh_output = x25519_dh(ratchet.dh_pair[0], ratchet.dh_remote)
|
|
ratchet.root_key, ratchet.send_chain_key = kdf_rk(shared_secret, dh_output)
|
|
ratchet.recv_chain_key = None
|
|
ratchet.send_n = 0
|
|
ratchet.recv_n = 0
|
|
ratchet.prev_send_n = 0
|
|
return ratchet
|
|
|
|
@classmethod
|
|
def init_bob(cls, shared_secret: bytes, spk_pair: tuple[X25519PrivateKey, X25519PublicKey]) -> "DoubleRatchet":
|
|
"""Initialize as responder (Bob) after X3DH.
|
|
|
|
Bob uses his SPK as the initial ratchet key pair.
|
|
"""
|
|
ratchet = cls()
|
|
ratchet.dh_pair = spk_pair
|
|
ratchet.root_key = shared_secret
|
|
ratchet.send_chain_key = None
|
|
ratchet.recv_chain_key = None
|
|
ratchet.send_n = 0
|
|
ratchet.recv_n = 0
|
|
ratchet.prev_send_n = 0
|
|
return ratchet
|
|
|
|
def encrypt(self, plaintext: bytes) -> dict:
|
|
"""Encrypt a message.
|
|
|
|
Returns {header: {dh_pub, n, pn}, ciphertext: bytes, nonce: bytes}.
|
|
"""
|
|
if self.send_chain_key is None:
|
|
raise RuntimeError("Send chain not initialized")
|
|
|
|
self.send_chain_key, message_key = kdf_ck(self.send_chain_key)
|
|
|
|
header = RatchetHeader(
|
|
dh_pub=serialize_x25519_public(self.dh_pair[1]),
|
|
n=self.send_n,
|
|
pn=self.prev_send_n,
|
|
)
|
|
|
|
# Encrypt with AES-256-GCM using the message key
|
|
nonce = os.urandom(12)
|
|
aesgcm = AESGCM(message_key)
|
|
# Include header as AAD to bind ciphertext to header
|
|
aad = header.serialize()
|
|
ct_with_tag = aesgcm.encrypt(nonce, plaintext, aad)
|
|
|
|
self.send_n += 1
|
|
|
|
return {
|
|
"header": header.to_dict(),
|
|
"ciphertext": ct_with_tag, # includes 16-byte tag
|
|
"nonce": nonce,
|
|
}
|
|
|
|
def decrypt(self, header_dict: dict, ciphertext: bytes, nonce: bytes) -> bytes:
|
|
"""Decrypt a message. Handles DH ratchet step if new dh_pub.
|
|
|
|
State is snapshotted before modification and restored on failure (M9 fix).
|
|
"""
|
|
header = RatchetHeader.from_dict(header_dict)
|
|
remote_dh_pub_bytes = header.dh_pub
|
|
|
|
# Check if this is from a skipped message (no state modification needed)
|
|
skip_key = (remote_dh_pub_bytes.hex(), header.n)
|
|
if skip_key in self.skipped:
|
|
mk = self.skipped.pop(skip_key)
|
|
aad = header.serialize()
|
|
aesgcm = AESGCM(mk)
|
|
try:
|
|
return aesgcm.decrypt(nonce, ciphertext, aad)
|
|
except Exception:
|
|
self.skipped[skip_key] = mk # restore skipped key
|
|
raise
|
|
|
|
# Snapshot state before modifications
|
|
snap = self._snapshot()
|
|
|
|
try:
|
|
remote_dh_pub = load_x25519_public(remote_dh_pub_bytes)
|
|
current_remote_bytes = serialize_x25519_public(self.dh_remote) if self.dh_remote else None
|
|
|
|
if current_remote_bytes is None or remote_dh_pub_bytes != current_remote_bytes:
|
|
# New DH ratchet step
|
|
self._skip_messages(header.pn)
|
|
self._dh_ratchet(remote_dh_pub)
|
|
|
|
self._skip_messages(header.n)
|
|
|
|
# Derive message key from receive chain
|
|
self.recv_chain_key, mk = kdf_ck(self.recv_chain_key)
|
|
self.recv_n += 1
|
|
|
|
aad = header.serialize()
|
|
aesgcm = AESGCM(mk)
|
|
return aesgcm.decrypt(nonce, ciphertext, aad)
|
|
except Exception:
|
|
self._restore(snap)
|
|
raise
|
|
|
|
def _snapshot(self) -> dict:
|
|
"""Capture mutable state for rollback on decrypt failure."""
|
|
return {
|
|
"dh_pair": self.dh_pair,
|
|
"dh_remote": self.dh_remote,
|
|
"root_key": self.root_key,
|
|
"send_chain_key": self.send_chain_key,
|
|
"recv_chain_key": self.recv_chain_key,
|
|
"send_n": self.send_n,
|
|
"recv_n": self.recv_n,
|
|
"prev_send_n": self.prev_send_n,
|
|
"skipped": dict(self.skipped),
|
|
}
|
|
|
|
def _restore(self, snap: dict):
|
|
"""Restore state from snapshot."""
|
|
self.dh_pair = snap["dh_pair"]
|
|
self.dh_remote = snap["dh_remote"]
|
|
self.root_key = snap["root_key"]
|
|
self.send_chain_key = snap["send_chain_key"]
|
|
self.recv_chain_key = snap["recv_chain_key"]
|
|
self.send_n = snap["send_n"]
|
|
self.recv_n = snap["recv_n"]
|
|
self.prev_send_n = snap["prev_send_n"]
|
|
self.skipped = snap["skipped"]
|
|
|
|
def _skip_messages(self, until: int):
|
|
"""Skip ahead in the receive chain, storing message keys for out-of-order delivery."""
|
|
if self.recv_chain_key is None:
|
|
return
|
|
if until - self.recv_n > MAX_SKIP:
|
|
raise RuntimeError(f"Too many skipped messages ({until - self.recv_n} > {MAX_SKIP})")
|
|
while self.recv_n < until:
|
|
self.recv_chain_key, mk = kdf_ck(self.recv_chain_key)
|
|
remote_hex = serialize_x25519_public(self.dh_remote).hex() if self.dh_remote else ""
|
|
self.skipped[(remote_hex, self.recv_n)] = mk
|
|
self.recv_n += 1
|
|
|
|
def _dh_ratchet(self, remote_dh_pub: X25519PublicKey):
|
|
"""Perform a DH ratchet step: update receive chain, generate new DH pair, update send chain."""
|
|
self.prev_send_n = self.send_n
|
|
self.send_n = 0
|
|
self.recv_n = 0
|
|
self.dh_remote = remote_dh_pub
|
|
|
|
# Derive new receive chain key
|
|
dh_output = x25519_dh(self.dh_pair[0], self.dh_remote)
|
|
self.root_key, self.recv_chain_key = kdf_rk(self.root_key, dh_output)
|
|
|
|
# Generate new DH pair and derive new send chain key
|
|
self.dh_pair = generate_x25519_keypair()
|
|
dh_output = x25519_dh(self.dh_pair[0], self.dh_remote)
|
|
self.root_key, self.send_chain_key = kdf_rk(self.root_key, dh_output)
|
|
|
|
def export_state(self) -> bytes:
|
|
"""Serialize full ratchet state for persistent storage."""
|
|
state = {
|
|
"dh_priv": serialize_x25519_private(self.dh_pair[0]).hex() if self.dh_pair else None,
|
|
"dh_pub": serialize_x25519_public(self.dh_pair[1]).hex() if self.dh_pair else None,
|
|
"dh_remote": serialize_x25519_public(self.dh_remote).hex() if self.dh_remote else None,
|
|
"root_key": self.root_key.hex(),
|
|
"send_ck": self.send_chain_key.hex() if self.send_chain_key else None,
|
|
"recv_ck": self.recv_chain_key.hex() if self.recv_chain_key else None,
|
|
"send_n": self.send_n,
|
|
"recv_n": self.recv_n,
|
|
"prev_send_n": self.prev_send_n,
|
|
"skipped": {f"{k[0]}:{k[1]}": v.hex() for k, v in self.skipped.items()},
|
|
}
|
|
return json.dumps(state).encode()
|
|
|
|
@classmethod
|
|
def import_state(cls, data: bytes) -> "DoubleRatchet":
|
|
"""Deserialize ratchet state."""
|
|
state = json.loads(data)
|
|
r = cls()
|
|
if state["dh_priv"] and state["dh_pub"]:
|
|
priv = load_x25519_private(bytes.fromhex(state["dh_priv"]))
|
|
pub = load_x25519_public(bytes.fromhex(state["dh_pub"]))
|
|
r.dh_pair = (priv, pub)
|
|
if state["dh_remote"]:
|
|
r.dh_remote = load_x25519_public(bytes.fromhex(state["dh_remote"]))
|
|
r.root_key = bytes.fromhex(state["root_key"])
|
|
r.send_chain_key = bytes.fromhex(state["send_ck"]) if state["send_ck"] else None
|
|
r.recv_chain_key = bytes.fromhex(state["recv_ck"]) if state["recv_ck"] else None
|
|
r.send_n = state["send_n"]
|
|
r.recv_n = state["recv_n"]
|
|
r.prev_send_n = state["prev_send_n"]
|
|
r.skipped = {}
|
|
for k_str, v_hex in state.get("skipped", {}).items():
|
|
parts = k_str.rsplit(":", 1)
|
|
dh_hex = parts[0]
|
|
n = int(parts[1])
|
|
r.skipped[(dh_hex, n)] = bytes.fromhex(v_hex)
|
|
return r
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Sender Keys (group messaging)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class SenderKeyState:
|
|
"""Sender key chain for group messaging.
|
|
|
|
Each sender in a group has their own sender key chain.
|
|
Other group members receive the initial sender_key via pairwise Double Ratchet.
|
|
"""
|
|
|
|
def __init__(self, sender_key: bytes | None = None):
|
|
if sender_key is None:
|
|
sender_key = os.urandom(32)
|
|
self.sender_key = sender_key
|
|
self.chain_id = hashlib.sha256(sender_key).digest()
|
|
self.chain_key = hkdf_derive(sender_key, salt=b"\x00" * 32, info=b"SenderKeyChain", length=32)
|
|
self.n = 0
|
|
# For receivers: track chain state to allow fast-forward
|
|
self._known_keys: dict[int, bytes] = {}
|
|
|
|
def encrypt(self, plaintext: bytes) -> dict:
|
|
"""Encrypt with current chain key.
|
|
|
|
Returns {chain_id: hex, n: int, ciphertext: bytes, nonce: bytes}.
|
|
"""
|
|
self.chain_key, message_key = kdf_ck(self.chain_key)
|
|
nonce = os.urandom(12)
|
|
aesgcm = AESGCM(message_key)
|
|
# AAD includes chain_id and message number
|
|
aad = self.chain_id + struct.pack(">I", self.n)
|
|
ct_with_tag = aesgcm.encrypt(nonce, plaintext, aad)
|
|
result = {
|
|
"chain_id": self.chain_id.hex(),
|
|
"n": self.n,
|
|
"ciphertext": ct_with_tag,
|
|
"nonce": nonce,
|
|
}
|
|
self.n += 1
|
|
return result
|
|
|
|
MAX_SENDER_KEY_SKIP = 256
|
|
|
|
def decrypt(self, chain_id_hex: str, n: int, ciphertext: bytes, nonce: bytes) -> bytes:
|
|
"""Decrypt a group message. Fast-forwards the chain if needed.
|
|
|
|
State is snapshotted before modification and restored on failure (M9 fix).
|
|
"""
|
|
chain_id = bytes.fromhex(chain_id_hex)
|
|
if chain_id != self.chain_id:
|
|
raise ValueError("Chain ID mismatch")
|
|
|
|
if n - self.n > self.MAX_SENDER_KEY_SKIP:
|
|
raise ValueError(f"Sender key skip too large ({n - self.n} > {self.MAX_SENDER_KEY_SKIP})")
|
|
|
|
# Snapshot before fast-forward
|
|
snap_chain_key = self.chain_key
|
|
snap_n = self.n
|
|
snap_known = dict(self._known_keys)
|
|
|
|
try:
|
|
# Fast-forward the chain to reach message n
|
|
while self.n <= n:
|
|
self.chain_key, mk = kdf_ck(self.chain_key)
|
|
self._known_keys[self.n] = mk
|
|
self.n += 1
|
|
|
|
mk = self._known_keys.pop(n, None)
|
|
if mk is None:
|
|
raise ValueError(f"Message key for n={n} not available (already consumed)")
|
|
|
|
aad = chain_id + struct.pack(">I", n)
|
|
aesgcm = AESGCM(mk)
|
|
return aesgcm.decrypt(nonce, ciphertext, aad)
|
|
except Exception:
|
|
self.chain_key = snap_chain_key
|
|
self.n = snap_n
|
|
self._known_keys = snap_known
|
|
raise
|
|
|
|
def export_key(self) -> bytes:
|
|
"""Export sender key for distribution to group members.
|
|
|
|
Contains everything needed to initialize a receiving SenderKeyState.
|
|
"""
|
|
return json.dumps({
|
|
"sender_key": self.sender_key.hex(),
|
|
}).encode()
|
|
|
|
def export_state(self) -> bytes:
|
|
"""Serialize full state for persistent storage."""
|
|
return json.dumps({
|
|
"sender_key": self.sender_key.hex(),
|
|
"chain_id": self.chain_id.hex(),
|
|
"chain_key": self.chain_key.hex(),
|
|
"n": self.n,
|
|
"known_keys": {str(k): v.hex() for k, v in self._known_keys.items()},
|
|
}).encode()
|
|
|
|
@classmethod
|
|
def import_state(cls, data: bytes) -> "SenderKeyState":
|
|
state = json.loads(data)
|
|
obj = cls.__new__(cls)
|
|
obj.sender_key = bytes.fromhex(state["sender_key"])
|
|
obj.chain_id = bytes.fromhex(state["chain_id"])
|
|
obj.chain_key = bytes.fromhex(state["chain_key"])
|
|
obj.n = state["n"]
|
|
obj._known_keys = {int(k): bytes.fromhex(v) for k, v in state.get("known_keys", {}).items()}
|
|
return obj
|
|
|
|
@classmethod
|
|
def from_key(cls, exported_key: bytes) -> "SenderKeyState":
|
|
"""Initialize a receiving SenderKeyState from an exported key."""
|
|
data = json.loads(exported_key)
|
|
return cls(sender_key=bytes.fromhex(data["sender_key"]))
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Contact Key Verification (Safety Numbers / Fingerprints / QR Codes)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
FINGERPRINT_VERSION = 0
|
|
|
|
|
|
def compute_fingerprint(user_id: str, identity_key_bytes: bytes, iterations: int = 5200) -> bytes:
|
|
"""Compute a 32-byte fingerprint for a user's identity key.
|
|
|
|
Uses iterated SHA-512 (Signal's NumericFingerprint algorithm).
|
|
Seed: version(2B) + identity_key(32B) + user_id(UTF-8).
|
|
Each iteration: SHA-512(previous_hash + identity_key).
|
|
Output: first 32 bytes of final hash.
|
|
"""
|
|
version_bytes = FINGERPRINT_VERSION.to_bytes(2, "big")
|
|
data = version_bytes + identity_key_bytes + user_id.encode("utf-8")
|
|
for _ in range(iterations):
|
|
data = hashlib.sha512(data + identity_key_bytes).digest()
|
|
return data[:32]
|
|
|
|
|
|
def format_fingerprint(fp_bytes: bytes) -> str:
|
|
"""Format 32-byte fingerprint as 6 groups of 5 zero-padded digits (30 digits).
|
|
|
|
Each group: int(bytes[i*5:(i+1)*5], big-endian) % 100000.
|
|
Output: two lines of 3 groups each, space-separated.
|
|
"""
|
|
groups = []
|
|
for i in range(6):
|
|
num = int.from_bytes(fp_bytes[i * 5:(i + 1) * 5], "big") % 100000
|
|
groups.append(f"{num:05d}")
|
|
return " ".join(groups[:3]) + "\n" + " ".join(groups[3:])
|
|
|
|
|
|
def compute_safety_number(my_uid: str, my_ik_bytes: bytes,
|
|
their_uid: str, their_ik_bytes: bytes) -> str:
|
|
"""Compute a 60-digit safety number for a pair of users.
|
|
|
|
Both users see the same number regardless of who computes it.
|
|
Lower user_id's fingerprint comes first (deterministic ordering).
|
|
Output: 12 groups of 5 digits, formatted as 3 lines of 4 groups.
|
|
"""
|
|
fp_mine = compute_fingerprint(my_uid, my_ik_bytes)
|
|
fp_theirs = compute_fingerprint(their_uid, their_ik_bytes)
|
|
if my_uid < their_uid:
|
|
combined = fp_mine + fp_theirs
|
|
else:
|
|
combined = fp_theirs + fp_mine
|
|
# 64 bytes -> 12 groups of 5 digits
|
|
groups = []
|
|
for i in range(12):
|
|
num = int.from_bytes(combined[i * 5:(i + 1) * 5], "big") % 100000
|
|
groups.append(f"{num:05d}")
|
|
lines = [
|
|
" ".join(groups[0:4]),
|
|
" ".join(groups[4:8]),
|
|
" ".join(groups[8:12]),
|
|
]
|
|
return "\n".join(lines)
|
|
|
|
|
|
def encode_verification_qr(user_id: str, identity_key_bytes: bytes) -> bytes:
|
|
"""Encode user identity for QR code verification.
|
|
|
|
Format: version(1B=0x01) + uid_len(1B) + uid(UTF-8) + identity_key(32B).
|
|
"""
|
|
uid_bytes = user_id.encode("utf-8")
|
|
return b"\x01" + len(uid_bytes).to_bytes(1, "big") + uid_bytes + identity_key_bytes
|
|
|
|
|
|
def decode_verification_qr(data: bytes) -> tuple[str, bytes]:
|
|
"""Decode QR code verification payload.
|
|
|
|
Returns (user_id, identity_key_bytes).
|
|
Raises ValueError on invalid format.
|
|
"""
|
|
if len(data) < 3:
|
|
raise ValueError("QR data too short")
|
|
if data[0] != 0x01:
|
|
raise ValueError(f"Unknown QR version: {data[0]}")
|
|
uid_len = data[1]
|
|
if len(data) < 2 + uid_len + 32:
|
|
raise ValueError("QR data truncated")
|
|
user_id = data[2:2 + uid_len].decode("utf-8")
|
|
identity_key = data[2 + uid_len:2 + uid_len + 32]
|
|
return user_id, identity_key
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Message Padding (metadata privacy — hide plaintext length)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_PAD_MAGIC = b"\x01"
|
|
_PAD_BUCKETS = [64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536]
|
|
|
|
|
|
def pad_plaintext(plaintext: bytes) -> bytes:
|
|
"""Pad plaintext to nearest bucket size to hide message length.
|
|
|
|
Format: 0x01 + plaintext + random_padding + pad_length(4B big-endian)
|
|
Prefix 0x01 distinguishes padded messages from legacy unpadded (which start with '{').
|
|
"""
|
|
content = _PAD_MAGIC + plaintext
|
|
# +4 for the length suffix
|
|
min_size = len(content) + 4
|
|
target = next((b for b in _PAD_BUCKETS if b >= min_size), min_size)
|
|
pad_len = target - len(content)
|
|
return content + os.urandom(pad_len - 4) + struct.pack(">I", pad_len)
|
|
|
|
|
|
def unpad_plaintext(data: bytes) -> bytes:
|
|
"""Remove padding. Returns raw plaintext for both padded and legacy unpadded messages."""
|
|
if not data or data[0:1] != _PAD_MAGIC:
|
|
return data # legacy unpadded message (starts with '{' for JSON)
|
|
if len(data) < 5:
|
|
return data # too short to be validly padded
|
|
pad_len = struct.unpack(">I", data[-4:])[0]
|
|
if pad_len < 4 or pad_len > len(data) - 1:
|
|
return data # invalid padding metadata, treat as legacy
|
|
return data[1:len(data) - pad_len]
|