From 2e7b72307d17c5dd9efa1304ab816e68183b83da Mon Sep 17 00:00:00 2001 From: Filip Date: Thu, 11 Jun 2026 18:22:39 -0400 Subject: [PATCH] =?UTF-8?q?Initial=20commit=20=E2=80=94=20encrypted=20chat?= =?UTF-8?q?=20server=20+=20Python=20clients=20(v0.8.5)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit E2E encrypted chat (X3DH + Double Ratchet, Signal Protocol). Server: asyncio TCP + TLS, MySQL. Clients: PyQt6 GUI + CLI. Secrets (.env, TLS keys, Cloudflare token), runtime data and mobile clients (separate repos) are gitignored. Co-Authored-By: Claude Fable 5 --- .gitignore | 19 + AGENTS.md | 42 + CLAUDE.md | 1136 +++++ Dockerfile | 44 + KEC-18_operational_cost_analysis.md | 289 ++ README.md | 326 ++ SECURITY_AUDIT.md | 491 ++ TODO.md | 138 + certs/README.md | 101 + certs/reload-server.sh | 28 + certs/setup-tls.sh | 108 + chat_core.py | 3907 +++++++++++++++ client.py | 928 ++++ crypto_utils.py | 988 ++++ db.py | 1726 +++++++ docker-compose.yml | 70 + gemini.md | 152 + gui_client.py | 6987 +++++++++++++++++++++++++++ protocol.py | 146 + requirements.txt | 11 + scaling.md | 252 + schema.sql | 189 + server.py | 3204 ++++++++++++ theme.py | 539 +++ 24 files changed, 21821 insertions(+) create mode 100644 .gitignore create mode 100644 AGENTS.md create mode 100644 CLAUDE.md create mode 100644 Dockerfile create mode 100644 KEC-18_operational_cost_analysis.md create mode 100644 README.md create mode 100644 SECURITY_AUDIT.md create mode 100644 TODO.md create mode 100644 certs/README.md create mode 100755 certs/reload-server.sh create mode 100755 certs/setup-tls.sh create mode 100644 chat_core.py create mode 100644 client.py create mode 100644 crypto_utils.py create mode 100644 db.py create mode 100644 docker-compose.yml create mode 100644 gemini.md create mode 100644 gui_client.py create mode 100644 protocol.py create mode 100644 requirements.txt create mode 100644 scaling.md create mode 100644 schema.sql create mode 100644 server.py create mode 100644 theme.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..58ca186 --- /dev/null +++ b/.gitignore @@ -0,0 +1,19 @@ +# Secrets & runtime config +.env +certs/*.pem +certs/cloudflare.ini + +# Python +.venv/ +__pycache__/ +*.pyc + +# Runtime data +uploads/ + +# Local tooling +.claude/settings.local.json + +# Mobile clients (separate repos) +ios/ +Android/ diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..fbcf12d --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,42 @@ +# Repository Guidelines + +## Project Structure & Module Organization + +The main Python modules live in the repository root. `server.py` contains the asyncio TCP server, request handlers, rate limiting, and upload flows. `chat_core.py` holds shared client logic, crypto workflows, and local key handling. `client.py` is the CLI, `gui_client.py` is the PyQt6 GUI, `db.py` is the MySQL layer, `protocol.py` defines the newline-delimited JSON protocol, and `crypto_utils.py` contains X3DH, Double Ratchet, Sender Keys, and local encryption helpers. Use `schema.sql` for a clean database bootstrap. Security and architecture notes are tracked in `SECURITY_AUDIT.md`, `README.md`, `scaling.md`, and `CLAUDE.md`. Put new test tooling under `tests/`. Treat `zaloha/` as archive code, not an active source directory. + +## Build, Test, and Development Commands + +Use the project virtualenv and MySQL schema: + +```bash +.venv/bin/pip install -r requirements.txt +mysql -u -p < schema.sql +.venv/bin/python server.py +.venv/bin/python client.py +.venv/bin/python gui_client.py +``` + +For quick validation, run: + +```bash +.venv/bin/python -m py_compile server.py chat_core.py client.py gui_client.py db.py +.venv/bin/python tests/pentest_client.py --server-host --member-email ... --peer-email ... --outsider-email ... +``` + +There is no full `pytest` suite yet; current regression coverage is mainly protocol-level through `tests/pentest_client.py`. + +## Coding Style & Naming Conventions + +Follow existing Python conventions: 4-space indentation, `snake_case` for functions and variables, `PascalCase` for classes, and type hints on new or changed code. Keep handlers non-blocking: DB, file, or SMTP work that can block should be moved behind async helpers or `asyncio.to_thread()`. Reuse central validation helpers instead of duplicating checks, and keep logs free of secrets, emails, or raw user-controlled text where possible. + +## Testing Guidelines + +Add tests in `tests/` with descriptive names. Prefer `test_.py` for focused checks and `_client.py` for protocol or penetration probes. Every security fix should include a regression path that covers malformed input, authorization, replay, rate limiting, or multi-device behavior. + +## Commit & Pull Request Guidelines + +Git history is not available in this workspace snapshot, so use short imperative commit messages. Conventional Commit style is preferred, for example `fix: reject invalid ratchet headers`. PRs should summarize behavior changes, mention schema or `.env` updates, link related issues, and include CLI or GUI evidence for user-visible changes. + +## Security & Configuration Tips + +Do not commit `.env`, TLS private keys, uploaded files, or local key material from `~/.encrypted_chat/`. When testing TLS, remember that `0.0.0.0` is a server bind address, not a valid client hostname. Use a host or IP that matches the certificate SAN or CN. diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..1909492 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,1136 @@ +# Encrypted Chat — Project Context + +End-to-end encrypted chat with forward secrecy (X3DH + Double Ratchet, Signal Protocol). +Server stores and relays opaque blobs — never sees plaintext. RSA retained for login only. + +## Files + +| File | Lines | Purpose | +|------|-------|---------| +| `schema.sql` | ~172 | MySQL schema (users, devices, signed_prekeys, one_time_prekeys, conversations, conversation_members, group_invitations, messages, message_recipients, message_reactions, group_sender_keys, message_reads, image_uploads, user_profiles) | +| `db.py` | ~1500 | MySQL CRUD — one connection per call, `dictionary=True` cursors, returns dicts. Includes profile CRUD, `get_user_contacts()`, `update_conversation_creator()`, `get_conversation()`. Phantom user CRUD + `upgrade_phantom_user()`. Invitation CRUD. Group avatar. Device CRUD. Per-device prekey/session management. Reactions CRUD (`add_reaction`, `remove_reaction`, `get_reactions`). Pins CRUD (`pin_message`, `unpin_message`, `get_pinned_messages`). | +| `server.py` | ~2200 | Asyncio TCP server, handler dispatch, rate limiting, real-time notifications via `connected_clients` dict. Profile + avatar handlers. Online/offline status push. Leave group, delete conversation, group invitations, group avatar handlers. Phantom user support. Graceful shutdown. 4 asyncio.Lock guards (H4 fix). Device registration + per-device key bundles + per-device notifications. SPK age reporting in `get_prekey_count`. Reactions, pins, pinned messages handlers. TCP keepalive (SO_KEEPALIVE) + dead writer cleanup. Streaming download (`download_stream` handler). | +| `protocol.py` | ~114 | Newline-delimited JSON protocol, `ProtocolReader`/`ProtocolWriter`, `encode_binary`/`decode_binary` (base64). Constants: `VERSION`, `MAX_MESSAGE_BYTES`, `MAX_IMAGE_BYTES`, `MAX_FILE_BYTES`, `IMAGE_CHUNK_SIZE`. | +| `crypto_utils.py` | ~950 | Ed25519, X25519, AES-256-GCM, HKDF, PBKDF2, X3DH, `DoubleRatchet` (with state snapshot/rollback), `SenderKeyState` (with state snapshot/rollback). RSA for login only. ECP1 password-based key encryption format (600k PBKDF2 iterations). Contact key verification: `compute_fingerprint()`, `format_fingerprint()`, `compute_safety_number()`, `encode_verification_qr()`, `decode_verification_qr()`. Message padding: `pad_plaintext()`, `unpad_plaintext()`. | +| `chat_core.py` | ~2850 | `ChatClient` class — session management, X3DH/ratchet encryption, local key storage, reconnect, profiles, file sharing, leave group, delete conversation, invitations, group avatar. Multi-device: per-device sessions, device_id persistence, device bundle cache. SPK rotation (7-day) with grace period. Reactions, pins, forwarding methods. Contact key verification: TOFU registry, explicit verification, safety numbers, QR code verify. Used by CLI + GUI | +| `client.py` | ~520 | Interactive CLI client — reactions, pin, forward, pinned messages, verify contact, show fingerprint commands | +| `gui_client.py` | ~3600 | PyQt6 GUI — `AsyncBridge` QThread bridges asyncio <-> Qt signals, `MainWindow`, `UserProfileDialog`, `VerificationDialog`, connection indicator + auto-reconnect, online status, file sharing, leave group, unread badges, circular avatars in conv list, online green dot overlay, group invitations UI, delete conversation, group avatar support, message reactions (emoji badges), forwarding with dialog, pin/unpin with indicator, pinned messages list, @mentions autocomplete, contact verification indicators (conv list green checkmark, E2E label status, key change warning dialog) | +| `ios_client/` | ~6200 | Native iOS client (Swift/SwiftUI) — wire-compatible with Python server. 47 Swift files: CryptoKit crypto (AES-GCM, HKDF, Ed25519, X25519), pure Swift GF(2^255-19) field arithmetic for Ed→X conversion, Security.framework RSA-4096, Network.framework TCP+TLS, Signal Protocol (X3DH + Double Ratchet + Sender Keys), SwiftUI views. Uses `project.yml` (XcodeGen). | + +## Architecture & Data Flow + +### Encryption: X3DH + Double Ratchet (Signal Protocol) + +**Keys per user:** +- **RSA-4096** — Login challenge-response only (server stores public key). Password-encrypted with ECP1 format (PBKDF2 600k iterations + AES-256-GCM). +- **Identity Key (IK)** — Ed25519 (signing) + converted to X25519 (for DH in X3DH). Password-encrypted with ECP1 format. +- **Signed Pre-Key (SPK)** — X25519, signed by IK, uploaded to server. **Rotates every 7 days** (M4). Previous SPK kept for grace period (in-flight X3DH). +- **One-Time Pre-Keys (OPK)** — X25519, consumed on X3DH initiation, auto-replenished when count < 20 + +**DM flow:** +1. Alice fetches Bob's per-device key bundles (IK, SPK per device, OPK per device) -> X3DH per device -> shared secret per device +2. Double Ratchet initialized from shared secret — one session per (user, device) pair +3. Each message: symmetric ratchet (HMAC chain) -> message key -> AES-256-GCM +4. Each reply direction change: DH ratchet (new X25519 keypair) -> new root + chain keys +5. Per-device ciphertext — each recipient device gets individually encrypted blob +6. Self-encrypted copy uses SELF_DEVICE_ID sentinel, readable by all own devices + +**Group flow (Sender Keys):** +1. Each sender has own SenderKeyState per group +2. Sender key distributed to members via pairwise Double Ratchet (as control DM with `_sender_key` field) +3. Group messages: symmetric ratchet on sender key -> AES-256-GCM +4. Same ciphertext replicated to all recipients (efficient) + +### Contact Key Verification (Out-of-Band Trust) + +**Problem:** Server stores identity keys but could MITM new sessions by substituting keys. Users need a way to verify keys out-of-band. + +**Design:** Entirely client-side — zero server changes. Server already stores identity keys in `users.identity_key` (32B Ed25519) and returns them via `get_user_info`/`get_key_bundle`. Verification state is local-only (privacy by design — server never learns who verified whom). + +**Trust model:** +``` +First contact → TOFU (Trust On First Use) → "trusted" (key recorded) + ↓ + Out-of-band verify → "verified" (explicit confirmation) + ↓ + Key change detected → "changed" / "changed_verified" (WARNING) + ↓ + Accept new key → "trusted" (verification reset) +``` + +**Verification methods:** +1. **Safety numbers** — 60-digit number (12 groups × 5 digits), deterministic for each pair (lower user_id's fingerprint first). Both users see the same number — compare in person or over trusted channel. +2. **QR codes** — Encode `0x01 + uid_len + uid + identity_key` (70 bytes). One user shows QR, other scans → automatic verification. +3. **Fingerprints** — Per-user 30-digit number (6 groups × 5 digits). Visual comparison. + +**Algorithm (Signal NumericFingerprint compatible):** +- Fingerprint: SHA-512 iterated 5200× on `version(2B) + identity_key(32B) + user_id(UTF-8)`, truncated to 32 bytes +- Display: `int(bytes[i*5:(i+1)*5], big-endian) % 100000`, zero-padded to 5 digits + +**Local storage** (encrypted with `_local_key`, AES-256-GCM): +- `known_identity_keys.bin` — TOFU registry: `{user_id → {identity_key hex, first_seen, last_seen}}` +- `verified_contacts.bin` — Explicit verification: `{user_id → {identity_key hex, verified_at, method}}` + +**Tamper resistance:** Oba soubory jsou šifrovány AES-256-GCM (klíč odvozen z identity key přes HKDF). Na rozdíl od session/sender key souborů (které mají plaintext migration fallback kvůli zpětné kompatibilitě) verifikační soubory **nemají žádný plaintext fallback** — pokud dešifrování selže, vrátí se prázdný dict. Útočník s přístupem k disku (ale bez znalosti hesla/identity key) tedy nemůže: +- Podvrhnout falešný "verified" status (injekce do `verified_contacts.bin`) +- Potlačit TOFU key-change warning (injekce do `known_identity_keys.bin`) +- Nejhorší případ při manipulaci = verifikace se resetuje (prázdný stav), nikdy se nepřijme podvržená hodnota + +**Threat model:** +- **DNS únos / podvržený server (existující kontakt):** TOFU detekuje key change → warning dialog. Útočník nemůže tiše podvrhnout klíč. +- **DNS únos / podvržený server (první kontakt):** TOFU věří prvnímu klíči — MITM úspěšný. Obrana: out-of-band verifikace safety number přes jiný kanál (telefon, osobně). Pokud se čísla neshodnou → odhaleno. +- **Kompromitovaný reálný server:** Server podmění klíč v `get_key_bundle` → TOFU detekuje změnu (stejné jako DNS únos). Server neví kdo je verified (stav je lokální) → nemůže cíleně obejít. +- **Disk access bez hesla:** Verifikační soubory šifrovány AES-256-GCM, žádný plaintext fallback → nelze podvrhnout. Viz tamper resistance výše. +- **TLS je první linie obrany:** S platným certifikátem DNS únos nestačí. Bez TLS / s `TLS_INSECURE` je MITM triviální. Verifikace kontaktů je druhá linie — chrání i při kompromitovaném serveru. + +**Data flow:** +1. `_get_user_info()` calls `check_identity_key()` → records TOFU or detects change +2. Key change → `_key_change_cb` fires → GUI shows warning dialog +3. User opens VerificationDialog → sees safety number + QR → marks verified +4. `_rebuild_conv_list()` queries `get_verification_status()` → shows green checkmark for verified DMs +5. E2E label in chat header shows "Verified" (green) or "Encrypted" (muted) + +**QR code encoding detail:** +- Raw binary payload (`0x01 + uid_len + uid + identity_key`) je před vložením do QR zakódován jako **base64** (ASCII-safe) +- Důvod: QR čtečky (pyzbar/zbar) re-kódují binární data přes UTF-8 → byty > 127 se zkomolí +- Při skenování se base64 dekóduje zpět na raw bytes → `decode_verification_qr()` +- iOS implementace musí použít stejný base64 wrapper (viz iOS spec níže) + +**Self-verification exclusion:** +- Vlastní user_id se nikdy nezobrazuje jako "unverified" — TOFU registr neobsahuje vlastní klíč (ten je na disku) +- GUI: security section v UserProfileDialog se nezobrazuje pro vlastní profil, verified badge v group info přeskakuje vlastní uid + +### Protocol + +Newline-delimited JSON over TCP (optional TLS). Fields: `type`, `status`, `data`, `request_id`. +Binary data encoded as base64 via `encode_binary()`/`decode_binary()`. + +**Request/response pattern:** Client sends `{"type": "...", "request_id": "uuid", ...}`, server responds with same `request_id`. Notifications (push) have no `request_id`. + +### Server notifications (push to connected clients) +- `new_message` — per-recipient ciphertext included +- `messages_read` — conversation_id + user_id + message_ids +- `message_deleted` — message_id + conversation_id +- `conversation_created` — conversation_id, name, created_by, members[] (pushed to added members) +- `member_added` — conversation_id, user_id, username, email (pushed to all members except requester) +- `member_removed` — conversation_id, user_id (pushed to removed member + remaining members) +- `group_invitation` — conversation_id, conversation_name, invited_by, invited_by_username (pushed to invited user) +- `conversation_renamed` — conversation_id, name, renamed_by (pushed to all members except renamer) +- `session_reset` — from_user_id, from_device_id (pushed to peer when session reset requested) +- `message_reacted` — message_id, conversation_id, user_id, username, reaction, action (pushed to members) +- `message_pinned` — message_id, conversation_id, user_id, username (pushed to members) +- `message_unpinned` — message_id, conversation_id, user_id, username (pushed to members) +- `user_online` — user_id (pushed to contacts when user connects) +- `user_offline` — user_id (pushed to contacts when user's last connection drops) +- `online_users` — user_ids[] (sent to user on login — list of currently online contacts) + +## DB Schema (schema.sql) + +``` +users: id, username, email (UNIQUE), rsa_public_key (TEXT), identity_key (BLOB 32B Ed25519), created_at +devices: id, user_id FK, device_name (nullable), created_at, last_seen_at +signed_prekeys: id, user_id FK, device_id (nullable), public_key (BLOB 32B), signature (BLOB 64B), created_at +one_time_prekeys: id, user_id FK, device_id (nullable), public_key (BLOB 32B) +conversations: id, created_at, name (nullable), created_by (nullable), avatar_file (nullable) +conversation_members: conversation_id + user_id (composite PK), joined_at +group_invitations: id, conversation_id FK, user_id FK, invited_by FK, created_at, UNIQUE(conversation_id, user_id) +messages: id, conversation_id FK, sender_id FK, sender_device_id (nullable), ratchet_header (BLOB JSON), + x3dh_header (BLOB JSON nullable), sender_chain_id (BLOB nullable), sender_chain_n (INT nullable), + created_at, deleted_at, image_file_id, pinned_at (nullable), pinned_by (nullable) +message_recipients: message_id + user_id + device_id (composite PK), encrypted_content (BLOB), nonce (BLOB), + ratchet_header (BLOB nullable), x3dh_header (BLOB nullable) +message_reactions: id, message_id FK, user_id FK, reaction VARCHAR(32), created_at, UNIQUE(message_id, user_id, reaction) +group_sender_keys: conversation_id + sender_id + device_id (composite PK), chain_id (BLOB 32B), created_at +message_reads: message_id + user_id (composite PK), read_at +image_uploads: file_id (PK), conversation_id FK, uploader_id FK, file_size, completed, created_at +user_profiles: user_id (PK FK), phone, phone_visible, email_visible, location, location_visible, avatar_file, updated_at +``` + +Constant: `SELF_DEVICE_ID = "00000000-0000-0000-0000-000000000000"` — sentinel for self-encrypted copies and legacy rows. + +Index: `messages(conversation_id, created_at)` for query performance. + +## Server Protocol — All Message Types + +### Pre-login (no session required) +| Type | Handler | Purpose | +|------|---------|---------| +| `register` | `handle_register_start` | Start registration (username, email, public_key, identity_key) | +| `register_confirm` | `handle_register_confirm` | Confirm with 6-digit code | +| `login_start` | `handle_login_start` | Get RSA challenge | +| `login_finish` | `handle_login_finish` | Respond with RSA signature -> session. Client sends `client_version`, server returns `server_version` in response. Also sends `online_users` and `user_online` notifications. | +| `get_user_info` | `handle_get_user_info` | Get user info + identity_key (by email or user_id) | +| `pairing_start` | `handle_pairing_start` | New device starts pairing (gets 8-digit code) | +| `pairing_poll` | `handle_pairing_poll` | New device polls for key payload | + +### Post-login (session required) +| Type | Handler | Purpose | +|------|---------|---------| +| `upload_prekeys` | `handle_upload_prekeys` | Upload SPK + batch of OPKs (server verifies SPK signature) | +| `get_key_bundle` | `handle_get_key_bundle` | Fetch key bundle for X3DH (consumes one OPK) | +| `get_prekey_count` | `handle_get_prekey_count` | Check remaining OPK count + SPK age (`spk_created_at`) for rotation | +| `ensure_prekeys` | `handle_ensure_prekeys` | Combined get_prekey_count + upload_prekeys in single round-trip. Returns count + spk_created_at + upload status. | +| `create_conversation` | `handle_create_conversation` | Create conversation — DMs auto-add both; groups add creator only + create invitations for others | +| `find_conversation` | `handle_find_conversation` | Find existing DM by email | +| `add_member` | `handle_add_member` | Create invitation for user to join group (was: direct add) | +| `remove_member` | `handle_remove_member` | Remove member (creator only) | +| `leave_group` | `handle_leave_group` | Voluntarily leave a group (transfers creator if needed, blocks DM leave) | +| `rename_conversation` | `handle_rename_conversation` | Rename group conversation (creator only, max 100 chars), pushes `conversation_renamed` to members | +| `delete_conversation` | `handle_delete_conversation` | Delete conversation — DMs: remove self; groups: creator-only, removes all members + files | +| `accept_invitation` | `handle_accept_invitation` | Accept pending group invitation → add to members, notify others | +| `decline_invitation` | `handle_decline_invitation` | Decline pending group invitation | +| `list_invitations` | `handle_list_invitations` | List user's pending invitations (with conv name + inviter username) | +| `list_conversations` | `handle_list_conversations` | List all user's conversations (includes avatar_file) | +| `send_message` | `handle_send_message` | Send encrypted message (ratchet_header + recipients[]) | +| `get_messages` | `handle_get_messages` | Get messages (returns per-user ciphertext, JOINs message_recipients). Supports `after_ts` for incremental sync. ROW_NUMBER dedup when both device-specific and SELF_DEVICE_ID rows exist. | +| `mark_read` | `handle_mark_read` | Mark messages as read | +| `delete_message` | `handle_delete_message` | Soft-delete message (sender only) | +| `rotate_keys` | `handle_rotate_keys` | Rotate RSA login key, disconnect other sessions | +| `pairing_claim` | `handle_pairing_claim` | Authorized device claims pairing code | +| `pairing_send` | `handle_pairing_send` | Authorized device sends encrypted key payload | +| `upload_image_start/chunk/end` | Image/file upload | Chunked encrypted upload (256KB chunks). `file_type` param: `"image"` (5MB limit) or `"file"` (50MB limit). | +| `download_image` | `handle_download_image` | Legacy chunked download with offset (one chunk per request, 2 DB queries per chunk) | +| `download_stream` | `handle_download_stream` | Streaming download: single request → server sends all chunks with same `request_id` + incremental `seq`. One DB auth check via `_validate_download()`, no per-chunk round-trip. Response fields: `file_id`, `data` (base64), `offset`, `seq`, `done`, `total_size`. | +| `get_profile` | `handle_get_profile` | Get user profile (respects visibility for other users) | +| `update_profile` | `handle_update_profile` | Update own profile (phone, location, visibility toggles) | +| `update_avatar` | `handle_update_avatar` | Upload user avatar (base64, max 2MB, JPEG/PNG) | +| `get_avatar` | `handle_get_avatar` | Download user's avatar | +| `update_group_avatar` | `handle_update_group_avatar` | Upload group avatar (base64, max 2MB, JPEG/PNG, creator only) | +| `get_group_avatar` | `handle_get_group_avatar` | Download group avatar | +| `get_deleted_since` | `handle_get_deleted_since` | Get message IDs deleted since a given timestamp (for incremental sync) | +| `reencrypt_messages` | `handle_reencrypt_messages` | Batch upsert message history with self-key (max 500/request, for device pairing + received msg self-encryption) | +| `list_devices` | `handle_list_devices` | List all devices for current user | +| `remove_device` | `handle_remove_device` | Remove a device (not current device) | +| `session_reset` | `handle_session_reset` | Notify peer to reset corrupted Double Ratchet session (push `session_reset` to peer) | +| `react_message` | `handle_react_message` | Add/remove emoji reaction on a message. Push `message_reacted` to members. | +| `pin_message` | `handle_pin_message` | Pin/unpin a message. Push `message_pinned`/`message_unpinned` to members. | +| `get_pinned_messages` | `handle_get_pinned_messages` | Get list of pinned messages for a conversation. | + +## Key Classes & Functions + +### crypto_utils.py + +**Password-based key encryption (ECP1 format):** +- `PBKDF2_ITERATIONS = 600_000` — OWASP 2023 compliant +- `_encrypt_private_key(raw_bytes, password) -> bytes` — PBKDF2-HMAC-SHA256 + AES-256-GCM. Format: `_ECP1_MAGIC(4) + salt(16) + nonce(12) + ciphertext_with_tag` +- `_decrypt_private_key(data, password) -> bytes` — Detects ECP1 magic prefix, derives key, decrypts + +**RSA (login only):** `generate_rsa_keypair()`, `serialize_private_key()` (ECP1 with password, PEM without), `serialize_public_key()`, `load_private_key()` (auto-detects ECP1 vs legacy PEM), `load_public_key()`, `rsa_sign()`, `rsa_verify()` + +**AES-256-GCM:** `aes_encrypt(plaintext, key=None) -> (key, nonce, ct, tag)`, `aes_decrypt(key, nonce, ct, tag) -> plaintext` + +**Ed25519:** `generate_identity_keypair()`, `serialize_ed25519_private()` (ECP1 with password, 32-byte raw without), `serialize_ed25519_private_raw()`, `serialize_ed25519_public()`, `load_ed25519_private()` (auto-detects ECP1 vs legacy PEM vs raw), `load_ed25519_public()`, `ed25519_sign()`, `ed25519_verify()` + +**X25519:** `generate_x25519_keypair()`, `serialize_x25519_private()`, `serialize_x25519_public()`, `load_x25519_private()`, `load_x25519_public()`, `x25519_dh()` + +**Key conversion:** `ed25519_private_to_x25519()` (SHA-512 + clamp), `ed25519_public_to_x25519()` (Montgomery u-coordinate) + +**HKDF:** `hkdf_derive()`, `kdf_rk(root_key, dh_output) -> (new_root_key, chain_key)`, `kdf_ck(chain_key) -> (new_chain_key, message_key)` + +**X3DH:** `generate_signed_prekey(identity_private) -> {private, public, signature, id}`, `generate_one_time_prekeys(count=50) -> [{private, public, id}]`, `x3dh_initiate(ik_private_ed, ik_public_remote_ed, spk_remote, spk_signature, opk_remote?) -> (shared_secret, ek_priv, ek_pub)`, `x3dh_respond(ik_private_ed, spk_private, ik_remote_ed, ek_remote, opk_private?) -> shared_secret` + +**DoubleRatchet class:** +- `init_alice(shared_secret, bob_spk_pub)` — initiator, performs first DH ratchet +- `init_bob(shared_secret, spk_pair)` — responder, waits for first message +- `encrypt(plaintext) -> {header: {dh_pub, n, pn}, ciphertext, nonce}` — AAD = serialized header +- `decrypt(header_dict, ciphertext, nonce)` — handles DH ratchet step if new dh_pub, skipped messages. **State snapshot/rollback on failure (M9):** `_snapshot()` captures all mutable state before modifications, `_restore()` rolls back on any exception. +- `_snapshot() -> dict` / `_restore(snap)` — Snapshot: dh_pair, dh_remote, root_key, send/recv chain keys, counters, skipped dict. Used internally by `decrypt()`. +- `export_state() -> bytes` / `import_state(data) -> DoubleRatchet` — JSON serialization + +**SenderKeyState class:** +- `__init__(sender_key=None)` — generates random 32B key if None +- `encrypt(plaintext) -> {chain_id, n, ciphertext, nonce}` — AAD = chain_id + message number +- `decrypt(chain_id_hex, n, ciphertext, nonce)` — fast-forwards chain if needed. **State snapshot/rollback on failure (M9):** snapshots chain_key, n, _known_keys before fast-forward, restores on exception. +- `export_key() -> bytes` — for distribution to group members +- `from_key(exported_key) -> SenderKeyState` — receiver initializes from exported key +- `export_state() / import_state()` — full state persistence + +**Contact Key Verification:** +- `FINGERPRINT_VERSION = 0` — version byte for fingerprint algorithm +- `compute_fingerprint(user_id, identity_key_bytes, iterations=5200) -> bytes` — iterated SHA-512, truncated to 32 bytes. Matches Signal's NumericFingerprint. +- `format_fingerprint(fp_bytes) -> str` — 32 bytes → 6 groups of 5 digits (30 digits), 2 lines +- `compute_safety_number(my_uid, my_ik, their_uid, their_ik) -> str` — 60 digits (12 groups of 5), deterministic ordering (lower uid first), 3 lines of 4 groups +- `encode_verification_qr(user_id, identity_key_bytes) -> bytes` — `0x01 + uid_len(1B) + uid(UTF-8) + ik(32B)` +- `decode_verification_qr(data) -> (user_id, identity_key_bytes)` — inverse of encode + +**Message Padding:** +- `_PAD_MAGIC = b"\x01"` — prefix byte distinguishing padded from legacy unpadded messages +- `_PAD_BUCKETS = [64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536]` — target sizes +- `pad_plaintext(plaintext) -> bytes` — Pad to nearest bucket. Format: `0x01 + plaintext + random_padding + pad_length(4B big-endian)`. +- `unpad_plaintext(data) -> bytes` — Remove padding. Legacy unpadded messages (starting with `{`) returned unchanged. + +### chat_core.py + +**Local key storage** (`~/.encrypted_chat/{email}/`): +``` +private.pem / public.pem — RSA (login, ECP1 format when password-encrypted) +identity_private.bin / identity_public.bin — Ed25519 (ECP1 format when password-encrypted, 32B raw otherwise) +device_id.txt — This device's UUID +spk_private.bin / spk_id.txt — Current signed prekey (AES-256-GCM via local_key) +prev_spk_private.bin / prev_spk_id.txt — Previous SPK for grace period (AES-256-GCM via local_key) +opk_private/{opk_id}.bin — One-time prekeys (AES-256-GCM via local_key) +login_lockout.json — Brute-force lockout state (failed_attempts, locked_until) +sessions/{user_id}_{device_id}.bin — Double Ratchet states (per peer device) +sender_keys/{conv_id}.bin — Own sender key states +sender_keys_recv/{conv_id}_{sender_id}_{device_id}.bin — Received sender keys (per sender device) +known_identity_keys.bin — TOFU registry: {user_id -> {identity_key hex, first_seen, last_seen}} +verified_contacts.bin — Explicit verification: {user_id -> {identity_key hex, verified_at, method}} +``` + +Storage functions: `save_keys()`, `load_keys()`, `_save_identity_keys()`, `_load_identity_keys()`, `_save_spk(local_key=)`, `_load_spk(local_key=)`, `_save_prev_spk(local_key=)`, `_load_prev_spk(local_key=)`, `_save_opk_private(local_key=)`, `_load_opk_private(local_key=)`, `_delete_opk_private()`, `_save_session()`, `_load_session()`, `_save_sender_key_state()`, `_load_sender_key_state()`, `_save_recv_sender_key()`, `_load_recv_sender_key()`, `_save_known_identity_keys()`, `_load_known_identity_keys()`, `_save_verified_contacts()`, `_load_verified_contacts()` + +Lockout functions: `_check_lockout(email) -> float`, `_record_failed_attempt(email)`, `_clear_lockout(email)`. Constants: `_LOCKOUT_BASE_SECONDS=2`, `_LOCKOUT_MAX_SECONDS=300`. + +**ChatClient attributes:** +- `private_key` / `public_key` — RSA (login) +- `identity_private` / `identity_public` — Ed25519 +- `spk_private` / `spk_id` — Current SPK +- `_prev_spk_private` / `_prev_spk_id` — Previous SPK for grace period (M4) +- `opk_privates: dict[str, X25519PrivateKey]` — OPK private keys by ID +- `device_id: str | None` — this device's UUID (persisted to disk) +- `sessions: dict[str, DoubleRatchet]` — "user_id:device_id" -> ratchet (per peer device) +- `sender_key_states: dict[str, SenderKeyState]` — conv_id -> own sender key +- `recv_sender_keys: dict[str, SenderKeyState]` — "conv_id:sender_id:device_id" -> their key +- `_device_bundle_cache: dict[str, tuple[float, list]]` — user_id -> (timestamp, device_bundles) with 5-min TTL +- `_pending_self_encrypt: list[dict]` — queue of received messages to self-encrypt for multi-device access +- `_user_cache: dict[str, dict]` — user_id -> {identity_key, username, email, identity_key_status} +- `connected: bool` — current connection state +- `_known_identity_keys: dict` — TOFU registry (user_id -> {identity_key hex, first_seen, last_seen}) +- `_verified_contacts: dict` — explicit verification (user_id -> {identity_key hex, verified_at, method}) +- `_key_change_cb: Callable | None` — callback fired on identity key change (GUI wires to warning dialog) + +**Key methods:** +- `register()` — Generates RSA + Ed25519, sends to server +- `confirm_registration()` — Confirms code, uploads prekeys (SPK + 50 OPKs) +- `login()` — Loads keys from disk (including prev_spk for grace period), RSA challenge-response, auto `_ensure_prekeys()` +- `_ensure_prekeys()` — Checks OPK count AND SPK age. Replenishes OPKs if < 20, **rotates SPK if >= 7 days old** (M4). Saves old SPK as grace period before generating new one. +- `_get_device_bundles(peer_user_id)` — Fetches per-device key bundles with 5-min TTL cache +- `_get_or_create_session(peer_user_id, peer_device_id, bundle)` — Loads from memory/disk or creates via X3DH, keyed by "user:device" +- `_process_x3dh_header(sender_id, x3dh_header, sender_device_id, spk_override?)` — Bob side of X3DH. `spk_override` param allows using previous SPK for grace period (M4). +- `send_message(conv_id, text, members, reply_to?)` — Routes to `_send_dm` or `_send_group_message` +- `_send_dm()` — Per-device Double Ratchet (encrypts for each of recipient's devices), self-encrypted copy with SELF_DEVICE_ID +- `_send_group_message()` — Sender Keys, distributes key if new (per-device) +- `_distribute_sender_key()` — Sends sender key as control message via per-device pairwise ratchet, includes sender_device_id +- `_decrypt_dm()` — Handles X3DH header for new sessions, returns None for control messages. On X3DH decrypt failure, retries with previous SPK (M4 grace period). +- `_decrypt_group()` — Uses received sender key chain +- `decrypt_notification()` — Returns None for control messages (sender key distribution) +- `get_messages()` — Cache-first with incremental sync (`after_ts`). Decrypts new messages, self-encrypts received messages for multi-device, syncs deletions via `get_deleted_since`. Returns merged cache + new messages. +- `_build_messages_from_cache()` — Builds sorted message list from cache dict +- `_queue_self_encrypt()` / `_flush_self_encrypt()` — Queue and upload self-encrypted copies of received messages +- `authorize_device()` — Exports RSA + Ed25519 only (simplified for multi-device — no session/sender key transfer) +- `pairing_wait()` — Imports RSA + identity key from paired device (new device generates own SPK + OPKs on login) +- `reconnect()` — Closes connection, re-establishes TCP + RSA login using in-memory keys +- `get_profile(user_id?)` — Gets user profile from server +- `update_profile(**fields)` — Updates own profile (phone, location, visibility) +- `update_avatar(image_data)` — Uploads avatar +- `get_avatar(user_id)` — Downloads avatar bytes +- `send_file(conv_id, file_path, members, reply_to?)` — Encrypt + chunked upload + send message with `file` payload +- `download_file(file_id, file_info)` — Chunked download + AES-GCM decrypt +- `leave_group(conv_id)` — Leave group, clean up local sender keys +- `rename_conversation(conv_id, name)` — Rename group (creator only) +- `delete_conversation(conv_id)` — Delete conversation, clean up local sender keys +- `accept_invitation(conv_id)` — Accept group invitation +- `decline_invitation(conv_id)` — Decline group invitation +- `list_invitations()` — Fetch pending invitations +- `update_group_avatar(conv_id, image_data)` — Upload group avatar +- `get_group_avatar(conv_id)` — Download group avatar +- `search_messages(conv_id, query)` — Search decrypted message cache (client-side only) +- `reset_session(peer_user_id, peer_device_id?)` — Delete local session + notify peer via server +- `handle_session_reset_notification(from_user_id, from_device_id?)` — Handle incoming session reset +- `_load_verification_stores()` — Load TOFU + verified contacts from disk (called on login/registration/pairing) +- `check_identity_key(user_id, ik_bytes) -> str` — TOFU check, returns "new"/"trusted"/"verified"/"changed"/"changed_verified" +- `verify_contact(user_id, ik_bytes, method)` — Mark contact as explicitly verified +- `unverify_contact(user_id)` — Remove explicit verification +- `accept_key_change(user_id, new_ik_bytes)` — Accept changed key, remove old verification +- `get_verification_status(user_id) -> str` — Returns "verified"/"trusted"/"unverified" +- `get_safety_number(peer_user_id) -> str` — Formatted 60-digit safety number +- `get_my_fingerprint() -> str` — Formatted 30-digit own fingerprint +- `get_peer_fingerprint(peer_user_id) -> str` — Formatted 30-digit peer fingerprint +- `get_verification_qr_data() -> bytes` — QR code payload for own identity +- `verify_qr_code(qr_data) -> (ok, user_id, message)` — Decode + verify scanned QR code + +### gui_client.py + +**AsyncBridge (QThread):** Runs asyncio event loop, `schedule(coro)` queues coroutines, pyqtSignals emit results back to Qt main thread. + +**Key signals:** `login_result`, `conversations_loaded`, `messages_loaded`, `message_sent`, `new_notification`, `messages_read_notification`, `message_deleted_notification`, `conversation_updated`, `connection_state_changed`, `profile_loaded`, `profile_updated`, `avatar_loaded`, `online_status_changed`, `online_users_loaded`, `file_sent`, `file_downloaded`, `group_left`, `conversation_deleted`, `invitations_loaded`, `invitation_result`, `invitation_received`, `group_avatar_loaded`, `group_avatar_updated`, `session_reset_notification`, `key_change_warning` + +**MainWindow:** Dark theme (Catppuccin Mocha), conversation list with circular avatars + online green dot overlay + unread count badges + verification checkmark, message bubbles with colored left border, context menu (reply, delete, view image, download file), image thumbnails via QTextDocument resources (`thumb://{file_id}`), file cards with download links (`file://{file_id}`), connection indicator dot (green/red/orange), profile button, attach menu (Image/File), Leave Group button in group info, delete conversation button (trash icon in header), group avatar display + change in group info dialog, invitation list (amber border) above conversation list with right-click accept/decline. E2E label clickable → opens VerificationDialog. Key change warning dialog on identity key change. + +**UserProfileDialog:** View (read-only) and edit (own profile) modes. Fields: avatar (circular crop), username, email, phone, location, visibility toggles. Avatar upload/download. Security section (viewing others): verification status + fingerprint. Opened from "My Profile" button or user info button in group info dialog. + +**VerificationDialog:** Frameless dialog for contact verification. Shows: peer name, verification status, safety number (monospace), QR code image, both fingerprints. Buttons: "Mark as Verified" / "Remove Verification" / "Scan QR Code" (via pyzbar). QR generation via `qrcode` library. + +**Avatar system in conversation list:** +- `_avatar_cache: dict[str, QPixmap]` — user avatars by user_id +- `_group_avatar_cache: dict[str, QPixmap]` — group avatars by conv_id +- `_avatar_requested: set[str]` / `_group_avatar_requested: set[str]` — dedup download requests +- `_make_circular_avatar(pixmap, size=32)` — QPainter circular crop +- `_make_default_avatar(username, size=32)` — colored circle with initial letter (deterministic color from username hash) +- `_add_online_dot(avatar)` — green dot overlay bottom-right +- `_get_conv_avatar(conv)` — returns QIcon (DM: user avatar + online dot; group: group avatar or default) +- Periodic refresh every 2 minutes via `_refresh_timer` / `_on_periodic_refresh()` + +## Important Implementation Details + +### X3DH Header Caching +When `_get_or_create_session()` creates a new session via X3DH, it attaches the X3DH header as `ratchet._x3dh_header`. The next `_send_dm()` reads and deletes it. This ensures the X3DH header is only sent with the first message. + +### Self-Encryption for DMs +Sender uses `derive_self_encryption_key(identity_private)` to encrypt their own copy of sent messages with a static AES key. Uses `SELF_DEVICE_ID` sentinel so all own devices can read it. This allows reading own sent messages when fetching history from any device. **Received messages** are also self-encrypted after decryption (via `_pending_self_encrypt` queue + `_flush_self_encrypt()`), creating SELF_DEVICE_ID copies so other devices of the same user can read them. `batch_reencrypt_messages()` uses INSERT ON DUPLICATE KEY UPDATE (upsert) to handle both cases. + +### Sender Key Distribution as Control Messages +Sender keys are distributed via normal `send_message` protocol (per-device pairwise ratchet). The payload contains `_sender_key: {conv_id, key, sender_device_id}` field. On decryption, `_decrypt_dm()` detects this field, stores the sender key keyed by `"conv_id:sender_id:sender_device_id"`, and returns `None` (not shown to user). + +### Group Messages: Dummy Ratchet Header +Group messages use `{"dh_pub": "00"*32, "n": 0, "pn": 0}` as ratchet_header because the server requires it, but groups use sender keys instead of Double Ratchet. + +### Multi-Device Architecture +Each device has independent Double Ratchet sessions. Sessions are keyed by `"peer_user_id:peer_device_id"`. When sending a DM, the client fetches per-device key bundles via `_get_device_bundles()` and encrypts separately for each device. The server registers devices at login (`handle_login_finish`), assigns device IDs, and routes notifications with `device_entries` arrays (one entry per recipient device). Device IDs are persisted to `~/.encrypted_chat/{email}/device_id.txt`. Old session files (`{user_id}.bin`) are automatically migrated to `{user_id}_{device_id}.bin` on first load. + +### Server Session Model +`connected_clients: dict[str, list[ProtocolWriter]]` — one user can have multiple connections (multi-device). `writer_device_map: dict[int, str]` maps `id(writer)` to `device_id`. Notifications are pushed to all connections except the sender's current one. + +### Device Registration +On `login_finish`, server checks for `device_id` in the request. If present and valid (belongs to user), reuses it. Otherwise creates a new device. Device ID returned in response and stored on client disk. `list_devices` and `remove_device` handlers for device management. + +### Simplified Pairing (Multi-Device) +`authorize_device()` only exports RSA + identity key (no sessions/sender keys). New device generates its own SPK + OPKs on first login, creates independent sessions via X3DH. Old messages readable via self-encryption (shared identity key). `reencrypt_history()` still runs to ensure all messages have self-encrypted copies. + +### Real-time Conversation Notifications +`handle_create_conversation`, `handle_add_member`, `handle_remove_member`, `handle_leave_group`, `handle_delete_conversation`, `handle_accept_invitation` push notifications to affected members via `connected_clients`. Types: `conversation_created`, `member_added`, `member_removed`, `group_invitation`. GUI handles these via `conversation_updated` signal -> refreshes conversation list. + +### Connection State + Auto-Reconnect + TCP Keepalive +`ChatClient.connected` flag tracks TCP connection state. `_background_listener` sets `connected = False` when server closes connection and **fails all pending futures** with `ConnectionError` (prevents `send_and_recv` from hanging forever). `send_and_recv` has a 30s timeout via `asyncio.wait_for` and catches `ConnectionError`/`TimeoutError`. `reconnect()` re-establishes TCP + RSA challenge-response using in-memory keys (no password needed, includes `device_id`). GUI `_notification_loop` detects listener death -> triggers `_auto_reconnect` with exponential backoff (1s->2s->4s->...->30s). Connection indicator dot: green (connected), red (disconnected), orange (reconnecting). + +**TCP Keepalive (OS-level):** Both server and client enable `SO_KEEPALIVE` on the TCP socket with `TCP_KEEPIDLE=25s`, `TCP_KEEPINTVL=10s`, `TCP_KEEPCNT=3`. After 25s of idle the OS sends probe packets every 10s; if 3 probes go unanswered (25+3×10 = 55s) the OS marks the connection dead → `read_message()` returns `None` → auto-reconnect triggers. This prevents silent connection death through NAT/firewalls. + +**Dead writer cleanup:** `_notify_users()` and `_notify_users_individual()` in server.py check `w.is_closing()` before sending and catch exceptions on send. Failed writers are removed from `connected_clients` via `_remove_dead_writer()` instead of being silently ignored. This ensures stale connections don't accumulate and block notification delivery. + +**iOS implementation:** `NWProtocolTCP.Options` — set `keepaliveIdle = 25`, `keepaliveInterval = 10`, `keepaliveCount = 3` and pass to `NWParameters(tls:tcp:)`. No application-level ping/pong — connection is maintained purely via OS-level TCP keepalive. On `NWConnection.stateUpdateHandler` receiving `.failed` or `.waiting`, trigger reconnect with exponential backoff. + +### Server Per-Message Error Handling +Server dispatch loop wraps each handler call in individual try/except. Handler crashes return "Internal server error" response instead of killing the entire connection. Errors logged with `exc_info=True` for full tracebacks. GUI `_do_send_message`/`_do_find_or_create_and_send` catch exceptions and emit error signal (prevents silent hang when send fails). + +### Online/Offline Status +- `db.get_user_contacts(user_id)` returns all user IDs sharing at least one conversation +- On login (`handle_login_finish`): server sends `online_users` list to new user + `user_online` to all contacts (only if user was fully offline before) +- On disconnect (`handle_client` finally block): if last connection drops, server sends `user_offline` to all contacts +- `_background_listener` routes `user_online`, `user_offline`, `online_users` to notification queue +- GUI: `_online_users: set[str]` tracks online users, green dot overlay on circular avatar in DM conversation list + green circle emoji in group info member list + +### Leave Group +- `handle_leave_group` in server.py: validates membership, blocks DM leave (len<=2 and no name), transfers creator to first remaining member if creator leaves, removes member, notifies remaining via `member_removed` +- `ChatClient.leave_group()`: sends request, cleans up local sender key states on success +- GUI: red "Leave Group" button in group info dialog, confirmation dialog, resets view on success + +### Delete Conversation +- **DMs:** Any member can delete. Only removes the deleting user from `conversation_members`. If both users delete, 0 members remain → conversation + files cleaned up. +- **Groups:** Only the creator (admin) can delete. Removes ALL members, cleans up `.enc` files from disk, deletes conversation via CASCADE. +- Server notifies remaining members via `member_removed` push. +- GUI: trash icon button in conversation header. Visible for DMs always, for groups only when user is creator. +- `chat_core.py`: cleans up local sender key states after successful delete. + +### Group Invitations +- **Flow:** `create_conversation` (group) or `add_member` → creates invitation → pushes `group_invitation` notification → invitee sees in invitation list → Accept (adds to members, notifies) / Decline (deletes invitation) +- **DMs are unaffected:** `create_conversation` for DMs still auto-adds both members +- **DB:** `group_invitations` table with UNIQUE(conversation_id, user_id) to prevent duplicates +- **Server:** `handle_accept_invitation` verifies invitation exists, adds member, deletes invitation, notifies existing members via `member_added`. `handle_decline_invitation` just deletes. +- **GUI:** `inv_list` QListWidget (max 120px, amber border) above `conv_list`. Right-click → Accept/Decline. `invitation_received` signal triggers refresh + notification banner. +- **Routing fix (IMPORTANT):** `group_invitation` must be in the notification types list in `chat_core.py:_background_listener` (~line 304). Without it, invitations get routed to `_response_queue` and the GUI never sees them. + +### Group Avatars +- Stored as files in `UPLOAD_DIR/avatars/group_{conv_id}.{ext}` (PNG or JPEG detected from magic bytes) +- `conversations.avatar_file` column stores the filename +- `list_conversations` response includes `avatar_file` so GUI knows which groups have avatars +- GUI: `_group_avatar_cache` dict, `_get_conv_avatar()` returns group avatar icon or default letter circle +- Group Info dialog shows 64px circular avatar + "Change Avatar" button (creator only) +- Periodic refresh every 2 minutes re-downloads all known group avatars + +### File Sharing +- Reuses image upload/download infrastructure (`upload_image_start/chunk/end`, `download_stream`/`download_image`) +- **Chunk size:** `IMAGE_CHUNK_SIZE = 262144` (256 KiB). `MAX_MESSAGE_BYTES = 1048576` (1 MiB) — StreamReader limit raised to accommodate base64-encoded 256KB chunks. +- `upload_image_start` accepts optional `file_type` param: `"image"` (MAX_IMAGE_BYTES=5MB) or `"file"` (MAX_FILE_BYTES=50MB) +- `ChatClient.send_file()`: reads raw file, AES-256-GCM encrypts, pipelined upload (256KB chunks), sends message with `file` field in payload (`{file_id, aes_key, iv, filename, size, mime_type}`) +- **Download:** `_stream_download()` (preferred) sends single `download_stream` request, server streams all chunks back in sequence with same `request_id`. Fallback to `_legacy_download()` (per-chunk `download_image`) for older servers. +- **Media cache:** Decrypted files cached in `~/.encrypted_chat/{email}/media_cache/{file_id}.bin` (chmod 0o600). Cache-first: checked before any server call. Populated by both sender (after upload) and receiver (after download). +- GUI: attach button is dropdown menu (Image / File), file messages render as styled cards with paperclip icon (transparent background, border) and clickable download link (`file://{file_id}`), context menu "Download file" option. `image_download_failed` signal clears pending state on failure. +- Files stored as `.enc` in UPLOAD_DIR, same as images + +### Unread Count Badges +- `_unread_counts: dict[str, int]` replaces old `_unread_convs: set` +- `_on_notification()` increments count per conversation +- `_on_conv_selected()` clears count for selected conversation +- Display: `(3) Username` with bold font, instead of old `● Username` + +### User Profiles +`user_profiles` table separated from `users` (clean separation, users = auth only). Default profile created on registration (`db.create_default_profile`). Visibility rules applied server-side in `db.get_user_profile(viewer_id)`. Avatars stored as files in `UPLOAD_DIR/avatars/{user_id}.{ext}` (not in DB). Format detection from magic bytes (PNG header vs default JPEG). UserProfileDialog shows circular cropped avatar (QPainter). + +### Prekey Replenishment + SPK Rotation +After login, `_ensure_prekeys()` is called as a background task. Checks two things: +1. **OPK count** — if < 20, generates and uploads a new batch of 50 +2. **SPK age** — server returns `spk_created_at` in `get_prekey_count` response. If SPK is >= 7 days old (`SPK_ROTATION_DAYS`), triggers rotation: saves current SPK as `prev_spk_private.bin`/`prev_spk_id.txt` (grace period), generates new SPK, uploads to server. + +### Password-Based Key Encryption (ECP1 Format) — M3 +Private keys (RSA, Ed25519) are encrypted with a custom envelope instead of `BestAvailableEncryption`: +- **Key derivation:** PBKDF2-HMAC-SHA256 with 600,000 iterations (OWASP 2023 compliant) +- **Encryption:** AES-256-GCM with the derived key, magic bytes as AAD +- **Format:** `_ECP1_MAGIC("ECP1", 4B) + salt(16B) + nonce(12B) + ciphertext_with_tag(N+16B)` +- **Backward compatibility:** `load_private_key()` and `load_ed25519_private()` detect ECP1 magic prefix. If absent, fall back to legacy PEM parsing (old `BestAvailableEncryption` format). On next save, files are re-encrypted in ECP1 format. +- **Functions:** `_encrypt_private_key()`, `_decrypt_private_key()` in `crypto_utils.py` +- **Applied to:** `serialize_private_key()` (RSA), `serialize_ed25519_private()` (Ed25519) + +### SPK Rotation (7-Day Cycle) — M4 +Signed Pre-Keys rotate periodically to limit exposure from a compromised SPK: +- **Rotation interval:** `SPK_ROTATION_DAYS = 7` (constant in `chat_core.py`) +- **Trigger:** `_ensure_prekeys()` checks `spk_created_at` from `get_prekey_count` response. If age >= 7 days, calls `_generate_and_upload_prekeys()`. +- **Grace period:** Before generating a new SPK, the current one is saved as `prev_spk_private.bin` / `prev_spk_id.txt`. Loaded on login into `_prev_spk_private` / `_prev_spk_id`. +- **Fallback on decrypt:** When `_decrypt_dm()` processes an X3DH header and decryption fails with the current SPK, it retries with the previous SPK via `_process_x3dh_header(..., spk_override=self._prev_spk_private)`. This handles in-flight X3DH initiated before rotation. +- **Server side:** `get_signed_prekey()` in `db.py` returns `created_at` column. `handle_get_prekey_count` includes `spk_created_at` (ISO format) in response. +- **Server SPK replacement:** `store_signed_prekey()` deletes old SPK and inserts new one — only one active SPK per device on server. + +### Ratchet State Rollback on Decrypt Failure — M9 +Both `DoubleRatchet.decrypt()` and `SenderKeyState.decrypt()` modify internal state (chain keys, counters, DH keys) before attempting AES-GCM decryption. If decryption fails (corrupted data, wrong key, AAD mismatch), the state would be permanently corrupted. + +**DoubleRatchet fix:** +- `_snapshot()` captures all mutable fields: `dh_pair`, `dh_remote`, `root_key`, `send_chain_key`, `recv_chain_key`, `send_n`, `recv_n`, `prev_send_n`, `skipped` dict (shallow copy) +- `decrypt()` takes snapshot before any state modification, wraps the entire DH ratchet + chain advance + AES-GCM decrypt in try/except, calls `_restore()` on failure +- Special case: skipped message decryption (no state modification needed) — if AES-GCM fails, the popped key is restored to `skipped` dict + +**SenderKeyState fix:** +- Before fast-forwarding the chain, snapshots `chain_key`, `n`, `_known_keys` (shallow copy) +- On any exception during fast-forward or AES-GCM decrypt, all three are restored + +### Message Reactions +- `ALLOWED_REACTIONS = {"thumbsup", "heart", "laugh", "surprised", "sad", "thumbsdown"}` in `db.py` +- `message_reactions` table with UNIQUE(message_id, user_id, reaction) — one reaction type per user per message +- `handle_react_message`: validates UUID, reaction, membership. Adds/removes reaction. Pushes `message_reacted` to all members. +- `get_messages` response includes `reactions: [{user_id, reaction, created_at}]` per message (batch query via `db.get_reactions()`) +- GUI: React submenu in context menu with toggle (add if not present, remove if already reacted). Badges below message text showing emoji + count, highlighted border if own reaction. Real-time update via `_on_reaction_notification`. +- Forwarding uses `forwarded_from` metadata in plaintext payload — no new server protocol needed, just client convention. + +### Pinned Messages +- `messages.pinned_at` (DATETIME) and `messages.pinned_by` (CHAR(36)) columns +- `handle_pin_message`: validates UUID, membership, pins/unpins. Pushes `message_pinned`/`message_unpinned`. +- `handle_get_pinned_messages`: returns pinned message metadata for a conversation +- `get_messages` response includes `pinned_at` and `pinned_by` per message +- GUI: Pin/Unpin in context menu, pin emoji in message header, "Pinned" button in chat header opens dialog with scrollable list (double-click scrolls to message in chat) + +### @Mentions +- Client-side only — no server-side handling or special notifications +- GUI `MentionCompleter`: popup `QListWidget` shown when `@` detected at cursor, filters conversation members by prefix, inserts `@username ` on selection +- Rendering: `re.sub(r'@(\w+)', ...)` highlights mentions in blue bold (`#89b4fa`) +- Triggered from `_on_input_changed()` -> `_check_mention_trigger()` + +### Contact Key Verification (Safety Numbers / Fingerprints / QR Codes) +- **Zero server changes** — entirely client-side. Server already stores identity keys in `users.identity_key`. +- **TOFU (Trust On First Use):** `known_identity_keys.bin` stores first-seen identity key per user. Encrypted with `_local_key` (AES-256-GCM). Loaded on login/registration/pairing. +- **Explicit verification:** `verified_contacts.bin` stores verified contacts with method + timestamp. Encrypted with `_local_key`. +- **Fingerprint algorithm:** Iterated SHA-512 (5200 iterations), seed = `version(2B) + identity_key(32B) + user_id(UTF-8)`. Each iteration: `SHA-512(prev + identity_key)`. Output: first 32 bytes → 6 groups of 5 zero-padded digits. +- **Safety number:** Both users' fingerprints concatenated (lower user_id first → deterministic ordering). 64 bytes → 12 groups of 5 digits, displayed as 3 lines of 4 groups. Both sides see the same number. +- **QR code format:** `0x01 + uid_len(1B) + uid(UTF-8) + identity_key(32B)`. Generated via `qrcode` library, decoded via `pyzbar` (optional). +- **Identity key status:** `check_identity_key()` returns `"new"` | `"trusted"` | `"verified"` | `"changed"` | `"changed_verified"`. Called from `_get_user_info()`, result stored in `_user_cache` as `identity_key_status`. +- **Key change callback:** `_key_change_cb(user_id, username, old_key_hex, was_verified)` fires on key change. GUI wires to `key_change_warning` signal → warning dialog. +- **GUI indicators:** Green checkmark badge in conversation list for verified DMs (`ROLE_VERIFIED` data role, painted by `ConversationDelegate`). E2E label in chat header shows "Verified" (green) or "Encrypted" (muted) — clickable → opens `VerificationDialog`. Green checkmark next to verified members in group info. Security section in `UserProfileDialog` showing fingerprint + status. +- **VerificationDialog:** Frameless dialog showing safety number (monospace), QR code, both fingerprints, verification status. Buttons: "Mark as Verified" / "Remove Verification" / "Scan QR Code". +- **CLI:** Option 20 (Verify contact) — show safety number + mark verified. Option 21 (Show my fingerprint). +- **Tamper resistance:** `_load_known_identity_keys()` a `_load_verified_contacts()` nemají plaintext migration fallback (na rozdíl od sessions/sender keys). Soubory nikdy neexistovaly v nešifrované podobě — feature přidána po implementaci lokálního šifrování. Pokud útočník s přístupem k disku nahradí šifrovaný soubor plaintextovým JSONem, `_decrypt_local()` selže a load vrátí `{}` (prázdný stav). Útočník nemůže podvrhnout falešný verified status ani potlačit key-change warning. +- **iOS implementation spec (pro implementaci v ios_client/):** + +| Položka | Specifikace | +|---------|-------------| +| **Fingerprint algoritmus** | SHA-512 iterated 5200×. Seed = `version(2B big-endian, hodnota 0) + identity_key(32B) + user_id(UTF-8)`. Každá iterace: `SHA-512(prev_result + identity_key)`. Výstup: prvních 32 bytes. | +| **Fingerprint display** | 6 skupin × 5 číslic: `UInt64(bytes[i*5..": {"identity_key": "", "first_seen": "ISO8601", "last_seen": "ISO8601"}}}` | +| **Verified storage** | Keychain nebo šifrovaný soubor. JSON schema: `{"version": 1, "contacts": {"": {"identity_key": "", "verified_at": "ISO8601", "method": "safety_number\|qr_code\|manual"}}}` | +| **Šifrování storage** | AES-256-GCM klíčem z HKDF(identity_key, salt="local_storage_key", info="encrypted_chat_local"). **Žádný plaintext fallback** — pokud decrypt selže, vrátit prázdný dict. | +| **Identity key status** | `checkIdentityKey(userId, ikBytes) → "new" \| "trusted" \| "verified" \| "changed" \| "changed_verified"`. Volat při každém `getUserInfo()`. | +| **Self exclusion** | Vlastní user_id nikdy nezobrazovat jako unverified — přeskočit v conv listu i group info. | +| **UI: Conversation list** | Zelená fajfka (SF Symbol `checkmark.seal.fill`) vedle jména pro verified DM kontakty. | +| **UI: Chat header** | "Verified" (zelená) nebo "Encrypted" (šedá) pod jménem. Tap → otevře VerificationView. | +| **UI: VerificationView** | Safety number (monospace), QR kód (CIQRCodeGenerator), oba fingerprints, status. Tlačítka: "Mark as Verified" / "Remove Verification" / "Scan QR Code". | +| **UI: Key change alert** | `.alert()` s textem "Identity key for [name] has changed!" + "Accept" / "View Details". | +| **UI: Group info** | Zelená fajfka vedle verified členů (ne u sebe). | +| **Cross-platform test** | Python klient a iOS klient musí pro stejný pár (user_id_A, ik_A, user_id_B, ik_B) vypočítat identický safety number. QR vygenerovaný na jedné platformě musí být čitelný na druhé. | + +**iOS: TCP Keepalive + Optimistic Send + Cache-First spec:** + +| Položka | Specifikace | +|---------|-------------| +| **TCP Keepalive** | `NWProtocolTCP.Options()` → `keepaliveIdle = 25`, `keepaliveInterval = 10`, `keepaliveCount = 3` → předat do `NWParameters(tls:tcp:)`. Žádný app-level ping/pong. Při `.failed`/`.waiting` → reconnect s exponenciálním backoffem (1→2→4→...→30s). | +| **Optimistic send** | Po tapnutí Send: (1) vytvořit `Message(isOptimistic: true, id: tempUUID, text: text, sender: me)`, (2) přidat do messages array + zobrazit v UI + scroll to bottom, (3) async `sendMessage()` → při úspěchu nahradit optimistickou zprávu potvrzenou (match přes text+sender, nové `message_id`+`created_at` ze serveru), při chybě odstranit z UI + alert. | +| **Cache-first loading** | Při otevření konverzace: (1) okamžitě `getCachedMessages(convId)` → zobrazit z lokálního úložiště, (2) async `getMessages(convId, afterTs:)` → doplnit nové zprávy. Uživatel vidí zprávy okamžitě, server sync na pozadí. | +| **Fetch dedup** | `inflightFetches: Set` — pokud pro convId už běží fetch, nový nespouštět. Zabraňuje duplicitním requestům při rychlém přepínání. | +| **Dead writer** | Server-side: `_remove_dead_writer()` automaticky odstraní mrtvé writery. iOS klient nemusí nic speciálního. | +| **`[PUSH]` log** | Server loguje `[PUSH] msg=... targets=[uid(Nw)]` — N = počet writerů per příjemce. Pro debug pokud zprávy nedorazí. | + +**iOS: Media Transfer spec (v0.8.5):** + +| Položka | Specifikace | +|---------|-------------| +| **Protokol verze** | `VERSION = "0.8.5"`, `MIN_CLIENT_VERSION = "0.8.5"`. Server odmítá klienty starší než 0.8.5 — `client_version` v `login_finish` requestu musí být >= 0.8.5. | +| **Chunk size** | `IMAGE_CHUNK_SIZE = 262144` (256 KiB). Platí pro upload i download. Starý 32KB chunk size je nekompatibilní — server čte/píše v 256KB blocích. | +| **Buffer limit** | `MAX_MESSAGE_BYTES = 1_048_576` (1 MiB). NWConnection musí zvládnout přijmout JSON zprávu do 1 MiB (base64-encoded 256KB chunk ≈ 341KB + JSON overhead ≈ 343KB, bezpečně pod limitem). Pokud NWConnection parsuje po newline, nastavit buffer >= 1 MiB. | +| **Upload flow** | Beze změny — pipelined `upload_image_start` → N × `upload_image_chunk` (256KB, base64) → `upload_image_end`. Chunk size větší = méně chunků (5MB obrázek = ~20 chunků místo 160). | +| **Download** | iOS používá výhradně legacy `download_image` per-chunk (jeden request na chunk, offset inkrementuje po `len(chunk)`). `download_stream` existuje na serveru, ale iOS ho nepoužívá — actor-based architektura způsobovala 60s timeouty. | +| **Sender media cache** | Po úspěšném `upload_image_end` cachovat **dešifrovaný** obrázek/soubor lokálně: `{cacheDir}/media_cache/{file_id}.bin`. Sender pak při kliknutí na vlastní obrázek nemusí stahovat ze serveru — okamžité zobrazení. | +| **Receiver media cache** | Po úspěšném download + decrypt cachovat výsledek na disk: `{cacheDir}/media_cache/{file_id}.bin`. Další zobrazení = okamžité z cache, žádný server call. | +| **Cache kontrola** | Před zahájením downloadu vždy zkontrolovat `FileManager.default.fileExists(atPath: cachePath)`. Pokud soubor existuje, vrátit `Data(contentsOf:)` bez server volání. | +| **Download failure handling** | Pokud download selže (timeout, error, disconnection), vyčistit pending stav (UI nesmí zůstat zablokované). Další kliknutí na obrázek musí spustit nový download. | +| **Timeout** | Upload chunk future: 30s per chunk. Download (legacy per-chunk): 30s per chunk. | +| **Šifrování souborů** | Beze změny: AES-256-GCM. Upload: `aes_encrypt(raw_data) → (key, iv, ct, tag)`, upload `ct + tag`. Download: stáhnout encrypted blob, `ct = blob[:-16]`, `tag = blob[-16:]`, `aes_decrypt(key, iv, ct, tag)`. Klíč a IV v message payloadu (`image.aes_key`, `image.iv`, base64). | + +**iOS: Send Queue & Background Upload (v0.8.5+):** + +| Položka | Specifikace | +|---------|-------------| +| **Architektura** | Priority-based send queue uvnitř ChatClient actoru. Text zprávy (priority 0) se prokládají mezi upload chunky (priority 1). UI input je vždy aktivní — `isSending` se nenastavuje na `true`. | +| **Optimistický thumbnail** | Před enqueuováním uploadu se vytvoří lokální JPEG thumbnail (max 6KB, progresivní snižování kvality 0.4→0.3→0.2→0.15, fallback 80×80px). Thumbnail se uloží na `Message.optimisticThumbnail` a zobrazí se okamžitě s upload spinnerem. | +| **Text interleaving** | Po každém chunk response se zkontroluje fronta — pokud čeká textová zpráva, odešle se před dalším chunkem. | +| **Upload failure** | Při selhání uploadu se pokračuje dalšími položkami ve frontě. Uživatel dostane info o chybě, ale následné zprávy se odešlou normálně. | +| **Android TODO** | Send queue je implementován pouze v iOS. Android potřebuje port — viz `ARCHITECTURE.md` v Android projektu. | + +### Rate Limits +- Per-IP+email window (60s): register 3/min, login 10/min, send_message 20/min +- Per-connection: 20 req/s — **`upload_image_chunk` is exempt** (a single 5MB image needs ~20 chunks in rapid succession; the upload subsystem has its own guards: per-user upload cap, per-user rate limit on `upload_image_start`, and file-size validation) +- Per-IP: max 10 connections, global max 200 +- Pairing: TTL 120s, max 90 poll attempts, pairing_start 10/min, pairing_poll 120/min, client polls every 2s + +### GUI Font Handling (IMPORTANT) +All widget stylesheet `font-size` declarations use `pt` (not `px`). Using `px` in Qt stylesheets sets `pixelSize` and leaves `pointSize=-1`, which causes `QFont::setPointSize: Point size <= 0` warnings on Windows. Conversion: `pt ~= px * 0.75` at 96 DPI. HTML styles inside QTextBrowser (`_render_single_message_html`) still use `px` — that's fine, QTextBrowser uses its own HTML renderer. Bold fonts for list items use `_bold_font()` helper + `item.setData(FontRole)` to avoid the same issue. + +### Phantom Users (Anti User-Enumeration) +- When a user creates a conversation with an unregistered email, the server creates a "phantom" user with `rsa_public_key = 'PHANTOM'` marker +- Phantom users have real crypto keys (Ed25519 IK, X25519 SPK + 5 OPKs) so X3DH works on the client side +- `handle_find_conversation` and `handle_create_conversation` create phantoms instead of returning "User not found" +- `handle_send_message` skips phantom recipients when storing `message_recipients` — only sender's self-encrypted copy is saved +- `phantom_user_ids: set[str]` in-memory cache loaded at startup from DB, updated on create/delete +- On registration (`handle_register_confirm`): if email belongs to a phantom, the phantom is **upgraded in-place** via `db.upgrade_phantom_user()` — preserves user_id and all FK references (invitations, conversation_members). Phantom's server-generated prekeys are deleted (real user uploads own). +- `handle_create_conversation` (groups) and `handle_add_member` create invitations for phantom users too. Push notifications only sent to non-phantom users. When phantom registers and logs in, they see pending invitations. +- Messages sent to phantom users are NOT stored and NOT recoverable after registration — this is by design (prevents user enumeration, sender sees own messages via self-encryption) +- DB functions: `db.create_phantom_user(email)`, `db.is_phantom_user(user_id)`, `db.delete_phantom_user(user_id)`, `db.upgrade_phantom_user(phantom_id, username, rsa_public_key_pem, identity_key)`, `db.get_all_phantom_user_ids()` + +### Logout/Login Fix +- `_is_logout` flag in MainWindow prevents `closeEvent()` from calling `bridge.stop()` which killed the asyncio loop +- On logout: set `_is_logout = True`, call `bridge.logout()`, then `close()` +- `closeEvent()` only calls `bridge.stop()` if `not self._is_logout` +- This allows `main()` to re-create the login/main windows after logout + +### Server Graceful Shutdown +- SIGINT handler force-closes all writers in `connected_clients` before the asyncio server context manager exits +- Without this, `async with server:` waited forever for `handle_client` loops to finish + +### Version Negotiation +- `VERSION = "0.8.4"` constant in `protocol.py` (shared between client and server) +- `MIN_CLIENT_VERSION = "0.8.3"` — server rejects clients older than this +- Client sends `client_version` in `login_finish` request (both `login()` and `reconnect()`) +- Server logs `client_version`, returns `server_version` in `login_finish` response +- Server startup log includes version: `"Encrypted chat server v0.8.4 listening on ..."` +- iOS client (0.8.3) stays compatible — new notification types (`message_reacted`, `message_pinned`, `message_unpinned`) and payload fields (`reactions`, `pinned_at`, `forwarded_from`) are ignored by older clients + +### Privacy Overlay (Lock Screen) +Anti-forensic privacy feature for PyQt GUI client: +- **Immediate overlay:** On `QEvent.Type.ActivationChange` (window loses focus), dark overlay (`rgba(30,30,46,245)`) with lock icon covers entire window. Protects against Alt+Tab thumbnails, screen sharing, shoulder surfing. +- **Timed lock:** After `_LOCK_TIMEOUT_MS` (30s) unfocused, overlay transitions to locked state — requires password to dismiss. +- **Password verification:** `_on_unlock_attempt()` reads `identity_private.bin` from disk and calls `_decrypt_private_key(data, password)` (ECP1 format: PBKDF2-600k + AES-256-GCM). Successful decryption = correct password. +- **Lock capability detection:** `_lock_capable` flag checks if identity key file starts with `b"ECP1"`. If key is not password-encrypted (legacy/no-password), lock timer never fires (overlay still works as visual privacy screen). +- **Toggle:** Ctrl+Shift+P enables/disables the feature (default: enabled). +- **Notification handling during lock:** `_show_tray_notification()` checks `self._privacy_locked` — tray toasts continue while locked. `_on_notification()` increments unread counts and skips `mark_read` while locked. +- **Components:** `_privacy_overlay` (QWidget), `_lock_input` (QLineEdit password), `_lock_error` (QLabel), `_lock_timer` (QTimer single-shot), `_lock_hint` (QLabel status text). + +### Secure Deletion (Anti-Forensic Wipe) +`_secure_delete(path)` helper in both `chat_core.py` and `server.py`: +- Opens file with `r+b`, overwrites entire content with `os.urandom(size)`, calls `f.flush()` + `os.fsync(f.fileno())`, then `p.unlink()`. +- Fallback: if overwrite fails (permissions, etc.), falls back to standard `p.unlink(missing_ok=True)`. +- **Applied to (chat_core.py):** `_delete_opk_private()`, `_delete_session_file()`, session migration cleanup, sender key migration cleanup, message cache migration (plaintext JSON → encrypted). +- **Applied to (server.py):** Conversation delete (all `.enc`+`.tmp`), message delete (`.enc`), oversized upload chunk cleanup (`.tmp`), incomplete/invalid upload end (`.tmp`), stale upload periodic cleanup (`.enc`+`.tmp`). + +### Metadata Privacy +Four measures to minimize metadata leakage: +- **Message Padding:** `pad_plaintext()`/`unpad_plaintext()` in `crypto_utils.py`. Plaintext padded to nearest bucket size (64B..64KB) before encryption. Format: `0x01 + plaintext + random + length(4B)`. Legacy unpadded messages (prefix `{`) auto-detected by `unpad_plaintext()`. Applied on all 6 send paths + 2 decrypt paths in `chat_core.py`. +- **Log Sanitization:** `_who(session)` returns `u=XXXXXXXX d=YYYYYYYY` (truncated user_id + device_id). Group names, recipient counts, emails, and usernames removed from all server log lines (register, login, DM create, invite, rename, send_message). +- **Metadata Retention:** `db.cleanup_old_reads(days)` and `db.cleanup_old_reactions(days)` delete old interaction data in batches. Default `METADATA_RETENTION_DAYS=90`. Runs every ~1 hour (every 30th cycle of `_periodic_cleanup`). `cleanup_old_reads` joins on `messages.created_at` to only delete reads for old messages. `get_unread_counts(max_age_days)` excludes messages older than retention window — prevents phantom unreads after read cleanup. Indexes: `idx_reads_read_at` on `message_reads.read_at`, `idx_reactions_created_at` on `message_reactions.created_at`. +- **Sender Chain Minimization:** For new group messages, `sender_chain_id`/`sender_chain_n` stored in per-recipient `message_recipients.ratchet_header` instead of `messages` table. Removes persistent sender correlation from message-level DB rows. Server verifies group-only context (dummy `dh_pub` all-zeros) before extracting chain data from per-recipient header — prevents DM injection. Self-copy entries (sender's `user_id == session["user_id"]`) are skipped during chain_meta injection. `_validate_header` now accepts `{"self": true}` for per-recipient ratchet_header (fixes self-copy storage in DB). `handle_get_messages` extracts from both locations (backward compat). Push notifications still include chain data for live decrypt. +- **Architectural limitation (known):** Server still stores `messages.sender_id` and `message_recipients.user_id` — the communication graph (who talks to whom, when) remains visible to the server. Full metadata hiding (e.g. Signal's Sealed Sender, onion routing) is a fundamentally different architecture requiring sender anonymity at the protocol level. Current metadata privacy measures reduce *unnecessary* metadata exposure (message length, log PII, interaction history retention, group chain correlation) but do not hide the communication graph itself. + +## Conventions + +- Server handlers: `handle_(msg, session, writer)` — registered in dispatch table in `handle_client()` +- DB functions: one `get_connection()` per call, `cursor(dictionary=True)`, returns dicts +- Binary data: always base64 in protocol (`encode_binary`/`decode_binary`) +- GUI signals: bridge emits `pyqtSignal`, MainWindow connects in `_connect_signals()` +- Error responses: `{"status": "error", "data": {"message": "..."}}` +- Notification decrypt returning `None` = control message, skip silently +- GUI stylesheet font sizes: always `pt`, never `px` (see Font Handling section above) +- File sharing reuses image upload infrastructure with `file_type` parameter +- Avatar files stored in `UPLOAD_DIR/avatars/` — user: `{user_id}.{ext}`, group: `group_{conv_id}.{ext}` + +## Aktuální stav práce + +### ✅ Dokončeno (tato session) +- **iOS client (Swift/SwiftUI)** — Full native iOS port in `ios_client/` (47 files, ~6200 lines). Wire-compatible with Python server. Crypto layer: CryptoKit (AES-GCM, HKDF, Ed25519, X25519), pure Swift GF(2^255-19) for Ed→X conversion, Security.framework RSA-4096, ECP1 key encryption (PBKDF2 600k + AES-GCM). Protocol: Network.framework TCP+TLS, newline-delimited JSON. Core: ChatClient actor (1644 lines) with X3DH, Double Ratchet, Sender Keys, per-device sessions, SPK rotation. UI: SwiftUI views (login/register, conversation list, chat with message bubbles, group info, profile, search). Server-side RSA PSS compatibility fix: `PSS.MAX_LENGTH` → `PSS.AUTO` in `rsa_verify()` to accept both Python (max salt) and iOS (hash-length salt) signatures. +- Logout/login bug fix — `_is_logout` flag prevents bridge.stop() on logout +- Hover text readability — `color: #cdd6f4;` added to `QListWidget::item:hover` +- File card background — `background:transparent; border:1px solid #45475a` +- Delete conversations — full stack (db, server, chat_core, gui), DMs + groups (creator-only), file cleanup from disk +- Group invitation system — full stack (schema, db, server, chat_core, gui), create/accept/decline, real-time notification, invitation list UI +- Circular avatars in conversation list — QPainter circular crop, default letter avatars, online green dot overlay +- Group avatar support — upload/download, display in group info dialog, "Change Avatar" button (creator only) +- Server graceful shutdown — force-close connected clients on SIGINT +- Profile dialog avatar circular crop — QPainter in UserProfileDialog._on_avatar_loaded +- Periodic refresh timer — 2-minute QTimer re-downloads avatars + invitations +- Group invitation notification fix — `group_invitation` added to `_background_listener` notification types +- Delete button in conversation header — trash icon for DMs always, groups creator-only +- File cleanup on conversation delete — `db.get_conversation_file_ids()` + unlink `.enc` files +- Removed right-click "Delete conversation" from conversation list context menu +- README.md updated +- **H4 Race conditions fix** — 4 asyncio.Lock guards (`_clients_lock`, `_conn_lock`, `_pairing_lock`, `_uploads_lock`) pro všechny sdílené mutable struktury v server.py. `_notify_users()` + `_notify_users_individual()` helpery. Rate limit memory cleanup v periodic task. Všechny I/O operace mimo kritické sekce. +- **Unread counts pro offline uživatele** — `db.get_unread_counts()` dotaz přes `message_reads` + `message_recipients`, server vrací `unread_count` v `list_conversations`, GUI populuje `_unread_counts` ze serverových dat (max z server vs local). Opravuje bug kdy offline uživatel po přihlášení neviděl nepřečtené zprávy. +- **C6 Path traversal fix** — `_UUID_RE` regex + `_valid_file_id()`, `_safe_upload_path()`, `_safe_avatar_path()` helpery v server.py. UUID validace v `handle_upload_image_start`, `handle_download_image`. `is_relative_to()` guard ve všech path konstrukcích: upload start/end, download, delete_message file cleanup, delete_conversation file cleanup, _cleanup_uploads, get/update avatar, get/update group avatar. Celkem 10 guardovaných míst. +- **C3+H1+M13 Lokální šifrování + permissions** — `derive_local_storage_key()` v crypto_utils.py (HKDF z identity key, odlišný salt/info od self-encryption key). `_encrypt_local()`/`_decrypt_local()` helpery v chat_core.py (AES-256-GCM, formát: nonce(12)+tag(16)+ct). `_save_session`/`_load_session`, `_save_sender_key_state`/`_load_sender_key_state`, `_save_recv_sender_key`/`_load_recv_sender_key` — volitelný `local_key` parametr, při nastavení šifruje/dešifruje, `chmod 0o600` na soubory. `ChatClient._local_key` derivováno při login/registraci/pairingu. Transparentní migrace: pokud dešifrování selže, zkusí plaintext a re-uloží šifrovaně. `os.chmod(d, 0o700)` na všechny `mkdir()` v get_key_dir, opk_private, sessions, sender_keys, sender_keys_recv, message_cache. `os.chmod(p, 0o600)` na plaintext fallback message cache. +- **H7 Avatar path traversal** — `_safe_avatar_path()` guard na handle_get_avatar, handle_get_group_avatar + defense-in-depth na handle_update_avatar, handle_update_group_avatar. +- **Multi-device support (per-device sessions)** — `devices` table, `device_id` columns on prekeys/messages/recipients/sender_keys. Server: device registration at login, `writer_device_map`, per-device key bundles (`device_bundles` array), per-device notification routing (`device_entries`), `list_devices`/`remove_device` handlers. Client: `device_id` persistence, sessions keyed by `"user_id:device_id"`, `_get_device_bundles()` with 5-min TTL cache, per-device encryption in `_send_dm`/`send_image`/`send_file`/`_distribute_sender_key`, `sender_device_id` in decrypt routing, `decrypt_notification()` handles `device_entries` format. Pairing simplified: only RSA + identity key transfer, new device generates own SPK + OPKs. Migration: old session files auto-migrated, backward compat with old clients/servers. H12 OPK race condition fixed (SELECT FOR UPDATE). +- **Connection resilience fixes** — `_background_listener` fails all pending futures with `ConnectionError` on disconnect (prevents hang). `send_and_recv` has 30s timeout + catches `ConnectionError`. Server dispatch has per-message try/except (handler crash no longer kills connection). GUI `_do_send_message`/`_do_find_or_create_and_send` catch exceptions and emit error signal. +- **DB transaction fix** — `db.get_key_bundles_for_user()` had "Transaction already in progress" error because mysql-connector starts implicit transactions. Fixed with `conn.commit()` before `conn.start_transaction()`. +- **H5+H6 Protocol error handling** — `decode_binary()` catches `binascii.Error` → `ValueError`. `parse_message()` catches `JSONDecodeError`/`UnicodeDecodeError` → `ValueError`. Server dispatch already handles `ValueError` from `read_message()` gracefully. +- **H3+H13 Anti-enumeration** — `handle_register_start` returns same "ok" response for existing email (no "Email already in use" leak). `handle_login_start` returns fake challenge for non-existent email. `handle_login_finish` returns generic "Invalid credentials" for all failure cases. `get_user_info` moved behind auth barrier (requires login). +- **H8 Password memory cleanup** — `register()`, `login()`, `pairing_wait()` convert password to `bytearray`, zero out in `finally` block after key derivation. +- **H10 Image validation** — `_safe_load_image()` helper validates size (<10MB) and dimensions (<8192px) before `QImage.fromData()`. Applied to all 6 image loading locations in gui_client.py. +- **H11 Filename sanitization** — `_safe_filename()` helper strips path components via `os.path.basename()`. Applied to save dialogs and image dialog title. +- **C1+C2+C5 DoS hardening** — C1: `LimitOverrunError` now drains buffer and raises `ValueError` (server sends error response instead of silent disconnect; memory already protected by `limit=` on StreamReader). C2: `MAX_SENDER_KEY_SKIP` reduced from 1000 to 256 (matches DoubleRatchet `MAX_SKIP`). C5: `handle_upload_image_end` validates `received_bytes == file_size` before completing upload. M12 (upload end size validation) also fixed by C5. +- **M2+M8+M10+M11 Security hardening batch** — M2: SenderKeyState HKDF salt changed from `b""` to `b"\x00" * 32` (matches X3DH convention). M8: `_valid_file_id()` renamed to `_valid_uuid()`, UUID validation added to all handlers accepting client-provided `conv_id`, `user_id`, `message_id`, `device_id`. M10: `handle_mark_read` caps `message_ids` to 500 (prevents slow SQL DoS). M11: `handle_pairing_start` generates `poll_token` (secrets.token_hex(16)), `handle_pairing_poll` requires and validates it via `secrets.compare_digest()` — prevents unauthorized poll/payload extraction. +- **H2+H14 TLS hardening** — `TLS_INSECURE` a `TLS_AUTOGEN` nyní vyžadují `ENVIRONMENT=dev` (RuntimeError bez toho). Warning log na serveru i klientovi když TLS vypnuté. C4 (OPK file permissions) bylo již opraveno v C3+H1+M13 batchi. +- **Online dot fix + sorting** — Fixed timing issue where `online_users` signal processed before conv list populated. `_rebuild_conv_list()` sorts: favorites first → online DMs → rest alphabetically. +- **Favorites system** — Right-click context menu on conversation list → Add/Remove from favorites. Star indicator (★). Persisted to `favorites.json` in user key directory. +- **Group renaming** — Full stack: `db.update_conversation_name()`, `handle_rename_conversation` (server, creator-only, max 100 chars), `rename_conversation()` (chat_core), "Rename" button in group info dialog (GUI), `conversation_renamed` push notification to all members. +- **M3+M4+M9 Security hardening** — M3: PBKDF2 600k iterations (`_encrypt_private_key`/`_decrypt_private_key` s ECP1 formátem, backward compat pro PEM). M4: SPK rotace každých 7 dní, `spk_created_at` v `get_prekey_count`, grace period s `prev_spk_private.bin`, fallback v `_process_x3dh_header`/`_decrypt_dm`. M9: `_snapshot()`/`_restore()` v `DoubleRatchet.decrypt()`, snapshot/restore v `SenderKeyState.decrypt()`. +- **Phantom invitation fix** — Phantom users now receive group invitations. `handle_create_conversation` and `handle_add_member` create invitations for phantoms (no push notification). `handle_register_confirm` upgrades phantom in-place via `db.upgrade_phantom_user()` (preserves user_id + FK references). `handle_add_member` creates phantom for unregistered emails (same as `create_conversation`). After registration, user sees pending invitations on login. +- **Message Search** — Client-side search through decrypted message cache. `ChatClient.search_messages()` searches local cache. GUI: collapsible search bar (Ctrl+F) with prev/next navigation, match count, yellow/orange highlighting in message HTML. Escape closes search. Search button in chat header. +- **Session Recovery** — `session_reset` protocol message + `handle_session_reset` server handler. `ChatClient.reset_session()` deletes local session + notifies peer. Peer handles `session_reset` notification by deleting their session. Next message auto-creates new session via X3DH. GUI: "Reset session with sender" context menu on undecryptable messages, status bar notification on incoming reset. +- **L8 Phantom user DB inflation fix** — `_valid_email()` helper validates email format before phantom creation in `handle_find_conversation`, `handle_create_conversation`, `handle_add_member`. `db.cleanup_stale_phantoms(30)` deletes phantom users older than 30 days with no active conversations with real users. Runs in `_periodic_cleanup` every 10 minutes, refreshes in-memory `phantom_user_ids` cache. +- **M6 TOCTOU race fix** — `db.remove_conversation_member_atomic()` returns bool (True if row existed). Used in `handle_remove_member` (checks return value, returns error if already removed) and `handle_leave_group`. Defense-in-depth: pre-checks remain for user-friendly errors, atomic operation prevents double-removal. +- **Local-first messages + multi-device received messages fix** — Self-encrypted copies for received messages (not just sent). `batch_reencrypt_messages()` changed to INSERT ON DUPLICATE KEY UPDATE (upsert) — allows creating SELF_DEVICE_ID rows for received messages. Server-side Python dedup in `handle_get_messages` (prefers device-specific over SELF_DEVICE_ID). `after_ts` parameter for incremental sync. `get_deleted_since` protocol + handler for deletion sync. `_pending_self_encrypt` queue + `_flush_self_encrypt()` for background self-encryption of received notifications via `asyncio.ensure_future()`. +- **Detailed server logging** — `_who(session)` helper for consistent log formatting. 25+ tagged log lines: `[CONN]`, `[LOGIN]`, `[REGISTER]`, `[MSG]`, `[FETCH]`, `[READ]`, `[UPLOAD]`, `[DOWNLOAD]`, `[CONV]`, `[INVITE]`, `[LEAVE]`, `[RENAME]`, `[DELETE]`, `[AVATAR]`, `[PREKEYS]`, `[X3DH]`, `[ROTATE]`, `[DEVICE]`, `[REENCRYPT]`, `[SESSION]`, `[MEMBER]`, `[LIST]`, `[ERROR]`. +- **Version bump to 0.8.4** — `VERSION = "0.8.4"`, `MIN_CLIENT_VERSION = "0.8.3"` in protocol.py. Server rejects clients older than 0.8.3. iOS client (0.8.3) still works — new notification types and payload fields are silently ignored. +- **Message Reactions** — Emoji reactions on messages (`thumbsup`, `heart`, `laugh`, `surprised`, `sad`, `thumbsdown`). `message_reactions` DB table. `db.add_reaction()`/`db.remove_reaction()`/`db.get_reactions()`. Server: `handle_react_message` + `message_reacted` push notification. `get_messages` response includes `reactions` array per message. GUI: reaction submenu in context menu, toggle (add/remove), emoji badges below messages (grouped by reaction, highlighted if own), real-time updates via notification. CLI: option 16. +- **Message Forwarding** — Forward messages to other conversations. `ChatClient.forward_message()` sends as normal `send_message` with `forwarded_from` metadata field (`{sender, conversation_id, message_id}`). No new server endpoint needed. GUI: "Forward" in context menu, conversation picker dialog, "Forwarded from" header with blue left border in message rendering. CLI: option 19. +- **Pinned Messages** — Pin/unpin messages in conversations. `messages.pinned_at`/`pinned_by` columns. `db.pin_message()`/`db.unpin_message()`/`db.get_pinned_messages()`. Server: `handle_pin_message` + `handle_get_pinned_messages` + `message_pinned`/`message_unpinned` push notifications. GUI: "Pin"/"Unpin" in context menu, pin emoji indicator in message header, "Pinned" button in header opens dialog with list (double-click scrolls to message), real-time updates. CLI: options 17-18. +- **@Mentions** — `@username` highlighting in messages. Client-side only (no server-side handling). GUI: `MentionCompleter` popup autocomplete activated when typing `@` in message input, filters current conversation members, inserts `@username` on selection. Message rendering: `@word` patterns highlighted in blue bold (`#89b4fa`). CLI: visible as plain `@username` text. +- **Drag & Drop souborů/obrázků** — Přetažení souboru na chat oblast (message input nebo message area) automaticky odešle. Obrázky (`.png`, `.jpg`, `.jpeg`, `.gif`, `.bmp`, `.webp`) se posílají přes `send_image`, ostatní přes `send_file`. Vizuální feedback: modrý dashed border při drag-over. Drop je aktivní pouze když je vybraná konverzace (`drop_enabled` flag na `MessageInput`, `current_conv_id` check v event filteru na `message_area`). +- **Registration prekey upload fix** — `_ensure_prekeys()` now detects missing SPK on server (empty `spk_created_at`) and forces upload with `need_new_spk=True`. Previously, registration uploaded prekeys before login (no session → server rejected with "Not logged in"), and the login flow only uploaded OPKs without SPK. Also added warning log in `_generate_and_upload_prekeys()` when upload fails. +- **Message sender identification fix** — `_render_single_message_html()` now uses `sender_id` (user_id UUID) instead of `sender` (username) for `is_me` detection. Fixes incorrect message alignment when two users share the same display name. +- **Privacy Overlay (lock screen)** — Anti-forensic privacy feature (PyQt only). On window deactivation: immediate dark overlay with lock icon hides all chat content (protects against Alt+Tab thumbnails, screen sharing, shoulder surfing). After 30 seconds unfocused: locks and requires login password to unlock (verified by decrypting identity key from disk via ECP1/PBKDF2-600k). Ctrl+Shift+P toggles on/off. Tray notifications continue working while locked. Messages arriving during lock are counted as unread and NOT marked as read until user unlocks. `_lock_capable` flag auto-detects if identity key is password-encrypted (ECP1 format) — lock requires password only when available. +- **Secure Deletion (anti-forensic wipe)** — `_secure_delete(path)` helper overwrites file with `os.urandom()` + `fsync` before `unlink`. Applied to all sensitive file deletions: OPK private keys, session files (reset + migration), received sender key migration, message cache migration (plaintext JSON cleanup) in chat_core.py; `.enc` and `.tmp` files on conversation delete, message delete, oversized upload cleanup, incomplete upload cleanup, stale upload periodic cleanup in server.py. Fallback to standard `unlink` if overwrite fails. +- **UI redesign (Signal/Telegram look)** — `theme.py` theme systém (ThemeColors dataclass, DARK_THEME Catppuccin Mocha, LIGHT_THEME Signal-inspired, ThemeManager singleton s persistence + live switching). Widget-based message bubbles (MessageBubble QFrame s QPainter rounded rect). ConversationDelegate (custom painting: avatar, name, preview, timestamp, unread badge). Redesigned chat header (avatar, name, status, action buttons). Pill-shaped input (auto-resize, 2-line min). Tabbed login (Login/Register/Link Device). Frameless dialogs (`_make_frameless`). Theme toggle (sun/moon) v settings. +- **Reakce/piny persistentní v cache** — `ChatClient.update_message_in_cache()` synchronně aktualizuje pole zprávy v šifrovaném message cache na disku. Volá se při přidání/odebrání reakce, pin/unpin (z kontextového menu i z push notifikací). Opravuje bug kdy reakce a piny mizely po přepnutí konverzace a návratu (incremental sync nenačítal stará data ze serveru). +- **Context menu fallback na zprávách** — `MessageBubble` context menu policy změněna z `CustomContextMenu` na `DefaultContextMenu`. S `CustomContextMenu` Qt emitoval nepřipojený signál místo volání `contextMenuEvent()` override — kontext menu se neukázalo. Event filter zůstává jako primární handler, `contextMenuEvent` je fallback. +- **Forward obrázků a souborů** — `forward_message()` v chat_core.py přeposílá kompletní `image`/`file` metadata (file_id, AES klíč, IV, thumbnail, filename, size). Dříve se posílal jen textový popis `[Forwarded image: ...]`. Šifrovaný soubor je na serveru, stačí přeposlat metadata. +- **Delete message okamžitý** — Smazání zprávy se projeví lokálně ihned (bez čekání na reload). Server posílá `message_deleted` notifikaci jen ostatním, ne odesílateli — opraveno lokálním označením zprávy jako smazané + uložením do cache + re-renderem před odesláním na server. +- **Frameless confirmation dialogy** — `_confirm_dialog()` helper nahrazuje `QMessageBox.question()` — frameless dialog bez systémové lišty, s červeným "Delete" tlačítkem. Aplikováno na Delete Message a Reset Session dialogy. +- **Reakce — explicitní barva textu** — Reaction badge QLabel nyní má explicitní `color: {t.text_primary}` aby byl text viditelný v obou režimech (předtím spoléhal na CSS dědičnost). +- **SPK/OPK encryption + brute-force lockout** — SPK/OPK private keys now encrypted with AES-256-GCM via `_local_key` (derived from identity key via HKDF). Transparent migration from plaintext. Client-side brute-force lockout: exponential backoff `min(2^N, 300)` seconds after N failed password attempts. Lockout state in `login_lockout.json`. Applied to `ChatClient.login()` and GUI privacy overlay unlock (`_on_unlock_attempt`). `_clear_lockout()` on success. +- **Contact Key Verification** — Signal-style safety numbers, fingerprints, QR codes. TOFU key tracking (`known_identity_keys.bin`) + explicit verification (`verified_contacts.bin`), both encrypted with `_local_key`. `crypto_utils.py`: `compute_fingerprint()` (iterated SHA-512, 5200x), `format_fingerprint()` (30 digits), `compute_safety_number()` (60 digits, symmetric), `encode_verification_qr()`/`decode_verification_qr()`. `chat_core.py`: `check_identity_key()` (TOFU with "new"/"trusted"/"verified"/"changed"/"changed_verified"), `verify_contact()`, `unverify_contact()`, `accept_key_change()`, `get_verification_status()`, `get_safety_number()`, `get_my_fingerprint()`, `get_peer_fingerprint()`, `get_verification_qr_data()`, `verify_qr_code()`. GUI: `VerificationDialog` (safety number, QR code, fingerprints, verify/unverify buttons, QR scan), green checkmark in conversation list for verified DMs (`ROLE_VERIFIED`), E2E label clickable with verification status, key change warning dialog, security section in `UserProfileDialog`, verified badge in group info member list. CLI: options 20 (verify contact) and 21 (show fingerprint). Zero server changes. +- **Metadata Privacy (Ochrana metadat)** — Čtyři opatření pro minimalizaci metadat: **(A) Message Padding** — `pad_plaintext()`/`unpad_plaintext()` v crypto_utils.py. Bucketed padding (64/128/256/512/1K/2K/4K/8K/16K/32K/64K). Formát: `0x01 + plaintext + random_padding + pad_length(4B)`. Prefix `0x01` rozliší od legacy JSON (začíná `{`). Aplikováno na všech 6 send cest v chat_core.py (send_message, distribute_sender_key, forward_message, send_image, send_file, reencrypt_history) + unpad na 2 decrypt cestách (_decrypt_dm, _decrypt_group). **(B) Log Sanitizace** — `_who()` vrací `u=XXXXXXXX d=YYYYYYYY` místo `username (email) [device]`. Group names a recipient counts odstraněny z logů. **(C) Metadata Retention** — `db.cleanup_old_reads()`/`db.cleanup_old_reactions()` mažou záznamy starší N dní (default `METADATA_RETENTION_DAYS=90`). Batch delete (10k/iterace). Spouští se v `_periodic_cleanup()` každé 2 minuty. **(D) Sender Chain Přesun** — `sender_chain_id`/`sender_chain_n` se pro nové group zprávy ukládají do `message_recipients.ratchet_header` místo `messages` tabulky. Server `handle_get_messages` extrahuje z obou míst (backward compat). Notifikace stále posílají chain data pro live decrypt. + +- **TCP Keepalive + Dead Writer Cleanup** — Oprava tichého umírání TCP spojení přes NAT/firewall, které způsobovalo nedoručení push notifikací (zprávy se zobrazily až po manuálním vstupu do konverzace). **(A)** Server + klient nastavují `SO_KEEPALIVE` s `TCP_KEEPIDLE=25s`, `TCP_KEEPINTVL=10s`, `TCP_KEEPCNT=3` na TCP socketu — OS posílá probe pakety a po 55s bez odpovědi označí spojení jako mrtvé. **(B)** `_notify_users()`/`_notify_users_individual()` v server.py kontrolují `w.is_closing()` před odesláním, při selhání logují a odstraňují mrtvý writer z `connected_clients` přes nový `_remove_dead_writer()` helper (dříve `except Exception: pass` tiše spolkl chybu). **(C)** `ProtocolWriter.is_closing()` nový helper v protocol.py. +- **Optimistic Message Send** — Zpráva se zobrazí v UI okamžitě po stisknutí Send (bez čekání na server). `_on_send()` v gui_client.py vytvoří optimistický payload s `_optimistic: True` a přidá ho do `current_messages` + UI. Po server response `_on_message_sent_payload()` najde optimistickou zprávu (match přes text+sender) a nahradí ji potvrzenou verzí (s `message_id`, `created_at`). Při chybě odeslání `_on_message_sent()` optimistickou zprávu odstraní z UI a zobrazí error. +- **Cache-First Message Loading** — Při přepnutí konverzace se okamžitě zobrazí zprávy z lokálního cache (disk), server fetch běží na pozadí. `chat_core.get_cached_messages(conv_id)` čte z message_cache bez server callu. `_on_conv_selected()` volá `get_cached_messages()` synchronně → zobrazí → poté `bridge.load_messages()` async doplní nové. Fetch deduplication: `_messages_inflight` set v AsyncBridge zabraňuje duplicitním fetchům stejné konverzace. +- **Notification Push Logging** — Server loguje `[PUSH] msg=... conv=... targets=[uid(Nw)]` s počtem writerů per příjemce. `_notify_users_individual()` loguje warning při selhání doručení s user_id a chybou. +- **Image/File Transfer Performance Overhaul** — Drastické zrychlení downloadu obrázků a souborů: **(A)** Chunk size zvětšen z 32KB na 256KB (8× méně chunků, méně JSON/base64 overhead). **(B)** `MAX_MESSAGE_BYTES` zvětšen z 64KB na 1MB (nutné pro větší chunky). **(C)** Nový `download_stream` handler na serveru — jedna DB autorizace, pak server streamuje všechny chunky bez čekání na per-chunk request (dříve 2 DB queries × N chunků). Klient sbírá stream chunky přes `asyncio.Queue` v `_background_listener`. **(D)** Fallback na legacy `download_image` pro starší servery. **(E)** `image_download_failed` signál v GUI — `_pending_image_download` se vyčistí při selhání (dříve zůstal navždy a blokoval další downloads). **(F)** Sender cache: obrázek se cachuje lokálně po uploadu (`media_cache/{file_id}.bin`), sender vidí obrázek okamžitě bez server round-trip. + +### 🐛 Známé bugy a problémy +- **Sender Key Redistribution (High Priority):** New group member can't decrypt old messages. On `add_member`, existing members should re-create and redistribute sender keys. +- ~~**Pomalé přepínání konverzací (High Priority):**~~ ✅ OPRAVENO — cache-first loading + fetch deduplication. Server round-trip už jen na pozadí pro sync nových zpráv. +- ~~**Database Connection Pooling:**~~ ✅ OPRAVENO — `MySQLConnectionPool(pool_size=10)`. +- ~~**Duplicate FETCH after send (GUI):**~~ ✅ OPRAVENO — `send_message` vrací payload lokálně, GUI appenduje bez re-fetch. +- ~~**Group delete confirmation message is generic**~~ ✅ OPRAVENO — frameless dialog s kontextovým textem. +- ~~**Reakce a piny mizely po přepnutí konverzace:**~~ ✅ OPRAVENO — `update_message_in_cache()` ukládá na disk. +- ~~**Forward obrázku/souboru posílal jen text:**~~ ✅ OPRAVENO — přeposílá kompletní image/file metadata. +- ~~**Delete message se neprojevil okamžitě:**~~ ✅ OPRAVENO — lokální smazání + cache update před serverovým voláním. +- ~~**Delete/Reset dialogy měly systémovou lištu:**~~ ✅ OPRAVENO — frameless `_confirm_dialog()`. + +### ⏭️ Další kroky (TODO) + +#### Bezpečnostní opravy (priorita dle auditu) +1. **C6 (CRITICAL): Path traversal přes file_id** — `handle_upload_image_start` vytváří soubor `UPLOAD_DIR / f"{file_id}.tmp"` bez validace. Útočník může `../../...` a zapisovat/mazat mimo UPLOAD_DIR. Řešení: validovat UUID formát, ověřit `path.resolve().is_relative_to(UPLOAD_DIR.resolve())`. +2. ~~**H12 (HIGH): OPK race condition v db.get_key_bundle()**~~ ✅ OPRAVENO (součást multi-device — SELECT FOR UPDATE v consume_one_time_prekey + get_key_bundle) +3. **H3+H13: User enumeration** — `get_user_info` dostupné bez auth, vrací identity_key pro libovolný email. `register_start`/`login_start` vrací jednoznačné chyby. Řešení: auth pro `get_user_info`, generické odpovědi pro register/login. +4. ~~**H2+H14: TLS hardening**~~ ✅ OPRAVENO — `TLS_INSECURE` a `TLS_AUTOGEN` vyžadují `ENVIRONMENT=dev`. Warning log při vypnutém TLS. +5. ~~**C1+C2+C5**~~ ✅ OPRAVENO — DoS vektory (LimitOverrunError → ValueError, MAX_SENDER_KEY_SKIP 256, upload completeness check) +6. **C3+C4+H1** — Šifrování dat na disku (message cache, sessions, OPK permissions, `chmod 0o700` pro adresáře) +7. **H5+H6** — Error handling v protokolu (base64, JSON) +8. **H7** — Path traversal v avatar souborech (`resolved_path.is_relative_to()`) +9. ~~**M11 (MEDIUM): Pairing poll DoS**~~ ✅ OPRAVENO — poll_token binding (secrets.token_hex(16) + secrets.compare_digest) +10. ~~**M12: Upload end bez validace velikosti**~~ ✅ OPRAVENO (součást C5 fixu — `handle_upload_image_end` validuje `received_bytes == file_size`) +11. ~~**L8: Phantom user DB inflation**~~ ✅ OPRAVENO — email validace + periodic cleanup stale phantoms (30 dní) +12. **Version negotiation** — `VERSION = "0.8"` v protocol.py, klient posílá `client_version` při loginu, server loguje a vrací `server_version` + +#### Před nasazením do produkce (checklist) +1. **TLS certifikáty** — Získat certifikát (Let's Encrypt / vlastní CA). Nastavit `TLS_ENABLED=true`, `TLS_CERT_FILE`, `TLS_KEY_FILE` v `.env`. Ověřit že `TLS_INSECURE` a `TLS_AUTOGEN` NEJSOU nastavené (vyžadují `ENVIRONMENT=dev`). Na klientovi nastavit `TLS_ENABLED=true` a případně `TLS_CA_FILE` pokud vlastní CA. +2. **Email validace** — Zapnout `_valid_email()` kontrolu v `handle_find_conversation`, `handle_create_conversation`, `handle_add_member` (kód existuje v server.py, volání zakomentována). Teď vypnuto protože dev prostředí používá emaily bez @. +3. **MySQL TLS** — Přidat SSL parametry do `db.get_connection()` (`ssl_ca`, `ssl_cert`, `ssl_key`) pokud DB běží na jiném stroji. +4. **Connection pooling** — Nahradit `get_connection()` za `mysql.connector.pooling.MySQLConnectionPool(pool_size=10)`. +5. **SMTP** — Nastavit reálný SMTP server pro registrační kódy (`SMTP_HOST`, `SMTP_PORT`, `SMTP_USER`, `SMTP_PASS`, `SMTP_FROM`). +6. **UPLOAD_DIR** — Ověřit že `UPLOAD_DIR` je na persistentním disku s dostatkem místa, správnými právy (0o700). +7. **Rate limity** — Přezkoumat limity pro produkční zátěž (registrace 3/min, login 10/min, send_message 20/min, max 10 spojení/IP). +8. **Packaging** — Zabalit klienta (pyinstaller / cx_Freeze) pro distribuci. Po zabalení zvážit auto-update mechanismus a `get_version` endpoint. +9. **Penetrační testy** — Provést před ostrým nasazením (viz sekce níže). +10. **Backup** — Nastavit pravidelný backup MySQL databáze + `UPLOAD_DIR`. + +#### Penetrační testy +- Naplánovat a provést manuální penetrační testy zaměřené na: + - Path traversal (file_id, avatar_file) + - DoS vektory (readuntil, sender key fast-forward, upload flooding) + - Race conditions (OPK reuse, membership TOCTOU) + - User enumeration (register, login, get_user_info) + - TLS downgrade / MITM bez TLS + - Pairing session hijacking + - Memory exhaustion (rate_limits, phantom users, message_ids) +- Vytvořit testovací skripty pro automatizované security testy +- Zdokumentovat výsledky a opravit nalezené problémy + +#### Ke zvážení +- **Auto-update klientů** — distribuce aktualizovaných souborů klientům před login/registrací. Řešit až po kompilaci/packagingu (pyinstaller apod.). Mechanismus: server verze check → klient stáhne nové soubory → restart. +- **Server version check endpoint** — po packagingu mít jednoduchý endpoint (např. `get_version`), který vrací min/aktuální podporovanou verzi klienta + URL/metadata pro update; klient může před loginem ověřit kompatibilitu a nabídnout update. Vhodné i pro postupné vypínání starých klientů. + +#### Monetizace — plán + +**Princip:** Oddělený platební server (KYC/AML compliant) od chat serveru (anonymní). Platba generuje jednorázový premium kód, chat server zná jen "user_id aktivoval kód". Žádný přímý link platba↔chat identita. + +**Architektura:** +``` +Platební server (Stripe/Paddle): zákazník → platba → vygeneruje premium_code + sdílí jen: premium_code (jednorázový string) +Chat server: user_id → aktivuje premium_code → premium do data X +``` +- Platební server neví jaký user_id kód použil +- Chat server neví kdo kód koupil +- AML splněno na platební straně, privacy zachováno na chat straně + +**Free vs Premium tier:** +| | Free | Premium | +|---|---|---| +| Konverzace | 5 | neomezeno | +| Soubory | 10 MB | 50 MB | +| Zařízení | 1 | 5 | +| Message retention | 30 dní | neomezeno | +| Skupiny | max 10 členů | neomezeno | + +**Implementace (chat server):** +- Tabulka `premium_codes(code VARCHAR(64) PK, plan ENUM('premium_30','premium_365'), created_at, redeemed_by CHAR(36) nullable FK, redeemed_at nullable)` +- Tabulka `user_plans(user_id PK FK, plan ENUM('free','premium'), expires_at DATETIME nullable, message_retention_days INT DEFAULT 30)` +- Sloupec `messages.expires_at` (nullable DATETIME) — NULL = neomezená retence +- Handler `redeem_code` — validuje kód, aktivuje plán, nastaví expires_at +- `handle_send_message` — kontroluje limity (konverzace, velikost) +- Periodic cleanup: `DELETE FROM messages WHERE expires_at IS NOT NULL AND expires_at < NOW()` + CASCADE + `.enc` smazání +- Retence podle konverzace: pokud alespoň jeden člen premium → vyšší retence +- Klient: indikace expirace zpráv, upozornění na limity, "Upgrade" tlačítko + +**Implementace (platební server — oddělený deployment):** +- Jednoduchý web (Flask/FastAPI) s Stripe Checkout +- Generuje premium_code, uloží do sdílené DB nebo pošle přes API +- AML/KYC řeší Stripe (PCI DSS, SCA, reporting) + +**Další revenue streams:** +- **Enterprise/B2B licence** — self-hosted deployment, LDAP/SSO, admin dashboard, SLA. Faktura na IČO. +- **Šifrovaný cloud backup** — export/import historie šifrované uživatelským heslem (nezávisle na retenci) + +#### Funkční vylepšení +1. **Sender Key Redistribution** — on add_member, redistribute sender keys to all members including new one +2. ~~**Device Linking fix**~~ ✅ — replaced with true multi-device (per-device sessions, simplified pairing) +3. ~~**SPK Rotation**~~ ✅ — periodic rotation with grace period (implemented in M4 fix) +4. **Typing Indicators** (budoucí) — `typing_start`/`typing_stop` protocol + GUI indicator (3s timeout, debounce) +5. ~~**CLI support**~~ ✅ — profiles, file sharing, invitations, leave/rename/delete, search, devices in `client.py` +6. ~~**Message search**~~ ✅ — client-side search through decrypted cache, Ctrl+F toggle, highlight + navigation +7. ~~**Session Recovery**~~ ✅ — `session_reset` protocol, auto-recreate via X3DH on next message +8. ~~**Connection Pooling**~~ ✅ — `MySQLConnectionPool(pool_size=10)`, lazy init, `DB_POOL_SIZE` env var. `conn.close()` vrací do poolu. +9. ~~**Version negotiation**~~ ✅ — `VERSION = "0.8.4"` in protocol.py, client sends `client_version` at login, server rejects clients < MIN_CLIENT_VERSION +10. **Delivery Receipts** — `message_delivered` notifikace po přijetí na zařízení (1 fajfka = odesláno, 2 fajfky = doručeno, modré = přečteno). Nová tabulka `message_deliveries` nebo rozšíření `message_reads`. +11. ~~**Reakce na zprávy (Message Reactions)**~~ ✅ — emoji reakce na zprávy. Tabulka `message_reactions`. Push notifikace `message_reacted`. GUI: emoji badges pod zprávou, submenu v kontext menu. +12. ~~**Přeposílání zpráv (Message Forwarding)**~~ ✅ — kontext menu "Forward", výběr konverzace, odeslání s `forwarded_from` metadatem. GUI: "Forwarded from" header. +13. ~~**Připnuté zprávy (Pinned Messages)**~~ ✅ — `messages.pinned_at`/`pinned_by` sloupce. `pin_message`/`unpin_message`/`get_pinned_messages` protokol. GUI: pin ikona + dialog s připnutými zprávami. +14. ~~**Zmínky (@mentions)**~~ ✅ — parsování `@username` v textu, autocomplete při psaní @, zvýraznění zmínek v modré. Klient-side only (bez server notifikací). +15. ~~**Contact Key Verification**~~ ✅ — Signal-style safety numbers (60 digits, symmetric), fingerprints (30 digits), QR codes. TOFU key tracking + explicit verification. GUI: VerificationDialog, green checkmark v conv listu, E2E label status, key change warning. CLI: options 20+21. Zero server changes. + +#### Message Loading + Send Optimalizace (✅ částečně implementováno) + +**Implementováno:** + +1. ✅ **Cache-first zobrazení (gui_client.py: `_on_conv_selected()`):** + - `chat_core.get_cached_messages(conv_id)` čte zprávy z lokálního disku (message_cache), žádný server call + - `_on_conv_selected()` volá `get_cached_messages()` → okamžitě zobrazí v UI → poté `bridge.load_messages()` async syncne se serverem na pozadí + - Výsledek: zprávy se zobrazí **okamžitě** při kliknutí na konverzaci + +2. ✅ **Fetch deduplication (gui_client.py: `AsyncBridge._messages_inflight`):** + - `_messages_inflight: set[str]` — pokud pro conv_id už běží server fetch, nový se neprovede + - Eliminuje duplicitní round-tripy (dříve stejná konverzace fetchována 4× za sebou) + +3. ✅ **Optimistic message send (gui_client.py: `_on_send()`):** + - Zpráva se zobrazí v UI okamžitě po stisknutí Send (optimistický payload s `_optimistic: True`) + - Server potvrzení na pozadí → `_on_message_sent_payload()` nahradí optimistickou zprávu potvrzenou + - Při chybě → `_on_message_sent()` odstraní optimistickou zprávu + zobrazí error + +**Aktuální flow:** +``` +_on_conv_selected() [gui_client.py] + ├─ get_cached_messages() → zobrazit OKAMŽITĚ z disku + └─ bridge.load_messages() → ASYNC na pozadí: + └─ chat_core.get_messages(conv_id) + ├─ send_and_recv("get_messages", after_ts=...) + ├─ send_and_recv("get_deleted_since", ...) + └─ mark_conversation_read(conv_id) + → messages_loaded signal → aktualizovat UI pokud nové zprávy + +_on_send() [gui_client.py] + ├─ Optimistický payload → zobrazit OKAMŽITĚ v UI + └─ bridge.send_message() → ASYNC na pozadí: + └─ chat_core.send_message() → šifrování + server + → message_sent_payload signal → nahradit optimistickou zprávu +``` + +**TODO (zbývající optimalizace):** +- [ ] TTL cache (5s) — skip server fetch pokud nedávno syncováno +- [ ] Skip `get_deleted_since` při 0 nových zpráv +- [ ] Skip `mark_conversation_read` při 0 nepřečtených +- [ ] Debounce přepínání konverzací (150ms timer) + +#### Optimalizace serveru +1. ~~**DB Connection Pooling**~~ ✅ — viz bod 8 výše. +2. ~~**Oprava duplicitních FETCH v GUI**~~ ✅ — `send_message` vrací payload lokálně, GUI appenduje bez re-fetch. Dedup guard v `_on_notification`. +3. ~~**Batch prekey replenishment**~~ ✅ — `ensure_prekeys` handler na serveru (get_prekey_count + upload v jednom roundtripu). `_generate_and_upload_prekeys_batch()` v chat_core. +4. ~~**Server-side message count**~~ ✅ — `get_messages` response obsahuje `total_count`. `db.count_messages()` funkce. +5. **Prepared statements / query cache** — pro často opakované dotazy (get_messages, list_conversations) připravit prepared statements. +6. **WebSocket upgrade** — dlouhodobě nahradit raw TCP za WebSocket pro lepší kompatibilitu s firewally, load balancery, a web klienty. + +#### Mobilní push notifikace (budoucí — iOS + Android) +- Tabulka `push_tokens(user_id, device_id, platform ENUM('ios','android'), token, created_at)` +- Server: při `new_message` pokud cílový uživatel nemá aktivní TCP spojení → odeslat přes APNs (iOS) / FCM (Android) +- Obsah push: jen "Nová zpráva od X" (E2EE = server nezná plaintext) +- iOS: `UserNotifications` framework, registrace tokenu při loginu, `didReceiveRemoteNotification` +- Android: Firebase Cloud Messaging, `FirebaseMessagingService` +- Server-side: `aioapns` (Python APNs library) + `firebase-admin` (FCM SDK) + +## Bezpečnostní audit (Security Audit) + +Kompletní audit provedený přes všechny soubory projektu. Nálezy seřazené podle závažnosti. + +### 🔴 CRITICAL — Okamžitě řešit před nasazením + +#### ~~C1. readuntil() bez limitu → memory exhaustion (protocol.py:62)~~ ✅ OPRAVENO +`ProtocolReader.read_message()` volá `readuntil(b"\n")`, které načte CELOU zprávu do paměti PŘED kontrolou velikosti. Útočník pošle gigabyty dat bez newline → server spadne na out-of-memory. +```python +line = await self._reader.readuntil(b"\n") # buffers everything first! +if len(line) > MAX_MESSAGE_BYTES: # too late +``` +**Řešení:** Implementovat framing s hlavičkou obsahující velikost zprávy, nebo použít `readuntil()` s `limit` parametrem (asyncio StreamReader nemá nativně — nutno obalit vlastním čtením po částech). + +#### ~~C2. SenderKeyState — neomezený fast-forward DoS (crypto_utils.py:642-645)~~ ✅ OPRAVENO +Při dešifrování skupinové zprávy s libovolně vysokým `n` se smyčka `while self.n <= n` provede milionkrát — derivuje milion klíčů, spotřebuje stovky MB RAM. +```python +while self.n <= n: + self.chain_key, mk = kdf_ck(self.chain_key) + self._known_keys[self.n] = mk # unbounded dict growth + self.n += 1 +``` +**Řešení:** Přidat `MAX_FORWARD_SKIP` limit (např. 1000) — stejně jako Double Ratchet má `MAX_SKIP=256`. + +#### C3. Dešifrované zprávy uložené jako plaintext na disku (chat_core.py:222-239) +Message cache v `~/.encrypted_chat/{email}/message_cache/{conv_id}.json` obsahuje plný obsah dešifrovaných zpráv v nešifrovaném JSONu. Bez nastavení `chmod 0o600`. Kdokoliv s přístupem k disku přečte kompletní historii. +**Řešení:** Šifrovat cache klíčem odvozeným z identity key + nastavit `chmod 0o600` na soubory. + +#### ~~C4. OPK private keys bez file permissions (chat_core.py:153-156)~~ ✅ OPRAVENO +OPK privátní klíče se ukládají bez `os.chmod(0o600)`. RSA klíče (řádek 87) a identity key (řádek 121) mají `chmod` — OPK ne. Na sdílených systémech může jiný uživatel přečíst ephemeral klíče. +**Opraveno:** Součást C3+H1+M13 fixu — `_save_opk_private()` nyní volá `os.chmod(path, 0o600)` + `os.chmod(dir, 0o700)`. + +#### ~~C5. Chunked upload nevaliduje celkovou velikost (server.py:1138-1142)~~ ✅ OPRAVENO +`handle_upload_image_chunk` akumuluje `received_bytes` ale nekontroluje limit. Útočník deklaruje `file_size=5MB`, pak posílá chunky donekonečna → disk exhaustion. +```python +upload["received_bytes"] += len(raw) # no check against file_size! +``` +**Řešení:** Přidat `if upload["received_bytes"] > upload["file_size"]: reject`. + +### 🟠 HIGH — Řešit před production nasazením + +#### H1. Session + sender key soubory nešifrované na disku (chat_core.py:176-215) +Double Ratchet session state (DH privátní klíče, root key, chain keys) a sender key state se ukládají jako plaintext hex JSON v `sessions/` a `sender_keys/`. Bez šifrování, bez `chmod 0o600`. Kompromitace disku = dešifrování celé historie. +**Řešení:** Šifrovat state klíčem z identity key, nastavit `chmod 0o600`. + +#### ~~H2. TLS vypnuté ve výchozím stavu (chat_core.py:274-291, server.py)~~ ✅ OPRAVENO (hardening) +`TLS_ENABLED` je defaultně `false`. Bez TLS jdou po síti RSA challenge-response, session tokeny a metadata v plaintextu. `TLS_INSECURE=true` vypíná certificate verification → MITM. +**Opraveno:** `TLS_INSECURE` a `TLS_AUTOGEN` nyní vyžadují `ENVIRONMENT=dev` — v produkci RuntimeError. Warning log při vypnutém TLS na serveru i klientovi. TLS_ENABLED zůstává default false (uživatel nemá certifikát), ale po nasazení Let's Encrypt stačí flip na true. + +#### H3. User enumeration přes registraci (server.py:182-189) +Registrace vrací "Email already in use" pro existující uživatele vs. tiché vytvoření phantoma pro neexistující. Útočník může enumerovat platné emaily. +**Řešení:** Vrátit generickou odpověď "Check your email for verification code" i když email existuje. + +#### ~~H4. Race conditions v in-memory strukturách (server.py: multiple)~~ ✅ OPRAVENO +`connected_clients` dict, `phantom_user_ids` set, `pairing_sessions` dict — čteny a zapisovány z více concurrent koroutin bez synchronizace. Asyncio je single-threaded, ale yieldy uvnitř handlerů (await) mohou způsobit nekonzistentní stav. +**Opraveno:** 4 asyncio.Lock guards: `_clients_lock` (connected_clients, phantom_user_ids), `_conn_lock` (connection_counts, current_connections, rate_limits), `_pairing_lock` (pairing_sessions, pending_registrations), `_uploads_lock` (pending_uploads). Helper funkce `_notify_users()` / `_notify_users_individual()` — snapshot under lock, send outside. Rate limit memory cleanup v periodic task. Žádný handler nedrží dva locky současně → deadlock impossible. + +#### H5. base64 decode bez error handling (protocol.py:14-16, server.py + chat_core.py) +`decode_binary()` volá `base64.b64decode()` bez try-except. Nevalidní base64 od klienta → unhandled `binascii.Error` → handler crash. Mnoho callsites v server.py (řádky 357, 378, 783) nemá catch. +**Řešení:** Obalit `decode_binary()` try-except, nebo validovat base64 vstup před dekódováním. + +#### H6. JSON parsing bez exception handling (protocol.py:48-50) +`parse_message()` volá `json.loads()` bez try-catch. Malformovaný JSON = neošetřený `JSONDecodeError`. Server handler catch (řádek 1399) to odchytí, ale není to explicitní. +**Řešení:** Obalit `json.loads()` v `parse_message()` try-except s explicitní chybovou zprávou. + +#### H7. Path traversal v avatar souborech (server.py:1265, 1318) +`avatar_file` ze serveru (z DB) se přímo joinuje s `UPLOAD_DIR / "avatars"` bez validace. Pokud DB obsahuje `../../etc/passwd`, server přečte libovolný soubor. +**Řešení:** Přidat `resolved_path.resolve().is_relative_to(UPLOAD_DIR)` check. + +#### H8. Heslo v paměti jako Python string (chat_core.py, gui_client.py) +Python stringy jsou immutable — nelze je bezpečně vymazat z paměti. Heslo zůstává v paměti dokud garbage collector neuklidí. Memory dump = plaintext heslo. +**Řešení:** Použít `bytearray` (mutable), po použití přepsat nulami: `pwd[:] = b'\x00' * len(pwd)`. + +#### H9. Self-encryption key je statický a deterministický (chat_core.py:904+, crypto_utils.py:224-233) +`derive_self_encryption_key(identity_private)` produkuje vždy stejný klíč. Kompromitace identity klíče = dešifrování všech vlastních kopií zpráv navždy. Žádná forward secrecy pro self-copies. +**Poznámka:** Toto je by-design (nutné pro cross-device čtení), ale je to architektonické omezení. + +#### H10. Malicious image data → QImage crash (gui_client.py) +`QImage.fromData(data)` zpracovává nevalidované binární data. Speciálně vytvořený obrázek může způsobit crash, memory exhaustion, nebo v extrémním případě RCE přes Qt image codec. +**Řešení:** Validovat velikost dat před parsováním, limit na max rozlišení. + +#### H11. Filename z serveru v save dialogu (gui_client.py:2389, 2460) +Server-controlled `filename` se předává jako default do `QFileDialog.getSaveFileName()`. Pokud server pošle `"../../../.bashrc"`, dialog to navrhne. +**Řešení:** Sanitizovat filename — odstranit `../`, `\`, absolutní cesty. Použít jen `os.path.basename()`. + +### 🟡 MEDIUM — Zvážit pro hardening + +#### M1. Inconsistentní Ed25519 serializace (crypto_utils.py:99-102) +Bez hesla: 32 raw bytes. S heslem: PEM PKCS8 (~302 bytes). Dva různé formáty mohou způsobit problém při migraci nebo obnově klíčů. +**Poznámka:** M3 fix částečně řeší — s heslem je nyní ECP1 formát (ne PEM), ale `load_ed25519_private()` stále detekuje 3 formáty (ECP1, PEM, raw). Legacy PEM soubory se automaticky migrují při dalším uložení. + +#### ~~M2. Prázdný salt v SenderKeyState HKDF (crypto_utils.py:610)~~ ✅ OPRAVENO +`hkdf_derive(sender_key, salt=b"", ...)` — RFC 5869 doporučuje nenulový salt. X3DH správně používá `b"\x00" * 32`. +**Opraveno:** Změněno `salt=b""` → `salt=b"\x00" * 32` aby odpovídalo X3DH konvenci. + +#### ~~M3. PBKDF2 iterace pod doporučeným minimem (crypto_utils.py)~~ ✅ OPRAVENO +`BestAvailableEncryption` používá ~100k iterací PBKDF2. OWASP 2023 doporučuje 480k+. +**Opraveno:** Nahrazeno vlastním `_encrypt_private_key`/`_decrypt_private_key` s PBKDF2-HMAC-SHA256 (600k iterací) + AES-256-GCM. ECP1 formát (magic prefix) s backward compat pro staré PEM soubory. Aplikováno na RSA (`serialize_private_key`/`load_private_key`) i Ed25519 (`serialize_ed25519_private`/`load_ed25519_private`). + +#### ~~M4. SPK bez replay protection a bez rotace (server.py:360-368)~~ ✅ OPRAVENO +Stejný SPK lze nahrát opakovaně. Žádný nonce/timestamp v podpisu. SPK se nikdy nerotuje → kompromitovaný SPK = trvalé dešifrování nových sessions. +**Opraveno:** SPK rotace každých 7 dní (`SPK_ROTATION_DAYS`). Server vrací `spk_created_at` v `handle_get_prekey_count`. `_ensure_prekeys()` kontroluje stáří SPK a rotuje pokud >= 7 dní. Předchozí SPK uložen jako grace period (`prev_spk_private.bin`/`prev_spk_id.txt`) pro in-flight X3DH. `_process_x3dh_header` přijímá `spk_override`, `_decrypt_dm` retry s předchozím SPK při selhání dešifrování. + +#### ~~M5. Rate limit unbounded memory (server.py:73-83)~~ ✅ OPRAVENO (součást H4 fixu) +Staré záznamy se nikdy nečistí pokud klíč přestane být aktivní → útočník vytvoří miliony klíčů → memory leak. +**Opraveno:** `_cleanup_rate_limits()` v periodic cleanup (každých 10 min) maže stale entries z `rate_limits` i `connection_counts`. + +#### ~~M6. TOCTOU race v membership checks (db.py)~~ ✅ OPRAVENO +`is_conversation_member()` → `remove_conversation_member()` — mezi kontrolou a akcí může jiný klient stav změnit. +**Opraveno:** `db.remove_conversation_member_atomic()` vrací bool (True pokud řádek existoval). Použito v `handle_remove_member` a `handle_leave_group`. + +#### M7. MySQL spojení bez TLS (db.py:20-28) +`get_connection()` nepředává SSL parametry. Na vzdáleném serveru jdou credentials v plaintextu. + +#### ~~M8. Chybějící validace UUID formátu (server.py: throughout)~~ ✅ OPRAVENO +`conv_id`, `user_id` — kontrola jen na neprázdnost, ne na formát UUID v4. +**Opraveno:** `_valid_file_id()` přejmenováno na `_valid_uuid()`. UUID validace přidána do všech handlerů přijímajících klientem poskytnuté `conv_id`, `user_id`, `message_id`, `device_id`. + +#### ~~M9. Ratchet state corruption recovery (chat_core.py:1088-1104)~~ ✅ OPRAVENO +Pokud `decrypt()` změní chain keys ale selže na AAD verification, backup/restore mechanismus funguje, ale pokud backup selže (out-of-memory), stav zůstane corrupted. +**Opraveno:** `DoubleRatchet.decrypt()` nyní snapshotuje stav přes `_snapshot()`/`_restore()` a rollbackuje při jakékoliv výjimce (včetně skipped message key restore). `SenderKeyState.decrypt()` stejně snapshotuje `chain_key`, `n`, `_known_keys` před fast-forward a rollbackuje při selhání. + +#### ~~M10. Chybějící validace velikosti message_ids listu (db.py:641-646)~~ ✅ OPRAVENO +Klient může poslat tisíce message_ids v jednom požadavku → pomalý SQL dotaz, DoS. +**Opraveno:** `handle_mark_read` nyní odmítá požadavky s více než 500 message_ids. + +### 🟢 LOW — Dobrá praxe, nízké riziko + +- **L1.** Hex string keys v skipped messages dict — timing side-channel po úspěšné autentikaci (crypto_utils.py:425) +- **L2.** RatchetHeader serializace redundantně konvertuje typy (crypto_utils.py:394-405) +- **L3.** `notif_label.setText()` bezpečné proti XSS (Qt neinterpretuje HTML v setText), ale křehké — přepnutí na `setHtml()` by to rozbilo (gui_client.py:1524, 2259) +- **L4.** SQL column interpolation v `update_user_profile` — whitelist chrání, ale pattern je nebezpečný při kopírování (db.py:818-822) +- **L5.** Chybějící TLS cipher suite hardening — Python defaulty jsou rozumné, ale ne explicitně nastavené (protocol.py) +- **L6.** Temporary pairing key není bezpečně vymazán z paměti (chat_core.py:581) +- **L7.** `_user_cache` ukládá public identity keys indefinitely — memory leak pro hodně kontaktů + +### Druhý bezpečnostní review (zaměření na návrh, DB, komunikaci, lokální tmp/cache) + +#### C6. Path traversal → libovolný zápis/smazání souborů přes file_id (server.py) +`handle_upload_image_start` vytváří soubor `UPLOAD_DIR / f"{file_id}.tmp"` bez validace `file_id`. Útočník může poslat `../../...` a zapisovat mimo UPLOAD_DIR. Následné rename, cleanup, `delete_message` a `delete_conversation` pak mohou mazat libovolné soubory. +**Řešení:** Striktně validovat file_id (UUID hex/kanonický formát), odmítnout cokoliv s `/`, `\`, `..`. Ověřit `path.resolve().is_relative_to(UPLOAD_DIR.resolve())`. Ideálně ukládat do podadresářů podle hash/UUID. + +#### ~~H12. OPK race condition — reuse one-time pre-keys (db.py)~~ ✅ OPRAVENO +V `db.get_key_bundle()` se OPK vybírá SELECT → DELETE bez transakčního zámku. Při souběhu může být stejný OPK vydán vícekrát → porušení bezpečnostních předpokladů X3DH. +**Opraveno:** `consume_one_time_prekey()` a `get_key_bundle()` nyní používají `SELECT ... FOR UPDATE` + DELETE v jedné transakci (součást multi-device implementace). + +#### H13. Neautentizované get_user_info + identity key exfiltrace (server.py) +`get_user_info` je dostupné bez přihlášení a vrací username, email a identity_key pro libovolný email/user_id. Umožňuje enumeraci uživatelů a sběr metadata/klíčů. +**Řešení:** Vyžadovat auth, nebo omezit na "kontakty v konverzaci". + +#### ~~H14. TLS_INSECURE umožňuje MITM i v produkci (chat_core.py, server.py)~~ ✅ OPRAVENO +`TLS_INSECURE=true` vypíná verifikaci certifikátu → útočník může podvrhnout key bundle. Přímo ohrožuje E2EE integritu. +**Opraveno:** `TLS_INSECURE` vyžaduje `ENVIRONMENT=dev`, jinak RuntimeError. Součást H2 fixu. + +#### ~~M11. Pairing poll DoS — neautentizovaný přístup k payload (server.py)~~ ✅ OPRAVENO +Kdokoli s 8-místným kódem může pollovat a "vyzvednout" payload (smazán po vyzvednutí). I když je šifrovaný, jde o snadný DoS (reálnému zařízení pairing selže). +**Opraveno:** `handle_pairing_start` generuje `poll_token` (secrets.token_hex(16)), vrací klientovi. `handle_pairing_poll` vyžaduje `poll_token` a porovnává přes `secrets.compare_digest()`. Klient ukládá token v `pairing_start()` a posílá v `pairing_wait()`. + +#### ~~M12. Upload end bez validace received_bytes == file_size (server.py)~~ ✅ OPRAVENO +`upload_image_end` neověřuje, že `received_bytes == file_size`. Může zůstat nedokončený/nevalidní soubor. +**Řešení:** Kontrola délky před `complete_image_upload`. + +#### M13. Klíčové adresáře bez chmod 700 (chat_core.py) +`get_key_dir` a podadresáře (`sessions`, `sender_keys`, `opk_private`) se vytvářejí bez explicitních práv; spoléhá se na umask. +**Řešení:** Po `mkdir` vždy `chmod 0o700` pro adresář, `0o600` pro soubory. + +#### ~~L8. Phantom users — DB inflation (server.py, db.py)~~ ✅ OPRAVENO +`find_conversation` vytváří phantom usery pro libovolné emaily. I s rate limit lze DB časem nafouknout. +**Opraveno:** `_valid_email()` validace před vytvořením phantomu. `db.cleanup_stale_phantoms(30)` v periodic cleanup — maže phantomy starší 30 dní bez aktivních konverzací s reálnými uživateli. + +### Bezpečnostní matice (souhrn) + +| Soubor | CRITICAL | HIGH | MEDIUM | LOW | +|--------|----------|------|--------|-----| +| `protocol.py` | 1 (C1) | 2 (H5, H6) | 0 | 1 (L5) | +| `crypto_utils.py` | 1 (C2) | 0 | 3 (M1-M3) | 2 (L1, L2) | +| `server.py` | 2 (C5, C6) | 3 (H3, H7, H13) | 4 (M4, M8, M11, M12) | 0 | +| `chat_core.py` | 2 (C3, C4) | 4 (H1, H2/H14, H8, H9) | 2 (M9, M13) | 1 (L6) | +| `gui_client.py` | 0 | 2 (H10, H11) | 0 | 2 (L3, L7) | +| `db.py` | 0 | 1 (H12) | 3 (M6, M7, M10) | 2 (L4, L8) | +| **Celkem** | **6** | **12** | **12** | **8** | +| **Opraveno** | 6 (~~C1~~, ~~C2~~, ~~C3~~, ~~C4~~, ~~C5~~, ~~C6~~) | 11 (~~H1~~, ~~H2~~, ~~H3~~, ~~H4~~, ~~H5~~, ~~H6~~, ~~H7~~, ~~H8~~, ~~H10~~, ~~H11~~, ~~H12~~, ~~H13~~, ~~H14~~) | 11 (~~M2~~, ~~M3~~, ~~M4~~, ~~M5~~, ~~M6~~, ~~M8~~, ~~M9~~, ~~M10~~, ~~M11~~, ~~M12~~, ~~M13~~) | 1 (~~L8~~) | +| **Zbývá** | **0** | **1** | **1** | **7** | + +### Doporučené pořadí oprav (aktualizováno) +1. ~~**C6**~~ ✅ — Path traversal přes file_id — DONE (UUID validace + is_relative_to) +2. ~~**C1 + C2 + C5**~~ ✅ — DoS vektory — DONE (LimitOverrunError → ValueError, MAX_SENDER_KEY_SKIP 256, upload completeness check) +3. ~~**H12**~~ ✅ — OPK race condition — DONE (SELECT FOR UPDATE, součást multi-device) +4. ~~**C3 + H1 + H7 + M13**~~ ✅ — Šifrování dat na disku + file/dir permissions + avatar path traversal — DONE +5. ~~**H2/H14**~~ ✅ — TLS hardening (TLS_INSECURE + TLS_AUTOGEN vyžadují ENVIRONMENT=dev, warning log) — DONE +6. ~~**H5 + H6**~~ ✅ — Error handling v protokolu (base64, JSON) — DONE +7. ~~**H3 + H13**~~ ✅ — User enumeration (generické odpovědi, auth pro get_user_info) — DONE +8. ~~**H4**~~ ✅ — Race conditions (asyncio.Lock) — DONE +9. ~~**H8 + H10 + H11**~~ ✅ — Paměť hesel, image parsing, filename sanitizace — DONE +10. ~~**M2 + M8 + M10 + M11**~~ ✅ — Hardening batch (HKDF salt, UUID validace, message_ids cap, pairing poll token) — DONE +11. ~~**M3, M4, M9**~~ ✅ — PBKDF2 600k iterations, SPK rotace 7 dní s grace period, ratchet state rollback — DONE +12. **M1, ~~M6~~, M7** — Remaining hardening (Ed25519 serialization, ~~TOCTOU~~ ✅, MySQL TLS) +13. **Penetrační testy** — manuální + automatizované security testy + +## Důležitá rozhodnutí a kontext +- **Invitations replace direct add for groups:** `handle_add_member` and `handle_create_conversation` (for groups) now create invitations instead of directly adding members. DMs still auto-add both users. This was a design decision to give users control over joining groups. +- **Group delete = total destruction:** When creator deletes a group, ALL members are removed and the conversation is fully deleted. This is different from "leave group" which only removes the leaving user. +- **DM delete is per-user:** Deleting a DM only removes the deleting user. The other user still sees the conversation until they also delete it. +- **Avatar caching in GUI is pixmap-based:** `_avatar_cache` and `_group_avatar_cache` store QPixmap objects, not raw bytes. The `_on_avatar_for_conv_list` and `_on_group_avatar_for_conv_list` signals convert bytes → QImage → QPixmap on receipt. +- **No context menu on conversation list anymore:** Delete was the only action. Now handled by header buttons. `conv_list.setContextMenuPolicy(DefaultContextMenu)`. + +## Environment Variables +See README.md for full list. Key: `SERVER_HOST`, `SERVER_PORT`, `MYSQL_*`, `TLS_*`, `SMTP_*`, `LOG_LEVEL`, `MAX_INPUT_CHARS`, `UPLOAD_DIR`, `MAX_IMAGE_BYTES`, `MAX_FILE_BYTES`, `MAX_MESSAGE_BYTES`, `METADATA_RETENTION_DAYS` (default 90). + +## Commands & Workflow + +- Start server: `python server.py` +- Start GUI client: `python gui_client.py` +- Start CLI client: `python client.py` +- Environment: `.env` file in project root (loaded by `dotenv`) +- Dependencies: `PyQt6`, `mysql-connector-python`, `cryptography`, `Pillow` (for image sharing), `python-dotenv` +- Check syntax: `python3 -m py_compile .py` +- All files on both server AND client side: `crypto_utils.py`, `protocol.py`, `chat_core.py`, `gui_client.py` (or `client.py`) diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..d35affc --- /dev/null +++ b/Dockerfile @@ -0,0 +1,44 @@ +# Encrypted Chat Server — Docker image +# Builds only the server-side components (server.py, db.py, crypto_utils.py, protocol.py) +# GUI/iOS client files are not included. + +FROM python:3.12-slim + +# Install system deps needed by pyzbar (libzbar) and Pillow +RUN apt-get update && apt-get install -y --no-install-recommends \ + libzbar0 \ + libjpeg62-turbo \ + libpng16-16 \ + default-libmysqlclient-dev \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +# Install Python deps — separate layer so code changes don't bust the cache +COPY requirements.txt . +# Install server-only deps (skip PyQt6, pyzbar, qrcode — not needed server-side) +RUN pip install --no-cache-dir \ + cryptography \ + "mysql-connector-python>=8.3.0" \ + "python-dotenv>=1.0.0" \ + "Pillow>=10.0.0" + +# Copy server source files +COPY server.py db.py crypto_utils.py protocol.py schema.sql ./ + +# Optional: copy .env if it exists (overridden at runtime via env vars or mounted file) +# COPY .env . + +# Create uploads directory +RUN mkdir -p /app/uploads && chmod 700 /app/uploads + +# Expose the default server port +EXPOSE 5000 + +# Health check: attempt TCP connection to the server port +HEALTHCHECK --interval=30s --timeout=5s --start-period=15s --retries=3 \ + CMD python -c "import socket,sys; s=socket.socket(); s.settimeout(3); s.connect(('localhost', int(__import__('os').getenv('SERVER_PORT','5000')))); s.close()" || exit 1 + +ENV PYTHONUNBUFFERED=1 + +CMD ["python", "server.py"] diff --git a/KEC-18_operational_cost_analysis.md b/KEC-18_operational_cost_analysis.md new file mode 100644 index 0000000..b5af2ed --- /dev/null +++ b/KEC-18_operational_cost_analysis.md @@ -0,0 +1,289 @@ +# KEC-18 Operational Cost Analysis + +Date: 2026-03-27 + +## Executive Summary + +- The absolute cheapest workable production setup for roughly 100 concurrent users is one Contabo `Cloud VPS 10` at `EUR4.50/mo`, plus a domain, free Let's Encrypt TLS, free-tier SMTP, and low-cost object backup. That lands around `EUR5.8-6.5/mo` (`USD6.7-7.5/mo`) if email volume stays inside a free tier. +- The more realistic "do not hate yourself later" floor is about `EUR11-16/mo` (`USD12.7-18.5/mo`) once you include a paid SMTP plan or more backup headroom. +- For 500 to 1,000 total users, Contabo remains extremely cheap. The main constraint is not raw VPS price; it is operational headroom, backup discipline, and the app's current default connection cap. +- For 5,000 total users, I would stop using a single-node layout. Split app and MySQL onto separate VPSes and keep uploads/backups external or on a storage-optimized node. + +## What I Sized From The Codebase + +Relevant defaults in the current code: + +- `protocol.py` defaults to `MAX_MESSAGE_BYTES=1 MiB`, `MAX_IMAGE_BYTES=5 MiB`, `MAX_FILE_BYTES=50 MiB`. +- `server.py` defaults to `MAX_CONNECTIONS_GLOBAL=200`, `MAX_CONNECTIONS_PER_IP=10`, `MAX_UPLOADS_GLOBAL=200`, `MAX_UPLOADS_PER_USER=5`. +- `db.py` defaults to `DB_POOL_SIZE=10`. +- `server.py` defaults to `THREAD_POOL_SIZE=40`. +- `README.md` and `scaling.md` indicate the server is an asyncio TCP relay with synchronous MySQL calls pushed behind `asyncio.to_thread()`, which is lightweight for small deployments but still means DB latency and connection pooling matter. + +Implication: + +- `100 concurrent users` is feasible on a very small VPS. +- `More than 200 simultaneously connected devices` is not feasible with the current default connection limit unless configuration is raised and load-tested. +- File storage, not CPU, becomes the first recurring cost question if usage shifts from chat-heavy to attachment-heavy. + +## Assumptions + +To make the tiers comparable, I used these planning assumptions: + +- The `500`, `1,000`, and `5,000` tiers are treated as total users, not fully concurrent users. +- Peak concurrent devices: + - 100-user tier: about 100 + - 500-user tier: about 75 to 125 + - 1,000-user tier: about 125 to 175 + - 5,000-user tier: about 300 to 500 +- Average retained encrypted upload footprint: + - 100 users: 100 GB + - 500 users: 250 GB + - 1,000 users: 500 GB + - 5,000 users: 1 TB +- SMTP use is limited to registration codes and lightweight transactional mail, not marketing mail. +- EUR to USD conversion uses the ECB reference `1 EUR = 1.1539 USD` visible on 2026-03-27 in the ECB currency converter. + +## Current Vendor Pricing Used + +### Contabo + +Official Contabo pricing page shows: + +- `Cloud VPS 10`: `3 vCPU`, `8 GB RAM`, `75 GB NVMe`, `32 TB traffic`, `EUR4.50/mo` +- `Cloud VPS 20`: `6 vCPU`, `12 GB RAM`, `100 GB NVMe`, `32 TB traffic`, `EUR7.00/mo` +- `Cloud VPS 30`: `8 vCPU`, `24 GB RAM`, `200 GB NVMe`, `32 TB traffic`, `EUR14.00/mo` +- `Cloud VPS 40`: `12 vCPU`, `48 GB RAM`, `250 GB NVMe`, `32 TB traffic`, `EUR25.00/mo` +- `Storage VPS 10`: `2 vCPU`, `4 GB RAM`, `300 GB SSD`, `EUR4.50/mo` +- `Storage VPS 20`: `3 vCPU`, `8 GB RAM`, `400 GB SSD`, `EUR7.00/mo` + +Note: Contabo also publishes separate location-fee pricing. For example, the location-fee page shows `Cloud VPS 10` in `United States (Central)` at `EUR0.95/mo` extra, for `EUR5.45/mo` total. Base prices above are the standard pricing page numbers. + +### Domain, TLS, SMTP, Backup, Monitoring, Agent Costs + +- Domain: Porkbun shows `.com` at `USD11.08/yr`, which is about `EUR9.60/yr` or `EUR0.80/mo`. +- TLS: Let's Encrypt certificates are free. +- SMTP: + - MailerSend free plan: `500 emails/month` + - MailerSend Hobby: `EUR5.15/mo` for `5,000 emails/month` + - MailerSend Starter: pricing page shows `EUR25/mo` and `50,000 emails/month` +- Backup/object storage: + - Backblaze B2 pricing page shows `USD6/TB/mo` pay-as-you-go. + - First `10 GB` is free. +- Monitoring: + - Self-hosted Uptime Kuma can be run on your own server at zero direct license cost. + - Managed alternative: UptimeRobot free plan exists; paid plans start at about `USD8/mo`. +- OpenAI / Codex API: + - OpenAI pricing page currently shows `gpt-5.4` standard at `USD2.50 / 1M input tokens` and `USD15.00 / 1M output tokens`. + - `gpt-5.4-mini` standard is `USD0.75 / 1M input` and `USD4.50 / 1M output`. + +## Recommended Infrastructure By Tier + +### Tier A: Minimum Viable, about 100 concurrent users + +Recommended stack: + +- 1 x `Cloud VPS 10` +- Let's Encrypt +- 1 `.com` domain +- Backblaze B2 for backups +- MailerSend free or Hobby depending email volume + +Why this is enough: + +- 8 GB RAM is adequate for Python app + MySQL on one box at this size. +- 75 GB NVMe is enough if uploads are modest and older media is backed up externally. +- 32 TB traffic is far above what this workload should consume. + +Estimated monthly cost: + +- VPS: `EUR4.50` / `USD5.19` +- Domain amortized monthly: `EUR0.80` / `USD0.92` +- Backup at about 100 GB retained: about `EUR0.47` / `USD0.54` +- TLS: `EUR0` +- SMTP: + - Free-tier case: `EUR0` + - Safer paid case: `EUR5.15` / `USD5.94` + +Total: + +- Absolute floor: about `EUR5.77/mo` / `USD6.65/mo` +- Safer operating floor: about `EUR10.92/mo` / `USD12.60/mo` + +### Tier B: About 500 total users + +Recommended stack: + +- 1 x `Cloud VPS 20` +- Backblaze B2 backups +- MailerSend Hobby + +Why: + +- More CPU and RAM headroom for MySQL buffering, background cleanup, and multi-device behavior. +- 100 GB NVMe is enough for DB + hot uploads if colder data is backed up externally. + +Estimated monthly cost: + +- VPS: `EUR7.00` / `USD8.08` +- Domain: `EUR0.80` / `USD0.92` +- Backup at about 250 GB retained: about `EUR1.25` / `USD1.44` +- TLS: `EUR0` +- SMTP Hobby: `EUR5.15` / `USD5.94` + +Total: + +- About `EUR14.20/mo` / `USD16.38/mo` + +### Tier C: About 1,000 total users + +Recommended stack: + +- 1 x `Cloud VPS 30` +- Backblaze B2 backups +- MailerSend Hobby or Starter + +Why: + +- `24 GB RAM` gives useful cache headroom for MySQL and smoother bursts. +- This is the point where a single node is still cheap, but monitoring and restore discipline matter more than raw VPS price. + +Estimated monthly cost: + +- VPS: `EUR14.00` / `USD16.15` +- Domain: `EUR0.80` / `USD0.92` +- Backup at about 500 GB retained: about `EUR2.55` / `USD2.94` +- TLS: `EUR0` +- SMTP Hobby: `EUR5.15` / `USD5.94` + +Total: + +- Lean setup: about `EUR22.50/mo` / `USD25.95/mo` + +If you want higher mail headroom: + +- Swap SMTP to Starter at `EUR25/mo` +- New total: about `EUR42.35/mo` / `USD48.87/mo` + +### Tier D: About 5,000 total users + +Recommended stack: + +- 1 x `Cloud VPS 20` for app server +- 1 x `Cloud VPS 20` for MySQL +- Backblaze B2 backups for media + DB dumps +- MailerSend Starter + +Why I would split here: + +- The current codebase is still operationally simple. A two-node layout buys more reliability than buying one oversized single VPS. +- Separate failure domains help during DB spikes, backup jobs, and incident response. +- This tier likely exceeds the current default `MAX_CONNECTIONS_GLOBAL=200` if user concurrency climbs, so configuration and load testing become mandatory. + +Estimated monthly cost: + +- App VPS: `EUR7.00` / `USD8.08` +- DB VPS: `EUR7.00` / `USD8.08` +- Domain: `EUR0.80` / `USD0.92` +- Backup at about 1 TB retained: about `EUR5.15` / `USD5.94` +- TLS: `EUR0` +- SMTP Starter: `EUR25.00` / `USD28.85` + +Total: + +- About `EUR44.95/mo` / `USD51.87/mo` + +Alternative: + +- If you strongly prefer a single-node layout, `Cloud VPS 40` plus backups is still cheap, but I would consider it worse operationally than two smaller nodes. + +## Minimum Viable Budget Answer + +If the question is "what is the absolute minimum monthly spend to run this for about 100 concurrent users," the answer is: + +- Roughly `EUR5.8-6.5/mo` (`USD6.7-7.5/mo`) with: + - `Cloud VPS 10` + - one cheap domain + - free TLS + - free SMTP tier + - minimal external backup + +If the question is "what is the minimum I would actually recommend for production without pretending backups and mail do not exist," the answer is: + +- Roughly `EUR11-16/mo` (`USD12.7-18.5/mo`) + +## Additional Infrastructure Recommendations + +### TLS certificates + +- Use Let's Encrypt. +- Direct recurring certificate cost: `EUR0`. + +### Domain + +- Budget about `EUR10/yr` to `EUR15/yr`. +- Using current Porkbun `.com` pricing, a normal `.com` is about `EUR9.60/yr`. + +### Backups + +- Do not rely only on local VPS storage. +- Cheapest clean option: nightly MySQL dumps + uploaded file backup to Backblaze B2. +- Ballpark backup cost at current B2 pricing: + - 100 GB: about `EUR0.47/mo` + - 250 GB: about `EUR1.25/mo` + - 500 GB: about `EUR2.55/mo` + - 1 TB: about `EUR5.15/mo` + +### SMTP relay + +- Free tier is enough for early registration-code traffic. +- Move to Hobby quickly once real users arrive; it is still cheap and removes needless friction. + +### Monitoring + +- Cheapest option: self-host Uptime Kuma. +- Managed option: UptimeRobot free or paid. +- I would treat managed monitoring as optional until there is paying traffic. + +## Agent Operational Cost Estimate + +These costs depend entirely on token volume, not server size. + +Using current OpenAI standard pricing: + +- `gpt-5.4`: `USD2.50 / 1M input`, `USD15.00 / 1M output` +- `gpt-5.4-mini`: `USD0.75 / 1M input`, `USD4.50 / 1M output` + +Illustrative monthly spend per active engineering agent: + +- Light usage, `gpt-5.4-mini`: + - 10M input + 2M output + - about `USD16.50/mo` / `EUR14.30/mo` +- Moderate usage, `gpt-5.4`: + - 10M input + 2M output + - about `USD55.00/mo` / `EUR47.66/mo` +- Heavy usage, `gpt-5.4`: + - 40M input + 8M output + - about `USD220.00/mo` / `EUR190.66/mo` + +For a small team of 3 active agents, a realistic monthly AI tooling band is: + +- Lean: about `EUR43-50/mo` +- Moderate: about `EUR143/mo` +- Heavy: about `EUR570+/mo` + +## Risks And Constraints + +- The code currently defaults to `MAX_CONNECTIONS_GLOBAL=200`. If "500 users" or "1,000 users" means concurrent devices, current defaults are not enough. +- The cheapest single-node layout mixes app, MySQL, and hot uploads on one VPS. That is acceptable early, but it increases recovery risk during disk or instance failure. +- Attachment-heavy usage can outgrow cheap NVMe faster than message traffic will outgrow CPU. +- SMTP, domain, and monitoring are trivial costs compared with the cost of not having backups. + +## Final Recommendation + +If I had to choose one path now: + +- Launch on `Cloud VPS 10` if the immediate target is only about `100 concurrent users` and budget is extremely tight. +- Launch on `Cloud VPS 20` if you want a safer early-production baseline without materially changing cost. +- Move to a split app/DB layout by the time you are targeting `5,000 total users` or any scenario above `200 concurrently connected devices`. + +In short: Contabo pricing is not the bottleneck here. Operational discipline, connection-limit tuning, and backup/storage policy are the real budget drivers once the app starts seeing real usage. diff --git a/README.md b/README.md new file mode 100644 index 0000000..d0facd9 --- /dev/null +++ b/README.md @@ -0,0 +1,326 @@ +# Encrypted Chat + +End-to-end encrypted chat s forward secrecy (X3DH + Double Ratchet, Signal Protocol). +Server ukládá a přeposílá šifrované bloby — nikdy nevidí plaintext. + +## Architektura + +``` +┌─────────────┐ TLS/TCP ┌─────────────┐ MySQL ┌─────────┐ +│ GUI/CLI │◄───────────────►│ Server │◄──────────────►│ DB │ +│ klient │ JSON + base64 │ (asyncio) │ │ │ +└─────────────┘ └─────────────┘ └─────────┘ + │ │ + │ X3DH + Double Ratchet │ Opaque blobs + │ Sender Keys (skupiny) │ (server nevidí plaintext) + ▼ ▼ + Lokální klíče Šifrované zprávy + (~/.encrypted_chat/) + metadata +``` + +## Soubory + +### Server +| Soubor | Řádky | Účel | +|--------|-------|------| +| `server.py` | ~2 900 | Asyncio TCP server, 45 handlerů, rate limiting, 5 asyncio.Lock guardů, real-time notifikace | +| `db.py` | ~1 700 | MySQL CRUD, connection pooling (pool_size=10), phantom users, reactions/pins CRUD | +| `schema.sql` | ~190 | MySQL schéma (14 tabulek) | + +### Klient +| Soubor | Řádky | Účel | +|--------|-------|------| +| `gui_client.py` | ~6 300 | PyQt6 GUI — dark/light téma, widget-based message bubbles, verifikace kontaktů, privacy overlay | +| `client.py` | ~900 | CLI klient — 23 menu opcí | +| `chat_core.py` | ~3 500 | Sdílená logika — session management, X3DH/ratchet šifrování, lokální klíče, multi-device | +| `theme.py` | ~540 | Catppuccin dark + Signal-inspired light téma, live switching | + +### Sdílené (server + klient) +| Soubor | Účel | +|--------|------| +| `crypto_utils.py` (~935 ř.) | Ed25519, X25519, AES-256-GCM, HKDF, PBKDF2, X3DH, Double Ratchet (state rollback), Sender Keys (state rollback), ECP1 key encryption, contact verification (fingerprints, safety numbers, QR), message padding | +| `protocol.py` (~140 ř.) | Newline-delimited JSON protokol, base64 encoding, verze (0.8.4) | + +### iOS klient +| Složka | Účel | +|--------|------| +| `ios_client/` (47 Swift souborů, ~5 000 ř.) | Nativní iOS port — CryptoKit + pure Swift GF(2^255-19) + Security.framework RSA, SwiftUI views, wire-kompatibilní s Python serverem | + +### Testy +| Soubor | Účel | +|--------|------| +| `tests/pentest_client.py` (~340 ř.) | Automatizované security regresní testy (AuthZ, malformed headers, session reset, rate limits) | + +## Quick Start + +1. `pip install -r requirements.txt` +2. Spustit `schema.sql` v MySQL +3. `python server.py` +4. Klient: `python gui_client.py` (GUI) nebo `python client.py` (CLI) + +## Jak funguje šifrování + +### Klíče na uživatele +| Klíč | Typ | Účel | +|------|-----|------| +| RSA-4096 | Asymetrický | Pouze login challenge-response. Šifrovaný ECP1 (PBKDF2 600k + AES-256-GCM). | +| Identity Key (IK) | Ed25519 | Podpisy, konverze na X25519 pro X3DH. Šifrovaný ECP1. | +| Signed Pre-Key (SPK) | X25519 | DH v X3DH, podepsaný IK. **Rotuje se každých 7 dní** s grace periodem. | +| One-Time Pre-Keys (OPK) | X25519 | Jednorázové, spotřebuje se při X3DH, automaticky doplňované (< 20 → +50). | + +### DM (1:1 zprávy) — X3DH + Double Ratchet +1. Alice stáhne Bobovy per-device key bundles (IK, SPK, OPK) → X3DH per device → shared secret per device. +2. Double Ratchet inicializován ze shared secret — jedna session per (user, device). +3. Každá zpráva: symmetric ratchet (HMAC chain) → message key → AES-256-GCM. +4. Každá odpověď: DH ratchet (nový X25519 keypair) → nový root key + chain key. +5. Per-device ciphertext — každé zařízení příjemce dostane individuálně šifrovaný blob. +6. Self-encrypted kopie s SELF_DEVICE_ID sentinel, čitelná všemi vlastními zařízeními. + +### Skupiny — Sender Keys +1. Každý odesílatel má vlastní SenderKeyState per group. +2. Sender key distribuován členům přes pairwise Double Ratchet (jako control DM). +3. Skupinové zprávy: symmetric ratchet na sender key → AES-256-GCM. +4. Stejný ciphertext pro všechny příjemce (efektivní). + +### Kontaktní verifikace (Signal-style) +- **Safety numbers** — 60-digit číslo (12 skupin × 5 číslic), deterministické pro každý pár. +- **QR kódy** — binární payload zakódovaný jako base64. +- **Fingerprints** — 30-digit per-user číslo. +- **TOFU** — Trust On First Use + explicit verification + key change warning. + +### Lokální úložiště klíčů +``` +~/.encrypted_chat/{email}/ + private.pem / public.pem — RSA (login, ECP1 formát) + identity_private.bin / _public.bin — Ed25519 (ECP1 formát) + device_id.txt — UUID tohoto zařízení + spk_private.bin / spk_id.txt — Aktuální SPK (AES-256-GCM) + prev_spk_private.bin / prev_spk_id.txt — Předchozí SPK, grace period + opk_private/{opk_id}.bin — One-time prekeys (AES-256-GCM) + sessions/{uid}_{did}.bin — Double Ratchet stavy (AES-256-GCM) + sender_keys/{conv_id}.bin — Vlastní sender keys + sender_keys_recv/{conv}_{uid}_{did}.bin — Přijaté sender keys + known_identity_keys.bin — TOFU registr (AES-256-GCM) + verified_contacts.bin — Explicitní verifikace (AES-256-GCM) + message_cache/{conv_id}.bin — Šifrovaný message cache + login_lockout.json — Brute-force lockout stav +``` + +## Bezpečnostní hardening + +### Šifrování privátních klíčů (ECP1 formát) +- **PBKDF2-HMAC-SHA256** s 600 000 iteracemi (OWASP 2023) +- **AES-256-GCM**, magic bytes "ECP1" jako AAD +- **Formát:** `ECP1(4B) + salt(16B) + nonce(12B) + ciphertext+tag` +- Zpětná kompatibilita: staré PEM se migrují automaticky + +### Lokální šifrování dat +- Session/sender key soubory, OPK, SPK, message cache, verifikační soubory — AES-256-GCM klíčem z HKDF(identity_key) +- `chmod 0o700` na adresáře, `0o600` na soubory + +### Brute-force ochrana +- Exponenciální backoff: `min(2^N, 300)` sekund po N chybných pokusech +- Aplikováno na login + privacy overlay unlock + +### SPK rotace (7 dní) +- Automatická rotace s grace periodem pro in-flight X3DH +- Omezuje dopad kompromitace SPK + +### Ratchet state rollback +- Snapshot/restore při selhání dešifrování (DoubleRatchet + SenderKeyState) + +### Secure deletion +- Overwrite `os.urandom()` + `fsync` + `unlink` na smazané citlivé soubory + +### Message padding +- Bucketed padding (64B–64KB) maskuje délku zpráv + +### Metadata privacy +- Log sanitizace (žádná PII), metadata retention (90 dní), sender chain minimalizace + +### Anti-enumeration +- Phantom users pro neregistrované emaily +- Generické odpovědi na register/login/get_user_info + +## Multi-Device Support + +Pravý multi-device (Signal-like) — každé zařízení má nezávislé Double Ratchet sessions. + +- **Devices tabulka** — každé přihlášení registruje device (UUID) +- **Per-device prekeys** — každé zařízení má vlastní SPK + OPKs +- **Per-device sessions** — klíčované `"user_id:device_id"` +- **Self-encryption** — statický klíč z identity key (čitelné všemi vlastními zařízeními) +- **Pairing** — přenos RSA + Ed25519, nové zařízení generuje vlastní SPK + OPKs + +## Features + +### Protokol & šifrování +- X3DH + Double Ratchet (DM) s forward secrecy +- Sender Keys (skupiny) s distribucí přes pairwise ratchet +- Per-device šifrování (multi-device) +- SPK rotace (7 dní) + grace period +- Ratchet state rollback při selhání +- ECP1 šifrování klíčů (PBKDF2 600k) +- Message padding (bucketed 64B–64KB) +- Kontaktní verifikace (safety numbers, fingerprints, QR kódy) + +### Komunikace +- DM + skupinové konverzace +- Reakce na zprávy (thumbsup, heart, laugh, surprised, sad, thumbsdown) +- Přeposílání zpráv (text, obrázky, soubory) +- Připnuté zprávy (pin/unpin + dialog) +- @Mentions s autocomplete +- Odpovědi na zprávy (reply_to) +- Hledání zpráv (client-side, Ctrl+F) +- Šifrované obrázky (AES-256-GCM, chunked upload, thumbnail) +- Šifrované soubory (až 50 MB, chunked upload) +- Read receipts (real-time) + +### Skupiny +- Skupinové pozvánky (accept/decline) +- Leave group + přenos creatora +- Rename group (creator only) +- Delete conversation (DMs per-user, groups creator-only) +- Group avatar + +### Správa +- Multi-device support (per-device sessions, pairing) +- User profily (telefon, lokace, avatar, viditelnost) +- Online/offline status +- Session reset (při poškození ratchetu) +- Key rotation (revokace zařízení) +- Brute-force lockout + +### GUI (PyQt6) +- Dark (Catppuccin Mocha) + Light (Signal) téma s live switching +- Widget-based message bubbles s ConversationDelegate +- Cirkulární avatary + online zelená tečka +- Unread count badges +- Privacy overlay / lock screen (30s timeout + heslo) +- Drag & drop souborů +- Frameless dialogy +- Connection indicator (green/red/orange) + auto-reconnect +- VerificationDialog (safety numbers, QR, fingerprints) +- Key change warning dialog + +### CLI +- 23 menu opcí (DM, skupiny, soubory, reakce, piny, forwarding, verifikace, zařízení, search) + +### iOS (SwiftUI) +- Wire-kompatibilní s Python serverem +- Kompletní Signal Protocol (X3DH, Double Ratchet, Sender Keys) +- CryptoKit + pure Swift field arithmetic + Security.framework RSA +- SwiftUI views (login, chat, groups, profiles, search) + +## Konfigurace + +### Server + DB +- `SERVER_HOST` (default `127.0.0.1`), `SERVER_PORT` (default `9999`) +- `MYSQL_HOST`, `MYSQL_PORT`, `MYSQL_USER`, `MYSQL_PASSWORD`, `MYSQL_DATABASE` +- `DB_POOL_SIZE` (default `10`) + +### TLS +- `TLS_ENABLED` — zapne TLS (default `false`) +- `TLS_REQUIRED` — vyžaduje TLS_ENABLED +- `TLS_CERT_FILE`, `TLS_KEY_FILE` — cesty k certifikátu (PEM) +- `TLS_AUTOGEN` — auto-generuje self-signed cert (**jen s `ENVIRONMENT=dev`**) +- `TLS_CA_FILE` (klient) — vlastní CA certifikát +- `TLS_INSECURE` (klient) — vypne ověření certifikátu (**jen s `ENVIRONMENT=dev`**) + +### SMTP +- `SMTP_HOST`, `SMTP_PORT`, `SMTP_USER`, `SMTP_PASS`, `SMTP_FROM` +- Bez SMTP = dev mód (kód se vrací přímo klientovi) + +### Limity +- `MAX_MESSAGE_BYTES` (default `65536`), `MAX_IMAGE_BYTES` (5 MB), `MAX_FILE_BYTES` (50 MB) +- `MAX_INPUT_CHARS` (GUI, default `2000`) +- `METADATA_RETENTION_DAYS` (default `90`) +- Rate limity: register 3/min, login 10/min, send_message 20/min +- Connection: 20 req/s, max 10/IP, 200 global + +### Logging +- `LOG_LEVEL` (default `INFO`) + +## Bezpečnostní audit + +Dva bezpečnostní audity provedeny (kód review). Nalezeno 6 CRITICAL, 12 HIGH, 12 MEDIUM, 8 LOW nálezů. + +| Závažnost | Celkem | Opraveno | Zbývá | +|-----------|--------|----------|-------| +| CRITICAL | 6 | **6** | 0 | +| HIGH | 12 | **11** | 1 (H9 — by-design) | +| MEDIUM | 12 | **11** | 1 (M7) | +| LOW | 8 | **1** | 7 | + +Detaily viz `SECURITY_AUDIT.md` a `CLAUDE.md`. + +## Known Issues + +- **Sender Key Redistribution:** Nový člen skupiny nedešifruje staré skupinové zprávy (sender keys se nedistribuují znovu při přidání). +- **iOS: Contact Key Verification** — safety numbers, QR kódy, TOFU zatím neimplementovány v iOS klientu. + +## Docker a CI/CD + +### Lokální vývoj s Docker Compose + +```bash +# Spustit server + MySQL +docker compose up + +# Rebuild po změně kódu +docker compose up --build + +# Zastavit a smazat data +docker compose down -v +``` + +Server bude dostupný na `localhost:5000`. MySQL na `localhost:3306`. +Schéma se automaticky importuje při prvním spuštění. + +### Ruční build Docker image + +```bash +docker build -t encrypted-chat-server . +docker run -p 5000:5000 \ + -e MYSQL_HOST=host.docker.internal \ + -e MYSQL_USER=chat \ + -e MYSQL_PASSWORD=chatpassword \ + -e MYSQL_DATABASE=encrypted_chat \ + -e ENVIRONMENT=dev \ + encrypted-chat-server +``` + +### Produkční deployment + +1. Získat TLS certifikát (Let's Encrypt / vlastní CA) +2. Nastavit env vars — viz `.env.example` +3. Spustit: + ```bash + docker compose -f docker-compose.yml up -d + ``` +4. Ověřit health: `docker compose ps` + +Kritické produkční proměnné: +- `TLS_ENABLED=true`, `TLS_CERT_FILE`, `TLS_KEY_FILE` +- `MYSQL_PASSWORD` — silné heslo +- `ENVIRONMENT=production` (ne `dev`) +- `SMTP_*` — pro registrační emaily + +### CI/CD (GitHub Actions) + +Pipeline v `.github/workflows/ci.yml` spouští při každém push/PR: + +1. **Lint** — `ruff check` na všechny Python soubory +2. **Crypto testy** — `test_crypto_integration.py` (bez serveru) +3. **Integration testy** — spustí MySQL + server, pak `test_server_integration.py` +4. **Docker build** — ověří že se image builduje bez chyb + +## Závislosti + +- `cryptography` — Ed25519, X25519, AES-GCM, RSA, HKDF, PBKDF2 +- `mysql-connector-python` — MySQL s connection pooling +- `python-dotenv` — env vars +- `PyQt6` — GUI +- `Pillow` — resize/thumbnail obrázků +- `qrcode` — generování QR kódů +- `pyzbar` (volitelné) — skenování QR kódů diff --git a/SECURITY_AUDIT.md b/SECURITY_AUDIT.md new file mode 100644 index 0000000..bf82bea --- /dev/null +++ b/SECURITY_AUDIT.md @@ -0,0 +1,491 @@ +# Security Audit (Encrypted Chat) + +Aktualizace: 2026-03-27 +Scope: `server.py`, `db.py`, `chat_core.py`, `gui_client.py`, `client.py`, `protocol.py`, `schema.sql`, `.env.example`, markdown dokumentace. + +Metodika: statický audit kódu + konfigurace. Nebyl proveden aktivní penetrační test ani fuzzing. + +## Refresh 2026-03-27 + +Při re-review aktuálního stavu kódu byly uzavřeny tyto nálezy (KEC-26): + +- MySQL TLS konfigurace je podporovaná přes `MYSQL_SSL_CA`, `MYSQL_SSL_CERT`, `MYSQL_SSL_KEY` (`db.py`, `.env.example`). +- SMTP `STARTTLS` používá explicitní `ssl.create_default_context()` a `EHLO` před/po TLS upgrade (`server.py`). +- Avatar upload flow nastavuje explicitně `chmod(0o600)` i pro user/group avatary (`server.py`). + +Scope limitation aktuálního workspace: + +- V repozitáři chybí `ios_client/` i jakýkoli Android klient, přestože jsou zmiňované v zadání i README. Tento refresh proto pokrývá pouze server a Python klienty. +- Historický nález o plaintext secrets v `.env` nelze v tomto snapshotu reprodukovat; workspace obsahuje pouze `.env.example`. + +Reziduální architektonické riziko: + +- Self-encryption klíč je z definice statický a deterministický; kompromitace identity private key proto zpřístupní všechny self-copies napříč historií (`crypto_utils.py:329-341`). To je tradeoff současného cross-device designu, ne implementační bug. + +## Executive Summary + +Nejzávažnější aktuálně otevřené nálezy: + +- Reziduální architektonický tradeoff: statický/deterministický self-encryption klíč pro self-copies. +- Mobilní klienti deklarovaní v dokumentaci nejsou součástí tohoto workspace, takže jejich security stav zůstává neověřený. + +## CRITICAL + +### ~~C1. TOFU / verifikace identity klíče se obchází při běžném X3DH flow~~ ✅ OPRAVENO (2026-03-08) + +**Evidence** + +- TOFU kontrola existuje jen v `_get_user_info()` (`chat_core.py:799-803`). +- Při navazování session (`_get_or_create_session`) se `identity_key` z bundle bere přímo bez TOFU kontroly (`chat_core.py:1497-1501`, `chat_core.py:1534-1538`). +- U příchozího X3DH (`_process_x3dh_header`) se remote IK také uloží bez TOFU kontroly (`chat_core.py:1551-1553`, `chat_core.py:1580-1584`). + +**Dopad** + +- Pokud server nebo MITM podstrčí jiný identity key, klient může navázat session bez varování. +- Prakticky to obchází uživatelskou verifikaci kontaktu ve výchozím messaging flow. + +**Oprava** + +- Nová výjimka `IdentityKeyChanged(user_id, new_key_bytes, status)` v `chat_core.py` — hard-fail při změně identity klíče. +- `_get_or_create_session()`: TOFU check přes `check_identity_key()` před X3DH initiate. Při `changed`/`changed_verified` vyhodí `IdentityKeyChanged` — session se nenaváže. +- `_process_x3dh_header()`: TOFU check před X3DH respond. Stejný hard-fail — příchozí zpráva s podvrženým klíčem je odmítnuta. +- GUI: `IdentityKeyChanged` zachycena v notification loopu (emituje `key_change_warning` signál místo pádu loopu) a v `_do_send_message` (zobrazí error + warning dialog). +- Session je blokována dokud uživatel explicitně neakceptuje key change přes `accept_key_change()`. +- `decrypt_notification()`: explicitní `except IdentityKeyChanged: raise` před generickým `except Exception` — výjimka se propaguje do notification loopu místo tichého spolknutí. +- `key_change_warning` signál rozšířen o 5. parametr `new_key_bytes: bytes` — "Accept New Key" dialog předává nový klíč přímo z výjimky, ne z cache (která mohla obsahovat starý klíč). +- `IdentityKeyChanged` ošetřena ve všech GUI send cestách: `_do_send_image`, `_do_send_file`, `_do_forward_message`, `_do_find_or_create_and_send` — zobrazí warning dialog + error message. +- CLI (`client.py`): `IdentityKeyChanged` ošetřena ve všech 6 send cestách (send_message ×3, send_image, send_file, forward_message). + +--- + +### ~~C2. Perzistentní DoS konverzace přes nevalidní message headers~~ ✅ OPRAVENO (2026-03-08) + +**Evidence** + +- Server přijímá `ratchet_header` / `x3dh_header` i jako raw `str/bytes` bez JSON schema validace (`server.py:1105-1112`, `server.py:1146-1151`). +- Při `get_messages` se hodnoty bez ochrany parsují `json.loads(...)` (`server.py:1266`, `server.py:1274`). + +**Dopad** + +- Útočník v konverzaci může uložit “poisoned” hlavičku a rozbít načtení historie ostatním členům (`Internal server error`). +- Chyba je perzistentní, dokud je vadná zpráva v historii. + +**Oprava** + +- Nový helper `_validate_header(raw, name)` v `server.py` — přijímá pouze `dict`, odmítá `str`/`bytes`, limit 4096 bajtů. +- `handle_send_message`: message-level i per-recipient headers procházejí `_validate_header()`. Nevalidní hlavička → error response, zpráva se neuloží. +- `handle_get_messages`: `json.loads()` obaleno `try/except` (JSONDecodeError, TypeError, UnicodeDecodeError). Corrupted header → prázdný dict `{}` + warning log, ostatní zprávy se načtou normálně. +- `_validate_header()` rozšířena o validaci očekávaných klíčů a typů pro ratchet headers (`dh_pub`: str, `n`: int, `pn`: int) a používá striktní kontrolu typu pro `n/pn` (`type(...) is int`) — `bool` je explicitně odmítnut. +- Realtime push notifikace nyní čtou data z validovaných `db_recipients` (ne z `recipients_raw`). Per-recipient hlavičky se dekódují z validovaných bytes zpět do `dict` pro JSON notifikaci. +- `encrypted_content` a `nonce` v push notifikacích se skládají z validovaných raw bytes a serializují se přes `encode_binary()` — untrusted hodnoty z raw requestu se do push větve nepropíší. + +--- + +### C3. Historický nález: plaintext tajemství v `.env` a `zaloha/.env` + +Status 2026-03-27: v aktuálním workspace nereprodukovatelné. + +**Evidence** + +- Tento snapshot obsahuje `.env.example`, ale neobsahuje `.env` ani `zaloha/`. +- Původní nález tedy nelze znovu ověřit bez jiného artefaktu nebo deploy prostředí. + +**Dopad** + +- Pokud jsou runtime secrets stále ukládány v reálném `.env` mimo tento snapshot, únik takového souboru by stále znamenal okamžitý přístup do DB. + +**Doporučení** + +1. Okamžitě rotovat DB heslo. +2. Nahradit repozitářové `.env` šablonou (`.env.example`) bez tajemství. +3. Použít secrets manager / deployment-level secret injection. + +## HIGH + +### ~~H1. Chybí TLS mezi aplikací a MySQL~~ ✅ OPRAVENO (2026-03-27) + +**Evidence** + +- DB pool načítá volitelné TLS parametry `MYSQL_SSL_CA`, `MYSQL_SSL_CERT`, `MYSQL_SSL_KEY` a předává je do `MySQLConnectionPool` (`db.py`). +- Konfigurační šablona je doplněná o stejné proměnné (`.env.example`). + +**Dopad** + +- Odposlech nebo MITM na trase app<->DB může odhalit credentials i data. + +**Oprava** + +1. Přidána podpora TLS parametrů v DB vrstvě. +2. Přidány dokumentované env proměnné pro CA/client cert/client key. + +--- + +### ~~H2. Upload/avatary na disku mají slabá oprávnění~~ ✅ OPRAVENO (2026-03-27) + +**Evidence** + +- Avatar upload flow nyní po zápisu explicitně nastavuje `chmod(0o600)` pro user i group avatary (`server.py`). + +**Dopad** + +- Lokální uživatelé na stejném hostu mohou číst citlivá data (včetně avatarů v plaintextu). + +**Oprava** + +1. Explicitní `chmod(0o600)` po zápisu avatar souborů. +2. Adresář `uploads/avatars` zůstává s `0700`. + +--- + +### ~~H3. `session_reset` nemá autorizační vazbu na vztah mezi uživateli~~ ✅ OPRAVENO (2026-03-08) + +**Evidence** + +- Handler přijme libovolné validní `peer_user_id` a pošle notifikaci (`server.py:2040-2052`). +- Neověřuje, že uživatelé sdílí konverzaci nebo existuje session. + +**Dopad** + +- Možnost spam/DoS reset notifikací na cílové uživatele. + +**Oprava** + +- Nová DB funkce `db.shares_conversation(user_id_a, user_id_b)` — `SELECT 1 ... LIMIT 1` přes `conversation_members` JOIN. +- `handle_session_reset`: před push notifikací ověřuje `shares_conversation()`. Pokud uživatelé nesdílí žádnou konverzaci → error response. +- Rate limit 5 požadavků/min na `session_reset` per user (`session_reset|{user_id}`) — IP adresa není součást klíče, takže změnou IP nejde limit obejít. +- Pokud je předán `peer_device_id`, reset notifikace se doručí pouze cílovému zařízení (filtr přes `writer_device_map`). Bez `peer_device_id` zůstává broadcast na všechna zařízení peera. + +--- + +### ~~H4. User enumeration přes pairing a user-info endpointy~~ ✅ OPRAVENO (2026-03-08) + +**Evidence** + +- `pairing_start` vrací explicitně `User not found` (`server.py:763-766`). +- `get_user_info` vrací metadata uživatele při lookupu přes email/user_id (`server.py:551-564`). + +**Dopad** + +- Snadné mapování existence účtů. + +**Oprava** + +- `handle_pairing_start`: vždy vrací `ok` s platně vypadajícím kódem a session se vytvoří vždy (i pro neexistující email), takže `pairing_poll` vrací nerozlišitelné `ready: false`. +- Přidán globální cap `PAIRING_MAX_SESSIONS = 100` pro omezení počtu současných pairing sessions (DoS hardening). +- `pairing_start` rate limit je per-IP (bez email komponenty), aby nešel obcházet rotací emailů. +- `pairing_claim` i `pairing_send`: sjednocená chyba `Invalid or expired code` (žádné rozlišení "neexistuje" vs "patří jinému účtu"). +- V pairing flow se síťové I/O (`send_resp`) volá až po uvolnění `_pairing_lock`. +- `handle_get_user_info`: přidán parametr `session` (vyžaduje login). Lookupovat lze jen sebe nebo kontakty (ověřeno přes `shares_conversation()`). Pro neexistující i nepovolené cíle vrací neutrální "User not found". +- Doplňuje dřívější anti-enumeration opravy: `register_start` (generická odpověď), `login_start` (fake challenge), `login_finish` (generická chyba). + +--- + +### ~~H5. Phantom user inflation přes `create_conversation` / `find_conversation` / `add_member` (DoS)~~ ✅ OPRAVENO (2026-03-08) + +**Evidence** + +- `create_conversation` vytváří phantom účty pro neznámé emaily bez dedikovaného rate limitu (`server.py:906-920`). +- `find_conversation` a `add_member` rate limitují přes `_rate_limit_key(..., addr, email)`, takže rotace emailů obchází limit (`server.py:972`, `server.py:1001`, `server.py:209-212`). +- `create_phantom_user()` pro každý nový email generuje IK+SPK+OTP a zapisuje více řádků do DB (`db.py:1470-1507`). + +**Dopad** + +- Útočník může nafukovat DB a CPU náklady (kryptografická generace + zápisy), případně degradovat výkon serveru. + +**Oprava** + +1. `_can_create_phantom(addr, user_id)` helper kontroluje 3 limity před každým `create_phantom_user()`: + - Globální cap: `MAX_PHANTOM_USERS = 500` (počet v `phantom_user_ids` setu) + - Per-user rate: `phantom_create|{user_id}` — 10/min (email-nezávislé, neobejitelné rotací) + - Per-IP rate: `phantom_create_ip|{addr}` — 10/min (email-nezávislé) +2. `create_conversation` nově má per-user rate limit 10/min + phantom check před každým členem. +3. `find_conversation` a `add_member` — existující per-addr+email limit zůstává (brání hammering jednoho emailu), přidán `_can_create_phantom` check před vytvořením phantomu. +4. Stávající `cleanup_stale_phantoms(30)` v periodic cleanup (10 min) zajišťuje garbage collection. + +--- + +### ~~H6. `pending_registrations` nemá hard cap (memory/SMTP abuse)~~ ✅ OPRAVENO (2026-03-16) + +**Evidence** + +- `pending_registrations` byl původně globální in-memory dict bez horního limitu. +- Cleanup expirovaných registrací se dříve spouštěl jen v `register_*` flow. +- Rate limit `register_start` byl vázaný i na email, takže šel obcházet rotací emailových adres. + +**Dopad** + +- Riziko růstu paměti, zaplnění slotů a SMTP abuse při masivním `register_start`. + +**Oprava** + +1. Přidán globální cap `MAX_PENDING_REGISTRATIONS = 1000` (`server.py:180`). +2. Přidány slot limity `MAX_PENDING_PER_IP = 5` a `MAX_PENDING_PER_SUBNET = 20` (`server.py:181-182`). +3. `_cleanup_registrations()` běží i v periodickém cleanup tasku (`server.py:327`, `server.py:2920`). +4. Přidán per-IP rate limit `register_start_ip|{addr}` (`server.py:476`) a pressure mode s PoW při vysokém zaplnění (`server.py:183`, `server.py:521`). +5. SMTP throttling je vícevrstvý: global, per-IP a per-target (`server.py:185-187`, `server.py:580`). + +**Poznámka** + +- Residual risk zůstává při multi-process nasazení nad stejnou DB: caps a rate-limity jsou in-memory per process, ne globálně distribuované. + +--- + +### ~~H7. Pairing flow důvěřuje serverem vrácenému `temp_public_key` (exfiltrace account private keys)~~ ✅ OPRAVENO (2026-03-16) + +**Evidence** + +- Nové zařízení pošle do `pairing_start` svůj `temp_public_key`; server si ho uloží do `pairing_sessions` (`server.py:1067`, `server.py:1089-1095`). +- Staré zařízení v `pairing_claim` získá `temp_public_key` zpět ze serveru a bez další autentizace ho načte (`server.py:1116-1140`, `chat_core.py:1441-1446`). +- Staré zařízení následně do payloadu zabalí `rsa_private` a `identity_private` a zašifruje je právě pod tento `temp_public_key` (`chat_core.py:1451-1477`). + +**Dopad** + +- Kompromitovaný nebo aktivně škodlivý server může v odpovědi na `pairing_claim` podvrhnout vlastní `temp_public_key`. +- Staré zařízení pak zašifruje exportovaný payload pod klíč útočníka/serveru, který tím získá `rsa_private` i `identity_private` oběti. +- To znamená plné převzetí účtu, možnost přihlášení jako oběť a přístup k self-encrypted historii (derivace `self/local` klíčů z identity private key). + +**Oprava** + +1. Přidán `compute_pairing_fingerprint()` helper nad raw dočasným veřejným klíčem (30 číslic pro ruční porovnání). +2. Pairing bootstrap už nepoužívá dočasný RSA transport, ale `X25519 + HKDF + AES-GCM`: nové zařízení pošle dočasný X25519 public key, staré zařízení vygeneruje jednorázový X25519 sender key a obě strany odvodí stejný symmetric bootstrap key z DH shared secret. +3. Nové zařízení po `pairing_start` zobrazuje 8místný kód i fingerprint dočasného pairing klíče. +4. Staré zařízení při `authorize_device` vyžaduje fingerprint opsaný z nového zařízení; před `pairing_send` vypočítá fingerprint klíče vráceného serverem a při neshodě celý pairing odmítne. +5. Tím se zavádí povinná out-of-band vazba mezi oběma zařízeními a server už nemůže nepozorovaně podvrhnout vlastní pairing key ani získat bootstrap secret. + +--- + +## MEDIUM + +### ~~M6. Chybí auditní notifikace po přidání nového zařízení~~ ✅ OPRAVENO (2026-03-16) + +**Evidence** + +- Po úspěšném `pairing_send` server pouze uloží payload a vrátí `OK` (`server.py:1143-1175`). +- V serveru ani klientech není samostatný notif type typu `device_added` / `device_linked`; zařízení lze zjistit až dodatečně přes `list_devices`. + +**Dopad** + +- Pokud uživatel omylem nebo po sociálním inženýrství schválí cizí pairing kód, ostatní aktivní zařízení nedostanou okamžitý auditní signál. +- Zhoršuje to detekci zneužití a forenzní dohledatelnost. + +**Oprava** + +1. Server po prvním loginu nově vytvořeného zařízení posílá na ostatní zařízení účtu push notifikaci `device_added`. +2. Payload notifikace obsahuje `device_id`, `device_name`, zdrojovou IP a čas přidání. +3. GUI zobrazuje bezpečnostní alert a zvýrazněný status bar. +4. CLI vypisuje explicitní auditní hlášku s doporučením okamžité rotace klíčů, pokud zařízení uživatel nepoznává. + +--- + +### ~~M1. `mark_read` a `confirm_delivery` neověřují, že `message_ids` patří do dané konverzace~~ ✅ OPRAVENO (2026-03-08) + +**Evidence** + +- Handler validuje členství jen v `conversation_id` (`server.py:1464-1479`, `server.py:1516-1531`). +- DB insert metody pro receipts neváží `message_id` na konverzaci (`db.py:1102-1113`, `db.py:1188-1200`). + +**Dopad** + +- Možná manipulace read/delivery stavu cizích zpráv (integrita metadat). + +**Oprava** + +- `db.mark_messages_read()` a `db.mark_messages_delivered()` nahrazeny z per-row `INSERT IGNORE` na batch `INSERT IGNORE ... SELECT m.id, %s FROM messages m WHERE m.id IN (...) AND m.conversation_id = %s`. +- Message IDs, které nepatří do dané konverzace, jsou tiše přeskočeny (SELECT je nevrátí). + +--- + +### ~~M2. SMTP STARTTLS bez explicitního TLS contextu~~ ✅ OPRAVENO (2026-03-27) + +**Evidence** + +- SMTP flow používá `server.starttls(context=ssl.create_default_context())` a `EHLO` před/po TLS upgrade (`server.py`). + +**Dopad** + +- Slabší kontrola TLS parametrů/verifikace dle runtime prostředí. + +**Oprava** + +1. Přidán explicitní TLS context pro STARTTLS. +2. Přidán `EHLO` před i po TLS upgrade. + +--- + +### ~~M3. CLI klient: několik lokálních hardening mezer~~ ✅ OPRAVENO (2026-03-08) + +**Evidence** + +- Heslo se zadává přes `input()` (echo on) (`client.py:730`, `client.py:749`, `client.py:754`). +- Zprávy se tisknou bez sanitace escape sekvencí (`client.py:491`). +- Default save path při downloadu je převzat z remote `filename` (`client.py:523-530`). + +**Dopad** + +- Shoulder-surfing hesla, terminal escape spoofing, riskantní defaultní save path. + +**Oprava** + +- Všechny password prompty (register, login, pairing, authorize device, rotate keys) nahrazeny `prompt_password()` wrapping `getpass.getpass()` — heslo se nezobrazuje na terminálu. +- `_sanitize_text()` helper stripuje control znaky (`\x00-\x1f` kromě `\t`/`\n`/`\r`) a ANSI escape sekvence. Aplikováno na `sender`, `text`, `filename` při výpisu zpráv v `_print_messages()`. +- Follow-up: `_sanitize_text()` nyní bezpečně přijímá i non-string vstupy (`None -> ""`, jinak `str(...)`), čímž se eliminuje `TypeError` při neočekávaném typu z payloadu (`client.py:32-36`). +- Follow-up: sanitace rozšířena na zbývající user-controlled CLI výpisy — seznam konverzací (`client.py:63-67`), search výsledky (`client.py:293-300`), seznam pozvánek (`client.py:612-614`), profil (`client.py:637-644`), seznam zařízení (`client.py:709-713`), verify view (`client.py:435`, `client.py:446`) a notifikace včetně reaction hodnoty (`client.py:752-774`). +- `_safe_filename()` helper: `os.path.basename()` + odstranění NUL + fallback na `"download"` pro prázdné/tečkové názvy. Aplikováno na default save path při downloadu. + +--- + +### ~~M4. `get_key_bundle` umožňuje OPK depletion (availability)~~ ✅ OPRAVENO (2026-03-08) + +**Evidence** + +- `handle_get_key_bundle` nemá rate limit ani authorizační vazbu na vztah mezi uživateli (`server.py:648-660`). +- DB vrstva při každém volání spotřebovává one-time prekeys (`get_key_bundles_for_user` — „Consumes one OPK per device atomically”, `db.py:394-450`). +- `target_user_id` lze získat přes `find_conversation` lookup (`server.py:966-987`). + +**Dopad** + +- Útočník může opakovanými dotazy vyčerpat OPK oběti, zhoršit doručitelnost a vynutit časté doplňování prekeys. + +**Oprava** + +1. Per-caller rate limit: `get_key_bundle|{user_id}` — 10/min (omezuje celkový počet fetchů jednoho uživatele). +2. Per-target rate limit: `get_key_bundle_target|{target_user_id}` — 20/min (omezuje rychlost vyčerpávání OPK konkrétní oběti). Autorizace probíhá před per-target RL (neautorizovaný request nespálí bucket cíle). +3. Autorizace: `shares_conversation()` — caller musí sdílet konverzaci s cílem (self-fetch povolen vždy). +4. Chybová zpráva pro neautorizovaný přístup je neutrální (`”Key bundle not available”`) — shodná s neexistujícím uživatelem. +5. **Doplňující per-user rate limity** na všechny zbývající výpočetně/DB náročné handlery (celkem 29 RL checks): + - Crypto+DB: `upload_prekeys` 5/min, `ensure_prekeys` 5/min, `rotate_keys` 3/min, `reencrypt` 10/min + - DB-heavy: `get_messages` 30/min, `delete_conv` 5/min, `delete_msg` 20/min, `react` 20/min, `remove_member` 10/min, `rename_conv` 5/min + - File I/O: `update_avatar` 5/min (sdílený bucket pro user i group avatar) + +--- + +### ~~M5. `upload_image_start` nemá anti-DoS cap na in-flight uploady~~ ✅ OPRAVENO (2026-03-08) + +**Evidence** + +- `upload_image_start` nevynucuje request rate limit ani limit počtu aktivních uploadů na user/IP (`server.py:1786-1823`). +- In-memory `pending_uploads` je bez explicitního capu (`server.py:58`, `server.py:1812-1819`). +- Cleanup stale uploadů běží periodicky (600s) a DB stale threshold je 3600s (`server.py:2488-2490`, `db.py:1626-1633`). + +**Dopad** + +- Útočník může zahájit mnoho uploadů a vytvářet dočasné soubory/záznamy, což zvyšuje memory/disk tlak. + +**Oprava** + +1. Per-user rate limit: `upload_start|{user_id}` — 10/min. +2. Globální cap: `MAX_UPLOADS_GLOBAL = 200` (kontrola `len(pending_uploads)` pod `_uploads_lock`). +3. Per-user cap: `MAX_UPLOADS_PER_USER = 5` (počet záznamů s `uploader_id == user_id`). +4. Stale threshold snížen z 3600s na `UPLOAD_STALE_SECONDS = 600` (10 min). +5. Periodic cleanup interval snížen z 600s na 120s (2 min). + +## LOW + +### ~~L1. `decode_binary` není strict base64~~ ✅ OPRAVENO (2026-03-08) + +**Evidence** + +- `base64.b64decode(data)` bez `validate=True` (`protocol.py:18`). + +**Dopad** + +- Méně striktní input parsing (robustnost), ne přímý průnik. + +**Oprava** + +- `decode_binary()` nyní volá `base64.b64decode(data, validate=True)` — odmítá neplatné base64 znaky (whitespace, non-alphabet). + +--- + +### ~~L2. Pairing používá dočasný RSA-2048 klíč, zatímco zbytek aplikace standardně generuje RSA-4096~~ ✅ OPRAVENO (2026-03-16) + +**Evidence** + +- `pairing_start()` generuje dočasný RSA klíč přes `generate_rsa_keypair(2048)` (`chat_core.py:1359`). +- Default helper `generate_rsa_keypair()` používá 4096 bitů (`crypto_utils.py:64`). + +**Dopad** + +- Není to primární slabina pairing flow; hlavní problém je autenticita `temp_public_key` (H7). +- Přesto jde o zbytečně slabší parametr pro přenos payloadu obsahujícího account private keys. + +**Oprava** + +1. Pairing bootstrap už dočasné RSA vůbec nepoužívá; byl nahrazen `X25519 + HKDF + AES-GCM`. +2. Tím odpadá původní důvod pro sjednocení na RSA-4096. + +--- + +### ~~L3. Pairing UX explicitně nevaruje, že párovací kód se nesmí sdílet~~ ✅ OPRAVENO (2026-03-16) + +**Evidence** + +- GUI po vygenerování kódu zobrazuje instrukci k jeho schválení, ale bez výrazného warningu proti sdílení kódu. +- Aktuální pairing model spoléhá na to, že uživatel 8místný kód neprozradí třetí straně. + +**Dopad** + +- Zvyšuje se riziko sociálního inženýrství ("nadiktujte mi kód z nového zařízení"), i když brute-force samotného kódu je při současných limitech nepraktický. + +**Oprava** + +1. GUI i CLI nyní při párování zobrazují explicitní warning `Never share this pairing code.` +2. Fingerprint nového zařízení se zobrazuje spolu s kódem, takže uživatel dostává zároveň instrukci k bezpečnému ručnímu ověření. +3. GUI nově zobrazuje pairing QR a staré zařízení ho může načíst ze souboru místo ručního opisování kódu a fingerprintu. + +--- + +### ~~L4. `reencrypt_history()` po pairingu prozrazuje serveru timing a rozsah self-history~~ ✅ MITIGOVÁNO (2026-03-17) + +**Evidence** + +- Po úspěšném `pairing_send` staré zařízení asynchronně spouští `reencrypt_history()` (`chat_core.py:1477-1485`). +- Server z batch operací vidí, že právě proběhlo párování, a přibližně kolik self-encrypted zpráv bylo potřeba přegenerovat. + +**Dopad** + +- Jde o metadata leak, nikoli o únik obsahu zpráv. +- Server může odhadnout velikost historie a intenzitu používání účtu. + +**Oprava** + +1. Post-pairing history resync už nezačíná okamžitě; běží po náhodném odkladu. +2. Pořadí konverzací i pořadí zpráv se před fetch/upload fází míchá. +3. Mezi fetch cykly i mezi upload batchi je náhodný jitter, takže pairing už negeneruje tak snadno korelovatelný burst. +4. Residual leak zůstává nízký: server stále ví, že nějaký history resync proběhl, ale výrazně hůř z něj odvodí přesný okamžik pairingu a strukturu resyncu po konverzacích. + +## Positive Findings + +- Dev-only guardy: `TLS_INSECURE` a `TLS_AUTOGEN` jsou blokovány mimo `ENVIRONMENT=dev`. +- Server používá UUID validace v řadě handlerů. +- Upload/download ověřuje členství v konverzaci. +- Klientské private keys/storage používají PBKDF2 + AES-GCM a restriktivní perms (`0700`/`0600`) v key storage. +- Přítomný client-side lockout na opakované chybné login pokusy. +- Pairing má anti-enumeration ochrany: generická odpověď, per-IP rate limit, `poll_token`, sjednocené chyby `Invalid or expired code` a krátkodobé pairing sessions. + +## Prioritní plán oprav + +### 0-48 hodin + +1. Rotace DB hesla + odstranění tajemství z `.env`. +2. ~~Zavést OOB autentizaci pairingu (fingerprint / QR / SAS) a nepovažovat serverem vrácený `temp_public_key` za důvěryhodný.~~ ✅ DONE +3. ~~Oprava TOFU bypassu v obou X3DH cestách.~~ ✅ DONE +4. ~~Zablokování nevalidních message headers na vstupu.~~ ✅ DONE +5. Přepnutí upload storage perms na `0700/0600`. +6. ~~Omezit phantom creation (rate limit bez emailu + cap).~~ ✅ DONE +7. ~~Zavést cap pro `pending_registrations` a čistit je i v periodickém cleanupu.~~ ✅ DONE +8. ~~Přidat cap/rate limit na in-flight uploady.~~ ✅ DONE + +### 7 dní + +1. Zapnout TLS mezi app a MySQL (mTLS nebo minimálně server cert verify). +2. ~~Opravit autorizaci `session_reset`.~~ ✅ DONE +3. ~~Opravit vazbu `message_ids` na `conversation_id` pro receipts.~~ ✅ DONE +4. ~~Omezit `get_key_bundle` (rate limit + policy sdílené konverzace).~~ ✅ DONE +5. ~~Přidat `device_added` audit notifikaci a zobrazit ji v GUI/CLI.~~ ✅ DONE + +### 30 dní + +1. ~~Anti-enumeration sjednotit napříč endpointy.~~ ✅ DONE +2. ~~CLI hardening (`getpass`, output sanitace, filename sanitace).~~ ✅ DONE +3. Doplnit integrační testy pro bezpečnostní regresi (TOFU, poisoned headers, receipt authz, session_reset device targeting, anti-enumeration, DoS caps, pairing MITM / device-added audit). diff --git a/TODO.md b/TODO.md new file mode 100644 index 0000000..bd87597 --- /dev/null +++ b/TODO.md @@ -0,0 +1,138 @@ +# TODO + +## Zbývající bezpečnostní nálezy + +### HIGH +- [ ] **H9: Self-encryption key** — statický/deterministický klíč z identity key (by-design pro cross-device čtení, architektonické omezení — žádná forward secrecy pro self-copies) + +### MEDIUM +- [x] **M7: MySQL TLS** — `db.get_connection()` podporuje SSL parametry (`MYSQL_SSL_CA`, `MYSQL_SSL_CERT`, `MYSQL_SSL_KEY`). + +### LOW (nízké riziko) +- [ ] L1: Hex string keys v skipped messages dict — timing side-channel (post-auth) +- [ ] L2: RatchetHeader redundantní type konverze +- [ ] L3: `notif_label.setText()` vs `setHtml()` křehkost +- [ ] L4: SQL column interpolation v `update_user_profile` (whitelist chrání) +- [ ] L5: TLS cipher suite hardening (Python defaults rozumné, ne explicitní) +- [ ] L6: Temporary pairing key cleanup z paměti +- [ ] L7: `_user_cache` indefinite growth + +## Funkční TODO + +### High Priority +- [ ] **Sender Key Redistribution** — při `add_member` redistribuovat sender keys všem členům včetně nového. Nový člen skupiny momentálně nedešifruje staré zprávy. + +### Medium Priority +- [ ] **iOS: Contact Key Verification** — safety numbers, fingerprints, QR kódy, TOFU registr. Spec viz CLAUDE.md (iOS implementation spec). +- [x] Typing indicators (`typing_start`/`typing_stop` + 3s timeout, debounce) +- [x] Delivery receipts (`message_delivered` notifikace — 1 fajfka odesláno, 2 fajfky doručeno, modré přečteno) +- [ ] Group admin roles (více adminů) +- [ ] Edit sent messages + +### Low Priority +- [ ] Desktop notifications (system tray) +- [ ] Image gallery view +- [ ] Systemd + Docker deployment + +## Před nasazením do produkce + +- [ ] **TLS certifikáty** — Let's Encrypt nebo vlastní CA. `TLS_ENABLED=true`, `TLS_CERT_FILE`, `TLS_KEY_FILE`. +- [ ] **SMTP** — reálný SMTP server pro registrační kódy. +- [x] **MySQL TLS** — SSL parametry v `db.get_connection()` pokud DB na jiném stroji. +- [ ] **UPLOAD_DIR** — persistentní disk, dostatečná kapacita, správná práva (0o700). +- [ ] **Backup** — pravidelný backup MySQL + UPLOAD_DIR. +- [ ] **Packaging** — pyinstaller / cx_Freeze pro distribuci klientů. +- [ ] **Penetrační testy** — manuální + automatizované (path traversal, DoS, race conditions, enumeration, TLS downgrade, pairing hijacking). + +## Budoucí plány + +- [ ] WebSocket upgrade (nahradit raw TCP pro lepší kompatibilitu) +- [ ] Mobilní push notifikace (APNs + FCM) +- [ ] Auto-update klientů (po packagingu) +- [ ] Monetizace — oddělený platební server (Stripe), premium kódy, free/premium tier. Detaily viz CLAUDE.md. + +## Phantom Users — Distributed Cap + +Pro multi-process deployment: +1. DB-backed quota (`system_quotas` tabulka, `SELECT ... FOR UPDATE`) +2. Same-email races přes `UNIQUE(email)` +3. Periodic reconciliation job +4. Shared rate-limits (Redis nebo DB atomic counters) +5. Concurrency testy + +## Hotovo + +### Security (všechny CRITICAL + většina HIGH/MEDIUM opraveny) +- [x] C1: readuntil DoS → LimitOverrunError handling +- [x] C2: SenderKeyState fast-forward DoS → MAX_SENDER_KEY_SKIP=256 +- [x] C3: Plaintext message cache → AES-256-GCM šifrování +- [x] C4: OPK file permissions → chmod 0o600 +- [x] C5: Upload size validation → received_bytes == file_size check +- [x] C6: Path traversal → UUID validace + is_relative_to +- [x] H1: Session/sender key šifrování → AES-256-GCM via _local_key +- [x] H2+H14: TLS hardening → ENVIRONMENT=dev guard +- [x] H3+H13: Anti-enumeration → generické odpovědi, auth pro get_user_info +- [x] H4: Race conditions → 5 asyncio.Lock guardů +- [x] H5+H6: Protocol error handling → base64/JSON exception handling +- [x] H7: Avatar path traversal → _safe_avatar_path +- [x] H8: Password memory → bytearray + zero-out +- [x] H10: Image validation → size + dimensions check +- [x] H11: Filename sanitization → os.path.basename +- [x] H12: OPK race condition → SELECT FOR UPDATE +- [x] M2: HKDF salt → b"\x00"*32 +- [x] M2 (SMTP): STARTTLS s explicitním `ssl.create_default_context()` + `EHLO` před/po TLS upgrade +- [x] M3: PBKDF2 600k iterations (ECP1 formát) +- [x] M4: SPK rotace 7 dní + grace period +- [x] M5: Rate limit cleanup +- [x] M6: TOCTOU → remove_conversation_member_atomic +- [x] M8: UUID validace všech handlerů +- [x] M9: Ratchet state rollback (snapshot/restore) +- [x] M10: message_ids cap (500) +- [x] M11: Pairing poll token (secrets.token_hex) +- [x] M12: Upload end size validation +- [x] M13: chmod 0o700/0o600 na klíčové adresáře/soubory +- [x] Avatar file perms: explicitní `chmod(0o600)` pro user/group avatary +- [x] L8: Phantom user cleanup (30 dní + email validace) +- [x] SPK/OPK šifrování na disku +- [x] Brute-force lockout (exponenciální backoff) + +### Features +- [x] X3DH + Double Ratchet (Signal Protocol) +- [x] Sender Keys pro skupiny +- [x] Multi-device support (per-device sessions, pairing) +- [x] Kontaktní verifikace (safety numbers, fingerprints, QR kódy) — Python klienti +- [x] Message padding (bucketed 64B–64KB) +- [x] Metadata privacy (log sanitizace, retention, sender chain minimalizace) +- [x] Secure deletion (overwrite + fsync + unlink) +- [x] Reakce na zprávy (6 emoji typů) +- [x] Přeposílání zpráv (text, obrázky, soubory) +- [x] Připnuté zprávy (pin/unpin + dialog) +- [x] @Mentions s autocomplete +- [x] Hledání zpráv (client-side, Ctrl+F) +- [x] Šifrované obrázky + soubory (chunked upload, až 50 MB) +- [x] Skupinové pozvánky (accept/decline) +- [x] Leave group + přenos creatora +- [x] Rename group (creator only) +- [x] Delete conversation +- [x] Group avatar +- [x] User profily (telefon, lokace, avatar, viditelnost) +- [x] Online/offline status +- [x] Unread count badges (server-side pro offline uživatele) +- [x] Privacy overlay / lock screen +- [x] Dark/light téma (Catppuccin + Signal) s live switching +- [x] Session recovery (reset + auto X3DH) +- [x] Connection indicator + auto-reconnect +- [x] TCP keepalive (SO_KEEPALIVE idle=25s, interval=10s, count=3) + dead writer cleanup +- [x] Optimistic message send (okamžité zobrazení v UI, server potvrzení na pozadí) +- [x] Cache-first message loading (okamžité zobrazení z disku, server sync na pozadí) +- [x] Fetch deduplication (_messages_inflight set) +- [x] Notification push logging ([PUSH] s počtem writerů per příjemce) +- [x] Drag & drop souborů +- [x] Favorites (GUI) +- [x] Phantom users (anti-enumeration) +- [x] DB connection pooling (pool_size=10) +- [x] Version negotiation (0.8.4, MIN_CLIENT_VERSION=0.8.3) +- [x] Graceful server shutdown +- [x] iOS klient (47 Swift souborů, ~5 000 řádků) +- [x] CLI klient (23 menu opcí) +- [x] Pentest harness (4 test kategorií) diff --git a/certs/README.md b/certs/README.md new file mode 100644 index 0000000..7c074d6 --- /dev/null +++ b/certs/README.md @@ -0,0 +1,101 @@ +# TLS Setup — Let's Encrypt + Cloudflare DNS + +TLS certifikát přes Let's Encrypt bez nutnosti otevírat port 80. +Ověření domény probíhá přes DNS TXT záznam (Cloudflare API). + +## Předpoklady + +- Doména s DNS na Cloudflare (free tier stačí) +- Cloudflare API token s oprávněním "Edit zone DNS" +- Root přístup na serveru (certbot potřebuje `/etc/letsencrypt/`) + +## Postup + +### 1. Cloudflare API token + +1. Jdi na https://dash.cloudflare.com/profile/api-tokens +2. **Create Token** → Use template **"Edit zone DNS"** +3. Zone Resources → vybrat svou doménu +4. Zkopíruj vygenerovaný token + +### 2. Credentials soubor + +```bash +cp cloudflare.ini.example cloudflare.ini +nano cloudflare.ini # vlož API token +chmod 600 cloudflare.ini +``` + +### 3. Získání certifikátu + +```bash +sudo ./setup-tls.sh chat.example.com +``` + +Skript nainstaluje certbot + Cloudflare plugin, získá certifikát a vytvoří symlinky v tomto adresáři. + +### 4. Konfigurace serveru + +Přidej do `.env` v kořenovém adresáři projektu: + +```env +TLS_ENABLED=true +TLS_CERT_FILE=/etc/letsencrypt/live/chat.example.com/fullchain.pem +TLS_KEY_FILE=/etc/letsencrypt/live/chat.example.com/privkey.pem +``` + +### 5. Konfigurace klienta + +Na klientovi stačí: + +```env +TLS_ENABLED=true +``` + +Let's Encrypt je v systémovém trust store — klient ověří certifikát automaticky. + +## Obnova certifikátu + +Certbot obnovuje certifikát automaticky přes systemd timer (každých ~60 dní, cert platí 90). + +```bash +# Ověřit že timer běží +systemctl status certbot.timer + +# Ruční obnova (test) +sudo certbot renew --dry-run +``` + +Po úspěšné obnově se spustí `reload-server.sh` (deploy hook) — restartuje chat server aby načetl nový certifikát. + +## Soubory + +| Soubor | Účel | +|--------|------| +| `setup-tls.sh` | Instalace certbot + získání certifikátu | +| `reload-server.sh` | Deploy hook — restartuje server po renew | +| `cloudflare.ini.example` | Šablona pro Cloudflare API token | +| `cloudflare.ini` | Tvůj API token (gitignored) | + +## FAQ + +**Funguje certifikát na nestandardním portu (např. 9999)?** +Ano. Certifikát je vázaný na doménu, ne na port. `chat.example.com:9999` funguje. + +**Musím otevírat port 80?** +Ne. DNS challenge ověřuje doménu přes DNS TXT záznam, žádný HTTP požadavek na server. + +**Co když nemám Cloudflare?** +Můžeš použít ruční DNS challenge (bez automatického renew): +```bash +sudo certbot certonly --manual --preferred-challenges dns -d chat.example.com +``` +Certbot ti řekne jaký TXT záznam přidat. Při renew to musíš opakovat ručně. + +**Dev/testování bez certifikátu?** +```env +ENVIRONMENT=dev +TLS_ENABLED=true +TLS_AUTOGEN=true # server vygeneruje self-signed cert +TLS_INSECURE=true # klient přeskočí ověření +``` diff --git a/certs/reload-server.sh b/certs/reload-server.sh new file mode 100755 index 0000000..abc7449 --- /dev/null +++ b/certs/reload-server.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash +# Deploy hook — spustí se automaticky po úspěšném renew certifikátu +# Certbot volá tento skript s RENEWED_LINEAGE a RENEWED_DOMAINS env vars +# +# Restartuje chat server aby načetl nový certifikát. +# Přizpůsob podle toho jak server spouštíš (systemd / screen / přímý proces). + +set -euo pipefail + +echo "Certifikát obnoven pro: ${RENEWED_DOMAINS:-unknown}" + +# Varianta 1: Systemd service +if systemctl is-active --quiet encrypted-chat 2>/dev/null; then + systemctl restart encrypted-chat + echo "Server restartován (systemd)." + exit 0 +fi + +# Varianta 2: Poslat SIGINT procesu (graceful shutdown + ruční restart) +PID=$(pgrep -f "python.*server.py" || true) +if [ -n "$PID" ]; then + echo "Posílám SIGINT procesu $PID (server.py)" + kill -INT "$PID" + echo "Server zastaven. Spusť ho znovu ručně nebo přes systemd." + exit 0 +fi + +echo "VAROVÁNÍ: Server proces nenalezen. Restartuj server ručně." diff --git a/certs/setup-tls.sh b/certs/setup-tls.sh new file mode 100755 index 0000000..1c1f7c6 --- /dev/null +++ b/certs/setup-tls.sh @@ -0,0 +1,108 @@ +#!/usr/bin/env bash +# Setup TLS certifikátu přes Let's Encrypt + Cloudflare DNS challenge +# Nevyžaduje otevřený port 80 — ověření přes DNS TXT záznam +# +# Použití: +# 1. Přesuň DNS domény na Cloudflare (free tier stačí) +# 2. Vytvoř API token: https://dash.cloudflare.com/profile/api-tokens +# -> Use template "Edit zone DNS" -> vybrat doménu +# 3. cp cloudflare.ini.example cloudflare.ini +# Vlož token, chmod 600 cloudflare.ini +# 4. sudo ./setup-tls.sh chat.example.com +# +# Po úspěšném získání certifikátu přidej do .env: +# TLS_ENABLED=true +# TLS_CERT_FILE=/etc/letsencrypt/live/DOMENA/fullchain.pem +# TLS_KEY_FILE=/etc/letsencrypt/live/DOMENA/privkey.pem + +set -euo pipefail + +DOMAIN="${1:-}" +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +CREDS="$SCRIPT_DIR/cloudflare.ini" +DEPLOY_HOOK="$SCRIPT_DIR/reload-server.sh" + +if [ -z "$DOMAIN" ]; then + echo "Použití: sudo $0 " + echo "Příklad: sudo $0 chat.example.com" + exit 1 +fi + +if [ "$EUID" -ne 0 ]; then + echo "Spusť jako root: sudo $0 $DOMAIN" + exit 1 +fi + +if [ ! -f "$CREDS" ]; then + echo "Chybí $CREDS" + echo "Zkopíruj cloudflare.ini.example -> cloudflare.ini a vlož API token." + exit 1 +fi + +# Ověř oprávnění credentials souboru +PERMS=$(stat -c %a "$CREDS" 2>/dev/null || stat -f %Lp "$CREDS" 2>/dev/null) +if [ "$PERMS" != "600" ]; then + echo "VAROVÁNÍ: $CREDS má oprávnění $PERMS, nastavuji 600" + chmod 600 "$CREDS" +fi + +echo "=== Instalace certbot + Cloudflare pluginu ===" +if ! command -v certbot &>/dev/null; then + apt-get update + apt-get install -y certbot python3-certbot-dns-cloudflare + echo "Certbot nainstalován." +else + echo "Certbot již nainstalován." + # Doinstaluj plugin pokud chybí + if ! python3 -c "import certbot_dns_cloudflare" 2>/dev/null; then + apt-get install -y python3-certbot-dns-cloudflare + fi +fi + +echo "" +echo "=== Získání certifikátu pro $DOMAIN ===" +DEPLOY_ARGS="" +if [ -f "$DEPLOY_HOOK" ] && [ -x "$DEPLOY_HOOK" ]; then + DEPLOY_ARGS="--deploy-hook $DEPLOY_HOOK" + echo "Deploy hook: $DEPLOY_HOOK" +fi + +certbot certonly \ + --dns-cloudflare \ + --dns-cloudflare-credentials "$CREDS" \ + --dns-cloudflare-propagation-seconds 30 \ + -d "$DOMAIN" \ + --non-interactive \ + --agree-tos \ + --register-unsafely-without-email \ + $DEPLOY_ARGS + +CERT_DIR="/etc/letsencrypt/live/$DOMAIN" +if [ -d "$CERT_DIR" ]; then + echo "" + echo "=== Certifikát úspěšně získán ===" + echo "" + echo "Soubory:" + echo " Certifikát: $CERT_DIR/fullchain.pem" + echo " Klíč: $CERT_DIR/privkey.pem" + echo "" + echo "Přidej do .env:" + echo " TLS_ENABLED=true" + echo " TLS_CERT_FILE=$CERT_DIR/fullchain.pem" + echo " TLS_KEY_FILE=$CERT_DIR/privkey.pem" + echo "" + echo "Na klientovi stačí:" + echo " TLS_ENABLED=true" + echo "" + echo "Automatický renew: certbot timer (systemd) nebo cron" + echo " systemctl status certbot.timer" + echo "" + + # Symlinky pro snadný přístup + ln -sf "$CERT_DIR/fullchain.pem" "$SCRIPT_DIR/fullchain.pem" + ln -sf "$CERT_DIR/privkey.pem" "$SCRIPT_DIR/privkey.pem" + echo "Symlinky vytvořeny v $SCRIPT_DIR/" +else + echo "CHYBA: Certifikát nebyl vytvořen." + exit 1 +fi diff --git a/chat_core.py b/chat_core.py new file mode 100644 index 0000000..59d7260 --- /dev/null +++ b/chat_core.py @@ -0,0 +1,3907 @@ +"""Shared network layer and ChatClient class for CLI and GUI clients. + +Uses X3DH + Double Ratchet for message encryption, Sender Keys for groups. +RSA retained for login challenge-response only. +""" + +import asyncio +import collections +import hashlib +import json +import logging +import os +import random +import ssl +import time +import uuid +from datetime import datetime, timezone +from pathlib import Path + +from dotenv import load_dotenv + +load_dotenv() + +from crypto_utils import ( + # RSA (login only) + generate_rsa_keypair, + serialize_private_key, + serialize_public_key, + load_private_key, + load_public_key, + rsa_sign, + # Ed25519 + generate_identity_keypair, + serialize_ed25519_private, + serialize_ed25519_private_raw, + serialize_ed25519_public, + load_ed25519_private, + load_ed25519_public, + ed25519_sign, + # X25519 + generate_x25519_keypair, + serialize_x25519_private, + serialize_x25519_public, + load_x25519_private, + load_x25519_public, + x25519_dh, + derive_pairing_shared_key, + # X3DH + generate_signed_prekey, + generate_one_time_prekeys, + x3dh_initiate, + x3dh_respond, + # Double Ratchet + DoubleRatchet, + # Sender Keys + SenderKeyState, + # AES + aes_encrypt, + aes_decrypt, + # Self-encryption + derive_self_encryption_key, + # Local storage encryption + derive_local_storage_key, + # Contact verification + compute_fingerprint, + compute_pairing_fingerprint, + encode_pairing_qr, + format_fingerprint, + normalize_pairing_fingerprint, + compute_safety_number, + encode_verification_qr, + decode_verification_qr, + # Message padding + pad_plaintext, + unpad_plaintext, +) +from protocol import ( + VERSION, + ProtocolReader, + ProtocolWriter, + encode_binary, + decode_binary, + build_request, + MAX_MESSAGE_BYTES, + MAX_IMAGE_BYTES, + IMAGE_CHUNK_SIZE, +) + + +KEY_DIR = Path.home() / ".encrypted_chat" +OPK_REPLENISH_THRESHOLD = 20 +OPK_BATCH_SIZE = 50 +SPK_ROTATION_DAYS = 7 +PAIRING_REENCRYPT_INITIAL_DELAY_RANGE = (20.0, 75.0) +PAIRING_REENCRYPT_INTER_BATCH_DELAY_RANGE = (1.0, 3.0) +PAIRING_REENCRYPT_INTER_FETCH_DELAY_RANGE = (0.15, 0.5) +PAIRING_REENCRYPT_BATCH_SIZE = 500 + + +def _encrypt_local(data: bytes, key: bytes) -> bytes: + """Encrypt data with AES-256-GCM for local storage. Format: nonce(12) + tag(16) + ciphertext.""" + _, nonce, ct, tag = aes_encrypt(data, key=key) + return nonce + tag + ct + + +def _decrypt_local(raw: bytes, key: bytes) -> bytes: + """Decrypt data encrypted by _encrypt_local.""" + nonce, tag, ct = raw[:12], raw[12:28], raw[28:] + return aes_decrypt(key, nonce, ct, tag) + + +def get_key_dir(email: str) -> Path: + d = KEY_DIR / email + d.mkdir(parents=True, exist_ok=True) + os.chmod(d, 0o700) + return d + + +# --------------------------------------------------------------------------- +# RSA key storage (login only — unchanged interface) +# --------------------------------------------------------------------------- + +def save_keys(email: str, private_key, public_key, password: bytes | None = None): + d = get_key_dir(email) + (d / "private.pem").write_bytes(serialize_private_key(private_key, password=password)) + (d / "public.pem").write_bytes(serialize_public_key(public_key)) + os.chmod(d / "private.pem", 0o600) + + +def load_keys(email: str, password: bytes | None = None): + d = get_key_dir(email) + priv_path = d / "private.pem" + pub_path = d / "public.pem" + if not priv_path.exists(): + return None, None, "No local keys found." + pem = priv_path.read_bytes() + try: + private_key = load_private_key(pem, password=password) + except Exception: + try: + private_key = load_private_key(pem, password=None) + if password: + save_keys(email, private_key, load_public_key(pub_path.read_bytes()), password=password) + except Exception: + return None, None, "Invalid or missing password." + public_key = load_public_key(pub_path.read_bytes()) + return private_key, public_key, None + + +# --------------------------------------------------------------------------- +# Identity + prekey storage +# --------------------------------------------------------------------------- + +def _save_identity_keys(email: str, ed_priv, ed_pub, password: bytes | None = None): + d = get_key_dir(email) + if password: + (d / "identity_private.bin").write_bytes(serialize_ed25519_private(ed_priv, password=password)) + else: + (d / "identity_private.bin").write_bytes(serialize_ed25519_private_raw(ed_priv)) + (d / "identity_public.bin").write_bytes(serialize_ed25519_public(ed_pub)) + os.chmod(d / "identity_private.bin", 0o600) + + +def _load_identity_keys(email: str, password: bytes | None = None): + d = get_key_dir(email) + priv_path = d / "identity_private.bin" + pub_path = d / "identity_public.bin" + if not priv_path.exists(): + return None, None + priv = load_ed25519_private(priv_path.read_bytes(), password=password) + pub = load_ed25519_public(pub_path.read_bytes()) + return priv, pub + + +def _save_spk(email: str, spk_priv, spk_id: str, local_key: bytes | None = None): + d = get_key_dir(email) + raw = serialize_x25519_private(spk_priv) + data = _encrypt_local(raw, local_key) if local_key else raw + (d / "spk_private.bin").write_bytes(data) + (d / "spk_id.txt").write_text(spk_id) + os.chmod(d / "spk_private.bin", 0o600) + + +def _load_spk(email: str, local_key: bytes | None = None): + d = get_key_dir(email) + priv_path = d / "spk_private.bin" + id_path = d / "spk_id.txt" + if not priv_path.exists(): + return None, None + raw = priv_path.read_bytes() + if local_key: + try: + raw = _decrypt_local(raw, local_key) + except Exception: + # Plaintext fallback (migration) — re-save encrypted + pass + priv = load_x25519_private(raw) + spk_id = id_path.read_text().strip() if id_path.exists() else "" + if local_key: + _save_spk(email, priv, spk_id, local_key) + return priv, spk_id + + +def _save_prev_spk(email: str, spk_priv, spk_id: str, local_key: bytes | None = None): + """Save previous SPK for grace period (in-flight X3DH may reference old SPK).""" + d = get_key_dir(email) + raw = serialize_x25519_private(spk_priv) + data = _encrypt_local(raw, local_key) if local_key else raw + (d / "prev_spk_private.bin").write_bytes(data) + (d / "prev_spk_id.txt").write_text(spk_id) + os.chmod(d / "prev_spk_private.bin", 0o600) + + +def _load_prev_spk(email: str, local_key: bytes | None = None): + """Load previous SPK (grace period). Returns (private_key, spk_id) or (None, None).""" + d = get_key_dir(email) + priv_path = d / "prev_spk_private.bin" + id_path = d / "prev_spk_id.txt" + if not priv_path.exists(): + return None, None + raw = priv_path.read_bytes() + if local_key: + try: + raw = _decrypt_local(raw, local_key) + except Exception: + pass + priv = load_x25519_private(raw) + spk_id = id_path.read_text().strip() if id_path.exists() else "" + if local_key: + _save_prev_spk(email, priv, spk_id, local_key) + return priv, spk_id + + +def _save_opk_private(email: str, opk_id: str, opk_priv, local_key: bytes | None = None): + d = get_key_dir(email) / "opk_private" + d.mkdir(parents=True, exist_ok=True) + os.chmod(d, 0o700) + raw = serialize_x25519_private(opk_priv) + data = _encrypt_local(raw, local_key) if local_key else raw + (d / f"{opk_id}.bin").write_bytes(data) + os.chmod(d / f"{opk_id}.bin", 0o600) + + +def _load_opk_private(email: str, opk_id: str, local_key: bytes | None = None): + d = get_key_dir(email) / "opk_private" + p = d / f"{opk_id}.bin" + if not p.exists(): + return None + raw = p.read_bytes() + if local_key: + try: + raw = _decrypt_local(raw, local_key) + except Exception: + pass + priv = load_x25519_private(raw) + # Migration: re-save encrypted if local_key provided + if local_key: + _save_opk_private(email, opk_id, priv, local_key) + return priv + + +def _secure_delete(p: Path): + """Overwrite file with random data before deletion (anti-forensic wipe).""" + try: + if not p.exists(): + return + size = p.stat().st_size + if size > 0: + with open(p, "r+b") as f: + f.write(os.urandom(size)) + f.flush() + os.fsync(f.fileno()) + p.unlink() + except Exception: + try: + p.unlink(missing_ok=True) + except Exception: + pass + + +def _delete_opk_private(email: str, opk_id: str): + d = get_key_dir(email) / "opk_private" + p = d / f"{opk_id}.bin" + _secure_delete(p) + + +def _save_device_id(email: str, device_id: str): + d = get_key_dir(email) + p = d / "device_id.txt" + p.write_text(device_id) + os.chmod(p, 0o600) + + +def _load_device_id(email: str) -> str | None: + d = get_key_dir(email) + p = d / "device_id.txt" + if not p.exists(): + return None + return p.read_text().strip() or None + + +# ------------------------------------------------------------------ +# Expiring LRU cache for user info (max_size + TTL eviction) +# ------------------------------------------------------------------ + +class _ExpiringLRUCache: + """Dict-like cache with max size (LRU eviction) and per-entry TTL.""" + + def __init__(self, max_size: int = 10_000, ttl: float = 3600.0): + self._max_size = max_size + self._ttl = ttl + self._data: collections.OrderedDict = collections.OrderedDict() # key -> (value, ts) + + def get(self, key, default=None): + entry = self._data.get(key) + if entry is None: + return default + value, ts = entry + if time.monotonic() - ts > self._ttl: + del self._data[key] + return default + # Move to end (most recently used) + self._data.move_to_end(key) + return value + + def __setitem__(self, key, value): + if key in self._data: + self._data.move_to_end(key) + self._data[key] = (value, time.monotonic()) + while len(self._data) > self._max_size: + self._data.popitem(last=False) + + def __getitem__(self, key): + result = self.get(key) + if result is None and key not in self._data: + raise KeyError(key) + return result + + def __contains__(self, key): + return self.get(key) is not None + + +# ------------------------------------------------------------------ +# Identity key change exception (TOFU hard-fail) +# ------------------------------------------------------------------ + +class IdentityKeyChanged(Exception): + """Raised when a peer's identity key has changed (TOFU violation). + + Session creation is blocked until the user explicitly accepts the new key. + """ + def __init__(self, user_id: str, new_key_bytes: bytes, status: str): + self.user_id = user_id + self.new_key_bytes = new_key_bytes + self.status = status # "changed" or "changed_verified" + super().__init__( + f"Identity key changed for user {user_id} (status={status}). " + f"Accept the new key before communicating." + ) + + +# ------------------------------------------------------------------ +# Client-side brute-force lockout +# ------------------------------------------------------------------ + +_LOCKOUT_BASE_SECONDS = 2 +_LOCKOUT_MAX_SECONDS = 300 # 5 min cap + + +def _get_lockout_path(email: str) -> Path: + return get_key_dir(email) / "login_lockout.json" + + +def _check_lockout(email: str) -> float: + """Return seconds remaining until next attempt allowed. 0 = can try now.""" + p = _get_lockout_path(email) + if not p.exists(): + return 0.0 + try: + data = json.loads(p.read_text()) + locked_until = data.get("locked_until", 0.0) + remaining = locked_until - time.time() + return max(0.0, remaining) + except Exception: + return 0.0 + + +def _record_failed_attempt(email: str): + """Increment failed counter, update locked_until.""" + p = _get_lockout_path(email) + failed = 0 + try: + if p.exists(): + data = json.loads(p.read_text()) + failed = data.get("failed_attempts", 0) + except Exception: + pass + failed += 1 + delay = min(_LOCKOUT_BASE_SECONDS ** failed, _LOCKOUT_MAX_SECONDS) + locked_until = time.time() + delay + p.write_text(json.dumps({"failed_attempts": failed, "locked_until": locked_until})) + os.chmod(p, 0o600) + + +def _clear_lockout(email: str): + """Reset on successful login.""" + p = _get_lockout_path(email) + if p.exists(): + try: + p.unlink() + except Exception: + pass + + +def _save_session(email: str, peer_user_id: str, ratchet: DoubleRatchet, + local_key: bytes | None = None, peer_device_id: str | None = None): + d = get_key_dir(email) / "sessions" + d.mkdir(parents=True, exist_ok=True) + os.chmod(d, 0o700) + if peer_device_id: + filename = f"{peer_user_id}_{peer_device_id}.bin" + else: + filename = f"{peer_user_id}.bin" + p = d / filename + data = ratchet.export_state() + if local_key: + data = _encrypt_local(data, local_key) + p.write_bytes(data) + os.chmod(p, 0o600) + + +def _load_session(email: str, peer_user_id: str, + local_key: bytes | None = None, + peer_device_id: str | None = None) -> DoubleRatchet | None: + d = get_key_dir(email) / "sessions" + if peer_device_id: + p = d / f"{peer_user_id}_{peer_device_id}.bin" + if not p.exists(): + # Fallback: try old format (no device_id) and migrate + p_old = d / f"{peer_user_id}.bin" + if p_old.exists(): + ratchet = _load_session_file(p_old, local_key) + if ratchet: + _save_session(email, peer_user_id, ratchet, local_key, + peer_device_id=peer_device_id) + _secure_delete(p_old) + return ratchet + return None + else: + p = d / f"{peer_user_id}.bin" + if not p.exists(): + return None + return _load_session_file(p, local_key) + + +def _load_session_file(p: Path, local_key: bytes | None = None) -> DoubleRatchet | None: + """Load a session from a specific file path.""" + if not p.exists(): + return None + raw = p.read_bytes() + if local_key: + try: + data = _decrypt_local(raw, local_key) + except Exception: + # Migration: try loading as plaintext (old unencrypted format) + try: + ratchet = DoubleRatchet.import_state(raw) + return ratchet + except Exception: + return None + return DoubleRatchet.import_state(data) + return DoubleRatchet.import_state(raw) + + +def _delete_session_file(email: str, peer_user_id: str, peer_device_id: str | None = None): + """Securely delete a session file from disk (for session reset).""" + d = get_key_dir(email) / "sessions" + if peer_device_id: + p = d / f"{peer_user_id}_{peer_device_id}.bin" + else: + p = d / f"{peer_user_id}.bin" + _secure_delete(p) + + +def _save_sender_key_state(email: str, conv_id: str, state: SenderKeyState, + local_key: bytes | None = None): + d = get_key_dir(email) / "sender_keys" + d.mkdir(parents=True, exist_ok=True) + os.chmod(d, 0o700) + p = d / f"{conv_id}.bin" + data = state.export_state() + if local_key: + data = _encrypt_local(data, local_key) + p.write_bytes(data) + os.chmod(p, 0o600) + + +def _load_sender_key_state(email: str, conv_id: str, + local_key: bytes | None = None) -> SenderKeyState | None: + d = get_key_dir(email) / "sender_keys" + p = d / f"{conv_id}.bin" + if not p.exists(): + return None + raw = p.read_bytes() + if local_key: + try: + data = _decrypt_local(raw, local_key) + except Exception: + try: + sk = SenderKeyState.import_state(raw) + _save_sender_key_state(email, conv_id, sk, local_key) + return sk + except Exception: + return None + return SenderKeyState.import_state(data) + return SenderKeyState.import_state(raw) + + +def _save_sender_key_recipients(email: str, conv_id: str, recipients: set[str], + local_key: bytes | None = None): + d = get_key_dir(email) / "sender_keys" + d.mkdir(parents=True, exist_ok=True) + os.chmod(d, 0o700) + p = d / f"{conv_id}.members.bin" + data = json.dumps(sorted(recipients), ensure_ascii=False).encode("utf-8") + if local_key: + data = _encrypt_local(data, local_key) + p.write_bytes(data) + os.chmod(p, 0o600) + + +def _load_sender_key_recipients(email: str, conv_id: str, + local_key: bytes | None = None) -> set[str]: + d = get_key_dir(email) / "sender_keys" + p = d / f"{conv_id}.members.bin" + if not p.exists(): + return set() + raw = p.read_bytes() + if local_key: + try: + data = _decrypt_local(raw, local_key) + except Exception: + # Migration: previous plaintext storage, re-save encrypted on success. + try: + parsed = json.loads(raw.decode("utf-8")) + recipients = {str(uid) for uid in parsed if isinstance(uid, str)} + _save_sender_key_recipients(email, conv_id, recipients, local_key) + return recipients + except Exception: + return set() + else: + data = raw + try: + parsed = json.loads(data.decode("utf-8")) + if not isinstance(parsed, list): + return set() + return {str(uid) for uid in parsed if isinstance(uid, str)} + except Exception: + return set() + + +def _save_recv_sender_key(email: str, conv_id: str, sender_id: str, state: SenderKeyState, + local_key: bytes | None = None, + sender_device_id: str | None = None): + d = get_key_dir(email) / "sender_keys_recv" + d.mkdir(parents=True, exist_ok=True) + os.chmod(d, 0o700) + if sender_device_id: + filename = f"{conv_id}_{sender_id}_{sender_device_id}.bin" + else: + filename = f"{conv_id}_{sender_id}.bin" + p = d / filename + data = state.export_state() + if local_key: + data = _encrypt_local(data, local_key) + p.write_bytes(data) + os.chmod(p, 0o600) + + +def _load_recv_sender_key(email: str, conv_id: str, sender_id: str, + local_key: bytes | None = None, + sender_device_id: str | None = None) -> SenderKeyState | None: + d = get_key_dir(email) / "sender_keys_recv" + if sender_device_id: + p = d / f"{conv_id}_{sender_id}_{sender_device_id}.bin" + if not p.exists(): + # Fallback: try old format and migrate + p_old = d / f"{conv_id}_{sender_id}.bin" + if p_old.exists(): + sk = _load_recv_sender_key_file(p_old, local_key) + if sk: + _save_recv_sender_key(email, conv_id, sender_id, sk, local_key, + sender_device_id=sender_device_id) + _secure_delete(p_old) + return sk + return None + else: + p = d / f"{conv_id}_{sender_id}.bin" + if not p.exists(): + return None + return _load_recv_sender_key_file(p, local_key) + + +def _load_recv_sender_key_file(p: Path, local_key: bytes | None = None) -> SenderKeyState | None: + """Load a recv sender key from a specific file path.""" + if not p.exists(): + return None + raw = p.read_bytes() + if local_key: + try: + data = _decrypt_local(raw, local_key) + except Exception: + try: + sk = SenderKeyState.import_state(raw) + return sk + except Exception: + return None + return SenderKeyState.import_state(data) + return SenderKeyState.import_state(raw) + + +# --------------------------------------------------------------------------- +# Local decrypted message cache (Double Ratchet keys are one-time use) +# --------------------------------------------------------------------------- + +def _load_message_cache(email: str, conv_id: str, cache_key: bytes | None = None) -> dict: + d = get_key_dir(email) / "message_cache" + p_bin = d / f"{conv_id}.bin" + p_json = d / f"{conv_id}.json" + + # Migration: if old plaintext .json exists but encrypted .bin doesn't + if p_json.exists() and not p_bin.exists(): + try: + cache = json.loads(p_json.read_text("utf-8")) + if cache_key: + _save_message_cache_full(d, conv_id, cache, cache_key) + _secure_delete(p_json) + return cache + except Exception: + return {} + + if not p_bin.exists(): + return {} + if not cache_key: + return {} + try: + raw = p_bin.read_bytes() + # Format: nonce (12) + tag (16) + ciphertext + nonce = raw[:12] + tag = raw[12:28] + ct = raw[28:] + plaintext = aes_decrypt(cache_key, nonce, ct, tag) + return json.loads(plaintext.decode("utf-8")) + except Exception: + return {} + + +def _save_message_cache_full(d: Path, conv_id: str, cache: dict, cache_key: bytes): + """Write the full cache dict encrypted to disk.""" + d.mkdir(parents=True, exist_ok=True) + os.chmod(d, 0o700) + p = d / f"{conv_id}.bin" + plaintext = json.dumps(cache, ensure_ascii=False).encode("utf-8") + _key, nonce, ct, tag = aes_encrypt(plaintext, key=cache_key) + p.write_bytes(nonce + tag + ct) + os.chmod(p, 0o600) + + +def _save_message_to_cache(email: str, conv_id: str, message_id: str, payload: dict, + cache_key: bytes | None = None): + d = get_key_dir(email) / "message_cache" + cache = _load_message_cache(email, conv_id, cache_key) + cache[message_id] = payload + if cache_key: + _save_message_cache_full(d, conv_id, cache, cache_key) + else: + # Fallback: plaintext (no identity key available yet) + d.mkdir(parents=True, exist_ok=True) + os.chmod(d, 0o700) + p = d / f"{conv_id}.json" + p.write_text(json.dumps(cache, ensure_ascii=False), "utf-8") + os.chmod(p, 0o600) + + +# --------------------------------------------------------------------------- +# Verification storage (TOFU + explicit verification) +# --------------------------------------------------------------------------- + +def _save_known_identity_keys(email: str, keys: dict, local_key: bytes | None = None): + """Save TOFU identity key registry (encrypted with local_key).""" + p = get_key_dir(email) / "known_identity_keys.bin" + data = json.dumps({"version": 1, "keys": keys}).encode("utf-8") + if local_key: + data = _encrypt_local(data, local_key) + p.write_bytes(data) + os.chmod(p, 0o600) + + +def _load_known_identity_keys(email: str, local_key: bytes | None = None) -> dict: + """Load TOFU identity key registry. Returns empty dict on error. + + No plaintext fallback — these files were never stored unencrypted + (feature introduced after local encryption was implemented). + Accepting plaintext would allow an attacker with disk access to + inject fake identity keys and bypass TOFU warnings. + """ + p = get_key_dir(email) / "known_identity_keys.bin" + if not p.exists(): + return {} + raw = p.read_bytes() + try: + if local_key: + data = _decrypt_local(raw, local_key) + else: + data = raw + obj = json.loads(data) + return obj.get("keys", {}) + except Exception: + return {} + + +def _save_verified_contacts(email: str, contacts: dict, local_key: bytes | None = None): + """Save explicit verification state (encrypted with local_key).""" + p = get_key_dir(email) / "verified_contacts.bin" + data = json.dumps({"version": 1, "contacts": contacts}).encode("utf-8") + if local_key: + data = _encrypt_local(data, local_key) + p.write_bytes(data) + os.chmod(p, 0o600) + + +def _load_verified_contacts(email: str, local_key: bytes | None = None) -> dict: + """Load explicit verification state. Returns empty dict on error. + + No plaintext fallback — these files were never stored unencrypted. + Accepting plaintext would allow an attacker with disk access to + inject fake verification records (mark attacker as "verified"). + """ + p = get_key_dir(email) / "verified_contacts.bin" + if not p.exists(): + return {} + raw = p.read_bytes() + try: + if local_key: + data = _decrypt_local(raw, local_key) + else: + data = raw + obj = json.loads(data) + return obj.get("contacts", {}) + except Exception: + return {} + + +def _solve_pow(challenge: str, difficulty: int) -> str: + """Solve a proof-of-work challenge by finding a nonce with enough leading zero bits.""" + target_bytes = difficulty // 8 + target_bits = difficulty % 8 + mask = (0xFF << (8 - target_bits)) & 0xFF if target_bits else 0 + nonce = 0 + while True: + digest = hashlib.sha256(f"{challenge}{nonce}".encode()).digest() + # Fast path: check full zero bytes first + ok = True + for i in range(target_bytes): + if digest[i] != 0: + ok = False + break + if ok and target_bits: + if digest[target_bytes] & mask: + ok = False + if ok: + return str(nonce) + nonce += 1 + + +class ChatClient: + def __init__(self): + self.reader: ProtocolReader | None = None + self.writer: ProtocolWriter | None = None + self.raw_writer: asyncio.StreamWriter | None = None + self.session: dict | None = None + self.private_key = None # RSA private key (login only) + self.public_key = None # RSA public key (login only) + self.username: str = "" + self.email: str = "" + self._listener_task: asyncio.Task | None = None + self._response_queue: asyncio.Queue = asyncio.Queue() + self._notification_queue: asyncio.Queue = asyncio.Queue() + self._pending: dict[str, asyncio.Future] = {} + self._pairing_temp_private_key = None + self._pairing_fingerprint: str = "" + self._pairing_code: str = "" + self._reencrypt_progress_cb = None + self._logger = logging.getLogger("encrypted_chat.client") + + # Signal Protocol keys + self.identity_private = None # Ed25519PrivateKey + self.identity_public = None # Ed25519PublicKey + self.spk_private = None # X25519PrivateKey (current signed prekey) + self.spk_id: str = "" + self._prev_spk_private = None # Previous SPK for grace period (M4) + self._prev_spk_id: str = "" + self.opk_privates: dict[str, object] = {} # id -> X25519PrivateKey + self.sessions: dict[str, DoubleRatchet] = {} # "user_id:device_id" -> ratchet + self.sender_key_states: dict[str, SenderKeyState] = {} # conv_id -> own sender key + self.recv_sender_keys: dict[str, SenderKeyState] = {} # "conv_id:sender_id:device_id" -> their key + # Cache: user_id -> {identity_key (Ed25519PublicKey), username, email} + # Bounded to 10K entries with 1-hour TTL to prevent unbounded growth (L7) + self._user_cache: _ExpiringLRUCache = _ExpiringLRUCache(max_size=10_000, ttl=3600.0) + self.connected: bool = False + self.login_rejected: bool = False + self._cache_key: bytes | None = None # AES key for encrypting message cache on disk + self._local_key: bytes | None = None # AES key for encrypting session/sender key files + # Multi-device support + self.device_id: str | None = None # This device's UUID + self._device_bundle_cache: dict[str, tuple[float, list[dict]]] = {} # user_id -> (ts, bundles) + # Queue of received messages to self-encrypt for multi-device access + self._pending_self_encrypt: list[dict] = [] + self._typing_active: dict[str, bool] = {} + self._typing_last_sent: dict[str, float] = {} + self._typing_stop_tasks: dict[str, asyncio.Task] = {} + # Contact key verification (TOFU + explicit) + self._known_identity_keys: dict = {} # user_id -> {identity_key hex, first_seen, last_seen} + self._verified_contacts: dict = {} # user_id -> {identity_key hex, verified_at, method} + self._key_change_cb = None # callback(user_id, username, old_key_hex, was_verified) + + async def connect(self): + host = os.getenv("SERVER_HOST", "127.0.0.1") + port = int(os.getenv("SERVER_PORT", "9999")) + tls_enabled = os.getenv("TLS_ENABLED", "false").lower() in ("1", "true", "yes") + tls_required = os.getenv("TLS_REQUIRED", "false").lower() in ("1", "true", "yes") + ssl_context = None + if tls_required and not tls_enabled: + raise RuntimeError("TLS_REQUIRED is enabled but TLS is not enabled.") + if tls_enabled: + insecure = os.getenv("TLS_INSECURE", "false").lower() in ("1", "true", "yes") + is_dev = os.getenv("ENVIRONMENT", "").lower() in ("dev", "development") + if insecure and not is_dev: + raise RuntimeError("TLS_INSECURE is only allowed when ENVIRONMENT=dev") + ssl_context = ssl.create_default_context() + ca_file = os.getenv("TLS_CA_FILE", "").strip() + if ca_file: + ssl_context.load_verify_locations(cafile=ca_file) + elif insecure: + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + else: + self._logger.warning("TLS is disabled — traffic is unencrypted. Set TLS_ENABLED=true for production.") + r, w = await asyncio.open_connection(host, port, limit=MAX_MESSAGE_BYTES, ssl=ssl_context) + # Enable TCP keepalive to detect dead connections through NAT/firewalls + sock = w.get_extra_info("socket") + if sock is not None: + import socket + try: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + if hasattr(socket, "TCP_KEEPIDLE"): + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 25) + if hasattr(socket, "TCP_KEEPINTVL"): + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 10) + if hasattr(socket, "TCP_KEEPCNT"): + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 3) + except OSError: + pass + self.reader = ProtocolReader(r) + self.writer = ProtocolWriter(w) + self.raw_writer = w + self.connected = True + self._logger.info("Connected to %s:%s (tls=%s)", host, port, "on" if tls_enabled else "off") + + def server_endpoint(self) -> str: + host = os.getenv("SERVER_HOST", "127.0.0.1") + port = os.getenv("SERVER_PORT", "9999") + return f"{host}:{port}" + + def pairing_fingerprint(self) -> str: + return self._pairing_fingerprint + + def pairing_qr_data(self) -> bytes | None: + if not self._pairing_code or not self._pairing_fingerprint: + return None + return encode_pairing_qr(self._pairing_code, self._pairing_fingerprint) + + async def _background_listener(self): + """Read messages from server, routing responses vs notifications.""" + while True: + msg = await self.reader.read_message() + if msg is None: + self.connected = False + # Fail all pending futures so send_and_recv doesn't hang + pending = dict(self._pending) + self._pending.clear() + err = ConnectionError("Server connection lost") + for obj in pending.values(): + if isinstance(obj, asyncio.Queue): + # Signal stream consumers that connection died + obj.put_nowait({"status": "error", "data": {"message": "Connection lost"}}) + elif not obj.done(): + obj.set_exception(err) + break + # Responses to our own requests (have request_id matching a pending future) + # must be routed to the pending future, even if the type matches a notification name. + req_id = msg.get("request_id") + if req_id and req_id in self._pending: + pending_obj = self._pending[req_id] + if isinstance(pending_obj, asyncio.Queue): + await pending_obj.put(msg) + else: + self._pending.pop(req_id) + if not pending_obj.done(): + pending_obj.set_result(msg) + elif msg.get("type") in ("new_message", "messages_read", "message_deleted", + "conversation_created", "member_added", "member_removed", + "user_online", "user_offline", "online_users", + "group_invitation", "conversation_renamed", + "device_added", + "session_reset", + "message_reacted", "message_pinned", "message_unpinned", + "message_delivered", "username_changed", + "avatar_changed", + "typing_start", "typing_stop"): + await self._notification_queue.put(msg) + else: + await self._response_queue.put(msg) + + async def send_and_recv(self, msg_type: str, timeout: float = 30.0, **kwargs) -> dict: + try: + request_id = str(uuid.uuid4()) + loop = asyncio.get_running_loop() + fut = loop.create_future() + self._pending[request_id] = fut + await self.writer.send_request(msg_type, request_id=request_id, **kwargs) + except (ValueError, ConnectionError, OSError) as e: + self._pending.pop(request_id, None) + return { + "type": msg_type, + "status": "error", + "data": {"message": str(e) or "Connection lost."}, + } + try: + return await asyncio.wait_for(fut, timeout=timeout) + except asyncio.TimeoutError: + self._logger.warning("send_and_recv timeout for '%s' after %.0fs", msg_type, timeout) + return { + "type": msg_type, + "status": "error", + "data": {"message": f"Request timed out ({msg_type})"}, + } + except ConnectionError: + return { + "type": msg_type, + "status": "error", + "data": {"message": "Connection lost."}, + } + finally: + self._pending.pop(request_id, None) + + # ------------------------------------------------------------------ + # User info / identity key cache + # ------------------------------------------------------------------ + + async def _get_user_info(self, user_id: str = "", email: str = "") -> dict | None: + """Get user info from server, cache identity key. Performs TOFU check.""" + cached = self._user_cache.get(user_id) + if cached: + return cached + kwargs = {} + if user_id: + kwargs["user_id"] = user_id + elif email: + kwargs["email"] = email + else: + return None + resp = await self.send_and_recv("get_user_info", **kwargs) + if resp["status"] != "ok": + return None + data = resp["data"] + ik_bytes = decode_binary(data["identity_key"]) if data.get("identity_key") else None + info = { + "user_id": data["user_id"], + "username": data["username"], + "email": data["email"], + "identity_key": load_ed25519_public(ik_bytes) if ik_bytes else None, + "identity_key_bytes": ik_bytes, + } + # TOFU: check identity key against known keys + if ik_bytes: + status = self.check_identity_key(data["user_id"], ik_bytes) + info["identity_key_status"] = status + self._user_cache[data["user_id"]] = info + return info + + # ------------------------------------------------------------------ + # Contact Key Verification + # ------------------------------------------------------------------ + + def _load_verification_stores(self): + """Load TOFU and verification stores from disk.""" + if not self.email: + return + self._known_identity_keys = _load_known_identity_keys(self.email, self._local_key) + self._verified_contacts = _load_verified_contacts(self.email, self._local_key) + + def check_identity_key(self, user_id: str, identity_key_bytes: bytes) -> str: + """Check a user's identity key against TOFU registry. + + Returns: + "new" — first contact, key recorded (TOFU) + "trusted" — key matches previously seen, not explicitly verified + "verified" — key matches and explicitly verified + "changed" — key differs from recorded (WARNING) + "changed_verified" — key changed AND was previously verified (CRITICAL) + """ + ik_hex = identity_key_bytes.hex() + now = datetime.now(timezone.utc).isoformat() + known = self._known_identity_keys.get(user_id) + + if known is None: + # First time seeing this user — TOFU: trust on first use + self._known_identity_keys[user_id] = { + "identity_key": ik_hex, + "first_seen": now, + "last_seen": now, + } + if self.email: + _save_known_identity_keys(self.email, self._known_identity_keys, self._local_key) + return "new" + + if known["identity_key"] == ik_hex: + # Key matches — update last_seen + known["last_seen"] = now + if self.email: + _save_known_identity_keys(self.email, self._known_identity_keys, self._local_key) + # Check if explicitly verified + verified = self._verified_contacts.get(user_id) + if verified and verified.get("identity_key") == ik_hex: + return "verified" + return "trusted" + + # Key has CHANGED + was_verified = user_id in self._verified_contacts + old_key_hex = known["identity_key"] + + # Invoke callback for GUI/CLI warning + if self._key_change_cb: + username = "" + cached = self._user_cache.get(user_id) + if cached: + username = cached.get("username", "") + try: + self._key_change_cb(user_id, username, old_key_hex, was_verified, identity_key_bytes) + except Exception: + pass + + return "changed_verified" if was_verified else "changed" + + def verify_contact(self, user_id: str, identity_key_bytes: bytes, method: str = "manual"): + """Mark a contact's identity key as explicitly verified.""" + ik_hex = identity_key_bytes.hex() + now = datetime.now(timezone.utc).isoformat() + self._verified_contacts[user_id] = { + "identity_key": ik_hex, + "verified_at": now, + "method": method, + } + # Also ensure TOFU registry is up to date + if user_id not in self._known_identity_keys: + self._known_identity_keys[user_id] = { + "identity_key": ik_hex, + "first_seen": now, + "last_seen": now, + } + else: + self._known_identity_keys[user_id]["last_seen"] = now + if self.email: + _save_verified_contacts(self.email, self._verified_contacts, self._local_key) + _save_known_identity_keys(self.email, self._known_identity_keys, self._local_key) + # Update user cache status + cached = self._user_cache.get(user_id) + if cached: + cached["identity_key_status"] = "verified" + + def unverify_contact(self, user_id: str): + """Remove explicit verification for a contact.""" + self._verified_contacts.pop(user_id, None) + if self.email: + _save_verified_contacts(self.email, self._verified_contacts, self._local_key) + cached = self._user_cache.get(user_id) + if cached and cached.get("identity_key_status") == "verified": + cached["identity_key_status"] = "trusted" + + def accept_key_change(self, user_id: str, new_ik_bytes: bytes): + """Accept a changed identity key — update TOFU, remove old verification.""" + ik_hex = new_ik_bytes.hex() + now = datetime.now(timezone.utc).isoformat() + self._known_identity_keys[user_id] = { + "identity_key": ik_hex, + "first_seen": now, + "last_seen": now, + } + # Remove old verification — user must re-verify + self._verified_contacts.pop(user_id, None) + if self.email: + _save_known_identity_keys(self.email, self._known_identity_keys, self._local_key) + _save_verified_contacts(self.email, self._verified_contacts, self._local_key) + # Update cache + cached = self._user_cache.get(user_id) + if cached: + cached["identity_key_status"] = "trusted" + + def get_verification_status(self, user_id: str) -> str: + """Get verification status for a user. + + Returns: "verified", "trusted", or "unverified". + """ + verified = self._verified_contacts.get(user_id) + if verified: + # Check key still matches + known = self._known_identity_keys.get(user_id) + if known and known.get("identity_key") == verified.get("identity_key"): + return "verified" + if user_id in self._known_identity_keys: + return "trusted" + return "unverified" + + def get_safety_number(self, peer_user_id: str) -> str | None: + """Get formatted safety number for a peer (requires both identity keys).""" + if not self.identity_public or not self.session: + return None + my_uid = self.session.get("user_id", "") + my_ik_bytes = serialize_ed25519_public(self.identity_public) + cached = self._user_cache.get(peer_user_id) + if not cached or not cached.get("identity_key_bytes"): + return None + return compute_safety_number(my_uid, my_ik_bytes, + peer_user_id, cached["identity_key_bytes"]) + + def get_my_fingerprint(self) -> str | None: + """Get formatted fingerprint for own identity key.""" + if not self.identity_public or not self.session: + return None + my_uid = self.session.get("user_id", "") + my_ik_bytes = serialize_ed25519_public(self.identity_public) + fp = compute_fingerprint(my_uid, my_ik_bytes) + return format_fingerprint(fp) + + def get_peer_fingerprint(self, peer_user_id: str) -> str | None: + """Get formatted fingerprint for a peer's identity key.""" + cached = self._user_cache.get(peer_user_id) + if not cached or not cached.get("identity_key_bytes"): + return None + fp = compute_fingerprint(peer_user_id, cached["identity_key_bytes"]) + return format_fingerprint(fp) + + def get_verification_qr_data(self) -> bytes | None: + """Get QR code payload bytes for own identity (for peer to scan).""" + if not self.identity_public or not self.session: + return None + my_uid = self.session.get("user_id", "") + my_ik_bytes = serialize_ed25519_public(self.identity_public) + return encode_verification_qr(my_uid, my_ik_bytes) + + def verify_qr_code(self, qr_data: bytes) -> tuple[bool, str, str]: + """Verify a scanned QR code against known identity keys. + + Returns (success, user_id, message). + """ + try: + user_id, ik_bytes = decode_verification_qr(qr_data) + except ValueError as e: + return False, "", f"Invalid QR code: {e}" + cached = self._user_cache.get(user_id) + if not cached: + return False, user_id, "Unknown user — not in your contacts." + if not cached.get("identity_key_bytes"): + return False, user_id, "No identity key on record for this user." + if cached["identity_key_bytes"] != ik_bytes: + return False, user_id, "Identity key MISMATCH — verification failed!" + # Keys match — mark as verified + self.verify_contact(user_id, ik_bytes, method="qr_code") + username = cached.get("username", user_id[:8]) + return True, user_id, f"Verified {username} via QR code." + + # ------------------------------------------------------------------ + # Registration + # ------------------------------------------------------------------ + + async def register(self, username: str, password: str, email: str) -> tuple[bool, str]: + """Register user. Generates RSA + Ed25519 in memory (saved to disk + only after server confirms registration via confirm_registration).""" + self.username = username + self.email = email + pwd_bytes = bytearray(password.encode("utf-8")) if password else None + + try: + pwd = bytes(pwd_bytes) if pwd_bytes else None + # Try loading existing keys (previous successful registration) + priv, pub, err = load_keys(email, password=pwd) + if priv is None: + priv, pub = generate_rsa_keypair() + self.private_key = priv + self.public_key = pub + + try: + ed_priv, ed_pub = _load_identity_keys(email, password=pwd) + except Exception: + ed_priv, ed_pub = None, None + if ed_priv is None: + ed_priv, ed_pub = generate_identity_keypair() + self.identity_private = ed_priv + self.identity_public = ed_pub + self._cache_key = derive_self_encryption_key(ed_priv) + self._local_key = derive_local_storage_key(ed_priv) + + # Store password for saving keys after confirm + self._reg_password = pwd + finally: + if pwd_bytes: + pwd_bytes[:] = b'\x00' * len(pwd_bytes) + + pub_pem = serialize_public_key(pub).decode("utf-8") + ik_b64 = encode_binary(serialize_ed25519_public(ed_pub)) + + extra_fields: dict = {} + start = await self.send_and_recv( + "register", + username=username, + public_key=pub_pem, + email=email, + identity_key=ik_b64, + ) + # Handle PoW challenge (server under pressure) + if start.get("status") == "pow_required": + challenge = start["data"]["challenge"] + mac = start["data"]["mac"] + difficulty = start["data"]["difficulty"] + logger.info("Server requires proof-of-work (difficulty %d), solving...", difficulty) + nonce = _solve_pow(challenge, difficulty) + extra_fields = {"pow_challenge": challenge, "pow_mac": mac, "pow_nonce": nonce} + start = await self.send_and_recv( + "register", + username=username, + public_key=pub_pem, + email=email, + identity_key=ik_b64, + **extra_fields, + ) + if start["status"] != "ok": + self._reg_password = None + return False, start["data"]["message"] + code = start["data"].get("code") + if code: + return True, code + return True, start["data"].get("message", "Check your email for the code.") + + async def confirm_registration(self, email: str, username: str, code: str) -> tuple[bool, str]: + confirm = await self.send_and_recv("register_confirm", email=email, code=code) + if confirm["status"] == "ok": + # Registration confirmed — NOW save keys to disk + pwd = getattr(self, "_reg_password", None) + save_keys(email, self.private_key, self.public_key, password=pwd) + _save_identity_keys(email, self.identity_private, self.identity_public, password=pwd) + self._reg_password = None + self._load_verification_stores() + return True, f"Registered as '{username}' (ID: {confirm['data']['user_id']})" + return False, confirm["data"]["message"] + + async def _generate_and_upload_prekeys(self, keep_spk: bool = False): + """Generate SPK + OPKs and upload to server. + + If keep_spk=True, re-sign the existing SPK instead of generating a new + one. This is used after device pairing so both devices share the same + SPK and either can respond to X3DH. + """ + if not self.identity_private: + return + + if keep_spk and self.spk_private and self.spk_id: + # Re-sign existing SPK (both devices share the identity key) + spk_pub_bytes = serialize_x25519_public(self.spk_private.public_key()) + spk_sig = ed25519_sign(self.identity_private, spk_pub_bytes) + spk_data = { + "id": self.spk_id, + "public_key": encode_binary(spk_pub_bytes), + "signature": encode_binary(spk_sig), + } + else: + # Save current SPK as previous for grace period (M4: in-flight X3DH) + if self.spk_private and self.spk_id: + self._prev_spk_private = self.spk_private + self._prev_spk_id = self.spk_id + _save_prev_spk(self.email, self.spk_private, self.spk_id, self._local_key) + # Generate a brand-new signed prekey + spk = generate_signed_prekey(self.identity_private) + self.spk_private = spk["private"] + self.spk_id = spk["id"] + _save_spk(self.email, spk["private"], spk["id"], self._local_key) + spk_data = { + "id": spk["id"], + "public_key": encode_binary(serialize_x25519_public(spk["public"])), + "signature": encode_binary(spk["signature"]), + } + + # Generate one-time prekeys + opks = generate_one_time_prekeys(OPK_BATCH_SIZE) + for opk in opks: + self.opk_privates[opk["id"]] = opk["private"] + _save_opk_private(self.email, opk["id"], opk["private"], self._local_key) + + # Upload to server + otp_data = [ + {"id": opk["id"], "public_key": encode_binary(serialize_x25519_public(opk["public"]))} + for opk in opks + ] + resp = await self.send_and_recv( + "upload_prekeys", + signed_prekey=spk_data, + one_time_prekeys=otp_data, + ) + if resp.get("status") != "ok": + self._logger.warning("upload_prekeys failed: %s (will retry on login)", + resp.get("data", {}).get("message", "unknown")) + + async def _ensure_prekeys(self): + """Check OPK count and SPK age, replenish/rotate if needed. + + Uses single-roundtrip `ensure_prekeys` handler when available, + falls back to legacy two-step flow (get_prekey_count + upload_prekeys). + """ + resp = await self.send_and_recv("get_prekey_count") + if resp["status"] != "ok": + return + count = resp["data"].get("count", 0) + spk_created_at = resp["data"].get("spk_created_at", "") + + need_new_spk = False + if spk_created_at: + try: + created = datetime.fromisoformat(spk_created_at) + if created.tzinfo is None: + created = created.replace(tzinfo=timezone.utc) + age_days = (datetime.now(timezone.utc) - created).days + if age_days >= SPK_ROTATION_DAYS: + need_new_spk = True + self._logger.info("SPK is %d days old, rotating...", age_days) + except Exception: + need_new_spk = True + else: + # No SPK on server for this device — must upload one + need_new_spk = True + self._logger.info("No SPK on server for this device, uploading...") + + if count < OPK_REPLENISH_THRESHOLD or need_new_spk: + if count >= OPK_REPLENISH_THRESHOLD: + self._logger.info("SPK rotation triggered (OPK count OK: %d)", count) + else: + self._logger.info("OPK count low (%d), replenishing...", count) + await self._generate_and_upload_prekeys_batch(need_new_spk) + + async def _generate_and_upload_prekeys_batch(self, need_new_spk: bool = False): + """Generate and upload prekeys in a single round-trip via ensure_prekeys.""" + if not self.identity_private: + return + + kwargs: dict = {} + + # SPK + if need_new_spk: + if self.spk_private and self.spk_id: + self._prev_spk_private = self.spk_private + self._prev_spk_id = self.spk_id + _save_prev_spk(self.email, self.spk_private, self.spk_id, self._local_key) + spk = generate_signed_prekey(self.identity_private) + self.spk_private = spk["private"] + self.spk_id = spk["id"] + _save_spk(self.email, spk["private"], spk["id"], self._local_key) + kwargs["signed_prekey"] = { + "id": spk["id"], + "public_key": encode_binary(serialize_x25519_public(spk["public"])), + "signature": encode_binary(spk["signature"]), + } + + # OPKs + opks = generate_one_time_prekeys(OPK_BATCH_SIZE) + for opk in opks: + self.opk_privates[opk["id"]] = opk["private"] + _save_opk_private(self.email, opk["id"], opk["private"], self._local_key) + kwargs["one_time_prekeys"] = [ + {"id": opk["id"], "public_key": encode_binary(serialize_x25519_public(opk["public"]))} + for opk in opks + ] + + resp = await self.send_and_recv("ensure_prekeys", **kwargs) + if resp["status"] == "ok": + data = resp.get("data", {}) + self._logger.info("ensure_prekeys: count=%d, spk_uploaded=%s, otps_uploaded=%d", + data.get("count", 0), data.get("uploaded_spk", False), + data.get("uploaded_otps", 0)) + + # ------------------------------------------------------------------ + # Login + # ------------------------------------------------------------------ + + async def login(self, email: str, password: str) -> tuple[bool, str]: + """Login user. Returns (success, message).""" + self.email = email + + # Brute-force lockout check + remaining = _check_lockout(email) + if remaining > 0: + return False, f"Too many failed attempts. Try again in {remaining:.0f}s." + + pwd_bytes = bytearray(password.encode("utf-8")) if password else None + + try: + # Load RSA keys + priv, pub, err = load_keys(email, password=bytes(pwd_bytes) if pwd_bytes else None) + if priv is None: + if err and "password" in err.lower(): + _record_failed_attempt(email) + return False, err or "No local keys found. Register first." + self.private_key = priv + self.public_key = pub + + # Load identity keys + ed_priv, ed_pub = _load_identity_keys(email, password=bytes(pwd_bytes) if pwd_bytes else None) + finally: + if pwd_bytes: + pwd_bytes[:] = b'\x00' * len(pwd_bytes) + + if ed_priv is not None: + self.identity_private = ed_priv + self.identity_public = ed_pub + self._cache_key = derive_self_encryption_key(ed_priv) + self._local_key = derive_local_storage_key(ed_priv) + self._load_verification_stores() + + # Load SPK + spk_priv, spk_id = _load_spk(email, self._local_key) + if spk_priv: + self.spk_private = spk_priv + self.spk_id = spk_id + + # Load previous SPK for grace period (M4) + prev_spk_priv, prev_spk_id = _load_prev_spk(email, self._local_key) + if prev_spk_priv: + self._prev_spk_private = prev_spk_priv + self._prev_spk_id = prev_spk_id + + # Load device_id from disk + self.device_id = _load_device_id(email) + + # RSA challenge-response login + start = await self.send_and_recv("login_start", email=email) + if start["status"] != "ok": + return False, start["data"]["message"] + + challenge = decode_binary(start["data"]["challenge"]) + signature = rsa_sign(self.private_key, challenge) + login_kwargs = {"email": email, "signature": encode_binary(signature), + "client_version": VERSION} + if self.device_id: + login_kwargs["device_id"] = self.device_id + finish = await self.send_and_recv("login_finish", **login_kwargs) + if finish["status"] == "ok": + self.session = finish["data"] + self.username = self.session.get("username", "") + # Store device_id from server + self.device_id = finish["data"].get("device_id") + if self.device_id: + _save_device_id(email, self.device_id) + # Replenish prekeys in background — after pairing, the new device + # has no local OPK private keys so we must generate fresh ones + # (server-side OPKs have no matching private keys on this device). + # Use keep_spk=True to preserve the shared SPK so both devices + # can respond to X3DH. + opk_dir = get_key_dir(self.email) / "opk_private" + has_local_opks = opk_dir.exists() and any(opk_dir.iterdir()) + if has_local_opks: + asyncio.create_task(self._ensure_prekeys()) + else: + self._logger.info("No local OPKs (likely new device). Generating fresh OPKs, keeping SPK.") + asyncio.create_task(self._generate_and_upload_prekeys(keep_spk=True)) + _clear_lockout(email) + return True, f"Logged in as '{self.username}' (ID: {self.session['user_id']})" + return False, finish["data"]["message"] + + # ------------------------------------------------------------------ + # Pairing (device pairing — transfers RSA + identity keys) + # ------------------------------------------------------------------ + + async def pairing_start(self, email: str) -> tuple[bool, str]: + """Start device pairing. Returns (success, code/message).""" + self._logger.info("pairing_start via %s for %s", self.server_endpoint(), email.strip().lower()) + temp_priv, temp_pub = generate_x25519_keypair() + self._pairing_temp_private_key = temp_priv + temp_pub_raw = serialize_x25519_public(temp_pub) + self._pairing_fingerprint = compute_pairing_fingerprint(temp_pub_raw) + resp = await self.send_and_recv( + "pairing_start", + email=email, + temp_public_key=encode_binary(temp_pub_raw), + temp_key_type="x25519", + ) + if resp["status"] == "ok": + self._pairing_code = resp["data"]["code"] + self._pairing_poll_token = resp["data"].get("poll_token", "") + return True, resp["data"]["code"] + self._pairing_fingerprint = "" + self._pairing_code = "" + return False, resp["data"]["message"] + + async def pairing_wait(self, code: str, email: str, password: str, timeout: int = 300) -> tuple[bool, str]: + """Wait for pairing payload and import keys. Returns (success, message).""" + if not self._pairing_temp_private_key: + return False, "Pairing not started." + from crypto_utils import aes_decrypt as _aes_decrypt + poll_token = getattr(self, "_pairing_poll_token", "") + deadline = asyncio.get_event_loop().time() + timeout + while asyncio.get_event_loop().time() < deadline: + resp = await self.send_and_recv("pairing_poll", code=code, poll_token=poll_token) + if resp["status"] != "ok": + self._pairing_fingerprint = "" + self._pairing_code = "" + return False, resp["data"]["message"] + if not resp["data"].get("ready"): + await asyncio.sleep(2.0) + continue + payload = resp["data"]["payload"] + try: + sender_pub_raw = decode_binary(payload["sender_public_key"]) + sender_pub = load_x25519_public(sender_pub_raw) + my_pub_raw = serialize_x25519_public(self._pairing_temp_private_key.public_key()) + shared_secret = x25519_dh(self._pairing_temp_private_key, sender_pub) + aes_key = derive_pairing_shared_key(shared_secret, my_pub_raw, sender_pub_raw) + nonce = decode_binary(payload["iv"]) + ct = decode_binary(payload["ciphertext"]) + tag = decode_binary(payload["tag"]) + keys_json = _aes_decrypt(aes_key, nonce, ct, tag) + keys_data = json.loads(keys_json) + + pwd_bytes = bytearray(password.encode("utf-8")) if password else None + + try: + # Import RSA key + rsa_priv = load_private_key(keys_data["rsa_private"].encode(), password=None) + rsa_pub = rsa_priv.public_key() + save_keys(email, rsa_priv, rsa_pub, password=bytes(pwd_bytes) if pwd_bytes else None) + + # Import identity keys + ed_priv = load_ed25519_private(bytes.fromhex(keys_data["identity_private"])) + ed_pub = ed_priv.public_key() + _save_identity_keys(email, ed_priv, ed_pub, password=bytes(pwd_bytes) if pwd_bytes else None) + finally: + if pwd_bytes: + pwd_bytes[:] = b'\x00' * len(pwd_bytes) + + self.email = email + self.private_key = rsa_priv + self.public_key = rsa_pub + self.identity_private = ed_priv + self.identity_public = ed_pub + self._cache_key = derive_self_encryption_key(ed_priv) + self._local_key = derive_local_storage_key(ed_priv) + self._load_verification_stores() + self._pairing_temp_private_key = None + self._pairing_fingerprint = "" + self._pairing_code = "" + + # Multi-device: new device generates own SPK + OPKs on first + # login. No session/sender key import needed — each device + # has independent Double Ratchet sessions. + + return True, "Pairing complete." + except Exception as e: + self._pairing_fingerprint = "" + self._pairing_code = "" + return False, f"Failed to import keys: {e}" + self._pairing_fingerprint = "" + self._pairing_code = "" + return False, "Pairing timed out." + + async def authorize_device(self, code: str, expected_fingerprint: str) -> tuple[bool, str]: + """Authorize a new device by sending all keys to it.""" + if not self.private_key or not self.identity_private: + return False, "Not logged in." + expected_digits = normalize_pairing_fingerprint(expected_fingerprint) + if len(expected_digits) != 30: + return False, "Pairing fingerprint must contain 30 digits." + claim = await self.send_and_recv("pairing_claim", code=code) + if claim["status"] != "ok": + return False, claim["data"]["message"] + + if claim["data"].get("temp_key_type") != "x25519": + return False, "Unsupported pairing key type. Update both devices and try again." + + temp_pub_raw = decode_binary(claim["data"]["temp_public_key"]) + actual_fp = compute_pairing_fingerprint(temp_pub_raw) + if normalize_pairing_fingerprint(actual_fp) != expected_digits: + self._logger.warning("Pairing fingerprint mismatch for code %s", code[:8]) + return False, ( + "Pairing fingerprint mismatch. Verify the new device fingerprint and try again.\n" + f"Expected: {expected_fingerprint}\n" + f"Received: {actual_fp}" + ) + temp_pub = load_x25519_public(temp_pub_raw) + + # Build keys payload — only RSA + identity key. + # Multi-device: new device generates own SPK + OPKs, creates independent + # sessions. No session/sender key transfer needed. + keys_data = { + "rsa_private": serialize_private_key(self.private_key, password=None).decode(), + "identity_private": serialize_ed25519_private_raw(self.identity_private).hex(), + } + + # Send keys to the new device first. Re-encrypting history can take a + # while on large accounts; doing it before pairing_send can make a valid + # code expire during authorization. + plaintext = json.dumps(keys_data).encode() + sender_priv, sender_pub = generate_x25519_keypair() + sender_pub_raw = serialize_x25519_public(sender_pub) + shared_secret = x25519_dh(sender_priv, temp_pub) + pairing_key = derive_pairing_shared_key(shared_secret, sender_pub_raw, temp_pub_raw) + _, nonce, ct, tag = aes_encrypt(plaintext, key=pairing_key) + payload = { + "sender_public_key": encode_binary(sender_pub_raw), + "iv": encode_binary(nonce), + "ciphertext": encode_binary(ct), + "tag": encode_binary(tag), + } + resp = await self.send_and_recv("pairing_send", code=code, payload=payload) + if resp["status"] == "ok": + async def _reencrypt_after_pairing(): + try: + delay = random.uniform(*PAIRING_REENCRYPT_INITIAL_DELAY_RANGE) + self._logger.info("Delaying post-pairing history resync by %.1fs", delay) + await asyncio.sleep(delay) + await self.reencrypt_history() + except Exception as e: + self._logger.warning("Post-pairing re-encryption failed: %s", e) + + asyncio.create_task(_reencrypt_after_pairing()) + return True, "Device authorized." + return False, resp["data"]["message"] + + # ------------------------------------------------------------------ + # Password change (local key re-encryption only) + # ------------------------------------------------------------------ + + def change_password(self, old_password: str, new_password: str) -> tuple[bool, str]: + """Change password for local key encryption (RSA + identity key). + + Returns (success, message). + """ + if not self.email: + return False, "Not logged in." + + old_pwd = bytearray(old_password.encode("utf-8")) + new_pwd = bytearray(new_password.encode("utf-8")) + try: + # 1. Verify old password by loading keys + priv, pub, err = load_keys(self.email, password=bytes(old_pwd)) + if priv is None: + return False, "Wrong current password." + + ed_priv, ed_pub = _load_identity_keys(self.email, password=bytes(old_pwd)) + if ed_priv is None: + return False, "Failed to load identity key." + + # 2. Re-save with new password + save_keys(self.email, priv, pub, password=bytes(new_pwd)) + _save_identity_keys(self.email, ed_priv, ed_pub, password=bytes(new_pwd)) + + return True, "Password changed successfully." + finally: + old_pwd[:] = b'\x00' * len(old_pwd) + new_pwd[:] = b'\x00' * len(new_pwd) + + async def change_username(self, new_username: str) -> tuple[bool, str]: + """Change display name on server.""" + if not self.session: + return False, "Not logged in." + new_username = new_username.strip() + if not new_username or len(new_username) > 100: + return False, "Username must be 1-100 characters." + resp = await self.send_and_recv("change_username", username=new_username) + if resp["status"] == "ok": + self.username = resp["data"]["username"] + if self.session: + self.session["username"] = self.username + return True, "Username changed." + return False, resp["data"].get("message", "Unknown error") + + # ------------------------------------------------------------------ + # Key rotation (RSA login key only) + # ------------------------------------------------------------------ + + async def rotate_keys(self, username: str, password: str) -> tuple[bool, str]: + """Rotate RSA keypair to revoke other devices.""" + if not self.session or self.session.get("username") != username: + return False, "Not logged in." + pwd_bytes = password.encode("utf-8") if password else None + priv, pub = generate_rsa_keypair() + save_keys(self.email, priv, pub, password=pwd_bytes) + self.private_key = priv + self.public_key = pub + pub_pem = serialize_public_key(pub).decode("utf-8") + resp = await self.send_and_recv("rotate_keys", public_key=pub_pem) + if resp["status"] == "ok": + return True, "RSA login keys rotated." + return False, resp["data"]["message"] + + # ------------------------------------------------------------------ + # Session management (X3DH + Double Ratchet) + # ------------------------------------------------------------------ + + async def _get_device_bundles(self, peer_user_id: str) -> list[dict]: + """Get per-device key bundles for a peer. Caches for 5 minutes.""" + import time + cached = self._device_bundle_cache.get(peer_user_id) + if cached: + ts, bundles = cached + if time.time() - ts < 300: + return bundles + + resp = await self.send_and_recv("get_key_bundle", user_id=peer_user_id) + if resp["status"] != "ok": + raise RuntimeError(f"Cannot get key bundle for {peer_user_id}: {resp['data']['message']}") + + data = resp["data"] + ik_b64 = data.get("identity_key", "") + + device_bundles = data.get("device_bundles") + if device_bundles: + # Attach identity_key to each bundle + for b in device_bundles: + b["identity_key"] = ik_b64 + else: + # Old server: wrap flat response as single-entry list + device_bundles = [{ + "device_id": None, + "identity_key": ik_b64, + "signed_prekey_id": data.get("signed_prekey_id", ""), + "signed_prekey": data.get("signed_prekey", ""), + "spk_signature": data.get("spk_signature", ""), + "one_time_prekey_id": data.get("one_time_prekey_id"), + "one_time_prekey": data.get("one_time_prekey"), + }] + + self._device_bundle_cache[peer_user_id] = (time.time(), device_bundles) + return device_bundles + + async def _get_or_create_session(self, peer_user_id: str, + peer_device_id: str | None = None, + bundle: dict | None = None) -> DoubleRatchet: + """Load existing session or create one via X3DH. + + If peer_device_id is set, sessions are keyed by "user_id:device_id". + If bundle is provided, it's used instead of fetching from server. + """ + session_key = f"{peer_user_id}:{peer_device_id}" if peer_device_id else peer_user_id + + # Check in-memory cache + if session_key in self.sessions: + return self.sessions[session_key] + + # Check on disk + ratchet = _load_session(self.email, peer_user_id, self._local_key, + peer_device_id=peer_device_id) + if ratchet: + self.sessions[session_key] = ratchet + return ratchet + + # Create new session via X3DH + if not bundle: + resp = await self.send_and_recv("get_key_bundle", user_id=peer_user_id) + if resp["status"] != "ok": + raise RuntimeError(f"Cannot get key bundle for {peer_user_id}: {resp['data']['message']}") + bundle = resp["data"] + + ik_remote_bytes = decode_binary(bundle["identity_key"]) + ik_remote = load_ed25519_public(ik_remote_bytes) + + # TOFU: verify identity key before using it in X3DH + ik_status = self.check_identity_key(peer_user_id, ik_remote_bytes) + if ik_status in ("changed", "changed_verified"): + raise IdentityKeyChanged(peer_user_id, ik_remote_bytes, ik_status) + + spk_remote = load_x25519_public(decode_binary(bundle["signed_prekey"])) + spk_sig = decode_binary(bundle["spk_signature"]) + + opk_remote = None + opk_id = bundle.get("one_time_prekey_id") + if bundle.get("one_time_prekey"): + opk_remote = load_x25519_public(decode_binary(bundle["one_time_prekey"])) + + # Perform X3DH + shared_secret, ek_priv, ek_pub = x3dh_initiate( + self.identity_private, + ik_remote, + spk_remote, + spk_sig, + opk_remote, + ) + + # Initialize Double Ratchet as Alice + ratchet = DoubleRatchet.init_alice(shared_secret, spk_remote) + self.sessions[session_key] = ratchet + _save_session(self.email, peer_user_id, ratchet, self._local_key, + peer_device_id=peer_device_id) + + # Build X3DH header for first message + x3dh_header = { + "ik": encode_binary(serialize_ed25519_public(self.identity_public)), + "ek": encode_binary(serialize_x25519_public(ek_pub)), + } + if opk_id: + x3dh_header["opk_id"] = opk_id + + # Cache the x3dh header for the next send_message call + ratchet._x3dh_header = x3dh_header + + # Cache remote user info + self._user_cache[peer_user_id] = { + "user_id": peer_user_id, + "identity_key": ik_remote, + "identity_key_bytes": ik_remote_bytes, + "identity_key_status": ik_status, + } + + return ratchet + + def _process_x3dh_header(self, sender_id: str, x3dh_header: dict, + sender_device_id: str | None = None, + spk_override=None) -> DoubleRatchet: + """Process an incoming X3DH header to establish session as Bob. + + Args: + spk_override: If provided, use this SPK private key instead of self.spk_private. + Used for grace period fallback (M4). + """ + ik_remote_bytes = decode_binary(x3dh_header["ik"]) + ik_remote = load_ed25519_public(ik_remote_bytes) + + # TOFU: verify identity key before using it in X3DH + ik_status = self.check_identity_key(sender_id, ik_remote_bytes) + if ik_status in ("changed", "changed_verified"): + raise IdentityKeyChanged(sender_id, ik_remote_bytes, ik_status) + + ek_remote = load_x25519_public(decode_binary(x3dh_header["ek"])) + + opk_id = x3dh_header.get("opk_id") + opk_priv = None + if opk_id: + opk_priv = _load_opk_private(self.email, opk_id, self._local_key) + if opk_priv: + _delete_opk_private(self.email, opk_id) + + spk_priv = spk_override if spk_override else self.spk_private + + shared_secret = x3dh_respond( + self.identity_private, + spk_priv, + ik_remote, + ek_remote, + opk_priv, + ) + + spk_pub = spk_priv.public_key() if hasattr(spk_priv, 'public_key') else None + ratchet = DoubleRatchet.init_bob(shared_secret, (spk_priv, spk_pub)) + + session_key = f"{sender_id}:{sender_device_id}" if sender_device_id else sender_id + self.sessions[session_key] = ratchet + _save_session(self.email, sender_id, ratchet, self._local_key, + peer_device_id=sender_device_id) + + self._user_cache[sender_id] = { + "user_id": sender_id, + "identity_key": ik_remote, + "identity_key_bytes": ik_remote_bytes, + "identity_key_status": ik_status, + } + + return ratchet + + # ------------------------------------------------------------------ + # Conversations + # ------------------------------------------------------------------ + + async def create_conversation(self, member_emails: list[str], name: str | None = None) -> tuple[str | None, str]: + kwargs = {"members": member_emails} + if name: + kwargs["name"] = name + resp = await self.send_and_recv("create_conversation", **kwargs) + if resp["status"] == "ok": + return resp["data"]["conversation_id"], "OK" + return None, resp["data"]["message"] + + async def remove_member(self, conv_id: str, user_id: str) -> tuple[bool, str]: + resp = await self.send_and_recv("remove_member", conversation_id=conv_id, user_id=user_id) + if resp["status"] == "ok": + return True, "OK" + return False, resp["data"]["message"] + + async def leave_group(self, conv_id: str) -> tuple[bool, str]: + """Leave a group conversation.""" + resp = await self.send_and_recv("leave_group", conversation_id=conv_id) + if resp["status"] == "ok": + # Clean up local sender key state for this group + self.sender_key_states.pop(conv_id, None) + # Remove received sender keys for this conversation + to_remove = [k for k in self.recv_sender_keys if k.startswith(f"{conv_id}:")] + for k in to_remove: + self.recv_sender_keys.pop(k, None) + return True, "OK" + return False, resp["data"]["message"] + + async def rename_conversation(self, conv_id: str, name: str) -> tuple[bool, str]: + """Rename a group conversation (creator only).""" + resp = await self.send_and_recv("rename_conversation", conversation_id=conv_id, name=name) + if resp["status"] == "ok": + return True, "OK" + return False, resp["data"]["message"] + + async def delete_conversation(self, conv_id: str) -> tuple[bool, str]: + """Delete a conversation (leave + server cleans up if empty).""" + resp = await self.send_and_recv("delete_conversation", conversation_id=conv_id) + if resp["status"] == "ok": + # Clean up local sender key state + self.sender_key_states.pop(conv_id, None) + to_remove = [k for k in self.recv_sender_keys if k.startswith(f"{conv_id}:")] + for k in to_remove: + self.recv_sender_keys.pop(k, None) + return True, "OK" + return False, resp["data"]["message"] + + async def add_member(self, conv_id: str, email: str) -> tuple[bool, str]: + resp = await self.send_and_recv("add_member", conversation_id=conv_id, email=email) + if resp["status"] == "ok": + return True, "OK" + return False, resp["data"]["message"] + + async def accept_invitation(self, conv_id: str) -> tuple[bool, str]: + """Accept a group invitation.""" + resp = await self.send_and_recv("accept_invitation", conversation_id=conv_id) + if resp["status"] == "ok": + return True, "OK" + return False, resp["data"]["message"] + + async def decline_invitation(self, conv_id: str) -> tuple[bool, str]: + """Decline a group invitation.""" + resp = await self.send_and_recv("decline_invitation", conversation_id=conv_id) + if resp["status"] == "ok": + return True, "OK" + return False, resp["data"]["message"] + + async def list_invitations(self) -> list[dict]: + """List pending group invitations.""" + resp = await self.send_and_recv("list_invitations") + if resp["status"] == "ok": + return resp["data"]["invitations"] + return [] + + async def list_conversations(self) -> list[dict]: + resp = await self.send_and_recv("list_conversations") + if resp["status"] == "ok": + return resp["data"]["conversations"] + return [] + + async def find_or_create_conversation(self, email: str) -> tuple[str | None, str]: + resp = await self.send_and_recv("find_conversation", email=email) + if resp["status"] != "ok": + return None, resp["data"]["message"] + conv_id = resp["data"]["conversation_id"] + if conv_id: + return conv_id, "OK" + return await self.create_conversation([email]) + + # ------------------------------------------------------------------ + # Send message + # ------------------------------------------------------------------ + + def _cancel_typing_timer(self, conv_id: str): + task = self._typing_stop_tasks.pop(conv_id, None) + if task and not task.done(): + task.cancel() + + async def _typing_stop_after_delay(self, conv_id: str, delay: float): + try: + await asyncio.sleep(delay) + await self.typing_stop(conv_id) + except asyncio.CancelledError: + return + + async def typing_start(self, conv_id: str): + """Debounced typing_start with 3s inactivity timeout.""" + if not conv_id or not self.session: + return + + now = time.monotonic() + was_active = self._typing_active.get(conv_id, False) + last_sent = self._typing_last_sent.get(conv_id, 0.0) + should_send = (not was_active) or (now - last_sent >= 1.0) + + self._typing_active[conv_id] = True + self._cancel_typing_timer(conv_id) + self._typing_stop_tasks[conv_id] = asyncio.create_task( + self._typing_stop_after_delay(conv_id, 3.0) + ) + + if not should_send: + return + self._typing_last_sent[conv_id] = now + try: + await self.send_and_recv("typing_start", timeout=5.0, conversation_id=conv_id) + except Exception: + pass + + async def typing_stop(self, conv_id: str, force: bool = False): + if not conv_id or not self.session: + return + self._cancel_typing_timer(conv_id) + + was_active = self._typing_active.get(conv_id, False) + self._typing_active[conv_id] = False + if not was_active and not force: + return + try: + await self.send_and_recv("typing_stop", timeout=5.0, conversation_id=conv_id) + except Exception: + pass + + def _is_group(self, members: list[dict]) -> bool: + return len(members) > 2 + + async def send_message(self, conv_id: str, text: str, members: list[dict], + reply_to: str | None = None) -> tuple[bool, str | dict]: + """Encrypt and send a message. DM: per-recipient Double Ratchet. Group: Sender Keys. + + Returns (True, msg_dict) on success or (False, error_string) on failure. + msg_dict contains the full decrypted payload ready for display. + """ + my_user_id = self.session["user_id"] + await self.typing_stop(conv_id, force=True) + + # Build plaintext payload + payload = { + "sender": self.username, + "text": text, + "reply_to": reply_to, + "timestamp": datetime.now(timezone.utc).isoformat(), + } + plaintext = pad_plaintext(json.dumps(payload, ensure_ascii=False).encode("utf-8")) + + if self._is_group(members): + return await self._send_group_message(conv_id, plaintext, members, payload) + else: + return await self._send_dm(conv_id, plaintext, members, payload) + + async def _send_dm(self, conv_id: str, plaintext: bytes, members: list[dict], + payload: dict | None = None) -> tuple[bool, str | dict]: + """Encrypt DM with per-device Double Ratchet.""" + my_user_id = self.session["user_id"] + recipients = [] + first_ratchet_header = None + + for member in members: + uid = member.get("user_id") + if not uid or uid == my_user_id: + continue + + # Get all device bundles for this user + try: + device_bundles = await self._get_device_bundles(uid) + self._logger.debug("Got %d device bundles for %s", len(device_bundles), uid) + except Exception as e: + self._logger.warning("Failed to get device bundles for %s: %s", uid, e) + device_bundles = [] + + if not device_bundles: + # Fallback: try single session (legacy peer) + ratchet = await self._get_or_create_session(uid) + result = ratchet.encrypt(plaintext) + x3dh_hdr = getattr(ratchet, "_x3dh_header", None) + if x3dh_hdr: + delattr(ratchet, "_x3dh_header") + entry = { + "user_id": uid, + "encrypted_content": encode_binary(result["ciphertext"]), + "nonce": encode_binary(result["nonce"]), + "ratchet_header": result["header"], + } + if x3dh_hdr: + entry["x3dh_header"] = x3dh_hdr + recipients.append(entry) + if first_ratchet_header is None: + first_ratchet_header = result["header"] + _save_session(self.email, uid, ratchet, self._local_key) + continue + + for bundle in device_bundles: + dev_id = bundle.get("device_id") + ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id, + bundle=bundle) + result = ratchet.encrypt(plaintext) + x3dh_hdr = getattr(ratchet, "_x3dh_header", None) + if x3dh_hdr: + delattr(ratchet, "_x3dh_header") + + entry = { + "user_id": uid, + "encrypted_content": encode_binary(result["ciphertext"]), + "nonce": encode_binary(result["nonce"]), + "ratchet_header": result["header"], + } + if dev_id: + entry["device_id"] = dev_id + if x3dh_hdr: + entry["x3dh_header"] = x3dh_hdr + recipients.append(entry) + + if first_ratchet_header is None: + first_ratchet_header = result["header"] + + _save_session(self.email, uid, ratchet, self._local_key, + peer_device_id=dev_id) + + # Encrypt self-copy with static key derived from identity (not ratchet) + # Uses SELF_DEVICE_ID so all own devices can read it + self_key = derive_self_encryption_key(self.identity_private) + _, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key) + recipients.append({ + "user_id": my_user_id, + "encrypted_content": encode_binary(self_ct + self_tag), + "nonce": encode_binary(self_nonce), + "ratchet_header": {"self": True}, + }) + + if not recipients: + return False, "No recipients." + + kwargs = { + "conversation_id": conv_id, + "ratchet_header": first_ratchet_header, + "recipients": recipients, + } + + resp = await self.send_and_recv("send_message", **kwargs) + if resp["status"] == "ok": + msg_data = resp.get("data", {}) + if payload is not None: + result = { + **payload, + "message_id": msg_data.get("message_id", ""), + "created_at": msg_data.get("created_at", ""), + "sender_id": self.session["user_id"], + "conversation_id": conv_id, + "read_by": [], + } + _save_message_to_cache(self.email, conv_id, result["message_id"], result, self._cache_key) + return True, result + return True, "Message sent." + return False, resp["data"]["message"] + + async def _send_group_message(self, conv_id: str, plaintext: bytes, + members: list[dict], + payload: dict | None = None) -> tuple[bool, str | dict]: + """Encrypt group message with Sender Keys.""" + my_user_id = self.session["user_id"] + + # Get or create sender key for this group + sk = self.sender_key_states.get(conv_id) + if not sk: + sk = _load_sender_key_state(self.email, conv_id, self._local_key) + if not sk: + sk = SenderKeyState() + self.sender_key_states[conv_id] = sk + _save_sender_key_state(self.email, conv_id, sk, self._local_key) + + self.sender_key_states[conv_id] = sk + + await self._catch_up_sender_key_distribution(conv_id, members, sk) + + # Encrypt with sender key + result = sk.encrypt(plaintext) + _save_sender_key_state(self.email, conv_id, sk, self._local_key) + + # Build per-recipient entries (same ciphertext for all except self) + recipients = [] + for member in members: + uid = member.get("user_id") + if not uid or uid == my_user_id: + continue + recipients.append({ + "user_id": uid, + "encrypted_content": encode_binary(result["ciphertext"]), + "nonce": encode_binary(result["nonce"]), + }) + + # Self-encrypted copy (so other devices + history fetch can decrypt) + self_key = derive_self_encryption_key(self.identity_private) + _, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key) + recipients.append({ + "user_id": my_user_id, + "encrypted_content": encode_binary(self_ct + self_tag), + "nonce": encode_binary(self_nonce), + "ratchet_header": {"self": True}, + }) + + ratchet_header = {"dh_pub": "00" * 32, "n": 0, "pn": 0} # Dummy for groups + + kwargs = { + "conversation_id": conv_id, + "ratchet_header": ratchet_header, + "recipients": recipients, + "sender_chain_id": encode_binary(bytes.fromhex(result["chain_id"])), + "sender_chain_n": result["n"], + } + + resp = await self.send_and_recv("send_message", **kwargs) + if resp["status"] == "ok": + msg_data = resp.get("data", {}) + if payload is not None: + result_msg = { + **payload, + "message_id": msg_data.get("message_id", ""), + "created_at": msg_data.get("created_at", ""), + "sender_id": self.session["user_id"], + "conversation_id": conv_id, + "read_by": [], + } + _save_message_to_cache(self.email, conv_id, result_msg["message_id"], result_msg, self._cache_key) + return True, result_msg + return True, "Message sent." + return False, resp["data"]["message"] + + async def _catch_up_sender_key_distribution(self, conv_id: str, members: list[dict], + sk: SenderKeyState): + """Ensure all current members have our existing sender key.""" + my_user_id = self.session["user_id"] + current_member_ids = { + uid for uid in (member.get("user_id") for member in members) + if uid and uid != my_user_id + } + if not current_member_ids: + return + + distributed_to = _load_sender_key_recipients(self.email, conv_id, self._local_key) + missing_ids = sorted(current_member_ids - distributed_to) + if not missing_ids: + return + + distributed_now = await self._distribute_sender_key( + conv_id, + [{"user_id": uid} for uid in missing_ids], + sk, + ) + if distributed_now: + distributed_to.update(distributed_now) + _save_sender_key_recipients(self.email, conv_id, distributed_to, self._local_key) + + async def _distribute_sender_key(self, conv_id: str, members: list[dict], + sk: SenderKeyState) -> set[str]: + """Send own sender key to all group members via pairwise Double Ratchet (per-device).""" + my_user_id = self.session["user_id"] + distributed_to: set[str] = set() + exported_key = sk.export_key() + + # Build a special "sender_key_distribution" payload + payload = { + "sender": self.username, + "text": "", + "reply_to": None, + "timestamp": datetime.now(timezone.utc).isoformat(), + "_sender_key": { + "conv_id": conv_id, + "key": encode_binary(exported_key), + "sender_device_id": self.device_id, + }, + } + plaintext = pad_plaintext(json.dumps(payload, ensure_ascii=False).encode("utf-8")) + + # Send as DM to each member's devices (per-device encryption) + for member in members: + uid = member.get("user_id") + if not uid or uid == my_user_id: + continue + + try: + # Get all device bundles for this user + try: + device_bundles = await self._get_device_bundles(uid) + except Exception: + device_bundles = [] + + if not device_bundles: + # Fallback: legacy single-device + ratchet = await self._get_or_create_session(uid) + result = ratchet.encrypt(plaintext) + x3dh_header = getattr(ratchet, "_x3dh_header", None) + if x3dh_header: + delattr(ratchet, "_x3dh_header") + + recipient_entry = { + "user_id": uid, + "encrypted_content": encode_binary(result["ciphertext"]), + "nonce": encode_binary(result["nonce"]), + "ratchet_header": result["header"], + } + if x3dh_header: + recipient_entry["x3dh_header"] = x3dh_header + kwargs = { + "conversation_id": conv_id, + "ratchet_header": result["header"], + "recipients": [recipient_entry], + } + await self.send_and_recv("send_message", **kwargs) + _save_session(self.email, uid, ratchet, self._local_key) + distributed_to.add(uid) + else: + # Per-device encryption + recipients = [] + first_rh = None + for bundle in device_bundles: + dev_id = bundle.get("device_id") + ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id, + bundle=bundle) + result = ratchet.encrypt(plaintext) + x3dh_header = getattr(ratchet, "_x3dh_header", None) + if x3dh_header: + delattr(ratchet, "_x3dh_header") + + entry = { + "user_id": uid, + "encrypted_content": encode_binary(result["ciphertext"]), + "nonce": encode_binary(result["nonce"]), + "ratchet_header": result["header"], + } + if dev_id: + entry["device_id"] = dev_id + if x3dh_header: + entry["x3dh_header"] = x3dh_header + recipients.append(entry) + if first_rh is None: + first_rh = result["header"] + _save_session(self.email, uid, ratchet, self._local_key, + peer_device_id=dev_id) + + kwargs = { + "conversation_id": conv_id, + "ratchet_header": first_rh, + "recipients": recipients, + } + await self.send_and_recv("send_message", **kwargs) + distributed_to.add(uid) + except Exception as e: + self._logger.warning("Failed to distribute sender key to %s: %s", uid, e) + return distributed_to + + async def redistribute_sender_key_to_member(self, conv_id: str, new_user_id: str): + """Redistribute our existing sender key to a newly joined group member. + + Called when we receive a member_added notification — the new member needs our + sender key so they can decrypt future messages we send in the group. + """ + if not self.session: + return + sk = self.sender_key_states.get(conv_id) + if sk is None: + sk = _load_sender_key_state(self.email, conv_id, self._local_key) + if sk is None: + # We haven't sent anything in this group yet — no key to redistribute + return + try: + distributed = await self._distribute_sender_key(conv_id, [{"user_id": new_user_id}], sk) + if new_user_id in distributed: + recipients = _load_sender_key_recipients(self.email, conv_id, self._local_key) + recipients.add(new_user_id) + _save_sender_key_recipients(self.email, conv_id, recipients, self._local_key) + self._logger.info("Redistributed sender key for conv=%s to new member %s", + conv_id[:8], new_user_id[:8]) + except Exception as e: + self._logger.warning("Failed to redistribute sender key to %s: %s", new_user_id, e) + + # ------------------------------------------------------------------ + # Decrypt messages + # ------------------------------------------------------------------ + + def _decrypt_message(self, msg_data: dict) -> dict: + """Decrypt a single message (DM or group).""" + # Check for self-encrypted marker FIRST — after re-encryption, + # group messages will have {"self": true} ratchet_header but still + # have sender_chain_id at message level. + rh = msg_data.get("ratchet_header", {}) + if isinstance(rh, dict) and rh.get("self"): + return self._decrypt_dm(msg_data) + + if msg_data.get("sender_chain_id"): + return self._decrypt_group(msg_data) + else: + return self._decrypt_dm(msg_data) + + def _decrypt_dm(self, msg_data: dict) -> dict: + """Decrypt DM using Double Ratchet with sender, or static key for self-copies.""" + sender_id = msg_data.get("sender_id", "") + sender_device_id = msg_data.get("sender_device_id") + ratchet_header = msg_data.get("ratchet_header", {}) + ct_b64 = msg_data.get("encrypted_content", "") + nonce_b64 = msg_data.get("nonce", "") + + if not ct_b64 or not nonce_b64: + raise ValueError("Missing ciphertext or nonce") + + ciphertext = decode_binary(ct_b64) + nonce = decode_binary(nonce_b64) + + # Self-encrypted message (own sent message copy) + if isinstance(ratchet_header, dict) and ratchet_header.get("self"): + self_key = derive_self_encryption_key(self.identity_private) + ct = ciphertext[:-16] + tag = ciphertext[-16:] + plaintext = aes_decrypt(self_key, nonce, ct, tag) + else: + x3dh_header = msg_data.get("x3dh_header") + + # Session key: "sender_id:sender_device_id" or just "sender_id" for legacy + session_key = f"{sender_id}:{sender_device_id}" if sender_device_id else sender_id + + # Try to load existing session + ratchet = self.sessions.get(session_key) + if not ratchet: + ratchet = _load_session(self.email, sender_id, self._local_key, + peer_device_id=sender_device_id) + if ratchet: + self.sessions[session_key] = ratchet + + if ratchet and not x3dh_header: + # Normal case: existing session, no X3DH header + plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce) + _save_session(self.email, sender_id, ratchet, self._local_key, + peer_device_id=sender_device_id) + elif x3dh_header: + if ratchet: + # Existing session + X3DH header: sender may have reset. + backup = ratchet.export_state() + try: + plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce) + _save_session(self.email, sender_id, ratchet, self._local_key, + peer_device_id=sender_device_id) + except Exception: + restored = DoubleRatchet.import_state(backup) + self.sessions[session_key] = restored + _save_session(self.email, sender_id, restored, self._local_key, + peer_device_id=sender_device_id) + ratchet = self._process_x3dh_header(sender_id, x3dh_header, + sender_device_id=sender_device_id) + try: + plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce) + except Exception: + if self._prev_spk_private: + ratchet = self._process_x3dh_header( + sender_id, x3dh_header, + sender_device_id=sender_device_id, + spk_override=self._prev_spk_private) + plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce) + else: + raise + _save_session(self.email, sender_id, ratchet, self._local_key, + peer_device_id=sender_device_id) + else: + ratchet = self._process_x3dh_header(sender_id, x3dh_header, + sender_device_id=sender_device_id) + try: + plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce) + except Exception: + if self._prev_spk_private: + ratchet = self._process_x3dh_header( + sender_id, x3dh_header, + sender_device_id=sender_device_id, + spk_override=self._prev_spk_private) + plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce) + else: + raise + _save_session(self.email, sender_id, ratchet, self._local_key, + peer_device_id=sender_device_id) + else: + raise ValueError(f"No session for sender {sender_id}") + + plaintext = unpad_plaintext(plaintext) + payload = json.loads(plaintext) + + # Handle sender key distribution messages + if "_sender_key" in payload: + sk_data = payload["_sender_key"] + sk_conv_id = sk_data["conv_id"] + sk_key = decode_binary(sk_data["key"]) + sk_sender_device_id = sk_data.get("sender_device_id") + recv_sk = SenderKeyState.from_key(sk_key) + if sk_sender_device_id: + cache_key = f"{sk_conv_id}:{sender_id}:{sk_sender_device_id}" + else: + cache_key = f"{sk_conv_id}:{sender_id}" + self.recv_sender_keys[cache_key] = recv_sk + _save_recv_sender_key(self.email, sk_conv_id, sender_id, recv_sk, self._local_key, + sender_device_id=sk_sender_device_id) + # Return empty — this is a control message, not user-visible + return None + + return payload + + def _decrypt_group(self, msg_data: dict) -> dict: + """Decrypt group message using sender's Sender Key.""" + sender_id = msg_data.get("sender_id", "") + sender_device_id = msg_data.get("sender_device_id") + conv_id = msg_data.get("conversation_id", "") + chain_id_b64 = msg_data.get("sender_chain_id", "") + chain_n = msg_data.get("sender_chain_n", 0) + ct_b64 = msg_data.get("encrypted_content", "") + nonce_b64 = msg_data.get("nonce", "") + + if not ct_b64 or not nonce_b64 or not chain_id_b64: + raise ValueError("Missing group message fields") + + ciphertext = decode_binary(ct_b64) + nonce = decode_binary(nonce_b64) + chain_id = decode_binary(chain_id_b64) + + my_user_id = self.session["user_id"] + + # If we sent this message, use our own sender key + if sender_id == my_user_id: + sk = self.sender_key_states.get(conv_id) + if not sk: + sk = _load_sender_key_state(self.email, conv_id, self._local_key) + if sk: + self.sender_key_states[conv_id] = sk + if not sk: + raise ValueError("Own sender key not found") + # For our own messages, we can't decrypt from sender key (it's already advanced) + # Return a placeholder — the server echoed our ciphertext + raise ValueError("Cannot decrypt own group message from sender key") + + # Use received sender key — try with sender_device_id first, fall back to without + sk = None + if sender_device_id: + cache_key = f"{conv_id}:{sender_id}:{sender_device_id}" + sk = self.recv_sender_keys.get(cache_key) + if not sk: + sk = _load_recv_sender_key(self.email, conv_id, sender_id, self._local_key, + sender_device_id=sender_device_id) + if sk: + self.recv_sender_keys[cache_key] = sk + + if not sk: + # Fallback: try without device_id (legacy or same-device) + cache_key = f"{conv_id}:{sender_id}" + sk = self.recv_sender_keys.get(cache_key) + if not sk: + sk = _load_recv_sender_key(self.email, conv_id, sender_id, self._local_key) + if sk: + self.recv_sender_keys[cache_key] = sk + + if not sk: + raise ValueError(f"No sender key for {sender_id} in conversation {conv_id}") + + plaintext = unpad_plaintext(sk.decrypt(chain_id.hex(), chain_n, ciphertext, nonce)) + _save_recv_sender_key(self.email, conv_id, sender_id, sk, self._local_key, + sender_device_id=sender_device_id) + + return json.loads(plaintext) + + # ------------------------------------------------------------------ + # Get/decrypt messages (batch) + # ------------------------------------------------------------------ + + async def get_messages(self, conv_id: str, limit: int = 50, offset: int = 0) -> list[dict]: + cache = _load_message_cache(self.email, conv_id, self._cache_key) + my_user_id = self.session["user_id"] if self.session else "" + + # Incremental sync: use stored server timestamp from last successful fetch. + after_ts = None + if cache and offset == 0: + after_ts = cache.get("__last_server_ts", {}).get("ts") + + req_params = {"conversation_id": conv_id, "limit": limit, "offset": offset} + if after_ts: + req_params["after_ts"] = after_ts + resp = await self.send_and_recv("get_messages", **req_params) + + if resp["status"] != "ok": + # Offline fallback: return from cache if available + if cache and offset == 0: + return self._build_from_cache(cache) + return [] + + raw_messages = resp["data"]["messages"] + raw_messages.reverse() # Server returns DESC, reverse to ASC + + # Save latest server timestamp for next incremental sync + if raw_messages: + # raw_messages are now ASC; last one is newest + newest_ts = raw_messages[-1].get("created_at", "") + if newest_ts: + cache["__last_server_ts"] = {"ts": newest_ts} + _save_message_to_cache(self.email, conv_id, "__last_server_ts", + {"ts": newest_ts}, cache_key=self._cache_key) + + # Decrypt new messages from server + new_decrypted = self._decrypt_raw_messages(raw_messages, cache, conv_id, my_user_id) + + # All non-critical ops fire-and-forget to avoid blocking message display + # Confirm delivery for messages from others + deliver_ids = [m["message_id"] for m in new_decrypted + if m.get("sender_id") and m["sender_id"] != my_user_id + and not m.get("deleted")] + if deliver_ids: + asyncio.ensure_future(self.confirm_delivery(conv_id, deliver_ids)) + + # Mark entire conversation as read (fire-and-forget) + asyncio.ensure_future(self.mark_conversation_read(conv_id)) + + # Flush self-encryption queue in background + if self._pending_self_encrypt: + asyncio.ensure_future(self._flush_self_encrypt()) + + if after_ts: + # Incremental: sync deletions in background, build from cache NOW + asyncio.ensure_future(self._sync_deletions(conv_id, after_ts)) + return self._build_from_cache(cache) + + return new_decrypted + + async def _sync_deletions(self, conv_id: str, after_ts: str): + """Sync message deletions from server (background, non-blocking).""" + try: + del_resp = await self.send_and_recv("get_deleted_since", + conversation_id=conv_id, since=after_ts) + if del_resp.get("status") == "ok": + for del_id in del_resp.get("data", {}).get("message_ids", []): + _save_message_to_cache(self.email, conv_id, del_id, {"deleted": True}, + cache_key=self._cache_key) + except Exception: + pass + + def get_cached_messages(self, conv_id: str) -> list[dict]: + """Return messages from local disk cache only (no server call). Instant.""" + if not self.email: + return [] + cache = _load_message_cache(self.email, conv_id, self._cache_key) + if not cache: + return [] + return self._build_from_cache(cache) + + def _build_from_cache(self, cache: dict) -> list[dict]: + """Build sorted message list from local cache (all messages).""" + messages = [] + for msg_id, p in cache.items(): + if p.get("_control") or msg_id.startswith("__"): + continue + entry = dict(p) + entry.setdefault("message_id", msg_id) + entry.setdefault("read_by", []) + entry.setdefault("delivered_to", []) + messages.append(entry) + messages.sort(key=lambda m: m.get("created_at", "")) + return messages + + def _decrypt_raw_messages(self, raw_messages: list, cache: dict, + conv_id: str, my_user_id: str) -> list[dict]: + """Decrypt server messages, update cache. Returns list of decrypted dicts.""" + decrypted = [] + for m in raw_messages: + msg_id = m["message_id"] + + if m.get("deleted_at"): + decrypted.append({ + "message_id": msg_id, + "sender": "", + "text": "", + "created_at": m["created_at"], + "read_by": [], + "sender_id": m.get("sender_id", ""), + "deleted": True, + }) + cache[msg_id] = {"deleted": True, "created_at": m["created_at"]} + continue + + # Check local cache first (ratchet keys are one-time use) + cached = cache.get(msg_id) + if cached and not cached.get("_control"): + cached["read_by"] = m.get("read_by", []) + cached["delivered_to"] = m.get("delivered_to", []) + cached["created_at"] = m["created_at"] + if m.get("reactions"): + cached["reactions"] = m["reactions"] + if m.get("pinned_at"): + cached["pinned_at"] = m["pinned_at"] + cached["pinned_by"] = m.get("pinned_by", "") + else: + cached.pop("pinned_at", None) + cached.pop("pinned_by", None) + decrypted.append(cached) + continue + if cached and cached.get("_control"): + continue + + try: + msg_data = { + "sender_id": m.get("sender_id", ""), + "sender_device_id": m.get("sender_device_id"), + "conversation_id": conv_id, + "ratchet_header": m.get("ratchet_header", {}), + "encrypted_content": m.get("encrypted_content", ""), + "nonce": m.get("nonce", ""), + "x3dh_header": m.get("x3dh_header"), + "sender_chain_id": m.get("sender_chain_id"), + "sender_chain_n": m.get("sender_chain_n"), + } + payload = self._decrypt_message(msg_data) + if payload is None: + _save_message_to_cache(self.email, conv_id, msg_id, {"_control": True}, + cache_key=self._cache_key) + cache[msg_id] = {"_control": True} + continue + payload["message_id"] = msg_id + payload["created_at"] = m["created_at"] + payload["read_by"] = m.get("read_by", []) + payload["delivered_to"] = m.get("delivered_to", []) + payload["sender_id"] = m.get("sender_id", "") + if m.get("reactions"): + payload["reactions"] = m["reactions"] + if m.get("pinned_at"): + payload["pinned_at"] = m["pinned_at"] + payload["pinned_by"] = m.get("pinned_by", "") + decrypted.append(payload) + _save_message_to_cache(self.email, conv_id, msg_id, payload, + cache_key=self._cache_key) + cache[msg_id] = payload + if m.get("sender_id", "") != my_user_id: + self._pending_self_encrypt.append({ + "message_id": msg_id, + "payload": {k: v for k, v in payload.items() + if k not in ("message_id", "created_at", "read_by", + "delivered_to", "sender_id", "deleted")}, + }) + except Exception as e: + decrypted.append({ + "message_id": msg_id, + "sender": "???", + "text": f"[Decryption failed: {e}]", + "created_at": m["created_at"], + "read_by": [], + }) + return decrypted + + async def _flush_self_encrypt(self): + """Upload self-encrypted copies of received messages for multi-device access.""" + if not self._pending_self_encrypt or not self.identity_private: + return + self_key = derive_self_encryption_key(self.identity_private) + updates = [] + for item in list(self._pending_self_encrypt): + try: + plaintext = json.dumps(item["payload"], ensure_ascii=False).encode("utf-8") + _, nonce, ct, tag = aes_encrypt(plaintext, key=self_key) + updates.append({ + "message_id": item["message_id"], + "encrypted_content": encode_binary(ct + tag), + "nonce": encode_binary(nonce), + }) + except Exception: + pass + self._pending_self_encrypt.clear() + if updates: + try: + for i in range(0, len(updates), 500): + batch = updates[i:i + 500] + await self.send_and_recv("reencrypt_messages", updates=batch) + except Exception as e: + self._logger.warning("Failed to self-encrypt received messages: %s", e) + + async def mark_read(self, conv_id: str, message_ids: list[str]): + if not message_ids: + return + await self.send_and_recv("mark_read", conversation_id=conv_id, message_ids=message_ids) + + async def mark_conversation_read(self, conv_id: str): + """Mark ALL unread messages in a conversation as read (server-side bulk).""" + try: + await self.send_and_recv("mark_conversation_read", conversation_id=conv_id) + except Exception: + pass # non-critical — don't fail message loading + + async def confirm_delivery(self, conv_id: str, message_ids: list[str]): + """Confirm delivery of messages (fire-and-forget, non-critical).""" + if not message_ids: + return + try: + await self.send_and_recv("confirm_delivery", + conversation_id=conv_id, message_ids=message_ids) + except Exception: + pass # non-critical + + def search_messages(self, conv_id: str, query: str) -> list[dict]: + """Search cached messages in a conversation. Returns matching messages.""" + cache = _load_message_cache(self.email, conv_id, self._cache_key) + query_lower = query.lower() + results = [] + for msg_id, payload in cache.items(): + if payload.get("deleted") or payload.get("_control") or payload.get("_sender_key"): + continue + text = payload.get("text", "") + if query_lower in text.lower(): + entry = dict(payload) + entry["message_id"] = msg_id + results.append(entry) + results.sort(key=lambda m: m.get("created_at", "")) + return results + + async def reset_session(self, peer_user_id: str, peer_device_id: str | None = None): + """Delete local session and notify peer to do the same.""" + if peer_device_id: + session_key = f"{peer_user_id}:{peer_device_id}" + else: + session_key = peer_user_id + self.sessions.pop(session_key, None) + _delete_session_file(self.email, peer_user_id, peer_device_id) + await self.send_and_recv("session_reset", + peer_user_id=peer_user_id, + peer_device_id=peer_device_id or "") + + def handle_session_reset_notification(self, from_user_id: str, from_device_id: str | None = None): + """Handle incoming session reset notification — delete the matching session.""" + if from_device_id: + session_key = f"{from_user_id}:{from_device_id}" + else: + session_key = from_user_id + self.sessions.pop(session_key, None) + _delete_session_file(self.email, from_user_id, from_device_id) + + # ------------------------------------------------------------------ + # Local message cache updates + # ------------------------------------------------------------------ + + def load_message_cache(self, conv_id: str) -> dict: + """Load cached messages for a conversation. Returns {msg_id: payload}.""" + if not self.email: + return {} + return _load_message_cache(self.email, conv_id, self._cache_key) + + def update_message_in_cache(self, conv_id: str, message_id: str, updates: dict): + """Update fields of a cached message on disk (synchronous).""" + if not self.email: + return + cache = _load_message_cache(self.email, conv_id, self._cache_key) + if message_id not in cache or cache[message_id].get("_control"): + return + for key, value in updates.items(): + if value is None: + cache[message_id].pop(key, None) + else: + cache[message_id][key] = value + d = get_key_dir(self.email) / "message_cache" + if self._cache_key: + _save_message_cache_full(d, conv_id, cache, self._cache_key) + + # ------------------------------------------------------------------ + # Reactions, Pins, Forwarding + # ------------------------------------------------------------------ + + async def react_message(self, message_id: str, reaction: str, action: str = "add") -> tuple[bool, str]: + """Add or remove a reaction on a message.""" + resp = await self.send_and_recv("react_message", + message_id=message_id, reaction=reaction, action=action) + if resp["status"] == "ok": + return True, "OK" + return False, resp.get("data", {}).get("message", "Failed") + + async def pin_message(self, message_id: str, conversation_id: str, action: str = "pin") -> tuple[bool, str]: + """Pin or unpin a message.""" + resp = await self.send_and_recv("pin_message", + message_id=message_id, conversation_id=conversation_id, action=action) + if resp["status"] == "ok": + return True, "OK" + return False, resp.get("data", {}).get("message", "Failed") + + async def get_pinned_messages(self, conversation_id: str) -> list[dict]: + """Get list of pinned messages for a conversation.""" + resp = await self.send_and_recv("get_pinned_messages", conversation_id=conversation_id) + if resp["status"] == "ok": + return resp["data"].get("messages", []) + return [] + + async def forward_message(self, target_conv_id: str, original_msg: dict, + target_members: list[dict]) -> tuple[bool, str | dict]: + """Forward a message to another conversation.""" + text = original_msg.get("text", "") + + payload = { + "sender": self.username, + "text": text, + "forwarded_from": { + "sender": original_msg.get("sender", ""), + "conversation_id": original_msg.get("conversation_id", ""), + "message_id": original_msg.get("message_id", ""), + }, + "timestamp": datetime.now(timezone.utc).isoformat(), + } + # Forward image/file metadata (the encrypted blob is already on the server) + if original_msg.get("image"): + payload["image"] = original_msg["image"] + if not text: + payload["text"] = "" + if original_msg.get("file"): + payload["file"] = original_msg["file"] + if not text: + payload["text"] = "" + plaintext = pad_plaintext(json.dumps(payload, ensure_ascii=False).encode("utf-8")) + + if self._is_group(target_members): + return await self._send_group_message(target_conv_id, plaintext, target_members, payload) + else: + return await self._send_dm(target_conv_id, plaintext, target_members, payload) + + # ------------------------------------------------------------------ + # Decrypt notification + # ------------------------------------------------------------------ + + def decrypt_notification(self, notif_data: dict) -> dict | None: + """Decrypt a new_message notification. Returns parsed payload or None. + + Supports new multi-device format (device_entries array) and legacy flat format. + """ + try: + conv_id = notif_data.get("conversation_id", "") + msg_id = notif_data.get("message_id", "") + sender_id = notif_data.get("sender_id", "") + sender_device_id = notif_data.get("sender_device_id") + my_user_id = self.session["user_id"] if self.session else "" + + # Extract per-device encrypted content from device_entries or flat fields + encrypted_content = "" + nonce = "" + ratchet_header = {} + x3dh_header = None + + device_entries = notif_data.get("device_entries") + if device_entries: + # Multi-device format: pick entry matching our device_id or SELF_DEVICE_ID + chosen = None + self_entry = None + for entry in device_entries: + eid = entry.get("device_id", "") + if eid == self.device_id: + chosen = entry + break + if eid == "00000000-0000-0000-0000-000000000000": + self_entry = entry + + # If sender is us, prefer self-encrypted entry + if sender_id == my_user_id: + chosen = self_entry or chosen + elif not chosen: + chosen = self_entry + + if not chosen: + self._logger.warning("No matching device_entry for device %s", self.device_id) + return None + + encrypted_content = chosen.get("encrypted_content", "") + nonce = chosen.get("nonce", "") + ratchet_header = chosen.get("ratchet_header") or notif_data.get("ratchet_header", {}) + x3dh_header = chosen.get("x3dh_header") or notif_data.get("x3dh_header") + else: + # Legacy flat format + encrypted_content = notif_data.get("encrypted_content", "") + nonce = notif_data.get("nonce", "") + ratchet_header = notif_data.get("ratchet_header", {}) + x3dh_header = notif_data.get("x3dh_header") + + msg_data = { + "sender_id": sender_id, + "sender_device_id": sender_device_id, + "conversation_id": conv_id, + "ratchet_header": ratchet_header, + "encrypted_content": encrypted_content, + "nonce": nonce, + "x3dh_header": x3dh_header, + "sender_chain_id": notif_data.get("sender_chain_id"), + "sender_chain_n": notif_data.get("sender_chain_n"), + } + payload = self._decrypt_message(msg_data) + if payload is None: + # Cache control message so get_messages skips it + if msg_id and conv_id: + _save_message_to_cache(self.email, conv_id, msg_id, {"_control": True}, + cache_key=self._cache_key) + return None + payload["conversation_id"] = conv_id + payload["message_id"] = msg_id + payload["sender_id"] = sender_id + # Use server-compatible timestamp (no timezone suffix) for cache consistency + _ts = payload.get("timestamp", "") + if _ts: + # Strip timezone suffix (+00:00 or Z) to match server DATETIME format + _ts = _ts.replace("+00:00", "").replace("Z", "") + # Strip microseconds if present + if "." in _ts: + _ts = _ts[:_ts.index(".")] + payload["created_at"] = _ts + payload["read_by"] = [] + payload["delivered_to"] = [] + # Cache so get_messages doesn't re-decrypt (ratchet keys are one-time) + if msg_id and conv_id: + _save_message_to_cache(self.email, conv_id, msg_id, payload, + cache_key=self._cache_key) + # Queue self-encryption for received messages (multi-device access) + if sender_id != my_user_id and msg_id: + self._pending_self_encrypt.append({ + "message_id": msg_id, + "payload": {k: v for k, v in payload.items() + if k not in ("conversation_id", "message_id", "created_at", + "read_by", "delivered_to", "sender_id", "deleted")}, + }) + return payload + except IdentityKeyChanged: + raise # Must propagate to caller for key-change UI + except Exception as e: + self._logger.warning("Failed to decrypt notification: %s", e) + return None + + # ------------------------------------------------------------------ + # Delete message + # ------------------------------------------------------------------ + + async def delete_message(self, message_id: str) -> tuple[bool, str]: + resp = await self.send_and_recv("delete_message", message_id=message_id) + if resp["status"] == "ok": + return True, "Message deleted." + return False, resp["data"]["message"] + + # ------------------------------------------------------------------ + # Image sharing + # ------------------------------------------------------------------ + + async def send_image(self, conv_id: str, image_path: str, members: list[dict], + reply_to: str | None = None) -> tuple[bool, str]: + """Encrypt and upload an image, then send as a message.""" + await self.typing_stop(conv_id, force=True) + try: + from PIL import Image + import io + except ImportError: + return False, "Pillow is required for image sharing. Install with: pip install Pillow" + + path = Path(image_path) + if not path.exists(): + return False, "File not found." + + try: + img = Image.open(path) + img.load() + except Exception as e: + return False, f"Cannot open image: {e}" + + # Prepare image bytes — offload CPU-heavy work to thread + def _prepare_image(img, path): + """PIL resize + thumbnail + AES encrypt (runs in thread).""" + original_format = img.format or "JPEG" + if original_format.upper() not in ("JPEG", "PNG", "WEBP", "GIF", "BMP"): + original_format = "JPEG" + + image_bytes = path.read_bytes() + + # AES-GCM overhead: 16 bytes tag. Check raw size as proxy. + if MAX_IMAGE_BYTES > 0 and len(image_bytes) + 16 > MAX_IMAGE_BYTES: + if img.mode not in ("RGB", "L"): + img = img.convert("RGB") + for quality in (92, 85, 75, 60): + buf = io.BytesIO() + img.save(buf, format="JPEG", quality=quality) + image_bytes = buf.getvalue() + if len(image_bytes) + 16 <= MAX_IMAGE_BYTES: + break + else: + for max_dim in (3840, 2560, 1920, 1280): + if max(img.size) > max_dim: + img.thumbnail((max_dim, max_dim), Image.Resampling.LANCZOS) + buf = io.BytesIO() + img.save(buf, format="JPEG", quality=75) + image_bytes = buf.getvalue() + if len(image_bytes) + 16 <= MAX_IMAGE_BYTES: + break + + # Generate thumbnail + thumb = img.copy() + thumb.thumbnail((200, 200), Image.Resampling.LANCZOS) + if thumb.mode not in ("RGB", "L"): + thumb = thumb.convert("RGB") + thumb_buf = io.BytesIO() + thumb.save(thumb_buf, format="JPEG", quality=60) + thumbnail_b64 = encode_binary(thumb_buf.getvalue()) + + # Encrypt + img_aes_key, img_iv, img_ct, img_tag = aes_encrypt(image_bytes) + encrypted_image = img_ct + img_tag + return image_bytes, thumbnail_b64, img_aes_key, img_iv, encrypted_image + + image_bytes, thumbnail_b64, img_aes_key, img_iv, encrypted_image = \ + await asyncio.get_event_loop().run_in_executor(None, _prepare_image, img, path) + + file_id = str(uuid.uuid4()) + file_size = len(encrypted_image) + + # Chunked upload + resp = await self.send_and_recv( + "upload_image_start", + conversation_id=conv_id, + file_id=file_id, + file_size=file_size, + ) + if resp["status"] != "ok": + return False, resp["data"]["message"] + + ok, err = await self._pipelined_upload(file_id, encrypted_image) + if not ok: + return False, err + + resp = await self.send_and_recv("upload_image_end", file_id=file_id) + if resp["status"] != "ok": + return False, resp["data"]["message"] + + # Cache decrypted image locally so sender never re-downloads + cache_path = self._media_cache_path(file_id) + if cache_path: + try: + cache_path.write_bytes(image_bytes) + os.chmod(cache_path, 0o600) + except OSError: + pass + + # Build message payload with image info + image_info = { + "file_id": file_id, + "aes_key": encode_binary(img_aes_key), + "iv": encode_binary(img_iv), + "thumbnail": thumbnail_b64, + "filename": path.name, + "size": len(image_bytes), + } + + payload = { + "sender": self.username, + "text": "", + "reply_to": reply_to, + "timestamp": datetime.now(timezone.utc).isoformat(), + "image": image_info, + } + plaintext = pad_plaintext(json.dumps(payload, ensure_ascii=False).encode("utf-8")) + + my_user_id = self.session["user_id"] + + if self._is_group(members): + # Group image: use sender key + sk = self.sender_key_states.get(conv_id) + if not sk: + sk = _load_sender_key_state(self.email, conv_id, self._local_key) + if not sk: + sk = SenderKeyState() + self.sender_key_states[conv_id] = sk + _save_sender_key_state(self.email, conv_id, sk, self._local_key) + self.sender_key_states[conv_id] = sk + await self._catch_up_sender_key_distribution(conv_id, members, sk) + + result = sk.encrypt(plaintext) + _save_sender_key_state(self.email, conv_id, sk, self._local_key) + + recipients = [] + for member in members: + uid = member.get("user_id") + if not uid or uid == my_user_id: + continue + recipients.append({ + "user_id": uid, + "encrypted_content": encode_binary(result["ciphertext"]), + "nonce": encode_binary(result["nonce"]), + }) + + # Self-encrypted copy for sender + self_key = derive_self_encryption_key(self.identity_private) + _, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key) + recipients.append({ + "user_id": my_user_id, + "encrypted_content": encode_binary(self_ct + self_tag), + "nonce": encode_binary(self_nonce), + "ratchet_header": {"self": True}, + }) + + resp = await self.send_and_recv( + "send_message", + conversation_id=conv_id, + ratchet_header={"dh_pub": "00" * 32, "n": 0, "pn": 0}, + recipients=recipients, + sender_chain_id=encode_binary(bytes.fromhex(result["chain_id"])), + sender_chain_n=result["n"], + image_file_id=file_id, + ) + else: + # DM image: per-device ratchet (same pattern as _send_dm) + recipients = [] + first_rh = None + for member in members: + uid = member.get("user_id") + if not uid or uid == my_user_id: + continue + + try: + device_bundles = await self._get_device_bundles(uid) + except Exception: + device_bundles = [] + + if not device_bundles: + # Fallback: legacy single-device + ratchet = await self._get_or_create_session(uid) + result = ratchet.encrypt(plaintext) + x3dh_h = getattr(ratchet, "_x3dh_header", None) + if x3dh_h: + delattr(ratchet, "_x3dh_header") + entry = { + "user_id": uid, + "encrypted_content": encode_binary(result["ciphertext"]), + "nonce": encode_binary(result["nonce"]), + "ratchet_header": result["header"], + } + if x3dh_h: + entry["x3dh_header"] = x3dh_h + recipients.append(entry) + if first_rh is None: + first_rh = result["header"] + _save_session(self.email, uid, ratchet, self._local_key) + else: + for bundle in device_bundles: + dev_id = bundle.get("device_id") + ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id, + bundle=bundle) + result = ratchet.encrypt(plaintext) + x3dh_h = getattr(ratchet, "_x3dh_header", None) + if x3dh_h: + delattr(ratchet, "_x3dh_header") + entry = { + "user_id": uid, + "encrypted_content": encode_binary(result["ciphertext"]), + "nonce": encode_binary(result["nonce"]), + "ratchet_header": result["header"], + } + if dev_id: + entry["device_id"] = dev_id + if x3dh_h: + entry["x3dh_header"] = x3dh_h + recipients.append(entry) + if first_rh is None: + first_rh = result["header"] + _save_session(self.email, uid, ratchet, self._local_key, + peer_device_id=dev_id) + + # Encrypt self-copy with static key + self_key = derive_self_encryption_key(self.identity_private) + _, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key) + recipients.append({ + "user_id": my_user_id, + "encrypted_content": encode_binary(self_ct + self_tag), + "nonce": encode_binary(self_nonce), + "ratchet_header": {"self": True}, + }) + + resp = await self.send_and_recv( + "send_message", + conversation_id=conv_id, + ratchet_header=first_rh, + recipients=recipients, + image_file_id=file_id, + ) + + if resp["status"] == "ok": + msg_data = resp.get("data", {}) + result_msg = { + **payload, + "message_id": msg_data.get("message_id", ""), + "created_at": msg_data.get("created_at", ""), + "sender_id": self.session["user_id"], + "conversation_id": conv_id, + "read_by": [], + } + _save_message_to_cache(self.email, conv_id, result_msg["message_id"], result_msg, self._cache_key) + return True, result_msg + return False, resp["data"]["message"] + + async def _pipelined_upload(self, file_id: str, encrypted_data: bytes) -> tuple[bool, str]: + """Upload encrypted data in pipelined chunks (no per-chunk round-trip wait).""" + file_size = len(encrypted_data) + chunk_futures = [] + upload_offset = 0 + while upload_offset < file_size: + chunk = encrypted_data[upload_offset:upload_offset + IMAGE_CHUNK_SIZE] + request_id = str(uuid.uuid4()) + loop = asyncio.get_running_loop() + fut = loop.create_future() + self._pending[request_id] = fut + try: + await self.writer.send_request( + "upload_image_chunk", + request_id=request_id, + file_id=file_id, + data=encode_binary(chunk), + ) + except Exception as e: + self._pending.pop(request_id, None) + return False, f"Upload failed: {e}" + chunk_futures.append((request_id, fut)) + upload_offset += len(chunk) + + for request_id, fut in chunk_futures: + try: + resp = await asyncio.wait_for(fut, timeout=30.0) + except (asyncio.TimeoutError, ConnectionError): + self._pending.pop(request_id, None) + return False, "Upload chunk timed out." + finally: + self._pending.pop(request_id, None) + if resp["status"] != "ok": + return False, resp["data"]["message"] + return True, "" + + async def send_file(self, conv_id: str, file_path: str, members: list[dict], + reply_to: str | None = None) -> tuple[bool, str | dict]: + """Encrypt and upload a file, then send as a message.""" + await self.typing_stop(conv_id, force=True) + import mimetypes + + path = Path(file_path) + if not path.exists(): + return False, "File not found." + + try: + file_bytes = path.read_bytes() + except Exception as e: + return False, f"Cannot read file: {e}" + + mime_type = mimetypes.guess_type(path.name)[0] or "application/octet-stream" + + # Encrypt file with AES-256-GCM (offload to thread) + file_aes_key, file_iv, file_ct, file_tag = await asyncio.get_event_loop().run_in_executor( + None, aes_encrypt, file_bytes) + encrypted_file = file_ct + file_tag + + file_id = str(uuid.uuid4()) + file_size = len(encrypted_file) + + # Chunked upload (reuse image upload infrastructure with file_type="file") + resp = await self.send_and_recv( + "upload_image_start", + conversation_id=conv_id, + file_id=file_id, + file_size=file_size, + file_type="file", + ) + if resp["status"] != "ok": + return False, resp["data"]["message"] + + ok, err = await self._pipelined_upload(file_id, encrypted_file) + if not ok: + return False, err + + resp = await self.send_and_recv("upload_image_end", file_id=file_id) + if resp["status"] != "ok": + return False, resp["data"]["message"] + + # Cache decrypted file locally so sender never re-downloads + cache_path = self._media_cache_path(file_id) + if cache_path: + try: + cache_path.write_bytes(file_bytes) + os.chmod(cache_path, 0o600) + except OSError: + pass + + # Build message payload with file info + file_info = { + "file_id": file_id, + "aes_key": encode_binary(file_aes_key), + "iv": encode_binary(file_iv), + "filename": path.name, + "size": len(file_bytes), + "mime_type": mime_type, + } + + payload = { + "sender": self.username, + "text": "", + "reply_to": reply_to, + "timestamp": datetime.now(timezone.utc).isoformat(), + "file": file_info, + } + plaintext = pad_plaintext(json.dumps(payload, ensure_ascii=False).encode("utf-8")) + + my_user_id = self.session["user_id"] + + if self._is_group(members): + sk = self.sender_key_states.get(conv_id) + if not sk: + sk = _load_sender_key_state(self.email, conv_id, self._local_key) + if not sk: + sk = SenderKeyState() + self.sender_key_states[conv_id] = sk + _save_sender_key_state(self.email, conv_id, sk, self._local_key) + self.sender_key_states[conv_id] = sk + await self._catch_up_sender_key_distribution(conv_id, members, sk) + + result = sk.encrypt(plaintext) + _save_sender_key_state(self.email, conv_id, sk, self._local_key) + + recipients = [] + for member in members: + uid = member.get("user_id") + if not uid or uid == my_user_id: + continue + recipients.append({ + "user_id": uid, + "encrypted_content": encode_binary(result["ciphertext"]), + "nonce": encode_binary(result["nonce"]), + }) + + # Self-encrypted copy for sender + self_key = derive_self_encryption_key(self.identity_private) + _, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key) + recipients.append({ + "user_id": my_user_id, + "encrypted_content": encode_binary(self_ct + self_tag), + "nonce": encode_binary(self_nonce), + "ratchet_header": {"self": True}, + }) + + resp = await self.send_and_recv( + "send_message", + conversation_id=conv_id, + ratchet_header={"dh_pub": "00" * 32, "n": 0, "pn": 0}, + recipients=recipients, + sender_chain_id=encode_binary(bytes.fromhex(result["chain_id"])), + sender_chain_n=result["n"], + image_file_id=file_id, + ) + else: + # DM file: per-device ratchet (same pattern as _send_dm) + recipients = [] + first_rh = None + for member in members: + uid = member.get("user_id") + if not uid or uid == my_user_id: + continue + + try: + device_bundles = await self._get_device_bundles(uid) + except Exception: + device_bundles = [] + + if not device_bundles: + # Fallback: legacy single-device + ratchet = await self._get_or_create_session(uid) + result = ratchet.encrypt(plaintext) + x3dh_h = getattr(ratchet, "_x3dh_header", None) + if x3dh_h: + delattr(ratchet, "_x3dh_header") + entry = { + "user_id": uid, + "encrypted_content": encode_binary(result["ciphertext"]), + "nonce": encode_binary(result["nonce"]), + "ratchet_header": result["header"], + } + if x3dh_h: + entry["x3dh_header"] = x3dh_h + recipients.append(entry) + if first_rh is None: + first_rh = result["header"] + _save_session(self.email, uid, ratchet, self._local_key) + else: + for bundle in device_bundles: + dev_id = bundle.get("device_id") + ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id, + bundle=bundle) + result = ratchet.encrypt(plaintext) + x3dh_h = getattr(ratchet, "_x3dh_header", None) + if x3dh_h: + delattr(ratchet, "_x3dh_header") + entry = { + "user_id": uid, + "encrypted_content": encode_binary(result["ciphertext"]), + "nonce": encode_binary(result["nonce"]), + "ratchet_header": result["header"], + } + if dev_id: + entry["device_id"] = dev_id + if x3dh_h: + entry["x3dh_header"] = x3dh_h + recipients.append(entry) + if first_rh is None: + first_rh = result["header"] + _save_session(self.email, uid, ratchet, self._local_key, + peer_device_id=dev_id) + + # Encrypt self-copy with static key + self_key = derive_self_encryption_key(self.identity_private) + _, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key) + recipients.append({ + "user_id": my_user_id, + "encrypted_content": encode_binary(self_ct + self_tag), + "nonce": encode_binary(self_nonce), + "ratchet_header": {"self": True}, + }) + + resp = await self.send_and_recv( + "send_message", + conversation_id=conv_id, + ratchet_header=first_rh, + recipients=recipients, + image_file_id=file_id, + ) + + if resp["status"] == "ok": + msg_data = resp.get("data", {}) + result_msg = { + **payload, + "message_id": msg_data.get("message_id", ""), + "created_at": msg_data.get("created_at", ""), + "sender_id": self.session["user_id"], + "conversation_id": conv_id, + "read_by": [], + } + _save_message_to_cache(self.email, conv_id, result_msg["message_id"], result_msg, self._cache_key) + return True, result_msg + return False, resp["data"]["message"] + + async def download_file(self, file_id: str, file_info: dict) -> bytes | None: + """Download and decrypt a file. Returns decrypted file bytes or None.""" + return await self._download_and_decrypt(file_id, file_info) + + def _media_cache_path(self, file_id: str) -> Path | None: + """Return path for cached decrypted media file, or None if no email.""" + if not self.email: + return None + d = get_key_dir(self.email) / "media_cache" + d.mkdir(parents=True, exist_ok=True) + try: + os.chmod(d, 0o700) + except OSError: + pass + return d / f"{file_id}.bin" + + async def _stream_download(self, file_id: str) -> bytes | None: + """Download file via streaming (single request, server sends all chunks). + + Falls back to legacy pipelined download if server doesn't support streaming. + """ + request_id = str(uuid.uuid4()) + q: asyncio.Queue = asyncio.Queue() + self._pending[request_id] = q + try: + await self.writer.send_request( + "download_stream", request_id=request_id, file_id=file_id, + ) + except Exception: + self._pending.pop(request_id, None) + return None + + chunks: dict[int, bytes] = {} + try: + while True: + try: + resp = await asyncio.wait_for(q.get(), timeout=60.0) + except (asyncio.TimeoutError, ConnectionError): + return None + if resp.get("status") != "ok": + # Server may not support download_stream — fall back + return await self._legacy_download(file_id) + data = resp["data"] + chunk_data = decode_binary(data["data"]) + chunks[data["offset"]] = chunk_data + if data.get("done"): + break + finally: + self._pending.pop(request_id, None) + + # Reassemble in order + parts = [] + for off in sorted(chunks.keys()): + parts.append(chunks[off]) + return b"".join(parts) + + async def _legacy_download(self, file_id: str) -> bytes | None: + """Fallback: download file chunk by chunk (for older servers).""" + resp = await self.send_and_recv("download_image", file_id=file_id, offset=0) + if resp["status"] != "ok": + return None + data = resp["data"] + first_chunk = decode_binary(data["data"]) + if data.get("done"): + return first_chunk + + chunk_size = len(first_chunk) + chunks = {0: first_chunk} + + # Pipeline remaining chunks + futures = [] + offset = chunk_size + total_size = data.get("total_size", 0) + # Calculate how many chunks we need + while offset < total_size: + request_id = str(uuid.uuid4()) + loop = asyncio.get_running_loop() + fut = loop.create_future() + self._pending[request_id] = fut + try: + await self.writer.send_request( + "download_image", request_id=request_id, + file_id=file_id, offset=offset, + ) + except Exception: + self._pending.pop(request_id, None) + return None + futures.append((request_id, offset, fut)) + offset += chunk_size + + for request_id, off, fut in futures: + try: + resp = await asyncio.wait_for(fut, timeout=30.0) + except (asyncio.TimeoutError, ConnectionError): + self._pending.pop(request_id, None) + return None + finally: + self._pending.pop(request_id, None) + if resp["status"] != "ok": + return None + chunk_data = decode_binary(resp["data"]["data"]) + chunks[off] = chunk_data + + parts = [] + for off in sorted(chunks.keys()): + parts.append(chunks[off]) + return b"".join(parts) + + async def _download_and_decrypt(self, file_id: str, info: dict) -> bytes | None: + """Download, decrypt, and cache a media file. Used by both image and file download.""" + # Check local cache first + cache_path = self._media_cache_path(file_id) + if cache_path and cache_path.exists(): + try: + return cache_path.read_bytes() + except OSError: + pass + + encrypted_data = await self._stream_download(file_id) + if not encrypted_data or len(encrypted_data) < 16: + return None + + ciphertext = encrypted_data[:-16] + tag = encrypted_data[-16:] + + try: + aes_key = decode_binary(info["aes_key"]) + iv = decode_binary(info["iv"]) + decrypted = aes_decrypt(aes_key, iv, ciphertext, tag) + except Exception: + return None + + # Cache decrypted result to disk + if cache_path and decrypted: + try: + cache_path.write_bytes(decrypted) + os.chmod(cache_path, 0o600) + except OSError: + pass + + return decrypted + + async def download_image(self, file_id: str, image_info: dict) -> bytes | None: + """Download and decrypt an image. Returns decrypted image bytes or None.""" + return await self._download_and_decrypt(file_id, image_info) + + # ------------------------------------------------------------------ + # Re-encrypt history (for device pairing) + # ------------------------------------------------------------------ + + async def reencrypt_history(self): + """Re-encrypt all cached messages with self-encryption key. + + After device pairing, the new device shares the same identity key + but cannot decrypt old messages (Double Ratchet keys are one-time use). + This re-encrypts all cached messages so they can be read using the + self-encryption key derived from the shared identity key. + """ + if not self.identity_private or not self.session: + return + + self_key = derive_self_encryption_key(self.identity_private) + + # Phase 1: Fetch & decrypt all messages to populate cache + # (messages the old device never opened won't be in cache yet) + try: + convs = await self.list_conversations() + convs = list(convs) + random.shuffle(convs) + total_convs = len(convs) + for ci, conv in enumerate(convs): + cid = conv.get("id") or conv.get("conversation_id") + if not cid: + continue + if self._reencrypt_progress_cb: + self._reencrypt_progress_cb( + f"Fetching messages: {ci + 1}/{total_convs} conversations..." + ) + offset = 0 + while True: + msgs = await self.get_messages(cid, limit=200, offset=offset) + if not msgs or len(msgs) < 200: + break + offset += len(msgs) + await asyncio.sleep(random.uniform(*PAIRING_REENCRYPT_INTER_FETCH_DELAY_RANGE)) + except Exception as e: + self._logger.warning("Failed to fetch messages for re-encryption: %s", e) + + # Phase 2: Read cache and re-encrypt + cache_dir = get_key_dir(self.email) / "message_cache" + if not cache_dir.exists(): + self._logger.info("No message cache to re-encrypt.") + return + + all_updates = [] + conv_ids = set() + for f in cache_dir.iterdir(): + if f.suffix in (".json", ".bin"): + conv_ids.add(f.stem) + + total_files = len(conv_ids) + conv_order = sorted(conv_ids) + random.shuffle(conv_order) + for i, conv_id in enumerate(conv_order): + cache = _load_message_cache(self.email, conv_id, self._cache_key) + if not cache: + continue + + items = list(cache.items()) + random.shuffle(items) + for msg_id, entry in items: + # Skip control messages (sender key distribution) + if entry.get("_control"): + continue + # Skip entries with no useful content + text = entry.get("text", "") + if not text and not entry.get("image") and not entry.get("file"): + continue + + # Rebuild plaintext from cached payload + payload = {k: v for k, v in entry.items() + if k not in ("message_id", "created_at", "read_by", "sender_id", "deleted")} + plaintext = pad_plaintext(json.dumps(payload, ensure_ascii=False).encode("utf-8")) + + # Re-encrypt with self-encryption key + _, nonce, ct, tag = aes_encrypt(plaintext, key=self_key) + all_updates.append({ + "message_id": msg_id, + "encrypted_content": encode_binary(ct + tag), + "nonce": encode_binary(nonce), + }) + + if self._reencrypt_progress_cb: + self._reencrypt_progress_cb(f"Re-encrypting history: {i + 1}/{total_files} conversations...") + + if not all_updates: + self._logger.info("No messages to re-encrypt.") + return + + # Send in batches of 500 + batch_size = PAIRING_REENCRYPT_BATCH_SIZE + random.shuffle(all_updates) + total = len(all_updates) + for start in range(0, total, batch_size): + batch = all_updates[start:start + batch_size] + resp = await self.send_and_recv("reencrypt_messages", updates=batch) + if resp["status"] != "ok": + self._logger.warning("Re-encrypt batch failed: %s", resp.get("data", {}).get("message", "")) + else: + self._logger.info("Re-encrypted %d/%d messages.", min(start + batch_size, total), total) + if start + batch_size < total: + await asyncio.sleep(random.uniform(*PAIRING_REENCRYPT_INTER_BATCH_DELAY_RANGE)) + + if self._reencrypt_progress_cb: + self._reencrypt_progress_cb(f"Re-encryption complete: {total} messages uploaded.") + + # ------------------------------------------------------------------ + # User Profiles + # ------------------------------------------------------------------ + + async def get_profile(self, user_id: str | None = None) -> dict | None: + """Get user profile. If user_id is None, returns own profile.""" + kwargs = {} + if user_id: + kwargs["user_id"] = user_id + resp = await self.send_and_recv("get_profile", **kwargs) + if resp["status"] == "ok": + return resp["data"] + return None + + async def update_profile(self, **fields) -> tuple[bool, str]: + """Update own profile (phone, location, *_visible).""" + resp = await self.send_and_recv("update_profile", **fields) + if resp["status"] == "ok": + return True, "OK" + return False, resp["data"]["message"] + + async def update_avatar(self, image_data: bytes) -> tuple[bool, str]: + """Upload avatar image.""" + resp = await self.send_and_recv("update_avatar", data=encode_binary(image_data)) + if resp["status"] == "ok": + return True, resp["data"].get("avatar_file", "") + return False, resp["data"]["message"] + + async def get_avatar(self, user_id: str) -> bytes | None: + """Download avatar for a user.""" + resp = await self.send_and_recv("get_avatar", 10.0, user_id=user_id) + if resp["status"] == "ok": + return decode_binary(resp["data"]["data"]) + return None + + async def update_group_avatar(self, conv_id: str, image_data: bytes) -> tuple[bool, str]: + """Upload avatar for a group conversation.""" + resp = await self.send_and_recv("update_group_avatar", + conversation_id=conv_id, data=encode_binary(image_data)) + if resp["status"] == "ok": + return True, resp["data"].get("avatar_file", "") + return False, resp["data"]["message"] + + async def get_group_avatar(self, conv_id: str) -> bytes | None: + """Download avatar for a group conversation.""" + resp = await self.send_and_recv("get_group_avatar", 10.0, conversation_id=conv_id) + if resp["status"] == "ok": + return decode_binary(resp["data"]["data"]) + return None + + # ------------------------------------------------------------------ + # Cleanup + # ------------------------------------------------------------------ + + async def close(self): + self.connected = False + for conv_id in list(self._typing_stop_tasks.keys()): + self._cancel_typing_timer(conv_id) + if self._listener_task: + self._listener_task.cancel() + if self.raw_writer: + self.raw_writer.close() + + async def reconnect(self): + """Close existing connection and re-establish: connect + re-login using in-memory keys.""" + try: + await self.close() + except Exception: + pass + # Reset reader/writer but keep keys and sessions + self.reader = None + self.writer = None + self.raw_writer = None + self._listener_task = None + self._pending.clear() + self.login_rejected = False + # Drain queues + while not self._response_queue.empty(): + try: + self._response_queue.get_nowait() + except Exception: + break + while not self._notification_queue.empty(): + try: + self._notification_queue.get_nowait() + except Exception: + break + await self.connect() + self._listener_task = asyncio.create_task(self._background_listener()) + if self.email and self.private_key: + # RSA challenge-response login (keys already in memory) + start = await self.send_and_recv("login_start", email=self.email) + if start["status"] == "ok": + challenge = decode_binary(start["data"]["challenge"]) + signature = rsa_sign(self.private_key, challenge) + login_kwargs = { + "email": self.email, + "signature": encode_binary(signature), + "client_version": VERSION, + } + if self.device_id: + login_kwargs["device_id"] = self.device_id + finish = await self.send_and_recv("login_finish", **login_kwargs) + if finish["status"] == "ok": + self.session = finish["data"] + asyncio.create_task(self._ensure_prekeys()) + else: + # Login rejected — keys were likely rotated on another device + self.session = None + self.connected = False + self.login_rejected = True diff --git a/client.py b/client.py new file mode 100644 index 0000000..f897615 --- /dev/null +++ b/client.py @@ -0,0 +1,928 @@ +"""Interactive CLI client for encrypted chat (X3DH + Double Ratchet).""" + +import asyncio +import getpass +import logging +import os +import re + +from chat_core import ChatClient, IdentityKeyChanged + + +def setup_logging(): + level_name = os.getenv("LOG_LEVEL", "WARNING").upper() + level = getattr(logging, level_name, logging.WARNING) + logging.basicConfig(level=level, format="%(levelname)s: %(message)s") + + +async def prompt(text: str) -> str: + """Non-blocking terminal input.""" + return await asyncio.get_event_loop().run_in_executor(None, lambda: input(text).strip()) + + +async def prompt_password(text: str = "Password: ") -> str: + """Non-blocking hidden password input (M3 fix).""" + return await asyncio.get_event_loop().run_in_executor(None, lambda: getpass.getpass(text)) + + +# M3 fix: strip terminal control/escape sequences from untrusted text +_CONTROL_RE = re.compile(r"[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]|\x1b\[[0-9;]*[A-Za-z]") + + +def _sanitize_text(s) -> str: + """Remove control characters and ANSI escape sequences.""" + if not isinstance(s, str): + s = str(s) if s is not None else "" + return _CONTROL_RE.sub("", s) + + +def _safe_filename(name: str) -> str: + """Sanitize remote filename: basename only, no path traversal, no NUL.""" + name = os.path.basename(name) + name = name.replace("\x00", "") + if not name or name.startswith("."): + name = "download" + return name + + +def _human_size(n: int) -> str: + if n >= 1024 * 1024: + return f"{n / (1024*1024):.1f} MB" + if n >= 1024: + return f"{n / 1024:.0f} KB" + return f"{n} B" + + +async def _select_conversation(client: ChatClient, label: str = "Select conversation") -> tuple[dict | None, list[dict]]: + """List conversations and let user pick one. Returns (conv, convs) or (None, []).""" + convs = await client.list_conversations() + if not convs: + print("[*] No conversations.") + return None, [] + + def conv_label(c): + if c.get("name"): + return _sanitize_text(c["name"]) + others = [_sanitize_text(m.get("username") or m.get("email") or "?") for m in c["members"] if m.get("email") != client.email] + return ", ".join(others) if others else _sanitize_text(client.username) + + print() + for i, c in enumerate(convs): + print(f" {i+1}) {conv_label(c)}") + choice = await prompt(f"{label}: ") + try: + idx = int(choice) - 1 + if not (0 <= idx < len(convs)): + print("[!] Invalid selection.") + return None, convs + except ValueError: + print("[!] Invalid selection.") + return None, convs + return convs[idx], convs + + +async def interactive_menu(client: ChatClient): + """Interactive terminal menu.""" + while True: + print("\n--- Encrypted Chat ---") + print("1) Send direct message") + print("2) Send to conversation") + print("3) Read messages") + print("4) Create group conversation") + print("5) Add member to group") + print("6) Send image") + print("7) Send file") + print("8) Invitations") + print("9) Leave group") + print("10) Rename group") + print("11) Delete conversation") + print("12) Search messages") + print("13) My profile") + print("14) View user profile") + print("15) Manage devices") + print("16) React to message") + print("17) Pin/Unpin message") + print("18) View pinned messages") + print("19) Forward message") + print("20) Verify contact") + print("21) Show my fingerprint") + print("22) Change password") + print("23) Change username") + print("q) Quit") + + choice = await prompt("> ") + + if choice == "1": + email = await prompt("To (email): ") + if not email: + continue + text = await prompt("Message: ") + if not text: + continue + conv_id, msg = await client.find_or_create_conversation(email) + if not conv_id: + print(f"[!] {msg}") + continue + convs = await client.list_conversations() + members = [] + for c in convs: + if c["conversation_id"] == conv_id: + members = c["members"] + break + try: + ok, result = await client.send_message(conv_id, text, members) + except IdentityKeyChanged as ikc: + print(f"[!] Identity key changed for {ikc.user_id[:8]}. Accept the new key before sending.") + continue + print(f"[{'+'if ok else '!'}] {'Message sent.' if ok else result}") + + elif choice == "2": + conv, _ = await _select_conversation(client) + if not conv: + continue + text = await prompt("Message: ") + if not text: + continue + try: + ok, result = await client.send_message(conv["conversation_id"], text, conv["members"]) + except IdentityKeyChanged as ikc: + print(f"[!] Identity key changed for {ikc.user_id[:8]}. Accept the new key before sending.") + continue + print(f"[{'+'if ok else '!'}] {'Message sent.' if ok else result}") + + elif choice == "3": + conv, _ = await _select_conversation(client) + if not conv: + continue + messages = await client.get_messages(conv["conversation_id"]) + if not messages: + print("[*] No messages.") + continue + _print_messages(messages, client, conv) + + action = await prompt("\nAction (r=reply, d=delete, dl=download file, empty=back): ") + if not action: + continue + if action.lower().startswith("dl"): + await _download_file_action(client, messages) + continue + if action.lower().startswith("d"): + await _delete_message_action(client, messages) + continue + if action.lower().startswith("r"): + reply_choice = await prompt("Reply to message #: ") + else: + reply_choice = action + try: + reply_idx = int(reply_choice) - 1 + if not (0 <= reply_idx < len(messages)): + print("[!] Invalid message number.") + continue + except ValueError: + print("[!] Invalid number.") + continue + reply_to_id = messages[reply_idx]["message_id"] + text = await prompt("Message: ") + if not text: + continue + try: + ok, result = await client.send_message(conv["conversation_id"], text, conv["members"], reply_to=reply_to_id) + except IdentityKeyChanged as ikc: + print(f"[!] Identity key changed for {ikc.user_id[:8]}. Accept the new key before sending.") + continue + print(f"[{'+'if ok else '!'}] {'Message sent.' if ok else result}") + + elif choice == "4": + name = await prompt("Group name (empty for none): ") + members_input = await prompt("Member emails (comma-separated): ") + members = [m.strip() for m in members_input.split(",") if m.strip()] + if not members: + continue + conv_id, msg = await client.create_conversation(members, name=name.strip() or None) + if conv_id: + print(f"[+] Group created with: {', '.join(members)}") + else: + print(f"[!] {msg}") + + elif choice == "5": + conv, _ = await _select_conversation(client) + if not conv: + continue + email = await prompt("Email to add: ") + ok, msg = await client.add_member(conv["conversation_id"], email) + print(f"[{'+'if ok else '!'}] {msg or 'Invitation sent.'}") + + elif choice == "6": + conv, _ = await _select_conversation(client) + if not conv: + continue + image_path = await prompt("Image path: ") + if not image_path: + continue + try: + ok, msg = await client.send_image(conv["conversation_id"], image_path, conv["members"]) + except IdentityKeyChanged as ikc: + print(f"[!] Identity key changed for {ikc.user_id[:8]}. Accept the new key before sending.") + continue + print(f"[{'+'if ok else '!'}] {msg}") + + elif choice == "7": + conv, _ = await _select_conversation(client) + if not conv: + continue + file_path = await prompt("File path: ") + if not file_path: + continue + if not os.path.isfile(file_path): + print("[!] File not found.") + continue + try: + ok, msg = await client.send_file(conv["conversation_id"], file_path, conv["members"]) + except IdentityKeyChanged as ikc: + print(f"[!] Identity key changed for {ikc.user_id[:8]}. Accept the new key before sending.") + continue + print(f"[{'+'if ok else '!'}] {msg}") + + elif choice == "8": + await _invitations_menu(client) + + elif choice == "9": + conv, _ = await _select_conversation(client, "Select group to leave") + if not conv: + continue + confirm = await prompt(f"Leave '{conv.get('name', 'this conversation')}'? (y/n): ") + if confirm.lower() != "y": + continue + ok, msg = await client.leave_group(conv["conversation_id"]) + print(f"[{'+'if ok else '!'}] {msg}") + + elif choice == "10": + conv, _ = await _select_conversation(client, "Select group to rename") + if not conv: + continue + name = await prompt("New name: ") + if not name: + continue + ok, msg = await client.rename_conversation(conv["conversation_id"], name.strip()) + print(f"[{'+'if ok else '!'}] {msg}") + + elif choice == "11": + conv, _ = await _select_conversation(client, "Select conversation to delete") + if not conv: + continue + confirm = await prompt("Delete this conversation? This cannot be undone. (y/n): ") + if confirm.lower() != "y": + continue + ok, msg = await client.delete_conversation(conv["conversation_id"]) + print(f"[{'+'if ok else '!'}] {msg}") + + elif choice == "12": + conv, _ = await _select_conversation(client, "Select conversation to search") + if not conv: + continue + query = await prompt("Search query: ") + if not query: + continue + # First ensure we have messages cached by fetching them + await client.get_messages(conv["conversation_id"]) + results = client.search_messages(conv["conversation_id"], query) + if not results: + print("[*] No matches found.") + continue + print(f"\n[*] {len(results)} match(es):") + for r in results: + sender = _sanitize_text(r.get("sender", "???")) + text = _sanitize_text(r.get("text", "")) + ts = r.get("created_at", "")[:16] + # Highlight match in text + idx = text.lower().find(query.lower()) + if idx >= 0: + text = text[:idx] + "\033[33m" + text[idx:idx+len(query)] + "\033[0m" + text[idx+len(query):] + print(f" [{ts}] {sender}: {text}") + + elif choice == "13": + await _my_profile_menu(client) + + elif choice == "14": + email = await prompt("User email: ") + if not email: + continue + # Need to find user_id from email — try via conversation members + user_id = None + convs = await client.list_conversations() + for c in convs: + for m in c.get("members", []): + if m.get("email") == email: + user_id = m.get("user_id") or m.get("id") + break + if user_id: + break + if not user_id: + print("[!] User not found in your conversations.") + continue + profile = await client.get_profile(user_id) + if not profile: + print("[!] Could not load profile.") + continue + _print_profile(profile) + + elif choice == "15": + await _devices_menu(client) + + elif choice == "16": + conv, _ = await _select_conversation(client) + if not conv: + continue + messages = await client.get_messages(conv["conversation_id"]) + if not messages: + print("[*] No messages.") + continue + _print_messages(messages, client, conv) + msg_choice = await prompt("React to message #: ") + try: + msg_idx = int(msg_choice) - 1 + if not (0 <= msg_idx < len(messages)): + print("[!] Invalid message number.") + continue + except ValueError: + print("[!] Invalid number.") + continue + print("Reactions: thumbsup, heart, laugh, surprised, sad, thumbsdown") + reaction = await prompt("Reaction: ").strip().lower() + if reaction not in ("thumbsup", "heart", "laugh", "surprised", "sad", "thumbsdown"): + print("[!] Invalid reaction.") + continue + ok, msg = await client.react_message(messages[msg_idx]["message_id"], reaction, "add") + print(f"[{'+'if ok else '!'}] {msg}") + + elif choice == "17": + conv, _ = await _select_conversation(client) + if not conv: + continue + messages = await client.get_messages(conv["conversation_id"]) + if not messages: + print("[*] No messages.") + continue + _print_messages(messages, client, conv) + msg_choice = await prompt("Pin/Unpin message #: ") + try: + msg_idx = int(msg_choice) - 1 + if not (0 <= msg_idx < len(messages)): + print("[!] Invalid message number.") + continue + except ValueError: + print("[!] Invalid number.") + continue + m = messages[msg_idx] + action = "unpin" if m.get("pinned_at") else "pin" + ok, msg = await client.pin_message(m["message_id"], conv["conversation_id"], action) + print(f"[{'+'if ok else '!'}] {action.capitalize()}: {msg}") + + elif choice == "18": + conv, _ = await _select_conversation(client) + if not conv: + continue + pinned = await client.get_pinned_messages(conv["conversation_id"]) + if not pinned: + print("[*] No pinned messages.") + continue + print(f"\n[*] {len(pinned)} pinned message(s):") + for p in pinned: + print(f" {p.get('message_id', '?')[:8]}... pinned at {p.get('pinned_at', '?')}") + + elif choice == "19": + conv, _ = await _select_conversation(client, "Select source conversation") + if not conv: + continue + messages = await client.get_messages(conv["conversation_id"]) + if not messages: + print("[*] No messages.") + continue + _print_messages(messages, client, conv) + msg_choice = await prompt("Forward message #: ") + try: + msg_idx = int(msg_choice) - 1 + if not (0 <= msg_idx < len(messages)): + print("[!] Invalid message number.") + continue + except ValueError: + print("[!] Invalid number.") + continue + target_conv, _ = await _select_conversation(client, "Select target conversation") + if not target_conv: + continue + fwd_msg = messages[msg_idx] + fwd_msg["conversation_id"] = conv["conversation_id"] + try: + ok, result = await client.forward_message( + target_conv["conversation_id"], fwd_msg, target_conv["members"] + ) + except IdentityKeyChanged as ikc: + print(f"[!] Identity key changed for {ikc.user_id[:8]}. Accept the new key before sending.") + continue + print(f"[{'+'if ok else '!'}] {'Forwarded.' if ok else result}") + + elif choice == "20": + # Verify contact — show safety number for a DM conversation + conv, _ = await _select_conversation(client, "Select DM to verify") + if not conv: + continue + # Find peer user_id + peer_uid = "" + peer_name = "" + for m in conv.get("members", []): + if m.get("email") != client.email: + peer_uid = m.get("user_id") or m.get("id") or "" + peer_name = _sanitize_text(m.get("username") or m.get("email") or "?") + break + if not peer_uid: + print("[!] Could not identify peer user.") + continue + # Ensure we have their identity key in cache + info = await client._get_user_info(user_id=peer_uid) + if not info or not info.get("identity_key_bytes"): + print("[!] Could not retrieve identity key for this user.") + continue + status = client.get_verification_status(peer_uid) + print(f"\n--- Verification: {peer_name} ---") + print(f"Status: {status.upper()}") + safety = client.get_safety_number(peer_uid) + if safety: + print(f"\nSafety Number:\n{safety}") + fp = client.get_peer_fingerprint(peer_uid) + if fp: + print(f"\nTheir Fingerprint:\n{fp}") + my_fp = client.get_my_fingerprint() + if my_fp: + print(f"\nYour Fingerprint:\n{my_fp}") + if status != "verified": + action = await prompt("\nMark as verified? (y/n): ") + if action.lower() == "y": + client.verify_contact(peer_uid, info["identity_key_bytes"], + method="safety_number") + print("[+] Contact marked as verified.") + else: + action = await prompt("\nRemove verification? (y/n): ") + if action.lower() == "y": + client.unverify_contact(peer_uid) + print("[+] Verification removed.") + + elif choice == "21": + # Show own fingerprint + fp = client.get_my_fingerprint() + if fp: + print(f"\n--- Your Fingerprint ---\n{fp}") + else: + print("[!] Not logged in or identity key not available.") + + elif choice == "22": + # Change password + old_pw = getpass.getpass("Current password: ") + new_pw = getpass.getpass("New password: ") + confirm_pw = getpass.getpass("Confirm new password: ") + if new_pw != confirm_pw: + print("[!] Passwords do not match.") + elif not new_pw: + print("[!] Password cannot be empty.") + else: + ok, msg = client.change_password(old_pw, new_pw) + if ok: + print(f"[+] {msg}") + else: + print(f"[!] {msg}") + + elif choice == "23": + new_un = await prompt("New username: ") + new_un = new_un.strip() if new_un else "" + if not new_un: + print("[!] Username cannot be empty.") + else: + ok, msg = await client.change_username(new_un) + if ok: + print(f"[+] {msg}") + else: + print(f"[!] {msg}") + + elif choice in ("q", "Q", "quit", "exit"): + print("[*] Bye.") + break + + +def _print_messages(messages, client, conv): + """Print messages to terminal.""" + print() + for i, m in enumerate(messages): + if m.get("deleted"): + print(f" #{i+1} [Message deleted]") + continue + reply_info = "" + if m.get("reply_to"): + for j, orig in enumerate(messages): + if orig["message_id"] == m["reply_to"]: + reply_info = f" (reply to #{j+1})" + break + else: + reply_info = " (reply to older message)" + image_info = "" + if m.get("image"): + img = m["image"] + image_info = f" [Image: {_sanitize_text(img.get('filename', '?'))} ({_human_size(img.get('size', 0))})]" + file_info = "" + if m.get("file"): + fi = m["file"] + file_info = f" [File: {_sanitize_text(fi.get('filename', '?'))} ({_human_size(fi.get('size', 0))})]" + read_info = "" + if m.get("sender") == client.username: + read_by = m.get("read_by", []) + delivered_to = m.get("delivered_to", []) + member_map = {} + for mem in conv.get("members", []): + uid = mem.get("user_id") or mem.get("id", "") + if uid: + member_map[uid] = _sanitize_text(mem.get("username") or mem.get("email") or "?") + my_uid = client.session.get("user_id", "") if client.session else "" + others_read = [r for r in read_by if r.get("user_id") != my_uid] + others_delivered = [d for d in delivered_to if d.get("user_id") != my_uid] + if others_read: + names = ", ".join(member_map.get(r["user_id"], r["user_id"][:8]) for r in others_read) + read_info = f" [\u2713\u2713 Read by {names}]" + elif others_delivered: + read_info = " [\u2713\u2713 Delivered]" + else: + read_info = " [\u2713 Sent]" + pin_info = "" + if m.get("pinned_at"): + pin_info = " \U0001f4cc" + reaction_info = "" + reactions = m.get("reactions", []) + if reactions: + grouped = {} + for r in reactions: + grouped.setdefault(r["reaction"], 0) + grouped[r["reaction"]] += 1 + _REMOJI = {"thumbsup": "\U0001f44d", "heart": "\u2764\ufe0f", "laugh": "\U0001f602", + "surprised": "\U0001f62e", "sad": "\U0001f622", "thumbsdown": "\U0001f44e"} + parts = [f"{_REMOJI.get(k, k)}{v}" for k, v in grouped.items()] + reaction_info = " [" + " ".join(parts) + "]" + fwd_info = "" + if m.get("forwarded_from"): + fwd_sender = _sanitize_text(m["forwarded_from"].get("sender", "?")) + fwd_info = f" (fwd from {fwd_sender})" + text = _sanitize_text(m.get("text", "")) + sender = _sanitize_text(m.get("sender", "?")) + print(f" #{i+1} {sender}: {text}{image_info}{file_info}{reply_info}{read_info}{pin_info}{reaction_info}{fwd_info}") + + +async def _delete_message_action(client, messages): + del_choice = await prompt("Delete message #: ") + try: + del_idx = int(del_choice) - 1 + if not (0 <= del_idx < len(messages)): + print("[!] Invalid message number.") + return + except ValueError: + print("[!] Invalid number.") + return + ok, msg = await client.delete_message(messages[del_idx]["message_id"]) + print(f"[{'+'if ok else '!'}] {msg}") + + +async def _download_file_action(client, messages): + dl_choice = await prompt("Download from message #: ") + try: + dl_idx = int(dl_choice) - 1 + if not (0 <= dl_idx < len(messages)): + print("[!] Invalid message number.") + return + except ValueError: + print("[!] Invalid number.") + return + m = messages[dl_idx] + file_info = m.get("file") or m.get("image") + if not file_info: + print("[!] No file/image in this message.") + return + filename = _safe_filename(file_info.get("filename", "download")) + save_path = await prompt(f"Save as [{filename}]: ") + if not save_path: + save_path = filename + data = await client.download_file(file_info["file_id"], file_info) + if data: + with open(save_path, "wb") as f: + f.write(data) + print(f"[+] Saved to {save_path} ({_human_size(len(data))})") + else: + print("[!] Download failed.") + + +async def _invitations_menu(client): + invitations = await client.list_invitations() + if not invitations: + print("[*] No pending invitations.") + return + print("\nPending invitations:") + for i, inv in enumerate(invitations): + inv_name = _sanitize_text(inv.get("conversation_name") or inv.get("conversation_id", "")[:8]) + invited_by = _sanitize_text(inv.get("invited_by_username") or inv.get("invited_by", "")[:8]) + print(f" {i+1}) {inv_name} (invited by {invited_by})") + choice = await prompt("Select invitation (or empty to go back): ") + if not choice: + return + try: + idx = int(choice) - 1 + if not (0 <= idx < len(invitations)): + print("[!] Invalid selection.") + return + except ValueError: + print("[!] Invalid selection.") + return + inv = invitations[idx] + action = await prompt("(a)ccept or (d)ecline? ") + if action.lower().startswith("a"): + ok, msg = await client.accept_invitation(inv["conversation_id"]) + print(f"[{'+'if ok else '!'}] {msg}") + elif action.lower().startswith("d"): + ok, msg = await client.decline_invitation(inv["conversation_id"]) + print(f"[{'+'if ok else '!'}] {msg}") + + +def _print_profile(profile): + print(f"\n Username: {_sanitize_text(profile.get('username', '?'))}") + print(f" Email: {_sanitize_text(profile.get('email', '?'))}") + phone = profile.get("phone") + if phone: + print(f" Phone: {_sanitize_text(phone)}") + location = profile.get("location") + if location: + print(f" Location: {_sanitize_text(location)}") + has_avatar = profile.get("avatar_file") + print(f" Avatar: {'Yes' if has_avatar else 'No'}") + + +async def _my_profile_menu(client): + profile = await client.get_profile() + if not profile: + print("[!] Could not load profile.") + return + print("\n--- My Profile ---") + _print_profile(profile) + print(f" Phone visible: {profile.get('phone_visible', False)}") + print(f" Email visible: {profile.get('email_visible', False)}") + print(f" Location visible: {profile.get('location_visible', False)}") + + action = await prompt("\n(e)dit, (a)vatar upload, or empty to go back: ") + if not action: + return + if action.lower().startswith("e"): + print("[*] Leave fields empty to keep current value.") + phone = await prompt(f"Phone [{profile.get('phone', '')}]: ") + location = await prompt(f"Location [{profile.get('location', '')}]: ") + phone_vis = await prompt(f"Phone visible [{profile.get('phone_visible', False)}] (y/n): ") + email_vis = await prompt(f"Email visible [{profile.get('email_visible', False)}] (y/n): ") + loc_vis = await prompt(f"Location visible [{profile.get('location_visible', False)}] (y/n): ") + + fields = {} + if phone: + fields["phone"] = phone + if location: + fields["location"] = location + if phone_vis.lower() in ("y", "n"): + fields["phone_visible"] = phone_vis.lower() == "y" + if email_vis.lower() in ("y", "n"): + fields["email_visible"] = email_vis.lower() == "y" + if loc_vis.lower() in ("y", "n"): + fields["location_visible"] = loc_vis.lower() == "y" + if fields: + ok, msg = await client.update_profile(**fields) + print(f"[{'+'if ok else '!'}] {msg}") + else: + print("[*] No changes.") + elif action.lower().startswith("a"): + path = await prompt("Avatar image path: ") + if not path or not os.path.isfile(path): + print("[!] File not found.") + return + data = open(path, "rb").read() + ok, msg = await client.update_avatar(data) + print(f"[{'+'if ok else '!'}] {msg}") + + +async def _devices_menu(client): + resp = await client.send_and_recv("list_devices") + if resp.get("status") != "ok": + print(f"[!] {resp.get('data', {}).get('message', 'Failed')}") + return + devices = resp["data"].get("devices", []) + if not devices: + print("[*] No devices found.") + return + current_device_id = client.device_id + print("\nYour devices:") + for i, d in enumerate(devices): + name = _sanitize_text(d.get("device_name") or "Unnamed") + did = d.get("device_id", "?") + last_seen = _sanitize_text(d.get("last_seen_at", "?")) + current = " (this device)" if did == current_device_id else "" + print(f" {i+1}) {name} — {did[:8]}... — last seen: {last_seen}{current}") + action = await prompt("\n(r)emove a device, or empty to go back: ") + if not action or not action.lower().startswith("r"): + return + choice = await prompt("Remove device #: ") + try: + idx = int(choice) - 1 + if not (0 <= idx < len(devices)): + print("[!] Invalid selection.") + return + except ValueError: + print("[!] Invalid selection.") + return + d = devices[idx] + if d.get("device_id") == current_device_id: + print("[!] Cannot remove current device.") + return + resp = await client.send_and_recv("remove_device", device_id=d["device_id"]) + if resp.get("status") == "ok": + print("[+] Device removed.") + else: + print(f"[!] {resp.get('data', {}).get('message', 'Failed')}") + + +async def notification_printer(client: ChatClient): + """Print real-time notifications with sender name.""" + while True: + notif = await client._notification_queue.get() + notif_type = notif.get("type", "") + data = notif.get("data", {}) + if notif_type == "messages_read": + continue # Silent - read receipts shown when reading messages + if notif_type == "typing_start": + who = _sanitize_text(data.get("username") or data.get("user_id", "")[:8] or "Someone") + print(f"\n[*] {who} is typing...") + continue + if notif_type == "typing_stop": + continue + if notif_type == "session_reset": + from_uid = data.get("from_user_id", "")[:8] + client.handle_session_reset_notification( + data.get("from_user_id", ""), + data.get("from_device_id") or None, + ) + print(f"\n[*] Session with {from_uid}... was reset. New session will be created on next message.") + continue + if notif_type == "group_invitation": + inv_name = _sanitize_text(data.get("conversation_name", "?")) + invited_by = _sanitize_text(data.get("invited_by_username", "?")) + print(f"\n[*] New invitation to '{inv_name}' from {invited_by}. Use option 8 to accept/decline.") + continue + if notif_type == "device_added": + device_name = _sanitize_text(data.get("device_name", "Unknown device")) + device_id = _sanitize_text((data.get("device_id", "") or "")[:8]) + ip = _sanitize_text(data.get("ip", "unknown")) + print( + f"\n[!] New device added to your account: {device_name} ({device_id}) from {ip}.\n" + f" If this was not you, rotate keys immediately." + ) + continue + if notif_type in ("conversation_created", "member_removed", "conversation_renamed"): + print(f"\n[*] Conversation updated ({notif_type}).") + continue + if notif_type == "member_added": + print(f"\n[*] Conversation updated (member_added).") + conv_id = data.get("conversation_id", "") + new_user_id = data.get("user_id", "") + if conv_id and new_user_id: + asyncio.ensure_future( + client.redistribute_sender_key_to_member(conv_id, new_user_id) + ) + continue + if notif_type == "message_reacted": + username = _sanitize_text(data.get("username", data.get("user_id", "?")[:8])) + reaction = _sanitize_text(data.get("reaction", "?")) + action = data.get("action", "add") + print(f"\n[*] {username} {'added' if action == 'add' else 'removed'} reaction '{reaction}'") + continue + if notif_type in ("message_pinned", "message_unpinned"): + username = _sanitize_text(data.get("username", data.get("user_id", "?")[:8])) + act = "pinned" if notif_type == "message_pinned" else "unpinned" + print(f"\n[*] {username} {act} a message") + continue + if notif_type in ("user_online", "user_offline", "online_users"): + continue # Silent for CLI + payload = client.decrypt_notification(data) + if payload: + print(f"\n[*] New message from {_sanitize_text(payload['sender'])} in conversation {data.get('conversation_id', '?')[:8]}...") + # None = control message (sender key distribution), skip silently + + +async def main(): + setup_logging() + client = ChatClient() + await client.connect() + + client._listener_task = asyncio.create_task(client._background_listener()) + notif_task = asyncio.create_task(notification_printer(client)) + + print("=== Encrypted Chat Client ===") + print("1) Register") + print("2) Login") + print("3) Link new device (this device)") + print("4) Authorize new device (from this device)") + print("5) Rotate keys (revoke other devices)") + choice = await prompt("> ") + + if choice == "1": + username = await prompt("Username (display): ") + email = await prompt("Email: ") + password = await prompt_password("Password (for private key): ") + if not email or not password: + print("[!] Email and password required.") + await client.close() + return + ok, code_or_msg = await client.register(username, password, email=email) + if not ok: + print(f"[!] {code_or_msg}") + await client.close() + return + print(f"[*] Registration code: {code_or_msg}") + code = await prompt("Enter code: ") + ok2, msg2 = await client.confirm_registration(email, username, code) + print(f"[{'+'if ok2 else '!'}] {msg2}") + if ok2: + ok3, msg3 = await client.login(email, password) + print(f"[{'+'if ok3 else '!'}] {msg3}") + elif choice == "2": + email = await prompt("Email: ") + password = await prompt_password("Password (for private key): ") + ok, msg = await client.login(email, password) + print(f"[{'+'if ok else '!'}] {msg}") + elif choice == "3": + email = await prompt("Email: ") + password = await prompt_password("Password (for private key): ") + if not password: + print("[!] Password required.") + await client.close() + return + ok, code_or_msg = await client.pairing_start(email) + if not ok: + print(f"[!] {code_or_msg}") + await client.close() + return + code = code_or_msg + fingerprint = client.pairing_fingerprint() + print(f"[*] Pairing code: {code}") + print("[*] Pairing fingerprint:") + print(fingerprint) + print("[*] Approve this code on an already-logged-in device.") + print("[!] Never share this pairing code.") + ok2, msg2 = await client.pairing_wait(code, email, password) + if not ok2: + print(f"[!] {msg2}") + await client.close() + return + print(f"[+] {msg2}") + ok3, msg3 = await client.login(email, password) + print(f"[{'+'if ok3 else '!'}] {msg3}") + elif choice == "4": + email = await prompt("Email: ") + password = await prompt_password("Password (for private key): ") + ok, msg = await client.login(email, password) + print(f"[{'+'if ok else '!'}] {msg}") + if not ok: + await client.close() + return + code = await prompt("Pairing code: ") + fingerprint = await prompt("Fingerprint shown on the new device: ") + ok2, msg2 = await client.authorize_device(code, fingerprint) + print(f"[{'+'if ok2 else '!'}] {msg2}") + elif choice == "5": + email = await prompt("Email: ") + password = await prompt_password("Password (for private key): ") + ok, msg = await client.login(email, password) + print(f"[{'+'if ok else '!'}] {msg}") + if not ok: + await client.close() + return + confirm = await prompt("This will revoke other devices. Type 'YES' to continue: ") + if confirm != "YES": + print("[*] Cancelled.") + await client.close() + return + ok2, msg2 = await client.rotate_keys(client.username, password) + print(f"[{'+'if ok2 else '!'}] {msg2}") + else: + print("[!] Invalid choice.") + await client.close() + return + + if client.session: + await interactive_menu(client) + + notif_task.cancel() + await client.close() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("\n[*] Bye.") diff --git a/crypto_utils.py b/crypto_utils.py new file mode 100644 index 0000000..02c82ec --- /dev/null +++ b/crypto_utils.py @@ -0,0 +1,988 @@ +"""Cryptographic utilities: Ed25519, X25519, AES-256-GCM, Double Ratchet, Sender Keys. + +RSA functions retained for login challenge-response only. +""" + +import hashlib +import hmac +import json +import os +import struct +import uuid +from dataclasses import dataclass, field + +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import padding, rsa +from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey, Ed25519PublicKey +from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey +from cryptography.hazmat.primitives.ciphers.aead import AESGCM +from cryptography.hazmat.primitives.kdf.hkdf import HKDF +from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC + + +# --------------------------------------------------------------------------- +# Password-based key encryption (M3: PBKDF2 600k iterations + AES-256-GCM) +# --------------------------------------------------------------------------- + +PBKDF2_ITERATIONS = 600_000 +_ECP1_MAGIC = b"ECP1" # Encrypted Chat PBKDF v1 format marker + + +def _encrypt_private_key(raw_bytes: bytes, password: bytes) -> bytes: + """Encrypt raw key bytes with PBKDF2-HMAC-SHA256 (600k iterations) + AES-256-GCM. + + Output format: MAGIC(4) + salt(16) + nonce(12) + ciphertext_with_tag(N+16) + """ + salt = os.urandom(16) + kdf = PBKDF2HMAC(algorithm=hashes.SHA256(), length=32, + salt=salt, iterations=PBKDF2_ITERATIONS) + derived = kdf.derive(password) + nonce = os.urandom(12) + aesgcm = AESGCM(derived) + ct = aesgcm.encrypt(nonce, raw_bytes, _ECP1_MAGIC) # AAD = magic bytes + return _ECP1_MAGIC + salt + nonce + ct + + +def _decrypt_private_key(data: bytes, password: bytes) -> bytes: + """Decrypt key bytes encrypted with _encrypt_private_key.""" + if not data.startswith(_ECP1_MAGIC): + raise ValueError("Not ECP1 format") + salt = data[4:20] + nonce = data[20:32] + ct = data[32:] + kdf = PBKDF2HMAC(algorithm=hashes.SHA256(), length=32, + salt=salt, iterations=PBKDF2_ITERATIONS) + derived = kdf.derive(password) + aesgcm = AESGCM(derived) + return aesgcm.decrypt(nonce, ct, _ECP1_MAGIC) + + +# --------------------------------------------------------------------------- +# RSA (login challenge-response ONLY) +# --------------------------------------------------------------------------- + +def generate_rsa_keypair(key_size: int = 4096) -> tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey]: + private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size) + return private_key, private_key.public_key() + + +def serialize_private_key(key: rsa.RSAPrivateKey, password: bytes | None = None) -> bytes: + if password: + raw = key.private_bytes(serialization.Encoding.DER, serialization.PrivateFormat.PKCS8, + serialization.NoEncryption()) + return _encrypt_private_key(raw, password) + return key.private_bytes(serialization.Encoding.PEM, serialization.PrivateFormat.PKCS8, + serialization.NoEncryption()) + + +def serialize_public_key(key: rsa.RSAPublicKey) -> bytes: + return key.public_bytes(serialization.Encoding.PEM, serialization.PublicFormat.SubjectPublicKeyInfo) + + +def load_private_key(data: bytes, password: bytes | None = None) -> rsa.RSAPrivateKey: + if data.startswith(_ECP1_MAGIC): + raw = _decrypt_private_key(data, password) + return serialization.load_der_private_key(raw, password=None) + # Legacy PEM format (old BestAvailableEncryption or unencrypted) + return serialization.load_pem_private_key(data, password=password) + + +def load_public_key(pem: bytes) -> rsa.RSAPublicKey: + return serialization.load_pem_public_key(pem) + + +def compute_pairing_fingerprint(public_key_data: bytes | str) -> str: + """Format a temporary pairing public key as a human-verifiable fingerprint.""" + if isinstance(public_key_data, str): + key_bytes = public_key_data.encode("utf-8") + else: + key_bytes = public_key_data + canonical = key_bytes.replace(b"\r\n", b"\n").strip() if b"-----BEGIN" in key_bytes else key_bytes + digest = hashlib.sha256(b"EncryptedChat_PairingKey_v1\x00" + canonical).digest() + return format_fingerprint(digest) + + +def normalize_pairing_fingerprint(value: str) -> str: + """Normalize user-entered pairing fingerprints for comparison.""" + return "".join(ch for ch in value if ch.isdigit()) + + +def encode_pairing_qr(code: str, fingerprint: str) -> bytes: + """Encode pairing code + fingerprint for QR transport. + + Format: magic(5='PAIR1') + code(8 ASCII digits) + fingerprint(30 ASCII digits) + """ + code_digits = "".join(ch for ch in code if ch.isdigit()) + fp_digits = normalize_pairing_fingerprint(fingerprint) + if len(code_digits) != 8: + raise ValueError("Pairing code must contain 8 digits") + if len(fp_digits) != 30: + raise ValueError("Pairing fingerprint must contain 30 digits") + return b"PAIR1" + code_digits.encode("ascii") + fp_digits.encode("ascii") + + +def decode_pairing_qr(data: bytes) -> tuple[str, str]: + """Decode pairing QR payload. Returns (code, formatted_fingerprint).""" + if len(data) != 43 or not data.startswith(b"PAIR1"): + raise ValueError("Invalid pairing QR payload") + code = data[5:13].decode("ascii") + fp_digits = data[13:43].decode("ascii") + if not code.isdigit() or not fp_digits.isdigit(): + raise ValueError("Invalid pairing QR payload") + groups = [fp_digits[i:i + 5] for i in range(0, 30, 5)] + return code, " ".join(groups[:3]) + "\n" + " ".join(groups[3:]) + + +def rsa_sign(private_key: rsa.RSAPrivateKey, data: bytes) -> bytes: + return private_key.sign( + data, + padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH), + hashes.SHA256(), + ) + + +def rsa_verify(public_key: rsa.RSAPublicKey, signature: bytes, data: bytes) -> bool: + try: + public_key.verify( + signature, data, + padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.AUTO), + hashes.SHA256(), + ) + return True + except Exception: + return False + + +# --------------------------------------------------------------------------- +# AES-256-GCM (symmetric encryption — used by ratchet message keys & images) +# --------------------------------------------------------------------------- + +def aes_encrypt(plaintext: bytes, key: bytes | None = None) -> tuple[bytes, bytes, bytes, bytes]: + """Encrypt with AES-256-GCM. Returns (key, nonce, ciphertext, tag).""" + if key is None: + key = AESGCM.generate_key(bit_length=256) + nonce = os.urandom(12) + aesgcm = AESGCM(key) + ct_with_tag = aesgcm.encrypt(nonce, plaintext, None) + ciphertext = ct_with_tag[:-16] + tag = ct_with_tag[-16:] + return key, nonce, ciphertext, tag + + +def aes_decrypt(key: bytes, nonce: bytes, ciphertext: bytes, tag: bytes) -> bytes: + """Decrypt with AES-256-GCM.""" + aesgcm = AESGCM(key) + return aesgcm.decrypt(nonce, ciphertext + tag, None) + + +# --------------------------------------------------------------------------- +# Ed25519 Identity Keys +# --------------------------------------------------------------------------- + +def generate_identity_keypair() -> tuple[Ed25519PrivateKey, Ed25519PublicKey]: + priv = Ed25519PrivateKey.generate() + return priv, priv.public_key() + + +def serialize_ed25519_private(key: Ed25519PrivateKey, password: bytes | None = None) -> bytes: + if password: + raw = serialize_ed25519_private_raw(key) # 32 bytes + return _encrypt_private_key(raw, password) + return serialize_ed25519_private_raw(key) # 32 bytes, no password + + +def serialize_ed25519_private_raw(key: Ed25519PrivateKey) -> bytes: + """Serialize Ed25519 private key to 32 raw bytes (unencrypted).""" + return key.private_bytes(serialization.Encoding.Raw, serialization.PrivateFormat.Raw, serialization.NoEncryption()) + + +def serialize_ed25519_public(key: Ed25519PublicKey) -> bytes: + """Serialize Ed25519 public key to 32 raw bytes.""" + return key.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw) + + +def load_ed25519_private(data: bytes, password: bytes | None = None) -> Ed25519PrivateKey: + if data.startswith(_ECP1_MAGIC): + raw = _decrypt_private_key(data, password) + return Ed25519PrivateKey.from_private_bytes(raw) + # Legacy formats: PEM (old BestAvailableEncryption) or 32-byte raw + if password: + return serialization.load_pem_private_key(data, password=password) + if len(data) == 32: + return Ed25519PrivateKey.from_private_bytes(data) + return serialization.load_pem_private_key(data, password=None) + + +def load_ed25519_public(data: bytes) -> Ed25519PublicKey: + if len(data) == 32: + return Ed25519PublicKey.from_public_bytes(data) + return serialization.load_pem_public_key(data) + + +def ed25519_sign(private_key: Ed25519PrivateKey, data: bytes) -> bytes: + """Sign data with Ed25519. Returns 64-byte signature.""" + return private_key.sign(data) + + +def ed25519_verify(public_key: Ed25519PublicKey, signature: bytes, data: bytes) -> bool: + """Verify Ed25519 signature.""" + try: + public_key.verify(signature, data) + return True + except Exception: + return False + + +# --------------------------------------------------------------------------- +# X25519 Key Exchange +# --------------------------------------------------------------------------- + +def generate_x25519_keypair() -> tuple[X25519PrivateKey, X25519PublicKey]: + priv = X25519PrivateKey.generate() + return priv, priv.public_key() + + +def serialize_x25519_private(key: X25519PrivateKey) -> bytes: + """Serialize X25519 private key to 32 raw bytes.""" + return key.private_bytes(serialization.Encoding.Raw, serialization.PrivateFormat.Raw, serialization.NoEncryption()) + + +def serialize_x25519_public(key: X25519PublicKey) -> bytes: + """Serialize X25519 public key to 32 raw bytes.""" + return key.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw) + + +def load_x25519_private(data: bytes) -> X25519PrivateKey: + return X25519PrivateKey.from_private_bytes(data) + + +def load_x25519_public(data: bytes) -> X25519PublicKey: + return X25519PublicKey.from_public_bytes(data) + + +def x25519_dh(private_key: X25519PrivateKey, public_key: X25519PublicKey) -> bytes: + """Perform X25519 Diffie-Hellman. Returns 32-byte shared secret.""" + return private_key.exchange(public_key) + + +def derive_pairing_shared_key(shared_secret: bytes, public_key_a: bytes, public_key_b: bytes) -> bytes: + """Derive a symmetric bootstrap key for device pairing. + + The key derivation is direction-agnostic: both peers sort the two public + keys lexicographically before binding them into HKDF salt. + """ + pub1, pub2 = sorted((public_key_a, public_key_b)) + salt = hashlib.sha256(b"EncryptedChat_PairingSalt_v1\x00" + pub1 + pub2).digest() + return hkdf_derive(shared_secret, salt=salt, info=b"EncryptedChat_PairingBootstrap", length=32) + + +# --------------------------------------------------------------------------- +# Ed25519 <-> X25519 conversion (for Identity Key dual use) +# --------------------------------------------------------------------------- + +def ed25519_private_to_x25519(ed_private: Ed25519PrivateKey) -> X25519PrivateKey: + """Derive X25519 private key from Ed25519 private key via RFC 7748 clamping.""" + raw = ed_private.private_bytes( + serialization.Encoding.Raw, serialization.PrivateFormat.Raw, serialization.NoEncryption() + ) + # SHA-512 hash of the seed, take first 32 bytes, clamp per RFC 7748 + h = hashlib.sha512(raw).digest()[:32] + clamped = bytearray(h) + clamped[0] &= 248 + clamped[31] &= 127 + clamped[31] |= 64 + return X25519PrivateKey.from_private_bytes(bytes(clamped)) + + +def ed25519_public_to_x25519(ed_public: Ed25519PublicKey) -> X25519PublicKey: + """Derive X25519 public key from Ed25519 public key. + + Uses the cryptography library's internal conversion. For production use, + we compute the X25519 public key from the converted private key when possible. + For remote keys (where we don't have the private key), we use a pure-Python + implementation of the Ed25519->X25519 point conversion. + """ + # Montgomery u = (1 + y) / (1 - y) mod p, where p = 2^255 - 19 + raw = ed_public.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw) + y = int.from_bytes(raw, "little") + # Clear the sign bit + y &= (1 << 255) - 1 + p = (1 << 255) - 19 + # u = (1 + y) * inverse(1 - y) mod p + one_plus_y = (1 + y) % p + one_minus_y = (1 - y) % p + inv = pow(one_minus_y, p - 2, p) + u = (one_plus_y * inv) % p + x25519_bytes = u.to_bytes(32, "little") + return X25519PublicKey.from_public_bytes(x25519_bytes) + + +# --------------------------------------------------------------------------- +# HKDF +# --------------------------------------------------------------------------- + +_HKDF_INFO_SELF = b"EncryptedChat_SelfKey" +_HKDF_INFO_RK = b"EncryptedChat_RootKey" + + +def derive_self_encryption_key(identity_private: Ed25519PrivateKey) -> bytes: + """Derive a static AES-256 key from identity key for encrypting own sent messages. + + This is NOT a ratchet — it's a static key. Safe because only the owner + has the identity private key, and self-copies don't need forward secrecy. + """ + raw = identity_private.private_bytes( + serialization.Encoding.Raw, serialization.PrivateFormat.Raw, serialization.NoEncryption() + ) + return hkdf_derive(raw, salt=b"self_encryption", info=_HKDF_INFO_SELF, length=32) + + +_HKDF_INFO_LOCAL = b"EncryptedChat_LocalStorage" + + +def derive_local_storage_key(identity_private: Ed25519PrivateKey) -> bytes: + """Derive AES-256 key for encrypting local session/sender key files.""" + raw = identity_private.private_bytes( + serialization.Encoding.Raw, serialization.PrivateFormat.Raw, serialization.NoEncryption() + ) + return hkdf_derive(raw, salt=b"local_storage", info=_HKDF_INFO_LOCAL, length=32) + + +_HKDF_INFO_CK_MSG = b"\x01" # chain key -> message key +_HKDF_INFO_CK_NEXT = b"\x02" # chain key -> next chain key + + +def hkdf_derive(input_key: bytes, salt: bytes, info: bytes, length: int = 32) -> bytes: + return HKDF(algorithm=hashes.SHA256(), length=length, salt=salt, info=info).derive(input_key) + + +def kdf_rk(root_key: bytes, dh_output: bytes) -> tuple[bytes, bytes]: + """Root key KDF. Returns (new_root_key, chain_key). + + Uses HKDF with the root key as salt and DH output as input key material. + Derives 64 bytes: first 32 = new root key, last 32 = chain key. + """ + derived = hkdf_derive(dh_output, salt=root_key, info=_HKDF_INFO_RK, length=64) + return derived[:32], derived[32:] + + +def kdf_ck(chain_key: bytes) -> tuple[bytes, bytes]: + """Chain key KDF. Returns (new_chain_key, message_key). + + Uses HMAC-SHA256: + message_key = HMAC(chain_key, 0x01) + new_chain_key = HMAC(chain_key, 0x02) + """ + message_key = hmac.new(chain_key, _HKDF_INFO_CK_MSG, hashlib.sha256).digest() + new_chain_key = hmac.new(chain_key, _HKDF_INFO_CK_NEXT, hashlib.sha256).digest() + return new_chain_key, message_key + + +# --------------------------------------------------------------------------- +# X3DH +# --------------------------------------------------------------------------- + +_X3DH_INFO = b"EncryptedChat_X3DH" + + +def generate_signed_prekey(identity_private: Ed25519PrivateKey) -> dict: + """Generate a signed pre-key (SPK). + + Returns {private: X25519PrivateKey, public: X25519PublicKey, signature: bytes, id: str}. + """ + spk_priv, spk_pub = generate_x25519_keypair() + spk_pub_bytes = serialize_x25519_public(spk_pub) + signature = ed25519_sign(identity_private, spk_pub_bytes) + return { + "private": spk_priv, + "public": spk_pub, + "signature": signature, + "id": str(uuid.uuid4()), + } + + +def generate_one_time_prekeys(count: int = 50) -> list[dict]: + """Generate a batch of one-time pre-keys. + + Returns [{private: X25519PrivateKey, public: X25519PublicKey, id: str}, ...]. + """ + result = [] + for _ in range(count): + priv, pub = generate_x25519_keypair() + result.append({"private": priv, "public": pub, "id": str(uuid.uuid4())}) + return result + + +def x3dh_initiate( + ik_private_ed: Ed25519PrivateKey, + ik_public_remote_ed: Ed25519PublicKey, + spk_remote: X25519PublicKey, + spk_signature: bytes, + opk_remote: X25519PublicKey | None = None, +) -> tuple[bytes, X25519PrivateKey, X25519PublicKey]: + """Initiator side of X3DH. + + Args: + ik_private_ed: Our Ed25519 identity private key + ik_public_remote_ed: Remote Ed25519 identity public key + spk_remote: Remote signed pre-key (X25519 public) + spk_signature: Ed25519 signature of spk_remote by ik_public_remote_ed + opk_remote: Optional one-time pre-key (X25519 public) + + Returns: + (shared_secret, ephemeral_private, ephemeral_public) + """ + # Verify SPK signature + spk_remote_bytes = serialize_x25519_public(spk_remote) + if not ed25519_verify(ik_public_remote_ed, spk_signature, spk_remote_bytes): + raise ValueError("Invalid SPK signature") + + # Convert identity keys to X25519 + ik_x25519_private = ed25519_private_to_x25519(ik_private_ed) + ik_x25519_remote = ed25519_public_to_x25519(ik_public_remote_ed) + + # Generate ephemeral keypair + ek_priv, ek_pub = generate_x25519_keypair() + + # DH computations + dh1 = x25519_dh(ik_x25519_private, spk_remote) # IK_A, SPK_B + dh2 = x25519_dh(ek_priv, ik_x25519_remote) # EK_A, IK_B + dh3 = x25519_dh(ek_priv, spk_remote) # EK_A, SPK_B + + dh_concat = dh1 + dh2 + dh3 + if opk_remote is not None: + dh4 = x25519_dh(ek_priv, opk_remote) # EK_A, OPK_B + dh_concat += dh4 + + # Derive shared secret + shared_secret = hkdf_derive(dh_concat, salt=b"\x00" * 32, info=_X3DH_INFO, length=32) + return shared_secret, ek_priv, ek_pub + + +def x3dh_respond( + ik_private_ed: Ed25519PrivateKey, + spk_private: X25519PrivateKey, + ik_remote_ed: Ed25519PublicKey, + ek_remote: X25519PublicKey, + opk_private: X25519PrivateKey | None = None, +) -> bytes: + """Responder side of X3DH. + + Args: + ik_private_ed: Our Ed25519 identity private key + spk_private: Our signed pre-key private (X25519) + ik_remote_ed: Remote Ed25519 identity public key + ek_remote: Remote ephemeral key (X25519 public) + opk_private: Our one-time pre-key private (X25519), if used + + Returns: + shared_secret (32 bytes) + """ + ik_x25519_private = ed25519_private_to_x25519(ik_private_ed) + ik_x25519_remote = ed25519_public_to_x25519(ik_remote_ed) + + dh1 = x25519_dh(spk_private, ik_x25519_remote) # SPK_B, IK_A + dh2 = x25519_dh(ik_x25519_private, ek_remote) # IK_B, EK_A + dh3 = x25519_dh(spk_private, ek_remote) # SPK_B, EK_A + + dh_concat = dh1 + dh2 + dh3 + if opk_private is not None: + dh4 = x25519_dh(opk_private, ek_remote) # OPK_B, EK_A + dh_concat += dh4 + + shared_secret = hkdf_derive(dh_concat, salt=b"\x00" * 32, info=_X3DH_INFO, length=32) + return shared_secret + + +# --------------------------------------------------------------------------- +# Double Ratchet +# --------------------------------------------------------------------------- + +MAX_SKIP = 256 # max messages to skip in a single chain (out-of-order tolerance) + + +@dataclass +class RatchetHeader: + """Header sent with each ratchet message.""" + dh_pub: bytes # sender's current ratchet public key (32 bytes) + n: int # message number in current sending chain + pn: int # number of messages in previous sending chain + + def serialize(self) -> bytes: + return json.dumps({ + "dh_pub": serialize_x25519_public(load_x25519_public(self.dh_pub)).hex() + if isinstance(self.dh_pub, bytes) else serialize_x25519_public(self.dh_pub).hex(), + "n": self.n, + "pn": self.pn, + }).encode() + + def to_dict(self) -> dict: + pub_hex = self.dh_pub.hex() if isinstance(self.dh_pub, bytes) else \ + serialize_x25519_public(self.dh_pub).hex() + return {"dh_pub": pub_hex, "n": self.n, "pn": self.pn} + + @classmethod + def from_dict(cls, d: dict) -> "RatchetHeader": + return cls(dh_pub=bytes.fromhex(d["dh_pub"]), n=d["n"], pn=d["pn"]) + + +class DoubleRatchet: + """Signal Double Ratchet implementation.""" + + def __init__(self): + self.dh_pair: tuple[X25519PrivateKey, X25519PublicKey] | None = None + self.dh_remote: X25519PublicKey | None = None + self.root_key: bytes = b"" + self.send_chain_key: bytes | None = None + self.recv_chain_key: bytes | None = None + self.send_n: int = 0 + self.recv_n: int = 0 + self.prev_send_n: int = 0 + # (dh_pub_hex, n) -> message_key for out-of-order messages + self.skipped: dict[tuple[str, int], bytes] = {} + + @classmethod + def init_alice(cls, shared_secret: bytes, bob_spk_pub: X25519PublicKey) -> "DoubleRatchet": + """Initialize as initiator (Alice) after X3DH. + + Alice performs the first DH ratchet step immediately. + """ + ratchet = cls() + ratchet.dh_pair = generate_x25519_keypair() + ratchet.dh_remote = bob_spk_pub + + # Perform DH ratchet to derive send chain + dh_output = x25519_dh(ratchet.dh_pair[0], ratchet.dh_remote) + ratchet.root_key, ratchet.send_chain_key = kdf_rk(shared_secret, dh_output) + ratchet.recv_chain_key = None + ratchet.send_n = 0 + ratchet.recv_n = 0 + ratchet.prev_send_n = 0 + return ratchet + + @classmethod + def init_bob(cls, shared_secret: bytes, spk_pair: tuple[X25519PrivateKey, X25519PublicKey]) -> "DoubleRatchet": + """Initialize as responder (Bob) after X3DH. + + Bob uses his SPK as the initial ratchet key pair. + """ + ratchet = cls() + ratchet.dh_pair = spk_pair + ratchet.root_key = shared_secret + ratchet.send_chain_key = None + ratchet.recv_chain_key = None + ratchet.send_n = 0 + ratchet.recv_n = 0 + ratchet.prev_send_n = 0 + return ratchet + + def encrypt(self, plaintext: bytes) -> dict: + """Encrypt a message. + + Returns {header: {dh_pub, n, pn}, ciphertext: bytes, nonce: bytes}. + """ + if self.send_chain_key is None: + raise RuntimeError("Send chain not initialized") + + self.send_chain_key, message_key = kdf_ck(self.send_chain_key) + + header = RatchetHeader( + dh_pub=serialize_x25519_public(self.dh_pair[1]), + n=self.send_n, + pn=self.prev_send_n, + ) + + # Encrypt with AES-256-GCM using the message key + nonce = os.urandom(12) + aesgcm = AESGCM(message_key) + # Include header as AAD to bind ciphertext to header + aad = header.serialize() + ct_with_tag = aesgcm.encrypt(nonce, plaintext, aad) + + self.send_n += 1 + + return { + "header": header.to_dict(), + "ciphertext": ct_with_tag, # includes 16-byte tag + "nonce": nonce, + } + + def decrypt(self, header_dict: dict, ciphertext: bytes, nonce: bytes) -> bytes: + """Decrypt a message. Handles DH ratchet step if new dh_pub. + + State is snapshotted before modification and restored on failure (M9 fix). + """ + header = RatchetHeader.from_dict(header_dict) + remote_dh_pub_bytes = header.dh_pub + + # Check if this is from a skipped message (no state modification needed) + skip_key = (remote_dh_pub_bytes.hex(), header.n) + if skip_key in self.skipped: + mk = self.skipped.pop(skip_key) + aad = header.serialize() + aesgcm = AESGCM(mk) + try: + return aesgcm.decrypt(nonce, ciphertext, aad) + except Exception: + self.skipped[skip_key] = mk # restore skipped key + raise + + # Snapshot state before modifications + snap = self._snapshot() + + try: + remote_dh_pub = load_x25519_public(remote_dh_pub_bytes) + current_remote_bytes = serialize_x25519_public(self.dh_remote) if self.dh_remote else None + + if current_remote_bytes is None or remote_dh_pub_bytes != current_remote_bytes: + # New DH ratchet step + self._skip_messages(header.pn) + self._dh_ratchet(remote_dh_pub) + + self._skip_messages(header.n) + + # Derive message key from receive chain + self.recv_chain_key, mk = kdf_ck(self.recv_chain_key) + self.recv_n += 1 + + aad = header.serialize() + aesgcm = AESGCM(mk) + return aesgcm.decrypt(nonce, ciphertext, aad) + except Exception: + self._restore(snap) + raise + + def _snapshot(self) -> dict: + """Capture mutable state for rollback on decrypt failure.""" + return { + "dh_pair": self.dh_pair, + "dh_remote": self.dh_remote, + "root_key": self.root_key, + "send_chain_key": self.send_chain_key, + "recv_chain_key": self.recv_chain_key, + "send_n": self.send_n, + "recv_n": self.recv_n, + "prev_send_n": self.prev_send_n, + "skipped": dict(self.skipped), + } + + def _restore(self, snap: dict): + """Restore state from snapshot.""" + self.dh_pair = snap["dh_pair"] + self.dh_remote = snap["dh_remote"] + self.root_key = snap["root_key"] + self.send_chain_key = snap["send_chain_key"] + self.recv_chain_key = snap["recv_chain_key"] + self.send_n = snap["send_n"] + self.recv_n = snap["recv_n"] + self.prev_send_n = snap["prev_send_n"] + self.skipped = snap["skipped"] + + def _skip_messages(self, until: int): + """Skip ahead in the receive chain, storing message keys for out-of-order delivery.""" + if self.recv_chain_key is None: + return + if until - self.recv_n > MAX_SKIP: + raise RuntimeError(f"Too many skipped messages ({until - self.recv_n} > {MAX_SKIP})") + while self.recv_n < until: + self.recv_chain_key, mk = kdf_ck(self.recv_chain_key) + remote_hex = serialize_x25519_public(self.dh_remote).hex() if self.dh_remote else "" + self.skipped[(remote_hex, self.recv_n)] = mk + self.recv_n += 1 + + def _dh_ratchet(self, remote_dh_pub: X25519PublicKey): + """Perform a DH ratchet step: update receive chain, generate new DH pair, update send chain.""" + self.prev_send_n = self.send_n + self.send_n = 0 + self.recv_n = 0 + self.dh_remote = remote_dh_pub + + # Derive new receive chain key + dh_output = x25519_dh(self.dh_pair[0], self.dh_remote) + self.root_key, self.recv_chain_key = kdf_rk(self.root_key, dh_output) + + # Generate new DH pair and derive new send chain key + self.dh_pair = generate_x25519_keypair() + dh_output = x25519_dh(self.dh_pair[0], self.dh_remote) + self.root_key, self.send_chain_key = kdf_rk(self.root_key, dh_output) + + def export_state(self) -> bytes: + """Serialize full ratchet state for persistent storage.""" + state = { + "dh_priv": serialize_x25519_private(self.dh_pair[0]).hex() if self.dh_pair else None, + "dh_pub": serialize_x25519_public(self.dh_pair[1]).hex() if self.dh_pair else None, + "dh_remote": serialize_x25519_public(self.dh_remote).hex() if self.dh_remote else None, + "root_key": self.root_key.hex(), + "send_ck": self.send_chain_key.hex() if self.send_chain_key else None, + "recv_ck": self.recv_chain_key.hex() if self.recv_chain_key else None, + "send_n": self.send_n, + "recv_n": self.recv_n, + "prev_send_n": self.prev_send_n, + "skipped": {f"{k[0]}:{k[1]}": v.hex() for k, v in self.skipped.items()}, + } + return json.dumps(state).encode() + + @classmethod + def import_state(cls, data: bytes) -> "DoubleRatchet": + """Deserialize ratchet state.""" + state = json.loads(data) + r = cls() + if state["dh_priv"] and state["dh_pub"]: + priv = load_x25519_private(bytes.fromhex(state["dh_priv"])) + pub = load_x25519_public(bytes.fromhex(state["dh_pub"])) + r.dh_pair = (priv, pub) + if state["dh_remote"]: + r.dh_remote = load_x25519_public(bytes.fromhex(state["dh_remote"])) + r.root_key = bytes.fromhex(state["root_key"]) + r.send_chain_key = bytes.fromhex(state["send_ck"]) if state["send_ck"] else None + r.recv_chain_key = bytes.fromhex(state["recv_ck"]) if state["recv_ck"] else None + r.send_n = state["send_n"] + r.recv_n = state["recv_n"] + r.prev_send_n = state["prev_send_n"] + r.skipped = {} + for k_str, v_hex in state.get("skipped", {}).items(): + parts = k_str.rsplit(":", 1) + dh_hex = parts[0] + n = int(parts[1]) + r.skipped[(dh_hex, n)] = bytes.fromhex(v_hex) + return r + + +# --------------------------------------------------------------------------- +# Sender Keys (group messaging) +# --------------------------------------------------------------------------- + +class SenderKeyState: + """Sender key chain for group messaging. + + Each sender in a group has their own sender key chain. + Other group members receive the initial sender_key via pairwise Double Ratchet. + """ + + def __init__(self, sender_key: bytes | None = None): + if sender_key is None: + sender_key = os.urandom(32) + self.sender_key = sender_key + self.chain_id = hashlib.sha256(sender_key).digest() + self.chain_key = hkdf_derive(sender_key, salt=b"\x00" * 32, info=b"SenderKeyChain", length=32) + self.n = 0 + # For receivers: track chain state to allow fast-forward + self._known_keys: dict[int, bytes] = {} + + def encrypt(self, plaintext: bytes) -> dict: + """Encrypt with current chain key. + + Returns {chain_id: hex, n: int, ciphertext: bytes, nonce: bytes}. + """ + self.chain_key, message_key = kdf_ck(self.chain_key) + nonce = os.urandom(12) + aesgcm = AESGCM(message_key) + # AAD includes chain_id and message number + aad = self.chain_id + struct.pack(">I", self.n) + ct_with_tag = aesgcm.encrypt(nonce, plaintext, aad) + result = { + "chain_id": self.chain_id.hex(), + "n": self.n, + "ciphertext": ct_with_tag, + "nonce": nonce, + } + self.n += 1 + return result + + MAX_SENDER_KEY_SKIP = 256 + + def decrypt(self, chain_id_hex: str, n: int, ciphertext: bytes, nonce: bytes) -> bytes: + """Decrypt a group message. Fast-forwards the chain if needed. + + State is snapshotted before modification and restored on failure (M9 fix). + """ + chain_id = bytes.fromhex(chain_id_hex) + if chain_id != self.chain_id: + raise ValueError("Chain ID mismatch") + + if n - self.n > self.MAX_SENDER_KEY_SKIP: + raise ValueError(f"Sender key skip too large ({n - self.n} > {self.MAX_SENDER_KEY_SKIP})") + + # Snapshot before fast-forward + snap_chain_key = self.chain_key + snap_n = self.n + snap_known = dict(self._known_keys) + + try: + # Fast-forward the chain to reach message n + while self.n <= n: + self.chain_key, mk = kdf_ck(self.chain_key) + self._known_keys[self.n] = mk + self.n += 1 + + mk = self._known_keys.pop(n, None) + if mk is None: + raise ValueError(f"Message key for n={n} not available (already consumed)") + + aad = chain_id + struct.pack(">I", n) + aesgcm = AESGCM(mk) + return aesgcm.decrypt(nonce, ciphertext, aad) + except Exception: + self.chain_key = snap_chain_key + self.n = snap_n + self._known_keys = snap_known + raise + + def export_key(self) -> bytes: + """Export sender key for distribution to group members. + + Contains everything needed to initialize a receiving SenderKeyState. + """ + return json.dumps({ + "sender_key": self.sender_key.hex(), + }).encode() + + def export_state(self) -> bytes: + """Serialize full state for persistent storage.""" + return json.dumps({ + "sender_key": self.sender_key.hex(), + "chain_id": self.chain_id.hex(), + "chain_key": self.chain_key.hex(), + "n": self.n, + "known_keys": {str(k): v.hex() for k, v in self._known_keys.items()}, + }).encode() + + @classmethod + def import_state(cls, data: bytes) -> "SenderKeyState": + state = json.loads(data) + obj = cls.__new__(cls) + obj.sender_key = bytes.fromhex(state["sender_key"]) + obj.chain_id = bytes.fromhex(state["chain_id"]) + obj.chain_key = bytes.fromhex(state["chain_key"]) + obj.n = state["n"] + obj._known_keys = {int(k): bytes.fromhex(v) for k, v in state.get("known_keys", {}).items()} + return obj + + @classmethod + def from_key(cls, exported_key: bytes) -> "SenderKeyState": + """Initialize a receiving SenderKeyState from an exported key.""" + data = json.loads(exported_key) + return cls(sender_key=bytes.fromhex(data["sender_key"])) + + +# --------------------------------------------------------------------------- +# Contact Key Verification (Safety Numbers / Fingerprints / QR Codes) +# --------------------------------------------------------------------------- + +FINGERPRINT_VERSION = 0 + + +def compute_fingerprint(user_id: str, identity_key_bytes: bytes, iterations: int = 5200) -> bytes: + """Compute a 32-byte fingerprint for a user's identity key. + + Uses iterated SHA-512 (Signal's NumericFingerprint algorithm). + Seed: version(2B) + identity_key(32B) + user_id(UTF-8). + Each iteration: SHA-512(previous_hash + identity_key). + Output: first 32 bytes of final hash. + """ + version_bytes = FINGERPRINT_VERSION.to_bytes(2, "big") + data = version_bytes + identity_key_bytes + user_id.encode("utf-8") + for _ in range(iterations): + data = hashlib.sha512(data + identity_key_bytes).digest() + return data[:32] + + +def format_fingerprint(fp_bytes: bytes) -> str: + """Format 32-byte fingerprint as 6 groups of 5 zero-padded digits (30 digits). + + Each group: int(bytes[i*5:(i+1)*5], big-endian) % 100000. + Output: two lines of 3 groups each, space-separated. + """ + groups = [] + for i in range(6): + num = int.from_bytes(fp_bytes[i * 5:(i + 1) * 5], "big") % 100000 + groups.append(f"{num:05d}") + return " ".join(groups[:3]) + "\n" + " ".join(groups[3:]) + + +def compute_safety_number(my_uid: str, my_ik_bytes: bytes, + their_uid: str, their_ik_bytes: bytes) -> str: + """Compute a 60-digit safety number for a pair of users. + + Both users see the same number regardless of who computes it. + Lower user_id's fingerprint comes first (deterministic ordering). + Output: 12 groups of 5 digits, formatted as 3 lines of 4 groups. + """ + fp_mine = compute_fingerprint(my_uid, my_ik_bytes) + fp_theirs = compute_fingerprint(their_uid, their_ik_bytes) + if my_uid < their_uid: + combined = fp_mine + fp_theirs + else: + combined = fp_theirs + fp_mine + # 64 bytes -> 12 groups of 5 digits + groups = [] + for i in range(12): + num = int.from_bytes(combined[i * 5:(i + 1) * 5], "big") % 100000 + groups.append(f"{num:05d}") + lines = [ + " ".join(groups[0:4]), + " ".join(groups[4:8]), + " ".join(groups[8:12]), + ] + return "\n".join(lines) + + +def encode_verification_qr(user_id: str, identity_key_bytes: bytes) -> bytes: + """Encode user identity for QR code verification. + + Format: version(1B=0x01) + uid_len(1B) + uid(UTF-8) + identity_key(32B). + """ + uid_bytes = user_id.encode("utf-8") + return b"\x01" + len(uid_bytes).to_bytes(1, "big") + uid_bytes + identity_key_bytes + + +def decode_verification_qr(data: bytes) -> tuple[str, bytes]: + """Decode QR code verification payload. + + Returns (user_id, identity_key_bytes). + Raises ValueError on invalid format. + """ + if len(data) < 3: + raise ValueError("QR data too short") + if data[0] != 0x01: + raise ValueError(f"Unknown QR version: {data[0]}") + uid_len = data[1] + if len(data) < 2 + uid_len + 32: + raise ValueError("QR data truncated") + user_id = data[2:2 + uid_len].decode("utf-8") + identity_key = data[2 + uid_len:2 + uid_len + 32] + return user_id, identity_key + + +# --------------------------------------------------------------------------- +# Message Padding (metadata privacy — hide plaintext length) +# --------------------------------------------------------------------------- + +_PAD_MAGIC = b"\x01" +_PAD_BUCKETS = [64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536] + + +def pad_plaintext(plaintext: bytes) -> bytes: + """Pad plaintext to nearest bucket size to hide message length. + + Format: 0x01 + plaintext + random_padding + pad_length(4B big-endian) + Prefix 0x01 distinguishes padded messages from legacy unpadded (which start with '{'). + """ + content = _PAD_MAGIC + plaintext + # +4 for the length suffix + min_size = len(content) + 4 + target = next((b for b in _PAD_BUCKETS if b >= min_size), min_size) + pad_len = target - len(content) + return content + os.urandom(pad_len - 4) + struct.pack(">I", pad_len) + + +def unpad_plaintext(data: bytes) -> bytes: + """Remove padding. Returns raw plaintext for both padded and legacy unpadded messages.""" + if not data or data[0:1] != _PAD_MAGIC: + return data # legacy unpadded message (starts with '{' for JSON) + if len(data) < 5: + return data # too short to be validly padded + pad_len = struct.unpack(">I", data[-4:])[0] + if pad_len < 4 or pad_len > len(data) - 1: + return data # invalid padding metadata, treat as legacy + return data[1:len(data) - pad_len] diff --git a/db.py b/db.py new file mode 100644 index 0000000..5ab5196 --- /dev/null +++ b/db.py @@ -0,0 +1,1726 @@ +"""MySQL database layer for the encrypted chat server.""" + +import os +import uuid + +import logging +import mysql.connector +from mysql.connector import pooling +from dotenv import load_dotenv + +from crypto_utils import ( + generate_identity_keypair, + serialize_ed25519_public, + generate_signed_prekey, + serialize_x25519_public, + generate_one_time_prekeys, +) + +load_dotenv() + +# Sentinel device_id for self-encrypted copies and legacy (pre-multi-device) rows +SELF_DEVICE_ID = "00000000-0000-0000-0000-000000000000" + + +_logger = logging.getLogger(__name__) + +_pool = None + + +def _get_pool(): + """Get or create the connection pool (lazy init).""" + global _pool + if _pool is None: + pool_size = int(os.getenv("DB_POOL_SIZE", "10")) + pool_kwargs = dict( + pool_name="chat_pool", + pool_size=pool_size, + pool_reset_session=True, + host=os.getenv("MYSQL_HOST", "localhost"), + port=int(os.getenv("MYSQL_PORT", "3306")), + user=os.getenv("MYSQL_USER", "root"), + password=os.getenv("MYSQL_PASSWORD", ""), + database=os.getenv("MYSQL_DATABASE", "encrypted_chat"), + ) + # Optional MySQL TLS (M7): set MYSQL_SSL_CA (and optionally MYSQL_SSL_CERT/KEY) + ssl_ca = os.getenv("MYSQL_SSL_CA", "").strip() + ssl_cert = os.getenv("MYSQL_SSL_CERT", "").strip() + ssl_key = os.getenv("MYSQL_SSL_KEY", "").strip() + if ssl_ca: + pool_kwargs["ssl_ca"] = ssl_ca + if ssl_cert: + pool_kwargs["ssl_cert"] = ssl_cert + if ssl_key: + pool_kwargs["ssl_key"] = ssl_key + _logger.info("MySQL TLS enabled (CA: %s)", ssl_ca) + _pool = pooling.MySQLConnectionPool(**pool_kwargs) + _logger.info("DB connection pool created (size=%d)", pool_size) + return _pool + + +def get_connection(): + """Get a connection from the pool.""" + return _get_pool().get_connection() + + +def generate_uuid() -> str: + return str(uuid.uuid4()) + + +# --- Devices --- + +def create_device(user_id: str, device_name: str | None = None) -> str: + """Create a new device for a user. Returns device_id.""" + conn = get_connection() + try: + cursor = conn.cursor() + device_id = generate_uuid() + cursor.execute( + "INSERT INTO devices (id, user_id, device_name) VALUES (%s, %s, %s)", + (device_id, user_id, device_name), + ) + conn.commit() + return device_id + finally: + conn.close() + + +def get_user_devices(user_id: str) -> list[dict]: + """Get all devices for a user.""" + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT id, user_id, device_name, created_at, last_seen_at " + "FROM devices WHERE user_id = %s ORDER BY created_at", + (user_id,), + ) + return cursor.fetchall() + finally: + conn.close() + + +def get_device(device_id: str) -> dict | None: + """Get a single device by ID.""" + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT id, user_id, device_name, created_at, last_seen_at " + "FROM devices WHERE id = %s", + (device_id,), + ) + return cursor.fetchone() + finally: + conn.close() + + +def update_device_last_seen(device_id: str): + """Update last_seen_at timestamp for a device.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "UPDATE devices SET last_seen_at = NOW() WHERE id = %s", + (device_id,), + ) + conn.commit() + finally: + conn.close() + + +def delete_device(device_id: str): + """Delete a device. CASCADE removes its prekeys.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute("DELETE FROM devices WHERE id = %s", (device_id,)) + # Also clean up prekeys explicitly for device_id column + cursor.execute("DELETE FROM signed_prekeys WHERE device_id = %s", (device_id,)) + cursor.execute("DELETE FROM one_time_prekeys WHERE device_id = %s", (device_id,)) + conn.commit() + finally: + conn.close() + + +# --- Users --- + +def create_user(username: str, email: str, rsa_public_key_pem: str, identity_key: bytes) -> str: + """Register a new user. Returns user ID.""" + conn = get_connection() + try: + cursor = conn.cursor() + user_id = generate_uuid() + cursor.execute( + "INSERT INTO users (id, username, email, rsa_public_key, identity_key) " + "VALUES (%s, %s, %s, %s, %s)", + (user_id, username, email, rsa_public_key_pem, identity_key), + ) + conn.commit() + return user_id + finally: + conn.close() + + +def get_user_by_email(email: str) -> dict | None: + """Get user by email.""" + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT id, username, rsa_public_key, email, identity_key FROM users WHERE email = %s", + (email,), + ) + return cursor.fetchone() + finally: + conn.close() + + +def get_user_by_id(user_id: str) -> dict | None: + """Get user by ID.""" + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT id, username, rsa_public_key, email, identity_key FROM users WHERE id = %s", + (user_id,), + ) + return cursor.fetchone() + finally: + conn.close() + + +def shares_conversation(user_id_a: str, user_id_b: str) -> bool: + """Check if two users share at least one conversation.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "SELECT 1 FROM conversation_members cm1 " + "JOIN conversation_members cm2 ON cm1.conversation_id = cm2.conversation_id " + "WHERE cm1.user_id = %s AND cm2.user_id = %s LIMIT 1", + (user_id_a, user_id_b), + ) + return cursor.fetchone() is not None + finally: + conn.close() + + +def get_user_contacts(user_id: str) -> list[str]: + """Get all user IDs that share at least one conversation with the given user.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "SELECT DISTINCT cm2.user_id " + "FROM conversation_members cm1 " + "JOIN conversation_members cm2 ON cm1.conversation_id = cm2.conversation_id " + "WHERE cm1.user_id = %s AND cm2.user_id != %s", + (user_id, user_id), + ) + return [row[0] for row in cursor.fetchall()] + finally: + conn.close() + + +def update_user_rsa_key(user_id: str, rsa_public_key_pem: str): + """Update user's RSA public key (for login).""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute("UPDATE users SET rsa_public_key = %s WHERE id = %s", (rsa_public_key_pem, user_id)) + conn.commit() + finally: + conn.close() + + +def update_username(user_id: str, new_username: str): + """Update user's display name.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute("UPDATE users SET username = %s WHERE id = %s", (new_username, user_id)) + conn.commit() + finally: + conn.close() + + +# --- Pre-keys --- + +def store_signed_prekey(user_id: str, spk_id: str, public_key: bytes, signature: bytes, + device_id: str | None = None): + """Store (or replace) a signed pre-key for a user's device.""" + conn = get_connection() + try: + cursor = conn.cursor() + # Remove old SPKs for this user+device + if device_id: + cursor.execute("DELETE FROM signed_prekeys WHERE user_id = %s AND device_id = %s", + (user_id, device_id)) + else: + cursor.execute("DELETE FROM signed_prekeys WHERE user_id = %s AND device_id IS NULL", + (user_id,)) + cursor.execute( + "INSERT INTO signed_prekeys (id, user_id, device_id, public_key, signature) " + "VALUES (%s, %s, %s, %s, %s)", + (spk_id, user_id, device_id, public_key, signature), + ) + conn.commit() + finally: + conn.close() + + +def get_signed_prekey(user_id: str, device_id: str | None = None) -> dict | None: + """Get the current signed pre-key for a user (optionally per device).""" + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + if device_id: + cursor.execute( + "SELECT id, public_key, signature, device_id, created_at FROM signed_prekeys " + "WHERE user_id = %s AND device_id = %s " + "ORDER BY created_at DESC LIMIT 1", + (user_id, device_id), + ) + else: + cursor.execute( + "SELECT id, public_key, signature, device_id, created_at FROM signed_prekeys " + "WHERE user_id = %s ORDER BY created_at DESC LIMIT 1", + (user_id,), + ) + return cursor.fetchone() + finally: + conn.close() + + +def store_one_time_prekeys(user_id: str, prekeys: list[dict], device_id: str | None = None): + """Store a batch of one-time pre-keys. Each dict has {id, public_key (bytes)}.""" + conn = get_connection() + try: + cursor = conn.cursor() + for pk in prekeys: + cursor.execute( + "INSERT INTO one_time_prekeys (id, user_id, device_id, public_key) " + "VALUES (%s, %s, %s, %s)", + (pk["id"], user_id, device_id, pk["public_key"]), + ) + conn.commit() + finally: + conn.close() + + +def consume_one_time_prekey(user_id: str, device_id: str | None = None) -> dict | None: + """Atomically consume one OTP: SELECT FOR UPDATE + DELETE. + Returns {id, public_key} or None.""" + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + conn.start_transaction() + if device_id: + cursor.execute( + "SELECT id, public_key FROM one_time_prekeys " + "WHERE user_id = %s AND device_id = %s LIMIT 1 FOR UPDATE", + (user_id, device_id), + ) + else: + cursor.execute( + "SELECT id, public_key FROM one_time_prekeys " + "WHERE user_id = %s LIMIT 1 FOR UPDATE", + (user_id,), + ) + row = cursor.fetchone() + if row: + cursor.execute("DELETE FROM one_time_prekeys WHERE id = %s", (row["id"],)) + conn.commit() + return row + except Exception: + conn.rollback() + raise + finally: + conn.close() + + +def count_one_time_prekeys(user_id: str, device_id: str | None = None) -> int: + """Count remaining OTPs for a user (optionally per device).""" + conn = get_connection() + try: + cursor = conn.cursor() + if device_id: + cursor.execute( + "SELECT COUNT(*) FROM one_time_prekeys WHERE user_id = %s AND device_id = %s", + (user_id, device_id), + ) + else: + cursor.execute("SELECT COUNT(*) FROM one_time_prekeys WHERE user_id = %s", (user_id,)) + return cursor.fetchone()[0] + finally: + conn.close() + + +def get_key_bundle(user_id: str) -> dict | None: + """Get complete key bundle for X3DH (single device — legacy compat). + + Returns {identity_key, signed_prekey_id, signed_prekey, spk_signature, + one_time_prekey_id, one_time_prekey} or None. + OTP is consumed atomically. + """ + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + # Get user identity key + cursor.execute("SELECT identity_key FROM users WHERE id = %s", (user_id,)) + user = cursor.fetchone() + if not user: + return None + + # Get signed prekey + cursor.execute( + "SELECT id, public_key, signature, device_id FROM signed_prekeys WHERE user_id = %s " + "ORDER BY created_at DESC LIMIT 1", + (user_id,), + ) + spk = cursor.fetchone() + if not spk: + return None + + # Consume one OTP (may be None) — use transaction for atomicity (H12 fix) + conn.start_transaction() + cursor.execute( + "SELECT id, public_key FROM one_time_prekeys WHERE user_id = %s LIMIT 1 FOR UPDATE", + (user_id,), + ) + opk = cursor.fetchone() + if opk: + cursor.execute("DELETE FROM one_time_prekeys WHERE id = %s", (opk["id"],)) + conn.commit() + + result = { + "identity_key": user["identity_key"], + "signed_prekey_id": spk["id"], + "signed_prekey": spk["public_key"], + "spk_signature": spk["signature"], + } + if opk: + result["one_time_prekey_id"] = opk["id"] + result["one_time_prekey"] = opk["public_key"] + return result + except Exception: + try: + conn.rollback() + except Exception: + pass + raise + finally: + conn.close() + + +def get_key_bundles_for_user(user_id: str) -> dict | None: + """Get key bundles for ALL devices of a user. Returns + {identity_key, device_bundles: [{device_id, signed_prekey_id, signed_prekey_pub, + spk_signature, opk_id, opk_pub}]} or None. + Consumes one OPK per device atomically. + """ + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + # Get user identity key + cursor.execute("SELECT identity_key FROM users WHERE id = %s", (user_id,)) + user = cursor.fetchone() + if not user: + return None + + # Get all signed prekeys (one per device, most recent) + cursor.execute( + "SELECT id, public_key, signature, device_id FROM signed_prekeys " + "WHERE user_id = %s ORDER BY created_at DESC", + (user_id,), + ) + all_spks = cursor.fetchall() + if not all_spks: + return None + + # De-duplicate: keep only the most recent SPK per device_id + seen_devices = set() + spks_by_device = [] + for spk in all_spks: + dev = spk.get("device_id") or "__legacy__" + if dev not in seen_devices: + seen_devices.add(dev) + spks_by_device.append(spk) + + device_bundles = [] + # Commit the implicit transaction from the read-only queries above + # so we can start an explicit transaction for atomic OPK consumption. + conn.commit() + conn.start_transaction() + for spk in spks_by_device: + dev_id = spk.get("device_id") + # Consume one OPK for this device + if dev_id: + cursor.execute( + "SELECT id, public_key FROM one_time_prekeys " + "WHERE user_id = %s AND device_id = %s LIMIT 1 FOR UPDATE", + (user_id, dev_id), + ) + else: + cursor.execute( + "SELECT id, public_key FROM one_time_prekeys " + "WHERE user_id = %s AND device_id IS NULL LIMIT 1 FOR UPDATE", + (user_id,), + ) + opk = cursor.fetchone() + if opk: + cursor.execute("DELETE FROM one_time_prekeys WHERE id = %s", (opk["id"],)) + + bundle = { + "device_id": dev_id, + "signed_prekey_id": spk["id"], + "signed_prekey_pub": spk["public_key"], + "spk_signature": spk["signature"], + } + if opk: + bundle["opk_id"] = opk["id"] + bundle["opk_pub"] = opk["public_key"] + device_bundles.append(bundle) + conn.commit() + + return { + "identity_key": user["identity_key"], + "device_bundles": device_bundles, + } + except Exception: + try: + conn.rollback() + except Exception: + pass + raise + finally: + conn.close() + + +# --- Conversations --- + +def create_conversation(member_user_ids: list[str], joined_at=None, name=None, created_by=None) -> str: + conn = get_connection() + try: + cursor = conn.cursor() + conv_id = generate_uuid() + cursor.execute("INSERT INTO conversations (id, name, created_by) VALUES (%s, %s, %s)", + (conv_id, name, created_by)) + for uid in member_user_ids: + cursor.execute( + "INSERT INTO conversation_members (conversation_id, user_id, joined_at) VALUES (%s, %s, %s)", + (conv_id, uid, joined_at), + ) + conn.commit() + return conv_id + finally: + conn.close() + + +def add_conversation_member(conversation_id: str, user_id: str, joined_at=None): + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "INSERT IGNORE INTO conversation_members (conversation_id, user_id, joined_at) VALUES (%s, %s, %s)", + (conversation_id, user_id, joined_at), + ) + conn.commit() + finally: + conn.close() + + +def remove_conversation_member(conversation_id: str, user_id: str): + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "DELETE FROM conversation_members WHERE conversation_id = %s AND user_id = %s", + (conversation_id, user_id), + ) + conn.commit() + finally: + conn.close() + + +def count_conversation_members(conversation_id: str) -> int: + """Count members in a conversation.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "SELECT COUNT(*) FROM conversation_members WHERE conversation_id = %s", + (conversation_id,), + ) + return cursor.fetchone()[0] + finally: + conn.close() + + +def get_conversation_file_ids(conversation_id: str) -> list[str]: + """Get all file IDs (images + files) uploaded to a conversation.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "SELECT file_id FROM image_uploads WHERE conversation_id = %s", + (conversation_id,), + ) + return [row[0] for row in cursor.fetchall()] + finally: + conn.close() + + +def delete_conversation(conversation_id: str): + """Delete a conversation entirely. CASCADE cleans up members, messages, etc.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute("DELETE FROM conversations WHERE id = %s", (conversation_id,)) + conn.commit() + finally: + conn.close() + + +def get_conversation_members(conversation_id: str) -> list[dict]: + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT u.id, u.username, u.email, u.identity_key FROM conversation_members cm " + "JOIN users u ON cm.user_id = u.id " + "WHERE cm.conversation_id = %s", + (conversation_id,), + ) + return cursor.fetchall() + finally: + conn.close() + + +def find_direct_conversation(user_id_a: str, user_id_b: str) -> str | None: + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "SELECT cm1.conversation_id FROM conversation_members cm1 " + "JOIN conversation_members cm2 ON cm1.conversation_id = cm2.conversation_id " + "WHERE cm1.user_id = %s AND cm2.user_id = %s " + "AND (SELECT COUNT(*) FROM conversation_members cm3 " + " WHERE cm3.conversation_id = cm1.conversation_id) = 2 " + "LIMIT 1", + (user_id_a, user_id_b), + ) + row = cursor.fetchone() + return row[0] if row else None + finally: + conn.close() + + +def update_conversation_creator(conversation_id: str, new_creator_id: str): + """Transfer group creator role to another member.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "UPDATE conversations SET created_by = %s WHERE id = %s", + (new_creator_id, conversation_id), + ) + conn.commit() + finally: + conn.close() + + +def get_conversation(conversation_id: str) -> dict | None: + """Get conversation by ID.""" + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT id, created_at, name, created_by, avatar_file FROM conversations WHERE id = %s", + (conversation_id,), + ) + return cursor.fetchone() + finally: + conn.close() + + +def update_conversation_avatar(conversation_id: str, avatar_file: str): + """Set avatar file for a conversation.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "UPDATE conversations SET avatar_file = %s WHERE id = %s", + (avatar_file, conversation_id), + ) + conn.commit() + finally: + conn.close() + + +def update_conversation_name(conversation_id: str, name: str): + """Update the name of a conversation.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "UPDATE conversations SET name = %s WHERE id = %s", + (name, conversation_id), + ) + conn.commit() + finally: + conn.close() + + +def is_conversation_member(conversation_id: str, user_id: str) -> bool: + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "SELECT 1 FROM conversation_members WHERE conversation_id = %s AND user_id = %s", + (conversation_id, user_id), + ) + return cursor.fetchone() is not None + finally: + conn.close() + + +def list_user_conversations(user_id: str) -> list[dict]: + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT c.id, c.created_at, c.name, c.created_by, c.avatar_file FROM conversations c " + "JOIN conversation_members cm ON c.id = cm.conversation_id " + "WHERE cm.user_id = %s ORDER BY c.created_at DESC", + (user_id,), + ) + convs = cursor.fetchall() + if not convs: + return convs + # Batch-fetch all members for all conversations in one query (N+1 fix) + conv_ids = [c["id"] for c in convs] + placeholders = ",".join(["%s"] * len(conv_ids)) + cursor.execute( + f"SELECT cm.conversation_id, u.id AS user_id, u.username, u.email " + f"FROM conversation_members cm JOIN users u ON cm.user_id = u.id " + f"WHERE cm.conversation_id IN ({placeholders})", + conv_ids, + ) + members_by_conv: dict[str, list[dict]] = {} + for row in cursor.fetchall(): + cid = row.pop("conversation_id") + members_by_conv.setdefault(cid, []).append(row) + for conv in convs: + conv["members"] = members_by_conv.get(conv["id"], []) + return convs + finally: + conn.close() + + +# --- Group Invitations --- + +def create_invitation(conversation_id: str, user_id: str, invited_by: str): + """Create a pending group invitation.""" + conn = get_connection() + try: + cursor = conn.cursor() + inv_id = generate_uuid() + cursor.execute( + "INSERT IGNORE INTO group_invitations (id, conversation_id, user_id, invited_by) " + "VALUES (%s, %s, %s, %s)", + (inv_id, conversation_id, user_id, invited_by), + ) + conn.commit() + finally: + conn.close() + + +def get_pending_invitations(user_id: str) -> list[dict]: + """Get all pending invitations for a user, joined with conversation and inviter info.""" + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT gi.id, gi.conversation_id, gi.invited_by, gi.created_at, " + "c.name AS conversation_name, u.username AS invited_by_username " + "FROM group_invitations gi " + "JOIN conversations c ON gi.conversation_id = c.id " + "JOIN users u ON gi.invited_by = u.id " + "WHERE gi.user_id = %s " + "ORDER BY gi.created_at DESC", + (user_id,), + ) + return cursor.fetchall() + finally: + conn.close() + + +def delete_invitation(conversation_id: str, user_id: str): + """Delete a pending invitation.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "DELETE FROM group_invitations WHERE conversation_id = %s AND user_id = %s", + (conversation_id, user_id), + ) + conn.commit() + finally: + conn.close() + + +def has_pending_invitation(conversation_id: str, user_id: str) -> bool: + """Check if a user has a pending invitation for a conversation.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "SELECT 1 FROM group_invitations WHERE conversation_id = %s AND user_id = %s", + (conversation_id, user_id), + ) + return cursor.fetchone() is not None + finally: + conn.close() + + +# --- Messages --- + +def store_message( + conversation_id: str, + sender_id: str, + ratchet_header: bytes, + recipients: list[dict], + x3dh_header: bytes | None = None, + sender_chain_id: bytes | None = None, + sender_chain_n: int | None = None, + image_file_id: str | None = None, + sender_device_id: str | None = None, +) -> str: + """Store an encrypted message with per-recipient ciphertext. + + recipients: [{user_id, encrypted_content (bytes), nonce (bytes), + device_id (str, optional), ratchet_header (bytes, optional), + x3dh_header (bytes, optional)}] + """ + conn = get_connection() + try: + cursor = conn.cursor() + msg_id = generate_uuid() + cursor.execute( + "INSERT INTO messages (id, conversation_id, sender_id, sender_device_id, " + "ratchet_header, x3dh_header, sender_chain_id, sender_chain_n, image_file_id) " + "VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)", + (msg_id, conversation_id, sender_id, sender_device_id, ratchet_header, + x3dh_header, sender_chain_id, sender_chain_n, image_file_id), + ) + for r in recipients: + device_id = r.get("device_id", SELF_DEVICE_ID) + cursor.execute( + "INSERT INTO message_recipients (message_id, user_id, device_id, " + "encrypted_content, nonce, ratchet_header, x3dh_header) " + "VALUES (%s, %s, %s, %s, %s, %s, %s)", + (msg_id, r["user_id"], device_id, r["encrypted_content"], r["nonce"], + r.get("ratchet_header"), r.get("x3dh_header")), + ) + conn.commit() + cursor.execute("SELECT created_at FROM messages WHERE id = %s", (msg_id,)) + row = cursor.fetchone() + created_at = row[0].isoformat() if row else None + return msg_id, created_at + finally: + conn.close() + + +def get_messages(conversation_id: str, user_id: str, limit: int = 50, offset: int = 0, + device_id: str | None = None, after_ts: str | None = None) -> list[dict]: + """Get messages for a user in a conversation, JOINing their per-recipient ciphertext. + + If device_id is set, returns rows where mr.device_id matches OR is the sentinel + (self-encrypted / legacy). May return duplicate message IDs when both device-specific + and self-encrypted rows exist — caller should deduplicate (prefer device-specific). + + If after_ts is set, only returns messages created after that timestamp (ISO format). + Results are ordered ASC when after_ts is used, DESC otherwise. + """ + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + if device_id: + where_parts = ["m.conversation_id = %s", + "(cm.joined_at IS NULL OR m.created_at >= cm.joined_at)"] + params = [user_id, device_id, SELF_DEVICE_ID, user_id, conversation_id] + + if after_ts: + where_parts.append("m.created_at > %s") + params.append(after_ts) + + where_clause = " AND ".join(where_parts) + order = "ASC" if after_ts else "DESC" + + cursor.execute( + "SELECT m.id, m.conversation_id, m.sender_id, m.sender_device_id, " + "m.ratchet_header, m.x3dh_header, " + "m.sender_chain_id, m.sender_chain_n, m.created_at, m.deleted_at, m.image_file_id, " + "m.pinned_at, m.pinned_by, " + "mr.encrypted_content, mr.nonce, mr.device_id AS mr_device_id, " + "mr.ratchet_header AS mr_ratchet_header, mr.x3dh_header AS mr_x3dh_header " + "FROM messages m " + "JOIN message_recipients mr ON m.id = mr.message_id AND mr.user_id = %s " + " AND (mr.device_id = %s OR mr.device_id = %s) " + "JOIN conversation_members cm ON cm.conversation_id = m.conversation_id AND cm.user_id = %s " + f"WHERE {where_clause} " + f"ORDER BY m.created_at {order} LIMIT %s OFFSET %s", + (*params, limit, offset), + ) + else: + where_parts = ["m.conversation_id = %s", + "(cm.joined_at IS NULL OR m.created_at >= cm.joined_at)"] + params = [user_id, user_id, conversation_id] + + if after_ts: + where_parts.append("m.created_at > %s") + params.append(after_ts) + + where_clause = " AND ".join(where_parts) + order = "ASC" if after_ts else "DESC" + + cursor.execute( + "SELECT m.id, m.conversation_id, m.sender_id, m.sender_device_id, " + "m.ratchet_header, m.x3dh_header, " + "m.sender_chain_id, m.sender_chain_n, m.created_at, m.deleted_at, m.image_file_id, " + "m.pinned_at, m.pinned_by, " + "mr.encrypted_content, mr.nonce, mr.device_id AS mr_device_id, " + "mr.ratchet_header AS mr_ratchet_header, mr.x3dh_header AS mr_x3dh_header " + "FROM messages m " + "JOIN message_recipients mr ON m.id = mr.message_id AND mr.user_id = %s " + "JOIN conversation_members cm ON cm.conversation_id = m.conversation_id AND cm.user_id = %s " + f"WHERE {where_clause} " + f"ORDER BY m.created_at {order} LIMIT %s OFFSET %s", + (*params, limit, offset), + ) + return cursor.fetchall() + finally: + conn.close() + + +def count_messages(conversation_id: str, user_id: str) -> int: + """Count total messages visible to a user in a conversation.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "SELECT COUNT(DISTINCT m.id) " + "FROM messages m " + "JOIN message_recipients mr ON m.id = mr.message_id AND mr.user_id = %s " + "JOIN conversation_members cm ON cm.conversation_id = m.conversation_id AND cm.user_id = %s " + "WHERE m.conversation_id = %s AND (cm.joined_at IS NULL OR m.created_at >= cm.joined_at)", + (user_id, user_id, conversation_id), + ) + row = cursor.fetchone() + return row[0] if row else 0 + finally: + conn.close() + + +def get_message_conversation(message_id: str) -> str | None: + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute("SELECT conversation_id FROM messages WHERE id = %s", (message_id,)) + row = cursor.fetchone() + return row[0] if row else None + finally: + conn.close() + + +def get_message_sender(message_id: str) -> str | None: + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute("SELECT sender_id FROM messages WHERE id = %s", (message_id,)) + row = cursor.fetchone() + return row[0] if row else None + finally: + conn.close() + + +def get_deleted_messages_since(conversation_id: str, user_id: str, since_ts: str) -> list[str]: + """Return message IDs that were soft-deleted since the given timestamp.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "SELECT m.id FROM messages m " + "JOIN conversation_members cm ON cm.conversation_id = m.conversation_id AND cm.user_id = %s " + "WHERE m.conversation_id = %s AND m.deleted_at IS NOT NULL AND m.deleted_at > %s", + (user_id, conversation_id, since_ts), + ) + return [row[0] for row in cursor.fetchall()] + finally: + conn.close() + + +# --- Reactions --- + +ALLOWED_REACTIONS = {"thumbsup", "heart", "laugh", "surprised", "sad", "thumbsdown"} + + +def add_reaction(message_id: str, user_id: str, reaction: str) -> tuple[bool, str | None]: + """Add or replace a reaction. Returns (changed, old_reaction_or_None).""" + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT reaction FROM message_reactions WHERE message_id = %s AND user_id = %s", + (message_id, user_id), + ) + row = cursor.fetchone() + old_reaction = row["reaction"] if row else None + + if old_reaction == reaction: + return False, None # already same reaction + + if old_reaction: + cursor.execute( + "UPDATE message_reactions SET reaction = %s, created_at = CURRENT_TIMESTAMP " + "WHERE message_id = %s AND user_id = %s", + (reaction, message_id, user_id), + ) + else: + cursor.execute( + "INSERT INTO message_reactions (id, message_id, user_id, reaction) " + "VALUES (%s, %s, %s, %s)", + (generate_uuid(), message_id, user_id, reaction), + ) + conn.commit() + return True, old_reaction + finally: + conn.close() + + +def remove_reaction(message_id: str, user_id: str) -> bool: + """Remove a user's reaction. Returns True if deleted.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "DELETE FROM message_reactions WHERE message_id = %s AND user_id = %s", + (message_id, user_id), + ) + conn.commit() + return cursor.rowcount > 0 + finally: + conn.close() + + +def get_reactions(message_ids: list[str]) -> dict[str, list[dict]]: + """Get reactions for multiple messages. Returns {msg_id: [{user_id, reaction, created_at}]}.""" + if not message_ids: + return {} + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + placeholders = ",".join(["%s"] * len(message_ids)) + cursor.execute( + f"SELECT message_id, user_id, reaction, created_at " + f"FROM message_reactions WHERE message_id IN ({placeholders}) " + f"ORDER BY created_at", + tuple(message_ids), + ) + result = {} + for row in cursor.fetchall(): + mid = row["message_id"] + if mid not in result: + result[mid] = [] + result[mid].append({ + "user_id": row["user_id"], + "reaction": row["reaction"], + "created_at": row["created_at"].isoformat() if hasattr(row["created_at"], "isoformat") else str(row["created_at"]), + }) + return result + finally: + conn.close() + + +# --- Pins --- + +def pin_message(message_id: str, user_id: str, conversation_id: str) -> bool: + """Pin a message. Returns True on success.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "UPDATE messages SET pinned_at = NOW(), pinned_by = %s " + "WHERE id = %s AND conversation_id = %s AND pinned_at IS NULL", + (user_id, message_id, conversation_id), + ) + conn.commit() + return cursor.rowcount > 0 + finally: + conn.close() + + +def unpin_message(message_id: str, conversation_id: str) -> bool: + """Unpin a message. Returns True on success.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "UPDATE messages SET pinned_at = NULL, pinned_by = NULL " + "WHERE id = %s AND conversation_id = %s AND pinned_at IS NOT NULL", + (message_id, conversation_id), + ) + conn.commit() + return cursor.rowcount > 0 + finally: + conn.close() + + +def get_pinned_messages(conversation_id: str, user_id: str) -> list[dict]: + """Get pinned messages for a conversation (membership verified via JOIN).""" + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT m.id AS message_id, m.sender_id, m.pinned_at, m.pinned_by, m.created_at " + "FROM messages m " + "JOIN conversation_members cm ON cm.conversation_id = m.conversation_id AND cm.user_id = %s " + "WHERE m.conversation_id = %s AND m.pinned_at IS NOT NULL AND m.deleted_at IS NULL " + "ORDER BY m.pinned_at DESC", + (user_id, conversation_id), + ) + rows = cursor.fetchall() + for r in rows: + for k in ("pinned_at", "created_at"): + if r.get(k) and hasattr(r[k], "isoformat"): + r[k] = r[k].isoformat() + return rows + finally: + conn.close() + + +# --- Group Sender Keys --- + +def store_sender_key(conversation_id: str, sender_id: str, chain_id: bytes, + device_id: str | None = None): + """Store or update a sender key chain ID for a group member's device.""" + conn = get_connection() + try: + cursor = conn.cursor() + dev = device_id or SELF_DEVICE_ID + cursor.execute( + "REPLACE INTO group_sender_keys (conversation_id, sender_id, device_id, chain_id) " + "VALUES (%s, %s, %s, %s)", + (conversation_id, sender_id, dev, chain_id), + ) + conn.commit() + finally: + conn.close() + + +def get_sender_key(conversation_id: str, sender_id: str, + device_id: str | None = None) -> dict | None: + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + dev = device_id or SELF_DEVICE_ID + cursor.execute( + "SELECT chain_id, created_at FROM group_sender_keys " + "WHERE conversation_id = %s AND sender_id = %s AND device_id = %s", + (conversation_id, sender_id, dev), + ) + return cursor.fetchone() + finally: + conn.close() + + +# --- Read Receipts --- + +def filter_message_ids_by_conversation(conversation_id: str, message_ids: list[str]) -> list[str]: + """Return only message_ids that belong to the given conversation.""" + if not message_ids: + return [] + conn = get_connection() + try: + cursor = conn.cursor() + placeholders = ",".join(["%s"] * len(message_ids)) + cursor.execute( + f"SELECT id FROM messages WHERE id IN ({placeholders}) AND conversation_id = %s", + (*message_ids, conversation_id), + ) + return [row[0] for row in cursor.fetchall()] + finally: + conn.close() + + +def mark_messages_read(conversation_id: str, user_id: str, message_ids: list[str]): + if not message_ids: + return + conn = get_connection() + try: + cursor = conn.cursor() + # M1 fix: JOIN messages to verify message_ids belong to conversation_id + placeholders = ",".join(["%s"] * len(message_ids)) + cursor.execute( + f"INSERT IGNORE INTO message_reads (message_id, user_id) " + f"SELECT m.id, %s FROM messages m " + f"WHERE m.id IN ({placeholders}) AND m.conversation_id = %s", + (user_id, *message_ids, conversation_id), + ) + conn.commit() + finally: + conn.close() + + +def mark_conversation_read(conversation_id: str, user_id: str) -> int: + """Mark ALL unread messages in a conversation as read for user. Returns count marked.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "INSERT IGNORE INTO message_reads (message_id, user_id) " + "SELECT m.id, %s " + "FROM messages m " + "JOIN message_recipients mr ON mr.message_id = m.id AND mr.user_id = %s " + "LEFT JOIN message_reads mrd ON mrd.message_id = m.id AND mrd.user_id = %s " + "WHERE m.conversation_id = %s AND m.sender_id != %s " + "AND m.deleted_at IS NULL AND mrd.message_id IS NULL", + (user_id, user_id, user_id, conversation_id, user_id), + ) + count = cursor.rowcount + conn.commit() + return count + finally: + conn.close() + + +def get_unread_counts(user_id: str, max_age_days: int = 0) -> dict[str, int]: + """Return {conversation_id: unread_count} for all conversations the user is in. + + max_age_days: if > 0, only count messages younger than this many days. + Must match METADATA_RETENTION_DAYS to avoid phantom unreads after read cleanup. + """ + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + age_filter = "" + params = [user_id, user_id, user_id] + if max_age_days > 0: + age_filter = " AND m.created_at >= DATE_SUB(NOW(), INTERVAL %s DAY)" + params.append(max_age_days) + cursor.execute( + "SELECT m.conversation_id, COUNT(DISTINCT m.id) AS cnt " + "FROM messages m " + "JOIN message_recipients mr ON mr.message_id = m.id AND mr.user_id = %s " + "LEFT JOIN message_reads mrd ON mrd.message_id = m.id AND mrd.user_id = %s " + "WHERE m.sender_id != %s AND m.deleted_at IS NULL AND mrd.message_id IS NULL" + f"{age_filter} " + "GROUP BY m.conversation_id", + params, + ) + return {row["conversation_id"]: row["cnt"] for row in cursor.fetchall()} + finally: + conn.close() + + +def get_message_read_status(message_ids: list[str]) -> dict: + if not message_ids: + return {} + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + placeholders = ",".join(["%s"] * len(message_ids)) + cursor.execute( + f"SELECT mr.message_id, mr.user_id, mr.read_at " + f"FROM message_reads mr " + f"WHERE mr.message_id IN ({placeholders})", + tuple(message_ids), + ) + result = {} + for row in cursor.fetchall(): + mid = row["message_id"] + if mid not in result: + result[mid] = [] + result[mid].append({ + "user_id": row["user_id"], + "read_at": row["read_at"].isoformat() if hasattr(row["read_at"], "isoformat") else str(row["read_at"]), + }) + return result + finally: + conn.close() + + +# --- Delivery Receipts --- + +def mark_messages_delivered(conversation_id: str, user_id: str, message_ids: list[str]): + """Batch insert delivery receipts (INSERT IGNORE — idempotent).""" + if not message_ids: + return + conn = get_connection() + try: + cursor = conn.cursor() + # M1 fix: JOIN messages to verify message_ids belong to conversation_id + placeholders = ",".join(["%s"] * len(message_ids)) + cursor.execute( + f"INSERT IGNORE INTO message_deliveries (message_id, user_id) " + f"SELECT m.id, %s FROM messages m " + f"WHERE m.id IN ({placeholders}) AND m.conversation_id = %s", + (user_id, *message_ids, conversation_id), + ) + conn.commit() + finally: + conn.close() + + +def get_message_delivery_status(message_ids: list[str]) -> dict: + """Get delivery status for messages. Returns {msg_id: [{user_id, delivered_at}]}.""" + if not message_ids: + return {} + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + placeholders = ",".join(["%s"] * len(message_ids)) + cursor.execute( + f"SELECT md.message_id, md.user_id, md.delivered_at " + f"FROM message_deliveries md " + f"WHERE md.message_id IN ({placeholders})", + tuple(message_ids), + ) + result = {} + for row in cursor.fetchall(): + mid = row["message_id"] + if mid not in result: + result[mid] = [] + result[mid].append({ + "user_id": row["user_id"], + "delivered_at": row["delivered_at"].isoformat() if hasattr(row["delivered_at"], "isoformat") else str(row["delivered_at"]), + }) + return result + finally: + conn.close() + + +# --- Delete --- + +def soft_delete_message(message_id: str, sender_id: str) -> dict | None: + """Soft-delete a message if sender matches. Returns {'image_file_id': ...} or None.""" + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT sender_id, image_file_id FROM messages WHERE id = %s AND deleted_at IS NULL", + (message_id,), + ) + row = cursor.fetchone() + if not row or row["sender_id"] != sender_id: + return None + cursor.execute( + "UPDATE messages SET deleted_at = NOW() WHERE id = %s", + (message_id,), + ) + # Clear per-recipient ciphertext + cursor.execute( + "UPDATE message_recipients SET encrypted_content = %s WHERE message_id = %s", + (b"", message_id), + ) + conn.commit() + return {"image_file_id": row.get("image_file_id")} + finally: + conn.close() + + +def set_message_image_file_id(message_id: str, file_id: str): + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "UPDATE messages SET image_file_id = %s WHERE id = %s", + (file_id, message_id), + ) + conn.commit() + finally: + conn.close() + + +# --- Image Uploads --- + +def create_image_upload(file_id: str, conversation_id: str, uploader_id: str, file_size: int): + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "INSERT INTO image_uploads (file_id, conversation_id, uploader_id, file_size) " + "VALUES (%s, %s, %s, %s)", + (file_id, conversation_id, uploader_id, file_size), + ) + conn.commit() + finally: + conn.close() + + +def complete_image_upload(file_id: str): + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "UPDATE image_uploads SET completed = TRUE WHERE file_id = %s", + (file_id,), + ) + conn.commit() + finally: + conn.close() + + +def get_image_upload(file_id: str) -> dict | None: + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT file_id, conversation_id, uploader_id, file_size, completed, created_at " + "FROM image_uploads WHERE file_id = %s", + (file_id,), + ) + return cursor.fetchone() + finally: + conn.close() + + +def delete_image_upload(file_id: str): + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute("DELETE FROM image_uploads WHERE file_id = %s", (file_id,)) + conn.commit() + finally: + conn.close() + + +# --- User Profiles --- + +def create_default_profile(user_id: str): + """Create a default profile for a new user.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "INSERT IGNORE INTO user_profiles (user_id) VALUES (%s)", + (user_id,), + ) + conn.commit() + finally: + conn.close() + + +def get_user_profile(user_id: str, viewer_id: str | None = None) -> dict | None: + """Get user profile joined with user info. Respects visibility if viewer is different user.""" + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT u.id AS user_id, u.username, u.email, u.created_at, " + "p.phone, p.phone_visible, p.email_visible, p.location, " + "p.location_visible, p.avatar_file, p.updated_at " + "FROM users u LEFT JOIN user_profiles p ON u.id = p.user_id " + "WHERE u.id = %s", + (user_id,), + ) + row = cursor.fetchone() + if not row: + return None + # If viewing someone else's profile, apply visibility rules + if viewer_id and viewer_id != user_id: + if not row.get("email_visible"): + row["email"] = None + if not row.get("phone_visible"): + row["phone"] = None + if not row.get("location_visible"): + row["location"] = None + return row + finally: + conn.close() + + +def update_user_profile(user_id: str, **fields): + """Upsert user profile fields. Allowed: phone, phone_visible, email_visible, + location, location_visible, avatar_file.""" + allowed = {"phone", "phone_visible", "email_visible", "location", + "location_visible", "avatar_file"} + filtered = {k: v for k, v in fields.items() if k in allowed} + if not filtered: + return + conn = get_connection() + try: + cursor = conn.cursor() + # Upsert: insert default then update + cursor.execute( + "INSERT IGNORE INTO user_profiles (user_id) VALUES (%s)", + (user_id,), + ) + set_clause = ", ".join(f"{k} = %s" for k in filtered) + values = list(filtered.values()) + [user_id] + cursor.execute( + f"UPDATE user_profiles SET {set_clause} WHERE user_id = %s", + values, + ) + conn.commit() + finally: + conn.close() + + +def batch_reencrypt_messages(user_id: str, updates: list[dict]): + """Batch upsert message_recipients rows with self-encryption key data. + + Each update: {message_id, encrypted_content (bytes), nonce (bytes)}. + Sets ratchet_header to '{"self":true}' and clears x3dh_header. + Uses INSERT ... ON DUPLICATE KEY UPDATE so it works for both sent messages + (which already have a SELF_DEVICE_ID row) and received messages (which don't). + """ + if not updates: + return + conn = get_connection() + try: + cursor = conn.cursor() + self_header = b'{"self":true}' + for u in updates: + cursor.execute( + "INSERT INTO message_recipients " + "(message_id, user_id, device_id, encrypted_content, nonce, ratchet_header, x3dh_header) " + "VALUES (%s, %s, %s, %s, %s, %s, NULL) " + "ON DUPLICATE KEY UPDATE encrypted_content = VALUES(encrypted_content), " + "nonce = VALUES(nonce), ratchet_header = VALUES(ratchet_header), x3dh_header = NULL", + (u["message_id"], user_id, SELF_DEVICE_ID, + u["encrypted_content"], u["nonce"], self_header), + ) + conn.commit() + finally: + conn.close() + + +# --- Phantom Users --- + +def create_phantom_user(email: str) -> dict: + """Create a phantom user with valid crypto keys for X3DH. + + Phantom users have rsa_public_key = 'PHANTOM' as a marker. + Returns user dict: {id, username, email, identity_key}. + """ + username = email.split("@")[0] + user_id = generate_uuid() + + # Generate real crypto keys so X3DH works on the client side + ik_private, ik_public = generate_identity_keypair() + ik_public_bytes = serialize_ed25519_public(ik_public) + + spk = generate_signed_prekey(ik_private) + spk_pub_bytes = serialize_x25519_public(spk["public"]) + spk_sig = spk["signature"] + + opks = generate_one_time_prekeys(count=5) + + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "INSERT INTO users (id, username, email, rsa_public_key, identity_key) " + "VALUES (%s, %s, %s, %s, %s)", + (user_id, username, email, "PHANTOM", ik_public_bytes), + ) + cursor.execute( + "INSERT INTO signed_prekeys (id, user_id, public_key, signature) VALUES (%s, %s, %s, %s)", + (spk["id"], user_id, spk_pub_bytes, spk_sig), + ) + for opk in opks: + cursor.execute( + "INSERT INTO one_time_prekeys (id, user_id, public_key) VALUES (%s, %s, %s)", + (opk["id"], user_id, serialize_x25519_public(opk["public"])), + ) + conn.commit() + return {"id": user_id, "username": username, "email": email, "identity_key": ik_public_bytes} + finally: + conn.close() + + +def is_phantom_user(user_id: str) -> bool: + """Check if a user is a phantom (rsa_public_key == 'PHANTOM').""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute("SELECT rsa_public_key FROM users WHERE id = %s", (user_id,)) + row = cursor.fetchone() + return row is not None and row[0] == "PHANTOM" + finally: + conn.close() + + +def delete_phantom_user(user_id: str): + """Delete a phantom user. CASCADE removes signed_prekeys, one_time_prekeys, + conversation_members, message_recipients, etc.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "DELETE FROM users WHERE id = %s AND rsa_public_key = %s", + (user_id, "PHANTOM"), + ) + conn.commit() + finally: + conn.close() + + +def upgrade_phantom_user(phantom_id: str, username: str, rsa_public_key_pem: str, + identity_key: bytes) -> str | None: + """Upgrade a phantom user to a real user in-place. + + Preserves user_id and all FK references (conversation_members, group_invitations, etc.). + Deletes phantom's server-generated prekeys (real user will upload own on first login). + Returns phantom_id as the new user_id, or None if phantom no longer exists. + """ + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "UPDATE users SET username = %s, rsa_public_key = %s, identity_key = %s " + "WHERE id = %s AND rsa_public_key = 'PHANTOM'", + (username, rsa_public_key_pem, identity_key, phantom_id), + ) + if cursor.rowcount == 0: + conn.rollback() + return None + # Remove phantom's server-generated crypto keys — real user uploads own + cursor.execute("DELETE FROM signed_prekeys WHERE user_id = %s", (phantom_id,)) + cursor.execute("DELETE FROM one_time_prekeys WHERE user_id = %s", (phantom_id,)) + conn.commit() + return phantom_id + finally: + conn.close() + + +def get_all_phantom_user_ids() -> set[str]: + """Return set of all phantom user IDs (for server startup cache).""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute("SELECT id FROM users WHERE rsa_public_key = %s", ("PHANTOM",)) + return {row[0] for row in cursor.fetchall()} + finally: + conn.close() + + +def cleanup_stale_phantoms(max_age_days: int = 30) -> int: + """Delete phantom users older than max_age_days with no active conversations with real users.""" + conn = get_connection() + try: + cursor = conn.cursor() + # Two-step: SELECT ids first, then DELETE. + # MySQL error 1093: can't DELETE from table referenced in subquery. + cursor.execute(""" + SELECT u.id FROM users u + WHERE u.rsa_public_key = 'PHANTOM' + AND u.created_at < DATE_SUB(NOW(), INTERVAL %s DAY) + AND NOT EXISTS ( + SELECT 1 FROM conversation_members cm1 + JOIN conversation_members cm2 ON cm1.conversation_id = cm2.conversation_id + JOIN users u2 ON cm2.user_id = u2.id + WHERE cm1.user_id = u.id + AND u2.rsa_public_key != 'PHANTOM' + ) + """, (max_age_days,)) + ids = [row[0] for row in cursor.fetchall()] + if not ids: + return 0 + cursor.execute( + "DELETE FROM users WHERE id IN (%s)" % ",".join(["%s"] * len(ids)), + ids, + ) + deleted = cursor.rowcount + conn.commit() + return deleted + finally: + conn.close() + + +def remove_conversation_member_atomic(conversation_id: str, user_id: str) -> bool: + """Remove member and return True if actually removed (row existed). M6 TOCTOU fix.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "DELETE FROM conversation_members WHERE conversation_id = %s AND user_id = %s", + (conversation_id, user_id), + ) + conn.commit() + return cursor.rowcount > 0 + finally: + conn.close() + + +def get_stale_uploads(max_age_seconds: int = 3600) -> list[dict]: + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT file_id FROM image_uploads " + "WHERE completed = FALSE AND created_at < DATE_SUB(NOW(), INTERVAL %s SECOND)", + (max_age_seconds,), + ) + return cursor.fetchall() + finally: + conn.close() + + +# --------------------------------------------------------------------------- +# Metadata retention cleanup +# --------------------------------------------------------------------------- + +def cleanup_old_reads(days: int = 90, batch_size: int = 10000) -> int: + """Delete message_reads older than N days in batches. + + Only deletes reads for messages whose created_at is also past the retention + window. This prevents phantom unreads: get_unread_counts uses the same + time window (max_age_days) so messages outside the window aren't counted. + """ + total = 0 + while True: + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "DELETE FROM message_reads " + "WHERE read_at < DATE_SUB(NOW(), INTERVAL %s DAY) " + "AND message_id IN (" + " SELECT id FROM messages " + " WHERE created_at < DATE_SUB(NOW(), INTERVAL %s DAY)" + ") LIMIT %s", + (days, days, batch_size), + ) + count = cursor.rowcount + conn.commit() + total += count + if count < batch_size: + break + finally: + conn.close() + return total + + +def cleanup_old_reactions(days: int = 90, batch_size: int = 10000) -> int: + """Delete message_reactions older than N days in batches.""" + total = 0 + while True: + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "DELETE FROM message_reactions WHERE created_at < DATE_SUB(NOW(), INTERVAL %s DAY) LIMIT %s", + (days, batch_size), + ) + count = cursor.rowcount + conn.commit() + total += count + if count < batch_size: + break + finally: + conn.close() + return total diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..776fe42 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,70 @@ +version: "3.9" + +# Local development stack: encrypted-chat server + MySQL +# Usage: +# docker compose up — start server + db +# docker compose up --build — rebuild server image first +# docker compose down -v — stop and remove volumes (wipes DB data) + +services: + db: + image: mysql:8.0 + restart: unless-stopped + environment: + MYSQL_ROOT_PASSWORD: devpassword + MYSQL_DATABASE: encrypted_chat + MYSQL_USER: chat + MYSQL_PASSWORD: chatpassword + volumes: + # Persist DB data between restarts + - db_data:/var/lib/mysql + # Auto-import schema on first start + - ./schema.sql:/docker-entrypoint-initdb.d/01_schema.sql:ro + ports: + - "3306:3306" + healthcheck: + test: ["CMD", "mysqladmin", "ping", "-h", "localhost", "-u", "root", "-pdevpassword"] + interval: 10s + timeout: 5s + retries: 10 + start_period: 30s + + server: + build: . + restart: unless-stopped + depends_on: + db: + condition: service_healthy + ports: + - "5000:5000" + volumes: + - uploads:/app/uploads + environment: + # MySQL connection + MYSQL_HOST: db + MYSQL_PORT: 3306 + MYSQL_USER: chat + MYSQL_PASSWORD: chatpassword + MYSQL_DATABASE: encrypted_chat + DB_POOL_SIZE: 10 + + # Server config + SERVER_HOST: 0.0.0.0 + SERVER_PORT: 5000 + UPLOAD_DIR: /app/uploads + + # Dev mode: registration codes returned in response (no SMTP needed) + ENVIRONMENT: dev + + # TLS: disabled by default for local dev (set TLS_ENABLED=true for prod) + TLS_ENABLED: "false" + + # Logging + LOG_LEVEL: INFO + + # Metadata retention (days) + METADATA_RETENTION_DAYS: 90 + +volumes: + db_data: + uploads: diff --git a/gemini.md b/gemini.md new file mode 100644 index 0000000..48b422b --- /dev/null +++ b/gemini.md @@ -0,0 +1,152 @@ +# Gemini Advanced Roadmap: Beyond the Basics + +Tento dokument obsahuje pokročilé návrhy na vylepšení bezpečnosti, architektury a UX aplikace `encrypted_chat`. Tyto body jdou nad rámec běžného "best practice" a směřují k funkcionalitě profesionálních secure messengerů (Signal, Threema, Wire) se zaměřením na ochranu metadat a anti-forenzní techniky. + +--- + +## 1. Ochrana Metadat & Traffic Analysis Resistance +*Cíl: Server by neměl vědět, KDO s KÝM komunikuje, ani JAKÝ typ dat si posílají.* + +### Sealed Sender (Odesílatel v obálce) +- **Koncept:** Server zná pouze `recipient_id`. Identita odesílatele (`sender_id`) je zašifrována uvnitř zprávy (v "obálce"), kterou server nedokáže přečíst. +- **Implementace:** + 1. Odesílatel vygeneruje klíč pro obálku (např. z profilu příjemce). + 2. Zabalí `sender_id` a payload do šifrovaného bloku. + 3. Server doručí blob příjemci bez ověření odesílatele (ověření proběhne až na klientovi po rozbalení). + 4. **Výhoda:** Při kompromitaci serveru útočník nevidí sociální graf (kdo se s kým baví). + +### Traffic Padding & Constant Bitrate +- **Problém:** Délka paketu prozrazuje obsah (krátký paket = "Ahoj", dlouhý paket = obrázek/klíč). Intervaly prozrazují aktivitu. +- **Řešení:** + 1. **Padding:** Všechny zprávy doplňovat náhodnými daty na fixní velikosti (např. bloky 4KB). + 2. **Dummy Traffic (Chaff):** Klient náhodně odesílá "falešné" pakety na server, které server zahodí nebo vrátí (echo). + 3. **Výhoda:** Pro síťového analytika (ISP) vypadá tok dat jako konstantní šum. + +--- + +## 2. Anti-Forenzní Ochrana (Client-side) +*Cíl: Minimalizovat dopad fyzického zabavení zařízení nebo vynuceného odemčení.* + +### Duress Password (Heslo pod nátlakem) +- **Funkce:** Uživatel si nastaví *druhé* heslo. +- **Chování:** Pokud se přihlásí tímto heslem: + - **Varianta A (Decoy):** Odemkne se prázdná nebo falešná databáze s neškodnými konverzacemi. + - **Varianta B (Panic):** Aplikace na pozadí tiše provede **secure wipe** (přepis) privátních klíčů a reálné DB, zatímco uživateli zobrazí "Connection Error". + +### Secure Deletion & DB Vacuuming +- **Problém:** SQL `DELETE` data nesmaže fyzicky, jen označí místo jako volné. +- **Řešení:** + 1. Před smazáním zprávy přepsat obsah náhodnými byty (`UPDATE messages SET content = random_blob WHERE id = ...`). + 2. Pravidelně spouštět `VACUUM` (u SQLite) nebo optimalizaci tabulek. + 3. Pro soubory (obrázky) použít bezpečné mazání (overwrite passes) před `os.unlink()`. + +### Disappearing Messages (TTL) +- **Funkce:** Odesílatel nastaví životnost zprávy (např. 1 minuta). +- **Implementace:** Odpočet začíná okamžikem zobrazení (Read Receipt). Po uplynutí času klient data nenávratně smaže z disku (včetně secure wipe). Server maže ihned po doručení. + +--- + +## 3. Infrastruktura & Škálování +*Cíl: Odlehčit Python procesu a databázi pro podporu tisíců uživatelů.* + +### Object Storage (MinIO / S3) + Presigned URLs +- **Problém:** `server.py` blokuje I/O při příjmu velkých souborů. +- **Řešení:** + 1. Klient požádá server o upload. + 2. Server vygeneruje **Presigned PUT URL** (časově omezený token pro přímý upload do MinIO/S3). + 3. Klient nahrává data přímo do úložiště (obchází aplikační server). + 4. Server ukládá pouze odkaz (URL/Key). +- **Výhoda:** Masivní zrychlení, server řeší jen metadata. + +### Read/Write Splitting (MySQL Replication) +- **Architektura:** + - **Master DB:** Pouze pro `INSERT`, `UPDATE`, `DELETE`. + - **Read Replicas (Slaves):** Pro těžké `SELECT` dotazy (historie zpráv, hledání). +- **Implementace v `db.py`:** Router, který podle typu dotazu volí connection pool. + +--- + +## 4. Protokol & Funkce +*Cíl: Rozšíření možností komunikace bez nutnosti centralizovaného streamování.* + +### P2P Volání (WebRTC Signalizace) +- **Koncept:** Využít existující bezpečný kanál (Double Ratchet) pro výměnu SDP (Session Description Protocol) paketů. +- **Flow:** + 1. Alice pošle Bobovi zašifrovanou zprávu typu `CALL_OFFER` s parametry WebRTC. + 2. Bob odpoví `CALL_ANSWER`. + 3. Klienti si vymění `ICE_CANDIDATES` (IP adresy/porty) a naváží přímé P2P spojení (UDP). + 4. Audio/Video stream (SRTP) jde mimo server. + +### Diferenciální Synchronizace (Merkle Trees) +- **Problém:** Stahování seznamu kontaktů (`get_user_contacts`) je pomalé při velkém množství dat. +- **Řešení:** Klient a server si udržují Hash Tree (Merkle Tree) stavu. Při synchronizaci porovnají pouze root hash. Pokud se liší, stahují se jen změněné větve stromu (delta update). + +--- + +## 5. UI/UX (PyQt Speciality) +*Cíl: Ochrana soukromí na úrovni OS a skrytá komunikace.* + +### Privacy Overlay (Task Switcher) +- **Funkce:** Detekovat událost ztráty fokusu okna (`QEvent.WindowDeactivate`) nebo minimalizace. +- **Akce:** Překrýt obsah okna rozmazaným efektem (`QGraphicsBlurEffect`) nebo logem aplikace. +- **Důvod:** Zabrání operačnímu systému (Windows/Linux/macOS) vytvořit čitelný náhled okna v Alt+Tab menu nebo v historii aktivit. + +### Steganografie +- **Funkce:** Ukrýt šifrovanou zprávu do obrazových dat nevinného obrázku (např. meme kočky). +- **Implementace:** Modifikace LSB (Least Significant Bit) pixelů obrázku. +- **Výhoda:** Pro síťového admina nebo forenzní analýzu to vypadá jako běžné posílání obrázků, přítomnost šifrované komunikace je popiratelná. + +--- + +## 6. High Availability Architecture (Distribuovaný Cluster) +*Cíl: Zajištění provozu i při výpadku/napadení serveru (Active-Active "RAID 1 přes síť").* + +### Architektura: Geograficky Distribuovaný "Zero-Trust" Cluster + +#### 1. Vstupní brána (Global Traffic Manager) +- **Funkce:** Rozděluje klienty mezi dostupné servery (Round Robin / Geo-DNS). +- **Self-Healing:** Při výpadku Serveru A okamžitě přesměruje provoz na Server B. Uživatel nic nepozná. + +#### 2. Aplikační vrstva (Stateless Servers) +- **Stav:** Servery jsou **bezstavové**. `server.py` neukládá nic důležitého v RAM. +- **Škálování:** Můžete spustit N instancí serveru. Je jedno, ke kterému se uživatel připojí. +- **Komunikace:** Servery spolu mluví přes rychlý Message Bus (Redis Pub/Sub) pro doručování real-time zpráv mezi uživateli na různých uzlech. + +#### 3. Datová vrstva (Zrcadlení Dat - "RAID 1") +- **Databáze (MySQL Galera Cluster):** Synchronní multi-master replikace. Zápis na Serveru A se potvrdí, až když je fyzicky zapsán i na Serveru B (a C). + - *Efekt:* Ztráta serveru neznamená ztrátu dat (klíčů, zpráv). +- **Soubory (MinIO Cluster):** Distribuovaný Object Storage s Erasure Coding. Soubory jsou matematicky rozprostřeny přes všechny servery. Výpadek disku/serveru nevadí. + +#### 4. Bezpečnostní pojistky ("Poisoned Node") +- **Soft Delete:** Databáze nemaže data ihned, ale označuje je jako smazané (ochrana proti `DELETE *` od útočníka). +- **Client-Side Verification:** I kdyby kompromitovaný server posílal podvržené klíče, klienti ověřují digitální podpisy (Identity Keys). Server nemůže zfalšovat identitu uživatelů. + +--- + +## 7. Technický Upgrade pro Stateless Architekturu +*Cíl: Odstranit závislost na paměti procesu (RAM) pro umožnění horizontálního škálování.* + +### 1. Redis jako Distribuovaná Paměť +Nahrazení Python `dict` struktur, které jsou lokální pro jeden proces, za centrální Redis úložiště přístupné všem serverům. + +* **Párovací Session:** + * *Stav:* `pairing_sessions` (dict) -> Redis Key `pair:{code}` (Hash/String s TTL). + * *Efekt:* Uživatel může vyžádat kód na Serveru A a potvrdit ho na Serveru B. +* **Rate Limiting:** + * *Stav:* `rate_limits` (dict) -> Redis Key `rl:{ip}:{action}` (Counter s EXPIRE). + * *Efekt:* Limity platí globálně pro celý cluster, ne jen per server. + +### 2. Redis Pub/Sub pro Real-Time Routing +Doručení zprávy uživateli, který je připojen k JINÉMU serveru než odesílatel. + +* **Princip:** + 1. Server A (odesílatel) zjistí, že příjemce Bob není připojen lokálně. + 2. Server A publikuje zprávu do Redis kanálu `user:{bob_user_id}`. + 3. Server B (kde je Bob připojen) tento kanál poslouchá (subscribe). + 4. Server B přijme zprávu z Redisu a pošle ji Bobovi do otevřeného TCP socketu. + +### 3. Session Sticky vs. Stateless Uploads +Řešení pro nahrávání souborů po částech (chunks). + +* **Varianta A (Infrastructure - Sticky Sessions):** Load Balancer (HAProxy/Nginx) zajistí, že všechny požadavky od jedné IP jdou vždy na stejný server. Nejjednodušší, nevyžaduje změnu kódu. +* **Varianta B (Architectural - Direct Upload):** Viz bod 3 "Object Storage + Presigned URLs". Server vůbec nepřijímá data souboru, pouze vygeneruje token. Plně stateless řešení. diff --git a/gui_client.py b/gui_client.py new file mode 100644 index 0000000..a4aded3 --- /dev/null +++ b/gui_client.py @@ -0,0 +1,6987 @@ +"""PyQt6 GUI client for encrypted chat.""" + +import asyncio +import base64 +import json +import logging +import os +from pathlib import Path +from collections import OrderedDict + +logger = logging.getLogger(__name__) +import re +import sys +from functools import partial + +from PyQt6.QtCore import QThread, pyqtSignal, Qt, QTimer, QUrl, QSize, QRect, QPoint, QPointF, QBuffer +from PyQt6.QtWidgets import ( + QApplication, QWidget, QVBoxLayout, QHBoxLayout, QPushButton, + QLineEdit, QLabel, QListWidget, QListWidgetItem, QTextEdit, + QSplitter, QMessageBox, QInputDialog, QMenu, QStackedWidget, + QDialog, QFileDialog, QScrollArea, QFrame, QSystemTrayIcon, + QSizePolicy, QStyledItemDelegate, +) +from PyQt6.QtGui import QFont, QFontMetricsF, QAction, QPixmap, QImage, QDesktopServices, QIcon, QPainter, QColor, QBrush, QPen, QShortcut, QKeySequence +from PyQt6.QtWidgets import QGraphicsDropShadowEffect +from PyQt6.QtWidgets import QStyle + +from chat_core import ChatClient, IdentityKeyChanged +from crypto_utils import decode_pairing_qr +from theme import ThemeManager, c, qss, tm + +# H10: Image validation limits +MAX_IMAGE_DATA_SIZE = 10 * 1024 * 1024 # 10 MB max raw image data +MAX_IMAGE_DIMENSION = 8192 # 8K pixels max + + +def _safe_load_image(data: bytes) -> QImage | None: + """Load image with size and dimension validation (H10).""" + if not data or len(data) > MAX_IMAGE_DATA_SIZE: + return None + qimg = QImage.fromData(data) + if qimg.isNull(): + return None + if qimg.width() > MAX_IMAGE_DIMENSION or qimg.height() > MAX_IMAGE_DIMENSION: + return None + return qimg + + +def _safe_filename(name: str, default: str = "file") -> str: + """Sanitize filename — strip path components, prevent traversal (H11).""" + name = os.path.basename(name) + name = name.replace("\x00", "") + return name if name else default + + +def _generate_qr_pixmap_bytes(data: bytes) -> QPixmap | None: + """Generate a QR code QPixmap from raw bytes (base64-wrapped for scanner compatibility).""" + try: + import qrcode + from io import BytesIO + qr = qrcode.QRCode( + version=1, + error_correction=qrcode.constants.ERROR_CORRECT_L, + box_size=4, + border=2, + ) + qr.add_data(base64.b64encode(data).decode("ascii")) + qr.make(fit=True) + img = qr.make_image(fill_color="black", back_color="white") + buf = BytesIO() + img.save(buf, format="PNG") + buf.seek(0) + qimg = QImage() + qimg.loadFromData(buf.getvalue()) + if qimg.isNull(): + return None + return QPixmap.fromImage(qimg) + except ImportError: + return None + except Exception: + return None + + +def _decode_qr_image_file(path: str) -> bytes | None: + """Decode base64-wrapped binary payload from a QR image file.""" + try: + from PIL import Image + pil_img = Image.open(path) + except Exception: + return None + try: + from pyzbar.pyzbar import decode as pyzbar_decode + results = pyzbar_decode(pil_img) + if not results: + return None + qr_text = results[0].data + except ImportError: + return None + except Exception: + return None + try: + return base64.b64decode(qr_text) + except Exception: + return None + + +_AVATAR_CACHE_MAX = 512 # max cached avatar pixmaps per cache + + +class _LRUPixmapCache: + """Simple LRU cache for QPixmap objects with a fixed max size.""" + + def __init__(self, maxsize: int = _AVATAR_CACHE_MAX): + self._data: OrderedDict[str, QPixmap] = OrderedDict() + self._maxsize = maxsize + + def get(self, key: str) -> QPixmap | None: + if key in self._data: + self._data.move_to_end(key) + return self._data[key] + return None + + def put(self, key: str, value: QPixmap) -> None: + if key in self._data: + self._data.move_to_end(key) + self._data[key] = value + else: + self._data[key] = value + if len(self._data) > self._maxsize: + self._data.popitem(last=False) + + def __contains__(self, key: str) -> bool: + return key in self._data + + def __getitem__(self, key: str) -> QPixmap: + self._data.move_to_end(key) + return self._data[key] + + def __setitem__(self, key: str, value: QPixmap) -> None: + self.put(key, value) + + def clear(self) -> None: + self._data.clear() + + +# URL regex: matches http:// and https:// URLs in raw (not-yet-escaped) text +_URL_RE = re.compile( + r'(https?://[^\s<>"\')\]]+)', + re.IGNORECASE, +) +_URL_TRAILING_PUNCT = re.compile(r'[.,;:!?]+$') + + +def _linkify_urls(raw_text: str, https_color: str | None = None, + http_color: str | None = None) -> str: + """HTML-escape text and convert URLs into clickable tags. + + HTTPS links get blue styling. HTTP links get orange + unlock icon warning. + Processes raw (unescaped) text — returns HTML-safe string. + """ + _https = https_color or c().link_https + _http = http_color or c().link_http + + def _esc(s): + return s.replace("&", "&").replace("<", "<").replace(">", ">") + + parts = _URL_RE.split(raw_text) + result = [] + for i, part in enumerate(parts): + if i % 2 == 1: + # URL match — strip trailing sentence punctuation + trail_m = _URL_TRAILING_PUNCT.search(part) + if trail_m: + url = part[:trail_m.start()] + trail = part[trail_m.start():] + else: + url = part + trail = "" + url_esc = _esc(url) + if url.lower().startswith("http://"): + result.append( + f'' + f'\U0001f513 {url_esc}' + ) + else: + result.append( + f'' + f'{url_esc}' + ) + if trail: + result.append(_esc(trail)) + else: + result.append(_esc(part)) + return "".join(result) + + +def setup_logging(): + level_name = os.getenv("LOG_LEVEL", "WARNING").upper() + level = getattr(logging, level_name, logging.WARNING) + logging.basicConfig(level=level, format="%(levelname)s: %(message)s") + + + +MAX_INPUT_CHARS = int(os.getenv("MAX_INPUT_CHARS", "2000")) + +# Custom item data roles for conversation list delegate +ROLE_CONV_ID = Qt.ItemDataRole.UserRole +ROLE_DISPLAY_NAME = Qt.ItemDataRole.UserRole + 1 +ROLE_PREVIEW = Qt.ItemDataRole.UserRole + 2 # last message preview text +ROLE_TIMESTAMP = Qt.ItemDataRole.UserRole + 3 # last message relative time +ROLE_UNREAD = Qt.ItemDataRole.UserRole + 4 # int unread count +ROLE_IS_FAV = Qt.ItemDataRole.UserRole + 5 # bool +ROLE_AVATAR = Qt.ItemDataRole.UserRole + 6 # QPixmap (circular, 44px) +ROLE_VERIFIED = Qt.ItemDataRole.UserRole + 7 # str: "verified", "trusted", "" (DMs only) +ROLE_RECEIPT = Qt.ItemDataRole.UserRole + 8 # str: "read", "delivered", "sent", "" (own msgs only) + + +def _relative_time(ts: str) -> str: + """Convert ISO timestamp to relative time string for conversation list.""" + if not ts or len(ts) < 16: + return "" + try: + from datetime import datetime, timezone + # Parse "YYYY-MM-DD HH:MM:SS" or "YYYY-MM-DDTHH:MM:SS" + clean = ts.replace("T", " ")[:19] + msg_time = datetime.strptime(clean, "%Y-%m-%d %H:%M:%S").replace( + tzinfo=timezone.utc + ) + now = datetime.now(timezone.utc) + diff = now - msg_time + secs = int(diff.total_seconds()) + if secs < 60: + return "now" + if secs < 3600: + return f"{secs // 60}m" + if secs < 86400: + return f"{secs // 3600}h" + days = secs // 86400 + if days == 1: + return "Yesterday" + if days < 7: + return msg_time.strftime("%a") # Mon, Tue, ... + return msg_time.strftime("%d/%m") + except Exception: + return ts[11:16] if len(ts) >= 16 else "" + + +def _format_msg_time(ts: str) -> str: + """Format timestamp for message bubbles: HH:MM today, 'Yesterday HH:MM', + 'Mon HH:MM' this week, 'DD.MM. HH:MM' older.""" + if not ts or len(ts) < 16: + return "" + try: + from datetime import datetime, timezone + clean = ts.replace("T", " ")[:19] + msg_time = datetime.strptime(clean, "%Y-%m-%d %H:%M:%S").replace( + tzinfo=timezone.utc + ) + now = datetime.now(timezone.utc) + hhmm = msg_time.strftime("%H:%M") + if msg_time.date() == now.date(): + return hhmm + diff_days = (now.date() - msg_time.date()).days + if diff_days == 1: + return f"Yesterday {hhmm}" + if diff_days < 7: + return f"{msg_time.strftime('%a')} {hhmm}" + if msg_time.year == now.year: + return f"{msg_time.day}.{msg_time.month}. {hhmm}" + return f"{msg_time.day}.{msg_time.month}.{msg_time.year} {hhmm}" + except Exception: + return ts[11:16] if len(ts) >= 16 else "" + + +class ConversationDelegate(QStyledItemDelegate): + """Custom delegate that paints Signal/Telegram-style conversation rows.""" + + ITEM_HEIGHT = 68 + AVATAR_SIZE = 44 + BADGE_SIZE = 20 + HPAD = 10 + VPAD = 10 + + def sizeHint(self, option, index): + return QSize(option.rect.width(), self.ITEM_HEIGHT) + + def paint(self, painter, option, index): + painter.save() + painter.setRenderHint(QPainter.RenderHint.Antialiasing) + t = c() + + rect = option.rect + is_selected = bool(option.state & QStyle.StateFlag.State_Selected) + is_hover = bool(option.state & QStyle.StateFlag.State_MouseOver) + + # Background + if is_selected: + painter.fillRect(rect, QColor(t.bg_selected)) + elif is_hover: + painter.fillRect(rect, QColor(t.bg_hover)) + + # Data from item roles + name = index.data(ROLE_DISPLAY_NAME) or "" + preview = index.data(ROLE_PREVIEW) or "" + timestamp = index.data(ROLE_TIMESTAMP) or "" + unread = index.data(ROLE_UNREAD) or 0 + is_fav = index.data(ROLE_IS_FAV) or False + avatar_pix = index.data(ROLE_AVATAR) + verified = index.data(ROLE_VERIFIED) or "" + + x = rect.x() + self.HPAD + y = rect.y() + self.VPAD + avail_w = rect.width() - 2 * self.HPAD + + # -- Avatar (left) -- + av_y = rect.y() + (rect.height() - self.AVATAR_SIZE) // 2 + if avatar_pix and not avatar_pix.isNull(): + painter.drawPixmap(x, av_y, self.AVATAR_SIZE, self.AVATAR_SIZE, + avatar_pix) + else: + painter.setBrush(QBrush(QColor(t.bg_secondary))) + painter.setPen(Qt.PenStyle.NoPen) + painter.drawEllipse(x, av_y, self.AVATAR_SIZE, self.AVATAR_SIZE) + painter.setPen(QColor(t.text_muted)) + f = QFont() + f.setPointSize(14) + f.setBold(True) + painter.setFont(f) + painter.drawText( + QRect(x, av_y, self.AVATAR_SIZE, self.AVATAR_SIZE), + Qt.AlignmentFlag.AlignCenter, + name[0].upper() if name else "?", + ) + + text_x = x + self.AVATAR_SIZE + 10 + text_w = avail_w - self.AVATAR_SIZE - 10 + + # -- Timestamp (top-right) -- + ts_w = 50 + painter.setPen(QColor(t.text_muted)) + ts_font = QFont() + ts_font.setPointSize(8) + painter.setFont(ts_font) + ts_rect = QRect( + rect.x() + rect.width() - self.HPAD - ts_w, + y, ts_w, 18, + ) + painter.drawText(ts_rect, Qt.AlignmentFlag.AlignRight | Qt.AlignmentFlag.AlignVCenter, timestamp) + + name_w = text_w - ts_w - 8 + + # -- Name (line 1) -- + name_font = QFont() + name_font.setPointSize(10) + if unread > 0: + name_font.setBold(True) + painter.setFont(name_font) + painter.setPen(QColor(t.text_primary)) + display_name = name + if is_fav: + display_name = f"\u2605 {name}" + # Elide name to fit + fm = painter.fontMetrics() + elided_name = fm.elidedText(display_name, Qt.TextElideMode.ElideRight, name_w) + painter.drawText(text_x, y, name_w, 22, + Qt.AlignmentFlag.AlignLeft | Qt.AlignmentFlag.AlignVCenter, + elided_name) + + # -- Verification badge (after name) -- + if verified == "verified": + name_text_w = fm.horizontalAdvance(elided_name) + badge_x = text_x + name_text_w + 4 + badge_y_center = y + 11 + painter.setPen(Qt.PenStyle.NoPen) + painter.setBrush(QBrush(QColor(t.success))) + painter.drawEllipse(badge_x, badge_y_center - 5, 10, 10) + # Checkmark inside circle + painter.setPen(QPen(QColor(t.bg_primary), 1.5)) + painter.drawLine(badge_x + 2, badge_y_center, badge_x + 4, badge_y_center + 2) + painter.drawLine(badge_x + 4, badge_y_center + 2, badge_x + 8, badge_y_center - 3) + + # -- Preview (line 2) -- + preview_y = y + 24 + preview_font = QFont() + preview_font.setPointSize(9) + painter.setFont(preview_font) + preview_w = text_w - (self.BADGE_SIZE + 8 if unread > 0 else 0) + fm2 = painter.fontMetrics() + receipt = index.data(ROLE_RECEIPT) or "" + preview_x = text_x + if receipt: + check_color = QColor(t.success) if receipt == "read" else QColor(t.text_muted) + painter.setPen(check_color) + single_w = fm2.horizontalAdvance("\u2713") + overlap = single_w * 0.4 + cy = preview_y + 10 # vertical center of 20px row + # First check + painter.drawText(int(preview_x), preview_y, int(single_w), 20, + Qt.AlignmentFlag.AlignLeft | Qt.AlignmentFlag.AlignVCenter, + "\u2713") + if receipt in ("delivered", "read"): + # Second check, overlapping + x2 = preview_x + single_w - overlap + painter.drawText(int(x2), preview_y, int(single_w), 20, + Qt.AlignmentFlag.AlignLeft | Qt.AlignmentFlag.AlignVCenter, + "\u2713") + total_w = single_w * 2 - overlap + 4 + else: + total_w = single_w + 4 + preview_x += total_w + preview_w -= total_w + preview_x = int(preview_x) + preview_w = int(preview_w) + painter.setPen(QColor(t.text_muted)) + elided_preview = fm2.elidedText(preview, Qt.TextElideMode.ElideRight, preview_w) + painter.drawText(preview_x, preview_y, preview_w, 20, + Qt.AlignmentFlag.AlignLeft | Qt.AlignmentFlag.AlignVCenter, + elided_preview) + + # -- Unread badge (bottom-right) -- + if unread > 0: + badge_x = rect.x() + rect.width() - self.HPAD - self.BADGE_SIZE + badge_y = preview_y + 1 + painter.setBrush(QBrush(QColor(t.accent))) + painter.setPen(Qt.PenStyle.NoPen) + painter.drawRoundedRect(badge_x, badge_y, self.BADGE_SIZE, self.BADGE_SIZE, 10, 10) + painter.setPen(QColor(t.accent_text)) + badge_font = QFont() + badge_font.setPointSize(7) + badge_font.setBold(True) + painter.setFont(badge_font) + badge_text = str(unread) if unread < 100 else "99+" + painter.drawText( + QRect(badge_x, badge_y, self.BADGE_SIZE, self.BADGE_SIZE), + Qt.AlignmentFlag.AlignCenter, badge_text, + ) + + # -- Bottom separator line -- + painter.setPen(QColor(t.separator)) + painter.drawLine(text_x, rect.bottom(), rect.right() - self.HPAD, rect.bottom()) + + painter.restore() + + +class _ReceiptFooter(QWidget): + """Tiny widget that draws timestamp + receipt checkmarks with tight spacing.""" + + def __init__(self, time_str: str, status: str, + time_color: str, check_color: str, read_color: str, + parent=None): + super().__init__(parent) + self._time = time_str + self._status = status # "", "sent", "delivered", "read" + self._time_color = QColor(time_color) + self._check_color = QColor(check_color) + self._read_color = QColor(read_color) + self._font = QFont() + self._font.setPointSize(8) + fm = QFontMetricsF(self._font) + tw = fm.horizontalAdvance(self._time + " ") + cw = fm.horizontalAdvance("\u2713") + # 2nd check overlaps 1st by 40% of its width + overlap = cw * 0.4 + checks_w = 0.0 + if status == "sent": + checks_w = cw + elif status in ("delivered", "read"): + checks_w = cw * 2 - overlap + total_w = tw + checks_w + 2 + h = fm.height() + 2 + self.setFixedSize(int(total_w + 1), int(h + 1)) + + def paintEvent(self, event): + p = QPainter(self) + p.setRenderHint(QPainter.RenderHint.Antialiasing) + p.setFont(self._font) + fm = QFontMetricsF(self._font) + y_base = fm.ascent() + 1 + + # Draw time + p.setPen(self._time_color) + p.drawText(QPointF(0, y_base), self._time) + x = fm.horizontalAdvance(self._time) + 4 + + if not self._status: + p.end() + return + + cw = fm.horizontalAdvance("\u2713") + overlap = cw * 0.4 + color = self._read_color if self._status == "read" else self._check_color + + # First check + p.setPen(color) + p.drawText(QPointF(x, y_base), "\u2713") + + # Second check (tight overlap) + if self._status in ("delivered", "read"): + p.drawText(QPointF(x + cw - overlap, y_base), "\u2713") + + p.end() + + +class MessageInput(QTextEdit): + """Multiline message input: Enter sends, Shift+Enter inserts newline.""" + send_requested = pyqtSignal() + file_dropped = pyqtSignal(str) + + @staticmethod + def _style_normal(): + t = c() + return ( + f"QTextEdit {{ background-color: {t.bg_secondary}; border: 1px solid {t.border}; " + f"border-radius: 18px; padding: 8px 14px; color: {t.text_primary}; }}" + f"QTextEdit:focus {{ border: 1px solid {t.border_focus}; }}" + ) + + @staticmethod + def _style_drop(): + t = c() + return ( + f"QTextEdit {{ background-color: {t.bg_secondary}; border: 2px dashed {t.accent}; " + f"border-radius: 18px; padding: 8px 14px; color: {t.text_primary}; }}" + ) + + def __init__(self, parent=None): + super().__init__(parent) + self.setAcceptRichText(False) + self.setPlaceholderText("Type a message...") + self.setMinimumHeight(52) + self.setMaximumHeight(120) + self.setAcceptDrops(True) + self.drop_enabled = False + self.setStyleSheet(self._style_normal()) + self.textChanged.connect(self._auto_resize) + # Tight line spacing — set on document default cursor format + from PyQt6.QtGui import QTextBlockFormat + fmt = QTextBlockFormat() + fmt.setTopMargin(0) + fmt.setBottomMargin(0) + fmt.setLineHeight(0, 0) # 0 = SingleHeight + self._block_fmt = fmt + # Apply to default block format + cursor = self.textCursor() + cursor.setBlockFormat(fmt) + self.setTextCursor(cursor) + + def _auto_resize(self): + doc_height = int(self.document().size().height()) + 16 # padding + new_h = max(52, min(doc_height, 120)) + self.setFixedHeight(new_h) + + def keyPressEvent(self, event): + if event.key() in (Qt.Key.Key_Return, Qt.Key.Key_Enter): + if event.modifiers() & Qt.KeyboardModifier.ShiftModifier: + # Insert plain newline instead of new paragraph + self.textCursor().insertText("\n") + return + else: + self.send_requested.emit() + return + super().keyPressEvent(event) + + def dragEnterEvent(self, event): + if not self.drop_enabled: + event.ignore() + return + if event.mimeData().hasUrls() and any(u.isLocalFile() for u in event.mimeData().urls()): + event.acceptProposedAction() + self.setStyleSheet(self._style_drop()) + else: + super().dragEnterEvent(event) + + def dragMoveEvent(self, event): + if event.mimeData().hasUrls(): + event.acceptProposedAction() + else: + super().dragMoveEvent(event) + + def dragLeaveEvent(self, event): + self.setStyleSheet(self._style_normal()) + super().dragLeaveEvent(event) + + def dropEvent(self, event): + self.setStyleSheet(self._style_normal()) + if event.mimeData().hasUrls(): + for url in event.mimeData().urls(): + if url.isLocalFile(): + self.file_dropped.emit(url.toLocalFile()) + event.acceptProposedAction() + else: + super().dropEvent(event) + + +class AsyncBridge(QThread): + """Runs asyncio event loop in a background thread, emits Qt signals.""" + connected = pyqtSignal() + connection_error = pyqtSignal(str) + login_result = pyqtSignal(bool, str) + register_result = pyqtSignal(bool, str) + conversations_loaded = pyqtSignal(list) + messages_loaded = pyqtSignal(str, list) # conv_id, messages + older_messages_loaded = pyqtSignal(str, list) # conv_id, older messages + message_sent = pyqtSignal(bool, str) + message_sent_payload = pyqtSignal(str, dict) # conv_id, message dict (for local append) + new_notification = pyqtSignal(dict) # decrypted payload + pairing_code = pyqtSignal(str, str, bytes) + pairing_complete = pyqtSignal(bool, str) + add_member_result = pyqtSignal(bool, str) + remove_member_result = pyqtSignal(bool, str) + authorize_result = pyqtSignal(bool, str) + rotate_result = pyqtSignal(bool, str) + reencrypt_status = pyqtSignal(str) + messages_read_notification = pyqtSignal(dict) + message_delivered_notification = pyqtSignal(dict) + typing_start_notification = pyqtSignal(dict) + typing_stop_notification = pyqtSignal(dict) + message_deleted_notification = pyqtSignal(dict) + image_sent = pyqtSignal(bool, str) + image_downloaded = pyqtSignal(str, bytes) # file_id, decrypted bytes + image_download_failed = pyqtSignal(str) # file_id + delete_message_result = pyqtSignal(bool, str) + reconnected = pyqtSignal() + conversation_updated = pyqtSignal() + connection_state_changed = pyqtSignal(str) # "connected", "disconnected", "reconnecting" + profile_loaded = pyqtSignal(dict) + profile_updated = pyqtSignal(bool, str) + avatar_loaded = pyqtSignal(str, bytes) # user_id, avatar_bytes + avatar_fetch_failed = pyqtSignal(str) # user_id — allows retry on next refresh + _avatar_changed_signal = pyqtSignal(str) # user_id — pushed when contact changes their avatar + online_status_changed = pyqtSignal(str, bool) # user_id, is_online + online_users_loaded = pyqtSignal(list) # list of user_ids + invitations_loaded = pyqtSignal(list) # list of invitation dicts + invitation_result = pyqtSignal(bool, str) # ok, message + invitation_received = pyqtSignal(dict) # invitation notification data + group_avatar_loaded = pyqtSignal(str, bytes) # conv_id, avatar_bytes + group_avatar_fetch_failed = pyqtSignal(str) # conv_id — allows retry on next refresh + group_avatar_updated = pyqtSignal(bool, str) # ok, message + device_added_notification = pyqtSignal(dict) + session_reset_notification = pyqtSignal(str, str) # from_user_id, from_device_id + reaction_result = pyqtSignal(bool, str) # ok, message + reaction_notification = pyqtSignal(dict) # {message_id, conversation_id, user_id, reaction, action} + pin_notification = pyqtSignal(dict) # {message_id, conversation_id, user_id, action=pin} + unpin_notification = pyqtSignal(dict) # {message_id, conversation_id, user_id, action=unpin} + pinned_messages_loaded = pyqtSignal(str, list) # conv_id, list of pinned msg dicts + forward_result = pyqtSignal(bool, str) # ok, message + key_change_warning = pyqtSignal(str, str, str, bool, bytes) # user_id, username, old_key_hex, was_verified, new_key_bytes + password_changed = pyqtSignal(bool, str) # ok, message + username_changed = pyqtSignal(bool, str) # ok, message + message_queued = pyqtSignal(int) # number of messages currently in retry queue (0 = queue cleared) + + def __init__(self): + super().__init__() + self.client = ChatClient() + self.loop: asyncio.AbstractEventLoop | None = None + self._running = True + self.client._reencrypt_progress_cb = self._emit_reencrypt_status + self.client._key_change_cb = self._emit_key_change_warning + self._ready: asyncio.Event | None = None + self._avatar_inflight: set[str] = set() + self._group_avatar_inflight: set[str] = set() + self._messages_inflight: set[str] = set() # conv_ids currently being fetched + self._invitations_inflight = False + self._pending_send_queue: list[dict] = [] # messages queued for retry after reconnect + + def _emit_reencrypt_status(self, message: str): + self.reencrypt_status.emit(message) + + def _emit_key_change_warning(self, user_id: str, username: str, old_key_hex: str, was_verified: bool, new_key_bytes: bytes = b""): + self.key_change_warning.emit(user_id, username, old_key_hex, was_verified, new_key_bytes) + + def run(self): + if sys.platform == "win32": + self.loop = asyncio.SelectorEventLoop() + else: + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + self._ready = asyncio.Event() + try: + self.loop.run_until_complete(self._run()) + except Exception as e: + logger.error("AsyncBridge loop crashed: %s", e, exc_info=True) + finally: + self.loop.close() + + async def _run(self): + try: + await self.client.connect() + self.client._listener_task = asyncio.create_task(self.client._background_listener()) + if self._ready: + self._ready.set() + self.connected.emit() + self.connection_state_changed.emit("connected") + except Exception as e: + self.connection_error.emit(str(e)) + return + + # Process notifications + await self._notification_loop() + + async def _notification_loop(self): + while self._running: + try: + # Check if listener task died (connection lost) + if (self.client._listener_task and self.client._listener_task.done() + and not self.client.connected): + self.connection_state_changed.emit("disconnected") + if self.client.session: + await self._auto_reconnect() + continue + + notif = await asyncio.wait_for( + self.client._notification_queue.get(), timeout=0.5 + ) + notif_type = notif.get("type", "") + data = notif.get("data", {}) + if notif_type in ("conversation_created", "member_removed", + "conversation_renamed"): + self.conversation_updated.emit() + elif notif_type == "member_added": + self.conversation_updated.emit() + # Redistribute our sender key to the newly joined member + conv_id = data.get("conversation_id", "") + new_user_id = data.get("user_id", "") + if conv_id and new_user_id: + asyncio.ensure_future( + self.client.redistribute_sender_key_to_member(conv_id, new_user_id) + ) + elif notif_type == "group_invitation": + self.invitation_received.emit(data) + elif notif_type == "device_added": + self.device_added_notification.emit(data) + elif notif_type == "user_online": + self.online_status_changed.emit(data.get("user_id", ""), True) + elif notif_type == "user_offline": + self.online_status_changed.emit(data.get("user_id", ""), False) + elif notif_type == "online_users": + self.online_users_loaded.emit(data.get("user_ids", [])) + elif notif_type == "messages_read": + self.messages_read_notification.emit(data) + elif notif_type == "message_delivered": + self.message_delivered_notification.emit(data) + elif notif_type == "typing_start": + self.typing_start_notification.emit(data) + elif notif_type == "typing_stop": + self.typing_stop_notification.emit(data) + elif notif_type == "message_deleted": + self.message_deleted_notification.emit(data) + elif notif_type == "session_reset": + from_uid = data.get("from_user_id", "") + from_did = data.get("from_device_id", "") + self.client.handle_session_reset_notification(from_uid, from_did or None) + self.session_reset_notification.emit(from_uid, from_did) + elif notif_type == "username_changed": + self.conversation_updated.emit() + elif notif_type == "avatar_changed": + uid = data.get("user_id", "") + if uid: + self._avatar_changed_signal.emit(uid) + elif notif_type == "message_reacted": + self.reaction_notification.emit(data) + elif notif_type == "message_pinned": + self.pin_notification.emit(data) + elif notif_type == "message_unpinned": + self.unpin_notification.emit(data) + elif notif_type == "new_message": + try: + payload = self.client.decrypt_notification(data) + except IdentityKeyChanged as ikc: + cached = self.client._user_cache.get(ikc.user_id) + uname = cached.get("username", "") if cached else "" + old_hex = "" + known = self.client._known_identity_keys.get(ikc.user_id) + if known: + old_hex = known.get("identity_key", "") + was_verified = ikc.status == "changed_verified" + self.key_change_warning.emit(ikc.user_id, uname, old_hex, was_verified, ikc.new_key_bytes) + continue + if payload: + self.new_notification.emit(payload) + # None = control message (e.g. sender key distribution), skip silently + except asyncio.TimeoutError: + continue + except Exception as e: + logger.error("Notification loop exception: %s", e, exc_info=True) + break + + async def _auto_reconnect(self): + """Auto-reconnect with exponential backoff.""" + delay = 1 + while self._running and not self.client.connected: + self.connection_state_changed.emit("reconnecting") + try: + await self.client.reconnect() + if self.client.connected and self.client.session: + self.connection_state_changed.emit("connected") + self.conversation_updated.emit() + asyncio.ensure_future(self._flush_send_queue()) + return + if self.client.login_rejected: + self.connection_state_changed.emit("revoked") + return + except Exception: + pass + await asyncio.sleep(delay) + delay = min(delay * 2, 30) + + async def _flush_send_queue(self): + """Retry all queued messages after reconnect.""" + if not self._pending_send_queue: + return + queue = self._pending_send_queue[:] + self._pending_send_queue.clear() + self.message_queued.emit(0) # Signal queue cleared + for item in queue: + await self._do_send_message( + item["conv_id"], item["text"], item["members"], item.get("reply_to") + ) + + def schedule(self, coro): + """Schedule a coroutine on the asyncio loop from the Qt thread.""" + if self.loop and self.loop.is_running(): + asyncio.run_coroutine_threadsafe(coro, self.loop) + else: + # Avoid "coroutine was never awaited" warnings if loop is down. + try: + coro.close() + except Exception: + pass + + async def _do_register(self, username, password, email): + if self._ready: + await self._ready.wait() + ok, code_or_msg = await self.client.register(username, password, email=email) + self.register_result.emit(ok, code_or_msg) + + async def _do_login(self, email, password): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.login(email, password) + self.login_result.emit(ok, msg) + + async def _do_logout(self): + if self._ready: + self._ready.clear() + try: + await self.client.close() + except Exception: + pass + self.client = ChatClient() + self.client._reencrypt_progress_cb = self._emit_reencrypt_status + try: + await self.client.connect() + self.client._listener_task = asyncio.create_task(self.client._background_listener()) + if self._ready: + self._ready.set() + self.reconnected.emit() + except Exception as e: + self.connection_error.emit(str(e)) + + async def _do_load_conversations(self): + if self._ready: + await self._ready.wait() + convs = await self.client.list_conversations() + self.conversations_loaded.emit(convs) + + async def _do_load_messages(self, conv_id): + if conv_id in self._messages_inflight: + return # Already fetching this conversation + self._messages_inflight.add(conv_id) + try: + if self._ready: + await self._ready.wait() + msgs = await self.client.get_messages(conv_id) + self.messages_loaded.emit(conv_id, msgs) + finally: + self._messages_inflight.discard(conv_id) + + async def _do_load_older_messages(self, conv_id, offset): + if self._ready: + await self._ready.wait() + msgs = await self.client.get_messages(conv_id, limit=50, offset=offset) + self.older_messages_loaded.emit(conv_id, msgs) + + async def _do_send_message(self, conv_id, text, members, reply_to=None): + if self._ready: + await self._ready.wait() + try: + ok, result = await self.client.send_message(conv_id, text, members, reply_to=reply_to) + except IdentityKeyChanged as ikc: + cached = self.client._user_cache.get(ikc.user_id) + uname = cached.get("username", "") if cached else "" + old_hex = "" + known = self.client._known_identity_keys.get(ikc.user_id) + if known: + old_hex = known.get("identity_key", "") + was_verified = ikc.status == "changed_verified" + self.key_change_warning.emit(ikc.user_id, uname, old_hex, was_verified, ikc.new_key_bytes) + self.message_sent.emit(False, "Identity key changed — accept new key first.") + return + except (ConnectionError, asyncio.TimeoutError) as e: + logger.warning("send_message network error, queuing for retry: %s", e) + self._pending_send_queue.append({ + "conv_id": conv_id, + "text": text, + "members": members, + "reply_to": reply_to, + }) + self.message_queued.emit(len(self._pending_send_queue)) + return + except Exception as e: + logger.error("send_message exception: %s", e, exc_info=True) + self.message_sent.emit(False, str(e)) + return + if ok and isinstance(result, dict): + self.message_sent.emit(True, "Message sent.") + self.message_sent_payload.emit(conv_id, result) + else: + self.message_sent.emit(ok, result if isinstance(result, str) else "Message sent.") + + async def _do_find_or_create_and_send(self, username, text): + if self._ready: + await self._ready.wait() + try: + conv_id, msg = await self.client.find_or_create_conversation(username) + if not conv_id: + self.message_sent.emit(False, msg) + return + convs = await self.client.list_conversations() + self.conversations_loaded.emit(convs) + members = [] + for cv in convs: + if cv["conversation_id"] == conv_id: + members = cv["members"] + break + ok, result = await self.client.send_message(conv_id, text, members) + if ok and isinstance(result, dict): + self.message_sent.emit(True, "Message sent.") + self.message_sent_payload.emit(conv_id, result) + else: + self.message_sent.emit(ok, result if isinstance(result, str) else "Message sent.") + except IdentityKeyChanged as ikc: + cached = self.client._user_cache.get(ikc.user_id) + uname = cached.get("username", "") if cached else "" + old_hex = "" + known = self.client._known_identity_keys.get(ikc.user_id) + if known: + old_hex = known.get("identity_key", "") + was_verified = ikc.status == "changed_verified" + self.key_change_warning.emit(ikc.user_id, uname, old_hex, was_verified, ikc.new_key_bytes) + self.message_sent.emit(False, "Identity key changed — accept new key first.") + except Exception as e: + logger.error("find_or_create_and_send exception: %s", e, exc_info=True) + self.message_sent.emit(False, str(e)) + + async def _do_create_group(self, members, name=None): + if self._ready: + await self._ready.wait() + conv_id, msg = await self.client.create_conversation(members, name=name) + if conv_id: + self.message_sent.emit(True, f"Group created") + else: + self.message_sent.emit(False, msg) + convs = await self.client.list_conversations() + self.conversations_loaded.emit(convs) + + async def _do_link_device(self, username, password): + try: + if self._ready: + await self._ready.wait() + ok, code_or_msg = await self.client.pairing_start(username) + if not ok: + self.pairing_complete.emit(False, code_or_msg) + return + code = code_or_msg + self.pairing_code.emit(code, self.client.pairing_fingerprint(), self.client.pairing_qr_data() or b"") + ok2, msg2 = await self.client.pairing_wait(code, username, password) + self.pairing_complete.emit(ok2, msg2) + except Exception as e: + logger.error("link_device exception: %s", e, exc_info=True) + self.pairing_complete.emit(False, str(e)) + + async def _do_authorize_device(self, code, fingerprint): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.authorize_device(code, fingerprint) + self.authorize_result.emit(ok, msg) + + async def _do_rotate_keys(self, username, password): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.rotate_keys(username, password) + self.rotate_result.emit(ok, msg) + + def do_register(self, username, password, email): + self.schedule(self._do_register(username, password, email)) + + def do_login(self, email, password): + self.schedule(self._do_login(email, password)) + + def load_conversations(self): + self.schedule(self._do_load_conversations()) + + def load_messages(self, conv_id): + self.schedule(self._do_load_messages(conv_id)) + + def load_older_messages(self, conv_id, offset): + self.schedule(self._do_load_older_messages(conv_id, offset)) + + def send_message(self, conv_id, text, members, reply_to=None): + self.schedule(self._do_send_message(conv_id, text, members, reply_to)) + + def send_typing_start(self, conv_id): + self.schedule(self.client.typing_start(conv_id)) + + def send_typing_stop(self, conv_id, force=False): + self.schedule(self.client.typing_stop(conv_id, force=force)) + + def send_new_chat(self, username, text): + self.schedule(self._do_find_or_create_and_send(username, text)) + + def create_group(self, members, name=None): + self.schedule(self._do_create_group(members, name=name)) + + async def _do_add_member(self, conv_id, email): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.add_member(conv_id, email) + self.add_member_result.emit(ok, msg) + if ok: + convs = await self.client.list_conversations() + self.conversations_loaded.emit(convs) + + def add_member(self, conv_id, email): + self.schedule(self._do_add_member(conv_id, email)) + + async def _do_remove_member(self, conv_id, user_id): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.remove_member(conv_id, user_id) + self.remove_member_result.emit(ok, msg) + if ok: + convs = await self.client.list_conversations() + self.conversations_loaded.emit(convs) + + def remove_member(self, conv_id, user_id): + self.schedule(self._do_remove_member(conv_id, user_id)) + + group_left = pyqtSignal(bool, str) + group_renamed = pyqtSignal(bool, str) + conversation_deleted = pyqtSignal(bool, str) + + async def _do_leave_group(self, conv_id): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.leave_group(conv_id) + self.group_left.emit(ok, msg) + if ok: + convs = await self.client.list_conversations() + self.conversations_loaded.emit(convs) + + def leave_group(self, conv_id): + self.schedule(self._do_leave_group(conv_id)) + + async def _do_rename_conversation(self, conv_id, name): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.rename_conversation(conv_id, name) + self.group_renamed.emit(ok, msg) + if ok: + convs = await self.client.list_conversations() + self.conversations_loaded.emit(convs) + + def rename_conversation(self, conv_id, name): + self.schedule(self._do_rename_conversation(conv_id, name)) + + async def _do_delete_conversation(self, conv_id): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.delete_conversation(conv_id) + self.conversation_deleted.emit(ok, msg) + if ok: + convs = await self.client.list_conversations() + self.conversations_loaded.emit(convs) + + def delete_conversation(self, conv_id): + self.schedule(self._do_delete_conversation(conv_id)) + + def link_device(self, username, password): + self.schedule(self._do_link_device(username, password)) + + def authorize_device(self, code, fingerprint): + self.schedule(self._do_authorize_device(code, fingerprint)) + + def rotate_keys(self, username, password): + self.schedule(self._do_rotate_keys(username, password)) + + async def _do_change_password(self, old_password, new_password): + if self._ready: + await self._ready.wait() + ok, msg = self.client.change_password(old_password, new_password) + self.password_changed.emit(ok, msg) + + def change_password(self, old_password, new_password): + self.schedule(self._do_change_password(old_password, new_password)) + + async def _do_change_username(self, new_username): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.change_username(new_username) + self.username_changed.emit(ok, msg) + + def change_username(self, new_username): + self.schedule(self._do_change_username(new_username)) + + async def _do_delete_message(self, message_id): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.delete_message(message_id) + self.delete_message_result.emit(ok, msg) + + def delete_message(self, message_id): + self.schedule(self._do_delete_message(message_id)) + + def reset_session(self, peer_user_id, peer_device_id=None): + self.schedule(self.client.reset_session(peer_user_id, peer_device_id)) + + async def _do_send_image(self, conv_id, image_path, members, reply_to=None): + if self._ready: + await self._ready.wait() + try: + ok, result = await self.client.send_image(conv_id, image_path, members, reply_to=reply_to) + except IdentityKeyChanged as ikc: + cached = self.client._user_cache.get(ikc.user_id) + uname = cached.get("username", "") if cached else "" + old_hex = "" + known = self.client._known_identity_keys.get(ikc.user_id) + if known: + old_hex = known.get("identity_key", "") + was_verified = ikc.status == "changed_verified" + self.key_change_warning.emit(ikc.user_id, uname, old_hex, was_verified, ikc.new_key_bytes) + self.image_sent.emit(False, "Identity key changed — accept new key first.") + return + except Exception as e: + logger.error("send_image exception: %s", e, exc_info=True) + self.image_sent.emit(False, str(e)) + return + if ok and isinstance(result, dict): + self.image_sent.emit(True, "Image sent.") + self.message_sent_payload.emit(conv_id, result) + else: + self.image_sent.emit(ok, result if isinstance(result, str) else "Image sent.") + + def send_image(self, conv_id, image_path, members, reply_to=None): + self.schedule(self._do_send_image(conv_id, image_path, members, reply_to)) + + async def _do_download_image(self, file_id, image_info): + if self._ready: + await self._ready.wait() + data = await self.client.download_image(file_id, image_info) + if data: + self.image_downloaded.emit(file_id, data) + else: + self.image_download_failed.emit(file_id) + + def download_image(self, file_id, image_info): + self.schedule(self._do_download_image(file_id, image_info)) + + file_sent = pyqtSignal(bool, str) + file_downloaded = pyqtSignal(bytes, dict) # decrypted_bytes, file_info + + async def _do_send_file(self, conv_id, file_path, members, reply_to=None): + if self._ready: + await self._ready.wait() + try: + ok, result = await self.client.send_file(conv_id, file_path, members, reply_to=reply_to) + except IdentityKeyChanged as ikc: + cached = self.client._user_cache.get(ikc.user_id) + uname = cached.get("username", "") if cached else "" + old_hex = "" + known = self.client._known_identity_keys.get(ikc.user_id) + if known: + old_hex = known.get("identity_key", "") + was_verified = ikc.status == "changed_verified" + self.key_change_warning.emit(ikc.user_id, uname, old_hex, was_verified, ikc.new_key_bytes) + self.file_sent.emit(False, "Identity key changed — accept new key first.") + return + except Exception as e: + logger.error("send_file exception: %s", e, exc_info=True) + self.file_sent.emit(False, str(e)) + return + if ok and isinstance(result, dict): + self.file_sent.emit(True, "File sent.") + self.message_sent_payload.emit(conv_id, result) + else: + self.file_sent.emit(ok, result if isinstance(result, str) else "File sent.") + + def send_file(self, conv_id, file_path, members, reply_to=None): + self.schedule(self._do_send_file(conv_id, file_path, members, reply_to)) + + async def _do_download_file(self, file_id, file_info): + if self._ready: + await self._ready.wait() + data = await self.client.download_file(file_id, file_info) + if data: + self.file_downloaded.emit(data, file_info) + + def download_file(self, file_id, file_info): + self.schedule(self._do_download_file(file_id, file_info)) + + async def _do_get_profile(self, user_id=None): + if self._ready: + await self._ready.wait() + profile = await self.client.get_profile(user_id) + if profile: + self.profile_loaded.emit(profile) + + def get_profile(self, user_id=None): + self.schedule(self._do_get_profile(user_id)) + + async def _do_update_profile(self, **fields): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.update_profile(**fields) + self.profile_updated.emit(ok, msg) + + def update_profile(self, **fields): + self.schedule(self._do_update_profile(**fields)) + + async def _do_update_avatar(self, image_data): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.update_avatar(image_data) + if ok: + # Cache own avatar locally so it appears immediately + my_uid = self.client.session.get("user_id", "") if self.client.session else "" + if my_uid: + from protocol import decode_binary, encode_binary + self.avatar_loaded.emit(my_uid, image_data) + self.profile_updated.emit(ok, msg) + + def update_avatar(self, image_data): + self.schedule(self._do_update_avatar(image_data)) + + _MAX_AVATAR_CONCURRENT = 3 + + async def _do_get_avatar(self, user_id): + if self._ready: + await self._ready.wait() + if not user_id or user_id in self._avatar_inflight: + return + # Limit concurrent avatar requests to avoid pipeline congestion + if len(self._avatar_inflight) >= self._MAX_AVATAR_CONCURRENT: + return + self._avatar_inflight.add(user_id) + try: + data = await self.client.get_avatar(user_id) + if data: + self.avatar_loaded.emit(user_id, data) + else: + self.avatar_fetch_failed.emit(user_id) + except Exception: + self.avatar_fetch_failed.emit(user_id) + finally: + self._avatar_inflight.discard(user_id) + + def get_avatar(self, user_id): + self.schedule(self._do_get_avatar(user_id)) + + async def _do_list_invitations(self): + if self._ready: + await self._ready.wait() + if self._invitations_inflight: + return + self._invitations_inflight = True + try: + invitations = await self.client.list_invitations() + self.invitations_loaded.emit(invitations) + finally: + self._invitations_inflight = False + + def list_invitations(self): + self.schedule(self._do_list_invitations()) + + async def _do_accept_invitation(self, conv_id): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.accept_invitation(conv_id) + self.invitation_result.emit(ok, msg) + if ok: + invitations = await self.client.list_invitations() + self.invitations_loaded.emit(invitations) + convs = await self.client.list_conversations() + self.conversations_loaded.emit(convs) + + def accept_invitation(self, conv_id): + self.schedule(self._do_accept_invitation(conv_id)) + + async def _do_decline_invitation(self, conv_id): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.decline_invitation(conv_id) + self.invitation_result.emit(ok, msg) + if ok: + invitations = await self.client.list_invitations() + self.invitations_loaded.emit(invitations) + + def decline_invitation(self, conv_id): + self.schedule(self._do_decline_invitation(conv_id)) + + async def _do_update_group_avatar(self, conv_id, image_data): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.update_group_avatar(conv_id, image_data) + self.group_avatar_updated.emit(ok, msg) + if ok: + convs = await self.client.list_conversations() + self.conversations_loaded.emit(convs) + + def update_group_avatar(self, conv_id, image_data): + self.schedule(self._do_update_group_avatar(conv_id, image_data)) + + async def _do_get_group_avatar(self, conv_id): + if self._ready: + await self._ready.wait() + if not conv_id or conv_id in self._group_avatar_inflight: + return + if len(self._group_avatar_inflight) >= self._MAX_AVATAR_CONCURRENT: + return + self._group_avatar_inflight.add(conv_id) + try: + data = await self.client.get_group_avatar(conv_id) + if data: + self.group_avatar_loaded.emit(conv_id, data) + else: + self.group_avatar_fetch_failed.emit(conv_id) + except Exception: + self.group_avatar_fetch_failed.emit(conv_id) + finally: + self._group_avatar_inflight.discard(conv_id) + + def get_group_avatar(self, conv_id): + self.schedule(self._do_get_group_avatar(conv_id)) + + # --- Reactions, Pins, Forwarding --- + + async def _do_react_message(self, message_id, reaction, action): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.react_message(message_id, reaction, action) + self.reaction_result.emit(ok, msg) + + def react_message(self, message_id, reaction, action="add"): + self.schedule(self._do_react_message(message_id, reaction, action)) + + async def _do_pin_message(self, message_id, conversation_id, action): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.pin_message(message_id, conversation_id, action) + if not ok: + self.reaction_result.emit(False, msg) + + def pin_message(self, message_id, conversation_id, action="pin"): + self.schedule(self._do_pin_message(message_id, conversation_id, action)) + + async def _do_get_pinned_messages(self, conv_id): + if self._ready: + await self._ready.wait() + pinned = await self.client.get_pinned_messages(conv_id) + self.pinned_messages_loaded.emit(conv_id, pinned) + + def get_pinned_messages(self, conv_id): + self.schedule(self._do_get_pinned_messages(conv_id)) + + async def _do_forward_message(self, target_conv_id, original_msg, target_members): + if self._ready: + await self._ready.wait() + try: + ok, result = await self.client.forward_message(target_conv_id, original_msg, target_members) + except IdentityKeyChanged as ikc: + cached = self.client._user_cache.get(ikc.user_id) + uname = cached.get("username", "") if cached else "" + old_hex = "" + known = self.client._known_identity_keys.get(ikc.user_id) + if known: + old_hex = known.get("identity_key", "") + was_verified = ikc.status == "changed_verified" + self.key_change_warning.emit(ikc.user_id, uname, old_hex, was_verified, ikc.new_key_bytes) + self.forward_result.emit(False, "Identity key changed — accept new key first.") + return + except Exception as e: + logger.error("forward_message exception: %s", e, exc_info=True) + self.forward_result.emit(False, str(e)) + return + if ok and isinstance(result, dict): + self.forward_result.emit(True, "Message forwarded.") + self.message_sent_payload.emit(target_conv_id, result) + else: + self.forward_result.emit(ok, result if isinstance(result, str) else "Forwarded.") + + def forward_message(self, target_conv_id, original_msg, target_members): + self.schedule(self._do_forward_message(target_conv_id, original_msg, target_members)) + + def logout(self): + self.schedule(self._do_logout()) + + def stop(self): + self._running = False + if self.loop: + asyncio.run_coroutine_threadsafe(self.client.close(), self.loop) + + +def _make_frameless(dlg: QDialog, title_text: str = ""): + """Configure a QDialog as frameless with custom title bar, rounded container, + and drop shadow. Returns a QVBoxLayout for the dialog content area — + callers just add their widgets to the returned layout. + + Usage:: + + dlg = QDialog(self) + dlg.setMinimumWidth(380) + content_layout = _make_frameless(dlg, "My Title") + content_layout.addWidget(QLabel("Hello!")) + dlg.exec() + """ + dlg.setWindowFlags( + Qt.WindowType.FramelessWindowHint | Qt.WindowType.Dialog + ) + dlg.setAttribute(Qt.WidgetAttribute.WA_TranslucentBackground) + dlg._drag_pos = None + + t = c() + + # -- Outer layout (transparent, holds container with margins for shadow) -- + outer = QVBoxLayout(dlg) + outer.setContentsMargins(12, 12, 12, 12) + outer.setSpacing(0) + + # -- Rounded container -- + container = QWidget() + container.setObjectName("_framelessContainer") + container.setStyleSheet( + f"#_framelessContainer {{ background-color: {t.bg_primary}; border-radius: 12px; }}" + ) + shadow = QGraphicsDropShadowEffect(container) + shadow.setBlurRadius(24) + shadow.setOffset(0, 4) + shadow.setColor(QColor(0, 0, 0, 80)) + container.setGraphicsEffect(shadow) + + container_lay = QVBoxLayout(container) + container_lay.setContentsMargins(0, 0, 0, 0) + container_lay.setSpacing(0) + + # -- Title bar -- + title_bar = QWidget() + title_bar.setFixedHeight(40) + title_bar.setStyleSheet( + f"background-color: {t.bg_secondary}; " + f"border-top-left-radius: 12px; border-top-right-radius: 12px;" + ) + bar_layout = QHBoxLayout(title_bar) + bar_layout.setContentsMargins(16, 0, 8, 0) + bar_layout.setSpacing(0) + + title_label = QLabel(title_text) + title_label.setStyleSheet( + f"color: {t.text_primary}; font-size: 11pt; font-weight: bold; " + f"background: transparent;" + ) + bar_layout.addWidget(title_label) + bar_layout.addStretch() + + close_btn = QPushButton("\u2715") + close_btn.setFixedSize(28, 28) + close_btn.setCursor(Qt.CursorShape.PointingHandCursor) + close_btn.setStyleSheet( + f"QPushButton {{ background: transparent; color: {t.text_muted}; " + f"border: none; border-radius: 14px; font-size: 12pt; }}" + f"QPushButton:hover {{ background-color: {t.error}; color: {t.accent_text}; }}" + ) + close_btn.clicked.connect(dlg.reject) + bar_layout.addWidget(close_btn) + + # Dragging via title bar + def _mouse_press(event): + if event.button() == Qt.MouseButton.LeftButton: + dlg._drag_pos = event.globalPosition().toPoint() - dlg.frameGeometry().topLeft() + event.accept() + def _mouse_move(event): + if dlg._drag_pos is not None and event.buttons() & Qt.MouseButton.LeftButton: + dlg.move(event.globalPosition().toPoint() - dlg._drag_pos) + event.accept() + def _mouse_release(event): + dlg._drag_pos = None + title_bar.mousePressEvent = _mouse_press + title_bar.mouseMoveEvent = _mouse_move + title_bar.mouseReleaseEvent = _mouse_release + + container_lay.addWidget(title_bar) + + # -- Content widget -- + content = QWidget() + content_layout = QVBoxLayout(content) + content_layout.setContentsMargins(16, 12, 16, 16) + content_layout.setSpacing(8) + container_lay.addWidget(content) + + outer.addWidget(container) + + # Store refs for later theming + dlg._frameless_container = container + dlg._frameless_title_bar = title_bar + dlg._frameless_title_label = title_label + return content_layout + + +class UserProfileDialog(QDialog): + """Dialog for viewing/editing user profiles.""" + + def __init__(self, bridge: AsyncBridge, user_id: str, editable: bool = False, parent=None): + super().__init__(parent) + self.bridge = bridge + self.user_id = user_id + self.editable = editable + self.setMinimumWidth(400) + self._build_ui() + self._connect_signals() + self.bridge.get_profile(user_id) + + def _build_ui(self): + t = c() + title_text = "Edit Profile" if self.editable else "User Profile" + self.layout_main = _make_frameless(self, title_text) + self.layout_main.setSpacing(12) + + # Avatar + self.avatar_label = QLabel() + self.avatar_label.setFixedSize(80, 80) + self.avatar_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + t = c() + self.avatar_label.setStyleSheet( + f"background-color: {t.bg_secondary}; border-radius: 40px; " + f"font-size: 21pt; color: {t.accent};" + ) + self.avatar_label.setText("?") + self.layout_main.addWidget(self.avatar_label, alignment=Qt.AlignmentFlag.AlignCenter) + + if self.editable: + avatar_btn = QPushButton("Change Avatar") + avatar_btn.setObjectName("secondaryBtn") + avatar_btn.clicked.connect(self._on_change_avatar) + self.layout_main.addWidget(avatar_btn, alignment=Qt.AlignmentFlag.AlignCenter) + + # Info fields + self.username_label = QLabel("") + self.username_label.setStyleSheet(f"font-size: 14pt; font-weight: bold; color: {t.accent};") + self.username_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.layout_main.addWidget(self.username_label) + + self.info_area = QVBoxLayout() + self.layout_main.addLayout(self.info_area) + + # Editable fields (only shown in edit mode) + if self.editable: + self.layout_main.addSpacing(8) + + form_label = QLabel("Profile Settings") + form_label.setStyleSheet(f"font-weight: bold; color: {t.accent};") + self.layout_main.addWidget(form_label) + + self.phone_input = QLineEdit() + self.phone_input.setPlaceholderText("Phone number") + self.layout_main.addWidget(self.phone_input) + + self.location_input = QLineEdit() + self.location_input.setPlaceholderText("Location") + self.layout_main.addWidget(self.location_input) + + from PyQt6.QtWidgets import QCheckBox + self.email_visible_cb = QCheckBox("Email visible to others") + self.email_visible_cb.setStyleSheet(f"color: {t.text_primary};") + self.layout_main.addWidget(self.email_visible_cb) + + self.phone_visible_cb = QCheckBox("Phone visible to others") + self.phone_visible_cb.setStyleSheet(f"color: {t.text_primary};") + self.layout_main.addWidget(self.phone_visible_cb) + + self.location_visible_cb = QCheckBox("Location visible to others") + self.location_visible_cb.setStyleSheet(f"color: {t.text_primary};") + self.layout_main.addWidget(self.location_visible_cb) + + save_btn = QPushButton("Save") + save_btn.clicked.connect(self._on_save) + self.layout_main.addWidget(save_btn) + + # Security section (only when viewing another user, not own profile) + if not self.editable: + self._security_section = QVBoxLayout() + self.layout_main.addLayout(self._security_section) + + close_btn = QPushButton("Close") + close_btn.setObjectName("secondaryBtn") + close_btn.clicked.connect(self.accept) + self.layout_main.addWidget(close_btn) + + def _connect_signals(self): + self.bridge.profile_loaded.connect(self._on_profile_loaded) + self.bridge.avatar_loaded.connect(self._on_avatar_loaded) + self.bridge.profile_updated.connect(self._on_profile_updated) + + def _on_profile_loaded(self, profile): + if profile.get("user_id") != self.user_id: + return + username = profile.get("username", "?") + self.username_label.setText(username) + # Set avatar initial + self.avatar_label.setText(username[0].upper() if username else "?") + + # Clear info area + while self.info_area.count(): + item = self.info_area.takeAt(0) + if item.widget(): + item.widget().deleteLater() + + # Email + email = profile.get("email") + if email: + self.info_area.addWidget(QLabel(f"Email: {email}")) + + # Phone + phone = profile.get("phone") + if phone: + self.info_area.addWidget(QLabel(f"Phone: {phone}")) + + # Location + location = profile.get("location") + if location: + self.info_area.addWidget(QLabel(f"Location: {location}")) + + # Member since + created_at = profile.get("created_at", "") + if created_at: + date_str = created_at[:10] if len(created_at) >= 10 else created_at + label = QLabel(f"Member since: {date_str}") + label.setStyleSheet(f"color: {c().text_muted};") + self.info_area.addWidget(label) + + # Populate editable fields + if self.editable: + self.phone_input.setText(phone or "") + self.location_input.setText(location or "") + self.email_visible_cb.setChecked(bool(profile.get("email_visible", 1))) + self.phone_visible_cb.setChecked(bool(profile.get("phone_visible", 0))) + self.location_visible_cb.setChecked(bool(profile.get("location_visible", 0))) + + # Try to load avatar + if profile.get("avatar_file"): + self.bridge.get_avatar(self.user_id) + + # Security section (viewing another user, not self) + my_uid = self.bridge.client.session.get("user_id", "") if self.bridge.client.session else "" + if not self.editable and hasattr(self, "_security_section") and self.user_id != my_uid: + self._populate_security_section() + + def _populate_security_section(self): + """Populate security/verification info for a peer user.""" + t = c() + sec = self._security_section + # Clear previous contents + while sec.count(): + item = sec.takeAt(0) + if item.widget(): + item.widget().deleteLater() + + sec.addSpacing(8) + header = QLabel("Security") + header.setStyleSheet(f"font-weight: bold; color: {t.accent};") + sec.addWidget(header) + + status = self.bridge.client.get_verification_status(self.user_id) + if status == "verified": + status_label = QLabel("\u2705 Identity verified") + status_label.setStyleSheet(f"color: {t.success}; font-weight: bold;") + elif status == "trusted": + status_label = QLabel("\U0001f512 Trusted (first use)") + status_label.setStyleSheet(f"color: {t.warning};") + else: + status_label = QLabel("\u26A0 Unverified") + status_label.setStyleSheet(f"color: {t.text_muted};") + sec.addWidget(status_label) + + fp = self.bridge.client.get_peer_fingerprint(self.user_id) + if fp: + fp_label = QLabel(f"Fingerprint:\n{fp}") + fp_label.setStyleSheet( + f"font-family: monospace; font-size: 8pt; color: {t.text_primary}; " + f"background: {t.bg_secondary}; padding: 4px; border-radius: 4px;" + ) + fp_label.setTextInteractionFlags(Qt.TextInteractionFlag.TextSelectableByMouse) + sec.addWidget(fp_label) + + def _on_avatar_loaded(self, user_id, data): + if user_id != self.user_id: + return + qimg = _safe_load_image(data) + if qimg is not None: + pixmap = QPixmap.fromImage(qimg) + # Circular crop + size = 80 + scaled = pixmap.scaled(size, size, Qt.AspectRatioMode.KeepAspectRatioByExpanding, + Qt.TransformationMode.SmoothTransformation) + result = QPixmap(size, size) + result.fill(QColor(0, 0, 0, 0)) + painter = QPainter(result) + painter.setRenderHint(QPainter.RenderHint.Antialiasing) + painter.setBrush(QBrush(scaled)) + painter.setPen(Qt.PenStyle.NoPen) + painter.drawEllipse(0, 0, size, size) + painter.end() + self.avatar_label.setPixmap(result) + + def _on_change_avatar(self): + path, _ = QFileDialog.getOpenFileName( + self, "Select Avatar", "", + "Images (*.png *.jpg *.jpeg);;All Files (*)", + ) + if not path: + return + try: + with open(path, "rb") as f: + data = f.read() + if len(data) > 2 * 1024 * 1024: + QMessageBox.warning(self, "Error", "Avatar too large (max 2 MB).") + return + self.bridge.update_avatar(data) + except Exception as e: + QMessageBox.warning(self, "Error", f"Failed to read file: {e}") + + def _on_save(self): + fields = { + "phone": self.phone_input.text().strip() or None, + "location": self.location_input.text().strip() or None, + "email_visible": 1 if self.email_visible_cb.isChecked() else 0, + "phone_visible": 1 if self.phone_visible_cb.isChecked() else 0, + "location_visible": 1 if self.location_visible_cb.isChecked() else 0, + } + self.bridge.update_profile(**fields) + + def _on_profile_updated(self, ok, msg): + if ok: + # Refresh profile + self.bridge.get_profile(self.user_id) + else: + QMessageBox.warning(self, "Error", msg) + + def closeEvent(self, event): + # Disconnect signals to avoid stale references + try: + self.bridge.profile_loaded.disconnect(self._on_profile_loaded) + self.bridge.avatar_loaded.disconnect(self._on_avatar_loaded) + self.bridge.profile_updated.disconnect(self._on_profile_updated) + except Exception: + pass + super().closeEvent(event) + + def reject(self): + try: + self.bridge.profile_loaded.disconnect(self._on_profile_loaded) + self.bridge.avatar_loaded.disconnect(self._on_avatar_loaded) + self.bridge.profile_updated.disconnect(self._on_profile_updated) + except Exception: + pass + super().reject() + + def accept(self): + try: + self.bridge.profile_loaded.disconnect(self._on_profile_loaded) + self.bridge.avatar_loaded.disconnect(self._on_avatar_loaded) + self.bridge.profile_updated.disconnect(self._on_profile_updated) + except Exception: + pass + super().accept() + + +class VerificationDialog(QDialog): + """Dialog for viewing safety numbers, fingerprints, QR codes, and verifying contacts.""" + + def __init__(self, bridge: AsyncBridge, peer_user_id: str, peer_name: str, parent=None): + super().__init__(parent) + self.bridge = bridge + self.peer_user_id = peer_user_id + self.peer_name = peer_name + self.setMinimumWidth(420) + self._build_ui() + + def _build_ui(self): + t = c() + lay = _make_frameless(self, "Verify Contact") + lay.setSpacing(10) + + # Peer name + name_label = QLabel(self.peer_name) + name_label.setStyleSheet(f"font-size: 14pt; font-weight: bold; color: {t.accent};") + name_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + lay.addWidget(name_label) + + # Verification status + status = self.bridge.client.get_verification_status(self.peer_user_id) + if status == "verified": + status_text = "\u2705 Verified" + status_color = t.success + elif status == "trusted": + status_text = "\U0001f512 Trusted (TOFU)" + status_color = t.warning + else: + status_text = "\u26A0 Unverified" + status_color = t.error + self._status_label = QLabel(status_text) + self._status_label.setStyleSheet(f"font-size: 11pt; font-weight: bold; color: {status_color};") + self._status_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + lay.addWidget(self._status_label) + + lay.addSpacing(4) + + # Safety Number + safety = self.bridge.client.get_safety_number(self.peer_user_id) + if safety: + sn_header = QLabel("Safety Number") + sn_header.setStyleSheet(f"font-weight: bold; color: {t.text_secondary};") + lay.addWidget(sn_header) + + sn_label = QLabel(safety) + sn_label.setStyleSheet( + f"font-family: monospace; font-size: 13pt; letter-spacing: 2px; " + f"color: {t.text_primary}; background: {t.bg_secondary}; " + f"padding: 12px; border-radius: 8px;" + ) + sn_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + sn_label.setTextInteractionFlags(Qt.TextInteractionFlag.TextSelectableByMouse) + lay.addWidget(sn_label) + + lay.addSpacing(4) + + # QR Code + qr_data = self.bridge.client.get_verification_qr_data() + if qr_data: + qr_header = QLabel("Your QR Code (for peer to scan)") + qr_header.setStyleSheet(f"font-weight: bold; color: {t.text_secondary};") + lay.addWidget(qr_header) + + qr_pixmap = self._generate_qr_pixmap(qr_data) + if qr_pixmap: + qr_label = QLabel() + qr_label.setPixmap(qr_pixmap) + qr_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + lay.addWidget(qr_label) + + save_qr_btn = QPushButton("Save QR Code") + save_qr_btn.setObjectName("secondaryBtn") + save_qr_btn.clicked.connect(lambda: self._save_qr(qr_pixmap)) + lay.addWidget(save_qr_btn) + + lay.addSpacing(4) + + # Fingerprints + my_fp = self.bridge.client.get_my_fingerprint() + peer_fp = self.bridge.client.get_peer_fingerprint(self.peer_user_id) + + if my_fp or peer_fp: + fp_header = QLabel("Fingerprints") + fp_header.setStyleSheet(f"font-weight: bold; color: {t.text_secondary};") + lay.addWidget(fp_header) + + if my_fp: + my_fp_label = QLabel(f"Yours:\n{my_fp}") + my_fp_label.setStyleSheet( + f"font-family: monospace; font-size: 9pt; color: {t.text_primary}; " + f"background: {t.bg_secondary}; padding: 6px; border-radius: 4px;" + ) + my_fp_label.setTextInteractionFlags(Qt.TextInteractionFlag.TextSelectableByMouse) + lay.addWidget(my_fp_label) + + if peer_fp: + peer_name_esc = self.peer_name.replace("&", "&").replace("<", "<") + peer_fp_label = QLabel(f"{peer_name_esc}:\n{peer_fp}") + peer_fp_label.setStyleSheet( + f"font-family: monospace; font-size: 9pt; color: {t.text_primary}; " + f"background: {t.bg_secondary}; padding: 6px; border-radius: 4px;" + ) + peer_fp_label.setTextInteractionFlags(Qt.TextInteractionFlag.TextSelectableByMouse) + lay.addWidget(peer_fp_label) + + lay.addSpacing(8) + + # Action buttons + btn_row = QHBoxLayout() + + if status != "verified": + verify_btn = QPushButton("Mark as Verified") + verify_btn.setStyleSheet( + f"QPushButton {{ background-color: {t.success}; color: {t.bg_primary}; " + f"font-weight: bold; padding: 8px 16px; border-radius: 6px; }}" + f"QPushButton:hover {{ opacity: 0.9; }}" + ) + verify_btn.clicked.connect(self._on_verify) + btn_row.addWidget(verify_btn) + else: + unverify_btn = QPushButton("Remove Verification") + unverify_btn.setStyleSheet( + f"QPushButton {{ background-color: {t.warning}; color: {t.bg_primary}; " + f"font-weight: bold; padding: 8px 16px; border-radius: 6px; }}" + ) + unverify_btn.clicked.connect(self._on_unverify) + btn_row.addWidget(unverify_btn) + + # Scan QR button + scan_btn = QPushButton("Scan QR Code") + scan_btn.setObjectName("secondaryBtn") + scan_btn.clicked.connect(self._on_scan_qr) + btn_row.addWidget(scan_btn) + + lay.addLayout(btn_row) + + close_btn = QPushButton("Close") + close_btn.setObjectName("secondaryBtn") + close_btn.clicked.connect(self.accept) + lay.addWidget(close_btn) + + def _generate_qr_pixmap(self, data: bytes) -> QPixmap | None: + """Generate a QR code QPixmap from raw bytes (base64-encoded for scanner compat).""" + try: + import qrcode + import base64 + from io import BytesIO + # Encode as base64 — raw binary gets corrupted by QR readers (UTF-8 re-encoding) + qr = qrcode.QRCode(version=1, error_correction=qrcode.constants.ERROR_CORRECT_L, + box_size=4, border=2) + qr.add_data(base64.b64encode(data).decode("ascii")) + qr.make(fit=True) + img = qr.make_image(fill_color="black", back_color="white") + buf = BytesIO() + img.save(buf, format="PNG") + buf.seek(0) + qimg = QImage() + qimg.loadFromData(buf.getvalue()) + if qimg.isNull(): + return None + return QPixmap.fromImage(qimg) + except ImportError: + return None + except Exception: + return None + + def _save_qr(self, pixmap: QPixmap): + """Save QR code image to file.""" + path, _ = QFileDialog.getSaveFileName( + self, "Save QR Code", "verification_qr.png", + "PNG Images (*.png);;All Files (*)", + ) + if path: + pixmap.save(path, "PNG") + + def _on_verify(self): + cached = self.bridge.client._user_cache.get(self.peer_user_id) + if cached and cached.get("identity_key_bytes"): + self.bridge.client.verify_contact( + self.peer_user_id, cached["identity_key_bytes"], method="safety_number" + ) + t = c() + self._status_label.setText("\u2705 Verified") + self._status_label.setStyleSheet(f"font-size: 11pt; font-weight: bold; color: {t.success};") + + def _on_unverify(self): + self.bridge.client.unverify_contact(self.peer_user_id) + t = c() + self._status_label.setText("\U0001f512 Trusted (TOFU)") + self._status_label.setStyleSheet(f"font-size: 11pt; font-weight: bold; color: {t.warning};") + + def _on_scan_qr(self): + """Open file picker for QR code image, decode and verify.""" + path, _ = QFileDialog.getOpenFileName( + self, "Select QR Code Image", "", + "Images (*.png *.jpg *.jpeg *.bmp);;All Files (*)", + ) + if not path: + return + try: + from PIL import Image + pil_img = Image.open(path) + except Exception as e: + QMessageBox.warning(self, "Error", f"Failed to open image: {e}") + return + # Try pyzbar first, fall back to manual decode + qr_text = None + try: + from pyzbar.pyzbar import decode as pyzbar_decode + results = pyzbar_decode(pil_img) + if results: + qr_text = results[0].data + except ImportError: + pass + if qr_text is None: + QMessageBox.information( + self, "QR Scan", + "Could not decode QR code. Install 'pyzbar' for QR scanning support, " + "or verify manually using the safety number above." + ) + return + # QR contains base64-encoded binary payload + import base64 + try: + qr_data = base64.b64decode(qr_text) + except Exception: + QMessageBox.warning(self, "Error", "Invalid QR code format.") + return + ok, user_id, message = self.bridge.client.verify_qr_code(qr_data) + if ok: + t = c() + self._status_label.setText("\u2705 Verified") + self._status_label.setStyleSheet(f"font-size: 11pt; font-weight: bold; color: {t.success};") + QMessageBox.information(self, "Verification", message) + else: + QMessageBox.warning(self, "Verification Failed", message) + + +class LoginWindow(QWidget): + def __init__(self, bridge: AsyncBridge): + super().__init__() + self.bridge = bridge + self.setWindowTitle("Encrypted Chat - Login") + self.setFixedSize(500, 540) + self._pair_email = "" + self._pair_password = "" + self._build_ui() + tm().on_change(self._apply_theme) + + def _login_card_qss(self): + t = c() + return ( + f"#loginCard {{ background-color: {t.bg_primary}; border-radius: 16px; }}" + f"#loginCard QWidget {{ background: transparent; }}" + f"#loginCard QLabel {{ background: transparent; color: {t.text_primary}; }}" + f"#loginCard QLineEdit {{" + f" background-color: {t.bg_secondary}; color: {t.text_primary};" + f" border: 1px solid {t.border}; border-radius: 6px; padding: 8px;" + f"}}" + f"#loginCard QLineEdit:focus {{ border: 1px solid {t.border_focus}; }}" + f"#loginCard QPushButton {{" + f" background-color: {t.accent}; color: {t.accent_text};" + f" border: none; border-radius: 6px; padding: 8px 16px; font-weight: bold;" + f"}}" + f"#loginCard QPushButton:hover {{ background-color: {t.accent_hover}; }}" + ) + + def _tab_bar_qss(self): + t = c() + return ( + f"QTabBar {{ background: transparent; border: none; }}" + f"QTabBar::tab {{ background: transparent; color: {t.text_muted}; " + f"padding: 10px 24px; font-size: 10pt; border: none; " + f"border-bottom: 2px solid transparent; }}" + f"QTabBar::tab:selected {{ color: {t.accent}; font-weight: bold; " + f"border-bottom: 2px solid {t.accent}; }}" + f"QTabBar::tab:hover {{ color: {t.text_primary}; }}" + f"QTabWidget::pane {{ border: none; background: transparent; }}" + ) + + def _apply_theme(self): + QApplication.instance().setStyleSheet(qss()) + t = c() + self.setStyleSheet(f"background-color: {t.bg_tertiary};") + self._card.setStyleSheet(self._login_card_qss()) + self._subtitle.setStyleSheet(f"color: {t.text_muted}; font-size: 9pt; margin-bottom: 8px;") + self._theme_btn.setText("\u2600" if tm().is_dark else "\U0001f319") + self._theme_btn.setStyleSheet( + f"QPushButton {{ background: transparent; color: {t.text_muted}; " + f"border: none; font-size: 14pt; }}" + f"QPushButton:hover {{ color: {t.accent}; }}" + ) + self._tabs.setStyleSheet(self._tab_bar_qss()) + # Verification page + self._step_label.setStyleSheet(f"color: {t.accent}; font-weight: bold; font-size: 10pt;") + self._info_label.setStyleSheet(f"color: {t.text_primary}; font-size: 10pt;") + self.code_input.setStyleSheet( + f"QLineEdit {{ font-size: 16pt; letter-spacing: 8px; text-align: center; " + f"background-color: {t.bg_secondary}; border: 1px solid {t.border}; " + f"border-radius: 6px; padding: 12px; color: {t.text_primary}; }}" + f"QLineEdit:focus {{ border: 1px solid {t.border_focus}; }}" + ) + + def _build_ui(self): + from PyQt6.QtWidgets import QTabWidget + outer = QVBoxLayout(self) + outer.setContentsMargins(0, 0, 0, 0) + t = c() + + # Background fills entire window + self.setStyleSheet(f"background-color: {t.bg_tertiary};") + + # Theme toggle in top-right corner + top_row = QHBoxLayout() + top_row.setContentsMargins(12, 8, 12, 0) + top_row.addStretch() + self._theme_btn = QPushButton("\u2600" if tm().is_dark else "\U0001f319") + self._theme_btn.setFixedSize(36, 36) + self._theme_btn.setToolTip("Toggle light/dark mode") + self._theme_btn.setStyleSheet( + f"QPushButton {{ background: transparent; color: {t.text_muted}; " + f"border: none; font-size: 14pt; }}" + f"QPushButton:hover {{ color: {t.accent}; }}" + ) + self._theme_btn.clicked.connect(tm().toggle) + top_row.addWidget(self._theme_btn) + outer.addLayout(top_row) + + self.stack = QStackedWidget() + outer.addWidget(self.stack) + + # --- Page 0: Login / Register form (card) --- + page0 = QWidget() + page0_layout = QVBoxLayout(page0) + page0_layout.setContentsMargins(40, 0, 40, 20) + page0_layout.setAlignment(Qt.AlignmentFlag.AlignCenter) + + self._card = QWidget() + self._card.setObjectName("loginCard") + self._card.setStyleSheet(self._login_card_qss()) + card_layout = QVBoxLayout(self._card) + card_layout.setSpacing(10) + card_layout.setContentsMargins(36, 28, 36, 28) + + title = QLabel("Encrypted Chat") + title.setObjectName("title") + title.setAlignment(Qt.AlignmentFlag.AlignCenter) + card_layout.addWidget(title) + + self._subtitle = QLabel("End-to-end encrypted messaging") + self._subtitle.setAlignment(Qt.AlignmentFlag.AlignCenter) + self._subtitle.setStyleSheet(f"color: {t.text_muted}; font-size: 9pt; margin-bottom: 8px;") + card_layout.addWidget(self._subtitle) + + # --- Tabs: Login | Register | Link Device --- + self._tabs = QTabWidget() + self._tabs.setMinimumHeight(220) + self._tabs.setStyleSheet(self._tab_bar_qss()) + + # == Login tab == + login_tab = QWidget() + login_lay = QVBoxLayout(login_tab) + login_lay.setSpacing(12) + login_lay.setContentsMargins(0, 12, 0, 4) + + self.email_input = QLineEdit() + self.email_input.setPlaceholderText("Email") + login_lay.addWidget(self.email_input) + + self.password_input = QLineEdit() + self.password_input.setPlaceholderText("Password") + self.password_input.setEchoMode(QLineEdit.EchoMode.Password) + self.password_input.returnPressed.connect(self._on_login) + login_lay.addWidget(self.password_input) + + login_lay.addStretch() + self.login_btn = QPushButton("Login") + self.login_btn.setMinimumHeight(40) + self.login_btn.clicked.connect(self._on_login) + login_lay.addWidget(self.login_btn) + + self._tabs.addTab(login_tab, "Login") + + # == Register tab == + reg_tab = QWidget() + reg_lay = QVBoxLayout(reg_tab) + reg_lay.setSpacing(12) + reg_lay.setContentsMargins(0, 12, 0, 4) + + self.username_input = QLineEdit() + self.username_input.setPlaceholderText("Username (display name)") + reg_lay.addWidget(self.username_input) + + self._reg_email_input = QLineEdit() + self._reg_email_input.setPlaceholderText("Email") + reg_lay.addWidget(self._reg_email_input) + + self._reg_password_input = QLineEdit() + self._reg_password_input.setPlaceholderText("Password") + self._reg_password_input.setEchoMode(QLineEdit.EchoMode.Password) + self._reg_password_input.returnPressed.connect(self._on_register) + reg_lay.addWidget(self._reg_password_input) + + reg_lay.addStretch() + self.register_btn = QPushButton("Register") + self.register_btn.setMinimumHeight(40) + self.register_btn.clicked.connect(self._on_register) + reg_lay.addWidget(self.register_btn) + + self._tabs.addTab(reg_tab, "Register") + + # == Link Device tab == + link_tab = QWidget() + link_lay = QVBoxLayout(link_tab) + link_lay.setSpacing(12) + link_lay.setContentsMargins(0, 12, 0, 4) + + self._link_email_input = QLineEdit() + self._link_email_input.setPlaceholderText("Email") + link_lay.addWidget(self._link_email_input) + + self._link_password_input = QLineEdit() + self._link_password_input.setPlaceholderText("Password") + self._link_password_input.setEchoMode(QLineEdit.EchoMode.Password) + self._link_password_input.returnPressed.connect(self._on_link_device) + link_lay.addWidget(self._link_password_input) + + link_lay.addStretch() + self.link_btn = QPushButton("Link Device") + self.link_btn.setMinimumHeight(40) + self.link_btn.clicked.connect(self._on_link_device) + link_lay.addWidget(self.link_btn) + + self._tabs.addTab(link_tab, "Link Device") + + card_layout.addWidget(self._tabs) + + self.status_label = QLabel("") + self.status_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.status_label.setWordWrap(True) + card_layout.addWidget(self.status_label) + + page0_layout.addWidget(self._card) + self.stack.addWidget(page0) + + # --- Page 1: Verification code form --- + page1 = QWidget() + vl = QVBoxLayout(page1) + vl.setSpacing(14) + vl.setContentsMargins(50, 40, 50, 40) + + self._step_label = QLabel("Step 2 of 2") + self._step_label.setStyleSheet(f"color: {t.accent}; font-weight: bold; font-size: 10pt;") + self._step_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + vl.addWidget(self._step_label) + + self._info_label = QLabel("Enter the 6-digit verification code sent to your email") + self._info_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + self._info_label.setWordWrap(True) + self._info_label.setStyleSheet(f"color: {t.text_primary}; font-size: 10pt;") + vl.addWidget(self._info_label) + + vl.addSpacing(12) + + self.code_input = QLineEdit() + self.code_input.setPlaceholderText("000000") + self.code_input.setMaxLength(6) + self.code_input.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.code_input.setStyleSheet( + f"QLineEdit {{ font-size: 16pt; letter-spacing: 8px; text-align: center; " + f"background-color: {t.bg_secondary}; border: 1px solid {t.border}; " + f"border-radius: 6px; padding: 12px; color: {t.text_primary}; }}" + f"QLineEdit:focus {{ border: 1px solid {t.border_focus}; }}" + ) + self.code_input.returnPressed.connect(self._on_confirm_code) + vl.addWidget(self.code_input) + + vl.addSpacing(8) + + self.code_status_label = QLabel("") + self.code_status_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.code_status_label.setWordWrap(True) + vl.addWidget(self.code_status_label) + + code_btn_row = QHBoxLayout() + self.back_btn = QPushButton("Back") + self.back_btn.setObjectName("secondaryBtn") + self.back_btn.clicked.connect(self._on_back_to_login) + code_btn_row.addWidget(self.back_btn) + + self.confirm_btn = QPushButton("Confirm") + self.confirm_btn.clicked.connect(self._on_confirm_code) + code_btn_row.addWidget(self.confirm_btn) + vl.addLayout(code_btn_row) + + vl.addStretch() + self.stack.addWidget(page1) + + def show_verification_page(self, message=""): + """Switch to verification code page.""" + self.code_input.clear() + self.code_status_label.setText(message) + self.code_status_label.setStyleSheet(f"color: {c().success};") + self.stack.setCurrentIndex(1) + self.code_input.setFocus() + + def _on_confirm_code(self): + code = self.code_input.text().strip() + if not code: + self.code_status_label.setText("Please enter the code.") + self.code_status_label.setStyleSheet(f"color: {c().error};") + return + self.code_status_label.setText("Confirming...") + self.code_status_label.setStyleSheet(f"color: {c().success};") + self.confirm_btn.setEnabled(False) + self.back_btn.setEnabled(False) + # Callback set by main() to handle confirmation + if hasattr(self, '_confirm_callback'): + self._confirm_callback(code) + + def _on_back_to_login(self): + self.stack.setCurrentIndex(0) + self._set_enabled(True) + self.status_label.setText("") + self.status_label.setStyleSheet("") + + def _on_register(self): + username = self.username_input.text().strip() + password = self._reg_password_input.text() + email = self._reg_email_input.text().strip() + if not username: + self.show_error("Username required.") + return + if not email or not password: + self.show_error("Email and password required.") + return + self.status_label.setText("Registering...") + self._set_enabled(False) + self.bridge.do_register(username, password, email) + + def _on_login(self): + email = self.email_input.text().strip() + password = self.password_input.text() + if not email or not password: + self.show_error("Email and password required.") + return + self.status_label.setText("Logging in...") + self._set_enabled(False) + self.bridge.do_login(email, password) + + def _on_link_device(self): + email = self._link_email_input.text().strip() + password = self._link_password_input.text() + if not email or not password: + self.show_error("Email and password required.") + return + self._pair_email = email + self._pair_password = password + self.status_label.setText("Generating pairing code...") + self._set_enabled(False) + self.bridge.link_device(email, password) + + def _set_enabled(self, enabled): + self._tabs.setEnabled(enabled) + self.username_input.setEnabled(enabled) + self.email_input.setEnabled(enabled) + self.password_input.setEnabled(enabled) + self._reg_email_input.setEnabled(enabled) + self._reg_password_input.setEnabled(enabled) + self._link_email_input.setEnabled(enabled) + self._link_password_input.setEnabled(enabled) + self.register_btn.setEnabled(enabled) + self.login_btn.setEnabled(enabled) + self.link_btn.setEnabled(enabled) + + def show_error(self, msg): + if self.stack.currentIndex() == 1: + self.code_status_label.setText(msg) + self.code_status_label.setStyleSheet(f"color: {c().error};") + self.confirm_btn.setEnabled(True) + self.back_btn.setEnabled(True) + else: + self.status_label.setText(msg) + self.status_label.setStyleSheet(f"color: {c().error};") + self._set_enabled(True) + + def show_success(self, msg): + if self.stack.currentIndex() == 1: + self.code_status_label.setText(msg) + self.code_status_label.setStyleSheet(f"color: {c().success};") + else: + self.status_label.setText(msg) + self.status_label.setStyleSheet(f"color: {c().success};") + + def show_pairing_qr(self, code: str, fingerprint: str, qr_data: bytes | None): + if not qr_data: + return + pixmap = _generate_qr_pixmap_bytes(qr_data) + if not pixmap: + return + dlg = QDialog(self) + dlg.setMinimumWidth(340) + t = c() + lay = _make_frameless(dlg, "Pairing QR") + info = QLabel( + "Scan this QR on the already logged-in device, or save it and load the image there." + ) + info.setWordWrap(True) + info.setStyleSheet(f"color: {t.text_primary}; font-size: 9pt;") + lay.addWidget(info) + + qr_label = QLabel() + qr_label.setPixmap(pixmap) + qr_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + lay.addWidget(qr_label) + + details = QLabel(f"Code: {code}\nFingerprint:\n{fingerprint}") + details.setAlignment(Qt.AlignmentFlag.AlignCenter) + details.setTextInteractionFlags(Qt.TextInteractionFlag.TextSelectableByMouse) + details.setStyleSheet( + f"font-family: monospace; font-size: 9pt; color: {t.text_primary}; " + f"background: {t.bg_secondary}; padding: 8px; border-radius: 6px;" + ) + lay.addWidget(details) + + save_btn = QPushButton("Save QR Code") + save_btn.setObjectName("secondaryBtn") + save_btn.clicked.connect(lambda: self._save_pairing_qr(pixmap)) + lay.addWidget(save_btn) + + close_btn = QPushButton("Close") + close_btn.setObjectName("secondaryBtn") + close_btn.clicked.connect(dlg.accept) + lay.addWidget(close_btn) + dlg.exec() + + def _save_pairing_qr(self, pixmap: QPixmap): + path, _ = QFileDialog.getSaveFileName( + self, "Save Pairing QR Code", "pairing_qr.png", + "PNG Images (*.png);;All Files (*)", + ) + if path: + pixmap.save(path, "PNG") + + def reset(self): + self.stack.setCurrentIndex(0) + self._tabs.setCurrentIndex(0) + self.status_label.setText("") + self.status_label.setStyleSheet("") + self.code_status_label.setText("") + self.code_status_label.setStyleSheet("") + self.username_input.clear() + self.email_input.clear() + self.password_input.clear() + self._reg_email_input.clear() + self._reg_password_input.clear() + self._link_email_input.clear() + self._link_password_input.clear() + self.code_input.clear() + self._set_enabled(True) + self.confirm_btn.setEnabled(True) + self.back_btn.setEnabled(True) + + +class MessageBubble(QFrame): + """Chat message bubble with rounded corners drawn via QPainter.""" + + def __init__(self, bg_color: str, parent=None): + super().__init__(parent) + self._bg_color = QColor(bg_color) + self.setStyleSheet("background: transparent; border: none;") + self.setContextMenuPolicy(Qt.ContextMenuPolicy.DefaultContextMenu) + + def set_bg_color(self, color: str): + self._bg_color = QColor(color) + self.update() + + def paintEvent(self, event): + p = QPainter(self) + p.setRenderHint(QPainter.RenderHint.Antialiasing) + p.setBrush(QBrush(self._bg_color)) + p.setPen(Qt.PenStyle.NoPen) + p.drawRoundedRect(self.rect(), 14, 14) + p.end() + + def contextMenuEvent(self, event): + # Walk up to find _msg_index, then call main window handler + idx = getattr(self, '_msg_index', None) + if idx is None: + p = self.parentWidget() + while p: + idx = getattr(p, '_msg_index', None) + if idx is not None: + break + p = p.parentWidget() + main_win = self.window() + if idx is not None and hasattr(main_win, '_show_msg_context_menu'): + main_win._show_msg_context_menu(idx, event.globalPos()) + event.accept() + + +class MainWindow(QWidget): + _AVATAR_REFRESH_BATCH = 8 + _GROUP_AVATAR_REFRESH_BATCH = 4 + _show_verification_dialog_signal = pyqtSignal(str, str) # peer_uid, peer_name + + def __init__(self, bridge: AsyncBridge, on_logout): + super().__init__() + self.bridge = bridge + self._on_logout_cb = on_logout + self.setWindowTitle(f"Encrypted Chat - {bridge.client.username}") + self.resize(900, 600) + + self.conversations: list[dict] = [] + self.current_conv_id: str | None = None + self.current_messages: list[dict] = [] + self.reply_to_id: str | None = None + self._unread_counts: dict[str, int] = {} + self._has_more_messages: bool = True + self._pending_image_download: dict | None = None # {file_id, image_info} + self._is_dm: bool = False + self._online_users: set[str] = set() + self._is_logout = False + self._avatar_cache = _LRUPixmapCache() # user_id -> pixmap + self._group_avatar_cache = _LRUPixmapCache() # conv_id -> pixmap + self._avatar_requested: set[str] = set() + self._group_avatar_requested: set[str] = set() + self._pending_invitations: list[dict] = [] + self._favorites: set[str] = self._load_favorites() + # Search state + self._search_results: list[int] = [] # indices into current_messages + self._search_current: int = -1 + self._search_query: str = "" + self._search_active: bool = False + + self._privacy_enabled: bool = True # Privacy overlay on/off + self._last_message_cache: dict[str, tuple[str, str, str]] = {} # conv_id -> (text, ts, receipt) + self._typing_users: dict[str, dict[str, str]] = {} # conv_id -> {user_id: username} + self._typing_timers: dict[tuple[str, str], QTimer] = {} # (conv_id, user_id) -> timer + + self._build_ui() + self._connect_signals() + self._setup_tray_icon() + self._setup_privacy_overlay() + + # Keyboard shortcuts + QShortcut(QKeySequence("Ctrl+F"), self).activated.connect(self._toggle_search) + QShortcut(QKeySequence("Ctrl+Shift+P"), self).activated.connect(self._toggle_privacy) + + self._preload_avatars_from_disk() + self.bridge.load_conversations() + self.bridge.list_invitations() + + # Periodic refresh: re-download avatars and conversation data + self._refresh_timer = QTimer(self) + self._refresh_timer.timeout.connect(self._on_periodic_refresh) + self._refresh_timer.start(120_000) # every 2 minutes + + tm().on_change(self._apply_theme) + + # -- Theme switching ------------------------------------------------------- + + def _apply_theme(self): + """Re-apply theme colours to all widgets after theme toggle.""" + app = QApplication.instance() + if app: + app.setStyleSheet(qss()) + t = c() + # Sidebar + self._sidebar_panel.setStyleSheet(f"#sidebarPanel {{ background-color: {t.bg_tertiary}; }}") + self._conv_label.setStyleSheet(f"font-weight: bold; font-size: 12pt; color: {t.accent}; background: transparent;") + self._settings_btn.setStyleSheet( + f"QPushButton {{ background: transparent; color: {t.text_secondary}; border: none; border-radius: 6px; padding: 8px 16px; }}" + f"QPushButton:hover {{ background-color: {t.bg_hover}; }}" + ) + self.conv_list.setStyleSheet( + f"QListWidget {{ background-color: {t.bg_tertiary}; border: none; padding: 0px; }}" + f"QListWidget::item {{ padding: 0px; border: none; }}" + f"QListWidget::item:selected {{ background: transparent; border: none; }}" + f"QListWidget::item:hover {{ background: transparent; }}" + ) + # Invitation list + self.inv_label.setStyleSheet(f"font-weight: bold; font-size: 9pt; color: {t.warning}; margin-top: 4px;") + self.inv_list.setStyleSheet( + f"QListWidget {{ background-color: {t.bg_primary}; border: 1px solid {t.warning}; border-radius: 6px; padding: 2px; }}" + f"QListWidget::item {{ padding: 6px; color: {t.text_primary}; }}" + f"QListWidget::item:hover {{ background-color: {t.bg_hover}; color: {t.text_primary}; }}" + ) + # Chat header + self._chat_header_widget.setStyleSheet( + f"#chatHeader {{ border-bottom: 1px solid {t.separator}; }}" + f"#chatHeader QLabel, #chatHeader QPushButton {{ border: none; }}" + ) + self.chat_header.setStyleSheet(f"font-weight: bold; font-size: 12pt; color: {t.accent};") + self._chat_header_status.setStyleSheet(f"color: {t.text_muted}; font-size: 8pt;") + self._e2e_label.setStyleSheet(f"font-size: 8pt; color: {t.text_muted}; background: transparent;") + self.connection_dot.setStyleSheet(f"color: {t.success}; font-size: 11pt;") + self._logout_btn.setStyleSheet( + f"QPushButton {{ background: transparent; color: {t.error}; border: none; " + f"border-radius: 16px; font-size: 13pt; font-weight: bold; }}" + f"QPushButton:hover {{ background-color: {t.error}; color: {t.accent_text}; }}" + ) + self.delete_conv_btn.setStyleSheet( + f"QPushButton {{ background: transparent; border: none; border-radius: 4px; padding: 4px; }}" + f"QPushButton:hover {{ background-color: {t.error}; }}" + ) + # Search bar + self.search_input.setStyleSheet( + f"QLineEdit {{ background-color: {t.bg_secondary}; color: {t.text_primary}; " + f"border: 1px solid {t.border}; border-radius: 4px; padding: 4px 8px; font-size: 10pt; }}" + ) + self.search_count_label.setStyleSheet(f"color: {t.text_muted}; font-size: 9pt; min-width: 40px;") + # Pin banner + self._pin_banner.setStyleSheet(f"background-color:{t.border}; border-bottom:2px solid {t.pin_color};") + self._pin_banner_label.setStyleSheet(f"color:{t.text_primary}; font-size:9pt; background:transparent; border:none;") + # Jump button + self.jump_btn.setStyleSheet( + f"QPushButton {{ background-color: {t.accent}; color: {t.accent_text}; border-radius: 18px; " + f"font-size: 14pt; font-weight: bold; }}" + f"QPushButton:hover {{ background-color: {t.accent_hover}; }}" + ) + # Reply label + self.reply_label.setStyleSheet( + f"color: {t.accent}; font-style: italic; font-size: 9pt; " + f"padding: 2px 4px; background: transparent;" + ) + # Input area + self._attach_btn.setStyleSheet( + f"QPushButton {{ background-color: {t.bg_secondary}; border: none; " + f"border-radius: 20px; font-size: 14pt; }}" + f"QPushButton:hover {{ background-color: {t.bg_hover}; }}" + ) + self._send_btn.setStyleSheet( + f"QPushButton {{ background-color: {t.accent}; color: {t.accent_text}; " + f"border: none; border-radius: 20px; font-size: 14pt; font-weight: bold; }}" + f"QPushButton:hover {{ background-color: {t.accent_hover}; }}" + ) + self.msg_input.setStyleSheet(self.msg_input._style_normal()) + # Counters + self.char_counter.setStyleSheet(f"color: {t.text_muted}; font-size: 8pt; padding: 0 4px;") + self.reencrypt_label.setStyleSheet( + f"background-color: {t.bg_secondary}; border-radius: 6px; " + f"padding: 8px 12px; color: {t.success}; font-weight: bold;" + ) + # Status bar + self.status_bar.setStyleSheet( + f"background-color: {t.bg_tertiary}; border-radius: 0px; " + f"padding: 0 8px; color: {t.success}; font-size: 8pt;" + ) + # Privacy overlay + if hasattr(self, "_privacy_overlay"): + self._privacy_overlay.setStyleSheet(f"background-color: {t.overlay};") + self._lock_hint.setStyleSheet(f"font-size: 12pt; color: {t.text_muted}; background: transparent;") + self._lock_input.setStyleSheet( + f"QLineEdit {{ font-size: 11pt; background-color: {t.bg_secondary}; " + f"border: 1px solid {t.border}; border-radius: 6px; padding: 8px; " + f"color: {t.text_primary}; }}" + f"QLineEdit:focus {{ border: 1px solid {t.border_focus}; }}" + ) + self._lock_error.setStyleSheet(f"font-size: 9pt; color: {t.error}; background: transparent;") + # Mention popup + if hasattr(self, "_mention_popup"): + self._mention_popup.setStyleSheet( + f"QListWidget {{ background:{t.bg_secondary}; color:{t.text_primary}; border:1px solid {t.border}; " + f"border-radius:4px; font-size:10pt; }}" + f"QListWidget::item {{ padding:4px 8px; }}" + f"QListWidget::item:selected {{ background:{t.border}; }}" + ) + # Message scroll area + if hasattr(self, '_msg_scroll_area'): + self._msg_scroll_area.setStyleSheet( + f"QScrollArea {{ background-color: {t.bg_primary}; border: none; }}" + ) + self._msg_container.setStyleSheet( + f"background-color: {t.bg_primary};" + ) + # Re-render messages and conversation list + self._rebuild_conv_list() + if self.current_messages: + self._render_messages(scroll_to_bottom=False) + + # -- Privacy Overlay (lock screen) ---------------------------------------- + + _LOCK_TIMEOUT_MS = 300_000 # 5 min unfocused → lock (require password) + + def _setup_privacy_overlay(self): + """Create overlay that hides content on focus loss; locks after timeout.""" + self._privacy_locked = False + # Check if identity key is password-encrypted (ECP1 format) + # If not, lock feature is disabled (no password to verify against) + self._lock_capable = False + try: + from chat_core import get_key_dir + key_path = get_key_dir(self.bridge.client.email) / "identity_private.bin" + if key_path.exists(): + self._lock_capable = key_path.read_bytes()[:4] == b"ECP1" + except Exception: + pass + + t = c() + # -- overlay widget -- + self._privacy_overlay = QWidget(self) + self._privacy_overlay.setStyleSheet( + f"background-color: {t.overlay};" + ) + self._privacy_overlay.hide() + + overlay_layout = QVBoxLayout(self._privacy_overlay) + overlay_layout.setAlignment(Qt.AlignmentFlag.AlignCenter) + + lock_icon = QLabel("\U0001f512") + lock_icon.setStyleSheet("font-size: 36pt; background: transparent;") + lock_icon.setAlignment(Qt.AlignmentFlag.AlignCenter) + overlay_layout.addWidget(lock_icon) + + self._lock_hint = QLabel("Encrypted Chat") + self._lock_hint.setStyleSheet( + f"font-size: 12pt; color: {t.text_muted}; background: transparent;" + ) + self._lock_hint.setAlignment(Qt.AlignmentFlag.AlignCenter) + overlay_layout.addWidget(self._lock_hint) + + # Password input (hidden until locked) + self._lock_input = QLineEdit() + self._lock_input.setPlaceholderText("Enter password to unlock") + self._lock_input.setEchoMode(QLineEdit.EchoMode.Password) + self._lock_input.setMaximumWidth(280) + self._lock_input.setStyleSheet( + f"QLineEdit {{ font-size: 11pt; background-color: {t.bg_secondary}; " + f"border: 1px solid {t.border}; border-radius: 6px; padding: 8px; " + f"color: {t.text_primary}; }}" + f"QLineEdit:focus {{ border: 1px solid {t.border_focus}; }}" + ) + self._lock_input.returnPressed.connect(self._on_unlock_attempt) + self._lock_input.hide() + overlay_layout.addWidget(self._lock_input, alignment=Qt.AlignmentFlag.AlignCenter) + + self._lock_error = QLabel("") + self._lock_error.setStyleSheet( + f"font-size: 9pt; color: {t.error}; background: transparent;" + ) + self._lock_error.setAlignment(Qt.AlignmentFlag.AlignCenter) + self._lock_error.hide() + overlay_layout.addWidget(self._lock_error) + + # Timer: after N seconds unfocused → require password + self._lock_timer = QTimer(self) + self._lock_timer.setSingleShot(True) + self._lock_timer.timeout.connect(self._on_lock_timeout) + + def _toggle_privacy(self): + """Toggle privacy overlay on/off (Ctrl+Shift+P).""" + self._privacy_enabled = not self._privacy_enabled + if not self._privacy_enabled: + self._privacy_locked = False + self._lock_timer.stop() + self._hide_privacy_overlay() + state = "ON" if self._privacy_enabled else "OFF" + base_title = f"Encrypted Chat - {self.bridge.client.username}" + self.setWindowTitle(f"{base_title} [Privacy: {state}]") + QTimer.singleShot(2000, lambda: self.setWindowTitle(base_title)) + + def _show_privacy_overlay(self): + if not self._privacy_enabled: + return + if not self._privacy_overlay.isVisible(): + self._privacy_overlay.setGeometry(self.rect()) + self._privacy_overlay.raise_() + self._privacy_overlay.show() + + def _hide_privacy_overlay(self): + self._lock_timer.stop() + self._lock_input.hide() + self._lock_input.clear() + self._lock_error.hide() + self._lock_hint.setText("Encrypted Chat") + if self._privacy_overlay.isVisible(): + self._privacy_overlay.hide() + + def _on_lock_timeout(self): + """Window unfocused too long — show overlay and require password.""" + if not self._privacy_enabled or not self._lock_capable: + return + if self.isActiveWindow(): + return # User came back before timeout fired + self._show_privacy_overlay() + self._privacy_locked = True + self._lock_hint.setText("Session locked") + self._lock_input.show() + self._lock_input.setFocus() + + def _on_unlock_attempt(self): + """Verify password by decrypting identity key from disk.""" + from chat_core import get_key_dir, _check_lockout, _record_failed_attempt, _clear_lockout + from crypto_utils import _decrypt_private_key + pwd = self._lock_input.text() + if not pwd: + return + email = self.bridge.client.email + remaining = _check_lockout(email) + if remaining > 0: + self._lock_error.setText(f"Too many attempts. Wait {remaining:.0f}s.") + self._lock_error.show() + self._lock_input.clear() + self._lock_input.setFocus() + return + try: + key_path = get_key_dir(email) / "identity_private.bin" + data = key_path.read_bytes() + _decrypt_private_key(data, pwd.encode("utf-8")) + # Success — unlock + _clear_lockout(email) + self._privacy_locked = False + self._hide_privacy_overlay() + except Exception: + _record_failed_attempt(email) + remaining = _check_lockout(email) + if remaining > 0: + self._lock_error.setText(f"Wrong password. Wait {remaining:.0f}s.") + else: + self._lock_error.setText("Wrong password") + self._lock_error.show() + self._lock_input.clear() + self._lock_input.setFocus() + + def changeEvent(self, event): + """Handle window state changes — tray minimize + privacy overlay.""" + from PyQt6.QtCore import QEvent + if event.type() == QEvent.Type.WindowStateChange: + if self.isMinimized() and self._tray_icon is not None: + event.ignore() + self.hide() + return + if event.type() == QEvent.Type.ActivationChange: + if self.isActiveWindow(): + # Window got focus back + self._lock_timer.stop() + if not self._privacy_locked: + self._hide_privacy_overlay() + else: + # Locked — keep overlay, focus password input + self._lock_input.setFocus() + else: + # Window lost focus — only start lock timer, no immediate overlay + if self._privacy_enabled and self._lock_capable: + self._lock_timer.start(self._LOCK_TIMEOUT_MS) + super().changeEvent(event) + + def resizeEvent(self, event): + """Keep privacy overlay sized to window.""" + super().resizeEvent(event) + if hasattr(self, "_privacy_overlay"): + self._privacy_overlay.setGeometry(self.rect()) + + # -- System Tray ---------------------------------------------------------- + + def _make_tray_icon(self) -> QIcon: + """Create a simple app icon (blue chat bubble) for the system tray.""" + px = QPixmap(64, 64) + px.fill(QColor(0, 0, 0, 0)) + p = QPainter(px) + p.setRenderHint(QPainter.RenderHint.Antialiasing) + p.setBrush(QBrush(QColor(c().accent))) + p.setPen(Qt.PenStyle.NoPen) + p.drawRoundedRect(4, 4, 56, 48, 12, 12) + # small triangle (tail) + from PyQt6.QtGui import QPolygon + from PyQt6.QtCore import QPoint + p.drawPolygon(QPolygon([QPoint(14, 52), QPoint(24, 44), QPoint(30, 52)])) + # lock icon (E2E indicator) + p.setBrush(QBrush(QColor(c().accent_text))) + p.drawRoundedRect(24, 16, 16, 14, 3, 3) + p.setPen(QPen(QColor(c().accent_text), 3)) + p.setBrush(Qt.BrushStyle.NoBrush) + p.drawArc(27, 10, 10, 12, 0, 180 * 16) + p.end() + return QIcon(px) + + def _setup_tray_icon(self): + """Initialize the system tray icon with context menu.""" + if not QSystemTrayIcon.isSystemTrayAvailable(): + self._tray_icon = None + return + self._tray_icon = QSystemTrayIcon(self._make_tray_icon(), self) + self._tray_icon.setToolTip(f"Encrypted Chat - {self.bridge.client.username}") + self._tray_icon.activated.connect(self._on_tray_activated) + + tray_menu = QMenu() + show_action = tray_menu.addAction("Show") + show_action.triggered.connect(self._restore_from_tray) + tray_menu.addSeparator() + quit_action = tray_menu.addAction("Quit") + quit_action.triggered.connect(self._quit_from_tray) + self._tray_icon.setContextMenu(tray_menu) + self._tray_icon.show() + + def _on_tray_activated(self, reason): + """Handle tray icon click — restore window on double-click or single click.""" + if reason in (QSystemTrayIcon.ActivationReason.Trigger, + QSystemTrayIcon.ActivationReason.DoubleClick): + self._restore_from_tray() + + def _restore_from_tray(self): + """Restore window from system tray.""" + self.showNormal() + self.activateWindow() + self.raise_() + + def _quit_from_tray(self): + """Quit the application from tray menu.""" + if self._tray_icon: + self._tray_icon.hide() + self._is_logout = False + self.close() + + def _show_tray_notification(self, title: str, text: str): + """Show a system tray toast notification if the window is not in the foreground.""" + if not self._tray_icon: + logger.debug("Tray notification skipped: no tray icon") + return + if self.isVisible() and self.isActiveWindow() and not self._privacy_locked: + return # user is looking at the app (and it's not locked) + if len(text) > 120: + text = text[:117] + "..." + logger.info("Tray notification: %s — %s", title, text[:50]) + self._tray_icon.showMessage( + title, text, + QSystemTrayIcon.MessageIcon.Information, 4000, + ) + + # -- End Tray --------------------------------------------------------------- + + def _bold_font(self) -> QFont: + """Return a bold font with a valid size (avoids QFont pointSize=-1 warnings).""" + f = QFont(self.conv_list.font()) + f.setBold(True) + # Stylesheet sets font-size in px so pointSize is -1; fix by using pixelSize + if f.pointSize() <= 0: + px = f.pixelSize() + if px > 0: + f.setPixelSize(px) + else: + f.setPointSize(10) + return f + + def _make_circular_avatar(self, pixmap: QPixmap, size: int = 32) -> QPixmap: + """Crop a pixmap into a circle.""" + scaled = pixmap.scaled(size, size, Qt.AspectRatioMode.KeepAspectRatioByExpanding, + Qt.TransformationMode.SmoothTransformation) + result = QPixmap(size, size) + result.fill(QColor(0, 0, 0, 0)) + painter = QPainter(result) + painter.setRenderHint(QPainter.RenderHint.Antialiasing) + painter.setBrush(QBrush(scaled)) + painter.setPen(Qt.PenStyle.NoPen) + painter.drawEllipse(0, 0, size, size) + painter.end() + return result + + def _make_default_avatar(self, username: str, size: int = 32) -> QPixmap: + """Generate a colored circle with the first letter of the username.""" + # Deterministic color from username — higher saturation in light mode + hue = (hash(username) % 360) + sat = 160 if not tm().is_dark else 120 + val = 180 if not tm().is_dark else 200 + color = QColor.fromHsv(hue, sat, val) + result = QPixmap(size, size) + result.fill(QColor(0, 0, 0, 0)) + painter = QPainter(result) + painter.setRenderHint(QPainter.RenderHint.Antialiasing) + painter.setBrush(QBrush(color)) + painter.setPen(Qt.PenStyle.NoPen) + painter.drawEllipse(0, 0, size, size) + # Draw letter + painter.setPen(QColor(255, 255, 255)) + font = QFont("Segoe UI Variable", int(size * 0.4)) + if not font.exactMatch(): + font = QFont("Segoe UI", int(size * 0.4)) + font.setBold(True) + painter.setFont(font) + letter = username[0].upper() if username else "?" + painter.drawText(0, 0, size, size, Qt.AlignmentFlag.AlignCenter, letter) + painter.end() + return result + + def _add_online_dot(self, avatar: QPixmap) -> QPixmap: + """Overlay a green dot on the bottom-right of an avatar pixmap.""" + result = QPixmap(avatar) + painter = QPainter(result) + painter.setRenderHint(QPainter.RenderHint.Antialiasing) + dot_size = max(8, avatar.width() // 4) + x = avatar.width() - dot_size + y = avatar.height() - dot_size + # Border ring (matches sidebar background) + painter.setBrush(QBrush(QColor(c().online_dot_border))) + painter.setPen(Qt.PenStyle.NoPen) + painter.drawEllipse(x - 1, y - 1, dot_size + 2, dot_size + 2) + # Green dot + painter.setBrush(QBrush(QColor(c().online_dot))) + painter.drawEllipse(x, y, dot_size, dot_size) + painter.end() + return result + + def _get_conv_avatar(self, conv: dict) -> QIcon: + """Get avatar icon for a conversation list item.""" + is_dm = len(conv["members"]) == 2 and not conv.get("name") + if is_dm: + other = None + for m in conv["members"]: + if m.get("email") != self.bridge.client.email: + other = m + break + if other: + uid = other.get("user_id") or other.get("id") or "" + uname = other.get("username") or other.get("email") or "?" + if uid in self._avatar_cache: + avatar = self._make_circular_avatar(self._avatar_cache[uid]) + else: + avatar = self._make_default_avatar(uname) + # Request avatar download if not yet requested + if uid and uid not in self._avatar_requested: + self._avatar_requested.add(uid) + self.bridge.get_avatar(uid) + if uid in self._online_users: + avatar = self._add_online_dot(avatar) + return QIcon(avatar) + # Group: use group avatar if available + conv_id = conv.get("conversation_id") or "" + if conv_id in self._group_avatar_cache: + return QIcon(self._make_circular_avatar(self._group_avatar_cache[conv_id])) + gname = conv.get("name") or "G" + # Request group avatar download if has avatar_file + if conv.get("avatar_file") and conv_id and conv_id not in self._group_avatar_requested: + self._group_avatar_requested.add(conv_id) + self.bridge.get_group_avatar(conv_id) + return QIcon(self._make_default_avatar(gname)) + + def _get_conv_avatar_pixmap(self, conv: dict, size: int = 44) -> QPixmap: + """Get avatar QPixmap for delegate painting.""" + is_dm = len(conv["members"]) == 2 and not conv.get("name") + if is_dm: + other = None + for m in conv["members"]: + if m.get("email") != self.bridge.client.email: + other = m + break + if other: + uid = other.get("user_id") or other.get("id") or "" + uname = other.get("username") or other.get("email") or "?" + if uid in self._avatar_cache: + avatar = self._make_circular_avatar(self._avatar_cache[uid], size) + else: + avatar = self._make_default_avatar(uname, size) + if uid and uid not in self._avatar_requested: + self._avatar_requested.add(uid) + self.bridge.get_avatar(uid) + if uid in self._online_users: + avatar = self._add_online_dot(avatar) + return avatar + conv_id = conv.get("conversation_id") or "" + if conv_id in self._group_avatar_cache: + return self._make_circular_avatar(self._group_avatar_cache[conv_id], size) + gname = conv.get("name") or "G" + if conv.get("avatar_file") and conv_id and conv_id not in self._group_avatar_requested: + self._group_avatar_requested.add(conv_id) + self.bridge.get_group_avatar(conv_id) + return self._make_default_avatar(gname, size) + + def _build_ui(self): + main_layout = QHBoxLayout(self) + main_layout.setContentsMargins(0, 0, 0, 0) + main_layout.setSpacing(0) + + splitter = QSplitter(Qt.Orientation.Horizontal) + + # Left panel - conversations + left = QWidget() + left_layout = QVBoxLayout(left) + left_layout.setContentsMargins(8, 8, 4, 8) + + t = c() + left.setObjectName("sidebarPanel") + left.setStyleSheet(f"#sidebarPanel {{ background-color: {t.bg_tertiary}; }}") + self._sidebar_panel = left + + header_row = QHBoxLayout() + self._conv_label = QLabel("Conversations") + self._conv_label.setStyleSheet(f"font-weight: bold; font-size: 12pt; color: {t.accent}; background: transparent;") + header_row.addWidget(self._conv_label) + header_row.addStretch() + + new_chat_btn = QPushButton("") + new_chat_btn.setFixedSize(32, 32) + new_chat_btn.setObjectName("toolBtn") + new_chat_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_FileDialogNewFolder)) + new_chat_btn.setToolTip("New Chat") + new_chat_btn.clicked.connect(self._on_new_chat) + header_row.addWidget(new_chat_btn) + + group_btn = QPushButton("") + group_btn.setFixedSize(32, 32) + group_btn.setObjectName("toolBtn") + group_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_DirIcon)) + group_btn.setToolTip("New Group") + group_btn.clicked.connect(self._on_new_group) + header_row.addWidget(group_btn) + + profile_btn = QPushButton("") + profile_btn.setFixedSize(32, 32) + profile_btn.setObjectName("toolBtn") + profile_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_FileDialogInfoView)) + profile_btn.setToolTip("My Profile") + profile_btn.clicked.connect(self._on_my_profile) + header_row.addWidget(profile_btn) + + left_layout.addLayout(header_row) + + # Invitation section (hidden when empty) + self.inv_label = QLabel("Pending Invitations") + self.inv_label.setStyleSheet(f"font-weight: bold; font-size: 9pt; color: {t.warning}; margin-top: 4px;") + self.inv_label.setVisible(False) + left_layout.addWidget(self.inv_label) + + self.inv_list = QListWidget() + self.inv_list.setMaximumHeight(120) + self.inv_list.setVisible(False) + self.inv_list.setContextMenuPolicy(Qt.ContextMenuPolicy.CustomContextMenu) + self.inv_list.customContextMenuRequested.connect(self._on_inv_context_menu) + self.inv_list.setStyleSheet( + f"QListWidget {{ background-color: {t.bg_primary}; border: 1px solid {t.warning}; border-radius: 6px; padding: 2px; }}" + f"QListWidget::item {{ padding: 6px; color: {t.text_primary}; }}" + f"QListWidget::item:hover {{ background-color: {t.bg_hover}; color: {t.text_primary}; }}" + ) + left_layout.addWidget(self.inv_list) + + self.conv_list = QListWidget() + self.conv_list.setIconSize(QSize(44, 44)) + # Override global QSS item styles — delegate handles all painting + self.conv_list.setStyleSheet( + f"QListWidget {{ background-color: {t.bg_tertiary}; border: none; padding: 0px; }}" + f"QListWidget::item {{ padding: 0px; border: none; }}" + f"QListWidget::item:selected {{ background: transparent; border: none; }}" + f"QListWidget::item:hover {{ background: transparent; }}" + ) + self._conv_delegate = ConversationDelegate(self.conv_list) + self.conv_list.setItemDelegate(self._conv_delegate) + self.conv_list.setMouseTracking(True) # for hover painting + self.conv_list.currentRowChanged.connect(self._on_conv_selected) + self.conv_list.setContextMenuPolicy(Qt.ContextMenuPolicy.CustomContextMenu) + self.conv_list.customContextMenuRequested.connect(self._on_conv_list_context_menu) + left_layout.addWidget(self.conv_list) + + # Bottom toolbar row: settings + logout + bottom_row = QHBoxLayout() + bottom_row.setContentsMargins(0, 4, 0, 0) + + settings_btn = QPushButton("\u2699 Settings") + settings_btn.setObjectName("sidebarBtn") + settings_btn.setToolTip("Open settings") + settings_btn.setStyleSheet( + f"QPushButton {{ background: transparent; color: {t.text_secondary}; border: none; border-radius: 6px; padding: 8px 16px; }}" + f"QPushButton:hover {{ background-color: {t.bg_hover}; }}" + ) + settings_btn.clicked.connect(self._on_open_settings) + self._settings_btn = settings_btn + bottom_row.addWidget(settings_btn) + + logout_btn = QPushButton("\u2715") + logout_btn.setFixedSize(32, 32) + logout_btn.setStyleSheet( + f"QPushButton {{ background: transparent; color: {t.error}; border: none; " + f"border-radius: 16px; font-size: 13pt; font-weight: bold; }}" + f"QPushButton:hover {{ background-color: {t.error}; color: {t.accent_text}; }}" + ) + logout_btn.setToolTip("Logout") + logout_btn.clicked.connect(self._on_logout) + self._logout_btn = logout_btn + bottom_row.addWidget(logout_btn) + + left_layout.addLayout(bottom_row) + + # Right panel - messages + right = QWidget() + right_layout = QVBoxLayout(right) + right_layout.setContentsMargins(4, 8, 8, 8) + + # Chat header bar (56px height) + chat_header_widget = QWidget() + chat_header_widget.setFixedHeight(56) + self._chat_header_widget = chat_header_widget + chat_header_widget.setObjectName("chatHeader") + chat_header_widget.setStyleSheet( + f"#chatHeader {{ border-bottom: 1px solid {t.separator}; }}" + f"#chatHeader QLabel, #chatHeader QPushButton {{ border: none; }}" + ) + chat_header_row = QHBoxLayout(chat_header_widget) + chat_header_row.setContentsMargins(12, 4, 8, 4) + + self.chat_header_avatar = QLabel() + self.chat_header_avatar.setFixedSize(40, 40) + self.chat_header_avatar.setStyleSheet("background: transparent;") + self.chat_header_avatar.setVisible(False) + chat_header_row.addWidget(self.chat_header_avatar) + + # Name + status text vertical stack + name_status_layout = QVBoxLayout() + name_status_layout.setSpacing(0) + name_status_layout.setContentsMargins(6, 0, 0, 0) + self.chat_header = QLabel("Select a conversation") + self.chat_header.setStyleSheet(f"font-weight: bold; font-size: 12pt; color: {t.accent};") + name_status_layout.addWidget(self.chat_header) + + self._chat_header_status = QLabel("") + self._chat_header_status.setStyleSheet(f"color: {t.text_muted}; font-size: 8pt;") + self._chat_header_status.setVisible(False) + name_status_layout.addWidget(self._chat_header_status) + chat_header_row.addLayout(name_status_layout) + + # E2E lock indicator + self._e2e_label = QLabel("\U0001f512 End-to-end encrypted") + self._e2e_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + self._e2e_label.setToolTip("End-to-end encrypted") + self._e2e_label.setStyleSheet(f"font-size: 8pt; color: {t.text_muted}; background: transparent;") + self._e2e_label.setCursor(Qt.CursorShape.PointingHandCursor) + self._e2e_label.mousePressEvent = self._on_e2e_label_clicked + self._e2e_label.setVisible(False) + chat_header_row.addWidget(self._e2e_label) + + self.connection_dot = QLabel("\u25cf") + self.connection_dot.setFixedSize(16, 16) + self.connection_dot.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.connection_dot.setStyleSheet(f"color: {t.success}; font-size: 11pt;") + self.connection_dot.setToolTip("Connected") + chat_header_row.addWidget(self.connection_dot) + chat_header_row.addStretch() + + self.group_info_btn = QPushButton("") + self.group_info_btn.setFixedSize(32, 32) + self.group_info_btn.setObjectName("toolBtn") + self.group_info_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_MessageBoxInformation)) + self.group_info_btn.setToolTip("Group Info") + self.group_info_btn.clicked.connect(self._on_group_info) + self.group_info_btn.setVisible(False) + chat_header_row.addWidget(self.group_info_btn) + + self.user_info_btn = QPushButton("") + self.user_info_btn.setFixedSize(32, 32) + self.user_info_btn.setObjectName("toolBtn") + self.user_info_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_FileDialogInfoView)) + self.user_info_btn.setToolTip("User Info") + self.user_info_btn.clicked.connect(self._on_dm_user_info) + self.user_info_btn.setVisible(False) + chat_header_row.addWidget(self.user_info_btn) + + self.delete_conv_btn = QPushButton("") + self.delete_conv_btn.setFixedSize(32, 32) + self.delete_conv_btn.setStyleSheet( + f"QPushButton {{ background: transparent; border: none; border-radius: 4px; padding: 4px; }}" + f"QPushButton:hover {{ background-color: {t.error}; }}" + ) + self.delete_conv_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_TrashIcon)) + self.delete_conv_btn.setToolTip("Delete conversation") + self.delete_conv_btn.clicked.connect(self._on_delete_conv_btn) + self.delete_conv_btn.setVisible(False) + chat_header_row.addWidget(self.delete_conv_btn) + + self.add_member_btn = QPushButton("") + self.add_member_btn.setFixedSize(32, 32) + self.add_member_btn.setObjectName("toolBtn") + self.add_member_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_FileDialogNewFolder)) + self.add_member_btn.setToolTip("Add Member") + self.add_member_btn.clicked.connect(self._on_add_member) + self.add_member_btn.setVisible(False) + chat_header_row.addWidget(self.add_member_btn) + + self.pin_list_btn = QPushButton("\U0001f4cc") + self.pin_list_btn.setFixedSize(32, 32) + self.pin_list_btn.setObjectName("toolBtn") + self.pin_list_btn.setToolTip("Pinned messages") + self.pin_list_btn.clicked.connect(self._show_pinned_messages) + self.pin_list_btn.setVisible(False) + chat_header_row.addWidget(self.pin_list_btn) + + self.search_btn = QPushButton("") + self.search_btn.setFixedSize(32, 32) + self.search_btn.setObjectName("toolBtn") + self.search_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_FileDialogContentsView)) + self.search_btn.setToolTip("Search messages (Ctrl+F)") + self.search_btn.clicked.connect(self._toggle_search) + self.search_btn.setVisible(False) + chat_header_row.addWidget(self.search_btn) + + right_layout.addWidget(chat_header_widget) + + # Search bar (hidden by default) + self.search_widget = QWidget() + search_row = QHBoxLayout(self.search_widget) + search_row.setContentsMargins(0, 2, 0, 2) + self.search_input = QLineEdit() + self.search_input.setPlaceholderText("Search messages...") + self.search_input.setStyleSheet( + f"QLineEdit {{ background-color: {t.bg_secondary}; color: {t.text_primary}; " + f"border: 1px solid {t.border}; border-radius: 4px; padding: 4px 8px; font-size: 10pt; }}" + ) + self.search_input.textChanged.connect(self._on_search_text_changed) + self.search_input.returnPressed.connect(self._on_search_next) + # Escape in search input closes search + QShortcut(QKeySequence("Escape"), self.search_input).activated.connect(self._close_search) + search_row.addWidget(self.search_input, stretch=1) + self.search_prev_btn = QPushButton("\u25b2") + self.search_prev_btn.setFixedSize(28, 28) + self.search_prev_btn.setObjectName("toolBtn") + self.search_prev_btn.setToolTip("Previous match") + self.search_prev_btn.clicked.connect(self._on_search_prev) + search_row.addWidget(self.search_prev_btn) + self.search_next_btn = QPushButton("\u25bc") + self.search_next_btn.setFixedSize(28, 28) + self.search_next_btn.setObjectName("toolBtn") + self.search_next_btn.setToolTip("Next match") + self.search_next_btn.clicked.connect(self._on_search_next) + search_row.addWidget(self.search_next_btn) + self.search_count_label = QLabel("0/0") + self.search_count_label.setStyleSheet(f"color: {t.text_muted}; font-size: 9pt; min-width: 40px;") + self.search_count_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + search_row.addWidget(self.search_count_label) + self.search_close_btn = QPushButton("\u2715") + self.search_close_btn.setFixedSize(28, 28) + self.search_close_btn.setObjectName("toolBtn") + self.search_close_btn.setToolTip("Close search") + self.search_close_btn.clicked.connect(self._close_search) + search_row.addWidget(self.search_close_btn) + self.search_widget.setVisible(False) + right_layout.addWidget(self.search_widget) + + # --- Pinned message banner --- + self._pin_banner = QWidget() + self._pin_banner.setStyleSheet( + f"background-color:{t.border}; border-bottom:2px solid {t.pin_color};" + ) + pin_banner_layout = QHBoxLayout(self._pin_banner) + pin_banner_layout.setContentsMargins(10, 6, 10, 6) + pin_icon = QLabel("\U0001f4cc") + pin_icon.setStyleSheet("font-size:12pt; background:transparent; border:none;") + pin_banner_layout.addWidget(pin_icon) + self._pin_banner_label = QLabel("") + self._pin_banner_label.setStyleSheet( + f"color:{t.text_primary}; font-size:9pt; background:transparent; border:none;" + ) + self._pin_banner_label.setCursor(Qt.CursorShape.PointingHandCursor) + self._pin_banner_label.setWordWrap(False) + pin_banner_layout.addWidget(self._pin_banner_label, stretch=1) + pin_banner_close = QPushButton("\u2715") + pin_banner_close.setFixedSize(20, 20) + pin_banner_close.setStyleSheet( + f"QPushButton {{ background:transparent; color:{t.text_muted}; border:none; font-size:10pt; }}" + f"QPushButton:hover {{ color:{t.text_primary}; }}" + ) + pin_banner_close.clicked.connect(lambda: self._pin_banner.setVisible(False)) + pin_banner_layout.addWidget(pin_banner_close) + self._pin_banner.setVisible(False) + self._pin_banner.setCursor(Qt.CursorShape.PointingHandCursor) + self._pin_banner.mousePressEvent = self._on_pin_banner_clicked + self._pin_banner_msg_id = None + right_layout.addWidget(self._pin_banner) + + self.load_more_btn = QPushButton("Load older messages") + self.load_more_btn.setObjectName("secondaryBtn") + self.load_more_btn.clicked.connect(self._on_load_more) + self.load_more_btn.setVisible(False) + right_layout.addWidget(self.load_more_btn) + + # Message display area — QScrollArea with widget-based bubbles + self._msg_scroll_area = QScrollArea() + self._msg_scroll_area.setWidgetResizable(True) + self._msg_scroll_area.setHorizontalScrollBarPolicy( + Qt.ScrollBarPolicy.ScrollBarAlwaysOff + ) + self._msg_scroll_area.setStyleSheet( + f"QScrollArea {{ background-color: {t.bg_primary}; border: none; }}" + ) + self._msg_scroll_area.setAcceptDrops(True) + self._msg_scroll_area.installEventFilter(self) + self._msg_scroll_area.viewport().setContextMenuPolicy( + Qt.ContextMenuPolicy.NoContextMenu + ) + self._msg_container = QWidget() + self._msg_container.setStyleSheet(f"background-color: {t.bg_primary};") + self._msg_layout = QVBoxLayout(self._msg_container) + self._msg_layout.setAlignment(Qt.AlignmentFlag.AlignTop) + self._msg_layout.setContentsMargins(0, 8, 0, 8) + self._msg_layout.setSpacing(2) + self._msg_scroll_area.setWidget(self._msg_container) + self._msg_widgets = [] + self.message_area = self._msg_scroll_area # alias for scroll/jump/drop + right_layout.addWidget(self._msg_scroll_area, stretch=1) + + # Smart scroll: track if user is near bottom + self._is_near_bottom = True + self._msg_scroll_area.verticalScrollBar().valueChanged.connect( + self._on_scroll_changed + ) + + # Scroll-to-bottom floating button (hidden by default) + self.jump_btn = QPushButton("\u2193") + self.jump_btn.setParent(self.message_area) + self.jump_btn.setVisible(False) + self.jump_btn.setFixedSize(36, 36) + self.jump_btn.setStyleSheet( + f"QPushButton {{ background-color: {t.accent}; color: {t.accent_text}; border-radius: 18px; " + f"font-size: 14pt; font-weight: bold; }}" + f"QPushButton:hover {{ background-color: {t.accent_hover}; }}" + ) + self.jump_btn.clicked.connect(self._scroll_to_bottom) + + # Reply preview (above input, blue left bar) + self._reply_widget = QWidget() + self._reply_widget.setVisible(False) + reply_row = QHBoxLayout(self._reply_widget) + reply_row.setContentsMargins(8, 4, 8, 0) + reply_row.setSpacing(4) + reply_bar = QFrame() + reply_bar.setFixedWidth(3) + reply_bar.setStyleSheet(f"background-color: {t.accent}; border-radius: 1px;") + reply_row.addWidget(reply_bar) + self.reply_label = QLabel("") + self.reply_label.setStyleSheet( + f"color: {t.accent}; font-style: italic; font-size: 9pt; " + f"padding: 2px 4px; background: transparent;" + ) + self.reply_label.setWordWrap(True) + reply_row.addWidget(self.reply_label, stretch=1) + reply_dismiss = QPushButton("\u2715") + reply_dismiss.setFixedSize(20, 20) + reply_dismiss.setStyleSheet( + f"QPushButton {{ background:transparent; color:{t.text_muted}; border:none; font-size:10pt; }}" + f"QPushButton:hover {{ color:{t.text_primary}; }}" + ) + reply_dismiss.clicked.connect(self._cancel_reply) + reply_row.addWidget(reply_dismiss) + right_layout.addWidget(self._reply_widget) + + # Input row: [attach] [input] [send] + input_row = QHBoxLayout() + input_row.setSpacing(8) + input_row.setContentsMargins(8, 4, 8, 4) + + self._attach_btn = QPushButton("\U0001f4ce") + attach_btn = self._attach_btn + attach_btn.setFixedSize(40, 40) + attach_btn.setCursor(Qt.CursorShape.PointingHandCursor) + attach_btn.setStyleSheet( + f"QPushButton {{ background-color: {t.bg_secondary}; border: none; " + f"border-radius: 20px; font-size: 14pt; }}" + f"QPushButton:hover {{ background-color: {t.bg_hover}; }}" + ) + self._attach_menu = QMenu(attach_btn) + self._attach_menu.addAction("\U0001f5bc Image", self._on_attach_image) + self._attach_menu.addAction("\U0001f4c4 File", self._on_attach_file) + attach_btn.clicked.connect(lambda: self._attach_menu.exec( + attach_btn.mapToGlobal(attach_btn.rect().topLeft() - QPoint(0, self._attach_menu.sizeHint().height())) + )) + input_row.addWidget(attach_btn) + + self.msg_input = MessageInput() + self.msg_input.send_requested.connect(self._on_send) + self.msg_input.textChanged.connect(self._on_input_changed) + self.msg_input.file_dropped.connect(self._on_file_dropped) + input_row.addWidget(self.msg_input, stretch=1) + + self._send_btn = QPushButton("\u27a4") + send_btn = self._send_btn + send_btn.setFixedSize(40, 40) + send_btn.setCursor(Qt.CursorShape.PointingHandCursor) + send_btn.setStyleSheet( + f"QPushButton {{ background-color: {t.accent}; color: {t.accent_text}; " + f"border: none; border-radius: 20px; font-size: 14pt; font-weight: bold; }}" + f"QPushButton:hover {{ background-color: {t.accent_hover}; }}" + ) + send_btn.clicked.connect(self._on_send) + input_row.addWidget(send_btn) + right_layout.addLayout(input_row) + + self.typing_label = QLabel("") + self.typing_label.setStyleSheet( + f"color: {t.text_muted}; font-size: 9pt; padding: 0 10px; font-style: italic;" + ) + self.typing_label.setVisible(False) + right_layout.addWidget(self.typing_label) + + self.char_counter = QLabel(f"0 / {MAX_INPUT_CHARS}") + self.char_counter.setStyleSheet(f"color: {t.text_muted}; font-size: 8pt; padding: 0 4px;") + self.char_counter.setAlignment(Qt.AlignmentFlag.AlignRight) + right_layout.addWidget(self.char_counter) + + self.reencrypt_label = QLabel("") + self.reencrypt_label.setStyleSheet( + f"background-color: {t.bg_secondary}; border-radius: 6px; " + f"padding: 8px 12px; color: {t.success}; font-weight: bold;" + ) + self.reencrypt_label.setVisible(False) + right_layout.addWidget(self.reencrypt_label) + + splitter.addWidget(left) + splitter.addWidget(right) + splitter.setStretchFactor(0, 1) + splitter.setStretchFactor(1, 3) + + # Wrap splitter + status bar in vertical layout for full-width status bar + wrapper = QVBoxLayout() + wrapper.setContentsMargins(0, 0, 0, 0) + wrapper.setSpacing(0) + wrapper.addWidget(splitter) + + # Status bar (permanent, fixed height, full width — no layout jumping) + self.status_bar = QLabel("") + self.status_bar.setFixedHeight(24) + self.status_bar.setStyleSheet( + f"background-color: {t.bg_tertiary}; border-radius: 0px; " + f"padding: 0 8px; color: {t.success}; font-size: 8pt;" + ) + self.status_bar.setCursor(Qt.CursorShape.PointingHandCursor) + self.status_bar.mousePressEvent = self._on_status_bar_click + self._status_bar_conv_id = None + wrapper.addWidget(self.status_bar) + + main_layout.addLayout(wrapper) + + def _connect_signals(self): + self.bridge.conversations_loaded.connect(self._on_conversations_loaded) + self.bridge.messages_loaded.connect(self._on_messages_loaded) + self.bridge.older_messages_loaded.connect(self._on_older_messages_loaded) + self.bridge.message_sent.connect(self._on_message_sent) + self.bridge.message_sent_payload.connect(self._on_message_sent_payload) + self.bridge.new_notification.connect(self._on_notification) + self.bridge.add_member_result.connect(self._on_add_member_result) + self.bridge.authorize_result.connect(self._on_authorize_result) + self.bridge.rotate_result.connect(self._on_rotate_result) + self.bridge.password_changed.connect(self._on_password_changed) + self.bridge.username_changed.connect(self._on_username_changed) + self.bridge.reencrypt_status.connect(self._on_reencrypt_status) + self.bridge.messages_read_notification.connect(self._on_messages_read) + self.bridge.message_delivered_notification.connect(self._on_message_delivered) + self.bridge.typing_start_notification.connect(self._on_typing_start) + self.bridge.typing_stop_notification.connect(self._on_typing_stop) + self.bridge.remove_member_result.connect(self._on_remove_member_result) + self.bridge.message_deleted_notification.connect(self._on_message_deleted) + self.bridge.delete_message_result.connect(self._on_delete_message_result) + self.bridge.image_sent.connect(self._on_image_sent) + self.bridge.image_downloaded.connect(self._on_image_downloaded) + self.bridge.image_download_failed.connect(self._on_image_download_failed) + self.bridge.file_sent.connect(self._on_file_sent) + self.bridge.file_downloaded.connect(self._on_file_downloaded) + self.bridge.conversation_updated.connect(self._on_conversation_updated) + self.bridge.connection_state_changed.connect(self._on_connection_state_changed) + self.bridge.group_left.connect(self._on_group_left) + self.bridge.group_renamed.connect(self._on_group_renamed) + self.bridge.conversation_deleted.connect(self._on_conversation_deleted) + self.bridge.avatar_loaded.connect(self._on_avatar_for_conv_list) + self.bridge.avatar_fetch_failed.connect(self._on_avatar_fetch_failed) + self.bridge._avatar_changed_signal.connect(self._on_avatar_changed_push) + self.bridge.invitations_loaded.connect(self._on_invitations_loaded) + self.bridge.invitation_result.connect(self._on_invitation_result) + self.bridge.invitation_received.connect(self._on_invitation_received) + self.bridge.device_added_notification.connect(self._on_device_added) + self.bridge.online_status_changed.connect(self._on_online_status_changed) + self.bridge.online_users_loaded.connect(self._on_online_users_loaded) + self.bridge.group_avatar_loaded.connect(self._on_group_avatar_for_conv_list) + self.bridge.group_avatar_fetch_failed.connect( + lambda cid: self._group_avatar_requested.discard(cid)) + self.bridge.group_avatar_updated.connect(self._on_group_avatar_updated) + self.bridge.session_reset_notification.connect(self._on_session_reset) + self.bridge.reaction_notification.connect(self._on_reaction_notification) + self.bridge.pin_notification.connect(self._on_pin_notification) + self.bridge.unpin_notification.connect(self._on_unpin_notification) + self.bridge.pinned_messages_loaded.connect(self._on_pinned_messages_loaded) + self.bridge.forward_result.connect(self._on_forward_result) + self.bridge.key_change_warning.connect(self._on_key_change_warning) + self.bridge.message_queued.connect(self._on_message_queued) + self._show_verification_dialog_signal.connect(self._show_verification_dialog) + + # ------------------------------------------------------------------ + # Favorites + # ------------------------------------------------------------------ + + def _favorites_path(self): + from chat_core import get_key_dir + return get_key_dir(self.bridge.client.email) / "favorites.json" + + def _load_favorites(self) -> set[str]: + try: + p = self._favorites_path() + if p.exists(): + return set(json.loads(p.read_text())) + except Exception: + pass + return set() + + def _save_favorites(self): + try: + self._favorites_path().write_text(json.dumps(list(self._favorites))) + except Exception: + pass + + def _on_conv_list_context_menu(self, pos): + item = self.conv_list.itemAt(pos) + if not item: + return + conv_id = item.data(Qt.ItemDataRole.UserRole) + if not conv_id: + return + from PyQt6.QtWidgets import QMenu + menu = QMenu(self) + is_fav = conv_id in self._favorites + action = menu.addAction("Odebrat z oblibených" if is_fav else "Přidat do oblíbených") + result = menu.exec(self.conv_list.mapToGlobal(pos)) + if result == action: + if is_fav: + self._favorites.discard(conv_id) + else: + self._favorites.add(conv_id) + self._save_favorites() + self._rebuild_conv_list() + + # ------------------------------------------------------------------ + # Conversation list helpers + # ------------------------------------------------------------------ + + def _get_conv_display_name(self, conv: dict) -> str: + """Get display name for a conversation (used for sorting and labels).""" + others = [m.get("username") or m.get("email") or "?" for m in conv["members"] + if m.get("email") != self.bridge.client.email] + return conv.get("name") or (", ".join(others) if others else self.bridge.client.username) + + def _get_conv_other_user_id(self, conv: dict) -> str: + """Get the other user's ID in a DM conversation (empty string for groups).""" + is_dm = len(conv["members"]) == 2 and not conv.get("name") + if not is_dm: + return "" + for m in conv["members"]: + if m.get("email") != self.bridge.client.email: + return m.get("user_id") or m.get("id") or "" + return "" + + def _get_conv_sort_key(self, conv: dict) -> tuple: + """Sort key: favorites first, then online DMs, then rest — alphabetically within each.""" + conv_id = conv.get("conversation_id", "") + is_fav = 0 if conv_id in self._favorites else 1 + other_uid = self._get_conv_other_user_id(conv) + is_online = 0 if other_uid and other_uid in self._online_users else 1 + name = self._get_conv_display_name(conv).lower() + return (is_fav, is_online, name) + + def _on_conversations_loaded(self, convs): + self.conversations = convs + # Populate unread counts from server (covers messages received while offline) + for cv in convs: + cid = cv["conversation_id"] + server_unread = cv.get("unread_count", 0) + # Use the higher of server vs local (local may have newer real-time notifications) + if server_unread > self._unread_counts.get(cid, 0): + self._unread_counts[cid] = server_unread + self._rebuild_conv_list() + + def _rebuild_conv_list(self): + """Sort and rebuild the conversation list widget.""" + if not self.conversations: + return + # Sort: favorites first, then online DMs, then rest — alphabetically within each group + self.conversations.sort(key=self._get_conv_sort_key) + prev_id = self.current_conv_id + self.conv_list.blockSignals(True) + self.conv_list.clear() + select_row = -1 + for i, cv in enumerate(self.conversations): + conv_id = cv["conversation_id"] + name = self._get_conv_display_name(cv) + count = self._unread_counts.get(conv_id, 0) + is_fav = conv_id in self._favorites + # Preview + timestamp from last-message cache + preview_text, preview_ts, receipt_st = self._last_message_cache.get(conv_id, ("", "", "")) + rel_ts = _relative_time(preview_ts) + # Avatar pixmap + avatar_pix = self._get_conv_avatar_pixmap(cv) + + # Verification status (DMs only) + verified_status = "" + is_dm = len(cv["members"]) == 2 and not cv.get("name") + if is_dm: + peer_uid = self._get_conv_other_user_id(cv) + if peer_uid: + verified_status = self.bridge.client.get_verification_status(peer_uid) + + item = QListWidgetItem() + item.setSizeHint(QSize(0, ConversationDelegate.ITEM_HEIGHT)) + item.setData(ROLE_CONV_ID, conv_id) + item.setData(ROLE_DISPLAY_NAME, name) + item.setData(ROLE_PREVIEW, preview_text) + item.setData(ROLE_TIMESTAMP, rel_ts) + item.setData(ROLE_UNREAD, count) + item.setData(ROLE_IS_FAV, is_fav) + item.setData(ROLE_AVATAR, avatar_pix) + item.setData(ROLE_VERIFIED, verified_status) + item.setData(ROLE_RECEIPT, receipt_st) + self.conv_list.addItem(item) + if conv_id == prev_id: + select_row = i + self.conv_list.blockSignals(False) + if select_row >= 0: + self.conv_list.setCurrentRow(select_row) + + def _on_conversation_updated(self): + """Refresh conversation list when a conversation is created/member added/removed.""" + self.bridge.load_conversations() + + def _on_periodic_refresh(self): + """Periodic refresh: reload invitations only. Avatars are fetched on-demand.""" + self.bridge.list_invitations() + + def _on_online_users_loaded(self, user_ids): + self._online_users = set(user_ids) + self._rebuild_conv_list() + + def _on_online_status_changed(self, user_id, is_online): + if is_online: + self._online_users.add(user_id) + else: + self._online_users.discard(user_id) + self._rebuild_conv_list() + + def _avatar_disk_cache_dir(self) -> Path | None: + """Return path to avatar disk cache directory.""" + try: + from chat_core import get_key_dir + d = get_key_dir(self.bridge.client.email) / "avatar_cache" + d.mkdir(parents=True, exist_ok=True) + os.chmod(d, 0o700) + return d + except Exception: + return None + + def _get_local_key(self) -> bytes | None: + """Get local encryption key from ChatClient (derived from identity key).""" + try: + return self.bridge.client._local_key + except Exception: + return None + + def _save_avatar_to_disk(self, key: str, data: bytes): + """Save avatar bytes to disk cache (encrypted with local_key).""" + d = self._avatar_disk_cache_dir() + if not d: + return + try: + from chat_core import _encrypt_local + local_key = self._get_local_key() + if local_key: + enc = _encrypt_local(data, local_key) + p = d / f"{key}.bin" + p.write_bytes(enc) + os.chmod(p, 0o600) + else: + p = d / f"{key}.bin" + p.write_bytes(data) + os.chmod(p, 0o600) + except Exception: + pass + + def _load_avatar_from_disk(self, key: str) -> bytes | None: + """Load avatar bytes from disk cache (decrypts if encrypted).""" + d = self._avatar_disk_cache_dir() + if not d: + return None + p = d / f"{key}.bin" + if not p.exists(): + return None + try: + raw = p.read_bytes() + local_key = self._get_local_key() + if local_key and len(raw) > 28: + from chat_core import _decrypt_local + try: + return _decrypt_local(raw, local_key) + except Exception: + # Transparent migration: try as plaintext, re-save encrypted + qimg = _safe_load_image(raw) + if qimg is not None: + self._save_avatar_to_disk(key, raw) + return raw + return None + return raw + except Exception: + return None + + def _preload_avatars_from_disk(self): + """Load cached avatars from disk into memory on startup.""" + d = self._avatar_disk_cache_dir() + if not d: + return + for p in d.glob("*.bin"): + key = p.stem + try: + data = self._load_avatar_from_disk(key) + if data is None: + continue + qimg = _safe_load_image(data) + if qimg is None: + continue + pix = QPixmap.fromImage(qimg) + if key.startswith("g_"): + self._group_avatar_cache[key[2:]] = pix + else: + self._avatar_cache[key] = pix + except Exception: + pass + + def _on_avatar_fetch_failed(self, user_id): + """Allow retry by removing from requested set.""" + self._avatar_requested.discard(user_id) + + def _on_avatar_changed_push(self, user_id): + """Server pushed avatar_changed — re-download this user's avatar.""" + self._avatar_requested.discard(user_id) + self.bridge.get_avatar(user_id) + + def _on_avatar_for_conv_list(self, user_id, data): + """Cache downloaded avatar and refresh conversation list icons + chat header.""" + qimg = _safe_load_image(data) + if qimg is not None: + self._avatar_cache[user_id] = QPixmap.fromImage(qimg) + self._save_avatar_to_disk(user_id, data) + self._update_conv_list_styles() + # Refresh chat header avatar if current conv uses this user's avatar + self._refresh_chat_header_avatar() + + def _on_group_avatar_for_conv_list(self, conv_id, data): + """Cache downloaded group avatar and refresh conversation list icons + chat header.""" + qimg = _safe_load_image(data) + if qimg is not None: + self._group_avatar_cache[conv_id] = QPixmap.fromImage(qimg) + self._save_avatar_to_disk(f"g_{conv_id}", data) + self._update_conv_list_styles() + # Refresh chat header avatar if current conv is this group + self._refresh_chat_header_avatar() + + def _on_group_avatar_updated(self, ok, msg): + if not ok: + QMessageBox.warning(self, "Group Avatar", msg) + + def _on_invitations_loaded(self, invitations): + self._pending_invitations = invitations + self.inv_list.clear() + if not invitations: + self.inv_label.setVisible(False) + self.inv_list.setVisible(False) + return + self.inv_label.setVisible(True) + self.inv_list.setVisible(True) + for inv in invitations: + conv_name = inv.get("conversation_name") or "Unnamed group" + inviter = inv.get("invited_by_username", "someone") + label = f"{conv_name} (from {inviter})" + item = QListWidgetItem(label) + item.setData(Qt.ItemDataRole.UserRole, inv["conversation_id"]) + self.inv_list.addItem(item) + + def _on_invitation_result(self, ok, msg): + if not ok: + QMessageBox.warning(self, "Invitation", msg) + + def _on_invitation_received(self, data): + """New invitation received via push notification.""" + self.bridge.list_invitations() + conv_name = data.get("conversation_name") or "a group" + inviter = data.get("invited_by_username", "Someone") + t = c() + self.status_bar.setText(f"{inviter} invited you to {conv_name}") + self.status_bar.setStyleSheet( + f"background-color: {t.bg_tertiary}; border-radius: 0px; " + f"padding: 0 8px; color: {t.warning}; font-size: 8pt; font-weight: bold;" + ) + self._status_bar_conv_id = None + QTimer.singleShot(5000, self._clear_status_bar) + + def _on_device_added(self, data): + device_name = data.get("device_name") or "Unknown device" + device_id = (data.get("device_id") or "")[:8] + ip = data.get("ip") or "unknown" + t = c() + self.status_bar.setText(f"New device added: {device_name} ({device_id}) from {ip}") + self.status_bar.setStyleSheet( + f"background-color: {t.bg_tertiary}; border-radius: 0px; " + f"padding: 0 8px; color: {t.warning}; font-size: 8pt; font-weight: bold;" + ) + self._status_bar_conv_id = None + QTimer.singleShot(10000, self._clear_status_bar) + QMessageBox.warning( + self, + "New Device Added", + f"A new device was added to this account:\n\n" + f"Name: {device_name}\n" + f"Device ID: {device_id}\n" + f"IP: {ip}\n\n" + f"If this was not you, rotate keys immediately.", + ) + + def _on_inv_context_menu(self, pos): + item = self.inv_list.itemAt(pos) + if not item: + return + conv_id = item.data(Qt.ItemDataRole.UserRole) + if not conv_id: + return + menu = QMenu(self) + accept_action = menu.addAction("Accept") + decline_action = menu.addAction("Decline") + chosen = menu.exec(self.inv_list.mapToGlobal(pos)) + if chosen == accept_action: + self.bridge.accept_invitation(conv_id) + elif chosen == decline_action: + self.bridge.decline_invitation(conv_id) + + def _on_connection_state_changed(self, state): + t = c() + if state == "connected": + self.connection_dot.setStyleSheet(f"color: {t.success}; font-size: 11pt;") + self.connection_dot.setToolTip("Connected") + self.status_bar.setText("Connected") + self.status_bar.setStyleSheet( + f"background-color: {t.bg_tertiary}; border-radius: 0px; " + f"padding: 0 8px; color: {t.success}; font-size: 8pt;" + ) + QTimer.singleShot(3000, self._clear_status_bar) + elif state == "disconnected": + self.connection_dot.setStyleSheet(f"color: {t.error}; font-size: 11pt;") + self.connection_dot.setToolTip("Disconnected") + self.status_bar.setText("Disconnected from server") + self.status_bar.setStyleSheet( + f"background-color: {t.bg_tertiary}; border-radius: 0px; " + f"padding: 0 8px; color: {t.error}; font-size: 8pt; font-weight: bold;" + ) + self._status_bar_conv_id = None + elif state == "reconnecting": + self.connection_dot.setStyleSheet(f"color: {t.warning}; font-size: 11pt;") + self.connection_dot.setToolTip("Reconnecting...") + self.status_bar.setText("Reconnecting...") + self.status_bar.setStyleSheet( + f"background-color: {t.bg_tertiary}; border-radius: 0px; " + f"padding: 0 8px; color: {t.warning}; font-size: 8pt;" + ) + self._status_bar_conv_id = None + elif state == "revoked": + self.connection_dot.setStyleSheet(f"color: {t.error}; font-size: 11pt;") + self.connection_dot.setToolTip("Access revoked") + # Clear conversation list + self.conv_list.clear() + self.conversations = [] + self._unread_counts.clear() + # Clear open conversation + self.current_conv_id = None + self.msg_input.drop_enabled = False + self.chat_header.setText("Select a conversation") + self.chat_header_avatar.setVisible(False) + self._clear_message_area() + self.msg_input.setEnabled(False) + self.group_info_btn.setVisible(False) + self.user_info_btn.setVisible(False) + self.add_member_btn.setVisible(False) + self.delete_conv_btn.setVisible(False) + QMessageBox.warning(self, "Access Revoked", + "Your keys were rotated on another device. " + "This session is no longer valid.") + + def _on_scroll_changed(self, value): + sb = self.message_area.verticalScrollBar() + self._is_near_bottom = (sb.maximum() - value) < 60 + if self._is_near_bottom: + self.jump_btn.setVisible(False) + elif sb.maximum() > 0: + self.jump_btn.setText("\u2193") + self.jump_btn.setVisible(True) + self._position_jump_btn() + + def _scroll_to_bottom(self): + sb = self.message_area.verticalScrollBar() + sb.setValue(sb.maximum()) + self.jump_btn.setVisible(False) + + def _position_jump_btn(self): + w = self.message_area.width() + self.jump_btn.move((w - 36) // 2, self.message_area.height() - 48) + + def _clear_status_bar(self): + self.status_bar.setText("") + t = c() + self.status_bar.setStyleSheet( + f"background-color: {t.bg_tertiary}; border-radius: 0px; " + f"padding: 0 8px; color: {t.success}; font-size: 8pt;" + ) + self._status_bar_conv_id = None + + def _on_status_bar_click(self, event): + conv_id = self._status_bar_conv_id + if conv_id: + for i, c in enumerate(self.conversations): + if c["conversation_id"] == conv_id: + self.conv_list.setCurrentRow(i) + self._clear_status_bar() + break + + def _update_chat_header_avatar(self, conv): + """Set the circular avatar next to the conversation name in the chat header.""" + is_dm = len(conv["members"]) == 2 and not conv.get("name") + size = 40 + t = c() + if is_dm: + other = None + for m in conv["members"]: + if m.get("email") != self.bridge.client.email: + other = m + break + if other: + uid = other.get("user_id") or other.get("id") or "" + uname = other.get("username") or other.get("email") or "?" + if uid in self._avatar_cache: + avatar = self._make_circular_avatar(self._avatar_cache[uid], size) + else: + avatar = self._make_default_avatar(uname, size) + self.chat_header_avatar.setPixmap(avatar) + self.chat_header_avatar.setVisible(True) + # Online status text + if uid in self._online_users: + self._chat_header_status.setText("Online") + self._chat_header_status.setStyleSheet(f"color: {t.success}; font-size: 8pt;") + else: + self._chat_header_status.setText("Offline") + self._chat_header_status.setStyleSheet(f"color: {t.text_muted}; font-size: 8pt;") + self._chat_header_status.setVisible(True) + else: + self.chat_header_avatar.setVisible(False) + self._chat_header_status.setVisible(False) + else: + conv_id = conv.get("conversation_id") or "" + gname = conv.get("name") or "G" + if conv_id in self._group_avatar_cache: + avatar = self._make_circular_avatar(self._group_avatar_cache[conv_id], size) + else: + avatar = self._make_default_avatar(gname, size) + self.chat_header_avatar.setPixmap(avatar) + self.chat_header_avatar.setVisible(True) + # Member count for groups + member_count = len(conv.get("members", [])) + self._chat_header_status.setText(f"{member_count} members") + self._chat_header_status.setStyleSheet(f"color: {t.text_muted}; font-size: 8pt;") + self._chat_header_status.setVisible(True) + # Show E2E indicator with verification status + if is_dm: + peer_uid = "" + for m in conv["members"]: + if m.get("email") != self.bridge.client.email: + peer_uid = m.get("user_id") or m.get("id") or "" + break + v_status = self.bridge.client.get_verification_status(peer_uid) if peer_uid else "" + if v_status == "verified": + self._e2e_label.setText("\u2705 Verified") + self._e2e_label.setStyleSheet(f"font-size: 8pt; color: {t.success}; background: transparent;") + self._e2e_label.setToolTip("Identity verified — click to view safety number") + elif v_status == "trusted": + self._e2e_label.setText("\U0001f512 Encrypted") + self._e2e_label.setStyleSheet(f"font-size: 8pt; color: {t.text_muted}; background: transparent;") + self._e2e_label.setToolTip("End-to-end encrypted (not verified) — click to verify") + else: + self._e2e_label.setText("\U0001f512 Encrypted") + self._e2e_label.setStyleSheet(f"font-size: 8pt; color: {t.text_muted}; background: transparent;") + self._e2e_label.setToolTip("End-to-end encrypted — click to verify") + else: + self._e2e_label.setText("\U0001f512 End-to-end encrypted") + self._e2e_label.setStyleSheet(f"font-size: 8pt; color: {t.text_muted}; background: transparent;") + self._e2e_label.setToolTip("End-to-end encrypted") + self._e2e_label.setVisible(True) + + def _refresh_chat_header_avatar(self): + """Re-render chat header avatar for the currently selected conversation.""" + if not self.current_conv_id: + return + for cv in self.conversations: + if cv["conversation_id"] == self.current_conv_id: + self._update_chat_header_avatar(cv) + return + + def _update_conv_list_styles(self): + """Update delegate data roles for all items (after avatar/unread changes).""" + for i in range(self.conv_list.count()): + item = self.conv_list.item(i) + conv_id = item.data(ROLE_CONV_ID) + count = self._unread_counts.get(conv_id, 0) + conv = None + for cv in self.conversations: + if cv["conversation_id"] == conv_id: + conv = cv + break + if conv: + item.setData(ROLE_DISPLAY_NAME, self._get_conv_display_name(conv)) + item.setData(ROLE_AVATAR, self._get_conv_avatar_pixmap(conv)) + item.setData(ROLE_IS_FAV, conv_id in self._favorites) + item.setData(ROLE_UNREAD, count) + # Update preview/timestamp from cache + preview_text, preview_ts, receipt_st = self._last_message_cache.get(conv_id, ("", "", "")) + item.setData(ROLE_PREVIEW, preview_text) + item.setData(ROLE_TIMESTAMP, _relative_time(preview_ts)) + item.setData(ROLE_RECEIPT, receipt_st) + + def _on_conv_selected(self, row): + if row < 0 or row >= len(self.conversations): + return + prev_conv_id = self.current_conv_id + if prev_conv_id: + self.bridge.send_typing_stop(prev_conv_id, force=True) + conv = self.conversations[row] + self.current_conv_id = conv["conversation_id"] + self._typing_label_refresh() + self.msg_input.drop_enabled = True + others = [m.get("username") or m.get("email") or "?" for m in conv["members"] + if m.get("email") != self.bridge.client.email] + header = conv.get("name") or (", ".join(others) if others else self.bridge.client.username) + self.chat_header.setText(header) + # Set avatar in chat header + self._update_chat_header_avatar(conv) + is_group = len(conv["members"]) > 2 or conv.get("name") + self._is_dm = not is_group + self.add_member_btn.setVisible(bool(is_group)) + self.group_info_btn.setVisible(bool(is_group)) + self.user_info_btn.setVisible(self._is_dm) + # DMs: always show delete. Groups: only show for creator. + if self._is_dm: + self.delete_conv_btn.setVisible(True) + else: + my_user_id = self.bridge.client.session.get("user_id", "") if self.bridge.client.session else "" + self.delete_conv_btn.setVisible(conv.get("created_by") == my_user_id) + self.reply_to_id = None + self._reply_widget.setVisible(False) + self._has_more_messages = True + self.load_more_btn.setVisible(False) + self._unread_counts.pop(self.current_conv_id, None) + self._update_conv_list_styles() + self.search_btn.setVisible(True) + self.pin_list_btn.setVisible(True) + self._pin_banner.setVisible(False) + self._close_search() + # Show cached messages instantly, then sync with server in background + cached = self.bridge.client.get_cached_messages(self.current_conv_id) + if cached: + self._on_messages_loaded(self.current_conv_id, cached) + self.bridge.load_messages(self.current_conv_id) + + def _on_e2e_label_clicked(self, event): + """Open VerificationDialog when E2E label is clicked (DMs only).""" + if not self.current_conv_id or not getattr(self, "_is_dm", False): + return + conv = None + for cv in self.conversations: + if cv["conversation_id"] == self.current_conv_id: + conv = cv + break + if not conv: + return + peer_uid = self._get_conv_other_user_id(conv) + if not peer_uid: + return + peer_name = "" + for m in conv["members"]: + if m.get("email") != self.bridge.client.email: + peer_name = m.get("username") or m.get("email") or "?" + break + # Ensure identity key is in cache + self.bridge.schedule(self._ensure_and_show_verification(peer_uid, peer_name)) + + async def _ensure_and_show_verification(self, peer_uid: str, peer_name: str): + """Ensure we have the peer's identity key, then show verification dialog.""" + await self.bridge.client._get_user_info(user_id=peer_uid) + # Emit signal back to Qt thread + self._show_verification_dialog_signal.emit(peer_uid, peer_name) + + def _show_verification_dialog(self, peer_uid: str, peer_name: str): + dlg = VerificationDialog(self.bridge, peer_uid, peer_name, parent=self) + dlg.exec() + # Refresh conv list and header to reflect any verification changes + self._rebuild_conv_list() + if self.current_conv_id: + for cv in self.conversations: + if cv["conversation_id"] == self.current_conv_id: + self._update_chat_header_avatar(cv) + break + + def _on_key_change_warning(self, user_id: str, username: str, old_key_hex: str, was_verified: bool, new_key_bytes: bytes = b""): + """Show warning dialog when a contact's identity key has changed.""" + t = c() + severity = "CRITICAL" if was_verified else "WARNING" + name = username or user_id[:8] + msg = ( + f"The identity key for {name} has changed!\n\n" + f"This could mean:\n" + f"- They re-installed the app or got a new device\n" + f"- Someone may be intercepting your messages\n\n" + ) + if was_verified: + msg += "This contact was previously verified. You should re-verify." + + dlg = QDialog(self) + dlg.setMinimumWidth(400) + lay = _make_frameless(dlg, f"Identity Key Changed ({severity})") + + warning_label = QLabel(msg) + warning_label.setWordWrap(True) + warning_label.setStyleSheet(f"color: {t.text_primary};") + lay.addWidget(warning_label) + + btn_row = QHBoxLayout() + accept_btn = QPushButton("Accept New Key") + accept_btn.setStyleSheet( + f"QPushButton {{ background-color: {t.warning}; color: {t.bg_primary}; " + f"font-weight: bold; padding: 8px 16px; border-radius: 6px; }}" + ) + accept_btn.clicked.connect(lambda: self._accept_key_change(user_id, new_key_bytes, dlg)) + btn_row.addWidget(accept_btn) + + close_btn = QPushButton("Dismiss") + close_btn.setObjectName("secondaryBtn") + close_btn.clicked.connect(dlg.accept) + btn_row.addWidget(close_btn) + + lay.addLayout(btn_row) + dlg.exec() + + def _accept_key_change(self, user_id: str, new_key_bytes: bytes, dlg: QDialog): + if new_key_bytes: + self.bridge.client.accept_key_change(user_id, new_key_bytes) + dlg.accept() + self._rebuild_conv_list() + + def _on_messages_loaded(self, conv_id, messages): + if conv_id != self.current_conv_id: + return + self.current_messages = messages + # Update last-message cache for conversation list preview + if messages: + last = messages[-1] + preview = last.get("text", "") + if last.get("deleted"): + preview = "Message deleted" + elif last.get("image") and not preview: + preview = "Sent an image" + elif last.get("file") and not preview: + preview = "Sent a file" + receipt = self._compute_receipt_status(last) + self._last_message_cache[conv_id] = (preview[:60], last.get("created_at", ""), receipt) + self._update_conv_list_styles() + # Show "Load older" if we got a full batch (there may be more) + self._has_more_messages = len(messages) >= 50 + self.load_more_btn.setVisible(self._has_more_messages) + self._render_messages() + self._update_pin_banner() + + def _compute_receipt_status(self, m) -> str: + """Return receipt status for a message: 'read', 'delivered', 'sent', or ''.""" + my_uid = (self.bridge.client.session.get("user_id", "") + if self.bridge and self.bridge.client and self.bridge.client.session else "") + if not my_uid or m.get("sender_id") != my_uid: + return "" + read_by = m.get("read_by", []) + if any(r.get("user_id") != my_uid for r in read_by): + return "read" + delivered_to = m.get("delivered_to", []) + if any(d.get("user_id") != my_uid for d in delivered_to): + return "delivered" + return "sent" + + def _render_messages(self, scroll_to_bottom=True): + """Clear and rebuild all message bubble widgets.""" + self._msg_widgets = [] + while self._msg_layout.count() > 0: + item = self._msg_layout.takeAt(0) + w = item.widget() + if w: + w.deleteLater() + for i, m in enumerate(self.current_messages): + w = self._create_message_widget(m, i) + self._msg_layout.addWidget(w) + self._msg_widgets.append(w) + if scroll_to_bottom: + QTimer.singleShot(10, self._scroll_to_bottom) + + def _clear_message_area(self): + """Remove all message widgets from the scroll area.""" + self._msg_widgets = [] + self.current_messages = [] + while self._msg_layout.count() > 0: + item = self._msg_layout.takeAt(0) + w = item.widget() + if w: + w.deleteLater() + + def _decode_thumbnail(self, image_info): + """Decode base64 thumbnail to QPixmap.""" + thumbnail_b64 = image_info.get("thumbnail", "") + if not thumbnail_b64: + return None + from protocol import decode_binary + try: + thumb_bytes = decode_binary(thumbnail_b64) + qimg = _safe_load_image(thumb_bytes) + if qimg is not None: + return QPixmap.fromImage(qimg) + except Exception: + pass + return None + + def _create_message_widget(self, m, index): + """Create a widget tree for a single message bubble.""" + t = c() + my_uid = (self.bridge.client.session.get("user_id", "") + if self.bridge.client.session else "") + is_me = (m.get("sender_id") == my_uid if my_uid + else m.get("sender") == self.bridge.client.username) + + # -- Wrapper for left/right alignment -- + wrapper = QWidget() + wrapper.setStyleSheet("background: transparent;") + wrapper._msg_index = index + wlay = QHBoxLayout(wrapper) + wlay.setContentsMargins(12, 2, 12, 2) + wlay.setSpacing(0) + + # -- Deleted message -- + if m.get("deleted"): + ts = m.get("created_at", "") + time_str = _format_msg_time(ts) + del_bubble = MessageBubble(t.bg_secondary) + del_bubble._msg_index = index + dlay = QVBoxLayout(del_bubble) + dlay.setContentsMargins(14, 8, 14, 8) + dl = QLabel(f"\u00b7 {time_str} Message deleted") + dl.setStyleSheet( + f"color: {t.text_muted}; font-style: italic; " + f"font-size: 10pt; background: transparent;" + ) + dlay.addWidget(dl) + if is_me: + wlay.addStretch(1) + wlay.addWidget(del_bubble) + if not is_me: + wlay.addStretch(1) + return wrapper + + sender = m.get("sender", "???") + text = m.get("text", "") + + # Determine colours + if is_me: + bubble_bg = t.bubble_sent_bg + text_color = t.bubble_sent_text + meta_color = t.bubble_sent_meta + sender_color = t.accent + else: + bubble_bg = t.bubble_recv_bg + text_color = t.bubble_recv_text + meta_color = t.bubble_recv_meta + sender_color = t.sender_name_other + + if is_me: + wlay.addStretch(1) + + # -- Bubble -- + bubble = MessageBubble(bubble_bg) + bubble._msg_index = index + bubble.setSizePolicy( + QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Minimum + ) + bubble.setMaximumWidth(600) + bubble.setMinimumWidth(80) + + blay = QVBoxLayout(bubble) + blay.setContentsMargins(14, 8, 14, 8) + blay.setSpacing(3) + + # -- Forwarded header -- + fwd = m.get("forwarded_from") + if fwd: + fwd_sender = fwd.get("sender", "???") + fwd_esc = (fwd_sender.replace("&", "&") + .replace("<", "<").replace(">", ">")) + fwd_label = QLabel( + f'' + f'Forwarded from {fwd_esc}' + f'' + ) + fwd_label.setTextFormat(Qt.TextFormat.RichText) + fwd_label.setStyleSheet( + f"background: transparent; border-left: 2px solid {t.info}; " + f"padding-left: 6px; margin-bottom: 2px;" + ) + blay.addWidget(fwd_label) + + # -- Reply context -- + if m.get("reply_to"): + for orig in self.current_messages: + if orig.get("message_id") == m.get("reply_to"): + orig_sender = orig.get("sender", "???") + orig_esc = (orig_sender.replace("&", "&") + .replace("<", "<").replace(">", ">")) + orig_text = orig.get("text", "")[:50] + orig_text_esc = (orig_text.replace("&", "&") + .replace("<", "<").replace(">", ">")) + reply_lbl = QLabel( + f'{orig_esc}
' + f'' + f'{orig_text_esc}' + ) + reply_lbl.setTextFormat(Qt.TextFormat.RichText) + reply_lbl.setWordWrap(True) + reply_lbl.setStyleSheet( + f"background: transparent; " + f"border-left: 2px solid {t.scrollbar}; " + f"padding-left: 6px; margin-bottom: 2px;" + ) + blay.addWidget(reply_lbl) + break + + # -- Header (sender name + pin — groups only) -- + timestamp = m.get("created_at", "") + sender_esc = (sender.replace("&", "&") + .replace("<", "<").replace(">", ">")) + pin_html = "" + if m.get("pinned_at"): + pin_html = f' \U0001f4cc' + + is_dm = self._is_dm + if not is_dm: + header_html = ( + f'{sender_esc}' + f'{pin_html}' + ) + header_label = QLabel(header_html) + header_label.setTextFormat(Qt.TextFormat.RichText) + header_label.setStyleSheet("background: transparent;") + blay.addWidget(header_label) + elif pin_html: + pin_label = QLabel(pin_html) + pin_label.setTextFormat(Qt.TextFormat.RichText) + pin_label.setStyleSheet("background: transparent;") + blay.addWidget(pin_label) + + # -- Suppress image placeholder text -- + image_info = m.get("image") + if image_info: + fname_raw = image_info.get("filename", "image") + if text == f"[Image: {fname_raw}]": + text = "" + + # -- Message text -- + if text: + text_html = _linkify_urls(text) + text_html = text_html.replace("\n", "
") + # Search highlighting + if (self._search_active and self._search_query + and index in self._search_results): + is_current = ( + 0 <= self._search_current < len(self._search_results) + and self._search_results[self._search_current] == index + ) + bg = t.search_current if is_current else t.search_highlight + text_html = self._highlight_search_text( + text_html, self._search_query, bg + ) + # @Mentions highlighting + import re as _re + mention_c = t.mention + text_html = _re.sub( + r'@(\w+)', + lambda mt: ( + f'' + f'@{mt.group(1)}' + ), + text_html, + ) + + text_label = QLabel(text_html) + text_label.setTextFormat(Qt.TextFormat.RichText) + text_label.setWordWrap(True) + text_label.setStyleSheet( + f"color: {text_color}; background: transparent; " + f"font-size: 11pt;" + ) + text_label.setTextInteractionFlags( + Qt.TextInteractionFlag.LinksAccessibleByMouse + ) + text_label.linkActivated.connect(self._on_link_clicked) + blay.addWidget(text_label) + + # -- Image thumbnail -- + is_uploading = m.get("_uploading", False) + if image_info: + thumb_pixmap = self._decode_thumbnail(image_info) + file_id = image_info.get("file_id", "") + filename = image_info.get("filename", "image") + size_bytes = image_info.get("size", 0) + size_str = self._human_file_size(size_bytes) + if thumb_pixmap: + scaled = thumb_pixmap.scaledToWidth( + min(200, thumb_pixmap.width()), + Qt.TransformationMode.SmoothTransformation, + ) + if is_uploading: + # Semi-transparent overlay + spinner text on thumbnail + overlay_pm = QPixmap(scaled.size()) + overlay_pm.fill(Qt.GlobalColor.transparent) + p = QPainter(overlay_pm) + p.drawPixmap(0, 0, scaled) + p.fillRect(overlay_pm.rect(), QColor(0, 0, 0, 100)) + p.setPen(QColor(255, 255, 255)) + p.setFont(QFont("sans-serif", 11)) + p.drawText(overlay_pm.rect(), Qt.AlignmentFlag.AlignCenter, + "\u2B06 Uploading...") + p.end() + img_label = QLabel() + img_label.setPixmap(overlay_pm) + img_label.setStyleSheet("background: transparent;") + else: + img_label = QLabel() + img_label.setPixmap(scaled) + img_label.setStyleSheet("background: transparent;") + img_label.setCursor(Qt.CursorShape.PointingHandCursor) + img_label.mousePressEvent = ( + lambda e, fid=file_id: self._on_image_click(fid) + ) + blay.addWidget(img_label) + + if is_uploading: + upload_label = QLabel( + f'' + f'Encrypting and uploading...' + ) + upload_label.setTextFormat(Qt.TextFormat.RichText) + upload_label.setStyleSheet("background: transparent;") + blay.addWidget(upload_label) + else: + link_color = text_color if is_me else t.accent + fname_esc = (filename.replace("&", "&") + .replace("<", "<").replace(">", ">")) + info_label = QLabel( + f'' + f'{fname_esc} ({size_str}) \u2014 Click to view' + ) + info_label.setTextFormat(Qt.TextFormat.RichText) + info_label.setStyleSheet( + f"font-size: 9pt; background: transparent;" + ) + info_label.setTextInteractionFlags( + Qt.TextInteractionFlag.LinksAccessibleByMouse + ) + info_label.linkActivated.connect(self._on_link_clicked) + blay.addWidget(info_label) + + # -- File card -- + file_info = m.get("file") + if file_info: + fname = file_info.get("filename", "file") + fname_esc = (fname.replace("&", "&") + .replace("<", "<").replace(">", ">")) + fsize = file_info.get("size", 0) + size_str = self._human_file_size(fsize) + f_id = file_info.get("file_id", "") + icon = self._file_icon(fname) + file_link_color = text_color if is_me else t.accent + + file_frame = QFrame() + file_frame.setStyleSheet( + f"QFrame {{ background: transparent; " + f"border: 1px solid {meta_color}; " + f"border-radius: 6px; padding: 8px; }}" + ) + flay = QHBoxLayout(file_frame) + flay.setContentsMargins(0, 0, 0, 0) + file_label = QLabel( + f'{icon} {fname_esc}' + f' ({size_str})' + ) + file_label.setTextFormat(Qt.TextFormat.RichText) + file_label.setStyleSheet("background: transparent; border: none;") + file_label.setTextInteractionFlags( + Qt.TextInteractionFlag.LinksAccessibleByMouse + ) + file_label.linkActivated.connect(self._on_link_clicked) + flay.addWidget(file_label) + blay.addWidget(file_frame) + + # -- Reaction badges -- + reactions = m.get("reactions", []) + if reactions: + _REACTION_EMOJI = { + "thumbsup": "\U0001f44d", "heart": "\u2764\ufe0f", + "laugh": "\U0001f602", "surprised": "\U0001f62e", + "sad": "\U0001f622", "thumbsdown": "\U0001f44e", + } + grouped = {} + for r in reactions: + grouped.setdefault(r["reaction"], []).append(r["user_id"]) + my_id = (self.bridge.client.session.get("user_id", "") + if self.bridge.client.session else "") + react_widget = QWidget() + react_widget.setStyleSheet("background: transparent;") + react_lay = QHBoxLayout(react_widget) + react_lay.setContentsMargins(0, 2, 0, 0) + react_lay.setSpacing(4) + for rkey, uids in grouped.items(): + emoji = _REACTION_EMOJI.get(rkey, rkey) + count = len(uids) + is_mine = my_id in uids + bg = t.reaction_bg_own if is_mine else t.reaction_bg + bdr = t.reaction_border_own if is_mine else t.reaction_border + badge = QLabel(f"{emoji} {count}") + badge.setStyleSheet( + f"background-color: {bg}; color: {t.text_primary}; " + f"border: 1px solid {bdr}; " + f"border-radius: 10px; padding: 2px 6px; font-size: 10pt;" + ) + react_lay.addWidget(badge) + react_lay.addStretch() + blay.addWidget(react_widget) + + # -- Footer (timestamp + receipt, right-aligned, below content) -- + time_str = _format_msg_time(timestamp) + receipt_status = "" + if is_me: + read_by = m.get("read_by", []) + others_read = [r for r in read_by if r.get("user_id") != my_uid] + delivered_to = m.get("delivered_to", []) + others_delivered = [d for d in delivered_to if d.get("user_id") != my_uid] + if others_read: + receipt_status = "read" + elif others_delivered: + receipt_status = "delivered" + else: + receipt_status = "sent" + footer_w = _ReceiptFooter(time_str, receipt_status, meta_color, + t.bubble_sent_text, t.success) + footer_w.setStyleSheet("background: transparent;") + blay.addWidget(footer_w, alignment=Qt.AlignmentFlag.AlignRight) + + wlay.addWidget(bubble) + if not is_me: + wlay.addStretch(1) + + # Install event filter on all child widgets for context menu propagation + for child in bubble.findChildren(QWidget): + child.setContextMenuPolicy(Qt.ContextMenuPolicy.NoContextMenu) + child.installEventFilter(self) + bubble.installEventFilter(self) + wrapper.installEventFilter(self) + + return wrapper + + def _on_image_click(self, file_id): + """Handle click on an image thumbnail in a message bubble.""" + for msg in self.current_messages: + image_info = msg.get("image") + if image_info and image_info.get("file_id") == file_id: + self._view_image(msg) + return + + def _find_msg_index_at_widget(self, widget): + """Walk up the widget tree to find the _msg_index attribute.""" + while widget: + if hasattr(widget, '_msg_index'): + return widget._msg_index + widget = widget.parentWidget() + return None + + def _on_older_messages_loaded(self, conv_id, messages): + if conv_id != self.current_conv_id: + return + if not messages: + self._has_more_messages = False + self.load_more_btn.setVisible(False) + return + self._has_more_messages = len(messages) >= 50 + self.load_more_btn.setVisible(self._has_more_messages) + # Prepend older messages and re-render + self.current_messages = messages + self.current_messages + self._render_messages(scroll_to_bottom=False) + + def _on_load_more(self): + if not self.current_conv_id or not self._has_more_messages: + return + offset = len(self.current_messages) + self.bridge.load_older_messages(self.current_conv_id, offset) + + def _on_message_context_menu(self, pos): + if not self.current_messages: + return + global_pos = self._msg_scroll_area.viewport().mapToGlobal(pos) + widget = QApplication.widgetAt(global_pos) + idx = self._find_msg_index_at_widget(widget) + if idx is not None: + self._show_msg_context_menu(idx, global_pos) + + def _show_msg_context_menu(self, idx, global_pos): + if idx < 0 or idx >= len(self.current_messages): + return + m = self.current_messages[idx] + if m.get("deleted"): + return + + my_user_id = self.bridge.client.session.get("user_id", "") if self.bridge.client.session else "" + menu = QMenu(self) + + reply_icon = self.style().standardIcon(QStyle.StandardPixmap.SP_ArrowBack) + reply_action = menu.addAction(reply_icon, "Reply") + + # Reaction submenu + react_menu = menu.addMenu("React") + _REACTION_LABELS = { + "thumbsup": "\U0001f44d +1", + "heart": "\u2764\ufe0f Heart", + "laugh": "\U0001f602 Haha", + "surprised": "\U0001f62e Wow", + "sad": "\U0001f622 Sad", + "thumbsdown": "\U0001f44e -1", + } + react_actions = {} + for rkey, rlabel in _REACTION_LABELS.items(): + act = react_menu.addAction(rlabel) + react_actions[act] = rkey + + # Forward + fwd_action = menu.addAction("Forward") + + # Pin / Unpin + pin_action = None + if m.get("pinned_at"): + pin_action = menu.addAction("\U0001f4cc Unpin") + else: + pin_action = menu.addAction("\U0001f4cc Pin") + + menu.addSeparator() + + # Delete option for own messages + del_action = None + if m.get("sender_id") == my_user_id: + del_icon = self.style().standardIcon(QStyle.StandardPixmap.SP_TrashIcon) + del_action = menu.addAction(del_icon, "Delete") + + # View image option + img_action = None + if m.get("image"): + img_icon = self.style().standardIcon(QStyle.StandardPixmap.SP_FileDialogContentsView) + img_action = menu.addAction(img_icon, "View image") + + # Download file option + file_action = None + if m.get("file"): + file_icon = self.style().standardIcon(QStyle.StandardPixmap.SP_DialogSaveButton) + file_action = menu.addAction(file_icon, "Download file") + + # Reset session option for undecryptable messages + reset_action = None + if m.get("text", "").startswith("[Decryption failed"): + reset_icon = self.style().standardIcon(QStyle.StandardPixmap.SP_BrowserReload) + reset_action = menu.addAction(reset_icon, "Reset session with sender") + + chosen = menu.exec(global_pos) + if not chosen: + return + if chosen == reply_action: + self.reply_to_id = m["message_id"] + sender = m.get("sender", "???") + preview = m.get("text", "")[:40] + self.reply_label.setText(f"Replying to {sender}: {preview}") + self._reply_widget.setVisible(True) + self.msg_input.setFocus() + elif chosen in react_actions: + rkey = react_actions[chosen] + existing = m.get("reactions", []) + has_it = any(r["user_id"] == my_user_id and r["reaction"] == rkey for r in existing) + if has_it: + m["reactions"] = [r for r in existing if r["user_id"] != my_user_id] + self.bridge.react_message(m["message_id"], rkey, "remove") + else: + new_reactions = [r for r in existing if r["user_id"] != my_user_id] + new_reactions.append({"user_id": my_user_id, "reaction": rkey, "created_at": ""}) + m["reactions"] = new_reactions + self.bridge.react_message(m["message_id"], rkey, "add") + # Persist to local cache so reactions survive conversation switch + self.bridge.client.update_message_in_cache( + self.current_conv_id, m["message_id"], + {"reactions": m["reactions"]}) + self._render_messages(scroll_to_bottom=self._is_near_bottom) + elif chosen == fwd_action: + self._show_forward_dialog(m) + elif chosen == pin_action: + if m.get("pinned_at"): + m.pop("pinned_at", None) + m.pop("pinned_by", None) + self.bridge.pin_message(m["message_id"], self.current_conv_id, "unpin") + self.bridge.client.update_message_in_cache( + self.current_conv_id, m["message_id"], + {"pinned_at": None, "pinned_by": None}) + else: + my_user_id_pin = self.bridge.client.session.get("user_id", "") if self.bridge.client.session else "" + m["pinned_at"] = "now" + m["pinned_by"] = my_user_id_pin + self.bridge.pin_message(m["message_id"], self.current_conv_id, "pin") + self.bridge.client.update_message_in_cache( + self.current_conv_id, m["message_id"], + {"pinned_at": "now", "pinned_by": my_user_id_pin}) + self._render_messages(scroll_to_bottom=self._is_near_bottom) + self._update_pin_banner() + elif chosen == del_action: + if self._confirm_dialog("Delete Message", + "Delete this message? This cannot be undone."): + # Apply locally immediately (server notification only goes to others) + m["deleted"] = True + m["text"] = "" + m["image"] = None + m["file"] = None + self.bridge.client.update_message_in_cache( + self.current_conv_id, m["message_id"], {"deleted": True}) + self._render_messages(scroll_to_bottom=self._is_near_bottom) + self.bridge.delete_message(m["message_id"]) + elif chosen == img_action: + self._view_image(m) + elif chosen == file_action: + file_info = m.get("file") + if file_info: + self.bridge.download_file(file_info["file_id"], file_info) + elif chosen == reset_action: + sender_id = m.get("sender_id", "") + if sender_id: + if self._confirm_dialog("Reset Session", + "Reset encryption session with this sender? " + "A new session will be created on the next message."): + self.bridge.reset_session(sender_id) + + # ------------------------------------------------------------------ + # Search + # ------------------------------------------------------------------ + + def _toggle_search(self): + if self._search_active: + self._close_search() + else: + if not self.current_conv_id: + return + self._search_active = True + self.search_widget.setVisible(True) + self.search_input.setFocus() + self.search_input.selectAll() + + def _close_search(self): + self._search_active = False + self._search_query = "" + self._search_results = [] + self._search_current = -1 + self.search_widget.setVisible(False) + self.search_input.clear() + self.search_count_label.setText("0/0") + # Re-render to remove highlights + if self.current_messages: + self._render_messages(scroll_to_bottom=False) + + def _on_search_text_changed(self, text): + self._search_query = text.strip() + if not self._search_query: + self._search_results = [] + self._search_current = -1 + self.search_count_label.setText("0/0") + self._render_messages(scroll_to_bottom=False) + return + query_lower = self._search_query.lower() + self._search_results = [] + for i, m in enumerate(self.current_messages): + if m.get("deleted"): + continue + msg_text = m.get("text", "") + if query_lower in msg_text.lower(): + self._search_results.append(i) + if self._search_results: + self._search_current = 0 + self.search_count_label.setText(f"1/{len(self._search_results)}") + else: + self._search_current = -1 + self.search_count_label.setText("0/0") + self._render_messages(scroll_to_bottom=False) + if self._search_results: + self._scroll_to_message(self._search_results[self._search_current]) + + def _on_search_next(self): + if not self._search_results: + return + self._search_current = (self._search_current + 1) % len(self._search_results) + self.search_count_label.setText(f"{self._search_current + 1}/{len(self._search_results)}") + self._render_messages(scroll_to_bottom=False) + self._scroll_to_message(self._search_results[self._search_current]) + + def _on_search_prev(self): + if not self._search_results: + return + self._search_current = (self._search_current - 1) % len(self._search_results) + self.search_count_label.setText(f"{self._search_current + 1}/{len(self._search_results)}") + self._render_messages(scroll_to_bottom=False) + self._scroll_to_message(self._search_results[self._search_current]) + + def _scroll_to_message(self, index): + if 0 <= index < len(self._msg_widgets): + self._msg_scroll_area.ensureWidgetVisible( + self._msg_widgets[index], 0, 50 + ) + + @staticmethod + def _highlight_search_text(html_text: str, query: str, bg_color: str) -> str: + """Highlight matching text in HTML, skipping content inside tags.""" + query_esc = query.replace("&", "&").replace("<", "<").replace(">", ">") + if not query_esc: + return html_text + result = [] + i = 0 + q_lower = query_esc.lower() + q_len = len(query_esc) + while i < len(html_text): + if html_text[i] == '<': + # Skip HTML tags + end = html_text.find('>', i) + if end == -1: + result.append(html_text[i:]) + break + result.append(html_text[i:end + 1]) + i = end + 1 + else: + # Look for match in text content + chunk_end = html_text.find('<', i) + if chunk_end == -1: + chunk_end = len(html_text) + chunk = html_text[i:chunk_end] + # Case-insensitive replace within this chunk + chunk_lower = chunk.lower() + out = [] + j = 0 + while j < len(chunk): + pos = chunk_lower.find(q_lower, j) + if pos == -1: + out.append(chunk[j:]) + break + out.append(chunk[j:pos]) + matched = chunk[pos:pos + q_len] + out.append(f'{matched}') + j = pos + q_len + result.append("".join(out)) + i = chunk_end + return "".join(result) + + # ------------------------------------------------------------------ + # Session reset + # ------------------------------------------------------------------ + + def _on_session_reset(self, from_user_id, from_device_id): + # Find username for the user + username = from_user_id[:8] + for cv in self.conversations: + for m in cv.get("members", []): + uid = m.get("user_id") or m.get("id") + if uid == from_user_id: + username = m.get("username") or m.get("email") or username + break + self.status_bar.setText(f"Session with {username} was reset. New session will be created on next message.") + t = c() + self.status_bar.setStyleSheet( + f"background-color: {t.bg_tertiary}; border-radius: 0px; " + f"padding: 0 8px; color: {t.warning}; font-size: 8pt; font-weight: bold;" + ) + QTimer.singleShot(8000, self._clear_status_bar) + + # ------------------------------------------------------------------ + # Reactions / Pins / Forward notification handlers + # ------------------------------------------------------------------ + + def _on_reaction_notification(self, data): + """Handle message_reacted push notification.""" + conv_id = data.get("conversation_id", "") + msg_id = data.get("message_id", "") + user_id = data.get("user_id", "") + reaction = data.get("reaction", "") + action = data.get("action", "add") + + if conv_id != self.current_conv_id: + return + for msg in self.current_messages: + if msg.get("message_id") == msg_id: + reactions = msg.get("reactions", []) + if action == "add": + reactions = [r for r in reactions if r["user_id"] != user_id] + reactions.append({"user_id": user_id, "reaction": reaction, "created_at": ""}) + else: + reactions = [r for r in reactions if r["user_id"] != user_id] + msg["reactions"] = reactions + self.bridge.client.update_message_in_cache( + conv_id, msg_id, {"reactions": reactions}) + break + self._render_messages(scroll_to_bottom=self._is_near_bottom) + + def _on_pin_notification(self, data): + """Handle message_pinned push notification.""" + conv_id = data.get("conversation_id", "") + msg_id = data.get("message_id", "") + user_id = data.get("user_id", "") + if conv_id != self.current_conv_id: + return + for msg in self.current_messages: + if msg.get("message_id") == msg_id: + msg["pinned_at"] = "now" + msg["pinned_by"] = user_id + self.bridge.client.update_message_in_cache( + conv_id, msg_id, {"pinned_at": "now", "pinned_by": user_id}) + break + self._render_messages(scroll_to_bottom=self._is_near_bottom) + self._update_pin_banner() + username = data.get("username", user_id[:8] if user_id else "?") + self.status_bar.setText(f"{username} pinned a message") + QTimer.singleShot(3000, self._clear_status_bar) + + def _on_unpin_notification(self, data): + """Handle message_unpinned push notification.""" + conv_id = data.get("conversation_id", "") + msg_id = data.get("message_id", "") + if conv_id != self.current_conv_id: + return + for msg in self.current_messages: + if msg.get("message_id") == msg_id: + msg.pop("pinned_at", None) + msg.pop("pinned_by", None) + self.bridge.client.update_message_in_cache( + conv_id, msg_id, {"pinned_at": None, "pinned_by": None}) + break + self._render_messages(scroll_to_bottom=self._is_near_bottom) + self._update_pin_banner() + + def _on_pinned_messages_loaded(self, conv_id, pinned): + """Show pinned messages dialog.""" + if conv_id != self.current_conv_id: + return + if not pinned: + dlg = QDialog(self) + dlg.setMinimumWidth(300) + lay = _make_frameless(dlg, "Pinned Messages") + t = c() + lbl = QLabel("No pinned messages in this conversation.") + lbl.setStyleSheet(f"color: {t.text_primary}; font-size: 10pt;") + lbl.setWordWrap(True) + lay.addWidget(lbl) + ok_btn = QPushButton("OK") + ok_btn.clicked.connect(dlg.accept) + lay.addWidget(ok_btn) + dlg.exec() + return + dlg = QDialog(self) + dlg.setMinimumSize(360, 300) + t = c() + layout = _make_frameless(dlg, "Pinned Messages") + lw = QListWidget() + lw.setStyleSheet( + f"QListWidget {{ background-color:{t.bg_secondary}; border:1px solid {t.border}; border-radius:6px; }}" + f"QListWidget::item {{ padding:8px; color:{t.text_primary}; border-bottom:1px solid {t.border}; }}" + f"QListWidget::item:selected {{ background-color:{t.bg_hover}; }}" + ) + # Build a sender map from current messages + sender_map = {} + for m in self.current_messages: + sid = m.get("sender_id", "") + if sid and sid not in sender_map: + sender_map[sid] = m.get("sender", sid[:8]) + for p in pinned: + sender = sender_map.get(p.get("sender_id", ""), p.get("sender_id", "?")[:8]) + # Find message text from current_messages + text_preview = "" + for m in self.current_messages: + if m.get("message_id") == p.get("message_id"): + text_preview = m.get("text", "")[:60] + if not sender_map.get(p.get("sender_id", "")): + sender = m.get("sender", sender) + break + ts = p.get("pinned_at", "")[:16] if p.get("pinned_at") else "" + item = QListWidgetItem(f"\U0001f4cc {sender}: {text_preview}\n Pinned: {ts}") + item.setData(Qt.ItemDataRole.UserRole, p.get("message_id")) + lw.addItem(item) + layout.addWidget(lw) + close_btn = QPushButton("Close") + close_btn.setObjectName("secondaryBtn") + close_btn.clicked.connect(dlg.accept) + layout.addWidget(close_btn) + + def _on_item_clicked(item): + target_id = item.data(Qt.ItemDataRole.UserRole) + if target_id: + for i, msg in enumerate(self.current_messages): + if msg.get("message_id") == target_id: + self._scroll_to_message(i) + break + dlg.accept() + lw.itemDoubleClicked.connect(_on_item_clicked) + dlg.exec() + + def _update_pin_banner(self): + """Update the pinned message banner from current_messages.""" + pinned = [m for m in self.current_messages if m.get("pinned_at")] + if not pinned: + self._pin_banner.setVisible(False) + self._pin_banner_msg_id = None + return + # Show the most recently pinned message + latest = pinned[-1] + sender = latest.get("sender", "?") + text = latest.get("text", "") + if len(text) > 80: + text = text[:80] + "..." + text = text.replace("\n", " ") + # HTML-escape user-controlled data to prevent injection + sender_esc = sender.replace("&", "&").replace("<", "<").replace(">", ">") + text_esc = text.replace("&", "&").replace("<", "<").replace(">", ">") + count_str = f" ({len(pinned)} pinned)" if len(pinned) > 1 else "" + self._pin_banner_label.setText(f"{sender_esc}: {text_esc}{count_str}") + self._pin_banner_msg_id = latest.get("message_id") + self._pin_banner.setVisible(True) + + def _on_pin_banner_clicked(self, event): + """Scroll to the pinned message when banner is clicked.""" + if self._pin_banner_msg_id: + for i, msg in enumerate(self.current_messages): + if msg.get("message_id") == self._pin_banner_msg_id: + self._scroll_to_message(i) + break + + def _show_pinned_messages(self): + """Fetch and show pinned messages for current conversation.""" + if self.current_conv_id: + self.bridge.get_pinned_messages(self.current_conv_id) + + def _on_forward_result(self, ok, msg): + if ok: + self.status_bar.setText("Message forwarded") + QTimer.singleShot(3000, self._clear_status_bar) + else: + self.status_bar.setText(f"Forward failed: {msg}") + QTimer.singleShot(5000, self._clear_status_bar) + + def _confirm_dialog(self, title: str, text: str) -> bool: + """Show a frameless confirmation dialog. Returns True if user confirmed.""" + dlg = QDialog(self) + dlg.setMinimumWidth(340) + layout = _make_frameless(dlg, title) + t = c() + label = QLabel(text) + label.setWordWrap(True) + label.setStyleSheet(f"color: {t.text_primary}; font-size: 10pt;") + layout.addWidget(label) + btn_lay = QHBoxLayout() + btn_lay.setSpacing(8) + cancel_btn = QPushButton("Cancel") + cancel_btn.setObjectName("secondaryBtn") + confirm_btn = QPushButton("Delete") + confirm_btn.setStyleSheet( + f"QPushButton {{ background-color: {t.error}; color: {t.accent_text}; " + f"border: none; border-radius: 6px; padding: 8px 16px; font-weight: bold; }}" + f"QPushButton:hover {{ background-color: {t.warning}; }}" + ) + btn_lay.addStretch() + btn_lay.addWidget(cancel_btn) + btn_lay.addWidget(confirm_btn) + layout.addLayout(btn_lay) + cancel_btn.clicked.connect(dlg.reject) + confirm_btn.clicked.connect(dlg.accept) + return dlg.exec() == QDialog.DialogCode.Accepted + + def _show_forward_dialog(self, msg): + """Show dialog to pick a conversation to forward a message to.""" + dlg = QDialog(self) + dlg.setMinimumSize(320, 360) + t = c() + layout = _make_frameless(dlg, "Forward message") + label = QLabel("Select conversation:") + label.setStyleSheet(f"color:{t.text_primary}; font-size:10pt;") + layout.addWidget(label) + conv_list = QListWidget() + conv_list.setStyleSheet( + f"QListWidget {{ background-color:{t.bg_secondary}; border:1px solid {t.border}; border-radius:6px; }}" + f"QListWidget::item {{ padding:8px; color:{t.text_primary}; }}" + f"QListWidget::item:selected {{ background-color:{t.bg_hover}; }}" + ) + for cv in self.conversations: + if cv["conversation_id"] != self.current_conv_id: + name = cv.get("name") + if not name: + others = [m.get("username") or m.get("email") or "?" for m in cv["members"] + if m.get("email") != self.bridge.client.email] + name = ", ".join(others) if others else "?" + item = QListWidgetItem(name) + item.setData(Qt.ItemDataRole.UserRole, cv) + conv_list.addItem(item) + layout.addWidget(conv_list) + fwd_btn = QPushButton("Forward") + fwd_btn.setObjectName("secondaryBtn") + layout.addWidget(fwd_btn) + + def _do_forward(): + sel = conv_list.currentItem() + if not sel: + return + target_conv = sel.data(Qt.ItemDataRole.UserRole) + if target_conv: + fwd_msg = dict(msg) + fwd_msg["conversation_id"] = self.current_conv_id + self.bridge.forward_message( + target_conv["conversation_id"], fwd_msg, target_conv["members"] + ) + dlg.accept() + + fwd_btn.clicked.connect(_do_forward) + conv_list.itemDoubleClicked.connect(lambda _: _do_forward()) + dlg.exec() + + # ------------------------------------------------------------------ + # @Mentions autocomplete + # ------------------------------------------------------------------ + + def _setup_mention_completer(self): + """Set up the mention autocomplete popup for msg_input.""" + self._mention_popup = QListWidget(self) + self._mention_popup.setWindowFlags(Qt.WindowType.Popup) + t = c() + self._mention_popup.setStyleSheet( + f"QListWidget {{ background:{t.bg_secondary}; color:{t.text_primary}; border:1px solid {t.border}; " + f"border-radius:4px; font-size:10pt; }}" + f"QListWidget::item {{ padding:4px 8px; }}" + f"QListWidget::item:selected {{ background:{t.bg_hover}; }}" + ) + self._mention_popup.setMaximumHeight(150) + self._mention_popup.itemClicked.connect(self._on_mention_selected) + self._mention_popup.hide() + self._mention_active = False + + def _check_mention_trigger(self): + """Check if user is typing @mention and show autocomplete.""" + if not hasattr(self, '_mention_popup'): + self._setup_mention_completer() + text = self.msg_input.toPlainText() + cursor_pos = self.msg_input.textCursor().position() + # Find @word at cursor position + before = text[:cursor_pos] + import re as _re + match = _re.search(r'@(\w*)$', before) + if not match: + self._mention_popup.hide() + self._mention_active = False + return + prefix = match.group(1).lower() + self._mention_start = match.start() + + # Get members of current conversation + members = [] + for cv in self.conversations: + if cv["conversation_id"] == self.current_conv_id: + members = cv.get("members", []) + break + my_email = self.bridge.client.email if self.bridge.client else "" + candidates = [] + for m in members: + uname = m.get("username") or m.get("email") or "" + if m.get("email") == my_email: + continue + if prefix == "" or uname.lower().startswith(prefix): + candidates.append(uname) + + if not candidates: + self._mention_popup.hide() + self._mention_active = False + return + + self._mention_popup.clear() + for cand in candidates[:6]: + self._mention_popup.addItem(cand) + # Position popup above the input + cursor_rect = self.msg_input.cursorRect() + global_pos = self.msg_input.mapToGlobal(cursor_rect.bottomLeft()) + self._mention_popup.move(global_pos.x(), global_pos.y() - self._mention_popup.sizeHint().height() - 5) + self._mention_popup.setFixedWidth(max(200, self.msg_input.width() // 2)) + self._mention_popup.show() + self._mention_active = True + + def _on_mention_selected(self, item): + """Insert the selected @mention into msg_input.""" + username = item.text() + text = self.msg_input.toPlainText() + cursor_pos = self.msg_input.textCursor().position() + # Replace @prefix with @username + new_text = text[:self._mention_start] + f"@{username} " + text[cursor_pos:] + self.msg_input.setPlainText(new_text) + cursor = self.msg_input.textCursor() + cursor.setPosition(self._mention_start + len(username) + 2) + self.msg_input.setTextCursor(cursor) + self._mention_popup.hide() + self._mention_active = False + self.msg_input.setFocus() + + def _on_input_changed(self): + text = self.msg_input.toPlainText() + count = len(text) + if count > MAX_INPUT_CHARS: + cursor = self.msg_input.textCursor() + self.msg_input.setPlainText(text[:MAX_INPUT_CHARS]) + cursor.movePosition(cursor.MoveOperation.End) + self.msg_input.setTextCursor(cursor) + count = MAX_INPUT_CHARS + color = c().error if count > MAX_INPUT_CHARS * 0.9 else c().text_muted + self.char_counter.setStyleSheet(f"color: {color}; font-size: 8pt; padding: 0 4px;") + self.char_counter.setText(f"{count} / {MAX_INPUT_CHARS}") + # @Mention autocomplete check + if self.current_conv_id: + if text.strip(): + self.bridge.send_typing_start(self.current_conv_id) + else: + self.bridge.send_typing_stop(self.current_conv_id) + self._check_mention_trigger() + + def _cancel_reply(self): + """Dismiss the reply preview.""" + self.reply_to_id = None + self.reply_label.setText("") + self._reply_widget.setVisible(False) + + def _on_send(self): + text = self.msg_input.toPlainText().strip() + if not text or not self.current_conv_id: + return + if len(text) > MAX_INPUT_CHARS: + QMessageBox.warning(self, "Message Too Long", + f"Message too long (max {MAX_INPUT_CHARS} characters).") + return + conv = None + for cv in self.conversations: + if cv["conversation_id"] == self.current_conv_id: + conv = cv + break + if not conv: + return + conv_id = self.current_conv_id + reply_to = self.reply_to_id + self.msg_input.clear() + self.bridge.send_typing_stop(conv_id, force=True) + self.reply_to_id = None + self._reply_widget.setVisible(False) + + # Optimistic display — show message immediately before server confirms + my_uid = self.bridge.client.session.get("user_id", "") if self.bridge.client.session else "" + optimistic_payload = { + "sender": self.bridge.client.username, + "text": text, + "reply_to": reply_to, + "message_id": "", # Filled when server responds + "created_at": "", + "sender_id": my_uid, + "conversation_id": conv_id, + "read_by": [], + "delivered_to": [], + "_optimistic": True, # Marker for optimistic message + } + if conv_id == self.current_conv_id: + self.current_messages.append(optimistic_payload) + idx = len(self.current_messages) - 1 + w = self._create_message_widget(optimistic_payload, idx) + self._msg_layout.addWidget(w) + self._msg_widgets.append(w) + if self._is_near_bottom: + QTimer.singleShot(10, self._scroll_to_bottom) + + self.bridge.send_message(conv_id, text, conv["members"], reply_to=reply_to) + + def _on_message_sent(self, ok, msg): + if not ok: + self._remove_optimistic() + QMessageBox.warning(self, "Error", msg) + + def _on_message_queued(self, count: int): + """Show status bar feedback when messages are queued for retry after reconnect.""" + t = c() + if count > 0: + label = "1 message" if count == 1 else f"{count} messages" + self.status_bar.setText(f"{label} queued — will send when reconnected") + self.status_bar.setStyleSheet( + f"background-color: {t.bg_tertiary}; border-radius: 0px; " + f"padding: 0 8px; color: {t.warning}; font-size: 8pt;" + ) + self._status_bar_conv_id = None + else: + # Queue was flushed (messages retried after reconnect) + self._clear_status_bar() + + def _remove_optimistic(self, opt_type: str = ""): + """Remove the last optimistic message of given type from UI.""" + for i in range(len(self.current_messages) - 1, -1, -1): + m = self.current_messages[i] + if m.get("_optimistic") and (not opt_type or m.get("_optimistic_type", "") == opt_type): + self.current_messages.pop(i) + if i < len(self._msg_widgets): + w = self._msg_widgets.pop(i) + self._msg_layout.removeWidget(w) + w.deleteLater() + break + + def _show_optimistic_attachment(self, conv_id: str, is_image: bool, filename: str = "", + reply_to: str | None = None, + image_path: str | None = None): + """Show an optimistic placeholder for image/file being uploaded.""" + my_uid = self.bridge.client.session.get("user_id", "") if self.bridge.client.session else "" + optimistic_payload = { + "sender": self.bridge.client.username, + "text": "" if is_image else f"Sending {filename}...", + "reply_to": reply_to, + "message_id": "", + "created_at": "", + "sender_id": my_uid, + "conversation_id": conv_id, + "read_by": [], + "delivered_to": [], + "_optimistic": True, + "_optimistic_type": "image" if is_image else "file", + } + # For images, generate a local thumbnail for immediate preview + if is_image and image_path: + try: + raw = open(image_path, "rb").read() + qimg = QImage() + qimg.loadFromData(raw) + if not qimg.isNull(): + thumb = qimg.scaled( + 200, 200, + Qt.AspectRatioMode.KeepAspectRatio, + Qt.TransformationMode.SmoothTransformation, + ) + buf = QBuffer() + buf.open(QBuffer.OpenModeFlag.WriteOnly) + thumb.save(buf, "JPEG", 60) + from protocol import encode_binary + thumb_b64 = encode_binary(bytes(buf.data())) + display_name = filename or os.path.basename(image_path) + optimistic_payload["image"] = { + "file_id": "", + "thumbnail": thumb_b64, + "filename": display_name, + "size": os.path.getsize(image_path), + } + optimistic_payload["text"] = "" + optimistic_payload["_uploading"] = True + else: + logger.warning("QImage failed to load: %s", image_path) + optimistic_payload["text"] = "Sending image..." + except Exception as exc: + logger.warning("Optimistic thumbnail failed: %s", exc) + optimistic_payload["text"] = "Sending image..." + if conv_id == self.current_conv_id: + self.current_messages.append(optimistic_payload) + idx = len(self.current_messages) - 1 + w = self._create_message_widget(optimistic_payload, idx) + self._msg_layout.addWidget(w) + self._msg_widgets.append(w) + if self._is_near_bottom: + QTimer.singleShot(10, self._scroll_to_bottom) + + def _on_message_sent_payload(self, conv_id, payload): + """Update optimistic message with server-confirmed data, or append if no optimistic.""" + # Update last-message cache + preview = payload.get("text", "") + if payload.get("image") and not preview: + preview = "Sent an image" + elif payload.get("file") and not preview: + preview = "Sent a file" + self._last_message_cache[conv_id] = (preview[:60], payload.get("created_at", ""), "sent") + self._update_conv_list_styles() + + if conv_id != self.current_conv_id: + return + msg_id = payload.get("message_id", "") + # Avoid duplicate if notification arrived first (race) + if msg_id: + for m in self.current_messages: + if m.get("message_id") == msg_id: + return + # Find and update the optimistic message + # Match by: text content (text msgs), or _optimistic_type (image/file) + is_image_payload = bool(payload.get("image")) + is_file_payload = bool(payload.get("file")) + for i, m in enumerate(self.current_messages): + if not m.get("_optimistic"): + continue + opt_type = m.get("_optimistic_type", "") + matched = False + if opt_type == "image" and is_image_payload: + matched = True + elif opt_type == "file" and is_file_payload: + matched = True + elif not opt_type and m.get("text") == payload.get("text"): + matched = True + if matched: + # Upgrade optimistic → confirmed + self.current_messages[i] = payload + # Rebuild the widget for this message + if i < len(self._msg_widgets): + old_w = self._msg_widgets[i] + new_w = self._create_message_widget(payload, i) + self._msg_layout.replaceWidget(old_w, new_w) + old_w.deleteLater() + self._msg_widgets[i] = new_w + return + # No optimistic found — append normally (e.g. image/file sends) + self.current_messages.append(payload) + idx = len(self.current_messages) - 1 + w = self._create_message_widget(payload, idx) + self._msg_layout.addWidget(w) + self._msg_widgets.append(w) + if self._is_near_bottom: + QTimer.singleShot(10, self._scroll_to_bottom) + + def _on_new_chat(self): + dlg = QDialog(self) + dlg.setMinimumWidth(400) + t = c() + lay = _make_frameless(dlg, "New Chat") + lay.setSpacing(12) + + email_label = QLabel("Recipient email") + email_label.setStyleSheet(f"color: {t.text_secondary}; font-size: 9pt;") + lay.addWidget(email_label) + email_input = QLineEdit() + email_input.setPlaceholderText("user@example.com") + email_input.setMinimumHeight(36) + lay.addWidget(email_input) + + msg_label = QLabel("First message") + msg_label.setStyleSheet(f"color: {t.text_secondary}; font-size: 9pt;") + lay.addWidget(msg_label) + msg_input = QLineEdit() + msg_input.setPlaceholderText("Type a message...") + msg_input.setMinimumHeight(36) + lay.addWidget(msg_input) + + btn_row = QHBoxLayout() + btn_row.addStretch() + cancel_btn = QPushButton("Cancel") + cancel_btn.setObjectName("secondaryBtn") + cancel_btn.clicked.connect(dlg.reject) + btn_row.addWidget(cancel_btn) + send_btn = QPushButton("Send") + send_btn.clicked.connect(dlg.accept) + btn_row.addWidget(send_btn) + lay.addLayout(btn_row) + + msg_input.returnPressed.connect(dlg.accept) + email_input.returnPressed.connect(lambda: msg_input.setFocus()) + + if dlg.exec() != QDialog.DialogCode.Accepted: + return + email = email_input.text().strip() + text = msg_input.text().strip() + if not email or not text: + return + if len(text) > MAX_INPUT_CHARS: + QMessageBox.warning(self, "Message Too Long", + f"Message too long (max {MAX_INPUT_CHARS} characters).") + return + self.bridge.send_new_chat(email, text) + + def _on_new_group(self): + dlg = QDialog(self) + dlg.setMinimumWidth(400) + t = c() + lay = _make_frameless(dlg, "New Group") + lay.setSpacing(12) + + name_label = QLabel("Group name") + name_label.setStyleSheet(f"color: {t.text_secondary}; font-size: 9pt;") + lay.addWidget(name_label) + name_input = QLineEdit() + name_input.setPlaceholderText("My Group") + name_input.setMinimumHeight(36) + lay.addWidget(name_input) + + members_label = QLabel("Member emails (comma-separated)") + members_label.setStyleSheet(f"color: {t.text_secondary}; font-size: 9pt;") + lay.addWidget(members_label) + members_input = QLineEdit() + members_input.setPlaceholderText("alice@example.com, bob@example.com") + members_input.setMinimumHeight(36) + lay.addWidget(members_input) + + btn_row = QHBoxLayout() + btn_row.addStretch() + cancel_btn = QPushButton("Cancel") + cancel_btn.setObjectName("secondaryBtn") + cancel_btn.clicked.connect(dlg.reject) + btn_row.addWidget(cancel_btn) + create_btn = QPushButton("Create") + create_btn.clicked.connect(dlg.accept) + btn_row.addWidget(create_btn) + lay.addLayout(btn_row) + + members_input.returnPressed.connect(dlg.accept) + name_input.returnPressed.connect(lambda: members_input.setFocus()) + + if dlg.exec() != QDialog.DialogCode.Accepted: + return + members = members_input.text().strip() + if not members: + return + member_list = [m.strip() for m in members.split(",") if m.strip()] + if member_list: + self.bridge.create_group(member_list, name=name_input.text().strip() or None) + + def _on_add_member(self): + if not self.current_conv_id: + return + dlg = QDialog(self) + dlg.setMinimumWidth(360) + t = c() + lay = _make_frameless(dlg, "Add Member") + lay.setSpacing(12) + lbl = QLabel("Email to invite") + lbl.setStyleSheet(f"color: {t.text_secondary}; font-size: 9pt;") + lay.addWidget(lbl) + email_input = QLineEdit() + email_input.setPlaceholderText("user@example.com") + email_input.setMinimumHeight(36) + email_input.returnPressed.connect(dlg.accept) + lay.addWidget(email_input) + btn_row = QHBoxLayout() + btn_row.addStretch() + cancel_btn = QPushButton("Cancel") + cancel_btn.setObjectName("secondaryBtn") + cancel_btn.clicked.connect(dlg.reject) + btn_row.addWidget(cancel_btn) + add_btn = QPushButton("Add") + add_btn.clicked.connect(dlg.accept) + btn_row.addWidget(add_btn) + lay.addLayout(btn_row) + if dlg.exec() != QDialog.DialogCode.Accepted: + return + email = email_input.text().strip() + if not email: + return + self.bridge.add_member(self.current_conv_id, email) + + def _on_add_member_result(self, ok, msg): + if ok: + QMessageBox.information(self, "Add Member", "Invitation sent.") + else: + QMessageBox.warning(self, "Add Member", msg) + + def _on_group_info(self): + if not self.current_conv_id: + return + conv = None + for cv in self.conversations: + if cv["conversation_id"] == self.current_conv_id: + conv = cv + break + if not conv: + return + my_user_id = self.bridge.client.session.get("user_id", "") if self.bridge.client.session else "" + is_creator = conv.get("created_by") == my_user_id + group_name = conv.get("name") or "Group" + members = conv["members"] + + dlg = QDialog(self) + dlg.setMinimumWidth(380) + t = c() + dlg_layout = _make_frameless(dlg, "Group Info") + + # Group avatar + avatar_row = QHBoxLayout() + avatar_label = QLabel() + avatar_label.setFixedSize(64, 64) + avatar_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + conv_id = conv["conversation_id"] + if conv_id in self._group_avatar_cache: + avatar_pix = self._make_circular_avatar(self._group_avatar_cache[conv_id], size=64) + else: + avatar_pix = self._make_default_avatar(group_name, size=64) + avatar_label.setPixmap(avatar_pix) + avatar_row.addWidget(avatar_label) + + group_name_esc = group_name.replace("&", "&").replace("<", "<").replace(">", ">") + title = QLabel(f"{group_name_esc}") + avatar_row.addWidget(title, stretch=1) + + if is_creator: + change_avatar_btn = QPushButton("Change Avatar") + change_avatar_btn.setObjectName("secondaryBtn") + change_avatar_btn.clicked.connect(lambda: self._do_change_group_avatar(conv_id, dlg)) + avatar_row.addWidget(change_avatar_btn) + + rename_btn = QPushButton("Rename") + rename_btn.setObjectName("secondaryBtn") + rename_btn.clicked.connect(lambda: self._do_rename_group(conv_id, group_name, dlg)) + avatar_row.addWidget(rename_btn) + + dlg_layout.addLayout(avatar_row) + + count_label = QLabel(f"Members ({len(members)}):") + count_label.setStyleSheet("margin-top: 8px;") + dlg_layout.addWidget(count_label) + + for mem in members: + uname = mem.get("username") or mem.get("email") or "?" + email = mem.get("email", "") + uid = mem.get("user_id") or mem.get("id") or "" + is_mem_creator = uid == conv.get("created_by") + + row = QHBoxLayout() + uname_esc = uname.replace("&", "&").replace("<", "<").replace(">", ">") + email_esc = email.replace("&", "&").replace("<", "<").replace(">", ">") + is_online = uid in self._online_users + online_dot = "\U0001f7e2 " if is_online else "" + verified_dot = "" + my_uid = self.bridge.client.session.get("user_id", "") if self.bridge.client.session else "" + if uid and uid != my_uid and self.bridge.client.get_verification_status(uid) == "verified": + verified_dot = " \u2705" + name_text = f"{online_dot}{uname_esc}{verified_dot}" + if email: + name_text += f" {email_esc}" + if is_mem_creator: + name_text += f" creator" + name_label = QLabel(name_text) + name_label.setWordWrap(True) + row.addWidget(name_label, stretch=1) + + info_btn = QPushButton("") + info_btn.setFixedSize(28, 28) + info_btn.setObjectName("secondaryBtn") + info_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_MessageBoxInformation)) + info_btn.setToolTip(f"View profile of {uname}") + info_btn.clicked.connect(lambda checked, u=uid, d=dlg: (d.accept(), self._show_user_profile(u))) + row.addWidget(info_btn) + + # Remove button (only for creator, not on self) + if is_creator and uid != my_user_id: + remove_btn = QPushButton("") + remove_btn.setFixedSize(28, 28) + remove_btn.setObjectName("secondaryBtn") + remove_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_DialogCloseButton)) + remove_btn.setToolTip(f"Remove {uname}") + remove_btn.clicked.connect(lambda checked, u=uid, n=uname, d=dlg: self._do_remove_member_action(u, n, d)) + row.addWidget(remove_btn) + + dlg_layout.addLayout(row) + + dlg_layout.addSpacing(12) + + # Leave Group button + leave_btn = QPushButton("Leave Group") + leave_btn.setStyleSheet( + f"QPushButton {{ background-color: {t.error}; color: {t.accent_text}; font-weight: bold; }}" + f"QPushButton:hover {{ opacity: 0.8; }}" + ) + leave_btn.clicked.connect(lambda: self._do_leave_group_action(dlg)) + dlg_layout.addWidget(leave_btn) + + close_btn = QPushButton("Close") + close_btn.clicked.connect(dlg.accept) + dlg_layout.addWidget(close_btn) + dlg.exec() + + def _do_remove_member_action(self, user_id, username, dialog): + if not self.current_conv_id: + return + confirm = QMessageBox.question( + self, "Remove Member", + f"Remove {username} from the group?", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + ) + if confirm == QMessageBox.StandardButton.Yes: + dialog.accept() + self.bridge.remove_member(self.current_conv_id, user_id) + + def _do_change_group_avatar(self, conv_id, dialog): + path, _ = QFileDialog.getOpenFileName( + dialog, "Select Group Avatar", "", + "Images (*.png *.jpg *.jpeg);;All Files (*)", + ) + if not path: + return + try: + with open(path, "rb") as f: + image_data = f.read() + if len(image_data) > 2 * 1024 * 1024: + QMessageBox.warning(dialog, "Too Large", "Avatar must be under 2 MB.") + return + dialog.accept() + self.bridge.update_group_avatar(conv_id, image_data) + except Exception as e: + QMessageBox.warning(dialog, "Error", f"Failed to read image: {e}") + + def _do_rename_group(self, conv_id, current_name, dialog): + from PyQt6.QtWidgets import QInputDialog + new_name, ok = QInputDialog.getText( + dialog, "Rename Group", "New group name:", + text=current_name, + ) + if ok and new_name.strip(): + new_name = new_name.strip() + if new_name != current_name: + dialog.accept() + self.bridge.rename_conversation(conv_id, new_name) + + def _on_group_renamed(self, ok, msg): + if not ok: + QMessageBox.warning(self, "Rename Group", msg) + + def _do_leave_group_action(self, dialog): + if not self.current_conv_id: + return + confirm = QMessageBox.question( + self, "Leave Group", + "Leave this group? You will no longer receive messages.", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + ) + if confirm == QMessageBox.StandardButton.Yes: + dialog.accept() + self.bridge.leave_group(self.current_conv_id) + + def _on_group_left(self, ok, msg): + if ok: + self.current_conv_id = None + self.msg_input.drop_enabled = False + self.chat_header.setText("Select a conversation") + self.chat_header_avatar.setVisible(False) + self._clear_message_area() + self.group_info_btn.setVisible(False) + self.user_info_btn.setVisible(False) + self.add_member_btn.setVisible(False) + self.delete_conv_btn.setVisible(False) + else: + QMessageBox.warning(self, "Leave Group", msg) + + def _on_delete_conv_btn(self): + if not self.current_conv_id: + return + confirm = QMessageBox.question( + self, "Delete Conversation", + "Delete this conversation? This cannot be undone.", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + ) + if confirm == QMessageBox.StandardButton.Yes: + self.bridge.delete_conversation(self.current_conv_id) + + def _on_conversation_deleted(self, ok, msg): + if ok: + self.current_conv_id = None + self.msg_input.drop_enabled = False + self.chat_header.setText("Select a conversation") + self.chat_header_avatar.setVisible(False) + self._clear_message_area() + self.group_info_btn.setVisible(False) + self.user_info_btn.setVisible(False) + self.add_member_btn.setVisible(False) + self.delete_conv_btn.setVisible(False) + else: + QMessageBox.warning(self, "Delete Conversation", msg) + + def _on_remove_member_result(self, ok, msg): + if ok: + QMessageBox.information(self, "Remove Member", "Member removed.") + if self.current_conv_id: + self.bridge.load_messages(self.current_conv_id) + else: + QMessageBox.warning(self, "Remove Member", msg) + + def _on_my_profile(self): + my_user_id = self.bridge.client.session.get("user_id", "") if self.bridge.client.session else "" + if not my_user_id: + return + dlg = UserProfileDialog(self.bridge, my_user_id, editable=True, parent=self) + dlg.exec() + + def _on_dm_user_info(self): + """Show profile of the other user in a DM conversation.""" + if not self.current_conv_id: + return + conv = None + for cv in self.conversations: + if cv["conversation_id"] == self.current_conv_id: + conv = cv + break + if not conv: + return + my_email = self.bridge.client.email + for m in conv["members"]: + if m.get("email") != my_email: + uid = m.get("user_id") or m.get("id") + if uid: + self._show_user_profile(uid) + return + + def _show_user_profile(self, user_id): + dlg = UserProfileDialog(self.bridge, user_id, editable=False, parent=self) + dlg.exec() + + def _on_open_settings(self): + t = c() + dlg = QDialog(self) + dlg.setMinimumWidth(360) + lay = _make_frameless(dlg, "Settings") + lay.setSpacing(16) + + # -- Appearance section -- + sec_appearance = QLabel("Appearance") + sec_appearance.setStyleSheet( + f"font-size: 10pt; font-weight: bold; color: {t.text_secondary}; " + f"margin-top: 4px;" + ) + lay.addWidget(sec_appearance) + + theme_row = QHBoxLayout() + theme_label = QLabel("Theme") + theme_label.setStyleSheet(f"font-size: 11pt; color: {t.text_primary};") + theme_row.addWidget(theme_label) + theme_row.addStretch() + theme_btn = QPushButton( + "\u2600 Light mode" if tm().is_dark else "\U0001f319 Dark mode" + ) + theme_btn.setObjectName("secondaryBtn") + theme_btn.setFixedWidth(140) + theme_row.addWidget(theme_btn) + lay.addLayout(theme_row) + + # Separator + sep1 = QFrame() + sep1.setFrameShape(QFrame.Shape.HLine) + sep1.setStyleSheet(f"background-color: {t.separator}; max-height: 1px;") + lay.addWidget(sep1) + + # -- Security section -- + sec_security = QLabel("Security") + sec_security.setStyleSheet( + f"font-size: 10pt; font-weight: bold; color: {t.text_secondary};" + ) + lay.addWidget(sec_security) + + # Rotate Keys + rotate_row = QHBoxLayout() + rotate_info = QVBoxLayout() + rotate_title = QLabel("Rotate Keys") + rotate_title.setStyleSheet(f"font-size: 11pt; color: {t.text_primary};") + rotate_info.addWidget(rotate_title) + rotate_desc = QLabel("Generate new RSA keys. Revokes other devices.") + rotate_desc.setStyleSheet(f"font-size: 8pt; color: {t.text_muted};") + rotate_desc.setWordWrap(True) + rotate_info.addWidget(rotate_desc) + rotate_row.addLayout(rotate_info, stretch=1) + rotate_btn = QPushButton("Rotate") + rotate_btn.setObjectName("secondaryBtn") + rotate_btn.setFixedWidth(100) + rotate_btn.clicked.connect(lambda: (dlg.close(), self._on_rotate_keys())) + rotate_row.addWidget(rotate_btn) + lay.addLayout(rotate_row) + + # Change Username + chun_row = QHBoxLayout() + chun_info = QVBoxLayout() + chun_title = QLabel("Change Username") + chun_title.setStyleSheet(f"font-size: 11pt; color: {t.text_primary};") + chun_info.addWidget(chun_title) + chun_desc = QLabel("Change your display name.") + chun_desc.setStyleSheet(f"font-size: 8pt; color: {t.text_muted};") + chun_desc.setWordWrap(True) + chun_info.addWidget(chun_desc) + chun_row.addLayout(chun_info, stretch=1) + chun_btn = QPushButton("Change") + chun_btn.setObjectName("secondaryBtn") + chun_btn.setFixedWidth(100) + chun_btn.clicked.connect(lambda: (dlg.close(), self._on_change_username())) + chun_row.addWidget(chun_btn) + lay.addLayout(chun_row) + + # Change Password + chpw_row = QHBoxLayout() + chpw_info = QVBoxLayout() + chpw_title = QLabel("Change Password") + chpw_title.setStyleSheet(f"font-size: 11pt; color: {t.text_primary};") + chpw_info.addWidget(chpw_title) + chpw_desc = QLabel("Change password for local key encryption.") + chpw_desc.setStyleSheet(f"font-size: 8pt; color: {t.text_muted};") + chpw_desc.setWordWrap(True) + chpw_info.addWidget(chpw_desc) + chpw_row.addLayout(chpw_info, stretch=1) + chpw_btn = QPushButton("Change") + chpw_btn.setObjectName("secondaryBtn") + chpw_btn.setFixedWidth(100) + chpw_btn.clicked.connect(lambda: (dlg.close(), self._on_change_password())) + chpw_row.addWidget(chpw_btn) + lay.addLayout(chpw_row) + + # Separator + sep2 = QFrame() + sep2.setFrameShape(QFrame.Shape.HLine) + sep2.setStyleSheet(f"background-color: {t.separator}; max-height: 1px;") + lay.addWidget(sep2) + + # -- Devices section -- + sec_devices = QLabel("Devices") + sec_devices.setStyleSheet( + f"font-size: 10pt; font-weight: bold; color: {t.text_secondary};" + ) + lay.addWidget(sec_devices) + + # Link Device (new) + link_row = QHBoxLayout() + link_info = QVBoxLayout() + link_title = QLabel("Link New Device") + link_title.setStyleSheet(f"font-size: 11pt; color: {t.text_primary};") + link_info.addWidget(link_title) + link_desc = QLabel("Authorize another device to access your account.") + link_desc.setStyleSheet(f"font-size: 8pt; color: {t.text_muted};") + link_desc.setWordWrap(True) + link_info.addWidget(link_desc) + link_row.addLayout(link_info, stretch=1) + link_btn = QPushButton("Link") + link_btn.setObjectName("secondaryBtn") + link_btn.setFixedWidth(100) + link_btn.clicked.connect(lambda: (dlg.close(), self._on_authorize_device())) + link_row.addWidget(link_btn) + lay.addLayout(link_row) + + # Wire up theme toggle now that all widgets exist + _s_labels = [sec_appearance, sec_security, sec_devices] + _t_labels = [theme_label, rotate_title, chpw_title, link_title] + _d_labels = [rotate_desc, chpw_desc, link_desc] + _seps = [sep1, sep2] + + def _toggle_and_update(): + tm().toggle() + theme_btn.setText( + "\u2600 Light mode" if tm().is_dark else "\U0001f319 Dark mode" + ) + t2 = c() + dlg._frameless_container.setStyleSheet( + f"#_framelessContainer {{ background-color: {t2.bg_primary}; border-radius: 12px; }}" + ) + dlg._frameless_title_bar.setStyleSheet( + f"background-color: {t2.bg_secondary}; " + f"border-top-left-radius: 12px; border-top-right-radius: 12px;" + ) + dlg._frameless_title_label.setStyleSheet( + f"font-weight: bold; font-size: 10pt; color: {t2.text_primary}; background: transparent;" + ) + for lbl in _s_labels: + lbl.setStyleSheet(f"font-size: 10pt; font-weight: bold; color: {t2.text_secondary};") + for lbl in _t_labels: + lbl.setStyleSheet(f"font-size: 11pt; color: {t2.text_primary};") + for lbl in _d_labels: + lbl.setStyleSheet(f"font-size: 8pt; color: {t2.text_muted};") + for s in _seps: + s.setStyleSheet(f"background-color: {t2.separator}; max-height: 1px;") + theme_btn.clicked.connect(_toggle_and_update) + + lay.addStretch() + + # Close button + close_btn = QPushButton("Close") + close_btn.clicked.connect(dlg.close) + lay.addWidget(close_btn) + + dlg.exec() + + def _on_authorize_device(self): + use_qr = QMessageBox.question( + self, + "Authorize Device", + "Load a pairing QR image?\n\n" + "Choose Yes to select an image file with the QR from the new device.\n" + "Choose No to enter the pairing code and fingerprint manually.", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No | QMessageBox.StandardButton.Cancel, + QMessageBox.StandardButton.Yes, + ) + code = "" + fingerprint = "" + if use_qr == QMessageBox.StandardButton.Cancel: + return + if use_qr == QMessageBox.StandardButton.Yes: + path, _ = QFileDialog.getOpenFileName( + self, + "Select Pairing QR Image", + "", + "Images (*.png *.jpg *.jpeg *.bmp);;All Files (*)", + ) + if not path: + return + qr_data = _decode_qr_image_file(path) + if qr_data is None: + QMessageBox.warning( + self, + "Authorize Device", + "Could not decode pairing QR code. Install 'pyzbar' or use manual entry.", + ) + return + try: + code, fingerprint = decode_pairing_qr(qr_data) + except ValueError as e: + QMessageBox.warning(self, "Authorize Device", f"Invalid pairing QR code: {e}") + return + else: + code, ok = QInputDialog.getText(self, "Authorize Device", "Pairing code:") + if not ok or not code.strip(): + return + fingerprint, ok = QInputDialog.getText( + self, + "Authorize Device", + "Fingerprint shown on the new device\n(spaces optional):", + ) + if not ok or not fingerprint.strip(): + return + self.bridge.authorize_device(code.strip(), fingerprint.strip()) + + def _on_rotate_keys(self): + confirm = QMessageBox.question( + self, + "Rotate Keys", + "This will revoke other devices. Continue?", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + ) + if confirm != QMessageBox.StandardButton.Yes: + return + password, ok = QInputDialog.getText(self, "Rotate Keys", "Password:", QLineEdit.EchoMode.Password) + if not ok or not password: + return + self.bridge.rotate_keys(self.bridge.client.username, password) + + def _on_change_password(self): + t = c() + dlg = QDialog(self) + dlg.setMinimumWidth(360) + lay = _make_frameless(dlg, "Change Password") + + old_label = QLabel("Current Password") + old_label.setStyleSheet(f"color: {t.text_secondary}; font-size: 9pt;") + lay.addWidget(old_label) + old_input = QLineEdit() + old_input.setEchoMode(QLineEdit.EchoMode.Password) + old_input.setPlaceholderText("Enter current password") + old_input.setStyleSheet( + f"QLineEdit {{ font-size: 11pt; background-color: {t.bg_secondary}; " + f"border: 1px solid {t.border}; border-radius: 6px; padding: 8px; " + f"color: {t.text_primary}; }}" + f"QLineEdit:focus {{ border: 1px solid {t.border_focus}; }}" + ) + lay.addWidget(old_input) + + new_label = QLabel("New Password") + new_label.setStyleSheet(f"color: {t.text_secondary}; font-size: 9pt;") + lay.addWidget(new_label) + new_input = QLineEdit() + new_input.setEchoMode(QLineEdit.EchoMode.Password) + new_input.setPlaceholderText("Enter new password") + new_input.setStyleSheet(old_input.styleSheet()) + lay.addWidget(new_input) + + confirm_label = QLabel("Confirm New Password") + confirm_label.setStyleSheet(f"color: {t.text_secondary}; font-size: 9pt;") + lay.addWidget(confirm_label) + confirm_input = QLineEdit() + confirm_input.setEchoMode(QLineEdit.EchoMode.Password) + confirm_input.setPlaceholderText("Re-enter new password") + confirm_input.setStyleSheet(old_input.styleSheet()) + lay.addWidget(confirm_input) + + error_label = QLabel("") + error_label.setStyleSheet(f"color: {t.error}; font-size: 9pt;") + error_label.hide() + lay.addWidget(error_label) + + btn_lay = QHBoxLayout() + btn_lay.addStretch() + change_btn = QPushButton("Change Password") + change_btn.setStyleSheet( + f"QPushButton {{ background-color: {t.accent}; color: {t.accent_text}; " + f"border: none; border-radius: 6px; padding: 8px 16px; font-weight: bold; }}" + f"QPushButton:hover {{ opacity: 0.9; }}" + ) + btn_lay.addWidget(change_btn) + lay.addLayout(btn_lay) + + def _do_change(): + old_pw = old_input.text() + new_pw = new_input.text() + conf_pw = confirm_input.text() + if not old_pw: + error_label.setText("Current password is required.") + error_label.show() + return + if not new_pw: + error_label.setText("New password cannot be empty.") + error_label.show() + return + if new_pw != conf_pw: + error_label.setText("New passwords do not match.") + error_label.show() + return + dlg.accept() + self.bridge.change_password(old_pw, new_pw) + + change_btn.clicked.connect(_do_change) + confirm_input.returnPressed.connect(_do_change) + dlg.exec() + + def _on_logout(self): + confirm = QMessageBox.question( + self, + "Logout", + "Log out and return to the login screen?", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + ) + if confirm != QMessageBox.StandardButton.Yes: + return + self._is_logout = True + self.bridge.logout() + self.close() + if self._on_logout_cb: + self._on_logout_cb() + + def _on_notification(self, payload): + sender = payload.get("sender", "???") + conv_id = payload.get("conversation_id", "") + sender_id = payload.get("sender_id", "") + + if conv_id and sender_id: + self._expire_typing_user(conv_id, sender_id) + timer = self._typing_timers.get((conv_id, sender_id)) + if timer: + timer.stop() + + # Update last-message cache for conversation list preview + if conv_id: + preview = payload.get("text", "") + if payload.get("image") and not preview: + preview = "Sent an image" + elif payload.get("file") and not preview: + preview = "Sent a file" + if preview: + self._last_message_cache[conv_id] = ( + f"{sender}: {preview}"[:60], + payload.get("created_at", ""), + "", # incoming message — no receipt status for others' msgs + ) + + # Resolve conversation name for notifications + conv_name = sender + is_notif_dm = False + if conv_id: + for cv in self.conversations: + if cv["conversation_id"] == conv_id: + is_notif_dm = len(cv["members"]) == 2 and not cv.get("name") + if not is_notif_dm: + conv_name = cv.get("name") or sender + break + + # System tray toast when window is not visible or not focused + if conv_id: + if is_notif_dm: + notif_title = sender + notif_text = payload.get("text", "New message") + else: + notif_title = conv_name + notif_text = f"{sender}: {payload.get('text', 'New message')}" + if payload.get("image"): + notif_text = notif_text or "Sent an image" + elif payload.get("file"): + notif_text = notif_text or "Sent a file" + self._show_tray_notification(notif_title, notif_text) + + # Show notification in status bar (for non-current conversations) + if conv_id and conv_id != self.current_conv_id: + if is_notif_dm: + self.status_bar.setText(f"New message from {sender}") + else: + self.status_bar.setText(f"New message from {sender} in {conv_name}") + t = c() + self.status_bar.setStyleSheet( + f"background-color: {t.bg_tertiary}; border-radius: 0px; " + f"padding: 0 8px; color: {t.success}; font-size: 8pt; font-weight: bold;" + ) + self._status_bar_conv_id = conv_id + QTimer.singleShot(5000, self._clear_status_bar) + + # Confirm delivery for all incoming messages (always, regardless of current view) + msg_id = payload.get("message_id", "") + if conv_id and msg_id: + self.bridge.schedule( + self.bridge.client.confirm_delivery(conv_id, [msg_id]) + ) + + # Increment unread count if not currently viewing this conversation + # (or if privacy overlay is locked — user can't see messages) + viewing = conv_id == self.current_conv_id and not self._privacy_locked + if conv_id and not viewing: + self._unread_counts[conv_id] = self._unread_counts.get(conv_id, 0) + 1 + self._update_conv_list_styles() + + # Append directly to current conversation instead of re-fetching + if conv_id == self.current_conv_id: + # Avoid duplicate if local send already appended this message + if msg_id: + for m in self.current_messages: + if m.get("message_id") == msg_id: + return + self.current_messages.append(payload) + idx = len(self.current_messages) - 1 + w = self._create_message_widget(payload, idx) + self._msg_layout.addWidget(w) + self._msg_widgets.append(w) + if self._is_near_bottom: + QTimer.singleShot(10, self._scroll_to_bottom) + else: + self.jump_btn.setText("\u2193 New") + self.jump_btn.setVisible(True) + self._position_jump_btn() + # Mark as read (only if not locked) + if msg_id and not self._privacy_locked: + self.bridge.schedule( + self.bridge.client.mark_read(conv_id, [msg_id]) + ) + + def _display_name_for_user(self, conv_id: str, user_id: str, fallback: str = "") -> str: + for conv in self.conversations: + if conv.get("conversation_id") != conv_id: + continue + for member in conv.get("members", []): + member_id = member.get("user_id") or member.get("id") + if member_id == user_id: + return member.get("username") or member.get("email") or fallback or "Someone" + return fallback or "Someone" + + def _typing_label_refresh(self): + if not self.current_conv_id: + self.typing_label.setVisible(False) + self.typing_label.setText("") + return + users = list(self._typing_users.get(self.current_conv_id, {}).values()) + if not users: + self.typing_label.setVisible(False) + self.typing_label.setText("") + return + if len(users) == 1: + text = f"{users[0]} is typing..." + elif len(users) == 2: + text = f"{users[0]} and {users[1]} are typing..." + else: + text = f"{users[0]} and {len(users) - 1} others are typing..." + self.typing_label.setText(text) + self.typing_label.setVisible(True) + + def _on_typing_start(self, data): + conv_id = data.get("conversation_id", "") + user_id = data.get("user_id", "") + if not conv_id or not user_id: + return + if self.bridge.client.session and user_id == self.bridge.client.session.get("user_id"): + return + + username = data.get("username") or self._display_name_for_user(conv_id, user_id, "Someone") + conv_users = self._typing_users.setdefault(conv_id, {}) + conv_users[user_id] = username + + key = (conv_id, user_id) + timer = self._typing_timers.get(key) + if timer is None: + timer = QTimer(self) + timer.setSingleShot(True) + timer.timeout.connect(lambda c=conv_id, u=user_id: self._expire_typing_user(c, u)) + self._typing_timers[key] = timer + timer.start(3000) + + if conv_id == self.current_conv_id: + self._typing_label_refresh() + + def _expire_typing_user(self, conv_id: str, user_id: str): + users = self._typing_users.get(conv_id) + if users and user_id in users: + users.pop(user_id, None) + if not users: + self._typing_users.pop(conv_id, None) + if conv_id == self.current_conv_id: + self._typing_label_refresh() + + def _on_typing_stop(self, data): + conv_id = data.get("conversation_id", "") + user_id = data.get("user_id", "") + if not conv_id or not user_id: + return + self._expire_typing_user(conv_id, user_id) + key = (conv_id, user_id) + timer = self._typing_timers.get(key) + if timer: + timer.stop() + + def _on_messages_read(self, data): + conv_id = data.get("conversation_id", "") + user_id = data.get("user_id", "") + message_ids = set(data.get("message_ids", [])) + my_uid = self.bridge.client.session.get("user_id", "") if self.bridge.client.session else "" + + # Persist to cache for ALL conversations (not just current) + if conv_id == self.current_conv_id: + for msg in self.current_messages: + if message_ids and msg.get("message_id") not in message_ids: + continue + if not message_ids and msg.get("sender_id") != my_uid: + continue + read_by = msg.get("read_by", []) + if not any(r.get("user_id") == user_id for r in read_by): + read_by.append({"user_id": user_id}) + msg["read_by"] = read_by + self.bridge.client.update_message_in_cache( + conv_id, msg.get("message_id"), {"read_by": read_by}) + self._render_messages(scroll_to_bottom=self._is_near_bottom) + else: + # Non-current conversation: update cache only (no in-memory messages to update) + cached = self.bridge.client.load_message_cache(conv_id) + if cached: + for msg_id_key, entry in cached.items(): + if message_ids and msg_id_key not in message_ids: + continue + if not message_ids and entry.get("sender_id") != my_uid: + continue + read_by = entry.get("read_by", []) + if not any(r.get("user_id") == user_id for r in read_by): + read_by.append({"user_id": user_id}) + self.bridge.client.update_message_in_cache( + conv_id, msg_id_key, {"read_by": read_by}) + # Update conv list receipt status to "read" + if conv_id in self._last_message_cache: + prev_text, prev_ts, prev_receipt = self._last_message_cache[conv_id] + if prev_receipt in ("sent", "delivered"): + self._last_message_cache[conv_id] = (prev_text, prev_ts, "read") + self._update_conv_list_styles() + + def _on_message_delivered(self, data): + conv_id = data.get("conversation_id", "") + user_id = data.get("user_id", "") + message_ids = set(data.get("message_ids", [])) + + if conv_id == self.current_conv_id: + for msg in self.current_messages: + if msg.get("message_id") in message_ids: + delivered_to = msg.get("delivered_to", []) + if not any(d.get("user_id") == user_id for d in delivered_to): + delivered_to.append({"user_id": user_id}) + msg["delivered_to"] = delivered_to + self.bridge.client.update_message_in_cache( + conv_id, msg.get("message_id"), {"delivered_to": delivered_to}) + self._render_messages(scroll_to_bottom=self._is_near_bottom) + else: + # Non-current conversation: update cache only + cached = self.bridge.client.load_message_cache(conv_id) + if cached: + for msg_id_key, entry in cached.items(): + if msg_id_key in message_ids: + delivered_to = entry.get("delivered_to", []) + if not any(d.get("user_id") == user_id for d in delivered_to): + delivered_to.append({"user_id": user_id}) + self.bridge.client.update_message_in_cache( + conv_id, msg_id_key, {"delivered_to": delivered_to}) + # Update conv list receipt status to "delivered" + if conv_id in self._last_message_cache: + prev_text, prev_ts, prev_receipt = self._last_message_cache[conv_id] + if prev_receipt == "sent": + self._last_message_cache[conv_id] = (prev_text, prev_ts, "delivered") + self._update_conv_list_styles() + + def _on_link_clicked(self, url_str): + """Handle link clicks from message bubble labels.""" + if url_str.startswith("image://"): + file_id = url_str[len("image://"):] + for msg in self.current_messages: + image_info = msg.get("image") + if image_info and image_info.get("file_id") == file_id: + self._view_image(msg) + return + elif url_str.startswith("file://"): + file_id = url_str[len("file://"):] + for msg in self.current_messages: + file_info = msg.get("file") + if file_info and file_info.get("file_id") == file_id: + self.bridge.download_file(file_id, file_info) + return + elif url_str.startswith("https://"): + QDesktopServices.openUrl(QUrl(url_str)) + elif url_str.startswith("http://"): + reply = QMessageBox.warning( + self, + "Insecure link", + f"This link uses unencrypted HTTP.\n\n{url_str}\n\nContinue anyway?", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + QMessageBox.StandardButton.No, + ) + if reply == QMessageBox.StandardButton.Yes: + QDesktopServices.openUrl(QUrl(url_str)) + + # -- Drag & drop -------------------------------------------------------- + + _IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp"} + def _msg_area_normal_style(self): + return f"QScrollArea {{ background-color: {c().bg_primary}; border: none; }}" + + def eventFilter(self, obj, event): + """Handle drag-and-drop on message scroll area + context menu on messages.""" + from PyQt6.QtCore import QEvent + # Context menu from any child of message container + if event.type() == QEvent.Type.ContextMenu and obj is not self._msg_scroll_area: + widget = obj + idx = self._find_msg_index_at_widget(widget) + if idx is not None: + self._show_msg_context_menu(idx, event.globalPos()) + return True + if obj is self._msg_scroll_area: + if event.type() == QEvent.Type.DragEnter: + if not self.current_conv_id: + event.ignore() + return True + if event.mimeData().hasUrls() and any(u.isLocalFile() for u in event.mimeData().urls()): + event.acceptProposedAction() + self._msg_scroll_area.setStyleSheet( + f"QScrollArea {{ border: 2px dashed {c().accent}; }}" + ) + return True + elif event.type() == QEvent.Type.DragMove: + if event.mimeData().hasUrls(): + event.acceptProposedAction() + return True + elif event.type() == QEvent.Type.DragLeave: + self._msg_scroll_area.setStyleSheet(self._msg_area_normal_style()) + return True + elif event.type() == QEvent.Type.Drop: + self._msg_scroll_area.setStyleSheet(self._msg_area_normal_style()) + if event.mimeData().hasUrls(): + for url in event.mimeData().urls(): + if url.isLocalFile(): + self._on_file_dropped(url.toLocalFile()) + event.acceptProposedAction() + return True + return super().eventFilter(obj, event) + + def _on_file_dropped(self, path: str): + """Send a dropped file as image or file attachment.""" + if not self.current_conv_id: + return + conv = None + for cv in self.conversations: + if cv["conversation_id"] == self.current_conv_id: + conv = cv + break + if not conv: + return + import os + ext = os.path.splitext(path)[1].lower() + reply_to = self.reply_to_id + is_image = ext in self._IMAGE_EXTENSIONS + filename = os.path.basename(path) + self._show_optimistic_attachment(self.current_conv_id, is_image=is_image, + filename=filename, reply_to=reply_to, + image_path=path if is_image else None) + if is_image: + self.bridge.send_image( + self.current_conv_id, path, conv["members"], + reply_to=reply_to, + ) + else: + self.bridge.send_file( + self.current_conv_id, path, conv["members"], + reply_to=reply_to, + ) + self.reply_to_id = None + self._reply_widget.setVisible(False) + + # -- Attach menu ------------------------------------------------------- + + def _on_attach_image(self): + if not self.current_conv_id: + return + path, _ = QFileDialog.getOpenFileName( + self, "Select Image", "", + "Images (*.png *.jpg *.jpeg *.gif *.bmp *.webp);;All Files (*)", + ) + if not path: + return + conv = None + for cv in self.conversations: + if cv["conversation_id"] == self.current_conv_id: + conv = cv + break + if not conv: + return + reply_to = self.reply_to_id + self._show_optimistic_attachment(self.current_conv_id, is_image=True, + reply_to=reply_to, image_path=path) + self.bridge.send_image( + self.current_conv_id, path, conv["members"], + reply_to=reply_to, + ) + self.reply_to_id = None + self._reply_widget.setVisible(False) + + @staticmethod + def _human_file_size(size_bytes): + if size_bytes >= 1024 * 1024: + return f"{size_bytes / (1024 * 1024):.1f} MB" + elif size_bytes >= 1024: + return f"{size_bytes / 1024:.0f} KB" + return f"{size_bytes} B" + + @staticmethod + def _file_icon(filename: str) -> str: + """Return an emoji icon based on file extension.""" + ext = filename.rsplit(".", 1)[-1].lower() if "." in filename else "" + _icons = { + "pdf": "\U0001f4d5", # red book + "doc": "\U0001f4d8", # blue book + "docx": "\U0001f4d8", + "odt": "\U0001f4d8", + "xls": "\U0001f4ca", # bar chart + "xlsx": "\U0001f4ca", + "ods": "\U0001f4ca", + "csv": "\U0001f4ca", + "ppt": "\U0001f4d9", # orange book + "pptx": "\U0001f4d9", + "odp": "\U0001f4d9", + "zip": "\U0001f4e6", # package + "rar": "\U0001f4e6", + "7z": "\U0001f4e6", + "tar": "\U0001f4e6", + "gz": "\U0001f4e6", + "mp3": "\U0001f3b5", # music note + "wav": "\U0001f3b5", + "flac": "\U0001f3b5", + "ogg": "\U0001f3b5", + "m4a": "\U0001f3b5", + "mp4": "\U0001f3ac", # clapper board + "mkv": "\U0001f3ac", + "avi": "\U0001f3ac", + "mov": "\U0001f3ac", + "webm": "\U0001f3ac", + "py": "\U0001f40d", # snake + "js": "\U0001f4dc", # scroll + "ts": "\U0001f4dc", + "html": "\U0001f310", # globe + "css": "\U0001f3a8", # palette + "json": "\U0001f4cb", # clipboard + "xml": "\U0001f4cb", + "yaml": "\U0001f4cb", + "yml": "\U0001f4cb", + "txt": "\U0001f4c4", # page facing up + "log": "\U0001f4c4", + "md": "\U0001f4c4", + } + return _icons.get(ext, "\U0001f4ce") # default: paperclip + + def _on_attach_file(self): + if not self.current_conv_id: + return + path, _ = QFileDialog.getOpenFileName( + self, "Select File", "", + "All Files (*)", + ) + if not path: + return + conv = None + for cv in self.conversations: + if cv["conversation_id"] == self.current_conv_id: + conv = cv + break + if not conv: + return + import os + reply_to = self.reply_to_id + filename = os.path.basename(path) + self._show_optimistic_attachment(self.current_conv_id, is_image=False, + filename=filename, reply_to=reply_to) + self.bridge.send_file( + self.current_conv_id, path, conv["members"], + reply_to=reply_to, + ) + self.reply_to_id = None + self._reply_widget.setVisible(False) + + def _on_file_sent(self, ok, msg): + if not ok: + self._remove_optimistic("file") + QMessageBox.warning(self, "File Error", msg) + + def _on_file_downloaded(self, data, file_info): + filename = _safe_filename(file_info.get("filename", "file"), "file") + path, _ = QFileDialog.getSaveFileName(self, "Save File", filename) + if path: + try: + with open(path, "wb") as f: + f.write(data) + QMessageBox.information(self, "Saved", f"File saved to {path}") + except Exception as e: + QMessageBox.warning(self, "Error", f"Failed to save: {e}") + + def _on_image_sent(self, ok, msg): + if not ok: + self._remove_optimistic("image") + QMessageBox.warning(self, "Image Error", msg) + + def _view_image(self, msg): + image_info = msg.get("image") + if not image_info: + return + file_id = image_info.get("file_id", "") + self._pending_image_download = {"file_id": file_id, "image_info": image_info} + self.bridge.download_image(file_id, image_info) + + def _on_image_downloaded(self, file_id, data): + if not self._pending_image_download or self._pending_image_download["file_id"] != file_id: + return + image_info = self._pending_image_download["image_info"] + self._pending_image_download = None + self._show_image_dialog(data, image_info) + + def _on_image_download_failed(self, file_id): + if self._pending_image_download and self._pending_image_download["file_id"] == file_id: + self._pending_image_download = None + self.statusBar().showMessage("Image download failed.", 5000) + + def _show_image_dialog(self, image_data, image_info): + dlg = QDialog(self) + dlg.setMinimumSize(400, 300) + img_title = _safe_filename(image_info.get("filename", "Image"), "Image") + layout = _make_frameless(dlg, img_title) + + qimg = _safe_load_image(image_data) + if qimg is None: + layout.addWidget(QLabel("Failed to load image.")) + else: + pixmap = QPixmap.fromImage(qimg) + label = QLabel() + # Scale down if larger than screen + screen_size = self.screen().availableSize() + max_w = int(screen_size.width() * 0.8) + max_h = int(screen_size.height() * 0.8) + if pixmap.width() > max_w or pixmap.height() > max_h: + pixmap = pixmap.scaled(max_w, max_h, Qt.AspectRatioMode.KeepAspectRatio, + Qt.TransformationMode.SmoothTransformation) + label.setPixmap(pixmap) + label.setAlignment(Qt.AlignmentFlag.AlignCenter) + + scroll = QScrollArea() + scroll.setWidget(label) + scroll.setWidgetResizable(True) + layout.addWidget(scroll) + + btn_row = QHBoxLayout() + save_btn = QPushButton("Save") + save_btn.clicked.connect(lambda: self._save_image(image_data, image_info, dlg)) + btn_row.addWidget(save_btn) + close_btn = QPushButton("Close") + close_btn.setObjectName("secondaryBtn") + close_btn.clicked.connect(dlg.accept) + btn_row.addWidget(close_btn) + layout.addLayout(btn_row) + + if qimg is not None and not qimg.isNull(): + dlg.resize(min(pixmap.width() + 40, max_w), + min(pixmap.height() + 80, max_h)) + dlg.exec() + + def _save_image(self, image_data, image_info, dialog): + filename = _safe_filename(image_info.get("filename", "image.jpg"), "image.jpg") + path, _ = QFileDialog.getSaveFileName(dialog, "Save Image", filename) + if path: + try: + with open(path, "wb") as f: + f.write(image_data) + QMessageBox.information(dialog, "Saved", f"Image saved to {path}") + except Exception as e: + QMessageBox.warning(dialog, "Error", f"Failed to save: {e}") + + def _on_message_deleted(self, data): + message_id = data.get("message_id", "") + conv_id = data.get("conversation_id", "") + if conv_id == self.current_conv_id: + for msg in self.current_messages: + if msg.get("message_id") == message_id: + msg["deleted"] = True + msg["text"] = "" + msg["image"] = None + break + self._render_messages() + + def _on_delete_message_result(self, ok, msg): + if not ok: + QMessageBox.warning(self, "Delete Error", msg) + return + # No need to reload — _on_message_deleted() already updates in-place via notification + + def _on_authorize_result(self, ok, msg): + if ok: + QMessageBox.information(self, "Authorize Device", msg) + else: + QMessageBox.warning(self, "Authorize Device", msg) + + def _on_rotate_result(self, ok, msg): + if ok: + QMessageBox.information(self, "Rotate Keys", msg) + else: + QMessageBox.warning(self, "Rotate Keys", msg) + + def _on_password_changed(self, ok, msg): + if ok: + QMessageBox.information(self, "Change Password", msg) + else: + QMessageBox.warning(self, "Change Password", msg) + + def _on_change_username(self): + t = c() + current = "" + if self.bridge and self.bridge.client: + current = getattr(self.bridge.client, "username", "") or "" + dlg = QDialog(self) + dlg.setMinimumWidth(360) + lay = _make_frameless(dlg, "Change Username") + + label = QLabel("New Username") + label.setStyleSheet(f"color: {t.text_secondary}; font-size: 9pt;") + lay.addWidget(label) + name_input = QLineEdit() + name_input.setText(current) + name_input.setPlaceholderText("Enter new username") + name_input.setMaxLength(100) + name_input.setStyleSheet( + f"QLineEdit {{ font-size: 11pt; background-color: {t.bg_secondary}; " + f"border: 1px solid {t.border}; border-radius: 6px; padding: 8px; " + f"color: {t.text_primary}; }}" + f"QLineEdit:focus {{ border: 1px solid {t.border_focus}; }}" + ) + lay.addWidget(name_input) + + btn_row = QHBoxLayout() + cancel_btn = QPushButton("Cancel") + cancel_btn.setObjectName("secondaryBtn") + cancel_btn.clicked.connect(dlg.reject) + btn_row.addWidget(cancel_btn) + save_btn = QPushButton("Save") + save_btn.setObjectName("primaryBtn") + save_btn.clicked.connect(dlg.accept) + btn_row.addWidget(save_btn) + lay.addLayout(btn_row) + + name_input.setFocus() + if dlg.exec() == QDialog.DialogCode.Accepted: + new_name = name_input.text().strip() + if new_name and new_name != current: + self.bridge.change_username(new_name) + + def _on_username_changed(self, ok, msg): + if ok: + QMessageBox.information(self, "Change Username", msg) + self.bridge.schedule(self.bridge._do_load_conversations()) + else: + QMessageBox.warning(self, "Change Username", msg) + + def _on_reencrypt_status(self, msg): + self.reencrypt_label.setText(msg) + self.reencrypt_label.setVisible(True) + if msg.lower().startswith("re-encryption complete"): + QTimer.singleShot(4000, lambda: self.reencrypt_label.setVisible(False)) + + def closeEvent(self, event): + if self._tray_icon: + self._tray_icon.hide() + if not self._is_logout: + self.bridge.stop() + self.bridge.wait(2000) + event.accept() + + +def main(): + setup_logging() + + # Suppress Qt screen enumeration warnings on Windows (monitor sleep/wake) + os.environ.setdefault("QT_LOGGING_RULES", "qt.qpa.screen=false") + + # Windows 10+ requires AppUserModelID for system tray notifications to work + if sys.platform == "win32": + try: + import ctypes + ctypes.windll.shell32.SetCurrentProcessExplicitAppUserModelID("com.encrypted-chat.client") + except Exception: + pass + + app = QApplication(sys.argv) + app.setStyleSheet(qss()) + + bridge = AsyncBridge() + + login_win = LoginWindow(bridge) + main_win = [None] # mutable ref + + def on_connected(): + login_win.reset() + login_win.status_label.setText(f"Connected to {bridge.client.server_endpoint()}") + login_win.show() + + def on_conn_error(msg): + QMessageBox.critical(None, "Connection Error", f"Cannot connect to server:\n{msg}") + sys.exit(1) + + def on_register_result(ok, msg): + if ok: + # Show verification code page inline + hint = "" + if msg and len(msg) <= 6 and msg.isdigit(): + hint = f"Code: {msg}" + elif msg: + hint = msg + login_win.show_verification_page(hint) + + def do_confirm(code): + async def _confirm(): + okc, msgc = await bridge.client.confirm_registration( + login_win._reg_email_input.text().strip(), + login_win.username_input.text().strip(), + code.strip(), + ) + if okc: + login_win.show_success(msgc) + bridge.do_login( + login_win._reg_email_input.text().strip(), + login_win._reg_password_input.text(), + ) + else: + login_win.show_error(msgc) + bridge.schedule(_confirm()) + + login_win._confirm_callback = do_confirm + else: + login_win.show_error(msg) + + def on_pairing_code(code, fingerprint, qr_data): + login_win._set_enabled(True) + login_win.show_success( + "Pairing code: " + f"{code}\n" + f"Pairing fingerprint:\n{fingerprint}\n" + f"Server: {bridge.client.server_endpoint()}\n" + "Approve it on an already logged-in device.\n" + "Never share this pairing code." + ) + login_win.show_pairing_qr(code, fingerprint, qr_data) + + def on_pairing_complete(ok, msg): + if ok: + login_win.show_success(msg) + bridge.do_login(login_win._pair_email, login_win._pair_password) + else: + login_win.show_error(msg) + + def on_login_result(ok, msg): + if ok: + login_win.show_success(msg) + login_win.hide() + tm().set_email(bridge.client.email) + app.setStyleSheet(qss()) + main_win[0] = MainWindow(bridge, on_logout=lambda: (login_win.reset(), login_win.show())) + main_win[0].show() + else: + login_win.show_error(msg) + + bridge.connected.connect(on_connected) + bridge.connection_error.connect(on_conn_error) + bridge.register_result.connect(on_register_result) + bridge.login_result.connect(on_login_result) + bridge.pairing_code.connect(on_pairing_code) + bridge.pairing_complete.connect(on_pairing_complete) + bridge.reconnected.connect(lambda: (login_win.reset(), login_win.show())) + + bridge.start() + + sys.exit(app.exec()) + + +if __name__ == "__main__": + main() diff --git a/protocol.py b/protocol.py new file mode 100644 index 0000000..5c9f15a --- /dev/null +++ b/protocol.py @@ -0,0 +1,146 @@ +"""Newline-delimited JSON protocol with base64 encoding for binary data.""" + +import asyncio +import base64 +import binascii +import json +import os + + +def encode_binary(data: bytes) -> str: + """Encode bytes to base64 string.""" + return base64.b64encode(data).decode("ascii") + + +def decode_binary(data: str) -> bytes: + """Decode base64 string to bytes.""" + try: + return base64.b64decode(data, validate=True) + except (TypeError, binascii.Error) as e: + raise ValueError(f"Invalid base64: {e}") + + +VERSION = "0.8.6" +MIN_CLIENT_VERSION = "0.8.6" # server rejects clients older than this + + +def version_gte(version: str, minimum: str) -> bool: + """Return True if version >= minimum (compares numeric tuples, e.g. '0.8.1' >= '0.8'). + + Returns False for malformed version strings (instead of silently treating them as 0). + """ + def _parse(v: str) -> tuple[int, ...] | None: + if not isinstance(v, str) or not v: + return None + parts = v.split(".") + try: + return tuple(int(x) for x in parts) + except (ValueError, AttributeError): + return None + parsed_ver = _parse(version) + parsed_min = _parse(minimum) + if parsed_ver is None or parsed_min is None: + return False + return parsed_ver >= parsed_min + + +MAX_MESSAGE_BYTES = int(os.getenv("MAX_MESSAGE_BYTES", str(1024 * 1024))) # 1 MiB default (was 64K, raised for 256K media chunks) +MAX_IMAGE_BYTES = int(os.getenv("MAX_IMAGE_BYTES", str(5 * 1024 * 1024))) # 5 MiB default, 0 = no limit +MAX_FILE_BYTES = int(os.getenv("MAX_FILE_BYTES", str(50 * 1024 * 1024))) # 50 MiB default +IMAGE_CHUNK_SIZE = 262144 # 256 KiB raw chunk size for image upload/download + + +def build_request(msg_type: str, request_id: str | None = None, **kwargs) -> bytes: + """Build a protocol message (newline-terminated JSON).""" + msg = {"type": msg_type, **kwargs} + if request_id: + msg["request_id"] = request_id + return json.dumps(msg, ensure_ascii=False).encode("utf-8") + b"\n" + + +def build_response( + msg_type: str, + status: str, + data: dict | None = None, + request_id: str | None = None, +) -> bytes: + """Build a server response.""" + msg = {"type": msg_type, "status": status} + if data is not None: + msg["data"] = data + if request_id: + msg["request_id"] = request_id + return json.dumps(msg, ensure_ascii=False).encode("utf-8") + b"\n" + + +def parse_message(line: bytes) -> dict: + """Parse a single protocol message from bytes.""" + try: + return json.loads(line.decode("utf-8")) + except (json.JSONDecodeError, UnicodeDecodeError) as e: + raise ValueError(f"Invalid message: {e}") + + +class ProtocolReader: + """Read newline-delimited JSON messages from an asyncio StreamReader.""" + + def __init__(self, reader: asyncio.StreamReader): + self._reader = reader + + async def read_message(self) -> dict | None: + """Read and parse one message. Returns None on EOF.""" + try: + line = await self._reader.readuntil(b"\n") + except (asyncio.IncompleteReadError, ConnectionError): + return None + except asyncio.LimitOverrunError as e: + # Message exceeded StreamReader limit — drain oversized data + # using public read() API (consumed=e.consumed bytes before limit). + # Read in chunks until newline found or EOF, then signal error. + remaining = e.consumed + while True: + chunk = await self._reader.read(max(remaining, 4096)) + if not chunk: + return None # EOF while draining + if b"\n" in chunk: + break # found delimiter, oversized message fully drained + raise ValueError("Message exceeds maximum size") + if not line: + return None + return parse_message(line.strip()) + + +class ProtocolWriter: + """Write newline-delimited JSON messages to an asyncio StreamWriter.""" + + def __init__(self, writer: asyncio.StreamWriter): + self._writer = writer + + async def send_request(self, msg_type: str, request_id: str | None = None, **kwargs): + """Send a request message.""" + payload = build_request(msg_type, request_id=request_id, **kwargs) + if len(payload) > MAX_MESSAGE_BYTES: + raise ValueError(f"Message exceeds limit ({len(payload)} > {MAX_MESSAGE_BYTES})") + self._writer.write(payload) + await self._writer.drain() + + async def send_response( + self, + msg_type: str, + status: str, + data: dict | None = None, + request_id: str | None = None, + ): + """Send a response message.""" + payload = build_response(msg_type, status, data, request_id=request_id) + if len(payload) > MAX_MESSAGE_BYTES: + raise ValueError(f"Message exceeds limit ({len(payload)} > {MAX_MESSAGE_BYTES})") + self._writer.write(payload) + await self._writer.drain() + + def is_closing(self) -> bool: + """Check if the underlying transport is closing or closed.""" + return self._writer.is_closing() + + def close(self): + self._writer.close() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..b4d7cb7 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +cryptography>=42.0.0 +mysql-connector-python>=8.3.0 +python-dotenv>=1.0.0 +# GUI client (optional, needed for gui_client.py) +PyQt6>=6.6.0 +# Image sharing (optional, needed for send_image feature) +Pillow>=10.0.0 +# QR code generation for contact verification (optional) +qrcode[pil]>=7.4 +# QR code scanning (needed for gui_client.py QR scan feature) +pyzbar>=0.1.9 diff --git a/scaling.md b/scaling.md new file mode 100644 index 0000000..742df17 --- /dev/null +++ b/scaling.md @@ -0,0 +1,252 @@ +# Škálování serveru — plán kapacitního růstu + +## Cílový hardware + +- **CPU:** Intel Xeon E5-2630v4 (10 cores / 20 threads, 2.2 GHz) +- **RAM:** 256 GB REG ECC +- **Disk:** 500 GB SSD (boot/OS/DB) + 4 TB HDD (soubory) +- **Síť:** 1 Gbit + +Odhadovaná kapacita po optimalizaci: **10 000–20 000 uživatelů**, **2000–5000 zpráv/s** + +--- + +## Krok 1: Okamžité změny (hotovo v kódu) + +### 1a. Thread pool — `server.py` + +```env +THREAD_POOL_SIZE=40 +``` + +Nastavuje `ThreadPoolExecutor(max_workers=40)` jako default executor pro `asyncio.to_thread()`. +S 20 HW thready a DB latencí ~2–5ms je 40 workerů optimální (2x HW threads — workery čekají na I/O). + +### 1b. DB pool — `.env` + +```env +DB_POOL_SIZE=30 +``` + +30 simultánních MySQL spojení. S 40 thread workers a ~2ms query je 30 pool konexí dostatek. + +### 1c. Chybějící DB indexy — `schema.sql` + +Přidány 5 nových indexů pro nejčastější dotazy: + +| Index | Tabulka | Dotaz který zrychlí | +|-------|---------|---------------------| +| `idx_cm_user (user_id)` | `conversation_members` | `list_user_conversations` — **kritický**, bez něj full table scan | +| `idx_inv_user (user_id)` | `group_invitations` | `get_pending_invitations` | +| `idx_messages_deleted (conversation_id, deleted_at)` | `messages` | `get_deleted_messages_since` | +| `idx_messages_pinned (conversation_id, pinned_at)` | `messages` | `get_pinned_messages` | +| `idx_reads_user (user_id)` | `message_reads` | `get_unread_counts` | + +**SQL migrace pro existující databázi:** + +```sql +ALTER TABLE conversation_members ADD INDEX idx_cm_user (user_id); +ALTER TABLE group_invitations ADD INDEX idx_inv_user (user_id); +ALTER TABLE messages ADD INDEX idx_messages_deleted (conversation_id, deleted_at); +ALTER TABLE messages ADD INDEX idx_messages_pinned (conversation_id, pinned_at); +ALTER TABLE message_reads ADD INDEX idx_reads_user (user_id); +``` + +### 1d. Upload adresář na HDD + +```env +UPLOAD_DIR=/mnt/hdd/encrypted_chat/uploads +``` + +Šifrované soubory a avatary na 4TB HDD — SSD zůstane pro OS a MySQL data. + +```bash +mkdir -p /mnt/hdd/encrypted_chat/uploads +chmod 700 /mnt/hdd/encrypted_chat/uploads +``` + +--- + +## Krok 2: MySQL tuning pro 256 GB RAM + +### `/etc/mysql/mysql.conf.d/tuning.cnf` (nebo ekvivalent v Dockeru) + +```ini +[mysqld] +# === Buffer Pool — hlavní cache pro data + indexy === +# 96 GB = ~37% RAM (MySQL + app na stejném stroji) +innodb_buffer_pool_size = 96G +innodb_buffer_pool_instances = 16 + +# === Redo Log — větší = méně I/O, rychlejší zápisy === +innodb_redo_log_capacity = 4G + +# === Flush strategie === +# 2 = flush do OS cache každou sekundu (ne každý commit) +# Ztráta max 1s dat při pádu OS, ale 10x rychlejší zápisy +innodb_flush_log_at_trx_commit = 2 +# O_DIRECT = bypass OS page cache (InnoDB má vlastní) +innodb_flush_method = O_DIRECT + +# === I/O kapacita (SSD) === +innodb_io_capacity = 2000 +innodb_io_capacity_max = 4000 + +# === Connections === +max_connections = 200 + +# === Sort/Join buffery === +sort_buffer_size = 4M +join_buffer_size = 4M +read_buffer_size = 2M +read_rnd_buffer_size = 2M + +# === Temporary tables === +tmp_table_size = 256M +max_heap_table_size = 256M + +# === Query cache (MySQL 8.0+ nemá, pro 5.7) === +# query_cache_type = 0 + +# === Thread cache === +thread_cache_size = 64 + +# === Binary logging (pro budoucí repliky) === +# server-id = 1 +# log_bin = /var/log/mysql/mysql-bin +# binlog_expire_logs_seconds = 604800 +# max_binlog_size = 256M +``` + +**Pokud MySQL běží v Dockeru:** + +```yaml +# docker-compose.yml +services: + mysql: + image: mysql:8.0 + volumes: + - /var/lib/mysql:/var/lib/mysql # data na SSD + - ./tuning.cnf:/etc/mysql/conf.d/tuning.cnf + deploy: + resources: + limits: + memory: 128G # limitovat aby zbylo pro app + environment: + MYSQL_DATABASE: encrypted_chat +``` + +### Po aplikaci restartovat MySQL a ověřit: + +```sql +SHOW VARIABLES LIKE 'innodb_buffer_pool_size'; +SHOW VARIABLES LIKE 'innodb_flush_log_at_trx_commit'; +SHOW ENGINE INNODB STATUS\G +``` + +--- + +## Krok 3: Doporučená `.env` pro produkci + +```env +# Server +SERVER_HOST=0.0.0.0 +SERVER_PORT=9999 + +# MySQL +MYSQL_HOST=127.0.0.1 +MYSQL_PORT=3306 +MYSQL_USER=sifrator +MYSQL_PASSWORD= +MYSQL_DATABASE=encrypted_chat +DB_POOL_SIZE=30 + +# Performance +THREAD_POOL_SIZE=40 + +# Storage +UPLOAD_DIR=/mnt/hdd/encrypted_chat/uploads + +# TLS (zapnout pro produkci) +TLS_ENABLED=true +TLS_CERT_FILE=/etc/letsencrypt/live/chat.example.com/fullchain.pem +TLS_KEY_FILE=/etc/letsencrypt/live/chat.example.com/privkey.pem + +# Logging +LOG_LEVEL=INFO +``` + +--- + +## Krok 4: Monitoring (doporučeno) + +### Jednoduché metriky bez externích nástrojů + +Přidat do serveru periodické logování: + +```python +# V _periodic_cleanup() (každých 10 min): +async with _clients_lock: + total_connections = sum(len(v) for v in connected_clients.values()) + unique_users = len(connected_clients) +logger.info("[STATS] users=%d connections=%d", unique_users, total_connections) +``` + +### S externími nástroji (volitelně) + +- **htop** — CPU / RAM využití procesu +- **mysqladmin status** — queries/s, slow queries, connections +- **Prometheus + Grafana** — dlouhodobé trendy (přidat až při potřebě) + +--- + +## Budoucí škálování + +### Fáze A: Separace MySQL (15K+ uživatelů) + +MySQL na separátní stroj (nebo managed DB). App server + Redis na jednom, DB na druhém. + +``` +[Server: App + Redis] ──TCP──▶ [Server: MySQL] + │ + └──▶ [HDD/S3: soubory] +``` + +### Fáze B: Horizontální škálování (50K+ uživatelů) + +Více app serverů za load balancerem + Redis Pub/Sub pro cross-server notifikace. + +``` + ┌─── App server 1 ───┐ +Client ──▶ │ connected_clients │──┐ + └─────────────────────┘ │ + ├──▶ Redis Pub/Sub ──▶ MySQL + ┌─── App server 2 ───┐ │ +Client ──▶ │ connected_clients │──┘ + └─────────────────────┘ + ▲ + Load Balancer (HAProxy / nginx stream) + (sticky sessions by user_id) +``` + +Hlavní změna: `_notify_users()` posílá do Redis místo lokálního `connected_clients` pokud uživatel není na tomto serveru. + +### Fáze C: DB škálování (100K+ uživatelů) + +- Read replicas pro SELECT dotazy +- Partitioning tabulky `messages` podle měsíce +- Sharding podle `conversation_id` + +--- + +## Přehled — co je hotovo + +| Krok | Stav | Popis | +|------|------|-------| +| asyncio.to_thread() pro DB | **Hotovo** | 131 DB volání offloadováno do thread poolu | +| ThreadPoolExecutor(40) | **Hotovo** | Konfigurovatelný přes `THREAD_POOL_SIZE` | +| DB indexy (5 nových) | **Hotovo** | Schema + SQL migrace připraveny | +| UPLOAD_DIR na HDD | **Konfigurace** | Nastavit v `.env` | +| MySQL tuning | **Konfigurace** | Aplikovat `tuning.cnf` | +| TLS certifikát | **TODO** | Let's Encrypt nebo vlastní CA | +| Monitoring | **Volitelné** | Periodické logování stats | diff --git a/schema.sql b/schema.sql new file mode 100644 index 0000000..b7d8426 --- /dev/null +++ b/schema.sql @@ -0,0 +1,189 @@ +CREATE DATABASE IF NOT EXISTS encrypted_chat + CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; + +USE encrypted_chat; + +-- Users: identity_key is Ed25519 (32B), rsa_public_key for login challenge only +CREATE TABLE IF NOT EXISTS users ( + id CHAR(36) NOT NULL PRIMARY KEY, + username VARCHAR(255) NOT NULL, + email VARCHAR(255) NOT NULL UNIQUE, + rsa_public_key TEXT NOT NULL, + identity_key BLOB NOT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP +) ENGINE=InnoDB; + +-- Devices: each user can have multiple devices +CREATE TABLE IF NOT EXISTS devices ( + id CHAR(36) NOT NULL PRIMARY KEY, + user_id CHAR(36) NOT NULL, + device_name VARCHAR(255) DEFAULT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + last_seen_at DATETIME DEFAULT NULL, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, + INDEX idx_devices_user (user_id) +) ENGINE=InnoDB; + +-- Signed Pre-Keys (X25519, signed by Ed25519 identity key) — per device +CREATE TABLE IF NOT EXISTS signed_prekeys ( + id CHAR(36) NOT NULL PRIMARY KEY, + user_id CHAR(36) NOT NULL, + device_id CHAR(36) DEFAULT NULL, + public_key BLOB NOT NULL, + signature BLOB NOT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, + INDEX idx_spk_user_device (user_id, device_id) +) ENGINE=InnoDB; + +-- One-Time Pre-Keys (consumed on use) — per device +CREATE TABLE IF NOT EXISTS one_time_prekeys ( + id CHAR(36) NOT NULL PRIMARY KEY, + user_id CHAR(36) NOT NULL, + device_id CHAR(36) DEFAULT NULL, + public_key BLOB NOT NULL, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, + INDEX idx_opk_user_device (user_id, device_id) +) ENGINE=InnoDB; + +-- Conversations +CREATE TABLE IF NOT EXISTS conversations ( + id CHAR(36) NOT NULL PRIMARY KEY, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + name VARCHAR(255) DEFAULT NULL, + created_by CHAR(36) DEFAULT NULL, + avatar_file VARCHAR(255) DEFAULT NULL +) ENGINE=InnoDB; + +CREATE TABLE IF NOT EXISTS conversation_members ( + conversation_id CHAR(36) NOT NULL, + user_id CHAR(36) NOT NULL, + joined_at DATETIME NULL, + PRIMARY KEY (conversation_id, user_id), + FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, + INDEX idx_cm_user (user_id) +) ENGINE=InnoDB; + +-- Group invitations (pending invitations to join a group) +CREATE TABLE IF NOT EXISTS group_invitations ( + id CHAR(36) NOT NULL PRIMARY KEY, + conversation_id CHAR(36) NOT NULL, + user_id CHAR(36) NOT NULL, + invited_by CHAR(36) NOT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + UNIQUE KEY uq_conv_user (conversation_id, user_id), + FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, + FOREIGN KEY (invited_by) REFERENCES users(id) ON DELETE CASCADE, + INDEX idx_inv_user (user_id) +) ENGINE=InnoDB; + +-- Messages: per-recipient ciphertext (Double Ratchet = each recipient has different ciphertext) +CREATE TABLE IF NOT EXISTS messages ( + id CHAR(36) NOT NULL PRIMARY KEY, + conversation_id CHAR(36) NOT NULL, + sender_id CHAR(36) NOT NULL, + sender_device_id CHAR(36) DEFAULT NULL, + ratchet_header BLOB NOT NULL, + x3dh_header BLOB DEFAULT NULL, + sender_chain_id BLOB DEFAULT NULL, + sender_chain_n INT DEFAULT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + deleted_at DATETIME DEFAULT NULL, + image_file_id CHAR(36) DEFAULT NULL, + pinned_at DATETIME DEFAULT NULL, + pinned_by CHAR(36) DEFAULT NULL, + FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE, + FOREIGN KEY (sender_id) REFERENCES users(id) ON DELETE CASCADE, + INDEX idx_messages_conv_created (conversation_id, created_at), + INDEX idx_messages_deleted (conversation_id, deleted_at), + INDEX idx_messages_pinned (conversation_id, pinned_at) +) ENGINE=InnoDB; + +-- Per-recipient encrypted content — per device +-- device_id '00000000-0000-0000-0000-000000000000' = self-encrypted / legacy +CREATE TABLE IF NOT EXISTS message_recipients ( + message_id CHAR(36) NOT NULL, + user_id CHAR(36) NOT NULL, + device_id CHAR(36) NOT NULL DEFAULT '00000000-0000-0000-0000-000000000000', + encrypted_content BLOB NOT NULL, + nonce BLOB NOT NULL, + ratchet_header BLOB DEFAULT NULL, + x3dh_header BLOB DEFAULT NULL, + PRIMARY KEY (message_id, user_id, device_id), + FOREIGN KEY (message_id) REFERENCES messages(id) ON DELETE CASCADE, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE +) ENGINE=InnoDB; + +-- Sender Keys for groups (distributed via pairwise ratchet) — per device +CREATE TABLE IF NOT EXISTS group_sender_keys ( + conversation_id CHAR(36) NOT NULL, + sender_id CHAR(36) NOT NULL, + device_id CHAR(36) NOT NULL DEFAULT '00000000-0000-0000-0000-000000000000', + chain_id BLOB NOT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (conversation_id, sender_id, device_id), + FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE, + FOREIGN KEY (sender_id) REFERENCES users(id) ON DELETE CASCADE +) ENGINE=InnoDB; + +-- Read receipts +CREATE TABLE IF NOT EXISTS message_reads ( + message_id CHAR(36) NOT NULL, + user_id CHAR(36) NOT NULL, + read_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (message_id, user_id), + FOREIGN KEY (message_id) REFERENCES messages(id) ON DELETE CASCADE, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, + INDEX idx_reads_user (user_id), + INDEX idx_reads_read_at (read_at) +) ENGINE=InnoDB; + +-- Delivery receipts +CREATE TABLE IF NOT EXISTS message_deliveries ( + message_id CHAR(36) NOT NULL, + user_id CHAR(36) NOT NULL, + delivered_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (message_id, user_id), + FOREIGN KEY (message_id) REFERENCES messages(id) ON DELETE CASCADE, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE +) ENGINE=InnoDB; + +-- User profiles +CREATE TABLE IF NOT EXISTS user_profiles ( + user_id CHAR(36) NOT NULL PRIMARY KEY, + phone VARCHAR(50) DEFAULT NULL, + phone_visible TINYINT(1) NOT NULL DEFAULT 0, + email_visible TINYINT(1) NOT NULL DEFAULT 1, + location VARCHAR(255) DEFAULT NULL, + location_visible TINYINT(1) NOT NULL DEFAULT 0, + avatar_file VARCHAR(255) DEFAULT NULL, + updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE +) ENGINE=InnoDB; + +-- Message reactions (emoji reactions on messages) +CREATE TABLE IF NOT EXISTS message_reactions ( + id CHAR(36) NOT NULL PRIMARY KEY, + message_id CHAR(36) NOT NULL, + user_id CHAR(36) NOT NULL, + reaction VARCHAR(32) NOT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + UNIQUE KEY uq_reaction (message_id, user_id), + FOREIGN KEY (message_id) REFERENCES messages(id) ON DELETE CASCADE, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, + INDEX idx_reactions_created_at (created_at) +) ENGINE=InnoDB; + +-- Image uploads +CREATE TABLE IF NOT EXISTS image_uploads ( + file_id CHAR(36) NOT NULL PRIMARY KEY, + conversation_id CHAR(36) NOT NULL, + uploader_id CHAR(36) NOT NULL, + file_size BIGINT NOT NULL DEFAULT 0, + completed BOOLEAN NOT NULL DEFAULT FALSE, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE, + FOREIGN KEY (uploader_id) REFERENCES users(id) ON DELETE CASCADE +) ENGINE=InnoDB; diff --git a/server.py b/server.py new file mode 100644 index 0000000..3caed14 --- /dev/null +++ b/server.py @@ -0,0 +1,3204 @@ +"""Asyncio TCP server — stores and relays encrypted blobs without seeing content.""" + +import asyncio +from concurrent.futures import ThreadPoolExecutor +import hashlib +import hmac +import ipaddress +import json +import logging +import os +import re +import secrets +import signal +import smtplib +import socket +import ssl +import subprocess +import sys +from email.mime.text import MIMEText +from pathlib import Path +from datetime import datetime, timezone + +from dotenv import load_dotenv + +load_dotenv() + +import db +from crypto_utils import load_public_key, rsa_verify, load_ed25519_public, ed25519_verify, serialize_x25519_public, load_x25519_public +from protocol import VERSION, MIN_CLIENT_VERSION, version_gte, ProtocolReader, ProtocolWriter, encode_binary, decode_binary, MAX_MESSAGE_BYTES, MAX_IMAGE_BYTES, MAX_FILE_BYTES, IMAGE_CHUNK_SIZE + + +class _AsyncDB: + """Async proxy — offloads every synchronous db.* call to a thread via asyncio.to_thread(). + + This prevents blocking the asyncio event loop during MySQL I/O. + Wrapper functions are cached after first access for efficiency. + """ + + def __getattr__(self, name: str): + func = getattr(db, name) + + async def wrapper(*args, **kwargs): + return await asyncio.to_thread(func, *args, **kwargs) + + wrapper.__name__ = name + wrapper.__qualname__ = f"_AsyncDB.{name}" + setattr(self, name, wrapper) + return wrapper + + +adb = _AsyncDB() + + +# Connected clients: user_id -> list[ProtocolWriter] +connected_clients: dict[str, list[ProtocolWriter]] = {} +# Writer -> device_id mapping (id(writer) -> device_id) +writer_device_map: dict[int, str] = {} +# Pairing sessions: code -> data +pairing_sessions: dict[str, dict] = {} +pending_registrations: dict[str, dict] = {} +# Used PoW challenges (prevents replay within validity window) +_used_pow_challenges: dict[str, float] = {} # challenge -> used_at +# Pending image uploads: file_id -> {temp_path, received_bytes, file_size, conv_id} +pending_uploads: dict[str, dict] = {} +# Phantom user IDs (loaded at startup, updated on create/delete) +phantom_user_ids: set[str] = set() + +# Locks for shared mutable state (H4 race condition fix) +_clients_lock = asyncio.Lock() # Protects: connected_clients, writer_device_map, phantom_user_ids +_conn_lock = asyncio.Lock() # Protects: connection_counts, current_connections, rate_limits +_pairing_lock = asyncio.Lock() # Protects: pairing_sessions, pending_registrations, _used_pow_challenges +_uploads_lock = asyncio.Lock() # Protects: pending_uploads +_phantom_lock = asyncio.Lock() # Serializes phantom user creation (cap check + DB create + set add) + +UPLOAD_DIR = Path(os.getenv("UPLOAD_DIR", "uploads")) + + +def _secure_delete(p: Path): + """Overwrite file with random data before deletion (anti-forensic wipe).""" + try: + if not p.exists(): + return + size = p.stat().st_size + if size > 0: + with open(p, "r+b") as f: + f.write(os.urandom(size)) + f.flush() + os.fsync(f.fileno()) + p.unlink() + except Exception: + try: + p.unlink(missing_ok=True) + except Exception: + pass + + +# C6 fix: UUID validation + safe path construction to prevent path traversal +_UUID_RE = re.compile(r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$', re.IGNORECASE) + + +def _valid_uuid(value: str) -> bool: + """Validate that value is a canonical UUID (no path components).""" + return bool(_UUID_RE.match(value)) + + +# L8 fix: email validation to prevent phantom DB inflation +_EMAIL_RE = re.compile(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$") + + +def _valid_email(email: str) -> bool: + """Validate basic email format (L8).""" + return bool(_EMAIL_RE.match(email)) and len(email) <= 254 + + +# C2 fix: ratchet/x3dh header validation +_RATCHET_HEADER_KEYS = {"dh_pub", "n", "pn"} +_MAX_HEADER_BYTES = 4096 + + +def _validate_header(raw, name: str) -> bytes | None: + """Validate and serialize a ratchet/x3dh header. + + Accepts only dict with expected keys, rejects str/bytes to prevent + poisoned headers from being stored. Validates that ratchet headers + contain the required keys (dh_pub, n, pn) with correct types. + Returns UTF-8 encoded JSON bytes or None if invalid. + """ + if not isinstance(raw, dict): + return None + serialized = json.dumps(raw) + if len(serialized) > _MAX_HEADER_BYTES: + return None + # Validate ratchet header keys/types if this looks like one + if name in ("ratchet_header", "recipient_ratchet_header"): + # Accept self-encrypted marker {"self": true} + if raw.get("self") is True and len(raw) == 1: + return serialized.encode() + if not _RATCHET_HEADER_KEYS.issubset(raw.keys()): + return None + if not isinstance(raw.get("dh_pub"), str): + return None + if type(raw.get("n")) is not int or type(raw.get("pn")) is not int: + return None + return serialized.encode() + + +def _append_file(path: Path, data: bytes): + """Append data to file (runs in thread pool to avoid blocking event loop).""" + with open(path, "ab") as f: + f.write(data) + + +def _read_file_chunk(path: Path, offset: int, size: int) -> bytes: + """Read a chunk from file (runs in thread pool to avoid blocking event loop).""" + with open(path, "rb") as f: + f.seek(offset) + return f.read(size) + + +def _safe_upload_path(file_id: str, suffix: str) -> Path | None: + """Return resolved path inside UPLOAD_DIR, or None if traversal detected.""" + p = (UPLOAD_DIR / f"{file_id}{suffix}").resolve() + if not p.is_relative_to(UPLOAD_DIR.resolve()): + return None + return p + + +def _safe_avatar_path(filename: str) -> Path | None: + """Return resolved avatar path inside UPLOAD_DIR/avatars, or None if traversal detected.""" + avatar_dir = (UPLOAD_DIR / "avatars").resolve() + p = (UPLOAD_DIR / "avatars" / filename).resolve() + if not p.is_relative_to(avatar_dir): + return None + return p + + +PAIRING_TTL_SECONDS = 300 +REGISTER_TTL_SECONDS = 600 # 10 min (was 3600) — faster slot release under load +PAIRING_MAX_POLL_ATTEMPTS = 90 +PAIRING_MAX_SESSIONS = 100 # global cap on concurrent pairing sessions +MAX_PENDING_REGISTRATIONS = 1000 # global cap on pending registration codes +MAX_PENDING_PER_IP = 5 # per-IP cap on pending registrations +MAX_PENDING_PER_SUBNET = 20 # per-/24 (IPv4) or /64 (IPv6) cap +REGISTRATION_PRESSURE_THRESHOLD = 0.8 # 80% → tighten limits + require PoW +POW_DIFFICULTY = 20 # leading zero bits in SHA-256 (~1M hashes, ~0.5-2s) +SMTP_RATE_GLOBAL = 30 # registration emails per minute (global) +SMTP_RATE_PER_IP = 3 # registration emails per minute (per IP) +SMTP_RATE_PER_TARGET = 2 # registration emails per minute (per target email) +MAX_PHANTOM_USERS = 500 # global cap on phantom user count +MAX_UPLOADS_GLOBAL = 200 # global cap on concurrent in-flight uploads +MAX_UPLOADS_PER_USER = 5 # per-user cap on concurrent in-flight uploads +UPLOAD_STALE_SECONDS = 600 # stale upload threshold (10 min) + +# SMTP configuration for registration codes +SMTP_HOST = os.getenv("SMTP_HOST", "") +SMTP_PORT = int(os.getenv("SMTP_PORT", "587")) +SMTP_USER = os.getenv("SMTP_USER", "") +SMTP_PASS = os.getenv("SMTP_PASS", "") +SMTP_FROM = os.getenv("SMTP_FROM", "") +RATE_LIMIT_WINDOW = 60.0 # seconds +CONNECTION_RL_WINDOW = 1.0 # seconds +CONNECTION_RL_MAX = 20 # max requests per window per connection +MAX_CONNECTIONS_PER_IP = 10 +MAX_CONNECTIONS_GLOBAL = 200 +METADATA_RETENTION_DAYS = int(os.getenv("METADATA_RETENTION_DAYS", "90")) +# TCP keepalive settings (seconds) +TCP_KEEPALIVE_IDLE = 25 # Start keepalive probes after 25s of idle +TCP_KEEPALIVE_INTERVAL = 10 # Send probes every 10s +TCP_KEEPALIVE_COUNT = 3 # Mark dead after 3 missed probes (30+3*10 = 60s max) + + +def setup_logging(): + level_name = os.getenv("LOG_LEVEL", "INFO").upper() + level = getattr(logging, level_name, logging.WARNING) + logging.basicConfig(level=level, format="%(asctime)s %(levelname)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S") + + +logger = logging.getLogger("encrypted_chat.server") + + +def _who(session: dict | None) -> str: + """Format session info for logging: truncated user_id + device prefix. + + Avoids leaking usernames and emails into log files. + """ + if not session: + return "" + uid = session.get("user_id", "?")[:8] + dev = session.get("device_id", "")[:8] if session.get("device_id") else "" + return f"u={uid} d={dev}" if dev else f"u={uid}" + + +rate_limits: dict[str, list[float]] = {} +connection_counts: dict[str, int] = {} +current_connections = 0 + + +def _rate_limit_key(action: str, addr: str, email: str | None = None) -> str: + if email: + return f"{action}|{addr}|{email.lower()}" + return f"{action}|{addr}" + + +def _normalize_email(email: str | None) -> str: + if not email: + return "" + return email.strip().lower() + + +async def _is_rate_limited(key: str, limit: int) -> bool: + async with _conn_lock: + now = asyncio.get_event_loop().time() + window_start = now - RATE_LIMIT_WINDOW + times = rate_limits.get(key, []) + times = [t for t in times if t >= window_start] + if len(times) >= limit: + rate_limits[key] = times + return True + times.append(now) + rate_limits[key] = times + return False + + +async def _create_phantom_guarded(email: str, addr: str, user_id: str) -> tuple[dict | None, str]: + """Check limits + create phantom user atomically (serialized via _phantom_lock). + + Returns (user_dict, error_message). user_dict is None on rejection. + """ + # Rate limit checks outside _phantom_lock (they acquire _conn_lock) + if await _is_rate_limited(f"phantom_create|{user_id}", 10): + return None, "Too many new contacts. Try later." + if await _is_rate_limited(f"phantom_create_ip|{addr}", 10): + return None, "Too many new contacts. Try later." + async with _phantom_lock: + async with _clients_lock: + phantom_count = len(phantom_user_ids) + if phantom_count >= MAX_PHANTOM_USERS: + return None, "Server limit reached. Try later." + u = await adb.create_phantom_user(email) + async with _clients_lock: + phantom_user_ids.add(u["id"]) + return u, "" + + +def _get_peer_addr(writer: ProtocolWriter) -> str: + try: + return str(writer._writer.get_extra_info("peername")[0]) + except Exception: + return "unknown" + + +async def _remove_dead_writer(w: ProtocolWriter): + """Remove a dead writer from connected_clients (best-effort).""" + async with _clients_lock: + wid = id(w) + writer_device_map.pop(wid, None) + for uid, writers in list(connected_clients.items()): + remaining = [wr for wr in writers if wr is not w] + if remaining: + connected_clients[uid] = remaining + else: + del connected_clients[uid] + + +async def _notify_users(user_ids, msg_type, data, exclude_writer=None): + """Snapshot writers under lock, send notifications outside lock.""" + targets = [] + async with _clients_lock: + for uid in user_ids: + for w in connected_clients.get(uid, []): + targets.append(w) + dead = [] + for w in targets: + if w is exclude_writer: + continue + try: + if w.is_closing(): + dead.append(w) + continue + await w.send_response(msg_type, "ok", data) + except Exception: + logger.debug("[NOTIFY] Failed to send %s, marking writer dead", msg_type) + dead.append(w) + for w in dead: + await _remove_dead_writer(w) + + +async def _notify_users_individual(notifications, exclude_writer=None): + """Send per-user data. notifications: list of (user_id, msg_type, data).""" + targets = [] + async with _clients_lock: + for uid, mt, d in notifications: + for w in connected_clients.get(uid, []): + targets.append((w, mt, d, uid)) + dead = [] + sent = 0 + skipped = 0 + delivered_users = set() + for w, mt, d, uid in targets: + if w is exclude_writer: + skipped += 1 + continue + try: + if w.is_closing(): + dead.append(w) + logger.warning("[NOTIFY] Writer for u=%s is closing, removing", uid[:8]) + continue + await w.send_response(mt, "ok", d) + sent += 1 + delivered_users.add(uid) + except Exception as e: + logger.warning("[NOTIFY] Failed to send %s to u=%s: %s", mt, uid[:8], e) + dead.append(w) + if dead: + for w in dead: + await _remove_dead_writer(w) + if mt == "new_message" if targets else False: + logger.debug("[NOTIFY] %s: sent=%d skipped=%d dead=%d", mt, sent, skipped, len(dead)) + return delivered_users + + +async def _cleanup_pairings(): + async with _pairing_lock: + now = asyncio.get_event_loop().time() + expired = [code for code, p in pairing_sessions.items() if now - p["created_at"] > PAIRING_TTL_SECONDS] + for code in expired: + pairing_sessions.pop(code, None) + + +async def _cleanup_registrations(): + async with _pairing_lock: + now = asyncio.get_event_loop().time() + expired = [code for code, p in pending_registrations.items() if now - p["created_at"] > REGISTER_TTL_SECONDS] + for code in expired: + pending_registrations.pop(code, None) + # Purge used PoW challenges older than 120s (validity window) + stale = [ch for ch, ts in _used_pow_challenges.items() if now - ts > 120] + for ch in stale: + _used_pow_challenges.pop(ch, None) + + +def _generate_pairing_code() -> str: + for _ in range(10): + code = f"{int.from_bytes(os.urandom(4), 'big') % 100000000:08d}" + if code not in pairing_sessions: + return code + return f"{int.from_bytes(os.urandom(4), 'big') % 100000000:08d}" + + +def _generate_register_code() -> str: + for _ in range(10): + code = f"{int.from_bytes(os.urandom(3), 'big') % 1000000:06d}" + if code not in pending_registrations: + return code + return f"{int.from_bytes(os.urandom(3), 'big') % 1000000:06d}" + +def _validate_public_key_pem(pem_str: str) -> bool: + """Validate that a string is a valid RSA public key PEM.""" + try: + key = load_public_key(pem_str.encode("utf-8")) + if key.key_size < 2048: + return False + return True + except Exception: + return False + + +def _send_registration_email(to_email: str, code: str) -> bool: + """Send registration code via SMTP. Returns True on success.""" + if not SMTP_HOST: + return False + try: + msg = MIMEText(f"Your registration code is: {code}\n\nThis code expires in 10 minutes.") + msg["Subject"] = "Encrypted Chat - Registration Code" + msg["From"] = SMTP_FROM or SMTP_USER + msg["To"] = to_email + with smtplib.SMTP(SMTP_HOST, SMTP_PORT, timeout=10) as server: + # RFC-style STARTTLS flow: advertise capabilities pre/post TLS upgrade. + server.ehlo() + server.starttls(context=ssl.create_default_context()) + server.ehlo() + if SMTP_USER: + server.login(SMTP_USER, SMTP_PASS) + server.send_message(msg) + return True + except Exception as e: + logger.warning("Failed to send registration email: %s", e) + return False + + +async def send_resp(msg: dict, writer: ProtocolWriter, msg_type: str, status: str, data: dict | None = None): + await writer.send_response(msg_type, status, data, request_id=msg.get("request_id")) + + +# --- Registration admission control --- + +_POW_SECRET = os.urandom(32) # per-process; restarts invalidate outstanding challenges + + +def _get_subnet(addr: str) -> str: + """Extract /24 for IPv4, /64 for IPv6.""" + try: + ip = ipaddress.ip_address(addr) + if ip.version == 4: + return str(ipaddress.ip_network(f"{ip}/24", strict=False)) + return str(ipaddress.ip_network(f"{ip}/64", strict=False)) + except ValueError: + return addr + + +def _pending_counts_by_origin(addr: str) -> tuple[int, int]: + """Count pending registrations by IP and subnet. Caller must hold _pairing_lock.""" + subnet = _get_subnet(addr) + ip_count = 0 + subnet_count = 0 + for p in pending_registrations.values(): + p_addr = p.get("addr", "") + if p_addr == addr: + ip_count += 1 + if _get_subnet(p_addr) == subnet: + subnet_count += 1 + return ip_count, subnet_count + + +def _generate_pow_challenge() -> tuple[str, str]: + """Generate a stateless PoW challenge (challenge, mac). + + The challenge embeds a timestamp so the server can reject stale solutions. + The HMAC proves the challenge was issued by this server instance. + """ + ts = str(int(asyncio.get_event_loop().time())) + nonce = secrets.token_hex(16) + challenge = f"{ts}:{nonce}" + mac = hmac.new(_POW_SECRET, challenge.encode(), hashlib.sha256).hexdigest() + return challenge, mac + + +def _verify_pow(challenge: str, mac: str, nonce: str, difficulty: int) -> bool: + """Verify a PoW solution: HMAC authentic, timestamp fresh, hash has leading zeros.""" + # Verify HMAC + expected = hmac.new(_POW_SECRET, challenge.encode(), hashlib.sha256).hexdigest() + if not hmac.compare_digest(expected, mac): + return False + # Check timestamp freshness (120s window) + try: + ts = int(challenge.split(":")[0]) + except (ValueError, IndexError): + return False + now = int(asyncio.get_event_loop().time()) + if abs(now - ts) > 120: + return False + # Verify PoW: SHA-256(challenge + nonce) must have `difficulty` leading zero bits + digest = hashlib.sha256(f"{challenge}{nonce}".encode()).digest() + # Check leading zero bits + bits_needed = difficulty + for byte in digest: + if bits_needed <= 0: + break + if bits_needed >= 8: + if byte != 0: + return False + bits_needed -= 8 + else: + mask = (0xFF << (8 - bits_needed)) & 0xFF + if byte & mask: + return False + bits_needed = 0 + return True + + +async def handle_register_start(msg: dict, writer: ProtocolWriter) -> dict | None: + await _cleanup_registrations() + username = msg.get("username", "").strip() + public_key = msg.get("public_key", "").strip() + identity_key_b64 = msg.get("identity_key", "").strip() + email = msg.get("email", "").strip() + addr = _get_peer_addr(writer) + if await _is_rate_limited(_rate_limit_key("register_start", addr, email), 3): + await send_resp(msg, writer, "register_start", "error", {"message": "Too many attempts. Try later."}) + return None + # Per-IP limit (regardless of email) to prevent SMTP spam via email rotation + if await _is_rate_limited(f"register_start_ip|{addr}", 6): + await send_resp(msg, writer, "register_start", "error", {"message": "Too many attempts. Try later."}) + return None + if not username or not public_key or not email or not identity_key_b64: + await send_resp(msg, writer, "register_start", "error", {"message": "Missing fields"}) + return None + if not _validate_public_key_pem(public_key): + await send_resp(msg, writer, "register_start", "error", {"message": "Invalid public key format"}) + return None + # Validate identity key is 32 bytes + try: + ik_bytes = decode_binary(identity_key_b64) + if len(ik_bytes) != 32: + raise ValueError("Identity key must be 32 bytes") + load_ed25519_public(ik_bytes) + except Exception: + await send_resp(msg, writer, "register_start", "error", {"message": "Invalid identity key"}) + return None + existing_email = await adb.get_user_by_email(email) + phantom_id = None + is_existing_real_user = False + if existing_email: + if existing_email.get("rsa_public_key") == "PHANTOM": + phantom_id = existing_email["id"] + else: + is_existing_real_user = True + # --- Admission control (all checks under lock, I/O outside) --- + # Existing-email goes through the same path so responses are + # indistinguishable from new-email (H3 anti-enumeration). + # Both allocate a slot so per-IP/subnet cap counting is identical. + async with _pairing_lock: + total = len(pending_registrations) + # Hard cap + if total >= MAX_PENDING_REGISTRATIONS: + reject_reason = "cap" + else: + # Per-IP / per-subnet slot limits + ip_count, subnet_count = _pending_counts_by_origin(addr) + if ip_count >= MAX_PENDING_PER_IP: + reject_reason = "ip" + elif subnet_count >= MAX_PENDING_PER_SUBNET: + reject_reason = "subnet" + else: + reject_reason = None + # Pressure mode: require PoW when >80% full + under_pressure = total >= MAX_PENDING_REGISTRATIONS * REGISTRATION_PRESSURE_THRESHOLD + need_pow = under_pressure and reject_reason is None + # If PoW required, verify the client's solution (one-time use) + pow_ok = False + if need_pow: + pow_challenge = msg.get("pow_challenge", "") + pow_mac = msg.get("pow_mac", "") + pow_nonce = msg.get("pow_nonce", "") + if pow_challenge and pow_mac and pow_nonce: + if pow_challenge in _used_pow_challenges: + pow_ok = False # replay + elif _verify_pow(pow_challenge, pow_mac, pow_nonce, POW_DIFFICULTY): + _used_pow_challenges[pow_challenge] = asyncio.get_event_loop().time() + pow_ok = True + # Decide: admit, challenge, or reject + if reject_reason: + admit = False + send_challenge = False + code = None + elif need_pow and not pow_ok: + admit = False + send_challenge = True + code = None + else: + # Both existing and new emails allocate a slot so per-IP/subnet + # counting behaves identically (anti-enumeration via slot side-channel). + # Existing-email slots are inert — register_confirm silently fails. + admit = True + send_challenge = False + code = _generate_register_code() + pending_registrations[code] = { + "username": username, + "public_key": public_key, + "identity_key": ik_bytes, + "email": email, + "created_at": asyncio.get_event_loop().time(), + "phantom_id": phantom_id, + "addr": addr, + "fake": is_existing_real_user, + } + # --- I/O outside lock --- + if not admit: + if send_challenge: + challenge, mac = _generate_pow_challenge() + await send_resp(msg, writer, "register_start", "pow_required", { + "challenge": challenge, "mac": mac, "difficulty": POW_DIFFICULTY, + }) + else: + await send_resp(msg, writer, "register_start", "error", {"message": "Server busy. Try later."}) + return None + logger.info("[REGISTER] registration started") + is_dev = os.getenv("ENVIRONMENT", "").lower() in ("dev", "development") + # SMTP rate limiting + smtp_blocked = False + if SMTP_HOST: + if await _is_rate_limited("smtp_send|global", SMTP_RATE_GLOBAL): + smtp_blocked = True + elif await _is_rate_limited(f"smtp_send_ip|{addr}", SMTP_RATE_PER_IP): + smtp_blocked = True + elif await _is_rate_limited(f"smtp_send_target|{email.lower()}", SMTP_RATE_PER_TARGET): + smtp_blocked = True + if smtp_blocked: + if is_dev: + logger.warning("[REGISTER] SMTP rate limit hit — returning code (dev mode)") + await send_resp(msg, writer, "register_start", "ok", {"code": code}) + else: + logger.warning("[REGISTER] SMTP rate limit hit — revoking slot silently") + async with _pairing_lock: + pending_registrations.pop(code, None) + await send_resp(msg, writer, "register_start", "ok", + {"message": "Code sent to your email."}) + return None + # Send registration email in a thread (non-blocking) for both real + # and fake registrations. For existing emails we still call SMTP so + # the response timing is indistinguishable (anti-enumeration). + # The email goes to the real address either way — existing users just + # won't be able to confirm (code is for a fake slot). + email_sent = await asyncio.to_thread(_send_registration_email, email, code) + if email_sent: + await send_resp(msg, writer, "register_start", "ok", {"message": "Code sent to your email."}) + elif is_dev: + logger.warning("[REGISTER] No SMTP / send failed — returning code (dev mode)") + await send_resp(msg, writer, "register_start", "ok", {"code": code}) + else: + logger.warning("[REGISTER] SMTP send failed — revoking slot silently") + async with _pairing_lock: + pending_registrations.pop(code, None) + await send_resp(msg, writer, "register_start", "ok", + {"message": "Code sent to your email."}) + return None + + +async def handle_register_confirm(msg: dict, writer: ProtocolWriter) -> dict | None: + await _cleanup_registrations() + email = msg.get("email", "").strip() + code = msg.get("code", "").strip() + addr = _get_peer_addr(writer) + if await _is_rate_limited(_rate_limit_key("register_confirm", addr, email), 3): + await send_resp(msg, writer, "register_confirm", "error", {"message": "Too many attempts. Try later."}) + return None + if not email or not code: + await send_resp(msg, writer, "register_confirm", "error", {"message": "Missing email or code"}) + return None + async with _pairing_lock: + pending = pending_registrations.get(code) + if pending and pending.get("email") == email: + pending_registrations.pop(code, None) + else: + pending = None + if not pending: + await send_resp(msg, writer, "register_confirm", "error", {"message": "Invalid or expired code"}) + return None + # H3 anti-enumeration: fake slot (existing email) — reject with same + # generic message so attacker can't distinguish from a wrong code + if pending.get("fake"): + await send_resp(msg, writer, "register_confirm", "error", {"message": "Invalid or expired code"}) + return None + phantom_id = pending.get("phantom_id") + if phantom_id: + # Upgrade phantom in-place — preserves FK references (invitations, memberships) + user_id = await adb.upgrade_phantom_user( + phantom_id, + pending["username"], + pending["public_key"], + pending["identity_key"], + ) + if user_id: + async with _clients_lock: + phantom_user_ids.discard(phantom_id) + else: + # Phantom was deleted concurrently — fall back to normal create + user_id = await adb.create_user( + pending["username"], + pending["email"], + pending["public_key"], + pending["identity_key"], + ) + else: + user_id = await adb.create_user( + pending["username"], + pending["email"], + pending["public_key"], + pending["identity_key"], + ) + await adb.create_default_profile(user_id) + logger.info("[REGISTER] confirmed (user_id=%s)", user_id[:8]) + await send_resp(msg, writer, "register_confirm", "ok", {"user_id": user_id}) + return None + + +async def handle_login_start(msg: dict, writer: ProtocolWriter, state: dict): + email = msg.get("email", "").strip() + normalized_email = _normalize_email(email) + addr = _get_peer_addr(writer) + if await _is_rate_limited(_rate_limit_key("login_start", addr, email), 10): + await send_resp(msg, writer, "login_start", "error", {"message": "Too many attempts. Try later."}) + return + if await _is_rate_limited(f"login_start_ip|{addr}", 20): + await send_resp(msg, writer, "login_start", "error", {"message": "Too many attempts. Try later."}) + return + if not email: + await send_resp(msg, writer, "login_start", "error", {"message": "Missing email"}) + return + user = await adb.get_user_by_email(normalized_email) + challenge = os.urandom(32) + state["login_email"] = email + state["login_challenge"] = challenge + if not user: + # H3 anti-enumeration: return a fake challenge so attacker can't distinguish + # "user not found" from "user exists". login_finish will fail with generic error. + state["_login_fake"] = True + await send_resp(msg, writer, "login_start", "ok", {"challenge": encode_binary(challenge)}) + + +async def handle_login_finish(msg: dict, writer: ProtocolWriter, state: dict) -> dict | None: + email = msg.get("email", "").strip() + signature_b64 = msg.get("signature", "") + challenge = state.get("login_challenge") + expected_email = state.get("login_email") + addr = _get_peer_addr(writer) + if await _is_rate_limited(_rate_limit_key("login_finish", addr, email), 10): + await send_resp(msg, writer, "login_finish", "error", {"message": "Too many attempts. Try later."}) + return None + if not email or not signature_b64: + await send_resp(msg, writer, "login_finish", "error", {"message": "Missing email or signature"}) + return None + if not challenge or expected_email != email: + await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"}) + return None + + # H3: if login_start was for a non-existent user, fail with generic error + is_fake = state.pop("_login_fake", False) + + try: + if is_fake: + await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"}) + return None + + user = await adb.get_user_by_email(email) + if not user: + await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"}) + return None + + public_key = load_public_key(user["rsa_public_key"].encode("utf-8")) + signature = decode_binary(signature_b64) + if not rsa_verify(public_key, signature, challenge): + await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"}) + return None + except ValueError: + # H5: invalid base64 in signature + await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"}) + return None + finally: + state.pop("login_challenge", None) + state.pop("login_email", None) + + user_id = user["id"] + + # Version check: reject outdated clients + client_version = msg.get("client_version", "") + if client_version and not version_gte(client_version, MIN_CLIENT_VERSION): + await send_resp(msg, writer, "login_finish", "error", { + "message": f"Client version {client_version} is too old. Minimum required: {MIN_CLIENT_VERSION}", + "min_version": MIN_CLIENT_VERSION, + "server_version": VERSION, + }) + return None + + # Device registration: client may send device_id to reuse an existing device + client_device_id = msg.get("device_id") + device_id = None + new_device_created = False + device_name = msg.get("device_name", "Unknown") + if client_device_id: + dev = await adb.get_device(client_device_id) + if dev and dev["user_id"] == user_id: + device_id = client_device_id + if not device_id: + device_id = await adb.create_device(user_id, device_name) + new_device_created = True + await adb.update_device_last_seen(device_id) + + async with _clients_lock: + if user_id not in connected_clients: + connected_clients[user_id] = [] + connected_clients[user_id].append(writer) + writer_device_map[id(writer)] = device_id + logger.info("[LOGIN] u=%s d=%s client_v=%s", + user_id[:8], device_id[:8] if device_id else "?", client_version or "unknown") + await send_resp(msg, writer, "login_finish", "ok", { + "user_id": user_id, "username": user["username"], "email": user["email"], + "device_id": device_id, "server_version": VERSION, + }) + + # Send online status notifications + contacts = await adb.get_user_contacts(user_id) + online_targets = [] + async with _clients_lock: + online_contacts = [cid for cid in contacts if cid in connected_clients and connected_clients[cid]] + # Always notify contacts (handles reconnect where old writer is still lingering) + for contact_id in contacts: + for cw in connected_clients.get(contact_id, []): + online_targets.append(cw) + await writer.send_response("online_users", "ok", {"user_ids": online_contacts}) + # Send online notifications outside lock + for cw in online_targets: + try: + await cw.send_response("user_online", "ok", {"user_id": user_id}) + except Exception: + pass + + if new_device_created: + await _notify_users([user_id], "device_added", { + "device_id": device_id, + "device_name": device_name, + "ip": addr, + "added_at": datetime.now(timezone.utc).isoformat(), + }, exclude_writer=writer) + + return {"user_id": user_id, "username": user["username"], "email": user["email"], + "device_id": device_id} + + +async def handle_get_user_info(msg: dict, session: dict, writer: ProtocolWriter): + """Get user info including identity key (for X3DH). Requires login.""" + email = msg.get("email", "").strip() + user_id = msg.get("user_id", "").strip() + addr = _get_peer_addr(writer) + if await _is_rate_limited(_rate_limit_key("get_user_info", addr, email or user_id), 30): + await send_resp(msg, writer, "get_user_info", "error", {"message": "Too many attempts. Try later."}) + return + if user_id and not _valid_uuid(user_id): + await send_resp(msg, writer, "get_user_info", "error", {"message": "User not found"}) + return + user = None + if email: + user = await adb.get_user_by_email(email) + elif user_id: + user = await adb.get_user_by_id(user_id) + if not user: + await send_resp(msg, writer, "get_user_info", "error", {"message": "User not found"}) + return + # H4 fix: restrict lookups to self or contacts (shared conversation) + target_id = user["id"] + if target_id != session["user_id"]: + if not await adb.shares_conversation(session["user_id"], target_id): + await send_resp(msg, writer, "get_user_info", "error", {"message": "User not found"}) + return + ik = user.get("identity_key") + await send_resp(msg, writer, "get_user_info", "ok", { + "user_id": user["id"], + "username": user["username"], + "email": user["email"], + "identity_key": encode_binary(ik) if ik else "", + }) + + +async def handle_upload_prekeys(msg: dict, session: dict, writer: ProtocolWriter): + """Upload signed prekey + batch of one-time prekeys.""" + if await _is_rate_limited(f"upload_prekeys|{session['user_id']}", 5): + await send_resp(msg, writer, "upload_prekeys", "error", {"message": "Too many requests. Try later."}) + return + spk_data = msg.get("signed_prekey") + otps = msg.get("one_time_prekeys", []) + if not spk_data: + await send_resp(msg, writer, "upload_prekeys", "error", {"message": "Missing signed_prekey"}) + return + + spk_id = spk_data.get("id", "") + spk_pub_b64 = spk_data.get("public_key", "") + spk_sig_b64 = spk_data.get("signature", "") + if not spk_id or not spk_pub_b64 or not spk_sig_b64: + await send_resp(msg, writer, "upload_prekeys", "error", {"message": "Incomplete signed_prekey"}) + return + + spk_pub = decode_binary(spk_pub_b64) + spk_sig = decode_binary(spk_sig_b64) + + # Verify SPK signature with user's identity key + user = await adb.get_user_by_id(session["user_id"]) + if not user or not user.get("identity_key"): + await send_resp(msg, writer, "upload_prekeys", "error", {"message": "No identity key"}) + return + ik_pub = load_ed25519_public(user["identity_key"]) + if not ed25519_verify(ik_pub, spk_sig, spk_pub): + await send_resp(msg, writer, "upload_prekeys", "error", {"message": "Invalid SPK signature"}) + return + + device_id = session.get("device_id") + await adb.store_signed_prekey(session["user_id"], spk_id, spk_pub, spk_sig, device_id=device_id) + + # Store OTPs + otp_records = [] + for otp in otps: + otp_id = otp.get("id", "") + otp_pub_b64 = otp.get("public_key", "") + if otp_id and otp_pub_b64: + otp_records.append({"id": otp_id, "public_key": decode_binary(otp_pub_b64)}) + if otp_records: + await adb.store_one_time_prekeys(session["user_id"], otp_records, device_id=device_id) + + logger.info("[PREKEYS] %s uploaded 1 SPK + %d OTPs", _who(session), len(otp_records)) + await send_resp(msg, writer, "upload_prekeys", "ok", {"message": "OK"}) + + +async def handle_get_key_bundle(msg: dict, session: dict, writer: ProtocolWriter): + """Fetch key bundle for X3DH. Returns per-device bundles. Consumes one OTP per device.""" + target_user_id = msg.get("user_id", "").strip() + if not target_user_id: + await send_resp(msg, writer, "get_key_bundle", "error", {"message": "Missing user_id"}) + return + if not _valid_uuid(target_user_id): + await send_resp(msg, writer, "get_key_bundle", "error", {"message": "Invalid user_id"}) + return + # M4: rate limit + authorization (prevents OPK depletion) + if await _is_rate_limited(f"get_key_bundle|{session['user_id']}", 10): + await send_resp(msg, writer, "get_key_bundle", "error", {"message": "Too many requests. Try later."}) + return + # Auth check before per-target rate limit so unauthorized requests don't burn target's bucket + if target_user_id != session["user_id"]: + if not await adb.shares_conversation(session["user_id"], target_user_id): + await send_resp(msg, writer, "get_key_bundle", "error", {"message": "Key bundle not available"}) + return + if await _is_rate_limited(f"get_key_bundle_target|{target_user_id}", 20): + await send_resp(msg, writer, "get_key_bundle", "error", {"message": "Too many requests. Try later."}) + return + result = await adb.get_key_bundles_for_user(target_user_id) + if not result or not result.get("device_bundles"): + await send_resp(msg, writer, "get_key_bundle", "error", {"message": "Key bundle not available"}) + return + + device_bundles_data = [] + for b in result["device_bundles"]: + entry = { + "device_id": b.get("device_id"), + "signed_prekey_id": b["signed_prekey_id"], + "signed_prekey": encode_binary(b["signed_prekey_pub"]), + "spk_signature": encode_binary(b["spk_signature"]), + } + if b.get("opk_pub"): + entry["one_time_prekey_id"] = b["opk_id"] + entry["one_time_prekey"] = encode_binary(b["opk_pub"]) + device_bundles_data.append(entry) + + # Build response with both new multi-device format and legacy flat fields + first = device_bundles_data[0] if device_bundles_data else {} + data = { + "identity_key": encode_binary(result["identity_key"]), + "device_bundles": device_bundles_data, + # Legacy flat fields from first device bundle (backward compat) + "signed_prekey_id": first.get("signed_prekey_id", ""), + "signed_prekey": first.get("signed_prekey", ""), + "spk_signature": first.get("spk_signature", ""), + } + if first.get("one_time_prekey"): + data["one_time_prekey_id"] = first["one_time_prekey_id"] + data["one_time_prekey"] = first["one_time_prekey"] + logger.info("[X3DH] %s fetched key bundle for user=%s (%d devices)", + _who(session), target_user_id[:8], len(device_bundles_data)) + await send_resp(msg, writer, "get_key_bundle", "ok", data) + + +async def handle_get_prekey_count(msg: dict, session: dict, writer: ProtocolWriter): + """How many OPKs does user have left (for this device)? Also returns SPK age for rotation.""" + device_id = session.get("device_id") + count = await adb.count_one_time_prekeys(session["user_id"], device_id=device_id) + spk_created_at = "" + spk = await adb.get_signed_prekey(session["user_id"], device_id=device_id) + if spk and spk.get("created_at"): + spk_created_at = spk["created_at"].isoformat() if hasattr(spk["created_at"], "isoformat") else str(spk["created_at"]) + await send_resp(msg, writer, "get_prekey_count", "ok", + {"count": count, "spk_created_at": spk_created_at}) + + +async def handle_ensure_prekeys(msg: dict, session: dict, writer: ProtocolWriter): + """Combined get_prekey_count + upload_prekeys in one round-trip. + + Client sends current OPK/SPK data; server checks count and SPK age, + stores new keys if provided, and returns the current status. + """ + if await _is_rate_limited(f"ensure_prekeys|{session['user_id']}", 5): + await send_resp(msg, writer, "ensure_prekeys", "error", {"message": "Too many requests. Try later."}) + return + device_id = session.get("device_id") + user_id = session["user_id"] + + # Step 1: Get current count + SPK age + count = await adb.count_one_time_prekeys(user_id, device_id=device_id) + spk_created_at = "" + spk = await adb.get_signed_prekey(user_id, device_id=device_id) + if spk and spk.get("created_at"): + spk_created_at = spk["created_at"].isoformat() if hasattr(spk["created_at"], "isoformat") else str(spk["created_at"]) + + # Step 2: If client included new keys, store them + uploaded_spk = False + uploaded_otps = 0 + spk_data = msg.get("signed_prekey") + if spk_data: + spk_id = spk_data.get("id", "") + spk_pub_b64 = spk_data.get("public_key", "") + spk_sig_b64 = spk_data.get("signature", "") + if spk_id and spk_pub_b64 and spk_sig_b64: + spk_pub = decode_binary(spk_pub_b64) + spk_sig = decode_binary(spk_sig_b64) + user = await adb.get_user_by_id(user_id) + if user and user.get("identity_key"): + ik_pub = load_ed25519_public(user["identity_key"]) + if ed25519_verify(ik_pub, spk_sig, spk_pub): + await adb.store_signed_prekey(user_id, spk_id, spk_pub, spk_sig, device_id=device_id) + uploaded_spk = True + + otps = msg.get("one_time_prekeys", []) + if otps: + otp_records = [] + for otp in otps: + otp_id = otp.get("id", "") + otp_pub_b64 = otp.get("public_key", "") + if otp_id and otp_pub_b64: + otp_records.append({"id": otp_id, "public_key": decode_binary(otp_pub_b64)}) + if otp_records: + await adb.store_one_time_prekeys(user_id, otp_records, device_id=device_id) + uploaded_otps = len(otp_records) + + # Recount after upload + if uploaded_spk or uploaded_otps: + count = await adb.count_one_time_prekeys(user_id, device_id=device_id) + spk = await adb.get_signed_prekey(user_id, device_id=device_id) + if spk and spk.get("created_at"): + spk_created_at = spk["created_at"].isoformat() if hasattr(spk["created_at"], "isoformat") else str(spk["created_at"]) + logger.info("[PREKEYS] %s ensure_prekeys: uploaded SPK=%s, OTPs=%d, new count=%d", + _who(session), uploaded_spk, uploaded_otps, count) + + await send_resp(msg, writer, "ensure_prekeys", "ok", + {"count": count, "spk_created_at": spk_created_at, + "uploaded_spk": uploaded_spk, "uploaded_otps": uploaded_otps}) + + +async def handle_rotate_keys(msg: dict, session: dict, writer: ProtocolWriter): + if await _is_rate_limited(f"rotate_keys|{session['user_id']}", 3): + await send_resp(msg, writer, "rotate_keys", "error", {"message": "Too many requests. Try later."}) + return + public_key = msg.get("public_key", "").strip() + if not public_key: + await send_resp(msg, writer, "rotate_keys", "error", {"message": "Missing public_key"}) + return + if not _validate_public_key_pem(public_key): + await send_resp(msg, writer, "rotate_keys", "error", {"message": "Invalid public key format"}) + return + await adb.update_user_rsa_key(session["user_id"], public_key) + logger.info("[ROTATE] %s rotated RSA key", _who(session)) + await send_resp(msg, writer, "rotate_keys", "ok", {"message": "OK"}) + # Disconnect other sessions + async with _clients_lock: + writers = connected_clients.get(session["user_id"], []) + others = [w for w in writers if w is not writer] + connected_clients[session["user_id"]] = [writer] + for w in others: + try: + w.close() + except Exception: + pass + + +async def handle_change_username(msg: dict, session: dict, writer: ProtocolWriter): + if await _is_rate_limited(f"change_username|{session['user_id']}", 5): + await send_resp(msg, writer, "change_username", "error", {"message": "Too many requests. Try later."}) + return + new_username = msg.get("username", "").strip() + if not new_username or len(new_username) > 100: + await send_resp(msg, writer, "change_username", "error", {"message": "Invalid username (1-100 chars)"}) + return + user_id = session["user_id"] + await adb.update_username(user_id, new_username) + session["username"] = new_username + logger.info("[ACCOUNT] %s changed username", _who(session)) + await send_resp(msg, writer, "change_username", "ok", {"username": new_username}) + # Notify contacts + contacts = await adb.get_user_contacts(user_id) + targets = [] + async with _clients_lock: + for cid in contacts: + for cw in connected_clients.get(cid, []): + targets.append(cw) + for cw in targets: + try: + await cw.send_response("username_changed", "ok", { + "user_id": user_id, "username": new_username, + }) + except Exception: + pass + + +async def handle_pairing_start(msg: dict, writer: ProtocolWriter): + await _cleanup_pairings() + email = msg.get("email", "").strip() + normalized_email = _normalize_email(email) + temp_public_key = msg.get("temp_public_key", "").strip() + temp_key_type = msg.get("temp_key_type", "x25519").strip() + addr = _get_peer_addr(writer) + # H4 fix: rate limit per IP only (not per email) to prevent enumeration via email rotation + if await _is_rate_limited(_rate_limit_key("pairing_start", addr), 10): + await send_resp(msg, writer, "pairing_start", "error", {"message": "Too many attempts. Try later."}) + return + if not email or not temp_public_key: + await send_resp(msg, writer, "pairing_start", "error", {"message": "Missing email or temp_public_key"}) + return + if temp_key_type != "x25519": + await send_resp(msg, writer, "pairing_start", "error", {"message": "Unsupported temp_key_type"}) + return + try: + temp_pub_raw = decode_binary(temp_public_key) + if len(temp_pub_raw) != 32: + raise ValueError("bad length") + load_x25519_public(temp_pub_raw) + except Exception: + await send_resp(msg, writer, "pairing_start", "error", {"message": "Invalid temp_public_key"}) + return + user = await adb.get_user_by_email(normalized_email) + target_user_id = user["id"] if user else None + poll_token = secrets.token_hex(16) + cap_hit = False + async with _pairing_lock: + # H4 fix: global cap prevents memory exhaustion from dummy sessions + if len(pairing_sessions) >= PAIRING_MAX_SESSIONS: + cap_hit = True + else: + code = _generate_pairing_code() + # H4 fix: always create session (anti-enumeration). For non-existent users + # the session behaves identically (poll returns ready:false, claim never matches + # because no real account can log in to claim it). TTL cleanup handles expiry. + pairing_sessions[code] = { + "email": normalized_email, + "user_id": target_user_id, + "temp_public_key": temp_public_key, + "temp_key_type": temp_key_type, + "created_at": asyncio.get_event_loop().time(), + "payload": None, + "poll_token": poll_token, + } + if cap_hit: + await send_resp(msg, writer, "pairing_start", "error", {"message": "Too many attempts. Try later."}) + return + logger.info( + "[PAIR] start code=%s user=%s pid=%s", + code[:8], + target_user_id[:8] if target_user_id else "", + os.getpid(), + ) + await send_resp(msg, writer, "pairing_start", "ok", {"code": code, "poll_token": poll_token}) + + +async def handle_pairing_claim(msg: dict, session: dict, writer: ProtocolWriter): + await _cleanup_pairings() + code = msg.get("code", "").strip() + if not code: + await send_resp(msg, writer, "pairing_claim", "error", {"message": "Missing code"}) + return + async with _pairing_lock: + p = pairing_sessions.get(code) + p_user_id = p.get("user_id") if p else None + temp_pub = p["temp_public_key"] if p else None + temp_key_type = p.get("temp_key_type", "x25519") if p else "x25519" + if p: + # Extend TTL — re-encryption may run between claim and send + p["created_at"] = asyncio.get_event_loop().time() + # H4 fix: unified error message (anti-enumeration) + if not p or not p_user_id or p_user_id != session.get("user_id"): + logger.warning( + "[PAIR] claim rejected code=%s pid=%s exists=%s target=%s session=%s", + code[:8], + os.getpid(), + bool(p), + p_user_id[:8] if p_user_id else "", + session.get("user_id", "")[:8] if session.get("user_id") else "", + ) + await send_resp(msg, writer, "pairing_claim", "error", {"message": "Invalid or expired code"}) + return + logger.info( + "[PAIR] claim ok code=%s user=%s pid=%s", + code[:8], + session.get("user_id", "")[:8], + os.getpid(), + ) + await send_resp(msg, writer, "pairing_claim", "ok", { + "temp_public_key": temp_pub, + "temp_key_type": temp_key_type, + }) + + +async def handle_pairing_send(msg: dict, session: dict, writer: ProtocolWriter): + await _cleanup_pairings() + code = msg.get("code", "").strip() + payload = msg.get("payload") + if not code or not payload: + await send_resp(msg, writer, "pairing_send", "error", {"message": "Missing code or payload"}) + return + error_msg = None + async with _pairing_lock: + p = pairing_sessions.get(code) + # H4 fix: unified error message (anti-enumeration) + if not p or not p.get("user_id") or p["user_id"] != session.get("user_id"): + error_msg = "Invalid or expired code" + logger.warning( + "[PAIR] send rejected code=%s pid=%s exists=%s target=%s session=%s", + code[:8], + os.getpid(), + bool(p), + p.get("user_id", "")[:8] if p and p.get("user_id") else "", + session.get("user_id", "")[:8] if session.get("user_id") else "", + ) + else: + p["payload"] = payload + logger.info( + "[PAIR] send ok code=%s user=%s pid=%s", + code[:8], + session.get("user_id", "")[:8], + os.getpid(), + ) + if error_msg: + await send_resp(msg, writer, "pairing_send", "error", {"message": error_msg}) + else: + await send_resp(msg, writer, "pairing_send", "ok", {"message": "OK"}) + + +async def handle_pairing_poll(msg: dict, writer: ProtocolWriter): + await _cleanup_pairings() + code = msg.get("code", "").strip() + poll_token = msg.get("poll_token", "").strip() + addr = _get_peer_addr(writer) + if await _is_rate_limited(_rate_limit_key("pairing_poll", addr), 120): + await send_resp(msg, writer, "pairing_poll", "error", {"message": "Too many attempts. Try later."}) + return + if not code: + await send_resp(msg, writer, "pairing_poll", "error", {"message": "Missing code"}) + return + if not poll_token: + await send_resp(msg, writer, "pairing_poll", "error", {"message": "Missing poll_token"}) + return + error_msg = None + ready = False + payload = None + async with _pairing_lock: + p = pairing_sessions.get(code) + if not p: + error_msg = "Invalid or expired code" + logger.warning("[PAIR] poll rejected code=%s pid=%s exists=false", code[:8], os.getpid()) + elif not secrets.compare_digest(p.get("poll_token", ""), poll_token): + error_msg = "Invalid poll_token" + logger.warning("[PAIR] poll token mismatch code=%s pid=%s", code[:8], os.getpid()) + else: + poll_attempts = p.get("poll_attempts", 0) + 1 + p["poll_attempts"] = poll_attempts + if poll_attempts > PAIRING_MAX_POLL_ATTEMPTS and not p.get("payload"): + pairing_sessions.pop(code, None) + error_msg = "Code invalidated due to too many attempts" + logger.warning("[PAIR] poll invalidated code=%s pid=%s attempts=%s", code[:8], os.getpid(), poll_attempts) + elif p.get("payload"): + ready = True + payload = p["payload"] + pairing_sessions.pop(code, None) + logger.info("[PAIR] poll ready code=%s pid=%s", code[:8], os.getpid()) + if error_msg: + await send_resp(msg, writer, "pairing_poll", "error", {"message": error_msg}) + elif ready: + await send_resp(msg, writer, "pairing_poll", "ok", {"ready": True, "payload": payload}) + else: + await send_resp(msg, writer, "pairing_poll", "ok", {"ready": False}) + + +async def handle_create_conversation(msg: dict, session: dict, writer: ProtocolWriter): + member_emails = msg.get("members", []) + name = msg.get("name") + addr = _get_peer_addr(writer) + if await _is_rate_limited(f"create_conversation|{session['user_id']}", 10): + await send_resp(msg, writer, "create_conversation", "error", {"message": "Too many attempts. Try later."}) + return + # Resolve all member user IDs + other_users = [] + for email in member_emails: + u = await adb.get_user_by_email(email) + if not u: + if not _valid_email(email): + await send_resp(msg, writer, "create_conversation", "error", {"message": f"Invalid email format: {email}"}) + return + # H5: atomic phantom creation (cap check + DB create + set add) + u, err_msg = await _create_phantom_guarded(email, addr, session["user_id"]) + if u is None: + await send_resp(msg, writer, "create_conversation", "error", {"message": err_msg}) + return + if u["id"] != session["user_id"]: + other_users.append(u) + is_dm = len(other_users) == 1 and not name + joined_at = datetime.now(timezone.utc) + if is_dm: + # DMs: add both members directly (no invitation) + all_ids = [session["user_id"]] + [u["id"] for u in other_users] + conv_id = await adb.create_conversation(all_ids, joined_at=joined_at, name=name, created_by=session["user_id"]) + logger.info("[CONV] %s created DM conv=%s", _who(session), conv_id[:8]) + await send_resp(msg, writer, "create_conversation", "ok", {"conversation_id": conv_id}) + # Notify the other member + members_info = await adb.get_conversation_members(conv_id) + member_list = [{"user_id": m["id"], "username": m["username"], "email": m["email"]} for m in members_info] + notif_data = { + "conversation_id": conv_id, + "name": name, + "created_by": session["user_id"], + "members": member_list, + } + await _notify_users([u["id"] for u in other_users], "conversation_created", notif_data) + else: + # Groups: only add creator, create invitations for others + conv_id = await adb.create_conversation([session["user_id"]], joined_at=joined_at, name=name, created_by=session["user_id"]) + logger.info("[CONV] %s created group conv=%s", + _who(session), conv_id[:8]) + # Create invitations for other members + creator_user = await adb.get_user_by_id(session["user_id"]) + creator_name = creator_user["username"] if creator_user else "Unknown" + invited_ids = [] + async with _clients_lock: + phantom_snapshot = set(phantom_user_ids) + for u in other_users: + await adb.create_invitation(conv_id, u["id"], session["user_id"]) + if u["id"] not in phantom_snapshot: + invited_ids.append(u["id"]) # only notify non-phantoms + inv_notif = { + "conversation_id": conv_id, + "conversation_name": name, + "invited_by": session["user_id"], + "invited_by_username": creator_name, + } + await _notify_users(invited_ids, "group_invitation", inv_notif) + await send_resp(msg, writer, "create_conversation", "ok", {"conversation_id": conv_id}) + + +async def handle_find_conversation(msg: dict, session: dict, writer: ProtocolWriter): + email = msg.get("email", "").strip() + if not email: + await send_resp(msg, writer, "find_conversation", "error", {"message": "Invalid request"}) + return + addr = _get_peer_addr(writer) + if await _is_rate_limited(_rate_limit_key("find_conversation", addr, email), 30): + await send_resp(msg, writer, "find_conversation", "error", {"message": "Too many attempts. Try later."}) + return + other = await adb.get_user_by_email(email) + if not other: + if not _valid_email(email): + await send_resp(msg, writer, "find_conversation", "error", {"message": "Invalid email format"}) + return + # H5: atomic phantom creation (cap check + DB create + set add) + other, err_msg = await _create_phantom_guarded(email, addr, session["user_id"]) + if other is None: + await send_resp(msg, writer, "find_conversation", "error", {"message": err_msg}) + return + conv_id = await adb.find_direct_conversation(session["user_id"], other["id"]) + await send_resp(msg, writer, "find_conversation", "ok", { + "conversation_id": conv_id, + "user_id": other["id"], + }) + + +async def handle_add_member(msg: dict, session: dict, writer: ProtocolWriter): + conv_id = msg.get("conversation_id", "") + email = msg.get("email", "").strip() + if not conv_id or not email: + await send_resp(msg, writer, "add_member", "error", {"message": "Invalid request"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, "add_member", "error", {"message": "Invalid conversation_id"}) + return + # L8: validate email format before phantom creation + addr = _get_peer_addr(writer) + if await _is_rate_limited(_rate_limit_key("add_member", addr, email), 10): + await send_resp(msg, writer, "add_member", "error", {"message": "Too many attempts. Try later."}) + return + if not await adb.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "add_member", "error", {"message": "Not a member"}) + return + user = await adb.get_user_by_email(email) + if not user: + # Create phantom for unregistered email (same as create_conversation) + if not _valid_email(email): + await send_resp(msg, writer, "add_member", "error", {"message": "Invalid email format"}) + return + # H5: atomic phantom creation (cap check + DB create + set add) + user, err_msg = await _create_phantom_guarded(email, addr, session["user_id"]) + if user is None: + await send_resp(msg, writer, "add_member", "error", {"message": err_msg}) + return + if await adb.is_conversation_member(conv_id, user["id"]): + await send_resp(msg, writer, "add_member", "error", {"message": "Already a member"}) + return + if await adb.has_pending_invitation(conv_id, user["id"]): + await send_resp(msg, writer, "add_member", "error", {"message": "Invitation already pending"}) + return + # Create invitation (for both real and phantom users) + await adb.create_invitation(conv_id, user["id"], session["user_id"]) + logger.info("[INVITE] %s invited u=%s to conv=%s", _who(session), user["id"][:8], conv_id[:8]) + await send_resp(msg, writer, "add_member", "ok", {"user_id": user["id"]}) + # Push invitation notification only to non-phantom users + async with _clients_lock: + is_phantom = user["id"] in phantom_user_ids + if not is_phantom: + conv = await adb.get_conversation(conv_id) + creator_user = await adb.get_user_by_id(session["user_id"]) + creator_name = creator_user["username"] if creator_user else "Unknown" + inv_notif = { + "conversation_id": conv_id, + "conversation_name": conv.get("name") if conv else None, + "invited_by": session["user_id"], + "invited_by_username": creator_name, + } + await _notify_users([user["id"]], "group_invitation", inv_notif) + + +async def handle_accept_invitation(msg: dict, session: dict, writer: ProtocolWriter): + """Accept a group invitation — add user to conversation members.""" + conv_id = msg.get("conversation_id", "") + if not conv_id: + await send_resp(msg, writer, "accept_invitation", "error", {"message": "Missing conversation_id"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, "accept_invitation", "error", {"message": "Invalid conversation_id"}) + return + if not await adb.has_pending_invitation(conv_id, session["user_id"]): + await send_resp(msg, writer, "accept_invitation", "error", {"message": "No pending invitation"}) + return + joined_at = datetime.now(timezone.utc) + await adb.add_conversation_member(conv_id, session["user_id"], joined_at=joined_at) + await adb.delete_invitation(conv_id, session["user_id"]) + logger.info("[INVITE] %s accepted invitation to conv=%s", _who(session), conv_id[:8]) + await send_resp(msg, writer, "accept_invitation", "ok", {"conversation_id": conv_id}) + # Notify existing members about the new member + user = await adb.get_user_by_id(session["user_id"]) + notif_data = { + "conversation_id": conv_id, + "user_id": session["user_id"], + "username": user["username"] if user else "", + "email": user["email"] if user else "", + } + members = await adb.get_conversation_members(conv_id) + member_ids = [m["id"] for m in members if m["id"] != session["user_id"]] + await _notify_users(member_ids, "member_added", notif_data) + + +async def handle_decline_invitation(msg: dict, session: dict, writer: ProtocolWriter): + """Decline a group invitation.""" + conv_id = msg.get("conversation_id", "") + if not conv_id: + await send_resp(msg, writer, "decline_invitation", "error", {"message": "Missing conversation_id"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, "decline_invitation", "error", {"message": "Invalid conversation_id"}) + return + if not await adb.has_pending_invitation(conv_id, session["user_id"]): + await send_resp(msg, writer, "decline_invitation", "error", {"message": "No pending invitation"}) + return + await adb.delete_invitation(conv_id, session["user_id"]) + logger.info("[INVITE] %s declined invitation to conv=%s", _who(session), conv_id[:8]) + await send_resp(msg, writer, "decline_invitation", "ok", {"message": "OK"}) + + +async def handle_list_invitations(msg: dict, session: dict, writer: ProtocolWriter): + """List pending group invitations for the current user.""" + invitations = await adb.get_pending_invitations(session["user_id"]) + result = [] + for inv in invitations: + entry = { + "conversation_id": inv["conversation_id"], + "conversation_name": inv.get("conversation_name"), + "invited_by": inv["invited_by"], + "invited_by_username": inv.get("invited_by_username", ""), + "created_at": inv["created_at"].isoformat() if hasattr(inv["created_at"], "isoformat") else str(inv["created_at"]), + } + result.append(entry) + await send_resp(msg, writer, "list_invitations", "ok", {"invitations": result}) + + +async def handle_list_conversations(msg: dict, session: dict, writer: ProtocolWriter): + convs = await adb.list_user_conversations(session["user_id"]) + unread = await adb.get_unread_counts(session["user_id"], max_age_days=METADATA_RETENTION_DAYS) + result = [] + for c in convs: + result.append({ + "conversation_id": c["id"], + "created_at": c["created_at"].isoformat() if hasattr(c["created_at"], "isoformat") else str(c["created_at"]), + "members": c["members"], + "name": c.get("name"), + "created_by": c.get("created_by"), + "avatar_file": c.get("avatar_file"), + "unread_count": unread.get(c["id"], 0), + }) + logger.info("[LIST] %s listed %d conversations", _who(session), len(result)) + await send_resp(msg, writer, "list_conversations", "ok", {"conversations": result}) + + +async def handle_send_message(msg: dict, session: dict, writer: ProtocolWriter): + conv_id = msg.get("conversation_id", "") + if not conv_id: + await send_resp(msg, writer, "send_message", "error", {"message": "Missing conversation_id"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, "send_message", "error", {"message": "Invalid conversation_id"}) + return + addr = _get_peer_addr(writer) + if await _is_rate_limited(_rate_limit_key("send_message", addr, session.get("email")), 20): + await send_resp(msg, writer, "send_message", "error", {"message": "Too many attempts. Try later."}) + return + if not await adb.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "send_message", "error", {"message": "Not a member"}) + return + + # New protocol: ratchet_header + recipients[] with per-user ciphertext + ratchet_header_raw = msg.get("ratchet_header") + recipients_raw = msg.get("recipients") + if not ratchet_header_raw or not recipients_raw: + await send_resp(msg, writer, "send_message", "error", {"message": "Missing ratchet_header or recipients"}) + return + + # C2 fix: validate header is a dict (reject raw str/bytes) + ratchet_header = _validate_header(ratchet_header_raw, "ratchet_header") + if ratchet_header is None: + await send_resp(msg, writer, "send_message", "error", {"message": "Invalid ratchet_header format"}) + return + + x3dh_header_raw = msg.get("x3dh_header") + x3dh_header = None + if x3dh_header_raw: + x3dh_header = _validate_header(x3dh_header_raw, "x3dh_header") + if x3dh_header is None: + await send_resp(msg, writer, "send_message", "error", {"message": "Invalid x3dh_header format"}) + return + + sender_chain_id_b64 = msg.get("sender_chain_id") + sender_chain_id = decode_binary(sender_chain_id_b64) if sender_chain_id_b64 else None + sender_chain_n = msg.get("sender_chain_n") + + # Validate recipients are actual members + conv_members = await adb.get_conversation_members(conv_id) + member_ids = {m["id"] for m in conv_members} + async with _clients_lock: + phantom_snapshot = set(phantom_user_ids) + db_recipients = [] + for r in recipients_raw: + uid = r.get("user_id", "") + if uid not in member_ids: + continue + if uid in phantom_snapshot: + continue + ct_b64 = r.get("encrypted_content", "") + nonce_b64 = r.get("nonce", "") + if not ct_b64 or not nonce_b64: + continue + entry = { + "user_id": uid, + "encrypted_content": decode_binary(ct_b64), + "nonce": decode_binary(nonce_b64), + } + # Per-recipient device_id (multi-device support) + r_device_id = r.get("device_id") + if r_device_id: + entry["device_id"] = r_device_id + # Per-recipient ratchet header and x3dh header (C2 fix: validate dict) + r_rh = r.get("ratchet_header") + if r_rh: + r_rh_bytes = _validate_header(r_rh, "recipient_ratchet_header") + if r_rh_bytes: + entry["ratchet_header"] = r_rh_bytes + r_x3dh = r.get("x3dh_header") + if r_x3dh: + r_x3dh_bytes = _validate_header(r_x3dh, "recipient_x3dh_header") + if r_x3dh_bytes: + entry["x3dh_header"] = r_x3dh_bytes + db_recipients.append(entry) + if not db_recipients: + await send_resp(msg, writer, "send_message", "error", {"message": "No valid recipients"}) + return + + image_file_id = msg.get("image_file_id") + + # Metadata privacy: for group messages (sender_chain_id present), store chain + # metadata in per-recipient ratchet_header instead of the messages table. + # This avoids persisting sender correlation data at the message level. + # Skip sender's own self-copy entry — it uses a different decrypt path + # (self-encryption key) and must keep its own ratchet_header ({"self":true}). + db_sender_chain_id = None + db_sender_chain_n = None + if sender_chain_id: + chain_meta = json.dumps({ + "chain_id": encode_binary(sender_chain_id), + "chain_n": sender_chain_n, + }).encode() + sender_uid = session["user_id"] + for r in db_recipients: + # Skip self-copy (sender's own entry) — uses self-encryption, not sender key + if r["user_id"] == sender_uid: + continue + if not r.get("ratchet_header"): + r["ratchet_header"] = chain_meta + + msg_id, created_at = await adb.store_message( + conv_id, session["user_id"], ratchet_header, db_recipients, + x3dh_header=x3dh_header, + sender_chain_id=db_sender_chain_id, + sender_chain_n=db_sender_chain_n, + image_file_id=image_file_id, + sender_device_id=session.get("device_id"), + ) + + # Link image upload to message if present + if image_file_id: + upload = await adb.get_image_upload(image_file_id) + if upload and upload["completed"] and upload["uploader_id"] == session["user_id"]: + await adb.set_message_image_file_id(msg_id, image_file_id) + + logger.info("[MSG] %s msg=%s conv=%s", _who(session), msg_id[:8], conv_id[:8]) + await send_resp(msg, writer, "send_message", "ok", {"message_id": msg_id, "created_at": created_at}) + + # Notify connected recipients — group all per-device entries by user_id + # Use validated db_recipients (not raw input) to prevent unvalidated headers in push + msg_ratchet_header_dict = json.loads(ratchet_header.decode()) + msg_x3dh_header_dict = json.loads(x3dh_header.decode()) if x3dh_header else None + + from collections import defaultdict + user_entries = defaultdict(list) + for r in db_recipients: + uid = r["user_id"] + # Per-recipient headers are stored as bytes; decode back to dict for notification JSON + r_rh = r.get("ratchet_header") + r_rh_dict = json.loads(r_rh.decode()) if r_rh else None + r_x3dh = r.get("x3dh_header") + r_x3dh_dict = json.loads(r_x3dh.decode()) if r_x3dh else None + user_entries[uid].append({ + "device_id": r.get("device_id", db.SELF_DEVICE_ID), + "encrypted_content": encode_binary(r["encrypted_content"]), + "nonce": encode_binary(r["nonce"]), + "ratchet_header": r_rh_dict or msg_ratchet_header_dict, + "x3dh_header": r_x3dh_dict or msg_x3dh_header_dict, + }) + + notifications = [] + for uid, entries in user_entries.items(): + notif_data = { + "message_id": msg_id, + "conversation_id": conv_id, + "sender_id": session["user_id"], + "sender_device_id": session.get("device_id"), + "device_entries": entries, + } + if sender_chain_id_b64: + notif_data["sender_chain_id"] = sender_chain_id_b64 + if sender_chain_n is not None: + notif_data["sender_chain_n"] = sender_chain_n + # Also include flat fields for backward compat with old clients + # (first entry's data as fallback) + if entries: + first = entries[0] + notif_data["ratchet_header"] = first.get("ratchet_header") or msg_ratchet_header_dict + notif_data["encrypted_content"] = first.get("encrypted_content", "") + notif_data["nonce"] = first.get("nonce", "") + if first.get("x3dh_header"): + notif_data["x3dh_header"] = first["x3dh_header"] + notifications.append((uid, "new_message", notif_data)) + # Log notification targets for debugging delivery issues + async with _clients_lock: + targets_info = [] + for uid, _, _ in notifications: + n_writers = len(connected_clients.get(uid, [])) + targets_info.append(f"{uid[:8]}({n_writers}w)") + logger.info("[PUSH] msg=%s conv=%s targets=[%s] exclude_sender=%s", + msg_id[:8], conv_id[:8], ", ".join(targets_info), "yes") + delivered_users = await _notify_users_individual(notifications, exclude_writer=writer) + + # Delivery receipt: if at least one recipient device got the push, acknowledge + # delivery to sender immediately. + delivered_users.discard(session["user_id"]) + if delivered_users: + for delivered_uid in delivered_users: + await _notify_users([session["user_id"]], "message_delivered", { + "conversation_id": conv_id, + "user_id": delivered_uid, + "message_ids": [msg_id], + }) + + +async def handle_get_messages(msg: dict, session: dict, writer: ProtocolWriter): + if await _is_rate_limited(f"get_messages|{session['user_id']}", 30): + await send_resp(msg, writer, "get_messages", "error", {"message": "Too many requests. Try later."}) + return + conv_id = msg.get("conversation_id", "") + if not conv_id: + await send_resp(msg, writer, "get_messages", "error", {"message": "Missing conversation_id"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, "get_messages", "error", {"message": "Invalid conversation_id"}) + return + if not await adb.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "get_messages", "error", {"message": "Not a member"}) + return + + limit = min(max(int(msg.get("limit", 50)), 1), 200) + offset = max(int(msg.get("offset", 0)), 0) + device_id = session.get("device_id") + after_ts = msg.get("after_ts") # ISO timestamp string or None + messages = await adb.get_messages(conv_id, session["user_id"], limit, offset, + device_id=device_id, after_ts=after_ts) + + # Deduplicate: when both device-specific and SELF_DEVICE_ID rows exist for the + # same message, prefer device-specific (non-sentinel). Keep first seen per message_id. + seen_ids = {} + deduped = [] + for m in messages: + mid = m["id"] + mr_dev = m.get("mr_device_id", "") + if mid not in seen_ids: + seen_ids[mid] = len(deduped) + deduped.append(m) + elif mr_dev != db.SELF_DEVICE_ID: + # Replace SELF_DEVICE_ID entry with device-specific one + deduped[seen_ids[mid]] = m + messages = deduped + + result = [] + message_ids = [m["id"] for m in messages] + read_status = await adb.get_message_read_status(message_ids) if message_ids else {} + delivery_status = await adb.get_message_delivery_status(message_ids) if message_ids else {} + reactions_map = await adb.get_reactions(message_ids) if message_ids else {} + for m in messages: + read_by = read_status.get(m["id"], []) + # Prefer per-recipient headers (mr_*) over message-level headers + rh_raw = m.get("mr_ratchet_header") or m.get("ratchet_header") + x3dh_raw = m.get("mr_x3dh_header") or m.get("x3dh_header") + # C2 fix: defensive JSON parsing — corrupted headers don't break fetch + try: + rh_parsed = json.loads(rh_raw) if rh_raw else {} + except (json.JSONDecodeError, TypeError, UnicodeDecodeError): + logger.warning("[FETCH] Corrupted ratchet_header in message %s, skipping", m["id"]) + rh_parsed = {} + try: + x3dh_parsed = json.loads(x3dh_raw) if x3dh_raw else None + except (json.JSONDecodeError, TypeError, UnicodeDecodeError): + logger.warning("[FETCH] Corrupted x3dh_header in message %s, skipping", m["id"]) + x3dh_parsed = None + entry = { + "message_id": m["id"], + "sender_id": m.get("sender_id") or "", + "ratchet_header": rh_parsed, + "encrypted_content": encode_binary(m["encrypted_content"]) if m.get("encrypted_content") else "", + "nonce": encode_binary(m["nonce"]) if m.get("nonce") else "", + "created_at": m["created_at"].isoformat() if hasattr(m["created_at"], "isoformat") else str(m["created_at"]), + "read_by": read_by, + "delivered_to": delivery_status.get(m["id"], []), + } + if x3dh_parsed: + entry["x3dh_header"] = x3dh_parsed + # Sender chain metadata: check message-level first (backward compat), + # then per-recipient ratchet_header (new metadata-private format). + # Only extract from per-recipient header if message-level ratchet_header + # is the group dummy (dh_pub all-zeros) — prevents DM header injection. + if m.get("sender_chain_id"): + entry["sender_chain_id"] = encode_binary(m["sender_chain_id"]) + elif isinstance(rh_parsed, dict) and rh_parsed.get("chain_id"): + # Verify this is a group message by checking the message-level header + msg_rh_raw = m.get("ratchet_header") + is_group = False + if msg_rh_raw: + try: + msg_rh = json.loads(msg_rh_raw) if isinstance(msg_rh_raw, (bytes, str)) else msg_rh_raw + is_group = isinstance(msg_rh, dict) and msg_rh.get("dh_pub") == "00" * 32 + except (json.JSONDecodeError, TypeError, UnicodeDecodeError): + pass + if is_group: + entry["sender_chain_id"] = rh_parsed["chain_id"] + if m.get("sender_chain_n") is not None: + entry["sender_chain_n"] = m["sender_chain_n"] + elif isinstance(rh_parsed, dict) and rh_parsed.get("chain_n") is not None: + # Same group-only guard + if "sender_chain_id" in entry: + entry["sender_chain_n"] = rh_parsed["chain_n"] + if m.get("sender_device_id"): + entry["sender_device_id"] = m["sender_device_id"] + if m.get("deleted_at"): + entry["deleted_at"] = m["deleted_at"].isoformat() if hasattr(m["deleted_at"], "isoformat") else str(m["deleted_at"]) + # Pin metadata + if m.get("pinned_at"): + entry["pinned_at"] = m["pinned_at"].isoformat() if hasattr(m["pinned_at"], "isoformat") else str(m["pinned_at"]) + entry["pinned_by"] = m.get("pinned_by") or "" + # Reactions + msg_reactions = reactions_map.get(m["id"]) + if msg_reactions: + entry["reactions"] = msg_reactions + result.append(entry) + total_count = await adb.count_messages(conv_id, session["user_id"]) + logger.info("[FETCH] %s fetched %d/%d msgs from conv=%s (limit=%d, offset=%d%s)", + _who(session), len(result), total_count, conv_id[:8], limit, offset, + f", after={after_ts}" if after_ts else "") + await send_resp(msg, writer, "get_messages", "ok", + {"messages": result, "total_count": total_count}) + + +async def _handle_typing_event(msg_type: str, msg: dict, session: dict, writer: ProtocolWriter): + conv_id = msg.get("conversation_id", "") + if not conv_id: + await send_resp(msg, writer, msg_type, "error", {"message": "Missing conversation_id"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, msg_type, "error", {"message": "Invalid conversation_id"}) + return + if await _is_rate_limited(f"{msg_type}|{session['user_id']}|{conv_id}", 120): + await send_resp(msg, writer, msg_type, "error", {"message": "Too many typing events. Slow down."}) + return + if not await adb.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, msg_type, "error", {"message": "Not a member"}) + return + + members = await adb.get_conversation_members(conv_id) + targets = [m["id"] for m in members if m["id"] != session["user_id"]] + await _notify_users(targets, msg_type, { + "conversation_id": conv_id, + "user_id": session["user_id"], + "username": session.get("username", ""), + }, exclude_writer=writer) + await send_resp(msg, writer, msg_type, "ok", {"message": "OK"}) + + +async def handle_typing_start(msg: dict, session: dict, writer: ProtocolWriter): + await _handle_typing_event("typing_start", msg, session, writer) + + +async def handle_typing_stop(msg: dict, session: dict, writer: ProtocolWriter): + await _handle_typing_event("typing_stop", msg, session, writer) + + +async def handle_remove_member(msg: dict, session: dict, writer: ProtocolWriter): + if await _is_rate_limited(f"remove_member|{session['user_id']}", 10): + await send_resp(msg, writer, "remove_member", "error", {"message": "Too many requests. Try later."}) + return + conv_id = msg.get("conversation_id", "") + user_id = msg.get("user_id", "") + if not conv_id or not user_id: + await send_resp(msg, writer, "remove_member", "error", {"message": "Missing conversation_id or user_id"}) + return + if not _valid_uuid(conv_id) or not _valid_uuid(user_id): + await send_resp(msg, writer, "remove_member", "error", {"message": "Invalid conversation_id or user_id"}) + return + if not await adb.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "remove_member", "error", {"message": "Not a member"}) + return + convs = await adb.list_user_conversations(session["user_id"]) + conv_data = None + for c in convs: + if c["id"] == conv_id: + conv_data = c + break + if not conv_data or conv_data.get("created_by") != session["user_id"]: + await send_resp(msg, writer, "remove_member", "error", {"message": "Only the group creator can remove members"}) + return + if user_id == session["user_id"]: + await send_resp(msg, writer, "remove_member", "error", {"message": "Cannot remove yourself"}) + return + # Get remaining members before removing (to notify them) + members_before = await adb.get_conversation_members(conv_id) + # M6: atomic removal — return value confirms row existed + removed = await adb.remove_conversation_member_atomic(conv_id, user_id) + if not removed: + await send_resp(msg, writer, "remove_member", "error", {"message": "Member already removed"}) + return + logger.info("[MEMBER] %s removed user=%s from conv=%s", _who(session), user_id[:8], conv_id[:8]) + await send_resp(msg, writer, "remove_member", "ok", {"message": "OK"}) + + # Notify removed member and remaining members + notif_data = { + "conversation_id": conv_id, + "user_id": user_id, + } + member_ids = [m["id"] for m in members_before if m["id"] != session["user_id"]] + await _notify_users(member_ids, "member_removed", notif_data) + + +async def handle_leave_group(msg: dict, session: dict, writer: ProtocolWriter): + """Leave a group conversation voluntarily.""" + conv_id = msg.get("conversation_id", "") + if not conv_id: + await send_resp(msg, writer, "leave_group", "error", {"message": "Missing conversation_id"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, "leave_group", "error", {"message": "Invalid conversation_id"}) + return + if not await adb.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "leave_group", "error", {"message": "Not a member"}) + return + # Don't allow leaving DMs (2 members without a name) + conv = await adb.get_conversation(conv_id) + members = await adb.get_conversation_members(conv_id) + if len(members) <= 2 and not (conv and conv.get("name")): + await send_resp(msg, writer, "leave_group", "error", {"message": "Cannot leave a DM conversation"}) + return + # If creator is leaving, transfer to first remaining member + if conv and conv.get("created_by") == session["user_id"]: + remaining = [m for m in members if m["id"] != session["user_id"]] + if remaining: + await adb.update_conversation_creator(conv_id, remaining[0]["id"]) + # M6: atomic removal + await adb.remove_conversation_member_atomic(conv_id, session["user_id"]) + logger.info("[LEAVE] %s left group conv=%s", _who(session), conv_id[:8]) + await send_resp(msg, writer, "leave_group", "ok", {"message": "OK"}) + # Notify remaining members + notif_data = { + "conversation_id": conv_id, + "user_id": session["user_id"], + } + member_ids = [m["id"] for m in members if m["id"] != session["user_id"]] + await _notify_users(member_ids, "member_removed", notif_data) + + +async def handle_rename_conversation(msg: dict, session: dict, writer: ProtocolWriter): + """Rename a group conversation (creator only).""" + if await _is_rate_limited(f"rename_conv|{session['user_id']}", 5): + await send_resp(msg, writer, "rename_conversation", "error", {"message": "Too many requests. Try later."}) + return + conv_id = msg.get("conversation_id", "") + new_name = msg.get("name", "").strip() + if not conv_id or not new_name: + await send_resp(msg, writer, "rename_conversation", "error", {"message": "Missing conversation_id or name"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, "rename_conversation", "error", {"message": "Invalid conversation_id"}) + return + if len(new_name) > 100: + await send_resp(msg, writer, "rename_conversation", "error", {"message": "Name too long (max 100)"}) + return + if not await adb.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "rename_conversation", "error", {"message": "Not a member"}) + return + conv = await adb.get_conversation(conv_id) + if not conv or not conv.get("name"): + await send_resp(msg, writer, "rename_conversation", "error", {"message": "Cannot rename a DM conversation"}) + return + if conv.get("created_by") != session["user_id"]: + await send_resp(msg, writer, "rename_conversation", "error", {"message": "Only the group creator can rename"}) + return + await adb.update_conversation_name(conv_id, new_name) + logger.info("[RENAME] %s renamed conv=%s", _who(session), conv_id[:8]) + await send_resp(msg, writer, "rename_conversation", "ok", {"message": "OK"}) + # Notify all members + members = await adb.get_conversation_members(conv_id) + member_ids = [m["id"] for m in members if m["id"] != session["user_id"]] + await _notify_users(member_ids, "conversation_renamed", { + "conversation_id": conv_id, + "name": new_name, + "renamed_by": session["user_id"], + }) + + +async def handle_delete_conversation(msg: dict, session: dict, writer: ProtocolWriter): + """Delete a conversation for the current user. Removes user from members, + deletes the conversation if no members remain.""" + if await _is_rate_limited(f"delete_conv|{session['user_id']}", 5): + await send_resp(msg, writer, "delete_conversation", "error", {"message": "Too many requests. Try later."}) + return + conv_id = msg.get("conversation_id", "") + if not conv_id: + await send_resp(msg, writer, "delete_conversation", "error", {"message": "Missing conversation_id"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, "delete_conversation", "error", {"message": "Invalid conversation_id"}) + return + if not await adb.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "delete_conversation", "error", {"message": "Not a member"}) + return + conv = await adb.get_conversation(conv_id) + members = await adb.get_conversation_members(conv_id) + is_group = len(members) > 2 or (conv and conv.get("name")) + # Groups can only be deleted by the creator (admin) + if is_group and (not conv or conv.get("created_by") != session["user_id"]): + await send_resp(msg, writer, "delete_conversation", "error", {"message": "Only the group creator can delete this conversation"}) + return + if is_group: + # Group: creator deletes for everyone — remove all members, clean up, delete + for member in members: + await adb.remove_conversation_member(conv_id, member["id"]) + else: + # DM: only remove self; other user keeps the conversation + await adb.remove_conversation_member(conv_id, session["user_id"]) + remaining_count = await adb.count_conversation_members(conv_id) + if remaining_count == 0: + # Clean up uploaded files from disk + file_ids = await adb.get_conversation_file_ids(conv_id) + for fid in file_ids: + for ext in (".enc", ".tmp"): + p = _safe_upload_path(fid, ext) + if not p: + continue + _secure_delete(p) + await adb.delete_conversation(conv_id) + logger.info("[DELETE] %s deleted conv=%s", _who(session), conv_id[:8]) + await send_resp(msg, writer, "delete_conversation", "ok", {"message": "OK"}) + # Notify other members they were removed + notif_data = { + "conversation_id": conv_id, + "user_id": session["user_id"], + } + member_ids = [m["id"] for m in members if m["id"] != session["user_id"]] + await _notify_users(member_ids, "member_removed", notif_data) + + +async def handle_mark_read(msg: dict, session: dict, writer: ProtocolWriter): + conv_id = msg.get("conversation_id", "") + message_ids = msg.get("message_ids", []) + if not conv_id or not message_ids: + await send_resp(msg, writer, "mark_read", "error", {"message": "Missing conversation_id or message_ids"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, "mark_read", "error", {"message": "Invalid conversation_id"}) + return + if len(message_ids) > 500: + await send_resp(msg, writer, "mark_read", "error", {"message": "Too many message_ids (max 500)"}) + return + if not await adb.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "mark_read", "error", {"message": "Not a member"}) + return + # M1 fix: filter to only message_ids that belong to this conversation + valid_ids = await adb.filter_message_ids_by_conversation(conv_id, message_ids) + if not valid_ids: + await send_resp(msg, writer, "mark_read", "ok", {"message": "OK"}) + return + await adb.mark_messages_read(conv_id, session["user_id"], valid_ids) + logger.info("[READ] %s marked %d msgs read in conv=%s", _who(session), len(valid_ids), conv_id[:8]) + await send_resp(msg, writer, "mark_read", "ok", {"message": "OK"}) + members = await adb.get_conversation_members(conv_id) + notif_data = { + "conversation_id": conv_id, + "user_id": session["user_id"], + "message_ids": valid_ids, + } + member_ids = [m["id"] for m in members if m["id"] != session["user_id"]] + await _notify_users(member_ids, "messages_read", notif_data) + + +async def handle_mark_conversation_read(msg: dict, session: dict, writer: ProtocolWriter): + conv_id = msg.get("conversation_id", "") + if not conv_id: + await send_resp(msg, writer, "mark_conversation_read", "error", {"message": "Missing conversation_id"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, "mark_conversation_read", "error", {"message": "Invalid conversation_id"}) + return + if not await adb.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "mark_conversation_read", "error", {"message": "Not a member"}) + return + count = await adb.mark_conversation_read(conv_id, session["user_id"]) + logger.info("[READ] %s marked conv=%s all-read (%d msgs)", _who(session), conv_id[:8], count) + await send_resp(msg, writer, "mark_conversation_read", "ok", {"marked_count": count}) + if count > 0: + members = await adb.get_conversation_members(conv_id) + member_ids = [m["id"] for m in members if m["id"] != session["user_id"]] + await _notify_users(member_ids, "messages_read", { + "conversation_id": conv_id, + "user_id": session["user_id"], + "message_ids": [], + }) + + +async def handle_confirm_delivery(msg: dict, session: dict, writer: ProtocolWriter): + conv_id = msg.get("conversation_id", "") + message_ids = msg.get("message_ids", []) + if not conv_id or not message_ids: + await send_resp(msg, writer, "confirm_delivery", "error", {"message": "Missing conversation_id or message_ids"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, "confirm_delivery", "error", {"message": "Invalid conversation_id"}) + return + if len(message_ids) > 500: + await send_resp(msg, writer, "confirm_delivery", "error", {"message": "Too many message_ids (max 500)"}) + return + if not await adb.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "confirm_delivery", "error", {"message": "Not a member"}) + return + # M1 fix: filter to only message_ids that belong to this conversation + valid_ids = await adb.filter_message_ids_by_conversation(conv_id, message_ids) + if not valid_ids: + await send_resp(msg, writer, "confirm_delivery", "ok", {"message": "OK"}) + return + await adb.mark_messages_delivered(conv_id, session["user_id"], valid_ids) + logger.info("[DELIVERY] %s confirmed %d msgs delivered in conv=%s", _who(session), len(valid_ids), conv_id[:8]) + await send_resp(msg, writer, "confirm_delivery", "ok", {"message": "OK"}) + + # Notify senders — batch lookup sender_id per message, push to each sender + sender_msgs: dict[str, list[str]] = {} + for mid in valid_ids: + sid = await adb.get_message_sender(mid) + if sid and sid != session["user_id"]: + sender_msgs.setdefault(sid, []).append(mid) + for sender_id, mids in sender_msgs.items(): + await _notify_users([sender_id], "message_delivered", { + "conversation_id": conv_id, + "user_id": session["user_id"], + "message_ids": mids, + }) + + +async def handle_delete_message(msg: dict, session: dict, writer: ProtocolWriter): + if await _is_rate_limited(f"delete_msg|{session['user_id']}", 20): + await send_resp(msg, writer, "delete_message", "error", {"message": "Too many requests. Try later."}) + return + message_id = msg.get("message_id", "") + if not message_id: + await send_resp(msg, writer, "delete_message", "error", {"message": "Missing message_id"}) + return + if not _valid_uuid(message_id): + await send_resp(msg, writer, "delete_message", "error", {"message": "Invalid message_id"}) + return + conv_id = await adb.get_message_conversation(message_id) + if not conv_id: + await send_resp(msg, writer, "delete_message", "error", {"message": "Message not found"}) + return + if not await adb.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "delete_message", "error", {"message": "Not a member"}) + return + result = await adb.soft_delete_message(message_id, session["user_id"]) + if result is None: + await send_resp(msg, writer, "delete_message", "error", {"message": "Cannot delete this message"}) + return + image_file_id = result.get("image_file_id") + if image_file_id: + image_path = _safe_upload_path(image_file_id, ".enc") + if image_path: + _secure_delete(image_path) + await adb.delete_image_upload(image_file_id) + logger.info("[MSG] %s deleted message=%s", _who(session), message_id[:8]) + await send_resp(msg, writer, "delete_message", "ok", {"message_id": message_id}) + members = await adb.get_conversation_members(conv_id) + notif_data = {"message_id": message_id, "conversation_id": conv_id} + member_ids = [m["id"] for m in members if m["id"] != session["user_id"]] + await _notify_users(member_ids, "message_deleted", notif_data) + + +async def handle_react_message(msg: dict, session: dict, writer: ProtocolWriter): + if await _is_rate_limited(f"react|{session['user_id']}", 20): + await send_resp(msg, writer, "react_message", "error", {"message": "Too many requests. Try later."}) + return + message_id = msg.get("message_id", "") + reaction = msg.get("reaction", "") + action = msg.get("action", "add") # "add" or "remove" + + if not message_id or not reaction: + await send_resp(msg, writer, "react_message", "error", {"message": "Missing fields"}) + return + if not _valid_uuid(message_id): + await send_resp(msg, writer, "react_message", "error", {"message": "Invalid message_id"}) + return + if reaction not in db.ALLOWED_REACTIONS: + await send_resp(msg, writer, "react_message", "error", {"message": "Invalid reaction"}) + return + if action not in ("add", "remove"): + await send_resp(msg, writer, "react_message", "error", {"message": "Invalid action"}) + return + + conv_id = await adb.get_message_conversation(message_id) + if not conv_id: + await send_resp(msg, writer, "react_message", "error", {"message": "Message not found"}) + return + if not await adb.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "react_message", "error", {"message": "Not a member"}) + return + + old_reaction = None + if action == "add": + changed, old_reaction = await adb.add_reaction(message_id, session["user_id"], reaction) + if not changed: + await send_resp(msg, writer, "react_message", "ok", {"message_id": message_id}) + return + else: + await adb.remove_reaction(message_id, session["user_id"]) + + logger.info("[MSG] %s %s reaction '%s' on message=%s", _who(session), action, reaction, message_id[:8]) + resp_data = {"message_id": message_id} + if old_reaction: + resp_data["old_reaction"] = old_reaction + await send_resp(msg, writer, "react_message", "ok", resp_data) + + members = await adb.get_conversation_members(conv_id) + member_ids = [m["id"] for m in members if m["id"] != session["user_id"]] + + # If replacing an old reaction, notify removal first + if old_reaction: + remove_data = { + "message_id": message_id, + "conversation_id": conv_id, + "user_id": session["user_id"], + "username": session.get("username", ""), + "reaction": old_reaction, + "action": "remove", + } + await _notify_users(member_ids, "message_reacted", remove_data) + + notif_data = { + "message_id": message_id, + "conversation_id": conv_id, + "user_id": session["user_id"], + "username": session.get("username", ""), + "reaction": reaction, + "action": action, + } + await _notify_users(member_ids, "message_reacted", notif_data) + + +async def handle_pin_message(msg: dict, session: dict, writer: ProtocolWriter): + message_id = msg.get("message_id", "") + action = msg.get("action", "pin") # "pin" or "unpin" + conversation_id = msg.get("conversation_id", "") + + if not message_id or not conversation_id: + await send_resp(msg, writer, "pin_message", "error", {"message": "Missing fields"}) + return + if not _valid_uuid(message_id) or not _valid_uuid(conversation_id): + await send_resp(msg, writer, "pin_message", "error", {"message": "Invalid ID"}) + return + if action not in ("pin", "unpin"): + await send_resp(msg, writer, "pin_message", "error", {"message": "Invalid action"}) + return + if not await adb.is_conversation_member(conversation_id, session["user_id"]): + await send_resp(msg, writer, "pin_message", "error", {"message": "Not a member"}) + return + + if action == "pin": + ok = await adb.pin_message(message_id, session["user_id"], conversation_id) + else: + ok = await adb.unpin_message(message_id, conversation_id) + + if not ok: + await send_resp(msg, writer, "pin_message", "error", + {"message": "Already pinned" if action == "pin" else "Not pinned"}) + return + + logger.info("[MSG] %s %s message=%s in conv=%s", _who(session), action, message_id[:8], conversation_id[:8]) + await send_resp(msg, writer, "pin_message", "ok", {"message_id": message_id, "action": action}) + + members = await adb.get_conversation_members(conversation_id) + notif_type = "message_pinned" if action == "pin" else "message_unpinned" + notif_data = { + "message_id": message_id, + "conversation_id": conversation_id, + "user_id": session["user_id"], + "username": session.get("username", ""), + } + member_ids = [m["id"] for m in members if m["id"] != session["user_id"]] + await _notify_users(member_ids, notif_type, notif_data) + + +async def handle_get_pinned_messages(msg: dict, session: dict, writer: ProtocolWriter): + conversation_id = msg.get("conversation_id", "") + if not conversation_id: + await send_resp(msg, writer, "get_pinned_messages", "error", {"message": "Missing conversation_id"}) + return + if not _valid_uuid(conversation_id): + await send_resp(msg, writer, "get_pinned_messages", "error", {"message": "Invalid conversation_id"}) + return + if not await adb.is_conversation_member(conversation_id, session["user_id"]): + await send_resp(msg, writer, "get_pinned_messages", "error", {"message": "Not a member"}) + return + + pinned = await adb.get_pinned_messages(conversation_id, session["user_id"]) + await send_resp(msg, writer, "get_pinned_messages", "ok", {"messages": pinned}) + + +async def handle_upload_image_start(msg: dict, session: dict, writer: ProtocolWriter): + conv_id = msg.get("conversation_id", "") + file_size = msg.get("file_size", 0) + file_id = msg.get("file_id", "") + file_type = msg.get("file_type", "image") # "image" or "file" + if not conv_id or not file_id: + await send_resp(msg, writer, "upload_image_start", "error", {"message": "Missing fields"}) + return + if not _valid_uuid(file_id): + await send_resp(msg, writer, "upload_image_start", "error", {"message": "Invalid file_id"}) + return + # M5: rate limit + caps on in-flight uploads + addr = _get_peer_addr(writer) + if await _is_rate_limited(f"upload_start|{session['user_id']}", 10): + await send_resp(msg, writer, "upload_image_start", "error", {"message": "Too many uploads. Try later."}) + return + if not await adb.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "upload_image_start", "error", {"message": "Not a member"}) + return + max_bytes = MAX_FILE_BYTES if file_type == "file" else MAX_IMAGE_BYTES + if max_bytes > 0 and file_size > max_bytes: + await send_resp(msg, writer, "upload_image_start", "error", + {"message": f"File too large (max {max_bytes} bytes)"}) + return + UPLOAD_DIR.mkdir(parents=True, exist_ok=True) + os.chmod(UPLOAD_DIR, 0o700) + temp_path = _safe_upload_path(file_id, ".tmp") + if not temp_path: + await send_resp(msg, writer, "upload_image_start", "error", {"message": "Invalid file_id"}) + return + # M5: atomic cap check + insert under single lock acquisition + cap_error = "" + async with _uploads_lock: + total = len(pending_uploads) + user_count = sum(1 for u in pending_uploads.values() if u.get("uploader_id") == session["user_id"]) + if total >= MAX_UPLOADS_GLOBAL: + cap_error = "Server upload limit reached. Try later." + elif user_count >= MAX_UPLOADS_PER_USER: + cap_error = "Too many active uploads. Finish or cancel existing ones." + else: + temp_path.write_bytes(b"") + os.chmod(temp_path, 0o600) + pending_uploads[file_id] = { + "temp_path": str(temp_path), + "received_bytes": 0, + "file_size": file_size, + "max_bytes": max_bytes, + "conv_id": conv_id, + "uploader_id": session["user_id"], + } + if cap_error: + await send_resp(msg, writer, "upload_image_start", "error", {"message": cap_error}) + return + try: + await adb.create_image_upload(file_id, conv_id, session["user_id"], file_size) + except Exception: + # Rollback: remove from pending_uploads + delete temp file + async with _uploads_lock: + pending_uploads.pop(file_id, None) + _secure_delete(temp_path) + logger.exception("[UPLOAD] DB create failed for file=%s", file_id[:8]) + await send_resp(msg, writer, "upload_image_start", "error", {"message": "Upload failed"}) + return + logger.info("[UPLOAD] %s started upload file=%s (%s, %d bytes)", + _who(session), file_id[:8], file_type, file_size) + await send_resp(msg, writer, "upload_image_start", "ok", {"file_id": file_id}) + + +async def handle_upload_image_chunk(msg: dict, session: dict, writer: ProtocolWriter): + file_id = msg.get("file_id", "") + chunk_data = msg.get("data", "") + if not file_id or not chunk_data: + await send_resp(msg, writer, "upload_image_chunk", "error", {"message": "Missing fields"}) + return + async with _uploads_lock: + upload = pending_uploads.get(file_id) + if not upload or upload["uploader_id"] != session["user_id"]: + upload = None + else: + temp_path_str = upload["temp_path"] + upload_max = upload.get("max_bytes", 0) + if not upload: + await send_resp(msg, writer, "upload_image_chunk", "error", {"message": "No active upload"}) + return + raw = decode_binary(chunk_data) + temp_path = Path(temp_path_str) + await asyncio.to_thread(_append_file, temp_path, raw) + over_limit = False + async with _uploads_lock: + upload = pending_uploads.get(file_id) + if upload: + upload["received_bytes"] += len(raw) + if upload_max > 0 and upload["received_bytes"] > upload_max: + pending_uploads.pop(file_id, None) + over_limit = True + received = upload["received_bytes"] + if over_limit: + _secure_delete(temp_path) + await send_resp(msg, writer, "upload_image_chunk", "error", {"message": "Upload exceeds size limit"}) + return + await send_resp(msg, writer, "upload_image_chunk", "ok", {"received": received}) + + +async def handle_upload_image_end(msg: dict, session: dict, writer: ProtocolWriter): + file_id = msg.get("file_id", "") + if not file_id: + await send_resp(msg, writer, "upload_image_end", "error", {"message": "Missing file_id"}) + return + async with _uploads_lock: + upload = pending_uploads.pop(file_id, None) + if not upload or upload["uploader_id"] != session["user_id"]: + await send_resp(msg, writer, "upload_image_end", "error", {"message": "No active upload"}) + return + temp_path = Path(upload["temp_path"]) + if upload["received_bytes"] != upload["file_size"]: + _secure_delete(temp_path) + await send_resp(msg, writer, "upload_image_end", "error", + {"message": f"Incomplete upload: received {upload['received_bytes']} of {upload['file_size']} bytes"}) + return + final_path = _safe_upload_path(file_id, ".enc") + if not final_path: + _secure_delete(temp_path) + await send_resp(msg, writer, "upload_image_end", "error", {"message": "Invalid file_id"}) + return + def _move_file(): + try: + temp_path.rename(final_path) + except Exception: + import shutil + shutil.move(str(temp_path), str(final_path)) + os.chmod(final_path, 0o600) + await asyncio.to_thread(_move_file) + await adb.complete_image_upload(file_id) + logger.info("[UPLOAD] %s completed upload file=%s (%d bytes)", + _who(session), file_id[:8], upload["received_bytes"]) + await send_resp(msg, writer, "upload_image_end", "ok", {"file_id": file_id}) + + +async def _validate_download(msg: dict, session: dict, writer: ProtocolWriter, resp_type: str): + """Validate file_id, check DB, return (file_path, file_size) or None on error.""" + file_id = msg.get("file_id", "") + if not file_id: + await send_resp(msg, writer, resp_type, "error", {"message": "Missing file_id"}) + return None + if not _valid_uuid(file_id): + await send_resp(msg, writer, resp_type, "error", {"message": "Invalid file_id"}) + return None + upload = await adb.get_image_upload(file_id) + if not upload or not upload["completed"]: + await send_resp(msg, writer, resp_type, "error", {"message": "File not found"}) + return None + if not await adb.is_conversation_member(upload["conversation_id"], session["user_id"]): + await send_resp(msg, writer, resp_type, "error", {"message": "Not a member"}) + return None + file_path = _safe_upload_path(file_id, ".enc") + if not file_path or not file_path.exists(): + await send_resp(msg, writer, resp_type, "error", {"message": "File not found"}) + return None + return file_path, file_path.stat().st_size + + +async def handle_download_image(msg: dict, session: dict, writer: ProtocolWriter): + file_id = msg.get("file_id", "") + offset = msg.get("offset", 0) + result = await _validate_download(msg, session, writer, "download_image") + if not result: + return + file_path, file_size = result + chunk = await asyncio.to_thread(_read_file_chunk, file_path, offset, IMAGE_CHUNK_SIZE) + done = (offset + len(chunk)) >= file_size + if offset == 0: + logger.info("[DOWNLOAD] %s downloading file=%s (%d bytes)", _who(session), file_id[:8], file_size) + await send_resp(msg, writer, "download_image", "ok", { + "file_id": file_id, + "data": encode_binary(chunk), + "offset": offset, + "done": done, + "total_size": file_size, + }) + + +async def handle_download_stream(msg: dict, session: dict, writer: ProtocolWriter): + """Stream entire file in chunks after a single auth check. No per-chunk round-trip. + + Server sends multiple responses with the same request_id. Client collects + them until it sees done=True. Each response has a unique ``seq`` number + so the client can reassemble chunks in order even if delivery is reordered. + """ + file_id = msg.get("file_id", "") + result = await _validate_download(msg, session, writer, "download_stream") + if not result: + return + file_path, file_size = result + logger.info("[DOWNLOAD] %s streaming file=%s (%d bytes)", _who(session), file_id[:8], file_size) + offset = 0 + seq = 0 + while offset < file_size: + chunk = await asyncio.to_thread(_read_file_chunk, file_path, offset, IMAGE_CHUNK_SIZE) + if not chunk: + break + done = (offset + len(chunk)) >= file_size + # Build response manually so we can reuse the original request_id + resp = { + "type": "download_stream", + "status": "ok", + "data": { + "file_id": file_id, + "data": encode_binary(chunk), + "offset": offset, + "seq": seq, + "done": done, + "total_size": file_size, + }, + } + req_id = msg.get("request_id") + if req_id: + resp["request_id"] = req_id + data = json.dumps(resp, ensure_ascii=False).encode("utf-8") + b"\n" + try: + writer._writer.write(data) + await writer._writer.drain() + except Exception: + break + offset += len(chunk) + seq += 1 + + +MAX_AVATAR_BYTES = 2 * 1024 * 1024 # 2 MB + + +async def handle_get_profile(msg: dict, session: dict, writer: ProtocolWriter): + """Get user profile (respects visibility for other users).""" + target_user_id = msg.get("user_id", "").strip() + if not target_user_id: + target_user_id = session["user_id"] + elif not _valid_uuid(target_user_id): + await send_resp(msg, writer, "get_profile", "error", {"message": "Invalid user_id"}) + return + profile = await adb.get_user_profile(target_user_id, viewer_id=session["user_id"]) + if not profile: + await send_resp(msg, writer, "get_profile", "error", {"message": "User not found"}) + return + # Serialize datetime fields + for key in ("created_at", "updated_at"): + if profile.get(key) and hasattr(profile[key], "isoformat"): + profile[key] = profile[key].isoformat() + await send_resp(msg, writer, "get_profile", "ok", profile) + + +async def handle_update_profile(msg: dict, session: dict, writer: ProtocolWriter): + """Update own profile fields.""" + fields = {} + for key in ("phone", "phone_visible", "email_visible", "location", "location_visible"): + if key in msg: + fields[key] = msg[key] + if not fields: + await send_resp(msg, writer, "update_profile", "error", {"message": "No fields to update"}) + return + await adb.update_user_profile(session["user_id"], **fields) + await send_resp(msg, writer, "update_profile", "ok", {"message": "OK"}) + + +async def handle_update_avatar(msg: dict, session: dict, writer: ProtocolWriter): + """Upload avatar (base64 in single message, max 2MB).""" + if await _is_rate_limited(f"update_avatar|{session['user_id']}", 5): + await send_resp(msg, writer, "update_avatar", "error", {"message": "Too many requests. Try later."}) + return + avatar_b64 = msg.get("data", "") + if not avatar_b64: + await send_resp(msg, writer, "update_avatar", "error", {"message": "Missing data"}) + return + avatar_data = decode_binary(avatar_b64) + if len(avatar_data) > MAX_AVATAR_BYTES: + await send_resp(msg, writer, "update_avatar", "error", + {"message": f"Avatar too large (max {MAX_AVATAR_BYTES} bytes)"}) + return + # Detect format from magic bytes + ext = "jpg" + if avatar_data[:8] == b'\x89PNG\r\n\x1a\n': + ext = "png" + avatar_dir = UPLOAD_DIR / "avatars" + avatar_dir.mkdir(parents=True, exist_ok=True) + os.chmod(avatar_dir, 0o700) + filename = f"{session['user_id']}.{ext}" + avatar_path = _safe_avatar_path(filename) + if not avatar_path: + await send_resp(msg, writer, "update_avatar", "error", {"message": "Invalid path"}) + return + await asyncio.to_thread(avatar_path.write_bytes, avatar_data) + os.chmod(avatar_path, 0o600) + await adb.update_user_profile(session["user_id"], avatar_file=filename) + logger.info("[AVATAR] %s updated their avatar", _who(session)) + await send_resp(msg, writer, "update_avatar", "ok", {"avatar_file": filename}) + # Notify contacts about avatar change + contacts = await adb.get_user_contacts(session["user_id"]) + if contacts: + await _notify_users(contacts, "avatar_changed", { + "user_id": session["user_id"], + }, exclude_writer=writer) + + +async def handle_get_avatar(msg: dict, session: dict, writer: ProtocolWriter): + """Download avatar for a user.""" + target_user_id = msg.get("user_id", "").strip() + if not target_user_id: + await send_resp(msg, writer, "get_avatar", "error", {"message": "Missing user_id"}) + return + if not _valid_uuid(target_user_id): + await send_resp(msg, writer, "get_avatar", "error", {"message": "Invalid user_id"}) + return + profile = await adb.get_user_profile(target_user_id) + if not profile or not profile.get("avatar_file"): + logger.debug("[AVATAR] get_avatar for %s — no avatar_file in profile", target_user_id[:8]) + await send_resp(msg, writer, "get_avatar", "error", {"message": "No avatar"}) + return + avatar_path = _safe_avatar_path(profile["avatar_file"]) + if not avatar_path or not avatar_path.exists(): + logger.warning("[AVATAR] get_avatar for %s — file missing: %s", target_user_id[:8], profile["avatar_file"]) + await send_resp(msg, writer, "get_avatar", "error", {"message": "Avatar file not found"}) + return + avatar_data = await asyncio.to_thread(avatar_path.read_bytes) + await send_resp(msg, writer, "get_avatar", "ok", { + "user_id": target_user_id, + "data": encode_binary(avatar_data), + "filename": profile["avatar_file"], + }) + + +async def handle_update_group_avatar(msg: dict, session: dict, writer: ProtocolWriter): + """Upload avatar for a group conversation (base64, max 2MB). Only members can set it.""" + if await _is_rate_limited(f"update_avatar|{session['user_id']}", 5): + await send_resp(msg, writer, "update_group_avatar", "error", {"message": "Too many requests. Try later."}) + return + conv_id = msg.get("conversation_id", "").strip() + avatar_b64 = msg.get("data", "") + if not conv_id or not avatar_b64: + await send_resp(msg, writer, "update_group_avatar", "error", {"message": "Missing fields"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, "update_group_avatar", "error", {"message": "Invalid conversation_id"}) + return + if not await adb.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "update_group_avatar", "error", {"message": "Not a member"}) + return + avatar_data = decode_binary(avatar_b64) + if len(avatar_data) > MAX_AVATAR_BYTES: + await send_resp(msg, writer, "update_group_avatar", "error", + {"message": f"Avatar too large (max {MAX_AVATAR_BYTES} bytes)"}) + return + ext = "jpg" + if avatar_data[:8] == b'\x89PNG\r\n\x1a\n': + ext = "png" + avatar_dir = UPLOAD_DIR / "avatars" + avatar_dir.mkdir(parents=True, exist_ok=True) + os.chmod(avatar_dir, 0o700) + filename = f"group_{conv_id}.{ext}" + avatar_path = _safe_avatar_path(filename) + if not avatar_path: + await send_resp(msg, writer, "update_group_avatar", "error", {"message": "Invalid path"}) + return + await asyncio.to_thread(avatar_path.write_bytes, avatar_data) + os.chmod(avatar_path, 0o600) + await adb.update_conversation_avatar(conv_id, filename) + logger.info("[AVATAR] %s updated group avatar for conv=%s", _who(session), conv_id[:8]) + await send_resp(msg, writer, "update_group_avatar", "ok", {"avatar_file": filename}) + + +async def handle_get_group_avatar(msg: dict, session: dict, writer: ProtocolWriter): + """Download avatar for a group conversation.""" + conv_id = msg.get("conversation_id", "").strip() + if not conv_id: + await send_resp(msg, writer, "get_group_avatar", "error", {"message": "Missing conversation_id"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, "get_group_avatar", "error", {"message": "Invalid conversation_id"}) + return + if not await adb.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "get_group_avatar", "error", {"message": "Not a member"}) + return + conv = await adb.get_conversation(conv_id) + if not conv or not conv.get("avatar_file"): + await send_resp(msg, writer, "get_group_avatar", "error", {"message": "No avatar"}) + return + avatar_path = _safe_avatar_path(conv["avatar_file"]) + if not avatar_path or not avatar_path.exists(): + await send_resp(msg, writer, "get_group_avatar", "error", {"message": "Avatar file not found"}) + return + avatar_data = await asyncio.to_thread(avatar_path.read_bytes) + await send_resp(msg, writer, "get_group_avatar", "ok", { + "conversation_id": conv_id, + "data": encode_binary(avatar_data), + "filename": conv["avatar_file"], + }) + + +async def handle_list_devices(msg: dict, session: dict, writer: ProtocolWriter): + """List all devices for the current user.""" + devices = await adb.get_user_devices(session["user_id"]) + result = [] + for d in devices: + entry = { + "device_id": d["id"], + "device_name": d.get("device_name"), + "created_at": d["created_at"].isoformat() if hasattr(d["created_at"], "isoformat") else str(d["created_at"]), + "last_seen_at": d["last_seen_at"].isoformat() if d.get("last_seen_at") and hasattr(d["last_seen_at"], "isoformat") else (str(d["last_seen_at"]) if d.get("last_seen_at") else None), + "is_current": d["id"] == session.get("device_id"), + } + result.append(entry) + await send_resp(msg, writer, "list_devices", "ok", {"devices": result}) + + +async def handle_remove_device(msg: dict, session: dict, writer: ProtocolWriter): + """Remove a device (cannot remove current device).""" + device_id = msg.get("device_id", "").strip() + if not device_id: + await send_resp(msg, writer, "remove_device", "error", {"message": "Missing device_id"}) + return + if not _valid_uuid(device_id): + await send_resp(msg, writer, "remove_device", "error", {"message": "Invalid device_id"}) + return + if device_id == session.get("device_id"): + await send_resp(msg, writer, "remove_device", "error", {"message": "Cannot remove current device"}) + return + dev = await adb.get_device(device_id) + if not dev or dev["user_id"] != session["user_id"]: + await send_resp(msg, writer, "remove_device", "error", {"message": "Device not found"}) + return + await adb.delete_device(device_id) + logger.info("[DEVICE] %s removed device=%s", _who(session), device_id[:8]) + await send_resp(msg, writer, "remove_device", "ok", {"message": "OK"}) + + +async def handle_session_reset(msg: dict, session: dict, writer: ProtocolWriter): + """Notify peer to reset a corrupted Double Ratchet session.""" + peer_user_id = msg.get("peer_user_id", "").strip() + peer_device_id = msg.get("peer_device_id", "").strip() or None + if not peer_user_id or not _valid_uuid(peer_user_id): + await send_resp(msg, writer, "session_reset", "error", {"message": "Invalid peer_user_id"}) + return + if peer_device_id and not _valid_uuid(peer_device_id): + await send_resp(msg, writer, "session_reset", "error", {"message": "Invalid peer_device_id"}) + return + # H3 fix: rate limit (5/min per user, keyed by user_id only — IP-independent) + if await _is_rate_limited(f"session_reset|{session['user_id']}", 5): + await send_resp(msg, writer, "session_reset", "error", {"message": "Rate limit exceeded"}) + return + # H3 fix: verify users share at least one conversation + if not await adb.shares_conversation(session["user_id"], peer_user_id): + await send_resp(msg, writer, "session_reset", "error", {"message": "No shared conversation"}) + return + # Push notification to peer (target specific device if specified) + notif_data = { + "from_user_id": session["user_id"], + "from_device_id": session.get("device_id"), + } + if peer_device_id: + # Send only to the specific device + targets = [] + async with _clients_lock: + for w in connected_clients.get(peer_user_id, []): + if writer_device_map.get(id(w)) == peer_device_id: + targets.append(w) + for w in targets: + try: + await w.send_response("session_reset", "ok", notif_data) + except Exception: + pass + else: + await _notify_users([peer_user_id], "session_reset", notif_data) + logger.info("[SESSION] %s reset session with peer=%s", _who(session), peer_user_id[:8]) + await send_resp(msg, writer, "session_reset", "ok", {}) + + +async def handle_get_deleted_since(msg: dict, session: dict, writer: ProtocolWriter): + """Return message IDs deleted since a given timestamp.""" + conv_id = msg.get("conversation_id", "") + since_ts = msg.get("since_ts", "") + if not conv_id or not since_ts: + await send_resp(msg, writer, "get_deleted_since", "error", {"message": "Missing parameters"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, "get_deleted_since", "error", {"message": "Invalid conversation_id"}) + return + if not await adb.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "get_deleted_since", "error", {"message": "Not a member"}) + return + deleted_ids = await adb.get_deleted_messages_since(conv_id, session["user_id"], since_ts) + await send_resp(msg, writer, "get_deleted_since", "ok", {"deleted_ids": deleted_ids}) + + +async def handle_reencrypt_messages(msg: dict, session: dict, writer: ProtocolWriter): + """Re-encrypt message history with self-encryption key (for device pairing).""" + if await _is_rate_limited(f"reencrypt|{session['user_id']}", 10): + await send_resp(msg, writer, "reencrypt_messages", "error", {"message": "Too many requests. Try later."}) + return + updates_raw = msg.get("updates", []) + if not updates_raw: + await send_resp(msg, writer, "reencrypt_messages", "error", {"message": "No updates"}) + return + if len(updates_raw) > 500: + await send_resp(msg, writer, "reencrypt_messages", "error", + {"message": "Too many updates (max 500 per request)"}) + return + updates = [] + for u in updates_raw: + mid = u.get("message_id", "") + ct_b64 = u.get("encrypted_content", "") + nonce_b64 = u.get("nonce", "") + if not mid or not ct_b64 or not nonce_b64: + continue + updates.append({ + "message_id": mid, + "encrypted_content": decode_binary(ct_b64), + "nonce": decode_binary(nonce_b64), + }) + if not updates: + await send_resp(msg, writer, "reencrypt_messages", "error", {"message": "No valid updates"}) + return + await adb.batch_reencrypt_messages(session["user_id"], updates) + logger.info("[REENCRYPT] %s re-encrypted %d messages", _who(session), len(updates)) + await send_resp(msg, writer, "reencrypt_messages", "ok", {"count": len(updates)}) + + +async def _cleanup_uploads(): + stale = await adb.get_stale_uploads(UPLOAD_STALE_SECONDS) + for s in stale: + fid = s["file_id"] + for ext in (".tmp", ".enc"): + p = _safe_upload_path(fid, ext) + if not p: + continue + _secure_delete(p) + await adb.delete_image_upload(fid) + async with _uploads_lock: + pending_uploads.pop(fid, None) + if stale: + logger.info("Cleaned up %d stale uploads.", len(stale)) + + +async def handle_client(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + global current_connections + addr = _get_peer_addr(ProtocolWriter(writer)) + async with _conn_lock: + current_connections += 1 + connection_counts[addr] = connection_counts.get(addr, 0) + 1 + over_limit = (current_connections > MAX_CONNECTIONS_GLOBAL or + connection_counts[addr] > MAX_CONNECTIONS_PER_IP) + if over_limit: + try: + writer.close() + except Exception: + pass + async with _conn_lock: + current_connections = max(0, current_connections - 1) + connection_counts[addr] = max(0, connection_counts.get(addr, 1) - 1) + return + logger.info("[CONN] Client connected from %s", addr) + + # Enable TCP keepalive on the socket to detect dead connections + sock = writer.get_extra_info("socket") + if sock is not None: + try: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + if hasattr(socket, "TCP_KEEPIDLE"): + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, TCP_KEEPALIVE_IDLE) + if hasattr(socket, "TCP_KEEPINTVL"): + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, TCP_KEEPALIVE_INTERVAL) + if hasattr(socket, "TCP_KEEPCNT"): + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, TCP_KEEPALIVE_COUNT) + except OSError: + pass # Some platforms/TLS wrappers don't support these options + + proto_reader = ProtocolReader(reader) + proto_writer = ProtocolWriter(writer) + session = None + state = {"_req_times": []} + + try: + while True: + try: + msg = await proto_reader.read_message() + except ValueError as e: + try: + await proto_writer.send_response("protocol_error", "error", {"message": str(e)}) + except Exception: + pass + break + if msg is None: + break + + msg_type = msg.get("type", "") + now = asyncio.get_event_loop().time() + # Upload chunks are exempt from per-connection rate limit — + # a single file upload can legitimately send 20+ chunks in rapid + # succession. The upload subsystem has its own guards (per-user + # upload cap, per-user rate limit on upload_image_start, and + # file-size validation) so double-throttling is unnecessary. + _rl_exempt = msg_type == "upload_image_chunk" + if not _rl_exempt: + times = [t for t in state["_req_times"] if now - t <= CONNECTION_RL_WINDOW] + if len(times) >= CONNECTION_RL_MAX: + await send_resp(msg, proto_writer, msg_type, "error", {"message": "Too many requests. Slow down."}) + state["_req_times"] = times + continue + times.append(now) + state["_req_times"] = times + + try: + if msg_type == "register": + await handle_register_start(msg, proto_writer) + elif msg_type == "register_confirm": + await handle_register_confirm(msg, proto_writer) + elif msg_type == "login_start": + await handle_login_start(msg, proto_writer, state) + elif msg_type == "login_finish": + result = await handle_login_finish(msg, proto_writer, state) + if result: + session = result + elif msg_type == "pairing_start": + await handle_pairing_start(msg, proto_writer) + elif msg_type == "pairing_poll": + await handle_pairing_poll(msg, proto_writer) + elif session is None: + await send_resp(msg, proto_writer, msg_type, "error", {"message": "Not logged in"}) + elif msg_type == "get_user_info": + await handle_get_user_info(msg, session, proto_writer) + elif msg_type == "upload_prekeys": + await handle_upload_prekeys(msg, session, proto_writer) + elif msg_type == "get_key_bundle": + await handle_get_key_bundle(msg, session, proto_writer) + elif msg_type == "get_prekey_count": + await handle_get_prekey_count(msg, session, proto_writer) + elif msg_type == "ensure_prekeys": + await handle_ensure_prekeys(msg, session, proto_writer) + elif msg_type == "create_conversation": + await handle_create_conversation(msg, session, proto_writer) + elif msg_type == "find_conversation": + await handle_find_conversation(msg, session, proto_writer) + elif msg_type == "add_member": + await handle_add_member(msg, session, proto_writer) + elif msg_type == "accept_invitation": + await handle_accept_invitation(msg, session, proto_writer) + elif msg_type == "decline_invitation": + await handle_decline_invitation(msg, session, proto_writer) + elif msg_type == "list_invitations": + await handle_list_invitations(msg, session, proto_writer) + elif msg_type == "list_conversations": + await handle_list_conversations(msg, session, proto_writer) + elif msg_type == "send_message": + await handle_send_message(msg, session, proto_writer) + elif msg_type == "get_messages": + await handle_get_messages(msg, session, proto_writer) + elif msg_type == "rotate_keys": + await handle_rotate_keys(msg, session, proto_writer) + elif msg_type == "change_username": + await handle_change_username(msg, session, proto_writer) + elif msg_type == "remove_member": + await handle_remove_member(msg, session, proto_writer) + elif msg_type == "leave_group": + await handle_leave_group(msg, session, proto_writer) + elif msg_type == "rename_conversation": + await handle_rename_conversation(msg, session, proto_writer) + elif msg_type == "delete_conversation": + await handle_delete_conversation(msg, session, proto_writer) + elif msg_type == "mark_read": + await handle_mark_read(msg, session, proto_writer) + elif msg_type == "mark_conversation_read": + await handle_mark_conversation_read(msg, session, proto_writer) + elif msg_type == "confirm_delivery": + await handle_confirm_delivery(msg, session, proto_writer) + elif msg_type == "typing_start": + await handle_typing_start(msg, session, proto_writer) + elif msg_type == "typing_stop": + await handle_typing_stop(msg, session, proto_writer) + elif msg_type == "pairing_claim": + await handle_pairing_claim(msg, session, proto_writer) + elif msg_type == "pairing_send": + await handle_pairing_send(msg, session, proto_writer) + elif msg_type == "delete_message": + await handle_delete_message(msg, session, proto_writer) + elif msg_type == "upload_image_start": + await handle_upload_image_start(msg, session, proto_writer) + elif msg_type == "upload_image_chunk": + await handle_upload_image_chunk(msg, session, proto_writer) + elif msg_type == "upload_image_end": + await handle_upload_image_end(msg, session, proto_writer) + elif msg_type == "download_image": + await handle_download_image(msg, session, proto_writer) + elif msg_type == "download_stream": + await handle_download_stream(msg, session, proto_writer) + elif msg_type == "get_profile": + await handle_get_profile(msg, session, proto_writer) + elif msg_type == "update_profile": + await handle_update_profile(msg, session, proto_writer) + elif msg_type == "update_avatar": + await handle_update_avatar(msg, session, proto_writer) + elif msg_type == "get_avatar": + await handle_get_avatar(msg, session, proto_writer) + elif msg_type == "update_group_avatar": + await handle_update_group_avatar(msg, session, proto_writer) + elif msg_type == "get_group_avatar": + await handle_get_group_avatar(msg, session, proto_writer) + elif msg_type == "get_deleted_since": + await handle_get_deleted_since(msg, session, proto_writer) + elif msg_type == "reencrypt_messages": + await handle_reencrypt_messages(msg, session, proto_writer) + elif msg_type == "list_devices": + await handle_list_devices(msg, session, proto_writer) + elif msg_type == "remove_device": + await handle_remove_device(msg, session, proto_writer) + elif msg_type == "session_reset": + await handle_session_reset(msg, session, proto_writer) + elif msg_type == "react_message": + await handle_react_message(msg, session, proto_writer) + elif msg_type == "pin_message": + await handle_pin_message(msg, session, proto_writer) + elif msg_type == "get_pinned_messages": + await handle_get_pinned_messages(msg, session, proto_writer) + else: + await send_resp(msg, proto_writer, msg_type, "error", {"message": "Unknown type"}) + except Exception as e: + logger.warning("[ERROR] %s handler '%s' failed: %s", _who(session), msg_type, e, exc_info=True) + try: + await send_resp(msg, proto_writer, msg_type, "error", {"message": "Internal server error"}) + except Exception: + break # Can't send response — connection is dead + except Exception as e: + logger.warning("Client connection error: %s", e) + finally: + async with _conn_lock: + current_connections = max(0, current_connections - 1) + connection_counts[addr] = max(0, connection_counts.get(addr, 1) - 1) + offline_targets = [] + if session: + uid = session["user_id"] + contacts = await adb.get_user_contacts(uid) + async with _clients_lock: + writer_device_map.pop(id(proto_writer), None) + if uid in connected_clients: + remaining = [w for w in connected_clients[uid] if w is not proto_writer] + if remaining: + connected_clients[uid] = remaining + else: + del connected_clients[uid] + # User fully offline — snapshot targets under lock + for contact_id in contacts: + for cw in connected_clients.get(contact_id, []): + offline_targets.append(cw) + # Send offline notifications outside lock + for cw in offline_targets: + try: + await cw.send_response("user_offline", "ok", {"user_id": uid}) + except Exception: + pass + writer.close() + logger.info("[CONN] %s disconnected", _who(session) if session else addr) + + +async def main(): + setup_logging() + host = os.getenv("SERVER_HOST", "127.0.0.1") + port = int(os.getenv("SERVER_PORT", "9999")) + tls_enabled = os.getenv("TLS_ENABLED", "false").lower() in ("1", "true", "yes") + tls_required = os.getenv("TLS_REQUIRED", "false").lower() in ("1", "true", "yes") + tls_autogen = os.getenv("TLS_AUTOGEN", "false").lower() in ("1", "true", "yes") + + is_dev = os.getenv("ENVIRONMENT", "").lower() in ("dev", "development") + ssl_context = None + if tls_required and not tls_enabled: + raise RuntimeError("TLS_REQUIRED is enabled but TLS is not enabled.") + if tls_enabled: + cert_file = os.getenv("TLS_CERT_FILE", "").strip() + key_file = os.getenv("TLS_KEY_FILE", "").strip() + if not cert_file or not key_file: + if tls_autogen: + if not is_dev: + raise RuntimeError("TLS_AUTOGEN is only allowed when ENVIRONMENT=dev") + cert_dir = Path(__file__).resolve().parent / "certs" + cert_dir.mkdir(parents=True, exist_ok=True) + cert_file = str(cert_dir / "server.crt") + key_file = str(cert_dir / "server.key") + if not (os.path.exists(cert_file) and os.path.exists(key_file)): + try: + subprocess.run( + [ + "openssl", "req", "-x509", "-newkey", "rsa:4096", + "-keyout", key_file, "-out", cert_file, + "-days", "365", "-nodes", "-subj", "/CN=localhost", + ], + check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, + ) + os.chmod(key_file, 0o600) + except FileNotFoundError: + raise RuntimeError("OpenSSL not found.") + except subprocess.CalledProcessError: + raise RuntimeError("Failed to auto-generate TLS cert.") + logger.warning("Using auto-generated self-signed certificate — not for production use.") + else: + raise RuntimeError("TLS is enabled but TLS_CERT_FILE or TLS_KEY_FILE is missing.") + ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ssl_context.load_cert_chain(certfile=cert_file, keyfile=key_file) + else: + logger.warning("TLS is disabled — traffic is unencrypted. Set TLS_ENABLED=true for production.") + + UPLOAD_DIR.mkdir(parents=True, exist_ok=True) + os.chmod(UPLOAD_DIR, 0o700) + + # Thread pool for asyncio.to_thread() — DB calls + file I/O + pool_workers = int(os.getenv("THREAD_POOL_SIZE", "40")) + asyncio.get_event_loop().set_default_executor(ThreadPoolExecutor(max_workers=pool_workers)) + logger.info("Thread pool executor: %d workers", pool_workers) + + # Load phantom user IDs from DB into in-memory cache + phantom_user_ids.update(await adb.get_all_phantom_user_ids()) + if phantom_user_ids: + logger.info("Loaded %d phantom user IDs.", len(phantom_user_ids)) + + server = await asyncio.start_server( + handle_client, host, port, limit=MAX_MESSAGE_BYTES, ssl=ssl_context, + ) + logger.info("Encrypted chat server v%s listening on %s:%s", VERSION, host, port) + + async def _cleanup_rate_limits(): + async with _conn_lock: + now = asyncio.get_event_loop().time() + window_start = now - RATE_LIMIT_WINDOW + stale_keys = [k for k, times in rate_limits.items() + if not any(t >= window_start for t in times)] + for k in stale_keys: + del rate_limits[k] + stale_conns = [k for k, v in connection_counts.items() if v <= 0] + for k in stale_conns: + del connection_counts[k] + + _cleanup_cycle = 0 + + async def _periodic_cleanup(): + nonlocal _cleanup_cycle + while True: + await asyncio.sleep(120) + _cleanup_cycle += 1 + try: + await _cleanup_uploads() + except Exception as e: + logger.warning("Upload cleanup error: %s", e) + try: + await _cleanup_rate_limits() + except Exception as e: + logger.warning("Rate limit cleanup error: %s", e) + try: + await _cleanup_registrations() + except Exception as e: + logger.warning("Registration cleanup error: %s", e) + # L8: clean up stale phantom users (>30 days, no real conversations) + try: + deleted = await adb.cleanup_stale_phantoms(30) + if deleted: + async with _clients_lock: + phantom_user_ids.clear() + phantom_user_ids.update(await adb.get_all_phantom_user_ids()) + logger.info("Cleaned up %d stale phantom users.", deleted) + except Exception as e: + logger.warning("Phantom cleanup error: %s", e) + # Metadata retention: purge old reads and reactions (every 30 cycles = ~1 hour) + if _cleanup_cycle % 30 == 0: + try: + reads_del = await adb.cleanup_old_reads(METADATA_RETENTION_DAYS) + reactions_del = await adb.cleanup_old_reactions(METADATA_RETENTION_DAYS) + if reads_del or reactions_del: + logger.info("Metadata cleanup: %d reads, %d reactions purged", + reads_del, reactions_del) + except Exception as e: + logger.warning("Metadata cleanup error: %s", e) + + asyncio.create_task(_periodic_cleanup()) + + loop = asyncio.get_running_loop() + stop = loop.create_future() + + def signal_handler(): + if not stop.done(): + stop.set_result(None) + + for sig in (signal.SIGINT, signal.SIGTERM): + loop.add_signal_handler(sig, signal_handler) + + async with server: + await stop + logger.info("Shutting down — closing %d client connections...", sum(len(ws) for ws in connected_clients.values())) + # Stop accepting new connections + server.close() + # Force-close all connected client writers + async with _clients_lock: + all_writers = [w for writers in connected_clients.values() for w in writers] + connected_clients.clear() + writer_device_map.clear() + for w in all_writers: + try: + w.close() + except Exception: + pass + # Give handle_client loops a moment to notice closed connections + await asyncio.sleep(0.1) + # Cancel any remaining handle_client tasks that are still blocked + tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + logger.info("Server shut down.") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/theme.py b/theme.py new file mode 100644 index 0000000..15ec5c1 --- /dev/null +++ b/theme.py @@ -0,0 +1,539 @@ +"""Theme system for Encrypted Chat GUI — light + dark mode with live switching.""" + +from __future__ import annotations + +import json +import logging +import os +from dataclasses import dataclass, fields +from pathlib import Path +from typing import Callable + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class ThemeColors: + """All colour tokens for one theme.""" + + # Surface hierarchy + bg_primary: str # Main background (messages area, right panel) + bg_secondary: str # Cards, inputs, elevated surfaces + bg_tertiary: str # Sidebar, deeper surfaces + bg_hover: str # Hover state on list items + bg_selected: str # Selected list item + + # Text + text_primary: str # Main text + text_secondary: str # Secondary / muted text + text_muted: str # Timestamps, counters, hints + + # Accent (brand blue) + accent: str + accent_hover: str + accent_text: str # Text on accent background + + # Message bubbles + bubble_sent_bg: str + bubble_sent_text: str + bubble_recv_bg: str + bubble_recv_text: str + bubble_sent_meta: str # Timestamp/read inside sent bubble + bubble_recv_meta: str # Timestamp inside received bubble + + # Semantic colours + success: str + warning: str + error: str + info: str + + # Chrome / borders + border: str + border_focus: str + scrollbar: str + separator: str + overlay: str # Privacy overlay background (rgba) + + # Links + link_https: str + link_http: str # Insecure link (orange) + + # Mentions & search + mention: str + search_highlight: str + search_current: str + + # Reactions + reaction_bg: str + reaction_bg_own: str + reaction_border: str + reaction_border_own: str + + # Misc + online_dot: str + online_dot_border: str + pin_color: str + sender_name_other: str # Non-self sender name colour in groups + receipt_read: str # Read receipt checkmarks (must contrast with sent bubble bg) + + +# --------------------------------------------------------------------------- +# Dark theme — Catppuccin Mocha palette +# --------------------------------------------------------------------------- + +DARK_THEME = ThemeColors( + bg_primary="#1e1e2e", + bg_secondary="#313244", + bg_tertiary="#181825", + bg_hover="#252536", + bg_selected="#313244", + + text_primary="#cdd6f4", + text_secondary="#bac2de", + text_muted="#6c7086", + + accent="#89b4fa", + accent_hover="#74c7ec", + accent_text="#1e1e2e", + + bubble_sent_bg="#2a4a7f", + bubble_sent_text="#cdd6f4", + bubble_recv_bg="#2c2c3e", + bubble_recv_text="#cdd6f4", + bubble_sent_meta="#8899bb", + bubble_recv_meta="#6c7086", + + success="#a6e3a1", + warning="#f9e2af", + error="#f38ba8", + info="#74c7ec", + + border="#45475a", + border_focus="#89b4fa", + scrollbar="#45475a", + separator="#45475a", + overlay="rgba(30, 30, 46, 245)", + + link_https="#89b4fa", + link_http="#fab387", + + mention="#89b4fa", + search_highlight="#f9e2af", + search_current="#fab387", + + reaction_bg="#313244", + reaction_bg_own="#45475a", + reaction_border="#45475a", + reaction_border_own="#585b70", + + online_dot="#a6e3a1", + online_dot_border="#181825", + pin_color="#f9e2af", + sender_name_other="#f9e2af", + receipt_read="#74c7ec", +) + +# --------------------------------------------------------------------------- +# Light theme — Signal-inspired palette +# --------------------------------------------------------------------------- + +LIGHT_THEME = ThemeColors( + bg_primary="#ffffff", + bg_secondary="#f2f2f7", + bg_tertiary="#e5e5ea", + bg_hover="#dcdce4", + bg_selected="#c7c7d1", + + text_primary="#1c1c1e", + text_secondary="#3a3a3c", + text_muted="#8a8a8e", + + accent="#3478f6", + accent_hover="#2563eb", + accent_text="#ffffff", + + bubble_sent_bg="#3478f6", + bubble_sent_text="#ffffff", + bubble_recv_bg="#e5e5ea", + bubble_recv_text="#1c1c1e", + bubble_sent_meta="#a3c4ff", + bubble_recv_meta="#8a8a8e", + + success="#34c759", + warning="#ff9500", + error="#ff3b30", + info="#5ac8fa", + + border="#c6c6c8", + border_focus="#3478f6", + scrollbar="#aeaeb2", + separator="#c6c6c8", + overlay="rgba(0, 0, 0, 200)", + + link_https="#2563eb", + link_http="#ea580c", + + mention="#2563eb", + search_highlight="#fde68a", + search_current="#fb923c", + + reaction_bg="#e5e5ea", + reaction_bg_own="#c7c7d1", + reaction_border="#c6c6c8", + reaction_border_own="#a0a0a8", + + online_dot="#34c759", + online_dot_border="#e5e5ea", + pin_color="#ff9500", + sender_name_other="#7c3aed", + receipt_read="#d0e8ff", +) + + +# --------------------------------------------------------------------------- +# ThemeManager singleton +# --------------------------------------------------------------------------- + +class ThemeManager: + """Manages the active theme, persistence and change notification.""" + + _instance: ThemeManager | None = None + + @classmethod + def instance(cls) -> ThemeManager: + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def __init__(self): + self._is_dark: bool = True + self._listeners: list[Callable[[], None]] = [] + self._email: str | None = None + self._load_global() + + # -- Public API -- + + @property + def is_dark(self) -> bool: + return self._is_dark + + @property + def colors(self) -> ThemeColors: + return DARK_THEME if self._is_dark else LIGHT_THEME + + def toggle(self): + self._is_dark = not self._is_dark + self._save() + self._notify() + + def set_dark(self, dark: bool): + if dark == self._is_dark: + return + self._is_dark = dark + self._save() + self._notify() + + def set_email(self, email: str): + """After login, bind to user-specific preference file.""" + self._email = email + self._load_user() + + def on_change(self, callback: Callable[[], None]): + self._listeners.append(callback) + + def remove_listener(self, callback: Callable[[], None]): + try: + self._listeners.remove(callback) + except ValueError: + pass + + def generate_qss(self) -> str: + return _build_qss(self.colors) + + # -- Persistence -- + + def _global_path(self) -> Path: + p = Path.home() / ".encrypted_chat" + p.mkdir(parents=True, exist_ok=True) + return p / "global_settings.json" + + def _user_path(self) -> Path | None: + if not self._email: + return None + p = Path.home() / ".encrypted_chat" / self._email + if not p.exists(): + return None + return p / "theme.json" + + def _load_global(self): + try: + p = self._global_path() + if p.exists(): + data = json.loads(p.read_text()) + self._is_dark = data.get("dark", True) + except Exception: + pass + + def _load_user(self): + try: + p = self._user_path() + if p and p.exists(): + data = json.loads(p.read_text()) + self._is_dark = data.get("dark", self._is_dark) + except Exception: + pass + + def _save(self): + data = {"dark": self._is_dark} + try: + self._global_path().write_text(json.dumps(data)) + except Exception: + pass + try: + p = self._user_path() + if p: + p.write_text(json.dumps(data)) + except Exception: + pass + + def _notify(self): + for cb in list(self._listeners): + try: + cb() + except Exception: + logger.debug("Theme listener error", exc_info=True) + + +# --------------------------------------------------------------------------- +# Convenience accessors +# --------------------------------------------------------------------------- + +def c() -> ThemeColors: + """Shorthand for ThemeManager.instance().colors.""" + return ThemeManager.instance().colors + + +def qss() -> str: + """Shorthand for ThemeManager.instance().generate_qss().""" + return ThemeManager.instance().generate_qss() + + +def tm() -> ThemeManager: + """Shorthand for ThemeManager.instance().""" + return ThemeManager.instance() + + +# --------------------------------------------------------------------------- +# QSS generator +# --------------------------------------------------------------------------- + +_FONT_STACK = ( + '"Segoe UI Variable", "Segoe UI", "Helvetica Neue", ' + '"SF Pro Text", "Calibri", sans-serif' +) + + +def _build_qss(t: ThemeColors) -> str: + return f""" +/* ── Global ──────────────────────────────────────────────── */ +QWidget {{ + background-color: {t.bg_primary}; + color: {t.text_primary}; + font-family: {_FONT_STACK}; + font-size: 11pt; +}} + +/* ── Input fields ────────────────────────────────────────── */ +QLineEdit {{ + background-color: {t.bg_secondary}; + border: 1px solid {t.border}; + border-radius: 6px; + padding: 8px; + color: {t.text_primary}; +}} +QLineEdit:focus {{ + border: 1px solid {t.border_focus}; +}} + +/* ── Buttons ─────────────────────────────────────────────── */ +QPushButton {{ + background-color: {t.accent}; + color: {t.accent_text}; + border: none; + border-radius: 6px; + padding: 8px 16px; + font-weight: bold; +}} +QPushButton:hover {{ + background-color: {t.accent_hover}; +}} +QPushButton:pressed {{ + background-color: {t.accent_hover}; +}} +QPushButton#secondaryBtn {{ + background-color: {t.bg_secondary}; + color: {t.text_primary}; + font-weight: normal; +}} +QPushButton#secondaryBtn:hover {{ + background-color: {t.bg_hover}; +}} +QPushButton#toolBtn {{ + background-color: transparent; + border: none; + border-radius: 4px; + padding: 4px; +}} +QPushButton#toolBtn:hover {{ + background-color: {t.bg_hover}; +}} + +/* ── Lists ───────────────────────────────────────────────── */ +QListWidget {{ + background-color: {t.bg_tertiary}; + border: none; + border-radius: 6px; + padding: 4px; +}} +QListWidget::item {{ + padding: 10px; + border-radius: 4px; +}} +QListWidget::item:selected {{ + background-color: {t.bg_selected}; + border-left: 3px solid {t.accent}; +}} +QListWidget::item:hover {{ + background-color: {t.bg_hover}; + color: {t.text_primary}; +}} + +/* ── Text areas ──────────────────────────────────────────── */ +QTextEdit, QTextBrowser {{ + background-color: {t.bg_primary}; + border: none; + border-radius: 6px; + padding: 8px; + color: {t.text_primary}; +}} + +/* ── Scrollbar ───────────────────────────────────────────── */ +QScrollBar:vertical {{ + background: transparent; + width: 8px; + margin: 0; +}} +QScrollBar::handle:vertical {{ + background: {t.scrollbar}; + border-radius: 4px; + min-height: 30px; +}} +QScrollBar::handle:vertical:hover {{ + background: {t.text_muted}; +}} +QScrollBar::add-line:vertical, QScrollBar::sub-line:vertical {{ + height: 0; +}} +QScrollBar::add-page:vertical, QScrollBar::sub-page:vertical {{ + background: transparent; +}} +QScrollBar:horizontal {{ + background: transparent; + height: 8px; + margin: 0; +}} +QScrollBar::handle:horizontal {{ + background: {t.scrollbar}; + border-radius: 4px; + min-width: 30px; +}} +QScrollBar::handle:horizontal:hover {{ + background: {t.text_muted}; +}} +QScrollBar::add-line:horizontal, QScrollBar::sub-line:horizontal {{ + width: 0; +}} +QScrollBar::add-page:horizontal, QScrollBar::sub-page:horizontal {{ + background: transparent; +}} + +/* ── Title label ─────────────────────────────────────────── */ +QLabel#title {{ + font-size: 15pt; + font-weight: bold; + color: {t.accent}; +}} + +/* ── Sidebar panel ───────────────────────────────────────── */ +#sidebarPanel {{ + background-color: {t.bg_tertiary}; +}} + +/* ── Splitter ────────────────────────────────────────────── */ +QSplitter::handle {{ + background-color: {t.separator}; + width: 1px; +}} + +/* ── Checkbox ────────────────────────────────────────────── */ +QCheckBox {{ + color: {t.text_primary}; +}} + +/* ── Menus ───────────────────────────────────────────────── */ +QMenu {{ + background-color: {t.bg_secondary}; + border: 1px solid {t.border}; + border-radius: 6px; + padding: 4px; +}} +QMenu::item {{ + padding: 6px 20px; + color: {t.text_primary}; + border-radius: 4px; +}} +QMenu::item:selected {{ + background-color: {t.bg_hover}; +}} +QMenu::separator {{ + height: 1px; + background: {t.separator}; + margin: 4px 8px; +}} + +/* ── Dialogs ─────────────────────────────────────────────── */ +QDialog {{ + background-color: {t.bg_primary}; + color: {t.text_primary}; +}} + +/* ── MessageBox ──────────────────────────────────────────── */ +QMessageBox {{ + background-color: {t.bg_primary}; + color: {t.text_primary}; +}} +QMessageBox QLabel {{ + color: {t.text_primary}; +}} + +/* ── InputDialog ─────────────────────────────────────────── */ +QInputDialog {{ + background-color: {t.bg_primary}; + color: {t.text_primary}; +}} + +/* ── ScrollArea ──────────────────────────────────────────── */ +QScrollArea {{ + background-color: {t.bg_primary}; + border: none; +}} + +/* ── ToolTip ─────────────────────────────────────────────── */ +QToolTip {{ + background-color: {t.bg_secondary}; + color: {t.text_primary}; + border: 1px solid {t.border}; + padding: 4px 8px; + font-size: 9pt; +}} +"""