diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 0000000..b4f0316 --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,18 @@ +{ + "permissions": { + "allow": [ + "Bash(python3:*)", + "Bash(ls:*)", + "Bash(pip3 show:*)", + "Bash(.venv/bin/python3:*)", + "Bash(python:*)", + "Bash(wc:*)", + "Bash(grep:*)", + "Bash(chmod:*)", + "Bash(find:*)", + "Bash(fc-list:*)", + "Bash(sudo ls:*)", + "Bash(mkdir:*)" + ] + } +} diff --git a/.env b/.env new file mode 100644 index 0000000..855c0c3 --- /dev/null +++ b/.env @@ -0,0 +1,19 @@ +MYSQL_HOST=192.168.1.112 +MYSQL_PORT=3306 +MYSQL_USER=sifrator +MYSQL_PASSWORD=Brouk100+1 +MYSQL_DATABASE=encrypted_chat + +#SERVER_HOST=192.168.88.65 +SERVER_HOST=0.0.0.0 +SERVER_PORT=9999 + +TLS_ENABLED=true +TLS_CERT_FILE=/home/filip/encrypted_chat/certs/fullchain.pem +TLS_KEY_FILE=/home/filip/encrypted_chat/certs/privkey.pem + +SMTP_HOST=smtp.protonmail.ch +SMTP_PORT=587 +SMTP_USER=cryptedchat@dw-technics.com +SMTP_PASS=DBL5GKTJA28KQRZF +SMTP_FROM=cryptedchat@dw-technics.com diff --git a/.gitignore b/.gitignore index 36b13f1..fbc2bda 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +<<<<<<< HEAD # ---> Python # Byte-compiled / optimized / DLL files __pycache__/ @@ -174,3 +175,14 @@ cython_debug/ # PyPI configuration file .pypirc +======= +__pycache__/ +*.pyc +#.env +#.env.* +.encrypted_chat/ +certs/* +!certs/*.sh +!certs/*.example +!certs/README.md +>>>>>>> d506e65 (initial commit) diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..9fcae88 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,1041 @@ +# Encrypted Chat — Project Context + +End-to-end encrypted chat with forward secrecy (X3DH + Double Ratchet, Signal Protocol). +Server stores and relays opaque blobs — never sees plaintext. RSA retained for login only. + +## Files + +| File | Lines | Purpose | +|------|-------|---------| +| `schema.sql` | ~172 | MySQL schema (users, devices, signed_prekeys, one_time_prekeys, conversations, conversation_members, group_invitations, messages, message_recipients, message_reactions, group_sender_keys, message_reads, image_uploads, user_profiles) | +| `db.py` | ~1500 | MySQL CRUD — one connection per call, `dictionary=True` cursors, returns dicts. Includes profile CRUD, `get_user_contacts()`, `update_conversation_creator()`, `get_conversation()`. Phantom user CRUD + `upgrade_phantom_user()`. Invitation CRUD. Group avatar. Device CRUD. Per-device prekey/session management. Reactions CRUD (`add_reaction`, `remove_reaction`, `get_reactions`). Pins CRUD (`pin_message`, `unpin_message`, `get_pinned_messages`). | +| `server.py` | ~2200 | Asyncio TCP server, handler dispatch, rate limiting, real-time notifications via `connected_clients` dict. Profile + avatar handlers. Online/offline status push. Leave group, delete conversation, group invitations, group avatar handlers. Phantom user support. Graceful shutdown. 4 asyncio.Lock guards (H4 fix). Device registration + per-device key bundles + per-device notifications. SPK age reporting in `get_prekey_count`. Reactions, pins, pinned messages handlers. | +| `protocol.py` | ~114 | Newline-delimited JSON protocol, `ProtocolReader`/`ProtocolWriter`, `encode_binary`/`decode_binary` (base64). Constants: `VERSION`, `MAX_MESSAGE_BYTES`, `MAX_IMAGE_BYTES`, `MAX_FILE_BYTES`, `IMAGE_CHUNK_SIZE`. | +| `crypto_utils.py` | ~950 | Ed25519, X25519, AES-256-GCM, HKDF, PBKDF2, X3DH, `DoubleRatchet` (with state snapshot/rollback), `SenderKeyState` (with state snapshot/rollback). RSA for login only. ECP1 password-based key encryption format (600k PBKDF2 iterations). Contact key verification: `compute_fingerprint()`, `format_fingerprint()`, `compute_safety_number()`, `encode_verification_qr()`, `decode_verification_qr()`. Message padding: `pad_plaintext()`, `unpad_plaintext()`. | +| `chat_core.py` | ~2850 | `ChatClient` class — session management, X3DH/ratchet encryption, local key storage, reconnect, profiles, file sharing, leave group, delete conversation, invitations, group avatar. Multi-device: per-device sessions, device_id persistence, device bundle cache. SPK rotation (7-day) with grace period. Reactions, pins, forwarding methods. Contact key verification: TOFU registry, explicit verification, safety numbers, QR code verify. Used by CLI + GUI | +| `client.py` | ~520 | Interactive CLI client — reactions, pin, forward, pinned messages, verify contact, show fingerprint commands | +| `gui_client.py` | ~3600 | PyQt6 GUI — `AsyncBridge` QThread bridges asyncio <-> Qt signals, `MainWindow`, `UserProfileDialog`, `VerificationDialog`, connection indicator + auto-reconnect, online status, file sharing, leave group, unread badges, circular avatars in conv list, online green dot overlay, group invitations UI, delete conversation, group avatar support, message reactions (emoji badges), forwarding with dialog, pin/unpin with indicator, pinned messages list, @mentions autocomplete, contact verification indicators (conv list green checkmark, E2E label status, key change warning dialog) | +| `ios_client/` | ~6200 | Native iOS client (Swift/SwiftUI) — wire-compatible with Python server. 47 Swift files: CryptoKit crypto (AES-GCM, HKDF, Ed25519, X25519), pure Swift GF(2^255-19) field arithmetic for Ed→X conversion, Security.framework RSA-4096, Network.framework TCP+TLS, Signal Protocol (X3DH + Double Ratchet + Sender Keys), SwiftUI views. Uses `project.yml` (XcodeGen). | + +## Architecture & Data Flow + +### Encryption: X3DH + Double Ratchet (Signal Protocol) + +**Keys per user:** +- **RSA-4096** — Login challenge-response only (server stores public key). Password-encrypted with ECP1 format (PBKDF2 600k iterations + AES-256-GCM). +- **Identity Key (IK)** — Ed25519 (signing) + converted to X25519 (for DH in X3DH). Password-encrypted with ECP1 format. +- **Signed Pre-Key (SPK)** — X25519, signed by IK, uploaded to server. **Rotates every 7 days** (M4). Previous SPK kept for grace period (in-flight X3DH). +- **One-Time Pre-Keys (OPK)** — X25519, consumed on X3DH initiation, auto-replenished when count < 20 + +**DM flow:** +1. Alice fetches Bob's per-device key bundles (IK, SPK per device, OPK per device) -> X3DH per device -> shared secret per device +2. Double Ratchet initialized from shared secret — one session per (user, device) pair +3. Each message: symmetric ratchet (HMAC chain) -> message key -> AES-256-GCM +4. Each reply direction change: DH ratchet (new X25519 keypair) -> new root + chain keys +5. Per-device ciphertext — each recipient device gets individually encrypted blob +6. Self-encrypted copy uses SELF_DEVICE_ID sentinel, readable by all own devices + +**Group flow (Sender Keys):** +1. Each sender has own SenderKeyState per group +2. Sender key distributed to members via pairwise Double Ratchet (as control DM with `_sender_key` field) +3. Group messages: symmetric ratchet on sender key -> AES-256-GCM +4. Same ciphertext replicated to all recipients (efficient) + +### Contact Key Verification (Out-of-Band Trust) + +**Problem:** Server stores identity keys but could MITM new sessions by substituting keys. Users need a way to verify keys out-of-band. + +**Design:** Entirely client-side — zero server changes. Server already stores identity keys in `users.identity_key` (32B Ed25519) and returns them via `get_user_info`/`get_key_bundle`. Verification state is local-only (privacy by design — server never learns who verified whom). + +**Trust model:** +``` +First contact → TOFU (Trust On First Use) → "trusted" (key recorded) + ↓ + Out-of-band verify → "verified" (explicit confirmation) + ↓ + Key change detected → "changed" / "changed_verified" (WARNING) + ↓ + Accept new key → "trusted" (verification reset) +``` + +**Verification methods:** +1. **Safety numbers** — 60-digit number (12 groups × 5 digits), deterministic for each pair (lower user_id's fingerprint first). Both users see the same number — compare in person or over trusted channel. +2. **QR codes** — Encode `0x01 + uid_len + uid + identity_key` (70 bytes). One user shows QR, other scans → automatic verification. +3. **Fingerprints** — Per-user 30-digit number (6 groups × 5 digits). Visual comparison. + +**Algorithm (Signal NumericFingerprint compatible):** +- Fingerprint: SHA-512 iterated 5200× on `version(2B) + identity_key(32B) + user_id(UTF-8)`, truncated to 32 bytes +- Display: `int(bytes[i*5:(i+1)*5], big-endian) % 100000`, zero-padded to 5 digits + +**Local storage** (encrypted with `_local_key`, AES-256-GCM): +- `known_identity_keys.bin` — TOFU registry: `{user_id → {identity_key hex, first_seen, last_seen}}` +- `verified_contacts.bin` — Explicit verification: `{user_id → {identity_key hex, verified_at, method}}` + +**Tamper resistance:** Oba soubory jsou šifrovány AES-256-GCM (klíč odvozen z identity key přes HKDF). Na rozdíl od session/sender key souborů (které mají plaintext migration fallback kvůli zpětné kompatibilitě) verifikační soubory **nemají žádný plaintext fallback** — pokud dešifrování selže, vrátí se prázdný dict. Útočník s přístupem k disku (ale bez znalosti hesla/identity key) tedy nemůže: +- Podvrhnout falešný "verified" status (injekce do `verified_contacts.bin`) +- Potlačit TOFU key-change warning (injekce do `known_identity_keys.bin`) +- Nejhorší případ při manipulaci = verifikace se resetuje (prázdný stav), nikdy se nepřijme podvržená hodnota + +**Threat model:** +- **DNS únos / podvržený server (existující kontakt):** TOFU detekuje key change → warning dialog. Útočník nemůže tiše podvrhnout klíč. +- **DNS únos / podvržený server (první kontakt):** TOFU věří prvnímu klíči — MITM úspěšný. Obrana: out-of-band verifikace safety number přes jiný kanál (telefon, osobně). Pokud se čísla neshodnou → odhaleno. +- **Kompromitovaný reálný server:** Server podmění klíč v `get_key_bundle` → TOFU detekuje změnu (stejné jako DNS únos). Server neví kdo je verified (stav je lokální) → nemůže cíleně obejít. +- **Disk access bez hesla:** Verifikační soubory šifrovány AES-256-GCM, žádný plaintext fallback → nelze podvrhnout. Viz tamper resistance výše. +- **TLS je první linie obrany:** S platným certifikátem DNS únos nestačí. Bez TLS / s `TLS_INSECURE` je MITM triviální. Verifikace kontaktů je druhá linie — chrání i při kompromitovaném serveru. + +**Data flow:** +1. `_get_user_info()` calls `check_identity_key()` → records TOFU or detects change +2. Key change → `_key_change_cb` fires → GUI shows warning dialog +3. User opens VerificationDialog → sees safety number + QR → marks verified +4. `_rebuild_conv_list()` queries `get_verification_status()` → shows green checkmark for verified DMs +5. E2E label in chat header shows "Verified" (green) or "Encrypted" (muted) + +**QR code encoding detail:** +- Raw binary payload (`0x01 + uid_len + uid + identity_key`) je před vložením do QR zakódován jako **base64** (ASCII-safe) +- Důvod: QR čtečky (pyzbar/zbar) re-kódují binární data přes UTF-8 → byty > 127 se zkomolí +- Při skenování se base64 dekóduje zpět na raw bytes → `decode_verification_qr()` +- iOS implementace musí použít stejný base64 wrapper (viz iOS spec níže) + +**Self-verification exclusion:** +- Vlastní user_id se nikdy nezobrazuje jako "unverified" — TOFU registr neobsahuje vlastní klíč (ten je na disku) +- GUI: security section v UserProfileDialog se nezobrazuje pro vlastní profil, verified badge v group info přeskakuje vlastní uid + +### Protocol + +Newline-delimited JSON over TCP (optional TLS). Fields: `type`, `status`, `data`, `request_id`. +Binary data encoded as base64 via `encode_binary()`/`decode_binary()`. + +**Request/response pattern:** Client sends `{"type": "...", "request_id": "uuid", ...}`, server responds with same `request_id`. Notifications (push) have no `request_id`. + +### Server notifications (push to connected clients) +- `new_message` — per-recipient ciphertext included +- `messages_read` — conversation_id + user_id + message_ids +- `message_deleted` — message_id + conversation_id +- `conversation_created` — conversation_id, name, created_by, members[] (pushed to added members) +- `member_added` — conversation_id, user_id, username, email (pushed to all members except requester) +- `member_removed` — conversation_id, user_id (pushed to removed member + remaining members) +- `group_invitation` — conversation_id, conversation_name, invited_by, invited_by_username (pushed to invited user) +- `conversation_renamed` — conversation_id, name, renamed_by (pushed to all members except renamer) +- `session_reset` — from_user_id, from_device_id (pushed to peer when session reset requested) +- `message_reacted` — message_id, conversation_id, user_id, username, reaction, action (pushed to members) +- `message_pinned` — message_id, conversation_id, user_id, username (pushed to members) +- `message_unpinned` — message_id, conversation_id, user_id, username (pushed to members) +- `user_online` — user_id (pushed to contacts when user connects) +- `user_offline` — user_id (pushed to contacts when user's last connection drops) +- `online_users` — user_ids[] (sent to user on login — list of currently online contacts) + +## DB Schema (schema.sql) + +``` +users: id, username, email (UNIQUE), rsa_public_key (TEXT), identity_key (BLOB 32B Ed25519), created_at +devices: id, user_id FK, device_name (nullable), created_at, last_seen_at +signed_prekeys: id, user_id FK, device_id (nullable), public_key (BLOB 32B), signature (BLOB 64B), created_at +one_time_prekeys: id, user_id FK, device_id (nullable), public_key (BLOB 32B) +conversations: id, created_at, name (nullable), created_by (nullable), avatar_file (nullable) +conversation_members: conversation_id + user_id (composite PK), joined_at +group_invitations: id, conversation_id FK, user_id FK, invited_by FK, created_at, UNIQUE(conversation_id, user_id) +messages: id, conversation_id FK, sender_id FK, sender_device_id (nullable), ratchet_header (BLOB JSON), + x3dh_header (BLOB JSON nullable), sender_chain_id (BLOB nullable), sender_chain_n (INT nullable), + created_at, deleted_at, image_file_id, pinned_at (nullable), pinned_by (nullable) +message_recipients: message_id + user_id + device_id (composite PK), encrypted_content (BLOB), nonce (BLOB), + ratchet_header (BLOB nullable), x3dh_header (BLOB nullable) +message_reactions: id, message_id FK, user_id FK, reaction VARCHAR(32), created_at, UNIQUE(message_id, user_id, reaction) +group_sender_keys: conversation_id + sender_id + device_id (composite PK), chain_id (BLOB 32B), created_at +message_reads: message_id + user_id (composite PK), read_at +image_uploads: file_id (PK), conversation_id FK, uploader_id FK, file_size, completed, created_at +user_profiles: user_id (PK FK), phone, phone_visible, email_visible, location, location_visible, avatar_file, updated_at +``` + +Constant: `SELF_DEVICE_ID = "00000000-0000-0000-0000-000000000000"` — sentinel for self-encrypted copies and legacy rows. + +Index: `messages(conversation_id, created_at)` for query performance. + +## Server Protocol — All Message Types + +### Pre-login (no session required) +| Type | Handler | Purpose | +|------|---------|---------| +| `register` | `handle_register_start` | Start registration (username, email, public_key, identity_key) | +| `register_confirm` | `handle_register_confirm` | Confirm with 6-digit code | +| `login_start` | `handle_login_start` | Get RSA challenge | +| `login_finish` | `handle_login_finish` | Respond with RSA signature -> session. Client sends `client_version`, server returns `server_version` in response. Also sends `online_users` and `user_online` notifications. | +| `get_user_info` | `handle_get_user_info` | Get user info + identity_key (by email or user_id) | +| `pairing_start` | `handle_pairing_start` | New device starts pairing (gets 8-digit code) | +| `pairing_poll` | `handle_pairing_poll` | New device polls for key payload | + +### Post-login (session required) +| Type | Handler | Purpose | +|------|---------|---------| +| `upload_prekeys` | `handle_upload_prekeys` | Upload SPK + batch of OPKs (server verifies SPK signature) | +| `get_key_bundle` | `handle_get_key_bundle` | Fetch key bundle for X3DH (consumes one OPK) | +| `get_prekey_count` | `handle_get_prekey_count` | Check remaining OPK count + SPK age (`spk_created_at`) for rotation | +| `ensure_prekeys` | `handle_ensure_prekeys` | Combined get_prekey_count + upload_prekeys in single round-trip. Returns count + spk_created_at + upload status. | +| `create_conversation` | `handle_create_conversation` | Create conversation — DMs auto-add both; groups add creator only + create invitations for others | +| `find_conversation` | `handle_find_conversation` | Find existing DM by email | +| `add_member` | `handle_add_member` | Create invitation for user to join group (was: direct add) | +| `remove_member` | `handle_remove_member` | Remove member (creator only) | +| `leave_group` | `handle_leave_group` | Voluntarily leave a group (transfers creator if needed, blocks DM leave) | +| `rename_conversation` | `handle_rename_conversation` | Rename group conversation (creator only, max 100 chars), pushes `conversation_renamed` to members | +| `delete_conversation` | `handle_delete_conversation` | Delete conversation — DMs: remove self; groups: creator-only, removes all members + files | +| `accept_invitation` | `handle_accept_invitation` | Accept pending group invitation → add to members, notify others | +| `decline_invitation` | `handle_decline_invitation` | Decline pending group invitation | +| `list_invitations` | `handle_list_invitations` | List user's pending invitations (with conv name + inviter username) | +| `list_conversations` | `handle_list_conversations` | List all user's conversations (includes avatar_file) | +| `send_message` | `handle_send_message` | Send encrypted message (ratchet_header + recipients[]) | +| `get_messages` | `handle_get_messages` | Get messages (returns per-user ciphertext, JOINs message_recipients). Supports `after_ts` for incremental sync. ROW_NUMBER dedup when both device-specific and SELF_DEVICE_ID rows exist. | +| `mark_read` | `handle_mark_read` | Mark messages as read | +| `delete_message` | `handle_delete_message` | Soft-delete message (sender only) | +| `rotate_keys` | `handle_rotate_keys` | Rotate RSA login key, disconnect other sessions | +| `pairing_claim` | `handle_pairing_claim` | Authorized device claims pairing code | +| `pairing_send` | `handle_pairing_send` | Authorized device sends encrypted key payload | +| `upload_image_start/chunk/end` | Image/file upload | Chunked encrypted upload (32KB chunks). `file_type` param: `"image"` (5MB limit) or `"file"` (50MB limit). | +| `download_image` | Image/file download | Chunked download with offset | +| `get_profile` | `handle_get_profile` | Get user profile (respects visibility for other users) | +| `update_profile` | `handle_update_profile` | Update own profile (phone, location, visibility toggles) | +| `update_avatar` | `handle_update_avatar` | Upload user avatar (base64, max 2MB, JPEG/PNG) | +| `get_avatar` | `handle_get_avatar` | Download user's avatar | +| `update_group_avatar` | `handle_update_group_avatar` | Upload group avatar (base64, max 2MB, JPEG/PNG, creator only) | +| `get_group_avatar` | `handle_get_group_avatar` | Download group avatar | +| `get_deleted_since` | `handle_get_deleted_since` | Get message IDs deleted since a given timestamp (for incremental sync) | +| `reencrypt_messages` | `handle_reencrypt_messages` | Batch upsert message history with self-key (max 500/request, for device pairing + received msg self-encryption) | +| `list_devices` | `handle_list_devices` | List all devices for current user | +| `remove_device` | `handle_remove_device` | Remove a device (not current device) | +| `session_reset` | `handle_session_reset` | Notify peer to reset corrupted Double Ratchet session (push `session_reset` to peer) | +| `react_message` | `handle_react_message` | Add/remove emoji reaction on a message. Push `message_reacted` to members. | +| `pin_message` | `handle_pin_message` | Pin/unpin a message. Push `message_pinned`/`message_unpinned` to members. | +| `get_pinned_messages` | `handle_get_pinned_messages` | Get list of pinned messages for a conversation. | + +## Key Classes & Functions + +### crypto_utils.py + +**Password-based key encryption (ECP1 format):** +- `PBKDF2_ITERATIONS = 600_000` — OWASP 2023 compliant +- `_encrypt_private_key(raw_bytes, password) -> bytes` — PBKDF2-HMAC-SHA256 + AES-256-GCM. Format: `_ECP1_MAGIC(4) + salt(16) + nonce(12) + ciphertext_with_tag` +- `_decrypt_private_key(data, password) -> bytes` — Detects ECP1 magic prefix, derives key, decrypts + +**RSA (login only):** `generate_rsa_keypair()`, `serialize_private_key()` (ECP1 with password, PEM without), `serialize_public_key()`, `load_private_key()` (auto-detects ECP1 vs legacy PEM), `load_public_key()`, `rsa_sign()`, `rsa_verify()` + +**AES-256-GCM:** `aes_encrypt(plaintext, key=None) -> (key, nonce, ct, tag)`, `aes_decrypt(key, nonce, ct, tag) -> plaintext` + +**Ed25519:** `generate_identity_keypair()`, `serialize_ed25519_private()` (ECP1 with password, 32-byte raw without), `serialize_ed25519_private_raw()`, `serialize_ed25519_public()`, `load_ed25519_private()` (auto-detects ECP1 vs legacy PEM vs raw), `load_ed25519_public()`, `ed25519_sign()`, `ed25519_verify()` + +**X25519:** `generate_x25519_keypair()`, `serialize_x25519_private()`, `serialize_x25519_public()`, `load_x25519_private()`, `load_x25519_public()`, `x25519_dh()` + +**Key conversion:** `ed25519_private_to_x25519()` (SHA-512 + clamp), `ed25519_public_to_x25519()` (Montgomery u-coordinate) + +**HKDF:** `hkdf_derive()`, `kdf_rk(root_key, dh_output) -> (new_root_key, chain_key)`, `kdf_ck(chain_key) -> (new_chain_key, message_key)` + +**X3DH:** `generate_signed_prekey(identity_private) -> {private, public, signature, id}`, `generate_one_time_prekeys(count=50) -> [{private, public, id}]`, `x3dh_initiate(ik_private_ed, ik_public_remote_ed, spk_remote, spk_signature, opk_remote?) -> (shared_secret, ek_priv, ek_pub)`, `x3dh_respond(ik_private_ed, spk_private, ik_remote_ed, ek_remote, opk_private?) -> shared_secret` + +**DoubleRatchet class:** +- `init_alice(shared_secret, bob_spk_pub)` — initiator, performs first DH ratchet +- `init_bob(shared_secret, spk_pair)` — responder, waits for first message +- `encrypt(plaintext) -> {header: {dh_pub, n, pn}, ciphertext, nonce}` — AAD = serialized header +- `decrypt(header_dict, ciphertext, nonce)` — handles DH ratchet step if new dh_pub, skipped messages. **State snapshot/rollback on failure (M9):** `_snapshot()` captures all mutable state before modifications, `_restore()` rolls back on any exception. +- `_snapshot() -> dict` / `_restore(snap)` — Snapshot: dh_pair, dh_remote, root_key, send/recv chain keys, counters, skipped dict. Used internally by `decrypt()`. +- `export_state() -> bytes` / `import_state(data) -> DoubleRatchet` — JSON serialization + +**SenderKeyState class:** +- `__init__(sender_key=None)` — generates random 32B key if None +- `encrypt(plaintext) -> {chain_id, n, ciphertext, nonce}` — AAD = chain_id + message number +- `decrypt(chain_id_hex, n, ciphertext, nonce)` — fast-forwards chain if needed. **State snapshot/rollback on failure (M9):** snapshots chain_key, n, _known_keys before fast-forward, restores on exception. +- `export_key() -> bytes` — for distribution to group members +- `from_key(exported_key) -> SenderKeyState` — receiver initializes from exported key +- `export_state() / import_state()` — full state persistence + +**Contact Key Verification:** +- `FINGERPRINT_VERSION = 0` — version byte for fingerprint algorithm +- `compute_fingerprint(user_id, identity_key_bytes, iterations=5200) -> bytes` — iterated SHA-512, truncated to 32 bytes. Matches Signal's NumericFingerprint. +- `format_fingerprint(fp_bytes) -> str` — 32 bytes → 6 groups of 5 digits (30 digits), 2 lines +- `compute_safety_number(my_uid, my_ik, their_uid, their_ik) -> str` — 60 digits (12 groups of 5), deterministic ordering (lower uid first), 3 lines of 4 groups +- `encode_verification_qr(user_id, identity_key_bytes) -> bytes` — `0x01 + uid_len(1B) + uid(UTF-8) + ik(32B)` +- `decode_verification_qr(data) -> (user_id, identity_key_bytes)` — inverse of encode + +**Message Padding:** +- `_PAD_MAGIC = b"\x01"` — prefix byte distinguishing padded from legacy unpadded messages +- `_PAD_BUCKETS = [64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536]` — target sizes +- `pad_plaintext(plaintext) -> bytes` — Pad to nearest bucket. Format: `0x01 + plaintext + random_padding + pad_length(4B big-endian)`. +- `unpad_plaintext(data) -> bytes` — Remove padding. Legacy unpadded messages (starting with `{`) returned unchanged. + +### chat_core.py + +**Local key storage** (`~/.encrypted_chat/{email}/`): +``` +private.pem / public.pem — RSA (login, ECP1 format when password-encrypted) +identity_private.bin / identity_public.bin — Ed25519 (ECP1 format when password-encrypted, 32B raw otherwise) +device_id.txt — This device's UUID +spk_private.bin / spk_id.txt — Current signed prekey (AES-256-GCM via local_key) +prev_spk_private.bin / prev_spk_id.txt — Previous SPK for grace period (AES-256-GCM via local_key) +opk_private/{opk_id}.bin — One-time prekeys (AES-256-GCM via local_key) +login_lockout.json — Brute-force lockout state (failed_attempts, locked_until) +sessions/{user_id}_{device_id}.bin — Double Ratchet states (per peer device) +sender_keys/{conv_id}.bin — Own sender key states +sender_keys_recv/{conv_id}_{sender_id}_{device_id}.bin — Received sender keys (per sender device) +known_identity_keys.bin — TOFU registry: {user_id -> {identity_key hex, first_seen, last_seen}} +verified_contacts.bin — Explicit verification: {user_id -> {identity_key hex, verified_at, method}} +``` + +Storage functions: `save_keys()`, `load_keys()`, `_save_identity_keys()`, `_load_identity_keys()`, `_save_spk(local_key=)`, `_load_spk(local_key=)`, `_save_prev_spk(local_key=)`, `_load_prev_spk(local_key=)`, `_save_opk_private(local_key=)`, `_load_opk_private(local_key=)`, `_delete_opk_private()`, `_save_session()`, `_load_session()`, `_save_sender_key_state()`, `_load_sender_key_state()`, `_save_recv_sender_key()`, `_load_recv_sender_key()`, `_save_known_identity_keys()`, `_load_known_identity_keys()`, `_save_verified_contacts()`, `_load_verified_contacts()` + +Lockout functions: `_check_lockout(email) -> float`, `_record_failed_attempt(email)`, `_clear_lockout(email)`. Constants: `_LOCKOUT_BASE_SECONDS=2`, `_LOCKOUT_MAX_SECONDS=300`. + +**ChatClient attributes:** +- `private_key` / `public_key` — RSA (login) +- `identity_private` / `identity_public` — Ed25519 +- `spk_private` / `spk_id` — Current SPK +- `_prev_spk_private` / `_prev_spk_id` — Previous SPK for grace period (M4) +- `opk_privates: dict[str, X25519PrivateKey]` — OPK private keys by ID +- `device_id: str | None` — this device's UUID (persisted to disk) +- `sessions: dict[str, DoubleRatchet]` — "user_id:device_id" -> ratchet (per peer device) +- `sender_key_states: dict[str, SenderKeyState]` — conv_id -> own sender key +- `recv_sender_keys: dict[str, SenderKeyState]` — "conv_id:sender_id:device_id" -> their key +- `_device_bundle_cache: dict[str, tuple[float, list]]` — user_id -> (timestamp, device_bundles) with 5-min TTL +- `_pending_self_encrypt: list[dict]` — queue of received messages to self-encrypt for multi-device access +- `_user_cache: dict[str, dict]` — user_id -> {identity_key, username, email, identity_key_status} +- `connected: bool` — current connection state +- `_known_identity_keys: dict` — TOFU registry (user_id -> {identity_key hex, first_seen, last_seen}) +- `_verified_contacts: dict` — explicit verification (user_id -> {identity_key hex, verified_at, method}) +- `_key_change_cb: Callable | None` — callback fired on identity key change (GUI wires to warning dialog) + +**Key methods:** +- `register()` — Generates RSA + Ed25519, sends to server +- `confirm_registration()` — Confirms code, uploads prekeys (SPK + 50 OPKs) +- `login()` — Loads keys from disk (including prev_spk for grace period), RSA challenge-response, auto `_ensure_prekeys()` +- `_ensure_prekeys()` — Checks OPK count AND SPK age. Replenishes OPKs if < 20, **rotates SPK if >= 7 days old** (M4). Saves old SPK as grace period before generating new one. +- `_get_device_bundles(peer_user_id)` — Fetches per-device key bundles with 5-min TTL cache +- `_get_or_create_session(peer_user_id, peer_device_id, bundle)` — Loads from memory/disk or creates via X3DH, keyed by "user:device" +- `_process_x3dh_header(sender_id, x3dh_header, sender_device_id, spk_override?)` — Bob side of X3DH. `spk_override` param allows using previous SPK for grace period (M4). +- `send_message(conv_id, text, members, reply_to?)` — Routes to `_send_dm` or `_send_group_message` +- `_send_dm()` — Per-device Double Ratchet (encrypts for each of recipient's devices), self-encrypted copy with SELF_DEVICE_ID +- `_send_group_message()` — Sender Keys, distributes key if new (per-device) +- `_distribute_sender_key()` — Sends sender key as control message via per-device pairwise ratchet, includes sender_device_id +- `_decrypt_dm()` — Handles X3DH header for new sessions, returns None for control messages. On X3DH decrypt failure, retries with previous SPK (M4 grace period). +- `_decrypt_group()` — Uses received sender key chain +- `decrypt_notification()` — Returns None for control messages (sender key distribution) +- `get_messages()` — Cache-first with incremental sync (`after_ts`). Decrypts new messages, self-encrypts received messages for multi-device, syncs deletions via `get_deleted_since`. Returns merged cache + new messages. +- `_build_messages_from_cache()` — Builds sorted message list from cache dict +- `_queue_self_encrypt()` / `_flush_self_encrypt()` — Queue and upload self-encrypted copies of received messages +- `authorize_device()` — Exports RSA + Ed25519 only (simplified for multi-device — no session/sender key transfer) +- `pairing_wait()` — Imports RSA + identity key from paired device (new device generates own SPK + OPKs on login) +- `reconnect()` — Closes connection, re-establishes TCP + RSA login using in-memory keys +- `get_profile(user_id?)` — Gets user profile from server +- `update_profile(**fields)` — Updates own profile (phone, location, visibility) +- `update_avatar(image_data)` — Uploads avatar +- `get_avatar(user_id)` — Downloads avatar bytes +- `send_file(conv_id, file_path, members, reply_to?)` — Encrypt + chunked upload + send message with `file` payload +- `download_file(file_id, file_info)` — Chunked download + AES-GCM decrypt +- `leave_group(conv_id)` — Leave group, clean up local sender keys +- `rename_conversation(conv_id, name)` — Rename group (creator only) +- `delete_conversation(conv_id)` — Delete conversation, clean up local sender keys +- `accept_invitation(conv_id)` — Accept group invitation +- `decline_invitation(conv_id)` — Decline group invitation +- `list_invitations()` — Fetch pending invitations +- `update_group_avatar(conv_id, image_data)` — Upload group avatar +- `get_group_avatar(conv_id)` — Download group avatar +- `search_messages(conv_id, query)` — Search decrypted message cache (client-side only) +- `reset_session(peer_user_id, peer_device_id?)` — Delete local session + notify peer via server +- `handle_session_reset_notification(from_user_id, from_device_id?)` — Handle incoming session reset +- `_load_verification_stores()` — Load TOFU + verified contacts from disk (called on login/registration/pairing) +- `check_identity_key(user_id, ik_bytes) -> str` — TOFU check, returns "new"/"trusted"/"verified"/"changed"/"changed_verified" +- `verify_contact(user_id, ik_bytes, method)` — Mark contact as explicitly verified +- `unverify_contact(user_id)` — Remove explicit verification +- `accept_key_change(user_id, new_ik_bytes)` — Accept changed key, remove old verification +- `get_verification_status(user_id) -> str` — Returns "verified"/"trusted"/"unverified" +- `get_safety_number(peer_user_id) -> str` — Formatted 60-digit safety number +- `get_my_fingerprint() -> str` — Formatted 30-digit own fingerprint +- `get_peer_fingerprint(peer_user_id) -> str` — Formatted 30-digit peer fingerprint +- `get_verification_qr_data() -> bytes` — QR code payload for own identity +- `verify_qr_code(qr_data) -> (ok, user_id, message)` — Decode + verify scanned QR code + +### gui_client.py + +**AsyncBridge (QThread):** Runs asyncio event loop, `schedule(coro)` queues coroutines, pyqtSignals emit results back to Qt main thread. + +**Key signals:** `login_result`, `conversations_loaded`, `messages_loaded`, `message_sent`, `new_notification`, `messages_read_notification`, `message_deleted_notification`, `conversation_updated`, `connection_state_changed`, `profile_loaded`, `profile_updated`, `avatar_loaded`, `online_status_changed`, `online_users_loaded`, `file_sent`, `file_downloaded`, `group_left`, `conversation_deleted`, `invitations_loaded`, `invitation_result`, `invitation_received`, `group_avatar_loaded`, `group_avatar_updated`, `session_reset_notification`, `key_change_warning` + +**MainWindow:** Dark theme (Catppuccin Mocha), conversation list with circular avatars + online green dot overlay + unread count badges + verification checkmark, message bubbles with colored left border, context menu (reply, delete, view image, download file), image thumbnails via QTextDocument resources (`thumb://{file_id}`), file cards with download links (`file://{file_id}`), connection indicator dot (green/red/orange), profile button, attach menu (Image/File), Leave Group button in group info, delete conversation button (trash icon in header), group avatar display + change in group info dialog, invitation list (amber border) above conversation list with right-click accept/decline. E2E label clickable → opens VerificationDialog. Key change warning dialog on identity key change. + +**UserProfileDialog:** View (read-only) and edit (own profile) modes. Fields: avatar (circular crop), username, email, phone, location, visibility toggles. Avatar upload/download. Security section (viewing others): verification status + fingerprint. Opened from "My Profile" button or user info button in group info dialog. + +**VerificationDialog:** Frameless dialog for contact verification. Shows: peer name, verification status, safety number (monospace), QR code image, both fingerprints. Buttons: "Mark as Verified" / "Remove Verification" / "Scan QR Code" (via pyzbar). QR generation via `qrcode` library. + +**Avatar system in conversation list:** +- `_avatar_cache: dict[str, QPixmap]` — user avatars by user_id +- `_group_avatar_cache: dict[str, QPixmap]` — group avatars by conv_id +- `_avatar_requested: set[str]` / `_group_avatar_requested: set[str]` — dedup download requests +- `_make_circular_avatar(pixmap, size=32)` — QPainter circular crop +- `_make_default_avatar(username, size=32)` — colored circle with initial letter (deterministic color from username hash) +- `_add_online_dot(avatar)` — green dot overlay bottom-right +- `_get_conv_avatar(conv)` — returns QIcon (DM: user avatar + online dot; group: group avatar or default) +- Periodic refresh every 2 minutes via `_refresh_timer` / `_on_periodic_refresh()` + +## Important Implementation Details + +### X3DH Header Caching +When `_get_or_create_session()` creates a new session via X3DH, it attaches the X3DH header as `ratchet._x3dh_header`. The next `_send_dm()` reads and deletes it. This ensures the X3DH header is only sent with the first message. + +### Self-Encryption for DMs +Sender uses `derive_self_encryption_key(identity_private)` to encrypt their own copy of sent messages with a static AES key. Uses `SELF_DEVICE_ID` sentinel so all own devices can read it. This allows reading own sent messages when fetching history from any device. **Received messages** are also self-encrypted after decryption (via `_pending_self_encrypt` queue + `_flush_self_encrypt()`), creating SELF_DEVICE_ID copies so other devices of the same user can read them. `batch_reencrypt_messages()` uses INSERT ON DUPLICATE KEY UPDATE (upsert) to handle both cases. + +### Sender Key Distribution as Control Messages +Sender keys are distributed via normal `send_message` protocol (per-device pairwise ratchet). The payload contains `_sender_key: {conv_id, key, sender_device_id}` field. On decryption, `_decrypt_dm()` detects this field, stores the sender key keyed by `"conv_id:sender_id:sender_device_id"`, and returns `None` (not shown to user). + +### Group Messages: Dummy Ratchet Header +Group messages use `{"dh_pub": "00"*32, "n": 0, "pn": 0}` as ratchet_header because the server requires it, but groups use sender keys instead of Double Ratchet. + +### Multi-Device Architecture +Each device has independent Double Ratchet sessions. Sessions are keyed by `"peer_user_id:peer_device_id"`. When sending a DM, the client fetches per-device key bundles via `_get_device_bundles()` and encrypts separately for each device. The server registers devices at login (`handle_login_finish`), assigns device IDs, and routes notifications with `device_entries` arrays (one entry per recipient device). Device IDs are persisted to `~/.encrypted_chat/{email}/device_id.txt`. Old session files (`{user_id}.bin`) are automatically migrated to `{user_id}_{device_id}.bin` on first load. + +### Server Session Model +`connected_clients: dict[str, list[ProtocolWriter]]` — one user can have multiple connections (multi-device). `writer_device_map: dict[int, str]` maps `id(writer)` to `device_id`. Notifications are pushed to all connections except the sender's current one. + +### Device Registration +On `login_finish`, server checks for `device_id` in the request. If present and valid (belongs to user), reuses it. Otherwise creates a new device. Device ID returned in response and stored on client disk. `list_devices` and `remove_device` handlers for device management. + +### Simplified Pairing (Multi-Device) +`authorize_device()` only exports RSA + identity key (no sessions/sender keys). New device generates its own SPK + OPKs on first login, creates independent sessions via X3DH. Old messages readable via self-encryption (shared identity key). `reencrypt_history()` still runs to ensure all messages have self-encrypted copies. + +### Real-time Conversation Notifications +`handle_create_conversation`, `handle_add_member`, `handle_remove_member`, `handle_leave_group`, `handle_delete_conversation`, `handle_accept_invitation` push notifications to affected members via `connected_clients`. Types: `conversation_created`, `member_added`, `member_removed`, `group_invitation`. GUI handles these via `conversation_updated` signal -> refreshes conversation list. + +### Connection State + Auto-Reconnect +`ChatClient.connected` flag tracks TCP connection state. `_background_listener` sets `connected = False` when server closes connection and **fails all pending futures** with `ConnectionError` (prevents `send_and_recv` from hanging forever). `send_and_recv` has a 30s timeout via `asyncio.wait_for` and catches `ConnectionError`/`TimeoutError`. `reconnect()` re-establishes TCP + RSA challenge-response using in-memory keys (no password needed, includes `device_id`). GUI `_notification_loop` detects listener death -> triggers `_auto_reconnect` with exponential backoff (1s->2s->4s->...->30s). Connection indicator dot: green (connected), red (disconnected), orange (reconnecting). + +### Server Per-Message Error Handling +Server dispatch loop wraps each handler call in individual try/except. Handler crashes return "Internal server error" response instead of killing the entire connection. Errors logged with `exc_info=True` for full tracebacks. GUI `_do_send_message`/`_do_find_or_create_and_send` catch exceptions and emit error signal (prevents silent hang when send fails). + +### Online/Offline Status +- `db.get_user_contacts(user_id)` returns all user IDs sharing at least one conversation +- On login (`handle_login_finish`): server sends `online_users` list to new user + `user_online` to all contacts (only if user was fully offline before) +- On disconnect (`handle_client` finally block): if last connection drops, server sends `user_offline` to all contacts +- `_background_listener` routes `user_online`, `user_offline`, `online_users` to notification queue +- GUI: `_online_users: set[str]` tracks online users, green dot overlay on circular avatar in DM conversation list + green circle emoji in group info member list + +### Leave Group +- `handle_leave_group` in server.py: validates membership, blocks DM leave (len<=2 and no name), transfers creator to first remaining member if creator leaves, removes member, notifies remaining via `member_removed` +- `ChatClient.leave_group()`: sends request, cleans up local sender key states on success +- GUI: red "Leave Group" button in group info dialog, confirmation dialog, resets view on success + +### Delete Conversation +- **DMs:** Any member can delete. Only removes the deleting user from `conversation_members`. If both users delete, 0 members remain → conversation + files cleaned up. +- **Groups:** Only the creator (admin) can delete. Removes ALL members, cleans up `.enc` files from disk, deletes conversation via CASCADE. +- Server notifies remaining members via `member_removed` push. +- GUI: trash icon button in conversation header. Visible for DMs always, for groups only when user is creator. +- `chat_core.py`: cleans up local sender key states after successful delete. + +### Group Invitations +- **Flow:** `create_conversation` (group) or `add_member` → creates invitation → pushes `group_invitation` notification → invitee sees in invitation list → Accept (adds to members, notifies) / Decline (deletes invitation) +- **DMs are unaffected:** `create_conversation` for DMs still auto-adds both members +- **DB:** `group_invitations` table with UNIQUE(conversation_id, user_id) to prevent duplicates +- **Server:** `handle_accept_invitation` verifies invitation exists, adds member, deletes invitation, notifies existing members via `member_added`. `handle_decline_invitation` just deletes. +- **GUI:** `inv_list` QListWidget (max 120px, amber border) above `conv_list`. Right-click → Accept/Decline. `invitation_received` signal triggers refresh + notification banner. +- **Routing fix (IMPORTANT):** `group_invitation` must be in the notification types list in `chat_core.py:_background_listener` (~line 304). Without it, invitations get routed to `_response_queue` and the GUI never sees them. + +### Group Avatars +- Stored as files in `UPLOAD_DIR/avatars/group_{conv_id}.{ext}` (PNG or JPEG detected from magic bytes) +- `conversations.avatar_file` column stores the filename +- `list_conversations` response includes `avatar_file` so GUI knows which groups have avatars +- GUI: `_group_avatar_cache` dict, `_get_conv_avatar()` returns group avatar icon or default letter circle +- Group Info dialog shows 64px circular avatar + "Change Avatar" button (creator only) +- Periodic refresh every 2 minutes re-downloads all known group avatars + +### File Sharing +- Reuses image upload/download infrastructure (`upload_image_start/chunk/end`, `download_image`) +- `upload_image_start` accepts optional `file_type` param: `"image"` (MAX_IMAGE_BYTES=5MB) or `"file"` (MAX_FILE_BYTES=50MB) +- `ChatClient.send_file()`: reads raw file, AES-256-GCM encrypts, chunked upload, sends message with `file` field in payload (`{file_id, aes_key, iv, filename, size, mime_type}`) +- `ChatClient.download_file()`: identical to `download_image()` — chunked download + AES-GCM decrypt +- GUI: attach button is dropdown menu (Image / File), file messages render as styled cards with paperclip icon (transparent background, border) and clickable download link (`file://{file_id}`), context menu "Download file" option +- Files stored as `.enc` in UPLOAD_DIR, same as images + +### Unread Count Badges +- `_unread_counts: dict[str, int]` replaces old `_unread_convs: set` +- `_on_notification()` increments count per conversation +- `_on_conv_selected()` clears count for selected conversation +- Display: `(3) Username` with bold font, instead of old `● Username` + +### User Profiles +`user_profiles` table separated from `users` (clean separation, users = auth only). Default profile created on registration (`db.create_default_profile`). Visibility rules applied server-side in `db.get_user_profile(viewer_id)`. Avatars stored as files in `UPLOAD_DIR/avatars/{user_id}.{ext}` (not in DB). Format detection from magic bytes (PNG header vs default JPEG). UserProfileDialog shows circular cropped avatar (QPainter). + +### Prekey Replenishment + SPK Rotation +After login, `_ensure_prekeys()` is called as a background task. Checks two things: +1. **OPK count** — if < 20, generates and uploads a new batch of 50 +2. **SPK age** — server returns `spk_created_at` in `get_prekey_count` response. If SPK is >= 7 days old (`SPK_ROTATION_DAYS`), triggers rotation: saves current SPK as `prev_spk_private.bin`/`prev_spk_id.txt` (grace period), generates new SPK, uploads to server. + +### Password-Based Key Encryption (ECP1 Format) — M3 +Private keys (RSA, Ed25519) are encrypted with a custom envelope instead of `BestAvailableEncryption`: +- **Key derivation:** PBKDF2-HMAC-SHA256 with 600,000 iterations (OWASP 2023 compliant) +- **Encryption:** AES-256-GCM with the derived key, magic bytes as AAD +- **Format:** `_ECP1_MAGIC("ECP1", 4B) + salt(16B) + nonce(12B) + ciphertext_with_tag(N+16B)` +- **Backward compatibility:** `load_private_key()` and `load_ed25519_private()` detect ECP1 magic prefix. If absent, fall back to legacy PEM parsing (old `BestAvailableEncryption` format). On next save, files are re-encrypted in ECP1 format. +- **Functions:** `_encrypt_private_key()`, `_decrypt_private_key()` in `crypto_utils.py` +- **Applied to:** `serialize_private_key()` (RSA), `serialize_ed25519_private()` (Ed25519) + +### SPK Rotation (7-Day Cycle) — M4 +Signed Pre-Keys rotate periodically to limit exposure from a compromised SPK: +- **Rotation interval:** `SPK_ROTATION_DAYS = 7` (constant in `chat_core.py`) +- **Trigger:** `_ensure_prekeys()` checks `spk_created_at` from `get_prekey_count` response. If age >= 7 days, calls `_generate_and_upload_prekeys()`. +- **Grace period:** Before generating a new SPK, the current one is saved as `prev_spk_private.bin` / `prev_spk_id.txt`. Loaded on login into `_prev_spk_private` / `_prev_spk_id`. +- **Fallback on decrypt:** When `_decrypt_dm()` processes an X3DH header and decryption fails with the current SPK, it retries with the previous SPK via `_process_x3dh_header(..., spk_override=self._prev_spk_private)`. This handles in-flight X3DH initiated before rotation. +- **Server side:** `get_signed_prekey()` in `db.py` returns `created_at` column. `handle_get_prekey_count` includes `spk_created_at` (ISO format) in response. +- **Server SPK replacement:** `store_signed_prekey()` deletes old SPK and inserts new one — only one active SPK per device on server. + +### Ratchet State Rollback on Decrypt Failure — M9 +Both `DoubleRatchet.decrypt()` and `SenderKeyState.decrypt()` modify internal state (chain keys, counters, DH keys) before attempting AES-GCM decryption. If decryption fails (corrupted data, wrong key, AAD mismatch), the state would be permanently corrupted. + +**DoubleRatchet fix:** +- `_snapshot()` captures all mutable fields: `dh_pair`, `dh_remote`, `root_key`, `send_chain_key`, `recv_chain_key`, `send_n`, `recv_n`, `prev_send_n`, `skipped` dict (shallow copy) +- `decrypt()` takes snapshot before any state modification, wraps the entire DH ratchet + chain advance + AES-GCM decrypt in try/except, calls `_restore()` on failure +- Special case: skipped message decryption (no state modification needed) — if AES-GCM fails, the popped key is restored to `skipped` dict + +**SenderKeyState fix:** +- Before fast-forwarding the chain, snapshots `chain_key`, `n`, `_known_keys` (shallow copy) +- On any exception during fast-forward or AES-GCM decrypt, all three are restored + +### Message Reactions +- `ALLOWED_REACTIONS = {"thumbsup", "heart", "laugh", "surprised", "sad", "thumbsdown"}` in `db.py` +- `message_reactions` table with UNIQUE(message_id, user_id, reaction) — one reaction type per user per message +- `handle_react_message`: validates UUID, reaction, membership. Adds/removes reaction. Pushes `message_reacted` to all members. +- `get_messages` response includes `reactions: [{user_id, reaction, created_at}]` per message (batch query via `db.get_reactions()`) +- GUI: React submenu in context menu with toggle (add if not present, remove if already reacted). Badges below message text showing emoji + count, highlighted border if own reaction. Real-time update via `_on_reaction_notification`. +- Forwarding uses `forwarded_from` metadata in plaintext payload — no new server protocol needed, just client convention. + +### Pinned Messages +- `messages.pinned_at` (DATETIME) and `messages.pinned_by` (CHAR(36)) columns +- `handle_pin_message`: validates UUID, membership, pins/unpins. Pushes `message_pinned`/`message_unpinned`. +- `handle_get_pinned_messages`: returns pinned message metadata for a conversation +- `get_messages` response includes `pinned_at` and `pinned_by` per message +- GUI: Pin/Unpin in context menu, pin emoji in message header, "Pinned" button in chat header opens dialog with scrollable list (double-click scrolls to message in chat) + +### @Mentions +- Client-side only — no server-side handling or special notifications +- GUI `MentionCompleter`: popup `QListWidget` shown when `@` detected at cursor, filters conversation members by prefix, inserts `@username ` on selection +- Rendering: `re.sub(r'@(\w+)', ...)` highlights mentions in blue bold (`#89b4fa`) +- Triggered from `_on_input_changed()` -> `_check_mention_trigger()` + +### Contact Key Verification (Safety Numbers / Fingerprints / QR Codes) +- **Zero server changes** — entirely client-side. Server already stores identity keys in `users.identity_key`. +- **TOFU (Trust On First Use):** `known_identity_keys.bin` stores first-seen identity key per user. Encrypted with `_local_key` (AES-256-GCM). Loaded on login/registration/pairing. +- **Explicit verification:** `verified_contacts.bin` stores verified contacts with method + timestamp. Encrypted with `_local_key`. +- **Fingerprint algorithm:** Iterated SHA-512 (5200 iterations), seed = `version(2B) + identity_key(32B) + user_id(UTF-8)`. Each iteration: `SHA-512(prev + identity_key)`. Output: first 32 bytes → 6 groups of 5 zero-padded digits. +- **Safety number:** Both users' fingerprints concatenated (lower user_id first → deterministic ordering). 64 bytes → 12 groups of 5 digits, displayed as 3 lines of 4 groups. Both sides see the same number. +- **QR code format:** `0x01 + uid_len(1B) + uid(UTF-8) + identity_key(32B)`. Generated via `qrcode` library, decoded via `pyzbar` (optional). +- **Identity key status:** `check_identity_key()` returns `"new"` | `"trusted"` | `"verified"` | `"changed"` | `"changed_verified"`. Called from `_get_user_info()`, result stored in `_user_cache` as `identity_key_status`. +- **Key change callback:** `_key_change_cb(user_id, username, old_key_hex, was_verified)` fires on key change. GUI wires to `key_change_warning` signal → warning dialog. +- **GUI indicators:** Green checkmark badge in conversation list for verified DMs (`ROLE_VERIFIED` data role, painted by `ConversationDelegate`). E2E label in chat header shows "Verified" (green) or "Encrypted" (muted) — clickable → opens `VerificationDialog`. Green checkmark next to verified members in group info. Security section in `UserProfileDialog` showing fingerprint + status. +- **VerificationDialog:** Frameless dialog showing safety number (monospace), QR code, both fingerprints, verification status. Buttons: "Mark as Verified" / "Remove Verification" / "Scan QR Code". +- **CLI:** Option 20 (Verify contact) — show safety number + mark verified. Option 21 (Show my fingerprint). +- **Tamper resistance:** `_load_known_identity_keys()` a `_load_verified_contacts()` nemají plaintext migration fallback (na rozdíl od sessions/sender keys). Soubory nikdy neexistovaly v nešifrované podobě — feature přidána po implementaci lokálního šifrování. Pokud útočník s přístupem k disku nahradí šifrovaný soubor plaintextovým JSONem, `_decrypt_local()` selže a load vrátí `{}` (prázdný stav). Útočník nemůže podvrhnout falešný verified status ani potlačit key-change warning. +- **iOS implementation spec (pro implementaci v ios_client/):** + +| Položka | Specifikace | +|---------|-------------| +| **Fingerprint algoritmus** | SHA-512 iterated 5200×. Seed = `version(2B big-endian, hodnota 0) + identity_key(32B) + user_id(UTF-8)`. Každá iterace: `SHA-512(prev_result + identity_key)`. Výstup: prvních 32 bytes. | +| **Fingerprint display** | 6 skupin × 5 číslic: `UInt64(bytes[i*5..": {"identity_key": "", "first_seen": "ISO8601", "last_seen": "ISO8601"}}}` | +| **Verified storage** | Keychain nebo šifrovaný soubor. JSON schema: `{"version": 1, "contacts": {"": {"identity_key": "", "verified_at": "ISO8601", "method": "safety_number\|qr_code\|manual"}}}` | +| **Šifrování storage** | AES-256-GCM klíčem z HKDF(identity_key, salt="local_storage_key", info="encrypted_chat_local"). **Žádný plaintext fallback** — pokud decrypt selže, vrátit prázdný dict. | +| **Identity key status** | `checkIdentityKey(userId, ikBytes) → "new" \| "trusted" \| "verified" \| "changed" \| "changed_verified"`. Volat při každém `getUserInfo()`. | +| **Self exclusion** | Vlastní user_id nikdy nezobrazovat jako unverified — přeskočit v conv listu i group info. | +| **UI: Conversation list** | Zelená fajfka (SF Symbol `checkmark.seal.fill`) vedle jména pro verified DM kontakty. | +| **UI: Chat header** | "Verified" (zelená) nebo "Encrypted" (šedá) pod jménem. Tap → otevře VerificationView. | +| **UI: VerificationView** | Safety number (monospace), QR kód (CIQRCodeGenerator), oba fingerprints, status. Tlačítka: "Mark as Verified" / "Remove Verification" / "Scan QR Code". | +| **UI: Key change alert** | `.alert()` s textem "Identity key for [name] has changed!" + "Accept" / "View Details". | +| **UI: Group info** | Zelená fajfka vedle verified členů (ne u sebe). | +| **Cross-platform test** | Python klient a iOS klient musí pro stejný pár (user_id_A, ik_A, user_id_B, ik_B) vypočítat identický safety number. QR vygenerovaný na jedné platformě musí být čitelný na druhé. | + +### Rate Limits +- Per-IP+email window (60s): register 3/min, login 10/min, send_message 20/min +- Per-connection: 20 req/s +- Per-IP: max 10 connections, global max 200 +- Pairing: TTL 120s, max 90 poll attempts, pairing_start 10/min, pairing_poll 120/min, client polls every 2s + +### GUI Font Handling (IMPORTANT) +All widget stylesheet `font-size` declarations use `pt` (not `px`). Using `px` in Qt stylesheets sets `pixelSize` and leaves `pointSize=-1`, which causes `QFont::setPointSize: Point size <= 0` warnings on Windows. Conversion: `pt ~= px * 0.75` at 96 DPI. HTML styles inside QTextBrowser (`_render_single_message_html`) still use `px` — that's fine, QTextBrowser uses its own HTML renderer. Bold fonts for list items use `_bold_font()` helper + `item.setData(FontRole)` to avoid the same issue. + +### Phantom Users (Anti User-Enumeration) +- When a user creates a conversation with an unregistered email, the server creates a "phantom" user with `rsa_public_key = 'PHANTOM'` marker +- Phantom users have real crypto keys (Ed25519 IK, X25519 SPK + 5 OPKs) so X3DH works on the client side +- `handle_find_conversation` and `handle_create_conversation` create phantoms instead of returning "User not found" +- `handle_send_message` skips phantom recipients when storing `message_recipients` — only sender's self-encrypted copy is saved +- `phantom_user_ids: set[str]` in-memory cache loaded at startup from DB, updated on create/delete +- On registration (`handle_register_confirm`): if email belongs to a phantom, the phantom is **upgraded in-place** via `db.upgrade_phantom_user()` — preserves user_id and all FK references (invitations, conversation_members). Phantom's server-generated prekeys are deleted (real user uploads own). +- `handle_create_conversation` (groups) and `handle_add_member` create invitations for phantom users too. Push notifications only sent to non-phantom users. When phantom registers and logs in, they see pending invitations. +- Messages sent to phantom users are NOT stored and NOT recoverable after registration — this is by design (prevents user enumeration, sender sees own messages via self-encryption) +- DB functions: `db.create_phantom_user(email)`, `db.is_phantom_user(user_id)`, `db.delete_phantom_user(user_id)`, `db.upgrade_phantom_user(phantom_id, username, rsa_public_key_pem, identity_key)`, `db.get_all_phantom_user_ids()` + +### Logout/Login Fix +- `_is_logout` flag in MainWindow prevents `closeEvent()` from calling `bridge.stop()` which killed the asyncio loop +- On logout: set `_is_logout = True`, call `bridge.logout()`, then `close()` +- `closeEvent()` only calls `bridge.stop()` if `not self._is_logout` +- This allows `main()` to re-create the login/main windows after logout + +### Server Graceful Shutdown +- SIGINT handler force-closes all writers in `connected_clients` before the asyncio server context manager exits +- Without this, `async with server:` waited forever for `handle_client` loops to finish + +### Version Negotiation +- `VERSION = "0.8.4"` constant in `protocol.py` (shared between client and server) +- `MIN_CLIENT_VERSION = "0.8.3"` — server rejects clients older than this +- Client sends `client_version` in `login_finish` request (both `login()` and `reconnect()`) +- Server logs `client_version`, returns `server_version` in `login_finish` response +- Server startup log includes version: `"Encrypted chat server v0.8.4 listening on ..."` +- iOS client (0.8.3) stays compatible — new notification types (`message_reacted`, `message_pinned`, `message_unpinned`) and payload fields (`reactions`, `pinned_at`, `forwarded_from`) are ignored by older clients + +### Privacy Overlay (Lock Screen) +Anti-forensic privacy feature for PyQt GUI client: +- **Immediate overlay:** On `QEvent.Type.ActivationChange` (window loses focus), dark overlay (`rgba(30,30,46,245)`) with lock icon covers entire window. Protects against Alt+Tab thumbnails, screen sharing, shoulder surfing. +- **Timed lock:** After `_LOCK_TIMEOUT_MS` (30s) unfocused, overlay transitions to locked state — requires password to dismiss. +- **Password verification:** `_on_unlock_attempt()` reads `identity_private.bin` from disk and calls `_decrypt_private_key(data, password)` (ECP1 format: PBKDF2-600k + AES-256-GCM). Successful decryption = correct password. +- **Lock capability detection:** `_lock_capable` flag checks if identity key file starts with `b"ECP1"`. If key is not password-encrypted (legacy/no-password), lock timer never fires (overlay still works as visual privacy screen). +- **Toggle:** Ctrl+Shift+P enables/disables the feature (default: enabled). +- **Notification handling during lock:** `_show_tray_notification()` checks `self._privacy_locked` — tray toasts continue while locked. `_on_notification()` increments unread counts and skips `mark_read` while locked. +- **Components:** `_privacy_overlay` (QWidget), `_lock_input` (QLineEdit password), `_lock_error` (QLabel), `_lock_timer` (QTimer single-shot), `_lock_hint` (QLabel status text). + +### Secure Deletion (Anti-Forensic Wipe) +`_secure_delete(path)` helper in both `chat_core.py` and `server.py`: +- Opens file with `r+b`, overwrites entire content with `os.urandom(size)`, calls `f.flush()` + `os.fsync(f.fileno())`, then `p.unlink()`. +- Fallback: if overwrite fails (permissions, etc.), falls back to standard `p.unlink(missing_ok=True)`. +- **Applied to (chat_core.py):** `_delete_opk_private()`, `_delete_session_file()`, session migration cleanup, sender key migration cleanup, message cache migration (plaintext JSON → encrypted). +- **Applied to (server.py):** Conversation delete (all `.enc`+`.tmp`), message delete (`.enc`), oversized upload chunk cleanup (`.tmp`), incomplete/invalid upload end (`.tmp`), stale upload periodic cleanup (`.enc`+`.tmp`). + +### Metadata Privacy +Four measures to minimize metadata leakage: +- **Message Padding:** `pad_plaintext()`/`unpad_plaintext()` in `crypto_utils.py`. Plaintext padded to nearest bucket size (64B..64KB) before encryption. Format: `0x01 + plaintext + random + length(4B)`. Legacy unpadded messages (prefix `{`) auto-detected by `unpad_plaintext()`. Applied on all 6 send paths + 2 decrypt paths in `chat_core.py`. +- **Log Sanitization:** `_who(session)` returns `u=XXXXXXXX d=YYYYYYYY` (truncated user_id + device_id). Group names, recipient counts, emails, and usernames removed from all server log lines (register, login, DM create, invite, rename, send_message). +- **Metadata Retention:** `db.cleanup_old_reads(days)` and `db.cleanup_old_reactions(days)` delete old interaction data in batches. Default `METADATA_RETENTION_DAYS=90`. Runs every ~1 hour (every 30th cycle of `_periodic_cleanup`). `cleanup_old_reads` joins on `messages.created_at` to only delete reads for old messages. `get_unread_counts(max_age_days)` excludes messages older than retention window — prevents phantom unreads after read cleanup. Indexes: `idx_reads_read_at` on `message_reads.read_at`, `idx_reactions_created_at` on `message_reactions.created_at`. +- **Sender Chain Minimization:** For new group messages, `sender_chain_id`/`sender_chain_n` stored in per-recipient `message_recipients.ratchet_header` instead of `messages` table. Removes persistent sender correlation from message-level DB rows. Server verifies group-only context (dummy `dh_pub` all-zeros) before extracting chain data from per-recipient header — prevents DM injection. Self-copy entries (sender's `user_id == session["user_id"]`) are skipped during chain_meta injection. `_validate_header` now accepts `{"self": true}` for per-recipient ratchet_header (fixes self-copy storage in DB). `handle_get_messages` extracts from both locations (backward compat). Push notifications still include chain data for live decrypt. +- **Architectural limitation (known):** Server still stores `messages.sender_id` and `message_recipients.user_id` — the communication graph (who talks to whom, when) remains visible to the server. Full metadata hiding (e.g. Signal's Sealed Sender, onion routing) is a fundamentally different architecture requiring sender anonymity at the protocol level. Current metadata privacy measures reduce *unnecessary* metadata exposure (message length, log PII, interaction history retention, group chain correlation) but do not hide the communication graph itself. + +## Conventions + +- Server handlers: `handle_(msg, session, writer)` — registered in dispatch table in `handle_client()` +- DB functions: one `get_connection()` per call, `cursor(dictionary=True)`, returns dicts +- Binary data: always base64 in protocol (`encode_binary`/`decode_binary`) +- GUI signals: bridge emits `pyqtSignal`, MainWindow connects in `_connect_signals()` +- Error responses: `{"status": "error", "data": {"message": "..."}}` +- Notification decrypt returning `None` = control message, skip silently +- GUI stylesheet font sizes: always `pt`, never `px` (see Font Handling section above) +- File sharing reuses image upload infrastructure with `file_type` parameter +- Avatar files stored in `UPLOAD_DIR/avatars/` — user: `{user_id}.{ext}`, group: `group_{conv_id}.{ext}` + +## Aktuální stav práce + +### ✅ Dokončeno (tato session) +- **iOS client (Swift/SwiftUI)** — Full native iOS port in `ios_client/` (47 files, ~6200 lines). Wire-compatible with Python server. Crypto layer: CryptoKit (AES-GCM, HKDF, Ed25519, X25519), pure Swift GF(2^255-19) for Ed→X conversion, Security.framework RSA-4096, ECP1 key encryption (PBKDF2 600k + AES-GCM). Protocol: Network.framework TCP+TLS, newline-delimited JSON. Core: ChatClient actor (1644 lines) with X3DH, Double Ratchet, Sender Keys, per-device sessions, SPK rotation. UI: SwiftUI views (login/register, conversation list, chat with message bubbles, group info, profile, search). Server-side RSA PSS compatibility fix: `PSS.MAX_LENGTH` → `PSS.AUTO` in `rsa_verify()` to accept both Python (max salt) and iOS (hash-length salt) signatures. +- Logout/login bug fix — `_is_logout` flag prevents bridge.stop() on logout +- Hover text readability — `color: #cdd6f4;` added to `QListWidget::item:hover` +- File card background — `background:transparent; border:1px solid #45475a` +- Delete conversations — full stack (db, server, chat_core, gui), DMs + groups (creator-only), file cleanup from disk +- Group invitation system — full stack (schema, db, server, chat_core, gui), create/accept/decline, real-time notification, invitation list UI +- Circular avatars in conversation list — QPainter circular crop, default letter avatars, online green dot overlay +- Group avatar support — upload/download, display in group info dialog, "Change Avatar" button (creator only) +- Server graceful shutdown — force-close connected clients on SIGINT +- Profile dialog avatar circular crop — QPainter in UserProfileDialog._on_avatar_loaded +- Periodic refresh timer — 2-minute QTimer re-downloads avatars + invitations +- Group invitation notification fix — `group_invitation` added to `_background_listener` notification types +- Delete button in conversation header — trash icon for DMs always, groups creator-only +- File cleanup on conversation delete — `db.get_conversation_file_ids()` + unlink `.enc` files +- Removed right-click "Delete conversation" from conversation list context menu +- README.md updated +- **H4 Race conditions fix** — 4 asyncio.Lock guards (`_clients_lock`, `_conn_lock`, `_pairing_lock`, `_uploads_lock`) pro všechny sdílené mutable struktury v server.py. `_notify_users()` + `_notify_users_individual()` helpery. Rate limit memory cleanup v periodic task. Všechny I/O operace mimo kritické sekce. +- **Unread counts pro offline uživatele** — `db.get_unread_counts()` dotaz přes `message_reads` + `message_recipients`, server vrací `unread_count` v `list_conversations`, GUI populuje `_unread_counts` ze serverových dat (max z server vs local). Opravuje bug kdy offline uživatel po přihlášení neviděl nepřečtené zprávy. +- **C6 Path traversal fix** — `_UUID_RE` regex + `_valid_file_id()`, `_safe_upload_path()`, `_safe_avatar_path()` helpery v server.py. UUID validace v `handle_upload_image_start`, `handle_download_image`. `is_relative_to()` guard ve všech path konstrukcích: upload start/end, download, delete_message file cleanup, delete_conversation file cleanup, _cleanup_uploads, get/update avatar, get/update group avatar. Celkem 10 guardovaných míst. +- **C3+H1+M13 Lokální šifrování + permissions** — `derive_local_storage_key()` v crypto_utils.py (HKDF z identity key, odlišný salt/info od self-encryption key). `_encrypt_local()`/`_decrypt_local()` helpery v chat_core.py (AES-256-GCM, formát: nonce(12)+tag(16)+ct). `_save_session`/`_load_session`, `_save_sender_key_state`/`_load_sender_key_state`, `_save_recv_sender_key`/`_load_recv_sender_key` — volitelný `local_key` parametr, při nastavení šifruje/dešifruje, `chmod 0o600` na soubory. `ChatClient._local_key` derivováno při login/registraci/pairingu. Transparentní migrace: pokud dešifrování selže, zkusí plaintext a re-uloží šifrovaně. `os.chmod(d, 0o700)` na všechny `mkdir()` v get_key_dir, opk_private, sessions, sender_keys, sender_keys_recv, message_cache. `os.chmod(p, 0o600)` na plaintext fallback message cache. +- **H7 Avatar path traversal** — `_safe_avatar_path()` guard na handle_get_avatar, handle_get_group_avatar + defense-in-depth na handle_update_avatar, handle_update_group_avatar. +- **Multi-device support (per-device sessions)** — `devices` table, `device_id` columns on prekeys/messages/recipients/sender_keys. Server: device registration at login, `writer_device_map`, per-device key bundles (`device_bundles` array), per-device notification routing (`device_entries`), `list_devices`/`remove_device` handlers. Client: `device_id` persistence, sessions keyed by `"user_id:device_id"`, `_get_device_bundles()` with 5-min TTL cache, per-device encryption in `_send_dm`/`send_image`/`send_file`/`_distribute_sender_key`, `sender_device_id` in decrypt routing, `decrypt_notification()` handles `device_entries` format. Pairing simplified: only RSA + identity key transfer, new device generates own SPK + OPKs. Migration: old session files auto-migrated, backward compat with old clients/servers. H12 OPK race condition fixed (SELECT FOR UPDATE). +- **Connection resilience fixes** — `_background_listener` fails all pending futures with `ConnectionError` on disconnect (prevents hang). `send_and_recv` has 30s timeout + catches `ConnectionError`. Server dispatch has per-message try/except (handler crash no longer kills connection). GUI `_do_send_message`/`_do_find_or_create_and_send` catch exceptions and emit error signal. +- **DB transaction fix** — `db.get_key_bundles_for_user()` had "Transaction already in progress" error because mysql-connector starts implicit transactions. Fixed with `conn.commit()` before `conn.start_transaction()`. +- **H5+H6 Protocol error handling** — `decode_binary()` catches `binascii.Error` → `ValueError`. `parse_message()` catches `JSONDecodeError`/`UnicodeDecodeError` → `ValueError`. Server dispatch already handles `ValueError` from `read_message()` gracefully. +- **H3+H13 Anti-enumeration** — `handle_register_start` returns same "ok" response for existing email (no "Email already in use" leak). `handle_login_start` returns fake challenge for non-existent email. `handle_login_finish` returns generic "Invalid credentials" for all failure cases. `get_user_info` moved behind auth barrier (requires login). +- **H8 Password memory cleanup** — `register()`, `login()`, `pairing_wait()` convert password to `bytearray`, zero out in `finally` block after key derivation. +- **H10 Image validation** — `_safe_load_image()` helper validates size (<10MB) and dimensions (<8192px) before `QImage.fromData()`. Applied to all 6 image loading locations in gui_client.py. +- **H11 Filename sanitization** — `_safe_filename()` helper strips path components via `os.path.basename()`. Applied to save dialogs and image dialog title. +- **C1+C2+C5 DoS hardening** — C1: `LimitOverrunError` now drains buffer and raises `ValueError` (server sends error response instead of silent disconnect; memory already protected by `limit=` on StreamReader). C2: `MAX_SENDER_KEY_SKIP` reduced from 1000 to 256 (matches DoubleRatchet `MAX_SKIP`). C5: `handle_upload_image_end` validates `received_bytes == file_size` before completing upload. M12 (upload end size validation) also fixed by C5. +- **M2+M8+M10+M11 Security hardening batch** — M2: SenderKeyState HKDF salt changed from `b""` to `b"\x00" * 32` (matches X3DH convention). M8: `_valid_file_id()` renamed to `_valid_uuid()`, UUID validation added to all handlers accepting client-provided `conv_id`, `user_id`, `message_id`, `device_id`. M10: `handle_mark_read` caps `message_ids` to 500 (prevents slow SQL DoS). M11: `handle_pairing_start` generates `poll_token` (secrets.token_hex(16)), `handle_pairing_poll` requires and validates it via `secrets.compare_digest()` — prevents unauthorized poll/payload extraction. +- **H2+H14 TLS hardening** — `TLS_INSECURE` a `TLS_AUTOGEN` nyní vyžadují `ENVIRONMENT=dev` (RuntimeError bez toho). Warning log na serveru i klientovi když TLS vypnuté. C4 (OPK file permissions) bylo již opraveno v C3+H1+M13 batchi. +- **Online dot fix + sorting** — Fixed timing issue where `online_users` signal processed before conv list populated. `_rebuild_conv_list()` sorts: favorites first → online DMs → rest alphabetically. +- **Favorites system** — Right-click context menu on conversation list → Add/Remove from favorites. Star indicator (★). Persisted to `favorites.json` in user key directory. +- **Group renaming** — Full stack: `db.update_conversation_name()`, `handle_rename_conversation` (server, creator-only, max 100 chars), `rename_conversation()` (chat_core), "Rename" button in group info dialog (GUI), `conversation_renamed` push notification to all members. +- **M3+M4+M9 Security hardening** — M3: PBKDF2 600k iterations (`_encrypt_private_key`/`_decrypt_private_key` s ECP1 formátem, backward compat pro PEM). M4: SPK rotace každých 7 dní, `spk_created_at` v `get_prekey_count`, grace period s `prev_spk_private.bin`, fallback v `_process_x3dh_header`/`_decrypt_dm`. M9: `_snapshot()`/`_restore()` v `DoubleRatchet.decrypt()`, snapshot/restore v `SenderKeyState.decrypt()`. +- **Phantom invitation fix** — Phantom users now receive group invitations. `handle_create_conversation` and `handle_add_member` create invitations for phantoms (no push notification). `handle_register_confirm` upgrades phantom in-place via `db.upgrade_phantom_user()` (preserves user_id + FK references). `handle_add_member` creates phantom for unregistered emails (same as `create_conversation`). After registration, user sees pending invitations on login. +- **Message Search** — Client-side search through decrypted message cache. `ChatClient.search_messages()` searches local cache. GUI: collapsible search bar (Ctrl+F) with prev/next navigation, match count, yellow/orange highlighting in message HTML. Escape closes search. Search button in chat header. +- **Session Recovery** — `session_reset` protocol message + `handle_session_reset` server handler. `ChatClient.reset_session()` deletes local session + notifies peer. Peer handles `session_reset` notification by deleting their session. Next message auto-creates new session via X3DH. GUI: "Reset session with sender" context menu on undecryptable messages, status bar notification on incoming reset. +- **L8 Phantom user DB inflation fix** — `_valid_email()` helper validates email format before phantom creation in `handle_find_conversation`, `handle_create_conversation`, `handle_add_member`. `db.cleanup_stale_phantoms(30)` deletes phantom users older than 30 days with no active conversations with real users. Runs in `_periodic_cleanup` every 10 minutes, refreshes in-memory `phantom_user_ids` cache. +- **M6 TOCTOU race fix** — `db.remove_conversation_member_atomic()` returns bool (True if row existed). Used in `handle_remove_member` (checks return value, returns error if already removed) and `handle_leave_group`. Defense-in-depth: pre-checks remain for user-friendly errors, atomic operation prevents double-removal. +- **Local-first messages + multi-device received messages fix** — Self-encrypted copies for received messages (not just sent). `batch_reencrypt_messages()` changed to INSERT ON DUPLICATE KEY UPDATE (upsert) — allows creating SELF_DEVICE_ID rows for received messages. Server-side Python dedup in `handle_get_messages` (prefers device-specific over SELF_DEVICE_ID). `after_ts` parameter for incremental sync. `get_deleted_since` protocol + handler for deletion sync. `_pending_self_encrypt` queue + `_flush_self_encrypt()` for background self-encryption of received notifications via `asyncio.ensure_future()`. +- **Detailed server logging** — `_who(session)` helper for consistent log formatting. 25+ tagged log lines: `[CONN]`, `[LOGIN]`, `[REGISTER]`, `[MSG]`, `[FETCH]`, `[READ]`, `[UPLOAD]`, `[DOWNLOAD]`, `[CONV]`, `[INVITE]`, `[LEAVE]`, `[RENAME]`, `[DELETE]`, `[AVATAR]`, `[PREKEYS]`, `[X3DH]`, `[ROTATE]`, `[DEVICE]`, `[REENCRYPT]`, `[SESSION]`, `[MEMBER]`, `[LIST]`, `[ERROR]`. +- **Version bump to 0.8.4** — `VERSION = "0.8.4"`, `MIN_CLIENT_VERSION = "0.8.3"` in protocol.py. Server rejects clients older than 0.8.3. iOS client (0.8.3) still works — new notification types and payload fields are silently ignored. +- **Message Reactions** — Emoji reactions on messages (`thumbsup`, `heart`, `laugh`, `surprised`, `sad`, `thumbsdown`). `message_reactions` DB table. `db.add_reaction()`/`db.remove_reaction()`/`db.get_reactions()`. Server: `handle_react_message` + `message_reacted` push notification. `get_messages` response includes `reactions` array per message. GUI: reaction submenu in context menu, toggle (add/remove), emoji badges below messages (grouped by reaction, highlighted if own), real-time updates via notification. CLI: option 16. +- **Message Forwarding** — Forward messages to other conversations. `ChatClient.forward_message()` sends as normal `send_message` with `forwarded_from` metadata field (`{sender, conversation_id, message_id}`). No new server endpoint needed. GUI: "Forward" in context menu, conversation picker dialog, "Forwarded from" header with blue left border in message rendering. CLI: option 19. +- **Pinned Messages** — Pin/unpin messages in conversations. `messages.pinned_at`/`pinned_by` columns. `db.pin_message()`/`db.unpin_message()`/`db.get_pinned_messages()`. Server: `handle_pin_message` + `handle_get_pinned_messages` + `message_pinned`/`message_unpinned` push notifications. GUI: "Pin"/"Unpin" in context menu, pin emoji indicator in message header, "Pinned" button in header opens dialog with list (double-click scrolls to message), real-time updates. CLI: options 17-18. +- **@Mentions** — `@username` highlighting in messages. Client-side only (no server-side handling). GUI: `MentionCompleter` popup autocomplete activated when typing `@` in message input, filters current conversation members, inserts `@username` on selection. Message rendering: `@word` patterns highlighted in blue bold (`#89b4fa`). CLI: visible as plain `@username` text. +- **Drag & Drop souborů/obrázků** — Přetažení souboru na chat oblast (message input nebo message area) automaticky odešle. Obrázky (`.png`, `.jpg`, `.jpeg`, `.gif`, `.bmp`, `.webp`) se posílají přes `send_image`, ostatní přes `send_file`. Vizuální feedback: modrý dashed border při drag-over. Drop je aktivní pouze když je vybraná konverzace (`drop_enabled` flag na `MessageInput`, `current_conv_id` check v event filteru na `message_area`). +- **Registration prekey upload fix** — `_ensure_prekeys()` now detects missing SPK on server (empty `spk_created_at`) and forces upload with `need_new_spk=True`. Previously, registration uploaded prekeys before login (no session → server rejected with "Not logged in"), and the login flow only uploaded OPKs without SPK. Also added warning log in `_generate_and_upload_prekeys()` when upload fails. +- **Message sender identification fix** — `_render_single_message_html()` now uses `sender_id` (user_id UUID) instead of `sender` (username) for `is_me` detection. Fixes incorrect message alignment when two users share the same display name. +- **Privacy Overlay (lock screen)** — Anti-forensic privacy feature (PyQt only). On window deactivation: immediate dark overlay with lock icon hides all chat content (protects against Alt+Tab thumbnails, screen sharing, shoulder surfing). After 30 seconds unfocused: locks and requires login password to unlock (verified by decrypting identity key from disk via ECP1/PBKDF2-600k). Ctrl+Shift+P toggles on/off. Tray notifications continue working while locked. Messages arriving during lock are counted as unread and NOT marked as read until user unlocks. `_lock_capable` flag auto-detects if identity key is password-encrypted (ECP1 format) — lock requires password only when available. +- **Secure Deletion (anti-forensic wipe)** — `_secure_delete(path)` helper overwrites file with `os.urandom()` + `fsync` before `unlink`. Applied to all sensitive file deletions: OPK private keys, session files (reset + migration), received sender key migration, message cache migration (plaintext JSON cleanup) in chat_core.py; `.enc` and `.tmp` files on conversation delete, message delete, oversized upload cleanup, incomplete upload cleanup, stale upload periodic cleanup in server.py. Fallback to standard `unlink` if overwrite fails. +- **UI redesign (Signal/Telegram look)** — `theme.py` theme systém (ThemeColors dataclass, DARK_THEME Catppuccin Mocha, LIGHT_THEME Signal-inspired, ThemeManager singleton s persistence + live switching). Widget-based message bubbles (MessageBubble QFrame s QPainter rounded rect). ConversationDelegate (custom painting: avatar, name, preview, timestamp, unread badge). Redesigned chat header (avatar, name, status, action buttons). Pill-shaped input (auto-resize, 2-line min). Tabbed login (Login/Register/Link Device). Frameless dialogs (`_make_frameless`). Theme toggle (sun/moon) v settings. +- **Reakce/piny persistentní v cache** — `ChatClient.update_message_in_cache()` synchronně aktualizuje pole zprávy v šifrovaném message cache na disku. Volá se při přidání/odebrání reakce, pin/unpin (z kontextového menu i z push notifikací). Opravuje bug kdy reakce a piny mizely po přepnutí konverzace a návratu (incremental sync nenačítal stará data ze serveru). +- **Context menu fallback na zprávách** — `MessageBubble` context menu policy změněna z `CustomContextMenu` na `DefaultContextMenu`. S `CustomContextMenu` Qt emitoval nepřipojený signál místo volání `contextMenuEvent()` override — kontext menu se neukázalo. Event filter zůstává jako primární handler, `contextMenuEvent` je fallback. +- **Forward obrázků a souborů** — `forward_message()` v chat_core.py přeposílá kompletní `image`/`file` metadata (file_id, AES klíč, IV, thumbnail, filename, size). Dříve se posílal jen textový popis `[Forwarded image: ...]`. Šifrovaný soubor je na serveru, stačí přeposlat metadata. +- **Delete message okamžitý** — Smazání zprávy se projeví lokálně ihned (bez čekání na reload). Server posílá `message_deleted` notifikaci jen ostatním, ne odesílateli — opraveno lokálním označením zprávy jako smazané + uložením do cache + re-renderem před odesláním na server. +- **Frameless confirmation dialogy** — `_confirm_dialog()` helper nahrazuje `QMessageBox.question()` — frameless dialog bez systémové lišty, s červeným "Delete" tlačítkem. Aplikováno na Delete Message a Reset Session dialogy. +- **Reakce — explicitní barva textu** — Reaction badge QLabel nyní má explicitní `color: {t.text_primary}` aby byl text viditelný v obou režimech (předtím spoléhal na CSS dědičnost). +- **SPK/OPK encryption + brute-force lockout** — SPK/OPK private keys now encrypted with AES-256-GCM via `_local_key` (derived from identity key via HKDF). Transparent migration from plaintext. Client-side brute-force lockout: exponential backoff `min(2^N, 300)` seconds after N failed password attempts. Lockout state in `login_lockout.json`. Applied to `ChatClient.login()` and GUI privacy overlay unlock (`_on_unlock_attempt`). `_clear_lockout()` on success. +- **Contact Key Verification** — Signal-style safety numbers, fingerprints, QR codes. TOFU key tracking (`known_identity_keys.bin`) + explicit verification (`verified_contacts.bin`), both encrypted with `_local_key`. `crypto_utils.py`: `compute_fingerprint()` (iterated SHA-512, 5200x), `format_fingerprint()` (30 digits), `compute_safety_number()` (60 digits, symmetric), `encode_verification_qr()`/`decode_verification_qr()`. `chat_core.py`: `check_identity_key()` (TOFU with "new"/"trusted"/"verified"/"changed"/"changed_verified"), `verify_contact()`, `unverify_contact()`, `accept_key_change()`, `get_verification_status()`, `get_safety_number()`, `get_my_fingerprint()`, `get_peer_fingerprint()`, `get_verification_qr_data()`, `verify_qr_code()`. GUI: `VerificationDialog` (safety number, QR code, fingerprints, verify/unverify buttons, QR scan), green checkmark in conversation list for verified DMs (`ROLE_VERIFIED`), E2E label clickable with verification status, key change warning dialog, security section in `UserProfileDialog`, verified badge in group info member list. CLI: options 20 (verify contact) and 21 (show fingerprint). Zero server changes. +- **Metadata Privacy (Ochrana metadat)** — Čtyři opatření pro minimalizaci metadat: **(A) Message Padding** — `pad_plaintext()`/`unpad_plaintext()` v crypto_utils.py. Bucketed padding (64/128/256/512/1K/2K/4K/8K/16K/32K/64K). Formát: `0x01 + plaintext + random_padding + pad_length(4B)`. Prefix `0x01` rozliší od legacy JSON (začíná `{`). Aplikováno na všech 6 send cest v chat_core.py (send_message, distribute_sender_key, forward_message, send_image, send_file, reencrypt_history) + unpad na 2 decrypt cestách (_decrypt_dm, _decrypt_group). **(B) Log Sanitizace** — `_who()` vrací `u=XXXXXXXX d=YYYYYYYY` místo `username (email) [device]`. Group names a recipient counts odstraněny z logů. **(C) Metadata Retention** — `db.cleanup_old_reads()`/`db.cleanup_old_reactions()` mažou záznamy starší N dní (default `METADATA_RETENTION_DAYS=90`). Batch delete (10k/iterace). Spouští se v `_periodic_cleanup()` každé 2 minuty. **(D) Sender Chain Přesun** — `sender_chain_id`/`sender_chain_n` se pro nové group zprávy ukládají do `message_recipients.ratchet_header` místo `messages` tabulky. Server `handle_get_messages` extrahuje z obou míst (backward compat). Notifikace stále posílají chain data pro live decrypt. + +### 🐛 Známé bugy a problémy +- **Sender Key Redistribution (High Priority):** New group member can't decrypt old messages. On `add_member`, existing members should re-create and redistribute sender keys. +- ~~**Database Connection Pooling:**~~ ✅ OPRAVENO — `MySQLConnectionPool(pool_size=10)`. +- ~~**Duplicate FETCH after send (GUI):**~~ ✅ OPRAVENO — `send_message` vrací payload lokálně, GUI appenduje bez re-fetch. +- ~~**Group delete confirmation message is generic**~~ ✅ OPRAVENO — frameless dialog s kontextovým textem. +- ~~**Reakce a piny mizely po přepnutí konverzace:**~~ ✅ OPRAVENO — `update_message_in_cache()` ukládá na disk. +- ~~**Forward obrázku/souboru posílal jen text:**~~ ✅ OPRAVENO — přeposílá kompletní image/file metadata. +- ~~**Delete message se neprojevil okamžitě:**~~ ✅ OPRAVENO — lokální smazání + cache update před serverovým voláním. +- ~~**Delete/Reset dialogy měly systémovou lištu:**~~ ✅ OPRAVENO — frameless `_confirm_dialog()`. + +### ⏭️ Další kroky (TODO) + +#### Bezpečnostní opravy (priorita dle auditu) +1. **C6 (CRITICAL): Path traversal přes file_id** — `handle_upload_image_start` vytváří soubor `UPLOAD_DIR / f"{file_id}.tmp"` bez validace. Útočník může `../../...` a zapisovat/mazat mimo UPLOAD_DIR. Řešení: validovat UUID formát, ověřit `path.resolve().is_relative_to(UPLOAD_DIR.resolve())`. +2. ~~**H12 (HIGH): OPK race condition v db.get_key_bundle()**~~ ✅ OPRAVENO (součást multi-device — SELECT FOR UPDATE v consume_one_time_prekey + get_key_bundle) +3. **H3+H13: User enumeration** — `get_user_info` dostupné bez auth, vrací identity_key pro libovolný email. `register_start`/`login_start` vrací jednoznačné chyby. Řešení: auth pro `get_user_info`, generické odpovědi pro register/login. +4. ~~**H2+H14: TLS hardening**~~ ✅ OPRAVENO — `TLS_INSECURE` a `TLS_AUTOGEN` vyžadují `ENVIRONMENT=dev`. Warning log při vypnutém TLS. +5. ~~**C1+C2+C5**~~ ✅ OPRAVENO — DoS vektory (LimitOverrunError → ValueError, MAX_SENDER_KEY_SKIP 256, upload completeness check) +6. **C3+C4+H1** — Šifrování dat na disku (message cache, sessions, OPK permissions, `chmod 0o700` pro adresáře) +7. **H5+H6** — Error handling v protokolu (base64, JSON) +8. **H7** — Path traversal v avatar souborech (`resolved_path.is_relative_to()`) +9. ~~**M11 (MEDIUM): Pairing poll DoS**~~ ✅ OPRAVENO — poll_token binding (secrets.token_hex(16) + secrets.compare_digest) +10. ~~**M12: Upload end bez validace velikosti**~~ ✅ OPRAVENO (součást C5 fixu — `handle_upload_image_end` validuje `received_bytes == file_size`) +11. ~~**L8: Phantom user DB inflation**~~ ✅ OPRAVENO — email validace + periodic cleanup stale phantoms (30 dní) +12. **Version negotiation** — `VERSION = "0.8"` v protocol.py, klient posílá `client_version` při loginu, server loguje a vrací `server_version` + +#### Před nasazením do produkce (checklist) +1. **TLS certifikáty** — Získat certifikát (Let's Encrypt / vlastní CA). Nastavit `TLS_ENABLED=true`, `TLS_CERT_FILE`, `TLS_KEY_FILE` v `.env`. Ověřit že `TLS_INSECURE` a `TLS_AUTOGEN` NEJSOU nastavené (vyžadují `ENVIRONMENT=dev`). Na klientovi nastavit `TLS_ENABLED=true` a případně `TLS_CA_FILE` pokud vlastní CA. +2. **Email validace** — Zapnout `_valid_email()` kontrolu v `handle_find_conversation`, `handle_create_conversation`, `handle_add_member` (kód existuje v server.py, volání zakomentována). Teď vypnuto protože dev prostředí používá emaily bez @. +3. **MySQL TLS** — Přidat SSL parametry do `db.get_connection()` (`ssl_ca`, `ssl_cert`, `ssl_key`) pokud DB běží na jiném stroji. +4. **Connection pooling** — Nahradit `get_connection()` za `mysql.connector.pooling.MySQLConnectionPool(pool_size=10)`. +5. **SMTP** — Nastavit reálný SMTP server pro registrační kódy (`SMTP_HOST`, `SMTP_PORT`, `SMTP_USER`, `SMTP_PASS`, `SMTP_FROM`). +6. **UPLOAD_DIR** — Ověřit že `UPLOAD_DIR` je na persistentním disku s dostatkem místa, správnými právy (0o700). +7. **Rate limity** — Přezkoumat limity pro produkční zátěž (registrace 3/min, login 10/min, send_message 20/min, max 10 spojení/IP). +8. **Packaging** — Zabalit klienta (pyinstaller / cx_Freeze) pro distribuci. Po zabalení zvážit auto-update mechanismus a `get_version` endpoint. +9. **Penetrační testy** — Provést před ostrým nasazením (viz sekce níže). +10. **Backup** — Nastavit pravidelný backup MySQL databáze + `UPLOAD_DIR`. + +#### Penetrační testy +- Naplánovat a provést manuální penetrační testy zaměřené na: + - Path traversal (file_id, avatar_file) + - DoS vektory (readuntil, sender key fast-forward, upload flooding) + - Race conditions (OPK reuse, membership TOCTOU) + - User enumeration (register, login, get_user_info) + - TLS downgrade / MITM bez TLS + - Pairing session hijacking + - Memory exhaustion (rate_limits, phantom users, message_ids) +- Vytvořit testovací skripty pro automatizované security testy +- Zdokumentovat výsledky a opravit nalezené problémy + +#### Ke zvážení +- **Auto-update klientů** — distribuce aktualizovaných souborů klientům před login/registrací. Řešit až po kompilaci/packagingu (pyinstaller apod.). Mechanismus: server verze check → klient stáhne nové soubory → restart. +- **Server version check endpoint** — po packagingu mít jednoduchý endpoint (např. `get_version`), který vrací min/aktuální podporovanou verzi klienta + URL/metadata pro update; klient může před loginem ověřit kompatibilitu a nabídnout update. Vhodné i pro postupné vypínání starých klientů. + +#### Monetizace — plán + +**Princip:** Oddělený platební server (KYC/AML compliant) od chat serveru (anonymní). Platba generuje jednorázový premium kód, chat server zná jen "user_id aktivoval kód". Žádný přímý link platba↔chat identita. + +**Architektura:** +``` +Platební server (Stripe/Paddle): zákazník → platba → vygeneruje premium_code + sdílí jen: premium_code (jednorázový string) +Chat server: user_id → aktivuje premium_code → premium do data X +``` +- Platební server neví jaký user_id kód použil +- Chat server neví kdo kód koupil +- AML splněno na platební straně, privacy zachováno na chat straně + +**Free vs Premium tier:** +| | Free | Premium | +|---|---|---| +| Konverzace | 5 | neomezeno | +| Soubory | 10 MB | 50 MB | +| Zařízení | 1 | 5 | +| Message retention | 30 dní | neomezeno | +| Skupiny | max 10 členů | neomezeno | + +**Implementace (chat server):** +- Tabulka `premium_codes(code VARCHAR(64) PK, plan ENUM('premium_30','premium_365'), created_at, redeemed_by CHAR(36) nullable FK, redeemed_at nullable)` +- Tabulka `user_plans(user_id PK FK, plan ENUM('free','premium'), expires_at DATETIME nullable, message_retention_days INT DEFAULT 30)` +- Sloupec `messages.expires_at` (nullable DATETIME) — NULL = neomezená retence +- Handler `redeem_code` — validuje kód, aktivuje plán, nastaví expires_at +- `handle_send_message` — kontroluje limity (konverzace, velikost) +- Periodic cleanup: `DELETE FROM messages WHERE expires_at IS NOT NULL AND expires_at < NOW()` + CASCADE + `.enc` smazání +- Retence podle konverzace: pokud alespoň jeden člen premium → vyšší retence +- Klient: indikace expirace zpráv, upozornění na limity, "Upgrade" tlačítko + +**Implementace (platební server — oddělený deployment):** +- Jednoduchý web (Flask/FastAPI) s Stripe Checkout +- Generuje premium_code, uloží do sdílené DB nebo pošle přes API +- AML/KYC řeší Stripe (PCI DSS, SCA, reporting) + +**Další revenue streams:** +- **Enterprise/B2B licence** — self-hosted deployment, LDAP/SSO, admin dashboard, SLA. Faktura na IČO. +- **Šifrovaný cloud backup** — export/import historie šifrované uživatelským heslem (nezávisle na retenci) + +#### Funkční vylepšení +1. **Sender Key Redistribution** — on add_member, redistribute sender keys to all members including new one +2. ~~**Device Linking fix**~~ ✅ — replaced with true multi-device (per-device sessions, simplified pairing) +3. ~~**SPK Rotation**~~ ✅ — periodic rotation with grace period (implemented in M4 fix) +4. **Typing Indicators** (budoucí) — `typing_start`/`typing_stop` protocol + GUI indicator (3s timeout, debounce) +5. ~~**CLI support**~~ ✅ — profiles, file sharing, invitations, leave/rename/delete, search, devices in `client.py` +6. ~~**Message search**~~ ✅ — client-side search through decrypted cache, Ctrl+F toggle, highlight + navigation +7. ~~**Session Recovery**~~ ✅ — `session_reset` protocol, auto-recreate via X3DH on next message +8. ~~**Connection Pooling**~~ ✅ — `MySQLConnectionPool(pool_size=10)`, lazy init, `DB_POOL_SIZE` env var. `conn.close()` vrací do poolu. +9. ~~**Version negotiation**~~ ✅ — `VERSION = "0.8.4"` in protocol.py, client sends `client_version` at login, server rejects clients < MIN_CLIENT_VERSION +10. **Delivery Receipts** — `message_delivered` notifikace po přijetí na zařízení (1 fajfka = odesláno, 2 fajfky = doručeno, modré = přečteno). Nová tabulka `message_deliveries` nebo rozšíření `message_reads`. +11. ~~**Reakce na zprávy (Message Reactions)**~~ ✅ — emoji reakce na zprávy. Tabulka `message_reactions`. Push notifikace `message_reacted`. GUI: emoji badges pod zprávou, submenu v kontext menu. +12. ~~**Přeposílání zpráv (Message Forwarding)**~~ ✅ — kontext menu "Forward", výběr konverzace, odeslání s `forwarded_from` metadatem. GUI: "Forwarded from" header. +13. ~~**Připnuté zprávy (Pinned Messages)**~~ ✅ — `messages.pinned_at`/`pinned_by` sloupce. `pin_message`/`unpin_message`/`get_pinned_messages` protokol. GUI: pin ikona + dialog s připnutými zprávami. +14. ~~**Zmínky (@mentions)**~~ ✅ — parsování `@username` v textu, autocomplete při psaní @, zvýraznění zmínek v modré. Klient-side only (bez server notifikací). +15. ~~**Contact Key Verification**~~ ✅ — Signal-style safety numbers (60 digits, symmetric), fingerprints (30 digits), QR codes. TOFU key tracking + explicit verification. GUI: VerificationDialog, green checkmark v conv listu, E2E label status, key change warning. CLI: options 20+21. Zero server changes. + +#### Optimalizace serveru +1. ~~**DB Connection Pooling**~~ ✅ — viz bod 8 výše. +2. ~~**Oprava duplicitních FETCH v GUI**~~ ✅ — `send_message` vrací payload lokálně, GUI appenduje bez re-fetch. Dedup guard v `_on_notification`. +3. ~~**Batch prekey replenishment**~~ ✅ — `ensure_prekeys` handler na serveru (get_prekey_count + upload v jednom roundtripu). `_generate_and_upload_prekeys_batch()` v chat_core. +4. ~~**Server-side message count**~~ ✅ — `get_messages` response obsahuje `total_count`. `db.count_messages()` funkce. +5. **Prepared statements / query cache** — pro často opakované dotazy (get_messages, list_conversations) připravit prepared statements. +6. **WebSocket upgrade** — dlouhodobě nahradit raw TCP za WebSocket pro lepší kompatibilitu s firewally, load balancery, a web klienty. + +#### Mobilní push notifikace (budoucí — iOS + Android) +- Tabulka `push_tokens(user_id, device_id, platform ENUM('ios','android'), token, created_at)` +- Server: při `new_message` pokud cílový uživatel nemá aktivní TCP spojení → odeslat přes APNs (iOS) / FCM (Android) +- Obsah push: jen "Nová zpráva od X" (E2EE = server nezná plaintext) +- iOS: `UserNotifications` framework, registrace tokenu při loginu, `didReceiveRemoteNotification` +- Android: Firebase Cloud Messaging, `FirebaseMessagingService` +- Server-side: `aioapns` (Python APNs library) + `firebase-admin` (FCM SDK) + +## Bezpečnostní audit (Security Audit) + +Kompletní audit provedený přes všechny soubory projektu. Nálezy seřazené podle závažnosti. + +### 🔴 CRITICAL — Okamžitě řešit před nasazením + +#### ~~C1. readuntil() bez limitu → memory exhaustion (protocol.py:62)~~ ✅ OPRAVENO +`ProtocolReader.read_message()` volá `readuntil(b"\n")`, které načte CELOU zprávu do paměti PŘED kontrolou velikosti. Útočník pošle gigabyty dat bez newline → server spadne na out-of-memory. +```python +line = await self._reader.readuntil(b"\n") # buffers everything first! +if len(line) > MAX_MESSAGE_BYTES: # too late +``` +**Řešení:** Implementovat framing s hlavičkou obsahující velikost zprávy, nebo použít `readuntil()` s `limit` parametrem (asyncio StreamReader nemá nativně — nutno obalit vlastním čtením po částech). + +#### ~~C2. SenderKeyState — neomezený fast-forward DoS (crypto_utils.py:642-645)~~ ✅ OPRAVENO +Při dešifrování skupinové zprávy s libovolně vysokým `n` se smyčka `while self.n <= n` provede milionkrát — derivuje milion klíčů, spotřebuje stovky MB RAM. +```python +while self.n <= n: + self.chain_key, mk = kdf_ck(self.chain_key) + self._known_keys[self.n] = mk # unbounded dict growth + self.n += 1 +``` +**Řešení:** Přidat `MAX_FORWARD_SKIP` limit (např. 1000) — stejně jako Double Ratchet má `MAX_SKIP=256`. + +#### C3. Dešifrované zprávy uložené jako plaintext na disku (chat_core.py:222-239) +Message cache v `~/.encrypted_chat/{email}/message_cache/{conv_id}.json` obsahuje plný obsah dešifrovaných zpráv v nešifrovaném JSONu. Bez nastavení `chmod 0o600`. Kdokoliv s přístupem k disku přečte kompletní historii. +**Řešení:** Šifrovat cache klíčem odvozeným z identity key + nastavit `chmod 0o600` na soubory. + +#### ~~C4. OPK private keys bez file permissions (chat_core.py:153-156)~~ ✅ OPRAVENO +OPK privátní klíče se ukládají bez `os.chmod(0o600)`. RSA klíče (řádek 87) a identity key (řádek 121) mají `chmod` — OPK ne. Na sdílených systémech může jiný uživatel přečíst ephemeral klíče. +**Opraveno:** Součást C3+H1+M13 fixu — `_save_opk_private()` nyní volá `os.chmod(path, 0o600)` + `os.chmod(dir, 0o700)`. + +#### ~~C5. Chunked upload nevaliduje celkovou velikost (server.py:1138-1142)~~ ✅ OPRAVENO +`handle_upload_image_chunk` akumuluje `received_bytes` ale nekontroluje limit. Útočník deklaruje `file_size=5MB`, pak posílá chunky donekonečna → disk exhaustion. +```python +upload["received_bytes"] += len(raw) # no check against file_size! +``` +**Řešení:** Přidat `if upload["received_bytes"] > upload["file_size"]: reject`. + +### 🟠 HIGH — Řešit před production nasazením + +#### H1. Session + sender key soubory nešifrované na disku (chat_core.py:176-215) +Double Ratchet session state (DH privátní klíče, root key, chain keys) a sender key state se ukládají jako plaintext hex JSON v `sessions/` a `sender_keys/`. Bez šifrování, bez `chmod 0o600`. Kompromitace disku = dešifrování celé historie. +**Řešení:** Šifrovat state klíčem z identity key, nastavit `chmod 0o600`. + +#### ~~H2. TLS vypnuté ve výchozím stavu (chat_core.py:274-291, server.py)~~ ✅ OPRAVENO (hardening) +`TLS_ENABLED` je defaultně `false`. Bez TLS jdou po síti RSA challenge-response, session tokeny a metadata v plaintextu. `TLS_INSECURE=true` vypíná certificate verification → MITM. +**Opraveno:** `TLS_INSECURE` a `TLS_AUTOGEN` nyní vyžadují `ENVIRONMENT=dev` — v produkci RuntimeError. Warning log při vypnutém TLS na serveru i klientovi. TLS_ENABLED zůstává default false (uživatel nemá certifikát), ale po nasazení Let's Encrypt stačí flip na true. + +#### H3. User enumeration přes registraci (server.py:182-189) +Registrace vrací "Email already in use" pro existující uživatele vs. tiché vytvoření phantoma pro neexistující. Útočník může enumerovat platné emaily. +**Řešení:** Vrátit generickou odpověď "Check your email for verification code" i když email existuje. + +#### ~~H4. Race conditions v in-memory strukturách (server.py: multiple)~~ ✅ OPRAVENO +`connected_clients` dict, `phantom_user_ids` set, `pairing_sessions` dict — čteny a zapisovány z více concurrent koroutin bez synchronizace. Asyncio je single-threaded, ale yieldy uvnitř handlerů (await) mohou způsobit nekonzistentní stav. +**Opraveno:** 4 asyncio.Lock guards: `_clients_lock` (connected_clients, phantom_user_ids), `_conn_lock` (connection_counts, current_connections, rate_limits), `_pairing_lock` (pairing_sessions, pending_registrations), `_uploads_lock` (pending_uploads). Helper funkce `_notify_users()` / `_notify_users_individual()` — snapshot under lock, send outside. Rate limit memory cleanup v periodic task. Žádný handler nedrží dva locky současně → deadlock impossible. + +#### H5. base64 decode bez error handling (protocol.py:14-16, server.py + chat_core.py) +`decode_binary()` volá `base64.b64decode()` bez try-except. Nevalidní base64 od klienta → unhandled `binascii.Error` → handler crash. Mnoho callsites v server.py (řádky 357, 378, 783) nemá catch. +**Řešení:** Obalit `decode_binary()` try-except, nebo validovat base64 vstup před dekódováním. + +#### H6. JSON parsing bez exception handling (protocol.py:48-50) +`parse_message()` volá `json.loads()` bez try-catch. Malformovaný JSON = neošetřený `JSONDecodeError`. Server handler catch (řádek 1399) to odchytí, ale není to explicitní. +**Řešení:** Obalit `json.loads()` v `parse_message()` try-except s explicitní chybovou zprávou. + +#### H7. Path traversal v avatar souborech (server.py:1265, 1318) +`avatar_file` ze serveru (z DB) se přímo joinuje s `UPLOAD_DIR / "avatars"` bez validace. Pokud DB obsahuje `../../etc/passwd`, server přečte libovolný soubor. +**Řešení:** Přidat `resolved_path.resolve().is_relative_to(UPLOAD_DIR)` check. + +#### H8. Heslo v paměti jako Python string (chat_core.py, gui_client.py) +Python stringy jsou immutable — nelze je bezpečně vymazat z paměti. Heslo zůstává v paměti dokud garbage collector neuklidí. Memory dump = plaintext heslo. +**Řešení:** Použít `bytearray` (mutable), po použití přepsat nulami: `pwd[:] = b'\x00' * len(pwd)`. + +#### H9. Self-encryption key je statický a deterministický (chat_core.py:904+, crypto_utils.py:224-233) +`derive_self_encryption_key(identity_private)` produkuje vždy stejný klíč. Kompromitace identity klíče = dešifrování všech vlastních kopií zpráv navždy. Žádná forward secrecy pro self-copies. +**Poznámka:** Toto je by-design (nutné pro cross-device čtení), ale je to architektonické omezení. + +#### H10. Malicious image data → QImage crash (gui_client.py) +`QImage.fromData(data)` zpracovává nevalidované binární data. Speciálně vytvořený obrázek může způsobit crash, memory exhaustion, nebo v extrémním případě RCE přes Qt image codec. +**Řešení:** Validovat velikost dat před parsováním, limit na max rozlišení. + +#### H11. Filename z serveru v save dialogu (gui_client.py:2389, 2460) +Server-controlled `filename` se předává jako default do `QFileDialog.getSaveFileName()`. Pokud server pošle `"../../../.bashrc"`, dialog to navrhne. +**Řešení:** Sanitizovat filename — odstranit `../`, `\`, absolutní cesty. Použít jen `os.path.basename()`. + +### 🟡 MEDIUM — Zvážit pro hardening + +#### M1. Inconsistentní Ed25519 serializace (crypto_utils.py:99-102) +Bez hesla: 32 raw bytes. S heslem: PEM PKCS8 (~302 bytes). Dva různé formáty mohou způsobit problém při migraci nebo obnově klíčů. +**Poznámka:** M3 fix částečně řeší — s heslem je nyní ECP1 formát (ne PEM), ale `load_ed25519_private()` stále detekuje 3 formáty (ECP1, PEM, raw). Legacy PEM soubory se automaticky migrují při dalším uložení. + +#### ~~M2. Prázdný salt v SenderKeyState HKDF (crypto_utils.py:610)~~ ✅ OPRAVENO +`hkdf_derive(sender_key, salt=b"", ...)` — RFC 5869 doporučuje nenulový salt. X3DH správně používá `b"\x00" * 32`. +**Opraveno:** Změněno `salt=b""` → `salt=b"\x00" * 32` aby odpovídalo X3DH konvenci. + +#### ~~M3. PBKDF2 iterace pod doporučeným minimem (crypto_utils.py)~~ ✅ OPRAVENO +`BestAvailableEncryption` používá ~100k iterací PBKDF2. OWASP 2023 doporučuje 480k+. +**Opraveno:** Nahrazeno vlastním `_encrypt_private_key`/`_decrypt_private_key` s PBKDF2-HMAC-SHA256 (600k iterací) + AES-256-GCM. ECP1 formát (magic prefix) s backward compat pro staré PEM soubory. Aplikováno na RSA (`serialize_private_key`/`load_private_key`) i Ed25519 (`serialize_ed25519_private`/`load_ed25519_private`). + +#### ~~M4. SPK bez replay protection a bez rotace (server.py:360-368)~~ ✅ OPRAVENO +Stejný SPK lze nahrát opakovaně. Žádný nonce/timestamp v podpisu. SPK se nikdy nerotuje → kompromitovaný SPK = trvalé dešifrování nových sessions. +**Opraveno:** SPK rotace každých 7 dní (`SPK_ROTATION_DAYS`). Server vrací `spk_created_at` v `handle_get_prekey_count`. `_ensure_prekeys()` kontroluje stáří SPK a rotuje pokud >= 7 dní. Předchozí SPK uložen jako grace period (`prev_spk_private.bin`/`prev_spk_id.txt`) pro in-flight X3DH. `_process_x3dh_header` přijímá `spk_override`, `_decrypt_dm` retry s předchozím SPK při selhání dešifrování. + +#### ~~M5. Rate limit unbounded memory (server.py:73-83)~~ ✅ OPRAVENO (součást H4 fixu) +Staré záznamy se nikdy nečistí pokud klíč přestane být aktivní → útočník vytvoří miliony klíčů → memory leak. +**Opraveno:** `_cleanup_rate_limits()` v periodic cleanup (každých 10 min) maže stale entries z `rate_limits` i `connection_counts`. + +#### ~~M6. TOCTOU race v membership checks (db.py)~~ ✅ OPRAVENO +`is_conversation_member()` → `remove_conversation_member()` — mezi kontrolou a akcí může jiný klient stav změnit. +**Opraveno:** `db.remove_conversation_member_atomic()` vrací bool (True pokud řádek existoval). Použito v `handle_remove_member` a `handle_leave_group`. + +#### M7. MySQL spojení bez TLS (db.py:20-28) +`get_connection()` nepředává SSL parametry. Na vzdáleném serveru jdou credentials v plaintextu. + +#### ~~M8. Chybějící validace UUID formátu (server.py: throughout)~~ ✅ OPRAVENO +`conv_id`, `user_id` — kontrola jen na neprázdnost, ne na formát UUID v4. +**Opraveno:** `_valid_file_id()` přejmenováno na `_valid_uuid()`. UUID validace přidána do všech handlerů přijímajících klientem poskytnuté `conv_id`, `user_id`, `message_id`, `device_id`. + +#### ~~M9. Ratchet state corruption recovery (chat_core.py:1088-1104)~~ ✅ OPRAVENO +Pokud `decrypt()` změní chain keys ale selže na AAD verification, backup/restore mechanismus funguje, ale pokud backup selže (out-of-memory), stav zůstane corrupted. +**Opraveno:** `DoubleRatchet.decrypt()` nyní snapshotuje stav přes `_snapshot()`/`_restore()` a rollbackuje při jakékoliv výjimce (včetně skipped message key restore). `SenderKeyState.decrypt()` stejně snapshotuje `chain_key`, `n`, `_known_keys` před fast-forward a rollbackuje při selhání. + +#### ~~M10. Chybějící validace velikosti message_ids listu (db.py:641-646)~~ ✅ OPRAVENO +Klient může poslat tisíce message_ids v jednom požadavku → pomalý SQL dotaz, DoS. +**Opraveno:** `handle_mark_read` nyní odmítá požadavky s více než 500 message_ids. + +### 🟢 LOW — Dobrá praxe, nízké riziko + +- **L1.** Hex string keys v skipped messages dict — timing side-channel po úspěšné autentikaci (crypto_utils.py:425) +- **L2.** RatchetHeader serializace redundantně konvertuje typy (crypto_utils.py:394-405) +- **L3.** `notif_label.setText()` bezpečné proti XSS (Qt neinterpretuje HTML v setText), ale křehké — přepnutí na `setHtml()` by to rozbilo (gui_client.py:1524, 2259) +- **L4.** SQL column interpolation v `update_user_profile` — whitelist chrání, ale pattern je nebezpečný při kopírování (db.py:818-822) +- **L5.** Chybějící TLS cipher suite hardening — Python defaulty jsou rozumné, ale ne explicitně nastavené (protocol.py) +- **L6.** Temporary pairing key není bezpečně vymazán z paměti (chat_core.py:581) +- **L7.** `_user_cache` ukládá public identity keys indefinitely — memory leak pro hodně kontaktů + +### Druhý bezpečnostní review (zaměření na návrh, DB, komunikaci, lokální tmp/cache) + +#### C6. Path traversal → libovolný zápis/smazání souborů přes file_id (server.py) +`handle_upload_image_start` vytváří soubor `UPLOAD_DIR / f"{file_id}.tmp"` bez validace `file_id`. Útočník může poslat `../../...` a zapisovat mimo UPLOAD_DIR. Následné rename, cleanup, `delete_message` a `delete_conversation` pak mohou mazat libovolné soubory. +**Řešení:** Striktně validovat file_id (UUID hex/kanonický formát), odmítnout cokoliv s `/`, `\`, `..`. Ověřit `path.resolve().is_relative_to(UPLOAD_DIR.resolve())`. Ideálně ukládat do podadresářů podle hash/UUID. + +#### ~~H12. OPK race condition — reuse one-time pre-keys (db.py)~~ ✅ OPRAVENO +V `db.get_key_bundle()` se OPK vybírá SELECT → DELETE bez transakčního zámku. Při souběhu může být stejný OPK vydán vícekrát → porušení bezpečnostních předpokladů X3DH. +**Opraveno:** `consume_one_time_prekey()` a `get_key_bundle()` nyní používají `SELECT ... FOR UPDATE` + DELETE v jedné transakci (součást multi-device implementace). + +#### H13. Neautentizované get_user_info + identity key exfiltrace (server.py) +`get_user_info` je dostupné bez přihlášení a vrací username, email a identity_key pro libovolný email/user_id. Umožňuje enumeraci uživatelů a sběr metadata/klíčů. +**Řešení:** Vyžadovat auth, nebo omezit na "kontakty v konverzaci". + +#### ~~H14. TLS_INSECURE umožňuje MITM i v produkci (chat_core.py, server.py)~~ ✅ OPRAVENO +`TLS_INSECURE=true` vypíná verifikaci certifikátu → útočník může podvrhnout key bundle. Přímo ohrožuje E2EE integritu. +**Opraveno:** `TLS_INSECURE` vyžaduje `ENVIRONMENT=dev`, jinak RuntimeError. Součást H2 fixu. + +#### ~~M11. Pairing poll DoS — neautentizovaný přístup k payload (server.py)~~ ✅ OPRAVENO +Kdokoli s 8-místným kódem může pollovat a "vyzvednout" payload (smazán po vyzvednutí). I když je šifrovaný, jde o snadný DoS (reálnému zařízení pairing selže). +**Opraveno:** `handle_pairing_start` generuje `poll_token` (secrets.token_hex(16)), vrací klientovi. `handle_pairing_poll` vyžaduje `poll_token` a porovnává přes `secrets.compare_digest()`. Klient ukládá token v `pairing_start()` a posílá v `pairing_wait()`. + +#### ~~M12. Upload end bez validace received_bytes == file_size (server.py)~~ ✅ OPRAVENO +`upload_image_end` neověřuje, že `received_bytes == file_size`. Může zůstat nedokončený/nevalidní soubor. +**Řešení:** Kontrola délky před `complete_image_upload`. + +#### M13. Klíčové adresáře bez chmod 700 (chat_core.py) +`get_key_dir` a podadresáře (`sessions`, `sender_keys`, `opk_private`) se vytvářejí bez explicitních práv; spoléhá se na umask. +**Řešení:** Po `mkdir` vždy `chmod 0o700` pro adresář, `0o600` pro soubory. + +#### ~~L8. Phantom users — DB inflation (server.py, db.py)~~ ✅ OPRAVENO +`find_conversation` vytváří phantom usery pro libovolné emaily. I s rate limit lze DB časem nafouknout. +**Opraveno:** `_valid_email()` validace před vytvořením phantomu. `db.cleanup_stale_phantoms(30)` v periodic cleanup — maže phantomy starší 30 dní bez aktivních konverzací s reálnými uživateli. + +### Bezpečnostní matice (souhrn) + +| Soubor | CRITICAL | HIGH | MEDIUM | LOW | +|--------|----------|------|--------|-----| +| `protocol.py` | 1 (C1) | 2 (H5, H6) | 0 | 1 (L5) | +| `crypto_utils.py` | 1 (C2) | 0 | 3 (M1-M3) | 2 (L1, L2) | +| `server.py` | 2 (C5, C6) | 3 (H3, H7, H13) | 4 (M4, M8, M11, M12) | 0 | +| `chat_core.py` | 2 (C3, C4) | 4 (H1, H2/H14, H8, H9) | 2 (M9, M13) | 1 (L6) | +| `gui_client.py` | 0 | 2 (H10, H11) | 0 | 2 (L3, L7) | +| `db.py` | 0 | 1 (H12) | 3 (M6, M7, M10) | 2 (L4, L8) | +| **Celkem** | **6** | **12** | **12** | **8** | +| **Opraveno** | 6 (~~C1~~, ~~C2~~, ~~C3~~, ~~C4~~, ~~C5~~, ~~C6~~) | 11 (~~H1~~, ~~H2~~, ~~H3~~, ~~H4~~, ~~H5~~, ~~H6~~, ~~H7~~, ~~H8~~, ~~H10~~, ~~H11~~, ~~H12~~, ~~H13~~, ~~H14~~) | 11 (~~M2~~, ~~M3~~, ~~M4~~, ~~M5~~, ~~M6~~, ~~M8~~, ~~M9~~, ~~M10~~, ~~M11~~, ~~M12~~, ~~M13~~) | 1 (~~L8~~) | +| **Zbývá** | **0** | **1** | **1** | **7** | + +### Doporučené pořadí oprav (aktualizováno) +1. ~~**C6**~~ ✅ — Path traversal přes file_id — DONE (UUID validace + is_relative_to) +2. ~~**C1 + C2 + C5**~~ ✅ — DoS vektory — DONE (LimitOverrunError → ValueError, MAX_SENDER_KEY_SKIP 256, upload completeness check) +3. ~~**H12**~~ ✅ — OPK race condition — DONE (SELECT FOR UPDATE, součást multi-device) +4. ~~**C3 + H1 + H7 + M13**~~ ✅ — Šifrování dat na disku + file/dir permissions + avatar path traversal — DONE +5. ~~**H2/H14**~~ ✅ — TLS hardening (TLS_INSECURE + TLS_AUTOGEN vyžadují ENVIRONMENT=dev, warning log) — DONE +6. ~~**H5 + H6**~~ ✅ — Error handling v protokolu (base64, JSON) — DONE +7. ~~**H3 + H13**~~ ✅ — User enumeration (generické odpovědi, auth pro get_user_info) — DONE +8. ~~**H4**~~ ✅ — Race conditions (asyncio.Lock) — DONE +9. ~~**H8 + H10 + H11**~~ ✅ — Paměť hesel, image parsing, filename sanitizace — DONE +10. ~~**M2 + M8 + M10 + M11**~~ ✅ — Hardening batch (HKDF salt, UUID validace, message_ids cap, pairing poll token) — DONE +11. ~~**M3, M4, M9**~~ ✅ — PBKDF2 600k iterations, SPK rotace 7 dní s grace period, ratchet state rollback — DONE +12. **M1, ~~M6~~, M7** — Remaining hardening (Ed25519 serialization, ~~TOCTOU~~ ✅, MySQL TLS) +13. **Penetrační testy** — manuální + automatizované security testy + +## Důležitá rozhodnutí a kontext +- **Invitations replace direct add for groups:** `handle_add_member` and `handle_create_conversation` (for groups) now create invitations instead of directly adding members. DMs still auto-add both users. This was a design decision to give users control over joining groups. +- **Group delete = total destruction:** When creator deletes a group, ALL members are removed and the conversation is fully deleted. This is different from "leave group" which only removes the leaving user. +- **DM delete is per-user:** Deleting a DM only removes the deleting user. The other user still sees the conversation until they also delete it. +- **Avatar caching in GUI is pixmap-based:** `_avatar_cache` and `_group_avatar_cache` store QPixmap objects, not raw bytes. The `_on_avatar_for_conv_list` and `_on_group_avatar_for_conv_list` signals convert bytes → QImage → QPixmap on receipt. +- **No context menu on conversation list anymore:** Delete was the only action. Now handled by header buttons. `conv_list.setContextMenuPolicy(DefaultContextMenu)`. + +## Environment Variables +See README.md for full list. Key: `SERVER_HOST`, `SERVER_PORT`, `MYSQL_*`, `TLS_*`, `SMTP_*`, `LOG_LEVEL`, `MAX_INPUT_CHARS`, `UPLOAD_DIR`, `MAX_IMAGE_BYTES`, `MAX_FILE_BYTES`, `MAX_MESSAGE_BYTES`, `METADATA_RETENTION_DAYS` (default 90). + +## Commands & Workflow + +- Start server: `python server.py` +- Start GUI client: `python gui_client.py` +- Start CLI client: `python client.py` +- Environment: `.env` file in project root (loaded by `dotenv`) +- Dependencies: `PyQt6`, `mysql-connector-python`, `cryptography`, `Pillow` (for image sharing), `python-dotenv` +- Check syntax: `python3 -m py_compile .py` +- All files on both server AND client side: `crypto_utils.py`, `protocol.py`, `chat_core.py`, `gui_client.py` (or `client.py`) diff --git a/README.md b/README.md index c21fb9a..c0221cf 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,343 @@ +<<<<<<< HEAD # Kecalek_python +======= +# Encrypted Chat + +End-to-end encrypted chat s forward secrecy (X3DH + Double Ratchet, Signal Protocol). +Server ukládá a přeposílá šifrované bloby — nikdy nevidí plaintext. + +## Soubory + +### Server +| Soubor | Účel | +|--------|------| +| `server.py` | Asyncio TCP server, handler dispatch, rate limiting, notifikace | +| `db.py` | MySQL CRUD, jedna connection na volání | +| `schema.sql` | MySQL schéma (users, conversations, messages, ...) | + +### Klient +| Soubor | Účel | +|--------|------| +| `gui_client.py` | PyQt6 GUI | +| `client.py` | CLI klient | +| `chat_core.py` | Logika klienta — session management, šifrování, lokální klíče | + +### Sdílené (server + klient) +| Soubor | Účel | +|--------|------| +| `crypto_utils.py` | Ed25519, X25519, AES-256-GCM, HKDF, PBKDF2, X3DH, Double Ratchet (state rollback), Sender Keys (state rollback), ECP1 key encryption | +| `protocol.py` | Newline-delimited JSON protokol, base64 encoding | + +## Quick Start + +1. `pip install -r requirements.txt` +2. Spustit `schema.sql` v MySQL (kompletní clean start). Pro migraci existující DB: `migration_multi_device.sql`. +3. `python server.py` +4. Klient: `python client.py` (CLI) nebo `python gui_client.py` (GUI, PyQt6) + +## Jak funguje šifrování + +### Klíče na uživatele +| Klíč | Typ | Účel | +|------|-----|------| +| RSA-4096 | Asymetrický | Pouze login challenge-response. Šifrovaný PBKDF2 (600k iterací) + AES-256-GCM. | +| Identity Key (IK) | Ed25519 | Podpisy, konverze na X25519 pro X3DH. Šifrovaný PBKDF2 (600k iterací) + AES-256-GCM. | +| Signed Pre-Key (SPK) | X25519 | DH v X3DH, podepsaný IK. **Rotuje se každých 7 dní** s grace periodem pro in-flight X3DH. | +| One-Time Pre-Keys (OPK) | X25519 | Jednorázové, spotřebuje se při X3DH, automaticky doplňované (< 20 → +50) | + +### DM (1:1 zprávy) — X3DH + Double Ratchet +1. Alice chce napsat Bobovi poprvé → stáhne jeho key bundle (IK, SPK, OPK) ze serveru. +2. X3DH: 4 DH výpočty → shared secret. +3. Double Ratchet inicializován ze shared secret. +4. Každá zpráva: symmetric ratchet (HMAC chain) → message key → AES-256-GCM. +5. Každá odpověď: DH ratchet (nový X25519 keypair) → nový root key + chain key. +6. Per-recipient ciphertext — každý recipient má vlastní šifrovaný blob. +7. Při selhání dešifrování: automatický rollback stavu ratchetu (snapshot/restore). + +### Skupiny — Sender Keys +1. Každý člen má vlastní sender key chain pro skupinu. +2. Sender key se distribuuje ostatním členům přes pairwise Double Ratchet (jako DM). +3. Skupinové zprávy: symmetric ratchet na sender key → AES-256-GCM. +4. Jeden ciphertext pro celou skupinu (efektivní). + +### Lokální úložiště klíčů +``` +~/.encrypted_chat/{email}/ + private.pem # RSA (login) — ECP1 formát s heslem, PEM bez hesla + public.pem # RSA (login) + identity_private.bin # Ed25519 — ECP1 formát s heslem, 32B raw bez hesla + identity_public.bin # Ed25519 + device_id.txt # UUID tohoto zařízení + spk_private.bin # Aktuální signed prekey (šifrovaný AES-256-GCM) + spk_id.txt + prev_spk_private.bin # Předchozí SPK, grace period (šifrovaný AES-256-GCM) + prev_spk_id.txt + opk_private/ # One-time prekeys (šifrované AES-256-GCM) + {opk_id}.bin + login_lockout.json # Brute-force lockout stav (failed_attempts, locked_until) + sessions/ # Double Ratchet stavy (šifrované AES-256-GCM) + {user_id}_{device_id}.bin + sender_keys/ # Vlastní sender keys pro skupiny + {conv_id}.bin + sender_keys_recv/ # Přijaté sender keys od ostatních + {conv_id}_{sender_id}_{device_id}.bin +``` + +## Bezpečnostní hardening + +### Šifrování privátních klíčů na disku (ECP1 formát) +RSA a Ed25519 privátní klíče šifrované heslem používají vlastní formát ECP1 (Encrypted Chat PBKDF v1): +- **PBKDF2-HMAC-SHA256** s 600 000 iteracemi (OWASP 2023 doporučuje 480k+) +- **AES-256-GCM** pro šifrování, magic bytes "ECP1" jako AAD +- **Formát:** `ECP1(4B) + salt(16B) + nonce(12B) + ciphertext+tag` +- **Zpětná kompatibilita:** Staré PEM soubory (z `BestAvailableEncryption`) se načtou automaticky a při dalším uložení se přešifrují do ECP1. + +### Šifrování SPK/OPK na disku +SPK a OPK privátní klíče jsou šifrované AES-256-GCM klíčem `_local_key` (HKDF z Ed25519 identity key): +- Při save: `_encrypt_local(raw, local_key)` → `nonce(12B) + tag(16B) + ciphertext` +- Při load: `_decrypt_local()` s transparentní migrací — pokud dešifrování selže, načte jako plaintext a uloží šifrovaně +- Aplikováno na `spk_private.bin`, `prev_spk_private.bin`, `opk_private/*.bin` + +### Brute-force ochrana (client-side lockout) +Po chybném zadání hesla se prodlužuje čas do dalšího pokusu: +- **Vzorec:** `min(2^N, 300)` sekund, kde N = počet neúspěšných pokusů (2s, 4s, 8s, ... až 5 min) +- **Stav:** `login_lockout.json` v adresáři klíčů (`failed_attempts`, `locked_until`) +- **Aplikováno na:** `ChatClient.login()` (síťový login) + GUI privacy overlay unlock (`_on_unlock_attempt`) +- **Reset:** Úspěšné přihlášení smaže lockout soubor +- **Defense-in-depth:** Smazání souboru resetuje počítadlo, ale PBKDF2-600k stále zpomaluje každý pokus (~0.5s/pokus) + +### SPK rotace (7 dní) +Signed Pre-Key se rotuje periodicky: +- Po přihlášení `_ensure_prekeys()` zjistí stáří SPK ze serveru (`spk_created_at`) +- Pokud je SPK starší než 7 dní → vygeneruje nový, starý uloží jako grace period +- **Grace period:** `prev_spk_private.bin` — pokud příchozí X3DH selže s aktuálním SPK, zkusí předchozí +- Omezuje dopad kompromitace SPK — útočník může vytvářet nové sessions max 7 dní + +### Odolnost ratchetu (state rollback) +Double Ratchet i Sender Keys automaticky rollbackují stav při selhání dešifrování: +- Před modifikací chain keys/counters se vytvoří snapshot +- Pokud AES-GCM dešifrování selže (corrupted data, wrong key), stav se obnoví +- Session zůstane funkční i po zpracování poškozené zprávy + +## Registrace + +1. `register` → server pošle 6-místný kód na email (nebo vrátí přímo v dev módu bez SMTP). +2. `register_confirm` → potvrzení kódu. +3. Automaticky se vygenerují a uploadnou prekeys (1 SPK + 50 OPKs). +4. Login. + +## Multi-Device Support + +Pravý multi-device (Signal-like) — každé zařízení má nezávislé Double Ratchet sessions. +Při posílání DM se zpráva šifruje zvlášť pro každé zařízení příjemce. +Všechna zařízení uživatele sdílejí Ed25519 identity key (pro self-encryption kompatibilitu). + +### Architektura +- **Devices tabulka** — každé přihlášení registruje device (UUID), server mapuje writer→device +- **Per-device prekeys** — každé zařízení má vlastní SPK + OPKs, server vrací `device_bundles` pole +- **Per-device sessions** — sessions klíčované `"user_id:device_id"`, nezávislé Double Ratchet instance +- **Self-encryption** — odesílatel šifruje vlastní kopii statickým klíčem z identity key (čitelné všemi vlastními zařízeními) +- **Notifikace** — `device_entries` pole, klient vybere záznam odpovídající svému device_id + +### Device Pairing (zjednodušený) + +Nové zařízení získá RSA + Ed25519 identity klíče od existujícího zařízení. +Přenos šifrovaný RSA-OAEP + AES-GCM přes server (server nevidí klíče). +Nové zařízení si po přihlášení automaticky vygeneruje vlastní SPK + OPKs. + +1. Nové zařízení: `Link Device` → dostane 8-místný kód. +2. Existující zařízení: `Authorize Device` → zadá kód → odešle RSA + identity klíče. +3. Nové zařízení importuje klíče, přihlásí se, vygeneruje vlastní prekeys. + +### Migrace +- Existující DB: spustit `migration_multi_device.sql` (nebo `migration_multi_device_resume.sql` pro idempotentní re-run) +- Čistá DB: `schema.sql` již obsahuje všechny multi-device sloupce + +## Device Revocation (Key Rotation) + +Rotuje RSA login klíč. Odpojí ostatní sessions. Forward secrecy zajišťuje, že kompromitace +jednoho session klíče neodhalí historii — není potřeba re-encryption. + +## Konfigurace + +### Server + DB +- `SERVER_HOST` (default `127.0.0.1`), `SERVER_PORT` (default `9999`) +- `MYSQL_HOST`, `MYSQL_PORT`, `MYSQL_USER`, `MYSQL_PASSWORD`, `MYSQL_DATABASE` + +### TLS +- `TLS_ENABLED` — zapne TLS (default `false`) +- `TLS_REQUIRED` — vyžaduje TLS_ENABLED, jinak server odmítne start +- `TLS_CERT_FILE`, `TLS_KEY_FILE` — cesty k certifikátu a privátnímu klíči (PEM) +- `TLS_AUTOGEN` — auto-generuje self-signed cert (**jen s `ENVIRONMENT=dev`**) +- `TLS_CA_FILE` (klient) — vlastní CA certifikát pro ověření serveru +- `TLS_INSECURE` (klient) — vypne ověření certifikátu (**jen s `ENVIRONMENT=dev`**) +- `ENVIRONMENT` — `dev`/`development` povolí TLS_INSECURE a TLS_AUTOGEN + +#### Produkční nasazení s Let's Encrypt +```bash +# 1. Nainstalovat certbot +sudo apt install certbot + +# 2. Získat certifikát (port 80 musí být volný pro ověření) +sudo certbot certonly --standalone -d chat.example.com + +# 3. V .env nastavit: +TLS_ENABLED=true +TLS_CERT_FILE=/etc/letsencrypt/live/chat.example.com/fullchain.pem +TLS_KEY_FILE=/etc/letsencrypt/live/chat.example.com/privkey.pem + +# 4. Klient — stačí zapnout TLS (Let's Encrypt je v systémovém trust store): +TLS_ENABLED=true +``` +Certifikát funguje na jakémkoliv portu (9999, 443, ...) — je vázaný na doménu, ne port. Certbot automaticky obnovuje certifikát každých 90 dní. + +#### Dev/testování (self-signed) +```bash +ENVIRONMENT=dev +TLS_ENABLED=true +TLS_AUTOGEN=true # server auto-generuje self-signed cert +TLS_INSECURE=true # klient přeskočí ověření certifikátu +``` + +### SMTP +- `SMTP_HOST`, `SMTP_PORT`, `SMTP_USER`, `SMTP_PASS`, `SMTP_FROM` +- Bez SMTP = dev mód (kód se vrací přímo klientovi). + +### Obrázky +- `UPLOAD_DIR` (default `uploads`), `MAX_IMAGE_BYTES` (default 5 MB, `0` = bez limitu) + +### Limity +- `MAX_MESSAGE_BYTES` (default `65536`), `MAX_INPUT_CHARS` (GUI, default `2000`) +- Rate limity: register 3/min, login 10/min, send_message 20/min, pairing_poll 10/min +- Connection: 20 req/s per connection, max 10 per IP, 200 global +- Pairing TTL: 120s, max 5 failed poll pokusů + +### Logging +- `LOG_LEVEL` (default `INFO`) + +## Features + +- Registrace (2-step, SMTP), login (RSA challenge-response), key rotation +- **Multi-device** — per-device sessions (Signal-like), device pairing (RSA + identity key transfer), automatické prekey generování na novém zařízení +- DM s forward secrecy (X3DH + Double Ratchet) — per-device šifrování +- Skupiny se Sender Keys (distribuované přes pairwise ratchet) +- Skupinové pozvánky — přidání do skupiny vyžaduje souhlas (accept/decline) +- Odpovědi na zprávy (reply_to) +- Mazání zpráv (soft-delete pro všechny, real-time notifikace) +- Mazání konverzací (pravý klik → smaže pro uživatele, pokud nezbývají členové smaže celou konverzaci) +- Šifrované obrázky (AES-256-GCM, chunked upload, thumbnail v bublině) +- Šifrované soubory (PDF, ZIP, atd. až 50 MB, chunked upload) +- Read receipts (real-time, client-side resoluce) +- Prekey replenishment (automatické doplňování OPKs po loginu + SPK rotace každých 7 dní) +- Silné šifrování klíčů na disku (PBKDF2 600k iterací + AES-256-GCM, ECP1 formát) +- Odolný ratchet — automatický rollback stavu při selhání dešifrování +- TLS (volitelný, auto-gen self-signed) +- Real-time notifikace konverzací — nové konverzace, přidání/odebrání členů se zobrazí okamžitě bez re-loginu +- Connection state indicator — zelená/červená/oranžová tečka, automatický reconnect s exponential backoff +- Online/offline status — zelená tečka na avataru v seznamu konverzací + v group info +- User profily — telefon, lokace, avatar, nastavení viditelnosti (email, telefon, lokace) +- Phantom users — anti user-enumeration: konverzace s neregistrovaným emailem funguje normálně (odesílatel vidí své zprávy), zprávy pro phantom příjemce se neukládají, phantom se smaže při skutečné registraci +- Clickable links — HTTPS modré, HTTP oranžové s ikonou zámku + potvrzovací dialog + +### GUI (PyQt6) +- Dark theme (Catppuccin Mocha) +- Seznam konverzací s kulatými avatary a online indikátorem (zelená tečka) +- Unread count badge na konverzacích (číselný počet nepřečtených zpráv) +- Message bubliny s barevným left border, timestamp vedle jména +- Read receipts (checkmarks), group info dialog, add/remove member +- Context menu: reply, delete, view image, download file +- Attach button pro obrázky a soubory, thumbnail v bublině, full-size viewer + save +- Pagination ("Load older messages") +- Connection indicator (zelená=online, červená=offline, oranžová=reconnecting) +- Auto-reconnect s exponential backoff (1s → 2s → 4s → ... → max 30s) +- Tlačítko "My Profile" — editace vlastního profilu (telefon, lokace, avatar, viditelnost) +- User profil dialog — klik na info tlačítko v group info → read-only profil uživatele +- Avatar upload/download (JPEG/PNG, max 2 MB, kruhový výřez) +- Leave group (červené tlačítko v group info, přenos creatora) +- Pozvánky do skupin — seznam pending pozvánek nad konverzacemi, pravý klik → accept/decline +- Periodický refresh avatarů a pozvánek (každé 2 minuty) + +### CLI +- Základní funkcionalita (DM, skupiny, šifrování). Profily a soubory pouze přes GUI. + +## Závislosti + +- `cryptography` — Ed25519, X25519, AES-GCM, RSA, HKDF, PBKDF2 +- `mysql-connector-python` — MySQL +- `python-dotenv` — env vars +- `PyQt6` — GUI +- `Pillow` — resize/thumbnail obrázků + +## Known Issues + +- Sender Keys pro skupiny se nedistribuují znovu při přidání nového člena (nový člen neuvidí staré skupinové zprávy). + +## TODO + +### Security — Zbývající +- [ ] **H9: Self-encryption key** — statický/deterministický klíč (by-design pro cross-device, architektonické omezení) +- [ ] M1: Nekonzistentní Ed25519 serializace (částečně vyřešeno M3 — ECP1 formát, ale 3 legacy formáty) +- [ ] M6: TOCTOU race v membership checks +- [ ] M7: MySQL spojení bez TLS +- [ ] L1-L8: Low-priority hardening +- [ ] **Penetrační testy** — manuální + automatizované + +### Features — High Priority +- [ ] Redistribuce sender keys při přidání nového člena do skupiny +- [ ] Typing indicators + +### Features — Medium Priority +- [ ] Hledání zpráv v konverzacích +- [ ] Group admin roles (více adminů) +- [ ] Edit sent messages + +### Features — Low Priority +- [ ] Dark/light theme toggle +- [ ] Desktop notifications (system tray) +- [ ] Database connection pooling +- [ ] Image gallery view +- [ ] Systemd + Docker deployment + +### Monetizace +Oddělený platební server (Stripe, KYC/AML compliant) od chat serveru (anonymní). Platba → premium kód → aktivace na chat serveru. Žádný přímý link platba↔chat identita. + +- **Free tier:** 5 konverzací, 10 MB soubory, 1 zařízení, 30 dní retence, max 10 členů/skupina +- **Premium:** neomezeno — aktivace jednorázovým kódem z platebního serveru +- **Enterprise:** self-hosted, LDAP/SSO, admin dashboard, SLA — faktura na firmu +- Detaily implementace viz `CLAUDE.md` + +### Hotovo — Security +- [x] **C1-C6: Všechny CRITICAL opraveny** — readuntil DoS, sender key fast-forward, OPK permissions, upload size check, path traversal (UUID validace + is_relative_to) +- [x] **H1-H8, H10-H14: Většina HIGH opravena** — lokální šifrování dat (AES-256-GCM), TLS hardening (INSECURE/AUTOGEN jen v dev), anti-enumeration, race conditions (asyncio.Lock), protokol error handling, avatar path traversal, hesla v paměti (bytearray+zero), image validace, filename sanitizace, OPK race condition (SELECT FOR UPDATE) +- [x] **M2-M5+M8-M13: Většina MEDIUM opravena** — HKDF salt, PBKDF2 600k iterací (ECP1 formát), SPK rotace 7 dní s grace periodem, rate limit cleanup, UUID validace, ratchet state rollback, message_ids cap, pairing poll token, upload check, chmod 0o700/0o600 +- [x] **SPK/OPK šifrování + brute-force lockout** — všechny privátní klíče na disku šifrované (ECP1 nebo AES-256-GCM), exponenciální backoff po chybném hesle (2^N s, max 5 min) + +### Hotovo — Features +- [x] **Multi-device support** — per-device sessions (Signal-like), device pairing, automatické prekey generování +- [x] Unread counts pro offline uživatele +- [x] Clickable HTTP links — HTTPS modré, HTTP oranžové s varováním +- [x] User profily (telefon, lokace, avatar, viditelnost) +- [x] Connection state indicator + auto-reconnect +- [x] Encrypted file sharing (až 50 MB) +- [x] Leave group + přenos creatora +- [x] Unread count badge +- [x] User avatars (upload/download, kruhový výřez) +- [x] Online/offline status (zelená tečka na avataru) +- [x] Mazání konverzací +- [x] Skupinové pozvánky (accept/decline) +- [x] Graceful server shutdown + +## Bezpečnostní audit + +Dva bezpečnostní audity provedeny (kód review). Nalezeno 6 CRITICAL, 12 HIGH, 12 MEDIUM, 8 LOW nálezů. + +| Závažnost | Celkem | Opraveno | Zbývá | +|-----------|--------|----------|-------| +| CRITICAL | 6 | **6** | 0 | +| HIGH | 12 | **11** | 1 (H9 — by-design) | +| MEDIUM | 12 | **10** | 2 (M1 částečně, M6, M7) | +| LOW | 8 | 0 | 8 | + +Detaily viz `CLAUDE.md`. +>>>>>>> d506e65 (initial commit) diff --git a/SECURITY_AUDIT.md b/SECURITY_AUDIT.md new file mode 100644 index 0000000..3afb064 --- /dev/null +++ b/SECURITY_AUDIT.md @@ -0,0 +1,363 @@ +# Security Audit (Encrypted Chat) + +Aktualizace: 2026-03-08 +Scope: `server.py`, `db.py`, `chat_core.py`, `gui_client.py`, `client.py`, `protocol.py`, `schema.sql`, `.env`, markdown dokumentace. + +Metodika: statický audit kódu + konfigurace. Nebyl proveden aktivní penetrační test ani fuzzing. + +## Executive Summary + +Nejzávažnější aktuálně otevřené nálezy: + +- Plaintext DB heslo v `.env` souborech (C3). +- Chybějící TLS mezi aplikací a MySQL (H1). +- Slabá oprávnění upload/avatary souborů na disku (H2). +- DoS přes neomezené `pending_registrations` (H6). + +## CRITICAL + +### ~~C1. TOFU / verifikace identity klíče se obchází při běžném X3DH flow~~ ✅ OPRAVENO (2026-03-08) + +**Evidence** + +- TOFU kontrola existuje jen v `_get_user_info()` (`chat_core.py:799-803`). +- Při navazování session (`_get_or_create_session`) se `identity_key` z bundle bere přímo bez TOFU kontroly (`chat_core.py:1497-1501`, `chat_core.py:1534-1538`). +- U příchozího X3DH (`_process_x3dh_header`) se remote IK také uloží bez TOFU kontroly (`chat_core.py:1551-1553`, `chat_core.py:1580-1584`). + +**Dopad** + +- Pokud server nebo MITM podstrčí jiný identity key, klient může navázat session bez varování. +- Prakticky to obchází uživatelskou verifikaci kontaktu ve výchozím messaging flow. + +**Oprava** + +- Nová výjimka `IdentityKeyChanged(user_id, new_key_bytes, status)` v `chat_core.py` — hard-fail při změně identity klíče. +- `_get_or_create_session()`: TOFU check přes `check_identity_key()` před X3DH initiate. Při `changed`/`changed_verified` vyhodí `IdentityKeyChanged` — session se nenaváže. +- `_process_x3dh_header()`: TOFU check před X3DH respond. Stejný hard-fail — příchozí zpráva s podvrženým klíčem je odmítnuta. +- GUI: `IdentityKeyChanged` zachycena v notification loopu (emituje `key_change_warning` signál místo pádu loopu) a v `_do_send_message` (zobrazí error + warning dialog). +- Session je blokována dokud uživatel explicitně neakceptuje key change přes `accept_key_change()`. +- `decrypt_notification()`: explicitní `except IdentityKeyChanged: raise` před generickým `except Exception` — výjimka se propaguje do notification loopu místo tichého spolknutí. +- `key_change_warning` signál rozšířen o 5. parametr `new_key_bytes: bytes` — "Accept New Key" dialog předává nový klíč přímo z výjimky, ne z cache (která mohla obsahovat starý klíč). +- `IdentityKeyChanged` ošetřena ve všech GUI send cestách: `_do_send_image`, `_do_send_file`, `_do_forward_message`, `_do_find_or_create_and_send` — zobrazí warning dialog + error message. +- CLI (`client.py`): `IdentityKeyChanged` ošetřena ve všech 6 send cestách (send_message ×3, send_image, send_file, forward_message). + +--- + +### ~~C2. Perzistentní DoS konverzace přes nevalidní message headers~~ ✅ OPRAVENO (2026-03-08) + +**Evidence** + +- Server přijímá `ratchet_header` / `x3dh_header` i jako raw `str/bytes` bez JSON schema validace (`server.py:1105-1112`, `server.py:1146-1151`). +- Při `get_messages` se hodnoty bez ochrany parsují `json.loads(...)` (`server.py:1266`, `server.py:1274`). + +**Dopad** + +- Útočník v konverzaci může uložit “poisoned” hlavičku a rozbít načtení historie ostatním členům (`Internal server error`). +- Chyba je perzistentní, dokud je vadná zpráva v historii. + +**Oprava** + +- Nový helper `_validate_header(raw, name)` v `server.py` — přijímá pouze `dict`, odmítá `str`/`bytes`, limit 4096 bajtů. +- `handle_send_message`: message-level i per-recipient headers procházejí `_validate_header()`. Nevalidní hlavička → error response, zpráva se neuloží. +- `handle_get_messages`: `json.loads()` obaleno `try/except` (JSONDecodeError, TypeError, UnicodeDecodeError). Corrupted header → prázdný dict `{}` + warning log, ostatní zprávy se načtou normálně. +- `_validate_header()` rozšířena o validaci očekávaných klíčů a typů pro ratchet headers (`dh_pub`: str, `n`: int, `pn`: int) a používá striktní kontrolu typu pro `n/pn` (`type(...) is int`) — `bool` je explicitně odmítnut. +- Realtime push notifikace nyní čtou data z validovaných `db_recipients` (ne z `recipients_raw`). Per-recipient hlavičky se dekódují z validovaných bytes zpět do `dict` pro JSON notifikaci. +- `encrypted_content` a `nonce` v push notifikacích se skládají z validovaných raw bytes a serializují se přes `encode_binary()` — untrusted hodnoty z raw requestu se do push větve nepropíší. + +--- + +### C3. Plaintext tajemství v `.env` a `zaloha/.env` + +**Evidence** + +- `.env` obsahuje `MYSQL_PASSWORD` (`.env:4`). +- `zaloha/.env` obsahuje `MYSQL_PASSWORD` (`zaloha/.env:4`). + +**Dopad** + +- Únik souboru = okamžitý přístup do DB. +- Riziko přes backupy, sdílení projektu, malware, CI artefakty. + +**Doporučení** + +1. Okamžitě rotovat DB heslo. +2. Nahradit repozitářové `.env` šablonou (`.env.example`) bez tajemství. +3. Použít secrets manager / deployment-level secret injection. + +## HIGH + +### H1. Chybí TLS mezi aplikací a MySQL + +**Evidence** + +- `MySQLConnectionPool` je bez `ssl_ca`/`ssl_verify_cert` parametrů (`db.py:35-44`). +- Konfigurace používá síťovou DB (`MYSQL_HOST=192.168.1.112`, `.env:1`). + +**Dopad** + +- Odposlech nebo MITM na trase app<->DB může odhalit credentials i data. + +**Doporučení** + +1. Zapnout MySQL TLS na serveru. +2. Vynutit TLS verifikaci certifikátu v `db.py`. + +--- + +### H2. Upload/avatary na disku mají slabá oprávnění + +**Evidence** + +- Upload soubory jsou vytvářeny bez explicitního `chmod` na file (`server.py:1732`, `server.py:1806`, `server.py:1909`, `server.py:1969`). +- V prostředí auditu: `uploads` a `uploads/avatars` mají `775`, soubory typicky `664`. + +**Dopad** + +- Lokální uživatelé na stejném hostu mohou číst citlivá data (včetně avatarů v plaintextu). + +**Doporučení** + +1. Nastavit adresáře `0700`. +2. Po zápisu každého souboru nastavit `0600`. +3. Upload storage přesunout mimo project tree. + +--- + +### ~~H3. `session_reset` nemá autorizační vazbu na vztah mezi uživateli~~ ✅ OPRAVENO (2026-03-08) + +**Evidence** + +- Handler přijme libovolné validní `peer_user_id` a pošle notifikaci (`server.py:2040-2052`). +- Neověřuje, že uživatelé sdílí konverzaci nebo existuje session. + +**Dopad** + +- Možnost spam/DoS reset notifikací na cílové uživatele. + +**Oprava** + +- Nová DB funkce `db.shares_conversation(user_id_a, user_id_b)` — `SELECT 1 ... LIMIT 1` přes `conversation_members` JOIN. +- `handle_session_reset`: před push notifikací ověřuje `shares_conversation()`. Pokud uživatelé nesdílí žádnou konverzaci → error response. +- Rate limit 5 požadavků/min na `session_reset` per user (`session_reset|{user_id}`) — IP adresa není součást klíče, takže změnou IP nejde limit obejít. +- Pokud je předán `peer_device_id`, reset notifikace se doručí pouze cílovému zařízení (filtr přes `writer_device_map`). Bez `peer_device_id` zůstává broadcast na všechna zařízení peera. + +--- + +### ~~H4. User enumeration přes pairing a user-info endpointy~~ ✅ OPRAVENO (2026-03-08) + +**Evidence** + +- `pairing_start` vrací explicitně `User not found` (`server.py:763-766`). +- `get_user_info` vrací metadata uživatele při lookupu přes email/user_id (`server.py:551-564`). + +**Dopad** + +- Snadné mapování existence účtů. + +**Oprava** + +- `handle_pairing_start`: vždy vrací `ok` s platně vypadajícím kódem a session se vytvoří vždy (i pro neexistující email), takže `pairing_poll` vrací nerozlišitelné `ready: false`. +- Přidán globální cap `PAIRING_MAX_SESSIONS = 100` pro omezení počtu současných pairing sessions (DoS hardening). +- `pairing_start` rate limit je per-IP (bez email komponenty), aby nešel obcházet rotací emailů. +- `pairing_claim` i `pairing_send`: sjednocená chyba `Invalid or expired code` (žádné rozlišení "neexistuje" vs "patří jinému účtu"). +- V pairing flow se síťové I/O (`send_resp`) volá až po uvolnění `_pairing_lock`. +- `handle_get_user_info`: přidán parametr `session` (vyžaduje login). Lookupovat lze jen sebe nebo kontakty (ověřeno přes `shares_conversation()`). Pro neexistující i nepovolené cíle vrací neutrální "User not found". +- Doplňuje dřívější anti-enumeration opravy: `register_start` (generická odpověď), `login_start` (fake challenge), `login_finish` (generická chyba). + +--- + +### ~~H5. Phantom user inflation přes `create_conversation` / `find_conversation` / `add_member` (DoS)~~ ✅ OPRAVENO (2026-03-08) + +**Evidence** + +- `create_conversation` vytváří phantom účty pro neznámé emaily bez dedikovaného rate limitu (`server.py:906-920`). +- `find_conversation` a `add_member` rate limitují přes `_rate_limit_key(..., addr, email)`, takže rotace emailů obchází limit (`server.py:972`, `server.py:1001`, `server.py:209-212`). +- `create_phantom_user()` pro každý nový email generuje IK+SPK+OTP a zapisuje více řádků do DB (`db.py:1470-1507`). + +**Dopad** + +- Útočník může nafukovat DB a CPU náklady (kryptografická generace + zápisy), případně degradovat výkon serveru. + +**Oprava** + +1. `_can_create_phantom(addr, user_id)` helper kontroluje 3 limity před každým `create_phantom_user()`: + - Globální cap: `MAX_PHANTOM_USERS = 500` (počet v `phantom_user_ids` setu) + - Per-user rate: `phantom_create|{user_id}` — 10/min (email-nezávislé, neobejitelné rotací) + - Per-IP rate: `phantom_create_ip|{addr}` — 10/min (email-nezávislé) +2. `create_conversation` nově má per-user rate limit 10/min + phantom check před každým členem. +3. `find_conversation` a `add_member` — existující per-addr+email limit zůstává (brání hammering jednoho emailu), přidán `_can_create_phantom` check před vytvořením phantomu. +4. Stávající `cleanup_stale_phantoms(30)` v periodic cleanup (10 min) zajišťuje garbage collection. + +--- + +### H6. `pending_registrations` nemá hard cap (memory/SMTP abuse) + +**Evidence** + +- `pending_registrations` je globální in-memory dict bez horního limitu (`server.py:56`). +- `register_start` používá rate limit klíč s emailem (`register_start|addr|email`), rotace emailů limit obchází (`server.py:341`, `server.py:209-212`). +- `register_start` ukládá novou pending registraci do dict bez capu (`server.py:373-382`). +- Periodický cleanup nevolá `_cleanup_registrations()`; expirace se spouští jen při `register_*` flow (`server.py:2486-2508`, `server.py:276-281`). + +**Dopad** + +- Riziko růstu paměti a SMTP abuse (masivní register_start s různými emaily). + +**Doporučení** + +1. Přidat `REGISTER_MAX_PENDING` cap a odmítnout nové requesty po dosažení limitu. +2. Změnit rate limit na per-IP (bez emailu) + případně per-subnet. +3. Přidat `_cleanup_registrations()` i do periodického cleanup tasku. + +## MEDIUM + +### ~~M1. `mark_read` a `confirm_delivery` neověřují, že `message_ids` patří do dané konverzace~~ ✅ OPRAVENO (2026-03-08) + +**Evidence** + +- Handler validuje členství jen v `conversation_id` (`server.py:1464-1479`, `server.py:1516-1531`). +- DB insert metody pro receipts neváží `message_id` na konverzaci (`db.py:1102-1113`, `db.py:1188-1200`). + +**Dopad** + +- Možná manipulace read/delivery stavu cizích zpráv (integrita metadat). + +**Oprava** + +- `db.mark_messages_read()` a `db.mark_messages_delivered()` nahrazeny z per-row `INSERT IGNORE` na batch `INSERT IGNORE ... SELECT m.id, %s FROM messages m WHERE m.id IN (...) AND m.conversation_id = %s`. +- Message IDs, které nepatří do dané konverzace, jsou tiše přeskočeny (SELECT je nevrátí). + +--- + +### M2. SMTP STARTTLS bez explicitního TLS contextu + +**Evidence** + +- `server.starttls()` je voláno bez `ssl.create_default_context()` (`server.py:290`). + +**Dopad** + +- Slabší kontrola TLS parametrů/verifikace dle runtime prostředí. + +**Doporučení** + +1. Použít `server.starttls(context=ssl.create_default_context())`. +2. Přidat `EHLO` před/po STARTTLS. + +--- + +### ~~M3. CLI klient: několik lokálních hardening mezer~~ ✅ OPRAVENO (2026-03-08) + +**Evidence** + +- Heslo se zadává přes `input()` (echo on) (`client.py:730`, `client.py:749`, `client.py:754`). +- Zprávy se tisknou bez sanitace escape sekvencí (`client.py:491`). +- Default save path při downloadu je převzat z remote `filename` (`client.py:523-530`). + +**Dopad** + +- Shoulder-surfing hesla, terminal escape spoofing, riskantní defaultní save path. + +**Oprava** + +- Všechny password prompty (register, login, pairing, authorize device, rotate keys) nahrazeny `prompt_password()` wrapping `getpass.getpass()` — heslo se nezobrazuje na terminálu. +- `_sanitize_text()` helper stripuje control znaky (`\x00-\x1f` kromě `\t`/`\n`/`\r`) a ANSI escape sekvence. Aplikováno na `sender`, `text`, `filename` při výpisu zpráv v `_print_messages()`. +- Follow-up: `_sanitize_text()` nyní bezpečně přijímá i non-string vstupy (`None -> ""`, jinak `str(...)`), čímž se eliminuje `TypeError` při neočekávaném typu z payloadu (`client.py:32-36`). +- Follow-up: sanitace rozšířena na zbývající user-controlled CLI výpisy — seznam konverzací (`client.py:63-67`), search výsledky (`client.py:293-300`), seznam pozvánek (`client.py:612-614`), profil (`client.py:637-644`), seznam zařízení (`client.py:709-713`), verify view (`client.py:435`, `client.py:446`) a notifikace včetně reaction hodnoty (`client.py:752-774`). +- `_safe_filename()` helper: `os.path.basename()` + odstranění NUL + fallback na `"download"` pro prázdné/tečkové názvy. Aplikováno na default save path při downloadu. + +--- + +### ~~M4. `get_key_bundle` umožňuje OPK depletion (availability)~~ ✅ OPRAVENO (2026-03-08) + +**Evidence** + +- `handle_get_key_bundle` nemá rate limit ani authorizační vazbu na vztah mezi uživateli (`server.py:648-660`). +- DB vrstva při každém volání spotřebovává one-time prekeys (`get_key_bundles_for_user` — „Consumes one OPK per device atomically”, `db.py:394-450`). +- `target_user_id` lze získat přes `find_conversation` lookup (`server.py:966-987`). + +**Dopad** + +- Útočník může opakovanými dotazy vyčerpat OPK oběti, zhoršit doručitelnost a vynutit časté doplňování prekeys. + +**Oprava** + +1. Per-caller rate limit: `get_key_bundle|{user_id}` — 10/min (omezuje celkový počet fetchů jednoho uživatele). +2. Per-target rate limit: `get_key_bundle_target|{target_user_id}` — 20/min (omezuje rychlost vyčerpávání OPK konkrétní oběti). Autorizace probíhá před per-target RL (neautorizovaný request nespálí bucket cíle). +3. Autorizace: `shares_conversation()` — caller musí sdílet konverzaci s cílem (self-fetch povolen vždy). +4. Chybová zpráva pro neautorizovaný přístup je neutrální (`”Key bundle not available”`) — shodná s neexistujícím uživatelem. +5. **Doplňující per-user rate limity** na všechny zbývající výpočetně/DB náročné handlery (celkem 29 RL checks): + - Crypto+DB: `upload_prekeys` 5/min, `ensure_prekeys` 5/min, `rotate_keys` 3/min, `reencrypt` 10/min + - DB-heavy: `get_messages` 30/min, `delete_conv` 5/min, `delete_msg` 20/min, `react` 20/min, `remove_member` 10/min, `rename_conv` 5/min + - File I/O: `update_avatar` 5/min (sdílený bucket pro user i group avatar) + +--- + +### ~~M5. `upload_image_start` nemá anti-DoS cap na in-flight uploady~~ ✅ OPRAVENO (2026-03-08) + +**Evidence** + +- `upload_image_start` nevynucuje request rate limit ani limit počtu aktivních uploadů na user/IP (`server.py:1786-1823`). +- In-memory `pending_uploads` je bez explicitního capu (`server.py:58`, `server.py:1812-1819`). +- Cleanup stale uploadů běží periodicky (600s) a DB stale threshold je 3600s (`server.py:2488-2490`, `db.py:1626-1633`). + +**Dopad** + +- Útočník může zahájit mnoho uploadů a vytvářet dočasné soubory/záznamy, což zvyšuje memory/disk tlak. + +**Oprava** + +1. Per-user rate limit: `upload_start|{user_id}` — 10/min. +2. Globální cap: `MAX_UPLOADS_GLOBAL = 200` (kontrola `len(pending_uploads)` pod `_uploads_lock`). +3. Per-user cap: `MAX_UPLOADS_PER_USER = 5` (počet záznamů s `uploader_id == user_id`). +4. Stale threshold snížen z 3600s na `UPLOAD_STALE_SECONDS = 600` (10 min). +5. Periodic cleanup interval snížen z 600s na 120s (2 min). + +## LOW + +### ~~L1. `decode_binary` není strict base64~~ ✅ OPRAVENO (2026-03-08) + +**Evidence** + +- `base64.b64decode(data)` bez `validate=True` (`protocol.py:18`). + +**Dopad** + +- Méně striktní input parsing (robustnost), ne přímý průnik. + +**Oprava** + +- `decode_binary()` nyní volá `base64.b64decode(data, validate=True)` — odmítá neplatné base64 znaky (whitespace, non-alphabet). + +## Positive Findings + +- Dev-only guardy: `TLS_INSECURE` a `TLS_AUTOGEN` jsou blokovány mimo `ENVIRONMENT=dev`. +- Server používá UUID validace v řadě handlerů. +- Upload/download ověřuje členství v konverzaci. +- Klientské private keys/storage používají PBKDF2 + AES-GCM a restriktivní perms (`0700`/`0600`) v key storage. +- Přítomný client-side lockout na opakované chybné login pokusy. + +## Prioritní plán oprav + +### 0-48 hodin + +1. Rotace DB hesla + odstranění tajemství z `.env`. +2. ~~Oprava TOFU bypassu v obou X3DH cestách.~~ ✅ DONE +3. ~~Zablokování nevalidních message headers na vstupu.~~ ✅ DONE +4. Přepnutí upload storage perms na `0700/0600`. +5. ~~Omezit phantom creation (rate limit bez emailu + cap).~~ ✅ DONE +6. Zavést cap pro `pending_registrations` a čistit je i v periodickém cleanupu. +7. ~~Přidat cap/rate limit na in-flight uploady.~~ ✅ DONE + +### 7 dní + +1. Zapnout TLS mezi app a MySQL (mTLS nebo minimálně server cert verify). +2. ~~Opravit autorizaci `session_reset`.~~ ✅ DONE +3. ~~Opravit vazbu `message_ids` na `conversation_id` pro receipts.~~ ✅ DONE +4. Omezit `get_key_bundle` (rate limit + policy sdílené konverzace). + +### 30 dní + +1. ~~Anti-enumeration sjednotit napříč endpointy.~~ ✅ DONE +2. ~~CLI hardening (`getpass`, output sanitace, filename sanitace).~~ ✅ DONE +3. Doplnit integrační testy pro bezpečnostní regresi (TOFU, poisoned headers, receipt authz, session_reset device targeting, anti-enumeration, DoS caps). diff --git a/TODO.md b/TODO.md new file mode 100644 index 0000000..302a5f5 --- /dev/null +++ b/TODO.md @@ -0,0 +1,22 @@ +# TODO + +## Distributed global cap for phantom users (multi-process safe) + +1. Add DB-backed quota as source of truth (`system_quotas` table, row `phantom_users` with `used` and `limit`). +2. Move cap enforcement into one DB transaction: + - lock quota row with `SELECT ... FOR UPDATE` + - check `used < limit` + - create phantom user + - increment `used` + - commit (or rollback on failure). +3. Handle same-email races using `UNIQUE(email)`: + - on duplicate key, do not increment quota + - return existing user (or unified error response). +4. Add periodic reconciliation job: + - recalculate phantom count from `users` + - repair `system_quotas.used` if drift is detected. +5. Move phantom creation rate-limits to shared backend (Redis or DB atomic counters), so all server processes enforce the same limits. +6. Add concurrency tests: + - multi-process create storm near cap boundary (499/500) + - duplicate-email storm + - assert `used <= limit` always holds. diff --git a/certs/README.md b/certs/README.md new file mode 100644 index 0000000..7c074d6 --- /dev/null +++ b/certs/README.md @@ -0,0 +1,101 @@ +# TLS Setup — Let's Encrypt + Cloudflare DNS + +TLS certifikát přes Let's Encrypt bez nutnosti otevírat port 80. +Ověření domény probíhá přes DNS TXT záznam (Cloudflare API). + +## Předpoklady + +- Doména s DNS na Cloudflare (free tier stačí) +- Cloudflare API token s oprávněním "Edit zone DNS" +- Root přístup na serveru (certbot potřebuje `/etc/letsencrypt/`) + +## Postup + +### 1. Cloudflare API token + +1. Jdi na https://dash.cloudflare.com/profile/api-tokens +2. **Create Token** → Use template **"Edit zone DNS"** +3. Zone Resources → vybrat svou doménu +4. Zkopíruj vygenerovaný token + +### 2. Credentials soubor + +```bash +cp cloudflare.ini.example cloudflare.ini +nano cloudflare.ini # vlož API token +chmod 600 cloudflare.ini +``` + +### 3. Získání certifikátu + +```bash +sudo ./setup-tls.sh chat.example.com +``` + +Skript nainstaluje certbot + Cloudflare plugin, získá certifikát a vytvoří symlinky v tomto adresáři. + +### 4. Konfigurace serveru + +Přidej do `.env` v kořenovém adresáři projektu: + +```env +TLS_ENABLED=true +TLS_CERT_FILE=/etc/letsencrypt/live/chat.example.com/fullchain.pem +TLS_KEY_FILE=/etc/letsencrypt/live/chat.example.com/privkey.pem +``` + +### 5. Konfigurace klienta + +Na klientovi stačí: + +```env +TLS_ENABLED=true +``` + +Let's Encrypt je v systémovém trust store — klient ověří certifikát automaticky. + +## Obnova certifikátu + +Certbot obnovuje certifikát automaticky přes systemd timer (každých ~60 dní, cert platí 90). + +```bash +# Ověřit že timer běží +systemctl status certbot.timer + +# Ruční obnova (test) +sudo certbot renew --dry-run +``` + +Po úspěšné obnově se spustí `reload-server.sh` (deploy hook) — restartuje chat server aby načetl nový certifikát. + +## Soubory + +| Soubor | Účel | +|--------|------| +| `setup-tls.sh` | Instalace certbot + získání certifikátu | +| `reload-server.sh` | Deploy hook — restartuje server po renew | +| `cloudflare.ini.example` | Šablona pro Cloudflare API token | +| `cloudflare.ini` | Tvůj API token (gitignored) | + +## FAQ + +**Funguje certifikát na nestandardním portu (např. 9999)?** +Ano. Certifikát je vázaný na doménu, ne na port. `chat.example.com:9999` funguje. + +**Musím otevírat port 80?** +Ne. DNS challenge ověřuje doménu přes DNS TXT záznam, žádný HTTP požadavek na server. + +**Co když nemám Cloudflare?** +Můžeš použít ruční DNS challenge (bez automatického renew): +```bash +sudo certbot certonly --manual --preferred-challenges dns -d chat.example.com +``` +Certbot ti řekne jaký TXT záznam přidat. Při renew to musíš opakovat ručně. + +**Dev/testování bez certifikátu?** +```env +ENVIRONMENT=dev +TLS_ENABLED=true +TLS_AUTOGEN=true # server vygeneruje self-signed cert +TLS_INSECURE=true # klient přeskočí ověření +``` diff --git a/certs/cloudflare.ini.example b/certs/cloudflare.ini.example new file mode 100644 index 0000000..2fd8c67 --- /dev/null +++ b/certs/cloudflare.ini.example @@ -0,0 +1,11 @@ +# Cloudflare API token pro certbot DNS challenge +# 1. Jdi na https://dash.cloudflare.com/profile/api-tokens +# 2. Create Token -> Edit zone DNS (template) +# 3. Zone Resources: vybrat svou doménu +# 4. Zkopírovat token sem +# +# Po vyplnění přejmenuj na cloudflare.ini a nastav práva: +# cp cloudflare.ini.example cloudflare.ini +# chmod 600 cloudflare.ini + +dns_cloudflare_api_token = VLOZ_SVUJ_CLOUDFLARE_API_TOKEN diff --git a/certs/reload-server.sh b/certs/reload-server.sh new file mode 100755 index 0000000..abc7449 --- /dev/null +++ b/certs/reload-server.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash +# Deploy hook — spustí se automaticky po úspěšném renew certifikátu +# Certbot volá tento skript s RENEWED_LINEAGE a RENEWED_DOMAINS env vars +# +# Restartuje chat server aby načetl nový certifikát. +# Přizpůsob podle toho jak server spouštíš (systemd / screen / přímý proces). + +set -euo pipefail + +echo "Certifikát obnoven pro: ${RENEWED_DOMAINS:-unknown}" + +# Varianta 1: Systemd service +if systemctl is-active --quiet encrypted-chat 2>/dev/null; then + systemctl restart encrypted-chat + echo "Server restartován (systemd)." + exit 0 +fi + +# Varianta 2: Poslat SIGINT procesu (graceful shutdown + ruční restart) +PID=$(pgrep -f "python.*server.py" || true) +if [ -n "$PID" ]; then + echo "Posílám SIGINT procesu $PID (server.py)" + kill -INT "$PID" + echo "Server zastaven. Spusť ho znovu ručně nebo přes systemd." + exit 0 +fi + +echo "VAROVÁNÍ: Server proces nenalezen. Restartuj server ručně." diff --git a/certs/setup-tls.sh b/certs/setup-tls.sh new file mode 100755 index 0000000..1c1f7c6 --- /dev/null +++ b/certs/setup-tls.sh @@ -0,0 +1,108 @@ +#!/usr/bin/env bash +# Setup TLS certifikátu přes Let's Encrypt + Cloudflare DNS challenge +# Nevyžaduje otevřený port 80 — ověření přes DNS TXT záznam +# +# Použití: +# 1. Přesuň DNS domény na Cloudflare (free tier stačí) +# 2. Vytvoř API token: https://dash.cloudflare.com/profile/api-tokens +# -> Use template "Edit zone DNS" -> vybrat doménu +# 3. cp cloudflare.ini.example cloudflare.ini +# Vlož token, chmod 600 cloudflare.ini +# 4. sudo ./setup-tls.sh chat.example.com +# +# Po úspěšném získání certifikátu přidej do .env: +# TLS_ENABLED=true +# TLS_CERT_FILE=/etc/letsencrypt/live/DOMENA/fullchain.pem +# TLS_KEY_FILE=/etc/letsencrypt/live/DOMENA/privkey.pem + +set -euo pipefail + +DOMAIN="${1:-}" +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +CREDS="$SCRIPT_DIR/cloudflare.ini" +DEPLOY_HOOK="$SCRIPT_DIR/reload-server.sh" + +if [ -z "$DOMAIN" ]; then + echo "Použití: sudo $0 " + echo "Příklad: sudo $0 chat.example.com" + exit 1 +fi + +if [ "$EUID" -ne 0 ]; then + echo "Spusť jako root: sudo $0 $DOMAIN" + exit 1 +fi + +if [ ! -f "$CREDS" ]; then + echo "Chybí $CREDS" + echo "Zkopíruj cloudflare.ini.example -> cloudflare.ini a vlož API token." + exit 1 +fi + +# Ověř oprávnění credentials souboru +PERMS=$(stat -c %a "$CREDS" 2>/dev/null || stat -f %Lp "$CREDS" 2>/dev/null) +if [ "$PERMS" != "600" ]; then + echo "VAROVÁNÍ: $CREDS má oprávnění $PERMS, nastavuji 600" + chmod 600 "$CREDS" +fi + +echo "=== Instalace certbot + Cloudflare pluginu ===" +if ! command -v certbot &>/dev/null; then + apt-get update + apt-get install -y certbot python3-certbot-dns-cloudflare + echo "Certbot nainstalován." +else + echo "Certbot již nainstalován." + # Doinstaluj plugin pokud chybí + if ! python3 -c "import certbot_dns_cloudflare" 2>/dev/null; then + apt-get install -y python3-certbot-dns-cloudflare + fi +fi + +echo "" +echo "=== Získání certifikátu pro $DOMAIN ===" +DEPLOY_ARGS="" +if [ -f "$DEPLOY_HOOK" ] && [ -x "$DEPLOY_HOOK" ]; then + DEPLOY_ARGS="--deploy-hook $DEPLOY_HOOK" + echo "Deploy hook: $DEPLOY_HOOK" +fi + +certbot certonly \ + --dns-cloudflare \ + --dns-cloudflare-credentials "$CREDS" \ + --dns-cloudflare-propagation-seconds 30 \ + -d "$DOMAIN" \ + --non-interactive \ + --agree-tos \ + --register-unsafely-without-email \ + $DEPLOY_ARGS + +CERT_DIR="/etc/letsencrypt/live/$DOMAIN" +if [ -d "$CERT_DIR" ]; then + echo "" + echo "=== Certifikát úspěšně získán ===" + echo "" + echo "Soubory:" + echo " Certifikát: $CERT_DIR/fullchain.pem" + echo " Klíč: $CERT_DIR/privkey.pem" + echo "" + echo "Přidej do .env:" + echo " TLS_ENABLED=true" + echo " TLS_CERT_FILE=$CERT_DIR/fullchain.pem" + echo " TLS_KEY_FILE=$CERT_DIR/privkey.pem" + echo "" + echo "Na klientovi stačí:" + echo " TLS_ENABLED=true" + echo "" + echo "Automatický renew: certbot timer (systemd) nebo cron" + echo " systemctl status certbot.timer" + echo "" + + # Symlinky pro snadný přístup + ln -sf "$CERT_DIR/fullchain.pem" "$SCRIPT_DIR/fullchain.pem" + ln -sf "$CERT_DIR/privkey.pem" "$SCRIPT_DIR/privkey.pem" + echo "Symlinky vytvořeny v $SCRIPT_DIR/" +else + echo "CHYBA: Certifikát nebyl vytvořen." + exit 1 +fi diff --git a/chat_core.py b/chat_core.py new file mode 100644 index 0000000..7f19d07 --- /dev/null +++ b/chat_core.py @@ -0,0 +1,3481 @@ +"""Shared network layer and ChatClient class for CLI and GUI clients. + +Uses X3DH + Double Ratchet for message encryption, Sender Keys for groups. +RSA retained for login challenge-response only. +""" + +import asyncio +import hashlib +import json +import logging +import os +import ssl +import time +import uuid +from datetime import datetime, timezone +from pathlib import Path + +from dotenv import load_dotenv + +load_dotenv() + +from crypto_utils import ( + # RSA (login only) + generate_rsa_keypair, + serialize_private_key, + serialize_public_key, + load_private_key, + load_public_key, + rsa_sign, + # Ed25519 + generate_identity_keypair, + serialize_ed25519_private, + serialize_ed25519_private_raw, + serialize_ed25519_public, + load_ed25519_private, + load_ed25519_public, + ed25519_sign, + # X25519 + generate_x25519_keypair, + serialize_x25519_private, + serialize_x25519_public, + load_x25519_private, + load_x25519_public, + # X3DH + generate_signed_prekey, + generate_one_time_prekeys, + x3dh_initiate, + x3dh_respond, + # Double Ratchet + DoubleRatchet, + # Sender Keys + SenderKeyState, + # AES + aes_encrypt, + aes_decrypt, + # Self-encryption + derive_self_encryption_key, + # Local storage encryption + derive_local_storage_key, + # Contact verification + compute_fingerprint, + format_fingerprint, + compute_safety_number, + encode_verification_qr, + decode_verification_qr, + # Message padding + pad_plaintext, + unpad_plaintext, +) +from protocol import ( + VERSION, + ProtocolReader, + ProtocolWriter, + encode_binary, + decode_binary, + build_request, + MAX_MESSAGE_BYTES, + IMAGE_CHUNK_SIZE, +) + + +KEY_DIR = Path.home() / ".encrypted_chat" +OPK_REPLENISH_THRESHOLD = 20 +OPK_BATCH_SIZE = 50 +SPK_ROTATION_DAYS = 7 + + +def _encrypt_local(data: bytes, key: bytes) -> bytes: + """Encrypt data with AES-256-GCM for local storage. Format: nonce(12) + tag(16) + ciphertext.""" + _, nonce, ct, tag = aes_encrypt(data, key=key) + return nonce + tag + ct + + +def _decrypt_local(raw: bytes, key: bytes) -> bytes: + """Decrypt data encrypted by _encrypt_local.""" + nonce, tag, ct = raw[:12], raw[12:28], raw[28:] + return aes_decrypt(key, nonce, ct, tag) + + +def get_key_dir(email: str) -> Path: + d = KEY_DIR / email + d.mkdir(parents=True, exist_ok=True) + os.chmod(d, 0o700) + return d + + +# --------------------------------------------------------------------------- +# RSA key storage (login only — unchanged interface) +# --------------------------------------------------------------------------- + +def save_keys(email: str, private_key, public_key, password: bytes | None = None): + d = get_key_dir(email) + (d / "private.pem").write_bytes(serialize_private_key(private_key, password=password)) + (d / "public.pem").write_bytes(serialize_public_key(public_key)) + os.chmod(d / "private.pem", 0o600) + + +def load_keys(email: str, password: bytes | None = None): + d = get_key_dir(email) + priv_path = d / "private.pem" + pub_path = d / "public.pem" + if not priv_path.exists(): + return None, None, "No local keys found." + pem = priv_path.read_bytes() + try: + private_key = load_private_key(pem, password=password) + except Exception: + try: + private_key = load_private_key(pem, password=None) + if password: + save_keys(email, private_key, load_public_key(pub_path.read_bytes()), password=password) + except Exception: + return None, None, "Invalid or missing password." + public_key = load_public_key(pub_path.read_bytes()) + return private_key, public_key, None + + +# --------------------------------------------------------------------------- +# Identity + prekey storage +# --------------------------------------------------------------------------- + +def _save_identity_keys(email: str, ed_priv, ed_pub, password: bytes | None = None): + d = get_key_dir(email) + if password: + (d / "identity_private.bin").write_bytes(serialize_ed25519_private(ed_priv, password=password)) + else: + (d / "identity_private.bin").write_bytes(serialize_ed25519_private_raw(ed_priv)) + (d / "identity_public.bin").write_bytes(serialize_ed25519_public(ed_pub)) + os.chmod(d / "identity_private.bin", 0o600) + + +def _load_identity_keys(email: str, password: bytes | None = None): + d = get_key_dir(email) + priv_path = d / "identity_private.bin" + pub_path = d / "identity_public.bin" + if not priv_path.exists(): + return None, None + priv = load_ed25519_private(priv_path.read_bytes(), password=password) + pub = load_ed25519_public(pub_path.read_bytes()) + return priv, pub + + +def _save_spk(email: str, spk_priv, spk_id: str, local_key: bytes | None = None): + d = get_key_dir(email) + raw = serialize_x25519_private(spk_priv) + data = _encrypt_local(raw, local_key) if local_key else raw + (d / "spk_private.bin").write_bytes(data) + (d / "spk_id.txt").write_text(spk_id) + os.chmod(d / "spk_private.bin", 0o600) + + +def _load_spk(email: str, local_key: bytes | None = None): + d = get_key_dir(email) + priv_path = d / "spk_private.bin" + id_path = d / "spk_id.txt" + if not priv_path.exists(): + return None, None + raw = priv_path.read_bytes() + if local_key: + try: + raw = _decrypt_local(raw, local_key) + except Exception: + # Plaintext fallback (migration) — re-save encrypted + pass + priv = load_x25519_private(raw) + spk_id = id_path.read_text().strip() if id_path.exists() else "" + if local_key: + _save_spk(email, priv, spk_id, local_key) + return priv, spk_id + + +def _save_prev_spk(email: str, spk_priv, spk_id: str, local_key: bytes | None = None): + """Save previous SPK for grace period (in-flight X3DH may reference old SPK).""" + d = get_key_dir(email) + raw = serialize_x25519_private(spk_priv) + data = _encrypt_local(raw, local_key) if local_key else raw + (d / "prev_spk_private.bin").write_bytes(data) + (d / "prev_spk_id.txt").write_text(spk_id) + os.chmod(d / "prev_spk_private.bin", 0o600) + + +def _load_prev_spk(email: str, local_key: bytes | None = None): + """Load previous SPK (grace period). Returns (private_key, spk_id) or (None, None).""" + d = get_key_dir(email) + priv_path = d / "prev_spk_private.bin" + id_path = d / "prev_spk_id.txt" + if not priv_path.exists(): + return None, None + raw = priv_path.read_bytes() + if local_key: + try: + raw = _decrypt_local(raw, local_key) + except Exception: + pass + priv = load_x25519_private(raw) + spk_id = id_path.read_text().strip() if id_path.exists() else "" + if local_key: + _save_prev_spk(email, priv, spk_id, local_key) + return priv, spk_id + + +def _save_opk_private(email: str, opk_id: str, opk_priv, local_key: bytes | None = None): + d = get_key_dir(email) / "opk_private" + d.mkdir(parents=True, exist_ok=True) + os.chmod(d, 0o700) + raw = serialize_x25519_private(opk_priv) + data = _encrypt_local(raw, local_key) if local_key else raw + (d / f"{opk_id}.bin").write_bytes(data) + os.chmod(d / f"{opk_id}.bin", 0o600) + + +def _load_opk_private(email: str, opk_id: str, local_key: bytes | None = None): + d = get_key_dir(email) / "opk_private" + p = d / f"{opk_id}.bin" + if not p.exists(): + return None + raw = p.read_bytes() + if local_key: + try: + raw = _decrypt_local(raw, local_key) + except Exception: + pass + priv = load_x25519_private(raw) + # Migration: re-save encrypted if local_key provided + if local_key: + _save_opk_private(email, opk_id, priv, local_key) + return priv + + +def _secure_delete(p: Path): + """Overwrite file with random data before deletion (anti-forensic wipe).""" + try: + if not p.exists(): + return + size = p.stat().st_size + if size > 0: + with open(p, "r+b") as f: + f.write(os.urandom(size)) + f.flush() + os.fsync(f.fileno()) + p.unlink() + except Exception: + try: + p.unlink(missing_ok=True) + except Exception: + pass + + +def _delete_opk_private(email: str, opk_id: str): + d = get_key_dir(email) / "opk_private" + p = d / f"{opk_id}.bin" + _secure_delete(p) + + +def _save_device_id(email: str, device_id: str): + d = get_key_dir(email) + p = d / "device_id.txt" + p.write_text(device_id) + os.chmod(p, 0o600) + + +def _load_device_id(email: str) -> str | None: + d = get_key_dir(email) + p = d / "device_id.txt" + if not p.exists(): + return None + return p.read_text().strip() or None + + +# ------------------------------------------------------------------ +# Identity key change exception (TOFU hard-fail) +# ------------------------------------------------------------------ + +class IdentityKeyChanged(Exception): + """Raised when a peer's identity key has changed (TOFU violation). + + Session creation is blocked until the user explicitly accepts the new key. + """ + def __init__(self, user_id: str, new_key_bytes: bytes, status: str): + self.user_id = user_id + self.new_key_bytes = new_key_bytes + self.status = status # "changed" or "changed_verified" + super().__init__( + f"Identity key changed for user {user_id} (status={status}). " + f"Accept the new key before communicating." + ) + + +# ------------------------------------------------------------------ +# Client-side brute-force lockout +# ------------------------------------------------------------------ + +_LOCKOUT_BASE_SECONDS = 2 +_LOCKOUT_MAX_SECONDS = 300 # 5 min cap + + +def _get_lockout_path(email: str) -> Path: + return get_key_dir(email) / "login_lockout.json" + + +def _check_lockout(email: str) -> float: + """Return seconds remaining until next attempt allowed. 0 = can try now.""" + p = _get_lockout_path(email) + if not p.exists(): + return 0.0 + try: + data = json.loads(p.read_text()) + locked_until = data.get("locked_until", 0.0) + remaining = locked_until - time.time() + return max(0.0, remaining) + except Exception: + return 0.0 + + +def _record_failed_attempt(email: str): + """Increment failed counter, update locked_until.""" + p = _get_lockout_path(email) + failed = 0 + try: + if p.exists(): + data = json.loads(p.read_text()) + failed = data.get("failed_attempts", 0) + except Exception: + pass + failed += 1 + delay = min(_LOCKOUT_BASE_SECONDS ** failed, _LOCKOUT_MAX_SECONDS) + locked_until = time.time() + delay + p.write_text(json.dumps({"failed_attempts": failed, "locked_until": locked_until})) + os.chmod(p, 0o600) + + +def _clear_lockout(email: str): + """Reset on successful login.""" + p = _get_lockout_path(email) + if p.exists(): + try: + p.unlink() + except Exception: + pass + + +def _save_session(email: str, peer_user_id: str, ratchet: DoubleRatchet, + local_key: bytes | None = None, peer_device_id: str | None = None): + d = get_key_dir(email) / "sessions" + d.mkdir(parents=True, exist_ok=True) + os.chmod(d, 0o700) + if peer_device_id: + filename = f"{peer_user_id}_{peer_device_id}.bin" + else: + filename = f"{peer_user_id}.bin" + p = d / filename + data = ratchet.export_state() + if local_key: + data = _encrypt_local(data, local_key) + p.write_bytes(data) + os.chmod(p, 0o600) + + +def _load_session(email: str, peer_user_id: str, + local_key: bytes | None = None, + peer_device_id: str | None = None) -> DoubleRatchet | None: + d = get_key_dir(email) / "sessions" + if peer_device_id: + p = d / f"{peer_user_id}_{peer_device_id}.bin" + if not p.exists(): + # Fallback: try old format (no device_id) and migrate + p_old = d / f"{peer_user_id}.bin" + if p_old.exists(): + ratchet = _load_session_file(p_old, local_key) + if ratchet: + _save_session(email, peer_user_id, ratchet, local_key, + peer_device_id=peer_device_id) + _secure_delete(p_old) + return ratchet + return None + else: + p = d / f"{peer_user_id}.bin" + if not p.exists(): + return None + return _load_session_file(p, local_key) + + +def _load_session_file(p: Path, local_key: bytes | None = None) -> DoubleRatchet | None: + """Load a session from a specific file path.""" + if not p.exists(): + return None + raw = p.read_bytes() + if local_key: + try: + data = _decrypt_local(raw, local_key) + except Exception: + # Migration: try loading as plaintext (old unencrypted format) + try: + ratchet = DoubleRatchet.import_state(raw) + return ratchet + except Exception: + return None + return DoubleRatchet.import_state(data) + return DoubleRatchet.import_state(raw) + + +def _delete_session_file(email: str, peer_user_id: str, peer_device_id: str | None = None): + """Securely delete a session file from disk (for session reset).""" + d = get_key_dir(email) / "sessions" + if peer_device_id: + p = d / f"{peer_user_id}_{peer_device_id}.bin" + else: + p = d / f"{peer_user_id}.bin" + _secure_delete(p) + + +def _save_sender_key_state(email: str, conv_id: str, state: SenderKeyState, + local_key: bytes | None = None): + d = get_key_dir(email) / "sender_keys" + d.mkdir(parents=True, exist_ok=True) + os.chmod(d, 0o700) + p = d / f"{conv_id}.bin" + data = state.export_state() + if local_key: + data = _encrypt_local(data, local_key) + p.write_bytes(data) + os.chmod(p, 0o600) + + +def _load_sender_key_state(email: str, conv_id: str, + local_key: bytes | None = None) -> SenderKeyState | None: + d = get_key_dir(email) / "sender_keys" + p = d / f"{conv_id}.bin" + if not p.exists(): + return None + raw = p.read_bytes() + if local_key: + try: + data = _decrypt_local(raw, local_key) + except Exception: + try: + sk = SenderKeyState.import_state(raw) + _save_sender_key_state(email, conv_id, sk, local_key) + return sk + except Exception: + return None + return SenderKeyState.import_state(data) + return SenderKeyState.import_state(raw) + + +def _save_recv_sender_key(email: str, conv_id: str, sender_id: str, state: SenderKeyState, + local_key: bytes | None = None, + sender_device_id: str | None = None): + d = get_key_dir(email) / "sender_keys_recv" + d.mkdir(parents=True, exist_ok=True) + os.chmod(d, 0o700) + if sender_device_id: + filename = f"{conv_id}_{sender_id}_{sender_device_id}.bin" + else: + filename = f"{conv_id}_{sender_id}.bin" + p = d / filename + data = state.export_state() + if local_key: + data = _encrypt_local(data, local_key) + p.write_bytes(data) + os.chmod(p, 0o600) + + +def _load_recv_sender_key(email: str, conv_id: str, sender_id: str, + local_key: bytes | None = None, + sender_device_id: str | None = None) -> SenderKeyState | None: + d = get_key_dir(email) / "sender_keys_recv" + if sender_device_id: + p = d / f"{conv_id}_{sender_id}_{sender_device_id}.bin" + if not p.exists(): + # Fallback: try old format and migrate + p_old = d / f"{conv_id}_{sender_id}.bin" + if p_old.exists(): + sk = _load_recv_sender_key_file(p_old, local_key) + if sk: + _save_recv_sender_key(email, conv_id, sender_id, sk, local_key, + sender_device_id=sender_device_id) + _secure_delete(p_old) + return sk + return None + else: + p = d / f"{conv_id}_{sender_id}.bin" + if not p.exists(): + return None + return _load_recv_sender_key_file(p, local_key) + + +def _load_recv_sender_key_file(p: Path, local_key: bytes | None = None) -> SenderKeyState | None: + """Load a recv sender key from a specific file path.""" + if not p.exists(): + return None + raw = p.read_bytes() + if local_key: + try: + data = _decrypt_local(raw, local_key) + except Exception: + try: + sk = SenderKeyState.import_state(raw) + return sk + except Exception: + return None + return SenderKeyState.import_state(data) + return SenderKeyState.import_state(raw) + + +# --------------------------------------------------------------------------- +# Local decrypted message cache (Double Ratchet keys are one-time use) +# --------------------------------------------------------------------------- + +def _load_message_cache(email: str, conv_id: str, cache_key: bytes | None = None) -> dict: + d = get_key_dir(email) / "message_cache" + p_bin = d / f"{conv_id}.bin" + p_json = d / f"{conv_id}.json" + + # Migration: if old plaintext .json exists but encrypted .bin doesn't + if p_json.exists() and not p_bin.exists(): + try: + cache = json.loads(p_json.read_text("utf-8")) + if cache_key: + _save_message_cache_full(d, conv_id, cache, cache_key) + _secure_delete(p_json) + return cache + except Exception: + return {} + + if not p_bin.exists(): + return {} + if not cache_key: + return {} + try: + raw = p_bin.read_bytes() + # Format: nonce (12) + tag (16) + ciphertext + nonce = raw[:12] + tag = raw[12:28] + ct = raw[28:] + plaintext = aes_decrypt(cache_key, nonce, ct, tag) + return json.loads(plaintext.decode("utf-8")) + except Exception: + return {} + + +def _save_message_cache_full(d: Path, conv_id: str, cache: dict, cache_key: bytes): + """Write the full cache dict encrypted to disk.""" + d.mkdir(parents=True, exist_ok=True) + os.chmod(d, 0o700) + p = d / f"{conv_id}.bin" + plaintext = json.dumps(cache, ensure_ascii=False).encode("utf-8") + _key, nonce, ct, tag = aes_encrypt(plaintext, key=cache_key) + p.write_bytes(nonce + tag + ct) + os.chmod(p, 0o600) + + +def _save_message_to_cache(email: str, conv_id: str, message_id: str, payload: dict, + cache_key: bytes | None = None): + d = get_key_dir(email) / "message_cache" + cache = _load_message_cache(email, conv_id, cache_key) + cache[message_id] = payload + if cache_key: + _save_message_cache_full(d, conv_id, cache, cache_key) + else: + # Fallback: plaintext (no identity key available yet) + d.mkdir(parents=True, exist_ok=True) + os.chmod(d, 0o700) + p = d / f"{conv_id}.json" + p.write_text(json.dumps(cache, ensure_ascii=False), "utf-8") + os.chmod(p, 0o600) + + +# --------------------------------------------------------------------------- +# Verification storage (TOFU + explicit verification) +# --------------------------------------------------------------------------- + +def _save_known_identity_keys(email: str, keys: dict, local_key: bytes | None = None): + """Save TOFU identity key registry (encrypted with local_key).""" + p = get_key_dir(email) / "known_identity_keys.bin" + data = json.dumps({"version": 1, "keys": keys}).encode("utf-8") + if local_key: + data = _encrypt_local(data, local_key) + p.write_bytes(data) + os.chmod(p, 0o600) + + +def _load_known_identity_keys(email: str, local_key: bytes | None = None) -> dict: + """Load TOFU identity key registry. Returns empty dict on error. + + No plaintext fallback — these files were never stored unencrypted + (feature introduced after local encryption was implemented). + Accepting plaintext would allow an attacker with disk access to + inject fake identity keys and bypass TOFU warnings. + """ + p = get_key_dir(email) / "known_identity_keys.bin" + if not p.exists(): + return {} + raw = p.read_bytes() + try: + if local_key: + data = _decrypt_local(raw, local_key) + else: + data = raw + obj = json.loads(data) + return obj.get("keys", {}) + except Exception: + return {} + + +def _save_verified_contacts(email: str, contacts: dict, local_key: bytes | None = None): + """Save explicit verification state (encrypted with local_key).""" + p = get_key_dir(email) / "verified_contacts.bin" + data = json.dumps({"version": 1, "contacts": contacts}).encode("utf-8") + if local_key: + data = _encrypt_local(data, local_key) + p.write_bytes(data) + os.chmod(p, 0o600) + + +def _load_verified_contacts(email: str, local_key: bytes | None = None) -> dict: + """Load explicit verification state. Returns empty dict on error. + + No plaintext fallback — these files were never stored unencrypted. + Accepting plaintext would allow an attacker with disk access to + inject fake verification records (mark attacker as "verified"). + """ + p = get_key_dir(email) / "verified_contacts.bin" + if not p.exists(): + return {} + raw = p.read_bytes() + try: + if local_key: + data = _decrypt_local(raw, local_key) + else: + data = raw + obj = json.loads(data) + return obj.get("contacts", {}) + except Exception: + return {} + + +def _solve_pow(challenge: str, difficulty: int) -> str: + """Solve a proof-of-work challenge by finding a nonce with enough leading zero bits.""" + target_bytes = difficulty // 8 + target_bits = difficulty % 8 + mask = (0xFF << (8 - target_bits)) & 0xFF if target_bits else 0 + nonce = 0 + while True: + digest = hashlib.sha256(f"{challenge}{nonce}".encode()).digest() + # Fast path: check full zero bytes first + ok = True + for i in range(target_bytes): + if digest[i] != 0: + ok = False + break + if ok and target_bits: + if digest[target_bytes] & mask: + ok = False + if ok: + return str(nonce) + nonce += 1 + + +class ChatClient: + def __init__(self): + self.reader: ProtocolReader | None = None + self.writer: ProtocolWriter | None = None + self.raw_writer: asyncio.StreamWriter | None = None + self.session: dict | None = None + self.private_key = None # RSA private key (login only) + self.public_key = None # RSA public key (login only) + self.username: str = "" + self.email: str = "" + self._listener_task: asyncio.Task | None = None + self._response_queue: asyncio.Queue = asyncio.Queue() + self._notification_queue: asyncio.Queue = asyncio.Queue() + self._pending: dict[str, asyncio.Future] = {} + self._pairing_temp_private_key = None + self._reencrypt_progress_cb = None + self._logger = logging.getLogger("encrypted_chat.client") + + # Signal Protocol keys + self.identity_private = None # Ed25519PrivateKey + self.identity_public = None # Ed25519PublicKey + self.spk_private = None # X25519PrivateKey (current signed prekey) + self.spk_id: str = "" + self._prev_spk_private = None # Previous SPK for grace period (M4) + self._prev_spk_id: str = "" + self.opk_privates: dict[str, object] = {} # id -> X25519PrivateKey + self.sessions: dict[str, DoubleRatchet] = {} # "user_id:device_id" -> ratchet + self.sender_key_states: dict[str, SenderKeyState] = {} # conv_id -> own sender key + self.recv_sender_keys: dict[str, SenderKeyState] = {} # "conv_id:sender_id:device_id" -> their key + # Cache: user_id -> {identity_key (Ed25519PublicKey), username, email} + self._user_cache: dict[str, dict] = {} + self.connected: bool = False + self.login_rejected: bool = False + self._cache_key: bytes | None = None # AES key for encrypting message cache on disk + self._local_key: bytes | None = None # AES key for encrypting session/sender key files + # Multi-device support + self.device_id: str | None = None # This device's UUID + self._device_bundle_cache: dict[str, tuple[float, list[dict]]] = {} # user_id -> (ts, bundles) + # Queue of received messages to self-encrypt for multi-device access + self._pending_self_encrypt: list[dict] = [] + # Contact key verification (TOFU + explicit) + self._known_identity_keys: dict = {} # user_id -> {identity_key hex, first_seen, last_seen} + self._verified_contacts: dict = {} # user_id -> {identity_key hex, verified_at, method} + self._key_change_cb = None # callback(user_id, username, old_key_hex, was_verified) + + async def connect(self): + host = os.getenv("SERVER_HOST", "127.0.0.1") + port = int(os.getenv("SERVER_PORT", "9999")) + tls_enabled = os.getenv("TLS_ENABLED", "false").lower() in ("1", "true", "yes") + tls_required = os.getenv("TLS_REQUIRED", "false").lower() in ("1", "true", "yes") + ssl_context = None + if tls_required and not tls_enabled: + raise RuntimeError("TLS_REQUIRED is enabled but TLS is not enabled.") + if tls_enabled: + insecure = os.getenv("TLS_INSECURE", "false").lower() in ("1", "true", "yes") + is_dev = os.getenv("ENVIRONMENT", "").lower() in ("dev", "development") + if insecure and not is_dev: + raise RuntimeError("TLS_INSECURE is only allowed when ENVIRONMENT=dev") + ssl_context = ssl.create_default_context() + ca_file = os.getenv("TLS_CA_FILE", "").strip() + if ca_file: + ssl_context.load_verify_locations(cafile=ca_file) + elif insecure: + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + else: + self._logger.warning("TLS is disabled — traffic is unencrypted. Set TLS_ENABLED=true for production.") + r, w = await asyncio.open_connection(host, port, limit=MAX_MESSAGE_BYTES, ssl=ssl_context) + self.reader = ProtocolReader(r) + self.writer = ProtocolWriter(w) + self.raw_writer = w + self.connected = True + + async def _background_listener(self): + """Read messages from server, routing responses vs notifications.""" + while True: + msg = await self.reader.read_message() + if msg is None: + self.connected = False + # Fail all pending futures so send_and_recv doesn't hang + pending = dict(self._pending) + self._pending.clear() + err = ConnectionError("Server connection lost") + for fut in pending.values(): + if not fut.done(): + fut.set_exception(err) + break + if msg.get("type") in ("new_message", "messages_read", "message_deleted", + "conversation_created", "member_added", "member_removed", + "user_online", "user_offline", "online_users", + "group_invitation", "conversation_renamed", + "session_reset", + "message_reacted", "message_pinned", "message_unpinned", + "message_delivered", "username_changed"): + await self._notification_queue.put(msg) + else: + req_id = msg.get("request_id") + if req_id and req_id in self._pending: + fut = self._pending.pop(req_id) + if not fut.done(): + fut.set_result(msg) + else: + await self._response_queue.put(msg) + + async def send_and_recv(self, msg_type: str, timeout: float = 30.0, **kwargs) -> dict: + try: + request_id = str(uuid.uuid4()) + loop = asyncio.get_running_loop() + fut = loop.create_future() + self._pending[request_id] = fut + await self.writer.send_request(msg_type, request_id=request_id, **kwargs) + except (ValueError, ConnectionError, OSError) as e: + self._pending.pop(request_id, None) + return { + "type": msg_type, + "status": "error", + "data": {"message": str(e) or "Connection lost."}, + } + try: + return await asyncio.wait_for(fut, timeout=timeout) + except asyncio.TimeoutError: + self._logger.warning("send_and_recv timeout for '%s' after %.0fs", msg_type, timeout) + return { + "type": msg_type, + "status": "error", + "data": {"message": f"Request timed out ({msg_type})"}, + } + except ConnectionError: + return { + "type": msg_type, + "status": "error", + "data": {"message": "Connection lost."}, + } + finally: + self._pending.pop(request_id, None) + + # ------------------------------------------------------------------ + # User info / identity key cache + # ------------------------------------------------------------------ + + async def _get_user_info(self, user_id: str = "", email: str = "") -> dict | None: + """Get user info from server, cache identity key. Performs TOFU check.""" + cached = self._user_cache.get(user_id) + if cached: + return cached + kwargs = {} + if user_id: + kwargs["user_id"] = user_id + elif email: + kwargs["email"] = email + else: + return None + resp = await self.send_and_recv("get_user_info", **kwargs) + if resp["status"] != "ok": + return None + data = resp["data"] + ik_bytes = decode_binary(data["identity_key"]) if data.get("identity_key") else None + info = { + "user_id": data["user_id"], + "username": data["username"], + "email": data["email"], + "identity_key": load_ed25519_public(ik_bytes) if ik_bytes else None, + "identity_key_bytes": ik_bytes, + } + # TOFU: check identity key against known keys + if ik_bytes: + status = self.check_identity_key(data["user_id"], ik_bytes) + info["identity_key_status"] = status + self._user_cache[data["user_id"]] = info + return info + + # ------------------------------------------------------------------ + # Contact Key Verification + # ------------------------------------------------------------------ + + def _load_verification_stores(self): + """Load TOFU and verification stores from disk.""" + if not self.email: + return + self._known_identity_keys = _load_known_identity_keys(self.email, self._local_key) + self._verified_contacts = _load_verified_contacts(self.email, self._local_key) + + def check_identity_key(self, user_id: str, identity_key_bytes: bytes) -> str: + """Check a user's identity key against TOFU registry. + + Returns: + "new" — first contact, key recorded (TOFU) + "trusted" — key matches previously seen, not explicitly verified + "verified" — key matches and explicitly verified + "changed" — key differs from recorded (WARNING) + "changed_verified" — key changed AND was previously verified (CRITICAL) + """ + ik_hex = identity_key_bytes.hex() + now = datetime.now(timezone.utc).isoformat() + known = self._known_identity_keys.get(user_id) + + if known is None: + # First time seeing this user — TOFU: trust on first use + self._known_identity_keys[user_id] = { + "identity_key": ik_hex, + "first_seen": now, + "last_seen": now, + } + if self.email: + _save_known_identity_keys(self.email, self._known_identity_keys, self._local_key) + return "new" + + if known["identity_key"] == ik_hex: + # Key matches — update last_seen + known["last_seen"] = now + if self.email: + _save_known_identity_keys(self.email, self._known_identity_keys, self._local_key) + # Check if explicitly verified + verified = self._verified_contacts.get(user_id) + if verified and verified.get("identity_key") == ik_hex: + return "verified" + return "trusted" + + # Key has CHANGED + was_verified = user_id in self._verified_contacts + old_key_hex = known["identity_key"] + + # Invoke callback for GUI/CLI warning + if self._key_change_cb: + username = "" + cached = self._user_cache.get(user_id) + if cached: + username = cached.get("username", "") + try: + self._key_change_cb(user_id, username, old_key_hex, was_verified, identity_key_bytes) + except Exception: + pass + + return "changed_verified" if was_verified else "changed" + + def verify_contact(self, user_id: str, identity_key_bytes: bytes, method: str = "manual"): + """Mark a contact's identity key as explicitly verified.""" + ik_hex = identity_key_bytes.hex() + now = datetime.now(timezone.utc).isoformat() + self._verified_contacts[user_id] = { + "identity_key": ik_hex, + "verified_at": now, + "method": method, + } + # Also ensure TOFU registry is up to date + if user_id not in self._known_identity_keys: + self._known_identity_keys[user_id] = { + "identity_key": ik_hex, + "first_seen": now, + "last_seen": now, + } + else: + self._known_identity_keys[user_id]["last_seen"] = now + if self.email: + _save_verified_contacts(self.email, self._verified_contacts, self._local_key) + _save_known_identity_keys(self.email, self._known_identity_keys, self._local_key) + # Update user cache status + cached = self._user_cache.get(user_id) + if cached: + cached["identity_key_status"] = "verified" + + def unverify_contact(self, user_id: str): + """Remove explicit verification for a contact.""" + self._verified_contacts.pop(user_id, None) + if self.email: + _save_verified_contacts(self.email, self._verified_contacts, self._local_key) + cached = self._user_cache.get(user_id) + if cached and cached.get("identity_key_status") == "verified": + cached["identity_key_status"] = "trusted" + + def accept_key_change(self, user_id: str, new_ik_bytes: bytes): + """Accept a changed identity key — update TOFU, remove old verification.""" + ik_hex = new_ik_bytes.hex() + now = datetime.now(timezone.utc).isoformat() + self._known_identity_keys[user_id] = { + "identity_key": ik_hex, + "first_seen": now, + "last_seen": now, + } + # Remove old verification — user must re-verify + self._verified_contacts.pop(user_id, None) + if self.email: + _save_known_identity_keys(self.email, self._known_identity_keys, self._local_key) + _save_verified_contacts(self.email, self._verified_contacts, self._local_key) + # Update cache + cached = self._user_cache.get(user_id) + if cached: + cached["identity_key_status"] = "trusted" + + def get_verification_status(self, user_id: str) -> str: + """Get verification status for a user. + + Returns: "verified", "trusted", or "unverified". + """ + verified = self._verified_contacts.get(user_id) + if verified: + # Check key still matches + known = self._known_identity_keys.get(user_id) + if known and known.get("identity_key") == verified.get("identity_key"): + return "verified" + if user_id in self._known_identity_keys: + return "trusted" + return "unverified" + + def get_safety_number(self, peer_user_id: str) -> str | None: + """Get formatted safety number for a peer (requires both identity keys).""" + if not self.identity_public or not self.session: + return None + my_uid = self.session.get("user_id", "") + my_ik_bytes = serialize_ed25519_public(self.identity_public) + cached = self._user_cache.get(peer_user_id) + if not cached or not cached.get("identity_key_bytes"): + return None + return compute_safety_number(my_uid, my_ik_bytes, + peer_user_id, cached["identity_key_bytes"]) + + def get_my_fingerprint(self) -> str | None: + """Get formatted fingerprint for own identity key.""" + if not self.identity_public or not self.session: + return None + my_uid = self.session.get("user_id", "") + my_ik_bytes = serialize_ed25519_public(self.identity_public) + fp = compute_fingerprint(my_uid, my_ik_bytes) + return format_fingerprint(fp) + + def get_peer_fingerprint(self, peer_user_id: str) -> str | None: + """Get formatted fingerprint for a peer's identity key.""" + cached = self._user_cache.get(peer_user_id) + if not cached or not cached.get("identity_key_bytes"): + return None + fp = compute_fingerprint(peer_user_id, cached["identity_key_bytes"]) + return format_fingerprint(fp) + + def get_verification_qr_data(self) -> bytes | None: + """Get QR code payload bytes for own identity (for peer to scan).""" + if not self.identity_public or not self.session: + return None + my_uid = self.session.get("user_id", "") + my_ik_bytes = serialize_ed25519_public(self.identity_public) + return encode_verification_qr(my_uid, my_ik_bytes) + + def verify_qr_code(self, qr_data: bytes) -> tuple[bool, str, str]: + """Verify a scanned QR code against known identity keys. + + Returns (success, user_id, message). + """ + try: + user_id, ik_bytes = decode_verification_qr(qr_data) + except ValueError as e: + return False, "", f"Invalid QR code: {e}" + cached = self._user_cache.get(user_id) + if not cached: + return False, user_id, "Unknown user — not in your contacts." + if not cached.get("identity_key_bytes"): + return False, user_id, "No identity key on record for this user." + if cached["identity_key_bytes"] != ik_bytes: + return False, user_id, "Identity key MISMATCH — verification failed!" + # Keys match — mark as verified + self.verify_contact(user_id, ik_bytes, method="qr_code") + username = cached.get("username", user_id[:8]) + return True, user_id, f"Verified {username} via QR code." + + # ------------------------------------------------------------------ + # Registration + # ------------------------------------------------------------------ + + async def register(self, username: str, password: str, email: str) -> tuple[bool, str]: + """Register user. Generates RSA + Ed25519 + prekeys.""" + self.username = username + self.email = email + pwd_bytes = bytearray(password.encode("utf-8")) if password else None + + try: + # RSA keys for login + priv, pub, err = load_keys(email, password=bytes(pwd_bytes) if pwd_bytes else None) + if priv is None: + priv, pub = generate_rsa_keypair() + save_keys(email, priv, pub, password=bytes(pwd_bytes) if pwd_bytes else None) + self.private_key = priv + self.public_key = pub + + # Ed25519 identity keys + ed_priv, ed_pub = _load_identity_keys(email, password=bytes(pwd_bytes) if pwd_bytes else None) + if ed_priv is None: + ed_priv, ed_pub = generate_identity_keypair() + _save_identity_keys(email, ed_priv, ed_pub, password=bytes(pwd_bytes) if pwd_bytes else None) + self.identity_private = ed_priv + self.identity_public = ed_pub + self._cache_key = derive_self_encryption_key(ed_priv) + self._local_key = derive_local_storage_key(ed_priv) + self._load_verification_stores() + finally: + if pwd_bytes: + pwd_bytes[:] = b'\x00' * len(pwd_bytes) + + pub_pem = serialize_public_key(pub).decode("utf-8") + ik_b64 = encode_binary(serialize_ed25519_public(ed_pub)) + + extra_fields: dict = {} + start = await self.send_and_recv( + "register", + username=username, + public_key=pub_pem, + email=email, + identity_key=ik_b64, + ) + # Handle PoW challenge (server under pressure) + if start.get("status") == "pow_required": + challenge = start["data"]["challenge"] + mac = start["data"]["mac"] + difficulty = start["data"]["difficulty"] + logger.info("Server requires proof-of-work (difficulty %d), solving...", difficulty) + nonce = _solve_pow(challenge, difficulty) + extra_fields = {"pow_challenge": challenge, "pow_mac": mac, "pow_nonce": nonce} + start = await self.send_and_recv( + "register", + username=username, + public_key=pub_pem, + email=email, + identity_key=ik_b64, + **extra_fields, + ) + if start["status"] != "ok": + return False, start["data"]["message"] + code = start["data"].get("code") + if code: + return True, code + return True, start["data"].get("message", "Check your email for the code.") + + async def confirm_registration(self, email: str, username: str, code: str) -> tuple[bool, str]: + confirm = await self.send_and_recv("register_confirm", email=email, code=code) + if confirm["status"] == "ok": + # Upload prekeys immediately after registration + await self._generate_and_upload_prekeys() + return True, f"Registered as '{username}' (ID: {confirm['data']['user_id']})" + return False, confirm["data"]["message"] + + async def _generate_and_upload_prekeys(self, keep_spk: bool = False): + """Generate SPK + OPKs and upload to server. + + If keep_spk=True, re-sign the existing SPK instead of generating a new + one. This is used after device pairing so both devices share the same + SPK and either can respond to X3DH. + """ + if not self.identity_private: + return + + if keep_spk and self.spk_private and self.spk_id: + # Re-sign existing SPK (both devices share the identity key) + spk_pub_bytes = serialize_x25519_public(self.spk_private.public_key()) + spk_sig = ed25519_sign(self.identity_private, spk_pub_bytes) + spk_data = { + "id": self.spk_id, + "public_key": encode_binary(spk_pub_bytes), + "signature": encode_binary(spk_sig), + } + else: + # Save current SPK as previous for grace period (M4: in-flight X3DH) + if self.spk_private and self.spk_id: + self._prev_spk_private = self.spk_private + self._prev_spk_id = self.spk_id + _save_prev_spk(self.email, self.spk_private, self.spk_id, self._local_key) + # Generate a brand-new signed prekey + spk = generate_signed_prekey(self.identity_private) + self.spk_private = spk["private"] + self.spk_id = spk["id"] + _save_spk(self.email, spk["private"], spk["id"], self._local_key) + spk_data = { + "id": spk["id"], + "public_key": encode_binary(serialize_x25519_public(spk["public"])), + "signature": encode_binary(spk["signature"]), + } + + # Generate one-time prekeys + opks = generate_one_time_prekeys(OPK_BATCH_SIZE) + for opk in opks: + self.opk_privates[opk["id"]] = opk["private"] + _save_opk_private(self.email, opk["id"], opk["private"], self._local_key) + + # Upload to server + otp_data = [ + {"id": opk["id"], "public_key": encode_binary(serialize_x25519_public(opk["public"]))} + for opk in opks + ] + resp = await self.send_and_recv( + "upload_prekeys", + signed_prekey=spk_data, + one_time_prekeys=otp_data, + ) + if resp.get("status") != "ok": + self._logger.warning("upload_prekeys failed: %s (will retry on login)", + resp.get("data", {}).get("message", "unknown")) + + async def _ensure_prekeys(self): + """Check OPK count and SPK age, replenish/rotate if needed. + + Uses single-roundtrip `ensure_prekeys` handler when available, + falls back to legacy two-step flow (get_prekey_count + upload_prekeys). + """ + resp = await self.send_and_recv("get_prekey_count") + if resp["status"] != "ok": + return + count = resp["data"].get("count", 0) + spk_created_at = resp["data"].get("spk_created_at", "") + + need_new_spk = False + if spk_created_at: + try: + created = datetime.fromisoformat(spk_created_at) + if created.tzinfo is None: + created = created.replace(tzinfo=timezone.utc) + age_days = (datetime.now(timezone.utc) - created).days + if age_days >= SPK_ROTATION_DAYS: + need_new_spk = True + self._logger.info("SPK is %d days old, rotating...", age_days) + except Exception: + need_new_spk = True + else: + # No SPK on server for this device — must upload one + need_new_spk = True + self._logger.info("No SPK on server for this device, uploading...") + + if count < OPK_REPLENISH_THRESHOLD or need_new_spk: + if count >= OPK_REPLENISH_THRESHOLD: + self._logger.info("SPK rotation triggered (OPK count OK: %d)", count) + else: + self._logger.info("OPK count low (%d), replenishing...", count) + await self._generate_and_upload_prekeys_batch(need_new_spk) + + async def _generate_and_upload_prekeys_batch(self, need_new_spk: bool = False): + """Generate and upload prekeys in a single round-trip via ensure_prekeys.""" + if not self.identity_private: + return + + kwargs: dict = {} + + # SPK + if need_new_spk: + if self.spk_private and self.spk_id: + self._prev_spk_private = self.spk_private + self._prev_spk_id = self.spk_id + _save_prev_spk(self.email, self.spk_private, self.spk_id, self._local_key) + spk = generate_signed_prekey(self.identity_private) + self.spk_private = spk["private"] + self.spk_id = spk["id"] + _save_spk(self.email, spk["private"], spk["id"], self._local_key) + kwargs["signed_prekey"] = { + "id": spk["id"], + "public_key": encode_binary(serialize_x25519_public(spk["public"])), + "signature": encode_binary(spk["signature"]), + } + + # OPKs + opks = generate_one_time_prekeys(OPK_BATCH_SIZE) + for opk in opks: + self.opk_privates[opk["id"]] = opk["private"] + _save_opk_private(self.email, opk["id"], opk["private"], self._local_key) + kwargs["one_time_prekeys"] = [ + {"id": opk["id"], "public_key": encode_binary(serialize_x25519_public(opk["public"]))} + for opk in opks + ] + + resp = await self.send_and_recv("ensure_prekeys", **kwargs) + if resp["status"] == "ok": + data = resp.get("data", {}) + self._logger.info("ensure_prekeys: count=%d, spk_uploaded=%s, otps_uploaded=%d", + data.get("count", 0), data.get("uploaded_spk", False), + data.get("uploaded_otps", 0)) + + # ------------------------------------------------------------------ + # Login + # ------------------------------------------------------------------ + + async def login(self, email: str, password: str) -> tuple[bool, str]: + """Login user. Returns (success, message).""" + self.email = email + + # Brute-force lockout check + remaining = _check_lockout(email) + if remaining > 0: + return False, f"Too many failed attempts. Try again in {remaining:.0f}s." + + pwd_bytes = bytearray(password.encode("utf-8")) if password else None + + try: + # Load RSA keys + priv, pub, err = load_keys(email, password=bytes(pwd_bytes) if pwd_bytes else None) + if priv is None: + if err and "password" in err.lower(): + _record_failed_attempt(email) + return False, err or "No local keys found. Register first." + self.private_key = priv + self.public_key = pub + + # Load identity keys + ed_priv, ed_pub = _load_identity_keys(email, password=bytes(pwd_bytes) if pwd_bytes else None) + finally: + if pwd_bytes: + pwd_bytes[:] = b'\x00' * len(pwd_bytes) + + if ed_priv is not None: + self.identity_private = ed_priv + self.identity_public = ed_pub + self._cache_key = derive_self_encryption_key(ed_priv) + self._local_key = derive_local_storage_key(ed_priv) + self._load_verification_stores() + + # Load SPK + spk_priv, spk_id = _load_spk(email, self._local_key) + if spk_priv: + self.spk_private = spk_priv + self.spk_id = spk_id + + # Load previous SPK for grace period (M4) + prev_spk_priv, prev_spk_id = _load_prev_spk(email, self._local_key) + if prev_spk_priv: + self._prev_spk_private = prev_spk_priv + self._prev_spk_id = prev_spk_id + + # Load device_id from disk + self.device_id = _load_device_id(email) + + # RSA challenge-response login + start = await self.send_and_recv("login_start", email=email) + if start["status"] != "ok": + return False, start["data"]["message"] + + challenge = decode_binary(start["data"]["challenge"]) + signature = rsa_sign(self.private_key, challenge) + login_kwargs = {"email": email, "signature": encode_binary(signature), + "client_version": VERSION} + if self.device_id: + login_kwargs["device_id"] = self.device_id + finish = await self.send_and_recv("login_finish", **login_kwargs) + if finish["status"] == "ok": + self.session = finish["data"] + self.username = self.session.get("username", "") + # Store device_id from server + self.device_id = finish["data"].get("device_id") + if self.device_id: + _save_device_id(email, self.device_id) + # Replenish prekeys in background — after pairing, the new device + # has no local OPK private keys so we must generate fresh ones + # (server-side OPKs have no matching private keys on this device). + # Use keep_spk=True to preserve the shared SPK so both devices + # can respond to X3DH. + opk_dir = get_key_dir(self.email) / "opk_private" + has_local_opks = opk_dir.exists() and any(opk_dir.iterdir()) + if has_local_opks: + asyncio.create_task(self._ensure_prekeys()) + else: + self._logger.info("No local OPKs (likely new device). Generating fresh OPKs, keeping SPK.") + asyncio.create_task(self._generate_and_upload_prekeys(keep_spk=True)) + _clear_lockout(email) + return True, f"Logged in as '{self.username}' (ID: {self.session['user_id']})" + return False, finish["data"]["message"] + + # ------------------------------------------------------------------ + # Pairing (device pairing — transfers RSA + identity keys) + # ------------------------------------------------------------------ + + async def pairing_start(self, email: str) -> tuple[bool, str]: + """Start device pairing. Returns (success, code/message).""" + temp_priv, temp_pub = generate_rsa_keypair(2048) + self._pairing_temp_private_key = temp_priv + temp_pub_pem = serialize_public_key(temp_pub).decode("utf-8") + resp = await self.send_and_recv("pairing_start", email=email, temp_public_key=temp_pub_pem) + if resp["status"] == "ok": + self._pairing_poll_token = resp["data"].get("poll_token", "") + return True, resp["data"]["code"] + return False, resp["data"]["message"] + + async def pairing_wait(self, code: str, email: str, password: str, timeout: int = 300) -> tuple[bool, str]: + """Wait for pairing payload and import keys. Returns (success, message).""" + if not self._pairing_temp_private_key: + return False, "Pairing not started." + from crypto_utils import aes_decrypt as _aes_decrypt + poll_token = getattr(self, "_pairing_poll_token", "") + deadline = asyncio.get_event_loop().time() + timeout + while asyncio.get_event_loop().time() < deadline: + resp = await self.send_and_recv("pairing_poll", code=code, poll_token=poll_token) + if resp["status"] != "ok": + return False, resp["data"]["message"] + if not resp["data"].get("ready"): + await asyncio.sleep(2.0) + continue + payload = resp["data"]["payload"] + try: + # Decrypt AES key with temp RSA key + from cryptography.hazmat.primitives.asymmetric import padding as rsa_padding + from cryptography.hazmat.primitives import hashes as rsa_hashes + enc_aes_key = decode_binary(payload["encrypted_key"]) + aes_key = self._pairing_temp_private_key.decrypt( + enc_aes_key, + rsa_padding.OAEP( + mgf=rsa_padding.MGF1(algorithm=rsa_hashes.SHA256()), + algorithm=rsa_hashes.SHA256(), + label=None, + ), + ) + nonce = decode_binary(payload["iv"]) + ct = decode_binary(payload["ciphertext"]) + tag = decode_binary(payload["tag"]) + keys_json = _aes_decrypt(aes_key, nonce, ct, tag) + keys_data = json.loads(keys_json) + + pwd_bytes = bytearray(password.encode("utf-8")) if password else None + + try: + # Import RSA key + rsa_priv = load_private_key(keys_data["rsa_private"].encode(), password=None) + rsa_pub = rsa_priv.public_key() + save_keys(email, rsa_priv, rsa_pub, password=bytes(pwd_bytes) if pwd_bytes else None) + + # Import identity keys + ed_priv = load_ed25519_private(bytes.fromhex(keys_data["identity_private"])) + ed_pub = ed_priv.public_key() + _save_identity_keys(email, ed_priv, ed_pub, password=bytes(pwd_bytes) if pwd_bytes else None) + finally: + if pwd_bytes: + pwd_bytes[:] = b'\x00' * len(pwd_bytes) + + self.email = email + self.private_key = rsa_priv + self.public_key = rsa_pub + self.identity_private = ed_priv + self.identity_public = ed_pub + self._cache_key = derive_self_encryption_key(ed_priv) + self._local_key = derive_local_storage_key(ed_priv) + self._load_verification_stores() + self._pairing_temp_private_key = None + + # Multi-device: new device generates own SPK + OPKs on first + # login. No session/sender key import needed — each device + # has independent Double Ratchet sessions. + + return True, "Pairing complete." + except Exception as e: + return False, f"Failed to import keys: {e}" + return False, "Pairing timed out." + + async def authorize_device(self, code: str) -> tuple[bool, str]: + """Authorize a new device by sending all keys to it.""" + if not self.private_key or not self.identity_private: + return False, "Not logged in." + claim = await self.send_and_recv("pairing_claim", code=code) + if claim["status"] != "ok": + return False, claim["data"]["message"] + + temp_pub_pem = claim["data"]["temp_public_key"].encode("utf-8") + temp_pub = load_public_key(temp_pub_pem) + + # Phase 1: Re-encrypt message history so new device can read old + # messages via self-encryption key. This also advances ratchet states + # for any previously-unfetched messages. + try: + await self.reencrypt_history() + except Exception as e: + self._logger.warning("Re-encryption failed: %s", e) + + # Phase 2: Build keys payload — only RSA + identity key. + # Multi-device: new device generates own SPK + OPKs, creates independent + # sessions. No session/sender key transfer needed. + keys_data = { + "rsa_private": serialize_private_key(self.private_key, password=None).decode(), + "identity_private": serialize_ed25519_private_raw(self.identity_private).hex(), + } + + # Phase 3: Encrypt and send keys to new device + from cryptography.hazmat.primitives.asymmetric import padding as rsa_padding + from cryptography.hazmat.primitives import hashes as rsa_hashes + plaintext = json.dumps(keys_data).encode() + aes_key, nonce, ct, tag = aes_encrypt(plaintext) + enc_aes_key = temp_pub.encrypt( + aes_key, + rsa_padding.OAEP( + mgf=rsa_padding.MGF1(algorithm=rsa_hashes.SHA256()), + algorithm=rsa_hashes.SHA256(), + label=None, + ), + ) + payload = { + "encrypted_key": encode_binary(enc_aes_key), + "iv": encode_binary(nonce), + "ciphertext": encode_binary(ct), + "tag": encode_binary(tag), + } + resp = await self.send_and_recv("pairing_send", code=code, payload=payload) + if resp["status"] == "ok": + return True, "Device authorized." + return False, resp["data"]["message"] + + # ------------------------------------------------------------------ + # Password change (local key re-encryption only) + # ------------------------------------------------------------------ + + def change_password(self, old_password: str, new_password: str) -> tuple[bool, str]: + """Change password for local key encryption (RSA + identity key). + + Returns (success, message). + """ + if not self.email: + return False, "Not logged in." + + old_pwd = bytearray(old_password.encode("utf-8")) + new_pwd = bytearray(new_password.encode("utf-8")) + try: + # 1. Verify old password by loading keys + priv, pub, err = load_keys(self.email, password=bytes(old_pwd)) + if priv is None: + return False, "Wrong current password." + + ed_priv, ed_pub = _load_identity_keys(self.email, password=bytes(old_pwd)) + if ed_priv is None: + return False, "Failed to load identity key." + + # 2. Re-save with new password + save_keys(self.email, priv, pub, password=bytes(new_pwd)) + _save_identity_keys(self.email, ed_priv, ed_pub, password=bytes(new_pwd)) + + return True, "Password changed successfully." + finally: + old_pwd[:] = b'\x00' * len(old_pwd) + new_pwd[:] = b'\x00' * len(new_pwd) + + async def change_username(self, new_username: str) -> tuple[bool, str]: + """Change display name on server.""" + if not self.session: + return False, "Not logged in." + new_username = new_username.strip() + if not new_username or len(new_username) > 100: + return False, "Username must be 1-100 characters." + resp = await self.send_and_recv("change_username", username=new_username) + if resp["status"] == "ok": + self.username = resp["data"]["username"] + if self.session: + self.session["username"] = self.username + return True, "Username changed." + return False, resp["data"].get("message", "Unknown error") + + # ------------------------------------------------------------------ + # Key rotation (RSA login key only) + # ------------------------------------------------------------------ + + async def rotate_keys(self, username: str, password: str) -> tuple[bool, str]: + """Rotate RSA keypair to revoke other devices.""" + if not self.session or self.session.get("username") != username: + return False, "Not logged in." + pwd_bytes = password.encode("utf-8") if password else None + priv, pub = generate_rsa_keypair() + save_keys(self.email, priv, pub, password=pwd_bytes) + self.private_key = priv + self.public_key = pub + pub_pem = serialize_public_key(pub).decode("utf-8") + resp = await self.send_and_recv("rotate_keys", public_key=pub_pem) + if resp["status"] == "ok": + return True, "RSA login keys rotated." + return False, resp["data"]["message"] + + # ------------------------------------------------------------------ + # Session management (X3DH + Double Ratchet) + # ------------------------------------------------------------------ + + async def _get_device_bundles(self, peer_user_id: str) -> list[dict]: + """Get per-device key bundles for a peer. Caches for 5 minutes.""" + import time + cached = self._device_bundle_cache.get(peer_user_id) + if cached: + ts, bundles = cached + if time.time() - ts < 300: + return bundles + + resp = await self.send_and_recv("get_key_bundle", user_id=peer_user_id) + if resp["status"] != "ok": + raise RuntimeError(f"Cannot get key bundle for {peer_user_id}: {resp['data']['message']}") + + data = resp["data"] + ik_b64 = data.get("identity_key", "") + + device_bundles = data.get("device_bundles") + if device_bundles: + # Attach identity_key to each bundle + for b in device_bundles: + b["identity_key"] = ik_b64 + else: + # Old server: wrap flat response as single-entry list + device_bundles = [{ + "device_id": None, + "identity_key": ik_b64, + "signed_prekey_id": data.get("signed_prekey_id", ""), + "signed_prekey": data.get("signed_prekey", ""), + "spk_signature": data.get("spk_signature", ""), + "one_time_prekey_id": data.get("one_time_prekey_id"), + "one_time_prekey": data.get("one_time_prekey"), + }] + + self._device_bundle_cache[peer_user_id] = (time.time(), device_bundles) + return device_bundles + + async def _get_or_create_session(self, peer_user_id: str, + peer_device_id: str | None = None, + bundle: dict | None = None) -> DoubleRatchet: + """Load existing session or create one via X3DH. + + If peer_device_id is set, sessions are keyed by "user_id:device_id". + If bundle is provided, it's used instead of fetching from server. + """ + session_key = f"{peer_user_id}:{peer_device_id}" if peer_device_id else peer_user_id + + # Check in-memory cache + if session_key in self.sessions: + return self.sessions[session_key] + + # Check on disk + ratchet = _load_session(self.email, peer_user_id, self._local_key, + peer_device_id=peer_device_id) + if ratchet: + self.sessions[session_key] = ratchet + return ratchet + + # Create new session via X3DH + if not bundle: + resp = await self.send_and_recv("get_key_bundle", user_id=peer_user_id) + if resp["status"] != "ok": + raise RuntimeError(f"Cannot get key bundle for {peer_user_id}: {resp['data']['message']}") + bundle = resp["data"] + + ik_remote_bytes = decode_binary(bundle["identity_key"]) + ik_remote = load_ed25519_public(ik_remote_bytes) + + # TOFU: verify identity key before using it in X3DH + ik_status = self.check_identity_key(peer_user_id, ik_remote_bytes) + if ik_status in ("changed", "changed_verified"): + raise IdentityKeyChanged(peer_user_id, ik_remote_bytes, ik_status) + + spk_remote = load_x25519_public(decode_binary(bundle["signed_prekey"])) + spk_sig = decode_binary(bundle["spk_signature"]) + + opk_remote = None + opk_id = bundle.get("one_time_prekey_id") + if bundle.get("one_time_prekey"): + opk_remote = load_x25519_public(decode_binary(bundle["one_time_prekey"])) + + # Perform X3DH + shared_secret, ek_priv, ek_pub = x3dh_initiate( + self.identity_private, + ik_remote, + spk_remote, + spk_sig, + opk_remote, + ) + + # Initialize Double Ratchet as Alice + ratchet = DoubleRatchet.init_alice(shared_secret, spk_remote) + self.sessions[session_key] = ratchet + _save_session(self.email, peer_user_id, ratchet, self._local_key, + peer_device_id=peer_device_id) + + # Build X3DH header for first message + x3dh_header = { + "ik": encode_binary(serialize_ed25519_public(self.identity_public)), + "ek": encode_binary(serialize_x25519_public(ek_pub)), + } + if opk_id: + x3dh_header["opk_id"] = opk_id + + # Cache the x3dh header for the next send_message call + ratchet._x3dh_header = x3dh_header + + # Cache remote user info + self._user_cache[peer_user_id] = { + "user_id": peer_user_id, + "identity_key": ik_remote, + "identity_key_bytes": ik_remote_bytes, + "identity_key_status": ik_status, + } + + return ratchet + + def _process_x3dh_header(self, sender_id: str, x3dh_header: dict, + sender_device_id: str | None = None, + spk_override=None) -> DoubleRatchet: + """Process an incoming X3DH header to establish session as Bob. + + Args: + spk_override: If provided, use this SPK private key instead of self.spk_private. + Used for grace period fallback (M4). + """ + ik_remote_bytes = decode_binary(x3dh_header["ik"]) + ik_remote = load_ed25519_public(ik_remote_bytes) + + # TOFU: verify identity key before using it in X3DH + ik_status = self.check_identity_key(sender_id, ik_remote_bytes) + if ik_status in ("changed", "changed_verified"): + raise IdentityKeyChanged(sender_id, ik_remote_bytes, ik_status) + + ek_remote = load_x25519_public(decode_binary(x3dh_header["ek"])) + + opk_id = x3dh_header.get("opk_id") + opk_priv = None + if opk_id: + opk_priv = _load_opk_private(self.email, opk_id, self._local_key) + if opk_priv: + _delete_opk_private(self.email, opk_id) + + spk_priv = spk_override if spk_override else self.spk_private + + shared_secret = x3dh_respond( + self.identity_private, + spk_priv, + ik_remote, + ek_remote, + opk_priv, + ) + + spk_pub = spk_priv.public_key() if hasattr(spk_priv, 'public_key') else None + ratchet = DoubleRatchet.init_bob(shared_secret, (spk_priv, spk_pub)) + + session_key = f"{sender_id}:{sender_device_id}" if sender_device_id else sender_id + self.sessions[session_key] = ratchet + _save_session(self.email, sender_id, ratchet, self._local_key, + peer_device_id=sender_device_id) + + self._user_cache[sender_id] = { + "user_id": sender_id, + "identity_key": ik_remote, + "identity_key_bytes": ik_remote_bytes, + "identity_key_status": ik_status, + } + + return ratchet + + # ------------------------------------------------------------------ + # Conversations + # ------------------------------------------------------------------ + + async def create_conversation(self, member_emails: list[str], name: str | None = None) -> tuple[str | None, str]: + kwargs = {"members": member_emails} + if name: + kwargs["name"] = name + resp = await self.send_and_recv("create_conversation", **kwargs) + if resp["status"] == "ok": + return resp["data"]["conversation_id"], "OK" + return None, resp["data"]["message"] + + async def remove_member(self, conv_id: str, user_id: str) -> tuple[bool, str]: + resp = await self.send_and_recv("remove_member", conversation_id=conv_id, user_id=user_id) + if resp["status"] == "ok": + return True, "OK" + return False, resp["data"]["message"] + + async def leave_group(self, conv_id: str) -> tuple[bool, str]: + """Leave a group conversation.""" + resp = await self.send_and_recv("leave_group", conversation_id=conv_id) + if resp["status"] == "ok": + # Clean up local sender key state for this group + self.sender_key_states.pop(conv_id, None) + # Remove received sender keys for this conversation + to_remove = [k for k in self.recv_sender_keys if k.startswith(f"{conv_id}:")] + for k in to_remove: + self.recv_sender_keys.pop(k, None) + return True, "OK" + return False, resp["data"]["message"] + + async def rename_conversation(self, conv_id: str, name: str) -> tuple[bool, str]: + """Rename a group conversation (creator only).""" + resp = await self.send_and_recv("rename_conversation", conversation_id=conv_id, name=name) + if resp["status"] == "ok": + return True, "OK" + return False, resp["data"]["message"] + + async def delete_conversation(self, conv_id: str) -> tuple[bool, str]: + """Delete a conversation (leave + server cleans up if empty).""" + resp = await self.send_and_recv("delete_conversation", conversation_id=conv_id) + if resp["status"] == "ok": + # Clean up local sender key state + self.sender_key_states.pop(conv_id, None) + to_remove = [k for k in self.recv_sender_keys if k.startswith(f"{conv_id}:")] + for k in to_remove: + self.recv_sender_keys.pop(k, None) + return True, "OK" + return False, resp["data"]["message"] + + async def add_member(self, conv_id: str, email: str) -> tuple[bool, str]: + resp = await self.send_and_recv("add_member", conversation_id=conv_id, email=email) + if resp["status"] == "ok": + return True, "OK" + return False, resp["data"]["message"] + + async def accept_invitation(self, conv_id: str) -> tuple[bool, str]: + """Accept a group invitation.""" + resp = await self.send_and_recv("accept_invitation", conversation_id=conv_id) + if resp["status"] == "ok": + return True, "OK" + return False, resp["data"]["message"] + + async def decline_invitation(self, conv_id: str) -> tuple[bool, str]: + """Decline a group invitation.""" + resp = await self.send_and_recv("decline_invitation", conversation_id=conv_id) + if resp["status"] == "ok": + return True, "OK" + return False, resp["data"]["message"] + + async def list_invitations(self) -> list[dict]: + """List pending group invitations.""" + resp = await self.send_and_recv("list_invitations") + if resp["status"] == "ok": + return resp["data"]["invitations"] + return [] + + async def list_conversations(self) -> list[dict]: + resp = await self.send_and_recv("list_conversations") + if resp["status"] == "ok": + return resp["data"]["conversations"] + return [] + + async def find_or_create_conversation(self, email: str) -> tuple[str | None, str]: + resp = await self.send_and_recv("find_conversation", email=email) + if resp["status"] != "ok": + return None, resp["data"]["message"] + conv_id = resp["data"]["conversation_id"] + if conv_id: + return conv_id, "OK" + return await self.create_conversation([email]) + + # ------------------------------------------------------------------ + # Send message + # ------------------------------------------------------------------ + + def _is_group(self, members: list[dict]) -> bool: + return len(members) > 2 + + async def send_message(self, conv_id: str, text: str, members: list[dict], + reply_to: str | None = None) -> tuple[bool, str | dict]: + """Encrypt and send a message. DM: per-recipient Double Ratchet. Group: Sender Keys. + + Returns (True, msg_dict) on success or (False, error_string) on failure. + msg_dict contains the full decrypted payload ready for display. + """ + my_user_id = self.session["user_id"] + + # Build plaintext payload + payload = { + "sender": self.username, + "text": text, + "reply_to": reply_to, + "timestamp": datetime.now(timezone.utc).isoformat(), + } + plaintext = pad_plaintext(json.dumps(payload, ensure_ascii=False).encode("utf-8")) + + if self._is_group(members): + return await self._send_group_message(conv_id, plaintext, members, payload) + else: + return await self._send_dm(conv_id, plaintext, members, payload) + + async def _send_dm(self, conv_id: str, plaintext: bytes, members: list[dict], + payload: dict | None = None) -> tuple[bool, str | dict]: + """Encrypt DM with per-device Double Ratchet.""" + my_user_id = self.session["user_id"] + recipients = [] + first_ratchet_header = None + + for member in members: + uid = member.get("user_id") + if not uid or uid == my_user_id: + continue + + # Get all device bundles for this user + try: + device_bundles = await self._get_device_bundles(uid) + self._logger.debug("Got %d device bundles for %s", len(device_bundles), uid) + except Exception as e: + self._logger.warning("Failed to get device bundles for %s: %s", uid, e) + device_bundles = [] + + if not device_bundles: + # Fallback: try single session (legacy peer) + ratchet = await self._get_or_create_session(uid) + result = ratchet.encrypt(plaintext) + x3dh_hdr = getattr(ratchet, "_x3dh_header", None) + if x3dh_hdr: + delattr(ratchet, "_x3dh_header") + entry = { + "user_id": uid, + "encrypted_content": encode_binary(result["ciphertext"]), + "nonce": encode_binary(result["nonce"]), + "ratchet_header": result["header"], + } + if x3dh_hdr: + entry["x3dh_header"] = x3dh_hdr + recipients.append(entry) + if first_ratchet_header is None: + first_ratchet_header = result["header"] + _save_session(self.email, uid, ratchet, self._local_key) + continue + + for bundle in device_bundles: + dev_id = bundle.get("device_id") + ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id, + bundle=bundle) + result = ratchet.encrypt(plaintext) + x3dh_hdr = getattr(ratchet, "_x3dh_header", None) + if x3dh_hdr: + delattr(ratchet, "_x3dh_header") + + entry = { + "user_id": uid, + "encrypted_content": encode_binary(result["ciphertext"]), + "nonce": encode_binary(result["nonce"]), + "ratchet_header": result["header"], + } + if dev_id: + entry["device_id"] = dev_id + if x3dh_hdr: + entry["x3dh_header"] = x3dh_hdr + recipients.append(entry) + + if first_ratchet_header is None: + first_ratchet_header = result["header"] + + _save_session(self.email, uid, ratchet, self._local_key, + peer_device_id=dev_id) + + # Encrypt self-copy with static key derived from identity (not ratchet) + # Uses SELF_DEVICE_ID so all own devices can read it + self_key = derive_self_encryption_key(self.identity_private) + _, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key) + recipients.append({ + "user_id": my_user_id, + "encrypted_content": encode_binary(self_ct + self_tag), + "nonce": encode_binary(self_nonce), + "ratchet_header": {"self": True}, + }) + + if not recipients: + return False, "No recipients." + + kwargs = { + "conversation_id": conv_id, + "ratchet_header": first_ratchet_header, + "recipients": recipients, + } + + resp = await self.send_and_recv("send_message", **kwargs) + if resp["status"] == "ok": + msg_data = resp.get("data", {}) + if payload is not None: + result = { + **payload, + "message_id": msg_data.get("message_id", ""), + "created_at": msg_data.get("created_at", ""), + "sender_id": self.session["user_id"], + "conversation_id": conv_id, + "read_by": [], + } + _save_message_to_cache(self.email, conv_id, result["message_id"], result, self._cache_key) + return True, result + return True, "Message sent." + return False, resp["data"]["message"] + + async def _send_group_message(self, conv_id: str, plaintext: bytes, + members: list[dict], + payload: dict | None = None) -> tuple[bool, str | dict]: + """Encrypt group message with Sender Keys.""" + my_user_id = self.session["user_id"] + + # Get or create sender key for this group + sk = self.sender_key_states.get(conv_id) + if not sk: + sk = _load_sender_key_state(self.email, conv_id, self._local_key) + if not sk: + sk = SenderKeyState() + self.sender_key_states[conv_id] = sk + _save_sender_key_state(self.email, conv_id, sk, self._local_key) + # Distribute sender key to all members via pairwise ratchet + await self._distribute_sender_key(conv_id, members, sk) + + self.sender_key_states[conv_id] = sk + + # Encrypt with sender key + result = sk.encrypt(plaintext) + _save_sender_key_state(self.email, conv_id, sk, self._local_key) + + # Build per-recipient entries (same ciphertext for all except self) + recipients = [] + for member in members: + uid = member.get("user_id") + if not uid or uid == my_user_id: + continue + recipients.append({ + "user_id": uid, + "encrypted_content": encode_binary(result["ciphertext"]), + "nonce": encode_binary(result["nonce"]), + }) + + # Self-encrypted copy (so other devices + history fetch can decrypt) + self_key = derive_self_encryption_key(self.identity_private) + _, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key) + recipients.append({ + "user_id": my_user_id, + "encrypted_content": encode_binary(self_ct + self_tag), + "nonce": encode_binary(self_nonce), + "ratchet_header": {"self": True}, + }) + + ratchet_header = {"dh_pub": "00" * 32, "n": 0, "pn": 0} # Dummy for groups + + kwargs = { + "conversation_id": conv_id, + "ratchet_header": ratchet_header, + "recipients": recipients, + "sender_chain_id": encode_binary(bytes.fromhex(result["chain_id"])), + "sender_chain_n": result["n"], + } + + resp = await self.send_and_recv("send_message", **kwargs) + if resp["status"] == "ok": + msg_data = resp.get("data", {}) + if payload is not None: + result_msg = { + **payload, + "message_id": msg_data.get("message_id", ""), + "created_at": msg_data.get("created_at", ""), + "sender_id": self.session["user_id"], + "conversation_id": conv_id, + "read_by": [], + } + _save_message_to_cache(self.email, conv_id, result_msg["message_id"], result_msg, self._cache_key) + return True, result_msg + return True, "Message sent." + return False, resp["data"]["message"] + + async def _distribute_sender_key(self, conv_id: str, members: list[dict], + sk: SenderKeyState): + """Send own sender key to all group members via pairwise Double Ratchet (per-device).""" + my_user_id = self.session["user_id"] + exported_key = sk.export_key() + + # Build a special "sender_key_distribution" payload + payload = { + "sender": self.username, + "text": "", + "reply_to": None, + "timestamp": datetime.now(timezone.utc).isoformat(), + "_sender_key": { + "conv_id": conv_id, + "key": encode_binary(exported_key), + "sender_device_id": self.device_id, + }, + } + plaintext = pad_plaintext(json.dumps(payload, ensure_ascii=False).encode("utf-8")) + + # Send as DM to each member's devices (per-device encryption) + for member in members: + uid = member.get("user_id") + if not uid or uid == my_user_id: + continue + + try: + # Get all device bundles for this user + try: + device_bundles = await self._get_device_bundles(uid) + except Exception: + device_bundles = [] + + if not device_bundles: + # Fallback: legacy single-device + ratchet = await self._get_or_create_session(uid) + result = ratchet.encrypt(plaintext) + x3dh_header = getattr(ratchet, "_x3dh_header", None) + if x3dh_header: + delattr(ratchet, "_x3dh_header") + + recipient_entry = { + "user_id": uid, + "encrypted_content": encode_binary(result["ciphertext"]), + "nonce": encode_binary(result["nonce"]), + "ratchet_header": result["header"], + } + if x3dh_header: + recipient_entry["x3dh_header"] = x3dh_header + kwargs = { + "conversation_id": conv_id, + "ratchet_header": result["header"], + "recipients": [recipient_entry], + } + await self.send_and_recv("send_message", **kwargs) + _save_session(self.email, uid, ratchet, self._local_key) + else: + # Per-device encryption + recipients = [] + first_rh = None + for bundle in device_bundles: + dev_id = bundle.get("device_id") + ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id, + bundle=bundle) + result = ratchet.encrypt(plaintext) + x3dh_header = getattr(ratchet, "_x3dh_header", None) + if x3dh_header: + delattr(ratchet, "_x3dh_header") + + entry = { + "user_id": uid, + "encrypted_content": encode_binary(result["ciphertext"]), + "nonce": encode_binary(result["nonce"]), + "ratchet_header": result["header"], + } + if dev_id: + entry["device_id"] = dev_id + if x3dh_header: + entry["x3dh_header"] = x3dh_header + recipients.append(entry) + if first_rh is None: + first_rh = result["header"] + _save_session(self.email, uid, ratchet, self._local_key, + peer_device_id=dev_id) + + kwargs = { + "conversation_id": conv_id, + "ratchet_header": first_rh, + "recipients": recipients, + } + await self.send_and_recv("send_message", **kwargs) + except Exception as e: + self._logger.warning("Failed to distribute sender key to %s: %s", uid, e) + + # ------------------------------------------------------------------ + # Decrypt messages + # ------------------------------------------------------------------ + + def _decrypt_message(self, msg_data: dict) -> dict: + """Decrypt a single message (DM or group).""" + # Check for self-encrypted marker FIRST — after re-encryption, + # group messages will have {"self": true} ratchet_header but still + # have sender_chain_id at message level. + rh = msg_data.get("ratchet_header", {}) + if isinstance(rh, dict) and rh.get("self"): + return self._decrypt_dm(msg_data) + + if msg_data.get("sender_chain_id"): + return self._decrypt_group(msg_data) + else: + return self._decrypt_dm(msg_data) + + def _decrypt_dm(self, msg_data: dict) -> dict: + """Decrypt DM using Double Ratchet with sender, or static key for self-copies.""" + sender_id = msg_data.get("sender_id", "") + sender_device_id = msg_data.get("sender_device_id") + ratchet_header = msg_data.get("ratchet_header", {}) + ct_b64 = msg_data.get("encrypted_content", "") + nonce_b64 = msg_data.get("nonce", "") + + if not ct_b64 or not nonce_b64: + raise ValueError("Missing ciphertext or nonce") + + ciphertext = decode_binary(ct_b64) + nonce = decode_binary(nonce_b64) + + # Self-encrypted message (own sent message copy) + if isinstance(ratchet_header, dict) and ratchet_header.get("self"): + self_key = derive_self_encryption_key(self.identity_private) + ct = ciphertext[:-16] + tag = ciphertext[-16:] + plaintext = aes_decrypt(self_key, nonce, ct, tag) + else: + x3dh_header = msg_data.get("x3dh_header") + + # Session key: "sender_id:sender_device_id" or just "sender_id" for legacy + session_key = f"{sender_id}:{sender_device_id}" if sender_device_id else sender_id + + # Try to load existing session + ratchet = self.sessions.get(session_key) + if not ratchet: + ratchet = _load_session(self.email, sender_id, self._local_key, + peer_device_id=sender_device_id) + if ratchet: + self.sessions[session_key] = ratchet + + if ratchet and not x3dh_header: + # Normal case: existing session, no X3DH header + plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce) + _save_session(self.email, sender_id, ratchet, self._local_key, + peer_device_id=sender_device_id) + elif x3dh_header: + if ratchet: + # Existing session + X3DH header: sender may have reset. + backup = ratchet.export_state() + try: + plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce) + _save_session(self.email, sender_id, ratchet, self._local_key, + peer_device_id=sender_device_id) + except Exception: + restored = DoubleRatchet.import_state(backup) + self.sessions[session_key] = restored + _save_session(self.email, sender_id, restored, self._local_key, + peer_device_id=sender_device_id) + ratchet = self._process_x3dh_header(sender_id, x3dh_header, + sender_device_id=sender_device_id) + try: + plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce) + except Exception: + if self._prev_spk_private: + ratchet = self._process_x3dh_header( + sender_id, x3dh_header, + sender_device_id=sender_device_id, + spk_override=self._prev_spk_private) + plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce) + else: + raise + _save_session(self.email, sender_id, ratchet, self._local_key, + peer_device_id=sender_device_id) + else: + ratchet = self._process_x3dh_header(sender_id, x3dh_header, + sender_device_id=sender_device_id) + try: + plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce) + except Exception: + if self._prev_spk_private: + ratchet = self._process_x3dh_header( + sender_id, x3dh_header, + sender_device_id=sender_device_id, + spk_override=self._prev_spk_private) + plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce) + else: + raise + _save_session(self.email, sender_id, ratchet, self._local_key, + peer_device_id=sender_device_id) + else: + raise ValueError(f"No session for sender {sender_id}") + + plaintext = unpad_plaintext(plaintext) + payload = json.loads(plaintext) + + # Handle sender key distribution messages + if "_sender_key" in payload: + sk_data = payload["_sender_key"] + sk_conv_id = sk_data["conv_id"] + sk_key = decode_binary(sk_data["key"]) + sk_sender_device_id = sk_data.get("sender_device_id") + recv_sk = SenderKeyState.from_key(sk_key) + if sk_sender_device_id: + cache_key = f"{sk_conv_id}:{sender_id}:{sk_sender_device_id}" + else: + cache_key = f"{sk_conv_id}:{sender_id}" + self.recv_sender_keys[cache_key] = recv_sk + _save_recv_sender_key(self.email, sk_conv_id, sender_id, recv_sk, self._local_key, + sender_device_id=sk_sender_device_id) + # Return empty — this is a control message, not user-visible + return None + + return payload + + def _decrypt_group(self, msg_data: dict) -> dict: + """Decrypt group message using sender's Sender Key.""" + sender_id = msg_data.get("sender_id", "") + sender_device_id = msg_data.get("sender_device_id") + conv_id = msg_data.get("conversation_id", "") + chain_id_b64 = msg_data.get("sender_chain_id", "") + chain_n = msg_data.get("sender_chain_n", 0) + ct_b64 = msg_data.get("encrypted_content", "") + nonce_b64 = msg_data.get("nonce", "") + + if not ct_b64 or not nonce_b64 or not chain_id_b64: + raise ValueError("Missing group message fields") + + ciphertext = decode_binary(ct_b64) + nonce = decode_binary(nonce_b64) + chain_id = decode_binary(chain_id_b64) + + my_user_id = self.session["user_id"] + + # If we sent this message, use our own sender key + if sender_id == my_user_id: + sk = self.sender_key_states.get(conv_id) + if not sk: + sk = _load_sender_key_state(self.email, conv_id, self._local_key) + if sk: + self.sender_key_states[conv_id] = sk + if not sk: + raise ValueError("Own sender key not found") + # For our own messages, we can't decrypt from sender key (it's already advanced) + # Return a placeholder — the server echoed our ciphertext + raise ValueError("Cannot decrypt own group message from sender key") + + # Use received sender key — try with sender_device_id first, fall back to without + sk = None + if sender_device_id: + cache_key = f"{conv_id}:{sender_id}:{sender_device_id}" + sk = self.recv_sender_keys.get(cache_key) + if not sk: + sk = _load_recv_sender_key(self.email, conv_id, sender_id, self._local_key, + sender_device_id=sender_device_id) + if sk: + self.recv_sender_keys[cache_key] = sk + + if not sk: + # Fallback: try without device_id (legacy or same-device) + cache_key = f"{conv_id}:{sender_id}" + sk = self.recv_sender_keys.get(cache_key) + if not sk: + sk = _load_recv_sender_key(self.email, conv_id, sender_id, self._local_key) + if sk: + self.recv_sender_keys[cache_key] = sk + + if not sk: + raise ValueError(f"No sender key for {sender_id} in conversation {conv_id}") + + plaintext = unpad_plaintext(sk.decrypt(chain_id.hex(), chain_n, ciphertext, nonce)) + _save_recv_sender_key(self.email, conv_id, sender_id, sk, self._local_key, + sender_device_id=sender_device_id) + + return json.loads(plaintext) + + # ------------------------------------------------------------------ + # Get/decrypt messages (batch) + # ------------------------------------------------------------------ + + async def get_messages(self, conv_id: str, limit: int = 50, offset: int = 0) -> list[dict]: + cache = _load_message_cache(self.email, conv_id, self._cache_key) + my_user_id = self.session["user_id"] if self.session else "" + + # Incremental sync: use stored server timestamp from last successful fetch. + after_ts = None + if cache and offset == 0: + after_ts = cache.get("__last_server_ts", {}).get("ts") + + req_params = {"conversation_id": conv_id, "limit": limit, "offset": offset} + if after_ts: + req_params["after_ts"] = after_ts + resp = await self.send_and_recv("get_messages", **req_params) + + if resp["status"] != "ok": + # Offline fallback: return from cache if available + if cache and offset == 0: + return self._build_from_cache(cache) + return [] + + raw_messages = resp["data"]["messages"] + raw_messages.reverse() # Server returns DESC, reverse to ASC + + # Save latest server timestamp for next incremental sync + if raw_messages: + # raw_messages are now ASC; last one is newest + newest_ts = raw_messages[-1].get("created_at", "") + if newest_ts: + cache["__last_server_ts"] = {"ts": newest_ts} + _save_message_to_cache(self.email, conv_id, "__last_server_ts", + {"ts": newest_ts}, cache_key=self._cache_key) + + # Decrypt new messages from server + new_decrypted = self._decrypt_raw_messages(raw_messages, cache, conv_id, my_user_id) + + # Confirm delivery for messages from others (fire-and-forget) + deliver_ids = [m["message_id"] for m in new_decrypted + if m.get("sender_id") and m["sender_id"] != my_user_id + and not m.get("deleted")] + if deliver_ids: + asyncio.ensure_future(self.confirm_delivery(conv_id, deliver_ids)) + + # Mark entire conversation as read (bulk — server handles filtering) + await self.mark_conversation_read(conv_id) + + # Flush self-encryption queue in background + if self._pending_self_encrypt: + asyncio.ensure_future(self._flush_self_encrypt()) + + if after_ts: + # Incremental: sync deletions, then build from cache + try: + del_resp = await self.send_and_recv("get_deleted_since", + conversation_id=conv_id, since=after_ts) + if del_resp.get("status") == "ok": + for del_id in del_resp.get("data", {}).get("message_ids", []): + cache.pop(del_id, None) + _save_message_to_cache(self.email, conv_id, del_id, {"deleted": True}, + cache_key=self._cache_key) + except Exception: + pass + return self._build_from_cache(cache) + + return new_decrypted + + def _build_from_cache(self, cache: dict) -> list[dict]: + """Build sorted message list from local cache (all messages).""" + messages = [] + for msg_id, p in cache.items(): + if p.get("_control") or msg_id.startswith("__"): + continue + entry = dict(p) + entry.setdefault("message_id", msg_id) + entry.setdefault("read_by", []) + entry.setdefault("delivered_to", []) + messages.append(entry) + messages.sort(key=lambda m: m.get("created_at", "")) + return messages + + def _decrypt_raw_messages(self, raw_messages: list, cache: dict, + conv_id: str, my_user_id: str) -> list[dict]: + """Decrypt server messages, update cache. Returns list of decrypted dicts.""" + decrypted = [] + for m in raw_messages: + msg_id = m["message_id"] + + if m.get("deleted_at"): + decrypted.append({ + "message_id": msg_id, + "sender": "", + "text": "", + "created_at": m["created_at"], + "read_by": [], + "sender_id": m.get("sender_id", ""), + "deleted": True, + }) + cache[msg_id] = {"deleted": True, "created_at": m["created_at"]} + continue + + # Check local cache first (ratchet keys are one-time use) + cached = cache.get(msg_id) + if cached and not cached.get("_control"): + cached["read_by"] = m.get("read_by", []) + cached["delivered_to"] = m.get("delivered_to", []) + cached["created_at"] = m["created_at"] + if m.get("reactions"): + cached["reactions"] = m["reactions"] + if m.get("pinned_at"): + cached["pinned_at"] = m["pinned_at"] + cached["pinned_by"] = m.get("pinned_by", "") + else: + cached.pop("pinned_at", None) + cached.pop("pinned_by", None) + decrypted.append(cached) + continue + if cached and cached.get("_control"): + continue + + try: + msg_data = { + "sender_id": m.get("sender_id", ""), + "sender_device_id": m.get("sender_device_id"), + "conversation_id": conv_id, + "ratchet_header": m.get("ratchet_header", {}), + "encrypted_content": m.get("encrypted_content", ""), + "nonce": m.get("nonce", ""), + "x3dh_header": m.get("x3dh_header"), + "sender_chain_id": m.get("sender_chain_id"), + "sender_chain_n": m.get("sender_chain_n"), + } + payload = self._decrypt_message(msg_data) + if payload is None: + _save_message_to_cache(self.email, conv_id, msg_id, {"_control": True}, + cache_key=self._cache_key) + cache[msg_id] = {"_control": True} + continue + payload["message_id"] = msg_id + payload["created_at"] = m["created_at"] + payload["read_by"] = m.get("read_by", []) + payload["delivered_to"] = m.get("delivered_to", []) + payload["sender_id"] = m.get("sender_id", "") + if m.get("reactions"): + payload["reactions"] = m["reactions"] + if m.get("pinned_at"): + payload["pinned_at"] = m["pinned_at"] + payload["pinned_by"] = m.get("pinned_by", "") + decrypted.append(payload) + _save_message_to_cache(self.email, conv_id, msg_id, payload, + cache_key=self._cache_key) + cache[msg_id] = payload + if m.get("sender_id", "") != my_user_id: + self._pending_self_encrypt.append({ + "message_id": msg_id, + "payload": {k: v for k, v in payload.items() + if k not in ("message_id", "created_at", "read_by", + "delivered_to", "sender_id", "deleted")}, + }) + except Exception as e: + decrypted.append({ + "message_id": msg_id, + "sender": "???", + "text": f"[Decryption failed: {e}]", + "created_at": m["created_at"], + "read_by": [], + }) + return decrypted + + async def _flush_self_encrypt(self): + """Upload self-encrypted copies of received messages for multi-device access.""" + if not self._pending_self_encrypt or not self.identity_private: + return + self_key = derive_self_encryption_key(self.identity_private) + updates = [] + for item in list(self._pending_self_encrypt): + try: + plaintext = json.dumps(item["payload"], ensure_ascii=False).encode("utf-8") + _, nonce, ct, tag = aes_encrypt(plaintext, key=self_key) + updates.append({ + "message_id": item["message_id"], + "encrypted_content": encode_binary(ct + tag), + "nonce": encode_binary(nonce), + }) + except Exception: + pass + self._pending_self_encrypt.clear() + if updates: + try: + for i in range(0, len(updates), 500): + batch = updates[i:i + 500] + await self.send_and_recv("reencrypt_messages", updates=batch) + except Exception as e: + self._logger.warning("Failed to self-encrypt received messages: %s", e) + + async def mark_read(self, conv_id: str, message_ids: list[str]): + if not message_ids: + return + await self.send_and_recv("mark_read", conversation_id=conv_id, message_ids=message_ids) + + async def mark_conversation_read(self, conv_id: str): + """Mark ALL unread messages in a conversation as read (server-side bulk).""" + try: + await self.send_and_recv("mark_conversation_read", conversation_id=conv_id) + except Exception: + pass # non-critical — don't fail message loading + + async def confirm_delivery(self, conv_id: str, message_ids: list[str]): + """Confirm delivery of messages (fire-and-forget, non-critical).""" + if not message_ids: + return + try: + await self.send_and_recv("confirm_delivery", + conversation_id=conv_id, message_ids=message_ids) + except Exception: + pass # non-critical + + def search_messages(self, conv_id: str, query: str) -> list[dict]: + """Search cached messages in a conversation. Returns matching messages.""" + cache = _load_message_cache(self.email, conv_id, self._cache_key) + query_lower = query.lower() + results = [] + for msg_id, payload in cache.items(): + if payload.get("deleted") or payload.get("_control") or payload.get("_sender_key"): + continue + text = payload.get("text", "") + if query_lower in text.lower(): + entry = dict(payload) + entry["message_id"] = msg_id + results.append(entry) + results.sort(key=lambda m: m.get("created_at", "")) + return results + + async def reset_session(self, peer_user_id: str, peer_device_id: str | None = None): + """Delete local session and notify peer to do the same.""" + if peer_device_id: + session_key = f"{peer_user_id}:{peer_device_id}" + else: + session_key = peer_user_id + self.sessions.pop(session_key, None) + _delete_session_file(self.email, peer_user_id, peer_device_id) + await self.send_and_recv("session_reset", + peer_user_id=peer_user_id, + peer_device_id=peer_device_id or "") + + def handle_session_reset_notification(self, from_user_id: str, from_device_id: str | None = None): + """Handle incoming session reset notification — delete the matching session.""" + if from_device_id: + session_key = f"{from_user_id}:{from_device_id}" + else: + session_key = from_user_id + self.sessions.pop(session_key, None) + _delete_session_file(self.email, from_user_id, from_device_id) + + # ------------------------------------------------------------------ + # Local message cache updates + # ------------------------------------------------------------------ + + def load_message_cache(self, conv_id: str) -> dict: + """Load cached messages for a conversation. Returns {msg_id: payload}.""" + if not self.email: + return {} + return _load_message_cache(self.email, conv_id, self._cache_key) + + def update_message_in_cache(self, conv_id: str, message_id: str, updates: dict): + """Update fields of a cached message on disk (synchronous).""" + if not self.email: + return + cache = _load_message_cache(self.email, conv_id, self._cache_key) + if message_id not in cache or cache[message_id].get("_control"): + return + for key, value in updates.items(): + if value is None: + cache[message_id].pop(key, None) + else: + cache[message_id][key] = value + d = get_key_dir(self.email) / "message_cache" + if self._cache_key: + _save_message_cache_full(d, conv_id, cache, self._cache_key) + + # ------------------------------------------------------------------ + # Reactions, Pins, Forwarding + # ------------------------------------------------------------------ + + async def react_message(self, message_id: str, reaction: str, action: str = "add") -> tuple[bool, str]: + """Add or remove a reaction on a message.""" + resp = await self.send_and_recv("react_message", + message_id=message_id, reaction=reaction, action=action) + if resp["status"] == "ok": + return True, "OK" + return False, resp.get("data", {}).get("message", "Failed") + + async def pin_message(self, message_id: str, conversation_id: str, action: str = "pin") -> tuple[bool, str]: + """Pin or unpin a message.""" + resp = await self.send_and_recv("pin_message", + message_id=message_id, conversation_id=conversation_id, action=action) + if resp["status"] == "ok": + return True, "OK" + return False, resp.get("data", {}).get("message", "Failed") + + async def get_pinned_messages(self, conversation_id: str) -> list[dict]: + """Get list of pinned messages for a conversation.""" + resp = await self.send_and_recv("get_pinned_messages", conversation_id=conversation_id) + if resp["status"] == "ok": + return resp["data"].get("messages", []) + return [] + + async def forward_message(self, target_conv_id: str, original_msg: dict, + target_members: list[dict]) -> tuple[bool, str | dict]: + """Forward a message to another conversation.""" + text = original_msg.get("text", "") + + payload = { + "sender": self.username, + "text": text, + "forwarded_from": { + "sender": original_msg.get("sender", ""), + "conversation_id": original_msg.get("conversation_id", ""), + "message_id": original_msg.get("message_id", ""), + }, + "timestamp": datetime.now(timezone.utc).isoformat(), + } + # Forward image/file metadata (the encrypted blob is already on the server) + if original_msg.get("image"): + payload["image"] = original_msg["image"] + if not text: + payload["text"] = "" + if original_msg.get("file"): + payload["file"] = original_msg["file"] + if not text: + payload["text"] = "" + plaintext = pad_plaintext(json.dumps(payload, ensure_ascii=False).encode("utf-8")) + + if self._is_group(target_members): + return await self._send_group_message(target_conv_id, plaintext, target_members, payload) + else: + return await self._send_dm(target_conv_id, plaintext, target_members, payload) + + # ------------------------------------------------------------------ + # Decrypt notification + # ------------------------------------------------------------------ + + def decrypt_notification(self, notif_data: dict) -> dict | None: + """Decrypt a new_message notification. Returns parsed payload or None. + + Supports new multi-device format (device_entries array) and legacy flat format. + """ + try: + conv_id = notif_data.get("conversation_id", "") + msg_id = notif_data.get("message_id", "") + sender_id = notif_data.get("sender_id", "") + sender_device_id = notif_data.get("sender_device_id") + my_user_id = self.session["user_id"] if self.session else "" + + # Extract per-device encrypted content from device_entries or flat fields + encrypted_content = "" + nonce = "" + ratchet_header = {} + x3dh_header = None + + device_entries = notif_data.get("device_entries") + if device_entries: + # Multi-device format: pick entry matching our device_id or SELF_DEVICE_ID + chosen = None + self_entry = None + for entry in device_entries: + eid = entry.get("device_id", "") + if eid == self.device_id: + chosen = entry + break + if eid == "00000000-0000-0000-0000-000000000000": + self_entry = entry + + # If sender is us, prefer self-encrypted entry + if sender_id == my_user_id: + chosen = self_entry or chosen + elif not chosen: + chosen = self_entry + + if not chosen: + self._logger.warning("No matching device_entry for device %s", self.device_id) + return None + + encrypted_content = chosen.get("encrypted_content", "") + nonce = chosen.get("nonce", "") + ratchet_header = chosen.get("ratchet_header") or notif_data.get("ratchet_header", {}) + x3dh_header = chosen.get("x3dh_header") or notif_data.get("x3dh_header") + else: + # Legacy flat format + encrypted_content = notif_data.get("encrypted_content", "") + nonce = notif_data.get("nonce", "") + ratchet_header = notif_data.get("ratchet_header", {}) + x3dh_header = notif_data.get("x3dh_header") + + msg_data = { + "sender_id": sender_id, + "sender_device_id": sender_device_id, + "conversation_id": conv_id, + "ratchet_header": ratchet_header, + "encrypted_content": encrypted_content, + "nonce": nonce, + "x3dh_header": x3dh_header, + "sender_chain_id": notif_data.get("sender_chain_id"), + "sender_chain_n": notif_data.get("sender_chain_n"), + } + payload = self._decrypt_message(msg_data) + if payload is None: + # Cache control message so get_messages skips it + if msg_id and conv_id: + _save_message_to_cache(self.email, conv_id, msg_id, {"_control": True}, + cache_key=self._cache_key) + return None + payload["conversation_id"] = conv_id + payload["message_id"] = msg_id + payload["sender_id"] = sender_id + # Use server-compatible timestamp (no timezone suffix) for cache consistency + _ts = payload.get("timestamp", "") + if _ts: + # Strip timezone suffix (+00:00 or Z) to match server DATETIME format + _ts = _ts.replace("+00:00", "").replace("Z", "") + # Strip microseconds if present + if "." in _ts: + _ts = _ts[:_ts.index(".")] + payload["created_at"] = _ts + payload["read_by"] = [] + payload["delivered_to"] = [] + # Cache so get_messages doesn't re-decrypt (ratchet keys are one-time) + if msg_id and conv_id: + _save_message_to_cache(self.email, conv_id, msg_id, payload, + cache_key=self._cache_key) + # Queue self-encryption for received messages (multi-device access) + if sender_id != my_user_id and msg_id: + self._pending_self_encrypt.append({ + "message_id": msg_id, + "payload": {k: v for k, v in payload.items() + if k not in ("conversation_id", "message_id", "created_at", + "read_by", "delivered_to", "sender_id", "deleted")}, + }) + return payload + except IdentityKeyChanged: + raise # Must propagate to caller for key-change UI + except Exception as e: + self._logger.warning("Failed to decrypt notification: %s", e) + return None + + # ------------------------------------------------------------------ + # Delete message + # ------------------------------------------------------------------ + + async def delete_message(self, message_id: str) -> tuple[bool, str]: + resp = await self.send_and_recv("delete_message", message_id=message_id) + if resp["status"] == "ok": + return True, "Message deleted." + return False, resp["data"]["message"] + + # ------------------------------------------------------------------ + # Image sharing + # ------------------------------------------------------------------ + + async def send_image(self, conv_id: str, image_path: str, members: list[dict], + reply_to: str | None = None) -> tuple[bool, str]: + """Encrypt and upload an image, then send as a message.""" + try: + from PIL import Image + import io + except ImportError: + return False, "Pillow is required for image sharing. Install with: pip install Pillow" + + path = Path(image_path) + if not path.exists(): + return False, "File not found." + + try: + img = Image.open(path) + img.load() + except Exception as e: + return False, f"Cannot open image: {e}" + + # Try sending in original format/quality first + original_format = img.format or "JPEG" + if original_format.upper() not in ("JPEG", "PNG", "WEBP", "GIF", "BMP"): + original_format = "JPEG" + + # Read raw file bytes for original quality + image_bytes = path.read_bytes() + + # If encrypted size exceeds limit, progressively downscale + if MAX_IMAGE_BYTES > 0: + img_aes_key_test, _, ct_test, tag_test = aes_encrypt(image_bytes) + if len(ct_test) + len(tag_test) > MAX_IMAGE_BYTES: + # Convert to RGB for JPEG compression + if img.mode not in ("RGB", "L"): + img = img.convert("RGB") + # Try JPEG at high quality first, then reduce quality/dimensions + for quality in (92, 85, 75, 60): + buf = io.BytesIO() + img.save(buf, format="JPEG", quality=quality) + image_bytes = buf.getvalue() + _, _, ct_test, tag_test = aes_encrypt(image_bytes) + if len(ct_test) + len(tag_test) <= MAX_IMAGE_BYTES: + break + else: + # Still too large — downscale dimensions + for max_dim in (3840, 2560, 1920, 1280): + if max(img.size) > max_dim: + img.thumbnail((max_dim, max_dim), Image.Resampling.LANCZOS) + buf = io.BytesIO() + img.save(buf, format="JPEG", quality=75) + image_bytes = buf.getvalue() + _, _, ct_test, tag_test = aes_encrypt(image_bytes) + if len(ct_test) + len(tag_test) <= MAX_IMAGE_BYTES: + break + + # Generate thumbnail + thumb = img.copy() + thumb.thumbnail((200, 200), Image.Resampling.LANCZOS) + if thumb.mode not in ("RGB", "L"): + thumb = thumb.convert("RGB") + thumb_buf = io.BytesIO() + thumb.save(thumb_buf, format="JPEG", quality=60) + thumbnail_b64 = encode_binary(thumb_buf.getvalue()) + + # Encrypt image with AES-256-GCM + img_aes_key, img_iv, img_ct, img_tag = aes_encrypt(image_bytes) + encrypted_image = img_ct + img_tag + + file_id = str(uuid.uuid4()) + file_size = len(encrypted_image) + + # Chunked upload + resp = await self.send_and_recv( + "upload_image_start", + conversation_id=conv_id, + file_id=file_id, + file_size=file_size, + ) + if resp["status"] != "ok": + return False, resp["data"]["message"] + + upload_offset = 0 + while upload_offset < file_size: + chunk = encrypted_image[upload_offset:upload_offset + IMAGE_CHUNK_SIZE] + resp = await self.send_and_recv( + "upload_image_chunk", + file_id=file_id, + data=encode_binary(chunk), + ) + if resp["status"] != "ok": + return False, resp["data"]["message"] + upload_offset += len(chunk) + + resp = await self.send_and_recv("upload_image_end", file_id=file_id) + if resp["status"] != "ok": + return False, resp["data"]["message"] + + # Build message payload with image info + image_info = { + "file_id": file_id, + "aes_key": encode_binary(img_aes_key), + "iv": encode_binary(img_iv), + "thumbnail": thumbnail_b64, + "filename": path.name, + "size": len(image_bytes), + } + + payload = { + "sender": self.username, + "text": "", + "reply_to": reply_to, + "timestamp": datetime.now(timezone.utc).isoformat(), + "image": image_info, + } + plaintext = pad_plaintext(json.dumps(payload, ensure_ascii=False).encode("utf-8")) + + my_user_id = self.session["user_id"] + + if self._is_group(members): + # Group image: use sender key + sk = self.sender_key_states.get(conv_id) + if not sk: + sk = _load_sender_key_state(self.email, conv_id, self._local_key) + if not sk: + sk = SenderKeyState() + self.sender_key_states[conv_id] = sk + _save_sender_key_state(self.email, conv_id, sk, self._local_key) + await self._distribute_sender_key(conv_id, members, sk) + + result = sk.encrypt(plaintext) + _save_sender_key_state(self.email, conv_id, sk, self._local_key) + + recipients = [] + for member in members: + uid = member.get("user_id") + if not uid or uid == my_user_id: + continue + recipients.append({ + "user_id": uid, + "encrypted_content": encode_binary(result["ciphertext"]), + "nonce": encode_binary(result["nonce"]), + }) + + # Self-encrypted copy for sender + self_key = derive_self_encryption_key(self.identity_private) + _, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key) + recipients.append({ + "user_id": my_user_id, + "encrypted_content": encode_binary(self_ct + self_tag), + "nonce": encode_binary(self_nonce), + "ratchet_header": {"self": True}, + }) + + resp = await self.send_and_recv( + "send_message", + conversation_id=conv_id, + ratchet_header={"dh_pub": "00" * 32, "n": 0, "pn": 0}, + recipients=recipients, + sender_chain_id=encode_binary(bytes.fromhex(result["chain_id"])), + sender_chain_n=result["n"], + image_file_id=file_id, + ) + else: + # DM image: per-device ratchet (same pattern as _send_dm) + recipients = [] + first_rh = None + for member in members: + uid = member.get("user_id") + if not uid or uid == my_user_id: + continue + + try: + device_bundles = await self._get_device_bundles(uid) + except Exception: + device_bundles = [] + + if not device_bundles: + # Fallback: legacy single-device + ratchet = await self._get_or_create_session(uid) + result = ratchet.encrypt(plaintext) + x3dh_h = getattr(ratchet, "_x3dh_header", None) + if x3dh_h: + delattr(ratchet, "_x3dh_header") + entry = { + "user_id": uid, + "encrypted_content": encode_binary(result["ciphertext"]), + "nonce": encode_binary(result["nonce"]), + "ratchet_header": result["header"], + } + if x3dh_h: + entry["x3dh_header"] = x3dh_h + recipients.append(entry) + if first_rh is None: + first_rh = result["header"] + _save_session(self.email, uid, ratchet, self._local_key) + else: + for bundle in device_bundles: + dev_id = bundle.get("device_id") + ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id, + bundle=bundle) + result = ratchet.encrypt(plaintext) + x3dh_h = getattr(ratchet, "_x3dh_header", None) + if x3dh_h: + delattr(ratchet, "_x3dh_header") + entry = { + "user_id": uid, + "encrypted_content": encode_binary(result["ciphertext"]), + "nonce": encode_binary(result["nonce"]), + "ratchet_header": result["header"], + } + if dev_id: + entry["device_id"] = dev_id + if x3dh_h: + entry["x3dh_header"] = x3dh_h + recipients.append(entry) + if first_rh is None: + first_rh = result["header"] + _save_session(self.email, uid, ratchet, self._local_key, + peer_device_id=dev_id) + + # Encrypt self-copy with static key + self_key = derive_self_encryption_key(self.identity_private) + _, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key) + recipients.append({ + "user_id": my_user_id, + "encrypted_content": encode_binary(self_ct + self_tag), + "nonce": encode_binary(self_nonce), + "ratchet_header": {"self": True}, + }) + + resp = await self.send_and_recv( + "send_message", + conversation_id=conv_id, + ratchet_header=first_rh, + recipients=recipients, + image_file_id=file_id, + ) + + if resp["status"] == "ok": + msg_data = resp.get("data", {}) + result_msg = { + **payload, + "message_id": msg_data.get("message_id", ""), + "created_at": msg_data.get("created_at", ""), + "sender_id": self.session["user_id"], + "conversation_id": conv_id, + "read_by": [], + } + _save_message_to_cache(self.email, conv_id, result_msg["message_id"], result_msg, self._cache_key) + return True, result_msg + return False, resp["data"]["message"] + + async def send_file(self, conv_id: str, file_path: str, members: list[dict], + reply_to: str | None = None) -> tuple[bool, str | dict]: + """Encrypt and upload a file, then send as a message.""" + import mimetypes + + path = Path(file_path) + if not path.exists(): + return False, "File not found." + + try: + file_bytes = path.read_bytes() + except Exception as e: + return False, f"Cannot read file: {e}" + + mime_type = mimetypes.guess_type(path.name)[0] or "application/octet-stream" + + # Encrypt file with AES-256-GCM + file_aes_key, file_iv, file_ct, file_tag = aes_encrypt(file_bytes) + encrypted_file = file_ct + file_tag + + file_id = str(uuid.uuid4()) + file_size = len(encrypted_file) + + # Chunked upload (reuse image upload infrastructure with file_type="file") + resp = await self.send_and_recv( + "upload_image_start", + conversation_id=conv_id, + file_id=file_id, + file_size=file_size, + file_type="file", + ) + if resp["status"] != "ok": + return False, resp["data"]["message"] + + upload_offset = 0 + while upload_offset < file_size: + chunk = encrypted_file[upload_offset:upload_offset + IMAGE_CHUNK_SIZE] + resp = await self.send_and_recv( + "upload_image_chunk", + file_id=file_id, + data=encode_binary(chunk), + ) + if resp["status"] != "ok": + return False, resp["data"]["message"] + upload_offset += len(chunk) + + resp = await self.send_and_recv("upload_image_end", file_id=file_id) + if resp["status"] != "ok": + return False, resp["data"]["message"] + + # Build message payload with file info + file_info = { + "file_id": file_id, + "aes_key": encode_binary(file_aes_key), + "iv": encode_binary(file_iv), + "filename": path.name, + "size": len(file_bytes), + "mime_type": mime_type, + } + + payload = { + "sender": self.username, + "text": "", + "reply_to": reply_to, + "timestamp": datetime.now(timezone.utc).isoformat(), + "file": file_info, + } + plaintext = pad_plaintext(json.dumps(payload, ensure_ascii=False).encode("utf-8")) + + my_user_id = self.session["user_id"] + + if self._is_group(members): + sk = self.sender_key_states.get(conv_id) + if not sk: + sk = _load_sender_key_state(self.email, conv_id, self._local_key) + if not sk: + sk = SenderKeyState() + self.sender_key_states[conv_id] = sk + _save_sender_key_state(self.email, conv_id, sk, self._local_key) + await self._distribute_sender_key(conv_id, members, sk) + + result = sk.encrypt(plaintext) + _save_sender_key_state(self.email, conv_id, sk, self._local_key) + + recipients = [] + for member in members: + uid = member.get("user_id") + if not uid or uid == my_user_id: + continue + recipients.append({ + "user_id": uid, + "encrypted_content": encode_binary(result["ciphertext"]), + "nonce": encode_binary(result["nonce"]), + }) + + # Self-encrypted copy for sender + self_key = derive_self_encryption_key(self.identity_private) + _, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key) + recipients.append({ + "user_id": my_user_id, + "encrypted_content": encode_binary(self_ct + self_tag), + "nonce": encode_binary(self_nonce), + "ratchet_header": {"self": True}, + }) + + resp = await self.send_and_recv( + "send_message", + conversation_id=conv_id, + ratchet_header={"dh_pub": "00" * 32, "n": 0, "pn": 0}, + recipients=recipients, + sender_chain_id=encode_binary(bytes.fromhex(result["chain_id"])), + sender_chain_n=result["n"], + image_file_id=file_id, + ) + else: + # DM file: per-device ratchet (same pattern as _send_dm) + recipients = [] + first_rh = None + for member in members: + uid = member.get("user_id") + if not uid or uid == my_user_id: + continue + + try: + device_bundles = await self._get_device_bundles(uid) + except Exception: + device_bundles = [] + + if not device_bundles: + # Fallback: legacy single-device + ratchet = await self._get_or_create_session(uid) + result = ratchet.encrypt(plaintext) + x3dh_h = getattr(ratchet, "_x3dh_header", None) + if x3dh_h: + delattr(ratchet, "_x3dh_header") + entry = { + "user_id": uid, + "encrypted_content": encode_binary(result["ciphertext"]), + "nonce": encode_binary(result["nonce"]), + "ratchet_header": result["header"], + } + if x3dh_h: + entry["x3dh_header"] = x3dh_h + recipients.append(entry) + if first_rh is None: + first_rh = result["header"] + _save_session(self.email, uid, ratchet, self._local_key) + else: + for bundle in device_bundles: + dev_id = bundle.get("device_id") + ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id, + bundle=bundle) + result = ratchet.encrypt(plaintext) + x3dh_h = getattr(ratchet, "_x3dh_header", None) + if x3dh_h: + delattr(ratchet, "_x3dh_header") + entry = { + "user_id": uid, + "encrypted_content": encode_binary(result["ciphertext"]), + "nonce": encode_binary(result["nonce"]), + "ratchet_header": result["header"], + } + if dev_id: + entry["device_id"] = dev_id + if x3dh_h: + entry["x3dh_header"] = x3dh_h + recipients.append(entry) + if first_rh is None: + first_rh = result["header"] + _save_session(self.email, uid, ratchet, self._local_key, + peer_device_id=dev_id) + + # Encrypt self-copy with static key + self_key = derive_self_encryption_key(self.identity_private) + _, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key) + recipients.append({ + "user_id": my_user_id, + "encrypted_content": encode_binary(self_ct + self_tag), + "nonce": encode_binary(self_nonce), + "ratchet_header": {"self": True}, + }) + + resp = await self.send_and_recv( + "send_message", + conversation_id=conv_id, + ratchet_header=first_rh, + recipients=recipients, + image_file_id=file_id, + ) + + if resp["status"] == "ok": + msg_data = resp.get("data", {}) + result_msg = { + **payload, + "message_id": msg_data.get("message_id", ""), + "created_at": msg_data.get("created_at", ""), + "sender_id": self.session["user_id"], + "conversation_id": conv_id, + "read_by": [], + } + _save_message_to_cache(self.email, conv_id, result_msg["message_id"], result_msg, self._cache_key) + return True, result_msg + return False, resp["data"]["message"] + + async def download_file(self, file_id: str, file_info: dict) -> bytes | None: + """Download and decrypt a file. Returns decrypted file bytes or None.""" + chunks = [] + offset = 0 + while True: + resp = await self.send_and_recv( + "download_image", + file_id=file_id, + offset=offset, + ) + if resp["status"] != "ok": + return None + data = resp["data"] + chunk = decode_binary(data["data"]) + chunks.append(chunk) + offset += len(chunk) + if data.get("done"): + break + + encrypted_data = b"".join(chunks) + if len(encrypted_data) < 16: + return None + ciphertext = encrypted_data[:-16] + tag = encrypted_data[-16:] + + try: + file_aes_key = decode_binary(file_info["aes_key"]) + iv = decode_binary(file_info["iv"]) + return aes_decrypt(file_aes_key, iv, ciphertext, tag) + except Exception: + return None + + async def download_image(self, file_id: str, image_info: dict) -> bytes | None: + """Download and decrypt an image. Returns decrypted image bytes or None.""" + chunks = [] + offset = 0 + while True: + resp = await self.send_and_recv( + "download_image", + file_id=file_id, + offset=offset, + ) + if resp["status"] != "ok": + return None + data = resp["data"] + chunk = decode_binary(data["data"]) + chunks.append(chunk) + offset += len(chunk) + if data.get("done"): + break + + encrypted_data = b"".join(chunks) + if len(encrypted_data) < 16: + return None + ciphertext = encrypted_data[:-16] + tag = encrypted_data[-16:] + + try: + img_aes_key = decode_binary(image_info["aes_key"]) + iv = decode_binary(image_info["iv"]) + return aes_decrypt(img_aes_key, iv, ciphertext, tag) + except Exception: + return None + + # ------------------------------------------------------------------ + # Re-encrypt history (for device pairing) + # ------------------------------------------------------------------ + + async def reencrypt_history(self): + """Re-encrypt all cached messages with self-encryption key. + + After device pairing, the new device shares the same identity key + but cannot decrypt old messages (Double Ratchet keys are one-time use). + This re-encrypts all cached messages so they can be read using the + self-encryption key derived from the shared identity key. + """ + if not self.identity_private or not self.session: + return + + self_key = derive_self_encryption_key(self.identity_private) + + # Phase 1: Fetch & decrypt all messages to populate cache + # (messages the old device never opened won't be in cache yet) + try: + convs = await self.list_conversations() + total_convs = len(convs) + for ci, conv in enumerate(convs): + cid = conv.get("id") or conv.get("conversation_id") + if not cid: + continue + if self._reencrypt_progress_cb: + self._reencrypt_progress_cb( + f"Fetching messages: {ci + 1}/{total_convs} conversations..." + ) + offset = 0 + while True: + msgs = await self.get_messages(cid, limit=200, offset=offset) + if not msgs or len(msgs) < 200: + break + offset += len(msgs) + except Exception as e: + self._logger.warning("Failed to fetch messages for re-encryption: %s", e) + + # Phase 2: Read cache and re-encrypt + cache_dir = get_key_dir(self.email) / "message_cache" + if not cache_dir.exists(): + self._logger.info("No message cache to re-encrypt.") + return + + all_updates = [] + conv_ids = set() + for f in cache_dir.iterdir(): + if f.suffix in (".json", ".bin"): + conv_ids.add(f.stem) + + total_files = len(conv_ids) + for i, conv_id in enumerate(sorted(conv_ids)): + cache = _load_message_cache(self.email, conv_id, self._cache_key) + if not cache: + continue + + for msg_id, entry in cache.items(): + # Skip control messages (sender key distribution) + if entry.get("_control"): + continue + # Skip entries with no useful content + text = entry.get("text", "") + if not text and not entry.get("image") and not entry.get("file"): + continue + + # Rebuild plaintext from cached payload + payload = {k: v for k, v in entry.items() + if k not in ("message_id", "created_at", "read_by", "sender_id", "deleted")} + plaintext = pad_plaintext(json.dumps(payload, ensure_ascii=False).encode("utf-8")) + + # Re-encrypt with self-encryption key + _, nonce, ct, tag = aes_encrypt(plaintext, key=self_key) + all_updates.append({ + "message_id": msg_id, + "encrypted_content": encode_binary(ct + tag), + "nonce": encode_binary(nonce), + }) + + if self._reencrypt_progress_cb: + self._reencrypt_progress_cb(f"Re-encrypting history: {i + 1}/{total_files} conversations...") + + if not all_updates: + self._logger.info("No messages to re-encrypt.") + return + + # Send in batches of 500 + batch_size = 500 + total = len(all_updates) + for start in range(0, total, batch_size): + batch = all_updates[start:start + batch_size] + resp = await self.send_and_recv("reencrypt_messages", updates=batch) + if resp["status"] != "ok": + self._logger.warning("Re-encrypt batch failed: %s", resp.get("data", {}).get("message", "")) + else: + self._logger.info("Re-encrypted %d/%d messages.", min(start + batch_size, total), total) + + if self._reencrypt_progress_cb: + self._reencrypt_progress_cb(f"Re-encryption complete: {total} messages uploaded.") + + # ------------------------------------------------------------------ + # User Profiles + # ------------------------------------------------------------------ + + async def get_profile(self, user_id: str | None = None) -> dict | None: + """Get user profile. If user_id is None, returns own profile.""" + kwargs = {} + if user_id: + kwargs["user_id"] = user_id + resp = await self.send_and_recv("get_profile", **kwargs) + if resp["status"] == "ok": + return resp["data"] + return None + + async def update_profile(self, **fields) -> tuple[bool, str]: + """Update own profile (phone, location, *_visible).""" + resp = await self.send_and_recv("update_profile", **fields) + if resp["status"] == "ok": + return True, "OK" + return False, resp["data"]["message"] + + async def update_avatar(self, image_data: bytes) -> tuple[bool, str]: + """Upload avatar image.""" + resp = await self.send_and_recv("update_avatar", data=encode_binary(image_data)) + if resp["status"] == "ok": + return True, resp["data"].get("avatar_file", "") + return False, resp["data"]["message"] + + async def get_avatar(self, user_id: str) -> bytes | None: + """Download avatar for a user.""" + resp = await self.send_and_recv("get_avatar", user_id=user_id) + if resp["status"] == "ok": + return decode_binary(resp["data"]["data"]) + return None + + async def update_group_avatar(self, conv_id: str, image_data: bytes) -> tuple[bool, str]: + """Upload avatar for a group conversation.""" + resp = await self.send_and_recv("update_group_avatar", + conversation_id=conv_id, data=encode_binary(image_data)) + if resp["status"] == "ok": + return True, resp["data"].get("avatar_file", "") + return False, resp["data"]["message"] + + async def get_group_avatar(self, conv_id: str) -> bytes | None: + """Download avatar for a group conversation.""" + resp = await self.send_and_recv("get_group_avatar", conversation_id=conv_id) + if resp["status"] == "ok": + return decode_binary(resp["data"]["data"]) + return None + + # ------------------------------------------------------------------ + # Cleanup + # ------------------------------------------------------------------ + + async def close(self): + self.connected = False + if self._listener_task: + self._listener_task.cancel() + if self.raw_writer: + self.raw_writer.close() + + async def reconnect(self): + """Close existing connection and re-establish: connect + re-login using in-memory keys.""" + try: + await self.close() + except Exception: + pass + # Reset reader/writer but keep keys and sessions + self.reader = None + self.writer = None + self.raw_writer = None + self._listener_task = None + self._pending.clear() + self.login_rejected = False + # Drain queues + while not self._response_queue.empty(): + try: + self._response_queue.get_nowait() + except Exception: + break + while not self._notification_queue.empty(): + try: + self._notification_queue.get_nowait() + except Exception: + break + await self.connect() + self._listener_task = asyncio.create_task(self._background_listener()) + if self.email and self.private_key: + # RSA challenge-response login (keys already in memory) + start = await self.send_and_recv("login_start", email=self.email) + if start["status"] == "ok": + challenge = decode_binary(start["data"]["challenge"]) + signature = rsa_sign(self.private_key, challenge) + login_kwargs = { + "email": self.email, + "signature": encode_binary(signature), + "client_version": VERSION, + } + if self.device_id: + login_kwargs["device_id"] = self.device_id + finish = await self.send_and_recv("login_finish", **login_kwargs) + if finish["status"] == "ok": + self.session = finish["data"] + asyncio.create_task(self._ensure_prekeys()) + else: + # Login rejected — keys were likely rotated on another device + self.session = None + self.connected = False + self.login_rejected = True diff --git a/client.py b/client.py new file mode 100644 index 0000000..23ba082 --- /dev/null +++ b/client.py @@ -0,0 +1,899 @@ +"""Interactive CLI client for encrypted chat (X3DH + Double Ratchet).""" + +import asyncio +import getpass +import logging +import os +import re + +from chat_core import ChatClient, IdentityKeyChanged + + +def setup_logging(): + level_name = os.getenv("LOG_LEVEL", "WARNING").upper() + level = getattr(logging, level_name, logging.WARNING) + logging.basicConfig(level=level, format="%(levelname)s: %(message)s") + + +async def prompt(text: str) -> str: + """Non-blocking terminal input.""" + return await asyncio.get_event_loop().run_in_executor(None, lambda: input(text).strip()) + + +async def prompt_password(text: str = "Password: ") -> str: + """Non-blocking hidden password input (M3 fix).""" + return await asyncio.get_event_loop().run_in_executor(None, lambda: getpass.getpass(text)) + + +# M3 fix: strip terminal control/escape sequences from untrusted text +_CONTROL_RE = re.compile(r"[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]|\x1b\[[0-9;]*[A-Za-z]") + + +def _sanitize_text(s) -> str: + """Remove control characters and ANSI escape sequences.""" + if not isinstance(s, str): + s = str(s) if s is not None else "" + return _CONTROL_RE.sub("", s) + + +def _safe_filename(name: str) -> str: + """Sanitize remote filename: basename only, no path traversal, no NUL.""" + name = os.path.basename(name) + name = name.replace("\x00", "") + if not name or name.startswith("."): + name = "download" + return name + + +def _human_size(n: int) -> str: + if n >= 1024 * 1024: + return f"{n / (1024*1024):.1f} MB" + if n >= 1024: + return f"{n / 1024:.0f} KB" + return f"{n} B" + + +async def _select_conversation(client: ChatClient, label: str = "Select conversation") -> tuple[dict | None, list[dict]]: + """List conversations and let user pick one. Returns (conv, convs) or (None, []).""" + convs = await client.list_conversations() + if not convs: + print("[*] No conversations.") + return None, [] + + def conv_label(c): + if c.get("name"): + return _sanitize_text(c["name"]) + others = [_sanitize_text(m.get("username") or m.get("email") or "?") for m in c["members"] if m.get("email") != client.email] + return ", ".join(others) if others else _sanitize_text(client.username) + + print() + for i, c in enumerate(convs): + print(f" {i+1}) {conv_label(c)}") + choice = await prompt(f"{label}: ") + try: + idx = int(choice) - 1 + if not (0 <= idx < len(convs)): + print("[!] Invalid selection.") + return None, convs + except ValueError: + print("[!] Invalid selection.") + return None, convs + return convs[idx], convs + + +async def interactive_menu(client: ChatClient): + """Interactive terminal menu.""" + while True: + print("\n--- Encrypted Chat ---") + print("1) Send direct message") + print("2) Send to conversation") + print("3) Read messages") + print("4) Create group conversation") + print("5) Add member to group") + print("6) Send image") + print("7) Send file") + print("8) Invitations") + print("9) Leave group") + print("10) Rename group") + print("11) Delete conversation") + print("12) Search messages") + print("13) My profile") + print("14) View user profile") + print("15) Manage devices") + print("16) React to message") + print("17) Pin/Unpin message") + print("18) View pinned messages") + print("19) Forward message") + print("20) Verify contact") + print("21) Show my fingerprint") + print("22) Change password") + print("23) Change username") + print("q) Quit") + + choice = await prompt("> ") + + if choice == "1": + email = await prompt("To (email): ") + if not email: + continue + text = await prompt("Message: ") + if not text: + continue + conv_id, msg = await client.find_or_create_conversation(email) + if not conv_id: + print(f"[!] {msg}") + continue + convs = await client.list_conversations() + members = [] + for c in convs: + if c["conversation_id"] == conv_id: + members = c["members"] + break + try: + ok, result = await client.send_message(conv_id, text, members) + except IdentityKeyChanged as ikc: + print(f"[!] Identity key changed for {ikc.user_id[:8]}. Accept the new key before sending.") + continue + print(f"[{'+'if ok else '!'}] {'Message sent.' if ok else result}") + + elif choice == "2": + conv, _ = await _select_conversation(client) + if not conv: + continue + text = await prompt("Message: ") + if not text: + continue + try: + ok, result = await client.send_message(conv["conversation_id"], text, conv["members"]) + except IdentityKeyChanged as ikc: + print(f"[!] Identity key changed for {ikc.user_id[:8]}. Accept the new key before sending.") + continue + print(f"[{'+'if ok else '!'}] {'Message sent.' if ok else result}") + + elif choice == "3": + conv, _ = await _select_conversation(client) + if not conv: + continue + messages = await client.get_messages(conv["conversation_id"]) + if not messages: + print("[*] No messages.") + continue + _print_messages(messages, client, conv) + + action = await prompt("\nAction (r=reply, d=delete, dl=download file, empty=back): ") + if not action: + continue + if action.lower().startswith("dl"): + await _download_file_action(client, messages) + continue + if action.lower().startswith("d"): + await _delete_message_action(client, messages) + continue + if action.lower().startswith("r"): + reply_choice = await prompt("Reply to message #: ") + else: + reply_choice = action + try: + reply_idx = int(reply_choice) - 1 + if not (0 <= reply_idx < len(messages)): + print("[!] Invalid message number.") + continue + except ValueError: + print("[!] Invalid number.") + continue + reply_to_id = messages[reply_idx]["message_id"] + text = await prompt("Message: ") + if not text: + continue + try: + ok, result = await client.send_message(conv["conversation_id"], text, conv["members"], reply_to=reply_to_id) + except IdentityKeyChanged as ikc: + print(f"[!] Identity key changed for {ikc.user_id[:8]}. Accept the new key before sending.") + continue + print(f"[{'+'if ok else '!'}] {'Message sent.' if ok else result}") + + elif choice == "4": + name = await prompt("Group name (empty for none): ") + members_input = await prompt("Member emails (comma-separated): ") + members = [m.strip() for m in members_input.split(",") if m.strip()] + if not members: + continue + conv_id, msg = await client.create_conversation(members, name=name.strip() or None) + if conv_id: + print(f"[+] Group created with: {', '.join(members)}") + else: + print(f"[!] {msg}") + + elif choice == "5": + conv, _ = await _select_conversation(client) + if not conv: + continue + email = await prompt("Email to add: ") + ok, msg = await client.add_member(conv["conversation_id"], email) + print(f"[{'+'if ok else '!'}] {msg or 'Invitation sent.'}") + + elif choice == "6": + conv, _ = await _select_conversation(client) + if not conv: + continue + image_path = await prompt("Image path: ") + if not image_path: + continue + try: + ok, msg = await client.send_image(conv["conversation_id"], image_path, conv["members"]) + except IdentityKeyChanged as ikc: + print(f"[!] Identity key changed for {ikc.user_id[:8]}. Accept the new key before sending.") + continue + print(f"[{'+'if ok else '!'}] {msg}") + + elif choice == "7": + conv, _ = await _select_conversation(client) + if not conv: + continue + file_path = await prompt("File path: ") + if not file_path: + continue + if not os.path.isfile(file_path): + print("[!] File not found.") + continue + try: + ok, msg = await client.send_file(conv["conversation_id"], file_path, conv["members"]) + except IdentityKeyChanged as ikc: + print(f"[!] Identity key changed for {ikc.user_id[:8]}. Accept the new key before sending.") + continue + print(f"[{'+'if ok else '!'}] {msg}") + + elif choice == "8": + await _invitations_menu(client) + + elif choice == "9": + conv, _ = await _select_conversation(client, "Select group to leave") + if not conv: + continue + confirm = await prompt(f"Leave '{conv.get('name', 'this conversation')}'? (y/n): ") + if confirm.lower() != "y": + continue + ok, msg = await client.leave_group(conv["conversation_id"]) + print(f"[{'+'if ok else '!'}] {msg}") + + elif choice == "10": + conv, _ = await _select_conversation(client, "Select group to rename") + if not conv: + continue + name = await prompt("New name: ") + if not name: + continue + ok, msg = await client.rename_conversation(conv["conversation_id"], name.strip()) + print(f"[{'+'if ok else '!'}] {msg}") + + elif choice == "11": + conv, _ = await _select_conversation(client, "Select conversation to delete") + if not conv: + continue + confirm = await prompt("Delete this conversation? This cannot be undone. (y/n): ") + if confirm.lower() != "y": + continue + ok, msg = await client.delete_conversation(conv["conversation_id"]) + print(f"[{'+'if ok else '!'}] {msg}") + + elif choice == "12": + conv, _ = await _select_conversation(client, "Select conversation to search") + if not conv: + continue + query = await prompt("Search query: ") + if not query: + continue + # First ensure we have messages cached by fetching them + await client.get_messages(conv["conversation_id"]) + results = client.search_messages(conv["conversation_id"], query) + if not results: + print("[*] No matches found.") + continue + print(f"\n[*] {len(results)} match(es):") + for r in results: + sender = _sanitize_text(r.get("sender", "???")) + text = _sanitize_text(r.get("text", "")) + ts = r.get("created_at", "")[:16] + # Highlight match in text + idx = text.lower().find(query.lower()) + if idx >= 0: + text = text[:idx] + "\033[33m" + text[idx:idx+len(query)] + "\033[0m" + text[idx+len(query):] + print(f" [{ts}] {sender}: {text}") + + elif choice == "13": + await _my_profile_menu(client) + + elif choice == "14": + email = await prompt("User email: ") + if not email: + continue + # Need to find user_id from email — try via conversation members + user_id = None + convs = await client.list_conversations() + for c in convs: + for m in c.get("members", []): + if m.get("email") == email: + user_id = m.get("user_id") or m.get("id") + break + if user_id: + break + if not user_id: + print("[!] User not found in your conversations.") + continue + profile = await client.get_profile(user_id) + if not profile: + print("[!] Could not load profile.") + continue + _print_profile(profile) + + elif choice == "15": + await _devices_menu(client) + + elif choice == "16": + conv, _ = await _select_conversation(client) + if not conv: + continue + messages = await client.get_messages(conv["conversation_id"]) + if not messages: + print("[*] No messages.") + continue + _print_messages(messages, client, conv) + msg_choice = await prompt("React to message #: ") + try: + msg_idx = int(msg_choice) - 1 + if not (0 <= msg_idx < len(messages)): + print("[!] Invalid message number.") + continue + except ValueError: + print("[!] Invalid number.") + continue + print("Reactions: thumbsup, heart, laugh, surprised, sad, thumbsdown") + reaction = await prompt("Reaction: ").strip().lower() + if reaction not in ("thumbsup", "heart", "laugh", "surprised", "sad", "thumbsdown"): + print("[!] Invalid reaction.") + continue + ok, msg = await client.react_message(messages[msg_idx]["message_id"], reaction, "add") + print(f"[{'+'if ok else '!'}] {msg}") + + elif choice == "17": + conv, _ = await _select_conversation(client) + if not conv: + continue + messages = await client.get_messages(conv["conversation_id"]) + if not messages: + print("[*] No messages.") + continue + _print_messages(messages, client, conv) + msg_choice = await prompt("Pin/Unpin message #: ") + try: + msg_idx = int(msg_choice) - 1 + if not (0 <= msg_idx < len(messages)): + print("[!] Invalid message number.") + continue + except ValueError: + print("[!] Invalid number.") + continue + m = messages[msg_idx] + action = "unpin" if m.get("pinned_at") else "pin" + ok, msg = await client.pin_message(m["message_id"], conv["conversation_id"], action) + print(f"[{'+'if ok else '!'}] {action.capitalize()}: {msg}") + + elif choice == "18": + conv, _ = await _select_conversation(client) + if not conv: + continue + pinned = await client.get_pinned_messages(conv["conversation_id"]) + if not pinned: + print("[*] No pinned messages.") + continue + print(f"\n[*] {len(pinned)} pinned message(s):") + for p in pinned: + print(f" {p.get('message_id', '?')[:8]}... pinned at {p.get('pinned_at', '?')}") + + elif choice == "19": + conv, _ = await _select_conversation(client, "Select source conversation") + if not conv: + continue + messages = await client.get_messages(conv["conversation_id"]) + if not messages: + print("[*] No messages.") + continue + _print_messages(messages, client, conv) + msg_choice = await prompt("Forward message #: ") + try: + msg_idx = int(msg_choice) - 1 + if not (0 <= msg_idx < len(messages)): + print("[!] Invalid message number.") + continue + except ValueError: + print("[!] Invalid number.") + continue + target_conv, _ = await _select_conversation(client, "Select target conversation") + if not target_conv: + continue + fwd_msg = messages[msg_idx] + fwd_msg["conversation_id"] = conv["conversation_id"] + try: + ok, result = await client.forward_message( + target_conv["conversation_id"], fwd_msg, target_conv["members"] + ) + except IdentityKeyChanged as ikc: + print(f"[!] Identity key changed for {ikc.user_id[:8]}. Accept the new key before sending.") + continue + print(f"[{'+'if ok else '!'}] {'Forwarded.' if ok else result}") + + elif choice == "20": + # Verify contact — show safety number for a DM conversation + conv, _ = await _select_conversation(client, "Select DM to verify") + if not conv: + continue + # Find peer user_id + peer_uid = "" + peer_name = "" + for m in conv.get("members", []): + if m.get("email") != client.email: + peer_uid = m.get("user_id") or m.get("id") or "" + peer_name = _sanitize_text(m.get("username") or m.get("email") or "?") + break + if not peer_uid: + print("[!] Could not identify peer user.") + continue + # Ensure we have their identity key in cache + info = await client._get_user_info(user_id=peer_uid) + if not info or not info.get("identity_key_bytes"): + print("[!] Could not retrieve identity key for this user.") + continue + status = client.get_verification_status(peer_uid) + print(f"\n--- Verification: {peer_name} ---") + print(f"Status: {status.upper()}") + safety = client.get_safety_number(peer_uid) + if safety: + print(f"\nSafety Number:\n{safety}") + fp = client.get_peer_fingerprint(peer_uid) + if fp: + print(f"\nTheir Fingerprint:\n{fp}") + my_fp = client.get_my_fingerprint() + if my_fp: + print(f"\nYour Fingerprint:\n{my_fp}") + if status != "verified": + action = await prompt("\nMark as verified? (y/n): ") + if action.lower() == "y": + client.verify_contact(peer_uid, info["identity_key_bytes"], + method="safety_number") + print("[+] Contact marked as verified.") + else: + action = await prompt("\nRemove verification? (y/n): ") + if action.lower() == "y": + client.unverify_contact(peer_uid) + print("[+] Verification removed.") + + elif choice == "21": + # Show own fingerprint + fp = client.get_my_fingerprint() + if fp: + print(f"\n--- Your Fingerprint ---\n{fp}") + else: + print("[!] Not logged in or identity key not available.") + + elif choice == "22": + # Change password + old_pw = getpass.getpass("Current password: ") + new_pw = getpass.getpass("New password: ") + confirm_pw = getpass.getpass("Confirm new password: ") + if new_pw != confirm_pw: + print("[!] Passwords do not match.") + elif not new_pw: + print("[!] Password cannot be empty.") + else: + ok, msg = client.change_password(old_pw, new_pw) + if ok: + print(f"[+] {msg}") + else: + print(f"[!] {msg}") + + elif choice == "23": + new_un = await prompt("New username: ") + new_un = new_un.strip() if new_un else "" + if not new_un: + print("[!] Username cannot be empty.") + else: + ok, msg = await client.change_username(new_un) + if ok: + print(f"[+] {msg}") + else: + print(f"[!] {msg}") + + elif choice in ("q", "Q", "quit", "exit"): + print("[*] Bye.") + break + + +def _print_messages(messages, client, conv): + """Print messages to terminal.""" + print() + for i, m in enumerate(messages): + if m.get("deleted"): + print(f" #{i+1} [Message deleted]") + continue + reply_info = "" + if m.get("reply_to"): + for j, orig in enumerate(messages): + if orig["message_id"] == m["reply_to"]: + reply_info = f" (reply to #{j+1})" + break + else: + reply_info = " (reply to older message)" + image_info = "" + if m.get("image"): + img = m["image"] + image_info = f" [Image: {_sanitize_text(img.get('filename', '?'))} ({_human_size(img.get('size', 0))})]" + file_info = "" + if m.get("file"): + fi = m["file"] + file_info = f" [File: {_sanitize_text(fi.get('filename', '?'))} ({_human_size(fi.get('size', 0))})]" + read_info = "" + if m.get("sender") == client.username: + read_by = m.get("read_by", []) + delivered_to = m.get("delivered_to", []) + member_map = {} + for mem in conv.get("members", []): + uid = mem.get("user_id") or mem.get("id", "") + if uid: + member_map[uid] = _sanitize_text(mem.get("username") or mem.get("email") or "?") + my_uid = client.session.get("user_id", "") if client.session else "" + others_read = [r for r in read_by if r.get("user_id") != my_uid] + others_delivered = [d for d in delivered_to if d.get("user_id") != my_uid] + if others_read: + names = ", ".join(member_map.get(r["user_id"], r["user_id"][:8]) for r in others_read) + read_info = f" [\u2713\u2713 Read by {names}]" + elif others_delivered: + read_info = " [\u2713\u2713 Delivered]" + else: + read_info = " [\u2713 Sent]" + pin_info = "" + if m.get("pinned_at"): + pin_info = " \U0001f4cc" + reaction_info = "" + reactions = m.get("reactions", []) + if reactions: + grouped = {} + for r in reactions: + grouped.setdefault(r["reaction"], 0) + grouped[r["reaction"]] += 1 + _REMOJI = {"thumbsup": "\U0001f44d", "heart": "\u2764\ufe0f", "laugh": "\U0001f602", + "surprised": "\U0001f62e", "sad": "\U0001f622", "thumbsdown": "\U0001f44e"} + parts = [f"{_REMOJI.get(k, k)}{v}" for k, v in grouped.items()] + reaction_info = " [" + " ".join(parts) + "]" + fwd_info = "" + if m.get("forwarded_from"): + fwd_sender = _sanitize_text(m["forwarded_from"].get("sender", "?")) + fwd_info = f" (fwd from {fwd_sender})" + text = _sanitize_text(m.get("text", "")) + sender = _sanitize_text(m.get("sender", "?")) + print(f" #{i+1} {sender}: {text}{image_info}{file_info}{reply_info}{read_info}{pin_info}{reaction_info}{fwd_info}") + + +async def _delete_message_action(client, messages): + del_choice = await prompt("Delete message #: ") + try: + del_idx = int(del_choice) - 1 + if not (0 <= del_idx < len(messages)): + print("[!] Invalid message number.") + return + except ValueError: + print("[!] Invalid number.") + return + ok, msg = await client.delete_message(messages[del_idx]["message_id"]) + print(f"[{'+'if ok else '!'}] {msg}") + + +async def _download_file_action(client, messages): + dl_choice = await prompt("Download from message #: ") + try: + dl_idx = int(dl_choice) - 1 + if not (0 <= dl_idx < len(messages)): + print("[!] Invalid message number.") + return + except ValueError: + print("[!] Invalid number.") + return + m = messages[dl_idx] + file_info = m.get("file") or m.get("image") + if not file_info: + print("[!] No file/image in this message.") + return + filename = _safe_filename(file_info.get("filename", "download")) + save_path = await prompt(f"Save as [{filename}]: ") + if not save_path: + save_path = filename + data = await client.download_file(file_info["file_id"], file_info) + if data: + with open(save_path, "wb") as f: + f.write(data) + print(f"[+] Saved to {save_path} ({_human_size(len(data))})") + else: + print("[!] Download failed.") + + +async def _invitations_menu(client): + invitations = await client.list_invitations() + if not invitations: + print("[*] No pending invitations.") + return + print("\nPending invitations:") + for i, inv in enumerate(invitations): + inv_name = _sanitize_text(inv.get("conversation_name") or inv.get("conversation_id", "")[:8]) + invited_by = _sanitize_text(inv.get("invited_by_username") or inv.get("invited_by", "")[:8]) + print(f" {i+1}) {inv_name} (invited by {invited_by})") + choice = await prompt("Select invitation (or empty to go back): ") + if not choice: + return + try: + idx = int(choice) - 1 + if not (0 <= idx < len(invitations)): + print("[!] Invalid selection.") + return + except ValueError: + print("[!] Invalid selection.") + return + inv = invitations[idx] + action = await prompt("(a)ccept or (d)ecline? ") + if action.lower().startswith("a"): + ok, msg = await client.accept_invitation(inv["conversation_id"]) + print(f"[{'+'if ok else '!'}] {msg}") + elif action.lower().startswith("d"): + ok, msg = await client.decline_invitation(inv["conversation_id"]) + print(f"[{'+'if ok else '!'}] {msg}") + + +def _print_profile(profile): + print(f"\n Username: {_sanitize_text(profile.get('username', '?'))}") + print(f" Email: {_sanitize_text(profile.get('email', '?'))}") + phone = profile.get("phone") + if phone: + print(f" Phone: {_sanitize_text(phone)}") + location = profile.get("location") + if location: + print(f" Location: {_sanitize_text(location)}") + has_avatar = profile.get("avatar_file") + print(f" Avatar: {'Yes' if has_avatar else 'No'}") + + +async def _my_profile_menu(client): + profile = await client.get_profile() + if not profile: + print("[!] Could not load profile.") + return + print("\n--- My Profile ---") + _print_profile(profile) + print(f" Phone visible: {profile.get('phone_visible', False)}") + print(f" Email visible: {profile.get('email_visible', False)}") + print(f" Location visible: {profile.get('location_visible', False)}") + + action = await prompt("\n(e)dit, (a)vatar upload, or empty to go back: ") + if not action: + return + if action.lower().startswith("e"): + print("[*] Leave fields empty to keep current value.") + phone = await prompt(f"Phone [{profile.get('phone', '')}]: ") + location = await prompt(f"Location [{profile.get('location', '')}]: ") + phone_vis = await prompt(f"Phone visible [{profile.get('phone_visible', False)}] (y/n): ") + email_vis = await prompt(f"Email visible [{profile.get('email_visible', False)}] (y/n): ") + loc_vis = await prompt(f"Location visible [{profile.get('location_visible', False)}] (y/n): ") + + fields = {} + if phone: + fields["phone"] = phone + if location: + fields["location"] = location + if phone_vis.lower() in ("y", "n"): + fields["phone_visible"] = phone_vis.lower() == "y" + if email_vis.lower() in ("y", "n"): + fields["email_visible"] = email_vis.lower() == "y" + if loc_vis.lower() in ("y", "n"): + fields["location_visible"] = loc_vis.lower() == "y" + if fields: + ok, msg = await client.update_profile(**fields) + print(f"[{'+'if ok else '!'}] {msg}") + else: + print("[*] No changes.") + elif action.lower().startswith("a"): + path = await prompt("Avatar image path: ") + if not path or not os.path.isfile(path): + print("[!] File not found.") + return + data = open(path, "rb").read() + ok, msg = await client.update_avatar(data) + print(f"[{'+'if ok else '!'}] {msg}") + + +async def _devices_menu(client): + resp = await client.send_and_recv("list_devices") + if resp.get("status") != "ok": + print(f"[!] {resp.get('data', {}).get('message', 'Failed')}") + return + devices = resp["data"].get("devices", []) + if not devices: + print("[*] No devices found.") + return + current_device_id = client.device_id + print("\nYour devices:") + for i, d in enumerate(devices): + name = _sanitize_text(d.get("device_name") or "Unnamed") + did = d.get("device_id", "?") + last_seen = _sanitize_text(d.get("last_seen_at", "?")) + current = " (this device)" if did == current_device_id else "" + print(f" {i+1}) {name} — {did[:8]}... — last seen: {last_seen}{current}") + action = await prompt("\n(r)emove a device, or empty to go back: ") + if not action or not action.lower().startswith("r"): + return + choice = await prompt("Remove device #: ") + try: + idx = int(choice) - 1 + if not (0 <= idx < len(devices)): + print("[!] Invalid selection.") + return + except ValueError: + print("[!] Invalid selection.") + return + d = devices[idx] + if d.get("device_id") == current_device_id: + print("[!] Cannot remove current device.") + return + resp = await client.send_and_recv("remove_device", device_id=d["device_id"]) + if resp.get("status") == "ok": + print("[+] Device removed.") + else: + print(f"[!] {resp.get('data', {}).get('message', 'Failed')}") + + +async def notification_printer(client: ChatClient): + """Print real-time notifications with sender name.""" + while True: + notif = await client._notification_queue.get() + notif_type = notif.get("type", "") + data = notif.get("data", {}) + if notif_type == "messages_read": + continue # Silent - read receipts shown when reading messages + if notif_type == "session_reset": + from_uid = data.get("from_user_id", "")[:8] + client.handle_session_reset_notification( + data.get("from_user_id", ""), + data.get("from_device_id") or None, + ) + print(f"\n[*] Session with {from_uid}... was reset. New session will be created on next message.") + continue + if notif_type == "group_invitation": + inv_name = _sanitize_text(data.get("conversation_name", "?")) + invited_by = _sanitize_text(data.get("invited_by_username", "?")) + print(f"\n[*] New invitation to '{inv_name}' from {invited_by}. Use option 8 to accept/decline.") + continue + if notif_type in ("conversation_created", "member_added", "member_removed", "conversation_renamed"): + print(f"\n[*] Conversation updated ({notif_type}).") + continue + if notif_type == "message_reacted": + username = _sanitize_text(data.get("username", data.get("user_id", "?")[:8])) + reaction = _sanitize_text(data.get("reaction", "?")) + action = data.get("action", "add") + print(f"\n[*] {username} {'added' if action == 'add' else 'removed'} reaction '{reaction}'") + continue + if notif_type in ("message_pinned", "message_unpinned"): + username = _sanitize_text(data.get("username", data.get("user_id", "?")[:8])) + act = "pinned" if notif_type == "message_pinned" else "unpinned" + print(f"\n[*] {username} {act} a message") + continue + if notif_type in ("user_online", "user_offline", "online_users"): + continue # Silent for CLI + payload = client.decrypt_notification(data) + if payload: + print(f"\n[*] New message from {_sanitize_text(payload['sender'])} in conversation {data.get('conversation_id', '?')[:8]}...") + # None = control message (sender key distribution), skip silently + + +async def main(): + setup_logging() + client = ChatClient() + await client.connect() + + client._listener_task = asyncio.create_task(client._background_listener()) + notif_task = asyncio.create_task(notification_printer(client)) + + print("=== Encrypted Chat Client ===") + print("1) Register") + print("2) Login") + print("3) Link new device (this device)") + print("4) Authorize new device (from this device)") + print("5) Rotate keys (revoke other devices)") + choice = await prompt("> ") + + if choice == "1": + username = await prompt("Username (display): ") + email = await prompt("Email: ") + password = await prompt_password("Password (for private key): ") + if not email or not password: + print("[!] Email and password required.") + await client.close() + return + ok, code_or_msg = await client.register(username, password, email=email) + if not ok: + print(f"[!] {code_or_msg}") + await client.close() + return + print(f"[*] Registration code: {code_or_msg}") + code = await prompt("Enter code: ") + ok2, msg2 = await client.confirm_registration(email, username, code) + print(f"[{'+'if ok2 else '!'}] {msg2}") + if ok2: + ok3, msg3 = await client.login(email, password) + print(f"[{'+'if ok3 else '!'}] {msg3}") + elif choice == "2": + email = await prompt("Email: ") + password = await prompt_password("Password (for private key): ") + ok, msg = await client.login(email, password) + print(f"[{'+'if ok else '!'}] {msg}") + elif choice == "3": + email = await prompt("Email: ") + password = await prompt_password("Password (for private key): ") + if not password: + print("[!] Password required.") + await client.close() + return + ok, code_or_msg = await client.pairing_start(email) + if not ok: + print(f"[!] {code_or_msg}") + await client.close() + return + code = code_or_msg + print(f"[*] Pairing code: {code}") + print("[*] Approve this code on an already-logged-in device.") + ok2, msg2 = await client.pairing_wait(code, email, password) + if not ok2: + print(f"[!] {msg2}") + await client.close() + return + print(f"[+] {msg2}") + ok3, msg3 = await client.login(email, password) + print(f"[{'+'if ok3 else '!'}] {msg3}") + elif choice == "4": + email = await prompt("Email: ") + password = await prompt_password("Password (for private key): ") + ok, msg = await client.login(email, password) + print(f"[{'+'if ok else '!'}] {msg}") + if not ok: + await client.close() + return + code = await prompt("Pairing code: ") + ok2, msg2 = await client.authorize_device(code) + print(f"[{'+'if ok2 else '!'}] {msg2}") + elif choice == "5": + email = await prompt("Email: ") + password = await prompt_password("Password (for private key): ") + ok, msg = await client.login(email, password) + print(f"[{'+'if ok else '!'}] {msg}") + if not ok: + await client.close() + return + confirm = await prompt("This will revoke other devices. Type 'YES' to continue: ") + if confirm != "YES": + print("[*] Cancelled.") + await client.close() + return + ok2, msg2 = await client.rotate_keys(client.username, password) + print(f"[{'+'if ok2 else '!'}] {msg2}") + else: + print("[!] Invalid choice.") + await client.close() + return + + if client.session: + await interactive_menu(client) + + notif_task.cancel() + await client.close() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("\n[*] Bye.") diff --git a/crypto_utils.py b/crypto_utils.py new file mode 100644 index 0000000..1f6795d --- /dev/null +++ b/crypto_utils.py @@ -0,0 +1,935 @@ +"""Cryptographic utilities: Ed25519, X25519, AES-256-GCM, Double Ratchet, Sender Keys. + +RSA functions retained for login challenge-response only. +""" + +import hashlib +import hmac +import json +import os +import struct +import uuid +from dataclasses import dataclass, field + +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import padding, rsa +from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey, Ed25519PublicKey +from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey +from cryptography.hazmat.primitives.ciphers.aead import AESGCM +from cryptography.hazmat.primitives.kdf.hkdf import HKDF +from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC + + +# --------------------------------------------------------------------------- +# Password-based key encryption (M3: PBKDF2 600k iterations + AES-256-GCM) +# --------------------------------------------------------------------------- + +PBKDF2_ITERATIONS = 600_000 +_ECP1_MAGIC = b"ECP1" # Encrypted Chat PBKDF v1 format marker + + +def _encrypt_private_key(raw_bytes: bytes, password: bytes) -> bytes: + """Encrypt raw key bytes with PBKDF2-HMAC-SHA256 (600k iterations) + AES-256-GCM. + + Output format: MAGIC(4) + salt(16) + nonce(12) + ciphertext_with_tag(N+16) + """ + salt = os.urandom(16) + kdf = PBKDF2HMAC(algorithm=hashes.SHA256(), length=32, + salt=salt, iterations=PBKDF2_ITERATIONS) + derived = kdf.derive(password) + nonce = os.urandom(12) + aesgcm = AESGCM(derived) + ct = aesgcm.encrypt(nonce, raw_bytes, _ECP1_MAGIC) # AAD = magic bytes + return _ECP1_MAGIC + salt + nonce + ct + + +def _decrypt_private_key(data: bytes, password: bytes) -> bytes: + """Decrypt key bytes encrypted with _encrypt_private_key.""" + if not data.startswith(_ECP1_MAGIC): + raise ValueError("Not ECP1 format") + salt = data[4:20] + nonce = data[20:32] + ct = data[32:] + kdf = PBKDF2HMAC(algorithm=hashes.SHA256(), length=32, + salt=salt, iterations=PBKDF2_ITERATIONS) + derived = kdf.derive(password) + aesgcm = AESGCM(derived) + return aesgcm.decrypt(nonce, ct, _ECP1_MAGIC) + + +# --------------------------------------------------------------------------- +# RSA (login challenge-response ONLY) +# --------------------------------------------------------------------------- + +def generate_rsa_keypair(key_size: int = 4096) -> tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey]: + private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size) + return private_key, private_key.public_key() + + +def serialize_private_key(key: rsa.RSAPrivateKey, password: bytes | None = None) -> bytes: + if password: + raw = key.private_bytes(serialization.Encoding.DER, serialization.PrivateFormat.PKCS8, + serialization.NoEncryption()) + return _encrypt_private_key(raw, password) + return key.private_bytes(serialization.Encoding.PEM, serialization.PrivateFormat.PKCS8, + serialization.NoEncryption()) + + +def serialize_public_key(key: rsa.RSAPublicKey) -> bytes: + return key.public_bytes(serialization.Encoding.PEM, serialization.PublicFormat.SubjectPublicKeyInfo) + + +def load_private_key(data: bytes, password: bytes | None = None) -> rsa.RSAPrivateKey: + if data.startswith(_ECP1_MAGIC): + raw = _decrypt_private_key(data, password) + return serialization.load_der_private_key(raw, password=None) + # Legacy PEM format (old BestAvailableEncryption or unencrypted) + return serialization.load_pem_private_key(data, password=password) + + +def load_public_key(pem: bytes) -> rsa.RSAPublicKey: + return serialization.load_pem_public_key(pem) + + +def rsa_sign(private_key: rsa.RSAPrivateKey, data: bytes) -> bytes: + return private_key.sign( + data, + padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH), + hashes.SHA256(), + ) + + +def rsa_verify(public_key: rsa.RSAPublicKey, signature: bytes, data: bytes) -> bool: + try: + public_key.verify( + signature, data, + padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.AUTO), + hashes.SHA256(), + ) + return True + except Exception: + return False + + +# --------------------------------------------------------------------------- +# AES-256-GCM (symmetric encryption — used by ratchet message keys & images) +# --------------------------------------------------------------------------- + +def aes_encrypt(plaintext: bytes, key: bytes | None = None) -> tuple[bytes, bytes, bytes, bytes]: + """Encrypt with AES-256-GCM. Returns (key, nonce, ciphertext, tag).""" + if key is None: + key = AESGCM.generate_key(bit_length=256) + nonce = os.urandom(12) + aesgcm = AESGCM(key) + ct_with_tag = aesgcm.encrypt(nonce, plaintext, None) + ciphertext = ct_with_tag[:-16] + tag = ct_with_tag[-16:] + return key, nonce, ciphertext, tag + + +def aes_decrypt(key: bytes, nonce: bytes, ciphertext: bytes, tag: bytes) -> bytes: + """Decrypt with AES-256-GCM.""" + aesgcm = AESGCM(key) + return aesgcm.decrypt(nonce, ciphertext + tag, None) + + +# --------------------------------------------------------------------------- +# Ed25519 Identity Keys +# --------------------------------------------------------------------------- + +def generate_identity_keypair() -> tuple[Ed25519PrivateKey, Ed25519PublicKey]: + priv = Ed25519PrivateKey.generate() + return priv, priv.public_key() + + +def serialize_ed25519_private(key: Ed25519PrivateKey, password: bytes | None = None) -> bytes: + if password: + raw = serialize_ed25519_private_raw(key) # 32 bytes + return _encrypt_private_key(raw, password) + return serialize_ed25519_private_raw(key) # 32 bytes, no password + + +def serialize_ed25519_private_raw(key: Ed25519PrivateKey) -> bytes: + """Serialize Ed25519 private key to 32 raw bytes (unencrypted).""" + return key.private_bytes(serialization.Encoding.Raw, serialization.PrivateFormat.Raw, serialization.NoEncryption()) + + +def serialize_ed25519_public(key: Ed25519PublicKey) -> bytes: + """Serialize Ed25519 public key to 32 raw bytes.""" + return key.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw) + + +def load_ed25519_private(data: bytes, password: bytes | None = None) -> Ed25519PrivateKey: + if data.startswith(_ECP1_MAGIC): + raw = _decrypt_private_key(data, password) + return Ed25519PrivateKey.from_private_bytes(raw) + # Legacy formats: PEM (old BestAvailableEncryption) or 32-byte raw + if password: + return serialization.load_pem_private_key(data, password=password) + if len(data) == 32: + return Ed25519PrivateKey.from_private_bytes(data) + return serialization.load_pem_private_key(data, password=None) + + +def load_ed25519_public(data: bytes) -> Ed25519PublicKey: + if len(data) == 32: + return Ed25519PublicKey.from_public_bytes(data) + return serialization.load_pem_public_key(data) + + +def ed25519_sign(private_key: Ed25519PrivateKey, data: bytes) -> bytes: + """Sign data with Ed25519. Returns 64-byte signature.""" + return private_key.sign(data) + + +def ed25519_verify(public_key: Ed25519PublicKey, signature: bytes, data: bytes) -> bool: + """Verify Ed25519 signature.""" + try: + public_key.verify(signature, data) + return True + except Exception: + return False + + +# --------------------------------------------------------------------------- +# X25519 Key Exchange +# --------------------------------------------------------------------------- + +def generate_x25519_keypair() -> tuple[X25519PrivateKey, X25519PublicKey]: + priv = X25519PrivateKey.generate() + return priv, priv.public_key() + + +def serialize_x25519_private(key: X25519PrivateKey) -> bytes: + """Serialize X25519 private key to 32 raw bytes.""" + return key.private_bytes(serialization.Encoding.Raw, serialization.PrivateFormat.Raw, serialization.NoEncryption()) + + +def serialize_x25519_public(key: X25519PublicKey) -> bytes: + """Serialize X25519 public key to 32 raw bytes.""" + return key.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw) + + +def load_x25519_private(data: bytes) -> X25519PrivateKey: + return X25519PrivateKey.from_private_bytes(data) + + +def load_x25519_public(data: bytes) -> X25519PublicKey: + return X25519PublicKey.from_public_bytes(data) + + +def x25519_dh(private_key: X25519PrivateKey, public_key: X25519PublicKey) -> bytes: + """Perform X25519 Diffie-Hellman. Returns 32-byte shared secret.""" + return private_key.exchange(public_key) + + +# --------------------------------------------------------------------------- +# Ed25519 <-> X25519 conversion (for Identity Key dual use) +# --------------------------------------------------------------------------- + +def ed25519_private_to_x25519(ed_private: Ed25519PrivateKey) -> X25519PrivateKey: + """Derive X25519 private key from Ed25519 private key via RFC 7748 clamping.""" + raw = ed_private.private_bytes( + serialization.Encoding.Raw, serialization.PrivateFormat.Raw, serialization.NoEncryption() + ) + # SHA-512 hash of the seed, take first 32 bytes, clamp per RFC 7748 + h = hashlib.sha512(raw).digest()[:32] + clamped = bytearray(h) + clamped[0] &= 248 + clamped[31] &= 127 + clamped[31] |= 64 + return X25519PrivateKey.from_private_bytes(bytes(clamped)) + + +def ed25519_public_to_x25519(ed_public: Ed25519PublicKey) -> X25519PublicKey: + """Derive X25519 public key from Ed25519 public key. + + Uses the cryptography library's internal conversion. For production use, + we compute the X25519 public key from the converted private key when possible. + For remote keys (where we don't have the private key), we use a pure-Python + implementation of the Ed25519->X25519 point conversion. + """ + # Montgomery u = (1 + y) / (1 - y) mod p, where p = 2^255 - 19 + raw = ed_public.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw) + y = int.from_bytes(raw, "little") + # Clear the sign bit + y &= (1 << 255) - 1 + p = (1 << 255) - 19 + # u = (1 + y) * inverse(1 - y) mod p + one_plus_y = (1 + y) % p + one_minus_y = (1 - y) % p + inv = pow(one_minus_y, p - 2, p) + u = (one_plus_y * inv) % p + x25519_bytes = u.to_bytes(32, "little") + return X25519PublicKey.from_public_bytes(x25519_bytes) + + +# --------------------------------------------------------------------------- +# HKDF +# --------------------------------------------------------------------------- + +_HKDF_INFO_SELF = b"EncryptedChat_SelfKey" +_HKDF_INFO_RK = b"EncryptedChat_RootKey" + + +def derive_self_encryption_key(identity_private: Ed25519PrivateKey) -> bytes: + """Derive a static AES-256 key from identity key for encrypting own sent messages. + + This is NOT a ratchet — it's a static key. Safe because only the owner + has the identity private key, and self-copies don't need forward secrecy. + """ + raw = identity_private.private_bytes( + serialization.Encoding.Raw, serialization.PrivateFormat.Raw, serialization.NoEncryption() + ) + return hkdf_derive(raw, salt=b"self_encryption", info=_HKDF_INFO_SELF, length=32) + + +_HKDF_INFO_LOCAL = b"EncryptedChat_LocalStorage" + + +def derive_local_storage_key(identity_private: Ed25519PrivateKey) -> bytes: + """Derive AES-256 key for encrypting local session/sender key files.""" + raw = identity_private.private_bytes( + serialization.Encoding.Raw, serialization.PrivateFormat.Raw, serialization.NoEncryption() + ) + return hkdf_derive(raw, salt=b"local_storage", info=_HKDF_INFO_LOCAL, length=32) + + +_HKDF_INFO_CK_MSG = b"\x01" # chain key -> message key +_HKDF_INFO_CK_NEXT = b"\x02" # chain key -> next chain key + + +def hkdf_derive(input_key: bytes, salt: bytes, info: bytes, length: int = 32) -> bytes: + return HKDF(algorithm=hashes.SHA256(), length=length, salt=salt, info=info).derive(input_key) + + +def kdf_rk(root_key: bytes, dh_output: bytes) -> tuple[bytes, bytes]: + """Root key KDF. Returns (new_root_key, chain_key). + + Uses HKDF with the root key as salt and DH output as input key material. + Derives 64 bytes: first 32 = new root key, last 32 = chain key. + """ + derived = hkdf_derive(dh_output, salt=root_key, info=_HKDF_INFO_RK, length=64) + return derived[:32], derived[32:] + + +def kdf_ck(chain_key: bytes) -> tuple[bytes, bytes]: + """Chain key KDF. Returns (new_chain_key, message_key). + + Uses HMAC-SHA256: + message_key = HMAC(chain_key, 0x01) + new_chain_key = HMAC(chain_key, 0x02) + """ + message_key = hmac.new(chain_key, _HKDF_INFO_CK_MSG, hashlib.sha256).digest() + new_chain_key = hmac.new(chain_key, _HKDF_INFO_CK_NEXT, hashlib.sha256).digest() + return new_chain_key, message_key + + +# --------------------------------------------------------------------------- +# X3DH +# --------------------------------------------------------------------------- + +_X3DH_INFO = b"EncryptedChat_X3DH" + + +def generate_signed_prekey(identity_private: Ed25519PrivateKey) -> dict: + """Generate a signed pre-key (SPK). + + Returns {private: X25519PrivateKey, public: X25519PublicKey, signature: bytes, id: str}. + """ + spk_priv, spk_pub = generate_x25519_keypair() + spk_pub_bytes = serialize_x25519_public(spk_pub) + signature = ed25519_sign(identity_private, spk_pub_bytes) + return { + "private": spk_priv, + "public": spk_pub, + "signature": signature, + "id": str(uuid.uuid4()), + } + + +def generate_one_time_prekeys(count: int = 50) -> list[dict]: + """Generate a batch of one-time pre-keys. + + Returns [{private: X25519PrivateKey, public: X25519PublicKey, id: str}, ...]. + """ + result = [] + for _ in range(count): + priv, pub = generate_x25519_keypair() + result.append({"private": priv, "public": pub, "id": str(uuid.uuid4())}) + return result + + +def x3dh_initiate( + ik_private_ed: Ed25519PrivateKey, + ik_public_remote_ed: Ed25519PublicKey, + spk_remote: X25519PublicKey, + spk_signature: bytes, + opk_remote: X25519PublicKey | None = None, +) -> tuple[bytes, X25519PrivateKey, X25519PublicKey]: + """Initiator side of X3DH. + + Args: + ik_private_ed: Our Ed25519 identity private key + ik_public_remote_ed: Remote Ed25519 identity public key + spk_remote: Remote signed pre-key (X25519 public) + spk_signature: Ed25519 signature of spk_remote by ik_public_remote_ed + opk_remote: Optional one-time pre-key (X25519 public) + + Returns: + (shared_secret, ephemeral_private, ephemeral_public) + """ + # Verify SPK signature + spk_remote_bytes = serialize_x25519_public(spk_remote) + if not ed25519_verify(ik_public_remote_ed, spk_signature, spk_remote_bytes): + raise ValueError("Invalid SPK signature") + + # Convert identity keys to X25519 + ik_x25519_private = ed25519_private_to_x25519(ik_private_ed) + ik_x25519_remote = ed25519_public_to_x25519(ik_public_remote_ed) + + # Generate ephemeral keypair + ek_priv, ek_pub = generate_x25519_keypair() + + # DH computations + dh1 = x25519_dh(ik_x25519_private, spk_remote) # IK_A, SPK_B + dh2 = x25519_dh(ek_priv, ik_x25519_remote) # EK_A, IK_B + dh3 = x25519_dh(ek_priv, spk_remote) # EK_A, SPK_B + + dh_concat = dh1 + dh2 + dh3 + if opk_remote is not None: + dh4 = x25519_dh(ek_priv, opk_remote) # EK_A, OPK_B + dh_concat += dh4 + + # Derive shared secret + shared_secret = hkdf_derive(dh_concat, salt=b"\x00" * 32, info=_X3DH_INFO, length=32) + return shared_secret, ek_priv, ek_pub + + +def x3dh_respond( + ik_private_ed: Ed25519PrivateKey, + spk_private: X25519PrivateKey, + ik_remote_ed: Ed25519PublicKey, + ek_remote: X25519PublicKey, + opk_private: X25519PrivateKey | None = None, +) -> bytes: + """Responder side of X3DH. + + Args: + ik_private_ed: Our Ed25519 identity private key + spk_private: Our signed pre-key private (X25519) + ik_remote_ed: Remote Ed25519 identity public key + ek_remote: Remote ephemeral key (X25519 public) + opk_private: Our one-time pre-key private (X25519), if used + + Returns: + shared_secret (32 bytes) + """ + ik_x25519_private = ed25519_private_to_x25519(ik_private_ed) + ik_x25519_remote = ed25519_public_to_x25519(ik_remote_ed) + + dh1 = x25519_dh(spk_private, ik_x25519_remote) # SPK_B, IK_A + dh2 = x25519_dh(ik_x25519_private, ek_remote) # IK_B, EK_A + dh3 = x25519_dh(spk_private, ek_remote) # SPK_B, EK_A + + dh_concat = dh1 + dh2 + dh3 + if opk_private is not None: + dh4 = x25519_dh(opk_private, ek_remote) # OPK_B, EK_A + dh_concat += dh4 + + shared_secret = hkdf_derive(dh_concat, salt=b"\x00" * 32, info=_X3DH_INFO, length=32) + return shared_secret + + +# --------------------------------------------------------------------------- +# Double Ratchet +# --------------------------------------------------------------------------- + +MAX_SKIP = 256 # max messages to skip in a single chain (out-of-order tolerance) + + +@dataclass +class RatchetHeader: + """Header sent with each ratchet message.""" + dh_pub: bytes # sender's current ratchet public key (32 bytes) + n: int # message number in current sending chain + pn: int # number of messages in previous sending chain + + def serialize(self) -> bytes: + return json.dumps({ + "dh_pub": serialize_x25519_public(load_x25519_public(self.dh_pub)).hex() + if isinstance(self.dh_pub, bytes) else serialize_x25519_public(self.dh_pub).hex(), + "n": self.n, + "pn": self.pn, + }).encode() + + def to_dict(self) -> dict: + pub_hex = self.dh_pub.hex() if isinstance(self.dh_pub, bytes) else \ + serialize_x25519_public(self.dh_pub).hex() + return {"dh_pub": pub_hex, "n": self.n, "pn": self.pn} + + @classmethod + def from_dict(cls, d: dict) -> "RatchetHeader": + return cls(dh_pub=bytes.fromhex(d["dh_pub"]), n=d["n"], pn=d["pn"]) + + +class DoubleRatchet: + """Signal Double Ratchet implementation.""" + + def __init__(self): + self.dh_pair: tuple[X25519PrivateKey, X25519PublicKey] | None = None + self.dh_remote: X25519PublicKey | None = None + self.root_key: bytes = b"" + self.send_chain_key: bytes | None = None + self.recv_chain_key: bytes | None = None + self.send_n: int = 0 + self.recv_n: int = 0 + self.prev_send_n: int = 0 + # (dh_pub_hex, n) -> message_key for out-of-order messages + self.skipped: dict[tuple[str, int], bytes] = {} + + @classmethod + def init_alice(cls, shared_secret: bytes, bob_spk_pub: X25519PublicKey) -> "DoubleRatchet": + """Initialize as initiator (Alice) after X3DH. + + Alice performs the first DH ratchet step immediately. + """ + ratchet = cls() + ratchet.dh_pair = generate_x25519_keypair() + ratchet.dh_remote = bob_spk_pub + + # Perform DH ratchet to derive send chain + dh_output = x25519_dh(ratchet.dh_pair[0], ratchet.dh_remote) + ratchet.root_key, ratchet.send_chain_key = kdf_rk(shared_secret, dh_output) + ratchet.recv_chain_key = None + ratchet.send_n = 0 + ratchet.recv_n = 0 + ratchet.prev_send_n = 0 + return ratchet + + @classmethod + def init_bob(cls, shared_secret: bytes, spk_pair: tuple[X25519PrivateKey, X25519PublicKey]) -> "DoubleRatchet": + """Initialize as responder (Bob) after X3DH. + + Bob uses his SPK as the initial ratchet key pair. + """ + ratchet = cls() + ratchet.dh_pair = spk_pair + ratchet.root_key = shared_secret + ratchet.send_chain_key = None + ratchet.recv_chain_key = None + ratchet.send_n = 0 + ratchet.recv_n = 0 + ratchet.prev_send_n = 0 + return ratchet + + def encrypt(self, plaintext: bytes) -> dict: + """Encrypt a message. + + Returns {header: {dh_pub, n, pn}, ciphertext: bytes, nonce: bytes}. + """ + if self.send_chain_key is None: + raise RuntimeError("Send chain not initialized") + + self.send_chain_key, message_key = kdf_ck(self.send_chain_key) + + header = RatchetHeader( + dh_pub=serialize_x25519_public(self.dh_pair[1]), + n=self.send_n, + pn=self.prev_send_n, + ) + + # Encrypt with AES-256-GCM using the message key + nonce = os.urandom(12) + aesgcm = AESGCM(message_key) + # Include header as AAD to bind ciphertext to header + aad = header.serialize() + ct_with_tag = aesgcm.encrypt(nonce, plaintext, aad) + + self.send_n += 1 + + return { + "header": header.to_dict(), + "ciphertext": ct_with_tag, # includes 16-byte tag + "nonce": nonce, + } + + def decrypt(self, header_dict: dict, ciphertext: bytes, nonce: bytes) -> bytes: + """Decrypt a message. Handles DH ratchet step if new dh_pub. + + State is snapshotted before modification and restored on failure (M9 fix). + """ + header = RatchetHeader.from_dict(header_dict) + remote_dh_pub_bytes = header.dh_pub + + # Check if this is from a skipped message (no state modification needed) + skip_key = (remote_dh_pub_bytes.hex(), header.n) + if skip_key in self.skipped: + mk = self.skipped.pop(skip_key) + aad = header.serialize() + aesgcm = AESGCM(mk) + try: + return aesgcm.decrypt(nonce, ciphertext, aad) + except Exception: + self.skipped[skip_key] = mk # restore skipped key + raise + + # Snapshot state before modifications + snap = self._snapshot() + + try: + remote_dh_pub = load_x25519_public(remote_dh_pub_bytes) + current_remote_bytes = serialize_x25519_public(self.dh_remote) if self.dh_remote else None + + if current_remote_bytes is None or remote_dh_pub_bytes != current_remote_bytes: + # New DH ratchet step + self._skip_messages(header.pn) + self._dh_ratchet(remote_dh_pub) + + self._skip_messages(header.n) + + # Derive message key from receive chain + self.recv_chain_key, mk = kdf_ck(self.recv_chain_key) + self.recv_n += 1 + + aad = header.serialize() + aesgcm = AESGCM(mk) + return aesgcm.decrypt(nonce, ciphertext, aad) + except Exception: + self._restore(snap) + raise + + def _snapshot(self) -> dict: + """Capture mutable state for rollback on decrypt failure.""" + return { + "dh_pair": self.dh_pair, + "dh_remote": self.dh_remote, + "root_key": self.root_key, + "send_chain_key": self.send_chain_key, + "recv_chain_key": self.recv_chain_key, + "send_n": self.send_n, + "recv_n": self.recv_n, + "prev_send_n": self.prev_send_n, + "skipped": dict(self.skipped), + } + + def _restore(self, snap: dict): + """Restore state from snapshot.""" + self.dh_pair = snap["dh_pair"] + self.dh_remote = snap["dh_remote"] + self.root_key = snap["root_key"] + self.send_chain_key = snap["send_chain_key"] + self.recv_chain_key = snap["recv_chain_key"] + self.send_n = snap["send_n"] + self.recv_n = snap["recv_n"] + self.prev_send_n = snap["prev_send_n"] + self.skipped = snap["skipped"] + + def _skip_messages(self, until: int): + """Skip ahead in the receive chain, storing message keys for out-of-order delivery.""" + if self.recv_chain_key is None: + return + if until - self.recv_n > MAX_SKIP: + raise RuntimeError(f"Too many skipped messages ({until - self.recv_n} > {MAX_SKIP})") + while self.recv_n < until: + self.recv_chain_key, mk = kdf_ck(self.recv_chain_key) + remote_hex = serialize_x25519_public(self.dh_remote).hex() if self.dh_remote else "" + self.skipped[(remote_hex, self.recv_n)] = mk + self.recv_n += 1 + + def _dh_ratchet(self, remote_dh_pub: X25519PublicKey): + """Perform a DH ratchet step: update receive chain, generate new DH pair, update send chain.""" + self.prev_send_n = self.send_n + self.send_n = 0 + self.recv_n = 0 + self.dh_remote = remote_dh_pub + + # Derive new receive chain key + dh_output = x25519_dh(self.dh_pair[0], self.dh_remote) + self.root_key, self.recv_chain_key = kdf_rk(self.root_key, dh_output) + + # Generate new DH pair and derive new send chain key + self.dh_pair = generate_x25519_keypair() + dh_output = x25519_dh(self.dh_pair[0], self.dh_remote) + self.root_key, self.send_chain_key = kdf_rk(self.root_key, dh_output) + + def export_state(self) -> bytes: + """Serialize full ratchet state for persistent storage.""" + state = { + "dh_priv": serialize_x25519_private(self.dh_pair[0]).hex() if self.dh_pair else None, + "dh_pub": serialize_x25519_public(self.dh_pair[1]).hex() if self.dh_pair else None, + "dh_remote": serialize_x25519_public(self.dh_remote).hex() if self.dh_remote else None, + "root_key": self.root_key.hex(), + "send_ck": self.send_chain_key.hex() if self.send_chain_key else None, + "recv_ck": self.recv_chain_key.hex() if self.recv_chain_key else None, + "send_n": self.send_n, + "recv_n": self.recv_n, + "prev_send_n": self.prev_send_n, + "skipped": {f"{k[0]}:{k[1]}": v.hex() for k, v in self.skipped.items()}, + } + return json.dumps(state).encode() + + @classmethod + def import_state(cls, data: bytes) -> "DoubleRatchet": + """Deserialize ratchet state.""" + state = json.loads(data) + r = cls() + if state["dh_priv"] and state["dh_pub"]: + priv = load_x25519_private(bytes.fromhex(state["dh_priv"])) + pub = load_x25519_public(bytes.fromhex(state["dh_pub"])) + r.dh_pair = (priv, pub) + if state["dh_remote"]: + r.dh_remote = load_x25519_public(bytes.fromhex(state["dh_remote"])) + r.root_key = bytes.fromhex(state["root_key"]) + r.send_chain_key = bytes.fromhex(state["send_ck"]) if state["send_ck"] else None + r.recv_chain_key = bytes.fromhex(state["recv_ck"]) if state["recv_ck"] else None + r.send_n = state["send_n"] + r.recv_n = state["recv_n"] + r.prev_send_n = state["prev_send_n"] + r.skipped = {} + for k_str, v_hex in state.get("skipped", {}).items(): + parts = k_str.rsplit(":", 1) + dh_hex = parts[0] + n = int(parts[1]) + r.skipped[(dh_hex, n)] = bytes.fromhex(v_hex) + return r + + +# --------------------------------------------------------------------------- +# Sender Keys (group messaging) +# --------------------------------------------------------------------------- + +class SenderKeyState: + """Sender key chain for group messaging. + + Each sender in a group has their own sender key chain. + Other group members receive the initial sender_key via pairwise Double Ratchet. + """ + + def __init__(self, sender_key: bytes | None = None): + if sender_key is None: + sender_key = os.urandom(32) + self.sender_key = sender_key + self.chain_id = hashlib.sha256(sender_key).digest() + self.chain_key = hkdf_derive(sender_key, salt=b"\x00" * 32, info=b"SenderKeyChain", length=32) + self.n = 0 + # For receivers: track chain state to allow fast-forward + self._known_keys: dict[int, bytes] = {} + + def encrypt(self, plaintext: bytes) -> dict: + """Encrypt with current chain key. + + Returns {chain_id: hex, n: int, ciphertext: bytes, nonce: bytes}. + """ + self.chain_key, message_key = kdf_ck(self.chain_key) + nonce = os.urandom(12) + aesgcm = AESGCM(message_key) + # AAD includes chain_id and message number + aad = self.chain_id + struct.pack(">I", self.n) + ct_with_tag = aesgcm.encrypt(nonce, plaintext, aad) + result = { + "chain_id": self.chain_id.hex(), + "n": self.n, + "ciphertext": ct_with_tag, + "nonce": nonce, + } + self.n += 1 + return result + + MAX_SENDER_KEY_SKIP = 256 + + def decrypt(self, chain_id_hex: str, n: int, ciphertext: bytes, nonce: bytes) -> bytes: + """Decrypt a group message. Fast-forwards the chain if needed. + + State is snapshotted before modification and restored on failure (M9 fix). + """ + chain_id = bytes.fromhex(chain_id_hex) + if chain_id != self.chain_id: + raise ValueError("Chain ID mismatch") + + if n - self.n > self.MAX_SENDER_KEY_SKIP: + raise ValueError(f"Sender key skip too large ({n - self.n} > {self.MAX_SENDER_KEY_SKIP})") + + # Snapshot before fast-forward + snap_chain_key = self.chain_key + snap_n = self.n + snap_known = dict(self._known_keys) + + try: + # Fast-forward the chain to reach message n + while self.n <= n: + self.chain_key, mk = kdf_ck(self.chain_key) + self._known_keys[self.n] = mk + self.n += 1 + + mk = self._known_keys.pop(n, None) + if mk is None: + raise ValueError(f"Message key for n={n} not available (already consumed)") + + aad = chain_id + struct.pack(">I", n) + aesgcm = AESGCM(mk) + return aesgcm.decrypt(nonce, ciphertext, aad) + except Exception: + self.chain_key = snap_chain_key + self.n = snap_n + self._known_keys = snap_known + raise + + def export_key(self) -> bytes: + """Export sender key for distribution to group members. + + Contains everything needed to initialize a receiving SenderKeyState. + """ + return json.dumps({ + "sender_key": self.sender_key.hex(), + }).encode() + + def export_state(self) -> bytes: + """Serialize full state for persistent storage.""" + return json.dumps({ + "sender_key": self.sender_key.hex(), + "chain_id": self.chain_id.hex(), + "chain_key": self.chain_key.hex(), + "n": self.n, + "known_keys": {str(k): v.hex() for k, v in self._known_keys.items()}, + }).encode() + + @classmethod + def import_state(cls, data: bytes) -> "SenderKeyState": + state = json.loads(data) + obj = cls.__new__(cls) + obj.sender_key = bytes.fromhex(state["sender_key"]) + obj.chain_id = bytes.fromhex(state["chain_id"]) + obj.chain_key = bytes.fromhex(state["chain_key"]) + obj.n = state["n"] + obj._known_keys = {int(k): bytes.fromhex(v) for k, v in state.get("known_keys", {}).items()} + return obj + + @classmethod + def from_key(cls, exported_key: bytes) -> "SenderKeyState": + """Initialize a receiving SenderKeyState from an exported key.""" + data = json.loads(exported_key) + return cls(sender_key=bytes.fromhex(data["sender_key"])) + + +# --------------------------------------------------------------------------- +# Contact Key Verification (Safety Numbers / Fingerprints / QR Codes) +# --------------------------------------------------------------------------- + +FINGERPRINT_VERSION = 0 + + +def compute_fingerprint(user_id: str, identity_key_bytes: bytes, iterations: int = 5200) -> bytes: + """Compute a 32-byte fingerprint for a user's identity key. + + Uses iterated SHA-512 (Signal's NumericFingerprint algorithm). + Seed: version(2B) + identity_key(32B) + user_id(UTF-8). + Each iteration: SHA-512(previous_hash + identity_key). + Output: first 32 bytes of final hash. + """ + version_bytes = FINGERPRINT_VERSION.to_bytes(2, "big") + data = version_bytes + identity_key_bytes + user_id.encode("utf-8") + for _ in range(iterations): + data = hashlib.sha512(data + identity_key_bytes).digest() + return data[:32] + + +def format_fingerprint(fp_bytes: bytes) -> str: + """Format 32-byte fingerprint as 6 groups of 5 zero-padded digits (30 digits). + + Each group: int(bytes[i*5:(i+1)*5], big-endian) % 100000. + Output: two lines of 3 groups each, space-separated. + """ + groups = [] + for i in range(6): + num = int.from_bytes(fp_bytes[i * 5:(i + 1) * 5], "big") % 100000 + groups.append(f"{num:05d}") + return " ".join(groups[:3]) + "\n" + " ".join(groups[3:]) + + +def compute_safety_number(my_uid: str, my_ik_bytes: bytes, + their_uid: str, their_ik_bytes: bytes) -> str: + """Compute a 60-digit safety number for a pair of users. + + Both users see the same number regardless of who computes it. + Lower user_id's fingerprint comes first (deterministic ordering). + Output: 12 groups of 5 digits, formatted as 3 lines of 4 groups. + """ + fp_mine = compute_fingerprint(my_uid, my_ik_bytes) + fp_theirs = compute_fingerprint(their_uid, their_ik_bytes) + if my_uid < their_uid: + combined = fp_mine + fp_theirs + else: + combined = fp_theirs + fp_mine + # 64 bytes -> 12 groups of 5 digits + groups = [] + for i in range(12): + num = int.from_bytes(combined[i * 5:(i + 1) * 5], "big") % 100000 + groups.append(f"{num:05d}") + lines = [ + " ".join(groups[0:4]), + " ".join(groups[4:8]), + " ".join(groups[8:12]), + ] + return "\n".join(lines) + + +def encode_verification_qr(user_id: str, identity_key_bytes: bytes) -> bytes: + """Encode user identity for QR code verification. + + Format: version(1B=0x01) + uid_len(1B) + uid(UTF-8) + identity_key(32B). + """ + uid_bytes = user_id.encode("utf-8") + return b"\x01" + len(uid_bytes).to_bytes(1, "big") + uid_bytes + identity_key_bytes + + +def decode_verification_qr(data: bytes) -> tuple[str, bytes]: + """Decode QR code verification payload. + + Returns (user_id, identity_key_bytes). + Raises ValueError on invalid format. + """ + if len(data) < 3: + raise ValueError("QR data too short") + if data[0] != 0x01: + raise ValueError(f"Unknown QR version: {data[0]}") + uid_len = data[1] + if len(data) < 2 + uid_len + 32: + raise ValueError("QR data truncated") + user_id = data[2:2 + uid_len].decode("utf-8") + identity_key = data[2 + uid_len:2 + uid_len + 32] + return user_id, identity_key + + +# --------------------------------------------------------------------------- +# Message Padding (metadata privacy — hide plaintext length) +# --------------------------------------------------------------------------- + +_PAD_MAGIC = b"\x01" +_PAD_BUCKETS = [64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536] + + +def pad_plaintext(plaintext: bytes) -> bytes: + """Pad plaintext to nearest bucket size to hide message length. + + Format: 0x01 + plaintext + random_padding + pad_length(4B big-endian) + Prefix 0x01 distinguishes padded messages from legacy unpadded (which start with '{'). + """ + content = _PAD_MAGIC + plaintext + # +4 for the length suffix + min_size = len(content) + 4 + target = next((b for b in _PAD_BUCKETS if b >= min_size), min_size) + pad_len = target - len(content) + return content + os.urandom(pad_len - 4) + struct.pack(">I", pad_len) + + +def unpad_plaintext(data: bytes) -> bytes: + """Remove padding. Returns raw plaintext for both padded and legacy unpadded messages.""" + if not data or data[0:1] != _PAD_MAGIC: + return data # legacy unpadded message (starts with '{' for JSON) + if len(data) < 5: + return data # too short to be validly padded + pad_len = struct.unpack(">I", data[-4:])[0] + if pad_len < 4 or pad_len > len(data) - 1: + return data # invalid padding metadata, treat as legacy + return data[1:len(data) - pad_len] diff --git a/db.py b/db.py new file mode 100644 index 0000000..76ecae0 --- /dev/null +++ b/db.py @@ -0,0 +1,1714 @@ +"""MySQL database layer for the encrypted chat server.""" + +import os +import uuid + +import logging +import mysql.connector +from mysql.connector import pooling +from dotenv import load_dotenv + +from crypto_utils import ( + generate_identity_keypair, + serialize_ed25519_public, + generate_signed_prekey, + serialize_x25519_public, + generate_one_time_prekeys, +) + +load_dotenv() + +# Sentinel device_id for self-encrypted copies and legacy (pre-multi-device) rows +SELF_DEVICE_ID = "00000000-0000-0000-0000-000000000000" + + +_logger = logging.getLogger(__name__) + +_pool = None + + +def _get_pool(): + """Get or create the connection pool (lazy init).""" + global _pool + if _pool is None: + pool_size = int(os.getenv("DB_POOL_SIZE", "10")) + _pool = pooling.MySQLConnectionPool( + pool_name="chat_pool", + pool_size=pool_size, + pool_reset_session=True, + host=os.getenv("MYSQL_HOST", "localhost"), + port=int(os.getenv("MYSQL_PORT", "3306")), + user=os.getenv("MYSQL_USER", "root"), + password=os.getenv("MYSQL_PASSWORD", ""), + database=os.getenv("MYSQL_DATABASE", "encrypted_chat"), + ) + _logger.info("DB connection pool created (size=%d)", pool_size) + return _pool + + +def get_connection(): + """Get a connection from the pool.""" + return _get_pool().get_connection() + + +def generate_uuid() -> str: + return str(uuid.uuid4()) + + +# --- Devices --- + +def create_device(user_id: str, device_name: str | None = None) -> str: + """Create a new device for a user. Returns device_id.""" + conn = get_connection() + try: + cursor = conn.cursor() + device_id = generate_uuid() + cursor.execute( + "INSERT INTO devices (id, user_id, device_name) VALUES (%s, %s, %s)", + (device_id, user_id, device_name), + ) + conn.commit() + return device_id + finally: + conn.close() + + +def get_user_devices(user_id: str) -> list[dict]: + """Get all devices for a user.""" + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT id, user_id, device_name, created_at, last_seen_at " + "FROM devices WHERE user_id = %s ORDER BY created_at", + (user_id,), + ) + return cursor.fetchall() + finally: + conn.close() + + +def get_device(device_id: str) -> dict | None: + """Get a single device by ID.""" + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT id, user_id, device_name, created_at, last_seen_at " + "FROM devices WHERE id = %s", + (device_id,), + ) + return cursor.fetchone() + finally: + conn.close() + + +def update_device_last_seen(device_id: str): + """Update last_seen_at timestamp for a device.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "UPDATE devices SET last_seen_at = NOW() WHERE id = %s", + (device_id,), + ) + conn.commit() + finally: + conn.close() + + +def delete_device(device_id: str): + """Delete a device. CASCADE removes its prekeys.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute("DELETE FROM devices WHERE id = %s", (device_id,)) + # Also clean up prekeys explicitly for device_id column + cursor.execute("DELETE FROM signed_prekeys WHERE device_id = %s", (device_id,)) + cursor.execute("DELETE FROM one_time_prekeys WHERE device_id = %s", (device_id,)) + conn.commit() + finally: + conn.close() + + +# --- Users --- + +def create_user(username: str, email: str, rsa_public_key_pem: str, identity_key: bytes) -> str: + """Register a new user. Returns user ID.""" + conn = get_connection() + try: + cursor = conn.cursor() + user_id = generate_uuid() + cursor.execute( + "INSERT INTO users (id, username, email, rsa_public_key, identity_key) " + "VALUES (%s, %s, %s, %s, %s)", + (user_id, username, email, rsa_public_key_pem, identity_key), + ) + conn.commit() + return user_id + finally: + conn.close() + + +def get_user_by_email(email: str) -> dict | None: + """Get user by email.""" + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT id, username, rsa_public_key, email, identity_key FROM users WHERE email = %s", + (email,), + ) + return cursor.fetchone() + finally: + conn.close() + + +def get_user_by_id(user_id: str) -> dict | None: + """Get user by ID.""" + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT id, username, rsa_public_key, email, identity_key FROM users WHERE id = %s", + (user_id,), + ) + return cursor.fetchone() + finally: + conn.close() + + +def shares_conversation(user_id_a: str, user_id_b: str) -> bool: + """Check if two users share at least one conversation.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "SELECT 1 FROM conversation_members cm1 " + "JOIN conversation_members cm2 ON cm1.conversation_id = cm2.conversation_id " + "WHERE cm1.user_id = %s AND cm2.user_id = %s LIMIT 1", + (user_id_a, user_id_b), + ) + return cursor.fetchone() is not None + finally: + conn.close() + + +def get_user_contacts(user_id: str) -> list[str]: + """Get all user IDs that share at least one conversation with the given user.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "SELECT DISTINCT cm2.user_id " + "FROM conversation_members cm1 " + "JOIN conversation_members cm2 ON cm1.conversation_id = cm2.conversation_id " + "WHERE cm1.user_id = %s AND cm2.user_id != %s", + (user_id, user_id), + ) + return [row[0] for row in cursor.fetchall()] + finally: + conn.close() + + +def update_user_rsa_key(user_id: str, rsa_public_key_pem: str): + """Update user's RSA public key (for login).""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute("UPDATE users SET rsa_public_key = %s WHERE id = %s", (rsa_public_key_pem, user_id)) + conn.commit() + finally: + conn.close() + + +def update_username(user_id: str, new_username: str): + """Update user's display name.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute("UPDATE users SET username = %s WHERE id = %s", (new_username, user_id)) + conn.commit() + finally: + conn.close() + + +# --- Pre-keys --- + +def store_signed_prekey(user_id: str, spk_id: str, public_key: bytes, signature: bytes, + device_id: str | None = None): + """Store (or replace) a signed pre-key for a user's device.""" + conn = get_connection() + try: + cursor = conn.cursor() + # Remove old SPKs for this user+device + if device_id: + cursor.execute("DELETE FROM signed_prekeys WHERE user_id = %s AND device_id = %s", + (user_id, device_id)) + else: + cursor.execute("DELETE FROM signed_prekeys WHERE user_id = %s AND device_id IS NULL", + (user_id,)) + cursor.execute( + "INSERT INTO signed_prekeys (id, user_id, device_id, public_key, signature) " + "VALUES (%s, %s, %s, %s, %s)", + (spk_id, user_id, device_id, public_key, signature), + ) + conn.commit() + finally: + conn.close() + + +def get_signed_prekey(user_id: str, device_id: str | None = None) -> dict | None: + """Get the current signed pre-key for a user (optionally per device).""" + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + if device_id: + cursor.execute( + "SELECT id, public_key, signature, device_id, created_at FROM signed_prekeys " + "WHERE user_id = %s AND device_id = %s " + "ORDER BY created_at DESC LIMIT 1", + (user_id, device_id), + ) + else: + cursor.execute( + "SELECT id, public_key, signature, device_id, created_at FROM signed_prekeys " + "WHERE user_id = %s ORDER BY created_at DESC LIMIT 1", + (user_id,), + ) + return cursor.fetchone() + finally: + conn.close() + + +def store_one_time_prekeys(user_id: str, prekeys: list[dict], device_id: str | None = None): + """Store a batch of one-time pre-keys. Each dict has {id, public_key (bytes)}.""" + conn = get_connection() + try: + cursor = conn.cursor() + for pk in prekeys: + cursor.execute( + "INSERT INTO one_time_prekeys (id, user_id, device_id, public_key) " + "VALUES (%s, %s, %s, %s)", + (pk["id"], user_id, device_id, pk["public_key"]), + ) + conn.commit() + finally: + conn.close() + + +def consume_one_time_prekey(user_id: str, device_id: str | None = None) -> dict | None: + """Atomically consume one OTP: SELECT FOR UPDATE + DELETE. + Returns {id, public_key} or None.""" + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + conn.start_transaction() + if device_id: + cursor.execute( + "SELECT id, public_key FROM one_time_prekeys " + "WHERE user_id = %s AND device_id = %s LIMIT 1 FOR UPDATE", + (user_id, device_id), + ) + else: + cursor.execute( + "SELECT id, public_key FROM one_time_prekeys " + "WHERE user_id = %s LIMIT 1 FOR UPDATE", + (user_id,), + ) + row = cursor.fetchone() + if row: + cursor.execute("DELETE FROM one_time_prekeys WHERE id = %s", (row["id"],)) + conn.commit() + return row + except Exception: + conn.rollback() + raise + finally: + conn.close() + + +def count_one_time_prekeys(user_id: str, device_id: str | None = None) -> int: + """Count remaining OTPs for a user (optionally per device).""" + conn = get_connection() + try: + cursor = conn.cursor() + if device_id: + cursor.execute( + "SELECT COUNT(*) FROM one_time_prekeys WHERE user_id = %s AND device_id = %s", + (user_id, device_id), + ) + else: + cursor.execute("SELECT COUNT(*) FROM one_time_prekeys WHERE user_id = %s", (user_id,)) + return cursor.fetchone()[0] + finally: + conn.close() + + +def get_key_bundle(user_id: str) -> dict | None: + """Get complete key bundle for X3DH (single device — legacy compat). + + Returns {identity_key, signed_prekey_id, signed_prekey, spk_signature, + one_time_prekey_id, one_time_prekey} or None. + OTP is consumed atomically. + """ + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + # Get user identity key + cursor.execute("SELECT identity_key FROM users WHERE id = %s", (user_id,)) + user = cursor.fetchone() + if not user: + return None + + # Get signed prekey + cursor.execute( + "SELECT id, public_key, signature, device_id FROM signed_prekeys WHERE user_id = %s " + "ORDER BY created_at DESC LIMIT 1", + (user_id,), + ) + spk = cursor.fetchone() + if not spk: + return None + + # Consume one OTP (may be None) — use transaction for atomicity (H12 fix) + conn.start_transaction() + cursor.execute( + "SELECT id, public_key FROM one_time_prekeys WHERE user_id = %s LIMIT 1 FOR UPDATE", + (user_id,), + ) + opk = cursor.fetchone() + if opk: + cursor.execute("DELETE FROM one_time_prekeys WHERE id = %s", (opk["id"],)) + conn.commit() + + result = { + "identity_key": user["identity_key"], + "signed_prekey_id": spk["id"], + "signed_prekey": spk["public_key"], + "spk_signature": spk["signature"], + } + if opk: + result["one_time_prekey_id"] = opk["id"] + result["one_time_prekey"] = opk["public_key"] + return result + except Exception: + try: + conn.rollback() + except Exception: + pass + raise + finally: + conn.close() + + +def get_key_bundles_for_user(user_id: str) -> dict | None: + """Get key bundles for ALL devices of a user. Returns + {identity_key, device_bundles: [{device_id, signed_prekey_id, signed_prekey_pub, + spk_signature, opk_id, opk_pub}]} or None. + Consumes one OPK per device atomically. + """ + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + # Get user identity key + cursor.execute("SELECT identity_key FROM users WHERE id = %s", (user_id,)) + user = cursor.fetchone() + if not user: + return None + + # Get all signed prekeys (one per device, most recent) + cursor.execute( + "SELECT id, public_key, signature, device_id FROM signed_prekeys " + "WHERE user_id = %s ORDER BY created_at DESC", + (user_id,), + ) + all_spks = cursor.fetchall() + if not all_spks: + return None + + # De-duplicate: keep only the most recent SPK per device_id + seen_devices = set() + spks_by_device = [] + for spk in all_spks: + dev = spk.get("device_id") or "__legacy__" + if dev not in seen_devices: + seen_devices.add(dev) + spks_by_device.append(spk) + + device_bundles = [] + # Commit the implicit transaction from the read-only queries above + # so we can start an explicit transaction for atomic OPK consumption. + conn.commit() + conn.start_transaction() + for spk in spks_by_device: + dev_id = spk.get("device_id") + # Consume one OPK for this device + if dev_id: + cursor.execute( + "SELECT id, public_key FROM one_time_prekeys " + "WHERE user_id = %s AND device_id = %s LIMIT 1 FOR UPDATE", + (user_id, dev_id), + ) + else: + cursor.execute( + "SELECT id, public_key FROM one_time_prekeys " + "WHERE user_id = %s AND device_id IS NULL LIMIT 1 FOR UPDATE", + (user_id,), + ) + opk = cursor.fetchone() + if opk: + cursor.execute("DELETE FROM one_time_prekeys WHERE id = %s", (opk["id"],)) + + bundle = { + "device_id": dev_id, + "signed_prekey_id": spk["id"], + "signed_prekey_pub": spk["public_key"], + "spk_signature": spk["signature"], + } + if opk: + bundle["opk_id"] = opk["id"] + bundle["opk_pub"] = opk["public_key"] + device_bundles.append(bundle) + conn.commit() + + return { + "identity_key": user["identity_key"], + "device_bundles": device_bundles, + } + except Exception: + try: + conn.rollback() + except Exception: + pass + raise + finally: + conn.close() + + +# --- Conversations --- + +def create_conversation(member_user_ids: list[str], joined_at=None, name=None, created_by=None) -> str: + conn = get_connection() + try: + cursor = conn.cursor() + conv_id = generate_uuid() + cursor.execute("INSERT INTO conversations (id, name, created_by) VALUES (%s, %s, %s)", + (conv_id, name, created_by)) + for uid in member_user_ids: + cursor.execute( + "INSERT INTO conversation_members (conversation_id, user_id, joined_at) VALUES (%s, %s, %s)", + (conv_id, uid, joined_at), + ) + conn.commit() + return conv_id + finally: + conn.close() + + +def add_conversation_member(conversation_id: str, user_id: str, joined_at=None): + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "INSERT IGNORE INTO conversation_members (conversation_id, user_id, joined_at) VALUES (%s, %s, %s)", + (conversation_id, user_id, joined_at), + ) + conn.commit() + finally: + conn.close() + + +def remove_conversation_member(conversation_id: str, user_id: str): + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "DELETE FROM conversation_members WHERE conversation_id = %s AND user_id = %s", + (conversation_id, user_id), + ) + conn.commit() + finally: + conn.close() + + +def count_conversation_members(conversation_id: str) -> int: + """Count members in a conversation.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "SELECT COUNT(*) FROM conversation_members WHERE conversation_id = %s", + (conversation_id,), + ) + return cursor.fetchone()[0] + finally: + conn.close() + + +def get_conversation_file_ids(conversation_id: str) -> list[str]: + """Get all file IDs (images + files) uploaded to a conversation.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "SELECT file_id FROM image_uploads WHERE conversation_id = %s", + (conversation_id,), + ) + return [row[0] for row in cursor.fetchall()] + finally: + conn.close() + + +def delete_conversation(conversation_id: str): + """Delete a conversation entirely. CASCADE cleans up members, messages, etc.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute("DELETE FROM conversations WHERE id = %s", (conversation_id,)) + conn.commit() + finally: + conn.close() + + +def get_conversation_members(conversation_id: str) -> list[dict]: + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT u.id, u.username, u.email, u.identity_key FROM conversation_members cm " + "JOIN users u ON cm.user_id = u.id " + "WHERE cm.conversation_id = %s", + (conversation_id,), + ) + return cursor.fetchall() + finally: + conn.close() + + +def find_direct_conversation(user_id_a: str, user_id_b: str) -> str | None: + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "SELECT cm1.conversation_id FROM conversation_members cm1 " + "JOIN conversation_members cm2 ON cm1.conversation_id = cm2.conversation_id " + "WHERE cm1.user_id = %s AND cm2.user_id = %s " + "AND (SELECT COUNT(*) FROM conversation_members cm3 " + " WHERE cm3.conversation_id = cm1.conversation_id) = 2 " + "LIMIT 1", + (user_id_a, user_id_b), + ) + row = cursor.fetchone() + return row[0] if row else None + finally: + conn.close() + + +def update_conversation_creator(conversation_id: str, new_creator_id: str): + """Transfer group creator role to another member.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "UPDATE conversations SET created_by = %s WHERE id = %s", + (new_creator_id, conversation_id), + ) + conn.commit() + finally: + conn.close() + + +def get_conversation(conversation_id: str) -> dict | None: + """Get conversation by ID.""" + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT id, created_at, name, created_by, avatar_file FROM conversations WHERE id = %s", + (conversation_id,), + ) + return cursor.fetchone() + finally: + conn.close() + + +def update_conversation_avatar(conversation_id: str, avatar_file: str): + """Set avatar file for a conversation.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "UPDATE conversations SET avatar_file = %s WHERE id = %s", + (avatar_file, conversation_id), + ) + conn.commit() + finally: + conn.close() + + +def update_conversation_name(conversation_id: str, name: str): + """Update the name of a conversation.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "UPDATE conversations SET name = %s WHERE id = %s", + (name, conversation_id), + ) + conn.commit() + finally: + conn.close() + + +def is_conversation_member(conversation_id: str, user_id: str) -> bool: + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "SELECT 1 FROM conversation_members WHERE conversation_id = %s AND user_id = %s", + (conversation_id, user_id), + ) + return cursor.fetchone() is not None + finally: + conn.close() + + +def list_user_conversations(user_id: str) -> list[dict]: + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT c.id, c.created_at, c.name, c.created_by, c.avatar_file FROM conversations c " + "JOIN conversation_members cm ON c.id = cm.conversation_id " + "WHERE cm.user_id = %s ORDER BY c.created_at DESC", + (user_id,), + ) + convs = cursor.fetchall() + if not convs: + return convs + # Batch-fetch all members for all conversations in one query (N+1 fix) + conv_ids = [c["id"] for c in convs] + placeholders = ",".join(["%s"] * len(conv_ids)) + cursor.execute( + f"SELECT cm.conversation_id, u.id AS user_id, u.username, u.email " + f"FROM conversation_members cm JOIN users u ON cm.user_id = u.id " + f"WHERE cm.conversation_id IN ({placeholders})", + conv_ids, + ) + members_by_conv: dict[str, list[dict]] = {} + for row in cursor.fetchall(): + cid = row.pop("conversation_id") + members_by_conv.setdefault(cid, []).append(row) + for conv in convs: + conv["members"] = members_by_conv.get(conv["id"], []) + return convs + finally: + conn.close() + + +# --- Group Invitations --- + +def create_invitation(conversation_id: str, user_id: str, invited_by: str): + """Create a pending group invitation.""" + conn = get_connection() + try: + cursor = conn.cursor() + inv_id = generate_uuid() + cursor.execute( + "INSERT IGNORE INTO group_invitations (id, conversation_id, user_id, invited_by) " + "VALUES (%s, %s, %s, %s)", + (inv_id, conversation_id, user_id, invited_by), + ) + conn.commit() + finally: + conn.close() + + +def get_pending_invitations(user_id: str) -> list[dict]: + """Get all pending invitations for a user, joined with conversation and inviter info.""" + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT gi.id, gi.conversation_id, gi.invited_by, gi.created_at, " + "c.name AS conversation_name, u.username AS invited_by_username " + "FROM group_invitations gi " + "JOIN conversations c ON gi.conversation_id = c.id " + "JOIN users u ON gi.invited_by = u.id " + "WHERE gi.user_id = %s " + "ORDER BY gi.created_at DESC", + (user_id,), + ) + return cursor.fetchall() + finally: + conn.close() + + +def delete_invitation(conversation_id: str, user_id: str): + """Delete a pending invitation.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "DELETE FROM group_invitations WHERE conversation_id = %s AND user_id = %s", + (conversation_id, user_id), + ) + conn.commit() + finally: + conn.close() + + +def has_pending_invitation(conversation_id: str, user_id: str) -> bool: + """Check if a user has a pending invitation for a conversation.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "SELECT 1 FROM group_invitations WHERE conversation_id = %s AND user_id = %s", + (conversation_id, user_id), + ) + return cursor.fetchone() is not None + finally: + conn.close() + + +# --- Messages --- + +def store_message( + conversation_id: str, + sender_id: str, + ratchet_header: bytes, + recipients: list[dict], + x3dh_header: bytes | None = None, + sender_chain_id: bytes | None = None, + sender_chain_n: int | None = None, + image_file_id: str | None = None, + sender_device_id: str | None = None, +) -> str: + """Store an encrypted message with per-recipient ciphertext. + + recipients: [{user_id, encrypted_content (bytes), nonce (bytes), + device_id (str, optional), ratchet_header (bytes, optional), + x3dh_header (bytes, optional)}] + """ + conn = get_connection() + try: + cursor = conn.cursor() + msg_id = generate_uuid() + cursor.execute( + "INSERT INTO messages (id, conversation_id, sender_id, sender_device_id, " + "ratchet_header, x3dh_header, sender_chain_id, sender_chain_n, image_file_id) " + "VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)", + (msg_id, conversation_id, sender_id, sender_device_id, ratchet_header, + x3dh_header, sender_chain_id, sender_chain_n, image_file_id), + ) + for r in recipients: + device_id = r.get("device_id", SELF_DEVICE_ID) + cursor.execute( + "INSERT INTO message_recipients (message_id, user_id, device_id, " + "encrypted_content, nonce, ratchet_header, x3dh_header) " + "VALUES (%s, %s, %s, %s, %s, %s, %s)", + (msg_id, r["user_id"], device_id, r["encrypted_content"], r["nonce"], + r.get("ratchet_header"), r.get("x3dh_header")), + ) + conn.commit() + cursor.execute("SELECT created_at FROM messages WHERE id = %s", (msg_id,)) + row = cursor.fetchone() + created_at = row[0].isoformat() if row else None + return msg_id, created_at + finally: + conn.close() + + +def get_messages(conversation_id: str, user_id: str, limit: int = 50, offset: int = 0, + device_id: str | None = None, after_ts: str | None = None) -> list[dict]: + """Get messages for a user in a conversation, JOINing their per-recipient ciphertext. + + If device_id is set, returns rows where mr.device_id matches OR is the sentinel + (self-encrypted / legacy). May return duplicate message IDs when both device-specific + and self-encrypted rows exist — caller should deduplicate (prefer device-specific). + + If after_ts is set, only returns messages created after that timestamp (ISO format). + Results are ordered ASC when after_ts is used, DESC otherwise. + """ + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + if device_id: + where_parts = ["m.conversation_id = %s", + "(cm.joined_at IS NULL OR m.created_at >= cm.joined_at)"] + params = [user_id, device_id, SELF_DEVICE_ID, user_id, conversation_id] + + if after_ts: + where_parts.append("m.created_at > %s") + params.append(after_ts) + + where_clause = " AND ".join(where_parts) + order = "ASC" if after_ts else "DESC" + + cursor.execute( + "SELECT m.id, m.conversation_id, m.sender_id, m.sender_device_id, " + "m.ratchet_header, m.x3dh_header, " + "m.sender_chain_id, m.sender_chain_n, m.created_at, m.deleted_at, m.image_file_id, " + "m.pinned_at, m.pinned_by, " + "mr.encrypted_content, mr.nonce, mr.device_id AS mr_device_id, " + "mr.ratchet_header AS mr_ratchet_header, mr.x3dh_header AS mr_x3dh_header " + "FROM messages m " + "JOIN message_recipients mr ON m.id = mr.message_id AND mr.user_id = %s " + " AND (mr.device_id = %s OR mr.device_id = %s) " + "JOIN conversation_members cm ON cm.conversation_id = m.conversation_id AND cm.user_id = %s " + f"WHERE {where_clause} " + f"ORDER BY m.created_at {order} LIMIT %s OFFSET %s", + (*params, limit, offset), + ) + else: + where_parts = ["m.conversation_id = %s", + "(cm.joined_at IS NULL OR m.created_at >= cm.joined_at)"] + params = [user_id, user_id, conversation_id] + + if after_ts: + where_parts.append("m.created_at > %s") + params.append(after_ts) + + where_clause = " AND ".join(where_parts) + order = "ASC" if after_ts else "DESC" + + cursor.execute( + "SELECT m.id, m.conversation_id, m.sender_id, m.sender_device_id, " + "m.ratchet_header, m.x3dh_header, " + "m.sender_chain_id, m.sender_chain_n, m.created_at, m.deleted_at, m.image_file_id, " + "m.pinned_at, m.pinned_by, " + "mr.encrypted_content, mr.nonce, mr.device_id AS mr_device_id, " + "mr.ratchet_header AS mr_ratchet_header, mr.x3dh_header AS mr_x3dh_header " + "FROM messages m " + "JOIN message_recipients mr ON m.id = mr.message_id AND mr.user_id = %s " + "JOIN conversation_members cm ON cm.conversation_id = m.conversation_id AND cm.user_id = %s " + f"WHERE {where_clause} " + f"ORDER BY m.created_at {order} LIMIT %s OFFSET %s", + (*params, limit, offset), + ) + return cursor.fetchall() + finally: + conn.close() + + +def count_messages(conversation_id: str, user_id: str) -> int: + """Count total messages visible to a user in a conversation.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "SELECT COUNT(DISTINCT m.id) " + "FROM messages m " + "JOIN message_recipients mr ON m.id = mr.message_id AND mr.user_id = %s " + "JOIN conversation_members cm ON cm.conversation_id = m.conversation_id AND cm.user_id = %s " + "WHERE m.conversation_id = %s AND (cm.joined_at IS NULL OR m.created_at >= cm.joined_at)", + (user_id, user_id, conversation_id), + ) + row = cursor.fetchone() + return row[0] if row else 0 + finally: + conn.close() + + +def get_message_conversation(message_id: str) -> str | None: + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute("SELECT conversation_id FROM messages WHERE id = %s", (message_id,)) + row = cursor.fetchone() + return row[0] if row else None + finally: + conn.close() + + +def get_message_sender(message_id: str) -> str | None: + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute("SELECT sender_id FROM messages WHERE id = %s", (message_id,)) + row = cursor.fetchone() + return row[0] if row else None + finally: + conn.close() + + +def get_deleted_messages_since(conversation_id: str, user_id: str, since_ts: str) -> list[str]: + """Return message IDs that were soft-deleted since the given timestamp.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "SELECT m.id FROM messages m " + "JOIN conversation_members cm ON cm.conversation_id = m.conversation_id AND cm.user_id = %s " + "WHERE m.conversation_id = %s AND m.deleted_at IS NOT NULL AND m.deleted_at > %s", + (user_id, conversation_id, since_ts), + ) + return [row[0] for row in cursor.fetchall()] + finally: + conn.close() + + +# --- Reactions --- + +ALLOWED_REACTIONS = {"thumbsup", "heart", "laugh", "surprised", "sad", "thumbsdown"} + + +def add_reaction(message_id: str, user_id: str, reaction: str) -> tuple[bool, str | None]: + """Add or replace a reaction. Returns (changed, old_reaction_or_None).""" + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT reaction FROM message_reactions WHERE message_id = %s AND user_id = %s", + (message_id, user_id), + ) + row = cursor.fetchone() + old_reaction = row["reaction"] if row else None + + if old_reaction == reaction: + return False, None # already same reaction + + if old_reaction: + cursor.execute( + "UPDATE message_reactions SET reaction = %s, created_at = CURRENT_TIMESTAMP " + "WHERE message_id = %s AND user_id = %s", + (reaction, message_id, user_id), + ) + else: + cursor.execute( + "INSERT INTO message_reactions (id, message_id, user_id, reaction) " + "VALUES (%s, %s, %s, %s)", + (generate_uuid(), message_id, user_id, reaction), + ) + conn.commit() + return True, old_reaction + finally: + conn.close() + + +def remove_reaction(message_id: str, user_id: str) -> bool: + """Remove a user's reaction. Returns True if deleted.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "DELETE FROM message_reactions WHERE message_id = %s AND user_id = %s", + (message_id, user_id), + ) + conn.commit() + return cursor.rowcount > 0 + finally: + conn.close() + + +def get_reactions(message_ids: list[str]) -> dict[str, list[dict]]: + """Get reactions for multiple messages. Returns {msg_id: [{user_id, reaction, created_at}]}.""" + if not message_ids: + return {} + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + placeholders = ",".join(["%s"] * len(message_ids)) + cursor.execute( + f"SELECT message_id, user_id, reaction, created_at " + f"FROM message_reactions WHERE message_id IN ({placeholders}) " + f"ORDER BY created_at", + tuple(message_ids), + ) + result = {} + for row in cursor.fetchall(): + mid = row["message_id"] + if mid not in result: + result[mid] = [] + result[mid].append({ + "user_id": row["user_id"], + "reaction": row["reaction"], + "created_at": row["created_at"].isoformat() if hasattr(row["created_at"], "isoformat") else str(row["created_at"]), + }) + return result + finally: + conn.close() + + +# --- Pins --- + +def pin_message(message_id: str, user_id: str, conversation_id: str) -> bool: + """Pin a message. Returns True on success.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "UPDATE messages SET pinned_at = NOW(), pinned_by = %s " + "WHERE id = %s AND conversation_id = %s AND pinned_at IS NULL", + (user_id, message_id, conversation_id), + ) + conn.commit() + return cursor.rowcount > 0 + finally: + conn.close() + + +def unpin_message(message_id: str, conversation_id: str) -> bool: + """Unpin a message. Returns True on success.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "UPDATE messages SET pinned_at = NULL, pinned_by = NULL " + "WHERE id = %s AND conversation_id = %s AND pinned_at IS NOT NULL", + (message_id, conversation_id), + ) + conn.commit() + return cursor.rowcount > 0 + finally: + conn.close() + + +def get_pinned_messages(conversation_id: str, user_id: str) -> list[dict]: + """Get pinned messages for a conversation (membership verified via JOIN).""" + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT m.id AS message_id, m.sender_id, m.pinned_at, m.pinned_by, m.created_at " + "FROM messages m " + "JOIN conversation_members cm ON cm.conversation_id = m.conversation_id AND cm.user_id = %s " + "WHERE m.conversation_id = %s AND m.pinned_at IS NOT NULL AND m.deleted_at IS NULL " + "ORDER BY m.pinned_at DESC", + (user_id, conversation_id), + ) + rows = cursor.fetchall() + for r in rows: + for k in ("pinned_at", "created_at"): + if r.get(k) and hasattr(r[k], "isoformat"): + r[k] = r[k].isoformat() + return rows + finally: + conn.close() + + +# --- Group Sender Keys --- + +def store_sender_key(conversation_id: str, sender_id: str, chain_id: bytes, + device_id: str | None = None): + """Store or update a sender key chain ID for a group member's device.""" + conn = get_connection() + try: + cursor = conn.cursor() + dev = device_id or SELF_DEVICE_ID + cursor.execute( + "REPLACE INTO group_sender_keys (conversation_id, sender_id, device_id, chain_id) " + "VALUES (%s, %s, %s, %s)", + (conversation_id, sender_id, dev, chain_id), + ) + conn.commit() + finally: + conn.close() + + +def get_sender_key(conversation_id: str, sender_id: str, + device_id: str | None = None) -> dict | None: + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + dev = device_id or SELF_DEVICE_ID + cursor.execute( + "SELECT chain_id, created_at FROM group_sender_keys " + "WHERE conversation_id = %s AND sender_id = %s AND device_id = %s", + (conversation_id, sender_id, dev), + ) + return cursor.fetchone() + finally: + conn.close() + + +# --- Read Receipts --- + +def filter_message_ids_by_conversation(conversation_id: str, message_ids: list[str]) -> list[str]: + """Return only message_ids that belong to the given conversation.""" + if not message_ids: + return [] + conn = get_connection() + try: + cursor = conn.cursor() + placeholders = ",".join(["%s"] * len(message_ids)) + cursor.execute( + f"SELECT id FROM messages WHERE id IN ({placeholders}) AND conversation_id = %s", + (*message_ids, conversation_id), + ) + return [row[0] for row in cursor.fetchall()] + finally: + conn.close() + + +def mark_messages_read(conversation_id: str, user_id: str, message_ids: list[str]): + if not message_ids: + return + conn = get_connection() + try: + cursor = conn.cursor() + # M1 fix: JOIN messages to verify message_ids belong to conversation_id + placeholders = ",".join(["%s"] * len(message_ids)) + cursor.execute( + f"INSERT IGNORE INTO message_reads (message_id, user_id) " + f"SELECT m.id, %s FROM messages m " + f"WHERE m.id IN ({placeholders}) AND m.conversation_id = %s", + (user_id, *message_ids, conversation_id), + ) + conn.commit() + finally: + conn.close() + + +def mark_conversation_read(conversation_id: str, user_id: str) -> int: + """Mark ALL unread messages in a conversation as read for user. Returns count marked.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "INSERT IGNORE INTO message_reads (message_id, user_id) " + "SELECT m.id, %s " + "FROM messages m " + "JOIN message_recipients mr ON mr.message_id = m.id AND mr.user_id = %s " + "LEFT JOIN message_reads mrd ON mrd.message_id = m.id AND mrd.user_id = %s " + "WHERE m.conversation_id = %s AND m.sender_id != %s " + "AND m.deleted_at IS NULL AND mrd.message_id IS NULL", + (user_id, user_id, user_id, conversation_id, user_id), + ) + count = cursor.rowcount + conn.commit() + return count + finally: + conn.close() + + +def get_unread_counts(user_id: str, max_age_days: int = 0) -> dict[str, int]: + """Return {conversation_id: unread_count} for all conversations the user is in. + + max_age_days: if > 0, only count messages younger than this many days. + Must match METADATA_RETENTION_DAYS to avoid phantom unreads after read cleanup. + """ + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + age_filter = "" + params = [user_id, user_id, user_id] + if max_age_days > 0: + age_filter = " AND m.created_at >= DATE_SUB(NOW(), INTERVAL %s DAY)" + params.append(max_age_days) + cursor.execute( + "SELECT m.conversation_id, COUNT(DISTINCT m.id) AS cnt " + "FROM messages m " + "JOIN message_recipients mr ON mr.message_id = m.id AND mr.user_id = %s " + "LEFT JOIN message_reads mrd ON mrd.message_id = m.id AND mrd.user_id = %s " + "WHERE m.sender_id != %s AND m.deleted_at IS NULL AND mrd.message_id IS NULL" + f"{age_filter} " + "GROUP BY m.conversation_id", + params, + ) + return {row["conversation_id"]: row["cnt"] for row in cursor.fetchall()} + finally: + conn.close() + + +def get_message_read_status(message_ids: list[str]) -> dict: + if not message_ids: + return {} + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + placeholders = ",".join(["%s"] * len(message_ids)) + cursor.execute( + f"SELECT mr.message_id, mr.user_id, mr.read_at " + f"FROM message_reads mr " + f"WHERE mr.message_id IN ({placeholders})", + tuple(message_ids), + ) + result = {} + for row in cursor.fetchall(): + mid = row["message_id"] + if mid not in result: + result[mid] = [] + result[mid].append({ + "user_id": row["user_id"], + "read_at": row["read_at"].isoformat() if hasattr(row["read_at"], "isoformat") else str(row["read_at"]), + }) + return result + finally: + conn.close() + + +# --- Delivery Receipts --- + +def mark_messages_delivered(conversation_id: str, user_id: str, message_ids: list[str]): + """Batch insert delivery receipts (INSERT IGNORE — idempotent).""" + if not message_ids: + return + conn = get_connection() + try: + cursor = conn.cursor() + # M1 fix: JOIN messages to verify message_ids belong to conversation_id + placeholders = ",".join(["%s"] * len(message_ids)) + cursor.execute( + f"INSERT IGNORE INTO message_deliveries (message_id, user_id) " + f"SELECT m.id, %s FROM messages m " + f"WHERE m.id IN ({placeholders}) AND m.conversation_id = %s", + (user_id, *message_ids, conversation_id), + ) + conn.commit() + finally: + conn.close() + + +def get_message_delivery_status(message_ids: list[str]) -> dict: + """Get delivery status for messages. Returns {msg_id: [{user_id, delivered_at}]}.""" + if not message_ids: + return {} + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + placeholders = ",".join(["%s"] * len(message_ids)) + cursor.execute( + f"SELECT md.message_id, md.user_id, md.delivered_at " + f"FROM message_deliveries md " + f"WHERE md.message_id IN ({placeholders})", + tuple(message_ids), + ) + result = {} + for row in cursor.fetchall(): + mid = row["message_id"] + if mid not in result: + result[mid] = [] + result[mid].append({ + "user_id": row["user_id"], + "delivered_at": row["delivered_at"].isoformat() if hasattr(row["delivered_at"], "isoformat") else str(row["delivered_at"]), + }) + return result + finally: + conn.close() + + +# --- Delete --- + +def soft_delete_message(message_id: str, sender_id: str) -> dict | None: + """Soft-delete a message if sender matches. Returns {'image_file_id': ...} or None.""" + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT sender_id, image_file_id FROM messages WHERE id = %s AND deleted_at IS NULL", + (message_id,), + ) + row = cursor.fetchone() + if not row or row["sender_id"] != sender_id: + return None + cursor.execute( + "UPDATE messages SET deleted_at = NOW() WHERE id = %s", + (message_id,), + ) + # Clear per-recipient ciphertext + cursor.execute( + "UPDATE message_recipients SET encrypted_content = %s WHERE message_id = %s", + (b"", message_id), + ) + conn.commit() + return {"image_file_id": row.get("image_file_id")} + finally: + conn.close() + + +def set_message_image_file_id(message_id: str, file_id: str): + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "UPDATE messages SET image_file_id = %s WHERE id = %s", + (file_id, message_id), + ) + conn.commit() + finally: + conn.close() + + +# --- Image Uploads --- + +def create_image_upload(file_id: str, conversation_id: str, uploader_id: str, file_size: int): + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "INSERT INTO image_uploads (file_id, conversation_id, uploader_id, file_size) " + "VALUES (%s, %s, %s, %s)", + (file_id, conversation_id, uploader_id, file_size), + ) + conn.commit() + finally: + conn.close() + + +def complete_image_upload(file_id: str): + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "UPDATE image_uploads SET completed = TRUE WHERE file_id = %s", + (file_id,), + ) + conn.commit() + finally: + conn.close() + + +def get_image_upload(file_id: str) -> dict | None: + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT file_id, conversation_id, uploader_id, file_size, completed, created_at " + "FROM image_uploads WHERE file_id = %s", + (file_id,), + ) + return cursor.fetchone() + finally: + conn.close() + + +def delete_image_upload(file_id: str): + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute("DELETE FROM image_uploads WHERE file_id = %s", (file_id,)) + conn.commit() + finally: + conn.close() + + +# --- User Profiles --- + +def create_default_profile(user_id: str): + """Create a default profile for a new user.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "INSERT IGNORE INTO user_profiles (user_id) VALUES (%s)", + (user_id,), + ) + conn.commit() + finally: + conn.close() + + +def get_user_profile(user_id: str, viewer_id: str | None = None) -> dict | None: + """Get user profile joined with user info. Respects visibility if viewer is different user.""" + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT u.id AS user_id, u.username, u.email, u.created_at, " + "p.phone, p.phone_visible, p.email_visible, p.location, " + "p.location_visible, p.avatar_file, p.updated_at " + "FROM users u LEFT JOIN user_profiles p ON u.id = p.user_id " + "WHERE u.id = %s", + (user_id,), + ) + row = cursor.fetchone() + if not row: + return None + # If viewing someone else's profile, apply visibility rules + if viewer_id and viewer_id != user_id: + if not row.get("email_visible"): + row["email"] = None + if not row.get("phone_visible"): + row["phone"] = None + if not row.get("location_visible"): + row["location"] = None + return row + finally: + conn.close() + + +def update_user_profile(user_id: str, **fields): + """Upsert user profile fields. Allowed: phone, phone_visible, email_visible, + location, location_visible, avatar_file.""" + allowed = {"phone", "phone_visible", "email_visible", "location", + "location_visible", "avatar_file"} + filtered = {k: v for k, v in fields.items() if k in allowed} + if not filtered: + return + conn = get_connection() + try: + cursor = conn.cursor() + # Upsert: insert default then update + cursor.execute( + "INSERT IGNORE INTO user_profiles (user_id) VALUES (%s)", + (user_id,), + ) + set_clause = ", ".join(f"{k} = %s" for k in filtered) + values = list(filtered.values()) + [user_id] + cursor.execute( + f"UPDATE user_profiles SET {set_clause} WHERE user_id = %s", + values, + ) + conn.commit() + finally: + conn.close() + + +def batch_reencrypt_messages(user_id: str, updates: list[dict]): + """Batch upsert message_recipients rows with self-encryption key data. + + Each update: {message_id, encrypted_content (bytes), nonce (bytes)}. + Sets ratchet_header to '{"self":true}' and clears x3dh_header. + Uses INSERT ... ON DUPLICATE KEY UPDATE so it works for both sent messages + (which already have a SELF_DEVICE_ID row) and received messages (which don't). + """ + if not updates: + return + conn = get_connection() + try: + cursor = conn.cursor() + self_header = b'{"self":true}' + for u in updates: + cursor.execute( + "INSERT INTO message_recipients " + "(message_id, user_id, device_id, encrypted_content, nonce, ratchet_header, x3dh_header) " + "VALUES (%s, %s, %s, %s, %s, %s, NULL) " + "ON DUPLICATE KEY UPDATE encrypted_content = VALUES(encrypted_content), " + "nonce = VALUES(nonce), ratchet_header = VALUES(ratchet_header), x3dh_header = NULL", + (u["message_id"], user_id, SELF_DEVICE_ID, + u["encrypted_content"], u["nonce"], self_header), + ) + conn.commit() + finally: + conn.close() + + +# --- Phantom Users --- + +def create_phantom_user(email: str) -> dict: + """Create a phantom user with valid crypto keys for X3DH. + + Phantom users have rsa_public_key = 'PHANTOM' as a marker. + Returns user dict: {id, username, email, identity_key}. + """ + username = email.split("@")[0] + user_id = generate_uuid() + + # Generate real crypto keys so X3DH works on the client side + ik_private, ik_public = generate_identity_keypair() + ik_public_bytes = serialize_ed25519_public(ik_public) + + spk = generate_signed_prekey(ik_private) + spk_pub_bytes = serialize_x25519_public(spk["public"]) + spk_sig = spk["signature"] + + opks = generate_one_time_prekeys(count=5) + + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "INSERT INTO users (id, username, email, rsa_public_key, identity_key) " + "VALUES (%s, %s, %s, %s, %s)", + (user_id, username, email, "PHANTOM", ik_public_bytes), + ) + cursor.execute( + "INSERT INTO signed_prekeys (id, user_id, public_key, signature) VALUES (%s, %s, %s, %s)", + (spk["id"], user_id, spk_pub_bytes, spk_sig), + ) + for opk in opks: + cursor.execute( + "INSERT INTO one_time_prekeys (id, user_id, public_key) VALUES (%s, %s, %s)", + (opk["id"], user_id, serialize_x25519_public(opk["public"])), + ) + conn.commit() + return {"id": user_id, "username": username, "email": email, "identity_key": ik_public_bytes} + finally: + conn.close() + + +def is_phantom_user(user_id: str) -> bool: + """Check if a user is a phantom (rsa_public_key == 'PHANTOM').""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute("SELECT rsa_public_key FROM users WHERE id = %s", (user_id,)) + row = cursor.fetchone() + return row is not None and row[0] == "PHANTOM" + finally: + conn.close() + + +def delete_phantom_user(user_id: str): + """Delete a phantom user. CASCADE removes signed_prekeys, one_time_prekeys, + conversation_members, message_recipients, etc.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "DELETE FROM users WHERE id = %s AND rsa_public_key = %s", + (user_id, "PHANTOM"), + ) + conn.commit() + finally: + conn.close() + + +def upgrade_phantom_user(phantom_id: str, username: str, rsa_public_key_pem: str, + identity_key: bytes) -> str | None: + """Upgrade a phantom user to a real user in-place. + + Preserves user_id and all FK references (conversation_members, group_invitations, etc.). + Deletes phantom's server-generated prekeys (real user will upload own on first login). + Returns phantom_id as the new user_id, or None if phantom no longer exists. + """ + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "UPDATE users SET username = %s, rsa_public_key = %s, identity_key = %s " + "WHERE id = %s AND rsa_public_key = 'PHANTOM'", + (username, rsa_public_key_pem, identity_key, phantom_id), + ) + if cursor.rowcount == 0: + conn.rollback() + return None + # Remove phantom's server-generated crypto keys — real user uploads own + cursor.execute("DELETE FROM signed_prekeys WHERE user_id = %s", (phantom_id,)) + cursor.execute("DELETE FROM one_time_prekeys WHERE user_id = %s", (phantom_id,)) + conn.commit() + return phantom_id + finally: + conn.close() + + +def get_all_phantom_user_ids() -> set[str]: + """Return set of all phantom user IDs (for server startup cache).""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute("SELECT id FROM users WHERE rsa_public_key = %s", ("PHANTOM",)) + return {row[0] for row in cursor.fetchall()} + finally: + conn.close() + + +def cleanup_stale_phantoms(max_age_days: int = 30) -> int: + """Delete phantom users older than max_age_days with no active conversations with real users.""" + conn = get_connection() + try: + cursor = conn.cursor() + # Two-step: SELECT ids first, then DELETE. + # MySQL error 1093: can't DELETE from table referenced in subquery. + cursor.execute(""" + SELECT u.id FROM users u + WHERE u.rsa_public_key = 'PHANTOM' + AND u.created_at < DATE_SUB(NOW(), INTERVAL %s DAY) + AND NOT EXISTS ( + SELECT 1 FROM conversation_members cm1 + JOIN conversation_members cm2 ON cm1.conversation_id = cm2.conversation_id + JOIN users u2 ON cm2.user_id = u2.id + WHERE cm1.user_id = u.id + AND u2.rsa_public_key != 'PHANTOM' + ) + """, (max_age_days,)) + ids = [row[0] for row in cursor.fetchall()] + if not ids: + return 0 + cursor.execute( + "DELETE FROM users WHERE id IN (%s)" % ",".join(["%s"] * len(ids)), + ids, + ) + deleted = cursor.rowcount + conn.commit() + return deleted + finally: + conn.close() + + +def remove_conversation_member_atomic(conversation_id: str, user_id: str) -> bool: + """Remove member and return True if actually removed (row existed). M6 TOCTOU fix.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "DELETE FROM conversation_members WHERE conversation_id = %s AND user_id = %s", + (conversation_id, user_id), + ) + conn.commit() + return cursor.rowcount > 0 + finally: + conn.close() + + +def get_stale_uploads(max_age_seconds: int = 3600) -> list[dict]: + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT file_id FROM image_uploads " + "WHERE completed = FALSE AND created_at < DATE_SUB(NOW(), INTERVAL %s SECOND)", + (max_age_seconds,), + ) + return cursor.fetchall() + finally: + conn.close() + + +# --------------------------------------------------------------------------- +# Metadata retention cleanup +# --------------------------------------------------------------------------- + +def cleanup_old_reads(days: int = 90, batch_size: int = 10000) -> int: + """Delete message_reads older than N days in batches. + + Only deletes reads for messages whose created_at is also past the retention + window. This prevents phantom unreads: get_unread_counts uses the same + time window (max_age_days) so messages outside the window aren't counted. + """ + total = 0 + while True: + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "DELETE FROM message_reads " + "WHERE read_at < DATE_SUB(NOW(), INTERVAL %s DAY) " + "AND message_id IN (" + " SELECT id FROM messages " + " WHERE created_at < DATE_SUB(NOW(), INTERVAL %s DAY)" + ") LIMIT %s", + (days, days, batch_size), + ) + count = cursor.rowcount + conn.commit() + total += count + if count < batch_size: + break + finally: + conn.close() + return total + + +def cleanup_old_reactions(days: int = 90, batch_size: int = 10000) -> int: + """Delete message_reactions older than N days in batches.""" + total = 0 + while True: + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "DELETE FROM message_reactions WHERE created_at < DATE_SUB(NOW(), INTERVAL %s DAY) LIMIT %s", + (days, batch_size), + ) + count = cursor.rowcount + conn.commit() + total += count + if count < batch_size: + break + finally: + conn.close() + return total diff --git a/gemini.md b/gemini.md new file mode 100644 index 0000000..48b422b --- /dev/null +++ b/gemini.md @@ -0,0 +1,152 @@ +# Gemini Advanced Roadmap: Beyond the Basics + +Tento dokument obsahuje pokročilé návrhy na vylepšení bezpečnosti, architektury a UX aplikace `encrypted_chat`. Tyto body jdou nad rámec běžného "best practice" a směřují k funkcionalitě profesionálních secure messengerů (Signal, Threema, Wire) se zaměřením na ochranu metadat a anti-forenzní techniky. + +--- + +## 1. Ochrana Metadat & Traffic Analysis Resistance +*Cíl: Server by neměl vědět, KDO s KÝM komunikuje, ani JAKÝ typ dat si posílají.* + +### Sealed Sender (Odesílatel v obálce) +- **Koncept:** Server zná pouze `recipient_id`. Identita odesílatele (`sender_id`) je zašifrována uvnitř zprávy (v "obálce"), kterou server nedokáže přečíst. +- **Implementace:** + 1. Odesílatel vygeneruje klíč pro obálku (např. z profilu příjemce). + 2. Zabalí `sender_id` a payload do šifrovaného bloku. + 3. Server doručí blob příjemci bez ověření odesílatele (ověření proběhne až na klientovi po rozbalení). + 4. **Výhoda:** Při kompromitaci serveru útočník nevidí sociální graf (kdo se s kým baví). + +### Traffic Padding & Constant Bitrate +- **Problém:** Délka paketu prozrazuje obsah (krátký paket = "Ahoj", dlouhý paket = obrázek/klíč). Intervaly prozrazují aktivitu. +- **Řešení:** + 1. **Padding:** Všechny zprávy doplňovat náhodnými daty na fixní velikosti (např. bloky 4KB). + 2. **Dummy Traffic (Chaff):** Klient náhodně odesílá "falešné" pakety na server, které server zahodí nebo vrátí (echo). + 3. **Výhoda:** Pro síťového analytika (ISP) vypadá tok dat jako konstantní šum. + +--- + +## 2. Anti-Forenzní Ochrana (Client-side) +*Cíl: Minimalizovat dopad fyzického zabavení zařízení nebo vynuceného odemčení.* + +### Duress Password (Heslo pod nátlakem) +- **Funkce:** Uživatel si nastaví *druhé* heslo. +- **Chování:** Pokud se přihlásí tímto heslem: + - **Varianta A (Decoy):** Odemkne se prázdná nebo falešná databáze s neškodnými konverzacemi. + - **Varianta B (Panic):** Aplikace na pozadí tiše provede **secure wipe** (přepis) privátních klíčů a reálné DB, zatímco uživateli zobrazí "Connection Error". + +### Secure Deletion & DB Vacuuming +- **Problém:** SQL `DELETE` data nesmaže fyzicky, jen označí místo jako volné. +- **Řešení:** + 1. Před smazáním zprávy přepsat obsah náhodnými byty (`UPDATE messages SET content = random_blob WHERE id = ...`). + 2. Pravidelně spouštět `VACUUM` (u SQLite) nebo optimalizaci tabulek. + 3. Pro soubory (obrázky) použít bezpečné mazání (overwrite passes) před `os.unlink()`. + +### Disappearing Messages (TTL) +- **Funkce:** Odesílatel nastaví životnost zprávy (např. 1 minuta). +- **Implementace:** Odpočet začíná okamžikem zobrazení (Read Receipt). Po uplynutí času klient data nenávratně smaže z disku (včetně secure wipe). Server maže ihned po doručení. + +--- + +## 3. Infrastruktura & Škálování +*Cíl: Odlehčit Python procesu a databázi pro podporu tisíců uživatelů.* + +### Object Storage (MinIO / S3) + Presigned URLs +- **Problém:** `server.py` blokuje I/O při příjmu velkých souborů. +- **Řešení:** + 1. Klient požádá server o upload. + 2. Server vygeneruje **Presigned PUT URL** (časově omezený token pro přímý upload do MinIO/S3). + 3. Klient nahrává data přímo do úložiště (obchází aplikační server). + 4. Server ukládá pouze odkaz (URL/Key). +- **Výhoda:** Masivní zrychlení, server řeší jen metadata. + +### Read/Write Splitting (MySQL Replication) +- **Architektura:** + - **Master DB:** Pouze pro `INSERT`, `UPDATE`, `DELETE`. + - **Read Replicas (Slaves):** Pro těžké `SELECT` dotazy (historie zpráv, hledání). +- **Implementace v `db.py`:** Router, který podle typu dotazu volí connection pool. + +--- + +## 4. Protokol & Funkce +*Cíl: Rozšíření možností komunikace bez nutnosti centralizovaného streamování.* + +### P2P Volání (WebRTC Signalizace) +- **Koncept:** Využít existující bezpečný kanál (Double Ratchet) pro výměnu SDP (Session Description Protocol) paketů. +- **Flow:** + 1. Alice pošle Bobovi zašifrovanou zprávu typu `CALL_OFFER` s parametry WebRTC. + 2. Bob odpoví `CALL_ANSWER`. + 3. Klienti si vymění `ICE_CANDIDATES` (IP adresy/porty) a naváží přímé P2P spojení (UDP). + 4. Audio/Video stream (SRTP) jde mimo server. + +### Diferenciální Synchronizace (Merkle Trees) +- **Problém:** Stahování seznamu kontaktů (`get_user_contacts`) je pomalé při velkém množství dat. +- **Řešení:** Klient a server si udržují Hash Tree (Merkle Tree) stavu. Při synchronizaci porovnají pouze root hash. Pokud se liší, stahují se jen změněné větve stromu (delta update). + +--- + +## 5. UI/UX (PyQt Speciality) +*Cíl: Ochrana soukromí na úrovni OS a skrytá komunikace.* + +### Privacy Overlay (Task Switcher) +- **Funkce:** Detekovat událost ztráty fokusu okna (`QEvent.WindowDeactivate`) nebo minimalizace. +- **Akce:** Překrýt obsah okna rozmazaným efektem (`QGraphicsBlurEffect`) nebo logem aplikace. +- **Důvod:** Zabrání operačnímu systému (Windows/Linux/macOS) vytvořit čitelný náhled okna v Alt+Tab menu nebo v historii aktivit. + +### Steganografie +- **Funkce:** Ukrýt šifrovanou zprávu do obrazových dat nevinného obrázku (např. meme kočky). +- **Implementace:** Modifikace LSB (Least Significant Bit) pixelů obrázku. +- **Výhoda:** Pro síťového admina nebo forenzní analýzu to vypadá jako běžné posílání obrázků, přítomnost šifrované komunikace je popiratelná. + +--- + +## 6. High Availability Architecture (Distribuovaný Cluster) +*Cíl: Zajištění provozu i při výpadku/napadení serveru (Active-Active "RAID 1 přes síť").* + +### Architektura: Geograficky Distribuovaný "Zero-Trust" Cluster + +#### 1. Vstupní brána (Global Traffic Manager) +- **Funkce:** Rozděluje klienty mezi dostupné servery (Round Robin / Geo-DNS). +- **Self-Healing:** Při výpadku Serveru A okamžitě přesměruje provoz na Server B. Uživatel nic nepozná. + +#### 2. Aplikační vrstva (Stateless Servers) +- **Stav:** Servery jsou **bezstavové**. `server.py` neukládá nic důležitého v RAM. +- **Škálování:** Můžete spustit N instancí serveru. Je jedno, ke kterému se uživatel připojí. +- **Komunikace:** Servery spolu mluví přes rychlý Message Bus (Redis Pub/Sub) pro doručování real-time zpráv mezi uživateli na různých uzlech. + +#### 3. Datová vrstva (Zrcadlení Dat - "RAID 1") +- **Databáze (MySQL Galera Cluster):** Synchronní multi-master replikace. Zápis na Serveru A se potvrdí, až když je fyzicky zapsán i na Serveru B (a C). + - *Efekt:* Ztráta serveru neznamená ztrátu dat (klíčů, zpráv). +- **Soubory (MinIO Cluster):** Distribuovaný Object Storage s Erasure Coding. Soubory jsou matematicky rozprostřeny přes všechny servery. Výpadek disku/serveru nevadí. + +#### 4. Bezpečnostní pojistky ("Poisoned Node") +- **Soft Delete:** Databáze nemaže data ihned, ale označuje je jako smazané (ochrana proti `DELETE *` od útočníka). +- **Client-Side Verification:** I kdyby kompromitovaný server posílal podvržené klíče, klienti ověřují digitální podpisy (Identity Keys). Server nemůže zfalšovat identitu uživatelů. + +--- + +## 7. Technický Upgrade pro Stateless Architekturu +*Cíl: Odstranit závislost na paměti procesu (RAM) pro umožnění horizontálního škálování.* + +### 1. Redis jako Distribuovaná Paměť +Nahrazení Python `dict` struktur, které jsou lokální pro jeden proces, za centrální Redis úložiště přístupné všem serverům. + +* **Párovací Session:** + * *Stav:* `pairing_sessions` (dict) -> Redis Key `pair:{code}` (Hash/String s TTL). + * *Efekt:* Uživatel může vyžádat kód na Serveru A a potvrdit ho na Serveru B. +* **Rate Limiting:** + * *Stav:* `rate_limits` (dict) -> Redis Key `rl:{ip}:{action}` (Counter s EXPIRE). + * *Efekt:* Limity platí globálně pro celý cluster, ne jen per server. + +### 2. Redis Pub/Sub pro Real-Time Routing +Doručení zprávy uživateli, který je připojen k JINÉMU serveru než odesílatel. + +* **Princip:** + 1. Server A (odesílatel) zjistí, že příjemce Bob není připojen lokálně. + 2. Server A publikuje zprávu do Redis kanálu `user:{bob_user_id}`. + 3. Server B (kde je Bob připojen) tento kanál poslouchá (subscribe). + 4. Server B přijme zprávu z Redisu a pošle ji Bobovi do otevřeného TCP socketu. + +### 3. Session Sticky vs. Stateless Uploads +Řešení pro nahrávání souborů po částech (chunks). + +* **Varianta A (Infrastructure - Sticky Sessions):** Load Balancer (HAProxy/Nginx) zajistí, že všechny požadavky od jedné IP jdou vždy na stejný server. Nejjednodušší, nevyžaduje změnu kódu. +* **Varianta B (Architectural - Direct Upload):** Viz bod 3 "Object Storage + Presigned URLs". Server vůbec nepřijímá data souboru, pouze vygeneruje token. Plně stateless řešení. diff --git a/gui_client.py b/gui_client.py new file mode 100644 index 0000000..2513061 --- /dev/null +++ b/gui_client.py @@ -0,0 +1,6338 @@ +"""PyQt6 GUI client for encrypted chat.""" + +import asyncio +import json +import logging +import os +from collections import OrderedDict + +logger = logging.getLogger(__name__) +import re +import sys +from functools import partial + +from PyQt6.QtCore import QThread, pyqtSignal, Qt, QTimer, QUrl, QSize, QRect, QPoint, QPointF +from PyQt6.QtWidgets import ( + QApplication, QWidget, QVBoxLayout, QHBoxLayout, QPushButton, + QLineEdit, QLabel, QListWidget, QListWidgetItem, QTextEdit, + QSplitter, QMessageBox, QInputDialog, QMenu, QStackedWidget, + QDialog, QFileDialog, QScrollArea, QFrame, QSystemTrayIcon, + QSizePolicy, QStyledItemDelegate, +) +from PyQt6.QtGui import QFont, QFontMetricsF, QAction, QPixmap, QImage, QDesktopServices, QIcon, QPainter, QColor, QBrush, QPen, QShortcut, QKeySequence +from PyQt6.QtWidgets import QGraphicsDropShadowEffect +from PyQt6.QtWidgets import QStyle + +from chat_core import ChatClient, IdentityKeyChanged +from theme import ThemeManager, c, qss, tm + +# H10: Image validation limits +MAX_IMAGE_DATA_SIZE = 10 * 1024 * 1024 # 10 MB max raw image data +MAX_IMAGE_DIMENSION = 8192 # 8K pixels max + + +def _safe_load_image(data: bytes) -> QImage | None: + """Load image with size and dimension validation (H10).""" + if not data or len(data) > MAX_IMAGE_DATA_SIZE: + return None + qimg = QImage.fromData(data) + if qimg.isNull(): + return None + if qimg.width() > MAX_IMAGE_DIMENSION or qimg.height() > MAX_IMAGE_DIMENSION: + return None + return qimg + + +def _safe_filename(name: str, default: str = "file") -> str: + """Sanitize filename — strip path components, prevent traversal (H11).""" + name = os.path.basename(name) + name = name.replace("\x00", "") + return name if name else default + + +_AVATAR_CACHE_MAX = 512 # max cached avatar pixmaps per cache + + +class _LRUPixmapCache: + """Simple LRU cache for QPixmap objects with a fixed max size.""" + + def __init__(self, maxsize: int = _AVATAR_CACHE_MAX): + self._data: OrderedDict[str, QPixmap] = OrderedDict() + self._maxsize = maxsize + + def get(self, key: str) -> QPixmap | None: + if key in self._data: + self._data.move_to_end(key) + return self._data[key] + return None + + def put(self, key: str, value: QPixmap) -> None: + if key in self._data: + self._data.move_to_end(key) + self._data[key] = value + else: + self._data[key] = value + if len(self._data) > self._maxsize: + self._data.popitem(last=False) + + def __contains__(self, key: str) -> bool: + return key in self._data + + def __getitem__(self, key: str) -> QPixmap: + self._data.move_to_end(key) + return self._data[key] + + def __setitem__(self, key: str, value: QPixmap) -> None: + self.put(key, value) + + def clear(self) -> None: + self._data.clear() + + +# URL regex: matches http:// and https:// URLs in raw (not-yet-escaped) text +_URL_RE = re.compile( + r'(https?://[^\s<>"\')\]]+)', + re.IGNORECASE, +) +_URL_TRAILING_PUNCT = re.compile(r'[.,;:!?]+$') + + +def _linkify_urls(raw_text: str, https_color: str | None = None, + http_color: str | None = None) -> str: + """HTML-escape text and convert URLs into clickable tags. + + HTTPS links get blue styling. HTTP links get orange + unlock icon warning. + Processes raw (unescaped) text — returns HTML-safe string. + """ + _https = https_color or c().link_https + _http = http_color or c().link_http + + def _esc(s): + return s.replace("&", "&").replace("<", "<").replace(">", ">") + + parts = _URL_RE.split(raw_text) + result = [] + for i, part in enumerate(parts): + if i % 2 == 1: + # URL match — strip trailing sentence punctuation + trail_m = _URL_TRAILING_PUNCT.search(part) + if trail_m: + url = part[:trail_m.start()] + trail = part[trail_m.start():] + else: + url = part + trail = "" + url_esc = _esc(url) + if url.lower().startswith("http://"): + result.append( + f'' + f'\U0001f513 {url_esc}' + ) + else: + result.append( + f'' + f'{url_esc}' + ) + if trail: + result.append(_esc(trail)) + else: + result.append(_esc(part)) + return "".join(result) + + +def setup_logging(): + level_name = os.getenv("LOG_LEVEL", "WARNING").upper() + level = getattr(logging, level_name, logging.WARNING) + logging.basicConfig(level=level, format="%(levelname)s: %(message)s") + + + +MAX_INPUT_CHARS = int(os.getenv("MAX_INPUT_CHARS", "2000")) + +# Custom item data roles for conversation list delegate +ROLE_CONV_ID = Qt.ItemDataRole.UserRole +ROLE_DISPLAY_NAME = Qt.ItemDataRole.UserRole + 1 +ROLE_PREVIEW = Qt.ItemDataRole.UserRole + 2 # last message preview text +ROLE_TIMESTAMP = Qt.ItemDataRole.UserRole + 3 # last message relative time +ROLE_UNREAD = Qt.ItemDataRole.UserRole + 4 # int unread count +ROLE_IS_FAV = Qt.ItemDataRole.UserRole + 5 # bool +ROLE_AVATAR = Qt.ItemDataRole.UserRole + 6 # QPixmap (circular, 44px) +ROLE_VERIFIED = Qt.ItemDataRole.UserRole + 7 # str: "verified", "trusted", "" (DMs only) +ROLE_RECEIPT = Qt.ItemDataRole.UserRole + 8 # str: "read", "delivered", "sent", "" (own msgs only) + + +def _relative_time(ts: str) -> str: + """Convert ISO timestamp to relative time string for conversation list.""" + if not ts or len(ts) < 16: + return "" + try: + from datetime import datetime, timezone + # Parse "YYYY-MM-DD HH:MM:SS" or "YYYY-MM-DDTHH:MM:SS" + clean = ts.replace("T", " ")[:19] + msg_time = datetime.strptime(clean, "%Y-%m-%d %H:%M:%S").replace( + tzinfo=timezone.utc + ) + now = datetime.now(timezone.utc) + diff = now - msg_time + secs = int(diff.total_seconds()) + if secs < 60: + return "now" + if secs < 3600: + return f"{secs // 60}m" + if secs < 86400: + return f"{secs // 3600}h" + days = secs // 86400 + if days == 1: + return "Yesterday" + if days < 7: + return msg_time.strftime("%a") # Mon, Tue, ... + return msg_time.strftime("%d/%m") + except Exception: + return ts[11:16] if len(ts) >= 16 else "" + + +def _format_msg_time(ts: str) -> str: + """Format timestamp for message bubbles: HH:MM today, 'Yesterday HH:MM', + 'Mon HH:MM' this week, 'DD.MM. HH:MM' older.""" + if not ts or len(ts) < 16: + return "" + try: + from datetime import datetime, timezone + clean = ts.replace("T", " ")[:19] + msg_time = datetime.strptime(clean, "%Y-%m-%d %H:%M:%S").replace( + tzinfo=timezone.utc + ) + now = datetime.now(timezone.utc) + hhmm = msg_time.strftime("%H:%M") + if msg_time.date() == now.date(): + return hhmm + diff_days = (now.date() - msg_time.date()).days + if diff_days == 1: + return f"Yesterday {hhmm}" + if diff_days < 7: + return f"{msg_time.strftime('%a')} {hhmm}" + if msg_time.year == now.year: + return f"{msg_time.day}.{msg_time.month}. {hhmm}" + return f"{msg_time.day}.{msg_time.month}.{msg_time.year} {hhmm}" + except Exception: + return ts[11:16] if len(ts) >= 16 else "" + + +class ConversationDelegate(QStyledItemDelegate): + """Custom delegate that paints Signal/Telegram-style conversation rows.""" + + ITEM_HEIGHT = 68 + AVATAR_SIZE = 44 + BADGE_SIZE = 20 + HPAD = 10 + VPAD = 10 + + def sizeHint(self, option, index): + return QSize(option.rect.width(), self.ITEM_HEIGHT) + + def paint(self, painter, option, index): + painter.save() + painter.setRenderHint(QPainter.RenderHint.Antialiasing) + t = c() + + rect = option.rect + is_selected = bool(option.state & QStyle.StateFlag.State_Selected) + is_hover = bool(option.state & QStyle.StateFlag.State_MouseOver) + + # Background + if is_selected: + painter.fillRect(rect, QColor(t.bg_selected)) + elif is_hover: + painter.fillRect(rect, QColor(t.bg_hover)) + + # Data from item roles + name = index.data(ROLE_DISPLAY_NAME) or "" + preview = index.data(ROLE_PREVIEW) or "" + timestamp = index.data(ROLE_TIMESTAMP) or "" + unread = index.data(ROLE_UNREAD) or 0 + is_fav = index.data(ROLE_IS_FAV) or False + avatar_pix = index.data(ROLE_AVATAR) + verified = index.data(ROLE_VERIFIED) or "" + + x = rect.x() + self.HPAD + y = rect.y() + self.VPAD + avail_w = rect.width() - 2 * self.HPAD + + # -- Avatar (left) -- + av_y = rect.y() + (rect.height() - self.AVATAR_SIZE) // 2 + if avatar_pix and not avatar_pix.isNull(): + painter.drawPixmap(x, av_y, self.AVATAR_SIZE, self.AVATAR_SIZE, + avatar_pix) + else: + painter.setBrush(QBrush(QColor(t.bg_secondary))) + painter.setPen(Qt.PenStyle.NoPen) + painter.drawEllipse(x, av_y, self.AVATAR_SIZE, self.AVATAR_SIZE) + painter.setPen(QColor(t.text_muted)) + f = QFont() + f.setPointSize(14) + f.setBold(True) + painter.setFont(f) + painter.drawText( + QRect(x, av_y, self.AVATAR_SIZE, self.AVATAR_SIZE), + Qt.AlignmentFlag.AlignCenter, + name[0].upper() if name else "?", + ) + + text_x = x + self.AVATAR_SIZE + 10 + text_w = avail_w - self.AVATAR_SIZE - 10 + + # -- Timestamp (top-right) -- + ts_w = 50 + painter.setPen(QColor(t.text_muted)) + ts_font = QFont() + ts_font.setPointSize(8) + painter.setFont(ts_font) + ts_rect = QRect( + rect.x() + rect.width() - self.HPAD - ts_w, + y, ts_w, 18, + ) + painter.drawText(ts_rect, Qt.AlignmentFlag.AlignRight | Qt.AlignmentFlag.AlignVCenter, timestamp) + + name_w = text_w - ts_w - 8 + + # -- Name (line 1) -- + name_font = QFont() + name_font.setPointSize(10) + if unread > 0: + name_font.setBold(True) + painter.setFont(name_font) + painter.setPen(QColor(t.text_primary)) + display_name = name + if is_fav: + display_name = f"\u2605 {name}" + # Elide name to fit + fm = painter.fontMetrics() + elided_name = fm.elidedText(display_name, Qt.TextElideMode.ElideRight, name_w) + painter.drawText(text_x, y, name_w, 22, + Qt.AlignmentFlag.AlignLeft | Qt.AlignmentFlag.AlignVCenter, + elided_name) + + # -- Verification badge (after name) -- + if verified == "verified": + name_text_w = fm.horizontalAdvance(elided_name) + badge_x = text_x + name_text_w + 4 + badge_y_center = y + 11 + painter.setPen(Qt.PenStyle.NoPen) + painter.setBrush(QBrush(QColor(t.success))) + painter.drawEllipse(badge_x, badge_y_center - 5, 10, 10) + # Checkmark inside circle + painter.setPen(QPen(QColor(t.bg_primary), 1.5)) + painter.drawLine(badge_x + 2, badge_y_center, badge_x + 4, badge_y_center + 2) + painter.drawLine(badge_x + 4, badge_y_center + 2, badge_x + 8, badge_y_center - 3) + + # -- Preview (line 2) -- + preview_y = y + 24 + preview_font = QFont() + preview_font.setPointSize(9) + painter.setFont(preview_font) + preview_w = text_w - (self.BADGE_SIZE + 8 if unread > 0 else 0) + fm2 = painter.fontMetrics() + receipt = index.data(ROLE_RECEIPT) or "" + preview_x = text_x + if receipt: + check_color = QColor(t.success) if receipt == "read" else QColor(t.text_muted) + painter.setPen(check_color) + single_w = fm2.horizontalAdvance("\u2713") + overlap = single_w * 0.4 + cy = preview_y + 10 # vertical center of 20px row + # First check + painter.drawText(int(preview_x), preview_y, int(single_w), 20, + Qt.AlignmentFlag.AlignLeft | Qt.AlignmentFlag.AlignVCenter, + "\u2713") + if receipt in ("delivered", "read"): + # Second check, overlapping + x2 = preview_x + single_w - overlap + painter.drawText(int(x2), preview_y, int(single_w), 20, + Qt.AlignmentFlag.AlignLeft | Qt.AlignmentFlag.AlignVCenter, + "\u2713") + total_w = single_w * 2 - overlap + 4 + else: + total_w = single_w + 4 + preview_x += total_w + preview_w -= total_w + preview_x = int(preview_x) + preview_w = int(preview_w) + painter.setPen(QColor(t.text_muted)) + elided_preview = fm2.elidedText(preview, Qt.TextElideMode.ElideRight, preview_w) + painter.drawText(preview_x, preview_y, preview_w, 20, + Qt.AlignmentFlag.AlignLeft | Qt.AlignmentFlag.AlignVCenter, + elided_preview) + + # -- Unread badge (bottom-right) -- + if unread > 0: + badge_x = rect.x() + rect.width() - self.HPAD - self.BADGE_SIZE + badge_y = preview_y + 1 + painter.setBrush(QBrush(QColor(t.accent))) + painter.setPen(Qt.PenStyle.NoPen) + painter.drawRoundedRect(badge_x, badge_y, self.BADGE_SIZE, self.BADGE_SIZE, 10, 10) + painter.setPen(QColor(t.accent_text)) + badge_font = QFont() + badge_font.setPointSize(7) + badge_font.setBold(True) + painter.setFont(badge_font) + badge_text = str(unread) if unread < 100 else "99+" + painter.drawText( + QRect(badge_x, badge_y, self.BADGE_SIZE, self.BADGE_SIZE), + Qt.AlignmentFlag.AlignCenter, badge_text, + ) + + # -- Bottom separator line -- + painter.setPen(QColor(t.separator)) + painter.drawLine(text_x, rect.bottom(), rect.right() - self.HPAD, rect.bottom()) + + painter.restore() + + +class _ReceiptFooter(QWidget): + """Tiny widget that draws timestamp + receipt checkmarks with tight spacing.""" + + def __init__(self, time_str: str, status: str, + time_color: str, check_color: str, read_color: str, + parent=None): + super().__init__(parent) + self._time = time_str + self._status = status # "", "sent", "delivered", "read" + self._time_color = QColor(time_color) + self._check_color = QColor(check_color) + self._read_color = QColor(read_color) + self._font = QFont() + self._font.setPointSize(8) + fm = QFontMetricsF(self._font) + tw = fm.horizontalAdvance(self._time + " ") + cw = fm.horizontalAdvance("\u2713") + # 2nd check overlaps 1st by 40% of its width + overlap = cw * 0.4 + checks_w = 0.0 + if status == "sent": + checks_w = cw + elif status in ("delivered", "read"): + checks_w = cw * 2 - overlap + total_w = tw + checks_w + 2 + h = fm.height() + 2 + self.setFixedSize(int(total_w + 1), int(h + 1)) + + def paintEvent(self, event): + p = QPainter(self) + p.setRenderHint(QPainter.RenderHint.Antialiasing) + p.setFont(self._font) + fm = QFontMetricsF(self._font) + y_base = fm.ascent() + 1 + + # Draw time + p.setPen(self._time_color) + p.drawText(QPointF(0, y_base), self._time) + x = fm.horizontalAdvance(self._time) + 4 + + if not self._status: + p.end() + return + + cw = fm.horizontalAdvance("\u2713") + overlap = cw * 0.4 + color = self._read_color if self._status == "read" else self._check_color + + # First check + p.setPen(color) + p.drawText(QPointF(x, y_base), "\u2713") + + # Second check (tight overlap) + if self._status in ("delivered", "read"): + p.drawText(QPointF(x + cw - overlap, y_base), "\u2713") + + p.end() + + +class MessageInput(QTextEdit): + """Multiline message input: Enter sends, Shift+Enter inserts newline.""" + send_requested = pyqtSignal() + file_dropped = pyqtSignal(str) + + @staticmethod + def _style_normal(): + t = c() + return ( + f"QTextEdit {{ background-color: {t.bg_secondary}; border: 1px solid {t.border}; " + f"border-radius: 18px; padding: 8px 14px; color: {t.text_primary}; }}" + f"QTextEdit:focus {{ border: 1px solid {t.border_focus}; }}" + ) + + @staticmethod + def _style_drop(): + t = c() + return ( + f"QTextEdit {{ background-color: {t.bg_secondary}; border: 2px dashed {t.accent}; " + f"border-radius: 18px; padding: 8px 14px; color: {t.text_primary}; }}" + ) + + def __init__(self, parent=None): + super().__init__(parent) + self.setAcceptRichText(False) + self.setPlaceholderText("Type a message...") + self.setMinimumHeight(52) + self.setMaximumHeight(120) + self.setAcceptDrops(True) + self.drop_enabled = False + self.setStyleSheet(self._style_normal()) + self.textChanged.connect(self._auto_resize) + # Tight line spacing — set on document default cursor format + from PyQt6.QtGui import QTextBlockFormat + fmt = QTextBlockFormat() + fmt.setTopMargin(0) + fmt.setBottomMargin(0) + fmt.setLineHeight(0, 0) # 0 = SingleHeight + self._block_fmt = fmt + # Apply to default block format + cursor = self.textCursor() + cursor.setBlockFormat(fmt) + self.setTextCursor(cursor) + + def _auto_resize(self): + doc_height = int(self.document().size().height()) + 16 # padding + new_h = max(52, min(doc_height, 120)) + self.setFixedHeight(new_h) + + def keyPressEvent(self, event): + if event.key() in (Qt.Key.Key_Return, Qt.Key.Key_Enter): + if event.modifiers() & Qt.KeyboardModifier.ShiftModifier: + # Insert plain newline instead of new paragraph + self.textCursor().insertText("\n") + return + else: + self.send_requested.emit() + return + super().keyPressEvent(event) + + def dragEnterEvent(self, event): + if not self.drop_enabled: + event.ignore() + return + if event.mimeData().hasUrls() and any(u.isLocalFile() for u in event.mimeData().urls()): + event.acceptProposedAction() + self.setStyleSheet(self._style_drop()) + else: + super().dragEnterEvent(event) + + def dragMoveEvent(self, event): + if event.mimeData().hasUrls(): + event.acceptProposedAction() + else: + super().dragMoveEvent(event) + + def dragLeaveEvent(self, event): + self.setStyleSheet(self._style_normal()) + super().dragLeaveEvent(event) + + def dropEvent(self, event): + self.setStyleSheet(self._style_normal()) + if event.mimeData().hasUrls(): + for url in event.mimeData().urls(): + if url.isLocalFile(): + self.file_dropped.emit(url.toLocalFile()) + event.acceptProposedAction() + else: + super().dropEvent(event) + + +class AsyncBridge(QThread): + """Runs asyncio event loop in a background thread, emits Qt signals.""" + connected = pyqtSignal() + connection_error = pyqtSignal(str) + login_result = pyqtSignal(bool, str) + register_result = pyqtSignal(bool, str) + conversations_loaded = pyqtSignal(list) + messages_loaded = pyqtSignal(str, list) # conv_id, messages + older_messages_loaded = pyqtSignal(str, list) # conv_id, older messages + message_sent = pyqtSignal(bool, str) + message_sent_payload = pyqtSignal(str, dict) # conv_id, message dict (for local append) + new_notification = pyqtSignal(dict) # decrypted payload + pairing_code = pyqtSignal(str) + pairing_complete = pyqtSignal(bool, str) + add_member_result = pyqtSignal(bool, str) + remove_member_result = pyqtSignal(bool, str) + authorize_result = pyqtSignal(bool, str) + rotate_result = pyqtSignal(bool, str) + reencrypt_status = pyqtSignal(str) + messages_read_notification = pyqtSignal(dict) + message_delivered_notification = pyqtSignal(dict) + message_deleted_notification = pyqtSignal(dict) + image_sent = pyqtSignal(bool, str) + image_downloaded = pyqtSignal(str, bytes) # file_id, decrypted bytes + delete_message_result = pyqtSignal(bool, str) + reconnected = pyqtSignal() + conversation_updated = pyqtSignal() + connection_state_changed = pyqtSignal(str) # "connected", "disconnected", "reconnecting" + profile_loaded = pyqtSignal(dict) + profile_updated = pyqtSignal(bool, str) + avatar_loaded = pyqtSignal(str, bytes) # user_id, avatar_bytes + online_status_changed = pyqtSignal(str, bool) # user_id, is_online + online_users_loaded = pyqtSignal(list) # list of user_ids + invitations_loaded = pyqtSignal(list) # list of invitation dicts + invitation_result = pyqtSignal(bool, str) # ok, message + invitation_received = pyqtSignal(dict) # invitation notification data + group_avatar_loaded = pyqtSignal(str, bytes) # conv_id, avatar_bytes + group_avatar_updated = pyqtSignal(bool, str) # ok, message + session_reset_notification = pyqtSignal(str, str) # from_user_id, from_device_id + reaction_result = pyqtSignal(bool, str) # ok, message + reaction_notification = pyqtSignal(dict) # {message_id, conversation_id, user_id, reaction, action} + pin_notification = pyqtSignal(dict) # {message_id, conversation_id, user_id, action=pin} + unpin_notification = pyqtSignal(dict) # {message_id, conversation_id, user_id, action=unpin} + pinned_messages_loaded = pyqtSignal(str, list) # conv_id, list of pinned msg dicts + forward_result = pyqtSignal(bool, str) # ok, message + key_change_warning = pyqtSignal(str, str, str, bool, bytes) # user_id, username, old_key_hex, was_verified, new_key_bytes + password_changed = pyqtSignal(bool, str) # ok, message + username_changed = pyqtSignal(bool, str) # ok, message + + def __init__(self): + super().__init__() + self.client = ChatClient() + self.loop: asyncio.AbstractEventLoop | None = None + self._running = True + self.client._reencrypt_progress_cb = self._emit_reencrypt_status + self.client._key_change_cb = self._emit_key_change_warning + self._ready: asyncio.Event | None = None + self._avatar_inflight: set[str] = set() + self._group_avatar_inflight: set[str] = set() + self._invitations_inflight = False + + def _emit_reencrypt_status(self, message: str): + self.reencrypt_status.emit(message) + + def _emit_key_change_warning(self, user_id: str, username: str, old_key_hex: str, was_verified: bool, new_key_bytes: bytes = b""): + self.key_change_warning.emit(user_id, username, old_key_hex, was_verified, new_key_bytes) + + def run(self): + if sys.platform == "win32": + self.loop = asyncio.SelectorEventLoop() + else: + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + self._ready = asyncio.Event() + try: + self.loop.run_until_complete(self._run()) + except Exception as e: + logger.error("AsyncBridge loop crashed: %s", e, exc_info=True) + finally: + self.loop.close() + + async def _run(self): + try: + await self.client.connect() + self.client._listener_task = asyncio.create_task(self.client._background_listener()) + if self._ready: + self._ready.set() + self.connected.emit() + self.connection_state_changed.emit("connected") + except Exception as e: + self.connection_error.emit(str(e)) + return + + # Process notifications + await self._notification_loop() + + async def _notification_loop(self): + while self._running: + try: + # Check if listener task died (connection lost) + if (self.client._listener_task and self.client._listener_task.done() + and not self.client.connected): + self.connection_state_changed.emit("disconnected") + if self.client.session: + await self._auto_reconnect() + continue + + notif = await asyncio.wait_for( + self.client._notification_queue.get(), timeout=0.5 + ) + notif_type = notif.get("type", "") + data = notif.get("data", {}) + if notif_type in ("conversation_created", "member_added", "member_removed", + "conversation_renamed"): + self.conversation_updated.emit() + elif notif_type == "group_invitation": + self.invitation_received.emit(data) + elif notif_type == "user_online": + self.online_status_changed.emit(data.get("user_id", ""), True) + elif notif_type == "user_offline": + self.online_status_changed.emit(data.get("user_id", ""), False) + elif notif_type == "online_users": + self.online_users_loaded.emit(data.get("user_ids", [])) + elif notif_type == "messages_read": + self.messages_read_notification.emit(data) + elif notif_type == "message_delivered": + self.message_delivered_notification.emit(data) + elif notif_type == "message_deleted": + self.message_deleted_notification.emit(data) + elif notif_type == "session_reset": + from_uid = data.get("from_user_id", "") + from_did = data.get("from_device_id", "") + self.client.handle_session_reset_notification(from_uid, from_did or None) + self.session_reset_notification.emit(from_uid, from_did) + elif notif_type == "username_changed": + self.conversation_updated.emit() + elif notif_type == "message_reacted": + self.reaction_notification.emit(data) + elif notif_type == "message_pinned": + self.pin_notification.emit(data) + elif notif_type == "message_unpinned": + self.unpin_notification.emit(data) + elif notif_type == "new_message": + try: + payload = self.client.decrypt_notification(data) + except IdentityKeyChanged as ikc: + cached = self.client._user_cache.get(ikc.user_id) + uname = cached.get("username", "") if cached else "" + old_hex = "" + known = self.client._known_identity_keys.get(ikc.user_id) + if known: + old_hex = known.get("identity_key", "") + was_verified = ikc.status == "changed_verified" + self.key_change_warning.emit(ikc.user_id, uname, old_hex, was_verified, ikc.new_key_bytes) + continue + if payload: + self.new_notification.emit(payload) + # None = control message (e.g. sender key distribution), skip silently + except asyncio.TimeoutError: + continue + except Exception as e: + logger.error("Notification loop exception: %s", e, exc_info=True) + break + + async def _auto_reconnect(self): + """Auto-reconnect with exponential backoff.""" + delay = 1 + while self._running and not self.client.connected: + self.connection_state_changed.emit("reconnecting") + try: + await self.client.reconnect() + if self.client.connected and self.client.session: + self.connection_state_changed.emit("connected") + self.conversation_updated.emit() + return + if self.client.login_rejected: + self.connection_state_changed.emit("revoked") + return + except Exception: + pass + await asyncio.sleep(delay) + delay = min(delay * 2, 30) + + def schedule(self, coro): + """Schedule a coroutine on the asyncio loop from the Qt thread.""" + if self.loop and self.loop.is_running(): + asyncio.run_coroutine_threadsafe(coro, self.loop) + else: + # Avoid "coroutine was never awaited" warnings if loop is down. + try: + coro.close() + except Exception: + pass + + async def _do_register(self, username, password, email): + if self._ready: + await self._ready.wait() + ok, code_or_msg = await self.client.register(username, password, email=email) + self.register_result.emit(ok, code_or_msg) + + async def _do_login(self, email, password): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.login(email, password) + self.login_result.emit(ok, msg) + + async def _do_logout(self): + if self._ready: + self._ready.clear() + try: + await self.client.close() + except Exception: + pass + self.client = ChatClient() + self.client._reencrypt_progress_cb = self._emit_reencrypt_status + try: + await self.client.connect() + self.client._listener_task = asyncio.create_task(self.client._background_listener()) + if self._ready: + self._ready.set() + self.reconnected.emit() + except Exception as e: + self.connection_error.emit(str(e)) + + async def _do_load_conversations(self): + if self._ready: + await self._ready.wait() + convs = await self.client.list_conversations() + self.conversations_loaded.emit(convs) + + async def _do_load_messages(self, conv_id): + if self._ready: + await self._ready.wait() + msgs = await self.client.get_messages(conv_id) + self.messages_loaded.emit(conv_id, msgs) + + async def _do_load_older_messages(self, conv_id, offset): + if self._ready: + await self._ready.wait() + msgs = await self.client.get_messages(conv_id, limit=50, offset=offset) + self.older_messages_loaded.emit(conv_id, msgs) + + async def _do_send_message(self, conv_id, text, members, reply_to=None): + if self._ready: + await self._ready.wait() + try: + ok, result = await self.client.send_message(conv_id, text, members, reply_to=reply_to) + except IdentityKeyChanged as ikc: + cached = self.client._user_cache.get(ikc.user_id) + uname = cached.get("username", "") if cached else "" + old_hex = "" + known = self.client._known_identity_keys.get(ikc.user_id) + if known: + old_hex = known.get("identity_key", "") + was_verified = ikc.status == "changed_verified" + self.key_change_warning.emit(ikc.user_id, uname, old_hex, was_verified, ikc.new_key_bytes) + self.message_sent.emit(False, "Identity key changed — accept new key first.") + return + except Exception as e: + logger.error("send_message exception: %s", e, exc_info=True) + self.message_sent.emit(False, str(e)) + return + if ok and isinstance(result, dict): + self.message_sent.emit(True, "Message sent.") + self.message_sent_payload.emit(conv_id, result) + else: + self.message_sent.emit(ok, result if isinstance(result, str) else "Message sent.") + + async def _do_find_or_create_and_send(self, username, text): + if self._ready: + await self._ready.wait() + try: + conv_id, msg = await self.client.find_or_create_conversation(username) + if not conv_id: + self.message_sent.emit(False, msg) + return + convs = await self.client.list_conversations() + self.conversations_loaded.emit(convs) + members = [] + for cv in convs: + if cv["conversation_id"] == conv_id: + members = cv["members"] + break + ok, result = await self.client.send_message(conv_id, text, members) + if ok and isinstance(result, dict): + self.message_sent.emit(True, "Message sent.") + self.message_sent_payload.emit(conv_id, result) + else: + self.message_sent.emit(ok, result if isinstance(result, str) else "Message sent.") + except IdentityKeyChanged as ikc: + cached = self.client._user_cache.get(ikc.user_id) + uname = cached.get("username", "") if cached else "" + old_hex = "" + known = self.client._known_identity_keys.get(ikc.user_id) + if known: + old_hex = known.get("identity_key", "") + was_verified = ikc.status == "changed_verified" + self.key_change_warning.emit(ikc.user_id, uname, old_hex, was_verified, ikc.new_key_bytes) + self.message_sent.emit(False, "Identity key changed — accept new key first.") + except Exception as e: + logger.error("find_or_create_and_send exception: %s", e, exc_info=True) + self.message_sent.emit(False, str(e)) + + async def _do_create_group(self, members, name=None): + if self._ready: + await self._ready.wait() + conv_id, msg = await self.client.create_conversation(members, name=name) + if conv_id: + self.message_sent.emit(True, f"Group created") + else: + self.message_sent.emit(False, msg) + convs = await self.client.list_conversations() + self.conversations_loaded.emit(convs) + + async def _do_link_device(self, username, password): + if self._ready: + await self._ready.wait() + ok, code_or_msg = await self.client.pairing_start(username) + if not ok: + self.pairing_complete.emit(False, code_or_msg) + return + code = code_or_msg + self.pairing_code.emit(code) + ok2, msg2 = await self.client.pairing_wait(code, username, password) + self.pairing_complete.emit(ok2, msg2) + + async def _do_authorize_device(self, code): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.authorize_device(code) + self.authorize_result.emit(ok, msg) + + async def _do_rotate_keys(self, username, password): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.rotate_keys(username, password) + self.rotate_result.emit(ok, msg) + + def do_register(self, username, password, email): + self.schedule(self._do_register(username, password, email)) + + def do_login(self, email, password): + self.schedule(self._do_login(email, password)) + + def load_conversations(self): + self.schedule(self._do_load_conversations()) + + def load_messages(self, conv_id): + self.schedule(self._do_load_messages(conv_id)) + + def load_older_messages(self, conv_id, offset): + self.schedule(self._do_load_older_messages(conv_id, offset)) + + def send_message(self, conv_id, text, members, reply_to=None): + self.schedule(self._do_send_message(conv_id, text, members, reply_to)) + + def send_new_chat(self, username, text): + self.schedule(self._do_find_or_create_and_send(username, text)) + + def create_group(self, members, name=None): + self.schedule(self._do_create_group(members, name=name)) + + async def _do_add_member(self, conv_id, email): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.add_member(conv_id, email) + self.add_member_result.emit(ok, msg) + if ok: + convs = await self.client.list_conversations() + self.conversations_loaded.emit(convs) + + def add_member(self, conv_id, email): + self.schedule(self._do_add_member(conv_id, email)) + + async def _do_remove_member(self, conv_id, user_id): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.remove_member(conv_id, user_id) + self.remove_member_result.emit(ok, msg) + if ok: + convs = await self.client.list_conversations() + self.conversations_loaded.emit(convs) + + def remove_member(self, conv_id, user_id): + self.schedule(self._do_remove_member(conv_id, user_id)) + + group_left = pyqtSignal(bool, str) + group_renamed = pyqtSignal(bool, str) + conversation_deleted = pyqtSignal(bool, str) + + async def _do_leave_group(self, conv_id): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.leave_group(conv_id) + self.group_left.emit(ok, msg) + if ok: + convs = await self.client.list_conversations() + self.conversations_loaded.emit(convs) + + def leave_group(self, conv_id): + self.schedule(self._do_leave_group(conv_id)) + + async def _do_rename_conversation(self, conv_id, name): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.rename_conversation(conv_id, name) + self.group_renamed.emit(ok, msg) + if ok: + convs = await self.client.list_conversations() + self.conversations_loaded.emit(convs) + + def rename_conversation(self, conv_id, name): + self.schedule(self._do_rename_conversation(conv_id, name)) + + async def _do_delete_conversation(self, conv_id): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.delete_conversation(conv_id) + self.conversation_deleted.emit(ok, msg) + if ok: + convs = await self.client.list_conversations() + self.conversations_loaded.emit(convs) + + def delete_conversation(self, conv_id): + self.schedule(self._do_delete_conversation(conv_id)) + + def link_device(self, username, password): + self.schedule(self._do_link_device(username, password)) + + def authorize_device(self, code): + self.schedule(self._do_authorize_device(code)) + + def rotate_keys(self, username, password): + self.schedule(self._do_rotate_keys(username, password)) + + async def _do_change_password(self, old_password, new_password): + if self._ready: + await self._ready.wait() + ok, msg = self.client.change_password(old_password, new_password) + self.password_changed.emit(ok, msg) + + def change_password(self, old_password, new_password): + self.schedule(self._do_change_password(old_password, new_password)) + + async def _do_change_username(self, new_username): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.change_username(new_username) + self.username_changed.emit(ok, msg) + + def change_username(self, new_username): + self.schedule(self._do_change_username(new_username)) + + async def _do_delete_message(self, message_id): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.delete_message(message_id) + self.delete_message_result.emit(ok, msg) + + def delete_message(self, message_id): + self.schedule(self._do_delete_message(message_id)) + + def reset_session(self, peer_user_id, peer_device_id=None): + self.schedule(self.client.reset_session(peer_user_id, peer_device_id)) + + async def _do_send_image(self, conv_id, image_path, members, reply_to=None): + if self._ready: + await self._ready.wait() + try: + ok, result = await self.client.send_image(conv_id, image_path, members, reply_to=reply_to) + except IdentityKeyChanged as ikc: + cached = self.client._user_cache.get(ikc.user_id) + uname = cached.get("username", "") if cached else "" + old_hex = "" + known = self.client._known_identity_keys.get(ikc.user_id) + if known: + old_hex = known.get("identity_key", "") + was_verified = ikc.status == "changed_verified" + self.key_change_warning.emit(ikc.user_id, uname, old_hex, was_verified, ikc.new_key_bytes) + self.image_sent.emit(False, "Identity key changed — accept new key first.") + return + except Exception as e: + logger.error("send_image exception: %s", e, exc_info=True) + self.image_sent.emit(False, str(e)) + return + if ok and isinstance(result, dict): + self.image_sent.emit(True, "Image sent.") + self.message_sent_payload.emit(conv_id, result) + else: + self.image_sent.emit(ok, result if isinstance(result, str) else "Image sent.") + + def send_image(self, conv_id, image_path, members, reply_to=None): + self.schedule(self._do_send_image(conv_id, image_path, members, reply_to)) + + async def _do_download_image(self, file_id, image_info): + if self._ready: + await self._ready.wait() + data = await self.client.download_image(file_id, image_info) + if data: + self.image_downloaded.emit(file_id, data) + + def download_image(self, file_id, image_info): + self.schedule(self._do_download_image(file_id, image_info)) + + file_sent = pyqtSignal(bool, str) + file_downloaded = pyqtSignal(bytes, dict) # decrypted_bytes, file_info + + async def _do_send_file(self, conv_id, file_path, members, reply_to=None): + if self._ready: + await self._ready.wait() + try: + ok, result = await self.client.send_file(conv_id, file_path, members, reply_to=reply_to) + except IdentityKeyChanged as ikc: + cached = self.client._user_cache.get(ikc.user_id) + uname = cached.get("username", "") if cached else "" + old_hex = "" + known = self.client._known_identity_keys.get(ikc.user_id) + if known: + old_hex = known.get("identity_key", "") + was_verified = ikc.status == "changed_verified" + self.key_change_warning.emit(ikc.user_id, uname, old_hex, was_verified, ikc.new_key_bytes) + self.file_sent.emit(False, "Identity key changed — accept new key first.") + return + except Exception as e: + logger.error("send_file exception: %s", e, exc_info=True) + self.file_sent.emit(False, str(e)) + return + if ok and isinstance(result, dict): + self.file_sent.emit(True, "File sent.") + self.message_sent_payload.emit(conv_id, result) + else: + self.file_sent.emit(ok, result if isinstance(result, str) else "File sent.") + + def send_file(self, conv_id, file_path, members, reply_to=None): + self.schedule(self._do_send_file(conv_id, file_path, members, reply_to)) + + async def _do_download_file(self, file_id, file_info): + if self._ready: + await self._ready.wait() + data = await self.client.download_file(file_id, file_info) + if data: + self.file_downloaded.emit(data, file_info) + + def download_file(self, file_id, file_info): + self.schedule(self._do_download_file(file_id, file_info)) + + async def _do_get_profile(self, user_id=None): + if self._ready: + await self._ready.wait() + profile = await self.client.get_profile(user_id) + if profile: + self.profile_loaded.emit(profile) + + def get_profile(self, user_id=None): + self.schedule(self._do_get_profile(user_id)) + + async def _do_update_profile(self, **fields): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.update_profile(**fields) + self.profile_updated.emit(ok, msg) + + def update_profile(self, **fields): + self.schedule(self._do_update_profile(**fields)) + + async def _do_update_avatar(self, image_data): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.update_avatar(image_data) + self.profile_updated.emit(ok, msg) + + def update_avatar(self, image_data): + self.schedule(self._do_update_avatar(image_data)) + + async def _do_get_avatar(self, user_id): + if self._ready: + await self._ready.wait() + if not user_id or user_id in self._avatar_inflight: + return + self._avatar_inflight.add(user_id) + try: + data = await self.client.get_avatar(user_id) + if data: + self.avatar_loaded.emit(user_id, data) + finally: + self._avatar_inflight.discard(user_id) + + def get_avatar(self, user_id): + self.schedule(self._do_get_avatar(user_id)) + + async def _do_list_invitations(self): + if self._ready: + await self._ready.wait() + if self._invitations_inflight: + return + self._invitations_inflight = True + try: + invitations = await self.client.list_invitations() + self.invitations_loaded.emit(invitations) + finally: + self._invitations_inflight = False + + def list_invitations(self): + self.schedule(self._do_list_invitations()) + + async def _do_accept_invitation(self, conv_id): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.accept_invitation(conv_id) + self.invitation_result.emit(ok, msg) + if ok: + invitations = await self.client.list_invitations() + self.invitations_loaded.emit(invitations) + convs = await self.client.list_conversations() + self.conversations_loaded.emit(convs) + + def accept_invitation(self, conv_id): + self.schedule(self._do_accept_invitation(conv_id)) + + async def _do_decline_invitation(self, conv_id): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.decline_invitation(conv_id) + self.invitation_result.emit(ok, msg) + if ok: + invitations = await self.client.list_invitations() + self.invitations_loaded.emit(invitations) + + def decline_invitation(self, conv_id): + self.schedule(self._do_decline_invitation(conv_id)) + + async def _do_update_group_avatar(self, conv_id, image_data): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.update_group_avatar(conv_id, image_data) + self.group_avatar_updated.emit(ok, msg) + if ok: + convs = await self.client.list_conversations() + self.conversations_loaded.emit(convs) + + def update_group_avatar(self, conv_id, image_data): + self.schedule(self._do_update_group_avatar(conv_id, image_data)) + + async def _do_get_group_avatar(self, conv_id): + if self._ready: + await self._ready.wait() + if not conv_id or conv_id in self._group_avatar_inflight: + return + self._group_avatar_inflight.add(conv_id) + try: + data = await self.client.get_group_avatar(conv_id) + if data: + self.group_avatar_loaded.emit(conv_id, data) + finally: + self._group_avatar_inflight.discard(conv_id) + + def get_group_avatar(self, conv_id): + self.schedule(self._do_get_group_avatar(conv_id)) + + # --- Reactions, Pins, Forwarding --- + + async def _do_react_message(self, message_id, reaction, action): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.react_message(message_id, reaction, action) + self.reaction_result.emit(ok, msg) + + def react_message(self, message_id, reaction, action="add"): + self.schedule(self._do_react_message(message_id, reaction, action)) + + async def _do_pin_message(self, message_id, conversation_id, action): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.pin_message(message_id, conversation_id, action) + if not ok: + self.reaction_result.emit(False, msg) + + def pin_message(self, message_id, conversation_id, action="pin"): + self.schedule(self._do_pin_message(message_id, conversation_id, action)) + + async def _do_get_pinned_messages(self, conv_id): + if self._ready: + await self._ready.wait() + pinned = await self.client.get_pinned_messages(conv_id) + self.pinned_messages_loaded.emit(conv_id, pinned) + + def get_pinned_messages(self, conv_id): + self.schedule(self._do_get_pinned_messages(conv_id)) + + async def _do_forward_message(self, target_conv_id, original_msg, target_members): + if self._ready: + await self._ready.wait() + try: + ok, result = await self.client.forward_message(target_conv_id, original_msg, target_members) + except IdentityKeyChanged as ikc: + cached = self.client._user_cache.get(ikc.user_id) + uname = cached.get("username", "") if cached else "" + old_hex = "" + known = self.client._known_identity_keys.get(ikc.user_id) + if known: + old_hex = known.get("identity_key", "") + was_verified = ikc.status == "changed_verified" + self.key_change_warning.emit(ikc.user_id, uname, old_hex, was_verified, ikc.new_key_bytes) + self.forward_result.emit(False, "Identity key changed — accept new key first.") + return + except Exception as e: + logger.error("forward_message exception: %s", e, exc_info=True) + self.forward_result.emit(False, str(e)) + return + if ok and isinstance(result, dict): + self.forward_result.emit(True, "Message forwarded.") + self.message_sent_payload.emit(target_conv_id, result) + else: + self.forward_result.emit(ok, result if isinstance(result, str) else "Forwarded.") + + def forward_message(self, target_conv_id, original_msg, target_members): + self.schedule(self._do_forward_message(target_conv_id, original_msg, target_members)) + + def logout(self): + self.schedule(self._do_logout()) + + def stop(self): + self._running = False + if self.loop: + asyncio.run_coroutine_threadsafe(self.client.close(), self.loop) + + +def _make_frameless(dlg: QDialog, title_text: str = ""): + """Configure a QDialog as frameless with custom title bar, rounded container, + and drop shadow. Returns a QVBoxLayout for the dialog content area — + callers just add their widgets to the returned layout. + + Usage:: + + dlg = QDialog(self) + dlg.setMinimumWidth(380) + content_layout = _make_frameless(dlg, "My Title") + content_layout.addWidget(QLabel("Hello!")) + dlg.exec() + """ + dlg.setWindowFlags( + Qt.WindowType.FramelessWindowHint | Qt.WindowType.Dialog + ) + dlg.setAttribute(Qt.WidgetAttribute.WA_TranslucentBackground) + dlg._drag_pos = None + + t = c() + + # -- Outer layout (transparent, holds container with margins for shadow) -- + outer = QVBoxLayout(dlg) + outer.setContentsMargins(12, 12, 12, 12) + outer.setSpacing(0) + + # -- Rounded container -- + container = QWidget() + container.setObjectName("_framelessContainer") + container.setStyleSheet( + f"#_framelessContainer {{ background-color: {t.bg_primary}; border-radius: 12px; }}" + ) + shadow = QGraphicsDropShadowEffect(container) + shadow.setBlurRadius(24) + shadow.setOffset(0, 4) + shadow.setColor(QColor(0, 0, 0, 80)) + container.setGraphicsEffect(shadow) + + container_lay = QVBoxLayout(container) + container_lay.setContentsMargins(0, 0, 0, 0) + container_lay.setSpacing(0) + + # -- Title bar -- + title_bar = QWidget() + title_bar.setFixedHeight(40) + title_bar.setStyleSheet( + f"background-color: {t.bg_secondary}; " + f"border-top-left-radius: 12px; border-top-right-radius: 12px;" + ) + bar_layout = QHBoxLayout(title_bar) + bar_layout.setContentsMargins(16, 0, 8, 0) + bar_layout.setSpacing(0) + + title_label = QLabel(title_text) + title_label.setStyleSheet( + f"color: {t.text_primary}; font-size: 11pt; font-weight: bold; " + f"background: transparent;" + ) + bar_layout.addWidget(title_label) + bar_layout.addStretch() + + close_btn = QPushButton("\u2715") + close_btn.setFixedSize(28, 28) + close_btn.setCursor(Qt.CursorShape.PointingHandCursor) + close_btn.setStyleSheet( + f"QPushButton {{ background: transparent; color: {t.text_muted}; " + f"border: none; border-radius: 14px; font-size: 12pt; }}" + f"QPushButton:hover {{ background-color: {t.error}; color: {t.accent_text}; }}" + ) + close_btn.clicked.connect(dlg.reject) + bar_layout.addWidget(close_btn) + + # Dragging via title bar + def _mouse_press(event): + if event.button() == Qt.MouseButton.LeftButton: + dlg._drag_pos = event.globalPosition().toPoint() - dlg.frameGeometry().topLeft() + event.accept() + def _mouse_move(event): + if dlg._drag_pos is not None and event.buttons() & Qt.MouseButton.LeftButton: + dlg.move(event.globalPosition().toPoint() - dlg._drag_pos) + event.accept() + def _mouse_release(event): + dlg._drag_pos = None + title_bar.mousePressEvent = _mouse_press + title_bar.mouseMoveEvent = _mouse_move + title_bar.mouseReleaseEvent = _mouse_release + + container_lay.addWidget(title_bar) + + # -- Content widget -- + content = QWidget() + content_layout = QVBoxLayout(content) + content_layout.setContentsMargins(16, 12, 16, 16) + content_layout.setSpacing(8) + container_lay.addWidget(content) + + outer.addWidget(container) + + # Store refs for later theming + dlg._frameless_container = container + dlg._frameless_title_bar = title_bar + dlg._frameless_title_label = title_label + return content_layout + + +class UserProfileDialog(QDialog): + """Dialog for viewing/editing user profiles.""" + + def __init__(self, bridge: AsyncBridge, user_id: str, editable: bool = False, parent=None): + super().__init__(parent) + self.bridge = bridge + self.user_id = user_id + self.editable = editable + self.setMinimumWidth(400) + self._build_ui() + self._connect_signals() + self.bridge.get_profile(user_id) + + def _build_ui(self): + t = c() + title_text = "Edit Profile" if self.editable else "User Profile" + self.layout_main = _make_frameless(self, title_text) + self.layout_main.setSpacing(12) + + # Avatar + self.avatar_label = QLabel() + self.avatar_label.setFixedSize(80, 80) + self.avatar_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + t = c() + self.avatar_label.setStyleSheet( + f"background-color: {t.bg_secondary}; border-radius: 40px; " + f"font-size: 21pt; color: {t.accent};" + ) + self.avatar_label.setText("?") + self.layout_main.addWidget(self.avatar_label, alignment=Qt.AlignmentFlag.AlignCenter) + + if self.editable: + avatar_btn = QPushButton("Change Avatar") + avatar_btn.setObjectName("secondaryBtn") + avatar_btn.clicked.connect(self._on_change_avatar) + self.layout_main.addWidget(avatar_btn, alignment=Qt.AlignmentFlag.AlignCenter) + + # Info fields + self.username_label = QLabel("") + self.username_label.setStyleSheet(f"font-size: 14pt; font-weight: bold; color: {t.accent};") + self.username_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.layout_main.addWidget(self.username_label) + + self.info_area = QVBoxLayout() + self.layout_main.addLayout(self.info_area) + + # Editable fields (only shown in edit mode) + if self.editable: + self.layout_main.addSpacing(8) + + form_label = QLabel("Profile Settings") + form_label.setStyleSheet(f"font-weight: bold; color: {t.accent};") + self.layout_main.addWidget(form_label) + + self.phone_input = QLineEdit() + self.phone_input.setPlaceholderText("Phone number") + self.layout_main.addWidget(self.phone_input) + + self.location_input = QLineEdit() + self.location_input.setPlaceholderText("Location") + self.layout_main.addWidget(self.location_input) + + from PyQt6.QtWidgets import QCheckBox + self.email_visible_cb = QCheckBox("Email visible to others") + self.email_visible_cb.setStyleSheet(f"color: {t.text_primary};") + self.layout_main.addWidget(self.email_visible_cb) + + self.phone_visible_cb = QCheckBox("Phone visible to others") + self.phone_visible_cb.setStyleSheet(f"color: {t.text_primary};") + self.layout_main.addWidget(self.phone_visible_cb) + + self.location_visible_cb = QCheckBox("Location visible to others") + self.location_visible_cb.setStyleSheet(f"color: {t.text_primary};") + self.layout_main.addWidget(self.location_visible_cb) + + save_btn = QPushButton("Save") + save_btn.clicked.connect(self._on_save) + self.layout_main.addWidget(save_btn) + + # Security section (only when viewing another user, not own profile) + if not self.editable: + self._security_section = QVBoxLayout() + self.layout_main.addLayout(self._security_section) + + close_btn = QPushButton("Close") + close_btn.setObjectName("secondaryBtn") + close_btn.clicked.connect(self.accept) + self.layout_main.addWidget(close_btn) + + def _connect_signals(self): + self.bridge.profile_loaded.connect(self._on_profile_loaded) + self.bridge.avatar_loaded.connect(self._on_avatar_loaded) + self.bridge.profile_updated.connect(self._on_profile_updated) + + def _on_profile_loaded(self, profile): + if profile.get("user_id") != self.user_id: + return + username = profile.get("username", "?") + self.username_label.setText(username) + # Set avatar initial + self.avatar_label.setText(username[0].upper() if username else "?") + + # Clear info area + while self.info_area.count(): + item = self.info_area.takeAt(0) + if item.widget(): + item.widget().deleteLater() + + # Email + email = profile.get("email") + if email: + self.info_area.addWidget(QLabel(f"Email: {email}")) + + # Phone + phone = profile.get("phone") + if phone: + self.info_area.addWidget(QLabel(f"Phone: {phone}")) + + # Location + location = profile.get("location") + if location: + self.info_area.addWidget(QLabel(f"Location: {location}")) + + # Member since + created_at = profile.get("created_at", "") + if created_at: + date_str = created_at[:10] if len(created_at) >= 10 else created_at + label = QLabel(f"Member since: {date_str}") + label.setStyleSheet(f"color: {c().text_muted};") + self.info_area.addWidget(label) + + # Populate editable fields + if self.editable: + self.phone_input.setText(phone or "") + self.location_input.setText(location or "") + self.email_visible_cb.setChecked(bool(profile.get("email_visible", 1))) + self.phone_visible_cb.setChecked(bool(profile.get("phone_visible", 0))) + self.location_visible_cb.setChecked(bool(profile.get("location_visible", 0))) + + # Try to load avatar + if profile.get("avatar_file"): + self.bridge.get_avatar(self.user_id) + + # Security section (viewing another user, not self) + my_uid = self.bridge.client.session.get("user_id", "") if self.bridge.client.session else "" + if not self.editable and hasattr(self, "_security_section") and self.user_id != my_uid: + self._populate_security_section() + + def _populate_security_section(self): + """Populate security/verification info for a peer user.""" + t = c() + sec = self._security_section + # Clear previous contents + while sec.count(): + item = sec.takeAt(0) + if item.widget(): + item.widget().deleteLater() + + sec.addSpacing(8) + header = QLabel("Security") + header.setStyleSheet(f"font-weight: bold; color: {t.accent};") + sec.addWidget(header) + + status = self.bridge.client.get_verification_status(self.user_id) + if status == "verified": + status_label = QLabel("\u2705 Identity verified") + status_label.setStyleSheet(f"color: {t.success}; font-weight: bold;") + elif status == "trusted": + status_label = QLabel("\U0001f512 Trusted (first use)") + status_label.setStyleSheet(f"color: {t.warning};") + else: + status_label = QLabel("\u26A0 Unverified") + status_label.setStyleSheet(f"color: {t.text_muted};") + sec.addWidget(status_label) + + fp = self.bridge.client.get_peer_fingerprint(self.user_id) + if fp: + fp_label = QLabel(f"Fingerprint:\n{fp}") + fp_label.setStyleSheet( + f"font-family: monospace; font-size: 8pt; color: {t.text_primary}; " + f"background: {t.bg_secondary}; padding: 4px; border-radius: 4px;" + ) + fp_label.setTextInteractionFlags(Qt.TextInteractionFlag.TextSelectableByMouse) + sec.addWidget(fp_label) + + def _on_avatar_loaded(self, user_id, data): + if user_id != self.user_id: + return + qimg = _safe_load_image(data) + if qimg is not None: + pixmap = QPixmap.fromImage(qimg) + # Circular crop + size = 80 + scaled = pixmap.scaled(size, size, Qt.AspectRatioMode.KeepAspectRatioByExpanding, + Qt.TransformationMode.SmoothTransformation) + result = QPixmap(size, size) + result.fill(QColor(0, 0, 0, 0)) + painter = QPainter(result) + painter.setRenderHint(QPainter.RenderHint.Antialiasing) + painter.setBrush(QBrush(scaled)) + painter.setPen(Qt.PenStyle.NoPen) + painter.drawEllipse(0, 0, size, size) + painter.end() + self.avatar_label.setPixmap(result) + + def _on_change_avatar(self): + path, _ = QFileDialog.getOpenFileName( + self, "Select Avatar", "", + "Images (*.png *.jpg *.jpeg);;All Files (*)", + ) + if not path: + return + try: + with open(path, "rb") as f: + data = f.read() + if len(data) > 2 * 1024 * 1024: + QMessageBox.warning(self, "Error", "Avatar too large (max 2 MB).") + return + self.bridge.update_avatar(data) + except Exception as e: + QMessageBox.warning(self, "Error", f"Failed to read file: {e}") + + def _on_save(self): + fields = { + "phone": self.phone_input.text().strip() or None, + "location": self.location_input.text().strip() or None, + "email_visible": 1 if self.email_visible_cb.isChecked() else 0, + "phone_visible": 1 if self.phone_visible_cb.isChecked() else 0, + "location_visible": 1 if self.location_visible_cb.isChecked() else 0, + } + self.bridge.update_profile(**fields) + + def _on_profile_updated(self, ok, msg): + if ok: + # Refresh profile + self.bridge.get_profile(self.user_id) + else: + QMessageBox.warning(self, "Error", msg) + + def closeEvent(self, event): + # Disconnect signals to avoid stale references + try: + self.bridge.profile_loaded.disconnect(self._on_profile_loaded) + self.bridge.avatar_loaded.disconnect(self._on_avatar_loaded) + self.bridge.profile_updated.disconnect(self._on_profile_updated) + except Exception: + pass + super().closeEvent(event) + + def reject(self): + try: + self.bridge.profile_loaded.disconnect(self._on_profile_loaded) + self.bridge.avatar_loaded.disconnect(self._on_avatar_loaded) + self.bridge.profile_updated.disconnect(self._on_profile_updated) + except Exception: + pass + super().reject() + + def accept(self): + try: + self.bridge.profile_loaded.disconnect(self._on_profile_loaded) + self.bridge.avatar_loaded.disconnect(self._on_avatar_loaded) + self.bridge.profile_updated.disconnect(self._on_profile_updated) + except Exception: + pass + super().accept() + + +class VerificationDialog(QDialog): + """Dialog for viewing safety numbers, fingerprints, QR codes, and verifying contacts.""" + + def __init__(self, bridge: AsyncBridge, peer_user_id: str, peer_name: str, parent=None): + super().__init__(parent) + self.bridge = bridge + self.peer_user_id = peer_user_id + self.peer_name = peer_name + self.setMinimumWidth(420) + self._build_ui() + + def _build_ui(self): + t = c() + lay = _make_frameless(self, "Verify Contact") + lay.setSpacing(10) + + # Peer name + name_label = QLabel(self.peer_name) + name_label.setStyleSheet(f"font-size: 14pt; font-weight: bold; color: {t.accent};") + name_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + lay.addWidget(name_label) + + # Verification status + status = self.bridge.client.get_verification_status(self.peer_user_id) + if status == "verified": + status_text = "\u2705 Verified" + status_color = t.success + elif status == "trusted": + status_text = "\U0001f512 Trusted (TOFU)" + status_color = t.warning + else: + status_text = "\u26A0 Unverified" + status_color = t.error + self._status_label = QLabel(status_text) + self._status_label.setStyleSheet(f"font-size: 11pt; font-weight: bold; color: {status_color};") + self._status_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + lay.addWidget(self._status_label) + + lay.addSpacing(4) + + # Safety Number + safety = self.bridge.client.get_safety_number(self.peer_user_id) + if safety: + sn_header = QLabel("Safety Number") + sn_header.setStyleSheet(f"font-weight: bold; color: {t.text_secondary};") + lay.addWidget(sn_header) + + sn_label = QLabel(safety) + sn_label.setStyleSheet( + f"font-family: monospace; font-size: 13pt; letter-spacing: 2px; " + f"color: {t.text_primary}; background: {t.bg_secondary}; " + f"padding: 12px; border-radius: 8px;" + ) + sn_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + sn_label.setTextInteractionFlags(Qt.TextInteractionFlag.TextSelectableByMouse) + lay.addWidget(sn_label) + + lay.addSpacing(4) + + # QR Code + qr_data = self.bridge.client.get_verification_qr_data() + if qr_data: + qr_header = QLabel("Your QR Code (for peer to scan)") + qr_header.setStyleSheet(f"font-weight: bold; color: {t.text_secondary};") + lay.addWidget(qr_header) + + qr_pixmap = self._generate_qr_pixmap(qr_data) + if qr_pixmap: + qr_label = QLabel() + qr_label.setPixmap(qr_pixmap) + qr_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + lay.addWidget(qr_label) + + save_qr_btn = QPushButton("Save QR Code") + save_qr_btn.setObjectName("secondaryBtn") + save_qr_btn.clicked.connect(lambda: self._save_qr(qr_pixmap)) + lay.addWidget(save_qr_btn) + + lay.addSpacing(4) + + # Fingerprints + my_fp = self.bridge.client.get_my_fingerprint() + peer_fp = self.bridge.client.get_peer_fingerprint(self.peer_user_id) + + if my_fp or peer_fp: + fp_header = QLabel("Fingerprints") + fp_header.setStyleSheet(f"font-weight: bold; color: {t.text_secondary};") + lay.addWidget(fp_header) + + if my_fp: + my_fp_label = QLabel(f"Yours:\n{my_fp}") + my_fp_label.setStyleSheet( + f"font-family: monospace; font-size: 9pt; color: {t.text_primary}; " + f"background: {t.bg_secondary}; padding: 6px; border-radius: 4px;" + ) + my_fp_label.setTextInteractionFlags(Qt.TextInteractionFlag.TextSelectableByMouse) + lay.addWidget(my_fp_label) + + if peer_fp: + peer_name_esc = self.peer_name.replace("&", "&").replace("<", "<") + peer_fp_label = QLabel(f"{peer_name_esc}:\n{peer_fp}") + peer_fp_label.setStyleSheet( + f"font-family: monospace; font-size: 9pt; color: {t.text_primary}; " + f"background: {t.bg_secondary}; padding: 6px; border-radius: 4px;" + ) + peer_fp_label.setTextInteractionFlags(Qt.TextInteractionFlag.TextSelectableByMouse) + lay.addWidget(peer_fp_label) + + lay.addSpacing(8) + + # Action buttons + btn_row = QHBoxLayout() + + if status != "verified": + verify_btn = QPushButton("Mark as Verified") + verify_btn.setStyleSheet( + f"QPushButton {{ background-color: {t.success}; color: {t.bg_primary}; " + f"font-weight: bold; padding: 8px 16px; border-radius: 6px; }}" + f"QPushButton:hover {{ opacity: 0.9; }}" + ) + verify_btn.clicked.connect(self._on_verify) + btn_row.addWidget(verify_btn) + else: + unverify_btn = QPushButton("Remove Verification") + unverify_btn.setStyleSheet( + f"QPushButton {{ background-color: {t.warning}; color: {t.bg_primary}; " + f"font-weight: bold; padding: 8px 16px; border-radius: 6px; }}" + ) + unverify_btn.clicked.connect(self._on_unverify) + btn_row.addWidget(unverify_btn) + + # Scan QR button + scan_btn = QPushButton("Scan QR Code") + scan_btn.setObjectName("secondaryBtn") + scan_btn.clicked.connect(self._on_scan_qr) + btn_row.addWidget(scan_btn) + + lay.addLayout(btn_row) + + close_btn = QPushButton("Close") + close_btn.setObjectName("secondaryBtn") + close_btn.clicked.connect(self.accept) + lay.addWidget(close_btn) + + def _generate_qr_pixmap(self, data: bytes) -> QPixmap | None: + """Generate a QR code QPixmap from raw bytes (base64-encoded for scanner compat).""" + try: + import qrcode + import base64 + from io import BytesIO + # Encode as base64 — raw binary gets corrupted by QR readers (UTF-8 re-encoding) + qr = qrcode.QRCode(version=1, error_correction=qrcode.constants.ERROR_CORRECT_L, + box_size=4, border=2) + qr.add_data(base64.b64encode(data).decode("ascii")) + qr.make(fit=True) + img = qr.make_image(fill_color="black", back_color="white") + buf = BytesIO() + img.save(buf, format="PNG") + buf.seek(0) + qimg = QImage() + qimg.loadFromData(buf.getvalue()) + if qimg.isNull(): + return None + return QPixmap.fromImage(qimg) + except ImportError: + return None + except Exception: + return None + + def _save_qr(self, pixmap: QPixmap): + """Save QR code image to file.""" + path, _ = QFileDialog.getSaveFileName( + self, "Save QR Code", "verification_qr.png", + "PNG Images (*.png);;All Files (*)", + ) + if path: + pixmap.save(path, "PNG") + + def _on_verify(self): + cached = self.bridge.client._user_cache.get(self.peer_user_id) + if cached and cached.get("identity_key_bytes"): + self.bridge.client.verify_contact( + self.peer_user_id, cached["identity_key_bytes"], method="safety_number" + ) + t = c() + self._status_label.setText("\u2705 Verified") + self._status_label.setStyleSheet(f"font-size: 11pt; font-weight: bold; color: {t.success};") + + def _on_unverify(self): + self.bridge.client.unverify_contact(self.peer_user_id) + t = c() + self._status_label.setText("\U0001f512 Trusted (TOFU)") + self._status_label.setStyleSheet(f"font-size: 11pt; font-weight: bold; color: {t.warning};") + + def _on_scan_qr(self): + """Open file picker for QR code image, decode and verify.""" + path, _ = QFileDialog.getOpenFileName( + self, "Select QR Code Image", "", + "Images (*.png *.jpg *.jpeg *.bmp);;All Files (*)", + ) + if not path: + return + try: + from PIL import Image + pil_img = Image.open(path) + except Exception as e: + QMessageBox.warning(self, "Error", f"Failed to open image: {e}") + return + # Try pyzbar first, fall back to manual decode + qr_text = None + try: + from pyzbar.pyzbar import decode as pyzbar_decode + results = pyzbar_decode(pil_img) + if results: + qr_text = results[0].data + except ImportError: + pass + if qr_text is None: + QMessageBox.information( + self, "QR Scan", + "Could not decode QR code. Install 'pyzbar' for QR scanning support, " + "or verify manually using the safety number above." + ) + return + # QR contains base64-encoded binary payload + import base64 + try: + qr_data = base64.b64decode(qr_text) + except Exception: + QMessageBox.warning(self, "Error", "Invalid QR code format.") + return + ok, user_id, message = self.bridge.client.verify_qr_code(qr_data) + if ok: + t = c() + self._status_label.setText("\u2705 Verified") + self._status_label.setStyleSheet(f"font-size: 11pt; font-weight: bold; color: {t.success};") + QMessageBox.information(self, "Verification", message) + else: + QMessageBox.warning(self, "Verification Failed", message) + + +class LoginWindow(QWidget): + def __init__(self, bridge: AsyncBridge): + super().__init__() + self.bridge = bridge + self.setWindowTitle("Encrypted Chat - Login") + self.setFixedSize(500, 540) + self._pair_email = "" + self._pair_password = "" + self._build_ui() + tm().on_change(self._apply_theme) + + def _login_card_qss(self): + t = c() + return ( + f"#loginCard {{ background-color: {t.bg_primary}; border-radius: 16px; }}" + f"#loginCard QWidget {{ background: transparent; }}" + f"#loginCard QLabel {{ background: transparent; color: {t.text_primary}; }}" + f"#loginCard QLineEdit {{" + f" background-color: {t.bg_secondary}; color: {t.text_primary};" + f" border: 1px solid {t.border}; border-radius: 6px; padding: 8px;" + f"}}" + f"#loginCard QLineEdit:focus {{ border: 1px solid {t.border_focus}; }}" + f"#loginCard QPushButton {{" + f" background-color: {t.accent}; color: {t.accent_text};" + f" border: none; border-radius: 6px; padding: 8px 16px; font-weight: bold;" + f"}}" + f"#loginCard QPushButton:hover {{ background-color: {t.accent_hover}; }}" + ) + + def _tab_bar_qss(self): + t = c() + return ( + f"QTabBar {{ background: transparent; border: none; }}" + f"QTabBar::tab {{ background: transparent; color: {t.text_muted}; " + f"padding: 10px 24px; font-size: 10pt; border: none; " + f"border-bottom: 2px solid transparent; }}" + f"QTabBar::tab:selected {{ color: {t.accent}; font-weight: bold; " + f"border-bottom: 2px solid {t.accent}; }}" + f"QTabBar::tab:hover {{ color: {t.text_primary}; }}" + f"QTabWidget::pane {{ border: none; background: transparent; }}" + ) + + def _apply_theme(self): + QApplication.instance().setStyleSheet(qss()) + t = c() + self.setStyleSheet(f"background-color: {t.bg_tertiary};") + self._card.setStyleSheet(self._login_card_qss()) + self._subtitle.setStyleSheet(f"color: {t.text_muted}; font-size: 9pt; margin-bottom: 8px;") + self._theme_btn.setText("\u2600" if tm().is_dark else "\U0001f319") + self._theme_btn.setStyleSheet( + f"QPushButton {{ background: transparent; color: {t.text_muted}; " + f"border: none; font-size: 14pt; }}" + f"QPushButton:hover {{ color: {t.accent}; }}" + ) + self._tabs.setStyleSheet(self._tab_bar_qss()) + # Verification page + self._step_label.setStyleSheet(f"color: {t.accent}; font-weight: bold; font-size: 10pt;") + self._info_label.setStyleSheet(f"color: {t.text_primary}; font-size: 10pt;") + self.code_input.setStyleSheet( + f"QLineEdit {{ font-size: 16pt; letter-spacing: 8px; text-align: center; " + f"background-color: {t.bg_secondary}; border: 1px solid {t.border}; " + f"border-radius: 6px; padding: 12px; color: {t.text_primary}; }}" + f"QLineEdit:focus {{ border: 1px solid {t.border_focus}; }}" + ) + + def _build_ui(self): + from PyQt6.QtWidgets import QTabWidget + outer = QVBoxLayout(self) + outer.setContentsMargins(0, 0, 0, 0) + t = c() + + # Background fills entire window + self.setStyleSheet(f"background-color: {t.bg_tertiary};") + + # Theme toggle in top-right corner + top_row = QHBoxLayout() + top_row.setContentsMargins(12, 8, 12, 0) + top_row.addStretch() + self._theme_btn = QPushButton("\u2600" if tm().is_dark else "\U0001f319") + self._theme_btn.setFixedSize(36, 36) + self._theme_btn.setToolTip("Toggle light/dark mode") + self._theme_btn.setStyleSheet( + f"QPushButton {{ background: transparent; color: {t.text_muted}; " + f"border: none; font-size: 14pt; }}" + f"QPushButton:hover {{ color: {t.accent}; }}" + ) + self._theme_btn.clicked.connect(tm().toggle) + top_row.addWidget(self._theme_btn) + outer.addLayout(top_row) + + self.stack = QStackedWidget() + outer.addWidget(self.stack) + + # --- Page 0: Login / Register form (card) --- + page0 = QWidget() + page0_layout = QVBoxLayout(page0) + page0_layout.setContentsMargins(40, 0, 40, 20) + page0_layout.setAlignment(Qt.AlignmentFlag.AlignCenter) + + self._card = QWidget() + self._card.setObjectName("loginCard") + self._card.setStyleSheet(self._login_card_qss()) + card_layout = QVBoxLayout(self._card) + card_layout.setSpacing(10) + card_layout.setContentsMargins(36, 28, 36, 28) + + title = QLabel("Encrypted Chat") + title.setObjectName("title") + title.setAlignment(Qt.AlignmentFlag.AlignCenter) + card_layout.addWidget(title) + + self._subtitle = QLabel("End-to-end encrypted messaging") + self._subtitle.setAlignment(Qt.AlignmentFlag.AlignCenter) + self._subtitle.setStyleSheet(f"color: {t.text_muted}; font-size: 9pt; margin-bottom: 8px;") + card_layout.addWidget(self._subtitle) + + # --- Tabs: Login | Register | Link Device --- + self._tabs = QTabWidget() + self._tabs.setMinimumHeight(220) + self._tabs.setStyleSheet(self._tab_bar_qss()) + + # == Login tab == + login_tab = QWidget() + login_lay = QVBoxLayout(login_tab) + login_lay.setSpacing(12) + login_lay.setContentsMargins(0, 12, 0, 4) + + self.email_input = QLineEdit() + self.email_input.setPlaceholderText("Email") + login_lay.addWidget(self.email_input) + + self.password_input = QLineEdit() + self.password_input.setPlaceholderText("Password") + self.password_input.setEchoMode(QLineEdit.EchoMode.Password) + self.password_input.returnPressed.connect(self._on_login) + login_lay.addWidget(self.password_input) + + login_lay.addStretch() + self.login_btn = QPushButton("Login") + self.login_btn.setMinimumHeight(40) + self.login_btn.clicked.connect(self._on_login) + login_lay.addWidget(self.login_btn) + + self._tabs.addTab(login_tab, "Login") + + # == Register tab == + reg_tab = QWidget() + reg_lay = QVBoxLayout(reg_tab) + reg_lay.setSpacing(12) + reg_lay.setContentsMargins(0, 12, 0, 4) + + self.username_input = QLineEdit() + self.username_input.setPlaceholderText("Username (display name)") + reg_lay.addWidget(self.username_input) + + self._reg_email_input = QLineEdit() + self._reg_email_input.setPlaceholderText("Email") + reg_lay.addWidget(self._reg_email_input) + + self._reg_password_input = QLineEdit() + self._reg_password_input.setPlaceholderText("Password") + self._reg_password_input.setEchoMode(QLineEdit.EchoMode.Password) + self._reg_password_input.returnPressed.connect(self._on_register) + reg_lay.addWidget(self._reg_password_input) + + reg_lay.addStretch() + self.register_btn = QPushButton("Register") + self.register_btn.setMinimumHeight(40) + self.register_btn.clicked.connect(self._on_register) + reg_lay.addWidget(self.register_btn) + + self._tabs.addTab(reg_tab, "Register") + + # == Link Device tab == + link_tab = QWidget() + link_lay = QVBoxLayout(link_tab) + link_lay.setSpacing(12) + link_lay.setContentsMargins(0, 12, 0, 4) + + self._link_email_input = QLineEdit() + self._link_email_input.setPlaceholderText("Email") + link_lay.addWidget(self._link_email_input) + + self._link_password_input = QLineEdit() + self._link_password_input.setPlaceholderText("Password") + self._link_password_input.setEchoMode(QLineEdit.EchoMode.Password) + self._link_password_input.returnPressed.connect(self._on_link_device) + link_lay.addWidget(self._link_password_input) + + link_lay.addStretch() + self.link_btn = QPushButton("Link Device") + self.link_btn.setMinimumHeight(40) + self.link_btn.clicked.connect(self._on_link_device) + link_lay.addWidget(self.link_btn) + + self._tabs.addTab(link_tab, "Link Device") + + card_layout.addWidget(self._tabs) + + self.status_label = QLabel("") + self.status_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.status_label.setWordWrap(True) + card_layout.addWidget(self.status_label) + + page0_layout.addWidget(self._card) + self.stack.addWidget(page0) + + # --- Page 1: Verification code form --- + page1 = QWidget() + vl = QVBoxLayout(page1) + vl.setSpacing(14) + vl.setContentsMargins(50, 40, 50, 40) + + self._step_label = QLabel("Step 2 of 2") + self._step_label.setStyleSheet(f"color: {t.accent}; font-weight: bold; font-size: 10pt;") + self._step_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + vl.addWidget(self._step_label) + + self._info_label = QLabel("Enter the 6-digit verification code sent to your email") + self._info_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + self._info_label.setWordWrap(True) + self._info_label.setStyleSheet(f"color: {t.text_primary}; font-size: 10pt;") + vl.addWidget(self._info_label) + + vl.addSpacing(12) + + self.code_input = QLineEdit() + self.code_input.setPlaceholderText("000000") + self.code_input.setMaxLength(6) + self.code_input.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.code_input.setStyleSheet( + f"QLineEdit {{ font-size: 16pt; letter-spacing: 8px; text-align: center; " + f"background-color: {t.bg_secondary}; border: 1px solid {t.border}; " + f"border-radius: 6px; padding: 12px; color: {t.text_primary}; }}" + f"QLineEdit:focus {{ border: 1px solid {t.border_focus}; }}" + ) + self.code_input.returnPressed.connect(self._on_confirm_code) + vl.addWidget(self.code_input) + + vl.addSpacing(8) + + self.code_status_label = QLabel("") + self.code_status_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.code_status_label.setWordWrap(True) + vl.addWidget(self.code_status_label) + + code_btn_row = QHBoxLayout() + self.back_btn = QPushButton("Back") + self.back_btn.setObjectName("secondaryBtn") + self.back_btn.clicked.connect(self._on_back_to_login) + code_btn_row.addWidget(self.back_btn) + + self.confirm_btn = QPushButton("Confirm") + self.confirm_btn.clicked.connect(self._on_confirm_code) + code_btn_row.addWidget(self.confirm_btn) + vl.addLayout(code_btn_row) + + vl.addStretch() + self.stack.addWidget(page1) + + def show_verification_page(self, message=""): + """Switch to verification code page.""" + self.code_input.clear() + self.code_status_label.setText(message) + self.code_status_label.setStyleSheet(f"color: {c().success};") + self.stack.setCurrentIndex(1) + self.code_input.setFocus() + + def _on_confirm_code(self): + code = self.code_input.text().strip() + if not code: + self.code_status_label.setText("Please enter the code.") + self.code_status_label.setStyleSheet(f"color: {c().error};") + return + self.code_status_label.setText("Confirming...") + self.code_status_label.setStyleSheet(f"color: {c().success};") + self.confirm_btn.setEnabled(False) + self.back_btn.setEnabled(False) + # Callback set by main() to handle confirmation + if hasattr(self, '_confirm_callback'): + self._confirm_callback(code) + + def _on_back_to_login(self): + self.stack.setCurrentIndex(0) + self._set_enabled(True) + self.status_label.setText("") + self.status_label.setStyleSheet("") + + def _on_register(self): + username = self.username_input.text().strip() + password = self._reg_password_input.text() + email = self._reg_email_input.text().strip() + if not username: + self.show_error("Username required.") + return + if not email or not password: + self.show_error("Email and password required.") + return + self.status_label.setText("Registering...") + self._set_enabled(False) + self.bridge.do_register(username, password, email) + + def _on_login(self): + email = self.email_input.text().strip() + password = self.password_input.text() + if not email or not password: + self.show_error("Email and password required.") + return + self.status_label.setText("Logging in...") + self._set_enabled(False) + self.bridge.do_login(email, password) + + def _on_link_device(self): + email = self._link_email_input.text().strip() + password = self._link_password_input.text() + if not email or not password: + self.show_error("Email and password required.") + return + self._pair_email = email + self._pair_password = password + self.status_label.setText("Generating pairing code...") + self._set_enabled(False) + self.bridge.link_device(email, password) + + def _set_enabled(self, enabled): + self._tabs.setEnabled(enabled) + self.username_input.setEnabled(enabled) + self.email_input.setEnabled(enabled) + self.password_input.setEnabled(enabled) + self._reg_email_input.setEnabled(enabled) + self._reg_password_input.setEnabled(enabled) + self._link_email_input.setEnabled(enabled) + self._link_password_input.setEnabled(enabled) + self.register_btn.setEnabled(enabled) + self.login_btn.setEnabled(enabled) + self.link_btn.setEnabled(enabled) + + def show_error(self, msg): + if self.stack.currentIndex() == 1: + self.code_status_label.setText(msg) + self.code_status_label.setStyleSheet(f"color: {c().error};") + self.confirm_btn.setEnabled(True) + self.back_btn.setEnabled(True) + else: + self.status_label.setText(msg) + self.status_label.setStyleSheet(f"color: {c().error};") + self._set_enabled(True) + + def show_success(self, msg): + if self.stack.currentIndex() == 1: + self.code_status_label.setText(msg) + self.code_status_label.setStyleSheet(f"color: {c().success};") + else: + self.status_label.setText(msg) + self.status_label.setStyleSheet(f"color: {c().success};") + + def reset(self): + self.stack.setCurrentIndex(0) + self._tabs.setCurrentIndex(0) + self.status_label.setText("") + self.status_label.setStyleSheet("") + self.code_status_label.setText("") + self.code_status_label.setStyleSheet("") + self.username_input.clear() + self.email_input.clear() + self.password_input.clear() + self._reg_email_input.clear() + self._reg_password_input.clear() + self._link_email_input.clear() + self._link_password_input.clear() + self.code_input.clear() + self._set_enabled(True) + self.confirm_btn.setEnabled(True) + self.back_btn.setEnabled(True) + + +class MessageBubble(QFrame): + """Chat message bubble with rounded corners drawn via QPainter.""" + + def __init__(self, bg_color: str, parent=None): + super().__init__(parent) + self._bg_color = QColor(bg_color) + self.setStyleSheet("background: transparent; border: none;") + self.setContextMenuPolicy(Qt.ContextMenuPolicy.DefaultContextMenu) + + def set_bg_color(self, color: str): + self._bg_color = QColor(color) + self.update() + + def paintEvent(self, event): + p = QPainter(self) + p.setRenderHint(QPainter.RenderHint.Antialiasing) + p.setBrush(QBrush(self._bg_color)) + p.setPen(Qt.PenStyle.NoPen) + p.drawRoundedRect(self.rect(), 14, 14) + p.end() + + def contextMenuEvent(self, event): + # Walk up to find _msg_index, then call main window handler + idx = getattr(self, '_msg_index', None) + if idx is None: + p = self.parentWidget() + while p: + idx = getattr(p, '_msg_index', None) + if idx is not None: + break + p = p.parentWidget() + main_win = self.window() + if idx is not None and hasattr(main_win, '_show_msg_context_menu'): + main_win._show_msg_context_menu(idx, event.globalPos()) + event.accept() + + +class MainWindow(QWidget): + _AVATAR_REFRESH_BATCH = 8 + _GROUP_AVATAR_REFRESH_BATCH = 4 + _show_verification_dialog_signal = pyqtSignal(str, str) # peer_uid, peer_name + + def __init__(self, bridge: AsyncBridge, on_logout): + super().__init__() + self.bridge = bridge + self._on_logout_cb = on_logout + self.setWindowTitle(f"Encrypted Chat - {bridge.client.username}") + self.resize(900, 600) + + self.conversations: list[dict] = [] + self.current_conv_id: str | None = None + self.current_messages: list[dict] = [] + self.reply_to_id: str | None = None + self._unread_counts: dict[str, int] = {} + self._has_more_messages: bool = True + self._pending_image_download: dict | None = None # {file_id, image_info} + self._is_dm: bool = False + self._online_users: set[str] = set() + self._is_logout = False + self._avatar_cache = _LRUPixmapCache() # user_id -> pixmap + self._group_avatar_cache = _LRUPixmapCache() # conv_id -> pixmap + self._avatar_requested: set[str] = set() + self._group_avatar_requested: set[str] = set() + self._avatar_refresh_cursor = 0 + self._group_avatar_refresh_cursor = 0 + self._pending_invitations: list[dict] = [] + self._favorites: set[str] = self._load_favorites() + # Search state + self._search_results: list[int] = [] # indices into current_messages + self._search_current: int = -1 + self._search_query: str = "" + self._search_active: bool = False + + self._privacy_enabled: bool = True # Privacy overlay on/off + self._last_message_cache: dict[str, tuple[str, str, str]] = {} # conv_id -> (text, ts, receipt) + + self._build_ui() + self._connect_signals() + self._setup_tray_icon() + self._setup_privacy_overlay() + + # Keyboard shortcuts + QShortcut(QKeySequence("Ctrl+F"), self).activated.connect(self._toggle_search) + QShortcut(QKeySequence("Ctrl+Shift+P"), self).activated.connect(self._toggle_privacy) + + self.bridge.load_conversations() + self.bridge.list_invitations() + + # Periodic refresh: re-download avatars and conversation data + self._refresh_timer = QTimer(self) + self._refresh_timer.timeout.connect(self._on_periodic_refresh) + self._refresh_timer.start(120_000) # every 2 minutes + + tm().on_change(self._apply_theme) + + # -- Theme switching ------------------------------------------------------- + + def _apply_theme(self): + """Re-apply theme colours to all widgets after theme toggle.""" + app = QApplication.instance() + if app: + app.setStyleSheet(qss()) + t = c() + # Sidebar + self._sidebar_panel.setStyleSheet(f"#sidebarPanel {{ background-color: {t.bg_tertiary}; }}") + self._conv_label.setStyleSheet(f"font-weight: bold; font-size: 12pt; color: {t.accent}; background: transparent;") + self._settings_btn.setStyleSheet( + f"QPushButton {{ background: transparent; color: {t.text_secondary}; border: none; border-radius: 6px; padding: 8px 16px; }}" + f"QPushButton:hover {{ background-color: {t.bg_hover}; }}" + ) + self.conv_list.setStyleSheet( + f"QListWidget {{ background-color: {t.bg_tertiary}; border: none; padding: 0px; }}" + f"QListWidget::item {{ padding: 0px; border: none; }}" + f"QListWidget::item:selected {{ background: transparent; border: none; }}" + f"QListWidget::item:hover {{ background: transparent; }}" + ) + # Invitation list + self.inv_label.setStyleSheet(f"font-weight: bold; font-size: 9pt; color: {t.warning}; margin-top: 4px;") + self.inv_list.setStyleSheet( + f"QListWidget {{ background-color: {t.bg_primary}; border: 1px solid {t.warning}; border-radius: 6px; padding: 2px; }}" + f"QListWidget::item {{ padding: 6px; color: {t.text_primary}; }}" + f"QListWidget::item:hover {{ background-color: {t.bg_hover}; color: {t.text_primary}; }}" + ) + # Chat header + self._chat_header_widget.setStyleSheet( + f"#chatHeader {{ border-bottom: 1px solid {t.separator}; }}" + f"#chatHeader QLabel, #chatHeader QPushButton {{ border: none; }}" + ) + self.chat_header.setStyleSheet(f"font-weight: bold; font-size: 12pt; color: {t.accent};") + self._chat_header_status.setStyleSheet(f"color: {t.text_muted}; font-size: 8pt;") + self._e2e_label.setStyleSheet(f"font-size: 8pt; color: {t.text_muted}; background: transparent;") + self.connection_dot.setStyleSheet(f"color: {t.success}; font-size: 11pt;") + self._logout_btn.setStyleSheet( + f"QPushButton {{ background: transparent; color: {t.error}; border: none; " + f"border-radius: 16px; font-size: 13pt; font-weight: bold; }}" + f"QPushButton:hover {{ background-color: {t.error}; color: {t.accent_text}; }}" + ) + self.delete_conv_btn.setStyleSheet( + f"QPushButton {{ background: transparent; border: none; border-radius: 4px; padding: 4px; }}" + f"QPushButton:hover {{ background-color: {t.error}; }}" + ) + # Search bar + self.search_input.setStyleSheet( + f"QLineEdit {{ background-color: {t.bg_secondary}; color: {t.text_primary}; " + f"border: 1px solid {t.border}; border-radius: 4px; padding: 4px 8px; font-size: 10pt; }}" + ) + self.search_count_label.setStyleSheet(f"color: {t.text_muted}; font-size: 9pt; min-width: 40px;") + # Pin banner + self._pin_banner.setStyleSheet(f"background-color:{t.border}; border-bottom:2px solid {t.pin_color};") + self._pin_banner_label.setStyleSheet(f"color:{t.text_primary}; font-size:9pt; background:transparent; border:none;") + # Jump button + self.jump_btn.setStyleSheet( + f"QPushButton {{ background-color: {t.accent}; color: {t.accent_text}; border-radius: 18px; " + f"font-size: 14pt; font-weight: bold; }}" + f"QPushButton:hover {{ background-color: {t.accent_hover}; }}" + ) + # Reply label + self.reply_label.setStyleSheet( + f"color: {t.accent}; font-style: italic; font-size: 9pt; " + f"padding: 2px 4px; background: transparent;" + ) + # Input area + self._attach_btn.setStyleSheet( + f"QPushButton {{ background-color: {t.bg_secondary}; border: none; " + f"border-radius: 20px; font-size: 14pt; }}" + f"QPushButton:hover {{ background-color: {t.bg_hover}; }}" + ) + self._send_btn.setStyleSheet( + f"QPushButton {{ background-color: {t.accent}; color: {t.accent_text}; " + f"border: none; border-radius: 20px; font-size: 14pt; font-weight: bold; }}" + f"QPushButton:hover {{ background-color: {t.accent_hover}; }}" + ) + self.msg_input.setStyleSheet(self.msg_input._style_normal()) + # Counters + self.char_counter.setStyleSheet(f"color: {t.text_muted}; font-size: 8pt; padding: 0 4px;") + self.reencrypt_label.setStyleSheet( + f"background-color: {t.bg_secondary}; border-radius: 6px; " + f"padding: 8px 12px; color: {t.success}; font-weight: bold;" + ) + # Status bar + self.status_bar.setStyleSheet( + f"background-color: {t.bg_tertiary}; border-radius: 0px; " + f"padding: 0 8px; color: {t.success}; font-size: 8pt;" + ) + # Privacy overlay + if hasattr(self, "_privacy_overlay"): + self._privacy_overlay.setStyleSheet(f"background-color: {t.overlay};") + self._lock_hint.setStyleSheet(f"font-size: 12pt; color: {t.text_muted}; background: transparent;") + self._lock_input.setStyleSheet( + f"QLineEdit {{ font-size: 11pt; background-color: {t.bg_secondary}; " + f"border: 1px solid {t.border}; border-radius: 6px; padding: 8px; " + f"color: {t.text_primary}; }}" + f"QLineEdit:focus {{ border: 1px solid {t.border_focus}; }}" + ) + self._lock_error.setStyleSheet(f"font-size: 9pt; color: {t.error}; background: transparent;") + # Mention popup + if hasattr(self, "_mention_popup"): + self._mention_popup.setStyleSheet( + f"QListWidget {{ background:{t.bg_secondary}; color:{t.text_primary}; border:1px solid {t.border}; " + f"border-radius:4px; font-size:10pt; }}" + f"QListWidget::item {{ padding:4px 8px; }}" + f"QListWidget::item:selected {{ background:{t.border}; }}" + ) + # Message scroll area + if hasattr(self, '_msg_scroll_area'): + self._msg_scroll_area.setStyleSheet( + f"QScrollArea {{ background-color: {t.bg_primary}; border: none; }}" + ) + self._msg_container.setStyleSheet( + f"background-color: {t.bg_primary};" + ) + # Re-render messages and conversation list + self._rebuild_conv_list() + if self.current_messages: + self._render_messages(scroll_to_bottom=False) + + # -- Privacy Overlay (lock screen) ---------------------------------------- + + _LOCK_TIMEOUT_MS = 30_000 # 30 s unfocused → lock (require password) + + def _setup_privacy_overlay(self): + """Create overlay that hides content on focus loss; locks after timeout.""" + self._privacy_locked = False + # Check if identity key is password-encrypted (ECP1 format) + # If not, lock feature is disabled (no password to verify against) + self._lock_capable = False + try: + from chat_core import get_key_dir + key_path = get_key_dir(self.bridge.client.email) / "identity_private.bin" + if key_path.exists(): + self._lock_capable = key_path.read_bytes()[:4] == b"ECP1" + except Exception: + pass + + t = c() + # -- overlay widget -- + self._privacy_overlay = QWidget(self) + self._privacy_overlay.setStyleSheet( + f"background-color: {t.overlay};" + ) + self._privacy_overlay.hide() + + overlay_layout = QVBoxLayout(self._privacy_overlay) + overlay_layout.setAlignment(Qt.AlignmentFlag.AlignCenter) + + lock_icon = QLabel("\U0001f512") + lock_icon.setStyleSheet("font-size: 36pt; background: transparent;") + lock_icon.setAlignment(Qt.AlignmentFlag.AlignCenter) + overlay_layout.addWidget(lock_icon) + + self._lock_hint = QLabel("Encrypted Chat") + self._lock_hint.setStyleSheet( + f"font-size: 12pt; color: {t.text_muted}; background: transparent;" + ) + self._lock_hint.setAlignment(Qt.AlignmentFlag.AlignCenter) + overlay_layout.addWidget(self._lock_hint) + + # Password input (hidden until locked) + self._lock_input = QLineEdit() + self._lock_input.setPlaceholderText("Enter password to unlock") + self._lock_input.setEchoMode(QLineEdit.EchoMode.Password) + self._lock_input.setMaximumWidth(280) + self._lock_input.setStyleSheet( + f"QLineEdit {{ font-size: 11pt; background-color: {t.bg_secondary}; " + f"border: 1px solid {t.border}; border-radius: 6px; padding: 8px; " + f"color: {t.text_primary}; }}" + f"QLineEdit:focus {{ border: 1px solid {t.border_focus}; }}" + ) + self._lock_input.returnPressed.connect(self._on_unlock_attempt) + self._lock_input.hide() + overlay_layout.addWidget(self._lock_input, alignment=Qt.AlignmentFlag.AlignCenter) + + self._lock_error = QLabel("") + self._lock_error.setStyleSheet( + f"font-size: 9pt; color: {t.error}; background: transparent;" + ) + self._lock_error.setAlignment(Qt.AlignmentFlag.AlignCenter) + self._lock_error.hide() + overlay_layout.addWidget(self._lock_error) + + # Timer: after N seconds unfocused → require password + self._lock_timer = QTimer(self) + self._lock_timer.setSingleShot(True) + self._lock_timer.timeout.connect(self._on_lock_timeout) + + def _toggle_privacy(self): + """Toggle privacy overlay on/off (Ctrl+Shift+P).""" + self._privacy_enabled = not self._privacy_enabled + if not self._privacy_enabled: + self._privacy_locked = False + self._lock_timer.stop() + self._hide_privacy_overlay() + state = "ON" if self._privacy_enabled else "OFF" + base_title = f"Encrypted Chat - {self.bridge.client.username}" + self.setWindowTitle(f"{base_title} [Privacy: {state}]") + QTimer.singleShot(2000, lambda: self.setWindowTitle(base_title)) + + def _show_privacy_overlay(self): + if not self._privacy_enabled: + return + if not self._privacy_overlay.isVisible(): + self._privacy_overlay.setGeometry(self.rect()) + self._privacy_overlay.raise_() + self._privacy_overlay.show() + # Start lock countdown + self._lock_timer.start(self._LOCK_TIMEOUT_MS) + + def _hide_privacy_overlay(self): + self._lock_timer.stop() + self._lock_input.hide() + self._lock_input.clear() + self._lock_error.hide() + self._lock_hint.setText("Encrypted Chat") + if self._privacy_overlay.isVisible(): + self._privacy_overlay.hide() + + def _on_lock_timeout(self): + """Window unfocused too long — require password.""" + if self._privacy_overlay.isVisible() and self._lock_capable: + self._privacy_locked = True + self._lock_hint.setText("Session locked") + self._lock_input.show() + self._lock_input.setFocus() + + def _on_unlock_attempt(self): + """Verify password by decrypting identity key from disk.""" + from chat_core import get_key_dir, _check_lockout, _record_failed_attempt, _clear_lockout + from crypto_utils import _decrypt_private_key + pwd = self._lock_input.text() + if not pwd: + return + email = self.bridge.client.email + remaining = _check_lockout(email) + if remaining > 0: + self._lock_error.setText(f"Too many attempts. Wait {remaining:.0f}s.") + self._lock_error.show() + self._lock_input.clear() + self._lock_input.setFocus() + return + try: + key_path = get_key_dir(email) / "identity_private.bin" + data = key_path.read_bytes() + _decrypt_private_key(data, pwd.encode("utf-8")) + # Success — unlock + _clear_lockout(email) + self._privacy_locked = False + self._hide_privacy_overlay() + except Exception: + _record_failed_attempt(email) + remaining = _check_lockout(email) + if remaining > 0: + self._lock_error.setText(f"Wrong password. Wait {remaining:.0f}s.") + else: + self._lock_error.setText("Wrong password") + self._lock_error.show() + self._lock_input.clear() + self._lock_input.setFocus() + + def changeEvent(self, event): + """Handle window state changes — tray minimize + privacy overlay.""" + from PyQt6.QtCore import QEvent + if event.type() == QEvent.Type.WindowStateChange: + if self.isMinimized() and self._tray_icon is not None: + event.ignore() + self.hide() + return + if event.type() == QEvent.Type.ActivationChange: + if self.isActiveWindow(): + if not self._privacy_locked: + self._hide_privacy_overlay() + else: + # Locked — keep overlay, focus password input + self._lock_input.setFocus() + else: + self._show_privacy_overlay() + super().changeEvent(event) + + def resizeEvent(self, event): + """Keep privacy overlay sized to window.""" + super().resizeEvent(event) + if hasattr(self, "_privacy_overlay"): + self._privacy_overlay.setGeometry(self.rect()) + + # -- System Tray ---------------------------------------------------------- + + def _make_tray_icon(self) -> QIcon: + """Create a simple app icon (blue chat bubble) for the system tray.""" + px = QPixmap(64, 64) + px.fill(QColor(0, 0, 0, 0)) + p = QPainter(px) + p.setRenderHint(QPainter.RenderHint.Antialiasing) + p.setBrush(QBrush(QColor(c().accent))) + p.setPen(Qt.PenStyle.NoPen) + p.drawRoundedRect(4, 4, 56, 48, 12, 12) + # small triangle (tail) + from PyQt6.QtGui import QPolygon + from PyQt6.QtCore import QPoint + p.drawPolygon(QPolygon([QPoint(14, 52), QPoint(24, 44), QPoint(30, 52)])) + # lock icon (E2E indicator) + p.setBrush(QBrush(QColor(c().accent_text))) + p.drawRoundedRect(24, 16, 16, 14, 3, 3) + p.setPen(QPen(QColor(c().accent_text), 3)) + p.setBrush(Qt.BrushStyle.NoBrush) + p.drawArc(27, 10, 10, 12, 0, 180 * 16) + p.end() + return QIcon(px) + + def _setup_tray_icon(self): + """Initialize the system tray icon with context menu.""" + if not QSystemTrayIcon.isSystemTrayAvailable(): + self._tray_icon = None + return + self._tray_icon = QSystemTrayIcon(self._make_tray_icon(), self) + self._tray_icon.setToolTip(f"Encrypted Chat - {self.bridge.client.username}") + self._tray_icon.activated.connect(self._on_tray_activated) + + tray_menu = QMenu() + show_action = tray_menu.addAction("Show") + show_action.triggered.connect(self._restore_from_tray) + tray_menu.addSeparator() + quit_action = tray_menu.addAction("Quit") + quit_action.triggered.connect(self._quit_from_tray) + self._tray_icon.setContextMenu(tray_menu) + self._tray_icon.show() + + def _on_tray_activated(self, reason): + """Handle tray icon click — restore window on double-click or single click.""" + if reason in (QSystemTrayIcon.ActivationReason.Trigger, + QSystemTrayIcon.ActivationReason.DoubleClick): + self._restore_from_tray() + + def _restore_from_tray(self): + """Restore window from system tray.""" + self.showNormal() + self.activateWindow() + self.raise_() + + def _quit_from_tray(self): + """Quit the application from tray menu.""" + if self._tray_icon: + self._tray_icon.hide() + self._is_logout = False + self.close() + + def _show_tray_notification(self, title: str, text: str): + """Show a system tray toast notification if the window is not in the foreground.""" + if not self._tray_icon: + logger.debug("Tray notification skipped: no tray icon") + return + if self.isVisible() and self.isActiveWindow() and not self._privacy_locked: + return # user is looking at the app (and it's not locked) + if len(text) > 120: + text = text[:117] + "..." + logger.info("Tray notification: %s — %s", title, text[:50]) + self._tray_icon.showMessage( + title, text, + QSystemTrayIcon.MessageIcon.Information, 4000, + ) + + # -- End Tray --------------------------------------------------------------- + + def _bold_font(self) -> QFont: + """Return a bold font with a valid size (avoids QFont pointSize=-1 warnings).""" + f = QFont(self.conv_list.font()) + f.setBold(True) + # Stylesheet sets font-size in px so pointSize is -1; fix by using pixelSize + if f.pointSize() <= 0: + px = f.pixelSize() + if px > 0: + f.setPixelSize(px) + else: + f.setPointSize(10) + return f + + def _make_circular_avatar(self, pixmap: QPixmap, size: int = 32) -> QPixmap: + """Crop a pixmap into a circle.""" + scaled = pixmap.scaled(size, size, Qt.AspectRatioMode.KeepAspectRatioByExpanding, + Qt.TransformationMode.SmoothTransformation) + result = QPixmap(size, size) + result.fill(QColor(0, 0, 0, 0)) + painter = QPainter(result) + painter.setRenderHint(QPainter.RenderHint.Antialiasing) + painter.setBrush(QBrush(scaled)) + painter.setPen(Qt.PenStyle.NoPen) + painter.drawEllipse(0, 0, size, size) + painter.end() + return result + + def _make_default_avatar(self, username: str, size: int = 32) -> QPixmap: + """Generate a colored circle with the first letter of the username.""" + # Deterministic color from username — higher saturation in light mode + hue = (hash(username) % 360) + sat = 160 if not tm().is_dark else 120 + val = 180 if not tm().is_dark else 200 + color = QColor.fromHsv(hue, sat, val) + result = QPixmap(size, size) + result.fill(QColor(0, 0, 0, 0)) + painter = QPainter(result) + painter.setRenderHint(QPainter.RenderHint.Antialiasing) + painter.setBrush(QBrush(color)) + painter.setPen(Qt.PenStyle.NoPen) + painter.drawEllipse(0, 0, size, size) + # Draw letter + painter.setPen(QColor(255, 255, 255)) + font = QFont("Segoe UI Variable", int(size * 0.4)) + if not font.exactMatch(): + font = QFont("Segoe UI", int(size * 0.4)) + font.setBold(True) + painter.setFont(font) + letter = username[0].upper() if username else "?" + painter.drawText(0, 0, size, size, Qt.AlignmentFlag.AlignCenter, letter) + painter.end() + return result + + def _add_online_dot(self, avatar: QPixmap) -> QPixmap: + """Overlay a green dot on the bottom-right of an avatar pixmap.""" + result = QPixmap(avatar) + painter = QPainter(result) + painter.setRenderHint(QPainter.RenderHint.Antialiasing) + dot_size = max(8, avatar.width() // 4) + x = avatar.width() - dot_size + y = avatar.height() - dot_size + # Border ring (matches sidebar background) + painter.setBrush(QBrush(QColor(c().online_dot_border))) + painter.setPen(Qt.PenStyle.NoPen) + painter.drawEllipse(x - 1, y - 1, dot_size + 2, dot_size + 2) + # Green dot + painter.setBrush(QBrush(QColor(c().online_dot))) + painter.drawEllipse(x, y, dot_size, dot_size) + painter.end() + return result + + def _get_conv_avatar(self, conv: dict) -> QIcon: + """Get avatar icon for a conversation list item.""" + is_dm = len(conv["members"]) == 2 and not conv.get("name") + if is_dm: + other = None + for m in conv["members"]: + if m.get("email") != self.bridge.client.email: + other = m + break + if other: + uid = other.get("user_id") or other.get("id") or "" + uname = other.get("username") or other.get("email") or "?" + if uid in self._avatar_cache: + avatar = self._make_circular_avatar(self._avatar_cache[uid]) + else: + avatar = self._make_default_avatar(uname) + # Request avatar download if not yet requested + if uid and uid not in self._avatar_requested: + self._avatar_requested.add(uid) + self.bridge.get_avatar(uid) + if uid in self._online_users: + avatar = self._add_online_dot(avatar) + return QIcon(avatar) + # Group: use group avatar if available + conv_id = conv.get("conversation_id") or "" + if conv_id in self._group_avatar_cache: + return QIcon(self._make_circular_avatar(self._group_avatar_cache[conv_id])) + gname = conv.get("name") or "G" + # Request group avatar download if has avatar_file + if conv.get("avatar_file") and conv_id and conv_id not in self._group_avatar_requested: + self._group_avatar_requested.add(conv_id) + self.bridge.get_group_avatar(conv_id) + return QIcon(self._make_default_avatar(gname)) + + def _get_conv_avatar_pixmap(self, conv: dict, size: int = 44) -> QPixmap: + """Get avatar QPixmap for delegate painting.""" + is_dm = len(conv["members"]) == 2 and not conv.get("name") + if is_dm: + other = None + for m in conv["members"]: + if m.get("email") != self.bridge.client.email: + other = m + break + if other: + uid = other.get("user_id") or other.get("id") or "" + uname = other.get("username") or other.get("email") or "?" + if uid in self._avatar_cache: + avatar = self._make_circular_avatar(self._avatar_cache[uid], size) + else: + avatar = self._make_default_avatar(uname, size) + if uid and uid not in self._avatar_requested: + self._avatar_requested.add(uid) + self.bridge.get_avatar(uid) + if uid in self._online_users: + avatar = self._add_online_dot(avatar) + return avatar + conv_id = conv.get("conversation_id") or "" + if conv_id in self._group_avatar_cache: + return self._make_circular_avatar(self._group_avatar_cache[conv_id], size) + gname = conv.get("name") or "G" + if conv.get("avatar_file") and conv_id and conv_id not in self._group_avatar_requested: + self._group_avatar_requested.add(conv_id) + self.bridge.get_group_avatar(conv_id) + return self._make_default_avatar(gname, size) + + def _build_ui(self): + main_layout = QHBoxLayout(self) + main_layout.setContentsMargins(0, 0, 0, 0) + main_layout.setSpacing(0) + + splitter = QSplitter(Qt.Orientation.Horizontal) + + # Left panel - conversations + left = QWidget() + left_layout = QVBoxLayout(left) + left_layout.setContentsMargins(8, 8, 4, 8) + + t = c() + left.setObjectName("sidebarPanel") + left.setStyleSheet(f"#sidebarPanel {{ background-color: {t.bg_tertiary}; }}") + self._sidebar_panel = left + + header_row = QHBoxLayout() + self._conv_label = QLabel("Conversations") + self._conv_label.setStyleSheet(f"font-weight: bold; font-size: 12pt; color: {t.accent}; background: transparent;") + header_row.addWidget(self._conv_label) + header_row.addStretch() + + new_chat_btn = QPushButton("") + new_chat_btn.setFixedSize(32, 32) + new_chat_btn.setObjectName("toolBtn") + new_chat_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_FileDialogNewFolder)) + new_chat_btn.setToolTip("New Chat") + new_chat_btn.clicked.connect(self._on_new_chat) + header_row.addWidget(new_chat_btn) + + group_btn = QPushButton("") + group_btn.setFixedSize(32, 32) + group_btn.setObjectName("toolBtn") + group_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_DirIcon)) + group_btn.setToolTip("New Group") + group_btn.clicked.connect(self._on_new_group) + header_row.addWidget(group_btn) + + profile_btn = QPushButton("") + profile_btn.setFixedSize(32, 32) + profile_btn.setObjectName("toolBtn") + profile_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_FileDialogInfoView)) + profile_btn.setToolTip("My Profile") + profile_btn.clicked.connect(self._on_my_profile) + header_row.addWidget(profile_btn) + + left_layout.addLayout(header_row) + + # Invitation section (hidden when empty) + self.inv_label = QLabel("Pending Invitations") + self.inv_label.setStyleSheet(f"font-weight: bold; font-size: 9pt; color: {t.warning}; margin-top: 4px;") + self.inv_label.setVisible(False) + left_layout.addWidget(self.inv_label) + + self.inv_list = QListWidget() + self.inv_list.setMaximumHeight(120) + self.inv_list.setVisible(False) + self.inv_list.setContextMenuPolicy(Qt.ContextMenuPolicy.CustomContextMenu) + self.inv_list.customContextMenuRequested.connect(self._on_inv_context_menu) + self.inv_list.setStyleSheet( + f"QListWidget {{ background-color: {t.bg_primary}; border: 1px solid {t.warning}; border-radius: 6px; padding: 2px; }}" + f"QListWidget::item {{ padding: 6px; color: {t.text_primary}; }}" + f"QListWidget::item:hover {{ background-color: {t.bg_hover}; color: {t.text_primary}; }}" + ) + left_layout.addWidget(self.inv_list) + + self.conv_list = QListWidget() + self.conv_list.setIconSize(QSize(44, 44)) + # Override global QSS item styles — delegate handles all painting + self.conv_list.setStyleSheet( + f"QListWidget {{ background-color: {t.bg_tertiary}; border: none; padding: 0px; }}" + f"QListWidget::item {{ padding: 0px; border: none; }}" + f"QListWidget::item:selected {{ background: transparent; border: none; }}" + f"QListWidget::item:hover {{ background: transparent; }}" + ) + self._conv_delegate = ConversationDelegate(self.conv_list) + self.conv_list.setItemDelegate(self._conv_delegate) + self.conv_list.setMouseTracking(True) # for hover painting + self.conv_list.currentRowChanged.connect(self._on_conv_selected) + self.conv_list.setContextMenuPolicy(Qt.ContextMenuPolicy.CustomContextMenu) + self.conv_list.customContextMenuRequested.connect(self._on_conv_list_context_menu) + left_layout.addWidget(self.conv_list) + + # Bottom toolbar row: settings + logout + bottom_row = QHBoxLayout() + bottom_row.setContentsMargins(0, 4, 0, 0) + + settings_btn = QPushButton("\u2699 Settings") + settings_btn.setObjectName("sidebarBtn") + settings_btn.setToolTip("Open settings") + settings_btn.setStyleSheet( + f"QPushButton {{ background: transparent; color: {t.text_secondary}; border: none; border-radius: 6px; padding: 8px 16px; }}" + f"QPushButton:hover {{ background-color: {t.bg_hover}; }}" + ) + settings_btn.clicked.connect(self._on_open_settings) + self._settings_btn = settings_btn + bottom_row.addWidget(settings_btn) + + logout_btn = QPushButton("\u2715") + logout_btn.setFixedSize(32, 32) + logout_btn.setStyleSheet( + f"QPushButton {{ background: transparent; color: {t.error}; border: none; " + f"border-radius: 16px; font-size: 13pt; font-weight: bold; }}" + f"QPushButton:hover {{ background-color: {t.error}; color: {t.accent_text}; }}" + ) + logout_btn.setToolTip("Logout") + logout_btn.clicked.connect(self._on_logout) + self._logout_btn = logout_btn + bottom_row.addWidget(logout_btn) + + left_layout.addLayout(bottom_row) + + # Right panel - messages + right = QWidget() + right_layout = QVBoxLayout(right) + right_layout.setContentsMargins(4, 8, 8, 8) + + # Chat header bar (56px height) + chat_header_widget = QWidget() + chat_header_widget.setFixedHeight(56) + self._chat_header_widget = chat_header_widget + chat_header_widget.setObjectName("chatHeader") + chat_header_widget.setStyleSheet( + f"#chatHeader {{ border-bottom: 1px solid {t.separator}; }}" + f"#chatHeader QLabel, #chatHeader QPushButton {{ border: none; }}" + ) + chat_header_row = QHBoxLayout(chat_header_widget) + chat_header_row.setContentsMargins(12, 4, 8, 4) + + self.chat_header_avatar = QLabel() + self.chat_header_avatar.setFixedSize(40, 40) + self.chat_header_avatar.setStyleSheet("background: transparent;") + self.chat_header_avatar.setVisible(False) + chat_header_row.addWidget(self.chat_header_avatar) + + # Name + status text vertical stack + name_status_layout = QVBoxLayout() + name_status_layout.setSpacing(0) + name_status_layout.setContentsMargins(6, 0, 0, 0) + self.chat_header = QLabel("Select a conversation") + self.chat_header.setStyleSheet(f"font-weight: bold; font-size: 12pt; color: {t.accent};") + name_status_layout.addWidget(self.chat_header) + + self._chat_header_status = QLabel("") + self._chat_header_status.setStyleSheet(f"color: {t.text_muted}; font-size: 8pt;") + self._chat_header_status.setVisible(False) + name_status_layout.addWidget(self._chat_header_status) + chat_header_row.addLayout(name_status_layout) + + # E2E lock indicator + self._e2e_label = QLabel("\U0001f512 End-to-end encrypted") + self._e2e_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + self._e2e_label.setToolTip("End-to-end encrypted") + self._e2e_label.setStyleSheet(f"font-size: 8pt; color: {t.text_muted}; background: transparent;") + self._e2e_label.setCursor(Qt.CursorShape.PointingHandCursor) + self._e2e_label.mousePressEvent = self._on_e2e_label_clicked + self._e2e_label.setVisible(False) + chat_header_row.addWidget(self._e2e_label) + + self.connection_dot = QLabel("\u25cf") + self.connection_dot.setFixedSize(16, 16) + self.connection_dot.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.connection_dot.setStyleSheet(f"color: {t.success}; font-size: 11pt;") + self.connection_dot.setToolTip("Connected") + chat_header_row.addWidget(self.connection_dot) + chat_header_row.addStretch() + + self.group_info_btn = QPushButton("") + self.group_info_btn.setFixedSize(32, 32) + self.group_info_btn.setObjectName("toolBtn") + self.group_info_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_MessageBoxInformation)) + self.group_info_btn.setToolTip("Group Info") + self.group_info_btn.clicked.connect(self._on_group_info) + self.group_info_btn.setVisible(False) + chat_header_row.addWidget(self.group_info_btn) + + self.user_info_btn = QPushButton("") + self.user_info_btn.setFixedSize(32, 32) + self.user_info_btn.setObjectName("toolBtn") + self.user_info_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_FileDialogInfoView)) + self.user_info_btn.setToolTip("User Info") + self.user_info_btn.clicked.connect(self._on_dm_user_info) + self.user_info_btn.setVisible(False) + chat_header_row.addWidget(self.user_info_btn) + + self.delete_conv_btn = QPushButton("") + self.delete_conv_btn.setFixedSize(32, 32) + self.delete_conv_btn.setStyleSheet( + f"QPushButton {{ background: transparent; border: none; border-radius: 4px; padding: 4px; }}" + f"QPushButton:hover {{ background-color: {t.error}; }}" + ) + self.delete_conv_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_TrashIcon)) + self.delete_conv_btn.setToolTip("Delete conversation") + self.delete_conv_btn.clicked.connect(self._on_delete_conv_btn) + self.delete_conv_btn.setVisible(False) + chat_header_row.addWidget(self.delete_conv_btn) + + self.add_member_btn = QPushButton("") + self.add_member_btn.setFixedSize(32, 32) + self.add_member_btn.setObjectName("toolBtn") + self.add_member_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_FileDialogNewFolder)) + self.add_member_btn.setToolTip("Add Member") + self.add_member_btn.clicked.connect(self._on_add_member) + self.add_member_btn.setVisible(False) + chat_header_row.addWidget(self.add_member_btn) + + self.pin_list_btn = QPushButton("\U0001f4cc") + self.pin_list_btn.setFixedSize(32, 32) + self.pin_list_btn.setObjectName("toolBtn") + self.pin_list_btn.setToolTip("Pinned messages") + self.pin_list_btn.clicked.connect(self._show_pinned_messages) + self.pin_list_btn.setVisible(False) + chat_header_row.addWidget(self.pin_list_btn) + + self.search_btn = QPushButton("") + self.search_btn.setFixedSize(32, 32) + self.search_btn.setObjectName("toolBtn") + self.search_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_FileDialogContentsView)) + self.search_btn.setToolTip("Search messages (Ctrl+F)") + self.search_btn.clicked.connect(self._toggle_search) + self.search_btn.setVisible(False) + chat_header_row.addWidget(self.search_btn) + + right_layout.addWidget(chat_header_widget) + + # Search bar (hidden by default) + self.search_widget = QWidget() + search_row = QHBoxLayout(self.search_widget) + search_row.setContentsMargins(0, 2, 0, 2) + self.search_input = QLineEdit() + self.search_input.setPlaceholderText("Search messages...") + self.search_input.setStyleSheet( + f"QLineEdit {{ background-color: {t.bg_secondary}; color: {t.text_primary}; " + f"border: 1px solid {t.border}; border-radius: 4px; padding: 4px 8px; font-size: 10pt; }}" + ) + self.search_input.textChanged.connect(self._on_search_text_changed) + self.search_input.returnPressed.connect(self._on_search_next) + # Escape in search input closes search + QShortcut(QKeySequence("Escape"), self.search_input).activated.connect(self._close_search) + search_row.addWidget(self.search_input, stretch=1) + self.search_prev_btn = QPushButton("\u25b2") + self.search_prev_btn.setFixedSize(28, 28) + self.search_prev_btn.setObjectName("toolBtn") + self.search_prev_btn.setToolTip("Previous match") + self.search_prev_btn.clicked.connect(self._on_search_prev) + search_row.addWidget(self.search_prev_btn) + self.search_next_btn = QPushButton("\u25bc") + self.search_next_btn.setFixedSize(28, 28) + self.search_next_btn.setObjectName("toolBtn") + self.search_next_btn.setToolTip("Next match") + self.search_next_btn.clicked.connect(self._on_search_next) + search_row.addWidget(self.search_next_btn) + self.search_count_label = QLabel("0/0") + self.search_count_label.setStyleSheet(f"color: {t.text_muted}; font-size: 9pt; min-width: 40px;") + self.search_count_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + search_row.addWidget(self.search_count_label) + self.search_close_btn = QPushButton("\u2715") + self.search_close_btn.setFixedSize(28, 28) + self.search_close_btn.setObjectName("toolBtn") + self.search_close_btn.setToolTip("Close search") + self.search_close_btn.clicked.connect(self._close_search) + search_row.addWidget(self.search_close_btn) + self.search_widget.setVisible(False) + right_layout.addWidget(self.search_widget) + + # --- Pinned message banner --- + self._pin_banner = QWidget() + self._pin_banner.setStyleSheet( + f"background-color:{t.border}; border-bottom:2px solid {t.pin_color};" + ) + pin_banner_layout = QHBoxLayout(self._pin_banner) + pin_banner_layout.setContentsMargins(10, 6, 10, 6) + pin_icon = QLabel("\U0001f4cc") + pin_icon.setStyleSheet("font-size:12pt; background:transparent; border:none;") + pin_banner_layout.addWidget(pin_icon) + self._pin_banner_label = QLabel("") + self._pin_banner_label.setStyleSheet( + f"color:{t.text_primary}; font-size:9pt; background:transparent; border:none;" + ) + self._pin_banner_label.setCursor(Qt.CursorShape.PointingHandCursor) + self._pin_banner_label.setWordWrap(False) + pin_banner_layout.addWidget(self._pin_banner_label, stretch=1) + pin_banner_close = QPushButton("\u2715") + pin_banner_close.setFixedSize(20, 20) + pin_banner_close.setStyleSheet( + f"QPushButton {{ background:transparent; color:{t.text_muted}; border:none; font-size:10pt; }}" + f"QPushButton:hover {{ color:{t.text_primary}; }}" + ) + pin_banner_close.clicked.connect(lambda: self._pin_banner.setVisible(False)) + pin_banner_layout.addWidget(pin_banner_close) + self._pin_banner.setVisible(False) + self._pin_banner.setCursor(Qt.CursorShape.PointingHandCursor) + self._pin_banner.mousePressEvent = self._on_pin_banner_clicked + self._pin_banner_msg_id = None + right_layout.addWidget(self._pin_banner) + + self.load_more_btn = QPushButton("Load older messages") + self.load_more_btn.setObjectName("secondaryBtn") + self.load_more_btn.clicked.connect(self._on_load_more) + self.load_more_btn.setVisible(False) + right_layout.addWidget(self.load_more_btn) + + # Message display area — QScrollArea with widget-based bubbles + self._msg_scroll_area = QScrollArea() + self._msg_scroll_area.setWidgetResizable(True) + self._msg_scroll_area.setHorizontalScrollBarPolicy( + Qt.ScrollBarPolicy.ScrollBarAlwaysOff + ) + self._msg_scroll_area.setStyleSheet( + f"QScrollArea {{ background-color: {t.bg_primary}; border: none; }}" + ) + self._msg_scroll_area.setAcceptDrops(True) + self._msg_scroll_area.installEventFilter(self) + self._msg_scroll_area.viewport().setContextMenuPolicy( + Qt.ContextMenuPolicy.NoContextMenu + ) + self._msg_container = QWidget() + self._msg_container.setStyleSheet(f"background-color: {t.bg_primary};") + self._msg_layout = QVBoxLayout(self._msg_container) + self._msg_layout.setAlignment(Qt.AlignmentFlag.AlignTop) + self._msg_layout.setContentsMargins(0, 8, 0, 8) + self._msg_layout.setSpacing(2) + self._msg_scroll_area.setWidget(self._msg_container) + self._msg_widgets = [] + self.message_area = self._msg_scroll_area # alias for scroll/jump/drop + right_layout.addWidget(self._msg_scroll_area, stretch=1) + + # Smart scroll: track if user is near bottom + self._is_near_bottom = True + self._msg_scroll_area.verticalScrollBar().valueChanged.connect( + self._on_scroll_changed + ) + + # Scroll-to-bottom floating button (hidden by default) + self.jump_btn = QPushButton("\u2193") + self.jump_btn.setParent(self.message_area) + self.jump_btn.setVisible(False) + self.jump_btn.setFixedSize(36, 36) + self.jump_btn.setStyleSheet( + f"QPushButton {{ background-color: {t.accent}; color: {t.accent_text}; border-radius: 18px; " + f"font-size: 14pt; font-weight: bold; }}" + f"QPushButton:hover {{ background-color: {t.accent_hover}; }}" + ) + self.jump_btn.clicked.connect(self._scroll_to_bottom) + + # Reply preview (above input, blue left bar) + self._reply_widget = QWidget() + self._reply_widget.setVisible(False) + reply_row = QHBoxLayout(self._reply_widget) + reply_row.setContentsMargins(8, 4, 8, 0) + reply_row.setSpacing(4) + reply_bar = QFrame() + reply_bar.setFixedWidth(3) + reply_bar.setStyleSheet(f"background-color: {t.accent}; border-radius: 1px;") + reply_row.addWidget(reply_bar) + self.reply_label = QLabel("") + self.reply_label.setStyleSheet( + f"color: {t.accent}; font-style: italic; font-size: 9pt; " + f"padding: 2px 4px; background: transparent;" + ) + self.reply_label.setWordWrap(True) + reply_row.addWidget(self.reply_label, stretch=1) + reply_dismiss = QPushButton("\u2715") + reply_dismiss.setFixedSize(20, 20) + reply_dismiss.setStyleSheet( + f"QPushButton {{ background:transparent; color:{t.text_muted}; border:none; font-size:10pt; }}" + f"QPushButton:hover {{ color:{t.text_primary}; }}" + ) + reply_dismiss.clicked.connect(self._cancel_reply) + reply_row.addWidget(reply_dismiss) + right_layout.addWidget(self._reply_widget) + + # Input row: [attach] [input] [send] + input_row = QHBoxLayout() + input_row.setSpacing(8) + input_row.setContentsMargins(8, 4, 8, 4) + + self._attach_btn = QPushButton("\U0001f4ce") + attach_btn = self._attach_btn + attach_btn.setFixedSize(40, 40) + attach_btn.setCursor(Qt.CursorShape.PointingHandCursor) + attach_btn.setStyleSheet( + f"QPushButton {{ background-color: {t.bg_secondary}; border: none; " + f"border-radius: 20px; font-size: 14pt; }}" + f"QPushButton:hover {{ background-color: {t.bg_hover}; }}" + ) + self._attach_menu = QMenu(attach_btn) + self._attach_menu.addAction("\U0001f5bc Image", self._on_attach_image) + self._attach_menu.addAction("\U0001f4c4 File", self._on_attach_file) + attach_btn.clicked.connect(lambda: self._attach_menu.exec( + attach_btn.mapToGlobal(attach_btn.rect().topLeft() - QPoint(0, self._attach_menu.sizeHint().height())) + )) + input_row.addWidget(attach_btn) + + self.msg_input = MessageInput() + self.msg_input.send_requested.connect(self._on_send) + self.msg_input.textChanged.connect(self._on_input_changed) + self.msg_input.file_dropped.connect(self._on_file_dropped) + input_row.addWidget(self.msg_input, stretch=1) + + self._send_btn = QPushButton("\u27a4") + send_btn = self._send_btn + send_btn.setFixedSize(40, 40) + send_btn.setCursor(Qt.CursorShape.PointingHandCursor) + send_btn.setStyleSheet( + f"QPushButton {{ background-color: {t.accent}; color: {t.accent_text}; " + f"border: none; border-radius: 20px; font-size: 14pt; font-weight: bold; }}" + f"QPushButton:hover {{ background-color: {t.accent_hover}; }}" + ) + send_btn.clicked.connect(self._on_send) + input_row.addWidget(send_btn) + right_layout.addLayout(input_row) + + self.char_counter = QLabel(f"0 / {MAX_INPUT_CHARS}") + self.char_counter.setStyleSheet(f"color: {t.text_muted}; font-size: 8pt; padding: 0 4px;") + self.char_counter.setAlignment(Qt.AlignmentFlag.AlignRight) + right_layout.addWidget(self.char_counter) + + self.reencrypt_label = QLabel("") + self.reencrypt_label.setStyleSheet( + f"background-color: {t.bg_secondary}; border-radius: 6px; " + f"padding: 8px 12px; color: {t.success}; font-weight: bold;" + ) + self.reencrypt_label.setVisible(False) + right_layout.addWidget(self.reencrypt_label) + + splitter.addWidget(left) + splitter.addWidget(right) + splitter.setStretchFactor(0, 1) + splitter.setStretchFactor(1, 3) + + # Wrap splitter + status bar in vertical layout for full-width status bar + wrapper = QVBoxLayout() + wrapper.setContentsMargins(0, 0, 0, 0) + wrapper.setSpacing(0) + wrapper.addWidget(splitter) + + # Status bar (permanent, fixed height, full width — no layout jumping) + self.status_bar = QLabel("") + self.status_bar.setFixedHeight(24) + self.status_bar.setStyleSheet( + f"background-color: {t.bg_tertiary}; border-radius: 0px; " + f"padding: 0 8px; color: {t.success}; font-size: 8pt;" + ) + self.status_bar.setCursor(Qt.CursorShape.PointingHandCursor) + self.status_bar.mousePressEvent = self._on_status_bar_click + self._status_bar_conv_id = None + wrapper.addWidget(self.status_bar) + + main_layout.addLayout(wrapper) + + def _connect_signals(self): + self.bridge.conversations_loaded.connect(self._on_conversations_loaded) + self.bridge.messages_loaded.connect(self._on_messages_loaded) + self.bridge.older_messages_loaded.connect(self._on_older_messages_loaded) + self.bridge.message_sent.connect(self._on_message_sent) + self.bridge.message_sent_payload.connect(self._on_message_sent_payload) + self.bridge.new_notification.connect(self._on_notification) + self.bridge.add_member_result.connect(self._on_add_member_result) + self.bridge.authorize_result.connect(self._on_authorize_result) + self.bridge.rotate_result.connect(self._on_rotate_result) + self.bridge.password_changed.connect(self._on_password_changed) + self.bridge.username_changed.connect(self._on_username_changed) + self.bridge.reencrypt_status.connect(self._on_reencrypt_status) + self.bridge.messages_read_notification.connect(self._on_messages_read) + self.bridge.message_delivered_notification.connect(self._on_message_delivered) + self.bridge.remove_member_result.connect(self._on_remove_member_result) + self.bridge.message_deleted_notification.connect(self._on_message_deleted) + self.bridge.delete_message_result.connect(self._on_delete_message_result) + self.bridge.image_sent.connect(self._on_image_sent) + self.bridge.image_downloaded.connect(self._on_image_downloaded) + self.bridge.file_sent.connect(self._on_file_sent) + self.bridge.file_downloaded.connect(self._on_file_downloaded) + self.bridge.conversation_updated.connect(self._on_conversation_updated) + self.bridge.connection_state_changed.connect(self._on_connection_state_changed) + self.bridge.group_left.connect(self._on_group_left) + self.bridge.group_renamed.connect(self._on_group_renamed) + self.bridge.conversation_deleted.connect(self._on_conversation_deleted) + self.bridge.avatar_loaded.connect(self._on_avatar_for_conv_list) + self.bridge.invitations_loaded.connect(self._on_invitations_loaded) + self.bridge.invitation_result.connect(self._on_invitation_result) + self.bridge.invitation_received.connect(self._on_invitation_received) + self.bridge.online_status_changed.connect(self._on_online_status_changed) + self.bridge.online_users_loaded.connect(self._on_online_users_loaded) + self.bridge.group_avatar_loaded.connect(self._on_group_avatar_for_conv_list) + self.bridge.group_avatar_updated.connect(self._on_group_avatar_updated) + self.bridge.session_reset_notification.connect(self._on_session_reset) + self.bridge.reaction_notification.connect(self._on_reaction_notification) + self.bridge.pin_notification.connect(self._on_pin_notification) + self.bridge.unpin_notification.connect(self._on_unpin_notification) + self.bridge.pinned_messages_loaded.connect(self._on_pinned_messages_loaded) + self.bridge.forward_result.connect(self._on_forward_result) + self.bridge.key_change_warning.connect(self._on_key_change_warning) + self._show_verification_dialog_signal.connect(self._show_verification_dialog) + + # ------------------------------------------------------------------ + # Favorites + # ------------------------------------------------------------------ + + def _favorites_path(self): + from chat_core import get_key_dir + return get_key_dir(self.bridge.client.email) / "favorites.json" + + def _load_favorites(self) -> set[str]: + try: + p = self._favorites_path() + if p.exists(): + return set(json.loads(p.read_text())) + except Exception: + pass + return set() + + def _save_favorites(self): + try: + self._favorites_path().write_text(json.dumps(list(self._favorites))) + except Exception: + pass + + def _on_conv_list_context_menu(self, pos): + item = self.conv_list.itemAt(pos) + if not item: + return + conv_id = item.data(Qt.ItemDataRole.UserRole) + if not conv_id: + return + from PyQt6.QtWidgets import QMenu + menu = QMenu(self) + is_fav = conv_id in self._favorites + action = menu.addAction("Odebrat z oblibených" if is_fav else "Přidat do oblíbených") + result = menu.exec(self.conv_list.mapToGlobal(pos)) + if result == action: + if is_fav: + self._favorites.discard(conv_id) + else: + self._favorites.add(conv_id) + self._save_favorites() + self._rebuild_conv_list() + + # ------------------------------------------------------------------ + # Conversation list helpers + # ------------------------------------------------------------------ + + def _get_conv_display_name(self, conv: dict) -> str: + """Get display name for a conversation (used for sorting and labels).""" + others = [m.get("username") or m.get("email") or "?" for m in conv["members"] + if m.get("email") != self.bridge.client.email] + return conv.get("name") or (", ".join(others) if others else self.bridge.client.username) + + def _get_conv_other_user_id(self, conv: dict) -> str: + """Get the other user's ID in a DM conversation (empty string for groups).""" + is_dm = len(conv["members"]) == 2 and not conv.get("name") + if not is_dm: + return "" + for m in conv["members"]: + if m.get("email") != self.bridge.client.email: + return m.get("user_id") or m.get("id") or "" + return "" + + def _get_conv_sort_key(self, conv: dict) -> tuple: + """Sort key: favorites first, then online DMs, then rest — alphabetically within each.""" + conv_id = conv.get("conversation_id", "") + is_fav = 0 if conv_id in self._favorites else 1 + other_uid = self._get_conv_other_user_id(conv) + is_online = 0 if other_uid and other_uid in self._online_users else 1 + name = self._get_conv_display_name(conv).lower() + return (is_fav, is_online, name) + + def _on_conversations_loaded(self, convs): + self.conversations = convs + # Populate unread counts from server (covers messages received while offline) + for cv in convs: + cid = cv["conversation_id"] + server_unread = cv.get("unread_count", 0) + # Use the higher of server vs local (local may have newer real-time notifications) + if server_unread > self._unread_counts.get(cid, 0): + self._unread_counts[cid] = server_unread + self._rebuild_conv_list() + + def _rebuild_conv_list(self): + """Sort and rebuild the conversation list widget.""" + if not self.conversations: + return + # Sort: favorites first, then online DMs, then rest — alphabetically within each group + self.conversations.sort(key=self._get_conv_sort_key) + prev_id = self.current_conv_id + self.conv_list.blockSignals(True) + self.conv_list.clear() + select_row = -1 + for i, cv in enumerate(self.conversations): + conv_id = cv["conversation_id"] + name = self._get_conv_display_name(cv) + count = self._unread_counts.get(conv_id, 0) + is_fav = conv_id in self._favorites + # Preview + timestamp from last-message cache + preview_text, preview_ts, receipt_st = self._last_message_cache.get(conv_id, ("", "", "")) + rel_ts = _relative_time(preview_ts) + # Avatar pixmap + avatar_pix = self._get_conv_avatar_pixmap(cv) + + # Verification status (DMs only) + verified_status = "" + is_dm = len(cv["members"]) == 2 and not cv.get("name") + if is_dm: + peer_uid = self._get_conv_other_user_id(cv) + if peer_uid: + verified_status = self.bridge.client.get_verification_status(peer_uid) + + item = QListWidgetItem() + item.setSizeHint(QSize(0, ConversationDelegate.ITEM_HEIGHT)) + item.setData(ROLE_CONV_ID, conv_id) + item.setData(ROLE_DISPLAY_NAME, name) + item.setData(ROLE_PREVIEW, preview_text) + item.setData(ROLE_TIMESTAMP, rel_ts) + item.setData(ROLE_UNREAD, count) + item.setData(ROLE_IS_FAV, is_fav) + item.setData(ROLE_AVATAR, avatar_pix) + item.setData(ROLE_VERIFIED, verified_status) + item.setData(ROLE_RECEIPT, receipt_st) + self.conv_list.addItem(item) + if conv_id == prev_id: + select_row = i + self.conv_list.blockSignals(False) + if select_row >= 0: + self.conv_list.setCurrentRow(select_row) + + def _on_conversation_updated(self): + """Refresh conversation list when a conversation is created/member added/removed.""" + self.bridge.load_conversations() + + def _on_periodic_refresh(self): + """Periodic refresh: reload invitations and refresh avatars in small batches.""" + # Keep this first so invitations are not queued behind a large avatar burst. + self.bridge.list_invitations() + + uids = list(self._avatar_requested) + if uids: + n = len(uids) + batch = min(self._AVATAR_REFRESH_BATCH, n) + start = self._avatar_refresh_cursor % n + for i in range(batch): + uid = uids[(start + i) % n] + self.bridge.get_avatar(uid) + self._avatar_refresh_cursor = (start + batch) % n + + conv_ids = list(self._group_avatar_requested) + if conv_ids: + n = len(conv_ids) + batch = min(self._GROUP_AVATAR_REFRESH_BATCH, n) + start = self._group_avatar_refresh_cursor % n + for i in range(batch): + conv_id = conv_ids[(start + i) % n] + self.bridge.get_group_avatar(conv_id) + self._group_avatar_refresh_cursor = (start + batch) % n + + def _on_online_users_loaded(self, user_ids): + self._online_users = set(user_ids) + self._rebuild_conv_list() + + def _on_online_status_changed(self, user_id, is_online): + if is_online: + self._online_users.add(user_id) + else: + self._online_users.discard(user_id) + self._rebuild_conv_list() + + def _on_avatar_for_conv_list(self, user_id, data): + """Cache downloaded avatar and refresh conversation list icons + chat header.""" + qimg = _safe_load_image(data) + if qimg is not None: + self._avatar_cache[user_id] = QPixmap.fromImage(qimg) + self._update_conv_list_styles() + # Refresh chat header avatar if current conv uses this user's avatar + self._refresh_chat_header_avatar() + + def _on_group_avatar_for_conv_list(self, conv_id, data): + """Cache downloaded group avatar and refresh conversation list icons + chat header.""" + qimg = _safe_load_image(data) + if qimg is not None: + self._group_avatar_cache[conv_id] = QPixmap.fromImage(qimg) + self._update_conv_list_styles() + # Refresh chat header avatar if current conv is this group + self._refresh_chat_header_avatar() + + def _on_group_avatar_updated(self, ok, msg): + if not ok: + QMessageBox.warning(self, "Group Avatar", msg) + + def _on_invitations_loaded(self, invitations): + self._pending_invitations = invitations + self.inv_list.clear() + if not invitations: + self.inv_label.setVisible(False) + self.inv_list.setVisible(False) + return + self.inv_label.setVisible(True) + self.inv_list.setVisible(True) + for inv in invitations: + conv_name = inv.get("conversation_name") or "Unnamed group" + inviter = inv.get("invited_by_username", "someone") + label = f"{conv_name} (from {inviter})" + item = QListWidgetItem(label) + item.setData(Qt.ItemDataRole.UserRole, inv["conversation_id"]) + self.inv_list.addItem(item) + + def _on_invitation_result(self, ok, msg): + if not ok: + QMessageBox.warning(self, "Invitation", msg) + + def _on_invitation_received(self, data): + """New invitation received via push notification.""" + self.bridge.list_invitations() + conv_name = data.get("conversation_name") or "a group" + inviter = data.get("invited_by_username", "Someone") + t = c() + self.status_bar.setText(f"{inviter} invited you to {conv_name}") + self.status_bar.setStyleSheet( + f"background-color: {t.bg_tertiary}; border-radius: 0px; " + f"padding: 0 8px; color: {t.warning}; font-size: 8pt; font-weight: bold;" + ) + self._status_bar_conv_id = None + QTimer.singleShot(5000, self._clear_status_bar) + + def _on_inv_context_menu(self, pos): + item = self.inv_list.itemAt(pos) + if not item: + return + conv_id = item.data(Qt.ItemDataRole.UserRole) + if not conv_id: + return + menu = QMenu(self) + accept_action = menu.addAction("Accept") + decline_action = menu.addAction("Decline") + chosen = menu.exec(self.inv_list.mapToGlobal(pos)) + if chosen == accept_action: + self.bridge.accept_invitation(conv_id) + elif chosen == decline_action: + self.bridge.decline_invitation(conv_id) + + def _on_connection_state_changed(self, state): + t = c() + if state == "connected": + self.connection_dot.setStyleSheet(f"color: {t.success}; font-size: 11pt;") + self.connection_dot.setToolTip("Connected") + self.status_bar.setText("Connected") + self.status_bar.setStyleSheet( + f"background-color: {t.bg_tertiary}; border-radius: 0px; " + f"padding: 0 8px; color: {t.success}; font-size: 8pt;" + ) + QTimer.singleShot(3000, self._clear_status_bar) + elif state == "disconnected": + self.connection_dot.setStyleSheet(f"color: {t.error}; font-size: 11pt;") + self.connection_dot.setToolTip("Disconnected") + self.status_bar.setText("Disconnected from server") + self.status_bar.setStyleSheet( + f"background-color: {t.bg_tertiary}; border-radius: 0px; " + f"padding: 0 8px; color: {t.error}; font-size: 8pt; font-weight: bold;" + ) + self._status_bar_conv_id = None + elif state == "reconnecting": + self.connection_dot.setStyleSheet(f"color: {t.warning}; font-size: 11pt;") + self.connection_dot.setToolTip("Reconnecting...") + self.status_bar.setText("Reconnecting...") + self.status_bar.setStyleSheet( + f"background-color: {t.bg_tertiary}; border-radius: 0px; " + f"padding: 0 8px; color: {t.warning}; font-size: 8pt;" + ) + self._status_bar_conv_id = None + elif state == "revoked": + self.connection_dot.setStyleSheet(f"color: {t.error}; font-size: 11pt;") + self.connection_dot.setToolTip("Access revoked") + # Clear conversation list + self.conv_list.clear() + self.conversations = [] + self._unread_counts.clear() + # Clear open conversation + self.current_conv_id = None + self.msg_input.drop_enabled = False + self.chat_header.setText("Select a conversation") + self.chat_header_avatar.setVisible(False) + self._clear_message_area() + self.msg_input.setEnabled(False) + self.group_info_btn.setVisible(False) + self.user_info_btn.setVisible(False) + self.add_member_btn.setVisible(False) + self.delete_conv_btn.setVisible(False) + QMessageBox.warning(self, "Access Revoked", + "Your keys were rotated on another device. " + "This session is no longer valid.") + + def _on_scroll_changed(self, value): + sb = self.message_area.verticalScrollBar() + self._is_near_bottom = (sb.maximum() - value) < 60 + if self._is_near_bottom: + self.jump_btn.setVisible(False) + elif sb.maximum() > 0: + self.jump_btn.setText("\u2193") + self.jump_btn.setVisible(True) + self._position_jump_btn() + + def _scroll_to_bottom(self): + sb = self.message_area.verticalScrollBar() + sb.setValue(sb.maximum()) + self.jump_btn.setVisible(False) + + def _position_jump_btn(self): + w = self.message_area.width() + self.jump_btn.move((w - 36) // 2, self.message_area.height() - 48) + + def _clear_status_bar(self): + self.status_bar.setText("") + t = c() + self.status_bar.setStyleSheet( + f"background-color: {t.bg_tertiary}; border-radius: 0px; " + f"padding: 0 8px; color: {t.success}; font-size: 8pt;" + ) + self._status_bar_conv_id = None + + def _on_status_bar_click(self, event): + conv_id = self._status_bar_conv_id + if conv_id: + for i, c in enumerate(self.conversations): + if c["conversation_id"] == conv_id: + self.conv_list.setCurrentRow(i) + self._clear_status_bar() + break + + def _update_chat_header_avatar(self, conv): + """Set the circular avatar next to the conversation name in the chat header.""" + is_dm = len(conv["members"]) == 2 and not conv.get("name") + size = 40 + t = c() + if is_dm: + other = None + for m in conv["members"]: + if m.get("email") != self.bridge.client.email: + other = m + break + if other: + uid = other.get("user_id") or other.get("id") or "" + uname = other.get("username") or other.get("email") or "?" + if uid in self._avatar_cache: + avatar = self._make_circular_avatar(self._avatar_cache[uid], size) + else: + avatar = self._make_default_avatar(uname, size) + self.chat_header_avatar.setPixmap(avatar) + self.chat_header_avatar.setVisible(True) + # Online status text + if uid in self._online_users: + self._chat_header_status.setText("Online") + self._chat_header_status.setStyleSheet(f"color: {t.success}; font-size: 8pt;") + else: + self._chat_header_status.setText("Offline") + self._chat_header_status.setStyleSheet(f"color: {t.text_muted}; font-size: 8pt;") + self._chat_header_status.setVisible(True) + else: + self.chat_header_avatar.setVisible(False) + self._chat_header_status.setVisible(False) + else: + conv_id = conv.get("conversation_id") or "" + gname = conv.get("name") or "G" + if conv_id in self._group_avatar_cache: + avatar = self._make_circular_avatar(self._group_avatar_cache[conv_id], size) + else: + avatar = self._make_default_avatar(gname, size) + self.chat_header_avatar.setPixmap(avatar) + self.chat_header_avatar.setVisible(True) + # Member count for groups + member_count = len(conv.get("members", [])) + self._chat_header_status.setText(f"{member_count} members") + self._chat_header_status.setStyleSheet(f"color: {t.text_muted}; font-size: 8pt;") + self._chat_header_status.setVisible(True) + # Show E2E indicator with verification status + if is_dm: + peer_uid = "" + for m in conv["members"]: + if m.get("email") != self.bridge.client.email: + peer_uid = m.get("user_id") or m.get("id") or "" + break + v_status = self.bridge.client.get_verification_status(peer_uid) if peer_uid else "" + if v_status == "verified": + self._e2e_label.setText("\u2705 Verified") + self._e2e_label.setStyleSheet(f"font-size: 8pt; color: {t.success}; background: transparent;") + self._e2e_label.setToolTip("Identity verified — click to view safety number") + elif v_status == "trusted": + self._e2e_label.setText("\U0001f512 Encrypted") + self._e2e_label.setStyleSheet(f"font-size: 8pt; color: {t.text_muted}; background: transparent;") + self._e2e_label.setToolTip("End-to-end encrypted (not verified) — click to verify") + else: + self._e2e_label.setText("\U0001f512 Encrypted") + self._e2e_label.setStyleSheet(f"font-size: 8pt; color: {t.text_muted}; background: transparent;") + self._e2e_label.setToolTip("End-to-end encrypted — click to verify") + else: + self._e2e_label.setText("\U0001f512 End-to-end encrypted") + self._e2e_label.setStyleSheet(f"font-size: 8pt; color: {t.text_muted}; background: transparent;") + self._e2e_label.setToolTip("End-to-end encrypted") + self._e2e_label.setVisible(True) + + def _refresh_chat_header_avatar(self): + """Re-render chat header avatar for the currently selected conversation.""" + if not self.current_conv_id: + return + for cv in self.conversations: + if cv["conversation_id"] == self.current_conv_id: + self._update_chat_header_avatar(cv) + return + + def _update_conv_list_styles(self): + """Update delegate data roles for all items (after avatar/unread changes).""" + for i in range(self.conv_list.count()): + item = self.conv_list.item(i) + conv_id = item.data(ROLE_CONV_ID) + count = self._unread_counts.get(conv_id, 0) + conv = None + for cv in self.conversations: + if cv["conversation_id"] == conv_id: + conv = cv + break + if conv: + item.setData(ROLE_DISPLAY_NAME, self._get_conv_display_name(conv)) + item.setData(ROLE_AVATAR, self._get_conv_avatar_pixmap(conv)) + item.setData(ROLE_IS_FAV, conv_id in self._favorites) + item.setData(ROLE_UNREAD, count) + # Update preview/timestamp from cache + preview_text, preview_ts, receipt_st = self._last_message_cache.get(conv_id, ("", "", "")) + item.setData(ROLE_PREVIEW, preview_text) + item.setData(ROLE_TIMESTAMP, _relative_time(preview_ts)) + item.setData(ROLE_RECEIPT, receipt_st) + + def _on_conv_selected(self, row): + if row < 0 or row >= len(self.conversations): + return + conv = self.conversations[row] + self.current_conv_id = conv["conversation_id"] + self.msg_input.drop_enabled = True + others = [m.get("username") or m.get("email") or "?" for m in conv["members"] + if m.get("email") != self.bridge.client.email] + header = conv.get("name") or (", ".join(others) if others else self.bridge.client.username) + self.chat_header.setText(header) + # Set avatar in chat header + self._update_chat_header_avatar(conv) + is_group = len(conv["members"]) > 2 or conv.get("name") + self._is_dm = not is_group + self.add_member_btn.setVisible(bool(is_group)) + self.group_info_btn.setVisible(bool(is_group)) + self.user_info_btn.setVisible(self._is_dm) + # DMs: always show delete. Groups: only show for creator. + if self._is_dm: + self.delete_conv_btn.setVisible(True) + else: + my_user_id = self.bridge.client.session.get("user_id", "") if self.bridge.client.session else "" + self.delete_conv_btn.setVisible(conv.get("created_by") == my_user_id) + self.reply_to_id = None + self._reply_widget.setVisible(False) + self._has_more_messages = True + self.load_more_btn.setVisible(False) + self._unread_counts.pop(self.current_conv_id, None) + self._update_conv_list_styles() + self.search_btn.setVisible(True) + self.pin_list_btn.setVisible(True) + self._pin_banner.setVisible(False) + self._close_search() + self.bridge.load_messages(self.current_conv_id) + + def _on_e2e_label_clicked(self, event): + """Open VerificationDialog when E2E label is clicked (DMs only).""" + if not self.current_conv_id or not getattr(self, "_is_dm", False): + return + conv = None + for cv in self.conversations: + if cv["conversation_id"] == self.current_conv_id: + conv = cv + break + if not conv: + return + peer_uid = self._get_conv_other_user_id(conv) + if not peer_uid: + return + peer_name = "" + for m in conv["members"]: + if m.get("email") != self.bridge.client.email: + peer_name = m.get("username") or m.get("email") or "?" + break + # Ensure identity key is in cache + self.bridge.schedule(self._ensure_and_show_verification(peer_uid, peer_name)) + + async def _ensure_and_show_verification(self, peer_uid: str, peer_name: str): + """Ensure we have the peer's identity key, then show verification dialog.""" + await self.bridge.client._get_user_info(user_id=peer_uid) + # Emit signal back to Qt thread + self._show_verification_dialog_signal.emit(peer_uid, peer_name) + + def _show_verification_dialog(self, peer_uid: str, peer_name: str): + dlg = VerificationDialog(self.bridge, peer_uid, peer_name, parent=self) + dlg.exec() + # Refresh conv list and header to reflect any verification changes + self._rebuild_conv_list() + if self.current_conv_id: + for cv in self.conversations: + if cv["conversation_id"] == self.current_conv_id: + self._update_chat_header_avatar(cv) + break + + def _on_key_change_warning(self, user_id: str, username: str, old_key_hex: str, was_verified: bool, new_key_bytes: bytes = b""): + """Show warning dialog when a contact's identity key has changed.""" + t = c() + severity = "CRITICAL" if was_verified else "WARNING" + name = username or user_id[:8] + msg = ( + f"The identity key for {name} has changed!\n\n" + f"This could mean:\n" + f"- They re-installed the app or got a new device\n" + f"- Someone may be intercepting your messages\n\n" + ) + if was_verified: + msg += "This contact was previously verified. You should re-verify." + + dlg = QDialog(self) + dlg.setMinimumWidth(400) + lay = _make_frameless(dlg, f"Identity Key Changed ({severity})") + + warning_label = QLabel(msg) + warning_label.setWordWrap(True) + warning_label.setStyleSheet(f"color: {t.text_primary};") + lay.addWidget(warning_label) + + btn_row = QHBoxLayout() + accept_btn = QPushButton("Accept New Key") + accept_btn.setStyleSheet( + f"QPushButton {{ background-color: {t.warning}; color: {t.bg_primary}; " + f"font-weight: bold; padding: 8px 16px; border-radius: 6px; }}" + ) + accept_btn.clicked.connect(lambda: self._accept_key_change(user_id, new_key_bytes, dlg)) + btn_row.addWidget(accept_btn) + + close_btn = QPushButton("Dismiss") + close_btn.setObjectName("secondaryBtn") + close_btn.clicked.connect(dlg.accept) + btn_row.addWidget(close_btn) + + lay.addLayout(btn_row) + dlg.exec() + + def _accept_key_change(self, user_id: str, new_key_bytes: bytes, dlg: QDialog): + if new_key_bytes: + self.bridge.client.accept_key_change(user_id, new_key_bytes) + dlg.accept() + self._rebuild_conv_list() + + def _on_messages_loaded(self, conv_id, messages): + if conv_id != self.current_conv_id: + return + self.current_messages = messages + # Update last-message cache for conversation list preview + if messages: + last = messages[-1] + preview = last.get("text", "") + if last.get("deleted"): + preview = "Message deleted" + elif last.get("image") and not preview: + preview = "Sent an image" + elif last.get("file") and not preview: + preview = "Sent a file" + receipt = self._compute_receipt_status(last) + self._last_message_cache[conv_id] = (preview[:60], last.get("created_at", ""), receipt) + self._update_conv_list_styles() + # Show "Load older" if we got a full batch (there may be more) + self._has_more_messages = len(messages) >= 50 + self.load_more_btn.setVisible(self._has_more_messages) + self._render_messages() + self._update_pin_banner() + + def _compute_receipt_status(self, m) -> str: + """Return receipt status for a message: 'read', 'delivered', 'sent', or ''.""" + my_uid = (self.bridge.client.session.get("user_id", "") + if self.bridge and self.bridge.client and self.bridge.client.session else "") + if not my_uid or m.get("sender_id") != my_uid: + return "" + read_by = m.get("read_by", []) + if any(r.get("user_id") != my_uid for r in read_by): + return "read" + delivered_to = m.get("delivered_to", []) + if any(d.get("user_id") != my_uid for d in delivered_to): + return "delivered" + return "sent" + + def _render_messages(self, scroll_to_bottom=True): + """Clear and rebuild all message bubble widgets.""" + self._msg_widgets = [] + while self._msg_layout.count() > 0: + item = self._msg_layout.takeAt(0) + w = item.widget() + if w: + w.deleteLater() + for i, m in enumerate(self.current_messages): + w = self._create_message_widget(m, i) + self._msg_layout.addWidget(w) + self._msg_widgets.append(w) + if scroll_to_bottom: + QTimer.singleShot(10, self._scroll_to_bottom) + + def _clear_message_area(self): + """Remove all message widgets from the scroll area.""" + self._msg_widgets = [] + self.current_messages = [] + while self._msg_layout.count() > 0: + item = self._msg_layout.takeAt(0) + w = item.widget() + if w: + w.deleteLater() + + def _decode_thumbnail(self, image_info): + """Decode base64 thumbnail to QPixmap.""" + thumbnail_b64 = image_info.get("thumbnail", "") + if not thumbnail_b64: + return None + from protocol import decode_binary + try: + thumb_bytes = decode_binary(thumbnail_b64) + qimg = _safe_load_image(thumb_bytes) + if qimg is not None: + return QPixmap.fromImage(qimg) + except Exception: + pass + return None + + def _create_message_widget(self, m, index): + """Create a widget tree for a single message bubble.""" + t = c() + my_uid = (self.bridge.client.session.get("user_id", "") + if self.bridge.client.session else "") + is_me = (m.get("sender_id") == my_uid if my_uid + else m.get("sender") == self.bridge.client.username) + + # -- Wrapper for left/right alignment -- + wrapper = QWidget() + wrapper.setStyleSheet("background: transparent;") + wrapper._msg_index = index + wlay = QHBoxLayout(wrapper) + wlay.setContentsMargins(12, 2, 12, 2) + wlay.setSpacing(0) + + # -- Deleted message -- + if m.get("deleted"): + ts = m.get("created_at", "") + time_str = _format_msg_time(ts) + del_bubble = MessageBubble(t.bg_secondary) + del_bubble._msg_index = index + dlay = QVBoxLayout(del_bubble) + dlay.setContentsMargins(14, 8, 14, 8) + dl = QLabel(f"\u00b7 {time_str} Message deleted") + dl.setStyleSheet( + f"color: {t.text_muted}; font-style: italic; " + f"font-size: 10pt; background: transparent;" + ) + dlay.addWidget(dl) + if is_me: + wlay.addStretch(1) + wlay.addWidget(del_bubble) + if not is_me: + wlay.addStretch(1) + return wrapper + + sender = m.get("sender", "???") + text = m.get("text", "") + + # Determine colours + if is_me: + bubble_bg = t.bubble_sent_bg + text_color = t.bubble_sent_text + meta_color = t.bubble_sent_meta + sender_color = t.accent + else: + bubble_bg = t.bubble_recv_bg + text_color = t.bubble_recv_text + meta_color = t.bubble_recv_meta + sender_color = t.sender_name_other + + if is_me: + wlay.addStretch(1) + + # -- Bubble -- + bubble = MessageBubble(bubble_bg) + bubble._msg_index = index + bubble.setSizePolicy( + QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Minimum + ) + bubble.setMaximumWidth(600) + bubble.setMinimumWidth(80) + + blay = QVBoxLayout(bubble) + blay.setContentsMargins(14, 8, 14, 8) + blay.setSpacing(3) + + # -- Forwarded header -- + fwd = m.get("forwarded_from") + if fwd: + fwd_sender = fwd.get("sender", "???") + fwd_esc = (fwd_sender.replace("&", "&") + .replace("<", "<").replace(">", ">")) + fwd_label = QLabel( + f'' + f'Forwarded from {fwd_esc}' + f'' + ) + fwd_label.setTextFormat(Qt.TextFormat.RichText) + fwd_label.setStyleSheet( + f"background: transparent; border-left: 2px solid {t.info}; " + f"padding-left: 6px; margin-bottom: 2px;" + ) + blay.addWidget(fwd_label) + + # -- Reply context -- + if m.get("reply_to"): + for orig in self.current_messages: + if orig.get("message_id") == m.get("reply_to"): + orig_sender = orig.get("sender", "???") + orig_esc = (orig_sender.replace("&", "&") + .replace("<", "<").replace(">", ">")) + orig_text = orig.get("text", "")[:50] + orig_text_esc = (orig_text.replace("&", "&") + .replace("<", "<").replace(">", ">")) + reply_lbl = QLabel( + f'{orig_esc}
' + f'' + f'{orig_text_esc}' + ) + reply_lbl.setTextFormat(Qt.TextFormat.RichText) + reply_lbl.setWordWrap(True) + reply_lbl.setStyleSheet( + f"background: transparent; " + f"border-left: 2px solid {t.scrollbar}; " + f"padding-left: 6px; margin-bottom: 2px;" + ) + blay.addWidget(reply_lbl) + break + + # -- Header (sender name + pin — groups only) -- + timestamp = m.get("created_at", "") + sender_esc = (sender.replace("&", "&") + .replace("<", "<").replace(">", ">")) + pin_html = "" + if m.get("pinned_at"): + pin_html = f' \U0001f4cc' + + is_dm = self._is_dm + if not is_dm: + header_html = ( + f'{sender_esc}' + f'{pin_html}' + ) + header_label = QLabel(header_html) + header_label.setTextFormat(Qt.TextFormat.RichText) + header_label.setStyleSheet("background: transparent;") + blay.addWidget(header_label) + elif pin_html: + pin_label = QLabel(pin_html) + pin_label.setTextFormat(Qt.TextFormat.RichText) + pin_label.setStyleSheet("background: transparent;") + blay.addWidget(pin_label) + + # -- Suppress image placeholder text -- + image_info = m.get("image") + if image_info: + fname_raw = image_info.get("filename", "image") + if text == f"[Image: {fname_raw}]": + text = "" + + # -- Message text -- + if text: + text_html = _linkify_urls(text) + text_html = text_html.replace("\n", "
") + # Search highlighting + if (self._search_active and self._search_query + and index in self._search_results): + is_current = ( + 0 <= self._search_current < len(self._search_results) + and self._search_results[self._search_current] == index + ) + bg = t.search_current if is_current else t.search_highlight + text_html = self._highlight_search_text( + text_html, self._search_query, bg + ) + # @Mentions highlighting + import re as _re + mention_c = t.mention + text_html = _re.sub( + r'@(\w+)', + lambda mt: ( + f'' + f'@{mt.group(1)}' + ), + text_html, + ) + + text_label = QLabel(text_html) + text_label.setTextFormat(Qt.TextFormat.RichText) + text_label.setWordWrap(True) + text_label.setStyleSheet( + f"color: {text_color}; background: transparent; " + f"font-size: 11pt;" + ) + text_label.setTextInteractionFlags( + Qt.TextInteractionFlag.LinksAccessibleByMouse + ) + text_label.linkActivated.connect(self._on_link_clicked) + blay.addWidget(text_label) + + # -- Image thumbnail -- + if image_info: + thumb_pixmap = self._decode_thumbnail(image_info) + file_id = image_info.get("file_id", "") + filename = image_info.get("filename", "image") + size_bytes = image_info.get("size", 0) + size_str = self._human_file_size(size_bytes) + if thumb_pixmap: + img_label = QLabel() + scaled = thumb_pixmap.scaledToWidth( + min(200, thumb_pixmap.width()), + Qt.TransformationMode.SmoothTransformation, + ) + img_label.setPixmap(scaled) + img_label.setStyleSheet("background: transparent;") + img_label.setCursor(Qt.CursorShape.PointingHandCursor) + img_label.mousePressEvent = ( + lambda e, fid=file_id: self._on_image_click(fid) + ) + blay.addWidget(img_label) + + link_color = text_color if is_me else t.accent + fname_esc = (filename.replace("&", "&") + .replace("<", "<").replace(">", ">")) + info_label = QLabel( + f'' + f'{fname_esc} ({size_str}) \u2014 Click to view' + ) + info_label.setTextFormat(Qt.TextFormat.RichText) + info_label.setStyleSheet( + f"font-size: 9pt; background: transparent;" + ) + info_label.setTextInteractionFlags( + Qt.TextInteractionFlag.LinksAccessibleByMouse + ) + info_label.linkActivated.connect(self._on_link_clicked) + blay.addWidget(info_label) + + # -- File card -- + file_info = m.get("file") + if file_info: + fname = file_info.get("filename", "file") + fname_esc = (fname.replace("&", "&") + .replace("<", "<").replace(">", ">")) + fsize = file_info.get("size", 0) + size_str = self._human_file_size(fsize) + f_id = file_info.get("file_id", "") + icon = self._file_icon(fname) + file_link_color = text_color if is_me else t.accent + + file_frame = QFrame() + file_frame.setStyleSheet( + f"QFrame {{ background: transparent; " + f"border: 1px solid {meta_color}; " + f"border-radius: 6px; padding: 8px; }}" + ) + flay = QHBoxLayout(file_frame) + flay.setContentsMargins(0, 0, 0, 0) + file_label = QLabel( + f'{icon} {fname_esc}' + f' ({size_str})' + ) + file_label.setTextFormat(Qt.TextFormat.RichText) + file_label.setStyleSheet("background: transparent; border: none;") + file_label.setTextInteractionFlags( + Qt.TextInteractionFlag.LinksAccessibleByMouse + ) + file_label.linkActivated.connect(self._on_link_clicked) + flay.addWidget(file_label) + blay.addWidget(file_frame) + + # -- Reaction badges -- + reactions = m.get("reactions", []) + if reactions: + _REACTION_EMOJI = { + "thumbsup": "\U0001f44d", "heart": "\u2764\ufe0f", + "laugh": "\U0001f602", "surprised": "\U0001f62e", + "sad": "\U0001f622", "thumbsdown": "\U0001f44e", + } + grouped = {} + for r in reactions: + grouped.setdefault(r["reaction"], []).append(r["user_id"]) + my_id = (self.bridge.client.session.get("user_id", "") + if self.bridge.client.session else "") + react_widget = QWidget() + react_widget.setStyleSheet("background: transparent;") + react_lay = QHBoxLayout(react_widget) + react_lay.setContentsMargins(0, 2, 0, 0) + react_lay.setSpacing(4) + for rkey, uids in grouped.items(): + emoji = _REACTION_EMOJI.get(rkey, rkey) + count = len(uids) + is_mine = my_id in uids + bg = t.reaction_bg_own if is_mine else t.reaction_bg + bdr = t.reaction_border_own if is_mine else t.reaction_border + badge = QLabel(f"{emoji} {count}") + badge.setStyleSheet( + f"background-color: {bg}; color: {t.text_primary}; " + f"border: 1px solid {bdr}; " + f"border-radius: 10px; padding: 2px 6px; font-size: 10pt;" + ) + react_lay.addWidget(badge) + react_lay.addStretch() + blay.addWidget(react_widget) + + # -- Footer (timestamp + receipt, right-aligned, below content) -- + time_str = _format_msg_time(timestamp) + receipt_status = "" + if is_me: + read_by = m.get("read_by", []) + others_read = [r for r in read_by if r.get("user_id") != my_uid] + delivered_to = m.get("delivered_to", []) + others_delivered = [d for d in delivered_to if d.get("user_id") != my_uid] + if others_read: + receipt_status = "read" + elif others_delivered: + receipt_status = "delivered" + else: + receipt_status = "sent" + footer_w = _ReceiptFooter(time_str, receipt_status, meta_color, + t.bubble_sent_text, t.success) + footer_w.setStyleSheet("background: transparent;") + blay.addWidget(footer_w, alignment=Qt.AlignmentFlag.AlignRight) + + wlay.addWidget(bubble) + if not is_me: + wlay.addStretch(1) + + # Install event filter on all child widgets for context menu propagation + for child in bubble.findChildren(QWidget): + child.setContextMenuPolicy(Qt.ContextMenuPolicy.NoContextMenu) + child.installEventFilter(self) + bubble.installEventFilter(self) + wrapper.installEventFilter(self) + + return wrapper + + def _on_image_click(self, file_id): + """Handle click on an image thumbnail in a message bubble.""" + for msg in self.current_messages: + image_info = msg.get("image") + if image_info and image_info.get("file_id") == file_id: + self._view_image(msg) + return + + def _find_msg_index_at_widget(self, widget): + """Walk up the widget tree to find the _msg_index attribute.""" + while widget: + if hasattr(widget, '_msg_index'): + return widget._msg_index + widget = widget.parentWidget() + return None + + def _on_older_messages_loaded(self, conv_id, messages): + if conv_id != self.current_conv_id: + return + if not messages: + self._has_more_messages = False + self.load_more_btn.setVisible(False) + return + self._has_more_messages = len(messages) >= 50 + self.load_more_btn.setVisible(self._has_more_messages) + # Prepend older messages and re-render + self.current_messages = messages + self.current_messages + self._render_messages(scroll_to_bottom=False) + + def _on_load_more(self): + if not self.current_conv_id or not self._has_more_messages: + return + offset = len(self.current_messages) + self.bridge.load_older_messages(self.current_conv_id, offset) + + def _on_message_context_menu(self, pos): + if not self.current_messages: + return + global_pos = self._msg_scroll_area.viewport().mapToGlobal(pos) + widget = QApplication.widgetAt(global_pos) + idx = self._find_msg_index_at_widget(widget) + if idx is not None: + self._show_msg_context_menu(idx, global_pos) + + def _show_msg_context_menu(self, idx, global_pos): + if idx < 0 or idx >= len(self.current_messages): + return + m = self.current_messages[idx] + if m.get("deleted"): + return + + my_user_id = self.bridge.client.session.get("user_id", "") if self.bridge.client.session else "" + menu = QMenu(self) + + reply_icon = self.style().standardIcon(QStyle.StandardPixmap.SP_ArrowBack) + reply_action = menu.addAction(reply_icon, "Reply") + + # Reaction submenu + react_menu = menu.addMenu("React") + _REACTION_LABELS = { + "thumbsup": "\U0001f44d +1", + "heart": "\u2764\ufe0f Heart", + "laugh": "\U0001f602 Haha", + "surprised": "\U0001f62e Wow", + "sad": "\U0001f622 Sad", + "thumbsdown": "\U0001f44e -1", + } + react_actions = {} + for rkey, rlabel in _REACTION_LABELS.items(): + act = react_menu.addAction(rlabel) + react_actions[act] = rkey + + # Forward + fwd_action = menu.addAction("Forward") + + # Pin / Unpin + pin_action = None + if m.get("pinned_at"): + pin_action = menu.addAction("\U0001f4cc Unpin") + else: + pin_action = menu.addAction("\U0001f4cc Pin") + + menu.addSeparator() + + # Delete option for own messages + del_action = None + if m.get("sender_id") == my_user_id: + del_icon = self.style().standardIcon(QStyle.StandardPixmap.SP_TrashIcon) + del_action = menu.addAction(del_icon, "Delete") + + # View image option + img_action = None + if m.get("image"): + img_icon = self.style().standardIcon(QStyle.StandardPixmap.SP_FileDialogContentsView) + img_action = menu.addAction(img_icon, "View image") + + # Download file option + file_action = None + if m.get("file"): + file_icon = self.style().standardIcon(QStyle.StandardPixmap.SP_DialogSaveButton) + file_action = menu.addAction(file_icon, "Download file") + + # Reset session option for undecryptable messages + reset_action = None + if m.get("text", "").startswith("[Decryption failed"): + reset_icon = self.style().standardIcon(QStyle.StandardPixmap.SP_BrowserReload) + reset_action = menu.addAction(reset_icon, "Reset session with sender") + + chosen = menu.exec(global_pos) + if not chosen: + return + if chosen == reply_action: + self.reply_to_id = m["message_id"] + sender = m.get("sender", "???") + preview = m.get("text", "")[:40] + self.reply_label.setText(f"Replying to {sender}: {preview}") + self._reply_widget.setVisible(True) + self.msg_input.setFocus() + elif chosen in react_actions: + rkey = react_actions[chosen] + existing = m.get("reactions", []) + has_it = any(r["user_id"] == my_user_id and r["reaction"] == rkey for r in existing) + if has_it: + m["reactions"] = [r for r in existing if r["user_id"] != my_user_id] + self.bridge.react_message(m["message_id"], rkey, "remove") + else: + new_reactions = [r for r in existing if r["user_id"] != my_user_id] + new_reactions.append({"user_id": my_user_id, "reaction": rkey, "created_at": ""}) + m["reactions"] = new_reactions + self.bridge.react_message(m["message_id"], rkey, "add") + # Persist to local cache so reactions survive conversation switch + self.bridge.client.update_message_in_cache( + self.current_conv_id, m["message_id"], + {"reactions": m["reactions"]}) + self._render_messages(scroll_to_bottom=self._is_near_bottom) + elif chosen == fwd_action: + self._show_forward_dialog(m) + elif chosen == pin_action: + if m.get("pinned_at"): + m.pop("pinned_at", None) + m.pop("pinned_by", None) + self.bridge.pin_message(m["message_id"], self.current_conv_id, "unpin") + self.bridge.client.update_message_in_cache( + self.current_conv_id, m["message_id"], + {"pinned_at": None, "pinned_by": None}) + else: + my_user_id_pin = self.bridge.client.session.get("user_id", "") if self.bridge.client.session else "" + m["pinned_at"] = "now" + m["pinned_by"] = my_user_id_pin + self.bridge.pin_message(m["message_id"], self.current_conv_id, "pin") + self.bridge.client.update_message_in_cache( + self.current_conv_id, m["message_id"], + {"pinned_at": "now", "pinned_by": my_user_id_pin}) + self._render_messages(scroll_to_bottom=self._is_near_bottom) + self._update_pin_banner() + elif chosen == del_action: + if self._confirm_dialog("Delete Message", + "Delete this message? This cannot be undone."): + # Apply locally immediately (server notification only goes to others) + m["deleted"] = True + m["text"] = "" + m["image"] = None + m["file"] = None + self.bridge.client.update_message_in_cache( + self.current_conv_id, m["message_id"], {"deleted": True}) + self._render_messages(scroll_to_bottom=self._is_near_bottom) + self.bridge.delete_message(m["message_id"]) + elif chosen == img_action: + self._view_image(m) + elif chosen == file_action: + file_info = m.get("file") + if file_info: + self.bridge.download_file(file_info["file_id"], file_info) + elif chosen == reset_action: + sender_id = m.get("sender_id", "") + if sender_id: + if self._confirm_dialog("Reset Session", + "Reset encryption session with this sender? " + "A new session will be created on the next message."): + self.bridge.reset_session(sender_id) + + # ------------------------------------------------------------------ + # Search + # ------------------------------------------------------------------ + + def _toggle_search(self): + if self._search_active: + self._close_search() + else: + if not self.current_conv_id: + return + self._search_active = True + self.search_widget.setVisible(True) + self.search_input.setFocus() + self.search_input.selectAll() + + def _close_search(self): + self._search_active = False + self._search_query = "" + self._search_results = [] + self._search_current = -1 + self.search_widget.setVisible(False) + self.search_input.clear() + self.search_count_label.setText("0/0") + # Re-render to remove highlights + if self.current_messages: + self._render_messages(scroll_to_bottom=False) + + def _on_search_text_changed(self, text): + self._search_query = text.strip() + if not self._search_query: + self._search_results = [] + self._search_current = -1 + self.search_count_label.setText("0/0") + self._render_messages(scroll_to_bottom=False) + return + query_lower = self._search_query.lower() + self._search_results = [] + for i, m in enumerate(self.current_messages): + if m.get("deleted"): + continue + msg_text = m.get("text", "") + if query_lower in msg_text.lower(): + self._search_results.append(i) + if self._search_results: + self._search_current = 0 + self.search_count_label.setText(f"1/{len(self._search_results)}") + else: + self._search_current = -1 + self.search_count_label.setText("0/0") + self._render_messages(scroll_to_bottom=False) + if self._search_results: + self._scroll_to_message(self._search_results[self._search_current]) + + def _on_search_next(self): + if not self._search_results: + return + self._search_current = (self._search_current + 1) % len(self._search_results) + self.search_count_label.setText(f"{self._search_current + 1}/{len(self._search_results)}") + self._render_messages(scroll_to_bottom=False) + self._scroll_to_message(self._search_results[self._search_current]) + + def _on_search_prev(self): + if not self._search_results: + return + self._search_current = (self._search_current - 1) % len(self._search_results) + self.search_count_label.setText(f"{self._search_current + 1}/{len(self._search_results)}") + self._render_messages(scroll_to_bottom=False) + self._scroll_to_message(self._search_results[self._search_current]) + + def _scroll_to_message(self, index): + if 0 <= index < len(self._msg_widgets): + self._msg_scroll_area.ensureWidgetVisible( + self._msg_widgets[index], 0, 50 + ) + + @staticmethod + def _highlight_search_text(html_text: str, query: str, bg_color: str) -> str: + """Highlight matching text in HTML, skipping content inside tags.""" + query_esc = query.replace("&", "&").replace("<", "<").replace(">", ">") + if not query_esc: + return html_text + result = [] + i = 0 + q_lower = query_esc.lower() + q_len = len(query_esc) + while i < len(html_text): + if html_text[i] == '<': + # Skip HTML tags + end = html_text.find('>', i) + if end == -1: + result.append(html_text[i:]) + break + result.append(html_text[i:end + 1]) + i = end + 1 + else: + # Look for match in text content + chunk_end = html_text.find('<', i) + if chunk_end == -1: + chunk_end = len(html_text) + chunk = html_text[i:chunk_end] + # Case-insensitive replace within this chunk + chunk_lower = chunk.lower() + out = [] + j = 0 + while j < len(chunk): + pos = chunk_lower.find(q_lower, j) + if pos == -1: + out.append(chunk[j:]) + break + out.append(chunk[j:pos]) + matched = chunk[pos:pos + q_len] + out.append(f'{matched}') + j = pos + q_len + result.append("".join(out)) + i = chunk_end + return "".join(result) + + # ------------------------------------------------------------------ + # Session reset + # ------------------------------------------------------------------ + + def _on_session_reset(self, from_user_id, from_device_id): + # Find username for the user + username = from_user_id[:8] + for cv in self.conversations: + for m in cv.get("members", []): + uid = m.get("user_id") or m.get("id") + if uid == from_user_id: + username = m.get("username") or m.get("email") or username + break + self.status_bar.setText(f"Session with {username} was reset. New session will be created on next message.") + t = c() + self.status_bar.setStyleSheet( + f"background-color: {t.bg_tertiary}; border-radius: 0px; " + f"padding: 0 8px; color: {t.warning}; font-size: 8pt; font-weight: bold;" + ) + QTimer.singleShot(8000, self._clear_status_bar) + + # ------------------------------------------------------------------ + # Reactions / Pins / Forward notification handlers + # ------------------------------------------------------------------ + + def _on_reaction_notification(self, data): + """Handle message_reacted push notification.""" + conv_id = data.get("conversation_id", "") + msg_id = data.get("message_id", "") + user_id = data.get("user_id", "") + reaction = data.get("reaction", "") + action = data.get("action", "add") + + if conv_id != self.current_conv_id: + return + for msg in self.current_messages: + if msg.get("message_id") == msg_id: + reactions = msg.get("reactions", []) + if action == "add": + reactions = [r for r in reactions if r["user_id"] != user_id] + reactions.append({"user_id": user_id, "reaction": reaction, "created_at": ""}) + else: + reactions = [r for r in reactions if r["user_id"] != user_id] + msg["reactions"] = reactions + self.bridge.client.update_message_in_cache( + conv_id, msg_id, {"reactions": reactions}) + break + self._render_messages(scroll_to_bottom=self._is_near_bottom) + + def _on_pin_notification(self, data): + """Handle message_pinned push notification.""" + conv_id = data.get("conversation_id", "") + msg_id = data.get("message_id", "") + user_id = data.get("user_id", "") + if conv_id != self.current_conv_id: + return + for msg in self.current_messages: + if msg.get("message_id") == msg_id: + msg["pinned_at"] = "now" + msg["pinned_by"] = user_id + self.bridge.client.update_message_in_cache( + conv_id, msg_id, {"pinned_at": "now", "pinned_by": user_id}) + break + self._render_messages(scroll_to_bottom=self._is_near_bottom) + self._update_pin_banner() + username = data.get("username", user_id[:8] if user_id else "?") + self.status_bar.setText(f"{username} pinned a message") + QTimer.singleShot(3000, self._clear_status_bar) + + def _on_unpin_notification(self, data): + """Handle message_unpinned push notification.""" + conv_id = data.get("conversation_id", "") + msg_id = data.get("message_id", "") + if conv_id != self.current_conv_id: + return + for msg in self.current_messages: + if msg.get("message_id") == msg_id: + msg.pop("pinned_at", None) + msg.pop("pinned_by", None) + self.bridge.client.update_message_in_cache( + conv_id, msg_id, {"pinned_at": None, "pinned_by": None}) + break + self._render_messages(scroll_to_bottom=self._is_near_bottom) + self._update_pin_banner() + + def _on_pinned_messages_loaded(self, conv_id, pinned): + """Show pinned messages dialog.""" + if conv_id != self.current_conv_id: + return + if not pinned: + dlg = QDialog(self) + dlg.setMinimumWidth(300) + lay = _make_frameless(dlg, "Pinned Messages") + t = c() + lbl = QLabel("No pinned messages in this conversation.") + lbl.setStyleSheet(f"color: {t.text_primary}; font-size: 10pt;") + lbl.setWordWrap(True) + lay.addWidget(lbl) + ok_btn = QPushButton("OK") + ok_btn.clicked.connect(dlg.accept) + lay.addWidget(ok_btn) + dlg.exec() + return + dlg = QDialog(self) + dlg.setMinimumSize(360, 300) + t = c() + layout = _make_frameless(dlg, "Pinned Messages") + lw = QListWidget() + lw.setStyleSheet( + f"QListWidget {{ background-color:{t.bg_secondary}; border:1px solid {t.border}; border-radius:6px; }}" + f"QListWidget::item {{ padding:8px; color:{t.text_primary}; border-bottom:1px solid {t.border}; }}" + f"QListWidget::item:selected {{ background-color:{t.bg_hover}; }}" + ) + # Build a sender map from current messages + sender_map = {} + for m in self.current_messages: + sid = m.get("sender_id", "") + if sid and sid not in sender_map: + sender_map[sid] = m.get("sender", sid[:8]) + for p in pinned: + sender = sender_map.get(p.get("sender_id", ""), p.get("sender_id", "?")[:8]) + # Find message text from current_messages + text_preview = "" + for m in self.current_messages: + if m.get("message_id") == p.get("message_id"): + text_preview = m.get("text", "")[:60] + if not sender_map.get(p.get("sender_id", "")): + sender = m.get("sender", sender) + break + ts = p.get("pinned_at", "")[:16] if p.get("pinned_at") else "" + item = QListWidgetItem(f"\U0001f4cc {sender}: {text_preview}\n Pinned: {ts}") + item.setData(Qt.ItemDataRole.UserRole, p.get("message_id")) + lw.addItem(item) + layout.addWidget(lw) + close_btn = QPushButton("Close") + close_btn.setObjectName("secondaryBtn") + close_btn.clicked.connect(dlg.accept) + layout.addWidget(close_btn) + + def _on_item_clicked(item): + target_id = item.data(Qt.ItemDataRole.UserRole) + if target_id: + for i, msg in enumerate(self.current_messages): + if msg.get("message_id") == target_id: + self._scroll_to_message(i) + break + dlg.accept() + lw.itemDoubleClicked.connect(_on_item_clicked) + dlg.exec() + + def _update_pin_banner(self): + """Update the pinned message banner from current_messages.""" + pinned = [m for m in self.current_messages if m.get("pinned_at")] + if not pinned: + self._pin_banner.setVisible(False) + self._pin_banner_msg_id = None + return + # Show the most recently pinned message + latest = pinned[-1] + sender = latest.get("sender", "?") + text = latest.get("text", "") + if len(text) > 80: + text = text[:80] + "..." + text = text.replace("\n", " ") + # HTML-escape user-controlled data to prevent injection + sender_esc = sender.replace("&", "&").replace("<", "<").replace(">", ">") + text_esc = text.replace("&", "&").replace("<", "<").replace(">", ">") + count_str = f" ({len(pinned)} pinned)" if len(pinned) > 1 else "" + self._pin_banner_label.setText(f"{sender_esc}: {text_esc}{count_str}") + self._pin_banner_msg_id = latest.get("message_id") + self._pin_banner.setVisible(True) + + def _on_pin_banner_clicked(self, event): + """Scroll to the pinned message when banner is clicked.""" + if self._pin_banner_msg_id: + for i, msg in enumerate(self.current_messages): + if msg.get("message_id") == self._pin_banner_msg_id: + self._scroll_to_message(i) + break + + def _show_pinned_messages(self): + """Fetch and show pinned messages for current conversation.""" + if self.current_conv_id: + self.bridge.get_pinned_messages(self.current_conv_id) + + def _on_forward_result(self, ok, msg): + if ok: + self.status_bar.setText("Message forwarded") + QTimer.singleShot(3000, self._clear_status_bar) + else: + self.status_bar.setText(f"Forward failed: {msg}") + QTimer.singleShot(5000, self._clear_status_bar) + + def _confirm_dialog(self, title: str, text: str) -> bool: + """Show a frameless confirmation dialog. Returns True if user confirmed.""" + dlg = QDialog(self) + dlg.setMinimumWidth(340) + layout = _make_frameless(dlg, title) + t = c() + label = QLabel(text) + label.setWordWrap(True) + label.setStyleSheet(f"color: {t.text_primary}; font-size: 10pt;") + layout.addWidget(label) + btn_lay = QHBoxLayout() + btn_lay.setSpacing(8) + cancel_btn = QPushButton("Cancel") + cancel_btn.setObjectName("secondaryBtn") + confirm_btn = QPushButton("Delete") + confirm_btn.setStyleSheet( + f"QPushButton {{ background-color: {t.error}; color: {t.accent_text}; " + f"border: none; border-radius: 6px; padding: 8px 16px; font-weight: bold; }}" + f"QPushButton:hover {{ background-color: {t.warning}; }}" + ) + btn_lay.addStretch() + btn_lay.addWidget(cancel_btn) + btn_lay.addWidget(confirm_btn) + layout.addLayout(btn_lay) + cancel_btn.clicked.connect(dlg.reject) + confirm_btn.clicked.connect(dlg.accept) + return dlg.exec() == QDialog.DialogCode.Accepted + + def _show_forward_dialog(self, msg): + """Show dialog to pick a conversation to forward a message to.""" + dlg = QDialog(self) + dlg.setMinimumSize(320, 360) + t = c() + layout = _make_frameless(dlg, "Forward message") + label = QLabel("Select conversation:") + label.setStyleSheet(f"color:{t.text_primary}; font-size:10pt;") + layout.addWidget(label) + conv_list = QListWidget() + conv_list.setStyleSheet( + f"QListWidget {{ background-color:{t.bg_secondary}; border:1px solid {t.border}; border-radius:6px; }}" + f"QListWidget::item {{ padding:8px; color:{t.text_primary}; }}" + f"QListWidget::item:selected {{ background-color:{t.bg_hover}; }}" + ) + for cv in self.conversations: + if cv["conversation_id"] != self.current_conv_id: + name = cv.get("name") + if not name: + others = [m.get("username") or m.get("email") or "?" for m in cv["members"] + if m.get("email") != self.bridge.client.email] + name = ", ".join(others) if others else "?" + item = QListWidgetItem(name) + item.setData(Qt.ItemDataRole.UserRole, cv) + conv_list.addItem(item) + layout.addWidget(conv_list) + fwd_btn = QPushButton("Forward") + fwd_btn.setObjectName("secondaryBtn") + layout.addWidget(fwd_btn) + + def _do_forward(): + sel = conv_list.currentItem() + if not sel: + return + target_conv = sel.data(Qt.ItemDataRole.UserRole) + if target_conv: + fwd_msg = dict(msg) + fwd_msg["conversation_id"] = self.current_conv_id + self.bridge.forward_message( + target_conv["conversation_id"], fwd_msg, target_conv["members"] + ) + dlg.accept() + + fwd_btn.clicked.connect(_do_forward) + conv_list.itemDoubleClicked.connect(lambda _: _do_forward()) + dlg.exec() + + # ------------------------------------------------------------------ + # @Mentions autocomplete + # ------------------------------------------------------------------ + + def _setup_mention_completer(self): + """Set up the mention autocomplete popup for msg_input.""" + self._mention_popup = QListWidget(self) + self._mention_popup.setWindowFlags(Qt.WindowType.Popup) + t = c() + self._mention_popup.setStyleSheet( + f"QListWidget {{ background:{t.bg_secondary}; color:{t.text_primary}; border:1px solid {t.border}; " + f"border-radius:4px; font-size:10pt; }}" + f"QListWidget::item {{ padding:4px 8px; }}" + f"QListWidget::item:selected {{ background:{t.bg_hover}; }}" + ) + self._mention_popup.setMaximumHeight(150) + self._mention_popup.itemClicked.connect(self._on_mention_selected) + self._mention_popup.hide() + self._mention_active = False + + def _check_mention_trigger(self): + """Check if user is typing @mention and show autocomplete.""" + if not hasattr(self, '_mention_popup'): + self._setup_mention_completer() + text = self.msg_input.toPlainText() + cursor_pos = self.msg_input.textCursor().position() + # Find @word at cursor position + before = text[:cursor_pos] + import re as _re + match = _re.search(r'@(\w*)$', before) + if not match: + self._mention_popup.hide() + self._mention_active = False + return + prefix = match.group(1).lower() + self._mention_start = match.start() + + # Get members of current conversation + members = [] + for cv in self.conversations: + if cv["conversation_id"] == self.current_conv_id: + members = cv.get("members", []) + break + my_email = self.bridge.client.email if self.bridge.client else "" + candidates = [] + for m in members: + uname = m.get("username") or m.get("email") or "" + if m.get("email") == my_email: + continue + if prefix == "" or uname.lower().startswith(prefix): + candidates.append(uname) + + if not candidates: + self._mention_popup.hide() + self._mention_active = False + return + + self._mention_popup.clear() + for cand in candidates[:6]: + self._mention_popup.addItem(cand) + # Position popup above the input + cursor_rect = self.msg_input.cursorRect() + global_pos = self.msg_input.mapToGlobal(cursor_rect.bottomLeft()) + self._mention_popup.move(global_pos.x(), global_pos.y() - self._mention_popup.sizeHint().height() - 5) + self._mention_popup.setFixedWidth(max(200, self.msg_input.width() // 2)) + self._mention_popup.show() + self._mention_active = True + + def _on_mention_selected(self, item): + """Insert the selected @mention into msg_input.""" + username = item.text() + text = self.msg_input.toPlainText() + cursor_pos = self.msg_input.textCursor().position() + # Replace @prefix with @username + new_text = text[:self._mention_start] + f"@{username} " + text[cursor_pos:] + self.msg_input.setPlainText(new_text) + cursor = self.msg_input.textCursor() + cursor.setPosition(self._mention_start + len(username) + 2) + self.msg_input.setTextCursor(cursor) + self._mention_popup.hide() + self._mention_active = False + self.msg_input.setFocus() + + def _on_input_changed(self): + text = self.msg_input.toPlainText() + count = len(text) + if count > MAX_INPUT_CHARS: + cursor = self.msg_input.textCursor() + self.msg_input.setPlainText(text[:MAX_INPUT_CHARS]) + cursor.movePosition(cursor.MoveOperation.End) + self.msg_input.setTextCursor(cursor) + count = MAX_INPUT_CHARS + color = c().error if count > MAX_INPUT_CHARS * 0.9 else c().text_muted + self.char_counter.setStyleSheet(f"color: {color}; font-size: 8pt; padding: 0 4px;") + self.char_counter.setText(f"{count} / {MAX_INPUT_CHARS}") + # @Mention autocomplete check + if self.current_conv_id: + self._check_mention_trigger() + + def _cancel_reply(self): + """Dismiss the reply preview.""" + self.reply_to_id = None + self.reply_label.setText("") + self._reply_widget.setVisible(False) + + def _on_send(self): + text = self.msg_input.toPlainText().strip() + if not text or not self.current_conv_id: + return + if len(text) > MAX_INPUT_CHARS: + QMessageBox.warning(self, "Message Too Long", + f"Message too long (max {MAX_INPUT_CHARS} characters).") + return + conv = None + for cv in self.conversations: + if cv["conversation_id"] == self.current_conv_id: + conv = cv + break + if not conv: + return + self.msg_input.clear() + self.bridge.send_message( + self.current_conv_id, text, conv["members"], + reply_to=self.reply_to_id, + ) + self.reply_to_id = None + self._reply_widget.setVisible(False) + + def _on_message_sent(self, ok, msg): + if not ok: + QMessageBox.warning(self, "Error", msg) + + def _on_message_sent_payload(self, conv_id, payload): + """Append the just-sent message locally without re-fetching from server.""" + # Update last-message cache + preview = payload.get("text", "") + if payload.get("image") and not preview: + preview = "Sent an image" + elif payload.get("file") and not preview: + preview = "Sent a file" + self._last_message_cache[conv_id] = (preview[:60], payload.get("created_at", ""), "sent") + self._update_conv_list_styles() + + if conv_id != self.current_conv_id: + return + # Avoid duplicate if notification arrived first (race) + msg_id = payload.get("message_id", "") + if msg_id: + for m in self.current_messages: + if m.get("message_id") == msg_id: + return + self.current_messages.append(payload) + idx = len(self.current_messages) - 1 + w = self._create_message_widget(payload, idx) + self._msg_layout.addWidget(w) + self._msg_widgets.append(w) + if self._is_near_bottom: + QTimer.singleShot(10, self._scroll_to_bottom) + + def _on_new_chat(self): + dlg = QDialog(self) + dlg.setMinimumWidth(400) + t = c() + lay = _make_frameless(dlg, "New Chat") + lay.setSpacing(12) + + email_label = QLabel("Recipient email") + email_label.setStyleSheet(f"color: {t.text_secondary}; font-size: 9pt;") + lay.addWidget(email_label) + email_input = QLineEdit() + email_input.setPlaceholderText("user@example.com") + email_input.setMinimumHeight(36) + lay.addWidget(email_input) + + msg_label = QLabel("First message") + msg_label.setStyleSheet(f"color: {t.text_secondary}; font-size: 9pt;") + lay.addWidget(msg_label) + msg_input = QLineEdit() + msg_input.setPlaceholderText("Type a message...") + msg_input.setMinimumHeight(36) + lay.addWidget(msg_input) + + btn_row = QHBoxLayout() + btn_row.addStretch() + cancel_btn = QPushButton("Cancel") + cancel_btn.setObjectName("secondaryBtn") + cancel_btn.clicked.connect(dlg.reject) + btn_row.addWidget(cancel_btn) + send_btn = QPushButton("Send") + send_btn.clicked.connect(dlg.accept) + btn_row.addWidget(send_btn) + lay.addLayout(btn_row) + + msg_input.returnPressed.connect(dlg.accept) + email_input.returnPressed.connect(lambda: msg_input.setFocus()) + + if dlg.exec() != QDialog.DialogCode.Accepted: + return + email = email_input.text().strip() + text = msg_input.text().strip() + if not email or not text: + return + if len(text) > MAX_INPUT_CHARS: + QMessageBox.warning(self, "Message Too Long", + f"Message too long (max {MAX_INPUT_CHARS} characters).") + return + self.bridge.send_new_chat(email, text) + + def _on_new_group(self): + dlg = QDialog(self) + dlg.setMinimumWidth(400) + t = c() + lay = _make_frameless(dlg, "New Group") + lay.setSpacing(12) + + name_label = QLabel("Group name") + name_label.setStyleSheet(f"color: {t.text_secondary}; font-size: 9pt;") + lay.addWidget(name_label) + name_input = QLineEdit() + name_input.setPlaceholderText("My Group") + name_input.setMinimumHeight(36) + lay.addWidget(name_input) + + members_label = QLabel("Member emails (comma-separated)") + members_label.setStyleSheet(f"color: {t.text_secondary}; font-size: 9pt;") + lay.addWidget(members_label) + members_input = QLineEdit() + members_input.setPlaceholderText("alice@example.com, bob@example.com") + members_input.setMinimumHeight(36) + lay.addWidget(members_input) + + btn_row = QHBoxLayout() + btn_row.addStretch() + cancel_btn = QPushButton("Cancel") + cancel_btn.setObjectName("secondaryBtn") + cancel_btn.clicked.connect(dlg.reject) + btn_row.addWidget(cancel_btn) + create_btn = QPushButton("Create") + create_btn.clicked.connect(dlg.accept) + btn_row.addWidget(create_btn) + lay.addLayout(btn_row) + + members_input.returnPressed.connect(dlg.accept) + name_input.returnPressed.connect(lambda: members_input.setFocus()) + + if dlg.exec() != QDialog.DialogCode.Accepted: + return + members = members_input.text().strip() + if not members: + return + member_list = [m.strip() for m in members.split(",") if m.strip()] + if member_list: + self.bridge.create_group(member_list, name=name_input.text().strip() or None) + + def _on_add_member(self): + if not self.current_conv_id: + return + dlg = QDialog(self) + dlg.setMinimumWidth(360) + t = c() + lay = _make_frameless(dlg, "Add Member") + lay.setSpacing(12) + lbl = QLabel("Email to invite") + lbl.setStyleSheet(f"color: {t.text_secondary}; font-size: 9pt;") + lay.addWidget(lbl) + email_input = QLineEdit() + email_input.setPlaceholderText("user@example.com") + email_input.setMinimumHeight(36) + email_input.returnPressed.connect(dlg.accept) + lay.addWidget(email_input) + btn_row = QHBoxLayout() + btn_row.addStretch() + cancel_btn = QPushButton("Cancel") + cancel_btn.setObjectName("secondaryBtn") + cancel_btn.clicked.connect(dlg.reject) + btn_row.addWidget(cancel_btn) + add_btn = QPushButton("Add") + add_btn.clicked.connect(dlg.accept) + btn_row.addWidget(add_btn) + lay.addLayout(btn_row) + if dlg.exec() != QDialog.DialogCode.Accepted: + return + email = email_input.text().strip() + if not email: + return + self.bridge.add_member(self.current_conv_id, email) + + def _on_add_member_result(self, ok, msg): + if ok: + QMessageBox.information(self, "Add Member", "Invitation sent.") + else: + QMessageBox.warning(self, "Add Member", msg) + + def _on_group_info(self): + if not self.current_conv_id: + return + conv = None + for cv in self.conversations: + if cv["conversation_id"] == self.current_conv_id: + conv = cv + break + if not conv: + return + my_user_id = self.bridge.client.session.get("user_id", "") if self.bridge.client.session else "" + is_creator = conv.get("created_by") == my_user_id + group_name = conv.get("name") or "Group" + members = conv["members"] + + dlg = QDialog(self) + dlg.setMinimumWidth(380) + t = c() + dlg_layout = _make_frameless(dlg, "Group Info") + + # Group avatar + avatar_row = QHBoxLayout() + avatar_label = QLabel() + avatar_label.setFixedSize(64, 64) + avatar_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + conv_id = conv["conversation_id"] + if conv_id in self._group_avatar_cache: + avatar_pix = self._make_circular_avatar(self._group_avatar_cache[conv_id], size=64) + else: + avatar_pix = self._make_default_avatar(group_name, size=64) + avatar_label.setPixmap(avatar_pix) + avatar_row.addWidget(avatar_label) + + group_name_esc = group_name.replace("&", "&").replace("<", "<").replace(">", ">") + title = QLabel(f"{group_name_esc}") + avatar_row.addWidget(title, stretch=1) + + if is_creator: + change_avatar_btn = QPushButton("Change Avatar") + change_avatar_btn.setObjectName("secondaryBtn") + change_avatar_btn.clicked.connect(lambda: self._do_change_group_avatar(conv_id, dlg)) + avatar_row.addWidget(change_avatar_btn) + + rename_btn = QPushButton("Rename") + rename_btn.setObjectName("secondaryBtn") + rename_btn.clicked.connect(lambda: self._do_rename_group(conv_id, group_name, dlg)) + avatar_row.addWidget(rename_btn) + + dlg_layout.addLayout(avatar_row) + + count_label = QLabel(f"Members ({len(members)}):") + count_label.setStyleSheet("margin-top: 8px;") + dlg_layout.addWidget(count_label) + + for mem in members: + uname = mem.get("username") or mem.get("email") or "?" + email = mem.get("email", "") + uid = mem.get("user_id") or mem.get("id") or "" + is_mem_creator = uid == conv.get("created_by") + + row = QHBoxLayout() + uname_esc = uname.replace("&", "&").replace("<", "<").replace(">", ">") + email_esc = email.replace("&", "&").replace("<", "<").replace(">", ">") + is_online = uid in self._online_users + online_dot = "\U0001f7e2 " if is_online else "" + verified_dot = "" + my_uid = self.bridge.client.session.get("user_id", "") if self.bridge.client.session else "" + if uid and uid != my_uid and self.bridge.client.get_verification_status(uid) == "verified": + verified_dot = " \u2705" + name_text = f"{online_dot}{uname_esc}{verified_dot}" + if email: + name_text += f" {email_esc}" + if is_mem_creator: + name_text += f" creator" + name_label = QLabel(name_text) + name_label.setWordWrap(True) + row.addWidget(name_label, stretch=1) + + info_btn = QPushButton("") + info_btn.setFixedSize(28, 28) + info_btn.setObjectName("secondaryBtn") + info_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_MessageBoxInformation)) + info_btn.setToolTip(f"View profile of {uname}") + info_btn.clicked.connect(lambda checked, u=uid, d=dlg: (d.accept(), self._show_user_profile(u))) + row.addWidget(info_btn) + + # Remove button (only for creator, not on self) + if is_creator and uid != my_user_id: + remove_btn = QPushButton("") + remove_btn.setFixedSize(28, 28) + remove_btn.setObjectName("secondaryBtn") + remove_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_DialogCloseButton)) + remove_btn.setToolTip(f"Remove {uname}") + remove_btn.clicked.connect(lambda checked, u=uid, n=uname, d=dlg: self._do_remove_member_action(u, n, d)) + row.addWidget(remove_btn) + + dlg_layout.addLayout(row) + + dlg_layout.addSpacing(12) + + # Leave Group button + leave_btn = QPushButton("Leave Group") + leave_btn.setStyleSheet( + f"QPushButton {{ background-color: {t.error}; color: {t.accent_text}; font-weight: bold; }}" + f"QPushButton:hover {{ opacity: 0.8; }}" + ) + leave_btn.clicked.connect(lambda: self._do_leave_group_action(dlg)) + dlg_layout.addWidget(leave_btn) + + close_btn = QPushButton("Close") + close_btn.clicked.connect(dlg.accept) + dlg_layout.addWidget(close_btn) + dlg.exec() + + def _do_remove_member_action(self, user_id, username, dialog): + if not self.current_conv_id: + return + confirm = QMessageBox.question( + self, "Remove Member", + f"Remove {username} from the group?", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + ) + if confirm == QMessageBox.StandardButton.Yes: + dialog.accept() + self.bridge.remove_member(self.current_conv_id, user_id) + + def _do_change_group_avatar(self, conv_id, dialog): + path, _ = QFileDialog.getOpenFileName( + dialog, "Select Group Avatar", "", + "Images (*.png *.jpg *.jpeg);;All Files (*)", + ) + if not path: + return + try: + with open(path, "rb") as f: + image_data = f.read() + if len(image_data) > 2 * 1024 * 1024: + QMessageBox.warning(dialog, "Too Large", "Avatar must be under 2 MB.") + return + dialog.accept() + self.bridge.update_group_avatar(conv_id, image_data) + except Exception as e: + QMessageBox.warning(dialog, "Error", f"Failed to read image: {e}") + + def _do_rename_group(self, conv_id, current_name, dialog): + from PyQt6.QtWidgets import QInputDialog + new_name, ok = QInputDialog.getText( + dialog, "Rename Group", "New group name:", + text=current_name, + ) + if ok and new_name.strip(): + new_name = new_name.strip() + if new_name != current_name: + dialog.accept() + self.bridge.rename_conversation(conv_id, new_name) + + def _on_group_renamed(self, ok, msg): + if not ok: + QMessageBox.warning(self, "Rename Group", msg) + + def _do_leave_group_action(self, dialog): + if not self.current_conv_id: + return + confirm = QMessageBox.question( + self, "Leave Group", + "Leave this group? You will no longer receive messages.", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + ) + if confirm == QMessageBox.StandardButton.Yes: + dialog.accept() + self.bridge.leave_group(self.current_conv_id) + + def _on_group_left(self, ok, msg): + if ok: + self.current_conv_id = None + self.msg_input.drop_enabled = False + self.chat_header.setText("Select a conversation") + self.chat_header_avatar.setVisible(False) + self._clear_message_area() + self.group_info_btn.setVisible(False) + self.user_info_btn.setVisible(False) + self.add_member_btn.setVisible(False) + self.delete_conv_btn.setVisible(False) + else: + QMessageBox.warning(self, "Leave Group", msg) + + def _on_delete_conv_btn(self): + if not self.current_conv_id: + return + confirm = QMessageBox.question( + self, "Delete Conversation", + "Delete this conversation? This cannot be undone.", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + ) + if confirm == QMessageBox.StandardButton.Yes: + self.bridge.delete_conversation(self.current_conv_id) + + def _on_conversation_deleted(self, ok, msg): + if ok: + self.current_conv_id = None + self.msg_input.drop_enabled = False + self.chat_header.setText("Select a conversation") + self.chat_header_avatar.setVisible(False) + self._clear_message_area() + self.group_info_btn.setVisible(False) + self.user_info_btn.setVisible(False) + self.add_member_btn.setVisible(False) + self.delete_conv_btn.setVisible(False) + else: + QMessageBox.warning(self, "Delete Conversation", msg) + + def _on_remove_member_result(self, ok, msg): + if ok: + QMessageBox.information(self, "Remove Member", "Member removed.") + if self.current_conv_id: + self.bridge.load_messages(self.current_conv_id) + else: + QMessageBox.warning(self, "Remove Member", msg) + + def _on_my_profile(self): + my_user_id = self.bridge.client.session.get("user_id", "") if self.bridge.client.session else "" + if not my_user_id: + return + dlg = UserProfileDialog(self.bridge, my_user_id, editable=True, parent=self) + dlg.exec() + + def _on_dm_user_info(self): + """Show profile of the other user in a DM conversation.""" + if not self.current_conv_id: + return + conv = None + for cv in self.conversations: + if cv["conversation_id"] == self.current_conv_id: + conv = cv + break + if not conv: + return + my_email = self.bridge.client.email + for m in conv["members"]: + if m.get("email") != my_email: + uid = m.get("user_id") or m.get("id") + if uid: + self._show_user_profile(uid) + return + + def _show_user_profile(self, user_id): + dlg = UserProfileDialog(self.bridge, user_id, editable=False, parent=self) + dlg.exec() + + def _on_open_settings(self): + t = c() + dlg = QDialog(self) + dlg.setMinimumWidth(360) + lay = _make_frameless(dlg, "Settings") + lay.setSpacing(16) + + # -- Appearance section -- + sec_appearance = QLabel("Appearance") + sec_appearance.setStyleSheet( + f"font-size: 10pt; font-weight: bold; color: {t.text_secondary}; " + f"margin-top: 4px;" + ) + lay.addWidget(sec_appearance) + + theme_row = QHBoxLayout() + theme_label = QLabel("Theme") + theme_label.setStyleSheet(f"font-size: 11pt; color: {t.text_primary};") + theme_row.addWidget(theme_label) + theme_row.addStretch() + theme_btn = QPushButton( + "\u2600 Light mode" if tm().is_dark else "\U0001f319 Dark mode" + ) + theme_btn.setObjectName("secondaryBtn") + theme_btn.setFixedWidth(140) + theme_row.addWidget(theme_btn) + lay.addLayout(theme_row) + + # Separator + sep1 = QFrame() + sep1.setFrameShape(QFrame.Shape.HLine) + sep1.setStyleSheet(f"background-color: {t.separator}; max-height: 1px;") + lay.addWidget(sep1) + + # -- Security section -- + sec_security = QLabel("Security") + sec_security.setStyleSheet( + f"font-size: 10pt; font-weight: bold; color: {t.text_secondary};" + ) + lay.addWidget(sec_security) + + # Rotate Keys + rotate_row = QHBoxLayout() + rotate_info = QVBoxLayout() + rotate_title = QLabel("Rotate Keys") + rotate_title.setStyleSheet(f"font-size: 11pt; color: {t.text_primary};") + rotate_info.addWidget(rotate_title) + rotate_desc = QLabel("Generate new RSA keys. Revokes other devices.") + rotate_desc.setStyleSheet(f"font-size: 8pt; color: {t.text_muted};") + rotate_desc.setWordWrap(True) + rotate_info.addWidget(rotate_desc) + rotate_row.addLayout(rotate_info, stretch=1) + rotate_btn = QPushButton("Rotate") + rotate_btn.setObjectName("secondaryBtn") + rotate_btn.setFixedWidth(100) + rotate_btn.clicked.connect(lambda: (dlg.close(), self._on_rotate_keys())) + rotate_row.addWidget(rotate_btn) + lay.addLayout(rotate_row) + + # Change Username + chun_row = QHBoxLayout() + chun_info = QVBoxLayout() + chun_title = QLabel("Change Username") + chun_title.setStyleSheet(f"font-size: 11pt; color: {t.text_primary};") + chun_info.addWidget(chun_title) + chun_desc = QLabel("Change your display name.") + chun_desc.setStyleSheet(f"font-size: 8pt; color: {t.text_muted};") + chun_desc.setWordWrap(True) + chun_info.addWidget(chun_desc) + chun_row.addLayout(chun_info, stretch=1) + chun_btn = QPushButton("Change") + chun_btn.setObjectName("secondaryBtn") + chun_btn.setFixedWidth(100) + chun_btn.clicked.connect(lambda: (dlg.close(), self._on_change_username())) + chun_row.addWidget(chun_btn) + lay.addLayout(chun_row) + + # Change Password + chpw_row = QHBoxLayout() + chpw_info = QVBoxLayout() + chpw_title = QLabel("Change Password") + chpw_title.setStyleSheet(f"font-size: 11pt; color: {t.text_primary};") + chpw_info.addWidget(chpw_title) + chpw_desc = QLabel("Change password for local key encryption.") + chpw_desc.setStyleSheet(f"font-size: 8pt; color: {t.text_muted};") + chpw_desc.setWordWrap(True) + chpw_info.addWidget(chpw_desc) + chpw_row.addLayout(chpw_info, stretch=1) + chpw_btn = QPushButton("Change") + chpw_btn.setObjectName("secondaryBtn") + chpw_btn.setFixedWidth(100) + chpw_btn.clicked.connect(lambda: (dlg.close(), self._on_change_password())) + chpw_row.addWidget(chpw_btn) + lay.addLayout(chpw_row) + + # Separator + sep2 = QFrame() + sep2.setFrameShape(QFrame.Shape.HLine) + sep2.setStyleSheet(f"background-color: {t.separator}; max-height: 1px;") + lay.addWidget(sep2) + + # -- Devices section -- + sec_devices = QLabel("Devices") + sec_devices.setStyleSheet( + f"font-size: 10pt; font-weight: bold; color: {t.text_secondary};" + ) + lay.addWidget(sec_devices) + + # Link Device (new) + link_row = QHBoxLayout() + link_info = QVBoxLayout() + link_title = QLabel("Link New Device") + link_title.setStyleSheet(f"font-size: 11pt; color: {t.text_primary};") + link_info.addWidget(link_title) + link_desc = QLabel("Authorize another device to access your account.") + link_desc.setStyleSheet(f"font-size: 8pt; color: {t.text_muted};") + link_desc.setWordWrap(True) + link_info.addWidget(link_desc) + link_row.addLayout(link_info, stretch=1) + link_btn = QPushButton("Link") + link_btn.setObjectName("secondaryBtn") + link_btn.setFixedWidth(100) + link_btn.clicked.connect(lambda: (dlg.close(), self._on_authorize_device())) + link_row.addWidget(link_btn) + lay.addLayout(link_row) + + # Wire up theme toggle now that all widgets exist + _s_labels = [sec_appearance, sec_security, sec_devices] + _t_labels = [theme_label, rotate_title, chpw_title, link_title] + _d_labels = [rotate_desc, chpw_desc, link_desc] + _seps = [sep1, sep2] + + def _toggle_and_update(): + tm().toggle() + theme_btn.setText( + "\u2600 Light mode" if tm().is_dark else "\U0001f319 Dark mode" + ) + t2 = c() + dlg._frameless_container.setStyleSheet( + f"#_framelessContainer {{ background-color: {t2.bg_primary}; border-radius: 12px; }}" + ) + dlg._frameless_title_bar.setStyleSheet( + f"background-color: {t2.bg_secondary}; " + f"border-top-left-radius: 12px; border-top-right-radius: 12px;" + ) + dlg._frameless_title_label.setStyleSheet( + f"font-weight: bold; font-size: 10pt; color: {t2.text_primary}; background: transparent;" + ) + for lbl in _s_labels: + lbl.setStyleSheet(f"font-size: 10pt; font-weight: bold; color: {t2.text_secondary};") + for lbl in _t_labels: + lbl.setStyleSheet(f"font-size: 11pt; color: {t2.text_primary};") + for lbl in _d_labels: + lbl.setStyleSheet(f"font-size: 8pt; color: {t2.text_muted};") + for s in _seps: + s.setStyleSheet(f"background-color: {t2.separator}; max-height: 1px;") + theme_btn.clicked.connect(_toggle_and_update) + + lay.addStretch() + + # Close button + close_btn = QPushButton("Close") + close_btn.clicked.connect(dlg.close) + lay.addWidget(close_btn) + + dlg.exec() + + def _on_authorize_device(self): + code, ok = QInputDialog.getText(self, "Authorize Device", "Pairing code:") + if not ok or not code.strip(): + return + self.bridge.authorize_device(code.strip()) + + def _on_rotate_keys(self): + confirm = QMessageBox.question( + self, + "Rotate Keys", + "This will revoke other devices. Continue?", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + ) + if confirm != QMessageBox.StandardButton.Yes: + return + password, ok = QInputDialog.getText(self, "Rotate Keys", "Password:", QLineEdit.EchoMode.Password) + if not ok or not password: + return + self.bridge.rotate_keys(self.bridge.client.username, password) + + def _on_change_password(self): + t = c() + dlg = QDialog(self) + dlg.setMinimumWidth(360) + lay = _make_frameless(dlg, "Change Password") + + old_label = QLabel("Current Password") + old_label.setStyleSheet(f"color: {t.text_secondary}; font-size: 9pt;") + lay.addWidget(old_label) + old_input = QLineEdit() + old_input.setEchoMode(QLineEdit.EchoMode.Password) + old_input.setPlaceholderText("Enter current password") + old_input.setStyleSheet( + f"QLineEdit {{ font-size: 11pt; background-color: {t.bg_secondary}; " + f"border: 1px solid {t.border}; border-radius: 6px; padding: 8px; " + f"color: {t.text_primary}; }}" + f"QLineEdit:focus {{ border: 1px solid {t.border_focus}; }}" + ) + lay.addWidget(old_input) + + new_label = QLabel("New Password") + new_label.setStyleSheet(f"color: {t.text_secondary}; font-size: 9pt;") + lay.addWidget(new_label) + new_input = QLineEdit() + new_input.setEchoMode(QLineEdit.EchoMode.Password) + new_input.setPlaceholderText("Enter new password") + new_input.setStyleSheet(old_input.styleSheet()) + lay.addWidget(new_input) + + confirm_label = QLabel("Confirm New Password") + confirm_label.setStyleSheet(f"color: {t.text_secondary}; font-size: 9pt;") + lay.addWidget(confirm_label) + confirm_input = QLineEdit() + confirm_input.setEchoMode(QLineEdit.EchoMode.Password) + confirm_input.setPlaceholderText("Re-enter new password") + confirm_input.setStyleSheet(old_input.styleSheet()) + lay.addWidget(confirm_input) + + error_label = QLabel("") + error_label.setStyleSheet(f"color: {t.error}; font-size: 9pt;") + error_label.hide() + lay.addWidget(error_label) + + btn_lay = QHBoxLayout() + btn_lay.addStretch() + change_btn = QPushButton("Change Password") + change_btn.setStyleSheet( + f"QPushButton {{ background-color: {t.accent}; color: {t.accent_text}; " + f"border: none; border-radius: 6px; padding: 8px 16px; font-weight: bold; }}" + f"QPushButton:hover {{ opacity: 0.9; }}" + ) + btn_lay.addWidget(change_btn) + lay.addLayout(btn_lay) + + def _do_change(): + old_pw = old_input.text() + new_pw = new_input.text() + conf_pw = confirm_input.text() + if not old_pw: + error_label.setText("Current password is required.") + error_label.show() + return + if not new_pw: + error_label.setText("New password cannot be empty.") + error_label.show() + return + if new_pw != conf_pw: + error_label.setText("New passwords do not match.") + error_label.show() + return + dlg.accept() + self.bridge.change_password(old_pw, new_pw) + + change_btn.clicked.connect(_do_change) + confirm_input.returnPressed.connect(_do_change) + dlg.exec() + + def _on_logout(self): + confirm = QMessageBox.question( + self, + "Logout", + "Log out and return to the login screen?", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + ) + if confirm != QMessageBox.StandardButton.Yes: + return + self._is_logout = True + self.bridge.logout() + self.close() + if self._on_logout_cb: + self._on_logout_cb() + + def _on_notification(self, payload): + sender = payload.get("sender", "???") + conv_id = payload.get("conversation_id", "") + + # Update last-message cache for conversation list preview + if conv_id: + preview = payload.get("text", "") + if payload.get("image") and not preview: + preview = "Sent an image" + elif payload.get("file") and not preview: + preview = "Sent a file" + if preview: + self._last_message_cache[conv_id] = ( + f"{sender}: {preview}"[:60], + payload.get("created_at", ""), + "", # incoming message — no receipt status for others' msgs + ) + + # Resolve conversation name for notifications + conv_name = sender + is_notif_dm = False + if conv_id: + for cv in self.conversations: + if cv["conversation_id"] == conv_id: + is_notif_dm = len(cv["members"]) == 2 and not cv.get("name") + if not is_notif_dm: + conv_name = cv.get("name") or sender + break + + # System tray toast when window is not visible or not focused + if conv_id: + if is_notif_dm: + notif_title = sender + notif_text = payload.get("text", "New message") + else: + notif_title = conv_name + notif_text = f"{sender}: {payload.get('text', 'New message')}" + if payload.get("image"): + notif_text = notif_text or "Sent an image" + elif payload.get("file"): + notif_text = notif_text or "Sent a file" + self._show_tray_notification(notif_title, notif_text) + + # Show notification in status bar (for non-current conversations) + if conv_id and conv_id != self.current_conv_id: + if is_notif_dm: + self.status_bar.setText(f"New message from {sender}") + else: + self.status_bar.setText(f"New message from {sender} in {conv_name}") + t = c() + self.status_bar.setStyleSheet( + f"background-color: {t.bg_tertiary}; border-radius: 0px; " + f"padding: 0 8px; color: {t.success}; font-size: 8pt; font-weight: bold;" + ) + self._status_bar_conv_id = conv_id + QTimer.singleShot(5000, self._clear_status_bar) + + # Confirm delivery for all incoming messages (always, regardless of current view) + msg_id = payload.get("message_id", "") + if conv_id and msg_id: + self.bridge.schedule( + self.bridge.client.confirm_delivery(conv_id, [msg_id]) + ) + + # Increment unread count if not currently viewing this conversation + # (or if privacy overlay is locked — user can't see messages) + viewing = conv_id == self.current_conv_id and not self._privacy_locked + if conv_id and not viewing: + self._unread_counts[conv_id] = self._unread_counts.get(conv_id, 0) + 1 + self._update_conv_list_styles() + + # Append directly to current conversation instead of re-fetching + if conv_id == self.current_conv_id: + # Avoid duplicate if local send already appended this message + if msg_id: + for m in self.current_messages: + if m.get("message_id") == msg_id: + return + self.current_messages.append(payload) + idx = len(self.current_messages) - 1 + w = self._create_message_widget(payload, idx) + self._msg_layout.addWidget(w) + self._msg_widgets.append(w) + if self._is_near_bottom: + QTimer.singleShot(10, self._scroll_to_bottom) + else: + self.jump_btn.setText("\u2193 New") + self.jump_btn.setVisible(True) + self._position_jump_btn() + # Mark as read (only if not locked) + if msg_id and not self._privacy_locked: + self.bridge.schedule( + self.bridge.client.mark_read(conv_id, [msg_id]) + ) + + def _on_messages_read(self, data): + conv_id = data.get("conversation_id", "") + user_id = data.get("user_id", "") + message_ids = set(data.get("message_ids", [])) + my_uid = self.bridge.client.session.get("user_id", "") if self.bridge.client.session else "" + + # Persist to cache for ALL conversations (not just current) + if conv_id == self.current_conv_id: + for msg in self.current_messages: + if message_ids and msg.get("message_id") not in message_ids: + continue + if not message_ids and msg.get("sender_id") != my_uid: + continue + read_by = msg.get("read_by", []) + if not any(r.get("user_id") == user_id for r in read_by): + read_by.append({"user_id": user_id}) + msg["read_by"] = read_by + self.bridge.client.update_message_in_cache( + conv_id, msg.get("message_id"), {"read_by": read_by}) + self._render_messages(scroll_to_bottom=self._is_near_bottom) + else: + # Non-current conversation: update cache only (no in-memory messages to update) + cached = self.bridge.client.load_message_cache(conv_id) + if cached: + for msg_id_key, entry in cached.items(): + if message_ids and msg_id_key not in message_ids: + continue + if not message_ids and entry.get("sender_id") != my_uid: + continue + read_by = entry.get("read_by", []) + if not any(r.get("user_id") == user_id for r in read_by): + read_by.append({"user_id": user_id}) + self.bridge.client.update_message_in_cache( + conv_id, msg_id_key, {"read_by": read_by}) + # Update conv list receipt status to "read" + if conv_id in self._last_message_cache: + prev_text, prev_ts, prev_receipt = self._last_message_cache[conv_id] + if prev_receipt in ("sent", "delivered"): + self._last_message_cache[conv_id] = (prev_text, prev_ts, "read") + self._update_conv_list_styles() + + def _on_message_delivered(self, data): + conv_id = data.get("conversation_id", "") + user_id = data.get("user_id", "") + message_ids = set(data.get("message_ids", [])) + + if conv_id == self.current_conv_id: + for msg in self.current_messages: + if msg.get("message_id") in message_ids: + delivered_to = msg.get("delivered_to", []) + if not any(d.get("user_id") == user_id for d in delivered_to): + delivered_to.append({"user_id": user_id}) + msg["delivered_to"] = delivered_to + self.bridge.client.update_message_in_cache( + conv_id, msg.get("message_id"), {"delivered_to": delivered_to}) + self._render_messages(scroll_to_bottom=self._is_near_bottom) + else: + # Non-current conversation: update cache only + cached = self.bridge.client.load_message_cache(conv_id) + if cached: + for msg_id_key, entry in cached.items(): + if msg_id_key in message_ids: + delivered_to = entry.get("delivered_to", []) + if not any(d.get("user_id") == user_id for d in delivered_to): + delivered_to.append({"user_id": user_id}) + self.bridge.client.update_message_in_cache( + conv_id, msg_id_key, {"delivered_to": delivered_to}) + # Update conv list receipt status to "delivered" + if conv_id in self._last_message_cache: + prev_text, prev_ts, prev_receipt = self._last_message_cache[conv_id] + if prev_receipt == "sent": + self._last_message_cache[conv_id] = (prev_text, prev_ts, "delivered") + self._update_conv_list_styles() + + def _on_link_clicked(self, url_str): + """Handle link clicks from message bubble labels.""" + if url_str.startswith("image://"): + file_id = url_str[len("image://"):] + for msg in self.current_messages: + image_info = msg.get("image") + if image_info and image_info.get("file_id") == file_id: + self._view_image(msg) + return + elif url_str.startswith("file://"): + file_id = url_str[len("file://"):] + for msg in self.current_messages: + file_info = msg.get("file") + if file_info and file_info.get("file_id") == file_id: + self.bridge.download_file(file_id, file_info) + return + elif url_str.startswith("https://"): + QDesktopServices.openUrl(QUrl(url_str)) + elif url_str.startswith("http://"): + reply = QMessageBox.warning( + self, + "Insecure link", + f"This link uses unencrypted HTTP.\n\n{url_str}\n\nContinue anyway?", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + QMessageBox.StandardButton.No, + ) + if reply == QMessageBox.StandardButton.Yes: + QDesktopServices.openUrl(QUrl(url_str)) + + # -- Drag & drop -------------------------------------------------------- + + _IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp"} + def _msg_area_normal_style(self): + return f"QScrollArea {{ background-color: {c().bg_primary}; border: none; }}" + + def eventFilter(self, obj, event): + """Handle drag-and-drop on message scroll area + context menu on messages.""" + from PyQt6.QtCore import QEvent + # Context menu from any child of message container + if event.type() == QEvent.Type.ContextMenu and obj is not self._msg_scroll_area: + widget = obj + idx = self._find_msg_index_at_widget(widget) + if idx is not None: + self._show_msg_context_menu(idx, event.globalPos()) + return True + if obj is self._msg_scroll_area: + if event.type() == QEvent.Type.DragEnter: + if not self.current_conv_id: + event.ignore() + return True + if event.mimeData().hasUrls() and any(u.isLocalFile() for u in event.mimeData().urls()): + event.acceptProposedAction() + self._msg_scroll_area.setStyleSheet( + f"QScrollArea {{ border: 2px dashed {c().accent}; }}" + ) + return True + elif event.type() == QEvent.Type.DragMove: + if event.mimeData().hasUrls(): + event.acceptProposedAction() + return True + elif event.type() == QEvent.Type.DragLeave: + self._msg_scroll_area.setStyleSheet(self._msg_area_normal_style()) + return True + elif event.type() == QEvent.Type.Drop: + self._msg_scroll_area.setStyleSheet(self._msg_area_normal_style()) + if event.mimeData().hasUrls(): + for url in event.mimeData().urls(): + if url.isLocalFile(): + self._on_file_dropped(url.toLocalFile()) + event.acceptProposedAction() + return True + return super().eventFilter(obj, event) + + def _on_file_dropped(self, path: str): + """Send a dropped file as image or file attachment.""" + if not self.current_conv_id: + return + conv = None + for cv in self.conversations: + if cv["conversation_id"] == self.current_conv_id: + conv = cv + break + if not conv: + return + import os + ext = os.path.splitext(path)[1].lower() + if ext in self._IMAGE_EXTENSIONS: + self.bridge.send_image( + self.current_conv_id, path, conv["members"], + reply_to=self.reply_to_id, + ) + else: + self.bridge.send_file( + self.current_conv_id, path, conv["members"], + reply_to=self.reply_to_id, + ) + self.reply_to_id = None + self._reply_widget.setVisible(False) + + # -- Attach menu ------------------------------------------------------- + + def _on_attach_image(self): + if not self.current_conv_id: + return + path, _ = QFileDialog.getOpenFileName( + self, "Select Image", "", + "Images (*.png *.jpg *.jpeg *.gif *.bmp *.webp);;All Files (*)", + ) + if not path: + return + conv = None + for cv in self.conversations: + if cv["conversation_id"] == self.current_conv_id: + conv = cv + break + if not conv: + return + self.bridge.send_image( + self.current_conv_id, path, conv["members"], + reply_to=self.reply_to_id, + ) + self.reply_to_id = None + self._reply_widget.setVisible(False) + + @staticmethod + def _human_file_size(size_bytes): + if size_bytes >= 1024 * 1024: + return f"{size_bytes / (1024 * 1024):.1f} MB" + elif size_bytes >= 1024: + return f"{size_bytes / 1024:.0f} KB" + return f"{size_bytes} B" + + @staticmethod + def _file_icon(filename: str) -> str: + """Return an emoji icon based on file extension.""" + ext = filename.rsplit(".", 1)[-1].lower() if "." in filename else "" + _icons = { + "pdf": "\U0001f4d5", # red book + "doc": "\U0001f4d8", # blue book + "docx": "\U0001f4d8", + "odt": "\U0001f4d8", + "xls": "\U0001f4ca", # bar chart + "xlsx": "\U0001f4ca", + "ods": "\U0001f4ca", + "csv": "\U0001f4ca", + "ppt": "\U0001f4d9", # orange book + "pptx": "\U0001f4d9", + "odp": "\U0001f4d9", + "zip": "\U0001f4e6", # package + "rar": "\U0001f4e6", + "7z": "\U0001f4e6", + "tar": "\U0001f4e6", + "gz": "\U0001f4e6", + "mp3": "\U0001f3b5", # music note + "wav": "\U0001f3b5", + "flac": "\U0001f3b5", + "ogg": "\U0001f3b5", + "m4a": "\U0001f3b5", + "mp4": "\U0001f3ac", # clapper board + "mkv": "\U0001f3ac", + "avi": "\U0001f3ac", + "mov": "\U0001f3ac", + "webm": "\U0001f3ac", + "py": "\U0001f40d", # snake + "js": "\U0001f4dc", # scroll + "ts": "\U0001f4dc", + "html": "\U0001f310", # globe + "css": "\U0001f3a8", # palette + "json": "\U0001f4cb", # clipboard + "xml": "\U0001f4cb", + "yaml": "\U0001f4cb", + "yml": "\U0001f4cb", + "txt": "\U0001f4c4", # page facing up + "log": "\U0001f4c4", + "md": "\U0001f4c4", + } + return _icons.get(ext, "\U0001f4ce") # default: paperclip + + def _on_attach_file(self): + if not self.current_conv_id: + return + path, _ = QFileDialog.getOpenFileName( + self, "Select File", "", + "All Files (*)", + ) + if not path: + return + conv = None + for cv in self.conversations: + if cv["conversation_id"] == self.current_conv_id: + conv = cv + break + if not conv: + return + self.bridge.send_file( + self.current_conv_id, path, conv["members"], + reply_to=self.reply_to_id, + ) + self.reply_to_id = None + self._reply_widget.setVisible(False) + + def _on_file_sent(self, ok, msg): + if not ok: + QMessageBox.warning(self, "File Error", msg) + + def _on_file_downloaded(self, data, file_info): + filename = _safe_filename(file_info.get("filename", "file"), "file") + path, _ = QFileDialog.getSaveFileName(self, "Save File", filename) + if path: + try: + with open(path, "wb") as f: + f.write(data) + QMessageBox.information(self, "Saved", f"File saved to {path}") + except Exception as e: + QMessageBox.warning(self, "Error", f"Failed to save: {e}") + + def _on_image_sent(self, ok, msg): + if not ok: + QMessageBox.warning(self, "Image Error", msg) + + def _view_image(self, msg): + image_info = msg.get("image") + if not image_info: + return + file_id = image_info.get("file_id", "") + self._pending_image_download = {"file_id": file_id, "image_info": image_info} + self.bridge.download_image(file_id, image_info) + + def _on_image_downloaded(self, file_id, data): + if not self._pending_image_download or self._pending_image_download["file_id"] != file_id: + return + image_info = self._pending_image_download["image_info"] + self._pending_image_download = None + self._show_image_dialog(data, image_info) + + def _show_image_dialog(self, image_data, image_info): + dlg = QDialog(self) + dlg.setMinimumSize(400, 300) + img_title = _safe_filename(image_info.get("filename", "Image"), "Image") + layout = _make_frameless(dlg, img_title) + + qimg = _safe_load_image(image_data) + if qimg is None: + layout.addWidget(QLabel("Failed to load image.")) + else: + pixmap = QPixmap.fromImage(qimg) + label = QLabel() + # Scale down if larger than screen + screen_size = self.screen().availableSize() + max_w = int(screen_size.width() * 0.8) + max_h = int(screen_size.height() * 0.8) + if pixmap.width() > max_w or pixmap.height() > max_h: + pixmap = pixmap.scaled(max_w, max_h, Qt.AspectRatioMode.KeepAspectRatio, + Qt.TransformationMode.SmoothTransformation) + label.setPixmap(pixmap) + label.setAlignment(Qt.AlignmentFlag.AlignCenter) + + scroll = QScrollArea() + scroll.setWidget(label) + scroll.setWidgetResizable(True) + layout.addWidget(scroll) + + btn_row = QHBoxLayout() + save_btn = QPushButton("Save") + save_btn.clicked.connect(lambda: self._save_image(image_data, image_info, dlg)) + btn_row.addWidget(save_btn) + close_btn = QPushButton("Close") + close_btn.setObjectName("secondaryBtn") + close_btn.clicked.connect(dlg.accept) + btn_row.addWidget(close_btn) + layout.addLayout(btn_row) + + if qimg is not None and not qimg.isNull(): + dlg.resize(min(pixmap.width() + 40, max_w), + min(pixmap.height() + 80, max_h)) + dlg.exec() + + def _save_image(self, image_data, image_info, dialog): + filename = _safe_filename(image_info.get("filename", "image.jpg"), "image.jpg") + path, _ = QFileDialog.getSaveFileName(dialog, "Save Image", filename) + if path: + try: + with open(path, "wb") as f: + f.write(image_data) + QMessageBox.information(dialog, "Saved", f"Image saved to {path}") + except Exception as e: + QMessageBox.warning(dialog, "Error", f"Failed to save: {e}") + + def _on_message_deleted(self, data): + message_id = data.get("message_id", "") + conv_id = data.get("conversation_id", "") + if conv_id == self.current_conv_id: + for msg in self.current_messages: + if msg.get("message_id") == message_id: + msg["deleted"] = True + msg["text"] = "" + msg["image"] = None + break + self._render_messages() + + def _on_delete_message_result(self, ok, msg): + if not ok: + QMessageBox.warning(self, "Delete Error", msg) + return + # No need to reload — _on_message_deleted() already updates in-place via notification + + def _on_authorize_result(self, ok, msg): + if ok: + QMessageBox.information(self, "Authorize Device", msg) + else: + QMessageBox.warning(self, "Authorize Device", msg) + + def _on_rotate_result(self, ok, msg): + if ok: + QMessageBox.information(self, "Rotate Keys", msg) + else: + QMessageBox.warning(self, "Rotate Keys", msg) + + def _on_password_changed(self, ok, msg): + if ok: + QMessageBox.information(self, "Change Password", msg) + else: + QMessageBox.warning(self, "Change Password", msg) + + def _on_change_username(self): + t = c() + current = "" + if self.bridge and self.bridge.client: + current = getattr(self.bridge.client, "username", "") or "" + dlg = QDialog(self) + dlg.setMinimumWidth(360) + lay = _make_frameless(dlg, "Change Username") + + label = QLabel("New Username") + label.setStyleSheet(f"color: {t.text_secondary}; font-size: 9pt;") + lay.addWidget(label) + name_input = QLineEdit() + name_input.setText(current) + name_input.setPlaceholderText("Enter new username") + name_input.setMaxLength(100) + name_input.setStyleSheet( + f"QLineEdit {{ font-size: 11pt; background-color: {t.bg_secondary}; " + f"border: 1px solid {t.border}; border-radius: 6px; padding: 8px; " + f"color: {t.text_primary}; }}" + f"QLineEdit:focus {{ border: 1px solid {t.border_focus}; }}" + ) + lay.addWidget(name_input) + + btn_row = QHBoxLayout() + cancel_btn = QPushButton("Cancel") + cancel_btn.setObjectName("secondaryBtn") + cancel_btn.clicked.connect(dlg.reject) + btn_row.addWidget(cancel_btn) + save_btn = QPushButton("Save") + save_btn.setObjectName("primaryBtn") + save_btn.clicked.connect(dlg.accept) + btn_row.addWidget(save_btn) + lay.addLayout(btn_row) + + name_input.setFocus() + if dlg.exec() == QDialog.DialogCode.Accepted: + new_name = name_input.text().strip() + if new_name and new_name != current: + self.bridge.change_username(new_name) + + def _on_username_changed(self, ok, msg): + if ok: + QMessageBox.information(self, "Change Username", msg) + self.bridge.schedule(self.bridge._do_load_conversations()) + else: + QMessageBox.warning(self, "Change Username", msg) + + def _on_reencrypt_status(self, msg): + self.reencrypt_label.setText(msg) + self.reencrypt_label.setVisible(True) + if msg.lower().startswith("re-encryption complete"): + QTimer.singleShot(4000, lambda: self.reencrypt_label.setVisible(False)) + + def closeEvent(self, event): + if self._tray_icon: + self._tray_icon.hide() + if not self._is_logout: + self.bridge.stop() + self.bridge.wait(2000) + event.accept() + + +def main(): + setup_logging() + + # Suppress Qt screen enumeration warnings on Windows (monitor sleep/wake) + os.environ.setdefault("QT_LOGGING_RULES", "qt.qpa.screen=false") + + # Windows 10+ requires AppUserModelID for system tray notifications to work + if sys.platform == "win32": + try: + import ctypes + ctypes.windll.shell32.SetCurrentProcessExplicitAppUserModelID("com.encrypted-chat.client") + except Exception: + pass + + app = QApplication(sys.argv) + app.setStyleSheet(qss()) + + bridge = AsyncBridge() + + login_win = LoginWindow(bridge) + main_win = [None] # mutable ref + + def on_connected(): + login_win.reset() + login_win.show() + + def on_conn_error(msg): + QMessageBox.critical(None, "Connection Error", f"Cannot connect to server:\n{msg}") + sys.exit(1) + + def on_register_result(ok, msg): + if ok: + # Show verification code page inline + hint = "" + if msg and len(msg) <= 6 and msg.isdigit(): + hint = f"Code: {msg}" + elif msg: + hint = msg + login_win.show_verification_page(hint) + + def do_confirm(code): + async def _confirm(): + okc, msgc = await bridge.client.confirm_registration( + login_win.email_input.text().strip(), + login_win.username_input.text().strip(), + code.strip(), + ) + if okc: + login_win.show_success(msgc) + bridge.do_login(login_win.email_input.text().strip(), login_win.password_input.text()) + else: + login_win.show_error(msgc) + bridge.schedule(_confirm()) + + login_win._confirm_callback = do_confirm + else: + login_win.show_error(msg) + + def on_pairing_code(code): + login_win.show_success(f"Pairing code: {code}") + + def on_pairing_complete(ok, msg): + if ok: + login_win.show_success(msg) + bridge.do_login(login_win._pair_email, login_win._pair_password) + else: + login_win.show_error(msg) + + def on_login_result(ok, msg): + if ok: + login_win.show_success(msg) + login_win.hide() + tm().set_email(bridge.client.email) + app.setStyleSheet(qss()) + main_win[0] = MainWindow(bridge, on_logout=lambda: (login_win.reset(), login_win.show())) + main_win[0].show() + else: + login_win.show_error(msg) + + bridge.connected.connect(on_connected) + bridge.connection_error.connect(on_conn_error) + bridge.register_result.connect(on_register_result) + bridge.login_result.connect(on_login_result) + bridge.pairing_code.connect(on_pairing_code) + bridge.pairing_complete.connect(on_pairing_complete) + bridge.reconnected.connect(lambda: (login_win.reset(), login_win.show())) + + bridge.start() + + sys.exit(app.exec()) + + +if __name__ == "__main__": + main() diff --git a/ios_client/EncryptedChat/App/AppState.swift b/ios_client/EncryptedChat/App/AppState.swift new file mode 100644 index 0000000..b1aac68 --- /dev/null +++ b/ios_client/EncryptedChat/App/AppState.swift @@ -0,0 +1,18 @@ +import Foundation +import SwiftUI + +enum ConnectionStatus: Equatable { + case disconnected + case connecting + case connected +} + +@Observable +final class AppState { + var isLoggedIn = false + var currentUser: User? + var connectionStatus: ConnectionStatus = .disconnected + var email: String = "" + + let chatClient = ChatClient() +} diff --git a/ios_client/EncryptedChat/App/EncryptedChatApp.swift b/ios_client/EncryptedChat/App/EncryptedChatApp.swift new file mode 100644 index 0000000..8747fde --- /dev/null +++ b/ios_client/EncryptedChat/App/EncryptedChatApp.swift @@ -0,0 +1,36 @@ +import SwiftUI + +@main +struct EncryptedChatApp: App { + @State private var appState = AppState() + @State private var authViewModel = AuthViewModel() + + var body: some Scene { + WindowGroup { + if appState.isLoggedIn { + MainTabView(appState: appState) + } else { + LoginView(viewModel: authViewModel, appState: appState) + } + } + } +} + +struct MainTabView: View { + var appState: AppState + @State private var convListVM = ConversationListVM() + + var body: some View { + TabView { + ConversationListView(appState: appState, viewModel: convListVM) + .tabItem { + Label("Chats", systemImage: "bubble.left.and.bubble.right.fill") + } + + ProfileView(appState: appState, isOwnProfile: true) + .tabItem { + Label("Profile", systemImage: "person.fill") + } + } + } +} diff --git a/ios_client/EncryptedChat/Core/ChatClient.swift b/ios_client/EncryptedChat/Core/ChatClient.swift new file mode 100644 index 0000000..40de811 --- /dev/null +++ b/ios_client/EncryptedChat/Core/ChatClient.swift @@ -0,0 +1,1644 @@ +import Foundation +import CryptoKit + +/// Notification types from the server +enum ChatNotification { + case newMessage(data: [String: Any]) + case messagesRead(data: [String: Any]) + case messageDeleted(data: [String: Any]) + case conversationCreated(data: [String: Any]) + case memberAdded(data: [String: Any]) + case memberRemoved(data: [String: Any]) + case userOnline(userId: String) + case userOffline(userId: String) + case onlineUsers(userIds: [String]) + case groupInvitation(data: [String: Any]) + case conversationRenamed(data: [String: Any]) + case sessionReset(data: [String: Any]) + case connectionStateChanged(connected: Bool) +} + +/// Main chat client — handles all server communication and crypto operations. +/// Thread-safe via Swift actor isolation. +/// Port of Python ChatClient class from chat_core.py +actor ChatClient { + + // MARK: - Connection + + let connectionManager = ConnectionManager() + private(set) var isConnected = false + private(set) var sessionToken: String? + private(set) var userId: String? + private(set) var username: String = "" + private(set) var email: String = "" + private(set) var loginRejected = false + + // MARK: - Keys + + private var rsaPrivate: SecKey? + private var rsaPublic: SecKey? + private(set) var identityPrivate: Curve25519.Signing.PrivateKey? + private(set) var identityPublic: Curve25519.Signing.PublicKey? + private var spkPrivate: Curve25519.KeyAgreement.PrivateKey? + private var spkId: String = "" + private var prevSpkPrivate: Curve25519.KeyAgreement.PrivateKey? + private var prevSpkId: String = "" + private var opkPrivates: [String: Curve25519.KeyAgreement.PrivateKey] = [:] + + // MARK: - Sessions & Sender Keys + + private var sessions: [String: DoubleRatchet] = [:] // "userId:deviceId" -> ratchet + private var senderKeyStates: [String: SenderKeyState] = [:] // convId -> own sender key + private var recvSenderKeys: [String: SenderKeyState] = [:] // "convId:senderId:deviceId" -> their key + + // MARK: - Derived Keys + + private var cacheKey: Data? // for encrypting message cache + private var localKey: Data? // for encrypting session/sender key files + + // MARK: - Multi-Device + + private(set) var deviceId: String? + + // MARK: - Caches + + private var userCache: [String: User] = [:] + private var deviceBundleCache: [String: (timestamp: Date, bundles: [DeviceBundle])] = [:] + + // MARK: - Request/Response Tracking + + private var pendingRequests: [String: CheckedContinuation<[String: Any], Error>] = [:] + private var listenerTask: Task? + + // MARK: - Notification Stream + + private var notificationContinuation: AsyncStream.Continuation? + nonisolated let notifications: AsyncStream + + // MARK: - Init + + init() { + var continuation: AsyncStream.Continuation! + notifications = AsyncStream { cont in + continuation = cont + } + notificationContinuation = continuation + } + + // MARK: - Connection + + func connect(host: String = Constants.defaultHost, port: UInt16 = Constants.defaultPort, + useTLS: Bool = false, tlsInsecure: Bool = false) async throws { + try await connectionManager.connect(host: host, port: port, useTLS: useTLS, tlsInsecure: tlsInsecure) + isConnected = true + notificationContinuation?.yield(.connectionStateChanged(connected: true)) + } + + func disconnect() async { + listenerTask?.cancel() + listenerTask = nil + await connectionManager.disconnect() + isConnected = false + // Fail all pending requests + let pending = pendingRequests + pendingRequests.removeAll() + for (_, cont) in pending { + cont.resume(throwing: NetworkError.notConnected) + } + notificationContinuation?.yield(.connectionStateChanged(connected: false)) + } + + // MARK: - Send and Receive + + /// Send a request and wait for the matching response. + func sendAndReceive(type: String, timeout: TimeInterval = 30, params: [String: Any] = [:]) async -> [String: Any] { + let requestId = ProtocolHandler.newRequestId() + + do { + let response: [String: Any] = try await withCheckedThrowingContinuation { continuation in + pendingRequests[requestId] = continuation + + Task { + do { + try await connectionManager.sendMessage(type: type, requestId: requestId, params: params) + } catch { + if let cont = pendingRequests.removeValue(forKey: requestId) { + cont.resume(throwing: error) + } + } + } + } + return response + } catch { + pendingRequests.removeValue(forKey: requestId) + return [ + "type": type, + "status": "error", + "data": ["message": error.localizedDescription] + ] + } + } + + // MARK: - Background Listener + + func startBackgroundListener() { + listenerTask = Task { [weak self] in + guard let self = self else { return } + await self.backgroundListenerLoop() + } + } + + private func backgroundListenerLoop() async { + while !Task.isCancelled { + do { + guard let msg = try await connectionManager.readMessage() else { + // EOF — connection closed + await handleDisconnect() + break + } + await routeMessage(msg) + } catch { + await handleDisconnect() + break + } + } + } + + private func handleDisconnect() { + isConnected = false + // Fail all pending futures + let pending = pendingRequests + pendingRequests.removeAll() + for (_, cont) in pending { + cont.resume(throwing: NetworkError.notConnected) + } + notificationContinuation?.yield(.connectionStateChanged(connected: false)) + } + + private func routeMessage(_ msg: [String: Any]) { + let msgType = msg["type"] as? String ?? "" + + // Notification types (no request_id expected from client) + let notificationTypes = Set([ + "new_message", "messages_read", "message_deleted", + "conversation_created", "member_added", "member_removed", + "user_online", "user_offline", "online_users", + "group_invitation", "conversation_renamed", "session_reset" + ]) + + if notificationTypes.contains(msgType) { + let data = msg["data"] as? [String: Any] ?? msg + switch msgType { + case "new_message": + notificationContinuation?.yield(.newMessage(data: data)) + case "messages_read": + notificationContinuation?.yield(.messagesRead(data: data)) + case "message_deleted": + notificationContinuation?.yield(.messageDeleted(data: data)) + case "conversation_created": + notificationContinuation?.yield(.conversationCreated(data: data)) + case "member_added": + notificationContinuation?.yield(.memberAdded(data: data)) + case "member_removed": + notificationContinuation?.yield(.memberRemoved(data: data)) + case "user_online": + if let uid = data["user_id"] as? String { + notificationContinuation?.yield(.userOnline(userId: uid)) + } + case "user_offline": + if let uid = data["user_id"] as? String { + notificationContinuation?.yield(.userOffline(userId: uid)) + } + case "online_users": + if let uids = data["user_ids"] as? [String] { + notificationContinuation?.yield(.onlineUsers(userIds: uids)) + } + case "group_invitation": + notificationContinuation?.yield(.groupInvitation(data: data)) + case "conversation_renamed": + notificationContinuation?.yield(.conversationRenamed(data: data)) + case "session_reset": + notificationContinuation?.yield(.sessionReset(data: data)) + default: + break + } + } else { + // Response to a pending request + if let requestId = msg["request_id"] as? String, + let cont = pendingRequests.removeValue(forKey: requestId) { + cont.resume(returning: msg) + } + } + } + + // MARK: - User Info Cache + + func getUserInfo(userId: String = "", userEmail: String = "") async -> User? { + if !userId.isEmpty, let cached = userCache[userId] { + return cached + } + var params: [String: Any] = [:] + if !userId.isEmpty { params["user_id"] = userId } + else if !userEmail.isEmpty { params["email"] = userEmail } + else { return nil } + + let resp = await sendAndReceive(type: "get_user_info", params: params) + guard resp.string(for: "status") == "ok", + let data = resp.dict(for: "data") else { return nil } + + var ikData: Data? + if let ikB64 = data["identity_key"] as? String { + ikData = try? ProtocolHandler.decodeBinary(ikB64) + } + + let user = User( + id: data.string(for: "user_id") ?? "", + username: data.string(for: "username") ?? "", + email: data.string(for: "email") ?? "", + identityKey: ikData + ) + userCache[user.id] = user + return user + } + + // MARK: - Registration + + func register(username: String, password: String, email: String) async -> (success: Bool, message: String) { + self.username = username + self.email = email + var pwdBytes = Array(password.utf8) + defer { pwdBytes.withUnsafeMutableBytes { ptr in _ = memset(ptr.baseAddress!, 0, ptr.count) } } + + let pwdData = Data(pwdBytes) + + do { + // RSA keys + let (rsaPriv, rsaPub, err) = KeyStorage.loadRSAKeys(email: email, password: pwdData) + if let rsaPriv = rsaPriv, let rsaPub = rsaPub { + self.rsaPrivate = rsaPriv + self.rsaPublic = rsaPub + } else { + let (newPriv, newPub) = try RSACrypto.generateKeypair() + try KeyStorage.saveRSAKeys(email: email, privateKey: newPriv, publicKey: newPub, password: pwdData) + self.rsaPrivate = newPriv + self.rsaPublic = newPub + } + + // Ed25519 identity keys + let (edPriv, edPub) = KeyStorage.loadIdentityKeys(email: email, password: pwdData) + if let edPriv = edPriv, let edPub = edPub { + self.identityPrivate = edPriv + self.identityPublic = edPub + } else { + let (newPriv, newPub) = Ed25519Crypto.generateKeypair() + try KeyStorage.saveIdentityKeys(email: email, privateKey: newPriv, publicKey: newPub, password: pwdData) + self.identityPrivate = newPriv + self.identityPublic = newPub + } + + self.cacheKey = CryptoUtils.deriveSelfEncryptionKey(identityPrivateRaw: identityPrivate!.rawData) + self.localKey = CryptoUtils.deriveLocalStorageKey(identityPrivateRaw: identityPrivate!.rawData) + } catch { + return (false, "Key generation failed: \(error.localizedDescription)") + } + + // Send registration request + let pubPem = String(data: try! RSACrypto.serializePublicKey(rsaPublic!), encoding: .utf8)! + let ikB64 = ProtocolHandler.encodeBinary(Ed25519Crypto.serializePublic(identityPublic!)) + + let resp = await sendAndReceive(type: "register", params: [ + "username": username, + "public_key": pubPem, + "email": email, + "identity_key": ikB64, + ]) + + guard resp.string(for: "status") == "ok" else { + let msg = resp.dict(for: "data")?.string(for: "message") ?? "Registration failed" + return (false, msg) + } + + let data = resp.dict(for: "data") ?? [:] + if let code = data.string(for: "code") { + return (true, code) + } + return (true, data.string(for: "message") ?? "Check your email for the code.") + } + + func confirmRegistration(email: String, username: String, code: String) async -> (success: Bool, message: String) { + let resp = await sendAndReceive(type: "register_confirm", params: [ + "email": email, + "code": code, + ]) + + guard resp.string(for: "status") == "ok" else { + let msg = resp.dict(for: "data")?.string(for: "message") ?? "Confirmation failed" + return (false, msg) + } + + // Upload prekeys + await generateAndUploadPrekeys() + + let uid = resp.dict(for: "data")?.string(for: "user_id") ?? "" + return (true, "Registered as '\(username)' (ID: \(uid))") + } + + // MARK: - Prekeys + + private func generateAndUploadPrekeys(keepSPK: Bool = false) async { + guard let identityPrivate = identityPrivate else { return } + + do { + let spkData: [String: Any] + + if keepSPK, let spkPriv = spkPrivate, !spkId.isEmpty { + let spkPubBytes = X25519Crypto.serializePublic(spkPriv.publicKey) + let sig = try Ed25519Crypto.sign(identityPrivate, data: spkPubBytes) + spkData = [ + "id": spkId, + "public_key": ProtocolHandler.encodeBinary(spkPubBytes), + "signature": ProtocolHandler.encodeBinary(sig), + ] + } else { + // Save current as previous (grace period) + if let spkPriv = spkPrivate, !spkId.isEmpty { + prevSpkPrivate = spkPriv + prevSpkId = spkId + try? KeyStorage.savePrevSPK(email: email, privateKey: spkPriv, spkId: spkId) + } + + let spk = try X3DH.generateSignedPrekey(identityPrivate: identityPrivate) + self.spkPrivate = spk.privateKey + self.spkId = spk.id + try? KeyStorage.saveSPK(email: email, privateKey: spk.privateKey, spkId: spk.id) + + spkData = [ + "id": spk.id, + "public_key": ProtocolHandler.encodeBinary(X25519Crypto.serializePublic(spk.publicKey)), + "signature": ProtocolHandler.encodeBinary(spk.signature), + ] + } + + // Generate OPKs + let opks = X3DH.generateOneTimePrekeys(count: Constants.opkBatchSize) + for opk in opks { + opkPrivates[opk.id] = opk.privateKey + try? KeyStorage.saveOPKPrivate(email: email, opkId: opk.id, privateKey: opk.privateKey) + } + + let otpData = opks.map { opk -> [String: Any] in + [ + "id": opk.id, + "public_key": ProtocolHandler.encodeBinary(X25519Crypto.serializePublic(opk.publicKey)), + ] + } + + _ = await sendAndReceive(type: "upload_prekeys", params: [ + "signed_prekey": spkData, + "one_time_prekeys": otpData, + ]) + } catch { + // Log error but don't fail + print("Prekey generation error: \(error)") + } + } + + private func ensurePrekeys() async { + let resp = await sendAndReceive(type: "get_prekey_count") + guard resp.string(for: "status") == "ok", + let data = resp.dict(for: "data") else { return } + + let count = data.int(for: "count") ?? 0 + let spkCreatedAt = data.string(for: "spk_created_at") ?? "" + + var needNewSPK = false + if !spkCreatedAt.isEmpty { + let formatter = ISO8601DateFormatter() + formatter.formatOptions = [.withInternetDateTime, .withFractionalSeconds] + if let created = formatter.date(from: spkCreatedAt) ?? ISO8601DateFormatter().date(from: spkCreatedAt) { + let ageDays = Calendar.current.dateComponents([.day], from: created, to: Date()).day ?? 0 + if ageDays >= Constants.spkRotationDays { + needNewSPK = true + } + } + } + + if count < Constants.opkReplenishThreshold || needNewSPK { + await generateAndUploadPrekeys() + } + } + + // MARK: - Login + + func login(email: String, password: String) async -> (success: Bool, message: String) { + self.email = email + var pwdBytes = Array(password.utf8) + defer { pwdBytes.withUnsafeMutableBytes { ptr in _ = memset(ptr.baseAddress!, 0, ptr.count) } } + let pwdData = Data(pwdBytes) + + // Load RSA keys + let (rsaPriv, rsaPub, err) = KeyStorage.loadRSAKeys(email: email, password: pwdData) + guard let rsaPriv = rsaPriv, let rsaPub = rsaPub else { + return (false, err ?? "No local keys found. Register first.") + } + self.rsaPrivate = rsaPriv + self.rsaPublic = rsaPub + + // Load identity keys + let (edPriv, edPub) = KeyStorage.loadIdentityKeys(email: email, password: pwdData) + if let edPriv = edPriv, let edPub = edPub { + self.identityPrivate = edPriv + self.identityPublic = edPub + self.cacheKey = CryptoUtils.deriveSelfEncryptionKey(identityPrivateRaw: edPriv.rawData) + self.localKey = CryptoUtils.deriveLocalStorageKey(identityPrivateRaw: edPriv.rawData) + } + + // Load SPK + let (spkP, spkI) = KeyStorage.loadSPK(email: email) + if let spkP = spkP { + self.spkPrivate = spkP + self.spkId = spkI ?? "" + } + + // Load previous SPK (grace period) + let (prevP, prevI) = KeyStorage.loadPrevSPK(email: email) + if let prevP = prevP { + self.prevSpkPrivate = prevP + self.prevSpkId = prevI ?? "" + } + + // Load device ID + self.deviceId = KeyStorage.loadDeviceId(email: email) + + // RSA challenge-response login + let startResp = await sendAndReceive(type: "login_start", params: ["email": email]) + guard startResp.string(for: "status") == "ok", + let startData = startResp.dict(for: "data"), + let challengeB64 = startData.string(for: "challenge") else { + let msg = startResp.dict(for: "data")?.string(for: "message") ?? "Login failed" + return (false, msg) + } + + let challengeData: Data + do { + challengeData = try ProtocolHandler.decodeBinary(challengeB64) + } catch { + return (false, "Invalid challenge data") + } + + let signature: Data + do { + signature = try RSACrypto.sign(rsaPriv, data: challengeData) + } catch { + return (false, "RSA signing failed: \(error.localizedDescription)") + } + + var finishParams: [String: Any] = [ + "email": email, + "signature": ProtocolHandler.encodeBinary(signature), + "client_version": Constants.version, + ] + if let deviceId = deviceId { + finishParams["device_id"] = deviceId + } + + let finishResp = await sendAndReceive(type: "login_finish", params: finishParams) + guard finishResp.string(for: "status") == "ok", + let finishData = finishResp.dict(for: "data") else { + let msg = finishResp.dict(for: "data")?.string(for: "message") ?? "Login failed" + loginRejected = true + return (false, msg) + } + + self.userId = finishData.string(for: "user_id") + self.username = finishData.string(for: "username") ?? "" + self.sessionToken = finishData.string(for: "session_token") + + // Save device ID from server + if let newDeviceId = finishData.string(for: "device_id") { + self.deviceId = newDeviceId + try? KeyStorage.saveDeviceId(email: email, deviceId: newDeviceId) + } + + // Start background listener + startBackgroundListener() + + // Handle online_users if included + if let onlineUserIds = finishData["online_user_ids"] as? [String] { + notificationContinuation?.yield(.onlineUsers(userIds: onlineUserIds)) + } + + // Ensure prekeys in background + Task { await ensurePrekeys() } + + return (true, "Logged in as \(username)") + } + + // MARK: - Reconnect + + func reconnect() async -> Bool { + guard rsaPrivate != nil else { return false } + + await disconnect() + + do { + try await connect() + } catch { + return false + } + + // RSA challenge-response with in-memory keys + let startResp = await sendAndReceive(type: "login_start", params: ["email": email]) + guard startResp.string(for: "status") == "ok", + let startData = startResp.dict(for: "data"), + let challengeB64 = startData.string(for: "challenge"), + let challengeData = try? ProtocolHandler.decodeBinary(challengeB64), + let signature = try? RSACrypto.sign(rsaPrivate!, data: challengeData) else { + return false + } + + var finishParams: [String: Any] = [ + "email": email, + "signature": ProtocolHandler.encodeBinary(signature), + "client_version": Constants.version, + ] + if let deviceId = deviceId { + finishParams["device_id"] = deviceId + } + + let finishResp = await sendAndReceive(type: "login_finish", params: finishParams) + guard finishResp.string(for: "status") == "ok" else { return false } + + startBackgroundListener() + return true + } + + // MARK: - Device Bundles + + private func getDeviceBundles(userId: String) async throws -> [DeviceBundle] { + // Check cache (5-min TTL) + if let cached = deviceBundleCache[userId], + Date().timeIntervalSince(cached.timestamp) < Constants.deviceBundleCacheTTL { + return cached.bundles + } + + let resp = await sendAndReceive(type: "get_key_bundle", params: ["user_id": userId]) + guard resp.string(for: "status") == "ok", + let data = resp.dict(for: "data") else { + throw ChatError.operationFailed("Failed to get key bundle") + } + + var bundles: [DeviceBundle] = [] + + // Per-device bundles (new format) + if let deviceBundlesRaw = data["device_bundles"] as? [[String: Any]] { + for bundleDict in deviceBundlesRaw { + if let bundle = try? DeviceBundle.fromDict(bundleDict) { + bundles.append(bundle) + } + } + } + // Legacy single bundle + else if let ikHex = data["identity_key"] as? String { + let bundle = try DeviceBundle.fromDict(data) + bundles.append(bundle) + } + + deviceBundleCache[userId] = (Date(), bundles) + return bundles + } + + // MARK: - Session Management + + private func getOrCreateSession( + peerUserId: String, + peerDeviceId: String, + bundle: DeviceBundle + ) async throws -> DoubleRatchet { + let sessionKey = "\(peerUserId):\(peerDeviceId)" + + // Check memory + if let session = sessions[sessionKey] { + return session + } + + // Check disk + if let session = KeyStorage.loadSession( + email: email, + peerUserId: peerUserId, + localKey: localKey, + peerDeviceId: peerDeviceId + ) { + sessions[sessionKey] = session + return session + } + + // Create new via X3DH + let remoteIkEd = try Ed25519Crypto.loadPublic(bundle.identityKey) + let spkRemote = try X25519Crypto.loadPublic(bundle.spk) + var opkRemote: Curve25519.KeyAgreement.PublicKey? + if let opkData = bundle.opk { + opkRemote = try X25519Crypto.loadPublic(opkData) + } + + let (sharedSecret, ekPriv, ekPub) = try X3DH.initiate( + ikPrivateEd: identityPrivate!, + ikPublicRemoteEd: remoteIkEd, + spkRemote: spkRemote, + spkSignature: bundle.spkSignature, + opkRemote: opkRemote + ) + + let ratchet = try DoubleRatchet.initAlice(sharedSecret: sharedSecret, bobSpkPub: spkRemote) + + // Build X3DH header for first message + var x3dhHeader: [String: Any] = [ + "ik": Ed25519Crypto.serializePublic(identityPublic!).hexString, + "ek": X25519Crypto.serializePublic(ekPub).hexString, + "spk_id": bundle.spkId, + ] + if let opkId = bundle.opkId { + x3dhHeader["opk_id"] = opkId + } + ratchet.x3dhHeader = x3dhHeader + + sessions[sessionKey] = ratchet + try? KeyStorage.saveSession(email: email, peerUserId: peerUserId, ratchet: ratchet, localKey: localKey, peerDeviceId: peerDeviceId) + + return ratchet + } + + // MARK: - X3DH Response (Bob Side) + + private func processX3DHHeader( + senderId: String, + x3dhHeader: [String: Any], + senderDeviceId: String, + spkOverride: Curve25519.KeyAgreement.PrivateKey? = nil + ) throws -> DoubleRatchet { + guard let ikHex = x3dhHeader["ik"] as? String, + let ikData = Data(hexString: ikHex), + let ekHex = x3dhHeader["ek"] as? String, + let ekData = Data(hexString: ekHex), + let spkIdStr = x3dhHeader["spk_id"] as? String else { + throw CryptoError.x3dhFailed("Invalid X3DH header") + } + + let remoteIkEd = try Ed25519Crypto.loadPublic(ikData) + let ekRemote = try X25519Crypto.loadPublic(ekData) + + // Determine which SPK to use + let spkToUse: Curve25519.KeyAgreement.PrivateKey + if let override = spkOverride { + spkToUse = override + } else if spkIdStr == spkId, let spk = spkPrivate { + spkToUse = spk + } else if spkIdStr == prevSpkId, let prevSpk = prevSpkPrivate { + spkToUse = prevSpk + } else { + throw CryptoError.x3dhFailed("SPK \(spkIdStr) not found") + } + + // OPK + var opkPriv: Curve25519.KeyAgreement.PrivateKey? + if let opkIdStr = x3dhHeader["opk_id"] as? String { + opkPriv = opkPrivates[opkIdStr] ?? KeyStorage.loadOPKPrivate(email: email, opkId: opkIdStr) + if opkPriv != nil { + opkPrivates.removeValue(forKey: opkIdStr) + KeyStorage.deleteOPKPrivate(email: email, opkId: opkIdStr) + } + } + + let sharedSecret = try X3DH.respond( + ikPrivateEd: identityPrivate!, + spkPrivate: spkToUse, + ikRemoteEd: remoteIkEd, + ekRemote: ekRemote, + opkPrivate: opkPriv + ) + + let ratchet = DoubleRatchet.initBob( + sharedSecret: sharedSecret, + spkPair: (spkToUse, spkToUse.publicKey) + ) + + let sessionKey = "\(senderId):\(senderDeviceId)" + sessions[sessionKey] = ratchet + try? KeyStorage.saveSession(email: email, peerUserId: senderId, ratchet: ratchet, localKey: localKey, peerDeviceId: senderDeviceId) + + return ratchet + } + + // MARK: - Send Message + + func sendMessage(convId: String, text: String, members: [ConversationMember], replyTo: String? = nil) async -> (success: Bool, message: String) { + let isGroup = members.count > 2 + + if isGroup { + return await sendGroupMessage(convId: convId, text: text, members: members, replyTo: replyTo) + } else { + return await sendDM(convId: convId, text: text, members: members, replyTo: replyTo) + } + } + + // MARK: - Send DM + + private func sendDM(convId: String, text: String, members: [ConversationMember], replyTo: String? = nil) async -> (success: Bool, message: String) { + guard let identityPrivate = identityPrivate else { + return (false, "Identity key not loaded") + } + + let plaintext = Data(text.utf8) + var payload: [String: Any] = ["text": text] + if let replyTo = replyTo { + payload["reply_to"] = replyTo + } + + var recipients: [[String: Any]] = [] + + // Encrypt for each member's devices + for member in members where member.userId != userId { + do { + let bundles = try await getDeviceBundles(userId: member.userId) + for bundle in bundles { + let ratchet = try await getOrCreateSession( + peerUserId: member.userId, + peerDeviceId: bundle.deviceId, + bundle: bundle + ) + + // Consume X3DH header if present (first message only) + let x3dhHeader = ratchet.x3dhHeader + ratchet.x3dhHeader = nil + + let encrypted = try ratchet.encrypt(plaintext) + try? KeyStorage.saveSession(email: email, peerUserId: member.userId, ratchet: ratchet, localKey: localKey, peerDeviceId: bundle.deviceId) + + var recipientEntry: [String: Any] = [ + "user_id": member.userId, + "device_id": bundle.deviceId, + "ciphertext": ProtocolHandler.encodeBinary(encrypted.ciphertext), + "nonce": ProtocolHandler.encodeBinary(encrypted.nonce), + "ratchet_header": encrypted.header, + ] + if let x3dh = x3dhHeader { + recipientEntry["x3dh_header"] = x3dh + } + recipients.append(recipientEntry) + } + } catch { + return (false, "Encryption failed for \(member.username): \(error.localizedDescription)") + } + } + + // Self-encrypted copy + let selfKey = CryptoUtils.deriveSelfEncryptionKey(identityPrivateRaw: identityPrivate.rawData) + if let (_, nonce, ct, tag) = try? CryptoUtils.aesEncrypt(plaintext, key: selfKey) { + let selfCiphertext = ct + tag + let dummyHeader: [String: Any] = [ + "dh_pub": String(repeating: "00", count: 32), + "n": 0, + "pn": 0, + ] + recipients.append([ + "user_id": userId!, + "device_id": Constants.selfDeviceId, + "ciphertext": ProtocolHandler.encodeBinary(selfCiphertext), + "nonce": ProtocolHandler.encodeBinary(nonce), + "ratchet_header": dummyHeader, + ]) + } + + // Build ratchet header for message table (use first recipient's or dummy) + let ratchetHeader: [String: Any] + if let first = recipients.first { + ratchetHeader = first["ratchet_header"] as? [String: Any] ?? [:] + } else { + ratchetHeader = ["dh_pub": String(repeating: "00", count: 32), "n": 0, "pn": 0] + } + + var params: [String: Any] = [ + "conversation_id": convId, + "ratchet_header": ratchetHeader, + "recipients": recipients, + ] + if let replyTo = replyTo { + params["reply_to"] = replyTo + } + + let resp = await sendAndReceive(type: "send_message", params: params) + guard resp.string(for: "status") == "ok" else { + let msg = resp.dict(for: "data")?.string(for: "message") ?? "Send failed" + return (false, msg) + } + + // Cache the sent message + if let msgData = resp.dict(for: "data"), let messageId = msgData.string(for: "message_id") { + var cacheEntry = payload + cacheEntry["sender_id"] = userId + cacheEntry["sender_username"] = username + cacheEntry["created_at"] = ISO8601DateFormatter().string(from: Date()) + try? MessageCache.save(email: email, convId: convId, messages: [cacheEntry], cacheKey: cacheKey) + } + + return (true, "Message sent") + } + + // MARK: - Send Group Message + + private func sendGroupMessage(convId: String, text: String, members: [ConversationMember], replyTo: String? = nil) async -> (success: Bool, message: String) { + guard let identityPrivate = identityPrivate, let userId = userId, let deviceId = deviceId else { + return (false, "Not properly logged in") + } + + // Get or create sender key for this group + var senderKeyState = senderKeyStates[convId] + if senderKeyState == nil { + senderKeyState = KeyStorage.loadSenderKeyState(email: email, convId: convId, localKey: localKey) + } + + var needDistribute = false + if senderKeyState == nil { + senderKeyState = SenderKeyState() + needDistribute = true + } + + senderKeyStates[convId] = senderKeyState + + // Distribute sender key if new + if needDistribute { + await distributeSenderKey(convId: convId, members: members) + } + + // Encrypt with sender key + let plaintext = Data(text.utf8) + do { + let encrypted = try senderKeyState!.encrypt(plaintext) + try? KeyStorage.saveSenderKeyState(email: email, convId: convId, state: senderKeyState!, localKey: localKey) + + // Build recipients (same ciphertext for all) + var recipients: [[String: Any]] = [] + for member in members where member.userId != userId { + recipients.append([ + "user_id": member.userId, + "device_id": Constants.selfDeviceId, // group messages use sentinel + "ciphertext": ProtocolHandler.encodeBinary(encrypted.ciphertext), + "nonce": ProtocolHandler.encodeBinary(encrypted.nonce), + ]) + } + + // Self copy + let selfKey = CryptoUtils.deriveSelfEncryptionKey(identityPrivateRaw: identityPrivate.rawData) + if let (_, nonce, ct, tag) = try? CryptoUtils.aesEncrypt(plaintext, key: selfKey) { + recipients.append([ + "user_id": userId, + "device_id": Constants.selfDeviceId, + "ciphertext": ProtocolHandler.encodeBinary(ct + tag), + "nonce": ProtocolHandler.encodeBinary(nonce), + ]) + } + + let dummyHeader: [String: Any] = [ + "dh_pub": String(repeating: "00", count: 32), + "n": 0, + "pn": 0, + ] + + var params: [String: Any] = [ + "conversation_id": convId, + "ratchet_header": dummyHeader, + "recipients": recipients, + "sender_chain_id": ProtocolHandler.encodeBinary(encrypted.ciphertext.prefix(0)), // placeholder + ] + + // Include sender key metadata for group routing + params["sender_chain_id"] = encrypted.chainIdHex + params["sender_chain_n"] = encrypted.n + + if let replyTo = replyTo { + params["reply_to"] = replyTo + } + + let resp = await sendAndReceive(type: "send_message", params: params) + guard resp.string(for: "status") == "ok" else { + let msg = resp.dict(for: "data")?.string(for: "message") ?? "Send failed" + return (false, msg) + } + + return (true, "Message sent") + } catch { + return (false, "Encryption failed: \(error.localizedDescription)") + } + } + + // MARK: - Distribute Sender Key + + private func distributeSenderKey(convId: String, members: [ConversationMember]) async { + guard let senderKeyState = senderKeyStates[convId], + let userId = userId, + let deviceId = deviceId else { return } + + let exportedKey = senderKeyState.exportKey() + + for member in members where member.userId != userId { + do { + let bundles = try await getDeviceBundles(userId: member.userId) + for bundle in bundles { + let ratchet = try await getOrCreateSession( + peerUserId: member.userId, + peerDeviceId: bundle.deviceId, + bundle: bundle + ) + + let x3dhHeader = ratchet.x3dhHeader + ratchet.x3dhHeader = nil + + // Payload includes sender key + metadata + let controlPayload: [String: Any] = [ + "_sender_key": [ + "conv_id": convId, + "key": ProtocolHandler.encodeBinary(exportedKey), + "sender_device_id": deviceId, + ] + ] + let controlData = try JSONSerialization.data(withJSONObject: controlPayload) + let encrypted = try ratchet.encrypt(controlData) + try? KeyStorage.saveSession(email: email, peerUserId: member.userId, ratchet: ratchet, localKey: localKey, peerDeviceId: bundle.deviceId) + + var recipientEntry: [String: Any] = [ + "user_id": member.userId, + "device_id": bundle.deviceId, + "ciphertext": ProtocolHandler.encodeBinary(encrypted.ciphertext), + "nonce": ProtocolHandler.encodeBinary(encrypted.nonce), + "ratchet_header": encrypted.header, + ] + if let x3dh = x3dhHeader { + recipientEntry["x3dh_header"] = x3dh + } + + let dummyHeader: [String: Any] = [ + "dh_pub": String(repeating: "00", count: 32), + "n": 0, + "pn": 0, + ] + + _ = await sendAndReceive(type: "send_message", params: [ + "conversation_id": convId, + "ratchet_header": dummyHeader, + "recipients": [recipientEntry], + ]) + } + } catch { + print("Failed to distribute sender key to \(member.userId): \(error)") + } + } + } + + // MARK: - Decrypt + + func decryptDMRecipientData( + senderData: [String: Any], + senderId: String, + senderDeviceId: String + ) -> Data? { + guard let ctB64 = senderData["ciphertext"] as? String, + let nonceB64 = senderData["nonce"] as? String, + let ct = try? ProtocolHandler.decodeBinary(ctB64), + let nonce = try? ProtocolHandler.decodeBinary(nonceB64) else { + return nil + } + + // Self-encrypted copy + if senderDeviceId == Constants.selfDeviceId || senderId == userId { + if let cacheKey = cacheKey { + // ct = ciphertext + tag(16) + guard ct.count >= 16 else { return nil } + let ciphertext = ct.prefix(ct.count - 16) + let tag = ct.suffix(16) + return try? CryptoUtils.aesDecrypt(key: cacheKey, nonce: nonce, ciphertext: Data(ciphertext), tag: Data(tag)) + } + return nil + } + + // Regular DM decryption + let headerDict = senderData["ratchet_header"] as? [String: Any] + let x3dhHeader = senderData["x3dh_header"] as? [String: Any] + + let sessionKey = "\(senderId):\(senderDeviceId)" + var ratchet = sessions[sessionKey] + ?? KeyStorage.loadSession(email: email, peerUserId: senderId, localKey: localKey, peerDeviceId: senderDeviceId) + + // Handle X3DH header (new session) + if let x3dh = x3dhHeader { + do { + ratchet = try processX3DHHeader( + senderId: senderId, + x3dhHeader: x3dh, + senderDeviceId: senderDeviceId + ) + } catch { + // Try with previous SPK (grace period) + if let prevSpk = prevSpkPrivate { + ratchet = try? processX3DHHeader( + senderId: senderId, + x3dhHeader: x3dh, + senderDeviceId: senderDeviceId, + spkOverride: prevSpk + ) + } + if ratchet == nil { return nil } + } + } + + guard let ratchet = ratchet, let header = headerDict else { return nil } + + do { + let plaintext = try ratchet.decrypt(headerDict: header, ciphertext: ct, nonce: nonce) + sessions[sessionKey] = ratchet + try? KeyStorage.saveSession(email: email, peerUserId: senderId, ratchet: ratchet, localKey: localKey, peerDeviceId: senderDeviceId) + + // Check for sender key distribution (control message) + if let jsonObj = try? JSONSerialization.jsonObject(with: plaintext) as? [String: Any], + let senderKeyInfo = jsonObj["_sender_key"] as? [String: Any] { + handleSenderKeyDistribution(senderKeyInfo, senderId: senderId) + return nil // Control message + } + + return plaintext + } catch { + return nil + } + } + + private func handleSenderKeyDistribution(_ info: [String: Any], senderId: String) { + guard let convId = info["conv_id"] as? String, + let keyB64 = info["key"] as? String, + let keyData = try? ProtocolHandler.decodeBinary(keyB64) else { return } + + let senderDeviceId = info["sender_device_id"] as? String ?? Constants.selfDeviceId + + do { + let senderKey = try SenderKeyState.fromKey(keyData) + let stateKey = "\(convId):\(senderId):\(senderDeviceId)" + recvSenderKeys[stateKey] = senderKey + try? KeyStorage.saveRecvSenderKey( + email: email, + convId: convId, + senderId: senderId, + senderDeviceId: senderDeviceId, + state: senderKey, + localKey: localKey + ) + } catch { + print("Failed to import sender key: \(error)") + } + } + + func decryptNotification(_ data: [String: Any]) -> Message? { + guard let senderId = data.string(for: "sender_id"), + let conversationId = data.string(for: "conversation_id"), + let messageId = data.string(for: "message_id") else { + return nil + } + + let senderDeviceId = data.string(for: "sender_device_id") ?? Constants.selfDeviceId + + // Find our device's entry + var recipientData: [String: Any]? + if let deviceEntries = data["device_entries"] as? [[String: Any]] { + recipientData = deviceEntries.first(where: { + ($0["device_id"] as? String) == deviceId || ($0["device_id"] as? String) == Constants.selfDeviceId + }) + } + // Fallback: use data directly if it has ciphertext + if recipientData == nil, data["ciphertext"] != nil { + recipientData = data + } + + guard let recipientData = recipientData else { return nil } + + // Try DM decryption + if let plaintext = decryptDMRecipientData( + senderData: recipientData, + senderId: senderId, + senderDeviceId: senderDeviceId + ) { + let text = String(data: plaintext, encoding: .utf8) + + // Parse JSON payload + var messageText = text + var replyTo: String? + var file: FileInfo? + if let jsonObj = try? JSONSerialization.jsonObject(with: plaintext) as? [String: Any] { + messageText = jsonObj["text"] as? String + replyTo = jsonObj["reply_to"] as? String + if let fileDict = jsonObj["file"] as? [String: Any] { + file = FileInfo( + fileId: fileDict["file_id"] as? String ?? "", + aesKey: fileDict["aes_key"] as? String ?? "", + iv: fileDict["iv"] as? String ?? "", + filename: fileDict["filename"] as? String ?? "", + size: fileDict["size"] as? Int ?? 0, + mimeType: fileDict["mime_type"] as? String ?? "" + ) + } + } + + let createdAt = data.string(for: "created_at").flatMap { ISO8601DateFormatter().date(from: $0) } ?? Date() + let senderUsername = data.string(for: "sender_username") ?? userCache[senderId]?.username ?? "Unknown" + + return Message( + id: messageId, + conversationId: conversationId, + senderId: senderId, + senderUsername: senderUsername, + createdAt: createdAt, + text: messageText, + replyTo: replyTo, + imageFileId: data.string(for: "image_file_id"), + file: file, + isDeleted: false, + readBy: [] + ) + } + + return nil + } + + // MARK: - Conversations + + func listConversations() async -> [Conversation] { + let resp = await sendAndReceive(type: "list_conversations") + guard resp.string(for: "status") == "ok", + let data = resp.dict(for: "data"), + let convList = data["conversations"] as? [[String: Any]] else { + return [] + } + + return convList.compactMap { dict -> Conversation? in + guard let id = dict.string(for: "id") else { return nil } + + let membersRaw = dict["members"] as? [[String: Any]] ?? [] + let members = membersRaw.compactMap { m -> ConversationMember? in + guard let uid = m.string(for: "user_id"), + let uname = m.string(for: "username"), + let uemail = m.string(for: "email") else { return nil } + return ConversationMember(userId: uid, username: uname, email: uemail) + } + + let unreadCount = dict.int(for: "unread_count") ?? 0 + + return Conversation( + id: id, + name: dict.string(for: "name"), + members: members, + createdBy: dict.string(for: "created_by"), + avatarFile: dict.string(for: "avatar_file"), + unreadCount: unreadCount, + isFavorite: false, + lastMessageTime: nil + ) + } + } + + func createConversation(emails: [String], name: String? = nil) async -> (convId: String?, message: String) { + var params: [String: Any] = ["emails": emails] + if let name = name { + params["name"] = name + } + + let resp = await sendAndReceive(type: "create_conversation", params: params) + guard resp.string(for: "status") == "ok", + let data = resp.dict(for: "data"), + let convId = data.string(for: "conversation_id") else { + let msg = resp.dict(for: "data")?.string(for: "message") ?? "Failed to create conversation" + return (nil, msg) + } + + return (convId, "Conversation created") + } + + func findConversation(email: String) async -> String? { + let resp = await sendAndReceive(type: "find_conversation", params: ["email": email]) + guard resp.string(for: "status") == "ok", + let data = resp.dict(for: "data") else { return nil } + return data.string(for: "conversation_id") + } + + // MARK: - Messages + + func getMessages(convId: String, limit: Int = 50, offset: Int = 0) async -> [Message] { + let resp = await sendAndReceive(type: "get_messages", params: [ + "conversation_id": convId, + "limit": limit, + "offset": offset, + ]) + guard resp.string(for: "status") == "ok", + let data = resp.dict(for: "data"), + let messagesRaw = data["messages"] as? [[String: Any]] else { + return [] + } + + var messages: [Message] = [] + for msgDict in messagesRaw { + guard let msgId = msgDict.string(for: "id"), + let senderId = msgDict.string(for: "sender_id") else { continue } + + let senderDeviceId = msgDict.string(for: "sender_device_id") ?? Constants.selfDeviceId + let isDeleted = msgDict["deleted_at"] != nil && !(msgDict["deleted_at"] is NSNull) + + if isDeleted { + let createdAt = msgDict.string(for: "created_at").flatMap { ISO8601DateFormatter().date(from: $0) } ?? Date() + messages.append(Message( + id: msgId, conversationId: convId, senderId: senderId, + senderUsername: msgDict.string(for: "sender_username") ?? "", + createdAt: createdAt, text: nil, isDeleted: true, readBy: [] + )) + continue + } + + // Try to decrypt + if let plaintext = decryptDMRecipientData( + senderData: msgDict, + senderId: senderId, + senderDeviceId: senderDeviceId + ) { + let text = String(data: plaintext, encoding: .utf8) + var messageText = text + var replyTo: String? + var file: FileInfo? + + if let jsonObj = try? JSONSerialization.jsonObject(with: plaintext) as? [String: Any] { + messageText = jsonObj["text"] as? String + replyTo = jsonObj["reply_to"] as? String + if let fileDict = jsonObj["file"] as? [String: Any] { + file = FileInfo( + fileId: fileDict["file_id"] as? String ?? "", + aesKey: fileDict["aes_key"] as? String ?? "", + iv: fileDict["iv"] as? String ?? "", + filename: fileDict["filename"] as? String ?? "", + size: fileDict["size"] as? Int ?? 0, + mimeType: fileDict["mime_type"] as? String ?? "" + ) + } + } + + if messageText == nil && file == nil { continue } // Control message + + let createdAt = msgDict.string(for: "created_at").flatMap { ISO8601DateFormatter().date(from: $0) } ?? Date() + messages.append(Message( + id: msgId, conversationId: convId, senderId: senderId, + senderUsername: msgDict.string(for: "sender_username") ?? "", + createdAt: createdAt, text: messageText, replyTo: replyTo, + imageFileId: msgDict.string(for: "image_file_id"), file: file, + isDeleted: false, readBy: [] + )) + } + } + + return messages + } + + func markRead(convId: String, messageIds: [String]) async { + _ = await sendAndReceive(type: "mark_read", params: [ + "conversation_id": convId, + "message_ids": messageIds, + ]) + } + + func deleteMessage(messageId: String, convId: String) async -> Bool { + let resp = await sendAndReceive(type: "delete_message", params: [ + "message_id": messageId, + "conversation_id": convId, + ]) + return resp.string(for: "status") == "ok" + } + + // MARK: - Group Operations + + func addMember(convId: String, email: String) async -> (success: Bool, message: String) { + let resp = await sendAndReceive(type: "add_member", params: [ + "conversation_id": convId, + "email": email, + ]) + let msg = resp.dict(for: "data")?.string(for: "message") ?? "" + return (resp.string(for: "status") == "ok", msg) + } + + func removeMember(convId: String, userId: String) async -> (success: Bool, message: String) { + let resp = await sendAndReceive(type: "remove_member", params: [ + "conversation_id": convId, + "user_id": userId, + ]) + let msg = resp.dict(for: "data")?.string(for: "message") ?? "" + return (resp.string(for: "status") == "ok", msg) + } + + func leaveGroup(convId: String) async -> (success: Bool, message: String) { + let resp = await sendAndReceive(type: "leave_group", params: [ + "conversation_id": convId, + ]) + if resp.string(for: "status") == "ok" { + // Clean up local sender keys + senderKeyStates.removeValue(forKey: convId) + KeyStorage.deleteSenderKeyState(email: email, convId: convId) + KeyStorage.deleteRecvSenderKeys(email: email, convId: convId) + return (true, "Left group") + } + let msg = resp.dict(for: "data")?.string(for: "message") ?? "Failed" + return (false, msg) + } + + func renameConversation(convId: String, name: String) async -> (success: Bool, message: String) { + let resp = await sendAndReceive(type: "rename_conversation", params: [ + "conversation_id": convId, + "name": name, + ]) + let msg = resp.dict(for: "data")?.string(for: "message") ?? "" + return (resp.string(for: "status") == "ok", msg) + } + + func deleteConversation(convId: String) async -> (success: Bool, message: String) { + let resp = await sendAndReceive(type: "delete_conversation", params: [ + "conversation_id": convId, + ]) + if resp.string(for: "status") == "ok" { + senderKeyStates.removeValue(forKey: convId) + KeyStorage.deleteSenderKeyState(email: email, convId: convId) + KeyStorage.deleteRecvSenderKeys(email: email, convId: convId) + return (true, "Deleted") + } + let msg = resp.dict(for: "data")?.string(for: "message") ?? "Failed" + return (false, msg) + } + + // MARK: - Invitations + + func acceptInvitation(convId: String) async -> (success: Bool, message: String) { + let resp = await sendAndReceive(type: "accept_invitation", params: [ + "conversation_id": convId, + ]) + let msg = resp.dict(for: "data")?.string(for: "message") ?? "" + return (resp.string(for: "status") == "ok", msg) + } + + func declineInvitation(convId: String) async -> (success: Bool, message: String) { + let resp = await sendAndReceive(type: "decline_invitation", params: [ + "conversation_id": convId, + ]) + let msg = resp.dict(for: "data")?.string(for: "message") ?? "" + return (resp.string(for: "status") == "ok", msg) + } + + func listInvitations() async -> [Invitation] { + let resp = await sendAndReceive(type: "list_invitations") + guard resp.string(for: "status") == "ok", + let data = resp.dict(for: "data"), + let invList = data["invitations"] as? [[String: Any]] else { + return [] + } + + return invList.compactMap { dict -> Invitation? in + guard let convId = dict.string(for: "conversation_id") else { return nil } + return Invitation( + id: dict.string(for: "id") ?? convId, + conversationId: convId, + conversationName: dict.string(for: "conversation_name") ?? "Group", + invitedBy: dict.string(for: "invited_by") ?? "", + invitedByUsername: dict.string(for: "invited_by_username") ?? "" + ) + } + } + + // MARK: - Profile + + func getProfile(userId: String? = nil) async -> UserProfile? { + var params: [String: Any] = [:] + if let userId = userId { + params["user_id"] = userId + } + let resp = await sendAndReceive(type: "get_profile", params: params) + guard resp.string(for: "status") == "ok", + let data = resp.dict(for: "data") else { return nil } + + return UserProfile( + userId: data.string(for: "user_id") ?? userId ?? self.userId ?? "", + username: data.string(for: "username"), + email: data.string(for: "email"), + phone: data.string(for: "phone"), + phoneVisible: data.bool(for: "phone_visible") ?? false, + location: data.string(for: "location"), + locationVisible: data.bool(for: "location_visible") ?? false, + avatarFile: data.string(for: "avatar_file") + ) + } + + func updateProfile(phone: String? = nil, phoneVisible: Bool? = nil, + location: String? = nil, locationVisible: Bool? = nil) async -> Bool { + var params: [String: Any] = [:] + if let phone = phone { params["phone"] = phone } + if let phoneVisible = phoneVisible { params["phone_visible"] = phoneVisible } + if let location = location { params["location"] = location } + if let locationVisible = locationVisible { params["location_visible"] = locationVisible } + + let resp = await sendAndReceive(type: "update_profile", params: params) + return resp.string(for: "status") == "ok" + } + + func updateAvatar(imageData: Data) async -> Bool { + let resp = await sendAndReceive(type: "update_avatar", params: [ + "avatar_data": ProtocolHandler.encodeBinary(imageData), + ]) + return resp.string(for: "status") == "ok" + } + + func getAvatar(userId: String) async -> Data? { + let resp = await sendAndReceive(type: "get_avatar", params: ["user_id": userId]) + guard resp.string(for: "status") == "ok", + let data = resp.dict(for: "data"), + let avatarB64 = data.string(for: "avatar_data"), + let avatarData = try? ProtocolHandler.decodeBinary(avatarB64) else { + return nil + } + return avatarData + } + + // MARK: - Group Avatar + + func updateGroupAvatar(convId: String, imageData: Data) async -> Bool { + let resp = await sendAndReceive(type: "update_group_avatar", params: [ + "conversation_id": convId, + "avatar_data": ProtocolHandler.encodeBinary(imageData), + ]) + return resp.string(for: "status") == "ok" + } + + func getGroupAvatar(convId: String) async -> Data? { + let resp = await sendAndReceive(type: "get_group_avatar", params: ["conversation_id": convId]) + guard resp.string(for: "status") == "ok", + let data = resp.dict(for: "data"), + let avatarB64 = data.string(for: "avatar_data"), + let avatarData = try? ProtocolHandler.decodeBinary(avatarB64) else { + return nil + } + return avatarData + } + + // MARK: - File Sharing + + func sendFile(convId: String, fileData: Data, filename: String, mimeType: String, + members: [ConversationMember], replyTo: String? = nil) async -> (success: Bool, message: String) { + // Encrypt file with AES-GCM + guard let (aesKey, nonce, ct, tag) = try? CryptoUtils.aesEncrypt(fileData) else { + return (false, "File encryption failed") + } + + let encryptedData = ct + tag + let fileType = mimeType.hasPrefix("image/") ? "image" : "file" + + // Start upload + let startResp = await sendAndReceive(type: "upload_image_start", params: [ + "conversation_id": convId, + "file_size": encryptedData.count, + "file_type": fileType, + ]) + guard startResp.string(for: "status") == "ok", + let startData = startResp.dict(for: "data"), + let fileId = startData.string(for: "file_id") else { + let msg = startResp.dict(for: "data")?.string(for: "message") ?? "Upload start failed" + return (false, msg) + } + + // Upload chunks + var offset = 0 + while offset < encryptedData.count { + let end = min(offset + Constants.imageChunkSize, encryptedData.count) + let chunk = encryptedData[offset.. Data? { + var allData = Data() + var offset = 0 + + while true { + let resp = await sendAndReceive(type: "download_image", params: [ + "file_id": fileId, + "offset": offset, + ]) + guard resp.string(for: "status") == "ok", + let data = resp.dict(for: "data"), + let chunkB64 = data.string(for: "chunk_data"), + let chunk = try? ProtocolHandler.decodeBinary(chunkB64) else { + break + } + + if chunk.isEmpty { break } + allData.append(chunk) + offset += chunk.count + + if data.bool(for: "complete") == true { break } + } + + guard !allData.isEmpty else { return nil } + + // Decrypt: allData = ciphertext + tag(16) + guard allData.count >= 16 else { return nil } + let ct = allData.prefix(allData.count - 16) + let tag = allData.suffix(16) + return try? CryptoUtils.aesDecrypt(key: aesKey, nonce: iv, ciphertext: Data(ct), tag: Data(tag)) + } + + // MARK: - Devices + + func listDevices() async -> [[String: Any]] { + let resp = await sendAndReceive(type: "list_devices") + guard resp.string(for: "status") == "ok", + let data = resp.dict(for: "data") else { return [] } + return data["devices"] as? [[String: Any]] ?? [] + } + + func removeDevice(deviceIdToRemove: String) async -> Bool { + let resp = await sendAndReceive(type: "remove_device", params: [ + "device_id": deviceIdToRemove, + ]) + return resp.string(for: "status") == "ok" + } + + // MARK: - Session Reset + + func resetSession(peerUserId: String, peerDeviceId: String? = nil) async { + if let peerDeviceId = peerDeviceId { + let sessionKey = "\(peerUserId):\(peerDeviceId)" + sessions.removeValue(forKey: sessionKey) + KeyStorage.deleteSession(email: email, peerUserId: peerUserId, peerDeviceId: peerDeviceId) + } else { + // Delete all sessions for this user + for key in sessions.keys where key.hasPrefix(peerUserId) { + sessions.removeValue(forKey: key) + } + KeyStorage.deleteSession(email: email, peerUserId: peerUserId) + } + + _ = await sendAndReceive(type: "session_reset", params: [ + "peer_user_id": peerUserId, + "peer_device_id": peerDeviceId ?? "", + ]) + } + + func handleSessionResetNotification(fromUserId: String, fromDeviceId: String?) { + if let deviceId = fromDeviceId { + let sessionKey = "\(fromUserId):\(deviceId)" + sessions.removeValue(forKey: sessionKey) + KeyStorage.deleteSession(email: email, peerUserId: fromUserId, peerDeviceId: deviceId) + } else { + for key in sessions.keys where key.hasPrefix(fromUserId) { + sessions.removeValue(forKey: key) + } + KeyStorage.deleteSession(email: email, peerUserId: fromUserId) + } + } + + // MARK: - Search + + func searchMessages(convId: String, query: String) -> [[String: Any]] { + MessageCache.search(email: email, convId: convId, query: query, cacheKey: cacheKey) + } +} diff --git a/ios_client/EncryptedChat/Core/KeyStorage.swift b/ios_client/EncryptedChat/Core/KeyStorage.swift new file mode 100644 index 0000000..e310090 --- /dev/null +++ b/ios_client/EncryptedChat/Core/KeyStorage.swift @@ -0,0 +1,397 @@ +import Foundation +import CryptoKit + +/// Local file storage for keys, sessions, and sender keys. +/// Matches Python: chat_core.py key storage functions. +/// +/// Base directory: Application Support / EncryptedChat / {email} +/// Same file names as Python client for cross-platform compatibility. +enum KeyStorage { + + // MARK: - Base Directory + + /// Get or create the key storage directory for a user + static func getKeyDir(email: String) throws -> URL { + let appSupport = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask).first! + let dir = appSupport.appendingPathComponent("EncryptedChat").appendingPathComponent(email) + try FileManager.default.createDirectory(at: dir, withIntermediateDirectories: true) + // iOS file protection + try (dir as NSURL).setResourceValue(URLFileProtection.complete, forKey: .fileProtectionKey) + return dir + } + + // MARK: - RSA Keys + + /// Save RSA keypair + static func saveRSAKeys(email: String, privateKey: SecKey, publicKey: SecKey, password: Data? = nil) throws { + let dir = try getKeyDir(email: email) + let privData = try RSACrypto.serializePrivateKey(privateKey, password: password) + let pubData = try RSACrypto.serializePublicKey(publicKey) + try writeProtected(privData, to: dir.appendingPathComponent("private.pem")) + try pubData.write(to: dir.appendingPathComponent("public.pem")) + } + + /// Load RSA keypair. Returns (private, public, error). + static func loadRSAKeys(email: String, password: Data? = nil) -> (SecKey?, SecKey?, String?) { + guard let dir = try? getKeyDir(email: email) else { + return (nil, nil, "Cannot access key directory") + } + let privPath = dir.appendingPathComponent("private.pem") + let pubPath = dir.appendingPathComponent("public.pem") + + guard FileManager.default.fileExists(atPath: privPath.path) else { + return (nil, nil, "No local keys found.") + } + + guard let privData = try? Data(contentsOf: privPath), + let pubData = try? Data(contentsOf: pubPath) else { + return (nil, nil, "Cannot read key files.") + } + + do { + let privateKey = try RSACrypto.loadPrivateKey(privData, password: password) + let publicKey = try RSACrypto.loadPublicKey(pubData) + return (privateKey, publicKey, nil) + } catch { + // Try without password (unencrypted) + do { + let privateKey = try RSACrypto.loadPrivateKey(privData, password: nil) + let publicKey = try RSACrypto.loadPublicKey(pubData) + // Re-save with password if provided + if let password = password { + try? saveRSAKeys(email: email, privateKey: privateKey, publicKey: publicKey, password: password) + } + return (privateKey, publicKey, nil) + } catch { + return (nil, nil, "Invalid or missing password.") + } + } + } + + // MARK: - Identity Keys (Ed25519) + + static func saveIdentityKeys( + email: String, + privateKey: Curve25519.Signing.PrivateKey, + publicKey: Curve25519.Signing.PublicKey, + password: Data? = nil + ) throws { + let dir = try getKeyDir(email: email) + let privData = try Ed25519Crypto.serializePrivate(privateKey, password: password) + let pubData = Ed25519Crypto.serializePublic(publicKey) + try writeProtected(privData, to: dir.appendingPathComponent("identity_private.bin")) + try pubData.write(to: dir.appendingPathComponent("identity_public.bin")) + } + + static func loadIdentityKeys( + email: String, + password: Data? = nil + ) -> (Curve25519.Signing.PrivateKey?, Curve25519.Signing.PublicKey?) { + guard let dir = try? getKeyDir(email: email) else { return (nil, nil) } + let privPath = dir.appendingPathComponent("identity_private.bin") + let pubPath = dir.appendingPathComponent("identity_public.bin") + + guard FileManager.default.fileExists(atPath: privPath.path), + let privData = try? Data(contentsOf: privPath), + let pubData = try? Data(contentsOf: pubPath) else { + return (nil, nil) + } + + do { + let priv = try Ed25519Crypto.loadPrivate(privData, password: password) + let pub = try Ed25519Crypto.loadPublic(pubData) + return (priv, pub) + } catch { + return (nil, nil) + } + } + + // MARK: - Signed Pre-Key + + static func saveSPK(email: String, privateKey: Curve25519.KeyAgreement.PrivateKey, spkId: String) throws { + let dir = try getKeyDir(email: email) + try writeProtected(X25519Crypto.serializePrivate(privateKey), to: dir.appendingPathComponent("spk_private.bin")) + try spkId.write(to: dir.appendingPathComponent("spk_id.txt"), atomically: true, encoding: .utf8) + } + + static func loadSPK(email: String) -> (Curve25519.KeyAgreement.PrivateKey?, String?) { + guard let dir = try? getKeyDir(email: email) else { return (nil, nil) } + let privPath = dir.appendingPathComponent("spk_private.bin") + let idPath = dir.appendingPathComponent("spk_id.txt") + + guard FileManager.default.fileExists(atPath: privPath.path), + let privData = try? Data(contentsOf: privPath), + let priv = try? X25519Crypto.loadPrivate(privData) else { + return (nil, nil) + } + let spkId = (try? String(contentsOf: idPath, encoding: .utf8))?.trimmed ?? "" + return (priv, spkId) + } + + // MARK: - Previous SPK (Grace Period) + + static func savePrevSPK(email: String, privateKey: Curve25519.KeyAgreement.PrivateKey, spkId: String) throws { + let dir = try getKeyDir(email: email) + try writeProtected(X25519Crypto.serializePrivate(privateKey), to: dir.appendingPathComponent("prev_spk_private.bin")) + try spkId.write(to: dir.appendingPathComponent("prev_spk_id.txt"), atomically: true, encoding: .utf8) + } + + static func loadPrevSPK(email: String) -> (Curve25519.KeyAgreement.PrivateKey?, String?) { + guard let dir = try? getKeyDir(email: email) else { return (nil, nil) } + let privPath = dir.appendingPathComponent("prev_spk_private.bin") + let idPath = dir.appendingPathComponent("prev_spk_id.txt") + + guard FileManager.default.fileExists(atPath: privPath.path), + let privData = try? Data(contentsOf: privPath), + let priv = try? X25519Crypto.loadPrivate(privData) else { + return (nil, nil) + } + let spkId = (try? String(contentsOf: idPath, encoding: .utf8))?.trimmed ?? "" + return (priv, spkId) + } + + // MARK: - One-Time Pre-Keys + + static func saveOPKPrivate(email: String, opkId: String, privateKey: Curve25519.KeyAgreement.PrivateKey) throws { + let dir = try getKeyDir(email: email).appendingPathComponent("opk_private") + try FileManager.default.createDirectory(at: dir, withIntermediateDirectories: true) + try writeProtected(X25519Crypto.serializePrivate(privateKey), to: dir.appendingPathComponent("\(opkId).bin")) + } + + static func loadOPKPrivate(email: String, opkId: String) -> Curve25519.KeyAgreement.PrivateKey? { + guard let dir = try? getKeyDir(email: email) else { return nil } + let path = dir.appendingPathComponent("opk_private").appendingPathComponent("\(opkId).bin") + guard let data = try? Data(contentsOf: path) else { return nil } + return try? X25519Crypto.loadPrivate(data) + } + + static func deleteOPKPrivate(email: String, opkId: String) { + guard let dir = try? getKeyDir(email: email) else { return } + let path = dir.appendingPathComponent("opk_private").appendingPathComponent("\(opkId).bin") + try? FileManager.default.removeItem(at: path) + } + + // MARK: - Device ID + + static func saveDeviceId(email: String, deviceId: String) throws { + let dir = try getKeyDir(email: email) + try writeProtected(Data(deviceId.utf8), to: dir.appendingPathComponent("device_id.txt")) + } + + static func loadDeviceId(email: String) -> String? { + guard let dir = try? getKeyDir(email: email) else { return nil } + let path = dir.appendingPathComponent("device_id.txt") + guard let data = try? Data(contentsOf: path) else { return nil } + let str = String(data: data, encoding: .utf8)?.trimmed + return (str?.isEmpty ?? true) ? nil : str + } + + // MARK: - Sessions (Double Ratchet) + + static func saveSession( + email: String, + peerUserId: String, + ratchet: DoubleRatchet, + localKey: Data? = nil, + peerDeviceId: String? = nil + ) throws { + let dir = try getKeyDir(email: email).appendingPathComponent("sessions") + try FileManager.default.createDirectory(at: dir, withIntermediateDirectories: true) + + let filename: String + if let deviceId = peerDeviceId { + filename = "\(peerUserId)_\(deviceId).bin" + } else { + filename = "\(peerUserId).bin" + } + + var data = try ratchet.exportState() + if let localKey = localKey { + data = try CryptoUtils.encryptLocal(data, key: localKey) + } + try writeProtected(data, to: dir.appendingPathComponent(filename)) + } + + static func loadSession( + email: String, + peerUserId: String, + localKey: Data? = nil, + peerDeviceId: String? = nil + ) -> DoubleRatchet? { + guard let dir = try? getKeyDir(email: email) else { return nil } + let sessionsDir = dir.appendingPathComponent("sessions") + + let filename: String + if let deviceId = peerDeviceId { + filename = "\(peerUserId)_\(deviceId).bin" + } else { + filename = "\(peerUserId).bin" + } + + let path = sessionsDir.appendingPathComponent(filename) + return loadSessionFile(path, localKey: localKey) + } + + static func deleteSession(email: String, peerUserId: String, peerDeviceId: String? = nil) { + guard let dir = try? getKeyDir(email: email) else { return } + let sessionsDir = dir.appendingPathComponent("sessions") + + if let deviceId = peerDeviceId { + let path = sessionsDir.appendingPathComponent("\(peerUserId)_\(deviceId).bin") + try? FileManager.default.removeItem(at: path) + } else { + // Delete all sessions for this user + if let files = try? FileManager.default.contentsOfDirectory(atPath: sessionsDir.path) { + for file in files where file.hasPrefix(peerUserId) { + try? FileManager.default.removeItem(at: sessionsDir.appendingPathComponent(file)) + } + } + } + } + + private static func loadSessionFile(_ path: URL, localKey: Data?) -> DoubleRatchet? { + guard let raw = try? Data(contentsOf: path) else { return nil } + + if let localKey = localKey { + // Try encrypted first + if let decrypted = try? CryptoUtils.decryptLocal(raw, key: localKey) { + return try? DoubleRatchet.importState(decrypted) + } + // Fallback: plaintext (transparent migration) + if let ratchet = try? DoubleRatchet.importState(raw) { + // Re-save encrypted + try? writeProtected(CryptoUtils.encryptLocal(try ratchet.exportState(), key: localKey), to: path) + return ratchet + } + return nil + } + + return try? DoubleRatchet.importState(raw) + } + + // MARK: - Sender Keys + + static func saveSenderKeyState( + email: String, + convId: String, + state: SenderKeyState, + localKey: Data? = nil + ) throws { + let dir = try getKeyDir(email: email).appendingPathComponent("sender_keys") + try FileManager.default.createDirectory(at: dir, withIntermediateDirectories: true) + + var data = state.exportState() + if let localKey = localKey { + data = try CryptoUtils.encryptLocal(data, key: localKey) + } + try writeProtected(data, to: dir.appendingPathComponent("\(convId).bin")) + } + + static func loadSenderKeyState( + email: String, + convId: String, + localKey: Data? = nil + ) -> SenderKeyState? { + guard let dir = try? getKeyDir(email: email) else { return nil } + let path = dir.appendingPathComponent("sender_keys").appendingPathComponent("\(convId).bin") + guard let raw = try? Data(contentsOf: path) else { return nil } + + if let localKey = localKey { + if let decrypted = try? CryptoUtils.decryptLocal(raw, key: localKey) { + return try? SenderKeyState.importState(decrypted) + } + // Plaintext fallback + if let state = try? SenderKeyState.importState(raw) { + try? writeProtected(CryptoUtils.encryptLocal(state.exportState(), key: localKey), to: path) + return state + } + return nil + } + + return try? SenderKeyState.importState(raw) + } + + static func deleteSenderKeyState(email: String, convId: String) { + guard let dir = try? getKeyDir(email: email) else { return } + let path = dir.appendingPathComponent("sender_keys").appendingPathComponent("\(convId).bin") + try? FileManager.default.removeItem(at: path) + } + + // MARK: - Received Sender Keys + + static func saveRecvSenderKey( + email: String, + convId: String, + senderId: String, + senderDeviceId: String, + state: SenderKeyState, + localKey: Data? = nil + ) throws { + let dir = try getKeyDir(email: email).appendingPathComponent("sender_keys_recv") + try FileManager.default.createDirectory(at: dir, withIntermediateDirectories: true) + + var data = state.exportState() + if let localKey = localKey { + data = try CryptoUtils.encryptLocal(data, key: localKey) + } + try writeProtected(data, to: dir.appendingPathComponent("\(convId)_\(senderId)_\(senderDeviceId).bin")) + } + + static func loadRecvSenderKey( + email: String, + convId: String, + senderId: String, + senderDeviceId: String, + localKey: Data? = nil + ) -> SenderKeyState? { + guard let dir = try? getKeyDir(email: email) else { return nil } + let path = dir.appendingPathComponent("sender_keys_recv").appendingPathComponent("\(convId)_\(senderId)_\(senderDeviceId).bin") + guard let raw = try? Data(contentsOf: path) else { return nil } + + if let localKey = localKey { + if let decrypted = try? CryptoUtils.decryptLocal(raw, key: localKey) { + return try? SenderKeyState.importState(decrypted) + } + if let state = try? SenderKeyState.importState(raw) { + try? writeProtected(CryptoUtils.encryptLocal(state.exportState(), key: localKey), to: path) + return state + } + return nil + } + + return try? SenderKeyState.importState(raw) + } + + static func deleteRecvSenderKeys(email: String, convId: String) { + guard let dir = try? getKeyDir(email: email) else { return } + let recvDir = dir.appendingPathComponent("sender_keys_recv") + guard let files = try? FileManager.default.contentsOfDirectory(atPath: recvDir.path) else { return } + for file in files where file.hasPrefix(convId) { + try? FileManager.default.removeItem(at: recvDir.appendingPathComponent(file)) + } + } + + // MARK: - Favorites + + static func saveFavorites(email: String, favorites: Set) throws { + let dir = try getKeyDir(email: email) + let data = try JSONSerialization.data(withJSONObject: Array(favorites)) + try data.write(to: dir.appendingPathComponent("favorites.json")) + } + + static func loadFavorites(email: String) -> Set { + guard let dir = try? getKeyDir(email: email) else { return [] } + let path = dir.appendingPathComponent("favorites.json") + guard let data = try? Data(contentsOf: path), + let array = try? JSONSerialization.jsonObject(with: data) as? [String] else { + return [] + } + return Set(array) + } + + // MARK: - Helpers + + private static func writeProtected(_ data: Data, to url: URL) throws { + try data.write(to: url, options: .completeFileProtection) + } +} diff --git a/ios_client/EncryptedChat/Core/MessageCache.swift b/ios_client/EncryptedChat/Core/MessageCache.swift new file mode 100644 index 0000000..f88d9df --- /dev/null +++ b/ios_client/EncryptedChat/Core/MessageCache.swift @@ -0,0 +1,65 @@ +import Foundation + +/// Encrypted local message cache. +/// Matches Python: chat_core.py message cache (message_cache/{conv_id}.json) +enum MessageCache { + + /// Save messages for a conversation (encrypted with local storage key) + static func save(email: String, convId: String, messages: [[String: Any]], cacheKey: Data?) throws { + let dir = try KeyStorage.getKeyDir(email: email).appendingPathComponent("message_cache") + try FileManager.default.createDirectory(at: dir, withIntermediateDirectories: true) + + let jsonData = try JSONSerialization.data(withJSONObject: messages) + + let dataToWrite: Data + if let cacheKey = cacheKey { + dataToWrite = try CryptoUtils.encryptLocal(jsonData, key: cacheKey) + } else { + dataToWrite = jsonData + } + + try dataToWrite.write(to: dir.appendingPathComponent("\(convId).json"), options: .completeFileProtection) + } + + /// Load messages for a conversation + static func load(email: String, convId: String, cacheKey: Data?) -> [[String: Any]]? { + guard let dir = try? KeyStorage.getKeyDir(email: email) else { return nil } + let path = dir.appendingPathComponent("message_cache").appendingPathComponent("\(convId).json") + guard let raw = try? Data(contentsOf: path) else { return nil } + + let jsonData: Data + if let cacheKey = cacheKey { + if let decrypted = try? CryptoUtils.decryptLocal(raw, key: cacheKey) { + jsonData = decrypted + } else { + // Plaintext fallback (migration) + jsonData = raw + } + } else { + jsonData = raw + } + + return try? JSONSerialization.jsonObject(with: jsonData) as? [[String: Any]] + } + + /// Search messages in a conversation + static func search(email: String, convId: String, query: String, cacheKey: Data?) -> [[String: Any]] { + guard let messages = load(email: email, convId: convId, cacheKey: cacheKey) else { + return [] + } + let lowerQuery = query.lowercased() + return messages.filter { msg in + if let text = msg["text"] as? String, text.lowercased().contains(lowerQuery) { + return true + } + return false + } + } + + /// Delete cache for a conversation + static func delete(email: String, convId: String) { + guard let dir = try? KeyStorage.getKeyDir(email: email) else { return } + let path = dir.appendingPathComponent("message_cache").appendingPathComponent("\(convId).json") + try? FileManager.default.removeItem(at: path) + } +} diff --git a/ios_client/EncryptedChat/Crypto/CryptoErrors.swift b/ios_client/EncryptedChat/Crypto/CryptoErrors.swift new file mode 100644 index 0000000..bf45320 --- /dev/null +++ b/ios_client/EncryptedChat/Crypto/CryptoErrors.swift @@ -0,0 +1,95 @@ +import Foundation + +enum CryptoError: Error, LocalizedError { + case invalidBase64 + case invalidHex + case invalidKeyData(String) + case invalidSignature + case signatureVerificationFailed + case encryptionFailed(String) + case decryptionFailed(String) + case invalidECP1Format + case pbkdf2Failed + case rsaKeyGenerationFailed + case rsaOperationFailed(String) + case x3dhFailed(String) + case ratchetError(String) + case senderKeyError(String) + case maxSkipExceeded + case duplicateMessage + case invalidHeader(String) + case stateImportFailed(String) + case keyConversionFailed(String) + + var errorDescription: String? { + switch self { + case .invalidBase64: return "Invalid base64 encoding" + case .invalidHex: return "Invalid hex encoding" + case .invalidKeyData(let msg): return "Invalid key data: \(msg)" + case .invalidSignature: return "Invalid signature format" + case .signatureVerificationFailed: return "Signature verification failed" + case .encryptionFailed(let msg): return "Encryption failed: \(msg)" + case .decryptionFailed(let msg): return "Decryption failed: \(msg)" + case .invalidECP1Format: return "Invalid ECP1 key format" + case .pbkdf2Failed: return "PBKDF2 key derivation failed" + case .rsaKeyGenerationFailed: return "RSA key generation failed" + case .rsaOperationFailed(let msg): return "RSA operation failed: \(msg)" + case .x3dhFailed(let msg): return "X3DH failed: \(msg)" + case .ratchetError(let msg): return "Ratchet error: \(msg)" + case .senderKeyError(let msg): return "Sender key error: \(msg)" + case .maxSkipExceeded: return "Maximum message skip exceeded" + case .duplicateMessage: return "Duplicate message detected" + case .invalidHeader(let msg): return "Invalid header: \(msg)" + case .stateImportFailed(let msg): return "State import failed: \(msg)" + case .keyConversionFailed(let msg): return "Key conversion failed: \(msg)" + } + } +} + +enum NetworkError: Error, LocalizedError { + case notConnected + case connectionFailed(String) + case timeout + case serverError(String) + case protocolError(String) + case messageTooLarge + case invalidResponse(String) + case authenticationFailed(String) + case alreadyConnected + + var errorDescription: String? { + switch self { + case .notConnected: return "Not connected to server" + case .connectionFailed(let msg): return "Connection failed: \(msg)" + case .timeout: return "Request timed out" + case .serverError(let msg): return "Server error: \(msg)" + case .protocolError(let msg): return "Protocol error: \(msg)" + case .messageTooLarge: return "Message exceeds maximum size" + case .invalidResponse(let msg): return "Invalid response: \(msg)" + case .authenticationFailed(let msg): return "Authentication failed: \(msg)" + case .alreadyConnected: return "Already connected" + } + } +} + +enum ChatError: Error, LocalizedError { + case notLoggedIn + case conversationNotFound + case membershipRequired + case permissionDenied(String) + case operationFailed(String) + case fileError(String) + case invalidData(String) + + var errorDescription: String? { + switch self { + case .notLoggedIn: return "Not logged in" + case .conversationNotFound: return "Conversation not found" + case .membershipRequired: return "Must be a member of this conversation" + case .permissionDenied(let msg): return "Permission denied: \(msg)" + case .operationFailed(let msg): return "Operation failed: \(msg)" + case .fileError(let msg): return "File error: \(msg)" + case .invalidData(let msg): return "Invalid data: \(msg)" + } + } +} diff --git a/ios_client/EncryptedChat/Crypto/CryptoUtils.swift b/ios_client/EncryptedChat/Crypto/CryptoUtils.swift new file mode 100644 index 0000000..60e5abb --- /dev/null +++ b/ios_client/EncryptedChat/Crypto/CryptoUtils.swift @@ -0,0 +1,196 @@ +import Foundation +import CryptoKit + +/// Core cryptographic utilities: AES-GCM, HKDF, KDF helpers +enum CryptoUtils { + + // MARK: - AES-256-GCM + + /// Encrypt with AES-256-GCM. Returns (key, nonce, ciphertext, tag) — all as Data. + /// If key is nil, generates a random 256-bit key. + /// Matches Python: aes_encrypt(plaintext, key=None) + static func aesEncrypt(_ plaintext: Data, key: Data? = nil) throws -> (key: Data, nonce: Data, ciphertext: Data, tag: Data) { + let keyData = key ?? Data.randomBytes(32) + let symmetricKey = SymmetricKey(data: keyData) + let nonceData = Data.randomBytes(12) + let gcmNonce = try AES.GCM.Nonce(data: nonceData) + + let sealedBox = try AES.GCM.seal(plaintext, using: symmetricKey, nonce: gcmNonce) + + return ( + key: keyData, + nonce: nonceData, + ciphertext: Data(sealedBox.ciphertext), + tag: Data(sealedBox.tag) + ) + } + + /// Decrypt with AES-256-GCM. + /// Matches Python: aes_decrypt(key, nonce, ciphertext, tag) + static func aesDecrypt(key: Data, nonce: Data, ciphertext: Data, tag: Data) throws -> Data { + let symmetricKey = SymmetricKey(data: key) + let gcmNonce = try AES.GCM.Nonce(data: nonce) + + let sealedBox = try AES.GCM.SealedBox( + nonce: gcmNonce, + ciphertext: ciphertext, + tag: tag + ) + + do { + return try AES.GCM.open(sealedBox, using: symmetricKey) + } catch { + throw CryptoError.decryptionFailed("AES-GCM decryption failed") + } + } + + /// Encrypt with AES-256-GCM using AAD. Returns ciphertext with tag appended. + /// Used by Double Ratchet and Sender Keys. + static func aesGcmEncrypt(_ plaintext: Data, key: Data, nonce: Data, aad: Data) throws -> Data { + let symmetricKey = SymmetricKey(data: key) + let gcmNonce = try AES.GCM.Nonce(data: nonce) + + let sealedBox = try AES.GCM.seal( + plaintext, + using: symmetricKey, + nonce: gcmNonce, + authenticating: aad + ) + + // Return ciphertext + tag concatenated (matches Python AESGCM.encrypt) + return Data(sealedBox.ciphertext) + Data(sealedBox.tag) + } + + /// Decrypt AES-256-GCM with AAD. Input ciphertext has tag appended (last 16 bytes). + static func aesGcmDecrypt(_ ctWithTag: Data, key: Data, nonce: Data, aad: Data) throws -> Data { + guard ctWithTag.count >= 16 else { + throw CryptoError.decryptionFailed("Ciphertext too short") + } + + let ct = ctWithTag.prefix(ctWithTag.count - 16) + let tag = ctWithTag.suffix(16) + + let symmetricKey = SymmetricKey(data: key) + let gcmNonce = try AES.GCM.Nonce(data: nonce) + + let sealedBox = try AES.GCM.SealedBox( + nonce: gcmNonce, + ciphertext: ct, + tag: tag + ) + + do { + return try AES.GCM.open(sealedBox, using: symmetricKey, authenticating: aad) + } catch { + throw CryptoError.decryptionFailed("AES-GCM decryption with AAD failed") + } + } + + // MARK: - HKDF + + /// HKDF-SHA256 key derivation. + /// Matches Python: hkdf_derive(input_key, salt, info, length=32) + static func hkdfDerive(inputKey: Data, salt: Data, info: Data, length: Int = 32) -> Data { + let symmetricKey = SymmetricKey(data: inputKey) + let derived = HKDF.deriveKey( + inputKeyMaterial: symmetricKey, + salt: salt, + info: info, + outputByteCount: length + ) + return derived.withUnsafeBytes { Data($0) } + } + + // MARK: - KDF for Double Ratchet + + /// Root key KDF. Returns (newRootKey, chainKey). + /// HKDF with rootKey as salt and DH output as input. Derives 64 bytes, split in half. + /// Matches Python: kdf_rk(root_key, dh_output) + static func kdfRK(rootKey: Data, dhOutput: Data) -> (newRootKey: Data, chainKey: Data) { + let derived = hkdfDerive( + inputKey: dhOutput, + salt: rootKey, + info: Data(Constants.rootKeyInfo.utf8), + length: 64 + ) + return (derived.prefix(32), Data(derived.suffix(32))) + } + + /// Chain key KDF. Returns (newChainKey, messageKey). + /// HMAC-SHA256: messageKey = HMAC(chainKey, 0x01), newChainKey = HMAC(chainKey, 0x02) + /// Matches Python: kdf_ck(chain_key) + static func kdfCK(chainKey: Data) -> (newChainKey: Data, messageKey: Data) { + let symmetricKey = SymmetricKey(data: chainKey) + let messageKey = Data(HMAC.authenticationCode(for: Data([0x01]), using: symmetricKey)) + let newChainKey = Data(HMAC.authenticationCode(for: Data([0x02]), using: symmetricKey)) + return (newChainKey, messageKey) + } + + // MARK: - Self-Encryption Key + + /// Derive static AES-256 key from identity key for self-encrypted message copies. + /// Matches Python: derive_self_encryption_key(identity_private) + static func deriveSelfEncryptionKey(identityPrivateRaw: Data) -> Data { + hkdfDerive( + inputKey: identityPrivateRaw, + salt: Data(Constants.selfEncryptionSalt.utf8), + info: Data(Constants.selfEncryptionInfo.utf8), + length: 32 + ) + } + + // MARK: - Local Storage Key + + /// Derive AES-256 key for encrypting local session/sender key files. + /// Matches Python: derive_local_storage_key(identity_private) + static func deriveLocalStorageKey(identityPrivateRaw: Data) -> Data { + hkdfDerive( + inputKey: identityPrivateRaw, + salt: Data(Constants.localStorageSalt.utf8), + info: Data(Constants.localStorageInfo.utf8), + length: 32 + ) + } + + // MARK: - Local File Encryption + + /// Encrypt data for local storage. Format: nonce(12) + tag(16) + ciphertext + /// Matches Python: _encrypt_local(data, key) + static func encryptLocal(_ data: Data, key: Data) throws -> Data { + let symmetricKey = SymmetricKey(data: key) + let sealedBox = try AES.GCM.seal(data, using: symmetricKey) + + var result = Data() + result.append(Data(sealedBox.nonce)) // 12 bytes + result.append(Data(sealedBox.tag)) // 16 bytes + result.append(Data(sealedBox.ciphertext)) // N bytes + return result + } + + /// Decrypt locally stored data. Format: nonce(12) + tag(16) + ciphertext + /// Matches Python: _decrypt_local(raw, key) + static func decryptLocal(_ raw: Data, key: Data) throws -> Data { + guard raw.count >= 28 else { // 12 + 16 minimum + throw CryptoError.decryptionFailed("Local encrypted data too short") + } + + let nonce = raw[0..<12] + let tag = raw[12..<28] + let ct = raw[28...] + + let symmetricKey = SymmetricKey(data: key) + let gcmNonce = try AES.GCM.Nonce(data: nonce) + + let sealedBox = try AES.GCM.SealedBox( + nonce: gcmNonce, + ciphertext: ct, + tag: tag + ) + + do { + return try AES.GCM.open(sealedBox, using: symmetricKey) + } catch { + throw CryptoError.decryptionFailed("Local storage decryption failed") + } + } +} diff --git a/ios_client/EncryptedChat/Crypto/DoubleRatchet.swift b/ios_client/EncryptedChat/Crypto/DoubleRatchet.swift new file mode 100644 index 0000000..29c8add --- /dev/null +++ b/ios_client/EncryptedChat/Crypto/DoubleRatchet.swift @@ -0,0 +1,371 @@ +import Foundation +import CryptoKit + +/// Ratchet header sent with each message +struct RatchetHeader { + let dhPub: Data // sender's current ratchet public key (32 bytes) + let n: Int // message number in current sending chain + let pn: Int // number of messages in previous sending chain + + /// Serialize header to JSON bytes for use as AAD. + /// Matches Python: RatchetHeader.serialize() + func serialize() -> Data { + let dict: [String: Any] = [ + "dh_pub": dhPub.hexString, + "n": n, + "pn": pn, + ] + // Must produce consistent JSON — sorted keys for determinism + return try! JSONSerialization.data(withJSONObject: dict, options: .sortedKeys) + } + + /// Convert to dictionary for protocol. + /// Matches Python: RatchetHeader.to_dict() + func toDict() -> [String: Any] { + [ + "dh_pub": dhPub.hexString, + "n": n, + "pn": pn, + ] + } + + /// Parse from dictionary. + /// Matches Python: RatchetHeader.from_dict(d) + static func fromDict(_ d: [String: Any]) throws -> RatchetHeader { + guard let dhPubHex = d["dh_pub"] as? String, + let dhPub = Data(hexString: dhPubHex), + let n = d["n"] as? Int, + let pn = d["pn"] as? Int else { + throw CryptoError.invalidHeader("Missing or invalid header fields") + } + return RatchetHeader(dhPub: dhPub, n: n, pn: pn) + } +} + +/// Signal Double Ratchet implementation. +/// Matches Python: DoubleRatchet class in crypto_utils.py +class DoubleRatchet { + + private(set) var dhPair: (privateKey: Curve25519.KeyAgreement.PrivateKey, + publicKey: Curve25519.KeyAgreement.PublicKey)? + private(set) var dhRemote: Curve25519.KeyAgreement.PublicKey? + private(set) var rootKey: Data = Data() + private(set) var sendChainKey: Data? + private(set) var recvChainKey: Data? + private(set) var sendN: Int = 0 + private(set) var recvN: Int = 0 + private(set) var prevSendN: Int = 0 + // Skipped message keys: "dh_pub_hex:n" → message_key + private(set) var skipped: [String: Data] = [:] + + /// Attached X3DH header — set when creating a new session, consumed on first send. + /// Matches Python: ratchet._x3dh_header + var x3dhHeader: [String: Any]? + + init() {} + + // MARK: - Initialization + + /// Initialize as initiator (Alice) after X3DH. + /// Matches Python: DoubleRatchet.init_alice(shared_secret, bob_spk_pub) + static func initAlice(sharedSecret: Data, bobSpkPub: Curve25519.KeyAgreement.PublicKey) throws -> DoubleRatchet { + let ratchet = DoubleRatchet() + let (priv, pub) = X25519Crypto.generateKeypair() + ratchet.dhPair = (priv, pub) + ratchet.dhRemote = bobSpkPub + + // Perform DH ratchet to derive send chain + let dhOutput = try X25519Crypto.dh(priv, bobSpkPub) + let (newRK, sendCK) = CryptoUtils.kdfRK(rootKey: sharedSecret, dhOutput: dhOutput) + ratchet.rootKey = newRK + ratchet.sendChainKey = sendCK + ratchet.recvChainKey = nil + ratchet.sendN = 0 + ratchet.recvN = 0 + ratchet.prevSendN = 0 + return ratchet + } + + /// Initialize as responder (Bob) after X3DH. + /// Matches Python: DoubleRatchet.init_bob(shared_secret, spk_pair) + static func initBob( + sharedSecret: Data, + spkPair: (privateKey: Curve25519.KeyAgreement.PrivateKey, publicKey: Curve25519.KeyAgreement.PublicKey) + ) -> DoubleRatchet { + let ratchet = DoubleRatchet() + ratchet.dhPair = spkPair + ratchet.rootKey = sharedSecret + ratchet.sendChainKey = nil + ratchet.recvChainKey = nil + ratchet.sendN = 0 + ratchet.recvN = 0 + ratchet.prevSendN = 0 + return ratchet + } + + // MARK: - Encrypt + + /// Encrypt a message. + /// Returns (header dict, ciphertext with tag, nonce). + /// Matches Python: DoubleRatchet.encrypt(plaintext) + func encrypt(_ plaintext: Data) throws -> (header: [String: Any], ciphertext: Data, nonce: Data) { + guard sendChainKey != nil else { + throw CryptoError.ratchetError("Send chain not initialized") + } + guard let dhPair = dhPair else { + throw CryptoError.ratchetError("DH pair not set") + } + + let (newCK, messageKey) = CryptoUtils.kdfCK(chainKey: sendChainKey!) + sendChainKey = newCK + + let header = RatchetHeader( + dhPub: X25519Crypto.serializePublic(dhPair.publicKey), + n: sendN, + pn: prevSendN + ) + + let nonce = Data.randomBytes(12) + let aad = header.serialize() + let ctWithTag = try CryptoUtils.aesGcmEncrypt(plaintext, key: messageKey, nonce: nonce, aad: aad) + + sendN += 1 + + return (header.toDict(), ctWithTag, nonce) + } + + // MARK: - Decrypt + + /// Decrypt a message. Handles DH ratchet step if new dh_pub. + /// State is snapshotted before modification and restored on failure (M9 fix). + /// Matches Python: DoubleRatchet.decrypt(header_dict, ciphertext, nonce) + func decrypt(headerDict: [String: Any], ciphertext: Data, nonce: Data) throws -> Data { + let header = try RatchetHeader.fromDict(headerDict) + let remoteDhPubBytes = header.dhPub + + // Check if this is from a skipped message + let skipKey = "\(remoteDhPubBytes.hexString):\(header.n)" + if let mk = skipped[skipKey] { + skipped.removeValue(forKey: skipKey) + let aad = header.serialize() + do { + return try CryptoUtils.aesGcmDecrypt(ciphertext, key: mk, nonce: nonce, aad: aad) + } catch { + // Restore skipped key on failure + skipped[skipKey] = mk + throw error + } + } + + // Snapshot state before modifications + let snap = snapshot() + + do { + let remoteDhPub = try X25519Crypto.loadPublic(remoteDhPubBytes) + let currentRemoteBytes: Data? = dhRemote.map { X25519Crypto.serializePublic($0) } + + if currentRemoteBytes == nil || remoteDhPubBytes != currentRemoteBytes { + // New DH ratchet step + try skipMessages(until: header.pn) + try dhRatchet(remoteDhPub: remoteDhPub) + } + + try skipMessages(until: header.n) + + // Derive message key from receive chain + guard recvChainKey != nil else { + throw CryptoError.ratchetError("Receive chain key is nil") + } + let (newCK, mk) = CryptoUtils.kdfCK(chainKey: recvChainKey!) + recvChainKey = newCK + recvN += 1 + + let aad = header.serialize() + return try CryptoUtils.aesGcmDecrypt(ciphertext, key: mk, nonce: nonce, aad: aad) + } catch { + restore(snap) + throw error + } + } + + // MARK: - State Snapshot/Restore (M9) + + private struct Snapshot { + let dhPairPriv: Data? + let dhPairPub: Data? + let dhRemote: Data? + let rootKey: Data + let sendChainKey: Data? + let recvChainKey: Data? + let sendN: Int + let recvN: Int + let prevSendN: Int + let skipped: [String: Data] + } + + private func snapshot() -> Snapshot { + Snapshot( + dhPairPriv: dhPair.map { X25519Crypto.serializePrivate($0.privateKey) }, + dhPairPub: dhPair.map { X25519Crypto.serializePublic($0.publicKey) }, + dhRemote: dhRemote.map { X25519Crypto.serializePublic($0) }, + rootKey: rootKey, + sendChainKey: sendChainKey, + recvChainKey: recvChainKey, + sendN: sendN, + recvN: recvN, + prevSendN: prevSendN, + skipped: skipped + ) + } + + private func restore(_ snap: Snapshot) { + if let privData = snap.dhPairPriv, let pubData = snap.dhPairPub, + let priv = try? X25519Crypto.loadPrivate(privData), + let pub = try? X25519Crypto.loadPublic(pubData) { + dhPair = (priv, pub) + } else { + dhPair = nil + } + if let remoteData = snap.dhRemote, let remote = try? X25519Crypto.loadPublic(remoteData) { + dhRemote = remote + } else { + dhRemote = nil + } + rootKey = snap.rootKey + sendChainKey = snap.sendChainKey + recvChainKey = snap.recvChainKey + sendN = snap.sendN + recvN = snap.recvN + prevSendN = snap.prevSendN + skipped = snap.skipped + } + + // MARK: - Internal Ratchet Operations + + private func skipMessages(until: Int) throws { + guard recvChainKey != nil else { return } + if until - recvN > Constants.maxSkip { + throw CryptoError.maxSkipExceeded + } + while recvN < until { + let (newCK, mk) = CryptoUtils.kdfCK(chainKey: recvChainKey!) + recvChainKey = newCK + let remoteHex = dhRemote.map { X25519Crypto.serializePublic($0).hexString } ?? "" + skipped["\(remoteHex):\(recvN)"] = mk + recvN += 1 + } + } + + private func dhRatchet(remoteDhPub: Curve25519.KeyAgreement.PublicKey) throws { + prevSendN = sendN + sendN = 0 + recvN = 0 + dhRemote = remoteDhPub + + // Derive new receive chain key + guard let dhPair = dhPair else { + throw CryptoError.ratchetError("DH pair not set") + } + let dhOutput1 = try X25519Crypto.dh(dhPair.privateKey, remoteDhPub) + let (newRK1, recvCK) = CryptoUtils.kdfRK(rootKey: rootKey, dhOutput: dhOutput1) + rootKey = newRK1 + recvChainKey = recvCK + + // Generate new DH pair and derive new send chain key + let (newPriv, newPub) = X25519Crypto.generateKeypair() + self.dhPair = (newPriv, newPub) + let dhOutput2 = try X25519Crypto.dh(newPriv, remoteDhPub) + let (newRK2, sendCK) = CryptoUtils.kdfRK(rootKey: rootKey, dhOutput: dhOutput2) + rootKey = newRK2 + sendChainKey = sendCK + } + + // MARK: - State Export/Import + + /// Serialize full ratchet state for persistent storage. + /// Produces JSON matching Python's DoubleRatchet.export_state() exactly. + func exportState() throws -> Data { + var state: [String: Any] = [:] + + if let pair = dhPair { + state["dh_priv"] = X25519Crypto.serializePrivate(pair.privateKey).hexString + state["dh_pub"] = X25519Crypto.serializePublic(pair.publicKey).hexString + } else { + state["dh_priv"] = NSNull() + state["dh_pub"] = NSNull() + } + + if let remote = dhRemote { + state["dh_remote"] = X25519Crypto.serializePublic(remote).hexString + } else { + state["dh_remote"] = NSNull() + } + + state["root_key"] = rootKey.hexString + state["send_ck"] = sendChainKey?.hexString ?? NSNull() + state["recv_ck"] = recvChainKey?.hexString ?? NSNull() + state["send_n"] = sendN + state["recv_n"] = recvN + state["prev_send_n"] = prevSendN + + // Skipped keys: Python format is "dh_pub_hex:n" -> message_key_hex + var skippedDict: [String: String] = [:] + for (key, value) in skipped { + skippedDict[key] = value.hexString + } + state["skipped"] = skippedDict + + return try JSONSerialization.data(withJSONObject: state) + } + + /// Deserialize ratchet state. + /// Matches Python: DoubleRatchet.import_state(data) + static func importState(_ data: Data) throws -> DoubleRatchet { + guard let state = try JSONSerialization.jsonObject(with: data) as? [String: Any] else { + throw CryptoError.stateImportFailed("Invalid JSON") + } + + let r = DoubleRatchet() + + if let dhPrivHex = state["dh_priv"] as? String, + let dhPubHex = state["dh_pub"] as? String, + let privData = Data(hexString: dhPrivHex), + let pubData = Data(hexString: dhPubHex) { + let priv = try X25519Crypto.loadPrivate(privData) + let pub = try X25519Crypto.loadPublic(pubData) + r.dhPair = (priv, pub) + } + + if let dhRemoteHex = state["dh_remote"] as? String, + let remoteData = Data(hexString: dhRemoteHex) { + r.dhRemote = try X25519Crypto.loadPublic(remoteData) + } + + guard let rootKeyHex = state["root_key"] as? String, + let rootKey = Data(hexString: rootKeyHex) else { + throw CryptoError.stateImportFailed("Missing root_key") + } + r.rootKey = rootKey + + if let sendCKHex = state["send_ck"] as? String, let ck = Data(hexString: sendCKHex) { + r.sendChainKey = ck + } + if let recvCKHex = state["recv_ck"] as? String, let ck = Data(hexString: recvCKHex) { + r.recvChainKey = ck + } + + r.sendN = state["send_n"] as? Int ?? 0 + r.recvN = state["recv_n"] as? Int ?? 0 + r.prevSendN = state["prev_send_n"] as? Int ?? 0 + + if let skippedDict = state["skipped"] as? [String: String] { + for (key, valueHex) in skippedDict { + if let value = Data(hexString: valueHex) { + r.skipped[key] = value + } + } + } + + return r + } +} diff --git a/ios_client/EncryptedChat/Crypto/Ed25519Crypto.swift b/ios_client/EncryptedChat/Crypto/Ed25519Crypto.swift new file mode 100644 index 0000000..a71ad96 --- /dev/null +++ b/ios_client/EncryptedChat/Crypto/Ed25519Crypto.swift @@ -0,0 +1,73 @@ +import Foundation +import CryptoKit + +/// Ed25519 signing operations — Identity Key management +enum Ed25519Crypto { + + // MARK: - Key Generation + + /// Generate Ed25519 keypair + static func generateKeypair() -> (privateKey: Curve25519.Signing.PrivateKey, publicKey: Curve25519.Signing.PublicKey) { + let privateKey = Curve25519.Signing.PrivateKey() + return (privateKey, privateKey.publicKey) + } + + // MARK: - Serialization + + /// Serialize Ed25519 private key. With password: raw 32B → ECP1. Without: raw 32B. + /// Matches Python: serialize_ed25519_private(key, password=None) + static func serializePrivate(_ key: Curve25519.Signing.PrivateKey, password: Data? = nil) throws -> Data { + let raw = key.rawData // 32 bytes + if let password = password { + return try KeyEncryption.encrypt(raw, password: password) + } + return raw + } + + /// Serialize Ed25519 public key to 32 raw bytes. + /// Matches Python: serialize_ed25519_public(key) + static func serializePublic(_ key: Curve25519.Signing.PublicKey) -> Data { + key.rawData // 32 bytes + } + + // MARK: - Loading + + /// Load Ed25519 private key. Auto-detects ECP1 / raw 32B. + /// Matches Python: load_ed25519_private(data, password=None) + static func loadPrivate(_ data: Data, password: Data? = nil) throws -> Curve25519.Signing.PrivateKey { + if KeyEncryption.isECP1Format(data) { + guard let pwd = password else { + throw CryptoError.invalidKeyData("ECP1 key requires password") + } + let raw = try KeyEncryption.decrypt(data, password: pwd) + return try Curve25519.Signing.PrivateKey(rawRepresentation: raw) + } + if data.count == 32 { + return try Curve25519.Signing.PrivateKey(rawRepresentation: data) + } + throw CryptoError.invalidKeyData("Cannot parse Ed25519 private key (\(data.count) bytes)") + } + + /// Load Ed25519 public key from 32 raw bytes. + /// Matches Python: load_ed25519_public(data) + static func loadPublic(_ data: Data) throws -> Curve25519.Signing.PublicKey { + guard data.count == 32 else { + throw CryptoError.invalidKeyData("Ed25519 public key must be 32 bytes, got \(data.count)") + } + return try Curve25519.Signing.PublicKey(rawRepresentation: data) + } + + // MARK: - Sign / Verify + + /// Sign data with Ed25519. Returns 64-byte signature. + /// Matches Python: ed25519_sign(private_key, data) + static func sign(_ privateKey: Curve25519.Signing.PrivateKey, data: Data) throws -> Data { + Data(try privateKey.signature(for: data)) + } + + /// Verify Ed25519 signature. + /// Matches Python: ed25519_verify(public_key, signature, data) + static func verify(_ publicKey: Curve25519.Signing.PublicKey, signature: Data, data: Data) -> Bool { + publicKey.isValidSignature(signature, for: data) + } +} diff --git a/ios_client/EncryptedChat/Crypto/FieldArithmetic.swift b/ios_client/EncryptedChat/Crypto/FieldArithmetic.swift new file mode 100644 index 0000000..bf844ff --- /dev/null +++ b/ios_client/EncryptedChat/Crypto/FieldArithmetic.swift @@ -0,0 +1,231 @@ +import Foundation + +/// Pure Swift GF(2^255-19) arithmetic for Ed25519 → X25519 public key conversion. +/// +/// The conversion formula is: u = (1 + y) / (1 - y) mod p +/// where p = 2^255 - 19, and y is the Ed25519 public key's y-coordinate. +/// +/// Uses 4-limb UInt64 representation (little-endian). +enum FieldArithmetic { + + // p = 2^255 - 19 + static let p: [UInt64] = [ + 0xFFFF_FFFF_FFFF_FFED, // limb 0 (least significant) + 0xFFFF_FFFF_FFFF_FFFF, // limb 1 + 0xFFFF_FFFF_FFFF_FFFF, // limb 2 + 0x7FFF_FFFF_FFFF_FFFF, // limb 3 (most significant, 2^63 - 1 accounting for -19) + ] + + /// Load a 256-bit little-endian byte array into 4 UInt64 limbs + static func load(_ bytes: Data) -> [UInt64] { + precondition(bytes.count == 32) + var limbs = [UInt64](repeating: 0, count: 4) + for i in 0..<4 { + var val: UInt64 = 0 + for j in 0..<8 { + val |= UInt64(bytes[i * 8 + j]) << (j * 8) + } + limbs[i] = val + } + return limbs + } + + /// Store 4 UInt64 limbs as 32 little-endian bytes + static func store(_ limbs: [UInt64]) -> Data { + var bytes = Data(count: 32) + for i in 0..<4 { + for j in 0..<8 { + bytes[i * 8 + j] = UInt8((limbs[i] >> (j * 8)) & 0xFF) + } + } + return bytes + } + + /// a + b mod p + static func add(_ a: [UInt64], _ b: [UInt64]) -> [UInt64] { + var result = [UInt64](repeating: 0, count: 4) + var carry: UInt64 = 0 + for i in 0..<4 { + let (sum1, c1) = a[i].addingReportingOverflow(b[i]) + let (sum2, c2) = sum1.addingReportingOverflow(carry) + result[i] = sum2 + carry = (c1 ? 1 : 0) + (c2 ? 1 : 0) + } + // Reduce mod p + return reduceOnce(result, carry: carry) + } + + /// a - b mod p + static func sub(_ a: [UInt64], _ b: [UInt64]) -> [UInt64] { + var result = [UInt64](repeating: 0, count: 4) + var borrow: UInt64 = 0 + for i in 0..<4 { + let (diff1, b1) = a[i].subtractingReportingOverflow(b[i]) + let (diff2, b2) = diff1.subtractingReportingOverflow(borrow) + result[i] = diff2 + borrow = (b1 ? 1 : 0) + (b2 ? 1 : 0) + } + if borrow > 0 { + // Add p back + var c: UInt64 = 0 + for i in 0..<4 { + let (s1, c1) = result[i].addingReportingOverflow(p[i]) + let (s2, c2) = s1.addingReportingOverflow(c) + result[i] = s2 + c = (c1 ? 1 : 0) + (c2 ? 1 : 0) + } + } + return result + } + + /// Multiply two 256-bit numbers mod p using schoolbook multiplication + static func mul(_ a: [UInt64], _ b: [UInt64]) -> [UInt64] { + // Full 512-bit product in 8 limbs + var product = [UInt64](repeating: 0, count: 8) + + for i in 0..<4 { + var carry: UInt64 = 0 + for j in 0..<4 { + let (hi, lo) = a[i].multipliedFullWidth(by: b[j]) + let (sum1, c1) = product[i + j].addingReportingOverflow(lo) + let (sum2, c2) = sum1.addingReportingOverflow(carry) + product[i + j] = sum2 + carry = hi + (c1 ? 1 : 0) + (c2 ? 1 : 0) + } + product[i + 4] = carry + } + + // Reduce mod p using Barrett-like reduction + // Since p = 2^255 - 19, for a 512-bit number we can use: + // x mod p = (x_low + x_high * 2^256) mod p + // Since 2^255 ≡ 19 (mod p), 2^256 ≡ 38 (mod p) + return reduceFull(product) + } + + /// Reduce 512-bit product mod p using 2^256 ≡ 38 (mod p) + private static func reduceFull(_ product: [UInt64]) -> [UInt64] { + // Split: low = product[0..3], high = product[4..7] + // result = low + high * 38 + var result = [UInt64](repeating: 0, count: 5) + + // Start with low part + for i in 0..<4 { + result[i] = product[i] + } + + // Add high * 38 + var carry: UInt64 = 0 + for i in 0..<4 { + let (hi, lo) = product[i + 4].multipliedFullWidth(by: 38) + let (sum1, c1) = result[i].addingReportingOverflow(lo) + let (sum2, c2) = sum1.addingReportingOverflow(carry) + result[i] = sum2 + carry = hi + (c1 ? 1 : 0) + (c2 ? 1 : 0) + } + result[4] = carry + + // The result might still be >= p, so reduce once more + // result[4] * 2^256 ≡ result[4] * 38 (mod p) + var extra: UInt64 = result[4] + result[4] = 0 + if extra > 0 { + let (hi, lo) = extra.multipliedFullWidth(by: 38) + let (sum1, c1) = result[0].addingReportingOverflow(lo) + result[0] = sum1 + var c = hi + (c1 ? 1 : 0) + for i in 1..<4 { + let (s, cf) = result[i].addingReportingOverflow(c) + result[i] = s + c = cf ? 1 : 0 + } + // One more round if carry + if c > 0 { + let (s, _) = result[0].addingReportingOverflow(c * 38) + result[0] = s + } + } + + var out = Array(result[0..<4]) + // Final reduction: if >= p, subtract p + out = reduceOnce(out, carry: 0) + return out + } + + /// If the number >= p, subtract p + private static func reduceOnce(_ val: [UInt64], carry: UInt64) -> [UInt64] { + if carry > 0 || isGreaterOrEqual(val, p) { + var result = [UInt64](repeating: 0, count: 4) + var borrow: UInt64 = 0 + for i in 0..<4 { + let (diff1, b1) = val[i].subtractingReportingOverflow(p[i]) + let (diff2, b2) = diff1.subtractingReportingOverflow(borrow) + result[i] = diff2 + borrow = (b1 ? 1 : 0) + (b2 ? 1 : 0) + } + // If borrow after subtracting p, the original was fine (shouldn't happen with carry) + if borrow > 0 && carry == 0 { + return val + } + return result + } + return val + } + + /// Compare a >= b + private static func isGreaterOrEqual(_ a: [UInt64], _ b: [UInt64]) -> Bool { + for i in stride(from: 3, through: 0, by: -1) { + if a[i] > b[i] { return true } + if a[i] < b[i] { return false } + } + return true // equal + } + + /// Modular inverse using Fermat's little theorem: a^(-1) = a^(p-2) mod p + static func inverse(_ a: [UInt64]) -> [UInt64] { + // p - 2 = 2^255 - 21 + let pMinus2 = sub(p, [2, 0, 0, 0]) + return power(a, pMinus2) + } + + /// Modular exponentiation using square-and-multiply + static func power(_ base: [UInt64], _ exp: [UInt64]) -> [UInt64] { + var result: [UInt64] = [1, 0, 0, 0] // 1 + var b = base + + for i in 0..<4 { + var limb = exp[i] + let bits = (i == 3) ? 63 : 64 // top limb has 63 bits for p-2 + for _ in 0..>= 1 + } + } + return result + } + + // MARK: - Ed25519 → X25519 Public Key Conversion + + /// Convert Ed25519 public key (32 bytes) to X25519 public key (32 bytes). + /// Formula: u = (1 + y) * inverse(1 - y) mod p + static func ed25519PublicToX25519(_ ed25519Pub: Data) -> Data { + precondition(ed25519Pub.count == 32) + + // Ed25519 public key is the y-coordinate with sign bit in the top bit of byte 31 + var keyBytes = ed25519Pub + // Clear the sign bit + keyBytes[31] &= 0x7F + + let y = load(keyBytes) + let one: [UInt64] = [1, 0, 0, 0] + + let onePlusY = add(one, y) + let oneMinusY = sub(one, y) + let inv = inverse(oneMinusY) + let u = mul(onePlusY, inv) + + return store(u) + } +} diff --git a/ios_client/EncryptedChat/Crypto/KeyEncryption.swift b/ios_client/EncryptedChat/Crypto/KeyEncryption.swift new file mode 100644 index 0000000..16e8dd2 --- /dev/null +++ b/ios_client/EncryptedChat/Crypto/KeyEncryption.swift @@ -0,0 +1,106 @@ +import Foundation +import CryptoKit +import CommonCrypto + +/// ECP1 key encryption format: PBKDF2-HMAC-SHA256 (600k iterations) + AES-256-GCM +/// Wire format: magic(4) + salt(16) + nonce(12) + ciphertext_with_tag(N+16) +enum KeyEncryption { + + /// Encrypt raw key bytes with password using ECP1 format + static func encrypt(_ rawBytes: Data, password: Data) throws -> Data { + let salt = Data.randomBytes(16) + let derivedKey = try pbkdf2(password: password, salt: salt) + + let nonce = Data.randomBytes(12) + let symmetricKey = SymmetricKey(data: derivedKey) + let gcmNonce = try AES.GCM.Nonce(data: nonce) + + // AAD = ECP1 magic bytes (matching Python) + let sealedBox = try AES.GCM.seal( + rawBytes, + using: symmetricKey, + nonce: gcmNonce, + authenticating: Constants.ecp1Magic + ) + + // ciphertext + tag concatenated (matches Python's AESGCM.encrypt output) + var result = Data() + result.append(Constants.ecp1Magic) // 4 bytes + result.append(salt) // 16 bytes + result.append(nonce) // 12 bytes + result.append(sealedBox.ciphertext) // N bytes + result.append(sealedBox.tag) // 16 bytes + return result + } + + /// Decrypt ECP1-encrypted key bytes with password + static func decrypt(_ data: Data, password: Data) throws -> Data { + guard data.count >= 48 else { // 4 + 16 + 12 + 16 minimum + throw CryptoError.invalidECP1Format + } + guard data.prefix(4) == Constants.ecp1Magic else { + throw CryptoError.invalidECP1Format + } + + let salt = data[4..<20] + let nonce = data[20..<32] + let ctWithTag = data[32...] + + guard ctWithTag.count >= 16 else { + throw CryptoError.invalidECP1Format + } + + let derivedKey = try pbkdf2(password: password, salt: Data(salt)) + let symmetricKey = SymmetricKey(data: derivedKey) + let gcmNonce = try AES.GCM.Nonce(data: nonce) + + // Split ciphertext and tag + let ct = ctWithTag.prefix(ctWithTag.count - 16) + let tag = ctWithTag.suffix(16) + + let sealedBox = try AES.GCM.SealedBox( + nonce: gcmNonce, + ciphertext: ct, + tag: tag + ) + + do { + return try AES.GCM.open(sealedBox, using: symmetricKey, authenticating: Constants.ecp1Magic) + } catch { + throw CryptoError.decryptionFailed("ECP1 decryption failed - wrong password?") + } + } + + /// Check if data starts with ECP1 magic + static func isECP1Format(_ data: Data) -> Bool { + data.count >= 4 && data.prefix(4) == Constants.ecp1Magic + } + + // MARK: - PBKDF2 + + /// Derive 32-byte key using PBKDF2-HMAC-SHA256 with 600k iterations + static func pbkdf2(password: Data, salt: Data) throws -> Data { + var derivedKey = Data(count: 32) + let status = derivedKey.withUnsafeMutableBytes { derivedKeyPtr in + password.withUnsafeBytes { passwordPtr in + salt.withUnsafeBytes { saltPtr in + CCKeyDerivationPBKDF( + CCPBKDFAlgorithm(kCCPBKDF2), + passwordPtr.baseAddress?.assumingMemoryBound(to: Int8.self), + password.count, + saltPtr.baseAddress?.assumingMemoryBound(to: UInt8.self), + salt.count, + CCPseudoRandomAlgorithm(kCCPRFHmacAlgSHA256), + Constants.pbkdf2Iterations, + derivedKeyPtr.baseAddress?.assumingMemoryBound(to: UInt8.self), + 32 + ) + } + } + } + guard status == kCCSuccess else { + throw CryptoError.pbkdf2Failed + } + return derivedKey + } +} diff --git a/ios_client/EncryptedChat/Crypto/RSACrypto.swift b/ios_client/EncryptedChat/Crypto/RSACrypto.swift new file mode 100644 index 0000000..6e027a3 --- /dev/null +++ b/ios_client/EncryptedChat/Crypto/RSACrypto.swift @@ -0,0 +1,309 @@ +import Foundation +import Security + +/// RSA-4096 operations — used for login challenge-response ONLY +enum RSACrypto { + + // MARK: - Key Generation + + /// Generate RSA-4096 keypair + static func generateKeypair() throws -> (privateKey: SecKey, publicKey: SecKey) { + let attributes: [String: Any] = [ + kSecAttrKeyType as String: kSecAttrKeyTypeRSA, + kSecAttrKeySizeInBits as String: 4096, + ] + + var error: Unmanaged? + guard let privateKey = SecKeyCreateRandomKey(attributes as CFDictionary, &error) else { + throw CryptoError.rsaKeyGenerationFailed + } + guard let publicKey = SecKeyCopyPublicKey(privateKey) else { + throw CryptoError.rsaKeyGenerationFailed + } + return (privateKey, publicKey) + } + + // MARK: - Serialization + + /// Serialize RSA private key. With password: DER → ECP1. Without: PEM PKCS#8. + static func serializePrivateKey(_ key: SecKey, password: Data? = nil) throws -> Data { + var error: Unmanaged? + guard let derData = SecKeyCopyExternalRepresentation(key, &error) as Data? else { + throw CryptoError.rsaOperationFailed("Failed to export private key") + } + + // SecKey exports in PKCS#1 format on iOS — wrap in PKCS#8 for Python compat + let pkcs8 = wrapRSAPrivateKeyPKCS8(derData) + + if let password = password { + return try KeyEncryption.encrypt(pkcs8, password: password) + } + + // PEM encode for Python compatibility + return pemEncode(pkcs8, label: "PRIVATE KEY") + } + + /// Serialize RSA public key as PEM SubjectPublicKeyInfo (Python-compatible) + static func serializePublicKey(_ key: SecKey) throws -> Data { + var error: Unmanaged? + guard let derData = SecKeyCopyExternalRepresentation(key, &error) as Data? else { + throw CryptoError.rsaOperationFailed("Failed to export public key") + } + + // SecKey exports PKCS#1 on iOS — wrap in SubjectPublicKeyInfo + let spki = wrapRSAPublicKeySPKI(derData) + return pemEncode(spki, label: "PUBLIC KEY") + } + + /// Load RSA private key. Auto-detects ECP1 vs PEM format. + static func loadPrivateKey(_ data: Data, password: Data? = nil) throws -> SecKey { + let derData: Data + + if KeyEncryption.isECP1Format(data) { + guard let pwd = password else { + throw CryptoError.invalidKeyData("ECP1 key requires password") + } + let raw = try KeyEncryption.decrypt(data, password: pwd) + derData = unwrapPKCS8ToRSAPrivateKey(raw) + } else { + // PEM format + let pem = String(data: data, encoding: .utf8) ?? "" + derData = try pemDecode(pem, label: "PRIVATE KEY") + .flatMap { unwrapPKCS8ToRSAPrivateKey($0) } + ?? pemDecode(pem, label: "RSA PRIVATE KEY") + ?? { throw CryptoError.invalidKeyData("Cannot parse RSA private key PEM") }() + } + + let attributes: [String: Any] = [ + kSecAttrKeyType as String: kSecAttrKeyTypeRSA, + kSecAttrKeyClass as String: kSecAttrKeyClassPrivate, + ] + + var error: Unmanaged? + guard let key = SecKeyCreateWithData(derData as CFData, attributes as CFDictionary, &error) else { + throw CryptoError.invalidKeyData("Failed to create RSA private key from DER") + } + return key + } + + /// Load RSA public key from PEM + static func loadPublicKey(_ pemData: Data) throws -> SecKey { + let pem = String(data: pemData, encoding: .utf8) ?? "" + + // Try SubjectPublicKeyInfo (PUBLIC KEY), unwrap to PKCS#1 + let derData: Data + if let spki = pemDecode(pem, label: "PUBLIC KEY") { + derData = unwrapSPKIToRSAPublicKey(spki) + } else if let pkcs1 = pemDecode(pem, label: "RSA PUBLIC KEY") { + derData = pkcs1 + } else { + throw CryptoError.invalidKeyData("Cannot parse RSA public key PEM") + } + + let attributes: [String: Any] = [ + kSecAttrKeyType as String: kSecAttrKeyTypeRSA, + kSecAttrKeyClass as String: kSecAttrKeyClassPublic, + ] + + var error: Unmanaged? + guard let key = SecKeyCreateWithData(derData as CFData, attributes as CFDictionary, &error) else { + throw CryptoError.invalidKeyData("Failed to create RSA public key from DER") + } + return key + } + + // MARK: - Sign / Verify + + /// Sign data with RSA-PSS SHA-256. + /// Note: iOS uses salt_length = hash_length (32). Server must use PSS.AUTO to verify. + static func sign(_ privateKey: SecKey, data: Data) throws -> Data { + var error: Unmanaged? + guard let signature = SecKeyCreateSignature( + privateKey, + .rsaSignatureMessagePSSSHA256, + data as CFData, + &error + ) as Data? else { + throw CryptoError.rsaOperationFailed("RSA signing failed") + } + return signature + } + + /// Verify RSA-PSS SHA-256 signature + static func verify(_ publicKey: SecKey, signature: Data, data: Data) -> Bool { + SecKeyVerifySignature( + publicKey, + .rsaSignatureMessagePSSSHA256, + data as CFData, + signature as CFData, + nil + ) + } + + // MARK: - PEM Helpers + + private static func pemEncode(_ der: Data, label: String) -> Data { + let base64 = der.base64EncodedString(options: .lineLength64Characters) + let pem = "-----BEGIN \(label)-----\n\(base64)\n-----END \(label)-----\n" + return Data(pem.utf8) + } + + private static func pemDecode(_ pem: String, label: String) -> Data? { + let beginMarker = "-----BEGIN \(label)-----" + let endMarker = "-----END \(label)-----" + + guard let beginRange = pem.range(of: beginMarker), + let endRange = pem.range(of: endMarker) else { + return nil + } + + let base64String = pem[beginRange.upperBound.. Data { + // PrivateKeyInfo ::= SEQUENCE { + // version INTEGER (0), + // algorithm AlgorithmIdentifier, + // privateKey OCTET STRING (containing PKCS#1 key) + // } + let version = Data([0x02, 0x01, 0x00]) // INTEGER 0 + let algorithmSeq = asn1Sequence(Data(rsaOID) + Data(nullParam)) + let privateKeyOctet = asn1OctetString(pkcs1) + return asn1Sequence(version + algorithmSeq + privateKeyOctet) + } + + /// Unwrap PKCS#8 to get PKCS#1 RSA private key + private static func unwrapPKCS8ToRSAPrivateKey(_ pkcs8: Data) -> Data { + // Parse SEQUENCE, skip version + algorithm, extract OCTET STRING + guard pkcs8.count > 2 else { return pkcs8 } + + var offset = 0 + // Outer SEQUENCE + guard pkcs8[offset] == 0x30 else { return pkcs8 } + offset += 1 + offset = skipASN1Length(pkcs8, offset: offset) + + // Version INTEGER + guard offset < pkcs8.count, pkcs8[offset] == 0x02 else { return pkcs8 } + offset += 1 + let versionLen = readASN1Length(pkcs8, offset: &offset) + offset += versionLen + + // Algorithm SEQUENCE + guard offset < pkcs8.count, pkcs8[offset] == 0x30 else { return pkcs8 } + offset += 1 + let algoLen = readASN1Length(pkcs8, offset: &offset) + offset += algoLen + + // Private key OCTET STRING + guard offset < pkcs8.count, pkcs8[offset] == 0x04 else { return pkcs8 } + offset += 1 + let keyLen = readASN1Length(pkcs8, offset: &offset) + guard offset + keyLen <= pkcs8.count else { return pkcs8 } + return Data(pkcs8[offset..<(offset + keyLen)]) + } + + /// Wrap PKCS#1 RSA public key in SubjectPublicKeyInfo + private static func wrapRSAPublicKeySPKI(_ pkcs1: Data) -> Data { + // SubjectPublicKeyInfo ::= SEQUENCE { + // algorithm AlgorithmIdentifier, + // subjectPublicKey BIT STRING (containing PKCS#1 key) + // } + let algorithmSeq = asn1Sequence(Data(rsaOID) + Data(nullParam)) + let bitString = asn1BitString(pkcs1) + return asn1Sequence(algorithmSeq + bitString) + } + + /// Unwrap SubjectPublicKeyInfo to get PKCS#1 RSA public key + private static func unwrapSPKIToRSAPublicKey(_ spki: Data) -> Data { + guard spki.count > 2 else { return spki } + + var offset = 0 + // Outer SEQUENCE + guard spki[offset] == 0x30 else { return spki } + offset += 1 + offset = skipASN1Length(spki, offset: offset) + + // Algorithm SEQUENCE + guard offset < spki.count, spki[offset] == 0x30 else { return spki } + offset += 1 + let algoLen = readASN1Length(spki, offset: &offset) + offset += algoLen + + // BIT STRING + guard offset < spki.count, spki[offset] == 0x03 else { return spki } + offset += 1 + let bitLen = readASN1Length(spki, offset: &offset) + // Skip the unused bits byte + guard offset < spki.count, spki[offset] == 0x00 else { return spki } + offset += 1 + let keyLen = bitLen - 1 + guard offset + keyLen <= spki.count else { return spki } + return Data(spki[offset..<(offset + keyLen)]) + } + + // MARK: - ASN.1 Primitives + + private static func asn1Length(_ length: Int) -> Data { + if length < 0x80 { + return Data([UInt8(length)]) + } else if length <= 0xFF { + return Data([0x81, UInt8(length)]) + } else if length <= 0xFFFF { + return Data([0x82, UInt8(length >> 8), UInt8(length & 0xFF)]) + } else { + return Data([0x83, UInt8(length >> 16), UInt8((length >> 8) & 0xFF), UInt8(length & 0xFF)]) + } + } + + private static func asn1Sequence(_ content: Data) -> Data { + Data([0x30]) + asn1Length(content.count) + content + } + + private static func asn1OctetString(_ content: Data) -> Data { + Data([0x04]) + asn1Length(content.count) + content + } + + private static func asn1BitString(_ content: Data) -> Data { + // BIT STRING: tag + length + unused_bits(0) + content + Data([0x03]) + asn1Length(content.count + 1) + Data([0x00]) + content + } + + private static func readASN1Length(_ data: Data, offset: inout Int) -> Int { + guard offset < data.count else { return 0 } + let first = data[offset] + offset += 1 + if first < 0x80 { + return Int(first) + } + let numBytes = Int(first & 0x7F) + var length = 0 + for _ in 0.. Int { + var off = offset + _ = readASN1Length(data, offset: &off) + return off + } +} diff --git a/ios_client/EncryptedChat/Crypto/SenderKeyState.swift b/ios_client/EncryptedChat/Crypto/SenderKeyState.swift new file mode 100644 index 0000000..63cc53d --- /dev/null +++ b/ios_client/EncryptedChat/Crypto/SenderKeyState.swift @@ -0,0 +1,175 @@ +import Foundation +import CryptoKit + +/// Sender key chain for group messaging. +/// Each sender in a group has their own chain. Others receive the initial key via pairwise ratchet. +/// Matches Python: SenderKeyState class in crypto_utils.py +class SenderKeyState { + + let senderKey: Data + let chainId: Data + private(set) var chainKey: Data + private(set) var n: Int + private var knownKeys: [Int: Data] + + /// Initialize with optional sender key (generates random 32B if nil). + /// Matches Python: SenderKeyState.__init__(sender_key=None) + init(senderKey: Data? = nil) { + let key = senderKey ?? Data.randomBytes(32) + self.senderKey = key + self.chainId = Data(SHA256.hash(data: key)) + self.chainKey = CryptoUtils.hkdfDerive( + inputKey: key, + salt: Data(repeating: 0x00, count: 32), + info: Data(Constants.senderKeyChainInfo.utf8), + length: 32 + ) + self.n = 0 + self.knownKeys = [:] + } + + /// Private init for import + private init(senderKey: Data, chainId: Data, chainKey: Data, n: Int, knownKeys: [Int: Data]) { + self.senderKey = senderKey + self.chainId = chainId + self.chainKey = chainKey + self.n = n + self.knownKeys = knownKeys + } + + // MARK: - Encrypt + + /// Encrypt with current chain key. + /// Returns (chainId hex, n, ciphertext with tag, nonce). + /// Matches Python: SenderKeyState.encrypt(plaintext) + func encrypt(_ plaintext: Data) throws -> (chainIdHex: String, n: Int, ciphertext: Data, nonce: Data) { + let (newCK, messageKey) = CryptoUtils.kdfCK(chainKey: chainKey) + chainKey = newCK + + let nonce = Data.randomBytes(12) + // AAD = chainId + bigEndian(UInt32(n)) + let aad = chainId + UInt32(n).bigEndianData + let ctWithTag = try CryptoUtils.aesGcmEncrypt(plaintext, key: messageKey, nonce: nonce, aad: aad) + + let result = (chainIdHex: chainId.hexString, n: n, ciphertext: ctWithTag, nonce: nonce) + n += 1 + return result + } + + // MARK: - Decrypt + + /// Decrypt a group message. Fast-forwards the chain if needed. + /// State is snapshotted before modification and restored on failure. + /// Matches Python: SenderKeyState.decrypt(chain_id_hex, n, ciphertext, nonce) + func decrypt(chainIdHex: String, n: Int, ciphertext: Data, nonce: Data) throws -> Data { + guard let expectedChainId = Data(hexString: chainIdHex) else { + throw CryptoError.senderKeyError("Invalid chain ID hex") + } + guard expectedChainId == chainId else { + throw CryptoError.senderKeyError("Chain ID mismatch") + } + + if n - self.n > Constants.maxSenderKeySkip { + throw CryptoError.senderKeyError("Sender key skip too large (\(n - self.n) > \(Constants.maxSenderKeySkip))") + } + + // Snapshot before fast-forward + let snapChainKey = chainKey + let snapN = self.n + let snapKnown = knownKeys + + do { + // Fast-forward the chain to reach message n + while self.n <= n { + let (newCK, mk) = CryptoUtils.kdfCK(chainKey: chainKey) + chainKey = newCK + knownKeys[self.n] = mk + self.n += 1 + } + + guard let mk = knownKeys.removeValue(forKey: n) else { + throw CryptoError.senderKeyError("Message key for n=\(n) not available") + } + + let aad = chainId + UInt32(n).bigEndianData + return try CryptoUtils.aesGcmDecrypt(ciphertext, key: mk, nonce: nonce, aad: aad) + } catch { + // Restore state on failure + chainKey = snapChainKey + self.n = snapN + knownKeys = snapKnown + throw error + } + } + + // MARK: - Key Export/Import + + /// Export sender key for distribution to group members. + /// Matches Python: SenderKeyState.export_key() + func exportKey() -> Data { + let dict: [String: Any] = ["sender_key": senderKey.hexString] + return try! JSONSerialization.data(withJSONObject: dict) + } + + /// Initialize a receiving SenderKeyState from an exported key. + /// Matches Python: SenderKeyState.from_key(exported_key) + static func fromKey(_ exportedKey: Data) throws -> SenderKeyState { + guard let dict = try JSONSerialization.jsonObject(with: exportedKey) as? [String: Any], + let senderKeyHex = dict["sender_key"] as? String, + let senderKey = Data(hexString: senderKeyHex) else { + throw CryptoError.stateImportFailed("Invalid sender key export") + } + return SenderKeyState(senderKey: senderKey) + } + + // MARK: - Full State Export/Import + + /// Serialize full state for persistent storage. + /// Matches Python: SenderKeyState.export_state() + func exportState() -> Data { + var knownKeysDict: [String: String] = [:] + for (k, v) in knownKeys { + knownKeysDict[String(k)] = v.hexString + } + let state: [String: Any] = [ + "sender_key": senderKey.hexString, + "chain_id": chainId.hexString, + "chain_key": chainKey.hexString, + "n": n, + "known_keys": knownKeysDict, + ] + return try! JSONSerialization.data(withJSONObject: state) + } + + /// Deserialize full state. + /// Matches Python: SenderKeyState.import_state(data) + static func importState(_ data: Data) throws -> SenderKeyState { + guard let state = try JSONSerialization.jsonObject(with: data) as? [String: Any], + let senderKeyHex = state["sender_key"] as? String, + let senderKey = Data(hexString: senderKeyHex), + let chainIdHex = state["chain_id"] as? String, + let chainId = Data(hexString: chainIdHex), + let chainKeyHex = state["chain_key"] as? String, + let chainKey = Data(hexString: chainKeyHex), + let n = state["n"] as? Int else { + throw CryptoError.stateImportFailed("Invalid sender key state") + } + + var knownKeys: [Int: Data] = [:] + if let knownKeysDict = state["known_keys"] as? [String: String] { + for (k, v) in knownKeysDict { + if let idx = Int(k), let data = Data(hexString: v) { + knownKeys[idx] = data + } + } + } + + return SenderKeyState( + senderKey: senderKey, + chainId: chainId, + chainKey: chainKey, + n: n, + knownKeys: knownKeys + ) + } +} diff --git a/ios_client/EncryptedChat/Crypto/X25519Crypto.swift b/ios_client/EncryptedChat/Crypto/X25519Crypto.swift new file mode 100644 index 0000000..5431489 --- /dev/null +++ b/ios_client/EncryptedChat/Crypto/X25519Crypto.swift @@ -0,0 +1,77 @@ +import Foundation +import CryptoKit + +/// X25519 Diffie-Hellman key agreement +enum X25519Crypto { + + // MARK: - Key Generation + + /// Generate X25519 keypair + static func generateKeypair() -> (privateKey: Curve25519.KeyAgreement.PrivateKey, publicKey: Curve25519.KeyAgreement.PublicKey) { + let privateKey = Curve25519.KeyAgreement.PrivateKey() + return (privateKey, privateKey.publicKey) + } + + // MARK: - Serialization + + /// Serialize X25519 private key to 32 raw bytes + static func serializePrivate(_ key: Curve25519.KeyAgreement.PrivateKey) -> Data { + key.rawData // 32 bytes + } + + /// Serialize X25519 public key to 32 raw bytes + static func serializePublic(_ key: Curve25519.KeyAgreement.PublicKey) -> Data { + key.rawData // 32 bytes + } + + /// Load X25519 private key from 32 raw bytes + static func loadPrivate(_ data: Data) throws -> Curve25519.KeyAgreement.PrivateKey { + guard data.count == 32 else { + throw CryptoError.invalidKeyData("X25519 private key must be 32 bytes") + } + return try Curve25519.KeyAgreement.PrivateKey(rawRepresentation: data) + } + + /// Load X25519 public key from 32 raw bytes + static func loadPublic(_ data: Data) throws -> Curve25519.KeyAgreement.PublicKey { + guard data.count == 32 else { + throw CryptoError.invalidKeyData("X25519 public key must be 32 bytes") + } + return try Curve25519.KeyAgreement.PublicKey(rawRepresentation: data) + } + + // MARK: - Diffie-Hellman + + /// Perform X25519 DH key agreement. Returns 32-byte shared secret. + /// Matches Python: x25519_dh(private_key, public_key) + static func dh(_ privateKey: Curve25519.KeyAgreement.PrivateKey, _ publicKey: Curve25519.KeyAgreement.PublicKey) throws -> Data { + let sharedSecret = try privateKey.sharedSecretFromKeyAgreement(with: publicKey) + // Extract raw bytes from SharedSecret + return sharedSecret.withUnsafeBytes { Data($0) } + } + + // MARK: - Ed25519 → X25519 Key Conversion + + /// Convert Ed25519 private key to X25519 private key. + /// SHA-512(seed) → take first 32 bytes → clamp per RFC 7748 + /// Matches Python: ed25519_private_to_x25519(ed_private) + static func fromEd25519Private(_ edPrivate: Curve25519.Signing.PrivateKey) throws -> Curve25519.KeyAgreement.PrivateKey { + let raw = edPrivate.rawData // 32 bytes seed + // SHA-512 of the seed + let hash = SHA512.hash(data: raw) + var clamped = Data(hash.prefix(32)) + // Clamp per RFC 7748 + clamped[0] &= 248 + clamped[31] &= 127 + clamped[31] |= 64 + return try Curve25519.KeyAgreement.PrivateKey(rawRepresentation: clamped) + } + + /// Convert Ed25519 public key to X25519 public key. + /// Uses Montgomery birational map: u = (1+y)/(1-y) mod p + /// Matches Python: ed25519_public_to_x25519(ed_public) + static func fromEd25519Public(_ edPublic: Curve25519.Signing.PublicKey) throws -> Curve25519.KeyAgreement.PublicKey { + let x25519Bytes = FieldArithmetic.ed25519PublicToX25519(edPublic.rawData) + return try Curve25519.KeyAgreement.PublicKey(rawRepresentation: x25519Bytes) + } +} diff --git a/ios_client/EncryptedChat/Crypto/X3DH.swift b/ios_client/EncryptedChat/Crypto/X3DH.swift new file mode 100644 index 0000000..f99b602 --- /dev/null +++ b/ios_client/EncryptedChat/Crypto/X3DH.swift @@ -0,0 +1,118 @@ +import Foundation +import CryptoKit + +/// X3DH key agreement protocol (Signal Protocol) +enum X3DH { + + // MARK: - Pre-Key Generation + + /// Generate a signed pre-key (SPK). + /// Returns (private, public, signature, id). + /// Matches Python: generate_signed_prekey(identity_private) + static func generateSignedPrekey( + identityPrivate: Curve25519.Signing.PrivateKey + ) throws -> (privateKey: Curve25519.KeyAgreement.PrivateKey, + publicKey: Curve25519.KeyAgreement.PublicKey, + signature: Data, + id: String) { + let (spkPriv, spkPub) = X25519Crypto.generateKeypair() + let spkPubBytes = X25519Crypto.serializePublic(spkPub) + let signature = try Ed25519Crypto.sign(identityPrivate, data: spkPubBytes) + return (spkPriv, spkPub, signature, UUID().uuidString) + } + + /// Generate a batch of one-time pre-keys. + /// Matches Python: generate_one_time_prekeys(count=50) + static func generateOneTimePrekeys(count: Int = 50) -> [(privateKey: Curve25519.KeyAgreement.PrivateKey, + publicKey: Curve25519.KeyAgreement.PublicKey, + id: String)] { + (0.. (sharedSecret: Data, + ephemeralPrivate: Curve25519.KeyAgreement.PrivateKey, + ephemeralPublic: Curve25519.KeyAgreement.PublicKey) { + // Verify SPK signature + let spkRemoteBytes = X25519Crypto.serializePublic(spkRemote) + guard Ed25519Crypto.verify(ikPublicRemoteEd, signature: spkSignature, data: spkRemoteBytes) else { + throw CryptoError.x3dhFailed("Invalid SPK signature") + } + + // Convert identity keys to X25519 + let ikX25519Private = try X25519Crypto.fromEd25519Private(ikPrivateEd) + let ikX25519Remote = try X25519Crypto.fromEd25519Public(ikPublicRemoteEd) + + // Generate ephemeral keypair + let (ekPriv, ekPub) = X25519Crypto.generateKeypair() + + // DH computations + let dh1 = try X25519Crypto.dh(ikX25519Private, spkRemote) // IK_A, SPK_B + let dh2 = try X25519Crypto.dh(ekPriv, ikX25519Remote) // EK_A, IK_B + let dh3 = try X25519Crypto.dh(ekPriv, spkRemote) // EK_A, SPK_B + + var dhConcat = dh1 + dh2 + dh3 + if let opk = opkRemote { + let dh4 = try X25519Crypto.dh(ekPriv, opk) // EK_A, OPK_B + dhConcat += dh4 + } + + // Derive shared secret + let sharedSecret = CryptoUtils.hkdfDerive( + inputKey: dhConcat, + salt: Data(repeating: 0x00, count: 32), + info: Data(Constants.x3dhInfo.utf8), + length: 32 + ) + + return (sharedSecret, ekPriv, ekPub) + } + + // MARK: - X3DH Respond (Bob) + + /// Responder side of X3DH. + /// Returns sharedSecret. + /// Matches Python: x3dh_respond(ik_private_ed, spk_private, ik_remote_ed, ek_remote, opk_private?) + static func respond( + ikPrivateEd: Curve25519.Signing.PrivateKey, + spkPrivate: Curve25519.KeyAgreement.PrivateKey, + ikRemoteEd: Curve25519.Signing.PublicKey, + ekRemote: Curve25519.KeyAgreement.PublicKey, + opkPrivate: Curve25519.KeyAgreement.PrivateKey? = nil + ) throws -> Data { + let ikX25519Private = try X25519Crypto.fromEd25519Private(ikPrivateEd) + let ikX25519Remote = try X25519Crypto.fromEd25519Public(ikRemoteEd) + + let dh1 = try X25519Crypto.dh(spkPrivate, ikX25519Remote) // SPK_B, IK_A + let dh2 = try X25519Crypto.dh(ikX25519Private, ekRemote) // IK_B, EK_A + let dh3 = try X25519Crypto.dh(spkPrivate, ekRemote) // SPK_B, EK_A + + var dhConcat = dh1 + dh2 + dh3 + if let opk = opkPrivate { + let dh4 = try X25519Crypto.dh(opk, ekRemote) // OPK_B, EK_A + dhConcat += dh4 + } + + let sharedSecret = CryptoUtils.hkdfDerive( + inputKey: dhConcat, + salt: Data(repeating: 0x00, count: 32), + info: Data(Constants.x3dhInfo.utf8), + length: 32 + ) + + return sharedSecret + } +} diff --git a/ios_client/EncryptedChat/Models/Conversation.swift b/ios_client/EncryptedChat/Models/Conversation.swift new file mode 100644 index 0000000..c005883 --- /dev/null +++ b/ios_client/EncryptedChat/Models/Conversation.swift @@ -0,0 +1,46 @@ +import Foundation + +struct Conversation: Identifiable, Equatable { + let id: String + var name: String? + var members: [ConversationMember] + var createdBy: String? + var avatarFile: String? + var unreadCount: Int + var isFavorite: Bool + var lastMessageTime: Date? + + var isGroup: Bool { + name != nil || members.count > 2 + } + + /// Display name: group name, or DM partner username + func displayName(currentUserId: String) -> String { + if let name = name, !name.isEmpty { + return name + } + // DM: show the other person's name + if let other = members.first(where: { $0.userId != currentUserId }) { + return other.username + } + return "Unknown" + } + + /// DM partner user ID (nil for groups) + func dmPartnerId(currentUserId: String) -> String? { + guard !isGroup else { return nil } + return members.first(where: { $0.userId != currentUserId })?.userId + } + + static func == (lhs: Conversation, rhs: Conversation) -> Bool { + lhs.id == rhs.id + } +} + +struct ConversationMember: Identifiable, Equatable, Codable { + let userId: String + var username: String + var email: String + + var id: String { userId } +} diff --git a/ios_client/EncryptedChat/Models/DeviceBundle.swift b/ios_client/EncryptedChat/Models/DeviceBundle.swift new file mode 100644 index 0000000..1192c46 --- /dev/null +++ b/ios_client/EncryptedChat/Models/DeviceBundle.swift @@ -0,0 +1,43 @@ +import Foundation + +/// Key bundle for one device, used in X3DH +struct DeviceBundle { + let deviceId: String + let identityKey: Data // Ed25519 public key (32 bytes) + let spk: Data // X25519 public key (32 bytes) + let spkSignature: Data // Ed25519 signature (64 bytes) + let spkId: String + let opk: Data? // X25519 public key (32 bytes), optional + let opkId: String? + + /// Parse from server response dictionary + static func fromDict(_ dict: [String: Any]) throws -> DeviceBundle { + guard let deviceId = dict["device_id"] as? String, + let ikHex = dict["identity_key"] as? String, + let ik = Data(hexString: ikHex), + let spkHex = dict["spk"] as? String, + let spk = Data(hexString: spkHex), + let spkSigHex = dict["spk_signature"] as? String, + let spkSig = Data(hexString: spkSigHex), + let spkId = dict["spk_id"] as? String else { + throw ChatError.invalidData("Invalid device bundle") + } + + var opk: Data? + var opkId: String? + if let opkHex = dict["opk"] as? String, let opkData = Data(hexString: opkHex) { + opk = opkData + opkId = dict["opk_id"] as? String + } + + return DeviceBundle( + deviceId: deviceId, + identityKey: ik, + spk: spk, + spkSignature: spkSig, + spkId: spkId, + opk: opk, + opkId: opkId + ) + } +} diff --git a/ios_client/EncryptedChat/Models/Invitation.swift b/ios_client/EncryptedChat/Models/Invitation.swift new file mode 100644 index 0000000..6b0631c --- /dev/null +++ b/ios_client/EncryptedChat/Models/Invitation.swift @@ -0,0 +1,9 @@ +import Foundation + +struct Invitation: Identifiable { + let id: String // invitation id (from server) or conversationId + let conversationId: String + let conversationName: String + let invitedBy: String + let invitedByUsername: String +} diff --git a/ios_client/EncryptedChat/Models/Message.swift b/ios_client/EncryptedChat/Models/Message.swift new file mode 100644 index 0000000..ae95bd5 --- /dev/null +++ b/ios_client/EncryptedChat/Models/Message.swift @@ -0,0 +1,33 @@ +import Foundation + +struct Message: Identifiable, Equatable { + let id: String + let conversationId: String + let senderId: String + var senderUsername: String + let createdAt: Date + var text: String? + var replyTo: String? + var imageFileId: String? + var file: FileInfo? + var isDeleted: Bool + var readBy: Set + + /// Whether this is a self-sent message + func isMine(currentUserId: String) -> Bool { + senderId == currentUserId + } + + static func == (lhs: Message, rhs: Message) -> Bool { + lhs.id == rhs.id + } +} + +struct FileInfo: Equatable, Codable { + let fileId: String + let aesKey: String // hex + let iv: String // hex + let filename: String + let size: Int + let mimeType: String +} diff --git a/ios_client/EncryptedChat/Models/User.swift b/ios_client/EncryptedChat/Models/User.swift new file mode 100644 index 0000000..0e429b3 --- /dev/null +++ b/ios_client/EncryptedChat/Models/User.swift @@ -0,0 +1,19 @@ +import Foundation + +struct User: Identifiable, Equatable { + let id: String + var username: String + var email: String + var identityKey: Data? // Ed25519 public key (32 bytes) +} + +struct UserProfile: Equatable { + var userId: String + var username: String? + var email: String? + var phone: String? + var phoneVisible: Bool + var location: String? + var locationVisible: Bool + var avatarFile: String? +} diff --git a/ios_client/EncryptedChat/Network/ConnectionManager.swift b/ios_client/EncryptedChat/Network/ConnectionManager.swift new file mode 100644 index 0000000..e184729 --- /dev/null +++ b/ios_client/EncryptedChat/Network/ConnectionManager.swift @@ -0,0 +1,188 @@ +import Foundation +import Network + +/// TCP connection manager using Network.framework. +/// Handles connection lifecycle, TLS, buffered reading (newline-delimited), and writing. +actor ConnectionManager { + + enum ConnectionState: Equatable { + case disconnected + case connecting + case connected + case failed(String) + } + + private var connection: NWConnection? + private var receiveBuffer = Data() + private(set) var state: ConnectionState = .disconnected + private var stateCallback: ((ConnectionState) -> Void)? + private var messageStream: AsyncStream<[String: Any]>.Continuation? + + /// Set a callback for connection state changes + func onStateChange(_ callback: @escaping (ConnectionState) -> Void) { + stateCallback = callback + } + + // MARK: - Connect / Disconnect + + /// Connect to server + func connect(host: String, port: UInt16, useTLS: Bool = false, tlsInsecure: Bool = false) async throws { + guard state == .disconnected || state != .connected else { + throw NetworkError.alreadyConnected + } + + updateState(.connecting) + + let nwHost = NWEndpoint.Host(host) + let nwPort = NWEndpoint.Port(rawValue: port)! + + let params: NWParameters + if useTLS { + let tlsOptions = NWProtocolTLS.Options() + if tlsInsecure { + // Skip certificate verification (dev only) + sec_protocol_options_set_verify_block( + tlsOptions.securityProtocolOptions, + { _, _, completionHandler in completionHandler(true) }, + .main + ) + } + params = NWParameters(tls: tlsOptions, tcp: .init()) + } else { + params = .tcp + } + + let conn = NWConnection(host: nwHost, port: nwPort, using: params) + self.connection = conn + self.receiveBuffer = Data() + + return try await withCheckedThrowingContinuation { continuation in + conn.stateUpdateHandler = { [weak self] newState in + Task { [weak self] in + guard let self = self else { return } + switch newState { + case .ready: + await self.updateState(.connected) + continuation.resume() + case .failed(let error): + await self.updateState(.failed(error.localizedDescription)) + continuation.resume(throwing: NetworkError.connectionFailed(error.localizedDescription)) + case .cancelled: + await self.updateState(.disconnected) + case .waiting(let error): + await self.updateState(.failed(error.localizedDescription)) + continuation.resume(throwing: NetworkError.connectionFailed("Waiting: \(error.localizedDescription)")) + default: + break + } + } + } + conn.start(queue: .global(qos: .userInitiated)) + } + } + + /// Disconnect from server + func disconnect() { + connection?.cancel() + connection = nil + receiveBuffer = Data() + updateState(.disconnected) + messageStream?.finish() + messageStream = nil + } + + // MARK: - Send + + /// Send raw data over the connection + func send(_ data: Data) async throws { + guard let connection = connection, state == .connected else { + throw NetworkError.notConnected + } + + return try await withCheckedThrowingContinuation { continuation in + connection.send(content: data, completion: .contentProcessed { error in + if let error = error { + continuation.resume(throwing: NetworkError.connectionFailed(error.localizedDescription)) + } else { + continuation.resume() + } + }) + } + } + + /// Send a protocol message (builds JSON + newline, sends) + func sendMessage(type: String, requestId: String? = nil, params: [String: Any] = [:]) async throws { + let data = try ProtocolHandler.buildRequest(type: type, requestId: requestId, params: params) + try await send(data) + } + + // MARK: - Receive + + /// Read one newline-delimited JSON message. + /// Returns nil on EOF / connection close. + func readMessage() async throws -> [String: Any]? { + while true { + // Check buffer for a complete line + if let newlineIndex = receiveBuffer.firstIndex(of: 0x0A) { + let lineData = receiveBuffer.prefix(through: newlineIndex) + receiveBuffer.removeSubrange(...newlineIndex) + + // Check size + if lineData.count > Constants.maxMessageBytes { + throw NetworkError.messageTooLarge + } + + return try ProtocolHandler.parseMessage(Data(lineData)) + } + + // Buffer doesn't have a complete line — read more from the connection + guard let connection = connection else { + return nil + } + + let chunk = try await receiveChunk(connection: connection) + guard let chunk = chunk else { + return nil // EOF + } + + receiveBuffer.append(chunk) + + // Safety: if buffer exceeds max without a newline, drop it + if receiveBuffer.count > Constants.maxMessageBytes * 2 { + receiveBuffer = Data() + throw NetworkError.messageTooLarge + } + } + } + + /// Read a chunk of data from the connection + private func receiveChunk(connection: NWConnection) async throws -> Data? { + return try await withCheckedThrowingContinuation { continuation in + connection.receive(minimumIncompleteLength: 1, maximumLength: 65536) { content, _, isComplete, error in + if let error = error { + continuation.resume(throwing: NetworkError.connectionFailed(error.localizedDescription)) + return + } + if let content = content, !content.isEmpty { + continuation.resume(returning: content) + } else if isComplete { + continuation.resume(returning: nil) + } else { + // No data and not complete — shouldn't happen but return nil + continuation.resume(returning: nil) + } + } + } + } + + // MARK: - State + + var isConnected: Bool { + state == .connected + } + + private func updateState(_ newState: ConnectionState) { + state = newState + stateCallback?(newState) + } +} diff --git a/ios_client/EncryptedChat/Network/ProtocolHandler.swift b/ios_client/EncryptedChat/Network/ProtocolHandler.swift new file mode 100644 index 0000000..2ed65c9 --- /dev/null +++ b/ios_client/EncryptedChat/Network/ProtocolHandler.swift @@ -0,0 +1,90 @@ +import Foundation + +/// Newline-delimited JSON protocol handler. +/// Matches Python: protocol.py build_request, build_response, parse_message, encode_binary, decode_binary +enum ProtocolHandler { + + /// Build a request message (newline-terminated JSON). + /// Matches Python: build_request(msg_type, request_id=None, **kwargs) + static func buildRequest(type: String, requestId: String? = nil, params: [String: Any] = [:]) throws -> Data { + var msg: [String: Any] = ["type": type] + if let requestId = requestId { + msg["request_id"] = requestId + } + // Merge params into msg + for (key, value) in params { + msg[key] = value + } + + let jsonData = try JSONSerialization.data(withJSONObject: msg) + guard jsonData.count < Constants.maxMessageBytes else { + throw NetworkError.messageTooLarge + } + return jsonData + Data([0x0A]) // newline + } + + /// Build a response message (newline-terminated JSON). + static func buildResponse(type: String, status: String, data: [String: Any]? = nil, requestId: String? = nil) throws -> Data { + var msg: [String: Any] = ["type": type, "status": status] + if let data = data { + msg["data"] = data + } + if let requestId = requestId { + msg["request_id"] = requestId + } + + let jsonData = try JSONSerialization.data(withJSONObject: msg) + guard jsonData.count < Constants.maxMessageBytes else { + throw NetworkError.messageTooLarge + } + return jsonData + Data([0x0A]) + } + + /// Parse a single protocol message from bytes. + /// Matches Python: parse_message(line) + static func parseMessage(_ data: Data) throws -> [String: Any] { + let trimmed = data.trimmingNewlines() + guard !trimmed.isEmpty else { + throw NetworkError.protocolError("Empty message") + } + guard let obj = try JSONSerialization.jsonObject(with: trimmed) as? [String: Any] else { + throw NetworkError.protocolError("Message is not a JSON object") + } + return obj + } + + /// Encode bytes to base64 string. + /// Matches Python: encode_binary(data) + static func encodeBinary(_ data: Data) -> String { + data.base64EncodedString(options: []) + } + + /// Decode base64 string to bytes. + /// Matches Python: decode_binary(data) + static func decodeBinary(_ string: String) throws -> Data { + guard let data = Data(base64Encoded: string, options: .ignoreUnknownCharacters) else { + throw CryptoError.invalidBase64 + } + return data + } + + /// Generate a new request ID (UUID string). + static func newRequestId() -> String { + UUID().uuidString + } +} + +// MARK: - Data Helpers + +private extension Data { + func trimmingNewlines() -> Data { + var data = self + while let last = data.last, last == 0x0A || last == 0x0D { + data.removeLast() + } + while let first = data.first, first == 0x0A || first == 0x0D { + data.removeFirst() + } + return data + } +} diff --git a/ios_client/EncryptedChat/Utilities/Constants.swift b/ios_client/EncryptedChat/Utilities/Constants.swift new file mode 100644 index 0000000..b4dfc3c --- /dev/null +++ b/ios_client/EncryptedChat/Utilities/Constants.swift @@ -0,0 +1,38 @@ +import Foundation + +enum Constants { + static let version = "0.8.2" + static let maxMessageBytes = 65536 + static let maxImageBytes = 5 * 1024 * 1024 // 5 MB + static let maxFileBytes = 50 * 1024 * 1024 // 50 MB + static let imageChunkSize = 32768 // 32 KB + static let selfDeviceId = "00000000-0000-0000-0000-000000000000" + + static let opkReplenishThreshold = 20 + static let opkBatchSize = 50 + static let spkRotationDays = 7 + + static let maxSkip = 256 + static let maxSenderKeySkip = 256 + + static let deviceBundleCacheTTL: TimeInterval = 300 // 5 minutes + static let sendReceiveTimeout: TimeInterval = 30 + static let reconnectBaseDelay: TimeInterval = 1 + static let reconnectMaxDelay: TimeInterval = 30 + + static let pbkdf2Iterations: UInt32 = 600_000 + static let ecp1Magic = Data([0x45, 0x43, 0x50, 0x31]) // "ECP1" + + // HKDF info/salt strings matching Python + static let x3dhInfo = "EncryptedChat_X3DH" + static let rootKeyInfo = "EncryptedChat_RootKey" + static let selfEncryptionSalt = "self_encryption" + static let selfEncryptionInfo = "EncryptedChat_SelfKey" + static let localStorageSalt = "local_storage" + static let localStorageInfo = "EncryptedChat_LocalStorage" + static let senderKeyChainInfo = "SenderKeyChain" + + // Server connection defaults + static let defaultHost = "127.0.0.1" + static let defaultPort: UInt16 = 9999 +} diff --git a/ios_client/EncryptedChat/Utilities/Extensions.swift b/ios_client/EncryptedChat/Utilities/Extensions.swift new file mode 100644 index 0000000..84ca32e --- /dev/null +++ b/ios_client/EncryptedChat/Utilities/Extensions.swift @@ -0,0 +1,132 @@ +import Foundation +import CryptoKit + +// MARK: - Data ↔ Hex + +extension Data { + /// Convert data to lowercase hex string + var hexString: String { + map { String(format: "%02x", $0) }.joined() + } + + /// Initialize Data from a hex string + init?(hexString: String) { + let hex = hexString.lowercased() + guard hex.count % 2 == 0 else { return nil } + var data = Data(capacity: hex.count / 2) + var index = hex.startIndex + while index < hex.endIndex { + let nextIndex = hex.index(index, offsetBy: 2) + guard let byte = UInt8(hex[index.. Data { + var data = Data(count: count) + data.withUnsafeMutableBytes { ptr in + _ = SecRandomCopyBytes(kSecRandomDefault, count, ptr.baseAddress!) + } + return data + } +} + +// MARK: - Data ↔ Base64 (Protocol Wire Format) + +extension Data { + /// Encode to standard base64 string (matching Python's base64.b64encode) + func base64EncodedString() -> String { + self.base64EncodedString(options: []) + } + + /// Decode from base64 string + static func fromBase64(_ string: String) throws -> Data { + // Try standard base64 first, then URL-safe + if let data = Data(base64Encoded: string, options: .ignoreUnknownCharacters) { + return data + } + throw CryptoError.invalidBase64 + } +} + +// MARK: - UInt32 Big-Endian + +extension UInt32 { + var bigEndianData: Data { + var value = self.bigEndian + return Data(bytes: &value, count: 4) + } +} + +// MARK: - CryptoKit Key → Data + +extension Curve25519.KeyAgreement.PublicKey { + var rawData: Data { + Data(rawRepresentation) + } +} + +extension Curve25519.KeyAgreement.PrivateKey { + var rawData: Data { + Data(rawRepresentation) + } +} + +extension Curve25519.Signing.PublicKey { + var rawData: Data { + Data(rawRepresentation) + } +} + +extension Curve25519.Signing.PrivateKey { + var rawData: Data { + Data(rawRepresentation) + } +} + +// MARK: - String helpers + +extension String { + /// Trim whitespace and newlines + var trimmed: String { + trimmingCharacters(in: .whitespacesAndNewlines) + } +} + +// MARK: - Dictionary merge helper + +extension Dictionary where Key == String, Value == Any { + func string(for key: String) -> String? { + self[key] as? String + } + + func int(for key: String) -> Int? { + if let i = self[key] as? Int { return i } + if let s = self[key] as? String, let i = Int(s) { return i } + return nil + } + + func dict(for key: String) -> [String: Any]? { + self[key] as? [String: Any] + } + + func array(for key: String) -> [[String: Any]]? { + self[key] as? [[String: Any]] + } + + func data(for key: String) -> Data? { + if let hex = self[key] as? String { + return Data(hexString: hex) + } + return nil + } + + func bool(for key: String) -> Bool? { + if let b = self[key] as? Bool { return b } + if let i = self[key] as? Int { return i != 0 } + return nil + } +} diff --git a/ios_client/EncryptedChat/ViewModels/AuthViewModel.swift b/ios_client/EncryptedChat/ViewModels/AuthViewModel.swift new file mode 100644 index 0000000..33b98ac --- /dev/null +++ b/ios_client/EncryptedChat/ViewModels/AuthViewModel.swift @@ -0,0 +1,114 @@ +import Foundation +import SwiftUI + +@Observable +final class AuthViewModel { + var email = "" + var password = "" + var confirmPassword = "" + var username = "" + var confirmationCode = "" + + var isLoading = false + var errorMessage: String? + var showConfirmation = false + var registrationMessage: String? + + var serverHost = Constants.defaultHost + var serverPort = String(Constants.defaultPort) + var useTLS = false + + enum AuthMode { + case login, register, pairing + } + var mode: AuthMode = .login + + func login(appState: AppState) async { + guard !email.isEmpty, !password.isEmpty else { + errorMessage = "Email and password are required" + return + } + + isLoading = true + errorMessage = nil + + do { + let port = UInt16(serverPort) ?? Constants.defaultPort + try await appState.chatClient.connect(host: serverHost, port: port, useTLS: useTLS) + } catch { + isLoading = false + errorMessage = "Connection failed: \(error.localizedDescription)" + return + } + + let (success, message) = await appState.chatClient.login(email: email, password: password) + isLoading = false + + if success { + appState.email = email + appState.isLoggedIn = true + appState.connectionStatus = .connected + if let userId = await appState.chatClient.userId { + appState.currentUser = User(id: userId, username: await appState.chatClient.username, email: email) + } + } else { + errorMessage = message + } + } + + func register(appState: AppState) async { + guard !email.isEmpty, !password.isEmpty, !username.isEmpty else { + errorMessage = "All fields are required" + return + } + guard password == confirmPassword else { + errorMessage = "Passwords don't match" + return + } + + isLoading = true + errorMessage = nil + + do { + let port = UInt16(serverPort) ?? Constants.defaultPort + try await appState.chatClient.connect(host: serverHost, port: port, useTLS: useTLS) + } catch { + isLoading = false + errorMessage = "Connection failed: \(error.localizedDescription)" + return + } + + let (success, message) = await appState.chatClient.register(username: username, password: password, email: email) + isLoading = false + + if success { + registrationMessage = message + showConfirmation = true + } else { + errorMessage = message + } + } + + func confirmRegistration(appState: AppState) async { + guard !confirmationCode.isEmpty else { + errorMessage = "Enter the confirmation code" + return + } + + isLoading = true + errorMessage = nil + + let (success, message) = await appState.chatClient.confirmRegistration( + email: email, username: username, code: confirmationCode + ) + isLoading = false + + if success { + registrationMessage = message + // Auto-login after registration + await login(appState: appState) + } else { + errorMessage = message + } + } +} diff --git a/ios_client/EncryptedChat/ViewModels/ChatViewModel.swift b/ios_client/EncryptedChat/ViewModels/ChatViewModel.swift new file mode 100644 index 0000000..e47dfef --- /dev/null +++ b/ios_client/EncryptedChat/ViewModels/ChatViewModel.swift @@ -0,0 +1,131 @@ +import Foundation +import SwiftUI + +@Observable +final class ChatViewModel { + var messages: [Message] = [] + var isLoading = false + var isSending = false + var errorMessage: String? + var searchQuery = "" + var searchResults: [String] = [] // message IDs matching search + var currentSearchIndex = 0 + + private var notificationTask: Task? + + func loadMessages(convId: String, chatClient: ChatClient) async { + isLoading = true + messages = await chatClient.getMessages(convId: convId, limit: 50) + isLoading = false + + // Mark as read + let unreadIds = messages.filter { !$0.isMine(currentUserId: await chatClient.userId ?? "") }.map(\.id) + if !unreadIds.isEmpty { + await chatClient.markRead(convId: convId, messageIds: unreadIds) + } + } + + func loadOlderMessages(convId: String, chatClient: ChatClient) async { + let older = await chatClient.getMessages(convId: convId, limit: 50, offset: messages.count) + messages.insert(contentsOf: older, at: 0) + } + + func sendMessage(convId: String, text: String, members: [ConversationMember], + chatClient: ChatClient, replyTo: String? = nil) async { + guard !text.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty else { return } + + isSending = true + errorMessage = nil + + let (success, msg) = await chatClient.sendMessage( + convId: convId, text: text, members: members, replyTo: replyTo + ) + + isSending = false + + if !success { + errorMessage = msg + } else { + // Reload messages to get the sent message + await loadMessages(convId: convId, chatClient: chatClient) + } + } + + func deleteMessage(messageId: String, convId: String, chatClient: ChatClient) async { + let success = await chatClient.deleteMessage(messageId: messageId, convId: convId) + if success { + messages.removeAll { $0.id == messageId } + } + } + + func search(query: String) { + searchQuery = query + if query.isEmpty { + searchResults = [] + currentSearchIndex = 0 + return + } + let lower = query.lowercased() + searchResults = messages.filter { $0.text?.lowercased().contains(lower) == true }.map(\.id) + currentSearchIndex = searchResults.isEmpty ? 0 : searchResults.count - 1 + } + + func nextSearchResult() { + guard !searchResults.isEmpty else { return } + currentSearchIndex = (currentSearchIndex + 1) % searchResults.count + } + + func prevSearchResult() { + guard !searchResults.isEmpty else { return } + currentSearchIndex = (currentSearchIndex - 1 + searchResults.count) % searchResults.count + } + + func startNotificationListener(convId: String, chatClient: ChatClient) { + notificationTask?.cancel() + notificationTask = Task { + for await notification in await chatClient.notifications { + await handleNotification(notification, convId: convId, chatClient: chatClient) + } + } + } + + @MainActor + private func handleNotification(_ notification: ChatNotification, convId: String, chatClient: ChatClient) { + switch notification { + case .newMessage(let data): + if data["conversation_id"] as? String == convId { + if let msg = Task.detached(priority: .userInitiated, operation: { + await chatClient.decryptNotification(data) + }) as? Task { + Task { + if let message = await msg.value { + messages.append(message) + // Mark as read immediately since we're viewing this conv + await chatClient.markRead(convId: convId, messageIds: [message.id]) + } + } + } + } + case .messageDeleted(let data): + if let msgId = data["message_id"] as? String { + messages.removeAll { $0.id == msgId } + } + case .messagesRead(let data): + if let readUserId = data["user_id"] as? String, + let msgIds = data["message_ids"] as? [String] { + for i in messages.indices { + if msgIds.contains(messages[i].id) { + messages[i].readBy.insert(readUserId) + } + } + } + default: + break + } + } + + func stop() { + notificationTask?.cancel() + notificationTask = nil + } +} diff --git a/ios_client/EncryptedChat/ViewModels/ConversationListVM.swift b/ios_client/EncryptedChat/ViewModels/ConversationListVM.swift new file mode 100644 index 0000000..82e93ed --- /dev/null +++ b/ios_client/EncryptedChat/ViewModels/ConversationListVM.swift @@ -0,0 +1,127 @@ +import Foundation +import SwiftUI + +@Observable +final class ConversationListVM { + var conversations: [Conversation] = [] + var invitations: [Invitation] = [] + var onlineUsers: Set = [] + var unreadCounts: [String: Int] = [:] + var favorites: Set = [] + var isLoading = false + + private var notificationTask: Task? + + func load(chatClient: ChatClient, email: String) async { + isLoading = true + + // Load favorites from disk + favorites = KeyStorage.loadFavorites(email: email) + + // Fetch conversations + let convs = await chatClient.listConversations() + conversations = sortConversations(convs, currentUserId: await chatClient.userId ?? "") + + // Populate unread counts from server + for conv in conversations where conv.unreadCount > 0 { + unreadCounts[conv.id] = conv.unreadCount + } + + // Fetch invitations + invitations = await chatClient.listInvitations() + + isLoading = false + + // Start notification listener + startNotificationListener(chatClient: chatClient, email: email) + } + + func refresh(chatClient: ChatClient) async { + let convs = await chatClient.listConversations() + conversations = sortConversations(convs, currentUserId: await chatClient.userId ?? "") + invitations = await chatClient.listInvitations() + } + + func toggleFavorite(convId: String, email: String) { + if favorites.contains(convId) { + favorites.remove(convId) + } else { + favorites.insert(convId) + } + try? KeyStorage.saveFavorites(email: email, favorites: favorites) + + // Re-sort + let userId = conversations.first?.createdBy ?? "" + conversations = sortConversations(conversations, currentUserId: userId) + } + + func markConversationRead(convId: String) { + unreadCounts[convId] = 0 + } + + func incrementUnread(convId: String) { + unreadCounts[convId, default: 0] += 1 + } + + private func sortConversations(_ convs: [Conversation], currentUserId: String) -> [Conversation] { + var result = convs.map { conv -> Conversation in + var c = conv + c.isFavorite = favorites.contains(conv.id) + c.unreadCount = unreadCounts[conv.id] ?? conv.unreadCount + return c + } + + result.sort { a, b in + // Favorites first + if a.isFavorite != b.isFavorite { return a.isFavorite } + // Online DMs next + let aOnline = a.dmPartnerId(currentUserId: currentUserId).map { onlineUsers.contains($0) } ?? false + let bOnline = b.dmPartnerId(currentUserId: currentUserId).map { onlineUsers.contains($0) } ?? false + if aOnline != bOnline { return aOnline } + // Alphabetical + return a.displayName(currentUserId: currentUserId).lowercased() < b.displayName(currentUserId: currentUserId).lowercased() + } + + return result + } + + private func startNotificationListener(chatClient: ChatClient, email: String) { + notificationTask?.cancel() + notificationTask = Task { + for await notification in await chatClient.notifications { + await handleNotification(notification, chatClient: chatClient, email: email) + } + } + } + + @MainActor + private func handleNotification(_ notification: ChatNotification, chatClient: ChatClient, email: String) { + switch notification { + case .newMessage(let data): + if let convId = data["conversation_id"] as? String { + incrementUnread(convId: convId) + } + case .onlineUsers(let userIds): + onlineUsers = Set(userIds) + case .userOnline(let userId): + onlineUsers.insert(userId) + case .userOffline(let userId): + onlineUsers.remove(userId) + case .conversationCreated, .memberAdded, .memberRemoved, .conversationRenamed: + Task { await refresh(chatClient: chatClient) } + case .groupInvitation: + Task { invitations = await chatClient.listInvitations() } + case .connectionStateChanged(let connected): + if !connected { + // Could trigger auto-reconnect here + } + default: + break + } + } + + func stop() { + notificationTask?.cancel() + notificationTask = nil + } +} diff --git a/ios_client/EncryptedChat/ViewModels/ProfileViewModel.swift b/ios_client/EncryptedChat/ViewModels/ProfileViewModel.swift new file mode 100644 index 0000000..2408919 --- /dev/null +++ b/ios_client/EncryptedChat/ViewModels/ProfileViewModel.swift @@ -0,0 +1,66 @@ +import Foundation +import SwiftUI + +@Observable +final class ProfileViewModel { + var profile: UserProfile? + var avatarData: Data? + var isLoading = false + var isSaving = false + var errorMessage: String? + + // Editable fields + var phone = "" + var phoneVisible = false + var location = "" + var locationVisible = false + + func loadProfile(userId: String? = nil, chatClient: ChatClient) async { + isLoading = true + profile = await chatClient.getProfile(userId: userId) + isLoading = false + + if let p = profile { + phone = p.phone ?? "" + phoneVisible = p.phoneVisible + location = p.location ?? "" + locationVisible = p.locationVisible + } + + // Load avatar + let uid = userId ?? await chatClient.userId ?? "" + if !uid.isEmpty { + avatarData = await chatClient.getAvatar(userId: uid) + } + } + + func saveProfile(chatClient: ChatClient) async { + isSaving = true + errorMessage = nil + + let success = await chatClient.updateProfile( + phone: phone.isEmpty ? nil : phone, + phoneVisible: phoneVisible, + location: location.isEmpty ? nil : location, + locationVisible: locationVisible + ) + + isSaving = false + + if !success { + errorMessage = "Failed to update profile" + } + } + + func uploadAvatar(imageData: Data, chatClient: ChatClient) async { + isSaving = true + let success = await chatClient.updateAvatar(imageData: imageData) + isSaving = false + + if success { + avatarData = imageData + } else { + errorMessage = "Failed to upload avatar" + } + } +} diff --git a/ios_client/EncryptedChat/Views/Auth/LoginView.swift b/ios_client/EncryptedChat/Views/Auth/LoginView.swift new file mode 100644 index 0000000..4b8b839 --- /dev/null +++ b/ios_client/EncryptedChat/Views/Auth/LoginView.swift @@ -0,0 +1,134 @@ +import SwiftUI + +struct LoginView: View { + @Bindable var viewModel: AuthViewModel + var appState: AppState + + var body: some View { + NavigationStack { + ScrollView { + VStack(spacing: 20) { + Image(systemName: "lock.shield.fill") + .font(.system(size: 60)) + .foregroundStyle(.blue) + .padding(.top, 40) + + Text("Encrypted Chat") + .font(.largeTitle.bold()) + + Text("End-to-end encrypted messaging") + .font(.subheadline) + .foregroundStyle(.secondary) + + VStack(spacing: 16) { + // Server config + DisclosureGroup("Server") { + TextField("Host", text: $viewModel.serverHost) + .textContentType(.URL) + .autocapitalization(.none) + TextField("Port", text: $viewModel.serverPort) + .keyboardType(.numberPad) + Toggle("Use TLS", isOn: $viewModel.useTLS) + } + .padding(.horizontal) + + TextField("Email", text: $viewModel.email) + .textContentType(.emailAddress) + .keyboardType(.emailAddress) + .autocapitalization(.none) + .textFieldStyle(.roundedBorder) + + SecureField("Password", text: $viewModel.password) + .textContentType(.password) + .textFieldStyle(.roundedBorder) + + if viewModel.mode == .register { + TextField("Username", text: $viewModel.username) + .textContentType(.username) + .autocapitalization(.none) + .textFieldStyle(.roundedBorder) + + SecureField("Confirm Password", text: $viewModel.confirmPassword) + .textFieldStyle(.roundedBorder) + } + + if let error = viewModel.errorMessage { + Text(error) + .foregroundStyle(.red) + .font(.caption) + .multilineTextAlignment(.center) + } + + Button(action: { + Task { + if viewModel.mode == .login { + await viewModel.login(appState: appState) + } else { + await viewModel.register(appState: appState) + } + } + }) { + if viewModel.isLoading { + ProgressView() + .frame(maxWidth: .infinity) + } else { + Text(viewModel.mode == .login ? "Login" : "Register") + .frame(maxWidth: .infinity) + } + } + .buttonStyle(.borderedProminent) + .disabled(viewModel.isLoading) + + Button(viewModel.mode == .login ? "Don't have an account? Register" : "Already have an account? Login") { + viewModel.mode = viewModel.mode == .login ? .register : .login + viewModel.errorMessage = nil + } + .font(.caption) + } + .padding(.horizontal, 32) + } + } + .sheet(isPresented: $viewModel.showConfirmation) { + ConfirmationSheet(viewModel: viewModel, appState: appState) + } + } + } +} + +struct ConfirmationSheet: View { + @Bindable var viewModel: AuthViewModel + var appState: AppState + + var body: some View { + VStack(spacing: 20) { + Text("Confirm Registration") + .font(.title2.bold()) + + if let msg = viewModel.registrationMessage { + Text(msg) + .font(.subheadline) + .foregroundStyle(.secondary) + .multilineTextAlignment(.center) + } + + TextField("Confirmation Code", text: $viewModel.confirmationCode) + .textFieldStyle(.roundedBorder) + .keyboardType(.numberPad) + + if let error = viewModel.errorMessage { + Text(error) + .foregroundStyle(.red) + .font(.caption) + } + + Button("Confirm") { + Task { + await viewModel.confirmRegistration(appState: appState) + } + } + .buttonStyle(.borderedProminent) + .disabled(viewModel.isLoading) + } + .padding(32) + } +} diff --git a/ios_client/EncryptedChat/Views/Auth/PairingView.swift b/ios_client/EncryptedChat/Views/Auth/PairingView.swift new file mode 100644 index 0000000..043fb5c --- /dev/null +++ b/ios_client/EncryptedChat/Views/Auth/PairingView.swift @@ -0,0 +1,49 @@ +import SwiftUI + +struct PairingView: View { + var appState: AppState + @State private var pairingCode = "" + @State private var isWaiting = false + @State private var statusMessage: String? + + var body: some View { + VStack(spacing: 24) { + Image(systemName: "iphone.and.arrow.forward") + .font(.system(size: 48)) + .foregroundStyle(.blue) + + Text("Device Pairing") + .font(.title2.bold()) + + Text("Enter the 8-digit pairing code shown on your other device.") + .font(.subheadline) + .foregroundStyle(.secondary) + .multilineTextAlignment(.center) + + TextField("Pairing Code", text: $pairingCode) + .textFieldStyle(.roundedBorder) + .keyboardType(.numberPad) + .frame(maxWidth: 200) + + if let status = statusMessage { + Text(status) + .font(.caption) + .foregroundStyle(status.contains("Error") ? .red : .secondary) + } + + if isWaiting { + ProgressView("Waiting for authorization...") + } + + Button("Start Pairing") { + Task { + // Pairing implementation would go here + statusMessage = "Pairing not yet implemented" + } + } + .buttonStyle(.borderedProminent) + .disabled(pairingCode.count != 8 || isWaiting) + } + .padding(32) + } +} diff --git a/ios_client/EncryptedChat/Views/Auth/RegisterView.swift b/ios_client/EncryptedChat/Views/Auth/RegisterView.swift new file mode 100644 index 0000000..04c0b57 --- /dev/null +++ b/ios_client/EncryptedChat/Views/Auth/RegisterView.swift @@ -0,0 +1,4 @@ +import SwiftUI + +// Registration is handled within LoginView via mode toggle. +// This file exists for potential future separation. diff --git a/ios_client/EncryptedChat/Views/Chat/ChatView.swift b/ios_client/EncryptedChat/Views/Chat/ChatView.swift new file mode 100644 index 0000000..f1671e2 --- /dev/null +++ b/ios_client/EncryptedChat/Views/Chat/ChatView.swift @@ -0,0 +1,164 @@ +import SwiftUI + +struct ChatView: View { + let conversation: Conversation + var appState: AppState + @State private var viewModel = ChatViewModel() + @State private var inputText = "" + @State private var replyTo: Message? + @State private var showGroupInfo = false + @State private var showSearch = false + @State private var showDeleteConfirm = false + + var body: some View { + VStack(spacing: 0) { + // Search bar + if showSearch { + SearchOverlayView( + query: $viewModel.searchQuery, + matchCount: viewModel.searchResults.count, + currentIndex: viewModel.currentSearchIndex, + onSearch: { viewModel.search(query: $0) }, + onNext: { viewModel.nextSearchResult() }, + onPrev: { viewModel.prevSearchResult() }, + onClose: { showSearch = false; viewModel.search(query: "") } + ) + } + + // Messages + ScrollViewReader { proxy in + ScrollView { + LazyVStack(spacing: 8) { + if viewModel.messages.count >= 50 { + Button("Load older messages") { + Task { + await viewModel.loadOlderMessages(convId: conversation.id, chatClient: appState.chatClient) + } + } + .font(.caption) + .padding() + } + + ForEach(viewModel.messages) { message in + MessageBubbleView( + message: message, + isMine: message.isMine(currentUserId: appState.currentUser?.id ?? ""), + isHighlighted: viewModel.searchResults.contains(message.id), + isCurrentSearchResult: viewModel.searchResults.indices.contains(viewModel.currentSearchIndex) && + viewModel.searchResults[viewModel.currentSearchIndex] == message.id, + onReply: { replyTo = message }, + onDelete: { + Task { + await viewModel.deleteMessage(messageId: message.id, convId: conversation.id, chatClient: appState.chatClient) + } + } + ) + .id(message.id) + } + } + .padding(.horizontal) + .padding(.vertical, 8) + } + .onChange(of: viewModel.messages.count) { + if let lastId = viewModel.messages.last?.id { + withAnimation { + proxy.scrollTo(lastId, anchor: .bottom) + } + } + } + } + + // Reply preview + if let reply = replyTo { + HStack { + Rectangle() + .fill(.blue) + .frame(width: 3) + VStack(alignment: .leading) { + Text(reply.senderUsername) + .font(.caption.bold()) + Text(reply.text ?? "") + .font(.caption) + .lineLimit(1) + } + Spacer() + Button(action: { replyTo = nil }) { + Image(systemName: "xmark.circle.fill") + .foregroundStyle(.secondary) + } + } + .padding(.horizontal) + .padding(.vertical, 6) + .background(.ultraThinMaterial) + } + + // Input + MessageInputView( + text: $inputText, + isSending: viewModel.isSending, + onSend: { + Task { + let text = inputText + inputText = "" + let reply = replyTo?.id + replyTo = nil + await viewModel.sendMessage( + convId: conversation.id, + text: text, + members: conversation.members, + chatClient: appState.chatClient, + replyTo: reply + ) + } + } + ) + } + .navigationTitle(conversation.displayName(currentUserId: appState.currentUser?.id ?? "")) + .navigationBarTitleDisplayMode(.inline) + .toolbar { + ToolbarItem(placement: .topBarTrailing) { + HStack(spacing: 16) { + Button(action: { showSearch.toggle() }) { + Image(systemName: "magnifyingglass") + } + + if conversation.isGroup { + Button(action: { showGroupInfo = true }) { + Image(systemName: "info.circle") + } + } + + // Delete button + if !conversation.isGroup || conversation.createdBy == appState.currentUser?.id { + Button(action: { showDeleteConfirm = true }) { + Image(systemName: "trash") + .foregroundStyle(.red) + } + } + } + } + } + .alert("Delete Conversation?", isPresented: $showDeleteConfirm) { + Button("Cancel", role: .cancel) {} + Button("Delete", role: .destructive) { + Task { + await appState.chatClient.deleteConversation(convId: conversation.id) + } + } + } message: { + Text(conversation.isGroup + ? "This will remove all members and delete the conversation." + : "This will remove you from the conversation.") + } + .sheet(isPresented: $showGroupInfo) { + GroupInfoView(conversation: conversation, appState: appState) + } + .task { + await viewModel.loadMessages(convId: conversation.id, chatClient: appState.chatClient) + viewModel.startNotificationListener(convId: conversation.id, chatClient: appState.chatClient) + } + .onDisappear { + viewModel.stop() + } + } +} diff --git a/ios_client/EncryptedChat/Views/Chat/ImageViewerView.swift b/ios_client/EncryptedChat/Views/Chat/ImageViewerView.swift new file mode 100644 index 0000000..0e35cfd --- /dev/null +++ b/ios_client/EncryptedChat/Views/Chat/ImageViewerView.swift @@ -0,0 +1,43 @@ +import SwiftUI + +struct ImageViewerView: View { + let imageData: Data + @State private var scale: CGFloat = 1.0 + @Environment(\.dismiss) private var dismiss + + var body: some View { + NavigationStack { + GeometryReader { geo in + if let uiImage = UIImage(data: imageData) { + Image(uiImage: uiImage) + .resizable() + .aspectRatio(contentMode: .fit) + .scaleEffect(scale) + .gesture( + MagnifyGesture() + .onChanged { value in + scale = value.magnification + } + .onEnded { _ in + withAnimation { + scale = max(1.0, min(scale, 5.0)) + } + } + ) + .onTapGesture(count: 2) { + withAnimation { + scale = scale > 1 ? 1 : 2 + } + } + .frame(width: geo.size.width, height: geo.size.height) + } + } + .toolbar { + ToolbarItem(placement: .topBarTrailing) { + Button("Done") { dismiss() } + } + } + .background(.black) + } + } +} diff --git a/ios_client/EncryptedChat/Views/Chat/MessageBubbleView.swift b/ios_client/EncryptedChat/Views/Chat/MessageBubbleView.swift new file mode 100644 index 0000000..e883b56 --- /dev/null +++ b/ios_client/EncryptedChat/Views/Chat/MessageBubbleView.swift @@ -0,0 +1,123 @@ +import SwiftUI + +struct MessageBubbleView: View { + let message: Message + let isMine: Bool + var isHighlighted: Bool = false + var isCurrentSearchResult: Bool = false + var onReply: (() -> Void)? + var onDelete: (() -> Void)? + + var body: some View { + HStack { + if isMine { Spacer(minLength: 60) } + + VStack(alignment: isMine ? .trailing : .leading, spacing: 4) { + if !isMine { + Text(message.senderUsername) + .font(.caption.bold()) + .foregroundStyle(.secondary) + } + + if message.isDeleted { + Text("Message deleted") + .font(.body.italic()) + .foregroundStyle(.secondary) + .padding(12) + .background(Color(.systemGray6)) + .clipShape(RoundedRectangle(cornerRadius: 16)) + } else { + // Reply reference + if let replyTo = message.replyTo { + HStack(spacing: 4) { + Rectangle() + .fill(.blue.opacity(0.5)) + .frame(width: 2) + Text("Reply to message") + .font(.caption) + .foregroundStyle(.secondary) + } + .padding(.horizontal, 8) + } + + // File card + if let file = message.file { + VStack(alignment: .leading, spacing: 4) { + HStack { + Image(systemName: "paperclip") + Text(file.filename) + .lineLimit(1) + } + .font(.subheadline) + + Text(formatFileSize(file.size)) + .font(.caption) + .foregroundStyle(.secondary) + } + .padding(12) + .background(Color(.systemGray5)) + .clipShape(RoundedRectangle(cornerRadius: 12)) + } + + // Text content + if let text = message.text { + Text(text) + .padding(12) + .background( + isMine ? Color.blue : Color(.systemGray5) + ) + .foregroundStyle(isMine ? .white : .primary) + .clipShape(RoundedRectangle(cornerRadius: 16)) + } + + // Timestamp + Text(formatTime(message.createdAt)) + .font(.caption2) + .foregroundStyle(.secondary) + } + } + .background( + isCurrentSearchResult ? Color.orange.opacity(0.3) : + isHighlighted ? Color.yellow.opacity(0.2) : Color.clear + ) + .clipShape(RoundedRectangle(cornerRadius: 16)) + .contextMenu { + if !message.isDeleted { + Button(action: { onReply?() }) { + Label("Reply", systemImage: "arrowshape.turn.up.left") + } + + Button(action: { + UIPasteboard.general.string = message.text ?? "" + }) { + Label("Copy", systemImage: "doc.on.doc") + } + + if isMine { + Button(role: .destructive, action: { onDelete?() }) { + Label("Delete", systemImage: "trash") + } + } + } + } + + if !isMine { Spacer(minLength: 60) } + } + } + + private func formatTime(_ date: Date) -> String { + let formatter = DateFormatter() + if Calendar.current.isDateInToday(date) { + formatter.dateFormat = "HH:mm" + } else { + formatter.dateFormat = "MMM d, HH:mm" + } + return formatter.string(from: date) + } + + private func formatFileSize(_ bytes: Int) -> String { + if bytes < 1024 { return "\(bytes) B" } + if bytes < 1024 * 1024 { return "\(bytes / 1024) KB" } + return String(format: "%.1f MB", Double(bytes) / (1024 * 1024)) + } +} diff --git a/ios_client/EncryptedChat/Views/Chat/MessageInputView.swift b/ios_client/EncryptedChat/Views/Chat/MessageInputView.swift new file mode 100644 index 0000000..de53335 --- /dev/null +++ b/ios_client/EncryptedChat/Views/Chat/MessageInputView.swift @@ -0,0 +1,55 @@ +import SwiftUI +import PhotosUI + +struct MessageInputView: View { + @Binding var text: String + let isSending: Bool + let onSend: () -> Void + + @State private var showAttachMenu = false + @State private var selectedPhoto: PhotosPickerItem? + + var body: some View { + HStack(spacing: 8) { + // Attach button + Menu { + Button(action: {}) { + Label("Photo", systemImage: "photo") + } + Button(action: {}) { + Label("File", systemImage: "doc") + } + } label: { + Image(systemName: "plus.circle.fill") + .font(.title2) + .foregroundStyle(.blue) + } + + // Text field + TextField("Message", text: $text, axis: .vertical) + .textFieldStyle(.roundedBorder) + .lineLimit(1...5) + .onSubmit { + if !text.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty { + onSend() + } + } + + // Send button + Button(action: onSend) { + if isSending { + ProgressView() + .scaleEffect(0.8) + } else { + Image(systemName: "arrow.up.circle.fill") + .font(.title2) + .foregroundStyle(.blue) + } + } + .disabled(text.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty || isSending) + } + .padding(.horizontal) + .padding(.vertical, 8) + .background(.ultraThinMaterial) + } +} diff --git a/ios_client/EncryptedChat/Views/Chat/SearchOverlayView.swift b/ios_client/EncryptedChat/Views/Chat/SearchOverlayView.swift new file mode 100644 index 0000000..98bc3b4 --- /dev/null +++ b/ios_client/EncryptedChat/Views/Chat/SearchOverlayView.swift @@ -0,0 +1,46 @@ +import SwiftUI + +struct SearchOverlayView: View { + @Binding var query: String + let matchCount: Int + let currentIndex: Int + let onSearch: (String) -> Void + let onNext: () -> Void + let onPrev: () -> Void + let onClose: () -> Void + + var body: some View { + HStack(spacing: 8) { + Image(systemName: "magnifyingglass") + .foregroundStyle(.secondary) + + TextField("Search messages", text: $query) + .textFieldStyle(.roundedBorder) + .onChange(of: query) { _, newValue in + onSearch(newValue) + } + + if matchCount > 0 { + Text("\(currentIndex + 1)/\(matchCount)") + .font(.caption) + .foregroundStyle(.secondary) + .fixedSize() + + Button(action: onPrev) { + Image(systemName: "chevron.up") + } + Button(action: onNext) { + Image(systemName: "chevron.down") + } + } + + Button(action: onClose) { + Image(systemName: "xmark.circle.fill") + .foregroundStyle(.secondary) + } + } + .padding(.horizontal) + .padding(.vertical, 6) + .background(.ultraThinMaterial) + } +} diff --git a/ios_client/EncryptedChat/Views/Components/CircularAvatarView.swift b/ios_client/EncryptedChat/Views/Components/CircularAvatarView.swift new file mode 100644 index 0000000..787a349 --- /dev/null +++ b/ios_client/EncryptedChat/Views/Components/CircularAvatarView.swift @@ -0,0 +1,46 @@ +import SwiftUI + +struct CircularAvatarView: View { + let name: String + var imageData: Data? + var size: CGFloat = 32 + var isGroup: Bool = false + + var body: some View { + if let imageData = imageData, let uiImage = UIImage(data: imageData) { + Image(uiImage: uiImage) + .resizable() + .aspectRatio(contentMode: .fill) + .frame(width: size, height: size) + .clipShape(Circle()) + } else { + // Default: colored circle with initial letter + ZStack { + Circle() + .fill(avatarColor) + .frame(width: size, height: size) + + Text(initial) + .font(.system(size: size * 0.4, weight: .semibold)) + .foregroundStyle(.white) + } + } + } + + private var initial: String { + String(name.prefix(1)).uppercased() + } + + /// Deterministic color from name hash (matching Python gui_client behavior) + private var avatarColor: Color { + let colors: [Color] = [ + .red, .orange, .yellow, .green, .mint, + .teal, .cyan, .blue, .indigo, .purple, .pink + ] + var hash = 0 + for char in name.unicodeScalars { + hash = hash &* 31 &+ Int(char.value) + } + return colors[abs(hash) % colors.count] + } +} diff --git a/ios_client/EncryptedChat/Views/Components/ConnectionIndicator.swift b/ios_client/EncryptedChat/Views/Components/ConnectionIndicator.swift new file mode 100644 index 0000000..4907e63 --- /dev/null +++ b/ios_client/EncryptedChat/Views/Components/ConnectionIndicator.swift @@ -0,0 +1,35 @@ +import SwiftUI + +struct ConnectionIndicator: View { + let status: ConnectionStatus + + var body: some View { + HStack(spacing: 4) { + Circle() + .fill(statusColor) + .frame(width: 8, height: 8) + + if status != .connected { + Text(statusText) + .font(.caption2) + .foregroundStyle(.secondary) + } + } + } + + private var statusColor: Color { + switch status { + case .connected: return .green + case .connecting: return .orange + case .disconnected: return .red + } + } + + private var statusText: String { + switch status { + case .connected: return "" + case .connecting: return "Connecting..." + case .disconnected: return "Disconnected" + } + } +} diff --git a/ios_client/EncryptedChat/Views/Components/OnlineDotOverlay.swift b/ios_client/EncryptedChat/Views/Components/OnlineDotOverlay.swift new file mode 100644 index 0000000..4a2b645 --- /dev/null +++ b/ios_client/EncryptedChat/Views/Components/OnlineDotOverlay.swift @@ -0,0 +1,15 @@ +import SwiftUI + +struct OnlineDotOverlay: View { + var size: CGFloat = 12 + + var body: some View { + Circle() + .fill(.green) + .frame(width: size, height: size) + .overlay( + Circle() + .stroke(.white, lineWidth: 2) + ) + } +} diff --git a/ios_client/EncryptedChat/Views/Conversations/ConversationListView.swift b/ios_client/EncryptedChat/Views/Conversations/ConversationListView.swift new file mode 100644 index 0000000..d387128 --- /dev/null +++ b/ios_client/EncryptedChat/Views/Conversations/ConversationListView.swift @@ -0,0 +1,99 @@ +import SwiftUI + +struct ConversationListView: View { + var appState: AppState + @Bindable var viewModel: ConversationListVM + @State private var showNewConversation = false + @State private var showProfile = false + @State private var selectedConversation: Conversation? + + var body: some View { + NavigationStack { + List { + // Invitations section + if !viewModel.invitations.isEmpty { + Section { + ForEach(viewModel.invitations) { invitation in + InvitationBanner( + invitation: invitation, + onAccept: { + Task { + let (success, _) = await appState.chatClient.acceptInvitation(convId: invitation.conversationId) + if success { + await viewModel.refresh(chatClient: appState.chatClient) + } + } + }, + onDecline: { + Task { + await appState.chatClient.declineInvitation(convId: invitation.conversationId) + await viewModel.refresh(chatClient: appState.chatClient) + } + } + ) + } + } header: { + Text("Invitations") + } + } + + // Conversations section + Section { + ForEach(viewModel.conversations) { conversation in + NavigationLink(value: conversation) { + ConversationRowView( + conversation: conversation, + currentUserId: appState.currentUser?.id ?? "", + isOnline: conversation.dmPartnerId(currentUserId: appState.currentUser?.id ?? "") + .map { viewModel.onlineUsers.contains($0) } ?? false, + unreadCount: viewModel.unreadCounts[conversation.id] ?? 0 + ) + } + .contextMenu { + Button(conversation.isFavorite ? "Remove from Favorites" : "Add to Favorites") { + viewModel.toggleFavorite(convId: conversation.id, email: appState.email) + } + } + } + } + } + .navigationTitle("Chats") + .navigationDestination(for: Conversation.self) { conversation in + ChatView( + conversation: conversation, + appState: appState + ) + } + .toolbar { + ToolbarItem(placement: .topBarLeading) { + ConnectionIndicator(status: appState.connectionStatus) + } + ToolbarItem(placement: .topBarTrailing) { + HStack { + Button(action: { showProfile = true }) { + Image(systemName: "person.circle") + } + Button(action: { showNewConversation = true }) { + Image(systemName: "square.and.pencil") + } + } + } + } + .refreshable { + await viewModel.refresh(chatClient: appState.chatClient) + } + .sheet(isPresented: $showNewConversation) { + NewConversationSheet(appState: appState) { convId in + showNewConversation = false + await viewModel.refresh(chatClient: appState.chatClient) + } + } + .sheet(isPresented: $showProfile) { + ProfileView(appState: appState, isOwnProfile: true) + } + .task { + await viewModel.load(chatClient: appState.chatClient, email: appState.email) + } + } + } +} diff --git a/ios_client/EncryptedChat/Views/Conversations/ConversationRowView.swift b/ios_client/EncryptedChat/Views/Conversations/ConversationRowView.swift new file mode 100644 index 0000000..2708773 --- /dev/null +++ b/ios_client/EncryptedChat/Views/Conversations/ConversationRowView.swift @@ -0,0 +1,58 @@ +import SwiftUI + +struct ConversationRowView: View { + let conversation: Conversation + let currentUserId: String + let isOnline: Bool + let unreadCount: Int + + var body: some View { + HStack(spacing: 12) { + // Avatar + ZStack(alignment: .bottomTrailing) { + CircularAvatarView( + name: conversation.displayName(currentUserId: currentUserId), + size: 44, + isGroup: conversation.isGroup + ) + + if isOnline && !conversation.isGroup { + OnlineDotOverlay(size: 12) + } + } + + VStack(alignment: .leading, spacing: 2) { + HStack { + if conversation.isFavorite { + Image(systemName: "star.fill") + .font(.caption2) + .foregroundStyle(.yellow) + } + + Text(conversation.displayName(currentUserId: currentUserId)) + .font(unreadCount > 0 ? .body.bold() : .body) + .lineLimit(1) + } + + if conversation.isGroup { + Text("\(conversation.members.count) members") + .font(.caption) + .foregroundStyle(.secondary) + } + } + + Spacer() + + if unreadCount > 0 { + Text("\(unreadCount)") + .font(.caption2.bold()) + .foregroundStyle(.white) + .padding(.horizontal, 8) + .padding(.vertical, 2) + .background(Color.blue) + .clipShape(Capsule()) + } + } + .padding(.vertical, 4) + } +} diff --git a/ios_client/EncryptedChat/Views/Conversations/NewConversationSheet.swift b/ios_client/EncryptedChat/Views/Conversations/NewConversationSheet.swift new file mode 100644 index 0000000..c0ffb31 --- /dev/null +++ b/ios_client/EncryptedChat/Views/Conversations/NewConversationSheet.swift @@ -0,0 +1,100 @@ +import SwiftUI + +struct NewConversationSheet: View { + var appState: AppState + var onCreated: (String) async -> Void + + @State private var email = "" + @State private var groupName = "" + @State private var isGroup = false + @State private var memberEmails: [String] = [""] + @State private var isLoading = false + @State private var errorMessage: String? + @Environment(\.dismiss) private var dismiss + + var body: some View { + NavigationStack { + Form { + Section { + Toggle("Create Group", isOn: $isGroup) + + if isGroup { + TextField("Group Name", text: $groupName) + } + } + + Section(isGroup ? "Members" : "Recipient") { + if isGroup { + ForEach(memberEmails.indices, id: \.self) { index in + TextField("Email", text: $memberEmails[index]) + .textContentType(.emailAddress) + .keyboardType(.emailAddress) + .autocapitalization(.none) + } + Button("Add Member") { + memberEmails.append("") + } + } else { + TextField("Email", text: $email) + .textContentType(.emailAddress) + .keyboardType(.emailAddress) + .autocapitalization(.none) + } + } + + if let error = errorMessage { + Section { + Text(error) + .foregroundStyle(.red) + } + } + } + .navigationTitle("New Conversation") + .navigationBarTitleDisplayMode(.inline) + .toolbar { + ToolbarItem(placement: .cancellationAction) { + Button("Cancel") { dismiss() } + } + ToolbarItem(placement: .confirmationAction) { + Button("Create") { + Task { await create() } + } + .disabled(isLoading) + } + } + } + } + + private func create() async { + isLoading = true + errorMessage = nil + + let emails: [String] + if isGroup { + emails = memberEmails.map { $0.trimmed }.filter { !$0.isEmpty } + guard !emails.isEmpty else { + errorMessage = "Add at least one member" + isLoading = false + return + } + } else { + guard !email.trimmed.isEmpty else { + errorMessage = "Enter an email address" + isLoading = false + return + } + emails = [email.trimmed] + } + + let name = isGroup && !groupName.trimmed.isEmpty ? groupName.trimmed : nil + let (convId, message) = await appState.chatClient.createConversation(emails: emails, name: name) + + isLoading = false + + if let convId = convId { + await onCreated(convId) + } else { + errorMessage = message + } + } +} diff --git a/ios_client/EncryptedChat/Views/Groups/CreateGroupSheet.swift b/ios_client/EncryptedChat/Views/Groups/CreateGroupSheet.swift new file mode 100644 index 0000000..cbeaa75 --- /dev/null +++ b/ios_client/EncryptedChat/Views/Groups/CreateGroupSheet.swift @@ -0,0 +1,4 @@ +import SwiftUI + +// Group creation is handled within NewConversationSheet via the isGroup toggle. +// This file exists for potential future separation. diff --git a/ios_client/EncryptedChat/Views/Groups/GroupInfoView.swift b/ios_client/EncryptedChat/Views/Groups/GroupInfoView.swift new file mode 100644 index 0000000..c9eb0c4 --- /dev/null +++ b/ios_client/EncryptedChat/Views/Groups/GroupInfoView.swift @@ -0,0 +1,123 @@ +import SwiftUI + +struct GroupInfoView: View { + let conversation: Conversation + var appState: AppState + @State private var showRenameSheet = false + @State private var showLeaveConfirm = false + @State private var newName = "" + @Environment(\.dismiss) private var dismiss + + private var isCreator: Bool { + conversation.createdBy == appState.currentUser?.id + } + + var body: some View { + NavigationStack { + List { + // Avatar section + Section { + HStack { + Spacer() + VStack(spacing: 8) { + CircularAvatarView( + name: conversation.name ?? "Group", + size: 64, + isGroup: true + ) + + Text(conversation.name ?? "Group") + .font(.title2.bold()) + + Text("\(conversation.members.count) members") + .font(.subheadline) + .foregroundStyle(.secondary) + } + Spacer() + } + .listRowBackground(Color.clear) + } + + // Actions + if isCreator { + Section { + Button("Rename Group") { + newName = conversation.name ?? "" + showRenameSheet = true + } + + Button("Change Avatar") { + // Photo picker would go here + } + } + } + + // Members + Section("Members") { + ForEach(conversation.members) { member in + HStack { + CircularAvatarView(name: member.username, size: 32, isGroup: false) + + VStack(alignment: .leading) { + Text(member.username) + .font(.body) + Text(member.email) + .font(.caption) + .foregroundStyle(.secondary) + } + + Spacer() + + if member.userId == conversation.createdBy { + Text("Admin") + .font(.caption) + .foregroundStyle(.blue) + } + } + } + } + + // Leave / Delete + Section { + Button("Leave Group", role: .destructive) { + showLeaveConfirm = true + } + + if isCreator { + Button("Delete Group", role: .destructive) { + Task { + await appState.chatClient.deleteConversation(convId: conversation.id) + dismiss() + } + } + } + } + } + .navigationTitle("Group Info") + .navigationBarTitleDisplayMode(.inline) + .toolbar { + ToolbarItem(placement: .topBarTrailing) { + Button("Done") { dismiss() } + } + } + .alert("Leave Group?", isPresented: $showLeaveConfirm) { + Button("Cancel", role: .cancel) {} + Button("Leave", role: .destructive) { + Task { + await appState.chatClient.leaveGroup(convId: conversation.id) + dismiss() + } + } + } + .alert("Rename Group", isPresented: $showRenameSheet) { + TextField("Group Name", text: $newName) + Button("Cancel", role: .cancel) {} + Button("Rename") { + Task { + await appState.chatClient.renameConversation(convId: conversation.id, name: newName) + } + } + } + } + } +} diff --git a/ios_client/EncryptedChat/Views/Groups/InvitationBanner.swift b/ios_client/EncryptedChat/Views/Groups/InvitationBanner.swift new file mode 100644 index 0000000..5f9c877 --- /dev/null +++ b/ios_client/EncryptedChat/Views/Groups/InvitationBanner.swift @@ -0,0 +1,41 @@ +import SwiftUI + +struct InvitationBanner: View { + let invitation: Invitation + let onAccept: () -> Void + let onDecline: () -> Void + + var body: some View { + VStack(alignment: .leading, spacing: 8) { + HStack { + Image(systemName: "envelope.badge") + .foregroundStyle(.orange) + + VStack(alignment: .leading) { + Text(invitation.conversationName) + .font(.body.bold()) + Text("Invited by \(invitation.invitedByUsername)") + .font(.caption) + .foregroundStyle(.secondary) + } + + Spacer() + } + + HStack(spacing: 12) { + Button("Accept") { + onAccept() + } + .buttonStyle(.borderedProminent) + .controlSize(.small) + + Button("Decline") { + onDecline() + } + .buttonStyle(.bordered) + .controlSize(.small) + } + } + .padding(.vertical, 4) + } +} diff --git a/ios_client/EncryptedChat/Views/Profile/EditProfileView.swift b/ios_client/EncryptedChat/Views/Profile/EditProfileView.swift new file mode 100644 index 0000000..b993ca0 --- /dev/null +++ b/ios_client/EncryptedChat/Views/Profile/EditProfileView.swift @@ -0,0 +1,4 @@ +import SwiftUI + +// Profile editing is handled within ProfileView when isOwnProfile = true. +// This file exists for potential future separation. diff --git a/ios_client/EncryptedChat/Views/Profile/ProfileView.swift b/ios_client/EncryptedChat/Views/Profile/ProfileView.swift new file mode 100644 index 0000000..4f10289 --- /dev/null +++ b/ios_client/EncryptedChat/Views/Profile/ProfileView.swift @@ -0,0 +1,111 @@ +import SwiftUI + +struct ProfileView: View { + var appState: AppState + var isOwnProfile: Bool + var userId: String? + @State private var viewModel = ProfileViewModel() + @Environment(\.dismiss) private var dismiss + + var body: some View { + NavigationStack { + Form { + // Avatar + Section { + HStack { + Spacer() + VStack(spacing: 8) { + if let avatarData = viewModel.avatarData, + let uiImage = UIImage(data: avatarData) { + Image(uiImage: uiImage) + .resizable() + .aspectRatio(contentMode: .fill) + .frame(width: 80, height: 80) + .clipShape(Circle()) + } else { + CircularAvatarView( + name: viewModel.profile?.username ?? "?", + size: 80, + isGroup: false + ) + } + + if isOwnProfile { + Button("Change Photo") { + // Photo picker would go here + } + .font(.caption) + } + } + Spacer() + } + .listRowBackground(Color.clear) + } + + // Info + Section("Info") { + if let username = viewModel.profile?.username { + LabeledContent("Username", value: username) + } + if let email = viewModel.profile?.email { + LabeledContent("Email", value: email) + } + } + + if isOwnProfile { + // Editable fields + Section("Contact") { + TextField("Phone", text: $viewModel.phone) + .keyboardType(.phonePad) + Toggle("Phone visible to contacts", isOn: $viewModel.phoneVisible) + + TextField("Location", text: $viewModel.location) + Toggle("Location visible to contacts", isOn: $viewModel.locationVisible) + } + } else { + // Read-only view + if let phone = viewModel.profile?.phone, viewModel.profile?.phoneVisible == true { + Section("Contact") { + LabeledContent("Phone", value: phone) + } + } + if let location = viewModel.profile?.location, viewModel.profile?.locationVisible == true { + Section("Location") { + LabeledContent("Location", value: location) + } + } + } + + if let error = viewModel.errorMessage { + Section { + Text(error) + .foregroundStyle(.red) + } + } + } + .navigationTitle(isOwnProfile ? "My Profile" : "Profile") + .navigationBarTitleDisplayMode(.inline) + .toolbar { + ToolbarItem(placement: .topBarTrailing) { + if isOwnProfile { + Button("Save") { + Task { + await viewModel.saveProfile(chatClient: appState.chatClient) + dismiss() + } + } + .disabled(viewModel.isSaving) + } else { + Button("Done") { dismiss() } + } + } + ToolbarItem(placement: .cancellationAction) { + Button("Cancel") { dismiss() } + } + } + .task { + await viewModel.loadProfile(userId: userId, chatClient: appState.chatClient) + } + } + } +} diff --git a/ios_client/incremental_sync_changes.md b/ios_client/incremental_sync_changes.md new file mode 100644 index 0000000..0d4f142 --- /dev/null +++ b/ios_client/incremental_sync_changes.md @@ -0,0 +1,239 @@ +# iOS Client — Inkrementální sync zpráv + +## Problém + +Klient při každém otevření konverzace posílá `get_messages` a server vrací 50 zpráv (šifrované bloby + metadata). I když klient 49 z nich už má. Zbytečný přenos dat a zátěž serveru. + +## Řešení + +Server už podporuje parametr `after_ts` v `get_messages`. Klient si pamatuje timestamp poslední zprávy a posílá jen dotaz na novější. + +--- + +## Protokol — co posílat serveru + +### `get_messages` — nový volitelný parametr `after_ts` + +**Request:** +```json +{ + "type": "get_messages", + "request_id": "uuid", + "conversation_id": "conv-uuid", + "limit": 50, + "offset": 0, + "after_ts": "2026-02-15T22:15:45" +} +``` + +- `after_ts` (string, ISO 8601, volitelný) — server vrátí jen zprávy s `created_at > after_ts` +- Pokud `after_ts` chybí nebo je null, chová se jako dřív (vrátí posledních `limit` zpráv) + +**Response** — beze změny, jen méně zpráv: +```json +{ + "type": "get_messages", + "status": "ok", + "data": { + "messages": [...], + "total_count": 123 + } +} +``` + +### `get_deleted_since` — sync smazaných zpráv + +Po inkrementálním fetchi je nutné zjistit co bylo smazáno od posledního syncu. + +**Request:** +```json +{ + "type": "get_deleted_since", + "request_id": "uuid", + "conversation_id": "conv-uuid", + "since": "2026-02-15T22:15:45" +} +``` + +**Response:** +```json +{ + "type": "get_deleted_since", + "status": "ok", + "data": { + "message_ids": ["msg-uuid-1", "msg-uuid-2"] + } +} +``` + +### `mark_read` — optimalizace + +**Request** — beze změny, jen posílat méně ID: +```json +{ + "type": "mark_read", + "request_id": "uuid", + "conversation_id": "conv-uuid", + "message_ids": ["only-unread-msg-id-1"] +} +``` + +Filtrovat na klientovi: jen zprávy kde `sender_id != myId` **a** `myId` není v `read_by`. + +--- + +## Implementace na iOS klientovi + +### 1. Lokální cache zpráv + +Ukládat dešifrované zprávy na disk per konverzace. Klíč = `message_id`, hodnota = dešifrovaný payload (bez `read_by` — ten se mění). + +```swift +// MessageCache.swift nebo rozšíření ChatClient + +/// Uložit zprávu do lokální cache +func cacheMessage(convId: String, msgId: String, payload: [String: Any]) + +/// Načíst cache pro konverzaci → [msgId: payload] +func loadCache(convId: String) -> [String: [String: Any]] + +/// Smazat zprávu z cache +func removeCachedMessage(convId: String, msgId: String) +``` + +Formát na disku: JSON soubor v app sandbox, šifrovaný identity key (stejně jako Python klient). + +### 2. Logika v `getMessages()` + +``` +1. Načíst lokální cache pro conv_id +2. Pokud cache je neprázdná A offset == 0: + a. Najít nejnovější created_at v cache → after_ts + b. Poslat get_messages s after_ts (server vrátí jen nové) + c. Dešifrovat nové zprávy, přidat do cache + d. Poslat get_deleted_since s after_ts → smazat z cache + e. Sestavit výsledek z cache (seřadit, vzít posledních limit) +3. Pokud cache je prázdná NEBO offset > 0: + a. Plný fetch jako dřív (bez after_ts) + b. Dešifrovat, uložit do cache + c. Vrátit +4. mark_read: filtrovat jen sender_id != myId a myId not in read_by +``` + +### 3. Pseudokód + +```swift +func getMessages(convId: String, limit: Int = 50, offset: Int = 0) async -> [Message] { + var cache = loadCache(convId: convId) + let myId = userId ?? "" + + // Rozhodnout: inkrementální vs plný fetch + var afterTs: String? = nil + if !cache.isEmpty && offset == 0 { + afterTs = cache.values + .compactMap { $0["created_at"] as? String } + .filter { !($0.isEmpty) } + .max() + } + + // Fetch ze serveru + var params: [String: Any] = [ + "conversation_id": convId, + "limit": limit, + "offset": offset, + ] + if let ts = afterTs { + params["after_ts"] = ts + } + let resp = await sendAndReceive(type: "get_messages", params: params) + + guard resp.string(for: "status") == "ok", + let data = resp.dict(for: "data"), + let rawMessages = data["messages"] as? [[String: Any]] else { + // Offline fallback — vrátit z cache + if !cache.isEmpty && offset == 0 { + return buildFromCache(cache, limit: limit) + } + return [] + } + + // Dešifrovat nové zprávy (existující logika) + let newMessages = decryptRawMessages(rawMessages, cache: &cache, convId: convId) + + // mark_read jen pro nepřečtené + let unreadIds = rawMessages.filter { msg in + let senderId = msg["sender_id"] as? String ?? "" + if senderId == myId { return false } + let readBy = msg["read_by"] as? [[String: Any]] ?? [] + return !readBy.contains { ($0["user_id"] as? String) == myId } + }.compactMap { $0["message_id"] as? String } + + if !unreadIds.isEmpty { + await markRead(convId: convId, messageIds: unreadIds) + } + + if afterTs != nil { + // Inkrementální: sync smazaných + let delResp = await sendAndReceive(type: "get_deleted_since", params: [ + "conversation_id": convId, + "since": afterTs!, + ]) + if let delData = delResp.dict(for: "data"), + let delIds = delData["message_ids"] as? [String] { + for id in delIds { + cache.removeValue(forKey: id) + removeCachedMessage(convId: convId, msgId: id) + } + } + return buildFromCache(cache, limit: limit) + } + + return newMessages +} + +/// Sestavit seřazený seznam z cache +func buildFromCache(_ cache: [String: [String: Any]], limit: Int) -> [Message] { + var messages: [Message] = [] + for (msgId, payload) in cache { + guard payload["_control"] == nil else { continue } + // Vytvořit Message z payload... + messages.append(messageFromPayload(msgId: msgId, payload: payload)) + } + messages.sort { $0.createdAt < $1.createdAt } + if messages.count > limit { + messages = Array(messages.suffix(limit)) + } + return messages +} +``` + +### 4. Co se změní v praxi + +| Situace | Dřív | Teď | +|---------|------|-----| +| Otevření konverzace kde jsem byl před 5 min | Server vrátí 50 zpráv (vše) | Server vrátí 0-2 nové zprávy | +| Otevření konverzace poprvé | Server vrátí 50 zpráv | Stejné (plný fetch) | +| Load older (scroll nahoru) | Server vrátí 50 starších | Stejné (offset > 0, plný fetch) | +| Po reconnectu | Server vrátí 50 zpráv | Server vrátí jen zprávy od odpojení | +| Offline | Nic (chyba) | Zobrazí cache | + +### 5. Metadata (read_by, reactions, pins) + +- **read_by** — neukládá se do cache (mění se často). Přichází v reálném čase přes `messages_read` notifikaci. Po reconnectu může být chvilku stale — přijatelné. +- **reactions** — server je vrací u každé zprávy. V cache se ukládají. Aktualizace přes `message_reacted` notifikaci v reálném čase. +- **pins** — stejně jako reactions. `message_pinned`/`message_unpinned` notifikace. +- Po inkrementálním fetchi jsou metadata aktuální jen pro NOVÉ zprávy. Starší mají stav z cache + real-time notifikací. Při plném fetchi (scroll nahoru / první load) jsou vždy aktuální. + +### 6. `ChatViewModel.loadMessages` — úprava + +```swift +func loadMessages(convId: String, chatClient: ChatClient) async { + isLoading = true + messages = await chatClient.getMessages(convId: convId, limit: 50) + isLoading = false + // mark_read se teď řeší uvnitř getMessages — tady nic + updatePinnedBanner() +} +``` + +`mark_read` volání se přesune z ViewModelu do `getMessages()` v ChatClientu (tam kde má přístup k `read_by` z response). diff --git a/ios_client/project.yml b/ios_client/project.yml new file mode 100644 index 0000000..5c6f78a --- /dev/null +++ b/ios_client/project.yml @@ -0,0 +1,33 @@ +name: EncryptedChat +options: + bundleIdPrefix: com.encryptedchat + deploymentTarget: + iOS: "16.0" + xcodeVersion: "15.0" + generateEmptyDirectories: true + +settings: + base: + SWIFT_VERSION: "5.9" + IPHONEOS_DEPLOYMENT_TARGET: "16.0" + ENABLE_PREVIEWS: "YES" + +targets: + EncryptedChat: + type: application + platform: iOS + sources: + - path: EncryptedChat + settings: + base: + GENERATE_INFOPLIST_FILE: "YES" + PRODUCT_BUNDLE_IDENTIFIER: com.encryptedchat.app + INFOPLIST_KEY_UIApplicationSceneManifest_Generation: "YES" + INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents: "YES" + INFOPLIST_KEY_UILaunchScreen_Generation: "YES" + INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone: "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight" + INFOPLIST_KEY_NSPhotoLibraryUsageDescription: "Select photos to share in chat" + INFOPLIST_KEY_NSCameraUsageDescription: "Take photos to share in chat" + INFOPLIST_KEY_CFBundleDisplayName: "Encrypted Chat" + INFOPLIST_KEY_LSApplicationCategoryType: "public.app-category.social-networking" + CODE_SIGN_STYLE: Automatic diff --git a/ios_client/v0.8.4_changes.md b/ios_client/v0.8.4_changes.md new file mode 100644 index 0000000..56ad6ab --- /dev/null +++ b/ios_client/v0.8.4_changes.md @@ -0,0 +1,1036 @@ +# iOS Client — v0.8.4 Changes + +Reakce (1 per user), Pinned Messages (banner), Forwarding, @Mentions, mark_read optimalizace. + +--- + +## 1. `Models/Message.swift` — Nové fieldy + +```swift +struct Message: Identifiable, Equatable { + let id: String + let conversationId: String + let senderId: String + var senderUsername: String + let createdAt: Date + var text: String? + var replyTo: String? + var imageFileId: String? + var file: FileInfo? + var isDeleted: Bool + var readBy: Set + + // --- v0.8.4 NEW --- + var reactions: [Reaction] // reakce na zprávu + var pinnedAt: String? // ISO timestamp pokud pinnutá, nil jinak + var pinnedBy: String? // user_id kdo pinnul + var forwardedFrom: ForwardInfo? // info o přeposlání + + func isMine(currentUserId: String) -> Bool { + senderId == currentUserId + } + + /// Vrátí reakci aktuálního uživatele (max 1 per user) + func myReaction(currentUserId: String) -> String? { + reactions.first(where: { $0.userId == currentUserId })?.reaction + } + + var isPinned: Bool { pinnedAt != nil } + + static func == (lhs: Message, rhs: Message) -> Bool { + lhs.id == rhs.id + } +} + +// --- v0.8.4 NEW structs --- + +struct Reaction: Equatable { + let userId: String + let reaction: String // "thumbsup", "heart", "laugh", "surprised", "sad", "thumbsdown" +} + +struct ForwardInfo: Equatable { + let sender: String // original sender username + let conversationId: String // original conversation + let messageId: String // original message id +} + +struct FileInfo: Equatable, Codable { + let fileId: String + let aesKey: String + let iv: String + let filename: String + let size: Int + let mimeType: String +} +``` + +--- + +## 2. `Core/ChatClient.swift` — Nové notifikace + metody + +### 2a. ChatNotification enum — přidat 3 nové case + +```swift +enum ChatNotification { + // ... existující cases ... + case sessionReset(data: [String: Any]) + case connectionStateChanged(connected: Bool) + + // --- v0.8.4 NEW --- + case messageReacted(data: [String: Any]) + case messagePinned(data: [String: Any]) + case messageUnpinned(data: [String: Any]) +} +``` + +### 2b. routeMessage() — přidat do notificationTypes setu a switch + +```swift +// V notificationTypes Set přidat: +let notificationTypes = Set([ + "new_message", "messages_read", "message_deleted", + "conversation_created", "member_added", "member_removed", + "user_online", "user_offline", "online_users", + "group_invitation", "conversation_renamed", "session_reset", + // v0.8.4: + "message_reacted", "message_pinned", "message_unpinned" +]) + +// V switch přidat: +case "message_reacted": + notificationContinuation?.yield(.messageReacted(data: data)) +case "message_pinned": + notificationContinuation?.yield(.messagePinned(data: data)) +case "message_unpinned": + notificationContinuation?.yield(.messageUnpinned(data: data)) +``` + +### 2c. getMessages() — parsovat reactions, pinned, forwarded_from + +V `getMessages()`, při vytváření Message objektu (cca řádek 1289), parsovat nová pole: + +```swift +// Po dekrypci JSON payloadu (jsonObj), před vytvořením Message: +// --- v0.8.4: Parse forwarded_from --- +var forwardedFrom: ForwardInfo? +if let fwd = jsonObj["forwarded_from"] as? [String: Any] { + forwardedFrom = ForwardInfo( + sender: fwd["sender"] as? String ?? "?", + conversationId: fwd["conversation_id"] as? String ?? "", + messageId: fwd["message_id"] as? String ?? "" + ) +} + +// --- v0.8.4: Parse reactions from server response (on msgDict, not jsonObj!) --- +var reactions: [Reaction] = [] +if let reactionsRaw = msgDict["reactions"] as? [[String: Any]] { + for r in reactionsRaw { + if let uid = r["user_id"] as? String, let rtype = r["reaction"] as? String { + reactions.append(Reaction(userId: uid, reaction: rtype)) + } + } +} + +// --- v0.8.4: Parse pinned --- +let pinnedAt = msgDict["pinned_at"] as? String // ISO string or nil +let pinnedBy = msgDict["pinned_by"] as? String + +// Pak v Message(...) přidat: +messages.append(Message( + id: msgId, conversationId: convId, senderId: senderId, + senderUsername: msgDict.string(for: "sender_username") ?? "", + createdAt: createdAt, text: messageText, replyTo: replyTo, + imageFileId: msgDict.string(for: "image_file_id"), file: file, + isDeleted: false, readBy: [], + reactions: reactions, // NEW + pinnedAt: pinnedAt, // NEW + pinnedBy: pinnedBy, // NEW + forwardedFrom: forwardedFrom // NEW +)) +``` + +**POZOR:** Taky pro deleted messages a fallback append volat s defaultními hodnotami: +```swift +reactions: [], pinnedAt: nil, pinnedBy: nil, forwardedFrom: nil +``` + +### 2d. Nové metody — react, pin, get_pinned, forward + +```swift +// MARK: - Reactions (v0.8.4) + +static let allowedReactions = ["thumbsup", "heart", "laugh", "surprised", "sad", "thumbsdown"] + +func reactMessage(messageId: String, reaction: String, action: String = "add") async -> Bool { + let resp = await sendAndReceive(type: "react_message", params: [ + "message_id": messageId, + "reaction": reaction, + "action": action, + ]) + return resp.string(for: "status") == "ok" +} + +// MARK: - Pins (v0.8.4) + +func pinMessage(messageId: String, conversationId: String, action: String = "pin") async -> Bool { + let resp = await sendAndReceive(type: "pin_message", params: [ + "message_id": messageId, + "conversation_id": conversationId, + "action": action, + ]) + return resp.string(for: "status") == "ok" +} + +func getPinnedMessages(conversationId: String) async -> [String] { + let resp = await sendAndReceive(type: "get_pinned_messages", params: [ + "conversation_id": conversationId, + ]) + guard resp.string(for: "status") == "ok", + let data = resp.dict(for: "data"), + let msgs = data["messages"] as? [[String: Any]] else { + return [] + } + return msgs.compactMap { $0["message_id"] as? String } +} + +// MARK: - Forward (v0.8.4) + +func forwardMessage(targetConvId: String, originalMsg: Message, + targetMembers: [ConversationMember]) async -> Bool { + // Forward = normální send_message s forwarded_from v payloadu + var text = originalMsg.text ?? "" + if originalMsg.imageFileId != nil { + text = "[Forwarded image]" + } + if originalMsg.file != nil { + text = "[Forwarded file: \(originalMsg.file?.filename ?? "file")]" + } + + // Sestavit payload JSON s forwarded_from + // Tady záleží na implementaci sendMessage — buď přidat parametr forwardedFrom, + // nebo vytvořit payload ručně. Nejjednodušší: přidat optional parametr do sendMessage. + let (success, _) = await sendMessage( + convId: targetConvId, text: text, members: targetMembers, + forwardedFrom: ForwardInfo( + sender: originalMsg.senderUsername, + conversationId: originalMsg.conversationId, + messageId: originalMsg.id + ) + ) + return success +} +``` + +### 2e. sendMessage() — přidat optional forwardedFrom parametr + +V existující `sendMessage()` funkci přidat parametr a propagovat do payloadu: + +```swift +func sendMessage(convId: String, text: String, members: [ConversationMember], + replyTo: String? = nil, + forwardedFrom: ForwardInfo? = nil // NEW +) async -> (Bool, String) { + // ... existující kód ... + + // Kde se sestavuje payload dict (jsonObj), přidat: + if let fwd = forwardedFrom { + payload["forwarded_from"] = [ + "sender": fwd.sender, + "conversation_id": fwd.conversationId, + "message_id": fwd.messageId, + ] + } + + // ... zbytek beze změny ... +} +``` + +### 2f. markRead optimalizace + +V `getMessages()`, **po** sestavení `messages` pole a **před** return, změnit mark_read logiku: + +```swift +// STARÉ (řádky 21-25 v ChatViewModel.loadMessages): +let unreadIds = messages.filter { !$0.isMine(currentUserId: ...) }.map(\.id) + +// NOVÉ — filtrovat jen zprávy co ještě nejsou přečtené: +// V getMessages() zpracovat read_by z server response: +let readByRaw = msgDict["read_by"] as? [[String: Any]] ?? [] +let readBySet = Set(readByRaw.compactMap { $0["user_id"] as? String }) +// ... a předat do Message(... readBy: readBySet ...) + +// Pak v ChatViewModel.loadMessages: +let myId = await chatClient.userId ?? "" +let unreadIds = messages.filter { + !$0.isMine(currentUserId: myId) && !$0.readBy.contains(myId) +}.map(\.id) +if !unreadIds.isEmpty { + await chatClient.markRead(convId: convId, messageIds: unreadIds) +} +``` + +--- + +## 3. `ViewModels/ChatViewModel.swift` — Notification handling + nové metody + +```swift +@Observable +final class ChatViewModel { + var messages: [Message] = [] + var isLoading = false + var isSending = false + var errorMessage: String? + var searchQuery = "" + var searchResults: [String] = [] + var currentSearchIndex = 0 + var pinnedMessage: Message? // NEW — pro banner + + private var notificationTask: Task? + + func loadMessages(convId: String, chatClient: ChatClient) async { + isLoading = true + messages = await chatClient.getMessages(convId: convId, limit: 50) + isLoading = false + + // v0.8.4: Jen nepřečtené zprávy od jiných + let myId = await chatClient.userId ?? "" + let unreadIds = messages.filter { + !$0.isMine(currentUserId: myId) && !$0.readBy.contains(myId) + }.map(\.id) + if !unreadIds.isEmpty { + await chatClient.markRead(convId: convId, messageIds: unreadIds) + } + + // v0.8.4: Update pin banner + updatePinnedBanner() + } + + // --- v0.8.4 NEW --- + + /// Aktualizovat pin banner z aktuálních zpráv + func updatePinnedBanner() { + pinnedMessage = messages.last(where: { $0.isPinned }) + } + + /// Reakce — optimistický update + server call + func react(messageId: String, reaction: String, chatClient: ChatClient) async { + let myId = await chatClient.userId ?? "" + + // Optimistický update + if let idx = messages.firstIndex(where: { $0.id == messageId }) { + let existing = messages[idx].myReaction(currentUserId: myId) + if existing == reaction { + // Toggle off + messages[idx].reactions.removeAll { $0.userId == myId } + await chatClient.reactMessage(messageId: messageId, reaction: reaction, action: "remove") + } else { + // Nahradit (1 per user) + messages[idx].reactions.removeAll { $0.userId == myId } + messages[idx].reactions.append(Reaction(userId: myId, reaction: reaction)) + await chatClient.reactMessage(messageId: messageId, reaction: reaction, action: "add") + } + } + } + + /// Pin/Unpin — optimistický update + server call + func togglePin(messageId: String, convId: String, chatClient: ChatClient) async { + let myId = await chatClient.userId ?? "" + if let idx = messages.firstIndex(where: { $0.id == messageId }) { + if messages[idx].isPinned { + messages[idx].pinnedAt = nil + messages[idx].pinnedBy = nil + await chatClient.pinMessage(messageId: messageId, conversationId: convId, action: "unpin") + } else { + messages[idx].pinnedAt = "now" + messages[idx].pinnedBy = myId + await chatClient.pinMessage(messageId: messageId, conversationId: convId, action: "pin") + } + updatePinnedBanner() + } + } + + // --- Notification handler — přidat nové cases --- + + @MainActor + private func handleNotification(_ notification: ChatNotification, convId: String, chatClient: ChatClient) { + switch notification { + // ... existující cases (newMessage, messageDeleted, messagesRead) ... + + // --- v0.8.4 NEW --- + case .messageReacted(let data): + guard data["conversation_id"] as? String == convId else { break } + let msgId = data["message_id"] as? String ?? "" + let userId = data["user_id"] as? String ?? "" + let reaction = data["reaction"] as? String ?? "" + let action = data["action"] as? String ?? "add" + + if let idx = messages.firstIndex(where: { $0.id == msgId }) { + if action == "add" { + // Remove old (1 per user) + add new + messages[idx].reactions.removeAll { $0.userId == userId } + messages[idx].reactions.append(Reaction(userId: userId, reaction: reaction)) + } else { + messages[idx].reactions.removeAll { $0.userId == userId } + } + } + + case .messagePinned(let data): + guard data["conversation_id"] as? String == convId else { break } + let msgId = data["message_id"] as? String ?? "" + let userId = data["user_id"] as? String ?? "" + if let idx = messages.firstIndex(where: { $0.id == msgId }) { + messages[idx].pinnedAt = "now" + messages[idx].pinnedBy = userId + updatePinnedBanner() + } + + case .messageUnpinned(let data): + guard data["conversation_id"] as? String == convId else { break } + let msgId = data["message_id"] as? String ?? "" + if let idx = messages.firstIndex(where: { $0.id == msgId }) { + messages[idx].pinnedAt = nil + messages[idx].pinnedBy = nil + updatePinnedBanner() + } + + default: + break + } + } +} +``` + +--- + +## 4. `Views/Chat/MessageBubbleView.swift` — Reakce, forwarded, pin, context menu + +```swift +import SwiftUI + +struct MessageBubbleView: View { + let message: Message + let isMine: Bool + let currentUserId: String // NEW — pro reaction check + var isHighlighted: Bool = false + var isCurrentSearchResult: Bool = false + var onReply: (() -> Void)? + var onDelete: (() -> Void)? + var onReact: ((String) -> Void)? // NEW — reaction callback + var onPin: (() -> Void)? // NEW — pin callback + var onForward: (() -> Void)? // NEW — forward callback + + // Emoji mapa + private static let reactionEmoji: [String: String] = [ + "thumbsup": "👍", "heart": "❤️", "laugh": "😂", + "surprised": "😮", "sad": "😢", "thumbsdown": "👎", + ] + + var body: some View { + HStack { + if isMine { Spacer(minLength: 60) } + + VStack(alignment: isMine ? .trailing : .leading, spacing: 4) { + if !isMine { + Text(message.senderUsername) + .font(.caption.bold()) + .foregroundStyle(.secondary) + } + + if message.isDeleted { + Text("Message deleted") + .font(.body.italic()) + .foregroundStyle(.secondary) + .padding(12) + .background(Color(.systemGray6)) + .clipShape(RoundedRectangle(cornerRadius: 16)) + } else { + // --- v0.8.4: Forwarded from header --- + if let fwd = message.forwardedFrom { + HStack(spacing: 4) { + Rectangle() + .fill(.cyan.opacity(0.6)) + .frame(width: 2) + VStack(alignment: .leading, spacing: 0) { + Text("Forwarded from") + .font(.caption2) + .foregroundStyle(.secondary) + Text(fwd.sender) + .font(.caption.bold()) + .foregroundStyle(.cyan) + } + } + .padding(.horizontal, 8) + .padding(.vertical, 2) + } + + // Reply reference + if let _ = message.replyTo { + HStack(spacing: 4) { + Rectangle() + .fill(.blue.opacity(0.5)) + .frame(width: 2) + Text("Reply to message") + .font(.caption) + .foregroundStyle(.secondary) + } + .padding(.horizontal, 8) + } + + // File card + if let file = message.file { + VStack(alignment: .leading, spacing: 4) { + HStack { + Image(systemName: "paperclip") + Text(file.filename).lineLimit(1) + } + .font(.subheadline) + Text(formatFileSize(file.size)) + .font(.caption) + .foregroundStyle(.secondary) + } + .padding(12) + .background(Color(.systemGray5)) + .clipShape(RoundedRectangle(cornerRadius: 12)) + } + + // Text content + pin indicator + if let text = message.text { + HStack(alignment: .top, spacing: 4) { + Text(highlightMentions(text)) + .padding(12) + + // --- v0.8.4: Pin indicator --- + if message.isPinned { + Text("📌") + .font(.caption2) + .padding(.top, 8) + } + } + .background(isMine ? Color.blue : Color(.systemGray5)) + .foregroundStyle(isMine ? .white : .primary) + .clipShape(RoundedRectangle(cornerRadius: 16)) + } + + // --- v0.8.4: Reaction badges --- + if !message.reactions.isEmpty { + reactionBadges + } + + // Timestamp + Text(formatTime(message.createdAt)) + .font(.caption2) + .foregroundStyle(.secondary) + } + } + .background( + isCurrentSearchResult ? Color.orange.opacity(0.3) : + isHighlighted ? Color.yellow.opacity(0.2) : Color.clear + ) + .clipShape(RoundedRectangle(cornerRadius: 16)) + .contextMenu { + if !message.isDeleted { + Button(action: { onReply?() }) { + Label("Reply", systemImage: "arrowshape.turn.up.left") + } + + Button(action: { UIPasteboard.general.string = message.text ?? "" }) { + Label("Copy", systemImage: "doc.on.doc") + } + + // --- v0.8.4: Forward --- + Button(action: { onForward?() }) { + Label("Forward", systemImage: "arrowshape.turn.up.right") + } + + // --- v0.8.4: Pin/Unpin --- + Button(action: { onPin?() }) { + Label(message.isPinned ? "Unpin" : "Pin", + systemImage: message.isPinned ? "pin.slash" : "pin") + } + + Divider() + + // --- v0.8.4: Reactions submenu --- + Menu { + ForEach(Array(Self.reactionEmoji.sorted(by: { $0.key < $1.key })), id: \.key) { key, emoji in + Button(action: { onReact?(key) }) { + let isMine = message.myReaction(currentUserId: currentUserId) == key + Label( + "\(emoji) \(isMine ? "✓" : "")", + systemImage: isMine ? "checkmark.circle.fill" : "face.smiling" + ) + } + } + } label: { + Label("React", systemImage: "face.smiling") + } + + if isMine { + Divider() + Button(role: .destructive, action: { onDelete?() }) { + Label("Delete", systemImage: "trash") + } + } + } + } + + if !isMine { Spacer(minLength: 60) } + } + } + + // --- v0.8.4: Reaction badges view --- + private var reactionBadges: some View { + // Seskupit reakce: [reaction: [userId]] + let grouped = Dictionary(grouping: message.reactions, by: \.reaction) + + return HStack(spacing: 4) { + ForEach(grouped.sorted(by: { $0.key < $1.key }), id: \.key) { reaction, users in + let emoji = Self.reactionEmoji[reaction] ?? reaction + let isMine = users.contains(where: { $0.userId == currentUserId }) + + HStack(spacing: 2) { + Text(emoji) + .font(.caption2) + if users.count > 1 { + Text("\(users.count)") + .font(.caption2) + } + } + .padding(.horizontal, 6) + .padding(.vertical, 2) + .background(isMine ? Color.blue.opacity(0.2) : Color(.systemGray5)) + .clipShape(Capsule()) + .overlay( + Capsule() + .stroke(isMine ? Color.blue.opacity(0.5) : Color.clear, lineWidth: 1) + ) + } + } + } + + // --- v0.8.4: @mention highlighting --- + private func highlightMentions(_ text: String) -> AttributedString { + var result = AttributedString(text) + // Najít @username patterny a zvýraznit modře + let pattern = try? NSRegularExpression(pattern: "@(\\w+)") + let nsText = text as NSString + let matches = pattern?.matches(in: text, range: NSRange(location: 0, length: nsText.length)) ?? [] + for match in matches.reversed() { + if let range = Range(match.range, in: text), + let attrRange = Range(range, in: result) { + result[attrRange].foregroundColor = .blue + result[attrRange].font = .body.bold() + } + } + return result + } + + private func formatTime(_ date: Date) -> String { + let formatter = DateFormatter() + if Calendar.current.isDateInToday(date) { + formatter.dateFormat = "HH:mm" + } else { + formatter.dateFormat = "MMM d, HH:mm" + } + return formatter.string(from: date) + } + + private func formatFileSize(_ bytes: Int) -> String { + if bytes < 1024 { return "\(bytes) B" } + if bytes < 1024 * 1024 { return "\(bytes / 1024) KB" } + return String(format: "%.1f MB", Double(bytes) / (1024 * 1024)) + } +} +``` + +--- + +## 5. `Views/Chat/ChatView.swift` — Pin banner, forward dialog, nové callbacky + +```swift +import SwiftUI + +struct ChatView: View { + let conversation: Conversation + var appState: AppState + @State private var viewModel = ChatViewModel() + @State private var inputText = "" + @State private var replyTo: Message? + @State private var showGroupInfo = false + @State private var showSearch = false + @State private var showDeleteConfirm = false + @State private var showForwardPicker: Message? // NEW — zpráva k přeposlání + @State private var showPinnedList = false // NEW — dialog pinnutých zpráv + + var body: some View { + VStack(spacing: 0) { + // Search bar + if showSearch { + SearchOverlayView( + query: $viewModel.searchQuery, + matchCount: viewModel.searchResults.count, + currentIndex: viewModel.currentSearchIndex, + onSearch: { viewModel.search(query: $0) }, + onNext: { viewModel.nextSearchResult() }, + onPrev: { viewModel.prevSearchResult() }, + onClose: { showSearch = false; viewModel.search(query: "") } + ) + } + + // --- v0.8.4: Pinned message banner --- + if let pinned = viewModel.pinnedMessage { + PinnedBannerView(message: pinned) { + // Scroll to pinned message + // (proxy reference needed — viz ScrollViewReader níže) + } + .onTapGesture { + showPinnedList = true + } + } + + // Messages + ScrollViewReader { proxy in + ScrollView { + LazyVStack(spacing: 8) { + if viewModel.messages.count >= 50 { + Button("Load older messages") { + Task { + await viewModel.loadOlderMessages( + convId: conversation.id, + chatClient: appState.chatClient + ) + } + } + .font(.caption) + .padding() + } + + ForEach(viewModel.messages) { message in + MessageBubbleView( + message: message, + isMine: message.isMine(currentUserId: appState.currentUser?.id ?? ""), + currentUserId: appState.currentUser?.id ?? "", // NEW + isHighlighted: viewModel.searchResults.contains(message.id), + isCurrentSearchResult: viewModel.searchResults.indices.contains(viewModel.currentSearchIndex) && + viewModel.searchResults[viewModel.currentSearchIndex] == message.id, + onReply: { replyTo = message }, + onDelete: { + Task { + await viewModel.deleteMessage( + messageId: message.id, + convId: conversation.id, + chatClient: appState.chatClient + ) + } + }, + // --- v0.8.4 NEW callbacks --- + onReact: { reaction in + Task { + await viewModel.react( + messageId: message.id, + reaction: reaction, + chatClient: appState.chatClient + ) + } + }, + onPin: { + Task { + await viewModel.togglePin( + messageId: message.id, + convId: conversation.id, + chatClient: appState.chatClient + ) + } + }, + onForward: { + showForwardPicker = message + } + ) + .id(message.id) + } + } + .padding(.horizontal) + .padding(.vertical, 8) + } + .onChange(of: viewModel.messages.count) { + if let lastId = viewModel.messages.last?.id { + withAnimation { + proxy.scrollTo(lastId, anchor: .bottom) + } + } + } + // --- v0.8.4: Scroll to pinned on banner tap --- + .onChange(of: showPinnedList) { + if !showPinnedList, let pinId = viewModel.pinnedMessage?.id { + withAnimation { + proxy.scrollTo(pinId, anchor: .center) + } + } + } + } + + // Reply preview + if let reply = replyTo { + HStack { + Rectangle().fill(.blue).frame(width: 3) + VStack(alignment: .leading) { + Text(reply.senderUsername).font(.caption.bold()) + Text(reply.text ?? "").font(.caption).lineLimit(1) + } + Spacer() + Button(action: { replyTo = nil }) { + Image(systemName: "xmark.circle.fill").foregroundStyle(.secondary) + } + } + .padding(.horizontal) + .padding(.vertical, 6) + .background(.ultraThinMaterial) + } + + // Input + MessageInputView( + text: $inputText, + isSending: viewModel.isSending, + onSend: { + Task { + let text = inputText + inputText = "" + let reply = replyTo?.id + replyTo = nil + await viewModel.sendMessage( + convId: conversation.id, + text: text, + members: conversation.members, + chatClient: appState.chatClient, + replyTo: reply + ) + } + } + ) + } + .navigationTitle(conversation.displayName(currentUserId: appState.currentUser?.id ?? "")) + .navigationBarTitleDisplayMode(.inline) + .toolbar { + ToolbarItem(placement: .topBarTrailing) { + HStack(spacing: 16) { + // --- v0.8.4: Pinned messages button --- + Button(action: { showPinnedList = true }) { + Image(systemName: "pin") + } + + Button(action: { showSearch.toggle() }) { + Image(systemName: "magnifyingglass") + } + if conversation.isGroup { + Button(action: { showGroupInfo = true }) { + Image(systemName: "info.circle") + } + } + if !conversation.isGroup || conversation.createdBy == appState.currentUser?.id { + Button(action: { showDeleteConfirm = true }) { + Image(systemName: "trash").foregroundStyle(.red) + } + } + } + } + } + .alert("Delete Conversation?", isPresented: $showDeleteConfirm) { + Button("Cancel", role: .cancel) {} + Button("Delete", role: .destructive) { + Task { await appState.chatClient.deleteConversation(convId: conversation.id) } + } + } message: { + Text(conversation.isGroup + ? "This will remove all members and delete the conversation." + : "This will remove you from the conversation.") + } + // --- v0.8.4: Forward picker sheet --- + .sheet(item: $showForwardPicker) { message in + ForwardPickerView( + message: message, + appState: appState, + onForward: { targetConv in + Task { + await appState.chatClient.forwardMessage( + targetConvId: targetConv.id, + originalMsg: message, + targetMembers: targetConv.members + ) + } + showForwardPicker = nil + } + ) + } + // --- v0.8.4: Pinned messages list sheet --- + .sheet(isPresented: $showPinnedList) { + PinnedMessagesView( + messages: viewModel.messages.filter(\.isPinned), + onSelect: { msg in + showPinnedList = false + // ScrollViewReader scroll handled by onChange above + } + ) + } + .sheet(isPresented: $showGroupInfo) { + GroupInfoView(conversation: conversation, appState: appState) + } + .task { + await viewModel.loadMessages(convId: conversation.id, chatClient: appState.chatClient) + viewModel.startNotificationListener(convId: conversation.id, chatClient: appState.chatClient) + } + .onDisappear { + viewModel.stop() + } + } +} +``` + +**POZNÁMKA:** `Message` musí být `Identifiable` (už je) pro `.sheet(item:)` na `showForwardPicker`. + +--- + +## 6. Nové pomocné views + +### `Views/Chat/PinnedBannerView.swift` (NOVÝ SOUBOR) + +```swift +import SwiftUI + +struct PinnedBannerView: View { + let message: Message + var onTap: (() -> Void)? + + var body: some View { + HStack(spacing: 8) { + Image(systemName: "pin.fill") + .foregroundStyle(.yellow) + .font(.caption) + + VStack(alignment: .leading, spacing: 1) { + Text(message.senderUsername) + .font(.caption.bold()) + Text(message.text ?? "") + .font(.caption) + .lineLimit(1) + .foregroundStyle(.secondary) + } + + Spacer() + } + .padding(.horizontal, 12) + .padding(.vertical, 6) + .background(Color(.systemGray5)) + .contentShape(Rectangle()) + .onTapGesture { onTap?() } + } +} +``` + +### `Views/Chat/ForwardPickerView.swift` (NOVÝ SOUBOR) + +```swift +import SwiftUI + +struct ForwardPickerView: View { + let message: Message + var appState: AppState + var onForward: (Conversation) -> Void + + @State private var conversations: [Conversation] = [] + @Environment(\.dismiss) private var dismiss + + var body: some View { + NavigationStack { + List(conversations) { conv in + Button(action: { onForward(conv) }) { + HStack { + Text(conv.displayName(currentUserId: appState.currentUser?.id ?? "")) + Spacer() + Image(systemName: "arrowshape.turn.up.right") + .foregroundStyle(.secondary) + } + } + } + .navigationTitle("Forward to...") + .navigationBarTitleDisplayMode(.inline) + .toolbar { + ToolbarItem(placement: .cancellationAction) { + Button("Cancel") { dismiss() } + } + } + .task { + conversations = await appState.chatClient.listConversations() + .filter { $0.id != message.conversationId } + } + } + } +} +``` + +### `Views/Chat/PinnedMessagesView.swift` (NOVÝ SOUBOR) + +```swift +import SwiftUI + +struct PinnedMessagesView: View { + let messages: [Message] + var onSelect: (Message) -> Void + + @Environment(\.dismiss) private var dismiss + + var body: some View { + NavigationStack { + Group { + if messages.isEmpty { + ContentUnavailableView( + "No Pinned Messages", + systemImage: "pin.slash", + description: Text("Pin important messages to find them easily.") + ) + } else { + List(messages) { msg in + Button(action: { onSelect(msg) }) { + VStack(alignment: .leading, spacing: 4) { + HStack { + Image(systemName: "pin.fill") + .foregroundStyle(.yellow) + .font(.caption) + Text(msg.senderUsername) + .font(.subheadline.bold()) + } + Text(msg.text ?? "") + .font(.subheadline) + .lineLimit(2) + .foregroundStyle(.secondary) + } + } + } + } + } + .navigationTitle("Pinned Messages") + .navigationBarTitleDisplayMode(.inline) + .toolbar { + ToolbarItem(placement: .cancellationAction) { + Button("Close") { dismiss() } + } + } + } + } +} +``` + +--- + +## 7. Shrnutí všech souborů k úpravě/přidání + +| Soubor | Akce | Popis | +|--------|------|-------| +| `Models/Message.swift` | EDIT | +Reaction, +ForwardInfo structs, +reactions/pinnedAt/forwardedFrom fieldy | +| `Core/ChatClient.swift` | EDIT | +3 notification types, +routeMessage dispatch, +getMessages parsing, +react/pin/forward metody, +sendMessage forwardedFrom param | +| `ViewModels/ChatViewModel.swift` | EDIT | +pinnedMessage, +updatePinnedBanner, +react(), +togglePin(), +3 notification cases, +mark_read optimalizace | +| `Views/Chat/MessageBubbleView.swift` | EDIT | +currentUserId, +onReact/onPin/onForward callbacky, +forwarded header, +pin indicator, +reaction badges, +@mention highlighting, +context menu items | +| `Views/Chat/ChatView.swift` | EDIT | +pin banner, +showForwardPicker, +showPinnedList, +nové callbacky, +pin toolbar button | +| `Views/Chat/PinnedBannerView.swift` | NEW | Pin banner component | +| `Views/Chat/ForwardPickerView.swift` | NEW | Forward conversation picker | +| `Views/Chat/PinnedMessagesView.swift` | NEW | Pinned messages list dialog | diff --git a/protocol.py b/protocol.py new file mode 100644 index 0000000..8135a6a --- /dev/null +++ b/protocol.py @@ -0,0 +1,142 @@ +"""Newline-delimited JSON protocol with base64 encoding for binary data.""" + +import asyncio +import base64 +import binascii +import json +import os + + +def encode_binary(data: bytes) -> str: + """Encode bytes to base64 string.""" + return base64.b64encode(data).decode("ascii") + + +def decode_binary(data: str) -> bytes: + """Decode base64 string to bytes.""" + try: + return base64.b64decode(data, validate=True) + except (TypeError, binascii.Error) as e: + raise ValueError(f"Invalid base64: {e}") + + +VERSION = "0.8.5" +MIN_CLIENT_VERSION = "0.8.5" # server rejects clients older than this + + +def version_gte(version: str, minimum: str) -> bool: + """Return True if version >= minimum (compares numeric tuples, e.g. '0.8.1' >= '0.8'). + + Returns False for malformed version strings (instead of silently treating them as 0). + """ + def _parse(v: str) -> tuple[int, ...] | None: + if not isinstance(v, str) or not v: + return None + parts = v.split(".") + try: + return tuple(int(x) for x in parts) + except (ValueError, AttributeError): + return None + parsed_ver = _parse(version) + parsed_min = _parse(minimum) + if parsed_ver is None or parsed_min is None: + return False + return parsed_ver >= parsed_min + + +MAX_MESSAGE_BYTES = int(os.getenv("MAX_MESSAGE_BYTES", "65536")) # 64 KiB default +MAX_IMAGE_BYTES = int(os.getenv("MAX_IMAGE_BYTES", str(5 * 1024 * 1024))) # 5 MiB default, 0 = no limit +MAX_FILE_BYTES = int(os.getenv("MAX_FILE_BYTES", str(50 * 1024 * 1024))) # 50 MiB default +IMAGE_CHUNK_SIZE = 32768 # 32 KiB raw chunk size for image upload/download + + +def build_request(msg_type: str, request_id: str | None = None, **kwargs) -> bytes: + """Build a protocol message (newline-terminated JSON).""" + msg = {"type": msg_type, **kwargs} + if request_id: + msg["request_id"] = request_id + return json.dumps(msg, ensure_ascii=False).encode("utf-8") + b"\n" + + +def build_response( + msg_type: str, + status: str, + data: dict | None = None, + request_id: str | None = None, +) -> bytes: + """Build a server response.""" + msg = {"type": msg_type, "status": status} + if data is not None: + msg["data"] = data + if request_id: + msg["request_id"] = request_id + return json.dumps(msg, ensure_ascii=False).encode("utf-8") + b"\n" + + +def parse_message(line: bytes) -> dict: + """Parse a single protocol message from bytes.""" + try: + return json.loads(line.decode("utf-8")) + except (json.JSONDecodeError, UnicodeDecodeError) as e: + raise ValueError(f"Invalid message: {e}") + + +class ProtocolReader: + """Read newline-delimited JSON messages from an asyncio StreamReader.""" + + def __init__(self, reader: asyncio.StreamReader): + self._reader = reader + + async def read_message(self) -> dict | None: + """Read and parse one message. Returns None on EOF.""" + try: + line = await self._reader.readuntil(b"\n") + except (asyncio.IncompleteReadError, ConnectionError): + return None + except asyncio.LimitOverrunError as e: + # Message exceeded StreamReader limit — drain oversized data + # using public read() API (consumed=e.consumed bytes before limit). + # Read in chunks until newline found or EOF, then signal error. + remaining = e.consumed + while True: + chunk = await self._reader.read(max(remaining, 4096)) + if not chunk: + return None # EOF while draining + if b"\n" in chunk: + break # found delimiter, oversized message fully drained + raise ValueError("Message exceeds maximum size") + if not line: + return None + return parse_message(line.strip()) + + +class ProtocolWriter: + """Write newline-delimited JSON messages to an asyncio StreamWriter.""" + + def __init__(self, writer: asyncio.StreamWriter): + self._writer = writer + + async def send_request(self, msg_type: str, request_id: str | None = None, **kwargs): + """Send a request message.""" + payload = build_request(msg_type, request_id=request_id, **kwargs) + if len(payload) > MAX_MESSAGE_BYTES: + raise ValueError(f"Message exceeds limit ({len(payload)} > {MAX_MESSAGE_BYTES})") + self._writer.write(payload) + await self._writer.drain() + + async def send_response( + self, + msg_type: str, + status: str, + data: dict | None = None, + request_id: str | None = None, + ): + """Send a response message.""" + payload = build_response(msg_type, status, data, request_id=request_id) + if len(payload) > MAX_MESSAGE_BYTES: + raise ValueError(f"Message exceeds limit ({len(payload)} > {MAX_MESSAGE_BYTES})") + self._writer.write(payload) + await self._writer.drain() + + def close(self): + self._writer.close() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..b4d7cb7 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +cryptography>=42.0.0 +mysql-connector-python>=8.3.0 +python-dotenv>=1.0.0 +# GUI client (optional, needed for gui_client.py) +PyQt6>=6.6.0 +# Image sharing (optional, needed for send_image feature) +Pillow>=10.0.0 +# QR code generation for contact verification (optional) +qrcode[pil]>=7.4 +# QR code scanning (needed for gui_client.py QR scan feature) +pyzbar>=0.1.9 diff --git a/scaling.md b/scaling.md new file mode 100644 index 0000000..742df17 --- /dev/null +++ b/scaling.md @@ -0,0 +1,252 @@ +# Škálování serveru — plán kapacitního růstu + +## Cílový hardware + +- **CPU:** Intel Xeon E5-2630v4 (10 cores / 20 threads, 2.2 GHz) +- **RAM:** 256 GB REG ECC +- **Disk:** 500 GB SSD (boot/OS/DB) + 4 TB HDD (soubory) +- **Síť:** 1 Gbit + +Odhadovaná kapacita po optimalizaci: **10 000–20 000 uživatelů**, **2000–5000 zpráv/s** + +--- + +## Krok 1: Okamžité změny (hotovo v kódu) + +### 1a. Thread pool — `server.py` + +```env +THREAD_POOL_SIZE=40 +``` + +Nastavuje `ThreadPoolExecutor(max_workers=40)` jako default executor pro `asyncio.to_thread()`. +S 20 HW thready a DB latencí ~2–5ms je 40 workerů optimální (2x HW threads — workery čekají na I/O). + +### 1b. DB pool — `.env` + +```env +DB_POOL_SIZE=30 +``` + +30 simultánních MySQL spojení. S 40 thread workers a ~2ms query je 30 pool konexí dostatek. + +### 1c. Chybějící DB indexy — `schema.sql` + +Přidány 5 nových indexů pro nejčastější dotazy: + +| Index | Tabulka | Dotaz který zrychlí | +|-------|---------|---------------------| +| `idx_cm_user (user_id)` | `conversation_members` | `list_user_conversations` — **kritický**, bez něj full table scan | +| `idx_inv_user (user_id)` | `group_invitations` | `get_pending_invitations` | +| `idx_messages_deleted (conversation_id, deleted_at)` | `messages` | `get_deleted_messages_since` | +| `idx_messages_pinned (conversation_id, pinned_at)` | `messages` | `get_pinned_messages` | +| `idx_reads_user (user_id)` | `message_reads` | `get_unread_counts` | + +**SQL migrace pro existující databázi:** + +```sql +ALTER TABLE conversation_members ADD INDEX idx_cm_user (user_id); +ALTER TABLE group_invitations ADD INDEX idx_inv_user (user_id); +ALTER TABLE messages ADD INDEX idx_messages_deleted (conversation_id, deleted_at); +ALTER TABLE messages ADD INDEX idx_messages_pinned (conversation_id, pinned_at); +ALTER TABLE message_reads ADD INDEX idx_reads_user (user_id); +``` + +### 1d. Upload adresář na HDD + +```env +UPLOAD_DIR=/mnt/hdd/encrypted_chat/uploads +``` + +Šifrované soubory a avatary na 4TB HDD — SSD zůstane pro OS a MySQL data. + +```bash +mkdir -p /mnt/hdd/encrypted_chat/uploads +chmod 700 /mnt/hdd/encrypted_chat/uploads +``` + +--- + +## Krok 2: MySQL tuning pro 256 GB RAM + +### `/etc/mysql/mysql.conf.d/tuning.cnf` (nebo ekvivalent v Dockeru) + +```ini +[mysqld] +# === Buffer Pool — hlavní cache pro data + indexy === +# 96 GB = ~37% RAM (MySQL + app na stejném stroji) +innodb_buffer_pool_size = 96G +innodb_buffer_pool_instances = 16 + +# === Redo Log — větší = méně I/O, rychlejší zápisy === +innodb_redo_log_capacity = 4G + +# === Flush strategie === +# 2 = flush do OS cache každou sekundu (ne každý commit) +# Ztráta max 1s dat při pádu OS, ale 10x rychlejší zápisy +innodb_flush_log_at_trx_commit = 2 +# O_DIRECT = bypass OS page cache (InnoDB má vlastní) +innodb_flush_method = O_DIRECT + +# === I/O kapacita (SSD) === +innodb_io_capacity = 2000 +innodb_io_capacity_max = 4000 + +# === Connections === +max_connections = 200 + +# === Sort/Join buffery === +sort_buffer_size = 4M +join_buffer_size = 4M +read_buffer_size = 2M +read_rnd_buffer_size = 2M + +# === Temporary tables === +tmp_table_size = 256M +max_heap_table_size = 256M + +# === Query cache (MySQL 8.0+ nemá, pro 5.7) === +# query_cache_type = 0 + +# === Thread cache === +thread_cache_size = 64 + +# === Binary logging (pro budoucí repliky) === +# server-id = 1 +# log_bin = /var/log/mysql/mysql-bin +# binlog_expire_logs_seconds = 604800 +# max_binlog_size = 256M +``` + +**Pokud MySQL běží v Dockeru:** + +```yaml +# docker-compose.yml +services: + mysql: + image: mysql:8.0 + volumes: + - /var/lib/mysql:/var/lib/mysql # data na SSD + - ./tuning.cnf:/etc/mysql/conf.d/tuning.cnf + deploy: + resources: + limits: + memory: 128G # limitovat aby zbylo pro app + environment: + MYSQL_DATABASE: encrypted_chat +``` + +### Po aplikaci restartovat MySQL a ověřit: + +```sql +SHOW VARIABLES LIKE 'innodb_buffer_pool_size'; +SHOW VARIABLES LIKE 'innodb_flush_log_at_trx_commit'; +SHOW ENGINE INNODB STATUS\G +``` + +--- + +## Krok 3: Doporučená `.env` pro produkci + +```env +# Server +SERVER_HOST=0.0.0.0 +SERVER_PORT=9999 + +# MySQL +MYSQL_HOST=127.0.0.1 +MYSQL_PORT=3306 +MYSQL_USER=sifrator +MYSQL_PASSWORD= +MYSQL_DATABASE=encrypted_chat +DB_POOL_SIZE=30 + +# Performance +THREAD_POOL_SIZE=40 + +# Storage +UPLOAD_DIR=/mnt/hdd/encrypted_chat/uploads + +# TLS (zapnout pro produkci) +TLS_ENABLED=true +TLS_CERT_FILE=/etc/letsencrypt/live/chat.example.com/fullchain.pem +TLS_KEY_FILE=/etc/letsencrypt/live/chat.example.com/privkey.pem + +# Logging +LOG_LEVEL=INFO +``` + +--- + +## Krok 4: Monitoring (doporučeno) + +### Jednoduché metriky bez externích nástrojů + +Přidat do serveru periodické logování: + +```python +# V _periodic_cleanup() (každých 10 min): +async with _clients_lock: + total_connections = sum(len(v) for v in connected_clients.values()) + unique_users = len(connected_clients) +logger.info("[STATS] users=%d connections=%d", unique_users, total_connections) +``` + +### S externími nástroji (volitelně) + +- **htop** — CPU / RAM využití procesu +- **mysqladmin status** — queries/s, slow queries, connections +- **Prometheus + Grafana** — dlouhodobé trendy (přidat až při potřebě) + +--- + +## Budoucí škálování + +### Fáze A: Separace MySQL (15K+ uživatelů) + +MySQL na separátní stroj (nebo managed DB). App server + Redis na jednom, DB na druhém. + +``` +[Server: App + Redis] ──TCP──▶ [Server: MySQL] + │ + └──▶ [HDD/S3: soubory] +``` + +### Fáze B: Horizontální škálování (50K+ uživatelů) + +Více app serverů za load balancerem + Redis Pub/Sub pro cross-server notifikace. + +``` + ┌─── App server 1 ───┐ +Client ──▶ │ connected_clients │──┐ + └─────────────────────┘ │ + ├──▶ Redis Pub/Sub ──▶ MySQL + ┌─── App server 2 ───┐ │ +Client ──▶ │ connected_clients │──┘ + └─────────────────────┘ + ▲ + Load Balancer (HAProxy / nginx stream) + (sticky sessions by user_id) +``` + +Hlavní změna: `_notify_users()` posílá do Redis místo lokálního `connected_clients` pokud uživatel není na tomto serveru. + +### Fáze C: DB škálování (100K+ uživatelů) + +- Read replicas pro SELECT dotazy +- Partitioning tabulky `messages` podle měsíce +- Sharding podle `conversation_id` + +--- + +## Přehled — co je hotovo + +| Krok | Stav | Popis | +|------|------|-------| +| asyncio.to_thread() pro DB | **Hotovo** | 131 DB volání offloadováno do thread poolu | +| ThreadPoolExecutor(40) | **Hotovo** | Konfigurovatelný přes `THREAD_POOL_SIZE` | +| DB indexy (5 nových) | **Hotovo** | Schema + SQL migrace připraveny | +| UPLOAD_DIR na HDD | **Konfigurace** | Nastavit v `.env` | +| MySQL tuning | **Konfigurace** | Aplikovat `tuning.cnf` | +| TLS certifikát | **TODO** | Let's Encrypt nebo vlastní CA | +| Monitoring | **Volitelné** | Periodické logování stats | diff --git a/schema.sql b/schema.sql new file mode 100644 index 0000000..b7d8426 --- /dev/null +++ b/schema.sql @@ -0,0 +1,189 @@ +CREATE DATABASE IF NOT EXISTS encrypted_chat + CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; + +USE encrypted_chat; + +-- Users: identity_key is Ed25519 (32B), rsa_public_key for login challenge only +CREATE TABLE IF NOT EXISTS users ( + id CHAR(36) NOT NULL PRIMARY KEY, + username VARCHAR(255) NOT NULL, + email VARCHAR(255) NOT NULL UNIQUE, + rsa_public_key TEXT NOT NULL, + identity_key BLOB NOT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP +) ENGINE=InnoDB; + +-- Devices: each user can have multiple devices +CREATE TABLE IF NOT EXISTS devices ( + id CHAR(36) NOT NULL PRIMARY KEY, + user_id CHAR(36) NOT NULL, + device_name VARCHAR(255) DEFAULT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + last_seen_at DATETIME DEFAULT NULL, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, + INDEX idx_devices_user (user_id) +) ENGINE=InnoDB; + +-- Signed Pre-Keys (X25519, signed by Ed25519 identity key) — per device +CREATE TABLE IF NOT EXISTS signed_prekeys ( + id CHAR(36) NOT NULL PRIMARY KEY, + user_id CHAR(36) NOT NULL, + device_id CHAR(36) DEFAULT NULL, + public_key BLOB NOT NULL, + signature BLOB NOT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, + INDEX idx_spk_user_device (user_id, device_id) +) ENGINE=InnoDB; + +-- One-Time Pre-Keys (consumed on use) — per device +CREATE TABLE IF NOT EXISTS one_time_prekeys ( + id CHAR(36) NOT NULL PRIMARY KEY, + user_id CHAR(36) NOT NULL, + device_id CHAR(36) DEFAULT NULL, + public_key BLOB NOT NULL, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, + INDEX idx_opk_user_device (user_id, device_id) +) ENGINE=InnoDB; + +-- Conversations +CREATE TABLE IF NOT EXISTS conversations ( + id CHAR(36) NOT NULL PRIMARY KEY, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + name VARCHAR(255) DEFAULT NULL, + created_by CHAR(36) DEFAULT NULL, + avatar_file VARCHAR(255) DEFAULT NULL +) ENGINE=InnoDB; + +CREATE TABLE IF NOT EXISTS conversation_members ( + conversation_id CHAR(36) NOT NULL, + user_id CHAR(36) NOT NULL, + joined_at DATETIME NULL, + PRIMARY KEY (conversation_id, user_id), + FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, + INDEX idx_cm_user (user_id) +) ENGINE=InnoDB; + +-- Group invitations (pending invitations to join a group) +CREATE TABLE IF NOT EXISTS group_invitations ( + id CHAR(36) NOT NULL PRIMARY KEY, + conversation_id CHAR(36) NOT NULL, + user_id CHAR(36) NOT NULL, + invited_by CHAR(36) NOT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + UNIQUE KEY uq_conv_user (conversation_id, user_id), + FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, + FOREIGN KEY (invited_by) REFERENCES users(id) ON DELETE CASCADE, + INDEX idx_inv_user (user_id) +) ENGINE=InnoDB; + +-- Messages: per-recipient ciphertext (Double Ratchet = each recipient has different ciphertext) +CREATE TABLE IF NOT EXISTS messages ( + id CHAR(36) NOT NULL PRIMARY KEY, + conversation_id CHAR(36) NOT NULL, + sender_id CHAR(36) NOT NULL, + sender_device_id CHAR(36) DEFAULT NULL, + ratchet_header BLOB NOT NULL, + x3dh_header BLOB DEFAULT NULL, + sender_chain_id BLOB DEFAULT NULL, + sender_chain_n INT DEFAULT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + deleted_at DATETIME DEFAULT NULL, + image_file_id CHAR(36) DEFAULT NULL, + pinned_at DATETIME DEFAULT NULL, + pinned_by CHAR(36) DEFAULT NULL, + FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE, + FOREIGN KEY (sender_id) REFERENCES users(id) ON DELETE CASCADE, + INDEX idx_messages_conv_created (conversation_id, created_at), + INDEX idx_messages_deleted (conversation_id, deleted_at), + INDEX idx_messages_pinned (conversation_id, pinned_at) +) ENGINE=InnoDB; + +-- Per-recipient encrypted content — per device +-- device_id '00000000-0000-0000-0000-000000000000' = self-encrypted / legacy +CREATE TABLE IF NOT EXISTS message_recipients ( + message_id CHAR(36) NOT NULL, + user_id CHAR(36) NOT NULL, + device_id CHAR(36) NOT NULL DEFAULT '00000000-0000-0000-0000-000000000000', + encrypted_content BLOB NOT NULL, + nonce BLOB NOT NULL, + ratchet_header BLOB DEFAULT NULL, + x3dh_header BLOB DEFAULT NULL, + PRIMARY KEY (message_id, user_id, device_id), + FOREIGN KEY (message_id) REFERENCES messages(id) ON DELETE CASCADE, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE +) ENGINE=InnoDB; + +-- Sender Keys for groups (distributed via pairwise ratchet) — per device +CREATE TABLE IF NOT EXISTS group_sender_keys ( + conversation_id CHAR(36) NOT NULL, + sender_id CHAR(36) NOT NULL, + device_id CHAR(36) NOT NULL DEFAULT '00000000-0000-0000-0000-000000000000', + chain_id BLOB NOT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (conversation_id, sender_id, device_id), + FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE, + FOREIGN KEY (sender_id) REFERENCES users(id) ON DELETE CASCADE +) ENGINE=InnoDB; + +-- Read receipts +CREATE TABLE IF NOT EXISTS message_reads ( + message_id CHAR(36) NOT NULL, + user_id CHAR(36) NOT NULL, + read_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (message_id, user_id), + FOREIGN KEY (message_id) REFERENCES messages(id) ON DELETE CASCADE, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, + INDEX idx_reads_user (user_id), + INDEX idx_reads_read_at (read_at) +) ENGINE=InnoDB; + +-- Delivery receipts +CREATE TABLE IF NOT EXISTS message_deliveries ( + message_id CHAR(36) NOT NULL, + user_id CHAR(36) NOT NULL, + delivered_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (message_id, user_id), + FOREIGN KEY (message_id) REFERENCES messages(id) ON DELETE CASCADE, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE +) ENGINE=InnoDB; + +-- User profiles +CREATE TABLE IF NOT EXISTS user_profiles ( + user_id CHAR(36) NOT NULL PRIMARY KEY, + phone VARCHAR(50) DEFAULT NULL, + phone_visible TINYINT(1) NOT NULL DEFAULT 0, + email_visible TINYINT(1) NOT NULL DEFAULT 1, + location VARCHAR(255) DEFAULT NULL, + location_visible TINYINT(1) NOT NULL DEFAULT 0, + avatar_file VARCHAR(255) DEFAULT NULL, + updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE +) ENGINE=InnoDB; + +-- Message reactions (emoji reactions on messages) +CREATE TABLE IF NOT EXISTS message_reactions ( + id CHAR(36) NOT NULL PRIMARY KEY, + message_id CHAR(36) NOT NULL, + user_id CHAR(36) NOT NULL, + reaction VARCHAR(32) NOT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + UNIQUE KEY uq_reaction (message_id, user_id), + FOREIGN KEY (message_id) REFERENCES messages(id) ON DELETE CASCADE, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, + INDEX idx_reactions_created_at (created_at) +) ENGINE=InnoDB; + +-- Image uploads +CREATE TABLE IF NOT EXISTS image_uploads ( + file_id CHAR(36) NOT NULL PRIMARY KEY, + conversation_id CHAR(36) NOT NULL, + uploader_id CHAR(36) NOT NULL, + file_size BIGINT NOT NULL DEFAULT 0, + completed BOOLEAN NOT NULL DEFAULT FALSE, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE, + FOREIGN KEY (uploader_id) REFERENCES users(id) ON DELETE CASCADE +) ENGINE=InnoDB; diff --git a/server.py b/server.py new file mode 100644 index 0000000..75a7c3f --- /dev/null +++ b/server.py @@ -0,0 +1,2933 @@ +"""Asyncio TCP server — stores and relays encrypted blobs without seeing content.""" + +import asyncio +from concurrent.futures import ThreadPoolExecutor +import hashlib +import hmac +import ipaddress +import json +import logging +import os +import re +import secrets +import signal +import smtplib +import ssl +import subprocess +import sys +from email.mime.text import MIMEText +from pathlib import Path +from datetime import datetime, timezone + +from dotenv import load_dotenv + +load_dotenv() + +import db +from crypto_utils import load_public_key, rsa_verify, load_ed25519_public, ed25519_verify, serialize_x25519_public +from protocol import VERSION, MIN_CLIENT_VERSION, version_gte, ProtocolReader, ProtocolWriter, encode_binary, decode_binary, MAX_MESSAGE_BYTES, MAX_IMAGE_BYTES, MAX_FILE_BYTES, IMAGE_CHUNK_SIZE + + +class _AsyncDB: + """Async proxy — offloads every synchronous db.* call to a thread via asyncio.to_thread(). + + This prevents blocking the asyncio event loop during MySQL I/O. + Wrapper functions are cached after first access for efficiency. + """ + + def __getattr__(self, name: str): + func = getattr(db, name) + + async def wrapper(*args, **kwargs): + return await asyncio.to_thread(func, *args, **kwargs) + + wrapper.__name__ = name + wrapper.__qualname__ = f"_AsyncDB.{name}" + setattr(self, name, wrapper) + return wrapper + + +adb = _AsyncDB() + + +# Connected clients: user_id -> list[ProtocolWriter] +connected_clients: dict[str, list[ProtocolWriter]] = {} +# Writer -> device_id mapping (id(writer) -> device_id) +writer_device_map: dict[int, str] = {} +# Pairing sessions: code -> data +pairing_sessions: dict[str, dict] = {} +pending_registrations: dict[str, dict] = {} +# Used PoW challenges (prevents replay within validity window) +_used_pow_challenges: dict[str, float] = {} # challenge -> used_at +# Pending image uploads: file_id -> {temp_path, received_bytes, file_size, conv_id} +pending_uploads: dict[str, dict] = {} +# Phantom user IDs (loaded at startup, updated on create/delete) +phantom_user_ids: set[str] = set() + +# Locks for shared mutable state (H4 race condition fix) +_clients_lock = asyncio.Lock() # Protects: connected_clients, writer_device_map, phantom_user_ids +_conn_lock = asyncio.Lock() # Protects: connection_counts, current_connections, rate_limits +_pairing_lock = asyncio.Lock() # Protects: pairing_sessions, pending_registrations, _used_pow_challenges +_uploads_lock = asyncio.Lock() # Protects: pending_uploads +_phantom_lock = asyncio.Lock() # Serializes phantom user creation (cap check + DB create + set add) + +UPLOAD_DIR = Path(os.getenv("UPLOAD_DIR", "uploads")) + + +def _secure_delete(p: Path): + """Overwrite file with random data before deletion (anti-forensic wipe).""" + try: + if not p.exists(): + return + size = p.stat().st_size + if size > 0: + with open(p, "r+b") as f: + f.write(os.urandom(size)) + f.flush() + os.fsync(f.fileno()) + p.unlink() + except Exception: + try: + p.unlink(missing_ok=True) + except Exception: + pass + + +# C6 fix: UUID validation + safe path construction to prevent path traversal +_UUID_RE = re.compile(r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$', re.IGNORECASE) + + +def _valid_uuid(value: str) -> bool: + """Validate that value is a canonical UUID (no path components).""" + return bool(_UUID_RE.match(value)) + + +# L8 fix: email validation to prevent phantom DB inflation +_EMAIL_RE = re.compile(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$") + + +def _valid_email(email: str) -> bool: + """Validate basic email format (L8).""" + return bool(_EMAIL_RE.match(email)) and len(email) <= 254 + + +# C2 fix: ratchet/x3dh header validation +_RATCHET_HEADER_KEYS = {"dh_pub", "n", "pn"} +_MAX_HEADER_BYTES = 4096 + + +def _validate_header(raw, name: str) -> bytes | None: + """Validate and serialize a ratchet/x3dh header. + + Accepts only dict with expected keys, rejects str/bytes to prevent + poisoned headers from being stored. Validates that ratchet headers + contain the required keys (dh_pub, n, pn) with correct types. + Returns UTF-8 encoded JSON bytes or None if invalid. + """ + if not isinstance(raw, dict): + return None + serialized = json.dumps(raw) + if len(serialized) > _MAX_HEADER_BYTES: + return None + # Validate ratchet header keys/types if this looks like one + if name in ("ratchet_header", "recipient_ratchet_header"): + # Accept self-encrypted marker {"self": true} + if raw.get("self") is True and len(raw) == 1: + return serialized.encode() + if not _RATCHET_HEADER_KEYS.issubset(raw.keys()): + return None + if not isinstance(raw.get("dh_pub"), str): + return None + if type(raw.get("n")) is not int or type(raw.get("pn")) is not int: + return None + return serialized.encode() + + +def _append_file(path: Path, data: bytes): + """Append data to file (runs in thread pool to avoid blocking event loop).""" + with open(path, "ab") as f: + f.write(data) + + +def _read_file_chunk(path: Path, offset: int, size: int) -> bytes: + """Read a chunk from file (runs in thread pool to avoid blocking event loop).""" + with open(path, "rb") as f: + f.seek(offset) + return f.read(size) + + +def _safe_upload_path(file_id: str, suffix: str) -> Path | None: + """Return resolved path inside UPLOAD_DIR, or None if traversal detected.""" + p = (UPLOAD_DIR / f"{file_id}{suffix}").resolve() + if not p.is_relative_to(UPLOAD_DIR.resolve()): + return None + return p + + +def _safe_avatar_path(filename: str) -> Path | None: + """Return resolved avatar path inside UPLOAD_DIR/avatars, or None if traversal detected.""" + avatar_dir = (UPLOAD_DIR / "avatars").resolve() + p = (UPLOAD_DIR / "avatars" / filename).resolve() + if not p.is_relative_to(avatar_dir): + return None + return p + + +PAIRING_TTL_SECONDS = 120 +REGISTER_TTL_SECONDS = 600 # 10 min (was 3600) — faster slot release under load +PAIRING_MAX_POLL_ATTEMPTS = 90 +PAIRING_MAX_SESSIONS = 100 # global cap on concurrent pairing sessions +MAX_PENDING_REGISTRATIONS = 1000 # global cap on pending registration codes +MAX_PENDING_PER_IP = 5 # per-IP cap on pending registrations +MAX_PENDING_PER_SUBNET = 20 # per-/24 (IPv4) or /64 (IPv6) cap +REGISTRATION_PRESSURE_THRESHOLD = 0.8 # 80% → tighten limits + require PoW +POW_DIFFICULTY = 20 # leading zero bits in SHA-256 (~1M hashes, ~0.5-2s) +SMTP_RATE_GLOBAL = 30 # registration emails per minute (global) +SMTP_RATE_PER_IP = 3 # registration emails per minute (per IP) +SMTP_RATE_PER_TARGET = 2 # registration emails per minute (per target email) +MAX_PHANTOM_USERS = 500 # global cap on phantom user count +MAX_UPLOADS_GLOBAL = 200 # global cap on concurrent in-flight uploads +MAX_UPLOADS_PER_USER = 5 # per-user cap on concurrent in-flight uploads +UPLOAD_STALE_SECONDS = 600 # stale upload threshold (10 min) + +# SMTP configuration for registration codes +SMTP_HOST = os.getenv("SMTP_HOST", "") +SMTP_PORT = int(os.getenv("SMTP_PORT", "587")) +SMTP_USER = os.getenv("SMTP_USER", "") +SMTP_PASS = os.getenv("SMTP_PASS", "") +SMTP_FROM = os.getenv("SMTP_FROM", "") +RATE_LIMIT_WINDOW = 60.0 # seconds +CONNECTION_RL_WINDOW = 1.0 # seconds +CONNECTION_RL_MAX = 20 # max requests per window per connection +MAX_CONNECTIONS_PER_IP = 10 +MAX_CONNECTIONS_GLOBAL = 200 +METADATA_RETENTION_DAYS = int(os.getenv("METADATA_RETENTION_DAYS", "90")) + + +def setup_logging(): + level_name = os.getenv("LOG_LEVEL", "INFO").upper() + level = getattr(logging, level_name, logging.WARNING) + logging.basicConfig(level=level, format="%(asctime)s %(levelname)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S") + + +logger = logging.getLogger("encrypted_chat.server") + + +def _who(session: dict | None) -> str: + """Format session info for logging: truncated user_id + device prefix. + + Avoids leaking usernames and emails into log files. + """ + if not session: + return "" + uid = session.get("user_id", "?")[:8] + dev = session.get("device_id", "")[:8] if session.get("device_id") else "" + return f"u={uid} d={dev}" if dev else f"u={uid}" + + +rate_limits: dict[str, list[float]] = {} +connection_counts: dict[str, int] = {} +current_connections = 0 + + +def _rate_limit_key(action: str, addr: str, email: str | None = None) -> str: + if email: + return f"{action}|{addr}|{email.lower()}" + return f"{action}|{addr}" + + +async def _is_rate_limited(key: str, limit: int) -> bool: + async with _conn_lock: + now = asyncio.get_event_loop().time() + window_start = now - RATE_LIMIT_WINDOW + times = rate_limits.get(key, []) + times = [t for t in times if t >= window_start] + if len(times) >= limit: + rate_limits[key] = times + return True + times.append(now) + rate_limits[key] = times + return False + + +async def _create_phantom_guarded(email: str, addr: str, user_id: str) -> tuple[dict | None, str]: + """Check limits + create phantom user atomically (serialized via _phantom_lock). + + Returns (user_dict, error_message). user_dict is None on rejection. + """ + # Rate limit checks outside _phantom_lock (they acquire _conn_lock) + if await _is_rate_limited(f"phantom_create|{user_id}", 10): + return None, "Too many new contacts. Try later." + if await _is_rate_limited(f"phantom_create_ip|{addr}", 10): + return None, "Too many new contacts. Try later." + async with _phantom_lock: + async with _clients_lock: + phantom_count = len(phantom_user_ids) + if phantom_count >= MAX_PHANTOM_USERS: + return None, "Server limit reached. Try later." + u = await adb.create_phantom_user(email) + async with _clients_lock: + phantom_user_ids.add(u["id"]) + return u, "" + + +def _get_peer_addr(writer: ProtocolWriter) -> str: + try: + return str(writer._writer.get_extra_info("peername")[0]) + except Exception: + return "unknown" + + +async def _notify_users(user_ids, msg_type, data, exclude_writer=None): + """Snapshot writers under lock, send notifications outside lock.""" + targets = [] + async with _clients_lock: + for uid in user_ids: + for w in connected_clients.get(uid, []): + targets.append(w) + for w in targets: + if w is exclude_writer: + continue + try: + await w.send_response(msg_type, "ok", data) + except Exception: + pass + + +async def _notify_users_individual(notifications, exclude_writer=None): + """Send per-user data. notifications: list of (user_id, msg_type, data).""" + targets = [] + async with _clients_lock: + for uid, mt, d in notifications: + for w in connected_clients.get(uid, []): + targets.append((w, mt, d)) + for w, mt, d in targets: + if w is exclude_writer: + continue + try: + await w.send_response(mt, "ok", d) + except Exception: + pass + + +async def _cleanup_pairings(): + async with _pairing_lock: + now = asyncio.get_event_loop().time() + expired = [code for code, p in pairing_sessions.items() if now - p["created_at"] > PAIRING_TTL_SECONDS] + for code in expired: + pairing_sessions.pop(code, None) + + +async def _cleanup_registrations(): + async with _pairing_lock: + now = asyncio.get_event_loop().time() + expired = [code for code, p in pending_registrations.items() if now - p["created_at"] > REGISTER_TTL_SECONDS] + for code in expired: + pending_registrations.pop(code, None) + # Purge used PoW challenges older than 120s (validity window) + stale = [ch for ch, ts in _used_pow_challenges.items() if now - ts > 120] + for ch in stale: + _used_pow_challenges.pop(ch, None) + + +def _generate_pairing_code() -> str: + for _ in range(10): + code = f"{int.from_bytes(os.urandom(4), 'big') % 100000000:08d}" + if code not in pairing_sessions: + return code + return f"{int.from_bytes(os.urandom(4), 'big') % 100000000:08d}" + + +def _generate_register_code() -> str: + for _ in range(10): + code = f"{int.from_bytes(os.urandom(3), 'big') % 1000000:06d}" + if code not in pending_registrations: + return code + return f"{int.from_bytes(os.urandom(3), 'big') % 1000000:06d}" + +def _validate_public_key_pem(pem_str: str) -> bool: + """Validate that a string is a valid RSA public key PEM.""" + try: + key = load_public_key(pem_str.encode("utf-8")) + if key.key_size < 2048: + return False + return True + except Exception: + return False + + +def _send_registration_email(to_email: str, code: str) -> bool: + """Send registration code via SMTP. Returns True on success.""" + if not SMTP_HOST: + return False + try: + msg = MIMEText(f"Your registration code is: {code}\n\nThis code expires in 10 minutes.") + msg["Subject"] = "Encrypted Chat - Registration Code" + msg["From"] = SMTP_FROM or SMTP_USER + msg["To"] = to_email + with smtplib.SMTP(SMTP_HOST, SMTP_PORT, timeout=10) as server: + server.starttls() + if SMTP_USER: + server.login(SMTP_USER, SMTP_PASS) + server.send_message(msg) + return True + except Exception as e: + logger.warning("Failed to send registration email: %s", e) + return False + + +async def send_resp(msg: dict, writer: ProtocolWriter, msg_type: str, status: str, data: dict | None = None): + await writer.send_response(msg_type, status, data, request_id=msg.get("request_id")) + + +# --- Registration admission control --- + +_POW_SECRET = os.urandom(32) # per-process; restarts invalidate outstanding challenges + + +def _get_subnet(addr: str) -> str: + """Extract /24 for IPv4, /64 for IPv6.""" + try: + ip = ipaddress.ip_address(addr) + if ip.version == 4: + return str(ipaddress.ip_network(f"{ip}/24", strict=False)) + return str(ipaddress.ip_network(f"{ip}/64", strict=False)) + except ValueError: + return addr + + +def _pending_counts_by_origin(addr: str) -> tuple[int, int]: + """Count pending registrations by IP and subnet. Caller must hold _pairing_lock.""" + subnet = _get_subnet(addr) + ip_count = 0 + subnet_count = 0 + for p in pending_registrations.values(): + p_addr = p.get("addr", "") + if p_addr == addr: + ip_count += 1 + if _get_subnet(p_addr) == subnet: + subnet_count += 1 + return ip_count, subnet_count + + +def _generate_pow_challenge() -> tuple[str, str]: + """Generate a stateless PoW challenge (challenge, mac). + + The challenge embeds a timestamp so the server can reject stale solutions. + The HMAC proves the challenge was issued by this server instance. + """ + ts = str(int(asyncio.get_event_loop().time())) + nonce = secrets.token_hex(16) + challenge = f"{ts}:{nonce}" + mac = hmac.new(_POW_SECRET, challenge.encode(), hashlib.sha256).hexdigest() + return challenge, mac + + +def _verify_pow(challenge: str, mac: str, nonce: str, difficulty: int) -> bool: + """Verify a PoW solution: HMAC authentic, timestamp fresh, hash has leading zeros.""" + # Verify HMAC + expected = hmac.new(_POW_SECRET, challenge.encode(), hashlib.sha256).hexdigest() + if not hmac.compare_digest(expected, mac): + return False + # Check timestamp freshness (120s window) + try: + ts = int(challenge.split(":")[0]) + except (ValueError, IndexError): + return False + now = int(asyncio.get_event_loop().time()) + if abs(now - ts) > 120: + return False + # Verify PoW: SHA-256(challenge + nonce) must have `difficulty` leading zero bits + digest = hashlib.sha256(f"{challenge}{nonce}".encode()).digest() + # Check leading zero bits + bits_needed = difficulty + for byte in digest: + if bits_needed <= 0: + break + if bits_needed >= 8: + if byte != 0: + return False + bits_needed -= 8 + else: + mask = (0xFF << (8 - bits_needed)) & 0xFF + if byte & mask: + return False + bits_needed = 0 + return True + + +async def handle_register_start(msg: dict, writer: ProtocolWriter) -> dict | None: + await _cleanup_registrations() + username = msg.get("username", "").strip() + public_key = msg.get("public_key", "").strip() + identity_key_b64 = msg.get("identity_key", "").strip() + email = msg.get("email", "").strip() + addr = _get_peer_addr(writer) + if await _is_rate_limited(_rate_limit_key("register_start", addr, email), 3): + await send_resp(msg, writer, "register_start", "error", {"message": "Too many attempts. Try later."}) + return None + # Per-IP limit (regardless of email) to prevent SMTP spam via email rotation + if await _is_rate_limited(f"register_start_ip|{addr}", 6): + await send_resp(msg, writer, "register_start", "error", {"message": "Too many attempts. Try later."}) + return None + if not username or not public_key or not email or not identity_key_b64: + await send_resp(msg, writer, "register_start", "error", {"message": "Missing fields"}) + return None + if not _validate_public_key_pem(public_key): + await send_resp(msg, writer, "register_start", "error", {"message": "Invalid public key format"}) + return None + # Validate identity key is 32 bytes + try: + ik_bytes = decode_binary(identity_key_b64) + if len(ik_bytes) != 32: + raise ValueError("Identity key must be 32 bytes") + load_ed25519_public(ik_bytes) + except Exception: + await send_resp(msg, writer, "register_start", "error", {"message": "Invalid identity key"}) + return None + existing_email = await adb.get_user_by_email(email) + phantom_id = None + is_existing_real_user = False + if existing_email: + if existing_email.get("rsa_public_key") == "PHANTOM": + phantom_id = existing_email["id"] + else: + is_existing_real_user = True + # --- Admission control (all checks under lock, I/O outside) --- + # Existing-email goes through the same path so responses are + # indistinguishable from new-email (H3 anti-enumeration). + # Both allocate a slot so per-IP/subnet cap counting is identical. + async with _pairing_lock: + total = len(pending_registrations) + # Hard cap + if total >= MAX_PENDING_REGISTRATIONS: + reject_reason = "cap" + else: + # Per-IP / per-subnet slot limits + ip_count, subnet_count = _pending_counts_by_origin(addr) + if ip_count >= MAX_PENDING_PER_IP: + reject_reason = "ip" + elif subnet_count >= MAX_PENDING_PER_SUBNET: + reject_reason = "subnet" + else: + reject_reason = None + # Pressure mode: require PoW when >80% full + under_pressure = total >= MAX_PENDING_REGISTRATIONS * REGISTRATION_PRESSURE_THRESHOLD + need_pow = under_pressure and reject_reason is None + # If PoW required, verify the client's solution (one-time use) + pow_ok = False + if need_pow: + pow_challenge = msg.get("pow_challenge", "") + pow_mac = msg.get("pow_mac", "") + pow_nonce = msg.get("pow_nonce", "") + if pow_challenge and pow_mac and pow_nonce: + if pow_challenge in _used_pow_challenges: + pow_ok = False # replay + elif _verify_pow(pow_challenge, pow_mac, pow_nonce, POW_DIFFICULTY): + _used_pow_challenges[pow_challenge] = asyncio.get_event_loop().time() + pow_ok = True + # Decide: admit, challenge, or reject + if reject_reason: + admit = False + send_challenge = False + code = None + elif need_pow and not pow_ok: + admit = False + send_challenge = True + code = None + else: + # Both existing and new emails allocate a slot so per-IP/subnet + # counting behaves identically (anti-enumeration via slot side-channel). + # Existing-email slots are inert — register_confirm silently fails. + admit = True + send_challenge = False + code = _generate_register_code() + pending_registrations[code] = { + "username": username, + "public_key": public_key, + "identity_key": ik_bytes, + "email": email, + "created_at": asyncio.get_event_loop().time(), + "phantom_id": phantom_id, + "addr": addr, + "fake": is_existing_real_user, + } + # --- I/O outside lock --- + if not admit: + if send_challenge: + challenge, mac = _generate_pow_challenge() + await send_resp(msg, writer, "register_start", "pow_required", { + "challenge": challenge, "mac": mac, "difficulty": POW_DIFFICULTY, + }) + else: + await send_resp(msg, writer, "register_start", "error", {"message": "Server busy. Try later."}) + return None + logger.info("[REGISTER] registration started") + is_dev = os.getenv("ENVIRONMENT", "").lower() in ("dev", "development") + # SMTP rate limiting + smtp_blocked = False + if SMTP_HOST: + if await _is_rate_limited("smtp_send|global", SMTP_RATE_GLOBAL): + smtp_blocked = True + elif await _is_rate_limited(f"smtp_send_ip|{addr}", SMTP_RATE_PER_IP): + smtp_blocked = True + elif await _is_rate_limited(f"smtp_send_target|{email.lower()}", SMTP_RATE_PER_TARGET): + smtp_blocked = True + if smtp_blocked: + if is_dev: + logger.warning("[REGISTER] SMTP rate limit hit — returning code (dev mode)") + await send_resp(msg, writer, "register_start", "ok", {"code": code}) + else: + logger.warning("[REGISTER] SMTP rate limit hit — revoking slot silently") + async with _pairing_lock: + pending_registrations.pop(code, None) + await send_resp(msg, writer, "register_start", "ok", + {"message": "Code sent to your email."}) + return None + # Send registration email in a thread (non-blocking) for both real + # and fake registrations. For existing emails we still call SMTP so + # the response timing is indistinguishable (anti-enumeration). + # The email goes to the real address either way — existing users just + # won't be able to confirm (code is for a fake slot). + email_sent = await asyncio.to_thread(_send_registration_email, email, code) + if email_sent: + await send_resp(msg, writer, "register_start", "ok", {"message": "Code sent to your email."}) + elif is_dev: + logger.warning("[REGISTER] No SMTP / send failed — returning code (dev mode)") + await send_resp(msg, writer, "register_start", "ok", {"code": code}) + else: + logger.warning("[REGISTER] SMTP send failed — revoking slot silently") + async with _pairing_lock: + pending_registrations.pop(code, None) + await send_resp(msg, writer, "register_start", "ok", + {"message": "Code sent to your email."}) + return None + + +async def handle_register_confirm(msg: dict, writer: ProtocolWriter) -> dict | None: + await _cleanup_registrations() + email = msg.get("email", "").strip() + code = msg.get("code", "").strip() + addr = _get_peer_addr(writer) + if await _is_rate_limited(_rate_limit_key("register_confirm", addr, email), 3): + await send_resp(msg, writer, "register_confirm", "error", {"message": "Too many attempts. Try later."}) + return None + if not email or not code: + await send_resp(msg, writer, "register_confirm", "error", {"message": "Missing email or code"}) + return None + async with _pairing_lock: + pending = pending_registrations.get(code) + if pending and pending.get("email") == email: + pending_registrations.pop(code, None) + else: + pending = None + if not pending: + await send_resp(msg, writer, "register_confirm", "error", {"message": "Invalid or expired code"}) + return None + # H3 anti-enumeration: fake slot (existing email) — reject with same + # generic message so attacker can't distinguish from a wrong code + if pending.get("fake"): + await send_resp(msg, writer, "register_confirm", "error", {"message": "Invalid or expired code"}) + return None + phantom_id = pending.get("phantom_id") + if phantom_id: + # Upgrade phantom in-place — preserves FK references (invitations, memberships) + user_id = await adb.upgrade_phantom_user( + phantom_id, + pending["username"], + pending["public_key"], + pending["identity_key"], + ) + if user_id: + async with _clients_lock: + phantom_user_ids.discard(phantom_id) + else: + # Phantom was deleted concurrently — fall back to normal create + user_id = await adb.create_user( + pending["username"], + pending["email"], + pending["public_key"], + pending["identity_key"], + ) + else: + user_id = await adb.create_user( + pending["username"], + pending["email"], + pending["public_key"], + pending["identity_key"], + ) + await adb.create_default_profile(user_id) + logger.info("[REGISTER] confirmed (user_id=%s)", user_id[:8]) + await send_resp(msg, writer, "register_confirm", "ok", {"user_id": user_id}) + return None + + +async def handle_login_start(msg: dict, writer: ProtocolWriter, state: dict): + email = msg.get("email", "").strip() + addr = _get_peer_addr(writer) + if await _is_rate_limited(_rate_limit_key("login_start", addr, email), 10): + await send_resp(msg, writer, "login_start", "error", {"message": "Too many attempts. Try later."}) + return + if await _is_rate_limited(f"login_start_ip|{addr}", 20): + await send_resp(msg, writer, "login_start", "error", {"message": "Too many attempts. Try later."}) + return + if not email: + await send_resp(msg, writer, "login_start", "error", {"message": "Missing email"}) + return + user = await adb.get_user_by_email(email) + challenge = os.urandom(32) + state["login_email"] = email + state["login_challenge"] = challenge + if not user: + # H3 anti-enumeration: return a fake challenge so attacker can't distinguish + # "user not found" from "user exists". login_finish will fail with generic error. + state["_login_fake"] = True + await send_resp(msg, writer, "login_start", "ok", {"challenge": encode_binary(challenge)}) + + +async def handle_login_finish(msg: dict, writer: ProtocolWriter, state: dict) -> dict | None: + email = msg.get("email", "").strip() + signature_b64 = msg.get("signature", "") + challenge = state.get("login_challenge") + expected_email = state.get("login_email") + addr = _get_peer_addr(writer) + if await _is_rate_limited(_rate_limit_key("login_finish", addr, email), 10): + await send_resp(msg, writer, "login_finish", "error", {"message": "Too many attempts. Try later."}) + return None + if not email or not signature_b64: + await send_resp(msg, writer, "login_finish", "error", {"message": "Missing email or signature"}) + return None + if not challenge or expected_email != email: + await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"}) + return None + + # H3: if login_start was for a non-existent user, fail with generic error + is_fake = state.pop("_login_fake", False) + + try: + if is_fake: + await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"}) + return None + + user = await adb.get_user_by_email(email) + if not user: + await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"}) + return None + + public_key = load_public_key(user["rsa_public_key"].encode("utf-8")) + signature = decode_binary(signature_b64) + if not rsa_verify(public_key, signature, challenge): + await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"}) + return None + except ValueError: + # H5: invalid base64 in signature + await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"}) + return None + finally: + state.pop("login_challenge", None) + state.pop("login_email", None) + + user_id = user["id"] + + # Version check: reject outdated clients + client_version = msg.get("client_version", "") + if client_version and not version_gte(client_version, MIN_CLIENT_VERSION): + await send_resp(msg, writer, "login_finish", "error", { + "message": f"Client version {client_version} is too old. Minimum required: {MIN_CLIENT_VERSION}", + "min_version": MIN_CLIENT_VERSION, + "server_version": VERSION, + }) + return None + + # Device registration: client may send device_id to reuse an existing device + client_device_id = msg.get("device_id") + device_id = None + if client_device_id: + dev = await adb.get_device(client_device_id) + if dev and dev["user_id"] == user_id: + device_id = client_device_id + if not device_id: + device_name = msg.get("device_name", "Unknown") + device_id = await adb.create_device(user_id, device_name) + await adb.update_device_last_seen(device_id) + + async with _clients_lock: + if user_id not in connected_clients: + connected_clients[user_id] = [] + connected_clients[user_id].append(writer) + writer_device_map[id(writer)] = device_id + logger.info("[LOGIN] u=%s d=%s client_v=%s", + user_id[:8], device_id[:8] if device_id else "?", client_version or "unknown") + await send_resp(msg, writer, "login_finish", "ok", { + "user_id": user_id, "username": user["username"], "email": user["email"], + "device_id": device_id, "server_version": VERSION, + }) + + # Send online status notifications + contacts = await adb.get_user_contacts(user_id) + online_targets = [] + async with _clients_lock: + online_contacts = [cid for cid in contacts if cid in connected_clients and connected_clients[cid]] + # Always notify contacts (handles reconnect where old writer is still lingering) + for contact_id in contacts: + for cw in connected_clients.get(contact_id, []): + online_targets.append(cw) + await writer.send_response("online_users", "ok", {"user_ids": online_contacts}) + # Send online notifications outside lock + for cw in online_targets: + try: + await cw.send_response("user_online", "ok", {"user_id": user_id}) + except Exception: + pass + + return {"user_id": user_id, "username": user["username"], "email": user["email"], + "device_id": device_id} + + +async def handle_get_user_info(msg: dict, session: dict, writer: ProtocolWriter): + """Get user info including identity key (for X3DH). Requires login.""" + email = msg.get("email", "").strip() + user_id = msg.get("user_id", "").strip() + addr = _get_peer_addr(writer) + if await _is_rate_limited(_rate_limit_key("get_user_info", addr, email or user_id), 30): + await send_resp(msg, writer, "get_user_info", "error", {"message": "Too many attempts. Try later."}) + return + if user_id and not _valid_uuid(user_id): + await send_resp(msg, writer, "get_user_info", "error", {"message": "User not found"}) + return + user = None + if email: + user = await adb.get_user_by_email(email) + elif user_id: + user = await adb.get_user_by_id(user_id) + if not user: + await send_resp(msg, writer, "get_user_info", "error", {"message": "User not found"}) + return + # H4 fix: restrict lookups to self or contacts (shared conversation) + target_id = user["id"] + if target_id != session["user_id"]: + if not await adb.shares_conversation(session["user_id"], target_id): + await send_resp(msg, writer, "get_user_info", "error", {"message": "User not found"}) + return + ik = user.get("identity_key") + await send_resp(msg, writer, "get_user_info", "ok", { + "user_id": user["id"], + "username": user["username"], + "email": user["email"], + "identity_key": encode_binary(ik) if ik else "", + }) + + +async def handle_upload_prekeys(msg: dict, session: dict, writer: ProtocolWriter): + """Upload signed prekey + batch of one-time prekeys.""" + if await _is_rate_limited(f"upload_prekeys|{session['user_id']}", 5): + await send_resp(msg, writer, "upload_prekeys", "error", {"message": "Too many requests. Try later."}) + return + spk_data = msg.get("signed_prekey") + otps = msg.get("one_time_prekeys", []) + if not spk_data: + await send_resp(msg, writer, "upload_prekeys", "error", {"message": "Missing signed_prekey"}) + return + + spk_id = spk_data.get("id", "") + spk_pub_b64 = spk_data.get("public_key", "") + spk_sig_b64 = spk_data.get("signature", "") + if not spk_id or not spk_pub_b64 or not spk_sig_b64: + await send_resp(msg, writer, "upload_prekeys", "error", {"message": "Incomplete signed_prekey"}) + return + + spk_pub = decode_binary(spk_pub_b64) + spk_sig = decode_binary(spk_sig_b64) + + # Verify SPK signature with user's identity key + user = await adb.get_user_by_id(session["user_id"]) + if not user or not user.get("identity_key"): + await send_resp(msg, writer, "upload_prekeys", "error", {"message": "No identity key"}) + return + ik_pub = load_ed25519_public(user["identity_key"]) + if not ed25519_verify(ik_pub, spk_sig, spk_pub): + await send_resp(msg, writer, "upload_prekeys", "error", {"message": "Invalid SPK signature"}) + return + + device_id = session.get("device_id") + await adb.store_signed_prekey(session["user_id"], spk_id, spk_pub, spk_sig, device_id=device_id) + + # Store OTPs + otp_records = [] + for otp in otps: + otp_id = otp.get("id", "") + otp_pub_b64 = otp.get("public_key", "") + if otp_id and otp_pub_b64: + otp_records.append({"id": otp_id, "public_key": decode_binary(otp_pub_b64)}) + if otp_records: + await adb.store_one_time_prekeys(session["user_id"], otp_records, device_id=device_id) + + logger.info("[PREKEYS] %s uploaded 1 SPK + %d OTPs", _who(session), len(otp_records)) + await send_resp(msg, writer, "upload_prekeys", "ok", {"message": "OK"}) + + +async def handle_get_key_bundle(msg: dict, session: dict, writer: ProtocolWriter): + """Fetch key bundle for X3DH. Returns per-device bundles. Consumes one OTP per device.""" + target_user_id = msg.get("user_id", "").strip() + if not target_user_id: + await send_resp(msg, writer, "get_key_bundle", "error", {"message": "Missing user_id"}) + return + if not _valid_uuid(target_user_id): + await send_resp(msg, writer, "get_key_bundle", "error", {"message": "Invalid user_id"}) + return + # M4: rate limit + authorization (prevents OPK depletion) + if await _is_rate_limited(f"get_key_bundle|{session['user_id']}", 10): + await send_resp(msg, writer, "get_key_bundle", "error", {"message": "Too many requests. Try later."}) + return + # Auth check before per-target rate limit so unauthorized requests don't burn target's bucket + if target_user_id != session["user_id"]: + if not await adb.shares_conversation(session["user_id"], target_user_id): + await send_resp(msg, writer, "get_key_bundle", "error", {"message": "Key bundle not available"}) + return + if await _is_rate_limited(f"get_key_bundle_target|{target_user_id}", 20): + await send_resp(msg, writer, "get_key_bundle", "error", {"message": "Too many requests. Try later."}) + return + result = await adb.get_key_bundles_for_user(target_user_id) + if not result or not result.get("device_bundles"): + await send_resp(msg, writer, "get_key_bundle", "error", {"message": "Key bundle not available"}) + return + + device_bundles_data = [] + for b in result["device_bundles"]: + entry = { + "device_id": b.get("device_id"), + "signed_prekey_id": b["signed_prekey_id"], + "signed_prekey": encode_binary(b["signed_prekey_pub"]), + "spk_signature": encode_binary(b["spk_signature"]), + } + if b.get("opk_pub"): + entry["one_time_prekey_id"] = b["opk_id"] + entry["one_time_prekey"] = encode_binary(b["opk_pub"]) + device_bundles_data.append(entry) + + # Build response with both new multi-device format and legacy flat fields + first = device_bundles_data[0] if device_bundles_data else {} + data = { + "identity_key": encode_binary(result["identity_key"]), + "device_bundles": device_bundles_data, + # Legacy flat fields from first device bundle (backward compat) + "signed_prekey_id": first.get("signed_prekey_id", ""), + "signed_prekey": first.get("signed_prekey", ""), + "spk_signature": first.get("spk_signature", ""), + } + if first.get("one_time_prekey"): + data["one_time_prekey_id"] = first["one_time_prekey_id"] + data["one_time_prekey"] = first["one_time_prekey"] + logger.info("[X3DH] %s fetched key bundle for user=%s (%d devices)", + _who(session), target_user_id[:8], len(device_bundles_data)) + await send_resp(msg, writer, "get_key_bundle", "ok", data) + + +async def handle_get_prekey_count(msg: dict, session: dict, writer: ProtocolWriter): + """How many OPKs does user have left (for this device)? Also returns SPK age for rotation.""" + device_id = session.get("device_id") + count = await adb.count_one_time_prekeys(session["user_id"], device_id=device_id) + spk_created_at = "" + spk = await adb.get_signed_prekey(session["user_id"], device_id=device_id) + if spk and spk.get("created_at"): + spk_created_at = spk["created_at"].isoformat() if hasattr(spk["created_at"], "isoformat") else str(spk["created_at"]) + await send_resp(msg, writer, "get_prekey_count", "ok", + {"count": count, "spk_created_at": spk_created_at}) + + +async def handle_ensure_prekeys(msg: dict, session: dict, writer: ProtocolWriter): + """Combined get_prekey_count + upload_prekeys in one round-trip. + + Client sends current OPK/SPK data; server checks count and SPK age, + stores new keys if provided, and returns the current status. + """ + if await _is_rate_limited(f"ensure_prekeys|{session['user_id']}", 5): + await send_resp(msg, writer, "ensure_prekeys", "error", {"message": "Too many requests. Try later."}) + return + device_id = session.get("device_id") + user_id = session["user_id"] + + # Step 1: Get current count + SPK age + count = await adb.count_one_time_prekeys(user_id, device_id=device_id) + spk_created_at = "" + spk = await adb.get_signed_prekey(user_id, device_id=device_id) + if spk and spk.get("created_at"): + spk_created_at = spk["created_at"].isoformat() if hasattr(spk["created_at"], "isoformat") else str(spk["created_at"]) + + # Step 2: If client included new keys, store them + uploaded_spk = False + uploaded_otps = 0 + spk_data = msg.get("signed_prekey") + if spk_data: + spk_id = spk_data.get("id", "") + spk_pub_b64 = spk_data.get("public_key", "") + spk_sig_b64 = spk_data.get("signature", "") + if spk_id and spk_pub_b64 and spk_sig_b64: + spk_pub = decode_binary(spk_pub_b64) + spk_sig = decode_binary(spk_sig_b64) + user = await adb.get_user_by_id(user_id) + if user and user.get("identity_key"): + ik_pub = load_ed25519_public(user["identity_key"]) + if ed25519_verify(ik_pub, spk_sig, spk_pub): + await adb.store_signed_prekey(user_id, spk_id, spk_pub, spk_sig, device_id=device_id) + uploaded_spk = True + + otps = msg.get("one_time_prekeys", []) + if otps: + otp_records = [] + for otp in otps: + otp_id = otp.get("id", "") + otp_pub_b64 = otp.get("public_key", "") + if otp_id and otp_pub_b64: + otp_records.append({"id": otp_id, "public_key": decode_binary(otp_pub_b64)}) + if otp_records: + await adb.store_one_time_prekeys(user_id, otp_records, device_id=device_id) + uploaded_otps = len(otp_records) + + # Recount after upload + if uploaded_spk or uploaded_otps: + count = await adb.count_one_time_prekeys(user_id, device_id=device_id) + spk = await adb.get_signed_prekey(user_id, device_id=device_id) + if spk and spk.get("created_at"): + spk_created_at = spk["created_at"].isoformat() if hasattr(spk["created_at"], "isoformat") else str(spk["created_at"]) + logger.info("[PREKEYS] %s ensure_prekeys: uploaded SPK=%s, OTPs=%d, new count=%d", + _who(session), uploaded_spk, uploaded_otps, count) + + await send_resp(msg, writer, "ensure_prekeys", "ok", + {"count": count, "spk_created_at": spk_created_at, + "uploaded_spk": uploaded_spk, "uploaded_otps": uploaded_otps}) + + +async def handle_rotate_keys(msg: dict, session: dict, writer: ProtocolWriter): + if await _is_rate_limited(f"rotate_keys|{session['user_id']}", 3): + await send_resp(msg, writer, "rotate_keys", "error", {"message": "Too many requests. Try later."}) + return + public_key = msg.get("public_key", "").strip() + if not public_key: + await send_resp(msg, writer, "rotate_keys", "error", {"message": "Missing public_key"}) + return + if not _validate_public_key_pem(public_key): + await send_resp(msg, writer, "rotate_keys", "error", {"message": "Invalid public key format"}) + return + await adb.update_user_rsa_key(session["user_id"], public_key) + logger.info("[ROTATE] %s rotated RSA key", _who(session)) + await send_resp(msg, writer, "rotate_keys", "ok", {"message": "OK"}) + # Disconnect other sessions + async with _clients_lock: + writers = connected_clients.get(session["user_id"], []) + others = [w for w in writers if w is not writer] + connected_clients[session["user_id"]] = [writer] + for w in others: + try: + w.close() + except Exception: + pass + + +async def handle_change_username(msg: dict, session: dict, writer: ProtocolWriter): + if await _is_rate_limited(f"change_username|{session['user_id']}", 5): + await send_resp(msg, writer, "change_username", "error", {"message": "Too many requests. Try later."}) + return + new_username = msg.get("username", "").strip() + if not new_username or len(new_username) > 100: + await send_resp(msg, writer, "change_username", "error", {"message": "Invalid username (1-100 chars)"}) + return + user_id = session["user_id"] + await adb.update_username(user_id, new_username) + session["username"] = new_username + logger.info("[ACCOUNT] %s changed username", _who(session)) + await send_resp(msg, writer, "change_username", "ok", {"username": new_username}) + # Notify contacts + contacts = await adb.get_user_contacts(user_id) + targets = [] + async with _clients_lock: + for cid in contacts: + for cw in connected_clients.get(cid, []): + targets.append(cw) + for cw in targets: + try: + await cw.send_response("username_changed", "ok", { + "user_id": user_id, "username": new_username, + }) + except Exception: + pass + + +async def handle_pairing_start(msg: dict, writer: ProtocolWriter): + await _cleanup_pairings() + email = msg.get("email", "").strip() + temp_public_key = msg.get("temp_public_key", "").strip() + addr = _get_peer_addr(writer) + # H4 fix: rate limit per IP only (not per email) to prevent enumeration via email rotation + if await _is_rate_limited(_rate_limit_key("pairing_start", addr), 10): + await send_resp(msg, writer, "pairing_start", "error", {"message": "Too many attempts. Try later."}) + return + if not email or not temp_public_key: + await send_resp(msg, writer, "pairing_start", "error", {"message": "Missing email or temp_public_key"}) + return + poll_token = secrets.token_hex(16) + cap_hit = False + async with _pairing_lock: + # H4 fix: global cap prevents memory exhaustion from dummy sessions + if len(pairing_sessions) >= PAIRING_MAX_SESSIONS: + cap_hit = True + else: + code = _generate_pairing_code() + # H4 fix: always create session (anti-enumeration). For non-existent users + # the session behaves identically (poll returns ready:false, claim never matches + # because no real account can log in to claim it). TTL cleanup handles expiry. + pairing_sessions[code] = { + "email": email, + "temp_public_key": temp_public_key, + "created_at": asyncio.get_event_loop().time(), + "payload": None, + "poll_token": poll_token, + } + if cap_hit: + await send_resp(msg, writer, "pairing_start", "error", {"message": "Too many attempts. Try later."}) + return + await send_resp(msg, writer, "pairing_start", "ok", {"code": code, "poll_token": poll_token}) + + +async def handle_pairing_claim(msg: dict, session: dict, writer: ProtocolWriter): + await _cleanup_pairings() + code = msg.get("code", "").strip() + if not code: + await send_resp(msg, writer, "pairing_claim", "error", {"message": "Missing code"}) + return + async with _pairing_lock: + p = pairing_sessions.get(code) + p_email = p["email"] if p else None + temp_pub = p["temp_public_key"] if p else None + if p: + # Extend TTL — re-encryption may run between claim and send + p["created_at"] = asyncio.get_event_loop().time() + # H4 fix: unified error message (anti-enumeration) + if not p or p_email != session.get("email"): + await send_resp(msg, writer, "pairing_claim", "error", {"message": "Invalid or expired code"}) + return + await send_resp(msg, writer, "pairing_claim", "ok", {"temp_public_key": temp_pub}) + + +async def handle_pairing_send(msg: dict, session: dict, writer: ProtocolWriter): + await _cleanup_pairings() + code = msg.get("code", "").strip() + payload = msg.get("payload") + if not code or not payload: + await send_resp(msg, writer, "pairing_send", "error", {"message": "Missing code or payload"}) + return + error_msg = None + async with _pairing_lock: + p = pairing_sessions.get(code) + # H4 fix: unified error message (anti-enumeration) + if not p or p["email"] != session.get("email"): + error_msg = "Invalid or expired code" + else: + p["payload"] = payload + if error_msg: + await send_resp(msg, writer, "pairing_send", "error", {"message": error_msg}) + else: + await send_resp(msg, writer, "pairing_send", "ok", {"message": "OK"}) + + +async def handle_pairing_poll(msg: dict, writer: ProtocolWriter): + await _cleanup_pairings() + code = msg.get("code", "").strip() + poll_token = msg.get("poll_token", "").strip() + addr = _get_peer_addr(writer) + if await _is_rate_limited(_rate_limit_key("pairing_poll", addr), 120): + await send_resp(msg, writer, "pairing_poll", "error", {"message": "Too many attempts. Try later."}) + return + if not code: + await send_resp(msg, writer, "pairing_poll", "error", {"message": "Missing code"}) + return + if not poll_token: + await send_resp(msg, writer, "pairing_poll", "error", {"message": "Missing poll_token"}) + return + error_msg = None + ready = False + payload = None + async with _pairing_lock: + p = pairing_sessions.get(code) + if not p: + error_msg = "Invalid or expired code" + elif not secrets.compare_digest(p.get("poll_token", ""), poll_token): + error_msg = "Invalid poll_token" + else: + poll_attempts = p.get("poll_attempts", 0) + 1 + p["poll_attempts"] = poll_attempts + if poll_attempts > PAIRING_MAX_POLL_ATTEMPTS and not p.get("payload"): + pairing_sessions.pop(code, None) + error_msg = "Code invalidated due to too many attempts" + elif p.get("payload"): + ready = True + payload = p["payload"] + pairing_sessions.pop(code, None) + if error_msg: + await send_resp(msg, writer, "pairing_poll", "error", {"message": error_msg}) + elif ready: + await send_resp(msg, writer, "pairing_poll", "ok", {"ready": True, "payload": payload}) + else: + await send_resp(msg, writer, "pairing_poll", "ok", {"ready": False}) + + +async def handle_create_conversation(msg: dict, session: dict, writer: ProtocolWriter): + member_emails = msg.get("members", []) + name = msg.get("name") + addr = _get_peer_addr(writer) + if await _is_rate_limited(f"create_conversation|{session['user_id']}", 10): + await send_resp(msg, writer, "create_conversation", "error", {"message": "Too many attempts. Try later."}) + return + # Resolve all member user IDs + other_users = [] + for email in member_emails: + u = await adb.get_user_by_email(email) + if not u: + if not _valid_email(email): + await send_resp(msg, writer, "create_conversation", "error", {"message": f"Invalid email format: {email}"}) + return + # H5: atomic phantom creation (cap check + DB create + set add) + u, err_msg = await _create_phantom_guarded(email, addr, session["user_id"]) + if u is None: + await send_resp(msg, writer, "create_conversation", "error", {"message": err_msg}) + return + if u["id"] != session["user_id"]: + other_users.append(u) + is_dm = len(other_users) == 1 and not name + joined_at = datetime.now(timezone.utc) + if is_dm: + # DMs: add both members directly (no invitation) + all_ids = [session["user_id"]] + [u["id"] for u in other_users] + conv_id = await adb.create_conversation(all_ids, joined_at=joined_at, name=name, created_by=session["user_id"]) + logger.info("[CONV] %s created DM conv=%s", _who(session), conv_id[:8]) + await send_resp(msg, writer, "create_conversation", "ok", {"conversation_id": conv_id}) + # Notify the other member + members_info = await adb.get_conversation_members(conv_id) + member_list = [{"user_id": m["id"], "username": m["username"], "email": m["email"]} for m in members_info] + notif_data = { + "conversation_id": conv_id, + "name": name, + "created_by": session["user_id"], + "members": member_list, + } + await _notify_users([u["id"] for u in other_users], "conversation_created", notif_data) + else: + # Groups: only add creator, create invitations for others + conv_id = await adb.create_conversation([session["user_id"]], joined_at=joined_at, name=name, created_by=session["user_id"]) + logger.info("[CONV] %s created group conv=%s", + _who(session), conv_id[:8]) + # Create invitations for other members + creator_user = await adb.get_user_by_id(session["user_id"]) + creator_name = creator_user["username"] if creator_user else "Unknown" + invited_ids = [] + async with _clients_lock: + phantom_snapshot = set(phantom_user_ids) + for u in other_users: + await adb.create_invitation(conv_id, u["id"], session["user_id"]) + if u["id"] not in phantom_snapshot: + invited_ids.append(u["id"]) # only notify non-phantoms + inv_notif = { + "conversation_id": conv_id, + "conversation_name": name, + "invited_by": session["user_id"], + "invited_by_username": creator_name, + } + await _notify_users(invited_ids, "group_invitation", inv_notif) + await send_resp(msg, writer, "create_conversation", "ok", {"conversation_id": conv_id}) + + +async def handle_find_conversation(msg: dict, session: dict, writer: ProtocolWriter): + email = msg.get("email", "").strip() + if not email: + await send_resp(msg, writer, "find_conversation", "error", {"message": "Invalid request"}) + return + addr = _get_peer_addr(writer) + if await _is_rate_limited(_rate_limit_key("find_conversation", addr, email), 30): + await send_resp(msg, writer, "find_conversation", "error", {"message": "Too many attempts. Try later."}) + return + other = await adb.get_user_by_email(email) + if not other: + if not _valid_email(email): + await send_resp(msg, writer, "find_conversation", "error", {"message": "Invalid email format"}) + return + # H5: atomic phantom creation (cap check + DB create + set add) + other, err_msg = await _create_phantom_guarded(email, addr, session["user_id"]) + if other is None: + await send_resp(msg, writer, "find_conversation", "error", {"message": err_msg}) + return + conv_id = await adb.find_direct_conversation(session["user_id"], other["id"]) + await send_resp(msg, writer, "find_conversation", "ok", { + "conversation_id": conv_id, + "user_id": other["id"], + }) + + +async def handle_add_member(msg: dict, session: dict, writer: ProtocolWriter): + conv_id = msg.get("conversation_id", "") + email = msg.get("email", "").strip() + if not conv_id or not email: + await send_resp(msg, writer, "add_member", "error", {"message": "Invalid request"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, "add_member", "error", {"message": "Invalid conversation_id"}) + return + # L8: validate email format before phantom creation + addr = _get_peer_addr(writer) + if await _is_rate_limited(_rate_limit_key("add_member", addr, email), 10): + await send_resp(msg, writer, "add_member", "error", {"message": "Too many attempts. Try later."}) + return + if not await adb.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "add_member", "error", {"message": "Not a member"}) + return + user = await adb.get_user_by_email(email) + if not user: + # Create phantom for unregistered email (same as create_conversation) + if not _valid_email(email): + await send_resp(msg, writer, "add_member", "error", {"message": "Invalid email format"}) + return + # H5: atomic phantom creation (cap check + DB create + set add) + user, err_msg = await _create_phantom_guarded(email, addr, session["user_id"]) + if user is None: + await send_resp(msg, writer, "add_member", "error", {"message": err_msg}) + return + if await adb.is_conversation_member(conv_id, user["id"]): + await send_resp(msg, writer, "add_member", "error", {"message": "Already a member"}) + return + if await adb.has_pending_invitation(conv_id, user["id"]): + await send_resp(msg, writer, "add_member", "error", {"message": "Invitation already pending"}) + return + # Create invitation (for both real and phantom users) + await adb.create_invitation(conv_id, user["id"], session["user_id"]) + logger.info("[INVITE] %s invited u=%s to conv=%s", _who(session), user["id"][:8], conv_id[:8]) + await send_resp(msg, writer, "add_member", "ok", {"user_id": user["id"]}) + # Push invitation notification only to non-phantom users + async with _clients_lock: + is_phantom = user["id"] in phantom_user_ids + if not is_phantom: + conv = await adb.get_conversation(conv_id) + creator_user = await adb.get_user_by_id(session["user_id"]) + creator_name = creator_user["username"] if creator_user else "Unknown" + inv_notif = { + "conversation_id": conv_id, + "conversation_name": conv.get("name") if conv else None, + "invited_by": session["user_id"], + "invited_by_username": creator_name, + } + await _notify_users([user["id"]], "group_invitation", inv_notif) + + +async def handle_accept_invitation(msg: dict, session: dict, writer: ProtocolWriter): + """Accept a group invitation — add user to conversation members.""" + conv_id = msg.get("conversation_id", "") + if not conv_id: + await send_resp(msg, writer, "accept_invitation", "error", {"message": "Missing conversation_id"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, "accept_invitation", "error", {"message": "Invalid conversation_id"}) + return + if not await adb.has_pending_invitation(conv_id, session["user_id"]): + await send_resp(msg, writer, "accept_invitation", "error", {"message": "No pending invitation"}) + return + joined_at = datetime.now(timezone.utc) + await adb.add_conversation_member(conv_id, session["user_id"], joined_at=joined_at) + await adb.delete_invitation(conv_id, session["user_id"]) + logger.info("[INVITE] %s accepted invitation to conv=%s", _who(session), conv_id[:8]) + await send_resp(msg, writer, "accept_invitation", "ok", {"conversation_id": conv_id}) + # Notify existing members about the new member + user = await adb.get_user_by_id(session["user_id"]) + notif_data = { + "conversation_id": conv_id, + "user_id": session["user_id"], + "username": user["username"] if user else "", + "email": user["email"] if user else "", + } + members = await adb.get_conversation_members(conv_id) + member_ids = [m["id"] for m in members if m["id"] != session["user_id"]] + await _notify_users(member_ids, "member_added", notif_data) + + +async def handle_decline_invitation(msg: dict, session: dict, writer: ProtocolWriter): + """Decline a group invitation.""" + conv_id = msg.get("conversation_id", "") + if not conv_id: + await send_resp(msg, writer, "decline_invitation", "error", {"message": "Missing conversation_id"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, "decline_invitation", "error", {"message": "Invalid conversation_id"}) + return + if not await adb.has_pending_invitation(conv_id, session["user_id"]): + await send_resp(msg, writer, "decline_invitation", "error", {"message": "No pending invitation"}) + return + await adb.delete_invitation(conv_id, session["user_id"]) + logger.info("[INVITE] %s declined invitation to conv=%s", _who(session), conv_id[:8]) + await send_resp(msg, writer, "decline_invitation", "ok", {"message": "OK"}) + + +async def handle_list_invitations(msg: dict, session: dict, writer: ProtocolWriter): + """List pending group invitations for the current user.""" + invitations = await adb.get_pending_invitations(session["user_id"]) + result = [] + for inv in invitations: + entry = { + "conversation_id": inv["conversation_id"], + "conversation_name": inv.get("conversation_name"), + "invited_by": inv["invited_by"], + "invited_by_username": inv.get("invited_by_username", ""), + "created_at": inv["created_at"].isoformat() if hasattr(inv["created_at"], "isoformat") else str(inv["created_at"]), + } + result.append(entry) + await send_resp(msg, writer, "list_invitations", "ok", {"invitations": result}) + + +async def handle_list_conversations(msg: dict, session: dict, writer: ProtocolWriter): + convs = await adb.list_user_conversations(session["user_id"]) + unread = await adb.get_unread_counts(session["user_id"], max_age_days=METADATA_RETENTION_DAYS) + result = [] + for c in convs: + result.append({ + "conversation_id": c["id"], + "created_at": c["created_at"].isoformat() if hasattr(c["created_at"], "isoformat") else str(c["created_at"]), + "members": c["members"], + "name": c.get("name"), + "created_by": c.get("created_by"), + "avatar_file": c.get("avatar_file"), + "unread_count": unread.get(c["id"], 0), + }) + logger.info("[LIST] %s listed %d conversations", _who(session), len(result)) + await send_resp(msg, writer, "list_conversations", "ok", {"conversations": result}) + + +async def handle_send_message(msg: dict, session: dict, writer: ProtocolWriter): + conv_id = msg.get("conversation_id", "") + if not conv_id: + await send_resp(msg, writer, "send_message", "error", {"message": "Missing conversation_id"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, "send_message", "error", {"message": "Invalid conversation_id"}) + return + addr = _get_peer_addr(writer) + if await _is_rate_limited(_rate_limit_key("send_message", addr, session.get("email")), 20): + await send_resp(msg, writer, "send_message", "error", {"message": "Too many attempts. Try later."}) + return + if not await adb.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "send_message", "error", {"message": "Not a member"}) + return + + # New protocol: ratchet_header + recipients[] with per-user ciphertext + ratchet_header_raw = msg.get("ratchet_header") + recipients_raw = msg.get("recipients") + if not ratchet_header_raw or not recipients_raw: + await send_resp(msg, writer, "send_message", "error", {"message": "Missing ratchet_header or recipients"}) + return + + # C2 fix: validate header is a dict (reject raw str/bytes) + ratchet_header = _validate_header(ratchet_header_raw, "ratchet_header") + if ratchet_header is None: + await send_resp(msg, writer, "send_message", "error", {"message": "Invalid ratchet_header format"}) + return + + x3dh_header_raw = msg.get("x3dh_header") + x3dh_header = None + if x3dh_header_raw: + x3dh_header = _validate_header(x3dh_header_raw, "x3dh_header") + if x3dh_header is None: + await send_resp(msg, writer, "send_message", "error", {"message": "Invalid x3dh_header format"}) + return + + sender_chain_id_b64 = msg.get("sender_chain_id") + sender_chain_id = decode_binary(sender_chain_id_b64) if sender_chain_id_b64 else None + sender_chain_n = msg.get("sender_chain_n") + + # Validate recipients are actual members + conv_members = await adb.get_conversation_members(conv_id) + member_ids = {m["id"] for m in conv_members} + async with _clients_lock: + phantom_snapshot = set(phantom_user_ids) + db_recipients = [] + for r in recipients_raw: + uid = r.get("user_id", "") + if uid not in member_ids: + continue + if uid in phantom_snapshot: + continue + ct_b64 = r.get("encrypted_content", "") + nonce_b64 = r.get("nonce", "") + if not ct_b64 or not nonce_b64: + continue + entry = { + "user_id": uid, + "encrypted_content": decode_binary(ct_b64), + "nonce": decode_binary(nonce_b64), + } + # Per-recipient device_id (multi-device support) + r_device_id = r.get("device_id") + if r_device_id: + entry["device_id"] = r_device_id + # Per-recipient ratchet header and x3dh header (C2 fix: validate dict) + r_rh = r.get("ratchet_header") + if r_rh: + r_rh_bytes = _validate_header(r_rh, "recipient_ratchet_header") + if r_rh_bytes: + entry["ratchet_header"] = r_rh_bytes + r_x3dh = r.get("x3dh_header") + if r_x3dh: + r_x3dh_bytes = _validate_header(r_x3dh, "recipient_x3dh_header") + if r_x3dh_bytes: + entry["x3dh_header"] = r_x3dh_bytes + db_recipients.append(entry) + if not db_recipients: + await send_resp(msg, writer, "send_message", "error", {"message": "No valid recipients"}) + return + + image_file_id = msg.get("image_file_id") + + # Metadata privacy: for group messages (sender_chain_id present), store chain + # metadata in per-recipient ratchet_header instead of the messages table. + # This avoids persisting sender correlation data at the message level. + # Skip sender's own self-copy entry — it uses a different decrypt path + # (self-encryption key) and must keep its own ratchet_header ({"self":true}). + db_sender_chain_id = None + db_sender_chain_n = None + if sender_chain_id: + chain_meta = json.dumps({ + "chain_id": encode_binary(sender_chain_id), + "chain_n": sender_chain_n, + }).encode() + sender_uid = session["user_id"] + for r in db_recipients: + # Skip self-copy (sender's own entry) — uses self-encryption, not sender key + if r["user_id"] == sender_uid: + continue + if not r.get("ratchet_header"): + r["ratchet_header"] = chain_meta + + msg_id, created_at = await adb.store_message( + conv_id, session["user_id"], ratchet_header, db_recipients, + x3dh_header=x3dh_header, + sender_chain_id=db_sender_chain_id, + sender_chain_n=db_sender_chain_n, + image_file_id=image_file_id, + sender_device_id=session.get("device_id"), + ) + + # Link image upload to message if present + if image_file_id: + upload = await adb.get_image_upload(image_file_id) + if upload and upload["completed"] and upload["uploader_id"] == session["user_id"]: + await adb.set_message_image_file_id(msg_id, image_file_id) + + logger.info("[MSG] %s msg=%s conv=%s", _who(session), msg_id[:8], conv_id[:8]) + await send_resp(msg, writer, "send_message", "ok", {"message_id": msg_id, "created_at": created_at}) + + # Notify connected recipients — group all per-device entries by user_id + # Use validated db_recipients (not raw input) to prevent unvalidated headers in push + msg_ratchet_header_dict = json.loads(ratchet_header.decode()) + msg_x3dh_header_dict = json.loads(x3dh_header.decode()) if x3dh_header else None + + from collections import defaultdict + user_entries = defaultdict(list) + for r in db_recipients: + uid = r["user_id"] + # Per-recipient headers are stored as bytes; decode back to dict for notification JSON + r_rh = r.get("ratchet_header") + r_rh_dict = json.loads(r_rh.decode()) if r_rh else None + r_x3dh = r.get("x3dh_header") + r_x3dh_dict = json.loads(r_x3dh.decode()) if r_x3dh else None + user_entries[uid].append({ + "device_id": r.get("device_id", db.SELF_DEVICE_ID), + "encrypted_content": encode_binary(r["encrypted_content"]), + "nonce": encode_binary(r["nonce"]), + "ratchet_header": r_rh_dict or msg_ratchet_header_dict, + "x3dh_header": r_x3dh_dict or msg_x3dh_header_dict, + }) + + notifications = [] + for uid, entries in user_entries.items(): + notif_data = { + "message_id": msg_id, + "conversation_id": conv_id, + "sender_id": session["user_id"], + "sender_device_id": session.get("device_id"), + "device_entries": entries, + } + if sender_chain_id_b64: + notif_data["sender_chain_id"] = sender_chain_id_b64 + if sender_chain_n is not None: + notif_data["sender_chain_n"] = sender_chain_n + # Also include flat fields for backward compat with old clients + # (first entry's data as fallback) + if entries: + first = entries[0] + notif_data["ratchet_header"] = first.get("ratchet_header") or msg_ratchet_header_dict + notif_data["encrypted_content"] = first.get("encrypted_content", "") + notif_data["nonce"] = first.get("nonce", "") + if first.get("x3dh_header"): + notif_data["x3dh_header"] = first["x3dh_header"] + notifications.append((uid, "new_message", notif_data)) + await _notify_users_individual(notifications, exclude_writer=writer) + + +async def handle_get_messages(msg: dict, session: dict, writer: ProtocolWriter): + if await _is_rate_limited(f"get_messages|{session['user_id']}", 30): + await send_resp(msg, writer, "get_messages", "error", {"message": "Too many requests. Try later."}) + return + conv_id = msg.get("conversation_id", "") + if not conv_id: + await send_resp(msg, writer, "get_messages", "error", {"message": "Missing conversation_id"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, "get_messages", "error", {"message": "Invalid conversation_id"}) + return + if not await adb.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "get_messages", "error", {"message": "Not a member"}) + return + + limit = min(max(int(msg.get("limit", 50)), 1), 200) + offset = max(int(msg.get("offset", 0)), 0) + device_id = session.get("device_id") + after_ts = msg.get("after_ts") # ISO timestamp string or None + messages = await adb.get_messages(conv_id, session["user_id"], limit, offset, + device_id=device_id, after_ts=after_ts) + + # Deduplicate: when both device-specific and SELF_DEVICE_ID rows exist for the + # same message, prefer device-specific (non-sentinel). Keep first seen per message_id. + seen_ids = {} + deduped = [] + for m in messages: + mid = m["id"] + mr_dev = m.get("mr_device_id", "") + if mid not in seen_ids: + seen_ids[mid] = len(deduped) + deduped.append(m) + elif mr_dev != db.SELF_DEVICE_ID: + # Replace SELF_DEVICE_ID entry with device-specific one + deduped[seen_ids[mid]] = m + messages = deduped + + result = [] + message_ids = [m["id"] for m in messages] + read_status = await adb.get_message_read_status(message_ids) if message_ids else {} + delivery_status = await adb.get_message_delivery_status(message_ids) if message_ids else {} + reactions_map = await adb.get_reactions(message_ids) if message_ids else {} + for m in messages: + read_by = read_status.get(m["id"], []) + # Prefer per-recipient headers (mr_*) over message-level headers + rh_raw = m.get("mr_ratchet_header") or m.get("ratchet_header") + x3dh_raw = m.get("mr_x3dh_header") or m.get("x3dh_header") + # C2 fix: defensive JSON parsing — corrupted headers don't break fetch + try: + rh_parsed = json.loads(rh_raw) if rh_raw else {} + except (json.JSONDecodeError, TypeError, UnicodeDecodeError): + logger.warning("[FETCH] Corrupted ratchet_header in message %s, skipping", m["id"]) + rh_parsed = {} + try: + x3dh_parsed = json.loads(x3dh_raw) if x3dh_raw else None + except (json.JSONDecodeError, TypeError, UnicodeDecodeError): + logger.warning("[FETCH] Corrupted x3dh_header in message %s, skipping", m["id"]) + x3dh_parsed = None + entry = { + "message_id": m["id"], + "sender_id": m.get("sender_id") or "", + "ratchet_header": rh_parsed, + "encrypted_content": encode_binary(m["encrypted_content"]) if m.get("encrypted_content") else "", + "nonce": encode_binary(m["nonce"]) if m.get("nonce") else "", + "created_at": m["created_at"].isoformat() if hasattr(m["created_at"], "isoformat") else str(m["created_at"]), + "read_by": read_by, + "delivered_to": delivery_status.get(m["id"], []), + } + if x3dh_parsed: + entry["x3dh_header"] = x3dh_parsed + # Sender chain metadata: check message-level first (backward compat), + # then per-recipient ratchet_header (new metadata-private format). + # Only extract from per-recipient header if message-level ratchet_header + # is the group dummy (dh_pub all-zeros) — prevents DM header injection. + if m.get("sender_chain_id"): + entry["sender_chain_id"] = encode_binary(m["sender_chain_id"]) + elif isinstance(rh_parsed, dict) and rh_parsed.get("chain_id"): + # Verify this is a group message by checking the message-level header + msg_rh_raw = m.get("ratchet_header") + is_group = False + if msg_rh_raw: + try: + msg_rh = json.loads(msg_rh_raw) if isinstance(msg_rh_raw, (bytes, str)) else msg_rh_raw + is_group = isinstance(msg_rh, dict) and msg_rh.get("dh_pub") == "00" * 32 + except (json.JSONDecodeError, TypeError, UnicodeDecodeError): + pass + if is_group: + entry["sender_chain_id"] = rh_parsed["chain_id"] + if m.get("sender_chain_n") is not None: + entry["sender_chain_n"] = m["sender_chain_n"] + elif isinstance(rh_parsed, dict) and rh_parsed.get("chain_n") is not None: + # Same group-only guard + if "sender_chain_id" in entry: + entry["sender_chain_n"] = rh_parsed["chain_n"] + if m.get("sender_device_id"): + entry["sender_device_id"] = m["sender_device_id"] + if m.get("deleted_at"): + entry["deleted_at"] = m["deleted_at"].isoformat() if hasattr(m["deleted_at"], "isoformat") else str(m["deleted_at"]) + # Pin metadata + if m.get("pinned_at"): + entry["pinned_at"] = m["pinned_at"].isoformat() if hasattr(m["pinned_at"], "isoformat") else str(m["pinned_at"]) + entry["pinned_by"] = m.get("pinned_by") or "" + # Reactions + msg_reactions = reactions_map.get(m["id"]) + if msg_reactions: + entry["reactions"] = msg_reactions + result.append(entry) + total_count = await adb.count_messages(conv_id, session["user_id"]) + logger.info("[FETCH] %s fetched %d/%d msgs from conv=%s (limit=%d, offset=%d%s)", + _who(session), len(result), total_count, conv_id[:8], limit, offset, + f", after={after_ts}" if after_ts else "") + await send_resp(msg, writer, "get_messages", "ok", + {"messages": result, "total_count": total_count}) + + +async def handle_remove_member(msg: dict, session: dict, writer: ProtocolWriter): + if await _is_rate_limited(f"remove_member|{session['user_id']}", 10): + await send_resp(msg, writer, "remove_member", "error", {"message": "Too many requests. Try later."}) + return + conv_id = msg.get("conversation_id", "") + user_id = msg.get("user_id", "") + if not conv_id or not user_id: + await send_resp(msg, writer, "remove_member", "error", {"message": "Missing conversation_id or user_id"}) + return + if not _valid_uuid(conv_id) or not _valid_uuid(user_id): + await send_resp(msg, writer, "remove_member", "error", {"message": "Invalid conversation_id or user_id"}) + return + if not await adb.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "remove_member", "error", {"message": "Not a member"}) + return + convs = await adb.list_user_conversations(session["user_id"]) + conv_data = None + for c in convs: + if c["id"] == conv_id: + conv_data = c + break + if not conv_data or conv_data.get("created_by") != session["user_id"]: + await send_resp(msg, writer, "remove_member", "error", {"message": "Only the group creator can remove members"}) + return + if user_id == session["user_id"]: + await send_resp(msg, writer, "remove_member", "error", {"message": "Cannot remove yourself"}) + return + # Get remaining members before removing (to notify them) + members_before = await adb.get_conversation_members(conv_id) + # M6: atomic removal — return value confirms row existed + removed = await adb.remove_conversation_member_atomic(conv_id, user_id) + if not removed: + await send_resp(msg, writer, "remove_member", "error", {"message": "Member already removed"}) + return + logger.info("[MEMBER] %s removed user=%s from conv=%s", _who(session), user_id[:8], conv_id[:8]) + await send_resp(msg, writer, "remove_member", "ok", {"message": "OK"}) + + # Notify removed member and remaining members + notif_data = { + "conversation_id": conv_id, + "user_id": user_id, + } + member_ids = [m["id"] for m in members_before if m["id"] != session["user_id"]] + await _notify_users(member_ids, "member_removed", notif_data) + + +async def handle_leave_group(msg: dict, session: dict, writer: ProtocolWriter): + """Leave a group conversation voluntarily.""" + conv_id = msg.get("conversation_id", "") + if not conv_id: + await send_resp(msg, writer, "leave_group", "error", {"message": "Missing conversation_id"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, "leave_group", "error", {"message": "Invalid conversation_id"}) + return + if not await adb.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "leave_group", "error", {"message": "Not a member"}) + return + # Don't allow leaving DMs (2 members without a name) + conv = await adb.get_conversation(conv_id) + members = await adb.get_conversation_members(conv_id) + if len(members) <= 2 and not (conv and conv.get("name")): + await send_resp(msg, writer, "leave_group", "error", {"message": "Cannot leave a DM conversation"}) + return + # If creator is leaving, transfer to first remaining member + if conv and conv.get("created_by") == session["user_id"]: + remaining = [m for m in members if m["id"] != session["user_id"]] + if remaining: + await adb.update_conversation_creator(conv_id, remaining[0]["id"]) + # M6: atomic removal + await adb.remove_conversation_member_atomic(conv_id, session["user_id"]) + logger.info("[LEAVE] %s left group conv=%s", _who(session), conv_id[:8]) + await send_resp(msg, writer, "leave_group", "ok", {"message": "OK"}) + # Notify remaining members + notif_data = { + "conversation_id": conv_id, + "user_id": session["user_id"], + } + member_ids = [m["id"] for m in members if m["id"] != session["user_id"]] + await _notify_users(member_ids, "member_removed", notif_data) + + +async def handle_rename_conversation(msg: dict, session: dict, writer: ProtocolWriter): + """Rename a group conversation (creator only).""" + if await _is_rate_limited(f"rename_conv|{session['user_id']}", 5): + await send_resp(msg, writer, "rename_conversation", "error", {"message": "Too many requests. Try later."}) + return + conv_id = msg.get("conversation_id", "") + new_name = msg.get("name", "").strip() + if not conv_id or not new_name: + await send_resp(msg, writer, "rename_conversation", "error", {"message": "Missing conversation_id or name"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, "rename_conversation", "error", {"message": "Invalid conversation_id"}) + return + if len(new_name) > 100: + await send_resp(msg, writer, "rename_conversation", "error", {"message": "Name too long (max 100)"}) + return + if not await adb.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "rename_conversation", "error", {"message": "Not a member"}) + return + conv = await adb.get_conversation(conv_id) + if not conv or not conv.get("name"): + await send_resp(msg, writer, "rename_conversation", "error", {"message": "Cannot rename a DM conversation"}) + return + if conv.get("created_by") != session["user_id"]: + await send_resp(msg, writer, "rename_conversation", "error", {"message": "Only the group creator can rename"}) + return + await adb.update_conversation_name(conv_id, new_name) + logger.info("[RENAME] %s renamed conv=%s", _who(session), conv_id[:8]) + await send_resp(msg, writer, "rename_conversation", "ok", {"message": "OK"}) + # Notify all members + members = await adb.get_conversation_members(conv_id) + member_ids = [m["id"] for m in members if m["id"] != session["user_id"]] + await _notify_users(member_ids, "conversation_renamed", { + "conversation_id": conv_id, + "name": new_name, + "renamed_by": session["user_id"], + }) + + +async def handle_delete_conversation(msg: dict, session: dict, writer: ProtocolWriter): + """Delete a conversation for the current user. Removes user from members, + deletes the conversation if no members remain.""" + if await _is_rate_limited(f"delete_conv|{session['user_id']}", 5): + await send_resp(msg, writer, "delete_conversation", "error", {"message": "Too many requests. Try later."}) + return + conv_id = msg.get("conversation_id", "") + if not conv_id: + await send_resp(msg, writer, "delete_conversation", "error", {"message": "Missing conversation_id"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, "delete_conversation", "error", {"message": "Invalid conversation_id"}) + return + if not await adb.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "delete_conversation", "error", {"message": "Not a member"}) + return + conv = await adb.get_conversation(conv_id) + members = await adb.get_conversation_members(conv_id) + is_group = len(members) > 2 or (conv and conv.get("name")) + # Groups can only be deleted by the creator (admin) + if is_group and (not conv or conv.get("created_by") != session["user_id"]): + await send_resp(msg, writer, "delete_conversation", "error", {"message": "Only the group creator can delete this conversation"}) + return + if is_group: + # Group: creator deletes for everyone — remove all members, clean up, delete + for member in members: + await adb.remove_conversation_member(conv_id, member["id"]) + else: + # DM: only remove self; other user keeps the conversation + await adb.remove_conversation_member(conv_id, session["user_id"]) + remaining_count = await adb.count_conversation_members(conv_id) + if remaining_count == 0: + # Clean up uploaded files from disk + file_ids = await adb.get_conversation_file_ids(conv_id) + for fid in file_ids: + for ext in (".enc", ".tmp"): + p = _safe_upload_path(fid, ext) + if not p: + continue + _secure_delete(p) + await adb.delete_conversation(conv_id) + logger.info("[DELETE] %s deleted conv=%s", _who(session), conv_id[:8]) + await send_resp(msg, writer, "delete_conversation", "ok", {"message": "OK"}) + # Notify other members they were removed + notif_data = { + "conversation_id": conv_id, + "user_id": session["user_id"], + } + member_ids = [m["id"] for m in members if m["id"] != session["user_id"]] + await _notify_users(member_ids, "member_removed", notif_data) + + +async def handle_mark_read(msg: dict, session: dict, writer: ProtocolWriter): + conv_id = msg.get("conversation_id", "") + message_ids = msg.get("message_ids", []) + if not conv_id or not message_ids: + await send_resp(msg, writer, "mark_read", "error", {"message": "Missing conversation_id or message_ids"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, "mark_read", "error", {"message": "Invalid conversation_id"}) + return + if len(message_ids) > 500: + await send_resp(msg, writer, "mark_read", "error", {"message": "Too many message_ids (max 500)"}) + return + if not await adb.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "mark_read", "error", {"message": "Not a member"}) + return + # M1 fix: filter to only message_ids that belong to this conversation + valid_ids = await adb.filter_message_ids_by_conversation(conv_id, message_ids) + if not valid_ids: + await send_resp(msg, writer, "mark_read", "ok", {"message": "OK"}) + return + await adb.mark_messages_read(conv_id, session["user_id"], valid_ids) + logger.info("[READ] %s marked %d msgs read in conv=%s", _who(session), len(valid_ids), conv_id[:8]) + await send_resp(msg, writer, "mark_read", "ok", {"message": "OK"}) + members = await adb.get_conversation_members(conv_id) + notif_data = { + "conversation_id": conv_id, + "user_id": session["user_id"], + "message_ids": valid_ids, + } + member_ids = [m["id"] for m in members if m["id"] != session["user_id"]] + await _notify_users(member_ids, "messages_read", notif_data) + + +async def handle_mark_conversation_read(msg: dict, session: dict, writer: ProtocolWriter): + conv_id = msg.get("conversation_id", "") + if not conv_id: + await send_resp(msg, writer, "mark_conversation_read", "error", {"message": "Missing conversation_id"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, "mark_conversation_read", "error", {"message": "Invalid conversation_id"}) + return + if not await adb.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "mark_conversation_read", "error", {"message": "Not a member"}) + return + count = await adb.mark_conversation_read(conv_id, session["user_id"]) + logger.info("[READ] %s marked conv=%s all-read (%d msgs)", _who(session), conv_id[:8], count) + await send_resp(msg, writer, "mark_conversation_read", "ok", {"marked_count": count}) + if count > 0: + members = await adb.get_conversation_members(conv_id) + member_ids = [m["id"] for m in members if m["id"] != session["user_id"]] + await _notify_users(member_ids, "messages_read", { + "conversation_id": conv_id, + "user_id": session["user_id"], + "message_ids": [], + }) + + +async def handle_confirm_delivery(msg: dict, session: dict, writer: ProtocolWriter): + conv_id = msg.get("conversation_id", "") + message_ids = msg.get("message_ids", []) + if not conv_id or not message_ids: + await send_resp(msg, writer, "confirm_delivery", "error", {"message": "Missing conversation_id or message_ids"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, "confirm_delivery", "error", {"message": "Invalid conversation_id"}) + return + if len(message_ids) > 500: + await send_resp(msg, writer, "confirm_delivery", "error", {"message": "Too many message_ids (max 500)"}) + return + if not await adb.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "confirm_delivery", "error", {"message": "Not a member"}) + return + # M1 fix: filter to only message_ids that belong to this conversation + valid_ids = await adb.filter_message_ids_by_conversation(conv_id, message_ids) + if not valid_ids: + await send_resp(msg, writer, "confirm_delivery", "ok", {"message": "OK"}) + return + await adb.mark_messages_delivered(conv_id, session["user_id"], valid_ids) + logger.info("[DELIVERY] %s confirmed %d msgs delivered in conv=%s", _who(session), len(valid_ids), conv_id[:8]) + await send_resp(msg, writer, "confirm_delivery", "ok", {"message": "OK"}) + + # Notify senders — batch lookup sender_id per message, push to each sender + sender_msgs: dict[str, list[str]] = {} + for mid in valid_ids: + sid = await adb.get_message_sender(mid) + if sid and sid != session["user_id"]: + sender_msgs.setdefault(sid, []).append(mid) + for sender_id, mids in sender_msgs.items(): + await _notify_users([sender_id], "message_delivered", { + "conversation_id": conv_id, + "user_id": session["user_id"], + "message_ids": mids, + }) + + +async def handle_delete_message(msg: dict, session: dict, writer: ProtocolWriter): + if await _is_rate_limited(f"delete_msg|{session['user_id']}", 20): + await send_resp(msg, writer, "delete_message", "error", {"message": "Too many requests. Try later."}) + return + message_id = msg.get("message_id", "") + if not message_id: + await send_resp(msg, writer, "delete_message", "error", {"message": "Missing message_id"}) + return + if not _valid_uuid(message_id): + await send_resp(msg, writer, "delete_message", "error", {"message": "Invalid message_id"}) + return + conv_id = await adb.get_message_conversation(message_id) + if not conv_id: + await send_resp(msg, writer, "delete_message", "error", {"message": "Message not found"}) + return + if not await adb.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "delete_message", "error", {"message": "Not a member"}) + return + result = await adb.soft_delete_message(message_id, session["user_id"]) + if result is None: + await send_resp(msg, writer, "delete_message", "error", {"message": "Cannot delete this message"}) + return + image_file_id = result.get("image_file_id") + if image_file_id: + image_path = _safe_upload_path(image_file_id, ".enc") + if image_path: + _secure_delete(image_path) + await adb.delete_image_upload(image_file_id) + logger.info("[MSG] %s deleted message=%s", _who(session), message_id[:8]) + await send_resp(msg, writer, "delete_message", "ok", {"message_id": message_id}) + members = await adb.get_conversation_members(conv_id) + notif_data = {"message_id": message_id, "conversation_id": conv_id} + member_ids = [m["id"] for m in members if m["id"] != session["user_id"]] + await _notify_users(member_ids, "message_deleted", notif_data) + + +async def handle_react_message(msg: dict, session: dict, writer: ProtocolWriter): + if await _is_rate_limited(f"react|{session['user_id']}", 20): + await send_resp(msg, writer, "react_message", "error", {"message": "Too many requests. Try later."}) + return + message_id = msg.get("message_id", "") + reaction = msg.get("reaction", "") + action = msg.get("action", "add") # "add" or "remove" + + if not message_id or not reaction: + await send_resp(msg, writer, "react_message", "error", {"message": "Missing fields"}) + return + if not _valid_uuid(message_id): + await send_resp(msg, writer, "react_message", "error", {"message": "Invalid message_id"}) + return + if reaction not in db.ALLOWED_REACTIONS: + await send_resp(msg, writer, "react_message", "error", {"message": "Invalid reaction"}) + return + if action not in ("add", "remove"): + await send_resp(msg, writer, "react_message", "error", {"message": "Invalid action"}) + return + + conv_id = await adb.get_message_conversation(message_id) + if not conv_id: + await send_resp(msg, writer, "react_message", "error", {"message": "Message not found"}) + return + if not await adb.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "react_message", "error", {"message": "Not a member"}) + return + + old_reaction = None + if action == "add": + changed, old_reaction = await adb.add_reaction(message_id, session["user_id"], reaction) + if not changed: + await send_resp(msg, writer, "react_message", "ok", {"message_id": message_id}) + return + else: + await adb.remove_reaction(message_id, session["user_id"]) + + logger.info("[MSG] %s %s reaction '%s' on message=%s", _who(session), action, reaction, message_id[:8]) + resp_data = {"message_id": message_id} + if old_reaction: + resp_data["old_reaction"] = old_reaction + await send_resp(msg, writer, "react_message", "ok", resp_data) + + members = await adb.get_conversation_members(conv_id) + member_ids = [m["id"] for m in members if m["id"] != session["user_id"]] + + # If replacing an old reaction, notify removal first + if old_reaction: + remove_data = { + "message_id": message_id, + "conversation_id": conv_id, + "user_id": session["user_id"], + "username": session.get("username", ""), + "reaction": old_reaction, + "action": "remove", + } + await _notify_users(member_ids, "message_reacted", remove_data) + + notif_data = { + "message_id": message_id, + "conversation_id": conv_id, + "user_id": session["user_id"], + "username": session.get("username", ""), + "reaction": reaction, + "action": action, + } + await _notify_users(member_ids, "message_reacted", notif_data) + + +async def handle_pin_message(msg: dict, session: dict, writer: ProtocolWriter): + message_id = msg.get("message_id", "") + action = msg.get("action", "pin") # "pin" or "unpin" + conversation_id = msg.get("conversation_id", "") + + if not message_id or not conversation_id: + await send_resp(msg, writer, "pin_message", "error", {"message": "Missing fields"}) + return + if not _valid_uuid(message_id) or not _valid_uuid(conversation_id): + await send_resp(msg, writer, "pin_message", "error", {"message": "Invalid ID"}) + return + if action not in ("pin", "unpin"): + await send_resp(msg, writer, "pin_message", "error", {"message": "Invalid action"}) + return + if not await adb.is_conversation_member(conversation_id, session["user_id"]): + await send_resp(msg, writer, "pin_message", "error", {"message": "Not a member"}) + return + + if action == "pin": + ok = await adb.pin_message(message_id, session["user_id"], conversation_id) + else: + ok = await adb.unpin_message(message_id, conversation_id) + + if not ok: + await send_resp(msg, writer, "pin_message", "error", + {"message": "Already pinned" if action == "pin" else "Not pinned"}) + return + + logger.info("[MSG] %s %s message=%s in conv=%s", _who(session), action, message_id[:8], conversation_id[:8]) + await send_resp(msg, writer, "pin_message", "ok", {"message_id": message_id, "action": action}) + + members = await adb.get_conversation_members(conversation_id) + notif_type = "message_pinned" if action == "pin" else "message_unpinned" + notif_data = { + "message_id": message_id, + "conversation_id": conversation_id, + "user_id": session["user_id"], + "username": session.get("username", ""), + } + member_ids = [m["id"] for m in members if m["id"] != session["user_id"]] + await _notify_users(member_ids, notif_type, notif_data) + + +async def handle_get_pinned_messages(msg: dict, session: dict, writer: ProtocolWriter): + conversation_id = msg.get("conversation_id", "") + if not conversation_id: + await send_resp(msg, writer, "get_pinned_messages", "error", {"message": "Missing conversation_id"}) + return + if not _valid_uuid(conversation_id): + await send_resp(msg, writer, "get_pinned_messages", "error", {"message": "Invalid conversation_id"}) + return + if not await adb.is_conversation_member(conversation_id, session["user_id"]): + await send_resp(msg, writer, "get_pinned_messages", "error", {"message": "Not a member"}) + return + + pinned = await adb.get_pinned_messages(conversation_id, session["user_id"]) + await send_resp(msg, writer, "get_pinned_messages", "ok", {"messages": pinned}) + + +async def handle_upload_image_start(msg: dict, session: dict, writer: ProtocolWriter): + conv_id = msg.get("conversation_id", "") + file_size = msg.get("file_size", 0) + file_id = msg.get("file_id", "") + file_type = msg.get("file_type", "image") # "image" or "file" + if not conv_id or not file_id: + await send_resp(msg, writer, "upload_image_start", "error", {"message": "Missing fields"}) + return + if not _valid_uuid(file_id): + await send_resp(msg, writer, "upload_image_start", "error", {"message": "Invalid file_id"}) + return + # M5: rate limit + caps on in-flight uploads + addr = _get_peer_addr(writer) + if await _is_rate_limited(f"upload_start|{session['user_id']}", 10): + await send_resp(msg, writer, "upload_image_start", "error", {"message": "Too many uploads. Try later."}) + return + if not await adb.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "upload_image_start", "error", {"message": "Not a member"}) + return + max_bytes = MAX_FILE_BYTES if file_type == "file" else MAX_IMAGE_BYTES + if max_bytes > 0 and file_size > max_bytes: + await send_resp(msg, writer, "upload_image_start", "error", + {"message": f"File too large (max {max_bytes} bytes)"}) + return + UPLOAD_DIR.mkdir(parents=True, exist_ok=True) + temp_path = _safe_upload_path(file_id, ".tmp") + if not temp_path: + await send_resp(msg, writer, "upload_image_start", "error", {"message": "Invalid file_id"}) + return + # M5: atomic cap check + insert under single lock acquisition + cap_error = "" + async with _uploads_lock: + total = len(pending_uploads) + user_count = sum(1 for u in pending_uploads.values() if u.get("uploader_id") == session["user_id"]) + if total >= MAX_UPLOADS_GLOBAL: + cap_error = "Server upload limit reached. Try later." + elif user_count >= MAX_UPLOADS_PER_USER: + cap_error = "Too many active uploads. Finish or cancel existing ones." + else: + temp_path.write_bytes(b"") + pending_uploads[file_id] = { + "temp_path": str(temp_path), + "received_bytes": 0, + "file_size": file_size, + "max_bytes": max_bytes, + "conv_id": conv_id, + "uploader_id": session["user_id"], + } + if cap_error: + await send_resp(msg, writer, "upload_image_start", "error", {"message": cap_error}) + return + try: + await adb.create_image_upload(file_id, conv_id, session["user_id"], file_size) + except Exception: + # Rollback: remove from pending_uploads + delete temp file + async with _uploads_lock: + pending_uploads.pop(file_id, None) + _secure_delete(temp_path) + logger.exception("[UPLOAD] DB create failed for file=%s", file_id[:8]) + await send_resp(msg, writer, "upload_image_start", "error", {"message": "Upload failed"}) + return + logger.info("[UPLOAD] %s started upload file=%s (%s, %d bytes)", + _who(session), file_id[:8], file_type, file_size) + await send_resp(msg, writer, "upload_image_start", "ok", {"file_id": file_id}) + + +async def handle_upload_image_chunk(msg: dict, session: dict, writer: ProtocolWriter): + file_id = msg.get("file_id", "") + chunk_data = msg.get("data", "") + if not file_id or not chunk_data: + await send_resp(msg, writer, "upload_image_chunk", "error", {"message": "Missing fields"}) + return + async with _uploads_lock: + upload = pending_uploads.get(file_id) + if not upload or upload["uploader_id"] != session["user_id"]: + upload = None + else: + temp_path_str = upload["temp_path"] + upload_max = upload.get("max_bytes", 0) + if not upload: + await send_resp(msg, writer, "upload_image_chunk", "error", {"message": "No active upload"}) + return + raw = decode_binary(chunk_data) + temp_path = Path(temp_path_str) + await asyncio.to_thread(_append_file, temp_path, raw) + over_limit = False + async with _uploads_lock: + upload = pending_uploads.get(file_id) + if upload: + upload["received_bytes"] += len(raw) + if upload_max > 0 and upload["received_bytes"] > upload_max: + pending_uploads.pop(file_id, None) + over_limit = True + received = upload["received_bytes"] + if over_limit: + _secure_delete(temp_path) + await send_resp(msg, writer, "upload_image_chunk", "error", {"message": "Upload exceeds size limit"}) + return + await send_resp(msg, writer, "upload_image_chunk", "ok", {"received": received}) + + +async def handle_upload_image_end(msg: dict, session: dict, writer: ProtocolWriter): + file_id = msg.get("file_id", "") + if not file_id: + await send_resp(msg, writer, "upload_image_end", "error", {"message": "Missing file_id"}) + return + async with _uploads_lock: + upload = pending_uploads.pop(file_id, None) + if not upload or upload["uploader_id"] != session["user_id"]: + await send_resp(msg, writer, "upload_image_end", "error", {"message": "No active upload"}) + return + temp_path = Path(upload["temp_path"]) + if upload["received_bytes"] != upload["file_size"]: + _secure_delete(temp_path) + await send_resp(msg, writer, "upload_image_end", "error", + {"message": f"Incomplete upload: received {upload['received_bytes']} of {upload['file_size']} bytes"}) + return + final_path = _safe_upload_path(file_id, ".enc") + if not final_path: + _secure_delete(temp_path) + await send_resp(msg, writer, "upload_image_end", "error", {"message": "Invalid file_id"}) + return + def _move_file(): + try: + temp_path.rename(final_path) + except Exception: + import shutil + shutil.move(str(temp_path), str(final_path)) + await asyncio.to_thread(_move_file) + await adb.complete_image_upload(file_id) + logger.info("[UPLOAD] %s completed upload file=%s (%d bytes)", + _who(session), file_id[:8], upload["received_bytes"]) + await send_resp(msg, writer, "upload_image_end", "ok", {"file_id": file_id}) + + +async def handle_download_image(msg: dict, session: dict, writer: ProtocolWriter): + file_id = msg.get("file_id", "") + offset = msg.get("offset", 0) + if not file_id: + await send_resp(msg, writer, "download_image", "error", {"message": "Missing file_id"}) + return + if not _valid_uuid(file_id): + await send_resp(msg, writer, "download_image", "error", {"message": "Invalid file_id"}) + return + upload = await adb.get_image_upload(file_id) + if not upload or not upload["completed"]: + await send_resp(msg, writer, "download_image", "error", {"message": "File not found"}) + return + if not await adb.is_conversation_member(upload["conversation_id"], session["user_id"]): + await send_resp(msg, writer, "download_image", "error", {"message": "Not a member"}) + return + file_path = _safe_upload_path(file_id, ".enc") + if not file_path or not file_path.exists(): + await send_resp(msg, writer, "download_image", "error", {"message": "File not found"}) + return + file_size = file_path.stat().st_size + chunk = await asyncio.to_thread(_read_file_chunk, file_path, offset, IMAGE_CHUNK_SIZE) + done = (offset + len(chunk)) >= file_size + if offset == 0: + logger.info("[DOWNLOAD] %s downloading file=%s (%d bytes)", _who(session), file_id[:8], file_size) + await send_resp(msg, writer, "download_image", "ok", { + "file_id": file_id, + "data": encode_binary(chunk), + "offset": offset, + "done": done, + "total_size": file_size, + }) + + +MAX_AVATAR_BYTES = 2 * 1024 * 1024 # 2 MB + + +async def handle_get_profile(msg: dict, session: dict, writer: ProtocolWriter): + """Get user profile (respects visibility for other users).""" + target_user_id = msg.get("user_id", "").strip() + if not target_user_id: + target_user_id = session["user_id"] + elif not _valid_uuid(target_user_id): + await send_resp(msg, writer, "get_profile", "error", {"message": "Invalid user_id"}) + return + profile = await adb.get_user_profile(target_user_id, viewer_id=session["user_id"]) + if not profile: + await send_resp(msg, writer, "get_profile", "error", {"message": "User not found"}) + return + # Serialize datetime fields + for key in ("created_at", "updated_at"): + if profile.get(key) and hasattr(profile[key], "isoformat"): + profile[key] = profile[key].isoformat() + await send_resp(msg, writer, "get_profile", "ok", profile) + + +async def handle_update_profile(msg: dict, session: dict, writer: ProtocolWriter): + """Update own profile fields.""" + fields = {} + for key in ("phone", "phone_visible", "email_visible", "location", "location_visible"): + if key in msg: + fields[key] = msg[key] + if not fields: + await send_resp(msg, writer, "update_profile", "error", {"message": "No fields to update"}) + return + await adb.update_user_profile(session["user_id"], **fields) + await send_resp(msg, writer, "update_profile", "ok", {"message": "OK"}) + + +async def handle_update_avatar(msg: dict, session: dict, writer: ProtocolWriter): + """Upload avatar (base64 in single message, max 2MB).""" + if await _is_rate_limited(f"update_avatar|{session['user_id']}", 5): + await send_resp(msg, writer, "update_avatar", "error", {"message": "Too many requests. Try later."}) + return + avatar_b64 = msg.get("data", "") + if not avatar_b64: + await send_resp(msg, writer, "update_avatar", "error", {"message": "Missing data"}) + return + avatar_data = decode_binary(avatar_b64) + if len(avatar_data) > MAX_AVATAR_BYTES: + await send_resp(msg, writer, "update_avatar", "error", + {"message": f"Avatar too large (max {MAX_AVATAR_BYTES} bytes)"}) + return + # Detect format from magic bytes + ext = "jpg" + if avatar_data[:8] == b'\x89PNG\r\n\x1a\n': + ext = "png" + avatar_dir = UPLOAD_DIR / "avatars" + avatar_dir.mkdir(parents=True, exist_ok=True) + os.chmod(avatar_dir, 0o700) + filename = f"{session['user_id']}.{ext}" + avatar_path = _safe_avatar_path(filename) + if not avatar_path: + await send_resp(msg, writer, "update_avatar", "error", {"message": "Invalid path"}) + return + await asyncio.to_thread(avatar_path.write_bytes, avatar_data) + await adb.update_user_profile(session["user_id"], avatar_file=filename) + logger.info("[AVATAR] %s updated their avatar", _who(session)) + await send_resp(msg, writer, "update_avatar", "ok", {"avatar_file": filename}) + + +async def handle_get_avatar(msg: dict, session: dict, writer: ProtocolWriter): + """Download avatar for a user.""" + target_user_id = msg.get("user_id", "").strip() + if not target_user_id: + await send_resp(msg, writer, "get_avatar", "error", {"message": "Missing user_id"}) + return + if not _valid_uuid(target_user_id): + await send_resp(msg, writer, "get_avatar", "error", {"message": "Invalid user_id"}) + return + profile = await adb.get_user_profile(target_user_id) + if not profile or not profile.get("avatar_file"): + await send_resp(msg, writer, "get_avatar", "error", {"message": "No avatar"}) + return + avatar_path = _safe_avatar_path(profile["avatar_file"]) + if not avatar_path or not avatar_path.exists(): + await send_resp(msg, writer, "get_avatar", "error", {"message": "Avatar file not found"}) + return + avatar_data = await asyncio.to_thread(avatar_path.read_bytes) + await send_resp(msg, writer, "get_avatar", "ok", { + "user_id": target_user_id, + "data": encode_binary(avatar_data), + "filename": profile["avatar_file"], + }) + + +async def handle_update_group_avatar(msg: dict, session: dict, writer: ProtocolWriter): + """Upload avatar for a group conversation (base64, max 2MB). Only members can set it.""" + if await _is_rate_limited(f"update_avatar|{session['user_id']}", 5): + await send_resp(msg, writer, "update_group_avatar", "error", {"message": "Too many requests. Try later."}) + return + conv_id = msg.get("conversation_id", "").strip() + avatar_b64 = msg.get("data", "") + if not conv_id or not avatar_b64: + await send_resp(msg, writer, "update_group_avatar", "error", {"message": "Missing fields"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, "update_group_avatar", "error", {"message": "Invalid conversation_id"}) + return + if not await adb.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "update_group_avatar", "error", {"message": "Not a member"}) + return + avatar_data = decode_binary(avatar_b64) + if len(avatar_data) > MAX_AVATAR_BYTES: + await send_resp(msg, writer, "update_group_avatar", "error", + {"message": f"Avatar too large (max {MAX_AVATAR_BYTES} bytes)"}) + return + ext = "jpg" + if avatar_data[:8] == b'\x89PNG\r\n\x1a\n': + ext = "png" + avatar_dir = UPLOAD_DIR / "avatars" + avatar_dir.mkdir(parents=True, exist_ok=True) + os.chmod(avatar_dir, 0o700) + filename = f"group_{conv_id}.{ext}" + avatar_path = _safe_avatar_path(filename) + if not avatar_path: + await send_resp(msg, writer, "update_group_avatar", "error", {"message": "Invalid path"}) + return + await asyncio.to_thread(avatar_path.write_bytes, avatar_data) + await adb.update_conversation_avatar(conv_id, filename) + logger.info("[AVATAR] %s updated group avatar for conv=%s", _who(session), conv_id[:8]) + await send_resp(msg, writer, "update_group_avatar", "ok", {"avatar_file": filename}) + + +async def handle_get_group_avatar(msg: dict, session: dict, writer: ProtocolWriter): + """Download avatar for a group conversation.""" + conv_id = msg.get("conversation_id", "").strip() + if not conv_id: + await send_resp(msg, writer, "get_group_avatar", "error", {"message": "Missing conversation_id"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, "get_group_avatar", "error", {"message": "Invalid conversation_id"}) + return + if not await adb.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "get_group_avatar", "error", {"message": "Not a member"}) + return + conv = await adb.get_conversation(conv_id) + if not conv or not conv.get("avatar_file"): + await send_resp(msg, writer, "get_group_avatar", "error", {"message": "No avatar"}) + return + avatar_path = _safe_avatar_path(conv["avatar_file"]) + if not avatar_path or not avatar_path.exists(): + await send_resp(msg, writer, "get_group_avatar", "error", {"message": "Avatar file not found"}) + return + avatar_data = await asyncio.to_thread(avatar_path.read_bytes) + await send_resp(msg, writer, "get_group_avatar", "ok", { + "conversation_id": conv_id, + "data": encode_binary(avatar_data), + "filename": conv["avatar_file"], + }) + + +async def handle_list_devices(msg: dict, session: dict, writer: ProtocolWriter): + """List all devices for the current user.""" + devices = await adb.get_user_devices(session["user_id"]) + result = [] + for d in devices: + entry = { + "device_id": d["id"], + "device_name": d.get("device_name"), + "created_at": d["created_at"].isoformat() if hasattr(d["created_at"], "isoformat") else str(d["created_at"]), + "last_seen_at": d["last_seen_at"].isoformat() if d.get("last_seen_at") and hasattr(d["last_seen_at"], "isoformat") else (str(d["last_seen_at"]) if d.get("last_seen_at") else None), + "is_current": d["id"] == session.get("device_id"), + } + result.append(entry) + await send_resp(msg, writer, "list_devices", "ok", {"devices": result}) + + +async def handle_remove_device(msg: dict, session: dict, writer: ProtocolWriter): + """Remove a device (cannot remove current device).""" + device_id = msg.get("device_id", "").strip() + if not device_id: + await send_resp(msg, writer, "remove_device", "error", {"message": "Missing device_id"}) + return + if not _valid_uuid(device_id): + await send_resp(msg, writer, "remove_device", "error", {"message": "Invalid device_id"}) + return + if device_id == session.get("device_id"): + await send_resp(msg, writer, "remove_device", "error", {"message": "Cannot remove current device"}) + return + dev = await adb.get_device(device_id) + if not dev or dev["user_id"] != session["user_id"]: + await send_resp(msg, writer, "remove_device", "error", {"message": "Device not found"}) + return + await adb.delete_device(device_id) + logger.info("[DEVICE] %s removed device=%s", _who(session), device_id[:8]) + await send_resp(msg, writer, "remove_device", "ok", {"message": "OK"}) + + +async def handle_session_reset(msg: dict, session: dict, writer: ProtocolWriter): + """Notify peer to reset a corrupted Double Ratchet session.""" + peer_user_id = msg.get("peer_user_id", "").strip() + peer_device_id = msg.get("peer_device_id", "").strip() or None + if not peer_user_id or not _valid_uuid(peer_user_id): + await send_resp(msg, writer, "session_reset", "error", {"message": "Invalid peer_user_id"}) + return + if peer_device_id and not _valid_uuid(peer_device_id): + await send_resp(msg, writer, "session_reset", "error", {"message": "Invalid peer_device_id"}) + return + # H3 fix: rate limit (5/min per user, keyed by user_id only — IP-independent) + if await _is_rate_limited(f"session_reset|{session['user_id']}", 5): + await send_resp(msg, writer, "session_reset", "error", {"message": "Rate limit exceeded"}) + return + # H3 fix: verify users share at least one conversation + if not await adb.shares_conversation(session["user_id"], peer_user_id): + await send_resp(msg, writer, "session_reset", "error", {"message": "No shared conversation"}) + return + # Push notification to peer (target specific device if specified) + notif_data = { + "from_user_id": session["user_id"], + "from_device_id": session.get("device_id"), + } + if peer_device_id: + # Send only to the specific device + targets = [] + async with _clients_lock: + for w in connected_clients.get(peer_user_id, []): + if writer_device_map.get(id(w)) == peer_device_id: + targets.append(w) + for w in targets: + try: + await w.send_response("session_reset", "ok", notif_data) + except Exception: + pass + else: + await _notify_users([peer_user_id], "session_reset", notif_data) + logger.info("[SESSION] %s reset session with peer=%s", _who(session), peer_user_id[:8]) + await send_resp(msg, writer, "session_reset", "ok", {}) + + +async def handle_get_deleted_since(msg: dict, session: dict, writer: ProtocolWriter): + """Return message IDs deleted since a given timestamp.""" + conv_id = msg.get("conversation_id", "") + since_ts = msg.get("since_ts", "") + if not conv_id or not since_ts: + await send_resp(msg, writer, "get_deleted_since", "error", {"message": "Missing parameters"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, "get_deleted_since", "error", {"message": "Invalid conversation_id"}) + return + if not await adb.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "get_deleted_since", "error", {"message": "Not a member"}) + return + deleted_ids = await adb.get_deleted_messages_since(conv_id, session["user_id"], since_ts) + await send_resp(msg, writer, "get_deleted_since", "ok", {"deleted_ids": deleted_ids}) + + +async def handle_reencrypt_messages(msg: dict, session: dict, writer: ProtocolWriter): + """Re-encrypt message history with self-encryption key (for device pairing).""" + if await _is_rate_limited(f"reencrypt|{session['user_id']}", 10): + await send_resp(msg, writer, "reencrypt_messages", "error", {"message": "Too many requests. Try later."}) + return + updates_raw = msg.get("updates", []) + if not updates_raw: + await send_resp(msg, writer, "reencrypt_messages", "error", {"message": "No updates"}) + return + if len(updates_raw) > 500: + await send_resp(msg, writer, "reencrypt_messages", "error", + {"message": "Too many updates (max 500 per request)"}) + return + updates = [] + for u in updates_raw: + mid = u.get("message_id", "") + ct_b64 = u.get("encrypted_content", "") + nonce_b64 = u.get("nonce", "") + if not mid or not ct_b64 or not nonce_b64: + continue + updates.append({ + "message_id": mid, + "encrypted_content": decode_binary(ct_b64), + "nonce": decode_binary(nonce_b64), + }) + if not updates: + await send_resp(msg, writer, "reencrypt_messages", "error", {"message": "No valid updates"}) + return + await adb.batch_reencrypt_messages(session["user_id"], updates) + logger.info("[REENCRYPT] %s re-encrypted %d messages", _who(session), len(updates)) + await send_resp(msg, writer, "reencrypt_messages", "ok", {"count": len(updates)}) + + +async def _cleanup_uploads(): + stale = await adb.get_stale_uploads(UPLOAD_STALE_SECONDS) + for s in stale: + fid = s["file_id"] + for ext in (".tmp", ".enc"): + p = _safe_upload_path(fid, ext) + if not p: + continue + _secure_delete(p) + await adb.delete_image_upload(fid) + async with _uploads_lock: + pending_uploads.pop(fid, None) + if stale: + logger.info("Cleaned up %d stale uploads.", len(stale)) + + +async def handle_client(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + global current_connections + addr = _get_peer_addr(ProtocolWriter(writer)) + async with _conn_lock: + current_connections += 1 + connection_counts[addr] = connection_counts.get(addr, 0) + 1 + over_limit = (current_connections > MAX_CONNECTIONS_GLOBAL or + connection_counts[addr] > MAX_CONNECTIONS_PER_IP) + if over_limit: + try: + writer.close() + except Exception: + pass + async with _conn_lock: + current_connections = max(0, current_connections - 1) + connection_counts[addr] = max(0, connection_counts.get(addr, 1) - 1) + return + logger.info("[CONN] Client connected from %s", addr) + proto_reader = ProtocolReader(reader) + proto_writer = ProtocolWriter(writer) + session = None + state = {"_req_times": []} + + try: + while True: + try: + msg = await proto_reader.read_message() + except ValueError as e: + try: + await proto_writer.send_response("protocol_error", "error", {"message": str(e)}) + except Exception: + pass + break + if msg is None: + break + + msg_type = msg.get("type", "") + now = asyncio.get_event_loop().time() + times = [t for t in state["_req_times"] if now - t <= CONNECTION_RL_WINDOW] + if len(times) >= CONNECTION_RL_MAX: + await send_resp(msg, proto_writer, msg_type, "error", {"message": "Too many requests. Slow down."}) + state["_req_times"] = times + continue + times.append(now) + state["_req_times"] = times + + try: + if msg_type == "register": + await handle_register_start(msg, proto_writer) + elif msg_type == "register_confirm": + await handle_register_confirm(msg, proto_writer) + elif msg_type == "login_start": + await handle_login_start(msg, proto_writer, state) + elif msg_type == "login_finish": + result = await handle_login_finish(msg, proto_writer, state) + if result: + session = result + elif msg_type == "pairing_start": + await handle_pairing_start(msg, proto_writer) + elif msg_type == "pairing_poll": + await handle_pairing_poll(msg, proto_writer) + elif session is None: + await send_resp(msg, proto_writer, msg_type, "error", {"message": "Not logged in"}) + elif msg_type == "get_user_info": + await handle_get_user_info(msg, session, proto_writer) + elif msg_type == "upload_prekeys": + await handle_upload_prekeys(msg, session, proto_writer) + elif msg_type == "get_key_bundle": + await handle_get_key_bundle(msg, session, proto_writer) + elif msg_type == "get_prekey_count": + await handle_get_prekey_count(msg, session, proto_writer) + elif msg_type == "ensure_prekeys": + await handle_ensure_prekeys(msg, session, proto_writer) + elif msg_type == "create_conversation": + await handle_create_conversation(msg, session, proto_writer) + elif msg_type == "find_conversation": + await handle_find_conversation(msg, session, proto_writer) + elif msg_type == "add_member": + await handle_add_member(msg, session, proto_writer) + elif msg_type == "accept_invitation": + await handle_accept_invitation(msg, session, proto_writer) + elif msg_type == "decline_invitation": + await handle_decline_invitation(msg, session, proto_writer) + elif msg_type == "list_invitations": + await handle_list_invitations(msg, session, proto_writer) + elif msg_type == "list_conversations": + await handle_list_conversations(msg, session, proto_writer) + elif msg_type == "send_message": + await handle_send_message(msg, session, proto_writer) + elif msg_type == "get_messages": + await handle_get_messages(msg, session, proto_writer) + elif msg_type == "rotate_keys": + await handle_rotate_keys(msg, session, proto_writer) + elif msg_type == "change_username": + await handle_change_username(msg, session, proto_writer) + elif msg_type == "remove_member": + await handle_remove_member(msg, session, proto_writer) + elif msg_type == "leave_group": + await handle_leave_group(msg, session, proto_writer) + elif msg_type == "rename_conversation": + await handle_rename_conversation(msg, session, proto_writer) + elif msg_type == "delete_conversation": + await handle_delete_conversation(msg, session, proto_writer) + elif msg_type == "mark_read": + await handle_mark_read(msg, session, proto_writer) + elif msg_type == "mark_conversation_read": + await handle_mark_conversation_read(msg, session, proto_writer) + elif msg_type == "confirm_delivery": + await handle_confirm_delivery(msg, session, proto_writer) + elif msg_type == "pairing_claim": + await handle_pairing_claim(msg, session, proto_writer) + elif msg_type == "pairing_send": + await handle_pairing_send(msg, session, proto_writer) + elif msg_type == "delete_message": + await handle_delete_message(msg, session, proto_writer) + elif msg_type == "upload_image_start": + await handle_upload_image_start(msg, session, proto_writer) + elif msg_type == "upload_image_chunk": + await handle_upload_image_chunk(msg, session, proto_writer) + elif msg_type == "upload_image_end": + await handle_upload_image_end(msg, session, proto_writer) + elif msg_type == "download_image": + await handle_download_image(msg, session, proto_writer) + elif msg_type == "get_profile": + await handle_get_profile(msg, session, proto_writer) + elif msg_type == "update_profile": + await handle_update_profile(msg, session, proto_writer) + elif msg_type == "update_avatar": + await handle_update_avatar(msg, session, proto_writer) + elif msg_type == "get_avatar": + await handle_get_avatar(msg, session, proto_writer) + elif msg_type == "update_group_avatar": + await handle_update_group_avatar(msg, session, proto_writer) + elif msg_type == "get_group_avatar": + await handle_get_group_avatar(msg, session, proto_writer) + elif msg_type == "get_deleted_since": + await handle_get_deleted_since(msg, session, proto_writer) + elif msg_type == "reencrypt_messages": + await handle_reencrypt_messages(msg, session, proto_writer) + elif msg_type == "list_devices": + await handle_list_devices(msg, session, proto_writer) + elif msg_type == "remove_device": + await handle_remove_device(msg, session, proto_writer) + elif msg_type == "session_reset": + await handle_session_reset(msg, session, proto_writer) + elif msg_type == "react_message": + await handle_react_message(msg, session, proto_writer) + elif msg_type == "pin_message": + await handle_pin_message(msg, session, proto_writer) + elif msg_type == "get_pinned_messages": + await handle_get_pinned_messages(msg, session, proto_writer) + else: + await send_resp(msg, proto_writer, msg_type, "error", {"message": "Unknown type"}) + except Exception as e: + logger.warning("[ERROR] %s handler '%s' failed: %s", _who(session), msg_type, e, exc_info=True) + try: + await send_resp(msg, proto_writer, msg_type, "error", {"message": "Internal server error"}) + except Exception: + break # Can't send response — connection is dead + except Exception as e: + logger.warning("Client connection error: %s", e) + finally: + async with _conn_lock: + current_connections = max(0, current_connections - 1) + connection_counts[addr] = max(0, connection_counts.get(addr, 1) - 1) + offline_targets = [] + if session: + uid = session["user_id"] + contacts = await adb.get_user_contacts(uid) + async with _clients_lock: + writer_device_map.pop(id(proto_writer), None) + if uid in connected_clients: + remaining = [w for w in connected_clients[uid] if w is not proto_writer] + if remaining: + connected_clients[uid] = remaining + else: + del connected_clients[uid] + # User fully offline — snapshot targets under lock + for contact_id in contacts: + for cw in connected_clients.get(contact_id, []): + offline_targets.append(cw) + # Send offline notifications outside lock + for cw in offline_targets: + try: + await cw.send_response("user_offline", "ok", {"user_id": uid}) + except Exception: + pass + writer.close() + logger.info("[CONN] %s disconnected", _who(session) if session else addr) + + +async def main(): + setup_logging() + host = os.getenv("SERVER_HOST", "127.0.0.1") + port = int(os.getenv("SERVER_PORT", "9999")) + tls_enabled = os.getenv("TLS_ENABLED", "false").lower() in ("1", "true", "yes") + tls_required = os.getenv("TLS_REQUIRED", "false").lower() in ("1", "true", "yes") + tls_autogen = os.getenv("TLS_AUTOGEN", "false").lower() in ("1", "true", "yes") + + is_dev = os.getenv("ENVIRONMENT", "").lower() in ("dev", "development") + ssl_context = None + if tls_required and not tls_enabled: + raise RuntimeError("TLS_REQUIRED is enabled but TLS is not enabled.") + if tls_enabled: + cert_file = os.getenv("TLS_CERT_FILE", "").strip() + key_file = os.getenv("TLS_KEY_FILE", "").strip() + if not cert_file or not key_file: + if tls_autogen: + if not is_dev: + raise RuntimeError("TLS_AUTOGEN is only allowed when ENVIRONMENT=dev") + cert_dir = Path(__file__).resolve().parent / "certs" + cert_dir.mkdir(parents=True, exist_ok=True) + cert_file = str(cert_dir / "server.crt") + key_file = str(cert_dir / "server.key") + if not (os.path.exists(cert_file) and os.path.exists(key_file)): + try: + subprocess.run( + [ + "openssl", "req", "-x509", "-newkey", "rsa:4096", + "-keyout", key_file, "-out", cert_file, + "-days", "365", "-nodes", "-subj", "/CN=localhost", + ], + check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, + ) + os.chmod(key_file, 0o600) + except FileNotFoundError: + raise RuntimeError("OpenSSL not found.") + except subprocess.CalledProcessError: + raise RuntimeError("Failed to auto-generate TLS cert.") + logger.warning("Using auto-generated self-signed certificate — not for production use.") + else: + raise RuntimeError("TLS is enabled but TLS_CERT_FILE or TLS_KEY_FILE is missing.") + ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ssl_context.load_cert_chain(certfile=cert_file, keyfile=key_file) + else: + logger.warning("TLS is disabled — traffic is unencrypted. Set TLS_ENABLED=true for production.") + + UPLOAD_DIR.mkdir(parents=True, exist_ok=True) + + # Thread pool for asyncio.to_thread() — DB calls + file I/O + pool_workers = int(os.getenv("THREAD_POOL_SIZE", "40")) + asyncio.get_event_loop().set_default_executor(ThreadPoolExecutor(max_workers=pool_workers)) + logger.info("Thread pool executor: %d workers", pool_workers) + + # Load phantom user IDs from DB into in-memory cache + phantom_user_ids.update(await adb.get_all_phantom_user_ids()) + if phantom_user_ids: + logger.info("Loaded %d phantom user IDs.", len(phantom_user_ids)) + + server = await asyncio.start_server( + handle_client, host, port, limit=MAX_MESSAGE_BYTES, ssl=ssl_context, + ) + logger.info("Encrypted chat server v%s listening on %s:%s", VERSION, host, port) + + async def _cleanup_rate_limits(): + async with _conn_lock: + now = asyncio.get_event_loop().time() + window_start = now - RATE_LIMIT_WINDOW + stale_keys = [k for k, times in rate_limits.items() + if not any(t >= window_start for t in times)] + for k in stale_keys: + del rate_limits[k] + stale_conns = [k for k, v in connection_counts.items() if v <= 0] + for k in stale_conns: + del connection_counts[k] + + _cleanup_cycle = 0 + + async def _periodic_cleanup(): + nonlocal _cleanup_cycle + while True: + await asyncio.sleep(120) + _cleanup_cycle += 1 + try: + await _cleanup_uploads() + except Exception as e: + logger.warning("Upload cleanup error: %s", e) + try: + await _cleanup_rate_limits() + except Exception as e: + logger.warning("Rate limit cleanup error: %s", e) + try: + await _cleanup_registrations() + except Exception as e: + logger.warning("Registration cleanup error: %s", e) + # L8: clean up stale phantom users (>30 days, no real conversations) + try: + deleted = await adb.cleanup_stale_phantoms(30) + if deleted: + async with _clients_lock: + phantom_user_ids.clear() + phantom_user_ids.update(await adb.get_all_phantom_user_ids()) + logger.info("Cleaned up %d stale phantom users.", deleted) + except Exception as e: + logger.warning("Phantom cleanup error: %s", e) + # Metadata retention: purge old reads and reactions (every 30 cycles = ~1 hour) + if _cleanup_cycle % 30 == 0: + try: + reads_del = await adb.cleanup_old_reads(METADATA_RETENTION_DAYS) + reactions_del = await adb.cleanup_old_reactions(METADATA_RETENTION_DAYS) + if reads_del or reactions_del: + logger.info("Metadata cleanup: %d reads, %d reactions purged", + reads_del, reactions_del) + except Exception as e: + logger.warning("Metadata cleanup error: %s", e) + + asyncio.create_task(_periodic_cleanup()) + + loop = asyncio.get_running_loop() + stop = loop.create_future() + + def signal_handler(): + if not stop.done(): + stop.set_result(None) + + for sig in (signal.SIGINT, signal.SIGTERM): + loop.add_signal_handler(sig, signal_handler) + + async with server: + await stop + logger.info("Shutting down — closing %d client connections...", sum(len(ws) for ws in connected_clients.values())) + # Stop accepting new connections + server.close() + # Force-close all connected client writers + async with _clients_lock: + all_writers = [w for writers in connected_clients.values() for w in writers] + connected_clients.clear() + writer_device_map.clear() + for w in all_writers: + try: + w.close() + except Exception: + pass + # Give handle_client loops a moment to notice closed connections + await asyncio.sleep(0.1) + # Cancel any remaining handle_client tasks that are still blocked + tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + logger.info("Server shut down.") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/PENTEST_CLIENT.md b/tests/PENTEST_CLIENT.md new file mode 100644 index 0000000..3cd3202 --- /dev/null +++ b/tests/PENTEST_CLIENT.md @@ -0,0 +1,79 @@ +# `tests/pentest_client.py` + +Automatizovaný pentest/integration harness nad živým serverem s reálnými účty. + +## Co test dělá + +1. **Conversation Isolation (AuthZ)** + - Účet `outsider` zkouší `get_messages`, `mark_read` a `send_message` do konverzace, kde není členem. + - Očekávání: server vrátí `error` + `"Not a member"`. + +2. **Malformed Header Rejection** + - Platný člen konverzace pošle `send_message` s obřím `ratchet_header`. + - Očekávání: server odmítne request (`Invalid ratchet_header format`), tj. funguje `_validate_header` limit. + +3. **Session Reset Authorization** + - `outsider` pošle `session_reset` na `peer_user_id`. + - Očekávání: `error` + `"No shared conversation"`. + - Pokud účty sdílenou konverzaci opravdu mají, test se označí jako `SKIP` (setup issue, ne nutně bezpečnostní chyba). + +4. **Login Rate Limits (volitelné)** + - Anonymní klient spamuje `login_start`: + - stejný email v různých kombinacích velikosti písmen (test case-normalization bucketu), + - potom rotace různých emailů ze stejné IP (test per-IP bucketu). + - Očekávání: aktivuje se jak per-email limit, tak per-IP limit. + +## Požadavky + +- Běžící server (`server.py`). +- Existující lokální klíče pro účty v `~/.encrypted_chat//` (stejné jako pro běžného CLI klienta). +- 3 různé účty: + - `member` (A), + - `peer` (B), + - `outsider` (C). + +## Spuštění + +```bash +python3 tests/pentest_client.py \ + --server-host localhost \ + --member-email alice@example.com \ + --peer-email bob@example.com \ + --outsider-email mallory@example.com +``` + +Skript si vyžádá hesla interaktivně. Lze je předat i argumenty: + +```bash +python3 tests/pentest_client.py \ + --server-host localhost \ + --member-email alice@example.com --member-password '***' \ + --peer-email bob@example.com --peer-password '***' \ + --outsider-email mallory@example.com --outsider-password '***' +``` + +Volby: + +- `--conversation-id `: použije konkrétní konverzaci místo auto member<->peer DM. +- `--skip-login-rate-limit`: přeskočí test `login_start` limiteru. +- `--server-host `: přepíše `SERVER_HOST` pro tento běh. +- `--server-port `: přepíše `SERVER_PORT` pro tento běh. + +Poznámka k TLS: + +- Pokud máš v `.env` `SERVER_HOST=0.0.0.0`, je to správně pro server bind, ale klient na to nesmí přistupovat přes TLS. +- Pro klienta použij `--server-host` s hodnotou, která je v certifikátu (SAN/CN), typicky `localhost` nebo konkrétní IP. + +## Výstup + +Skript tiskne souhrn: + +- `[PASS]` test prošel, +- `[FAIL]` test selhal (potenciální regrese), +- `[SKIP]` test nelze vyhodnotit kvůli dataset/setup podmínkám. + +Návratový kód: + +- `0` = bez failu, +- `1` = alespoň jeden fail, +- `2` = chyba vstupních parametrů. diff --git a/tests/pentest_client.py b/tests/pentest_client.py new file mode 100644 index 0000000..502153a --- /dev/null +++ b/tests/pentest_client.py @@ -0,0 +1,338 @@ +#!/usr/bin/env python3 +"""Security regression harness for encrypted_chat server. + +Runs focused pentest/integration checks against a live server using real accounts. +""" + +from __future__ import annotations + +import argparse +import asyncio +import getpass +import os +import ssl +import sys +import time +import uuid +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Any + +ROOT = Path(__file__).resolve().parents[1] +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + +if TYPE_CHECKING: + from chat_core import ChatClient + + +@dataclass +class TestResult: + name: str + outcome: str # PASS | FAIL | SKIP + details: str + + +def _msg(resp: dict) -> str: + data = resp.get("data") or {} + return str(data.get("message", "")) + + +async def _connect_client() -> "ChatClient": + from chat_core import ChatClient # Imported lazily so --help works without full deps + client = ChatClient() + await client.connect() + client._listener_task = asyncio.create_task(client._background_listener()) + return client + + +async def _login_client(email: str, password: str) -> tuple["ChatClient", str]: + client = await _connect_client() + ok, message = await client.login(email, password) + if not ok: + await client.close() + raise RuntimeError(f"Login failed for {email}: {message}") + return client, message + + +async def _close_client(client: "ChatClient | None"): + if not client: + return + try: + await client.close() + except Exception: + pass + + +def _too_many_attempts(resp: dict) -> bool: + return resp.get("status") == "error" and "Too many attempts" in _msg(resp) + + +async def test_conversation_isolation(outsider: "ChatClient", conv_id: str) -> TestResult: + """Outsider must not access a conversation they are not a member of.""" + fake_mid = str(uuid.uuid4()) + checks: list[tuple[str, dict]] = [ + ( + "get_messages", + await outsider.send_and_recv("get_messages", conversation_id=conv_id, limit=5, offset=0), + ), + ( + "mark_read", + await outsider.send_and_recv("mark_read", conversation_id=conv_id, message_ids=[fake_mid]), + ), + ( + "send_message", + await outsider.send_and_recv("send_message", conversation_id=conv_id), + ), + ] + failures: list[str] = [] + for endpoint, resp in checks: + if resp.get("status") != "error" or "Not a member" not in _msg(resp): + failures.append(f"{endpoint} -> status={resp.get('status')} message={_msg(resp)!r}") + if failures: + return TestResult("Conversation Isolation (AuthZ)", "FAIL", "; ".join(failures)) + return TestResult( + "Conversation Isolation (AuthZ)", + "PASS", + "Outsider got 'Not a member' for get_messages, mark_read, send_message.", + ) + + +async def test_session_reset_no_shared(outsider: "ChatClient", peer_user_id: str) -> TestResult: + """session_reset must be rejected without shared conversation.""" + resp = await outsider.send_and_recv("session_reset", peer_user_id=peer_user_id) + if resp.get("status") == "error" and "No shared conversation" in _msg(resp): + return TestResult("Session Reset Authorization", "PASS", "Rejected with 'No shared conversation'.") + if resp.get("status") == "ok": + return TestResult( + "Session Reset Authorization", + "SKIP", + "Outsider appears to share a conversation with peer in current dataset.", + ) + return TestResult( + "Session Reset Authorization", + "FAIL", + f"Unexpected response: status={resp.get('status')} message={_msg(resp)!r}", + ) + + +async def test_malformed_header_rejected(member: "ChatClient", conv_id: str) -> TestResult: + """Oversized ratchet header should be rejected by server-side validation.""" + huge_header = {"dh_pub": "A" * 5000, "n": 1, "pn": 0} + resp = await member.send_and_recv( + "send_message", + conversation_id=conv_id, + ratchet_header=huge_header, + recipients=[{}], + ) + if resp.get("status") == "error" and "Invalid ratchet_header format" in _msg(resp): + return TestResult("Malformed Header Rejection", "PASS", "Oversized ratchet_header rejected.") + return TestResult( + "Malformed Header Rejection", + "FAIL", + f"Unexpected response: status={resp.get('status')} message={_msg(resp)!r}", + ) + + +async def test_login_rate_limits() -> TestResult: + """Validate login_start per-email(case-insensitive) and per-IP limits.""" + probe = await _connect_client() + try: + stamp = int(time.time()) + base_local = f"pentest-login-{stamp}" + base_domain = "example.invalid" + base_email = f"{base_local}@{base_domain}" + case_variants = [ + base_email, + f"{base_local.upper()}@{base_domain}", + f"{base_local.capitalize()}@{base_domain}", + f"{base_local}@{base_domain.upper()}", + f"{base_local.swapcase()}@{base_domain}", + base_email, + f"{base_local.upper()}@{base_domain}", + f"{base_local.capitalize()}@{base_domain}", + f"{base_local}@{base_domain.upper()}", + f"{base_local.swapcase()}@{base_domain}", + base_email, # should exceed per-email bucket (10/min) + ] + + email_bucket_triggered = False + phase1_last = "" + for e in case_variants: + resp = await probe.send_and_recv("login_start", email=e) + phase1_last = _msg(resp) + if _too_many_attempts(resp): + email_bucket_triggered = True + await asyncio.sleep(0.12) # stay under per-connection 20 req/s limiter + + ip_bucket_triggered = False + phase2_last = "" + for i in range(1, 16): + unique_email = f"{base_local}-{i}@{base_domain}" + resp = await probe.send_and_recv("login_start", email=unique_email) + phase2_last = _msg(resp) + if _too_many_attempts(resp): + ip_bucket_triggered = True + break + await asyncio.sleep(0.12) + + if email_bucket_triggered and ip_bucket_triggered: + return TestResult( + "Login Rate Limits (case + per-IP)", + "PASS", + "Per-email(case-insensitive) and per-IP login_start limits both triggered.", + ) + return TestResult( + "Login Rate Limits (case + per-IP)", + "FAIL", + ( + f"email_bucket_triggered={email_bucket_triggered}, " + f"ip_bucket_triggered={ip_bucket_triggered}, " + f"phase1_last={phase1_last!r}, phase2_last={phase2_last!r}" + ), + ) + finally: + await _close_client(probe) + + +def _pick_password(flag_value: str | None, prompt: str) -> str: + if flag_value is not None: + return flag_value + return getpass.getpass(prompt) + + +async def run(args: argparse.Namespace) -> int: + if len({args.member_email.lower(), args.peer_email.lower(), args.outsider_email.lower()}) != 3: + print("ERROR: member/peer/outsider emails must be three different accounts.", file=sys.stderr) + return 2 + + if args.server_host: + os.environ["SERVER_HOST"] = args.server_host + if args.server_port is not None: + os.environ["SERVER_PORT"] = str(args.server_port) + + effective_host = os.getenv("SERVER_HOST", "127.0.0.1").strip() + if effective_host == "0.0.0.0": + print( + "ERROR: SERVER_HOST=0.0.0.0 je bind adresa serveru, ne klientský TLS hostname.\n" + "Pouzij --server-host (napr. localhost nebo 192.168.1.112).", + file=sys.stderr, + ) + return 2 + + member_password = _pick_password(args.member_password, f"Password for {args.member_email}: ") + peer_password = _pick_password(args.peer_password, f"Password for {args.peer_email}: ") + outsider_password = _pick_password(args.outsider_password, f"Password for {args.outsider_email}: ") + + member: "ChatClient | None" = None + peer: "ChatClient | None" = None + outsider: "ChatClient | None" = None + results: list[TestResult] = [] + + try: + print("[setup] Logging in member account...") + member, _ = await _login_client(args.member_email, member_password) + print("[setup] Logging in peer account...") + peer, _ = await _login_client(args.peer_email, peer_password) + print("[setup] Logging in outsider account...") + outsider, _ = await _login_client(args.outsider_email, outsider_password) + + if args.conversation_id: + conv_id = args.conversation_id + else: + print("[setup] Finding/creating member<->peer direct conversation...") + conv_id, err = await member.find_or_create_conversation(args.peer_email) + if not conv_id: + raise RuntimeError(f"Could not find/create conversation: {err}") + + outsider_convs = {c["conversation_id"] for c in await outsider.list_conversations()} + if conv_id in outsider_convs: + results.append( + TestResult( + "Conversation Isolation (AuthZ)", + "SKIP", + "Outsider is already a member of target conversation; choose different outsider/account set.", + ) + ) + else: + results.append(await test_conversation_isolation(outsider, conv_id)) + results.append(await test_malformed_header_rejected(member, conv_id)) + + results.append(await test_session_reset_no_shared(outsider, peer.session["user_id"])) + + if not args.skip_login_rate_limit: + results.append(await test_login_rate_limits()) + + except ssl.SSLCertVerificationError as e: + results.append( + TestResult( + "Harness Setup", + "FAIL", + ( + f"{e}. Zkus --server-host s hodnotou ze SAN/CN certifikatu " + "(napr. localhost nebo 192.168.1.112)." + ), + ) + ) + except Exception as e: + emsg = str(e) + if "CERTIFICATE_VERIFY_FAILED" in emsg or "IP address mismatch" in emsg: + emsg += ( + " | Hint: pouzij --server-host s hostname/IP, ktery je v certifikatu " + "(SERVER_HOST nesmi byt 0.0.0.0)." + ) + results.append(TestResult("Harness Setup", "FAIL", emsg)) + finally: + await _close_client(member) + await _close_client(peer) + await _close_client(outsider) + + print("\n=== Pentest Results ===") + for r in results: + print(f"[{r.outcome}] {r.name}: {r.details}") + + has_fail = any(r.outcome == "FAIL" for r in results) + return 1 if has_fail else 0 + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Run focused pentest/integration checks against encrypted_chat server." + ) + parser.add_argument("--member-email", required=True, help="Email of regular account A (conversation member).") + parser.add_argument("--peer-email", required=True, help="Email of regular account B (other conversation member).") + parser.add_argument("--outsider-email", required=True, help="Email of regular account C (must not be in target conversation).") + parser.add_argument( + "--server-host", + default=None, + help="Override SERVER_HOST for this run (must match TLS cert SAN/CN).", + ) + parser.add_argument( + "--server-port", + type=int, + default=None, + help="Override SERVER_PORT for this run.", + ) + parser.add_argument("--member-password", default=None, help="Password for --member-email (optional; prompt if omitted).") + parser.add_argument("--peer-password", default=None, help="Password for --peer-email (optional; prompt if omitted).") + parser.add_argument("--outsider-password", default=None, help="Password for --outsider-email (optional; prompt if omitted).") + parser.add_argument( + "--conversation-id", + default=None, + help="Optional target conversation UUID. If omitted, member<->peer DM is found/created automatically.", + ) + parser.add_argument( + "--skip-login-rate-limit", + action="store_true", + help="Skip anonymous login_start rate-limit regression check.", + ) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + return asyncio.run(run(args)) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/theme.py b/theme.py new file mode 100644 index 0000000..15ec5c1 --- /dev/null +++ b/theme.py @@ -0,0 +1,539 @@ +"""Theme system for Encrypted Chat GUI — light + dark mode with live switching.""" + +from __future__ import annotations + +import json +import logging +import os +from dataclasses import dataclass, fields +from pathlib import Path +from typing import Callable + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class ThemeColors: + """All colour tokens for one theme.""" + + # Surface hierarchy + bg_primary: str # Main background (messages area, right panel) + bg_secondary: str # Cards, inputs, elevated surfaces + bg_tertiary: str # Sidebar, deeper surfaces + bg_hover: str # Hover state on list items + bg_selected: str # Selected list item + + # Text + text_primary: str # Main text + text_secondary: str # Secondary / muted text + text_muted: str # Timestamps, counters, hints + + # Accent (brand blue) + accent: str + accent_hover: str + accent_text: str # Text on accent background + + # Message bubbles + bubble_sent_bg: str + bubble_sent_text: str + bubble_recv_bg: str + bubble_recv_text: str + bubble_sent_meta: str # Timestamp/read inside sent bubble + bubble_recv_meta: str # Timestamp inside received bubble + + # Semantic colours + success: str + warning: str + error: str + info: str + + # Chrome / borders + border: str + border_focus: str + scrollbar: str + separator: str + overlay: str # Privacy overlay background (rgba) + + # Links + link_https: str + link_http: str # Insecure link (orange) + + # Mentions & search + mention: str + search_highlight: str + search_current: str + + # Reactions + reaction_bg: str + reaction_bg_own: str + reaction_border: str + reaction_border_own: str + + # Misc + online_dot: str + online_dot_border: str + pin_color: str + sender_name_other: str # Non-self sender name colour in groups + receipt_read: str # Read receipt checkmarks (must contrast with sent bubble bg) + + +# --------------------------------------------------------------------------- +# Dark theme — Catppuccin Mocha palette +# --------------------------------------------------------------------------- + +DARK_THEME = ThemeColors( + bg_primary="#1e1e2e", + bg_secondary="#313244", + bg_tertiary="#181825", + bg_hover="#252536", + bg_selected="#313244", + + text_primary="#cdd6f4", + text_secondary="#bac2de", + text_muted="#6c7086", + + accent="#89b4fa", + accent_hover="#74c7ec", + accent_text="#1e1e2e", + + bubble_sent_bg="#2a4a7f", + bubble_sent_text="#cdd6f4", + bubble_recv_bg="#2c2c3e", + bubble_recv_text="#cdd6f4", + bubble_sent_meta="#8899bb", + bubble_recv_meta="#6c7086", + + success="#a6e3a1", + warning="#f9e2af", + error="#f38ba8", + info="#74c7ec", + + border="#45475a", + border_focus="#89b4fa", + scrollbar="#45475a", + separator="#45475a", + overlay="rgba(30, 30, 46, 245)", + + link_https="#89b4fa", + link_http="#fab387", + + mention="#89b4fa", + search_highlight="#f9e2af", + search_current="#fab387", + + reaction_bg="#313244", + reaction_bg_own="#45475a", + reaction_border="#45475a", + reaction_border_own="#585b70", + + online_dot="#a6e3a1", + online_dot_border="#181825", + pin_color="#f9e2af", + sender_name_other="#f9e2af", + receipt_read="#74c7ec", +) + +# --------------------------------------------------------------------------- +# Light theme — Signal-inspired palette +# --------------------------------------------------------------------------- + +LIGHT_THEME = ThemeColors( + bg_primary="#ffffff", + bg_secondary="#f2f2f7", + bg_tertiary="#e5e5ea", + bg_hover="#dcdce4", + bg_selected="#c7c7d1", + + text_primary="#1c1c1e", + text_secondary="#3a3a3c", + text_muted="#8a8a8e", + + accent="#3478f6", + accent_hover="#2563eb", + accent_text="#ffffff", + + bubble_sent_bg="#3478f6", + bubble_sent_text="#ffffff", + bubble_recv_bg="#e5e5ea", + bubble_recv_text="#1c1c1e", + bubble_sent_meta="#a3c4ff", + bubble_recv_meta="#8a8a8e", + + success="#34c759", + warning="#ff9500", + error="#ff3b30", + info="#5ac8fa", + + border="#c6c6c8", + border_focus="#3478f6", + scrollbar="#aeaeb2", + separator="#c6c6c8", + overlay="rgba(0, 0, 0, 200)", + + link_https="#2563eb", + link_http="#ea580c", + + mention="#2563eb", + search_highlight="#fde68a", + search_current="#fb923c", + + reaction_bg="#e5e5ea", + reaction_bg_own="#c7c7d1", + reaction_border="#c6c6c8", + reaction_border_own="#a0a0a8", + + online_dot="#34c759", + online_dot_border="#e5e5ea", + pin_color="#ff9500", + sender_name_other="#7c3aed", + receipt_read="#d0e8ff", +) + + +# --------------------------------------------------------------------------- +# ThemeManager singleton +# --------------------------------------------------------------------------- + +class ThemeManager: + """Manages the active theme, persistence and change notification.""" + + _instance: ThemeManager | None = None + + @classmethod + def instance(cls) -> ThemeManager: + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def __init__(self): + self._is_dark: bool = True + self._listeners: list[Callable[[], None]] = [] + self._email: str | None = None + self._load_global() + + # -- Public API -- + + @property + def is_dark(self) -> bool: + return self._is_dark + + @property + def colors(self) -> ThemeColors: + return DARK_THEME if self._is_dark else LIGHT_THEME + + def toggle(self): + self._is_dark = not self._is_dark + self._save() + self._notify() + + def set_dark(self, dark: bool): + if dark == self._is_dark: + return + self._is_dark = dark + self._save() + self._notify() + + def set_email(self, email: str): + """After login, bind to user-specific preference file.""" + self._email = email + self._load_user() + + def on_change(self, callback: Callable[[], None]): + self._listeners.append(callback) + + def remove_listener(self, callback: Callable[[], None]): + try: + self._listeners.remove(callback) + except ValueError: + pass + + def generate_qss(self) -> str: + return _build_qss(self.colors) + + # -- Persistence -- + + def _global_path(self) -> Path: + p = Path.home() / ".encrypted_chat" + p.mkdir(parents=True, exist_ok=True) + return p / "global_settings.json" + + def _user_path(self) -> Path | None: + if not self._email: + return None + p = Path.home() / ".encrypted_chat" / self._email + if not p.exists(): + return None + return p / "theme.json" + + def _load_global(self): + try: + p = self._global_path() + if p.exists(): + data = json.loads(p.read_text()) + self._is_dark = data.get("dark", True) + except Exception: + pass + + def _load_user(self): + try: + p = self._user_path() + if p and p.exists(): + data = json.loads(p.read_text()) + self._is_dark = data.get("dark", self._is_dark) + except Exception: + pass + + def _save(self): + data = {"dark": self._is_dark} + try: + self._global_path().write_text(json.dumps(data)) + except Exception: + pass + try: + p = self._user_path() + if p: + p.write_text(json.dumps(data)) + except Exception: + pass + + def _notify(self): + for cb in list(self._listeners): + try: + cb() + except Exception: + logger.debug("Theme listener error", exc_info=True) + + +# --------------------------------------------------------------------------- +# Convenience accessors +# --------------------------------------------------------------------------- + +def c() -> ThemeColors: + """Shorthand for ThemeManager.instance().colors.""" + return ThemeManager.instance().colors + + +def qss() -> str: + """Shorthand for ThemeManager.instance().generate_qss().""" + return ThemeManager.instance().generate_qss() + + +def tm() -> ThemeManager: + """Shorthand for ThemeManager.instance().""" + return ThemeManager.instance() + + +# --------------------------------------------------------------------------- +# QSS generator +# --------------------------------------------------------------------------- + +_FONT_STACK = ( + '"Segoe UI Variable", "Segoe UI", "Helvetica Neue", ' + '"SF Pro Text", "Calibri", sans-serif' +) + + +def _build_qss(t: ThemeColors) -> str: + return f""" +/* ── Global ──────────────────────────────────────────────── */ +QWidget {{ + background-color: {t.bg_primary}; + color: {t.text_primary}; + font-family: {_FONT_STACK}; + font-size: 11pt; +}} + +/* ── Input fields ────────────────────────────────────────── */ +QLineEdit {{ + background-color: {t.bg_secondary}; + border: 1px solid {t.border}; + border-radius: 6px; + padding: 8px; + color: {t.text_primary}; +}} +QLineEdit:focus {{ + border: 1px solid {t.border_focus}; +}} + +/* ── Buttons ─────────────────────────────────────────────── */ +QPushButton {{ + background-color: {t.accent}; + color: {t.accent_text}; + border: none; + border-radius: 6px; + padding: 8px 16px; + font-weight: bold; +}} +QPushButton:hover {{ + background-color: {t.accent_hover}; +}} +QPushButton:pressed {{ + background-color: {t.accent_hover}; +}} +QPushButton#secondaryBtn {{ + background-color: {t.bg_secondary}; + color: {t.text_primary}; + font-weight: normal; +}} +QPushButton#secondaryBtn:hover {{ + background-color: {t.bg_hover}; +}} +QPushButton#toolBtn {{ + background-color: transparent; + border: none; + border-radius: 4px; + padding: 4px; +}} +QPushButton#toolBtn:hover {{ + background-color: {t.bg_hover}; +}} + +/* ── Lists ───────────────────────────────────────────────── */ +QListWidget {{ + background-color: {t.bg_tertiary}; + border: none; + border-radius: 6px; + padding: 4px; +}} +QListWidget::item {{ + padding: 10px; + border-radius: 4px; +}} +QListWidget::item:selected {{ + background-color: {t.bg_selected}; + border-left: 3px solid {t.accent}; +}} +QListWidget::item:hover {{ + background-color: {t.bg_hover}; + color: {t.text_primary}; +}} + +/* ── Text areas ──────────────────────────────────────────── */ +QTextEdit, QTextBrowser {{ + background-color: {t.bg_primary}; + border: none; + border-radius: 6px; + padding: 8px; + color: {t.text_primary}; +}} + +/* ── Scrollbar ───────────────────────────────────────────── */ +QScrollBar:vertical {{ + background: transparent; + width: 8px; + margin: 0; +}} +QScrollBar::handle:vertical {{ + background: {t.scrollbar}; + border-radius: 4px; + min-height: 30px; +}} +QScrollBar::handle:vertical:hover {{ + background: {t.text_muted}; +}} +QScrollBar::add-line:vertical, QScrollBar::sub-line:vertical {{ + height: 0; +}} +QScrollBar::add-page:vertical, QScrollBar::sub-page:vertical {{ + background: transparent; +}} +QScrollBar:horizontal {{ + background: transparent; + height: 8px; + margin: 0; +}} +QScrollBar::handle:horizontal {{ + background: {t.scrollbar}; + border-radius: 4px; + min-width: 30px; +}} +QScrollBar::handle:horizontal:hover {{ + background: {t.text_muted}; +}} +QScrollBar::add-line:horizontal, QScrollBar::sub-line:horizontal {{ + width: 0; +}} +QScrollBar::add-page:horizontal, QScrollBar::sub-page:horizontal {{ + background: transparent; +}} + +/* ── Title label ─────────────────────────────────────────── */ +QLabel#title {{ + font-size: 15pt; + font-weight: bold; + color: {t.accent}; +}} + +/* ── Sidebar panel ───────────────────────────────────────── */ +#sidebarPanel {{ + background-color: {t.bg_tertiary}; +}} + +/* ── Splitter ────────────────────────────────────────────── */ +QSplitter::handle {{ + background-color: {t.separator}; + width: 1px; +}} + +/* ── Checkbox ────────────────────────────────────────────── */ +QCheckBox {{ + color: {t.text_primary}; +}} + +/* ── Menus ───────────────────────────────────────────────── */ +QMenu {{ + background-color: {t.bg_secondary}; + border: 1px solid {t.border}; + border-radius: 6px; + padding: 4px; +}} +QMenu::item {{ + padding: 6px 20px; + color: {t.text_primary}; + border-radius: 4px; +}} +QMenu::item:selected {{ + background-color: {t.bg_hover}; +}} +QMenu::separator {{ + height: 1px; + background: {t.separator}; + margin: 4px 8px; +}} + +/* ── Dialogs ─────────────────────────────────────────────── */ +QDialog {{ + background-color: {t.bg_primary}; + color: {t.text_primary}; +}} + +/* ── MessageBox ──────────────────────────────────────────── */ +QMessageBox {{ + background-color: {t.bg_primary}; + color: {t.text_primary}; +}} +QMessageBox QLabel {{ + color: {t.text_primary}; +}} + +/* ── InputDialog ─────────────────────────────────────────── */ +QInputDialog {{ + background-color: {t.bg_primary}; + color: {t.text_primary}; +}} + +/* ── ScrollArea ──────────────────────────────────────────── */ +QScrollArea {{ + background-color: {t.bg_primary}; + border: none; +}} + +/* ── ToolTip ─────────────────────────────────────────────── */ +QToolTip {{ + background-color: {t.bg_secondary}; + color: {t.text_primary}; + border: 1px solid {t.border}; + padding: 4px 8px; + font-size: 9pt; +}} +""" diff --git a/uploads/041d292d-c94a-4c46-b8f0-b4ac02536d50.enc b/uploads/041d292d-c94a-4c46-b8f0-b4ac02536d50.enc new file mode 100644 index 0000000..d9f1f35 Binary files /dev/null and b/uploads/041d292d-c94a-4c46-b8f0-b4ac02536d50.enc differ diff --git a/uploads/055eca42-b4b5-4211-b819-030dc601e9b4.enc b/uploads/055eca42-b4b5-4211-b819-030dc601e9b4.enc new file mode 100644 index 0000000..494c7fb Binary files /dev/null and b/uploads/055eca42-b4b5-4211-b819-030dc601e9b4.enc differ diff --git a/uploads/0cb3e08e-b294-437d-8228-abf97c996311.enc b/uploads/0cb3e08e-b294-437d-8228-abf97c996311.enc new file mode 100644 index 0000000..8368ad2 Binary files /dev/null and b/uploads/0cb3e08e-b294-437d-8228-abf97c996311.enc differ diff --git a/uploads/189af45e-fa60-4fd6-8a3f-8313921d1f48.enc b/uploads/189af45e-fa60-4fd6-8a3f-8313921d1f48.enc new file mode 100644 index 0000000..a4f53f1 Binary files /dev/null and b/uploads/189af45e-fa60-4fd6-8a3f-8313921d1f48.enc differ diff --git a/uploads/1ce346fd-54f6-4e7a-85c4-7c5098e01d2a.enc b/uploads/1ce346fd-54f6-4e7a-85c4-7c5098e01d2a.enc new file mode 100644 index 0000000..197bb3e Binary files /dev/null and b/uploads/1ce346fd-54f6-4e7a-85c4-7c5098e01d2a.enc differ diff --git a/uploads/1fb0cffc-1e95-4f21-aba9-90fc398d1bb2.enc b/uploads/1fb0cffc-1e95-4f21-aba9-90fc398d1bb2.enc new file mode 100644 index 0000000..bbc5782 Binary files /dev/null and b/uploads/1fb0cffc-1e95-4f21-aba9-90fc398d1bb2.enc differ diff --git a/uploads/23737dc9-334b-4818-a258-758596e75aef.enc b/uploads/23737dc9-334b-4818-a258-758596e75aef.enc new file mode 100644 index 0000000..d4a9e0d Binary files /dev/null and b/uploads/23737dc9-334b-4818-a258-758596e75aef.enc differ diff --git a/uploads/37ef215b-b30e-4339-a7ab-43445df27526.enc b/uploads/37ef215b-b30e-4339-a7ab-43445df27526.enc new file mode 100644 index 0000000..59c4e9f Binary files /dev/null and b/uploads/37ef215b-b30e-4339-a7ab-43445df27526.enc differ diff --git a/uploads/41384719-ab8c-4b65-9823-91217d3bf3d3.enc b/uploads/41384719-ab8c-4b65-9823-91217d3bf3d3.enc new file mode 100644 index 0000000..bf87a11 Binary files /dev/null and b/uploads/41384719-ab8c-4b65-9823-91217d3bf3d3.enc differ diff --git a/uploads/4f661a78-fd57-469f-8af1-fd88bf8c167e.enc b/uploads/4f661a78-fd57-469f-8af1-fd88bf8c167e.enc new file mode 100644 index 0000000..b3966e7 Binary files /dev/null and b/uploads/4f661a78-fd57-469f-8af1-fd88bf8c167e.enc differ diff --git a/uploads/572a6f37-6a4f-4004-a95a-9e10a3080b7e.enc b/uploads/572a6f37-6a4f-4004-a95a-9e10a3080b7e.enc new file mode 100644 index 0000000..a6bc829 Binary files /dev/null and b/uploads/572a6f37-6a4f-4004-a95a-9e10a3080b7e.enc differ diff --git a/uploads/6aabd2ba-c64b-40ae-960a-a7a161c337db.enc b/uploads/6aabd2ba-c64b-40ae-960a-a7a161c337db.enc new file mode 100644 index 0000000..0cb6da1 Binary files /dev/null and b/uploads/6aabd2ba-c64b-40ae-960a-a7a161c337db.enc differ diff --git a/uploads/6deec74d-940d-498a-8f91-38424b935a13.enc b/uploads/6deec74d-940d-498a-8f91-38424b935a13.enc new file mode 100644 index 0000000..95b61f6 Binary files /dev/null and b/uploads/6deec74d-940d-498a-8f91-38424b935a13.enc differ diff --git a/uploads/73680cd7-4a47-4944-980f-4225e019527c.enc b/uploads/73680cd7-4a47-4944-980f-4225e019527c.enc new file mode 100644 index 0000000..6284f2e Binary files /dev/null and b/uploads/73680cd7-4a47-4944-980f-4225e019527c.enc differ diff --git a/uploads/7e32ef79-2c29-4466-8c1a-cd5cee3e430c.enc b/uploads/7e32ef79-2c29-4466-8c1a-cd5cee3e430c.enc new file mode 100644 index 0000000..d983335 Binary files /dev/null and b/uploads/7e32ef79-2c29-4466-8c1a-cd5cee3e430c.enc differ diff --git a/uploads/8cc77d3d-f28a-4b9c-80b6-adc6ec672214.enc b/uploads/8cc77d3d-f28a-4b9c-80b6-adc6ec672214.enc new file mode 100644 index 0000000..1900437 Binary files /dev/null and b/uploads/8cc77d3d-f28a-4b9c-80b6-adc6ec672214.enc differ diff --git a/uploads/8e81df99-8ae0-4348-842b-a3d0de0510f5.enc b/uploads/8e81df99-8ae0-4348-842b-a3d0de0510f5.enc new file mode 100644 index 0000000..7476fbb Binary files /dev/null and b/uploads/8e81df99-8ae0-4348-842b-a3d0de0510f5.enc differ diff --git a/uploads/9321986f-0918-4e59-b30b-bc8086a20508.enc b/uploads/9321986f-0918-4e59-b30b-bc8086a20508.enc new file mode 100644 index 0000000..1339c34 Binary files /dev/null and b/uploads/9321986f-0918-4e59-b30b-bc8086a20508.enc differ diff --git a/uploads/9a17095f-58bc-461c-a8cc-43a20ec78392.enc b/uploads/9a17095f-58bc-461c-a8cc-43a20ec78392.enc new file mode 100644 index 0000000..51686fd Binary files /dev/null and b/uploads/9a17095f-58bc-461c-a8cc-43a20ec78392.enc differ diff --git a/uploads/a5eb6b09-47ab-43c0-9d44-bef051c56a17.enc b/uploads/a5eb6b09-47ab-43c0-9d44-bef051c56a17.enc new file mode 100644 index 0000000..8608ab8 Binary files /dev/null and b/uploads/a5eb6b09-47ab-43c0-9d44-bef051c56a17.enc differ diff --git a/uploads/avatars/0b282232-9214-4fbe-a72d-2f07a74760e3.png b/uploads/avatars/0b282232-9214-4fbe-a72d-2f07a74760e3.png new file mode 100644 index 0000000..730164c Binary files /dev/null and b/uploads/avatars/0b282232-9214-4fbe-a72d-2f07a74760e3.png differ diff --git a/uploads/avatars/14420bdd-f4e3-4e57-87e9-4c553361cb1c.png b/uploads/avatars/14420bdd-f4e3-4e57-87e9-4c553361cb1c.png new file mode 100644 index 0000000..d28ec02 Binary files /dev/null and b/uploads/avatars/14420bdd-f4e3-4e57-87e9-4c553361cb1c.png differ diff --git a/uploads/avatars/1c4b7c64-fb08-4cc6-948f-5c9e46aa5b35.jpg b/uploads/avatars/1c4b7c64-fb08-4cc6-948f-5c9e46aa5b35.jpg new file mode 100644 index 0000000..aeda4ea Binary files /dev/null and b/uploads/avatars/1c4b7c64-fb08-4cc6-948f-5c9e46aa5b35.jpg differ diff --git a/uploads/avatars/59abf7ba-576c-4052-9042-6ccd18e68ded.png b/uploads/avatars/59abf7ba-576c-4052-9042-6ccd18e68ded.png new file mode 100644 index 0000000..6065b4d Binary files /dev/null and b/uploads/avatars/59abf7ba-576c-4052-9042-6ccd18e68ded.png differ diff --git a/uploads/avatars/74e3f5a5-df49-4f76-ba39-c5b5aab2e77b.jpg b/uploads/avatars/74e3f5a5-df49-4f76-ba39-c5b5aab2e77b.jpg new file mode 100644 index 0000000..cc053bb Binary files /dev/null and b/uploads/avatars/74e3f5a5-df49-4f76-ba39-c5b5aab2e77b.jpg differ diff --git a/uploads/avatars/88014d90-bcbc-499b-9fd8-8147ecf850c2.png b/uploads/avatars/88014d90-bcbc-499b-9fd8-8147ecf850c2.png new file mode 100644 index 0000000..b8041f8 Binary files /dev/null and b/uploads/avatars/88014d90-bcbc-499b-9fd8-8147ecf850c2.png differ diff --git a/uploads/avatars/b819e594-fd8f-4686-99d8-ca1e25b23082.jpg b/uploads/avatars/b819e594-fd8f-4686-99d8-ca1e25b23082.jpg new file mode 100644 index 0000000..ac5b77e Binary files /dev/null and b/uploads/avatars/b819e594-fd8f-4686-99d8-ca1e25b23082.jpg differ diff --git a/uploads/avatars/c8565a7f-7aee-44f1-a9b7-d7bbec0c3d81.jpg b/uploads/avatars/c8565a7f-7aee-44f1-a9b7-d7bbec0c3d81.jpg new file mode 100644 index 0000000..3d7326f Binary files /dev/null and b/uploads/avatars/c8565a7f-7aee-44f1-a9b7-d7bbec0c3d81.jpg differ diff --git a/uploads/avatars/ca301dba-e119-4e7a-9776-3a226488fcf0.jpg b/uploads/avatars/ca301dba-e119-4e7a-9776-3a226488fcf0.jpg new file mode 100644 index 0000000..f3b1691 Binary files /dev/null and b/uploads/avatars/ca301dba-e119-4e7a-9776-3a226488fcf0.jpg differ diff --git a/uploads/avatars/d461ed16-5d8c-4a1d-984a-af784d3f0925.jpg b/uploads/avatars/d461ed16-5d8c-4a1d-984a-af784d3f0925.jpg new file mode 100644 index 0000000..f3b1691 Binary files /dev/null and b/uploads/avatars/d461ed16-5d8c-4a1d-984a-af784d3f0925.jpg differ diff --git a/uploads/avatars/group_d299bac6-8fa1-46bc-bc13-8ace9c6ada10.png b/uploads/avatars/group_d299bac6-8fa1-46bc-bc13-8ace9c6ada10.png new file mode 100644 index 0000000..20d0270 Binary files /dev/null and b/uploads/avatars/group_d299bac6-8fa1-46bc-bc13-8ace9c6ada10.png differ diff --git a/uploads/b544b412-50ea-4ba1-9798-9410ec209bb3.enc b/uploads/b544b412-50ea-4ba1-9798-9410ec209bb3.enc new file mode 100644 index 0000000..9b3de0c Binary files /dev/null and b/uploads/b544b412-50ea-4ba1-9798-9410ec209bb3.enc differ diff --git a/uploads/c9996203-b91a-4afc-b0fc-f8eb578495ed.enc b/uploads/c9996203-b91a-4afc-b0fc-f8eb578495ed.enc new file mode 100644 index 0000000..f6d6811 Binary files /dev/null and b/uploads/c9996203-b91a-4afc-b0fc-f8eb578495ed.enc differ diff --git a/uploads/ca076a30-da8c-42e3-9d8a-f560221ae691.enc b/uploads/ca076a30-da8c-42e3-9d8a-f560221ae691.enc new file mode 100644 index 0000000..40189d4 Binary files /dev/null and b/uploads/ca076a30-da8c-42e3-9d8a-f560221ae691.enc differ diff --git a/uploads/ce6de468-c848-4e7f-a2ce-d30dd7adc901.enc b/uploads/ce6de468-c848-4e7f-a2ce-d30dd7adc901.enc new file mode 100644 index 0000000..54d8baf Binary files /dev/null and b/uploads/ce6de468-c848-4e7f-a2ce-d30dd7adc901.enc differ diff --git a/uploads/d5ea4c6e-ee82-49d9-a6fb-c437af7c0bdf.enc b/uploads/d5ea4c6e-ee82-49d9-a6fb-c437af7c0bdf.enc new file mode 100644 index 0000000..ff0eb78 Binary files /dev/null and b/uploads/d5ea4c6e-ee82-49d9-a6fb-c437af7c0bdf.enc differ diff --git a/uploads/e88e52c5-9ca6-47d9-89bd-7f22def1f745.enc b/uploads/e88e52c5-9ca6-47d9-89bd-7f22def1f745.enc new file mode 100644 index 0000000..dd34ad5 Binary files /dev/null and b/uploads/e88e52c5-9ca6-47d9-89bd-7f22def1f745.enc differ diff --git a/uploads/ecc7f33d-b013-409a-ad8a-94be1a8fba1d.enc b/uploads/ecc7f33d-b013-409a-ad8a-94be1a8fba1d.enc new file mode 100644 index 0000000..3889f23 Binary files /dev/null and b/uploads/ecc7f33d-b013-409a-ad8a-94be1a8fba1d.enc differ diff --git a/uploads/f095935b-d216-415e-866a-f2a44e95c506.enc b/uploads/f095935b-d216-415e-866a-f2a44e95c506.enc new file mode 100644 index 0000000..497801d Binary files /dev/null and b/uploads/f095935b-d216-415e-866a-f2a44e95c506.enc differ diff --git a/uploads/f28e6ac4-1e84-4c9a-b9fb-3af94f6e143e.enc b/uploads/f28e6ac4-1e84-4c9a-b9fb-3af94f6e143e.enc new file mode 100644 index 0000000..96d29a0 Binary files /dev/null and b/uploads/f28e6ac4-1e84-4c9a-b9fb-3af94f6e143e.enc differ diff --git a/zaloha/.gitignore b/zaloha/.gitignore new file mode 100644 index 0000000..a1096b0 --- /dev/null +++ b/zaloha/.gitignore @@ -0,0 +1,9 @@ +__pycache__/ +*.pyc +.env +.env.* +.encrypted_chat/ +certs/* +!certs/*.sh +!certs/*.example +!certs/README.md diff --git a/zaloha/CLAUDE.md b/zaloha/CLAUDE.md new file mode 100644 index 0000000..864121b --- /dev/null +++ b/zaloha/CLAUDE.md @@ -0,0 +1,758 @@ +# Encrypted Chat — Project Context + +End-to-end encrypted chat with forward secrecy (X3DH + Double Ratchet, Signal Protocol). +Server stores and relays opaque blobs — never sees plaintext. RSA retained for login only. + +## Files + +| File | Lines | Purpose | +|------|-------|---------| +| `schema.sql` | ~158 | MySQL schema (users, devices, signed_prekeys, one_time_prekeys, conversations, conversation_members, group_invitations, messages, message_recipients, group_sender_keys, message_reads, image_uploads, user_profiles) | +| `db.py` | ~1245 | MySQL CRUD — one connection per call, `dictionary=True` cursors, returns dicts. Includes profile CRUD, `get_user_contacts()`, `update_conversation_creator()`, `get_conversation()`. Phantom user CRUD + `upgrade_phantom_user()`. Invitation CRUD. Group avatar. Device CRUD. Per-device prekey/session management. | +| `server.py` | ~1986 | Asyncio TCP server, handler dispatch, rate limiting, real-time notifications via `connected_clients` dict. Profile + avatar handlers. Online/offline status push. Leave group, delete conversation, group invitations, group avatar handlers. Phantom user support. Graceful shutdown. 4 asyncio.Lock guards (H4 fix). Device registration + per-device key bundles + per-device notifications. SPK age reporting in `get_prekey_count`. | +| `protocol.py` | ~114 | Newline-delimited JSON protocol, `ProtocolReader`/`ProtocolWriter`, `encode_binary`/`decode_binary` (base64). Constants: `VERSION`, `MAX_MESSAGE_BYTES`, `MAX_IMAGE_BYTES`, `MAX_FILE_BYTES`, `IMAGE_CHUNK_SIZE`. | +| `crypto_utils.py` | ~812 | Ed25519, X25519, AES-256-GCM, HKDF, PBKDF2, X3DH, `DoubleRatchet` (with state snapshot/rollback), `SenderKeyState` (with state snapshot/rollback). RSA for login only. ECP1 password-based key encryption format (600k PBKDF2 iterations). | +| `chat_core.py` | ~2555 | `ChatClient` class — session management, X3DH/ratchet encryption, local key storage, reconnect, profiles, file sharing, leave group, delete conversation, invitations, group avatar. Multi-device: per-device sessions, device_id persistence, device bundle cache. SPK rotation (7-day) with grace period. Used by CLI + GUI | +| `client.py` | ~382 | Interactive CLI client | +| `gui_client.py` | ~2591 | PyQt6 GUI — `AsyncBridge` QThread bridges asyncio <-> Qt signals, `MainWindow`, `UserProfileDialog`, connection indicator + auto-reconnect, online status, file sharing, leave group, unread badges, circular avatars in conv list, online green dot overlay, group invitations UI, delete conversation, group avatar support | + +## Architecture & Data Flow + +### Encryption: X3DH + Double Ratchet (Signal Protocol) + +**Keys per user:** +- **RSA-4096** — Login challenge-response only (server stores public key). Password-encrypted with ECP1 format (PBKDF2 600k iterations + AES-256-GCM). +- **Identity Key (IK)** — Ed25519 (signing) + converted to X25519 (for DH in X3DH). Password-encrypted with ECP1 format. +- **Signed Pre-Key (SPK)** — X25519, signed by IK, uploaded to server. **Rotates every 7 days** (M4). Previous SPK kept for grace period (in-flight X3DH). +- **One-Time Pre-Keys (OPK)** — X25519, consumed on X3DH initiation, auto-replenished when count < 20 + +**DM flow:** +1. Alice fetches Bob's per-device key bundles (IK, SPK per device, OPK per device) -> X3DH per device -> shared secret per device +2. Double Ratchet initialized from shared secret — one session per (user, device) pair +3. Each message: symmetric ratchet (HMAC chain) -> message key -> AES-256-GCM +4. Each reply direction change: DH ratchet (new X25519 keypair) -> new root + chain keys +5. Per-device ciphertext — each recipient device gets individually encrypted blob +6. Self-encrypted copy uses SELF_DEVICE_ID sentinel, readable by all own devices + +**Group flow (Sender Keys):** +1. Each sender has own SenderKeyState per group +2. Sender key distributed to members via pairwise Double Ratchet (as control DM with `_sender_key` field) +3. Group messages: symmetric ratchet on sender key -> AES-256-GCM +4. Same ciphertext replicated to all recipients (efficient) + +### Protocol + +Newline-delimited JSON over TCP (optional TLS). Fields: `type`, `status`, `data`, `request_id`. +Binary data encoded as base64 via `encode_binary()`/`decode_binary()`. + +**Request/response pattern:** Client sends `{"type": "...", "request_id": "uuid", ...}`, server responds with same `request_id`. Notifications (push) have no `request_id`. + +### Server notifications (push to connected clients) +- `new_message` — per-recipient ciphertext included +- `messages_read` — conversation_id + user_id + message_ids +- `message_deleted` — message_id + conversation_id +- `conversation_created` — conversation_id, name, created_by, members[] (pushed to added members) +- `member_added` — conversation_id, user_id, username, email (pushed to all members except requester) +- `member_removed` — conversation_id, user_id (pushed to removed member + remaining members) +- `group_invitation` — conversation_id, conversation_name, invited_by, invited_by_username (pushed to invited user) +- `conversation_renamed` — conversation_id, name, renamed_by (pushed to all members except renamer) +- `session_reset` — from_user_id, from_device_id (pushed to peer when session reset requested) +- `user_online` — user_id (pushed to contacts when user connects) +- `user_offline` — user_id (pushed to contacts when user's last connection drops) +- `online_users` — user_ids[] (sent to user on login — list of currently online contacts) + +## DB Schema (schema.sql) + +``` +users: id, username, email (UNIQUE), rsa_public_key (TEXT), identity_key (BLOB 32B Ed25519), created_at +devices: id, user_id FK, device_name (nullable), created_at, last_seen_at +signed_prekeys: id, user_id FK, device_id (nullable), public_key (BLOB 32B), signature (BLOB 64B), created_at +one_time_prekeys: id, user_id FK, device_id (nullable), public_key (BLOB 32B) +conversations: id, created_at, name (nullable), created_by (nullable), avatar_file (nullable) +conversation_members: conversation_id + user_id (composite PK), joined_at +group_invitations: id, conversation_id FK, user_id FK, invited_by FK, created_at, UNIQUE(conversation_id, user_id) +messages: id, conversation_id FK, sender_id FK, sender_device_id (nullable), ratchet_header (BLOB JSON), + x3dh_header (BLOB JSON nullable), sender_chain_id (BLOB nullable), sender_chain_n (INT nullable), + created_at, deleted_at, image_file_id +message_recipients: message_id + user_id + device_id (composite PK), encrypted_content (BLOB), nonce (BLOB), + ratchet_header (BLOB nullable), x3dh_header (BLOB nullable) +group_sender_keys: conversation_id + sender_id + device_id (composite PK), chain_id (BLOB 32B), created_at +message_reads: message_id + user_id (composite PK), read_at +image_uploads: file_id (PK), conversation_id FK, uploader_id FK, file_size, completed, created_at +user_profiles: user_id (PK FK), phone, phone_visible, email_visible, location, location_visible, avatar_file, updated_at +``` + +Constant: `SELF_DEVICE_ID = "00000000-0000-0000-0000-000000000000"` — sentinel for self-encrypted copies and legacy rows. + +Index: `messages(conversation_id, created_at)` for query performance. + +## Server Protocol — All Message Types + +### Pre-login (no session required) +| Type | Handler | Purpose | +|------|---------|---------| +| `register` | `handle_register_start` | Start registration (username, email, public_key, identity_key) | +| `register_confirm` | `handle_register_confirm` | Confirm with 6-digit code | +| `login_start` | `handle_login_start` | Get RSA challenge | +| `login_finish` | `handle_login_finish` | Respond with RSA signature -> session. Client sends `client_version`, server returns `server_version` in response. Also sends `online_users` and `user_online` notifications. | +| `get_user_info` | `handle_get_user_info` | Get user info + identity_key (by email or user_id) | +| `pairing_start` | `handle_pairing_start` | New device starts pairing (gets 8-digit code) | +| `pairing_poll` | `handle_pairing_poll` | New device polls for key payload | + +### Post-login (session required) +| Type | Handler | Purpose | +|------|---------|---------| +| `upload_prekeys` | `handle_upload_prekeys` | Upload SPK + batch of OPKs (server verifies SPK signature) | +| `get_key_bundle` | `handle_get_key_bundle` | Fetch key bundle for X3DH (consumes one OPK) | +| `get_prekey_count` | `handle_get_prekey_count` | Check remaining OPK count + SPK age (`spk_created_at`) for rotation | +| `create_conversation` | `handle_create_conversation` | Create conversation — DMs auto-add both; groups add creator only + create invitations for others | +| `find_conversation` | `handle_find_conversation` | Find existing DM by email | +| `add_member` | `handle_add_member` | Create invitation for user to join group (was: direct add) | +| `remove_member` | `handle_remove_member` | Remove member (creator only) | +| `leave_group` | `handle_leave_group` | Voluntarily leave a group (transfers creator if needed, blocks DM leave) | +| `rename_conversation` | `handle_rename_conversation` | Rename group conversation (creator only, max 100 chars), pushes `conversation_renamed` to members | +| `delete_conversation` | `handle_delete_conversation` | Delete conversation — DMs: remove self; groups: creator-only, removes all members + files | +| `accept_invitation` | `handle_accept_invitation` | Accept pending group invitation → add to members, notify others | +| `decline_invitation` | `handle_decline_invitation` | Decline pending group invitation | +| `list_invitations` | `handle_list_invitations` | List user's pending invitations (with conv name + inviter username) | +| `list_conversations` | `handle_list_conversations` | List all user's conversations (includes avatar_file) | +| `send_message` | `handle_send_message` | Send encrypted message (ratchet_header + recipients[]) | +| `get_messages` | `handle_get_messages` | Get messages (returns per-user ciphertext, JOINs message_recipients) | +| `mark_read` | `handle_mark_read` | Mark messages as read | +| `delete_message` | `handle_delete_message` | Soft-delete message (sender only) | +| `rotate_keys` | `handle_rotate_keys` | Rotate RSA login key, disconnect other sessions | +| `pairing_claim` | `handle_pairing_claim` | Authorized device claims pairing code | +| `pairing_send` | `handle_pairing_send` | Authorized device sends encrypted key payload | +| `upload_image_start/chunk/end` | Image/file upload | Chunked encrypted upload (32KB chunks). `file_type` param: `"image"` (5MB limit) or `"file"` (50MB limit). | +| `download_image` | Image/file download | Chunked download with offset | +| `get_profile` | `handle_get_profile` | Get user profile (respects visibility for other users) | +| `update_profile` | `handle_update_profile` | Update own profile (phone, location, visibility toggles) | +| `update_avatar` | `handle_update_avatar` | Upload user avatar (base64, max 2MB, JPEG/PNG) | +| `get_avatar` | `handle_get_avatar` | Download user's avatar | +| `update_group_avatar` | `handle_update_group_avatar` | Upload group avatar (base64, max 2MB, JPEG/PNG, creator only) | +| `get_group_avatar` | `handle_get_group_avatar` | Download group avatar | +| `reencrypt_messages` | `handle_reencrypt_messages` | Batch re-encrypt message history with self-key (max 500/request, for device pairing) | +| `list_devices` | `handle_list_devices` | List all devices for current user | +| `remove_device` | `handle_remove_device` | Remove a device (not current device) | +| `session_reset` | `handle_session_reset` | Notify peer to reset corrupted Double Ratchet session (push `session_reset` to peer) | + +## Key Classes & Functions + +### crypto_utils.py + +**Password-based key encryption (ECP1 format):** +- `PBKDF2_ITERATIONS = 600_000` — OWASP 2023 compliant +- `_encrypt_private_key(raw_bytes, password) -> bytes` — PBKDF2-HMAC-SHA256 + AES-256-GCM. Format: `_ECP1_MAGIC(4) + salt(16) + nonce(12) + ciphertext_with_tag` +- `_decrypt_private_key(data, password) -> bytes` — Detects ECP1 magic prefix, derives key, decrypts + +**RSA (login only):** `generate_rsa_keypair()`, `serialize_private_key()` (ECP1 with password, PEM without), `serialize_public_key()`, `load_private_key()` (auto-detects ECP1 vs legacy PEM), `load_public_key()`, `rsa_sign()`, `rsa_verify()` + +**AES-256-GCM:** `aes_encrypt(plaintext, key=None) -> (key, nonce, ct, tag)`, `aes_decrypt(key, nonce, ct, tag) -> plaintext` + +**Ed25519:** `generate_identity_keypair()`, `serialize_ed25519_private()` (ECP1 with password, 32-byte raw without), `serialize_ed25519_private_raw()`, `serialize_ed25519_public()`, `load_ed25519_private()` (auto-detects ECP1 vs legacy PEM vs raw), `load_ed25519_public()`, `ed25519_sign()`, `ed25519_verify()` + +**X25519:** `generate_x25519_keypair()`, `serialize_x25519_private()`, `serialize_x25519_public()`, `load_x25519_private()`, `load_x25519_public()`, `x25519_dh()` + +**Key conversion:** `ed25519_private_to_x25519()` (SHA-512 + clamp), `ed25519_public_to_x25519()` (Montgomery u-coordinate) + +**HKDF:** `hkdf_derive()`, `kdf_rk(root_key, dh_output) -> (new_root_key, chain_key)`, `kdf_ck(chain_key) -> (new_chain_key, message_key)` + +**X3DH:** `generate_signed_prekey(identity_private) -> {private, public, signature, id}`, `generate_one_time_prekeys(count=50) -> [{private, public, id}]`, `x3dh_initiate(ik_private_ed, ik_public_remote_ed, spk_remote, spk_signature, opk_remote?) -> (shared_secret, ek_priv, ek_pub)`, `x3dh_respond(ik_private_ed, spk_private, ik_remote_ed, ek_remote, opk_private?) -> shared_secret` + +**DoubleRatchet class:** +- `init_alice(shared_secret, bob_spk_pub)` — initiator, performs first DH ratchet +- `init_bob(shared_secret, spk_pair)` — responder, waits for first message +- `encrypt(plaintext) -> {header: {dh_pub, n, pn}, ciphertext, nonce}` — AAD = serialized header +- `decrypt(header_dict, ciphertext, nonce)` — handles DH ratchet step if new dh_pub, skipped messages. **State snapshot/rollback on failure (M9):** `_snapshot()` captures all mutable state before modifications, `_restore()` rolls back on any exception. +- `_snapshot() -> dict` / `_restore(snap)` — Snapshot: dh_pair, dh_remote, root_key, send/recv chain keys, counters, skipped dict. Used internally by `decrypt()`. +- `export_state() -> bytes` / `import_state(data) -> DoubleRatchet` — JSON serialization + +**SenderKeyState class:** +- `__init__(sender_key=None)` — generates random 32B key if None +- `encrypt(plaintext) -> {chain_id, n, ciphertext, nonce}` — AAD = chain_id + message number +- `decrypt(chain_id_hex, n, ciphertext, nonce)` — fast-forwards chain if needed. **State snapshot/rollback on failure (M9):** snapshots chain_key, n, _known_keys before fast-forward, restores on exception. +- `export_key() -> bytes` — for distribution to group members +- `from_key(exported_key) -> SenderKeyState` — receiver initializes from exported key +- `export_state() / import_state()` — full state persistence + +### chat_core.py + +**Local key storage** (`~/.encrypted_chat/{email}/`): +``` +private.pem / public.pem — RSA (login, ECP1 format when password-encrypted) +identity_private.bin / identity_public.bin — Ed25519 (ECP1 format when password-encrypted, 32B raw otherwise) +device_id.txt — This device's UUID +spk_private.bin / spk_id.txt — Current signed prekey +prev_spk_private.bin / prev_spk_id.txt — Previous SPK for grace period (M4, in-flight X3DH) +opk_private/{opk_id}.bin — One-time prekeys +sessions/{user_id}_{device_id}.bin — Double Ratchet states (per peer device) +sender_keys/{conv_id}.bin — Own sender key states +sender_keys_recv/{conv_id}_{sender_id}_{device_id}.bin — Received sender keys (per sender device) +``` + +Storage functions: `save_keys()`, `load_keys()`, `_save_identity_keys()`, `_load_identity_keys()`, `_save_spk()`, `_load_spk()`, `_save_prev_spk()`, `_load_prev_spk()`, `_save_opk_private()`, `_load_opk_private()`, `_delete_opk_private()`, `_save_session()`, `_load_session()`, `_save_sender_key_state()`, `_load_sender_key_state()`, `_save_recv_sender_key()`, `_load_recv_sender_key()` + +**ChatClient attributes:** +- `private_key` / `public_key` — RSA (login) +- `identity_private` / `identity_public` — Ed25519 +- `spk_private` / `spk_id` — Current SPK +- `_prev_spk_private` / `_prev_spk_id` — Previous SPK for grace period (M4) +- `opk_privates: dict[str, X25519PrivateKey]` — OPK private keys by ID +- `device_id: str | None` — this device's UUID (persisted to disk) +- `sessions: dict[str, DoubleRatchet]` — "user_id:device_id" -> ratchet (per peer device) +- `sender_key_states: dict[str, SenderKeyState]` — conv_id -> own sender key +- `recv_sender_keys: dict[str, SenderKeyState]` — "conv_id:sender_id:device_id" -> their key +- `_device_bundle_cache: dict[str, tuple[float, list]]` — user_id -> (timestamp, device_bundles) with 5-min TTL +- `_user_cache: dict[str, dict]` — user_id -> {identity_key, username, email} +- `connected: bool` — current connection state + +**Key methods:** +- `register()` — Generates RSA + Ed25519, sends to server +- `confirm_registration()` — Confirms code, uploads prekeys (SPK + 50 OPKs) +- `login()` — Loads keys from disk (including prev_spk for grace period), RSA challenge-response, auto `_ensure_prekeys()` +- `_ensure_prekeys()` — Checks OPK count AND SPK age. Replenishes OPKs if < 20, **rotates SPK if >= 7 days old** (M4). Saves old SPK as grace period before generating new one. +- `_get_device_bundles(peer_user_id)` — Fetches per-device key bundles with 5-min TTL cache +- `_get_or_create_session(peer_user_id, peer_device_id, bundle)` — Loads from memory/disk or creates via X3DH, keyed by "user:device" +- `_process_x3dh_header(sender_id, x3dh_header, sender_device_id, spk_override?)` — Bob side of X3DH. `spk_override` param allows using previous SPK for grace period (M4). +- `send_message(conv_id, text, members, reply_to?)` — Routes to `_send_dm` or `_send_group_message` +- `_send_dm()` — Per-device Double Ratchet (encrypts for each of recipient's devices), self-encrypted copy with SELF_DEVICE_ID +- `_send_group_message()` — Sender Keys, distributes key if new (per-device) +- `_distribute_sender_key()` — Sends sender key as control message via per-device pairwise ratchet, includes sender_device_id +- `_decrypt_dm()` — Handles X3DH header for new sessions, returns None for control messages. On X3DH decrypt failure, retries with previous SPK (M4 grace period). +- `_decrypt_group()` — Uses received sender key chain +- `decrypt_notification()` — Returns None for control messages (sender key distribution) +- `get_messages()` — Batch decrypt, marks read, skips control messages +- `authorize_device()` — Exports RSA + Ed25519 only (simplified for multi-device — no session/sender key transfer) +- `pairing_wait()` — Imports RSA + identity key from paired device (new device generates own SPK + OPKs on login) +- `reconnect()` — Closes connection, re-establishes TCP + RSA login using in-memory keys +- `get_profile(user_id?)` — Gets user profile from server +- `update_profile(**fields)` — Updates own profile (phone, location, visibility) +- `update_avatar(image_data)` — Uploads avatar +- `get_avatar(user_id)` — Downloads avatar bytes +- `send_file(conv_id, file_path, members, reply_to?)` — Encrypt + chunked upload + send message with `file` payload +- `download_file(file_id, file_info)` — Chunked download + AES-GCM decrypt +- `leave_group(conv_id)` — Leave group, clean up local sender keys +- `rename_conversation(conv_id, name)` — Rename group (creator only) +- `delete_conversation(conv_id)` — Delete conversation, clean up local sender keys +- `accept_invitation(conv_id)` — Accept group invitation +- `decline_invitation(conv_id)` — Decline group invitation +- `list_invitations()` — Fetch pending invitations +- `update_group_avatar(conv_id, image_data)` — Upload group avatar +- `get_group_avatar(conv_id)` — Download group avatar +- `search_messages(conv_id, query)` — Search decrypted message cache (client-side only) +- `reset_session(peer_user_id, peer_device_id?)` — Delete local session + notify peer via server +- `handle_session_reset_notification(from_user_id, from_device_id?)` — Handle incoming session reset + +### gui_client.py + +**AsyncBridge (QThread):** Runs asyncio event loop, `schedule(coro)` queues coroutines, pyqtSignals emit results back to Qt main thread. + +**Key signals:** `login_result`, `conversations_loaded`, `messages_loaded`, `message_sent`, `new_notification`, `messages_read_notification`, `message_deleted_notification`, `conversation_updated`, `connection_state_changed`, `profile_loaded`, `profile_updated`, `avatar_loaded`, `online_status_changed`, `online_users_loaded`, `file_sent`, `file_downloaded`, `group_left`, `conversation_deleted`, `invitations_loaded`, `invitation_result`, `invitation_received`, `group_avatar_loaded`, `group_avatar_updated`, `session_reset_notification` + +**MainWindow:** Dark theme (Catppuccin Mocha), conversation list with circular avatars + online green dot overlay + unread count badges, message bubbles with colored left border, context menu (reply, delete, view image, download file), image thumbnails via QTextDocument resources (`thumb://{file_id}`), file cards with download links (`file://{file_id}`), connection indicator dot (green/red/orange), profile button, attach menu (Image/File), Leave Group button in group info, delete conversation button (trash icon in header), group avatar display + change in group info dialog, invitation list (amber border) above conversation list with right-click accept/decline. + +**UserProfileDialog:** View (read-only) and edit (own profile) modes. Fields: avatar (circular crop), username, email, phone, location, visibility toggles. Avatar upload/download. Opened from "My Profile" button or user info button in group info dialog. + +**Avatar system in conversation list:** +- `_avatar_cache: dict[str, QPixmap]` — user avatars by user_id +- `_group_avatar_cache: dict[str, QPixmap]` — group avatars by conv_id +- `_avatar_requested: set[str]` / `_group_avatar_requested: set[str]` — dedup download requests +- `_make_circular_avatar(pixmap, size=32)` — QPainter circular crop +- `_make_default_avatar(username, size=32)` — colored circle with initial letter (deterministic color from username hash) +- `_add_online_dot(avatar)` — green dot overlay bottom-right +- `_get_conv_avatar(conv)` — returns QIcon (DM: user avatar + online dot; group: group avatar or default) +- Periodic refresh every 2 minutes via `_refresh_timer` / `_on_periodic_refresh()` + +## Important Implementation Details + +### X3DH Header Caching +When `_get_or_create_session()` creates a new session via X3DH, it attaches the X3DH header as `ratchet._x3dh_header`. The next `_send_dm()` reads and deletes it. This ensures the X3DH header is only sent with the first message. + +### Self-Encryption for DMs +Sender uses `derive_self_encryption_key(identity_private)` to encrypt their own copy of sent messages with a static AES key. Uses `SELF_DEVICE_ID` sentinel so all own devices can read it. This allows reading own sent messages when fetching history from any device. + +### Sender Key Distribution as Control Messages +Sender keys are distributed via normal `send_message` protocol (per-device pairwise ratchet). The payload contains `_sender_key: {conv_id, key, sender_device_id}` field. On decryption, `_decrypt_dm()` detects this field, stores the sender key keyed by `"conv_id:sender_id:sender_device_id"`, and returns `None` (not shown to user). + +### Group Messages: Dummy Ratchet Header +Group messages use `{"dh_pub": "00"*32, "n": 0, "pn": 0}` as ratchet_header because the server requires it, but groups use sender keys instead of Double Ratchet. + +### Multi-Device Architecture +Each device has independent Double Ratchet sessions. Sessions are keyed by `"peer_user_id:peer_device_id"`. When sending a DM, the client fetches per-device key bundles via `_get_device_bundles()` and encrypts separately for each device. The server registers devices at login (`handle_login_finish`), assigns device IDs, and routes notifications with `device_entries` arrays (one entry per recipient device). Device IDs are persisted to `~/.encrypted_chat/{email}/device_id.txt`. Old session files (`{user_id}.bin`) are automatically migrated to `{user_id}_{device_id}.bin` on first load. + +### Server Session Model +`connected_clients: dict[str, list[ProtocolWriter]]` — one user can have multiple connections (multi-device). `writer_device_map: dict[int, str]` maps `id(writer)` to `device_id`. Notifications are pushed to all connections except the sender's current one. + +### Device Registration +On `login_finish`, server checks for `device_id` in the request. If present and valid (belongs to user), reuses it. Otherwise creates a new device. Device ID returned in response and stored on client disk. `list_devices` and `remove_device` handlers for device management. + +### Simplified Pairing (Multi-Device) +`authorize_device()` only exports RSA + identity key (no sessions/sender keys). New device generates its own SPK + OPKs on first login, creates independent sessions via X3DH. Old messages readable via self-encryption (shared identity key). `reencrypt_history()` still runs to ensure all messages have self-encrypted copies. + +### Real-time Conversation Notifications +`handle_create_conversation`, `handle_add_member`, `handle_remove_member`, `handle_leave_group`, `handle_delete_conversation`, `handle_accept_invitation` push notifications to affected members via `connected_clients`. Types: `conversation_created`, `member_added`, `member_removed`, `group_invitation`. GUI handles these via `conversation_updated` signal -> refreshes conversation list. + +### Connection State + Auto-Reconnect +`ChatClient.connected` flag tracks TCP connection state. `_background_listener` sets `connected = False` when server closes connection and **fails all pending futures** with `ConnectionError` (prevents `send_and_recv` from hanging forever). `send_and_recv` has a 30s timeout via `asyncio.wait_for` and catches `ConnectionError`/`TimeoutError`. `reconnect()` re-establishes TCP + RSA challenge-response using in-memory keys (no password needed, includes `device_id`). GUI `_notification_loop` detects listener death -> triggers `_auto_reconnect` with exponential backoff (1s->2s->4s->...->30s). Connection indicator dot: green (connected), red (disconnected), orange (reconnecting). + +### Server Per-Message Error Handling +Server dispatch loop wraps each handler call in individual try/except. Handler crashes return "Internal server error" response instead of killing the entire connection. Errors logged with `exc_info=True` for full tracebacks. GUI `_do_send_message`/`_do_find_or_create_and_send` catch exceptions and emit error signal (prevents silent hang when send fails). + +### Online/Offline Status +- `db.get_user_contacts(user_id)` returns all user IDs sharing at least one conversation +- On login (`handle_login_finish`): server sends `online_users` list to new user + `user_online` to all contacts (only if user was fully offline before) +- On disconnect (`handle_client` finally block): if last connection drops, server sends `user_offline` to all contacts +- `_background_listener` routes `user_online`, `user_offline`, `online_users` to notification queue +- GUI: `_online_users: set[str]` tracks online users, green dot overlay on circular avatar in DM conversation list + green circle emoji in group info member list + +### Leave Group +- `handle_leave_group` in server.py: validates membership, blocks DM leave (len<=2 and no name), transfers creator to first remaining member if creator leaves, removes member, notifies remaining via `member_removed` +- `ChatClient.leave_group()`: sends request, cleans up local sender key states on success +- GUI: red "Leave Group" button in group info dialog, confirmation dialog, resets view on success + +### Delete Conversation +- **DMs:** Any member can delete. Only removes the deleting user from `conversation_members`. If both users delete, 0 members remain → conversation + files cleaned up. +- **Groups:** Only the creator (admin) can delete. Removes ALL members, cleans up `.enc` files from disk, deletes conversation via CASCADE. +- Server notifies remaining members via `member_removed` push. +- GUI: trash icon button in conversation header. Visible for DMs always, for groups only when user is creator. +- `chat_core.py`: cleans up local sender key states after successful delete. + +### Group Invitations +- **Flow:** `create_conversation` (group) or `add_member` → creates invitation → pushes `group_invitation` notification → invitee sees in invitation list → Accept (adds to members, notifies) / Decline (deletes invitation) +- **DMs are unaffected:** `create_conversation` for DMs still auto-adds both members +- **DB:** `group_invitations` table with UNIQUE(conversation_id, user_id) to prevent duplicates +- **Server:** `handle_accept_invitation` verifies invitation exists, adds member, deletes invitation, notifies existing members via `member_added`. `handle_decline_invitation` just deletes. +- **GUI:** `inv_list` QListWidget (max 120px, amber border) above `conv_list`. Right-click → Accept/Decline. `invitation_received` signal triggers refresh + notification banner. +- **Routing fix (IMPORTANT):** `group_invitation` must be in the notification types list in `chat_core.py:_background_listener` (~line 304). Without it, invitations get routed to `_response_queue` and the GUI never sees them. + +### Group Avatars +- Stored as files in `UPLOAD_DIR/avatars/group_{conv_id}.{ext}` (PNG or JPEG detected from magic bytes) +- `conversations.avatar_file` column stores the filename +- `list_conversations` response includes `avatar_file` so GUI knows which groups have avatars +- GUI: `_group_avatar_cache` dict, `_get_conv_avatar()` returns group avatar icon or default letter circle +- Group Info dialog shows 64px circular avatar + "Change Avatar" button (creator only) +- Periodic refresh every 2 minutes re-downloads all known group avatars + +### File Sharing +- Reuses image upload/download infrastructure (`upload_image_start/chunk/end`, `download_image`) +- `upload_image_start` accepts optional `file_type` param: `"image"` (MAX_IMAGE_BYTES=5MB) or `"file"` (MAX_FILE_BYTES=50MB) +- `ChatClient.send_file()`: reads raw file, AES-256-GCM encrypts, chunked upload, sends message with `file` field in payload (`{file_id, aes_key, iv, filename, size, mime_type}`) +- `ChatClient.download_file()`: identical to `download_image()` — chunked download + AES-GCM decrypt +- GUI: attach button is dropdown menu (Image / File), file messages render as styled cards with paperclip icon (transparent background, border) and clickable download link (`file://{file_id}`), context menu "Download file" option +- Files stored as `.enc` in UPLOAD_DIR, same as images + +### Unread Count Badges +- `_unread_counts: dict[str, int]` replaces old `_unread_convs: set` +- `_on_notification()` increments count per conversation +- `_on_conv_selected()` clears count for selected conversation +- Display: `(3) Username` with bold font, instead of old `● Username` + +### User Profiles +`user_profiles` table separated from `users` (clean separation, users = auth only). Default profile created on registration (`db.create_default_profile`). Visibility rules applied server-side in `db.get_user_profile(viewer_id)`. Avatars stored as files in `UPLOAD_DIR/avatars/{user_id}.{ext}` (not in DB). Format detection from magic bytes (PNG header vs default JPEG). UserProfileDialog shows circular cropped avatar (QPainter). + +### Prekey Replenishment + SPK Rotation +After login, `_ensure_prekeys()` is called as a background task. Checks two things: +1. **OPK count** — if < 20, generates and uploads a new batch of 50 +2. **SPK age** — server returns `spk_created_at` in `get_prekey_count` response. If SPK is >= 7 days old (`SPK_ROTATION_DAYS`), triggers rotation: saves current SPK as `prev_spk_private.bin`/`prev_spk_id.txt` (grace period), generates new SPK, uploads to server. + +### Password-Based Key Encryption (ECP1 Format) — M3 +Private keys (RSA, Ed25519) are encrypted with a custom envelope instead of `BestAvailableEncryption`: +- **Key derivation:** PBKDF2-HMAC-SHA256 with 600,000 iterations (OWASP 2023 compliant) +- **Encryption:** AES-256-GCM with the derived key, magic bytes as AAD +- **Format:** `_ECP1_MAGIC("ECP1", 4B) + salt(16B) + nonce(12B) + ciphertext_with_tag(N+16B)` +- **Backward compatibility:** `load_private_key()` and `load_ed25519_private()` detect ECP1 magic prefix. If absent, fall back to legacy PEM parsing (old `BestAvailableEncryption` format). On next save, files are re-encrypted in ECP1 format. +- **Functions:** `_encrypt_private_key()`, `_decrypt_private_key()` in `crypto_utils.py` +- **Applied to:** `serialize_private_key()` (RSA), `serialize_ed25519_private()` (Ed25519) + +### SPK Rotation (7-Day Cycle) — M4 +Signed Pre-Keys rotate periodically to limit exposure from a compromised SPK: +- **Rotation interval:** `SPK_ROTATION_DAYS = 7` (constant in `chat_core.py`) +- **Trigger:** `_ensure_prekeys()` checks `spk_created_at` from `get_prekey_count` response. If age >= 7 days, calls `_generate_and_upload_prekeys()`. +- **Grace period:** Before generating a new SPK, the current one is saved as `prev_spk_private.bin` / `prev_spk_id.txt`. Loaded on login into `_prev_spk_private` / `_prev_spk_id`. +- **Fallback on decrypt:** When `_decrypt_dm()` processes an X3DH header and decryption fails with the current SPK, it retries with the previous SPK via `_process_x3dh_header(..., spk_override=self._prev_spk_private)`. This handles in-flight X3DH initiated before rotation. +- **Server side:** `get_signed_prekey()` in `db.py` returns `created_at` column. `handle_get_prekey_count` includes `spk_created_at` (ISO format) in response. +- **Server SPK replacement:** `store_signed_prekey()` deletes old SPK and inserts new one — only one active SPK per device on server. + +### Ratchet State Rollback on Decrypt Failure — M9 +Both `DoubleRatchet.decrypt()` and `SenderKeyState.decrypt()` modify internal state (chain keys, counters, DH keys) before attempting AES-GCM decryption. If decryption fails (corrupted data, wrong key, AAD mismatch), the state would be permanently corrupted. + +**DoubleRatchet fix:** +- `_snapshot()` captures all mutable fields: `dh_pair`, `dh_remote`, `root_key`, `send_chain_key`, `recv_chain_key`, `send_n`, `recv_n`, `prev_send_n`, `skipped` dict (shallow copy) +- `decrypt()` takes snapshot before any state modification, wraps the entire DH ratchet + chain advance + AES-GCM decrypt in try/except, calls `_restore()` on failure +- Special case: skipped message decryption (no state modification needed) — if AES-GCM fails, the popped key is restored to `skipped` dict + +**SenderKeyState fix:** +- Before fast-forwarding the chain, snapshots `chain_key`, `n`, `_known_keys` (shallow copy) +- On any exception during fast-forward or AES-GCM decrypt, all three are restored + +### Rate Limits +- Per-IP+email window (60s): register 3/min, login 10/min, send_message 20/min +- Per-connection: 20 req/s +- Per-IP: max 10 connections, global max 200 +- Pairing: TTL 120s, max 90 poll attempts, pairing_start 10/min, pairing_poll 120/min, client polls every 2s + +### GUI Font Handling (IMPORTANT) +All widget stylesheet `font-size` declarations use `pt` (not `px`). Using `px` in Qt stylesheets sets `pixelSize` and leaves `pointSize=-1`, which causes `QFont::setPointSize: Point size <= 0` warnings on Windows. Conversion: `pt ~= px * 0.75` at 96 DPI. HTML styles inside QTextBrowser (`_render_single_message_html`) still use `px` — that's fine, QTextBrowser uses its own HTML renderer. Bold fonts for list items use `_bold_font()` helper + `item.setData(FontRole)` to avoid the same issue. + +### Phantom Users (Anti User-Enumeration) +- When a user creates a conversation with an unregistered email, the server creates a "phantom" user with `rsa_public_key = 'PHANTOM'` marker +- Phantom users have real crypto keys (Ed25519 IK, X25519 SPK + 5 OPKs) so X3DH works on the client side +- `handle_find_conversation` and `handle_create_conversation` create phantoms instead of returning "User not found" +- `handle_send_message` skips phantom recipients when storing `message_recipients` — only sender's self-encrypted copy is saved +- `phantom_user_ids: set[str]` in-memory cache loaded at startup from DB, updated on create/delete +- On registration (`handle_register_confirm`): if email belongs to a phantom, the phantom is **upgraded in-place** via `db.upgrade_phantom_user()` — preserves user_id and all FK references (invitations, conversation_members). Phantom's server-generated prekeys are deleted (real user uploads own). +- `handle_create_conversation` (groups) and `handle_add_member` create invitations for phantom users too. Push notifications only sent to non-phantom users. When phantom registers and logs in, they see pending invitations. +- Messages sent to phantom users are NOT stored and NOT recoverable after registration — this is by design (prevents user enumeration, sender sees own messages via self-encryption) +- DB functions: `db.create_phantom_user(email)`, `db.is_phantom_user(user_id)`, `db.delete_phantom_user(user_id)`, `db.upgrade_phantom_user(phantom_id, username, rsa_public_key_pem, identity_key)`, `db.get_all_phantom_user_ids()` + +### Logout/Login Fix +- `_is_logout` flag in MainWindow prevents `closeEvent()` from calling `bridge.stop()` which killed the asyncio loop +- On logout: set `_is_logout = True`, call `bridge.logout()`, then `close()` +- `closeEvent()` only calls `bridge.stop()` if `not self._is_logout` +- This allows `main()` to re-create the login/main windows after logout + +### Server Graceful Shutdown +- SIGINT handler force-closes all writers in `connected_clients` before the asyncio server context manager exits +- Without this, `async with server:` waited forever for `handle_client` loops to finish + +### Version Negotiation +- `VERSION = "0.8"` constant in `protocol.py` (shared between client and server) +- Client sends `client_version` in `login_finish` request (both `login()` and `reconnect()`) +- Server logs `client_version`, returns `server_version` in `login_finish` response +- Server startup log includes version: `"Encrypted chat server v0.8 listening on ..."` +- Future: server can reject incompatible client versions, client can warn about outdated server + +## Conventions + +- Server handlers: `handle_(msg, session, writer)` — registered in dispatch table in `handle_client()` +- DB functions: one `get_connection()` per call, `cursor(dictionary=True)`, returns dicts +- Binary data: always base64 in protocol (`encode_binary`/`decode_binary`) +- GUI signals: bridge emits `pyqtSignal`, MainWindow connects in `_connect_signals()` +- Error responses: `{"status": "error", "data": {"message": "..."}}` +- Notification decrypt returning `None` = control message, skip silently +- GUI stylesheet font sizes: always `pt`, never `px` (see Font Handling section above) +- File sharing reuses image upload infrastructure with `file_type` parameter +- Avatar files stored in `UPLOAD_DIR/avatars/` — user: `{user_id}.{ext}`, group: `group_{conv_id}.{ext}` + +## Aktuální stav práce + +### ✅ Dokončeno (tato session) +- Logout/login bug fix — `_is_logout` flag prevents bridge.stop() on logout +- Hover text readability — `color: #cdd6f4;` added to `QListWidget::item:hover` +- File card background — `background:transparent; border:1px solid #45475a` +- Delete conversations — full stack (db, server, chat_core, gui), DMs + groups (creator-only), file cleanup from disk +- Group invitation system — full stack (schema, db, server, chat_core, gui), create/accept/decline, real-time notification, invitation list UI +- Circular avatars in conversation list — QPainter circular crop, default letter avatars, online green dot overlay +- Group avatar support — upload/download, display in group info dialog, "Change Avatar" button (creator only) +- Server graceful shutdown — force-close connected clients on SIGINT +- Profile dialog avatar circular crop — QPainter in UserProfileDialog._on_avatar_loaded +- Periodic refresh timer — 2-minute QTimer re-downloads avatars + invitations +- Group invitation notification fix — `group_invitation` added to `_background_listener` notification types +- Delete button in conversation header — trash icon for DMs always, groups creator-only +- File cleanup on conversation delete — `db.get_conversation_file_ids()` + unlink `.enc` files +- Removed right-click "Delete conversation" from conversation list context menu +- README.md updated +- **H4 Race conditions fix** — 4 asyncio.Lock guards (`_clients_lock`, `_conn_lock`, `_pairing_lock`, `_uploads_lock`) pro všechny sdílené mutable struktury v server.py. `_notify_users()` + `_notify_users_individual()` helpery. Rate limit memory cleanup v periodic task. Všechny I/O operace mimo kritické sekce. +- **Unread counts pro offline uživatele** — `db.get_unread_counts()` dotaz přes `message_reads` + `message_recipients`, server vrací `unread_count` v `list_conversations`, GUI populuje `_unread_counts` ze serverových dat (max z server vs local). Opravuje bug kdy offline uživatel po přihlášení neviděl nepřečtené zprávy. +- **C6 Path traversal fix** — `_UUID_RE` regex + `_valid_file_id()`, `_safe_upload_path()`, `_safe_avatar_path()` helpery v server.py. UUID validace v `handle_upload_image_start`, `handle_download_image`. `is_relative_to()` guard ve všech path konstrukcích: upload start/end, download, delete_message file cleanup, delete_conversation file cleanup, _cleanup_uploads, get/update avatar, get/update group avatar. Celkem 10 guardovaných míst. +- **C3+H1+M13 Lokální šifrování + permissions** — `derive_local_storage_key()` v crypto_utils.py (HKDF z identity key, odlišný salt/info od self-encryption key). `_encrypt_local()`/`_decrypt_local()` helpery v chat_core.py (AES-256-GCM, formát: nonce(12)+tag(16)+ct). `_save_session`/`_load_session`, `_save_sender_key_state`/`_load_sender_key_state`, `_save_recv_sender_key`/`_load_recv_sender_key` — volitelný `local_key` parametr, při nastavení šifruje/dešifruje, `chmod 0o600` na soubory. `ChatClient._local_key` derivováno při login/registraci/pairingu. Transparentní migrace: pokud dešifrování selže, zkusí plaintext a re-uloží šifrovaně. `os.chmod(d, 0o700)` na všechny `mkdir()` v get_key_dir, opk_private, sessions, sender_keys, sender_keys_recv, message_cache. `os.chmod(p, 0o600)` na plaintext fallback message cache. +- **H7 Avatar path traversal** — `_safe_avatar_path()` guard na handle_get_avatar, handle_get_group_avatar + defense-in-depth na handle_update_avatar, handle_update_group_avatar. +- **Multi-device support (per-device sessions)** — `devices` table, `device_id` columns on prekeys/messages/recipients/sender_keys. Server: device registration at login, `writer_device_map`, per-device key bundles (`device_bundles` array), per-device notification routing (`device_entries`), `list_devices`/`remove_device` handlers. Client: `device_id` persistence, sessions keyed by `"user_id:device_id"`, `_get_device_bundles()` with 5-min TTL cache, per-device encryption in `_send_dm`/`send_image`/`send_file`/`_distribute_sender_key`, `sender_device_id` in decrypt routing, `decrypt_notification()` handles `device_entries` format. Pairing simplified: only RSA + identity key transfer, new device generates own SPK + OPKs. Migration: old session files auto-migrated, backward compat with old clients/servers. H12 OPK race condition fixed (SELECT FOR UPDATE). +- **Connection resilience fixes** — `_background_listener` fails all pending futures with `ConnectionError` on disconnect (prevents hang). `send_and_recv` has 30s timeout + catches `ConnectionError`. Server dispatch has per-message try/except (handler crash no longer kills connection). GUI `_do_send_message`/`_do_find_or_create_and_send` catch exceptions and emit error signal. +- **DB transaction fix** — `db.get_key_bundles_for_user()` had "Transaction already in progress" error because mysql-connector starts implicit transactions. Fixed with `conn.commit()` before `conn.start_transaction()`. +- **H5+H6 Protocol error handling** — `decode_binary()` catches `binascii.Error` → `ValueError`. `parse_message()` catches `JSONDecodeError`/`UnicodeDecodeError` → `ValueError`. Server dispatch already handles `ValueError` from `read_message()` gracefully. +- **H3+H13 Anti-enumeration** — `handle_register_start` returns same "ok" response for existing email (no "Email already in use" leak). `handle_login_start` returns fake challenge for non-existent email. `handle_login_finish` returns generic "Invalid credentials" for all failure cases. `get_user_info` moved behind auth barrier (requires login). +- **H8 Password memory cleanup** — `register()`, `login()`, `pairing_wait()` convert password to `bytearray`, zero out in `finally` block after key derivation. +- **H10 Image validation** — `_safe_load_image()` helper validates size (<10MB) and dimensions (<8192px) before `QImage.fromData()`. Applied to all 6 image loading locations in gui_client.py. +- **H11 Filename sanitization** — `_safe_filename()` helper strips path components via `os.path.basename()`. Applied to save dialogs and image dialog title. +- **C1+C2+C5 DoS hardening** — C1: `LimitOverrunError` now drains buffer and raises `ValueError` (server sends error response instead of silent disconnect; memory already protected by `limit=` on StreamReader). C2: `MAX_SENDER_KEY_SKIP` reduced from 1000 to 256 (matches DoubleRatchet `MAX_SKIP`). C5: `handle_upload_image_end` validates `received_bytes == file_size` before completing upload. M12 (upload end size validation) also fixed by C5. +- **M2+M8+M10+M11 Security hardening batch** — M2: SenderKeyState HKDF salt changed from `b""` to `b"\x00" * 32` (matches X3DH convention). M8: `_valid_file_id()` renamed to `_valid_uuid()`, UUID validation added to all handlers accepting client-provided `conv_id`, `user_id`, `message_id`, `device_id`. M10: `handle_mark_read` caps `message_ids` to 500 (prevents slow SQL DoS). M11: `handle_pairing_start` generates `poll_token` (secrets.token_hex(16)), `handle_pairing_poll` requires and validates it via `secrets.compare_digest()` — prevents unauthorized poll/payload extraction. +- **H2+H14 TLS hardening** — `TLS_INSECURE` a `TLS_AUTOGEN` nyní vyžadují `ENVIRONMENT=dev` (RuntimeError bez toho). Warning log na serveru i klientovi když TLS vypnuté. C4 (OPK file permissions) bylo již opraveno v C3+H1+M13 batchi. +- **Online dot fix + sorting** — Fixed timing issue where `online_users` signal processed before conv list populated. `_rebuild_conv_list()` sorts: favorites first → online DMs → rest alphabetically. +- **Favorites system** — Right-click context menu on conversation list → Add/Remove from favorites. Star indicator (★). Persisted to `favorites.json` in user key directory. +- **Group renaming** — Full stack: `db.update_conversation_name()`, `handle_rename_conversation` (server, creator-only, max 100 chars), `rename_conversation()` (chat_core), "Rename" button in group info dialog (GUI), `conversation_renamed` push notification to all members. +- **M3+M4+M9 Security hardening** — M3: PBKDF2 600k iterations (`_encrypt_private_key`/`_decrypt_private_key` s ECP1 formátem, backward compat pro PEM). M4: SPK rotace každých 7 dní, `spk_created_at` v `get_prekey_count`, grace period s `prev_spk_private.bin`, fallback v `_process_x3dh_header`/`_decrypt_dm`. M9: `_snapshot()`/`_restore()` v `DoubleRatchet.decrypt()`, snapshot/restore v `SenderKeyState.decrypt()`. +- **Phantom invitation fix** — Phantom users now receive group invitations. `handle_create_conversation` and `handle_add_member` create invitations for phantoms (no push notification). `handle_register_confirm` upgrades phantom in-place via `db.upgrade_phantom_user()` (preserves user_id + FK references). `handle_add_member` creates phantom for unregistered emails (same as `create_conversation`). After registration, user sees pending invitations on login. +- **Message Search** — Client-side search through decrypted message cache. `ChatClient.search_messages()` searches local cache. GUI: collapsible search bar (Ctrl+F) with prev/next navigation, match count, yellow/orange highlighting in message HTML. Escape closes search. Search button in chat header. +- **Session Recovery** — `session_reset` protocol message + `handle_session_reset` server handler. `ChatClient.reset_session()` deletes local session + notifies peer. Peer handles `session_reset` notification by deleting their session. Next message auto-creates new session via X3DH. GUI: "Reset session with sender" context menu on undecryptable messages, status bar notification on incoming reset. +- **L8 Phantom user DB inflation fix** — `_valid_email()` helper validates email format before phantom creation in `handle_find_conversation`, `handle_create_conversation`, `handle_add_member`. `db.cleanup_stale_phantoms(30)` deletes phantom users older than 30 days with no active conversations with real users. Runs in `_periodic_cleanup` every 10 minutes, refreshes in-memory `phantom_user_ids` cache. +- **M6 TOCTOU race fix** — `db.remove_conversation_member_atomic()` returns bool (True if row existed). Used in `handle_remove_member` (checks return value, returns error if already removed) and `handle_leave_group`. Defense-in-depth: pre-checks remain for user-friendly errors, atomic operation prevents double-removal. + +### 🐛 Známé bugy a problémy +- **Sender Key Redistribution (High Priority):** New group member can't decrypt old messages. On `add_member`, existing members should re-create and redistribute sender keys. +- **Database Connection Pooling:** Every `db.*` call creates new MySQL connection. Should use pooling for production. +- **Group delete confirmation message is generic** — could say "Delete group and remove all members?" for groups vs "Delete conversation?" for DMs. + +### ⏭️ Další kroky (TODO) + +#### Bezpečnostní opravy (priorita dle auditu) +1. **C6 (CRITICAL): Path traversal přes file_id** — `handle_upload_image_start` vytváří soubor `UPLOAD_DIR / f"{file_id}.tmp"` bez validace. Útočník může `../../...` a zapisovat/mazat mimo UPLOAD_DIR. Řešení: validovat UUID formát, ověřit `path.resolve().is_relative_to(UPLOAD_DIR.resolve())`. +2. ~~**H12 (HIGH): OPK race condition v db.get_key_bundle()**~~ ✅ OPRAVENO (součást multi-device — SELECT FOR UPDATE v consume_one_time_prekey + get_key_bundle) +3. **H3+H13: User enumeration** — `get_user_info` dostupné bez auth, vrací identity_key pro libovolný email. `register_start`/`login_start` vrací jednoznačné chyby. Řešení: auth pro `get_user_info`, generické odpovědi pro register/login. +4. ~~**H2+H14: TLS hardening**~~ ✅ OPRAVENO — `TLS_INSECURE` a `TLS_AUTOGEN` vyžadují `ENVIRONMENT=dev`. Warning log při vypnutém TLS. +5. ~~**C1+C2+C5**~~ ✅ OPRAVENO — DoS vektory (LimitOverrunError → ValueError, MAX_SENDER_KEY_SKIP 256, upload completeness check) +6. **C3+C4+H1** — Šifrování dat na disku (message cache, sessions, OPK permissions, `chmod 0o700` pro adresáře) +7. **H5+H6** — Error handling v protokolu (base64, JSON) +8. **H7** — Path traversal v avatar souborech (`resolved_path.is_relative_to()`) +9. ~~**M11 (MEDIUM): Pairing poll DoS**~~ ✅ OPRAVENO — poll_token binding (secrets.token_hex(16) + secrets.compare_digest) +10. ~~**M12: Upload end bez validace velikosti**~~ ✅ OPRAVENO (součást C5 fixu — `handle_upload_image_end` validuje `received_bytes == file_size`) +11. ~~**L8: Phantom user DB inflation**~~ ✅ OPRAVENO — email validace + periodic cleanup stale phantoms (30 dní) +12. **Version negotiation** — `VERSION = "0.8"` v protocol.py, klient posílá `client_version` při loginu, server loguje a vrací `server_version` + +#### Před nasazením do produkce (checklist) +1. **TLS certifikáty** — Získat certifikát (Let's Encrypt / vlastní CA). Nastavit `TLS_ENABLED=true`, `TLS_CERT_FILE`, `TLS_KEY_FILE` v `.env`. Ověřit že `TLS_INSECURE` a `TLS_AUTOGEN` NEJSOU nastavené (vyžadují `ENVIRONMENT=dev`). Na klientovi nastavit `TLS_ENABLED=true` a případně `TLS_CA_FILE` pokud vlastní CA. +2. **Email validace** — Zapnout `_valid_email()` kontrolu v `handle_find_conversation`, `handle_create_conversation`, `handle_add_member` (kód existuje v server.py, volání zakomentována). Teď vypnuto protože dev prostředí používá emaily bez @. +3. **MySQL TLS** — Přidat SSL parametry do `db.get_connection()` (`ssl_ca`, `ssl_cert`, `ssl_key`) pokud DB běží na jiném stroji. +4. **Connection pooling** — Nahradit `get_connection()` za `mysql.connector.pooling.MySQLConnectionPool(pool_size=10)`. +5. **SMTP** — Nastavit reálný SMTP server pro registrační kódy (`SMTP_HOST`, `SMTP_PORT`, `SMTP_USER`, `SMTP_PASS`, `SMTP_FROM`). +6. **UPLOAD_DIR** — Ověřit že `UPLOAD_DIR` je na persistentním disku s dostatkem místa, správnými právy (0o700). +7. **Rate limity** — Přezkoumat limity pro produkční zátěž (registrace 3/min, login 10/min, send_message 20/min, max 10 spojení/IP). +8. **Packaging** — Zabalit klienta (pyinstaller / cx_Freeze) pro distribuci. Po zabalení zvážit auto-update mechanismus a `get_version` endpoint. +9. **Penetrační testy** — Provést před ostrým nasazením (viz sekce níže). +10. **Backup** — Nastavit pravidelný backup MySQL databáze + `UPLOAD_DIR`. + +#### Penetrační testy +- Naplánovat a provést manuální penetrační testy zaměřené na: + - Path traversal (file_id, avatar_file) + - DoS vektory (readuntil, sender key fast-forward, upload flooding) + - Race conditions (OPK reuse, membership TOCTOU) + - User enumeration (register, login, get_user_info) + - TLS downgrade / MITM bez TLS + - Pairing session hijacking + - Memory exhaustion (rate_limits, phantom users, message_ids) +- Vytvořit testovací skripty pro automatizované security testy +- Zdokumentovat výsledky a opravit nalezené problémy + +#### Ke zvážení +- **Auto-update klientů** — distribuce aktualizovaných souborů klientům před login/registrací. Řešit až po kompilaci/packagingu (pyinstaller apod.). Mechanismus: server verze check → klient stáhne nové soubory → restart. +- **Server version check endpoint** — po packagingu mít jednoduchý endpoint (např. `get_version`), který vrací min/aktuální podporovanou verzi klienta + URL/metadata pro update; klient může před loginem ověřit kompatibilitu a nabídnout update. Vhodné i pro postupné vypínání starých klientů. + +#### Funkční vylepšení +1. **Sender Key Redistribution** — on add_member, redistribute sender keys to all members including new one +2. ~~**Device Linking fix**~~ ✅ — replaced with true multi-device (per-device sessions, simplified pairing) +3. ~~**SPK Rotation**~~ ✅ — periodic rotation with grace period (implemented in M4 fix) +4. **Typing Indicators** — `typing_start`/`typing_stop` protocol + GUI indicator +5. ~~**CLI support**~~ ✅ — profiles, file sharing, invitations, leave/rename/delete, search, devices in `client.py` +6. ~~**Message search**~~ ✅ — client-side search through decrypted cache, Ctrl+F toggle, highlight + navigation +7. ~~**Session Recovery**~~ ✅ — `session_reset` protocol, auto-recreate via X3DH on next message +8. **Connection Pooling** — `mysql.connector.pooling` for production +9. ~~**Version negotiation**~~ ✅ — `VERSION = "0.8"` in protocol.py, client sends `client_version` at login, server logs it and returns `server_version` + +## Bezpečnostní audit (Security Audit) + +Kompletní audit provedený přes všechny soubory projektu. Nálezy seřazené podle závažnosti. + +### 🔴 CRITICAL — Okamžitě řešit před nasazením + +#### ~~C1. readuntil() bez limitu → memory exhaustion (protocol.py:62)~~ ✅ OPRAVENO +`ProtocolReader.read_message()` volá `readuntil(b"\n")`, které načte CELOU zprávu do paměti PŘED kontrolou velikosti. Útočník pošle gigabyty dat bez newline → server spadne na out-of-memory. +```python +line = await self._reader.readuntil(b"\n") # buffers everything first! +if len(line) > MAX_MESSAGE_BYTES: # too late +``` +**Řešení:** Implementovat framing s hlavičkou obsahující velikost zprávy, nebo použít `readuntil()` s `limit` parametrem (asyncio StreamReader nemá nativně — nutno obalit vlastním čtením po částech). + +#### ~~C2. SenderKeyState — neomezený fast-forward DoS (crypto_utils.py:642-645)~~ ✅ OPRAVENO +Při dešifrování skupinové zprávy s libovolně vysokým `n` se smyčka `while self.n <= n` provede milionkrát — derivuje milion klíčů, spotřebuje stovky MB RAM. +```python +while self.n <= n: + self.chain_key, mk = kdf_ck(self.chain_key) + self._known_keys[self.n] = mk # unbounded dict growth + self.n += 1 +``` +**Řešení:** Přidat `MAX_FORWARD_SKIP` limit (např. 1000) — stejně jako Double Ratchet má `MAX_SKIP=256`. + +#### C3. Dešifrované zprávy uložené jako plaintext na disku (chat_core.py:222-239) +Message cache v `~/.encrypted_chat/{email}/message_cache/{conv_id}.json` obsahuje plný obsah dešifrovaných zpráv v nešifrovaném JSONu. Bez nastavení `chmod 0o600`. Kdokoliv s přístupem k disku přečte kompletní historii. +**Řešení:** Šifrovat cache klíčem odvozeným z identity key + nastavit `chmod 0o600` na soubory. + +#### ~~C4. OPK private keys bez file permissions (chat_core.py:153-156)~~ ✅ OPRAVENO +OPK privátní klíče se ukládají bez `os.chmod(0o600)`. RSA klíče (řádek 87) a identity key (řádek 121) mají `chmod` — OPK ne. Na sdílených systémech může jiný uživatel přečíst ephemeral klíče. +**Opraveno:** Součást C3+H1+M13 fixu — `_save_opk_private()` nyní volá `os.chmod(path, 0o600)` + `os.chmod(dir, 0o700)`. + +#### ~~C5. Chunked upload nevaliduje celkovou velikost (server.py:1138-1142)~~ ✅ OPRAVENO +`handle_upload_image_chunk` akumuluje `received_bytes` ale nekontroluje limit. Útočník deklaruje `file_size=5MB`, pak posílá chunky donekonečna → disk exhaustion. +```python +upload["received_bytes"] += len(raw) # no check against file_size! +``` +**Řešení:** Přidat `if upload["received_bytes"] > upload["file_size"]: reject`. + +### 🟠 HIGH — Řešit před production nasazením + +#### H1. Session + sender key soubory nešifrované na disku (chat_core.py:176-215) +Double Ratchet session state (DH privátní klíče, root key, chain keys) a sender key state se ukládají jako plaintext hex JSON v `sessions/` a `sender_keys/`. Bez šifrování, bez `chmod 0o600`. Kompromitace disku = dešifrování celé historie. +**Řešení:** Šifrovat state klíčem z identity key, nastavit `chmod 0o600`. + +#### ~~H2. TLS vypnuté ve výchozím stavu (chat_core.py:274-291, server.py)~~ ✅ OPRAVENO (hardening) +`TLS_ENABLED` je defaultně `false`. Bez TLS jdou po síti RSA challenge-response, session tokeny a metadata v plaintextu. `TLS_INSECURE=true` vypíná certificate verification → MITM. +**Opraveno:** `TLS_INSECURE` a `TLS_AUTOGEN` nyní vyžadují `ENVIRONMENT=dev` — v produkci RuntimeError. Warning log při vypnutém TLS na serveru i klientovi. TLS_ENABLED zůstává default false (uživatel nemá certifikát), ale po nasazení Let's Encrypt stačí flip na true. + +#### H3. User enumeration přes registraci (server.py:182-189) +Registrace vrací "Email already in use" pro existující uživatele vs. tiché vytvoření phantoma pro neexistující. Útočník může enumerovat platné emaily. +**Řešení:** Vrátit generickou odpověď "Check your email for verification code" i když email existuje. + +#### ~~H4. Race conditions v in-memory strukturách (server.py: multiple)~~ ✅ OPRAVENO +`connected_clients` dict, `phantom_user_ids` set, `pairing_sessions` dict — čteny a zapisovány z více concurrent koroutin bez synchronizace. Asyncio je single-threaded, ale yieldy uvnitř handlerů (await) mohou způsobit nekonzistentní stav. +**Opraveno:** 4 asyncio.Lock guards: `_clients_lock` (connected_clients, phantom_user_ids), `_conn_lock` (connection_counts, current_connections, rate_limits), `_pairing_lock` (pairing_sessions, pending_registrations), `_uploads_lock` (pending_uploads). Helper funkce `_notify_users()` / `_notify_users_individual()` — snapshot under lock, send outside. Rate limit memory cleanup v periodic task. Žádný handler nedrží dva locky současně → deadlock impossible. + +#### H5. base64 decode bez error handling (protocol.py:14-16, server.py + chat_core.py) +`decode_binary()` volá `base64.b64decode()` bez try-except. Nevalidní base64 od klienta → unhandled `binascii.Error` → handler crash. Mnoho callsites v server.py (řádky 357, 378, 783) nemá catch. +**Řešení:** Obalit `decode_binary()` try-except, nebo validovat base64 vstup před dekódováním. + +#### H6. JSON parsing bez exception handling (protocol.py:48-50) +`parse_message()` volá `json.loads()` bez try-catch. Malformovaný JSON = neošetřený `JSONDecodeError`. Server handler catch (řádek 1399) to odchytí, ale není to explicitní. +**Řešení:** Obalit `json.loads()` v `parse_message()` try-except s explicitní chybovou zprávou. + +#### H7. Path traversal v avatar souborech (server.py:1265, 1318) +`avatar_file` ze serveru (z DB) se přímo joinuje s `UPLOAD_DIR / "avatars"` bez validace. Pokud DB obsahuje `../../etc/passwd`, server přečte libovolný soubor. +**Řešení:** Přidat `resolved_path.resolve().is_relative_to(UPLOAD_DIR)` check. + +#### H8. Heslo v paměti jako Python string (chat_core.py, gui_client.py) +Python stringy jsou immutable — nelze je bezpečně vymazat z paměti. Heslo zůstává v paměti dokud garbage collector neuklidí. Memory dump = plaintext heslo. +**Řešení:** Použít `bytearray` (mutable), po použití přepsat nulami: `pwd[:] = b'\x00' * len(pwd)`. + +#### H9. Self-encryption key je statický a deterministický (chat_core.py:904+, crypto_utils.py:224-233) +`derive_self_encryption_key(identity_private)` produkuje vždy stejný klíč. Kompromitace identity klíče = dešifrování všech vlastních kopií zpráv navždy. Žádná forward secrecy pro self-copies. +**Poznámka:** Toto je by-design (nutné pro cross-device čtení), ale je to architektonické omezení. + +#### H10. Malicious image data → QImage crash (gui_client.py) +`QImage.fromData(data)` zpracovává nevalidované binární data. Speciálně vytvořený obrázek může způsobit crash, memory exhaustion, nebo v extrémním případě RCE přes Qt image codec. +**Řešení:** Validovat velikost dat před parsováním, limit na max rozlišení. + +#### H11. Filename z serveru v save dialogu (gui_client.py:2389, 2460) +Server-controlled `filename` se předává jako default do `QFileDialog.getSaveFileName()`. Pokud server pošle `"../../../.bashrc"`, dialog to navrhne. +**Řešení:** Sanitizovat filename — odstranit `../`, `\`, absolutní cesty. Použít jen `os.path.basename()`. + +### 🟡 MEDIUM — Zvážit pro hardening + +#### M1. Inconsistentní Ed25519 serializace (crypto_utils.py:99-102) +Bez hesla: 32 raw bytes. S heslem: PEM PKCS8 (~302 bytes). Dva různé formáty mohou způsobit problém při migraci nebo obnově klíčů. +**Poznámka:** M3 fix částečně řeší — s heslem je nyní ECP1 formát (ne PEM), ale `load_ed25519_private()` stále detekuje 3 formáty (ECP1, PEM, raw). Legacy PEM soubory se automaticky migrují při dalším uložení. + +#### ~~M2. Prázdný salt v SenderKeyState HKDF (crypto_utils.py:610)~~ ✅ OPRAVENO +`hkdf_derive(sender_key, salt=b"", ...)` — RFC 5869 doporučuje nenulový salt. X3DH správně používá `b"\x00" * 32`. +**Opraveno:** Změněno `salt=b""` → `salt=b"\x00" * 32` aby odpovídalo X3DH konvenci. + +#### ~~M3. PBKDF2 iterace pod doporučeným minimem (crypto_utils.py)~~ ✅ OPRAVENO +`BestAvailableEncryption` používá ~100k iterací PBKDF2. OWASP 2023 doporučuje 480k+. +**Opraveno:** Nahrazeno vlastním `_encrypt_private_key`/`_decrypt_private_key` s PBKDF2-HMAC-SHA256 (600k iterací) + AES-256-GCM. ECP1 formát (magic prefix) s backward compat pro staré PEM soubory. Aplikováno na RSA (`serialize_private_key`/`load_private_key`) i Ed25519 (`serialize_ed25519_private`/`load_ed25519_private`). + +#### ~~M4. SPK bez replay protection a bez rotace (server.py:360-368)~~ ✅ OPRAVENO +Stejný SPK lze nahrát opakovaně. Žádný nonce/timestamp v podpisu. SPK se nikdy nerotuje → kompromitovaný SPK = trvalé dešifrování nových sessions. +**Opraveno:** SPK rotace každých 7 dní (`SPK_ROTATION_DAYS`). Server vrací `spk_created_at` v `handle_get_prekey_count`. `_ensure_prekeys()` kontroluje stáří SPK a rotuje pokud >= 7 dní. Předchozí SPK uložen jako grace period (`prev_spk_private.bin`/`prev_spk_id.txt`) pro in-flight X3DH. `_process_x3dh_header` přijímá `spk_override`, `_decrypt_dm` retry s předchozím SPK při selhání dešifrování. + +#### ~~M5. Rate limit unbounded memory (server.py:73-83)~~ ✅ OPRAVENO (součást H4 fixu) +Staré záznamy se nikdy nečistí pokud klíč přestane být aktivní → útočník vytvoří miliony klíčů → memory leak. +**Opraveno:** `_cleanup_rate_limits()` v periodic cleanup (každých 10 min) maže stale entries z `rate_limits` i `connection_counts`. + +#### ~~M6. TOCTOU race v membership checks (db.py)~~ ✅ OPRAVENO +`is_conversation_member()` → `remove_conversation_member()` — mezi kontrolou a akcí může jiný klient stav změnit. +**Opraveno:** `db.remove_conversation_member_atomic()` vrací bool (True pokud řádek existoval). Použito v `handle_remove_member` a `handle_leave_group`. + +#### M7. MySQL spojení bez TLS (db.py:20-28) +`get_connection()` nepředává SSL parametry. Na vzdáleném serveru jdou credentials v plaintextu. + +#### ~~M8. Chybějící validace UUID formátu (server.py: throughout)~~ ✅ OPRAVENO +`conv_id`, `user_id` — kontrola jen na neprázdnost, ne na formát UUID v4. +**Opraveno:** `_valid_file_id()` přejmenováno na `_valid_uuid()`. UUID validace přidána do všech handlerů přijímajících klientem poskytnuté `conv_id`, `user_id`, `message_id`, `device_id`. + +#### ~~M9. Ratchet state corruption recovery (chat_core.py:1088-1104)~~ ✅ OPRAVENO +Pokud `decrypt()` změní chain keys ale selže na AAD verification, backup/restore mechanismus funguje, ale pokud backup selže (out-of-memory), stav zůstane corrupted. +**Opraveno:** `DoubleRatchet.decrypt()` nyní snapshotuje stav přes `_snapshot()`/`_restore()` a rollbackuje při jakékoliv výjimce (včetně skipped message key restore). `SenderKeyState.decrypt()` stejně snapshotuje `chain_key`, `n`, `_known_keys` před fast-forward a rollbackuje při selhání. + +#### ~~M10. Chybějící validace velikosti message_ids listu (db.py:641-646)~~ ✅ OPRAVENO +Klient může poslat tisíce message_ids v jednom požadavku → pomalý SQL dotaz, DoS. +**Opraveno:** `handle_mark_read` nyní odmítá požadavky s více než 500 message_ids. + +### 🟢 LOW — Dobrá praxe, nízké riziko + +- **L1.** Hex string keys v skipped messages dict — timing side-channel po úspěšné autentikaci (crypto_utils.py:425) +- **L2.** RatchetHeader serializace redundantně konvertuje typy (crypto_utils.py:394-405) +- **L3.** `notif_label.setText()` bezpečné proti XSS (Qt neinterpretuje HTML v setText), ale křehké — přepnutí na `setHtml()` by to rozbilo (gui_client.py:1524, 2259) +- **L4.** SQL column interpolation v `update_user_profile` — whitelist chrání, ale pattern je nebezpečný při kopírování (db.py:818-822) +- **L5.** Chybějící TLS cipher suite hardening — Python defaulty jsou rozumné, ale ne explicitně nastavené (protocol.py) +- **L6.** Temporary pairing key není bezpečně vymazán z paměti (chat_core.py:581) +- **L7.** `_user_cache` ukládá public identity keys indefinitely — memory leak pro hodně kontaktů + +### Druhý bezpečnostní review (zaměření na návrh, DB, komunikaci, lokální tmp/cache) + +#### C6. Path traversal → libovolný zápis/smazání souborů přes file_id (server.py) +`handle_upload_image_start` vytváří soubor `UPLOAD_DIR / f"{file_id}.tmp"` bez validace `file_id`. Útočník může poslat `../../...` a zapisovat mimo UPLOAD_DIR. Následné rename, cleanup, `delete_message` a `delete_conversation` pak mohou mazat libovolné soubory. +**Řešení:** Striktně validovat file_id (UUID hex/kanonický formát), odmítnout cokoliv s `/`, `\`, `..`. Ověřit `path.resolve().is_relative_to(UPLOAD_DIR.resolve())`. Ideálně ukládat do podadresářů podle hash/UUID. + +#### ~~H12. OPK race condition — reuse one-time pre-keys (db.py)~~ ✅ OPRAVENO +V `db.get_key_bundle()` se OPK vybírá SELECT → DELETE bez transakčního zámku. Při souběhu může být stejný OPK vydán vícekrát → porušení bezpečnostních předpokladů X3DH. +**Opraveno:** `consume_one_time_prekey()` a `get_key_bundle()` nyní používají `SELECT ... FOR UPDATE` + DELETE v jedné transakci (součást multi-device implementace). + +#### H13. Neautentizované get_user_info + identity key exfiltrace (server.py) +`get_user_info` je dostupné bez přihlášení a vrací username, email a identity_key pro libovolný email/user_id. Umožňuje enumeraci uživatelů a sběr metadata/klíčů. +**Řešení:** Vyžadovat auth, nebo omezit na "kontakty v konverzaci". + +#### ~~H14. TLS_INSECURE umožňuje MITM i v produkci (chat_core.py, server.py)~~ ✅ OPRAVENO +`TLS_INSECURE=true` vypíná verifikaci certifikátu → útočník může podvrhnout key bundle. Přímo ohrožuje E2EE integritu. +**Opraveno:** `TLS_INSECURE` vyžaduje `ENVIRONMENT=dev`, jinak RuntimeError. Součást H2 fixu. + +#### ~~M11. Pairing poll DoS — neautentizovaný přístup k payload (server.py)~~ ✅ OPRAVENO +Kdokoli s 8-místným kódem může pollovat a "vyzvednout" payload (smazán po vyzvednutí). I když je šifrovaný, jde o snadný DoS (reálnému zařízení pairing selže). +**Opraveno:** `handle_pairing_start` generuje `poll_token` (secrets.token_hex(16)), vrací klientovi. `handle_pairing_poll` vyžaduje `poll_token` a porovnává přes `secrets.compare_digest()`. Klient ukládá token v `pairing_start()` a posílá v `pairing_wait()`. + +#### ~~M12. Upload end bez validace received_bytes == file_size (server.py)~~ ✅ OPRAVENO +`upload_image_end` neověřuje, že `received_bytes == file_size`. Může zůstat nedokončený/nevalidní soubor. +**Řešení:** Kontrola délky před `complete_image_upload`. + +#### M13. Klíčové adresáře bez chmod 700 (chat_core.py) +`get_key_dir` a podadresáře (`sessions`, `sender_keys`, `opk_private`) se vytvářejí bez explicitních práv; spoléhá se na umask. +**Řešení:** Po `mkdir` vždy `chmod 0o700` pro adresář, `0o600` pro soubory. + +#### ~~L8. Phantom users — DB inflation (server.py, db.py)~~ ✅ OPRAVENO +`find_conversation` vytváří phantom usery pro libovolné emaily. I s rate limit lze DB časem nafouknout. +**Opraveno:** `_valid_email()` validace před vytvořením phantomu. `db.cleanup_stale_phantoms(30)` v periodic cleanup — maže phantomy starší 30 dní bez aktivních konverzací s reálnými uživateli. + +### Bezpečnostní matice (souhrn) + +| Soubor | CRITICAL | HIGH | MEDIUM | LOW | +|--------|----------|------|--------|-----| +| `protocol.py` | 1 (C1) | 2 (H5, H6) | 0 | 1 (L5) | +| `crypto_utils.py` | 1 (C2) | 0 | 3 (M1-M3) | 2 (L1, L2) | +| `server.py` | 2 (C5, C6) | 3 (H3, H7, H13) | 4 (M4, M8, M11, M12) | 0 | +| `chat_core.py` | 2 (C3, C4) | 4 (H1, H2/H14, H8, H9) | 2 (M9, M13) | 1 (L6) | +| `gui_client.py` | 0 | 2 (H10, H11) | 0 | 2 (L3, L7) | +| `db.py` | 0 | 1 (H12) | 3 (M6, M7, M10) | 2 (L4, L8) | +| **Celkem** | **6** | **12** | **12** | **8** | +| **Opraveno** | 6 (~~C1~~, ~~C2~~, ~~C3~~, ~~C4~~, ~~C5~~, ~~C6~~) | 11 (~~H1~~, ~~H2~~, ~~H3~~, ~~H4~~, ~~H5~~, ~~H6~~, ~~H7~~, ~~H8~~, ~~H10~~, ~~H11~~, ~~H12~~, ~~H13~~, ~~H14~~) | 11 (~~M2~~, ~~M3~~, ~~M4~~, ~~M5~~, ~~M6~~, ~~M8~~, ~~M9~~, ~~M10~~, ~~M11~~, ~~M12~~, ~~M13~~) | 1 (~~L8~~) | +| **Zbývá** | **0** | **1** | **1** | **7** | + +### Doporučené pořadí oprav (aktualizováno) +1. ~~**C6**~~ ✅ — Path traversal přes file_id — DONE (UUID validace + is_relative_to) +2. ~~**C1 + C2 + C5**~~ ✅ — DoS vektory — DONE (LimitOverrunError → ValueError, MAX_SENDER_KEY_SKIP 256, upload completeness check) +3. ~~**H12**~~ ✅ — OPK race condition — DONE (SELECT FOR UPDATE, součást multi-device) +4. ~~**C3 + H1 + H7 + M13**~~ ✅ — Šifrování dat na disku + file/dir permissions + avatar path traversal — DONE +5. ~~**H2/H14**~~ ✅ — TLS hardening (TLS_INSECURE + TLS_AUTOGEN vyžadují ENVIRONMENT=dev, warning log) — DONE +6. ~~**H5 + H6**~~ ✅ — Error handling v protokolu (base64, JSON) — DONE +7. ~~**H3 + H13**~~ ✅ — User enumeration (generické odpovědi, auth pro get_user_info) — DONE +8. ~~**H4**~~ ✅ — Race conditions (asyncio.Lock) — DONE +9. ~~**H8 + H10 + H11**~~ ✅ — Paměť hesel, image parsing, filename sanitizace — DONE +10. ~~**M2 + M8 + M10 + M11**~~ ✅ — Hardening batch (HKDF salt, UUID validace, message_ids cap, pairing poll token) — DONE +11. ~~**M3, M4, M9**~~ ✅ — PBKDF2 600k iterations, SPK rotace 7 dní s grace period, ratchet state rollback — DONE +12. **M1, ~~M6~~, M7** — Remaining hardening (Ed25519 serialization, ~~TOCTOU~~ ✅, MySQL TLS) +13. **Penetrační testy** — manuální + automatizované security testy + +## Důležitá rozhodnutí a kontext +- **Invitations replace direct add for groups:** `handle_add_member` and `handle_create_conversation` (for groups) now create invitations instead of directly adding members. DMs still auto-add both users. This was a design decision to give users control over joining groups. +- **Group delete = total destruction:** When creator deletes a group, ALL members are removed and the conversation is fully deleted. This is different from "leave group" which only removes the leaving user. +- **DM delete is per-user:** Deleting a DM only removes the deleting user. The other user still sees the conversation until they also delete it. +- **Avatar caching in GUI is pixmap-based:** `_avatar_cache` and `_group_avatar_cache` store QPixmap objects, not raw bytes. The `_on_avatar_for_conv_list` and `_on_group_avatar_for_conv_list` signals convert bytes → QImage → QPixmap on receipt. +- **No context menu on conversation list anymore:** Delete was the only action. Now handled by header buttons. `conv_list.setContextMenuPolicy(DefaultContextMenu)`. + +## Environment Variables +See README.md for full list. Key: `SERVER_HOST`, `SERVER_PORT`, `MYSQL_*`, `TLS_*`, `SMTP_*`, `LOG_LEVEL`, `MAX_INPUT_CHARS`, `UPLOAD_DIR`, `MAX_IMAGE_BYTES`, `MAX_FILE_BYTES`, `MAX_MESSAGE_BYTES`. + +## Commands & Workflow + +- Start server: `python server.py` +- Start GUI client: `python gui_client.py` +- Start CLI client: `python client.py` +- Environment: `.env` file in project root (loaded by `dotenv`) +- Dependencies: `PyQt6`, `mysql-connector-python`, `cryptography`, `Pillow` (for image sharing), `python-dotenv` +- Check syntax: `python3 -m py_compile .py` +- All files on both server AND client side: `crypto_utils.py`, `protocol.py`, `chat_core.py`, `gui_client.py` (or `client.py`) diff --git a/zaloha/README.md b/zaloha/README.md new file mode 100644 index 0000000..8204c9d --- /dev/null +++ b/zaloha/README.md @@ -0,0 +1,314 @@ +# Encrypted Chat + +End-to-end encrypted chat s forward secrecy (X3DH + Double Ratchet, Signal Protocol). +Server ukládá a přeposílá šifrované bloby — nikdy nevidí plaintext. + +## Soubory + +### Server +| Soubor | Účel | +|--------|------| +| `server.py` | Asyncio TCP server, handler dispatch, rate limiting, notifikace | +| `db.py` | MySQL CRUD, jedna connection na volání | +| `schema.sql` | MySQL schéma (users, conversations, messages, ...) | + +### Klient +| Soubor | Účel | +|--------|------| +| `gui_client.py` | PyQt6 GUI | +| `client.py` | CLI klient | +| `chat_core.py` | Logika klienta — session management, šifrování, lokální klíče | + +### Sdílené (server + klient) +| Soubor | Účel | +|--------|------| +| `crypto_utils.py` | Ed25519, X25519, AES-256-GCM, HKDF, PBKDF2, X3DH, Double Ratchet (state rollback), Sender Keys (state rollback), ECP1 key encryption | +| `protocol.py` | Newline-delimited JSON protokol, base64 encoding | + +## Quick Start + +1. `pip install -r requirements.txt` +2. Spustit `schema.sql` v MySQL (kompletní clean start). Pro migraci existující DB: `migration_multi_device.sql`. +3. `python server.py` +4. Klient: `python client.py` (CLI) nebo `python gui_client.py` (GUI, PyQt6) + +## Jak funguje šifrování + +### Klíče na uživatele +| Klíč | Typ | Účel | +|------|-----|------| +| RSA-4096 | Asymetrický | Pouze login challenge-response. Šifrovaný PBKDF2 (600k iterací) + AES-256-GCM. | +| Identity Key (IK) | Ed25519 | Podpisy, konverze na X25519 pro X3DH. Šifrovaný PBKDF2 (600k iterací) + AES-256-GCM. | +| Signed Pre-Key (SPK) | X25519 | DH v X3DH, podepsaný IK. **Rotuje se každých 7 dní** s grace periodem pro in-flight X3DH. | +| One-Time Pre-Keys (OPK) | X25519 | Jednorázové, spotřebuje se při X3DH, automaticky doplňované (< 20 → +50) | + +### DM (1:1 zprávy) — X3DH + Double Ratchet +1. Alice chce napsat Bobovi poprvé → stáhne jeho key bundle (IK, SPK, OPK) ze serveru. +2. X3DH: 4 DH výpočty → shared secret. +3. Double Ratchet inicializován ze shared secret. +4. Každá zpráva: symmetric ratchet (HMAC chain) → message key → AES-256-GCM. +5. Každá odpověď: DH ratchet (nový X25519 keypair) → nový root key + chain key. +6. Per-recipient ciphertext — každý recipient má vlastní šifrovaný blob. +7. Při selhání dešifrování: automatický rollback stavu ratchetu (snapshot/restore). + +### Skupiny — Sender Keys +1. Každý člen má vlastní sender key chain pro skupinu. +2. Sender key se distribuuje ostatním členům přes pairwise Double Ratchet (jako DM). +3. Skupinové zprávy: symmetric ratchet na sender key → AES-256-GCM. +4. Jeden ciphertext pro celou skupinu (efektivní). + +### Lokální úložiště klíčů +``` +~/.encrypted_chat/{email}/ + private.pem # RSA (login) — ECP1 formát s heslem, PEM bez hesla + public.pem # RSA (login) + identity_private.bin # Ed25519 — ECP1 formát s heslem, 32B raw bez hesla + identity_public.bin # Ed25519 + device_id.txt # UUID tohoto zařízení + spk_private.bin # Aktuální signed prekey + spk_id.txt + prev_spk_private.bin # Předchozí SPK (grace period pro in-flight X3DH) + prev_spk_id.txt + opk_private/ # One-time prekeys + {opk_id}.bin + sessions/ # Double Ratchet stavy (šifrované AES-256-GCM) + {user_id}_{device_id}.bin + sender_keys/ # Vlastní sender keys pro skupiny + {conv_id}.bin + sender_keys_recv/ # Přijaté sender keys od ostatních + {conv_id}_{sender_id}_{device_id}.bin +``` + +## Bezpečnostní hardening + +### Šifrování privátních klíčů na disku (ECP1 formát) +RSA a Ed25519 privátní klíče šifrované heslem používají vlastní formát ECP1 (Encrypted Chat PBKDF v1): +- **PBKDF2-HMAC-SHA256** s 600 000 iteracemi (OWASP 2023 doporučuje 480k+) +- **AES-256-GCM** pro šifrování, magic bytes "ECP1" jako AAD +- **Formát:** `ECP1(4B) + salt(16B) + nonce(12B) + ciphertext+tag` +- **Zpětná kompatibilita:** Staré PEM soubory (z `BestAvailableEncryption`) se načtou automaticky a při dalším uložení se přešifrují do ECP1. + +### SPK rotace (7 dní) +Signed Pre-Key se rotuje periodicky: +- Po přihlášení `_ensure_prekeys()` zjistí stáří SPK ze serveru (`spk_created_at`) +- Pokud je SPK starší než 7 dní → vygeneruje nový, starý uloží jako grace period +- **Grace period:** `prev_spk_private.bin` — pokud příchozí X3DH selže s aktuálním SPK, zkusí předchozí +- Omezuje dopad kompromitace SPK — útočník může vytvářet nové sessions max 7 dní + +### Odolnost ratchetu (state rollback) +Double Ratchet i Sender Keys automaticky rollbackují stav při selhání dešifrování: +- Před modifikací chain keys/counters se vytvoří snapshot +- Pokud AES-GCM dešifrování selže (corrupted data, wrong key), stav se obnoví +- Session zůstane funkční i po zpracování poškozené zprávy + +## Registrace + +1. `register` → server pošle 6-místný kód na email (nebo vrátí přímo v dev módu bez SMTP). +2. `register_confirm` → potvrzení kódu. +3. Automaticky se vygenerují a uploadnou prekeys (1 SPK + 50 OPKs). +4. Login. + +## Multi-Device Support + +Pravý multi-device (Signal-like) — každé zařízení má nezávislé Double Ratchet sessions. +Při posílání DM se zpráva šifruje zvlášť pro každé zařízení příjemce. +Všechna zařízení uživatele sdílejí Ed25519 identity key (pro self-encryption kompatibilitu). + +### Architektura +- **Devices tabulka** — každé přihlášení registruje device (UUID), server mapuje writer→device +- **Per-device prekeys** — každé zařízení má vlastní SPK + OPKs, server vrací `device_bundles` pole +- **Per-device sessions** — sessions klíčované `"user_id:device_id"`, nezávislé Double Ratchet instance +- **Self-encryption** — odesílatel šifruje vlastní kopii statickým klíčem z identity key (čitelné všemi vlastními zařízeními) +- **Notifikace** — `device_entries` pole, klient vybere záznam odpovídající svému device_id + +### Device Pairing (zjednodušený) + +Nové zařízení získá RSA + Ed25519 identity klíče od existujícího zařízení. +Přenos šifrovaný RSA-OAEP + AES-GCM přes server (server nevidí klíče). +Nové zařízení si po přihlášení automaticky vygeneruje vlastní SPK + OPKs. + +1. Nové zařízení: `Link Device` → dostane 8-místný kód. +2. Existující zařízení: `Authorize Device` → zadá kód → odešle RSA + identity klíče. +3. Nové zařízení importuje klíče, přihlásí se, vygeneruje vlastní prekeys. + +### Migrace +- Existující DB: spustit `migration_multi_device.sql` (nebo `migration_multi_device_resume.sql` pro idempotentní re-run) +- Čistá DB: `schema.sql` již obsahuje všechny multi-device sloupce + +## Device Revocation (Key Rotation) + +Rotuje RSA login klíč. Odpojí ostatní sessions. Forward secrecy zajišťuje, že kompromitace +jednoho session klíče neodhalí historii — není potřeba re-encryption. + +## Konfigurace + +### Server + DB +- `SERVER_HOST` (default `127.0.0.1`), `SERVER_PORT` (default `9999`) +- `MYSQL_HOST`, `MYSQL_PORT`, `MYSQL_USER`, `MYSQL_PASSWORD`, `MYSQL_DATABASE` + +### TLS +- `TLS_ENABLED` — zapne TLS (default `false`) +- `TLS_REQUIRED` — vyžaduje TLS_ENABLED, jinak server odmítne start +- `TLS_CERT_FILE`, `TLS_KEY_FILE` — cesty k certifikátu a privátnímu klíči (PEM) +- `TLS_AUTOGEN` — auto-generuje self-signed cert (**jen s `ENVIRONMENT=dev`**) +- `TLS_CA_FILE` (klient) — vlastní CA certifikát pro ověření serveru +- `TLS_INSECURE` (klient) — vypne ověření certifikátu (**jen s `ENVIRONMENT=dev`**) +- `ENVIRONMENT` — `dev`/`development` povolí TLS_INSECURE a TLS_AUTOGEN + +#### Produkční nasazení s Let's Encrypt +```bash +# 1. Nainstalovat certbot +sudo apt install certbot + +# 2. Získat certifikát (port 80 musí být volný pro ověření) +sudo certbot certonly --standalone -d chat.example.com + +# 3. V .env nastavit: +TLS_ENABLED=true +TLS_CERT_FILE=/etc/letsencrypt/live/chat.example.com/fullchain.pem +TLS_KEY_FILE=/etc/letsencrypt/live/chat.example.com/privkey.pem + +# 4. Klient — stačí zapnout TLS (Let's Encrypt je v systémovém trust store): +TLS_ENABLED=true +``` +Certifikát funguje na jakémkoliv portu (9999, 443, ...) — je vázaný na doménu, ne port. Certbot automaticky obnovuje certifikát každých 90 dní. + +#### Dev/testování (self-signed) +```bash +ENVIRONMENT=dev +TLS_ENABLED=true +TLS_AUTOGEN=true # server auto-generuje self-signed cert +TLS_INSECURE=true # klient přeskočí ověření certifikátu +``` + +### SMTP +- `SMTP_HOST`, `SMTP_PORT`, `SMTP_USER`, `SMTP_PASS`, `SMTP_FROM` +- Bez SMTP = dev mód (kód se vrací přímo klientovi). + +### Obrázky +- `UPLOAD_DIR` (default `uploads`), `MAX_IMAGE_BYTES` (default 5 MB, `0` = bez limitu) + +### Limity +- `MAX_MESSAGE_BYTES` (default `65536`), `MAX_INPUT_CHARS` (GUI, default `2000`) +- Rate limity: register 3/min, login 10/min, send_message 20/min, pairing_poll 10/min +- Connection: 20 req/s per connection, max 10 per IP, 200 global +- Pairing TTL: 120s, max 5 failed poll pokusů + +### Logging +- `LOG_LEVEL` (default `INFO`) + +## Features + +- Registrace (2-step, SMTP), login (RSA challenge-response), key rotation +- **Multi-device** — per-device sessions (Signal-like), device pairing (RSA + identity key transfer), automatické prekey generování na novém zařízení +- DM s forward secrecy (X3DH + Double Ratchet) — per-device šifrování +- Skupiny se Sender Keys (distribuované přes pairwise ratchet) +- Skupinové pozvánky — přidání do skupiny vyžaduje souhlas (accept/decline) +- Odpovědi na zprávy (reply_to) +- Mazání zpráv (soft-delete pro všechny, real-time notifikace) +- Mazání konverzací (pravý klik → smaže pro uživatele, pokud nezbývají členové smaže celou konverzaci) +- Šifrované obrázky (AES-256-GCM, chunked upload, thumbnail v bublině) +- Šifrované soubory (PDF, ZIP, atd. až 50 MB, chunked upload) +- Read receipts (real-time, client-side resoluce) +- Prekey replenishment (automatické doplňování OPKs po loginu + SPK rotace každých 7 dní) +- Silné šifrování klíčů na disku (PBKDF2 600k iterací + AES-256-GCM, ECP1 formát) +- Odolný ratchet — automatický rollback stavu při selhání dešifrování +- TLS (volitelný, auto-gen self-signed) +- Real-time notifikace konverzací — nové konverzace, přidání/odebrání členů se zobrazí okamžitě bez re-loginu +- Connection state indicator — zelená/červená/oranžová tečka, automatický reconnect s exponential backoff +- Online/offline status — zelená tečka na avataru v seznamu konverzací + v group info +- User profily — telefon, lokace, avatar, nastavení viditelnosti (email, telefon, lokace) +- Phantom users — anti user-enumeration: konverzace s neregistrovaným emailem funguje normálně (odesílatel vidí své zprávy), zprávy pro phantom příjemce se neukládají, phantom se smaže při skutečné registraci +- Clickable links — HTTPS modré, HTTP oranžové s ikonou zámku + potvrzovací dialog + +### GUI (PyQt6) +- Dark theme (Catppuccin Mocha) +- Seznam konverzací s kulatými avatary a online indikátorem (zelená tečka) +- Unread count badge na konverzacích (číselný počet nepřečtených zpráv) +- Message bubliny s barevným left border, timestamp vedle jména +- Read receipts (checkmarks), group info dialog, add/remove member +- Context menu: reply, delete, view image, download file +- Attach button pro obrázky a soubory, thumbnail v bublině, full-size viewer + save +- Pagination ("Load older messages") +- Connection indicator (zelená=online, červená=offline, oranžová=reconnecting) +- Auto-reconnect s exponential backoff (1s → 2s → 4s → ... → max 30s) +- Tlačítko "My Profile" — editace vlastního profilu (telefon, lokace, avatar, viditelnost) +- User profil dialog — klik na info tlačítko v group info → read-only profil uživatele +- Avatar upload/download (JPEG/PNG, max 2 MB, kruhový výřez) +- Leave group (červené tlačítko v group info, přenos creatora) +- Pozvánky do skupin — seznam pending pozvánek nad konverzacemi, pravý klik → accept/decline +- Periodický refresh avatarů a pozvánek (každé 2 minuty) + +### CLI +- Základní funkcionalita (DM, skupiny, šifrování). Profily a soubory pouze přes GUI. + +## Závislosti + +- `cryptography` — Ed25519, X25519, AES-GCM, RSA, HKDF, PBKDF2 +- `mysql-connector-python` — MySQL +- `python-dotenv` — env vars +- `PyQt6` — GUI +- `Pillow` — resize/thumbnail obrázků + +## Known Issues + +- Sender Keys pro skupiny se nedistribuují znovu při přidání nového člena (nový člen neuvidí staré skupinové zprávy). + +## TODO + +### Security — Zbývající +- [ ] **H9: Self-encryption key** — statický/deterministický klíč (by-design pro cross-device, architektonické omezení) +- [ ] M1: Nekonzistentní Ed25519 serializace (částečně vyřešeno M3 — ECP1 formát, ale 3 legacy formáty) +- [ ] M6: TOCTOU race v membership checks +- [ ] M7: MySQL spojení bez TLS +- [ ] L1-L8: Low-priority hardening +- [ ] **Penetrační testy** — manuální + automatizované + +### Features — High Priority +- [ ] Redistribuce sender keys při přidání nového člena do skupiny +- [ ] Typing indicators + +### Features — Medium Priority +- [ ] Hledání zpráv v konverzacích +- [ ] Group admin roles (více adminů) +- [ ] Edit sent messages + +### Features — Low Priority +- [ ] Dark/light theme toggle +- [ ] Desktop notifications (system tray) +- [ ] Database connection pooling +- [ ] Image gallery view +- [ ] Systemd + Docker deployment + +### Hotovo — Security +- [x] **C1-C6: Všechny CRITICAL opraveny** — readuntil DoS, sender key fast-forward, OPK permissions, upload size check, path traversal (UUID validace + is_relative_to) +- [x] **H1-H8, H10-H14: Většina HIGH opravena** — lokální šifrování dat (AES-256-GCM), TLS hardening (INSECURE/AUTOGEN jen v dev), anti-enumeration, race conditions (asyncio.Lock), protokol error handling, avatar path traversal, hesla v paměti (bytearray+zero), image validace, filename sanitizace, OPK race condition (SELECT FOR UPDATE) +- [x] **M2-M5+M8-M13: Většina MEDIUM opravena** — HKDF salt, PBKDF2 600k iterací (ECP1 formát), SPK rotace 7 dní s grace periodem, rate limit cleanup, UUID validace, ratchet state rollback, message_ids cap, pairing poll token, upload check, chmod 0o700/0o600 + +### Hotovo — Features +- [x] **Multi-device support** — per-device sessions (Signal-like), device pairing, automatické prekey generování +- [x] Unread counts pro offline uživatele +- [x] Clickable HTTP links — HTTPS modré, HTTP oranžové s varováním +- [x] User profily (telefon, lokace, avatar, viditelnost) +- [x] Connection state indicator + auto-reconnect +- [x] Encrypted file sharing (až 50 MB) +- [x] Leave group + přenos creatora +- [x] Unread count badge +- [x] User avatars (upload/download, kruhový výřez) +- [x] Online/offline status (zelená tečka na avataru) +- [x] Mazání konverzací +- [x] Skupinové pozvánky (accept/decline) +- [x] Graceful server shutdown + +## Bezpečnostní audit + +Dva bezpečnostní audity provedeny (kód review). Nalezeno 6 CRITICAL, 12 HIGH, 12 MEDIUM, 8 LOW nálezů. + +| Závažnost | Celkem | Opraveno | Zbývá | +|-----------|--------|----------|-------| +| CRITICAL | 6 | **6** | 0 | +| HIGH | 12 | **11** | 1 (H9 — by-design) | +| MEDIUM | 12 | **10** | 2 (M1 částečně, M6, M7) | +| LOW | 8 | 0 | 8 | + +Detaily viz `CLAUDE.md`. diff --git a/zaloha/chat_core.py b/zaloha/chat_core.py new file mode 100644 index 0000000..8084b55 --- /dev/null +++ b/zaloha/chat_core.py @@ -0,0 +1,2609 @@ +"""Shared network layer and ChatClient class for CLI and GUI clients. + +Uses X3DH + Double Ratchet for message encryption, Sender Keys for groups. +RSA retained for login challenge-response only. +""" + +import asyncio +import json +import logging +import os +import ssl +import uuid +from datetime import datetime, timezone +from pathlib import Path + +from dotenv import load_dotenv + +load_dotenv() + +from crypto_utils import ( + # RSA (login only) + generate_rsa_keypair, + serialize_private_key, + serialize_public_key, + load_private_key, + load_public_key, + rsa_sign, + # Ed25519 + generate_identity_keypair, + serialize_ed25519_private, + serialize_ed25519_private_raw, + serialize_ed25519_public, + load_ed25519_private, + load_ed25519_public, + ed25519_sign, + # X25519 + generate_x25519_keypair, + serialize_x25519_private, + serialize_x25519_public, + load_x25519_private, + load_x25519_public, + # X3DH + generate_signed_prekey, + generate_one_time_prekeys, + x3dh_initiate, + x3dh_respond, + # Double Ratchet + DoubleRatchet, + # Sender Keys + SenderKeyState, + # AES + aes_encrypt, + aes_decrypt, + # Self-encryption + derive_self_encryption_key, + # Local storage encryption + derive_local_storage_key, +) +from protocol import ( + VERSION, + ProtocolReader, + ProtocolWriter, + encode_binary, + decode_binary, + build_request, + MAX_MESSAGE_BYTES, + IMAGE_CHUNK_SIZE, +) + + +KEY_DIR = Path.home() / ".encrypted_chat" +OPK_REPLENISH_THRESHOLD = 20 +OPK_BATCH_SIZE = 50 +SPK_ROTATION_DAYS = 7 + + +def _encrypt_local(data: bytes, key: bytes) -> bytes: + """Encrypt data with AES-256-GCM for local storage. Format: nonce(12) + tag(16) + ciphertext.""" + _, nonce, ct, tag = aes_encrypt(data, key=key) + return nonce + tag + ct + + +def _decrypt_local(raw: bytes, key: bytes) -> bytes: + """Decrypt data encrypted by _encrypt_local.""" + nonce, tag, ct = raw[:12], raw[12:28], raw[28:] + return aes_decrypt(key, nonce, ct, tag) + + +def get_key_dir(email: str) -> Path: + d = KEY_DIR / email + d.mkdir(parents=True, exist_ok=True) + os.chmod(d, 0o700) + return d + + +# --------------------------------------------------------------------------- +# RSA key storage (login only — unchanged interface) +# --------------------------------------------------------------------------- + +def save_keys(email: str, private_key, public_key, password: bytes | None = None): + d = get_key_dir(email) + (d / "private.pem").write_bytes(serialize_private_key(private_key, password=password)) + (d / "public.pem").write_bytes(serialize_public_key(public_key)) + os.chmod(d / "private.pem", 0o600) + + +def load_keys(email: str, password: bytes | None = None): + d = get_key_dir(email) + priv_path = d / "private.pem" + pub_path = d / "public.pem" + if not priv_path.exists(): + return None, None, "No local keys found." + pem = priv_path.read_bytes() + try: + private_key = load_private_key(pem, password=password) + except Exception: + try: + private_key = load_private_key(pem, password=None) + if password: + save_keys(email, private_key, load_public_key(pub_path.read_bytes()), password=password) + except Exception: + return None, None, "Invalid or missing password." + public_key = load_public_key(pub_path.read_bytes()) + return private_key, public_key, None + + +# --------------------------------------------------------------------------- +# Identity + prekey storage +# --------------------------------------------------------------------------- + +def _save_identity_keys(email: str, ed_priv, ed_pub, password: bytes | None = None): + d = get_key_dir(email) + if password: + (d / "identity_private.bin").write_bytes(serialize_ed25519_private(ed_priv, password=password)) + else: + (d / "identity_private.bin").write_bytes(serialize_ed25519_private_raw(ed_priv)) + (d / "identity_public.bin").write_bytes(serialize_ed25519_public(ed_pub)) + os.chmod(d / "identity_private.bin", 0o600) + + +def _load_identity_keys(email: str, password: bytes | None = None): + d = get_key_dir(email) + priv_path = d / "identity_private.bin" + pub_path = d / "identity_public.bin" + if not priv_path.exists(): + return None, None + priv = load_ed25519_private(priv_path.read_bytes(), password=password) + pub = load_ed25519_public(pub_path.read_bytes()) + return priv, pub + + +def _save_spk(email: str, spk_priv, spk_id: str): + d = get_key_dir(email) + (d / "spk_private.bin").write_bytes(serialize_x25519_private(spk_priv)) + (d / "spk_id.txt").write_text(spk_id) + os.chmod(d / "spk_private.bin", 0o600) + + +def _load_spk(email: str): + d = get_key_dir(email) + priv_path = d / "spk_private.bin" + id_path = d / "spk_id.txt" + if not priv_path.exists(): + return None, None + priv = load_x25519_private(priv_path.read_bytes()) + spk_id = id_path.read_text().strip() if id_path.exists() else "" + return priv, spk_id + + +def _save_prev_spk(email: str, spk_priv, spk_id: str): + """Save previous SPK for grace period (in-flight X3DH may reference old SPK).""" + d = get_key_dir(email) + (d / "prev_spk_private.bin").write_bytes(serialize_x25519_private(spk_priv)) + (d / "prev_spk_id.txt").write_text(spk_id) + os.chmod(d / "prev_spk_private.bin", 0o600) + + +def _load_prev_spk(email: str): + """Load previous SPK (grace period). Returns (private_key, spk_id) or (None, None).""" + d = get_key_dir(email) + priv_path = d / "prev_spk_private.bin" + id_path = d / "prev_spk_id.txt" + if not priv_path.exists(): + return None, None + priv = load_x25519_private(priv_path.read_bytes()) + spk_id = id_path.read_text().strip() if id_path.exists() else "" + return priv, spk_id + + +def _save_opk_private(email: str, opk_id: str, opk_priv): + d = get_key_dir(email) / "opk_private" + d.mkdir(parents=True, exist_ok=True) + os.chmod(d, 0o700) + (d / f"{opk_id}.bin").write_bytes(serialize_x25519_private(opk_priv)) + os.chmod(d / f"{opk_id}.bin", 0o600) + + +def _load_opk_private(email: str, opk_id: str): + d = get_key_dir(email) / "opk_private" + p = d / f"{opk_id}.bin" + if not p.exists(): + return None + return load_x25519_private(p.read_bytes()) + + +def _delete_opk_private(email: str, opk_id: str): + d = get_key_dir(email) / "opk_private" + p = d / f"{opk_id}.bin" + try: + p.unlink(missing_ok=True) + except Exception: + pass + + +def _save_device_id(email: str, device_id: str): + d = get_key_dir(email) + p = d / "device_id.txt" + p.write_text(device_id) + os.chmod(p, 0o600) + + +def _load_device_id(email: str) -> str | None: + d = get_key_dir(email) + p = d / "device_id.txt" + if not p.exists(): + return None + return p.read_text().strip() or None + + +def _save_session(email: str, peer_user_id: str, ratchet: DoubleRatchet, + local_key: bytes | None = None, peer_device_id: str | None = None): + d = get_key_dir(email) / "sessions" + d.mkdir(parents=True, exist_ok=True) + os.chmod(d, 0o700) + if peer_device_id: + filename = f"{peer_user_id}_{peer_device_id}.bin" + else: + filename = f"{peer_user_id}.bin" + p = d / filename + data = ratchet.export_state() + if local_key: + data = _encrypt_local(data, local_key) + p.write_bytes(data) + os.chmod(p, 0o600) + + +def _load_session(email: str, peer_user_id: str, + local_key: bytes | None = None, + peer_device_id: str | None = None) -> DoubleRatchet | None: + d = get_key_dir(email) / "sessions" + if peer_device_id: + p = d / f"{peer_user_id}_{peer_device_id}.bin" + if not p.exists(): + # Fallback: try old format (no device_id) and migrate + p_old = d / f"{peer_user_id}.bin" + if p_old.exists(): + ratchet = _load_session_file(p_old, local_key) + if ratchet: + _save_session(email, peer_user_id, ratchet, local_key, + peer_device_id=peer_device_id) + try: + p_old.unlink() + except Exception: + pass + return ratchet + return None + else: + p = d / f"{peer_user_id}.bin" + if not p.exists(): + return None + return _load_session_file(p, local_key) + + +def _load_session_file(p: Path, local_key: bytes | None = None) -> DoubleRatchet | None: + """Load a session from a specific file path.""" + if not p.exists(): + return None + raw = p.read_bytes() + if local_key: + try: + data = _decrypt_local(raw, local_key) + except Exception: + # Migration: try loading as plaintext (old unencrypted format) + try: + ratchet = DoubleRatchet.import_state(raw) + return ratchet + except Exception: + return None + return DoubleRatchet.import_state(data) + return DoubleRatchet.import_state(raw) + + +def _delete_session_file(email: str, peer_user_id: str, peer_device_id: str | None = None): + """Delete a session file from disk (for session reset).""" + d = get_key_dir(email) / "sessions" + if peer_device_id: + p = d / f"{peer_user_id}_{peer_device_id}.bin" + else: + p = d / f"{peer_user_id}.bin" + try: + p.unlink(missing_ok=True) + except Exception: + pass + + +def _save_sender_key_state(email: str, conv_id: str, state: SenderKeyState, + local_key: bytes | None = None): + d = get_key_dir(email) / "sender_keys" + d.mkdir(parents=True, exist_ok=True) + os.chmod(d, 0o700) + p = d / f"{conv_id}.bin" + data = state.export_state() + if local_key: + data = _encrypt_local(data, local_key) + p.write_bytes(data) + os.chmod(p, 0o600) + + +def _load_sender_key_state(email: str, conv_id: str, + local_key: bytes | None = None) -> SenderKeyState | None: + d = get_key_dir(email) / "sender_keys" + p = d / f"{conv_id}.bin" + if not p.exists(): + return None + raw = p.read_bytes() + if local_key: + try: + data = _decrypt_local(raw, local_key) + except Exception: + try: + sk = SenderKeyState.import_state(raw) + _save_sender_key_state(email, conv_id, sk, local_key) + return sk + except Exception: + return None + return SenderKeyState.import_state(data) + return SenderKeyState.import_state(raw) + + +def _save_recv_sender_key(email: str, conv_id: str, sender_id: str, state: SenderKeyState, + local_key: bytes | None = None, + sender_device_id: str | None = None): + d = get_key_dir(email) / "sender_keys_recv" + d.mkdir(parents=True, exist_ok=True) + os.chmod(d, 0o700) + if sender_device_id: + filename = f"{conv_id}_{sender_id}_{sender_device_id}.bin" + else: + filename = f"{conv_id}_{sender_id}.bin" + p = d / filename + data = state.export_state() + if local_key: + data = _encrypt_local(data, local_key) + p.write_bytes(data) + os.chmod(p, 0o600) + + +def _load_recv_sender_key(email: str, conv_id: str, sender_id: str, + local_key: bytes | None = None, + sender_device_id: str | None = None) -> SenderKeyState | None: + d = get_key_dir(email) / "sender_keys_recv" + if sender_device_id: + p = d / f"{conv_id}_{sender_id}_{sender_device_id}.bin" + if not p.exists(): + # Fallback: try old format and migrate + p_old = d / f"{conv_id}_{sender_id}.bin" + if p_old.exists(): + sk = _load_recv_sender_key_file(p_old, local_key) + if sk: + _save_recv_sender_key(email, conv_id, sender_id, sk, local_key, + sender_device_id=sender_device_id) + try: + p_old.unlink() + except Exception: + pass + return sk + return None + else: + p = d / f"{conv_id}_{sender_id}.bin" + if not p.exists(): + return None + return _load_recv_sender_key_file(p, local_key) + + +def _load_recv_sender_key_file(p: Path, local_key: bytes | None = None) -> SenderKeyState | None: + """Load a recv sender key from a specific file path.""" + if not p.exists(): + return None + raw = p.read_bytes() + if local_key: + try: + data = _decrypt_local(raw, local_key) + except Exception: + try: + sk = SenderKeyState.import_state(raw) + return sk + except Exception: + return None + return SenderKeyState.import_state(data) + return SenderKeyState.import_state(raw) + + +# --------------------------------------------------------------------------- +# Local decrypted message cache (Double Ratchet keys are one-time use) +# --------------------------------------------------------------------------- + +def _load_message_cache(email: str, conv_id: str, cache_key: bytes | None = None) -> dict: + d = get_key_dir(email) / "message_cache" + p_bin = d / f"{conv_id}.bin" + p_json = d / f"{conv_id}.json" + + # Migration: if old plaintext .json exists but encrypted .bin doesn't + if p_json.exists() and not p_bin.exists(): + try: + cache = json.loads(p_json.read_text("utf-8")) + if cache_key: + _save_message_cache_full(d, conv_id, cache, cache_key) + p_json.unlink(missing_ok=True) + return cache + except Exception: + return {} + + if not p_bin.exists(): + return {} + if not cache_key: + return {} + try: + raw = p_bin.read_bytes() + # Format: nonce (12) + tag (16) + ciphertext + nonce = raw[:12] + tag = raw[12:28] + ct = raw[28:] + plaintext = aes_decrypt(cache_key, nonce, ct, tag) + return json.loads(plaintext.decode("utf-8")) + except Exception: + return {} + + +def _save_message_cache_full(d: Path, conv_id: str, cache: dict, cache_key: bytes): + """Write the full cache dict encrypted to disk.""" + d.mkdir(parents=True, exist_ok=True) + os.chmod(d, 0o700) + p = d / f"{conv_id}.bin" + plaintext = json.dumps(cache, ensure_ascii=False).encode("utf-8") + _key, nonce, ct, tag = aes_encrypt(plaintext, key=cache_key) + p.write_bytes(nonce + tag + ct) + os.chmod(p, 0o600) + + +def _save_message_to_cache(email: str, conv_id: str, message_id: str, payload: dict, + cache_key: bytes | None = None): + d = get_key_dir(email) / "message_cache" + cache = _load_message_cache(email, conv_id, cache_key) + cache[message_id] = payload + if cache_key: + _save_message_cache_full(d, conv_id, cache, cache_key) + else: + # Fallback: plaintext (no identity key available yet) + d.mkdir(parents=True, exist_ok=True) + os.chmod(d, 0o700) + p = d / f"{conv_id}.json" + p.write_text(json.dumps(cache, ensure_ascii=False), "utf-8") + os.chmod(p, 0o600) + + +class ChatClient: + def __init__(self): + self.reader: ProtocolReader | None = None + self.writer: ProtocolWriter | None = None + self.raw_writer: asyncio.StreamWriter | None = None + self.session: dict | None = None + self.private_key = None # RSA private key (login only) + self.public_key = None # RSA public key (login only) + self.username: str = "" + self.email: str = "" + self._listener_task: asyncio.Task | None = None + self._response_queue: asyncio.Queue = asyncio.Queue() + self._notification_queue: asyncio.Queue = asyncio.Queue() + self._pending: dict[str, asyncio.Future] = {} + self._pairing_temp_private_key = None + self._reencrypt_progress_cb = None + self._logger = logging.getLogger("encrypted_chat.client") + + # Signal Protocol keys + self.identity_private = None # Ed25519PrivateKey + self.identity_public = None # Ed25519PublicKey + self.spk_private = None # X25519PrivateKey (current signed prekey) + self.spk_id: str = "" + self._prev_spk_private = None # Previous SPK for grace period (M4) + self._prev_spk_id: str = "" + self.opk_privates: dict[str, object] = {} # id -> X25519PrivateKey + self.sessions: dict[str, DoubleRatchet] = {} # "user_id:device_id" -> ratchet + self.sender_key_states: dict[str, SenderKeyState] = {} # conv_id -> own sender key + self.recv_sender_keys: dict[str, SenderKeyState] = {} # "conv_id:sender_id:device_id" -> their key + # Cache: user_id -> {identity_key (Ed25519PublicKey), username, email} + self._user_cache: dict[str, dict] = {} + self.connected: bool = False + self.login_rejected: bool = False + self._cache_key: bytes | None = None # AES key for encrypting message cache on disk + self._local_key: bytes | None = None # AES key for encrypting session/sender key files + # Multi-device support + self.device_id: str | None = None # This device's UUID + self._device_bundle_cache: dict[str, tuple[float, list[dict]]] = {} # user_id -> (ts, bundles) + + async def connect(self): + host = os.getenv("SERVER_HOST", "127.0.0.1") + port = int(os.getenv("SERVER_PORT", "9999")) + tls_enabled = os.getenv("TLS_ENABLED", "false").lower() in ("1", "true", "yes") + tls_required = os.getenv("TLS_REQUIRED", "false").lower() in ("1", "true", "yes") + ssl_context = None + if tls_required and not tls_enabled: + raise RuntimeError("TLS_REQUIRED is enabled but TLS is not enabled.") + if tls_enabled: + insecure = os.getenv("TLS_INSECURE", "false").lower() in ("1", "true", "yes") + is_dev = os.getenv("ENVIRONMENT", "").lower() in ("dev", "development") + if insecure and not is_dev: + raise RuntimeError("TLS_INSECURE is only allowed when ENVIRONMENT=dev") + ssl_context = ssl.create_default_context() + ca_file = os.getenv("TLS_CA_FILE", "").strip() + if ca_file: + ssl_context.load_verify_locations(cafile=ca_file) + elif insecure: + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + else: + self._logger.warning("TLS is disabled — traffic is unencrypted. Set TLS_ENABLED=true for production.") + r, w = await asyncio.open_connection(host, port, limit=MAX_MESSAGE_BYTES, ssl=ssl_context) + self.reader = ProtocolReader(r) + self.writer = ProtocolWriter(w) + self.raw_writer = w + self.connected = True + + async def _background_listener(self): + """Read messages from server, routing responses vs notifications.""" + while True: + msg = await self.reader.read_message() + if msg is None: + self.connected = False + # Fail all pending futures so send_and_recv doesn't hang + pending = dict(self._pending) + self._pending.clear() + err = ConnectionError("Server connection lost") + for fut in pending.values(): + if not fut.done(): + fut.set_exception(err) + break + if msg.get("type") in ("new_message", "messages_read", "message_deleted", + "conversation_created", "member_added", "member_removed", + "user_online", "user_offline", "online_users", + "group_invitation", "conversation_renamed", + "session_reset"): + await self._notification_queue.put(msg) + else: + req_id = msg.get("request_id") + if req_id and req_id in self._pending: + fut = self._pending.pop(req_id) + if not fut.done(): + fut.set_result(msg) + else: + await self._response_queue.put(msg) + + async def send_and_recv(self, msg_type: str, timeout: float = 30.0, **kwargs) -> dict: + try: + request_id = str(uuid.uuid4()) + loop = asyncio.get_running_loop() + fut = loop.create_future() + self._pending[request_id] = fut + await self.writer.send_request(msg_type, request_id=request_id, **kwargs) + except (ValueError, ConnectionError, OSError) as e: + self._pending.pop(request_id, None) + return { + "type": msg_type, + "status": "error", + "data": {"message": str(e) or "Connection lost."}, + } + try: + return await asyncio.wait_for(fut, timeout=timeout) + except asyncio.TimeoutError: + self._logger.warning("send_and_recv timeout for '%s' after %.0fs", msg_type, timeout) + return { + "type": msg_type, + "status": "error", + "data": {"message": f"Request timed out ({msg_type})"}, + } + except ConnectionError: + return { + "type": msg_type, + "status": "error", + "data": {"message": "Connection lost."}, + } + finally: + self._pending.pop(request_id, None) + + # ------------------------------------------------------------------ + # User info / identity key cache + # ------------------------------------------------------------------ + + async def _get_user_info(self, user_id: str = "", email: str = "") -> dict | None: + """Get user info from server, cache identity key.""" + cached = self._user_cache.get(user_id) + if cached: + return cached + kwargs = {} + if user_id: + kwargs["user_id"] = user_id + elif email: + kwargs["email"] = email + else: + return None + resp = await self.send_and_recv("get_user_info", **kwargs) + if resp["status"] != "ok": + return None + data = resp["data"] + ik_bytes = decode_binary(data["identity_key"]) if data.get("identity_key") else None + info = { + "user_id": data["user_id"], + "username": data["username"], + "email": data["email"], + "identity_key": load_ed25519_public(ik_bytes) if ik_bytes else None, + "identity_key_bytes": ik_bytes, + } + self._user_cache[data["user_id"]] = info + return info + + # ------------------------------------------------------------------ + # Registration + # ------------------------------------------------------------------ + + async def register(self, username: str, password: str, email: str) -> tuple[bool, str]: + """Register user. Generates RSA + Ed25519 + prekeys.""" + self.username = username + self.email = email + pwd_bytes = bytearray(password.encode("utf-8")) if password else None + + try: + # RSA keys for login + priv, pub, err = load_keys(email, password=bytes(pwd_bytes) if pwd_bytes else None) + if priv is None: + priv, pub = generate_rsa_keypair() + save_keys(email, priv, pub, password=bytes(pwd_bytes) if pwd_bytes else None) + self.private_key = priv + self.public_key = pub + + # Ed25519 identity keys + ed_priv, ed_pub = _load_identity_keys(email, password=bytes(pwd_bytes) if pwd_bytes else None) + if ed_priv is None: + ed_priv, ed_pub = generate_identity_keypair() + _save_identity_keys(email, ed_priv, ed_pub, password=bytes(pwd_bytes) if pwd_bytes else None) + self.identity_private = ed_priv + self.identity_public = ed_pub + self._cache_key = derive_self_encryption_key(ed_priv) + self._local_key = derive_local_storage_key(ed_priv) + finally: + if pwd_bytes: + pwd_bytes[:] = b'\x00' * len(pwd_bytes) + + pub_pem = serialize_public_key(pub).decode("utf-8") + ik_b64 = encode_binary(serialize_ed25519_public(ed_pub)) + + start = await self.send_and_recv( + "register", + username=username, + public_key=pub_pem, + email=email, + identity_key=ik_b64, + ) + if start["status"] != "ok": + return False, start["data"]["message"] + code = start["data"].get("code") + if code: + return True, code + return True, start["data"].get("message", "Check your email for the code.") + + async def confirm_registration(self, email: str, username: str, code: str) -> tuple[bool, str]: + confirm = await self.send_and_recv("register_confirm", email=email, code=code) + if confirm["status"] == "ok": + # Upload prekeys immediately after registration + await self._generate_and_upload_prekeys() + return True, f"Registered as '{username}' (ID: {confirm['data']['user_id']})" + return False, confirm["data"]["message"] + + async def _generate_and_upload_prekeys(self, keep_spk: bool = False): + """Generate SPK + OPKs and upload to server. + + If keep_spk=True, re-sign the existing SPK instead of generating a new + one. This is used after device pairing so both devices share the same + SPK and either can respond to X3DH. + """ + if not self.identity_private: + return + + if keep_spk and self.spk_private and self.spk_id: + # Re-sign existing SPK (both devices share the identity key) + spk_pub_bytes = serialize_x25519_public(self.spk_private.public_key()) + spk_sig = ed25519_sign(self.identity_private, spk_pub_bytes) + spk_data = { + "id": self.spk_id, + "public_key": encode_binary(spk_pub_bytes), + "signature": encode_binary(spk_sig), + } + else: + # Save current SPK as previous for grace period (M4: in-flight X3DH) + if self.spk_private and self.spk_id: + self._prev_spk_private = self.spk_private + self._prev_spk_id = self.spk_id + _save_prev_spk(self.email, self.spk_private, self.spk_id) + # Generate a brand-new signed prekey + spk = generate_signed_prekey(self.identity_private) + self.spk_private = spk["private"] + self.spk_id = spk["id"] + _save_spk(self.email, spk["private"], spk["id"]) + spk_data = { + "id": spk["id"], + "public_key": encode_binary(serialize_x25519_public(spk["public"])), + "signature": encode_binary(spk["signature"]), + } + + # Generate one-time prekeys + opks = generate_one_time_prekeys(OPK_BATCH_SIZE) + for opk in opks: + self.opk_privates[opk["id"]] = opk["private"] + _save_opk_private(self.email, opk["id"], opk["private"]) + + # Upload to server + otp_data = [ + {"id": opk["id"], "public_key": encode_binary(serialize_x25519_public(opk["public"]))} + for opk in opks + ] + await self.send_and_recv( + "upload_prekeys", + signed_prekey=spk_data, + one_time_prekeys=otp_data, + ) + + async def _ensure_prekeys(self): + """Check OPK count and SPK age, replenish/rotate if needed.""" + resp = await self.send_and_recv("get_prekey_count") + if resp["status"] != "ok": + return + count = resp["data"].get("count", 0) + spk_created_at = resp["data"].get("spk_created_at", "") + + need_new_spk = False + if spk_created_at: + try: + created = datetime.fromisoformat(spk_created_at) + if created.tzinfo is None: + created = created.replace(tzinfo=timezone.utc) + age_days = (datetime.now(timezone.utc) - created).days + if age_days >= SPK_ROTATION_DAYS: + need_new_spk = True + self._logger.info("SPK is %d days old, rotating...", age_days) + except Exception: + pass + + if count < OPK_REPLENISH_THRESHOLD or need_new_spk: + if count >= OPK_REPLENISH_THRESHOLD: + self._logger.info("SPK rotation triggered (OPK count OK: %d)", count) + else: + self._logger.info("OPK count low (%d), replenishing...", count) + await self._generate_and_upload_prekeys() + + # ------------------------------------------------------------------ + # Login + # ------------------------------------------------------------------ + + async def login(self, email: str, password: str) -> tuple[bool, str]: + """Login user. Returns (success, message).""" + self.email = email + pwd_bytes = bytearray(password.encode("utf-8")) if password else None + + try: + # Load RSA keys + priv, pub, err = load_keys(email, password=bytes(pwd_bytes) if pwd_bytes else None) + if priv is None: + return False, err or "No local keys found. Register first." + self.private_key = priv + self.public_key = pub + + # Load identity keys + ed_priv, ed_pub = _load_identity_keys(email, password=bytes(pwd_bytes) if pwd_bytes else None) + finally: + if pwd_bytes: + pwd_bytes[:] = b'\x00' * len(pwd_bytes) + + if ed_priv is not None: + self.identity_private = ed_priv + self.identity_public = ed_pub + self._cache_key = derive_self_encryption_key(ed_priv) + self._local_key = derive_local_storage_key(ed_priv) + + # Load SPK + spk_priv, spk_id = _load_spk(email) + if spk_priv: + self.spk_private = spk_priv + self.spk_id = spk_id + + # Load previous SPK for grace period (M4) + prev_spk_priv, prev_spk_id = _load_prev_spk(email) + if prev_spk_priv: + self._prev_spk_private = prev_spk_priv + self._prev_spk_id = prev_spk_id + + # Load device_id from disk + self.device_id = _load_device_id(email) + + # RSA challenge-response login + start = await self.send_and_recv("login_start", email=email) + if start["status"] != "ok": + return False, start["data"]["message"] + + challenge = decode_binary(start["data"]["challenge"]) + signature = rsa_sign(self.private_key, challenge) + login_kwargs = {"email": email, "signature": encode_binary(signature), + "client_version": VERSION} + if self.device_id: + login_kwargs["device_id"] = self.device_id + finish = await self.send_and_recv("login_finish", **login_kwargs) + if finish["status"] == "ok": + self.session = finish["data"] + self.username = self.session.get("username", "") + # Store device_id from server + self.device_id = finish["data"].get("device_id") + if self.device_id: + _save_device_id(email, self.device_id) + # Replenish prekeys in background — after pairing, the new device + # has no local OPK private keys so we must generate fresh ones + # (server-side OPKs have no matching private keys on this device). + # Use keep_spk=True to preserve the shared SPK so both devices + # can respond to X3DH. + opk_dir = get_key_dir(self.email) / "opk_private" + has_local_opks = opk_dir.exists() and any(opk_dir.iterdir()) + if has_local_opks: + asyncio.create_task(self._ensure_prekeys()) + else: + self._logger.info("No local OPKs (likely new device). Generating fresh OPKs, keeping SPK.") + asyncio.create_task(self._generate_and_upload_prekeys(keep_spk=True)) + return True, f"Logged in as '{self.username}' (ID: {self.session['user_id']})" + return False, finish["data"]["message"] + + # ------------------------------------------------------------------ + # Pairing (device pairing — transfers RSA + identity keys) + # ------------------------------------------------------------------ + + async def pairing_start(self, email: str) -> tuple[bool, str]: + """Start device pairing. Returns (success, code/message).""" + temp_priv, temp_pub = generate_rsa_keypair(2048) + self._pairing_temp_private_key = temp_priv + temp_pub_pem = serialize_public_key(temp_pub).decode("utf-8") + resp = await self.send_and_recv("pairing_start", email=email, temp_public_key=temp_pub_pem) + if resp["status"] == "ok": + self._pairing_poll_token = resp["data"].get("poll_token", "") + return True, resp["data"]["code"] + return False, resp["data"]["message"] + + async def pairing_wait(self, code: str, email: str, password: str, timeout: int = 300) -> tuple[bool, str]: + """Wait for pairing payload and import keys. Returns (success, message).""" + if not self._pairing_temp_private_key: + return False, "Pairing not started." + from crypto_utils import aes_decrypt as _aes_decrypt + poll_token = getattr(self, "_pairing_poll_token", "") + deadline = asyncio.get_event_loop().time() + timeout + while asyncio.get_event_loop().time() < deadline: + resp = await self.send_and_recv("pairing_poll", code=code, poll_token=poll_token) + if resp["status"] != "ok": + return False, resp["data"]["message"] + if not resp["data"].get("ready"): + await asyncio.sleep(2.0) + continue + payload = resp["data"]["payload"] + try: + # Decrypt AES key with temp RSA key + from cryptography.hazmat.primitives.asymmetric import padding as rsa_padding + from cryptography.hazmat.primitives import hashes as rsa_hashes + enc_aes_key = decode_binary(payload["encrypted_key"]) + aes_key = self._pairing_temp_private_key.decrypt( + enc_aes_key, + rsa_padding.OAEP( + mgf=rsa_padding.MGF1(algorithm=rsa_hashes.SHA256()), + algorithm=rsa_hashes.SHA256(), + label=None, + ), + ) + nonce = decode_binary(payload["iv"]) + ct = decode_binary(payload["ciphertext"]) + tag = decode_binary(payload["tag"]) + keys_json = _aes_decrypt(aes_key, nonce, ct, tag) + keys_data = json.loads(keys_json) + + pwd_bytes = bytearray(password.encode("utf-8")) if password else None + + try: + # Import RSA key + rsa_priv = load_private_key(keys_data["rsa_private"].encode(), password=None) + rsa_pub = rsa_priv.public_key() + save_keys(email, rsa_priv, rsa_pub, password=bytes(pwd_bytes) if pwd_bytes else None) + + # Import identity keys + ed_priv = load_ed25519_private(bytes.fromhex(keys_data["identity_private"])) + ed_pub = ed_priv.public_key() + _save_identity_keys(email, ed_priv, ed_pub, password=bytes(pwd_bytes) if pwd_bytes else None) + finally: + if pwd_bytes: + pwd_bytes[:] = b'\x00' * len(pwd_bytes) + + self.email = email + self.private_key = rsa_priv + self.public_key = rsa_pub + self.identity_private = ed_priv + self.identity_public = ed_pub + self._cache_key = derive_self_encryption_key(ed_priv) + self._local_key = derive_local_storage_key(ed_priv) + self._pairing_temp_private_key = None + + # Multi-device: new device generates own SPK + OPKs on first + # login. No session/sender key import needed — each device + # has independent Double Ratchet sessions. + + return True, "Pairing complete." + except Exception as e: + return False, f"Failed to import keys: {e}" + return False, "Pairing timed out." + + async def authorize_device(self, code: str) -> tuple[bool, str]: + """Authorize a new device by sending all keys to it.""" + if not self.private_key or not self.identity_private: + return False, "Not logged in." + claim = await self.send_and_recv("pairing_claim", code=code) + if claim["status"] != "ok": + return False, claim["data"]["message"] + + temp_pub_pem = claim["data"]["temp_public_key"].encode("utf-8") + temp_pub = load_public_key(temp_pub_pem) + + # Phase 1: Re-encrypt message history so new device can read old + # messages via self-encryption key. This also advances ratchet states + # for any previously-unfetched messages. + try: + await self.reencrypt_history() + except Exception as e: + self._logger.warning("Re-encryption failed: %s", e) + + # Phase 2: Build keys payload — only RSA + identity key. + # Multi-device: new device generates own SPK + OPKs, creates independent + # sessions. No session/sender key transfer needed. + keys_data = { + "rsa_private": serialize_private_key(self.private_key, password=None).decode(), + "identity_private": serialize_ed25519_private_raw(self.identity_private).hex(), + } + + # Phase 3: Encrypt and send keys to new device + from cryptography.hazmat.primitives.asymmetric import padding as rsa_padding + from cryptography.hazmat.primitives import hashes as rsa_hashes + plaintext = json.dumps(keys_data).encode() + aes_key, nonce, ct, tag = aes_encrypt(plaintext) + enc_aes_key = temp_pub.encrypt( + aes_key, + rsa_padding.OAEP( + mgf=rsa_padding.MGF1(algorithm=rsa_hashes.SHA256()), + algorithm=rsa_hashes.SHA256(), + label=None, + ), + ) + payload = { + "encrypted_key": encode_binary(enc_aes_key), + "iv": encode_binary(nonce), + "ciphertext": encode_binary(ct), + "tag": encode_binary(tag), + } + resp = await self.send_and_recv("pairing_send", code=code, payload=payload) + if resp["status"] == "ok": + return True, "Device authorized." + return False, resp["data"]["message"] + + # ------------------------------------------------------------------ + # Key rotation (RSA login key only) + # ------------------------------------------------------------------ + + async def rotate_keys(self, username: str, password: str) -> tuple[bool, str]: + """Rotate RSA keypair to revoke other devices.""" + if not self.session or self.session.get("username") != username: + return False, "Not logged in." + pwd_bytes = password.encode("utf-8") if password else None + priv, pub = generate_rsa_keypair() + save_keys(self.email, priv, pub, password=pwd_bytes) + self.private_key = priv + self.public_key = pub + pub_pem = serialize_public_key(pub).decode("utf-8") + resp = await self.send_and_recv("rotate_keys", public_key=pub_pem) + if resp["status"] == "ok": + return True, "RSA login keys rotated." + return False, resp["data"]["message"] + + # ------------------------------------------------------------------ + # Session management (X3DH + Double Ratchet) + # ------------------------------------------------------------------ + + async def _get_device_bundles(self, peer_user_id: str) -> list[dict]: + """Get per-device key bundles for a peer. Caches for 5 minutes.""" + import time + cached = self._device_bundle_cache.get(peer_user_id) + if cached: + ts, bundles = cached + if time.time() - ts < 300: + return bundles + + resp = await self.send_and_recv("get_key_bundle", user_id=peer_user_id) + if resp["status"] != "ok": + raise RuntimeError(f"Cannot get key bundle for {peer_user_id}: {resp['data']['message']}") + + data = resp["data"] + ik_b64 = data.get("identity_key", "") + + device_bundles = data.get("device_bundles") + if device_bundles: + # Attach identity_key to each bundle + for b in device_bundles: + b["identity_key"] = ik_b64 + else: + # Old server: wrap flat response as single-entry list + device_bundles = [{ + "device_id": None, + "identity_key": ik_b64, + "signed_prekey_id": data.get("signed_prekey_id", ""), + "signed_prekey": data.get("signed_prekey", ""), + "spk_signature": data.get("spk_signature", ""), + "one_time_prekey_id": data.get("one_time_prekey_id"), + "one_time_prekey": data.get("one_time_prekey"), + }] + + self._device_bundle_cache[peer_user_id] = (time.time(), device_bundles) + return device_bundles + + async def _get_or_create_session(self, peer_user_id: str, + peer_device_id: str | None = None, + bundle: dict | None = None) -> DoubleRatchet: + """Load existing session or create one via X3DH. + + If peer_device_id is set, sessions are keyed by "user_id:device_id". + If bundle is provided, it's used instead of fetching from server. + """ + session_key = f"{peer_user_id}:{peer_device_id}" if peer_device_id else peer_user_id + + # Check in-memory cache + if session_key in self.sessions: + return self.sessions[session_key] + + # Check on disk + ratchet = _load_session(self.email, peer_user_id, self._local_key, + peer_device_id=peer_device_id) + if ratchet: + self.sessions[session_key] = ratchet + return ratchet + + # Create new session via X3DH + if not bundle: + resp = await self.send_and_recv("get_key_bundle", user_id=peer_user_id) + if resp["status"] != "ok": + raise RuntimeError(f"Cannot get key bundle for {peer_user_id}: {resp['data']['message']}") + bundle = resp["data"] + + ik_remote_bytes = decode_binary(bundle["identity_key"]) + ik_remote = load_ed25519_public(ik_remote_bytes) + spk_remote = load_x25519_public(decode_binary(bundle["signed_prekey"])) + spk_sig = decode_binary(bundle["spk_signature"]) + + opk_remote = None + opk_id = bundle.get("one_time_prekey_id") + if bundle.get("one_time_prekey"): + opk_remote = load_x25519_public(decode_binary(bundle["one_time_prekey"])) + + # Perform X3DH + shared_secret, ek_priv, ek_pub = x3dh_initiate( + self.identity_private, + ik_remote, + spk_remote, + spk_sig, + opk_remote, + ) + + # Initialize Double Ratchet as Alice + ratchet = DoubleRatchet.init_alice(shared_secret, spk_remote) + self.sessions[session_key] = ratchet + _save_session(self.email, peer_user_id, ratchet, self._local_key, + peer_device_id=peer_device_id) + + # Build X3DH header for first message + x3dh_header = { + "ik": encode_binary(serialize_ed25519_public(self.identity_public)), + "ek": encode_binary(serialize_x25519_public(ek_pub)), + } + if opk_id: + x3dh_header["opk_id"] = opk_id + + # Cache the x3dh header for the next send_message call + ratchet._x3dh_header = x3dh_header + + # Cache remote user info + self._user_cache[peer_user_id] = { + "user_id": peer_user_id, + "identity_key": ik_remote, + "identity_key_bytes": ik_remote_bytes, + } + + return ratchet + + def _process_x3dh_header(self, sender_id: str, x3dh_header: dict, + sender_device_id: str | None = None, + spk_override=None) -> DoubleRatchet: + """Process an incoming X3DH header to establish session as Bob. + + Args: + spk_override: If provided, use this SPK private key instead of self.spk_private. + Used for grace period fallback (M4). + """ + ik_remote_bytes = decode_binary(x3dh_header["ik"]) + ik_remote = load_ed25519_public(ik_remote_bytes) + ek_remote = load_x25519_public(decode_binary(x3dh_header["ek"])) + + opk_id = x3dh_header.get("opk_id") + opk_priv = None + if opk_id: + opk_priv = _load_opk_private(self.email, opk_id) + if opk_priv: + _delete_opk_private(self.email, opk_id) + + spk_priv = spk_override if spk_override else self.spk_private + + shared_secret = x3dh_respond( + self.identity_private, + spk_priv, + ik_remote, + ek_remote, + opk_priv, + ) + + spk_pub = spk_priv.public_key() if hasattr(spk_priv, 'public_key') else None + ratchet = DoubleRatchet.init_bob(shared_secret, (spk_priv, spk_pub)) + + session_key = f"{sender_id}:{sender_device_id}" if sender_device_id else sender_id + self.sessions[session_key] = ratchet + _save_session(self.email, sender_id, ratchet, self._local_key, + peer_device_id=sender_device_id) + + self._user_cache[sender_id] = { + "user_id": sender_id, + "identity_key": ik_remote, + "identity_key_bytes": ik_remote_bytes, + } + + return ratchet + + # ------------------------------------------------------------------ + # Conversations + # ------------------------------------------------------------------ + + async def create_conversation(self, member_emails: list[str], name: str | None = None) -> tuple[str | None, str]: + kwargs = {"members": member_emails} + if name: + kwargs["name"] = name + resp = await self.send_and_recv("create_conversation", **kwargs) + if resp["status"] == "ok": + return resp["data"]["conversation_id"], "OK" + return None, resp["data"]["message"] + + async def remove_member(self, conv_id: str, user_id: str) -> tuple[bool, str]: + resp = await self.send_and_recv("remove_member", conversation_id=conv_id, user_id=user_id) + if resp["status"] == "ok": + return True, "OK" + return False, resp["data"]["message"] + + async def leave_group(self, conv_id: str) -> tuple[bool, str]: + """Leave a group conversation.""" + resp = await self.send_and_recv("leave_group", conversation_id=conv_id) + if resp["status"] == "ok": + # Clean up local sender key state for this group + self.sender_key_states.pop(conv_id, None) + # Remove received sender keys for this conversation + to_remove = [k for k in self.recv_sender_keys if k.startswith(f"{conv_id}:")] + for k in to_remove: + self.recv_sender_keys.pop(k, None) + return True, "OK" + return False, resp["data"]["message"] + + async def rename_conversation(self, conv_id: str, name: str) -> tuple[bool, str]: + """Rename a group conversation (creator only).""" + resp = await self.send_and_recv("rename_conversation", conversation_id=conv_id, name=name) + if resp["status"] == "ok": + return True, "OK" + return False, resp["data"]["message"] + + async def delete_conversation(self, conv_id: str) -> tuple[bool, str]: + """Delete a conversation (leave + server cleans up if empty).""" + resp = await self.send_and_recv("delete_conversation", conversation_id=conv_id) + if resp["status"] == "ok": + # Clean up local sender key state + self.sender_key_states.pop(conv_id, None) + to_remove = [k for k in self.recv_sender_keys if k.startswith(f"{conv_id}:")] + for k in to_remove: + self.recv_sender_keys.pop(k, None) + return True, "OK" + return False, resp["data"]["message"] + + async def add_member(self, conv_id: str, email: str) -> tuple[bool, str]: + resp = await self.send_and_recv("add_member", conversation_id=conv_id, email=email) + if resp["status"] == "ok": + return True, "OK" + return False, resp["data"]["message"] + + async def accept_invitation(self, conv_id: str) -> tuple[bool, str]: + """Accept a group invitation.""" + resp = await self.send_and_recv("accept_invitation", conversation_id=conv_id) + if resp["status"] == "ok": + return True, "OK" + return False, resp["data"]["message"] + + async def decline_invitation(self, conv_id: str) -> tuple[bool, str]: + """Decline a group invitation.""" + resp = await self.send_and_recv("decline_invitation", conversation_id=conv_id) + if resp["status"] == "ok": + return True, "OK" + return False, resp["data"]["message"] + + async def list_invitations(self) -> list[dict]: + """List pending group invitations.""" + resp = await self.send_and_recv("list_invitations") + if resp["status"] == "ok": + return resp["data"]["invitations"] + return [] + + async def list_conversations(self) -> list[dict]: + resp = await self.send_and_recv("list_conversations") + if resp["status"] == "ok": + return resp["data"]["conversations"] + return [] + + async def find_or_create_conversation(self, email: str) -> tuple[str | None, str]: + resp = await self.send_and_recv("find_conversation", email=email) + if resp["status"] != "ok": + return None, resp["data"]["message"] + conv_id = resp["data"]["conversation_id"] + if conv_id: + return conv_id, "OK" + return await self.create_conversation([email]) + + # ------------------------------------------------------------------ + # Send message + # ------------------------------------------------------------------ + + def _is_group(self, members: list[dict]) -> bool: + return len(members) > 2 + + async def send_message(self, conv_id: str, text: str, members: list[dict], + reply_to: str | None = None) -> tuple[bool, str]: + """Encrypt and send a message. DM: per-recipient Double Ratchet. Group: Sender Keys.""" + my_user_id = self.session["user_id"] + + # Build plaintext payload + payload = { + "sender": self.username, + "text": text, + "reply_to": reply_to, + "timestamp": datetime.now(timezone.utc).isoformat(), + } + plaintext = json.dumps(payload, ensure_ascii=False).encode("utf-8") + + if self._is_group(members): + return await self._send_group_message(conv_id, plaintext, members) + else: + return await self._send_dm(conv_id, plaintext, members) + + async def _send_dm(self, conv_id: str, plaintext: bytes, members: list[dict]) -> tuple[bool, str]: + """Encrypt DM with per-device Double Ratchet.""" + my_user_id = self.session["user_id"] + recipients = [] + first_ratchet_header = None + + for member in members: + uid = member.get("user_id") + if not uid or uid == my_user_id: + continue + + # Get all device bundles for this user + try: + device_bundles = await self._get_device_bundles(uid) + self._logger.debug("Got %d device bundles for %s", len(device_bundles), uid) + except Exception as e: + self._logger.warning("Failed to get device bundles for %s: %s", uid, e) + device_bundles = [] + + if not device_bundles: + # Fallback: try single session (legacy peer) + ratchet = await self._get_or_create_session(uid) + result = ratchet.encrypt(plaintext) + x3dh_hdr = getattr(ratchet, "_x3dh_header", None) + if x3dh_hdr: + delattr(ratchet, "_x3dh_header") + entry = { + "user_id": uid, + "encrypted_content": encode_binary(result["ciphertext"]), + "nonce": encode_binary(result["nonce"]), + "ratchet_header": result["header"], + } + if x3dh_hdr: + entry["x3dh_header"] = x3dh_hdr + recipients.append(entry) + if first_ratchet_header is None: + first_ratchet_header = result["header"] + _save_session(self.email, uid, ratchet, self._local_key) + continue + + for bundle in device_bundles: + dev_id = bundle.get("device_id") + ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id, + bundle=bundle) + result = ratchet.encrypt(plaintext) + x3dh_hdr = getattr(ratchet, "_x3dh_header", None) + if x3dh_hdr: + delattr(ratchet, "_x3dh_header") + + entry = { + "user_id": uid, + "encrypted_content": encode_binary(result["ciphertext"]), + "nonce": encode_binary(result["nonce"]), + "ratchet_header": result["header"], + } + if dev_id: + entry["device_id"] = dev_id + if x3dh_hdr: + entry["x3dh_header"] = x3dh_hdr + recipients.append(entry) + + if first_ratchet_header is None: + first_ratchet_header = result["header"] + + _save_session(self.email, uid, ratchet, self._local_key, + peer_device_id=dev_id) + + # Encrypt self-copy with static key derived from identity (not ratchet) + # Uses SELF_DEVICE_ID so all own devices can read it + self_key = derive_self_encryption_key(self.identity_private) + _, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key) + recipients.append({ + "user_id": my_user_id, + "encrypted_content": encode_binary(self_ct + self_tag), + "nonce": encode_binary(self_nonce), + "ratchet_header": {"self": True}, + }) + + if not recipients: + return False, "No recipients." + + kwargs = { + "conversation_id": conv_id, + "ratchet_header": first_ratchet_header, + "recipients": recipients, + } + + resp = await self.send_and_recv("send_message", **kwargs) + if resp["status"] == "ok": + return True, "Message sent." + return False, resp["data"]["message"] + + async def _send_group_message(self, conv_id: str, plaintext: bytes, + members: list[dict]) -> tuple[bool, str]: + """Encrypt group message with Sender Keys.""" + my_user_id = self.session["user_id"] + + # Get or create sender key for this group + sk = self.sender_key_states.get(conv_id) + if not sk: + sk = _load_sender_key_state(self.email, conv_id, self._local_key) + if not sk: + sk = SenderKeyState() + self.sender_key_states[conv_id] = sk + _save_sender_key_state(self.email, conv_id, sk, self._local_key) + # Distribute sender key to all members via pairwise ratchet + await self._distribute_sender_key(conv_id, members, sk) + + self.sender_key_states[conv_id] = sk + + # Encrypt with sender key + result = sk.encrypt(plaintext) + _save_sender_key_state(self.email, conv_id, sk, self._local_key) + + # Build per-recipient entries (same ciphertext for all except self) + recipients = [] + for member in members: + uid = member.get("user_id") + if not uid or uid == my_user_id: + continue + recipients.append({ + "user_id": uid, + "encrypted_content": encode_binary(result["ciphertext"]), + "nonce": encode_binary(result["nonce"]), + }) + + # Self-encrypted copy (so other devices + history fetch can decrypt) + self_key = derive_self_encryption_key(self.identity_private) + _, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key) + recipients.append({ + "user_id": my_user_id, + "encrypted_content": encode_binary(self_ct + self_tag), + "nonce": encode_binary(self_nonce), + "ratchet_header": {"self": True}, + }) + + ratchet_header = {"dh_pub": "00" * 32, "n": 0, "pn": 0} # Dummy for groups + + kwargs = { + "conversation_id": conv_id, + "ratchet_header": ratchet_header, + "recipients": recipients, + "sender_chain_id": encode_binary(bytes.fromhex(result["chain_id"])), + "sender_chain_n": result["n"], + } + + resp = await self.send_and_recv("send_message", **kwargs) + if resp["status"] == "ok": + return True, "Message sent." + return False, resp["data"]["message"] + + async def _distribute_sender_key(self, conv_id: str, members: list[dict], + sk: SenderKeyState): + """Send own sender key to all group members via pairwise Double Ratchet (per-device).""" + my_user_id = self.session["user_id"] + exported_key = sk.export_key() + + # Build a special "sender_key_distribution" payload + payload = { + "sender": self.username, + "text": "", + "reply_to": None, + "timestamp": datetime.now(timezone.utc).isoformat(), + "_sender_key": { + "conv_id": conv_id, + "key": encode_binary(exported_key), + "sender_device_id": self.device_id, + }, + } + plaintext = json.dumps(payload, ensure_ascii=False).encode("utf-8") + + # Send as DM to each member's devices (per-device encryption) + for member in members: + uid = member.get("user_id") + if not uid or uid == my_user_id: + continue + + try: + # Get all device bundles for this user + try: + device_bundles = await self._get_device_bundles(uid) + except Exception: + device_bundles = [] + + if not device_bundles: + # Fallback: legacy single-device + ratchet = await self._get_or_create_session(uid) + result = ratchet.encrypt(plaintext) + x3dh_header = getattr(ratchet, "_x3dh_header", None) + if x3dh_header: + delattr(ratchet, "_x3dh_header") + + recipient_entry = { + "user_id": uid, + "encrypted_content": encode_binary(result["ciphertext"]), + "nonce": encode_binary(result["nonce"]), + "ratchet_header": result["header"], + } + if x3dh_header: + recipient_entry["x3dh_header"] = x3dh_header + kwargs = { + "conversation_id": conv_id, + "ratchet_header": result["header"], + "recipients": [recipient_entry], + } + await self.send_and_recv("send_message", **kwargs) + _save_session(self.email, uid, ratchet, self._local_key) + else: + # Per-device encryption + recipients = [] + first_rh = None + for bundle in device_bundles: + dev_id = bundle.get("device_id") + ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id, + bundle=bundle) + result = ratchet.encrypt(plaintext) + x3dh_header = getattr(ratchet, "_x3dh_header", None) + if x3dh_header: + delattr(ratchet, "_x3dh_header") + + entry = { + "user_id": uid, + "encrypted_content": encode_binary(result["ciphertext"]), + "nonce": encode_binary(result["nonce"]), + "ratchet_header": result["header"], + } + if dev_id: + entry["device_id"] = dev_id + if x3dh_header: + entry["x3dh_header"] = x3dh_header + recipients.append(entry) + if first_rh is None: + first_rh = result["header"] + _save_session(self.email, uid, ratchet, self._local_key, + peer_device_id=dev_id) + + kwargs = { + "conversation_id": conv_id, + "ratchet_header": first_rh, + "recipients": recipients, + } + await self.send_and_recv("send_message", **kwargs) + except Exception as e: + self._logger.warning("Failed to distribute sender key to %s: %s", uid, e) + + # ------------------------------------------------------------------ + # Decrypt messages + # ------------------------------------------------------------------ + + def _decrypt_message(self, msg_data: dict) -> dict: + """Decrypt a single message (DM or group).""" + # Check for self-encrypted marker FIRST — after re-encryption, + # group messages will have {"self": true} ratchet_header but still + # have sender_chain_id at message level. + rh = msg_data.get("ratchet_header", {}) + if isinstance(rh, dict) and rh.get("self"): + return self._decrypt_dm(msg_data) + + if msg_data.get("sender_chain_id"): + return self._decrypt_group(msg_data) + else: + return self._decrypt_dm(msg_data) + + def _decrypt_dm(self, msg_data: dict) -> dict: + """Decrypt DM using Double Ratchet with sender, or static key for self-copies.""" + sender_id = msg_data.get("sender_id", "") + sender_device_id = msg_data.get("sender_device_id") + ratchet_header = msg_data.get("ratchet_header", {}) + ct_b64 = msg_data.get("encrypted_content", "") + nonce_b64 = msg_data.get("nonce", "") + + if not ct_b64 or not nonce_b64: + raise ValueError("Missing ciphertext or nonce") + + ciphertext = decode_binary(ct_b64) + nonce = decode_binary(nonce_b64) + + # Self-encrypted message (own sent message copy) + if isinstance(ratchet_header, dict) and ratchet_header.get("self"): + self_key = derive_self_encryption_key(self.identity_private) + ct = ciphertext[:-16] + tag = ciphertext[-16:] + plaintext = aes_decrypt(self_key, nonce, ct, tag) + else: + x3dh_header = msg_data.get("x3dh_header") + + # Session key: "sender_id:sender_device_id" or just "sender_id" for legacy + session_key = f"{sender_id}:{sender_device_id}" if sender_device_id else sender_id + + # Try to load existing session + ratchet = self.sessions.get(session_key) + if not ratchet: + ratchet = _load_session(self.email, sender_id, self._local_key, + peer_device_id=sender_device_id) + if ratchet: + self.sessions[session_key] = ratchet + + if ratchet and not x3dh_header: + # Normal case: existing session, no X3DH header + plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce) + _save_session(self.email, sender_id, ratchet, self._local_key, + peer_device_id=sender_device_id) + elif x3dh_header: + if ratchet: + # Existing session + X3DH header: sender may have reset. + backup = ratchet.export_state() + try: + plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce) + _save_session(self.email, sender_id, ratchet, self._local_key, + peer_device_id=sender_device_id) + except Exception: + restored = DoubleRatchet.import_state(backup) + self.sessions[session_key] = restored + _save_session(self.email, sender_id, restored, self._local_key, + peer_device_id=sender_device_id) + ratchet = self._process_x3dh_header(sender_id, x3dh_header, + sender_device_id=sender_device_id) + try: + plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce) + except Exception: + if self._prev_spk_private: + ratchet = self._process_x3dh_header( + sender_id, x3dh_header, + sender_device_id=sender_device_id, + spk_override=self._prev_spk_private) + plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce) + else: + raise + _save_session(self.email, sender_id, ratchet, self._local_key, + peer_device_id=sender_device_id) + else: + ratchet = self._process_x3dh_header(sender_id, x3dh_header, + sender_device_id=sender_device_id) + try: + plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce) + except Exception: + if self._prev_spk_private: + ratchet = self._process_x3dh_header( + sender_id, x3dh_header, + sender_device_id=sender_device_id, + spk_override=self._prev_spk_private) + plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce) + else: + raise + _save_session(self.email, sender_id, ratchet, self._local_key, + peer_device_id=sender_device_id) + else: + raise ValueError(f"No session for sender {sender_id}") + + payload = json.loads(plaintext) + + # Handle sender key distribution messages + if "_sender_key" in payload: + sk_data = payload["_sender_key"] + sk_conv_id = sk_data["conv_id"] + sk_key = decode_binary(sk_data["key"]) + sk_sender_device_id = sk_data.get("sender_device_id") + recv_sk = SenderKeyState.from_key(sk_key) + if sk_sender_device_id: + cache_key = f"{sk_conv_id}:{sender_id}:{sk_sender_device_id}" + else: + cache_key = f"{sk_conv_id}:{sender_id}" + self.recv_sender_keys[cache_key] = recv_sk + _save_recv_sender_key(self.email, sk_conv_id, sender_id, recv_sk, self._local_key, + sender_device_id=sk_sender_device_id) + # Return empty — this is a control message, not user-visible + return None + + return payload + + def _decrypt_group(self, msg_data: dict) -> dict: + """Decrypt group message using sender's Sender Key.""" + sender_id = msg_data.get("sender_id", "") + sender_device_id = msg_data.get("sender_device_id") + conv_id = msg_data.get("conversation_id", "") + chain_id_b64 = msg_data.get("sender_chain_id", "") + chain_n = msg_data.get("sender_chain_n", 0) + ct_b64 = msg_data.get("encrypted_content", "") + nonce_b64 = msg_data.get("nonce", "") + + if not ct_b64 or not nonce_b64 or not chain_id_b64: + raise ValueError("Missing group message fields") + + ciphertext = decode_binary(ct_b64) + nonce = decode_binary(nonce_b64) + chain_id = decode_binary(chain_id_b64) + + my_user_id = self.session["user_id"] + + # If we sent this message, use our own sender key + if sender_id == my_user_id: + sk = self.sender_key_states.get(conv_id) + if not sk: + sk = _load_sender_key_state(self.email, conv_id, self._local_key) + if sk: + self.sender_key_states[conv_id] = sk + if not sk: + raise ValueError("Own sender key not found") + # For our own messages, we can't decrypt from sender key (it's already advanced) + # Return a placeholder — the server echoed our ciphertext + raise ValueError("Cannot decrypt own group message from sender key") + + # Use received sender key — try with sender_device_id first, fall back to without + sk = None + if sender_device_id: + cache_key = f"{conv_id}:{sender_id}:{sender_device_id}" + sk = self.recv_sender_keys.get(cache_key) + if not sk: + sk = _load_recv_sender_key(self.email, conv_id, sender_id, self._local_key, + sender_device_id=sender_device_id) + if sk: + self.recv_sender_keys[cache_key] = sk + + if not sk: + # Fallback: try without device_id (legacy or same-device) + cache_key = f"{conv_id}:{sender_id}" + sk = self.recv_sender_keys.get(cache_key) + if not sk: + sk = _load_recv_sender_key(self.email, conv_id, sender_id, self._local_key) + if sk: + self.recv_sender_keys[cache_key] = sk + + if not sk: + raise ValueError(f"No sender key for {sender_id} in conversation {conv_id}") + + plaintext = sk.decrypt(chain_id.hex(), chain_n, ciphertext, nonce) + _save_recv_sender_key(self.email, conv_id, sender_id, sk, self._local_key, + sender_device_id=sender_device_id) + + return json.loads(plaintext) + + # ------------------------------------------------------------------ + # Get/decrypt messages (batch) + # ------------------------------------------------------------------ + + async def get_messages(self, conv_id: str, limit: int = 50, offset: int = 0) -> list[dict]: + resp = await self.send_and_recv("get_messages", conversation_id=conv_id, limit=limit, offset=offset) + if resp["status"] != "ok": + return [] + + cache = _load_message_cache(self.email, conv_id, self._cache_key) + decrypted = [] + message_ids = [] + raw_messages = resp["data"]["messages"] + raw_messages.reverse() # Server returns DESC, reverse to ASC + for m in raw_messages: + msg_id = m["message_id"] + message_ids.append(msg_id) + + if m.get("deleted_at"): + decrypted.append({ + "message_id": msg_id, + "sender": "", + "text": "", + "created_at": m["created_at"], + "read_by": [], + "sender_id": m.get("sender_id", ""), + "deleted": True, + }) + continue + + # Check local cache first (ratchet keys are one-time use) + cached = cache.get(msg_id) + if cached: + cached["read_by"] = m.get("read_by", []) + cached["created_at"] = m["created_at"] + if cached.get("_control"): + continue # Skip control messages + decrypted.append(cached) + continue + + try: + msg_data = { + "sender_id": m.get("sender_id", ""), + "sender_device_id": m.get("sender_device_id"), + "conversation_id": conv_id, + "ratchet_header": m.get("ratchet_header", {}), + "encrypted_content": m.get("encrypted_content", ""), + "nonce": m.get("nonce", ""), + "x3dh_header": m.get("x3dh_header"), + "sender_chain_id": m.get("sender_chain_id"), + "sender_chain_n": m.get("sender_chain_n"), + } + payload = self._decrypt_message(msg_data) + if payload is None: + # Control message (sender key distribution) — cache and skip + _save_message_to_cache(self.email, conv_id, msg_id, {"_control": True}, + cache_key=self._cache_key) + continue + payload["message_id"] = msg_id + payload["created_at"] = m["created_at"] + payload["read_by"] = m.get("read_by", []) + payload["sender_id"] = m.get("sender_id", "") + decrypted.append(payload) + # Cache the decrypted payload (without read_by which changes) + cache_entry = {k: v for k, v in payload.items() if k != "read_by"} + _save_message_to_cache(self.email, conv_id, msg_id, cache_entry, + cache_key=self._cache_key) + except Exception as e: + decrypted.append({ + "message_id": msg_id, + "sender": "???", + "text": f"[Decryption failed: {e}]", + "created_at": m["created_at"], + "read_by": [], + }) + + if message_ids: + await self.mark_read(conv_id, message_ids) + + return decrypted + + async def mark_read(self, conv_id: str, message_ids: list[str]): + if not message_ids: + return + await self.send_and_recv("mark_read", conversation_id=conv_id, message_ids=message_ids) + + def search_messages(self, conv_id: str, query: str) -> list[dict]: + """Search cached messages in a conversation. Returns matching messages.""" + cache = _load_message_cache(self.email, conv_id, self._cache_key) + query_lower = query.lower() + results = [] + for msg_id, payload in cache.items(): + if payload.get("deleted") or payload.get("_control") or payload.get("_sender_key"): + continue + text = payload.get("text", "") + if query_lower in text.lower(): + entry = dict(payload) + entry["message_id"] = msg_id + results.append(entry) + results.sort(key=lambda m: m.get("created_at", "")) + return results + + async def reset_session(self, peer_user_id: str, peer_device_id: str | None = None): + """Delete local session and notify peer to do the same.""" + if peer_device_id: + session_key = f"{peer_user_id}:{peer_device_id}" + else: + session_key = peer_user_id + self.sessions.pop(session_key, None) + _delete_session_file(self.email, peer_user_id, peer_device_id) + await self.send_and_recv("session_reset", + peer_user_id=peer_user_id, + peer_device_id=peer_device_id or "") + + def handle_session_reset_notification(self, from_user_id: str, from_device_id: str | None = None): + """Handle incoming session reset notification — delete the matching session.""" + if from_device_id: + session_key = f"{from_user_id}:{from_device_id}" + else: + session_key = from_user_id + self.sessions.pop(session_key, None) + _delete_session_file(self.email, from_user_id, from_device_id) + + # ------------------------------------------------------------------ + # Decrypt notification + # ------------------------------------------------------------------ + + def decrypt_notification(self, notif_data: dict) -> dict | None: + """Decrypt a new_message notification. Returns parsed payload or None. + + Supports new multi-device format (device_entries array) and legacy flat format. + """ + try: + conv_id = notif_data.get("conversation_id", "") + msg_id = notif_data.get("message_id", "") + sender_id = notif_data.get("sender_id", "") + sender_device_id = notif_data.get("sender_device_id") + my_user_id = self.session["user_id"] if self.session else "" + + # Extract per-device encrypted content from device_entries or flat fields + encrypted_content = "" + nonce = "" + ratchet_header = {} + x3dh_header = None + + device_entries = notif_data.get("device_entries") + if device_entries: + # Multi-device format: pick entry matching our device_id or SELF_DEVICE_ID + chosen = None + self_entry = None + for entry in device_entries: + eid = entry.get("device_id", "") + if eid == self.device_id: + chosen = entry + break + if eid == "00000000-0000-0000-0000-000000000000": + self_entry = entry + + # If sender is us, prefer self-encrypted entry + if sender_id == my_user_id: + chosen = self_entry or chosen + elif not chosen: + chosen = self_entry + + if not chosen: + self._logger.warning("No matching device_entry for device %s", self.device_id) + return None + + encrypted_content = chosen.get("encrypted_content", "") + nonce = chosen.get("nonce", "") + ratchet_header = chosen.get("ratchet_header") or notif_data.get("ratchet_header", {}) + x3dh_header = chosen.get("x3dh_header") or notif_data.get("x3dh_header") + else: + # Legacy flat format + encrypted_content = notif_data.get("encrypted_content", "") + nonce = notif_data.get("nonce", "") + ratchet_header = notif_data.get("ratchet_header", {}) + x3dh_header = notif_data.get("x3dh_header") + + msg_data = { + "sender_id": sender_id, + "sender_device_id": sender_device_id, + "conversation_id": conv_id, + "ratchet_header": ratchet_header, + "encrypted_content": encrypted_content, + "nonce": nonce, + "x3dh_header": x3dh_header, + "sender_chain_id": notif_data.get("sender_chain_id"), + "sender_chain_n": notif_data.get("sender_chain_n"), + } + payload = self._decrypt_message(msg_data) + if payload is None: + # Cache control message so get_messages skips it + if msg_id and conv_id: + _save_message_to_cache(self.email, conv_id, msg_id, {"_control": True}, + cache_key=self._cache_key) + return None + payload["conversation_id"] = conv_id + payload["message_id"] = msg_id + payload["sender_id"] = sender_id + payload["created_at"] = payload.get("timestamp", "") + payload["read_by"] = [] + # Cache so get_messages doesn't re-decrypt (ratchet keys are one-time) + if msg_id and conv_id: + cache_entry = {k: v for k, v in payload.items() if k != "read_by"} + _save_message_to_cache(self.email, conv_id, msg_id, cache_entry, + cache_key=self._cache_key) + return payload + except Exception as e: + self._logger.warning("Failed to decrypt notification: %s", e) + return None + + # ------------------------------------------------------------------ + # Delete message + # ------------------------------------------------------------------ + + async def delete_message(self, message_id: str) -> tuple[bool, str]: + resp = await self.send_and_recv("delete_message", message_id=message_id) + if resp["status"] == "ok": + return True, "Message deleted." + return False, resp["data"]["message"] + + # ------------------------------------------------------------------ + # Image sharing + # ------------------------------------------------------------------ + + async def send_image(self, conv_id: str, image_path: str, members: list[dict], + reply_to: str | None = None) -> tuple[bool, str]: + """Encrypt and upload an image, then send as a message.""" + try: + from PIL import Image + import io + except ImportError: + return False, "Pillow is required for image sharing. Install with: pip install Pillow" + + path = Path(image_path) + if not path.exists(): + return False, "File not found." + + try: + img = Image.open(path) + img.load() + except Exception as e: + return False, f"Cannot open image: {e}" + + if img.mode not in ("RGB", "L"): + img = img.convert("RGB") + + max_dim = 1920 + if max(img.size) > max_dim: + img.thumbnail((max_dim, max_dim), Image.Resampling.LANCZOS) + + buf = io.BytesIO() + img.save(buf, format="JPEG", quality=85) + image_bytes = buf.getvalue() + + thumb = img.copy() + thumb.thumbnail((200, 200), Image.Resampling.LANCZOS) + thumb_buf = io.BytesIO() + thumb.save(thumb_buf, format="JPEG", quality=60) + thumbnail_b64 = encode_binary(thumb_buf.getvalue()) + + # Encrypt image with AES-256-GCM + img_aes_key, img_iv, img_ct, img_tag = aes_encrypt(image_bytes) + encrypted_image = img_ct + img_tag + + file_id = str(uuid.uuid4()) + file_size = len(encrypted_image) + + # Chunked upload + resp = await self.send_and_recv( + "upload_image_start", + conversation_id=conv_id, + file_id=file_id, + file_size=file_size, + ) + if resp["status"] != "ok": + return False, resp["data"]["message"] + + upload_offset = 0 + while upload_offset < file_size: + chunk = encrypted_image[upload_offset:upload_offset + IMAGE_CHUNK_SIZE] + resp = await self.send_and_recv( + "upload_image_chunk", + file_id=file_id, + data=encode_binary(chunk), + ) + if resp["status"] != "ok": + return False, resp["data"]["message"] + upload_offset += len(chunk) + + resp = await self.send_and_recv("upload_image_end", file_id=file_id) + if resp["status"] != "ok": + return False, resp["data"]["message"] + + # Build message payload with image info + image_info = { + "file_id": file_id, + "aes_key": encode_binary(img_aes_key), + "iv": encode_binary(img_iv), + "thumbnail": thumbnail_b64, + "filename": path.name, + "size": len(image_bytes), + } + + payload = { + "sender": self.username, + "text": "", + "reply_to": reply_to, + "timestamp": datetime.now(timezone.utc).isoformat(), + "image": image_info, + } + plaintext = json.dumps(payload, ensure_ascii=False).encode("utf-8") + + my_user_id = self.session["user_id"] + + if self._is_group(members): + # Group image: use sender key + sk = self.sender_key_states.get(conv_id) + if not sk: + sk = _load_sender_key_state(self.email, conv_id, self._local_key) + if not sk: + sk = SenderKeyState() + self.sender_key_states[conv_id] = sk + _save_sender_key_state(self.email, conv_id, sk, self._local_key) + await self._distribute_sender_key(conv_id, members, sk) + + result = sk.encrypt(plaintext) + _save_sender_key_state(self.email, conv_id, sk, self._local_key) + + recipients = [] + for member in members: + uid = member.get("user_id") + if not uid or uid == my_user_id: + continue + recipients.append({ + "user_id": uid, + "encrypted_content": encode_binary(result["ciphertext"]), + "nonce": encode_binary(result["nonce"]), + }) + + # Self-encrypted copy for sender + self_key = derive_self_encryption_key(self.identity_private) + _, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key) + recipients.append({ + "user_id": my_user_id, + "encrypted_content": encode_binary(self_ct + self_tag), + "nonce": encode_binary(self_nonce), + "ratchet_header": {"self": True}, + }) + + resp = await self.send_and_recv( + "send_message", + conversation_id=conv_id, + ratchet_header={"dh_pub": "00" * 32, "n": 0, "pn": 0}, + recipients=recipients, + sender_chain_id=encode_binary(bytes.fromhex(result["chain_id"])), + sender_chain_n=result["n"], + image_file_id=file_id, + ) + else: + # DM image: per-device ratchet (same pattern as _send_dm) + recipients = [] + first_rh = None + for member in members: + uid = member.get("user_id") + if not uid or uid == my_user_id: + continue + + try: + device_bundles = await self._get_device_bundles(uid) + except Exception: + device_bundles = [] + + if not device_bundles: + # Fallback: legacy single-device + ratchet = await self._get_or_create_session(uid) + result = ratchet.encrypt(plaintext) + x3dh_h = getattr(ratchet, "_x3dh_header", None) + if x3dh_h: + delattr(ratchet, "_x3dh_header") + entry = { + "user_id": uid, + "encrypted_content": encode_binary(result["ciphertext"]), + "nonce": encode_binary(result["nonce"]), + "ratchet_header": result["header"], + } + if x3dh_h: + entry["x3dh_header"] = x3dh_h + recipients.append(entry) + if first_rh is None: + first_rh = result["header"] + _save_session(self.email, uid, ratchet, self._local_key) + else: + for bundle in device_bundles: + dev_id = bundle.get("device_id") + ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id, + bundle=bundle) + result = ratchet.encrypt(plaintext) + x3dh_h = getattr(ratchet, "_x3dh_header", None) + if x3dh_h: + delattr(ratchet, "_x3dh_header") + entry = { + "user_id": uid, + "encrypted_content": encode_binary(result["ciphertext"]), + "nonce": encode_binary(result["nonce"]), + "ratchet_header": result["header"], + } + if dev_id: + entry["device_id"] = dev_id + if x3dh_h: + entry["x3dh_header"] = x3dh_h + recipients.append(entry) + if first_rh is None: + first_rh = result["header"] + _save_session(self.email, uid, ratchet, self._local_key, + peer_device_id=dev_id) + + # Encrypt self-copy with static key + self_key = derive_self_encryption_key(self.identity_private) + _, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key) + recipients.append({ + "user_id": my_user_id, + "encrypted_content": encode_binary(self_ct + self_tag), + "nonce": encode_binary(self_nonce), + "ratchet_header": {"self": True}, + }) + + resp = await self.send_and_recv( + "send_message", + conversation_id=conv_id, + ratchet_header=first_rh, + recipients=recipients, + image_file_id=file_id, + ) + + if resp["status"] == "ok": + return True, "Image sent." + return False, resp["data"]["message"] + + async def send_file(self, conv_id: str, file_path: str, members: list[dict], + reply_to: str | None = None) -> tuple[bool, str]: + """Encrypt and upload a file, then send as a message.""" + import mimetypes + + path = Path(file_path) + if not path.exists(): + return False, "File not found." + + try: + file_bytes = path.read_bytes() + except Exception as e: + return False, f"Cannot read file: {e}" + + mime_type = mimetypes.guess_type(path.name)[0] or "application/octet-stream" + + # Encrypt file with AES-256-GCM + file_aes_key, file_iv, file_ct, file_tag = aes_encrypt(file_bytes) + encrypted_file = file_ct + file_tag + + file_id = str(uuid.uuid4()) + file_size = len(encrypted_file) + + # Chunked upload (reuse image upload infrastructure with file_type="file") + resp = await self.send_and_recv( + "upload_image_start", + conversation_id=conv_id, + file_id=file_id, + file_size=file_size, + file_type="file", + ) + if resp["status"] != "ok": + return False, resp["data"]["message"] + + upload_offset = 0 + while upload_offset < file_size: + chunk = encrypted_file[upload_offset:upload_offset + IMAGE_CHUNK_SIZE] + resp = await self.send_and_recv( + "upload_image_chunk", + file_id=file_id, + data=encode_binary(chunk), + ) + if resp["status"] != "ok": + return False, resp["data"]["message"] + upload_offset += len(chunk) + + resp = await self.send_and_recv("upload_image_end", file_id=file_id) + if resp["status"] != "ok": + return False, resp["data"]["message"] + + # Build message payload with file info + file_info = { + "file_id": file_id, + "aes_key": encode_binary(file_aes_key), + "iv": encode_binary(file_iv), + "filename": path.name, + "size": len(file_bytes), + "mime_type": mime_type, + } + + payload = { + "sender": self.username, + "text": "", + "reply_to": reply_to, + "timestamp": datetime.now(timezone.utc).isoformat(), + "file": file_info, + } + plaintext = json.dumps(payload, ensure_ascii=False).encode("utf-8") + + my_user_id = self.session["user_id"] + + if self._is_group(members): + sk = self.sender_key_states.get(conv_id) + if not sk: + sk = _load_sender_key_state(self.email, conv_id, self._local_key) + if not sk: + sk = SenderKeyState() + self.sender_key_states[conv_id] = sk + _save_sender_key_state(self.email, conv_id, sk, self._local_key) + await self._distribute_sender_key(conv_id, members, sk) + + result = sk.encrypt(plaintext) + _save_sender_key_state(self.email, conv_id, sk, self._local_key) + + recipients = [] + for member in members: + uid = member.get("user_id") + if not uid or uid == my_user_id: + continue + recipients.append({ + "user_id": uid, + "encrypted_content": encode_binary(result["ciphertext"]), + "nonce": encode_binary(result["nonce"]), + }) + + # Self-encrypted copy for sender + self_key = derive_self_encryption_key(self.identity_private) + _, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key) + recipients.append({ + "user_id": my_user_id, + "encrypted_content": encode_binary(self_ct + self_tag), + "nonce": encode_binary(self_nonce), + "ratchet_header": {"self": True}, + }) + + resp = await self.send_and_recv( + "send_message", + conversation_id=conv_id, + ratchet_header={"dh_pub": "00" * 32, "n": 0, "pn": 0}, + recipients=recipients, + sender_chain_id=encode_binary(bytes.fromhex(result["chain_id"])), + sender_chain_n=result["n"], + image_file_id=file_id, + ) + else: + # DM file: per-device ratchet (same pattern as _send_dm) + recipients = [] + first_rh = None + for member in members: + uid = member.get("user_id") + if not uid or uid == my_user_id: + continue + + try: + device_bundles = await self._get_device_bundles(uid) + except Exception: + device_bundles = [] + + if not device_bundles: + # Fallback: legacy single-device + ratchet = await self._get_or_create_session(uid) + result = ratchet.encrypt(plaintext) + x3dh_h = getattr(ratchet, "_x3dh_header", None) + if x3dh_h: + delattr(ratchet, "_x3dh_header") + entry = { + "user_id": uid, + "encrypted_content": encode_binary(result["ciphertext"]), + "nonce": encode_binary(result["nonce"]), + "ratchet_header": result["header"], + } + if x3dh_h: + entry["x3dh_header"] = x3dh_h + recipients.append(entry) + if first_rh is None: + first_rh = result["header"] + _save_session(self.email, uid, ratchet, self._local_key) + else: + for bundle in device_bundles: + dev_id = bundle.get("device_id") + ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id, + bundle=bundle) + result = ratchet.encrypt(plaintext) + x3dh_h = getattr(ratchet, "_x3dh_header", None) + if x3dh_h: + delattr(ratchet, "_x3dh_header") + entry = { + "user_id": uid, + "encrypted_content": encode_binary(result["ciphertext"]), + "nonce": encode_binary(result["nonce"]), + "ratchet_header": result["header"], + } + if dev_id: + entry["device_id"] = dev_id + if x3dh_h: + entry["x3dh_header"] = x3dh_h + recipients.append(entry) + if first_rh is None: + first_rh = result["header"] + _save_session(self.email, uid, ratchet, self._local_key, + peer_device_id=dev_id) + + # Encrypt self-copy with static key + self_key = derive_self_encryption_key(self.identity_private) + _, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key) + recipients.append({ + "user_id": my_user_id, + "encrypted_content": encode_binary(self_ct + self_tag), + "nonce": encode_binary(self_nonce), + "ratchet_header": {"self": True}, + }) + + resp = await self.send_and_recv( + "send_message", + conversation_id=conv_id, + ratchet_header=first_rh, + recipients=recipients, + image_file_id=file_id, + ) + + if resp["status"] == "ok": + return True, "File sent." + return False, resp["data"]["message"] + + async def download_file(self, file_id: str, file_info: dict) -> bytes | None: + """Download and decrypt a file. Returns decrypted file bytes or None.""" + chunks = [] + offset = 0 + while True: + resp = await self.send_and_recv( + "download_image", + file_id=file_id, + offset=offset, + ) + if resp["status"] != "ok": + return None + data = resp["data"] + chunk = decode_binary(data["data"]) + chunks.append(chunk) + offset += len(chunk) + if data.get("done"): + break + + encrypted_data = b"".join(chunks) + if len(encrypted_data) < 16: + return None + ciphertext = encrypted_data[:-16] + tag = encrypted_data[-16:] + + try: + file_aes_key = decode_binary(file_info["aes_key"]) + iv = decode_binary(file_info["iv"]) + return aes_decrypt(file_aes_key, iv, ciphertext, tag) + except Exception: + return None + + async def download_image(self, file_id: str, image_info: dict) -> bytes | None: + """Download and decrypt an image. Returns decrypted image bytes or None.""" + chunks = [] + offset = 0 + while True: + resp = await self.send_and_recv( + "download_image", + file_id=file_id, + offset=offset, + ) + if resp["status"] != "ok": + return None + data = resp["data"] + chunk = decode_binary(data["data"]) + chunks.append(chunk) + offset += len(chunk) + if data.get("done"): + break + + encrypted_data = b"".join(chunks) + if len(encrypted_data) < 16: + return None + ciphertext = encrypted_data[:-16] + tag = encrypted_data[-16:] + + try: + img_aes_key = decode_binary(image_info["aes_key"]) + iv = decode_binary(image_info["iv"]) + return aes_decrypt(img_aes_key, iv, ciphertext, tag) + except Exception: + return None + + # ------------------------------------------------------------------ + # Re-encrypt history (for device pairing) + # ------------------------------------------------------------------ + + async def reencrypt_history(self): + """Re-encrypt all cached messages with self-encryption key. + + After device pairing, the new device shares the same identity key + but cannot decrypt old messages (Double Ratchet keys are one-time use). + This re-encrypts all cached messages so they can be read using the + self-encryption key derived from the shared identity key. + """ + if not self.identity_private or not self.session: + return + + self_key = derive_self_encryption_key(self.identity_private) + + # Phase 1: Fetch & decrypt all messages to populate cache + # (messages the old device never opened won't be in cache yet) + try: + convs = await self.list_conversations() + total_convs = len(convs) + for ci, conv in enumerate(convs): + cid = conv.get("id") or conv.get("conversation_id") + if not cid: + continue + if self._reencrypt_progress_cb: + self._reencrypt_progress_cb( + f"Fetching messages: {ci + 1}/{total_convs} conversations..." + ) + offset = 0 + while True: + msgs = await self.get_messages(cid, limit=200, offset=offset) + if not msgs or len(msgs) < 200: + break + offset += len(msgs) + except Exception as e: + self._logger.warning("Failed to fetch messages for re-encryption: %s", e) + + # Phase 2: Read cache and re-encrypt + cache_dir = get_key_dir(self.email) / "message_cache" + if not cache_dir.exists(): + self._logger.info("No message cache to re-encrypt.") + return + + all_updates = [] + conv_ids = set() + for f in cache_dir.iterdir(): + if f.suffix in (".json", ".bin"): + conv_ids.add(f.stem) + + total_files = len(conv_ids) + for i, conv_id in enumerate(sorted(conv_ids)): + cache = _load_message_cache(self.email, conv_id, self._cache_key) + if not cache: + continue + + for msg_id, entry in cache.items(): + # Skip control messages (sender key distribution) + if entry.get("_control"): + continue + # Skip entries with no useful content + text = entry.get("text", "") + if not text and not entry.get("image") and not entry.get("file"): + continue + + # Rebuild plaintext from cached payload + payload = {k: v for k, v in entry.items() + if k not in ("message_id", "created_at", "read_by", "sender_id", "deleted")} + plaintext = json.dumps(payload, ensure_ascii=False).encode("utf-8") + + # Re-encrypt with self-encryption key + _, nonce, ct, tag = aes_encrypt(plaintext, key=self_key) + all_updates.append({ + "message_id": msg_id, + "encrypted_content": encode_binary(ct + tag), + "nonce": encode_binary(nonce), + }) + + if self._reencrypt_progress_cb: + self._reencrypt_progress_cb(f"Re-encrypting history: {i + 1}/{total_files} conversations...") + + if not all_updates: + self._logger.info("No messages to re-encrypt.") + return + + # Send in batches of 500 + batch_size = 500 + total = len(all_updates) + for start in range(0, total, batch_size): + batch = all_updates[start:start + batch_size] + resp = await self.send_and_recv("reencrypt_messages", updates=batch) + if resp["status"] != "ok": + self._logger.warning("Re-encrypt batch failed: %s", resp.get("data", {}).get("message", "")) + else: + self._logger.info("Re-encrypted %d/%d messages.", min(start + batch_size, total), total) + + if self._reencrypt_progress_cb: + self._reencrypt_progress_cb(f"Re-encryption complete: {total} messages uploaded.") + + # ------------------------------------------------------------------ + # User Profiles + # ------------------------------------------------------------------ + + async def get_profile(self, user_id: str | None = None) -> dict | None: + """Get user profile. If user_id is None, returns own profile.""" + kwargs = {} + if user_id: + kwargs["user_id"] = user_id + resp = await self.send_and_recv("get_profile", **kwargs) + if resp["status"] == "ok": + return resp["data"] + return None + + async def update_profile(self, **fields) -> tuple[bool, str]: + """Update own profile (phone, location, *_visible).""" + resp = await self.send_and_recv("update_profile", **fields) + if resp["status"] == "ok": + return True, "OK" + return False, resp["data"]["message"] + + async def update_avatar(self, image_data: bytes) -> tuple[bool, str]: + """Upload avatar image.""" + resp = await self.send_and_recv("update_avatar", data=encode_binary(image_data)) + if resp["status"] == "ok": + return True, resp["data"].get("avatar_file", "") + return False, resp["data"]["message"] + + async def get_avatar(self, user_id: str) -> bytes | None: + """Download avatar for a user.""" + resp = await self.send_and_recv("get_avatar", user_id=user_id) + if resp["status"] == "ok": + return decode_binary(resp["data"]["data"]) + return None + + async def update_group_avatar(self, conv_id: str, image_data: bytes) -> tuple[bool, str]: + """Upload avatar for a group conversation.""" + resp = await self.send_and_recv("update_group_avatar", + conversation_id=conv_id, data=encode_binary(image_data)) + if resp["status"] == "ok": + return True, resp["data"].get("avatar_file", "") + return False, resp["data"]["message"] + + async def get_group_avatar(self, conv_id: str) -> bytes | None: + """Download avatar for a group conversation.""" + resp = await self.send_and_recv("get_group_avatar", conversation_id=conv_id) + if resp["status"] == "ok": + return decode_binary(resp["data"]["data"]) + return None + + # ------------------------------------------------------------------ + # Cleanup + # ------------------------------------------------------------------ + + async def close(self): + self.connected = False + if self._listener_task: + self._listener_task.cancel() + if self.raw_writer: + self.raw_writer.close() + + async def reconnect(self): + """Close existing connection and re-establish: connect + re-login using in-memory keys.""" + try: + await self.close() + except Exception: + pass + # Reset reader/writer but keep keys and sessions + self.reader = None + self.writer = None + self.raw_writer = None + self._listener_task = None + self._pending.clear() + self.login_rejected = False + # Drain queues + while not self._response_queue.empty(): + try: + self._response_queue.get_nowait() + except Exception: + break + while not self._notification_queue.empty(): + try: + self._notification_queue.get_nowait() + except Exception: + break + await self.connect() + self._listener_task = asyncio.create_task(self._background_listener()) + if self.email and self.private_key: + # RSA challenge-response login (keys already in memory) + start = await self.send_and_recv("login_start", email=self.email) + if start["status"] == "ok": + challenge = decode_binary(start["data"]["challenge"]) + signature = rsa_sign(self.private_key, challenge) + login_kwargs = { + "email": self.email, + "signature": encode_binary(signature), + "client_version": VERSION, + } + if self.device_id: + login_kwargs["device_id"] = self.device_id + finish = await self.send_and_recv("login_finish", **login_kwargs) + if finish["status"] == "ok": + self.session = finish["data"] + asyncio.create_task(self._ensure_prekeys()) + else: + # Login rejected — keys were likely rotated on another device + self.session = None + self.connected = False + self.login_rejected = True diff --git a/zaloha/client.py b/zaloha/client.py new file mode 100644 index 0000000..8fc2f2c --- /dev/null +++ b/zaloha/client.py @@ -0,0 +1,636 @@ +"""Interactive CLI client for encrypted chat (X3DH + Double Ratchet).""" + +import asyncio +import logging +import os + +from chat_core import ChatClient + + +def setup_logging(): + level_name = os.getenv("LOG_LEVEL", "WARNING").upper() + level = getattr(logging, level_name, logging.WARNING) + logging.basicConfig(level=level, format="%(levelname)s: %(message)s") + + +async def prompt(text: str) -> str: + """Non-blocking terminal input.""" + return await asyncio.get_event_loop().run_in_executor(None, lambda: input(text).strip()) + + +def _human_size(n: int) -> str: + if n >= 1024 * 1024: + return f"{n / (1024*1024):.1f} MB" + if n >= 1024: + return f"{n / 1024:.0f} KB" + return f"{n} B" + + +async def _select_conversation(client: ChatClient, label: str = "Select conversation") -> tuple[dict | None, list[dict]]: + """List conversations and let user pick one. Returns (conv, convs) or (None, []).""" + convs = await client.list_conversations() + if not convs: + print("[*] No conversations.") + return None, [] + + def conv_label(c): + if c.get("name"): + return c["name"] + others = [m.get("username") or m.get("email") or "?" for m in c["members"] if m.get("email") != client.email] + return ", ".join(others) if others else client.username + + print() + for i, c in enumerate(convs): + print(f" {i+1}) {conv_label(c)}") + choice = await prompt(f"{label}: ") + try: + idx = int(choice) - 1 + if not (0 <= idx < len(convs)): + print("[!] Invalid selection.") + return None, convs + except ValueError: + print("[!] Invalid selection.") + return None, convs + return convs[idx], convs + + +async def interactive_menu(client: ChatClient): + """Interactive terminal menu.""" + while True: + print("\n--- Encrypted Chat ---") + print("1) Send direct message") + print("2) Send to conversation") + print("3) Read messages") + print("4) Create group conversation") + print("5) Add member to group") + print("6) Send image") + print("7) Send file") + print("8) Invitations") + print("9) Leave group") + print("10) Rename group") + print("11) Delete conversation") + print("12) Search messages") + print("13) My profile") + print("14) View user profile") + print("15) Manage devices") + print("q) Quit") + + choice = await prompt("> ") + + if choice == "1": + email = await prompt("To (email): ") + if not email: + continue + text = await prompt("Message: ") + if not text: + continue + conv_id, msg = await client.find_or_create_conversation(email) + if not conv_id: + print(f"[!] {msg}") + continue + convs = await client.list_conversations() + members = [] + for c in convs: + if c["conversation_id"] == conv_id: + members = c["members"] + break + ok, msg = await client.send_message(conv_id, text, members) + print(f"[{'+'if ok else '!'}] {msg}") + + elif choice == "2": + conv, _ = await _select_conversation(client) + if not conv: + continue + text = await prompt("Message: ") + if not text: + continue + ok, msg = await client.send_message(conv["conversation_id"], text, conv["members"]) + print(f"[{'+'if ok else '!'}] {msg}") + + elif choice == "3": + conv, _ = await _select_conversation(client) + if not conv: + continue + messages = await client.get_messages(conv["conversation_id"]) + if not messages: + print("[*] No messages.") + continue + _print_messages(messages, client, conv) + + action = await prompt("\nAction (r=reply, d=delete, dl=download file, empty=back): ") + if not action: + continue + if action.lower().startswith("dl"): + await _download_file_action(client, messages) + continue + if action.lower().startswith("d"): + await _delete_message_action(client, messages) + continue + if action.lower().startswith("r"): + reply_choice = await prompt("Reply to message #: ") + else: + reply_choice = action + try: + reply_idx = int(reply_choice) - 1 + if not (0 <= reply_idx < len(messages)): + print("[!] Invalid message number.") + continue + except ValueError: + print("[!] Invalid number.") + continue + reply_to_id = messages[reply_idx]["message_id"] + text = await prompt("Message: ") + if not text: + continue + ok, msg = await client.send_message(conv["conversation_id"], text, conv["members"], reply_to=reply_to_id) + print(f"[{'+'if ok else '!'}] {msg}") + + elif choice == "4": + name = await prompt("Group name (empty for none): ") + members_input = await prompt("Member emails (comma-separated): ") + members = [m.strip() for m in members_input.split(",") if m.strip()] + if not members: + continue + conv_id, msg = await client.create_conversation(members, name=name.strip() or None) + if conv_id: + print(f"[+] Group created with: {', '.join(members)}") + else: + print(f"[!] {msg}") + + elif choice == "5": + conv, _ = await _select_conversation(client) + if not conv: + continue + email = await prompt("Email to add: ") + ok, msg = await client.add_member(conv["conversation_id"], email) + print(f"[{'+'if ok else '!'}] {msg or 'Invitation sent.'}") + + elif choice == "6": + conv, _ = await _select_conversation(client) + if not conv: + continue + image_path = await prompt("Image path: ") + if not image_path: + continue + ok, msg = await client.send_image(conv["conversation_id"], image_path, conv["members"]) + print(f"[{'+'if ok else '!'}] {msg}") + + elif choice == "7": + conv, _ = await _select_conversation(client) + if not conv: + continue + file_path = await prompt("File path: ") + if not file_path: + continue + if not os.path.isfile(file_path): + print("[!] File not found.") + continue + ok, msg = await client.send_file(conv["conversation_id"], file_path, conv["members"]) + print(f"[{'+'if ok else '!'}] {msg}") + + elif choice == "8": + await _invitations_menu(client) + + elif choice == "9": + conv, _ = await _select_conversation(client, "Select group to leave") + if not conv: + continue + confirm = await prompt(f"Leave '{conv.get('name', 'this conversation')}'? (y/n): ") + if confirm.lower() != "y": + continue + ok, msg = await client.leave_group(conv["conversation_id"]) + print(f"[{'+'if ok else '!'}] {msg}") + + elif choice == "10": + conv, _ = await _select_conversation(client, "Select group to rename") + if not conv: + continue + name = await prompt("New name: ") + if not name: + continue + ok, msg = await client.rename_conversation(conv["conversation_id"], name.strip()) + print(f"[{'+'if ok else '!'}] {msg}") + + elif choice == "11": + conv, _ = await _select_conversation(client, "Select conversation to delete") + if not conv: + continue + confirm = await prompt("Delete this conversation? This cannot be undone. (y/n): ") + if confirm.lower() != "y": + continue + ok, msg = await client.delete_conversation(conv["conversation_id"]) + print(f"[{'+'if ok else '!'}] {msg}") + + elif choice == "12": + conv, _ = await _select_conversation(client, "Select conversation to search") + if not conv: + continue + query = await prompt("Search query: ") + if not query: + continue + # First ensure we have messages cached by fetching them + await client.get_messages(conv["conversation_id"]) + results = client.search_messages(conv["conversation_id"], query) + if not results: + print("[*] No matches found.") + continue + print(f"\n[*] {len(results)} match(es):") + for r in results: + sender = r.get("sender", "???") + text = r.get("text", "") + ts = r.get("created_at", "")[:16] + # Highlight match in text + idx = text.lower().find(query.lower()) + if idx >= 0: + text = text[:idx] + "\033[33m" + text[idx:idx+len(query)] + "\033[0m" + text[idx+len(query):] + print(f" [{ts}] {sender}: {text}") + + elif choice == "13": + await _my_profile_menu(client) + + elif choice == "14": + email = await prompt("User email: ") + if not email: + continue + # Need to find user_id from email — try via conversation members + user_id = None + convs = await client.list_conversations() + for c in convs: + for m in c.get("members", []): + if m.get("email") == email: + user_id = m.get("user_id") or m.get("id") + break + if user_id: + break + if not user_id: + print("[!] User not found in your conversations.") + continue + profile = await client.get_profile(user_id) + if not profile: + print("[!] Could not load profile.") + continue + _print_profile(profile) + + elif choice == "15": + await _devices_menu(client) + + elif choice in ("q", "Q", "quit", "exit"): + print("[*] Bye.") + break + + +def _print_messages(messages, client, conv): + """Print messages to terminal.""" + print() + for i, m in enumerate(messages): + if m.get("deleted"): + print(f" #{i+1} [Message deleted]") + continue + reply_info = "" + if m.get("reply_to"): + for j, orig in enumerate(messages): + if orig["message_id"] == m["reply_to"]: + reply_info = f" (reply to #{j+1})" + break + else: + reply_info = " (reply to older message)" + image_info = "" + if m.get("image"): + img = m["image"] + image_info = f" [Image: {img.get('filename', '?')} ({_human_size(img.get('size', 0))})]" + file_info = "" + if m.get("file"): + fi = m["file"] + file_info = f" [File: {fi.get('filename', '?')} ({_human_size(fi.get('size', 0))})]" + read_info = "" + if m.get("sender") == client.username: + read_by = m.get("read_by", []) + member_map = {} + for mem in conv.get("members", []): + uid = mem.get("user_id") or mem.get("id", "") + if uid: + member_map[uid] = mem.get("username") or mem.get("email") or "?" + my_uid = client.session.get("user_id", "") if client.session else "" + others_read = [r for r in read_by if r.get("user_id") != my_uid] + if others_read: + names = ", ".join(member_map.get(r["user_id"], r["user_id"][:8]) for r in others_read) + read_info = f" [\u2713\u2713 Read by {names}]" + else: + read_info = " [\u2713 Sent]" + text = m.get("text", "") + print(f" #{i+1} {m['sender']}: {text}{image_info}{file_info}{reply_info}{read_info}") + + +async def _delete_message_action(client, messages): + del_choice = await prompt("Delete message #: ") + try: + del_idx = int(del_choice) - 1 + if not (0 <= del_idx < len(messages)): + print("[!] Invalid message number.") + return + except ValueError: + print("[!] Invalid number.") + return + ok, msg = await client.delete_message(messages[del_idx]["message_id"]) + print(f"[{'+'if ok else '!'}] {msg}") + + +async def _download_file_action(client, messages): + dl_choice = await prompt("Download from message #: ") + try: + dl_idx = int(dl_choice) - 1 + if not (0 <= dl_idx < len(messages)): + print("[!] Invalid message number.") + return + except ValueError: + print("[!] Invalid number.") + return + m = messages[dl_idx] + file_info = m.get("file") or m.get("image") + if not file_info: + print("[!] No file/image in this message.") + return + filename = file_info.get("filename", "download") + save_path = await prompt(f"Save as [{filename}]: ") + if not save_path: + save_path = filename + data = await client.download_file(file_info["file_id"], file_info) + if data: + with open(save_path, "wb") as f: + f.write(data) + print(f"[+] Saved to {save_path} ({_human_size(len(data))})") + else: + print("[!] Download failed.") + + +async def _invitations_menu(client): + invitations = await client.list_invitations() + if not invitations: + print("[*] No pending invitations.") + return + print("\nPending invitations:") + for i, inv in enumerate(invitations): + inv_name = inv.get("conversation_name") or inv.get("conversation_id", "")[:8] + invited_by = inv.get("invited_by_username") or inv.get("invited_by", "")[:8] + print(f" {i+1}) {inv_name} (invited by {invited_by})") + choice = await prompt("Select invitation (or empty to go back): ") + if not choice: + return + try: + idx = int(choice) - 1 + if not (0 <= idx < len(invitations)): + print("[!] Invalid selection.") + return + except ValueError: + print("[!] Invalid selection.") + return + inv = invitations[idx] + action = await prompt("(a)ccept or (d)ecline? ") + if action.lower().startswith("a"): + ok, msg = await client.accept_invitation(inv["conversation_id"]) + print(f"[{'+'if ok else '!'}] {msg}") + elif action.lower().startswith("d"): + ok, msg = await client.decline_invitation(inv["conversation_id"]) + print(f"[{'+'if ok else '!'}] {msg}") + + +def _print_profile(profile): + print(f"\n Username: {profile.get('username', '?')}") + print(f" Email: {profile.get('email', '?')}") + phone = profile.get("phone") + if phone: + print(f" Phone: {phone}") + location = profile.get("location") + if location: + print(f" Location: {location}") + has_avatar = profile.get("avatar_file") + print(f" Avatar: {'Yes' if has_avatar else 'No'}") + + +async def _my_profile_menu(client): + profile = await client.get_profile() + if not profile: + print("[!] Could not load profile.") + return + print("\n--- My Profile ---") + _print_profile(profile) + print(f" Phone visible: {profile.get('phone_visible', False)}") + print(f" Email visible: {profile.get('email_visible', False)}") + print(f" Location visible: {profile.get('location_visible', False)}") + + action = await prompt("\n(e)dit, (a)vatar upload, or empty to go back: ") + if not action: + return + if action.lower().startswith("e"): + print("[*] Leave fields empty to keep current value.") + phone = await prompt(f"Phone [{profile.get('phone', '')}]: ") + location = await prompt(f"Location [{profile.get('location', '')}]: ") + phone_vis = await prompt(f"Phone visible [{profile.get('phone_visible', False)}] (y/n): ") + email_vis = await prompt(f"Email visible [{profile.get('email_visible', False)}] (y/n): ") + loc_vis = await prompt(f"Location visible [{profile.get('location_visible', False)}] (y/n): ") + + fields = {} + if phone: + fields["phone"] = phone + if location: + fields["location"] = location + if phone_vis.lower() in ("y", "n"): + fields["phone_visible"] = phone_vis.lower() == "y" + if email_vis.lower() in ("y", "n"): + fields["email_visible"] = email_vis.lower() == "y" + if loc_vis.lower() in ("y", "n"): + fields["location_visible"] = loc_vis.lower() == "y" + if fields: + ok, msg = await client.update_profile(**fields) + print(f"[{'+'if ok else '!'}] {msg}") + else: + print("[*] No changes.") + elif action.lower().startswith("a"): + path = await prompt("Avatar image path: ") + if not path or not os.path.isfile(path): + print("[!] File not found.") + return + data = open(path, "rb").read() + ok, msg = await client.update_avatar(data) + print(f"[{'+'if ok else '!'}] {msg}") + + +async def _devices_menu(client): + resp = await client.send_and_recv("list_devices") + if resp.get("status") != "ok": + print(f"[!] {resp.get('data', {}).get('message', 'Failed')}") + return + devices = resp["data"].get("devices", []) + if not devices: + print("[*] No devices found.") + return + current_device_id = client.device_id + print("\nYour devices:") + for i, d in enumerate(devices): + name = d.get("device_name") or "Unnamed" + did = d.get("device_id", "?") + last_seen = d.get("last_seen_at", "?") + current = " (this device)" if did == current_device_id else "" + print(f" {i+1}) {name} — {did[:8]}... — last seen: {last_seen}{current}") + action = await prompt("\n(r)emove a device, or empty to go back: ") + if not action or not action.lower().startswith("r"): + return + choice = await prompt("Remove device #: ") + try: + idx = int(choice) - 1 + if not (0 <= idx < len(devices)): + print("[!] Invalid selection.") + return + except ValueError: + print("[!] Invalid selection.") + return + d = devices[idx] + if d.get("device_id") == current_device_id: + print("[!] Cannot remove current device.") + return + resp = await client.send_and_recv("remove_device", device_id=d["device_id"]) + if resp.get("status") == "ok": + print("[+] Device removed.") + else: + print(f"[!] {resp.get('data', {}).get('message', 'Failed')}") + + +async def notification_printer(client: ChatClient): + """Print real-time notifications with sender name.""" + while True: + notif = await client._notification_queue.get() + notif_type = notif.get("type", "") + data = notif.get("data", {}) + if notif_type == "messages_read": + continue # Silent - read receipts shown when reading messages + if notif_type == "session_reset": + from_uid = data.get("from_user_id", "")[:8] + client.handle_session_reset_notification( + data.get("from_user_id", ""), + data.get("from_device_id") or None, + ) + print(f"\n[*] Session with {from_uid}... was reset. New session will be created on next message.") + continue + if notif_type == "group_invitation": + inv_name = data.get("conversation_name", "?") + invited_by = data.get("invited_by_username", "?") + print(f"\n[*] New invitation to '{inv_name}' from {invited_by}. Use option 8 to accept/decline.") + continue + if notif_type in ("conversation_created", "member_added", "member_removed", "conversation_renamed"): + print(f"\n[*] Conversation updated ({notif_type}).") + continue + if notif_type in ("user_online", "user_offline", "online_users"): + continue # Silent for CLI + payload = client.decrypt_notification(data) + if payload: + print(f"\n[*] New message from {payload['sender']} in conversation {data.get('conversation_id', '?')[:8]}...") + # None = control message (sender key distribution), skip silently + + +async def main(): + setup_logging() + client = ChatClient() + await client.connect() + + client._listener_task = asyncio.create_task(client._background_listener()) + notif_task = asyncio.create_task(notification_printer(client)) + + print("=== Encrypted Chat Client ===") + print("1) Register") + print("2) Login") + print("3) Link new device (this device)") + print("4) Authorize new device (from this device)") + print("5) Rotate keys (revoke other devices)") + choice = await prompt("> ") + + if choice == "1": + username = await prompt("Username (display): ") + email = await prompt("Email: ") + password = await prompt("Password (for private key): ") + if not email or not password: + print("[!] Email and password required.") + await client.close() + return + ok, code_or_msg = await client.register(username, password, email=email) + if not ok: + print(f"[!] {code_or_msg}") + await client.close() + return + print(f"[*] Registration code: {code_or_msg}") + code = await prompt("Enter code: ") + ok2, msg2 = await client.confirm_registration(email, username, code) + print(f"[{'+'if ok2 else '!'}] {msg2}") + if ok2: + ok3, msg3 = await client.login(email, password) + print(f"[{'+'if ok3 else '!'}] {msg3}") + elif choice == "2": + email = await prompt("Email: ") + password = await prompt("Password (for private key): ") + ok, msg = await client.login(email, password) + print(f"[{'+'if ok else '!'}] {msg}") + elif choice == "3": + email = await prompt("Email: ") + password = await prompt("Password (for private key): ") + if not password: + print("[!] Password required.") + await client.close() + return + ok, code_or_msg = await client.pairing_start(email) + if not ok: + print(f"[!] {code_or_msg}") + await client.close() + return + code = code_or_msg + print(f"[*] Pairing code: {code}") + print("[*] Approve this code on an already-logged-in device.") + ok2, msg2 = await client.pairing_wait(code, email, password) + if not ok2: + print(f"[!] {msg2}") + await client.close() + return + print(f"[+] {msg2}") + ok3, msg3 = await client.login(email, password) + print(f"[{'+'if ok3 else '!'}] {msg3}") + elif choice == "4": + email = await prompt("Email: ") + password = await prompt("Password (for private key): ") + ok, msg = await client.login(email, password) + print(f"[{'+'if ok else '!'}] {msg}") + if not ok: + await client.close() + return + code = await prompt("Pairing code: ") + ok2, msg2 = await client.authorize_device(code) + print(f"[{'+'if ok2 else '!'}] {msg2}") + elif choice == "5": + email = await prompt("Email: ") + password = await prompt("Password (for private key): ") + ok, msg = await client.login(email, password) + print(f"[{'+'if ok else '!'}] {msg}") + if not ok: + await client.close() + return + confirm = await prompt("This will revoke other devices. Type 'YES' to continue: ") + if confirm != "YES": + print("[*] Cancelled.") + await client.close() + return + ok2, msg2 = await client.rotate_keys(client.username, password) + print(f"[{'+'if ok2 else '!'}] {msg2}") + else: + print("[!] Invalid choice.") + await client.close() + return + + if client.session: + await interactive_menu(client) + + notif_task.cancel() + await client.close() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("\n[*] Bye.") diff --git a/zaloha/crypto_utils.py b/zaloha/crypto_utils.py new file mode 100644 index 0000000..7c1a0d9 --- /dev/null +++ b/zaloha/crypto_utils.py @@ -0,0 +1,812 @@ +"""Cryptographic utilities: Ed25519, X25519, AES-256-GCM, Double Ratchet, Sender Keys. + +RSA functions retained for login challenge-response only. +""" + +import hashlib +import hmac +import json +import os +import struct +import uuid +from dataclasses import dataclass, field + +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import padding, rsa +from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey, Ed25519PublicKey +from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey +from cryptography.hazmat.primitives.ciphers.aead import AESGCM +from cryptography.hazmat.primitives.kdf.hkdf import HKDF +from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC + + +# --------------------------------------------------------------------------- +# Password-based key encryption (M3: PBKDF2 600k iterations + AES-256-GCM) +# --------------------------------------------------------------------------- + +PBKDF2_ITERATIONS = 600_000 +_ECP1_MAGIC = b"ECP1" # Encrypted Chat PBKDF v1 format marker + + +def _encrypt_private_key(raw_bytes: bytes, password: bytes) -> bytes: + """Encrypt raw key bytes with PBKDF2-HMAC-SHA256 (600k iterations) + AES-256-GCM. + + Output format: MAGIC(4) + salt(16) + nonce(12) + ciphertext_with_tag(N+16) + """ + salt = os.urandom(16) + kdf = PBKDF2HMAC(algorithm=hashes.SHA256(), length=32, + salt=salt, iterations=PBKDF2_ITERATIONS) + derived = kdf.derive(password) + nonce = os.urandom(12) + aesgcm = AESGCM(derived) + ct = aesgcm.encrypt(nonce, raw_bytes, _ECP1_MAGIC) # AAD = magic bytes + return _ECP1_MAGIC + salt + nonce + ct + + +def _decrypt_private_key(data: bytes, password: bytes) -> bytes: + """Decrypt key bytes encrypted with _encrypt_private_key.""" + if not data.startswith(_ECP1_MAGIC): + raise ValueError("Not ECP1 format") + salt = data[4:20] + nonce = data[20:32] + ct = data[32:] + kdf = PBKDF2HMAC(algorithm=hashes.SHA256(), length=32, + salt=salt, iterations=PBKDF2_ITERATIONS) + derived = kdf.derive(password) + aesgcm = AESGCM(derived) + return aesgcm.decrypt(nonce, ct, _ECP1_MAGIC) + + +# --------------------------------------------------------------------------- +# RSA (login challenge-response ONLY) +# --------------------------------------------------------------------------- + +def generate_rsa_keypair(key_size: int = 4096) -> tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey]: + private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size) + return private_key, private_key.public_key() + + +def serialize_private_key(key: rsa.RSAPrivateKey, password: bytes | None = None) -> bytes: + if password: + raw = key.private_bytes(serialization.Encoding.DER, serialization.PrivateFormat.PKCS8, + serialization.NoEncryption()) + return _encrypt_private_key(raw, password) + return key.private_bytes(serialization.Encoding.PEM, serialization.PrivateFormat.PKCS8, + serialization.NoEncryption()) + + +def serialize_public_key(key: rsa.RSAPublicKey) -> bytes: + return key.public_bytes(serialization.Encoding.PEM, serialization.PublicFormat.SubjectPublicKeyInfo) + + +def load_private_key(data: bytes, password: bytes | None = None) -> rsa.RSAPrivateKey: + if data.startswith(_ECP1_MAGIC): + raw = _decrypt_private_key(data, password) + return serialization.load_der_private_key(raw, password=None) + # Legacy PEM format (old BestAvailableEncryption or unencrypted) + return serialization.load_pem_private_key(data, password=password) + + +def load_public_key(pem: bytes) -> rsa.RSAPublicKey: + return serialization.load_pem_public_key(pem) + + +def rsa_sign(private_key: rsa.RSAPrivateKey, data: bytes) -> bytes: + return private_key.sign( + data, + padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH), + hashes.SHA256(), + ) + + +def rsa_verify(public_key: rsa.RSAPublicKey, signature: bytes, data: bytes) -> bool: + try: + public_key.verify( + signature, data, + padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH), + hashes.SHA256(), + ) + return True + except Exception: + return False + + +# --------------------------------------------------------------------------- +# AES-256-GCM (symmetric encryption — used by ratchet message keys & images) +# --------------------------------------------------------------------------- + +def aes_encrypt(plaintext: bytes, key: bytes | None = None) -> tuple[bytes, bytes, bytes, bytes]: + """Encrypt with AES-256-GCM. Returns (key, nonce, ciphertext, tag).""" + if key is None: + key = AESGCM.generate_key(bit_length=256) + nonce = os.urandom(12) + aesgcm = AESGCM(key) + ct_with_tag = aesgcm.encrypt(nonce, plaintext, None) + ciphertext = ct_with_tag[:-16] + tag = ct_with_tag[-16:] + return key, nonce, ciphertext, tag + + +def aes_decrypt(key: bytes, nonce: bytes, ciphertext: bytes, tag: bytes) -> bytes: + """Decrypt with AES-256-GCM.""" + aesgcm = AESGCM(key) + return aesgcm.decrypt(nonce, ciphertext + tag, None) + + +# --------------------------------------------------------------------------- +# Ed25519 Identity Keys +# --------------------------------------------------------------------------- + +def generate_identity_keypair() -> tuple[Ed25519PrivateKey, Ed25519PublicKey]: + priv = Ed25519PrivateKey.generate() + return priv, priv.public_key() + + +def serialize_ed25519_private(key: Ed25519PrivateKey, password: bytes | None = None) -> bytes: + if password: + raw = serialize_ed25519_private_raw(key) # 32 bytes + return _encrypt_private_key(raw, password) + return serialize_ed25519_private_raw(key) # 32 bytes, no password + + +def serialize_ed25519_private_raw(key: Ed25519PrivateKey) -> bytes: + """Serialize Ed25519 private key to 32 raw bytes (unencrypted).""" + return key.private_bytes(serialization.Encoding.Raw, serialization.PrivateFormat.Raw, serialization.NoEncryption()) + + +def serialize_ed25519_public(key: Ed25519PublicKey) -> bytes: + """Serialize Ed25519 public key to 32 raw bytes.""" + return key.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw) + + +def load_ed25519_private(data: bytes, password: bytes | None = None) -> Ed25519PrivateKey: + if data.startswith(_ECP1_MAGIC): + raw = _decrypt_private_key(data, password) + return Ed25519PrivateKey.from_private_bytes(raw) + # Legacy formats: PEM (old BestAvailableEncryption) or 32-byte raw + if password: + return serialization.load_pem_private_key(data, password=password) + if len(data) == 32: + return Ed25519PrivateKey.from_private_bytes(data) + return serialization.load_pem_private_key(data, password=None) + + +def load_ed25519_public(data: bytes) -> Ed25519PublicKey: + if len(data) == 32: + return Ed25519PublicKey.from_public_bytes(data) + return serialization.load_pem_public_key(data) + + +def ed25519_sign(private_key: Ed25519PrivateKey, data: bytes) -> bytes: + """Sign data with Ed25519. Returns 64-byte signature.""" + return private_key.sign(data) + + +def ed25519_verify(public_key: Ed25519PublicKey, signature: bytes, data: bytes) -> bool: + """Verify Ed25519 signature.""" + try: + public_key.verify(signature, data) + return True + except Exception: + return False + + +# --------------------------------------------------------------------------- +# X25519 Key Exchange +# --------------------------------------------------------------------------- + +def generate_x25519_keypair() -> tuple[X25519PrivateKey, X25519PublicKey]: + priv = X25519PrivateKey.generate() + return priv, priv.public_key() + + +def serialize_x25519_private(key: X25519PrivateKey) -> bytes: + """Serialize X25519 private key to 32 raw bytes.""" + return key.private_bytes(serialization.Encoding.Raw, serialization.PrivateFormat.Raw, serialization.NoEncryption()) + + +def serialize_x25519_public(key: X25519PublicKey) -> bytes: + """Serialize X25519 public key to 32 raw bytes.""" + return key.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw) + + +def load_x25519_private(data: bytes) -> X25519PrivateKey: + return X25519PrivateKey.from_private_bytes(data) + + +def load_x25519_public(data: bytes) -> X25519PublicKey: + return X25519PublicKey.from_public_bytes(data) + + +def x25519_dh(private_key: X25519PrivateKey, public_key: X25519PublicKey) -> bytes: + """Perform X25519 Diffie-Hellman. Returns 32-byte shared secret.""" + return private_key.exchange(public_key) + + +# --------------------------------------------------------------------------- +# Ed25519 <-> X25519 conversion (for Identity Key dual use) +# --------------------------------------------------------------------------- + +def ed25519_private_to_x25519(ed_private: Ed25519PrivateKey) -> X25519PrivateKey: + """Derive X25519 private key from Ed25519 private key via RFC 7748 clamping.""" + raw = ed_private.private_bytes( + serialization.Encoding.Raw, serialization.PrivateFormat.Raw, serialization.NoEncryption() + ) + # SHA-512 hash of the seed, take first 32 bytes, clamp per RFC 7748 + h = hashlib.sha512(raw).digest()[:32] + clamped = bytearray(h) + clamped[0] &= 248 + clamped[31] &= 127 + clamped[31] |= 64 + return X25519PrivateKey.from_private_bytes(bytes(clamped)) + + +def ed25519_public_to_x25519(ed_public: Ed25519PublicKey) -> X25519PublicKey: + """Derive X25519 public key from Ed25519 public key. + + Uses the cryptography library's internal conversion. For production use, + we compute the X25519 public key from the converted private key when possible. + For remote keys (where we don't have the private key), we use a pure-Python + implementation of the Ed25519->X25519 point conversion. + """ + # Montgomery u = (1 + y) / (1 - y) mod p, where p = 2^255 - 19 + raw = ed_public.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw) + y = int.from_bytes(raw, "little") + # Clear the sign bit + y &= (1 << 255) - 1 + p = (1 << 255) - 19 + # u = (1 + y) * inverse(1 - y) mod p + one_plus_y = (1 + y) % p + one_minus_y = (1 - y) % p + inv = pow(one_minus_y, p - 2, p) + u = (one_plus_y * inv) % p + x25519_bytes = u.to_bytes(32, "little") + return X25519PublicKey.from_public_bytes(x25519_bytes) + + +# --------------------------------------------------------------------------- +# HKDF +# --------------------------------------------------------------------------- + +_HKDF_INFO_SELF = b"EncryptedChat_SelfKey" +_HKDF_INFO_RK = b"EncryptedChat_RootKey" + + +def derive_self_encryption_key(identity_private: Ed25519PrivateKey) -> bytes: + """Derive a static AES-256 key from identity key for encrypting own sent messages. + + This is NOT a ratchet — it's a static key. Safe because only the owner + has the identity private key, and self-copies don't need forward secrecy. + """ + raw = identity_private.private_bytes( + serialization.Encoding.Raw, serialization.PrivateFormat.Raw, serialization.NoEncryption() + ) + return hkdf_derive(raw, salt=b"self_encryption", info=_HKDF_INFO_SELF, length=32) + + +_HKDF_INFO_LOCAL = b"EncryptedChat_LocalStorage" + + +def derive_local_storage_key(identity_private: Ed25519PrivateKey) -> bytes: + """Derive AES-256 key for encrypting local session/sender key files.""" + raw = identity_private.private_bytes( + serialization.Encoding.Raw, serialization.PrivateFormat.Raw, serialization.NoEncryption() + ) + return hkdf_derive(raw, salt=b"local_storage", info=_HKDF_INFO_LOCAL, length=32) + + +_HKDF_INFO_CK_MSG = b"\x01" # chain key -> message key +_HKDF_INFO_CK_NEXT = b"\x02" # chain key -> next chain key + + +def hkdf_derive(input_key: bytes, salt: bytes, info: bytes, length: int = 32) -> bytes: + return HKDF(algorithm=hashes.SHA256(), length=length, salt=salt, info=info).derive(input_key) + + +def kdf_rk(root_key: bytes, dh_output: bytes) -> tuple[bytes, bytes]: + """Root key KDF. Returns (new_root_key, chain_key). + + Uses HKDF with the root key as salt and DH output as input key material. + Derives 64 bytes: first 32 = new root key, last 32 = chain key. + """ + derived = hkdf_derive(dh_output, salt=root_key, info=_HKDF_INFO_RK, length=64) + return derived[:32], derived[32:] + + +def kdf_ck(chain_key: bytes) -> tuple[bytes, bytes]: + """Chain key KDF. Returns (new_chain_key, message_key). + + Uses HMAC-SHA256: + message_key = HMAC(chain_key, 0x01) + new_chain_key = HMAC(chain_key, 0x02) + """ + message_key = hmac.new(chain_key, _HKDF_INFO_CK_MSG, hashlib.sha256).digest() + new_chain_key = hmac.new(chain_key, _HKDF_INFO_CK_NEXT, hashlib.sha256).digest() + return new_chain_key, message_key + + +# --------------------------------------------------------------------------- +# X3DH +# --------------------------------------------------------------------------- + +_X3DH_INFO = b"EncryptedChat_X3DH" + + +def generate_signed_prekey(identity_private: Ed25519PrivateKey) -> dict: + """Generate a signed pre-key (SPK). + + Returns {private: X25519PrivateKey, public: X25519PublicKey, signature: bytes, id: str}. + """ + spk_priv, spk_pub = generate_x25519_keypair() + spk_pub_bytes = serialize_x25519_public(spk_pub) + signature = ed25519_sign(identity_private, spk_pub_bytes) + return { + "private": spk_priv, + "public": spk_pub, + "signature": signature, + "id": str(uuid.uuid4()), + } + + +def generate_one_time_prekeys(count: int = 50) -> list[dict]: + """Generate a batch of one-time pre-keys. + + Returns [{private: X25519PrivateKey, public: X25519PublicKey, id: str}, ...]. + """ + result = [] + for _ in range(count): + priv, pub = generate_x25519_keypair() + result.append({"private": priv, "public": pub, "id": str(uuid.uuid4())}) + return result + + +def x3dh_initiate( + ik_private_ed: Ed25519PrivateKey, + ik_public_remote_ed: Ed25519PublicKey, + spk_remote: X25519PublicKey, + spk_signature: bytes, + opk_remote: X25519PublicKey | None = None, +) -> tuple[bytes, X25519PrivateKey, X25519PublicKey]: + """Initiator side of X3DH. + + Args: + ik_private_ed: Our Ed25519 identity private key + ik_public_remote_ed: Remote Ed25519 identity public key + spk_remote: Remote signed pre-key (X25519 public) + spk_signature: Ed25519 signature of spk_remote by ik_public_remote_ed + opk_remote: Optional one-time pre-key (X25519 public) + + Returns: + (shared_secret, ephemeral_private, ephemeral_public) + """ + # Verify SPK signature + spk_remote_bytes = serialize_x25519_public(spk_remote) + if not ed25519_verify(ik_public_remote_ed, spk_signature, spk_remote_bytes): + raise ValueError("Invalid SPK signature") + + # Convert identity keys to X25519 + ik_x25519_private = ed25519_private_to_x25519(ik_private_ed) + ik_x25519_remote = ed25519_public_to_x25519(ik_public_remote_ed) + + # Generate ephemeral keypair + ek_priv, ek_pub = generate_x25519_keypair() + + # DH computations + dh1 = x25519_dh(ik_x25519_private, spk_remote) # IK_A, SPK_B + dh2 = x25519_dh(ek_priv, ik_x25519_remote) # EK_A, IK_B + dh3 = x25519_dh(ek_priv, spk_remote) # EK_A, SPK_B + + dh_concat = dh1 + dh2 + dh3 + if opk_remote is not None: + dh4 = x25519_dh(ek_priv, opk_remote) # EK_A, OPK_B + dh_concat += dh4 + + # Derive shared secret + shared_secret = hkdf_derive(dh_concat, salt=b"\x00" * 32, info=_X3DH_INFO, length=32) + return shared_secret, ek_priv, ek_pub + + +def x3dh_respond( + ik_private_ed: Ed25519PrivateKey, + spk_private: X25519PrivateKey, + ik_remote_ed: Ed25519PublicKey, + ek_remote: X25519PublicKey, + opk_private: X25519PrivateKey | None = None, +) -> bytes: + """Responder side of X3DH. + + Args: + ik_private_ed: Our Ed25519 identity private key + spk_private: Our signed pre-key private (X25519) + ik_remote_ed: Remote Ed25519 identity public key + ek_remote: Remote ephemeral key (X25519 public) + opk_private: Our one-time pre-key private (X25519), if used + + Returns: + shared_secret (32 bytes) + """ + ik_x25519_private = ed25519_private_to_x25519(ik_private_ed) + ik_x25519_remote = ed25519_public_to_x25519(ik_remote_ed) + + dh1 = x25519_dh(spk_private, ik_x25519_remote) # SPK_B, IK_A + dh2 = x25519_dh(ik_x25519_private, ek_remote) # IK_B, EK_A + dh3 = x25519_dh(spk_private, ek_remote) # SPK_B, EK_A + + dh_concat = dh1 + dh2 + dh3 + if opk_private is not None: + dh4 = x25519_dh(opk_private, ek_remote) # OPK_B, EK_A + dh_concat += dh4 + + shared_secret = hkdf_derive(dh_concat, salt=b"\x00" * 32, info=_X3DH_INFO, length=32) + return shared_secret + + +# --------------------------------------------------------------------------- +# Double Ratchet +# --------------------------------------------------------------------------- + +MAX_SKIP = 256 # max messages to skip in a single chain (out-of-order tolerance) + + +@dataclass +class RatchetHeader: + """Header sent with each ratchet message.""" + dh_pub: bytes # sender's current ratchet public key (32 bytes) + n: int # message number in current sending chain + pn: int # number of messages in previous sending chain + + def serialize(self) -> bytes: + return json.dumps({ + "dh_pub": serialize_x25519_public(load_x25519_public(self.dh_pub)).hex() + if isinstance(self.dh_pub, bytes) else serialize_x25519_public(self.dh_pub).hex(), + "n": self.n, + "pn": self.pn, + }).encode() + + def to_dict(self) -> dict: + pub_hex = self.dh_pub.hex() if isinstance(self.dh_pub, bytes) else \ + serialize_x25519_public(self.dh_pub).hex() + return {"dh_pub": pub_hex, "n": self.n, "pn": self.pn} + + @classmethod + def from_dict(cls, d: dict) -> "RatchetHeader": + return cls(dh_pub=bytes.fromhex(d["dh_pub"]), n=d["n"], pn=d["pn"]) + + +class DoubleRatchet: + """Signal Double Ratchet implementation.""" + + def __init__(self): + self.dh_pair: tuple[X25519PrivateKey, X25519PublicKey] | None = None + self.dh_remote: X25519PublicKey | None = None + self.root_key: bytes = b"" + self.send_chain_key: bytes | None = None + self.recv_chain_key: bytes | None = None + self.send_n: int = 0 + self.recv_n: int = 0 + self.prev_send_n: int = 0 + # (dh_pub_hex, n) -> message_key for out-of-order messages + self.skipped: dict[tuple[str, int], bytes] = {} + + @classmethod + def init_alice(cls, shared_secret: bytes, bob_spk_pub: X25519PublicKey) -> "DoubleRatchet": + """Initialize as initiator (Alice) after X3DH. + + Alice performs the first DH ratchet step immediately. + """ + ratchet = cls() + ratchet.dh_pair = generate_x25519_keypair() + ratchet.dh_remote = bob_spk_pub + + # Perform DH ratchet to derive send chain + dh_output = x25519_dh(ratchet.dh_pair[0], ratchet.dh_remote) + ratchet.root_key, ratchet.send_chain_key = kdf_rk(shared_secret, dh_output) + ratchet.recv_chain_key = None + ratchet.send_n = 0 + ratchet.recv_n = 0 + ratchet.prev_send_n = 0 + return ratchet + + @classmethod + def init_bob(cls, shared_secret: bytes, spk_pair: tuple[X25519PrivateKey, X25519PublicKey]) -> "DoubleRatchet": + """Initialize as responder (Bob) after X3DH. + + Bob uses his SPK as the initial ratchet key pair. + """ + ratchet = cls() + ratchet.dh_pair = spk_pair + ratchet.root_key = shared_secret + ratchet.send_chain_key = None + ratchet.recv_chain_key = None + ratchet.send_n = 0 + ratchet.recv_n = 0 + ratchet.prev_send_n = 0 + return ratchet + + def encrypt(self, plaintext: bytes) -> dict: + """Encrypt a message. + + Returns {header: {dh_pub, n, pn}, ciphertext: bytes, nonce: bytes}. + """ + if self.send_chain_key is None: + raise RuntimeError("Send chain not initialized") + + self.send_chain_key, message_key = kdf_ck(self.send_chain_key) + + header = RatchetHeader( + dh_pub=serialize_x25519_public(self.dh_pair[1]), + n=self.send_n, + pn=self.prev_send_n, + ) + + # Encrypt with AES-256-GCM using the message key + nonce = os.urandom(12) + aesgcm = AESGCM(message_key) + # Include header as AAD to bind ciphertext to header + aad = header.serialize() + ct_with_tag = aesgcm.encrypt(nonce, plaintext, aad) + + self.send_n += 1 + + return { + "header": header.to_dict(), + "ciphertext": ct_with_tag, # includes 16-byte tag + "nonce": nonce, + } + + def decrypt(self, header_dict: dict, ciphertext: bytes, nonce: bytes) -> bytes: + """Decrypt a message. Handles DH ratchet step if new dh_pub. + + State is snapshotted before modification and restored on failure (M9 fix). + """ + header = RatchetHeader.from_dict(header_dict) + remote_dh_pub_bytes = header.dh_pub + + # Check if this is from a skipped message (no state modification needed) + skip_key = (remote_dh_pub_bytes.hex(), header.n) + if skip_key in self.skipped: + mk = self.skipped.pop(skip_key) + aad = header.serialize() + aesgcm = AESGCM(mk) + try: + return aesgcm.decrypt(nonce, ciphertext, aad) + except Exception: + self.skipped[skip_key] = mk # restore skipped key + raise + + # Snapshot state before modifications + snap = self._snapshot() + + try: + remote_dh_pub = load_x25519_public(remote_dh_pub_bytes) + current_remote_bytes = serialize_x25519_public(self.dh_remote) if self.dh_remote else None + + if current_remote_bytes is None or remote_dh_pub_bytes != current_remote_bytes: + # New DH ratchet step + self._skip_messages(header.pn) + self._dh_ratchet(remote_dh_pub) + + self._skip_messages(header.n) + + # Derive message key from receive chain + self.recv_chain_key, mk = kdf_ck(self.recv_chain_key) + self.recv_n += 1 + + aad = header.serialize() + aesgcm = AESGCM(mk) + return aesgcm.decrypt(nonce, ciphertext, aad) + except Exception: + self._restore(snap) + raise + + def _snapshot(self) -> dict: + """Capture mutable state for rollback on decrypt failure.""" + return { + "dh_pair": self.dh_pair, + "dh_remote": self.dh_remote, + "root_key": self.root_key, + "send_chain_key": self.send_chain_key, + "recv_chain_key": self.recv_chain_key, + "send_n": self.send_n, + "recv_n": self.recv_n, + "prev_send_n": self.prev_send_n, + "skipped": dict(self.skipped), + } + + def _restore(self, snap: dict): + """Restore state from snapshot.""" + self.dh_pair = snap["dh_pair"] + self.dh_remote = snap["dh_remote"] + self.root_key = snap["root_key"] + self.send_chain_key = snap["send_chain_key"] + self.recv_chain_key = snap["recv_chain_key"] + self.send_n = snap["send_n"] + self.recv_n = snap["recv_n"] + self.prev_send_n = snap["prev_send_n"] + self.skipped = snap["skipped"] + + def _skip_messages(self, until: int): + """Skip ahead in the receive chain, storing message keys for out-of-order delivery.""" + if self.recv_chain_key is None: + return + if until - self.recv_n > MAX_SKIP: + raise RuntimeError(f"Too many skipped messages ({until - self.recv_n} > {MAX_SKIP})") + while self.recv_n < until: + self.recv_chain_key, mk = kdf_ck(self.recv_chain_key) + remote_hex = serialize_x25519_public(self.dh_remote).hex() if self.dh_remote else "" + self.skipped[(remote_hex, self.recv_n)] = mk + self.recv_n += 1 + + def _dh_ratchet(self, remote_dh_pub: X25519PublicKey): + """Perform a DH ratchet step: update receive chain, generate new DH pair, update send chain.""" + self.prev_send_n = self.send_n + self.send_n = 0 + self.recv_n = 0 + self.dh_remote = remote_dh_pub + + # Derive new receive chain key + dh_output = x25519_dh(self.dh_pair[0], self.dh_remote) + self.root_key, self.recv_chain_key = kdf_rk(self.root_key, dh_output) + + # Generate new DH pair and derive new send chain key + self.dh_pair = generate_x25519_keypair() + dh_output = x25519_dh(self.dh_pair[0], self.dh_remote) + self.root_key, self.send_chain_key = kdf_rk(self.root_key, dh_output) + + def export_state(self) -> bytes: + """Serialize full ratchet state for persistent storage.""" + state = { + "dh_priv": serialize_x25519_private(self.dh_pair[0]).hex() if self.dh_pair else None, + "dh_pub": serialize_x25519_public(self.dh_pair[1]).hex() if self.dh_pair else None, + "dh_remote": serialize_x25519_public(self.dh_remote).hex() if self.dh_remote else None, + "root_key": self.root_key.hex(), + "send_ck": self.send_chain_key.hex() if self.send_chain_key else None, + "recv_ck": self.recv_chain_key.hex() if self.recv_chain_key else None, + "send_n": self.send_n, + "recv_n": self.recv_n, + "prev_send_n": self.prev_send_n, + "skipped": {f"{k[0]}:{k[1]}": v.hex() for k, v in self.skipped.items()}, + } + return json.dumps(state).encode() + + @classmethod + def import_state(cls, data: bytes) -> "DoubleRatchet": + """Deserialize ratchet state.""" + state = json.loads(data) + r = cls() + if state["dh_priv"] and state["dh_pub"]: + priv = load_x25519_private(bytes.fromhex(state["dh_priv"])) + pub = load_x25519_public(bytes.fromhex(state["dh_pub"])) + r.dh_pair = (priv, pub) + if state["dh_remote"]: + r.dh_remote = load_x25519_public(bytes.fromhex(state["dh_remote"])) + r.root_key = bytes.fromhex(state["root_key"]) + r.send_chain_key = bytes.fromhex(state["send_ck"]) if state["send_ck"] else None + r.recv_chain_key = bytes.fromhex(state["recv_ck"]) if state["recv_ck"] else None + r.send_n = state["send_n"] + r.recv_n = state["recv_n"] + r.prev_send_n = state["prev_send_n"] + r.skipped = {} + for k_str, v_hex in state.get("skipped", {}).items(): + parts = k_str.rsplit(":", 1) + dh_hex = parts[0] + n = int(parts[1]) + r.skipped[(dh_hex, n)] = bytes.fromhex(v_hex) + return r + + +# --------------------------------------------------------------------------- +# Sender Keys (group messaging) +# --------------------------------------------------------------------------- + +class SenderKeyState: + """Sender key chain for group messaging. + + Each sender in a group has their own sender key chain. + Other group members receive the initial sender_key via pairwise Double Ratchet. + """ + + def __init__(self, sender_key: bytes | None = None): + if sender_key is None: + sender_key = os.urandom(32) + self.sender_key = sender_key + self.chain_id = hashlib.sha256(sender_key).digest() + self.chain_key = hkdf_derive(sender_key, salt=b"\x00" * 32, info=b"SenderKeyChain", length=32) + self.n = 0 + # For receivers: track chain state to allow fast-forward + self._known_keys: dict[int, bytes] = {} + + def encrypt(self, plaintext: bytes) -> dict: + """Encrypt with current chain key. + + Returns {chain_id: hex, n: int, ciphertext: bytes, nonce: bytes}. + """ + self.chain_key, message_key = kdf_ck(self.chain_key) + nonce = os.urandom(12) + aesgcm = AESGCM(message_key) + # AAD includes chain_id and message number + aad = self.chain_id + struct.pack(">I", self.n) + ct_with_tag = aesgcm.encrypt(nonce, plaintext, aad) + result = { + "chain_id": self.chain_id.hex(), + "n": self.n, + "ciphertext": ct_with_tag, + "nonce": nonce, + } + self.n += 1 + return result + + MAX_SENDER_KEY_SKIP = 256 + + def decrypt(self, chain_id_hex: str, n: int, ciphertext: bytes, nonce: bytes) -> bytes: + """Decrypt a group message. Fast-forwards the chain if needed. + + State is snapshotted before modification and restored on failure (M9 fix). + """ + chain_id = bytes.fromhex(chain_id_hex) + if chain_id != self.chain_id: + raise ValueError("Chain ID mismatch") + + if n - self.n > self.MAX_SENDER_KEY_SKIP: + raise ValueError(f"Sender key skip too large ({n - self.n} > {self.MAX_SENDER_KEY_SKIP})") + + # Snapshot before fast-forward + snap_chain_key = self.chain_key + snap_n = self.n + snap_known = dict(self._known_keys) + + try: + # Fast-forward the chain to reach message n + while self.n <= n: + self.chain_key, mk = kdf_ck(self.chain_key) + self._known_keys[self.n] = mk + self.n += 1 + + mk = self._known_keys.pop(n, None) + if mk is None: + raise ValueError(f"Message key for n={n} not available (already consumed)") + + aad = chain_id + struct.pack(">I", n) + aesgcm = AESGCM(mk) + return aesgcm.decrypt(nonce, ciphertext, aad) + except Exception: + self.chain_key = snap_chain_key + self.n = snap_n + self._known_keys = snap_known + raise + + def export_key(self) -> bytes: + """Export sender key for distribution to group members. + + Contains everything needed to initialize a receiving SenderKeyState. + """ + return json.dumps({ + "sender_key": self.sender_key.hex(), + }).encode() + + def export_state(self) -> bytes: + """Serialize full state for persistent storage.""" + return json.dumps({ + "sender_key": self.sender_key.hex(), + "chain_id": self.chain_id.hex(), + "chain_key": self.chain_key.hex(), + "n": self.n, + "known_keys": {str(k): v.hex() for k, v in self._known_keys.items()}, + }).encode() + + @classmethod + def import_state(cls, data: bytes) -> "SenderKeyState": + state = json.loads(data) + obj = cls.__new__(cls) + obj.sender_key = bytes.fromhex(state["sender_key"]) + obj.chain_id = bytes.fromhex(state["chain_id"]) + obj.chain_key = bytes.fromhex(state["chain_key"]) + obj.n = state["n"] + obj._known_keys = {int(k): bytes.fromhex(v) for k, v in state.get("known_keys", {}).items()} + return obj + + @classmethod + def from_key(cls, exported_key: bytes) -> "SenderKeyState": + """Initialize a receiving SenderKeyState from an exported key.""" + data = json.loads(exported_key) + return cls(sender_key=bytes.fromhex(data["sender_key"])) diff --git a/zaloha/db.py b/zaloha/db.py new file mode 100644 index 0000000..91079aa --- /dev/null +++ b/zaloha/db.py @@ -0,0 +1,1293 @@ +"""MySQL database layer for the encrypted chat server.""" + +import os +import uuid + +import mysql.connector +from dotenv import load_dotenv + +from crypto_utils import ( + generate_identity_keypair, + serialize_ed25519_public, + generate_signed_prekey, + serialize_x25519_public, + generate_one_time_prekeys, +) + +load_dotenv() + +# Sentinel device_id for self-encrypted copies and legacy (pre-multi-device) rows +SELF_DEVICE_ID = "00000000-0000-0000-0000-000000000000" + + +def get_connection(): + """Create a new MySQL connection from environment variables.""" + return mysql.connector.connect( + host=os.getenv("MYSQL_HOST", "localhost"), + port=int(os.getenv("MYSQL_PORT", "3306")), + user=os.getenv("MYSQL_USER", "root"), + password=os.getenv("MYSQL_PASSWORD", ""), + database=os.getenv("MYSQL_DATABASE", "encrypted_chat"), + ) + + +def generate_uuid() -> str: + return str(uuid.uuid4()) + + +# --- Devices --- + +def create_device(user_id: str, device_name: str | None = None) -> str: + """Create a new device for a user. Returns device_id.""" + conn = get_connection() + try: + cursor = conn.cursor() + device_id = generate_uuid() + cursor.execute( + "INSERT INTO devices (id, user_id, device_name) VALUES (%s, %s, %s)", + (device_id, user_id, device_name), + ) + conn.commit() + return device_id + finally: + conn.close() + + +def get_user_devices(user_id: str) -> list[dict]: + """Get all devices for a user.""" + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT id, user_id, device_name, created_at, last_seen_at " + "FROM devices WHERE user_id = %s ORDER BY created_at", + (user_id,), + ) + return cursor.fetchall() + finally: + conn.close() + + +def get_device(device_id: str) -> dict | None: + """Get a single device by ID.""" + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT id, user_id, device_name, created_at, last_seen_at " + "FROM devices WHERE id = %s", + (device_id,), + ) + return cursor.fetchone() + finally: + conn.close() + + +def update_device_last_seen(device_id: str): + """Update last_seen_at timestamp for a device.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "UPDATE devices SET last_seen_at = NOW() WHERE id = %s", + (device_id,), + ) + conn.commit() + finally: + conn.close() + + +def delete_device(device_id: str): + """Delete a device. CASCADE removes its prekeys.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute("DELETE FROM devices WHERE id = %s", (device_id,)) + # Also clean up prekeys explicitly for device_id column + cursor.execute("DELETE FROM signed_prekeys WHERE device_id = %s", (device_id,)) + cursor.execute("DELETE FROM one_time_prekeys WHERE device_id = %s", (device_id,)) + conn.commit() + finally: + conn.close() + + +# --- Users --- + +def create_user(username: str, email: str, rsa_public_key_pem: str, identity_key: bytes) -> str: + """Register a new user. Returns user ID.""" + conn = get_connection() + try: + cursor = conn.cursor() + user_id = generate_uuid() + cursor.execute( + "INSERT INTO users (id, username, email, rsa_public_key, identity_key) " + "VALUES (%s, %s, %s, %s, %s)", + (user_id, username, email, rsa_public_key_pem, identity_key), + ) + conn.commit() + return user_id + finally: + conn.close() + + +def get_user_by_email(email: str) -> dict | None: + """Get user by email.""" + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT id, username, rsa_public_key, email, identity_key FROM users WHERE email = %s", + (email,), + ) + return cursor.fetchone() + finally: + conn.close() + + +def get_user_by_id(user_id: str) -> dict | None: + """Get user by ID.""" + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT id, username, rsa_public_key, email, identity_key FROM users WHERE id = %s", + (user_id,), + ) + return cursor.fetchone() + finally: + conn.close() + + +def get_user_contacts(user_id: str) -> list[str]: + """Get all user IDs that share at least one conversation with the given user.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "SELECT DISTINCT cm2.user_id " + "FROM conversation_members cm1 " + "JOIN conversation_members cm2 ON cm1.conversation_id = cm2.conversation_id " + "WHERE cm1.user_id = %s AND cm2.user_id != %s", + (user_id, user_id), + ) + return [row[0] for row in cursor.fetchall()] + finally: + conn.close() + + +def update_user_rsa_key(user_id: str, rsa_public_key_pem: str): + """Update user's RSA public key (for login).""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute("UPDATE users SET rsa_public_key = %s WHERE id = %s", (rsa_public_key_pem, user_id)) + conn.commit() + finally: + conn.close() + + +# --- Pre-keys --- + +def store_signed_prekey(user_id: str, spk_id: str, public_key: bytes, signature: bytes, + device_id: str | None = None): + """Store (or replace) a signed pre-key for a user's device.""" + conn = get_connection() + try: + cursor = conn.cursor() + # Remove old SPKs for this user+device + if device_id: + cursor.execute("DELETE FROM signed_prekeys WHERE user_id = %s AND device_id = %s", + (user_id, device_id)) + else: + cursor.execute("DELETE FROM signed_prekeys WHERE user_id = %s AND device_id IS NULL", + (user_id,)) + cursor.execute( + "INSERT INTO signed_prekeys (id, user_id, device_id, public_key, signature) " + "VALUES (%s, %s, %s, %s, %s)", + (spk_id, user_id, device_id, public_key, signature), + ) + conn.commit() + finally: + conn.close() + + +def get_signed_prekey(user_id: str, device_id: str | None = None) -> dict | None: + """Get the current signed pre-key for a user (optionally per device).""" + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + if device_id: + cursor.execute( + "SELECT id, public_key, signature, device_id, created_at FROM signed_prekeys " + "WHERE user_id = %s AND device_id = %s " + "ORDER BY created_at DESC LIMIT 1", + (user_id, device_id), + ) + else: + cursor.execute( + "SELECT id, public_key, signature, device_id, created_at FROM signed_prekeys " + "WHERE user_id = %s ORDER BY created_at DESC LIMIT 1", + (user_id,), + ) + return cursor.fetchone() + finally: + conn.close() + + +def store_one_time_prekeys(user_id: str, prekeys: list[dict], device_id: str | None = None): + """Store a batch of one-time pre-keys. Each dict has {id, public_key (bytes)}.""" + conn = get_connection() + try: + cursor = conn.cursor() + for pk in prekeys: + cursor.execute( + "INSERT INTO one_time_prekeys (id, user_id, device_id, public_key) " + "VALUES (%s, %s, %s, %s)", + (pk["id"], user_id, device_id, pk["public_key"]), + ) + conn.commit() + finally: + conn.close() + + +def consume_one_time_prekey(user_id: str, device_id: str | None = None) -> dict | None: + """Atomically consume one OTP: SELECT FOR UPDATE + DELETE. + Returns {id, public_key} or None.""" + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + conn.start_transaction() + if device_id: + cursor.execute( + "SELECT id, public_key FROM one_time_prekeys " + "WHERE user_id = %s AND device_id = %s LIMIT 1 FOR UPDATE", + (user_id, device_id), + ) + else: + cursor.execute( + "SELECT id, public_key FROM one_time_prekeys " + "WHERE user_id = %s LIMIT 1 FOR UPDATE", + (user_id,), + ) + row = cursor.fetchone() + if row: + cursor.execute("DELETE FROM one_time_prekeys WHERE id = %s", (row["id"],)) + conn.commit() + return row + except Exception: + conn.rollback() + raise + finally: + conn.close() + + +def count_one_time_prekeys(user_id: str, device_id: str | None = None) -> int: + """Count remaining OTPs for a user (optionally per device).""" + conn = get_connection() + try: + cursor = conn.cursor() + if device_id: + cursor.execute( + "SELECT COUNT(*) FROM one_time_prekeys WHERE user_id = %s AND device_id = %s", + (user_id, device_id), + ) + else: + cursor.execute("SELECT COUNT(*) FROM one_time_prekeys WHERE user_id = %s", (user_id,)) + return cursor.fetchone()[0] + finally: + conn.close() + + +def get_key_bundle(user_id: str) -> dict | None: + """Get complete key bundle for X3DH (single device — legacy compat). + + Returns {identity_key, signed_prekey_id, signed_prekey, spk_signature, + one_time_prekey_id, one_time_prekey} or None. + OTP is consumed atomically. + """ + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + # Get user identity key + cursor.execute("SELECT identity_key FROM users WHERE id = %s", (user_id,)) + user = cursor.fetchone() + if not user: + return None + + # Get signed prekey + cursor.execute( + "SELECT id, public_key, signature, device_id FROM signed_prekeys WHERE user_id = %s " + "ORDER BY created_at DESC LIMIT 1", + (user_id,), + ) + spk = cursor.fetchone() + if not spk: + return None + + # Consume one OTP (may be None) — use transaction for atomicity (H12 fix) + conn.start_transaction() + cursor.execute( + "SELECT id, public_key FROM one_time_prekeys WHERE user_id = %s LIMIT 1 FOR UPDATE", + (user_id,), + ) + opk = cursor.fetchone() + if opk: + cursor.execute("DELETE FROM one_time_prekeys WHERE id = %s", (opk["id"],)) + conn.commit() + + result = { + "identity_key": user["identity_key"], + "signed_prekey_id": spk["id"], + "signed_prekey": spk["public_key"], + "spk_signature": spk["signature"], + } + if opk: + result["one_time_prekey_id"] = opk["id"] + result["one_time_prekey"] = opk["public_key"] + return result + except Exception: + try: + conn.rollback() + except Exception: + pass + raise + finally: + conn.close() + + +def get_key_bundles_for_user(user_id: str) -> dict | None: + """Get key bundles for ALL devices of a user. Returns + {identity_key, device_bundles: [{device_id, signed_prekey_id, signed_prekey_pub, + spk_signature, opk_id, opk_pub}]} or None. + Consumes one OPK per device atomically. + """ + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + # Get user identity key + cursor.execute("SELECT identity_key FROM users WHERE id = %s", (user_id,)) + user = cursor.fetchone() + if not user: + return None + + # Get all signed prekeys (one per device, most recent) + cursor.execute( + "SELECT id, public_key, signature, device_id FROM signed_prekeys " + "WHERE user_id = %s ORDER BY created_at DESC", + (user_id,), + ) + all_spks = cursor.fetchall() + if not all_spks: + return None + + # De-duplicate: keep only the most recent SPK per device_id + seen_devices = set() + spks_by_device = [] + for spk in all_spks: + dev = spk.get("device_id") or "__legacy__" + if dev not in seen_devices: + seen_devices.add(dev) + spks_by_device.append(spk) + + device_bundles = [] + # Commit the implicit transaction from the read-only queries above + # so we can start an explicit transaction for atomic OPK consumption. + conn.commit() + conn.start_transaction() + for spk in spks_by_device: + dev_id = spk.get("device_id") + # Consume one OPK for this device + if dev_id: + cursor.execute( + "SELECT id, public_key FROM one_time_prekeys " + "WHERE user_id = %s AND device_id = %s LIMIT 1 FOR UPDATE", + (user_id, dev_id), + ) + else: + cursor.execute( + "SELECT id, public_key FROM one_time_prekeys " + "WHERE user_id = %s AND device_id IS NULL LIMIT 1 FOR UPDATE", + (user_id,), + ) + opk = cursor.fetchone() + if opk: + cursor.execute("DELETE FROM one_time_prekeys WHERE id = %s", (opk["id"],)) + + bundle = { + "device_id": dev_id, + "signed_prekey_id": spk["id"], + "signed_prekey_pub": spk["public_key"], + "spk_signature": spk["signature"], + } + if opk: + bundle["opk_id"] = opk["id"] + bundle["opk_pub"] = opk["public_key"] + device_bundles.append(bundle) + conn.commit() + + return { + "identity_key": user["identity_key"], + "device_bundles": device_bundles, + } + except Exception: + try: + conn.rollback() + except Exception: + pass + raise + finally: + conn.close() + + +# --- Conversations --- + +def create_conversation(member_user_ids: list[str], joined_at=None, name=None, created_by=None) -> str: + conn = get_connection() + try: + cursor = conn.cursor() + conv_id = generate_uuid() + cursor.execute("INSERT INTO conversations (id, name, created_by) VALUES (%s, %s, %s)", + (conv_id, name, created_by)) + for uid in member_user_ids: + cursor.execute( + "INSERT INTO conversation_members (conversation_id, user_id, joined_at) VALUES (%s, %s, %s)", + (conv_id, uid, joined_at), + ) + conn.commit() + return conv_id + finally: + conn.close() + + +def add_conversation_member(conversation_id: str, user_id: str, joined_at=None): + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "INSERT IGNORE INTO conversation_members (conversation_id, user_id, joined_at) VALUES (%s, %s, %s)", + (conversation_id, user_id, joined_at), + ) + conn.commit() + finally: + conn.close() + + +def remove_conversation_member(conversation_id: str, user_id: str): + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "DELETE FROM conversation_members WHERE conversation_id = %s AND user_id = %s", + (conversation_id, user_id), + ) + conn.commit() + finally: + conn.close() + + +def count_conversation_members(conversation_id: str) -> int: + """Count members in a conversation.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "SELECT COUNT(*) FROM conversation_members WHERE conversation_id = %s", + (conversation_id,), + ) + return cursor.fetchone()[0] + finally: + conn.close() + + +def get_conversation_file_ids(conversation_id: str) -> list[str]: + """Get all file IDs (images + files) uploaded to a conversation.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "SELECT file_id FROM image_uploads WHERE conversation_id = %s", + (conversation_id,), + ) + return [row[0] for row in cursor.fetchall()] + finally: + conn.close() + + +def delete_conversation(conversation_id: str): + """Delete a conversation entirely. CASCADE cleans up members, messages, etc.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute("DELETE FROM conversations WHERE id = %s", (conversation_id,)) + conn.commit() + finally: + conn.close() + + +def get_conversation_members(conversation_id: str) -> list[dict]: + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT u.id, u.username, u.email, u.identity_key FROM conversation_members cm " + "JOIN users u ON cm.user_id = u.id " + "WHERE cm.conversation_id = %s", + (conversation_id,), + ) + return cursor.fetchall() + finally: + conn.close() + + +def find_direct_conversation(user_id_a: str, user_id_b: str) -> str | None: + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "SELECT cm1.conversation_id FROM conversation_members cm1 " + "JOIN conversation_members cm2 ON cm1.conversation_id = cm2.conversation_id " + "WHERE cm1.user_id = %s AND cm2.user_id = %s " + "AND (SELECT COUNT(*) FROM conversation_members cm3 " + " WHERE cm3.conversation_id = cm1.conversation_id) = 2 " + "LIMIT 1", + (user_id_a, user_id_b), + ) + row = cursor.fetchone() + return row[0] if row else None + finally: + conn.close() + + +def update_conversation_creator(conversation_id: str, new_creator_id: str): + """Transfer group creator role to another member.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "UPDATE conversations SET created_by = %s WHERE id = %s", + (new_creator_id, conversation_id), + ) + conn.commit() + finally: + conn.close() + + +def get_conversation(conversation_id: str) -> dict | None: + """Get conversation by ID.""" + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT id, created_at, name, created_by, avatar_file FROM conversations WHERE id = %s", + (conversation_id,), + ) + return cursor.fetchone() + finally: + conn.close() + + +def update_conversation_avatar(conversation_id: str, avatar_file: str): + """Set avatar file for a conversation.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "UPDATE conversations SET avatar_file = %s WHERE id = %s", + (avatar_file, conversation_id), + ) + conn.commit() + finally: + conn.close() + + +def update_conversation_name(conversation_id: str, name: str): + """Update the name of a conversation.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "UPDATE conversations SET name = %s WHERE id = %s", + (name, conversation_id), + ) + conn.commit() + finally: + conn.close() + + +def is_conversation_member(conversation_id: str, user_id: str) -> bool: + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "SELECT 1 FROM conversation_members WHERE conversation_id = %s AND user_id = %s", + (conversation_id, user_id), + ) + return cursor.fetchone() is not None + finally: + conn.close() + + +def list_user_conversations(user_id: str) -> list[dict]: + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT c.id, c.created_at, c.name, c.created_by, c.avatar_file FROM conversations c " + "JOIN conversation_members cm ON c.id = cm.conversation_id " + "WHERE cm.user_id = %s ORDER BY c.created_at DESC", + (user_id,), + ) + convs = cursor.fetchall() + for conv in convs: + cursor.execute( + "SELECT u.id AS user_id, u.username, u.email FROM conversation_members cm " + "JOIN users u ON cm.user_id = u.id " + "WHERE cm.conversation_id = %s", + (conv["id"],), + ) + conv["members"] = cursor.fetchall() + return convs + finally: + conn.close() + + +# --- Group Invitations --- + +def create_invitation(conversation_id: str, user_id: str, invited_by: str): + """Create a pending group invitation.""" + conn = get_connection() + try: + cursor = conn.cursor() + inv_id = generate_uuid() + cursor.execute( + "INSERT IGNORE INTO group_invitations (id, conversation_id, user_id, invited_by) " + "VALUES (%s, %s, %s, %s)", + (inv_id, conversation_id, user_id, invited_by), + ) + conn.commit() + finally: + conn.close() + + +def get_pending_invitations(user_id: str) -> list[dict]: + """Get all pending invitations for a user, joined with conversation and inviter info.""" + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT gi.id, gi.conversation_id, gi.invited_by, gi.created_at, " + "c.name AS conversation_name, u.username AS invited_by_username " + "FROM group_invitations gi " + "JOIN conversations c ON gi.conversation_id = c.id " + "JOIN users u ON gi.invited_by = u.id " + "WHERE gi.user_id = %s " + "ORDER BY gi.created_at DESC", + (user_id,), + ) + return cursor.fetchall() + finally: + conn.close() + + +def delete_invitation(conversation_id: str, user_id: str): + """Delete a pending invitation.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "DELETE FROM group_invitations WHERE conversation_id = %s AND user_id = %s", + (conversation_id, user_id), + ) + conn.commit() + finally: + conn.close() + + +def has_pending_invitation(conversation_id: str, user_id: str) -> bool: + """Check if a user has a pending invitation for a conversation.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "SELECT 1 FROM group_invitations WHERE conversation_id = %s AND user_id = %s", + (conversation_id, user_id), + ) + return cursor.fetchone() is not None + finally: + conn.close() + + +# --- Messages --- + +def store_message( + conversation_id: str, + sender_id: str, + ratchet_header: bytes, + recipients: list[dict], + x3dh_header: bytes | None = None, + sender_chain_id: bytes | None = None, + sender_chain_n: int | None = None, + image_file_id: str | None = None, + sender_device_id: str | None = None, +) -> str: + """Store an encrypted message with per-recipient ciphertext. + + recipients: [{user_id, encrypted_content (bytes), nonce (bytes), + device_id (str, optional), ratchet_header (bytes, optional), + x3dh_header (bytes, optional)}] + """ + conn = get_connection() + try: + cursor = conn.cursor() + msg_id = generate_uuid() + cursor.execute( + "INSERT INTO messages (id, conversation_id, sender_id, sender_device_id, " + "ratchet_header, x3dh_header, sender_chain_id, sender_chain_n, image_file_id) " + "VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)", + (msg_id, conversation_id, sender_id, sender_device_id, ratchet_header, + x3dh_header, sender_chain_id, sender_chain_n, image_file_id), + ) + for r in recipients: + device_id = r.get("device_id", SELF_DEVICE_ID) + cursor.execute( + "INSERT INTO message_recipients (message_id, user_id, device_id, " + "encrypted_content, nonce, ratchet_header, x3dh_header) " + "VALUES (%s, %s, %s, %s, %s, %s, %s)", + (msg_id, r["user_id"], device_id, r["encrypted_content"], r["nonce"], + r.get("ratchet_header"), r.get("x3dh_header")), + ) + conn.commit() + return msg_id + finally: + conn.close() + + +def get_messages(conversation_id: str, user_id: str, limit: int = 50, offset: int = 0, + device_id: str | None = None) -> list[dict]: + """Get messages for a user in a conversation, JOINing their per-recipient ciphertext. + + If device_id is set, returns rows where mr.device_id matches OR is the sentinel + (self-encrypted / legacy). This ensures both device-specific and self-encrypted + copies are returned. + """ + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + if device_id: + cursor.execute( + "SELECT m.id, m.conversation_id, m.sender_id, m.sender_device_id, " + "m.ratchet_header, m.x3dh_header, " + "m.sender_chain_id, m.sender_chain_n, m.created_at, m.deleted_at, m.image_file_id, " + "mr.encrypted_content, mr.nonce, mr.device_id AS mr_device_id, " + "mr.ratchet_header AS mr_ratchet_header, mr.x3dh_header AS mr_x3dh_header " + "FROM messages m " + "JOIN message_recipients mr ON m.id = mr.message_id AND mr.user_id = %s " + " AND (mr.device_id = %s OR mr.device_id = %s) " + "JOIN conversation_members cm ON cm.conversation_id = m.conversation_id AND cm.user_id = %s " + "WHERE m.conversation_id = %s AND (cm.joined_at IS NULL OR m.created_at >= cm.joined_at) " + "ORDER BY m.created_at DESC LIMIT %s OFFSET %s", + (user_id, device_id, SELF_DEVICE_ID, user_id, conversation_id, limit, offset), + ) + else: + cursor.execute( + "SELECT m.id, m.conversation_id, m.sender_id, m.sender_device_id, " + "m.ratchet_header, m.x3dh_header, " + "m.sender_chain_id, m.sender_chain_n, m.created_at, m.deleted_at, m.image_file_id, " + "mr.encrypted_content, mr.nonce, mr.device_id AS mr_device_id, " + "mr.ratchet_header AS mr_ratchet_header, mr.x3dh_header AS mr_x3dh_header " + "FROM messages m " + "JOIN message_recipients mr ON m.id = mr.message_id AND mr.user_id = %s " + "JOIN conversation_members cm ON cm.conversation_id = m.conversation_id AND cm.user_id = %s " + "WHERE m.conversation_id = %s AND (cm.joined_at IS NULL OR m.created_at >= cm.joined_at) " + "ORDER BY m.created_at DESC LIMIT %s OFFSET %s", + (user_id, user_id, conversation_id, limit, offset), + ) + return cursor.fetchall() + finally: + conn.close() + + +def get_message_conversation(message_id: str) -> str | None: + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute("SELECT conversation_id FROM messages WHERE id = %s", (message_id,)) + row = cursor.fetchone() + return row[0] if row else None + finally: + conn.close() + + +def get_message_sender(message_id: str) -> str | None: + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute("SELECT sender_id FROM messages WHERE id = %s", (message_id,)) + row = cursor.fetchone() + return row[0] if row else None + finally: + conn.close() + + +# --- Group Sender Keys --- + +def store_sender_key(conversation_id: str, sender_id: str, chain_id: bytes, + device_id: str | None = None): + """Store or update a sender key chain ID for a group member's device.""" + conn = get_connection() + try: + cursor = conn.cursor() + dev = device_id or SELF_DEVICE_ID + cursor.execute( + "REPLACE INTO group_sender_keys (conversation_id, sender_id, device_id, chain_id) " + "VALUES (%s, %s, %s, %s)", + (conversation_id, sender_id, dev, chain_id), + ) + conn.commit() + finally: + conn.close() + + +def get_sender_key(conversation_id: str, sender_id: str, + device_id: str | None = None) -> dict | None: + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + dev = device_id or SELF_DEVICE_ID + cursor.execute( + "SELECT chain_id, created_at FROM group_sender_keys " + "WHERE conversation_id = %s AND sender_id = %s AND device_id = %s", + (conversation_id, sender_id, dev), + ) + return cursor.fetchone() + finally: + conn.close() + + +# --- Read Receipts --- + +def mark_messages_read(conversation_id: str, user_id: str, message_ids: list[str]): + if not message_ids: + return + conn = get_connection() + try: + cursor = conn.cursor() + for mid in message_ids: + cursor.execute( + "INSERT IGNORE INTO message_reads (message_id, user_id) VALUES (%s, %s)", + (mid, user_id), + ) + conn.commit() + finally: + conn.close() + + +def get_unread_counts(user_id: str) -> dict[str, int]: + """Return {conversation_id: unread_count} for all conversations the user is in.""" + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT m.conversation_id, COUNT(*) AS cnt " + "FROM messages m " + "JOIN message_recipients mr ON mr.message_id = m.id AND mr.user_id = %s " + "LEFT JOIN message_reads mrd ON mrd.message_id = m.id AND mrd.user_id = %s " + "WHERE m.sender_id != %s AND m.deleted_at IS NULL AND mrd.message_id IS NULL " + "GROUP BY m.conversation_id", + (user_id, user_id, user_id), + ) + return {row["conversation_id"]: row["cnt"] for row in cursor.fetchall()} + finally: + conn.close() + + +def get_message_read_status(message_ids: list[str]) -> dict: + if not message_ids: + return {} + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + placeholders = ",".join(["%s"] * len(message_ids)) + cursor.execute( + f"SELECT mr.message_id, mr.user_id, mr.read_at " + f"FROM message_reads mr " + f"WHERE mr.message_id IN ({placeholders})", + tuple(message_ids), + ) + result = {} + for row in cursor.fetchall(): + mid = row["message_id"] + if mid not in result: + result[mid] = [] + result[mid].append({ + "user_id": row["user_id"], + "read_at": row["read_at"].isoformat() if hasattr(row["read_at"], "isoformat") else str(row["read_at"]), + }) + return result + finally: + conn.close() + + +# --- Delete --- + +def soft_delete_message(message_id: str, sender_id: str) -> dict | None: + """Soft-delete a message if sender matches. Returns {'image_file_id': ...} or None.""" + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT sender_id, image_file_id FROM messages WHERE id = %s AND deleted_at IS NULL", + (message_id,), + ) + row = cursor.fetchone() + if not row or row["sender_id"] != sender_id: + return None + cursor.execute( + "UPDATE messages SET deleted_at = NOW() WHERE id = %s", + (message_id,), + ) + # Clear per-recipient ciphertext + cursor.execute( + "UPDATE message_recipients SET encrypted_content = %s WHERE message_id = %s", + (b"", message_id), + ) + conn.commit() + return {"image_file_id": row.get("image_file_id")} + finally: + conn.close() + + +def set_message_image_file_id(message_id: str, file_id: str): + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "UPDATE messages SET image_file_id = %s WHERE id = %s", + (file_id, message_id), + ) + conn.commit() + finally: + conn.close() + + +# --- Image Uploads --- + +def create_image_upload(file_id: str, conversation_id: str, uploader_id: str, file_size: int): + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "INSERT INTO image_uploads (file_id, conversation_id, uploader_id, file_size) " + "VALUES (%s, %s, %s, %s)", + (file_id, conversation_id, uploader_id, file_size), + ) + conn.commit() + finally: + conn.close() + + +def complete_image_upload(file_id: str): + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "UPDATE image_uploads SET completed = TRUE WHERE file_id = %s", + (file_id,), + ) + conn.commit() + finally: + conn.close() + + +def get_image_upload(file_id: str) -> dict | None: + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT file_id, conversation_id, uploader_id, file_size, completed, created_at " + "FROM image_uploads WHERE file_id = %s", + (file_id,), + ) + return cursor.fetchone() + finally: + conn.close() + + +def delete_image_upload(file_id: str): + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute("DELETE FROM image_uploads WHERE file_id = %s", (file_id,)) + conn.commit() + finally: + conn.close() + + +# --- User Profiles --- + +def create_default_profile(user_id: str): + """Create a default profile for a new user.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "INSERT IGNORE INTO user_profiles (user_id) VALUES (%s)", + (user_id,), + ) + conn.commit() + finally: + conn.close() + + +def get_user_profile(user_id: str, viewer_id: str | None = None) -> dict | None: + """Get user profile joined with user info. Respects visibility if viewer is different user.""" + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT u.id AS user_id, u.username, u.email, u.created_at, " + "p.phone, p.phone_visible, p.email_visible, p.location, " + "p.location_visible, p.avatar_file, p.updated_at " + "FROM users u LEFT JOIN user_profiles p ON u.id = p.user_id " + "WHERE u.id = %s", + (user_id,), + ) + row = cursor.fetchone() + if not row: + return None + # If viewing someone else's profile, apply visibility rules + if viewer_id and viewer_id != user_id: + if not row.get("email_visible"): + row["email"] = None + if not row.get("phone_visible"): + row["phone"] = None + if not row.get("location_visible"): + row["location"] = None + return row + finally: + conn.close() + + +def update_user_profile(user_id: str, **fields): + """Upsert user profile fields. Allowed: phone, phone_visible, email_visible, + location, location_visible, avatar_file.""" + allowed = {"phone", "phone_visible", "email_visible", "location", + "location_visible", "avatar_file"} + filtered = {k: v for k, v in fields.items() if k in allowed} + if not filtered: + return + conn = get_connection() + try: + cursor = conn.cursor() + # Upsert: insert default then update + cursor.execute( + "INSERT IGNORE INTO user_profiles (user_id) VALUES (%s)", + (user_id,), + ) + set_clause = ", ".join(f"{k} = %s" for k in filtered) + values = list(filtered.values()) + [user_id] + cursor.execute( + f"UPDATE user_profiles SET {set_clause} WHERE user_id = %s", + values, + ) + conn.commit() + finally: + conn.close() + + +def batch_reencrypt_messages(user_id: str, updates: list[dict]): + """Batch re-encrypt message_recipients rows with self-encryption key data. + + Each update: {message_id, encrypted_content (bytes), nonce (bytes)}. + Sets ratchet_header to '{"self":true}' and clears x3dh_header. + Only updates rows belonging to user_id with sentinel device_id (self-encrypted copies). + """ + if not updates: + return + conn = get_connection() + try: + cursor = conn.cursor() + self_header = b'{"self":true}' + for u in updates: + cursor.execute( + "UPDATE message_recipients " + "SET encrypted_content = %s, nonce = %s, ratchet_header = %s, x3dh_header = NULL " + "WHERE message_id = %s AND user_id = %s AND device_id = %s", + (u["encrypted_content"], u["nonce"], self_header, u["message_id"], + user_id, SELF_DEVICE_ID), + ) + conn.commit() + finally: + conn.close() + + +# --- Phantom Users --- + +def create_phantom_user(email: str) -> dict: + """Create a phantom user with valid crypto keys for X3DH. + + Phantom users have rsa_public_key = 'PHANTOM' as a marker. + Returns user dict: {id, username, email, identity_key}. + """ + username = email.split("@")[0] + user_id = generate_uuid() + + # Generate real crypto keys so X3DH works on the client side + ik_private, ik_public = generate_identity_keypair() + ik_public_bytes = serialize_ed25519_public(ik_public) + + spk = generate_signed_prekey(ik_private) + spk_pub_bytes = serialize_x25519_public(spk["public"]) + spk_sig = spk["signature"] + + opks = generate_one_time_prekeys(count=5) + + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "INSERT INTO users (id, username, email, rsa_public_key, identity_key) " + "VALUES (%s, %s, %s, %s, %s)", + (user_id, username, email, "PHANTOM", ik_public_bytes), + ) + cursor.execute( + "INSERT INTO signed_prekeys (id, user_id, public_key, signature) VALUES (%s, %s, %s, %s)", + (spk["id"], user_id, spk_pub_bytes, spk_sig), + ) + for opk in opks: + cursor.execute( + "INSERT INTO one_time_prekeys (id, user_id, public_key) VALUES (%s, %s, %s)", + (opk["id"], user_id, serialize_x25519_public(opk["public"])), + ) + conn.commit() + return {"id": user_id, "username": username, "email": email, "identity_key": ik_public_bytes} + finally: + conn.close() + + +def is_phantom_user(user_id: str) -> bool: + """Check if a user is a phantom (rsa_public_key == 'PHANTOM').""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute("SELECT rsa_public_key FROM users WHERE id = %s", (user_id,)) + row = cursor.fetchone() + return row is not None and row[0] == "PHANTOM" + finally: + conn.close() + + +def delete_phantom_user(user_id: str): + """Delete a phantom user. CASCADE removes signed_prekeys, one_time_prekeys, + conversation_members, message_recipients, etc.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "DELETE FROM users WHERE id = %s AND rsa_public_key = %s", + (user_id, "PHANTOM"), + ) + conn.commit() + finally: + conn.close() + + +def upgrade_phantom_user(phantom_id: str, username: str, rsa_public_key_pem: str, + identity_key: bytes) -> str | None: + """Upgrade a phantom user to a real user in-place. + + Preserves user_id and all FK references (conversation_members, group_invitations, etc.). + Deletes phantom's server-generated prekeys (real user will upload own on first login). + Returns phantom_id as the new user_id, or None if phantom no longer exists. + """ + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "UPDATE users SET username = %s, rsa_public_key = %s, identity_key = %s " + "WHERE id = %s AND rsa_public_key = 'PHANTOM'", + (username, rsa_public_key_pem, identity_key, phantom_id), + ) + if cursor.rowcount == 0: + conn.rollback() + return None + # Remove phantom's server-generated crypto keys — real user uploads own + cursor.execute("DELETE FROM signed_prekeys WHERE user_id = %s", (phantom_id,)) + cursor.execute("DELETE FROM one_time_prekeys WHERE user_id = %s", (phantom_id,)) + conn.commit() + return phantom_id + finally: + conn.close() + + +def get_all_phantom_user_ids() -> set[str]: + """Return set of all phantom user IDs (for server startup cache).""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute("SELECT id FROM users WHERE rsa_public_key = %s", ("PHANTOM",)) + return {row[0] for row in cursor.fetchall()} + finally: + conn.close() + + +def cleanup_stale_phantoms(max_age_days: int = 30) -> int: + """Delete phantom users older than max_age_days with no active conversations with real users.""" + conn = get_connection() + try: + cursor = conn.cursor() + # Two-step: SELECT ids first, then DELETE. + # MySQL error 1093: can't DELETE from table referenced in subquery. + cursor.execute(""" + SELECT u.id FROM users u + WHERE u.rsa_public_key = 'PHANTOM' + AND u.created_at < DATE_SUB(NOW(), INTERVAL %s DAY) + AND NOT EXISTS ( + SELECT 1 FROM conversation_members cm1 + JOIN conversation_members cm2 ON cm1.conversation_id = cm2.conversation_id + JOIN users u2 ON cm2.user_id = u2.id + WHERE cm1.user_id = u.id + AND u2.rsa_public_key != 'PHANTOM' + ) + """, (max_age_days,)) + ids = [row[0] for row in cursor.fetchall()] + if not ids: + return 0 + cursor.execute( + "DELETE FROM users WHERE id IN (%s)" % ",".join(["%s"] * len(ids)), + ids, + ) + deleted = cursor.rowcount + conn.commit() + return deleted + finally: + conn.close() + + +def remove_conversation_member_atomic(conversation_id: str, user_id: str) -> bool: + """Remove member and return True if actually removed (row existed). M6 TOCTOU fix.""" + conn = get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "DELETE FROM conversation_members WHERE conversation_id = %s AND user_id = %s", + (conversation_id, user_id), + ) + conn.commit() + return cursor.rowcount > 0 + finally: + conn.close() + + +def get_stale_uploads(max_age_seconds: int = 3600) -> list[dict]: + conn = get_connection() + try: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT file_id FROM image_uploads " + "WHERE completed = FALSE AND created_at < DATE_SUB(NOW(), INTERVAL %s SECOND)", + (max_age_seconds,), + ) + return cursor.fetchall() + finally: + conn.close() diff --git a/zaloha/gui_client.py b/zaloha/gui_client.py new file mode 100644 index 0000000..60adb26 --- /dev/null +++ b/zaloha/gui_client.py @@ -0,0 +1,3335 @@ +"""PyQt6 GUI client for encrypted chat.""" + +import asyncio +import json +import logging +import os + +logger = logging.getLogger(__name__) +import re +import sys +from functools import partial + +from PyQt6.QtCore import QThread, pyqtSignal, Qt, QTimer, QUrl, QSize +from PyQt6.QtWidgets import ( + QApplication, QWidget, QVBoxLayout, QHBoxLayout, QPushButton, + QLineEdit, QLabel, QListWidget, QListWidgetItem, QTextEdit, + QSplitter, QMessageBox, QInputDialog, QMenu, QStackedWidget, + QDialog, QFileDialog, QScrollArea, QTextBrowser, +) +from PyQt6.QtGui import QFont, QAction, QPixmap, QImage, QDesktopServices, QIcon, QPainter, QColor, QBrush, QPen, QShortcut, QKeySequence +from PyQt6.QtWidgets import QStyle + +from chat_core import ChatClient + +# H10: Image validation limits +MAX_IMAGE_DATA_SIZE = 10 * 1024 * 1024 # 10 MB max raw image data +MAX_IMAGE_DIMENSION = 8192 # 8K pixels max + + +def _safe_load_image(data: bytes) -> QImage | None: + """Load image with size and dimension validation (H10).""" + if not data or len(data) > MAX_IMAGE_DATA_SIZE: + return None + qimg = QImage.fromData(data) + if qimg.isNull(): + return None + if qimg.width() > MAX_IMAGE_DIMENSION or qimg.height() > MAX_IMAGE_DIMENSION: + return None + return qimg + + +def _safe_filename(name: str, default: str = "file") -> str: + """Sanitize filename — strip path components, prevent traversal (H11).""" + name = os.path.basename(name) + name = name.replace("\x00", "") + return name if name else default + + +# URL regex: matches http:// and https:// URLs in raw (not-yet-escaped) text +_URL_RE = re.compile( + r'(https?://[^\s<>"\')\]]+)', + re.IGNORECASE, +) +_URL_TRAILING_PUNCT = re.compile(r'[.,;:!?]+$') + + +def _linkify_urls(raw_text: str) -> str: + """HTML-escape text and convert URLs into clickable tags. + + HTTPS links get blue styling. HTTP links get orange + unlock icon warning. + Processes raw (unescaped) text — returns HTML-safe string. + """ + def _esc(s): + return s.replace("&", "&").replace("<", "<").replace(">", ">") + + parts = _URL_RE.split(raw_text) + result = [] + for i, part in enumerate(parts): + if i % 2 == 1: + # URL match — strip trailing sentence punctuation + trail_m = _URL_TRAILING_PUNCT.search(part) + if trail_m: + url = part[:trail_m.start()] + trail = part[trail_m.start():] + else: + url = part + trail = "" + url_esc = _esc(url) + if url.lower().startswith("http://"): + result.append( + f'' + f'\U0001f513 {url_esc}' + ) + else: + result.append( + f'' + f'{url_esc}' + ) + if trail: + result.append(_esc(trail)) + else: + result.append(_esc(part)) + return "".join(result) + + +def setup_logging(): + level_name = os.getenv("LOG_LEVEL", "WARNING").upper() + level = getattr(logging, level_name, logging.WARNING) + logging.basicConfig(level=level, format="%(levelname)s: %(message)s") + + +DARK_STYLE = """ +QWidget { + background-color: #1e1e2e; + color: #cdd6f4; + font-family: "Segoe UI", "DejaVu Sans", sans-serif; + font-size: 11pt; +} +QLineEdit { + background-color: #313244; + border: 1px solid #45475a; + border-radius: 6px; + padding: 8px; + color: #cdd6f4; +} +QLineEdit:focus { + border: 1px solid #89b4fa; +} +QPushButton { + background-color: #89b4fa; + color: #1e1e2e; + border: none; + border-radius: 6px; + padding: 8px 16px; + font-weight: bold; +} +QPushButton:hover { + background-color: #74c7ec; +} +QPushButton:pressed { + background-color: #89dceb; +} +QPushButton#secondaryBtn { + background-color: #45475a; + color: #cdd6f4; +} +QPushButton#secondaryBtn:hover { + background-color: #585b70; +} +QListWidget { + background-color: #181825; + border: none; + border-radius: 6px; + padding: 4px; +} +QListWidget::item { + padding: 10px; + border-radius: 4px; +} +QListWidget::item:selected { + background-color: #313244; + border-left: 3px solid #89b4fa; +} +QListWidget::item:hover { + background-color: #252536; + color: #cdd6f4; +} +QTextEdit, QTextBrowser { + background-color: #1e1e2e; + border: none; + border-radius: 6px; + padding: 8px; + color: #cdd6f4; +} +QScrollBar:vertical { + background: #1e1e2e; + width: 10px; + border-radius: 5px; +} +QScrollBar::handle:vertical { + background: #45475a; + border-radius: 5px; + min-height: 30px; +} +QScrollBar::handle:vertical:hover { + background: #585b70; +} +QScrollBar::add-line:vertical, QScrollBar::sub-line:vertical { + height: 0; +} +QLabel#title { + font-size: 15pt; + font-weight: bold; + color: #89b4fa; +} +QSplitter::handle { + background-color: #45475a; + width: 1px; +} +""" + +MAX_INPUT_CHARS = int(os.getenv("MAX_INPUT_CHARS", "2000")) + + +class MessageInput(QTextEdit): + """Multiline message input: Enter sends, Shift+Enter inserts newline.""" + send_requested = pyqtSignal() + + def __init__(self, parent=None): + super().__init__(parent) + self.setAcceptRichText(False) + self.setPlaceholderText("Type a message...") + self.setFixedHeight(72) + self.setStyleSheet( + "QTextEdit { background-color: #313244; border: 1px solid #45475a; " + "border-radius: 6px; padding: 8px; color: #cdd6f4; }" + "QTextEdit:focus { border: 1px solid #89b4fa; }" + ) + + def keyPressEvent(self, event): + if event.key() in (Qt.Key.Key_Return, Qt.Key.Key_Enter): + if event.modifiers() & Qt.KeyboardModifier.ShiftModifier: + super().keyPressEvent(event) + else: + self.send_requested.emit() + return + super().keyPressEvent(event) + + +class AsyncBridge(QThread): + """Runs asyncio event loop in a background thread, emits Qt signals.""" + connected = pyqtSignal() + connection_error = pyqtSignal(str) + login_result = pyqtSignal(bool, str) + register_result = pyqtSignal(bool, str) + conversations_loaded = pyqtSignal(list) + messages_loaded = pyqtSignal(str, list) # conv_id, messages + older_messages_loaded = pyqtSignal(str, list) # conv_id, older messages + message_sent = pyqtSignal(bool, str) + new_notification = pyqtSignal(dict) # decrypted payload + pairing_code = pyqtSignal(str) + pairing_complete = pyqtSignal(bool, str) + add_member_result = pyqtSignal(bool, str) + remove_member_result = pyqtSignal(bool, str) + authorize_result = pyqtSignal(bool, str) + rotate_result = pyqtSignal(bool, str) + reencrypt_status = pyqtSignal(str) + messages_read_notification = pyqtSignal(dict) + message_deleted_notification = pyqtSignal(dict) + image_sent = pyqtSignal(bool, str) + image_downloaded = pyqtSignal(str, bytes) # file_id, decrypted bytes + delete_message_result = pyqtSignal(bool, str) + reconnected = pyqtSignal() + conversation_updated = pyqtSignal() + connection_state_changed = pyqtSignal(str) # "connected", "disconnected", "reconnecting" + profile_loaded = pyqtSignal(dict) + profile_updated = pyqtSignal(bool, str) + avatar_loaded = pyqtSignal(str, bytes) # user_id, avatar_bytes + online_status_changed = pyqtSignal(str, bool) # user_id, is_online + online_users_loaded = pyqtSignal(list) # list of user_ids + invitations_loaded = pyqtSignal(list) # list of invitation dicts + invitation_result = pyqtSignal(bool, str) # ok, message + invitation_received = pyqtSignal(dict) # invitation notification data + group_avatar_loaded = pyqtSignal(str, bytes) # conv_id, avatar_bytes + group_avatar_updated = pyqtSignal(bool, str) # ok, message + session_reset_notification = pyqtSignal(str, str) # from_user_id, from_device_id + + def __init__(self): + super().__init__() + self.client = ChatClient() + self.loop: asyncio.AbstractEventLoop | None = None + self._running = True + self.client._reencrypt_progress_cb = self._emit_reencrypt_status + self._ready: asyncio.Event | None = None + + def _emit_reencrypt_status(self, message: str): + self.reencrypt_status.emit(message) + + def run(self): + if sys.platform == "win32": + self.loop = asyncio.SelectorEventLoop() + else: + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + self._ready = asyncio.Event() + try: + self.loop.run_until_complete(self._run()) + except Exception: + pass + finally: + self.loop.close() + + async def _run(self): + try: + await self.client.connect() + self.client._listener_task = asyncio.create_task(self.client._background_listener()) + if self._ready: + self._ready.set() + self.connected.emit() + self.connection_state_changed.emit("connected") + except Exception as e: + self.connection_error.emit(str(e)) + return + + # Process notifications + await self._notification_loop() + + async def _notification_loop(self): + while self._running: + try: + # Check if listener task died (connection lost) + if (self.client._listener_task and self.client._listener_task.done() + and not self.client.connected): + self.connection_state_changed.emit("disconnected") + if self.client.session: + await self._auto_reconnect() + continue + + notif = await asyncio.wait_for( + self.client._notification_queue.get(), timeout=0.5 + ) + notif_type = notif.get("type", "") + data = notif.get("data", {}) + if notif_type in ("conversation_created", "member_added", "member_removed", + "conversation_renamed"): + self.conversation_updated.emit() + elif notif_type == "group_invitation": + self.invitation_received.emit(data) + elif notif_type == "user_online": + self.online_status_changed.emit(data.get("user_id", ""), True) + elif notif_type == "user_offline": + self.online_status_changed.emit(data.get("user_id", ""), False) + elif notif_type == "online_users": + self.online_users_loaded.emit(data.get("user_ids", [])) + elif notif_type == "messages_read": + self.messages_read_notification.emit(data) + elif notif_type == "message_deleted": + self.message_deleted_notification.emit(data) + elif notif_type == "session_reset": + from_uid = data.get("from_user_id", "") + from_did = data.get("from_device_id", "") + self.client.handle_session_reset_notification(from_uid, from_did or None) + self.session_reset_notification.emit(from_uid, from_did) + elif notif_type == "new_message": + payload = self.client.decrypt_notification(data) + if payload: + self.new_notification.emit(payload) + # None = control message (e.g. sender key distribution), skip silently + except asyncio.TimeoutError: + continue + except Exception: + break + + async def _auto_reconnect(self): + """Auto-reconnect with exponential backoff.""" + delay = 1 + while self._running and not self.client.connected: + self.connection_state_changed.emit("reconnecting") + try: + await self.client.reconnect() + if self.client.connected and self.client.session: + self.connection_state_changed.emit("connected") + self.conversation_updated.emit() + return + if self.client.login_rejected: + self.connection_state_changed.emit("revoked") + return + except Exception: + pass + await asyncio.sleep(delay) + delay = min(delay * 2, 30) + + def schedule(self, coro): + """Schedule a coroutine on the asyncio loop from the Qt thread.""" + if self.loop and self.loop.is_running(): + asyncio.run_coroutine_threadsafe(coro, self.loop) + else: + # Avoid "coroutine was never awaited" warnings if loop is down. + try: + coro.close() + except Exception: + pass + + async def _do_register(self, username, password, email): + if self._ready: + await self._ready.wait() + ok, code_or_msg = await self.client.register(username, password, email=email) + self.register_result.emit(ok, code_or_msg) + + async def _do_login(self, email, password): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.login(email, password) + self.login_result.emit(ok, msg) + + async def _do_logout(self): + if self._ready: + self._ready.clear() + try: + await self.client.close() + except Exception: + pass + self.client = ChatClient() + self.client._reencrypt_progress_cb = self._emit_reencrypt_status + try: + await self.client.connect() + self.client._listener_task = asyncio.create_task(self.client._background_listener()) + if self._ready: + self._ready.set() + self.reconnected.emit() + except Exception as e: + self.connection_error.emit(str(e)) + + async def _do_load_conversations(self): + if self._ready: + await self._ready.wait() + convs = await self.client.list_conversations() + self.conversations_loaded.emit(convs) + + async def _do_load_messages(self, conv_id): + if self._ready: + await self._ready.wait() + msgs = await self.client.get_messages(conv_id) + self.messages_loaded.emit(conv_id, msgs) + + async def _do_load_older_messages(self, conv_id, offset): + if self._ready: + await self._ready.wait() + msgs = await self.client.get_messages(conv_id, limit=50, offset=offset) + self.older_messages_loaded.emit(conv_id, msgs) + + async def _do_send_message(self, conv_id, text, members, reply_to=None): + if self._ready: + await self._ready.wait() + try: + ok, msg = await self.client.send_message(conv_id, text, members, reply_to=reply_to) + except Exception as e: + logger.error("send_message exception: %s", e, exc_info=True) + self.message_sent.emit(False, str(e)) + return + self.message_sent.emit(ok, msg) + if ok: + # Reload messages to get the server-assigned message_id and timestamp + msgs = await self.client.get_messages(conv_id, limit=50) + self.messages_loaded.emit(conv_id, msgs) + + async def _do_find_or_create_and_send(self, username, text): + if self._ready: + await self._ready.wait() + try: + conv_id, msg = await self.client.find_or_create_conversation(username) + if not conv_id: + self.message_sent.emit(False, msg) + return + convs = await self.client.list_conversations() + self.conversations_loaded.emit(convs) + members = [] + for c in convs: + if c["conversation_id"] == conv_id: + members = c["members"] + break + ok, msg = await self.client.send_message(conv_id, text, members) + self.message_sent.emit(ok, msg) + if ok: + msgs = await self.client.get_messages(conv_id) + self.messages_loaded.emit(conv_id, msgs) + except Exception as e: + logger.error("find_or_create_and_send exception: %s", e, exc_info=True) + self.message_sent.emit(False, str(e)) + + async def _do_create_group(self, members, name=None): + if self._ready: + await self._ready.wait() + conv_id, msg = await self.client.create_conversation(members, name=name) + if conv_id: + self.message_sent.emit(True, f"Group created") + else: + self.message_sent.emit(False, msg) + convs = await self.client.list_conversations() + self.conversations_loaded.emit(convs) + + async def _do_link_device(self, username, password): + if self._ready: + await self._ready.wait() + ok, code_or_msg = await self.client.pairing_start(username) + if not ok: + self.pairing_complete.emit(False, code_or_msg) + return + code = code_or_msg + self.pairing_code.emit(code) + ok2, msg2 = await self.client.pairing_wait(code, username, password) + self.pairing_complete.emit(ok2, msg2) + + async def _do_authorize_device(self, code): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.authorize_device(code) + self.authorize_result.emit(ok, msg) + + async def _do_rotate_keys(self, username, password): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.rotate_keys(username, password) + self.rotate_result.emit(ok, msg) + + def do_register(self, username, password, email): + self.schedule(self._do_register(username, password, email)) + + def do_login(self, email, password): + self.schedule(self._do_login(email, password)) + + def load_conversations(self): + self.schedule(self._do_load_conversations()) + + def load_messages(self, conv_id): + self.schedule(self._do_load_messages(conv_id)) + + def load_older_messages(self, conv_id, offset): + self.schedule(self._do_load_older_messages(conv_id, offset)) + + def send_message(self, conv_id, text, members, reply_to=None): + self.schedule(self._do_send_message(conv_id, text, members, reply_to)) + + def send_new_chat(self, username, text): + self.schedule(self._do_find_or_create_and_send(username, text)) + + def create_group(self, members, name=None): + self.schedule(self._do_create_group(members, name=name)) + + async def _do_add_member(self, conv_id, email): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.add_member(conv_id, email) + self.add_member_result.emit(ok, msg) + if ok: + convs = await self.client.list_conversations() + self.conversations_loaded.emit(convs) + + def add_member(self, conv_id, email): + self.schedule(self._do_add_member(conv_id, email)) + + async def _do_remove_member(self, conv_id, user_id): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.remove_member(conv_id, user_id) + self.remove_member_result.emit(ok, msg) + if ok: + convs = await self.client.list_conversations() + self.conversations_loaded.emit(convs) + + def remove_member(self, conv_id, user_id): + self.schedule(self._do_remove_member(conv_id, user_id)) + + group_left = pyqtSignal(bool, str) + group_renamed = pyqtSignal(bool, str) + conversation_deleted = pyqtSignal(bool, str) + + async def _do_leave_group(self, conv_id): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.leave_group(conv_id) + self.group_left.emit(ok, msg) + if ok: + convs = await self.client.list_conversations() + self.conversations_loaded.emit(convs) + + def leave_group(self, conv_id): + self.schedule(self._do_leave_group(conv_id)) + + async def _do_rename_conversation(self, conv_id, name): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.rename_conversation(conv_id, name) + self.group_renamed.emit(ok, msg) + if ok: + convs = await self.client.list_conversations() + self.conversations_loaded.emit(convs) + + def rename_conversation(self, conv_id, name): + self.schedule(self._do_rename_conversation(conv_id, name)) + + async def _do_delete_conversation(self, conv_id): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.delete_conversation(conv_id) + self.conversation_deleted.emit(ok, msg) + if ok: + convs = await self.client.list_conversations() + self.conversations_loaded.emit(convs) + + def delete_conversation(self, conv_id): + self.schedule(self._do_delete_conversation(conv_id)) + + def link_device(self, username, password): + self.schedule(self._do_link_device(username, password)) + + def authorize_device(self, code): + self.schedule(self._do_authorize_device(code)) + + def rotate_keys(self, username, password): + self.schedule(self._do_rotate_keys(username, password)) + + async def _do_delete_message(self, message_id): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.delete_message(message_id) + self.delete_message_result.emit(ok, msg) + + def delete_message(self, message_id): + self.schedule(self._do_delete_message(message_id)) + + def reset_session(self, peer_user_id, peer_device_id=None): + self.schedule(self.client.reset_session(peer_user_id, peer_device_id)) + + async def _do_send_image(self, conv_id, image_path, members, reply_to=None): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.send_image(conv_id, image_path, members, reply_to=reply_to) + self.image_sent.emit(ok, msg) + if ok: + msgs = await self.client.get_messages(conv_id, limit=50) + self.messages_loaded.emit(conv_id, msgs) + + def send_image(self, conv_id, image_path, members, reply_to=None): + self.schedule(self._do_send_image(conv_id, image_path, members, reply_to)) + + async def _do_download_image(self, file_id, image_info): + if self._ready: + await self._ready.wait() + data = await self.client.download_image(file_id, image_info) + if data: + self.image_downloaded.emit(file_id, data) + + def download_image(self, file_id, image_info): + self.schedule(self._do_download_image(file_id, image_info)) + + file_sent = pyqtSignal(bool, str) + file_downloaded = pyqtSignal(bytes, dict) # decrypted_bytes, file_info + + async def _do_send_file(self, conv_id, file_path, members, reply_to=None): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.send_file(conv_id, file_path, members, reply_to=reply_to) + self.file_sent.emit(ok, msg) + if ok: + msgs = await self.client.get_messages(conv_id, limit=50) + self.messages_loaded.emit(conv_id, msgs) + + def send_file(self, conv_id, file_path, members, reply_to=None): + self.schedule(self._do_send_file(conv_id, file_path, members, reply_to)) + + async def _do_download_file(self, file_id, file_info): + if self._ready: + await self._ready.wait() + data = await self.client.download_file(file_id, file_info) + if data: + self.file_downloaded.emit(data, file_info) + + def download_file(self, file_id, file_info): + self.schedule(self._do_download_file(file_id, file_info)) + + async def _do_get_profile(self, user_id=None): + if self._ready: + await self._ready.wait() + profile = await self.client.get_profile(user_id) + if profile: + self.profile_loaded.emit(profile) + + def get_profile(self, user_id=None): + self.schedule(self._do_get_profile(user_id)) + + async def _do_update_profile(self, **fields): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.update_profile(**fields) + self.profile_updated.emit(ok, msg) + + def update_profile(self, **fields): + self.schedule(self._do_update_profile(**fields)) + + async def _do_update_avatar(self, image_data): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.update_avatar(image_data) + self.profile_updated.emit(ok, msg) + + def update_avatar(self, image_data): + self.schedule(self._do_update_avatar(image_data)) + + async def _do_get_avatar(self, user_id): + if self._ready: + await self._ready.wait() + data = await self.client.get_avatar(user_id) + if data: + self.avatar_loaded.emit(user_id, data) + + def get_avatar(self, user_id): + self.schedule(self._do_get_avatar(user_id)) + + async def _do_list_invitations(self): + if self._ready: + await self._ready.wait() + invitations = await self.client.list_invitations() + self.invitations_loaded.emit(invitations) + + def list_invitations(self): + self.schedule(self._do_list_invitations()) + + async def _do_accept_invitation(self, conv_id): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.accept_invitation(conv_id) + self.invitation_result.emit(ok, msg) + if ok: + invitations = await self.client.list_invitations() + self.invitations_loaded.emit(invitations) + convs = await self.client.list_conversations() + self.conversations_loaded.emit(convs) + + def accept_invitation(self, conv_id): + self.schedule(self._do_accept_invitation(conv_id)) + + async def _do_decline_invitation(self, conv_id): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.decline_invitation(conv_id) + self.invitation_result.emit(ok, msg) + if ok: + invitations = await self.client.list_invitations() + self.invitations_loaded.emit(invitations) + + def decline_invitation(self, conv_id): + self.schedule(self._do_decline_invitation(conv_id)) + + async def _do_update_group_avatar(self, conv_id, image_data): + if self._ready: + await self._ready.wait() + ok, msg = await self.client.update_group_avatar(conv_id, image_data) + self.group_avatar_updated.emit(ok, msg) + if ok: + convs = await self.client.list_conversations() + self.conversations_loaded.emit(convs) + + def update_group_avatar(self, conv_id, image_data): + self.schedule(self._do_update_group_avatar(conv_id, image_data)) + + async def _do_get_group_avatar(self, conv_id): + if self._ready: + await self._ready.wait() + data = await self.client.get_group_avatar(conv_id) + if data: + self.group_avatar_loaded.emit(conv_id, data) + + def get_group_avatar(self, conv_id): + self.schedule(self._do_get_group_avatar(conv_id)) + + def logout(self): + self.schedule(self._do_logout()) + + def stop(self): + self._running = False + if self.loop: + asyncio.run_coroutine_threadsafe(self.client.close(), self.loop) + + +class UserProfileDialog(QDialog): + """Dialog for viewing/editing user profiles.""" + + def __init__(self, bridge: AsyncBridge, user_id: str, editable: bool = False, parent=None): + super().__init__(parent) + self.bridge = bridge + self.user_id = user_id + self.editable = editable + self.setWindowTitle("User Profile" if not editable else "Edit Profile") + self.setMinimumWidth(400) + self._build_ui() + self._connect_signals() + self.bridge.get_profile(user_id) + + def _build_ui(self): + self.layout_main = QVBoxLayout(self) + self.layout_main.setSpacing(12) + self.layout_main.setContentsMargins(24, 20, 24, 20) + + # Avatar + self.avatar_label = QLabel() + self.avatar_label.setFixedSize(80, 80) + self.avatar_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.avatar_label.setStyleSheet( + "background-color: #313244; border-radius: 40px; " + "font-size: 21pt; color: #89b4fa;" + ) + self.avatar_label.setText("?") + self.layout_main.addWidget(self.avatar_label, alignment=Qt.AlignmentFlag.AlignCenter) + + if self.editable: + avatar_btn = QPushButton("Change Avatar") + avatar_btn.setObjectName("secondaryBtn") + avatar_btn.clicked.connect(self._on_change_avatar) + self.layout_main.addWidget(avatar_btn, alignment=Qt.AlignmentFlag.AlignCenter) + + # Info fields + self.username_label = QLabel("") + self.username_label.setStyleSheet("font-size: 14pt; font-weight: bold; color: #89b4fa;") + self.username_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.layout_main.addWidget(self.username_label) + + self.info_area = QVBoxLayout() + self.layout_main.addLayout(self.info_area) + + # Editable fields (only shown in edit mode) + if self.editable: + self.layout_main.addSpacing(8) + + form_label = QLabel("Profile Settings") + form_label.setStyleSheet("font-weight: bold; color: #89b4fa;") + self.layout_main.addWidget(form_label) + + self.phone_input = QLineEdit() + self.phone_input.setPlaceholderText("Phone number") + self.layout_main.addWidget(self.phone_input) + + self.location_input = QLineEdit() + self.location_input.setPlaceholderText("Location") + self.layout_main.addWidget(self.location_input) + + from PyQt6.QtWidgets import QCheckBox + self.email_visible_cb = QCheckBox("Email visible to others") + self.email_visible_cb.setStyleSheet("color: #cdd6f4;") + self.layout_main.addWidget(self.email_visible_cb) + + self.phone_visible_cb = QCheckBox("Phone visible to others") + self.phone_visible_cb.setStyleSheet("color: #cdd6f4;") + self.layout_main.addWidget(self.phone_visible_cb) + + self.location_visible_cb = QCheckBox("Location visible to others") + self.location_visible_cb.setStyleSheet("color: #cdd6f4;") + self.layout_main.addWidget(self.location_visible_cb) + + save_btn = QPushButton("Save") + save_btn.clicked.connect(self._on_save) + self.layout_main.addWidget(save_btn) + + close_btn = QPushButton("Close") + close_btn.setObjectName("secondaryBtn") + close_btn.clicked.connect(self.accept) + self.layout_main.addWidget(close_btn) + + def _connect_signals(self): + self.bridge.profile_loaded.connect(self._on_profile_loaded) + self.bridge.avatar_loaded.connect(self._on_avatar_loaded) + self.bridge.profile_updated.connect(self._on_profile_updated) + + def _on_profile_loaded(self, profile): + if profile.get("user_id") != self.user_id: + return + username = profile.get("username", "?") + self.username_label.setText(username) + # Set avatar initial + self.avatar_label.setText(username[0].upper() if username else "?") + + # Clear info area + while self.info_area.count(): + item = self.info_area.takeAt(0) + if item.widget(): + item.widget().deleteLater() + + # Email + email = profile.get("email") + if email: + self.info_area.addWidget(QLabel(f"Email: {email}")) + + # Phone + phone = profile.get("phone") + if phone: + self.info_area.addWidget(QLabel(f"Phone: {phone}")) + + # Location + location = profile.get("location") + if location: + self.info_area.addWidget(QLabel(f"Location: {location}")) + + # Member since + created_at = profile.get("created_at", "") + if created_at: + date_str = created_at[:10] if len(created_at) >= 10 else created_at + label = QLabel(f"Member since: {date_str}") + label.setStyleSheet("color: #6c7086;") + self.info_area.addWidget(label) + + # Populate editable fields + if self.editable: + self.phone_input.setText(phone or "") + self.location_input.setText(location or "") + self.email_visible_cb.setChecked(bool(profile.get("email_visible", 1))) + self.phone_visible_cb.setChecked(bool(profile.get("phone_visible", 0))) + self.location_visible_cb.setChecked(bool(profile.get("location_visible", 0))) + + # Try to load avatar + if profile.get("avatar_file"): + self.bridge.get_avatar(self.user_id) + + def _on_avatar_loaded(self, user_id, data): + if user_id != self.user_id: + return + qimg = _safe_load_image(data) + if qimg is not None: + pixmap = QPixmap.fromImage(qimg) + # Circular crop + size = 80 + scaled = pixmap.scaled(size, size, Qt.AspectRatioMode.KeepAspectRatioByExpanding, + Qt.TransformationMode.SmoothTransformation) + result = QPixmap(size, size) + result.fill(QColor(0, 0, 0, 0)) + painter = QPainter(result) + painter.setRenderHint(QPainter.RenderHint.Antialiasing) + painter.setBrush(QBrush(scaled)) + painter.setPen(Qt.PenStyle.NoPen) + painter.drawEllipse(0, 0, size, size) + painter.end() + self.avatar_label.setPixmap(result) + + def _on_change_avatar(self): + path, _ = QFileDialog.getOpenFileName( + self, "Select Avatar", "", + "Images (*.png *.jpg *.jpeg);;All Files (*)", + ) + if not path: + return + try: + with open(path, "rb") as f: + data = f.read() + if len(data) > 2 * 1024 * 1024: + QMessageBox.warning(self, "Error", "Avatar too large (max 2 MB).") + return + self.bridge.update_avatar(data) + except Exception as e: + QMessageBox.warning(self, "Error", f"Failed to read file: {e}") + + def _on_save(self): + fields = { + "phone": self.phone_input.text().strip() or None, + "location": self.location_input.text().strip() or None, + "email_visible": 1 if self.email_visible_cb.isChecked() else 0, + "phone_visible": 1 if self.phone_visible_cb.isChecked() else 0, + "location_visible": 1 if self.location_visible_cb.isChecked() else 0, + } + self.bridge.update_profile(**fields) + + def _on_profile_updated(self, ok, msg): + if ok: + # Refresh profile + self.bridge.get_profile(self.user_id) + else: + QMessageBox.warning(self, "Error", msg) + + def closeEvent(self, event): + # Disconnect signals to avoid stale references + try: + self.bridge.profile_loaded.disconnect(self._on_profile_loaded) + self.bridge.avatar_loaded.disconnect(self._on_avatar_loaded) + self.bridge.profile_updated.disconnect(self._on_profile_updated) + except Exception: + pass + super().closeEvent(event) + + def reject(self): + try: + self.bridge.profile_loaded.disconnect(self._on_profile_loaded) + self.bridge.avatar_loaded.disconnect(self._on_avatar_loaded) + self.bridge.profile_updated.disconnect(self._on_profile_updated) + except Exception: + pass + super().reject() + + def accept(self): + try: + self.bridge.profile_loaded.disconnect(self._on_profile_loaded) + self.bridge.avatar_loaded.disconnect(self._on_avatar_loaded) + self.bridge.profile_updated.disconnect(self._on_profile_updated) + except Exception: + pass + super().accept() + + +class LoginWindow(QWidget): + def __init__(self, bridge: AsyncBridge): + super().__init__() + self.bridge = bridge + self.setWindowTitle("Encrypted Chat - Login") + self.setFixedSize(500, 480) + self._pair_email = "" + self._pair_password = "" + self._build_ui() + + def _build_ui(self): + outer = QVBoxLayout(self) + outer.setContentsMargins(0, 0, 0, 0) + + self.stack = QStackedWidget() + outer.addWidget(self.stack) + + # --- Page 0: Login / Register form --- + page0 = QWidget() + layout = QVBoxLayout(page0) + layout.setSpacing(14) + layout.setContentsMargins(50, 40, 50, 40) + + title = QLabel("Encrypted Chat") + title.setObjectName("title") + title.setAlignment(Qt.AlignmentFlag.AlignCenter) + layout.addWidget(title) + + subtitle = QLabel("End-to-end encrypted messaging") + subtitle.setAlignment(Qt.AlignmentFlag.AlignCenter) + subtitle.setStyleSheet("color: #6c7086; font-size: 9pt; margin-bottom: 8px;") + layout.addWidget(subtitle) + + layout.addSpacing(6) + + self.username_input = QLineEdit() + self.username_input.setPlaceholderText("Username (display)") + self.username_input.returnPressed.connect(self._on_login) + layout.addWidget(self.username_input) + + self.email_input = QLineEdit() + self.email_input.setPlaceholderText("Email") + layout.addWidget(self.email_input) + + self.password_input = QLineEdit() + self.password_input.setPlaceholderText("Password") + self.password_input.setEchoMode(QLineEdit.EchoMode.Password) + self.password_input.returnPressed.connect(self._on_login) + layout.addWidget(self.password_input) + + btn_row = QHBoxLayout() + self.register_btn = QPushButton("Register") + self.register_btn.setObjectName("secondaryBtn") + self.register_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_DialogApplyButton)) + self.register_btn.clicked.connect(self._on_register) + btn_row.addWidget(self.register_btn) + + self.login_btn = QPushButton("Login") + self.login_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_DialogOkButton)) + self.login_btn.clicked.connect(self._on_login) + btn_row.addWidget(self.login_btn) + layout.addLayout(btn_row) + + self.link_btn = QPushButton("Link Device") + self.link_btn.setObjectName("secondaryBtn") + self.link_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_DialogOpenButton)) + self.link_btn.clicked.connect(self._on_link_device) + layout.addWidget(self.link_btn) + + self.status_label = QLabel("") + self.status_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.status_label.setWordWrap(True) + layout.addWidget(self.status_label) + + layout.addStretch() + self.stack.addWidget(page0) + + # --- Page 1: Verification code form --- + page1 = QWidget() + vl = QVBoxLayout(page1) + vl.setSpacing(14) + vl.setContentsMargins(50, 40, 50, 40) + + step_label = QLabel("Step 2 of 2") + step_label.setStyleSheet("color: #89b4fa; font-weight: bold; font-size: 10pt;") + step_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + vl.addWidget(step_label) + + info_label = QLabel("Enter the 6-digit verification code sent to your email") + info_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + info_label.setWordWrap(True) + info_label.setStyleSheet("color: #cdd6f4; font-size: 10pt;") + vl.addWidget(info_label) + + vl.addSpacing(12) + + self.code_input = QLineEdit() + self.code_input.setPlaceholderText("000000") + self.code_input.setMaxLength(6) + self.code_input.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.code_input.setStyleSheet( + "QLineEdit { font-size: 16pt; letter-spacing: 8px; text-align: center; " + "background-color: #313244; border: 1px solid #45475a; border-radius: 6px; padding: 12px; }" + "QLineEdit:focus { border: 1px solid #89b4fa; }" + ) + self.code_input.returnPressed.connect(self._on_confirm_code) + vl.addWidget(self.code_input) + + vl.addSpacing(8) + + self.code_status_label = QLabel("") + self.code_status_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.code_status_label.setWordWrap(True) + vl.addWidget(self.code_status_label) + + code_btn_row = QHBoxLayout() + self.back_btn = QPushButton("Back") + self.back_btn.setObjectName("secondaryBtn") + self.back_btn.clicked.connect(self._on_back_to_login) + code_btn_row.addWidget(self.back_btn) + + self.confirm_btn = QPushButton("Confirm") + self.confirm_btn.clicked.connect(self._on_confirm_code) + code_btn_row.addWidget(self.confirm_btn) + vl.addLayout(code_btn_row) + + vl.addStretch() + self.stack.addWidget(page1) + + def show_verification_page(self, message=""): + """Switch to verification code page.""" + self.code_input.clear() + self.code_status_label.setText(message) + self.code_status_label.setStyleSheet("color: #a6e3a1;") + self.stack.setCurrentIndex(1) + self.code_input.setFocus() + + def _on_confirm_code(self): + code = self.code_input.text().strip() + if not code: + self.code_status_label.setText("Please enter the code.") + self.code_status_label.setStyleSheet("color: #f38ba8;") + return + self.code_status_label.setText("Confirming...") + self.code_status_label.setStyleSheet("color: #a6e3a1;") + self.confirm_btn.setEnabled(False) + self.back_btn.setEnabled(False) + # Callback set by main() to handle confirmation + if hasattr(self, '_confirm_callback'): + self._confirm_callback(code) + + def _on_back_to_login(self): + self.stack.setCurrentIndex(0) + self._set_enabled(True) + self.status_label.setText("") + self.status_label.setStyleSheet("") + + def _on_register(self): + username = self.username_input.text().strip() + password = self.password_input.text() + email = self.email_input.text().strip() + if not username: + return + if not email or not password: + self.show_error("Email and password required.") + return + self.status_label.setText("Registering...") + self._set_enabled(False) + self.bridge.do_register(username, password, email) + + def _on_login(self): + email = self.email_input.text().strip() + password = self.password_input.text() + if not email or not password: + self.show_error("Email and password required.") + return + self.status_label.setText("Logging in...") + self._set_enabled(False) + self.bridge.do_login(email, password) + + def _on_link_device(self): + email = self.email_input.text().strip() + password = self.password_input.text() + if not email or not password: + self.show_error("Email and password required.") + return + self._pair_email = email + self._pair_password = password + self.status_label.setText("Generating pairing code...") + self._set_enabled(False) + self.bridge.link_device(email, password) + + def _set_enabled(self, enabled): + self.username_input.setEnabled(enabled) + self.email_input.setEnabled(enabled) + self.password_input.setEnabled(enabled) + self.register_btn.setEnabled(enabled) + self.login_btn.setEnabled(enabled) + self.link_btn.setEnabled(enabled) + + def show_error(self, msg): + if self.stack.currentIndex() == 1: + self.code_status_label.setText(msg) + self.code_status_label.setStyleSheet("color: #f38ba8;") + self.confirm_btn.setEnabled(True) + self.back_btn.setEnabled(True) + else: + self.status_label.setText(msg) + self.status_label.setStyleSheet("color: #f38ba8;") + self._set_enabled(True) + + def show_success(self, msg): + if self.stack.currentIndex() == 1: + self.code_status_label.setText(msg) + self.code_status_label.setStyleSheet("color: #a6e3a1;") + else: + self.status_label.setText(msg) + self.status_label.setStyleSheet("color: #a6e3a1;") + + def reset(self): + self.stack.setCurrentIndex(0) + self.status_label.setText("") + self.status_label.setStyleSheet("") + self.code_status_label.setText("") + self.code_status_label.setStyleSheet("") + self.username_input.clear() + self.email_input.clear() + self.password_input.clear() + self.code_input.clear() + self._set_enabled(True) + self.confirm_btn.setEnabled(True) + self.back_btn.setEnabled(True) + + +class MainWindow(QWidget): + def __init__(self, bridge: AsyncBridge, on_logout): + super().__init__() + self.bridge = bridge + self._on_logout_cb = on_logout + self.setWindowTitle(f"Encrypted Chat - {bridge.client.username}") + self.resize(900, 600) + + self.conversations: list[dict] = [] + self.current_conv_id: str | None = None + self.current_messages: list[dict] = [] + self.reply_to_id: str | None = None + self._unread_counts: dict[str, int] = {} + self._has_more_messages: bool = True + self._pending_image_download: dict | None = None # {file_id, image_info} + self._is_dm: bool = False + self._online_users: set[str] = set() + self._is_logout = False + self._avatar_cache: dict[str, QPixmap] = {} # user_id -> pixmap + self._group_avatar_cache: dict[str, QPixmap] = {} # conv_id -> pixmap + self._avatar_requested: set[str] = set() + self._group_avatar_requested: set[str] = set() + self._pending_invitations: list[dict] = [] + self._favorites: set[str] = self._load_favorites() + # Search state + self._search_results: list[int] = [] # indices into current_messages + self._search_current: int = -1 + self._search_query: str = "" + self._search_active: bool = False + + self._build_ui() + self._connect_signals() + + # Keyboard shortcuts + QShortcut(QKeySequence("Ctrl+F"), self).activated.connect(self._toggle_search) + + self.bridge.load_conversations() + self.bridge.list_invitations() + + # Periodic refresh: re-download avatars and conversation data + self._refresh_timer = QTimer(self) + self._refresh_timer.timeout.connect(self._on_periodic_refresh) + self._refresh_timer.start(120_000) # every 2 minutes + + def _bold_font(self) -> QFont: + """Return a bold font with a valid size (avoids QFont pointSize=-1 warnings).""" + f = QFont(self.conv_list.font()) + f.setBold(True) + # Stylesheet sets font-size in px so pointSize is -1; fix by using pixelSize + if f.pointSize() <= 0: + px = f.pixelSize() + if px > 0: + f.setPixelSize(px) + else: + f.setPointSize(10) + return f + + def _make_circular_avatar(self, pixmap: QPixmap, size: int = 32) -> QPixmap: + """Crop a pixmap into a circle.""" + scaled = pixmap.scaled(size, size, Qt.AspectRatioMode.KeepAspectRatioByExpanding, + Qt.TransformationMode.SmoothTransformation) + result = QPixmap(size, size) + result.fill(QColor(0, 0, 0, 0)) + painter = QPainter(result) + painter.setRenderHint(QPainter.RenderHint.Antialiasing) + painter.setBrush(QBrush(scaled)) + painter.setPen(Qt.PenStyle.NoPen) + painter.drawEllipse(0, 0, size, size) + painter.end() + return result + + def _make_default_avatar(self, username: str, size: int = 32) -> QPixmap: + """Generate a colored circle with the first letter of the username.""" + # Deterministic color from username + hue = (hash(username) % 360) + color = QColor.fromHsv(hue, 120, 200) + result = QPixmap(size, size) + result.fill(QColor(0, 0, 0, 0)) + painter = QPainter(result) + painter.setRenderHint(QPainter.RenderHint.Antialiasing) + painter.setBrush(QBrush(color)) + painter.setPen(Qt.PenStyle.NoPen) + painter.drawEllipse(0, 0, size, size) + # Draw letter + painter.setPen(QColor(255, 255, 255)) + font = QFont("Segoe UI", int(size * 0.4)) + font.setBold(True) + painter.setFont(font) + letter = username[0].upper() if username else "?" + painter.drawText(0, 0, size, size, Qt.AlignmentFlag.AlignCenter, letter) + painter.end() + return result + + def _add_online_dot(self, avatar: QPixmap) -> QPixmap: + """Overlay a green dot on the bottom-right of an avatar pixmap.""" + result = QPixmap(avatar) + painter = QPainter(result) + painter.setRenderHint(QPainter.RenderHint.Antialiasing) + dot_size = max(8, avatar.width() // 4) + x = avatar.width() - dot_size + y = avatar.height() - dot_size + # Dark border + painter.setBrush(QBrush(QColor(0x1e, 0x1e, 0x2e))) + painter.setPen(Qt.PenStyle.NoPen) + painter.drawEllipse(x - 1, y - 1, dot_size + 2, dot_size + 2) + # Green dot + painter.setBrush(QBrush(QColor(0xa6, 0xe3, 0xa1))) + painter.drawEllipse(x, y, dot_size, dot_size) + painter.end() + return result + + def _get_conv_avatar(self, conv: dict) -> QIcon: + """Get avatar icon for a conversation list item.""" + is_dm = len(conv["members"]) == 2 and not conv.get("name") + if is_dm: + other = None + for m in conv["members"]: + if m.get("email") != self.bridge.client.email: + other = m + break + if other: + uid = other.get("user_id") or other.get("id") or "" + uname = other.get("username") or other.get("email") or "?" + if uid in self._avatar_cache: + avatar = self._make_circular_avatar(self._avatar_cache[uid]) + else: + avatar = self._make_default_avatar(uname) + # Request avatar download if not yet requested + if uid and uid not in self._avatar_requested: + self._avatar_requested.add(uid) + self.bridge.get_avatar(uid) + if uid in self._online_users: + avatar = self._add_online_dot(avatar) + return QIcon(avatar) + # Group: use group avatar if available + conv_id = conv.get("conversation_id") or "" + if conv_id in self._group_avatar_cache: + return QIcon(self._make_circular_avatar(self._group_avatar_cache[conv_id])) + gname = conv.get("name") or "G" + # Request group avatar download if has avatar_file + if conv.get("avatar_file") and conv_id and conv_id not in self._group_avatar_requested: + self._group_avatar_requested.add(conv_id) + self.bridge.get_group_avatar(conv_id) + return QIcon(self._make_default_avatar(gname)) + + def _build_ui(self): + main_layout = QHBoxLayout(self) + main_layout.setContentsMargins(0, 0, 0, 0) + main_layout.setSpacing(0) + + splitter = QSplitter(Qt.Orientation.Horizontal) + + # Left panel - conversations + left = QWidget() + left_layout = QVBoxLayout(left) + left_layout.setContentsMargins(8, 8, 4, 8) + + header_row = QHBoxLayout() + conv_label = QLabel("Conversations") + conv_label.setStyleSheet("font-weight: bold; font-size: 12pt; color: #89b4fa;") + header_row.addWidget(conv_label) + header_row.addStretch() + + new_chat_btn = QPushButton("") + new_chat_btn.setFixedSize(32, 32) + new_chat_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_FileDialogNewFolder)) + new_chat_btn.setToolTip("New Chat") + new_chat_btn.clicked.connect(self._on_new_chat) + header_row.addWidget(new_chat_btn) + + group_btn = QPushButton("") + group_btn.setFixedSize(32, 32) + group_btn.setObjectName("secondaryBtn") + group_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_DirIcon)) + group_btn.setToolTip("New Group") + group_btn.clicked.connect(self._on_new_group) + header_row.addWidget(group_btn) + + auth_btn = QPushButton("") + auth_btn.setFixedSize(32, 32) + auth_btn.setObjectName("secondaryBtn") + auth_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_DialogApplyButton)) + auth_btn.setToolTip("Authorize Device") + auth_btn.clicked.connect(self._on_authorize_device) + header_row.addWidget(auth_btn) + + rotate_btn = QPushButton("") + rotate_btn.setFixedSize(32, 32) + rotate_btn.setObjectName("secondaryBtn") + rotate_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_BrowserReload)) + rotate_btn.setToolTip("Rotate Keys") + rotate_btn.clicked.connect(self._on_rotate_keys) + header_row.addWidget(rotate_btn) + + profile_btn = QPushButton("") + profile_btn.setFixedSize(32, 32) + profile_btn.setObjectName("secondaryBtn") + profile_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_FileDialogInfoView)) + profile_btn.setToolTip("My Profile") + profile_btn.clicked.connect(self._on_my_profile) + header_row.addWidget(profile_btn) + + logout_btn = QPushButton("") + logout_btn.setFixedSize(32, 32) + logout_btn.setObjectName("secondaryBtn") + logout_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_DialogCloseButton)) + logout_btn.setToolTip("Logout") + logout_btn.clicked.connect(self._on_logout) + header_row.addWidget(logout_btn) + + left_layout.addLayout(header_row) + + # Invitation section (hidden when empty) + self.inv_label = QLabel("Pending Invitations") + self.inv_label.setStyleSheet("font-weight: bold; font-size: 9pt; color: #f9e2af; margin-top: 4px;") + self.inv_label.setVisible(False) + left_layout.addWidget(self.inv_label) + + self.inv_list = QListWidget() + self.inv_list.setMaximumHeight(120) + self.inv_list.setVisible(False) + self.inv_list.setContextMenuPolicy(Qt.ContextMenuPolicy.CustomContextMenu) + self.inv_list.customContextMenuRequested.connect(self._on_inv_context_menu) + self.inv_list.setStyleSheet( + "QListWidget { background-color: #1e1e2e; border: 1px solid #f9e2af; border-radius: 6px; padding: 2px; }" + "QListWidget::item { padding: 6px; color: #cdd6f4; }" + "QListWidget::item:hover { background-color: #252536; color: #cdd6f4; }" + ) + left_layout.addWidget(self.inv_list) + + self.conv_list = QListWidget() + from PyQt6.QtCore import QSize + self.conv_list.setIconSize(QSize(32, 32)) + self.conv_list.currentRowChanged.connect(self._on_conv_selected) + self.conv_list.setContextMenuPolicy(Qt.ContextMenuPolicy.CustomContextMenu) + self.conv_list.customContextMenuRequested.connect(self._on_conv_list_context_menu) + left_layout.addWidget(self.conv_list) + + # Right panel - messages + right = QWidget() + right_layout = QVBoxLayout(right) + right_layout.setContentsMargins(4, 8, 8, 8) + + chat_header_row = QHBoxLayout() + self.chat_header_avatar = QLabel() + self.chat_header_avatar.setFixedSize(28, 28) + self.chat_header_avatar.setVisible(False) + chat_header_row.addWidget(self.chat_header_avatar) + self.chat_header = QLabel("Select a conversation") + self.chat_header.setStyleSheet("font-weight: bold; font-size: 12pt; color: #89b4fa;") + chat_header_row.addWidget(self.chat_header) + + self.connection_dot = QLabel("\u25cf") + self.connection_dot.setFixedSize(16, 16) + self.connection_dot.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.connection_dot.setStyleSheet("color: #a6e3a1; font-size: 11pt;") + self.connection_dot.setToolTip("Connected") + chat_header_row.addWidget(self.connection_dot) + chat_header_row.addStretch() + + self.group_info_btn = QPushButton("") + self.group_info_btn.setFixedSize(32, 32) + self.group_info_btn.setObjectName("secondaryBtn") + self.group_info_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_MessageBoxInformation)) + self.group_info_btn.setToolTip("Group Info") + self.group_info_btn.clicked.connect(self._on_group_info) + self.group_info_btn.setVisible(False) + chat_header_row.addWidget(self.group_info_btn) + + self.user_info_btn = QPushButton("") + self.user_info_btn.setFixedSize(32, 32) + self.user_info_btn.setObjectName("secondaryBtn") + self.user_info_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_FileDialogInfoView)) + self.user_info_btn.setToolTip("User Info") + self.user_info_btn.clicked.connect(self._on_dm_user_info) + self.user_info_btn.setVisible(False) + chat_header_row.addWidget(self.user_info_btn) + + self.delete_conv_btn = QPushButton("") + self.delete_conv_btn.setFixedSize(32, 32) + self.delete_conv_btn.setStyleSheet( + "QPushButton { background-color: #45475a; color: #f38ba8; border: none; border-radius: 6px; }" + "QPushButton:hover { background-color: #f38ba8; color: #1e1e2e; }" + ) + self.delete_conv_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_TrashIcon)) + self.delete_conv_btn.setToolTip("Delete conversation") + self.delete_conv_btn.clicked.connect(self._on_delete_conv_btn) + self.delete_conv_btn.setVisible(False) + chat_header_row.addWidget(self.delete_conv_btn) + + self.add_member_btn = QPushButton("") + self.add_member_btn.setFixedSize(32, 32) + self.add_member_btn.setObjectName("secondaryBtn") + self.add_member_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_FileDialogNewFolder)) + self.add_member_btn.setToolTip("Add Member") + self.add_member_btn.clicked.connect(self._on_add_member) + self.add_member_btn.setVisible(False) + chat_header_row.addWidget(self.add_member_btn) + + self.search_btn = QPushButton("") + self.search_btn.setFixedSize(32, 32) + self.search_btn.setObjectName("secondaryBtn") + self.search_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_FileDialogContentsView)) + self.search_btn.setToolTip("Search messages (Ctrl+F)") + self.search_btn.clicked.connect(self._toggle_search) + self.search_btn.setVisible(False) + chat_header_row.addWidget(self.search_btn) + + right_layout.addLayout(chat_header_row) + + # Search bar (hidden by default) + self.search_widget = QWidget() + search_row = QHBoxLayout(self.search_widget) + search_row.setContentsMargins(0, 2, 0, 2) + self.search_input = QLineEdit() + self.search_input.setPlaceholderText("Search messages...") + self.search_input.setStyleSheet( + "QLineEdit { background-color: #313244; color: #cdd6f4; border: 1px solid #45475a; " + "border-radius: 4px; padding: 4px 8px; font-size: 10pt; }" + ) + self.search_input.textChanged.connect(self._on_search_text_changed) + self.search_input.returnPressed.connect(self._on_search_next) + # Escape in search input closes search + QShortcut(QKeySequence("Escape"), self.search_input).activated.connect(self._close_search) + search_row.addWidget(self.search_input, stretch=1) + self.search_prev_btn = QPushButton("\u25b2") + self.search_prev_btn.setFixedSize(28, 28) + self.search_prev_btn.setObjectName("secondaryBtn") + self.search_prev_btn.setToolTip("Previous match") + self.search_prev_btn.clicked.connect(self._on_search_prev) + search_row.addWidget(self.search_prev_btn) + self.search_next_btn = QPushButton("\u25bc") + self.search_next_btn.setFixedSize(28, 28) + self.search_next_btn.setObjectName("secondaryBtn") + self.search_next_btn.setToolTip("Next match") + self.search_next_btn.clicked.connect(self._on_search_next) + search_row.addWidget(self.search_next_btn) + self.search_count_label = QLabel("0/0") + self.search_count_label.setStyleSheet("color: #6c7086; font-size: 9pt; min-width: 40px;") + self.search_count_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + search_row.addWidget(self.search_count_label) + self.search_close_btn = QPushButton("\u2715") + self.search_close_btn.setFixedSize(28, 28) + self.search_close_btn.setObjectName("secondaryBtn") + self.search_close_btn.setToolTip("Close search") + self.search_close_btn.clicked.connect(self._close_search) + search_row.addWidget(self.search_close_btn) + self.search_widget.setVisible(False) + right_layout.addWidget(self.search_widget) + + self.load_more_btn = QPushButton("Load older messages") + self.load_more_btn.setObjectName("secondaryBtn") + self.load_more_btn.clicked.connect(self._on_load_more) + self.load_more_btn.setVisible(False) + right_layout.addWidget(self.load_more_btn) + + self.message_area = QTextBrowser() + self.message_area.setReadOnly(True) + self.message_area.setOpenExternalLinks(False) + self.message_area.setOpenLinks(False) + self.message_area.setContextMenuPolicy(Qt.ContextMenuPolicy.CustomContextMenu) + self.message_area.customContextMenuRequested.connect(self._on_message_context_menu) + self.message_area.anchorClicked.connect(self._on_anchor_clicked) + right_layout.addWidget(self.message_area, stretch=1) + + # Smart scroll: track if user is near bottom + self._is_near_bottom = True + self.message_area.verticalScrollBar().valueChanged.connect(self._on_scroll_changed) + + # "New messages" floating button (hidden by default) + self.jump_btn = QPushButton("New messages \u2193") + self.jump_btn.setParent(self.message_area) + self.jump_btn.setVisible(False) + self.jump_btn.setFixedHeight(28) + self.jump_btn.setStyleSheet( + "QPushButton { background-color: #89b4fa; color: #1e1e2e; border-radius: 14px; " + "padding: 0 16px; font-size: 8pt; font-weight: bold; }" + "QPushButton:hover { background-color: #74c7ec; }" + ) + self.jump_btn.clicked.connect(self._scroll_to_bottom) + + self.reply_label = QLabel("") + self.reply_label.setStyleSheet("color: #89b4fa; font-style: italic; padding: 2px 4px;") + self.reply_label.setVisible(False) + right_layout.addWidget(self.reply_label) + + # Input row + input_row = QHBoxLayout() + self.msg_input = MessageInput() + self.msg_input.send_requested.connect(self._on_send) + self.msg_input.textChanged.connect(self._on_input_changed) + input_row.addWidget(self.msg_input) + + attach_btn = QPushButton("") + attach_btn.setFixedSize(52, 72) + attach_btn.setObjectName("secondaryBtn") + attach_btn.setIconSize(QSize(24, 24)) + attach_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_FileIcon)) + attach_menu = QMenu(attach_btn) + attach_menu.addAction("Image", self._on_attach_image) + attach_menu.addAction("File", self._on_attach_file) + attach_btn.setMenu(attach_menu) + input_row.addWidget(attach_btn) + + send_btn = QPushButton("Send") + send_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_ArrowRight)) + send_btn.clicked.connect(self._on_send) + send_btn.setFixedHeight(72) + input_row.addWidget(send_btn) + right_layout.addLayout(input_row) + + self.char_counter = QLabel(f"0 / {MAX_INPUT_CHARS}") + self.char_counter.setStyleSheet("color: #6c7086; font-size: 8pt; padding: 0 4px;") + self.char_counter.setAlignment(Qt.AlignmentFlag.AlignRight) + right_layout.addWidget(self.char_counter) + + self.reencrypt_label = QLabel("") + self.reencrypt_label.setStyleSheet( + "background-color: #313244; border-radius: 6px; " + "padding: 8px 12px; color: #a6e3a1; font-weight: bold;" + ) + self.reencrypt_label.setVisible(False) + right_layout.addWidget(self.reencrypt_label) + + splitter.addWidget(left) + splitter.addWidget(right) + splitter.setStretchFactor(0, 1) + splitter.setStretchFactor(1, 3) + + # Wrap splitter + status bar in vertical layout for full-width status bar + wrapper = QVBoxLayout() + wrapper.setContentsMargins(0, 0, 0, 0) + wrapper.setSpacing(0) + wrapper.addWidget(splitter) + + # Status bar (permanent, fixed height, full width — no layout jumping) + self.status_bar = QLabel("") + self.status_bar.setFixedHeight(24) + self.status_bar.setStyleSheet( + "background-color: #181825; border-radius: 0px; " + "padding: 0 8px; color: #a6e3a1; font-size: 8pt;" + ) + self.status_bar.setCursor(Qt.CursorShape.PointingHandCursor) + self.status_bar.mousePressEvent = self._on_status_bar_click + self._status_bar_conv_id = None + wrapper.addWidget(self.status_bar) + + main_layout.addLayout(wrapper) + + def _connect_signals(self): + self.bridge.conversations_loaded.connect(self._on_conversations_loaded) + self.bridge.messages_loaded.connect(self._on_messages_loaded) + self.bridge.older_messages_loaded.connect(self._on_older_messages_loaded) + self.bridge.message_sent.connect(self._on_message_sent) + self.bridge.new_notification.connect(self._on_notification) + self.bridge.add_member_result.connect(self._on_add_member_result) + self.bridge.authorize_result.connect(self._on_authorize_result) + self.bridge.rotate_result.connect(self._on_rotate_result) + self.bridge.reencrypt_status.connect(self._on_reencrypt_status) + self.bridge.messages_read_notification.connect(self._on_messages_read) + self.bridge.remove_member_result.connect(self._on_remove_member_result) + self.bridge.message_deleted_notification.connect(self._on_message_deleted) + self.bridge.delete_message_result.connect(self._on_delete_message_result) + self.bridge.image_sent.connect(self._on_image_sent) + self.bridge.image_downloaded.connect(self._on_image_downloaded) + self.bridge.file_sent.connect(self._on_file_sent) + self.bridge.file_downloaded.connect(self._on_file_downloaded) + self.bridge.conversation_updated.connect(self._on_conversation_updated) + self.bridge.connection_state_changed.connect(self._on_connection_state_changed) + self.bridge.group_left.connect(self._on_group_left) + self.bridge.group_renamed.connect(self._on_group_renamed) + self.bridge.conversation_deleted.connect(self._on_conversation_deleted) + self.bridge.avatar_loaded.connect(self._on_avatar_for_conv_list) + self.bridge.invitations_loaded.connect(self._on_invitations_loaded) + self.bridge.invitation_result.connect(self._on_invitation_result) + self.bridge.invitation_received.connect(self._on_invitation_received) + self.bridge.online_status_changed.connect(self._on_online_status_changed) + self.bridge.online_users_loaded.connect(self._on_online_users_loaded) + self.bridge.group_avatar_loaded.connect(self._on_group_avatar_for_conv_list) + self.bridge.group_avatar_updated.connect(self._on_group_avatar_updated) + self.bridge.session_reset_notification.connect(self._on_session_reset) + + # ------------------------------------------------------------------ + # Favorites + # ------------------------------------------------------------------ + + def _favorites_path(self): + from chat_core import get_key_dir + return get_key_dir(self.bridge.client.email) / "favorites.json" + + def _load_favorites(self) -> set[str]: + try: + p = self._favorites_path() + if p.exists(): + return set(json.loads(p.read_text())) + except Exception: + pass + return set() + + def _save_favorites(self): + try: + self._favorites_path().write_text(json.dumps(list(self._favorites))) + except Exception: + pass + + def _on_conv_list_context_menu(self, pos): + item = self.conv_list.itemAt(pos) + if not item: + return + conv_id = item.data(Qt.ItemDataRole.UserRole) + if not conv_id: + return + from PyQt6.QtWidgets import QMenu + menu = QMenu(self) + is_fav = conv_id in self._favorites + action = menu.addAction("Odebrat z oblibených" if is_fav else "Přidat do oblíbených") + result = menu.exec(self.conv_list.mapToGlobal(pos)) + if result == action: + if is_fav: + self._favorites.discard(conv_id) + else: + self._favorites.add(conv_id) + self._save_favorites() + self._rebuild_conv_list() + + # ------------------------------------------------------------------ + # Conversation list helpers + # ------------------------------------------------------------------ + + def _get_conv_display_name(self, conv: dict) -> str: + """Get display name for a conversation (used for sorting and labels).""" + others = [m.get("username") or m.get("email") or "?" for m in conv["members"] + if m.get("email") != self.bridge.client.email] + return conv.get("name") or (", ".join(others) if others else self.bridge.client.username) + + def _get_conv_other_user_id(self, conv: dict) -> str: + """Get the other user's ID in a DM conversation (empty string for groups).""" + is_dm = len(conv["members"]) == 2 and not conv.get("name") + if not is_dm: + return "" + for m in conv["members"]: + if m.get("email") != self.bridge.client.email: + return m.get("user_id") or m.get("id") or "" + return "" + + def _get_conv_sort_key(self, conv: dict) -> tuple: + """Sort key: favorites first, then online DMs, then rest — alphabetically within each.""" + conv_id = conv.get("conversation_id", "") + is_fav = 0 if conv_id in self._favorites else 1 + other_uid = self._get_conv_other_user_id(conv) + is_online = 0 if other_uid and other_uid in self._online_users else 1 + name = self._get_conv_display_name(conv).lower() + return (is_fav, is_online, name) + + def _on_conversations_loaded(self, convs): + self.conversations = convs + # Populate unread counts from server (covers messages received while offline) + for c in convs: + cid = c["conversation_id"] + server_unread = c.get("unread_count", 0) + # Use the higher of server vs local (local may have newer real-time notifications) + if server_unread > self._unread_counts.get(cid, 0): + self._unread_counts[cid] = server_unread + self._rebuild_conv_list() + + def _rebuild_conv_list(self): + """Sort and rebuild the conversation list widget.""" + if not self.conversations: + return + # Sort: unread first, then online DMs, then rest — alphabetically within each group + self.conversations.sort(key=self._get_conv_sort_key) + prev_id = self.current_conv_id + self.conv_list.blockSignals(True) + self.conv_list.clear() + select_row = -1 + for i, c in enumerate(self.conversations): + conv_id = c["conversation_id"] + base_label = self._get_conv_display_name(c) + star = "\u2605 " if conv_id in self._favorites else "" + count = self._unread_counts.get(conv_id, 0) + label = f"{star}({count}) {base_label}" if count > 0 else f"{star}{base_label}" + item = QListWidgetItem(label) + item.setData(Qt.ItemDataRole.UserRole, conv_id) + item.setIcon(self._get_conv_avatar(c)) + if count > 0: + item.setData(Qt.ItemDataRole.FontRole, self._bold_font()) + item.setForeground(Qt.GlobalColor.white) + self.conv_list.addItem(item) + if conv_id == prev_id: + select_row = i + self.conv_list.blockSignals(False) + if select_row >= 0: + self.conv_list.setCurrentRow(select_row) + + def _on_conversation_updated(self): + """Refresh conversation list when a conversation is created/member added/removed.""" + self.bridge.load_conversations() + + def _on_periodic_refresh(self): + """Periodic refresh: re-download avatars for known users and reload invitations.""" + # Re-request avatars for all cached users (server returns latest) + for uid in list(self._avatar_requested): + self.bridge.get_avatar(uid) + # Re-request group avatars + for conv_id in list(self._group_avatar_requested): + self.bridge.get_group_avatar(conv_id) + self.bridge.list_invitations() + + def _on_online_users_loaded(self, user_ids): + self._online_users = set(user_ids) + self._rebuild_conv_list() + + def _on_online_status_changed(self, user_id, is_online): + if is_online: + self._online_users.add(user_id) + else: + self._online_users.discard(user_id) + self._rebuild_conv_list() + + def _on_avatar_for_conv_list(self, user_id, data): + """Cache downloaded avatar and refresh conversation list icons + chat header.""" + qimg = _safe_load_image(data) + if qimg is not None: + self._avatar_cache[user_id] = QPixmap.fromImage(qimg) + self._update_conv_list_styles() + # Refresh chat header avatar if current conv uses this user's avatar + self._refresh_chat_header_avatar() + + def _on_group_avatar_for_conv_list(self, conv_id, data): + """Cache downloaded group avatar and refresh conversation list icons + chat header.""" + qimg = _safe_load_image(data) + if qimg is not None: + self._group_avatar_cache[conv_id] = QPixmap.fromImage(qimg) + self._update_conv_list_styles() + # Refresh chat header avatar if current conv is this group + self._refresh_chat_header_avatar() + + def _on_group_avatar_updated(self, ok, msg): + if not ok: + QMessageBox.warning(self, "Group Avatar", msg) + + def _on_invitations_loaded(self, invitations): + self._pending_invitations = invitations + self.inv_list.clear() + if not invitations: + self.inv_label.setVisible(False) + self.inv_list.setVisible(False) + return + self.inv_label.setVisible(True) + self.inv_list.setVisible(True) + for inv in invitations: + conv_name = inv.get("conversation_name") or "Unnamed group" + inviter = inv.get("invited_by_username", "someone") + label = f"{conv_name} (from {inviter})" + item = QListWidgetItem(label) + item.setData(Qt.ItemDataRole.UserRole, inv["conversation_id"]) + self.inv_list.addItem(item) + + def _on_invitation_result(self, ok, msg): + if not ok: + QMessageBox.warning(self, "Invitation", msg) + + def _on_invitation_received(self, data): + """New invitation received via push notification.""" + self.bridge.list_invitations() + conv_name = data.get("conversation_name") or "a group" + inviter = data.get("invited_by_username", "Someone") + self.status_bar.setText(f"{inviter} invited you to {conv_name}") + self.status_bar.setStyleSheet( + "background-color: #181825; border-radius: 0px; " + "padding: 0 8px; color: #f9e2af; font-size: 8pt; font-weight: bold;" + ) + self._status_bar_conv_id = None + QTimer.singleShot(5000, self._clear_status_bar) + + def _on_inv_context_menu(self, pos): + item = self.inv_list.itemAt(pos) + if not item: + return + conv_id = item.data(Qt.ItemDataRole.UserRole) + if not conv_id: + return + menu = QMenu(self) + accept_action = menu.addAction("Accept") + decline_action = menu.addAction("Decline") + chosen = menu.exec(self.inv_list.mapToGlobal(pos)) + if chosen == accept_action: + self.bridge.accept_invitation(conv_id) + elif chosen == decline_action: + self.bridge.decline_invitation(conv_id) + + def _on_connection_state_changed(self, state): + if state == "connected": + self.connection_dot.setStyleSheet("color: #a6e3a1; font-size: 11pt;") + self.connection_dot.setToolTip("Connected") + self.status_bar.setText("Connected") + self.status_bar.setStyleSheet( + "background-color: #181825; border-radius: 0px; " + "padding: 0 8px; color: #a6e3a1; font-size: 8pt;" + ) + QTimer.singleShot(3000, self._clear_status_bar) + elif state == "disconnected": + self.connection_dot.setStyleSheet("color: #f38ba8; font-size: 11pt;") + self.connection_dot.setToolTip("Disconnected") + self.status_bar.setText("Disconnected from server") + self.status_bar.setStyleSheet( + "background-color: #181825; border-radius: 0px; " + "padding: 0 8px; color: #f38ba8; font-size: 8pt; font-weight: bold;" + ) + self._status_bar_conv_id = None + elif state == "reconnecting": + self.connection_dot.setStyleSheet("color: #fab387; font-size: 11pt;") + self.connection_dot.setToolTip("Reconnecting...") + self.status_bar.setText("Reconnecting...") + self.status_bar.setStyleSheet( + "background-color: #181825; border-radius: 0px; " + "padding: 0 8px; color: #fab387; font-size: 8pt;" + ) + self._status_bar_conv_id = None + elif state == "revoked": + self.connection_dot.setStyleSheet("color: #f38ba8; font-size: 11pt;") + self.connection_dot.setToolTip("Access revoked") + # Clear conversation list + self.conv_list.clear() + self.conversations = [] + self._unread_counts.clear() + # Clear open conversation + self.current_conv_id = None + self.chat_header.setText("Select a conversation") + self.chat_header_avatar.setVisible(False) + self.message_area.clear() + self.msg_input.setEnabled(False) + self.group_info_btn.setVisible(False) + self.user_info_btn.setVisible(False) + self.add_member_btn.setVisible(False) + self.delete_conv_btn.setVisible(False) + QMessageBox.warning(self, "Access Revoked", + "Your keys were rotated on another device. " + "This session is no longer valid.") + + def _on_scroll_changed(self, value): + sb = self.message_area.verticalScrollBar() + self._is_near_bottom = (sb.maximum() - value) < 60 + if self._is_near_bottom: + self.jump_btn.setVisible(False) + + def _scroll_to_bottom(self): + sb = self.message_area.verticalScrollBar() + sb.setValue(sb.maximum()) + self.jump_btn.setVisible(False) + + def _position_jump_btn(self): + w = self.message_area.width() + btn_w = self.jump_btn.sizeHint().width() + self.jump_btn.move((w - btn_w) // 2, self.message_area.height() - 40) + + def _clear_status_bar(self): + self.status_bar.setText("") + self.status_bar.setStyleSheet( + "background-color: #181825; border-radius: 0px; " + "padding: 0 8px; color: #a6e3a1; font-size: 8pt;" + ) + self._status_bar_conv_id = None + + def _on_status_bar_click(self, event): + conv_id = self._status_bar_conv_id + if conv_id: + for i, c in enumerate(self.conversations): + if c["conversation_id"] == conv_id: + self.conv_list.setCurrentRow(i) + self._clear_status_bar() + break + + def _update_chat_header_avatar(self, conv): + """Set the circular avatar next to the conversation name in the chat header.""" + is_dm = len(conv["members"]) == 2 and not conv.get("name") + size = 28 + if is_dm: + other = None + for m in conv["members"]: + if m.get("email") != self.bridge.client.email: + other = m + break + if other: + uid = other.get("user_id") or other.get("id") or "" + uname = other.get("username") or other.get("email") or "?" + if uid in self._avatar_cache: + avatar = self._make_circular_avatar(self._avatar_cache[uid], size) + else: + avatar = self._make_default_avatar(uname, size) + self.chat_header_avatar.setPixmap(avatar) + self.chat_header_avatar.setVisible(True) + else: + self.chat_header_avatar.setVisible(False) + else: + conv_id = conv.get("conversation_id") or "" + gname = conv.get("name") or "G" + if conv_id in self._group_avatar_cache: + avatar = self._make_circular_avatar(self._group_avatar_cache[conv_id], size) + else: + avatar = self._make_default_avatar(gname, size) + self.chat_header_avatar.setPixmap(avatar) + self.chat_header_avatar.setVisible(True) + + def _refresh_chat_header_avatar(self): + """Re-render chat header avatar for the currently selected conversation.""" + if not self.current_conv_id: + return + for c in self.conversations: + if c["conversation_id"] == self.current_conv_id: + self._update_chat_header_avatar(c) + return + + def _update_conv_list_styles(self): + for i in range(self.conv_list.count()): + item = self.conv_list.item(i) + conv_id = item.data(Qt.ItemDataRole.UserRole) + count = self._unread_counts.get(conv_id, 0) + others = [] + conv_name = None + conv = None + for c in self.conversations: + if c["conversation_id"] == conv_id: + others = [m.get("username") or m.get("email") or "?" for m in c["members"] + if m.get("email") != self.bridge.client.email] + conv_name = c.get("name") + conv = c + break + base_label = conv_name or (", ".join(others) if others else self.bridge.client.username) + star = "\u2605 " if conv_id in self._favorites else "" + item.setText(f"{star}({count}) {base_label}" if count > 0 else f"{star}{base_label}") + if conv: + item.setIcon(self._get_conv_avatar(conv)) + if count > 0: + item.setData(Qt.ItemDataRole.FontRole, self._bold_font()) + else: + item.setData(Qt.ItemDataRole.FontRole, None) + + def _on_conv_selected(self, row): + if row < 0 or row >= len(self.conversations): + return + conv = self.conversations[row] + self.current_conv_id = conv["conversation_id"] + others = [m.get("username") or m.get("email") or "?" for m in conv["members"] + if m.get("email") != self.bridge.client.email] + header = conv.get("name") or (", ".join(others) if others else self.bridge.client.username) + self.chat_header.setText(header) + # Set avatar in chat header + self._update_chat_header_avatar(conv) + is_group = len(conv["members"]) > 2 or conv.get("name") + self._is_dm = not is_group + self.add_member_btn.setVisible(bool(is_group)) + self.group_info_btn.setVisible(bool(is_group)) + self.user_info_btn.setVisible(self._is_dm) + # DMs: always show delete. Groups: only show for creator. + if self._is_dm: + self.delete_conv_btn.setVisible(True) + else: + my_user_id = self.bridge.client.session.get("user_id", "") if self.bridge.client.session else "" + self.delete_conv_btn.setVisible(conv.get("created_by") == my_user_id) + self.reply_to_id = None + self.reply_label.setVisible(False) + self._has_more_messages = True + self.load_more_btn.setVisible(False) + self._unread_counts.pop(self.current_conv_id, None) + self._update_conv_list_styles() + self.search_btn.setVisible(True) + self._close_search() + self.bridge.load_messages(self.current_conv_id) + + def _on_messages_loaded(self, conv_id, messages): + if conv_id != self.current_conv_id: + return + self.current_messages = messages + # Show "Load older" if we got a full batch (there may be more) + self._has_more_messages = len(messages) >= 50 + self.load_more_btn.setVisible(self._has_more_messages) + self._render_messages() + + def _render_messages(self, scroll_to_bottom=True): + html_parts = [] + for i, m in enumerate(self.current_messages): + html_parts.append(self._render_single_message_html(m, i)) + + self.message_area.setHtml("".join(html_parts)) + # Register thumbnail images as document resources + self._register_thumbnails() + # Re-set HTML so images resolve + self.message_area.setHtml("".join(html_parts)) + if scroll_to_bottom: + sb = self.message_area.verticalScrollBar() + sb.setValue(sb.maximum()) + + def _register_thumbnails(self): + """Add thumbnail images as resources to the QTextDocument.""" + from PyQt6.QtCore import QUrl + doc = self.message_area.document() + for m in self.current_messages: + image_info = m.get("image") + if not image_info: + continue + thumbnail_b64 = image_info.get("thumbnail", "") + file_id = image_info.get("file_id", "") + if not thumbnail_b64 or not file_id: + continue + from protocol import decode_binary + try: + thumb_bytes = decode_binary(thumbnail_b64) + qimg = _safe_load_image(thumb_bytes) + if qimg is not None: + url = QUrl(f"thumb://{file_id}") + doc.addResource( + doc.ResourceType.ImageResource.value, + url, + qimg, + ) + except Exception: + pass + + def _register_single_thumbnail(self, m): + """Register a single message's thumbnail as a document resource.""" + image_info = m.get("image") + if not image_info: + return + thumbnail_b64 = image_info.get("thumbnail", "") + file_id = image_info.get("file_id", "") + if not thumbnail_b64 or not file_id: + return + from PyQt6.QtCore import QUrl + from protocol import decode_binary + try: + thumb_bytes = decode_binary(thumbnail_b64) + qimg = _safe_load_image(thumb_bytes) + if qimg is not None: + doc = self.message_area.document() + url = QUrl(f"thumb://{file_id}") + doc.addResource( + doc.ResourceType.ImageResource.value, + url, + qimg, + ) + except Exception: + pass + + def _on_older_messages_loaded(self, conv_id, messages): + if conv_id != self.current_conv_id: + return + if not messages: + self._has_more_messages = False + self.load_more_btn.setVisible(False) + return + self._has_more_messages = len(messages) >= 50 + self.load_more_btn.setVisible(self._has_more_messages) + # Prepend older messages and re-render + self.current_messages = messages + self.current_messages + self._render_messages(scroll_to_bottom=False) + + def _on_load_more(self): + if not self.current_conv_id or not self._has_more_messages: + return + offset = len(self.current_messages) + self.bridge.load_older_messages(self.current_conv_id, offset) + + def _render_single_message_html(self, m, index): + """Render HTML for a single message.""" + is_dm = self._is_dm + + # Handle deleted messages + if m.get("deleted"): + timestamp = m.get("created_at", "") + time_str = "" + if timestamp: + time_short = timestamp[11:16] if len(timestamp) >= 16 else timestamp + time_str = f' \u00b7 {time_short}' + prefix = "" + is_me = m.get("sender") == self.bridge.client.username or m.get("sender_id") == ( + self.bridge.client.session.get("user_id", "") if self.bridge.client.session else "") + if is_me: + del_align = "right" + del_border = "border-right:3px solid #6c7086; border-left:none;" + else: + del_align = "left" + del_border = "border-left:3px solid #6c7086;" + return ( + f'
' + f'' + f'' + f'{prefix}{time_str} Zpr\u00e1va byla smaz\u00e1na' + f'
' + ) + + sender = m.get("sender", "???") + text = m.get("text", "") + timestamp = m.get("created_at", "") + text = _linkify_urls(text) + text = text.replace("\n", "
") + # Search highlighting + if self._search_active and self._search_query and index in self._search_results: + is_current_match = (self._search_current >= 0 + and self._search_current < len(self._search_results) + and self._search_results[self._search_current] == index) + bg_color = "#fab387" if is_current_match else "#f9e2af" + text = self._highlight_search_text(text, self._search_query, bg_color) + sender_esc = sender.replace("&", "&").replace("<", "<").replace(">", ">") + + is_me = sender == self.bridge.client.username + if is_me: + color = "#89b4fa" + bg = "#1e1e3a" + border = "border-right:3px solid #89b4fa; border-left:none;" + align = "right" + else: + color = "#f9e2af" + bg = "#1e1e2e" + border = "border-left:3px solid #f9e2af; border-right:none;" + align = "left" + + reply_html = "" + if m.get("reply_to"): + for orig in self.current_messages: + if orig["message_id"] == m["reply_to"]: + orig_sender = orig.get("sender", "???") + orig_sender_esc = orig_sender.replace("&", "&").replace("<", "<").replace(">", ">") + orig_text = orig.get("text", "")[:50] + orig_text = orig_text.replace("&", "&").replace("<", "<").replace(">", ">") + reply_html = ( + f'
' + f'{orig_sender_esc}
{orig_text}' + f'
' + ) + break + + time_str = "" + if timestamp: + time_short = timestamp[11:16] if len(timestamp) >= 16 else timestamp + time_str = f' \u00b7 {time_short}' + + # Image rendering + image_html = "" + image_info = m.get("image") + if image_info: + thumbnail_b64 = image_info.get("thumbnail", "") + filename = image_info.get("filename", "image") + filename_esc = filename.replace("&", "&").replace("<", "<").replace(">", ">") + size_bytes = image_info.get("size", 0) + size_kb = size_bytes / 1024 + if size_kb >= 1024: + size_str = f"{size_kb/1024:.1f} MB" + else: + size_str = f"{size_kb:.0f} KB" + file_id = image_info.get("file_id", "") + if thumbnail_b64: + image_html = ( + f'
' + f'' + f'' + f'
' + f'' + ) + # Clear text if it's just the placeholder + if text == "[Image: " + filename_esc + "]": + text = "" + + # File rendering + file_html = "" + file_info = m.get("file") + if file_info: + fname = file_info.get("filename", "file") + fname_esc = fname.replace("&", "&").replace("<", "<").replace(">", ">") + fsize = file_info.get("size", 0) + size_str = self._human_file_size(fsize) + f_id = file_info.get("file_id", "") + icon = self._file_icon(fname) + file_html = ( + f'
' + f'{icon} {fname_esc}' + f' ({size_str})' + f'
' + ) + + read_html = "" + if is_me: + read_by = m.get("read_by", []) + member_map = {} + for c in self.conversations: + if c["conversation_id"] == self.current_conv_id: + for mem in c["members"]: + uid = mem.get("user_id") or mem.get("id") + if uid: + member_map[uid] = mem.get("username") or mem.get("email") or "?" + break + my_user_id = self.bridge.client.session.get("user_id", "") if self.bridge.client.session else "" + others_read = [r for r in read_by if r.get("user_id") != my_user_id] + if others_read: + names = ", ".join(member_map.get(r["user_id"], r["user_id"][:8]) for r in others_read) + read_html = f'
\u2713\u2713 Read by {names}
' + else: + read_html = '
\u2713 Sent
' + + text_html = f'{text}' if text else '' + + # DM: no sender name, no message number — just timestamp + if is_dm: + header_line = f'{time_str.strip()}' + else: + header_line = f'{sender_esc}{time_str}' + + return ( + f'
' + f'' + f'' + f'{reply_html}' + f'{header_line}
' + f'{text_html}' + f'{image_html}' + f'{file_html}' + f'{read_html}' + f'
' + ) + + def _find_message_at_pos(self, pos): + """Find the message index at the given position in message_area.""" + cursor = self.message_area.cursorForPosition(pos) + # Walk backwards through blocks to find the nearest msg: anchor + block = cursor.block() + while block.isValid(): + frag_it = block.begin() + while frag_it != block.end(): + frag = frag_it.fragment() + if frag.isValid(): + fmt = frag.charFormat() + anchor = fmt.anchorNames() + if anchor: + for a in anchor: + if a.startswith("msg:"): + try: + return int(a[4:]) + except ValueError: + pass + frag_it += 1 + block = block.previous() + return None + + def _on_message_context_menu(self, pos): + if not self.current_messages: + return + idx = self._find_message_at_pos(pos) + if idx is None or idx < 0 or idx >= len(self.current_messages): + return + m = self.current_messages[idx] + if m.get("deleted"): + return + + my_user_id = self.bridge.client.session.get("user_id", "") if self.bridge.client.session else "" + menu = QMenu(self) + menu.setStyleSheet( + "QMenu { background-color: #313244; border: 1px solid #45475a; border-radius: 6px; padding: 4px; }" + "QMenu::item { padding: 6px 12px; color: #cdd6f4; }" + "QMenu::item:selected { background-color: #45475a; }" + ) + + reply_icon = self.style().standardIcon(QStyle.StandardPixmap.SP_ArrowBack) + reply_action = menu.addAction(reply_icon, "Reply") + + # Delete option for own messages + del_action = None + if m.get("sender_id") == my_user_id: + del_icon = self.style().standardIcon(QStyle.StandardPixmap.SP_TrashIcon) + del_action = menu.addAction(del_icon, "Delete") + + # View image option + img_action = None + if m.get("image"): + img_icon = self.style().standardIcon(QStyle.StandardPixmap.SP_FileDialogContentsView) + img_action = menu.addAction(img_icon, "View image") + + # Download file option + file_action = None + if m.get("file"): + file_icon = self.style().standardIcon(QStyle.StandardPixmap.SP_DialogSaveButton) + file_action = menu.addAction(file_icon, "Download file") + + # Reset session option for undecryptable messages + reset_action = None + if m.get("text", "").startswith("[Decryption failed"): + reset_icon = self.style().standardIcon(QStyle.StandardPixmap.SP_BrowserReload) + reset_action = menu.addAction(reset_icon, "Reset session with sender") + + chosen = menu.exec(self.message_area.mapToGlobal(pos)) + if not chosen: + return + if chosen == reply_action: + self.reply_to_id = m["message_id"] + sender = m.get("sender", "???") + preview = m.get("text", "")[:40] + self.reply_label.setText(f"Replying to {sender}: {preview}") + self.reply_label.setVisible(True) + self.msg_input.setFocus() + elif chosen == del_action: + confirm = QMessageBox.question( + self, "Delete Message", + "Delete this message? This cannot be undone.", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + ) + if confirm == QMessageBox.StandardButton.Yes: + self.bridge.delete_message(m["message_id"]) + elif chosen == img_action: + self._view_image(m) + elif chosen == file_action: + file_info = m.get("file") + if file_info: + self.bridge.download_file(file_info["file_id"], file_info) + elif chosen == reset_action: + sender_id = m.get("sender_id", "") + if sender_id: + confirm = QMessageBox.question( + self, "Reset Session", + "Reset encryption session with this sender? " + "A new session will be created on the next message.", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + ) + if confirm == QMessageBox.StandardButton.Yes: + self.bridge.reset_session(sender_id) + + # ------------------------------------------------------------------ + # Search + # ------------------------------------------------------------------ + + def _toggle_search(self): + if self._search_active: + self._close_search() + else: + if not self.current_conv_id: + return + self._search_active = True + self.search_widget.setVisible(True) + self.search_input.setFocus() + self.search_input.selectAll() + + def _close_search(self): + self._search_active = False + self._search_query = "" + self._search_results = [] + self._search_current = -1 + self.search_widget.setVisible(False) + self.search_input.clear() + self.search_count_label.setText("0/0") + # Re-render to remove highlights + if self.current_messages: + self._render_messages(scroll_to_bottom=False) + + def _on_search_text_changed(self, text): + self._search_query = text.strip() + if not self._search_query: + self._search_results = [] + self._search_current = -1 + self.search_count_label.setText("0/0") + self._render_messages(scroll_to_bottom=False) + return + query_lower = self._search_query.lower() + self._search_results = [] + for i, m in enumerate(self.current_messages): + if m.get("deleted"): + continue + msg_text = m.get("text", "") + if query_lower in msg_text.lower(): + self._search_results.append(i) + if self._search_results: + self._search_current = 0 + self.search_count_label.setText(f"1/{len(self._search_results)}") + else: + self._search_current = -1 + self.search_count_label.setText("0/0") + self._render_messages(scroll_to_bottom=False) + if self._search_results: + self._scroll_to_message(self._search_results[self._search_current]) + + def _on_search_next(self): + if not self._search_results: + return + self._search_current = (self._search_current + 1) % len(self._search_results) + self.search_count_label.setText(f"{self._search_current + 1}/{len(self._search_results)}") + self._render_messages(scroll_to_bottom=False) + self._scroll_to_message(self._search_results[self._search_current]) + + def _on_search_prev(self): + if not self._search_results: + return + self._search_current = (self._search_current - 1) % len(self._search_results) + self.search_count_label.setText(f"{self._search_current + 1}/{len(self._search_results)}") + self._render_messages(scroll_to_bottom=False) + self._scroll_to_message(self._search_results[self._search_current]) + + def _scroll_to_message(self, index): + self.message_area.scrollToAnchor(f"msg:{index}") + + @staticmethod + def _highlight_search_text(html_text: str, query: str, bg_color: str) -> str: + """Highlight matching text in HTML, skipping content inside tags.""" + query_esc = query.replace("&", "&").replace("<", "<").replace(">", ">") + if not query_esc: + return html_text + result = [] + i = 0 + q_lower = query_esc.lower() + q_len = len(query_esc) + while i < len(html_text): + if html_text[i] == '<': + # Skip HTML tags + end = html_text.find('>', i) + if end == -1: + result.append(html_text[i:]) + break + result.append(html_text[i:end + 1]) + i = end + 1 + else: + # Look for match in text content + chunk_end = html_text.find('<', i) + if chunk_end == -1: + chunk_end = len(html_text) + chunk = html_text[i:chunk_end] + # Case-insensitive replace within this chunk + chunk_lower = chunk.lower() + out = [] + j = 0 + while j < len(chunk): + pos = chunk_lower.find(q_lower, j) + if pos == -1: + out.append(chunk[j:]) + break + out.append(chunk[j:pos]) + matched = chunk[pos:pos + q_len] + out.append(f'{matched}') + j = pos + q_len + result.append("".join(out)) + i = chunk_end + return "".join(result) + + # ------------------------------------------------------------------ + # Session reset + # ------------------------------------------------------------------ + + def _on_session_reset(self, from_user_id, from_device_id): + # Find username for the user + username = from_user_id[:8] + for c in self.conversations: + for m in c.get("members", []): + uid = m.get("user_id") or m.get("id") + if uid == from_user_id: + username = m.get("username") or m.get("email") or username + break + self.status_bar.setText(f"Session with {username} was reset. New session will be created on next message.") + self.status_bar.setStyleSheet( + "background-color: #181825; border-radius: 0px; " + "padding: 0 8px; color: #f9e2af; font-size: 8pt; font-weight: bold;" + ) + QTimer.singleShot(8000, self._clear_status_bar) + + def _on_input_changed(self): + text = self.msg_input.toPlainText() + count = len(text) + if count > MAX_INPUT_CHARS: + cursor = self.msg_input.textCursor() + self.msg_input.setPlainText(text[:MAX_INPUT_CHARS]) + cursor.movePosition(cursor.MoveOperation.End) + self.msg_input.setTextCursor(cursor) + count = MAX_INPUT_CHARS + color = "#f38ba8" if count > MAX_INPUT_CHARS * 0.9 else "#6c7086" + self.char_counter.setStyleSheet(f"color: {color}; font-size: 8pt; padding: 0 4px;") + self.char_counter.setText(f"{count} / {MAX_INPUT_CHARS}") + + def _on_send(self): + text = self.msg_input.toPlainText().strip() + if not text or not self.current_conv_id: + return + if len(text) > MAX_INPUT_CHARS: + QMessageBox.warning(self, "Message Too Long", + f"Message too long (max {MAX_INPUT_CHARS} characters).") + return + conv = None + for c in self.conversations: + if c["conversation_id"] == self.current_conv_id: + conv = c + break + if not conv: + return + self.msg_input.clear() + self.bridge.send_message( + self.current_conv_id, text, conv["members"], + reply_to=self.reply_to_id, + ) + self.reply_to_id = None + self.reply_label.setVisible(False) + + def _on_message_sent(self, ok, msg): + if not ok: + QMessageBox.warning(self, "Error", msg) + + def _on_new_chat(self): + email, ok = QInputDialog.getText(self, "New Chat", "Email:") + if not ok or not email.strip(): + return + text, ok2 = QInputDialog.getText(self, "New Chat", "Message:") + if not ok2 or not text.strip(): + return + if len(text.strip()) > MAX_INPUT_CHARS: + QMessageBox.warning(self, "Message Too Long", + f"Message too long (max {MAX_INPUT_CHARS} characters).") + return + self.bridge.send_new_chat(email.strip(), text.strip()) + + def _on_new_group(self): + name, ok = QInputDialog.getText(self, "New Group", "Group name:") + if not ok: + return + members, ok2 = QInputDialog.getText(self, "New Group", "Member emails (comma-separated):") + if not ok2 or not members.strip(): + return + member_list = [m.strip() for m in members.split(",") if m.strip()] + if member_list: + self.bridge.create_group(member_list, name=name.strip() or None) + + def _on_add_member(self): + if not self.current_conv_id: + return + email, ok = QInputDialog.getText(self, "Add Member", "Email to add:") + if not ok or not email.strip(): + return + self.bridge.add_member(self.current_conv_id, email.strip()) + + def _on_add_member_result(self, ok, msg): + if ok: + QMessageBox.information(self, "Add Member", "Invitation sent.") + else: + QMessageBox.warning(self, "Add Member", msg) + + def _on_group_info(self): + if not self.current_conv_id: + return + conv = None + for c in self.conversations: + if c["conversation_id"] == self.current_conv_id: + conv = c + break + if not conv: + return + my_user_id = self.bridge.client.session.get("user_id", "") if self.bridge.client.session else "" + is_creator = conv.get("created_by") == my_user_id + group_name = conv.get("name") or "Group" + members = conv["members"] + + dlg = QDialog(self) + dlg.setWindowTitle("Group Info") + dlg.setMinimumWidth(380) + dlg_layout = QVBoxLayout(dlg) + + # Group avatar + avatar_row = QHBoxLayout() + avatar_label = QLabel() + avatar_label.setFixedSize(64, 64) + avatar_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + conv_id = conv["conversation_id"] + if conv_id in self._group_avatar_cache: + avatar_pix = self._make_circular_avatar(self._group_avatar_cache[conv_id], size=64) + else: + avatar_pix = self._make_default_avatar(group_name, size=64) + avatar_label.setPixmap(avatar_pix) + avatar_row.addWidget(avatar_label) + + group_name_esc = group_name.replace("&", "&").replace("<", "<").replace(">", ">") + title = QLabel(f"{group_name_esc}") + avatar_row.addWidget(title, stretch=1) + + if is_creator: + change_avatar_btn = QPushButton("Change Avatar") + change_avatar_btn.setObjectName("secondaryBtn") + change_avatar_btn.clicked.connect(lambda: self._do_change_group_avatar(conv_id, dlg)) + avatar_row.addWidget(change_avatar_btn) + + rename_btn = QPushButton("Rename") + rename_btn.setObjectName("secondaryBtn") + rename_btn.clicked.connect(lambda: self._do_rename_group(conv_id, group_name, dlg)) + avatar_row.addWidget(rename_btn) + + dlg_layout.addLayout(avatar_row) + + count_label = QLabel(f"Members ({len(members)}):") + count_label.setStyleSheet("margin-top: 8px;") + dlg_layout.addWidget(count_label) + + for mem in members: + uname = mem.get("username") or mem.get("email") or "?" + email = mem.get("email", "") + uid = mem.get("user_id") or mem.get("id") or "" + is_mem_creator = uid == conv.get("created_by") + + row = QHBoxLayout() + uname_esc = uname.replace("&", "&").replace("<", "<").replace(">", ">") + email_esc = email.replace("&", "&").replace("<", "<").replace(">", ">") + is_online = uid in self._online_users + online_dot = "\U0001f7e2 " if is_online else "" + name_text = f"{online_dot}{uname_esc}" + if email: + name_text += f" {email_esc}" + if is_mem_creator: + name_text += " creator" + name_label = QLabel(name_text) + name_label.setWordWrap(True) + row.addWidget(name_label, stretch=1) + + info_btn = QPushButton("") + info_btn.setFixedSize(28, 28) + info_btn.setObjectName("secondaryBtn") + info_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_MessageBoxInformation)) + info_btn.setToolTip(f"View profile of {uname}") + info_btn.clicked.connect(lambda checked, u=uid, d=dlg: (d.accept(), self._show_user_profile(u))) + row.addWidget(info_btn) + + # Remove button (only for creator, not on self) + if is_creator and uid != my_user_id: + remove_btn = QPushButton("") + remove_btn.setFixedSize(28, 28) + remove_btn.setObjectName("secondaryBtn") + remove_btn.setIcon(self.style().standardIcon(QStyle.StandardPixmap.SP_DialogCloseButton)) + remove_btn.setToolTip(f"Remove {uname}") + remove_btn.clicked.connect(lambda checked, u=uid, n=uname, d=dlg: self._do_remove_member_action(u, n, d)) + row.addWidget(remove_btn) + + dlg_layout.addLayout(row) + + dlg_layout.addSpacing(12) + + # Leave Group button + leave_btn = QPushButton("Leave Group") + leave_btn.setStyleSheet( + "QPushButton { background-color: #f38ba8; color: #1e1e2e; font-weight: bold; }" + "QPushButton:hover { background-color: #eba0ac; }" + ) + leave_btn.clicked.connect(lambda: self._do_leave_group_action(dlg)) + dlg_layout.addWidget(leave_btn) + + close_btn = QPushButton("Close") + close_btn.clicked.connect(dlg.accept) + dlg_layout.addWidget(close_btn) + dlg.exec() + + def _do_remove_member_action(self, user_id, username, dialog): + if not self.current_conv_id: + return + confirm = QMessageBox.question( + self, "Remove Member", + f"Remove {username} from the group?", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + ) + if confirm == QMessageBox.StandardButton.Yes: + dialog.accept() + self.bridge.remove_member(self.current_conv_id, user_id) + + def _do_change_group_avatar(self, conv_id, dialog): + path, _ = QFileDialog.getOpenFileName( + dialog, "Select Group Avatar", "", + "Images (*.png *.jpg *.jpeg);;All Files (*)", + ) + if not path: + return + try: + with open(path, "rb") as f: + image_data = f.read() + if len(image_data) > 2 * 1024 * 1024: + QMessageBox.warning(dialog, "Too Large", "Avatar must be under 2 MB.") + return + dialog.accept() + self.bridge.update_group_avatar(conv_id, image_data) + except Exception as e: + QMessageBox.warning(dialog, "Error", f"Failed to read image: {e}") + + def _do_rename_group(self, conv_id, current_name, dialog): + from PyQt6.QtWidgets import QInputDialog + new_name, ok = QInputDialog.getText( + dialog, "Rename Group", "New group name:", + text=current_name, + ) + if ok and new_name.strip(): + new_name = new_name.strip() + if new_name != current_name: + dialog.accept() + self.bridge.rename_conversation(conv_id, new_name) + + def _on_group_renamed(self, ok, msg): + if not ok: + QMessageBox.warning(self, "Rename Group", msg) + + def _do_leave_group_action(self, dialog): + if not self.current_conv_id: + return + confirm = QMessageBox.question( + self, "Leave Group", + "Leave this group? You will no longer receive messages.", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + ) + if confirm == QMessageBox.StandardButton.Yes: + dialog.accept() + self.bridge.leave_group(self.current_conv_id) + + def _on_group_left(self, ok, msg): + if ok: + self.current_conv_id = None + self.chat_header.setText("Select a conversation") + self.chat_header_avatar.setVisible(False) + self.message_area.clear() + self.group_info_btn.setVisible(False) + self.user_info_btn.setVisible(False) + self.add_member_btn.setVisible(False) + self.delete_conv_btn.setVisible(False) + else: + QMessageBox.warning(self, "Leave Group", msg) + + def _on_delete_conv_btn(self): + if not self.current_conv_id: + return + confirm = QMessageBox.question( + self, "Delete Conversation", + "Delete this conversation? This cannot be undone.", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + ) + if confirm == QMessageBox.StandardButton.Yes: + self.bridge.delete_conversation(self.current_conv_id) + + def _on_conversation_deleted(self, ok, msg): + if ok: + self.current_conv_id = None + self.chat_header.setText("Select a conversation") + self.chat_header_avatar.setVisible(False) + self.message_area.clear() + self.group_info_btn.setVisible(False) + self.user_info_btn.setVisible(False) + self.add_member_btn.setVisible(False) + self.delete_conv_btn.setVisible(False) + else: + QMessageBox.warning(self, "Delete Conversation", msg) + + def _on_remove_member_result(self, ok, msg): + if ok: + QMessageBox.information(self, "Remove Member", "Member removed.") + if self.current_conv_id: + self.bridge.load_messages(self.current_conv_id) + else: + QMessageBox.warning(self, "Remove Member", msg) + + def _on_my_profile(self): + my_user_id = self.bridge.client.session.get("user_id", "") if self.bridge.client.session else "" + if not my_user_id: + return + dlg = UserProfileDialog(self.bridge, my_user_id, editable=True, parent=self) + dlg.exec() + + def _on_dm_user_info(self): + """Show profile of the other user in a DM conversation.""" + if not self.current_conv_id: + return + conv = None + for c in self.conversations: + if c["conversation_id"] == self.current_conv_id: + conv = c + break + if not conv: + return + my_email = self.bridge.client.email + for m in conv["members"]: + if m.get("email") != my_email: + uid = m.get("user_id") or m.get("id") + if uid: + self._show_user_profile(uid) + return + + def _show_user_profile(self, user_id): + dlg = UserProfileDialog(self.bridge, user_id, editable=False, parent=self) + dlg.exec() + + def _on_authorize_device(self): + code, ok = QInputDialog.getText(self, "Authorize Device", "Pairing code:") + if not ok or not code.strip(): + return + self.bridge.authorize_device(code.strip()) + + def _on_rotate_keys(self): + confirm = QMessageBox.question( + self, + "Rotate Keys", + "This will revoke other devices. Continue?", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + ) + if confirm != QMessageBox.StandardButton.Yes: + return + password, ok = QInputDialog.getText(self, "Rotate Keys", "Password:", QLineEdit.EchoMode.Password) + if not ok or not password: + return + self.bridge.rotate_keys(self.bridge.client.username, password) + + def _on_logout(self): + confirm = QMessageBox.question( + self, + "Logout", + "Log out and return to the login screen?", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + ) + if confirm != QMessageBox.StandardButton.Yes: + return + self._is_logout = True + self.bridge.logout() + self.close() + if self._on_logout_cb: + self._on_logout_cb() + + def _on_notification(self, payload): + sender = payload.get("sender", "???") + conv_id = payload.get("conversation_id", "") + + # Show notification in status bar (with conv name for non-current conversations) + if conv_id and conv_id != self.current_conv_id: + conv_name = sender + is_notif_dm = False + for c in self.conversations: + if c["conversation_id"] == conv_id: + is_notif_dm = len(c["members"]) == 2 and not c.get("name") + if not is_notif_dm: + conv_name = c.get("name") or sender + break + if is_notif_dm: + self.status_bar.setText(f"New message from {sender}") + else: + self.status_bar.setText(f"New message from {sender} in {conv_name}") + self.status_bar.setStyleSheet( + "background-color: #181825; border-radius: 0px; " + "padding: 0 8px; color: #a6e3a1; font-size: 8pt; font-weight: bold;" + ) + self._status_bar_conv_id = conv_id + QTimer.singleShot(5000, self._clear_status_bar) + + # Increment unread count if not currently viewing this conversation + if conv_id and conv_id != self.current_conv_id: + self._unread_counts[conv_id] = self._unread_counts.get(conv_id, 0) + 1 + self._update_conv_list_styles() + + # Append directly to current conversation instead of re-fetching + if conv_id == self.current_conv_id: + self.current_messages.append(payload) + # Register thumbnail if this is an image message + self._register_single_thumbnail(payload) + idx = len(self.current_messages) - 1 + html = self._render_single_message_html(payload, idx) + self.message_area.append(html) + if self._is_near_bottom: + sb = self.message_area.verticalScrollBar() + sb.setValue(sb.maximum()) + else: + self.jump_btn.setVisible(True) + self._position_jump_btn() + # Mark as read + msg_id = payload.get("message_id") + if msg_id: + self.bridge.schedule( + self.bridge.client.mark_read(conv_id, [msg_id]) + ) + + def _on_messages_read(self, data): + conv_id = data.get("conversation_id", "") + if conv_id == self.current_conv_id: + # Update read status in memory instead of re-fetching + user_id = data.get("user_id", "") + message_ids = set(data.get("message_ids", [])) + for msg in self.current_messages: + if msg.get("message_id") in message_ids: + read_by = msg.get("read_by", []) + if not any(r.get("user_id") == user_id for r in read_by): + read_by.append({"user_id": user_id}) + msg["read_by"] = read_by + self._render_messages(scroll_to_bottom=self._is_near_bottom) + + def _on_anchor_clicked(self, url): + url_str = url.toString() + if url_str.startswith("image://"): + file_id = url_str[len("image://"):] + # Find message with this image + for msg in self.current_messages: + image_info = msg.get("image") + if image_info and image_info.get("file_id") == file_id: + self._view_image(msg) + return + elif url_str.startswith("file://"): + file_id = url_str[len("file://"):] + for msg in self.current_messages: + file_info = msg.get("file") + if file_info and file_info.get("file_id") == file_id: + self.bridge.download_file(file_id, file_info) + return + elif url_str.startswith("https://"): + QDesktopServices.openUrl(QUrl(url_str)) + elif url_str.startswith("http://"): + reply = QMessageBox.warning( + self, + "Nezabezpečený odkaz", + f"Tento odkaz používá nešifrované HTTP spojení.\n\n{url_str}\n\nChcete přesto pokračovat?", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + QMessageBox.StandardButton.No, + ) + if reply == QMessageBox.StandardButton.Yes: + QDesktopServices.openUrl(QUrl(url_str)) + + def _on_attach_image(self): + if not self.current_conv_id: + return + path, _ = QFileDialog.getOpenFileName( + self, "Select Image", "", + "Images (*.png *.jpg *.jpeg *.gif *.bmp *.webp);;All Files (*)", + ) + if not path: + return + conv = None + for c in self.conversations: + if c["conversation_id"] == self.current_conv_id: + conv = c + break + if not conv: + return + self.bridge.send_image( + self.current_conv_id, path, conv["members"], + reply_to=self.reply_to_id, + ) + self.reply_to_id = None + self.reply_label.setVisible(False) + + @staticmethod + def _human_file_size(size_bytes): + if size_bytes >= 1024 * 1024: + return f"{size_bytes / (1024 * 1024):.1f} MB" + elif size_bytes >= 1024: + return f"{size_bytes / 1024:.0f} KB" + return f"{size_bytes} B" + + @staticmethod + def _file_icon(filename: str) -> str: + """Return an emoji icon based on file extension.""" + ext = filename.rsplit(".", 1)[-1].lower() if "." in filename else "" + _icons = { + "pdf": "\U0001f4d5", # red book + "doc": "\U0001f4d8", # blue book + "docx": "\U0001f4d8", + "odt": "\U0001f4d8", + "xls": "\U0001f4ca", # bar chart + "xlsx": "\U0001f4ca", + "ods": "\U0001f4ca", + "csv": "\U0001f4ca", + "ppt": "\U0001f4d9", # orange book + "pptx": "\U0001f4d9", + "odp": "\U0001f4d9", + "zip": "\U0001f4e6", # package + "rar": "\U0001f4e6", + "7z": "\U0001f4e6", + "tar": "\U0001f4e6", + "gz": "\U0001f4e6", + "mp3": "\U0001f3b5", # music note + "wav": "\U0001f3b5", + "flac": "\U0001f3b5", + "ogg": "\U0001f3b5", + "m4a": "\U0001f3b5", + "mp4": "\U0001f3ac", # clapper board + "mkv": "\U0001f3ac", + "avi": "\U0001f3ac", + "mov": "\U0001f3ac", + "webm": "\U0001f3ac", + "py": "\U0001f40d", # snake + "js": "\U0001f4dc", # scroll + "ts": "\U0001f4dc", + "html": "\U0001f310", # globe + "css": "\U0001f3a8", # palette + "json": "\U0001f4cb", # clipboard + "xml": "\U0001f4cb", + "yaml": "\U0001f4cb", + "yml": "\U0001f4cb", + "txt": "\U0001f4c4", # page facing up + "log": "\U0001f4c4", + "md": "\U0001f4c4", + } + return _icons.get(ext, "\U0001f4ce") # default: paperclip + + def _on_attach_file(self): + if not self.current_conv_id: + return + path, _ = QFileDialog.getOpenFileName( + self, "Select File", "", + "All Files (*)", + ) + if not path: + return + conv = None + for c in self.conversations: + if c["conversation_id"] == self.current_conv_id: + conv = c + break + if not conv: + return + self.bridge.send_file( + self.current_conv_id, path, conv["members"], + reply_to=self.reply_to_id, + ) + self.reply_to_id = None + self.reply_label.setVisible(False) + + def _on_file_sent(self, ok, msg): + if not ok: + QMessageBox.warning(self, "File Error", msg) + + def _on_file_downloaded(self, data, file_info): + filename = _safe_filename(file_info.get("filename", "file"), "file") + path, _ = QFileDialog.getSaveFileName(self, "Save File", filename) + if path: + try: + with open(path, "wb") as f: + f.write(data) + QMessageBox.information(self, "Saved", f"File saved to {path}") + except Exception as e: + QMessageBox.warning(self, "Error", f"Failed to save: {e}") + + def _on_image_sent(self, ok, msg): + if not ok: + QMessageBox.warning(self, "Image Error", msg) + + def _view_image(self, msg): + image_info = msg.get("image") + if not image_info: + return + file_id = image_info.get("file_id", "") + self._pending_image_download = {"file_id": file_id, "image_info": image_info} + self.bridge.download_image(file_id, image_info) + + def _on_image_downloaded(self, file_id, data): + if not self._pending_image_download or self._pending_image_download["file_id"] != file_id: + return + image_info = self._pending_image_download["image_info"] + self._pending_image_download = None + self._show_image_dialog(data, image_info) + + def _show_image_dialog(self, image_data, image_info): + dlg = QDialog(self) + dlg.setWindowTitle(_safe_filename(image_info.get("filename", "Image"), "Image")) + dlg.setMinimumSize(400, 300) + layout = QVBoxLayout(dlg) + + qimg = _safe_load_image(image_data) + if qimg is None: + layout.addWidget(QLabel("Failed to load image.")) + else: + pixmap = QPixmap.fromImage(qimg) + label = QLabel() + # Scale down if larger than screen + screen_size = self.screen().availableSize() + max_w = int(screen_size.width() * 0.8) + max_h = int(screen_size.height() * 0.8) + if pixmap.width() > max_w or pixmap.height() > max_h: + pixmap = pixmap.scaled(max_w, max_h, Qt.AspectRatioMode.KeepAspectRatio, + Qt.TransformationMode.SmoothTransformation) + label.setPixmap(pixmap) + label.setAlignment(Qt.AlignmentFlag.AlignCenter) + + scroll = QScrollArea() + scroll.setWidget(label) + scroll.setWidgetResizable(True) + layout.addWidget(scroll) + + btn_row = QHBoxLayout() + save_btn = QPushButton("Save") + save_btn.clicked.connect(lambda: self._save_image(image_data, image_info, dlg)) + btn_row.addWidget(save_btn) + close_btn = QPushButton("Close") + close_btn.setObjectName("secondaryBtn") + close_btn.clicked.connect(dlg.accept) + btn_row.addWidget(close_btn) + layout.addLayout(btn_row) + + dlg.resize(min(pixmap.width() + 40, max_w) if not qimg.isNull() else 400, + min(pixmap.height() + 80, max_h) if not qimg.isNull() else 300) + dlg.exec() + + def _save_image(self, image_data, image_info, dialog): + filename = _safe_filename(image_info.get("filename", "image.jpg"), "image.jpg") + path, _ = QFileDialog.getSaveFileName(dialog, "Save Image", filename) + if path: + try: + with open(path, "wb") as f: + f.write(image_data) + QMessageBox.information(dialog, "Saved", f"Image saved to {path}") + except Exception as e: + QMessageBox.warning(dialog, "Error", f"Failed to save: {e}") + + def _on_message_deleted(self, data): + message_id = data.get("message_id", "") + conv_id = data.get("conversation_id", "") + if conv_id == self.current_conv_id: + for msg in self.current_messages: + if msg.get("message_id") == message_id: + msg["deleted"] = True + msg["text"] = "" + msg["image"] = None + break + self._render_messages() + + def _on_delete_message_result(self, ok, msg): + if not ok: + QMessageBox.warning(self, "Delete Error", msg) + return + # Reload messages to reflect deletion + if self.current_conv_id: + self.bridge.load_messages(self.current_conv_id) + + def _on_authorize_result(self, ok, msg): + if ok: + QMessageBox.information(self, "Authorize Device", msg) + else: + QMessageBox.warning(self, "Authorize Device", msg) + + def _on_rotate_result(self, ok, msg): + if ok: + QMessageBox.information(self, "Rotate Keys", msg) + else: + QMessageBox.warning(self, "Rotate Keys", msg) + + def _on_reencrypt_status(self, msg): + self.reencrypt_label.setText(msg) + self.reencrypt_label.setVisible(True) + if msg.lower().startswith("re-encryption complete"): + QTimer.singleShot(4000, lambda: self.reencrypt_label.setVisible(False)) + + def closeEvent(self, event): + if not self._is_logout: + self.bridge.stop() + self.bridge.wait(2000) + event.accept() + + +def main(): + setup_logging() + app = QApplication(sys.argv) + app.setStyleSheet(DARK_STYLE) + + bridge = AsyncBridge() + + login_win = LoginWindow(bridge) + main_win = [None] # mutable ref + + def on_connected(): + login_win.reset() + login_win.show() + + def on_conn_error(msg): + QMessageBox.critical(None, "Connection Error", f"Cannot connect to server:\n{msg}") + sys.exit(1) + + def on_register_result(ok, msg): + if ok: + # Show verification code page inline + hint = "" + if msg and len(msg) <= 6 and msg.isdigit(): + hint = f"Code: {msg}" + elif msg: + hint = msg + login_win.show_verification_page(hint) + + def do_confirm(code): + async def _confirm(): + okc, msgc = await bridge.client.confirm_registration( + login_win.email_input.text().strip(), + login_win.username_input.text().strip(), + code.strip(), + ) + if okc: + login_win.show_success(msgc) + bridge.do_login(login_win.email_input.text().strip(), login_win.password_input.text()) + else: + login_win.show_error(msgc) + bridge.schedule(_confirm()) + + login_win._confirm_callback = do_confirm + else: + login_win.show_error(msg) + + def on_pairing_code(code): + login_win.show_success(f"Pairing code: {code}") + + def on_pairing_complete(ok, msg): + if ok: + login_win.show_success(msg) + bridge.do_login(login_win._pair_email, login_win._pair_password) + else: + login_win.show_error(msg) + + def on_login_result(ok, msg): + if ok: + login_win.show_success(msg) + login_win.hide() + main_win[0] = MainWindow(bridge, on_logout=lambda: (login_win.reset(), login_win.show())) + main_win[0].show() + else: + login_win.show_error(msg) + + bridge.connected.connect(on_connected) + bridge.connection_error.connect(on_conn_error) + bridge.register_result.connect(on_register_result) + bridge.login_result.connect(on_login_result) + bridge.pairing_code.connect(on_pairing_code) + bridge.pairing_complete.connect(on_pairing_complete) + bridge.reconnected.connect(lambda: (login_win.reset(), login_win.show())) + + bridge.start() + + sys.exit(app.exec()) + + +if __name__ == "__main__": + main() diff --git a/zaloha/protocol.py b/zaloha/protocol.py new file mode 100644 index 0000000..30bfbdd --- /dev/null +++ b/zaloha/protocol.py @@ -0,0 +1,125 @@ +"""Newline-delimited JSON protocol with base64 encoding for binary data.""" + +import asyncio +import base64 +import binascii +import json +import os + + +def encode_binary(data: bytes) -> str: + """Encode bytes to base64 string.""" + return base64.b64encode(data).decode("ascii") + + +def decode_binary(data: str) -> bytes: + """Decode base64 string to bytes.""" + try: + return base64.b64decode(data) + except (TypeError, binascii.Error) as e: + raise ValueError(f"Invalid base64: {e}") + + +VERSION = "0.8.2" +MIN_CLIENT_VERSION = "0.8" # server rejects clients older than this + + +def version_gte(version: str, minimum: str) -> bool: + """Return True if version >= minimum (compares numeric tuples, e.g. '0.8.1' >= '0.8').""" + def _parse(v: str) -> tuple[int, ...]: + try: + return tuple(int(x) for x in v.split(".")) + except (ValueError, AttributeError): + return (0,) + return _parse(version) >= _parse(minimum) + + +MAX_MESSAGE_BYTES = int(os.getenv("MAX_MESSAGE_BYTES", "65536")) # 64 KiB default +MAX_IMAGE_BYTES = int(os.getenv("MAX_IMAGE_BYTES", str(5 * 1024 * 1024))) # 5 MiB default, 0 = no limit +MAX_FILE_BYTES = int(os.getenv("MAX_FILE_BYTES", str(50 * 1024 * 1024))) # 50 MiB default +IMAGE_CHUNK_SIZE = 32768 # 32 KiB raw chunk size for image upload/download + + +def build_request(msg_type: str, request_id: str | None = None, **kwargs) -> bytes: + """Build a protocol message (newline-terminated JSON).""" + msg = {"type": msg_type, **kwargs} + if request_id: + msg["request_id"] = request_id + return json.dumps(msg, ensure_ascii=False).encode("utf-8") + b"\n" + + +def build_response( + msg_type: str, + status: str, + data: dict | None = None, + request_id: str | None = None, +) -> bytes: + """Build a server response.""" + msg = {"type": msg_type, "status": status} + if data is not None: + msg["data"] = data + if request_id: + msg["request_id"] = request_id + return json.dumps(msg, ensure_ascii=False).encode("utf-8") + b"\n" + + +def parse_message(line: bytes) -> dict: + """Parse a single protocol message from bytes.""" + try: + return json.loads(line.decode("utf-8")) + except (json.JSONDecodeError, UnicodeDecodeError) as e: + raise ValueError(f"Invalid message: {e}") + + +class ProtocolReader: + """Read newline-delimited JSON messages from an asyncio StreamReader.""" + + def __init__(self, reader: asyncio.StreamReader): + self._reader = reader + + async def read_message(self) -> dict | None: + """Read and parse one message. Returns None on EOF.""" + try: + line = await self._reader.readuntil(b"\n") + except (asyncio.IncompleteReadError, ConnectionError): + return None + except asyncio.LimitOverrunError: + # Message exceeded limit — drain the internal buffer and signal error + self._reader._buffer.clear() + self._reader._maybe_resume_transport() + raise ValueError("Message exceeds maximum size") + if not line: + return None + return parse_message(line.strip()) + + +class ProtocolWriter: + """Write newline-delimited JSON messages to an asyncio StreamWriter.""" + + def __init__(self, writer: asyncio.StreamWriter): + self._writer = writer + + async def send_request(self, msg_type: str, request_id: str | None = None, **kwargs): + """Send a request message.""" + payload = build_request(msg_type, request_id=request_id, **kwargs) + if len(payload) > MAX_MESSAGE_BYTES: + raise ValueError(f"Message exceeds limit ({len(payload)} > {MAX_MESSAGE_BYTES})") + self._writer.write(payload) + await self._writer.drain() + + async def send_response( + self, + msg_type: str, + status: str, + data: dict | None = None, + request_id: str | None = None, + ): + """Send a response message.""" + payload = build_response(msg_type, status, data, request_id=request_id) + if len(payload) > MAX_MESSAGE_BYTES: + raise ValueError(f"Message exceeds limit ({len(payload)} > {MAX_MESSAGE_BYTES})") + self._writer.write(payload) + await self._writer.drain() + + def close(self): + self._writer.close() diff --git a/zaloha/requirements.txt b/zaloha/requirements.txt new file mode 100644 index 0000000..04ac3a1 --- /dev/null +++ b/zaloha/requirements.txt @@ -0,0 +1,7 @@ +cryptography>=42.0.0 +mysql-connector-python>=8.3.0 +python-dotenv>=1.0.0 +# GUI client (optional, needed for gui_client.py) +PyQt6>=6.6.0 +# Image sharing (optional, needed for send_image feature) +Pillow>=10.0.0 diff --git a/zaloha/schema.sql b/zaloha/schema.sql new file mode 100644 index 0000000..135acde --- /dev/null +++ b/zaloha/schema.sql @@ -0,0 +1,158 @@ +CREATE DATABASE IF NOT EXISTS encrypted_chat + CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; + +USE encrypted_chat; + +-- Users: identity_key is Ed25519 (32B), rsa_public_key for login challenge only +CREATE TABLE IF NOT EXISTS users ( + id CHAR(36) NOT NULL PRIMARY KEY, + username VARCHAR(255) NOT NULL, + email VARCHAR(255) NOT NULL UNIQUE, + rsa_public_key TEXT NOT NULL, + identity_key BLOB NOT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP +) ENGINE=InnoDB; + +-- Devices: each user can have multiple devices +CREATE TABLE IF NOT EXISTS devices ( + id CHAR(36) NOT NULL PRIMARY KEY, + user_id CHAR(36) NOT NULL, + device_name VARCHAR(255) DEFAULT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + last_seen_at DATETIME DEFAULT NULL, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, + INDEX idx_devices_user (user_id) +) ENGINE=InnoDB; + +-- Signed Pre-Keys (X25519, signed by Ed25519 identity key) — per device +CREATE TABLE IF NOT EXISTS signed_prekeys ( + id CHAR(36) NOT NULL PRIMARY KEY, + user_id CHAR(36) NOT NULL, + device_id CHAR(36) DEFAULT NULL, + public_key BLOB NOT NULL, + signature BLOB NOT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, + INDEX idx_spk_user_device (user_id, device_id) +) ENGINE=InnoDB; + +-- One-Time Pre-Keys (consumed on use) — per device +CREATE TABLE IF NOT EXISTS one_time_prekeys ( + id CHAR(36) NOT NULL PRIMARY KEY, + user_id CHAR(36) NOT NULL, + device_id CHAR(36) DEFAULT NULL, + public_key BLOB NOT NULL, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, + INDEX idx_opk_user_device (user_id, device_id) +) ENGINE=InnoDB; + +-- Conversations +CREATE TABLE IF NOT EXISTS conversations ( + id CHAR(36) NOT NULL PRIMARY KEY, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + name VARCHAR(255) DEFAULT NULL, + created_by CHAR(36) DEFAULT NULL, + avatar_file VARCHAR(255) DEFAULT NULL +) ENGINE=InnoDB; + +CREATE TABLE IF NOT EXISTS conversation_members ( + conversation_id CHAR(36) NOT NULL, + user_id CHAR(36) NOT NULL, + joined_at DATETIME NULL, + PRIMARY KEY (conversation_id, user_id), + FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE +) ENGINE=InnoDB; + +-- Group invitations (pending invitations to join a group) +CREATE TABLE IF NOT EXISTS group_invitations ( + id CHAR(36) NOT NULL PRIMARY KEY, + conversation_id CHAR(36) NOT NULL, + user_id CHAR(36) NOT NULL, + invited_by CHAR(36) NOT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + UNIQUE KEY uq_conv_user (conversation_id, user_id), + FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, + FOREIGN KEY (invited_by) REFERENCES users(id) ON DELETE CASCADE +) ENGINE=InnoDB; + +-- Messages: per-recipient ciphertext (Double Ratchet = each recipient has different ciphertext) +CREATE TABLE IF NOT EXISTS messages ( + id CHAR(36) NOT NULL PRIMARY KEY, + conversation_id CHAR(36) NOT NULL, + sender_id CHAR(36) NOT NULL, + sender_device_id CHAR(36) DEFAULT NULL, + ratchet_header BLOB NOT NULL, + x3dh_header BLOB DEFAULT NULL, + sender_chain_id BLOB DEFAULT NULL, + sender_chain_n INT DEFAULT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + deleted_at DATETIME DEFAULT NULL, + image_file_id CHAR(36) DEFAULT NULL, + FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE, + FOREIGN KEY (sender_id) REFERENCES users(id) ON DELETE CASCADE, + INDEX idx_messages_conv_created (conversation_id, created_at) +) ENGINE=InnoDB; + +-- Per-recipient encrypted content — per device +-- device_id '00000000-0000-0000-0000-000000000000' = self-encrypted / legacy +CREATE TABLE IF NOT EXISTS message_recipients ( + message_id CHAR(36) NOT NULL, + user_id CHAR(36) NOT NULL, + device_id CHAR(36) NOT NULL DEFAULT '00000000-0000-0000-0000-000000000000', + encrypted_content BLOB NOT NULL, + nonce BLOB NOT NULL, + ratchet_header BLOB DEFAULT NULL, + x3dh_header BLOB DEFAULT NULL, + PRIMARY KEY (message_id, user_id, device_id), + FOREIGN KEY (message_id) REFERENCES messages(id) ON DELETE CASCADE, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE +) ENGINE=InnoDB; + +-- Sender Keys for groups (distributed via pairwise ratchet) — per device +CREATE TABLE IF NOT EXISTS group_sender_keys ( + conversation_id CHAR(36) NOT NULL, + sender_id CHAR(36) NOT NULL, + device_id CHAR(36) NOT NULL DEFAULT '00000000-0000-0000-0000-000000000000', + chain_id BLOB NOT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (conversation_id, sender_id, device_id), + FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE, + FOREIGN KEY (sender_id) REFERENCES users(id) ON DELETE CASCADE +) ENGINE=InnoDB; + +-- Read receipts +CREATE TABLE IF NOT EXISTS message_reads ( + message_id CHAR(36) NOT NULL, + user_id CHAR(36) NOT NULL, + read_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (message_id, user_id), + FOREIGN KEY (message_id) REFERENCES messages(id) ON DELETE CASCADE, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE +) ENGINE=InnoDB; + +-- User profiles +CREATE TABLE IF NOT EXISTS user_profiles ( + user_id CHAR(36) NOT NULL PRIMARY KEY, + phone VARCHAR(50) DEFAULT NULL, + phone_visible TINYINT(1) NOT NULL DEFAULT 0, + email_visible TINYINT(1) NOT NULL DEFAULT 1, + location VARCHAR(255) DEFAULT NULL, + location_visible TINYINT(1) NOT NULL DEFAULT 0, + avatar_file VARCHAR(255) DEFAULT NULL, + updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE +) ENGINE=InnoDB; + +-- Image uploads +CREATE TABLE IF NOT EXISTS image_uploads ( + file_id CHAR(36) NOT NULL PRIMARY KEY, + conversation_id CHAR(36) NOT NULL, + uploader_id CHAR(36) NOT NULL, + file_size BIGINT NOT NULL DEFAULT 0, + completed BOOLEAN NOT NULL DEFAULT FALSE, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE, + FOREIGN KEY (uploader_id) REFERENCES users(id) ON DELETE CASCADE +) ENGINE=InnoDB; diff --git a/zaloha/server.py b/zaloha/server.py new file mode 100644 index 0000000..57344ca --- /dev/null +++ b/zaloha/server.py @@ -0,0 +1,2053 @@ +"""Asyncio TCP server — stores and relays encrypted blobs without seeing content.""" + +import asyncio +import json +import logging +import os +import re +import secrets +import signal +import smtplib +import ssl +import subprocess +import sys +from email.mime.text import MIMEText +from pathlib import Path +from datetime import datetime, timezone + +from dotenv import load_dotenv + +load_dotenv() + +import db +from crypto_utils import load_public_key, rsa_verify, load_ed25519_public, ed25519_verify, serialize_x25519_public +from protocol import VERSION, MIN_CLIENT_VERSION, version_gte, ProtocolReader, ProtocolWriter, encode_binary, decode_binary, MAX_MESSAGE_BYTES, MAX_IMAGE_BYTES, MAX_FILE_BYTES, IMAGE_CHUNK_SIZE + + +# Connected clients: user_id -> list[ProtocolWriter] +connected_clients: dict[str, list[ProtocolWriter]] = {} +# Writer -> device_id mapping (id(writer) -> device_id) +writer_device_map: dict[int, str] = {} +# Pairing sessions: code -> data +pairing_sessions: dict[str, dict] = {} +pending_registrations: dict[str, dict] = {} +# Pending image uploads: file_id -> {temp_path, received_bytes, file_size, conv_id} +pending_uploads: dict[str, dict] = {} +# Phantom user IDs (loaded at startup, updated on create/delete) +phantom_user_ids: set[str] = set() + +# Locks for shared mutable state (H4 race condition fix) +_clients_lock = asyncio.Lock() # Protects: connected_clients, writer_device_map, phantom_user_ids +_conn_lock = asyncio.Lock() # Protects: connection_counts, current_connections, rate_limits +_pairing_lock = asyncio.Lock() # Protects: pairing_sessions, pending_registrations +_uploads_lock = asyncio.Lock() # Protects: pending_uploads + +UPLOAD_DIR = Path(os.getenv("UPLOAD_DIR", "uploads")) + +# C6 fix: UUID validation + safe path construction to prevent path traversal +_UUID_RE = re.compile(r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$', re.IGNORECASE) + + +def _valid_uuid(value: str) -> bool: + """Validate that value is a canonical UUID (no path components).""" + return bool(_UUID_RE.match(value)) + + +# L8 fix: email validation to prevent phantom DB inflation +_EMAIL_RE = re.compile(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$") + + +def _valid_email(email: str) -> bool: + """Validate basic email format (L8).""" + return bool(_EMAIL_RE.match(email)) and len(email) <= 254 + + +def _append_file(path: Path, data: bytes): + """Append data to file (runs in thread pool to avoid blocking event loop).""" + with open(path, "ab") as f: + f.write(data) + + +def _read_file_chunk(path: Path, offset: int, size: int) -> bytes: + """Read a chunk from file (runs in thread pool to avoid blocking event loop).""" + with open(path, "rb") as f: + f.seek(offset) + return f.read(size) + + +def _safe_upload_path(file_id: str, suffix: str) -> Path | None: + """Return resolved path inside UPLOAD_DIR, or None if traversal detected.""" + p = (UPLOAD_DIR / f"{file_id}{suffix}").resolve() + if not p.is_relative_to(UPLOAD_DIR.resolve()): + return None + return p + + +def _safe_avatar_path(filename: str) -> Path | None: + """Return resolved avatar path inside UPLOAD_DIR/avatars, or None if traversal detected.""" + avatar_dir = (UPLOAD_DIR / "avatars").resolve() + p = (UPLOAD_DIR / "avatars" / filename).resolve() + if not p.is_relative_to(avatar_dir): + return None + return p + + +PAIRING_TTL_SECONDS = 120 +REGISTER_TTL_SECONDS = 3600 +PAIRING_MAX_POLL_ATTEMPTS = 90 + +# SMTP configuration for registration codes +SMTP_HOST = os.getenv("SMTP_HOST", "") +SMTP_PORT = int(os.getenv("SMTP_PORT", "587")) +SMTP_USER = os.getenv("SMTP_USER", "") +SMTP_PASS = os.getenv("SMTP_PASS", "") +SMTP_FROM = os.getenv("SMTP_FROM", "") +RATE_LIMIT_WINDOW = 60.0 # seconds +CONNECTION_RL_WINDOW = 1.0 # seconds +CONNECTION_RL_MAX = 20 # max requests per window per connection +MAX_CONNECTIONS_PER_IP = 10 +MAX_CONNECTIONS_GLOBAL = 200 + + +def setup_logging(): + level_name = os.getenv("LOG_LEVEL", "INFO").upper() + level = getattr(logging, level_name, logging.WARNING) + logging.basicConfig(level=level, format="%(levelname)s: %(message)s") + + +logger = logging.getLogger("encrypted_chat.server") + +rate_limits: dict[str, list[float]] = {} +connection_counts: dict[str, int] = {} +current_connections = 0 + + +def _rate_limit_key(action: str, addr: str, email: str | None = None) -> str: + if email: + return f"{action}|{addr}|{email}" + return f"{action}|{addr}" + + +async def _is_rate_limited(key: str, limit: int) -> bool: + async with _conn_lock: + now = asyncio.get_event_loop().time() + window_start = now - RATE_LIMIT_WINDOW + times = rate_limits.get(key, []) + times = [t for t in times if t >= window_start] + if len(times) >= limit: + rate_limits[key] = times + return True + times.append(now) + rate_limits[key] = times + return False + + +def _get_peer_addr(writer: ProtocolWriter) -> str: + try: + return str(writer._writer.get_extra_info("peername")[0]) + except Exception: + return "unknown" + + +async def _notify_users(user_ids, msg_type, data, exclude_writer=None): + """Snapshot writers under lock, send notifications outside lock.""" + targets = [] + async with _clients_lock: + for uid in user_ids: + for w in connected_clients.get(uid, []): + targets.append(w) + for w in targets: + if w is exclude_writer: + continue + try: + await w.send_response(msg_type, "ok", data) + except Exception: + pass + + +async def _notify_users_individual(notifications, exclude_writer=None): + """Send per-user data. notifications: list of (user_id, msg_type, data).""" + targets = [] + async with _clients_lock: + for uid, mt, d in notifications: + for w in connected_clients.get(uid, []): + targets.append((w, mt, d)) + for w, mt, d in targets: + if w is exclude_writer: + continue + try: + await w.send_response(mt, "ok", d) + except Exception: + pass + + +async def _cleanup_pairings(): + async with _pairing_lock: + now = asyncio.get_event_loop().time() + expired = [code for code, p in pairing_sessions.items() if now - p["created_at"] > PAIRING_TTL_SECONDS] + for code in expired: + pairing_sessions.pop(code, None) + + +async def _cleanup_registrations(): + async with _pairing_lock: + now = asyncio.get_event_loop().time() + expired = [code for code, p in pending_registrations.items() if now - p["created_at"] > REGISTER_TTL_SECONDS] + for code in expired: + pending_registrations.pop(code, None) + + +def _generate_pairing_code() -> str: + for _ in range(10): + code = f"{int.from_bytes(os.urandom(4), 'big') % 100000000:08d}" + if code not in pairing_sessions: + return code + return f"{int.from_bytes(os.urandom(4), 'big') % 100000000:08d}" + + +def _generate_register_code() -> str: + for _ in range(10): + code = f"{int.from_bytes(os.urandom(3), 'big') % 1000000:06d}" + if code not in pending_registrations: + return code + return f"{int.from_bytes(os.urandom(3), 'big') % 1000000:06d}" + +def _validate_public_key_pem(pem_str: str) -> bool: + """Validate that a string is a valid RSA public key PEM.""" + try: + key = load_public_key(pem_str.encode("utf-8")) + if key.key_size < 2048: + return False + return True + except Exception: + return False + + +def _send_registration_email(to_email: str, code: str) -> bool: + """Send registration code via SMTP. Returns True on success.""" + if not SMTP_HOST: + return False + try: + msg = MIMEText(f"Your registration code is: {code}\n\nThis code expires in 1 hour.") + msg["Subject"] = "Encrypted Chat - Registration Code" + msg["From"] = SMTP_FROM or SMTP_USER + msg["To"] = to_email + with smtplib.SMTP(SMTP_HOST, SMTP_PORT, timeout=10) as server: + server.starttls() + if SMTP_USER: + server.login(SMTP_USER, SMTP_PASS) + server.send_message(msg) + return True + except Exception as e: + logger.warning("Failed to send registration email: %s", e) + return False + + +async def send_resp(msg: dict, writer: ProtocolWriter, msg_type: str, status: str, data: dict | None = None): + await writer.send_response(msg_type, status, data, request_id=msg.get("request_id")) + + +async def handle_register_start(msg: dict, writer: ProtocolWriter) -> dict | None: + await _cleanup_registrations() + username = msg.get("username", "").strip() + public_key = msg.get("public_key", "").strip() + identity_key_b64 = msg.get("identity_key", "").strip() + email = msg.get("email", "").strip() + addr = _get_peer_addr(writer) + if await _is_rate_limited(_rate_limit_key("register_start", addr, email), 3): + await send_resp(msg, writer, "register_start", "error", {"message": "Too many attempts. Try later."}) + return None + if not username or not public_key or not email or not identity_key_b64: + await send_resp(msg, writer, "register_start", "error", {"message": "Missing fields"}) + return None + if not _validate_public_key_pem(public_key): + await send_resp(msg, writer, "register_start", "error", {"message": "Invalid public key format"}) + return None + # Validate identity key is 32 bytes + try: + ik_bytes = decode_binary(identity_key_b64) + if len(ik_bytes) != 32: + raise ValueError("Identity key must be 32 bytes") + load_ed25519_public(ik_bytes) + except Exception: + await send_resp(msg, writer, "register_start", "error", {"message": "Invalid identity key"}) + return None + existing_email = db.get_user_by_email(email) + phantom_id = None + if existing_email: + if existing_email.get("rsa_public_key") == "PHANTOM": + # Don't delete — will be upgraded in register_confirm to preserve + # FK references (group_invitations, conversation_members, etc.) + phantom_id = existing_email["id"] + else: + # H3 anti-enumeration: return same response as success to prevent + # attackers from discovering valid emails. User won't receive a code + # via email, so they can't confirm — silent failure. + logger.debug("Registration attempt for existing email (hidden from client).") + await send_resp(msg, writer, "register_start", "ok", {"message": "Code sent to your email."}) + return None + async with _pairing_lock: + code = _generate_register_code() + pending_registrations[code] = { + "username": username, + "public_key": public_key, + "identity_key": ik_bytes, + "email": email, + "created_at": asyncio.get_event_loop().time(), + "phantom_id": phantom_id, + } + logger.info("Registration started.") + email_sent = _send_registration_email(email, code) + if email_sent: + await send_resp(msg, writer, "register_start", "ok", {"message": "Code sent to your email."}) + else: + if SMTP_HOST: + logger.warning("SMTP configured but email failed for %s", email) + else: + logger.warning("No SMTP configured — returning code directly (dev mode).") + await send_resp(msg, writer, "register_start", "ok", {"code": code}) + return None + + +async def handle_register_confirm(msg: dict, writer: ProtocolWriter) -> dict | None: + await _cleanup_registrations() + email = msg.get("email", "").strip() + code = msg.get("code", "").strip() + addr = _get_peer_addr(writer) + if await _is_rate_limited(_rate_limit_key("register_confirm", addr, email), 3): + await send_resp(msg, writer, "register_confirm", "error", {"message": "Too many attempts. Try later."}) + return None + if not email or not code: + await send_resp(msg, writer, "register_confirm", "error", {"message": "Missing email or code"}) + return None + async with _pairing_lock: + pending = pending_registrations.get(code) + if pending and pending.get("email") == email: + pending_registrations.pop(code, None) + else: + pending = None + if not pending: + await send_resp(msg, writer, "register_confirm", "error", {"message": "Invalid or expired code"}) + return None + phantom_id = pending.get("phantom_id") + if phantom_id: + # Upgrade phantom in-place — preserves FK references (invitations, memberships) + user_id = db.upgrade_phantom_user( + phantom_id, + pending["username"], + pending["public_key"], + pending["identity_key"], + ) + if user_id: + async with _clients_lock: + phantom_user_ids.discard(phantom_id) + else: + # Phantom was deleted concurrently — fall back to normal create + user_id = db.create_user( + pending["username"], + pending["email"], + pending["public_key"], + pending["identity_key"], + ) + else: + user_id = db.create_user( + pending["username"], + pending["email"], + pending["public_key"], + pending["identity_key"], + ) + db.create_default_profile(user_id) + logger.info("User registered.") + await send_resp(msg, writer, "register_confirm", "ok", {"user_id": user_id}) + return None + + +async def handle_login_start(msg: dict, writer: ProtocolWriter, state: dict): + email = msg.get("email", "").strip() + addr = _get_peer_addr(writer) + if await _is_rate_limited(_rate_limit_key("login_start", addr, email), 10): + await send_resp(msg, writer, "login_start", "error", {"message": "Too many attempts. Try later."}) + return + if not email: + await send_resp(msg, writer, "login_start", "error", {"message": "Missing email"}) + return + user = db.get_user_by_email(email) + challenge = os.urandom(32) + state["login_email"] = email + state["login_challenge"] = challenge + if not user: + # H3 anti-enumeration: return a fake challenge so attacker can't distinguish + # "user not found" from "user exists". login_finish will fail with generic error. + state["_login_fake"] = True + await send_resp(msg, writer, "login_start", "ok", {"challenge": encode_binary(challenge)}) + + +async def handle_login_finish(msg: dict, writer: ProtocolWriter, state: dict) -> dict | None: + email = msg.get("email", "").strip() + signature_b64 = msg.get("signature", "") + challenge = state.get("login_challenge") + expected_email = state.get("login_email") + addr = _get_peer_addr(writer) + if await _is_rate_limited(_rate_limit_key("login_finish", addr, email), 10): + await send_resp(msg, writer, "login_finish", "error", {"message": "Too many attempts. Try later."}) + return None + if not email or not signature_b64: + await send_resp(msg, writer, "login_finish", "error", {"message": "Missing email or signature"}) + return None + if not challenge or expected_email != email: + await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"}) + return None + + # H3: if login_start was for a non-existent user, fail with generic error + is_fake = state.pop("_login_fake", False) + + try: + if is_fake: + await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"}) + return None + + user = db.get_user_by_email(email) + if not user: + await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"}) + return None + + public_key = load_public_key(user["rsa_public_key"].encode("utf-8")) + signature = decode_binary(signature_b64) + if not rsa_verify(public_key, signature, challenge): + await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"}) + return None + except ValueError: + # H5: invalid base64 in signature + await send_resp(msg, writer, "login_finish", "error", {"message": "Invalid credentials"}) + return None + finally: + state.pop("login_challenge", None) + state.pop("login_email", None) + + user_id = user["id"] + + # Version check: reject outdated clients + client_version = msg.get("client_version", "") + if client_version and not version_gte(client_version, MIN_CLIENT_VERSION): + await send_resp(msg, writer, "login_finish", "error", { + "message": f"Client version {client_version} is too old. Minimum required: {MIN_CLIENT_VERSION}", + "min_version": MIN_CLIENT_VERSION, + "server_version": VERSION, + }) + return None + + # Device registration: client may send device_id to reuse an existing device + client_device_id = msg.get("device_id") + device_id = None + if client_device_id: + dev = db.get_device(client_device_id) + if dev and dev["user_id"] == user_id: + device_id = client_device_id + if not device_id: + device_name = msg.get("device_name", "Unknown") + device_id = db.create_device(user_id, device_name) + db.update_device_last_seen(device_id) + + async with _clients_lock: + was_offline = user_id not in connected_clients or not connected_clients[user_id] + if user_id not in connected_clients: + connected_clients[user_id] = [] + connected_clients[user_id].append(writer) + writer_device_map[id(writer)] = device_id + logger.info("User logged in (device %s, client v%s).", device_id, client_version or "unknown") + await send_resp(msg, writer, "login_finish", "ok", { + "user_id": user_id, "username": user["username"], "email": user["email"], + "device_id": device_id, "server_version": VERSION, + }) + + # Send online status notifications + contacts = db.get_user_contacts(user_id) + online_targets = [] + async with _clients_lock: + online_contacts = [cid for cid in contacts if cid in connected_clients and connected_clients[cid]] + if was_offline: + for contact_id in contacts: + for cw in connected_clients.get(contact_id, []): + online_targets.append(cw) + await writer.send_response("online_users", "ok", {"user_ids": online_contacts}) + # Send online notifications outside lock + for cw in online_targets: + try: + await cw.send_response("user_online", "ok", {"user_id": user_id}) + except Exception: + pass + + return {"user_id": user_id, "username": user["username"], "email": user["email"], + "device_id": device_id} + + +async def handle_get_user_info(msg: dict, writer: ProtocolWriter): + """Get user info including identity key (for X3DH).""" + email = msg.get("email", "").strip() + user_id = msg.get("user_id", "").strip() + addr = _get_peer_addr(writer) + if await _is_rate_limited(_rate_limit_key("get_user_info", addr, email or user_id), 30): + await send_resp(msg, writer, "get_user_info", "error", {"message": "Too many attempts. Try later."}) + return + if user_id and not _valid_uuid(user_id): + await send_resp(msg, writer, "get_user_info", "error", {"message": "Invalid user_id"}) + return + user = None + if email: + user = db.get_user_by_email(email) + elif user_id: + user = db.get_user_by_id(user_id) + if not user: + await send_resp(msg, writer, "get_user_info", "error", {"message": "User not found"}) + return + ik = user.get("identity_key") + await send_resp(msg, writer, "get_user_info", "ok", { + "user_id": user["id"], + "username": user["username"], + "email": user["email"], + "identity_key": encode_binary(ik) if ik else "", + }) + + +async def handle_upload_prekeys(msg: dict, session: dict, writer: ProtocolWriter): + """Upload signed prekey + batch of one-time prekeys.""" + spk_data = msg.get("signed_prekey") + otps = msg.get("one_time_prekeys", []) + if not spk_data: + await send_resp(msg, writer, "upload_prekeys", "error", {"message": "Missing signed_prekey"}) + return + + spk_id = spk_data.get("id", "") + spk_pub_b64 = spk_data.get("public_key", "") + spk_sig_b64 = spk_data.get("signature", "") + if not spk_id or not spk_pub_b64 or not spk_sig_b64: + await send_resp(msg, writer, "upload_prekeys", "error", {"message": "Incomplete signed_prekey"}) + return + + spk_pub = decode_binary(spk_pub_b64) + spk_sig = decode_binary(spk_sig_b64) + + # Verify SPK signature with user's identity key + user = db.get_user_by_id(session["user_id"]) + if not user or not user.get("identity_key"): + await send_resp(msg, writer, "upload_prekeys", "error", {"message": "No identity key"}) + return + ik_pub = load_ed25519_public(user["identity_key"]) + if not ed25519_verify(ik_pub, spk_sig, spk_pub): + await send_resp(msg, writer, "upload_prekeys", "error", {"message": "Invalid SPK signature"}) + return + + device_id = session.get("device_id") + db.store_signed_prekey(session["user_id"], spk_id, spk_pub, spk_sig, device_id=device_id) + + # Store OTPs + otp_records = [] + for otp in otps: + otp_id = otp.get("id", "") + otp_pub_b64 = otp.get("public_key", "") + if otp_id and otp_pub_b64: + otp_records.append({"id": otp_id, "public_key": decode_binary(otp_pub_b64)}) + if otp_records: + db.store_one_time_prekeys(session["user_id"], otp_records, device_id=device_id) + + logger.info("Prekeys uploaded: 1 SPK + %d OTPs (device %s)", len(otp_records), device_id) + await send_resp(msg, writer, "upload_prekeys", "ok", {"message": "OK"}) + + +async def handle_get_key_bundle(msg: dict, session: dict, writer: ProtocolWriter): + """Fetch key bundle for X3DH. Returns per-device bundles. Consumes one OTP per device.""" + target_user_id = msg.get("user_id", "").strip() + if not target_user_id: + await send_resp(msg, writer, "get_key_bundle", "error", {"message": "Missing user_id"}) + return + if not _valid_uuid(target_user_id): + await send_resp(msg, writer, "get_key_bundle", "error", {"message": "Invalid user_id"}) + return + result = db.get_key_bundles_for_user(target_user_id) + if not result or not result.get("device_bundles"): + await send_resp(msg, writer, "get_key_bundle", "error", {"message": "Key bundle not available"}) + return + + device_bundles_data = [] + for b in result["device_bundles"]: + entry = { + "device_id": b.get("device_id"), + "signed_prekey_id": b["signed_prekey_id"], + "signed_prekey": encode_binary(b["signed_prekey_pub"]), + "spk_signature": encode_binary(b["spk_signature"]), + } + if b.get("opk_pub"): + entry["one_time_prekey_id"] = b["opk_id"] + entry["one_time_prekey"] = encode_binary(b["opk_pub"]) + device_bundles_data.append(entry) + + # Build response with both new multi-device format and legacy flat fields + first = device_bundles_data[0] if device_bundles_data else {} + data = { + "identity_key": encode_binary(result["identity_key"]), + "device_bundles": device_bundles_data, + # Legacy flat fields from first device bundle (backward compat) + "signed_prekey_id": first.get("signed_prekey_id", ""), + "signed_prekey": first.get("signed_prekey", ""), + "spk_signature": first.get("spk_signature", ""), + } + if first.get("one_time_prekey"): + data["one_time_prekey_id"] = first["one_time_prekey_id"] + data["one_time_prekey"] = first["one_time_prekey"] + await send_resp(msg, writer, "get_key_bundle", "ok", data) + + +async def handle_get_prekey_count(msg: dict, session: dict, writer: ProtocolWriter): + """How many OPKs does user have left (for this device)? Also returns SPK age for rotation.""" + device_id = session.get("device_id") + count = db.count_one_time_prekeys(session["user_id"], device_id=device_id) + spk_created_at = "" + spk = db.get_signed_prekey(session["user_id"], device_id=device_id) + if spk and spk.get("created_at"): + spk_created_at = spk["created_at"].isoformat() if hasattr(spk["created_at"], "isoformat") else str(spk["created_at"]) + await send_resp(msg, writer, "get_prekey_count", "ok", + {"count": count, "spk_created_at": spk_created_at}) + + +async def handle_rotate_keys(msg: dict, session: dict, writer: ProtocolWriter): + public_key = msg.get("public_key", "").strip() + if not public_key: + await send_resp(msg, writer, "rotate_keys", "error", {"message": "Missing public_key"}) + return + if not _validate_public_key_pem(public_key): + await send_resp(msg, writer, "rotate_keys", "error", {"message": "Invalid public key format"}) + return + db.update_user_rsa_key(session["user_id"], public_key) + logger.info("RSA key rotated.") + await send_resp(msg, writer, "rotate_keys", "ok", {"message": "OK"}) + # Disconnect other sessions + async with _clients_lock: + writers = connected_clients.get(session["user_id"], []) + others = [w for w in writers if w is not writer] + connected_clients[session["user_id"]] = [writer] + for w in others: + try: + w.close() + except Exception: + pass + + +async def handle_pairing_start(msg: dict, writer: ProtocolWriter): + await _cleanup_pairings() + email = msg.get("email", "").strip() + temp_public_key = msg.get("temp_public_key", "").strip() + addr = _get_peer_addr(writer) + if await _is_rate_limited(_rate_limit_key("pairing_start", addr, email), 10): + await send_resp(msg, writer, "pairing_start", "error", {"message": "Too many attempts. Try later."}) + return + if not email or not temp_public_key: + await send_resp(msg, writer, "pairing_start", "error", {"message": "Missing email or temp_public_key"}) + return + user = db.get_user_by_email(email) + if not user: + await send_resp(msg, writer, "pairing_start", "error", {"message": "User not found"}) + return + poll_token = secrets.token_hex(16) + async with _pairing_lock: + code = _generate_pairing_code() + pairing_sessions[code] = { + "email": email, + "temp_public_key": temp_public_key, + "created_at": asyncio.get_event_loop().time(), + "payload": None, + "poll_token": poll_token, + } + await send_resp(msg, writer, "pairing_start", "ok", {"code": code, "poll_token": poll_token}) + + +async def handle_pairing_claim(msg: dict, session: dict, writer: ProtocolWriter): + await _cleanup_pairings() + code = msg.get("code", "").strip() + if not code: + await send_resp(msg, writer, "pairing_claim", "error", {"message": "Missing code"}) + return + async with _pairing_lock: + p = pairing_sessions.get(code) + p_email = p["email"] if p else None + temp_pub = p["temp_public_key"] if p else None + if p: + # Extend TTL — re-encryption may run between claim and send + p["created_at"] = asyncio.get_event_loop().time() + if not p: + await send_resp(msg, writer, "pairing_claim", "error", {"message": "Invalid or expired code"}) + return + if p_email != session.get("email"): + await send_resp(msg, writer, "pairing_claim", "error", {"message": "Not authorized for this code"}) + return + await send_resp(msg, writer, "pairing_claim", "ok", {"temp_public_key": temp_pub}) + + +async def handle_pairing_send(msg: dict, session: dict, writer: ProtocolWriter): + await _cleanup_pairings() + code = msg.get("code", "").strip() + payload = msg.get("payload") + if not code or not payload: + await send_resp(msg, writer, "pairing_send", "error", {"message": "Missing code or payload"}) + return + error_msg = None + async with _pairing_lock: + p = pairing_sessions.get(code) + if not p: + error_msg = "Invalid or expired code" + elif p["email"] != session.get("email"): + error_msg = "Not authorized for this code" + else: + p["payload"] = payload + if error_msg: + await send_resp(msg, writer, "pairing_send", "error", {"message": error_msg}) + else: + await send_resp(msg, writer, "pairing_send", "ok", {"message": "OK"}) + + +async def handle_pairing_poll(msg: dict, writer: ProtocolWriter): + await _cleanup_pairings() + code = msg.get("code", "").strip() + poll_token = msg.get("poll_token", "").strip() + addr = _get_peer_addr(writer) + if await _is_rate_limited(_rate_limit_key("pairing_poll", addr), 120): + await send_resp(msg, writer, "pairing_poll", "error", {"message": "Too many attempts. Try later."}) + return + if not code: + await send_resp(msg, writer, "pairing_poll", "error", {"message": "Missing code"}) + return + if not poll_token: + await send_resp(msg, writer, "pairing_poll", "error", {"message": "Missing poll_token"}) + return + error_msg = None + ready = False + payload = None + async with _pairing_lock: + p = pairing_sessions.get(code) + if not p: + error_msg = "Invalid or expired code" + elif not secrets.compare_digest(p.get("poll_token", ""), poll_token): + error_msg = "Invalid poll_token" + else: + poll_attempts = p.get("poll_attempts", 0) + 1 + p["poll_attempts"] = poll_attempts + if poll_attempts > PAIRING_MAX_POLL_ATTEMPTS and not p.get("payload"): + pairing_sessions.pop(code, None) + error_msg = "Code invalidated due to too many attempts" + elif p.get("payload"): + ready = True + payload = p["payload"] + pairing_sessions.pop(code, None) + if error_msg: + await send_resp(msg, writer, "pairing_poll", "error", {"message": error_msg}) + elif ready: + await send_resp(msg, writer, "pairing_poll", "ok", {"ready": True, "payload": payload}) + else: + await send_resp(msg, writer, "pairing_poll", "ok", {"ready": False}) + + +async def handle_create_conversation(msg: dict, session: dict, writer: ProtocolWriter): + member_emails = msg.get("members", []) + name = msg.get("name") + # Resolve all member user IDs + other_users = [] + for email in member_emails: + u = db.get_user_by_email(email) + if not u: + u = db.create_phantom_user(email) + async with _clients_lock: + phantom_user_ids.add(u["id"]) + if u["id"] != session["user_id"]: + other_users.append(u) + is_dm = len(other_users) == 1 and not name + joined_at = datetime.now(timezone.utc) + if is_dm: + # DMs: add both members directly (no invitation) + all_ids = [session["user_id"]] + [u["id"] for u in other_users] + conv_id = db.create_conversation(all_ids, joined_at=joined_at, name=name, created_by=session["user_id"]) + logger.info("DM conversation created.") + await send_resp(msg, writer, "create_conversation", "ok", {"conversation_id": conv_id}) + # Notify the other member + members_info = db.get_conversation_members(conv_id) + member_list = [{"user_id": m["id"], "username": m["username"], "email": m["email"]} for m in members_info] + notif_data = { + "conversation_id": conv_id, + "name": name, + "created_by": session["user_id"], + "members": member_list, + } + await _notify_users([u["id"] for u in other_users], "conversation_created", notif_data) + else: + # Groups: only add creator, create invitations for others + conv_id = db.create_conversation([session["user_id"]], joined_at=joined_at, name=name, created_by=session["user_id"]) + logger.info("Group conversation created with invitations.") + # Create invitations for other members + creator_user = db.get_user_by_id(session["user_id"]) + creator_name = creator_user["username"] if creator_user else "Unknown" + invited_ids = [] + async with _clients_lock: + phantom_snapshot = set(phantom_user_ids) + for u in other_users: + db.create_invitation(conv_id, u["id"], session["user_id"]) + if u["id"] not in phantom_snapshot: + invited_ids.append(u["id"]) # only notify non-phantoms + inv_notif = { + "conversation_id": conv_id, + "conversation_name": name, + "invited_by": session["user_id"], + "invited_by_username": creator_name, + } + await _notify_users(invited_ids, "group_invitation", inv_notif) + await send_resp(msg, writer, "create_conversation", "ok", {"conversation_id": conv_id}) + + +async def handle_find_conversation(msg: dict, session: dict, writer: ProtocolWriter): + email = msg.get("email", "").strip() + if not email: + await send_resp(msg, writer, "find_conversation", "error", {"message": "Invalid request"}) + return + addr = _get_peer_addr(writer) + if await _is_rate_limited(_rate_limit_key("find_conversation", addr, email), 30): + await send_resp(msg, writer, "find_conversation", "error", {"message": "Too many attempts. Try later."}) + return + other = db.get_user_by_email(email) + if not other: + other = db.create_phantom_user(email) + async with _clients_lock: + phantom_user_ids.add(other["id"]) + conv_id = db.find_direct_conversation(session["user_id"], other["id"]) + await send_resp(msg, writer, "find_conversation", "ok", { + "conversation_id": conv_id, + "user_id": other["id"], + }) + + +async def handle_add_member(msg: dict, session: dict, writer: ProtocolWriter): + conv_id = msg.get("conversation_id", "") + email = msg.get("email", "").strip() + if not conv_id or not email: + await send_resp(msg, writer, "add_member", "error", {"message": "Invalid request"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, "add_member", "error", {"message": "Invalid conversation_id"}) + return + # L8: validate email format before phantom creation + addr = _get_peer_addr(writer) + if await _is_rate_limited(_rate_limit_key("add_member", addr, email), 10): + await send_resp(msg, writer, "add_member", "error", {"message": "Too many attempts. Try later."}) + return + if not db.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "add_member", "error", {"message": "Not a member"}) + return + user = db.get_user_by_email(email) + if not user: + # Create phantom for unregistered email (same as create_conversation) + user = db.create_phantom_user(email) + async with _clients_lock: + phantom_user_ids.add(user["id"]) + if db.is_conversation_member(conv_id, user["id"]): + await send_resp(msg, writer, "add_member", "error", {"message": "Already a member"}) + return + if db.has_pending_invitation(conv_id, user["id"]): + await send_resp(msg, writer, "add_member", "error", {"message": "Invitation already pending"}) + return + # Create invitation (for both real and phantom users) + db.create_invitation(conv_id, user["id"], session["user_id"]) + logger.info("Group invitation created.") + await send_resp(msg, writer, "add_member", "ok", {"user_id": user["id"]}) + # Push invitation notification only to non-phantom users + async with _clients_lock: + is_phantom = user["id"] in phantom_user_ids + if not is_phantom: + conv = db.get_conversation(conv_id) + creator_user = db.get_user_by_id(session["user_id"]) + creator_name = creator_user["username"] if creator_user else "Unknown" + inv_notif = { + "conversation_id": conv_id, + "conversation_name": conv.get("name") if conv else None, + "invited_by": session["user_id"], + "invited_by_username": creator_name, + } + await _notify_users([user["id"]], "group_invitation", inv_notif) + + +async def handle_accept_invitation(msg: dict, session: dict, writer: ProtocolWriter): + """Accept a group invitation — add user to conversation members.""" + conv_id = msg.get("conversation_id", "") + if not conv_id: + await send_resp(msg, writer, "accept_invitation", "error", {"message": "Missing conversation_id"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, "accept_invitation", "error", {"message": "Invalid conversation_id"}) + return + if not db.has_pending_invitation(conv_id, session["user_id"]): + await send_resp(msg, writer, "accept_invitation", "error", {"message": "No pending invitation"}) + return + joined_at = datetime.now(timezone.utc) + db.add_conversation_member(conv_id, session["user_id"], joined_at=joined_at) + db.delete_invitation(conv_id, session["user_id"]) + logger.info("Invitation accepted.") + await send_resp(msg, writer, "accept_invitation", "ok", {"conversation_id": conv_id}) + # Notify existing members about the new member + user = db.get_user_by_id(session["user_id"]) + notif_data = { + "conversation_id": conv_id, + "user_id": session["user_id"], + "username": user["username"] if user else "", + "email": user["email"] if user else "", + } + members = db.get_conversation_members(conv_id) + member_ids = [m["id"] for m in members if m["id"] != session["user_id"]] + await _notify_users(member_ids, "member_added", notif_data) + + +async def handle_decline_invitation(msg: dict, session: dict, writer: ProtocolWriter): + """Decline a group invitation.""" + conv_id = msg.get("conversation_id", "") + if not conv_id: + await send_resp(msg, writer, "decline_invitation", "error", {"message": "Missing conversation_id"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, "decline_invitation", "error", {"message": "Invalid conversation_id"}) + return + if not db.has_pending_invitation(conv_id, session["user_id"]): + await send_resp(msg, writer, "decline_invitation", "error", {"message": "No pending invitation"}) + return + db.delete_invitation(conv_id, session["user_id"]) + logger.info("Invitation declined.") + await send_resp(msg, writer, "decline_invitation", "ok", {"message": "OK"}) + + +async def handle_list_invitations(msg: dict, session: dict, writer: ProtocolWriter): + """List pending group invitations for the current user.""" + invitations = db.get_pending_invitations(session["user_id"]) + result = [] + for inv in invitations: + entry = { + "conversation_id": inv["conversation_id"], + "conversation_name": inv.get("conversation_name"), + "invited_by": inv["invited_by"], + "invited_by_username": inv.get("invited_by_username", ""), + "created_at": inv["created_at"].isoformat() if hasattr(inv["created_at"], "isoformat") else str(inv["created_at"]), + } + result.append(entry) + await send_resp(msg, writer, "list_invitations", "ok", {"invitations": result}) + + +async def handle_list_conversations(msg: dict, session: dict, writer: ProtocolWriter): + convs = db.list_user_conversations(session["user_id"]) + unread = db.get_unread_counts(session["user_id"]) + result = [] + for c in convs: + result.append({ + "conversation_id": c["id"], + "created_at": c["created_at"].isoformat() if hasattr(c["created_at"], "isoformat") else str(c["created_at"]), + "members": c["members"], + "name": c.get("name"), + "created_by": c.get("created_by"), + "avatar_file": c.get("avatar_file"), + "unread_count": unread.get(c["id"], 0), + }) + await send_resp(msg, writer, "list_conversations", "ok", {"conversations": result}) + + +async def handle_send_message(msg: dict, session: dict, writer: ProtocolWriter): + conv_id = msg.get("conversation_id", "") + if not conv_id: + await send_resp(msg, writer, "send_message", "error", {"message": "Missing conversation_id"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, "send_message", "error", {"message": "Invalid conversation_id"}) + return + addr = _get_peer_addr(writer) + if await _is_rate_limited(_rate_limit_key("send_message", addr, session.get("email")), 20): + await send_resp(msg, writer, "send_message", "error", {"message": "Too many attempts. Try later."}) + return + if not db.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "send_message", "error", {"message": "Not a member"}) + return + + # New protocol: ratchet_header + recipients[] with per-user ciphertext + ratchet_header_raw = msg.get("ratchet_header") + recipients_raw = msg.get("recipients") + if not ratchet_header_raw or not recipients_raw: + await send_resp(msg, writer, "send_message", "error", {"message": "Missing ratchet_header or recipients"}) + return + + ratchet_header = json.dumps(ratchet_header_raw).encode() if isinstance(ratchet_header_raw, dict) else \ + ratchet_header_raw.encode() if isinstance(ratchet_header_raw, str) else ratchet_header_raw + + x3dh_header_raw = msg.get("x3dh_header") + x3dh_header = None + if x3dh_header_raw: + x3dh_header = json.dumps(x3dh_header_raw).encode() if isinstance(x3dh_header_raw, dict) else \ + x3dh_header_raw.encode() if isinstance(x3dh_header_raw, str) else x3dh_header_raw + + sender_chain_id_b64 = msg.get("sender_chain_id") + sender_chain_id = decode_binary(sender_chain_id_b64) if sender_chain_id_b64 else None + sender_chain_n = msg.get("sender_chain_n") + + # Validate recipients are actual members + member_ids = {m["id"] for m in db.get_conversation_members(conv_id)} + async with _clients_lock: + phantom_snapshot = set(phantom_user_ids) + db_recipients = [] + for r in recipients_raw: + uid = r.get("user_id", "") + if uid not in member_ids: + continue + if uid in phantom_snapshot: + continue + ct_b64 = r.get("encrypted_content", "") + nonce_b64 = r.get("nonce", "") + if not ct_b64 or not nonce_b64: + continue + entry = { + "user_id": uid, + "encrypted_content": decode_binary(ct_b64), + "nonce": decode_binary(nonce_b64), + } + # Per-recipient device_id (multi-device support) + r_device_id = r.get("device_id") + if r_device_id: + entry["device_id"] = r_device_id + # Per-recipient ratchet header and x3dh header + r_rh = r.get("ratchet_header") + if r_rh: + entry["ratchet_header"] = json.dumps(r_rh).encode() if isinstance(r_rh, dict) else \ + r_rh.encode() if isinstance(r_rh, str) else r_rh + r_x3dh = r.get("x3dh_header") + if r_x3dh: + entry["x3dh_header"] = json.dumps(r_x3dh).encode() if isinstance(r_x3dh, dict) else \ + r_x3dh.encode() if isinstance(r_x3dh, str) else r_x3dh + db_recipients.append(entry) + if not db_recipients: + await send_resp(msg, writer, "send_message", "error", {"message": "No valid recipients"}) + return + + image_file_id = msg.get("image_file_id") + msg_id = db.store_message( + conv_id, session["user_id"], ratchet_header, db_recipients, + x3dh_header=x3dh_header, + sender_chain_id=sender_chain_id, + sender_chain_n=sender_chain_n, + image_file_id=image_file_id, + sender_device_id=session.get("device_id"), + ) + + # Link image upload to message if present + if image_file_id: + upload = db.get_image_upload(image_file_id) + if upload and upload["completed"] and upload["uploader_id"] == session["user_id"]: + db.set_message_image_file_id(msg_id, image_file_id) + + logger.info("Message stored.") + await send_resp(msg, writer, "send_message", "ok", {"message_id": msg_id}) + + # Notify connected recipients — group all per-device entries by user_id + from collections import defaultdict + user_entries = defaultdict(list) + for r in recipients_raw: + uid = r.get("user_id", "") + user_entries[uid].append({ + "device_id": r.get("device_id", db.SELF_DEVICE_ID), + "encrypted_content": r.get("encrypted_content", ""), + "nonce": r.get("nonce", ""), + "ratchet_header": r.get("ratchet_header") or ratchet_header_raw, + "x3dh_header": r.get("x3dh_header") or x3dh_header_raw, + }) + + notifications = [] + for uid, entries in user_entries.items(): + notif_data = { + "message_id": msg_id, + "conversation_id": conv_id, + "sender_id": session["user_id"], + "sender_device_id": session.get("device_id"), + "device_entries": entries, + } + if sender_chain_id_b64: + notif_data["sender_chain_id"] = sender_chain_id_b64 + if sender_chain_n is not None: + notif_data["sender_chain_n"] = sender_chain_n + # Also include flat fields for backward compat with old clients + # (first entry's data as fallback) + if entries: + first = entries[0] + notif_data["ratchet_header"] = first.get("ratchet_header") or ratchet_header_raw + notif_data["encrypted_content"] = first.get("encrypted_content", "") + notif_data["nonce"] = first.get("nonce", "") + if first.get("x3dh_header"): + notif_data["x3dh_header"] = first["x3dh_header"] + notifications.append((uid, "new_message", notif_data)) + await _notify_users_individual(notifications, exclude_writer=writer) + + +async def handle_get_messages(msg: dict, session: dict, writer: ProtocolWriter): + conv_id = msg.get("conversation_id", "") + if not conv_id: + await send_resp(msg, writer, "get_messages", "error", {"message": "Missing conversation_id"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, "get_messages", "error", {"message": "Invalid conversation_id"}) + return + if not db.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "get_messages", "error", {"message": "Not a member"}) + return + + limit = min(max(int(msg.get("limit", 50)), 1), 200) + offset = max(int(msg.get("offset", 0)), 0) + device_id = session.get("device_id") + messages = db.get_messages(conv_id, session["user_id"], limit, offset, device_id=device_id) + + result = [] + message_ids = [m["id"] for m in messages] + read_status = db.get_message_read_status(message_ids) if message_ids else {} + for m in messages: + read_by = read_status.get(m["id"], []) + # Prefer per-recipient headers (mr_*) over message-level headers + rh_raw = m.get("mr_ratchet_header") or m.get("ratchet_header") + x3dh_raw = m.get("mr_x3dh_header") or m.get("x3dh_header") + entry = { + "message_id": m["id"], + "sender_id": m.get("sender_id") or "", + "ratchet_header": json.loads(rh_raw) if rh_raw else {}, + "encrypted_content": encode_binary(m["encrypted_content"]) if m.get("encrypted_content") else "", + "nonce": encode_binary(m["nonce"]) if m.get("nonce") else "", + "created_at": m["created_at"].isoformat() if hasattr(m["created_at"], "isoformat") else str(m["created_at"]), + "read_by": read_by, + } + if x3dh_raw: + entry["x3dh_header"] = json.loads(x3dh_raw) + if m.get("sender_chain_id"): + entry["sender_chain_id"] = encode_binary(m["sender_chain_id"]) + if m.get("sender_chain_n") is not None: + entry["sender_chain_n"] = m["sender_chain_n"] + if m.get("sender_device_id"): + entry["sender_device_id"] = m["sender_device_id"] + if m.get("deleted_at"): + entry["deleted_at"] = m["deleted_at"].isoformat() if hasattr(m["deleted_at"], "isoformat") else str(m["deleted_at"]) + result.append(entry) + await send_resp(msg, writer, "get_messages", "ok", {"messages": result}) + + +async def handle_remove_member(msg: dict, session: dict, writer: ProtocolWriter): + conv_id = msg.get("conversation_id", "") + user_id = msg.get("user_id", "") + if not conv_id or not user_id: + await send_resp(msg, writer, "remove_member", "error", {"message": "Missing conversation_id or user_id"}) + return + if not _valid_uuid(conv_id) or not _valid_uuid(user_id): + await send_resp(msg, writer, "remove_member", "error", {"message": "Invalid conversation_id or user_id"}) + return + if not db.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "remove_member", "error", {"message": "Not a member"}) + return + convs = db.list_user_conversations(session["user_id"]) + conv_data = None + for c in convs: + if c["id"] == conv_id: + conv_data = c + break + if not conv_data or conv_data.get("created_by") != session["user_id"]: + await send_resp(msg, writer, "remove_member", "error", {"message": "Only the group creator can remove members"}) + return + if user_id == session["user_id"]: + await send_resp(msg, writer, "remove_member", "error", {"message": "Cannot remove yourself"}) + return + # Get remaining members before removing (to notify them) + members_before = db.get_conversation_members(conv_id) + # M6: atomic removal — return value confirms row existed + removed = db.remove_conversation_member_atomic(conv_id, user_id) + if not removed: + await send_resp(msg, writer, "remove_member", "error", {"message": "Member already removed"}) + return + logger.info("Conversation member removed.") + await send_resp(msg, writer, "remove_member", "ok", {"message": "OK"}) + + # Notify removed member and remaining members + notif_data = { + "conversation_id": conv_id, + "user_id": user_id, + } + member_ids = [m["id"] for m in members_before if m["id"] != session["user_id"]] + await _notify_users(member_ids, "member_removed", notif_data) + + +async def handle_leave_group(msg: dict, session: dict, writer: ProtocolWriter): + """Leave a group conversation voluntarily.""" + conv_id = msg.get("conversation_id", "") + if not conv_id: + await send_resp(msg, writer, "leave_group", "error", {"message": "Missing conversation_id"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, "leave_group", "error", {"message": "Invalid conversation_id"}) + return + if not db.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "leave_group", "error", {"message": "Not a member"}) + return + # Don't allow leaving DMs (2 members without a name) + conv = db.get_conversation(conv_id) + members = db.get_conversation_members(conv_id) + if len(members) <= 2 and not (conv and conv.get("name")): + await send_resp(msg, writer, "leave_group", "error", {"message": "Cannot leave a DM conversation"}) + return + # If creator is leaving, transfer to first remaining member + if conv and conv.get("created_by") == session["user_id"]: + remaining = [m for m in members if m["id"] != session["user_id"]] + if remaining: + db.update_conversation_creator(conv_id, remaining[0]["id"]) + # M6: atomic removal + db.remove_conversation_member_atomic(conv_id, session["user_id"]) + logger.info("User left group.") + await send_resp(msg, writer, "leave_group", "ok", {"message": "OK"}) + # Notify remaining members + notif_data = { + "conversation_id": conv_id, + "user_id": session["user_id"], + } + member_ids = [m["id"] for m in members if m["id"] != session["user_id"]] + await _notify_users(member_ids, "member_removed", notif_data) + + +async def handle_rename_conversation(msg: dict, session: dict, writer: ProtocolWriter): + """Rename a group conversation (creator only).""" + conv_id = msg.get("conversation_id", "") + new_name = msg.get("name", "").strip() + if not conv_id or not new_name: + await send_resp(msg, writer, "rename_conversation", "error", {"message": "Missing conversation_id or name"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, "rename_conversation", "error", {"message": "Invalid conversation_id"}) + return + if len(new_name) > 100: + await send_resp(msg, writer, "rename_conversation", "error", {"message": "Name too long (max 100)"}) + return + if not db.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "rename_conversation", "error", {"message": "Not a member"}) + return + conv = db.get_conversation(conv_id) + if not conv or not conv.get("name"): + await send_resp(msg, writer, "rename_conversation", "error", {"message": "Cannot rename a DM conversation"}) + return + if conv.get("created_by") != session["user_id"]: + await send_resp(msg, writer, "rename_conversation", "error", {"message": "Only the group creator can rename"}) + return + db.update_conversation_name(conv_id, new_name) + logger.info("Group renamed: %s", conv_id) + await send_resp(msg, writer, "rename_conversation", "ok", {"message": "OK"}) + # Notify all members + members = db.get_conversation_members(conv_id) + member_ids = [m["id"] for m in members if m["id"] != session["user_id"]] + await _notify_users(member_ids, "conversation_renamed", { + "conversation_id": conv_id, + "name": new_name, + "renamed_by": session["user_id"], + }) + + +async def handle_delete_conversation(msg: dict, session: dict, writer: ProtocolWriter): + """Delete a conversation for the current user. Removes user from members, + deletes the conversation if no members remain.""" + conv_id = msg.get("conversation_id", "") + if not conv_id: + await send_resp(msg, writer, "delete_conversation", "error", {"message": "Missing conversation_id"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, "delete_conversation", "error", {"message": "Invalid conversation_id"}) + return + if not db.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "delete_conversation", "error", {"message": "Not a member"}) + return + conv = db.get_conversation(conv_id) + members = db.get_conversation_members(conv_id) + is_group = len(members) > 2 or (conv and conv.get("name")) + # Groups can only be deleted by the creator (admin) + if is_group and (not conv or conv.get("created_by") != session["user_id"]): + await send_resp(msg, writer, "delete_conversation", "error", {"message": "Only the group creator can delete this conversation"}) + return + if is_group: + # Group: creator deletes for everyone — remove all members, clean up, delete + for member in members: + db.remove_conversation_member(conv_id, member["id"]) + else: + # DM: only remove self; other user keeps the conversation + db.remove_conversation_member(conv_id, session["user_id"]) + remaining_count = db.count_conversation_members(conv_id) + if remaining_count == 0: + # Clean up uploaded files from disk + file_ids = db.get_conversation_file_ids(conv_id) + for fid in file_ids: + for ext in (".enc", ".tmp"): + p = _safe_upload_path(fid, ext) + if not p: + continue + try: + p.unlink(missing_ok=True) + except Exception: + pass + db.delete_conversation(conv_id) + logger.info("Conversation deleted for user.") + await send_resp(msg, writer, "delete_conversation", "ok", {"message": "OK"}) + # Notify other members they were removed + notif_data = { + "conversation_id": conv_id, + "user_id": session["user_id"], + } + member_ids = [m["id"] for m in members if m["id"] != session["user_id"]] + await _notify_users(member_ids, "member_removed", notif_data) + + +async def handle_mark_read(msg: dict, session: dict, writer: ProtocolWriter): + conv_id = msg.get("conversation_id", "") + message_ids = msg.get("message_ids", []) + if not conv_id or not message_ids: + await send_resp(msg, writer, "mark_read", "error", {"message": "Missing conversation_id or message_ids"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, "mark_read", "error", {"message": "Invalid conversation_id"}) + return + if len(message_ids) > 500: + await send_resp(msg, writer, "mark_read", "error", {"message": "Too many message_ids (max 500)"}) + return + if not db.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "mark_read", "error", {"message": "Not a member"}) + return + db.mark_messages_read(conv_id, session["user_id"], message_ids) + await send_resp(msg, writer, "mark_read", "ok", {"message": "OK"}) + members = db.get_conversation_members(conv_id) + notif_data = { + "conversation_id": conv_id, + "user_id": session["user_id"], + "message_ids": message_ids, + } + member_ids = [m["id"] for m in members if m["id"] != session["user_id"]] + await _notify_users(member_ids, "messages_read", notif_data) + + +async def handle_delete_message(msg: dict, session: dict, writer: ProtocolWriter): + message_id = msg.get("message_id", "") + if not message_id: + await send_resp(msg, writer, "delete_message", "error", {"message": "Missing message_id"}) + return + if not _valid_uuid(message_id): + await send_resp(msg, writer, "delete_message", "error", {"message": "Invalid message_id"}) + return + conv_id = db.get_message_conversation(message_id) + if not conv_id: + await send_resp(msg, writer, "delete_message", "error", {"message": "Message not found"}) + return + if not db.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "delete_message", "error", {"message": "Not a member"}) + return + result = db.soft_delete_message(message_id, session["user_id"]) + if result is None: + await send_resp(msg, writer, "delete_message", "error", {"message": "Cannot delete this message"}) + return + image_file_id = result.get("image_file_id") + if image_file_id: + image_path = _safe_upload_path(image_file_id, ".enc") + if image_path: + try: + image_path.unlink(missing_ok=True) + except Exception: + pass + db.delete_image_upload(image_file_id) + logger.info("Message deleted.") + await send_resp(msg, writer, "delete_message", "ok", {"message_id": message_id}) + members = db.get_conversation_members(conv_id) + notif_data = {"message_id": message_id, "conversation_id": conv_id} + member_ids = [m["id"] for m in members if m["id"] != session["user_id"]] + await _notify_users(member_ids, "message_deleted", notif_data) + + +async def handle_upload_image_start(msg: dict, session: dict, writer: ProtocolWriter): + conv_id = msg.get("conversation_id", "") + file_size = msg.get("file_size", 0) + file_id = msg.get("file_id", "") + file_type = msg.get("file_type", "image") # "image" or "file" + if not conv_id or not file_id: + await send_resp(msg, writer, "upload_image_start", "error", {"message": "Missing fields"}) + return + if not _valid_uuid(file_id): + await send_resp(msg, writer, "upload_image_start", "error", {"message": "Invalid file_id"}) + return + if not db.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "upload_image_start", "error", {"message": "Not a member"}) + return + max_bytes = MAX_FILE_BYTES if file_type == "file" else MAX_IMAGE_BYTES + if max_bytes > 0 and file_size > max_bytes: + await send_resp(msg, writer, "upload_image_start", "error", + {"message": f"File too large (max {max_bytes} bytes)"}) + return + UPLOAD_DIR.mkdir(parents=True, exist_ok=True) + temp_path = _safe_upload_path(file_id, ".tmp") + if not temp_path: + await send_resp(msg, writer, "upload_image_start", "error", {"message": "Invalid file_id"}) + return + temp_path.write_bytes(b"") + async with _uploads_lock: + pending_uploads[file_id] = { + "temp_path": str(temp_path), + "received_bytes": 0, + "file_size": file_size, + "max_bytes": max_bytes, + "conv_id": conv_id, + "uploader_id": session["user_id"], + } + db.create_image_upload(file_id, conv_id, session["user_id"], file_size) + logger.info("Image upload started: %s", file_id) + await send_resp(msg, writer, "upload_image_start", "ok", {"file_id": file_id}) + + +async def handle_upload_image_chunk(msg: dict, session: dict, writer: ProtocolWriter): + file_id = msg.get("file_id", "") + chunk_data = msg.get("data", "") + if not file_id or not chunk_data: + await send_resp(msg, writer, "upload_image_chunk", "error", {"message": "Missing fields"}) + return + async with _uploads_lock: + upload = pending_uploads.get(file_id) + if not upload or upload["uploader_id"] != session["user_id"]: + upload = None + else: + temp_path_str = upload["temp_path"] + upload_max = upload.get("max_bytes", 0) + if not upload: + await send_resp(msg, writer, "upload_image_chunk", "error", {"message": "No active upload"}) + return + raw = decode_binary(chunk_data) + temp_path = Path(temp_path_str) + await asyncio.to_thread(_append_file, temp_path, raw) + over_limit = False + async with _uploads_lock: + upload = pending_uploads.get(file_id) + if upload: + upload["received_bytes"] += len(raw) + if upload_max > 0 and upload["received_bytes"] > upload_max: + pending_uploads.pop(file_id, None) + over_limit = True + received = upload["received_bytes"] + if over_limit: + temp_path.unlink(missing_ok=True) + await send_resp(msg, writer, "upload_image_chunk", "error", {"message": "Upload exceeds size limit"}) + return + await send_resp(msg, writer, "upload_image_chunk", "ok", {"received": received}) + + +async def handle_upload_image_end(msg: dict, session: dict, writer: ProtocolWriter): + file_id = msg.get("file_id", "") + if not file_id: + await send_resp(msg, writer, "upload_image_end", "error", {"message": "Missing file_id"}) + return + async with _uploads_lock: + upload = pending_uploads.pop(file_id, None) + if not upload or upload["uploader_id"] != session["user_id"]: + await send_resp(msg, writer, "upload_image_end", "error", {"message": "No active upload"}) + return + temp_path = Path(upload["temp_path"]) + if upload["received_bytes"] != upload["file_size"]: + temp_path.unlink(missing_ok=True) + await send_resp(msg, writer, "upload_image_end", "error", + {"message": f"Incomplete upload: received {upload['received_bytes']} of {upload['file_size']} bytes"}) + return + final_path = _safe_upload_path(file_id, ".enc") + if not final_path: + temp_path.unlink(missing_ok=True) + await send_resp(msg, writer, "upload_image_end", "error", {"message": "Invalid file_id"}) + return + def _move_file(): + try: + temp_path.rename(final_path) + except Exception: + import shutil + shutil.move(str(temp_path), str(final_path)) + await asyncio.to_thread(_move_file) + db.complete_image_upload(file_id) + logger.info("Image upload completed: %s (%d bytes)", file_id, upload["received_bytes"]) + await send_resp(msg, writer, "upload_image_end", "ok", {"file_id": file_id}) + + +async def handle_download_image(msg: dict, session: dict, writer: ProtocolWriter): + file_id = msg.get("file_id", "") + offset = msg.get("offset", 0) + if not file_id: + await send_resp(msg, writer, "download_image", "error", {"message": "Missing file_id"}) + return + if not _valid_uuid(file_id): + await send_resp(msg, writer, "download_image", "error", {"message": "Invalid file_id"}) + return + upload = db.get_image_upload(file_id) + if not upload or not upload["completed"]: + await send_resp(msg, writer, "download_image", "error", {"message": "File not found"}) + return + if not db.is_conversation_member(upload["conversation_id"], session["user_id"]): + await send_resp(msg, writer, "download_image", "error", {"message": "Not a member"}) + return + file_path = _safe_upload_path(file_id, ".enc") + if not file_path or not file_path.exists(): + await send_resp(msg, writer, "download_image", "error", {"message": "File not found"}) + return + file_size = file_path.stat().st_size + chunk = await asyncio.to_thread(_read_file_chunk, file_path, offset, IMAGE_CHUNK_SIZE) + done = (offset + len(chunk)) >= file_size + await send_resp(msg, writer, "download_image", "ok", { + "file_id": file_id, + "data": encode_binary(chunk), + "offset": offset, + "done": done, + "total_size": file_size, + }) + + +MAX_AVATAR_BYTES = 2 * 1024 * 1024 # 2 MB + + +async def handle_get_profile(msg: dict, session: dict, writer: ProtocolWriter): + """Get user profile (respects visibility for other users).""" + target_user_id = msg.get("user_id", "").strip() + if not target_user_id: + target_user_id = session["user_id"] + elif not _valid_uuid(target_user_id): + await send_resp(msg, writer, "get_profile", "error", {"message": "Invalid user_id"}) + return + profile = db.get_user_profile(target_user_id, viewer_id=session["user_id"]) + if not profile: + await send_resp(msg, writer, "get_profile", "error", {"message": "User not found"}) + return + # Serialize datetime fields + for key in ("created_at", "updated_at"): + if profile.get(key) and hasattr(profile[key], "isoformat"): + profile[key] = profile[key].isoformat() + await send_resp(msg, writer, "get_profile", "ok", profile) + + +async def handle_update_profile(msg: dict, session: dict, writer: ProtocolWriter): + """Update own profile fields.""" + fields = {} + for key in ("phone", "phone_visible", "email_visible", "location", "location_visible"): + if key in msg: + fields[key] = msg[key] + if not fields: + await send_resp(msg, writer, "update_profile", "error", {"message": "No fields to update"}) + return + db.update_user_profile(session["user_id"], **fields) + await send_resp(msg, writer, "update_profile", "ok", {"message": "OK"}) + + +async def handle_update_avatar(msg: dict, session: dict, writer: ProtocolWriter): + """Upload avatar (base64 in single message, max 2MB).""" + avatar_b64 = msg.get("data", "") + if not avatar_b64: + await send_resp(msg, writer, "update_avatar", "error", {"message": "Missing data"}) + return + avatar_data = decode_binary(avatar_b64) + if len(avatar_data) > MAX_AVATAR_BYTES: + await send_resp(msg, writer, "update_avatar", "error", + {"message": f"Avatar too large (max {MAX_AVATAR_BYTES} bytes)"}) + return + # Detect format from magic bytes + ext = "jpg" + if avatar_data[:8] == b'\x89PNG\r\n\x1a\n': + ext = "png" + avatar_dir = UPLOAD_DIR / "avatars" + avatar_dir.mkdir(parents=True, exist_ok=True) + filename = f"{session['user_id']}.{ext}" + avatar_path = _safe_avatar_path(filename) + if not avatar_path: + await send_resp(msg, writer, "update_avatar", "error", {"message": "Invalid path"}) + return + await asyncio.to_thread(avatar_path.write_bytes, avatar_data) + db.update_user_profile(session["user_id"], avatar_file=filename) + logger.info("Avatar updated for user %s", session["user_id"]) + await send_resp(msg, writer, "update_avatar", "ok", {"avatar_file": filename}) + + +async def handle_get_avatar(msg: dict, session: dict, writer: ProtocolWriter): + """Download avatar for a user.""" + target_user_id = msg.get("user_id", "").strip() + if not target_user_id: + await send_resp(msg, writer, "get_avatar", "error", {"message": "Missing user_id"}) + return + if not _valid_uuid(target_user_id): + await send_resp(msg, writer, "get_avatar", "error", {"message": "Invalid user_id"}) + return + profile = db.get_user_profile(target_user_id) + if not profile or not profile.get("avatar_file"): + await send_resp(msg, writer, "get_avatar", "error", {"message": "No avatar"}) + return + avatar_path = _safe_avatar_path(profile["avatar_file"]) + if not avatar_path or not avatar_path.exists(): + await send_resp(msg, writer, "get_avatar", "error", {"message": "Avatar file not found"}) + return + avatar_data = await asyncio.to_thread(avatar_path.read_bytes) + await send_resp(msg, writer, "get_avatar", "ok", { + "user_id": target_user_id, + "data": encode_binary(avatar_data), + "filename": profile["avatar_file"], + }) + + +async def handle_update_group_avatar(msg: dict, session: dict, writer: ProtocolWriter): + """Upload avatar for a group conversation (base64, max 2MB). Only members can set it.""" + conv_id = msg.get("conversation_id", "").strip() + avatar_b64 = msg.get("data", "") + if not conv_id or not avatar_b64: + await send_resp(msg, writer, "update_group_avatar", "error", {"message": "Missing fields"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, "update_group_avatar", "error", {"message": "Invalid conversation_id"}) + return + if not db.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "update_group_avatar", "error", {"message": "Not a member"}) + return + avatar_data = decode_binary(avatar_b64) + if len(avatar_data) > MAX_AVATAR_BYTES: + await send_resp(msg, writer, "update_group_avatar", "error", + {"message": f"Avatar too large (max {MAX_AVATAR_BYTES} bytes)"}) + return + ext = "jpg" + if avatar_data[:8] == b'\x89PNG\r\n\x1a\n': + ext = "png" + avatar_dir = UPLOAD_DIR / "avatars" + avatar_dir.mkdir(parents=True, exist_ok=True) + filename = f"group_{conv_id}.{ext}" + avatar_path = _safe_avatar_path(filename) + if not avatar_path: + await send_resp(msg, writer, "update_group_avatar", "error", {"message": "Invalid path"}) + return + await asyncio.to_thread(avatar_path.write_bytes, avatar_data) + db.update_conversation_avatar(conv_id, filename) + logger.info("Group avatar updated for conversation %s", conv_id) + await send_resp(msg, writer, "update_group_avatar", "ok", {"avatar_file": filename}) + + +async def handle_get_group_avatar(msg: dict, session: dict, writer: ProtocolWriter): + """Download avatar for a group conversation.""" + conv_id = msg.get("conversation_id", "").strip() + if not conv_id: + await send_resp(msg, writer, "get_group_avatar", "error", {"message": "Missing conversation_id"}) + return + if not _valid_uuid(conv_id): + await send_resp(msg, writer, "get_group_avatar", "error", {"message": "Invalid conversation_id"}) + return + if not db.is_conversation_member(conv_id, session["user_id"]): + await send_resp(msg, writer, "get_group_avatar", "error", {"message": "Not a member"}) + return + conv = db.get_conversation(conv_id) + if not conv or not conv.get("avatar_file"): + await send_resp(msg, writer, "get_group_avatar", "error", {"message": "No avatar"}) + return + avatar_path = _safe_avatar_path(conv["avatar_file"]) + if not avatar_path or not avatar_path.exists(): + await send_resp(msg, writer, "get_group_avatar", "error", {"message": "Avatar file not found"}) + return + avatar_data = await asyncio.to_thread(avatar_path.read_bytes) + await send_resp(msg, writer, "get_group_avatar", "ok", { + "conversation_id": conv_id, + "data": encode_binary(avatar_data), + "filename": conv["avatar_file"], + }) + + +async def handle_list_devices(msg: dict, session: dict, writer: ProtocolWriter): + """List all devices for the current user.""" + devices = db.get_user_devices(session["user_id"]) + result = [] + for d in devices: + entry = { + "device_id": d["id"], + "device_name": d.get("device_name"), + "created_at": d["created_at"].isoformat() if hasattr(d["created_at"], "isoformat") else str(d["created_at"]), + "last_seen_at": d["last_seen_at"].isoformat() if d.get("last_seen_at") and hasattr(d["last_seen_at"], "isoformat") else (str(d["last_seen_at"]) if d.get("last_seen_at") else None), + "is_current": d["id"] == session.get("device_id"), + } + result.append(entry) + await send_resp(msg, writer, "list_devices", "ok", {"devices": result}) + + +async def handle_remove_device(msg: dict, session: dict, writer: ProtocolWriter): + """Remove a device (cannot remove current device).""" + device_id = msg.get("device_id", "").strip() + if not device_id: + await send_resp(msg, writer, "remove_device", "error", {"message": "Missing device_id"}) + return + if not _valid_uuid(device_id): + await send_resp(msg, writer, "remove_device", "error", {"message": "Invalid device_id"}) + return + if device_id == session.get("device_id"): + await send_resp(msg, writer, "remove_device", "error", {"message": "Cannot remove current device"}) + return + dev = db.get_device(device_id) + if not dev or dev["user_id"] != session["user_id"]: + await send_resp(msg, writer, "remove_device", "error", {"message": "Device not found"}) + return + db.delete_device(device_id) + logger.info("Device removed: %s", device_id) + await send_resp(msg, writer, "remove_device", "ok", {"message": "OK"}) + + +async def handle_session_reset(msg: dict, session: dict, writer: ProtocolWriter): + """Notify peer to reset a corrupted Double Ratchet session.""" + peer_user_id = msg.get("peer_user_id", "").strip() + peer_device_id = msg.get("peer_device_id", "").strip() or None + if not peer_user_id or not _valid_uuid(peer_user_id): + await send_resp(msg, writer, "session_reset", "error", {"message": "Invalid peer_user_id"}) + return + if peer_device_id and not _valid_uuid(peer_device_id): + await send_resp(msg, writer, "session_reset", "error", {"message": "Invalid peer_device_id"}) + return + # Push notification to peer + await _notify_users([peer_user_id], "session_reset", { + "from_user_id": session["user_id"], + "from_device_id": session.get("device_id"), + }) + await send_resp(msg, writer, "session_reset", "ok", {}) + + +async def handle_reencrypt_messages(msg: dict, session: dict, writer: ProtocolWriter): + """Re-encrypt message history with self-encryption key (for device pairing).""" + updates_raw = msg.get("updates", []) + if not updates_raw: + await send_resp(msg, writer, "reencrypt_messages", "error", {"message": "No updates"}) + return + if len(updates_raw) > 500: + await send_resp(msg, writer, "reencrypt_messages", "error", + {"message": "Too many updates (max 500 per request)"}) + return + updates = [] + for u in updates_raw: + mid = u.get("message_id", "") + ct_b64 = u.get("encrypted_content", "") + nonce_b64 = u.get("nonce", "") + if not mid or not ct_b64 or not nonce_b64: + continue + updates.append({ + "message_id": mid, + "encrypted_content": decode_binary(ct_b64), + "nonce": decode_binary(nonce_b64), + }) + if not updates: + await send_resp(msg, writer, "reencrypt_messages", "error", {"message": "No valid updates"}) + return + db.batch_reencrypt_messages(session["user_id"], updates) + logger.info("Re-encrypted %d messages for user.", len(updates)) + await send_resp(msg, writer, "reencrypt_messages", "ok", {"count": len(updates)}) + + +async def _cleanup_uploads(): + stale = db.get_stale_uploads(3600) + for s in stale: + fid = s["file_id"] + for ext in (".tmp", ".enc"): + p = _safe_upload_path(fid, ext) + if not p: + continue + try: + p.unlink(missing_ok=True) + except Exception: + pass + db.delete_image_upload(fid) + async with _uploads_lock: + pending_uploads.pop(fid, None) + if stale: + logger.info("Cleaned up %d stale uploads.", len(stale)) + + +async def handle_client(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + global current_connections + addr = _get_peer_addr(ProtocolWriter(writer)) + async with _conn_lock: + current_connections += 1 + connection_counts[addr] = connection_counts.get(addr, 0) + 1 + over_limit = (current_connections > MAX_CONNECTIONS_GLOBAL or + connection_counts[addr] > MAX_CONNECTIONS_PER_IP) + if over_limit: + try: + writer.close() + except Exception: + pass + async with _conn_lock: + current_connections = max(0, current_connections - 1) + connection_counts[addr] = max(0, connection_counts.get(addr, 1) - 1) + return + logger.debug("Client connected.") + proto_reader = ProtocolReader(reader) + proto_writer = ProtocolWriter(writer) + session = None + state = {"_req_times": []} + + try: + while True: + try: + msg = await proto_reader.read_message() + except ValueError as e: + try: + await proto_writer.send_response("protocol_error", "error", {"message": str(e)}) + except Exception: + pass + break + if msg is None: + break + + msg_type = msg.get("type", "") + now = asyncio.get_event_loop().time() + times = [t for t in state["_req_times"] if now - t <= CONNECTION_RL_WINDOW] + if len(times) >= CONNECTION_RL_MAX: + await send_resp(msg, proto_writer, msg_type, "error", {"message": "Too many requests. Slow down."}) + state["_req_times"] = times + continue + times.append(now) + state["_req_times"] = times + + try: + if msg_type == "register": + await handle_register_start(msg, proto_writer) + elif msg_type == "register_confirm": + await handle_register_confirm(msg, proto_writer) + elif msg_type == "login_start": + await handle_login_start(msg, proto_writer, state) + elif msg_type == "login_finish": + result = await handle_login_finish(msg, proto_writer, state) + if result: + session = result + elif msg_type == "pairing_start": + await handle_pairing_start(msg, proto_writer) + elif msg_type == "pairing_poll": + await handle_pairing_poll(msg, proto_writer) + elif session is None: + await send_resp(msg, proto_writer, msg_type, "error", {"message": "Not logged in"}) + elif msg_type == "get_user_info": + await handle_get_user_info(msg, proto_writer) + elif msg_type == "upload_prekeys": + await handle_upload_prekeys(msg, session, proto_writer) + elif msg_type == "get_key_bundle": + await handle_get_key_bundle(msg, session, proto_writer) + elif msg_type == "get_prekey_count": + await handle_get_prekey_count(msg, session, proto_writer) + elif msg_type == "create_conversation": + await handle_create_conversation(msg, session, proto_writer) + elif msg_type == "find_conversation": + await handle_find_conversation(msg, session, proto_writer) + elif msg_type == "add_member": + await handle_add_member(msg, session, proto_writer) + elif msg_type == "accept_invitation": + await handle_accept_invitation(msg, session, proto_writer) + elif msg_type == "decline_invitation": + await handle_decline_invitation(msg, session, proto_writer) + elif msg_type == "list_invitations": + await handle_list_invitations(msg, session, proto_writer) + elif msg_type == "list_conversations": + await handle_list_conversations(msg, session, proto_writer) + elif msg_type == "send_message": + await handle_send_message(msg, session, proto_writer) + elif msg_type == "get_messages": + await handle_get_messages(msg, session, proto_writer) + elif msg_type == "rotate_keys": + await handle_rotate_keys(msg, session, proto_writer) + elif msg_type == "remove_member": + await handle_remove_member(msg, session, proto_writer) + elif msg_type == "leave_group": + await handle_leave_group(msg, session, proto_writer) + elif msg_type == "rename_conversation": + await handle_rename_conversation(msg, session, proto_writer) + elif msg_type == "delete_conversation": + await handle_delete_conversation(msg, session, proto_writer) + elif msg_type == "mark_read": + await handle_mark_read(msg, session, proto_writer) + elif msg_type == "pairing_claim": + await handle_pairing_claim(msg, session, proto_writer) + elif msg_type == "pairing_send": + await handle_pairing_send(msg, session, proto_writer) + elif msg_type == "delete_message": + await handle_delete_message(msg, session, proto_writer) + elif msg_type == "upload_image_start": + await handle_upload_image_start(msg, session, proto_writer) + elif msg_type == "upload_image_chunk": + await handle_upload_image_chunk(msg, session, proto_writer) + elif msg_type == "upload_image_end": + await handle_upload_image_end(msg, session, proto_writer) + elif msg_type == "download_image": + await handle_download_image(msg, session, proto_writer) + elif msg_type == "get_profile": + await handle_get_profile(msg, session, proto_writer) + elif msg_type == "update_profile": + await handle_update_profile(msg, session, proto_writer) + elif msg_type == "update_avatar": + await handle_update_avatar(msg, session, proto_writer) + elif msg_type == "get_avatar": + await handle_get_avatar(msg, session, proto_writer) + elif msg_type == "update_group_avatar": + await handle_update_group_avatar(msg, session, proto_writer) + elif msg_type == "get_group_avatar": + await handle_get_group_avatar(msg, session, proto_writer) + elif msg_type == "reencrypt_messages": + await handle_reencrypt_messages(msg, session, proto_writer) + elif msg_type == "list_devices": + await handle_list_devices(msg, session, proto_writer) + elif msg_type == "remove_device": + await handle_remove_device(msg, session, proto_writer) + elif msg_type == "session_reset": + await handle_session_reset(msg, session, proto_writer) + else: + await send_resp(msg, proto_writer, msg_type, "error", {"message": "Unknown type"}) + except Exception as e: + logger.warning("Handler error for '%s': %s", msg_type, e, exc_info=True) + try: + await send_resp(msg, proto_writer, msg_type, "error", {"message": "Internal server error"}) + except Exception: + break # Can't send response — connection is dead + except Exception as e: + logger.warning("Client connection error: %s", e) + finally: + async with _conn_lock: + current_connections = max(0, current_connections - 1) + connection_counts[addr] = max(0, connection_counts.get(addr, 1) - 1) + offline_targets = [] + if session: + uid = session["user_id"] + contacts = db.get_user_contacts(uid) + async with _clients_lock: + writer_device_map.pop(id(proto_writer), None) + if uid in connected_clients: + remaining = [w for w in connected_clients[uid] if w is not proto_writer] + if remaining: + connected_clients[uid] = remaining + else: + del connected_clients[uid] + # User fully offline — snapshot targets under lock + for contact_id in contacts: + for cw in connected_clients.get(contact_id, []): + offline_targets.append(cw) + # Send offline notifications outside lock + for cw in offline_targets: + try: + await cw.send_response("user_offline", "ok", {"user_id": uid}) + except Exception: + pass + writer.close() + logger.debug("Client disconnected.") + + +async def main(): + setup_logging() + host = os.getenv("SERVER_HOST", "127.0.0.1") + port = int(os.getenv("SERVER_PORT", "9999")) + tls_enabled = os.getenv("TLS_ENABLED", "false").lower() in ("1", "true", "yes") + tls_required = os.getenv("TLS_REQUIRED", "false").lower() in ("1", "true", "yes") + tls_autogen = os.getenv("TLS_AUTOGEN", "false").lower() in ("1", "true", "yes") + + is_dev = os.getenv("ENVIRONMENT", "").lower() in ("dev", "development") + ssl_context = None + if tls_required and not tls_enabled: + raise RuntimeError("TLS_REQUIRED is enabled but TLS is not enabled.") + if tls_enabled: + cert_file = os.getenv("TLS_CERT_FILE", "").strip() + key_file = os.getenv("TLS_KEY_FILE", "").strip() + if not cert_file or not key_file: + if tls_autogen: + if not is_dev: + raise RuntimeError("TLS_AUTOGEN is only allowed when ENVIRONMENT=dev") + cert_dir = Path(__file__).resolve().parent / "certs" + cert_dir.mkdir(parents=True, exist_ok=True) + cert_file = str(cert_dir / "server.crt") + key_file = str(cert_dir / "server.key") + if not (os.path.exists(cert_file) and os.path.exists(key_file)): + try: + subprocess.run( + [ + "openssl", "req", "-x509", "-newkey", "rsa:4096", + "-keyout", key_file, "-out", cert_file, + "-days", "365", "-nodes", "-subj", "/CN=localhost", + ], + check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, + ) + os.chmod(key_file, 0o600) + except FileNotFoundError: + raise RuntimeError("OpenSSL not found.") + except subprocess.CalledProcessError: + raise RuntimeError("Failed to auto-generate TLS cert.") + logger.warning("Using auto-generated self-signed certificate — not for production use.") + else: + raise RuntimeError("TLS is enabled but TLS_CERT_FILE or TLS_KEY_FILE is missing.") + ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ssl_context.load_cert_chain(certfile=cert_file, keyfile=key_file) + else: + logger.warning("TLS is disabled — traffic is unencrypted. Set TLS_ENABLED=true for production.") + + UPLOAD_DIR.mkdir(parents=True, exist_ok=True) + + # Load phantom user IDs from DB into in-memory cache + phantom_user_ids.update(db.get_all_phantom_user_ids()) + if phantom_user_ids: + logger.info("Loaded %d phantom user IDs.", len(phantom_user_ids)) + + server = await asyncio.start_server( + handle_client, host, port, limit=MAX_MESSAGE_BYTES, ssl=ssl_context, + ) + logger.info("Encrypted chat server v%s listening on %s:%s", VERSION, host, port) + + async def _cleanup_rate_limits(): + async with _conn_lock: + now = asyncio.get_event_loop().time() + window_start = now - RATE_LIMIT_WINDOW + stale_keys = [k for k, times in rate_limits.items() + if not any(t >= window_start for t in times)] + for k in stale_keys: + del rate_limits[k] + stale_conns = [k for k, v in connection_counts.items() if v <= 0] + for k in stale_conns: + del connection_counts[k] + + async def _periodic_cleanup(): + while True: + await asyncio.sleep(600) + try: + await _cleanup_uploads() + except Exception as e: + logger.warning("Upload cleanup error: %s", e) + try: + await _cleanup_rate_limits() + except Exception as e: + logger.warning("Rate limit cleanup error: %s", e) + # L8: clean up stale phantom users (>30 days, no real conversations) + try: + deleted = db.cleanup_stale_phantoms(30) + if deleted: + async with _clients_lock: + phantom_user_ids.clear() + phantom_user_ids.update(db.get_all_phantom_user_ids()) + logger.info("Cleaned up %d stale phantom users.", deleted) + except Exception as e: + logger.warning("Phantom cleanup error: %s", e) + + asyncio.create_task(_periodic_cleanup()) + + loop = asyncio.get_running_loop() + stop = loop.create_future() + + def signal_handler(): + if not stop.done(): + stop.set_result(None) + + for sig in (signal.SIGINT, signal.SIGTERM): + loop.add_signal_handler(sig, signal_handler) + + async with server: + await stop + # Force-close all connected clients BEFORE exiting context manager, + # otherwise wait_closed() blocks forever waiting for handle_client tasks + logger.info("Shutting down — closing %d client connections...", sum(len(ws) for ws in connected_clients.values())) + async with _clients_lock: + all_writers = [w for writers in connected_clients.values() for w in writers] + connected_clients.clear() + writer_device_map.clear() + for w in all_writers: + try: + w.close() + except Exception: + pass + logger.info("Server shut down.") + + +if __name__ == "__main__": + asyncio.run(main())