"""Shared network layer and ChatClient class for CLI and GUI clients. Uses X3DH + Double Ratchet for message encryption, Sender Keys for groups. RSA retained for login challenge-response only. """ import asyncio import json import logging import os import ssl import uuid from datetime import datetime, timezone from pathlib import Path from dotenv import load_dotenv load_dotenv() from crypto_utils import ( # RSA (login only) generate_rsa_keypair, serialize_private_key, serialize_public_key, load_private_key, load_public_key, rsa_sign, # Ed25519 generate_identity_keypair, serialize_ed25519_private, serialize_ed25519_private_raw, serialize_ed25519_public, load_ed25519_private, load_ed25519_public, ed25519_sign, # X25519 generate_x25519_keypair, serialize_x25519_private, serialize_x25519_public, load_x25519_private, load_x25519_public, # X3DH generate_signed_prekey, generate_one_time_prekeys, x3dh_initiate, x3dh_respond, # Double Ratchet DoubleRatchet, # Sender Keys SenderKeyState, # AES aes_encrypt, aes_decrypt, # Self-encryption derive_self_encryption_key, # Local storage encryption derive_local_storage_key, ) from protocol import ( VERSION, ProtocolReader, ProtocolWriter, encode_binary, decode_binary, build_request, MAX_MESSAGE_BYTES, IMAGE_CHUNK_SIZE, ) KEY_DIR = Path.home() / ".encrypted_chat" OPK_REPLENISH_THRESHOLD = 20 OPK_BATCH_SIZE = 50 SPK_ROTATION_DAYS = 7 def _encrypt_local(data: bytes, key: bytes) -> bytes: """Encrypt data with AES-256-GCM for local storage. Format: nonce(12) + tag(16) + ciphertext.""" _, nonce, ct, tag = aes_encrypt(data, key=key) return nonce + tag + ct def _decrypt_local(raw: bytes, key: bytes) -> bytes: """Decrypt data encrypted by _encrypt_local.""" nonce, tag, ct = raw[:12], raw[12:28], raw[28:] return aes_decrypt(key, nonce, ct, tag) def get_key_dir(email: str) -> Path: d = KEY_DIR / email d.mkdir(parents=True, exist_ok=True) os.chmod(d, 0o700) return d # --------------------------------------------------------------------------- # RSA key storage (login only — unchanged interface) # --------------------------------------------------------------------------- def save_keys(email: str, private_key, public_key, password: bytes | None = None): d = get_key_dir(email) (d / "private.pem").write_bytes(serialize_private_key(private_key, password=password)) (d / "public.pem").write_bytes(serialize_public_key(public_key)) os.chmod(d / "private.pem", 0o600) def load_keys(email: str, password: bytes | None = None): d = get_key_dir(email) priv_path = d / "private.pem" pub_path = d / "public.pem" if not priv_path.exists(): return None, None, "No local keys found." pem = priv_path.read_bytes() try: private_key = load_private_key(pem, password=password) except Exception: try: private_key = load_private_key(pem, password=None) if password: save_keys(email, private_key, load_public_key(pub_path.read_bytes()), password=password) except Exception: return None, None, "Invalid or missing password." public_key = load_public_key(pub_path.read_bytes()) return private_key, public_key, None # --------------------------------------------------------------------------- # Identity + prekey storage # --------------------------------------------------------------------------- def _save_identity_keys(email: str, ed_priv, ed_pub, password: bytes | None = None): d = get_key_dir(email) if password: (d / "identity_private.bin").write_bytes(serialize_ed25519_private(ed_priv, password=password)) else: (d / "identity_private.bin").write_bytes(serialize_ed25519_private_raw(ed_priv)) (d / "identity_public.bin").write_bytes(serialize_ed25519_public(ed_pub)) os.chmod(d / "identity_private.bin", 0o600) def _load_identity_keys(email: str, password: bytes | None = None): d = get_key_dir(email) priv_path = d / "identity_private.bin" pub_path = d / "identity_public.bin" if not priv_path.exists(): return None, None priv = load_ed25519_private(priv_path.read_bytes(), password=password) pub = load_ed25519_public(pub_path.read_bytes()) return priv, pub def _save_spk(email: str, spk_priv, spk_id: str): d = get_key_dir(email) (d / "spk_private.bin").write_bytes(serialize_x25519_private(spk_priv)) (d / "spk_id.txt").write_text(spk_id) os.chmod(d / "spk_private.bin", 0o600) def _load_spk(email: str): d = get_key_dir(email) priv_path = d / "spk_private.bin" id_path = d / "spk_id.txt" if not priv_path.exists(): return None, None priv = load_x25519_private(priv_path.read_bytes()) spk_id = id_path.read_text().strip() if id_path.exists() else "" return priv, spk_id def _save_prev_spk(email: str, spk_priv, spk_id: str): """Save previous SPK for grace period (in-flight X3DH may reference old SPK).""" d = get_key_dir(email) (d / "prev_spk_private.bin").write_bytes(serialize_x25519_private(spk_priv)) (d / "prev_spk_id.txt").write_text(spk_id) os.chmod(d / "prev_spk_private.bin", 0o600) def _load_prev_spk(email: str): """Load previous SPK (grace period). Returns (private_key, spk_id) or (None, None).""" d = get_key_dir(email) priv_path = d / "prev_spk_private.bin" id_path = d / "prev_spk_id.txt" if not priv_path.exists(): return None, None priv = load_x25519_private(priv_path.read_bytes()) spk_id = id_path.read_text().strip() if id_path.exists() else "" return priv, spk_id def _save_opk_private(email: str, opk_id: str, opk_priv): d = get_key_dir(email) / "opk_private" d.mkdir(parents=True, exist_ok=True) os.chmod(d, 0o700) (d / f"{opk_id}.bin").write_bytes(serialize_x25519_private(opk_priv)) os.chmod(d / f"{opk_id}.bin", 0o600) def _load_opk_private(email: str, opk_id: str): d = get_key_dir(email) / "opk_private" p = d / f"{opk_id}.bin" if not p.exists(): return None return load_x25519_private(p.read_bytes()) def _delete_opk_private(email: str, opk_id: str): d = get_key_dir(email) / "opk_private" p = d / f"{opk_id}.bin" try: p.unlink(missing_ok=True) except Exception: pass def _save_device_id(email: str, device_id: str): d = get_key_dir(email) p = d / "device_id.txt" p.write_text(device_id) os.chmod(p, 0o600) def _load_device_id(email: str) -> str | None: d = get_key_dir(email) p = d / "device_id.txt" if not p.exists(): return None return p.read_text().strip() or None def _save_session(email: str, peer_user_id: str, ratchet: DoubleRatchet, local_key: bytes | None = None, peer_device_id: str | None = None): d = get_key_dir(email) / "sessions" d.mkdir(parents=True, exist_ok=True) os.chmod(d, 0o700) if peer_device_id: filename = f"{peer_user_id}_{peer_device_id}.bin" else: filename = f"{peer_user_id}.bin" p = d / filename data = ratchet.export_state() if local_key: data = _encrypt_local(data, local_key) p.write_bytes(data) os.chmod(p, 0o600) def _load_session(email: str, peer_user_id: str, local_key: bytes | None = None, peer_device_id: str | None = None) -> DoubleRatchet | None: d = get_key_dir(email) / "sessions" if peer_device_id: p = d / f"{peer_user_id}_{peer_device_id}.bin" if not p.exists(): # Fallback: try old format (no device_id) and migrate p_old = d / f"{peer_user_id}.bin" if p_old.exists(): ratchet = _load_session_file(p_old, local_key) if ratchet: _save_session(email, peer_user_id, ratchet, local_key, peer_device_id=peer_device_id) try: p_old.unlink() except Exception: pass return ratchet return None else: p = d / f"{peer_user_id}.bin" if not p.exists(): return None return _load_session_file(p, local_key) def _load_session_file(p: Path, local_key: bytes | None = None) -> DoubleRatchet | None: """Load a session from a specific file path.""" if not p.exists(): return None raw = p.read_bytes() if local_key: try: data = _decrypt_local(raw, local_key) except Exception: # Migration: try loading as plaintext (old unencrypted format) try: ratchet = DoubleRatchet.import_state(raw) return ratchet except Exception: return None return DoubleRatchet.import_state(data) return DoubleRatchet.import_state(raw) def _delete_session_file(email: str, peer_user_id: str, peer_device_id: str | None = None): """Delete a session file from disk (for session reset).""" d = get_key_dir(email) / "sessions" if peer_device_id: p = d / f"{peer_user_id}_{peer_device_id}.bin" else: p = d / f"{peer_user_id}.bin" try: p.unlink(missing_ok=True) except Exception: pass def _save_sender_key_state(email: str, conv_id: str, state: SenderKeyState, local_key: bytes | None = None): d = get_key_dir(email) / "sender_keys" d.mkdir(parents=True, exist_ok=True) os.chmod(d, 0o700) p = d / f"{conv_id}.bin" data = state.export_state() if local_key: data = _encrypt_local(data, local_key) p.write_bytes(data) os.chmod(p, 0o600) def _load_sender_key_state(email: str, conv_id: str, local_key: bytes | None = None) -> SenderKeyState | None: d = get_key_dir(email) / "sender_keys" p = d / f"{conv_id}.bin" if not p.exists(): return None raw = p.read_bytes() if local_key: try: data = _decrypt_local(raw, local_key) except Exception: try: sk = SenderKeyState.import_state(raw) _save_sender_key_state(email, conv_id, sk, local_key) return sk except Exception: return None return SenderKeyState.import_state(data) return SenderKeyState.import_state(raw) def _save_recv_sender_key(email: str, conv_id: str, sender_id: str, state: SenderKeyState, local_key: bytes | None = None, sender_device_id: str | None = None): d = get_key_dir(email) / "sender_keys_recv" d.mkdir(parents=True, exist_ok=True) os.chmod(d, 0o700) if sender_device_id: filename = f"{conv_id}_{sender_id}_{sender_device_id}.bin" else: filename = f"{conv_id}_{sender_id}.bin" p = d / filename data = state.export_state() if local_key: data = _encrypt_local(data, local_key) p.write_bytes(data) os.chmod(p, 0o600) def _load_recv_sender_key(email: str, conv_id: str, sender_id: str, local_key: bytes | None = None, sender_device_id: str | None = None) -> SenderKeyState | None: d = get_key_dir(email) / "sender_keys_recv" if sender_device_id: p = d / f"{conv_id}_{sender_id}_{sender_device_id}.bin" if not p.exists(): # Fallback: try old format and migrate p_old = d / f"{conv_id}_{sender_id}.bin" if p_old.exists(): sk = _load_recv_sender_key_file(p_old, local_key) if sk: _save_recv_sender_key(email, conv_id, sender_id, sk, local_key, sender_device_id=sender_device_id) try: p_old.unlink() except Exception: pass return sk return None else: p = d / f"{conv_id}_{sender_id}.bin" if not p.exists(): return None return _load_recv_sender_key_file(p, local_key) def _load_recv_sender_key_file(p: Path, local_key: bytes | None = None) -> SenderKeyState | None: """Load a recv sender key from a specific file path.""" if not p.exists(): return None raw = p.read_bytes() if local_key: try: data = _decrypt_local(raw, local_key) except Exception: try: sk = SenderKeyState.import_state(raw) return sk except Exception: return None return SenderKeyState.import_state(data) return SenderKeyState.import_state(raw) # --------------------------------------------------------------------------- # Local decrypted message cache (Double Ratchet keys are one-time use) # --------------------------------------------------------------------------- def _load_message_cache(email: str, conv_id: str, cache_key: bytes | None = None) -> dict: d = get_key_dir(email) / "message_cache" p_bin = d / f"{conv_id}.bin" p_json = d / f"{conv_id}.json" # Migration: if old plaintext .json exists but encrypted .bin doesn't if p_json.exists() and not p_bin.exists(): try: cache = json.loads(p_json.read_text("utf-8")) if cache_key: _save_message_cache_full(d, conv_id, cache, cache_key) p_json.unlink(missing_ok=True) return cache except Exception: return {} if not p_bin.exists(): return {} if not cache_key: return {} try: raw = p_bin.read_bytes() # Format: nonce (12) + tag (16) + ciphertext nonce = raw[:12] tag = raw[12:28] ct = raw[28:] plaintext = aes_decrypt(cache_key, nonce, ct, tag) return json.loads(plaintext.decode("utf-8")) except Exception: return {} def _save_message_cache_full(d: Path, conv_id: str, cache: dict, cache_key: bytes): """Write the full cache dict encrypted to disk.""" d.mkdir(parents=True, exist_ok=True) os.chmod(d, 0o700) p = d / f"{conv_id}.bin" plaintext = json.dumps(cache, ensure_ascii=False).encode("utf-8") _key, nonce, ct, tag = aes_encrypt(plaintext, key=cache_key) p.write_bytes(nonce + tag + ct) os.chmod(p, 0o600) def _save_message_to_cache(email: str, conv_id: str, message_id: str, payload: dict, cache_key: bytes | None = None): d = get_key_dir(email) / "message_cache" cache = _load_message_cache(email, conv_id, cache_key) cache[message_id] = payload if cache_key: _save_message_cache_full(d, conv_id, cache, cache_key) else: # Fallback: plaintext (no identity key available yet) d.mkdir(parents=True, exist_ok=True) os.chmod(d, 0o700) p = d / f"{conv_id}.json" p.write_text(json.dumps(cache, ensure_ascii=False), "utf-8") os.chmod(p, 0o600) class ChatClient: def __init__(self): self.reader: ProtocolReader | None = None self.writer: ProtocolWriter | None = None self.raw_writer: asyncio.StreamWriter | None = None self.session: dict | None = None self.private_key = None # RSA private key (login only) self.public_key = None # RSA public key (login only) self.username: str = "" self.email: str = "" self._listener_task: asyncio.Task | None = None self._response_queue: asyncio.Queue = asyncio.Queue() self._notification_queue: asyncio.Queue = asyncio.Queue() self._pending: dict[str, asyncio.Future] = {} self._pairing_temp_private_key = None self._reencrypt_progress_cb = None self._logger = logging.getLogger("encrypted_chat.client") # Signal Protocol keys self.identity_private = None # Ed25519PrivateKey self.identity_public = None # Ed25519PublicKey self.spk_private = None # X25519PrivateKey (current signed prekey) self.spk_id: str = "" self._prev_spk_private = None # Previous SPK for grace period (M4) self._prev_spk_id: str = "" self.opk_privates: dict[str, object] = {} # id -> X25519PrivateKey self.sessions: dict[str, DoubleRatchet] = {} # "user_id:device_id" -> ratchet self.sender_key_states: dict[str, SenderKeyState] = {} # conv_id -> own sender key self.recv_sender_keys: dict[str, SenderKeyState] = {} # "conv_id:sender_id:device_id" -> their key # Cache: user_id -> {identity_key (Ed25519PublicKey), username, email} self._user_cache: dict[str, dict] = {} self.connected: bool = False self.login_rejected: bool = False self._cache_key: bytes | None = None # AES key for encrypting message cache on disk self._local_key: bytes | None = None # AES key for encrypting session/sender key files # Multi-device support self.device_id: str | None = None # This device's UUID self._device_bundle_cache: dict[str, tuple[float, list[dict]]] = {} # user_id -> (ts, bundles) async def connect(self): host = os.getenv("SERVER_HOST", "127.0.0.1") port = int(os.getenv("SERVER_PORT", "9999")) tls_enabled = os.getenv("TLS_ENABLED", "false").lower() in ("1", "true", "yes") tls_required = os.getenv("TLS_REQUIRED", "false").lower() in ("1", "true", "yes") ssl_context = None if tls_required and not tls_enabled: raise RuntimeError("TLS_REQUIRED is enabled but TLS is not enabled.") if tls_enabled: insecure = os.getenv("TLS_INSECURE", "false").lower() in ("1", "true", "yes") is_dev = os.getenv("ENVIRONMENT", "").lower() in ("dev", "development") if insecure and not is_dev: raise RuntimeError("TLS_INSECURE is only allowed when ENVIRONMENT=dev") ssl_context = ssl.create_default_context() ca_file = os.getenv("TLS_CA_FILE", "").strip() if ca_file: ssl_context.load_verify_locations(cafile=ca_file) elif insecure: ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE else: self._logger.warning("TLS is disabled — traffic is unencrypted. Set TLS_ENABLED=true for production.") r, w = await asyncio.open_connection(host, port, limit=MAX_MESSAGE_BYTES, ssl=ssl_context) self.reader = ProtocolReader(r) self.writer = ProtocolWriter(w) self.raw_writer = w self.connected = True async def _background_listener(self): """Read messages from server, routing responses vs notifications.""" while True: msg = await self.reader.read_message() if msg is None: self.connected = False # Fail all pending futures so send_and_recv doesn't hang pending = dict(self._pending) self._pending.clear() err = ConnectionError("Server connection lost") for fut in pending.values(): if not fut.done(): fut.set_exception(err) break if msg.get("type") in ("new_message", "messages_read", "message_deleted", "conversation_created", "member_added", "member_removed", "user_online", "user_offline", "online_users", "group_invitation", "conversation_renamed", "session_reset"): await self._notification_queue.put(msg) else: req_id = msg.get("request_id") if req_id and req_id in self._pending: fut = self._pending.pop(req_id) if not fut.done(): fut.set_result(msg) else: await self._response_queue.put(msg) async def send_and_recv(self, msg_type: str, timeout: float = 30.0, **kwargs) -> dict: try: request_id = str(uuid.uuid4()) loop = asyncio.get_running_loop() fut = loop.create_future() self._pending[request_id] = fut await self.writer.send_request(msg_type, request_id=request_id, **kwargs) except (ValueError, ConnectionError, OSError) as e: self._pending.pop(request_id, None) return { "type": msg_type, "status": "error", "data": {"message": str(e) or "Connection lost."}, } try: return await asyncio.wait_for(fut, timeout=timeout) except asyncio.TimeoutError: self._logger.warning("send_and_recv timeout for '%s' after %.0fs", msg_type, timeout) return { "type": msg_type, "status": "error", "data": {"message": f"Request timed out ({msg_type})"}, } except ConnectionError: return { "type": msg_type, "status": "error", "data": {"message": "Connection lost."}, } finally: self._pending.pop(request_id, None) # ------------------------------------------------------------------ # User info / identity key cache # ------------------------------------------------------------------ async def _get_user_info(self, user_id: str = "", email: str = "") -> dict | None: """Get user info from server, cache identity key.""" cached = self._user_cache.get(user_id) if cached: return cached kwargs = {} if user_id: kwargs["user_id"] = user_id elif email: kwargs["email"] = email else: return None resp = await self.send_and_recv("get_user_info", **kwargs) if resp["status"] != "ok": return None data = resp["data"] ik_bytes = decode_binary(data["identity_key"]) if data.get("identity_key") else None info = { "user_id": data["user_id"], "username": data["username"], "email": data["email"], "identity_key": load_ed25519_public(ik_bytes) if ik_bytes else None, "identity_key_bytes": ik_bytes, } self._user_cache[data["user_id"]] = info return info # ------------------------------------------------------------------ # Registration # ------------------------------------------------------------------ async def register(self, username: str, password: str, email: str) -> tuple[bool, str]: """Register user. Generates RSA + Ed25519 + prekeys.""" self.username = username self.email = email pwd_bytes = bytearray(password.encode("utf-8")) if password else None try: # RSA keys for login priv, pub, err = load_keys(email, password=bytes(pwd_bytes) if pwd_bytes else None) if priv is None: priv, pub = generate_rsa_keypair() save_keys(email, priv, pub, password=bytes(pwd_bytes) if pwd_bytes else None) self.private_key = priv self.public_key = pub # Ed25519 identity keys ed_priv, ed_pub = _load_identity_keys(email, password=bytes(pwd_bytes) if pwd_bytes else None) if ed_priv is None: ed_priv, ed_pub = generate_identity_keypair() _save_identity_keys(email, ed_priv, ed_pub, password=bytes(pwd_bytes) if pwd_bytes else None) self.identity_private = ed_priv self.identity_public = ed_pub self._cache_key = derive_self_encryption_key(ed_priv) self._local_key = derive_local_storage_key(ed_priv) finally: if pwd_bytes: pwd_bytes[:] = b'\x00' * len(pwd_bytes) pub_pem = serialize_public_key(pub).decode("utf-8") ik_b64 = encode_binary(serialize_ed25519_public(ed_pub)) start = await self.send_and_recv( "register", username=username, public_key=pub_pem, email=email, identity_key=ik_b64, ) if start["status"] != "ok": return False, start["data"]["message"] code = start["data"].get("code") if code: return True, code return True, start["data"].get("message", "Check your email for the code.") async def confirm_registration(self, email: str, username: str, code: str) -> tuple[bool, str]: confirm = await self.send_and_recv("register_confirm", email=email, code=code) if confirm["status"] == "ok": # Upload prekeys immediately after registration await self._generate_and_upload_prekeys() return True, f"Registered as '{username}' (ID: {confirm['data']['user_id']})" return False, confirm["data"]["message"] async def _generate_and_upload_prekeys(self, keep_spk: bool = False): """Generate SPK + OPKs and upload to server. If keep_spk=True, re-sign the existing SPK instead of generating a new one. This is used after device pairing so both devices share the same SPK and either can respond to X3DH. """ if not self.identity_private: return if keep_spk and self.spk_private and self.spk_id: # Re-sign existing SPK (both devices share the identity key) spk_pub_bytes = serialize_x25519_public(self.spk_private.public_key()) spk_sig = ed25519_sign(self.identity_private, spk_pub_bytes) spk_data = { "id": self.spk_id, "public_key": encode_binary(spk_pub_bytes), "signature": encode_binary(spk_sig), } else: # Save current SPK as previous for grace period (M4: in-flight X3DH) if self.spk_private and self.spk_id: self._prev_spk_private = self.spk_private self._prev_spk_id = self.spk_id _save_prev_spk(self.email, self.spk_private, self.spk_id) # Generate a brand-new signed prekey spk = generate_signed_prekey(self.identity_private) self.spk_private = spk["private"] self.spk_id = spk["id"] _save_spk(self.email, spk["private"], spk["id"]) spk_data = { "id": spk["id"], "public_key": encode_binary(serialize_x25519_public(spk["public"])), "signature": encode_binary(spk["signature"]), } # Generate one-time prekeys opks = generate_one_time_prekeys(OPK_BATCH_SIZE) for opk in opks: self.opk_privates[opk["id"]] = opk["private"] _save_opk_private(self.email, opk["id"], opk["private"]) # Upload to server otp_data = [ {"id": opk["id"], "public_key": encode_binary(serialize_x25519_public(opk["public"]))} for opk in opks ] await self.send_and_recv( "upload_prekeys", signed_prekey=spk_data, one_time_prekeys=otp_data, ) async def _ensure_prekeys(self): """Check OPK count and SPK age, replenish/rotate if needed.""" resp = await self.send_and_recv("get_prekey_count") if resp["status"] != "ok": return count = resp["data"].get("count", 0) spk_created_at = resp["data"].get("spk_created_at", "") need_new_spk = False if spk_created_at: try: created = datetime.fromisoformat(spk_created_at) if created.tzinfo is None: created = created.replace(tzinfo=timezone.utc) age_days = (datetime.now(timezone.utc) - created).days if age_days >= SPK_ROTATION_DAYS: need_new_spk = True self._logger.info("SPK is %d days old, rotating...", age_days) except Exception: pass if count < OPK_REPLENISH_THRESHOLD or need_new_spk: if count >= OPK_REPLENISH_THRESHOLD: self._logger.info("SPK rotation triggered (OPK count OK: %d)", count) else: self._logger.info("OPK count low (%d), replenishing...", count) await self._generate_and_upload_prekeys() # ------------------------------------------------------------------ # Login # ------------------------------------------------------------------ async def login(self, email: str, password: str) -> tuple[bool, str]: """Login user. Returns (success, message).""" self.email = email pwd_bytes = bytearray(password.encode("utf-8")) if password else None try: # Load RSA keys priv, pub, err = load_keys(email, password=bytes(pwd_bytes) if pwd_bytes else None) if priv is None: return False, err or "No local keys found. Register first." self.private_key = priv self.public_key = pub # Load identity keys ed_priv, ed_pub = _load_identity_keys(email, password=bytes(pwd_bytes) if pwd_bytes else None) finally: if pwd_bytes: pwd_bytes[:] = b'\x00' * len(pwd_bytes) if ed_priv is not None: self.identity_private = ed_priv self.identity_public = ed_pub self._cache_key = derive_self_encryption_key(ed_priv) self._local_key = derive_local_storage_key(ed_priv) # Load SPK spk_priv, spk_id = _load_spk(email) if spk_priv: self.spk_private = spk_priv self.spk_id = spk_id # Load previous SPK for grace period (M4) prev_spk_priv, prev_spk_id = _load_prev_spk(email) if prev_spk_priv: self._prev_spk_private = prev_spk_priv self._prev_spk_id = prev_spk_id # Load device_id from disk self.device_id = _load_device_id(email) # RSA challenge-response login start = await self.send_and_recv("login_start", email=email) if start["status"] != "ok": return False, start["data"]["message"] challenge = decode_binary(start["data"]["challenge"]) signature = rsa_sign(self.private_key, challenge) login_kwargs = {"email": email, "signature": encode_binary(signature), "client_version": VERSION} if self.device_id: login_kwargs["device_id"] = self.device_id finish = await self.send_and_recv("login_finish", **login_kwargs) if finish["status"] == "ok": self.session = finish["data"] self.username = self.session.get("username", "") # Store device_id from server self.device_id = finish["data"].get("device_id") if self.device_id: _save_device_id(email, self.device_id) # Replenish prekeys in background — after pairing, the new device # has no local OPK private keys so we must generate fresh ones # (server-side OPKs have no matching private keys on this device). # Use keep_spk=True to preserve the shared SPK so both devices # can respond to X3DH. opk_dir = get_key_dir(self.email) / "opk_private" has_local_opks = opk_dir.exists() and any(opk_dir.iterdir()) if has_local_opks: asyncio.create_task(self._ensure_prekeys()) else: self._logger.info("No local OPKs (likely new device). Generating fresh OPKs, keeping SPK.") asyncio.create_task(self._generate_and_upload_prekeys(keep_spk=True)) return True, f"Logged in as '{self.username}' (ID: {self.session['user_id']})" return False, finish["data"]["message"] # ------------------------------------------------------------------ # Pairing (device pairing — transfers RSA + identity keys) # ------------------------------------------------------------------ async def pairing_start(self, email: str) -> tuple[bool, str]: """Start device pairing. Returns (success, code/message).""" temp_priv, temp_pub = generate_rsa_keypair(2048) self._pairing_temp_private_key = temp_priv temp_pub_pem = serialize_public_key(temp_pub).decode("utf-8") resp = await self.send_and_recv("pairing_start", email=email, temp_public_key=temp_pub_pem) if resp["status"] == "ok": self._pairing_poll_token = resp["data"].get("poll_token", "") return True, resp["data"]["code"] return False, resp["data"]["message"] async def pairing_wait(self, code: str, email: str, password: str, timeout: int = 300) -> tuple[bool, str]: """Wait for pairing payload and import keys. Returns (success, message).""" if not self._pairing_temp_private_key: return False, "Pairing not started." from crypto_utils import aes_decrypt as _aes_decrypt poll_token = getattr(self, "_pairing_poll_token", "") deadline = asyncio.get_event_loop().time() + timeout while asyncio.get_event_loop().time() < deadline: resp = await self.send_and_recv("pairing_poll", code=code, poll_token=poll_token) if resp["status"] != "ok": return False, resp["data"]["message"] if not resp["data"].get("ready"): await asyncio.sleep(2.0) continue payload = resp["data"]["payload"] try: # Decrypt AES key with temp RSA key from cryptography.hazmat.primitives.asymmetric import padding as rsa_padding from cryptography.hazmat.primitives import hashes as rsa_hashes enc_aes_key = decode_binary(payload["encrypted_key"]) aes_key = self._pairing_temp_private_key.decrypt( enc_aes_key, rsa_padding.OAEP( mgf=rsa_padding.MGF1(algorithm=rsa_hashes.SHA256()), algorithm=rsa_hashes.SHA256(), label=None, ), ) nonce = decode_binary(payload["iv"]) ct = decode_binary(payload["ciphertext"]) tag = decode_binary(payload["tag"]) keys_json = _aes_decrypt(aes_key, nonce, ct, tag) keys_data = json.loads(keys_json) pwd_bytes = bytearray(password.encode("utf-8")) if password else None try: # Import RSA key rsa_priv = load_private_key(keys_data["rsa_private"].encode(), password=None) rsa_pub = rsa_priv.public_key() save_keys(email, rsa_priv, rsa_pub, password=bytes(pwd_bytes) if pwd_bytes else None) # Import identity keys ed_priv = load_ed25519_private(bytes.fromhex(keys_data["identity_private"])) ed_pub = ed_priv.public_key() _save_identity_keys(email, ed_priv, ed_pub, password=bytes(pwd_bytes) if pwd_bytes else None) finally: if pwd_bytes: pwd_bytes[:] = b'\x00' * len(pwd_bytes) self.email = email self.private_key = rsa_priv self.public_key = rsa_pub self.identity_private = ed_priv self.identity_public = ed_pub self._cache_key = derive_self_encryption_key(ed_priv) self._local_key = derive_local_storage_key(ed_priv) self._pairing_temp_private_key = None # Multi-device: new device generates own SPK + OPKs on first # login. No session/sender key import needed — each device # has independent Double Ratchet sessions. return True, "Pairing complete." except Exception as e: return False, f"Failed to import keys: {e}" return False, "Pairing timed out." async def authorize_device(self, code: str) -> tuple[bool, str]: """Authorize a new device by sending all keys to it.""" if not self.private_key or not self.identity_private: return False, "Not logged in." claim = await self.send_and_recv("pairing_claim", code=code) if claim["status"] != "ok": return False, claim["data"]["message"] temp_pub_pem = claim["data"]["temp_public_key"].encode("utf-8") temp_pub = load_public_key(temp_pub_pem) # Phase 1: Re-encrypt message history so new device can read old # messages via self-encryption key. This also advances ratchet states # for any previously-unfetched messages. try: await self.reencrypt_history() except Exception as e: self._logger.warning("Re-encryption failed: %s", e) # Phase 2: Build keys payload — only RSA + identity key. # Multi-device: new device generates own SPK + OPKs, creates independent # sessions. No session/sender key transfer needed. keys_data = { "rsa_private": serialize_private_key(self.private_key, password=None).decode(), "identity_private": serialize_ed25519_private_raw(self.identity_private).hex(), } # Phase 3: Encrypt and send keys to new device from cryptography.hazmat.primitives.asymmetric import padding as rsa_padding from cryptography.hazmat.primitives import hashes as rsa_hashes plaintext = json.dumps(keys_data).encode() aes_key, nonce, ct, tag = aes_encrypt(plaintext) enc_aes_key = temp_pub.encrypt( aes_key, rsa_padding.OAEP( mgf=rsa_padding.MGF1(algorithm=rsa_hashes.SHA256()), algorithm=rsa_hashes.SHA256(), label=None, ), ) payload = { "encrypted_key": encode_binary(enc_aes_key), "iv": encode_binary(nonce), "ciphertext": encode_binary(ct), "tag": encode_binary(tag), } resp = await self.send_and_recv("pairing_send", code=code, payload=payload) if resp["status"] == "ok": return True, "Device authorized." return False, resp["data"]["message"] # ------------------------------------------------------------------ # Key rotation (RSA login key only) # ------------------------------------------------------------------ async def rotate_keys(self, username: str, password: str) -> tuple[bool, str]: """Rotate RSA keypair to revoke other devices.""" if not self.session or self.session.get("username") != username: return False, "Not logged in." pwd_bytes = password.encode("utf-8") if password else None priv, pub = generate_rsa_keypair() save_keys(self.email, priv, pub, password=pwd_bytes) self.private_key = priv self.public_key = pub pub_pem = serialize_public_key(pub).decode("utf-8") resp = await self.send_and_recv("rotate_keys", public_key=pub_pem) if resp["status"] == "ok": return True, "RSA login keys rotated." return False, resp["data"]["message"] # ------------------------------------------------------------------ # Session management (X3DH + Double Ratchet) # ------------------------------------------------------------------ async def _get_device_bundles(self, peer_user_id: str) -> list[dict]: """Get per-device key bundles for a peer. Caches for 5 minutes.""" import time cached = self._device_bundle_cache.get(peer_user_id) if cached: ts, bundles = cached if time.time() - ts < 300: return bundles resp = await self.send_and_recv("get_key_bundle", user_id=peer_user_id) if resp["status"] != "ok": raise RuntimeError(f"Cannot get key bundle for {peer_user_id}: {resp['data']['message']}") data = resp["data"] ik_b64 = data.get("identity_key", "") device_bundles = data.get("device_bundles") if device_bundles: # Attach identity_key to each bundle for b in device_bundles: b["identity_key"] = ik_b64 else: # Old server: wrap flat response as single-entry list device_bundles = [{ "device_id": None, "identity_key": ik_b64, "signed_prekey_id": data.get("signed_prekey_id", ""), "signed_prekey": data.get("signed_prekey", ""), "spk_signature": data.get("spk_signature", ""), "one_time_prekey_id": data.get("one_time_prekey_id"), "one_time_prekey": data.get("one_time_prekey"), }] self._device_bundle_cache[peer_user_id] = (time.time(), device_bundles) return device_bundles async def _get_or_create_session(self, peer_user_id: str, peer_device_id: str | None = None, bundle: dict | None = None) -> DoubleRatchet: """Load existing session or create one via X3DH. If peer_device_id is set, sessions are keyed by "user_id:device_id". If bundle is provided, it's used instead of fetching from server. """ session_key = f"{peer_user_id}:{peer_device_id}" if peer_device_id else peer_user_id # Check in-memory cache if session_key in self.sessions: return self.sessions[session_key] # Check on disk ratchet = _load_session(self.email, peer_user_id, self._local_key, peer_device_id=peer_device_id) if ratchet: self.sessions[session_key] = ratchet return ratchet # Create new session via X3DH if not bundle: resp = await self.send_and_recv("get_key_bundle", user_id=peer_user_id) if resp["status"] != "ok": raise RuntimeError(f"Cannot get key bundle for {peer_user_id}: {resp['data']['message']}") bundle = resp["data"] ik_remote_bytes = decode_binary(bundle["identity_key"]) ik_remote = load_ed25519_public(ik_remote_bytes) spk_remote = load_x25519_public(decode_binary(bundle["signed_prekey"])) spk_sig = decode_binary(bundle["spk_signature"]) opk_remote = None opk_id = bundle.get("one_time_prekey_id") if bundle.get("one_time_prekey"): opk_remote = load_x25519_public(decode_binary(bundle["one_time_prekey"])) # Perform X3DH shared_secret, ek_priv, ek_pub = x3dh_initiate( self.identity_private, ik_remote, spk_remote, spk_sig, opk_remote, ) # Initialize Double Ratchet as Alice ratchet = DoubleRatchet.init_alice(shared_secret, spk_remote) self.sessions[session_key] = ratchet _save_session(self.email, peer_user_id, ratchet, self._local_key, peer_device_id=peer_device_id) # Build X3DH header for first message x3dh_header = { "ik": encode_binary(serialize_ed25519_public(self.identity_public)), "ek": encode_binary(serialize_x25519_public(ek_pub)), } if opk_id: x3dh_header["opk_id"] = opk_id # Cache the x3dh header for the next send_message call ratchet._x3dh_header = x3dh_header # Cache remote user info self._user_cache[peer_user_id] = { "user_id": peer_user_id, "identity_key": ik_remote, "identity_key_bytes": ik_remote_bytes, } return ratchet def _process_x3dh_header(self, sender_id: str, x3dh_header: dict, sender_device_id: str | None = None, spk_override=None) -> DoubleRatchet: """Process an incoming X3DH header to establish session as Bob. Args: spk_override: If provided, use this SPK private key instead of self.spk_private. Used for grace period fallback (M4). """ ik_remote_bytes = decode_binary(x3dh_header["ik"]) ik_remote = load_ed25519_public(ik_remote_bytes) ek_remote = load_x25519_public(decode_binary(x3dh_header["ek"])) opk_id = x3dh_header.get("opk_id") opk_priv = None if opk_id: opk_priv = _load_opk_private(self.email, opk_id) if opk_priv: _delete_opk_private(self.email, opk_id) spk_priv = spk_override if spk_override else self.spk_private shared_secret = x3dh_respond( self.identity_private, spk_priv, ik_remote, ek_remote, opk_priv, ) spk_pub = spk_priv.public_key() if hasattr(spk_priv, 'public_key') else None ratchet = DoubleRatchet.init_bob(shared_secret, (spk_priv, spk_pub)) session_key = f"{sender_id}:{sender_device_id}" if sender_device_id else sender_id self.sessions[session_key] = ratchet _save_session(self.email, sender_id, ratchet, self._local_key, peer_device_id=sender_device_id) self._user_cache[sender_id] = { "user_id": sender_id, "identity_key": ik_remote, "identity_key_bytes": ik_remote_bytes, } return ratchet # ------------------------------------------------------------------ # Conversations # ------------------------------------------------------------------ async def create_conversation(self, member_emails: list[str], name: str | None = None) -> tuple[str | None, str]: kwargs = {"members": member_emails} if name: kwargs["name"] = name resp = await self.send_and_recv("create_conversation", **kwargs) if resp["status"] == "ok": return resp["data"]["conversation_id"], "OK" return None, resp["data"]["message"] async def remove_member(self, conv_id: str, user_id: str) -> tuple[bool, str]: resp = await self.send_and_recv("remove_member", conversation_id=conv_id, user_id=user_id) if resp["status"] == "ok": return True, "OK" return False, resp["data"]["message"] async def leave_group(self, conv_id: str) -> tuple[bool, str]: """Leave a group conversation.""" resp = await self.send_and_recv("leave_group", conversation_id=conv_id) if resp["status"] == "ok": # Clean up local sender key state for this group self.sender_key_states.pop(conv_id, None) # Remove received sender keys for this conversation to_remove = [k for k in self.recv_sender_keys if k.startswith(f"{conv_id}:")] for k in to_remove: self.recv_sender_keys.pop(k, None) return True, "OK" return False, resp["data"]["message"] async def rename_conversation(self, conv_id: str, name: str) -> tuple[bool, str]: """Rename a group conversation (creator only).""" resp = await self.send_and_recv("rename_conversation", conversation_id=conv_id, name=name) if resp["status"] == "ok": return True, "OK" return False, resp["data"]["message"] async def delete_conversation(self, conv_id: str) -> tuple[bool, str]: """Delete a conversation (leave + server cleans up if empty).""" resp = await self.send_and_recv("delete_conversation", conversation_id=conv_id) if resp["status"] == "ok": # Clean up local sender key state self.sender_key_states.pop(conv_id, None) to_remove = [k for k in self.recv_sender_keys if k.startswith(f"{conv_id}:")] for k in to_remove: self.recv_sender_keys.pop(k, None) return True, "OK" return False, resp["data"]["message"] async def add_member(self, conv_id: str, email: str) -> tuple[bool, str]: resp = await self.send_and_recv("add_member", conversation_id=conv_id, email=email) if resp["status"] == "ok": return True, "OK" return False, resp["data"]["message"] async def accept_invitation(self, conv_id: str) -> tuple[bool, str]: """Accept a group invitation.""" resp = await self.send_and_recv("accept_invitation", conversation_id=conv_id) if resp["status"] == "ok": return True, "OK" return False, resp["data"]["message"] async def decline_invitation(self, conv_id: str) -> tuple[bool, str]: """Decline a group invitation.""" resp = await self.send_and_recv("decline_invitation", conversation_id=conv_id) if resp["status"] == "ok": return True, "OK" return False, resp["data"]["message"] async def list_invitations(self) -> list[dict]: """List pending group invitations.""" resp = await self.send_and_recv("list_invitations") if resp["status"] == "ok": return resp["data"]["invitations"] return [] async def list_conversations(self) -> list[dict]: resp = await self.send_and_recv("list_conversations") if resp["status"] == "ok": return resp["data"]["conversations"] return [] async def find_or_create_conversation(self, email: str) -> tuple[str | None, str]: resp = await self.send_and_recv("find_conversation", email=email) if resp["status"] != "ok": return None, resp["data"]["message"] conv_id = resp["data"]["conversation_id"] if conv_id: return conv_id, "OK" return await self.create_conversation([email]) # ------------------------------------------------------------------ # Send message # ------------------------------------------------------------------ def _is_group(self, members: list[dict]) -> bool: return len(members) > 2 async def send_message(self, conv_id: str, text: str, members: list[dict], reply_to: str | None = None) -> tuple[bool, str]: """Encrypt and send a message. DM: per-recipient Double Ratchet. Group: Sender Keys.""" my_user_id = self.session["user_id"] # Build plaintext payload payload = { "sender": self.username, "text": text, "reply_to": reply_to, "timestamp": datetime.now(timezone.utc).isoformat(), } plaintext = json.dumps(payload, ensure_ascii=False).encode("utf-8") if self._is_group(members): return await self._send_group_message(conv_id, plaintext, members) else: return await self._send_dm(conv_id, plaintext, members) async def _send_dm(self, conv_id: str, plaintext: bytes, members: list[dict]) -> tuple[bool, str]: """Encrypt DM with per-device Double Ratchet.""" my_user_id = self.session["user_id"] recipients = [] first_ratchet_header = None for member in members: uid = member.get("user_id") if not uid or uid == my_user_id: continue # Get all device bundles for this user try: device_bundles = await self._get_device_bundles(uid) self._logger.debug("Got %d device bundles for %s", len(device_bundles), uid) except Exception as e: self._logger.warning("Failed to get device bundles for %s: %s", uid, e) device_bundles = [] if not device_bundles: # Fallback: try single session (legacy peer) ratchet = await self._get_or_create_session(uid) result = ratchet.encrypt(plaintext) x3dh_hdr = getattr(ratchet, "_x3dh_header", None) if x3dh_hdr: delattr(ratchet, "_x3dh_header") entry = { "user_id": uid, "encrypted_content": encode_binary(result["ciphertext"]), "nonce": encode_binary(result["nonce"]), "ratchet_header": result["header"], } if x3dh_hdr: entry["x3dh_header"] = x3dh_hdr recipients.append(entry) if first_ratchet_header is None: first_ratchet_header = result["header"] _save_session(self.email, uid, ratchet, self._local_key) continue for bundle in device_bundles: dev_id = bundle.get("device_id") ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id, bundle=bundle) result = ratchet.encrypt(plaintext) x3dh_hdr = getattr(ratchet, "_x3dh_header", None) if x3dh_hdr: delattr(ratchet, "_x3dh_header") entry = { "user_id": uid, "encrypted_content": encode_binary(result["ciphertext"]), "nonce": encode_binary(result["nonce"]), "ratchet_header": result["header"], } if dev_id: entry["device_id"] = dev_id if x3dh_hdr: entry["x3dh_header"] = x3dh_hdr recipients.append(entry) if first_ratchet_header is None: first_ratchet_header = result["header"] _save_session(self.email, uid, ratchet, self._local_key, peer_device_id=dev_id) # Encrypt self-copy with static key derived from identity (not ratchet) # Uses SELF_DEVICE_ID so all own devices can read it self_key = derive_self_encryption_key(self.identity_private) _, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key) recipients.append({ "user_id": my_user_id, "encrypted_content": encode_binary(self_ct + self_tag), "nonce": encode_binary(self_nonce), "ratchet_header": {"self": True}, }) if not recipients: return False, "No recipients." kwargs = { "conversation_id": conv_id, "ratchet_header": first_ratchet_header, "recipients": recipients, } resp = await self.send_and_recv("send_message", **kwargs) if resp["status"] == "ok": return True, "Message sent." return False, resp["data"]["message"] async def _send_group_message(self, conv_id: str, plaintext: bytes, members: list[dict]) -> tuple[bool, str]: """Encrypt group message with Sender Keys.""" my_user_id = self.session["user_id"] # Get or create sender key for this group sk = self.sender_key_states.get(conv_id) if not sk: sk = _load_sender_key_state(self.email, conv_id, self._local_key) if not sk: sk = SenderKeyState() self.sender_key_states[conv_id] = sk _save_sender_key_state(self.email, conv_id, sk, self._local_key) # Distribute sender key to all members via pairwise ratchet await self._distribute_sender_key(conv_id, members, sk) self.sender_key_states[conv_id] = sk # Encrypt with sender key result = sk.encrypt(plaintext) _save_sender_key_state(self.email, conv_id, sk, self._local_key) # Build per-recipient entries (same ciphertext for all except self) recipients = [] for member in members: uid = member.get("user_id") if not uid or uid == my_user_id: continue recipients.append({ "user_id": uid, "encrypted_content": encode_binary(result["ciphertext"]), "nonce": encode_binary(result["nonce"]), }) # Self-encrypted copy (so other devices + history fetch can decrypt) self_key = derive_self_encryption_key(self.identity_private) _, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key) recipients.append({ "user_id": my_user_id, "encrypted_content": encode_binary(self_ct + self_tag), "nonce": encode_binary(self_nonce), "ratchet_header": {"self": True}, }) ratchet_header = {"dh_pub": "00" * 32, "n": 0, "pn": 0} # Dummy for groups kwargs = { "conversation_id": conv_id, "ratchet_header": ratchet_header, "recipients": recipients, "sender_chain_id": encode_binary(bytes.fromhex(result["chain_id"])), "sender_chain_n": result["n"], } resp = await self.send_and_recv("send_message", **kwargs) if resp["status"] == "ok": return True, "Message sent." return False, resp["data"]["message"] async def _distribute_sender_key(self, conv_id: str, members: list[dict], sk: SenderKeyState): """Send own sender key to all group members via pairwise Double Ratchet (per-device).""" my_user_id = self.session["user_id"] exported_key = sk.export_key() # Build a special "sender_key_distribution" payload payload = { "sender": self.username, "text": "", "reply_to": None, "timestamp": datetime.now(timezone.utc).isoformat(), "_sender_key": { "conv_id": conv_id, "key": encode_binary(exported_key), "sender_device_id": self.device_id, }, } plaintext = json.dumps(payload, ensure_ascii=False).encode("utf-8") # Send as DM to each member's devices (per-device encryption) for member in members: uid = member.get("user_id") if not uid or uid == my_user_id: continue try: # Get all device bundles for this user try: device_bundles = await self._get_device_bundles(uid) except Exception: device_bundles = [] if not device_bundles: # Fallback: legacy single-device ratchet = await self._get_or_create_session(uid) result = ratchet.encrypt(plaintext) x3dh_header = getattr(ratchet, "_x3dh_header", None) if x3dh_header: delattr(ratchet, "_x3dh_header") recipient_entry = { "user_id": uid, "encrypted_content": encode_binary(result["ciphertext"]), "nonce": encode_binary(result["nonce"]), "ratchet_header": result["header"], } if x3dh_header: recipient_entry["x3dh_header"] = x3dh_header kwargs = { "conversation_id": conv_id, "ratchet_header": result["header"], "recipients": [recipient_entry], } await self.send_and_recv("send_message", **kwargs) _save_session(self.email, uid, ratchet, self._local_key) else: # Per-device encryption recipients = [] first_rh = None for bundle in device_bundles: dev_id = bundle.get("device_id") ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id, bundle=bundle) result = ratchet.encrypt(plaintext) x3dh_header = getattr(ratchet, "_x3dh_header", None) if x3dh_header: delattr(ratchet, "_x3dh_header") entry = { "user_id": uid, "encrypted_content": encode_binary(result["ciphertext"]), "nonce": encode_binary(result["nonce"]), "ratchet_header": result["header"], } if dev_id: entry["device_id"] = dev_id if x3dh_header: entry["x3dh_header"] = x3dh_header recipients.append(entry) if first_rh is None: first_rh = result["header"] _save_session(self.email, uid, ratchet, self._local_key, peer_device_id=dev_id) kwargs = { "conversation_id": conv_id, "ratchet_header": first_rh, "recipients": recipients, } await self.send_and_recv("send_message", **kwargs) except Exception as e: self._logger.warning("Failed to distribute sender key to %s: %s", uid, e) # ------------------------------------------------------------------ # Decrypt messages # ------------------------------------------------------------------ def _decrypt_message(self, msg_data: dict) -> dict: """Decrypt a single message (DM or group).""" # Check for self-encrypted marker FIRST — after re-encryption, # group messages will have {"self": true} ratchet_header but still # have sender_chain_id at message level. rh = msg_data.get("ratchet_header", {}) if isinstance(rh, dict) and rh.get("self"): return self._decrypt_dm(msg_data) if msg_data.get("sender_chain_id"): return self._decrypt_group(msg_data) else: return self._decrypt_dm(msg_data) def _decrypt_dm(self, msg_data: dict) -> dict: """Decrypt DM using Double Ratchet with sender, or static key for self-copies.""" sender_id = msg_data.get("sender_id", "") sender_device_id = msg_data.get("sender_device_id") ratchet_header = msg_data.get("ratchet_header", {}) ct_b64 = msg_data.get("encrypted_content", "") nonce_b64 = msg_data.get("nonce", "") if not ct_b64 or not nonce_b64: raise ValueError("Missing ciphertext or nonce") ciphertext = decode_binary(ct_b64) nonce = decode_binary(nonce_b64) # Self-encrypted message (own sent message copy) if isinstance(ratchet_header, dict) and ratchet_header.get("self"): self_key = derive_self_encryption_key(self.identity_private) ct = ciphertext[:-16] tag = ciphertext[-16:] plaintext = aes_decrypt(self_key, nonce, ct, tag) else: x3dh_header = msg_data.get("x3dh_header") # Session key: "sender_id:sender_device_id" or just "sender_id" for legacy session_key = f"{sender_id}:{sender_device_id}" if sender_device_id else sender_id # Try to load existing session ratchet = self.sessions.get(session_key) if not ratchet: ratchet = _load_session(self.email, sender_id, self._local_key, peer_device_id=sender_device_id) if ratchet: self.sessions[session_key] = ratchet if ratchet and not x3dh_header: # Normal case: existing session, no X3DH header plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce) _save_session(self.email, sender_id, ratchet, self._local_key, peer_device_id=sender_device_id) elif x3dh_header: if ratchet: # Existing session + X3DH header: sender may have reset. backup = ratchet.export_state() try: plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce) _save_session(self.email, sender_id, ratchet, self._local_key, peer_device_id=sender_device_id) except Exception: restored = DoubleRatchet.import_state(backup) self.sessions[session_key] = restored _save_session(self.email, sender_id, restored, self._local_key, peer_device_id=sender_device_id) ratchet = self._process_x3dh_header(sender_id, x3dh_header, sender_device_id=sender_device_id) try: plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce) except Exception: if self._prev_spk_private: ratchet = self._process_x3dh_header( sender_id, x3dh_header, sender_device_id=sender_device_id, spk_override=self._prev_spk_private) plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce) else: raise _save_session(self.email, sender_id, ratchet, self._local_key, peer_device_id=sender_device_id) else: ratchet = self._process_x3dh_header(sender_id, x3dh_header, sender_device_id=sender_device_id) try: plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce) except Exception: if self._prev_spk_private: ratchet = self._process_x3dh_header( sender_id, x3dh_header, sender_device_id=sender_device_id, spk_override=self._prev_spk_private) plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce) else: raise _save_session(self.email, sender_id, ratchet, self._local_key, peer_device_id=sender_device_id) else: raise ValueError(f"No session for sender {sender_id}") payload = json.loads(plaintext) # Handle sender key distribution messages if "_sender_key" in payload: sk_data = payload["_sender_key"] sk_conv_id = sk_data["conv_id"] sk_key = decode_binary(sk_data["key"]) sk_sender_device_id = sk_data.get("sender_device_id") recv_sk = SenderKeyState.from_key(sk_key) if sk_sender_device_id: cache_key = f"{sk_conv_id}:{sender_id}:{sk_sender_device_id}" else: cache_key = f"{sk_conv_id}:{sender_id}" self.recv_sender_keys[cache_key] = recv_sk _save_recv_sender_key(self.email, sk_conv_id, sender_id, recv_sk, self._local_key, sender_device_id=sk_sender_device_id) # Return empty — this is a control message, not user-visible return None return payload def _decrypt_group(self, msg_data: dict) -> dict: """Decrypt group message using sender's Sender Key.""" sender_id = msg_data.get("sender_id", "") sender_device_id = msg_data.get("sender_device_id") conv_id = msg_data.get("conversation_id", "") chain_id_b64 = msg_data.get("sender_chain_id", "") chain_n = msg_data.get("sender_chain_n", 0) ct_b64 = msg_data.get("encrypted_content", "") nonce_b64 = msg_data.get("nonce", "") if not ct_b64 or not nonce_b64 or not chain_id_b64: raise ValueError("Missing group message fields") ciphertext = decode_binary(ct_b64) nonce = decode_binary(nonce_b64) chain_id = decode_binary(chain_id_b64) my_user_id = self.session["user_id"] # If we sent this message, use our own sender key if sender_id == my_user_id: sk = self.sender_key_states.get(conv_id) if not sk: sk = _load_sender_key_state(self.email, conv_id, self._local_key) if sk: self.sender_key_states[conv_id] = sk if not sk: raise ValueError("Own sender key not found") # For our own messages, we can't decrypt from sender key (it's already advanced) # Return a placeholder — the server echoed our ciphertext raise ValueError("Cannot decrypt own group message from sender key") # Use received sender key — try with sender_device_id first, fall back to without sk = None if sender_device_id: cache_key = f"{conv_id}:{sender_id}:{sender_device_id}" sk = self.recv_sender_keys.get(cache_key) if not sk: sk = _load_recv_sender_key(self.email, conv_id, sender_id, self._local_key, sender_device_id=sender_device_id) if sk: self.recv_sender_keys[cache_key] = sk if not sk: # Fallback: try without device_id (legacy or same-device) cache_key = f"{conv_id}:{sender_id}" sk = self.recv_sender_keys.get(cache_key) if not sk: sk = _load_recv_sender_key(self.email, conv_id, sender_id, self._local_key) if sk: self.recv_sender_keys[cache_key] = sk if not sk: raise ValueError(f"No sender key for {sender_id} in conversation {conv_id}") plaintext = sk.decrypt(chain_id.hex(), chain_n, ciphertext, nonce) _save_recv_sender_key(self.email, conv_id, sender_id, sk, self._local_key, sender_device_id=sender_device_id) return json.loads(plaintext) # ------------------------------------------------------------------ # Get/decrypt messages (batch) # ------------------------------------------------------------------ async def get_messages(self, conv_id: str, limit: int = 50, offset: int = 0) -> list[dict]: resp = await self.send_and_recv("get_messages", conversation_id=conv_id, limit=limit, offset=offset) if resp["status"] != "ok": return [] cache = _load_message_cache(self.email, conv_id, self._cache_key) decrypted = [] message_ids = [] raw_messages = resp["data"]["messages"] raw_messages.reverse() # Server returns DESC, reverse to ASC for m in raw_messages: msg_id = m["message_id"] message_ids.append(msg_id) if m.get("deleted_at"): decrypted.append({ "message_id": msg_id, "sender": "", "text": "", "created_at": m["created_at"], "read_by": [], "sender_id": m.get("sender_id", ""), "deleted": True, }) continue # Check local cache first (ratchet keys are one-time use) cached = cache.get(msg_id) if cached: cached["read_by"] = m.get("read_by", []) cached["created_at"] = m["created_at"] if cached.get("_control"): continue # Skip control messages decrypted.append(cached) continue try: msg_data = { "sender_id": m.get("sender_id", ""), "sender_device_id": m.get("sender_device_id"), "conversation_id": conv_id, "ratchet_header": m.get("ratchet_header", {}), "encrypted_content": m.get("encrypted_content", ""), "nonce": m.get("nonce", ""), "x3dh_header": m.get("x3dh_header"), "sender_chain_id": m.get("sender_chain_id"), "sender_chain_n": m.get("sender_chain_n"), } payload = self._decrypt_message(msg_data) if payload is None: # Control message (sender key distribution) — cache and skip _save_message_to_cache(self.email, conv_id, msg_id, {"_control": True}, cache_key=self._cache_key) continue payload["message_id"] = msg_id payload["created_at"] = m["created_at"] payload["read_by"] = m.get("read_by", []) payload["sender_id"] = m.get("sender_id", "") decrypted.append(payload) # Cache the decrypted payload (without read_by which changes) cache_entry = {k: v for k, v in payload.items() if k != "read_by"} _save_message_to_cache(self.email, conv_id, msg_id, cache_entry, cache_key=self._cache_key) except Exception as e: decrypted.append({ "message_id": msg_id, "sender": "???", "text": f"[Decryption failed: {e}]", "created_at": m["created_at"], "read_by": [], }) if message_ids: await self.mark_read(conv_id, message_ids) return decrypted async def mark_read(self, conv_id: str, message_ids: list[str]): if not message_ids: return await self.send_and_recv("mark_read", conversation_id=conv_id, message_ids=message_ids) def search_messages(self, conv_id: str, query: str) -> list[dict]: """Search cached messages in a conversation. Returns matching messages.""" cache = _load_message_cache(self.email, conv_id, self._cache_key) query_lower = query.lower() results = [] for msg_id, payload in cache.items(): if payload.get("deleted") or payload.get("_control") or payload.get("_sender_key"): continue text = payload.get("text", "") if query_lower in text.lower(): entry = dict(payload) entry["message_id"] = msg_id results.append(entry) results.sort(key=lambda m: m.get("created_at", "")) return results async def reset_session(self, peer_user_id: str, peer_device_id: str | None = None): """Delete local session and notify peer to do the same.""" if peer_device_id: session_key = f"{peer_user_id}:{peer_device_id}" else: session_key = peer_user_id self.sessions.pop(session_key, None) _delete_session_file(self.email, peer_user_id, peer_device_id) await self.send_and_recv("session_reset", peer_user_id=peer_user_id, peer_device_id=peer_device_id or "") def handle_session_reset_notification(self, from_user_id: str, from_device_id: str | None = None): """Handle incoming session reset notification — delete the matching session.""" if from_device_id: session_key = f"{from_user_id}:{from_device_id}" else: session_key = from_user_id self.sessions.pop(session_key, None) _delete_session_file(self.email, from_user_id, from_device_id) # ------------------------------------------------------------------ # Decrypt notification # ------------------------------------------------------------------ def decrypt_notification(self, notif_data: dict) -> dict | None: """Decrypt a new_message notification. Returns parsed payload or None. Supports new multi-device format (device_entries array) and legacy flat format. """ try: conv_id = notif_data.get("conversation_id", "") msg_id = notif_data.get("message_id", "") sender_id = notif_data.get("sender_id", "") sender_device_id = notif_data.get("sender_device_id") my_user_id = self.session["user_id"] if self.session else "" # Extract per-device encrypted content from device_entries or flat fields encrypted_content = "" nonce = "" ratchet_header = {} x3dh_header = None device_entries = notif_data.get("device_entries") if device_entries: # Multi-device format: pick entry matching our device_id or SELF_DEVICE_ID chosen = None self_entry = None for entry in device_entries: eid = entry.get("device_id", "") if eid == self.device_id: chosen = entry break if eid == "00000000-0000-0000-0000-000000000000": self_entry = entry # If sender is us, prefer self-encrypted entry if sender_id == my_user_id: chosen = self_entry or chosen elif not chosen: chosen = self_entry if not chosen: self._logger.warning("No matching device_entry for device %s", self.device_id) return None encrypted_content = chosen.get("encrypted_content", "") nonce = chosen.get("nonce", "") ratchet_header = chosen.get("ratchet_header") or notif_data.get("ratchet_header", {}) x3dh_header = chosen.get("x3dh_header") or notif_data.get("x3dh_header") else: # Legacy flat format encrypted_content = notif_data.get("encrypted_content", "") nonce = notif_data.get("nonce", "") ratchet_header = notif_data.get("ratchet_header", {}) x3dh_header = notif_data.get("x3dh_header") msg_data = { "sender_id": sender_id, "sender_device_id": sender_device_id, "conversation_id": conv_id, "ratchet_header": ratchet_header, "encrypted_content": encrypted_content, "nonce": nonce, "x3dh_header": x3dh_header, "sender_chain_id": notif_data.get("sender_chain_id"), "sender_chain_n": notif_data.get("sender_chain_n"), } payload = self._decrypt_message(msg_data) if payload is None: # Cache control message so get_messages skips it if msg_id and conv_id: _save_message_to_cache(self.email, conv_id, msg_id, {"_control": True}, cache_key=self._cache_key) return None payload["conversation_id"] = conv_id payload["message_id"] = msg_id payload["sender_id"] = sender_id payload["created_at"] = payload.get("timestamp", "") payload["read_by"] = [] # Cache so get_messages doesn't re-decrypt (ratchet keys are one-time) if msg_id and conv_id: cache_entry = {k: v for k, v in payload.items() if k != "read_by"} _save_message_to_cache(self.email, conv_id, msg_id, cache_entry, cache_key=self._cache_key) return payload except Exception as e: self._logger.warning("Failed to decrypt notification: %s", e) return None # ------------------------------------------------------------------ # Delete message # ------------------------------------------------------------------ async def delete_message(self, message_id: str) -> tuple[bool, str]: resp = await self.send_and_recv("delete_message", message_id=message_id) if resp["status"] == "ok": return True, "Message deleted." return False, resp["data"]["message"] # ------------------------------------------------------------------ # Image sharing # ------------------------------------------------------------------ async def send_image(self, conv_id: str, image_path: str, members: list[dict], reply_to: str | None = None) -> tuple[bool, str]: """Encrypt and upload an image, then send as a message.""" try: from PIL import Image import io except ImportError: return False, "Pillow is required for image sharing. Install with: pip install Pillow" path = Path(image_path) if not path.exists(): return False, "File not found." try: img = Image.open(path) img.load() except Exception as e: return False, f"Cannot open image: {e}" if img.mode not in ("RGB", "L"): img = img.convert("RGB") max_dim = 1920 if max(img.size) > max_dim: img.thumbnail((max_dim, max_dim), Image.Resampling.LANCZOS) buf = io.BytesIO() img.save(buf, format="JPEG", quality=85) image_bytes = buf.getvalue() thumb = img.copy() thumb.thumbnail((200, 200), Image.Resampling.LANCZOS) thumb_buf = io.BytesIO() thumb.save(thumb_buf, format="JPEG", quality=60) thumbnail_b64 = encode_binary(thumb_buf.getvalue()) # Encrypt image with AES-256-GCM img_aes_key, img_iv, img_ct, img_tag = aes_encrypt(image_bytes) encrypted_image = img_ct + img_tag file_id = str(uuid.uuid4()) file_size = len(encrypted_image) # Chunked upload resp = await self.send_and_recv( "upload_image_start", conversation_id=conv_id, file_id=file_id, file_size=file_size, ) if resp["status"] != "ok": return False, resp["data"]["message"] upload_offset = 0 while upload_offset < file_size: chunk = encrypted_image[upload_offset:upload_offset + IMAGE_CHUNK_SIZE] resp = await self.send_and_recv( "upload_image_chunk", file_id=file_id, data=encode_binary(chunk), ) if resp["status"] != "ok": return False, resp["data"]["message"] upload_offset += len(chunk) resp = await self.send_and_recv("upload_image_end", file_id=file_id) if resp["status"] != "ok": return False, resp["data"]["message"] # Build message payload with image info image_info = { "file_id": file_id, "aes_key": encode_binary(img_aes_key), "iv": encode_binary(img_iv), "thumbnail": thumbnail_b64, "filename": path.name, "size": len(image_bytes), } payload = { "sender": self.username, "text": "", "reply_to": reply_to, "timestamp": datetime.now(timezone.utc).isoformat(), "image": image_info, } plaintext = json.dumps(payload, ensure_ascii=False).encode("utf-8") my_user_id = self.session["user_id"] if self._is_group(members): # Group image: use sender key sk = self.sender_key_states.get(conv_id) if not sk: sk = _load_sender_key_state(self.email, conv_id, self._local_key) if not sk: sk = SenderKeyState() self.sender_key_states[conv_id] = sk _save_sender_key_state(self.email, conv_id, sk, self._local_key) await self._distribute_sender_key(conv_id, members, sk) result = sk.encrypt(plaintext) _save_sender_key_state(self.email, conv_id, sk, self._local_key) recipients = [] for member in members: uid = member.get("user_id") if not uid or uid == my_user_id: continue recipients.append({ "user_id": uid, "encrypted_content": encode_binary(result["ciphertext"]), "nonce": encode_binary(result["nonce"]), }) # Self-encrypted copy for sender self_key = derive_self_encryption_key(self.identity_private) _, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key) recipients.append({ "user_id": my_user_id, "encrypted_content": encode_binary(self_ct + self_tag), "nonce": encode_binary(self_nonce), "ratchet_header": {"self": True}, }) resp = await self.send_and_recv( "send_message", conversation_id=conv_id, ratchet_header={"dh_pub": "00" * 32, "n": 0, "pn": 0}, recipients=recipients, sender_chain_id=encode_binary(bytes.fromhex(result["chain_id"])), sender_chain_n=result["n"], image_file_id=file_id, ) else: # DM image: per-device ratchet (same pattern as _send_dm) recipients = [] first_rh = None for member in members: uid = member.get("user_id") if not uid or uid == my_user_id: continue try: device_bundles = await self._get_device_bundles(uid) except Exception: device_bundles = [] if not device_bundles: # Fallback: legacy single-device ratchet = await self._get_or_create_session(uid) result = ratchet.encrypt(plaintext) x3dh_h = getattr(ratchet, "_x3dh_header", None) if x3dh_h: delattr(ratchet, "_x3dh_header") entry = { "user_id": uid, "encrypted_content": encode_binary(result["ciphertext"]), "nonce": encode_binary(result["nonce"]), "ratchet_header": result["header"], } if x3dh_h: entry["x3dh_header"] = x3dh_h recipients.append(entry) if first_rh is None: first_rh = result["header"] _save_session(self.email, uid, ratchet, self._local_key) else: for bundle in device_bundles: dev_id = bundle.get("device_id") ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id, bundle=bundle) result = ratchet.encrypt(plaintext) x3dh_h = getattr(ratchet, "_x3dh_header", None) if x3dh_h: delattr(ratchet, "_x3dh_header") entry = { "user_id": uid, "encrypted_content": encode_binary(result["ciphertext"]), "nonce": encode_binary(result["nonce"]), "ratchet_header": result["header"], } if dev_id: entry["device_id"] = dev_id if x3dh_h: entry["x3dh_header"] = x3dh_h recipients.append(entry) if first_rh is None: first_rh = result["header"] _save_session(self.email, uid, ratchet, self._local_key, peer_device_id=dev_id) # Encrypt self-copy with static key self_key = derive_self_encryption_key(self.identity_private) _, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key) recipients.append({ "user_id": my_user_id, "encrypted_content": encode_binary(self_ct + self_tag), "nonce": encode_binary(self_nonce), "ratchet_header": {"self": True}, }) resp = await self.send_and_recv( "send_message", conversation_id=conv_id, ratchet_header=first_rh, recipients=recipients, image_file_id=file_id, ) if resp["status"] == "ok": return True, "Image sent." return False, resp["data"]["message"] async def send_file(self, conv_id: str, file_path: str, members: list[dict], reply_to: str | None = None) -> tuple[bool, str]: """Encrypt and upload a file, then send as a message.""" import mimetypes path = Path(file_path) if not path.exists(): return False, "File not found." try: file_bytes = path.read_bytes() except Exception as e: return False, f"Cannot read file: {e}" mime_type = mimetypes.guess_type(path.name)[0] or "application/octet-stream" # Encrypt file with AES-256-GCM file_aes_key, file_iv, file_ct, file_tag = aes_encrypt(file_bytes) encrypted_file = file_ct + file_tag file_id = str(uuid.uuid4()) file_size = len(encrypted_file) # Chunked upload (reuse image upload infrastructure with file_type="file") resp = await self.send_and_recv( "upload_image_start", conversation_id=conv_id, file_id=file_id, file_size=file_size, file_type="file", ) if resp["status"] != "ok": return False, resp["data"]["message"] upload_offset = 0 while upload_offset < file_size: chunk = encrypted_file[upload_offset:upload_offset + IMAGE_CHUNK_SIZE] resp = await self.send_and_recv( "upload_image_chunk", file_id=file_id, data=encode_binary(chunk), ) if resp["status"] != "ok": return False, resp["data"]["message"] upload_offset += len(chunk) resp = await self.send_and_recv("upload_image_end", file_id=file_id) if resp["status"] != "ok": return False, resp["data"]["message"] # Build message payload with file info file_info = { "file_id": file_id, "aes_key": encode_binary(file_aes_key), "iv": encode_binary(file_iv), "filename": path.name, "size": len(file_bytes), "mime_type": mime_type, } payload = { "sender": self.username, "text": "", "reply_to": reply_to, "timestamp": datetime.now(timezone.utc).isoformat(), "file": file_info, } plaintext = json.dumps(payload, ensure_ascii=False).encode("utf-8") my_user_id = self.session["user_id"] if self._is_group(members): sk = self.sender_key_states.get(conv_id) if not sk: sk = _load_sender_key_state(self.email, conv_id, self._local_key) if not sk: sk = SenderKeyState() self.sender_key_states[conv_id] = sk _save_sender_key_state(self.email, conv_id, sk, self._local_key) await self._distribute_sender_key(conv_id, members, sk) result = sk.encrypt(plaintext) _save_sender_key_state(self.email, conv_id, sk, self._local_key) recipients = [] for member in members: uid = member.get("user_id") if not uid or uid == my_user_id: continue recipients.append({ "user_id": uid, "encrypted_content": encode_binary(result["ciphertext"]), "nonce": encode_binary(result["nonce"]), }) # Self-encrypted copy for sender self_key = derive_self_encryption_key(self.identity_private) _, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key) recipients.append({ "user_id": my_user_id, "encrypted_content": encode_binary(self_ct + self_tag), "nonce": encode_binary(self_nonce), "ratchet_header": {"self": True}, }) resp = await self.send_and_recv( "send_message", conversation_id=conv_id, ratchet_header={"dh_pub": "00" * 32, "n": 0, "pn": 0}, recipients=recipients, sender_chain_id=encode_binary(bytes.fromhex(result["chain_id"])), sender_chain_n=result["n"], image_file_id=file_id, ) else: # DM file: per-device ratchet (same pattern as _send_dm) recipients = [] first_rh = None for member in members: uid = member.get("user_id") if not uid or uid == my_user_id: continue try: device_bundles = await self._get_device_bundles(uid) except Exception: device_bundles = [] if not device_bundles: # Fallback: legacy single-device ratchet = await self._get_or_create_session(uid) result = ratchet.encrypt(plaintext) x3dh_h = getattr(ratchet, "_x3dh_header", None) if x3dh_h: delattr(ratchet, "_x3dh_header") entry = { "user_id": uid, "encrypted_content": encode_binary(result["ciphertext"]), "nonce": encode_binary(result["nonce"]), "ratchet_header": result["header"], } if x3dh_h: entry["x3dh_header"] = x3dh_h recipients.append(entry) if first_rh is None: first_rh = result["header"] _save_session(self.email, uid, ratchet, self._local_key) else: for bundle in device_bundles: dev_id = bundle.get("device_id") ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id, bundle=bundle) result = ratchet.encrypt(plaintext) x3dh_h = getattr(ratchet, "_x3dh_header", None) if x3dh_h: delattr(ratchet, "_x3dh_header") entry = { "user_id": uid, "encrypted_content": encode_binary(result["ciphertext"]), "nonce": encode_binary(result["nonce"]), "ratchet_header": result["header"], } if dev_id: entry["device_id"] = dev_id if x3dh_h: entry["x3dh_header"] = x3dh_h recipients.append(entry) if first_rh is None: first_rh = result["header"] _save_session(self.email, uid, ratchet, self._local_key, peer_device_id=dev_id) # Encrypt self-copy with static key self_key = derive_self_encryption_key(self.identity_private) _, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key) recipients.append({ "user_id": my_user_id, "encrypted_content": encode_binary(self_ct + self_tag), "nonce": encode_binary(self_nonce), "ratchet_header": {"self": True}, }) resp = await self.send_and_recv( "send_message", conversation_id=conv_id, ratchet_header=first_rh, recipients=recipients, image_file_id=file_id, ) if resp["status"] == "ok": return True, "File sent." return False, resp["data"]["message"] async def download_file(self, file_id: str, file_info: dict) -> bytes | None: """Download and decrypt a file. Returns decrypted file bytes or None.""" chunks = [] offset = 0 while True: resp = await self.send_and_recv( "download_image", file_id=file_id, offset=offset, ) if resp["status"] != "ok": return None data = resp["data"] chunk = decode_binary(data["data"]) chunks.append(chunk) offset += len(chunk) if data.get("done"): break encrypted_data = b"".join(chunks) if len(encrypted_data) < 16: return None ciphertext = encrypted_data[:-16] tag = encrypted_data[-16:] try: file_aes_key = decode_binary(file_info["aes_key"]) iv = decode_binary(file_info["iv"]) return aes_decrypt(file_aes_key, iv, ciphertext, tag) except Exception: return None async def download_image(self, file_id: str, image_info: dict) -> bytes | None: """Download and decrypt an image. Returns decrypted image bytes or None.""" chunks = [] offset = 0 while True: resp = await self.send_and_recv( "download_image", file_id=file_id, offset=offset, ) if resp["status"] != "ok": return None data = resp["data"] chunk = decode_binary(data["data"]) chunks.append(chunk) offset += len(chunk) if data.get("done"): break encrypted_data = b"".join(chunks) if len(encrypted_data) < 16: return None ciphertext = encrypted_data[:-16] tag = encrypted_data[-16:] try: img_aes_key = decode_binary(image_info["aes_key"]) iv = decode_binary(image_info["iv"]) return aes_decrypt(img_aes_key, iv, ciphertext, tag) except Exception: return None # ------------------------------------------------------------------ # Re-encrypt history (for device pairing) # ------------------------------------------------------------------ async def reencrypt_history(self): """Re-encrypt all cached messages with self-encryption key. After device pairing, the new device shares the same identity key but cannot decrypt old messages (Double Ratchet keys are one-time use). This re-encrypts all cached messages so they can be read using the self-encryption key derived from the shared identity key. """ if not self.identity_private or not self.session: return self_key = derive_self_encryption_key(self.identity_private) # Phase 1: Fetch & decrypt all messages to populate cache # (messages the old device never opened won't be in cache yet) try: convs = await self.list_conversations() total_convs = len(convs) for ci, conv in enumerate(convs): cid = conv.get("id") or conv.get("conversation_id") if not cid: continue if self._reencrypt_progress_cb: self._reencrypt_progress_cb( f"Fetching messages: {ci + 1}/{total_convs} conversations..." ) offset = 0 while True: msgs = await self.get_messages(cid, limit=200, offset=offset) if not msgs or len(msgs) < 200: break offset += len(msgs) except Exception as e: self._logger.warning("Failed to fetch messages for re-encryption: %s", e) # Phase 2: Read cache and re-encrypt cache_dir = get_key_dir(self.email) / "message_cache" if not cache_dir.exists(): self._logger.info("No message cache to re-encrypt.") return all_updates = [] conv_ids = set() for f in cache_dir.iterdir(): if f.suffix in (".json", ".bin"): conv_ids.add(f.stem) total_files = len(conv_ids) for i, conv_id in enumerate(sorted(conv_ids)): cache = _load_message_cache(self.email, conv_id, self._cache_key) if not cache: continue for msg_id, entry in cache.items(): # Skip control messages (sender key distribution) if entry.get("_control"): continue # Skip entries with no useful content text = entry.get("text", "") if not text and not entry.get("image") and not entry.get("file"): continue # Rebuild plaintext from cached payload payload = {k: v for k, v in entry.items() if k not in ("message_id", "created_at", "read_by", "sender_id", "deleted")} plaintext = json.dumps(payload, ensure_ascii=False).encode("utf-8") # Re-encrypt with self-encryption key _, nonce, ct, tag = aes_encrypt(plaintext, key=self_key) all_updates.append({ "message_id": msg_id, "encrypted_content": encode_binary(ct + tag), "nonce": encode_binary(nonce), }) if self._reencrypt_progress_cb: self._reencrypt_progress_cb(f"Re-encrypting history: {i + 1}/{total_files} conversations...") if not all_updates: self._logger.info("No messages to re-encrypt.") return # Send in batches of 500 batch_size = 500 total = len(all_updates) for start in range(0, total, batch_size): batch = all_updates[start:start + batch_size] resp = await self.send_and_recv("reencrypt_messages", updates=batch) if resp["status"] != "ok": self._logger.warning("Re-encrypt batch failed: %s", resp.get("data", {}).get("message", "")) else: self._logger.info("Re-encrypted %d/%d messages.", min(start + batch_size, total), total) if self._reencrypt_progress_cb: self._reencrypt_progress_cb(f"Re-encryption complete: {total} messages uploaded.") # ------------------------------------------------------------------ # User Profiles # ------------------------------------------------------------------ async def get_profile(self, user_id: str | None = None) -> dict | None: """Get user profile. If user_id is None, returns own profile.""" kwargs = {} if user_id: kwargs["user_id"] = user_id resp = await self.send_and_recv("get_profile", **kwargs) if resp["status"] == "ok": return resp["data"] return None async def update_profile(self, **fields) -> tuple[bool, str]: """Update own profile (phone, location, *_visible).""" resp = await self.send_and_recv("update_profile", **fields) if resp["status"] == "ok": return True, "OK" return False, resp["data"]["message"] async def update_avatar(self, image_data: bytes) -> tuple[bool, str]: """Upload avatar image.""" resp = await self.send_and_recv("update_avatar", data=encode_binary(image_data)) if resp["status"] == "ok": return True, resp["data"].get("avatar_file", "") return False, resp["data"]["message"] async def get_avatar(self, user_id: str) -> bytes | None: """Download avatar for a user.""" resp = await self.send_and_recv("get_avatar", user_id=user_id) if resp["status"] == "ok": return decode_binary(resp["data"]["data"]) return None async def update_group_avatar(self, conv_id: str, image_data: bytes) -> tuple[bool, str]: """Upload avatar for a group conversation.""" resp = await self.send_and_recv("update_group_avatar", conversation_id=conv_id, data=encode_binary(image_data)) if resp["status"] == "ok": return True, resp["data"].get("avatar_file", "") return False, resp["data"]["message"] async def get_group_avatar(self, conv_id: str) -> bytes | None: """Download avatar for a group conversation.""" resp = await self.send_and_recv("get_group_avatar", conversation_id=conv_id) if resp["status"] == "ok": return decode_binary(resp["data"]["data"]) return None # ------------------------------------------------------------------ # Cleanup # ------------------------------------------------------------------ async def close(self): self.connected = False if self._listener_task: self._listener_task.cancel() if self.raw_writer: self.raw_writer.close() async def reconnect(self): """Close existing connection and re-establish: connect + re-login using in-memory keys.""" try: await self.close() except Exception: pass # Reset reader/writer but keep keys and sessions self.reader = None self.writer = None self.raw_writer = None self._listener_task = None self._pending.clear() self.login_rejected = False # Drain queues while not self._response_queue.empty(): try: self._response_queue.get_nowait() except Exception: break while not self._notification_queue.empty(): try: self._notification_queue.get_nowait() except Exception: break await self.connect() self._listener_task = asyncio.create_task(self._background_listener()) if self.email and self.private_key: # RSA challenge-response login (keys already in memory) start = await self.send_and_recv("login_start", email=self.email) if start["status"] == "ok": challenge = decode_binary(start["data"]["challenge"]) signature = rsa_sign(self.private_key, challenge) login_kwargs = { "email": self.email, "signature": encode_binary(signature), "client_version": VERSION, } if self.device_id: login_kwargs["device_id"] = self.device_id finish = await self.send_and_recv("login_finish", **login_kwargs) if finish["status"] == "ok": self.session = finish["data"] asyncio.create_task(self._ensure_prekeys()) else: # Login rejected — keys were likely rotated on another device self.session = None self.connected = False self.login_rejected = True