E2E encrypted chat (X3DH + Double Ratchet, Signal Protocol). Server: asyncio TCP + TLS, MySQL. Clients: PyQt6 GUI + CLI. Secrets (.env, TLS keys, Cloudflare token), runtime data and mobile clients (separate repos) are gitignored. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
3205 lines
144 KiB
Python
3205 lines
144 KiB
Python
"""Asyncio TCP server — stores and relays encrypted blobs without seeing content."""
|
|
|
|
import asyncio
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
import hashlib
|
|
import hmac
|
|
import ipaddress
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import secrets
|
|
import signal
|
|
import smtplib
|
|
import socket
|
|
import ssl
|
|
import subprocess
|
|
import sys
|
|
from email.mime.text import MIMEText
|
|
from pathlib import Path
|
|
from datetime import datetime, timezone
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
load_dotenv()
|
|
|
|
import db
|
|
from crypto_utils import load_public_key, rsa_verify, load_ed25519_public, ed25519_verify, serialize_x25519_public, load_x25519_public
|
|
from protocol import VERSION, MIN_CLIENT_VERSION, version_gte, ProtocolReader, ProtocolWriter, encode_binary, decode_binary, MAX_MESSAGE_BYTES, MAX_IMAGE_BYTES, MAX_FILE_BYTES, IMAGE_CHUNK_SIZE
|
|
|
|
|
|
class _AsyncDB:
|
|
"""Async proxy — offloads every synchronous db.* call to a thread via asyncio.to_thread().
|
|
|
|
This prevents blocking the asyncio event loop during MySQL I/O.
|
|
Wrapper functions are cached after first access for efficiency.
|
|
"""
|
|
|
|
def __getattr__(self, name: str):
|
|
func = getattr(db, name)
|
|
|
|
async def wrapper(*args, **kwargs):
|
|
return await asyncio.to_thread(func, *args, **kwargs)
|
|
|
|
wrapper.__name__ = name
|
|
wrapper.__qualname__ = f"_AsyncDB.{name}"
|
|
setattr(self, name, wrapper)
|
|
return wrapper
|
|
|
|
|
|
adb = _AsyncDB()
|
|
|
|
|
|
# Connected clients: user_id -> list[ProtocolWriter]
|
|
connected_clients: dict[str, list[ProtocolWriter]] = {}
|
|
# Writer -> device_id mapping (id(writer) -> device_id)
|
|
writer_device_map: dict[int, str] = {}
|
|
# Pairing sessions: code -> data
|
|
pairing_sessions: dict[str, dict] = {}
|
|
pending_registrations: dict[str, dict] = {}
|
|
# Used PoW challenges (prevents replay within validity window)
|
|
_used_pow_challenges: dict[str, float] = {} # challenge -> used_at
|
|
# Pending image uploads: file_id -> {temp_path, received_bytes, file_size, conv_id}
|
|
pending_uploads: dict[str, dict] = {}
|
|
# Phantom user IDs (loaded at startup, updated on create/delete)
|
|
phantom_user_ids: set[str] = set()
|
|
|
|
# Locks for shared mutable state (H4 race condition fix)
|
|
_clients_lock = asyncio.Lock() # Protects: connected_clients, writer_device_map, phantom_user_ids
|
|
_conn_lock = asyncio.Lock() # Protects: connection_counts, current_connections, rate_limits
|
|
_pairing_lock = asyncio.Lock() # Protects: pairing_sessions, pending_registrations, _used_pow_challenges
|
|
_uploads_lock = asyncio.Lock() # Protects: pending_uploads
|
|
_phantom_lock = asyncio.Lock() # Serializes phantom user creation (cap check + DB create + set add)
|
|
|
|
UPLOAD_DIR = Path(os.getenv("UPLOAD_DIR", "uploads"))
|
|
|
|
|
|
def _secure_delete(p: Path):
|
|
"""Overwrite file with random data before deletion (anti-forensic wipe)."""
|
|
try:
|
|
if not p.exists():
|
|
return
|
|
size = p.stat().st_size
|
|
if size > 0:
|
|
with open(p, "r+b") as f:
|
|
f.write(os.urandom(size))
|
|
f.flush()
|
|
os.fsync(f.fileno())
|
|
p.unlink()
|
|
except Exception:
|
|
try:
|
|
p.unlink(missing_ok=True)
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
# C6 fix: UUID validation + safe path construction to prevent path traversal
|
|
_UUID_RE = re.compile(r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$', re.IGNORECASE)
|
|
|
|
|
|
def _valid_uuid(value: str) -> bool:
|
|
"""Validate that value is a canonical UUID (no path components)."""
|
|
return bool(_UUID_RE.match(value))
|
|
|
|
|
|
# L8 fix: email validation to prevent phantom DB inflation
|
|
_EMAIL_RE = re.compile(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$")
|
|
|
|
|
|
def _valid_email(email: str) -> bool:
|
|
"""Validate basic email format (L8)."""
|
|
return bool(_EMAIL_RE.match(email)) and len(email) <= 254
|
|
|
|
|
|
# C2 fix: ratchet/x3dh header validation
|
|
_RATCHET_HEADER_KEYS = {"dh_pub", "n", "pn"}
|
|
_MAX_HEADER_BYTES = 4096
|
|
|
|
|
|
def _validate_header(raw, name: str) -> bytes | None:
|
|
"""Validate and serialize a ratchet/x3dh header.
|
|
|
|
Accepts only dict with expected keys, rejects str/bytes to prevent
|
|
poisoned headers from being stored. Validates that ratchet headers
|
|
contain the required keys (dh_pub, n, pn) with correct types.
|
|
Returns UTF-8 encoded JSON bytes or None if invalid.
|
|
"""
|
|
if not isinstance(raw, dict):
|
|
return None
|
|
serialized = json.dumps(raw)
|
|
if len(serialized) > _MAX_HEADER_BYTES:
|
|
return None
|
|
# Validate ratchet header keys/types if this looks like one
|
|
if name in ("ratchet_header", "recipient_ratchet_header"):
|
|
# Accept self-encrypted marker {"self": true}
|
|
if raw.get("self") is True and len(raw) == 1:
|
|
return serialized.encode()
|
|
if not _RATCHET_HEADER_KEYS.issubset(raw.keys()):
|
|
return None
|
|
if not isinstance(raw.get("dh_pub"), str):
|
|
return None
|
|
if type(raw.get("n")) is not int or type(raw.get("pn")) is not int:
|
|
return None
|
|
return serialized.encode()
|
|
|
|
|
|
def _append_file(path: Path, data: bytes):
|
|
"""Append data to file (runs in thread pool to avoid blocking event loop)."""
|
|
with open(path, "ab") as f:
|
|
f.write(data)
|
|
|
|
|
|
def _read_file_chunk(path: Path, offset: int, size: int) -> bytes:
|
|
"""Read a chunk from file (runs in thread pool to avoid blocking event loop)."""
|
|
with open(path, "rb") as f:
|
|
f.seek(offset)
|
|
return f.read(size)
|
|
|
|
|
|
def _safe_upload_path(file_id: str, suffix: str) -> Path | None:
|
|
"""Return resolved path inside UPLOAD_DIR, or None if traversal detected."""
|
|
p = (UPLOAD_DIR / f"{file_id}{suffix}").resolve()
|
|
if not p.is_relative_to(UPLOAD_DIR.resolve()):
|
|
return None
|
|
return p
|
|
|
|
|
|
def _safe_avatar_path(filename: str) -> Path | None:
|
|
"""Return resolved avatar path inside UPLOAD_DIR/avatars, or None if traversal detected."""
|
|
avatar_dir = (UPLOAD_DIR / "avatars").resolve()
|
|
p = (UPLOAD_DIR / "avatars" / filename).resolve()
|
|
if not p.is_relative_to(avatar_dir):
|
|
return None
|
|
return p
|
|
|
|
|
|
PAIRING_TTL_SECONDS = 300
|
|
REGISTER_TTL_SECONDS = 600 # 10 min (was 3600) — faster slot release under load
|
|
PAIRING_MAX_POLL_ATTEMPTS = 90
|
|
PAIRING_MAX_SESSIONS = 100 # global cap on concurrent pairing sessions
|
|
MAX_PENDING_REGISTRATIONS = 1000 # global cap on pending registration codes
|
|
MAX_PENDING_PER_IP = 5 # per-IP cap on pending registrations
|
|
MAX_PENDING_PER_SUBNET = 20 # per-/24 (IPv4) or /64 (IPv6) cap
|
|
REGISTRATION_PRESSURE_THRESHOLD = 0.8 # 80% → tighten limits + require PoW
|
|
POW_DIFFICULTY = 20 # leading zero bits in SHA-256 (~1M hashes, ~0.5-2s)
|
|
SMTP_RATE_GLOBAL = 30 # registration emails per minute (global)
|
|
SMTP_RATE_PER_IP = 3 # registration emails per minute (per IP)
|
|
SMTP_RATE_PER_TARGET = 2 # registration emails per minute (per target email)
|
|
MAX_PHANTOM_USERS = 500 # global cap on phantom user count
|
|
MAX_UPLOADS_GLOBAL = 200 # global cap on concurrent in-flight uploads
|
|
MAX_UPLOADS_PER_USER = 5 # per-user cap on concurrent in-flight uploads
|
|
UPLOAD_STALE_SECONDS = 600 # stale upload threshold (10 min)
|
|
|
|
# SMTP configuration for registration codes
|
|
SMTP_HOST = os.getenv("SMTP_HOST", "")
|
|
SMTP_PORT = int(os.getenv("SMTP_PORT", "587"))
|
|
SMTP_USER = os.getenv("SMTP_USER", "")
|
|
SMTP_PASS = os.getenv("SMTP_PASS", "")
|
|
SMTP_FROM = os.getenv("SMTP_FROM", "")
|
|
RATE_LIMIT_WINDOW = 60.0 # seconds
|
|
CONNECTION_RL_WINDOW = 1.0 # seconds
|
|
CONNECTION_RL_MAX = 20 # max requests per window per connection
|
|
MAX_CONNECTIONS_PER_IP = 10
|
|
MAX_CONNECTIONS_GLOBAL = 200
|
|
METADATA_RETENTION_DAYS = int(os.getenv("METADATA_RETENTION_DAYS", "90"))
|
|
# TCP keepalive settings (seconds)
|
|
TCP_KEEPALIVE_IDLE = 25 # Start keepalive probes after 25s of idle
|
|
TCP_KEEPALIVE_INTERVAL = 10 # Send probes every 10s
|
|
TCP_KEEPALIVE_COUNT = 3 # Mark dead after 3 missed probes (30+3*10 = 60s max)
|
|
|
|
|
|
def setup_logging():
|
|
level_name = os.getenv("LOG_LEVEL", "INFO").upper()
|
|
level = getattr(logging, level_name, logging.WARNING)
|
|
logging.basicConfig(level=level, format="%(asctime)s %(levelname)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
|
|
|
|
|
logger = logging.getLogger("encrypted_chat.server")
|
|
|
|
|
|
def _who(session: dict | None) -> str:
|
|
"""Format session info for logging: truncated user_id + device prefix.
|
|
|
|
Avoids leaking usernames and emails into log files.
|
|
"""
|
|
if not session:
|
|
return "<anon>"
|
|
uid = session.get("user_id", "?")[:8]
|
|
dev = session.get("device_id", "")[:8] if session.get("device_id") else ""
|
|
return f"u={uid} d={dev}" if dev else f"u={uid}"
|
|
|
|
|
|
rate_limits: dict[str, list[float]] = {}
|
|
connection_counts: dict[str, int] = {}
|
|
current_connections = 0
|
|
|
|
|
|
def _rate_limit_key(action: str, addr: str, email: str | None = None) -> str:
|
|
if email:
|
|
return f"{action}|{addr}|{email.lower()}"
|
|
return f"{action}|{addr}"
|
|
|
|
|
|
def _normalize_email(email: str | None) -> str:
|
|
if not email:
|
|
return ""
|
|
return email.strip().lower()
|
|
|
|
|
|
async def _is_rate_limited(key: str, limit: int) -> bool:
|
|
async with _conn_lock:
|
|
now = asyncio.get_event_loop().time()
|
|
window_start = now - RATE_LIMIT_WINDOW
|
|
times = rate_limits.get(key, [])
|
|
times = [t for t in times if t >= window_start]
|
|
if len(times) >= limit:
|
|
rate_limits[key] = times
|
|
return True
|
|
times.append(now)
|
|
rate_limits[key] = times
|
|
return False
|
|
|
|
|
|
async def _create_phantom_guarded(email: str, addr: str, user_id: str) -> tuple[dict | None, str]:
|
|
"""Check limits + create phantom user atomically (serialized via _phantom_lock).
|
|
|
|
Returns (user_dict, error_message). user_dict is None on rejection.
|
|
"""
|
|
# Rate limit checks outside _phantom_lock (they acquire _conn_lock)
|
|
if await _is_rate_limited(f"phantom_create|{user_id}", 10):
|
|
return None, "Too many new contacts. Try later."
|
|
if await _is_rate_limited(f"phantom_create_ip|{addr}", 10):
|
|
return None, "Too many new contacts. Try later."
|
|
async with _phantom_lock:
|
|
async with _clients_lock:
|
|
phantom_count = len(phantom_user_ids)
|
|
if phantom_count >= MAX_PHANTOM_USERS:
|
|
return None, "Server limit reached. Try later."
|
|
u = await adb.create_phantom_user(email)
|
|
async with _clients_lock:
|
|
phantom_user_ids.add(u["id"])
|
|
return u, ""
|
|
|
|
|
|
def _get_peer_addr(writer: ProtocolWriter) -> str:
|
|
try:
|
|
return str(writer._writer.get_extra_info("peername")[0])
|
|
except Exception:
|
|
return "unknown"
|
|
|
|
|
|
async def _remove_dead_writer(w: ProtocolWriter):
|
|
"""Remove a dead writer from connected_clients (best-effort)."""
|
|
async with _clients_lock:
|
|
wid = id(w)
|
|
writer_device_map.pop(wid, None)
|
|
for uid, writers in list(connected_clients.items()):
|
|
remaining = [wr for wr in writers if wr is not w]
|
|
if remaining:
|
|
connected_clients[uid] = remaining
|
|
else:
|
|
del connected_clients[uid]
|
|
|
|
|
|
async def _notify_users(user_ids, msg_type, data, exclude_writer=None):
|
|
"""Snapshot writers under lock, send notifications outside lock."""
|
|
targets = []
|
|
async with _clients_lock:
|
|
for uid in user_ids:
|
|
for w in connected_clients.get(uid, []):
|
|
targets.append(w)
|
|
dead = []
|
|
for w in targets:
|
|
if w is exclude_writer:
|
|
continue
|
|
try:
|
|
if w.is_closing():
|
|
dead.append(w)
|
|
continue
|
|
await w.send_response(msg_type, "ok", data)
|
|
except Exception:
|
|
logger.debug("[NOTIFY] Failed to send %s, marking writer dead", msg_type)
|
|
dead.append(w)
|
|
for w in dead:
|
|
await _remove_dead_writer(w)
|
|
|
|
|
|
async def _notify_users_individual(notifications, exclude_writer=None):
|
|
"""Send per-user data. notifications: list of (user_id, msg_type, data)."""
|
|
targets = []
|
|
async with _clients_lock:
|
|
for uid, mt, d in notifications:
|
|
for w in connected_clients.get(uid, []):
|
|
targets.append((w, mt, d, uid))
|
|
dead = []
|
|
sent = 0
|
|
skipped = 0
|
|
delivered_users = set()
|
|
for w, mt, d, uid in targets:
|
|
if w is exclude_writer:
|
|
skipped += 1
|
|
continue
|
|
try:
|
|
if w.is_closing():
|
|
dead.append(w)
|
|
logger.warning("[NOTIFY] Writer for u=%s is closing, removing", uid[:8])
|
|
continue
|
|
await w.send_response(mt, "ok", d)
|
|
sent += 1
|
|
delivered_users.add(uid)
|
|
except Exception as e:
|
|
logger.warning("[NOTIFY] Failed to send %s to u=%s: %s", mt, uid[:8], e)
|
|
dead.append(w)
|
|
if dead:
|
|
for w in dead:
|
|
await _remove_dead_writer(w)
|
|
if mt == "new_message" if targets else False:
|
|
logger.debug("[NOTIFY] %s: sent=%d skipped=%d dead=%d", mt, sent, skipped, len(dead))
|
|
return delivered_users
|
|
|
|
|
|
async def _cleanup_pairings():
|
|
async with _pairing_lock:
|
|
now = asyncio.get_event_loop().time()
|
|
expired = [code for code, p in pairing_sessions.items() if now - p["created_at"] > PAIRING_TTL_SECONDS]
|
|
for code in expired:
|
|
pairing_sessions.pop(code, None)
|
|
|
|
|
|
async def _cleanup_registrations():
|
|
async with _pairing_lock:
|
|
now = asyncio.get_event_loop().time()
|
|
expired = [code for code, p in pending_registrations.items() if now - p["created_at"] > REGISTER_TTL_SECONDS]
|
|
for code in expired:
|
|
pending_registrations.pop(code, None)
|
|
# Purge used PoW challenges older than 120s (validity window)
|
|
stale = [ch for ch, ts in _used_pow_challenges.items() if now - ts > 120]
|
|
for ch in stale:
|
|
_used_pow_challenges.pop(ch, None)
|
|
|
|
|
|
def _generate_pairing_code() -> str:
|
|
for _ in range(10):
|
|
code = f"{int.from_bytes(os.urandom(4), 'big') % 100000000:08d}"
|
|
if code not in pairing_sessions:
|
|
return code
|
|
return f"{int.from_bytes(os.urandom(4), 'big') % 100000000:08d}"
|
|
|
|
|
|
def _generate_register_code() -> str:
|
|
for _ in range(10):
|
|
code = f"{int.from_bytes(os.urandom(3), 'big') % 1000000:06d}"
|
|
if code not in pending_registrations:
|
|
return code
|
|
return f"{int.from_bytes(os.urandom(3), 'big') % 1000000:06d}"
|
|
|
|
def _validate_public_key_pem(pem_str: str) -> bool:
|
|
"""Validate that a string is a valid RSA public key PEM."""
|
|
try:
|
|
key = load_public_key(pem_str.encode("utf-8"))
|
|
if key.key_size < 2048:
|
|
return False
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
def _send_registration_email(to_email: str, code: str) -> bool:
|
|
"""Send registration code via SMTP. Returns True on success."""
|
|
if not SMTP_HOST:
|
|
return False
|
|
try:
|
|
msg = MIMEText(f"Your registration code is: {code}\n\nThis code expires in 10 minutes.")
|
|
msg["Subject"] = "Encrypted Chat - Registration Code"
|
|
msg["From"] = SMTP_FROM or SMTP_USER
|
|
msg["To"] = to_email
|
|
with smtplib.SMTP(SMTP_HOST, SMTP_PORT, timeout=10) as server:
|
|
# RFC-style STARTTLS flow: advertise capabilities pre/post TLS upgrade.
|
|
server.ehlo()
|
|
server.starttls(context=ssl.create_default_context())
|
|
server.ehlo()
|
|
if SMTP_USER:
|
|
server.login(SMTP_USER, SMTP_PASS)
|
|
server.send_message(msg)
|
|
return True
|
|
except Exception as e:
|
|
logger.warning("Failed to send registration email: %s", e)
|
|
return False
|
|
|
|
|
|
async def send_resp(msg: dict, writer: ProtocolWriter, msg_type: str, status: str, data: dict | None = None):
|
|
await writer.send_response(msg_type, status, data, request_id=msg.get("request_id"))
|
|
|
|
|
|
# --- Registration admission control ---
|
|
|
|
_POW_SECRET = os.urandom(32) # per-process; restarts invalidate outstanding challenges
|
|
|
|
|
|
def _get_subnet(addr: str) -> str:
|
|
"""Extract /24 for IPv4, /64 for IPv6."""
|
|
try:
|
|
ip = ipaddress.ip_address(addr)
|
|
if ip.version == 4:
|
|
return str(ipaddress.ip_network(f"{ip}/24", strict=False))
|
|
return str(ipaddress.ip_network(f"{ip}/64", strict=False))
|
|
except ValueError:
|
|
return addr
|
|
|
|
|
|
def _pending_counts_by_origin(addr: str) -> tuple[int, int]:
|
|
"""Count pending registrations by IP and subnet. Caller must hold _pairing_lock."""
|
|
subnet = _get_subnet(addr)
|
|
ip_count = 0
|
|
subnet_count = 0
|
|
for p in pending_registrations.values():
|
|
p_addr = p.get("addr", "")
|
|
if p_addr == addr:
|
|
ip_count += 1
|
|
if _get_subnet(p_addr) == subnet:
|
|
subnet_count += 1
|
|
return ip_count, subnet_count
|
|
|
|
|
|
def _generate_pow_challenge() -> tuple[str, str]:
|
|
"""Generate a stateless PoW challenge (challenge, mac).
|
|
|
|
The challenge embeds a timestamp so the server can reject stale solutions.
|
|
The HMAC proves the challenge was issued by this server instance.
|
|
"""
|
|
ts = str(int(asyncio.get_event_loop().time()))
|
|
nonce = secrets.token_hex(16)
|
|
challenge = f"{ts}:{nonce}"
|
|
mac = hmac.new(_POW_SECRET, challenge.encode(), hashlib.sha256).hexdigest()
|
|
return challenge, mac
|
|
|
|
|
|
def _verify_pow(challenge: str, mac: str, nonce: str, difficulty: int) -> bool:
|
|
"""Verify a PoW solution: HMAC authentic, timestamp fresh, hash has leading zeros."""
|
|
# Verify HMAC
|
|
expected = hmac.new(_POW_SECRET, challenge.encode(), hashlib.sha256).hexdigest()
|
|
if not hmac.compare_digest(expected, mac):
|
|
return False
|
|
# Check timestamp freshness (120s window)
|
|
try:
|
|
ts = int(challenge.split(":")[0])
|
|
except (ValueError, IndexError):
|
|
return False
|
|
now = int(asyncio.get_event_loop().time())
|
|
if abs(now - ts) > 120:
|
|
return False
|
|
# Verify PoW: SHA-256(challenge + nonce) must have `difficulty` leading zero bits
|
|
digest = hashlib.sha256(f"{challenge}{nonce}".encode()).digest()
|
|
# Check leading zero bits
|
|
bits_needed = difficulty
|
|
for byte in digest:
|
|
if bits_needed <= 0:
|
|
break
|
|
if bits_needed >= 8:
|
|
if byte != 0:
|
|
return False
|
|
bits_needed -= 8
|
|
else:
|
|
mask = (0xFF << (8 - bits_needed)) & 0xFF
|
|
if byte & mask:
|
|
return False
|
|
bits_needed = 0
|
|
return True
|
|
|
|
|
|
async def handle_register_start(msg: dict, writer: ProtocolWriter) -> dict | None:
|
|
await _cleanup_registrations()
|
|
username = msg.get("username", "").strip()
|
|
public_key = msg.get("public_key", "").strip()
|
|
identity_key_b64 = msg.get("identity_key", "").strip()
|
|
email = msg.get("email", "").strip()
|
|
addr = _get_peer_addr(writer)
|
|
if await _is_rate_limited(_rate_limit_key("register_start", addr, email), 3):
|
|
await send_resp(msg, writer, "register_start", "error", {"message": "Too many attempts. Try later."})
|
|
return None
|
|
# Per-IP limit (regardless of email) to prevent SMTP spam via email rotation
|
|
if await _is_rate_limited(f"register_start_ip|{addr}", 6):
|
|
await send_resp(msg, writer, "register_start", "error", {"message": "Too many attempts. Try later."})
|
|
return None
|
|
if not username or not public_key or not email or not identity_key_b64:
|
|
await send_resp(msg, writer, "register_start", "error", {"message": "Missing fields"})
|
|
return None
|
|
if not _validate_public_key_pem(public_key):
|
|
await send_resp(msg, writer, "register_start", "error", {"message": "Invalid public key format"})
|
|
return None
|
|
# Validate identity key is 32 bytes
|
|
try:
|
|
ik_bytes = decode_binary(identity_key_b64)
|
|
if len(ik_bytes) != 32:
|
|
raise ValueError("Identity key must be 32 bytes")
|
|
load_ed25519_public(ik_bytes)
|
|
except Exception:
|
|
await send_resp(msg, writer, "register_start", "error", {"message": "Invalid identity key"})
|
|
return None
|
|
existing_email = await adb.get_user_by_email(email)
|
|
phantom_id = None
|
|
is_existing_real_user = False
|
|
if existing_email:
|
|
if existing_email.get("rsa_public_key") == "PHANTOM":
|
|
phantom_id = existing_email["id"]
|
|
else:
|
|
is_existing_real_user = True
|
|
# --- Admission control (all checks under lock, I/O outside) ---
|
|
# Existing-email goes through the same path so responses are
|
|
# indistinguishable from new-email (H3 anti-enumeration).
|
|
# Both allocate a slot so per-IP/subnet cap counting is identical.
|
|
async with _pairing_lock:
|
|
total = len(pending_registrations)
|
|
# Hard cap
|
|
if total >= MAX_PENDING_REGISTRATIONS:
|
|
reject_reason = "cap"
|
|
else:
|
|
# Per-IP / per-subnet slot limits
|
|
ip_count, subnet_count = _pending_counts_by_origin(addr)
|
|
if ip_count >= MAX_PENDING_PER_IP:
|
|
reject_reason = "ip"
|
|
elif subnet_count >= MAX_PENDING_PER_SUBNET:
|
|
reject_reason = "subnet"
|
|
else:
|
|
reject_reason = None
|
|
# Pressure mode: require PoW when >80% full
|
|
under_pressure = total >= MAX_PENDING_REGISTRATIONS * REGISTRATION_PRESSURE_THRESHOLD
|
|
need_pow = under_pressure and reject_reason is None
|
|
# If PoW required, verify the client's solution (one-time use)
|
|
pow_ok = False
|
|
if need_pow:
|
|
pow_challenge = msg.get("pow_challenge", "")
|
|
pow_mac = msg.get("pow_mac", "")
|
|
pow_nonce = msg.get("pow_nonce", "")
|
|
if pow_challenge and pow_mac and pow_nonce:
|
|
if pow_challenge in _used_pow_challenges:
|
|
pow_ok = False # replay
|
|
elif _verify_pow(pow_challenge, pow_mac, pow_nonce, POW_DIFFICULTY):
|
|
_used_pow_challenges[pow_challenge] = asyncio.get_event_loop().time()
|
|
pow_ok = True
|
|
# Decide: admit, challenge, or reject
|
|
if reject_reason:
|
|
admit = False
|
|
send_challenge = False
|
|
code = None
|
|
elif need_pow and not pow_ok:
|
|
admit = False
|
|
send_challenge = True
|
|
code = None
|
|
else:
|
|
# Both existing and new emails allocate a slot so per-IP/subnet
|
|
# counting behaves identically (anti-enumeration via slot side-channel).
|
|
# Existing-email slots are inert — register_confirm silently fails.
|
|
admit = True
|
|
send_challenge = False
|
|
code = _generate_register_code()
|
|
pending_registrations[code] = {
|
|
"username": username,
|
|
"public_key": public_key,
|
|
"identity_key": ik_bytes,
|
|
"email": email,
|
|
"created_at": asyncio.get_event_loop().time(),
|
|
"phantom_id": phantom_id,
|
|
"addr": addr,
|
|
"fake": is_existing_real_user,
|
|
}
|
|
# --- I/O outside lock ---
|
|
if not admit:
|
|
if send_challenge:
|
|
challenge, mac = _generate_pow_challenge()
|
|
await send_resp(msg, writer, "register_start", "pow_required", {
|
|
"challenge": challenge, "mac": mac, "difficulty": POW_DIFFICULTY,
|
|
})
|
|
else:
|
|
await send_resp(msg, writer, "register_start", "error", {"message": "Server busy. Try later."})
|
|
return None
|
|
logger.info("[REGISTER] registration started")
|
|
is_dev = os.getenv("ENVIRONMENT", "").lower() in ("dev", "development")
|
|
# SMTP rate limiting
|
|
smtp_blocked = False
|
|
if SMTP_HOST:
|
|
if await _is_rate_limited("smtp_send|global", SMTP_RATE_GLOBAL):
|
|
smtp_blocked = True
|
|
elif await _is_rate_limited(f"smtp_send_ip|{addr}", SMTP_RATE_PER_IP):
|
|
smtp_blocked = True
|
|
elif await _is_rate_limited(f"smtp_send_target|{email.lower()}", SMTP_RATE_PER_TARGET):
|
|
smtp_blocked = True
|
|
if smtp_blocked:
|
|
if is_dev:
|
|
logger.warning("[REGISTER] SMTP rate limit hit — returning code (dev mode)")
|
|
await send_resp(msg, writer, "register_start", "ok", {"code": code})
|
|
else:
|
|
logger.warning("[REGISTER] SMTP rate limit hit — revoking slot silently")
|
|
async with _pairing_lock:
|
|
pending_registrations.pop(code, None)
|
|
await send_resp(msg, writer, "register_start", "ok",
|
|
{"message": "Code sent to your email."})
|
|
return None
|
|
# Send registration email in a thread (non-blocking) for both real
|
|
# and fake registrations. For existing emails we still call SMTP so
|
|
# the response timing is indistinguishable (anti-enumeration).
|
|
# The email goes to the real address either way — existing users just
|
|
# won't be able to confirm (code is for a fake slot).
|
|
email_sent = await asyncio.to_thread(_send_registration_email, email, code)
|
|
if email_sent:
|
|
await send_resp(msg, writer, "register_start", "ok", {"message": "Code sent to your email."})
|
|
elif is_dev:
|
|
logger.warning("[REGISTER] No SMTP / send failed — returning code (dev mode)")
|
|
await send_resp(msg, writer, "register_start", "ok", {"code": code})
|
|
else:
|
|
logger.warning("[REGISTER] SMTP send failed — revoking slot silently")
|
|
async with _pairing_lock:
|
|
pending_registrations.pop(code, None)
|
|
await send_resp(msg, writer, "register_start", "ok",
|
|
{"message": "Code sent to your email."})
|
|
return None
|
|
|
|
|
|
async def handle_register_confirm(msg: dict, writer: ProtocolWriter) -> dict | None:
|
|
await _cleanup_registrations()
|
|
email = msg.get("email", "").strip()
|
|
code = msg.get("code", "").strip()
|
|
addr = _get_peer_addr(writer)
|
|
if await _is_rate_limited(_rate_limit_key("register_confirm", addr, email), 3):
|
|
await send_resp(msg, writer, "register_confirm", "error", {"message": "Too many attempts. Try later."})
|
|
return None
|
|
if not email or not code:
|
|
await send_resp(msg, writer, "register_confirm", "error", {"message": "Missing email or code"})
|
|
return None
|
|
async with _pairing_lock:
|
|
pending = pending_registrations.get(code)
|
|
if pending and pending.get("email") == email:
|
|
pending_registrations.pop(code, None)
|
|
else:
|
|
pending = None
|
|
if not pending:
|
|
await send_resp(msg, writer, "register_confirm", "error", {"message": "Invalid or expired code"})
|
|
return None
|
|
# H3 anti-enumeration: fake slot (existing email) — reject with same
|
|
# generic message so attacker can't distinguish from a wrong code
|
|
if pending.get("fake"):
|
|
await send_resp(msg, writer, "register_confirm", "error", {"message": "Invalid or expired code"})
|
|
return None
|
|
phantom_id = pending.get("phantom_id")
|
|
if phantom_id:
|
|
# Upgrade phantom in-place — preserves FK references (invitations, memberships)
|
|
user_id = await adb.upgrade_phantom_user(
|
|
phantom_id,
|
|
pending["username"],
|
|
pending["public_key"],
|
|
pending["identity_key"],
|
|
)
|
|
if user_id:
|
|
async with _clients_lock:
|
|
phantom_user_ids.discard(phantom_id)
|
|
else:
|
|
# Phantom was deleted concurrently — fall back to normal create
|
|
user_id = await adb.create_user(
|
|
pending["username"],
|
|
pending["email"],
|
|
pending["public_key"],
|
|
pending["identity_key"],
|
|
)
|
|
else:
|
|
user_id = await adb.create_user(
|
|
pending["username"],
|
|
pending["email"],
|
|
pending["public_key"],
|
|
pending["identity_key"],
|
|
)
|
|
await adb.create_default_profile(user_id)
|
|
logger.info("[REGISTER] confirmed (user_id=%s)", user_id[:8])
|
|
await send_resp(msg, writer, "register_confirm", "ok", {"user_id": user_id})
|
|
return None
|
|
|
|
|
|
async def handle_login_start(msg: dict, writer: ProtocolWriter, state: dict):
|
|
email = msg.get("email", "").strip()
|
|
normalized_email = _normalize_email(email)
|
|
addr = _get_peer_addr(writer)
|
|
if await _is_rate_limited(_rate_limit_key("login_start", addr, email), 10):
|
|
await send_resp(msg, writer, "login_start", "error", {"message": "Too many attempts. Try later."})
|
|
return
|
|
if await _is_rate_limited(f"login_start_ip|{addr}", 20):
|
|
await send_resp(msg, writer, "login_start", "error", {"message": "Too many attempts. Try later."})
|
|
return
|
|
if not email:
|
|
await send_resp(msg, writer, "login_start", "error", {"message": "Missing email"})
|
|
return
|
|
user = await adb.get_user_by_email(normalized_email)
|
|
challenge = os.urandom(32)
|
|
state["login_email"] = email
|
|
state["login_challenge"] = challenge
|
|
if not user:
|
|
# H3 anti-enumeration: return a fake challenge so attacker can't distinguish
|
|
# "user not found" from "user exists". login_finish will fail with generic error.
|
|
state["_login_fake"] = True
|
|
await send_resp(msg, writer, "login_start", "ok", {"challenge": encode_binary(challenge)})
|
|
|
|
|
|
async def handle_login_finish(msg: dict, writer: ProtocolWriter, state: dict) -> dict | None:
|
|
email = msg.get("email", "").strip()
|
|
signature_b64 = msg.get("signature", "")
|
|
challenge = state.get("login_challenge")
|
|
expected_email = state.get("login_email")
|
|
addr = _get_peer_addr(writer)
|
|
if await _is_rate_limited(_rate_limit_key("login_finish", addr, email), 10):
|
|
await send_resp(msg, writer, "login_finish", "error", {"message": "Too many attempts. Try later."})
|
|
return None
|
|
if not email or not signature_b64:
|
|
await send_resp(msg, writer, "login_finish", "error", {"message": "Missing email or signature"})
|
|
return None
|
|
if not challenge or expected_email != email:
|
|
await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"})
|
|
return None
|
|
|
|
# H3: if login_start was for a non-existent user, fail with generic error
|
|
is_fake = state.pop("_login_fake", False)
|
|
|
|
try:
|
|
if is_fake:
|
|
await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"})
|
|
return None
|
|
|
|
user = await adb.get_user_by_email(email)
|
|
if not user:
|
|
await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"})
|
|
return None
|
|
|
|
public_key = load_public_key(user["rsa_public_key"].encode("utf-8"))
|
|
signature = decode_binary(signature_b64)
|
|
if not rsa_verify(public_key, signature, challenge):
|
|
await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"})
|
|
return None
|
|
except ValueError:
|
|
# H5: invalid base64 in signature
|
|
await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"})
|
|
return None
|
|
finally:
|
|
state.pop("login_challenge", None)
|
|
state.pop("login_email", None)
|
|
|
|
user_id = user["id"]
|
|
|
|
# Version check: reject outdated clients
|
|
client_version = msg.get("client_version", "")
|
|
if client_version and not version_gte(client_version, MIN_CLIENT_VERSION):
|
|
await send_resp(msg, writer, "login_finish", "error", {
|
|
"message": f"Client version {client_version} is too old. Minimum required: {MIN_CLIENT_VERSION}",
|
|
"min_version": MIN_CLIENT_VERSION,
|
|
"server_version": VERSION,
|
|
})
|
|
return None
|
|
|
|
# Device registration: client may send device_id to reuse an existing device
|
|
client_device_id = msg.get("device_id")
|
|
device_id = None
|
|
new_device_created = False
|
|
device_name = msg.get("device_name", "Unknown")
|
|
if client_device_id:
|
|
dev = await adb.get_device(client_device_id)
|
|
if dev and dev["user_id"] == user_id:
|
|
device_id = client_device_id
|
|
if not device_id:
|
|
device_id = await adb.create_device(user_id, device_name)
|
|
new_device_created = True
|
|
await adb.update_device_last_seen(device_id)
|
|
|
|
async with _clients_lock:
|
|
if user_id not in connected_clients:
|
|
connected_clients[user_id] = []
|
|
connected_clients[user_id].append(writer)
|
|
writer_device_map[id(writer)] = device_id
|
|
logger.info("[LOGIN] u=%s d=%s client_v=%s",
|
|
user_id[:8], device_id[:8] if device_id else "?", client_version or "unknown")
|
|
await send_resp(msg, writer, "login_finish", "ok", {
|
|
"user_id": user_id, "username": user["username"], "email": user["email"],
|
|
"device_id": device_id, "server_version": VERSION,
|
|
})
|
|
|
|
# Send online status notifications
|
|
contacts = await adb.get_user_contacts(user_id)
|
|
online_targets = []
|
|
async with _clients_lock:
|
|
online_contacts = [cid for cid in contacts if cid in connected_clients and connected_clients[cid]]
|
|
# Always notify contacts (handles reconnect where old writer is still lingering)
|
|
for contact_id in contacts:
|
|
for cw in connected_clients.get(contact_id, []):
|
|
online_targets.append(cw)
|
|
await writer.send_response("online_users", "ok", {"user_ids": online_contacts})
|
|
# Send online notifications outside lock
|
|
for cw in online_targets:
|
|
try:
|
|
await cw.send_response("user_online", "ok", {"user_id": user_id})
|
|
except Exception:
|
|
pass
|
|
|
|
if new_device_created:
|
|
await _notify_users([user_id], "device_added", {
|
|
"device_id": device_id,
|
|
"device_name": device_name,
|
|
"ip": addr,
|
|
"added_at": datetime.now(timezone.utc).isoformat(),
|
|
}, exclude_writer=writer)
|
|
|
|
return {"user_id": user_id, "username": user["username"], "email": user["email"],
|
|
"device_id": device_id}
|
|
|
|
|
|
async def handle_get_user_info(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""Get user info including identity key (for X3DH). Requires login."""
|
|
email = msg.get("email", "").strip()
|
|
user_id = msg.get("user_id", "").strip()
|
|
addr = _get_peer_addr(writer)
|
|
if await _is_rate_limited(_rate_limit_key("get_user_info", addr, email or user_id), 30):
|
|
await send_resp(msg, writer, "get_user_info", "error", {"message": "Too many attempts. Try later."})
|
|
return
|
|
if user_id and not _valid_uuid(user_id):
|
|
await send_resp(msg, writer, "get_user_info", "error", {"message": "User not found"})
|
|
return
|
|
user = None
|
|
if email:
|
|
user = await adb.get_user_by_email(email)
|
|
elif user_id:
|
|
user = await adb.get_user_by_id(user_id)
|
|
if not user:
|
|
await send_resp(msg, writer, "get_user_info", "error", {"message": "User not found"})
|
|
return
|
|
# H4 fix: restrict lookups to self or contacts (shared conversation)
|
|
target_id = user["id"]
|
|
if target_id != session["user_id"]:
|
|
if not await adb.shares_conversation(session["user_id"], target_id):
|
|
await send_resp(msg, writer, "get_user_info", "error", {"message": "User not found"})
|
|
return
|
|
ik = user.get("identity_key")
|
|
await send_resp(msg, writer, "get_user_info", "ok", {
|
|
"user_id": user["id"],
|
|
"username": user["username"],
|
|
"email": user["email"],
|
|
"identity_key": encode_binary(ik) if ik else "",
|
|
})
|
|
|
|
|
|
async def handle_upload_prekeys(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""Upload signed prekey + batch of one-time prekeys."""
|
|
if await _is_rate_limited(f"upload_prekeys|{session['user_id']}", 5):
|
|
await send_resp(msg, writer, "upload_prekeys", "error", {"message": "Too many requests. Try later."})
|
|
return
|
|
spk_data = msg.get("signed_prekey")
|
|
otps = msg.get("one_time_prekeys", [])
|
|
if not spk_data:
|
|
await send_resp(msg, writer, "upload_prekeys", "error", {"message": "Missing signed_prekey"})
|
|
return
|
|
|
|
spk_id = spk_data.get("id", "")
|
|
spk_pub_b64 = spk_data.get("public_key", "")
|
|
spk_sig_b64 = spk_data.get("signature", "")
|
|
if not spk_id or not spk_pub_b64 or not spk_sig_b64:
|
|
await send_resp(msg, writer, "upload_prekeys", "error", {"message": "Incomplete signed_prekey"})
|
|
return
|
|
|
|
spk_pub = decode_binary(spk_pub_b64)
|
|
spk_sig = decode_binary(spk_sig_b64)
|
|
|
|
# Verify SPK signature with user's identity key
|
|
user = await adb.get_user_by_id(session["user_id"])
|
|
if not user or not user.get("identity_key"):
|
|
await send_resp(msg, writer, "upload_prekeys", "error", {"message": "No identity key"})
|
|
return
|
|
ik_pub = load_ed25519_public(user["identity_key"])
|
|
if not ed25519_verify(ik_pub, spk_sig, spk_pub):
|
|
await send_resp(msg, writer, "upload_prekeys", "error", {"message": "Invalid SPK signature"})
|
|
return
|
|
|
|
device_id = session.get("device_id")
|
|
await adb.store_signed_prekey(session["user_id"], spk_id, spk_pub, spk_sig, device_id=device_id)
|
|
|
|
# Store OTPs
|
|
otp_records = []
|
|
for otp in otps:
|
|
otp_id = otp.get("id", "")
|
|
otp_pub_b64 = otp.get("public_key", "")
|
|
if otp_id and otp_pub_b64:
|
|
otp_records.append({"id": otp_id, "public_key": decode_binary(otp_pub_b64)})
|
|
if otp_records:
|
|
await adb.store_one_time_prekeys(session["user_id"], otp_records, device_id=device_id)
|
|
|
|
logger.info("[PREKEYS] %s uploaded 1 SPK + %d OTPs", _who(session), len(otp_records))
|
|
await send_resp(msg, writer, "upload_prekeys", "ok", {"message": "OK"})
|
|
|
|
|
|
async def handle_get_key_bundle(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""Fetch key bundle for X3DH. Returns per-device bundles. Consumes one OTP per device."""
|
|
target_user_id = msg.get("user_id", "").strip()
|
|
if not target_user_id:
|
|
await send_resp(msg, writer, "get_key_bundle", "error", {"message": "Missing user_id"})
|
|
return
|
|
if not _valid_uuid(target_user_id):
|
|
await send_resp(msg, writer, "get_key_bundle", "error", {"message": "Invalid user_id"})
|
|
return
|
|
# M4: rate limit + authorization (prevents OPK depletion)
|
|
if await _is_rate_limited(f"get_key_bundle|{session['user_id']}", 10):
|
|
await send_resp(msg, writer, "get_key_bundle", "error", {"message": "Too many requests. Try later."})
|
|
return
|
|
# Auth check before per-target rate limit so unauthorized requests don't burn target's bucket
|
|
if target_user_id != session["user_id"]:
|
|
if not await adb.shares_conversation(session["user_id"], target_user_id):
|
|
await send_resp(msg, writer, "get_key_bundle", "error", {"message": "Key bundle not available"})
|
|
return
|
|
if await _is_rate_limited(f"get_key_bundle_target|{target_user_id}", 20):
|
|
await send_resp(msg, writer, "get_key_bundle", "error", {"message": "Too many requests. Try later."})
|
|
return
|
|
result = await adb.get_key_bundles_for_user(target_user_id)
|
|
if not result or not result.get("device_bundles"):
|
|
await send_resp(msg, writer, "get_key_bundle", "error", {"message": "Key bundle not available"})
|
|
return
|
|
|
|
device_bundles_data = []
|
|
for b in result["device_bundles"]:
|
|
entry = {
|
|
"device_id": b.get("device_id"),
|
|
"signed_prekey_id": b["signed_prekey_id"],
|
|
"signed_prekey": encode_binary(b["signed_prekey_pub"]),
|
|
"spk_signature": encode_binary(b["spk_signature"]),
|
|
}
|
|
if b.get("opk_pub"):
|
|
entry["one_time_prekey_id"] = b["opk_id"]
|
|
entry["one_time_prekey"] = encode_binary(b["opk_pub"])
|
|
device_bundles_data.append(entry)
|
|
|
|
# Build response with both new multi-device format and legacy flat fields
|
|
first = device_bundles_data[0] if device_bundles_data else {}
|
|
data = {
|
|
"identity_key": encode_binary(result["identity_key"]),
|
|
"device_bundles": device_bundles_data,
|
|
# Legacy flat fields from first device bundle (backward compat)
|
|
"signed_prekey_id": first.get("signed_prekey_id", ""),
|
|
"signed_prekey": first.get("signed_prekey", ""),
|
|
"spk_signature": first.get("spk_signature", ""),
|
|
}
|
|
if first.get("one_time_prekey"):
|
|
data["one_time_prekey_id"] = first["one_time_prekey_id"]
|
|
data["one_time_prekey"] = first["one_time_prekey"]
|
|
logger.info("[X3DH] %s fetched key bundle for user=%s (%d devices)",
|
|
_who(session), target_user_id[:8], len(device_bundles_data))
|
|
await send_resp(msg, writer, "get_key_bundle", "ok", data)
|
|
|
|
|
|
async def handle_get_prekey_count(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""How many OPKs does user have left (for this device)? Also returns SPK age for rotation."""
|
|
device_id = session.get("device_id")
|
|
count = await adb.count_one_time_prekeys(session["user_id"], device_id=device_id)
|
|
spk_created_at = ""
|
|
spk = await adb.get_signed_prekey(session["user_id"], device_id=device_id)
|
|
if spk and spk.get("created_at"):
|
|
spk_created_at = spk["created_at"].isoformat() if hasattr(spk["created_at"], "isoformat") else str(spk["created_at"])
|
|
await send_resp(msg, writer, "get_prekey_count", "ok",
|
|
{"count": count, "spk_created_at": spk_created_at})
|
|
|
|
|
|
async def handle_ensure_prekeys(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""Combined get_prekey_count + upload_prekeys in one round-trip.
|
|
|
|
Client sends current OPK/SPK data; server checks count and SPK age,
|
|
stores new keys if provided, and returns the current status.
|
|
"""
|
|
if await _is_rate_limited(f"ensure_prekeys|{session['user_id']}", 5):
|
|
await send_resp(msg, writer, "ensure_prekeys", "error", {"message": "Too many requests. Try later."})
|
|
return
|
|
device_id = session.get("device_id")
|
|
user_id = session["user_id"]
|
|
|
|
# Step 1: Get current count + SPK age
|
|
count = await adb.count_one_time_prekeys(user_id, device_id=device_id)
|
|
spk_created_at = ""
|
|
spk = await adb.get_signed_prekey(user_id, device_id=device_id)
|
|
if spk and spk.get("created_at"):
|
|
spk_created_at = spk["created_at"].isoformat() if hasattr(spk["created_at"], "isoformat") else str(spk["created_at"])
|
|
|
|
# Step 2: If client included new keys, store them
|
|
uploaded_spk = False
|
|
uploaded_otps = 0
|
|
spk_data = msg.get("signed_prekey")
|
|
if spk_data:
|
|
spk_id = spk_data.get("id", "")
|
|
spk_pub_b64 = spk_data.get("public_key", "")
|
|
spk_sig_b64 = spk_data.get("signature", "")
|
|
if spk_id and spk_pub_b64 and spk_sig_b64:
|
|
spk_pub = decode_binary(spk_pub_b64)
|
|
spk_sig = decode_binary(spk_sig_b64)
|
|
user = await adb.get_user_by_id(user_id)
|
|
if user and user.get("identity_key"):
|
|
ik_pub = load_ed25519_public(user["identity_key"])
|
|
if ed25519_verify(ik_pub, spk_sig, spk_pub):
|
|
await adb.store_signed_prekey(user_id, spk_id, spk_pub, spk_sig, device_id=device_id)
|
|
uploaded_spk = True
|
|
|
|
otps = msg.get("one_time_prekeys", [])
|
|
if otps:
|
|
otp_records = []
|
|
for otp in otps:
|
|
otp_id = otp.get("id", "")
|
|
otp_pub_b64 = otp.get("public_key", "")
|
|
if otp_id and otp_pub_b64:
|
|
otp_records.append({"id": otp_id, "public_key": decode_binary(otp_pub_b64)})
|
|
if otp_records:
|
|
await adb.store_one_time_prekeys(user_id, otp_records, device_id=device_id)
|
|
uploaded_otps = len(otp_records)
|
|
|
|
# Recount after upload
|
|
if uploaded_spk or uploaded_otps:
|
|
count = await adb.count_one_time_prekeys(user_id, device_id=device_id)
|
|
spk = await adb.get_signed_prekey(user_id, device_id=device_id)
|
|
if spk and spk.get("created_at"):
|
|
spk_created_at = spk["created_at"].isoformat() if hasattr(spk["created_at"], "isoformat") else str(spk["created_at"])
|
|
logger.info("[PREKEYS] %s ensure_prekeys: uploaded SPK=%s, OTPs=%d, new count=%d",
|
|
_who(session), uploaded_spk, uploaded_otps, count)
|
|
|
|
await send_resp(msg, writer, "ensure_prekeys", "ok",
|
|
{"count": count, "spk_created_at": spk_created_at,
|
|
"uploaded_spk": uploaded_spk, "uploaded_otps": uploaded_otps})
|
|
|
|
|
|
async def handle_rotate_keys(msg: dict, session: dict, writer: ProtocolWriter):
|
|
if await _is_rate_limited(f"rotate_keys|{session['user_id']}", 3):
|
|
await send_resp(msg, writer, "rotate_keys", "error", {"message": "Too many requests. Try later."})
|
|
return
|
|
public_key = msg.get("public_key", "").strip()
|
|
if not public_key:
|
|
await send_resp(msg, writer, "rotate_keys", "error", {"message": "Missing public_key"})
|
|
return
|
|
if not _validate_public_key_pem(public_key):
|
|
await send_resp(msg, writer, "rotate_keys", "error", {"message": "Invalid public key format"})
|
|
return
|
|
await adb.update_user_rsa_key(session["user_id"], public_key)
|
|
logger.info("[ROTATE] %s rotated RSA key", _who(session))
|
|
await send_resp(msg, writer, "rotate_keys", "ok", {"message": "OK"})
|
|
# Disconnect other sessions
|
|
async with _clients_lock:
|
|
writers = connected_clients.get(session["user_id"], [])
|
|
others = [w for w in writers if w is not writer]
|
|
connected_clients[session["user_id"]] = [writer]
|
|
for w in others:
|
|
try:
|
|
w.close()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
async def handle_change_username(msg: dict, session: dict, writer: ProtocolWriter):
|
|
if await _is_rate_limited(f"change_username|{session['user_id']}", 5):
|
|
await send_resp(msg, writer, "change_username", "error", {"message": "Too many requests. Try later."})
|
|
return
|
|
new_username = msg.get("username", "").strip()
|
|
if not new_username or len(new_username) > 100:
|
|
await send_resp(msg, writer, "change_username", "error", {"message": "Invalid username (1-100 chars)"})
|
|
return
|
|
user_id = session["user_id"]
|
|
await adb.update_username(user_id, new_username)
|
|
session["username"] = new_username
|
|
logger.info("[ACCOUNT] %s changed username", _who(session))
|
|
await send_resp(msg, writer, "change_username", "ok", {"username": new_username})
|
|
# Notify contacts
|
|
contacts = await adb.get_user_contacts(user_id)
|
|
targets = []
|
|
async with _clients_lock:
|
|
for cid in contacts:
|
|
for cw in connected_clients.get(cid, []):
|
|
targets.append(cw)
|
|
for cw in targets:
|
|
try:
|
|
await cw.send_response("username_changed", "ok", {
|
|
"user_id": user_id, "username": new_username,
|
|
})
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
async def handle_pairing_start(msg: dict, writer: ProtocolWriter):
|
|
await _cleanup_pairings()
|
|
email = msg.get("email", "").strip()
|
|
normalized_email = _normalize_email(email)
|
|
temp_public_key = msg.get("temp_public_key", "").strip()
|
|
temp_key_type = msg.get("temp_key_type", "x25519").strip()
|
|
addr = _get_peer_addr(writer)
|
|
# H4 fix: rate limit per IP only (not per email) to prevent enumeration via email rotation
|
|
if await _is_rate_limited(_rate_limit_key("pairing_start", addr), 10):
|
|
await send_resp(msg, writer, "pairing_start", "error", {"message": "Too many attempts. Try later."})
|
|
return
|
|
if not email or not temp_public_key:
|
|
await send_resp(msg, writer, "pairing_start", "error", {"message": "Missing email or temp_public_key"})
|
|
return
|
|
if temp_key_type != "x25519":
|
|
await send_resp(msg, writer, "pairing_start", "error", {"message": "Unsupported temp_key_type"})
|
|
return
|
|
try:
|
|
temp_pub_raw = decode_binary(temp_public_key)
|
|
if len(temp_pub_raw) != 32:
|
|
raise ValueError("bad length")
|
|
load_x25519_public(temp_pub_raw)
|
|
except Exception:
|
|
await send_resp(msg, writer, "pairing_start", "error", {"message": "Invalid temp_public_key"})
|
|
return
|
|
user = await adb.get_user_by_email(normalized_email)
|
|
target_user_id = user["id"] if user else None
|
|
poll_token = secrets.token_hex(16)
|
|
cap_hit = False
|
|
async with _pairing_lock:
|
|
# H4 fix: global cap prevents memory exhaustion from dummy sessions
|
|
if len(pairing_sessions) >= PAIRING_MAX_SESSIONS:
|
|
cap_hit = True
|
|
else:
|
|
code = _generate_pairing_code()
|
|
# H4 fix: always create session (anti-enumeration). For non-existent users
|
|
# the session behaves identically (poll returns ready:false, claim never matches
|
|
# because no real account can log in to claim it). TTL cleanup handles expiry.
|
|
pairing_sessions[code] = {
|
|
"email": normalized_email,
|
|
"user_id": target_user_id,
|
|
"temp_public_key": temp_public_key,
|
|
"temp_key_type": temp_key_type,
|
|
"created_at": asyncio.get_event_loop().time(),
|
|
"payload": None,
|
|
"poll_token": poll_token,
|
|
}
|
|
if cap_hit:
|
|
await send_resp(msg, writer, "pairing_start", "error", {"message": "Too many attempts. Try later."})
|
|
return
|
|
logger.info(
|
|
"[PAIR] start code=%s user=%s pid=%s",
|
|
code[:8],
|
|
target_user_id[:8] if target_user_id else "<none>",
|
|
os.getpid(),
|
|
)
|
|
await send_resp(msg, writer, "pairing_start", "ok", {"code": code, "poll_token": poll_token})
|
|
|
|
|
|
async def handle_pairing_claim(msg: dict, session: dict, writer: ProtocolWriter):
|
|
await _cleanup_pairings()
|
|
code = msg.get("code", "").strip()
|
|
if not code:
|
|
await send_resp(msg, writer, "pairing_claim", "error", {"message": "Missing code"})
|
|
return
|
|
async with _pairing_lock:
|
|
p = pairing_sessions.get(code)
|
|
p_user_id = p.get("user_id") if p else None
|
|
temp_pub = p["temp_public_key"] if p else None
|
|
temp_key_type = p.get("temp_key_type", "x25519") if p else "x25519"
|
|
if p:
|
|
# Extend TTL — re-encryption may run between claim and send
|
|
p["created_at"] = asyncio.get_event_loop().time()
|
|
# H4 fix: unified error message (anti-enumeration)
|
|
if not p or not p_user_id or p_user_id != session.get("user_id"):
|
|
logger.warning(
|
|
"[PAIR] claim rejected code=%s pid=%s exists=%s target=%s session=%s",
|
|
code[:8],
|
|
os.getpid(),
|
|
bool(p),
|
|
p_user_id[:8] if p_user_id else "<none>",
|
|
session.get("user_id", "")[:8] if session.get("user_id") else "<none>",
|
|
)
|
|
await send_resp(msg, writer, "pairing_claim", "error", {"message": "Invalid or expired code"})
|
|
return
|
|
logger.info(
|
|
"[PAIR] claim ok code=%s user=%s pid=%s",
|
|
code[:8],
|
|
session.get("user_id", "")[:8],
|
|
os.getpid(),
|
|
)
|
|
await send_resp(msg, writer, "pairing_claim", "ok", {
|
|
"temp_public_key": temp_pub,
|
|
"temp_key_type": temp_key_type,
|
|
})
|
|
|
|
|
|
async def handle_pairing_send(msg: dict, session: dict, writer: ProtocolWriter):
|
|
await _cleanup_pairings()
|
|
code = msg.get("code", "").strip()
|
|
payload = msg.get("payload")
|
|
if not code or not payload:
|
|
await send_resp(msg, writer, "pairing_send", "error", {"message": "Missing code or payload"})
|
|
return
|
|
error_msg = None
|
|
async with _pairing_lock:
|
|
p = pairing_sessions.get(code)
|
|
# H4 fix: unified error message (anti-enumeration)
|
|
if not p or not p.get("user_id") or p["user_id"] != session.get("user_id"):
|
|
error_msg = "Invalid or expired code"
|
|
logger.warning(
|
|
"[PAIR] send rejected code=%s pid=%s exists=%s target=%s session=%s",
|
|
code[:8],
|
|
os.getpid(),
|
|
bool(p),
|
|
p.get("user_id", "")[:8] if p and p.get("user_id") else "<none>",
|
|
session.get("user_id", "")[:8] if session.get("user_id") else "<none>",
|
|
)
|
|
else:
|
|
p["payload"] = payload
|
|
logger.info(
|
|
"[PAIR] send ok code=%s user=%s pid=%s",
|
|
code[:8],
|
|
session.get("user_id", "")[:8],
|
|
os.getpid(),
|
|
)
|
|
if error_msg:
|
|
await send_resp(msg, writer, "pairing_send", "error", {"message": error_msg})
|
|
else:
|
|
await send_resp(msg, writer, "pairing_send", "ok", {"message": "OK"})
|
|
|
|
|
|
async def handle_pairing_poll(msg: dict, writer: ProtocolWriter):
|
|
await _cleanup_pairings()
|
|
code = msg.get("code", "").strip()
|
|
poll_token = msg.get("poll_token", "").strip()
|
|
addr = _get_peer_addr(writer)
|
|
if await _is_rate_limited(_rate_limit_key("pairing_poll", addr), 120):
|
|
await send_resp(msg, writer, "pairing_poll", "error", {"message": "Too many attempts. Try later."})
|
|
return
|
|
if not code:
|
|
await send_resp(msg, writer, "pairing_poll", "error", {"message": "Missing code"})
|
|
return
|
|
if not poll_token:
|
|
await send_resp(msg, writer, "pairing_poll", "error", {"message": "Missing poll_token"})
|
|
return
|
|
error_msg = None
|
|
ready = False
|
|
payload = None
|
|
async with _pairing_lock:
|
|
p = pairing_sessions.get(code)
|
|
if not p:
|
|
error_msg = "Invalid or expired code"
|
|
logger.warning("[PAIR] poll rejected code=%s pid=%s exists=false", code[:8], os.getpid())
|
|
elif not secrets.compare_digest(p.get("poll_token", ""), poll_token):
|
|
error_msg = "Invalid poll_token"
|
|
logger.warning("[PAIR] poll token mismatch code=%s pid=%s", code[:8], os.getpid())
|
|
else:
|
|
poll_attempts = p.get("poll_attempts", 0) + 1
|
|
p["poll_attempts"] = poll_attempts
|
|
if poll_attempts > PAIRING_MAX_POLL_ATTEMPTS and not p.get("payload"):
|
|
pairing_sessions.pop(code, None)
|
|
error_msg = "Code invalidated due to too many attempts"
|
|
logger.warning("[PAIR] poll invalidated code=%s pid=%s attempts=%s", code[:8], os.getpid(), poll_attempts)
|
|
elif p.get("payload"):
|
|
ready = True
|
|
payload = p["payload"]
|
|
pairing_sessions.pop(code, None)
|
|
logger.info("[PAIR] poll ready code=%s pid=%s", code[:8], os.getpid())
|
|
if error_msg:
|
|
await send_resp(msg, writer, "pairing_poll", "error", {"message": error_msg})
|
|
elif ready:
|
|
await send_resp(msg, writer, "pairing_poll", "ok", {"ready": True, "payload": payload})
|
|
else:
|
|
await send_resp(msg, writer, "pairing_poll", "ok", {"ready": False})
|
|
|
|
|
|
async def handle_create_conversation(msg: dict, session: dict, writer: ProtocolWriter):
|
|
member_emails = msg.get("members", [])
|
|
name = msg.get("name")
|
|
addr = _get_peer_addr(writer)
|
|
if await _is_rate_limited(f"create_conversation|{session['user_id']}", 10):
|
|
await send_resp(msg, writer, "create_conversation", "error", {"message": "Too many attempts. Try later."})
|
|
return
|
|
# Resolve all member user IDs
|
|
other_users = []
|
|
for email in member_emails:
|
|
u = await adb.get_user_by_email(email)
|
|
if not u:
|
|
if not _valid_email(email):
|
|
await send_resp(msg, writer, "create_conversation", "error", {"message": f"Invalid email format: {email}"})
|
|
return
|
|
# H5: atomic phantom creation (cap check + DB create + set add)
|
|
u, err_msg = await _create_phantom_guarded(email, addr, session["user_id"])
|
|
if u is None:
|
|
await send_resp(msg, writer, "create_conversation", "error", {"message": err_msg})
|
|
return
|
|
if u["id"] != session["user_id"]:
|
|
other_users.append(u)
|
|
is_dm = len(other_users) == 1 and not name
|
|
joined_at = datetime.now(timezone.utc)
|
|
if is_dm:
|
|
# DMs: add both members directly (no invitation)
|
|
all_ids = [session["user_id"]] + [u["id"] for u in other_users]
|
|
conv_id = await adb.create_conversation(all_ids, joined_at=joined_at, name=name, created_by=session["user_id"])
|
|
logger.info("[CONV] %s created DM conv=%s", _who(session), conv_id[:8])
|
|
await send_resp(msg, writer, "create_conversation", "ok", {"conversation_id": conv_id})
|
|
# Notify the other member
|
|
members_info = await adb.get_conversation_members(conv_id)
|
|
member_list = [{"user_id": m["id"], "username": m["username"], "email": m["email"]} for m in members_info]
|
|
notif_data = {
|
|
"conversation_id": conv_id,
|
|
"name": name,
|
|
"created_by": session["user_id"],
|
|
"members": member_list,
|
|
}
|
|
await _notify_users([u["id"] for u in other_users], "conversation_created", notif_data)
|
|
else:
|
|
# Groups: only add creator, create invitations for others
|
|
conv_id = await adb.create_conversation([session["user_id"]], joined_at=joined_at, name=name, created_by=session["user_id"])
|
|
logger.info("[CONV] %s created group conv=%s",
|
|
_who(session), conv_id[:8])
|
|
# Create invitations for other members
|
|
creator_user = await adb.get_user_by_id(session["user_id"])
|
|
creator_name = creator_user["username"] if creator_user else "Unknown"
|
|
invited_ids = []
|
|
async with _clients_lock:
|
|
phantom_snapshot = set(phantom_user_ids)
|
|
for u in other_users:
|
|
await adb.create_invitation(conv_id, u["id"], session["user_id"])
|
|
if u["id"] not in phantom_snapshot:
|
|
invited_ids.append(u["id"]) # only notify non-phantoms
|
|
inv_notif = {
|
|
"conversation_id": conv_id,
|
|
"conversation_name": name,
|
|
"invited_by": session["user_id"],
|
|
"invited_by_username": creator_name,
|
|
}
|
|
await _notify_users(invited_ids, "group_invitation", inv_notif)
|
|
await send_resp(msg, writer, "create_conversation", "ok", {"conversation_id": conv_id})
|
|
|
|
|
|
async def handle_find_conversation(msg: dict, session: dict, writer: ProtocolWriter):
|
|
email = msg.get("email", "").strip()
|
|
if not email:
|
|
await send_resp(msg, writer, "find_conversation", "error", {"message": "Invalid request"})
|
|
return
|
|
addr = _get_peer_addr(writer)
|
|
if await _is_rate_limited(_rate_limit_key("find_conversation", addr, email), 30):
|
|
await send_resp(msg, writer, "find_conversation", "error", {"message": "Too many attempts. Try later."})
|
|
return
|
|
other = await adb.get_user_by_email(email)
|
|
if not other:
|
|
if not _valid_email(email):
|
|
await send_resp(msg, writer, "find_conversation", "error", {"message": "Invalid email format"})
|
|
return
|
|
# H5: atomic phantom creation (cap check + DB create + set add)
|
|
other, err_msg = await _create_phantom_guarded(email, addr, session["user_id"])
|
|
if other is None:
|
|
await send_resp(msg, writer, "find_conversation", "error", {"message": err_msg})
|
|
return
|
|
conv_id = await adb.find_direct_conversation(session["user_id"], other["id"])
|
|
await send_resp(msg, writer, "find_conversation", "ok", {
|
|
"conversation_id": conv_id,
|
|
"user_id": other["id"],
|
|
})
|
|
|
|
|
|
async def handle_add_member(msg: dict, session: dict, writer: ProtocolWriter):
|
|
conv_id = msg.get("conversation_id", "")
|
|
email = msg.get("email", "").strip()
|
|
if not conv_id or not email:
|
|
await send_resp(msg, writer, "add_member", "error", {"message": "Invalid request"})
|
|
return
|
|
if not _valid_uuid(conv_id):
|
|
await send_resp(msg, writer, "add_member", "error", {"message": "Invalid conversation_id"})
|
|
return
|
|
# L8: validate email format before phantom creation
|
|
addr = _get_peer_addr(writer)
|
|
if await _is_rate_limited(_rate_limit_key("add_member", addr, email), 10):
|
|
await send_resp(msg, writer, "add_member", "error", {"message": "Too many attempts. Try later."})
|
|
return
|
|
if not await adb.is_conversation_member(conv_id, session["user_id"]):
|
|
await send_resp(msg, writer, "add_member", "error", {"message": "Not a member"})
|
|
return
|
|
user = await adb.get_user_by_email(email)
|
|
if not user:
|
|
# Create phantom for unregistered email (same as create_conversation)
|
|
if not _valid_email(email):
|
|
await send_resp(msg, writer, "add_member", "error", {"message": "Invalid email format"})
|
|
return
|
|
# H5: atomic phantom creation (cap check + DB create + set add)
|
|
user, err_msg = await _create_phantom_guarded(email, addr, session["user_id"])
|
|
if user is None:
|
|
await send_resp(msg, writer, "add_member", "error", {"message": err_msg})
|
|
return
|
|
if await adb.is_conversation_member(conv_id, user["id"]):
|
|
await send_resp(msg, writer, "add_member", "error", {"message": "Already a member"})
|
|
return
|
|
if await adb.has_pending_invitation(conv_id, user["id"]):
|
|
await send_resp(msg, writer, "add_member", "error", {"message": "Invitation already pending"})
|
|
return
|
|
# Create invitation (for both real and phantom users)
|
|
await adb.create_invitation(conv_id, user["id"], session["user_id"])
|
|
logger.info("[INVITE] %s invited u=%s to conv=%s", _who(session), user["id"][:8], conv_id[:8])
|
|
await send_resp(msg, writer, "add_member", "ok", {"user_id": user["id"]})
|
|
# Push invitation notification only to non-phantom users
|
|
async with _clients_lock:
|
|
is_phantom = user["id"] in phantom_user_ids
|
|
if not is_phantom:
|
|
conv = await adb.get_conversation(conv_id)
|
|
creator_user = await adb.get_user_by_id(session["user_id"])
|
|
creator_name = creator_user["username"] if creator_user else "Unknown"
|
|
inv_notif = {
|
|
"conversation_id": conv_id,
|
|
"conversation_name": conv.get("name") if conv else None,
|
|
"invited_by": session["user_id"],
|
|
"invited_by_username": creator_name,
|
|
}
|
|
await _notify_users([user["id"]], "group_invitation", inv_notif)
|
|
|
|
|
|
async def handle_accept_invitation(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""Accept a group invitation — add user to conversation members."""
|
|
conv_id = msg.get("conversation_id", "")
|
|
if not conv_id:
|
|
await send_resp(msg, writer, "accept_invitation", "error", {"message": "Missing conversation_id"})
|
|
return
|
|
if not _valid_uuid(conv_id):
|
|
await send_resp(msg, writer, "accept_invitation", "error", {"message": "Invalid conversation_id"})
|
|
return
|
|
if not await adb.has_pending_invitation(conv_id, session["user_id"]):
|
|
await send_resp(msg, writer, "accept_invitation", "error", {"message": "No pending invitation"})
|
|
return
|
|
joined_at = datetime.now(timezone.utc)
|
|
await adb.add_conversation_member(conv_id, session["user_id"], joined_at=joined_at)
|
|
await adb.delete_invitation(conv_id, session["user_id"])
|
|
logger.info("[INVITE] %s accepted invitation to conv=%s", _who(session), conv_id[:8])
|
|
await send_resp(msg, writer, "accept_invitation", "ok", {"conversation_id": conv_id})
|
|
# Notify existing members about the new member
|
|
user = await adb.get_user_by_id(session["user_id"])
|
|
notif_data = {
|
|
"conversation_id": conv_id,
|
|
"user_id": session["user_id"],
|
|
"username": user["username"] if user else "",
|
|
"email": user["email"] if user else "",
|
|
}
|
|
members = await adb.get_conversation_members(conv_id)
|
|
member_ids = [m["id"] for m in members if m["id"] != session["user_id"]]
|
|
await _notify_users(member_ids, "member_added", notif_data)
|
|
|
|
|
|
async def handle_decline_invitation(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""Decline a group invitation."""
|
|
conv_id = msg.get("conversation_id", "")
|
|
if not conv_id:
|
|
await send_resp(msg, writer, "decline_invitation", "error", {"message": "Missing conversation_id"})
|
|
return
|
|
if not _valid_uuid(conv_id):
|
|
await send_resp(msg, writer, "decline_invitation", "error", {"message": "Invalid conversation_id"})
|
|
return
|
|
if not await adb.has_pending_invitation(conv_id, session["user_id"]):
|
|
await send_resp(msg, writer, "decline_invitation", "error", {"message": "No pending invitation"})
|
|
return
|
|
await adb.delete_invitation(conv_id, session["user_id"])
|
|
logger.info("[INVITE] %s declined invitation to conv=%s", _who(session), conv_id[:8])
|
|
await send_resp(msg, writer, "decline_invitation", "ok", {"message": "OK"})
|
|
|
|
|
|
async def handle_list_invitations(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""List pending group invitations for the current user."""
|
|
invitations = await adb.get_pending_invitations(session["user_id"])
|
|
result = []
|
|
for inv in invitations:
|
|
entry = {
|
|
"conversation_id": inv["conversation_id"],
|
|
"conversation_name": inv.get("conversation_name"),
|
|
"invited_by": inv["invited_by"],
|
|
"invited_by_username": inv.get("invited_by_username", ""),
|
|
"created_at": inv["created_at"].isoformat() if hasattr(inv["created_at"], "isoformat") else str(inv["created_at"]),
|
|
}
|
|
result.append(entry)
|
|
await send_resp(msg, writer, "list_invitations", "ok", {"invitations": result})
|
|
|
|
|
|
async def handle_list_conversations(msg: dict, session: dict, writer: ProtocolWriter):
|
|
convs = await adb.list_user_conversations(session["user_id"])
|
|
unread = await adb.get_unread_counts(session["user_id"], max_age_days=METADATA_RETENTION_DAYS)
|
|
result = []
|
|
for c in convs:
|
|
result.append({
|
|
"conversation_id": c["id"],
|
|
"created_at": c["created_at"].isoformat() if hasattr(c["created_at"], "isoformat") else str(c["created_at"]),
|
|
"members": c["members"],
|
|
"name": c.get("name"),
|
|
"created_by": c.get("created_by"),
|
|
"avatar_file": c.get("avatar_file"),
|
|
"unread_count": unread.get(c["id"], 0),
|
|
})
|
|
logger.info("[LIST] %s listed %d conversations", _who(session), len(result))
|
|
await send_resp(msg, writer, "list_conversations", "ok", {"conversations": result})
|
|
|
|
|
|
async def handle_send_message(msg: dict, session: dict, writer: ProtocolWriter):
|
|
conv_id = msg.get("conversation_id", "")
|
|
if not conv_id:
|
|
await send_resp(msg, writer, "send_message", "error", {"message": "Missing conversation_id"})
|
|
return
|
|
if not _valid_uuid(conv_id):
|
|
await send_resp(msg, writer, "send_message", "error", {"message": "Invalid conversation_id"})
|
|
return
|
|
addr = _get_peer_addr(writer)
|
|
if await _is_rate_limited(_rate_limit_key("send_message", addr, session.get("email")), 20):
|
|
await send_resp(msg, writer, "send_message", "error", {"message": "Too many attempts. Try later."})
|
|
return
|
|
if not await adb.is_conversation_member(conv_id, session["user_id"]):
|
|
await send_resp(msg, writer, "send_message", "error", {"message": "Not a member"})
|
|
return
|
|
|
|
# New protocol: ratchet_header + recipients[] with per-user ciphertext
|
|
ratchet_header_raw = msg.get("ratchet_header")
|
|
recipients_raw = msg.get("recipients")
|
|
if not ratchet_header_raw or not recipients_raw:
|
|
await send_resp(msg, writer, "send_message", "error", {"message": "Missing ratchet_header or recipients"})
|
|
return
|
|
|
|
# C2 fix: validate header is a dict (reject raw str/bytes)
|
|
ratchet_header = _validate_header(ratchet_header_raw, "ratchet_header")
|
|
if ratchet_header is None:
|
|
await send_resp(msg, writer, "send_message", "error", {"message": "Invalid ratchet_header format"})
|
|
return
|
|
|
|
x3dh_header_raw = msg.get("x3dh_header")
|
|
x3dh_header = None
|
|
if x3dh_header_raw:
|
|
x3dh_header = _validate_header(x3dh_header_raw, "x3dh_header")
|
|
if x3dh_header is None:
|
|
await send_resp(msg, writer, "send_message", "error", {"message": "Invalid x3dh_header format"})
|
|
return
|
|
|
|
sender_chain_id_b64 = msg.get("sender_chain_id")
|
|
sender_chain_id = decode_binary(sender_chain_id_b64) if sender_chain_id_b64 else None
|
|
sender_chain_n = msg.get("sender_chain_n")
|
|
|
|
# Validate recipients are actual members
|
|
conv_members = await adb.get_conversation_members(conv_id)
|
|
member_ids = {m["id"] for m in conv_members}
|
|
async with _clients_lock:
|
|
phantom_snapshot = set(phantom_user_ids)
|
|
db_recipients = []
|
|
for r in recipients_raw:
|
|
uid = r.get("user_id", "")
|
|
if uid not in member_ids:
|
|
continue
|
|
if uid in phantom_snapshot:
|
|
continue
|
|
ct_b64 = r.get("encrypted_content", "")
|
|
nonce_b64 = r.get("nonce", "")
|
|
if not ct_b64 or not nonce_b64:
|
|
continue
|
|
entry = {
|
|
"user_id": uid,
|
|
"encrypted_content": decode_binary(ct_b64),
|
|
"nonce": decode_binary(nonce_b64),
|
|
}
|
|
# Per-recipient device_id (multi-device support)
|
|
r_device_id = r.get("device_id")
|
|
if r_device_id:
|
|
entry["device_id"] = r_device_id
|
|
# Per-recipient ratchet header and x3dh header (C2 fix: validate dict)
|
|
r_rh = r.get("ratchet_header")
|
|
if r_rh:
|
|
r_rh_bytes = _validate_header(r_rh, "recipient_ratchet_header")
|
|
if r_rh_bytes:
|
|
entry["ratchet_header"] = r_rh_bytes
|
|
r_x3dh = r.get("x3dh_header")
|
|
if r_x3dh:
|
|
r_x3dh_bytes = _validate_header(r_x3dh, "recipient_x3dh_header")
|
|
if r_x3dh_bytes:
|
|
entry["x3dh_header"] = r_x3dh_bytes
|
|
db_recipients.append(entry)
|
|
if not db_recipients:
|
|
await send_resp(msg, writer, "send_message", "error", {"message": "No valid recipients"})
|
|
return
|
|
|
|
image_file_id = msg.get("image_file_id")
|
|
|
|
# Metadata privacy: for group messages (sender_chain_id present), store chain
|
|
# metadata in per-recipient ratchet_header instead of the messages table.
|
|
# This avoids persisting sender correlation data at the message level.
|
|
# Skip sender's own self-copy entry — it uses a different decrypt path
|
|
# (self-encryption key) and must keep its own ratchet_header ({"self":true}).
|
|
db_sender_chain_id = None
|
|
db_sender_chain_n = None
|
|
if sender_chain_id:
|
|
chain_meta = json.dumps({
|
|
"chain_id": encode_binary(sender_chain_id),
|
|
"chain_n": sender_chain_n,
|
|
}).encode()
|
|
sender_uid = session["user_id"]
|
|
for r in db_recipients:
|
|
# Skip self-copy (sender's own entry) — uses self-encryption, not sender key
|
|
if r["user_id"] == sender_uid:
|
|
continue
|
|
if not r.get("ratchet_header"):
|
|
r["ratchet_header"] = chain_meta
|
|
|
|
msg_id, created_at = await adb.store_message(
|
|
conv_id, session["user_id"], ratchet_header, db_recipients,
|
|
x3dh_header=x3dh_header,
|
|
sender_chain_id=db_sender_chain_id,
|
|
sender_chain_n=db_sender_chain_n,
|
|
image_file_id=image_file_id,
|
|
sender_device_id=session.get("device_id"),
|
|
)
|
|
|
|
# Link image upload to message if present
|
|
if image_file_id:
|
|
upload = await adb.get_image_upload(image_file_id)
|
|
if upload and upload["completed"] and upload["uploader_id"] == session["user_id"]:
|
|
await adb.set_message_image_file_id(msg_id, image_file_id)
|
|
|
|
logger.info("[MSG] %s msg=%s conv=%s", _who(session), msg_id[:8], conv_id[:8])
|
|
await send_resp(msg, writer, "send_message", "ok", {"message_id": msg_id, "created_at": created_at})
|
|
|
|
# Notify connected recipients — group all per-device entries by user_id
|
|
# Use validated db_recipients (not raw input) to prevent unvalidated headers in push
|
|
msg_ratchet_header_dict = json.loads(ratchet_header.decode())
|
|
msg_x3dh_header_dict = json.loads(x3dh_header.decode()) if x3dh_header else None
|
|
|
|
from collections import defaultdict
|
|
user_entries = defaultdict(list)
|
|
for r in db_recipients:
|
|
uid = r["user_id"]
|
|
# Per-recipient headers are stored as bytes; decode back to dict for notification JSON
|
|
r_rh = r.get("ratchet_header")
|
|
r_rh_dict = json.loads(r_rh.decode()) if r_rh else None
|
|
r_x3dh = r.get("x3dh_header")
|
|
r_x3dh_dict = json.loads(r_x3dh.decode()) if r_x3dh else None
|
|
user_entries[uid].append({
|
|
"device_id": r.get("device_id", db.SELF_DEVICE_ID),
|
|
"encrypted_content": encode_binary(r["encrypted_content"]),
|
|
"nonce": encode_binary(r["nonce"]),
|
|
"ratchet_header": r_rh_dict or msg_ratchet_header_dict,
|
|
"x3dh_header": r_x3dh_dict or msg_x3dh_header_dict,
|
|
})
|
|
|
|
notifications = []
|
|
for uid, entries in user_entries.items():
|
|
notif_data = {
|
|
"message_id": msg_id,
|
|
"conversation_id": conv_id,
|
|
"sender_id": session["user_id"],
|
|
"sender_device_id": session.get("device_id"),
|
|
"device_entries": entries,
|
|
}
|
|
if sender_chain_id_b64:
|
|
notif_data["sender_chain_id"] = sender_chain_id_b64
|
|
if sender_chain_n is not None:
|
|
notif_data["sender_chain_n"] = sender_chain_n
|
|
# Also include flat fields for backward compat with old clients
|
|
# (first entry's data as fallback)
|
|
if entries:
|
|
first = entries[0]
|
|
notif_data["ratchet_header"] = first.get("ratchet_header") or msg_ratchet_header_dict
|
|
notif_data["encrypted_content"] = first.get("encrypted_content", "")
|
|
notif_data["nonce"] = first.get("nonce", "")
|
|
if first.get("x3dh_header"):
|
|
notif_data["x3dh_header"] = first["x3dh_header"]
|
|
notifications.append((uid, "new_message", notif_data))
|
|
# Log notification targets for debugging delivery issues
|
|
async with _clients_lock:
|
|
targets_info = []
|
|
for uid, _, _ in notifications:
|
|
n_writers = len(connected_clients.get(uid, []))
|
|
targets_info.append(f"{uid[:8]}({n_writers}w)")
|
|
logger.info("[PUSH] msg=%s conv=%s targets=[%s] exclude_sender=%s",
|
|
msg_id[:8], conv_id[:8], ", ".join(targets_info), "yes")
|
|
delivered_users = await _notify_users_individual(notifications, exclude_writer=writer)
|
|
|
|
# Delivery receipt: if at least one recipient device got the push, acknowledge
|
|
# delivery to sender immediately.
|
|
delivered_users.discard(session["user_id"])
|
|
if delivered_users:
|
|
for delivered_uid in delivered_users:
|
|
await _notify_users([session["user_id"]], "message_delivered", {
|
|
"conversation_id": conv_id,
|
|
"user_id": delivered_uid,
|
|
"message_ids": [msg_id],
|
|
})
|
|
|
|
|
|
async def handle_get_messages(msg: dict, session: dict, writer: ProtocolWriter):
|
|
if await _is_rate_limited(f"get_messages|{session['user_id']}", 30):
|
|
await send_resp(msg, writer, "get_messages", "error", {"message": "Too many requests. Try later."})
|
|
return
|
|
conv_id = msg.get("conversation_id", "")
|
|
if not conv_id:
|
|
await send_resp(msg, writer, "get_messages", "error", {"message": "Missing conversation_id"})
|
|
return
|
|
if not _valid_uuid(conv_id):
|
|
await send_resp(msg, writer, "get_messages", "error", {"message": "Invalid conversation_id"})
|
|
return
|
|
if not await adb.is_conversation_member(conv_id, session["user_id"]):
|
|
await send_resp(msg, writer, "get_messages", "error", {"message": "Not a member"})
|
|
return
|
|
|
|
limit = min(max(int(msg.get("limit", 50)), 1), 200)
|
|
offset = max(int(msg.get("offset", 0)), 0)
|
|
device_id = session.get("device_id")
|
|
after_ts = msg.get("after_ts") # ISO timestamp string or None
|
|
messages = await adb.get_messages(conv_id, session["user_id"], limit, offset,
|
|
device_id=device_id, after_ts=after_ts)
|
|
|
|
# Deduplicate: when both device-specific and SELF_DEVICE_ID rows exist for the
|
|
# same message, prefer device-specific (non-sentinel). Keep first seen per message_id.
|
|
seen_ids = {}
|
|
deduped = []
|
|
for m in messages:
|
|
mid = m["id"]
|
|
mr_dev = m.get("mr_device_id", "")
|
|
if mid not in seen_ids:
|
|
seen_ids[mid] = len(deduped)
|
|
deduped.append(m)
|
|
elif mr_dev != db.SELF_DEVICE_ID:
|
|
# Replace SELF_DEVICE_ID entry with device-specific one
|
|
deduped[seen_ids[mid]] = m
|
|
messages = deduped
|
|
|
|
result = []
|
|
message_ids = [m["id"] for m in messages]
|
|
read_status = await adb.get_message_read_status(message_ids) if message_ids else {}
|
|
delivery_status = await adb.get_message_delivery_status(message_ids) if message_ids else {}
|
|
reactions_map = await adb.get_reactions(message_ids) if message_ids else {}
|
|
for m in messages:
|
|
read_by = read_status.get(m["id"], [])
|
|
# Prefer per-recipient headers (mr_*) over message-level headers
|
|
rh_raw = m.get("mr_ratchet_header") or m.get("ratchet_header")
|
|
x3dh_raw = m.get("mr_x3dh_header") or m.get("x3dh_header")
|
|
# C2 fix: defensive JSON parsing — corrupted headers don't break fetch
|
|
try:
|
|
rh_parsed = json.loads(rh_raw) if rh_raw else {}
|
|
except (json.JSONDecodeError, TypeError, UnicodeDecodeError):
|
|
logger.warning("[FETCH] Corrupted ratchet_header in message %s, skipping", m["id"])
|
|
rh_parsed = {}
|
|
try:
|
|
x3dh_parsed = json.loads(x3dh_raw) if x3dh_raw else None
|
|
except (json.JSONDecodeError, TypeError, UnicodeDecodeError):
|
|
logger.warning("[FETCH] Corrupted x3dh_header in message %s, skipping", m["id"])
|
|
x3dh_parsed = None
|
|
entry = {
|
|
"message_id": m["id"],
|
|
"sender_id": m.get("sender_id") or "",
|
|
"ratchet_header": rh_parsed,
|
|
"encrypted_content": encode_binary(m["encrypted_content"]) if m.get("encrypted_content") else "",
|
|
"nonce": encode_binary(m["nonce"]) if m.get("nonce") else "",
|
|
"created_at": m["created_at"].isoformat() if hasattr(m["created_at"], "isoformat") else str(m["created_at"]),
|
|
"read_by": read_by,
|
|
"delivered_to": delivery_status.get(m["id"], []),
|
|
}
|
|
if x3dh_parsed:
|
|
entry["x3dh_header"] = x3dh_parsed
|
|
# Sender chain metadata: check message-level first (backward compat),
|
|
# then per-recipient ratchet_header (new metadata-private format).
|
|
# Only extract from per-recipient header if message-level ratchet_header
|
|
# is the group dummy (dh_pub all-zeros) — prevents DM header injection.
|
|
if m.get("sender_chain_id"):
|
|
entry["sender_chain_id"] = encode_binary(m["sender_chain_id"])
|
|
elif isinstance(rh_parsed, dict) and rh_parsed.get("chain_id"):
|
|
# Verify this is a group message by checking the message-level header
|
|
msg_rh_raw = m.get("ratchet_header")
|
|
is_group = False
|
|
if msg_rh_raw:
|
|
try:
|
|
msg_rh = json.loads(msg_rh_raw) if isinstance(msg_rh_raw, (bytes, str)) else msg_rh_raw
|
|
is_group = isinstance(msg_rh, dict) and msg_rh.get("dh_pub") == "00" * 32
|
|
except (json.JSONDecodeError, TypeError, UnicodeDecodeError):
|
|
pass
|
|
if is_group:
|
|
entry["sender_chain_id"] = rh_parsed["chain_id"]
|
|
if m.get("sender_chain_n") is not None:
|
|
entry["sender_chain_n"] = m["sender_chain_n"]
|
|
elif isinstance(rh_parsed, dict) and rh_parsed.get("chain_n") is not None:
|
|
# Same group-only guard
|
|
if "sender_chain_id" in entry:
|
|
entry["sender_chain_n"] = rh_parsed["chain_n"]
|
|
if m.get("sender_device_id"):
|
|
entry["sender_device_id"] = m["sender_device_id"]
|
|
if m.get("deleted_at"):
|
|
entry["deleted_at"] = m["deleted_at"].isoformat() if hasattr(m["deleted_at"], "isoformat") else str(m["deleted_at"])
|
|
# Pin metadata
|
|
if m.get("pinned_at"):
|
|
entry["pinned_at"] = m["pinned_at"].isoformat() if hasattr(m["pinned_at"], "isoformat") else str(m["pinned_at"])
|
|
entry["pinned_by"] = m.get("pinned_by") or ""
|
|
# Reactions
|
|
msg_reactions = reactions_map.get(m["id"])
|
|
if msg_reactions:
|
|
entry["reactions"] = msg_reactions
|
|
result.append(entry)
|
|
total_count = await adb.count_messages(conv_id, session["user_id"])
|
|
logger.info("[FETCH] %s fetched %d/%d msgs from conv=%s (limit=%d, offset=%d%s)",
|
|
_who(session), len(result), total_count, conv_id[:8], limit, offset,
|
|
f", after={after_ts}" if after_ts else "")
|
|
await send_resp(msg, writer, "get_messages", "ok",
|
|
{"messages": result, "total_count": total_count})
|
|
|
|
|
|
async def _handle_typing_event(msg_type: str, msg: dict, session: dict, writer: ProtocolWriter):
|
|
conv_id = msg.get("conversation_id", "")
|
|
if not conv_id:
|
|
await send_resp(msg, writer, msg_type, "error", {"message": "Missing conversation_id"})
|
|
return
|
|
if not _valid_uuid(conv_id):
|
|
await send_resp(msg, writer, msg_type, "error", {"message": "Invalid conversation_id"})
|
|
return
|
|
if await _is_rate_limited(f"{msg_type}|{session['user_id']}|{conv_id}", 120):
|
|
await send_resp(msg, writer, msg_type, "error", {"message": "Too many typing events. Slow down."})
|
|
return
|
|
if not await adb.is_conversation_member(conv_id, session["user_id"]):
|
|
await send_resp(msg, writer, msg_type, "error", {"message": "Not a member"})
|
|
return
|
|
|
|
members = await adb.get_conversation_members(conv_id)
|
|
targets = [m["id"] for m in members if m["id"] != session["user_id"]]
|
|
await _notify_users(targets, msg_type, {
|
|
"conversation_id": conv_id,
|
|
"user_id": session["user_id"],
|
|
"username": session.get("username", ""),
|
|
}, exclude_writer=writer)
|
|
await send_resp(msg, writer, msg_type, "ok", {"message": "OK"})
|
|
|
|
|
|
async def handle_typing_start(msg: dict, session: dict, writer: ProtocolWriter):
|
|
await _handle_typing_event("typing_start", msg, session, writer)
|
|
|
|
|
|
async def handle_typing_stop(msg: dict, session: dict, writer: ProtocolWriter):
|
|
await _handle_typing_event("typing_stop", msg, session, writer)
|
|
|
|
|
|
async def handle_remove_member(msg: dict, session: dict, writer: ProtocolWriter):
|
|
if await _is_rate_limited(f"remove_member|{session['user_id']}", 10):
|
|
await send_resp(msg, writer, "remove_member", "error", {"message": "Too many requests. Try later."})
|
|
return
|
|
conv_id = msg.get("conversation_id", "")
|
|
user_id = msg.get("user_id", "")
|
|
if not conv_id or not user_id:
|
|
await send_resp(msg, writer, "remove_member", "error", {"message": "Missing conversation_id or user_id"})
|
|
return
|
|
if not _valid_uuid(conv_id) or not _valid_uuid(user_id):
|
|
await send_resp(msg, writer, "remove_member", "error", {"message": "Invalid conversation_id or user_id"})
|
|
return
|
|
if not await adb.is_conversation_member(conv_id, session["user_id"]):
|
|
await send_resp(msg, writer, "remove_member", "error", {"message": "Not a member"})
|
|
return
|
|
convs = await adb.list_user_conversations(session["user_id"])
|
|
conv_data = None
|
|
for c in convs:
|
|
if c["id"] == conv_id:
|
|
conv_data = c
|
|
break
|
|
if not conv_data or conv_data.get("created_by") != session["user_id"]:
|
|
await send_resp(msg, writer, "remove_member", "error", {"message": "Only the group creator can remove members"})
|
|
return
|
|
if user_id == session["user_id"]:
|
|
await send_resp(msg, writer, "remove_member", "error", {"message": "Cannot remove yourself"})
|
|
return
|
|
# Get remaining members before removing (to notify them)
|
|
members_before = await adb.get_conversation_members(conv_id)
|
|
# M6: atomic removal — return value confirms row existed
|
|
removed = await adb.remove_conversation_member_atomic(conv_id, user_id)
|
|
if not removed:
|
|
await send_resp(msg, writer, "remove_member", "error", {"message": "Member already removed"})
|
|
return
|
|
logger.info("[MEMBER] %s removed user=%s from conv=%s", _who(session), user_id[:8], conv_id[:8])
|
|
await send_resp(msg, writer, "remove_member", "ok", {"message": "OK"})
|
|
|
|
# Notify removed member and remaining members
|
|
notif_data = {
|
|
"conversation_id": conv_id,
|
|
"user_id": user_id,
|
|
}
|
|
member_ids = [m["id"] for m in members_before if m["id"] != session["user_id"]]
|
|
await _notify_users(member_ids, "member_removed", notif_data)
|
|
|
|
|
|
async def handle_leave_group(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""Leave a group conversation voluntarily."""
|
|
conv_id = msg.get("conversation_id", "")
|
|
if not conv_id:
|
|
await send_resp(msg, writer, "leave_group", "error", {"message": "Missing conversation_id"})
|
|
return
|
|
if not _valid_uuid(conv_id):
|
|
await send_resp(msg, writer, "leave_group", "error", {"message": "Invalid conversation_id"})
|
|
return
|
|
if not await adb.is_conversation_member(conv_id, session["user_id"]):
|
|
await send_resp(msg, writer, "leave_group", "error", {"message": "Not a member"})
|
|
return
|
|
# Don't allow leaving DMs (2 members without a name)
|
|
conv = await adb.get_conversation(conv_id)
|
|
members = await adb.get_conversation_members(conv_id)
|
|
if len(members) <= 2 and not (conv and conv.get("name")):
|
|
await send_resp(msg, writer, "leave_group", "error", {"message": "Cannot leave a DM conversation"})
|
|
return
|
|
# If creator is leaving, transfer to first remaining member
|
|
if conv and conv.get("created_by") == session["user_id"]:
|
|
remaining = [m for m in members if m["id"] != session["user_id"]]
|
|
if remaining:
|
|
await adb.update_conversation_creator(conv_id, remaining[0]["id"])
|
|
# M6: atomic removal
|
|
await adb.remove_conversation_member_atomic(conv_id, session["user_id"])
|
|
logger.info("[LEAVE] %s left group conv=%s", _who(session), conv_id[:8])
|
|
await send_resp(msg, writer, "leave_group", "ok", {"message": "OK"})
|
|
# Notify remaining members
|
|
notif_data = {
|
|
"conversation_id": conv_id,
|
|
"user_id": session["user_id"],
|
|
}
|
|
member_ids = [m["id"] for m in members if m["id"] != session["user_id"]]
|
|
await _notify_users(member_ids, "member_removed", notif_data)
|
|
|
|
|
|
async def handle_rename_conversation(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""Rename a group conversation (creator only)."""
|
|
if await _is_rate_limited(f"rename_conv|{session['user_id']}", 5):
|
|
await send_resp(msg, writer, "rename_conversation", "error", {"message": "Too many requests. Try later."})
|
|
return
|
|
conv_id = msg.get("conversation_id", "")
|
|
new_name = msg.get("name", "").strip()
|
|
if not conv_id or not new_name:
|
|
await send_resp(msg, writer, "rename_conversation", "error", {"message": "Missing conversation_id or name"})
|
|
return
|
|
if not _valid_uuid(conv_id):
|
|
await send_resp(msg, writer, "rename_conversation", "error", {"message": "Invalid conversation_id"})
|
|
return
|
|
if len(new_name) > 100:
|
|
await send_resp(msg, writer, "rename_conversation", "error", {"message": "Name too long (max 100)"})
|
|
return
|
|
if not await adb.is_conversation_member(conv_id, session["user_id"]):
|
|
await send_resp(msg, writer, "rename_conversation", "error", {"message": "Not a member"})
|
|
return
|
|
conv = await adb.get_conversation(conv_id)
|
|
if not conv or not conv.get("name"):
|
|
await send_resp(msg, writer, "rename_conversation", "error", {"message": "Cannot rename a DM conversation"})
|
|
return
|
|
if conv.get("created_by") != session["user_id"]:
|
|
await send_resp(msg, writer, "rename_conversation", "error", {"message": "Only the group creator can rename"})
|
|
return
|
|
await adb.update_conversation_name(conv_id, new_name)
|
|
logger.info("[RENAME] %s renamed conv=%s", _who(session), conv_id[:8])
|
|
await send_resp(msg, writer, "rename_conversation", "ok", {"message": "OK"})
|
|
# Notify all members
|
|
members = await adb.get_conversation_members(conv_id)
|
|
member_ids = [m["id"] for m in members if m["id"] != session["user_id"]]
|
|
await _notify_users(member_ids, "conversation_renamed", {
|
|
"conversation_id": conv_id,
|
|
"name": new_name,
|
|
"renamed_by": session["user_id"],
|
|
})
|
|
|
|
|
|
async def handle_delete_conversation(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""Delete a conversation for the current user. Removes user from members,
|
|
deletes the conversation if no members remain."""
|
|
if await _is_rate_limited(f"delete_conv|{session['user_id']}", 5):
|
|
await send_resp(msg, writer, "delete_conversation", "error", {"message": "Too many requests. Try later."})
|
|
return
|
|
conv_id = msg.get("conversation_id", "")
|
|
if not conv_id:
|
|
await send_resp(msg, writer, "delete_conversation", "error", {"message": "Missing conversation_id"})
|
|
return
|
|
if not _valid_uuid(conv_id):
|
|
await send_resp(msg, writer, "delete_conversation", "error", {"message": "Invalid conversation_id"})
|
|
return
|
|
if not await adb.is_conversation_member(conv_id, session["user_id"]):
|
|
await send_resp(msg, writer, "delete_conversation", "error", {"message": "Not a member"})
|
|
return
|
|
conv = await adb.get_conversation(conv_id)
|
|
members = await adb.get_conversation_members(conv_id)
|
|
is_group = len(members) > 2 or (conv and conv.get("name"))
|
|
# Groups can only be deleted by the creator (admin)
|
|
if is_group and (not conv or conv.get("created_by") != session["user_id"]):
|
|
await send_resp(msg, writer, "delete_conversation", "error", {"message": "Only the group creator can delete this conversation"})
|
|
return
|
|
if is_group:
|
|
# Group: creator deletes for everyone — remove all members, clean up, delete
|
|
for member in members:
|
|
await adb.remove_conversation_member(conv_id, member["id"])
|
|
else:
|
|
# DM: only remove self; other user keeps the conversation
|
|
await adb.remove_conversation_member(conv_id, session["user_id"])
|
|
remaining_count = await adb.count_conversation_members(conv_id)
|
|
if remaining_count == 0:
|
|
# Clean up uploaded files from disk
|
|
file_ids = await adb.get_conversation_file_ids(conv_id)
|
|
for fid in file_ids:
|
|
for ext in (".enc", ".tmp"):
|
|
p = _safe_upload_path(fid, ext)
|
|
if not p:
|
|
continue
|
|
_secure_delete(p)
|
|
await adb.delete_conversation(conv_id)
|
|
logger.info("[DELETE] %s deleted conv=%s", _who(session), conv_id[:8])
|
|
await send_resp(msg, writer, "delete_conversation", "ok", {"message": "OK"})
|
|
# Notify other members they were removed
|
|
notif_data = {
|
|
"conversation_id": conv_id,
|
|
"user_id": session["user_id"],
|
|
}
|
|
member_ids = [m["id"] for m in members if m["id"] != session["user_id"]]
|
|
await _notify_users(member_ids, "member_removed", notif_data)
|
|
|
|
|
|
async def handle_mark_read(msg: dict, session: dict, writer: ProtocolWriter):
|
|
conv_id = msg.get("conversation_id", "")
|
|
message_ids = msg.get("message_ids", [])
|
|
if not conv_id or not message_ids:
|
|
await send_resp(msg, writer, "mark_read", "error", {"message": "Missing conversation_id or message_ids"})
|
|
return
|
|
if not _valid_uuid(conv_id):
|
|
await send_resp(msg, writer, "mark_read", "error", {"message": "Invalid conversation_id"})
|
|
return
|
|
if len(message_ids) > 500:
|
|
await send_resp(msg, writer, "mark_read", "error", {"message": "Too many message_ids (max 500)"})
|
|
return
|
|
if not await adb.is_conversation_member(conv_id, session["user_id"]):
|
|
await send_resp(msg, writer, "mark_read", "error", {"message": "Not a member"})
|
|
return
|
|
# M1 fix: filter to only message_ids that belong to this conversation
|
|
valid_ids = await adb.filter_message_ids_by_conversation(conv_id, message_ids)
|
|
if not valid_ids:
|
|
await send_resp(msg, writer, "mark_read", "ok", {"message": "OK"})
|
|
return
|
|
await adb.mark_messages_read(conv_id, session["user_id"], valid_ids)
|
|
logger.info("[READ] %s marked %d msgs read in conv=%s", _who(session), len(valid_ids), conv_id[:8])
|
|
await send_resp(msg, writer, "mark_read", "ok", {"message": "OK"})
|
|
members = await adb.get_conversation_members(conv_id)
|
|
notif_data = {
|
|
"conversation_id": conv_id,
|
|
"user_id": session["user_id"],
|
|
"message_ids": valid_ids,
|
|
}
|
|
member_ids = [m["id"] for m in members if m["id"] != session["user_id"]]
|
|
await _notify_users(member_ids, "messages_read", notif_data)
|
|
|
|
|
|
async def handle_mark_conversation_read(msg: dict, session: dict, writer: ProtocolWriter):
|
|
conv_id = msg.get("conversation_id", "")
|
|
if not conv_id:
|
|
await send_resp(msg, writer, "mark_conversation_read", "error", {"message": "Missing conversation_id"})
|
|
return
|
|
if not _valid_uuid(conv_id):
|
|
await send_resp(msg, writer, "mark_conversation_read", "error", {"message": "Invalid conversation_id"})
|
|
return
|
|
if not await adb.is_conversation_member(conv_id, session["user_id"]):
|
|
await send_resp(msg, writer, "mark_conversation_read", "error", {"message": "Not a member"})
|
|
return
|
|
count = await adb.mark_conversation_read(conv_id, session["user_id"])
|
|
logger.info("[READ] %s marked conv=%s all-read (%d msgs)", _who(session), conv_id[:8], count)
|
|
await send_resp(msg, writer, "mark_conversation_read", "ok", {"marked_count": count})
|
|
if count > 0:
|
|
members = await adb.get_conversation_members(conv_id)
|
|
member_ids = [m["id"] for m in members if m["id"] != session["user_id"]]
|
|
await _notify_users(member_ids, "messages_read", {
|
|
"conversation_id": conv_id,
|
|
"user_id": session["user_id"],
|
|
"message_ids": [],
|
|
})
|
|
|
|
|
|
async def handle_confirm_delivery(msg: dict, session: dict, writer: ProtocolWriter):
|
|
conv_id = msg.get("conversation_id", "")
|
|
message_ids = msg.get("message_ids", [])
|
|
if not conv_id or not message_ids:
|
|
await send_resp(msg, writer, "confirm_delivery", "error", {"message": "Missing conversation_id or message_ids"})
|
|
return
|
|
if not _valid_uuid(conv_id):
|
|
await send_resp(msg, writer, "confirm_delivery", "error", {"message": "Invalid conversation_id"})
|
|
return
|
|
if len(message_ids) > 500:
|
|
await send_resp(msg, writer, "confirm_delivery", "error", {"message": "Too many message_ids (max 500)"})
|
|
return
|
|
if not await adb.is_conversation_member(conv_id, session["user_id"]):
|
|
await send_resp(msg, writer, "confirm_delivery", "error", {"message": "Not a member"})
|
|
return
|
|
# M1 fix: filter to only message_ids that belong to this conversation
|
|
valid_ids = await adb.filter_message_ids_by_conversation(conv_id, message_ids)
|
|
if not valid_ids:
|
|
await send_resp(msg, writer, "confirm_delivery", "ok", {"message": "OK"})
|
|
return
|
|
await adb.mark_messages_delivered(conv_id, session["user_id"], valid_ids)
|
|
logger.info("[DELIVERY] %s confirmed %d msgs delivered in conv=%s", _who(session), len(valid_ids), conv_id[:8])
|
|
await send_resp(msg, writer, "confirm_delivery", "ok", {"message": "OK"})
|
|
|
|
# Notify senders — batch lookup sender_id per message, push to each sender
|
|
sender_msgs: dict[str, list[str]] = {}
|
|
for mid in valid_ids:
|
|
sid = await adb.get_message_sender(mid)
|
|
if sid and sid != session["user_id"]:
|
|
sender_msgs.setdefault(sid, []).append(mid)
|
|
for sender_id, mids in sender_msgs.items():
|
|
await _notify_users([sender_id], "message_delivered", {
|
|
"conversation_id": conv_id,
|
|
"user_id": session["user_id"],
|
|
"message_ids": mids,
|
|
})
|
|
|
|
|
|
async def handle_delete_message(msg: dict, session: dict, writer: ProtocolWriter):
|
|
if await _is_rate_limited(f"delete_msg|{session['user_id']}", 20):
|
|
await send_resp(msg, writer, "delete_message", "error", {"message": "Too many requests. Try later."})
|
|
return
|
|
message_id = msg.get("message_id", "")
|
|
if not message_id:
|
|
await send_resp(msg, writer, "delete_message", "error", {"message": "Missing message_id"})
|
|
return
|
|
if not _valid_uuid(message_id):
|
|
await send_resp(msg, writer, "delete_message", "error", {"message": "Invalid message_id"})
|
|
return
|
|
conv_id = await adb.get_message_conversation(message_id)
|
|
if not conv_id:
|
|
await send_resp(msg, writer, "delete_message", "error", {"message": "Message not found"})
|
|
return
|
|
if not await adb.is_conversation_member(conv_id, session["user_id"]):
|
|
await send_resp(msg, writer, "delete_message", "error", {"message": "Not a member"})
|
|
return
|
|
result = await adb.soft_delete_message(message_id, session["user_id"])
|
|
if result is None:
|
|
await send_resp(msg, writer, "delete_message", "error", {"message": "Cannot delete this message"})
|
|
return
|
|
image_file_id = result.get("image_file_id")
|
|
if image_file_id:
|
|
image_path = _safe_upload_path(image_file_id, ".enc")
|
|
if image_path:
|
|
_secure_delete(image_path)
|
|
await adb.delete_image_upload(image_file_id)
|
|
logger.info("[MSG] %s deleted message=%s", _who(session), message_id[:8])
|
|
await send_resp(msg, writer, "delete_message", "ok", {"message_id": message_id})
|
|
members = await adb.get_conversation_members(conv_id)
|
|
notif_data = {"message_id": message_id, "conversation_id": conv_id}
|
|
member_ids = [m["id"] for m in members if m["id"] != session["user_id"]]
|
|
await _notify_users(member_ids, "message_deleted", notif_data)
|
|
|
|
|
|
async def handle_react_message(msg: dict, session: dict, writer: ProtocolWriter):
|
|
if await _is_rate_limited(f"react|{session['user_id']}", 20):
|
|
await send_resp(msg, writer, "react_message", "error", {"message": "Too many requests. Try later."})
|
|
return
|
|
message_id = msg.get("message_id", "")
|
|
reaction = msg.get("reaction", "")
|
|
action = msg.get("action", "add") # "add" or "remove"
|
|
|
|
if not message_id or not reaction:
|
|
await send_resp(msg, writer, "react_message", "error", {"message": "Missing fields"})
|
|
return
|
|
if not _valid_uuid(message_id):
|
|
await send_resp(msg, writer, "react_message", "error", {"message": "Invalid message_id"})
|
|
return
|
|
if reaction not in db.ALLOWED_REACTIONS:
|
|
await send_resp(msg, writer, "react_message", "error", {"message": "Invalid reaction"})
|
|
return
|
|
if action not in ("add", "remove"):
|
|
await send_resp(msg, writer, "react_message", "error", {"message": "Invalid action"})
|
|
return
|
|
|
|
conv_id = await adb.get_message_conversation(message_id)
|
|
if not conv_id:
|
|
await send_resp(msg, writer, "react_message", "error", {"message": "Message not found"})
|
|
return
|
|
if not await adb.is_conversation_member(conv_id, session["user_id"]):
|
|
await send_resp(msg, writer, "react_message", "error", {"message": "Not a member"})
|
|
return
|
|
|
|
old_reaction = None
|
|
if action == "add":
|
|
changed, old_reaction = await adb.add_reaction(message_id, session["user_id"], reaction)
|
|
if not changed:
|
|
await send_resp(msg, writer, "react_message", "ok", {"message_id": message_id})
|
|
return
|
|
else:
|
|
await adb.remove_reaction(message_id, session["user_id"])
|
|
|
|
logger.info("[MSG] %s %s reaction '%s' on message=%s", _who(session), action, reaction, message_id[:8])
|
|
resp_data = {"message_id": message_id}
|
|
if old_reaction:
|
|
resp_data["old_reaction"] = old_reaction
|
|
await send_resp(msg, writer, "react_message", "ok", resp_data)
|
|
|
|
members = await adb.get_conversation_members(conv_id)
|
|
member_ids = [m["id"] for m in members if m["id"] != session["user_id"]]
|
|
|
|
# If replacing an old reaction, notify removal first
|
|
if old_reaction:
|
|
remove_data = {
|
|
"message_id": message_id,
|
|
"conversation_id": conv_id,
|
|
"user_id": session["user_id"],
|
|
"username": session.get("username", ""),
|
|
"reaction": old_reaction,
|
|
"action": "remove",
|
|
}
|
|
await _notify_users(member_ids, "message_reacted", remove_data)
|
|
|
|
notif_data = {
|
|
"message_id": message_id,
|
|
"conversation_id": conv_id,
|
|
"user_id": session["user_id"],
|
|
"username": session.get("username", ""),
|
|
"reaction": reaction,
|
|
"action": action,
|
|
}
|
|
await _notify_users(member_ids, "message_reacted", notif_data)
|
|
|
|
|
|
async def handle_pin_message(msg: dict, session: dict, writer: ProtocolWriter):
|
|
message_id = msg.get("message_id", "")
|
|
action = msg.get("action", "pin") # "pin" or "unpin"
|
|
conversation_id = msg.get("conversation_id", "")
|
|
|
|
if not message_id or not conversation_id:
|
|
await send_resp(msg, writer, "pin_message", "error", {"message": "Missing fields"})
|
|
return
|
|
if not _valid_uuid(message_id) or not _valid_uuid(conversation_id):
|
|
await send_resp(msg, writer, "pin_message", "error", {"message": "Invalid ID"})
|
|
return
|
|
if action not in ("pin", "unpin"):
|
|
await send_resp(msg, writer, "pin_message", "error", {"message": "Invalid action"})
|
|
return
|
|
if not await adb.is_conversation_member(conversation_id, session["user_id"]):
|
|
await send_resp(msg, writer, "pin_message", "error", {"message": "Not a member"})
|
|
return
|
|
|
|
if action == "pin":
|
|
ok = await adb.pin_message(message_id, session["user_id"], conversation_id)
|
|
else:
|
|
ok = await adb.unpin_message(message_id, conversation_id)
|
|
|
|
if not ok:
|
|
await send_resp(msg, writer, "pin_message", "error",
|
|
{"message": "Already pinned" if action == "pin" else "Not pinned"})
|
|
return
|
|
|
|
logger.info("[MSG] %s %s message=%s in conv=%s", _who(session), action, message_id[:8], conversation_id[:8])
|
|
await send_resp(msg, writer, "pin_message", "ok", {"message_id": message_id, "action": action})
|
|
|
|
members = await adb.get_conversation_members(conversation_id)
|
|
notif_type = "message_pinned" if action == "pin" else "message_unpinned"
|
|
notif_data = {
|
|
"message_id": message_id,
|
|
"conversation_id": conversation_id,
|
|
"user_id": session["user_id"],
|
|
"username": session.get("username", ""),
|
|
}
|
|
member_ids = [m["id"] for m in members if m["id"] != session["user_id"]]
|
|
await _notify_users(member_ids, notif_type, notif_data)
|
|
|
|
|
|
async def handle_get_pinned_messages(msg: dict, session: dict, writer: ProtocolWriter):
|
|
conversation_id = msg.get("conversation_id", "")
|
|
if not conversation_id:
|
|
await send_resp(msg, writer, "get_pinned_messages", "error", {"message": "Missing conversation_id"})
|
|
return
|
|
if not _valid_uuid(conversation_id):
|
|
await send_resp(msg, writer, "get_pinned_messages", "error", {"message": "Invalid conversation_id"})
|
|
return
|
|
if not await adb.is_conversation_member(conversation_id, session["user_id"]):
|
|
await send_resp(msg, writer, "get_pinned_messages", "error", {"message": "Not a member"})
|
|
return
|
|
|
|
pinned = await adb.get_pinned_messages(conversation_id, session["user_id"])
|
|
await send_resp(msg, writer, "get_pinned_messages", "ok", {"messages": pinned})
|
|
|
|
|
|
async def handle_upload_image_start(msg: dict, session: dict, writer: ProtocolWriter):
|
|
conv_id = msg.get("conversation_id", "")
|
|
file_size = msg.get("file_size", 0)
|
|
file_id = msg.get("file_id", "")
|
|
file_type = msg.get("file_type", "image") # "image" or "file"
|
|
if not conv_id or not file_id:
|
|
await send_resp(msg, writer, "upload_image_start", "error", {"message": "Missing fields"})
|
|
return
|
|
if not _valid_uuid(file_id):
|
|
await send_resp(msg, writer, "upload_image_start", "error", {"message": "Invalid file_id"})
|
|
return
|
|
# M5: rate limit + caps on in-flight uploads
|
|
addr = _get_peer_addr(writer)
|
|
if await _is_rate_limited(f"upload_start|{session['user_id']}", 10):
|
|
await send_resp(msg, writer, "upload_image_start", "error", {"message": "Too many uploads. Try later."})
|
|
return
|
|
if not await adb.is_conversation_member(conv_id, session["user_id"]):
|
|
await send_resp(msg, writer, "upload_image_start", "error", {"message": "Not a member"})
|
|
return
|
|
max_bytes = MAX_FILE_BYTES if file_type == "file" else MAX_IMAGE_BYTES
|
|
if max_bytes > 0 and file_size > max_bytes:
|
|
await send_resp(msg, writer, "upload_image_start", "error",
|
|
{"message": f"File too large (max {max_bytes} bytes)"})
|
|
return
|
|
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
|
os.chmod(UPLOAD_DIR, 0o700)
|
|
temp_path = _safe_upload_path(file_id, ".tmp")
|
|
if not temp_path:
|
|
await send_resp(msg, writer, "upload_image_start", "error", {"message": "Invalid file_id"})
|
|
return
|
|
# M5: atomic cap check + insert under single lock acquisition
|
|
cap_error = ""
|
|
async with _uploads_lock:
|
|
total = len(pending_uploads)
|
|
user_count = sum(1 for u in pending_uploads.values() if u.get("uploader_id") == session["user_id"])
|
|
if total >= MAX_UPLOADS_GLOBAL:
|
|
cap_error = "Server upload limit reached. Try later."
|
|
elif user_count >= MAX_UPLOADS_PER_USER:
|
|
cap_error = "Too many active uploads. Finish or cancel existing ones."
|
|
else:
|
|
temp_path.write_bytes(b"")
|
|
os.chmod(temp_path, 0o600)
|
|
pending_uploads[file_id] = {
|
|
"temp_path": str(temp_path),
|
|
"received_bytes": 0,
|
|
"file_size": file_size,
|
|
"max_bytes": max_bytes,
|
|
"conv_id": conv_id,
|
|
"uploader_id": session["user_id"],
|
|
}
|
|
if cap_error:
|
|
await send_resp(msg, writer, "upload_image_start", "error", {"message": cap_error})
|
|
return
|
|
try:
|
|
await adb.create_image_upload(file_id, conv_id, session["user_id"], file_size)
|
|
except Exception:
|
|
# Rollback: remove from pending_uploads + delete temp file
|
|
async with _uploads_lock:
|
|
pending_uploads.pop(file_id, None)
|
|
_secure_delete(temp_path)
|
|
logger.exception("[UPLOAD] DB create failed for file=%s", file_id[:8])
|
|
await send_resp(msg, writer, "upload_image_start", "error", {"message": "Upload failed"})
|
|
return
|
|
logger.info("[UPLOAD] %s started upload file=%s (%s, %d bytes)",
|
|
_who(session), file_id[:8], file_type, file_size)
|
|
await send_resp(msg, writer, "upload_image_start", "ok", {"file_id": file_id})
|
|
|
|
|
|
async def handle_upload_image_chunk(msg: dict, session: dict, writer: ProtocolWriter):
|
|
file_id = msg.get("file_id", "")
|
|
chunk_data = msg.get("data", "")
|
|
if not file_id or not chunk_data:
|
|
await send_resp(msg, writer, "upload_image_chunk", "error", {"message": "Missing fields"})
|
|
return
|
|
async with _uploads_lock:
|
|
upload = pending_uploads.get(file_id)
|
|
if not upload or upload["uploader_id"] != session["user_id"]:
|
|
upload = None
|
|
else:
|
|
temp_path_str = upload["temp_path"]
|
|
upload_max = upload.get("max_bytes", 0)
|
|
if not upload:
|
|
await send_resp(msg, writer, "upload_image_chunk", "error", {"message": "No active upload"})
|
|
return
|
|
raw = decode_binary(chunk_data)
|
|
temp_path = Path(temp_path_str)
|
|
await asyncio.to_thread(_append_file, temp_path, raw)
|
|
over_limit = False
|
|
async with _uploads_lock:
|
|
upload = pending_uploads.get(file_id)
|
|
if upload:
|
|
upload["received_bytes"] += len(raw)
|
|
if upload_max > 0 and upload["received_bytes"] > upload_max:
|
|
pending_uploads.pop(file_id, None)
|
|
over_limit = True
|
|
received = upload["received_bytes"]
|
|
if over_limit:
|
|
_secure_delete(temp_path)
|
|
await send_resp(msg, writer, "upload_image_chunk", "error", {"message": "Upload exceeds size limit"})
|
|
return
|
|
await send_resp(msg, writer, "upload_image_chunk", "ok", {"received": received})
|
|
|
|
|
|
async def handle_upload_image_end(msg: dict, session: dict, writer: ProtocolWriter):
|
|
file_id = msg.get("file_id", "")
|
|
if not file_id:
|
|
await send_resp(msg, writer, "upload_image_end", "error", {"message": "Missing file_id"})
|
|
return
|
|
async with _uploads_lock:
|
|
upload = pending_uploads.pop(file_id, None)
|
|
if not upload or upload["uploader_id"] != session["user_id"]:
|
|
await send_resp(msg, writer, "upload_image_end", "error", {"message": "No active upload"})
|
|
return
|
|
temp_path = Path(upload["temp_path"])
|
|
if upload["received_bytes"] != upload["file_size"]:
|
|
_secure_delete(temp_path)
|
|
await send_resp(msg, writer, "upload_image_end", "error",
|
|
{"message": f"Incomplete upload: received {upload['received_bytes']} of {upload['file_size']} bytes"})
|
|
return
|
|
final_path = _safe_upload_path(file_id, ".enc")
|
|
if not final_path:
|
|
_secure_delete(temp_path)
|
|
await send_resp(msg, writer, "upload_image_end", "error", {"message": "Invalid file_id"})
|
|
return
|
|
def _move_file():
|
|
try:
|
|
temp_path.rename(final_path)
|
|
except Exception:
|
|
import shutil
|
|
shutil.move(str(temp_path), str(final_path))
|
|
os.chmod(final_path, 0o600)
|
|
await asyncio.to_thread(_move_file)
|
|
await adb.complete_image_upload(file_id)
|
|
logger.info("[UPLOAD] %s completed upload file=%s (%d bytes)",
|
|
_who(session), file_id[:8], upload["received_bytes"])
|
|
await send_resp(msg, writer, "upload_image_end", "ok", {"file_id": file_id})
|
|
|
|
|
|
async def _validate_download(msg: dict, session: dict, writer: ProtocolWriter, resp_type: str):
|
|
"""Validate file_id, check DB, return (file_path, file_size) or None on error."""
|
|
file_id = msg.get("file_id", "")
|
|
if not file_id:
|
|
await send_resp(msg, writer, resp_type, "error", {"message": "Missing file_id"})
|
|
return None
|
|
if not _valid_uuid(file_id):
|
|
await send_resp(msg, writer, resp_type, "error", {"message": "Invalid file_id"})
|
|
return None
|
|
upload = await adb.get_image_upload(file_id)
|
|
if not upload or not upload["completed"]:
|
|
await send_resp(msg, writer, resp_type, "error", {"message": "File not found"})
|
|
return None
|
|
if not await adb.is_conversation_member(upload["conversation_id"], session["user_id"]):
|
|
await send_resp(msg, writer, resp_type, "error", {"message": "Not a member"})
|
|
return None
|
|
file_path = _safe_upload_path(file_id, ".enc")
|
|
if not file_path or not file_path.exists():
|
|
await send_resp(msg, writer, resp_type, "error", {"message": "File not found"})
|
|
return None
|
|
return file_path, file_path.stat().st_size
|
|
|
|
|
|
async def handle_download_image(msg: dict, session: dict, writer: ProtocolWriter):
|
|
file_id = msg.get("file_id", "")
|
|
offset = msg.get("offset", 0)
|
|
result = await _validate_download(msg, session, writer, "download_image")
|
|
if not result:
|
|
return
|
|
file_path, file_size = result
|
|
chunk = await asyncio.to_thread(_read_file_chunk, file_path, offset, IMAGE_CHUNK_SIZE)
|
|
done = (offset + len(chunk)) >= file_size
|
|
if offset == 0:
|
|
logger.info("[DOWNLOAD] %s downloading file=%s (%d bytes)", _who(session), file_id[:8], file_size)
|
|
await send_resp(msg, writer, "download_image", "ok", {
|
|
"file_id": file_id,
|
|
"data": encode_binary(chunk),
|
|
"offset": offset,
|
|
"done": done,
|
|
"total_size": file_size,
|
|
})
|
|
|
|
|
|
async def handle_download_stream(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""Stream entire file in chunks after a single auth check. No per-chunk round-trip.
|
|
|
|
Server sends multiple responses with the same request_id. Client collects
|
|
them until it sees done=True. Each response has a unique ``seq`` number
|
|
so the client can reassemble chunks in order even if delivery is reordered.
|
|
"""
|
|
file_id = msg.get("file_id", "")
|
|
result = await _validate_download(msg, session, writer, "download_stream")
|
|
if not result:
|
|
return
|
|
file_path, file_size = result
|
|
logger.info("[DOWNLOAD] %s streaming file=%s (%d bytes)", _who(session), file_id[:8], file_size)
|
|
offset = 0
|
|
seq = 0
|
|
while offset < file_size:
|
|
chunk = await asyncio.to_thread(_read_file_chunk, file_path, offset, IMAGE_CHUNK_SIZE)
|
|
if not chunk:
|
|
break
|
|
done = (offset + len(chunk)) >= file_size
|
|
# Build response manually so we can reuse the original request_id
|
|
resp = {
|
|
"type": "download_stream",
|
|
"status": "ok",
|
|
"data": {
|
|
"file_id": file_id,
|
|
"data": encode_binary(chunk),
|
|
"offset": offset,
|
|
"seq": seq,
|
|
"done": done,
|
|
"total_size": file_size,
|
|
},
|
|
}
|
|
req_id = msg.get("request_id")
|
|
if req_id:
|
|
resp["request_id"] = req_id
|
|
data = json.dumps(resp, ensure_ascii=False).encode("utf-8") + b"\n"
|
|
try:
|
|
writer._writer.write(data)
|
|
await writer._writer.drain()
|
|
except Exception:
|
|
break
|
|
offset += len(chunk)
|
|
seq += 1
|
|
|
|
|
|
MAX_AVATAR_BYTES = 2 * 1024 * 1024 # 2 MB
|
|
|
|
|
|
async def handle_get_profile(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""Get user profile (respects visibility for other users)."""
|
|
target_user_id = msg.get("user_id", "").strip()
|
|
if not target_user_id:
|
|
target_user_id = session["user_id"]
|
|
elif not _valid_uuid(target_user_id):
|
|
await send_resp(msg, writer, "get_profile", "error", {"message": "Invalid user_id"})
|
|
return
|
|
profile = await adb.get_user_profile(target_user_id, viewer_id=session["user_id"])
|
|
if not profile:
|
|
await send_resp(msg, writer, "get_profile", "error", {"message": "User not found"})
|
|
return
|
|
# Serialize datetime fields
|
|
for key in ("created_at", "updated_at"):
|
|
if profile.get(key) and hasattr(profile[key], "isoformat"):
|
|
profile[key] = profile[key].isoformat()
|
|
await send_resp(msg, writer, "get_profile", "ok", profile)
|
|
|
|
|
|
async def handle_update_profile(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""Update own profile fields."""
|
|
fields = {}
|
|
for key in ("phone", "phone_visible", "email_visible", "location", "location_visible"):
|
|
if key in msg:
|
|
fields[key] = msg[key]
|
|
if not fields:
|
|
await send_resp(msg, writer, "update_profile", "error", {"message": "No fields to update"})
|
|
return
|
|
await adb.update_user_profile(session["user_id"], **fields)
|
|
await send_resp(msg, writer, "update_profile", "ok", {"message": "OK"})
|
|
|
|
|
|
async def handle_update_avatar(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""Upload avatar (base64 in single message, max 2MB)."""
|
|
if await _is_rate_limited(f"update_avatar|{session['user_id']}", 5):
|
|
await send_resp(msg, writer, "update_avatar", "error", {"message": "Too many requests. Try later."})
|
|
return
|
|
avatar_b64 = msg.get("data", "")
|
|
if not avatar_b64:
|
|
await send_resp(msg, writer, "update_avatar", "error", {"message": "Missing data"})
|
|
return
|
|
avatar_data = decode_binary(avatar_b64)
|
|
if len(avatar_data) > MAX_AVATAR_BYTES:
|
|
await send_resp(msg, writer, "update_avatar", "error",
|
|
{"message": f"Avatar too large (max {MAX_AVATAR_BYTES} bytes)"})
|
|
return
|
|
# Detect format from magic bytes
|
|
ext = "jpg"
|
|
if avatar_data[:8] == b'\x89PNG\r\n\x1a\n':
|
|
ext = "png"
|
|
avatar_dir = UPLOAD_DIR / "avatars"
|
|
avatar_dir.mkdir(parents=True, exist_ok=True)
|
|
os.chmod(avatar_dir, 0o700)
|
|
filename = f"{session['user_id']}.{ext}"
|
|
avatar_path = _safe_avatar_path(filename)
|
|
if not avatar_path:
|
|
await send_resp(msg, writer, "update_avatar", "error", {"message": "Invalid path"})
|
|
return
|
|
await asyncio.to_thread(avatar_path.write_bytes, avatar_data)
|
|
os.chmod(avatar_path, 0o600)
|
|
await adb.update_user_profile(session["user_id"], avatar_file=filename)
|
|
logger.info("[AVATAR] %s updated their avatar", _who(session))
|
|
await send_resp(msg, writer, "update_avatar", "ok", {"avatar_file": filename})
|
|
# Notify contacts about avatar change
|
|
contacts = await adb.get_user_contacts(session["user_id"])
|
|
if contacts:
|
|
await _notify_users(contacts, "avatar_changed", {
|
|
"user_id": session["user_id"],
|
|
}, exclude_writer=writer)
|
|
|
|
|
|
async def handle_get_avatar(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""Download avatar for a user."""
|
|
target_user_id = msg.get("user_id", "").strip()
|
|
if not target_user_id:
|
|
await send_resp(msg, writer, "get_avatar", "error", {"message": "Missing user_id"})
|
|
return
|
|
if not _valid_uuid(target_user_id):
|
|
await send_resp(msg, writer, "get_avatar", "error", {"message": "Invalid user_id"})
|
|
return
|
|
profile = await adb.get_user_profile(target_user_id)
|
|
if not profile or not profile.get("avatar_file"):
|
|
logger.debug("[AVATAR] get_avatar for %s — no avatar_file in profile", target_user_id[:8])
|
|
await send_resp(msg, writer, "get_avatar", "error", {"message": "No avatar"})
|
|
return
|
|
avatar_path = _safe_avatar_path(profile["avatar_file"])
|
|
if not avatar_path or not avatar_path.exists():
|
|
logger.warning("[AVATAR] get_avatar for %s — file missing: %s", target_user_id[:8], profile["avatar_file"])
|
|
await send_resp(msg, writer, "get_avatar", "error", {"message": "Avatar file not found"})
|
|
return
|
|
avatar_data = await asyncio.to_thread(avatar_path.read_bytes)
|
|
await send_resp(msg, writer, "get_avatar", "ok", {
|
|
"user_id": target_user_id,
|
|
"data": encode_binary(avatar_data),
|
|
"filename": profile["avatar_file"],
|
|
})
|
|
|
|
|
|
async def handle_update_group_avatar(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""Upload avatar for a group conversation (base64, max 2MB). Only members can set it."""
|
|
if await _is_rate_limited(f"update_avatar|{session['user_id']}", 5):
|
|
await send_resp(msg, writer, "update_group_avatar", "error", {"message": "Too many requests. Try later."})
|
|
return
|
|
conv_id = msg.get("conversation_id", "").strip()
|
|
avatar_b64 = msg.get("data", "")
|
|
if not conv_id or not avatar_b64:
|
|
await send_resp(msg, writer, "update_group_avatar", "error", {"message": "Missing fields"})
|
|
return
|
|
if not _valid_uuid(conv_id):
|
|
await send_resp(msg, writer, "update_group_avatar", "error", {"message": "Invalid conversation_id"})
|
|
return
|
|
if not await adb.is_conversation_member(conv_id, session["user_id"]):
|
|
await send_resp(msg, writer, "update_group_avatar", "error", {"message": "Not a member"})
|
|
return
|
|
avatar_data = decode_binary(avatar_b64)
|
|
if len(avatar_data) > MAX_AVATAR_BYTES:
|
|
await send_resp(msg, writer, "update_group_avatar", "error",
|
|
{"message": f"Avatar too large (max {MAX_AVATAR_BYTES} bytes)"})
|
|
return
|
|
ext = "jpg"
|
|
if avatar_data[:8] == b'\x89PNG\r\n\x1a\n':
|
|
ext = "png"
|
|
avatar_dir = UPLOAD_DIR / "avatars"
|
|
avatar_dir.mkdir(parents=True, exist_ok=True)
|
|
os.chmod(avatar_dir, 0o700)
|
|
filename = f"group_{conv_id}.{ext}"
|
|
avatar_path = _safe_avatar_path(filename)
|
|
if not avatar_path:
|
|
await send_resp(msg, writer, "update_group_avatar", "error", {"message": "Invalid path"})
|
|
return
|
|
await asyncio.to_thread(avatar_path.write_bytes, avatar_data)
|
|
os.chmod(avatar_path, 0o600)
|
|
await adb.update_conversation_avatar(conv_id, filename)
|
|
logger.info("[AVATAR] %s updated group avatar for conv=%s", _who(session), conv_id[:8])
|
|
await send_resp(msg, writer, "update_group_avatar", "ok", {"avatar_file": filename})
|
|
|
|
|
|
async def handle_get_group_avatar(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""Download avatar for a group conversation."""
|
|
conv_id = msg.get("conversation_id", "").strip()
|
|
if not conv_id:
|
|
await send_resp(msg, writer, "get_group_avatar", "error", {"message": "Missing conversation_id"})
|
|
return
|
|
if not _valid_uuid(conv_id):
|
|
await send_resp(msg, writer, "get_group_avatar", "error", {"message": "Invalid conversation_id"})
|
|
return
|
|
if not await adb.is_conversation_member(conv_id, session["user_id"]):
|
|
await send_resp(msg, writer, "get_group_avatar", "error", {"message": "Not a member"})
|
|
return
|
|
conv = await adb.get_conversation(conv_id)
|
|
if not conv or not conv.get("avatar_file"):
|
|
await send_resp(msg, writer, "get_group_avatar", "error", {"message": "No avatar"})
|
|
return
|
|
avatar_path = _safe_avatar_path(conv["avatar_file"])
|
|
if not avatar_path or not avatar_path.exists():
|
|
await send_resp(msg, writer, "get_group_avatar", "error", {"message": "Avatar file not found"})
|
|
return
|
|
avatar_data = await asyncio.to_thread(avatar_path.read_bytes)
|
|
await send_resp(msg, writer, "get_group_avatar", "ok", {
|
|
"conversation_id": conv_id,
|
|
"data": encode_binary(avatar_data),
|
|
"filename": conv["avatar_file"],
|
|
})
|
|
|
|
|
|
async def handle_list_devices(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""List all devices for the current user."""
|
|
devices = await adb.get_user_devices(session["user_id"])
|
|
result = []
|
|
for d in devices:
|
|
entry = {
|
|
"device_id": d["id"],
|
|
"device_name": d.get("device_name"),
|
|
"created_at": d["created_at"].isoformat() if hasattr(d["created_at"], "isoformat") else str(d["created_at"]),
|
|
"last_seen_at": d["last_seen_at"].isoformat() if d.get("last_seen_at") and hasattr(d["last_seen_at"], "isoformat") else (str(d["last_seen_at"]) if d.get("last_seen_at") else None),
|
|
"is_current": d["id"] == session.get("device_id"),
|
|
}
|
|
result.append(entry)
|
|
await send_resp(msg, writer, "list_devices", "ok", {"devices": result})
|
|
|
|
|
|
async def handle_remove_device(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""Remove a device (cannot remove current device)."""
|
|
device_id = msg.get("device_id", "").strip()
|
|
if not device_id:
|
|
await send_resp(msg, writer, "remove_device", "error", {"message": "Missing device_id"})
|
|
return
|
|
if not _valid_uuid(device_id):
|
|
await send_resp(msg, writer, "remove_device", "error", {"message": "Invalid device_id"})
|
|
return
|
|
if device_id == session.get("device_id"):
|
|
await send_resp(msg, writer, "remove_device", "error", {"message": "Cannot remove current device"})
|
|
return
|
|
dev = await adb.get_device(device_id)
|
|
if not dev or dev["user_id"] != session["user_id"]:
|
|
await send_resp(msg, writer, "remove_device", "error", {"message": "Device not found"})
|
|
return
|
|
await adb.delete_device(device_id)
|
|
logger.info("[DEVICE] %s removed device=%s", _who(session), device_id[:8])
|
|
await send_resp(msg, writer, "remove_device", "ok", {"message": "OK"})
|
|
|
|
|
|
async def handle_session_reset(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""Notify peer to reset a corrupted Double Ratchet session."""
|
|
peer_user_id = msg.get("peer_user_id", "").strip()
|
|
peer_device_id = msg.get("peer_device_id", "").strip() or None
|
|
if not peer_user_id or not _valid_uuid(peer_user_id):
|
|
await send_resp(msg, writer, "session_reset", "error", {"message": "Invalid peer_user_id"})
|
|
return
|
|
if peer_device_id and not _valid_uuid(peer_device_id):
|
|
await send_resp(msg, writer, "session_reset", "error", {"message": "Invalid peer_device_id"})
|
|
return
|
|
# H3 fix: rate limit (5/min per user, keyed by user_id only — IP-independent)
|
|
if await _is_rate_limited(f"session_reset|{session['user_id']}", 5):
|
|
await send_resp(msg, writer, "session_reset", "error", {"message": "Rate limit exceeded"})
|
|
return
|
|
# H3 fix: verify users share at least one conversation
|
|
if not await adb.shares_conversation(session["user_id"], peer_user_id):
|
|
await send_resp(msg, writer, "session_reset", "error", {"message": "No shared conversation"})
|
|
return
|
|
# Push notification to peer (target specific device if specified)
|
|
notif_data = {
|
|
"from_user_id": session["user_id"],
|
|
"from_device_id": session.get("device_id"),
|
|
}
|
|
if peer_device_id:
|
|
# Send only to the specific device
|
|
targets = []
|
|
async with _clients_lock:
|
|
for w in connected_clients.get(peer_user_id, []):
|
|
if writer_device_map.get(id(w)) == peer_device_id:
|
|
targets.append(w)
|
|
for w in targets:
|
|
try:
|
|
await w.send_response("session_reset", "ok", notif_data)
|
|
except Exception:
|
|
pass
|
|
else:
|
|
await _notify_users([peer_user_id], "session_reset", notif_data)
|
|
logger.info("[SESSION] %s reset session with peer=%s", _who(session), peer_user_id[:8])
|
|
await send_resp(msg, writer, "session_reset", "ok", {})
|
|
|
|
|
|
async def handle_get_deleted_since(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""Return message IDs deleted since a given timestamp."""
|
|
conv_id = msg.get("conversation_id", "")
|
|
since_ts = msg.get("since_ts", "")
|
|
if not conv_id or not since_ts:
|
|
await send_resp(msg, writer, "get_deleted_since", "error", {"message": "Missing parameters"})
|
|
return
|
|
if not _valid_uuid(conv_id):
|
|
await send_resp(msg, writer, "get_deleted_since", "error", {"message": "Invalid conversation_id"})
|
|
return
|
|
if not await adb.is_conversation_member(conv_id, session["user_id"]):
|
|
await send_resp(msg, writer, "get_deleted_since", "error", {"message": "Not a member"})
|
|
return
|
|
deleted_ids = await adb.get_deleted_messages_since(conv_id, session["user_id"], since_ts)
|
|
await send_resp(msg, writer, "get_deleted_since", "ok", {"deleted_ids": deleted_ids})
|
|
|
|
|
|
async def handle_reencrypt_messages(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""Re-encrypt message history with self-encryption key (for device pairing)."""
|
|
if await _is_rate_limited(f"reencrypt|{session['user_id']}", 10):
|
|
await send_resp(msg, writer, "reencrypt_messages", "error", {"message": "Too many requests. Try later."})
|
|
return
|
|
updates_raw = msg.get("updates", [])
|
|
if not updates_raw:
|
|
await send_resp(msg, writer, "reencrypt_messages", "error", {"message": "No updates"})
|
|
return
|
|
if len(updates_raw) > 500:
|
|
await send_resp(msg, writer, "reencrypt_messages", "error",
|
|
{"message": "Too many updates (max 500 per request)"})
|
|
return
|
|
updates = []
|
|
for u in updates_raw:
|
|
mid = u.get("message_id", "")
|
|
ct_b64 = u.get("encrypted_content", "")
|
|
nonce_b64 = u.get("nonce", "")
|
|
if not mid or not ct_b64 or not nonce_b64:
|
|
continue
|
|
updates.append({
|
|
"message_id": mid,
|
|
"encrypted_content": decode_binary(ct_b64),
|
|
"nonce": decode_binary(nonce_b64),
|
|
})
|
|
if not updates:
|
|
await send_resp(msg, writer, "reencrypt_messages", "error", {"message": "No valid updates"})
|
|
return
|
|
await adb.batch_reencrypt_messages(session["user_id"], updates)
|
|
logger.info("[REENCRYPT] %s re-encrypted %d messages", _who(session), len(updates))
|
|
await send_resp(msg, writer, "reencrypt_messages", "ok", {"count": len(updates)})
|
|
|
|
|
|
async def _cleanup_uploads():
|
|
stale = await adb.get_stale_uploads(UPLOAD_STALE_SECONDS)
|
|
for s in stale:
|
|
fid = s["file_id"]
|
|
for ext in (".tmp", ".enc"):
|
|
p = _safe_upload_path(fid, ext)
|
|
if not p:
|
|
continue
|
|
_secure_delete(p)
|
|
await adb.delete_image_upload(fid)
|
|
async with _uploads_lock:
|
|
pending_uploads.pop(fid, None)
|
|
if stale:
|
|
logger.info("Cleaned up %d stale uploads.", len(stale))
|
|
|
|
|
|
async def handle_client(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
|
global current_connections
|
|
addr = _get_peer_addr(ProtocolWriter(writer))
|
|
async with _conn_lock:
|
|
current_connections += 1
|
|
connection_counts[addr] = connection_counts.get(addr, 0) + 1
|
|
over_limit = (current_connections > MAX_CONNECTIONS_GLOBAL or
|
|
connection_counts[addr] > MAX_CONNECTIONS_PER_IP)
|
|
if over_limit:
|
|
try:
|
|
writer.close()
|
|
except Exception:
|
|
pass
|
|
async with _conn_lock:
|
|
current_connections = max(0, current_connections - 1)
|
|
connection_counts[addr] = max(0, connection_counts.get(addr, 1) - 1)
|
|
return
|
|
logger.info("[CONN] Client connected from %s", addr)
|
|
|
|
# Enable TCP keepalive on the socket to detect dead connections
|
|
sock = writer.get_extra_info("socket")
|
|
if sock is not None:
|
|
try:
|
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
|
|
if hasattr(socket, "TCP_KEEPIDLE"):
|
|
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, TCP_KEEPALIVE_IDLE)
|
|
if hasattr(socket, "TCP_KEEPINTVL"):
|
|
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, TCP_KEEPALIVE_INTERVAL)
|
|
if hasattr(socket, "TCP_KEEPCNT"):
|
|
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, TCP_KEEPALIVE_COUNT)
|
|
except OSError:
|
|
pass # Some platforms/TLS wrappers don't support these options
|
|
|
|
proto_reader = ProtocolReader(reader)
|
|
proto_writer = ProtocolWriter(writer)
|
|
session = None
|
|
state = {"_req_times": []}
|
|
|
|
try:
|
|
while True:
|
|
try:
|
|
msg = await proto_reader.read_message()
|
|
except ValueError as e:
|
|
try:
|
|
await proto_writer.send_response("protocol_error", "error", {"message": str(e)})
|
|
except Exception:
|
|
pass
|
|
break
|
|
if msg is None:
|
|
break
|
|
|
|
msg_type = msg.get("type", "")
|
|
now = asyncio.get_event_loop().time()
|
|
# Upload chunks are exempt from per-connection rate limit —
|
|
# a single file upload can legitimately send 20+ chunks in rapid
|
|
# succession. The upload subsystem has its own guards (per-user
|
|
# upload cap, per-user rate limit on upload_image_start, and
|
|
# file-size validation) so double-throttling is unnecessary.
|
|
_rl_exempt = msg_type == "upload_image_chunk"
|
|
if not _rl_exempt:
|
|
times = [t for t in state["_req_times"] if now - t <= CONNECTION_RL_WINDOW]
|
|
if len(times) >= CONNECTION_RL_MAX:
|
|
await send_resp(msg, proto_writer, msg_type, "error", {"message": "Too many requests. Slow down."})
|
|
state["_req_times"] = times
|
|
continue
|
|
times.append(now)
|
|
state["_req_times"] = times
|
|
|
|
try:
|
|
if msg_type == "register":
|
|
await handle_register_start(msg, proto_writer)
|
|
elif msg_type == "register_confirm":
|
|
await handle_register_confirm(msg, proto_writer)
|
|
elif msg_type == "login_start":
|
|
await handle_login_start(msg, proto_writer, state)
|
|
elif msg_type == "login_finish":
|
|
result = await handle_login_finish(msg, proto_writer, state)
|
|
if result:
|
|
session = result
|
|
elif msg_type == "pairing_start":
|
|
await handle_pairing_start(msg, proto_writer)
|
|
elif msg_type == "pairing_poll":
|
|
await handle_pairing_poll(msg, proto_writer)
|
|
elif session is None:
|
|
await send_resp(msg, proto_writer, msg_type, "error", {"message": "Not logged in"})
|
|
elif msg_type == "get_user_info":
|
|
await handle_get_user_info(msg, session, proto_writer)
|
|
elif msg_type == "upload_prekeys":
|
|
await handle_upload_prekeys(msg, session, proto_writer)
|
|
elif msg_type == "get_key_bundle":
|
|
await handle_get_key_bundle(msg, session, proto_writer)
|
|
elif msg_type == "get_prekey_count":
|
|
await handle_get_prekey_count(msg, session, proto_writer)
|
|
elif msg_type == "ensure_prekeys":
|
|
await handle_ensure_prekeys(msg, session, proto_writer)
|
|
elif msg_type == "create_conversation":
|
|
await handle_create_conversation(msg, session, proto_writer)
|
|
elif msg_type == "find_conversation":
|
|
await handle_find_conversation(msg, session, proto_writer)
|
|
elif msg_type == "add_member":
|
|
await handle_add_member(msg, session, proto_writer)
|
|
elif msg_type == "accept_invitation":
|
|
await handle_accept_invitation(msg, session, proto_writer)
|
|
elif msg_type == "decline_invitation":
|
|
await handle_decline_invitation(msg, session, proto_writer)
|
|
elif msg_type == "list_invitations":
|
|
await handle_list_invitations(msg, session, proto_writer)
|
|
elif msg_type == "list_conversations":
|
|
await handle_list_conversations(msg, session, proto_writer)
|
|
elif msg_type == "send_message":
|
|
await handle_send_message(msg, session, proto_writer)
|
|
elif msg_type == "get_messages":
|
|
await handle_get_messages(msg, session, proto_writer)
|
|
elif msg_type == "rotate_keys":
|
|
await handle_rotate_keys(msg, session, proto_writer)
|
|
elif msg_type == "change_username":
|
|
await handle_change_username(msg, session, proto_writer)
|
|
elif msg_type == "remove_member":
|
|
await handle_remove_member(msg, session, proto_writer)
|
|
elif msg_type == "leave_group":
|
|
await handle_leave_group(msg, session, proto_writer)
|
|
elif msg_type == "rename_conversation":
|
|
await handle_rename_conversation(msg, session, proto_writer)
|
|
elif msg_type == "delete_conversation":
|
|
await handle_delete_conversation(msg, session, proto_writer)
|
|
elif msg_type == "mark_read":
|
|
await handle_mark_read(msg, session, proto_writer)
|
|
elif msg_type == "mark_conversation_read":
|
|
await handle_mark_conversation_read(msg, session, proto_writer)
|
|
elif msg_type == "confirm_delivery":
|
|
await handle_confirm_delivery(msg, session, proto_writer)
|
|
elif msg_type == "typing_start":
|
|
await handle_typing_start(msg, session, proto_writer)
|
|
elif msg_type == "typing_stop":
|
|
await handle_typing_stop(msg, session, proto_writer)
|
|
elif msg_type == "pairing_claim":
|
|
await handle_pairing_claim(msg, session, proto_writer)
|
|
elif msg_type == "pairing_send":
|
|
await handle_pairing_send(msg, session, proto_writer)
|
|
elif msg_type == "delete_message":
|
|
await handle_delete_message(msg, session, proto_writer)
|
|
elif msg_type == "upload_image_start":
|
|
await handle_upload_image_start(msg, session, proto_writer)
|
|
elif msg_type == "upload_image_chunk":
|
|
await handle_upload_image_chunk(msg, session, proto_writer)
|
|
elif msg_type == "upload_image_end":
|
|
await handle_upload_image_end(msg, session, proto_writer)
|
|
elif msg_type == "download_image":
|
|
await handle_download_image(msg, session, proto_writer)
|
|
elif msg_type == "download_stream":
|
|
await handle_download_stream(msg, session, proto_writer)
|
|
elif msg_type == "get_profile":
|
|
await handle_get_profile(msg, session, proto_writer)
|
|
elif msg_type == "update_profile":
|
|
await handle_update_profile(msg, session, proto_writer)
|
|
elif msg_type == "update_avatar":
|
|
await handle_update_avatar(msg, session, proto_writer)
|
|
elif msg_type == "get_avatar":
|
|
await handle_get_avatar(msg, session, proto_writer)
|
|
elif msg_type == "update_group_avatar":
|
|
await handle_update_group_avatar(msg, session, proto_writer)
|
|
elif msg_type == "get_group_avatar":
|
|
await handle_get_group_avatar(msg, session, proto_writer)
|
|
elif msg_type == "get_deleted_since":
|
|
await handle_get_deleted_since(msg, session, proto_writer)
|
|
elif msg_type == "reencrypt_messages":
|
|
await handle_reencrypt_messages(msg, session, proto_writer)
|
|
elif msg_type == "list_devices":
|
|
await handle_list_devices(msg, session, proto_writer)
|
|
elif msg_type == "remove_device":
|
|
await handle_remove_device(msg, session, proto_writer)
|
|
elif msg_type == "session_reset":
|
|
await handle_session_reset(msg, session, proto_writer)
|
|
elif msg_type == "react_message":
|
|
await handle_react_message(msg, session, proto_writer)
|
|
elif msg_type == "pin_message":
|
|
await handle_pin_message(msg, session, proto_writer)
|
|
elif msg_type == "get_pinned_messages":
|
|
await handle_get_pinned_messages(msg, session, proto_writer)
|
|
else:
|
|
await send_resp(msg, proto_writer, msg_type, "error", {"message": "Unknown type"})
|
|
except Exception as e:
|
|
logger.warning("[ERROR] %s handler '%s' failed: %s", _who(session), msg_type, e, exc_info=True)
|
|
try:
|
|
await send_resp(msg, proto_writer, msg_type, "error", {"message": "Internal server error"})
|
|
except Exception:
|
|
break # Can't send response — connection is dead
|
|
except Exception as e:
|
|
logger.warning("Client connection error: %s", e)
|
|
finally:
|
|
async with _conn_lock:
|
|
current_connections = max(0, current_connections - 1)
|
|
connection_counts[addr] = max(0, connection_counts.get(addr, 1) - 1)
|
|
offline_targets = []
|
|
if session:
|
|
uid = session["user_id"]
|
|
contacts = await adb.get_user_contacts(uid)
|
|
async with _clients_lock:
|
|
writer_device_map.pop(id(proto_writer), None)
|
|
if uid in connected_clients:
|
|
remaining = [w for w in connected_clients[uid] if w is not proto_writer]
|
|
if remaining:
|
|
connected_clients[uid] = remaining
|
|
else:
|
|
del connected_clients[uid]
|
|
# User fully offline — snapshot targets under lock
|
|
for contact_id in contacts:
|
|
for cw in connected_clients.get(contact_id, []):
|
|
offline_targets.append(cw)
|
|
# Send offline notifications outside lock
|
|
for cw in offline_targets:
|
|
try:
|
|
await cw.send_response("user_offline", "ok", {"user_id": uid})
|
|
except Exception:
|
|
pass
|
|
writer.close()
|
|
logger.info("[CONN] %s disconnected", _who(session) if session else addr)
|
|
|
|
|
|
async def main():
|
|
setup_logging()
|
|
host = os.getenv("SERVER_HOST", "127.0.0.1")
|
|
port = int(os.getenv("SERVER_PORT", "9999"))
|
|
tls_enabled = os.getenv("TLS_ENABLED", "false").lower() in ("1", "true", "yes")
|
|
tls_required = os.getenv("TLS_REQUIRED", "false").lower() in ("1", "true", "yes")
|
|
tls_autogen = os.getenv("TLS_AUTOGEN", "false").lower() in ("1", "true", "yes")
|
|
|
|
is_dev = os.getenv("ENVIRONMENT", "").lower() in ("dev", "development")
|
|
ssl_context = None
|
|
if tls_required and not tls_enabled:
|
|
raise RuntimeError("TLS_REQUIRED is enabled but TLS is not enabled.")
|
|
if tls_enabled:
|
|
cert_file = os.getenv("TLS_CERT_FILE", "").strip()
|
|
key_file = os.getenv("TLS_KEY_FILE", "").strip()
|
|
if not cert_file or not key_file:
|
|
if tls_autogen:
|
|
if not is_dev:
|
|
raise RuntimeError("TLS_AUTOGEN is only allowed when ENVIRONMENT=dev")
|
|
cert_dir = Path(__file__).resolve().parent / "certs"
|
|
cert_dir.mkdir(parents=True, exist_ok=True)
|
|
cert_file = str(cert_dir / "server.crt")
|
|
key_file = str(cert_dir / "server.key")
|
|
if not (os.path.exists(cert_file) and os.path.exists(key_file)):
|
|
try:
|
|
subprocess.run(
|
|
[
|
|
"openssl", "req", "-x509", "-newkey", "rsa:4096",
|
|
"-keyout", key_file, "-out", cert_file,
|
|
"-days", "365", "-nodes", "-subj", "/CN=localhost",
|
|
],
|
|
check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
|
|
)
|
|
os.chmod(key_file, 0o600)
|
|
except FileNotFoundError:
|
|
raise RuntimeError("OpenSSL not found.")
|
|
except subprocess.CalledProcessError:
|
|
raise RuntimeError("Failed to auto-generate TLS cert.")
|
|
logger.warning("Using auto-generated self-signed certificate — not for production use.")
|
|
else:
|
|
raise RuntimeError("TLS is enabled but TLS_CERT_FILE or TLS_KEY_FILE is missing.")
|
|
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
|
ssl_context.load_cert_chain(certfile=cert_file, keyfile=key_file)
|
|
else:
|
|
logger.warning("TLS is disabled — traffic is unencrypted. Set TLS_ENABLED=true for production.")
|
|
|
|
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
|
os.chmod(UPLOAD_DIR, 0o700)
|
|
|
|
# Thread pool for asyncio.to_thread() — DB calls + file I/O
|
|
pool_workers = int(os.getenv("THREAD_POOL_SIZE", "40"))
|
|
asyncio.get_event_loop().set_default_executor(ThreadPoolExecutor(max_workers=pool_workers))
|
|
logger.info("Thread pool executor: %d workers", pool_workers)
|
|
|
|
# Load phantom user IDs from DB into in-memory cache
|
|
phantom_user_ids.update(await adb.get_all_phantom_user_ids())
|
|
if phantom_user_ids:
|
|
logger.info("Loaded %d phantom user IDs.", len(phantom_user_ids))
|
|
|
|
server = await asyncio.start_server(
|
|
handle_client, host, port, limit=MAX_MESSAGE_BYTES, ssl=ssl_context,
|
|
)
|
|
logger.info("Encrypted chat server v%s listening on %s:%s", VERSION, host, port)
|
|
|
|
async def _cleanup_rate_limits():
|
|
async with _conn_lock:
|
|
now = asyncio.get_event_loop().time()
|
|
window_start = now - RATE_LIMIT_WINDOW
|
|
stale_keys = [k for k, times in rate_limits.items()
|
|
if not any(t >= window_start for t in times)]
|
|
for k in stale_keys:
|
|
del rate_limits[k]
|
|
stale_conns = [k for k, v in connection_counts.items() if v <= 0]
|
|
for k in stale_conns:
|
|
del connection_counts[k]
|
|
|
|
_cleanup_cycle = 0
|
|
|
|
async def _periodic_cleanup():
|
|
nonlocal _cleanup_cycle
|
|
while True:
|
|
await asyncio.sleep(120)
|
|
_cleanup_cycle += 1
|
|
try:
|
|
await _cleanup_uploads()
|
|
except Exception as e:
|
|
logger.warning("Upload cleanup error: %s", e)
|
|
try:
|
|
await _cleanup_rate_limits()
|
|
except Exception as e:
|
|
logger.warning("Rate limit cleanup error: %s", e)
|
|
try:
|
|
await _cleanup_registrations()
|
|
except Exception as e:
|
|
logger.warning("Registration cleanup error: %s", e)
|
|
# L8: clean up stale phantom users (>30 days, no real conversations)
|
|
try:
|
|
deleted = await adb.cleanup_stale_phantoms(30)
|
|
if deleted:
|
|
async with _clients_lock:
|
|
phantom_user_ids.clear()
|
|
phantom_user_ids.update(await adb.get_all_phantom_user_ids())
|
|
logger.info("Cleaned up %d stale phantom users.", deleted)
|
|
except Exception as e:
|
|
logger.warning("Phantom cleanup error: %s", e)
|
|
# Metadata retention: purge old reads and reactions (every 30 cycles = ~1 hour)
|
|
if _cleanup_cycle % 30 == 0:
|
|
try:
|
|
reads_del = await adb.cleanup_old_reads(METADATA_RETENTION_DAYS)
|
|
reactions_del = await adb.cleanup_old_reactions(METADATA_RETENTION_DAYS)
|
|
if reads_del or reactions_del:
|
|
logger.info("Metadata cleanup: %d reads, %d reactions purged",
|
|
reads_del, reactions_del)
|
|
except Exception as e:
|
|
logger.warning("Metadata cleanup error: %s", e)
|
|
|
|
asyncio.create_task(_periodic_cleanup())
|
|
|
|
loop = asyncio.get_running_loop()
|
|
stop = loop.create_future()
|
|
|
|
def signal_handler():
|
|
if not stop.done():
|
|
stop.set_result(None)
|
|
|
|
for sig in (signal.SIGINT, signal.SIGTERM):
|
|
loop.add_signal_handler(sig, signal_handler)
|
|
|
|
async with server:
|
|
await stop
|
|
logger.info("Shutting down — closing %d client connections...", sum(len(ws) for ws in connected_clients.values()))
|
|
# Stop accepting new connections
|
|
server.close()
|
|
# Force-close all connected client writers
|
|
async with _clients_lock:
|
|
all_writers = [w for writers in connected_clients.values() for w in writers]
|
|
connected_clients.clear()
|
|
writer_device_map.clear()
|
|
for w in all_writers:
|
|
try:
|
|
w.close()
|
|
except Exception:
|
|
pass
|
|
# Give handle_client loops a moment to notice closed connections
|
|
await asyncio.sleep(0.1)
|
|
# Cancel any remaining handle_client tasks that are still blocked
|
|
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
|
for t in tasks:
|
|
t.cancel()
|
|
await asyncio.gather(*tasks, return_exceptions=True)
|
|
logger.info("Server shut down.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|