Files
Kecalek_python/server.py
2026-03-11 16:54:14 +01:00

2934 lines
133 KiB
Python

"""Asyncio TCP server — stores and relays encrypted blobs without seeing content."""
import asyncio
from concurrent.futures import ThreadPoolExecutor
import hashlib
import hmac
import ipaddress
import json
import logging
import os
import re
import secrets
import signal
import smtplib
import ssl
import subprocess
import sys
from email.mime.text import MIMEText
from pathlib import Path
from datetime import datetime, timezone
from dotenv import load_dotenv
load_dotenv()
import db
from crypto_utils import load_public_key, rsa_verify, load_ed25519_public, ed25519_verify, serialize_x25519_public
from protocol import VERSION, MIN_CLIENT_VERSION, version_gte, ProtocolReader, ProtocolWriter, encode_binary, decode_binary, MAX_MESSAGE_BYTES, MAX_IMAGE_BYTES, MAX_FILE_BYTES, IMAGE_CHUNK_SIZE
class _AsyncDB:
"""Async proxy — offloads every synchronous db.* call to a thread via asyncio.to_thread().
This prevents blocking the asyncio event loop during MySQL I/O.
Wrapper functions are cached after first access for efficiency.
"""
def __getattr__(self, name: str):
func = getattr(db, name)
async def wrapper(*args, **kwargs):
return await asyncio.to_thread(func, *args, **kwargs)
wrapper.__name__ = name
wrapper.__qualname__ = f"_AsyncDB.{name}"
setattr(self, name, wrapper)
return wrapper
adb = _AsyncDB()
# Connected clients: user_id -> list[ProtocolWriter]
connected_clients: dict[str, list[ProtocolWriter]] = {}
# Writer -> device_id mapping (id(writer) -> device_id)
writer_device_map: dict[int, str] = {}
# Pairing sessions: code -> data
pairing_sessions: dict[str, dict] = {}
pending_registrations: dict[str, dict] = {}
# Used PoW challenges (prevents replay within validity window)
_used_pow_challenges: dict[str, float] = {} # challenge -> used_at
# Pending image uploads: file_id -> {temp_path, received_bytes, file_size, conv_id}
pending_uploads: dict[str, dict] = {}
# Phantom user IDs (loaded at startup, updated on create/delete)
phantom_user_ids: set[str] = set()
# Locks for shared mutable state (H4 race condition fix)
_clients_lock = asyncio.Lock() # Protects: connected_clients, writer_device_map, phantom_user_ids
_conn_lock = asyncio.Lock() # Protects: connection_counts, current_connections, rate_limits
_pairing_lock = asyncio.Lock() # Protects: pairing_sessions, pending_registrations, _used_pow_challenges
_uploads_lock = asyncio.Lock() # Protects: pending_uploads
_phantom_lock = asyncio.Lock() # Serializes phantom user creation (cap check + DB create + set add)
UPLOAD_DIR = Path(os.getenv("UPLOAD_DIR", "uploads"))
def _secure_delete(p: Path):
"""Overwrite file with random data before deletion (anti-forensic wipe)."""
try:
if not p.exists():
return
size = p.stat().st_size
if size > 0:
with open(p, "r+b") as f:
f.write(os.urandom(size))
f.flush()
os.fsync(f.fileno())
p.unlink()
except Exception:
try:
p.unlink(missing_ok=True)
except Exception:
pass
# C6 fix: UUID validation + safe path construction to prevent path traversal
_UUID_RE = re.compile(r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$', re.IGNORECASE)
def _valid_uuid(value: str) -> bool:
"""Validate that value is a canonical UUID (no path components)."""
return bool(_UUID_RE.match(value))
# L8 fix: email validation to prevent phantom DB inflation
_EMAIL_RE = re.compile(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$")
def _valid_email(email: str) -> bool:
"""Validate basic email format (L8)."""
return bool(_EMAIL_RE.match(email)) and len(email) <= 254
# C2 fix: ratchet/x3dh header validation
_RATCHET_HEADER_KEYS = {"dh_pub", "n", "pn"}
_MAX_HEADER_BYTES = 4096
def _validate_header(raw, name: str) -> bytes | None:
"""Validate and serialize a ratchet/x3dh header.
Accepts only dict with expected keys, rejects str/bytes to prevent
poisoned headers from being stored. Validates that ratchet headers
contain the required keys (dh_pub, n, pn) with correct types.
Returns UTF-8 encoded JSON bytes or None if invalid.
"""
if not isinstance(raw, dict):
return None
serialized = json.dumps(raw)
if len(serialized) > _MAX_HEADER_BYTES:
return None
# Validate ratchet header keys/types if this looks like one
if name in ("ratchet_header", "recipient_ratchet_header"):
# Accept self-encrypted marker {"self": true}
if raw.get("self") is True and len(raw) == 1:
return serialized.encode()
if not _RATCHET_HEADER_KEYS.issubset(raw.keys()):
return None
if not isinstance(raw.get("dh_pub"), str):
return None
if type(raw.get("n")) is not int or type(raw.get("pn")) is not int:
return None
return serialized.encode()
def _append_file(path: Path, data: bytes):
"""Append data to file (runs in thread pool to avoid blocking event loop)."""
with open(path, "ab") as f:
f.write(data)
def _read_file_chunk(path: Path, offset: int, size: int) -> bytes:
"""Read a chunk from file (runs in thread pool to avoid blocking event loop)."""
with open(path, "rb") as f:
f.seek(offset)
return f.read(size)
def _safe_upload_path(file_id: str, suffix: str) -> Path | None:
"""Return resolved path inside UPLOAD_DIR, or None if traversal detected."""
p = (UPLOAD_DIR / f"{file_id}{suffix}").resolve()
if not p.is_relative_to(UPLOAD_DIR.resolve()):
return None
return p
def _safe_avatar_path(filename: str) -> Path | None:
"""Return resolved avatar path inside UPLOAD_DIR/avatars, or None if traversal detected."""
avatar_dir = (UPLOAD_DIR / "avatars").resolve()
p = (UPLOAD_DIR / "avatars" / filename).resolve()
if not p.is_relative_to(avatar_dir):
return None
return p
PAIRING_TTL_SECONDS = 120
REGISTER_TTL_SECONDS = 600 # 10 min (was 3600) — faster slot release under load
PAIRING_MAX_POLL_ATTEMPTS = 90
PAIRING_MAX_SESSIONS = 100 # global cap on concurrent pairing sessions
MAX_PENDING_REGISTRATIONS = 1000 # global cap on pending registration codes
MAX_PENDING_PER_IP = 5 # per-IP cap on pending registrations
MAX_PENDING_PER_SUBNET = 20 # per-/24 (IPv4) or /64 (IPv6) cap
REGISTRATION_PRESSURE_THRESHOLD = 0.8 # 80% → tighten limits + require PoW
POW_DIFFICULTY = 20 # leading zero bits in SHA-256 (~1M hashes, ~0.5-2s)
SMTP_RATE_GLOBAL = 30 # registration emails per minute (global)
SMTP_RATE_PER_IP = 3 # registration emails per minute (per IP)
SMTP_RATE_PER_TARGET = 2 # registration emails per minute (per target email)
MAX_PHANTOM_USERS = 500 # global cap on phantom user count
MAX_UPLOADS_GLOBAL = 200 # global cap on concurrent in-flight uploads
MAX_UPLOADS_PER_USER = 5 # per-user cap on concurrent in-flight uploads
UPLOAD_STALE_SECONDS = 600 # stale upload threshold (10 min)
# SMTP configuration for registration codes
SMTP_HOST = os.getenv("SMTP_HOST", "")
SMTP_PORT = int(os.getenv("SMTP_PORT", "587"))
SMTP_USER = os.getenv("SMTP_USER", "")
SMTP_PASS = os.getenv("SMTP_PASS", "")
SMTP_FROM = os.getenv("SMTP_FROM", "")
RATE_LIMIT_WINDOW = 60.0 # seconds
CONNECTION_RL_WINDOW = 1.0 # seconds
CONNECTION_RL_MAX = 20 # max requests per window per connection
MAX_CONNECTIONS_PER_IP = 10
MAX_CONNECTIONS_GLOBAL = 200
METADATA_RETENTION_DAYS = int(os.getenv("METADATA_RETENTION_DAYS", "90"))
def setup_logging():
level_name = os.getenv("LOG_LEVEL", "INFO").upper()
level = getattr(logging, level_name, logging.WARNING)
logging.basicConfig(level=level, format="%(asctime)s %(levelname)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
logger = logging.getLogger("encrypted_chat.server")
def _who(session: dict | None) -> str:
"""Format session info for logging: truncated user_id + device prefix.
Avoids leaking usernames and emails into log files.
"""
if not session:
return "<anon>"
uid = session.get("user_id", "?")[:8]
dev = session.get("device_id", "")[:8] if session.get("device_id") else ""
return f"u={uid} d={dev}" if dev else f"u={uid}"
rate_limits: dict[str, list[float]] = {}
connection_counts: dict[str, int] = {}
current_connections = 0
def _rate_limit_key(action: str, addr: str, email: str | None = None) -> str:
if email:
return f"{action}|{addr}|{email.lower()}"
return f"{action}|{addr}"
async def _is_rate_limited(key: str, limit: int) -> bool:
async with _conn_lock:
now = asyncio.get_event_loop().time()
window_start = now - RATE_LIMIT_WINDOW
times = rate_limits.get(key, [])
times = [t for t in times if t >= window_start]
if len(times) >= limit:
rate_limits[key] = times
return True
times.append(now)
rate_limits[key] = times
return False
async def _create_phantom_guarded(email: str, addr: str, user_id: str) -> tuple[dict | None, str]:
"""Check limits + create phantom user atomically (serialized via _phantom_lock).
Returns (user_dict, error_message). user_dict is None on rejection.
"""
# Rate limit checks outside _phantom_lock (they acquire _conn_lock)
if await _is_rate_limited(f"phantom_create|{user_id}", 10):
return None, "Too many new contacts. Try later."
if await _is_rate_limited(f"phantom_create_ip|{addr}", 10):
return None, "Too many new contacts. Try later."
async with _phantom_lock:
async with _clients_lock:
phantom_count = len(phantom_user_ids)
if phantom_count >= MAX_PHANTOM_USERS:
return None, "Server limit reached. Try later."
u = await adb.create_phantom_user(email)
async with _clients_lock:
phantom_user_ids.add(u["id"])
return u, ""
def _get_peer_addr(writer: ProtocolWriter) -> str:
try:
return str(writer._writer.get_extra_info("peername")[0])
except Exception:
return "unknown"
async def _notify_users(user_ids, msg_type, data, exclude_writer=None):
"""Snapshot writers under lock, send notifications outside lock."""
targets = []
async with _clients_lock:
for uid in user_ids:
for w in connected_clients.get(uid, []):
targets.append(w)
for w in targets:
if w is exclude_writer:
continue
try:
await w.send_response(msg_type, "ok", data)
except Exception:
pass
async def _notify_users_individual(notifications, exclude_writer=None):
"""Send per-user data. notifications: list of (user_id, msg_type, data)."""
targets = []
async with _clients_lock:
for uid, mt, d in notifications:
for w in connected_clients.get(uid, []):
targets.append((w, mt, d))
for w, mt, d in targets:
if w is exclude_writer:
continue
try:
await w.send_response(mt, "ok", d)
except Exception:
pass
async def _cleanup_pairings():
async with _pairing_lock:
now = asyncio.get_event_loop().time()
expired = [code for code, p in pairing_sessions.items() if now - p["created_at"] > PAIRING_TTL_SECONDS]
for code in expired:
pairing_sessions.pop(code, None)
async def _cleanup_registrations():
async with _pairing_lock:
now = asyncio.get_event_loop().time()
expired = [code for code, p in pending_registrations.items() if now - p["created_at"] > REGISTER_TTL_SECONDS]
for code in expired:
pending_registrations.pop(code, None)
# Purge used PoW challenges older than 120s (validity window)
stale = [ch for ch, ts in _used_pow_challenges.items() if now - ts > 120]
for ch in stale:
_used_pow_challenges.pop(ch, None)
def _generate_pairing_code() -> str:
for _ in range(10):
code = f"{int.from_bytes(os.urandom(4), 'big') % 100000000:08d}"
if code not in pairing_sessions:
return code
return f"{int.from_bytes(os.urandom(4), 'big') % 100000000:08d}"
def _generate_register_code() -> str:
for _ in range(10):
code = f"{int.from_bytes(os.urandom(3), 'big') % 1000000:06d}"
if code not in pending_registrations:
return code
return f"{int.from_bytes(os.urandom(3), 'big') % 1000000:06d}"
def _validate_public_key_pem(pem_str: str) -> bool:
"""Validate that a string is a valid RSA public key PEM."""
try:
key = load_public_key(pem_str.encode("utf-8"))
if key.key_size < 2048:
return False
return True
except Exception:
return False
def _send_registration_email(to_email: str, code: str) -> bool:
"""Send registration code via SMTP. Returns True on success."""
if not SMTP_HOST:
return False
try:
msg = MIMEText(f"Your registration code is: {code}\n\nThis code expires in 10 minutes.")
msg["Subject"] = "Encrypted Chat - Registration Code"
msg["From"] = SMTP_FROM or SMTP_USER
msg["To"] = to_email
with smtplib.SMTP(SMTP_HOST, SMTP_PORT, timeout=10) as server:
server.starttls()
if SMTP_USER:
server.login(SMTP_USER, SMTP_PASS)
server.send_message(msg)
return True
except Exception as e:
logger.warning("Failed to send registration email: %s", e)
return False
async def send_resp(msg: dict, writer: ProtocolWriter, msg_type: str, status: str, data: dict | None = None):
await writer.send_response(msg_type, status, data, request_id=msg.get("request_id"))
# --- Registration admission control ---
_POW_SECRET = os.urandom(32) # per-process; restarts invalidate outstanding challenges
def _get_subnet(addr: str) -> str:
"""Extract /24 for IPv4, /64 for IPv6."""
try:
ip = ipaddress.ip_address(addr)
if ip.version == 4:
return str(ipaddress.ip_network(f"{ip}/24", strict=False))
return str(ipaddress.ip_network(f"{ip}/64", strict=False))
except ValueError:
return addr
def _pending_counts_by_origin(addr: str) -> tuple[int, int]:
"""Count pending registrations by IP and subnet. Caller must hold _pairing_lock."""
subnet = _get_subnet(addr)
ip_count = 0
subnet_count = 0
for p in pending_registrations.values():
p_addr = p.get("addr", "")
if p_addr == addr:
ip_count += 1
if _get_subnet(p_addr) == subnet:
subnet_count += 1
return ip_count, subnet_count
def _generate_pow_challenge() -> tuple[str, str]:
"""Generate a stateless PoW challenge (challenge, mac).
The challenge embeds a timestamp so the server can reject stale solutions.
The HMAC proves the challenge was issued by this server instance.
"""
ts = str(int(asyncio.get_event_loop().time()))
nonce = secrets.token_hex(16)
challenge = f"{ts}:{nonce}"
mac = hmac.new(_POW_SECRET, challenge.encode(), hashlib.sha256).hexdigest()
return challenge, mac
def _verify_pow(challenge: str, mac: str, nonce: str, difficulty: int) -> bool:
"""Verify a PoW solution: HMAC authentic, timestamp fresh, hash has leading zeros."""
# Verify HMAC
expected = hmac.new(_POW_SECRET, challenge.encode(), hashlib.sha256).hexdigest()
if not hmac.compare_digest(expected, mac):
return False
# Check timestamp freshness (120s window)
try:
ts = int(challenge.split(":")[0])
except (ValueError, IndexError):
return False
now = int(asyncio.get_event_loop().time())
if abs(now - ts) > 120:
return False
# Verify PoW: SHA-256(challenge + nonce) must have `difficulty` leading zero bits
digest = hashlib.sha256(f"{challenge}{nonce}".encode()).digest()
# Check leading zero bits
bits_needed = difficulty
for byte in digest:
if bits_needed <= 0:
break
if bits_needed >= 8:
if byte != 0:
return False
bits_needed -= 8
else:
mask = (0xFF << (8 - bits_needed)) & 0xFF
if byte & mask:
return False
bits_needed = 0
return True
async def handle_register_start(msg: dict, writer: ProtocolWriter) -> dict | None:
await _cleanup_registrations()
username = msg.get("username", "").strip()
public_key = msg.get("public_key", "").strip()
identity_key_b64 = msg.get("identity_key", "").strip()
email = msg.get("email", "").strip()
addr = _get_peer_addr(writer)
if await _is_rate_limited(_rate_limit_key("register_start", addr, email), 3):
await send_resp(msg, writer, "register_start", "error", {"message": "Too many attempts. Try later."})
return None
# Per-IP limit (regardless of email) to prevent SMTP spam via email rotation
if await _is_rate_limited(f"register_start_ip|{addr}", 6):
await send_resp(msg, writer, "register_start", "error", {"message": "Too many attempts. Try later."})
return None
if not username or not public_key or not email or not identity_key_b64:
await send_resp(msg, writer, "register_start", "error", {"message": "Missing fields"})
return None
if not _validate_public_key_pem(public_key):
await send_resp(msg, writer, "register_start", "error", {"message": "Invalid public key format"})
return None
# Validate identity key is 32 bytes
try:
ik_bytes = decode_binary(identity_key_b64)
if len(ik_bytes) != 32:
raise ValueError("Identity key must be 32 bytes")
load_ed25519_public(ik_bytes)
except Exception:
await send_resp(msg, writer, "register_start", "error", {"message": "Invalid identity key"})
return None
existing_email = await adb.get_user_by_email(email)
phantom_id = None
is_existing_real_user = False
if existing_email:
if existing_email.get("rsa_public_key") == "PHANTOM":
phantom_id = existing_email["id"]
else:
is_existing_real_user = True
# --- Admission control (all checks under lock, I/O outside) ---
# Existing-email goes through the same path so responses are
# indistinguishable from new-email (H3 anti-enumeration).
# Both allocate a slot so per-IP/subnet cap counting is identical.
async with _pairing_lock:
total = len(pending_registrations)
# Hard cap
if total >= MAX_PENDING_REGISTRATIONS:
reject_reason = "cap"
else:
# Per-IP / per-subnet slot limits
ip_count, subnet_count = _pending_counts_by_origin(addr)
if ip_count >= MAX_PENDING_PER_IP:
reject_reason = "ip"
elif subnet_count >= MAX_PENDING_PER_SUBNET:
reject_reason = "subnet"
else:
reject_reason = None
# Pressure mode: require PoW when >80% full
under_pressure = total >= MAX_PENDING_REGISTRATIONS * REGISTRATION_PRESSURE_THRESHOLD
need_pow = under_pressure and reject_reason is None
# If PoW required, verify the client's solution (one-time use)
pow_ok = False
if need_pow:
pow_challenge = msg.get("pow_challenge", "")
pow_mac = msg.get("pow_mac", "")
pow_nonce = msg.get("pow_nonce", "")
if pow_challenge and pow_mac and pow_nonce:
if pow_challenge in _used_pow_challenges:
pow_ok = False # replay
elif _verify_pow(pow_challenge, pow_mac, pow_nonce, POW_DIFFICULTY):
_used_pow_challenges[pow_challenge] = asyncio.get_event_loop().time()
pow_ok = True
# Decide: admit, challenge, or reject
if reject_reason:
admit = False
send_challenge = False
code = None
elif need_pow and not pow_ok:
admit = False
send_challenge = True
code = None
else:
# Both existing and new emails allocate a slot so per-IP/subnet
# counting behaves identically (anti-enumeration via slot side-channel).
# Existing-email slots are inert — register_confirm silently fails.
admit = True
send_challenge = False
code = _generate_register_code()
pending_registrations[code] = {
"username": username,
"public_key": public_key,
"identity_key": ik_bytes,
"email": email,
"created_at": asyncio.get_event_loop().time(),
"phantom_id": phantom_id,
"addr": addr,
"fake": is_existing_real_user,
}
# --- I/O outside lock ---
if not admit:
if send_challenge:
challenge, mac = _generate_pow_challenge()
await send_resp(msg, writer, "register_start", "pow_required", {
"challenge": challenge, "mac": mac, "difficulty": POW_DIFFICULTY,
})
else:
await send_resp(msg, writer, "register_start", "error", {"message": "Server busy. Try later."})
return None
logger.info("[REGISTER] registration started")
is_dev = os.getenv("ENVIRONMENT", "").lower() in ("dev", "development")
# SMTP rate limiting
smtp_blocked = False
if SMTP_HOST:
if await _is_rate_limited("smtp_send|global", SMTP_RATE_GLOBAL):
smtp_blocked = True
elif await _is_rate_limited(f"smtp_send_ip|{addr}", SMTP_RATE_PER_IP):
smtp_blocked = True
elif await _is_rate_limited(f"smtp_send_target|{email.lower()}", SMTP_RATE_PER_TARGET):
smtp_blocked = True
if smtp_blocked:
if is_dev:
logger.warning("[REGISTER] SMTP rate limit hit — returning code (dev mode)")
await send_resp(msg, writer, "register_start", "ok", {"code": code})
else:
logger.warning("[REGISTER] SMTP rate limit hit — revoking slot silently")
async with _pairing_lock:
pending_registrations.pop(code, None)
await send_resp(msg, writer, "register_start", "ok",
{"message": "Code sent to your email."})
return None
# Send registration email in a thread (non-blocking) for both real
# and fake registrations. For existing emails we still call SMTP so
# the response timing is indistinguishable (anti-enumeration).
# The email goes to the real address either way — existing users just
# won't be able to confirm (code is for a fake slot).
email_sent = await asyncio.to_thread(_send_registration_email, email, code)
if email_sent:
await send_resp(msg, writer, "register_start", "ok", {"message": "Code sent to your email."})
elif is_dev:
logger.warning("[REGISTER] No SMTP / send failed — returning code (dev mode)")
await send_resp(msg, writer, "register_start", "ok", {"code": code})
else:
logger.warning("[REGISTER] SMTP send failed — revoking slot silently")
async with _pairing_lock:
pending_registrations.pop(code, None)
await send_resp(msg, writer, "register_start", "ok",
{"message": "Code sent to your email."})
return None
async def handle_register_confirm(msg: dict, writer: ProtocolWriter) -> dict | None:
await _cleanup_registrations()
email = msg.get("email", "").strip()
code = msg.get("code", "").strip()
addr = _get_peer_addr(writer)
if await _is_rate_limited(_rate_limit_key("register_confirm", addr, email), 3):
await send_resp(msg, writer, "register_confirm", "error", {"message": "Too many attempts. Try later."})
return None
if not email or not code:
await send_resp(msg, writer, "register_confirm", "error", {"message": "Missing email or code"})
return None
async with _pairing_lock:
pending = pending_registrations.get(code)
if pending and pending.get("email") == email:
pending_registrations.pop(code, None)
else:
pending = None
if not pending:
await send_resp(msg, writer, "register_confirm", "error", {"message": "Invalid or expired code"})
return None
# H3 anti-enumeration: fake slot (existing email) — reject with same
# generic message so attacker can't distinguish from a wrong code
if pending.get("fake"):
await send_resp(msg, writer, "register_confirm", "error", {"message": "Invalid or expired code"})
return None
phantom_id = pending.get("phantom_id")
if phantom_id:
# Upgrade phantom in-place — preserves FK references (invitations, memberships)
user_id = await adb.upgrade_phantom_user(
phantom_id,
pending["username"],
pending["public_key"],
pending["identity_key"],
)
if user_id:
async with _clients_lock:
phantom_user_ids.discard(phantom_id)
else:
# Phantom was deleted concurrently — fall back to normal create
user_id = await adb.create_user(
pending["username"],
pending["email"],
pending["public_key"],
pending["identity_key"],
)
else:
user_id = await adb.create_user(
pending["username"],
pending["email"],
pending["public_key"],
pending["identity_key"],
)
await adb.create_default_profile(user_id)
logger.info("[REGISTER] confirmed (user_id=%s)", user_id[:8])
await send_resp(msg, writer, "register_confirm", "ok", {"user_id": user_id})
return None
async def handle_login_start(msg: dict, writer: ProtocolWriter, state: dict):
email = msg.get("email", "").strip()
addr = _get_peer_addr(writer)
if await _is_rate_limited(_rate_limit_key("login_start", addr, email), 10):
await send_resp(msg, writer, "login_start", "error", {"message": "Too many attempts. Try later."})
return
if await _is_rate_limited(f"login_start_ip|{addr}", 20):
await send_resp(msg, writer, "login_start", "error", {"message": "Too many attempts. Try later."})
return
if not email:
await send_resp(msg, writer, "login_start", "error", {"message": "Missing email"})
return
user = await adb.get_user_by_email(email)
challenge = os.urandom(32)
state["login_email"] = email
state["login_challenge"] = challenge
if not user:
# H3 anti-enumeration: return a fake challenge so attacker can't distinguish
# "user not found" from "user exists". login_finish will fail with generic error.
state["_login_fake"] = True
await send_resp(msg, writer, "login_start", "ok", {"challenge": encode_binary(challenge)})
async def handle_login_finish(msg: dict, writer: ProtocolWriter, state: dict) -> dict | None:
email = msg.get("email", "").strip()
signature_b64 = msg.get("signature", "")
challenge = state.get("login_challenge")
expected_email = state.get("login_email")
addr = _get_peer_addr(writer)
if await _is_rate_limited(_rate_limit_key("login_finish", addr, email), 10):
await send_resp(msg, writer, "login_finish", "error", {"message": "Too many attempts. Try later."})
return None
if not email or not signature_b64:
await send_resp(msg, writer, "login_finish", "error", {"message": "Missing email or signature"})
return None
if not challenge or expected_email != email:
await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"})
return None
# H3: if login_start was for a non-existent user, fail with generic error
is_fake = state.pop("_login_fake", False)
try:
if is_fake:
await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"})
return None
user = await adb.get_user_by_email(email)
if not user:
await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"})
return None
public_key = load_public_key(user["rsa_public_key"].encode("utf-8"))
signature = decode_binary(signature_b64)
if not rsa_verify(public_key, signature, challenge):
await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"})
return None
except ValueError:
# H5: invalid base64 in signature
await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"})
return None
finally:
state.pop("login_challenge", None)
state.pop("login_email", None)
user_id = user["id"]
# Version check: reject outdated clients
client_version = msg.get("client_version", "")
if client_version and not version_gte(client_version, MIN_CLIENT_VERSION):
await send_resp(msg, writer, "login_finish", "error", {
"message": f"Client version {client_version} is too old. Minimum required: {MIN_CLIENT_VERSION}",
"min_version": MIN_CLIENT_VERSION,
"server_version": VERSION,
})
return None
# Device registration: client may send device_id to reuse an existing device
client_device_id = msg.get("device_id")
device_id = None
if client_device_id:
dev = await adb.get_device(client_device_id)
if dev and dev["user_id"] == user_id:
device_id = client_device_id
if not device_id:
device_name = msg.get("device_name", "Unknown")
device_id = await adb.create_device(user_id, device_name)
await adb.update_device_last_seen(device_id)
async with _clients_lock:
if user_id not in connected_clients:
connected_clients[user_id] = []
connected_clients[user_id].append(writer)
writer_device_map[id(writer)] = device_id
logger.info("[LOGIN] u=%s d=%s client_v=%s",
user_id[:8], device_id[:8] if device_id else "?", client_version or "unknown")
await send_resp(msg, writer, "login_finish", "ok", {
"user_id": user_id, "username": user["username"], "email": user["email"],
"device_id": device_id, "server_version": VERSION,
})
# Send online status notifications
contacts = await adb.get_user_contacts(user_id)
online_targets = []
async with _clients_lock:
online_contacts = [cid for cid in contacts if cid in connected_clients and connected_clients[cid]]
# Always notify contacts (handles reconnect where old writer is still lingering)
for contact_id in contacts:
for cw in connected_clients.get(contact_id, []):
online_targets.append(cw)
await writer.send_response("online_users", "ok", {"user_ids": online_contacts})
# Send online notifications outside lock
for cw in online_targets:
try:
await cw.send_response("user_online", "ok", {"user_id": user_id})
except Exception:
pass
return {"user_id": user_id, "username": user["username"], "email": user["email"],
"device_id": device_id}
async def handle_get_user_info(msg: dict, session: dict, writer: ProtocolWriter):
"""Get user info including identity key (for X3DH). Requires login."""
email = msg.get("email", "").strip()
user_id = msg.get("user_id", "").strip()
addr = _get_peer_addr(writer)
if await _is_rate_limited(_rate_limit_key("get_user_info", addr, email or user_id), 30):
await send_resp(msg, writer, "get_user_info", "error", {"message": "Too many attempts. Try later."})
return
if user_id and not _valid_uuid(user_id):
await send_resp(msg, writer, "get_user_info", "error", {"message": "User not found"})
return
user = None
if email:
user = await adb.get_user_by_email(email)
elif user_id:
user = await adb.get_user_by_id(user_id)
if not user:
await send_resp(msg, writer, "get_user_info", "error", {"message": "User not found"})
return
# H4 fix: restrict lookups to self or contacts (shared conversation)
target_id = user["id"]
if target_id != session["user_id"]:
if not await adb.shares_conversation(session["user_id"], target_id):
await send_resp(msg, writer, "get_user_info", "error", {"message": "User not found"})
return
ik = user.get("identity_key")
await send_resp(msg, writer, "get_user_info", "ok", {
"user_id": user["id"],
"username": user["username"],
"email": user["email"],
"identity_key": encode_binary(ik) if ik else "",
})
async def handle_upload_prekeys(msg: dict, session: dict, writer: ProtocolWriter):
"""Upload signed prekey + batch of one-time prekeys."""
if await _is_rate_limited(f"upload_prekeys|{session['user_id']}", 5):
await send_resp(msg, writer, "upload_prekeys", "error", {"message": "Too many requests. Try later."})
return
spk_data = msg.get("signed_prekey")
otps = msg.get("one_time_prekeys", [])
if not spk_data:
await send_resp(msg, writer, "upload_prekeys", "error", {"message": "Missing signed_prekey"})
return
spk_id = spk_data.get("id", "")
spk_pub_b64 = spk_data.get("public_key", "")
spk_sig_b64 = spk_data.get("signature", "")
if not spk_id or not spk_pub_b64 or not spk_sig_b64:
await send_resp(msg, writer, "upload_prekeys", "error", {"message": "Incomplete signed_prekey"})
return
spk_pub = decode_binary(spk_pub_b64)
spk_sig = decode_binary(spk_sig_b64)
# Verify SPK signature with user's identity key
user = await adb.get_user_by_id(session["user_id"])
if not user or not user.get("identity_key"):
await send_resp(msg, writer, "upload_prekeys", "error", {"message": "No identity key"})
return
ik_pub = load_ed25519_public(user["identity_key"])
if not ed25519_verify(ik_pub, spk_sig, spk_pub):
await send_resp(msg, writer, "upload_prekeys", "error", {"message": "Invalid SPK signature"})
return
device_id = session.get("device_id")
await adb.store_signed_prekey(session["user_id"], spk_id, spk_pub, spk_sig, device_id=device_id)
# Store OTPs
otp_records = []
for otp in otps:
otp_id = otp.get("id", "")
otp_pub_b64 = otp.get("public_key", "")
if otp_id and otp_pub_b64:
otp_records.append({"id": otp_id, "public_key": decode_binary(otp_pub_b64)})
if otp_records:
await adb.store_one_time_prekeys(session["user_id"], otp_records, device_id=device_id)
logger.info("[PREKEYS] %s uploaded 1 SPK + %d OTPs", _who(session), len(otp_records))
await send_resp(msg, writer, "upload_prekeys", "ok", {"message": "OK"})
async def handle_get_key_bundle(msg: dict, session: dict, writer: ProtocolWriter):
"""Fetch key bundle for X3DH. Returns per-device bundles. Consumes one OTP per device."""
target_user_id = msg.get("user_id", "").strip()
if not target_user_id:
await send_resp(msg, writer, "get_key_bundle", "error", {"message": "Missing user_id"})
return
if not _valid_uuid(target_user_id):
await send_resp(msg, writer, "get_key_bundle", "error", {"message": "Invalid user_id"})
return
# M4: rate limit + authorization (prevents OPK depletion)
if await _is_rate_limited(f"get_key_bundle|{session['user_id']}", 10):
await send_resp(msg, writer, "get_key_bundle", "error", {"message": "Too many requests. Try later."})
return
# Auth check before per-target rate limit so unauthorized requests don't burn target's bucket
if target_user_id != session["user_id"]:
if not await adb.shares_conversation(session["user_id"], target_user_id):
await send_resp(msg, writer, "get_key_bundle", "error", {"message": "Key bundle not available"})
return
if await _is_rate_limited(f"get_key_bundle_target|{target_user_id}", 20):
await send_resp(msg, writer, "get_key_bundle", "error", {"message": "Too many requests. Try later."})
return
result = await adb.get_key_bundles_for_user(target_user_id)
if not result or not result.get("device_bundles"):
await send_resp(msg, writer, "get_key_bundle", "error", {"message": "Key bundle not available"})
return
device_bundles_data = []
for b in result["device_bundles"]:
entry = {
"device_id": b.get("device_id"),
"signed_prekey_id": b["signed_prekey_id"],
"signed_prekey": encode_binary(b["signed_prekey_pub"]),
"spk_signature": encode_binary(b["spk_signature"]),
}
if b.get("opk_pub"):
entry["one_time_prekey_id"] = b["opk_id"]
entry["one_time_prekey"] = encode_binary(b["opk_pub"])
device_bundles_data.append(entry)
# Build response with both new multi-device format and legacy flat fields
first = device_bundles_data[0] if device_bundles_data else {}
data = {
"identity_key": encode_binary(result["identity_key"]),
"device_bundles": device_bundles_data,
# Legacy flat fields from first device bundle (backward compat)
"signed_prekey_id": first.get("signed_prekey_id", ""),
"signed_prekey": first.get("signed_prekey", ""),
"spk_signature": first.get("spk_signature", ""),
}
if first.get("one_time_prekey"):
data["one_time_prekey_id"] = first["one_time_prekey_id"]
data["one_time_prekey"] = first["one_time_prekey"]
logger.info("[X3DH] %s fetched key bundle for user=%s (%d devices)",
_who(session), target_user_id[:8], len(device_bundles_data))
await send_resp(msg, writer, "get_key_bundle", "ok", data)
async def handle_get_prekey_count(msg: dict, session: dict, writer: ProtocolWriter):
"""How many OPKs does user have left (for this device)? Also returns SPK age for rotation."""
device_id = session.get("device_id")
count = await adb.count_one_time_prekeys(session["user_id"], device_id=device_id)
spk_created_at = ""
spk = await adb.get_signed_prekey(session["user_id"], device_id=device_id)
if spk and spk.get("created_at"):
spk_created_at = spk["created_at"].isoformat() if hasattr(spk["created_at"], "isoformat") else str(spk["created_at"])
await send_resp(msg, writer, "get_prekey_count", "ok",
{"count": count, "spk_created_at": spk_created_at})
async def handle_ensure_prekeys(msg: dict, session: dict, writer: ProtocolWriter):
"""Combined get_prekey_count + upload_prekeys in one round-trip.
Client sends current OPK/SPK data; server checks count and SPK age,
stores new keys if provided, and returns the current status.
"""
if await _is_rate_limited(f"ensure_prekeys|{session['user_id']}", 5):
await send_resp(msg, writer, "ensure_prekeys", "error", {"message": "Too many requests. Try later."})
return
device_id = session.get("device_id")
user_id = session["user_id"]
# Step 1: Get current count + SPK age
count = await adb.count_one_time_prekeys(user_id, device_id=device_id)
spk_created_at = ""
spk = await adb.get_signed_prekey(user_id, device_id=device_id)
if spk and spk.get("created_at"):
spk_created_at = spk["created_at"].isoformat() if hasattr(spk["created_at"], "isoformat") else str(spk["created_at"])
# Step 2: If client included new keys, store them
uploaded_spk = False
uploaded_otps = 0
spk_data = msg.get("signed_prekey")
if spk_data:
spk_id = spk_data.get("id", "")
spk_pub_b64 = spk_data.get("public_key", "")
spk_sig_b64 = spk_data.get("signature", "")
if spk_id and spk_pub_b64 and spk_sig_b64:
spk_pub = decode_binary(spk_pub_b64)
spk_sig = decode_binary(spk_sig_b64)
user = await adb.get_user_by_id(user_id)
if user and user.get("identity_key"):
ik_pub = load_ed25519_public(user["identity_key"])
if ed25519_verify(ik_pub, spk_sig, spk_pub):
await adb.store_signed_prekey(user_id, spk_id, spk_pub, spk_sig, device_id=device_id)
uploaded_spk = True
otps = msg.get("one_time_prekeys", [])
if otps:
otp_records = []
for otp in otps:
otp_id = otp.get("id", "")
otp_pub_b64 = otp.get("public_key", "")
if otp_id and otp_pub_b64:
otp_records.append({"id": otp_id, "public_key": decode_binary(otp_pub_b64)})
if otp_records:
await adb.store_one_time_prekeys(user_id, otp_records, device_id=device_id)
uploaded_otps = len(otp_records)
# Recount after upload
if uploaded_spk or uploaded_otps:
count = await adb.count_one_time_prekeys(user_id, device_id=device_id)
spk = await adb.get_signed_prekey(user_id, device_id=device_id)
if spk and spk.get("created_at"):
spk_created_at = spk["created_at"].isoformat() if hasattr(spk["created_at"], "isoformat") else str(spk["created_at"])
logger.info("[PREKEYS] %s ensure_prekeys: uploaded SPK=%s, OTPs=%d, new count=%d",
_who(session), uploaded_spk, uploaded_otps, count)
await send_resp(msg, writer, "ensure_prekeys", "ok",
{"count": count, "spk_created_at": spk_created_at,
"uploaded_spk": uploaded_spk, "uploaded_otps": uploaded_otps})
async def handle_rotate_keys(msg: dict, session: dict, writer: ProtocolWriter):
if await _is_rate_limited(f"rotate_keys|{session['user_id']}", 3):
await send_resp(msg, writer, "rotate_keys", "error", {"message": "Too many requests. Try later."})
return
public_key = msg.get("public_key", "").strip()
if not public_key:
await send_resp(msg, writer, "rotate_keys", "error", {"message": "Missing public_key"})
return
if not _validate_public_key_pem(public_key):
await send_resp(msg, writer, "rotate_keys", "error", {"message": "Invalid public key format"})
return
await adb.update_user_rsa_key(session["user_id"], public_key)
logger.info("[ROTATE] %s rotated RSA key", _who(session))
await send_resp(msg, writer, "rotate_keys", "ok", {"message": "OK"})
# Disconnect other sessions
async with _clients_lock:
writers = connected_clients.get(session["user_id"], [])
others = [w for w in writers if w is not writer]
connected_clients[session["user_id"]] = [writer]
for w in others:
try:
w.close()
except Exception:
pass
async def handle_change_username(msg: dict, session: dict, writer: ProtocolWriter):
if await _is_rate_limited(f"change_username|{session['user_id']}", 5):
await send_resp(msg, writer, "change_username", "error", {"message": "Too many requests. Try later."})
return
new_username = msg.get("username", "").strip()
if not new_username or len(new_username) > 100:
await send_resp(msg, writer, "change_username", "error", {"message": "Invalid username (1-100 chars)"})
return
user_id = session["user_id"]
await adb.update_username(user_id, new_username)
session["username"] = new_username
logger.info("[ACCOUNT] %s changed username", _who(session))
await send_resp(msg, writer, "change_username", "ok", {"username": new_username})
# Notify contacts
contacts = await adb.get_user_contacts(user_id)
targets = []
async with _clients_lock:
for cid in contacts:
for cw in connected_clients.get(cid, []):
targets.append(cw)
for cw in targets:
try:
await cw.send_response("username_changed", "ok", {
"user_id": user_id, "username": new_username,
})
except Exception:
pass
async def handle_pairing_start(msg: dict, writer: ProtocolWriter):
await _cleanup_pairings()
email = msg.get("email", "").strip()
temp_public_key = msg.get("temp_public_key", "").strip()
addr = _get_peer_addr(writer)
# H4 fix: rate limit per IP only (not per email) to prevent enumeration via email rotation
if await _is_rate_limited(_rate_limit_key("pairing_start", addr), 10):
await send_resp(msg, writer, "pairing_start", "error", {"message": "Too many attempts. Try later."})
return
if not email or not temp_public_key:
await send_resp(msg, writer, "pairing_start", "error", {"message": "Missing email or temp_public_key"})
return
poll_token = secrets.token_hex(16)
cap_hit = False
async with _pairing_lock:
# H4 fix: global cap prevents memory exhaustion from dummy sessions
if len(pairing_sessions) >= PAIRING_MAX_SESSIONS:
cap_hit = True
else:
code = _generate_pairing_code()
# H4 fix: always create session (anti-enumeration). For non-existent users
# the session behaves identically (poll returns ready:false, claim never matches
# because no real account can log in to claim it). TTL cleanup handles expiry.
pairing_sessions[code] = {
"email": email,
"temp_public_key": temp_public_key,
"created_at": asyncio.get_event_loop().time(),
"payload": None,
"poll_token": poll_token,
}
if cap_hit:
await send_resp(msg, writer, "pairing_start", "error", {"message": "Too many attempts. Try later."})
return
await send_resp(msg, writer, "pairing_start", "ok", {"code": code, "poll_token": poll_token})
async def handle_pairing_claim(msg: dict, session: dict, writer: ProtocolWriter):
await _cleanup_pairings()
code = msg.get("code", "").strip()
if not code:
await send_resp(msg, writer, "pairing_claim", "error", {"message": "Missing code"})
return
async with _pairing_lock:
p = pairing_sessions.get(code)
p_email = p["email"] if p else None
temp_pub = p["temp_public_key"] if p else None
if p:
# Extend TTL — re-encryption may run between claim and send
p["created_at"] = asyncio.get_event_loop().time()
# H4 fix: unified error message (anti-enumeration)
if not p or p_email != session.get("email"):
await send_resp(msg, writer, "pairing_claim", "error", {"message": "Invalid or expired code"})
return
await send_resp(msg, writer, "pairing_claim", "ok", {"temp_public_key": temp_pub})
async def handle_pairing_send(msg: dict, session: dict, writer: ProtocolWriter):
await _cleanup_pairings()
code = msg.get("code", "").strip()
payload = msg.get("payload")
if not code or not payload:
await send_resp(msg, writer, "pairing_send", "error", {"message": "Missing code or payload"})
return
error_msg = None
async with _pairing_lock:
p = pairing_sessions.get(code)
# H4 fix: unified error message (anti-enumeration)
if not p or p["email"] != session.get("email"):
error_msg = "Invalid or expired code"
else:
p["payload"] = payload
if error_msg:
await send_resp(msg, writer, "pairing_send", "error", {"message": error_msg})
else:
await send_resp(msg, writer, "pairing_send", "ok", {"message": "OK"})
async def handle_pairing_poll(msg: dict, writer: ProtocolWriter):
await _cleanup_pairings()
code = msg.get("code", "").strip()
poll_token = msg.get("poll_token", "").strip()
addr = _get_peer_addr(writer)
if await _is_rate_limited(_rate_limit_key("pairing_poll", addr), 120):
await send_resp(msg, writer, "pairing_poll", "error", {"message": "Too many attempts. Try later."})
return
if not code:
await send_resp(msg, writer, "pairing_poll", "error", {"message": "Missing code"})
return
if not poll_token:
await send_resp(msg, writer, "pairing_poll", "error", {"message": "Missing poll_token"})
return
error_msg = None
ready = False
payload = None
async with _pairing_lock:
p = pairing_sessions.get(code)
if not p:
error_msg = "Invalid or expired code"
elif not secrets.compare_digest(p.get("poll_token", ""), poll_token):
error_msg = "Invalid poll_token"
else:
poll_attempts = p.get("poll_attempts", 0) + 1
p["poll_attempts"] = poll_attempts
if poll_attempts > PAIRING_MAX_POLL_ATTEMPTS and not p.get("payload"):
pairing_sessions.pop(code, None)
error_msg = "Code invalidated due to too many attempts"
elif p.get("payload"):
ready = True
payload = p["payload"]
pairing_sessions.pop(code, None)
if error_msg:
await send_resp(msg, writer, "pairing_poll", "error", {"message": error_msg})
elif ready:
await send_resp(msg, writer, "pairing_poll", "ok", {"ready": True, "payload": payload})
else:
await send_resp(msg, writer, "pairing_poll", "ok", {"ready": False})
async def handle_create_conversation(msg: dict, session: dict, writer: ProtocolWriter):
member_emails = msg.get("members", [])
name = msg.get("name")
addr = _get_peer_addr(writer)
if await _is_rate_limited(f"create_conversation|{session['user_id']}", 10):
await send_resp(msg, writer, "create_conversation", "error", {"message": "Too many attempts. Try later."})
return
# Resolve all member user IDs
other_users = []
for email in member_emails:
u = await adb.get_user_by_email(email)
if not u:
if not _valid_email(email):
await send_resp(msg, writer, "create_conversation", "error", {"message": f"Invalid email format: {email}"})
return
# H5: atomic phantom creation (cap check + DB create + set add)
u, err_msg = await _create_phantom_guarded(email, addr, session["user_id"])
if u is None:
await send_resp(msg, writer, "create_conversation", "error", {"message": err_msg})
return
if u["id"] != session["user_id"]:
other_users.append(u)
is_dm = len(other_users) == 1 and not name
joined_at = datetime.now(timezone.utc)
if is_dm:
# DMs: add both members directly (no invitation)
all_ids = [session["user_id"]] + [u["id"] for u in other_users]
conv_id = await adb.create_conversation(all_ids, joined_at=joined_at, name=name, created_by=session["user_id"])
logger.info("[CONV] %s created DM conv=%s", _who(session), conv_id[:8])
await send_resp(msg, writer, "create_conversation", "ok", {"conversation_id": conv_id})
# Notify the other member
members_info = await adb.get_conversation_members(conv_id)
member_list = [{"user_id": m["id"], "username": m["username"], "email": m["email"]} for m in members_info]
notif_data = {
"conversation_id": conv_id,
"name": name,
"created_by": session["user_id"],
"members": member_list,
}
await _notify_users([u["id"] for u in other_users], "conversation_created", notif_data)
else:
# Groups: only add creator, create invitations for others
conv_id = await adb.create_conversation([session["user_id"]], joined_at=joined_at, name=name, created_by=session["user_id"])
logger.info("[CONV] %s created group conv=%s",
_who(session), conv_id[:8])
# Create invitations for other members
creator_user = await adb.get_user_by_id(session["user_id"])
creator_name = creator_user["username"] if creator_user else "Unknown"
invited_ids = []
async with _clients_lock:
phantom_snapshot = set(phantom_user_ids)
for u in other_users:
await adb.create_invitation(conv_id, u["id"], session["user_id"])
if u["id"] not in phantom_snapshot:
invited_ids.append(u["id"]) # only notify non-phantoms
inv_notif = {
"conversation_id": conv_id,
"conversation_name": name,
"invited_by": session["user_id"],
"invited_by_username": creator_name,
}
await _notify_users(invited_ids, "group_invitation", inv_notif)
await send_resp(msg, writer, "create_conversation", "ok", {"conversation_id": conv_id})
async def handle_find_conversation(msg: dict, session: dict, writer: ProtocolWriter):
email = msg.get("email", "").strip()
if not email:
await send_resp(msg, writer, "find_conversation", "error", {"message": "Invalid request"})
return
addr = _get_peer_addr(writer)
if await _is_rate_limited(_rate_limit_key("find_conversation", addr, email), 30):
await send_resp(msg, writer, "find_conversation", "error", {"message": "Too many attempts. Try later."})
return
other = await adb.get_user_by_email(email)
if not other:
if not _valid_email(email):
await send_resp(msg, writer, "find_conversation", "error", {"message": "Invalid email format"})
return
# H5: atomic phantom creation (cap check + DB create + set add)
other, err_msg = await _create_phantom_guarded(email, addr, session["user_id"])
if other is None:
await send_resp(msg, writer, "find_conversation", "error", {"message": err_msg})
return
conv_id = await adb.find_direct_conversation(session["user_id"], other["id"])
await send_resp(msg, writer, "find_conversation", "ok", {
"conversation_id": conv_id,
"user_id": other["id"],
})
async def handle_add_member(msg: dict, session: dict, writer: ProtocolWriter):
conv_id = msg.get("conversation_id", "")
email = msg.get("email", "").strip()
if not conv_id or not email:
await send_resp(msg, writer, "add_member", "error", {"message": "Invalid request"})
return
if not _valid_uuid(conv_id):
await send_resp(msg, writer, "add_member", "error", {"message": "Invalid conversation_id"})
return
# L8: validate email format before phantom creation
addr = _get_peer_addr(writer)
if await _is_rate_limited(_rate_limit_key("add_member", addr, email), 10):
await send_resp(msg, writer, "add_member", "error", {"message": "Too many attempts. Try later."})
return
if not await adb.is_conversation_member(conv_id, session["user_id"]):
await send_resp(msg, writer, "add_member", "error", {"message": "Not a member"})
return
user = await adb.get_user_by_email(email)
if not user:
# Create phantom for unregistered email (same as create_conversation)
if not _valid_email(email):
await send_resp(msg, writer, "add_member", "error", {"message": "Invalid email format"})
return
# H5: atomic phantom creation (cap check + DB create + set add)
user, err_msg = await _create_phantom_guarded(email, addr, session["user_id"])
if user is None:
await send_resp(msg, writer, "add_member", "error", {"message": err_msg})
return
if await adb.is_conversation_member(conv_id, user["id"]):
await send_resp(msg, writer, "add_member", "error", {"message": "Already a member"})
return
if await adb.has_pending_invitation(conv_id, user["id"]):
await send_resp(msg, writer, "add_member", "error", {"message": "Invitation already pending"})
return
# Create invitation (for both real and phantom users)
await adb.create_invitation(conv_id, user["id"], session["user_id"])
logger.info("[INVITE] %s invited u=%s to conv=%s", _who(session), user["id"][:8], conv_id[:8])
await send_resp(msg, writer, "add_member", "ok", {"user_id": user["id"]})
# Push invitation notification only to non-phantom users
async with _clients_lock:
is_phantom = user["id"] in phantom_user_ids
if not is_phantom:
conv = await adb.get_conversation(conv_id)
creator_user = await adb.get_user_by_id(session["user_id"])
creator_name = creator_user["username"] if creator_user else "Unknown"
inv_notif = {
"conversation_id": conv_id,
"conversation_name": conv.get("name") if conv else None,
"invited_by": session["user_id"],
"invited_by_username": creator_name,
}
await _notify_users([user["id"]], "group_invitation", inv_notif)
async def handle_accept_invitation(msg: dict, session: dict, writer: ProtocolWriter):
"""Accept a group invitation — add user to conversation members."""
conv_id = msg.get("conversation_id", "")
if not conv_id:
await send_resp(msg, writer, "accept_invitation", "error", {"message": "Missing conversation_id"})
return
if not _valid_uuid(conv_id):
await send_resp(msg, writer, "accept_invitation", "error", {"message": "Invalid conversation_id"})
return
if not await adb.has_pending_invitation(conv_id, session["user_id"]):
await send_resp(msg, writer, "accept_invitation", "error", {"message": "No pending invitation"})
return
joined_at = datetime.now(timezone.utc)
await adb.add_conversation_member(conv_id, session["user_id"], joined_at=joined_at)
await adb.delete_invitation(conv_id, session["user_id"])
logger.info("[INVITE] %s accepted invitation to conv=%s", _who(session), conv_id[:8])
await send_resp(msg, writer, "accept_invitation", "ok", {"conversation_id": conv_id})
# Notify existing members about the new member
user = await adb.get_user_by_id(session["user_id"])
notif_data = {
"conversation_id": conv_id,
"user_id": session["user_id"],
"username": user["username"] if user else "",
"email": user["email"] if user else "",
}
members = await adb.get_conversation_members(conv_id)
member_ids = [m["id"] for m in members if m["id"] != session["user_id"]]
await _notify_users(member_ids, "member_added", notif_data)
async def handle_decline_invitation(msg: dict, session: dict, writer: ProtocolWriter):
"""Decline a group invitation."""
conv_id = msg.get("conversation_id", "")
if not conv_id:
await send_resp(msg, writer, "decline_invitation", "error", {"message": "Missing conversation_id"})
return
if not _valid_uuid(conv_id):
await send_resp(msg, writer, "decline_invitation", "error", {"message": "Invalid conversation_id"})
return
if not await adb.has_pending_invitation(conv_id, session["user_id"]):
await send_resp(msg, writer, "decline_invitation", "error", {"message": "No pending invitation"})
return
await adb.delete_invitation(conv_id, session["user_id"])
logger.info("[INVITE] %s declined invitation to conv=%s", _who(session), conv_id[:8])
await send_resp(msg, writer, "decline_invitation", "ok", {"message": "OK"})
async def handle_list_invitations(msg: dict, session: dict, writer: ProtocolWriter):
"""List pending group invitations for the current user."""
invitations = await adb.get_pending_invitations(session["user_id"])
result = []
for inv in invitations:
entry = {
"conversation_id": inv["conversation_id"],
"conversation_name": inv.get("conversation_name"),
"invited_by": inv["invited_by"],
"invited_by_username": inv.get("invited_by_username", ""),
"created_at": inv["created_at"].isoformat() if hasattr(inv["created_at"], "isoformat") else str(inv["created_at"]),
}
result.append(entry)
await send_resp(msg, writer, "list_invitations", "ok", {"invitations": result})
async def handle_list_conversations(msg: dict, session: dict, writer: ProtocolWriter):
convs = await adb.list_user_conversations(session["user_id"])
unread = await adb.get_unread_counts(session["user_id"], max_age_days=METADATA_RETENTION_DAYS)
result = []
for c in convs:
result.append({
"conversation_id": c["id"],
"created_at": c["created_at"].isoformat() if hasattr(c["created_at"], "isoformat") else str(c["created_at"]),
"members": c["members"],
"name": c.get("name"),
"created_by": c.get("created_by"),
"avatar_file": c.get("avatar_file"),
"unread_count": unread.get(c["id"], 0),
})
logger.info("[LIST] %s listed %d conversations", _who(session), len(result))
await send_resp(msg, writer, "list_conversations", "ok", {"conversations": result})
async def handle_send_message(msg: dict, session: dict, writer: ProtocolWriter):
conv_id = msg.get("conversation_id", "")
if not conv_id:
await send_resp(msg, writer, "send_message", "error", {"message": "Missing conversation_id"})
return
if not _valid_uuid(conv_id):
await send_resp(msg, writer, "send_message", "error", {"message": "Invalid conversation_id"})
return
addr = _get_peer_addr(writer)
if await _is_rate_limited(_rate_limit_key("send_message", addr, session.get("email")), 20):
await send_resp(msg, writer, "send_message", "error", {"message": "Too many attempts. Try later."})
return
if not await adb.is_conversation_member(conv_id, session["user_id"]):
await send_resp(msg, writer, "send_message", "error", {"message": "Not a member"})
return
# New protocol: ratchet_header + recipients[] with per-user ciphertext
ratchet_header_raw = msg.get("ratchet_header")
recipients_raw = msg.get("recipients")
if not ratchet_header_raw or not recipients_raw:
await send_resp(msg, writer, "send_message", "error", {"message": "Missing ratchet_header or recipients"})
return
# C2 fix: validate header is a dict (reject raw str/bytes)
ratchet_header = _validate_header(ratchet_header_raw, "ratchet_header")
if ratchet_header is None:
await send_resp(msg, writer, "send_message", "error", {"message": "Invalid ratchet_header format"})
return
x3dh_header_raw = msg.get("x3dh_header")
x3dh_header = None
if x3dh_header_raw:
x3dh_header = _validate_header(x3dh_header_raw, "x3dh_header")
if x3dh_header is None:
await send_resp(msg, writer, "send_message", "error", {"message": "Invalid x3dh_header format"})
return
sender_chain_id_b64 = msg.get("sender_chain_id")
sender_chain_id = decode_binary(sender_chain_id_b64) if sender_chain_id_b64 else None
sender_chain_n = msg.get("sender_chain_n")
# Validate recipients are actual members
conv_members = await adb.get_conversation_members(conv_id)
member_ids = {m["id"] for m in conv_members}
async with _clients_lock:
phantom_snapshot = set(phantom_user_ids)
db_recipients = []
for r in recipients_raw:
uid = r.get("user_id", "")
if uid not in member_ids:
continue
if uid in phantom_snapshot:
continue
ct_b64 = r.get("encrypted_content", "")
nonce_b64 = r.get("nonce", "")
if not ct_b64 or not nonce_b64:
continue
entry = {
"user_id": uid,
"encrypted_content": decode_binary(ct_b64),
"nonce": decode_binary(nonce_b64),
}
# Per-recipient device_id (multi-device support)
r_device_id = r.get("device_id")
if r_device_id:
entry["device_id"] = r_device_id
# Per-recipient ratchet header and x3dh header (C2 fix: validate dict)
r_rh = r.get("ratchet_header")
if r_rh:
r_rh_bytes = _validate_header(r_rh, "recipient_ratchet_header")
if r_rh_bytes:
entry["ratchet_header"] = r_rh_bytes
r_x3dh = r.get("x3dh_header")
if r_x3dh:
r_x3dh_bytes = _validate_header(r_x3dh, "recipient_x3dh_header")
if r_x3dh_bytes:
entry["x3dh_header"] = r_x3dh_bytes
db_recipients.append(entry)
if not db_recipients:
await send_resp(msg, writer, "send_message", "error", {"message": "No valid recipients"})
return
image_file_id = msg.get("image_file_id")
# Metadata privacy: for group messages (sender_chain_id present), store chain
# metadata in per-recipient ratchet_header instead of the messages table.
# This avoids persisting sender correlation data at the message level.
# Skip sender's own self-copy entry — it uses a different decrypt path
# (self-encryption key) and must keep its own ratchet_header ({"self":true}).
db_sender_chain_id = None
db_sender_chain_n = None
if sender_chain_id:
chain_meta = json.dumps({
"chain_id": encode_binary(sender_chain_id),
"chain_n": sender_chain_n,
}).encode()
sender_uid = session["user_id"]
for r in db_recipients:
# Skip self-copy (sender's own entry) — uses self-encryption, not sender key
if r["user_id"] == sender_uid:
continue
if not r.get("ratchet_header"):
r["ratchet_header"] = chain_meta
msg_id, created_at = await adb.store_message(
conv_id, session["user_id"], ratchet_header, db_recipients,
x3dh_header=x3dh_header,
sender_chain_id=db_sender_chain_id,
sender_chain_n=db_sender_chain_n,
image_file_id=image_file_id,
sender_device_id=session.get("device_id"),
)
# Link image upload to message if present
if image_file_id:
upload = await adb.get_image_upload(image_file_id)
if upload and upload["completed"] and upload["uploader_id"] == session["user_id"]:
await adb.set_message_image_file_id(msg_id, image_file_id)
logger.info("[MSG] %s msg=%s conv=%s", _who(session), msg_id[:8], conv_id[:8])
await send_resp(msg, writer, "send_message", "ok", {"message_id": msg_id, "created_at": created_at})
# Notify connected recipients — group all per-device entries by user_id
# Use validated db_recipients (not raw input) to prevent unvalidated headers in push
msg_ratchet_header_dict = json.loads(ratchet_header.decode())
msg_x3dh_header_dict = json.loads(x3dh_header.decode()) if x3dh_header else None
from collections import defaultdict
user_entries = defaultdict(list)
for r in db_recipients:
uid = r["user_id"]
# Per-recipient headers are stored as bytes; decode back to dict for notification JSON
r_rh = r.get("ratchet_header")
r_rh_dict = json.loads(r_rh.decode()) if r_rh else None
r_x3dh = r.get("x3dh_header")
r_x3dh_dict = json.loads(r_x3dh.decode()) if r_x3dh else None
user_entries[uid].append({
"device_id": r.get("device_id", db.SELF_DEVICE_ID),
"encrypted_content": encode_binary(r["encrypted_content"]),
"nonce": encode_binary(r["nonce"]),
"ratchet_header": r_rh_dict or msg_ratchet_header_dict,
"x3dh_header": r_x3dh_dict or msg_x3dh_header_dict,
})
notifications = []
for uid, entries in user_entries.items():
notif_data = {
"message_id": msg_id,
"conversation_id": conv_id,
"sender_id": session["user_id"],
"sender_device_id": session.get("device_id"),
"device_entries": entries,
}
if sender_chain_id_b64:
notif_data["sender_chain_id"] = sender_chain_id_b64
if sender_chain_n is not None:
notif_data["sender_chain_n"] = sender_chain_n
# Also include flat fields for backward compat with old clients
# (first entry's data as fallback)
if entries:
first = entries[0]
notif_data["ratchet_header"] = first.get("ratchet_header") or msg_ratchet_header_dict
notif_data["encrypted_content"] = first.get("encrypted_content", "")
notif_data["nonce"] = first.get("nonce", "")
if first.get("x3dh_header"):
notif_data["x3dh_header"] = first["x3dh_header"]
notifications.append((uid, "new_message", notif_data))
await _notify_users_individual(notifications, exclude_writer=writer)
async def handle_get_messages(msg: dict, session: dict, writer: ProtocolWriter):
if await _is_rate_limited(f"get_messages|{session['user_id']}", 30):
await send_resp(msg, writer, "get_messages", "error", {"message": "Too many requests. Try later."})
return
conv_id = msg.get("conversation_id", "")
if not conv_id:
await send_resp(msg, writer, "get_messages", "error", {"message": "Missing conversation_id"})
return
if not _valid_uuid(conv_id):
await send_resp(msg, writer, "get_messages", "error", {"message": "Invalid conversation_id"})
return
if not await adb.is_conversation_member(conv_id, session["user_id"]):
await send_resp(msg, writer, "get_messages", "error", {"message": "Not a member"})
return
limit = min(max(int(msg.get("limit", 50)), 1), 200)
offset = max(int(msg.get("offset", 0)), 0)
device_id = session.get("device_id")
after_ts = msg.get("after_ts") # ISO timestamp string or None
messages = await adb.get_messages(conv_id, session["user_id"], limit, offset,
device_id=device_id, after_ts=after_ts)
# Deduplicate: when both device-specific and SELF_DEVICE_ID rows exist for the
# same message, prefer device-specific (non-sentinel). Keep first seen per message_id.
seen_ids = {}
deduped = []
for m in messages:
mid = m["id"]
mr_dev = m.get("mr_device_id", "")
if mid not in seen_ids:
seen_ids[mid] = len(deduped)
deduped.append(m)
elif mr_dev != db.SELF_DEVICE_ID:
# Replace SELF_DEVICE_ID entry with device-specific one
deduped[seen_ids[mid]] = m
messages = deduped
result = []
message_ids = [m["id"] for m in messages]
read_status = await adb.get_message_read_status(message_ids) if message_ids else {}
delivery_status = await adb.get_message_delivery_status(message_ids) if message_ids else {}
reactions_map = await adb.get_reactions(message_ids) if message_ids else {}
for m in messages:
read_by = read_status.get(m["id"], [])
# Prefer per-recipient headers (mr_*) over message-level headers
rh_raw = m.get("mr_ratchet_header") or m.get("ratchet_header")
x3dh_raw = m.get("mr_x3dh_header") or m.get("x3dh_header")
# C2 fix: defensive JSON parsing — corrupted headers don't break fetch
try:
rh_parsed = json.loads(rh_raw) if rh_raw else {}
except (json.JSONDecodeError, TypeError, UnicodeDecodeError):
logger.warning("[FETCH] Corrupted ratchet_header in message %s, skipping", m["id"])
rh_parsed = {}
try:
x3dh_parsed = json.loads(x3dh_raw) if x3dh_raw else None
except (json.JSONDecodeError, TypeError, UnicodeDecodeError):
logger.warning("[FETCH] Corrupted x3dh_header in message %s, skipping", m["id"])
x3dh_parsed = None
entry = {
"message_id": m["id"],
"sender_id": m.get("sender_id") or "",
"ratchet_header": rh_parsed,
"encrypted_content": encode_binary(m["encrypted_content"]) if m.get("encrypted_content") else "",
"nonce": encode_binary(m["nonce"]) if m.get("nonce") else "",
"created_at": m["created_at"].isoformat() if hasattr(m["created_at"], "isoformat") else str(m["created_at"]),
"read_by": read_by,
"delivered_to": delivery_status.get(m["id"], []),
}
if x3dh_parsed:
entry["x3dh_header"] = x3dh_parsed
# Sender chain metadata: check message-level first (backward compat),
# then per-recipient ratchet_header (new metadata-private format).
# Only extract from per-recipient header if message-level ratchet_header
# is the group dummy (dh_pub all-zeros) — prevents DM header injection.
if m.get("sender_chain_id"):
entry["sender_chain_id"] = encode_binary(m["sender_chain_id"])
elif isinstance(rh_parsed, dict) and rh_parsed.get("chain_id"):
# Verify this is a group message by checking the message-level header
msg_rh_raw = m.get("ratchet_header")
is_group = False
if msg_rh_raw:
try:
msg_rh = json.loads(msg_rh_raw) if isinstance(msg_rh_raw, (bytes, str)) else msg_rh_raw
is_group = isinstance(msg_rh, dict) and msg_rh.get("dh_pub") == "00" * 32
except (json.JSONDecodeError, TypeError, UnicodeDecodeError):
pass
if is_group:
entry["sender_chain_id"] = rh_parsed["chain_id"]
if m.get("sender_chain_n") is not None:
entry["sender_chain_n"] = m["sender_chain_n"]
elif isinstance(rh_parsed, dict) and rh_parsed.get("chain_n") is not None:
# Same group-only guard
if "sender_chain_id" in entry:
entry["sender_chain_n"] = rh_parsed["chain_n"]
if m.get("sender_device_id"):
entry["sender_device_id"] = m["sender_device_id"]
if m.get("deleted_at"):
entry["deleted_at"] = m["deleted_at"].isoformat() if hasattr(m["deleted_at"], "isoformat") else str(m["deleted_at"])
# Pin metadata
if m.get("pinned_at"):
entry["pinned_at"] = m["pinned_at"].isoformat() if hasattr(m["pinned_at"], "isoformat") else str(m["pinned_at"])
entry["pinned_by"] = m.get("pinned_by") or ""
# Reactions
msg_reactions = reactions_map.get(m["id"])
if msg_reactions:
entry["reactions"] = msg_reactions
result.append(entry)
total_count = await adb.count_messages(conv_id, session["user_id"])
logger.info("[FETCH] %s fetched %d/%d msgs from conv=%s (limit=%d, offset=%d%s)",
_who(session), len(result), total_count, conv_id[:8], limit, offset,
f", after={after_ts}" if after_ts else "")
await send_resp(msg, writer, "get_messages", "ok",
{"messages": result, "total_count": total_count})
async def handle_remove_member(msg: dict, session: dict, writer: ProtocolWriter):
if await _is_rate_limited(f"remove_member|{session['user_id']}", 10):
await send_resp(msg, writer, "remove_member", "error", {"message": "Too many requests. Try later."})
return
conv_id = msg.get("conversation_id", "")
user_id = msg.get("user_id", "")
if not conv_id or not user_id:
await send_resp(msg, writer, "remove_member", "error", {"message": "Missing conversation_id or user_id"})
return
if not _valid_uuid(conv_id) or not _valid_uuid(user_id):
await send_resp(msg, writer, "remove_member", "error", {"message": "Invalid conversation_id or user_id"})
return
if not await adb.is_conversation_member(conv_id, session["user_id"]):
await send_resp(msg, writer, "remove_member", "error", {"message": "Not a member"})
return
convs = await adb.list_user_conversations(session["user_id"])
conv_data = None
for c in convs:
if c["id"] == conv_id:
conv_data = c
break
if not conv_data or conv_data.get("created_by") != session["user_id"]:
await send_resp(msg, writer, "remove_member", "error", {"message": "Only the group creator can remove members"})
return
if user_id == session["user_id"]:
await send_resp(msg, writer, "remove_member", "error", {"message": "Cannot remove yourself"})
return
# Get remaining members before removing (to notify them)
members_before = await adb.get_conversation_members(conv_id)
# M6: atomic removal — return value confirms row existed
removed = await adb.remove_conversation_member_atomic(conv_id, user_id)
if not removed:
await send_resp(msg, writer, "remove_member", "error", {"message": "Member already removed"})
return
logger.info("[MEMBER] %s removed user=%s from conv=%s", _who(session), user_id[:8], conv_id[:8])
await send_resp(msg, writer, "remove_member", "ok", {"message": "OK"})
# Notify removed member and remaining members
notif_data = {
"conversation_id": conv_id,
"user_id": user_id,
}
member_ids = [m["id"] for m in members_before if m["id"] != session["user_id"]]
await _notify_users(member_ids, "member_removed", notif_data)
async def handle_leave_group(msg: dict, session: dict, writer: ProtocolWriter):
"""Leave a group conversation voluntarily."""
conv_id = msg.get("conversation_id", "")
if not conv_id:
await send_resp(msg, writer, "leave_group", "error", {"message": "Missing conversation_id"})
return
if not _valid_uuid(conv_id):
await send_resp(msg, writer, "leave_group", "error", {"message": "Invalid conversation_id"})
return
if not await adb.is_conversation_member(conv_id, session["user_id"]):
await send_resp(msg, writer, "leave_group", "error", {"message": "Not a member"})
return
# Don't allow leaving DMs (2 members without a name)
conv = await adb.get_conversation(conv_id)
members = await adb.get_conversation_members(conv_id)
if len(members) <= 2 and not (conv and conv.get("name")):
await send_resp(msg, writer, "leave_group", "error", {"message": "Cannot leave a DM conversation"})
return
# If creator is leaving, transfer to first remaining member
if conv and conv.get("created_by") == session["user_id"]:
remaining = [m for m in members if m["id"] != session["user_id"]]
if remaining:
await adb.update_conversation_creator(conv_id, remaining[0]["id"])
# M6: atomic removal
await adb.remove_conversation_member_atomic(conv_id, session["user_id"])
logger.info("[LEAVE] %s left group conv=%s", _who(session), conv_id[:8])
await send_resp(msg, writer, "leave_group", "ok", {"message": "OK"})
# Notify remaining members
notif_data = {
"conversation_id": conv_id,
"user_id": session["user_id"],
}
member_ids = [m["id"] for m in members if m["id"] != session["user_id"]]
await _notify_users(member_ids, "member_removed", notif_data)
async def handle_rename_conversation(msg: dict, session: dict, writer: ProtocolWriter):
"""Rename a group conversation (creator only)."""
if await _is_rate_limited(f"rename_conv|{session['user_id']}", 5):
await send_resp(msg, writer, "rename_conversation", "error", {"message": "Too many requests. Try later."})
return
conv_id = msg.get("conversation_id", "")
new_name = msg.get("name", "").strip()
if not conv_id or not new_name:
await send_resp(msg, writer, "rename_conversation", "error", {"message": "Missing conversation_id or name"})
return
if not _valid_uuid(conv_id):
await send_resp(msg, writer, "rename_conversation", "error", {"message": "Invalid conversation_id"})
return
if len(new_name) > 100:
await send_resp(msg, writer, "rename_conversation", "error", {"message": "Name too long (max 100)"})
return
if not await adb.is_conversation_member(conv_id, session["user_id"]):
await send_resp(msg, writer, "rename_conversation", "error", {"message": "Not a member"})
return
conv = await adb.get_conversation(conv_id)
if not conv or not conv.get("name"):
await send_resp(msg, writer, "rename_conversation", "error", {"message": "Cannot rename a DM conversation"})
return
if conv.get("created_by") != session["user_id"]:
await send_resp(msg, writer, "rename_conversation", "error", {"message": "Only the group creator can rename"})
return
await adb.update_conversation_name(conv_id, new_name)
logger.info("[RENAME] %s renamed conv=%s", _who(session), conv_id[:8])
await send_resp(msg, writer, "rename_conversation", "ok", {"message": "OK"})
# Notify all members
members = await adb.get_conversation_members(conv_id)
member_ids = [m["id"] for m in members if m["id"] != session["user_id"]]
await _notify_users(member_ids, "conversation_renamed", {
"conversation_id": conv_id,
"name": new_name,
"renamed_by": session["user_id"],
})
async def handle_delete_conversation(msg: dict, session: dict, writer: ProtocolWriter):
"""Delete a conversation for the current user. Removes user from members,
deletes the conversation if no members remain."""
if await _is_rate_limited(f"delete_conv|{session['user_id']}", 5):
await send_resp(msg, writer, "delete_conversation", "error", {"message": "Too many requests. Try later."})
return
conv_id = msg.get("conversation_id", "")
if not conv_id:
await send_resp(msg, writer, "delete_conversation", "error", {"message": "Missing conversation_id"})
return
if not _valid_uuid(conv_id):
await send_resp(msg, writer, "delete_conversation", "error", {"message": "Invalid conversation_id"})
return
if not await adb.is_conversation_member(conv_id, session["user_id"]):
await send_resp(msg, writer, "delete_conversation", "error", {"message": "Not a member"})
return
conv = await adb.get_conversation(conv_id)
members = await adb.get_conversation_members(conv_id)
is_group = len(members) > 2 or (conv and conv.get("name"))
# Groups can only be deleted by the creator (admin)
if is_group and (not conv or conv.get("created_by") != session["user_id"]):
await send_resp(msg, writer, "delete_conversation", "error", {"message": "Only the group creator can delete this conversation"})
return
if is_group:
# Group: creator deletes for everyone — remove all members, clean up, delete
for member in members:
await adb.remove_conversation_member(conv_id, member["id"])
else:
# DM: only remove self; other user keeps the conversation
await adb.remove_conversation_member(conv_id, session["user_id"])
remaining_count = await adb.count_conversation_members(conv_id)
if remaining_count == 0:
# Clean up uploaded files from disk
file_ids = await adb.get_conversation_file_ids(conv_id)
for fid in file_ids:
for ext in (".enc", ".tmp"):
p = _safe_upload_path(fid, ext)
if not p:
continue
_secure_delete(p)
await adb.delete_conversation(conv_id)
logger.info("[DELETE] %s deleted conv=%s", _who(session), conv_id[:8])
await send_resp(msg, writer, "delete_conversation", "ok", {"message": "OK"})
# Notify other members they were removed
notif_data = {
"conversation_id": conv_id,
"user_id": session["user_id"],
}
member_ids = [m["id"] for m in members if m["id"] != session["user_id"]]
await _notify_users(member_ids, "member_removed", notif_data)
async def handle_mark_read(msg: dict, session: dict, writer: ProtocolWriter):
conv_id = msg.get("conversation_id", "")
message_ids = msg.get("message_ids", [])
if not conv_id or not message_ids:
await send_resp(msg, writer, "mark_read", "error", {"message": "Missing conversation_id or message_ids"})
return
if not _valid_uuid(conv_id):
await send_resp(msg, writer, "mark_read", "error", {"message": "Invalid conversation_id"})
return
if len(message_ids) > 500:
await send_resp(msg, writer, "mark_read", "error", {"message": "Too many message_ids (max 500)"})
return
if not await adb.is_conversation_member(conv_id, session["user_id"]):
await send_resp(msg, writer, "mark_read", "error", {"message": "Not a member"})
return
# M1 fix: filter to only message_ids that belong to this conversation
valid_ids = await adb.filter_message_ids_by_conversation(conv_id, message_ids)
if not valid_ids:
await send_resp(msg, writer, "mark_read", "ok", {"message": "OK"})
return
await adb.mark_messages_read(conv_id, session["user_id"], valid_ids)
logger.info("[READ] %s marked %d msgs read in conv=%s", _who(session), len(valid_ids), conv_id[:8])
await send_resp(msg, writer, "mark_read", "ok", {"message": "OK"})
members = await adb.get_conversation_members(conv_id)
notif_data = {
"conversation_id": conv_id,
"user_id": session["user_id"],
"message_ids": valid_ids,
}
member_ids = [m["id"] for m in members if m["id"] != session["user_id"]]
await _notify_users(member_ids, "messages_read", notif_data)
async def handle_mark_conversation_read(msg: dict, session: dict, writer: ProtocolWriter):
conv_id = msg.get("conversation_id", "")
if not conv_id:
await send_resp(msg, writer, "mark_conversation_read", "error", {"message": "Missing conversation_id"})
return
if not _valid_uuid(conv_id):
await send_resp(msg, writer, "mark_conversation_read", "error", {"message": "Invalid conversation_id"})
return
if not await adb.is_conversation_member(conv_id, session["user_id"]):
await send_resp(msg, writer, "mark_conversation_read", "error", {"message": "Not a member"})
return
count = await adb.mark_conversation_read(conv_id, session["user_id"])
logger.info("[READ] %s marked conv=%s all-read (%d msgs)", _who(session), conv_id[:8], count)
await send_resp(msg, writer, "mark_conversation_read", "ok", {"marked_count": count})
if count > 0:
members = await adb.get_conversation_members(conv_id)
member_ids = [m["id"] for m in members if m["id"] != session["user_id"]]
await _notify_users(member_ids, "messages_read", {
"conversation_id": conv_id,
"user_id": session["user_id"],
"message_ids": [],
})
async def handle_confirm_delivery(msg: dict, session: dict, writer: ProtocolWriter):
conv_id = msg.get("conversation_id", "")
message_ids = msg.get("message_ids", [])
if not conv_id or not message_ids:
await send_resp(msg, writer, "confirm_delivery", "error", {"message": "Missing conversation_id or message_ids"})
return
if not _valid_uuid(conv_id):
await send_resp(msg, writer, "confirm_delivery", "error", {"message": "Invalid conversation_id"})
return
if len(message_ids) > 500:
await send_resp(msg, writer, "confirm_delivery", "error", {"message": "Too many message_ids (max 500)"})
return
if not await adb.is_conversation_member(conv_id, session["user_id"]):
await send_resp(msg, writer, "confirm_delivery", "error", {"message": "Not a member"})
return
# M1 fix: filter to only message_ids that belong to this conversation
valid_ids = await adb.filter_message_ids_by_conversation(conv_id, message_ids)
if not valid_ids:
await send_resp(msg, writer, "confirm_delivery", "ok", {"message": "OK"})
return
await adb.mark_messages_delivered(conv_id, session["user_id"], valid_ids)
logger.info("[DELIVERY] %s confirmed %d msgs delivered in conv=%s", _who(session), len(valid_ids), conv_id[:8])
await send_resp(msg, writer, "confirm_delivery", "ok", {"message": "OK"})
# Notify senders — batch lookup sender_id per message, push to each sender
sender_msgs: dict[str, list[str]] = {}
for mid in valid_ids:
sid = await adb.get_message_sender(mid)
if sid and sid != session["user_id"]:
sender_msgs.setdefault(sid, []).append(mid)
for sender_id, mids in sender_msgs.items():
await _notify_users([sender_id], "message_delivered", {
"conversation_id": conv_id,
"user_id": session["user_id"],
"message_ids": mids,
})
async def handle_delete_message(msg: dict, session: dict, writer: ProtocolWriter):
if await _is_rate_limited(f"delete_msg|{session['user_id']}", 20):
await send_resp(msg, writer, "delete_message", "error", {"message": "Too many requests. Try later."})
return
message_id = msg.get("message_id", "")
if not message_id:
await send_resp(msg, writer, "delete_message", "error", {"message": "Missing message_id"})
return
if not _valid_uuid(message_id):
await send_resp(msg, writer, "delete_message", "error", {"message": "Invalid message_id"})
return
conv_id = await adb.get_message_conversation(message_id)
if not conv_id:
await send_resp(msg, writer, "delete_message", "error", {"message": "Message not found"})
return
if not await adb.is_conversation_member(conv_id, session["user_id"]):
await send_resp(msg, writer, "delete_message", "error", {"message": "Not a member"})
return
result = await adb.soft_delete_message(message_id, session["user_id"])
if result is None:
await send_resp(msg, writer, "delete_message", "error", {"message": "Cannot delete this message"})
return
image_file_id = result.get("image_file_id")
if image_file_id:
image_path = _safe_upload_path(image_file_id, ".enc")
if image_path:
_secure_delete(image_path)
await adb.delete_image_upload(image_file_id)
logger.info("[MSG] %s deleted message=%s", _who(session), message_id[:8])
await send_resp(msg, writer, "delete_message", "ok", {"message_id": message_id})
members = await adb.get_conversation_members(conv_id)
notif_data = {"message_id": message_id, "conversation_id": conv_id}
member_ids = [m["id"] for m in members if m["id"] != session["user_id"]]
await _notify_users(member_ids, "message_deleted", notif_data)
async def handle_react_message(msg: dict, session: dict, writer: ProtocolWriter):
if await _is_rate_limited(f"react|{session['user_id']}", 20):
await send_resp(msg, writer, "react_message", "error", {"message": "Too many requests. Try later."})
return
message_id = msg.get("message_id", "")
reaction = msg.get("reaction", "")
action = msg.get("action", "add") # "add" or "remove"
if not message_id or not reaction:
await send_resp(msg, writer, "react_message", "error", {"message": "Missing fields"})
return
if not _valid_uuid(message_id):
await send_resp(msg, writer, "react_message", "error", {"message": "Invalid message_id"})
return
if reaction not in db.ALLOWED_REACTIONS:
await send_resp(msg, writer, "react_message", "error", {"message": "Invalid reaction"})
return
if action not in ("add", "remove"):
await send_resp(msg, writer, "react_message", "error", {"message": "Invalid action"})
return
conv_id = await adb.get_message_conversation(message_id)
if not conv_id:
await send_resp(msg, writer, "react_message", "error", {"message": "Message not found"})
return
if not await adb.is_conversation_member(conv_id, session["user_id"]):
await send_resp(msg, writer, "react_message", "error", {"message": "Not a member"})
return
old_reaction = None
if action == "add":
changed, old_reaction = await adb.add_reaction(message_id, session["user_id"], reaction)
if not changed:
await send_resp(msg, writer, "react_message", "ok", {"message_id": message_id})
return
else:
await adb.remove_reaction(message_id, session["user_id"])
logger.info("[MSG] %s %s reaction '%s' on message=%s", _who(session), action, reaction, message_id[:8])
resp_data = {"message_id": message_id}
if old_reaction:
resp_data["old_reaction"] = old_reaction
await send_resp(msg, writer, "react_message", "ok", resp_data)
members = await adb.get_conversation_members(conv_id)
member_ids = [m["id"] for m in members if m["id"] != session["user_id"]]
# If replacing an old reaction, notify removal first
if old_reaction:
remove_data = {
"message_id": message_id,
"conversation_id": conv_id,
"user_id": session["user_id"],
"username": session.get("username", ""),
"reaction": old_reaction,
"action": "remove",
}
await _notify_users(member_ids, "message_reacted", remove_data)
notif_data = {
"message_id": message_id,
"conversation_id": conv_id,
"user_id": session["user_id"],
"username": session.get("username", ""),
"reaction": reaction,
"action": action,
}
await _notify_users(member_ids, "message_reacted", notif_data)
async def handle_pin_message(msg: dict, session: dict, writer: ProtocolWriter):
message_id = msg.get("message_id", "")
action = msg.get("action", "pin") # "pin" or "unpin"
conversation_id = msg.get("conversation_id", "")
if not message_id or not conversation_id:
await send_resp(msg, writer, "pin_message", "error", {"message": "Missing fields"})
return
if not _valid_uuid(message_id) or not _valid_uuid(conversation_id):
await send_resp(msg, writer, "pin_message", "error", {"message": "Invalid ID"})
return
if action not in ("pin", "unpin"):
await send_resp(msg, writer, "pin_message", "error", {"message": "Invalid action"})
return
if not await adb.is_conversation_member(conversation_id, session["user_id"]):
await send_resp(msg, writer, "pin_message", "error", {"message": "Not a member"})
return
if action == "pin":
ok = await adb.pin_message(message_id, session["user_id"], conversation_id)
else:
ok = await adb.unpin_message(message_id, conversation_id)
if not ok:
await send_resp(msg, writer, "pin_message", "error",
{"message": "Already pinned" if action == "pin" else "Not pinned"})
return
logger.info("[MSG] %s %s message=%s in conv=%s", _who(session), action, message_id[:8], conversation_id[:8])
await send_resp(msg, writer, "pin_message", "ok", {"message_id": message_id, "action": action})
members = await adb.get_conversation_members(conversation_id)
notif_type = "message_pinned" if action == "pin" else "message_unpinned"
notif_data = {
"message_id": message_id,
"conversation_id": conversation_id,
"user_id": session["user_id"],
"username": session.get("username", ""),
}
member_ids = [m["id"] for m in members if m["id"] != session["user_id"]]
await _notify_users(member_ids, notif_type, notif_data)
async def handle_get_pinned_messages(msg: dict, session: dict, writer: ProtocolWriter):
conversation_id = msg.get("conversation_id", "")
if not conversation_id:
await send_resp(msg, writer, "get_pinned_messages", "error", {"message": "Missing conversation_id"})
return
if not _valid_uuid(conversation_id):
await send_resp(msg, writer, "get_pinned_messages", "error", {"message": "Invalid conversation_id"})
return
if not await adb.is_conversation_member(conversation_id, session["user_id"]):
await send_resp(msg, writer, "get_pinned_messages", "error", {"message": "Not a member"})
return
pinned = await adb.get_pinned_messages(conversation_id, session["user_id"])
await send_resp(msg, writer, "get_pinned_messages", "ok", {"messages": pinned})
async def handle_upload_image_start(msg: dict, session: dict, writer: ProtocolWriter):
conv_id = msg.get("conversation_id", "")
file_size = msg.get("file_size", 0)
file_id = msg.get("file_id", "")
file_type = msg.get("file_type", "image") # "image" or "file"
if not conv_id or not file_id:
await send_resp(msg, writer, "upload_image_start", "error", {"message": "Missing fields"})
return
if not _valid_uuid(file_id):
await send_resp(msg, writer, "upload_image_start", "error", {"message": "Invalid file_id"})
return
# M5: rate limit + caps on in-flight uploads
addr = _get_peer_addr(writer)
if await _is_rate_limited(f"upload_start|{session['user_id']}", 10):
await send_resp(msg, writer, "upload_image_start", "error", {"message": "Too many uploads. Try later."})
return
if not await adb.is_conversation_member(conv_id, session["user_id"]):
await send_resp(msg, writer, "upload_image_start", "error", {"message": "Not a member"})
return
max_bytes = MAX_FILE_BYTES if file_type == "file" else MAX_IMAGE_BYTES
if max_bytes > 0 and file_size > max_bytes:
await send_resp(msg, writer, "upload_image_start", "error",
{"message": f"File too large (max {max_bytes} bytes)"})
return
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
temp_path = _safe_upload_path(file_id, ".tmp")
if not temp_path:
await send_resp(msg, writer, "upload_image_start", "error", {"message": "Invalid file_id"})
return
# M5: atomic cap check + insert under single lock acquisition
cap_error = ""
async with _uploads_lock:
total = len(pending_uploads)
user_count = sum(1 for u in pending_uploads.values() if u.get("uploader_id") == session["user_id"])
if total >= MAX_UPLOADS_GLOBAL:
cap_error = "Server upload limit reached. Try later."
elif user_count >= MAX_UPLOADS_PER_USER:
cap_error = "Too many active uploads. Finish or cancel existing ones."
else:
temp_path.write_bytes(b"")
pending_uploads[file_id] = {
"temp_path": str(temp_path),
"received_bytes": 0,
"file_size": file_size,
"max_bytes": max_bytes,
"conv_id": conv_id,
"uploader_id": session["user_id"],
}
if cap_error:
await send_resp(msg, writer, "upload_image_start", "error", {"message": cap_error})
return
try:
await adb.create_image_upload(file_id, conv_id, session["user_id"], file_size)
except Exception:
# Rollback: remove from pending_uploads + delete temp file
async with _uploads_lock:
pending_uploads.pop(file_id, None)
_secure_delete(temp_path)
logger.exception("[UPLOAD] DB create failed for file=%s", file_id[:8])
await send_resp(msg, writer, "upload_image_start", "error", {"message": "Upload failed"})
return
logger.info("[UPLOAD] %s started upload file=%s (%s, %d bytes)",
_who(session), file_id[:8], file_type, file_size)
await send_resp(msg, writer, "upload_image_start", "ok", {"file_id": file_id})
async def handle_upload_image_chunk(msg: dict, session: dict, writer: ProtocolWriter):
file_id = msg.get("file_id", "")
chunk_data = msg.get("data", "")
if not file_id or not chunk_data:
await send_resp(msg, writer, "upload_image_chunk", "error", {"message": "Missing fields"})
return
async with _uploads_lock:
upload = pending_uploads.get(file_id)
if not upload or upload["uploader_id"] != session["user_id"]:
upload = None
else:
temp_path_str = upload["temp_path"]
upload_max = upload.get("max_bytes", 0)
if not upload:
await send_resp(msg, writer, "upload_image_chunk", "error", {"message": "No active upload"})
return
raw = decode_binary(chunk_data)
temp_path = Path(temp_path_str)
await asyncio.to_thread(_append_file, temp_path, raw)
over_limit = False
async with _uploads_lock:
upload = pending_uploads.get(file_id)
if upload:
upload["received_bytes"] += len(raw)
if upload_max > 0 and upload["received_bytes"] > upload_max:
pending_uploads.pop(file_id, None)
over_limit = True
received = upload["received_bytes"]
if over_limit:
_secure_delete(temp_path)
await send_resp(msg, writer, "upload_image_chunk", "error", {"message": "Upload exceeds size limit"})
return
await send_resp(msg, writer, "upload_image_chunk", "ok", {"received": received})
async def handle_upload_image_end(msg: dict, session: dict, writer: ProtocolWriter):
file_id = msg.get("file_id", "")
if not file_id:
await send_resp(msg, writer, "upload_image_end", "error", {"message": "Missing file_id"})
return
async with _uploads_lock:
upload = pending_uploads.pop(file_id, None)
if not upload or upload["uploader_id"] != session["user_id"]:
await send_resp(msg, writer, "upload_image_end", "error", {"message": "No active upload"})
return
temp_path = Path(upload["temp_path"])
if upload["received_bytes"] != upload["file_size"]:
_secure_delete(temp_path)
await send_resp(msg, writer, "upload_image_end", "error",
{"message": f"Incomplete upload: received {upload['received_bytes']} of {upload['file_size']} bytes"})
return
final_path = _safe_upload_path(file_id, ".enc")
if not final_path:
_secure_delete(temp_path)
await send_resp(msg, writer, "upload_image_end", "error", {"message": "Invalid file_id"})
return
def _move_file():
try:
temp_path.rename(final_path)
except Exception:
import shutil
shutil.move(str(temp_path), str(final_path))
await asyncio.to_thread(_move_file)
await adb.complete_image_upload(file_id)
logger.info("[UPLOAD] %s completed upload file=%s (%d bytes)",
_who(session), file_id[:8], upload["received_bytes"])
await send_resp(msg, writer, "upload_image_end", "ok", {"file_id": file_id})
async def handle_download_image(msg: dict, session: dict, writer: ProtocolWriter):
file_id = msg.get("file_id", "")
offset = msg.get("offset", 0)
if not file_id:
await send_resp(msg, writer, "download_image", "error", {"message": "Missing file_id"})
return
if not _valid_uuid(file_id):
await send_resp(msg, writer, "download_image", "error", {"message": "Invalid file_id"})
return
upload = await adb.get_image_upload(file_id)
if not upload or not upload["completed"]:
await send_resp(msg, writer, "download_image", "error", {"message": "File not found"})
return
if not await adb.is_conversation_member(upload["conversation_id"], session["user_id"]):
await send_resp(msg, writer, "download_image", "error", {"message": "Not a member"})
return
file_path = _safe_upload_path(file_id, ".enc")
if not file_path or not file_path.exists():
await send_resp(msg, writer, "download_image", "error", {"message": "File not found"})
return
file_size = file_path.stat().st_size
chunk = await asyncio.to_thread(_read_file_chunk, file_path, offset, IMAGE_CHUNK_SIZE)
done = (offset + len(chunk)) >= file_size
if offset == 0:
logger.info("[DOWNLOAD] %s downloading file=%s (%d bytes)", _who(session), file_id[:8], file_size)
await send_resp(msg, writer, "download_image", "ok", {
"file_id": file_id,
"data": encode_binary(chunk),
"offset": offset,
"done": done,
"total_size": file_size,
})
MAX_AVATAR_BYTES = 2 * 1024 * 1024 # 2 MB
async def handle_get_profile(msg: dict, session: dict, writer: ProtocolWriter):
"""Get user profile (respects visibility for other users)."""
target_user_id = msg.get("user_id", "").strip()
if not target_user_id:
target_user_id = session["user_id"]
elif not _valid_uuid(target_user_id):
await send_resp(msg, writer, "get_profile", "error", {"message": "Invalid user_id"})
return
profile = await adb.get_user_profile(target_user_id, viewer_id=session["user_id"])
if not profile:
await send_resp(msg, writer, "get_profile", "error", {"message": "User not found"})
return
# Serialize datetime fields
for key in ("created_at", "updated_at"):
if profile.get(key) and hasattr(profile[key], "isoformat"):
profile[key] = profile[key].isoformat()
await send_resp(msg, writer, "get_profile", "ok", profile)
async def handle_update_profile(msg: dict, session: dict, writer: ProtocolWriter):
"""Update own profile fields."""
fields = {}
for key in ("phone", "phone_visible", "email_visible", "location", "location_visible"):
if key in msg:
fields[key] = msg[key]
if not fields:
await send_resp(msg, writer, "update_profile", "error", {"message": "No fields to update"})
return
await adb.update_user_profile(session["user_id"], **fields)
await send_resp(msg, writer, "update_profile", "ok", {"message": "OK"})
async def handle_update_avatar(msg: dict, session: dict, writer: ProtocolWriter):
"""Upload avatar (base64 in single message, max 2MB)."""
if await _is_rate_limited(f"update_avatar|{session['user_id']}", 5):
await send_resp(msg, writer, "update_avatar", "error", {"message": "Too many requests. Try later."})
return
avatar_b64 = msg.get("data", "")
if not avatar_b64:
await send_resp(msg, writer, "update_avatar", "error", {"message": "Missing data"})
return
avatar_data = decode_binary(avatar_b64)
if len(avatar_data) > MAX_AVATAR_BYTES:
await send_resp(msg, writer, "update_avatar", "error",
{"message": f"Avatar too large (max {MAX_AVATAR_BYTES} bytes)"})
return
# Detect format from magic bytes
ext = "jpg"
if avatar_data[:8] == b'\x89PNG\r\n\x1a\n':
ext = "png"
avatar_dir = UPLOAD_DIR / "avatars"
avatar_dir.mkdir(parents=True, exist_ok=True)
os.chmod(avatar_dir, 0o700)
filename = f"{session['user_id']}.{ext}"
avatar_path = _safe_avatar_path(filename)
if not avatar_path:
await send_resp(msg, writer, "update_avatar", "error", {"message": "Invalid path"})
return
await asyncio.to_thread(avatar_path.write_bytes, avatar_data)
await adb.update_user_profile(session["user_id"], avatar_file=filename)
logger.info("[AVATAR] %s updated their avatar", _who(session))
await send_resp(msg, writer, "update_avatar", "ok", {"avatar_file": filename})
async def handle_get_avatar(msg: dict, session: dict, writer: ProtocolWriter):
"""Download avatar for a user."""
target_user_id = msg.get("user_id", "").strip()
if not target_user_id:
await send_resp(msg, writer, "get_avatar", "error", {"message": "Missing user_id"})
return
if not _valid_uuid(target_user_id):
await send_resp(msg, writer, "get_avatar", "error", {"message": "Invalid user_id"})
return
profile = await adb.get_user_profile(target_user_id)
if not profile or not profile.get("avatar_file"):
await send_resp(msg, writer, "get_avatar", "error", {"message": "No avatar"})
return
avatar_path = _safe_avatar_path(profile["avatar_file"])
if not avatar_path or not avatar_path.exists():
await send_resp(msg, writer, "get_avatar", "error", {"message": "Avatar file not found"})
return
avatar_data = await asyncio.to_thread(avatar_path.read_bytes)
await send_resp(msg, writer, "get_avatar", "ok", {
"user_id": target_user_id,
"data": encode_binary(avatar_data),
"filename": profile["avatar_file"],
})
async def handle_update_group_avatar(msg: dict, session: dict, writer: ProtocolWriter):
"""Upload avatar for a group conversation (base64, max 2MB). Only members can set it."""
if await _is_rate_limited(f"update_avatar|{session['user_id']}", 5):
await send_resp(msg, writer, "update_group_avatar", "error", {"message": "Too many requests. Try later."})
return
conv_id = msg.get("conversation_id", "").strip()
avatar_b64 = msg.get("data", "")
if not conv_id or not avatar_b64:
await send_resp(msg, writer, "update_group_avatar", "error", {"message": "Missing fields"})
return
if not _valid_uuid(conv_id):
await send_resp(msg, writer, "update_group_avatar", "error", {"message": "Invalid conversation_id"})
return
if not await adb.is_conversation_member(conv_id, session["user_id"]):
await send_resp(msg, writer, "update_group_avatar", "error", {"message": "Not a member"})
return
avatar_data = decode_binary(avatar_b64)
if len(avatar_data) > MAX_AVATAR_BYTES:
await send_resp(msg, writer, "update_group_avatar", "error",
{"message": f"Avatar too large (max {MAX_AVATAR_BYTES} bytes)"})
return
ext = "jpg"
if avatar_data[:8] == b'\x89PNG\r\n\x1a\n':
ext = "png"
avatar_dir = UPLOAD_DIR / "avatars"
avatar_dir.mkdir(parents=True, exist_ok=True)
os.chmod(avatar_dir, 0o700)
filename = f"group_{conv_id}.{ext}"
avatar_path = _safe_avatar_path(filename)
if not avatar_path:
await send_resp(msg, writer, "update_group_avatar", "error", {"message": "Invalid path"})
return
await asyncio.to_thread(avatar_path.write_bytes, avatar_data)
await adb.update_conversation_avatar(conv_id, filename)
logger.info("[AVATAR] %s updated group avatar for conv=%s", _who(session), conv_id[:8])
await send_resp(msg, writer, "update_group_avatar", "ok", {"avatar_file": filename})
async def handle_get_group_avatar(msg: dict, session: dict, writer: ProtocolWriter):
"""Download avatar for a group conversation."""
conv_id = msg.get("conversation_id", "").strip()
if not conv_id:
await send_resp(msg, writer, "get_group_avatar", "error", {"message": "Missing conversation_id"})
return
if not _valid_uuid(conv_id):
await send_resp(msg, writer, "get_group_avatar", "error", {"message": "Invalid conversation_id"})
return
if not await adb.is_conversation_member(conv_id, session["user_id"]):
await send_resp(msg, writer, "get_group_avatar", "error", {"message": "Not a member"})
return
conv = await adb.get_conversation(conv_id)
if not conv or not conv.get("avatar_file"):
await send_resp(msg, writer, "get_group_avatar", "error", {"message": "No avatar"})
return
avatar_path = _safe_avatar_path(conv["avatar_file"])
if not avatar_path or not avatar_path.exists():
await send_resp(msg, writer, "get_group_avatar", "error", {"message": "Avatar file not found"})
return
avatar_data = await asyncio.to_thread(avatar_path.read_bytes)
await send_resp(msg, writer, "get_group_avatar", "ok", {
"conversation_id": conv_id,
"data": encode_binary(avatar_data),
"filename": conv["avatar_file"],
})
async def handle_list_devices(msg: dict, session: dict, writer: ProtocolWriter):
"""List all devices for the current user."""
devices = await adb.get_user_devices(session["user_id"])
result = []
for d in devices:
entry = {
"device_id": d["id"],
"device_name": d.get("device_name"),
"created_at": d["created_at"].isoformat() if hasattr(d["created_at"], "isoformat") else str(d["created_at"]),
"last_seen_at": d["last_seen_at"].isoformat() if d.get("last_seen_at") and hasattr(d["last_seen_at"], "isoformat") else (str(d["last_seen_at"]) if d.get("last_seen_at") else None),
"is_current": d["id"] == session.get("device_id"),
}
result.append(entry)
await send_resp(msg, writer, "list_devices", "ok", {"devices": result})
async def handle_remove_device(msg: dict, session: dict, writer: ProtocolWriter):
"""Remove a device (cannot remove current device)."""
device_id = msg.get("device_id", "").strip()
if not device_id:
await send_resp(msg, writer, "remove_device", "error", {"message": "Missing device_id"})
return
if not _valid_uuid(device_id):
await send_resp(msg, writer, "remove_device", "error", {"message": "Invalid device_id"})
return
if device_id == session.get("device_id"):
await send_resp(msg, writer, "remove_device", "error", {"message": "Cannot remove current device"})
return
dev = await adb.get_device(device_id)
if not dev or dev["user_id"] != session["user_id"]:
await send_resp(msg, writer, "remove_device", "error", {"message": "Device not found"})
return
await adb.delete_device(device_id)
logger.info("[DEVICE] %s removed device=%s", _who(session), device_id[:8])
await send_resp(msg, writer, "remove_device", "ok", {"message": "OK"})
async def handle_session_reset(msg: dict, session: dict, writer: ProtocolWriter):
"""Notify peer to reset a corrupted Double Ratchet session."""
peer_user_id = msg.get("peer_user_id", "").strip()
peer_device_id = msg.get("peer_device_id", "").strip() or None
if not peer_user_id or not _valid_uuid(peer_user_id):
await send_resp(msg, writer, "session_reset", "error", {"message": "Invalid peer_user_id"})
return
if peer_device_id and not _valid_uuid(peer_device_id):
await send_resp(msg, writer, "session_reset", "error", {"message": "Invalid peer_device_id"})
return
# H3 fix: rate limit (5/min per user, keyed by user_id only — IP-independent)
if await _is_rate_limited(f"session_reset|{session['user_id']}", 5):
await send_resp(msg, writer, "session_reset", "error", {"message": "Rate limit exceeded"})
return
# H3 fix: verify users share at least one conversation
if not await adb.shares_conversation(session["user_id"], peer_user_id):
await send_resp(msg, writer, "session_reset", "error", {"message": "No shared conversation"})
return
# Push notification to peer (target specific device if specified)
notif_data = {
"from_user_id": session["user_id"],
"from_device_id": session.get("device_id"),
}
if peer_device_id:
# Send only to the specific device
targets = []
async with _clients_lock:
for w in connected_clients.get(peer_user_id, []):
if writer_device_map.get(id(w)) == peer_device_id:
targets.append(w)
for w in targets:
try:
await w.send_response("session_reset", "ok", notif_data)
except Exception:
pass
else:
await _notify_users([peer_user_id], "session_reset", notif_data)
logger.info("[SESSION] %s reset session with peer=%s", _who(session), peer_user_id[:8])
await send_resp(msg, writer, "session_reset", "ok", {})
async def handle_get_deleted_since(msg: dict, session: dict, writer: ProtocolWriter):
"""Return message IDs deleted since a given timestamp."""
conv_id = msg.get("conversation_id", "")
since_ts = msg.get("since_ts", "")
if not conv_id or not since_ts:
await send_resp(msg, writer, "get_deleted_since", "error", {"message": "Missing parameters"})
return
if not _valid_uuid(conv_id):
await send_resp(msg, writer, "get_deleted_since", "error", {"message": "Invalid conversation_id"})
return
if not await adb.is_conversation_member(conv_id, session["user_id"]):
await send_resp(msg, writer, "get_deleted_since", "error", {"message": "Not a member"})
return
deleted_ids = await adb.get_deleted_messages_since(conv_id, session["user_id"], since_ts)
await send_resp(msg, writer, "get_deleted_since", "ok", {"deleted_ids": deleted_ids})
async def handle_reencrypt_messages(msg: dict, session: dict, writer: ProtocolWriter):
"""Re-encrypt message history with self-encryption key (for device pairing)."""
if await _is_rate_limited(f"reencrypt|{session['user_id']}", 10):
await send_resp(msg, writer, "reencrypt_messages", "error", {"message": "Too many requests. Try later."})
return
updates_raw = msg.get("updates", [])
if not updates_raw:
await send_resp(msg, writer, "reencrypt_messages", "error", {"message": "No updates"})
return
if len(updates_raw) > 500:
await send_resp(msg, writer, "reencrypt_messages", "error",
{"message": "Too many updates (max 500 per request)"})
return
updates = []
for u in updates_raw:
mid = u.get("message_id", "")
ct_b64 = u.get("encrypted_content", "")
nonce_b64 = u.get("nonce", "")
if not mid or not ct_b64 or not nonce_b64:
continue
updates.append({
"message_id": mid,
"encrypted_content": decode_binary(ct_b64),
"nonce": decode_binary(nonce_b64),
})
if not updates:
await send_resp(msg, writer, "reencrypt_messages", "error", {"message": "No valid updates"})
return
await adb.batch_reencrypt_messages(session["user_id"], updates)
logger.info("[REENCRYPT] %s re-encrypted %d messages", _who(session), len(updates))
await send_resp(msg, writer, "reencrypt_messages", "ok", {"count": len(updates)})
async def _cleanup_uploads():
stale = await adb.get_stale_uploads(UPLOAD_STALE_SECONDS)
for s in stale:
fid = s["file_id"]
for ext in (".tmp", ".enc"):
p = _safe_upload_path(fid, ext)
if not p:
continue
_secure_delete(p)
await adb.delete_image_upload(fid)
async with _uploads_lock:
pending_uploads.pop(fid, None)
if stale:
logger.info("Cleaned up %d stale uploads.", len(stale))
async def handle_client(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
global current_connections
addr = _get_peer_addr(ProtocolWriter(writer))
async with _conn_lock:
current_connections += 1
connection_counts[addr] = connection_counts.get(addr, 0) + 1
over_limit = (current_connections > MAX_CONNECTIONS_GLOBAL or
connection_counts[addr] > MAX_CONNECTIONS_PER_IP)
if over_limit:
try:
writer.close()
except Exception:
pass
async with _conn_lock:
current_connections = max(0, current_connections - 1)
connection_counts[addr] = max(0, connection_counts.get(addr, 1) - 1)
return
logger.info("[CONN] Client connected from %s", addr)
proto_reader = ProtocolReader(reader)
proto_writer = ProtocolWriter(writer)
session = None
state = {"_req_times": []}
try:
while True:
try:
msg = await proto_reader.read_message()
except ValueError as e:
try:
await proto_writer.send_response("protocol_error", "error", {"message": str(e)})
except Exception:
pass
break
if msg is None:
break
msg_type = msg.get("type", "")
now = asyncio.get_event_loop().time()
times = [t for t in state["_req_times"] if now - t <= CONNECTION_RL_WINDOW]
if len(times) >= CONNECTION_RL_MAX:
await send_resp(msg, proto_writer, msg_type, "error", {"message": "Too many requests. Slow down."})
state["_req_times"] = times
continue
times.append(now)
state["_req_times"] = times
try:
if msg_type == "register":
await handle_register_start(msg, proto_writer)
elif msg_type == "register_confirm":
await handle_register_confirm(msg, proto_writer)
elif msg_type == "login_start":
await handle_login_start(msg, proto_writer, state)
elif msg_type == "login_finish":
result = await handle_login_finish(msg, proto_writer, state)
if result:
session = result
elif msg_type == "pairing_start":
await handle_pairing_start(msg, proto_writer)
elif msg_type == "pairing_poll":
await handle_pairing_poll(msg, proto_writer)
elif session is None:
await send_resp(msg, proto_writer, msg_type, "error", {"message": "Not logged in"})
elif msg_type == "get_user_info":
await handle_get_user_info(msg, session, proto_writer)
elif msg_type == "upload_prekeys":
await handle_upload_prekeys(msg, session, proto_writer)
elif msg_type == "get_key_bundle":
await handle_get_key_bundle(msg, session, proto_writer)
elif msg_type == "get_prekey_count":
await handle_get_prekey_count(msg, session, proto_writer)
elif msg_type == "ensure_prekeys":
await handle_ensure_prekeys(msg, session, proto_writer)
elif msg_type == "create_conversation":
await handle_create_conversation(msg, session, proto_writer)
elif msg_type == "find_conversation":
await handle_find_conversation(msg, session, proto_writer)
elif msg_type == "add_member":
await handle_add_member(msg, session, proto_writer)
elif msg_type == "accept_invitation":
await handle_accept_invitation(msg, session, proto_writer)
elif msg_type == "decline_invitation":
await handle_decline_invitation(msg, session, proto_writer)
elif msg_type == "list_invitations":
await handle_list_invitations(msg, session, proto_writer)
elif msg_type == "list_conversations":
await handle_list_conversations(msg, session, proto_writer)
elif msg_type == "send_message":
await handle_send_message(msg, session, proto_writer)
elif msg_type == "get_messages":
await handle_get_messages(msg, session, proto_writer)
elif msg_type == "rotate_keys":
await handle_rotate_keys(msg, session, proto_writer)
elif msg_type == "change_username":
await handle_change_username(msg, session, proto_writer)
elif msg_type == "remove_member":
await handle_remove_member(msg, session, proto_writer)
elif msg_type == "leave_group":
await handle_leave_group(msg, session, proto_writer)
elif msg_type == "rename_conversation":
await handle_rename_conversation(msg, session, proto_writer)
elif msg_type == "delete_conversation":
await handle_delete_conversation(msg, session, proto_writer)
elif msg_type == "mark_read":
await handle_mark_read(msg, session, proto_writer)
elif msg_type == "mark_conversation_read":
await handle_mark_conversation_read(msg, session, proto_writer)
elif msg_type == "confirm_delivery":
await handle_confirm_delivery(msg, session, proto_writer)
elif msg_type == "pairing_claim":
await handle_pairing_claim(msg, session, proto_writer)
elif msg_type == "pairing_send":
await handle_pairing_send(msg, session, proto_writer)
elif msg_type == "delete_message":
await handle_delete_message(msg, session, proto_writer)
elif msg_type == "upload_image_start":
await handle_upload_image_start(msg, session, proto_writer)
elif msg_type == "upload_image_chunk":
await handle_upload_image_chunk(msg, session, proto_writer)
elif msg_type == "upload_image_end":
await handle_upload_image_end(msg, session, proto_writer)
elif msg_type == "download_image":
await handle_download_image(msg, session, proto_writer)
elif msg_type == "get_profile":
await handle_get_profile(msg, session, proto_writer)
elif msg_type == "update_profile":
await handle_update_profile(msg, session, proto_writer)
elif msg_type == "update_avatar":
await handle_update_avatar(msg, session, proto_writer)
elif msg_type == "get_avatar":
await handle_get_avatar(msg, session, proto_writer)
elif msg_type == "update_group_avatar":
await handle_update_group_avatar(msg, session, proto_writer)
elif msg_type == "get_group_avatar":
await handle_get_group_avatar(msg, session, proto_writer)
elif msg_type == "get_deleted_since":
await handle_get_deleted_since(msg, session, proto_writer)
elif msg_type == "reencrypt_messages":
await handle_reencrypt_messages(msg, session, proto_writer)
elif msg_type == "list_devices":
await handle_list_devices(msg, session, proto_writer)
elif msg_type == "remove_device":
await handle_remove_device(msg, session, proto_writer)
elif msg_type == "session_reset":
await handle_session_reset(msg, session, proto_writer)
elif msg_type == "react_message":
await handle_react_message(msg, session, proto_writer)
elif msg_type == "pin_message":
await handle_pin_message(msg, session, proto_writer)
elif msg_type == "get_pinned_messages":
await handle_get_pinned_messages(msg, session, proto_writer)
else:
await send_resp(msg, proto_writer, msg_type, "error", {"message": "Unknown type"})
except Exception as e:
logger.warning("[ERROR] %s handler '%s' failed: %s", _who(session), msg_type, e, exc_info=True)
try:
await send_resp(msg, proto_writer, msg_type, "error", {"message": "Internal server error"})
except Exception:
break # Can't send response — connection is dead
except Exception as e:
logger.warning("Client connection error: %s", e)
finally:
async with _conn_lock:
current_connections = max(0, current_connections - 1)
connection_counts[addr] = max(0, connection_counts.get(addr, 1) - 1)
offline_targets = []
if session:
uid = session["user_id"]
contacts = await adb.get_user_contacts(uid)
async with _clients_lock:
writer_device_map.pop(id(proto_writer), None)
if uid in connected_clients:
remaining = [w for w in connected_clients[uid] if w is not proto_writer]
if remaining:
connected_clients[uid] = remaining
else:
del connected_clients[uid]
# User fully offline — snapshot targets under lock
for contact_id in contacts:
for cw in connected_clients.get(contact_id, []):
offline_targets.append(cw)
# Send offline notifications outside lock
for cw in offline_targets:
try:
await cw.send_response("user_offline", "ok", {"user_id": uid})
except Exception:
pass
writer.close()
logger.info("[CONN] %s disconnected", _who(session) if session else addr)
async def main():
setup_logging()
host = os.getenv("SERVER_HOST", "127.0.0.1")
port = int(os.getenv("SERVER_PORT", "9999"))
tls_enabled = os.getenv("TLS_ENABLED", "false").lower() in ("1", "true", "yes")
tls_required = os.getenv("TLS_REQUIRED", "false").lower() in ("1", "true", "yes")
tls_autogen = os.getenv("TLS_AUTOGEN", "false").lower() in ("1", "true", "yes")
is_dev = os.getenv("ENVIRONMENT", "").lower() in ("dev", "development")
ssl_context = None
if tls_required and not tls_enabled:
raise RuntimeError("TLS_REQUIRED is enabled but TLS is not enabled.")
if tls_enabled:
cert_file = os.getenv("TLS_CERT_FILE", "").strip()
key_file = os.getenv("TLS_KEY_FILE", "").strip()
if not cert_file or not key_file:
if tls_autogen:
if not is_dev:
raise RuntimeError("TLS_AUTOGEN is only allowed when ENVIRONMENT=dev")
cert_dir = Path(__file__).resolve().parent / "certs"
cert_dir.mkdir(parents=True, exist_ok=True)
cert_file = str(cert_dir / "server.crt")
key_file = str(cert_dir / "server.key")
if not (os.path.exists(cert_file) and os.path.exists(key_file)):
try:
subprocess.run(
[
"openssl", "req", "-x509", "-newkey", "rsa:4096",
"-keyout", key_file, "-out", cert_file,
"-days", "365", "-nodes", "-subj", "/CN=localhost",
],
check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
)
os.chmod(key_file, 0o600)
except FileNotFoundError:
raise RuntimeError("OpenSSL not found.")
except subprocess.CalledProcessError:
raise RuntimeError("Failed to auto-generate TLS cert.")
logger.warning("Using auto-generated self-signed certificate — not for production use.")
else:
raise RuntimeError("TLS is enabled but TLS_CERT_FILE or TLS_KEY_FILE is missing.")
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_context.load_cert_chain(certfile=cert_file, keyfile=key_file)
else:
logger.warning("TLS is disabled — traffic is unencrypted. Set TLS_ENABLED=true for production.")
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
# Thread pool for asyncio.to_thread() — DB calls + file I/O
pool_workers = int(os.getenv("THREAD_POOL_SIZE", "40"))
asyncio.get_event_loop().set_default_executor(ThreadPoolExecutor(max_workers=pool_workers))
logger.info("Thread pool executor: %d workers", pool_workers)
# Load phantom user IDs from DB into in-memory cache
phantom_user_ids.update(await adb.get_all_phantom_user_ids())
if phantom_user_ids:
logger.info("Loaded %d phantom user IDs.", len(phantom_user_ids))
server = await asyncio.start_server(
handle_client, host, port, limit=MAX_MESSAGE_BYTES, ssl=ssl_context,
)
logger.info("Encrypted chat server v%s listening on %s:%s", VERSION, host, port)
async def _cleanup_rate_limits():
async with _conn_lock:
now = asyncio.get_event_loop().time()
window_start = now - RATE_LIMIT_WINDOW
stale_keys = [k for k, times in rate_limits.items()
if not any(t >= window_start for t in times)]
for k in stale_keys:
del rate_limits[k]
stale_conns = [k for k, v in connection_counts.items() if v <= 0]
for k in stale_conns:
del connection_counts[k]
_cleanup_cycle = 0
async def _periodic_cleanup():
nonlocal _cleanup_cycle
while True:
await asyncio.sleep(120)
_cleanup_cycle += 1
try:
await _cleanup_uploads()
except Exception as e:
logger.warning("Upload cleanup error: %s", e)
try:
await _cleanup_rate_limits()
except Exception as e:
logger.warning("Rate limit cleanup error: %s", e)
try:
await _cleanup_registrations()
except Exception as e:
logger.warning("Registration cleanup error: %s", e)
# L8: clean up stale phantom users (>30 days, no real conversations)
try:
deleted = await adb.cleanup_stale_phantoms(30)
if deleted:
async with _clients_lock:
phantom_user_ids.clear()
phantom_user_ids.update(await adb.get_all_phantom_user_ids())
logger.info("Cleaned up %d stale phantom users.", deleted)
except Exception as e:
logger.warning("Phantom cleanup error: %s", e)
# Metadata retention: purge old reads and reactions (every 30 cycles = ~1 hour)
if _cleanup_cycle % 30 == 0:
try:
reads_del = await adb.cleanup_old_reads(METADATA_RETENTION_DAYS)
reactions_del = await adb.cleanup_old_reactions(METADATA_RETENTION_DAYS)
if reads_del or reactions_del:
logger.info("Metadata cleanup: %d reads, %d reactions purged",
reads_del, reactions_del)
except Exception as e:
logger.warning("Metadata cleanup error: %s", e)
asyncio.create_task(_periodic_cleanup())
loop = asyncio.get_running_loop()
stop = loop.create_future()
def signal_handler():
if not stop.done():
stop.set_result(None)
for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(sig, signal_handler)
async with server:
await stop
logger.info("Shutting down — closing %d client connections...", sum(len(ws) for ws in connected_clients.values()))
# Stop accepting new connections
server.close()
# Force-close all connected client writers
async with _clients_lock:
all_writers = [w for writers in connected_clients.values() for w in writers]
connected_clients.clear()
writer_device_map.clear()
for w in all_writers:
try:
w.close()
except Exception:
pass
# Give handle_client loops a moment to notice closed connections
await asyncio.sleep(0.1)
# Cancel any remaining handle_client tasks that are still blocked
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
logger.info("Server shut down.")
if __name__ == "__main__":
asyncio.run(main())