2054 lines
90 KiB
Python
2054 lines
90 KiB
Python
"""Asyncio TCP server — stores and relays encrypted blobs without seeing content."""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import secrets
|
|
import signal
|
|
import smtplib
|
|
import ssl
|
|
import subprocess
|
|
import sys
|
|
from email.mime.text import MIMEText
|
|
from pathlib import Path
|
|
from datetime import datetime, timezone
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
load_dotenv()
|
|
|
|
import db
|
|
from crypto_utils import load_public_key, rsa_verify, load_ed25519_public, ed25519_verify, serialize_x25519_public
|
|
from protocol import VERSION, MIN_CLIENT_VERSION, version_gte, ProtocolReader, ProtocolWriter, encode_binary, decode_binary, MAX_MESSAGE_BYTES, MAX_IMAGE_BYTES, MAX_FILE_BYTES, IMAGE_CHUNK_SIZE
|
|
|
|
|
|
# Connected clients: user_id -> list[ProtocolWriter]
|
|
connected_clients: dict[str, list[ProtocolWriter]] = {}
|
|
# Writer -> device_id mapping (id(writer) -> device_id)
|
|
writer_device_map: dict[int, str] = {}
|
|
# Pairing sessions: code -> data
|
|
pairing_sessions: dict[str, dict] = {}
|
|
pending_registrations: dict[str, dict] = {}
|
|
# Pending image uploads: file_id -> {temp_path, received_bytes, file_size, conv_id}
|
|
pending_uploads: dict[str, dict] = {}
|
|
# Phantom user IDs (loaded at startup, updated on create/delete)
|
|
phantom_user_ids: set[str] = set()
|
|
|
|
# Locks for shared mutable state (H4 race condition fix)
|
|
_clients_lock = asyncio.Lock() # Protects: connected_clients, writer_device_map, phantom_user_ids
|
|
_conn_lock = asyncio.Lock() # Protects: connection_counts, current_connections, rate_limits
|
|
_pairing_lock = asyncio.Lock() # Protects: pairing_sessions, pending_registrations
|
|
_uploads_lock = asyncio.Lock() # Protects: pending_uploads
|
|
|
|
UPLOAD_DIR = Path(os.getenv("UPLOAD_DIR", "uploads"))
|
|
|
|
# C6 fix: UUID validation + safe path construction to prevent path traversal
|
|
_UUID_RE = re.compile(r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$', re.IGNORECASE)
|
|
|
|
|
|
def _valid_uuid(value: str) -> bool:
|
|
"""Validate that value is a canonical UUID (no path components)."""
|
|
return bool(_UUID_RE.match(value))
|
|
|
|
|
|
# L8 fix: email validation to prevent phantom DB inflation
|
|
_EMAIL_RE = re.compile(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$")
|
|
|
|
|
|
def _valid_email(email: str) -> bool:
|
|
"""Validate basic email format (L8)."""
|
|
return bool(_EMAIL_RE.match(email)) and len(email) <= 254
|
|
|
|
|
|
def _append_file(path: Path, data: bytes):
|
|
"""Append data to file (runs in thread pool to avoid blocking event loop)."""
|
|
with open(path, "ab") as f:
|
|
f.write(data)
|
|
|
|
|
|
def _read_file_chunk(path: Path, offset: int, size: int) -> bytes:
|
|
"""Read a chunk from file (runs in thread pool to avoid blocking event loop)."""
|
|
with open(path, "rb") as f:
|
|
f.seek(offset)
|
|
return f.read(size)
|
|
|
|
|
|
def _safe_upload_path(file_id: str, suffix: str) -> Path | None:
|
|
"""Return resolved path inside UPLOAD_DIR, or None if traversal detected."""
|
|
p = (UPLOAD_DIR / f"{file_id}{suffix}").resolve()
|
|
if not p.is_relative_to(UPLOAD_DIR.resolve()):
|
|
return None
|
|
return p
|
|
|
|
|
|
def _safe_avatar_path(filename: str) -> Path | None:
|
|
"""Return resolved avatar path inside UPLOAD_DIR/avatars, or None if traversal detected."""
|
|
avatar_dir = (UPLOAD_DIR / "avatars").resolve()
|
|
p = (UPLOAD_DIR / "avatars" / filename).resolve()
|
|
if not p.is_relative_to(avatar_dir):
|
|
return None
|
|
return p
|
|
|
|
|
|
PAIRING_TTL_SECONDS = 120
|
|
REGISTER_TTL_SECONDS = 3600
|
|
PAIRING_MAX_POLL_ATTEMPTS = 90
|
|
|
|
# SMTP configuration for registration codes
|
|
SMTP_HOST = os.getenv("SMTP_HOST", "")
|
|
SMTP_PORT = int(os.getenv("SMTP_PORT", "587"))
|
|
SMTP_USER = os.getenv("SMTP_USER", "")
|
|
SMTP_PASS = os.getenv("SMTP_PASS", "")
|
|
SMTP_FROM = os.getenv("SMTP_FROM", "")
|
|
RATE_LIMIT_WINDOW = 60.0 # seconds
|
|
CONNECTION_RL_WINDOW = 1.0 # seconds
|
|
CONNECTION_RL_MAX = 20 # max requests per window per connection
|
|
MAX_CONNECTIONS_PER_IP = 10
|
|
MAX_CONNECTIONS_GLOBAL = 200
|
|
|
|
|
|
def setup_logging():
|
|
level_name = os.getenv("LOG_LEVEL", "INFO").upper()
|
|
level = getattr(logging, level_name, logging.WARNING)
|
|
logging.basicConfig(level=level, format="%(levelname)s: %(message)s")
|
|
|
|
|
|
logger = logging.getLogger("encrypted_chat.server")
|
|
|
|
rate_limits: dict[str, list[float]] = {}
|
|
connection_counts: dict[str, int] = {}
|
|
current_connections = 0
|
|
|
|
|
|
def _rate_limit_key(action: str, addr: str, email: str | None = None) -> str:
|
|
if email:
|
|
return f"{action}|{addr}|{email}"
|
|
return f"{action}|{addr}"
|
|
|
|
|
|
async def _is_rate_limited(key: str, limit: int) -> bool:
|
|
async with _conn_lock:
|
|
now = asyncio.get_event_loop().time()
|
|
window_start = now - RATE_LIMIT_WINDOW
|
|
times = rate_limits.get(key, [])
|
|
times = [t for t in times if t >= window_start]
|
|
if len(times) >= limit:
|
|
rate_limits[key] = times
|
|
return True
|
|
times.append(now)
|
|
rate_limits[key] = times
|
|
return False
|
|
|
|
|
|
def _get_peer_addr(writer: ProtocolWriter) -> str:
|
|
try:
|
|
return str(writer._writer.get_extra_info("peername")[0])
|
|
except Exception:
|
|
return "unknown"
|
|
|
|
|
|
async def _notify_users(user_ids, msg_type, data, exclude_writer=None):
|
|
"""Snapshot writers under lock, send notifications outside lock."""
|
|
targets = []
|
|
async with _clients_lock:
|
|
for uid in user_ids:
|
|
for w in connected_clients.get(uid, []):
|
|
targets.append(w)
|
|
for w in targets:
|
|
if w is exclude_writer:
|
|
continue
|
|
try:
|
|
await w.send_response(msg_type, "ok", data)
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
async def _notify_users_individual(notifications, exclude_writer=None):
|
|
"""Send per-user data. notifications: list of (user_id, msg_type, data)."""
|
|
targets = []
|
|
async with _clients_lock:
|
|
for uid, mt, d in notifications:
|
|
for w in connected_clients.get(uid, []):
|
|
targets.append((w, mt, d))
|
|
for w, mt, d in targets:
|
|
if w is exclude_writer:
|
|
continue
|
|
try:
|
|
await w.send_response(mt, "ok", d)
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
async def _cleanup_pairings():
|
|
async with _pairing_lock:
|
|
now = asyncio.get_event_loop().time()
|
|
expired = [code for code, p in pairing_sessions.items() if now - p["created_at"] > PAIRING_TTL_SECONDS]
|
|
for code in expired:
|
|
pairing_sessions.pop(code, None)
|
|
|
|
|
|
async def _cleanup_registrations():
|
|
async with _pairing_lock:
|
|
now = asyncio.get_event_loop().time()
|
|
expired = [code for code, p in pending_registrations.items() if now - p["created_at"] > REGISTER_TTL_SECONDS]
|
|
for code in expired:
|
|
pending_registrations.pop(code, None)
|
|
|
|
|
|
def _generate_pairing_code() -> str:
|
|
for _ in range(10):
|
|
code = f"{int.from_bytes(os.urandom(4), 'big') % 100000000:08d}"
|
|
if code not in pairing_sessions:
|
|
return code
|
|
return f"{int.from_bytes(os.urandom(4), 'big') % 100000000:08d}"
|
|
|
|
|
|
def _generate_register_code() -> str:
|
|
for _ in range(10):
|
|
code = f"{int.from_bytes(os.urandom(3), 'big') % 1000000:06d}"
|
|
if code not in pending_registrations:
|
|
return code
|
|
return f"{int.from_bytes(os.urandom(3), 'big') % 1000000:06d}"
|
|
|
|
def _validate_public_key_pem(pem_str: str) -> bool:
|
|
"""Validate that a string is a valid RSA public key PEM."""
|
|
try:
|
|
key = load_public_key(pem_str.encode("utf-8"))
|
|
if key.key_size < 2048:
|
|
return False
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
def _send_registration_email(to_email: str, code: str) -> bool:
|
|
"""Send registration code via SMTP. Returns True on success."""
|
|
if not SMTP_HOST:
|
|
return False
|
|
try:
|
|
msg = MIMEText(f"Your registration code is: {code}\n\nThis code expires in 1 hour.")
|
|
msg["Subject"] = "Encrypted Chat - Registration Code"
|
|
msg["From"] = SMTP_FROM or SMTP_USER
|
|
msg["To"] = to_email
|
|
with smtplib.SMTP(SMTP_HOST, SMTP_PORT, timeout=10) as server:
|
|
server.starttls()
|
|
if SMTP_USER:
|
|
server.login(SMTP_USER, SMTP_PASS)
|
|
server.send_message(msg)
|
|
return True
|
|
except Exception as e:
|
|
logger.warning("Failed to send registration email: %s", e)
|
|
return False
|
|
|
|
|
|
async def send_resp(msg: dict, writer: ProtocolWriter, msg_type: str, status: str, data: dict | None = None):
|
|
await writer.send_response(msg_type, status, data, request_id=msg.get("request_id"))
|
|
|
|
|
|
async def handle_register_start(msg: dict, writer: ProtocolWriter) -> dict | None:
|
|
await _cleanup_registrations()
|
|
username = msg.get("username", "").strip()
|
|
public_key = msg.get("public_key", "").strip()
|
|
identity_key_b64 = msg.get("identity_key", "").strip()
|
|
email = msg.get("email", "").strip()
|
|
addr = _get_peer_addr(writer)
|
|
if await _is_rate_limited(_rate_limit_key("register_start", addr, email), 3):
|
|
await send_resp(msg, writer, "register_start", "error", {"message": "Too many attempts. Try later."})
|
|
return None
|
|
if not username or not public_key or not email or not identity_key_b64:
|
|
await send_resp(msg, writer, "register_start", "error", {"message": "Missing fields"})
|
|
return None
|
|
if not _validate_public_key_pem(public_key):
|
|
await send_resp(msg, writer, "register_start", "error", {"message": "Invalid public key format"})
|
|
return None
|
|
# Validate identity key is 32 bytes
|
|
try:
|
|
ik_bytes = decode_binary(identity_key_b64)
|
|
if len(ik_bytes) != 32:
|
|
raise ValueError("Identity key must be 32 bytes")
|
|
load_ed25519_public(ik_bytes)
|
|
except Exception:
|
|
await send_resp(msg, writer, "register_start", "error", {"message": "Invalid identity key"})
|
|
return None
|
|
existing_email = db.get_user_by_email(email)
|
|
phantom_id = None
|
|
if existing_email:
|
|
if existing_email.get("rsa_public_key") == "PHANTOM":
|
|
# Don't delete — will be upgraded in register_confirm to preserve
|
|
# FK references (group_invitations, conversation_members, etc.)
|
|
phantom_id = existing_email["id"]
|
|
else:
|
|
# H3 anti-enumeration: return same response as success to prevent
|
|
# attackers from discovering valid emails. User won't receive a code
|
|
# via email, so they can't confirm — silent failure.
|
|
logger.debug("Registration attempt for existing email (hidden from client).")
|
|
await send_resp(msg, writer, "register_start", "ok", {"message": "Code sent to your email."})
|
|
return None
|
|
async with _pairing_lock:
|
|
code = _generate_register_code()
|
|
pending_registrations[code] = {
|
|
"username": username,
|
|
"public_key": public_key,
|
|
"identity_key": ik_bytes,
|
|
"email": email,
|
|
"created_at": asyncio.get_event_loop().time(),
|
|
"phantom_id": phantom_id,
|
|
}
|
|
logger.info("Registration started.")
|
|
email_sent = _send_registration_email(email, code)
|
|
if email_sent:
|
|
await send_resp(msg, writer, "register_start", "ok", {"message": "Code sent to your email."})
|
|
else:
|
|
if SMTP_HOST:
|
|
logger.warning("SMTP configured but email failed for %s", email)
|
|
else:
|
|
logger.warning("No SMTP configured — returning code directly (dev mode).")
|
|
await send_resp(msg, writer, "register_start", "ok", {"code": code})
|
|
return None
|
|
|
|
|
|
async def handle_register_confirm(msg: dict, writer: ProtocolWriter) -> dict | None:
|
|
await _cleanup_registrations()
|
|
email = msg.get("email", "").strip()
|
|
code = msg.get("code", "").strip()
|
|
addr = _get_peer_addr(writer)
|
|
if await _is_rate_limited(_rate_limit_key("register_confirm", addr, email), 3):
|
|
await send_resp(msg, writer, "register_confirm", "error", {"message": "Too many attempts. Try later."})
|
|
return None
|
|
if not email or not code:
|
|
await send_resp(msg, writer, "register_confirm", "error", {"message": "Missing email or code"})
|
|
return None
|
|
async with _pairing_lock:
|
|
pending = pending_registrations.get(code)
|
|
if pending and pending.get("email") == email:
|
|
pending_registrations.pop(code, None)
|
|
else:
|
|
pending = None
|
|
if not pending:
|
|
await send_resp(msg, writer, "register_confirm", "error", {"message": "Invalid or expired code"})
|
|
return None
|
|
phantom_id = pending.get("phantom_id")
|
|
if phantom_id:
|
|
# Upgrade phantom in-place — preserves FK references (invitations, memberships)
|
|
user_id = db.upgrade_phantom_user(
|
|
phantom_id,
|
|
pending["username"],
|
|
pending["public_key"],
|
|
pending["identity_key"],
|
|
)
|
|
if user_id:
|
|
async with _clients_lock:
|
|
phantom_user_ids.discard(phantom_id)
|
|
else:
|
|
# Phantom was deleted concurrently — fall back to normal create
|
|
user_id = db.create_user(
|
|
pending["username"],
|
|
pending["email"],
|
|
pending["public_key"],
|
|
pending["identity_key"],
|
|
)
|
|
else:
|
|
user_id = db.create_user(
|
|
pending["username"],
|
|
pending["email"],
|
|
pending["public_key"],
|
|
pending["identity_key"],
|
|
)
|
|
db.create_default_profile(user_id)
|
|
logger.info("User registered.")
|
|
await send_resp(msg, writer, "register_confirm", "ok", {"user_id": user_id})
|
|
return None
|
|
|
|
|
|
async def handle_login_start(msg: dict, writer: ProtocolWriter, state: dict):
|
|
email = msg.get("email", "").strip()
|
|
addr = _get_peer_addr(writer)
|
|
if await _is_rate_limited(_rate_limit_key("login_start", addr, email), 10):
|
|
await send_resp(msg, writer, "login_start", "error", {"message": "Too many attempts. Try later."})
|
|
return
|
|
if not email:
|
|
await send_resp(msg, writer, "login_start", "error", {"message": "Missing email"})
|
|
return
|
|
user = db.get_user_by_email(email)
|
|
challenge = os.urandom(32)
|
|
state["login_email"] = email
|
|
state["login_challenge"] = challenge
|
|
if not user:
|
|
# H3 anti-enumeration: return a fake challenge so attacker can't distinguish
|
|
# "user not found" from "user exists". login_finish will fail with generic error.
|
|
state["_login_fake"] = True
|
|
await send_resp(msg, writer, "login_start", "ok", {"challenge": encode_binary(challenge)})
|
|
|
|
|
|
async def handle_login_finish(msg: dict, writer: ProtocolWriter, state: dict) -> dict | None:
|
|
email = msg.get("email", "").strip()
|
|
signature_b64 = msg.get("signature", "")
|
|
challenge = state.get("login_challenge")
|
|
expected_email = state.get("login_email")
|
|
addr = _get_peer_addr(writer)
|
|
if await _is_rate_limited(_rate_limit_key("login_finish", addr, email), 10):
|
|
await send_resp(msg, writer, "login_finish", "error", {"message": "Too many attempts. Try later."})
|
|
return None
|
|
if not email or not signature_b64:
|
|
await send_resp(msg, writer, "login_finish", "error", {"message": "Missing email or signature"})
|
|
return None
|
|
if not challenge or expected_email != email:
|
|
await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"})
|
|
return None
|
|
|
|
# H3: if login_start was for a non-existent user, fail with generic error
|
|
is_fake = state.pop("_login_fake", False)
|
|
|
|
try:
|
|
if is_fake:
|
|
await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"})
|
|
return None
|
|
|
|
user = db.get_user_by_email(email)
|
|
if not user:
|
|
await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"})
|
|
return None
|
|
|
|
public_key = load_public_key(user["rsa_public_key"].encode("utf-8"))
|
|
signature = decode_binary(signature_b64)
|
|
if not rsa_verify(public_key, signature, challenge):
|
|
await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"})
|
|
return None
|
|
except ValueError:
|
|
# H5: invalid base64 in signature
|
|
await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"})
|
|
return None
|
|
finally:
|
|
state.pop("login_challenge", None)
|
|
state.pop("login_email", None)
|
|
|
|
user_id = user["id"]
|
|
|
|
# Version check: reject outdated clients
|
|
client_version = msg.get("client_version", "")
|
|
if client_version and not version_gte(client_version, MIN_CLIENT_VERSION):
|
|
await send_resp(msg, writer, "login_finish", "error", {
|
|
"message": f"Client version {client_version} is too old. Minimum required: {MIN_CLIENT_VERSION}",
|
|
"min_version": MIN_CLIENT_VERSION,
|
|
"server_version": VERSION,
|
|
})
|
|
return None
|
|
|
|
# Device registration: client may send device_id to reuse an existing device
|
|
client_device_id = msg.get("device_id")
|
|
device_id = None
|
|
if client_device_id:
|
|
dev = db.get_device(client_device_id)
|
|
if dev and dev["user_id"] == user_id:
|
|
device_id = client_device_id
|
|
if not device_id:
|
|
device_name = msg.get("device_name", "Unknown")
|
|
device_id = db.create_device(user_id, device_name)
|
|
db.update_device_last_seen(device_id)
|
|
|
|
async with _clients_lock:
|
|
was_offline = user_id not in connected_clients or not connected_clients[user_id]
|
|
if user_id not in connected_clients:
|
|
connected_clients[user_id] = []
|
|
connected_clients[user_id].append(writer)
|
|
writer_device_map[id(writer)] = device_id
|
|
logger.info("User logged in (device %s, client v%s).", device_id, client_version or "unknown")
|
|
await send_resp(msg, writer, "login_finish", "ok", {
|
|
"user_id": user_id, "username": user["username"], "email": user["email"],
|
|
"device_id": device_id, "server_version": VERSION,
|
|
})
|
|
|
|
# Send online status notifications
|
|
contacts = db.get_user_contacts(user_id)
|
|
online_targets = []
|
|
async with _clients_lock:
|
|
online_contacts = [cid for cid in contacts if cid in connected_clients and connected_clients[cid]]
|
|
if was_offline:
|
|
for contact_id in contacts:
|
|
for cw in connected_clients.get(contact_id, []):
|
|
online_targets.append(cw)
|
|
await writer.send_response("online_users", "ok", {"user_ids": online_contacts})
|
|
# Send online notifications outside lock
|
|
for cw in online_targets:
|
|
try:
|
|
await cw.send_response("user_online", "ok", {"user_id": user_id})
|
|
except Exception:
|
|
pass
|
|
|
|
return {"user_id": user_id, "username": user["username"], "email": user["email"],
|
|
"device_id": device_id}
|
|
|
|
|
|
async def handle_get_user_info(msg: dict, writer: ProtocolWriter):
|
|
"""Get user info including identity key (for X3DH)."""
|
|
email = msg.get("email", "").strip()
|
|
user_id = msg.get("user_id", "").strip()
|
|
addr = _get_peer_addr(writer)
|
|
if await _is_rate_limited(_rate_limit_key("get_user_info", addr, email or user_id), 30):
|
|
await send_resp(msg, writer, "get_user_info", "error", {"message": "Too many attempts. Try later."})
|
|
return
|
|
if user_id and not _valid_uuid(user_id):
|
|
await send_resp(msg, writer, "get_user_info", "error", {"message": "Invalid user_id"})
|
|
return
|
|
user = None
|
|
if email:
|
|
user = db.get_user_by_email(email)
|
|
elif user_id:
|
|
user = db.get_user_by_id(user_id)
|
|
if not user:
|
|
await send_resp(msg, writer, "get_user_info", "error", {"message": "User not found"})
|
|
return
|
|
ik = user.get("identity_key")
|
|
await send_resp(msg, writer, "get_user_info", "ok", {
|
|
"user_id": user["id"],
|
|
"username": user["username"],
|
|
"email": user["email"],
|
|
"identity_key": encode_binary(ik) if ik else "",
|
|
})
|
|
|
|
|
|
async def handle_upload_prekeys(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""Upload signed prekey + batch of one-time prekeys."""
|
|
spk_data = msg.get("signed_prekey")
|
|
otps = msg.get("one_time_prekeys", [])
|
|
if not spk_data:
|
|
await send_resp(msg, writer, "upload_prekeys", "error", {"message": "Missing signed_prekey"})
|
|
return
|
|
|
|
spk_id = spk_data.get("id", "")
|
|
spk_pub_b64 = spk_data.get("public_key", "")
|
|
spk_sig_b64 = spk_data.get("signature", "")
|
|
if not spk_id or not spk_pub_b64 or not spk_sig_b64:
|
|
await send_resp(msg, writer, "upload_prekeys", "error", {"message": "Incomplete signed_prekey"})
|
|
return
|
|
|
|
spk_pub = decode_binary(spk_pub_b64)
|
|
spk_sig = decode_binary(spk_sig_b64)
|
|
|
|
# Verify SPK signature with user's identity key
|
|
user = db.get_user_by_id(session["user_id"])
|
|
if not user or not user.get("identity_key"):
|
|
await send_resp(msg, writer, "upload_prekeys", "error", {"message": "No identity key"})
|
|
return
|
|
ik_pub = load_ed25519_public(user["identity_key"])
|
|
if not ed25519_verify(ik_pub, spk_sig, spk_pub):
|
|
await send_resp(msg, writer, "upload_prekeys", "error", {"message": "Invalid SPK signature"})
|
|
return
|
|
|
|
device_id = session.get("device_id")
|
|
db.store_signed_prekey(session["user_id"], spk_id, spk_pub, spk_sig, device_id=device_id)
|
|
|
|
# Store OTPs
|
|
otp_records = []
|
|
for otp in otps:
|
|
otp_id = otp.get("id", "")
|
|
otp_pub_b64 = otp.get("public_key", "")
|
|
if otp_id and otp_pub_b64:
|
|
otp_records.append({"id": otp_id, "public_key": decode_binary(otp_pub_b64)})
|
|
if otp_records:
|
|
db.store_one_time_prekeys(session["user_id"], otp_records, device_id=device_id)
|
|
|
|
logger.info("Prekeys uploaded: 1 SPK + %d OTPs (device %s)", len(otp_records), device_id)
|
|
await send_resp(msg, writer, "upload_prekeys", "ok", {"message": "OK"})
|
|
|
|
|
|
async def handle_get_key_bundle(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""Fetch key bundle for X3DH. Returns per-device bundles. Consumes one OTP per device."""
|
|
target_user_id = msg.get("user_id", "").strip()
|
|
if not target_user_id:
|
|
await send_resp(msg, writer, "get_key_bundle", "error", {"message": "Missing user_id"})
|
|
return
|
|
if not _valid_uuid(target_user_id):
|
|
await send_resp(msg, writer, "get_key_bundle", "error", {"message": "Invalid user_id"})
|
|
return
|
|
result = db.get_key_bundles_for_user(target_user_id)
|
|
if not result or not result.get("device_bundles"):
|
|
await send_resp(msg, writer, "get_key_bundle", "error", {"message": "Key bundle not available"})
|
|
return
|
|
|
|
device_bundles_data = []
|
|
for b in result["device_bundles"]:
|
|
entry = {
|
|
"device_id": b.get("device_id"),
|
|
"signed_prekey_id": b["signed_prekey_id"],
|
|
"signed_prekey": encode_binary(b["signed_prekey_pub"]),
|
|
"spk_signature": encode_binary(b["spk_signature"]),
|
|
}
|
|
if b.get("opk_pub"):
|
|
entry["one_time_prekey_id"] = b["opk_id"]
|
|
entry["one_time_prekey"] = encode_binary(b["opk_pub"])
|
|
device_bundles_data.append(entry)
|
|
|
|
# Build response with both new multi-device format and legacy flat fields
|
|
first = device_bundles_data[0] if device_bundles_data else {}
|
|
data = {
|
|
"identity_key": encode_binary(result["identity_key"]),
|
|
"device_bundles": device_bundles_data,
|
|
# Legacy flat fields from first device bundle (backward compat)
|
|
"signed_prekey_id": first.get("signed_prekey_id", ""),
|
|
"signed_prekey": first.get("signed_prekey", ""),
|
|
"spk_signature": first.get("spk_signature", ""),
|
|
}
|
|
if first.get("one_time_prekey"):
|
|
data["one_time_prekey_id"] = first["one_time_prekey_id"]
|
|
data["one_time_prekey"] = first["one_time_prekey"]
|
|
await send_resp(msg, writer, "get_key_bundle", "ok", data)
|
|
|
|
|
|
async def handle_get_prekey_count(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""How many OPKs does user have left (for this device)? Also returns SPK age for rotation."""
|
|
device_id = session.get("device_id")
|
|
count = db.count_one_time_prekeys(session["user_id"], device_id=device_id)
|
|
spk_created_at = ""
|
|
spk = db.get_signed_prekey(session["user_id"], device_id=device_id)
|
|
if spk and spk.get("created_at"):
|
|
spk_created_at = spk["created_at"].isoformat() if hasattr(spk["created_at"], "isoformat") else str(spk["created_at"])
|
|
await send_resp(msg, writer, "get_prekey_count", "ok",
|
|
{"count": count, "spk_created_at": spk_created_at})
|
|
|
|
|
|
async def handle_rotate_keys(msg: dict, session: dict, writer: ProtocolWriter):
|
|
public_key = msg.get("public_key", "").strip()
|
|
if not public_key:
|
|
await send_resp(msg, writer, "rotate_keys", "error", {"message": "Missing public_key"})
|
|
return
|
|
if not _validate_public_key_pem(public_key):
|
|
await send_resp(msg, writer, "rotate_keys", "error", {"message": "Invalid public key format"})
|
|
return
|
|
db.update_user_rsa_key(session["user_id"], public_key)
|
|
logger.info("RSA key rotated.")
|
|
await send_resp(msg, writer, "rotate_keys", "ok", {"message": "OK"})
|
|
# Disconnect other sessions
|
|
async with _clients_lock:
|
|
writers = connected_clients.get(session["user_id"], [])
|
|
others = [w for w in writers if w is not writer]
|
|
connected_clients[session["user_id"]] = [writer]
|
|
for w in others:
|
|
try:
|
|
w.close()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
async def handle_pairing_start(msg: dict, writer: ProtocolWriter):
|
|
await _cleanup_pairings()
|
|
email = msg.get("email", "").strip()
|
|
temp_public_key = msg.get("temp_public_key", "").strip()
|
|
addr = _get_peer_addr(writer)
|
|
if await _is_rate_limited(_rate_limit_key("pairing_start", addr, email), 10):
|
|
await send_resp(msg, writer, "pairing_start", "error", {"message": "Too many attempts. Try later."})
|
|
return
|
|
if not email or not temp_public_key:
|
|
await send_resp(msg, writer, "pairing_start", "error", {"message": "Missing email or temp_public_key"})
|
|
return
|
|
user = db.get_user_by_email(email)
|
|
if not user:
|
|
await send_resp(msg, writer, "pairing_start", "error", {"message": "User not found"})
|
|
return
|
|
poll_token = secrets.token_hex(16)
|
|
async with _pairing_lock:
|
|
code = _generate_pairing_code()
|
|
pairing_sessions[code] = {
|
|
"email": email,
|
|
"temp_public_key": temp_public_key,
|
|
"created_at": asyncio.get_event_loop().time(),
|
|
"payload": None,
|
|
"poll_token": poll_token,
|
|
}
|
|
await send_resp(msg, writer, "pairing_start", "ok", {"code": code, "poll_token": poll_token})
|
|
|
|
|
|
async def handle_pairing_claim(msg: dict, session: dict, writer: ProtocolWriter):
|
|
await _cleanup_pairings()
|
|
code = msg.get("code", "").strip()
|
|
if not code:
|
|
await send_resp(msg, writer, "pairing_claim", "error", {"message": "Missing code"})
|
|
return
|
|
async with _pairing_lock:
|
|
p = pairing_sessions.get(code)
|
|
p_email = p["email"] if p else None
|
|
temp_pub = p["temp_public_key"] if p else None
|
|
if p:
|
|
# Extend TTL — re-encryption may run between claim and send
|
|
p["created_at"] = asyncio.get_event_loop().time()
|
|
if not p:
|
|
await send_resp(msg, writer, "pairing_claim", "error", {"message": "Invalid or expired code"})
|
|
return
|
|
if p_email != session.get("email"):
|
|
await send_resp(msg, writer, "pairing_claim", "error", {"message": "Not authorized for this code"})
|
|
return
|
|
await send_resp(msg, writer, "pairing_claim", "ok", {"temp_public_key": temp_pub})
|
|
|
|
|
|
async def handle_pairing_send(msg: dict, session: dict, writer: ProtocolWriter):
|
|
await _cleanup_pairings()
|
|
code = msg.get("code", "").strip()
|
|
payload = msg.get("payload")
|
|
if not code or not payload:
|
|
await send_resp(msg, writer, "pairing_send", "error", {"message": "Missing code or payload"})
|
|
return
|
|
error_msg = None
|
|
async with _pairing_lock:
|
|
p = pairing_sessions.get(code)
|
|
if not p:
|
|
error_msg = "Invalid or expired code"
|
|
elif p["email"] != session.get("email"):
|
|
error_msg = "Not authorized for this code"
|
|
else:
|
|
p["payload"] = payload
|
|
if error_msg:
|
|
await send_resp(msg, writer, "pairing_send", "error", {"message": error_msg})
|
|
else:
|
|
await send_resp(msg, writer, "pairing_send", "ok", {"message": "OK"})
|
|
|
|
|
|
async def handle_pairing_poll(msg: dict, writer: ProtocolWriter):
|
|
await _cleanup_pairings()
|
|
code = msg.get("code", "").strip()
|
|
poll_token = msg.get("poll_token", "").strip()
|
|
addr = _get_peer_addr(writer)
|
|
if await _is_rate_limited(_rate_limit_key("pairing_poll", addr), 120):
|
|
await send_resp(msg, writer, "pairing_poll", "error", {"message": "Too many attempts. Try later."})
|
|
return
|
|
if not code:
|
|
await send_resp(msg, writer, "pairing_poll", "error", {"message": "Missing code"})
|
|
return
|
|
if not poll_token:
|
|
await send_resp(msg, writer, "pairing_poll", "error", {"message": "Missing poll_token"})
|
|
return
|
|
error_msg = None
|
|
ready = False
|
|
payload = None
|
|
async with _pairing_lock:
|
|
p = pairing_sessions.get(code)
|
|
if not p:
|
|
error_msg = "Invalid or expired code"
|
|
elif not secrets.compare_digest(p.get("poll_token", ""), poll_token):
|
|
error_msg = "Invalid poll_token"
|
|
else:
|
|
poll_attempts = p.get("poll_attempts", 0) + 1
|
|
p["poll_attempts"] = poll_attempts
|
|
if poll_attempts > PAIRING_MAX_POLL_ATTEMPTS and not p.get("payload"):
|
|
pairing_sessions.pop(code, None)
|
|
error_msg = "Code invalidated due to too many attempts"
|
|
elif p.get("payload"):
|
|
ready = True
|
|
payload = p["payload"]
|
|
pairing_sessions.pop(code, None)
|
|
if error_msg:
|
|
await send_resp(msg, writer, "pairing_poll", "error", {"message": error_msg})
|
|
elif ready:
|
|
await send_resp(msg, writer, "pairing_poll", "ok", {"ready": True, "payload": payload})
|
|
else:
|
|
await send_resp(msg, writer, "pairing_poll", "ok", {"ready": False})
|
|
|
|
|
|
async def handle_create_conversation(msg: dict, session: dict, writer: ProtocolWriter):
|
|
member_emails = msg.get("members", [])
|
|
name = msg.get("name")
|
|
# Resolve all member user IDs
|
|
other_users = []
|
|
for email in member_emails:
|
|
u = db.get_user_by_email(email)
|
|
if not u:
|
|
u = db.create_phantom_user(email)
|
|
async with _clients_lock:
|
|
phantom_user_ids.add(u["id"])
|
|
if u["id"] != session["user_id"]:
|
|
other_users.append(u)
|
|
is_dm = len(other_users) == 1 and not name
|
|
joined_at = datetime.now(timezone.utc)
|
|
if is_dm:
|
|
# DMs: add both members directly (no invitation)
|
|
all_ids = [session["user_id"]] + [u["id"] for u in other_users]
|
|
conv_id = db.create_conversation(all_ids, joined_at=joined_at, name=name, created_by=session["user_id"])
|
|
logger.info("DM conversation created.")
|
|
await send_resp(msg, writer, "create_conversation", "ok", {"conversation_id": conv_id})
|
|
# Notify the other member
|
|
members_info = db.get_conversation_members(conv_id)
|
|
member_list = [{"user_id": m["id"], "username": m["username"], "email": m["email"]} for m in members_info]
|
|
notif_data = {
|
|
"conversation_id": conv_id,
|
|
"name": name,
|
|
"created_by": session["user_id"],
|
|
"members": member_list,
|
|
}
|
|
await _notify_users([u["id"] for u in other_users], "conversation_created", notif_data)
|
|
else:
|
|
# Groups: only add creator, create invitations for others
|
|
conv_id = db.create_conversation([session["user_id"]], joined_at=joined_at, name=name, created_by=session["user_id"])
|
|
logger.info("Group conversation created with invitations.")
|
|
# Create invitations for other members
|
|
creator_user = db.get_user_by_id(session["user_id"])
|
|
creator_name = creator_user["username"] if creator_user else "Unknown"
|
|
invited_ids = []
|
|
async with _clients_lock:
|
|
phantom_snapshot = set(phantom_user_ids)
|
|
for u in other_users:
|
|
db.create_invitation(conv_id, u["id"], session["user_id"])
|
|
if u["id"] not in phantom_snapshot:
|
|
invited_ids.append(u["id"]) # only notify non-phantoms
|
|
inv_notif = {
|
|
"conversation_id": conv_id,
|
|
"conversation_name": name,
|
|
"invited_by": session["user_id"],
|
|
"invited_by_username": creator_name,
|
|
}
|
|
await _notify_users(invited_ids, "group_invitation", inv_notif)
|
|
await send_resp(msg, writer, "create_conversation", "ok", {"conversation_id": conv_id})
|
|
|
|
|
|
async def handle_find_conversation(msg: dict, session: dict, writer: ProtocolWriter):
|
|
email = msg.get("email", "").strip()
|
|
if not email:
|
|
await send_resp(msg, writer, "find_conversation", "error", {"message": "Invalid request"})
|
|
return
|
|
addr = _get_peer_addr(writer)
|
|
if await _is_rate_limited(_rate_limit_key("find_conversation", addr, email), 30):
|
|
await send_resp(msg, writer, "find_conversation", "error", {"message": "Too many attempts. Try later."})
|
|
return
|
|
other = db.get_user_by_email(email)
|
|
if not other:
|
|
other = db.create_phantom_user(email)
|
|
async with _clients_lock:
|
|
phantom_user_ids.add(other["id"])
|
|
conv_id = db.find_direct_conversation(session["user_id"], other["id"])
|
|
await send_resp(msg, writer, "find_conversation", "ok", {
|
|
"conversation_id": conv_id,
|
|
"user_id": other["id"],
|
|
})
|
|
|
|
|
|
async def handle_add_member(msg: dict, session: dict, writer: ProtocolWriter):
|
|
conv_id = msg.get("conversation_id", "")
|
|
email = msg.get("email", "").strip()
|
|
if not conv_id or not email:
|
|
await send_resp(msg, writer, "add_member", "error", {"message": "Invalid request"})
|
|
return
|
|
if not _valid_uuid(conv_id):
|
|
await send_resp(msg, writer, "add_member", "error", {"message": "Invalid conversation_id"})
|
|
return
|
|
# L8: validate email format before phantom creation
|
|
addr = _get_peer_addr(writer)
|
|
if await _is_rate_limited(_rate_limit_key("add_member", addr, email), 10):
|
|
await send_resp(msg, writer, "add_member", "error", {"message": "Too many attempts. Try later."})
|
|
return
|
|
if not db.is_conversation_member(conv_id, session["user_id"]):
|
|
await send_resp(msg, writer, "add_member", "error", {"message": "Not a member"})
|
|
return
|
|
user = db.get_user_by_email(email)
|
|
if not user:
|
|
# Create phantom for unregistered email (same as create_conversation)
|
|
user = db.create_phantom_user(email)
|
|
async with _clients_lock:
|
|
phantom_user_ids.add(user["id"])
|
|
if db.is_conversation_member(conv_id, user["id"]):
|
|
await send_resp(msg, writer, "add_member", "error", {"message": "Already a member"})
|
|
return
|
|
if db.has_pending_invitation(conv_id, user["id"]):
|
|
await send_resp(msg, writer, "add_member", "error", {"message": "Invitation already pending"})
|
|
return
|
|
# Create invitation (for both real and phantom users)
|
|
db.create_invitation(conv_id, user["id"], session["user_id"])
|
|
logger.info("Group invitation created.")
|
|
await send_resp(msg, writer, "add_member", "ok", {"user_id": user["id"]})
|
|
# Push invitation notification only to non-phantom users
|
|
async with _clients_lock:
|
|
is_phantom = user["id"] in phantom_user_ids
|
|
if not is_phantom:
|
|
conv = db.get_conversation(conv_id)
|
|
creator_user = db.get_user_by_id(session["user_id"])
|
|
creator_name = creator_user["username"] if creator_user else "Unknown"
|
|
inv_notif = {
|
|
"conversation_id": conv_id,
|
|
"conversation_name": conv.get("name") if conv else None,
|
|
"invited_by": session["user_id"],
|
|
"invited_by_username": creator_name,
|
|
}
|
|
await _notify_users([user["id"]], "group_invitation", inv_notif)
|
|
|
|
|
|
async def handle_accept_invitation(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""Accept a group invitation — add user to conversation members."""
|
|
conv_id = msg.get("conversation_id", "")
|
|
if not conv_id:
|
|
await send_resp(msg, writer, "accept_invitation", "error", {"message": "Missing conversation_id"})
|
|
return
|
|
if not _valid_uuid(conv_id):
|
|
await send_resp(msg, writer, "accept_invitation", "error", {"message": "Invalid conversation_id"})
|
|
return
|
|
if not db.has_pending_invitation(conv_id, session["user_id"]):
|
|
await send_resp(msg, writer, "accept_invitation", "error", {"message": "No pending invitation"})
|
|
return
|
|
joined_at = datetime.now(timezone.utc)
|
|
db.add_conversation_member(conv_id, session["user_id"], joined_at=joined_at)
|
|
db.delete_invitation(conv_id, session["user_id"])
|
|
logger.info("Invitation accepted.")
|
|
await send_resp(msg, writer, "accept_invitation", "ok", {"conversation_id": conv_id})
|
|
# Notify existing members about the new member
|
|
user = db.get_user_by_id(session["user_id"])
|
|
notif_data = {
|
|
"conversation_id": conv_id,
|
|
"user_id": session["user_id"],
|
|
"username": user["username"] if user else "",
|
|
"email": user["email"] if user else "",
|
|
}
|
|
members = db.get_conversation_members(conv_id)
|
|
member_ids = [m["id"] for m in members if m["id"] != session["user_id"]]
|
|
await _notify_users(member_ids, "member_added", notif_data)
|
|
|
|
|
|
async def handle_decline_invitation(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""Decline a group invitation."""
|
|
conv_id = msg.get("conversation_id", "")
|
|
if not conv_id:
|
|
await send_resp(msg, writer, "decline_invitation", "error", {"message": "Missing conversation_id"})
|
|
return
|
|
if not _valid_uuid(conv_id):
|
|
await send_resp(msg, writer, "decline_invitation", "error", {"message": "Invalid conversation_id"})
|
|
return
|
|
if not db.has_pending_invitation(conv_id, session["user_id"]):
|
|
await send_resp(msg, writer, "decline_invitation", "error", {"message": "No pending invitation"})
|
|
return
|
|
db.delete_invitation(conv_id, session["user_id"])
|
|
logger.info("Invitation declined.")
|
|
await send_resp(msg, writer, "decline_invitation", "ok", {"message": "OK"})
|
|
|
|
|
|
async def handle_list_invitations(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""List pending group invitations for the current user."""
|
|
invitations = db.get_pending_invitations(session["user_id"])
|
|
result = []
|
|
for inv in invitations:
|
|
entry = {
|
|
"conversation_id": inv["conversation_id"],
|
|
"conversation_name": inv.get("conversation_name"),
|
|
"invited_by": inv["invited_by"],
|
|
"invited_by_username": inv.get("invited_by_username", ""),
|
|
"created_at": inv["created_at"].isoformat() if hasattr(inv["created_at"], "isoformat") else str(inv["created_at"]),
|
|
}
|
|
result.append(entry)
|
|
await send_resp(msg, writer, "list_invitations", "ok", {"invitations": result})
|
|
|
|
|
|
async def handle_list_conversations(msg: dict, session: dict, writer: ProtocolWriter):
|
|
convs = db.list_user_conversations(session["user_id"])
|
|
unread = db.get_unread_counts(session["user_id"])
|
|
result = []
|
|
for c in convs:
|
|
result.append({
|
|
"conversation_id": c["id"],
|
|
"created_at": c["created_at"].isoformat() if hasattr(c["created_at"], "isoformat") else str(c["created_at"]),
|
|
"members": c["members"],
|
|
"name": c.get("name"),
|
|
"created_by": c.get("created_by"),
|
|
"avatar_file": c.get("avatar_file"),
|
|
"unread_count": unread.get(c["id"], 0),
|
|
})
|
|
await send_resp(msg, writer, "list_conversations", "ok", {"conversations": result})
|
|
|
|
|
|
async def handle_send_message(msg: dict, session: dict, writer: ProtocolWriter):
|
|
conv_id = msg.get("conversation_id", "")
|
|
if not conv_id:
|
|
await send_resp(msg, writer, "send_message", "error", {"message": "Missing conversation_id"})
|
|
return
|
|
if not _valid_uuid(conv_id):
|
|
await send_resp(msg, writer, "send_message", "error", {"message": "Invalid conversation_id"})
|
|
return
|
|
addr = _get_peer_addr(writer)
|
|
if await _is_rate_limited(_rate_limit_key("send_message", addr, session.get("email")), 20):
|
|
await send_resp(msg, writer, "send_message", "error", {"message": "Too many attempts. Try later."})
|
|
return
|
|
if not db.is_conversation_member(conv_id, session["user_id"]):
|
|
await send_resp(msg, writer, "send_message", "error", {"message": "Not a member"})
|
|
return
|
|
|
|
# New protocol: ratchet_header + recipients[] with per-user ciphertext
|
|
ratchet_header_raw = msg.get("ratchet_header")
|
|
recipients_raw = msg.get("recipients")
|
|
if not ratchet_header_raw or not recipients_raw:
|
|
await send_resp(msg, writer, "send_message", "error", {"message": "Missing ratchet_header or recipients"})
|
|
return
|
|
|
|
ratchet_header = json.dumps(ratchet_header_raw).encode() if isinstance(ratchet_header_raw, dict) else \
|
|
ratchet_header_raw.encode() if isinstance(ratchet_header_raw, str) else ratchet_header_raw
|
|
|
|
x3dh_header_raw = msg.get("x3dh_header")
|
|
x3dh_header = None
|
|
if x3dh_header_raw:
|
|
x3dh_header = json.dumps(x3dh_header_raw).encode() if isinstance(x3dh_header_raw, dict) else \
|
|
x3dh_header_raw.encode() if isinstance(x3dh_header_raw, str) else x3dh_header_raw
|
|
|
|
sender_chain_id_b64 = msg.get("sender_chain_id")
|
|
sender_chain_id = decode_binary(sender_chain_id_b64) if sender_chain_id_b64 else None
|
|
sender_chain_n = msg.get("sender_chain_n")
|
|
|
|
# Validate recipients are actual members
|
|
member_ids = {m["id"] for m in db.get_conversation_members(conv_id)}
|
|
async with _clients_lock:
|
|
phantom_snapshot = set(phantom_user_ids)
|
|
db_recipients = []
|
|
for r in recipients_raw:
|
|
uid = r.get("user_id", "")
|
|
if uid not in member_ids:
|
|
continue
|
|
if uid in phantom_snapshot:
|
|
continue
|
|
ct_b64 = r.get("encrypted_content", "")
|
|
nonce_b64 = r.get("nonce", "")
|
|
if not ct_b64 or not nonce_b64:
|
|
continue
|
|
entry = {
|
|
"user_id": uid,
|
|
"encrypted_content": decode_binary(ct_b64),
|
|
"nonce": decode_binary(nonce_b64),
|
|
}
|
|
# Per-recipient device_id (multi-device support)
|
|
r_device_id = r.get("device_id")
|
|
if r_device_id:
|
|
entry["device_id"] = r_device_id
|
|
# Per-recipient ratchet header and x3dh header
|
|
r_rh = r.get("ratchet_header")
|
|
if r_rh:
|
|
entry["ratchet_header"] = json.dumps(r_rh).encode() if isinstance(r_rh, dict) else \
|
|
r_rh.encode() if isinstance(r_rh, str) else r_rh
|
|
r_x3dh = r.get("x3dh_header")
|
|
if r_x3dh:
|
|
entry["x3dh_header"] = json.dumps(r_x3dh).encode() if isinstance(r_x3dh, dict) else \
|
|
r_x3dh.encode() if isinstance(r_x3dh, str) else r_x3dh
|
|
db_recipients.append(entry)
|
|
if not db_recipients:
|
|
await send_resp(msg, writer, "send_message", "error", {"message": "No valid recipients"})
|
|
return
|
|
|
|
image_file_id = msg.get("image_file_id")
|
|
msg_id = db.store_message(
|
|
conv_id, session["user_id"], ratchet_header, db_recipients,
|
|
x3dh_header=x3dh_header,
|
|
sender_chain_id=sender_chain_id,
|
|
sender_chain_n=sender_chain_n,
|
|
image_file_id=image_file_id,
|
|
sender_device_id=session.get("device_id"),
|
|
)
|
|
|
|
# Link image upload to message if present
|
|
if image_file_id:
|
|
upload = db.get_image_upload(image_file_id)
|
|
if upload and upload["completed"] and upload["uploader_id"] == session["user_id"]:
|
|
db.set_message_image_file_id(msg_id, image_file_id)
|
|
|
|
logger.info("Message stored.")
|
|
await send_resp(msg, writer, "send_message", "ok", {"message_id": msg_id})
|
|
|
|
# Notify connected recipients — group all per-device entries by user_id
|
|
from collections import defaultdict
|
|
user_entries = defaultdict(list)
|
|
for r in recipients_raw:
|
|
uid = r.get("user_id", "")
|
|
user_entries[uid].append({
|
|
"device_id": r.get("device_id", db.SELF_DEVICE_ID),
|
|
"encrypted_content": r.get("encrypted_content", ""),
|
|
"nonce": r.get("nonce", ""),
|
|
"ratchet_header": r.get("ratchet_header") or ratchet_header_raw,
|
|
"x3dh_header": r.get("x3dh_header") or x3dh_header_raw,
|
|
})
|
|
|
|
notifications = []
|
|
for uid, entries in user_entries.items():
|
|
notif_data = {
|
|
"message_id": msg_id,
|
|
"conversation_id": conv_id,
|
|
"sender_id": session["user_id"],
|
|
"sender_device_id": session.get("device_id"),
|
|
"device_entries": entries,
|
|
}
|
|
if sender_chain_id_b64:
|
|
notif_data["sender_chain_id"] = sender_chain_id_b64
|
|
if sender_chain_n is not None:
|
|
notif_data["sender_chain_n"] = sender_chain_n
|
|
# Also include flat fields for backward compat with old clients
|
|
# (first entry's data as fallback)
|
|
if entries:
|
|
first = entries[0]
|
|
notif_data["ratchet_header"] = first.get("ratchet_header") or ratchet_header_raw
|
|
notif_data["encrypted_content"] = first.get("encrypted_content", "")
|
|
notif_data["nonce"] = first.get("nonce", "")
|
|
if first.get("x3dh_header"):
|
|
notif_data["x3dh_header"] = first["x3dh_header"]
|
|
notifications.append((uid, "new_message", notif_data))
|
|
await _notify_users_individual(notifications, exclude_writer=writer)
|
|
|
|
|
|
async def handle_get_messages(msg: dict, session: dict, writer: ProtocolWriter):
|
|
conv_id = msg.get("conversation_id", "")
|
|
if not conv_id:
|
|
await send_resp(msg, writer, "get_messages", "error", {"message": "Missing conversation_id"})
|
|
return
|
|
if not _valid_uuid(conv_id):
|
|
await send_resp(msg, writer, "get_messages", "error", {"message": "Invalid conversation_id"})
|
|
return
|
|
if not db.is_conversation_member(conv_id, session["user_id"]):
|
|
await send_resp(msg, writer, "get_messages", "error", {"message": "Not a member"})
|
|
return
|
|
|
|
limit = min(max(int(msg.get("limit", 50)), 1), 200)
|
|
offset = max(int(msg.get("offset", 0)), 0)
|
|
device_id = session.get("device_id")
|
|
messages = db.get_messages(conv_id, session["user_id"], limit, offset, device_id=device_id)
|
|
|
|
result = []
|
|
message_ids = [m["id"] for m in messages]
|
|
read_status = db.get_message_read_status(message_ids) if message_ids else {}
|
|
for m in messages:
|
|
read_by = read_status.get(m["id"], [])
|
|
# Prefer per-recipient headers (mr_*) over message-level headers
|
|
rh_raw = m.get("mr_ratchet_header") or m.get("ratchet_header")
|
|
x3dh_raw = m.get("mr_x3dh_header") or m.get("x3dh_header")
|
|
entry = {
|
|
"message_id": m["id"],
|
|
"sender_id": m.get("sender_id") or "",
|
|
"ratchet_header": json.loads(rh_raw) if rh_raw else {},
|
|
"encrypted_content": encode_binary(m["encrypted_content"]) if m.get("encrypted_content") else "",
|
|
"nonce": encode_binary(m["nonce"]) if m.get("nonce") else "",
|
|
"created_at": m["created_at"].isoformat() if hasattr(m["created_at"], "isoformat") else str(m["created_at"]),
|
|
"read_by": read_by,
|
|
}
|
|
if x3dh_raw:
|
|
entry["x3dh_header"] = json.loads(x3dh_raw)
|
|
if m.get("sender_chain_id"):
|
|
entry["sender_chain_id"] = encode_binary(m["sender_chain_id"])
|
|
if m.get("sender_chain_n") is not None:
|
|
entry["sender_chain_n"] = m["sender_chain_n"]
|
|
if m.get("sender_device_id"):
|
|
entry["sender_device_id"] = m["sender_device_id"]
|
|
if m.get("deleted_at"):
|
|
entry["deleted_at"] = m["deleted_at"].isoformat() if hasattr(m["deleted_at"], "isoformat") else str(m["deleted_at"])
|
|
result.append(entry)
|
|
await send_resp(msg, writer, "get_messages", "ok", {"messages": result})
|
|
|
|
|
|
async def handle_remove_member(msg: dict, session: dict, writer: ProtocolWriter):
|
|
conv_id = msg.get("conversation_id", "")
|
|
user_id = msg.get("user_id", "")
|
|
if not conv_id or not user_id:
|
|
await send_resp(msg, writer, "remove_member", "error", {"message": "Missing conversation_id or user_id"})
|
|
return
|
|
if not _valid_uuid(conv_id) or not _valid_uuid(user_id):
|
|
await send_resp(msg, writer, "remove_member", "error", {"message": "Invalid conversation_id or user_id"})
|
|
return
|
|
if not db.is_conversation_member(conv_id, session["user_id"]):
|
|
await send_resp(msg, writer, "remove_member", "error", {"message": "Not a member"})
|
|
return
|
|
convs = db.list_user_conversations(session["user_id"])
|
|
conv_data = None
|
|
for c in convs:
|
|
if c["id"] == conv_id:
|
|
conv_data = c
|
|
break
|
|
if not conv_data or conv_data.get("created_by") != session["user_id"]:
|
|
await send_resp(msg, writer, "remove_member", "error", {"message": "Only the group creator can remove members"})
|
|
return
|
|
if user_id == session["user_id"]:
|
|
await send_resp(msg, writer, "remove_member", "error", {"message": "Cannot remove yourself"})
|
|
return
|
|
# Get remaining members before removing (to notify them)
|
|
members_before = db.get_conversation_members(conv_id)
|
|
# M6: atomic removal — return value confirms row existed
|
|
removed = db.remove_conversation_member_atomic(conv_id, user_id)
|
|
if not removed:
|
|
await send_resp(msg, writer, "remove_member", "error", {"message": "Member already removed"})
|
|
return
|
|
logger.info("Conversation member removed.")
|
|
await send_resp(msg, writer, "remove_member", "ok", {"message": "OK"})
|
|
|
|
# Notify removed member and remaining members
|
|
notif_data = {
|
|
"conversation_id": conv_id,
|
|
"user_id": user_id,
|
|
}
|
|
member_ids = [m["id"] for m in members_before if m["id"] != session["user_id"]]
|
|
await _notify_users(member_ids, "member_removed", notif_data)
|
|
|
|
|
|
async def handle_leave_group(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""Leave a group conversation voluntarily."""
|
|
conv_id = msg.get("conversation_id", "")
|
|
if not conv_id:
|
|
await send_resp(msg, writer, "leave_group", "error", {"message": "Missing conversation_id"})
|
|
return
|
|
if not _valid_uuid(conv_id):
|
|
await send_resp(msg, writer, "leave_group", "error", {"message": "Invalid conversation_id"})
|
|
return
|
|
if not db.is_conversation_member(conv_id, session["user_id"]):
|
|
await send_resp(msg, writer, "leave_group", "error", {"message": "Not a member"})
|
|
return
|
|
# Don't allow leaving DMs (2 members without a name)
|
|
conv = db.get_conversation(conv_id)
|
|
members = db.get_conversation_members(conv_id)
|
|
if len(members) <= 2 and not (conv and conv.get("name")):
|
|
await send_resp(msg, writer, "leave_group", "error", {"message": "Cannot leave a DM conversation"})
|
|
return
|
|
# If creator is leaving, transfer to first remaining member
|
|
if conv and conv.get("created_by") == session["user_id"]:
|
|
remaining = [m for m in members if m["id"] != session["user_id"]]
|
|
if remaining:
|
|
db.update_conversation_creator(conv_id, remaining[0]["id"])
|
|
# M6: atomic removal
|
|
db.remove_conversation_member_atomic(conv_id, session["user_id"])
|
|
logger.info("User left group.")
|
|
await send_resp(msg, writer, "leave_group", "ok", {"message": "OK"})
|
|
# Notify remaining members
|
|
notif_data = {
|
|
"conversation_id": conv_id,
|
|
"user_id": session["user_id"],
|
|
}
|
|
member_ids = [m["id"] for m in members if m["id"] != session["user_id"]]
|
|
await _notify_users(member_ids, "member_removed", notif_data)
|
|
|
|
|
|
async def handle_rename_conversation(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""Rename a group conversation (creator only)."""
|
|
conv_id = msg.get("conversation_id", "")
|
|
new_name = msg.get("name", "").strip()
|
|
if not conv_id or not new_name:
|
|
await send_resp(msg, writer, "rename_conversation", "error", {"message": "Missing conversation_id or name"})
|
|
return
|
|
if not _valid_uuid(conv_id):
|
|
await send_resp(msg, writer, "rename_conversation", "error", {"message": "Invalid conversation_id"})
|
|
return
|
|
if len(new_name) > 100:
|
|
await send_resp(msg, writer, "rename_conversation", "error", {"message": "Name too long (max 100)"})
|
|
return
|
|
if not db.is_conversation_member(conv_id, session["user_id"]):
|
|
await send_resp(msg, writer, "rename_conversation", "error", {"message": "Not a member"})
|
|
return
|
|
conv = db.get_conversation(conv_id)
|
|
if not conv or not conv.get("name"):
|
|
await send_resp(msg, writer, "rename_conversation", "error", {"message": "Cannot rename a DM conversation"})
|
|
return
|
|
if conv.get("created_by") != session["user_id"]:
|
|
await send_resp(msg, writer, "rename_conversation", "error", {"message": "Only the group creator can rename"})
|
|
return
|
|
db.update_conversation_name(conv_id, new_name)
|
|
logger.info("Group renamed: %s", conv_id)
|
|
await send_resp(msg, writer, "rename_conversation", "ok", {"message": "OK"})
|
|
# Notify all members
|
|
members = db.get_conversation_members(conv_id)
|
|
member_ids = [m["id"] for m in members if m["id"] != session["user_id"]]
|
|
await _notify_users(member_ids, "conversation_renamed", {
|
|
"conversation_id": conv_id,
|
|
"name": new_name,
|
|
"renamed_by": session["user_id"],
|
|
})
|
|
|
|
|
|
async def handle_delete_conversation(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""Delete a conversation for the current user. Removes user from members,
|
|
deletes the conversation if no members remain."""
|
|
conv_id = msg.get("conversation_id", "")
|
|
if not conv_id:
|
|
await send_resp(msg, writer, "delete_conversation", "error", {"message": "Missing conversation_id"})
|
|
return
|
|
if not _valid_uuid(conv_id):
|
|
await send_resp(msg, writer, "delete_conversation", "error", {"message": "Invalid conversation_id"})
|
|
return
|
|
if not db.is_conversation_member(conv_id, session["user_id"]):
|
|
await send_resp(msg, writer, "delete_conversation", "error", {"message": "Not a member"})
|
|
return
|
|
conv = db.get_conversation(conv_id)
|
|
members = db.get_conversation_members(conv_id)
|
|
is_group = len(members) > 2 or (conv and conv.get("name"))
|
|
# Groups can only be deleted by the creator (admin)
|
|
if is_group and (not conv or conv.get("created_by") != session["user_id"]):
|
|
await send_resp(msg, writer, "delete_conversation", "error", {"message": "Only the group creator can delete this conversation"})
|
|
return
|
|
if is_group:
|
|
# Group: creator deletes for everyone — remove all members, clean up, delete
|
|
for member in members:
|
|
db.remove_conversation_member(conv_id, member["id"])
|
|
else:
|
|
# DM: only remove self; other user keeps the conversation
|
|
db.remove_conversation_member(conv_id, session["user_id"])
|
|
remaining_count = db.count_conversation_members(conv_id)
|
|
if remaining_count == 0:
|
|
# Clean up uploaded files from disk
|
|
file_ids = db.get_conversation_file_ids(conv_id)
|
|
for fid in file_ids:
|
|
for ext in (".enc", ".tmp"):
|
|
p = _safe_upload_path(fid, ext)
|
|
if not p:
|
|
continue
|
|
try:
|
|
p.unlink(missing_ok=True)
|
|
except Exception:
|
|
pass
|
|
db.delete_conversation(conv_id)
|
|
logger.info("Conversation deleted for user.")
|
|
await send_resp(msg, writer, "delete_conversation", "ok", {"message": "OK"})
|
|
# Notify other members they were removed
|
|
notif_data = {
|
|
"conversation_id": conv_id,
|
|
"user_id": session["user_id"],
|
|
}
|
|
member_ids = [m["id"] for m in members if m["id"] != session["user_id"]]
|
|
await _notify_users(member_ids, "member_removed", notif_data)
|
|
|
|
|
|
async def handle_mark_read(msg: dict, session: dict, writer: ProtocolWriter):
|
|
conv_id = msg.get("conversation_id", "")
|
|
message_ids = msg.get("message_ids", [])
|
|
if not conv_id or not message_ids:
|
|
await send_resp(msg, writer, "mark_read", "error", {"message": "Missing conversation_id or message_ids"})
|
|
return
|
|
if not _valid_uuid(conv_id):
|
|
await send_resp(msg, writer, "mark_read", "error", {"message": "Invalid conversation_id"})
|
|
return
|
|
if len(message_ids) > 500:
|
|
await send_resp(msg, writer, "mark_read", "error", {"message": "Too many message_ids (max 500)"})
|
|
return
|
|
if not db.is_conversation_member(conv_id, session["user_id"]):
|
|
await send_resp(msg, writer, "mark_read", "error", {"message": "Not a member"})
|
|
return
|
|
db.mark_messages_read(conv_id, session["user_id"], message_ids)
|
|
await send_resp(msg, writer, "mark_read", "ok", {"message": "OK"})
|
|
members = db.get_conversation_members(conv_id)
|
|
notif_data = {
|
|
"conversation_id": conv_id,
|
|
"user_id": session["user_id"],
|
|
"message_ids": message_ids,
|
|
}
|
|
member_ids = [m["id"] for m in members if m["id"] != session["user_id"]]
|
|
await _notify_users(member_ids, "messages_read", notif_data)
|
|
|
|
|
|
async def handle_delete_message(msg: dict, session: dict, writer: ProtocolWriter):
|
|
message_id = msg.get("message_id", "")
|
|
if not message_id:
|
|
await send_resp(msg, writer, "delete_message", "error", {"message": "Missing message_id"})
|
|
return
|
|
if not _valid_uuid(message_id):
|
|
await send_resp(msg, writer, "delete_message", "error", {"message": "Invalid message_id"})
|
|
return
|
|
conv_id = db.get_message_conversation(message_id)
|
|
if not conv_id:
|
|
await send_resp(msg, writer, "delete_message", "error", {"message": "Message not found"})
|
|
return
|
|
if not db.is_conversation_member(conv_id, session["user_id"]):
|
|
await send_resp(msg, writer, "delete_message", "error", {"message": "Not a member"})
|
|
return
|
|
result = db.soft_delete_message(message_id, session["user_id"])
|
|
if result is None:
|
|
await send_resp(msg, writer, "delete_message", "error", {"message": "Cannot delete this message"})
|
|
return
|
|
image_file_id = result.get("image_file_id")
|
|
if image_file_id:
|
|
image_path = _safe_upload_path(image_file_id, ".enc")
|
|
if image_path:
|
|
try:
|
|
image_path.unlink(missing_ok=True)
|
|
except Exception:
|
|
pass
|
|
db.delete_image_upload(image_file_id)
|
|
logger.info("Message deleted.")
|
|
await send_resp(msg, writer, "delete_message", "ok", {"message_id": message_id})
|
|
members = db.get_conversation_members(conv_id)
|
|
notif_data = {"message_id": message_id, "conversation_id": conv_id}
|
|
member_ids = [m["id"] for m in members if m["id"] != session["user_id"]]
|
|
await _notify_users(member_ids, "message_deleted", notif_data)
|
|
|
|
|
|
async def handle_upload_image_start(msg: dict, session: dict, writer: ProtocolWriter):
|
|
conv_id = msg.get("conversation_id", "")
|
|
file_size = msg.get("file_size", 0)
|
|
file_id = msg.get("file_id", "")
|
|
file_type = msg.get("file_type", "image") # "image" or "file"
|
|
if not conv_id or not file_id:
|
|
await send_resp(msg, writer, "upload_image_start", "error", {"message": "Missing fields"})
|
|
return
|
|
if not _valid_uuid(file_id):
|
|
await send_resp(msg, writer, "upload_image_start", "error", {"message": "Invalid file_id"})
|
|
return
|
|
if not db.is_conversation_member(conv_id, session["user_id"]):
|
|
await send_resp(msg, writer, "upload_image_start", "error", {"message": "Not a member"})
|
|
return
|
|
max_bytes = MAX_FILE_BYTES if file_type == "file" else MAX_IMAGE_BYTES
|
|
if max_bytes > 0 and file_size > max_bytes:
|
|
await send_resp(msg, writer, "upload_image_start", "error",
|
|
{"message": f"File too large (max {max_bytes} bytes)"})
|
|
return
|
|
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
|
temp_path = _safe_upload_path(file_id, ".tmp")
|
|
if not temp_path:
|
|
await send_resp(msg, writer, "upload_image_start", "error", {"message": "Invalid file_id"})
|
|
return
|
|
temp_path.write_bytes(b"")
|
|
async with _uploads_lock:
|
|
pending_uploads[file_id] = {
|
|
"temp_path": str(temp_path),
|
|
"received_bytes": 0,
|
|
"file_size": file_size,
|
|
"max_bytes": max_bytes,
|
|
"conv_id": conv_id,
|
|
"uploader_id": session["user_id"],
|
|
}
|
|
db.create_image_upload(file_id, conv_id, session["user_id"], file_size)
|
|
logger.info("Image upload started: %s", file_id)
|
|
await send_resp(msg, writer, "upload_image_start", "ok", {"file_id": file_id})
|
|
|
|
|
|
async def handle_upload_image_chunk(msg: dict, session: dict, writer: ProtocolWriter):
|
|
file_id = msg.get("file_id", "")
|
|
chunk_data = msg.get("data", "")
|
|
if not file_id or not chunk_data:
|
|
await send_resp(msg, writer, "upload_image_chunk", "error", {"message": "Missing fields"})
|
|
return
|
|
async with _uploads_lock:
|
|
upload = pending_uploads.get(file_id)
|
|
if not upload or upload["uploader_id"] != session["user_id"]:
|
|
upload = None
|
|
else:
|
|
temp_path_str = upload["temp_path"]
|
|
upload_max = upload.get("max_bytes", 0)
|
|
if not upload:
|
|
await send_resp(msg, writer, "upload_image_chunk", "error", {"message": "No active upload"})
|
|
return
|
|
raw = decode_binary(chunk_data)
|
|
temp_path = Path(temp_path_str)
|
|
await asyncio.to_thread(_append_file, temp_path, raw)
|
|
over_limit = False
|
|
async with _uploads_lock:
|
|
upload = pending_uploads.get(file_id)
|
|
if upload:
|
|
upload["received_bytes"] += len(raw)
|
|
if upload_max > 0 and upload["received_bytes"] > upload_max:
|
|
pending_uploads.pop(file_id, None)
|
|
over_limit = True
|
|
received = upload["received_bytes"]
|
|
if over_limit:
|
|
temp_path.unlink(missing_ok=True)
|
|
await send_resp(msg, writer, "upload_image_chunk", "error", {"message": "Upload exceeds size limit"})
|
|
return
|
|
await send_resp(msg, writer, "upload_image_chunk", "ok", {"received": received})
|
|
|
|
|
|
async def handle_upload_image_end(msg: dict, session: dict, writer: ProtocolWriter):
|
|
file_id = msg.get("file_id", "")
|
|
if not file_id:
|
|
await send_resp(msg, writer, "upload_image_end", "error", {"message": "Missing file_id"})
|
|
return
|
|
async with _uploads_lock:
|
|
upload = pending_uploads.pop(file_id, None)
|
|
if not upload or upload["uploader_id"] != session["user_id"]:
|
|
await send_resp(msg, writer, "upload_image_end", "error", {"message": "No active upload"})
|
|
return
|
|
temp_path = Path(upload["temp_path"])
|
|
if upload["received_bytes"] != upload["file_size"]:
|
|
temp_path.unlink(missing_ok=True)
|
|
await send_resp(msg, writer, "upload_image_end", "error",
|
|
{"message": f"Incomplete upload: received {upload['received_bytes']} of {upload['file_size']} bytes"})
|
|
return
|
|
final_path = _safe_upload_path(file_id, ".enc")
|
|
if not final_path:
|
|
temp_path.unlink(missing_ok=True)
|
|
await send_resp(msg, writer, "upload_image_end", "error", {"message": "Invalid file_id"})
|
|
return
|
|
def _move_file():
|
|
try:
|
|
temp_path.rename(final_path)
|
|
except Exception:
|
|
import shutil
|
|
shutil.move(str(temp_path), str(final_path))
|
|
await asyncio.to_thread(_move_file)
|
|
db.complete_image_upload(file_id)
|
|
logger.info("Image upload completed: %s (%d bytes)", file_id, upload["received_bytes"])
|
|
await send_resp(msg, writer, "upload_image_end", "ok", {"file_id": file_id})
|
|
|
|
|
|
async def handle_download_image(msg: dict, session: dict, writer: ProtocolWriter):
|
|
file_id = msg.get("file_id", "")
|
|
offset = msg.get("offset", 0)
|
|
if not file_id:
|
|
await send_resp(msg, writer, "download_image", "error", {"message": "Missing file_id"})
|
|
return
|
|
if not _valid_uuid(file_id):
|
|
await send_resp(msg, writer, "download_image", "error", {"message": "Invalid file_id"})
|
|
return
|
|
upload = db.get_image_upload(file_id)
|
|
if not upload or not upload["completed"]:
|
|
await send_resp(msg, writer, "download_image", "error", {"message": "File not found"})
|
|
return
|
|
if not db.is_conversation_member(upload["conversation_id"], session["user_id"]):
|
|
await send_resp(msg, writer, "download_image", "error", {"message": "Not a member"})
|
|
return
|
|
file_path = _safe_upload_path(file_id, ".enc")
|
|
if not file_path or not file_path.exists():
|
|
await send_resp(msg, writer, "download_image", "error", {"message": "File not found"})
|
|
return
|
|
file_size = file_path.stat().st_size
|
|
chunk = await asyncio.to_thread(_read_file_chunk, file_path, offset, IMAGE_CHUNK_SIZE)
|
|
done = (offset + len(chunk)) >= file_size
|
|
await send_resp(msg, writer, "download_image", "ok", {
|
|
"file_id": file_id,
|
|
"data": encode_binary(chunk),
|
|
"offset": offset,
|
|
"done": done,
|
|
"total_size": file_size,
|
|
})
|
|
|
|
|
|
MAX_AVATAR_BYTES = 2 * 1024 * 1024 # 2 MB
|
|
|
|
|
|
async def handle_get_profile(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""Get user profile (respects visibility for other users)."""
|
|
target_user_id = msg.get("user_id", "").strip()
|
|
if not target_user_id:
|
|
target_user_id = session["user_id"]
|
|
elif not _valid_uuid(target_user_id):
|
|
await send_resp(msg, writer, "get_profile", "error", {"message": "Invalid user_id"})
|
|
return
|
|
profile = db.get_user_profile(target_user_id, viewer_id=session["user_id"])
|
|
if not profile:
|
|
await send_resp(msg, writer, "get_profile", "error", {"message": "User not found"})
|
|
return
|
|
# Serialize datetime fields
|
|
for key in ("created_at", "updated_at"):
|
|
if profile.get(key) and hasattr(profile[key], "isoformat"):
|
|
profile[key] = profile[key].isoformat()
|
|
await send_resp(msg, writer, "get_profile", "ok", profile)
|
|
|
|
|
|
async def handle_update_profile(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""Update own profile fields."""
|
|
fields = {}
|
|
for key in ("phone", "phone_visible", "email_visible", "location", "location_visible"):
|
|
if key in msg:
|
|
fields[key] = msg[key]
|
|
if not fields:
|
|
await send_resp(msg, writer, "update_profile", "error", {"message": "No fields to update"})
|
|
return
|
|
db.update_user_profile(session["user_id"], **fields)
|
|
await send_resp(msg, writer, "update_profile", "ok", {"message": "OK"})
|
|
|
|
|
|
async def handle_update_avatar(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""Upload avatar (base64 in single message, max 2MB)."""
|
|
avatar_b64 = msg.get("data", "")
|
|
if not avatar_b64:
|
|
await send_resp(msg, writer, "update_avatar", "error", {"message": "Missing data"})
|
|
return
|
|
avatar_data = decode_binary(avatar_b64)
|
|
if len(avatar_data) > MAX_AVATAR_BYTES:
|
|
await send_resp(msg, writer, "update_avatar", "error",
|
|
{"message": f"Avatar too large (max {MAX_AVATAR_BYTES} bytes)"})
|
|
return
|
|
# Detect format from magic bytes
|
|
ext = "jpg"
|
|
if avatar_data[:8] == b'\x89PNG\r\n\x1a\n':
|
|
ext = "png"
|
|
avatar_dir = UPLOAD_DIR / "avatars"
|
|
avatar_dir.mkdir(parents=True, exist_ok=True)
|
|
filename = f"{session['user_id']}.{ext}"
|
|
avatar_path = _safe_avatar_path(filename)
|
|
if not avatar_path:
|
|
await send_resp(msg, writer, "update_avatar", "error", {"message": "Invalid path"})
|
|
return
|
|
await asyncio.to_thread(avatar_path.write_bytes, avatar_data)
|
|
db.update_user_profile(session["user_id"], avatar_file=filename)
|
|
logger.info("Avatar updated for user %s", session["user_id"])
|
|
await send_resp(msg, writer, "update_avatar", "ok", {"avatar_file": filename})
|
|
|
|
|
|
async def handle_get_avatar(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""Download avatar for a user."""
|
|
target_user_id = msg.get("user_id", "").strip()
|
|
if not target_user_id:
|
|
await send_resp(msg, writer, "get_avatar", "error", {"message": "Missing user_id"})
|
|
return
|
|
if not _valid_uuid(target_user_id):
|
|
await send_resp(msg, writer, "get_avatar", "error", {"message": "Invalid user_id"})
|
|
return
|
|
profile = db.get_user_profile(target_user_id)
|
|
if not profile or not profile.get("avatar_file"):
|
|
await send_resp(msg, writer, "get_avatar", "error", {"message": "No avatar"})
|
|
return
|
|
avatar_path = _safe_avatar_path(profile["avatar_file"])
|
|
if not avatar_path or not avatar_path.exists():
|
|
await send_resp(msg, writer, "get_avatar", "error", {"message": "Avatar file not found"})
|
|
return
|
|
avatar_data = await asyncio.to_thread(avatar_path.read_bytes)
|
|
await send_resp(msg, writer, "get_avatar", "ok", {
|
|
"user_id": target_user_id,
|
|
"data": encode_binary(avatar_data),
|
|
"filename": profile["avatar_file"],
|
|
})
|
|
|
|
|
|
async def handle_update_group_avatar(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""Upload avatar for a group conversation (base64, max 2MB). Only members can set it."""
|
|
conv_id = msg.get("conversation_id", "").strip()
|
|
avatar_b64 = msg.get("data", "")
|
|
if not conv_id or not avatar_b64:
|
|
await send_resp(msg, writer, "update_group_avatar", "error", {"message": "Missing fields"})
|
|
return
|
|
if not _valid_uuid(conv_id):
|
|
await send_resp(msg, writer, "update_group_avatar", "error", {"message": "Invalid conversation_id"})
|
|
return
|
|
if not db.is_conversation_member(conv_id, session["user_id"]):
|
|
await send_resp(msg, writer, "update_group_avatar", "error", {"message": "Not a member"})
|
|
return
|
|
avatar_data = decode_binary(avatar_b64)
|
|
if len(avatar_data) > MAX_AVATAR_BYTES:
|
|
await send_resp(msg, writer, "update_group_avatar", "error",
|
|
{"message": f"Avatar too large (max {MAX_AVATAR_BYTES} bytes)"})
|
|
return
|
|
ext = "jpg"
|
|
if avatar_data[:8] == b'\x89PNG\r\n\x1a\n':
|
|
ext = "png"
|
|
avatar_dir = UPLOAD_DIR / "avatars"
|
|
avatar_dir.mkdir(parents=True, exist_ok=True)
|
|
filename = f"group_{conv_id}.{ext}"
|
|
avatar_path = _safe_avatar_path(filename)
|
|
if not avatar_path:
|
|
await send_resp(msg, writer, "update_group_avatar", "error", {"message": "Invalid path"})
|
|
return
|
|
await asyncio.to_thread(avatar_path.write_bytes, avatar_data)
|
|
db.update_conversation_avatar(conv_id, filename)
|
|
logger.info("Group avatar updated for conversation %s", conv_id)
|
|
await send_resp(msg, writer, "update_group_avatar", "ok", {"avatar_file": filename})
|
|
|
|
|
|
async def handle_get_group_avatar(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""Download avatar for a group conversation."""
|
|
conv_id = msg.get("conversation_id", "").strip()
|
|
if not conv_id:
|
|
await send_resp(msg, writer, "get_group_avatar", "error", {"message": "Missing conversation_id"})
|
|
return
|
|
if not _valid_uuid(conv_id):
|
|
await send_resp(msg, writer, "get_group_avatar", "error", {"message": "Invalid conversation_id"})
|
|
return
|
|
if not db.is_conversation_member(conv_id, session["user_id"]):
|
|
await send_resp(msg, writer, "get_group_avatar", "error", {"message": "Not a member"})
|
|
return
|
|
conv = db.get_conversation(conv_id)
|
|
if not conv or not conv.get("avatar_file"):
|
|
await send_resp(msg, writer, "get_group_avatar", "error", {"message": "No avatar"})
|
|
return
|
|
avatar_path = _safe_avatar_path(conv["avatar_file"])
|
|
if not avatar_path or not avatar_path.exists():
|
|
await send_resp(msg, writer, "get_group_avatar", "error", {"message": "Avatar file not found"})
|
|
return
|
|
avatar_data = await asyncio.to_thread(avatar_path.read_bytes)
|
|
await send_resp(msg, writer, "get_group_avatar", "ok", {
|
|
"conversation_id": conv_id,
|
|
"data": encode_binary(avatar_data),
|
|
"filename": conv["avatar_file"],
|
|
})
|
|
|
|
|
|
async def handle_list_devices(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""List all devices for the current user."""
|
|
devices = db.get_user_devices(session["user_id"])
|
|
result = []
|
|
for d in devices:
|
|
entry = {
|
|
"device_id": d["id"],
|
|
"device_name": d.get("device_name"),
|
|
"created_at": d["created_at"].isoformat() if hasattr(d["created_at"], "isoformat") else str(d["created_at"]),
|
|
"last_seen_at": d["last_seen_at"].isoformat() if d.get("last_seen_at") and hasattr(d["last_seen_at"], "isoformat") else (str(d["last_seen_at"]) if d.get("last_seen_at") else None),
|
|
"is_current": d["id"] == session.get("device_id"),
|
|
}
|
|
result.append(entry)
|
|
await send_resp(msg, writer, "list_devices", "ok", {"devices": result})
|
|
|
|
|
|
async def handle_remove_device(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""Remove a device (cannot remove current device)."""
|
|
device_id = msg.get("device_id", "").strip()
|
|
if not device_id:
|
|
await send_resp(msg, writer, "remove_device", "error", {"message": "Missing device_id"})
|
|
return
|
|
if not _valid_uuid(device_id):
|
|
await send_resp(msg, writer, "remove_device", "error", {"message": "Invalid device_id"})
|
|
return
|
|
if device_id == session.get("device_id"):
|
|
await send_resp(msg, writer, "remove_device", "error", {"message": "Cannot remove current device"})
|
|
return
|
|
dev = db.get_device(device_id)
|
|
if not dev or dev["user_id"] != session["user_id"]:
|
|
await send_resp(msg, writer, "remove_device", "error", {"message": "Device not found"})
|
|
return
|
|
db.delete_device(device_id)
|
|
logger.info("Device removed: %s", device_id)
|
|
await send_resp(msg, writer, "remove_device", "ok", {"message": "OK"})
|
|
|
|
|
|
async def handle_session_reset(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""Notify peer to reset a corrupted Double Ratchet session."""
|
|
peer_user_id = msg.get("peer_user_id", "").strip()
|
|
peer_device_id = msg.get("peer_device_id", "").strip() or None
|
|
if not peer_user_id or not _valid_uuid(peer_user_id):
|
|
await send_resp(msg, writer, "session_reset", "error", {"message": "Invalid peer_user_id"})
|
|
return
|
|
if peer_device_id and not _valid_uuid(peer_device_id):
|
|
await send_resp(msg, writer, "session_reset", "error", {"message": "Invalid peer_device_id"})
|
|
return
|
|
# Push notification to peer
|
|
await _notify_users([peer_user_id], "session_reset", {
|
|
"from_user_id": session["user_id"],
|
|
"from_device_id": session.get("device_id"),
|
|
})
|
|
await send_resp(msg, writer, "session_reset", "ok", {})
|
|
|
|
|
|
async def handle_reencrypt_messages(msg: dict, session: dict, writer: ProtocolWriter):
|
|
"""Re-encrypt message history with self-encryption key (for device pairing)."""
|
|
updates_raw = msg.get("updates", [])
|
|
if not updates_raw:
|
|
await send_resp(msg, writer, "reencrypt_messages", "error", {"message": "No updates"})
|
|
return
|
|
if len(updates_raw) > 500:
|
|
await send_resp(msg, writer, "reencrypt_messages", "error",
|
|
{"message": "Too many updates (max 500 per request)"})
|
|
return
|
|
updates = []
|
|
for u in updates_raw:
|
|
mid = u.get("message_id", "")
|
|
ct_b64 = u.get("encrypted_content", "")
|
|
nonce_b64 = u.get("nonce", "")
|
|
if not mid or not ct_b64 or not nonce_b64:
|
|
continue
|
|
updates.append({
|
|
"message_id": mid,
|
|
"encrypted_content": decode_binary(ct_b64),
|
|
"nonce": decode_binary(nonce_b64),
|
|
})
|
|
if not updates:
|
|
await send_resp(msg, writer, "reencrypt_messages", "error", {"message": "No valid updates"})
|
|
return
|
|
db.batch_reencrypt_messages(session["user_id"], updates)
|
|
logger.info("Re-encrypted %d messages for user.", len(updates))
|
|
await send_resp(msg, writer, "reencrypt_messages", "ok", {"count": len(updates)})
|
|
|
|
|
|
async def _cleanup_uploads():
|
|
stale = db.get_stale_uploads(3600)
|
|
for s in stale:
|
|
fid = s["file_id"]
|
|
for ext in (".tmp", ".enc"):
|
|
p = _safe_upload_path(fid, ext)
|
|
if not p:
|
|
continue
|
|
try:
|
|
p.unlink(missing_ok=True)
|
|
except Exception:
|
|
pass
|
|
db.delete_image_upload(fid)
|
|
async with _uploads_lock:
|
|
pending_uploads.pop(fid, None)
|
|
if stale:
|
|
logger.info("Cleaned up %d stale uploads.", len(stale))
|
|
|
|
|
|
async def handle_client(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
|
global current_connections
|
|
addr = _get_peer_addr(ProtocolWriter(writer))
|
|
async with _conn_lock:
|
|
current_connections += 1
|
|
connection_counts[addr] = connection_counts.get(addr, 0) + 1
|
|
over_limit = (current_connections > MAX_CONNECTIONS_GLOBAL or
|
|
connection_counts[addr] > MAX_CONNECTIONS_PER_IP)
|
|
if over_limit:
|
|
try:
|
|
writer.close()
|
|
except Exception:
|
|
pass
|
|
async with _conn_lock:
|
|
current_connections = max(0, current_connections - 1)
|
|
connection_counts[addr] = max(0, connection_counts.get(addr, 1) - 1)
|
|
return
|
|
logger.debug("Client connected.")
|
|
proto_reader = ProtocolReader(reader)
|
|
proto_writer = ProtocolWriter(writer)
|
|
session = None
|
|
state = {"_req_times": []}
|
|
|
|
try:
|
|
while True:
|
|
try:
|
|
msg = await proto_reader.read_message()
|
|
except ValueError as e:
|
|
try:
|
|
await proto_writer.send_response("protocol_error", "error", {"message": str(e)})
|
|
except Exception:
|
|
pass
|
|
break
|
|
if msg is None:
|
|
break
|
|
|
|
msg_type = msg.get("type", "")
|
|
now = asyncio.get_event_loop().time()
|
|
times = [t for t in state["_req_times"] if now - t <= CONNECTION_RL_WINDOW]
|
|
if len(times) >= CONNECTION_RL_MAX:
|
|
await send_resp(msg, proto_writer, msg_type, "error", {"message": "Too many requests. Slow down."})
|
|
state["_req_times"] = times
|
|
continue
|
|
times.append(now)
|
|
state["_req_times"] = times
|
|
|
|
try:
|
|
if msg_type == "register":
|
|
await handle_register_start(msg, proto_writer)
|
|
elif msg_type == "register_confirm":
|
|
await handle_register_confirm(msg, proto_writer)
|
|
elif msg_type == "login_start":
|
|
await handle_login_start(msg, proto_writer, state)
|
|
elif msg_type == "login_finish":
|
|
result = await handle_login_finish(msg, proto_writer, state)
|
|
if result:
|
|
session = result
|
|
elif msg_type == "pairing_start":
|
|
await handle_pairing_start(msg, proto_writer)
|
|
elif msg_type == "pairing_poll":
|
|
await handle_pairing_poll(msg, proto_writer)
|
|
elif session is None:
|
|
await send_resp(msg, proto_writer, msg_type, "error", {"message": "Not logged in"})
|
|
elif msg_type == "get_user_info":
|
|
await handle_get_user_info(msg, proto_writer)
|
|
elif msg_type == "upload_prekeys":
|
|
await handle_upload_prekeys(msg, session, proto_writer)
|
|
elif msg_type == "get_key_bundle":
|
|
await handle_get_key_bundle(msg, session, proto_writer)
|
|
elif msg_type == "get_prekey_count":
|
|
await handle_get_prekey_count(msg, session, proto_writer)
|
|
elif msg_type == "create_conversation":
|
|
await handle_create_conversation(msg, session, proto_writer)
|
|
elif msg_type == "find_conversation":
|
|
await handle_find_conversation(msg, session, proto_writer)
|
|
elif msg_type == "add_member":
|
|
await handle_add_member(msg, session, proto_writer)
|
|
elif msg_type == "accept_invitation":
|
|
await handle_accept_invitation(msg, session, proto_writer)
|
|
elif msg_type == "decline_invitation":
|
|
await handle_decline_invitation(msg, session, proto_writer)
|
|
elif msg_type == "list_invitations":
|
|
await handle_list_invitations(msg, session, proto_writer)
|
|
elif msg_type == "list_conversations":
|
|
await handle_list_conversations(msg, session, proto_writer)
|
|
elif msg_type == "send_message":
|
|
await handle_send_message(msg, session, proto_writer)
|
|
elif msg_type == "get_messages":
|
|
await handle_get_messages(msg, session, proto_writer)
|
|
elif msg_type == "rotate_keys":
|
|
await handle_rotate_keys(msg, session, proto_writer)
|
|
elif msg_type == "remove_member":
|
|
await handle_remove_member(msg, session, proto_writer)
|
|
elif msg_type == "leave_group":
|
|
await handle_leave_group(msg, session, proto_writer)
|
|
elif msg_type == "rename_conversation":
|
|
await handle_rename_conversation(msg, session, proto_writer)
|
|
elif msg_type == "delete_conversation":
|
|
await handle_delete_conversation(msg, session, proto_writer)
|
|
elif msg_type == "mark_read":
|
|
await handle_mark_read(msg, session, proto_writer)
|
|
elif msg_type == "pairing_claim":
|
|
await handle_pairing_claim(msg, session, proto_writer)
|
|
elif msg_type == "pairing_send":
|
|
await handle_pairing_send(msg, session, proto_writer)
|
|
elif msg_type == "delete_message":
|
|
await handle_delete_message(msg, session, proto_writer)
|
|
elif msg_type == "upload_image_start":
|
|
await handle_upload_image_start(msg, session, proto_writer)
|
|
elif msg_type == "upload_image_chunk":
|
|
await handle_upload_image_chunk(msg, session, proto_writer)
|
|
elif msg_type == "upload_image_end":
|
|
await handle_upload_image_end(msg, session, proto_writer)
|
|
elif msg_type == "download_image":
|
|
await handle_download_image(msg, session, proto_writer)
|
|
elif msg_type == "get_profile":
|
|
await handle_get_profile(msg, session, proto_writer)
|
|
elif msg_type == "update_profile":
|
|
await handle_update_profile(msg, session, proto_writer)
|
|
elif msg_type == "update_avatar":
|
|
await handle_update_avatar(msg, session, proto_writer)
|
|
elif msg_type == "get_avatar":
|
|
await handle_get_avatar(msg, session, proto_writer)
|
|
elif msg_type == "update_group_avatar":
|
|
await handle_update_group_avatar(msg, session, proto_writer)
|
|
elif msg_type == "get_group_avatar":
|
|
await handle_get_group_avatar(msg, session, proto_writer)
|
|
elif msg_type == "reencrypt_messages":
|
|
await handle_reencrypt_messages(msg, session, proto_writer)
|
|
elif msg_type == "list_devices":
|
|
await handle_list_devices(msg, session, proto_writer)
|
|
elif msg_type == "remove_device":
|
|
await handle_remove_device(msg, session, proto_writer)
|
|
elif msg_type == "session_reset":
|
|
await handle_session_reset(msg, session, proto_writer)
|
|
else:
|
|
await send_resp(msg, proto_writer, msg_type, "error", {"message": "Unknown type"})
|
|
except Exception as e:
|
|
logger.warning("Handler error for '%s': %s", msg_type, e, exc_info=True)
|
|
try:
|
|
await send_resp(msg, proto_writer, msg_type, "error", {"message": "Internal server error"})
|
|
except Exception:
|
|
break # Can't send response — connection is dead
|
|
except Exception as e:
|
|
logger.warning("Client connection error: %s", e)
|
|
finally:
|
|
async with _conn_lock:
|
|
current_connections = max(0, current_connections - 1)
|
|
connection_counts[addr] = max(0, connection_counts.get(addr, 1) - 1)
|
|
offline_targets = []
|
|
if session:
|
|
uid = session["user_id"]
|
|
contacts = db.get_user_contacts(uid)
|
|
async with _clients_lock:
|
|
writer_device_map.pop(id(proto_writer), None)
|
|
if uid in connected_clients:
|
|
remaining = [w for w in connected_clients[uid] if w is not proto_writer]
|
|
if remaining:
|
|
connected_clients[uid] = remaining
|
|
else:
|
|
del connected_clients[uid]
|
|
# User fully offline — snapshot targets under lock
|
|
for contact_id in contacts:
|
|
for cw in connected_clients.get(contact_id, []):
|
|
offline_targets.append(cw)
|
|
# Send offline notifications outside lock
|
|
for cw in offline_targets:
|
|
try:
|
|
await cw.send_response("user_offline", "ok", {"user_id": uid})
|
|
except Exception:
|
|
pass
|
|
writer.close()
|
|
logger.debug("Client disconnected.")
|
|
|
|
|
|
async def main():
|
|
setup_logging()
|
|
host = os.getenv("SERVER_HOST", "127.0.0.1")
|
|
port = int(os.getenv("SERVER_PORT", "9999"))
|
|
tls_enabled = os.getenv("TLS_ENABLED", "false").lower() in ("1", "true", "yes")
|
|
tls_required = os.getenv("TLS_REQUIRED", "false").lower() in ("1", "true", "yes")
|
|
tls_autogen = os.getenv("TLS_AUTOGEN", "false").lower() in ("1", "true", "yes")
|
|
|
|
is_dev = os.getenv("ENVIRONMENT", "").lower() in ("dev", "development")
|
|
ssl_context = None
|
|
if tls_required and not tls_enabled:
|
|
raise RuntimeError("TLS_REQUIRED is enabled but TLS is not enabled.")
|
|
if tls_enabled:
|
|
cert_file = os.getenv("TLS_CERT_FILE", "").strip()
|
|
key_file = os.getenv("TLS_KEY_FILE", "").strip()
|
|
if not cert_file or not key_file:
|
|
if tls_autogen:
|
|
if not is_dev:
|
|
raise RuntimeError("TLS_AUTOGEN is only allowed when ENVIRONMENT=dev")
|
|
cert_dir = Path(__file__).resolve().parent / "certs"
|
|
cert_dir.mkdir(parents=True, exist_ok=True)
|
|
cert_file = str(cert_dir / "server.crt")
|
|
key_file = str(cert_dir / "server.key")
|
|
if not (os.path.exists(cert_file) and os.path.exists(key_file)):
|
|
try:
|
|
subprocess.run(
|
|
[
|
|
"openssl", "req", "-x509", "-newkey", "rsa:4096",
|
|
"-keyout", key_file, "-out", cert_file,
|
|
"-days", "365", "-nodes", "-subj", "/CN=localhost",
|
|
],
|
|
check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
|
|
)
|
|
os.chmod(key_file, 0o600)
|
|
except FileNotFoundError:
|
|
raise RuntimeError("OpenSSL not found.")
|
|
except subprocess.CalledProcessError:
|
|
raise RuntimeError("Failed to auto-generate TLS cert.")
|
|
logger.warning("Using auto-generated self-signed certificate — not for production use.")
|
|
else:
|
|
raise RuntimeError("TLS is enabled but TLS_CERT_FILE or TLS_KEY_FILE is missing.")
|
|
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
|
ssl_context.load_cert_chain(certfile=cert_file, keyfile=key_file)
|
|
else:
|
|
logger.warning("TLS is disabled — traffic is unencrypted. Set TLS_ENABLED=true for production.")
|
|
|
|
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Load phantom user IDs from DB into in-memory cache
|
|
phantom_user_ids.update(db.get_all_phantom_user_ids())
|
|
if phantom_user_ids:
|
|
logger.info("Loaded %d phantom user IDs.", len(phantom_user_ids))
|
|
|
|
server = await asyncio.start_server(
|
|
handle_client, host, port, limit=MAX_MESSAGE_BYTES, ssl=ssl_context,
|
|
)
|
|
logger.info("Encrypted chat server v%s listening on %s:%s", VERSION, host, port)
|
|
|
|
async def _cleanup_rate_limits():
|
|
async with _conn_lock:
|
|
now = asyncio.get_event_loop().time()
|
|
window_start = now - RATE_LIMIT_WINDOW
|
|
stale_keys = [k for k, times in rate_limits.items()
|
|
if not any(t >= window_start for t in times)]
|
|
for k in stale_keys:
|
|
del rate_limits[k]
|
|
stale_conns = [k for k, v in connection_counts.items() if v <= 0]
|
|
for k in stale_conns:
|
|
del connection_counts[k]
|
|
|
|
async def _periodic_cleanup():
|
|
while True:
|
|
await asyncio.sleep(600)
|
|
try:
|
|
await _cleanup_uploads()
|
|
except Exception as e:
|
|
logger.warning("Upload cleanup error: %s", e)
|
|
try:
|
|
await _cleanup_rate_limits()
|
|
except Exception as e:
|
|
logger.warning("Rate limit cleanup error: %s", e)
|
|
# L8: clean up stale phantom users (>30 days, no real conversations)
|
|
try:
|
|
deleted = db.cleanup_stale_phantoms(30)
|
|
if deleted:
|
|
async with _clients_lock:
|
|
phantom_user_ids.clear()
|
|
phantom_user_ids.update(db.get_all_phantom_user_ids())
|
|
logger.info("Cleaned up %d stale phantom users.", deleted)
|
|
except Exception as e:
|
|
logger.warning("Phantom cleanup error: %s", e)
|
|
|
|
asyncio.create_task(_periodic_cleanup())
|
|
|
|
loop = asyncio.get_running_loop()
|
|
stop = loop.create_future()
|
|
|
|
def signal_handler():
|
|
if not stop.done():
|
|
stop.set_result(None)
|
|
|
|
for sig in (signal.SIGINT, signal.SIGTERM):
|
|
loop.add_signal_handler(sig, signal_handler)
|
|
|
|
async with server:
|
|
await stop
|
|
# Force-close all connected clients BEFORE exiting context manager,
|
|
# otherwise wait_closed() blocks forever waiting for handle_client tasks
|
|
logger.info("Shutting down — closing %d client connections...", sum(len(ws) for ws in connected_clients.values()))
|
|
async with _clients_lock:
|
|
all_writers = [w for writers in connected_clients.values() for w in writers]
|
|
connected_clients.clear()
|
|
writer_device_map.clear()
|
|
for w in all_writers:
|
|
try:
|
|
w.close()
|
|
except Exception:
|
|
pass
|
|
logger.info("Server shut down.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|