"""MySQL database layer for the encrypted chat server.""" import os import uuid import logging import mysql.connector from mysql.connector import pooling from dotenv import load_dotenv from crypto_utils import ( generate_identity_keypair, serialize_ed25519_public, generate_signed_prekey, serialize_x25519_public, generate_one_time_prekeys, ) load_dotenv() # Sentinel device_id for self-encrypted copies and legacy (pre-multi-device) rows SELF_DEVICE_ID = "00000000-0000-0000-0000-000000000000" _logger = logging.getLogger(__name__) _pool = None def _get_pool(): """Get or create the connection pool (lazy init).""" global _pool if _pool is None: pool_size = int(os.getenv("DB_POOL_SIZE", "10")) pool_kwargs = dict( pool_name="chat_pool", pool_size=pool_size, pool_reset_session=True, host=os.getenv("MYSQL_HOST", "localhost"), port=int(os.getenv("MYSQL_PORT", "3306")), user=os.getenv("MYSQL_USER", "root"), password=os.getenv("MYSQL_PASSWORD", ""), database=os.getenv("MYSQL_DATABASE", "encrypted_chat"), ) # Optional MySQL TLS (M7): set MYSQL_SSL_CA (and optionally MYSQL_SSL_CERT/KEY) ssl_ca = os.getenv("MYSQL_SSL_CA", "").strip() ssl_cert = os.getenv("MYSQL_SSL_CERT", "").strip() ssl_key = os.getenv("MYSQL_SSL_KEY", "").strip() if ssl_ca: pool_kwargs["ssl_ca"] = ssl_ca if ssl_cert: pool_kwargs["ssl_cert"] = ssl_cert if ssl_key: pool_kwargs["ssl_key"] = ssl_key _logger.info("MySQL TLS enabled (CA: %s)", ssl_ca) _pool = pooling.MySQLConnectionPool(**pool_kwargs) _logger.info("DB connection pool created (size=%d)", pool_size) return _pool def get_connection(): """Get a connection from the pool.""" return _get_pool().get_connection() def generate_uuid() -> str: return str(uuid.uuid4()) # --- Devices --- def create_device(user_id: str, device_name: str | None = None) -> str: """Create a new device for a user. Returns device_id.""" conn = get_connection() try: cursor = conn.cursor() device_id = generate_uuid() cursor.execute( "INSERT INTO devices (id, user_id, device_name) VALUES (%s, %s, %s)", (device_id, user_id, device_name), ) conn.commit() return device_id finally: conn.close() def get_user_devices(user_id: str) -> list[dict]: """Get all devices for a user.""" conn = get_connection() try: cursor = conn.cursor(dictionary=True) cursor.execute( "SELECT id, user_id, device_name, created_at, last_seen_at " "FROM devices WHERE user_id = %s ORDER BY created_at", (user_id,), ) return cursor.fetchall() finally: conn.close() def get_device(device_id: str) -> dict | None: """Get a single device by ID.""" conn = get_connection() try: cursor = conn.cursor(dictionary=True) cursor.execute( "SELECT id, user_id, device_name, created_at, last_seen_at " "FROM devices WHERE id = %s", (device_id,), ) return cursor.fetchone() finally: conn.close() def update_device_last_seen(device_id: str): """Update last_seen_at timestamp for a device.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute( "UPDATE devices SET last_seen_at = NOW() WHERE id = %s", (device_id,), ) conn.commit() finally: conn.close() def delete_device(device_id: str): """Delete a device. CASCADE removes its prekeys.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute("DELETE FROM devices WHERE id = %s", (device_id,)) # Also clean up prekeys explicitly for device_id column cursor.execute("DELETE FROM signed_prekeys WHERE device_id = %s", (device_id,)) cursor.execute("DELETE FROM one_time_prekeys WHERE device_id = %s", (device_id,)) conn.commit() finally: conn.close() # --- Users --- def create_user(username: str, email: str, rsa_public_key_pem: str, identity_key: bytes) -> str: """Register a new user. Returns user ID.""" conn = get_connection() try: cursor = conn.cursor() user_id = generate_uuid() cursor.execute( "INSERT INTO users (id, username, email, rsa_public_key, identity_key) " "VALUES (%s, %s, %s, %s, %s)", (user_id, username, email, rsa_public_key_pem, identity_key), ) conn.commit() return user_id finally: conn.close() def get_user_by_email(email: str) -> dict | None: """Get user by email.""" conn = get_connection() try: cursor = conn.cursor(dictionary=True) cursor.execute( "SELECT id, username, rsa_public_key, email, identity_key FROM users WHERE email = %s", (email,), ) return cursor.fetchone() finally: conn.close() def get_user_by_id(user_id: str) -> dict | None: """Get user by ID.""" conn = get_connection() try: cursor = conn.cursor(dictionary=True) cursor.execute( "SELECT id, username, rsa_public_key, email, identity_key FROM users WHERE id = %s", (user_id,), ) return cursor.fetchone() finally: conn.close() def shares_conversation(user_id_a: str, user_id_b: str) -> bool: """Check if two users share at least one conversation.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute( "SELECT 1 FROM conversation_members cm1 " "JOIN conversation_members cm2 ON cm1.conversation_id = cm2.conversation_id " "WHERE cm1.user_id = %s AND cm2.user_id = %s LIMIT 1", (user_id_a, user_id_b), ) return cursor.fetchone() is not None finally: conn.close() def get_user_contacts(user_id: str) -> list[str]: """Get all user IDs that share at least one conversation with the given user.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute( "SELECT DISTINCT cm2.user_id " "FROM conversation_members cm1 " "JOIN conversation_members cm2 ON cm1.conversation_id = cm2.conversation_id " "WHERE cm1.user_id = %s AND cm2.user_id != %s", (user_id, user_id), ) return [row[0] for row in cursor.fetchall()] finally: conn.close() def update_user_rsa_key(user_id: str, rsa_public_key_pem: str): """Update user's RSA public key (for login).""" conn = get_connection() try: cursor = conn.cursor() cursor.execute("UPDATE users SET rsa_public_key = %s WHERE id = %s", (rsa_public_key_pem, user_id)) conn.commit() finally: conn.close() def update_username(user_id: str, new_username: str): """Update user's display name.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute("UPDATE users SET username = %s WHERE id = %s", (new_username, user_id)) conn.commit() finally: conn.close() # --- Pre-keys --- def store_signed_prekey(user_id: str, spk_id: str, public_key: bytes, signature: bytes, device_id: str | None = None): """Store (or replace) a signed pre-key for a user's device.""" conn = get_connection() try: cursor = conn.cursor() # Remove old SPKs for this user+device if device_id: cursor.execute("DELETE FROM signed_prekeys WHERE user_id = %s AND device_id = %s", (user_id, device_id)) else: cursor.execute("DELETE FROM signed_prekeys WHERE user_id = %s AND device_id IS NULL", (user_id,)) cursor.execute( "INSERT INTO signed_prekeys (id, user_id, device_id, public_key, signature) " "VALUES (%s, %s, %s, %s, %s)", (spk_id, user_id, device_id, public_key, signature), ) conn.commit() finally: conn.close() def get_signed_prekey(user_id: str, device_id: str | None = None) -> dict | None: """Get the current signed pre-key for a user (optionally per device).""" conn = get_connection() try: cursor = conn.cursor(dictionary=True) if device_id: cursor.execute( "SELECT id, public_key, signature, device_id, created_at FROM signed_prekeys " "WHERE user_id = %s AND device_id = %s " "ORDER BY created_at DESC LIMIT 1", (user_id, device_id), ) else: cursor.execute( "SELECT id, public_key, signature, device_id, created_at FROM signed_prekeys " "WHERE user_id = %s ORDER BY created_at DESC LIMIT 1", (user_id,), ) return cursor.fetchone() finally: conn.close() def store_one_time_prekeys(user_id: str, prekeys: list[dict], device_id: str | None = None): """Store a batch of one-time pre-keys. Each dict has {id, public_key (bytes)}.""" conn = get_connection() try: cursor = conn.cursor() for pk in prekeys: cursor.execute( "INSERT INTO one_time_prekeys (id, user_id, device_id, public_key) " "VALUES (%s, %s, %s, %s)", (pk["id"], user_id, device_id, pk["public_key"]), ) conn.commit() finally: conn.close() def consume_one_time_prekey(user_id: str, device_id: str | None = None) -> dict | None: """Atomically consume one OTP: SELECT FOR UPDATE + DELETE. Returns {id, public_key} or None.""" conn = get_connection() try: cursor = conn.cursor(dictionary=True) conn.start_transaction() if device_id: cursor.execute( "SELECT id, public_key FROM one_time_prekeys " "WHERE user_id = %s AND device_id = %s LIMIT 1 FOR UPDATE", (user_id, device_id), ) else: cursor.execute( "SELECT id, public_key FROM one_time_prekeys " "WHERE user_id = %s LIMIT 1 FOR UPDATE", (user_id,), ) row = cursor.fetchone() if row: cursor.execute("DELETE FROM one_time_prekeys WHERE id = %s", (row["id"],)) conn.commit() return row except Exception: conn.rollback() raise finally: conn.close() def count_one_time_prekeys(user_id: str, device_id: str | None = None) -> int: """Count remaining OTPs for a user (optionally per device).""" conn = get_connection() try: cursor = conn.cursor() if device_id: cursor.execute( "SELECT COUNT(*) FROM one_time_prekeys WHERE user_id = %s AND device_id = %s", (user_id, device_id), ) else: cursor.execute("SELECT COUNT(*) FROM one_time_prekeys WHERE user_id = %s", (user_id,)) return cursor.fetchone()[0] finally: conn.close() def get_key_bundle(user_id: str) -> dict | None: """Get complete key bundle for X3DH (single device — legacy compat). Returns {identity_key, signed_prekey_id, signed_prekey, spk_signature, one_time_prekey_id, one_time_prekey} or None. OTP is consumed atomically. """ conn = get_connection() try: cursor = conn.cursor(dictionary=True) # Get user identity key cursor.execute("SELECT identity_key FROM users WHERE id = %s", (user_id,)) user = cursor.fetchone() if not user: return None # Get signed prekey cursor.execute( "SELECT id, public_key, signature, device_id FROM signed_prekeys WHERE user_id = %s " "ORDER BY created_at DESC LIMIT 1", (user_id,), ) spk = cursor.fetchone() if not spk: return None # Consume one OTP (may be None) — use transaction for atomicity (H12 fix) conn.start_transaction() cursor.execute( "SELECT id, public_key FROM one_time_prekeys WHERE user_id = %s LIMIT 1 FOR UPDATE", (user_id,), ) opk = cursor.fetchone() if opk: cursor.execute("DELETE FROM one_time_prekeys WHERE id = %s", (opk["id"],)) conn.commit() result = { "identity_key": user["identity_key"], "signed_prekey_id": spk["id"], "signed_prekey": spk["public_key"], "spk_signature": spk["signature"], } if opk: result["one_time_prekey_id"] = opk["id"] result["one_time_prekey"] = opk["public_key"] return result except Exception: try: conn.rollback() except Exception: pass raise finally: conn.close() def get_key_bundles_for_user(user_id: str) -> dict | None: """Get key bundles for ALL devices of a user. Returns {identity_key, device_bundles: [{device_id, signed_prekey_id, signed_prekey_pub, spk_signature, opk_id, opk_pub}]} or None. Consumes one OPK per device atomically. """ conn = get_connection() try: cursor = conn.cursor(dictionary=True) # Get user identity key cursor.execute("SELECT identity_key FROM users WHERE id = %s", (user_id,)) user = cursor.fetchone() if not user: return None # Get all signed prekeys (one per device, most recent) cursor.execute( "SELECT id, public_key, signature, device_id FROM signed_prekeys " "WHERE user_id = %s ORDER BY created_at DESC", (user_id,), ) all_spks = cursor.fetchall() if not all_spks: return None # De-duplicate: keep only the most recent SPK per device_id seen_devices = set() spks_by_device = [] for spk in all_spks: dev = spk.get("device_id") or "__legacy__" if dev not in seen_devices: seen_devices.add(dev) spks_by_device.append(spk) device_bundles = [] # Commit the implicit transaction from the read-only queries above # so we can start an explicit transaction for atomic OPK consumption. conn.commit() conn.start_transaction() for spk in spks_by_device: dev_id = spk.get("device_id") # Consume one OPK for this device if dev_id: cursor.execute( "SELECT id, public_key FROM one_time_prekeys " "WHERE user_id = %s AND device_id = %s LIMIT 1 FOR UPDATE", (user_id, dev_id), ) else: cursor.execute( "SELECT id, public_key FROM one_time_prekeys " "WHERE user_id = %s AND device_id IS NULL LIMIT 1 FOR UPDATE", (user_id,), ) opk = cursor.fetchone() if opk: cursor.execute("DELETE FROM one_time_prekeys WHERE id = %s", (opk["id"],)) bundle = { "device_id": dev_id, "signed_prekey_id": spk["id"], "signed_prekey_pub": spk["public_key"], "spk_signature": spk["signature"], } if opk: bundle["opk_id"] = opk["id"] bundle["opk_pub"] = opk["public_key"] device_bundles.append(bundle) conn.commit() return { "identity_key": user["identity_key"], "device_bundles": device_bundles, } except Exception: try: conn.rollback() except Exception: pass raise finally: conn.close() # --- Conversations --- def create_conversation(member_user_ids: list[str], joined_at=None, name=None, created_by=None) -> str: conn = get_connection() try: cursor = conn.cursor() conv_id = generate_uuid() cursor.execute("INSERT INTO conversations (id, name, created_by) VALUES (%s, %s, %s)", (conv_id, name, created_by)) for uid in member_user_ids: cursor.execute( "INSERT INTO conversation_members (conversation_id, user_id, joined_at) VALUES (%s, %s, %s)", (conv_id, uid, joined_at), ) conn.commit() return conv_id finally: conn.close() def add_conversation_member(conversation_id: str, user_id: str, joined_at=None): conn = get_connection() try: cursor = conn.cursor() cursor.execute( "INSERT IGNORE INTO conversation_members (conversation_id, user_id, joined_at) VALUES (%s, %s, %s)", (conversation_id, user_id, joined_at), ) conn.commit() finally: conn.close() def remove_conversation_member(conversation_id: str, user_id: str): conn = get_connection() try: cursor = conn.cursor() cursor.execute( "DELETE FROM conversation_members WHERE conversation_id = %s AND user_id = %s", (conversation_id, user_id), ) conn.commit() finally: conn.close() def count_conversation_members(conversation_id: str) -> int: """Count members in a conversation.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute( "SELECT COUNT(*) FROM conversation_members WHERE conversation_id = %s", (conversation_id,), ) return cursor.fetchone()[0] finally: conn.close() def get_conversation_file_ids(conversation_id: str) -> list[str]: """Get all file IDs (images + files) uploaded to a conversation.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute( "SELECT file_id FROM image_uploads WHERE conversation_id = %s", (conversation_id,), ) return [row[0] for row in cursor.fetchall()] finally: conn.close() def delete_conversation(conversation_id: str): """Delete a conversation entirely. CASCADE cleans up members, messages, etc.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute("DELETE FROM conversations WHERE id = %s", (conversation_id,)) conn.commit() finally: conn.close() def get_conversation_members(conversation_id: str) -> list[dict]: conn = get_connection() try: cursor = conn.cursor(dictionary=True) cursor.execute( "SELECT u.id, u.username, u.email, u.identity_key FROM conversation_members cm " "JOIN users u ON cm.user_id = u.id " "WHERE cm.conversation_id = %s", (conversation_id,), ) return cursor.fetchall() finally: conn.close() def find_direct_conversation(user_id_a: str, user_id_b: str) -> str | None: conn = get_connection() try: cursor = conn.cursor() cursor.execute( "SELECT cm1.conversation_id FROM conversation_members cm1 " "JOIN conversation_members cm2 ON cm1.conversation_id = cm2.conversation_id " "WHERE cm1.user_id = %s AND cm2.user_id = %s " "AND (SELECT COUNT(*) FROM conversation_members cm3 " " WHERE cm3.conversation_id = cm1.conversation_id) = 2 " "LIMIT 1", (user_id_a, user_id_b), ) row = cursor.fetchone() return row[0] if row else None finally: conn.close() def update_conversation_creator(conversation_id: str, new_creator_id: str): """Transfer group creator role to another member.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute( "UPDATE conversations SET created_by = %s WHERE id = %s", (new_creator_id, conversation_id), ) conn.commit() finally: conn.close() def get_conversation(conversation_id: str) -> dict | None: """Get conversation by ID.""" conn = get_connection() try: cursor = conn.cursor(dictionary=True) cursor.execute( "SELECT id, created_at, name, created_by, avatar_file FROM conversations WHERE id = %s", (conversation_id,), ) return cursor.fetchone() finally: conn.close() def update_conversation_avatar(conversation_id: str, avatar_file: str): """Set avatar file for a conversation.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute( "UPDATE conversations SET avatar_file = %s WHERE id = %s", (avatar_file, conversation_id), ) conn.commit() finally: conn.close() def update_conversation_name(conversation_id: str, name: str): """Update the name of a conversation.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute( "UPDATE conversations SET name = %s WHERE id = %s", (name, conversation_id), ) conn.commit() finally: conn.close() def is_conversation_member(conversation_id: str, user_id: str) -> bool: conn = get_connection() try: cursor = conn.cursor() cursor.execute( "SELECT 1 FROM conversation_members WHERE conversation_id = %s AND user_id = %s", (conversation_id, user_id), ) return cursor.fetchone() is not None finally: conn.close() def list_user_conversations(user_id: str) -> list[dict]: conn = get_connection() try: cursor = conn.cursor(dictionary=True) cursor.execute( "SELECT c.id, c.created_at, c.name, c.created_by, c.avatar_file FROM conversations c " "JOIN conversation_members cm ON c.id = cm.conversation_id " "WHERE cm.user_id = %s ORDER BY c.created_at DESC", (user_id,), ) convs = cursor.fetchall() if not convs: return convs # Batch-fetch all members for all conversations in one query (N+1 fix) conv_ids = [c["id"] for c in convs] placeholders = ",".join(["%s"] * len(conv_ids)) cursor.execute( f"SELECT cm.conversation_id, u.id AS user_id, u.username, u.email " f"FROM conversation_members cm JOIN users u ON cm.user_id = u.id " f"WHERE cm.conversation_id IN ({placeholders})", conv_ids, ) members_by_conv: dict[str, list[dict]] = {} for row in cursor.fetchall(): cid = row.pop("conversation_id") members_by_conv.setdefault(cid, []).append(row) for conv in convs: conv["members"] = members_by_conv.get(conv["id"], []) return convs finally: conn.close() # --- Group Invitations --- def create_invitation(conversation_id: str, user_id: str, invited_by: str): """Create a pending group invitation.""" conn = get_connection() try: cursor = conn.cursor() inv_id = generate_uuid() cursor.execute( "INSERT IGNORE INTO group_invitations (id, conversation_id, user_id, invited_by) " "VALUES (%s, %s, %s, %s)", (inv_id, conversation_id, user_id, invited_by), ) conn.commit() finally: conn.close() def get_pending_invitations(user_id: str) -> list[dict]: """Get all pending invitations for a user, joined with conversation and inviter info.""" conn = get_connection() try: cursor = conn.cursor(dictionary=True) cursor.execute( "SELECT gi.id, gi.conversation_id, gi.invited_by, gi.created_at, " "c.name AS conversation_name, u.username AS invited_by_username " "FROM group_invitations gi " "JOIN conversations c ON gi.conversation_id = c.id " "JOIN users u ON gi.invited_by = u.id " "WHERE gi.user_id = %s " "ORDER BY gi.created_at DESC", (user_id,), ) return cursor.fetchall() finally: conn.close() def delete_invitation(conversation_id: str, user_id: str): """Delete a pending invitation.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute( "DELETE FROM group_invitations WHERE conversation_id = %s AND user_id = %s", (conversation_id, user_id), ) conn.commit() finally: conn.close() def has_pending_invitation(conversation_id: str, user_id: str) -> bool: """Check if a user has a pending invitation for a conversation.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute( "SELECT 1 FROM group_invitations WHERE conversation_id = %s AND user_id = %s", (conversation_id, user_id), ) return cursor.fetchone() is not None finally: conn.close() # --- Messages --- def store_message( conversation_id: str, sender_id: str, ratchet_header: bytes, recipients: list[dict], x3dh_header: bytes | None = None, sender_chain_id: bytes | None = None, sender_chain_n: int | None = None, image_file_id: str | None = None, sender_device_id: str | None = None, ) -> str: """Store an encrypted message with per-recipient ciphertext. recipients: [{user_id, encrypted_content (bytes), nonce (bytes), device_id (str, optional), ratchet_header (bytes, optional), x3dh_header (bytes, optional)}] """ conn = get_connection() try: cursor = conn.cursor() msg_id = generate_uuid() cursor.execute( "INSERT INTO messages (id, conversation_id, sender_id, sender_device_id, " "ratchet_header, x3dh_header, sender_chain_id, sender_chain_n, image_file_id) " "VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)", (msg_id, conversation_id, sender_id, sender_device_id, ratchet_header, x3dh_header, sender_chain_id, sender_chain_n, image_file_id), ) for r in recipients: device_id = r.get("device_id", SELF_DEVICE_ID) cursor.execute( "INSERT INTO message_recipients (message_id, user_id, device_id, " "encrypted_content, nonce, ratchet_header, x3dh_header) " "VALUES (%s, %s, %s, %s, %s, %s, %s)", (msg_id, r["user_id"], device_id, r["encrypted_content"], r["nonce"], r.get("ratchet_header"), r.get("x3dh_header")), ) conn.commit() cursor.execute("SELECT created_at FROM messages WHERE id = %s", (msg_id,)) row = cursor.fetchone() created_at = row[0].isoformat() if row else None return msg_id, created_at finally: conn.close() def get_messages(conversation_id: str, user_id: str, limit: int = 50, offset: int = 0, device_id: str | None = None, after_ts: str | None = None) -> list[dict]: """Get messages for a user in a conversation, JOINing their per-recipient ciphertext. If device_id is set, returns rows where mr.device_id matches OR is the sentinel (self-encrypted / legacy). May return duplicate message IDs when both device-specific and self-encrypted rows exist — caller should deduplicate (prefer device-specific). If after_ts is set, only returns messages created after that timestamp (ISO format). Results are ordered ASC when after_ts is used, DESC otherwise. """ conn = get_connection() try: cursor = conn.cursor(dictionary=True) if device_id: where_parts = ["m.conversation_id = %s", "(cm.joined_at IS NULL OR m.created_at >= cm.joined_at)"] params = [user_id, device_id, SELF_DEVICE_ID, user_id, conversation_id] if after_ts: where_parts.append("m.created_at > %s") params.append(after_ts) where_clause = " AND ".join(where_parts) order = "ASC" if after_ts else "DESC" cursor.execute( "SELECT m.id, m.conversation_id, m.sender_id, m.sender_device_id, " "m.ratchet_header, m.x3dh_header, " "m.sender_chain_id, m.sender_chain_n, m.created_at, m.deleted_at, m.image_file_id, " "m.pinned_at, m.pinned_by, " "mr.encrypted_content, mr.nonce, mr.device_id AS mr_device_id, " "mr.ratchet_header AS mr_ratchet_header, mr.x3dh_header AS mr_x3dh_header " "FROM messages m " "JOIN message_recipients mr ON m.id = mr.message_id AND mr.user_id = %s " " AND (mr.device_id = %s OR mr.device_id = %s) " "JOIN conversation_members cm ON cm.conversation_id = m.conversation_id AND cm.user_id = %s " f"WHERE {where_clause} " f"ORDER BY m.created_at {order} LIMIT %s OFFSET %s", (*params, limit, offset), ) else: where_parts = ["m.conversation_id = %s", "(cm.joined_at IS NULL OR m.created_at >= cm.joined_at)"] params = [user_id, user_id, conversation_id] if after_ts: where_parts.append("m.created_at > %s") params.append(after_ts) where_clause = " AND ".join(where_parts) order = "ASC" if after_ts else "DESC" cursor.execute( "SELECT m.id, m.conversation_id, m.sender_id, m.sender_device_id, " "m.ratchet_header, m.x3dh_header, " "m.sender_chain_id, m.sender_chain_n, m.created_at, m.deleted_at, m.image_file_id, " "m.pinned_at, m.pinned_by, " "mr.encrypted_content, mr.nonce, mr.device_id AS mr_device_id, " "mr.ratchet_header AS mr_ratchet_header, mr.x3dh_header AS mr_x3dh_header " "FROM messages m " "JOIN message_recipients mr ON m.id = mr.message_id AND mr.user_id = %s " "JOIN conversation_members cm ON cm.conversation_id = m.conversation_id AND cm.user_id = %s " f"WHERE {where_clause} " f"ORDER BY m.created_at {order} LIMIT %s OFFSET %s", (*params, limit, offset), ) return cursor.fetchall() finally: conn.close() def count_messages(conversation_id: str, user_id: str) -> int: """Count total messages visible to a user in a conversation.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute( "SELECT COUNT(DISTINCT m.id) " "FROM messages m " "JOIN message_recipients mr ON m.id = mr.message_id AND mr.user_id = %s " "JOIN conversation_members cm ON cm.conversation_id = m.conversation_id AND cm.user_id = %s " "WHERE m.conversation_id = %s AND (cm.joined_at IS NULL OR m.created_at >= cm.joined_at)", (user_id, user_id, conversation_id), ) row = cursor.fetchone() return row[0] if row else 0 finally: conn.close() def get_message_conversation(message_id: str) -> str | None: conn = get_connection() try: cursor = conn.cursor() cursor.execute("SELECT conversation_id FROM messages WHERE id = %s", (message_id,)) row = cursor.fetchone() return row[0] if row else None finally: conn.close() def get_message_sender(message_id: str) -> str | None: conn = get_connection() try: cursor = conn.cursor() cursor.execute("SELECT sender_id FROM messages WHERE id = %s", (message_id,)) row = cursor.fetchone() return row[0] if row else None finally: conn.close() def get_deleted_messages_since(conversation_id: str, user_id: str, since_ts: str) -> list[str]: """Return message IDs that were soft-deleted since the given timestamp.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute( "SELECT m.id FROM messages m " "JOIN conversation_members cm ON cm.conversation_id = m.conversation_id AND cm.user_id = %s " "WHERE m.conversation_id = %s AND m.deleted_at IS NOT NULL AND m.deleted_at > %s", (user_id, conversation_id, since_ts), ) return [row[0] for row in cursor.fetchall()] finally: conn.close() # --- Reactions --- ALLOWED_REACTIONS = {"thumbsup", "heart", "laugh", "surprised", "sad", "thumbsdown"} def add_reaction(message_id: str, user_id: str, reaction: str) -> tuple[bool, str | None]: """Add or replace a reaction. Returns (changed, old_reaction_or_None).""" conn = get_connection() try: cursor = conn.cursor(dictionary=True) cursor.execute( "SELECT reaction FROM message_reactions WHERE message_id = %s AND user_id = %s", (message_id, user_id), ) row = cursor.fetchone() old_reaction = row["reaction"] if row else None if old_reaction == reaction: return False, None # already same reaction if old_reaction: cursor.execute( "UPDATE message_reactions SET reaction = %s, created_at = CURRENT_TIMESTAMP " "WHERE message_id = %s AND user_id = %s", (reaction, message_id, user_id), ) else: cursor.execute( "INSERT INTO message_reactions (id, message_id, user_id, reaction) " "VALUES (%s, %s, %s, %s)", (generate_uuid(), message_id, user_id, reaction), ) conn.commit() return True, old_reaction finally: conn.close() def remove_reaction(message_id: str, user_id: str) -> bool: """Remove a user's reaction. Returns True if deleted.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute( "DELETE FROM message_reactions WHERE message_id = %s AND user_id = %s", (message_id, user_id), ) conn.commit() return cursor.rowcount > 0 finally: conn.close() def get_reactions(message_ids: list[str]) -> dict[str, list[dict]]: """Get reactions for multiple messages. Returns {msg_id: [{user_id, reaction, created_at}]}.""" if not message_ids: return {} conn = get_connection() try: cursor = conn.cursor(dictionary=True) placeholders = ",".join(["%s"] * len(message_ids)) cursor.execute( f"SELECT message_id, user_id, reaction, created_at " f"FROM message_reactions WHERE message_id IN ({placeholders}) " f"ORDER BY created_at", tuple(message_ids), ) result = {} for row in cursor.fetchall(): mid = row["message_id"] if mid not in result: result[mid] = [] result[mid].append({ "user_id": row["user_id"], "reaction": row["reaction"], "created_at": row["created_at"].isoformat() if hasattr(row["created_at"], "isoformat") else str(row["created_at"]), }) return result finally: conn.close() # --- Pins --- def pin_message(message_id: str, user_id: str, conversation_id: str) -> bool: """Pin a message. Returns True on success.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute( "UPDATE messages SET pinned_at = NOW(), pinned_by = %s " "WHERE id = %s AND conversation_id = %s AND pinned_at IS NULL", (user_id, message_id, conversation_id), ) conn.commit() return cursor.rowcount > 0 finally: conn.close() def unpin_message(message_id: str, conversation_id: str) -> bool: """Unpin a message. Returns True on success.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute( "UPDATE messages SET pinned_at = NULL, pinned_by = NULL " "WHERE id = %s AND conversation_id = %s AND pinned_at IS NOT NULL", (message_id, conversation_id), ) conn.commit() return cursor.rowcount > 0 finally: conn.close() def get_pinned_messages(conversation_id: str, user_id: str) -> list[dict]: """Get pinned messages for a conversation (membership verified via JOIN).""" conn = get_connection() try: cursor = conn.cursor(dictionary=True) cursor.execute( "SELECT m.id AS message_id, m.sender_id, m.pinned_at, m.pinned_by, m.created_at " "FROM messages m " "JOIN conversation_members cm ON cm.conversation_id = m.conversation_id AND cm.user_id = %s " "WHERE m.conversation_id = %s AND m.pinned_at IS NOT NULL AND m.deleted_at IS NULL " "ORDER BY m.pinned_at DESC", (user_id, conversation_id), ) rows = cursor.fetchall() for r in rows: for k in ("pinned_at", "created_at"): if r.get(k) and hasattr(r[k], "isoformat"): r[k] = r[k].isoformat() return rows finally: conn.close() # --- Group Sender Keys --- def store_sender_key(conversation_id: str, sender_id: str, chain_id: bytes, device_id: str | None = None): """Store or update a sender key chain ID for a group member's device.""" conn = get_connection() try: cursor = conn.cursor() dev = device_id or SELF_DEVICE_ID cursor.execute( "REPLACE INTO group_sender_keys (conversation_id, sender_id, device_id, chain_id) " "VALUES (%s, %s, %s, %s)", (conversation_id, sender_id, dev, chain_id), ) conn.commit() finally: conn.close() def get_sender_key(conversation_id: str, sender_id: str, device_id: str | None = None) -> dict | None: conn = get_connection() try: cursor = conn.cursor(dictionary=True) dev = device_id or SELF_DEVICE_ID cursor.execute( "SELECT chain_id, created_at FROM group_sender_keys " "WHERE conversation_id = %s AND sender_id = %s AND device_id = %s", (conversation_id, sender_id, dev), ) return cursor.fetchone() finally: conn.close() # --- Read Receipts --- def filter_message_ids_by_conversation(conversation_id: str, message_ids: list[str]) -> list[str]: """Return only message_ids that belong to the given conversation.""" if not message_ids: return [] conn = get_connection() try: cursor = conn.cursor() placeholders = ",".join(["%s"] * len(message_ids)) cursor.execute( f"SELECT id FROM messages WHERE id IN ({placeholders}) AND conversation_id = %s", (*message_ids, conversation_id), ) return [row[0] for row in cursor.fetchall()] finally: conn.close() def mark_messages_read(conversation_id: str, user_id: str, message_ids: list[str]): if not message_ids: return conn = get_connection() try: cursor = conn.cursor() # M1 fix: JOIN messages to verify message_ids belong to conversation_id placeholders = ",".join(["%s"] * len(message_ids)) cursor.execute( f"INSERT IGNORE INTO message_reads (message_id, user_id) " f"SELECT m.id, %s FROM messages m " f"WHERE m.id IN ({placeholders}) AND m.conversation_id = %s", (user_id, *message_ids, conversation_id), ) conn.commit() finally: conn.close() def mark_conversation_read(conversation_id: str, user_id: str) -> int: """Mark ALL unread messages in a conversation as read for user. Returns count marked.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute( "INSERT IGNORE INTO message_reads (message_id, user_id) " "SELECT m.id, %s " "FROM messages m " "JOIN message_recipients mr ON mr.message_id = m.id AND mr.user_id = %s " "LEFT JOIN message_reads mrd ON mrd.message_id = m.id AND mrd.user_id = %s " "WHERE m.conversation_id = %s AND m.sender_id != %s " "AND m.deleted_at IS NULL AND mrd.message_id IS NULL", (user_id, user_id, user_id, conversation_id, user_id), ) count = cursor.rowcount conn.commit() return count finally: conn.close() def get_unread_counts(user_id: str, max_age_days: int = 0) -> dict[str, int]: """Return {conversation_id: unread_count} for all conversations the user is in. max_age_days: if > 0, only count messages younger than this many days. Must match METADATA_RETENTION_DAYS to avoid phantom unreads after read cleanup. """ conn = get_connection() try: cursor = conn.cursor(dictionary=True) age_filter = "" params = [user_id, user_id, user_id] if max_age_days > 0: age_filter = " AND m.created_at >= DATE_SUB(NOW(), INTERVAL %s DAY)" params.append(max_age_days) cursor.execute( "SELECT m.conversation_id, COUNT(DISTINCT m.id) AS cnt " "FROM messages m " "JOIN message_recipients mr ON mr.message_id = m.id AND mr.user_id = %s " "LEFT JOIN message_reads mrd ON mrd.message_id = m.id AND mrd.user_id = %s " "WHERE m.sender_id != %s AND m.deleted_at IS NULL AND mrd.message_id IS NULL" f"{age_filter} " "GROUP BY m.conversation_id", params, ) return {row["conversation_id"]: row["cnt"] for row in cursor.fetchall()} finally: conn.close() def get_message_read_status(message_ids: list[str]) -> dict: if not message_ids: return {} conn = get_connection() try: cursor = conn.cursor(dictionary=True) placeholders = ",".join(["%s"] * len(message_ids)) cursor.execute( f"SELECT mr.message_id, mr.user_id, mr.read_at " f"FROM message_reads mr " f"WHERE mr.message_id IN ({placeholders})", tuple(message_ids), ) result = {} for row in cursor.fetchall(): mid = row["message_id"] if mid not in result: result[mid] = [] result[mid].append({ "user_id": row["user_id"], "read_at": row["read_at"].isoformat() if hasattr(row["read_at"], "isoformat") else str(row["read_at"]), }) return result finally: conn.close() # --- Delivery Receipts --- def mark_messages_delivered(conversation_id: str, user_id: str, message_ids: list[str]): """Batch insert delivery receipts (INSERT IGNORE — idempotent).""" if not message_ids: return conn = get_connection() try: cursor = conn.cursor() # M1 fix: JOIN messages to verify message_ids belong to conversation_id placeholders = ",".join(["%s"] * len(message_ids)) cursor.execute( f"INSERT IGNORE INTO message_deliveries (message_id, user_id) " f"SELECT m.id, %s FROM messages m " f"WHERE m.id IN ({placeholders}) AND m.conversation_id = %s", (user_id, *message_ids, conversation_id), ) conn.commit() finally: conn.close() def get_message_delivery_status(message_ids: list[str]) -> dict: """Get delivery status for messages. Returns {msg_id: [{user_id, delivered_at}]}.""" if not message_ids: return {} conn = get_connection() try: cursor = conn.cursor(dictionary=True) placeholders = ",".join(["%s"] * len(message_ids)) cursor.execute( f"SELECT md.message_id, md.user_id, md.delivered_at " f"FROM message_deliveries md " f"WHERE md.message_id IN ({placeholders})", tuple(message_ids), ) result = {} for row in cursor.fetchall(): mid = row["message_id"] if mid not in result: result[mid] = [] result[mid].append({ "user_id": row["user_id"], "delivered_at": row["delivered_at"].isoformat() if hasattr(row["delivered_at"], "isoformat") else str(row["delivered_at"]), }) return result finally: conn.close() # --- Delete --- def soft_delete_message(message_id: str, sender_id: str) -> dict | None: """Soft-delete a message if sender matches. Returns {'image_file_id': ...} or None.""" conn = get_connection() try: cursor = conn.cursor(dictionary=True) cursor.execute( "SELECT sender_id, image_file_id FROM messages WHERE id = %s AND deleted_at IS NULL", (message_id,), ) row = cursor.fetchone() if not row or row["sender_id"] != sender_id: return None cursor.execute( "UPDATE messages SET deleted_at = NOW() WHERE id = %s", (message_id,), ) # Clear per-recipient ciphertext cursor.execute( "UPDATE message_recipients SET encrypted_content = %s WHERE message_id = %s", (b"", message_id), ) conn.commit() return {"image_file_id": row.get("image_file_id")} finally: conn.close() def set_message_image_file_id(message_id: str, file_id: str): conn = get_connection() try: cursor = conn.cursor() cursor.execute( "UPDATE messages SET image_file_id = %s WHERE id = %s", (file_id, message_id), ) conn.commit() finally: conn.close() # --- Image Uploads --- def create_image_upload(file_id: str, conversation_id: str, uploader_id: str, file_size: int): conn = get_connection() try: cursor = conn.cursor() cursor.execute( "INSERT INTO image_uploads (file_id, conversation_id, uploader_id, file_size) " "VALUES (%s, %s, %s, %s)", (file_id, conversation_id, uploader_id, file_size), ) conn.commit() finally: conn.close() def complete_image_upload(file_id: str): conn = get_connection() try: cursor = conn.cursor() cursor.execute( "UPDATE image_uploads SET completed = TRUE WHERE file_id = %s", (file_id,), ) conn.commit() finally: conn.close() def get_image_upload(file_id: str) -> dict | None: conn = get_connection() try: cursor = conn.cursor(dictionary=True) cursor.execute( "SELECT file_id, conversation_id, uploader_id, file_size, completed, created_at " "FROM image_uploads WHERE file_id = %s", (file_id,), ) return cursor.fetchone() finally: conn.close() def delete_image_upload(file_id: str): conn = get_connection() try: cursor = conn.cursor() cursor.execute("DELETE FROM image_uploads WHERE file_id = %s", (file_id,)) conn.commit() finally: conn.close() # --- User Profiles --- def create_default_profile(user_id: str): """Create a default profile for a new user.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute( "INSERT IGNORE INTO user_profiles (user_id) VALUES (%s)", (user_id,), ) conn.commit() finally: conn.close() def get_user_profile(user_id: str, viewer_id: str | None = None) -> dict | None: """Get user profile joined with user info. Respects visibility if viewer is different user.""" conn = get_connection() try: cursor = conn.cursor(dictionary=True) cursor.execute( "SELECT u.id AS user_id, u.username, u.email, u.created_at, " "p.phone, p.phone_visible, p.email_visible, p.location, " "p.location_visible, p.avatar_file, p.updated_at " "FROM users u LEFT JOIN user_profiles p ON u.id = p.user_id " "WHERE u.id = %s", (user_id,), ) row = cursor.fetchone() if not row: return None # If viewing someone else's profile, apply visibility rules if viewer_id and viewer_id != user_id: if not row.get("email_visible"): row["email"] = None if not row.get("phone_visible"): row["phone"] = None if not row.get("location_visible"): row["location"] = None return row finally: conn.close() def update_user_profile(user_id: str, **fields): """Upsert user profile fields. Allowed: phone, phone_visible, email_visible, location, location_visible, avatar_file.""" allowed = {"phone", "phone_visible", "email_visible", "location", "location_visible", "avatar_file"} filtered = {k: v for k, v in fields.items() if k in allowed} if not filtered: return conn = get_connection() try: cursor = conn.cursor() # Upsert: insert default then update cursor.execute( "INSERT IGNORE INTO user_profiles (user_id) VALUES (%s)", (user_id,), ) set_clause = ", ".join(f"{k} = %s" for k in filtered) values = list(filtered.values()) + [user_id] cursor.execute( f"UPDATE user_profiles SET {set_clause} WHERE user_id = %s", values, ) conn.commit() finally: conn.close() def batch_reencrypt_messages(user_id: str, updates: list[dict]): """Batch upsert message_recipients rows with self-encryption key data. Each update: {message_id, encrypted_content (bytes), nonce (bytes)}. Sets ratchet_header to '{"self":true}' and clears x3dh_header. Uses INSERT ... ON DUPLICATE KEY UPDATE so it works for both sent messages (which already have a SELF_DEVICE_ID row) and received messages (which don't). """ if not updates: return conn = get_connection() try: cursor = conn.cursor() self_header = b'{"self":true}' for u in updates: cursor.execute( "INSERT INTO message_recipients " "(message_id, user_id, device_id, encrypted_content, nonce, ratchet_header, x3dh_header) " "VALUES (%s, %s, %s, %s, %s, %s, NULL) " "ON DUPLICATE KEY UPDATE encrypted_content = VALUES(encrypted_content), " "nonce = VALUES(nonce), ratchet_header = VALUES(ratchet_header), x3dh_header = NULL", (u["message_id"], user_id, SELF_DEVICE_ID, u["encrypted_content"], u["nonce"], self_header), ) conn.commit() finally: conn.close() # --- Phantom Users --- def create_phantom_user(email: str) -> dict: """Create a phantom user with valid crypto keys for X3DH. Phantom users have rsa_public_key = 'PHANTOM' as a marker. Returns user dict: {id, username, email, identity_key}. """ username = email.split("@")[0] user_id = generate_uuid() # Generate real crypto keys so X3DH works on the client side ik_private, ik_public = generate_identity_keypair() ik_public_bytes = serialize_ed25519_public(ik_public) spk = generate_signed_prekey(ik_private) spk_pub_bytes = serialize_x25519_public(spk["public"]) spk_sig = spk["signature"] opks = generate_one_time_prekeys(count=5) conn = get_connection() try: cursor = conn.cursor() cursor.execute( "INSERT INTO users (id, username, email, rsa_public_key, identity_key) " "VALUES (%s, %s, %s, %s, %s)", (user_id, username, email, "PHANTOM", ik_public_bytes), ) cursor.execute( "INSERT INTO signed_prekeys (id, user_id, public_key, signature) VALUES (%s, %s, %s, %s)", (spk["id"], user_id, spk_pub_bytes, spk_sig), ) for opk in opks: cursor.execute( "INSERT INTO one_time_prekeys (id, user_id, public_key) VALUES (%s, %s, %s)", (opk["id"], user_id, serialize_x25519_public(opk["public"])), ) conn.commit() return {"id": user_id, "username": username, "email": email, "identity_key": ik_public_bytes} finally: conn.close() def is_phantom_user(user_id: str) -> bool: """Check if a user is a phantom (rsa_public_key == 'PHANTOM').""" conn = get_connection() try: cursor = conn.cursor() cursor.execute("SELECT rsa_public_key FROM users WHERE id = %s", (user_id,)) row = cursor.fetchone() return row is not None and row[0] == "PHANTOM" finally: conn.close() def delete_phantom_user(user_id: str): """Delete a phantom user. CASCADE removes signed_prekeys, one_time_prekeys, conversation_members, message_recipients, etc.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute( "DELETE FROM users WHERE id = %s AND rsa_public_key = %s", (user_id, "PHANTOM"), ) conn.commit() finally: conn.close() def upgrade_phantom_user(phantom_id: str, username: str, rsa_public_key_pem: str, identity_key: bytes) -> str | None: """Upgrade a phantom user to a real user in-place. Preserves user_id and all FK references (conversation_members, group_invitations, etc.). Deletes phantom's server-generated prekeys (real user will upload own on first login). Returns phantom_id as the new user_id, or None if phantom no longer exists. """ conn = get_connection() try: cursor = conn.cursor() cursor.execute( "UPDATE users SET username = %s, rsa_public_key = %s, identity_key = %s " "WHERE id = %s AND rsa_public_key = 'PHANTOM'", (username, rsa_public_key_pem, identity_key, phantom_id), ) if cursor.rowcount == 0: conn.rollback() return None # Remove phantom's server-generated crypto keys — real user uploads own cursor.execute("DELETE FROM signed_prekeys WHERE user_id = %s", (phantom_id,)) cursor.execute("DELETE FROM one_time_prekeys WHERE user_id = %s", (phantom_id,)) conn.commit() return phantom_id finally: conn.close() def get_all_phantom_user_ids() -> set[str]: """Return set of all phantom user IDs (for server startup cache).""" conn = get_connection() try: cursor = conn.cursor() cursor.execute("SELECT id FROM users WHERE rsa_public_key = %s", ("PHANTOM",)) return {row[0] for row in cursor.fetchall()} finally: conn.close() def cleanup_stale_phantoms(max_age_days: int = 30) -> int: """Delete phantom users older than max_age_days with no active conversations with real users.""" conn = get_connection() try: cursor = conn.cursor() # Two-step: SELECT ids first, then DELETE. # MySQL error 1093: can't DELETE from table referenced in subquery. cursor.execute(""" SELECT u.id FROM users u WHERE u.rsa_public_key = 'PHANTOM' AND u.created_at < DATE_SUB(NOW(), INTERVAL %s DAY) AND NOT EXISTS ( SELECT 1 FROM conversation_members cm1 JOIN conversation_members cm2 ON cm1.conversation_id = cm2.conversation_id JOIN users u2 ON cm2.user_id = u2.id WHERE cm1.user_id = u.id AND u2.rsa_public_key != 'PHANTOM' ) """, (max_age_days,)) ids = [row[0] for row in cursor.fetchall()] if not ids: return 0 cursor.execute( "DELETE FROM users WHERE id IN (%s)" % ",".join(["%s"] * len(ids)), ids, ) deleted = cursor.rowcount conn.commit() return deleted finally: conn.close() def remove_conversation_member_atomic(conversation_id: str, user_id: str) -> bool: """Remove member and return True if actually removed (row existed). M6 TOCTOU fix.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute( "DELETE FROM conversation_members WHERE conversation_id = %s AND user_id = %s", (conversation_id, user_id), ) conn.commit() return cursor.rowcount > 0 finally: conn.close() def get_stale_uploads(max_age_seconds: int = 3600) -> list[dict]: conn = get_connection() try: cursor = conn.cursor(dictionary=True) cursor.execute( "SELECT file_id FROM image_uploads " "WHERE completed = FALSE AND created_at < DATE_SUB(NOW(), INTERVAL %s SECOND)", (max_age_seconds,), ) return cursor.fetchall() finally: conn.close() # --------------------------------------------------------------------------- # Metadata retention cleanup # --------------------------------------------------------------------------- def cleanup_old_reads(days: int = 90, batch_size: int = 10000) -> int: """Delete message_reads older than N days in batches. Only deletes reads for messages whose created_at is also past the retention window. This prevents phantom unreads: get_unread_counts uses the same time window (max_age_days) so messages outside the window aren't counted. """ total = 0 while True: conn = get_connection() try: cursor = conn.cursor() cursor.execute( "DELETE FROM message_reads " "WHERE read_at < DATE_SUB(NOW(), INTERVAL %s DAY) " "AND message_id IN (" " SELECT id FROM messages " " WHERE created_at < DATE_SUB(NOW(), INTERVAL %s DAY)" ") LIMIT %s", (days, days, batch_size), ) count = cursor.rowcount conn.commit() total += count if count < batch_size: break finally: conn.close() return total def cleanup_old_reactions(days: int = 90, batch_size: int = 10000) -> int: """Delete message_reactions older than N days in batches.""" total = 0 while True: conn = get_connection() try: cursor = conn.cursor() cursor.execute( "DELETE FROM message_reactions WHERE created_at < DATE_SUB(NOW(), INTERVAL %s DAY) LIMIT %s", (days, batch_size), ) count = cursor.rowcount conn.commit() total += count if count < batch_size: break finally: conn.close() return total