E2E encrypted chat (X3DH + Double Ratchet, Signal Protocol). Server: asyncio TCP + TLS, MySQL. Clients: PyQt6 GUI + CLI. Secrets (.env, TLS keys, Cloudflare token), runtime data and mobile clients (separate repos) are gitignored. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
1727 lines
58 KiB
Python
1727 lines
58 KiB
Python
"""MySQL database layer for the encrypted chat server."""
|
|
|
|
import os
|
|
import uuid
|
|
|
|
import logging
|
|
import mysql.connector
|
|
from mysql.connector import pooling
|
|
from dotenv import load_dotenv
|
|
|
|
from crypto_utils import (
|
|
generate_identity_keypair,
|
|
serialize_ed25519_public,
|
|
generate_signed_prekey,
|
|
serialize_x25519_public,
|
|
generate_one_time_prekeys,
|
|
)
|
|
|
|
load_dotenv()
|
|
|
|
# Sentinel device_id for self-encrypted copies and legacy (pre-multi-device) rows
|
|
SELF_DEVICE_ID = "00000000-0000-0000-0000-000000000000"
|
|
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
_pool = None
|
|
|
|
|
|
def _get_pool():
|
|
"""Get or create the connection pool (lazy init)."""
|
|
global _pool
|
|
if _pool is None:
|
|
pool_size = int(os.getenv("DB_POOL_SIZE", "10"))
|
|
pool_kwargs = dict(
|
|
pool_name="chat_pool",
|
|
pool_size=pool_size,
|
|
pool_reset_session=True,
|
|
host=os.getenv("MYSQL_HOST", "localhost"),
|
|
port=int(os.getenv("MYSQL_PORT", "3306")),
|
|
user=os.getenv("MYSQL_USER", "root"),
|
|
password=os.getenv("MYSQL_PASSWORD", ""),
|
|
database=os.getenv("MYSQL_DATABASE", "encrypted_chat"),
|
|
)
|
|
# Optional MySQL TLS (M7): set MYSQL_SSL_CA (and optionally MYSQL_SSL_CERT/KEY)
|
|
ssl_ca = os.getenv("MYSQL_SSL_CA", "").strip()
|
|
ssl_cert = os.getenv("MYSQL_SSL_CERT", "").strip()
|
|
ssl_key = os.getenv("MYSQL_SSL_KEY", "").strip()
|
|
if ssl_ca:
|
|
pool_kwargs["ssl_ca"] = ssl_ca
|
|
if ssl_cert:
|
|
pool_kwargs["ssl_cert"] = ssl_cert
|
|
if ssl_key:
|
|
pool_kwargs["ssl_key"] = ssl_key
|
|
_logger.info("MySQL TLS enabled (CA: %s)", ssl_ca)
|
|
_pool = pooling.MySQLConnectionPool(**pool_kwargs)
|
|
_logger.info("DB connection pool created (size=%d)", pool_size)
|
|
return _pool
|
|
|
|
|
|
def get_connection():
|
|
"""Get a connection from the pool."""
|
|
return _get_pool().get_connection()
|
|
|
|
|
|
def generate_uuid() -> str:
|
|
return str(uuid.uuid4())
|
|
|
|
|
|
# --- Devices ---
|
|
|
|
def create_device(user_id: str, device_name: str | None = None) -> str:
|
|
"""Create a new device for a user. Returns device_id."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
device_id = generate_uuid()
|
|
cursor.execute(
|
|
"INSERT INTO devices (id, user_id, device_name) VALUES (%s, %s, %s)",
|
|
(device_id, user_id, device_name),
|
|
)
|
|
conn.commit()
|
|
return device_id
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def get_user_devices(user_id: str) -> list[dict]:
|
|
"""Get all devices for a user."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor(dictionary=True)
|
|
cursor.execute(
|
|
"SELECT id, user_id, device_name, created_at, last_seen_at "
|
|
"FROM devices WHERE user_id = %s ORDER BY created_at",
|
|
(user_id,),
|
|
)
|
|
return cursor.fetchall()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def get_device(device_id: str) -> dict | None:
|
|
"""Get a single device by ID."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor(dictionary=True)
|
|
cursor.execute(
|
|
"SELECT id, user_id, device_name, created_at, last_seen_at "
|
|
"FROM devices WHERE id = %s",
|
|
(device_id,),
|
|
)
|
|
return cursor.fetchone()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def update_device_last_seen(device_id: str):
|
|
"""Update last_seen_at timestamp for a device."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"UPDATE devices SET last_seen_at = NOW() WHERE id = %s",
|
|
(device_id,),
|
|
)
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def delete_device(device_id: str):
|
|
"""Delete a device. CASCADE removes its prekeys."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute("DELETE FROM devices WHERE id = %s", (device_id,))
|
|
# Also clean up prekeys explicitly for device_id column
|
|
cursor.execute("DELETE FROM signed_prekeys WHERE device_id = %s", (device_id,))
|
|
cursor.execute("DELETE FROM one_time_prekeys WHERE device_id = %s", (device_id,))
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
# --- Users ---
|
|
|
|
def create_user(username: str, email: str, rsa_public_key_pem: str, identity_key: bytes) -> str:
|
|
"""Register a new user. Returns user ID."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
user_id = generate_uuid()
|
|
cursor.execute(
|
|
"INSERT INTO users (id, username, email, rsa_public_key, identity_key) "
|
|
"VALUES (%s, %s, %s, %s, %s)",
|
|
(user_id, username, email, rsa_public_key_pem, identity_key),
|
|
)
|
|
conn.commit()
|
|
return user_id
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def get_user_by_email(email: str) -> dict | None:
|
|
"""Get user by email."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor(dictionary=True)
|
|
cursor.execute(
|
|
"SELECT id, username, rsa_public_key, email, identity_key FROM users WHERE email = %s",
|
|
(email,),
|
|
)
|
|
return cursor.fetchone()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def get_user_by_id(user_id: str) -> dict | None:
|
|
"""Get user by ID."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor(dictionary=True)
|
|
cursor.execute(
|
|
"SELECT id, username, rsa_public_key, email, identity_key FROM users WHERE id = %s",
|
|
(user_id,),
|
|
)
|
|
return cursor.fetchone()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def shares_conversation(user_id_a: str, user_id_b: str) -> bool:
|
|
"""Check if two users share at least one conversation."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"SELECT 1 FROM conversation_members cm1 "
|
|
"JOIN conversation_members cm2 ON cm1.conversation_id = cm2.conversation_id "
|
|
"WHERE cm1.user_id = %s AND cm2.user_id = %s LIMIT 1",
|
|
(user_id_a, user_id_b),
|
|
)
|
|
return cursor.fetchone() is not None
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def get_user_contacts(user_id: str) -> list[str]:
|
|
"""Get all user IDs that share at least one conversation with the given user."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"SELECT DISTINCT cm2.user_id "
|
|
"FROM conversation_members cm1 "
|
|
"JOIN conversation_members cm2 ON cm1.conversation_id = cm2.conversation_id "
|
|
"WHERE cm1.user_id = %s AND cm2.user_id != %s",
|
|
(user_id, user_id),
|
|
)
|
|
return [row[0] for row in cursor.fetchall()]
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def update_user_rsa_key(user_id: str, rsa_public_key_pem: str):
|
|
"""Update user's RSA public key (for login)."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute("UPDATE users SET rsa_public_key = %s WHERE id = %s", (rsa_public_key_pem, user_id))
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def update_username(user_id: str, new_username: str):
|
|
"""Update user's display name."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute("UPDATE users SET username = %s WHERE id = %s", (new_username, user_id))
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
# --- Pre-keys ---
|
|
|
|
def store_signed_prekey(user_id: str, spk_id: str, public_key: bytes, signature: bytes,
|
|
device_id: str | None = None):
|
|
"""Store (or replace) a signed pre-key for a user's device."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
# Remove old SPKs for this user+device
|
|
if device_id:
|
|
cursor.execute("DELETE FROM signed_prekeys WHERE user_id = %s AND device_id = %s",
|
|
(user_id, device_id))
|
|
else:
|
|
cursor.execute("DELETE FROM signed_prekeys WHERE user_id = %s AND device_id IS NULL",
|
|
(user_id,))
|
|
cursor.execute(
|
|
"INSERT INTO signed_prekeys (id, user_id, device_id, public_key, signature) "
|
|
"VALUES (%s, %s, %s, %s, %s)",
|
|
(spk_id, user_id, device_id, public_key, signature),
|
|
)
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def get_signed_prekey(user_id: str, device_id: str | None = None) -> dict | None:
|
|
"""Get the current signed pre-key for a user (optionally per device)."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor(dictionary=True)
|
|
if device_id:
|
|
cursor.execute(
|
|
"SELECT id, public_key, signature, device_id, created_at FROM signed_prekeys "
|
|
"WHERE user_id = %s AND device_id = %s "
|
|
"ORDER BY created_at DESC LIMIT 1",
|
|
(user_id, device_id),
|
|
)
|
|
else:
|
|
cursor.execute(
|
|
"SELECT id, public_key, signature, device_id, created_at FROM signed_prekeys "
|
|
"WHERE user_id = %s ORDER BY created_at DESC LIMIT 1",
|
|
(user_id,),
|
|
)
|
|
return cursor.fetchone()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def store_one_time_prekeys(user_id: str, prekeys: list[dict], device_id: str | None = None):
|
|
"""Store a batch of one-time pre-keys. Each dict has {id, public_key (bytes)}."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
for pk in prekeys:
|
|
cursor.execute(
|
|
"INSERT INTO one_time_prekeys (id, user_id, device_id, public_key) "
|
|
"VALUES (%s, %s, %s, %s)",
|
|
(pk["id"], user_id, device_id, pk["public_key"]),
|
|
)
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def consume_one_time_prekey(user_id: str, device_id: str | None = None) -> dict | None:
|
|
"""Atomically consume one OTP: SELECT FOR UPDATE + DELETE.
|
|
Returns {id, public_key} or None."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor(dictionary=True)
|
|
conn.start_transaction()
|
|
if device_id:
|
|
cursor.execute(
|
|
"SELECT id, public_key FROM one_time_prekeys "
|
|
"WHERE user_id = %s AND device_id = %s LIMIT 1 FOR UPDATE",
|
|
(user_id, device_id),
|
|
)
|
|
else:
|
|
cursor.execute(
|
|
"SELECT id, public_key FROM one_time_prekeys "
|
|
"WHERE user_id = %s LIMIT 1 FOR UPDATE",
|
|
(user_id,),
|
|
)
|
|
row = cursor.fetchone()
|
|
if row:
|
|
cursor.execute("DELETE FROM one_time_prekeys WHERE id = %s", (row["id"],))
|
|
conn.commit()
|
|
return row
|
|
except Exception:
|
|
conn.rollback()
|
|
raise
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def count_one_time_prekeys(user_id: str, device_id: str | None = None) -> int:
|
|
"""Count remaining OTPs for a user (optionally per device)."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
if device_id:
|
|
cursor.execute(
|
|
"SELECT COUNT(*) FROM one_time_prekeys WHERE user_id = %s AND device_id = %s",
|
|
(user_id, device_id),
|
|
)
|
|
else:
|
|
cursor.execute("SELECT COUNT(*) FROM one_time_prekeys WHERE user_id = %s", (user_id,))
|
|
return cursor.fetchone()[0]
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def get_key_bundle(user_id: str) -> dict | None:
|
|
"""Get complete key bundle for X3DH (single device — legacy compat).
|
|
|
|
Returns {identity_key, signed_prekey_id, signed_prekey, spk_signature,
|
|
one_time_prekey_id, one_time_prekey} or None.
|
|
OTP is consumed atomically.
|
|
"""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor(dictionary=True)
|
|
# Get user identity key
|
|
cursor.execute("SELECT identity_key FROM users WHERE id = %s", (user_id,))
|
|
user = cursor.fetchone()
|
|
if not user:
|
|
return None
|
|
|
|
# Get signed prekey
|
|
cursor.execute(
|
|
"SELECT id, public_key, signature, device_id FROM signed_prekeys WHERE user_id = %s "
|
|
"ORDER BY created_at DESC LIMIT 1",
|
|
(user_id,),
|
|
)
|
|
spk = cursor.fetchone()
|
|
if not spk:
|
|
return None
|
|
|
|
# Consume one OTP (may be None) — use transaction for atomicity (H12 fix)
|
|
conn.start_transaction()
|
|
cursor.execute(
|
|
"SELECT id, public_key FROM one_time_prekeys WHERE user_id = %s LIMIT 1 FOR UPDATE",
|
|
(user_id,),
|
|
)
|
|
opk = cursor.fetchone()
|
|
if opk:
|
|
cursor.execute("DELETE FROM one_time_prekeys WHERE id = %s", (opk["id"],))
|
|
conn.commit()
|
|
|
|
result = {
|
|
"identity_key": user["identity_key"],
|
|
"signed_prekey_id": spk["id"],
|
|
"signed_prekey": spk["public_key"],
|
|
"spk_signature": spk["signature"],
|
|
}
|
|
if opk:
|
|
result["one_time_prekey_id"] = opk["id"]
|
|
result["one_time_prekey"] = opk["public_key"]
|
|
return result
|
|
except Exception:
|
|
try:
|
|
conn.rollback()
|
|
except Exception:
|
|
pass
|
|
raise
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def get_key_bundles_for_user(user_id: str) -> dict | None:
|
|
"""Get key bundles for ALL devices of a user. Returns
|
|
{identity_key, device_bundles: [{device_id, signed_prekey_id, signed_prekey_pub,
|
|
spk_signature, opk_id, opk_pub}]} or None.
|
|
Consumes one OPK per device atomically.
|
|
"""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor(dictionary=True)
|
|
# Get user identity key
|
|
cursor.execute("SELECT identity_key FROM users WHERE id = %s", (user_id,))
|
|
user = cursor.fetchone()
|
|
if not user:
|
|
return None
|
|
|
|
# Get all signed prekeys (one per device, most recent)
|
|
cursor.execute(
|
|
"SELECT id, public_key, signature, device_id FROM signed_prekeys "
|
|
"WHERE user_id = %s ORDER BY created_at DESC",
|
|
(user_id,),
|
|
)
|
|
all_spks = cursor.fetchall()
|
|
if not all_spks:
|
|
return None
|
|
|
|
# De-duplicate: keep only the most recent SPK per device_id
|
|
seen_devices = set()
|
|
spks_by_device = []
|
|
for spk in all_spks:
|
|
dev = spk.get("device_id") or "__legacy__"
|
|
if dev not in seen_devices:
|
|
seen_devices.add(dev)
|
|
spks_by_device.append(spk)
|
|
|
|
device_bundles = []
|
|
# Commit the implicit transaction from the read-only queries above
|
|
# so we can start an explicit transaction for atomic OPK consumption.
|
|
conn.commit()
|
|
conn.start_transaction()
|
|
for spk in spks_by_device:
|
|
dev_id = spk.get("device_id")
|
|
# Consume one OPK for this device
|
|
if dev_id:
|
|
cursor.execute(
|
|
"SELECT id, public_key FROM one_time_prekeys "
|
|
"WHERE user_id = %s AND device_id = %s LIMIT 1 FOR UPDATE",
|
|
(user_id, dev_id),
|
|
)
|
|
else:
|
|
cursor.execute(
|
|
"SELECT id, public_key FROM one_time_prekeys "
|
|
"WHERE user_id = %s AND device_id IS NULL LIMIT 1 FOR UPDATE",
|
|
(user_id,),
|
|
)
|
|
opk = cursor.fetchone()
|
|
if opk:
|
|
cursor.execute("DELETE FROM one_time_prekeys WHERE id = %s", (opk["id"],))
|
|
|
|
bundle = {
|
|
"device_id": dev_id,
|
|
"signed_prekey_id": spk["id"],
|
|
"signed_prekey_pub": spk["public_key"],
|
|
"spk_signature": spk["signature"],
|
|
}
|
|
if opk:
|
|
bundle["opk_id"] = opk["id"]
|
|
bundle["opk_pub"] = opk["public_key"]
|
|
device_bundles.append(bundle)
|
|
conn.commit()
|
|
|
|
return {
|
|
"identity_key": user["identity_key"],
|
|
"device_bundles": device_bundles,
|
|
}
|
|
except Exception:
|
|
try:
|
|
conn.rollback()
|
|
except Exception:
|
|
pass
|
|
raise
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
# --- Conversations ---
|
|
|
|
def create_conversation(member_user_ids: list[str], joined_at=None, name=None, created_by=None) -> str:
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
conv_id = generate_uuid()
|
|
cursor.execute("INSERT INTO conversations (id, name, created_by) VALUES (%s, %s, %s)",
|
|
(conv_id, name, created_by))
|
|
for uid in member_user_ids:
|
|
cursor.execute(
|
|
"INSERT INTO conversation_members (conversation_id, user_id, joined_at) VALUES (%s, %s, %s)",
|
|
(conv_id, uid, joined_at),
|
|
)
|
|
conn.commit()
|
|
return conv_id
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def add_conversation_member(conversation_id: str, user_id: str, joined_at=None):
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"INSERT IGNORE INTO conversation_members (conversation_id, user_id, joined_at) VALUES (%s, %s, %s)",
|
|
(conversation_id, user_id, joined_at),
|
|
)
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def remove_conversation_member(conversation_id: str, user_id: str):
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"DELETE FROM conversation_members WHERE conversation_id = %s AND user_id = %s",
|
|
(conversation_id, user_id),
|
|
)
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def count_conversation_members(conversation_id: str) -> int:
|
|
"""Count members in a conversation."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"SELECT COUNT(*) FROM conversation_members WHERE conversation_id = %s",
|
|
(conversation_id,),
|
|
)
|
|
return cursor.fetchone()[0]
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def get_conversation_file_ids(conversation_id: str) -> list[str]:
|
|
"""Get all file IDs (images + files) uploaded to a conversation."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"SELECT file_id FROM image_uploads WHERE conversation_id = %s",
|
|
(conversation_id,),
|
|
)
|
|
return [row[0] for row in cursor.fetchall()]
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def delete_conversation(conversation_id: str):
|
|
"""Delete a conversation entirely. CASCADE cleans up members, messages, etc."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute("DELETE FROM conversations WHERE id = %s", (conversation_id,))
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def get_conversation_members(conversation_id: str) -> list[dict]:
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor(dictionary=True)
|
|
cursor.execute(
|
|
"SELECT u.id, u.username, u.email, u.identity_key FROM conversation_members cm "
|
|
"JOIN users u ON cm.user_id = u.id "
|
|
"WHERE cm.conversation_id = %s",
|
|
(conversation_id,),
|
|
)
|
|
return cursor.fetchall()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def find_direct_conversation(user_id_a: str, user_id_b: str) -> str | None:
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"SELECT cm1.conversation_id FROM conversation_members cm1 "
|
|
"JOIN conversation_members cm2 ON cm1.conversation_id = cm2.conversation_id "
|
|
"WHERE cm1.user_id = %s AND cm2.user_id = %s "
|
|
"AND (SELECT COUNT(*) FROM conversation_members cm3 "
|
|
" WHERE cm3.conversation_id = cm1.conversation_id) = 2 "
|
|
"LIMIT 1",
|
|
(user_id_a, user_id_b),
|
|
)
|
|
row = cursor.fetchone()
|
|
return row[0] if row else None
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def update_conversation_creator(conversation_id: str, new_creator_id: str):
|
|
"""Transfer group creator role to another member."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"UPDATE conversations SET created_by = %s WHERE id = %s",
|
|
(new_creator_id, conversation_id),
|
|
)
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def get_conversation(conversation_id: str) -> dict | None:
|
|
"""Get conversation by ID."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor(dictionary=True)
|
|
cursor.execute(
|
|
"SELECT id, created_at, name, created_by, avatar_file FROM conversations WHERE id = %s",
|
|
(conversation_id,),
|
|
)
|
|
return cursor.fetchone()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def update_conversation_avatar(conversation_id: str, avatar_file: str):
|
|
"""Set avatar file for a conversation."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"UPDATE conversations SET avatar_file = %s WHERE id = %s",
|
|
(avatar_file, conversation_id),
|
|
)
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def update_conversation_name(conversation_id: str, name: str):
|
|
"""Update the name of a conversation."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"UPDATE conversations SET name = %s WHERE id = %s",
|
|
(name, conversation_id),
|
|
)
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def is_conversation_member(conversation_id: str, user_id: str) -> bool:
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"SELECT 1 FROM conversation_members WHERE conversation_id = %s AND user_id = %s",
|
|
(conversation_id, user_id),
|
|
)
|
|
return cursor.fetchone() is not None
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def list_user_conversations(user_id: str) -> list[dict]:
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor(dictionary=True)
|
|
cursor.execute(
|
|
"SELECT c.id, c.created_at, c.name, c.created_by, c.avatar_file FROM conversations c "
|
|
"JOIN conversation_members cm ON c.id = cm.conversation_id "
|
|
"WHERE cm.user_id = %s ORDER BY c.created_at DESC",
|
|
(user_id,),
|
|
)
|
|
convs = cursor.fetchall()
|
|
if not convs:
|
|
return convs
|
|
# Batch-fetch all members for all conversations in one query (N+1 fix)
|
|
conv_ids = [c["id"] for c in convs]
|
|
placeholders = ",".join(["%s"] * len(conv_ids))
|
|
cursor.execute(
|
|
f"SELECT cm.conversation_id, u.id AS user_id, u.username, u.email "
|
|
f"FROM conversation_members cm JOIN users u ON cm.user_id = u.id "
|
|
f"WHERE cm.conversation_id IN ({placeholders})",
|
|
conv_ids,
|
|
)
|
|
members_by_conv: dict[str, list[dict]] = {}
|
|
for row in cursor.fetchall():
|
|
cid = row.pop("conversation_id")
|
|
members_by_conv.setdefault(cid, []).append(row)
|
|
for conv in convs:
|
|
conv["members"] = members_by_conv.get(conv["id"], [])
|
|
return convs
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
# --- Group Invitations ---
|
|
|
|
def create_invitation(conversation_id: str, user_id: str, invited_by: str):
|
|
"""Create a pending group invitation."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
inv_id = generate_uuid()
|
|
cursor.execute(
|
|
"INSERT IGNORE INTO group_invitations (id, conversation_id, user_id, invited_by) "
|
|
"VALUES (%s, %s, %s, %s)",
|
|
(inv_id, conversation_id, user_id, invited_by),
|
|
)
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def get_pending_invitations(user_id: str) -> list[dict]:
|
|
"""Get all pending invitations for a user, joined with conversation and inviter info."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor(dictionary=True)
|
|
cursor.execute(
|
|
"SELECT gi.id, gi.conversation_id, gi.invited_by, gi.created_at, "
|
|
"c.name AS conversation_name, u.username AS invited_by_username "
|
|
"FROM group_invitations gi "
|
|
"JOIN conversations c ON gi.conversation_id = c.id "
|
|
"JOIN users u ON gi.invited_by = u.id "
|
|
"WHERE gi.user_id = %s "
|
|
"ORDER BY gi.created_at DESC",
|
|
(user_id,),
|
|
)
|
|
return cursor.fetchall()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def delete_invitation(conversation_id: str, user_id: str):
|
|
"""Delete a pending invitation."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"DELETE FROM group_invitations WHERE conversation_id = %s AND user_id = %s",
|
|
(conversation_id, user_id),
|
|
)
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def has_pending_invitation(conversation_id: str, user_id: str) -> bool:
|
|
"""Check if a user has a pending invitation for a conversation."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"SELECT 1 FROM group_invitations WHERE conversation_id = %s AND user_id = %s",
|
|
(conversation_id, user_id),
|
|
)
|
|
return cursor.fetchone() is not None
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
# --- Messages ---
|
|
|
|
def store_message(
|
|
conversation_id: str,
|
|
sender_id: str,
|
|
ratchet_header: bytes,
|
|
recipients: list[dict],
|
|
x3dh_header: bytes | None = None,
|
|
sender_chain_id: bytes | None = None,
|
|
sender_chain_n: int | None = None,
|
|
image_file_id: str | None = None,
|
|
sender_device_id: str | None = None,
|
|
) -> str:
|
|
"""Store an encrypted message with per-recipient ciphertext.
|
|
|
|
recipients: [{user_id, encrypted_content (bytes), nonce (bytes),
|
|
device_id (str, optional), ratchet_header (bytes, optional),
|
|
x3dh_header (bytes, optional)}]
|
|
"""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
msg_id = generate_uuid()
|
|
cursor.execute(
|
|
"INSERT INTO messages (id, conversation_id, sender_id, sender_device_id, "
|
|
"ratchet_header, x3dh_header, sender_chain_id, sender_chain_n, image_file_id) "
|
|
"VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)",
|
|
(msg_id, conversation_id, sender_id, sender_device_id, ratchet_header,
|
|
x3dh_header, sender_chain_id, sender_chain_n, image_file_id),
|
|
)
|
|
for r in recipients:
|
|
device_id = r.get("device_id", SELF_DEVICE_ID)
|
|
cursor.execute(
|
|
"INSERT INTO message_recipients (message_id, user_id, device_id, "
|
|
"encrypted_content, nonce, ratchet_header, x3dh_header) "
|
|
"VALUES (%s, %s, %s, %s, %s, %s, %s)",
|
|
(msg_id, r["user_id"], device_id, r["encrypted_content"], r["nonce"],
|
|
r.get("ratchet_header"), r.get("x3dh_header")),
|
|
)
|
|
conn.commit()
|
|
cursor.execute("SELECT created_at FROM messages WHERE id = %s", (msg_id,))
|
|
row = cursor.fetchone()
|
|
created_at = row[0].isoformat() if row else None
|
|
return msg_id, created_at
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def get_messages(conversation_id: str, user_id: str, limit: int = 50, offset: int = 0,
|
|
device_id: str | None = None, after_ts: str | None = None) -> list[dict]:
|
|
"""Get messages for a user in a conversation, JOINing their per-recipient ciphertext.
|
|
|
|
If device_id is set, returns rows where mr.device_id matches OR is the sentinel
|
|
(self-encrypted / legacy). May return duplicate message IDs when both device-specific
|
|
and self-encrypted rows exist — caller should deduplicate (prefer device-specific).
|
|
|
|
If after_ts is set, only returns messages created after that timestamp (ISO format).
|
|
Results are ordered ASC when after_ts is used, DESC otherwise.
|
|
"""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor(dictionary=True)
|
|
if device_id:
|
|
where_parts = ["m.conversation_id = %s",
|
|
"(cm.joined_at IS NULL OR m.created_at >= cm.joined_at)"]
|
|
params = [user_id, device_id, SELF_DEVICE_ID, user_id, conversation_id]
|
|
|
|
if after_ts:
|
|
where_parts.append("m.created_at > %s")
|
|
params.append(after_ts)
|
|
|
|
where_clause = " AND ".join(where_parts)
|
|
order = "ASC" if after_ts else "DESC"
|
|
|
|
cursor.execute(
|
|
"SELECT m.id, m.conversation_id, m.sender_id, m.sender_device_id, "
|
|
"m.ratchet_header, m.x3dh_header, "
|
|
"m.sender_chain_id, m.sender_chain_n, m.created_at, m.deleted_at, m.image_file_id, "
|
|
"m.pinned_at, m.pinned_by, "
|
|
"mr.encrypted_content, mr.nonce, mr.device_id AS mr_device_id, "
|
|
"mr.ratchet_header AS mr_ratchet_header, mr.x3dh_header AS mr_x3dh_header "
|
|
"FROM messages m "
|
|
"JOIN message_recipients mr ON m.id = mr.message_id AND mr.user_id = %s "
|
|
" AND (mr.device_id = %s OR mr.device_id = %s) "
|
|
"JOIN conversation_members cm ON cm.conversation_id = m.conversation_id AND cm.user_id = %s "
|
|
f"WHERE {where_clause} "
|
|
f"ORDER BY m.created_at {order} LIMIT %s OFFSET %s",
|
|
(*params, limit, offset),
|
|
)
|
|
else:
|
|
where_parts = ["m.conversation_id = %s",
|
|
"(cm.joined_at IS NULL OR m.created_at >= cm.joined_at)"]
|
|
params = [user_id, user_id, conversation_id]
|
|
|
|
if after_ts:
|
|
where_parts.append("m.created_at > %s")
|
|
params.append(after_ts)
|
|
|
|
where_clause = " AND ".join(where_parts)
|
|
order = "ASC" if after_ts else "DESC"
|
|
|
|
cursor.execute(
|
|
"SELECT m.id, m.conversation_id, m.sender_id, m.sender_device_id, "
|
|
"m.ratchet_header, m.x3dh_header, "
|
|
"m.sender_chain_id, m.sender_chain_n, m.created_at, m.deleted_at, m.image_file_id, "
|
|
"m.pinned_at, m.pinned_by, "
|
|
"mr.encrypted_content, mr.nonce, mr.device_id AS mr_device_id, "
|
|
"mr.ratchet_header AS mr_ratchet_header, mr.x3dh_header AS mr_x3dh_header "
|
|
"FROM messages m "
|
|
"JOIN message_recipients mr ON m.id = mr.message_id AND mr.user_id = %s "
|
|
"JOIN conversation_members cm ON cm.conversation_id = m.conversation_id AND cm.user_id = %s "
|
|
f"WHERE {where_clause} "
|
|
f"ORDER BY m.created_at {order} LIMIT %s OFFSET %s",
|
|
(*params, limit, offset),
|
|
)
|
|
return cursor.fetchall()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def count_messages(conversation_id: str, user_id: str) -> int:
|
|
"""Count total messages visible to a user in a conversation."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"SELECT COUNT(DISTINCT m.id) "
|
|
"FROM messages m "
|
|
"JOIN message_recipients mr ON m.id = mr.message_id AND mr.user_id = %s "
|
|
"JOIN conversation_members cm ON cm.conversation_id = m.conversation_id AND cm.user_id = %s "
|
|
"WHERE m.conversation_id = %s AND (cm.joined_at IS NULL OR m.created_at >= cm.joined_at)",
|
|
(user_id, user_id, conversation_id),
|
|
)
|
|
row = cursor.fetchone()
|
|
return row[0] if row else 0
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def get_message_conversation(message_id: str) -> str | None:
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT conversation_id FROM messages WHERE id = %s", (message_id,))
|
|
row = cursor.fetchone()
|
|
return row[0] if row else None
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def get_message_sender(message_id: str) -> str | None:
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT sender_id FROM messages WHERE id = %s", (message_id,))
|
|
row = cursor.fetchone()
|
|
return row[0] if row else None
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def get_deleted_messages_since(conversation_id: str, user_id: str, since_ts: str) -> list[str]:
|
|
"""Return message IDs that were soft-deleted since the given timestamp."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"SELECT m.id FROM messages m "
|
|
"JOIN conversation_members cm ON cm.conversation_id = m.conversation_id AND cm.user_id = %s "
|
|
"WHERE m.conversation_id = %s AND m.deleted_at IS NOT NULL AND m.deleted_at > %s",
|
|
(user_id, conversation_id, since_ts),
|
|
)
|
|
return [row[0] for row in cursor.fetchall()]
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
# --- Reactions ---
|
|
|
|
ALLOWED_REACTIONS = {"thumbsup", "heart", "laugh", "surprised", "sad", "thumbsdown"}
|
|
|
|
|
|
def add_reaction(message_id: str, user_id: str, reaction: str) -> tuple[bool, str | None]:
|
|
"""Add or replace a reaction. Returns (changed, old_reaction_or_None)."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor(dictionary=True)
|
|
cursor.execute(
|
|
"SELECT reaction FROM message_reactions WHERE message_id = %s AND user_id = %s",
|
|
(message_id, user_id),
|
|
)
|
|
row = cursor.fetchone()
|
|
old_reaction = row["reaction"] if row else None
|
|
|
|
if old_reaction == reaction:
|
|
return False, None # already same reaction
|
|
|
|
if old_reaction:
|
|
cursor.execute(
|
|
"UPDATE message_reactions SET reaction = %s, created_at = CURRENT_TIMESTAMP "
|
|
"WHERE message_id = %s AND user_id = %s",
|
|
(reaction, message_id, user_id),
|
|
)
|
|
else:
|
|
cursor.execute(
|
|
"INSERT INTO message_reactions (id, message_id, user_id, reaction) "
|
|
"VALUES (%s, %s, %s, %s)",
|
|
(generate_uuid(), message_id, user_id, reaction),
|
|
)
|
|
conn.commit()
|
|
return True, old_reaction
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def remove_reaction(message_id: str, user_id: str) -> bool:
|
|
"""Remove a user's reaction. Returns True if deleted."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"DELETE FROM message_reactions WHERE message_id = %s AND user_id = %s",
|
|
(message_id, user_id),
|
|
)
|
|
conn.commit()
|
|
return cursor.rowcount > 0
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def get_reactions(message_ids: list[str]) -> dict[str, list[dict]]:
|
|
"""Get reactions for multiple messages. Returns {msg_id: [{user_id, reaction, created_at}]}."""
|
|
if not message_ids:
|
|
return {}
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor(dictionary=True)
|
|
placeholders = ",".join(["%s"] * len(message_ids))
|
|
cursor.execute(
|
|
f"SELECT message_id, user_id, reaction, created_at "
|
|
f"FROM message_reactions WHERE message_id IN ({placeholders}) "
|
|
f"ORDER BY created_at",
|
|
tuple(message_ids),
|
|
)
|
|
result = {}
|
|
for row in cursor.fetchall():
|
|
mid = row["message_id"]
|
|
if mid not in result:
|
|
result[mid] = []
|
|
result[mid].append({
|
|
"user_id": row["user_id"],
|
|
"reaction": row["reaction"],
|
|
"created_at": row["created_at"].isoformat() if hasattr(row["created_at"], "isoformat") else str(row["created_at"]),
|
|
})
|
|
return result
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
# --- Pins ---
|
|
|
|
def pin_message(message_id: str, user_id: str, conversation_id: str) -> bool:
|
|
"""Pin a message. Returns True on success."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"UPDATE messages SET pinned_at = NOW(), pinned_by = %s "
|
|
"WHERE id = %s AND conversation_id = %s AND pinned_at IS NULL",
|
|
(user_id, message_id, conversation_id),
|
|
)
|
|
conn.commit()
|
|
return cursor.rowcount > 0
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def unpin_message(message_id: str, conversation_id: str) -> bool:
|
|
"""Unpin a message. Returns True on success."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"UPDATE messages SET pinned_at = NULL, pinned_by = NULL "
|
|
"WHERE id = %s AND conversation_id = %s AND pinned_at IS NOT NULL",
|
|
(message_id, conversation_id),
|
|
)
|
|
conn.commit()
|
|
return cursor.rowcount > 0
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def get_pinned_messages(conversation_id: str, user_id: str) -> list[dict]:
|
|
"""Get pinned messages for a conversation (membership verified via JOIN)."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor(dictionary=True)
|
|
cursor.execute(
|
|
"SELECT m.id AS message_id, m.sender_id, m.pinned_at, m.pinned_by, m.created_at "
|
|
"FROM messages m "
|
|
"JOIN conversation_members cm ON cm.conversation_id = m.conversation_id AND cm.user_id = %s "
|
|
"WHERE m.conversation_id = %s AND m.pinned_at IS NOT NULL AND m.deleted_at IS NULL "
|
|
"ORDER BY m.pinned_at DESC",
|
|
(user_id, conversation_id),
|
|
)
|
|
rows = cursor.fetchall()
|
|
for r in rows:
|
|
for k in ("pinned_at", "created_at"):
|
|
if r.get(k) and hasattr(r[k], "isoformat"):
|
|
r[k] = r[k].isoformat()
|
|
return rows
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
# --- Group Sender Keys ---
|
|
|
|
def store_sender_key(conversation_id: str, sender_id: str, chain_id: bytes,
|
|
device_id: str | None = None):
|
|
"""Store or update a sender key chain ID for a group member's device."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
dev = device_id or SELF_DEVICE_ID
|
|
cursor.execute(
|
|
"REPLACE INTO group_sender_keys (conversation_id, sender_id, device_id, chain_id) "
|
|
"VALUES (%s, %s, %s, %s)",
|
|
(conversation_id, sender_id, dev, chain_id),
|
|
)
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def get_sender_key(conversation_id: str, sender_id: str,
|
|
device_id: str | None = None) -> dict | None:
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor(dictionary=True)
|
|
dev = device_id or SELF_DEVICE_ID
|
|
cursor.execute(
|
|
"SELECT chain_id, created_at FROM group_sender_keys "
|
|
"WHERE conversation_id = %s AND sender_id = %s AND device_id = %s",
|
|
(conversation_id, sender_id, dev),
|
|
)
|
|
return cursor.fetchone()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
# --- Read Receipts ---
|
|
|
|
def filter_message_ids_by_conversation(conversation_id: str, message_ids: list[str]) -> list[str]:
|
|
"""Return only message_ids that belong to the given conversation."""
|
|
if not message_ids:
|
|
return []
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
placeholders = ",".join(["%s"] * len(message_ids))
|
|
cursor.execute(
|
|
f"SELECT id FROM messages WHERE id IN ({placeholders}) AND conversation_id = %s",
|
|
(*message_ids, conversation_id),
|
|
)
|
|
return [row[0] for row in cursor.fetchall()]
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def mark_messages_read(conversation_id: str, user_id: str, message_ids: list[str]):
|
|
if not message_ids:
|
|
return
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
# M1 fix: JOIN messages to verify message_ids belong to conversation_id
|
|
placeholders = ",".join(["%s"] * len(message_ids))
|
|
cursor.execute(
|
|
f"INSERT IGNORE INTO message_reads (message_id, user_id) "
|
|
f"SELECT m.id, %s FROM messages m "
|
|
f"WHERE m.id IN ({placeholders}) AND m.conversation_id = %s",
|
|
(user_id, *message_ids, conversation_id),
|
|
)
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def mark_conversation_read(conversation_id: str, user_id: str) -> int:
|
|
"""Mark ALL unread messages in a conversation as read for user. Returns count marked."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"INSERT IGNORE INTO message_reads (message_id, user_id) "
|
|
"SELECT m.id, %s "
|
|
"FROM messages m "
|
|
"JOIN message_recipients mr ON mr.message_id = m.id AND mr.user_id = %s "
|
|
"LEFT JOIN message_reads mrd ON mrd.message_id = m.id AND mrd.user_id = %s "
|
|
"WHERE m.conversation_id = %s AND m.sender_id != %s "
|
|
"AND m.deleted_at IS NULL AND mrd.message_id IS NULL",
|
|
(user_id, user_id, user_id, conversation_id, user_id),
|
|
)
|
|
count = cursor.rowcount
|
|
conn.commit()
|
|
return count
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def get_unread_counts(user_id: str, max_age_days: int = 0) -> dict[str, int]:
|
|
"""Return {conversation_id: unread_count} for all conversations the user is in.
|
|
|
|
max_age_days: if > 0, only count messages younger than this many days.
|
|
Must match METADATA_RETENTION_DAYS to avoid phantom unreads after read cleanup.
|
|
"""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor(dictionary=True)
|
|
age_filter = ""
|
|
params = [user_id, user_id, user_id]
|
|
if max_age_days > 0:
|
|
age_filter = " AND m.created_at >= DATE_SUB(NOW(), INTERVAL %s DAY)"
|
|
params.append(max_age_days)
|
|
cursor.execute(
|
|
"SELECT m.conversation_id, COUNT(DISTINCT m.id) AS cnt "
|
|
"FROM messages m "
|
|
"JOIN message_recipients mr ON mr.message_id = m.id AND mr.user_id = %s "
|
|
"LEFT JOIN message_reads mrd ON mrd.message_id = m.id AND mrd.user_id = %s "
|
|
"WHERE m.sender_id != %s AND m.deleted_at IS NULL AND mrd.message_id IS NULL"
|
|
f"{age_filter} "
|
|
"GROUP BY m.conversation_id",
|
|
params,
|
|
)
|
|
return {row["conversation_id"]: row["cnt"] for row in cursor.fetchall()}
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def get_message_read_status(message_ids: list[str]) -> dict:
|
|
if not message_ids:
|
|
return {}
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor(dictionary=True)
|
|
placeholders = ",".join(["%s"] * len(message_ids))
|
|
cursor.execute(
|
|
f"SELECT mr.message_id, mr.user_id, mr.read_at "
|
|
f"FROM message_reads mr "
|
|
f"WHERE mr.message_id IN ({placeholders})",
|
|
tuple(message_ids),
|
|
)
|
|
result = {}
|
|
for row in cursor.fetchall():
|
|
mid = row["message_id"]
|
|
if mid not in result:
|
|
result[mid] = []
|
|
result[mid].append({
|
|
"user_id": row["user_id"],
|
|
"read_at": row["read_at"].isoformat() if hasattr(row["read_at"], "isoformat") else str(row["read_at"]),
|
|
})
|
|
return result
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
# --- Delivery Receipts ---
|
|
|
|
def mark_messages_delivered(conversation_id: str, user_id: str, message_ids: list[str]):
|
|
"""Batch insert delivery receipts (INSERT IGNORE — idempotent)."""
|
|
if not message_ids:
|
|
return
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
# M1 fix: JOIN messages to verify message_ids belong to conversation_id
|
|
placeholders = ",".join(["%s"] * len(message_ids))
|
|
cursor.execute(
|
|
f"INSERT IGNORE INTO message_deliveries (message_id, user_id) "
|
|
f"SELECT m.id, %s FROM messages m "
|
|
f"WHERE m.id IN ({placeholders}) AND m.conversation_id = %s",
|
|
(user_id, *message_ids, conversation_id),
|
|
)
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def get_message_delivery_status(message_ids: list[str]) -> dict:
|
|
"""Get delivery status for messages. Returns {msg_id: [{user_id, delivered_at}]}."""
|
|
if not message_ids:
|
|
return {}
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor(dictionary=True)
|
|
placeholders = ",".join(["%s"] * len(message_ids))
|
|
cursor.execute(
|
|
f"SELECT md.message_id, md.user_id, md.delivered_at "
|
|
f"FROM message_deliveries md "
|
|
f"WHERE md.message_id IN ({placeholders})",
|
|
tuple(message_ids),
|
|
)
|
|
result = {}
|
|
for row in cursor.fetchall():
|
|
mid = row["message_id"]
|
|
if mid not in result:
|
|
result[mid] = []
|
|
result[mid].append({
|
|
"user_id": row["user_id"],
|
|
"delivered_at": row["delivered_at"].isoformat() if hasattr(row["delivered_at"], "isoformat") else str(row["delivered_at"]),
|
|
})
|
|
return result
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
# --- Delete ---
|
|
|
|
def soft_delete_message(message_id: str, sender_id: str) -> dict | None:
|
|
"""Soft-delete a message if sender matches. Returns {'image_file_id': ...} or None."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor(dictionary=True)
|
|
cursor.execute(
|
|
"SELECT sender_id, image_file_id FROM messages WHERE id = %s AND deleted_at IS NULL",
|
|
(message_id,),
|
|
)
|
|
row = cursor.fetchone()
|
|
if not row or row["sender_id"] != sender_id:
|
|
return None
|
|
cursor.execute(
|
|
"UPDATE messages SET deleted_at = NOW() WHERE id = %s",
|
|
(message_id,),
|
|
)
|
|
# Clear per-recipient ciphertext
|
|
cursor.execute(
|
|
"UPDATE message_recipients SET encrypted_content = %s WHERE message_id = %s",
|
|
(b"", message_id),
|
|
)
|
|
conn.commit()
|
|
return {"image_file_id": row.get("image_file_id")}
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def set_message_image_file_id(message_id: str, file_id: str):
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"UPDATE messages SET image_file_id = %s WHERE id = %s",
|
|
(file_id, message_id),
|
|
)
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
# --- Image Uploads ---
|
|
|
|
def create_image_upload(file_id: str, conversation_id: str, uploader_id: str, file_size: int):
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"INSERT INTO image_uploads (file_id, conversation_id, uploader_id, file_size) "
|
|
"VALUES (%s, %s, %s, %s)",
|
|
(file_id, conversation_id, uploader_id, file_size),
|
|
)
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def complete_image_upload(file_id: str):
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"UPDATE image_uploads SET completed = TRUE WHERE file_id = %s",
|
|
(file_id,),
|
|
)
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def get_image_upload(file_id: str) -> dict | None:
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor(dictionary=True)
|
|
cursor.execute(
|
|
"SELECT file_id, conversation_id, uploader_id, file_size, completed, created_at "
|
|
"FROM image_uploads WHERE file_id = %s",
|
|
(file_id,),
|
|
)
|
|
return cursor.fetchone()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def delete_image_upload(file_id: str):
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute("DELETE FROM image_uploads WHERE file_id = %s", (file_id,))
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
# --- User Profiles ---
|
|
|
|
def create_default_profile(user_id: str):
|
|
"""Create a default profile for a new user."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"INSERT IGNORE INTO user_profiles (user_id) VALUES (%s)",
|
|
(user_id,),
|
|
)
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def get_user_profile(user_id: str, viewer_id: str | None = None) -> dict | None:
|
|
"""Get user profile joined with user info. Respects visibility if viewer is different user."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor(dictionary=True)
|
|
cursor.execute(
|
|
"SELECT u.id AS user_id, u.username, u.email, u.created_at, "
|
|
"p.phone, p.phone_visible, p.email_visible, p.location, "
|
|
"p.location_visible, p.avatar_file, p.updated_at "
|
|
"FROM users u LEFT JOIN user_profiles p ON u.id = p.user_id "
|
|
"WHERE u.id = %s",
|
|
(user_id,),
|
|
)
|
|
row = cursor.fetchone()
|
|
if not row:
|
|
return None
|
|
# If viewing someone else's profile, apply visibility rules
|
|
if viewer_id and viewer_id != user_id:
|
|
if not row.get("email_visible"):
|
|
row["email"] = None
|
|
if not row.get("phone_visible"):
|
|
row["phone"] = None
|
|
if not row.get("location_visible"):
|
|
row["location"] = None
|
|
return row
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def update_user_profile(user_id: str, **fields):
|
|
"""Upsert user profile fields. Allowed: phone, phone_visible, email_visible,
|
|
location, location_visible, avatar_file."""
|
|
allowed = {"phone", "phone_visible", "email_visible", "location",
|
|
"location_visible", "avatar_file"}
|
|
filtered = {k: v for k, v in fields.items() if k in allowed}
|
|
if not filtered:
|
|
return
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
# Upsert: insert default then update
|
|
cursor.execute(
|
|
"INSERT IGNORE INTO user_profiles (user_id) VALUES (%s)",
|
|
(user_id,),
|
|
)
|
|
set_clause = ", ".join(f"{k} = %s" for k in filtered)
|
|
values = list(filtered.values()) + [user_id]
|
|
cursor.execute(
|
|
f"UPDATE user_profiles SET {set_clause} WHERE user_id = %s",
|
|
values,
|
|
)
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def batch_reencrypt_messages(user_id: str, updates: list[dict]):
|
|
"""Batch upsert message_recipients rows with self-encryption key data.
|
|
|
|
Each update: {message_id, encrypted_content (bytes), nonce (bytes)}.
|
|
Sets ratchet_header to '{"self":true}' and clears x3dh_header.
|
|
Uses INSERT ... ON DUPLICATE KEY UPDATE so it works for both sent messages
|
|
(which already have a SELF_DEVICE_ID row) and received messages (which don't).
|
|
"""
|
|
if not updates:
|
|
return
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
self_header = b'{"self":true}'
|
|
for u in updates:
|
|
cursor.execute(
|
|
"INSERT INTO message_recipients "
|
|
"(message_id, user_id, device_id, encrypted_content, nonce, ratchet_header, x3dh_header) "
|
|
"VALUES (%s, %s, %s, %s, %s, %s, NULL) "
|
|
"ON DUPLICATE KEY UPDATE encrypted_content = VALUES(encrypted_content), "
|
|
"nonce = VALUES(nonce), ratchet_header = VALUES(ratchet_header), x3dh_header = NULL",
|
|
(u["message_id"], user_id, SELF_DEVICE_ID,
|
|
u["encrypted_content"], u["nonce"], self_header),
|
|
)
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
# --- Phantom Users ---
|
|
|
|
def create_phantom_user(email: str) -> dict:
|
|
"""Create a phantom user with valid crypto keys for X3DH.
|
|
|
|
Phantom users have rsa_public_key = 'PHANTOM' as a marker.
|
|
Returns user dict: {id, username, email, identity_key}.
|
|
"""
|
|
username = email.split("@")[0]
|
|
user_id = generate_uuid()
|
|
|
|
# Generate real crypto keys so X3DH works on the client side
|
|
ik_private, ik_public = generate_identity_keypair()
|
|
ik_public_bytes = serialize_ed25519_public(ik_public)
|
|
|
|
spk = generate_signed_prekey(ik_private)
|
|
spk_pub_bytes = serialize_x25519_public(spk["public"])
|
|
spk_sig = spk["signature"]
|
|
|
|
opks = generate_one_time_prekeys(count=5)
|
|
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"INSERT INTO users (id, username, email, rsa_public_key, identity_key) "
|
|
"VALUES (%s, %s, %s, %s, %s)",
|
|
(user_id, username, email, "PHANTOM", ik_public_bytes),
|
|
)
|
|
cursor.execute(
|
|
"INSERT INTO signed_prekeys (id, user_id, public_key, signature) VALUES (%s, %s, %s, %s)",
|
|
(spk["id"], user_id, spk_pub_bytes, spk_sig),
|
|
)
|
|
for opk in opks:
|
|
cursor.execute(
|
|
"INSERT INTO one_time_prekeys (id, user_id, public_key) VALUES (%s, %s, %s)",
|
|
(opk["id"], user_id, serialize_x25519_public(opk["public"])),
|
|
)
|
|
conn.commit()
|
|
return {"id": user_id, "username": username, "email": email, "identity_key": ik_public_bytes}
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def is_phantom_user(user_id: str) -> bool:
|
|
"""Check if a user is a phantom (rsa_public_key == 'PHANTOM')."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT rsa_public_key FROM users WHERE id = %s", (user_id,))
|
|
row = cursor.fetchone()
|
|
return row is not None and row[0] == "PHANTOM"
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def delete_phantom_user(user_id: str):
|
|
"""Delete a phantom user. CASCADE removes signed_prekeys, one_time_prekeys,
|
|
conversation_members, message_recipients, etc."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"DELETE FROM users WHERE id = %s AND rsa_public_key = %s",
|
|
(user_id, "PHANTOM"),
|
|
)
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def upgrade_phantom_user(phantom_id: str, username: str, rsa_public_key_pem: str,
|
|
identity_key: bytes) -> str | None:
|
|
"""Upgrade a phantom user to a real user in-place.
|
|
|
|
Preserves user_id and all FK references (conversation_members, group_invitations, etc.).
|
|
Deletes phantom's server-generated prekeys (real user will upload own on first login).
|
|
Returns phantom_id as the new user_id, or None if phantom no longer exists.
|
|
"""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"UPDATE users SET username = %s, rsa_public_key = %s, identity_key = %s "
|
|
"WHERE id = %s AND rsa_public_key = 'PHANTOM'",
|
|
(username, rsa_public_key_pem, identity_key, phantom_id),
|
|
)
|
|
if cursor.rowcount == 0:
|
|
conn.rollback()
|
|
return None
|
|
# Remove phantom's server-generated crypto keys — real user uploads own
|
|
cursor.execute("DELETE FROM signed_prekeys WHERE user_id = %s", (phantom_id,))
|
|
cursor.execute("DELETE FROM one_time_prekeys WHERE user_id = %s", (phantom_id,))
|
|
conn.commit()
|
|
return phantom_id
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def get_all_phantom_user_ids() -> set[str]:
|
|
"""Return set of all phantom user IDs (for server startup cache)."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT id FROM users WHERE rsa_public_key = %s", ("PHANTOM",))
|
|
return {row[0] for row in cursor.fetchall()}
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def cleanup_stale_phantoms(max_age_days: int = 30) -> int:
|
|
"""Delete phantom users older than max_age_days with no active conversations with real users."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
# Two-step: SELECT ids first, then DELETE.
|
|
# MySQL error 1093: can't DELETE from table referenced in subquery.
|
|
cursor.execute("""
|
|
SELECT u.id FROM users u
|
|
WHERE u.rsa_public_key = 'PHANTOM'
|
|
AND u.created_at < DATE_SUB(NOW(), INTERVAL %s DAY)
|
|
AND NOT EXISTS (
|
|
SELECT 1 FROM conversation_members cm1
|
|
JOIN conversation_members cm2 ON cm1.conversation_id = cm2.conversation_id
|
|
JOIN users u2 ON cm2.user_id = u2.id
|
|
WHERE cm1.user_id = u.id
|
|
AND u2.rsa_public_key != 'PHANTOM'
|
|
)
|
|
""", (max_age_days,))
|
|
ids = [row[0] for row in cursor.fetchall()]
|
|
if not ids:
|
|
return 0
|
|
cursor.execute(
|
|
"DELETE FROM users WHERE id IN (%s)" % ",".join(["%s"] * len(ids)),
|
|
ids,
|
|
)
|
|
deleted = cursor.rowcount
|
|
conn.commit()
|
|
return deleted
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def remove_conversation_member_atomic(conversation_id: str, user_id: str) -> bool:
|
|
"""Remove member and return True if actually removed (row existed). M6 TOCTOU fix."""
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"DELETE FROM conversation_members WHERE conversation_id = %s AND user_id = %s",
|
|
(conversation_id, user_id),
|
|
)
|
|
conn.commit()
|
|
return cursor.rowcount > 0
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def get_stale_uploads(max_age_seconds: int = 3600) -> list[dict]:
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor(dictionary=True)
|
|
cursor.execute(
|
|
"SELECT file_id FROM image_uploads "
|
|
"WHERE completed = FALSE AND created_at < DATE_SUB(NOW(), INTERVAL %s SECOND)",
|
|
(max_age_seconds,),
|
|
)
|
|
return cursor.fetchall()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Metadata retention cleanup
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def cleanup_old_reads(days: int = 90, batch_size: int = 10000) -> int:
|
|
"""Delete message_reads older than N days in batches.
|
|
|
|
Only deletes reads for messages whose created_at is also past the retention
|
|
window. This prevents phantom unreads: get_unread_counts uses the same
|
|
time window (max_age_days) so messages outside the window aren't counted.
|
|
"""
|
|
total = 0
|
|
while True:
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"DELETE FROM message_reads "
|
|
"WHERE read_at < DATE_SUB(NOW(), INTERVAL %s DAY) "
|
|
"AND message_id IN ("
|
|
" SELECT id FROM messages "
|
|
" WHERE created_at < DATE_SUB(NOW(), INTERVAL %s DAY)"
|
|
") LIMIT %s",
|
|
(days, days, batch_size),
|
|
)
|
|
count = cursor.rowcount
|
|
conn.commit()
|
|
total += count
|
|
if count < batch_size:
|
|
break
|
|
finally:
|
|
conn.close()
|
|
return total
|
|
|
|
|
|
def cleanup_old_reactions(days: int = 90, batch_size: int = 10000) -> int:
|
|
"""Delete message_reactions older than N days in batches."""
|
|
total = 0
|
|
while True:
|
|
conn = get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"DELETE FROM message_reactions WHERE created_at < DATE_SUB(NOW(), INTERVAL %s DAY) LIMIT %s",
|
|
(days, batch_size),
|
|
)
|
|
count = cursor.rowcount
|
|
conn.commit()
|
|
total += count
|
|
if count < batch_size:
|
|
break
|
|
finally:
|
|
conn.close()
|
|
return total
|