"""Shared network layer and ChatClient class for CLI and GUI clients. Uses X3DH + Double Ratchet for message encryption, Sender Keys for groups. RSA retained for login challenge-response only. """ import asyncio import collections import hashlib import json import logging import os import random import ssl import time import uuid from datetime import datetime, timezone from pathlib import Path from dotenv import load_dotenv load_dotenv() from crypto_utils import ( # RSA (login only) generate_rsa_keypair, serialize_private_key, serialize_public_key, load_private_key, load_public_key, rsa_sign, # Ed25519 generate_identity_keypair, serialize_ed25519_private, serialize_ed25519_private_raw, serialize_ed25519_public, load_ed25519_private, load_ed25519_public, ed25519_sign, # X25519 generate_x25519_keypair, serialize_x25519_private, serialize_x25519_public, load_x25519_private, load_x25519_public, x25519_dh, derive_pairing_shared_key, # X3DH generate_signed_prekey, generate_one_time_prekeys, x3dh_initiate, x3dh_respond, # Double Ratchet DoubleRatchet, # Sender Keys SenderKeyState, # AES aes_encrypt, aes_decrypt, # Self-encryption derive_self_encryption_key, # Local storage encryption derive_local_storage_key, # Contact verification compute_fingerprint, compute_pairing_fingerprint, encode_pairing_qr, format_fingerprint, normalize_pairing_fingerprint, compute_safety_number, encode_verification_qr, decode_verification_qr, # Message padding pad_plaintext, unpad_plaintext, ) from protocol import ( VERSION, ProtocolReader, ProtocolWriter, encode_binary, decode_binary, build_request, MAX_MESSAGE_BYTES, MAX_IMAGE_BYTES, IMAGE_CHUNK_SIZE, ) KEY_DIR = Path.home() / ".encrypted_chat" OPK_REPLENISH_THRESHOLD = 20 OPK_BATCH_SIZE = 50 SPK_ROTATION_DAYS = 7 PAIRING_REENCRYPT_INITIAL_DELAY_RANGE = (20.0, 75.0) PAIRING_REENCRYPT_INTER_BATCH_DELAY_RANGE = (1.0, 3.0) PAIRING_REENCRYPT_INTER_FETCH_DELAY_RANGE = (0.15, 0.5) PAIRING_REENCRYPT_BATCH_SIZE = 500 def _encrypt_local(data: bytes, key: bytes) -> bytes: """Encrypt data with AES-256-GCM for local storage. Format: nonce(12) + tag(16) + ciphertext.""" _, nonce, ct, tag = aes_encrypt(data, key=key) return nonce + tag + ct def _decrypt_local(raw: bytes, key: bytes) -> bytes: """Decrypt data encrypted by _encrypt_local.""" nonce, tag, ct = raw[:12], raw[12:28], raw[28:] return aes_decrypt(key, nonce, ct, tag) def get_key_dir(email: str) -> Path: d = KEY_DIR / email d.mkdir(parents=True, exist_ok=True) os.chmod(d, 0o700) return d # --------------------------------------------------------------------------- # RSA key storage (login only — unchanged interface) # --------------------------------------------------------------------------- def save_keys(email: str, private_key, public_key, password: bytes | None = None): d = get_key_dir(email) (d / "private.pem").write_bytes(serialize_private_key(private_key, password=password)) (d / "public.pem").write_bytes(serialize_public_key(public_key)) os.chmod(d / "private.pem", 0o600) def load_keys(email: str, password: bytes | None = None): d = get_key_dir(email) priv_path = d / "private.pem" pub_path = d / "public.pem" if not priv_path.exists(): return None, None, "No local keys found." pem = priv_path.read_bytes() try: private_key = load_private_key(pem, password=password) except Exception: try: private_key = load_private_key(pem, password=None) if password: save_keys(email, private_key, load_public_key(pub_path.read_bytes()), password=password) except Exception: return None, None, "Invalid or missing password." public_key = load_public_key(pub_path.read_bytes()) return private_key, public_key, None # --------------------------------------------------------------------------- # Identity + prekey storage # --------------------------------------------------------------------------- def _save_identity_keys(email: str, ed_priv, ed_pub, password: bytes | None = None): d = get_key_dir(email) if password: (d / "identity_private.bin").write_bytes(serialize_ed25519_private(ed_priv, password=password)) else: (d / "identity_private.bin").write_bytes(serialize_ed25519_private_raw(ed_priv)) (d / "identity_public.bin").write_bytes(serialize_ed25519_public(ed_pub)) os.chmod(d / "identity_private.bin", 0o600) def _load_identity_keys(email: str, password: bytes | None = None): d = get_key_dir(email) priv_path = d / "identity_private.bin" pub_path = d / "identity_public.bin" if not priv_path.exists(): return None, None priv = load_ed25519_private(priv_path.read_bytes(), password=password) pub = load_ed25519_public(pub_path.read_bytes()) return priv, pub def _save_spk(email: str, spk_priv, spk_id: str, local_key: bytes | None = None): d = get_key_dir(email) raw = serialize_x25519_private(spk_priv) data = _encrypt_local(raw, local_key) if local_key else raw (d / "spk_private.bin").write_bytes(data) (d / "spk_id.txt").write_text(spk_id) os.chmod(d / "spk_private.bin", 0o600) def _load_spk(email: str, local_key: bytes | None = None): d = get_key_dir(email) priv_path = d / "spk_private.bin" id_path = d / "spk_id.txt" if not priv_path.exists(): return None, None raw = priv_path.read_bytes() if local_key: try: raw = _decrypt_local(raw, local_key) except Exception: # Plaintext fallback (migration) — re-save encrypted pass priv = load_x25519_private(raw) spk_id = id_path.read_text().strip() if id_path.exists() else "" if local_key: _save_spk(email, priv, spk_id, local_key) return priv, spk_id def _save_prev_spk(email: str, spk_priv, spk_id: str, local_key: bytes | None = None): """Save previous SPK for grace period (in-flight X3DH may reference old SPK).""" d = get_key_dir(email) raw = serialize_x25519_private(spk_priv) data = _encrypt_local(raw, local_key) if local_key else raw (d / "prev_spk_private.bin").write_bytes(data) (d / "prev_spk_id.txt").write_text(spk_id) os.chmod(d / "prev_spk_private.bin", 0o600) def _load_prev_spk(email: str, local_key: bytes | None = None): """Load previous SPK (grace period). Returns (private_key, spk_id) or (None, None).""" d = get_key_dir(email) priv_path = d / "prev_spk_private.bin" id_path = d / "prev_spk_id.txt" if not priv_path.exists(): return None, None raw = priv_path.read_bytes() if local_key: try: raw = _decrypt_local(raw, local_key) except Exception: pass priv = load_x25519_private(raw) spk_id = id_path.read_text().strip() if id_path.exists() else "" if local_key: _save_prev_spk(email, priv, spk_id, local_key) return priv, spk_id def _save_opk_private(email: str, opk_id: str, opk_priv, local_key: bytes | None = None): d = get_key_dir(email) / "opk_private" d.mkdir(parents=True, exist_ok=True) os.chmod(d, 0o700) raw = serialize_x25519_private(opk_priv) data = _encrypt_local(raw, local_key) if local_key else raw (d / f"{opk_id}.bin").write_bytes(data) os.chmod(d / f"{opk_id}.bin", 0o600) def _load_opk_private(email: str, opk_id: str, local_key: bytes | None = None): d = get_key_dir(email) / "opk_private" p = d / f"{opk_id}.bin" if not p.exists(): return None raw = p.read_bytes() if local_key: try: raw = _decrypt_local(raw, local_key) except Exception: pass priv = load_x25519_private(raw) # Migration: re-save encrypted if local_key provided if local_key: _save_opk_private(email, opk_id, priv, local_key) return priv def _secure_delete(p: Path): """Overwrite file with random data before deletion (anti-forensic wipe).""" try: if not p.exists(): return size = p.stat().st_size if size > 0: with open(p, "r+b") as f: f.write(os.urandom(size)) f.flush() os.fsync(f.fileno()) p.unlink() except Exception: try: p.unlink(missing_ok=True) except Exception: pass def _delete_opk_private(email: str, opk_id: str): d = get_key_dir(email) / "opk_private" p = d / f"{opk_id}.bin" _secure_delete(p) def _save_device_id(email: str, device_id: str): d = get_key_dir(email) p = d / "device_id.txt" p.write_text(device_id) os.chmod(p, 0o600) def _load_device_id(email: str) -> str | None: d = get_key_dir(email) p = d / "device_id.txt" if not p.exists(): return None return p.read_text().strip() or None # ------------------------------------------------------------------ # Expiring LRU cache for user info (max_size + TTL eviction) # ------------------------------------------------------------------ class _ExpiringLRUCache: """Dict-like cache with max size (LRU eviction) and per-entry TTL.""" def __init__(self, max_size: int = 10_000, ttl: float = 3600.0): self._max_size = max_size self._ttl = ttl self._data: collections.OrderedDict = collections.OrderedDict() # key -> (value, ts) def get(self, key, default=None): entry = self._data.get(key) if entry is None: return default value, ts = entry if time.monotonic() - ts > self._ttl: del self._data[key] return default # Move to end (most recently used) self._data.move_to_end(key) return value def __setitem__(self, key, value): if key in self._data: self._data.move_to_end(key) self._data[key] = (value, time.monotonic()) while len(self._data) > self._max_size: self._data.popitem(last=False) def __getitem__(self, key): result = self.get(key) if result is None and key not in self._data: raise KeyError(key) return result def __contains__(self, key): return self.get(key) is not None # ------------------------------------------------------------------ # Identity key change exception (TOFU hard-fail) # ------------------------------------------------------------------ class IdentityKeyChanged(Exception): """Raised when a peer's identity key has changed (TOFU violation). Session creation is blocked until the user explicitly accepts the new key. """ def __init__(self, user_id: str, new_key_bytes: bytes, status: str): self.user_id = user_id self.new_key_bytes = new_key_bytes self.status = status # "changed" or "changed_verified" super().__init__( f"Identity key changed for user {user_id} (status={status}). " f"Accept the new key before communicating." ) # ------------------------------------------------------------------ # Client-side brute-force lockout # ------------------------------------------------------------------ _LOCKOUT_BASE_SECONDS = 2 _LOCKOUT_MAX_SECONDS = 300 # 5 min cap def _get_lockout_path(email: str) -> Path: return get_key_dir(email) / "login_lockout.json" def _check_lockout(email: str) -> float: """Return seconds remaining until next attempt allowed. 0 = can try now.""" p = _get_lockout_path(email) if not p.exists(): return 0.0 try: data = json.loads(p.read_text()) locked_until = data.get("locked_until", 0.0) remaining = locked_until - time.time() return max(0.0, remaining) except Exception: return 0.0 def _record_failed_attempt(email: str): """Increment failed counter, update locked_until.""" p = _get_lockout_path(email) failed = 0 try: if p.exists(): data = json.loads(p.read_text()) failed = data.get("failed_attempts", 0) except Exception: pass failed += 1 delay = min(_LOCKOUT_BASE_SECONDS ** failed, _LOCKOUT_MAX_SECONDS) locked_until = time.time() + delay p.write_text(json.dumps({"failed_attempts": failed, "locked_until": locked_until})) os.chmod(p, 0o600) def _clear_lockout(email: str): """Reset on successful login.""" p = _get_lockout_path(email) if p.exists(): try: p.unlink() except Exception: pass def _save_session(email: str, peer_user_id: str, ratchet: DoubleRatchet, local_key: bytes | None = None, peer_device_id: str | None = None): d = get_key_dir(email) / "sessions" d.mkdir(parents=True, exist_ok=True) os.chmod(d, 0o700) if peer_device_id: filename = f"{peer_user_id}_{peer_device_id}.bin" else: filename = f"{peer_user_id}.bin" p = d / filename data = ratchet.export_state() if local_key: data = _encrypt_local(data, local_key) p.write_bytes(data) os.chmod(p, 0o600) def _load_session(email: str, peer_user_id: str, local_key: bytes | None = None, peer_device_id: str | None = None) -> DoubleRatchet | None: d = get_key_dir(email) / "sessions" if peer_device_id: p = d / f"{peer_user_id}_{peer_device_id}.bin" if not p.exists(): # Fallback: try old format (no device_id) and migrate p_old = d / f"{peer_user_id}.bin" if p_old.exists(): ratchet = _load_session_file(p_old, local_key) if ratchet: _save_session(email, peer_user_id, ratchet, local_key, peer_device_id=peer_device_id) _secure_delete(p_old) return ratchet return None else: p = d / f"{peer_user_id}.bin" if not p.exists(): return None return _load_session_file(p, local_key) def _load_session_file(p: Path, local_key: bytes | None = None) -> DoubleRatchet | None: """Load a session from a specific file path.""" if not p.exists(): return None raw = p.read_bytes() if local_key: try: data = _decrypt_local(raw, local_key) except Exception: # Migration: try loading as plaintext (old unencrypted format) try: ratchet = DoubleRatchet.import_state(raw) return ratchet except Exception: return None return DoubleRatchet.import_state(data) return DoubleRatchet.import_state(raw) def _delete_session_file(email: str, peer_user_id: str, peer_device_id: str | None = None): """Securely delete a session file from disk (for session reset).""" d = get_key_dir(email) / "sessions" if peer_device_id: p = d / f"{peer_user_id}_{peer_device_id}.bin" else: p = d / f"{peer_user_id}.bin" _secure_delete(p) def _save_sender_key_state(email: str, conv_id: str, state: SenderKeyState, local_key: bytes | None = None): d = get_key_dir(email) / "sender_keys" d.mkdir(parents=True, exist_ok=True) os.chmod(d, 0o700) p = d / f"{conv_id}.bin" data = state.export_state() if local_key: data = _encrypt_local(data, local_key) p.write_bytes(data) os.chmod(p, 0o600) def _load_sender_key_state(email: str, conv_id: str, local_key: bytes | None = None) -> SenderKeyState | None: d = get_key_dir(email) / "sender_keys" p = d / f"{conv_id}.bin" if not p.exists(): return None raw = p.read_bytes() if local_key: try: data = _decrypt_local(raw, local_key) except Exception: try: sk = SenderKeyState.import_state(raw) _save_sender_key_state(email, conv_id, sk, local_key) return sk except Exception: return None return SenderKeyState.import_state(data) return SenderKeyState.import_state(raw) def _save_sender_key_recipients(email: str, conv_id: str, recipients: set[str], local_key: bytes | None = None): d = get_key_dir(email) / "sender_keys" d.mkdir(parents=True, exist_ok=True) os.chmod(d, 0o700) p = d / f"{conv_id}.members.bin" data = json.dumps(sorted(recipients), ensure_ascii=False).encode("utf-8") if local_key: data = _encrypt_local(data, local_key) p.write_bytes(data) os.chmod(p, 0o600) def _load_sender_key_recipients(email: str, conv_id: str, local_key: bytes | None = None) -> set[str]: d = get_key_dir(email) / "sender_keys" p = d / f"{conv_id}.members.bin" if not p.exists(): return set() raw = p.read_bytes() if local_key: try: data = _decrypt_local(raw, local_key) except Exception: # Migration: previous plaintext storage, re-save encrypted on success. try: parsed = json.loads(raw.decode("utf-8")) recipients = {str(uid) for uid in parsed if isinstance(uid, str)} _save_sender_key_recipients(email, conv_id, recipients, local_key) return recipients except Exception: return set() else: data = raw try: parsed = json.loads(data.decode("utf-8")) if not isinstance(parsed, list): return set() return {str(uid) for uid in parsed if isinstance(uid, str)} except Exception: return set() def _save_recv_sender_key(email: str, conv_id: str, sender_id: str, state: SenderKeyState, local_key: bytes | None = None, sender_device_id: str | None = None): d = get_key_dir(email) / "sender_keys_recv" d.mkdir(parents=True, exist_ok=True) os.chmod(d, 0o700) if sender_device_id: filename = f"{conv_id}_{sender_id}_{sender_device_id}.bin" else: filename = f"{conv_id}_{sender_id}.bin" p = d / filename data = state.export_state() if local_key: data = _encrypt_local(data, local_key) p.write_bytes(data) os.chmod(p, 0o600) def _load_recv_sender_key(email: str, conv_id: str, sender_id: str, local_key: bytes | None = None, sender_device_id: str | None = None) -> SenderKeyState | None: d = get_key_dir(email) / "sender_keys_recv" if sender_device_id: p = d / f"{conv_id}_{sender_id}_{sender_device_id}.bin" if not p.exists(): # Fallback: try old format and migrate p_old = d / f"{conv_id}_{sender_id}.bin" if p_old.exists(): sk = _load_recv_sender_key_file(p_old, local_key) if sk: _save_recv_sender_key(email, conv_id, sender_id, sk, local_key, sender_device_id=sender_device_id) _secure_delete(p_old) return sk return None else: p = d / f"{conv_id}_{sender_id}.bin" if not p.exists(): return None return _load_recv_sender_key_file(p, local_key) def _load_recv_sender_key_file(p: Path, local_key: bytes | None = None) -> SenderKeyState | None: """Load a recv sender key from a specific file path.""" if not p.exists(): return None raw = p.read_bytes() if local_key: try: data = _decrypt_local(raw, local_key) except Exception: try: sk = SenderKeyState.import_state(raw) return sk except Exception: return None return SenderKeyState.import_state(data) return SenderKeyState.import_state(raw) # --------------------------------------------------------------------------- # Local decrypted message cache (Double Ratchet keys are one-time use) # --------------------------------------------------------------------------- def _load_message_cache(email: str, conv_id: str, cache_key: bytes | None = None) -> dict: d = get_key_dir(email) / "message_cache" p_bin = d / f"{conv_id}.bin" p_json = d / f"{conv_id}.json" # Migration: if old plaintext .json exists but encrypted .bin doesn't if p_json.exists() and not p_bin.exists(): try: cache = json.loads(p_json.read_text("utf-8")) if cache_key: _save_message_cache_full(d, conv_id, cache, cache_key) _secure_delete(p_json) return cache except Exception: return {} if not p_bin.exists(): return {} if not cache_key: return {} try: raw = p_bin.read_bytes() # Format: nonce (12) + tag (16) + ciphertext nonce = raw[:12] tag = raw[12:28] ct = raw[28:] plaintext = aes_decrypt(cache_key, nonce, ct, tag) return json.loads(plaintext.decode("utf-8")) except Exception: return {} def _save_message_cache_full(d: Path, conv_id: str, cache: dict, cache_key: bytes): """Write the full cache dict encrypted to disk.""" d.mkdir(parents=True, exist_ok=True) os.chmod(d, 0o700) p = d / f"{conv_id}.bin" plaintext = json.dumps(cache, ensure_ascii=False).encode("utf-8") _key, nonce, ct, tag = aes_encrypt(plaintext, key=cache_key) p.write_bytes(nonce + tag + ct) os.chmod(p, 0o600) def _save_message_to_cache(email: str, conv_id: str, message_id: str, payload: dict, cache_key: bytes | None = None): d = get_key_dir(email) / "message_cache" cache = _load_message_cache(email, conv_id, cache_key) cache[message_id] = payload if cache_key: _save_message_cache_full(d, conv_id, cache, cache_key) else: # Fallback: plaintext (no identity key available yet) d.mkdir(parents=True, exist_ok=True) os.chmod(d, 0o700) p = d / f"{conv_id}.json" p.write_text(json.dumps(cache, ensure_ascii=False), "utf-8") os.chmod(p, 0o600) # --------------------------------------------------------------------------- # Verification storage (TOFU + explicit verification) # --------------------------------------------------------------------------- def _save_known_identity_keys(email: str, keys: dict, local_key: bytes | None = None): """Save TOFU identity key registry (encrypted with local_key).""" p = get_key_dir(email) / "known_identity_keys.bin" data = json.dumps({"version": 1, "keys": keys}).encode("utf-8") if local_key: data = _encrypt_local(data, local_key) p.write_bytes(data) os.chmod(p, 0o600) def _load_known_identity_keys(email: str, local_key: bytes | None = None) -> dict: """Load TOFU identity key registry. Returns empty dict on error. No plaintext fallback — these files were never stored unencrypted (feature introduced after local encryption was implemented). Accepting plaintext would allow an attacker with disk access to inject fake identity keys and bypass TOFU warnings. """ p = get_key_dir(email) / "known_identity_keys.bin" if not p.exists(): return {} raw = p.read_bytes() try: if local_key: data = _decrypt_local(raw, local_key) else: data = raw obj = json.loads(data) return obj.get("keys", {}) except Exception: return {} def _save_verified_contacts(email: str, contacts: dict, local_key: bytes | None = None): """Save explicit verification state (encrypted with local_key).""" p = get_key_dir(email) / "verified_contacts.bin" data = json.dumps({"version": 1, "contacts": contacts}).encode("utf-8") if local_key: data = _encrypt_local(data, local_key) p.write_bytes(data) os.chmod(p, 0o600) def _load_verified_contacts(email: str, local_key: bytes | None = None) -> dict: """Load explicit verification state. Returns empty dict on error. No plaintext fallback — these files were never stored unencrypted. Accepting plaintext would allow an attacker with disk access to inject fake verification records (mark attacker as "verified"). """ p = get_key_dir(email) / "verified_contacts.bin" if not p.exists(): return {} raw = p.read_bytes() try: if local_key: data = _decrypt_local(raw, local_key) else: data = raw obj = json.loads(data) return obj.get("contacts", {}) except Exception: return {} # How many sync cycles a message that fails to decrypt is retried before it # is recorded as permanently failed and the sync watermark moves past it. _MAX_DECRYPT_RETRIES = 3 def _solve_pow(challenge: str, difficulty: int) -> str: """Solve a proof-of-work challenge by finding a nonce with enough leading zero bits.""" target_bytes = difficulty // 8 target_bits = difficulty % 8 mask = (0xFF << (8 - target_bits)) & 0xFF if target_bits else 0 nonce = 0 while True: digest = hashlib.sha256(f"{challenge}{nonce}".encode()).digest() # Fast path: check full zero bytes first ok = True for i in range(target_bytes): if digest[i] != 0: ok = False break if ok and target_bits: if digest[target_bytes] & mask: ok = False if ok: return str(nonce) nonce += 1 class ChatClient: def __init__(self): self.reader: ProtocolReader | None = None self.writer: ProtocolWriter | None = None self.raw_writer: asyncio.StreamWriter | None = None self.session: dict | None = None self.private_key = None # RSA private key (login only) self.public_key = None # RSA public key (login only) self.username: str = "" self.email: str = "" self._listener_task: asyncio.Task | None = None self._response_queue: asyncio.Queue = asyncio.Queue() self._notification_queue: asyncio.Queue = asyncio.Queue() self._pending: dict[str, asyncio.Future] = {} self._pairing_temp_private_key = None self._pairing_fingerprint: str = "" self._pairing_code: str = "" self._reencrypt_progress_cb = None self._logger = logging.getLogger("encrypted_chat.client") # Signal Protocol keys self.identity_private = None # Ed25519PrivateKey self.identity_public = None # Ed25519PublicKey self.spk_private = None # X25519PrivateKey (current signed prekey) self.spk_id: str = "" self._prev_spk_private = None # Previous SPK for grace period (M4) self._prev_spk_id: str = "" self.opk_privates: dict[str, object] = {} # id -> X25519PrivateKey self.sessions: dict[str, DoubleRatchet] = {} # "user_id:device_id" -> ratchet self.sender_key_states: dict[str, SenderKeyState] = {} # conv_id -> own sender key self.recv_sender_keys: dict[str, SenderKeyState] = {} # "conv_id:sender_id:device_id" -> their key # Cache: user_id -> {identity_key (Ed25519PublicKey), username, email} # Bounded to 10K entries with 1-hour TTL to prevent unbounded growth (L7) self._user_cache: _ExpiringLRUCache = _ExpiringLRUCache(max_size=10_000, ttl=3600.0) self.connected: bool = False self.login_rejected: bool = False self._cache_key: bytes | None = None # AES key for encrypting message cache on disk self._local_key: bytes | None = None # AES key for encrypting session/sender key files # Multi-device support self.device_id: str | None = None # This device's UUID self._device_bundle_cache: dict[str, tuple[float, list[dict]]] = {} # user_id -> (ts, bundles) # Queue of received messages to self-encrypt for multi-device access self._pending_self_encrypt: list[dict] = [] self._typing_active: dict[str, bool] = {} self._typing_last_sent: dict[str, float] = {} self._typing_stop_tasks: dict[str, asyncio.Task] = {} # Contact key verification (TOFU + explicit) self._known_identity_keys: dict = {} # user_id -> {identity_key hex, first_seen, last_seen} self._verified_contacts: dict = {} # user_id -> {identity_key hex, verified_at, method} self._key_change_cb = None # callback(user_id, username, old_key_hex, was_verified) async def connect(self): host = os.getenv("SERVER_HOST", "127.0.0.1") port = int(os.getenv("SERVER_PORT", "9999")) tls_enabled = os.getenv("TLS_ENABLED", "false").lower() in ("1", "true", "yes") tls_required = os.getenv("TLS_REQUIRED", "false").lower() in ("1", "true", "yes") ssl_context = None if tls_required and not tls_enabled: raise RuntimeError("TLS_REQUIRED is enabled but TLS is not enabled.") if tls_enabled: insecure = os.getenv("TLS_INSECURE", "false").lower() in ("1", "true", "yes") is_dev = os.getenv("ENVIRONMENT", "").lower() in ("dev", "development") if insecure and not is_dev: raise RuntimeError("TLS_INSECURE is only allowed when ENVIRONMENT=dev") ssl_context = ssl.create_default_context() ca_file = os.getenv("TLS_CA_FILE", "").strip() if ca_file: ssl_context.load_verify_locations(cafile=ca_file) elif insecure: ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE else: self._logger.warning("TLS is disabled — traffic is unencrypted. Set TLS_ENABLED=true for production.") r, w = await asyncio.open_connection(host, port, limit=MAX_MESSAGE_BYTES, ssl=ssl_context) # Enable TCP keepalive to detect dead connections through NAT/firewalls sock = w.get_extra_info("socket") if sock is not None: import socket try: sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) if hasattr(socket, "TCP_KEEPIDLE"): sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 25) if hasattr(socket, "TCP_KEEPINTVL"): sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 10) if hasattr(socket, "TCP_KEEPCNT"): sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 3) except OSError: pass self.reader = ProtocolReader(r) self.writer = ProtocolWriter(w) self.raw_writer = w self.connected = True self._logger.info("Connected to %s:%s (tls=%s)", host, port, "on" if tls_enabled else "off") def server_endpoint(self) -> str: host = os.getenv("SERVER_HOST", "127.0.0.1") port = os.getenv("SERVER_PORT", "9999") return f"{host}:{port}" def pairing_fingerprint(self) -> str: return self._pairing_fingerprint def pairing_qr_data(self) -> bytes | None: if not self._pairing_code or not self._pairing_fingerprint: return None return encode_pairing_qr(self._pairing_code, self._pairing_fingerprint) async def _background_listener(self): """Read messages from server, routing responses vs notifications.""" while True: msg = await self.reader.read_message() if msg is None: self.connected = False # Fail all pending futures so send_and_recv doesn't hang pending = dict(self._pending) self._pending.clear() err = ConnectionError("Server connection lost") for obj in pending.values(): if isinstance(obj, asyncio.Queue): # Signal stream consumers that connection died obj.put_nowait({"status": "error", "data": {"message": "Connection lost"}}) elif not obj.done(): obj.set_exception(err) break # Responses to our own requests (have request_id matching a pending future) # must be routed to the pending future, even if the type matches a notification name. req_id = msg.get("request_id") if req_id and req_id in self._pending: pending_obj = self._pending[req_id] if isinstance(pending_obj, asyncio.Queue): await pending_obj.put(msg) else: self._pending.pop(req_id) if not pending_obj.done(): pending_obj.set_result(msg) elif msg.get("type") in ("new_message", "messages_read", "message_deleted", "conversation_created", "member_added", "member_removed", "user_online", "user_offline", "online_users", "group_invitation", "conversation_renamed", "device_added", "session_reset", "message_reacted", "message_pinned", "message_unpinned", "message_delivered", "username_changed", "avatar_changed", "keys_updated", "typing_start", "typing_stop"): await self._notification_queue.put(msg) else: await self._response_queue.put(msg) async def send_and_recv(self, msg_type: str, timeout: float = 30.0, **kwargs) -> dict: try: request_id = str(uuid.uuid4()) loop = asyncio.get_running_loop() fut = loop.create_future() self._pending[request_id] = fut await self.writer.send_request(msg_type, request_id=request_id, **kwargs) except (ValueError, ConnectionError, OSError) as e: self._pending.pop(request_id, None) return { "type": msg_type, "status": "error", "data": {"message": str(e) or "Connection lost."}, } try: return await asyncio.wait_for(fut, timeout=timeout) except asyncio.TimeoutError: self._logger.warning("send_and_recv timeout for '%s' after %.0fs", msg_type, timeout) return { "type": msg_type, "status": "error", "data": {"message": f"Request timed out ({msg_type})"}, } except ConnectionError: return { "type": msg_type, "status": "error", "data": {"message": "Connection lost."}, } finally: self._pending.pop(request_id, None) # ------------------------------------------------------------------ # User info / identity key cache # ------------------------------------------------------------------ async def _get_user_info(self, user_id: str = "", email: str = "") -> dict | None: """Get user info from server, cache identity key. Performs TOFU check.""" cached = self._user_cache.get(user_id) if cached: return cached kwargs = {} if user_id: kwargs["user_id"] = user_id elif email: kwargs["email"] = email else: return None resp = await self.send_and_recv("get_user_info", **kwargs) if resp["status"] != "ok": return None data = resp["data"] ik_bytes = decode_binary(data["identity_key"]) if data.get("identity_key") else None info = { "user_id": data["user_id"], "username": data["username"], "email": data["email"], "identity_key": load_ed25519_public(ik_bytes) if ik_bytes else None, "identity_key_bytes": ik_bytes, } # TOFU: check identity key against known keys if ik_bytes: status = self.check_identity_key(data["user_id"], ik_bytes) info["identity_key_status"] = status self._user_cache[data["user_id"]] = info return info # ------------------------------------------------------------------ # Contact Key Verification # ------------------------------------------------------------------ def _load_verification_stores(self): """Load TOFU and verification stores from disk.""" if not self.email: return self._known_identity_keys = _load_known_identity_keys(self.email, self._local_key) self._verified_contacts = _load_verified_contacts(self.email, self._local_key) def check_identity_key(self, user_id: str, identity_key_bytes: bytes) -> str: """Check a user's identity key against TOFU registry. Returns: "new" — first contact, key recorded (TOFU) "trusted" — key matches previously seen, not explicitly verified "verified" — key matches and explicitly verified "changed" — key differs from recorded (WARNING) "changed_verified" — key changed AND was previously verified (CRITICAL) """ ik_hex = identity_key_bytes.hex() now = datetime.now(timezone.utc).isoformat() known = self._known_identity_keys.get(user_id) if known is None: # First time seeing this user — TOFU: trust on first use self._known_identity_keys[user_id] = { "identity_key": ik_hex, "first_seen": now, "last_seen": now, } if self.email: _save_known_identity_keys(self.email, self._known_identity_keys, self._local_key) return "new" if known["identity_key"] == ik_hex: # Key matches — update last_seen known["last_seen"] = now if self.email: _save_known_identity_keys(self.email, self._known_identity_keys, self._local_key) # Check if explicitly verified verified = self._verified_contacts.get(user_id) if verified and verified.get("identity_key") == ik_hex: return "verified" return "trusted" # Key has CHANGED was_verified = user_id in self._verified_contacts old_key_hex = known["identity_key"] # Invoke callback for GUI/CLI warning if self._key_change_cb: username = "" cached = self._user_cache.get(user_id) if cached: username = cached.get("username", "") try: self._key_change_cb(user_id, username, old_key_hex, was_verified, identity_key_bytes) except Exception: pass return "changed_verified" if was_verified else "changed" def verify_contact(self, user_id: str, identity_key_bytes: bytes, method: str = "manual"): """Mark a contact's identity key as explicitly verified.""" ik_hex = identity_key_bytes.hex() now = datetime.now(timezone.utc).isoformat() self._verified_contacts[user_id] = { "identity_key": ik_hex, "verified_at": now, "method": method, } # Also ensure TOFU registry is up to date if user_id not in self._known_identity_keys: self._known_identity_keys[user_id] = { "identity_key": ik_hex, "first_seen": now, "last_seen": now, } else: self._known_identity_keys[user_id]["last_seen"] = now if self.email: _save_verified_contacts(self.email, self._verified_contacts, self._local_key) _save_known_identity_keys(self.email, self._known_identity_keys, self._local_key) # Update user cache status cached = self._user_cache.get(user_id) if cached: cached["identity_key_status"] = "verified" def unverify_contact(self, user_id: str): """Remove explicit verification for a contact.""" self._verified_contacts.pop(user_id, None) if self.email: _save_verified_contacts(self.email, self._verified_contacts, self._local_key) cached = self._user_cache.get(user_id) if cached and cached.get("identity_key_status") == "verified": cached["identity_key_status"] = "trusted" def accept_key_change(self, user_id: str, new_ik_bytes: bytes): """Accept a changed identity key — update TOFU, remove old verification.""" ik_hex = new_ik_bytes.hex() now = datetime.now(timezone.utc).isoformat() self._known_identity_keys[user_id] = { "identity_key": ik_hex, "first_seen": now, "last_seen": now, } # Remove old verification — user must re-verify self._verified_contacts.pop(user_id, None) if self.email: _save_known_identity_keys(self.email, self._known_identity_keys, self._local_key) _save_verified_contacts(self.email, self._verified_contacts, self._local_key) # Update cache cached = self._user_cache.get(user_id) if cached: cached["identity_key_status"] = "trusted" def get_verification_status(self, user_id: str) -> str: """Get verification status for a user. Returns: "verified", "trusted", or "unverified". """ verified = self._verified_contacts.get(user_id) if verified: # Check key still matches known = self._known_identity_keys.get(user_id) if known and known.get("identity_key") == verified.get("identity_key"): return "verified" if user_id in self._known_identity_keys: return "trusted" return "unverified" def get_safety_number(self, peer_user_id: str) -> str | None: """Get formatted safety number for a peer (requires both identity keys).""" if not self.identity_public or not self.session: return None my_uid = self.session.get("user_id", "") my_ik_bytes = serialize_ed25519_public(self.identity_public) cached = self._user_cache.get(peer_user_id) if not cached or not cached.get("identity_key_bytes"): return None return compute_safety_number(my_uid, my_ik_bytes, peer_user_id, cached["identity_key_bytes"]) def get_my_fingerprint(self) -> str | None: """Get formatted fingerprint for own identity key.""" if not self.identity_public or not self.session: return None my_uid = self.session.get("user_id", "") my_ik_bytes = serialize_ed25519_public(self.identity_public) fp = compute_fingerprint(my_uid, my_ik_bytes) return format_fingerprint(fp) def get_peer_fingerprint(self, peer_user_id: str) -> str | None: """Get formatted fingerprint for a peer's identity key.""" cached = self._user_cache.get(peer_user_id) if not cached or not cached.get("identity_key_bytes"): return None fp = compute_fingerprint(peer_user_id, cached["identity_key_bytes"]) return format_fingerprint(fp) def get_verification_qr_data(self) -> bytes | None: """Get QR code payload bytes for own identity (for peer to scan).""" if not self.identity_public or not self.session: return None my_uid = self.session.get("user_id", "") my_ik_bytes = serialize_ed25519_public(self.identity_public) return encode_verification_qr(my_uid, my_ik_bytes) def verify_qr_code(self, qr_data: bytes) -> tuple[bool, str, str]: """Verify a scanned QR code against known identity keys. Returns (success, user_id, message). """ try: user_id, ik_bytes = decode_verification_qr(qr_data) except ValueError as e: return False, "", f"Invalid QR code: {e}" cached = self._user_cache.get(user_id) if not cached: return False, user_id, "Unknown user — not in your contacts." if not cached.get("identity_key_bytes"): return False, user_id, "No identity key on record for this user." if cached["identity_key_bytes"] != ik_bytes: return False, user_id, "Identity key MISMATCH — verification failed!" # Keys match — mark as verified self.verify_contact(user_id, ik_bytes, method="qr_code") username = cached.get("username", user_id[:8]) return True, user_id, f"Verified {username} via QR code." # ------------------------------------------------------------------ # Registration # ------------------------------------------------------------------ async def register(self, username: str, password: str, email: str) -> tuple[bool, str]: """Register user. Generates RSA + Ed25519 in memory (saved to disk only after server confirms registration via confirm_registration).""" self.username = username self.email = email pwd_bytes = bytearray(password.encode("utf-8")) if password else None try: pwd = bytes(pwd_bytes) if pwd_bytes else None # Try loading existing keys (previous successful registration) priv, pub, err = load_keys(email, password=pwd) if priv is None: priv, pub = generate_rsa_keypair() self.private_key = priv self.public_key = pub try: ed_priv, ed_pub = _load_identity_keys(email, password=pwd) except Exception: ed_priv, ed_pub = None, None if ed_priv is None: ed_priv, ed_pub = generate_identity_keypair() self.identity_private = ed_priv self.identity_public = ed_pub self._cache_key = derive_self_encryption_key(ed_priv) self._local_key = derive_local_storage_key(ed_priv) # Store password for saving keys after confirm self._reg_password = pwd finally: if pwd_bytes: pwd_bytes[:] = b'\x00' * len(pwd_bytes) pub_pem = serialize_public_key(pub).decode("utf-8") ik_b64 = encode_binary(serialize_ed25519_public(ed_pub)) extra_fields: dict = {} start = await self.send_and_recv( "register", username=username, public_key=pub_pem, email=email, identity_key=ik_b64, ) # Handle PoW challenge (server under pressure) if start.get("status") == "pow_required": challenge = start["data"]["challenge"] mac = start["data"]["mac"] difficulty = start["data"]["difficulty"] self._logger.info("Server requires proof-of-work (difficulty %d), solving...", difficulty) nonce = await asyncio.get_running_loop().run_in_executor( None, _solve_pow, challenge, difficulty) extra_fields = {"pow_challenge": challenge, "pow_mac": mac, "pow_nonce": nonce} start = await self.send_and_recv( "register", username=username, public_key=pub_pem, email=email, identity_key=ik_b64, **extra_fields, ) if start["status"] != "ok": self._reg_password = None return False, start["data"]["message"] code = start["data"].get("code") if code: return True, code return True, start["data"].get("message", "Check your email for the code.") async def confirm_registration(self, email: str, username: str, code: str) -> tuple[bool, str]: confirm = await self.send_and_recv("register_confirm", email=email, code=code) if confirm["status"] == "ok": # Registration confirmed — NOW save keys to disk pwd = getattr(self, "_reg_password", None) save_keys(email, self.private_key, self.public_key, password=pwd) _save_identity_keys(email, self.identity_private, self.identity_public, password=pwd) self._reg_password = None self._load_verification_stores() return True, f"Registered as '{username}' (ID: {confirm['data']['user_id']})" return False, confirm["data"]["message"] async def _generate_and_upload_prekeys(self, keep_spk: bool = False): """Generate SPK + OPKs and upload to server. If keep_spk=True, re-sign the existing SPK instead of generating a new one. This is used after device pairing so both devices share the same SPK and either can respond to X3DH. """ if not self.identity_private: return if keep_spk and self.spk_private and self.spk_id: # Re-sign existing SPK (both devices share the identity key) spk_pub_bytes = serialize_x25519_public(self.spk_private.public_key()) spk_sig = ed25519_sign(self.identity_private, spk_pub_bytes) spk_data = { "id": self.spk_id, "public_key": encode_binary(spk_pub_bytes), "signature": encode_binary(spk_sig), } else: # Save current SPK as previous for grace period (M4: in-flight X3DH) if self.spk_private and self.spk_id: self._prev_spk_private = self.spk_private self._prev_spk_id = self.spk_id _save_prev_spk(self.email, self.spk_private, self.spk_id, self._local_key) # Generate a brand-new signed prekey spk = generate_signed_prekey(self.identity_private) self.spk_private = spk["private"] self.spk_id = spk["id"] _save_spk(self.email, spk["private"], spk["id"], self._local_key) spk_data = { "id": spk["id"], "public_key": encode_binary(serialize_x25519_public(spk["public"])), "signature": encode_binary(spk["signature"]), } # Generate one-time prekeys opks = generate_one_time_prekeys(OPK_BATCH_SIZE) for opk in opks: self.opk_privates[opk["id"]] = opk["private"] _save_opk_private(self.email, opk["id"], opk["private"], self._local_key) # Upload to server otp_data = [ {"id": opk["id"], "public_key": encode_binary(serialize_x25519_public(opk["public"]))} for opk in opks ] resp = await self.send_and_recv( "upload_prekeys", signed_prekey=spk_data, one_time_prekeys=otp_data, ) if resp.get("status") != "ok": self._logger.warning("upload_prekeys failed: %s (will retry on login)", resp.get("data", {}).get("message", "unknown")) async def _ensure_prekeys(self): """Check OPK count and SPK age, replenish/rotate if needed. Uses single-roundtrip `ensure_prekeys` handler when available, falls back to legacy two-step flow (get_prekey_count + upload_prekeys). """ resp = await self.send_and_recv("get_prekey_count") if resp["status"] != "ok": return count = resp["data"].get("count", 0) spk_created_at = resp["data"].get("spk_created_at", "") need_new_spk = False if spk_created_at: try: created = datetime.fromisoformat(spk_created_at) if created.tzinfo is None: created = created.replace(tzinfo=timezone.utc) age_days = (datetime.now(timezone.utc) - created).days if age_days >= SPK_ROTATION_DAYS: need_new_spk = True self._logger.info("SPK is %d days old, rotating...", age_days) except Exception: need_new_spk = True else: # No SPK on server for this device — must upload one need_new_spk = True self._logger.info("No SPK on server for this device, uploading...") if count < OPK_REPLENISH_THRESHOLD or need_new_spk: if count >= OPK_REPLENISH_THRESHOLD: self._logger.info("SPK rotation triggered (OPK count OK: %d)", count) else: self._logger.info("OPK count low (%d), replenishing...", count) await self._generate_and_upload_prekeys_batch(need_new_spk) async def _generate_and_upload_prekeys_batch(self, need_new_spk: bool = False): """Generate and upload prekeys in a single round-trip via ensure_prekeys.""" if not self.identity_private: return kwargs: dict = {} # SPK if need_new_spk: if self.spk_private and self.spk_id: self._prev_spk_private = self.spk_private self._prev_spk_id = self.spk_id _save_prev_spk(self.email, self.spk_private, self.spk_id, self._local_key) spk = generate_signed_prekey(self.identity_private) self.spk_private = spk["private"] self.spk_id = spk["id"] _save_spk(self.email, spk["private"], spk["id"], self._local_key) kwargs["signed_prekey"] = { "id": spk["id"], "public_key": encode_binary(serialize_x25519_public(spk["public"])), "signature": encode_binary(spk["signature"]), } # OPKs opks = generate_one_time_prekeys(OPK_BATCH_SIZE) for opk in opks: self.opk_privates[opk["id"]] = opk["private"] _save_opk_private(self.email, opk["id"], opk["private"], self._local_key) kwargs["one_time_prekeys"] = [ {"id": opk["id"], "public_key": encode_binary(serialize_x25519_public(opk["public"]))} for opk in opks ] resp = await self.send_and_recv("ensure_prekeys", **kwargs) if resp["status"] == "ok": data = resp.get("data", {}) self._logger.info("ensure_prekeys: count=%d, spk_uploaded=%s, otps_uploaded=%d", data.get("count", 0), data.get("uploaded_spk", False), data.get("uploaded_otps", 0)) # ------------------------------------------------------------------ # Login # ------------------------------------------------------------------ async def login(self, email: str, password: str) -> tuple[bool, str]: """Login user. Returns (success, message).""" self.email = email # Brute-force lockout check remaining = _check_lockout(email) if remaining > 0: return False, f"Too many failed attempts. Try again in {remaining:.0f}s." pwd_bytes = bytearray(password.encode("utf-8")) if password else None try: # Load RSA keys priv, pub, err = load_keys(email, password=bytes(pwd_bytes) if pwd_bytes else None) if priv is None: if err and "password" in err.lower(): _record_failed_attempt(email) return False, err or "No local keys found. Register first." self.private_key = priv self.public_key = pub # Load identity keys ed_priv, ed_pub = _load_identity_keys(email, password=bytes(pwd_bytes) if pwd_bytes else None) finally: if pwd_bytes: pwd_bytes[:] = b'\x00' * len(pwd_bytes) if ed_priv is not None: self.identity_private = ed_priv self.identity_public = ed_pub self._cache_key = derive_self_encryption_key(ed_priv) self._local_key = derive_local_storage_key(ed_priv) self._load_verification_stores() # Load SPK spk_priv, spk_id = _load_spk(email, self._local_key) if spk_priv: self.spk_private = spk_priv self.spk_id = spk_id # Load previous SPK for grace period (M4) prev_spk_priv, prev_spk_id = _load_prev_spk(email, self._local_key) if prev_spk_priv: self._prev_spk_private = prev_spk_priv self._prev_spk_id = prev_spk_id # Load device_id from disk self.device_id = _load_device_id(email) # RSA challenge-response login start = await self.send_and_recv("login_start", email=email) if start["status"] != "ok": return False, start["data"]["message"] challenge = decode_binary(start["data"]["challenge"]) signature = rsa_sign(self.private_key, challenge) login_kwargs = {"email": email, "signature": encode_binary(signature), "client_version": VERSION} if self.device_id: login_kwargs["device_id"] = self.device_id finish = await self.send_and_recv("login_finish", **login_kwargs) if finish["status"] == "ok": self.session = finish["data"] self.username = self.session.get("username", "") # Store device_id from server self.device_id = finish["data"].get("device_id") if self.device_id: _save_device_id(email, self.device_id) # Replenish prekeys in background — after pairing, the new device # has no local OPK private keys so we must generate fresh ones # (server-side OPKs have no matching private keys on this device). # Use keep_spk=True to preserve the shared SPK so both devices # can respond to X3DH. opk_dir = get_key_dir(self.email) / "opk_private" has_local_opks = opk_dir.exists() and any(opk_dir.iterdir()) if has_local_opks: asyncio.create_task(self._ensure_prekeys()) else: self._logger.info("No local OPKs (likely new device). Generating fresh OPKs, keeping SPK.") asyncio.create_task(self._generate_and_upload_prekeys(keep_spk=True)) _clear_lockout(email) return True, f"Logged in as '{self.username}' (ID: {self.session['user_id']})" return False, finish["data"]["message"] # ------------------------------------------------------------------ # Pairing (device pairing — transfers RSA + identity keys) # ------------------------------------------------------------------ async def pairing_start(self, email: str) -> tuple[bool, str]: """Start device pairing. Returns (success, code/message).""" self._logger.info("pairing_start via %s for %s", self.server_endpoint(), email.strip().lower()) temp_priv, temp_pub = generate_x25519_keypair() self._pairing_temp_private_key = temp_priv temp_pub_raw = serialize_x25519_public(temp_pub) self._pairing_fingerprint = compute_pairing_fingerprint(temp_pub_raw) resp = await self.send_and_recv( "pairing_start", email=email, temp_public_key=encode_binary(temp_pub_raw), temp_key_type="x25519", ) if resp["status"] == "ok": self._pairing_code = resp["data"]["code"] self._pairing_poll_token = resp["data"].get("poll_token", "") return True, resp["data"]["code"] self._pairing_fingerprint = "" self._pairing_code = "" return False, resp["data"]["message"] async def pairing_wait(self, code: str, email: str, password: str, timeout: int = 300) -> tuple[bool, str]: """Wait for pairing payload and import keys. Returns (success, message).""" if not self._pairing_temp_private_key: return False, "Pairing not started." from crypto_utils import aes_decrypt as _aes_decrypt poll_token = getattr(self, "_pairing_poll_token", "") deadline = asyncio.get_event_loop().time() + timeout while asyncio.get_event_loop().time() < deadline: resp = await self.send_and_recv("pairing_poll", code=code, poll_token=poll_token) if resp["status"] != "ok": self._pairing_fingerprint = "" self._pairing_code = "" return False, resp["data"]["message"] if not resp["data"].get("ready"): await asyncio.sleep(2.0) continue payload = resp["data"]["payload"] try: sender_pub_raw = decode_binary(payload["sender_public_key"]) sender_pub = load_x25519_public(sender_pub_raw) my_pub_raw = serialize_x25519_public(self._pairing_temp_private_key.public_key()) shared_secret = x25519_dh(self._pairing_temp_private_key, sender_pub) aes_key = derive_pairing_shared_key(shared_secret, my_pub_raw, sender_pub_raw) nonce = decode_binary(payload["iv"]) ct = decode_binary(payload["ciphertext"]) tag = decode_binary(payload["tag"]) keys_json = _aes_decrypt(aes_key, nonce, ct, tag) keys_data = json.loads(keys_json) pwd_bytes = bytearray(password.encode("utf-8")) if password else None try: # Import RSA key rsa_priv = load_private_key(keys_data["rsa_private"].encode(), password=None) rsa_pub = rsa_priv.public_key() save_keys(email, rsa_priv, rsa_pub, password=bytes(pwd_bytes) if pwd_bytes else None) # Import identity keys ed_priv = load_ed25519_private(bytes.fromhex(keys_data["identity_private"])) ed_pub = ed_priv.public_key() _save_identity_keys(email, ed_priv, ed_pub, password=bytes(pwd_bytes) if pwd_bytes else None) finally: if pwd_bytes: pwd_bytes[:] = b'\x00' * len(pwd_bytes) self.email = email self.private_key = rsa_priv self.public_key = rsa_pub self.identity_private = ed_priv self.identity_public = ed_pub self._cache_key = derive_self_encryption_key(ed_priv) self._local_key = derive_local_storage_key(ed_priv) self._load_verification_stores() self._pairing_temp_private_key = None self._pairing_fingerprint = "" self._pairing_code = "" # Multi-device: new device generates own SPK + OPKs on first # login. No session/sender key import needed — each device # has independent Double Ratchet sessions. return True, "Pairing complete." except Exception as e: self._pairing_fingerprint = "" self._pairing_code = "" return False, f"Failed to import keys: {e}" self._pairing_fingerprint = "" self._pairing_code = "" return False, "Pairing timed out." async def authorize_device(self, code: str, expected_fingerprint: str) -> tuple[bool, str]: """Authorize a new device by sending all keys to it.""" if not self.private_key or not self.identity_private: return False, "Not logged in." expected_digits = normalize_pairing_fingerprint(expected_fingerprint) if len(expected_digits) != 30: return False, "Pairing fingerprint must contain 30 digits." claim = await self.send_and_recv("pairing_claim", code=code) if claim["status"] != "ok": return False, claim["data"]["message"] if claim["data"].get("temp_key_type") != "x25519": return False, "Unsupported pairing key type. Update both devices and try again." temp_pub_raw = decode_binary(claim["data"]["temp_public_key"]) actual_fp = compute_pairing_fingerprint(temp_pub_raw) if normalize_pairing_fingerprint(actual_fp) != expected_digits: self._logger.warning("Pairing fingerprint mismatch for code %s", code[:8]) return False, ( "Pairing fingerprint mismatch. Verify the new device fingerprint and try again.\n" f"Expected: {expected_fingerprint}\n" f"Received: {actual_fp}" ) temp_pub = load_x25519_public(temp_pub_raw) # Build keys payload — only RSA + identity key. # Multi-device: new device generates own SPK + OPKs, creates independent # sessions. No session/sender key transfer needed. keys_data = { "rsa_private": serialize_private_key(self.private_key, password=None).decode(), "identity_private": serialize_ed25519_private_raw(self.identity_private).hex(), } # Send keys to the new device first. Re-encrypting history can take a # while on large accounts; doing it before pairing_send can make a valid # code expire during authorization. plaintext = json.dumps(keys_data).encode() sender_priv, sender_pub = generate_x25519_keypair() sender_pub_raw = serialize_x25519_public(sender_pub) shared_secret = x25519_dh(sender_priv, temp_pub) pairing_key = derive_pairing_shared_key(shared_secret, sender_pub_raw, temp_pub_raw) _, nonce, ct, tag = aes_encrypt(plaintext, key=pairing_key) payload = { "sender_public_key": encode_binary(sender_pub_raw), "iv": encode_binary(nonce), "ciphertext": encode_binary(ct), "tag": encode_binary(tag), } resp = await self.send_and_recv("pairing_send", code=code, payload=payload) if resp["status"] == "ok": async def _reencrypt_after_pairing(): try: delay = random.uniform(*PAIRING_REENCRYPT_INITIAL_DELAY_RANGE) self._logger.info("Delaying post-pairing history resync by %.1fs", delay) await asyncio.sleep(delay) await self.reencrypt_history() except Exception as e: self._logger.warning("Post-pairing re-encryption failed: %s", e) asyncio.create_task(_reencrypt_after_pairing()) return True, "Device authorized." return False, resp["data"]["message"] # ------------------------------------------------------------------ # Password change (local key re-encryption only) # ------------------------------------------------------------------ def change_password(self, old_password: str, new_password: str) -> tuple[bool, str]: """Change password for local key encryption (RSA + identity key). Returns (success, message). """ if not self.email: return False, "Not logged in." old_pwd = bytearray(old_password.encode("utf-8")) new_pwd = bytearray(new_password.encode("utf-8")) try: # 1. Verify old password by loading keys priv, pub, err = load_keys(self.email, password=bytes(old_pwd)) if priv is None: return False, "Wrong current password." ed_priv, ed_pub = _load_identity_keys(self.email, password=bytes(old_pwd)) if ed_priv is None: return False, "Failed to load identity key." # 2. Re-save with new password save_keys(self.email, priv, pub, password=bytes(new_pwd)) _save_identity_keys(self.email, ed_priv, ed_pub, password=bytes(new_pwd)) return True, "Password changed successfully." finally: old_pwd[:] = b'\x00' * len(old_pwd) new_pwd[:] = b'\x00' * len(new_pwd) async def change_username(self, new_username: str) -> tuple[bool, str]: """Change display name on server.""" if not self.session: return False, "Not logged in." new_username = new_username.strip() if not new_username or len(new_username) > 100: return False, "Username must be 1-100 characters." resp = await self.send_and_recv("change_username", username=new_username) if resp["status"] == "ok": self.username = resp["data"]["username"] if self.session: self.session["username"] = self.username return True, "Username changed." return False, resp["data"].get("message", "Unknown error") # ------------------------------------------------------------------ # Key rotation (RSA login key only) # ------------------------------------------------------------------ async def rotate_keys(self, username: str, password: str) -> tuple[bool, str]: """Rotate RSA keypair to revoke other devices.""" if not self.session or self.session.get("username") != username: return False, "Not logged in." pwd_bytes = password.encode("utf-8") if password else None priv, pub = generate_rsa_keypair() pub_pem = serialize_public_key(pub).decode("utf-8") # Persist the new key only after the server accepted it — overwriting # private.pem first would brick the account if rotation fails. resp = await self.send_and_recv("rotate_keys", public_key=pub_pem) if resp["status"] == "ok": save_keys(self.email, priv, pub, password=pwd_bytes) self.private_key = priv self.public_key = pub return True, "RSA login keys rotated." return False, resp["data"]["message"] # ------------------------------------------------------------------ # Session management (X3DH + Double Ratchet) # ------------------------------------------------------------------ async def _get_device_bundles(self, peer_user_id: str) -> list[dict]: """Get per-device key bundles for a peer. Caches for 5 minutes.""" import time cached = self._device_bundle_cache.get(peer_user_id) if cached: ts, bundles = cached if time.time() - ts < 300: return bundles resp = await self.send_and_recv("get_key_bundle", user_id=peer_user_id) if resp["status"] != "ok": raise RuntimeError(f"Cannot get key bundle for {peer_user_id}: {resp['data']['message']}") data = resp["data"] ik_b64 = data.get("identity_key", "") device_bundles = data.get("device_bundles") if device_bundles: # Attach identity_key to each bundle for b in device_bundles: b["identity_key"] = ik_b64 else: # Old server: wrap flat response as single-entry list device_bundles = [{ "device_id": None, "identity_key": ik_b64, "signed_prekey_id": data.get("signed_prekey_id", ""), "signed_prekey": data.get("signed_prekey", ""), "spk_signature": data.get("spk_signature", ""), "one_time_prekey_id": data.get("one_time_prekey_id"), "one_time_prekey": data.get("one_time_prekey"), }] self._device_bundle_cache[peer_user_id] = (time.time(), device_bundles) return device_bundles async def _get_or_create_session(self, peer_user_id: str, peer_device_id: str | None = None, bundle: dict | None = None) -> DoubleRatchet: """Load existing session or create one via X3DH. If peer_device_id is set, sessions are keyed by "user_id:device_id". If bundle is provided, it's used instead of fetching from server. """ session_key = f"{peer_user_id}:{peer_device_id}" if peer_device_id else peer_user_id # Check in-memory cache if session_key in self.sessions: return self.sessions[session_key] # Check on disk ratchet = _load_session(self.email, peer_user_id, self._local_key, peer_device_id=peer_device_id) if ratchet: self.sessions[session_key] = ratchet return ratchet # Create new session via X3DH if not bundle: resp = await self.send_and_recv("get_key_bundle", user_id=peer_user_id) if resp["status"] != "ok": raise RuntimeError(f"Cannot get key bundle for {peer_user_id}: {resp['data']['message']}") bundle = resp["data"] ik_remote_bytes = decode_binary(bundle["identity_key"]) ik_remote = load_ed25519_public(ik_remote_bytes) # TOFU: verify identity key before using it in X3DH ik_status = self.check_identity_key(peer_user_id, ik_remote_bytes) if ik_status in ("changed", "changed_verified"): raise IdentityKeyChanged(peer_user_id, ik_remote_bytes, ik_status) spk_remote = load_x25519_public(decode_binary(bundle["signed_prekey"])) spk_sig = decode_binary(bundle["spk_signature"]) opk_remote = None opk_id = bundle.get("one_time_prekey_id") if bundle.get("one_time_prekey"): opk_remote = load_x25519_public(decode_binary(bundle["one_time_prekey"])) # Perform X3DH shared_secret, ek_priv, ek_pub = x3dh_initiate( self.identity_private, ik_remote, spk_remote, spk_sig, opk_remote, ) # Initialize Double Ratchet as Alice ratchet = DoubleRatchet.init_alice(shared_secret, spk_remote) self.sessions[session_key] = ratchet _save_session(self.email, peer_user_id, ratchet, self._local_key, peer_device_id=peer_device_id) # Build X3DH header for first message x3dh_header = { "ik": encode_binary(serialize_ed25519_public(self.identity_public)), "ek": encode_binary(serialize_x25519_public(ek_pub)), } if opk_id: x3dh_header["opk_id"] = opk_id # Cache the x3dh header for the next send_message call ratchet._x3dh_header = x3dh_header # Cache remote user info self._user_cache[peer_user_id] = { "user_id": peer_user_id, "identity_key": ik_remote, "identity_key_bytes": ik_remote_bytes, "identity_key_status": ik_status, } return ratchet def _process_x3dh_header(self, sender_id: str, x3dh_header: dict, sender_device_id: str | None = None, spk_override=None) -> DoubleRatchet: """Process an incoming X3DH header to establish session as Bob. Args: spk_override: If provided, use this SPK private key instead of self.spk_private. Used for grace period fallback (M4). """ ik_remote_bytes = decode_binary(x3dh_header["ik"]) ik_remote = load_ed25519_public(ik_remote_bytes) # TOFU: verify identity key before using it in X3DH ik_status = self.check_identity_key(sender_id, ik_remote_bytes) if ik_status in ("changed", "changed_verified"): raise IdentityKeyChanged(sender_id, ik_remote_bytes, ik_status) ek_remote = load_x25519_public(decode_binary(x3dh_header["ek"])) opk_id = x3dh_header.get("opk_id") opk_priv = None if opk_id: opk_priv = _load_opk_private(self.email, opk_id, self._local_key) # Deletion is deferred until the first message decrypts successfully # (_consume_pending_opk). Deleting here would break the SPK # grace-period retry: the second _process_x3dh_header call could no # longer load the OPK and the message would be lost permanently. spk_priv = spk_override if spk_override else self.spk_private shared_secret = x3dh_respond( self.identity_private, spk_priv, ik_remote, ek_remote, opk_priv, ) spk_pub = spk_priv.public_key() if hasattr(spk_priv, 'public_key') else None ratchet = DoubleRatchet.init_bob(shared_secret, (spk_priv, spk_pub)) ratchet._pending_opk_delete = opk_id if opk_priv else None # NOTE: the ratchet is intentionally NOT installed into self.sessions # nor saved to disk here. The caller does that only after the first # message decrypts successfully — otherwise a failed/forged X3DH # header would overwrite a working session. self._user_cache[sender_id] = { "user_id": sender_id, "identity_key": ik_remote, "identity_key_bytes": ik_remote_bytes, "identity_key_status": ik_status, } return ratchet def _consume_pending_opk(self, ratchet) -> None: """Delete the one-time prekey consumed by an X3DH handshake. Called only after the first message decrypted successfully, so a failed attempt (e.g. wrong SPK during the grace period) can still retry with the same OPK. """ opk_id = getattr(ratchet, "_pending_opk_delete", None) if opk_id: _delete_opk_private(self.email, opk_id) ratchet._pending_opk_delete = None # ------------------------------------------------------------------ # Conversations # ------------------------------------------------------------------ async def create_conversation(self, member_emails: list[str], name: str | None = None) -> tuple[str | None, str]: kwargs = {"members": member_emails} if name: kwargs["name"] = name resp = await self.send_and_recv("create_conversation", **kwargs) if resp["status"] == "ok": return resp["data"]["conversation_id"], "OK" return None, resp["data"]["message"] async def remove_member(self, conv_id: str, user_id: str) -> tuple[bool, str]: resp = await self.send_and_recv("remove_member", conversation_id=conv_id, user_id=user_id) if resp["status"] == "ok": return True, "OK" return False, resp["data"]["message"] async def leave_group(self, conv_id: str) -> tuple[bool, str]: """Leave a group conversation.""" resp = await self.send_and_recv("leave_group", conversation_id=conv_id) if resp["status"] == "ok": # Clean up local sender key state for this group self.sender_key_states.pop(conv_id, None) # Remove received sender keys for this conversation to_remove = [k for k in self.recv_sender_keys if k.startswith(f"{conv_id}:")] for k in to_remove: self.recv_sender_keys.pop(k, None) return True, "OK" return False, resp["data"]["message"] async def rename_conversation(self, conv_id: str, name: str) -> tuple[bool, str]: """Rename a group conversation (creator only).""" resp = await self.send_and_recv("rename_conversation", conversation_id=conv_id, name=name) if resp["status"] == "ok": return True, "OK" return False, resp["data"]["message"] async def delete_conversation(self, conv_id: str) -> tuple[bool, str]: """Delete a conversation (leave + server cleans up if empty).""" resp = await self.send_and_recv("delete_conversation", conversation_id=conv_id) if resp["status"] == "ok": # Clean up local sender key state self.sender_key_states.pop(conv_id, None) to_remove = [k for k in self.recv_sender_keys if k.startswith(f"{conv_id}:")] for k in to_remove: self.recv_sender_keys.pop(k, None) return True, "OK" return False, resp["data"]["message"] async def add_member(self, conv_id: str, email: str) -> tuple[bool, str]: resp = await self.send_and_recv("add_member", conversation_id=conv_id, email=email) if resp["status"] == "ok": return True, "OK" return False, resp["data"]["message"] async def accept_invitation(self, conv_id: str) -> tuple[bool, str]: """Accept a group invitation.""" resp = await self.send_and_recv("accept_invitation", conversation_id=conv_id) if resp["status"] == "ok": return True, "OK" return False, resp["data"]["message"] async def decline_invitation(self, conv_id: str) -> tuple[bool, str]: """Decline a group invitation.""" resp = await self.send_and_recv("decline_invitation", conversation_id=conv_id) if resp["status"] == "ok": return True, "OK" return False, resp["data"]["message"] async def list_invitations(self) -> list[dict]: """List pending group invitations.""" resp = await self.send_and_recv("list_invitations") if resp["status"] == "ok": return resp["data"]["invitations"] return [] async def list_conversations(self) -> list[dict]: resp = await self.send_and_recv("list_conversations") if resp["status"] == "ok": return resp["data"]["conversations"] return [] async def find_or_create_conversation(self, email: str) -> tuple[str | None, str]: resp = await self.send_and_recv("find_conversation", email=email) if resp["status"] != "ok": return None, resp["data"]["message"] conv_id = resp["data"]["conversation_id"] if conv_id: return conv_id, "OK" return await self.create_conversation([email]) # ------------------------------------------------------------------ # Send message # ------------------------------------------------------------------ def _cancel_typing_timer(self, conv_id: str): task = self._typing_stop_tasks.pop(conv_id, None) if task and not task.done(): task.cancel() async def _typing_stop_after_delay(self, conv_id: str, delay: float): try: await asyncio.sleep(delay) await self.typing_stop(conv_id) except asyncio.CancelledError: return async def typing_start(self, conv_id: str): """Debounced typing_start with 3s inactivity timeout.""" if not conv_id or not self.session: return now = time.monotonic() was_active = self._typing_active.get(conv_id, False) last_sent = self._typing_last_sent.get(conv_id, 0.0) should_send = (not was_active) or (now - last_sent >= 1.0) self._typing_active[conv_id] = True self._cancel_typing_timer(conv_id) self._typing_stop_tasks[conv_id] = asyncio.create_task( self._typing_stop_after_delay(conv_id, 3.0) ) if not should_send: return self._typing_last_sent[conv_id] = now try: await self.send_and_recv("typing_start", timeout=5.0, conversation_id=conv_id) except Exception: pass async def typing_stop(self, conv_id: str, force: bool = False): if not conv_id or not self.session: return self._cancel_typing_timer(conv_id) was_active = self._typing_active.get(conv_id, False) self._typing_active[conv_id] = False if not was_active and not force: return try: await self.send_and_recv("typing_stop", timeout=5.0, conversation_id=conv_id) except Exception: pass def _is_group(self, members: list[dict]) -> bool: return len(members) > 2 async def send_message(self, conv_id: str, text: str, members: list[dict], reply_to: str | None = None) -> tuple[bool, str | dict]: """Encrypt and send a message. DM: per-recipient Double Ratchet. Group: Sender Keys. Returns (True, msg_dict) on success or (False, error_string) on failure. msg_dict contains the full decrypted payload ready for display. """ my_user_id = self.session["user_id"] await self.typing_stop(conv_id, force=True) # Build plaintext payload payload = { "sender": self.username, "text": text, "reply_to": reply_to, "timestamp": datetime.now(timezone.utc).isoformat(), } plaintext = pad_plaintext(json.dumps(payload, ensure_ascii=False).encode("utf-8")) if self._is_group(members): return await self._send_group_message(conv_id, plaintext, members, payload) else: return await self._send_dm(conv_id, plaintext, members, payload) async def _send_dm(self, conv_id: str, plaintext: bytes, members: list[dict], payload: dict | None = None) -> tuple[bool, str | dict]: """Encrypt DM with per-device Double Ratchet.""" my_user_id = self.session["user_id"] recipients = [] first_ratchet_header = None for member in members: uid = member.get("user_id") if not uid or uid == my_user_id: continue # Get all device bundles for this user try: device_bundles = await self._get_device_bundles(uid) self._logger.debug("Got %d device bundles for %s", len(device_bundles), uid) except Exception as e: self._logger.warning("Failed to get device bundles for %s: %s", uid, e) device_bundles = [] if not device_bundles: # Fallback: try single session (legacy peer) ratchet = await self._get_or_create_session(uid) result = ratchet.encrypt(plaintext) x3dh_hdr = getattr(ratchet, "_x3dh_header", None) if x3dh_hdr: delattr(ratchet, "_x3dh_header") entry = { "user_id": uid, "encrypted_content": encode_binary(result["ciphertext"]), "nonce": encode_binary(result["nonce"]), "ratchet_header": result["header"], } if x3dh_hdr: entry["x3dh_header"] = x3dh_hdr recipients.append(entry) if first_ratchet_header is None: first_ratchet_header = result["header"] _save_session(self.email, uid, ratchet, self._local_key) continue for bundle in device_bundles: dev_id = bundle.get("device_id") ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id, bundle=bundle) result = ratchet.encrypt(plaintext) x3dh_hdr = getattr(ratchet, "_x3dh_header", None) if x3dh_hdr: delattr(ratchet, "_x3dh_header") entry = { "user_id": uid, "encrypted_content": encode_binary(result["ciphertext"]), "nonce": encode_binary(result["nonce"]), "ratchet_header": result["header"], } if dev_id: entry["device_id"] = dev_id if x3dh_hdr: entry["x3dh_header"] = x3dh_hdr recipients.append(entry) if first_ratchet_header is None: first_ratchet_header = result["header"] _save_session(self.email, uid, ratchet, self._local_key, peer_device_id=dev_id) # Encrypt self-copy with static key derived from identity (not ratchet) # Uses SELF_DEVICE_ID so all own devices can read it self_key = derive_self_encryption_key(self.identity_private) _, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key) recipients.append({ "user_id": my_user_id, "encrypted_content": encode_binary(self_ct + self_tag), "nonce": encode_binary(self_nonce), "ratchet_header": {"self": True}, }) if not recipients: return False, "No recipients." kwargs = { "conversation_id": conv_id, "ratchet_header": first_ratchet_header, "recipients": recipients, } resp = await self.send_and_recv("send_message", **kwargs) if resp["status"] == "ok": msg_data = resp.get("data", {}) if payload is not None: result = { **payload, "message_id": msg_data.get("message_id", ""), "created_at": msg_data.get("created_at", ""), "sender_id": self.session["user_id"], "conversation_id": conv_id, "read_by": [], } _save_message_to_cache(self.email, conv_id, result["message_id"], result, self._cache_key) return True, result return True, "Message sent." return False, resp["data"]["message"] async def _send_group_message(self, conv_id: str, plaintext: bytes, members: list[dict], payload: dict | None = None) -> tuple[bool, str | dict]: """Encrypt group message with Sender Keys.""" my_user_id = self.session["user_id"] # Get or create sender key for this group sk = self.sender_key_states.get(conv_id) if not sk: sk = _load_sender_key_state(self.email, conv_id, self._local_key) if not sk: sk = SenderKeyState() self.sender_key_states[conv_id] = sk _save_sender_key_state(self.email, conv_id, sk, self._local_key) self.sender_key_states[conv_id] = sk await self._catch_up_sender_key_distribution(conv_id, members, sk) # Encrypt with sender key result = sk.encrypt(plaintext) _save_sender_key_state(self.email, conv_id, sk, self._local_key) # Build per-recipient entries (same ciphertext for all except self) recipients = [] for member in members: uid = member.get("user_id") if not uid or uid == my_user_id: continue recipients.append({ "user_id": uid, "encrypted_content": encode_binary(result["ciphertext"]), "nonce": encode_binary(result["nonce"]), }) # Self-encrypted copy (so other devices + history fetch can decrypt) self_key = derive_self_encryption_key(self.identity_private) _, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key) recipients.append({ "user_id": my_user_id, "encrypted_content": encode_binary(self_ct + self_tag), "nonce": encode_binary(self_nonce), "ratchet_header": {"self": True}, }) ratchet_header = {"dh_pub": "00" * 32, "n": 0, "pn": 0} # Dummy for groups kwargs = { "conversation_id": conv_id, "ratchet_header": ratchet_header, "recipients": recipients, "sender_chain_id": encode_binary(bytes.fromhex(result["chain_id"])), "sender_chain_n": result["n"], } resp = await self.send_and_recv("send_message", **kwargs) if resp["status"] == "ok": msg_data = resp.get("data", {}) if payload is not None: result_msg = { **payload, "message_id": msg_data.get("message_id", ""), "created_at": msg_data.get("created_at", ""), "sender_id": self.session["user_id"], "conversation_id": conv_id, "read_by": [], } _save_message_to_cache(self.email, conv_id, result_msg["message_id"], result_msg, self._cache_key) return True, result_msg return True, "Message sent." return False, resp["data"]["message"] async def _catch_up_sender_key_distribution(self, conv_id: str, members: list[dict], sk: SenderKeyState): """Ensure all current members have our existing sender key.""" my_user_id = self.session["user_id"] current_member_ids = { uid for uid in (member.get("user_id") for member in members) if uid and uid != my_user_id } if not current_member_ids: return distributed_to = _load_sender_key_recipients(self.email, conv_id, self._local_key) missing_ids = sorted(current_member_ids - distributed_to) if not missing_ids: return distributed_now = await self._distribute_sender_key( conv_id, [{"user_id": uid} for uid in missing_ids], sk, ) if distributed_now: distributed_to.update(distributed_now) _save_sender_key_recipients(self.email, conv_id, distributed_to, self._local_key) async def _distribute_sender_key(self, conv_id: str, members: list[dict], sk: SenderKeyState) -> set[str]: """Send own sender key to all group members via pairwise Double Ratchet (per-device).""" my_user_id = self.session["user_id"] distributed_to: set[str] = set() exported_key = sk.export_key() # Build a special "sender_key_distribution" payload payload = { "sender": self.username, "text": "", "reply_to": None, "timestamp": datetime.now(timezone.utc).isoformat(), "_sender_key": { "conv_id": conv_id, "key": encode_binary(exported_key), "sender_device_id": self.device_id, }, } plaintext = pad_plaintext(json.dumps(payload, ensure_ascii=False).encode("utf-8")) # Send as DM to each member's devices (per-device encryption) for member in members: uid = member.get("user_id") if not uid or uid == my_user_id: continue try: # Get all device bundles for this user try: device_bundles = await self._get_device_bundles(uid) except Exception: device_bundles = [] if not device_bundles: # Fallback: legacy single-device ratchet = await self._get_or_create_session(uid) result = ratchet.encrypt(plaintext) x3dh_header = getattr(ratchet, "_x3dh_header", None) if x3dh_header: delattr(ratchet, "_x3dh_header") recipient_entry = { "user_id": uid, "encrypted_content": encode_binary(result["ciphertext"]), "nonce": encode_binary(result["nonce"]), "ratchet_header": result["header"], } if x3dh_header: recipient_entry["x3dh_header"] = x3dh_header kwargs = { "conversation_id": conv_id, "ratchet_header": result["header"], "recipients": [recipient_entry], } await self.send_and_recv("send_message", **kwargs) _save_session(self.email, uid, ratchet, self._local_key) distributed_to.add(uid) else: # Per-device encryption recipients = [] first_rh = None for bundle in device_bundles: dev_id = bundle.get("device_id") ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id, bundle=bundle) result = ratchet.encrypt(plaintext) x3dh_header = getattr(ratchet, "_x3dh_header", None) if x3dh_header: delattr(ratchet, "_x3dh_header") entry = { "user_id": uid, "encrypted_content": encode_binary(result["ciphertext"]), "nonce": encode_binary(result["nonce"]), "ratchet_header": result["header"], } if dev_id: entry["device_id"] = dev_id if x3dh_header: entry["x3dh_header"] = x3dh_header recipients.append(entry) if first_rh is None: first_rh = result["header"] _save_session(self.email, uid, ratchet, self._local_key, peer_device_id=dev_id) kwargs = { "conversation_id": conv_id, "ratchet_header": first_rh, "recipients": recipients, } await self.send_and_recv("send_message", **kwargs) distributed_to.add(uid) except Exception as e: self._logger.warning("Failed to distribute sender key to %s: %s", uid, e) return distributed_to async def redistribute_sender_key_to_member(self, conv_id: str, new_user_id: str): """Redistribute our existing sender key to a newly joined group member. Called when we receive a member_added notification — the new member needs our sender key so they can decrypt future messages we send in the group. """ if not self.session: return sk = self.sender_key_states.get(conv_id) if sk is None: sk = _load_sender_key_state(self.email, conv_id, self._local_key) if sk is None: # We haven't sent anything in this group yet — no key to redistribute return try: distributed = await self._distribute_sender_key(conv_id, [{"user_id": new_user_id}], sk) if new_user_id in distributed: recipients = _load_sender_key_recipients(self.email, conv_id, self._local_key) recipients.add(new_user_id) _save_sender_key_recipients(self.email, conv_id, recipients, self._local_key) self._logger.info("Redistributed sender key for conv=%s to new member %s", conv_id[:8], new_user_id[:8]) except Exception as e: self._logger.warning("Failed to redistribute sender key to %s: %s", new_user_id, e) # ------------------------------------------------------------------ # Decrypt messages # ------------------------------------------------------------------ def _decrypt_message(self, msg_data: dict) -> dict: """Decrypt a single message (DM or group).""" # Check for self-encrypted marker FIRST — after re-encryption, # group messages will have {"self": true} ratchet_header but still # have sender_chain_id at message level. rh = msg_data.get("ratchet_header", {}) if isinstance(rh, dict) and rh.get("self"): return self._decrypt_dm(msg_data) if msg_data.get("sender_chain_id"): return self._decrypt_group(msg_data) else: return self._decrypt_dm(msg_data) def _decrypt_dm(self, msg_data: dict) -> dict: """Decrypt DM using Double Ratchet with sender, or static key for self-copies.""" sender_id = msg_data.get("sender_id", "") sender_device_id = msg_data.get("sender_device_id") ratchet_header = msg_data.get("ratchet_header", {}) ct_b64 = msg_data.get("encrypted_content", "") nonce_b64 = msg_data.get("nonce", "") if not ct_b64 or not nonce_b64: raise ValueError("Missing ciphertext or nonce") ciphertext = decode_binary(ct_b64) nonce = decode_binary(nonce_b64) # Self-encrypted message (own sent message copy) if isinstance(ratchet_header, dict) and ratchet_header.get("self"): self_key = derive_self_encryption_key(self.identity_private) ct = ciphertext[:-16] tag = ciphertext[-16:] plaintext = aes_decrypt(self_key, nonce, ct, tag) else: x3dh_header = msg_data.get("x3dh_header") # Session key: "sender_id:sender_device_id" or just "sender_id" for legacy session_key = f"{sender_id}:{sender_device_id}" if sender_device_id else sender_id # Try to load existing session ratchet = self.sessions.get(session_key) if not ratchet: ratchet = _load_session(self.email, sender_id, self._local_key, peer_device_id=sender_device_id) if ratchet: self.sessions[session_key] = ratchet if ratchet and not x3dh_header: # Normal case: existing session, no X3DH header plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce) _save_session(self.email, sender_id, ratchet, self._local_key, peer_device_id=sender_device_id) elif x3dh_header: if ratchet: # Existing session + X3DH header: sender may have reset. backup = ratchet.export_state() try: plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce) _save_session(self.email, sender_id, ratchet, self._local_key, peer_device_id=sender_device_id) except Exception: # Restore the known-good session before attempting a # fresh X3DH; if the X3DH path fails too, this restored # session stays installed (in memory and on disk). restored = DoubleRatchet.import_state(backup) self.sessions[session_key] = restored _save_session(self.email, sender_id, restored, self._local_key, peer_device_id=sender_device_id) ratchet = self._process_x3dh_header(sender_id, x3dh_header, sender_device_id=sender_device_id) try: plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce) except Exception: if self._prev_spk_private: ratchet = self._process_x3dh_header( sender_id, x3dh_header, sender_device_id=sender_device_id, spk_override=self._prev_spk_private) plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce) else: raise # First decrypt succeeded — only now adopt the new session self.sessions[session_key] = ratchet _save_session(self.email, sender_id, ratchet, self._local_key, peer_device_id=sender_device_id) else: ratchet = self._process_x3dh_header(sender_id, x3dh_header, sender_device_id=sender_device_id) try: plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce) except Exception: if self._prev_spk_private: ratchet = self._process_x3dh_header( sender_id, x3dh_header, sender_device_id=sender_device_id, spk_override=self._prev_spk_private) plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce) else: raise # First decrypt succeeded — install + persist the session self.sessions[session_key] = ratchet _save_session(self.email, sender_id, ratchet, self._local_key, peer_device_id=sender_device_id) else: raise ValueError(f"No session for sender {sender_id}") self._consume_pending_opk(ratchet) plaintext = unpad_plaintext(plaintext) payload = json.loads(plaintext) # Handle sender key distribution messages if "_sender_key" in payload: sk_data = payload["_sender_key"] sk_conv_id = sk_data["conv_id"] sk_key = decode_binary(sk_data["key"]) sk_sender_device_id = sk_data.get("sender_device_id") recv_sk = SenderKeyState.from_key(sk_key) if sk_sender_device_id: cache_key = f"{sk_conv_id}:{sender_id}:{sk_sender_device_id}" else: cache_key = f"{sk_conv_id}:{sender_id}" self.recv_sender_keys[cache_key] = recv_sk _save_recv_sender_key(self.email, sk_conv_id, sender_id, recv_sk, self._local_key, sender_device_id=sk_sender_device_id) # Return empty — this is a control message, not user-visible return None return payload def _decrypt_group(self, msg_data: dict) -> dict: """Decrypt group message using sender's Sender Key.""" sender_id = msg_data.get("sender_id", "") sender_device_id = msg_data.get("sender_device_id") conv_id = msg_data.get("conversation_id", "") chain_id_b64 = msg_data.get("sender_chain_id", "") chain_n = msg_data.get("sender_chain_n", 0) ct_b64 = msg_data.get("encrypted_content", "") nonce_b64 = msg_data.get("nonce", "") if not ct_b64 or not nonce_b64 or not chain_id_b64: raise ValueError("Missing group message fields") ciphertext = decode_binary(ct_b64) nonce = decode_binary(nonce_b64) chain_id = decode_binary(chain_id_b64) my_user_id = self.session["user_id"] # If we sent this message, use our own sender key if sender_id == my_user_id: sk = self.sender_key_states.get(conv_id) if not sk: sk = _load_sender_key_state(self.email, conv_id, self._local_key) if sk: self.sender_key_states[conv_id] = sk if not sk: raise ValueError("Own sender key not found") # For our own messages, we can't decrypt from sender key (it's already advanced) # Return a placeholder — the server echoed our ciphertext raise ValueError("Cannot decrypt own group message from sender key") # Use received sender key — try with sender_device_id first, fall back to without sk = None if sender_device_id: cache_key = f"{conv_id}:{sender_id}:{sender_device_id}" sk = self.recv_sender_keys.get(cache_key) if not sk: sk = _load_recv_sender_key(self.email, conv_id, sender_id, self._local_key, sender_device_id=sender_device_id) if sk: self.recv_sender_keys[cache_key] = sk if not sk: # Fallback: try without device_id (legacy or same-device) cache_key = f"{conv_id}:{sender_id}" sk = self.recv_sender_keys.get(cache_key) if not sk: sk = _load_recv_sender_key(self.email, conv_id, sender_id, self._local_key) if sk: self.recv_sender_keys[cache_key] = sk if not sk: raise ValueError(f"No sender key for {sender_id} in conversation {conv_id}") plaintext = unpad_plaintext(sk.decrypt(chain_id.hex(), chain_n, ciphertext, nonce)) _save_recv_sender_key(self.email, conv_id, sender_id, sk, self._local_key, sender_device_id=sender_device_id) return json.loads(plaintext) # ------------------------------------------------------------------ # Get/decrypt messages (batch) # ------------------------------------------------------------------ async def get_messages(self, conv_id: str, limit: int = 50, offset: int = 0) -> list[dict]: cache = _load_message_cache(self.email, conv_id, self._cache_key) my_user_id = self.session["user_id"] if self.session else "" # Incremental sync: use stored server timestamp from last successful fetch. after_ts = None if cache and offset == 0: after_ts = cache.get("__last_server_ts", {}).get("ts") req_params = {"conversation_id": conv_id, "limit": limit, "offset": offset} if after_ts: req_params["after_ts"] = after_ts resp = await self.send_and_recv("get_messages", **req_params) if resp["status"] != "ok": # Offline fallback: return from cache if available if cache and offset == 0: return self._build_from_cache(cache) return [] raw_messages = resp["data"]["messages"] raw_messages.reverse() # Server returns DESC, reverse to ASC # Decrypt new messages from server new_decrypted = self._decrypt_raw_messages(raw_messages, cache, conv_id, my_user_id) # Advance the incremental-sync watermark only across the prefix of # messages that are settled in the cache (decrypted, control, deleted, # or failed too many times). Stopping at the first unsettled message # means a transiently undecryptable message (e.g. sender key not yet # received) is re-fetched and retried on the next sync instead of # being skipped forever. if raw_messages and offset == 0: newest_ts = "" for m in raw_messages: entry = cache.get(m["message_id"]) if entry is None: break fails = entry.get("_decrypt_failed", 0) if fails and fails < _MAX_DECRYPT_RETRIES: break newest_ts = m.get("created_at", "") or newest_ts prev_ts = cache.get("__last_server_ts", {}).get("ts", "") if newest_ts and newest_ts > prev_ts: cache["__last_server_ts"] = {"ts": newest_ts} _save_message_to_cache(self.email, conv_id, "__last_server_ts", {"ts": newest_ts}, cache_key=self._cache_key) # All non-critical ops fire-and-forget to avoid blocking message display # Confirm delivery for messages from others deliver_ids = [m["message_id"] for m in new_decrypted if m.get("sender_id") and m["sender_id"] != my_user_id and not m.get("deleted")] if deliver_ids: asyncio.ensure_future(self.confirm_delivery(conv_id, deliver_ids)) # Mark entire conversation as read (fire-and-forget) asyncio.ensure_future(self.mark_conversation_read(conv_id)) # Flush self-encryption queue in background if self._pending_self_encrypt: asyncio.ensure_future(self._flush_self_encrypt()) if after_ts: # Incremental: sync deletions in background, build from cache NOW asyncio.ensure_future(self._sync_deletions(conv_id, after_ts)) return self._build_from_cache(cache) return new_decrypted async def _sync_deletions(self, conv_id: str, after_ts: str): """Sync message deletions from server (background, non-blocking).""" try: del_resp = await self.send_and_recv("get_deleted_since", conversation_id=conv_id, since_ts=after_ts) if del_resp.get("status") == "ok": for del_id in del_resp.get("data", {}).get("deleted_ids", []): _save_message_to_cache(self.email, conv_id, del_id, {"deleted": True}, cache_key=self._cache_key) except Exception: pass def get_cached_messages(self, conv_id: str) -> list[dict]: """Return messages from local disk cache only (no server call). Instant.""" if not self.email: return [] cache = _load_message_cache(self.email, conv_id, self._cache_key) if not cache: return [] return self._build_from_cache(cache) def _build_from_cache(self, cache: dict) -> list[dict]: """Build sorted message list from local cache (all messages).""" messages = [] for msg_id, p in cache.items(): if p.get("_control") or msg_id.startswith("__"): continue if p.get("_decrypt_failed"): messages.append({ "message_id": msg_id, "sender": "???", "text": "[Decryption failed]", "created_at": p.get("created_at", ""), "read_by": [], "delivered_to": [], }) continue entry = dict(p) entry.setdefault("message_id", msg_id) entry.setdefault("read_by", []) entry.setdefault("delivered_to", []) messages.append(entry) messages.sort(key=lambda m: m.get("created_at", "")) return messages def _decrypt_raw_messages(self, raw_messages: list, cache: dict, conv_id: str, my_user_id: str) -> list[dict]: """Decrypt server messages, update cache. Returns list of decrypted dicts.""" decrypted = [] for m in raw_messages: msg_id = m["message_id"] if m.get("deleted_at"): decrypted.append({ "message_id": msg_id, "sender": "", "text": "", "created_at": m["created_at"], "read_by": [], "sender_id": m.get("sender_id", ""), "deleted": True, }) cache[msg_id] = {"deleted": True, "created_at": m["created_at"]} continue # Check local cache first (ratchet keys are one-time use) cached = cache.get(msg_id) if cached and cached.get("_decrypt_failed"): if cached["_decrypt_failed"] >= _MAX_DECRYPT_RETRIES: decrypted.append({ "message_id": msg_id, "sender": "???", "text": "[Decryption failed]", "created_at": m["created_at"], "read_by": [], "sender_id": m.get("sender_id", ""), }) continue cached = None # retry decryption below if cached and not cached.get("_control"): cached["read_by"] = m.get("read_by", []) cached["delivered_to"] = m.get("delivered_to", []) cached["created_at"] = m["created_at"] if m.get("reactions"): cached["reactions"] = m["reactions"] if m.get("pinned_at"): cached["pinned_at"] = m["pinned_at"] cached["pinned_by"] = m.get("pinned_by", "") else: cached.pop("pinned_at", None) cached.pop("pinned_by", None) decrypted.append(cached) continue if cached and cached.get("_control"): continue try: msg_data = { "sender_id": m.get("sender_id", ""), "sender_device_id": m.get("sender_device_id"), "conversation_id": conv_id, "ratchet_header": m.get("ratchet_header", {}), "encrypted_content": m.get("encrypted_content", ""), "nonce": m.get("nonce", ""), "x3dh_header": m.get("x3dh_header"), "sender_chain_id": m.get("sender_chain_id"), "sender_chain_n": m.get("sender_chain_n"), } payload = self._decrypt_message(msg_data) if payload is None: _save_message_to_cache(self.email, conv_id, msg_id, {"_control": True}, cache_key=self._cache_key) cache[msg_id] = {"_control": True} continue payload["message_id"] = msg_id payload["created_at"] = m["created_at"] payload["read_by"] = m.get("read_by", []) payload["delivered_to"] = m.get("delivered_to", []) payload["sender_id"] = m.get("sender_id", "") if m.get("reactions"): payload["reactions"] = m["reactions"] if m.get("pinned_at"): payload["pinned_at"] = m["pinned_at"] payload["pinned_by"] = m.get("pinned_by", "") decrypted.append(payload) _save_message_to_cache(self.email, conv_id, msg_id, payload, cache_key=self._cache_key) cache[msg_id] = payload if m.get("sender_id", "") != my_user_id: self._pending_self_encrypt.append({ "message_id": msg_id, "payload": {k: v for k, v in payload.items() if k not in ("message_id", "created_at", "read_by", "delivered_to", "sender_id", "deleted")}, }) except Exception as e: # Record the failure (with retry count) so the sync watermark # stops here and the message is retried on the next fetch. fails = (cache.get(msg_id) or {}).get("_decrypt_failed", 0) + 1 fail_entry = {"_decrypt_failed": fails, "created_at": m["created_at"]} cache[msg_id] = fail_entry _save_message_to_cache(self.email, conv_id, msg_id, fail_entry, cache_key=self._cache_key) decrypted.append({ "message_id": msg_id, "sender": "???", "text": f"[Decryption failed: {e}]", "created_at": m["created_at"], "read_by": [], "sender_id": m.get("sender_id", ""), }) return decrypted async def _flush_self_encrypt(self): """Upload self-encrypted copies of received messages for multi-device access.""" if not self._pending_self_encrypt or not self.identity_private: return self_key = derive_self_encryption_key(self.identity_private) updates = [] for item in list(self._pending_self_encrypt): try: plaintext = json.dumps(item["payload"], ensure_ascii=False).encode("utf-8") _, nonce, ct, tag = aes_encrypt(plaintext, key=self_key) updates.append({ "message_id": item["message_id"], "encrypted_content": encode_binary(ct + tag), "nonce": encode_binary(nonce), }) except Exception: pass self._pending_self_encrypt.clear() if updates: try: for i in range(0, len(updates), 500): batch = updates[i:i + 500] await self.send_and_recv("reencrypt_messages", updates=batch) except Exception as e: self._logger.warning("Failed to self-encrypt received messages: %s", e) async def mark_read(self, conv_id: str, message_ids: list[str]): if not message_ids: return await self.send_and_recv("mark_read", conversation_id=conv_id, message_ids=message_ids) async def mark_conversation_read(self, conv_id: str): """Mark ALL unread messages in a conversation as read (server-side bulk).""" try: await self.send_and_recv("mark_conversation_read", conversation_id=conv_id) except Exception: pass # non-critical — don't fail message loading async def confirm_delivery(self, conv_id: str, message_ids: list[str]): """Confirm delivery of messages (fire-and-forget, non-critical).""" if not message_ids: return try: await self.send_and_recv("confirm_delivery", conversation_id=conv_id, message_ids=message_ids) except Exception: pass # non-critical def search_messages(self, conv_id: str, query: str) -> list[dict]: """Search cached messages in a conversation. Returns matching messages.""" cache = _load_message_cache(self.email, conv_id, self._cache_key) query_lower = query.lower() results = [] for msg_id, payload in cache.items(): if payload.get("deleted") or payload.get("_control") or payload.get("_sender_key"): continue text = payload.get("text", "") if query_lower in text.lower(): entry = dict(payload) entry["message_id"] = msg_id results.append(entry) results.sort(key=lambda m: m.get("created_at", "")) return results async def reset_session(self, peer_user_id: str, peer_device_id: str | None = None): """Delete local session and notify peer to do the same.""" if peer_device_id: session_key = f"{peer_user_id}:{peer_device_id}" else: session_key = peer_user_id self.sessions.pop(session_key, None) _delete_session_file(self.email, peer_user_id, peer_device_id) await self.send_and_recv("session_reset", peer_user_id=peer_user_id, peer_device_id=peer_device_id or "") def handle_session_reset_notification(self, from_user_id: str, from_device_id: str | None = None): """Handle incoming session reset notification — delete the matching session.""" if from_device_id: session_key = f"{from_user_id}:{from_device_id}" else: session_key = from_user_id self.sessions.pop(session_key, None) _delete_session_file(self.email, from_user_id, from_device_id) # ------------------------------------------------------------------ # Local message cache updates # ------------------------------------------------------------------ def load_message_cache(self, conv_id: str) -> dict: """Load cached messages for a conversation. Returns {msg_id: payload}.""" if not self.email: return {} return _load_message_cache(self.email, conv_id, self._cache_key) def update_message_in_cache(self, conv_id: str, message_id: str, updates: dict): """Update fields of a cached message on disk (synchronous).""" if not self.email: return cache = _load_message_cache(self.email, conv_id, self._cache_key) if message_id not in cache or cache[message_id].get("_control"): return for key, value in updates.items(): if value is None: cache[message_id].pop(key, None) else: cache[message_id][key] = value d = get_key_dir(self.email) / "message_cache" if self._cache_key: _save_message_cache_full(d, conv_id, cache, self._cache_key) # ------------------------------------------------------------------ # Reactions, Pins, Forwarding # ------------------------------------------------------------------ async def react_message(self, message_id: str, reaction: str, action: str = "add") -> tuple[bool, str]: """Add or remove a reaction on a message.""" resp = await self.send_and_recv("react_message", message_id=message_id, reaction=reaction, action=action) if resp["status"] == "ok": return True, "OK" return False, resp.get("data", {}).get("message", "Failed") async def pin_message(self, message_id: str, conversation_id: str, action: str = "pin") -> tuple[bool, str]: """Pin or unpin a message.""" resp = await self.send_and_recv("pin_message", message_id=message_id, conversation_id=conversation_id, action=action) if resp["status"] == "ok": return True, "OK" return False, resp.get("data", {}).get("message", "Failed") async def get_pinned_messages(self, conversation_id: str) -> list[dict]: """Get list of pinned messages for a conversation.""" resp = await self.send_and_recv("get_pinned_messages", conversation_id=conversation_id) if resp["status"] == "ok": return resp["data"].get("messages", []) return [] async def forward_message(self, target_conv_id: str, original_msg: dict, target_members: list[dict]) -> tuple[bool, str | dict]: """Forward a message to another conversation.""" text = original_msg.get("text", "") payload = { "sender": self.username, "text": text, "forwarded_from": { "sender": original_msg.get("sender", ""), "conversation_id": original_msg.get("conversation_id", ""), "message_id": original_msg.get("message_id", ""), }, "timestamp": datetime.now(timezone.utc).isoformat(), } # Forward image/file metadata (the encrypted blob is already on the server) if original_msg.get("image"): payload["image"] = original_msg["image"] if not text: payload["text"] = "" if original_msg.get("file"): payload["file"] = original_msg["file"] if not text: payload["text"] = "" plaintext = pad_plaintext(json.dumps(payload, ensure_ascii=False).encode("utf-8")) if self._is_group(target_members): return await self._send_group_message(target_conv_id, plaintext, target_members, payload) else: return await self._send_dm(target_conv_id, plaintext, target_members, payload) # ------------------------------------------------------------------ # Decrypt notification # ------------------------------------------------------------------ def decrypt_notification(self, notif_data: dict) -> dict | None: """Decrypt a new_message notification. Returns parsed payload or None. Supports new multi-device format (device_entries array) and legacy flat format. """ try: conv_id = notif_data.get("conversation_id", "") msg_id = notif_data.get("message_id", "") sender_id = notif_data.get("sender_id", "") sender_device_id = notif_data.get("sender_device_id") my_user_id = self.session["user_id"] if self.session else "" # Extract per-device encrypted content from device_entries or flat fields encrypted_content = "" nonce = "" ratchet_header = {} x3dh_header = None device_entries = notif_data.get("device_entries") if device_entries: # Multi-device format: pick entry matching our device_id or SELF_DEVICE_ID chosen = None self_entry = None for entry in device_entries: eid = entry.get("device_id", "") if eid == self.device_id: chosen = entry break if eid == "00000000-0000-0000-0000-000000000000": self_entry = entry # If sender is us, prefer self-encrypted entry if sender_id == my_user_id: chosen = self_entry or chosen elif not chosen: chosen = self_entry if not chosen: self._logger.warning("No matching device_entry for device %s", self.device_id) return None encrypted_content = chosen.get("encrypted_content", "") nonce = chosen.get("nonce", "") ratchet_header = chosen.get("ratchet_header") or notif_data.get("ratchet_header", {}) x3dh_header = chosen.get("x3dh_header") or notif_data.get("x3dh_header") else: # Legacy flat format encrypted_content = notif_data.get("encrypted_content", "") nonce = notif_data.get("nonce", "") ratchet_header = notif_data.get("ratchet_header", {}) x3dh_header = notif_data.get("x3dh_header") msg_data = { "sender_id": sender_id, "sender_device_id": sender_device_id, "conversation_id": conv_id, "ratchet_header": ratchet_header, "encrypted_content": encrypted_content, "nonce": nonce, "x3dh_header": x3dh_header, "sender_chain_id": notif_data.get("sender_chain_id"), "sender_chain_n": notif_data.get("sender_chain_n"), } payload = self._decrypt_message(msg_data) if payload is None: # Cache control message so get_messages skips it if msg_id and conv_id: _save_message_to_cache(self.email, conv_id, msg_id, {"_control": True}, cache_key=self._cache_key) return None payload["conversation_id"] = conv_id payload["message_id"] = msg_id payload["sender_id"] = sender_id # Use server-compatible timestamp (no timezone suffix) for cache consistency _ts = payload.get("timestamp", "") if _ts: # Strip timezone suffix (+00:00 or Z) to match server DATETIME format _ts = _ts.replace("+00:00", "").replace("Z", "") # Strip microseconds if present if "." in _ts: _ts = _ts[:_ts.index(".")] payload["created_at"] = _ts payload["read_by"] = [] payload["delivered_to"] = [] # Cache so get_messages doesn't re-decrypt (ratchet keys are one-time) if msg_id and conv_id: _save_message_to_cache(self.email, conv_id, msg_id, payload, cache_key=self._cache_key) # Queue self-encryption for received messages (multi-device access) if sender_id != my_user_id and msg_id: self._pending_self_encrypt.append({ "message_id": msg_id, "payload": {k: v for k, v in payload.items() if k not in ("conversation_id", "message_id", "created_at", "read_by", "delivered_to", "sender_id", "deleted")}, }) return payload except IdentityKeyChanged: raise # Must propagate to caller for key-change UI except Exception as e: self._logger.warning("Failed to decrypt notification: %s", e) return None # ------------------------------------------------------------------ # Delete message # ------------------------------------------------------------------ async def delete_message(self, message_id: str) -> tuple[bool, str]: resp = await self.send_and_recv("delete_message", message_id=message_id) if resp["status"] == "ok": return True, "Message deleted." return False, resp["data"]["message"] # ------------------------------------------------------------------ # Image sharing # ------------------------------------------------------------------ async def send_image(self, conv_id: str, image_path: str, members: list[dict], reply_to: str | None = None) -> tuple[bool, str]: """Encrypt and upload an image, then send as a message.""" await self.typing_stop(conv_id, force=True) try: from PIL import Image import io except ImportError: return False, "Pillow is required for image sharing. Install with: pip install Pillow" path = Path(image_path) if not path.exists(): return False, "File not found." try: img = Image.open(path) img.load() except Exception as e: return False, f"Cannot open image: {e}" # Prepare image bytes — offload CPU-heavy work to thread def _prepare_image(img, path): """PIL resize + thumbnail + AES encrypt (runs in thread).""" original_format = img.format or "JPEG" if original_format.upper() not in ("JPEG", "PNG", "WEBP", "GIF", "BMP"): original_format = "JPEG" image_bytes = path.read_bytes() # AES-GCM overhead: 16 bytes tag. Check raw size as proxy. if MAX_IMAGE_BYTES > 0 and len(image_bytes) + 16 > MAX_IMAGE_BYTES: if img.mode not in ("RGB", "L"): img = img.convert("RGB") for quality in (92, 85, 75, 60): buf = io.BytesIO() img.save(buf, format="JPEG", quality=quality) image_bytes = buf.getvalue() if len(image_bytes) + 16 <= MAX_IMAGE_BYTES: break else: for max_dim in (3840, 2560, 1920, 1280): if max(img.size) > max_dim: img.thumbnail((max_dim, max_dim), Image.Resampling.LANCZOS) buf = io.BytesIO() img.save(buf, format="JPEG", quality=75) image_bytes = buf.getvalue() if len(image_bytes) + 16 <= MAX_IMAGE_BYTES: break # Generate thumbnail thumb = img.copy() thumb.thumbnail((200, 200), Image.Resampling.LANCZOS) if thumb.mode not in ("RGB", "L"): thumb = thumb.convert("RGB") thumb_buf = io.BytesIO() thumb.save(thumb_buf, format="JPEG", quality=60) thumbnail_b64 = encode_binary(thumb_buf.getvalue()) # Encrypt img_aes_key, img_iv, img_ct, img_tag = aes_encrypt(image_bytes) encrypted_image = img_ct + img_tag return image_bytes, thumbnail_b64, img_aes_key, img_iv, encrypted_image image_bytes, thumbnail_b64, img_aes_key, img_iv, encrypted_image = \ await asyncio.get_event_loop().run_in_executor(None, _prepare_image, img, path) file_id = str(uuid.uuid4()) file_size = len(encrypted_image) # Chunked upload resp = await self.send_and_recv( "upload_image_start", conversation_id=conv_id, file_id=file_id, file_size=file_size, ) if resp["status"] != "ok": return False, resp["data"]["message"] ok, err = await self._pipelined_upload(file_id, encrypted_image) if not ok: return False, err resp = await self.send_and_recv("upload_image_end", file_id=file_id) if resp["status"] != "ok": return False, resp["data"]["message"] # Cache decrypted image locally so sender never re-downloads cache_path = self._media_cache_path(file_id) if cache_path: try: cache_path.write_bytes(image_bytes) os.chmod(cache_path, 0o600) except OSError: pass # Build message payload with image info image_info = { "file_id": file_id, "aes_key": encode_binary(img_aes_key), "iv": encode_binary(img_iv), "thumbnail": thumbnail_b64, "filename": path.name, "size": len(image_bytes), } payload = { "sender": self.username, "text": "", "reply_to": reply_to, "timestamp": datetime.now(timezone.utc).isoformat(), "image": image_info, } plaintext = pad_plaintext(json.dumps(payload, ensure_ascii=False).encode("utf-8")) my_user_id = self.session["user_id"] if self._is_group(members): # Group image: use sender key sk = self.sender_key_states.get(conv_id) if not sk: sk = _load_sender_key_state(self.email, conv_id, self._local_key) if not sk: sk = SenderKeyState() self.sender_key_states[conv_id] = sk _save_sender_key_state(self.email, conv_id, sk, self._local_key) self.sender_key_states[conv_id] = sk await self._catch_up_sender_key_distribution(conv_id, members, sk) result = sk.encrypt(plaintext) _save_sender_key_state(self.email, conv_id, sk, self._local_key) recipients = [] for member in members: uid = member.get("user_id") if not uid or uid == my_user_id: continue recipients.append({ "user_id": uid, "encrypted_content": encode_binary(result["ciphertext"]), "nonce": encode_binary(result["nonce"]), }) # Self-encrypted copy for sender self_key = derive_self_encryption_key(self.identity_private) _, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key) recipients.append({ "user_id": my_user_id, "encrypted_content": encode_binary(self_ct + self_tag), "nonce": encode_binary(self_nonce), "ratchet_header": {"self": True}, }) resp = await self.send_and_recv( "send_message", conversation_id=conv_id, ratchet_header={"dh_pub": "00" * 32, "n": 0, "pn": 0}, recipients=recipients, sender_chain_id=encode_binary(bytes.fromhex(result["chain_id"])), sender_chain_n=result["n"], image_file_id=file_id, ) else: # DM image: per-device ratchet (same pattern as _send_dm) recipients = [] first_rh = None for member in members: uid = member.get("user_id") if not uid or uid == my_user_id: continue try: device_bundles = await self._get_device_bundles(uid) except Exception: device_bundles = [] if not device_bundles: # Fallback: legacy single-device ratchet = await self._get_or_create_session(uid) result = ratchet.encrypt(plaintext) x3dh_h = getattr(ratchet, "_x3dh_header", None) if x3dh_h: delattr(ratchet, "_x3dh_header") entry = { "user_id": uid, "encrypted_content": encode_binary(result["ciphertext"]), "nonce": encode_binary(result["nonce"]), "ratchet_header": result["header"], } if x3dh_h: entry["x3dh_header"] = x3dh_h recipients.append(entry) if first_rh is None: first_rh = result["header"] _save_session(self.email, uid, ratchet, self._local_key) else: for bundle in device_bundles: dev_id = bundle.get("device_id") ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id, bundle=bundle) result = ratchet.encrypt(plaintext) x3dh_h = getattr(ratchet, "_x3dh_header", None) if x3dh_h: delattr(ratchet, "_x3dh_header") entry = { "user_id": uid, "encrypted_content": encode_binary(result["ciphertext"]), "nonce": encode_binary(result["nonce"]), "ratchet_header": result["header"], } if dev_id: entry["device_id"] = dev_id if x3dh_h: entry["x3dh_header"] = x3dh_h recipients.append(entry) if first_rh is None: first_rh = result["header"] _save_session(self.email, uid, ratchet, self._local_key, peer_device_id=dev_id) # Encrypt self-copy with static key self_key = derive_self_encryption_key(self.identity_private) _, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key) recipients.append({ "user_id": my_user_id, "encrypted_content": encode_binary(self_ct + self_tag), "nonce": encode_binary(self_nonce), "ratchet_header": {"self": True}, }) resp = await self.send_and_recv( "send_message", conversation_id=conv_id, ratchet_header=first_rh, recipients=recipients, image_file_id=file_id, ) if resp["status"] == "ok": msg_data = resp.get("data", {}) result_msg = { **payload, "message_id": msg_data.get("message_id", ""), "created_at": msg_data.get("created_at", ""), "sender_id": self.session["user_id"], "conversation_id": conv_id, "read_by": [], } _save_message_to_cache(self.email, conv_id, result_msg["message_id"], result_msg, self._cache_key) return True, result_msg return False, resp["data"]["message"] async def _pipelined_upload(self, file_id: str, encrypted_data: bytes) -> tuple[bool, str]: """Upload encrypted data in pipelined chunks (no per-chunk round-trip wait).""" file_size = len(encrypted_data) chunk_futures = [] upload_offset = 0 while upload_offset < file_size: chunk = encrypted_data[upload_offset:upload_offset + IMAGE_CHUNK_SIZE] request_id = str(uuid.uuid4()) loop = asyncio.get_running_loop() fut = loop.create_future() self._pending[request_id] = fut try: await self.writer.send_request( "upload_image_chunk", request_id=request_id, file_id=file_id, data=encode_binary(chunk), ) except Exception as e: self._pending.pop(request_id, None) return False, f"Upload failed: {e}" chunk_futures.append((request_id, fut)) upload_offset += len(chunk) for request_id, fut in chunk_futures: try: resp = await asyncio.wait_for(fut, timeout=30.0) except (asyncio.TimeoutError, ConnectionError): self._pending.pop(request_id, None) return False, "Upload chunk timed out." finally: self._pending.pop(request_id, None) if resp["status"] != "ok": return False, resp["data"]["message"] return True, "" async def send_file(self, conv_id: str, file_path: str, members: list[dict], reply_to: str | None = None) -> tuple[bool, str | dict]: """Encrypt and upload a file, then send as a message.""" await self.typing_stop(conv_id, force=True) import mimetypes path = Path(file_path) if not path.exists(): return False, "File not found." try: file_bytes = path.read_bytes() except Exception as e: return False, f"Cannot read file: {e}" mime_type = mimetypes.guess_type(path.name)[0] or "application/octet-stream" # Encrypt file with AES-256-GCM (offload to thread) file_aes_key, file_iv, file_ct, file_tag = await asyncio.get_event_loop().run_in_executor( None, aes_encrypt, file_bytes) encrypted_file = file_ct + file_tag file_id = str(uuid.uuid4()) file_size = len(encrypted_file) # Chunked upload (reuse image upload infrastructure with file_type="file") resp = await self.send_and_recv( "upload_image_start", conversation_id=conv_id, file_id=file_id, file_size=file_size, file_type="file", ) if resp["status"] != "ok": return False, resp["data"]["message"] ok, err = await self._pipelined_upload(file_id, encrypted_file) if not ok: return False, err resp = await self.send_and_recv("upload_image_end", file_id=file_id) if resp["status"] != "ok": return False, resp["data"]["message"] # Cache decrypted file locally so sender never re-downloads cache_path = self._media_cache_path(file_id) if cache_path: try: cache_path.write_bytes(file_bytes) os.chmod(cache_path, 0o600) except OSError: pass # Build message payload with file info file_info = { "file_id": file_id, "aes_key": encode_binary(file_aes_key), "iv": encode_binary(file_iv), "filename": path.name, "size": len(file_bytes), "mime_type": mime_type, } payload = { "sender": self.username, "text": "", "reply_to": reply_to, "timestamp": datetime.now(timezone.utc).isoformat(), "file": file_info, } plaintext = pad_plaintext(json.dumps(payload, ensure_ascii=False).encode("utf-8")) my_user_id = self.session["user_id"] if self._is_group(members): sk = self.sender_key_states.get(conv_id) if not sk: sk = _load_sender_key_state(self.email, conv_id, self._local_key) if not sk: sk = SenderKeyState() self.sender_key_states[conv_id] = sk _save_sender_key_state(self.email, conv_id, sk, self._local_key) self.sender_key_states[conv_id] = sk await self._catch_up_sender_key_distribution(conv_id, members, sk) result = sk.encrypt(plaintext) _save_sender_key_state(self.email, conv_id, sk, self._local_key) recipients = [] for member in members: uid = member.get("user_id") if not uid or uid == my_user_id: continue recipients.append({ "user_id": uid, "encrypted_content": encode_binary(result["ciphertext"]), "nonce": encode_binary(result["nonce"]), }) # Self-encrypted copy for sender self_key = derive_self_encryption_key(self.identity_private) _, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key) recipients.append({ "user_id": my_user_id, "encrypted_content": encode_binary(self_ct + self_tag), "nonce": encode_binary(self_nonce), "ratchet_header": {"self": True}, }) resp = await self.send_and_recv( "send_message", conversation_id=conv_id, ratchet_header={"dh_pub": "00" * 32, "n": 0, "pn": 0}, recipients=recipients, sender_chain_id=encode_binary(bytes.fromhex(result["chain_id"])), sender_chain_n=result["n"], image_file_id=file_id, ) else: # DM file: per-device ratchet (same pattern as _send_dm) recipients = [] first_rh = None for member in members: uid = member.get("user_id") if not uid or uid == my_user_id: continue try: device_bundles = await self._get_device_bundles(uid) except Exception: device_bundles = [] if not device_bundles: # Fallback: legacy single-device ratchet = await self._get_or_create_session(uid) result = ratchet.encrypt(plaintext) x3dh_h = getattr(ratchet, "_x3dh_header", None) if x3dh_h: delattr(ratchet, "_x3dh_header") entry = { "user_id": uid, "encrypted_content": encode_binary(result["ciphertext"]), "nonce": encode_binary(result["nonce"]), "ratchet_header": result["header"], } if x3dh_h: entry["x3dh_header"] = x3dh_h recipients.append(entry) if first_rh is None: first_rh = result["header"] _save_session(self.email, uid, ratchet, self._local_key) else: for bundle in device_bundles: dev_id = bundle.get("device_id") ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id, bundle=bundle) result = ratchet.encrypt(plaintext) x3dh_h = getattr(ratchet, "_x3dh_header", None) if x3dh_h: delattr(ratchet, "_x3dh_header") entry = { "user_id": uid, "encrypted_content": encode_binary(result["ciphertext"]), "nonce": encode_binary(result["nonce"]), "ratchet_header": result["header"], } if dev_id: entry["device_id"] = dev_id if x3dh_h: entry["x3dh_header"] = x3dh_h recipients.append(entry) if first_rh is None: first_rh = result["header"] _save_session(self.email, uid, ratchet, self._local_key, peer_device_id=dev_id) # Encrypt self-copy with static key self_key = derive_self_encryption_key(self.identity_private) _, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key) recipients.append({ "user_id": my_user_id, "encrypted_content": encode_binary(self_ct + self_tag), "nonce": encode_binary(self_nonce), "ratchet_header": {"self": True}, }) resp = await self.send_and_recv( "send_message", conversation_id=conv_id, ratchet_header=first_rh, recipients=recipients, image_file_id=file_id, ) if resp["status"] == "ok": msg_data = resp.get("data", {}) result_msg = { **payload, "message_id": msg_data.get("message_id", ""), "created_at": msg_data.get("created_at", ""), "sender_id": self.session["user_id"], "conversation_id": conv_id, "read_by": [], } _save_message_to_cache(self.email, conv_id, result_msg["message_id"], result_msg, self._cache_key) return True, result_msg return False, resp["data"]["message"] async def download_file(self, file_id: str, file_info: dict) -> bytes | None: """Download and decrypt a file. Returns decrypted file bytes or None.""" return await self._download_and_decrypt(file_id, file_info) def _media_cache_path(self, file_id: str) -> Path | None: """Return path for cached decrypted media file, or None if no email.""" if not self.email: return None d = get_key_dir(self.email) / "media_cache" d.mkdir(parents=True, exist_ok=True) try: os.chmod(d, 0o700) except OSError: pass return d / f"{file_id}.bin" async def _stream_download(self, file_id: str) -> bytes | None: """Download file via streaming (single request, server sends all chunks). Falls back to legacy pipelined download if server doesn't support streaming. """ request_id = str(uuid.uuid4()) q: asyncio.Queue = asyncio.Queue() self._pending[request_id] = q try: await self.writer.send_request( "download_stream", request_id=request_id, file_id=file_id, ) except Exception: self._pending.pop(request_id, None) return None chunks: dict[int, bytes] = {} try: while True: try: resp = await asyncio.wait_for(q.get(), timeout=60.0) except (asyncio.TimeoutError, ConnectionError): return None if resp.get("status") != "ok": # Server may not support download_stream — fall back return await self._legacy_download(file_id) data = resp["data"] chunk_data = decode_binary(data["data"]) chunks[data["offset"]] = chunk_data if data.get("done"): break finally: self._pending.pop(request_id, None) # Reassemble in order parts = [] for off in sorted(chunks.keys()): parts.append(chunks[off]) return b"".join(parts) async def _legacy_download(self, file_id: str) -> bytes | None: """Fallback: download file chunk by chunk (for older servers).""" resp = await self.send_and_recv("download_image", file_id=file_id, offset=0) if resp["status"] != "ok": return None data = resp["data"] first_chunk = decode_binary(data["data"]) if data.get("done"): return first_chunk chunk_size = len(first_chunk) chunks = {0: first_chunk} # Pipeline remaining chunks futures = [] offset = chunk_size total_size = data.get("total_size", 0) # Calculate how many chunks we need while offset < total_size: request_id = str(uuid.uuid4()) loop = asyncio.get_running_loop() fut = loop.create_future() self._pending[request_id] = fut try: await self.writer.send_request( "download_image", request_id=request_id, file_id=file_id, offset=offset, ) except Exception: self._pending.pop(request_id, None) return None futures.append((request_id, offset, fut)) offset += chunk_size for request_id, off, fut in futures: try: resp = await asyncio.wait_for(fut, timeout=30.0) except (asyncio.TimeoutError, ConnectionError): self._pending.pop(request_id, None) return None finally: self._pending.pop(request_id, None) if resp["status"] != "ok": return None chunk_data = decode_binary(resp["data"]["data"]) chunks[off] = chunk_data parts = [] for off in sorted(chunks.keys()): parts.append(chunks[off]) return b"".join(parts) async def _download_and_decrypt(self, file_id: str, info: dict) -> bytes | None: """Download, decrypt, and cache a media file. Used by both image and file download.""" # Check local cache first cache_path = self._media_cache_path(file_id) if cache_path and cache_path.exists(): try: return cache_path.read_bytes() except OSError: pass encrypted_data = await self._stream_download(file_id) if not encrypted_data or len(encrypted_data) < 16: return None ciphertext = encrypted_data[:-16] tag = encrypted_data[-16:] try: aes_key = decode_binary(info["aes_key"]) iv = decode_binary(info["iv"]) decrypted = aes_decrypt(aes_key, iv, ciphertext, tag) except Exception: return None # Cache decrypted result to disk if cache_path and decrypted: try: cache_path.write_bytes(decrypted) os.chmod(cache_path, 0o600) except OSError: pass return decrypted async def download_image(self, file_id: str, image_info: dict) -> bytes | None: """Download and decrypt an image. Returns decrypted image bytes or None.""" return await self._download_and_decrypt(file_id, image_info) # ------------------------------------------------------------------ # Re-encrypt history (for device pairing) # ------------------------------------------------------------------ async def reencrypt_history(self): """Re-encrypt all cached messages with self-encryption key. After device pairing, the new device shares the same identity key but cannot decrypt old messages (Double Ratchet keys are one-time use). This re-encrypts all cached messages so they can be read using the self-encryption key derived from the shared identity key. """ if not self.identity_private or not self.session: return self_key = derive_self_encryption_key(self.identity_private) # Phase 1: Fetch & decrypt all messages to populate cache # (messages the old device never opened won't be in cache yet) try: convs = await self.list_conversations() convs = list(convs) random.shuffle(convs) total_convs = len(convs) for ci, conv in enumerate(convs): cid = conv.get("id") or conv.get("conversation_id") if not cid: continue if self._reencrypt_progress_cb: self._reencrypt_progress_cb( f"Fetching messages: {ci + 1}/{total_convs} conversations..." ) offset = 0 while True: msgs = await self.get_messages(cid, limit=200, offset=offset) if not msgs or len(msgs) < 200: break offset += len(msgs) await asyncio.sleep(random.uniform(*PAIRING_REENCRYPT_INTER_FETCH_DELAY_RANGE)) except Exception as e: self._logger.warning("Failed to fetch messages for re-encryption: %s", e) # Phase 2: Read cache and re-encrypt cache_dir = get_key_dir(self.email) / "message_cache" if not cache_dir.exists(): self._logger.info("No message cache to re-encrypt.") return all_updates = [] conv_ids = set() for f in cache_dir.iterdir(): if f.suffix in (".json", ".bin"): conv_ids.add(f.stem) total_files = len(conv_ids) conv_order = sorted(conv_ids) random.shuffle(conv_order) for i, conv_id in enumerate(conv_order): cache = _load_message_cache(self.email, conv_id, self._cache_key) if not cache: continue items = list(cache.items()) random.shuffle(items) for msg_id, entry in items: # Skip control messages (sender key distribution) if entry.get("_control"): continue # Skip entries with no useful content text = entry.get("text", "") if not text and not entry.get("image") and not entry.get("file"): continue # Rebuild plaintext from cached payload payload = {k: v for k, v in entry.items() if k not in ("message_id", "created_at", "read_by", "sender_id", "deleted")} plaintext = pad_plaintext(json.dumps(payload, ensure_ascii=False).encode("utf-8")) # Re-encrypt with self-encryption key _, nonce, ct, tag = aes_encrypt(plaintext, key=self_key) all_updates.append({ "message_id": msg_id, "encrypted_content": encode_binary(ct + tag), "nonce": encode_binary(nonce), }) if self._reencrypt_progress_cb: self._reencrypt_progress_cb(f"Re-encrypting history: {i + 1}/{total_files} conversations...") if not all_updates: self._logger.info("No messages to re-encrypt.") return # Send in batches of 500 batch_size = PAIRING_REENCRYPT_BATCH_SIZE random.shuffle(all_updates) total = len(all_updates) for start in range(0, total, batch_size): batch = all_updates[start:start + batch_size] resp = await self.send_and_recv("reencrypt_messages", updates=batch) if resp["status"] != "ok": self._logger.warning("Re-encrypt batch failed: %s", resp.get("data", {}).get("message", "")) else: self._logger.info("Re-encrypted %d/%d messages.", min(start + batch_size, total), total) if start + batch_size < total: await asyncio.sleep(random.uniform(*PAIRING_REENCRYPT_INTER_BATCH_DELAY_RANGE)) if self._reencrypt_progress_cb: self._reencrypt_progress_cb(f"Re-encryption complete: {total} messages uploaded.") # ------------------------------------------------------------------ # User Profiles # ------------------------------------------------------------------ async def get_profile(self, user_id: str | None = None) -> dict | None: """Get user profile. If user_id is None, returns own profile.""" kwargs = {} if user_id: kwargs["user_id"] = user_id resp = await self.send_and_recv("get_profile", **kwargs) if resp["status"] == "ok": return resp["data"] return None async def update_profile(self, **fields) -> tuple[bool, str]: """Update own profile (phone, location, *_visible).""" resp = await self.send_and_recv("update_profile", **fields) if resp["status"] == "ok": return True, "OK" return False, resp["data"]["message"] async def update_avatar(self, image_data: bytes) -> tuple[bool, str]: """Upload avatar image.""" resp = await self.send_and_recv("update_avatar", data=encode_binary(image_data)) if resp["status"] == "ok": return True, resp["data"].get("avatar_file", "") return False, resp["data"]["message"] async def get_avatar(self, user_id: str) -> bytes | None: """Download avatar for a user.""" resp = await self.send_and_recv("get_avatar", 10.0, user_id=user_id) if resp["status"] == "ok": return decode_binary(resp["data"]["data"]) return None async def update_group_avatar(self, conv_id: str, image_data: bytes) -> tuple[bool, str]: """Upload avatar for a group conversation.""" resp = await self.send_and_recv("update_group_avatar", conversation_id=conv_id, data=encode_binary(image_data)) if resp["status"] == "ok": return True, resp["data"].get("avatar_file", "") return False, resp["data"]["message"] async def get_group_avatar(self, conv_id: str) -> bytes | None: """Download avatar for a group conversation.""" resp = await self.send_and_recv("get_group_avatar", 10.0, conversation_id=conv_id) if resp["status"] == "ok": return decode_binary(resp["data"]["data"]) return None # ------------------------------------------------------------------ # Cleanup # ------------------------------------------------------------------ async def close(self): self.connected = False for conv_id in list(self._typing_stop_tasks.keys()): self._cancel_typing_timer(conv_id) if self._listener_task: self._listener_task.cancel() if self.raw_writer: self.raw_writer.close() async def reconnect(self): """Close existing connection and re-establish: connect + re-login using in-memory keys.""" try: await self.close() except Exception: pass # Reset reader/writer but keep keys and sessions self.reader = None self.writer = None self.raw_writer = None self._listener_task = None self._pending.clear() self.login_rejected = False # Drain queues while not self._response_queue.empty(): try: self._response_queue.get_nowait() except Exception: break while not self._notification_queue.empty(): try: self._notification_queue.get_nowait() except Exception: break await self.connect() self._listener_task = asyncio.create_task(self._background_listener()) if self.email and self.private_key: # RSA challenge-response login (keys already in memory) start = await self.send_and_recv("login_start", email=self.email) if start["status"] == "ok": challenge = decode_binary(start["data"]["challenge"]) signature = rsa_sign(self.private_key, challenge) login_kwargs = { "email": self.email, "signature": encode_binary(signature), "client_version": VERSION, } if self.device_id: login_kwargs["device_id"] = self.device_id finish = await self.send_and_recv("login_finish", **login_kwargs) if finish["status"] == "ok": self.session = finish["data"] asyncio.create_task(self._ensure_prekeys()) else: # Login rejected — keys were likely rotated on another device self.session = None self.connected = False self.login_rejected = True