"""MySQL database layer for the encrypted chat server.""" import os import uuid import logging import mysql.connector from mysql.connector import pooling from dotenv import load_dotenv from crypto_utils import ( generate_identity_keypair, serialize_ed25519_public, generate_signed_prekey, serialize_x25519_public, generate_one_time_prekeys, ) load_dotenv() # Sentinel device_id for self-encrypted copies and legacy (pre-multi-device) rows SELF_DEVICE_ID = "00000000-0000-0000-0000-000000000000" _logger = logging.getLogger(__name__) _pool = None def _get_pool(): """Get or create the connection pool (lazy init).""" global _pool if _pool is None: pool_size = int(os.getenv("DB_POOL_SIZE", "10")) pool_kwargs = dict( pool_name="chat_pool", pool_size=pool_size, pool_reset_session=True, host=os.getenv("MYSQL_HOST", "localhost"), port=int(os.getenv("MYSQL_PORT", "3306")), user=os.getenv("MYSQL_USER", "root"), password=os.getenv("MYSQL_PASSWORD", ""), database=os.getenv("MYSQL_DATABASE", "encrypted_chat"), ) # Optional MySQL TLS (M7): set MYSQL_SSL_CA (and optionally MYSQL_SSL_CERT/KEY) ssl_ca = os.getenv("MYSQL_SSL_CA", "").strip() ssl_cert = os.getenv("MYSQL_SSL_CERT", "").strip() ssl_key = os.getenv("MYSQL_SSL_KEY", "").strip() if ssl_ca: pool_kwargs["ssl_ca"] = ssl_ca if ssl_cert: pool_kwargs["ssl_cert"] = ssl_cert if ssl_key: pool_kwargs["ssl_key"] = ssl_key _logger.info("MySQL TLS enabled (CA: %s)", ssl_ca) _pool = pooling.MySQLConnectionPool(**pool_kwargs) _logger.info("DB connection pool created (size=%d)", pool_size) return _pool def get_connection(): """Get a connection from the pool.""" return _get_pool().get_connection() def generate_uuid() -> str: return str(uuid.uuid4()) # --- Devices --- def create_device(user_id: str, device_name: str | None = None) -> str: """Create a new device for a user. Returns device_id.""" conn = get_connection() try: cursor = conn.cursor() device_id = generate_uuid() cursor.execute( "INSERT INTO devices (id, user_id, device_name) VALUES (%s, %s, %s)", (device_id, user_id, device_name), ) conn.commit() return device_id finally: conn.close() def get_user_devices(user_id: str) -> list[dict]: """Get all devices for a user.""" conn = get_connection() try: cursor = conn.cursor(dictionary=True) cursor.execute( "SELECT id, user_id, device_name, created_at, last_seen_at " "FROM devices WHERE user_id = %s ORDER BY created_at", (user_id,), ) return cursor.fetchall() finally: conn.close() def get_device(device_id: str) -> dict | None: """Get a single device by ID.""" conn = get_connection() try: cursor = conn.cursor(dictionary=True) cursor.execute( "SELECT id, user_id, device_name, created_at, last_seen_at " "FROM devices WHERE id = %s", (device_id,), ) return cursor.fetchone() finally: conn.close() def update_device_last_seen(device_id: str): """Update last_seen_at timestamp for a device.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute( "UPDATE devices SET last_seen_at = NOW() WHERE id = %s", (device_id,), ) conn.commit() finally: conn.close() def delete_device(device_id: str): """Delete a device. CASCADE removes its prekeys.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute("DELETE FROM devices WHERE id = %s", (device_id,)) # Also clean up prekeys explicitly for device_id column cursor.execute("DELETE FROM signed_prekeys WHERE device_id = %s", (device_id,)) cursor.execute("DELETE FROM one_time_prekeys WHERE device_id = %s", (device_id,)) conn.commit() finally: conn.close() # --- Users --- def create_user(username: str, email: str, rsa_public_key_pem: str, identity_key: bytes) -> str: """Register a new user. Returns user ID.""" conn = get_connection() try: cursor = conn.cursor() user_id = generate_uuid() cursor.execute( "INSERT INTO users (id, username, email, rsa_public_key, identity_key) " "VALUES (%s, %s, %s, %s, %s)", (user_id, username, email, rsa_public_key_pem, identity_key), ) conn.commit() return user_id finally: conn.close() def get_user_by_email(email: str) -> dict | None: """Get user by email.""" conn = get_connection() try: cursor = conn.cursor(dictionary=True) cursor.execute( "SELECT id, username, rsa_public_key, email, identity_key FROM users WHERE email = %s", (email,), ) return cursor.fetchone() finally: conn.close() def get_user_by_id(user_id: str) -> dict | None: """Get user by ID.""" conn = get_connection() try: cursor = conn.cursor(dictionary=True) cursor.execute( "SELECT id, username, rsa_public_key, email, identity_key FROM users WHERE id = %s", (user_id,), ) return cursor.fetchone() finally: conn.close() def shares_conversation(user_id_a: str, user_id_b: str) -> bool: """Check if two users share at least one conversation.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute( "SELECT 1 FROM conversation_members cm1 " "JOIN conversation_members cm2 ON cm1.conversation_id = cm2.conversation_id " "WHERE cm1.user_id = %s AND cm2.user_id = %s LIMIT 1", (user_id_a, user_id_b), ) return cursor.fetchone() is not None finally: conn.close() def get_user_contacts(user_id: str) -> list[str]: """Get all user IDs that share at least one conversation with the given user.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute( "SELECT DISTINCT cm2.user_id " "FROM conversation_members cm1 " "JOIN conversation_members cm2 ON cm1.conversation_id = cm2.conversation_id " "WHERE cm1.user_id = %s AND cm2.user_id != %s", (user_id, user_id), ) return [row[0] for row in cursor.fetchall()] finally: conn.close() def update_user_rsa_key(user_id: str, rsa_public_key_pem: str): """Update user's RSA public key (for login).""" conn = get_connection() try: cursor = conn.cursor() cursor.execute("UPDATE users SET rsa_public_key = %s WHERE id = %s", (rsa_public_key_pem, user_id)) conn.commit() finally: conn.close() def update_username(user_id: str, new_username: str): """Update user's display name.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute("UPDATE users SET username = %s WHERE id = %s", (new_username, user_id)) conn.commit() finally: conn.close() # --- Pre-keys --- def store_signed_prekey(user_id: str, spk_id: str, public_key: bytes, signature: bytes, device_id: str | None = None): """Store (or replace) a signed pre-key for a user's device.""" conn = get_connection() try: cursor = conn.cursor() # Remove old SPKs for this user+device if device_id: cursor.execute("DELETE FROM signed_prekeys WHERE user_id = %s AND device_id = %s", (user_id, device_id)) else: cursor.execute("DELETE FROM signed_prekeys WHERE user_id = %s AND device_id IS NULL", (user_id,)) cursor.execute( "INSERT INTO signed_prekeys (id, user_id, device_id, public_key, signature) " "VALUES (%s, %s, %s, %s, %s)", (spk_id, user_id, device_id, public_key, signature), ) conn.commit() finally: conn.close() def get_signed_prekey(user_id: str, device_id: str | None = None) -> dict | None: """Get the current signed pre-key for a user (optionally per device).""" conn = get_connection() try: cursor = conn.cursor(dictionary=True) if device_id: cursor.execute( "SELECT id, public_key, signature, device_id, created_at FROM signed_prekeys " "WHERE user_id = %s AND device_id = %s " "ORDER BY created_at DESC LIMIT 1", (user_id, device_id), ) else: cursor.execute( "SELECT id, public_key, signature, device_id, created_at FROM signed_prekeys " "WHERE user_id = %s ORDER BY created_at DESC LIMIT 1", (user_id,), ) return cursor.fetchone() finally: conn.close() def store_one_time_prekeys(user_id: str, prekeys: list[dict], device_id: str | None = None): """Store a batch of one-time pre-keys. Each dict has {id, public_key (bytes)}.""" conn = get_connection() try: cursor = conn.cursor() for pk in prekeys: cursor.execute( "INSERT INTO one_time_prekeys (id, user_id, device_id, public_key) " "VALUES (%s, %s, %s, %s)", (pk["id"], user_id, device_id, pk["public_key"]), ) conn.commit() finally: conn.close() def consume_one_time_prekey(user_id: str, device_id: str | None = None) -> dict | None: """Atomically consume one OTP: SELECT FOR UPDATE + DELETE. Returns {id, public_key} or None.""" conn = get_connection() try: cursor = conn.cursor(dictionary=True) conn.start_transaction() if device_id: cursor.execute( "SELECT id, public_key FROM one_time_prekeys " "WHERE user_id = %s AND device_id = %s LIMIT 1 FOR UPDATE", (user_id, device_id), ) else: cursor.execute( "SELECT id, public_key FROM one_time_prekeys " "WHERE user_id = %s LIMIT 1 FOR UPDATE", (user_id,), ) row = cursor.fetchone() if row: cursor.execute("DELETE FROM one_time_prekeys WHERE id = %s", (row["id"],)) conn.commit() return row except Exception: conn.rollback() raise finally: conn.close() def count_one_time_prekeys(user_id: str, device_id: str | None = None) -> int: """Count remaining OTPs for a user (optionally per device).""" conn = get_connection() try: cursor = conn.cursor() if device_id: cursor.execute( "SELECT COUNT(*) FROM one_time_prekeys WHERE user_id = %s AND device_id = %s", (user_id, device_id), ) else: cursor.execute("SELECT COUNT(*) FROM one_time_prekeys WHERE user_id = %s", (user_id,)) return cursor.fetchone()[0] finally: conn.close() def get_key_bundle(user_id: str) -> dict | None: """Get complete key bundle for X3DH (single device — legacy compat). Returns {identity_key, signed_prekey_id, signed_prekey, spk_signature, one_time_prekey_id, one_time_prekey} or None. OTP is consumed atomically. """ conn = get_connection() try: cursor = conn.cursor(dictionary=True) # Get user identity key cursor.execute("SELECT identity_key FROM users WHERE id = %s", (user_id,)) user = cursor.fetchone() if not user: return None # Get signed prekey cursor.execute( "SELECT id, public_key, signature, device_id FROM signed_prekeys WHERE user_id = %s " "ORDER BY created_at DESC LIMIT 1", (user_id,), ) spk = cursor.fetchone() if not spk: return None # Consume one OTP (may be None) — use transaction for atomicity (H12 fix) conn.start_transaction() cursor.execute( "SELECT id, public_key FROM one_time_prekeys WHERE user_id = %s LIMIT 1 FOR UPDATE", (user_id,), ) opk = cursor.fetchone() if opk: cursor.execute("DELETE FROM one_time_prekeys WHERE id = %s", (opk["id"],)) conn.commit() result = { "identity_key": user["identity_key"], "signed_prekey_id": spk["id"], "signed_prekey": spk["public_key"], "spk_signature": spk["signature"], } if opk: result["one_time_prekey_id"] = opk["id"] result["one_time_prekey"] = opk["public_key"] return result except Exception: try: conn.rollback() except Exception: pass raise finally: conn.close() def get_key_bundles_for_user(user_id: str) -> dict | None: """Get key bundles for ALL devices of a user. Returns {identity_key, device_bundles: [{device_id, signed_prekey_id, signed_prekey_pub, spk_signature, opk_id, opk_pub}]} or None. Consumes one OPK per device atomically. """ conn = get_connection() try: cursor = conn.cursor(dictionary=True) # Get user identity key cursor.execute("SELECT identity_key FROM users WHERE id = %s", (user_id,)) user = cursor.fetchone() if not user: return None # Get all signed prekeys (one per device, most recent) cursor.execute( "SELECT id, public_key, signature, device_id FROM signed_prekeys " "WHERE user_id = %s ORDER BY created_at DESC", (user_id,), ) all_spks = cursor.fetchall() if not all_spks: return None # De-duplicate: keep only the most recent SPK per device_id seen_devices = set() spks_by_device = [] for spk in all_spks: dev = spk.get("device_id") or "__legacy__" if dev not in seen_devices: seen_devices.add(dev) spks_by_device.append(spk) device_bundles = [] # Commit the implicit transaction from the read-only queries above # so we can start an explicit transaction for atomic OPK consumption. conn.commit() conn.start_transaction() for spk in spks_by_device: dev_id = spk.get("device_id") # Consume one OPK for this device if dev_id: cursor.execute( "SELECT id, public_key FROM one_time_prekeys " "WHERE user_id = %s AND device_id = %s LIMIT 1 FOR UPDATE", (user_id, dev_id), ) else: cursor.execute( "SELECT id, public_key FROM one_time_prekeys " "WHERE user_id = %s AND device_id IS NULL LIMIT 1 FOR UPDATE", (user_id,), ) opk = cursor.fetchone() if opk: cursor.execute("DELETE FROM one_time_prekeys WHERE id = %s", (opk["id"],)) bundle = { "device_id": dev_id, "signed_prekey_id": spk["id"], "signed_prekey_pub": spk["public_key"], "spk_signature": spk["signature"], } if opk: bundle["opk_id"] = opk["id"] bundle["opk_pub"] = opk["public_key"] device_bundles.append(bundle) conn.commit() return { "identity_key": user["identity_key"], "device_bundles": device_bundles, } except Exception: try: conn.rollback() except Exception: pass raise finally: conn.close() # --- Conversations --- def create_conversation(member_user_ids: list[str], joined_at=None, name=None, created_by=None) -> str: conn = get_connection() try: cursor = conn.cursor() conv_id = generate_uuid() cursor.execute("INSERT INTO conversations (id, name, created_by) VALUES (%s, %s, %s)", (conv_id, name, created_by)) for uid in member_user_ids: cursor.execute( "INSERT INTO conversation_members (conversation_id, user_id, joined_at) VALUES (%s, %s, %s)", (conv_id, uid, joined_at), ) conn.commit() return conv_id finally: conn.close() def add_conversation_member(conversation_id: str, user_id: str, joined_at=None): conn = get_connection() try: cursor = conn.cursor() cursor.execute( "INSERT IGNORE INTO conversation_members (conversation_id, user_id, joined_at) VALUES (%s, %s, %s)", (conversation_id, user_id, joined_at), ) conn.commit() finally: conn.close() def remove_conversation_member(conversation_id: str, user_id: str): conn = get_connection() try: cursor = conn.cursor() cursor.execute( "DELETE FROM conversation_members WHERE conversation_id = %s AND user_id = %s", (conversation_id, user_id), ) conn.commit() finally: conn.close() def count_conversation_members(conversation_id: str) -> int: """Count members in a conversation.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute( "SELECT COUNT(*) FROM conversation_members WHERE conversation_id = %s", (conversation_id,), ) return cursor.fetchone()[0] finally: conn.close() def get_conversation_file_ids(conversation_id: str) -> list[str]: """Get all file IDs (images + files) uploaded to a conversation.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute( "SELECT file_id FROM image_uploads WHERE conversation_id = %s", (conversation_id,), ) return [row[0] for row in cursor.fetchall()] finally: conn.close() def delete_conversation(conversation_id: str): """Delete a conversation entirely. CASCADE cleans up members, messages, etc.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute("DELETE FROM conversations WHERE id = %s", (conversation_id,)) conn.commit() finally: conn.close() def get_conversation_members(conversation_id: str) -> list[dict]: conn = get_connection() try: cursor = conn.cursor(dictionary=True) cursor.execute( "SELECT u.id, u.username, u.email, u.identity_key FROM conversation_members cm " "JOIN users u ON cm.user_id = u.id " "WHERE cm.conversation_id = %s", (conversation_id,), ) return cursor.fetchall() finally: conn.close() def find_direct_conversation(user_id_a: str, user_id_b: str) -> str | None: conn = get_connection() try: cursor = conn.cursor() cursor.execute( "SELECT cm1.conversation_id FROM conversation_members cm1 " "JOIN conversation_members cm2 ON cm1.conversation_id = cm2.conversation_id " "WHERE cm1.user_id = %s AND cm2.user_id = %s " "AND (SELECT COUNT(*) FROM conversation_members cm3 " " WHERE cm3.conversation_id = cm1.conversation_id) = 2 " "LIMIT 1", (user_id_a, user_id_b), ) row = cursor.fetchone() return row[0] if row else None finally: conn.close() def update_conversation_creator(conversation_id: str, new_creator_id: str): """Transfer group creator role to another member.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute( "UPDATE conversations SET created_by = %s WHERE id = %s", (new_creator_id, conversation_id), ) conn.commit() finally: conn.close() def get_conversation(conversation_id: str) -> dict | None: """Get conversation by ID.""" conn = get_connection() try: cursor = conn.cursor(dictionary=True) cursor.execute( "SELECT id, created_at, name, created_by, avatar_file FROM conversations WHERE id = %s", (conversation_id,), ) return cursor.fetchone() finally: conn.close() def update_conversation_avatar(conversation_id: str, avatar_file: str): """Set avatar file for a conversation.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute( "UPDATE conversations SET avatar_file = %s WHERE id = %s", (avatar_file, conversation_id), ) conn.commit() finally: conn.close() def update_conversation_name(conversation_id: str, name: str): """Update the name of a conversation.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute( "UPDATE conversations SET name = %s WHERE id = %s", (name, conversation_id), ) conn.commit() finally: conn.close() def is_conversation_member(conversation_id: str, user_id: str) -> bool: conn = get_connection() try: cursor = conn.cursor() cursor.execute( "SELECT 1 FROM conversation_members WHERE conversation_id = %s AND user_id = %s", (conversation_id, user_id), ) return cursor.fetchone() is not None finally: conn.close() def list_user_conversations(user_id: str) -> list[dict]: conn = get_connection() try: cursor = conn.cursor(dictionary=True) cursor.execute( "SELECT c.id, c.created_at, c.name, c.created_by, c.avatar_file FROM conversations c " "JOIN conversation_members cm ON c.id = cm.conversation_id " "WHERE cm.user_id = %s ORDER BY c.created_at DESC", (user_id,), ) convs = cursor.fetchall() if not convs: return convs # Batch-fetch all members for all conversations in one query (N+1 fix) conv_ids = [c["id"] for c in convs] placeholders = ",".join(["%s"] * len(conv_ids)) cursor.execute( f"SELECT cm.conversation_id, u.id AS user_id, u.username, u.email " f"FROM conversation_members cm JOIN users u ON cm.user_id = u.id " f"WHERE cm.conversation_id IN ({placeholders})", conv_ids, ) members_by_conv: dict[str, list[dict]] = {} for row in cursor.fetchall(): cid = row.pop("conversation_id") members_by_conv.setdefault(cid, []).append(row) for conv in convs: conv["members"] = members_by_conv.get(conv["id"], []) return convs finally: conn.close() # --- Group Invitations --- def create_invitation(conversation_id: str, user_id: str, invited_by: str): """Create a pending group invitation.""" conn = get_connection() try: cursor = conn.cursor() inv_id = generate_uuid() cursor.execute( "INSERT IGNORE INTO group_invitations (id, conversation_id, user_id, invited_by) " "VALUES (%s, %s, %s, %s)", (inv_id, conversation_id, user_id, invited_by), ) conn.commit() finally: conn.close() def get_pending_invitations(user_id: str) -> list[dict]: """Get all pending invitations for a user, joined with conversation and inviter info.""" conn = get_connection() try: cursor = conn.cursor(dictionary=True) cursor.execute( "SELECT gi.id, gi.conversation_id, gi.invited_by, gi.created_at, " "c.name AS conversation_name, u.username AS invited_by_username " "FROM group_invitations gi " "JOIN conversations c ON gi.conversation_id = c.id " "JOIN users u ON gi.invited_by = u.id " "WHERE gi.user_id = %s " "ORDER BY gi.created_at DESC", (user_id,), ) return cursor.fetchall() finally: conn.close() def delete_invitation(conversation_id: str, user_id: str): """Delete a pending invitation.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute( "DELETE FROM group_invitations WHERE conversation_id = %s AND user_id = %s", (conversation_id, user_id), ) conn.commit() finally: conn.close() def has_pending_invitation(conversation_id: str, user_id: str) -> bool: """Check if a user has a pending invitation for a conversation.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute( "SELECT 1 FROM group_invitations WHERE conversation_id = %s AND user_id = %s", (conversation_id, user_id), ) return cursor.fetchone() is not None finally: conn.close() # --- Messages --- def store_message( conversation_id: str, sender_id: str, ratchet_header: bytes, recipients: list[dict], x3dh_header: bytes | None = None, sender_chain_id: bytes | None = None, sender_chain_n: int | None = None, image_file_id: str | None = None, sender_device_id: str | None = None, ) -> str: """Store an encrypted message with per-recipient ciphertext. recipients: [{user_id, encrypted_content (bytes), nonce (bytes), device_id (str, optional), ratchet_header (bytes, optional), x3dh_header (bytes, optional)}] """ conn = get_connection() try: cursor = conn.cursor() msg_id = generate_uuid() cursor.execute( "INSERT INTO messages (id, conversation_id, sender_id, sender_device_id, " "ratchet_header, x3dh_header, sender_chain_id, sender_chain_n, image_file_id) " "VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)", (msg_id, conversation_id, sender_id, sender_device_id, ratchet_header, x3dh_header, sender_chain_id, sender_chain_n, image_file_id), ) for r in recipients: device_id = r.get("device_id", SELF_DEVICE_ID) cursor.execute( "INSERT INTO message_recipients (message_id, user_id, device_id, " "encrypted_content, nonce, ratchet_header, x3dh_header) " "VALUES (%s, %s, %s, %s, %s, %s, %s)", (msg_id, r["user_id"], device_id, r["encrypted_content"], r["nonce"], r.get("ratchet_header"), r.get("x3dh_header")), ) conn.commit() cursor.execute("SELECT created_at FROM messages WHERE id = %s", (msg_id,)) row = cursor.fetchone() created_at = row[0].isoformat() if row else None return msg_id, created_at finally: conn.close() def get_messages(conversation_id: str, user_id: str, limit: int = 50, offset: int = 0, device_id: str | None = None, after_ts: str | None = None) -> list[dict]: """Get messages for a user in a conversation, JOINing their per-recipient ciphertext. If device_id is set, returns rows where mr.device_id matches OR is the sentinel (self-encrypted / legacy). May return duplicate message IDs when both device-specific and self-encrypted rows exist — caller should deduplicate (prefer device-specific). If after_ts is set, only returns messages created after that timestamp (ISO format). Results are ordered ASC when after_ts is used, DESC otherwise. """ conn = get_connection() try: cursor = conn.cursor(dictionary=True) if device_id: where_parts = ["m.conversation_id = %s", "(cm.joined_at IS NULL OR m.created_at >= cm.joined_at)"] params = [user_id, device_id, SELF_DEVICE_ID, user_id, conversation_id] if after_ts: where_parts.append("m.created_at > %s") params.append(after_ts) where_clause = " AND ".join(where_parts) order = "ASC" if after_ts else "DESC" cursor.execute( "SELECT m.id, m.conversation_id, m.sender_id, m.sender_device_id, " "m.ratchet_header, m.x3dh_header, " "m.sender_chain_id, m.sender_chain_n, m.created_at, m.deleted_at, m.image_file_id, " "m.pinned_at, m.pinned_by, " "mr.encrypted_content, mr.nonce, mr.device_id AS mr_device_id, " "mr.ratchet_header AS mr_ratchet_header, mr.x3dh_header AS mr_x3dh_header " "FROM messages m " "JOIN message_recipients mr ON m.id = mr.message_id AND mr.user_id = %s " " AND (mr.device_id = %s OR mr.device_id = %s) " "JOIN conversation_members cm ON cm.conversation_id = m.conversation_id AND cm.user_id = %s " f"WHERE {where_clause} " f"ORDER BY m.created_at {order} LIMIT %s OFFSET %s", (*params, limit, offset), ) else: where_parts = ["m.conversation_id = %s", "(cm.joined_at IS NULL OR m.created_at >= cm.joined_at)"] params = [user_id, user_id, conversation_id] if after_ts: where_parts.append("m.created_at > %s") params.append(after_ts) where_clause = " AND ".join(where_parts) order = "ASC" if after_ts else "DESC" cursor.execute( "SELECT m.id, m.conversation_id, m.sender_id, m.sender_device_id, " "m.ratchet_header, m.x3dh_header, " "m.sender_chain_id, m.sender_chain_n, m.created_at, m.deleted_at, m.image_file_id, " "m.pinned_at, m.pinned_by, " "mr.encrypted_content, mr.nonce, mr.device_id AS mr_device_id, " "mr.ratchet_header AS mr_ratchet_header, mr.x3dh_header AS mr_x3dh_header " "FROM messages m " "JOIN message_recipients mr ON m.id = mr.message_id AND mr.user_id = %s " "JOIN conversation_members cm ON cm.conversation_id = m.conversation_id AND cm.user_id = %s " f"WHERE {where_clause} " f"ORDER BY m.created_at {order} LIMIT %s OFFSET %s", (*params, limit, offset), ) return cursor.fetchall() finally: conn.close() def count_messages(conversation_id: str, user_id: str) -> int: """Count total messages visible to a user in a conversation.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute( "SELECT COUNT(DISTINCT m.id) " "FROM messages m " "JOIN message_recipients mr ON m.id = mr.message_id AND mr.user_id = %s " "JOIN conversation_members cm ON cm.conversation_id = m.conversation_id AND cm.user_id = %s " "WHERE m.conversation_id = %s AND (cm.joined_at IS NULL OR m.created_at >= cm.joined_at)", (user_id, user_id, conversation_id), ) row = cursor.fetchone() return row[0] if row else 0 finally: conn.close() def get_message_conversation(message_id: str) -> str | None: conn = get_connection() try: cursor = conn.cursor() cursor.execute("SELECT conversation_id FROM messages WHERE id = %s", (message_id,)) row = cursor.fetchone() return row[0] if row else None finally: conn.close() def get_message_sender(message_id: str) -> str | None: conn = get_connection() try: cursor = conn.cursor() cursor.execute("SELECT sender_id FROM messages WHERE id = %s", (message_id,)) row = cursor.fetchone() return row[0] if row else None finally: conn.close() def get_deleted_messages_since(conversation_id: str, user_id: str, since_ts: str) -> list[str]: """Return message IDs that were soft-deleted since the given timestamp.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute( "SELECT m.id FROM messages m " "JOIN conversation_members cm ON cm.conversation_id = m.conversation_id AND cm.user_id = %s " "WHERE m.conversation_id = %s AND m.deleted_at IS NOT NULL AND m.deleted_at > %s", (user_id, conversation_id, since_ts), ) return [row[0] for row in cursor.fetchall()] finally: conn.close() # --- Reactions --- ALLOWED_REACTIONS = {"thumbsup", "heart", "laugh", "surprised", "sad", "thumbsdown"} def add_reaction(message_id: str, user_id: str, reaction: str) -> tuple[bool, str | None]: """Add or replace a reaction. Returns (changed, old_reaction_or_None).""" conn = get_connection() try: cursor = conn.cursor(dictionary=True) cursor.execute( "SELECT reaction FROM message_reactions WHERE message_id = %s AND user_id = %s", (message_id, user_id), ) row = cursor.fetchone() old_reaction = row["reaction"] if row else None if old_reaction == reaction: return False, None # already same reaction if old_reaction: cursor.execute( "UPDATE message_reactions SET reaction = %s, created_at = CURRENT_TIMESTAMP " "WHERE message_id = %s AND user_id = %s", (reaction, message_id, user_id), ) else: cursor.execute( "INSERT INTO message_reactions (id, message_id, user_id, reaction) " "VALUES (%s, %s, %s, %s)", (generate_uuid(), message_id, user_id, reaction), ) conn.commit() return True, old_reaction finally: conn.close() def remove_reaction(message_id: str, user_id: str) -> bool: """Remove a user's reaction. Returns True if deleted.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute( "DELETE FROM message_reactions WHERE message_id = %s AND user_id = %s", (message_id, user_id), ) conn.commit() return cursor.rowcount > 0 finally: conn.close() def get_reactions(message_ids: list[str]) -> dict[str, list[dict]]: """Get reactions for multiple messages. Returns {msg_id: [{user_id, reaction, created_at}]}.""" if not message_ids: return {} conn = get_connection() try: cursor = conn.cursor(dictionary=True) placeholders = ",".join(["%s"] * len(message_ids)) cursor.execute( f"SELECT message_id, user_id, reaction, created_at " f"FROM message_reactions WHERE message_id IN ({placeholders}) " f"ORDER BY created_at", tuple(message_ids), ) result = {} for row in cursor.fetchall(): mid = row["message_id"] if mid not in result: result[mid] = [] result[mid].append({ "user_id": row["user_id"], "reaction": row["reaction"], "created_at": row["created_at"].isoformat() if hasattr(row["created_at"], "isoformat") else str(row["created_at"]), }) return result finally: conn.close() # --- Pins --- def pin_message(message_id: str, user_id: str, conversation_id: str) -> bool: """Pin a message. Returns True on success.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute( "UPDATE messages SET pinned_at = NOW(), pinned_by = %s " "WHERE id = %s AND conversation_id = %s AND pinned_at IS NULL", (user_id, message_id, conversation_id), ) conn.commit() return cursor.rowcount > 0 finally: conn.close() def unpin_message(message_id: str, conversation_id: str) -> bool: """Unpin a message. Returns True on success.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute( "UPDATE messages SET pinned_at = NULL, pinned_by = NULL " "WHERE id = %s AND conversation_id = %s AND pinned_at IS NOT NULL", (message_id, conversation_id), ) conn.commit() return cursor.rowcount > 0 finally: conn.close() def get_pinned_messages(conversation_id: str, user_id: str) -> list[dict]: """Get pinned messages for a conversation (membership verified via JOIN).""" conn = get_connection() try: cursor = conn.cursor(dictionary=True) cursor.execute( "SELECT m.id AS message_id, m.sender_id, m.pinned_at, m.pinned_by, m.created_at " "FROM messages m " "JOIN conversation_members cm ON cm.conversation_id = m.conversation_id AND cm.user_id = %s " "WHERE m.conversation_id = %s AND m.pinned_at IS NOT NULL AND m.deleted_at IS NULL " "ORDER BY m.pinned_at DESC", (user_id, conversation_id), ) rows = cursor.fetchall() for r in rows: for k in ("pinned_at", "created_at"): if r.get(k) and hasattr(r[k], "isoformat"): r[k] = r[k].isoformat() return rows finally: conn.close() # --- Group Sender Keys --- def store_sender_key(conversation_id: str, sender_id: str, chain_id: bytes, device_id: str | None = None): """Store or update a sender key chain ID for a group member's device.""" conn = get_connection() try: cursor = conn.cursor() dev = device_id or SELF_DEVICE_ID cursor.execute( "REPLACE INTO group_sender_keys (conversation_id, sender_id, device_id, chain_id) " "VALUES (%s, %s, %s, %s)", (conversation_id, sender_id, dev, chain_id), ) conn.commit() finally: conn.close() def get_sender_key(conversation_id: str, sender_id: str, device_id: str | None = None) -> dict | None: conn = get_connection() try: cursor = conn.cursor(dictionary=True) dev = device_id or SELF_DEVICE_ID cursor.execute( "SELECT chain_id, created_at FROM group_sender_keys " "WHERE conversation_id = %s AND sender_id = %s AND device_id = %s", (conversation_id, sender_id, dev), ) return cursor.fetchone() finally: conn.close() # --- Read Receipts --- def filter_message_ids_by_conversation(conversation_id: str, message_ids: list[str]) -> list[str]: """Return only message_ids that belong to the given conversation.""" if not message_ids: return [] conn = get_connection() try: cursor = conn.cursor() placeholders = ",".join(["%s"] * len(message_ids)) cursor.execute( f"SELECT id FROM messages WHERE id IN ({placeholders}) AND conversation_id = %s", (*message_ids, conversation_id), ) return [row[0] for row in cursor.fetchall()] finally: conn.close() def mark_messages_read(conversation_id: str, user_id: str, message_ids: list[str]): if not message_ids: return conn = get_connection() try: cursor = conn.cursor() # M1 fix: JOIN messages to verify message_ids belong to conversation_id placeholders = ",".join(["%s"] * len(message_ids)) cursor.execute( f"INSERT IGNORE INTO message_reads (message_id, user_id) " f"SELECT m.id, %s FROM messages m " f"WHERE m.id IN ({placeholders}) AND m.conversation_id = %s", (user_id, *message_ids, conversation_id), ) conn.commit() finally: conn.close() def mark_conversation_read(conversation_id: str, user_id: str) -> int: """Mark ALL unread messages in a conversation as read for user. Returns count marked.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute( "INSERT IGNORE INTO message_reads (message_id, user_id) " "SELECT m.id, %s " "FROM messages m " "JOIN message_recipients mr ON mr.message_id = m.id AND mr.user_id = %s " "LEFT JOIN message_reads mrd ON mrd.message_id = m.id AND mrd.user_id = %s " "WHERE m.conversation_id = %s AND m.sender_id != %s " "AND m.deleted_at IS NULL AND mrd.message_id IS NULL", (user_id, user_id, user_id, conversation_id, user_id), ) count = cursor.rowcount conn.commit() return count finally: conn.close() def get_unread_counts(user_id: str, max_age_days: int = 0) -> dict[str, int]: """Return {conversation_id: unread_count} for all conversations the user is in. max_age_days: if > 0, only count messages younger than this many days. Must match METADATA_RETENTION_DAYS to avoid phantom unreads after read cleanup. """ conn = get_connection() try: cursor = conn.cursor(dictionary=True) age_filter = "" params = [user_id, user_id, user_id] if max_age_days > 0: age_filter = " AND m.created_at >= DATE_SUB(NOW(), INTERVAL %s DAY)" params.append(max_age_days) cursor.execute( "SELECT m.conversation_id, COUNT(DISTINCT m.id) AS cnt " "FROM messages m " "JOIN message_recipients mr ON mr.message_id = m.id AND mr.user_id = %s " "LEFT JOIN message_reads mrd ON mrd.message_id = m.id AND mrd.user_id = %s " "WHERE m.sender_id != %s AND m.deleted_at IS NULL AND mrd.message_id IS NULL" f"{age_filter} " "GROUP BY m.conversation_id", params, ) return {row["conversation_id"]: row["cnt"] for row in cursor.fetchall()} finally: conn.close() def get_message_read_status(message_ids: list[str]) -> dict: if not message_ids: return {} conn = get_connection() try: cursor = conn.cursor(dictionary=True) placeholders = ",".join(["%s"] * len(message_ids)) cursor.execute( f"SELECT mr.message_id, mr.user_id, mr.read_at " f"FROM message_reads mr " f"WHERE mr.message_id IN ({placeholders})", tuple(message_ids), ) result = {} for row in cursor.fetchall(): mid = row["message_id"] if mid not in result: result[mid] = [] result[mid].append({ "user_id": row["user_id"], "read_at": row["read_at"].isoformat() if hasattr(row["read_at"], "isoformat") else str(row["read_at"]), }) return result finally: conn.close() # --- Delivery Receipts --- def mark_messages_delivered(conversation_id: str, user_id: str, message_ids: list[str]): """Batch insert delivery receipts (INSERT IGNORE — idempotent).""" if not message_ids: return conn = get_connection() try: cursor = conn.cursor() # M1 fix: JOIN messages to verify message_ids belong to conversation_id placeholders = ",".join(["%s"] * len(message_ids)) cursor.execute( f"INSERT IGNORE INTO message_deliveries (message_id, user_id) " f"SELECT m.id, %s FROM messages m " f"WHERE m.id IN ({placeholders}) AND m.conversation_id = %s", (user_id, *message_ids, conversation_id), ) conn.commit() finally: conn.close() def get_message_delivery_status(message_ids: list[str]) -> dict: """Get delivery status for messages. Returns {msg_id: [{user_id, delivered_at}]}.""" if not message_ids: return {} conn = get_connection() try: cursor = conn.cursor(dictionary=True) placeholders = ",".join(["%s"] * len(message_ids)) cursor.execute( f"SELECT md.message_id, md.user_id, md.delivered_at " f"FROM message_deliveries md " f"WHERE md.message_id IN ({placeholders})", tuple(message_ids), ) result = {} for row in cursor.fetchall(): mid = row["message_id"] if mid not in result: result[mid] = [] result[mid].append({ "user_id": row["user_id"], "delivered_at": row["delivered_at"].isoformat() if hasattr(row["delivered_at"], "isoformat") else str(row["delivered_at"]), }) return result finally: conn.close() # --- Delete --- def soft_delete_message(message_id: str, sender_id: str) -> dict | None: """Soft-delete a message if sender matches. Returns {'image_file_id': ...} or None.""" conn = get_connection() try: cursor = conn.cursor(dictionary=True) cursor.execute( "SELECT sender_id, image_file_id FROM messages WHERE id = %s AND deleted_at IS NULL", (message_id,), ) row = cursor.fetchone() if not row or row["sender_id"] != sender_id: return None cursor.execute( "UPDATE messages SET deleted_at = NOW() WHERE id = %s", (message_id,), ) # Clear per-recipient ciphertext cursor.execute( "UPDATE message_recipients SET encrypted_content = %s WHERE message_id = %s", (b"", message_id), ) conn.commit() return {"image_file_id": row.get("image_file_id")} finally: conn.close() def set_message_image_file_id(message_id: str, file_id: str): conn = get_connection() try: cursor = conn.cursor() cursor.execute( "UPDATE messages SET image_file_id = %s WHERE id = %s", (file_id, message_id), ) conn.commit() finally: conn.close() # --- Image Uploads --- def create_image_upload(file_id: str, conversation_id: str, uploader_id: str, file_size: int): conn = get_connection() try: cursor = conn.cursor() cursor.execute( "INSERT INTO image_uploads (file_id, conversation_id, uploader_id, file_size) " "VALUES (%s, %s, %s, %s)", (file_id, conversation_id, uploader_id, file_size), ) conn.commit() finally: conn.close() def complete_image_upload(file_id: str): conn = get_connection() try: cursor = conn.cursor() cursor.execute( "UPDATE image_uploads SET completed = TRUE WHERE file_id = %s", (file_id,), ) conn.commit() finally: conn.close() def get_image_upload(file_id: str) -> dict | None: conn = get_connection() try: cursor = conn.cursor(dictionary=True) cursor.execute( "SELECT file_id, conversation_id, uploader_id, file_size, completed, created_at " "FROM image_uploads WHERE file_id = %s", (file_id,), ) return cursor.fetchone() finally: conn.close() def delete_image_upload(file_id: str): conn = get_connection() try: cursor = conn.cursor() cursor.execute("DELETE FROM image_uploads WHERE file_id = %s", (file_id,)) conn.commit() finally: conn.close() # --- User Profiles --- def create_default_profile(user_id: str): """Create a default profile for a new user.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute( "INSERT IGNORE INTO user_profiles (user_id) VALUES (%s)", (user_id,), ) conn.commit() finally: conn.close() def get_user_profile(user_id: str, viewer_id: str | None = None) -> dict | None: """Get user profile joined with user info. Respects visibility if viewer is different user.""" conn = get_connection() try: cursor = conn.cursor(dictionary=True) cursor.execute( "SELECT u.id AS user_id, u.username, u.email, u.created_at, " "p.phone, p.phone_visible, p.email_visible, p.location, " "p.location_visible, p.avatar_file, p.updated_at " "FROM users u LEFT JOIN user_profiles p ON u.id = p.user_id " "WHERE u.id = %s", (user_id,), ) row = cursor.fetchone() if not row: return None # If viewing someone else's profile, apply visibility rules if viewer_id and viewer_id != user_id: if not row.get("email_visible"): row["email"] = None if not row.get("phone_visible"): row["phone"] = None if not row.get("location_visible"): row["location"] = None return row finally: conn.close() def update_user_profile(user_id: str, **fields): """Upsert user profile fields. Allowed: phone, phone_visible, email_visible, location, location_visible, avatar_file.""" allowed = {"phone", "phone_visible", "email_visible", "location", "location_visible", "avatar_file"} filtered = {k: v for k, v in fields.items() if k in allowed} if not filtered: return conn = get_connection() try: cursor = conn.cursor() # Upsert: insert default then update cursor.execute( "INSERT IGNORE INTO user_profiles (user_id) VALUES (%s)", (user_id,), ) set_clause = ", ".join(f"{k} = %s" for k in filtered) values = list(filtered.values()) + [user_id] cursor.execute( f"UPDATE user_profiles SET {set_clause} WHERE user_id = %s", values, ) conn.commit() finally: conn.close() def batch_reencrypt_messages(user_id: str, updates: list[dict]): """Batch upsert message_recipients rows with self-encryption key data. Each update: {message_id, encrypted_content (bytes), nonce (bytes)}. Sets ratchet_header to '{"self":true}' and clears x3dh_header. Uses INSERT ... ON DUPLICATE KEY UPDATE so it works for both sent messages (which already have a SELF_DEVICE_ID row) and received messages (which don't). """ if not updates: return conn = get_connection() try: cursor = conn.cursor() self_header = b'{"self":true}' for u in updates: cursor.execute( "INSERT INTO message_recipients " "(message_id, user_id, device_id, encrypted_content, nonce, ratchet_header, x3dh_header) " "VALUES (%s, %s, %s, %s, %s, %s, NULL) " "ON DUPLICATE KEY UPDATE encrypted_content = VALUES(encrypted_content), " "nonce = VALUES(nonce), ratchet_header = VALUES(ratchet_header), x3dh_header = NULL", (u["message_id"], user_id, SELF_DEVICE_ID, u["encrypted_content"], u["nonce"], self_header), ) conn.commit() finally: conn.close() # --- Phantom Users --- def create_phantom_user(email: str) -> dict: """Create a phantom user with valid crypto keys for X3DH. Phantom users have rsa_public_key = 'PHANTOM' as a marker. Returns user dict: {id, username, email, identity_key}. """ username = email.split("@")[0] user_id = generate_uuid() # Generate real crypto keys so X3DH works on the client side ik_private, ik_public = generate_identity_keypair() ik_public_bytes = serialize_ed25519_public(ik_public) spk = generate_signed_prekey(ik_private) spk_pub_bytes = serialize_x25519_public(spk["public"]) spk_sig = spk["signature"] opks = generate_one_time_prekeys(count=5) conn = get_connection() try: cursor = conn.cursor() cursor.execute( "INSERT INTO users (id, username, email, rsa_public_key, identity_key) " "VALUES (%s, %s, %s, %s, %s)", (user_id, username, email, "PHANTOM", ik_public_bytes), ) cursor.execute( "INSERT INTO signed_prekeys (id, user_id, public_key, signature) VALUES (%s, %s, %s, %s)", (spk["id"], user_id, spk_pub_bytes, spk_sig), ) for opk in opks: cursor.execute( "INSERT INTO one_time_prekeys (id, user_id, public_key) VALUES (%s, %s, %s)", (opk["id"], user_id, serialize_x25519_public(opk["public"])), ) conn.commit() return {"id": user_id, "username": username, "email": email, "identity_key": ik_public_bytes} finally: conn.close() def is_phantom_user(user_id: str) -> bool: """Check if a user is a phantom (rsa_public_key == 'PHANTOM').""" conn = get_connection() try: cursor = conn.cursor() cursor.execute("SELECT rsa_public_key FROM users WHERE id = %s", (user_id,)) row = cursor.fetchone() return row is not None and row[0] == "PHANTOM" finally: conn.close() def delete_phantom_user(user_id: str): """Delete a phantom user. CASCADE removes signed_prekeys, one_time_prekeys, conversation_members, message_recipients, etc.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute( "DELETE FROM users WHERE id = %s AND rsa_public_key = %s", (user_id, "PHANTOM"), ) conn.commit() finally: conn.close() def upgrade_phantom_user(phantom_id: str, username: str, rsa_public_key_pem: str, identity_key: bytes) -> str | None: """Upgrade a phantom user to a real user in-place. Preserves user_id and all FK references (conversation_members, group_invitations, etc.). Deletes phantom's server-generated prekeys (real user will upload own on first login). Returns phantom_id as the new user_id, or None if phantom no longer exists. """ conn = get_connection() try: cursor = conn.cursor() cursor.execute( "UPDATE users SET username = %s, rsa_public_key = %s, identity_key = %s " "WHERE id = %s AND rsa_public_key = 'PHANTOM'", (username, rsa_public_key_pem, identity_key, phantom_id), ) if cursor.rowcount == 0: conn.rollback() return None # Remove phantom's server-generated crypto keys — real user uploads own cursor.execute("DELETE FROM signed_prekeys WHERE user_id = %s", (phantom_id,)) cursor.execute("DELETE FROM one_time_prekeys WHERE user_id = %s", (phantom_id,)) conn.commit() return phantom_id finally: conn.close() def get_all_phantom_user_ids() -> set[str]: """Return set of all phantom user IDs (for server startup cache).""" conn = get_connection() try: cursor = conn.cursor() cursor.execute("SELECT id FROM users WHERE rsa_public_key = %s", ("PHANTOM",)) return {row[0] for row in cursor.fetchall()} finally: conn.close() def cleanup_stale_phantoms(max_age_days: int = 30) -> int: """Delete phantom users older than max_age_days with no active conversations with real users.""" conn = get_connection() try: cursor = conn.cursor() # Two-step: SELECT ids first, then DELETE. # MySQL error 1093: can't DELETE from table referenced in subquery. cursor.execute(""" SELECT u.id FROM users u WHERE u.rsa_public_key = 'PHANTOM' AND u.created_at < DATE_SUB(NOW(), INTERVAL %s DAY) AND NOT EXISTS ( SELECT 1 FROM conversation_members cm1 JOIN conversation_members cm2 ON cm1.conversation_id = cm2.conversation_id JOIN users u2 ON cm2.user_id = u2.id WHERE cm1.user_id = u.id AND u2.rsa_public_key != 'PHANTOM' ) """, (max_age_days,)) ids = [row[0] for row in cursor.fetchall()] if not ids: return 0 cursor.execute( "DELETE FROM users WHERE id IN (%s)" % ",".join(["%s"] * len(ids)), ids, ) deleted = cursor.rowcount conn.commit() return deleted finally: conn.close() def remove_conversation_member_atomic(conversation_id: str, user_id: str) -> bool: """Remove member and return True if actually removed (row existed). M6 TOCTOU fix.""" conn = get_connection() try: cursor = conn.cursor() cursor.execute( "DELETE FROM conversation_members WHERE conversation_id = %s AND user_id = %s", (conversation_id, user_id), ) conn.commit() return cursor.rowcount > 0 finally: conn.close() def get_stale_uploads(max_age_seconds: int = 3600) -> list[dict]: conn = get_connection() try: cursor = conn.cursor(dictionary=True) cursor.execute( "SELECT file_id FROM image_uploads " "WHERE completed = FALSE AND created_at < DATE_SUB(NOW(), INTERVAL %s SECOND)", (max_age_seconds,), ) return cursor.fetchall() finally: conn.close() # --------------------------------------------------------------------------- # Metadata retention cleanup # --------------------------------------------------------------------------- def cleanup_old_reads(days: int = 90, batch_size: int = 10000) -> int: """Delete message_reads older than N days in batches. Only deletes reads for messages whose created_at is also past the retention window. This prevents phantom unreads: get_unread_counts uses the same time window (max_age_days) so messages outside the window aren't counted. """ total = 0 while True: conn = get_connection() try: cursor = conn.cursor() cursor.execute( "DELETE FROM message_reads " "WHERE read_at < DATE_SUB(NOW(), INTERVAL %s DAY) " "AND message_id IN (" " SELECT id FROM messages " " WHERE created_at < DATE_SUB(NOW(), INTERVAL %s DAY)" ") LIMIT %s", (days, days, batch_size), ) count = cursor.rowcount conn.commit() total += count if count < batch_size: break finally: conn.close() return total def cleanup_old_reactions(days: int = 90, batch_size: int = 10000) -> int: """Delete message_reactions older than N days in batches.""" total = 0 while True: conn = get_connection() try: cursor = conn.cursor() cursor.execute( "DELETE FROM message_reactions WHERE created_at < DATE_SUB(NOW(), INTERVAL %s DAY) LIMIT %s", (days, batch_size), ) count = cursor.rowcount conn.commit() total += count if count < batch_size: break finally: conn.close() return total def cleanup_old_messages(days: int, batch_size: int = 1000) -> tuple[int, list[str]]: """Delete messages older than N days in batches. message_recipients / message_reads / message_deliveries / message_reactions rows go with them via ON DELETE CASCADE. Returns (deleted_count, orphaned_file_ids) — file_ids whose encrypted blobs are no longer referenced by any surviving message. The caller is responsible for removing those files from the upload directory (db layer does not touch the filesystem). """ # Collect attachment file_ids referenced by messages about to be deleted conn = get_connection() try: cursor = conn.cursor() cursor.execute( "SELECT DISTINCT image_file_id FROM messages " "WHERE created_at < DATE_SUB(NOW(), INTERVAL %s DAY) " "AND image_file_id IS NOT NULL", (days,), ) candidate_files = [row[0] for row in cursor.fetchall()] finally: conn.close() total = 0 while True: conn = get_connection() try: cursor = conn.cursor() cursor.execute( "DELETE FROM messages WHERE created_at < DATE_SUB(NOW(), INTERVAL %s DAY) LIMIT %s", (days, batch_size), ) count = cursor.rowcount conn.commit() total += count if count < batch_size: break finally: conn.close() # A file is orphaned only if no surviving (newer) message still references # it (e.g. a forwarded copy) orphaned: list[str] = [] if candidate_files: still_referenced: set[str] = set() conn = get_connection() try: cursor = conn.cursor() for i in range(0, len(candidate_files), 500): chunk = candidate_files[i:i + 500] placeholders = ", ".join(["%s"] * len(chunk)) cursor.execute( f"SELECT DISTINCT image_file_id FROM messages " f"WHERE image_file_id IN ({placeholders})", chunk, ) still_referenced.update(row[0] for row in cursor.fetchall()) finally: conn.close() orphaned = [f for f in candidate_files if f not in still_referenced] if orphaned: conn = get_connection() try: cursor = conn.cursor() for i in range(0, len(orphaned), 500): chunk = orphaned[i:i + 500] placeholders = ", ".join(["%s"] * len(chunk)) cursor.execute( f"DELETE FROM image_uploads WHERE file_id IN ({placeholders})", chunk, ) conn.commit() finally: conn.close() return total, orphaned