"""Asyncio TCP server — stores and relays encrypted blobs without seeing content.""" import asyncio from concurrent.futures import ThreadPoolExecutor import hashlib import hmac import ipaddress import json import logging import os import re import secrets import signal import smtplib import ssl import subprocess import sys from email.mime.text import MIMEText from pathlib import Path from datetime import datetime, timezone from dotenv import load_dotenv load_dotenv() import db from crypto_utils import load_public_key, rsa_verify, load_ed25519_public, ed25519_verify, serialize_x25519_public from protocol import VERSION, MIN_CLIENT_VERSION, version_gte, ProtocolReader, ProtocolWriter, encode_binary, decode_binary, MAX_MESSAGE_BYTES, MAX_IMAGE_BYTES, MAX_FILE_BYTES, IMAGE_CHUNK_SIZE class _AsyncDB: """Async proxy — offloads every synchronous db.* call to a thread via asyncio.to_thread(). This prevents blocking the asyncio event loop during MySQL I/O. Wrapper functions are cached after first access for efficiency. """ def __getattr__(self, name: str): func = getattr(db, name) async def wrapper(*args, **kwargs): return await asyncio.to_thread(func, *args, **kwargs) wrapper.__name__ = name wrapper.__qualname__ = f"_AsyncDB.{name}" setattr(self, name, wrapper) return wrapper adb = _AsyncDB() # Connected clients: user_id -> list[ProtocolWriter] connected_clients: dict[str, list[ProtocolWriter]] = {} # Writer -> device_id mapping (id(writer) -> device_id) writer_device_map: dict[int, str] = {} # Pairing sessions: code -> data pairing_sessions: dict[str, dict] = {} pending_registrations: dict[str, dict] = {} # Used PoW challenges (prevents replay within validity window) _used_pow_challenges: dict[str, float] = {} # challenge -> used_at # Pending image uploads: file_id -> {temp_path, received_bytes, file_size, conv_id} pending_uploads: dict[str, dict] = {} # Phantom user IDs (loaded at startup, updated on create/delete) phantom_user_ids: set[str] = set() # Locks for shared mutable state (H4 race condition fix) _clients_lock = asyncio.Lock() # Protects: connected_clients, writer_device_map, phantom_user_ids _conn_lock = asyncio.Lock() # Protects: connection_counts, current_connections, rate_limits _pairing_lock = asyncio.Lock() # Protects: pairing_sessions, pending_registrations, _used_pow_challenges _uploads_lock = asyncio.Lock() # Protects: pending_uploads _phantom_lock = asyncio.Lock() # Serializes phantom user creation (cap check + DB create + set add) UPLOAD_DIR = Path(os.getenv("UPLOAD_DIR", "uploads")) def _secure_delete(p: Path): """Overwrite file with random data before deletion (anti-forensic wipe).""" try: if not p.exists(): return size = p.stat().st_size if size > 0: with open(p, "r+b") as f: f.write(os.urandom(size)) f.flush() os.fsync(f.fileno()) p.unlink() except Exception: try: p.unlink(missing_ok=True) except Exception: pass # C6 fix: UUID validation + safe path construction to prevent path traversal _UUID_RE = re.compile(r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$', re.IGNORECASE) def _valid_uuid(value: str) -> bool: """Validate that value is a canonical UUID (no path components).""" return bool(_UUID_RE.match(value)) # L8 fix: email validation to prevent phantom DB inflation _EMAIL_RE = re.compile(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$") def _valid_email(email: str) -> bool: """Validate basic email format (L8).""" return bool(_EMAIL_RE.match(email)) and len(email) <= 254 # C2 fix: ratchet/x3dh header validation _RATCHET_HEADER_KEYS = {"dh_pub", "n", "pn"} _MAX_HEADER_BYTES = 4096 def _validate_header(raw, name: str) -> bytes | None: """Validate and serialize a ratchet/x3dh header. Accepts only dict with expected keys, rejects str/bytes to prevent poisoned headers from being stored. Validates that ratchet headers contain the required keys (dh_pub, n, pn) with correct types. Returns UTF-8 encoded JSON bytes or None if invalid. """ if not isinstance(raw, dict): return None serialized = json.dumps(raw) if len(serialized) > _MAX_HEADER_BYTES: return None # Validate ratchet header keys/types if this looks like one if name in ("ratchet_header", "recipient_ratchet_header"): # Accept self-encrypted marker {"self": true} if raw.get("self") is True and len(raw) == 1: return serialized.encode() if not _RATCHET_HEADER_KEYS.issubset(raw.keys()): return None if not isinstance(raw.get("dh_pub"), str): return None if type(raw.get("n")) is not int or type(raw.get("pn")) is not int: return None return serialized.encode() def _append_file(path: Path, data: bytes): """Append data to file (runs in thread pool to avoid blocking event loop).""" with open(path, "ab") as f: f.write(data) def _read_file_chunk(path: Path, offset: int, size: int) -> bytes: """Read a chunk from file (runs in thread pool to avoid blocking event loop).""" with open(path, "rb") as f: f.seek(offset) return f.read(size) def _safe_upload_path(file_id: str, suffix: str) -> Path | None: """Return resolved path inside UPLOAD_DIR, or None if traversal detected.""" p = (UPLOAD_DIR / f"{file_id}{suffix}").resolve() if not p.is_relative_to(UPLOAD_DIR.resolve()): return None return p def _safe_avatar_path(filename: str) -> Path | None: """Return resolved avatar path inside UPLOAD_DIR/avatars, or None if traversal detected.""" avatar_dir = (UPLOAD_DIR / "avatars").resolve() p = (UPLOAD_DIR / "avatars" / filename).resolve() if not p.is_relative_to(avatar_dir): return None return p PAIRING_TTL_SECONDS = 120 REGISTER_TTL_SECONDS = 600 # 10 min (was 3600) — faster slot release under load PAIRING_MAX_POLL_ATTEMPTS = 90 PAIRING_MAX_SESSIONS = 100 # global cap on concurrent pairing sessions MAX_PENDING_REGISTRATIONS = 1000 # global cap on pending registration codes MAX_PENDING_PER_IP = 5 # per-IP cap on pending registrations MAX_PENDING_PER_SUBNET = 20 # per-/24 (IPv4) or /64 (IPv6) cap REGISTRATION_PRESSURE_THRESHOLD = 0.8 # 80% → tighten limits + require PoW POW_DIFFICULTY = 20 # leading zero bits in SHA-256 (~1M hashes, ~0.5-2s) SMTP_RATE_GLOBAL = 30 # registration emails per minute (global) SMTP_RATE_PER_IP = 3 # registration emails per minute (per IP) SMTP_RATE_PER_TARGET = 2 # registration emails per minute (per target email) MAX_PHANTOM_USERS = 500 # global cap on phantom user count MAX_UPLOADS_GLOBAL = 200 # global cap on concurrent in-flight uploads MAX_UPLOADS_PER_USER = 5 # per-user cap on concurrent in-flight uploads UPLOAD_STALE_SECONDS = 600 # stale upload threshold (10 min) # SMTP configuration for registration codes SMTP_HOST = os.getenv("SMTP_HOST", "") SMTP_PORT = int(os.getenv("SMTP_PORT", "587")) SMTP_USER = os.getenv("SMTP_USER", "") SMTP_PASS = os.getenv("SMTP_PASS", "") SMTP_FROM = os.getenv("SMTP_FROM", "") RATE_LIMIT_WINDOW = 60.0 # seconds CONNECTION_RL_WINDOW = 1.0 # seconds CONNECTION_RL_MAX = 20 # max requests per window per connection MAX_CONNECTIONS_PER_IP = 10 MAX_CONNECTIONS_GLOBAL = 200 METADATA_RETENTION_DAYS = int(os.getenv("METADATA_RETENTION_DAYS", "90")) def setup_logging(): level_name = os.getenv("LOG_LEVEL", "INFO").upper() level = getattr(logging, level_name, logging.WARNING) logging.basicConfig(level=level, format="%(asctime)s %(levelname)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S") logger = logging.getLogger("encrypted_chat.server") def _who(session: dict | None) -> str: """Format session info for logging: truncated user_id + device prefix. Avoids leaking usernames and emails into log files. """ if not session: return "" uid = session.get("user_id", "?")[:8] dev = session.get("device_id", "")[:8] if session.get("device_id") else "" return f"u={uid} d={dev}" if dev else f"u={uid}" rate_limits: dict[str, list[float]] = {} connection_counts: dict[str, int] = {} current_connections = 0 def _rate_limit_key(action: str, addr: str, email: str | None = None) -> str: if email: return f"{action}|{addr}|{email.lower()}" return f"{action}|{addr}" async def _is_rate_limited(key: str, limit: int) -> bool: async with _conn_lock: now = asyncio.get_event_loop().time() window_start = now - RATE_LIMIT_WINDOW times = rate_limits.get(key, []) times = [t for t in times if t >= window_start] if len(times) >= limit: rate_limits[key] = times return True times.append(now) rate_limits[key] = times return False async def _create_phantom_guarded(email: str, addr: str, user_id: str) -> tuple[dict | None, str]: """Check limits + create phantom user atomically (serialized via _phantom_lock). Returns (user_dict, error_message). user_dict is None on rejection. """ # Rate limit checks outside _phantom_lock (they acquire _conn_lock) if await _is_rate_limited(f"phantom_create|{user_id}", 10): return None, "Too many new contacts. Try later." if await _is_rate_limited(f"phantom_create_ip|{addr}", 10): return None, "Too many new contacts. Try later." async with _phantom_lock: async with _clients_lock: phantom_count = len(phantom_user_ids) if phantom_count >= MAX_PHANTOM_USERS: return None, "Server limit reached. Try later." u = await adb.create_phantom_user(email) async with _clients_lock: phantom_user_ids.add(u["id"]) return u, "" def _get_peer_addr(writer: ProtocolWriter) -> str: try: return str(writer._writer.get_extra_info("peername")[0]) except Exception: return "unknown" async def _notify_users(user_ids, msg_type, data, exclude_writer=None): """Snapshot writers under lock, send notifications outside lock.""" targets = [] async with _clients_lock: for uid in user_ids: for w in connected_clients.get(uid, []): targets.append(w) for w in targets: if w is exclude_writer: continue try: await w.send_response(msg_type, "ok", data) except Exception: pass async def _notify_users_individual(notifications, exclude_writer=None): """Send per-user data. notifications: list of (user_id, msg_type, data).""" targets = [] async with _clients_lock: for uid, mt, d in notifications: for w in connected_clients.get(uid, []): targets.append((w, mt, d)) for w, mt, d in targets: if w is exclude_writer: continue try: await w.send_response(mt, "ok", d) except Exception: pass async def _cleanup_pairings(): async with _pairing_lock: now = asyncio.get_event_loop().time() expired = [code for code, p in pairing_sessions.items() if now - p["created_at"] > PAIRING_TTL_SECONDS] for code in expired: pairing_sessions.pop(code, None) async def _cleanup_registrations(): async with _pairing_lock: now = asyncio.get_event_loop().time() expired = [code for code, p in pending_registrations.items() if now - p["created_at"] > REGISTER_TTL_SECONDS] for code in expired: pending_registrations.pop(code, None) # Purge used PoW challenges older than 120s (validity window) stale = [ch for ch, ts in _used_pow_challenges.items() if now - ts > 120] for ch in stale: _used_pow_challenges.pop(ch, None) def _generate_pairing_code() -> str: for _ in range(10): code = f"{int.from_bytes(os.urandom(4), 'big') % 100000000:08d}" if code not in pairing_sessions: return code return f"{int.from_bytes(os.urandom(4), 'big') % 100000000:08d}" def _generate_register_code() -> str: for _ in range(10): code = f"{int.from_bytes(os.urandom(3), 'big') % 1000000:06d}" if code not in pending_registrations: return code return f"{int.from_bytes(os.urandom(3), 'big') % 1000000:06d}" def _validate_public_key_pem(pem_str: str) -> bool: """Validate that a string is a valid RSA public key PEM.""" try: key = load_public_key(pem_str.encode("utf-8")) if key.key_size < 2048: return False return True except Exception: return False def _send_registration_email(to_email: str, code: str) -> bool: """Send registration code via SMTP. Returns True on success.""" if not SMTP_HOST: return False try: msg = MIMEText(f"Your registration code is: {code}\n\nThis code expires in 10 minutes.") msg["Subject"] = "Encrypted Chat - Registration Code" msg["From"] = SMTP_FROM or SMTP_USER msg["To"] = to_email with smtplib.SMTP(SMTP_HOST, SMTP_PORT, timeout=10) as server: server.starttls() if SMTP_USER: server.login(SMTP_USER, SMTP_PASS) server.send_message(msg) return True except Exception as e: logger.warning("Failed to send registration email: %s", e) return False async def send_resp(msg: dict, writer: ProtocolWriter, msg_type: str, status: str, data: dict | None = None): await writer.send_response(msg_type, status, data, request_id=msg.get("request_id")) # --- Registration admission control --- _POW_SECRET = os.urandom(32) # per-process; restarts invalidate outstanding challenges def _get_subnet(addr: str) -> str: """Extract /24 for IPv4, /64 for IPv6.""" try: ip = ipaddress.ip_address(addr) if ip.version == 4: return str(ipaddress.ip_network(f"{ip}/24", strict=False)) return str(ipaddress.ip_network(f"{ip}/64", strict=False)) except ValueError: return addr def _pending_counts_by_origin(addr: str) -> tuple[int, int]: """Count pending registrations by IP and subnet. Caller must hold _pairing_lock.""" subnet = _get_subnet(addr) ip_count = 0 subnet_count = 0 for p in pending_registrations.values(): p_addr = p.get("addr", "") if p_addr == addr: ip_count += 1 if _get_subnet(p_addr) == subnet: subnet_count += 1 return ip_count, subnet_count def _generate_pow_challenge() -> tuple[str, str]: """Generate a stateless PoW challenge (challenge, mac). The challenge embeds a timestamp so the server can reject stale solutions. The HMAC proves the challenge was issued by this server instance. """ ts = str(int(asyncio.get_event_loop().time())) nonce = secrets.token_hex(16) challenge = f"{ts}:{nonce}" mac = hmac.new(_POW_SECRET, challenge.encode(), hashlib.sha256).hexdigest() return challenge, mac def _verify_pow(challenge: str, mac: str, nonce: str, difficulty: int) -> bool: """Verify a PoW solution: HMAC authentic, timestamp fresh, hash has leading zeros.""" # Verify HMAC expected = hmac.new(_POW_SECRET, challenge.encode(), hashlib.sha256).hexdigest() if not hmac.compare_digest(expected, mac): return False # Check timestamp freshness (120s window) try: ts = int(challenge.split(":")[0]) except (ValueError, IndexError): return False now = int(asyncio.get_event_loop().time()) if abs(now - ts) > 120: return False # Verify PoW: SHA-256(challenge + nonce) must have `difficulty` leading zero bits digest = hashlib.sha256(f"{challenge}{nonce}".encode()).digest() # Check leading zero bits bits_needed = difficulty for byte in digest: if bits_needed <= 0: break if bits_needed >= 8: if byte != 0: return False bits_needed -= 8 else: mask = (0xFF << (8 - bits_needed)) & 0xFF if byte & mask: return False bits_needed = 0 return True async def handle_register_start(msg: dict, writer: ProtocolWriter) -> dict | None: await _cleanup_registrations() username = msg.get("username", "").strip() public_key = msg.get("public_key", "").strip() identity_key_b64 = msg.get("identity_key", "").strip() email = msg.get("email", "").strip() addr = _get_peer_addr(writer) if await _is_rate_limited(_rate_limit_key("register_start", addr, email), 3): await send_resp(msg, writer, "register_start", "error", {"message": "Too many attempts. Try later."}) return None # Per-IP limit (regardless of email) to prevent SMTP spam via email rotation if await _is_rate_limited(f"register_start_ip|{addr}", 6): await send_resp(msg, writer, "register_start", "error", {"message": "Too many attempts. Try later."}) return None if not username or not public_key or not email or not identity_key_b64: await send_resp(msg, writer, "register_start", "error", {"message": "Missing fields"}) return None if not _validate_public_key_pem(public_key): await send_resp(msg, writer, "register_start", "error", {"message": "Invalid public key format"}) return None # Validate identity key is 32 bytes try: ik_bytes = decode_binary(identity_key_b64) if len(ik_bytes) != 32: raise ValueError("Identity key must be 32 bytes") load_ed25519_public(ik_bytes) except Exception: await send_resp(msg, writer, "register_start", "error", {"message": "Invalid identity key"}) return None existing_email = await adb.get_user_by_email(email) phantom_id = None is_existing_real_user = False if existing_email: if existing_email.get("rsa_public_key") == "PHANTOM": phantom_id = existing_email["id"] else: is_existing_real_user = True # --- Admission control (all checks under lock, I/O outside) --- # Existing-email goes through the same path so responses are # indistinguishable from new-email (H3 anti-enumeration). # Both allocate a slot so per-IP/subnet cap counting is identical. async with _pairing_lock: total = len(pending_registrations) # Hard cap if total >= MAX_PENDING_REGISTRATIONS: reject_reason = "cap" else: # Per-IP / per-subnet slot limits ip_count, subnet_count = _pending_counts_by_origin(addr) if ip_count >= MAX_PENDING_PER_IP: reject_reason = "ip" elif subnet_count >= MAX_PENDING_PER_SUBNET: reject_reason = "subnet" else: reject_reason = None # Pressure mode: require PoW when >80% full under_pressure = total >= MAX_PENDING_REGISTRATIONS * REGISTRATION_PRESSURE_THRESHOLD need_pow = under_pressure and reject_reason is None # If PoW required, verify the client's solution (one-time use) pow_ok = False if need_pow: pow_challenge = msg.get("pow_challenge", "") pow_mac = msg.get("pow_mac", "") pow_nonce = msg.get("pow_nonce", "") if pow_challenge and pow_mac and pow_nonce: if pow_challenge in _used_pow_challenges: pow_ok = False # replay elif _verify_pow(pow_challenge, pow_mac, pow_nonce, POW_DIFFICULTY): _used_pow_challenges[pow_challenge] = asyncio.get_event_loop().time() pow_ok = True # Decide: admit, challenge, or reject if reject_reason: admit = False send_challenge = False code = None elif need_pow and not pow_ok: admit = False send_challenge = True code = None else: # Both existing and new emails allocate a slot so per-IP/subnet # counting behaves identically (anti-enumeration via slot side-channel). # Existing-email slots are inert — register_confirm silently fails. admit = True send_challenge = False code = _generate_register_code() pending_registrations[code] = { "username": username, "public_key": public_key, "identity_key": ik_bytes, "email": email, "created_at": asyncio.get_event_loop().time(), "phantom_id": phantom_id, "addr": addr, "fake": is_existing_real_user, } # --- I/O outside lock --- if not admit: if send_challenge: challenge, mac = _generate_pow_challenge() await send_resp(msg, writer, "register_start", "pow_required", { "challenge": challenge, "mac": mac, "difficulty": POW_DIFFICULTY, }) else: await send_resp(msg, writer, "register_start", "error", {"message": "Server busy. Try later."}) return None logger.info("[REGISTER] registration started") is_dev = os.getenv("ENVIRONMENT", "").lower() in ("dev", "development") # SMTP rate limiting smtp_blocked = False if SMTP_HOST: if await _is_rate_limited("smtp_send|global", SMTP_RATE_GLOBAL): smtp_blocked = True elif await _is_rate_limited(f"smtp_send_ip|{addr}", SMTP_RATE_PER_IP): smtp_blocked = True elif await _is_rate_limited(f"smtp_send_target|{email.lower()}", SMTP_RATE_PER_TARGET): smtp_blocked = True if smtp_blocked: if is_dev: logger.warning("[REGISTER] SMTP rate limit hit — returning code (dev mode)") await send_resp(msg, writer, "register_start", "ok", {"code": code}) else: logger.warning("[REGISTER] SMTP rate limit hit — revoking slot silently") async with _pairing_lock: pending_registrations.pop(code, None) await send_resp(msg, writer, "register_start", "ok", {"message": "Code sent to your email."}) return None # Send registration email in a thread (non-blocking) for both real # and fake registrations. For existing emails we still call SMTP so # the response timing is indistinguishable (anti-enumeration). # The email goes to the real address either way — existing users just # won't be able to confirm (code is for a fake slot). email_sent = await asyncio.to_thread(_send_registration_email, email, code) if email_sent: await send_resp(msg, writer, "register_start", "ok", {"message": "Code sent to your email."}) elif is_dev: logger.warning("[REGISTER] No SMTP / send failed — returning code (dev mode)") await send_resp(msg, writer, "register_start", "ok", {"code": code}) else: logger.warning("[REGISTER] SMTP send failed — revoking slot silently") async with _pairing_lock: pending_registrations.pop(code, None) await send_resp(msg, writer, "register_start", "ok", {"message": "Code sent to your email."}) return None async def handle_register_confirm(msg: dict, writer: ProtocolWriter) -> dict | None: await _cleanup_registrations() email = msg.get("email", "").strip() code = msg.get("code", "").strip() addr = _get_peer_addr(writer) if await _is_rate_limited(_rate_limit_key("register_confirm", addr, email), 3): await send_resp(msg, writer, "register_confirm", "error", {"message": "Too many attempts. Try later."}) return None if not email or not code: await send_resp(msg, writer, "register_confirm", "error", {"message": "Missing email or code"}) return None async with _pairing_lock: pending = pending_registrations.get(code) if pending and pending.get("email") == email: pending_registrations.pop(code, None) else: pending = None if not pending: await send_resp(msg, writer, "register_confirm", "error", {"message": "Invalid or expired code"}) return None # H3 anti-enumeration: fake slot (existing email) — reject with same # generic message so attacker can't distinguish from a wrong code if pending.get("fake"): await send_resp(msg, writer, "register_confirm", "error", {"message": "Invalid or expired code"}) return None phantom_id = pending.get("phantom_id") if phantom_id: # Upgrade phantom in-place — preserves FK references (invitations, memberships) user_id = await adb.upgrade_phantom_user( phantom_id, pending["username"], pending["public_key"], pending["identity_key"], ) if user_id: async with _clients_lock: phantom_user_ids.discard(phantom_id) else: # Phantom was deleted concurrently — fall back to normal create user_id = await adb.create_user( pending["username"], pending["email"], pending["public_key"], pending["identity_key"], ) else: user_id = await adb.create_user( pending["username"], pending["email"], pending["public_key"], pending["identity_key"], ) await adb.create_default_profile(user_id) logger.info("[REGISTER] confirmed (user_id=%s)", user_id[:8]) await send_resp(msg, writer, "register_confirm", "ok", {"user_id": user_id}) return None async def handle_login_start(msg: dict, writer: ProtocolWriter, state: dict): email = msg.get("email", "").strip() addr = _get_peer_addr(writer) if await _is_rate_limited(_rate_limit_key("login_start", addr, email), 10): await send_resp(msg, writer, "login_start", "error", {"message": "Too many attempts. Try later."}) return if await _is_rate_limited(f"login_start_ip|{addr}", 20): await send_resp(msg, writer, "login_start", "error", {"message": "Too many attempts. Try later."}) return if not email: await send_resp(msg, writer, "login_start", "error", {"message": "Missing email"}) return user = await adb.get_user_by_email(email) challenge = os.urandom(32) state["login_email"] = email state["login_challenge"] = challenge if not user: # H3 anti-enumeration: return a fake challenge so attacker can't distinguish # "user not found" from "user exists". login_finish will fail with generic error. state["_login_fake"] = True await send_resp(msg, writer, "login_start", "ok", {"challenge": encode_binary(challenge)}) async def handle_login_finish(msg: dict, writer: ProtocolWriter, state: dict) -> dict | None: email = msg.get("email", "").strip() signature_b64 = msg.get("signature", "") challenge = state.get("login_challenge") expected_email = state.get("login_email") addr = _get_peer_addr(writer) if await _is_rate_limited(_rate_limit_key("login_finish", addr, email), 10): await send_resp(msg, writer, "login_finish", "error", {"message": "Too many attempts. Try later."}) return None if not email or not signature_b64: await send_resp(msg, writer, "login_finish", "error", {"message": "Missing email or signature"}) return None if not challenge or expected_email != email: await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"}) return None # H3: if login_start was for a non-existent user, fail with generic error is_fake = state.pop("_login_fake", False) try: if is_fake: await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"}) return None user = await adb.get_user_by_email(email) if not user: await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"}) return None public_key = load_public_key(user["rsa_public_key"].encode("utf-8")) signature = decode_binary(signature_b64) if not rsa_verify(public_key, signature, challenge): await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"}) return None except ValueError: # H5: invalid base64 in signature await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"}) return None finally: state.pop("login_challenge", None) state.pop("login_email", None) user_id = user["id"] # Version check: reject outdated clients client_version = msg.get("client_version", "") if client_version and not version_gte(client_version, MIN_CLIENT_VERSION): await send_resp(msg, writer, "login_finish", "error", { "message": f"Client version {client_version} is too old. Minimum required: {MIN_CLIENT_VERSION}", "min_version": MIN_CLIENT_VERSION, "server_version": VERSION, }) return None # Device registration: client may send device_id to reuse an existing device client_device_id = msg.get("device_id") device_id = None if client_device_id: dev = await adb.get_device(client_device_id) if dev and dev["user_id"] == user_id: device_id = client_device_id if not device_id: device_name = msg.get("device_name", "Unknown") device_id = await adb.create_device(user_id, device_name) await adb.update_device_last_seen(device_id) async with _clients_lock: if user_id not in connected_clients: connected_clients[user_id] = [] connected_clients[user_id].append(writer) writer_device_map[id(writer)] = device_id logger.info("[LOGIN] u=%s d=%s client_v=%s", user_id[:8], device_id[:8] if device_id else "?", client_version or "unknown") await send_resp(msg, writer, "login_finish", "ok", { "user_id": user_id, "username": user["username"], "email": user["email"], "device_id": device_id, "server_version": VERSION, }) # Send online status notifications contacts = await adb.get_user_contacts(user_id) online_targets = [] async with _clients_lock: online_contacts = [cid for cid in contacts if cid in connected_clients and connected_clients[cid]] # Always notify contacts (handles reconnect where old writer is still lingering) for contact_id in contacts: for cw in connected_clients.get(contact_id, []): online_targets.append(cw) await writer.send_response("online_users", "ok", {"user_ids": online_contacts}) # Send online notifications outside lock for cw in online_targets: try: await cw.send_response("user_online", "ok", {"user_id": user_id}) except Exception: pass return {"user_id": user_id, "username": user["username"], "email": user["email"], "device_id": device_id} async def handle_get_user_info(msg: dict, session: dict, writer: ProtocolWriter): """Get user info including identity key (for X3DH). Requires login.""" email = msg.get("email", "").strip() user_id = msg.get("user_id", "").strip() addr = _get_peer_addr(writer) if await _is_rate_limited(_rate_limit_key("get_user_info", addr, email or user_id), 30): await send_resp(msg, writer, "get_user_info", "error", {"message": "Too many attempts. Try later."}) return if user_id and not _valid_uuid(user_id): await send_resp(msg, writer, "get_user_info", "error", {"message": "User not found"}) return user = None if email: user = await adb.get_user_by_email(email) elif user_id: user = await adb.get_user_by_id(user_id) if not user: await send_resp(msg, writer, "get_user_info", "error", {"message": "User not found"}) return # H4 fix: restrict lookups to self or contacts (shared conversation) target_id = user["id"] if target_id != session["user_id"]: if not await adb.shares_conversation(session["user_id"], target_id): await send_resp(msg, writer, "get_user_info", "error", {"message": "User not found"}) return ik = user.get("identity_key") await send_resp(msg, writer, "get_user_info", "ok", { "user_id": user["id"], "username": user["username"], "email": user["email"], "identity_key": encode_binary(ik) if ik else "", }) async def handle_upload_prekeys(msg: dict, session: dict, writer: ProtocolWriter): """Upload signed prekey + batch of one-time prekeys.""" if await _is_rate_limited(f"upload_prekeys|{session['user_id']}", 5): await send_resp(msg, writer, "upload_prekeys", "error", {"message": "Too many requests. Try later."}) return spk_data = msg.get("signed_prekey") otps = msg.get("one_time_prekeys", []) if not spk_data: await send_resp(msg, writer, "upload_prekeys", "error", {"message": "Missing signed_prekey"}) return spk_id = spk_data.get("id", "") spk_pub_b64 = spk_data.get("public_key", "") spk_sig_b64 = spk_data.get("signature", "") if not spk_id or not spk_pub_b64 or not spk_sig_b64: await send_resp(msg, writer, "upload_prekeys", "error", {"message": "Incomplete signed_prekey"}) return spk_pub = decode_binary(spk_pub_b64) spk_sig = decode_binary(spk_sig_b64) # Verify SPK signature with user's identity key user = await adb.get_user_by_id(session["user_id"]) if not user or not user.get("identity_key"): await send_resp(msg, writer, "upload_prekeys", "error", {"message": "No identity key"}) return ik_pub = load_ed25519_public(user["identity_key"]) if not ed25519_verify(ik_pub, spk_sig, spk_pub): await send_resp(msg, writer, "upload_prekeys", "error", {"message": "Invalid SPK signature"}) return device_id = session.get("device_id") await adb.store_signed_prekey(session["user_id"], spk_id, spk_pub, spk_sig, device_id=device_id) # Store OTPs otp_records = [] for otp in otps: otp_id = otp.get("id", "") otp_pub_b64 = otp.get("public_key", "") if otp_id and otp_pub_b64: otp_records.append({"id": otp_id, "public_key": decode_binary(otp_pub_b64)}) if otp_records: await adb.store_one_time_prekeys(session["user_id"], otp_records, device_id=device_id) logger.info("[PREKEYS] %s uploaded 1 SPK + %d OTPs", _who(session), len(otp_records)) await send_resp(msg, writer, "upload_prekeys", "ok", {"message": "OK"}) async def handle_get_key_bundle(msg: dict, session: dict, writer: ProtocolWriter): """Fetch key bundle for X3DH. Returns per-device bundles. Consumes one OTP per device.""" target_user_id = msg.get("user_id", "").strip() if not target_user_id: await send_resp(msg, writer, "get_key_bundle", "error", {"message": "Missing user_id"}) return if not _valid_uuid(target_user_id): await send_resp(msg, writer, "get_key_bundle", "error", {"message": "Invalid user_id"}) return # M4: rate limit + authorization (prevents OPK depletion) if await _is_rate_limited(f"get_key_bundle|{session['user_id']}", 10): await send_resp(msg, writer, "get_key_bundle", "error", {"message": "Too many requests. Try later."}) return # Auth check before per-target rate limit so unauthorized requests don't burn target's bucket if target_user_id != session["user_id"]: if not await adb.shares_conversation(session["user_id"], target_user_id): await send_resp(msg, writer, "get_key_bundle", "error", {"message": "Key bundle not available"}) return if await _is_rate_limited(f"get_key_bundle_target|{target_user_id}", 20): await send_resp(msg, writer, "get_key_bundle", "error", {"message": "Too many requests. Try later."}) return result = await adb.get_key_bundles_for_user(target_user_id) if not result or not result.get("device_bundles"): await send_resp(msg, writer, "get_key_bundle", "error", {"message": "Key bundle not available"}) return device_bundles_data = [] for b in result["device_bundles"]: entry = { "device_id": b.get("device_id"), "signed_prekey_id": b["signed_prekey_id"], "signed_prekey": encode_binary(b["signed_prekey_pub"]), "spk_signature": encode_binary(b["spk_signature"]), } if b.get("opk_pub"): entry["one_time_prekey_id"] = b["opk_id"] entry["one_time_prekey"] = encode_binary(b["opk_pub"]) device_bundles_data.append(entry) # Build response with both new multi-device format and legacy flat fields first = device_bundles_data[0] if device_bundles_data else {} data = { "identity_key": encode_binary(result["identity_key"]), "device_bundles": device_bundles_data, # Legacy flat fields from first device bundle (backward compat) "signed_prekey_id": first.get("signed_prekey_id", ""), "signed_prekey": first.get("signed_prekey", ""), "spk_signature": first.get("spk_signature", ""), } if first.get("one_time_prekey"): data["one_time_prekey_id"] = first["one_time_prekey_id"] data["one_time_prekey"] = first["one_time_prekey"] logger.info("[X3DH] %s fetched key bundle for user=%s (%d devices)", _who(session), target_user_id[:8], len(device_bundles_data)) await send_resp(msg, writer, "get_key_bundle", "ok", data) async def handle_get_prekey_count(msg: dict, session: dict, writer: ProtocolWriter): """How many OPKs does user have left (for this device)? Also returns SPK age for rotation.""" device_id = session.get("device_id") count = await adb.count_one_time_prekeys(session["user_id"], device_id=device_id) spk_created_at = "" spk = await adb.get_signed_prekey(session["user_id"], device_id=device_id) if spk and spk.get("created_at"): spk_created_at = spk["created_at"].isoformat() if hasattr(spk["created_at"], "isoformat") else str(spk["created_at"]) await send_resp(msg, writer, "get_prekey_count", "ok", {"count": count, "spk_created_at": spk_created_at}) async def handle_ensure_prekeys(msg: dict, session: dict, writer: ProtocolWriter): """Combined get_prekey_count + upload_prekeys in one round-trip. Client sends current OPK/SPK data; server checks count and SPK age, stores new keys if provided, and returns the current status. """ if await _is_rate_limited(f"ensure_prekeys|{session['user_id']}", 5): await send_resp(msg, writer, "ensure_prekeys", "error", {"message": "Too many requests. Try later."}) return device_id = session.get("device_id") user_id = session["user_id"] # Step 1: Get current count + SPK age count = await adb.count_one_time_prekeys(user_id, device_id=device_id) spk_created_at = "" spk = await adb.get_signed_prekey(user_id, device_id=device_id) if spk and spk.get("created_at"): spk_created_at = spk["created_at"].isoformat() if hasattr(spk["created_at"], "isoformat") else str(spk["created_at"]) # Step 2: If client included new keys, store them uploaded_spk = False uploaded_otps = 0 spk_data = msg.get("signed_prekey") if spk_data: spk_id = spk_data.get("id", "") spk_pub_b64 = spk_data.get("public_key", "") spk_sig_b64 = spk_data.get("signature", "") if spk_id and spk_pub_b64 and spk_sig_b64: spk_pub = decode_binary(spk_pub_b64) spk_sig = decode_binary(spk_sig_b64) user = await adb.get_user_by_id(user_id) if user and user.get("identity_key"): ik_pub = load_ed25519_public(user["identity_key"]) if ed25519_verify(ik_pub, spk_sig, spk_pub): await adb.store_signed_prekey(user_id, spk_id, spk_pub, spk_sig, device_id=device_id) uploaded_spk = True otps = msg.get("one_time_prekeys", []) if otps: otp_records = [] for otp in otps: otp_id = otp.get("id", "") otp_pub_b64 = otp.get("public_key", "") if otp_id and otp_pub_b64: otp_records.append({"id": otp_id, "public_key": decode_binary(otp_pub_b64)}) if otp_records: await adb.store_one_time_prekeys(user_id, otp_records, device_id=device_id) uploaded_otps = len(otp_records) # Recount after upload if uploaded_spk or uploaded_otps: count = await adb.count_one_time_prekeys(user_id, device_id=device_id) spk = await adb.get_signed_prekey(user_id, device_id=device_id) if spk and spk.get("created_at"): spk_created_at = spk["created_at"].isoformat() if hasattr(spk["created_at"], "isoformat") else str(spk["created_at"]) logger.info("[PREKEYS] %s ensure_prekeys: uploaded SPK=%s, OTPs=%d, new count=%d", _who(session), uploaded_spk, uploaded_otps, count) await send_resp(msg, writer, "ensure_prekeys", "ok", {"count": count, "spk_created_at": spk_created_at, "uploaded_spk": uploaded_spk, "uploaded_otps": uploaded_otps}) async def handle_rotate_keys(msg: dict, session: dict, writer: ProtocolWriter): if await _is_rate_limited(f"rotate_keys|{session['user_id']}", 3): await send_resp(msg, writer, "rotate_keys", "error", {"message": "Too many requests. Try later."}) return public_key = msg.get("public_key", "").strip() if not public_key: await send_resp(msg, writer, "rotate_keys", "error", {"message": "Missing public_key"}) return if not _validate_public_key_pem(public_key): await send_resp(msg, writer, "rotate_keys", "error", {"message": "Invalid public key format"}) return await adb.update_user_rsa_key(session["user_id"], public_key) logger.info("[ROTATE] %s rotated RSA key", _who(session)) await send_resp(msg, writer, "rotate_keys", "ok", {"message": "OK"}) # Disconnect other sessions async with _clients_lock: writers = connected_clients.get(session["user_id"], []) others = [w for w in writers if w is not writer] connected_clients[session["user_id"]] = [writer] for w in others: try: w.close() except Exception: pass async def handle_change_username(msg: dict, session: dict, writer: ProtocolWriter): if await _is_rate_limited(f"change_username|{session['user_id']}", 5): await send_resp(msg, writer, "change_username", "error", {"message": "Too many requests. Try later."}) return new_username = msg.get("username", "").strip() if not new_username or len(new_username) > 100: await send_resp(msg, writer, "change_username", "error", {"message": "Invalid username (1-100 chars)"}) return user_id = session["user_id"] await adb.update_username(user_id, new_username) session["username"] = new_username logger.info("[ACCOUNT] %s changed username", _who(session)) await send_resp(msg, writer, "change_username", "ok", {"username": new_username}) # Notify contacts contacts = await adb.get_user_contacts(user_id) targets = [] async with _clients_lock: for cid in contacts: for cw in connected_clients.get(cid, []): targets.append(cw) for cw in targets: try: await cw.send_response("username_changed", "ok", { "user_id": user_id, "username": new_username, }) except Exception: pass async def handle_pairing_start(msg: dict, writer: ProtocolWriter): await _cleanup_pairings() email = msg.get("email", "").strip() temp_public_key = msg.get("temp_public_key", "").strip() addr = _get_peer_addr(writer) # H4 fix: rate limit per IP only (not per email) to prevent enumeration via email rotation if await _is_rate_limited(_rate_limit_key("pairing_start", addr), 10): await send_resp(msg, writer, "pairing_start", "error", {"message": "Too many attempts. Try later."}) return if not email or not temp_public_key: await send_resp(msg, writer, "pairing_start", "error", {"message": "Missing email or temp_public_key"}) return poll_token = secrets.token_hex(16) cap_hit = False async with _pairing_lock: # H4 fix: global cap prevents memory exhaustion from dummy sessions if len(pairing_sessions) >= PAIRING_MAX_SESSIONS: cap_hit = True else: code = _generate_pairing_code() # H4 fix: always create session (anti-enumeration). For non-existent users # the session behaves identically (poll returns ready:false, claim never matches # because no real account can log in to claim it). TTL cleanup handles expiry. pairing_sessions[code] = { "email": email, "temp_public_key": temp_public_key, "created_at": asyncio.get_event_loop().time(), "payload": None, "poll_token": poll_token, } if cap_hit: await send_resp(msg, writer, "pairing_start", "error", {"message": "Too many attempts. Try later."}) return await send_resp(msg, writer, "pairing_start", "ok", {"code": code, "poll_token": poll_token}) async def handle_pairing_claim(msg: dict, session: dict, writer: ProtocolWriter): await _cleanup_pairings() code = msg.get("code", "").strip() if not code: await send_resp(msg, writer, "pairing_claim", "error", {"message": "Missing code"}) return async with _pairing_lock: p = pairing_sessions.get(code) p_email = p["email"] if p else None temp_pub = p["temp_public_key"] if p else None if p: # Extend TTL — re-encryption may run between claim and send p["created_at"] = asyncio.get_event_loop().time() # H4 fix: unified error message (anti-enumeration) if not p or p_email != session.get("email"): await send_resp(msg, writer, "pairing_claim", "error", {"message": "Invalid or expired code"}) return await send_resp(msg, writer, "pairing_claim", "ok", {"temp_public_key": temp_pub}) async def handle_pairing_send(msg: dict, session: dict, writer: ProtocolWriter): await _cleanup_pairings() code = msg.get("code", "").strip() payload = msg.get("payload") if not code or not payload: await send_resp(msg, writer, "pairing_send", "error", {"message": "Missing code or payload"}) return error_msg = None async with _pairing_lock: p = pairing_sessions.get(code) # H4 fix: unified error message (anti-enumeration) if not p or p["email"] != session.get("email"): error_msg = "Invalid or expired code" else: p["payload"] = payload if error_msg: await send_resp(msg, writer, "pairing_send", "error", {"message": error_msg}) else: await send_resp(msg, writer, "pairing_send", "ok", {"message": "OK"}) async def handle_pairing_poll(msg: dict, writer: ProtocolWriter): await _cleanup_pairings() code = msg.get("code", "").strip() poll_token = msg.get("poll_token", "").strip() addr = _get_peer_addr(writer) if await _is_rate_limited(_rate_limit_key("pairing_poll", addr), 120): await send_resp(msg, writer, "pairing_poll", "error", {"message": "Too many attempts. Try later."}) return if not code: await send_resp(msg, writer, "pairing_poll", "error", {"message": "Missing code"}) return if not poll_token: await send_resp(msg, writer, "pairing_poll", "error", {"message": "Missing poll_token"}) return error_msg = None ready = False payload = None async with _pairing_lock: p = pairing_sessions.get(code) if not p: error_msg = "Invalid or expired code" elif not secrets.compare_digest(p.get("poll_token", ""), poll_token): error_msg = "Invalid poll_token" else: poll_attempts = p.get("poll_attempts", 0) + 1 p["poll_attempts"] = poll_attempts if poll_attempts > PAIRING_MAX_POLL_ATTEMPTS and not p.get("payload"): pairing_sessions.pop(code, None) error_msg = "Code invalidated due to too many attempts" elif p.get("payload"): ready = True payload = p["payload"] pairing_sessions.pop(code, None) if error_msg: await send_resp(msg, writer, "pairing_poll", "error", {"message": error_msg}) elif ready: await send_resp(msg, writer, "pairing_poll", "ok", {"ready": True, "payload": payload}) else: await send_resp(msg, writer, "pairing_poll", "ok", {"ready": False}) async def handle_create_conversation(msg: dict, session: dict, writer: ProtocolWriter): member_emails = msg.get("members", []) name = msg.get("name") addr = _get_peer_addr(writer) if await _is_rate_limited(f"create_conversation|{session['user_id']}", 10): await send_resp(msg, writer, "create_conversation", "error", {"message": "Too many attempts. Try later."}) return # Resolve all member user IDs other_users = [] for email in member_emails: u = await adb.get_user_by_email(email) if not u: if not _valid_email(email): await send_resp(msg, writer, "create_conversation", "error", {"message": f"Invalid email format: {email}"}) return # H5: atomic phantom creation (cap check + DB create + set add) u, err_msg = await _create_phantom_guarded(email, addr, session["user_id"]) if u is None: await send_resp(msg, writer, "create_conversation", "error", {"message": err_msg}) return if u["id"] != session["user_id"]: other_users.append(u) is_dm = len(other_users) == 1 and not name joined_at = datetime.now(timezone.utc) if is_dm: # DMs: add both members directly (no invitation) all_ids = [session["user_id"]] + [u["id"] for u in other_users] conv_id = await adb.create_conversation(all_ids, joined_at=joined_at, name=name, created_by=session["user_id"]) logger.info("[CONV] %s created DM conv=%s", _who(session), conv_id[:8]) await send_resp(msg, writer, "create_conversation", "ok", {"conversation_id": conv_id}) # Notify the other member members_info = await adb.get_conversation_members(conv_id) member_list = [{"user_id": m["id"], "username": m["username"], "email": m["email"]} for m in members_info] notif_data = { "conversation_id": conv_id, "name": name, "created_by": session["user_id"], "members": member_list, } await _notify_users([u["id"] for u in other_users], "conversation_created", notif_data) else: # Groups: only add creator, create invitations for others conv_id = await adb.create_conversation([session["user_id"]], joined_at=joined_at, name=name, created_by=session["user_id"]) logger.info("[CONV] %s created group conv=%s", _who(session), conv_id[:8]) # Create invitations for other members creator_user = await adb.get_user_by_id(session["user_id"]) creator_name = creator_user["username"] if creator_user else "Unknown" invited_ids = [] async with _clients_lock: phantom_snapshot = set(phantom_user_ids) for u in other_users: await adb.create_invitation(conv_id, u["id"], session["user_id"]) if u["id"] not in phantom_snapshot: invited_ids.append(u["id"]) # only notify non-phantoms inv_notif = { "conversation_id": conv_id, "conversation_name": name, "invited_by": session["user_id"], "invited_by_username": creator_name, } await _notify_users(invited_ids, "group_invitation", inv_notif) await send_resp(msg, writer, "create_conversation", "ok", {"conversation_id": conv_id}) async def handle_find_conversation(msg: dict, session: dict, writer: ProtocolWriter): email = msg.get("email", "").strip() if not email: await send_resp(msg, writer, "find_conversation", "error", {"message": "Invalid request"}) return addr = _get_peer_addr(writer) if await _is_rate_limited(_rate_limit_key("find_conversation", addr, email), 30): await send_resp(msg, writer, "find_conversation", "error", {"message": "Too many attempts. Try later."}) return other = await adb.get_user_by_email(email) if not other: if not _valid_email(email): await send_resp(msg, writer, "find_conversation", "error", {"message": "Invalid email format"}) return # H5: atomic phantom creation (cap check + DB create + set add) other, err_msg = await _create_phantom_guarded(email, addr, session["user_id"]) if other is None: await send_resp(msg, writer, "find_conversation", "error", {"message": err_msg}) return conv_id = await adb.find_direct_conversation(session["user_id"], other["id"]) await send_resp(msg, writer, "find_conversation", "ok", { "conversation_id": conv_id, "user_id": other["id"], }) async def handle_add_member(msg: dict, session: dict, writer: ProtocolWriter): conv_id = msg.get("conversation_id", "") email = msg.get("email", "").strip() if not conv_id or not email: await send_resp(msg, writer, "add_member", "error", {"message": "Invalid request"}) return if not _valid_uuid(conv_id): await send_resp(msg, writer, "add_member", "error", {"message": "Invalid conversation_id"}) return # L8: validate email format before phantom creation addr = _get_peer_addr(writer) if await _is_rate_limited(_rate_limit_key("add_member", addr, email), 10): await send_resp(msg, writer, "add_member", "error", {"message": "Too many attempts. Try later."}) return if not await adb.is_conversation_member(conv_id, session["user_id"]): await send_resp(msg, writer, "add_member", "error", {"message": "Not a member"}) return user = await adb.get_user_by_email(email) if not user: # Create phantom for unregistered email (same as create_conversation) if not _valid_email(email): await send_resp(msg, writer, "add_member", "error", {"message": "Invalid email format"}) return # H5: atomic phantom creation (cap check + DB create + set add) user, err_msg = await _create_phantom_guarded(email, addr, session["user_id"]) if user is None: await send_resp(msg, writer, "add_member", "error", {"message": err_msg}) return if await adb.is_conversation_member(conv_id, user["id"]): await send_resp(msg, writer, "add_member", "error", {"message": "Already a member"}) return if await adb.has_pending_invitation(conv_id, user["id"]): await send_resp(msg, writer, "add_member", "error", {"message": "Invitation already pending"}) return # Create invitation (for both real and phantom users) await adb.create_invitation(conv_id, user["id"], session["user_id"]) logger.info("[INVITE] %s invited u=%s to conv=%s", _who(session), user["id"][:8], conv_id[:8]) await send_resp(msg, writer, "add_member", "ok", {"user_id": user["id"]}) # Push invitation notification only to non-phantom users async with _clients_lock: is_phantom = user["id"] in phantom_user_ids if not is_phantom: conv = await adb.get_conversation(conv_id) creator_user = await adb.get_user_by_id(session["user_id"]) creator_name = creator_user["username"] if creator_user else "Unknown" inv_notif = { "conversation_id": conv_id, "conversation_name": conv.get("name") if conv else None, "invited_by": session["user_id"], "invited_by_username": creator_name, } await _notify_users([user["id"]], "group_invitation", inv_notif) async def handle_accept_invitation(msg: dict, session: dict, writer: ProtocolWriter): """Accept a group invitation — add user to conversation members.""" conv_id = msg.get("conversation_id", "") if not conv_id: await send_resp(msg, writer, "accept_invitation", "error", {"message": "Missing conversation_id"}) return if not _valid_uuid(conv_id): await send_resp(msg, writer, "accept_invitation", "error", {"message": "Invalid conversation_id"}) return if not await adb.has_pending_invitation(conv_id, session["user_id"]): await send_resp(msg, writer, "accept_invitation", "error", {"message": "No pending invitation"}) return joined_at = datetime.now(timezone.utc) await adb.add_conversation_member(conv_id, session["user_id"], joined_at=joined_at) await adb.delete_invitation(conv_id, session["user_id"]) logger.info("[INVITE] %s accepted invitation to conv=%s", _who(session), conv_id[:8]) await send_resp(msg, writer, "accept_invitation", "ok", {"conversation_id": conv_id}) # Notify existing members about the new member user = await adb.get_user_by_id(session["user_id"]) notif_data = { "conversation_id": conv_id, "user_id": session["user_id"], "username": user["username"] if user else "", "email": user["email"] if user else "", } members = await adb.get_conversation_members(conv_id) member_ids = [m["id"] for m in members if m["id"] != session["user_id"]] await _notify_users(member_ids, "member_added", notif_data) async def handle_decline_invitation(msg: dict, session: dict, writer: ProtocolWriter): """Decline a group invitation.""" conv_id = msg.get("conversation_id", "") if not conv_id: await send_resp(msg, writer, "decline_invitation", "error", {"message": "Missing conversation_id"}) return if not _valid_uuid(conv_id): await send_resp(msg, writer, "decline_invitation", "error", {"message": "Invalid conversation_id"}) return if not await adb.has_pending_invitation(conv_id, session["user_id"]): await send_resp(msg, writer, "decline_invitation", "error", {"message": "No pending invitation"}) return await adb.delete_invitation(conv_id, session["user_id"]) logger.info("[INVITE] %s declined invitation to conv=%s", _who(session), conv_id[:8]) await send_resp(msg, writer, "decline_invitation", "ok", {"message": "OK"}) async def handle_list_invitations(msg: dict, session: dict, writer: ProtocolWriter): """List pending group invitations for the current user.""" invitations = await adb.get_pending_invitations(session["user_id"]) result = [] for inv in invitations: entry = { "conversation_id": inv["conversation_id"], "conversation_name": inv.get("conversation_name"), "invited_by": inv["invited_by"], "invited_by_username": inv.get("invited_by_username", ""), "created_at": inv["created_at"].isoformat() if hasattr(inv["created_at"], "isoformat") else str(inv["created_at"]), } result.append(entry) await send_resp(msg, writer, "list_invitations", "ok", {"invitations": result}) async def handle_list_conversations(msg: dict, session: dict, writer: ProtocolWriter): convs = await adb.list_user_conversations(session["user_id"]) unread = await adb.get_unread_counts(session["user_id"], max_age_days=METADATA_RETENTION_DAYS) result = [] for c in convs: result.append({ "conversation_id": c["id"], "created_at": c["created_at"].isoformat() if hasattr(c["created_at"], "isoformat") else str(c["created_at"]), "members": c["members"], "name": c.get("name"), "created_by": c.get("created_by"), "avatar_file": c.get("avatar_file"), "unread_count": unread.get(c["id"], 0), }) logger.info("[LIST] %s listed %d conversations", _who(session), len(result)) await send_resp(msg, writer, "list_conversations", "ok", {"conversations": result}) async def handle_send_message(msg: dict, session: dict, writer: ProtocolWriter): conv_id = msg.get("conversation_id", "") if not conv_id: await send_resp(msg, writer, "send_message", "error", {"message": "Missing conversation_id"}) return if not _valid_uuid(conv_id): await send_resp(msg, writer, "send_message", "error", {"message": "Invalid conversation_id"}) return addr = _get_peer_addr(writer) if await _is_rate_limited(_rate_limit_key("send_message", addr, session.get("email")), 20): await send_resp(msg, writer, "send_message", "error", {"message": "Too many attempts. Try later."}) return if not await adb.is_conversation_member(conv_id, session["user_id"]): await send_resp(msg, writer, "send_message", "error", {"message": "Not a member"}) return # New protocol: ratchet_header + recipients[] with per-user ciphertext ratchet_header_raw = msg.get("ratchet_header") recipients_raw = msg.get("recipients") if not ratchet_header_raw or not recipients_raw: await send_resp(msg, writer, "send_message", "error", {"message": "Missing ratchet_header or recipients"}) return # C2 fix: validate header is a dict (reject raw str/bytes) ratchet_header = _validate_header(ratchet_header_raw, "ratchet_header") if ratchet_header is None: await send_resp(msg, writer, "send_message", "error", {"message": "Invalid ratchet_header format"}) return x3dh_header_raw = msg.get("x3dh_header") x3dh_header = None if x3dh_header_raw: x3dh_header = _validate_header(x3dh_header_raw, "x3dh_header") if x3dh_header is None: await send_resp(msg, writer, "send_message", "error", {"message": "Invalid x3dh_header format"}) return sender_chain_id_b64 = msg.get("sender_chain_id") sender_chain_id = decode_binary(sender_chain_id_b64) if sender_chain_id_b64 else None sender_chain_n = msg.get("sender_chain_n") # Validate recipients are actual members conv_members = await adb.get_conversation_members(conv_id) member_ids = {m["id"] for m in conv_members} async with _clients_lock: phantom_snapshot = set(phantom_user_ids) db_recipients = [] for r in recipients_raw: uid = r.get("user_id", "") if uid not in member_ids: continue if uid in phantom_snapshot: continue ct_b64 = r.get("encrypted_content", "") nonce_b64 = r.get("nonce", "") if not ct_b64 or not nonce_b64: continue entry = { "user_id": uid, "encrypted_content": decode_binary(ct_b64), "nonce": decode_binary(nonce_b64), } # Per-recipient device_id (multi-device support) r_device_id = r.get("device_id") if r_device_id: entry["device_id"] = r_device_id # Per-recipient ratchet header and x3dh header (C2 fix: validate dict) r_rh = r.get("ratchet_header") if r_rh: r_rh_bytes = _validate_header(r_rh, "recipient_ratchet_header") if r_rh_bytes: entry["ratchet_header"] = r_rh_bytes r_x3dh = r.get("x3dh_header") if r_x3dh: r_x3dh_bytes = _validate_header(r_x3dh, "recipient_x3dh_header") if r_x3dh_bytes: entry["x3dh_header"] = r_x3dh_bytes db_recipients.append(entry) if not db_recipients: await send_resp(msg, writer, "send_message", "error", {"message": "No valid recipients"}) return image_file_id = msg.get("image_file_id") # Metadata privacy: for group messages (sender_chain_id present), store chain # metadata in per-recipient ratchet_header instead of the messages table. # This avoids persisting sender correlation data at the message level. # Skip sender's own self-copy entry — it uses a different decrypt path # (self-encryption key) and must keep its own ratchet_header ({"self":true}). db_sender_chain_id = None db_sender_chain_n = None if sender_chain_id: chain_meta = json.dumps({ "chain_id": encode_binary(sender_chain_id), "chain_n": sender_chain_n, }).encode() sender_uid = session["user_id"] for r in db_recipients: # Skip self-copy (sender's own entry) — uses self-encryption, not sender key if r["user_id"] == sender_uid: continue if not r.get("ratchet_header"): r["ratchet_header"] = chain_meta msg_id, created_at = await adb.store_message( conv_id, session["user_id"], ratchet_header, db_recipients, x3dh_header=x3dh_header, sender_chain_id=db_sender_chain_id, sender_chain_n=db_sender_chain_n, image_file_id=image_file_id, sender_device_id=session.get("device_id"), ) # Link image upload to message if present if image_file_id: upload = await adb.get_image_upload(image_file_id) if upload and upload["completed"] and upload["uploader_id"] == session["user_id"]: await adb.set_message_image_file_id(msg_id, image_file_id) logger.info("[MSG] %s msg=%s conv=%s", _who(session), msg_id[:8], conv_id[:8]) await send_resp(msg, writer, "send_message", "ok", {"message_id": msg_id, "created_at": created_at}) # Notify connected recipients — group all per-device entries by user_id # Use validated db_recipients (not raw input) to prevent unvalidated headers in push msg_ratchet_header_dict = json.loads(ratchet_header.decode()) msg_x3dh_header_dict = json.loads(x3dh_header.decode()) if x3dh_header else None from collections import defaultdict user_entries = defaultdict(list) for r in db_recipients: uid = r["user_id"] # Per-recipient headers are stored as bytes; decode back to dict for notification JSON r_rh = r.get("ratchet_header") r_rh_dict = json.loads(r_rh.decode()) if r_rh else None r_x3dh = r.get("x3dh_header") r_x3dh_dict = json.loads(r_x3dh.decode()) if r_x3dh else None user_entries[uid].append({ "device_id": r.get("device_id", db.SELF_DEVICE_ID), "encrypted_content": encode_binary(r["encrypted_content"]), "nonce": encode_binary(r["nonce"]), "ratchet_header": r_rh_dict or msg_ratchet_header_dict, "x3dh_header": r_x3dh_dict or msg_x3dh_header_dict, }) notifications = [] for uid, entries in user_entries.items(): notif_data = { "message_id": msg_id, "conversation_id": conv_id, "sender_id": session["user_id"], "sender_device_id": session.get("device_id"), "device_entries": entries, } if sender_chain_id_b64: notif_data["sender_chain_id"] = sender_chain_id_b64 if sender_chain_n is not None: notif_data["sender_chain_n"] = sender_chain_n # Also include flat fields for backward compat with old clients # (first entry's data as fallback) if entries: first = entries[0] notif_data["ratchet_header"] = first.get("ratchet_header") or msg_ratchet_header_dict notif_data["encrypted_content"] = first.get("encrypted_content", "") notif_data["nonce"] = first.get("nonce", "") if first.get("x3dh_header"): notif_data["x3dh_header"] = first["x3dh_header"] notifications.append((uid, "new_message", notif_data)) await _notify_users_individual(notifications, exclude_writer=writer) async def handle_get_messages(msg: dict, session: dict, writer: ProtocolWriter): if await _is_rate_limited(f"get_messages|{session['user_id']}", 30): await send_resp(msg, writer, "get_messages", "error", {"message": "Too many requests. Try later."}) return conv_id = msg.get("conversation_id", "") if not conv_id: await send_resp(msg, writer, "get_messages", "error", {"message": "Missing conversation_id"}) return if not _valid_uuid(conv_id): await send_resp(msg, writer, "get_messages", "error", {"message": "Invalid conversation_id"}) return if not await adb.is_conversation_member(conv_id, session["user_id"]): await send_resp(msg, writer, "get_messages", "error", {"message": "Not a member"}) return limit = min(max(int(msg.get("limit", 50)), 1), 200) offset = max(int(msg.get("offset", 0)), 0) device_id = session.get("device_id") after_ts = msg.get("after_ts") # ISO timestamp string or None messages = await adb.get_messages(conv_id, session["user_id"], limit, offset, device_id=device_id, after_ts=after_ts) # Deduplicate: when both device-specific and SELF_DEVICE_ID rows exist for the # same message, prefer device-specific (non-sentinel). Keep first seen per message_id. seen_ids = {} deduped = [] for m in messages: mid = m["id"] mr_dev = m.get("mr_device_id", "") if mid not in seen_ids: seen_ids[mid] = len(deduped) deduped.append(m) elif mr_dev != db.SELF_DEVICE_ID: # Replace SELF_DEVICE_ID entry with device-specific one deduped[seen_ids[mid]] = m messages = deduped result = [] message_ids = [m["id"] for m in messages] read_status = await adb.get_message_read_status(message_ids) if message_ids else {} delivery_status = await adb.get_message_delivery_status(message_ids) if message_ids else {} reactions_map = await adb.get_reactions(message_ids) if message_ids else {} for m in messages: read_by = read_status.get(m["id"], []) # Prefer per-recipient headers (mr_*) over message-level headers rh_raw = m.get("mr_ratchet_header") or m.get("ratchet_header") x3dh_raw = m.get("mr_x3dh_header") or m.get("x3dh_header") # C2 fix: defensive JSON parsing — corrupted headers don't break fetch try: rh_parsed = json.loads(rh_raw) if rh_raw else {} except (json.JSONDecodeError, TypeError, UnicodeDecodeError): logger.warning("[FETCH] Corrupted ratchet_header in message %s, skipping", m["id"]) rh_parsed = {} try: x3dh_parsed = json.loads(x3dh_raw) if x3dh_raw else None except (json.JSONDecodeError, TypeError, UnicodeDecodeError): logger.warning("[FETCH] Corrupted x3dh_header in message %s, skipping", m["id"]) x3dh_parsed = None entry = { "message_id": m["id"], "sender_id": m.get("sender_id") or "", "ratchet_header": rh_parsed, "encrypted_content": encode_binary(m["encrypted_content"]) if m.get("encrypted_content") else "", "nonce": encode_binary(m["nonce"]) if m.get("nonce") else "", "created_at": m["created_at"].isoformat() if hasattr(m["created_at"], "isoformat") else str(m["created_at"]), "read_by": read_by, "delivered_to": delivery_status.get(m["id"], []), } if x3dh_parsed: entry["x3dh_header"] = x3dh_parsed # Sender chain metadata: check message-level first (backward compat), # then per-recipient ratchet_header (new metadata-private format). # Only extract from per-recipient header if message-level ratchet_header # is the group dummy (dh_pub all-zeros) — prevents DM header injection. if m.get("sender_chain_id"): entry["sender_chain_id"] = encode_binary(m["sender_chain_id"]) elif isinstance(rh_parsed, dict) and rh_parsed.get("chain_id"): # Verify this is a group message by checking the message-level header msg_rh_raw = m.get("ratchet_header") is_group = False if msg_rh_raw: try: msg_rh = json.loads(msg_rh_raw) if isinstance(msg_rh_raw, (bytes, str)) else msg_rh_raw is_group = isinstance(msg_rh, dict) and msg_rh.get("dh_pub") == "00" * 32 except (json.JSONDecodeError, TypeError, UnicodeDecodeError): pass if is_group: entry["sender_chain_id"] = rh_parsed["chain_id"] if m.get("sender_chain_n") is not None: entry["sender_chain_n"] = m["sender_chain_n"] elif isinstance(rh_parsed, dict) and rh_parsed.get("chain_n") is not None: # Same group-only guard if "sender_chain_id" in entry: entry["sender_chain_n"] = rh_parsed["chain_n"] if m.get("sender_device_id"): entry["sender_device_id"] = m["sender_device_id"] if m.get("deleted_at"): entry["deleted_at"] = m["deleted_at"].isoformat() if hasattr(m["deleted_at"], "isoformat") else str(m["deleted_at"]) # Pin metadata if m.get("pinned_at"): entry["pinned_at"] = m["pinned_at"].isoformat() if hasattr(m["pinned_at"], "isoformat") else str(m["pinned_at"]) entry["pinned_by"] = m.get("pinned_by") or "" # Reactions msg_reactions = reactions_map.get(m["id"]) if msg_reactions: entry["reactions"] = msg_reactions result.append(entry) total_count = await adb.count_messages(conv_id, session["user_id"]) logger.info("[FETCH] %s fetched %d/%d msgs from conv=%s (limit=%d, offset=%d%s)", _who(session), len(result), total_count, conv_id[:8], limit, offset, f", after={after_ts}" if after_ts else "") await send_resp(msg, writer, "get_messages", "ok", {"messages": result, "total_count": total_count}) async def handle_remove_member(msg: dict, session: dict, writer: ProtocolWriter): if await _is_rate_limited(f"remove_member|{session['user_id']}", 10): await send_resp(msg, writer, "remove_member", "error", {"message": "Too many requests. Try later."}) return conv_id = msg.get("conversation_id", "") user_id = msg.get("user_id", "") if not conv_id or not user_id: await send_resp(msg, writer, "remove_member", "error", {"message": "Missing conversation_id or user_id"}) return if not _valid_uuid(conv_id) or not _valid_uuid(user_id): await send_resp(msg, writer, "remove_member", "error", {"message": "Invalid conversation_id or user_id"}) return if not await adb.is_conversation_member(conv_id, session["user_id"]): await send_resp(msg, writer, "remove_member", "error", {"message": "Not a member"}) return convs = await adb.list_user_conversations(session["user_id"]) conv_data = None for c in convs: if c["id"] == conv_id: conv_data = c break if not conv_data or conv_data.get("created_by") != session["user_id"]: await send_resp(msg, writer, "remove_member", "error", {"message": "Only the group creator can remove members"}) return if user_id == session["user_id"]: await send_resp(msg, writer, "remove_member", "error", {"message": "Cannot remove yourself"}) return # Get remaining members before removing (to notify them) members_before = await adb.get_conversation_members(conv_id) # M6: atomic removal — return value confirms row existed removed = await adb.remove_conversation_member_atomic(conv_id, user_id) if not removed: await send_resp(msg, writer, "remove_member", "error", {"message": "Member already removed"}) return logger.info("[MEMBER] %s removed user=%s from conv=%s", _who(session), user_id[:8], conv_id[:8]) await send_resp(msg, writer, "remove_member", "ok", {"message": "OK"}) # Notify removed member and remaining members notif_data = { "conversation_id": conv_id, "user_id": user_id, } member_ids = [m["id"] for m in members_before if m["id"] != session["user_id"]] await _notify_users(member_ids, "member_removed", notif_data) async def handle_leave_group(msg: dict, session: dict, writer: ProtocolWriter): """Leave a group conversation voluntarily.""" conv_id = msg.get("conversation_id", "") if not conv_id: await send_resp(msg, writer, "leave_group", "error", {"message": "Missing conversation_id"}) return if not _valid_uuid(conv_id): await send_resp(msg, writer, "leave_group", "error", {"message": "Invalid conversation_id"}) return if not await adb.is_conversation_member(conv_id, session["user_id"]): await send_resp(msg, writer, "leave_group", "error", {"message": "Not a member"}) return # Don't allow leaving DMs (2 members without a name) conv = await adb.get_conversation(conv_id) members = await adb.get_conversation_members(conv_id) if len(members) <= 2 and not (conv and conv.get("name")): await send_resp(msg, writer, "leave_group", "error", {"message": "Cannot leave a DM conversation"}) return # If creator is leaving, transfer to first remaining member if conv and conv.get("created_by") == session["user_id"]: remaining = [m for m in members if m["id"] != session["user_id"]] if remaining: await adb.update_conversation_creator(conv_id, remaining[0]["id"]) # M6: atomic removal await adb.remove_conversation_member_atomic(conv_id, session["user_id"]) logger.info("[LEAVE] %s left group conv=%s", _who(session), conv_id[:8]) await send_resp(msg, writer, "leave_group", "ok", {"message": "OK"}) # Notify remaining members notif_data = { "conversation_id": conv_id, "user_id": session["user_id"], } member_ids = [m["id"] for m in members if m["id"] != session["user_id"]] await _notify_users(member_ids, "member_removed", notif_data) async def handle_rename_conversation(msg: dict, session: dict, writer: ProtocolWriter): """Rename a group conversation (creator only).""" if await _is_rate_limited(f"rename_conv|{session['user_id']}", 5): await send_resp(msg, writer, "rename_conversation", "error", {"message": "Too many requests. Try later."}) return conv_id = msg.get("conversation_id", "") new_name = msg.get("name", "").strip() if not conv_id or not new_name: await send_resp(msg, writer, "rename_conversation", "error", {"message": "Missing conversation_id or name"}) return if not _valid_uuid(conv_id): await send_resp(msg, writer, "rename_conversation", "error", {"message": "Invalid conversation_id"}) return if len(new_name) > 100: await send_resp(msg, writer, "rename_conversation", "error", {"message": "Name too long (max 100)"}) return if not await adb.is_conversation_member(conv_id, session["user_id"]): await send_resp(msg, writer, "rename_conversation", "error", {"message": "Not a member"}) return conv = await adb.get_conversation(conv_id) if not conv or not conv.get("name"): await send_resp(msg, writer, "rename_conversation", "error", {"message": "Cannot rename a DM conversation"}) return if conv.get("created_by") != session["user_id"]: await send_resp(msg, writer, "rename_conversation", "error", {"message": "Only the group creator can rename"}) return await adb.update_conversation_name(conv_id, new_name) logger.info("[RENAME] %s renamed conv=%s", _who(session), conv_id[:8]) await send_resp(msg, writer, "rename_conversation", "ok", {"message": "OK"}) # Notify all members members = await adb.get_conversation_members(conv_id) member_ids = [m["id"] for m in members if m["id"] != session["user_id"]] await _notify_users(member_ids, "conversation_renamed", { "conversation_id": conv_id, "name": new_name, "renamed_by": session["user_id"], }) async def handle_delete_conversation(msg: dict, session: dict, writer: ProtocolWriter): """Delete a conversation for the current user. Removes user from members, deletes the conversation if no members remain.""" if await _is_rate_limited(f"delete_conv|{session['user_id']}", 5): await send_resp(msg, writer, "delete_conversation", "error", {"message": "Too many requests. Try later."}) return conv_id = msg.get("conversation_id", "") if not conv_id: await send_resp(msg, writer, "delete_conversation", "error", {"message": "Missing conversation_id"}) return if not _valid_uuid(conv_id): await send_resp(msg, writer, "delete_conversation", "error", {"message": "Invalid conversation_id"}) return if not await adb.is_conversation_member(conv_id, session["user_id"]): await send_resp(msg, writer, "delete_conversation", "error", {"message": "Not a member"}) return conv = await adb.get_conversation(conv_id) members = await adb.get_conversation_members(conv_id) is_group = len(members) > 2 or (conv and conv.get("name")) # Groups can only be deleted by the creator (admin) if is_group and (not conv or conv.get("created_by") != session["user_id"]): await send_resp(msg, writer, "delete_conversation", "error", {"message": "Only the group creator can delete this conversation"}) return if is_group: # Group: creator deletes for everyone — remove all members, clean up, delete for member in members: await adb.remove_conversation_member(conv_id, member["id"]) else: # DM: only remove self; other user keeps the conversation await adb.remove_conversation_member(conv_id, session["user_id"]) remaining_count = await adb.count_conversation_members(conv_id) if remaining_count == 0: # Clean up uploaded files from disk file_ids = await adb.get_conversation_file_ids(conv_id) for fid in file_ids: for ext in (".enc", ".tmp"): p = _safe_upload_path(fid, ext) if not p: continue _secure_delete(p) await adb.delete_conversation(conv_id) logger.info("[DELETE] %s deleted conv=%s", _who(session), conv_id[:8]) await send_resp(msg, writer, "delete_conversation", "ok", {"message": "OK"}) # Notify other members they were removed notif_data = { "conversation_id": conv_id, "user_id": session["user_id"], } member_ids = [m["id"] for m in members if m["id"] != session["user_id"]] await _notify_users(member_ids, "member_removed", notif_data) async def handle_mark_read(msg: dict, session: dict, writer: ProtocolWriter): conv_id = msg.get("conversation_id", "") message_ids = msg.get("message_ids", []) if not conv_id or not message_ids: await send_resp(msg, writer, "mark_read", "error", {"message": "Missing conversation_id or message_ids"}) return if not _valid_uuid(conv_id): await send_resp(msg, writer, "mark_read", "error", {"message": "Invalid conversation_id"}) return if len(message_ids) > 500: await send_resp(msg, writer, "mark_read", "error", {"message": "Too many message_ids (max 500)"}) return if not await adb.is_conversation_member(conv_id, session["user_id"]): await send_resp(msg, writer, "mark_read", "error", {"message": "Not a member"}) return # M1 fix: filter to only message_ids that belong to this conversation valid_ids = await adb.filter_message_ids_by_conversation(conv_id, message_ids) if not valid_ids: await send_resp(msg, writer, "mark_read", "ok", {"message": "OK"}) return await adb.mark_messages_read(conv_id, session["user_id"], valid_ids) logger.info("[READ] %s marked %d msgs read in conv=%s", _who(session), len(valid_ids), conv_id[:8]) await send_resp(msg, writer, "mark_read", "ok", {"message": "OK"}) members = await adb.get_conversation_members(conv_id) notif_data = { "conversation_id": conv_id, "user_id": session["user_id"], "message_ids": valid_ids, } member_ids = [m["id"] for m in members if m["id"] != session["user_id"]] await _notify_users(member_ids, "messages_read", notif_data) async def handle_mark_conversation_read(msg: dict, session: dict, writer: ProtocolWriter): conv_id = msg.get("conversation_id", "") if not conv_id: await send_resp(msg, writer, "mark_conversation_read", "error", {"message": "Missing conversation_id"}) return if not _valid_uuid(conv_id): await send_resp(msg, writer, "mark_conversation_read", "error", {"message": "Invalid conversation_id"}) return if not await adb.is_conversation_member(conv_id, session["user_id"]): await send_resp(msg, writer, "mark_conversation_read", "error", {"message": "Not a member"}) return count = await adb.mark_conversation_read(conv_id, session["user_id"]) logger.info("[READ] %s marked conv=%s all-read (%d msgs)", _who(session), conv_id[:8], count) await send_resp(msg, writer, "mark_conversation_read", "ok", {"marked_count": count}) if count > 0: members = await adb.get_conversation_members(conv_id) member_ids = [m["id"] for m in members if m["id"] != session["user_id"]] await _notify_users(member_ids, "messages_read", { "conversation_id": conv_id, "user_id": session["user_id"], "message_ids": [], }) async def handle_confirm_delivery(msg: dict, session: dict, writer: ProtocolWriter): conv_id = msg.get("conversation_id", "") message_ids = msg.get("message_ids", []) if not conv_id or not message_ids: await send_resp(msg, writer, "confirm_delivery", "error", {"message": "Missing conversation_id or message_ids"}) return if not _valid_uuid(conv_id): await send_resp(msg, writer, "confirm_delivery", "error", {"message": "Invalid conversation_id"}) return if len(message_ids) > 500: await send_resp(msg, writer, "confirm_delivery", "error", {"message": "Too many message_ids (max 500)"}) return if not await adb.is_conversation_member(conv_id, session["user_id"]): await send_resp(msg, writer, "confirm_delivery", "error", {"message": "Not a member"}) return # M1 fix: filter to only message_ids that belong to this conversation valid_ids = await adb.filter_message_ids_by_conversation(conv_id, message_ids) if not valid_ids: await send_resp(msg, writer, "confirm_delivery", "ok", {"message": "OK"}) return await adb.mark_messages_delivered(conv_id, session["user_id"], valid_ids) logger.info("[DELIVERY] %s confirmed %d msgs delivered in conv=%s", _who(session), len(valid_ids), conv_id[:8]) await send_resp(msg, writer, "confirm_delivery", "ok", {"message": "OK"}) # Notify senders — batch lookup sender_id per message, push to each sender sender_msgs: dict[str, list[str]] = {} for mid in valid_ids: sid = await adb.get_message_sender(mid) if sid and sid != session["user_id"]: sender_msgs.setdefault(sid, []).append(mid) for sender_id, mids in sender_msgs.items(): await _notify_users([sender_id], "message_delivered", { "conversation_id": conv_id, "user_id": session["user_id"], "message_ids": mids, }) async def handle_delete_message(msg: dict, session: dict, writer: ProtocolWriter): if await _is_rate_limited(f"delete_msg|{session['user_id']}", 20): await send_resp(msg, writer, "delete_message", "error", {"message": "Too many requests. Try later."}) return message_id = msg.get("message_id", "") if not message_id: await send_resp(msg, writer, "delete_message", "error", {"message": "Missing message_id"}) return if not _valid_uuid(message_id): await send_resp(msg, writer, "delete_message", "error", {"message": "Invalid message_id"}) return conv_id = await adb.get_message_conversation(message_id) if not conv_id: await send_resp(msg, writer, "delete_message", "error", {"message": "Message not found"}) return if not await adb.is_conversation_member(conv_id, session["user_id"]): await send_resp(msg, writer, "delete_message", "error", {"message": "Not a member"}) return result = await adb.soft_delete_message(message_id, session["user_id"]) if result is None: await send_resp(msg, writer, "delete_message", "error", {"message": "Cannot delete this message"}) return image_file_id = result.get("image_file_id") if image_file_id: image_path = _safe_upload_path(image_file_id, ".enc") if image_path: _secure_delete(image_path) await adb.delete_image_upload(image_file_id) logger.info("[MSG] %s deleted message=%s", _who(session), message_id[:8]) await send_resp(msg, writer, "delete_message", "ok", {"message_id": message_id}) members = await adb.get_conversation_members(conv_id) notif_data = {"message_id": message_id, "conversation_id": conv_id} member_ids = [m["id"] for m in members if m["id"] != session["user_id"]] await _notify_users(member_ids, "message_deleted", notif_data) async def handle_react_message(msg: dict, session: dict, writer: ProtocolWriter): if await _is_rate_limited(f"react|{session['user_id']}", 20): await send_resp(msg, writer, "react_message", "error", {"message": "Too many requests. Try later."}) return message_id = msg.get("message_id", "") reaction = msg.get("reaction", "") action = msg.get("action", "add") # "add" or "remove" if not message_id or not reaction: await send_resp(msg, writer, "react_message", "error", {"message": "Missing fields"}) return if not _valid_uuid(message_id): await send_resp(msg, writer, "react_message", "error", {"message": "Invalid message_id"}) return if reaction not in db.ALLOWED_REACTIONS: await send_resp(msg, writer, "react_message", "error", {"message": "Invalid reaction"}) return if action not in ("add", "remove"): await send_resp(msg, writer, "react_message", "error", {"message": "Invalid action"}) return conv_id = await adb.get_message_conversation(message_id) if not conv_id: await send_resp(msg, writer, "react_message", "error", {"message": "Message not found"}) return if not await adb.is_conversation_member(conv_id, session["user_id"]): await send_resp(msg, writer, "react_message", "error", {"message": "Not a member"}) return old_reaction = None if action == "add": changed, old_reaction = await adb.add_reaction(message_id, session["user_id"], reaction) if not changed: await send_resp(msg, writer, "react_message", "ok", {"message_id": message_id}) return else: await adb.remove_reaction(message_id, session["user_id"]) logger.info("[MSG] %s %s reaction '%s' on message=%s", _who(session), action, reaction, message_id[:8]) resp_data = {"message_id": message_id} if old_reaction: resp_data["old_reaction"] = old_reaction await send_resp(msg, writer, "react_message", "ok", resp_data) members = await adb.get_conversation_members(conv_id) member_ids = [m["id"] for m in members if m["id"] != session["user_id"]] # If replacing an old reaction, notify removal first if old_reaction: remove_data = { "message_id": message_id, "conversation_id": conv_id, "user_id": session["user_id"], "username": session.get("username", ""), "reaction": old_reaction, "action": "remove", } await _notify_users(member_ids, "message_reacted", remove_data) notif_data = { "message_id": message_id, "conversation_id": conv_id, "user_id": session["user_id"], "username": session.get("username", ""), "reaction": reaction, "action": action, } await _notify_users(member_ids, "message_reacted", notif_data) async def handle_pin_message(msg: dict, session: dict, writer: ProtocolWriter): message_id = msg.get("message_id", "") action = msg.get("action", "pin") # "pin" or "unpin" conversation_id = msg.get("conversation_id", "") if not message_id or not conversation_id: await send_resp(msg, writer, "pin_message", "error", {"message": "Missing fields"}) return if not _valid_uuid(message_id) or not _valid_uuid(conversation_id): await send_resp(msg, writer, "pin_message", "error", {"message": "Invalid ID"}) return if action not in ("pin", "unpin"): await send_resp(msg, writer, "pin_message", "error", {"message": "Invalid action"}) return if not await adb.is_conversation_member(conversation_id, session["user_id"]): await send_resp(msg, writer, "pin_message", "error", {"message": "Not a member"}) return if action == "pin": ok = await adb.pin_message(message_id, session["user_id"], conversation_id) else: ok = await adb.unpin_message(message_id, conversation_id) if not ok: await send_resp(msg, writer, "pin_message", "error", {"message": "Already pinned" if action == "pin" else "Not pinned"}) return logger.info("[MSG] %s %s message=%s in conv=%s", _who(session), action, message_id[:8], conversation_id[:8]) await send_resp(msg, writer, "pin_message", "ok", {"message_id": message_id, "action": action}) members = await adb.get_conversation_members(conversation_id) notif_type = "message_pinned" if action == "pin" else "message_unpinned" notif_data = { "message_id": message_id, "conversation_id": conversation_id, "user_id": session["user_id"], "username": session.get("username", ""), } member_ids = [m["id"] for m in members if m["id"] != session["user_id"]] await _notify_users(member_ids, notif_type, notif_data) async def handle_get_pinned_messages(msg: dict, session: dict, writer: ProtocolWriter): conversation_id = msg.get("conversation_id", "") if not conversation_id: await send_resp(msg, writer, "get_pinned_messages", "error", {"message": "Missing conversation_id"}) return if not _valid_uuid(conversation_id): await send_resp(msg, writer, "get_pinned_messages", "error", {"message": "Invalid conversation_id"}) return if not await adb.is_conversation_member(conversation_id, session["user_id"]): await send_resp(msg, writer, "get_pinned_messages", "error", {"message": "Not a member"}) return pinned = await adb.get_pinned_messages(conversation_id, session["user_id"]) await send_resp(msg, writer, "get_pinned_messages", "ok", {"messages": pinned}) async def handle_upload_image_start(msg: dict, session: dict, writer: ProtocolWriter): conv_id = msg.get("conversation_id", "") file_size = msg.get("file_size", 0) file_id = msg.get("file_id", "") file_type = msg.get("file_type", "image") # "image" or "file" if not conv_id or not file_id: await send_resp(msg, writer, "upload_image_start", "error", {"message": "Missing fields"}) return if not _valid_uuid(file_id): await send_resp(msg, writer, "upload_image_start", "error", {"message": "Invalid file_id"}) return # M5: rate limit + caps on in-flight uploads addr = _get_peer_addr(writer) if await _is_rate_limited(f"upload_start|{session['user_id']}", 10): await send_resp(msg, writer, "upload_image_start", "error", {"message": "Too many uploads. Try later."}) return if not await adb.is_conversation_member(conv_id, session["user_id"]): await send_resp(msg, writer, "upload_image_start", "error", {"message": "Not a member"}) return max_bytes = MAX_FILE_BYTES if file_type == "file" else MAX_IMAGE_BYTES if max_bytes > 0 and file_size > max_bytes: await send_resp(msg, writer, "upload_image_start", "error", {"message": f"File too large (max {max_bytes} bytes)"}) return UPLOAD_DIR.mkdir(parents=True, exist_ok=True) temp_path = _safe_upload_path(file_id, ".tmp") if not temp_path: await send_resp(msg, writer, "upload_image_start", "error", {"message": "Invalid file_id"}) return # M5: atomic cap check + insert under single lock acquisition cap_error = "" async with _uploads_lock: total = len(pending_uploads) user_count = sum(1 for u in pending_uploads.values() if u.get("uploader_id") == session["user_id"]) if total >= MAX_UPLOADS_GLOBAL: cap_error = "Server upload limit reached. Try later." elif user_count >= MAX_UPLOADS_PER_USER: cap_error = "Too many active uploads. Finish or cancel existing ones." else: temp_path.write_bytes(b"") pending_uploads[file_id] = { "temp_path": str(temp_path), "received_bytes": 0, "file_size": file_size, "max_bytes": max_bytes, "conv_id": conv_id, "uploader_id": session["user_id"], } if cap_error: await send_resp(msg, writer, "upload_image_start", "error", {"message": cap_error}) return try: await adb.create_image_upload(file_id, conv_id, session["user_id"], file_size) except Exception: # Rollback: remove from pending_uploads + delete temp file async with _uploads_lock: pending_uploads.pop(file_id, None) _secure_delete(temp_path) logger.exception("[UPLOAD] DB create failed for file=%s", file_id[:8]) await send_resp(msg, writer, "upload_image_start", "error", {"message": "Upload failed"}) return logger.info("[UPLOAD] %s started upload file=%s (%s, %d bytes)", _who(session), file_id[:8], file_type, file_size) await send_resp(msg, writer, "upload_image_start", "ok", {"file_id": file_id}) async def handle_upload_image_chunk(msg: dict, session: dict, writer: ProtocolWriter): file_id = msg.get("file_id", "") chunk_data = msg.get("data", "") if not file_id or not chunk_data: await send_resp(msg, writer, "upload_image_chunk", "error", {"message": "Missing fields"}) return async with _uploads_lock: upload = pending_uploads.get(file_id) if not upload or upload["uploader_id"] != session["user_id"]: upload = None else: temp_path_str = upload["temp_path"] upload_max = upload.get("max_bytes", 0) if not upload: await send_resp(msg, writer, "upload_image_chunk", "error", {"message": "No active upload"}) return raw = decode_binary(chunk_data) temp_path = Path(temp_path_str) await asyncio.to_thread(_append_file, temp_path, raw) over_limit = False async with _uploads_lock: upload = pending_uploads.get(file_id) if upload: upload["received_bytes"] += len(raw) if upload_max > 0 and upload["received_bytes"] > upload_max: pending_uploads.pop(file_id, None) over_limit = True received = upload["received_bytes"] if over_limit: _secure_delete(temp_path) await send_resp(msg, writer, "upload_image_chunk", "error", {"message": "Upload exceeds size limit"}) return await send_resp(msg, writer, "upload_image_chunk", "ok", {"received": received}) async def handle_upload_image_end(msg: dict, session: dict, writer: ProtocolWriter): file_id = msg.get("file_id", "") if not file_id: await send_resp(msg, writer, "upload_image_end", "error", {"message": "Missing file_id"}) return async with _uploads_lock: upload = pending_uploads.pop(file_id, None) if not upload or upload["uploader_id"] != session["user_id"]: await send_resp(msg, writer, "upload_image_end", "error", {"message": "No active upload"}) return temp_path = Path(upload["temp_path"]) if upload["received_bytes"] != upload["file_size"]: _secure_delete(temp_path) await send_resp(msg, writer, "upload_image_end", "error", {"message": f"Incomplete upload: received {upload['received_bytes']} of {upload['file_size']} bytes"}) return final_path = _safe_upload_path(file_id, ".enc") if not final_path: _secure_delete(temp_path) await send_resp(msg, writer, "upload_image_end", "error", {"message": "Invalid file_id"}) return def _move_file(): try: temp_path.rename(final_path) except Exception: import shutil shutil.move(str(temp_path), str(final_path)) await asyncio.to_thread(_move_file) await adb.complete_image_upload(file_id) logger.info("[UPLOAD] %s completed upload file=%s (%d bytes)", _who(session), file_id[:8], upload["received_bytes"]) await send_resp(msg, writer, "upload_image_end", "ok", {"file_id": file_id}) async def handle_download_image(msg: dict, session: dict, writer: ProtocolWriter): file_id = msg.get("file_id", "") offset = msg.get("offset", 0) if not file_id: await send_resp(msg, writer, "download_image", "error", {"message": "Missing file_id"}) return if not _valid_uuid(file_id): await send_resp(msg, writer, "download_image", "error", {"message": "Invalid file_id"}) return upload = await adb.get_image_upload(file_id) if not upload or not upload["completed"]: await send_resp(msg, writer, "download_image", "error", {"message": "File not found"}) return if not await adb.is_conversation_member(upload["conversation_id"], session["user_id"]): await send_resp(msg, writer, "download_image", "error", {"message": "Not a member"}) return file_path = _safe_upload_path(file_id, ".enc") if not file_path or not file_path.exists(): await send_resp(msg, writer, "download_image", "error", {"message": "File not found"}) return file_size = file_path.stat().st_size chunk = await asyncio.to_thread(_read_file_chunk, file_path, offset, IMAGE_CHUNK_SIZE) done = (offset + len(chunk)) >= file_size if offset == 0: logger.info("[DOWNLOAD] %s downloading file=%s (%d bytes)", _who(session), file_id[:8], file_size) await send_resp(msg, writer, "download_image", "ok", { "file_id": file_id, "data": encode_binary(chunk), "offset": offset, "done": done, "total_size": file_size, }) MAX_AVATAR_BYTES = 2 * 1024 * 1024 # 2 MB async def handle_get_profile(msg: dict, session: dict, writer: ProtocolWriter): """Get user profile (respects visibility for other users).""" target_user_id = msg.get("user_id", "").strip() if not target_user_id: target_user_id = session["user_id"] elif not _valid_uuid(target_user_id): await send_resp(msg, writer, "get_profile", "error", {"message": "Invalid user_id"}) return profile = await adb.get_user_profile(target_user_id, viewer_id=session["user_id"]) if not profile: await send_resp(msg, writer, "get_profile", "error", {"message": "User not found"}) return # Serialize datetime fields for key in ("created_at", "updated_at"): if profile.get(key) and hasattr(profile[key], "isoformat"): profile[key] = profile[key].isoformat() await send_resp(msg, writer, "get_profile", "ok", profile) async def handle_update_profile(msg: dict, session: dict, writer: ProtocolWriter): """Update own profile fields.""" fields = {} for key in ("phone", "phone_visible", "email_visible", "location", "location_visible"): if key in msg: fields[key] = msg[key] if not fields: await send_resp(msg, writer, "update_profile", "error", {"message": "No fields to update"}) return await adb.update_user_profile(session["user_id"], **fields) await send_resp(msg, writer, "update_profile", "ok", {"message": "OK"}) async def handle_update_avatar(msg: dict, session: dict, writer: ProtocolWriter): """Upload avatar (base64 in single message, max 2MB).""" if await _is_rate_limited(f"update_avatar|{session['user_id']}", 5): await send_resp(msg, writer, "update_avatar", "error", {"message": "Too many requests. Try later."}) return avatar_b64 = msg.get("data", "") if not avatar_b64: await send_resp(msg, writer, "update_avatar", "error", {"message": "Missing data"}) return avatar_data = decode_binary(avatar_b64) if len(avatar_data) > MAX_AVATAR_BYTES: await send_resp(msg, writer, "update_avatar", "error", {"message": f"Avatar too large (max {MAX_AVATAR_BYTES} bytes)"}) return # Detect format from magic bytes ext = "jpg" if avatar_data[:8] == b'\x89PNG\r\n\x1a\n': ext = "png" avatar_dir = UPLOAD_DIR / "avatars" avatar_dir.mkdir(parents=True, exist_ok=True) os.chmod(avatar_dir, 0o700) filename = f"{session['user_id']}.{ext}" avatar_path = _safe_avatar_path(filename) if not avatar_path: await send_resp(msg, writer, "update_avatar", "error", {"message": "Invalid path"}) return await asyncio.to_thread(avatar_path.write_bytes, avatar_data) await adb.update_user_profile(session["user_id"], avatar_file=filename) logger.info("[AVATAR] %s updated their avatar", _who(session)) await send_resp(msg, writer, "update_avatar", "ok", {"avatar_file": filename}) async def handle_get_avatar(msg: dict, session: dict, writer: ProtocolWriter): """Download avatar for a user.""" target_user_id = msg.get("user_id", "").strip() if not target_user_id: await send_resp(msg, writer, "get_avatar", "error", {"message": "Missing user_id"}) return if not _valid_uuid(target_user_id): await send_resp(msg, writer, "get_avatar", "error", {"message": "Invalid user_id"}) return profile = await adb.get_user_profile(target_user_id) if not profile or not profile.get("avatar_file"): await send_resp(msg, writer, "get_avatar", "error", {"message": "No avatar"}) return avatar_path = _safe_avatar_path(profile["avatar_file"]) if not avatar_path or not avatar_path.exists(): await send_resp(msg, writer, "get_avatar", "error", {"message": "Avatar file not found"}) return avatar_data = await asyncio.to_thread(avatar_path.read_bytes) await send_resp(msg, writer, "get_avatar", "ok", { "user_id": target_user_id, "data": encode_binary(avatar_data), "filename": profile["avatar_file"], }) async def handle_update_group_avatar(msg: dict, session: dict, writer: ProtocolWriter): """Upload avatar for a group conversation (base64, max 2MB). Only members can set it.""" if await _is_rate_limited(f"update_avatar|{session['user_id']}", 5): await send_resp(msg, writer, "update_group_avatar", "error", {"message": "Too many requests. Try later."}) return conv_id = msg.get("conversation_id", "").strip() avatar_b64 = msg.get("data", "") if not conv_id or not avatar_b64: await send_resp(msg, writer, "update_group_avatar", "error", {"message": "Missing fields"}) return if not _valid_uuid(conv_id): await send_resp(msg, writer, "update_group_avatar", "error", {"message": "Invalid conversation_id"}) return if not await adb.is_conversation_member(conv_id, session["user_id"]): await send_resp(msg, writer, "update_group_avatar", "error", {"message": "Not a member"}) return avatar_data = decode_binary(avatar_b64) if len(avatar_data) > MAX_AVATAR_BYTES: await send_resp(msg, writer, "update_group_avatar", "error", {"message": f"Avatar too large (max {MAX_AVATAR_BYTES} bytes)"}) return ext = "jpg" if avatar_data[:8] == b'\x89PNG\r\n\x1a\n': ext = "png" avatar_dir = UPLOAD_DIR / "avatars" avatar_dir.mkdir(parents=True, exist_ok=True) os.chmod(avatar_dir, 0o700) filename = f"group_{conv_id}.{ext}" avatar_path = _safe_avatar_path(filename) if not avatar_path: await send_resp(msg, writer, "update_group_avatar", "error", {"message": "Invalid path"}) return await asyncio.to_thread(avatar_path.write_bytes, avatar_data) await adb.update_conversation_avatar(conv_id, filename) logger.info("[AVATAR] %s updated group avatar for conv=%s", _who(session), conv_id[:8]) await send_resp(msg, writer, "update_group_avatar", "ok", {"avatar_file": filename}) async def handle_get_group_avatar(msg: dict, session: dict, writer: ProtocolWriter): """Download avatar for a group conversation.""" conv_id = msg.get("conversation_id", "").strip() if not conv_id: await send_resp(msg, writer, "get_group_avatar", "error", {"message": "Missing conversation_id"}) return if not _valid_uuid(conv_id): await send_resp(msg, writer, "get_group_avatar", "error", {"message": "Invalid conversation_id"}) return if not await adb.is_conversation_member(conv_id, session["user_id"]): await send_resp(msg, writer, "get_group_avatar", "error", {"message": "Not a member"}) return conv = await adb.get_conversation(conv_id) if not conv or not conv.get("avatar_file"): await send_resp(msg, writer, "get_group_avatar", "error", {"message": "No avatar"}) return avatar_path = _safe_avatar_path(conv["avatar_file"]) if not avatar_path or not avatar_path.exists(): await send_resp(msg, writer, "get_group_avatar", "error", {"message": "Avatar file not found"}) return avatar_data = await asyncio.to_thread(avatar_path.read_bytes) await send_resp(msg, writer, "get_group_avatar", "ok", { "conversation_id": conv_id, "data": encode_binary(avatar_data), "filename": conv["avatar_file"], }) async def handle_list_devices(msg: dict, session: dict, writer: ProtocolWriter): """List all devices for the current user.""" devices = await adb.get_user_devices(session["user_id"]) result = [] for d in devices: entry = { "device_id": d["id"], "device_name": d.get("device_name"), "created_at": d["created_at"].isoformat() if hasattr(d["created_at"], "isoformat") else str(d["created_at"]), "last_seen_at": d["last_seen_at"].isoformat() if d.get("last_seen_at") and hasattr(d["last_seen_at"], "isoformat") else (str(d["last_seen_at"]) if d.get("last_seen_at") else None), "is_current": d["id"] == session.get("device_id"), } result.append(entry) await send_resp(msg, writer, "list_devices", "ok", {"devices": result}) async def handle_remove_device(msg: dict, session: dict, writer: ProtocolWriter): """Remove a device (cannot remove current device).""" device_id = msg.get("device_id", "").strip() if not device_id: await send_resp(msg, writer, "remove_device", "error", {"message": "Missing device_id"}) return if not _valid_uuid(device_id): await send_resp(msg, writer, "remove_device", "error", {"message": "Invalid device_id"}) return if device_id == session.get("device_id"): await send_resp(msg, writer, "remove_device", "error", {"message": "Cannot remove current device"}) return dev = await adb.get_device(device_id) if not dev or dev["user_id"] != session["user_id"]: await send_resp(msg, writer, "remove_device", "error", {"message": "Device not found"}) return await adb.delete_device(device_id) logger.info("[DEVICE] %s removed device=%s", _who(session), device_id[:8]) await send_resp(msg, writer, "remove_device", "ok", {"message": "OK"}) async def handle_session_reset(msg: dict, session: dict, writer: ProtocolWriter): """Notify peer to reset a corrupted Double Ratchet session.""" peer_user_id = msg.get("peer_user_id", "").strip() peer_device_id = msg.get("peer_device_id", "").strip() or None if not peer_user_id or not _valid_uuid(peer_user_id): await send_resp(msg, writer, "session_reset", "error", {"message": "Invalid peer_user_id"}) return if peer_device_id and not _valid_uuid(peer_device_id): await send_resp(msg, writer, "session_reset", "error", {"message": "Invalid peer_device_id"}) return # H3 fix: rate limit (5/min per user, keyed by user_id only — IP-independent) if await _is_rate_limited(f"session_reset|{session['user_id']}", 5): await send_resp(msg, writer, "session_reset", "error", {"message": "Rate limit exceeded"}) return # H3 fix: verify users share at least one conversation if not await adb.shares_conversation(session["user_id"], peer_user_id): await send_resp(msg, writer, "session_reset", "error", {"message": "No shared conversation"}) return # Push notification to peer (target specific device if specified) notif_data = { "from_user_id": session["user_id"], "from_device_id": session.get("device_id"), } if peer_device_id: # Send only to the specific device targets = [] async with _clients_lock: for w in connected_clients.get(peer_user_id, []): if writer_device_map.get(id(w)) == peer_device_id: targets.append(w) for w in targets: try: await w.send_response("session_reset", "ok", notif_data) except Exception: pass else: await _notify_users([peer_user_id], "session_reset", notif_data) logger.info("[SESSION] %s reset session with peer=%s", _who(session), peer_user_id[:8]) await send_resp(msg, writer, "session_reset", "ok", {}) async def handle_get_deleted_since(msg: dict, session: dict, writer: ProtocolWriter): """Return message IDs deleted since a given timestamp.""" conv_id = msg.get("conversation_id", "") since_ts = msg.get("since_ts", "") if not conv_id or not since_ts: await send_resp(msg, writer, "get_deleted_since", "error", {"message": "Missing parameters"}) return if not _valid_uuid(conv_id): await send_resp(msg, writer, "get_deleted_since", "error", {"message": "Invalid conversation_id"}) return if not await adb.is_conversation_member(conv_id, session["user_id"]): await send_resp(msg, writer, "get_deleted_since", "error", {"message": "Not a member"}) return deleted_ids = await adb.get_deleted_messages_since(conv_id, session["user_id"], since_ts) await send_resp(msg, writer, "get_deleted_since", "ok", {"deleted_ids": deleted_ids}) async def handle_reencrypt_messages(msg: dict, session: dict, writer: ProtocolWriter): """Re-encrypt message history with self-encryption key (for device pairing).""" if await _is_rate_limited(f"reencrypt|{session['user_id']}", 10): await send_resp(msg, writer, "reencrypt_messages", "error", {"message": "Too many requests. Try later."}) return updates_raw = msg.get("updates", []) if not updates_raw: await send_resp(msg, writer, "reencrypt_messages", "error", {"message": "No updates"}) return if len(updates_raw) > 500: await send_resp(msg, writer, "reencrypt_messages", "error", {"message": "Too many updates (max 500 per request)"}) return updates = [] for u in updates_raw: mid = u.get("message_id", "") ct_b64 = u.get("encrypted_content", "") nonce_b64 = u.get("nonce", "") if not mid or not ct_b64 or not nonce_b64: continue updates.append({ "message_id": mid, "encrypted_content": decode_binary(ct_b64), "nonce": decode_binary(nonce_b64), }) if not updates: await send_resp(msg, writer, "reencrypt_messages", "error", {"message": "No valid updates"}) return await adb.batch_reencrypt_messages(session["user_id"], updates) logger.info("[REENCRYPT] %s re-encrypted %d messages", _who(session), len(updates)) await send_resp(msg, writer, "reencrypt_messages", "ok", {"count": len(updates)}) async def _cleanup_uploads(): stale = await adb.get_stale_uploads(UPLOAD_STALE_SECONDS) for s in stale: fid = s["file_id"] for ext in (".tmp", ".enc"): p = _safe_upload_path(fid, ext) if not p: continue _secure_delete(p) await adb.delete_image_upload(fid) async with _uploads_lock: pending_uploads.pop(fid, None) if stale: logger.info("Cleaned up %d stale uploads.", len(stale)) async def handle_client(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): global current_connections addr = _get_peer_addr(ProtocolWriter(writer)) async with _conn_lock: current_connections += 1 connection_counts[addr] = connection_counts.get(addr, 0) + 1 over_limit = (current_connections > MAX_CONNECTIONS_GLOBAL or connection_counts[addr] > MAX_CONNECTIONS_PER_IP) if over_limit: try: writer.close() except Exception: pass async with _conn_lock: current_connections = max(0, current_connections - 1) connection_counts[addr] = max(0, connection_counts.get(addr, 1) - 1) return logger.info("[CONN] Client connected from %s", addr) proto_reader = ProtocolReader(reader) proto_writer = ProtocolWriter(writer) session = None state = {"_req_times": []} try: while True: try: msg = await proto_reader.read_message() except ValueError as e: try: await proto_writer.send_response("protocol_error", "error", {"message": str(e)}) except Exception: pass break if msg is None: break msg_type = msg.get("type", "") now = asyncio.get_event_loop().time() times = [t for t in state["_req_times"] if now - t <= CONNECTION_RL_WINDOW] if len(times) >= CONNECTION_RL_MAX: await send_resp(msg, proto_writer, msg_type, "error", {"message": "Too many requests. Slow down."}) state["_req_times"] = times continue times.append(now) state["_req_times"] = times try: if msg_type == "register": await handle_register_start(msg, proto_writer) elif msg_type == "register_confirm": await handle_register_confirm(msg, proto_writer) elif msg_type == "login_start": await handle_login_start(msg, proto_writer, state) elif msg_type == "login_finish": result = await handle_login_finish(msg, proto_writer, state) if result: session = result elif msg_type == "pairing_start": await handle_pairing_start(msg, proto_writer) elif msg_type == "pairing_poll": await handle_pairing_poll(msg, proto_writer) elif session is None: await send_resp(msg, proto_writer, msg_type, "error", {"message": "Not logged in"}) elif msg_type == "get_user_info": await handle_get_user_info(msg, session, proto_writer) elif msg_type == "upload_prekeys": await handle_upload_prekeys(msg, session, proto_writer) elif msg_type == "get_key_bundle": await handle_get_key_bundle(msg, session, proto_writer) elif msg_type == "get_prekey_count": await handle_get_prekey_count(msg, session, proto_writer) elif msg_type == "ensure_prekeys": await handle_ensure_prekeys(msg, session, proto_writer) elif msg_type == "create_conversation": await handle_create_conversation(msg, session, proto_writer) elif msg_type == "find_conversation": await handle_find_conversation(msg, session, proto_writer) elif msg_type == "add_member": await handle_add_member(msg, session, proto_writer) elif msg_type == "accept_invitation": await handle_accept_invitation(msg, session, proto_writer) elif msg_type == "decline_invitation": await handle_decline_invitation(msg, session, proto_writer) elif msg_type == "list_invitations": await handle_list_invitations(msg, session, proto_writer) elif msg_type == "list_conversations": await handle_list_conversations(msg, session, proto_writer) elif msg_type == "send_message": await handle_send_message(msg, session, proto_writer) elif msg_type == "get_messages": await handle_get_messages(msg, session, proto_writer) elif msg_type == "rotate_keys": await handle_rotate_keys(msg, session, proto_writer) elif msg_type == "change_username": await handle_change_username(msg, session, proto_writer) elif msg_type == "remove_member": await handle_remove_member(msg, session, proto_writer) elif msg_type == "leave_group": await handle_leave_group(msg, session, proto_writer) elif msg_type == "rename_conversation": await handle_rename_conversation(msg, session, proto_writer) elif msg_type == "delete_conversation": await handle_delete_conversation(msg, session, proto_writer) elif msg_type == "mark_read": await handle_mark_read(msg, session, proto_writer) elif msg_type == "mark_conversation_read": await handle_mark_conversation_read(msg, session, proto_writer) elif msg_type == "confirm_delivery": await handle_confirm_delivery(msg, session, proto_writer) elif msg_type == "pairing_claim": await handle_pairing_claim(msg, session, proto_writer) elif msg_type == "pairing_send": await handle_pairing_send(msg, session, proto_writer) elif msg_type == "delete_message": await handle_delete_message(msg, session, proto_writer) elif msg_type == "upload_image_start": await handle_upload_image_start(msg, session, proto_writer) elif msg_type == "upload_image_chunk": await handle_upload_image_chunk(msg, session, proto_writer) elif msg_type == "upload_image_end": await handle_upload_image_end(msg, session, proto_writer) elif msg_type == "download_image": await handle_download_image(msg, session, proto_writer) elif msg_type == "get_profile": await handle_get_profile(msg, session, proto_writer) elif msg_type == "update_profile": await handle_update_profile(msg, session, proto_writer) elif msg_type == "update_avatar": await handle_update_avatar(msg, session, proto_writer) elif msg_type == "get_avatar": await handle_get_avatar(msg, session, proto_writer) elif msg_type == "update_group_avatar": await handle_update_group_avatar(msg, session, proto_writer) elif msg_type == "get_group_avatar": await handle_get_group_avatar(msg, session, proto_writer) elif msg_type == "get_deleted_since": await handle_get_deleted_since(msg, session, proto_writer) elif msg_type == "reencrypt_messages": await handle_reencrypt_messages(msg, session, proto_writer) elif msg_type == "list_devices": await handle_list_devices(msg, session, proto_writer) elif msg_type == "remove_device": await handle_remove_device(msg, session, proto_writer) elif msg_type == "session_reset": await handle_session_reset(msg, session, proto_writer) elif msg_type == "react_message": await handle_react_message(msg, session, proto_writer) elif msg_type == "pin_message": await handle_pin_message(msg, session, proto_writer) elif msg_type == "get_pinned_messages": await handle_get_pinned_messages(msg, session, proto_writer) else: await send_resp(msg, proto_writer, msg_type, "error", {"message": "Unknown type"}) except Exception as e: logger.warning("[ERROR] %s handler '%s' failed: %s", _who(session), msg_type, e, exc_info=True) try: await send_resp(msg, proto_writer, msg_type, "error", {"message": "Internal server error"}) except Exception: break # Can't send response — connection is dead except Exception as e: logger.warning("Client connection error: %s", e) finally: async with _conn_lock: current_connections = max(0, current_connections - 1) connection_counts[addr] = max(0, connection_counts.get(addr, 1) - 1) offline_targets = [] if session: uid = session["user_id"] contacts = await adb.get_user_contacts(uid) async with _clients_lock: writer_device_map.pop(id(proto_writer), None) if uid in connected_clients: remaining = [w for w in connected_clients[uid] if w is not proto_writer] if remaining: connected_clients[uid] = remaining else: del connected_clients[uid] # User fully offline — snapshot targets under lock for contact_id in contacts: for cw in connected_clients.get(contact_id, []): offline_targets.append(cw) # Send offline notifications outside lock for cw in offline_targets: try: await cw.send_response("user_offline", "ok", {"user_id": uid}) except Exception: pass writer.close() logger.info("[CONN] %s disconnected", _who(session) if session else addr) async def main(): setup_logging() host = os.getenv("SERVER_HOST", "127.0.0.1") port = int(os.getenv("SERVER_PORT", "9999")) tls_enabled = os.getenv("TLS_ENABLED", "false").lower() in ("1", "true", "yes") tls_required = os.getenv("TLS_REQUIRED", "false").lower() in ("1", "true", "yes") tls_autogen = os.getenv("TLS_AUTOGEN", "false").lower() in ("1", "true", "yes") is_dev = os.getenv("ENVIRONMENT", "").lower() in ("dev", "development") ssl_context = None if tls_required and not tls_enabled: raise RuntimeError("TLS_REQUIRED is enabled but TLS is not enabled.") if tls_enabled: cert_file = os.getenv("TLS_CERT_FILE", "").strip() key_file = os.getenv("TLS_KEY_FILE", "").strip() if not cert_file or not key_file: if tls_autogen: if not is_dev: raise RuntimeError("TLS_AUTOGEN is only allowed when ENVIRONMENT=dev") cert_dir = Path(__file__).resolve().parent / "certs" cert_dir.mkdir(parents=True, exist_ok=True) cert_file = str(cert_dir / "server.crt") key_file = str(cert_dir / "server.key") if not (os.path.exists(cert_file) and os.path.exists(key_file)): try: subprocess.run( [ "openssl", "req", "-x509", "-newkey", "rsa:4096", "-keyout", key_file, "-out", cert_file, "-days", "365", "-nodes", "-subj", "/CN=localhost", ], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, ) os.chmod(key_file, 0o600) except FileNotFoundError: raise RuntimeError("OpenSSL not found.") except subprocess.CalledProcessError: raise RuntimeError("Failed to auto-generate TLS cert.") logger.warning("Using auto-generated self-signed certificate — not for production use.") else: raise RuntimeError("TLS is enabled but TLS_CERT_FILE or TLS_KEY_FILE is missing.") ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) ssl_context.load_cert_chain(certfile=cert_file, keyfile=key_file) else: logger.warning("TLS is disabled — traffic is unencrypted. Set TLS_ENABLED=true for production.") UPLOAD_DIR.mkdir(parents=True, exist_ok=True) # Thread pool for asyncio.to_thread() — DB calls + file I/O pool_workers = int(os.getenv("THREAD_POOL_SIZE", "40")) asyncio.get_event_loop().set_default_executor(ThreadPoolExecutor(max_workers=pool_workers)) logger.info("Thread pool executor: %d workers", pool_workers) # Load phantom user IDs from DB into in-memory cache phantom_user_ids.update(await adb.get_all_phantom_user_ids()) if phantom_user_ids: logger.info("Loaded %d phantom user IDs.", len(phantom_user_ids)) server = await asyncio.start_server( handle_client, host, port, limit=MAX_MESSAGE_BYTES, ssl=ssl_context, ) logger.info("Encrypted chat server v%s listening on %s:%s", VERSION, host, port) async def _cleanup_rate_limits(): async with _conn_lock: now = asyncio.get_event_loop().time() window_start = now - RATE_LIMIT_WINDOW stale_keys = [k for k, times in rate_limits.items() if not any(t >= window_start for t in times)] for k in stale_keys: del rate_limits[k] stale_conns = [k for k, v in connection_counts.items() if v <= 0] for k in stale_conns: del connection_counts[k] _cleanup_cycle = 0 async def _periodic_cleanup(): nonlocal _cleanup_cycle while True: await asyncio.sleep(120) _cleanup_cycle += 1 try: await _cleanup_uploads() except Exception as e: logger.warning("Upload cleanup error: %s", e) try: await _cleanup_rate_limits() except Exception as e: logger.warning("Rate limit cleanup error: %s", e) try: await _cleanup_registrations() except Exception as e: logger.warning("Registration cleanup error: %s", e) # L8: clean up stale phantom users (>30 days, no real conversations) try: deleted = await adb.cleanup_stale_phantoms(30) if deleted: async with _clients_lock: phantom_user_ids.clear() phantom_user_ids.update(await adb.get_all_phantom_user_ids()) logger.info("Cleaned up %d stale phantom users.", deleted) except Exception as e: logger.warning("Phantom cleanup error: %s", e) # Metadata retention: purge old reads and reactions (every 30 cycles = ~1 hour) if _cleanup_cycle % 30 == 0: try: reads_del = await adb.cleanup_old_reads(METADATA_RETENTION_DAYS) reactions_del = await adb.cleanup_old_reactions(METADATA_RETENTION_DAYS) if reads_del or reactions_del: logger.info("Metadata cleanup: %d reads, %d reactions purged", reads_del, reactions_del) except Exception as e: logger.warning("Metadata cleanup error: %s", e) asyncio.create_task(_periodic_cleanup()) loop = asyncio.get_running_loop() stop = loop.create_future() def signal_handler(): if not stop.done(): stop.set_result(None) for sig in (signal.SIGINT, signal.SIGTERM): loop.add_signal_handler(sig, signal_handler) async with server: await stop logger.info("Shutting down — closing %d client connections...", sum(len(ws) for ws in connected_clients.values())) # Stop accepting new connections server.close() # Force-close all connected client writers async with _clients_lock: all_writers = [w for writers in connected_clients.values() for w in writers] connected_clients.clear() writer_device_map.clear() for w in all_writers: try: w.close() except Exception: pass # Give handle_client loops a moment to notice closed connections await asyncio.sleep(0.1) # Cancel any remaining handle_client tasks that are still blocked tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] for t in tasks: t.cancel() await asyncio.gather(*tasks, return_exceptions=True) logger.info("Server shut down.") if __name__ == "__main__": asyncio.run(main())