"""Asyncio TCP server — stores and relays encrypted blobs without seeing content.""" import asyncio import json import logging import os import re import secrets import signal import smtplib import ssl import subprocess import sys from email.mime.text import MIMEText from pathlib import Path from datetime import datetime, timezone from dotenv import load_dotenv load_dotenv() import db from crypto_utils import load_public_key, rsa_verify, load_ed25519_public, ed25519_verify, serialize_x25519_public from protocol import VERSION, MIN_CLIENT_VERSION, version_gte, ProtocolReader, ProtocolWriter, encode_binary, decode_binary, MAX_MESSAGE_BYTES, MAX_IMAGE_BYTES, MAX_FILE_BYTES, IMAGE_CHUNK_SIZE # Connected clients: user_id -> list[ProtocolWriter] connected_clients: dict[str, list[ProtocolWriter]] = {} # Writer -> device_id mapping (id(writer) -> device_id) writer_device_map: dict[int, str] = {} # Pairing sessions: code -> data pairing_sessions: dict[str, dict] = {} pending_registrations: dict[str, dict] = {} # Pending image uploads: file_id -> {temp_path, received_bytes, file_size, conv_id} pending_uploads: dict[str, dict] = {} # Phantom user IDs (loaded at startup, updated on create/delete) phantom_user_ids: set[str] = set() # Locks for shared mutable state (H4 race condition fix) _clients_lock = asyncio.Lock() # Protects: connected_clients, writer_device_map, phantom_user_ids _conn_lock = asyncio.Lock() # Protects: connection_counts, current_connections, rate_limits _pairing_lock = asyncio.Lock() # Protects: pairing_sessions, pending_registrations _uploads_lock = asyncio.Lock() # Protects: pending_uploads UPLOAD_DIR = Path(os.getenv("UPLOAD_DIR", "uploads")) # C6 fix: UUID validation + safe path construction to prevent path traversal _UUID_RE = re.compile(r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$', re.IGNORECASE) def _valid_uuid(value: str) -> bool: """Validate that value is a canonical UUID (no path components).""" return bool(_UUID_RE.match(value)) # L8 fix: email validation to prevent phantom DB inflation _EMAIL_RE = re.compile(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$") def _valid_email(email: str) -> bool: """Validate basic email format (L8).""" return bool(_EMAIL_RE.match(email)) and len(email) <= 254 def _append_file(path: Path, data: bytes): """Append data to file (runs in thread pool to avoid blocking event loop).""" with open(path, "ab") as f: f.write(data) def _read_file_chunk(path: Path, offset: int, size: int) -> bytes: """Read a chunk from file (runs in thread pool to avoid blocking event loop).""" with open(path, "rb") as f: f.seek(offset) return f.read(size) def _safe_upload_path(file_id: str, suffix: str) -> Path | None: """Return resolved path inside UPLOAD_DIR, or None if traversal detected.""" p = (UPLOAD_DIR / f"{file_id}{suffix}").resolve() if not p.is_relative_to(UPLOAD_DIR.resolve()): return None return p def _safe_avatar_path(filename: str) -> Path | None: """Return resolved avatar path inside UPLOAD_DIR/avatars, or None if traversal detected.""" avatar_dir = (UPLOAD_DIR / "avatars").resolve() p = (UPLOAD_DIR / "avatars" / filename).resolve() if not p.is_relative_to(avatar_dir): return None return p PAIRING_TTL_SECONDS = 120 REGISTER_TTL_SECONDS = 3600 PAIRING_MAX_POLL_ATTEMPTS = 90 # SMTP configuration for registration codes SMTP_HOST = os.getenv("SMTP_HOST", "") SMTP_PORT = int(os.getenv("SMTP_PORT", "587")) SMTP_USER = os.getenv("SMTP_USER", "") SMTP_PASS = os.getenv("SMTP_PASS", "") SMTP_FROM = os.getenv("SMTP_FROM", "") RATE_LIMIT_WINDOW = 60.0 # seconds CONNECTION_RL_WINDOW = 1.0 # seconds CONNECTION_RL_MAX = 20 # max requests per window per connection MAX_CONNECTIONS_PER_IP = 10 MAX_CONNECTIONS_GLOBAL = 200 def setup_logging(): level_name = os.getenv("LOG_LEVEL", "INFO").upper() level = getattr(logging, level_name, logging.WARNING) logging.basicConfig(level=level, format="%(levelname)s: %(message)s") logger = logging.getLogger("encrypted_chat.server") rate_limits: dict[str, list[float]] = {} connection_counts: dict[str, int] = {} current_connections = 0 def _rate_limit_key(action: str, addr: str, email: str | None = None) -> str: if email: return f"{action}|{addr}|{email}" return f"{action}|{addr}" async def _is_rate_limited(key: str, limit: int) -> bool: async with _conn_lock: now = asyncio.get_event_loop().time() window_start = now - RATE_LIMIT_WINDOW times = rate_limits.get(key, []) times = [t for t in times if t >= window_start] if len(times) >= limit: rate_limits[key] = times return True times.append(now) rate_limits[key] = times return False def _get_peer_addr(writer: ProtocolWriter) -> str: try: return str(writer._writer.get_extra_info("peername")[0]) except Exception: return "unknown" async def _notify_users(user_ids, msg_type, data, exclude_writer=None): """Snapshot writers under lock, send notifications outside lock.""" targets = [] async with _clients_lock: for uid in user_ids: for w in connected_clients.get(uid, []): targets.append(w) for w in targets: if w is exclude_writer: continue try: await w.send_response(msg_type, "ok", data) except Exception: pass async def _notify_users_individual(notifications, exclude_writer=None): """Send per-user data. notifications: list of (user_id, msg_type, data).""" targets = [] async with _clients_lock: for uid, mt, d in notifications: for w in connected_clients.get(uid, []): targets.append((w, mt, d)) for w, mt, d in targets: if w is exclude_writer: continue try: await w.send_response(mt, "ok", d) except Exception: pass async def _cleanup_pairings(): async with _pairing_lock: now = asyncio.get_event_loop().time() expired = [code for code, p in pairing_sessions.items() if now - p["created_at"] > PAIRING_TTL_SECONDS] for code in expired: pairing_sessions.pop(code, None) async def _cleanup_registrations(): async with _pairing_lock: now = asyncio.get_event_loop().time() expired = [code for code, p in pending_registrations.items() if now - p["created_at"] > REGISTER_TTL_SECONDS] for code in expired: pending_registrations.pop(code, None) def _generate_pairing_code() -> str: for _ in range(10): code = f"{int.from_bytes(os.urandom(4), 'big') % 100000000:08d}" if code not in pairing_sessions: return code return f"{int.from_bytes(os.urandom(4), 'big') % 100000000:08d}" def _generate_register_code() -> str: for _ in range(10): code = f"{int.from_bytes(os.urandom(3), 'big') % 1000000:06d}" if code not in pending_registrations: return code return f"{int.from_bytes(os.urandom(3), 'big') % 1000000:06d}" def _validate_public_key_pem(pem_str: str) -> bool: """Validate that a string is a valid RSA public key PEM.""" try: key = load_public_key(pem_str.encode("utf-8")) if key.key_size < 2048: return False return True except Exception: return False def _send_registration_email(to_email: str, code: str) -> bool: """Send registration code via SMTP. Returns True on success.""" if not SMTP_HOST: return False try: msg = MIMEText(f"Your registration code is: {code}\n\nThis code expires in 1 hour.") msg["Subject"] = "Encrypted Chat - Registration Code" msg["From"] = SMTP_FROM or SMTP_USER msg["To"] = to_email with smtplib.SMTP(SMTP_HOST, SMTP_PORT, timeout=10) as server: server.starttls() if SMTP_USER: server.login(SMTP_USER, SMTP_PASS) server.send_message(msg) return True except Exception as e: logger.warning("Failed to send registration email: %s", e) return False async def send_resp(msg: dict, writer: ProtocolWriter, msg_type: str, status: str, data: dict | None = None): await writer.send_response(msg_type, status, data, request_id=msg.get("request_id")) async def handle_register_start(msg: dict, writer: ProtocolWriter) -> dict | None: await _cleanup_registrations() username = msg.get("username", "").strip() public_key = msg.get("public_key", "").strip() identity_key_b64 = msg.get("identity_key", "").strip() email = msg.get("email", "").strip() addr = _get_peer_addr(writer) if await _is_rate_limited(_rate_limit_key("register_start", addr, email), 3): await send_resp(msg, writer, "register_start", "error", {"message": "Too many attempts. Try later."}) return None if not username or not public_key or not email or not identity_key_b64: await send_resp(msg, writer, "register_start", "error", {"message": "Missing fields"}) return None if not _validate_public_key_pem(public_key): await send_resp(msg, writer, "register_start", "error", {"message": "Invalid public key format"}) return None # Validate identity key is 32 bytes try: ik_bytes = decode_binary(identity_key_b64) if len(ik_bytes) != 32: raise ValueError("Identity key must be 32 bytes") load_ed25519_public(ik_bytes) except Exception: await send_resp(msg, writer, "register_start", "error", {"message": "Invalid identity key"}) return None existing_email = db.get_user_by_email(email) phantom_id = None if existing_email: if existing_email.get("rsa_public_key") == "PHANTOM": # Don't delete — will be upgraded in register_confirm to preserve # FK references (group_invitations, conversation_members, etc.) phantom_id = existing_email["id"] else: # H3 anti-enumeration: return same response as success to prevent # attackers from discovering valid emails. User won't receive a code # via email, so they can't confirm — silent failure. logger.debug("Registration attempt for existing email (hidden from client).") await send_resp(msg, writer, "register_start", "ok", {"message": "Code sent to your email."}) return None async with _pairing_lock: code = _generate_register_code() pending_registrations[code] = { "username": username, "public_key": public_key, "identity_key": ik_bytes, "email": email, "created_at": asyncio.get_event_loop().time(), "phantom_id": phantom_id, } logger.info("Registration started.") email_sent = _send_registration_email(email, code) if email_sent: await send_resp(msg, writer, "register_start", "ok", {"message": "Code sent to your email."}) else: if SMTP_HOST: logger.warning("SMTP configured but email failed for %s", email) else: logger.warning("No SMTP configured — returning code directly (dev mode).") await send_resp(msg, writer, "register_start", "ok", {"code": code}) return None async def handle_register_confirm(msg: dict, writer: ProtocolWriter) -> dict | None: await _cleanup_registrations() email = msg.get("email", "").strip() code = msg.get("code", "").strip() addr = _get_peer_addr(writer) if await _is_rate_limited(_rate_limit_key("register_confirm", addr, email), 3): await send_resp(msg, writer, "register_confirm", "error", {"message": "Too many attempts. Try later."}) return None if not email or not code: await send_resp(msg, writer, "register_confirm", "error", {"message": "Missing email or code"}) return None async with _pairing_lock: pending = pending_registrations.get(code) if pending and pending.get("email") == email: pending_registrations.pop(code, None) else: pending = None if not pending: await send_resp(msg, writer, "register_confirm", "error", {"message": "Invalid or expired code"}) return None phantom_id = pending.get("phantom_id") if phantom_id: # Upgrade phantom in-place — preserves FK references (invitations, memberships) user_id = db.upgrade_phantom_user( phantom_id, pending["username"], pending["public_key"], pending["identity_key"], ) if user_id: async with _clients_lock: phantom_user_ids.discard(phantom_id) else: # Phantom was deleted concurrently — fall back to normal create user_id = db.create_user( pending["username"], pending["email"], pending["public_key"], pending["identity_key"], ) else: user_id = db.create_user( pending["username"], pending["email"], pending["public_key"], pending["identity_key"], ) db.create_default_profile(user_id) logger.info("User registered.") await send_resp(msg, writer, "register_confirm", "ok", {"user_id": user_id}) return None async def handle_login_start(msg: dict, writer: ProtocolWriter, state: dict): email = msg.get("email", "").strip() addr = _get_peer_addr(writer) if await _is_rate_limited(_rate_limit_key("login_start", addr, email), 10): await send_resp(msg, writer, "login_start", "error", {"message": "Too many attempts. Try later."}) return if not email: await send_resp(msg, writer, "login_start", "error", {"message": "Missing email"}) return user = db.get_user_by_email(email) challenge = os.urandom(32) state["login_email"] = email state["login_challenge"] = challenge if not user: # H3 anti-enumeration: return a fake challenge so attacker can't distinguish # "user not found" from "user exists". login_finish will fail with generic error. state["_login_fake"] = True await send_resp(msg, writer, "login_start", "ok", {"challenge": encode_binary(challenge)}) async def handle_login_finish(msg: dict, writer: ProtocolWriter, state: dict) -> dict | None: email = msg.get("email", "").strip() signature_b64 = msg.get("signature", "") challenge = state.get("login_challenge") expected_email = state.get("login_email") addr = _get_peer_addr(writer) if await _is_rate_limited(_rate_limit_key("login_finish", addr, email), 10): await send_resp(msg, writer, "login_finish", "error", {"message": "Too many attempts. Try later."}) return None if not email or not signature_b64: await send_resp(msg, writer, "login_finish", "error", {"message": "Missing email or signature"}) return None if not challenge or expected_email != email: await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"}) return None # H3: if login_start was for a non-existent user, fail with generic error is_fake = state.pop("_login_fake", False) try: if is_fake: await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"}) return None user = db.get_user_by_email(email) if not user: await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"}) return None public_key = load_public_key(user["rsa_public_key"].encode("utf-8")) signature = decode_binary(signature_b64) if not rsa_verify(public_key, signature, challenge): await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"}) return None except ValueError: # H5: invalid base64 in signature await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"}) return None finally: state.pop("login_challenge", None) state.pop("login_email", None) user_id = user["id"] # Version check: reject outdated clients client_version = msg.get("client_version", "") if client_version and not version_gte(client_version, MIN_CLIENT_VERSION): await send_resp(msg, writer, "login_finish", "error", { "message": f"Client version {client_version} is too old. Minimum required: {MIN_CLIENT_VERSION}", "min_version": MIN_CLIENT_VERSION, "server_version": VERSION, }) return None # Device registration: client may send device_id to reuse an existing device client_device_id = msg.get("device_id") device_id = None if client_device_id: dev = db.get_device(client_device_id) if dev and dev["user_id"] == user_id: device_id = client_device_id if not device_id: device_name = msg.get("device_name", "Unknown") device_id = db.create_device(user_id, device_name) db.update_device_last_seen(device_id) async with _clients_lock: was_offline = user_id not in connected_clients or not connected_clients[user_id] if user_id not in connected_clients: connected_clients[user_id] = [] connected_clients[user_id].append(writer) writer_device_map[id(writer)] = device_id logger.info("User logged in (device %s, client v%s).", device_id, client_version or "unknown") await send_resp(msg, writer, "login_finish", "ok", { "user_id": user_id, "username": user["username"], "email": user["email"], "device_id": device_id, "server_version": VERSION, }) # Send online status notifications contacts = db.get_user_contacts(user_id) online_targets = [] async with _clients_lock: online_contacts = [cid for cid in contacts if cid in connected_clients and connected_clients[cid]] if was_offline: for contact_id in contacts: for cw in connected_clients.get(contact_id, []): online_targets.append(cw) await writer.send_response("online_users", "ok", {"user_ids": online_contacts}) # Send online notifications outside lock for cw in online_targets: try: await cw.send_response("user_online", "ok", {"user_id": user_id}) except Exception: pass return {"user_id": user_id, "username": user["username"], "email": user["email"], "device_id": device_id} async def handle_get_user_info(msg: dict, writer: ProtocolWriter): """Get user info including identity key (for X3DH).""" email = msg.get("email", "").strip() user_id = msg.get("user_id", "").strip() addr = _get_peer_addr(writer) if await _is_rate_limited(_rate_limit_key("get_user_info", addr, email or user_id), 30): await send_resp(msg, writer, "get_user_info", "error", {"message": "Too many attempts. Try later."}) return if user_id and not _valid_uuid(user_id): await send_resp(msg, writer, "get_user_info", "error", {"message": "Invalid user_id"}) return user = None if email: user = db.get_user_by_email(email) elif user_id: user = db.get_user_by_id(user_id) if not user: await send_resp(msg, writer, "get_user_info", "error", {"message": "User not found"}) return ik = user.get("identity_key") await send_resp(msg, writer, "get_user_info", "ok", { "user_id": user["id"], "username": user["username"], "email": user["email"], "identity_key": encode_binary(ik) if ik else "", }) async def handle_upload_prekeys(msg: dict, session: dict, writer: ProtocolWriter): """Upload signed prekey + batch of one-time prekeys.""" spk_data = msg.get("signed_prekey") otps = msg.get("one_time_prekeys", []) if not spk_data: await send_resp(msg, writer, "upload_prekeys", "error", {"message": "Missing signed_prekey"}) return spk_id = spk_data.get("id", "") spk_pub_b64 = spk_data.get("public_key", "") spk_sig_b64 = spk_data.get("signature", "") if not spk_id or not spk_pub_b64 or not spk_sig_b64: await send_resp(msg, writer, "upload_prekeys", "error", {"message": "Incomplete signed_prekey"}) return spk_pub = decode_binary(spk_pub_b64) spk_sig = decode_binary(spk_sig_b64) # Verify SPK signature with user's identity key user = db.get_user_by_id(session["user_id"]) if not user or not user.get("identity_key"): await send_resp(msg, writer, "upload_prekeys", "error", {"message": "No identity key"}) return ik_pub = load_ed25519_public(user["identity_key"]) if not ed25519_verify(ik_pub, spk_sig, spk_pub): await send_resp(msg, writer, "upload_prekeys", "error", {"message": "Invalid SPK signature"}) return device_id = session.get("device_id") db.store_signed_prekey(session["user_id"], spk_id, spk_pub, spk_sig, device_id=device_id) # Store OTPs otp_records = [] for otp in otps: otp_id = otp.get("id", "") otp_pub_b64 = otp.get("public_key", "") if otp_id and otp_pub_b64: otp_records.append({"id": otp_id, "public_key": decode_binary(otp_pub_b64)}) if otp_records: db.store_one_time_prekeys(session["user_id"], otp_records, device_id=device_id) logger.info("Prekeys uploaded: 1 SPK + %d OTPs (device %s)", len(otp_records), device_id) await send_resp(msg, writer, "upload_prekeys", "ok", {"message": "OK"}) async def handle_get_key_bundle(msg: dict, session: dict, writer: ProtocolWriter): """Fetch key bundle for X3DH. Returns per-device bundles. Consumes one OTP per device.""" target_user_id = msg.get("user_id", "").strip() if not target_user_id: await send_resp(msg, writer, "get_key_bundle", "error", {"message": "Missing user_id"}) return if not _valid_uuid(target_user_id): await send_resp(msg, writer, "get_key_bundle", "error", {"message": "Invalid user_id"}) return result = db.get_key_bundles_for_user(target_user_id) if not result or not result.get("device_bundles"): await send_resp(msg, writer, "get_key_bundle", "error", {"message": "Key bundle not available"}) return device_bundles_data = [] for b in result["device_bundles"]: entry = { "device_id": b.get("device_id"), "signed_prekey_id": b["signed_prekey_id"], "signed_prekey": encode_binary(b["signed_prekey_pub"]), "spk_signature": encode_binary(b["spk_signature"]), } if b.get("opk_pub"): entry["one_time_prekey_id"] = b["opk_id"] entry["one_time_prekey"] = encode_binary(b["opk_pub"]) device_bundles_data.append(entry) # Build response with both new multi-device format and legacy flat fields first = device_bundles_data[0] if device_bundles_data else {} data = { "identity_key": encode_binary(result["identity_key"]), "device_bundles": device_bundles_data, # Legacy flat fields from first device bundle (backward compat) "signed_prekey_id": first.get("signed_prekey_id", ""), "signed_prekey": first.get("signed_prekey", ""), "spk_signature": first.get("spk_signature", ""), } if first.get("one_time_prekey"): data["one_time_prekey_id"] = first["one_time_prekey_id"] data["one_time_prekey"] = first["one_time_prekey"] await send_resp(msg, writer, "get_key_bundle", "ok", data) async def handle_get_prekey_count(msg: dict, session: dict, writer: ProtocolWriter): """How many OPKs does user have left (for this device)? Also returns SPK age for rotation.""" device_id = session.get("device_id") count = db.count_one_time_prekeys(session["user_id"], device_id=device_id) spk_created_at = "" spk = db.get_signed_prekey(session["user_id"], device_id=device_id) if spk and spk.get("created_at"): spk_created_at = spk["created_at"].isoformat() if hasattr(spk["created_at"], "isoformat") else str(spk["created_at"]) await send_resp(msg, writer, "get_prekey_count", "ok", {"count": count, "spk_created_at": spk_created_at}) async def handle_rotate_keys(msg: dict, session: dict, writer: ProtocolWriter): public_key = msg.get("public_key", "").strip() if not public_key: await send_resp(msg, writer, "rotate_keys", "error", {"message": "Missing public_key"}) return if not _validate_public_key_pem(public_key): await send_resp(msg, writer, "rotate_keys", "error", {"message": "Invalid public key format"}) return db.update_user_rsa_key(session["user_id"], public_key) logger.info("RSA key rotated.") await send_resp(msg, writer, "rotate_keys", "ok", {"message": "OK"}) # Disconnect other sessions async with _clients_lock: writers = connected_clients.get(session["user_id"], []) others = [w for w in writers if w is not writer] connected_clients[session["user_id"]] = [writer] for w in others: try: w.close() except Exception: pass async def handle_pairing_start(msg: dict, writer: ProtocolWriter): await _cleanup_pairings() email = msg.get("email", "").strip() temp_public_key = msg.get("temp_public_key", "").strip() addr = _get_peer_addr(writer) if await _is_rate_limited(_rate_limit_key("pairing_start", addr, email), 10): await send_resp(msg, writer, "pairing_start", "error", {"message": "Too many attempts. Try later."}) return if not email or not temp_public_key: await send_resp(msg, writer, "pairing_start", "error", {"message": "Missing email or temp_public_key"}) return user = db.get_user_by_email(email) if not user: await send_resp(msg, writer, "pairing_start", "error", {"message": "User not found"}) return poll_token = secrets.token_hex(16) async with _pairing_lock: code = _generate_pairing_code() pairing_sessions[code] = { "email": email, "temp_public_key": temp_public_key, "created_at": asyncio.get_event_loop().time(), "payload": None, "poll_token": poll_token, } await send_resp(msg, writer, "pairing_start", "ok", {"code": code, "poll_token": poll_token}) async def handle_pairing_claim(msg: dict, session: dict, writer: ProtocolWriter): await _cleanup_pairings() code = msg.get("code", "").strip() if not code: await send_resp(msg, writer, "pairing_claim", "error", {"message": "Missing code"}) return async with _pairing_lock: p = pairing_sessions.get(code) p_email = p["email"] if p else None temp_pub = p["temp_public_key"] if p else None if p: # Extend TTL — re-encryption may run between claim and send p["created_at"] = asyncio.get_event_loop().time() if not p: await send_resp(msg, writer, "pairing_claim", "error", {"message": "Invalid or expired code"}) return if p_email != session.get("email"): await send_resp(msg, writer, "pairing_claim", "error", {"message": "Not authorized for this code"}) return await send_resp(msg, writer, "pairing_claim", "ok", {"temp_public_key": temp_pub}) async def handle_pairing_send(msg: dict, session: dict, writer: ProtocolWriter): await _cleanup_pairings() code = msg.get("code", "").strip() payload = msg.get("payload") if not code or not payload: await send_resp(msg, writer, "pairing_send", "error", {"message": "Missing code or payload"}) return error_msg = None async with _pairing_lock: p = pairing_sessions.get(code) if not p: error_msg = "Invalid or expired code" elif p["email"] != session.get("email"): error_msg = "Not authorized for this code" else: p["payload"] = payload if error_msg: await send_resp(msg, writer, "pairing_send", "error", {"message": error_msg}) else: await send_resp(msg, writer, "pairing_send", "ok", {"message": "OK"}) async def handle_pairing_poll(msg: dict, writer: ProtocolWriter): await _cleanup_pairings() code = msg.get("code", "").strip() poll_token = msg.get("poll_token", "").strip() addr = _get_peer_addr(writer) if await _is_rate_limited(_rate_limit_key("pairing_poll", addr), 120): await send_resp(msg, writer, "pairing_poll", "error", {"message": "Too many attempts. Try later."}) return if not code: await send_resp(msg, writer, "pairing_poll", "error", {"message": "Missing code"}) return if not poll_token: await send_resp(msg, writer, "pairing_poll", "error", {"message": "Missing poll_token"}) return error_msg = None ready = False payload = None async with _pairing_lock: p = pairing_sessions.get(code) if not p: error_msg = "Invalid or expired code" elif not secrets.compare_digest(p.get("poll_token", ""), poll_token): error_msg = "Invalid poll_token" else: poll_attempts = p.get("poll_attempts", 0) + 1 p["poll_attempts"] = poll_attempts if poll_attempts > PAIRING_MAX_POLL_ATTEMPTS and not p.get("payload"): pairing_sessions.pop(code, None) error_msg = "Code invalidated due to too many attempts" elif p.get("payload"): ready = True payload = p["payload"] pairing_sessions.pop(code, None) if error_msg: await send_resp(msg, writer, "pairing_poll", "error", {"message": error_msg}) elif ready: await send_resp(msg, writer, "pairing_poll", "ok", {"ready": True, "payload": payload}) else: await send_resp(msg, writer, "pairing_poll", "ok", {"ready": False}) async def handle_create_conversation(msg: dict, session: dict, writer: ProtocolWriter): member_emails = msg.get("members", []) name = msg.get("name") # Resolve all member user IDs other_users = [] for email in member_emails: u = db.get_user_by_email(email) if not u: u = db.create_phantom_user(email) async with _clients_lock: phantom_user_ids.add(u["id"]) if u["id"] != session["user_id"]: other_users.append(u) is_dm = len(other_users) == 1 and not name joined_at = datetime.now(timezone.utc) if is_dm: # DMs: add both members directly (no invitation) all_ids = [session["user_id"]] + [u["id"] for u in other_users] conv_id = db.create_conversation(all_ids, joined_at=joined_at, name=name, created_by=session["user_id"]) logger.info("DM conversation created.") await send_resp(msg, writer, "create_conversation", "ok", {"conversation_id": conv_id}) # Notify the other member members_info = db.get_conversation_members(conv_id) member_list = [{"user_id": m["id"], "username": m["username"], "email": m["email"]} for m in members_info] notif_data = { "conversation_id": conv_id, "name": name, "created_by": session["user_id"], "members": member_list, } await _notify_users([u["id"] for u in other_users], "conversation_created", notif_data) else: # Groups: only add creator, create invitations for others conv_id = db.create_conversation([session["user_id"]], joined_at=joined_at, name=name, created_by=session["user_id"]) logger.info("Group conversation created with invitations.") # Create invitations for other members creator_user = db.get_user_by_id(session["user_id"]) creator_name = creator_user["username"] if creator_user else "Unknown" invited_ids = [] async with _clients_lock: phantom_snapshot = set(phantom_user_ids) for u in other_users: db.create_invitation(conv_id, u["id"], session["user_id"]) if u["id"] not in phantom_snapshot: invited_ids.append(u["id"]) # only notify non-phantoms inv_notif = { "conversation_id": conv_id, "conversation_name": name, "invited_by": session["user_id"], "invited_by_username": creator_name, } await _notify_users(invited_ids, "group_invitation", inv_notif) await send_resp(msg, writer, "create_conversation", "ok", {"conversation_id": conv_id}) async def handle_find_conversation(msg: dict, session: dict, writer: ProtocolWriter): email = msg.get("email", "").strip() if not email: await send_resp(msg, writer, "find_conversation", "error", {"message": "Invalid request"}) return addr = _get_peer_addr(writer) if await _is_rate_limited(_rate_limit_key("find_conversation", addr, email), 30): await send_resp(msg, writer, "find_conversation", "error", {"message": "Too many attempts. Try later."}) return other = db.get_user_by_email(email) if not other: other = db.create_phantom_user(email) async with _clients_lock: phantom_user_ids.add(other["id"]) conv_id = db.find_direct_conversation(session["user_id"], other["id"]) await send_resp(msg, writer, "find_conversation", "ok", { "conversation_id": conv_id, "user_id": other["id"], }) async def handle_add_member(msg: dict, session: dict, writer: ProtocolWriter): conv_id = msg.get("conversation_id", "") email = msg.get("email", "").strip() if not conv_id or not email: await send_resp(msg, writer, "add_member", "error", {"message": "Invalid request"}) return if not _valid_uuid(conv_id): await send_resp(msg, writer, "add_member", "error", {"message": "Invalid conversation_id"}) return # L8: validate email format before phantom creation addr = _get_peer_addr(writer) if await _is_rate_limited(_rate_limit_key("add_member", addr, email), 10): await send_resp(msg, writer, "add_member", "error", {"message": "Too many attempts. Try later."}) return if not db.is_conversation_member(conv_id, session["user_id"]): await send_resp(msg, writer, "add_member", "error", {"message": "Not a member"}) return user = db.get_user_by_email(email) if not user: # Create phantom for unregistered email (same as create_conversation) user = db.create_phantom_user(email) async with _clients_lock: phantom_user_ids.add(user["id"]) if db.is_conversation_member(conv_id, user["id"]): await send_resp(msg, writer, "add_member", "error", {"message": "Already a member"}) return if db.has_pending_invitation(conv_id, user["id"]): await send_resp(msg, writer, "add_member", "error", {"message": "Invitation already pending"}) return # Create invitation (for both real and phantom users) db.create_invitation(conv_id, user["id"], session["user_id"]) logger.info("Group invitation created.") await send_resp(msg, writer, "add_member", "ok", {"user_id": user["id"]}) # Push invitation notification only to non-phantom users async with _clients_lock: is_phantom = user["id"] in phantom_user_ids if not is_phantom: conv = db.get_conversation(conv_id) creator_user = db.get_user_by_id(session["user_id"]) creator_name = creator_user["username"] if creator_user else "Unknown" inv_notif = { "conversation_id": conv_id, "conversation_name": conv.get("name") if conv else None, "invited_by": session["user_id"], "invited_by_username": creator_name, } await _notify_users([user["id"]], "group_invitation", inv_notif) async def handle_accept_invitation(msg: dict, session: dict, writer: ProtocolWriter): """Accept a group invitation — add user to conversation members.""" conv_id = msg.get("conversation_id", "") if not conv_id: await send_resp(msg, writer, "accept_invitation", "error", {"message": "Missing conversation_id"}) return if not _valid_uuid(conv_id): await send_resp(msg, writer, "accept_invitation", "error", {"message": "Invalid conversation_id"}) return if not db.has_pending_invitation(conv_id, session["user_id"]): await send_resp(msg, writer, "accept_invitation", "error", {"message": "No pending invitation"}) return joined_at = datetime.now(timezone.utc) db.add_conversation_member(conv_id, session["user_id"], joined_at=joined_at) db.delete_invitation(conv_id, session["user_id"]) logger.info("Invitation accepted.") await send_resp(msg, writer, "accept_invitation", "ok", {"conversation_id": conv_id}) # Notify existing members about the new member user = db.get_user_by_id(session["user_id"]) notif_data = { "conversation_id": conv_id, "user_id": session["user_id"], "username": user["username"] if user else "", "email": user["email"] if user else "", } members = db.get_conversation_members(conv_id) member_ids = [m["id"] for m in members if m["id"] != session["user_id"]] await _notify_users(member_ids, "member_added", notif_data) async def handle_decline_invitation(msg: dict, session: dict, writer: ProtocolWriter): """Decline a group invitation.""" conv_id = msg.get("conversation_id", "") if not conv_id: await send_resp(msg, writer, "decline_invitation", "error", {"message": "Missing conversation_id"}) return if not _valid_uuid(conv_id): await send_resp(msg, writer, "decline_invitation", "error", {"message": "Invalid conversation_id"}) return if not db.has_pending_invitation(conv_id, session["user_id"]): await send_resp(msg, writer, "decline_invitation", "error", {"message": "No pending invitation"}) return db.delete_invitation(conv_id, session["user_id"]) logger.info("Invitation declined.") await send_resp(msg, writer, "decline_invitation", "ok", {"message": "OK"}) async def handle_list_invitations(msg: dict, session: dict, writer: ProtocolWriter): """List pending group invitations for the current user.""" invitations = db.get_pending_invitations(session["user_id"]) result = [] for inv in invitations: entry = { "conversation_id": inv["conversation_id"], "conversation_name": inv.get("conversation_name"), "invited_by": inv["invited_by"], "invited_by_username": inv.get("invited_by_username", ""), "created_at": inv["created_at"].isoformat() if hasattr(inv["created_at"], "isoformat") else str(inv["created_at"]), } result.append(entry) await send_resp(msg, writer, "list_invitations", "ok", {"invitations": result}) async def handle_list_conversations(msg: dict, session: dict, writer: ProtocolWriter): convs = db.list_user_conversations(session["user_id"]) unread = db.get_unread_counts(session["user_id"]) result = [] for c in convs: result.append({ "conversation_id": c["id"], "created_at": c["created_at"].isoformat() if hasattr(c["created_at"], "isoformat") else str(c["created_at"]), "members": c["members"], "name": c.get("name"), "created_by": c.get("created_by"), "avatar_file": c.get("avatar_file"), "unread_count": unread.get(c["id"], 0), }) await send_resp(msg, writer, "list_conversations", "ok", {"conversations": result}) async def handle_send_message(msg: dict, session: dict, writer: ProtocolWriter): conv_id = msg.get("conversation_id", "") if not conv_id: await send_resp(msg, writer, "send_message", "error", {"message": "Missing conversation_id"}) return if not _valid_uuid(conv_id): await send_resp(msg, writer, "send_message", "error", {"message": "Invalid conversation_id"}) return addr = _get_peer_addr(writer) if await _is_rate_limited(_rate_limit_key("send_message", addr, session.get("email")), 20): await send_resp(msg, writer, "send_message", "error", {"message": "Too many attempts. Try later."}) return if not db.is_conversation_member(conv_id, session["user_id"]): await send_resp(msg, writer, "send_message", "error", {"message": "Not a member"}) return # New protocol: ratchet_header + recipients[] with per-user ciphertext ratchet_header_raw = msg.get("ratchet_header") recipients_raw = msg.get("recipients") if not ratchet_header_raw or not recipients_raw: await send_resp(msg, writer, "send_message", "error", {"message": "Missing ratchet_header or recipients"}) return ratchet_header = json.dumps(ratchet_header_raw).encode() if isinstance(ratchet_header_raw, dict) else \ ratchet_header_raw.encode() if isinstance(ratchet_header_raw, str) else ratchet_header_raw x3dh_header_raw = msg.get("x3dh_header") x3dh_header = None if x3dh_header_raw: x3dh_header = json.dumps(x3dh_header_raw).encode() if isinstance(x3dh_header_raw, dict) else \ x3dh_header_raw.encode() if isinstance(x3dh_header_raw, str) else x3dh_header_raw sender_chain_id_b64 = msg.get("sender_chain_id") sender_chain_id = decode_binary(sender_chain_id_b64) if sender_chain_id_b64 else None sender_chain_n = msg.get("sender_chain_n") # Validate recipients are actual members member_ids = {m["id"] for m in db.get_conversation_members(conv_id)} async with _clients_lock: phantom_snapshot = set(phantom_user_ids) db_recipients = [] for r in recipients_raw: uid = r.get("user_id", "") if uid not in member_ids: continue if uid in phantom_snapshot: continue ct_b64 = r.get("encrypted_content", "") nonce_b64 = r.get("nonce", "") if not ct_b64 or not nonce_b64: continue entry = { "user_id": uid, "encrypted_content": decode_binary(ct_b64), "nonce": decode_binary(nonce_b64), } # Per-recipient device_id (multi-device support) r_device_id = r.get("device_id") if r_device_id: entry["device_id"] = r_device_id # Per-recipient ratchet header and x3dh header r_rh = r.get("ratchet_header") if r_rh: entry["ratchet_header"] = json.dumps(r_rh).encode() if isinstance(r_rh, dict) else \ r_rh.encode() if isinstance(r_rh, str) else r_rh r_x3dh = r.get("x3dh_header") if r_x3dh: entry["x3dh_header"] = json.dumps(r_x3dh).encode() if isinstance(r_x3dh, dict) else \ r_x3dh.encode() if isinstance(r_x3dh, str) else r_x3dh db_recipients.append(entry) if not db_recipients: await send_resp(msg, writer, "send_message", "error", {"message": "No valid recipients"}) return image_file_id = msg.get("image_file_id") msg_id = db.store_message( conv_id, session["user_id"], ratchet_header, db_recipients, x3dh_header=x3dh_header, sender_chain_id=sender_chain_id, sender_chain_n=sender_chain_n, image_file_id=image_file_id, sender_device_id=session.get("device_id"), ) # Link image upload to message if present if image_file_id: upload = db.get_image_upload(image_file_id) if upload and upload["completed"] and upload["uploader_id"] == session["user_id"]: db.set_message_image_file_id(msg_id, image_file_id) logger.info("Message stored.") await send_resp(msg, writer, "send_message", "ok", {"message_id": msg_id}) # Notify connected recipients — group all per-device entries by user_id from collections import defaultdict user_entries = defaultdict(list) for r in recipients_raw: uid = r.get("user_id", "") user_entries[uid].append({ "device_id": r.get("device_id", db.SELF_DEVICE_ID), "encrypted_content": r.get("encrypted_content", ""), "nonce": r.get("nonce", ""), "ratchet_header": r.get("ratchet_header") or ratchet_header_raw, "x3dh_header": r.get("x3dh_header") or x3dh_header_raw, }) notifications = [] for uid, entries in user_entries.items(): notif_data = { "message_id": msg_id, "conversation_id": conv_id, "sender_id": session["user_id"], "sender_device_id": session.get("device_id"), "device_entries": entries, } if sender_chain_id_b64: notif_data["sender_chain_id"] = sender_chain_id_b64 if sender_chain_n is not None: notif_data["sender_chain_n"] = sender_chain_n # Also include flat fields for backward compat with old clients # (first entry's data as fallback) if entries: first = entries[0] notif_data["ratchet_header"] = first.get("ratchet_header") or ratchet_header_raw notif_data["encrypted_content"] = first.get("encrypted_content", "") notif_data["nonce"] = first.get("nonce", "") if first.get("x3dh_header"): notif_data["x3dh_header"] = first["x3dh_header"] notifications.append((uid, "new_message", notif_data)) await _notify_users_individual(notifications, exclude_writer=writer) async def handle_get_messages(msg: dict, session: dict, writer: ProtocolWriter): conv_id = msg.get("conversation_id", "") if not conv_id: await send_resp(msg, writer, "get_messages", "error", {"message": "Missing conversation_id"}) return if not _valid_uuid(conv_id): await send_resp(msg, writer, "get_messages", "error", {"message": "Invalid conversation_id"}) return if not db.is_conversation_member(conv_id, session["user_id"]): await send_resp(msg, writer, "get_messages", "error", {"message": "Not a member"}) return limit = min(max(int(msg.get("limit", 50)), 1), 200) offset = max(int(msg.get("offset", 0)), 0) device_id = session.get("device_id") messages = db.get_messages(conv_id, session["user_id"], limit, offset, device_id=device_id) result = [] message_ids = [m["id"] for m in messages] read_status = db.get_message_read_status(message_ids) if message_ids else {} for m in messages: read_by = read_status.get(m["id"], []) # Prefer per-recipient headers (mr_*) over message-level headers rh_raw = m.get("mr_ratchet_header") or m.get("ratchet_header") x3dh_raw = m.get("mr_x3dh_header") or m.get("x3dh_header") entry = { "message_id": m["id"], "sender_id": m.get("sender_id") or "", "ratchet_header": json.loads(rh_raw) if rh_raw else {}, "encrypted_content": encode_binary(m["encrypted_content"]) if m.get("encrypted_content") else "", "nonce": encode_binary(m["nonce"]) if m.get("nonce") else "", "created_at": m["created_at"].isoformat() if hasattr(m["created_at"], "isoformat") else str(m["created_at"]), "read_by": read_by, } if x3dh_raw: entry["x3dh_header"] = json.loads(x3dh_raw) if m.get("sender_chain_id"): entry["sender_chain_id"] = encode_binary(m["sender_chain_id"]) if m.get("sender_chain_n") is not None: entry["sender_chain_n"] = m["sender_chain_n"] if m.get("sender_device_id"): entry["sender_device_id"] = m["sender_device_id"] if m.get("deleted_at"): entry["deleted_at"] = m["deleted_at"].isoformat() if hasattr(m["deleted_at"], "isoformat") else str(m["deleted_at"]) result.append(entry) await send_resp(msg, writer, "get_messages", "ok", {"messages": result}) async def handle_remove_member(msg: dict, session: dict, writer: ProtocolWriter): conv_id = msg.get("conversation_id", "") user_id = msg.get("user_id", "") if not conv_id or not user_id: await send_resp(msg, writer, "remove_member", "error", {"message": "Missing conversation_id or user_id"}) return if not _valid_uuid(conv_id) or not _valid_uuid(user_id): await send_resp(msg, writer, "remove_member", "error", {"message": "Invalid conversation_id or user_id"}) return if not db.is_conversation_member(conv_id, session["user_id"]): await send_resp(msg, writer, "remove_member", "error", {"message": "Not a member"}) return convs = db.list_user_conversations(session["user_id"]) conv_data = None for c in convs: if c["id"] == conv_id: conv_data = c break if not conv_data or conv_data.get("created_by") != session["user_id"]: await send_resp(msg, writer, "remove_member", "error", {"message": "Only the group creator can remove members"}) return if user_id == session["user_id"]: await send_resp(msg, writer, "remove_member", "error", {"message": "Cannot remove yourself"}) return # Get remaining members before removing (to notify them) members_before = db.get_conversation_members(conv_id) # M6: atomic removal — return value confirms row existed removed = db.remove_conversation_member_atomic(conv_id, user_id) if not removed: await send_resp(msg, writer, "remove_member", "error", {"message": "Member already removed"}) return logger.info("Conversation member removed.") await send_resp(msg, writer, "remove_member", "ok", {"message": "OK"}) # Notify removed member and remaining members notif_data = { "conversation_id": conv_id, "user_id": user_id, } member_ids = [m["id"] for m in members_before if m["id"] != session["user_id"]] await _notify_users(member_ids, "member_removed", notif_data) async def handle_leave_group(msg: dict, session: dict, writer: ProtocolWriter): """Leave a group conversation voluntarily.""" conv_id = msg.get("conversation_id", "") if not conv_id: await send_resp(msg, writer, "leave_group", "error", {"message": "Missing conversation_id"}) return if not _valid_uuid(conv_id): await send_resp(msg, writer, "leave_group", "error", {"message": "Invalid conversation_id"}) return if not db.is_conversation_member(conv_id, session["user_id"]): await send_resp(msg, writer, "leave_group", "error", {"message": "Not a member"}) return # Don't allow leaving DMs (2 members without a name) conv = db.get_conversation(conv_id) members = db.get_conversation_members(conv_id) if len(members) <= 2 and not (conv and conv.get("name")): await send_resp(msg, writer, "leave_group", "error", {"message": "Cannot leave a DM conversation"}) return # If creator is leaving, transfer to first remaining member if conv and conv.get("created_by") == session["user_id"]: remaining = [m for m in members if m["id"] != session["user_id"]] if remaining: db.update_conversation_creator(conv_id, remaining[0]["id"]) # M6: atomic removal db.remove_conversation_member_atomic(conv_id, session["user_id"]) logger.info("User left group.") await send_resp(msg, writer, "leave_group", "ok", {"message": "OK"}) # Notify remaining members notif_data = { "conversation_id": conv_id, "user_id": session["user_id"], } member_ids = [m["id"] for m in members if m["id"] != session["user_id"]] await _notify_users(member_ids, "member_removed", notif_data) async def handle_rename_conversation(msg: dict, session: dict, writer: ProtocolWriter): """Rename a group conversation (creator only).""" conv_id = msg.get("conversation_id", "") new_name = msg.get("name", "").strip() if not conv_id or not new_name: await send_resp(msg, writer, "rename_conversation", "error", {"message": "Missing conversation_id or name"}) return if not _valid_uuid(conv_id): await send_resp(msg, writer, "rename_conversation", "error", {"message": "Invalid conversation_id"}) return if len(new_name) > 100: await send_resp(msg, writer, "rename_conversation", "error", {"message": "Name too long (max 100)"}) return if not db.is_conversation_member(conv_id, session["user_id"]): await send_resp(msg, writer, "rename_conversation", "error", {"message": "Not a member"}) return conv = db.get_conversation(conv_id) if not conv or not conv.get("name"): await send_resp(msg, writer, "rename_conversation", "error", {"message": "Cannot rename a DM conversation"}) return if conv.get("created_by") != session["user_id"]: await send_resp(msg, writer, "rename_conversation", "error", {"message": "Only the group creator can rename"}) return db.update_conversation_name(conv_id, new_name) logger.info("Group renamed: %s", conv_id) await send_resp(msg, writer, "rename_conversation", "ok", {"message": "OK"}) # Notify all members members = db.get_conversation_members(conv_id) member_ids = [m["id"] for m in members if m["id"] != session["user_id"]] await _notify_users(member_ids, "conversation_renamed", { "conversation_id": conv_id, "name": new_name, "renamed_by": session["user_id"], }) async def handle_delete_conversation(msg: dict, session: dict, writer: ProtocolWriter): """Delete a conversation for the current user. Removes user from members, deletes the conversation if no members remain.""" conv_id = msg.get("conversation_id", "") if not conv_id: await send_resp(msg, writer, "delete_conversation", "error", {"message": "Missing conversation_id"}) return if not _valid_uuid(conv_id): await send_resp(msg, writer, "delete_conversation", "error", {"message": "Invalid conversation_id"}) return if not db.is_conversation_member(conv_id, session["user_id"]): await send_resp(msg, writer, "delete_conversation", "error", {"message": "Not a member"}) return conv = db.get_conversation(conv_id) members = db.get_conversation_members(conv_id) is_group = len(members) > 2 or (conv and conv.get("name")) # Groups can only be deleted by the creator (admin) if is_group and (not conv or conv.get("created_by") != session["user_id"]): await send_resp(msg, writer, "delete_conversation", "error", {"message": "Only the group creator can delete this conversation"}) return if is_group: # Group: creator deletes for everyone — remove all members, clean up, delete for member in members: db.remove_conversation_member(conv_id, member["id"]) else: # DM: only remove self; other user keeps the conversation db.remove_conversation_member(conv_id, session["user_id"]) remaining_count = db.count_conversation_members(conv_id) if remaining_count == 0: # Clean up uploaded files from disk file_ids = db.get_conversation_file_ids(conv_id) for fid in file_ids: for ext in (".enc", ".tmp"): p = _safe_upload_path(fid, ext) if not p: continue try: p.unlink(missing_ok=True) except Exception: pass db.delete_conversation(conv_id) logger.info("Conversation deleted for user.") await send_resp(msg, writer, "delete_conversation", "ok", {"message": "OK"}) # Notify other members they were removed notif_data = { "conversation_id": conv_id, "user_id": session["user_id"], } member_ids = [m["id"] for m in members if m["id"] != session["user_id"]] await _notify_users(member_ids, "member_removed", notif_data) async def handle_mark_read(msg: dict, session: dict, writer: ProtocolWriter): conv_id = msg.get("conversation_id", "") message_ids = msg.get("message_ids", []) if not conv_id or not message_ids: await send_resp(msg, writer, "mark_read", "error", {"message": "Missing conversation_id or message_ids"}) return if not _valid_uuid(conv_id): await send_resp(msg, writer, "mark_read", "error", {"message": "Invalid conversation_id"}) return if len(message_ids) > 500: await send_resp(msg, writer, "mark_read", "error", {"message": "Too many message_ids (max 500)"}) return if not db.is_conversation_member(conv_id, session["user_id"]): await send_resp(msg, writer, "mark_read", "error", {"message": "Not a member"}) return db.mark_messages_read(conv_id, session["user_id"], message_ids) await send_resp(msg, writer, "mark_read", "ok", {"message": "OK"}) members = db.get_conversation_members(conv_id) notif_data = { "conversation_id": conv_id, "user_id": session["user_id"], "message_ids": message_ids, } member_ids = [m["id"] for m in members if m["id"] != session["user_id"]] await _notify_users(member_ids, "messages_read", notif_data) async def handle_delete_message(msg: dict, session: dict, writer: ProtocolWriter): message_id = msg.get("message_id", "") if not message_id: await send_resp(msg, writer, "delete_message", "error", {"message": "Missing message_id"}) return if not _valid_uuid(message_id): await send_resp(msg, writer, "delete_message", "error", {"message": "Invalid message_id"}) return conv_id = db.get_message_conversation(message_id) if not conv_id: await send_resp(msg, writer, "delete_message", "error", {"message": "Message not found"}) return if not db.is_conversation_member(conv_id, session["user_id"]): await send_resp(msg, writer, "delete_message", "error", {"message": "Not a member"}) return result = db.soft_delete_message(message_id, session["user_id"]) if result is None: await send_resp(msg, writer, "delete_message", "error", {"message": "Cannot delete this message"}) return image_file_id = result.get("image_file_id") if image_file_id: image_path = _safe_upload_path(image_file_id, ".enc") if image_path: try: image_path.unlink(missing_ok=True) except Exception: pass db.delete_image_upload(image_file_id) logger.info("Message deleted.") await send_resp(msg, writer, "delete_message", "ok", {"message_id": message_id}) members = db.get_conversation_members(conv_id) notif_data = {"message_id": message_id, "conversation_id": conv_id} member_ids = [m["id"] for m in members if m["id"] != session["user_id"]] await _notify_users(member_ids, "message_deleted", notif_data) async def handle_upload_image_start(msg: dict, session: dict, writer: ProtocolWriter): conv_id = msg.get("conversation_id", "") file_size = msg.get("file_size", 0) file_id = msg.get("file_id", "") file_type = msg.get("file_type", "image") # "image" or "file" if not conv_id or not file_id: await send_resp(msg, writer, "upload_image_start", "error", {"message": "Missing fields"}) return if not _valid_uuid(file_id): await send_resp(msg, writer, "upload_image_start", "error", {"message": "Invalid file_id"}) return if not db.is_conversation_member(conv_id, session["user_id"]): await send_resp(msg, writer, "upload_image_start", "error", {"message": "Not a member"}) return max_bytes = MAX_FILE_BYTES if file_type == "file" else MAX_IMAGE_BYTES if max_bytes > 0 and file_size > max_bytes: await send_resp(msg, writer, "upload_image_start", "error", {"message": f"File too large (max {max_bytes} bytes)"}) return UPLOAD_DIR.mkdir(parents=True, exist_ok=True) temp_path = _safe_upload_path(file_id, ".tmp") if not temp_path: await send_resp(msg, writer, "upload_image_start", "error", {"message": "Invalid file_id"}) return temp_path.write_bytes(b"") async with _uploads_lock: pending_uploads[file_id] = { "temp_path": str(temp_path), "received_bytes": 0, "file_size": file_size, "max_bytes": max_bytes, "conv_id": conv_id, "uploader_id": session["user_id"], } db.create_image_upload(file_id, conv_id, session["user_id"], file_size) logger.info("Image upload started: %s", file_id) await send_resp(msg, writer, "upload_image_start", "ok", {"file_id": file_id}) async def handle_upload_image_chunk(msg: dict, session: dict, writer: ProtocolWriter): file_id = msg.get("file_id", "") chunk_data = msg.get("data", "") if not file_id or not chunk_data: await send_resp(msg, writer, "upload_image_chunk", "error", {"message": "Missing fields"}) return async with _uploads_lock: upload = pending_uploads.get(file_id) if not upload or upload["uploader_id"] != session["user_id"]: upload = None else: temp_path_str = upload["temp_path"] upload_max = upload.get("max_bytes", 0) if not upload: await send_resp(msg, writer, "upload_image_chunk", "error", {"message": "No active upload"}) return raw = decode_binary(chunk_data) temp_path = Path(temp_path_str) await asyncio.to_thread(_append_file, temp_path, raw) over_limit = False async with _uploads_lock: upload = pending_uploads.get(file_id) if upload: upload["received_bytes"] += len(raw) if upload_max > 0 and upload["received_bytes"] > upload_max: pending_uploads.pop(file_id, None) over_limit = True received = upload["received_bytes"] if over_limit: temp_path.unlink(missing_ok=True) await send_resp(msg, writer, "upload_image_chunk", "error", {"message": "Upload exceeds size limit"}) return await send_resp(msg, writer, "upload_image_chunk", "ok", {"received": received}) async def handle_upload_image_end(msg: dict, session: dict, writer: ProtocolWriter): file_id = msg.get("file_id", "") if not file_id: await send_resp(msg, writer, "upload_image_end", "error", {"message": "Missing file_id"}) return async with _uploads_lock: upload = pending_uploads.pop(file_id, None) if not upload or upload["uploader_id"] != session["user_id"]: await send_resp(msg, writer, "upload_image_end", "error", {"message": "No active upload"}) return temp_path = Path(upload["temp_path"]) if upload["received_bytes"] != upload["file_size"]: temp_path.unlink(missing_ok=True) await send_resp(msg, writer, "upload_image_end", "error", {"message": f"Incomplete upload: received {upload['received_bytes']} of {upload['file_size']} bytes"}) return final_path = _safe_upload_path(file_id, ".enc") if not final_path: temp_path.unlink(missing_ok=True) await send_resp(msg, writer, "upload_image_end", "error", {"message": "Invalid file_id"}) return def _move_file(): try: temp_path.rename(final_path) except Exception: import shutil shutil.move(str(temp_path), str(final_path)) await asyncio.to_thread(_move_file) db.complete_image_upload(file_id) logger.info("Image upload completed: %s (%d bytes)", file_id, upload["received_bytes"]) await send_resp(msg, writer, "upload_image_end", "ok", {"file_id": file_id}) async def handle_download_image(msg: dict, session: dict, writer: ProtocolWriter): file_id = msg.get("file_id", "") offset = msg.get("offset", 0) if not file_id: await send_resp(msg, writer, "download_image", "error", {"message": "Missing file_id"}) return if not _valid_uuid(file_id): await send_resp(msg, writer, "download_image", "error", {"message": "Invalid file_id"}) return upload = db.get_image_upload(file_id) if not upload or not upload["completed"]: await send_resp(msg, writer, "download_image", "error", {"message": "File not found"}) return if not db.is_conversation_member(upload["conversation_id"], session["user_id"]): await send_resp(msg, writer, "download_image", "error", {"message": "Not a member"}) return file_path = _safe_upload_path(file_id, ".enc") if not file_path or not file_path.exists(): await send_resp(msg, writer, "download_image", "error", {"message": "File not found"}) return file_size = file_path.stat().st_size chunk = await asyncio.to_thread(_read_file_chunk, file_path, offset, IMAGE_CHUNK_SIZE) done = (offset + len(chunk)) >= file_size await send_resp(msg, writer, "download_image", "ok", { "file_id": file_id, "data": encode_binary(chunk), "offset": offset, "done": done, "total_size": file_size, }) MAX_AVATAR_BYTES = 2 * 1024 * 1024 # 2 MB async def handle_get_profile(msg: dict, session: dict, writer: ProtocolWriter): """Get user profile (respects visibility for other users).""" target_user_id = msg.get("user_id", "").strip() if not target_user_id: target_user_id = session["user_id"] elif not _valid_uuid(target_user_id): await send_resp(msg, writer, "get_profile", "error", {"message": "Invalid user_id"}) return profile = db.get_user_profile(target_user_id, viewer_id=session["user_id"]) if not profile: await send_resp(msg, writer, "get_profile", "error", {"message": "User not found"}) return # Serialize datetime fields for key in ("created_at", "updated_at"): if profile.get(key) and hasattr(profile[key], "isoformat"): profile[key] = profile[key].isoformat() await send_resp(msg, writer, "get_profile", "ok", profile) async def handle_update_profile(msg: dict, session: dict, writer: ProtocolWriter): """Update own profile fields.""" fields = {} for key in ("phone", "phone_visible", "email_visible", "location", "location_visible"): if key in msg: fields[key] = msg[key] if not fields: await send_resp(msg, writer, "update_profile", "error", {"message": "No fields to update"}) return db.update_user_profile(session["user_id"], **fields) await send_resp(msg, writer, "update_profile", "ok", {"message": "OK"}) async def handle_update_avatar(msg: dict, session: dict, writer: ProtocolWriter): """Upload avatar (base64 in single message, max 2MB).""" avatar_b64 = msg.get("data", "") if not avatar_b64: await send_resp(msg, writer, "update_avatar", "error", {"message": "Missing data"}) return avatar_data = decode_binary(avatar_b64) if len(avatar_data) > MAX_AVATAR_BYTES: await send_resp(msg, writer, "update_avatar", "error", {"message": f"Avatar too large (max {MAX_AVATAR_BYTES} bytes)"}) return # Detect format from magic bytes ext = "jpg" if avatar_data[:8] == b'\x89PNG\r\n\x1a\n': ext = "png" avatar_dir = UPLOAD_DIR / "avatars" avatar_dir.mkdir(parents=True, exist_ok=True) filename = f"{session['user_id']}.{ext}" avatar_path = _safe_avatar_path(filename) if not avatar_path: await send_resp(msg, writer, "update_avatar", "error", {"message": "Invalid path"}) return await asyncio.to_thread(avatar_path.write_bytes, avatar_data) db.update_user_profile(session["user_id"], avatar_file=filename) logger.info("Avatar updated for user %s", session["user_id"]) await send_resp(msg, writer, "update_avatar", "ok", {"avatar_file": filename}) async def handle_get_avatar(msg: dict, session: dict, writer: ProtocolWriter): """Download avatar for a user.""" target_user_id = msg.get("user_id", "").strip() if not target_user_id: await send_resp(msg, writer, "get_avatar", "error", {"message": "Missing user_id"}) return if not _valid_uuid(target_user_id): await send_resp(msg, writer, "get_avatar", "error", {"message": "Invalid user_id"}) return profile = db.get_user_profile(target_user_id) if not profile or not profile.get("avatar_file"): await send_resp(msg, writer, "get_avatar", "error", {"message": "No avatar"}) return avatar_path = _safe_avatar_path(profile["avatar_file"]) if not avatar_path or not avatar_path.exists(): await send_resp(msg, writer, "get_avatar", "error", {"message": "Avatar file not found"}) return avatar_data = await asyncio.to_thread(avatar_path.read_bytes) await send_resp(msg, writer, "get_avatar", "ok", { "user_id": target_user_id, "data": encode_binary(avatar_data), "filename": profile["avatar_file"], }) async def handle_update_group_avatar(msg: dict, session: dict, writer: ProtocolWriter): """Upload avatar for a group conversation (base64, max 2MB). Only members can set it.""" conv_id = msg.get("conversation_id", "").strip() avatar_b64 = msg.get("data", "") if not conv_id or not avatar_b64: await send_resp(msg, writer, "update_group_avatar", "error", {"message": "Missing fields"}) return if not _valid_uuid(conv_id): await send_resp(msg, writer, "update_group_avatar", "error", {"message": "Invalid conversation_id"}) return if not db.is_conversation_member(conv_id, session["user_id"]): await send_resp(msg, writer, "update_group_avatar", "error", {"message": "Not a member"}) return avatar_data = decode_binary(avatar_b64) if len(avatar_data) > MAX_AVATAR_BYTES: await send_resp(msg, writer, "update_group_avatar", "error", {"message": f"Avatar too large (max {MAX_AVATAR_BYTES} bytes)"}) return ext = "jpg" if avatar_data[:8] == b'\x89PNG\r\n\x1a\n': ext = "png" avatar_dir = UPLOAD_DIR / "avatars" avatar_dir.mkdir(parents=True, exist_ok=True) filename = f"group_{conv_id}.{ext}" avatar_path = _safe_avatar_path(filename) if not avatar_path: await send_resp(msg, writer, "update_group_avatar", "error", {"message": "Invalid path"}) return await asyncio.to_thread(avatar_path.write_bytes, avatar_data) db.update_conversation_avatar(conv_id, filename) logger.info("Group avatar updated for conversation %s", conv_id) await send_resp(msg, writer, "update_group_avatar", "ok", {"avatar_file": filename}) async def handle_get_group_avatar(msg: dict, session: dict, writer: ProtocolWriter): """Download avatar for a group conversation.""" conv_id = msg.get("conversation_id", "").strip() if not conv_id: await send_resp(msg, writer, "get_group_avatar", "error", {"message": "Missing conversation_id"}) return if not _valid_uuid(conv_id): await send_resp(msg, writer, "get_group_avatar", "error", {"message": "Invalid conversation_id"}) return if not db.is_conversation_member(conv_id, session["user_id"]): await send_resp(msg, writer, "get_group_avatar", "error", {"message": "Not a member"}) return conv = db.get_conversation(conv_id) if not conv or not conv.get("avatar_file"): await send_resp(msg, writer, "get_group_avatar", "error", {"message": "No avatar"}) return avatar_path = _safe_avatar_path(conv["avatar_file"]) if not avatar_path or not avatar_path.exists(): await send_resp(msg, writer, "get_group_avatar", "error", {"message": "Avatar file not found"}) return avatar_data = await asyncio.to_thread(avatar_path.read_bytes) await send_resp(msg, writer, "get_group_avatar", "ok", { "conversation_id": conv_id, "data": encode_binary(avatar_data), "filename": conv["avatar_file"], }) async def handle_list_devices(msg: dict, session: dict, writer: ProtocolWriter): """List all devices for the current user.""" devices = db.get_user_devices(session["user_id"]) result = [] for d in devices: entry = { "device_id": d["id"], "device_name": d.get("device_name"), "created_at": d["created_at"].isoformat() if hasattr(d["created_at"], "isoformat") else str(d["created_at"]), "last_seen_at": d["last_seen_at"].isoformat() if d.get("last_seen_at") and hasattr(d["last_seen_at"], "isoformat") else (str(d["last_seen_at"]) if d.get("last_seen_at") else None), "is_current": d["id"] == session.get("device_id"), } result.append(entry) await send_resp(msg, writer, "list_devices", "ok", {"devices": result}) async def handle_remove_device(msg: dict, session: dict, writer: ProtocolWriter): """Remove a device (cannot remove current device).""" device_id = msg.get("device_id", "").strip() if not device_id: await send_resp(msg, writer, "remove_device", "error", {"message": "Missing device_id"}) return if not _valid_uuid(device_id): await send_resp(msg, writer, "remove_device", "error", {"message": "Invalid device_id"}) return if device_id == session.get("device_id"): await send_resp(msg, writer, "remove_device", "error", {"message": "Cannot remove current device"}) return dev = db.get_device(device_id) if not dev or dev["user_id"] != session["user_id"]: await send_resp(msg, writer, "remove_device", "error", {"message": "Device not found"}) return db.delete_device(device_id) logger.info("Device removed: %s", device_id) await send_resp(msg, writer, "remove_device", "ok", {"message": "OK"}) async def handle_session_reset(msg: dict, session: dict, writer: ProtocolWriter): """Notify peer to reset a corrupted Double Ratchet session.""" peer_user_id = msg.get("peer_user_id", "").strip() peer_device_id = msg.get("peer_device_id", "").strip() or None if not peer_user_id or not _valid_uuid(peer_user_id): await send_resp(msg, writer, "session_reset", "error", {"message": "Invalid peer_user_id"}) return if peer_device_id and not _valid_uuid(peer_device_id): await send_resp(msg, writer, "session_reset", "error", {"message": "Invalid peer_device_id"}) return # Push notification to peer await _notify_users([peer_user_id], "session_reset", { "from_user_id": session["user_id"], "from_device_id": session.get("device_id"), }) await send_resp(msg, writer, "session_reset", "ok", {}) async def handle_reencrypt_messages(msg: dict, session: dict, writer: ProtocolWriter): """Re-encrypt message history with self-encryption key (for device pairing).""" updates_raw = msg.get("updates", []) if not updates_raw: await send_resp(msg, writer, "reencrypt_messages", "error", {"message": "No updates"}) return if len(updates_raw) > 500: await send_resp(msg, writer, "reencrypt_messages", "error", {"message": "Too many updates (max 500 per request)"}) return updates = [] for u in updates_raw: mid = u.get("message_id", "") ct_b64 = u.get("encrypted_content", "") nonce_b64 = u.get("nonce", "") if not mid or not ct_b64 or not nonce_b64: continue updates.append({ "message_id": mid, "encrypted_content": decode_binary(ct_b64), "nonce": decode_binary(nonce_b64), }) if not updates: await send_resp(msg, writer, "reencrypt_messages", "error", {"message": "No valid updates"}) return db.batch_reencrypt_messages(session["user_id"], updates) logger.info("Re-encrypted %d messages for user.", len(updates)) await send_resp(msg, writer, "reencrypt_messages", "ok", {"count": len(updates)}) async def _cleanup_uploads(): stale = db.get_stale_uploads(3600) for s in stale: fid = s["file_id"] for ext in (".tmp", ".enc"): p = _safe_upload_path(fid, ext) if not p: continue try: p.unlink(missing_ok=True) except Exception: pass db.delete_image_upload(fid) async with _uploads_lock: pending_uploads.pop(fid, None) if stale: logger.info("Cleaned up %d stale uploads.", len(stale)) async def handle_client(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): global current_connections addr = _get_peer_addr(ProtocolWriter(writer)) async with _conn_lock: current_connections += 1 connection_counts[addr] = connection_counts.get(addr, 0) + 1 over_limit = (current_connections > MAX_CONNECTIONS_GLOBAL or connection_counts[addr] > MAX_CONNECTIONS_PER_IP) if over_limit: try: writer.close() except Exception: pass async with _conn_lock: current_connections = max(0, current_connections - 1) connection_counts[addr] = max(0, connection_counts.get(addr, 1) - 1) return logger.debug("Client connected.") proto_reader = ProtocolReader(reader) proto_writer = ProtocolWriter(writer) session = None state = {"_req_times": []} try: while True: try: msg = await proto_reader.read_message() except ValueError as e: try: await proto_writer.send_response("protocol_error", "error", {"message": str(e)}) except Exception: pass break if msg is None: break msg_type = msg.get("type", "") now = asyncio.get_event_loop().time() times = [t for t in state["_req_times"] if now - t <= CONNECTION_RL_WINDOW] if len(times) >= CONNECTION_RL_MAX: await send_resp(msg, proto_writer, msg_type, "error", {"message": "Too many requests. Slow down."}) state["_req_times"] = times continue times.append(now) state["_req_times"] = times try: if msg_type == "register": await handle_register_start(msg, proto_writer) elif msg_type == "register_confirm": await handle_register_confirm(msg, proto_writer) elif msg_type == "login_start": await handle_login_start(msg, proto_writer, state) elif msg_type == "login_finish": result = await handle_login_finish(msg, proto_writer, state) if result: session = result elif msg_type == "pairing_start": await handle_pairing_start(msg, proto_writer) elif msg_type == "pairing_poll": await handle_pairing_poll(msg, proto_writer) elif session is None: await send_resp(msg, proto_writer, msg_type, "error", {"message": "Not logged in"}) elif msg_type == "get_user_info": await handle_get_user_info(msg, proto_writer) elif msg_type == "upload_prekeys": await handle_upload_prekeys(msg, session, proto_writer) elif msg_type == "get_key_bundle": await handle_get_key_bundle(msg, session, proto_writer) elif msg_type == "get_prekey_count": await handle_get_prekey_count(msg, session, proto_writer) elif msg_type == "create_conversation": await handle_create_conversation(msg, session, proto_writer) elif msg_type == "find_conversation": await handle_find_conversation(msg, session, proto_writer) elif msg_type == "add_member": await handle_add_member(msg, session, proto_writer) elif msg_type == "accept_invitation": await handle_accept_invitation(msg, session, proto_writer) elif msg_type == "decline_invitation": await handle_decline_invitation(msg, session, proto_writer) elif msg_type == "list_invitations": await handle_list_invitations(msg, session, proto_writer) elif msg_type == "list_conversations": await handle_list_conversations(msg, session, proto_writer) elif msg_type == "send_message": await handle_send_message(msg, session, proto_writer) elif msg_type == "get_messages": await handle_get_messages(msg, session, proto_writer) elif msg_type == "rotate_keys": await handle_rotate_keys(msg, session, proto_writer) elif msg_type == "remove_member": await handle_remove_member(msg, session, proto_writer) elif msg_type == "leave_group": await handle_leave_group(msg, session, proto_writer) elif msg_type == "rename_conversation": await handle_rename_conversation(msg, session, proto_writer) elif msg_type == "delete_conversation": await handle_delete_conversation(msg, session, proto_writer) elif msg_type == "mark_read": await handle_mark_read(msg, session, proto_writer) elif msg_type == "pairing_claim": await handle_pairing_claim(msg, session, proto_writer) elif msg_type == "pairing_send": await handle_pairing_send(msg, session, proto_writer) elif msg_type == "delete_message": await handle_delete_message(msg, session, proto_writer) elif msg_type == "upload_image_start": await handle_upload_image_start(msg, session, proto_writer) elif msg_type == "upload_image_chunk": await handle_upload_image_chunk(msg, session, proto_writer) elif msg_type == "upload_image_end": await handle_upload_image_end(msg, session, proto_writer) elif msg_type == "download_image": await handle_download_image(msg, session, proto_writer) elif msg_type == "get_profile": await handle_get_profile(msg, session, proto_writer) elif msg_type == "update_profile": await handle_update_profile(msg, session, proto_writer) elif msg_type == "update_avatar": await handle_update_avatar(msg, session, proto_writer) elif msg_type == "get_avatar": await handle_get_avatar(msg, session, proto_writer) elif msg_type == "update_group_avatar": await handle_update_group_avatar(msg, session, proto_writer) elif msg_type == "get_group_avatar": await handle_get_group_avatar(msg, session, proto_writer) elif msg_type == "reencrypt_messages": await handle_reencrypt_messages(msg, session, proto_writer) elif msg_type == "list_devices": await handle_list_devices(msg, session, proto_writer) elif msg_type == "remove_device": await handle_remove_device(msg, session, proto_writer) elif msg_type == "session_reset": await handle_session_reset(msg, session, proto_writer) else: await send_resp(msg, proto_writer, msg_type, "error", {"message": "Unknown type"}) except Exception as e: logger.warning("Handler error for '%s': %s", msg_type, e, exc_info=True) try: await send_resp(msg, proto_writer, msg_type, "error", {"message": "Internal server error"}) except Exception: break # Can't send response — connection is dead except Exception as e: logger.warning("Client connection error: %s", e) finally: async with _conn_lock: current_connections = max(0, current_connections - 1) connection_counts[addr] = max(0, connection_counts.get(addr, 1) - 1) offline_targets = [] if session: uid = session["user_id"] contacts = db.get_user_contacts(uid) async with _clients_lock: writer_device_map.pop(id(proto_writer), None) if uid in connected_clients: remaining = [w for w in connected_clients[uid] if w is not proto_writer] if remaining: connected_clients[uid] = remaining else: del connected_clients[uid] # User fully offline — snapshot targets under lock for contact_id in contacts: for cw in connected_clients.get(contact_id, []): offline_targets.append(cw) # Send offline notifications outside lock for cw in offline_targets: try: await cw.send_response("user_offline", "ok", {"user_id": uid}) except Exception: pass writer.close() logger.debug("Client disconnected.") async def main(): setup_logging() host = os.getenv("SERVER_HOST", "127.0.0.1") port = int(os.getenv("SERVER_PORT", "9999")) tls_enabled = os.getenv("TLS_ENABLED", "false").lower() in ("1", "true", "yes") tls_required = os.getenv("TLS_REQUIRED", "false").lower() in ("1", "true", "yes") tls_autogen = os.getenv("TLS_AUTOGEN", "false").lower() in ("1", "true", "yes") is_dev = os.getenv("ENVIRONMENT", "").lower() in ("dev", "development") ssl_context = None if tls_required and not tls_enabled: raise RuntimeError("TLS_REQUIRED is enabled but TLS is not enabled.") if tls_enabled: cert_file = os.getenv("TLS_CERT_FILE", "").strip() key_file = os.getenv("TLS_KEY_FILE", "").strip() if not cert_file or not key_file: if tls_autogen: if not is_dev: raise RuntimeError("TLS_AUTOGEN is only allowed when ENVIRONMENT=dev") cert_dir = Path(__file__).resolve().parent / "certs" cert_dir.mkdir(parents=True, exist_ok=True) cert_file = str(cert_dir / "server.crt") key_file = str(cert_dir / "server.key") if not (os.path.exists(cert_file) and os.path.exists(key_file)): try: subprocess.run( [ "openssl", "req", "-x509", "-newkey", "rsa:4096", "-keyout", key_file, "-out", cert_file, "-days", "365", "-nodes", "-subj", "/CN=localhost", ], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, ) os.chmod(key_file, 0o600) except FileNotFoundError: raise RuntimeError("OpenSSL not found.") except subprocess.CalledProcessError: raise RuntimeError("Failed to auto-generate TLS cert.") logger.warning("Using auto-generated self-signed certificate — not for production use.") else: raise RuntimeError("TLS is enabled but TLS_CERT_FILE or TLS_KEY_FILE is missing.") ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) ssl_context.load_cert_chain(certfile=cert_file, keyfile=key_file) else: logger.warning("TLS is disabled — traffic is unencrypted. Set TLS_ENABLED=true for production.") UPLOAD_DIR.mkdir(parents=True, exist_ok=True) # Load phantom user IDs from DB into in-memory cache phantom_user_ids.update(db.get_all_phantom_user_ids()) if phantom_user_ids: logger.info("Loaded %d phantom user IDs.", len(phantom_user_ids)) server = await asyncio.start_server( handle_client, host, port, limit=MAX_MESSAGE_BYTES, ssl=ssl_context, ) logger.info("Encrypted chat server v%s listening on %s:%s", VERSION, host, port) async def _cleanup_rate_limits(): async with _conn_lock: now = asyncio.get_event_loop().time() window_start = now - RATE_LIMIT_WINDOW stale_keys = [k for k, times in rate_limits.items() if not any(t >= window_start for t in times)] for k in stale_keys: del rate_limits[k] stale_conns = [k for k, v in connection_counts.items() if v <= 0] for k in stale_conns: del connection_counts[k] async def _periodic_cleanup(): while True: await asyncio.sleep(600) try: await _cleanup_uploads() except Exception as e: logger.warning("Upload cleanup error: %s", e) try: await _cleanup_rate_limits() except Exception as e: logger.warning("Rate limit cleanup error: %s", e) # L8: clean up stale phantom users (>30 days, no real conversations) try: deleted = db.cleanup_stale_phantoms(30) if deleted: async with _clients_lock: phantom_user_ids.clear() phantom_user_ids.update(db.get_all_phantom_user_ids()) logger.info("Cleaned up %d stale phantom users.", deleted) except Exception as e: logger.warning("Phantom cleanup error: %s", e) asyncio.create_task(_periodic_cleanup()) loop = asyncio.get_running_loop() stop = loop.create_future() def signal_handler(): if not stop.done(): stop.set_result(None) for sig in (signal.SIGINT, signal.SIGTERM): loop.add_signal_handler(sig, signal_handler) async with server: await stop # Force-close all connected clients BEFORE exiting context manager, # otherwise wait_closed() blocks forever waiting for handle_client tasks logger.info("Shutting down — closing %d client connections...", sum(len(ws) for ws in connected_clients.values())) async with _clients_lock: all_writers = [w for writers in connected_clients.values() for w in writers] connected_clients.clear() writer_device_map.clear() for w in all_writers: try: w.close() except Exception: pass logger.info("Server shut down.") if __name__ == "__main__": asyncio.run(main())