Files
Kecalek_python/db.py
Filip f42ecf5c5b Add message retention and hide emails by default
- db: cleanup_old_messages(days) purges messages older than N days in
  batches; recipients/reads/deliveries/reactions follow via ON DELETE
  CASCADE. Returns attachment file_ids no longer referenced by any
  surviving message (forwarded copies keep their files) and removes
  their image_uploads rows
- server: MESSAGE_RETENTION_DAYS env var (default 0 = keep forever);
  hourly cleanup deletes expired messages and securely removes orphaned
  attachment blobs from the upload dir
- schema: email_visible now defaults to 0 — previously any logged-in
  user who knew a UUID could read another user's email via get_profile
- migrations: SQL script to apply the new default and reset the flag on
  existing databases (run manually, see file header)
- docker-compose: document MESSAGE_RETENTION_DAYS

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-12 10:30:42 +02:00

1808 lines
61 KiB
Python

"""MySQL database layer for the encrypted chat server."""
import os
import uuid
import logging
import mysql.connector
from mysql.connector import pooling
from dotenv import load_dotenv
from crypto_utils import (
generate_identity_keypair,
serialize_ed25519_public,
generate_signed_prekey,
serialize_x25519_public,
generate_one_time_prekeys,
)
load_dotenv()
# Sentinel device_id for self-encrypted copies and legacy (pre-multi-device) rows
SELF_DEVICE_ID = "00000000-0000-0000-0000-000000000000"
_logger = logging.getLogger(__name__)
_pool = None
def _get_pool():
"""Get or create the connection pool (lazy init)."""
global _pool
if _pool is None:
pool_size = int(os.getenv("DB_POOL_SIZE", "10"))
pool_kwargs = dict(
pool_name="chat_pool",
pool_size=pool_size,
pool_reset_session=True,
host=os.getenv("MYSQL_HOST", "localhost"),
port=int(os.getenv("MYSQL_PORT", "3306")),
user=os.getenv("MYSQL_USER", "root"),
password=os.getenv("MYSQL_PASSWORD", ""),
database=os.getenv("MYSQL_DATABASE", "encrypted_chat"),
)
# Optional MySQL TLS (M7): set MYSQL_SSL_CA (and optionally MYSQL_SSL_CERT/KEY)
ssl_ca = os.getenv("MYSQL_SSL_CA", "").strip()
ssl_cert = os.getenv("MYSQL_SSL_CERT", "").strip()
ssl_key = os.getenv("MYSQL_SSL_KEY", "").strip()
if ssl_ca:
pool_kwargs["ssl_ca"] = ssl_ca
if ssl_cert:
pool_kwargs["ssl_cert"] = ssl_cert
if ssl_key:
pool_kwargs["ssl_key"] = ssl_key
_logger.info("MySQL TLS enabled (CA: %s)", ssl_ca)
_pool = pooling.MySQLConnectionPool(**pool_kwargs)
_logger.info("DB connection pool created (size=%d)", pool_size)
return _pool
def get_connection():
"""Get a connection from the pool."""
return _get_pool().get_connection()
def generate_uuid() -> str:
return str(uuid.uuid4())
# --- Devices ---
def create_device(user_id: str, device_name: str | None = None) -> str:
"""Create a new device for a user. Returns device_id."""
conn = get_connection()
try:
cursor = conn.cursor()
device_id = generate_uuid()
cursor.execute(
"INSERT INTO devices (id, user_id, device_name) VALUES (%s, %s, %s)",
(device_id, user_id, device_name),
)
conn.commit()
return device_id
finally:
conn.close()
def get_user_devices(user_id: str) -> list[dict]:
"""Get all devices for a user."""
conn = get_connection()
try:
cursor = conn.cursor(dictionary=True)
cursor.execute(
"SELECT id, user_id, device_name, created_at, last_seen_at "
"FROM devices WHERE user_id = %s ORDER BY created_at",
(user_id,),
)
return cursor.fetchall()
finally:
conn.close()
def get_device(device_id: str) -> dict | None:
"""Get a single device by ID."""
conn = get_connection()
try:
cursor = conn.cursor(dictionary=True)
cursor.execute(
"SELECT id, user_id, device_name, created_at, last_seen_at "
"FROM devices WHERE id = %s",
(device_id,),
)
return cursor.fetchone()
finally:
conn.close()
def update_device_last_seen(device_id: str):
"""Update last_seen_at timestamp for a device."""
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"UPDATE devices SET last_seen_at = NOW() WHERE id = %s",
(device_id,),
)
conn.commit()
finally:
conn.close()
def delete_device(device_id: str):
"""Delete a device. CASCADE removes its prekeys."""
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute("DELETE FROM devices WHERE id = %s", (device_id,))
# Also clean up prekeys explicitly for device_id column
cursor.execute("DELETE FROM signed_prekeys WHERE device_id = %s", (device_id,))
cursor.execute("DELETE FROM one_time_prekeys WHERE device_id = %s", (device_id,))
conn.commit()
finally:
conn.close()
# --- Users ---
def create_user(username: str, email: str, rsa_public_key_pem: str, identity_key: bytes) -> str:
"""Register a new user. Returns user ID."""
conn = get_connection()
try:
cursor = conn.cursor()
user_id = generate_uuid()
cursor.execute(
"INSERT INTO users (id, username, email, rsa_public_key, identity_key) "
"VALUES (%s, %s, %s, %s, %s)",
(user_id, username, email, rsa_public_key_pem, identity_key),
)
conn.commit()
return user_id
finally:
conn.close()
def get_user_by_email(email: str) -> dict | None:
"""Get user by email."""
conn = get_connection()
try:
cursor = conn.cursor(dictionary=True)
cursor.execute(
"SELECT id, username, rsa_public_key, email, identity_key FROM users WHERE email = %s",
(email,),
)
return cursor.fetchone()
finally:
conn.close()
def get_user_by_id(user_id: str) -> dict | None:
"""Get user by ID."""
conn = get_connection()
try:
cursor = conn.cursor(dictionary=True)
cursor.execute(
"SELECT id, username, rsa_public_key, email, identity_key FROM users WHERE id = %s",
(user_id,),
)
return cursor.fetchone()
finally:
conn.close()
def shares_conversation(user_id_a: str, user_id_b: str) -> bool:
"""Check if two users share at least one conversation."""
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"SELECT 1 FROM conversation_members cm1 "
"JOIN conversation_members cm2 ON cm1.conversation_id = cm2.conversation_id "
"WHERE cm1.user_id = %s AND cm2.user_id = %s LIMIT 1",
(user_id_a, user_id_b),
)
return cursor.fetchone() is not None
finally:
conn.close()
def get_user_contacts(user_id: str) -> list[str]:
"""Get all user IDs that share at least one conversation with the given user."""
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"SELECT DISTINCT cm2.user_id "
"FROM conversation_members cm1 "
"JOIN conversation_members cm2 ON cm1.conversation_id = cm2.conversation_id "
"WHERE cm1.user_id = %s AND cm2.user_id != %s",
(user_id, user_id),
)
return [row[0] for row in cursor.fetchall()]
finally:
conn.close()
def update_user_rsa_key(user_id: str, rsa_public_key_pem: str):
"""Update user's RSA public key (for login)."""
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute("UPDATE users SET rsa_public_key = %s WHERE id = %s", (rsa_public_key_pem, user_id))
conn.commit()
finally:
conn.close()
def update_username(user_id: str, new_username: str):
"""Update user's display name."""
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute("UPDATE users SET username = %s WHERE id = %s", (new_username, user_id))
conn.commit()
finally:
conn.close()
# --- Pre-keys ---
def store_signed_prekey(user_id: str, spk_id: str, public_key: bytes, signature: bytes,
device_id: str | None = None):
"""Store (or replace) a signed pre-key for a user's device."""
conn = get_connection()
try:
cursor = conn.cursor()
# Remove old SPKs for this user+device
if device_id:
cursor.execute("DELETE FROM signed_prekeys WHERE user_id = %s AND device_id = %s",
(user_id, device_id))
else:
cursor.execute("DELETE FROM signed_prekeys WHERE user_id = %s AND device_id IS NULL",
(user_id,))
cursor.execute(
"INSERT INTO signed_prekeys (id, user_id, device_id, public_key, signature) "
"VALUES (%s, %s, %s, %s, %s)",
(spk_id, user_id, device_id, public_key, signature),
)
conn.commit()
finally:
conn.close()
def get_signed_prekey(user_id: str, device_id: str | None = None) -> dict | None:
"""Get the current signed pre-key for a user (optionally per device)."""
conn = get_connection()
try:
cursor = conn.cursor(dictionary=True)
if device_id:
cursor.execute(
"SELECT id, public_key, signature, device_id, created_at FROM signed_prekeys "
"WHERE user_id = %s AND device_id = %s "
"ORDER BY created_at DESC LIMIT 1",
(user_id, device_id),
)
else:
cursor.execute(
"SELECT id, public_key, signature, device_id, created_at FROM signed_prekeys "
"WHERE user_id = %s ORDER BY created_at DESC LIMIT 1",
(user_id,),
)
return cursor.fetchone()
finally:
conn.close()
def store_one_time_prekeys(user_id: str, prekeys: list[dict], device_id: str | None = None):
"""Store a batch of one-time pre-keys. Each dict has {id, public_key (bytes)}."""
conn = get_connection()
try:
cursor = conn.cursor()
for pk in prekeys:
cursor.execute(
"INSERT INTO one_time_prekeys (id, user_id, device_id, public_key) "
"VALUES (%s, %s, %s, %s)",
(pk["id"], user_id, device_id, pk["public_key"]),
)
conn.commit()
finally:
conn.close()
def consume_one_time_prekey(user_id: str, device_id: str | None = None) -> dict | None:
"""Atomically consume one OTP: SELECT FOR UPDATE + DELETE.
Returns {id, public_key} or None."""
conn = get_connection()
try:
cursor = conn.cursor(dictionary=True)
conn.start_transaction()
if device_id:
cursor.execute(
"SELECT id, public_key FROM one_time_prekeys "
"WHERE user_id = %s AND device_id = %s LIMIT 1 FOR UPDATE",
(user_id, device_id),
)
else:
cursor.execute(
"SELECT id, public_key FROM one_time_prekeys "
"WHERE user_id = %s LIMIT 1 FOR UPDATE",
(user_id,),
)
row = cursor.fetchone()
if row:
cursor.execute("DELETE FROM one_time_prekeys WHERE id = %s", (row["id"],))
conn.commit()
return row
except Exception:
conn.rollback()
raise
finally:
conn.close()
def count_one_time_prekeys(user_id: str, device_id: str | None = None) -> int:
"""Count remaining OTPs for a user (optionally per device)."""
conn = get_connection()
try:
cursor = conn.cursor()
if device_id:
cursor.execute(
"SELECT COUNT(*) FROM one_time_prekeys WHERE user_id = %s AND device_id = %s",
(user_id, device_id),
)
else:
cursor.execute("SELECT COUNT(*) FROM one_time_prekeys WHERE user_id = %s", (user_id,))
return cursor.fetchone()[0]
finally:
conn.close()
def get_key_bundle(user_id: str) -> dict | None:
"""Get complete key bundle for X3DH (single device — legacy compat).
Returns {identity_key, signed_prekey_id, signed_prekey, spk_signature,
one_time_prekey_id, one_time_prekey} or None.
OTP is consumed atomically.
"""
conn = get_connection()
try:
cursor = conn.cursor(dictionary=True)
# Get user identity key
cursor.execute("SELECT identity_key FROM users WHERE id = %s", (user_id,))
user = cursor.fetchone()
if not user:
return None
# Get signed prekey
cursor.execute(
"SELECT id, public_key, signature, device_id FROM signed_prekeys WHERE user_id = %s "
"ORDER BY created_at DESC LIMIT 1",
(user_id,),
)
spk = cursor.fetchone()
if not spk:
return None
# Consume one OTP (may be None) — use transaction for atomicity (H12 fix)
conn.start_transaction()
cursor.execute(
"SELECT id, public_key FROM one_time_prekeys WHERE user_id = %s LIMIT 1 FOR UPDATE",
(user_id,),
)
opk = cursor.fetchone()
if opk:
cursor.execute("DELETE FROM one_time_prekeys WHERE id = %s", (opk["id"],))
conn.commit()
result = {
"identity_key": user["identity_key"],
"signed_prekey_id": spk["id"],
"signed_prekey": spk["public_key"],
"spk_signature": spk["signature"],
}
if opk:
result["one_time_prekey_id"] = opk["id"]
result["one_time_prekey"] = opk["public_key"]
return result
except Exception:
try:
conn.rollback()
except Exception:
pass
raise
finally:
conn.close()
def get_key_bundles_for_user(user_id: str) -> dict | None:
"""Get key bundles for ALL devices of a user. Returns
{identity_key, device_bundles: [{device_id, signed_prekey_id, signed_prekey_pub,
spk_signature, opk_id, opk_pub}]} or None.
Consumes one OPK per device atomically.
"""
conn = get_connection()
try:
cursor = conn.cursor(dictionary=True)
# Get user identity key
cursor.execute("SELECT identity_key FROM users WHERE id = %s", (user_id,))
user = cursor.fetchone()
if not user:
return None
# Get all signed prekeys (one per device, most recent)
cursor.execute(
"SELECT id, public_key, signature, device_id FROM signed_prekeys "
"WHERE user_id = %s ORDER BY created_at DESC",
(user_id,),
)
all_spks = cursor.fetchall()
if not all_spks:
return None
# De-duplicate: keep only the most recent SPK per device_id
seen_devices = set()
spks_by_device = []
for spk in all_spks:
dev = spk.get("device_id") or "__legacy__"
if dev not in seen_devices:
seen_devices.add(dev)
spks_by_device.append(spk)
device_bundles = []
# Commit the implicit transaction from the read-only queries above
# so we can start an explicit transaction for atomic OPK consumption.
conn.commit()
conn.start_transaction()
for spk in spks_by_device:
dev_id = spk.get("device_id")
# Consume one OPK for this device
if dev_id:
cursor.execute(
"SELECT id, public_key FROM one_time_prekeys "
"WHERE user_id = %s AND device_id = %s LIMIT 1 FOR UPDATE",
(user_id, dev_id),
)
else:
cursor.execute(
"SELECT id, public_key FROM one_time_prekeys "
"WHERE user_id = %s AND device_id IS NULL LIMIT 1 FOR UPDATE",
(user_id,),
)
opk = cursor.fetchone()
if opk:
cursor.execute("DELETE FROM one_time_prekeys WHERE id = %s", (opk["id"],))
bundle = {
"device_id": dev_id,
"signed_prekey_id": spk["id"],
"signed_prekey_pub": spk["public_key"],
"spk_signature": spk["signature"],
}
if opk:
bundle["opk_id"] = opk["id"]
bundle["opk_pub"] = opk["public_key"]
device_bundles.append(bundle)
conn.commit()
return {
"identity_key": user["identity_key"],
"device_bundles": device_bundles,
}
except Exception:
try:
conn.rollback()
except Exception:
pass
raise
finally:
conn.close()
# --- Conversations ---
def create_conversation(member_user_ids: list[str], joined_at=None, name=None, created_by=None) -> str:
conn = get_connection()
try:
cursor = conn.cursor()
conv_id = generate_uuid()
cursor.execute("INSERT INTO conversations (id, name, created_by) VALUES (%s, %s, %s)",
(conv_id, name, created_by))
for uid in member_user_ids:
cursor.execute(
"INSERT INTO conversation_members (conversation_id, user_id, joined_at) VALUES (%s, %s, %s)",
(conv_id, uid, joined_at),
)
conn.commit()
return conv_id
finally:
conn.close()
def add_conversation_member(conversation_id: str, user_id: str, joined_at=None):
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"INSERT IGNORE INTO conversation_members (conversation_id, user_id, joined_at) VALUES (%s, %s, %s)",
(conversation_id, user_id, joined_at),
)
conn.commit()
finally:
conn.close()
def remove_conversation_member(conversation_id: str, user_id: str):
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"DELETE FROM conversation_members WHERE conversation_id = %s AND user_id = %s",
(conversation_id, user_id),
)
conn.commit()
finally:
conn.close()
def count_conversation_members(conversation_id: str) -> int:
"""Count members in a conversation."""
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"SELECT COUNT(*) FROM conversation_members WHERE conversation_id = %s",
(conversation_id,),
)
return cursor.fetchone()[0]
finally:
conn.close()
def get_conversation_file_ids(conversation_id: str) -> list[str]:
"""Get all file IDs (images + files) uploaded to a conversation."""
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"SELECT file_id FROM image_uploads WHERE conversation_id = %s",
(conversation_id,),
)
return [row[0] for row in cursor.fetchall()]
finally:
conn.close()
def delete_conversation(conversation_id: str):
"""Delete a conversation entirely. CASCADE cleans up members, messages, etc."""
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute("DELETE FROM conversations WHERE id = %s", (conversation_id,))
conn.commit()
finally:
conn.close()
def get_conversation_members(conversation_id: str) -> list[dict]:
conn = get_connection()
try:
cursor = conn.cursor(dictionary=True)
cursor.execute(
"SELECT u.id, u.username, u.email, u.identity_key FROM conversation_members cm "
"JOIN users u ON cm.user_id = u.id "
"WHERE cm.conversation_id = %s",
(conversation_id,),
)
return cursor.fetchall()
finally:
conn.close()
def find_direct_conversation(user_id_a: str, user_id_b: str) -> str | None:
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"SELECT cm1.conversation_id FROM conversation_members cm1 "
"JOIN conversation_members cm2 ON cm1.conversation_id = cm2.conversation_id "
"WHERE cm1.user_id = %s AND cm2.user_id = %s "
"AND (SELECT COUNT(*) FROM conversation_members cm3 "
" WHERE cm3.conversation_id = cm1.conversation_id) = 2 "
"LIMIT 1",
(user_id_a, user_id_b),
)
row = cursor.fetchone()
return row[0] if row else None
finally:
conn.close()
def update_conversation_creator(conversation_id: str, new_creator_id: str):
"""Transfer group creator role to another member."""
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"UPDATE conversations SET created_by = %s WHERE id = %s",
(new_creator_id, conversation_id),
)
conn.commit()
finally:
conn.close()
def get_conversation(conversation_id: str) -> dict | None:
"""Get conversation by ID."""
conn = get_connection()
try:
cursor = conn.cursor(dictionary=True)
cursor.execute(
"SELECT id, created_at, name, created_by, avatar_file FROM conversations WHERE id = %s",
(conversation_id,),
)
return cursor.fetchone()
finally:
conn.close()
def update_conversation_avatar(conversation_id: str, avatar_file: str):
"""Set avatar file for a conversation."""
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"UPDATE conversations SET avatar_file = %s WHERE id = %s",
(avatar_file, conversation_id),
)
conn.commit()
finally:
conn.close()
def update_conversation_name(conversation_id: str, name: str):
"""Update the name of a conversation."""
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"UPDATE conversations SET name = %s WHERE id = %s",
(name, conversation_id),
)
conn.commit()
finally:
conn.close()
def is_conversation_member(conversation_id: str, user_id: str) -> bool:
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"SELECT 1 FROM conversation_members WHERE conversation_id = %s AND user_id = %s",
(conversation_id, user_id),
)
return cursor.fetchone() is not None
finally:
conn.close()
def list_user_conversations(user_id: str) -> list[dict]:
conn = get_connection()
try:
cursor = conn.cursor(dictionary=True)
cursor.execute(
"SELECT c.id, c.created_at, c.name, c.created_by, c.avatar_file FROM conversations c "
"JOIN conversation_members cm ON c.id = cm.conversation_id "
"WHERE cm.user_id = %s ORDER BY c.created_at DESC",
(user_id,),
)
convs = cursor.fetchall()
if not convs:
return convs
# Batch-fetch all members for all conversations in one query (N+1 fix)
conv_ids = [c["id"] for c in convs]
placeholders = ",".join(["%s"] * len(conv_ids))
cursor.execute(
f"SELECT cm.conversation_id, u.id AS user_id, u.username, u.email "
f"FROM conversation_members cm JOIN users u ON cm.user_id = u.id "
f"WHERE cm.conversation_id IN ({placeholders})",
conv_ids,
)
members_by_conv: dict[str, list[dict]] = {}
for row in cursor.fetchall():
cid = row.pop("conversation_id")
members_by_conv.setdefault(cid, []).append(row)
for conv in convs:
conv["members"] = members_by_conv.get(conv["id"], [])
return convs
finally:
conn.close()
# --- Group Invitations ---
def create_invitation(conversation_id: str, user_id: str, invited_by: str):
"""Create a pending group invitation."""
conn = get_connection()
try:
cursor = conn.cursor()
inv_id = generate_uuid()
cursor.execute(
"INSERT IGNORE INTO group_invitations (id, conversation_id, user_id, invited_by) "
"VALUES (%s, %s, %s, %s)",
(inv_id, conversation_id, user_id, invited_by),
)
conn.commit()
finally:
conn.close()
def get_pending_invitations(user_id: str) -> list[dict]:
"""Get all pending invitations for a user, joined with conversation and inviter info."""
conn = get_connection()
try:
cursor = conn.cursor(dictionary=True)
cursor.execute(
"SELECT gi.id, gi.conversation_id, gi.invited_by, gi.created_at, "
"c.name AS conversation_name, u.username AS invited_by_username "
"FROM group_invitations gi "
"JOIN conversations c ON gi.conversation_id = c.id "
"JOIN users u ON gi.invited_by = u.id "
"WHERE gi.user_id = %s "
"ORDER BY gi.created_at DESC",
(user_id,),
)
return cursor.fetchall()
finally:
conn.close()
def delete_invitation(conversation_id: str, user_id: str):
"""Delete a pending invitation."""
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"DELETE FROM group_invitations WHERE conversation_id = %s AND user_id = %s",
(conversation_id, user_id),
)
conn.commit()
finally:
conn.close()
def has_pending_invitation(conversation_id: str, user_id: str) -> bool:
"""Check if a user has a pending invitation for a conversation."""
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"SELECT 1 FROM group_invitations WHERE conversation_id = %s AND user_id = %s",
(conversation_id, user_id),
)
return cursor.fetchone() is not None
finally:
conn.close()
# --- Messages ---
def store_message(
conversation_id: str,
sender_id: str,
ratchet_header: bytes,
recipients: list[dict],
x3dh_header: bytes | None = None,
sender_chain_id: bytes | None = None,
sender_chain_n: int | None = None,
image_file_id: str | None = None,
sender_device_id: str | None = None,
) -> str:
"""Store an encrypted message with per-recipient ciphertext.
recipients: [{user_id, encrypted_content (bytes), nonce (bytes),
device_id (str, optional), ratchet_header (bytes, optional),
x3dh_header (bytes, optional)}]
"""
conn = get_connection()
try:
cursor = conn.cursor()
msg_id = generate_uuid()
cursor.execute(
"INSERT INTO messages (id, conversation_id, sender_id, sender_device_id, "
"ratchet_header, x3dh_header, sender_chain_id, sender_chain_n, image_file_id) "
"VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)",
(msg_id, conversation_id, sender_id, sender_device_id, ratchet_header,
x3dh_header, sender_chain_id, sender_chain_n, image_file_id),
)
for r in recipients:
device_id = r.get("device_id", SELF_DEVICE_ID)
cursor.execute(
"INSERT INTO message_recipients (message_id, user_id, device_id, "
"encrypted_content, nonce, ratchet_header, x3dh_header) "
"VALUES (%s, %s, %s, %s, %s, %s, %s)",
(msg_id, r["user_id"], device_id, r["encrypted_content"], r["nonce"],
r.get("ratchet_header"), r.get("x3dh_header")),
)
conn.commit()
cursor.execute("SELECT created_at FROM messages WHERE id = %s", (msg_id,))
row = cursor.fetchone()
created_at = row[0].isoformat() if row else None
return msg_id, created_at
finally:
conn.close()
def get_messages(conversation_id: str, user_id: str, limit: int = 50, offset: int = 0,
device_id: str | None = None, after_ts: str | None = None) -> list[dict]:
"""Get messages for a user in a conversation, JOINing their per-recipient ciphertext.
If device_id is set, returns rows where mr.device_id matches OR is the sentinel
(self-encrypted / legacy). May return duplicate message IDs when both device-specific
and self-encrypted rows exist — caller should deduplicate (prefer device-specific).
If after_ts is set, only returns messages created after that timestamp (ISO format).
Results are ordered ASC when after_ts is used, DESC otherwise.
"""
conn = get_connection()
try:
cursor = conn.cursor(dictionary=True)
if device_id:
where_parts = ["m.conversation_id = %s",
"(cm.joined_at IS NULL OR m.created_at >= cm.joined_at)"]
params = [user_id, device_id, SELF_DEVICE_ID, user_id, conversation_id]
if after_ts:
where_parts.append("m.created_at > %s")
params.append(after_ts)
where_clause = " AND ".join(where_parts)
order = "ASC" if after_ts else "DESC"
cursor.execute(
"SELECT m.id, m.conversation_id, m.sender_id, m.sender_device_id, "
"m.ratchet_header, m.x3dh_header, "
"m.sender_chain_id, m.sender_chain_n, m.created_at, m.deleted_at, m.image_file_id, "
"m.pinned_at, m.pinned_by, "
"mr.encrypted_content, mr.nonce, mr.device_id AS mr_device_id, "
"mr.ratchet_header AS mr_ratchet_header, mr.x3dh_header AS mr_x3dh_header "
"FROM messages m "
"JOIN message_recipients mr ON m.id = mr.message_id AND mr.user_id = %s "
" AND (mr.device_id = %s OR mr.device_id = %s) "
"JOIN conversation_members cm ON cm.conversation_id = m.conversation_id AND cm.user_id = %s "
f"WHERE {where_clause} "
f"ORDER BY m.created_at {order} LIMIT %s OFFSET %s",
(*params, limit, offset),
)
else:
where_parts = ["m.conversation_id = %s",
"(cm.joined_at IS NULL OR m.created_at >= cm.joined_at)"]
params = [user_id, user_id, conversation_id]
if after_ts:
where_parts.append("m.created_at > %s")
params.append(after_ts)
where_clause = " AND ".join(where_parts)
order = "ASC" if after_ts else "DESC"
cursor.execute(
"SELECT m.id, m.conversation_id, m.sender_id, m.sender_device_id, "
"m.ratchet_header, m.x3dh_header, "
"m.sender_chain_id, m.sender_chain_n, m.created_at, m.deleted_at, m.image_file_id, "
"m.pinned_at, m.pinned_by, "
"mr.encrypted_content, mr.nonce, mr.device_id AS mr_device_id, "
"mr.ratchet_header AS mr_ratchet_header, mr.x3dh_header AS mr_x3dh_header "
"FROM messages m "
"JOIN message_recipients mr ON m.id = mr.message_id AND mr.user_id = %s "
"JOIN conversation_members cm ON cm.conversation_id = m.conversation_id AND cm.user_id = %s "
f"WHERE {where_clause} "
f"ORDER BY m.created_at {order} LIMIT %s OFFSET %s",
(*params, limit, offset),
)
return cursor.fetchall()
finally:
conn.close()
def count_messages(conversation_id: str, user_id: str) -> int:
"""Count total messages visible to a user in a conversation."""
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"SELECT COUNT(DISTINCT m.id) "
"FROM messages m "
"JOIN message_recipients mr ON m.id = mr.message_id AND mr.user_id = %s "
"JOIN conversation_members cm ON cm.conversation_id = m.conversation_id AND cm.user_id = %s "
"WHERE m.conversation_id = %s AND (cm.joined_at IS NULL OR m.created_at >= cm.joined_at)",
(user_id, user_id, conversation_id),
)
row = cursor.fetchone()
return row[0] if row else 0
finally:
conn.close()
def get_message_conversation(message_id: str) -> str | None:
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT conversation_id FROM messages WHERE id = %s", (message_id,))
row = cursor.fetchone()
return row[0] if row else None
finally:
conn.close()
def get_message_sender(message_id: str) -> str | None:
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT sender_id FROM messages WHERE id = %s", (message_id,))
row = cursor.fetchone()
return row[0] if row else None
finally:
conn.close()
def get_deleted_messages_since(conversation_id: str, user_id: str, since_ts: str) -> list[str]:
"""Return message IDs that were soft-deleted since the given timestamp."""
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"SELECT m.id FROM messages m "
"JOIN conversation_members cm ON cm.conversation_id = m.conversation_id AND cm.user_id = %s "
"WHERE m.conversation_id = %s AND m.deleted_at IS NOT NULL AND m.deleted_at > %s",
(user_id, conversation_id, since_ts),
)
return [row[0] for row in cursor.fetchall()]
finally:
conn.close()
# --- Reactions ---
ALLOWED_REACTIONS = {"thumbsup", "heart", "laugh", "surprised", "sad", "thumbsdown"}
def add_reaction(message_id: str, user_id: str, reaction: str) -> tuple[bool, str | None]:
"""Add or replace a reaction. Returns (changed, old_reaction_or_None)."""
conn = get_connection()
try:
cursor = conn.cursor(dictionary=True)
cursor.execute(
"SELECT reaction FROM message_reactions WHERE message_id = %s AND user_id = %s",
(message_id, user_id),
)
row = cursor.fetchone()
old_reaction = row["reaction"] if row else None
if old_reaction == reaction:
return False, None # already same reaction
if old_reaction:
cursor.execute(
"UPDATE message_reactions SET reaction = %s, created_at = CURRENT_TIMESTAMP "
"WHERE message_id = %s AND user_id = %s",
(reaction, message_id, user_id),
)
else:
cursor.execute(
"INSERT INTO message_reactions (id, message_id, user_id, reaction) "
"VALUES (%s, %s, %s, %s)",
(generate_uuid(), message_id, user_id, reaction),
)
conn.commit()
return True, old_reaction
finally:
conn.close()
def remove_reaction(message_id: str, user_id: str) -> bool:
"""Remove a user's reaction. Returns True if deleted."""
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"DELETE FROM message_reactions WHERE message_id = %s AND user_id = %s",
(message_id, user_id),
)
conn.commit()
return cursor.rowcount > 0
finally:
conn.close()
def get_reactions(message_ids: list[str]) -> dict[str, list[dict]]:
"""Get reactions for multiple messages. Returns {msg_id: [{user_id, reaction, created_at}]}."""
if not message_ids:
return {}
conn = get_connection()
try:
cursor = conn.cursor(dictionary=True)
placeholders = ",".join(["%s"] * len(message_ids))
cursor.execute(
f"SELECT message_id, user_id, reaction, created_at "
f"FROM message_reactions WHERE message_id IN ({placeholders}) "
f"ORDER BY created_at",
tuple(message_ids),
)
result = {}
for row in cursor.fetchall():
mid = row["message_id"]
if mid not in result:
result[mid] = []
result[mid].append({
"user_id": row["user_id"],
"reaction": row["reaction"],
"created_at": row["created_at"].isoformat() if hasattr(row["created_at"], "isoformat") else str(row["created_at"]),
})
return result
finally:
conn.close()
# --- Pins ---
def pin_message(message_id: str, user_id: str, conversation_id: str) -> bool:
"""Pin a message. Returns True on success."""
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"UPDATE messages SET pinned_at = NOW(), pinned_by = %s "
"WHERE id = %s AND conversation_id = %s AND pinned_at IS NULL",
(user_id, message_id, conversation_id),
)
conn.commit()
return cursor.rowcount > 0
finally:
conn.close()
def unpin_message(message_id: str, conversation_id: str) -> bool:
"""Unpin a message. Returns True on success."""
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"UPDATE messages SET pinned_at = NULL, pinned_by = NULL "
"WHERE id = %s AND conversation_id = %s AND pinned_at IS NOT NULL",
(message_id, conversation_id),
)
conn.commit()
return cursor.rowcount > 0
finally:
conn.close()
def get_pinned_messages(conversation_id: str, user_id: str) -> list[dict]:
"""Get pinned messages for a conversation (membership verified via JOIN)."""
conn = get_connection()
try:
cursor = conn.cursor(dictionary=True)
cursor.execute(
"SELECT m.id AS message_id, m.sender_id, m.pinned_at, m.pinned_by, m.created_at "
"FROM messages m "
"JOIN conversation_members cm ON cm.conversation_id = m.conversation_id AND cm.user_id = %s "
"WHERE m.conversation_id = %s AND m.pinned_at IS NOT NULL AND m.deleted_at IS NULL "
"ORDER BY m.pinned_at DESC",
(user_id, conversation_id),
)
rows = cursor.fetchall()
for r in rows:
for k in ("pinned_at", "created_at"):
if r.get(k) and hasattr(r[k], "isoformat"):
r[k] = r[k].isoformat()
return rows
finally:
conn.close()
# --- Group Sender Keys ---
def store_sender_key(conversation_id: str, sender_id: str, chain_id: bytes,
device_id: str | None = None):
"""Store or update a sender key chain ID for a group member's device."""
conn = get_connection()
try:
cursor = conn.cursor()
dev = device_id or SELF_DEVICE_ID
cursor.execute(
"REPLACE INTO group_sender_keys (conversation_id, sender_id, device_id, chain_id) "
"VALUES (%s, %s, %s, %s)",
(conversation_id, sender_id, dev, chain_id),
)
conn.commit()
finally:
conn.close()
def get_sender_key(conversation_id: str, sender_id: str,
device_id: str | None = None) -> dict | None:
conn = get_connection()
try:
cursor = conn.cursor(dictionary=True)
dev = device_id or SELF_DEVICE_ID
cursor.execute(
"SELECT chain_id, created_at FROM group_sender_keys "
"WHERE conversation_id = %s AND sender_id = %s AND device_id = %s",
(conversation_id, sender_id, dev),
)
return cursor.fetchone()
finally:
conn.close()
# --- Read Receipts ---
def filter_message_ids_by_conversation(conversation_id: str, message_ids: list[str]) -> list[str]:
"""Return only message_ids that belong to the given conversation."""
if not message_ids:
return []
conn = get_connection()
try:
cursor = conn.cursor()
placeholders = ",".join(["%s"] * len(message_ids))
cursor.execute(
f"SELECT id FROM messages WHERE id IN ({placeholders}) AND conversation_id = %s",
(*message_ids, conversation_id),
)
return [row[0] for row in cursor.fetchall()]
finally:
conn.close()
def mark_messages_read(conversation_id: str, user_id: str, message_ids: list[str]):
if not message_ids:
return
conn = get_connection()
try:
cursor = conn.cursor()
# M1 fix: JOIN messages to verify message_ids belong to conversation_id
placeholders = ",".join(["%s"] * len(message_ids))
cursor.execute(
f"INSERT IGNORE INTO message_reads (message_id, user_id) "
f"SELECT m.id, %s FROM messages m "
f"WHERE m.id IN ({placeholders}) AND m.conversation_id = %s",
(user_id, *message_ids, conversation_id),
)
conn.commit()
finally:
conn.close()
def mark_conversation_read(conversation_id: str, user_id: str) -> int:
"""Mark ALL unread messages in a conversation as read for user. Returns count marked."""
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"INSERT IGNORE INTO message_reads (message_id, user_id) "
"SELECT m.id, %s "
"FROM messages m "
"JOIN message_recipients mr ON mr.message_id = m.id AND mr.user_id = %s "
"LEFT JOIN message_reads mrd ON mrd.message_id = m.id AND mrd.user_id = %s "
"WHERE m.conversation_id = %s AND m.sender_id != %s "
"AND m.deleted_at IS NULL AND mrd.message_id IS NULL",
(user_id, user_id, user_id, conversation_id, user_id),
)
count = cursor.rowcount
conn.commit()
return count
finally:
conn.close()
def get_unread_counts(user_id: str, max_age_days: int = 0) -> dict[str, int]:
"""Return {conversation_id: unread_count} for all conversations the user is in.
max_age_days: if > 0, only count messages younger than this many days.
Must match METADATA_RETENTION_DAYS to avoid phantom unreads after read cleanup.
"""
conn = get_connection()
try:
cursor = conn.cursor(dictionary=True)
age_filter = ""
params = [user_id, user_id, user_id]
if max_age_days > 0:
age_filter = " AND m.created_at >= DATE_SUB(NOW(), INTERVAL %s DAY)"
params.append(max_age_days)
cursor.execute(
"SELECT m.conversation_id, COUNT(DISTINCT m.id) AS cnt "
"FROM messages m "
"JOIN message_recipients mr ON mr.message_id = m.id AND mr.user_id = %s "
"LEFT JOIN message_reads mrd ON mrd.message_id = m.id AND mrd.user_id = %s "
"WHERE m.sender_id != %s AND m.deleted_at IS NULL AND mrd.message_id IS NULL"
f"{age_filter} "
"GROUP BY m.conversation_id",
params,
)
return {row["conversation_id"]: row["cnt"] for row in cursor.fetchall()}
finally:
conn.close()
def get_message_read_status(message_ids: list[str]) -> dict:
if not message_ids:
return {}
conn = get_connection()
try:
cursor = conn.cursor(dictionary=True)
placeholders = ",".join(["%s"] * len(message_ids))
cursor.execute(
f"SELECT mr.message_id, mr.user_id, mr.read_at "
f"FROM message_reads mr "
f"WHERE mr.message_id IN ({placeholders})",
tuple(message_ids),
)
result = {}
for row in cursor.fetchall():
mid = row["message_id"]
if mid not in result:
result[mid] = []
result[mid].append({
"user_id": row["user_id"],
"read_at": row["read_at"].isoformat() if hasattr(row["read_at"], "isoformat") else str(row["read_at"]),
})
return result
finally:
conn.close()
# --- Delivery Receipts ---
def mark_messages_delivered(conversation_id: str, user_id: str, message_ids: list[str]):
"""Batch insert delivery receipts (INSERT IGNORE — idempotent)."""
if not message_ids:
return
conn = get_connection()
try:
cursor = conn.cursor()
# M1 fix: JOIN messages to verify message_ids belong to conversation_id
placeholders = ",".join(["%s"] * len(message_ids))
cursor.execute(
f"INSERT IGNORE INTO message_deliveries (message_id, user_id) "
f"SELECT m.id, %s FROM messages m "
f"WHERE m.id IN ({placeholders}) AND m.conversation_id = %s",
(user_id, *message_ids, conversation_id),
)
conn.commit()
finally:
conn.close()
def get_message_delivery_status(message_ids: list[str]) -> dict:
"""Get delivery status for messages. Returns {msg_id: [{user_id, delivered_at}]}."""
if not message_ids:
return {}
conn = get_connection()
try:
cursor = conn.cursor(dictionary=True)
placeholders = ",".join(["%s"] * len(message_ids))
cursor.execute(
f"SELECT md.message_id, md.user_id, md.delivered_at "
f"FROM message_deliveries md "
f"WHERE md.message_id IN ({placeholders})",
tuple(message_ids),
)
result = {}
for row in cursor.fetchall():
mid = row["message_id"]
if mid not in result:
result[mid] = []
result[mid].append({
"user_id": row["user_id"],
"delivered_at": row["delivered_at"].isoformat() if hasattr(row["delivered_at"], "isoformat") else str(row["delivered_at"]),
})
return result
finally:
conn.close()
# --- Delete ---
def soft_delete_message(message_id: str, sender_id: str) -> dict | None:
"""Soft-delete a message if sender matches. Returns {'image_file_id': ...} or None."""
conn = get_connection()
try:
cursor = conn.cursor(dictionary=True)
cursor.execute(
"SELECT sender_id, image_file_id FROM messages WHERE id = %s AND deleted_at IS NULL",
(message_id,),
)
row = cursor.fetchone()
if not row or row["sender_id"] != sender_id:
return None
cursor.execute(
"UPDATE messages SET deleted_at = NOW() WHERE id = %s",
(message_id,),
)
# Clear per-recipient ciphertext
cursor.execute(
"UPDATE message_recipients SET encrypted_content = %s WHERE message_id = %s",
(b"", message_id),
)
conn.commit()
return {"image_file_id": row.get("image_file_id")}
finally:
conn.close()
def set_message_image_file_id(message_id: str, file_id: str):
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"UPDATE messages SET image_file_id = %s WHERE id = %s",
(file_id, message_id),
)
conn.commit()
finally:
conn.close()
# --- Image Uploads ---
def create_image_upload(file_id: str, conversation_id: str, uploader_id: str, file_size: int):
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"INSERT INTO image_uploads (file_id, conversation_id, uploader_id, file_size) "
"VALUES (%s, %s, %s, %s)",
(file_id, conversation_id, uploader_id, file_size),
)
conn.commit()
finally:
conn.close()
def complete_image_upload(file_id: str):
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"UPDATE image_uploads SET completed = TRUE WHERE file_id = %s",
(file_id,),
)
conn.commit()
finally:
conn.close()
def get_image_upload(file_id: str) -> dict | None:
conn = get_connection()
try:
cursor = conn.cursor(dictionary=True)
cursor.execute(
"SELECT file_id, conversation_id, uploader_id, file_size, completed, created_at "
"FROM image_uploads WHERE file_id = %s",
(file_id,),
)
return cursor.fetchone()
finally:
conn.close()
def delete_image_upload(file_id: str):
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute("DELETE FROM image_uploads WHERE file_id = %s", (file_id,))
conn.commit()
finally:
conn.close()
# --- User Profiles ---
def create_default_profile(user_id: str):
"""Create a default profile for a new user."""
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"INSERT IGNORE INTO user_profiles (user_id) VALUES (%s)",
(user_id,),
)
conn.commit()
finally:
conn.close()
def get_user_profile(user_id: str, viewer_id: str | None = None) -> dict | None:
"""Get user profile joined with user info. Respects visibility if viewer is different user."""
conn = get_connection()
try:
cursor = conn.cursor(dictionary=True)
cursor.execute(
"SELECT u.id AS user_id, u.username, u.email, u.created_at, "
"p.phone, p.phone_visible, p.email_visible, p.location, "
"p.location_visible, p.avatar_file, p.updated_at "
"FROM users u LEFT JOIN user_profiles p ON u.id = p.user_id "
"WHERE u.id = %s",
(user_id,),
)
row = cursor.fetchone()
if not row:
return None
# If viewing someone else's profile, apply visibility rules
if viewer_id and viewer_id != user_id:
if not row.get("email_visible"):
row["email"] = None
if not row.get("phone_visible"):
row["phone"] = None
if not row.get("location_visible"):
row["location"] = None
return row
finally:
conn.close()
def update_user_profile(user_id: str, **fields):
"""Upsert user profile fields. Allowed: phone, phone_visible, email_visible,
location, location_visible, avatar_file."""
allowed = {"phone", "phone_visible", "email_visible", "location",
"location_visible", "avatar_file"}
filtered = {k: v for k, v in fields.items() if k in allowed}
if not filtered:
return
conn = get_connection()
try:
cursor = conn.cursor()
# Upsert: insert default then update
cursor.execute(
"INSERT IGNORE INTO user_profiles (user_id) VALUES (%s)",
(user_id,),
)
set_clause = ", ".join(f"{k} = %s" for k in filtered)
values = list(filtered.values()) + [user_id]
cursor.execute(
f"UPDATE user_profiles SET {set_clause} WHERE user_id = %s",
values,
)
conn.commit()
finally:
conn.close()
def batch_reencrypt_messages(user_id: str, updates: list[dict]):
"""Batch upsert message_recipients rows with self-encryption key data.
Each update: {message_id, encrypted_content (bytes), nonce (bytes)}.
Sets ratchet_header to '{"self":true}' and clears x3dh_header.
Uses INSERT ... ON DUPLICATE KEY UPDATE so it works for both sent messages
(which already have a SELF_DEVICE_ID row) and received messages (which don't).
"""
if not updates:
return
conn = get_connection()
try:
cursor = conn.cursor()
self_header = b'{"self":true}'
for u in updates:
cursor.execute(
"INSERT INTO message_recipients "
"(message_id, user_id, device_id, encrypted_content, nonce, ratchet_header, x3dh_header) "
"VALUES (%s, %s, %s, %s, %s, %s, NULL) "
"ON DUPLICATE KEY UPDATE encrypted_content = VALUES(encrypted_content), "
"nonce = VALUES(nonce), ratchet_header = VALUES(ratchet_header), x3dh_header = NULL",
(u["message_id"], user_id, SELF_DEVICE_ID,
u["encrypted_content"], u["nonce"], self_header),
)
conn.commit()
finally:
conn.close()
# --- Phantom Users ---
def create_phantom_user(email: str) -> dict:
"""Create a phantom user with valid crypto keys for X3DH.
Phantom users have rsa_public_key = 'PHANTOM' as a marker.
Returns user dict: {id, username, email, identity_key}.
"""
username = email.split("@")[0]
user_id = generate_uuid()
# Generate real crypto keys so X3DH works on the client side
ik_private, ik_public = generate_identity_keypair()
ik_public_bytes = serialize_ed25519_public(ik_public)
spk = generate_signed_prekey(ik_private)
spk_pub_bytes = serialize_x25519_public(spk["public"])
spk_sig = spk["signature"]
opks = generate_one_time_prekeys(count=5)
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"INSERT INTO users (id, username, email, rsa_public_key, identity_key) "
"VALUES (%s, %s, %s, %s, %s)",
(user_id, username, email, "PHANTOM", ik_public_bytes),
)
cursor.execute(
"INSERT INTO signed_prekeys (id, user_id, public_key, signature) VALUES (%s, %s, %s, %s)",
(spk["id"], user_id, spk_pub_bytes, spk_sig),
)
for opk in opks:
cursor.execute(
"INSERT INTO one_time_prekeys (id, user_id, public_key) VALUES (%s, %s, %s)",
(opk["id"], user_id, serialize_x25519_public(opk["public"])),
)
conn.commit()
return {"id": user_id, "username": username, "email": email, "identity_key": ik_public_bytes}
finally:
conn.close()
def is_phantom_user(user_id: str) -> bool:
"""Check if a user is a phantom (rsa_public_key == 'PHANTOM')."""
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT rsa_public_key FROM users WHERE id = %s", (user_id,))
row = cursor.fetchone()
return row is not None and row[0] == "PHANTOM"
finally:
conn.close()
def delete_phantom_user(user_id: str):
"""Delete a phantom user. CASCADE removes signed_prekeys, one_time_prekeys,
conversation_members, message_recipients, etc."""
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"DELETE FROM users WHERE id = %s AND rsa_public_key = %s",
(user_id, "PHANTOM"),
)
conn.commit()
finally:
conn.close()
def upgrade_phantom_user(phantom_id: str, username: str, rsa_public_key_pem: str,
identity_key: bytes) -> str | None:
"""Upgrade a phantom user to a real user in-place.
Preserves user_id and all FK references (conversation_members, group_invitations, etc.).
Deletes phantom's server-generated prekeys (real user will upload own on first login).
Returns phantom_id as the new user_id, or None if phantom no longer exists.
"""
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"UPDATE users SET username = %s, rsa_public_key = %s, identity_key = %s "
"WHERE id = %s AND rsa_public_key = 'PHANTOM'",
(username, rsa_public_key_pem, identity_key, phantom_id),
)
if cursor.rowcount == 0:
conn.rollback()
return None
# Remove phantom's server-generated crypto keys — real user uploads own
cursor.execute("DELETE FROM signed_prekeys WHERE user_id = %s", (phantom_id,))
cursor.execute("DELETE FROM one_time_prekeys WHERE user_id = %s", (phantom_id,))
conn.commit()
return phantom_id
finally:
conn.close()
def get_all_phantom_user_ids() -> set[str]:
"""Return set of all phantom user IDs (for server startup cache)."""
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT id FROM users WHERE rsa_public_key = %s", ("PHANTOM",))
return {row[0] for row in cursor.fetchall()}
finally:
conn.close()
def cleanup_stale_phantoms(max_age_days: int = 30) -> int:
"""Delete phantom users older than max_age_days with no active conversations with real users."""
conn = get_connection()
try:
cursor = conn.cursor()
# Two-step: SELECT ids first, then DELETE.
# MySQL error 1093: can't DELETE from table referenced in subquery.
cursor.execute("""
SELECT u.id FROM users u
WHERE u.rsa_public_key = 'PHANTOM'
AND u.created_at < DATE_SUB(NOW(), INTERVAL %s DAY)
AND NOT EXISTS (
SELECT 1 FROM conversation_members cm1
JOIN conversation_members cm2 ON cm1.conversation_id = cm2.conversation_id
JOIN users u2 ON cm2.user_id = u2.id
WHERE cm1.user_id = u.id
AND u2.rsa_public_key != 'PHANTOM'
)
""", (max_age_days,))
ids = [row[0] for row in cursor.fetchall()]
if not ids:
return 0
cursor.execute(
"DELETE FROM users WHERE id IN (%s)" % ",".join(["%s"] * len(ids)),
ids,
)
deleted = cursor.rowcount
conn.commit()
return deleted
finally:
conn.close()
def remove_conversation_member_atomic(conversation_id: str, user_id: str) -> bool:
"""Remove member and return True if actually removed (row existed). M6 TOCTOU fix."""
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"DELETE FROM conversation_members WHERE conversation_id = %s AND user_id = %s",
(conversation_id, user_id),
)
conn.commit()
return cursor.rowcount > 0
finally:
conn.close()
def get_stale_uploads(max_age_seconds: int = 3600) -> list[dict]:
conn = get_connection()
try:
cursor = conn.cursor(dictionary=True)
cursor.execute(
"SELECT file_id FROM image_uploads "
"WHERE completed = FALSE AND created_at < DATE_SUB(NOW(), INTERVAL %s SECOND)",
(max_age_seconds,),
)
return cursor.fetchall()
finally:
conn.close()
# ---------------------------------------------------------------------------
# Metadata retention cleanup
# ---------------------------------------------------------------------------
def cleanup_old_reads(days: int = 90, batch_size: int = 10000) -> int:
"""Delete message_reads older than N days in batches.
Only deletes reads for messages whose created_at is also past the retention
window. This prevents phantom unreads: get_unread_counts uses the same
time window (max_age_days) so messages outside the window aren't counted.
"""
total = 0
while True:
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"DELETE FROM message_reads "
"WHERE read_at < DATE_SUB(NOW(), INTERVAL %s DAY) "
"AND message_id IN ("
" SELECT id FROM messages "
" WHERE created_at < DATE_SUB(NOW(), INTERVAL %s DAY)"
") LIMIT %s",
(days, days, batch_size),
)
count = cursor.rowcount
conn.commit()
total += count
if count < batch_size:
break
finally:
conn.close()
return total
def cleanup_old_reactions(days: int = 90, batch_size: int = 10000) -> int:
"""Delete message_reactions older than N days in batches."""
total = 0
while True:
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"DELETE FROM message_reactions WHERE created_at < DATE_SUB(NOW(), INTERVAL %s DAY) LIMIT %s",
(days, batch_size),
)
count = cursor.rowcount
conn.commit()
total += count
if count < batch_size:
break
finally:
conn.close()
return total
def cleanup_old_messages(days: int, batch_size: int = 1000) -> tuple[int, list[str]]:
"""Delete messages older than N days in batches.
message_recipients / message_reads / message_deliveries / message_reactions
rows go with them via ON DELETE CASCADE.
Returns (deleted_count, orphaned_file_ids) — file_ids whose encrypted
blobs are no longer referenced by any surviving message. The caller is
responsible for removing those files from the upload directory (db layer
does not touch the filesystem).
"""
# Collect attachment file_ids referenced by messages about to be deleted
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"SELECT DISTINCT image_file_id FROM messages "
"WHERE created_at < DATE_SUB(NOW(), INTERVAL %s DAY) "
"AND image_file_id IS NOT NULL",
(days,),
)
candidate_files = [row[0] for row in cursor.fetchall()]
finally:
conn.close()
total = 0
while True:
conn = get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"DELETE FROM messages WHERE created_at < DATE_SUB(NOW(), INTERVAL %s DAY) LIMIT %s",
(days, batch_size),
)
count = cursor.rowcount
conn.commit()
total += count
if count < batch_size:
break
finally:
conn.close()
# A file is orphaned only if no surviving (newer) message still references
# it (e.g. a forwarded copy)
orphaned: list[str] = []
if candidate_files:
still_referenced: set[str] = set()
conn = get_connection()
try:
cursor = conn.cursor()
for i in range(0, len(candidate_files), 500):
chunk = candidate_files[i:i + 500]
placeholders = ", ".join(["%s"] * len(chunk))
cursor.execute(
f"SELECT DISTINCT image_file_id FROM messages "
f"WHERE image_file_id IN ({placeholders})",
chunk,
)
still_referenced.update(row[0] for row in cursor.fetchall())
finally:
conn.close()
orphaned = [f for f in candidate_files if f not in still_referenced]
if orphaned:
conn = get_connection()
try:
cursor = conn.cursor()
for i in range(0, len(orphaned), 500):
chunk = orphaned[i:i + 500]
placeholders = ", ".join(["%s"] * len(chunk))
cursor.execute(
f"DELETE FROM image_uploads WHERE file_id IN ({placeholders})",
chunk,
)
conn.commit()
finally:
conn.close()
return total, orphaned