Files
Kecalek_python/chat_core.py
Filip 6da7515d1e Transfer contact verification state during device pairing
When authorizing a new device, include the TOFU registry
(known_identity_keys) and manual verifications (verified_contacts) in
the encrypted pairing payload, so a contact verified on the existing
device stays verified on the newly paired one. Previously these stores
are device-local and started empty on the new device, dropping verified
status. Fields are optional and ignored by older clients; symmetric with
the iOS client.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-15 20:00:53 +02:00

4009 lines
167 KiB
Python

"""Shared network layer and ChatClient class for CLI and GUI clients.
Uses X3DH + Double Ratchet for message encryption, Sender Keys for groups.
RSA retained for login challenge-response only.
"""
import asyncio
import collections
import hashlib
import json
import logging
import os
import random
import ssl
import time
import uuid
from datetime import datetime, timezone
from pathlib import Path
from dotenv import load_dotenv
load_dotenv()
from crypto_utils import (
# RSA (login only)
generate_rsa_keypair,
serialize_private_key,
serialize_public_key,
load_private_key,
load_public_key,
rsa_sign,
# Ed25519
generate_identity_keypair,
serialize_ed25519_private,
serialize_ed25519_private_raw,
serialize_ed25519_public,
load_ed25519_private,
load_ed25519_public,
ed25519_sign,
# X25519
generate_x25519_keypair,
serialize_x25519_private,
serialize_x25519_public,
load_x25519_private,
load_x25519_public,
x25519_dh,
derive_pairing_shared_key,
# X3DH
generate_signed_prekey,
generate_one_time_prekeys,
x3dh_initiate,
x3dh_respond,
# Double Ratchet
DoubleRatchet,
# Sender Keys
SenderKeyState,
# AES
aes_encrypt,
aes_decrypt,
# Self-encryption
derive_self_encryption_key,
# Local storage encryption
derive_local_storage_key,
# Contact verification
compute_fingerprint,
compute_pairing_fingerprint,
encode_pairing_qr,
format_fingerprint,
normalize_pairing_fingerprint,
compute_safety_number,
encode_verification_qr,
decode_verification_qr,
# Message padding
pad_plaintext,
unpad_plaintext,
)
from protocol import (
VERSION,
ProtocolReader,
ProtocolWriter,
encode_binary,
decode_binary,
build_request,
MAX_MESSAGE_BYTES,
MAX_IMAGE_BYTES,
IMAGE_CHUNK_SIZE,
)
KEY_DIR = Path.home() / ".encrypted_chat"
OPK_REPLENISH_THRESHOLD = 20
OPK_BATCH_SIZE = 50
SPK_ROTATION_DAYS = 7
PAIRING_REENCRYPT_INITIAL_DELAY_RANGE = (20.0, 75.0)
PAIRING_REENCRYPT_INTER_BATCH_DELAY_RANGE = (1.0, 3.0)
PAIRING_REENCRYPT_INTER_FETCH_DELAY_RANGE = (0.15, 0.5)
PAIRING_REENCRYPT_BATCH_SIZE = 500
def _encrypt_local(data: bytes, key: bytes) -> bytes:
"""Encrypt data with AES-256-GCM for local storage. Format: nonce(12) + tag(16) + ciphertext."""
_, nonce, ct, tag = aes_encrypt(data, key=key)
return nonce + tag + ct
def _decrypt_local(raw: bytes, key: bytes) -> bytes:
"""Decrypt data encrypted by _encrypt_local."""
nonce, tag, ct = raw[:12], raw[12:28], raw[28:]
return aes_decrypt(key, nonce, ct, tag)
def get_key_dir(email: str) -> Path:
d = KEY_DIR / email
d.mkdir(parents=True, exist_ok=True)
os.chmod(d, 0o700)
return d
# ---------------------------------------------------------------------------
# RSA key storage (login only — unchanged interface)
# ---------------------------------------------------------------------------
def save_keys(email: str, private_key, public_key, password: bytes | None = None):
d = get_key_dir(email)
(d / "private.pem").write_bytes(serialize_private_key(private_key, password=password))
(d / "public.pem").write_bytes(serialize_public_key(public_key))
os.chmod(d / "private.pem", 0o600)
def load_keys(email: str, password: bytes | None = None):
d = get_key_dir(email)
priv_path = d / "private.pem"
pub_path = d / "public.pem"
if not priv_path.exists():
return None, None, "No local keys found."
pem = priv_path.read_bytes()
try:
private_key = load_private_key(pem, password=password)
except Exception:
try:
private_key = load_private_key(pem, password=None)
if password:
save_keys(email, private_key, load_public_key(pub_path.read_bytes()), password=password)
except Exception:
return None, None, "Invalid or missing password."
public_key = load_public_key(pub_path.read_bytes())
return private_key, public_key, None
# ---------------------------------------------------------------------------
# Identity + prekey storage
# ---------------------------------------------------------------------------
def _save_identity_keys(email: str, ed_priv, ed_pub, password: bytes | None = None):
d = get_key_dir(email)
if password:
(d / "identity_private.bin").write_bytes(serialize_ed25519_private(ed_priv, password=password))
else:
(d / "identity_private.bin").write_bytes(serialize_ed25519_private_raw(ed_priv))
(d / "identity_public.bin").write_bytes(serialize_ed25519_public(ed_pub))
os.chmod(d / "identity_private.bin", 0o600)
def _load_identity_keys(email: str, password: bytes | None = None):
d = get_key_dir(email)
priv_path = d / "identity_private.bin"
pub_path = d / "identity_public.bin"
if not priv_path.exists():
return None, None
priv = load_ed25519_private(priv_path.read_bytes(), password=password)
pub = load_ed25519_public(pub_path.read_bytes())
return priv, pub
def _save_spk(email: str, spk_priv, spk_id: str, local_key: bytes | None = None):
d = get_key_dir(email)
raw = serialize_x25519_private(spk_priv)
data = _encrypt_local(raw, local_key) if local_key else raw
(d / "spk_private.bin").write_bytes(data)
(d / "spk_id.txt").write_text(spk_id)
os.chmod(d / "spk_private.bin", 0o600)
def _load_spk(email: str, local_key: bytes | None = None):
d = get_key_dir(email)
priv_path = d / "spk_private.bin"
id_path = d / "spk_id.txt"
if not priv_path.exists():
return None, None
raw = priv_path.read_bytes()
if local_key:
try:
raw = _decrypt_local(raw, local_key)
except Exception:
# Plaintext fallback (migration) — re-save encrypted
pass
priv = load_x25519_private(raw)
spk_id = id_path.read_text().strip() if id_path.exists() else ""
if local_key:
_save_spk(email, priv, spk_id, local_key)
return priv, spk_id
def _save_prev_spk(email: str, spk_priv, spk_id: str, local_key: bytes | None = None):
"""Save previous SPK for grace period (in-flight X3DH may reference old SPK)."""
d = get_key_dir(email)
raw = serialize_x25519_private(spk_priv)
data = _encrypt_local(raw, local_key) if local_key else raw
(d / "prev_spk_private.bin").write_bytes(data)
(d / "prev_spk_id.txt").write_text(spk_id)
os.chmod(d / "prev_spk_private.bin", 0o600)
def _load_prev_spk(email: str, local_key: bytes | None = None):
"""Load previous SPK (grace period). Returns (private_key, spk_id) or (None, None)."""
d = get_key_dir(email)
priv_path = d / "prev_spk_private.bin"
id_path = d / "prev_spk_id.txt"
if not priv_path.exists():
return None, None
raw = priv_path.read_bytes()
if local_key:
try:
raw = _decrypt_local(raw, local_key)
except Exception:
pass
priv = load_x25519_private(raw)
spk_id = id_path.read_text().strip() if id_path.exists() else ""
if local_key:
_save_prev_spk(email, priv, spk_id, local_key)
return priv, spk_id
def _save_opk_private(email: str, opk_id: str, opk_priv, local_key: bytes | None = None):
d = get_key_dir(email) / "opk_private"
d.mkdir(parents=True, exist_ok=True)
os.chmod(d, 0o700)
raw = serialize_x25519_private(opk_priv)
data = _encrypt_local(raw, local_key) if local_key else raw
(d / f"{opk_id}.bin").write_bytes(data)
os.chmod(d / f"{opk_id}.bin", 0o600)
def _load_opk_private(email: str, opk_id: str, local_key: bytes | None = None):
d = get_key_dir(email) / "opk_private"
p = d / f"{opk_id}.bin"
if not p.exists():
return None
raw = p.read_bytes()
if local_key:
try:
raw = _decrypt_local(raw, local_key)
except Exception:
pass
priv = load_x25519_private(raw)
# Migration: re-save encrypted if local_key provided
if local_key:
_save_opk_private(email, opk_id, priv, local_key)
return priv
def _secure_delete(p: Path):
"""Overwrite file with random data before deletion (anti-forensic wipe)."""
try:
if not p.exists():
return
size = p.stat().st_size
if size > 0:
with open(p, "r+b") as f:
f.write(os.urandom(size))
f.flush()
os.fsync(f.fileno())
p.unlink()
except Exception:
try:
p.unlink(missing_ok=True)
except Exception:
pass
def _delete_opk_private(email: str, opk_id: str):
d = get_key_dir(email) / "opk_private"
p = d / f"{opk_id}.bin"
_secure_delete(p)
def _save_device_id(email: str, device_id: str):
d = get_key_dir(email)
p = d / "device_id.txt"
p.write_text(device_id)
os.chmod(p, 0o600)
def _load_device_id(email: str) -> str | None:
d = get_key_dir(email)
p = d / "device_id.txt"
if not p.exists():
return None
return p.read_text().strip() or None
# ------------------------------------------------------------------
# Expiring LRU cache for user info (max_size + TTL eviction)
# ------------------------------------------------------------------
class _ExpiringLRUCache:
"""Dict-like cache with max size (LRU eviction) and per-entry TTL."""
def __init__(self, max_size: int = 10_000, ttl: float = 3600.0):
self._max_size = max_size
self._ttl = ttl
self._data: collections.OrderedDict = collections.OrderedDict() # key -> (value, ts)
def get(self, key, default=None):
entry = self._data.get(key)
if entry is None:
return default
value, ts = entry
if time.monotonic() - ts > self._ttl:
del self._data[key]
return default
# Move to end (most recently used)
self._data.move_to_end(key)
return value
def __setitem__(self, key, value):
if key in self._data:
self._data.move_to_end(key)
self._data[key] = (value, time.monotonic())
while len(self._data) > self._max_size:
self._data.popitem(last=False)
def __getitem__(self, key):
result = self.get(key)
if result is None and key not in self._data:
raise KeyError(key)
return result
def __contains__(self, key):
return self.get(key) is not None
# ------------------------------------------------------------------
# Identity key change exception (TOFU hard-fail)
# ------------------------------------------------------------------
class IdentityKeyChanged(Exception):
"""Raised when a peer's identity key has changed (TOFU violation).
Session creation is blocked until the user explicitly accepts the new key.
"""
def __init__(self, user_id: str, new_key_bytes: bytes, status: str):
self.user_id = user_id
self.new_key_bytes = new_key_bytes
self.status = status # "changed" or "changed_verified"
super().__init__(
f"Identity key changed for user {user_id} (status={status}). "
f"Accept the new key before communicating."
)
# ------------------------------------------------------------------
# Client-side brute-force lockout
# ------------------------------------------------------------------
_LOCKOUT_BASE_SECONDS = 2
_LOCKOUT_MAX_SECONDS = 300 # 5 min cap
def _get_lockout_path(email: str) -> Path:
return get_key_dir(email) / "login_lockout.json"
def _check_lockout(email: str) -> float:
"""Return seconds remaining until next attempt allowed. 0 = can try now."""
p = _get_lockout_path(email)
if not p.exists():
return 0.0
try:
data = json.loads(p.read_text())
locked_until = data.get("locked_until", 0.0)
remaining = locked_until - time.time()
return max(0.0, remaining)
except Exception:
return 0.0
def _record_failed_attempt(email: str):
"""Increment failed counter, update locked_until."""
p = _get_lockout_path(email)
failed = 0
try:
if p.exists():
data = json.loads(p.read_text())
failed = data.get("failed_attempts", 0)
except Exception:
pass
failed += 1
delay = min(_LOCKOUT_BASE_SECONDS ** failed, _LOCKOUT_MAX_SECONDS)
locked_until = time.time() + delay
p.write_text(json.dumps({"failed_attempts": failed, "locked_until": locked_until}))
os.chmod(p, 0o600)
def _clear_lockout(email: str):
"""Reset on successful login."""
p = _get_lockout_path(email)
if p.exists():
try:
p.unlink()
except Exception:
pass
def _save_session(email: str, peer_user_id: str, ratchet: DoubleRatchet,
local_key: bytes | None = None, peer_device_id: str | None = None):
d = get_key_dir(email) / "sessions"
d.mkdir(parents=True, exist_ok=True)
os.chmod(d, 0o700)
if peer_device_id:
filename = f"{peer_user_id}_{peer_device_id}.bin"
else:
filename = f"{peer_user_id}.bin"
p = d / filename
data = ratchet.export_state()
if local_key:
data = _encrypt_local(data, local_key)
p.write_bytes(data)
os.chmod(p, 0o600)
def _load_session(email: str, peer_user_id: str,
local_key: bytes | None = None,
peer_device_id: str | None = None) -> DoubleRatchet | None:
d = get_key_dir(email) / "sessions"
if peer_device_id:
p = d / f"{peer_user_id}_{peer_device_id}.bin"
if not p.exists():
# Fallback: try old format (no device_id) and migrate
p_old = d / f"{peer_user_id}.bin"
if p_old.exists():
ratchet = _load_session_file(p_old, local_key)
if ratchet:
_save_session(email, peer_user_id, ratchet, local_key,
peer_device_id=peer_device_id)
_secure_delete(p_old)
return ratchet
return None
else:
p = d / f"{peer_user_id}.bin"
if not p.exists():
return None
return _load_session_file(p, local_key)
def _load_session_file(p: Path, local_key: bytes | None = None) -> DoubleRatchet | None:
"""Load a session from a specific file path."""
if not p.exists():
return None
raw = p.read_bytes()
if local_key:
try:
data = _decrypt_local(raw, local_key)
except Exception:
# Migration: try loading as plaintext (old unencrypted format)
try:
ratchet = DoubleRatchet.import_state(raw)
return ratchet
except Exception:
return None
return DoubleRatchet.import_state(data)
return DoubleRatchet.import_state(raw)
def _delete_session_file(email: str, peer_user_id: str, peer_device_id: str | None = None):
"""Securely delete a session file from disk (for session reset)."""
d = get_key_dir(email) / "sessions"
if peer_device_id:
p = d / f"{peer_user_id}_{peer_device_id}.bin"
else:
p = d / f"{peer_user_id}.bin"
_secure_delete(p)
def _save_sender_key_state(email: str, conv_id: str, state: SenderKeyState,
local_key: bytes | None = None):
d = get_key_dir(email) / "sender_keys"
d.mkdir(parents=True, exist_ok=True)
os.chmod(d, 0o700)
p = d / f"{conv_id}.bin"
data = state.export_state()
if local_key:
data = _encrypt_local(data, local_key)
p.write_bytes(data)
os.chmod(p, 0o600)
def _load_sender_key_state(email: str, conv_id: str,
local_key: bytes | None = None) -> SenderKeyState | None:
d = get_key_dir(email) / "sender_keys"
p = d / f"{conv_id}.bin"
if not p.exists():
return None
raw = p.read_bytes()
if local_key:
try:
data = _decrypt_local(raw, local_key)
except Exception:
try:
sk = SenderKeyState.import_state(raw)
_save_sender_key_state(email, conv_id, sk, local_key)
return sk
except Exception:
return None
return SenderKeyState.import_state(data)
return SenderKeyState.import_state(raw)
def _save_sender_key_recipients(email: str, conv_id: str, recipients: set[str],
local_key: bytes | None = None):
d = get_key_dir(email) / "sender_keys"
d.mkdir(parents=True, exist_ok=True)
os.chmod(d, 0o700)
p = d / f"{conv_id}.members.bin"
data = json.dumps(sorted(recipients), ensure_ascii=False).encode("utf-8")
if local_key:
data = _encrypt_local(data, local_key)
p.write_bytes(data)
os.chmod(p, 0o600)
def _load_sender_key_recipients(email: str, conv_id: str,
local_key: bytes | None = None) -> set[str]:
d = get_key_dir(email) / "sender_keys"
p = d / f"{conv_id}.members.bin"
if not p.exists():
return set()
raw = p.read_bytes()
if local_key:
try:
data = _decrypt_local(raw, local_key)
except Exception:
# Migration: previous plaintext storage, re-save encrypted on success.
try:
parsed = json.loads(raw.decode("utf-8"))
recipients = {str(uid) for uid in parsed if isinstance(uid, str)}
_save_sender_key_recipients(email, conv_id, recipients, local_key)
return recipients
except Exception:
return set()
else:
data = raw
try:
parsed = json.loads(data.decode("utf-8"))
if not isinstance(parsed, list):
return set()
return {str(uid) for uid in parsed if isinstance(uid, str)}
except Exception:
return set()
def _save_recv_sender_key(email: str, conv_id: str, sender_id: str, state: SenderKeyState,
local_key: bytes | None = None,
sender_device_id: str | None = None):
d = get_key_dir(email) / "sender_keys_recv"
d.mkdir(parents=True, exist_ok=True)
os.chmod(d, 0o700)
if sender_device_id:
filename = f"{conv_id}_{sender_id}_{sender_device_id}.bin"
else:
filename = f"{conv_id}_{sender_id}.bin"
p = d / filename
data = state.export_state()
if local_key:
data = _encrypt_local(data, local_key)
p.write_bytes(data)
os.chmod(p, 0o600)
def _load_recv_sender_key(email: str, conv_id: str, sender_id: str,
local_key: bytes | None = None,
sender_device_id: str | None = None) -> SenderKeyState | None:
d = get_key_dir(email) / "sender_keys_recv"
if sender_device_id:
p = d / f"{conv_id}_{sender_id}_{sender_device_id}.bin"
if not p.exists():
# Fallback: try old format and migrate
p_old = d / f"{conv_id}_{sender_id}.bin"
if p_old.exists():
sk = _load_recv_sender_key_file(p_old, local_key)
if sk:
_save_recv_sender_key(email, conv_id, sender_id, sk, local_key,
sender_device_id=sender_device_id)
_secure_delete(p_old)
return sk
return None
else:
p = d / f"{conv_id}_{sender_id}.bin"
if not p.exists():
return None
return _load_recv_sender_key_file(p, local_key)
def _load_recv_sender_key_file(p: Path, local_key: bytes | None = None) -> SenderKeyState | None:
"""Load a recv sender key from a specific file path."""
if not p.exists():
return None
raw = p.read_bytes()
if local_key:
try:
data = _decrypt_local(raw, local_key)
except Exception:
try:
sk = SenderKeyState.import_state(raw)
return sk
except Exception:
return None
return SenderKeyState.import_state(data)
return SenderKeyState.import_state(raw)
# ---------------------------------------------------------------------------
# Local decrypted message cache (Double Ratchet keys are one-time use)
# ---------------------------------------------------------------------------
def _load_message_cache(email: str, conv_id: str, cache_key: bytes | None = None) -> dict:
d = get_key_dir(email) / "message_cache"
p_bin = d / f"{conv_id}.bin"
p_json = d / f"{conv_id}.json"
# Migration: if old plaintext .json exists but encrypted .bin doesn't
if p_json.exists() and not p_bin.exists():
try:
cache = json.loads(p_json.read_text("utf-8"))
if cache_key:
_save_message_cache_full(d, conv_id, cache, cache_key)
_secure_delete(p_json)
return cache
except Exception:
return {}
if not p_bin.exists():
return {}
if not cache_key:
return {}
try:
raw = p_bin.read_bytes()
# Format: nonce (12) + tag (16) + ciphertext
nonce = raw[:12]
tag = raw[12:28]
ct = raw[28:]
plaintext = aes_decrypt(cache_key, nonce, ct, tag)
return json.loads(plaintext.decode("utf-8"))
except Exception:
return {}
def _save_message_cache_full(d: Path, conv_id: str, cache: dict, cache_key: bytes):
"""Write the full cache dict encrypted to disk."""
d.mkdir(parents=True, exist_ok=True)
os.chmod(d, 0o700)
p = d / f"{conv_id}.bin"
plaintext = json.dumps(cache, ensure_ascii=False).encode("utf-8")
_key, nonce, ct, tag = aes_encrypt(plaintext, key=cache_key)
p.write_bytes(nonce + tag + ct)
os.chmod(p, 0o600)
def _save_message_to_cache(email: str, conv_id: str, message_id: str, payload: dict,
cache_key: bytes | None = None):
d = get_key_dir(email) / "message_cache"
cache = _load_message_cache(email, conv_id, cache_key)
cache[message_id] = payload
if cache_key:
_save_message_cache_full(d, conv_id, cache, cache_key)
else:
# Fallback: plaintext (no identity key available yet)
d.mkdir(parents=True, exist_ok=True)
os.chmod(d, 0o700)
p = d / f"{conv_id}.json"
p.write_text(json.dumps(cache, ensure_ascii=False), "utf-8")
os.chmod(p, 0o600)
# ---------------------------------------------------------------------------
# Verification storage (TOFU + explicit verification)
# ---------------------------------------------------------------------------
def _save_known_identity_keys(email: str, keys: dict, local_key: bytes | None = None):
"""Save TOFU identity key registry (encrypted with local_key)."""
p = get_key_dir(email) / "known_identity_keys.bin"
data = json.dumps({"version": 1, "keys": keys}).encode("utf-8")
if local_key:
data = _encrypt_local(data, local_key)
p.write_bytes(data)
os.chmod(p, 0o600)
def _load_known_identity_keys(email: str, local_key: bytes | None = None) -> dict:
"""Load TOFU identity key registry. Returns empty dict on error.
No plaintext fallback — these files were never stored unencrypted
(feature introduced after local encryption was implemented).
Accepting plaintext would allow an attacker with disk access to
inject fake identity keys and bypass TOFU warnings.
"""
p = get_key_dir(email) / "known_identity_keys.bin"
if not p.exists():
return {}
raw = p.read_bytes()
try:
if local_key:
data = _decrypt_local(raw, local_key)
else:
data = raw
obj = json.loads(data)
return obj.get("keys", {})
except Exception:
return {}
def _save_verified_contacts(email: str, contacts: dict, local_key: bytes | None = None):
"""Save explicit verification state (encrypted with local_key)."""
p = get_key_dir(email) / "verified_contacts.bin"
data = json.dumps({"version": 1, "contacts": contacts}).encode("utf-8")
if local_key:
data = _encrypt_local(data, local_key)
p.write_bytes(data)
os.chmod(p, 0o600)
def _load_verified_contacts(email: str, local_key: bytes | None = None) -> dict:
"""Load explicit verification state. Returns empty dict on error.
No plaintext fallback — these files were never stored unencrypted.
Accepting plaintext would allow an attacker with disk access to
inject fake verification records (mark attacker as "verified").
"""
p = get_key_dir(email) / "verified_contacts.bin"
if not p.exists():
return {}
raw = p.read_bytes()
try:
if local_key:
data = _decrypt_local(raw, local_key)
else:
data = raw
obj = json.loads(data)
return obj.get("contacts", {})
except Exception:
return {}
# How many sync cycles a message that fails to decrypt is retried before it
# is recorded as permanently failed and the sync watermark moves past it.
_MAX_DECRYPT_RETRIES = 3
def _solve_pow(challenge: str, difficulty: int) -> str:
"""Solve a proof-of-work challenge by finding a nonce with enough leading zero bits."""
target_bytes = difficulty // 8
target_bits = difficulty % 8
mask = (0xFF << (8 - target_bits)) & 0xFF if target_bits else 0
nonce = 0
while True:
digest = hashlib.sha256(f"{challenge}{nonce}".encode()).digest()
# Fast path: check full zero bytes first
ok = True
for i in range(target_bytes):
if digest[i] != 0:
ok = False
break
if ok and target_bits:
if digest[target_bytes] & mask:
ok = False
if ok:
return str(nonce)
nonce += 1
class ChatClient:
def __init__(self):
self.reader: ProtocolReader | None = None
self.writer: ProtocolWriter | None = None
self.raw_writer: asyncio.StreamWriter | None = None
self.session: dict | None = None
self.private_key = None # RSA private key (login only)
self.public_key = None # RSA public key (login only)
self.username: str = ""
self.email: str = ""
self._listener_task: asyncio.Task | None = None
self._response_queue: asyncio.Queue = asyncio.Queue()
self._notification_queue: asyncio.Queue = asyncio.Queue()
self._pending: dict[str, asyncio.Future] = {}
self._pairing_temp_private_key = None
self._pairing_fingerprint: str = ""
self._pairing_code: str = ""
self._reencrypt_progress_cb = None
self._logger = logging.getLogger("encrypted_chat.client")
# Signal Protocol keys
self.identity_private = None # Ed25519PrivateKey
self.identity_public = None # Ed25519PublicKey
self.spk_private = None # X25519PrivateKey (current signed prekey)
self.spk_id: str = ""
self._prev_spk_private = None # Previous SPK for grace period (M4)
self._prev_spk_id: str = ""
self.opk_privates: dict[str, object] = {} # id -> X25519PrivateKey
self.sessions: dict[str, DoubleRatchet] = {} # "user_id:device_id" -> ratchet
self.sender_key_states: dict[str, SenderKeyState] = {} # conv_id -> own sender key
self.recv_sender_keys: dict[str, SenderKeyState] = {} # "conv_id:sender_id:device_id" -> their key
# Cache: user_id -> {identity_key (Ed25519PublicKey), username, email}
# Bounded to 10K entries with 1-hour TTL to prevent unbounded growth (L7)
self._user_cache: _ExpiringLRUCache = _ExpiringLRUCache(max_size=10_000, ttl=3600.0)
self.connected: bool = False
self.login_rejected: bool = False
self._cache_key: bytes | None = None # AES key for encrypting message cache on disk
self._local_key: bytes | None = None # AES key for encrypting session/sender key files
# Multi-device support
self.device_id: str | None = None # This device's UUID
self._device_bundle_cache: dict[str, tuple[float, list[dict]]] = {} # user_id -> (ts, bundles)
# Queue of received messages to self-encrypt for multi-device access
self._pending_self_encrypt: list[dict] = []
self._typing_active: dict[str, bool] = {}
self._typing_last_sent: dict[str, float] = {}
self._typing_stop_tasks: dict[str, asyncio.Task] = {}
# Contact key verification (TOFU + explicit)
self._known_identity_keys: dict = {} # user_id -> {identity_key hex, first_seen, last_seen}
self._verified_contacts: dict = {} # user_id -> {identity_key hex, verified_at, method}
self._key_change_cb = None # callback(user_id, username, old_key_hex, was_verified)
async def connect(self):
host = os.getenv("SERVER_HOST", "127.0.0.1")
port = int(os.getenv("SERVER_PORT", "9999"))
tls_enabled = os.getenv("TLS_ENABLED", "false").lower() in ("1", "true", "yes")
tls_required = os.getenv("TLS_REQUIRED", "false").lower() in ("1", "true", "yes")
ssl_context = None
if tls_required and not tls_enabled:
raise RuntimeError("TLS_REQUIRED is enabled but TLS is not enabled.")
if tls_enabled:
insecure = os.getenv("TLS_INSECURE", "false").lower() in ("1", "true", "yes")
is_dev = os.getenv("ENVIRONMENT", "").lower() in ("dev", "development")
if insecure and not is_dev:
raise RuntimeError("TLS_INSECURE is only allowed when ENVIRONMENT=dev")
ssl_context = ssl.create_default_context()
ca_file = os.getenv("TLS_CA_FILE", "").strip()
if ca_file:
ssl_context.load_verify_locations(cafile=ca_file)
elif insecure:
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
else:
self._logger.warning("TLS is disabled — traffic is unencrypted. Set TLS_ENABLED=true for production.")
r, w = await asyncio.open_connection(host, port, limit=MAX_MESSAGE_BYTES, ssl=ssl_context)
# Enable TCP keepalive to detect dead connections through NAT/firewalls
sock = w.get_extra_info("socket")
if sock is not None:
import socket
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
if hasattr(socket, "TCP_KEEPIDLE"):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 25)
if hasattr(socket, "TCP_KEEPINTVL"):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 10)
if hasattr(socket, "TCP_KEEPCNT"):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 3)
except OSError:
pass
self.reader = ProtocolReader(r)
self.writer = ProtocolWriter(w)
self.raw_writer = w
self.connected = True
self._logger.info("Connected to %s:%s (tls=%s)", host, port, "on" if tls_enabled else "off")
def server_endpoint(self) -> str:
host = os.getenv("SERVER_HOST", "127.0.0.1")
port = os.getenv("SERVER_PORT", "9999")
return f"{host}:{port}"
def pairing_fingerprint(self) -> str:
return self._pairing_fingerprint
def pairing_qr_data(self) -> bytes | None:
if not self._pairing_code or not self._pairing_fingerprint:
return None
return encode_pairing_qr(self._pairing_code, self._pairing_fingerprint)
async def _background_listener(self):
"""Read messages from server, routing responses vs notifications."""
while True:
msg = await self.reader.read_message()
if msg is None:
self.connected = False
# Fail all pending futures so send_and_recv doesn't hang
pending = dict(self._pending)
self._pending.clear()
err = ConnectionError("Server connection lost")
for obj in pending.values():
if isinstance(obj, asyncio.Queue):
# Signal stream consumers that connection died
obj.put_nowait({"status": "error", "data": {"message": "Connection lost"}})
elif not obj.done():
obj.set_exception(err)
break
# Responses to our own requests (have request_id matching a pending future)
# must be routed to the pending future, even if the type matches a notification name.
req_id = msg.get("request_id")
if req_id and req_id in self._pending:
pending_obj = self._pending[req_id]
if isinstance(pending_obj, asyncio.Queue):
await pending_obj.put(msg)
else:
self._pending.pop(req_id)
if not pending_obj.done():
pending_obj.set_result(msg)
elif msg.get("type") in ("new_message", "messages_read", "message_deleted",
"conversation_created", "member_added", "member_removed",
"user_online", "user_offline", "online_users",
"group_invitation", "conversation_renamed",
"device_added",
"session_reset",
"message_reacted", "message_pinned", "message_unpinned",
"message_delivered", "username_changed",
"avatar_changed", "keys_updated",
"typing_start", "typing_stop"):
await self._notification_queue.put(msg)
else:
await self._response_queue.put(msg)
async def send_and_recv(self, msg_type: str, timeout: float = 30.0, **kwargs) -> dict:
try:
request_id = str(uuid.uuid4())
loop = asyncio.get_running_loop()
fut = loop.create_future()
self._pending[request_id] = fut
await self.writer.send_request(msg_type, request_id=request_id, **kwargs)
except (ValueError, ConnectionError, OSError) as e:
self._pending.pop(request_id, None)
return {
"type": msg_type,
"status": "error",
"data": {"message": str(e) or "Connection lost."},
}
try:
return await asyncio.wait_for(fut, timeout=timeout)
except asyncio.TimeoutError:
self._logger.warning("send_and_recv timeout for '%s' after %.0fs", msg_type, timeout)
return {
"type": msg_type,
"status": "error",
"data": {"message": f"Request timed out ({msg_type})"},
}
except ConnectionError:
return {
"type": msg_type,
"status": "error",
"data": {"message": "Connection lost."},
}
finally:
self._pending.pop(request_id, None)
# ------------------------------------------------------------------
# User info / identity key cache
# ------------------------------------------------------------------
async def _get_user_info(self, user_id: str = "", email: str = "") -> dict | None:
"""Get user info from server, cache identity key. Performs TOFU check."""
cached = self._user_cache.get(user_id)
if cached:
return cached
kwargs = {}
if user_id:
kwargs["user_id"] = user_id
elif email:
kwargs["email"] = email
else:
return None
resp = await self.send_and_recv("get_user_info", **kwargs)
if resp["status"] != "ok":
return None
data = resp["data"]
ik_bytes = decode_binary(data["identity_key"]) if data.get("identity_key") else None
info = {
"user_id": data["user_id"],
"username": data["username"],
"email": data["email"],
"identity_key": load_ed25519_public(ik_bytes) if ik_bytes else None,
"identity_key_bytes": ik_bytes,
}
# TOFU: check identity key against known keys
if ik_bytes:
status = self.check_identity_key(data["user_id"], ik_bytes)
info["identity_key_status"] = status
self._user_cache[data["user_id"]] = info
return info
# ------------------------------------------------------------------
# Contact Key Verification
# ------------------------------------------------------------------
def _load_verification_stores(self):
"""Load TOFU and verification stores from disk."""
if not self.email:
return
self._known_identity_keys = _load_known_identity_keys(self.email, self._local_key)
self._verified_contacts = _load_verified_contacts(self.email, self._local_key)
def check_identity_key(self, user_id: str, identity_key_bytes: bytes) -> str:
"""Check a user's identity key against TOFU registry.
Returns:
"new" — first contact, key recorded (TOFU)
"trusted" — key matches previously seen, not explicitly verified
"verified" — key matches and explicitly verified
"changed" — key differs from recorded (WARNING)
"changed_verified" — key changed AND was previously verified (CRITICAL)
"""
ik_hex = identity_key_bytes.hex()
now = datetime.now(timezone.utc).isoformat()
known = self._known_identity_keys.get(user_id)
if known is None:
# First time seeing this user — TOFU: trust on first use
self._known_identity_keys[user_id] = {
"identity_key": ik_hex,
"first_seen": now,
"last_seen": now,
}
if self.email:
_save_known_identity_keys(self.email, self._known_identity_keys, self._local_key)
return "new"
if known["identity_key"] == ik_hex:
# Key matches — update last_seen
known["last_seen"] = now
if self.email:
_save_known_identity_keys(self.email, self._known_identity_keys, self._local_key)
# Check if explicitly verified
verified = self._verified_contacts.get(user_id)
if verified and verified.get("identity_key") == ik_hex:
return "verified"
return "trusted"
# Key has CHANGED
was_verified = user_id in self._verified_contacts
old_key_hex = known["identity_key"]
# Invoke callback for GUI/CLI warning
if self._key_change_cb:
username = ""
cached = self._user_cache.get(user_id)
if cached:
username = cached.get("username", "")
try:
self._key_change_cb(user_id, username, old_key_hex, was_verified, identity_key_bytes)
except Exception:
pass
return "changed_verified" if was_verified else "changed"
def verify_contact(self, user_id: str, identity_key_bytes: bytes, method: str = "manual"):
"""Mark a contact's identity key as explicitly verified."""
ik_hex = identity_key_bytes.hex()
now = datetime.now(timezone.utc).isoformat()
self._verified_contacts[user_id] = {
"identity_key": ik_hex,
"verified_at": now,
"method": method,
}
# Also ensure TOFU registry is up to date
if user_id not in self._known_identity_keys:
self._known_identity_keys[user_id] = {
"identity_key": ik_hex,
"first_seen": now,
"last_seen": now,
}
else:
self._known_identity_keys[user_id]["last_seen"] = now
if self.email:
_save_verified_contacts(self.email, self._verified_contacts, self._local_key)
_save_known_identity_keys(self.email, self._known_identity_keys, self._local_key)
# Update user cache status
cached = self._user_cache.get(user_id)
if cached:
cached["identity_key_status"] = "verified"
def unverify_contact(self, user_id: str):
"""Remove explicit verification for a contact."""
self._verified_contacts.pop(user_id, None)
if self.email:
_save_verified_contacts(self.email, self._verified_contacts, self._local_key)
cached = self._user_cache.get(user_id)
if cached and cached.get("identity_key_status") == "verified":
cached["identity_key_status"] = "trusted"
def accept_key_change(self, user_id: str, new_ik_bytes: bytes):
"""Accept a changed identity key — update TOFU, remove old verification."""
ik_hex = new_ik_bytes.hex()
now = datetime.now(timezone.utc).isoformat()
self._known_identity_keys[user_id] = {
"identity_key": ik_hex,
"first_seen": now,
"last_seen": now,
}
# Remove old verification — user must re-verify
self._verified_contacts.pop(user_id, None)
if self.email:
_save_known_identity_keys(self.email, self._known_identity_keys, self._local_key)
_save_verified_contacts(self.email, self._verified_contacts, self._local_key)
# Update cache
cached = self._user_cache.get(user_id)
if cached:
cached["identity_key_status"] = "trusted"
def get_verification_status(self, user_id: str) -> str:
"""Get verification status for a user.
Returns: "verified", "trusted", or "unverified".
"""
verified = self._verified_contacts.get(user_id)
if verified:
# Check key still matches
known = self._known_identity_keys.get(user_id)
if known and known.get("identity_key") == verified.get("identity_key"):
return "verified"
if user_id in self._known_identity_keys:
return "trusted"
return "unverified"
def get_safety_number(self, peer_user_id: str) -> str | None:
"""Get formatted safety number for a peer (requires both identity keys)."""
if not self.identity_public or not self.session:
return None
my_uid = self.session.get("user_id", "")
my_ik_bytes = serialize_ed25519_public(self.identity_public)
cached = self._user_cache.get(peer_user_id)
if not cached or not cached.get("identity_key_bytes"):
return None
return compute_safety_number(my_uid, my_ik_bytes,
peer_user_id, cached["identity_key_bytes"])
def get_my_fingerprint(self) -> str | None:
"""Get formatted fingerprint for own identity key."""
if not self.identity_public or not self.session:
return None
my_uid = self.session.get("user_id", "")
my_ik_bytes = serialize_ed25519_public(self.identity_public)
fp = compute_fingerprint(my_uid, my_ik_bytes)
return format_fingerprint(fp)
def get_peer_fingerprint(self, peer_user_id: str) -> str | None:
"""Get formatted fingerprint for a peer's identity key."""
cached = self._user_cache.get(peer_user_id)
if not cached or not cached.get("identity_key_bytes"):
return None
fp = compute_fingerprint(peer_user_id, cached["identity_key_bytes"])
return format_fingerprint(fp)
def get_verification_qr_data(self) -> bytes | None:
"""Get QR code payload bytes for own identity (for peer to scan)."""
if not self.identity_public or not self.session:
return None
my_uid = self.session.get("user_id", "")
my_ik_bytes = serialize_ed25519_public(self.identity_public)
return encode_verification_qr(my_uid, my_ik_bytes)
def verify_qr_code(self, qr_data: bytes) -> tuple[bool, str, str]:
"""Verify a scanned QR code against known identity keys.
Returns (success, user_id, message).
"""
try:
user_id, ik_bytes = decode_verification_qr(qr_data)
except ValueError as e:
return False, "", f"Invalid QR code: {e}"
cached = self._user_cache.get(user_id)
if not cached:
return False, user_id, "Unknown user — not in your contacts."
if not cached.get("identity_key_bytes"):
return False, user_id, "No identity key on record for this user."
if cached["identity_key_bytes"] != ik_bytes:
return False, user_id, "Identity key MISMATCH — verification failed!"
# Keys match — mark as verified
self.verify_contact(user_id, ik_bytes, method="qr_code")
username = cached.get("username", user_id[:8])
return True, user_id, f"Verified {username} via QR code."
# ------------------------------------------------------------------
# Registration
# ------------------------------------------------------------------
async def register(self, username: str, password: str, email: str) -> tuple[bool, str]:
"""Register user. Generates RSA + Ed25519 in memory (saved to disk
only after server confirms registration via confirm_registration)."""
self.username = username
self.email = email
pwd_bytes = bytearray(password.encode("utf-8")) if password else None
try:
pwd = bytes(pwd_bytes) if pwd_bytes else None
# Try loading existing keys (previous successful registration)
priv, pub, err = load_keys(email, password=pwd)
if priv is None:
priv, pub = generate_rsa_keypair()
self.private_key = priv
self.public_key = pub
try:
ed_priv, ed_pub = _load_identity_keys(email, password=pwd)
except Exception:
ed_priv, ed_pub = None, None
if ed_priv is None:
ed_priv, ed_pub = generate_identity_keypair()
self.identity_private = ed_priv
self.identity_public = ed_pub
self._cache_key = derive_self_encryption_key(ed_priv)
self._local_key = derive_local_storage_key(ed_priv)
# Store password for saving keys after confirm
self._reg_password = pwd
finally:
if pwd_bytes:
pwd_bytes[:] = b'\x00' * len(pwd_bytes)
pub_pem = serialize_public_key(pub).decode("utf-8")
ik_b64 = encode_binary(serialize_ed25519_public(ed_pub))
extra_fields: dict = {}
start = await self.send_and_recv(
"register",
username=username,
public_key=pub_pem,
email=email,
identity_key=ik_b64,
)
# Handle PoW challenge (server under pressure)
if start.get("status") == "pow_required":
challenge = start["data"]["challenge"]
mac = start["data"]["mac"]
difficulty = start["data"]["difficulty"]
self._logger.info("Server requires proof-of-work (difficulty %d), solving...", difficulty)
nonce = await asyncio.get_running_loop().run_in_executor(
None, _solve_pow, challenge, difficulty)
extra_fields = {"pow_challenge": challenge, "pow_mac": mac, "pow_nonce": nonce}
start = await self.send_and_recv(
"register",
username=username,
public_key=pub_pem,
email=email,
identity_key=ik_b64,
**extra_fields,
)
if start["status"] != "ok":
self._reg_password = None
return False, start["data"]["message"]
code = start["data"].get("code")
if code:
return True, code
return True, start["data"].get("message", "Check your email for the code.")
async def confirm_registration(self, email: str, username: str, code: str) -> tuple[bool, str]:
confirm = await self.send_and_recv("register_confirm", email=email, code=code)
if confirm["status"] == "ok":
# Registration confirmed — NOW save keys to disk
pwd = getattr(self, "_reg_password", None)
save_keys(email, self.private_key, self.public_key, password=pwd)
_save_identity_keys(email, self.identity_private, self.identity_public, password=pwd)
self._reg_password = None
self._load_verification_stores()
return True, f"Registered as '{username}' (ID: {confirm['data']['user_id']})"
return False, confirm["data"]["message"]
async def _generate_and_upload_prekeys(self, keep_spk: bool = False):
"""Generate SPK + OPKs and upload to server.
If keep_spk=True, re-sign the existing SPK instead of generating a new
one. This is used after device pairing so both devices share the same
SPK and either can respond to X3DH.
"""
if not self.identity_private:
return
if keep_spk and self.spk_private and self.spk_id:
# Re-sign existing SPK (both devices share the identity key)
spk_pub_bytes = serialize_x25519_public(self.spk_private.public_key())
spk_sig = ed25519_sign(self.identity_private, spk_pub_bytes)
spk_data = {
"id": self.spk_id,
"public_key": encode_binary(spk_pub_bytes),
"signature": encode_binary(spk_sig),
}
else:
# Save current SPK as previous for grace period (M4: in-flight X3DH)
if self.spk_private and self.spk_id:
self._prev_spk_private = self.spk_private
self._prev_spk_id = self.spk_id
_save_prev_spk(self.email, self.spk_private, self.spk_id, self._local_key)
# Generate a brand-new signed prekey
spk = generate_signed_prekey(self.identity_private)
self.spk_private = spk["private"]
self.spk_id = spk["id"]
_save_spk(self.email, spk["private"], spk["id"], self._local_key)
spk_data = {
"id": spk["id"],
"public_key": encode_binary(serialize_x25519_public(spk["public"])),
"signature": encode_binary(spk["signature"]),
}
# Generate one-time prekeys
opks = generate_one_time_prekeys(OPK_BATCH_SIZE)
for opk in opks:
self.opk_privates[opk["id"]] = opk["private"]
_save_opk_private(self.email, opk["id"], opk["private"], self._local_key)
# Upload to server
otp_data = [
{"id": opk["id"], "public_key": encode_binary(serialize_x25519_public(opk["public"]))}
for opk in opks
]
resp = await self.send_and_recv(
"upload_prekeys",
signed_prekey=spk_data,
one_time_prekeys=otp_data,
)
if resp.get("status") != "ok":
self._logger.warning("upload_prekeys failed: %s (will retry on login)",
resp.get("data", {}).get("message", "unknown"))
async def _ensure_prekeys(self):
"""Check OPK count and SPK age, replenish/rotate if needed.
Uses single-roundtrip `ensure_prekeys` handler when available,
falls back to legacy two-step flow (get_prekey_count + upload_prekeys).
"""
resp = await self.send_and_recv("get_prekey_count")
if resp["status"] != "ok":
return
count = resp["data"].get("count", 0)
spk_created_at = resp["data"].get("spk_created_at", "")
need_new_spk = False
if spk_created_at:
try:
created = datetime.fromisoformat(spk_created_at)
if created.tzinfo is None:
created = created.replace(tzinfo=timezone.utc)
age_days = (datetime.now(timezone.utc) - created).days
if age_days >= SPK_ROTATION_DAYS:
need_new_spk = True
self._logger.info("SPK is %d days old, rotating...", age_days)
except Exception:
need_new_spk = True
else:
# No SPK on server for this device — must upload one
need_new_spk = True
self._logger.info("No SPK on server for this device, uploading...")
if count < OPK_REPLENISH_THRESHOLD or need_new_spk:
if count >= OPK_REPLENISH_THRESHOLD:
self._logger.info("SPK rotation triggered (OPK count OK: %d)", count)
else:
self._logger.info("OPK count low (%d), replenishing...", count)
await self._generate_and_upload_prekeys_batch(need_new_spk)
async def _generate_and_upload_prekeys_batch(self, need_new_spk: bool = False):
"""Generate and upload prekeys in a single round-trip via ensure_prekeys."""
if not self.identity_private:
return
kwargs: dict = {}
# SPK
if need_new_spk:
if self.spk_private and self.spk_id:
self._prev_spk_private = self.spk_private
self._prev_spk_id = self.spk_id
_save_prev_spk(self.email, self.spk_private, self.spk_id, self._local_key)
spk = generate_signed_prekey(self.identity_private)
self.spk_private = spk["private"]
self.spk_id = spk["id"]
_save_spk(self.email, spk["private"], spk["id"], self._local_key)
kwargs["signed_prekey"] = {
"id": spk["id"],
"public_key": encode_binary(serialize_x25519_public(spk["public"])),
"signature": encode_binary(spk["signature"]),
}
# OPKs
opks = generate_one_time_prekeys(OPK_BATCH_SIZE)
for opk in opks:
self.opk_privates[opk["id"]] = opk["private"]
_save_opk_private(self.email, opk["id"], opk["private"], self._local_key)
kwargs["one_time_prekeys"] = [
{"id": opk["id"], "public_key": encode_binary(serialize_x25519_public(opk["public"]))}
for opk in opks
]
resp = await self.send_and_recv("ensure_prekeys", **kwargs)
if resp["status"] == "ok":
data = resp.get("data", {})
self._logger.info("ensure_prekeys: count=%d, spk_uploaded=%s, otps_uploaded=%d",
data.get("count", 0), data.get("uploaded_spk", False),
data.get("uploaded_otps", 0))
# ------------------------------------------------------------------
# Login
# ------------------------------------------------------------------
async def login(self, email: str, password: str) -> tuple[bool, str]:
"""Login user. Returns (success, message)."""
self.email = email
# Brute-force lockout check
remaining = _check_lockout(email)
if remaining > 0:
return False, f"Too many failed attempts. Try again in {remaining:.0f}s."
pwd_bytes = bytearray(password.encode("utf-8")) if password else None
try:
# Load RSA keys
priv, pub, err = load_keys(email, password=bytes(pwd_bytes) if pwd_bytes else None)
if priv is None:
if err and "password" in err.lower():
_record_failed_attempt(email)
return False, err or "No local keys found. Register first."
self.private_key = priv
self.public_key = pub
# Load identity keys
ed_priv, ed_pub = _load_identity_keys(email, password=bytes(pwd_bytes) if pwd_bytes else None)
finally:
if pwd_bytes:
pwd_bytes[:] = b'\x00' * len(pwd_bytes)
if ed_priv is not None:
self.identity_private = ed_priv
self.identity_public = ed_pub
self._cache_key = derive_self_encryption_key(ed_priv)
self._local_key = derive_local_storage_key(ed_priv)
self._load_verification_stores()
# Load SPK
spk_priv, spk_id = _load_spk(email, self._local_key)
if spk_priv:
self.spk_private = spk_priv
self.spk_id = spk_id
# Load previous SPK for grace period (M4)
prev_spk_priv, prev_spk_id = _load_prev_spk(email, self._local_key)
if prev_spk_priv:
self._prev_spk_private = prev_spk_priv
self._prev_spk_id = prev_spk_id
# Load device_id from disk
self.device_id = _load_device_id(email)
# RSA challenge-response login
start = await self.send_and_recv("login_start", email=email)
if start["status"] != "ok":
return False, start["data"]["message"]
challenge = decode_binary(start["data"]["challenge"])
signature = rsa_sign(self.private_key, challenge)
login_kwargs = {"email": email, "signature": encode_binary(signature),
"client_version": VERSION}
if self.device_id:
login_kwargs["device_id"] = self.device_id
finish = await self.send_and_recv("login_finish", **login_kwargs)
if finish["status"] == "ok":
self.session = finish["data"]
self.username = self.session.get("username", "")
# Store device_id from server
self.device_id = finish["data"].get("device_id")
if self.device_id:
_save_device_id(email, self.device_id)
# Replenish prekeys in background — after pairing, the new device
# has no local OPK private keys so we must generate fresh ones
# (server-side OPKs have no matching private keys on this device).
# Use keep_spk=True to preserve the shared SPK so both devices
# can respond to X3DH.
opk_dir = get_key_dir(self.email) / "opk_private"
has_local_opks = opk_dir.exists() and any(opk_dir.iterdir())
if has_local_opks:
asyncio.create_task(self._ensure_prekeys())
else:
self._logger.info("No local OPKs (likely new device). Generating fresh OPKs, keeping SPK.")
asyncio.create_task(self._generate_and_upload_prekeys(keep_spk=True))
_clear_lockout(email)
return True, f"Logged in as '{self.username}' (ID: {self.session['user_id']})"
return False, finish["data"]["message"]
# ------------------------------------------------------------------
# Pairing (device pairing — transfers RSA + identity keys)
# ------------------------------------------------------------------
async def pairing_start(self, email: str) -> tuple[bool, str]:
"""Start device pairing. Returns (success, code/message)."""
self._logger.info("pairing_start via %s for %s", self.server_endpoint(), email.strip().lower())
temp_priv, temp_pub = generate_x25519_keypair()
self._pairing_temp_private_key = temp_priv
temp_pub_raw = serialize_x25519_public(temp_pub)
self._pairing_fingerprint = compute_pairing_fingerprint(temp_pub_raw)
resp = await self.send_and_recv(
"pairing_start",
email=email,
temp_public_key=encode_binary(temp_pub_raw),
temp_key_type="x25519",
)
if resp["status"] == "ok":
self._pairing_code = resp["data"]["code"]
self._pairing_poll_token = resp["data"].get("poll_token", "")
return True, resp["data"]["code"]
self._pairing_fingerprint = ""
self._pairing_code = ""
return False, resp["data"]["message"]
async def pairing_wait(self, code: str, email: str, password: str, timeout: int = 300) -> tuple[bool, str]:
"""Wait for pairing payload and import keys. Returns (success, message)."""
if not self._pairing_temp_private_key:
return False, "Pairing not started."
from crypto_utils import aes_decrypt as _aes_decrypt
poll_token = getattr(self, "_pairing_poll_token", "")
deadline = asyncio.get_event_loop().time() + timeout
while asyncio.get_event_loop().time() < deadline:
resp = await self.send_and_recv("pairing_poll", code=code, poll_token=poll_token)
if resp["status"] != "ok":
self._pairing_fingerprint = ""
self._pairing_code = ""
return False, resp["data"]["message"]
if not resp["data"].get("ready"):
await asyncio.sleep(2.0)
continue
payload = resp["data"]["payload"]
try:
sender_pub_raw = decode_binary(payload["sender_public_key"])
sender_pub = load_x25519_public(sender_pub_raw)
my_pub_raw = serialize_x25519_public(self._pairing_temp_private_key.public_key())
shared_secret = x25519_dh(self._pairing_temp_private_key, sender_pub)
aes_key = derive_pairing_shared_key(shared_secret, my_pub_raw, sender_pub_raw)
nonce = decode_binary(payload["iv"])
ct = decode_binary(payload["ciphertext"])
tag = decode_binary(payload["tag"])
keys_json = _aes_decrypt(aes_key, nonce, ct, tag)
keys_data = json.loads(keys_json)
pwd_bytes = bytearray(password.encode("utf-8")) if password else None
try:
# Import RSA key
rsa_priv = load_private_key(keys_data["rsa_private"].encode(), password=None)
rsa_pub = rsa_priv.public_key()
save_keys(email, rsa_priv, rsa_pub, password=bytes(pwd_bytes) if pwd_bytes else None)
# Import identity keys
ed_priv = load_ed25519_private(bytes.fromhex(keys_data["identity_private"]))
ed_pub = ed_priv.public_key()
_save_identity_keys(email, ed_priv, ed_pub, password=bytes(pwd_bytes) if pwd_bytes else None)
finally:
if pwd_bytes:
pwd_bytes[:] = b'\x00' * len(pwd_bytes)
self.email = email
self.private_key = rsa_priv
self.public_key = rsa_pub
self.identity_private = ed_priv
self.identity_public = ed_pub
self._cache_key = derive_self_encryption_key(ed_priv)
self._local_key = derive_local_storage_key(ed_priv)
self._load_verification_stores()
# Import verification state transferred from the authorizing
# device (optional — absent when paired from an older client)
vc_raw = keys_data.get("verified_contacts")
if vc_raw:
try:
self._verified_contacts = json.loads(vc_raw)
_save_verified_contacts(email, self._verified_contacts, self._local_key)
except (json.JSONDecodeError, TypeError):
pass
kik_raw = keys_data.get("known_identity_keys")
if kik_raw:
try:
self._known_identity_keys = json.loads(kik_raw)
_save_known_identity_keys(email, self._known_identity_keys, self._local_key)
except (json.JSONDecodeError, TypeError):
pass
self._pairing_temp_private_key = None
self._pairing_fingerprint = ""
self._pairing_code = ""
# Multi-device: new device generates own SPK + OPKs on first
# login. No session/sender key import needed — each device
# has independent Double Ratchet sessions.
return True, "Pairing complete."
except Exception as e:
self._pairing_fingerprint = ""
self._pairing_code = ""
return False, f"Failed to import keys: {e}"
self._pairing_fingerprint = ""
self._pairing_code = ""
return False, "Pairing timed out."
async def authorize_device(self, code: str, expected_fingerprint: str) -> tuple[bool, str]:
"""Authorize a new device by sending all keys to it."""
if not self.private_key or not self.identity_private:
return False, "Not logged in."
expected_digits = normalize_pairing_fingerprint(expected_fingerprint)
if len(expected_digits) != 30:
return False, "Pairing fingerprint must contain 30 digits."
claim = await self.send_and_recv("pairing_claim", code=code)
if claim["status"] != "ok":
return False, claim["data"]["message"]
if claim["data"].get("temp_key_type") != "x25519":
return False, "Unsupported pairing key type. Update both devices and try again."
temp_pub_raw = decode_binary(claim["data"]["temp_public_key"])
actual_fp = compute_pairing_fingerprint(temp_pub_raw)
if normalize_pairing_fingerprint(actual_fp) != expected_digits:
self._logger.warning("Pairing fingerprint mismatch for code %s", code[:8])
return False, (
"Pairing fingerprint mismatch. Verify the new device fingerprint and try again.\n"
f"Expected: {expected_fingerprint}\n"
f"Received: {actual_fp}"
)
temp_pub = load_x25519_public(temp_pub_raw)
# Build keys payload — only RSA + identity key.
# Multi-device: new device generates own SPK + OPKs, creates independent
# sessions. No session/sender key transfer needed.
keys_data = {
"rsa_private": serialize_private_key(self.private_key, password=None).decode(),
"identity_private": serialize_ed25519_private_raw(self.identity_private).hex(),
}
# Carry the TOFU registry + manual verifications so a contact verified on
# this device stays verified on the new one (these stores are local and
# would otherwise start empty). Receivers ignore unknown fields.
if self._verified_contacts:
keys_data["verified_contacts"] = json.dumps(self._verified_contacts)
if self._known_identity_keys:
keys_data["known_identity_keys"] = json.dumps(self._known_identity_keys)
# Send keys to the new device first. Re-encrypting history can take a
# while on large accounts; doing it before pairing_send can make a valid
# code expire during authorization.
plaintext = json.dumps(keys_data).encode()
sender_priv, sender_pub = generate_x25519_keypair()
sender_pub_raw = serialize_x25519_public(sender_pub)
shared_secret = x25519_dh(sender_priv, temp_pub)
pairing_key = derive_pairing_shared_key(shared_secret, sender_pub_raw, temp_pub_raw)
_, nonce, ct, tag = aes_encrypt(plaintext, key=pairing_key)
payload = {
"sender_public_key": encode_binary(sender_pub_raw),
"iv": encode_binary(nonce),
"ciphertext": encode_binary(ct),
"tag": encode_binary(tag),
}
resp = await self.send_and_recv("pairing_send", code=code, payload=payload)
if resp["status"] == "ok":
async def _reencrypt_after_pairing():
try:
delay = random.uniform(*PAIRING_REENCRYPT_INITIAL_DELAY_RANGE)
self._logger.info("Delaying post-pairing history resync by %.1fs", delay)
await asyncio.sleep(delay)
await self.reencrypt_history()
except Exception as e:
self._logger.warning("Post-pairing re-encryption failed: %s", e)
asyncio.create_task(_reencrypt_after_pairing())
return True, "Device authorized."
return False, resp["data"]["message"]
# ------------------------------------------------------------------
# Password change (local key re-encryption only)
# ------------------------------------------------------------------
def change_password(self, old_password: str, new_password: str) -> tuple[bool, str]:
"""Change password for local key encryption (RSA + identity key).
Returns (success, message).
"""
if not self.email:
return False, "Not logged in."
old_pwd = bytearray(old_password.encode("utf-8"))
new_pwd = bytearray(new_password.encode("utf-8"))
try:
# 1. Verify old password by loading keys
priv, pub, err = load_keys(self.email, password=bytes(old_pwd))
if priv is None:
return False, "Wrong current password."
ed_priv, ed_pub = _load_identity_keys(self.email, password=bytes(old_pwd))
if ed_priv is None:
return False, "Failed to load identity key."
# 2. Re-save with new password
save_keys(self.email, priv, pub, password=bytes(new_pwd))
_save_identity_keys(self.email, ed_priv, ed_pub, password=bytes(new_pwd))
return True, "Password changed successfully."
finally:
old_pwd[:] = b'\x00' * len(old_pwd)
new_pwd[:] = b'\x00' * len(new_pwd)
async def change_username(self, new_username: str) -> tuple[bool, str]:
"""Change display name on server."""
if not self.session:
return False, "Not logged in."
new_username = new_username.strip()
if not new_username or len(new_username) > 100:
return False, "Username must be 1-100 characters."
resp = await self.send_and_recv("change_username", username=new_username)
if resp["status"] == "ok":
self.username = resp["data"]["username"]
if self.session:
self.session["username"] = self.username
return True, "Username changed."
return False, resp["data"].get("message", "Unknown error")
# ------------------------------------------------------------------
# Key rotation (RSA login key only)
# ------------------------------------------------------------------
async def rotate_keys(self, username: str, password: str) -> tuple[bool, str]:
"""Rotate RSA keypair to revoke other devices."""
if not self.session or self.session.get("username") != username:
return False, "Not logged in."
pwd_bytes = password.encode("utf-8") if password else None
priv, pub = generate_rsa_keypair()
pub_pem = serialize_public_key(pub).decode("utf-8")
# Persist the new key only after the server accepted it — overwriting
# private.pem first would brick the account if rotation fails.
resp = await self.send_and_recv("rotate_keys", public_key=pub_pem)
if resp["status"] == "ok":
save_keys(self.email, priv, pub, password=pwd_bytes)
self.private_key = priv
self.public_key = pub
return True, "RSA login keys rotated."
return False, resp["data"]["message"]
# ------------------------------------------------------------------
# Session management (X3DH + Double Ratchet)
# ------------------------------------------------------------------
async def _get_device_bundles(self, peer_user_id: str) -> list[dict]:
"""Get per-device key bundles for a peer. Caches for 5 minutes."""
import time
cached = self._device_bundle_cache.get(peer_user_id)
if cached:
ts, bundles = cached
if time.time() - ts < 300:
return bundles
resp = await self.send_and_recv("get_key_bundle", user_id=peer_user_id)
if resp["status"] != "ok":
raise RuntimeError(f"Cannot get key bundle for {peer_user_id}: {resp['data']['message']}")
data = resp["data"]
ik_b64 = data.get("identity_key", "")
device_bundles = data.get("device_bundles")
if device_bundles:
# Attach identity_key to each bundle
for b in device_bundles:
b["identity_key"] = ik_b64
else:
# Old server: wrap flat response as single-entry list
device_bundles = [{
"device_id": None,
"identity_key": ik_b64,
"signed_prekey_id": data.get("signed_prekey_id", ""),
"signed_prekey": data.get("signed_prekey", ""),
"spk_signature": data.get("spk_signature", ""),
"one_time_prekey_id": data.get("one_time_prekey_id"),
"one_time_prekey": data.get("one_time_prekey"),
}]
self._device_bundle_cache[peer_user_id] = (time.time(), device_bundles)
return device_bundles
async def _get_or_create_session(self, peer_user_id: str,
peer_device_id: str | None = None,
bundle: dict | None = None) -> DoubleRatchet:
"""Load existing session or create one via X3DH.
If peer_device_id is set, sessions are keyed by "user_id:device_id".
If bundle is provided, it's used instead of fetching from server.
"""
session_key = f"{peer_user_id}:{peer_device_id}" if peer_device_id else peer_user_id
# Check in-memory cache
if session_key in self.sessions:
return self.sessions[session_key]
# Check on disk
ratchet = _load_session(self.email, peer_user_id, self._local_key,
peer_device_id=peer_device_id)
if ratchet:
self.sessions[session_key] = ratchet
return ratchet
# Create new session via X3DH
if not bundle:
resp = await self.send_and_recv("get_key_bundle", user_id=peer_user_id)
if resp["status"] != "ok":
raise RuntimeError(f"Cannot get key bundle for {peer_user_id}: {resp['data']['message']}")
bundle = resp["data"]
ik_remote_bytes = decode_binary(bundle["identity_key"])
ik_remote = load_ed25519_public(ik_remote_bytes)
# TOFU: verify identity key before using it in X3DH
ik_status = self.check_identity_key(peer_user_id, ik_remote_bytes)
if ik_status in ("changed", "changed_verified"):
raise IdentityKeyChanged(peer_user_id, ik_remote_bytes, ik_status)
spk_remote = load_x25519_public(decode_binary(bundle["signed_prekey"]))
spk_sig = decode_binary(bundle["spk_signature"])
opk_remote = None
opk_id = bundle.get("one_time_prekey_id")
if bundle.get("one_time_prekey"):
opk_remote = load_x25519_public(decode_binary(bundle["one_time_prekey"]))
# Perform X3DH
shared_secret, ek_priv, ek_pub = x3dh_initiate(
self.identity_private,
ik_remote,
spk_remote,
spk_sig,
opk_remote,
)
# Initialize Double Ratchet as Alice
ratchet = DoubleRatchet.init_alice(shared_secret, spk_remote)
self.sessions[session_key] = ratchet
_save_session(self.email, peer_user_id, ratchet, self._local_key,
peer_device_id=peer_device_id)
# Build X3DH header for first message
x3dh_header = {
"ik": encode_binary(serialize_ed25519_public(self.identity_public)),
"ek": encode_binary(serialize_x25519_public(ek_pub)),
}
if opk_id:
x3dh_header["opk_id"] = opk_id
# Cache the x3dh header for the next send_message call
ratchet._x3dh_header = x3dh_header
# Cache remote user info
self._user_cache[peer_user_id] = {
"user_id": peer_user_id,
"identity_key": ik_remote,
"identity_key_bytes": ik_remote_bytes,
"identity_key_status": ik_status,
}
return ratchet
def _process_x3dh_header(self, sender_id: str, x3dh_header: dict,
sender_device_id: str | None = None,
spk_override=None) -> DoubleRatchet:
"""Process an incoming X3DH header to establish session as Bob.
Args:
spk_override: If provided, use this SPK private key instead of self.spk_private.
Used for grace period fallback (M4).
"""
ik_remote_bytes = decode_binary(x3dh_header["ik"])
ik_remote = load_ed25519_public(ik_remote_bytes)
# TOFU: verify identity key before using it in X3DH
ik_status = self.check_identity_key(sender_id, ik_remote_bytes)
if ik_status in ("changed", "changed_verified"):
raise IdentityKeyChanged(sender_id, ik_remote_bytes, ik_status)
ek_remote = load_x25519_public(decode_binary(x3dh_header["ek"]))
opk_id = x3dh_header.get("opk_id")
opk_priv = None
if opk_id:
opk_priv = _load_opk_private(self.email, opk_id, self._local_key)
# Deletion is deferred until the first message decrypts successfully
# (_consume_pending_opk). Deleting here would break the SPK
# grace-period retry: the second _process_x3dh_header call could no
# longer load the OPK and the message would be lost permanently.
spk_priv = spk_override if spk_override else self.spk_private
shared_secret = x3dh_respond(
self.identity_private,
spk_priv,
ik_remote,
ek_remote,
opk_priv,
)
spk_pub = spk_priv.public_key() if hasattr(spk_priv, 'public_key') else None
ratchet = DoubleRatchet.init_bob(shared_secret, (spk_priv, spk_pub))
ratchet._pending_opk_delete = opk_id if opk_priv else None
# NOTE: the ratchet is intentionally NOT installed into self.sessions
# nor saved to disk here. The caller does that only after the first
# message decrypts successfully — otherwise a failed/forged X3DH
# header would overwrite a working session.
self._user_cache[sender_id] = {
"user_id": sender_id,
"identity_key": ik_remote,
"identity_key_bytes": ik_remote_bytes,
"identity_key_status": ik_status,
}
return ratchet
def _consume_pending_opk(self, ratchet) -> None:
"""Delete the one-time prekey consumed by an X3DH handshake.
Called only after the first message decrypted successfully, so a failed
attempt (e.g. wrong SPK during the grace period) can still retry with
the same OPK.
"""
opk_id = getattr(ratchet, "_pending_opk_delete", None)
if opk_id:
_delete_opk_private(self.email, opk_id)
ratchet._pending_opk_delete = None
# ------------------------------------------------------------------
# Conversations
# ------------------------------------------------------------------
async def create_conversation(self, member_emails: list[str], name: str | None = None) -> tuple[str | None, str]:
kwargs = {"members": member_emails}
if name:
kwargs["name"] = name
resp = await self.send_and_recv("create_conversation", **kwargs)
if resp["status"] == "ok":
return resp["data"]["conversation_id"], "OK"
return None, resp["data"]["message"]
async def remove_member(self, conv_id: str, user_id: str) -> tuple[bool, str]:
resp = await self.send_and_recv("remove_member", conversation_id=conv_id, user_id=user_id)
if resp["status"] == "ok":
return True, "OK"
return False, resp["data"]["message"]
async def leave_group(self, conv_id: str) -> tuple[bool, str]:
"""Leave a group conversation."""
resp = await self.send_and_recv("leave_group", conversation_id=conv_id)
if resp["status"] == "ok":
# Clean up local sender key state for this group
self.sender_key_states.pop(conv_id, None)
# Remove received sender keys for this conversation
to_remove = [k for k in self.recv_sender_keys if k.startswith(f"{conv_id}:")]
for k in to_remove:
self.recv_sender_keys.pop(k, None)
return True, "OK"
return False, resp["data"]["message"]
async def rename_conversation(self, conv_id: str, name: str) -> tuple[bool, str]:
"""Rename a group conversation (creator only)."""
resp = await self.send_and_recv("rename_conversation", conversation_id=conv_id, name=name)
if resp["status"] == "ok":
return True, "OK"
return False, resp["data"]["message"]
async def delete_conversation(self, conv_id: str) -> tuple[bool, str]:
"""Delete a conversation (leave + server cleans up if empty)."""
resp = await self.send_and_recv("delete_conversation", conversation_id=conv_id)
if resp["status"] == "ok":
# Clean up local sender key state
self.sender_key_states.pop(conv_id, None)
to_remove = [k for k in self.recv_sender_keys if k.startswith(f"{conv_id}:")]
for k in to_remove:
self.recv_sender_keys.pop(k, None)
return True, "OK"
return False, resp["data"]["message"]
async def add_member(self, conv_id: str, email: str) -> tuple[bool, str]:
resp = await self.send_and_recv("add_member", conversation_id=conv_id, email=email)
if resp["status"] == "ok":
return True, "OK"
return False, resp["data"]["message"]
async def accept_invitation(self, conv_id: str) -> tuple[bool, str]:
"""Accept a group invitation."""
resp = await self.send_and_recv("accept_invitation", conversation_id=conv_id)
if resp["status"] == "ok":
return True, "OK"
return False, resp["data"]["message"]
async def decline_invitation(self, conv_id: str) -> tuple[bool, str]:
"""Decline a group invitation."""
resp = await self.send_and_recv("decline_invitation", conversation_id=conv_id)
if resp["status"] == "ok":
return True, "OK"
return False, resp["data"]["message"]
async def list_invitations(self) -> list[dict]:
"""List pending group invitations."""
resp = await self.send_and_recv("list_invitations")
if resp["status"] == "ok":
return resp["data"]["invitations"]
return []
async def list_conversations(self) -> list[dict]:
resp = await self.send_and_recv("list_conversations")
if resp["status"] == "ok":
return resp["data"]["conversations"]
return []
async def find_or_create_conversation(self, email: str) -> tuple[str | None, str]:
resp = await self.send_and_recv("find_conversation", email=email)
if resp["status"] != "ok":
return None, resp["data"]["message"]
conv_id = resp["data"]["conversation_id"]
if conv_id:
return conv_id, "OK"
return await self.create_conversation([email])
# ------------------------------------------------------------------
# Send message
# ------------------------------------------------------------------
def _cancel_typing_timer(self, conv_id: str):
task = self._typing_stop_tasks.pop(conv_id, None)
if task and not task.done():
task.cancel()
async def _typing_stop_after_delay(self, conv_id: str, delay: float):
try:
await asyncio.sleep(delay)
await self.typing_stop(conv_id)
except asyncio.CancelledError:
return
async def typing_start(self, conv_id: str):
"""Debounced typing_start with 3s inactivity timeout."""
if not conv_id or not self.session:
return
now = time.monotonic()
was_active = self._typing_active.get(conv_id, False)
last_sent = self._typing_last_sent.get(conv_id, 0.0)
should_send = (not was_active) or (now - last_sent >= 1.0)
self._typing_active[conv_id] = True
self._cancel_typing_timer(conv_id)
self._typing_stop_tasks[conv_id] = asyncio.create_task(
self._typing_stop_after_delay(conv_id, 3.0)
)
if not should_send:
return
self._typing_last_sent[conv_id] = now
try:
await self.send_and_recv("typing_start", timeout=5.0, conversation_id=conv_id)
except Exception:
pass
async def typing_stop(self, conv_id: str, force: bool = False):
if not conv_id or not self.session:
return
self._cancel_typing_timer(conv_id)
was_active = self._typing_active.get(conv_id, False)
self._typing_active[conv_id] = False
if not was_active and not force:
return
try:
await self.send_and_recv("typing_stop", timeout=5.0, conversation_id=conv_id)
except Exception:
pass
def _is_group(self, members: list[dict]) -> bool:
return len(members) > 2
async def send_message(self, conv_id: str, text: str, members: list[dict],
reply_to: str | None = None) -> tuple[bool, str | dict]:
"""Encrypt and send a message. DM: per-recipient Double Ratchet. Group: Sender Keys.
Returns (True, msg_dict) on success or (False, error_string) on failure.
msg_dict contains the full decrypted payload ready for display.
"""
my_user_id = self.session["user_id"]
await self.typing_stop(conv_id, force=True)
# Build plaintext payload
payload = {
"sender": self.username,
"text": text,
"reply_to": reply_to,
"timestamp": datetime.now(timezone.utc).isoformat(),
}
plaintext = pad_plaintext(json.dumps(payload, ensure_ascii=False).encode("utf-8"))
if self._is_group(members):
return await self._send_group_message(conv_id, plaintext, members, payload)
else:
return await self._send_dm(conv_id, plaintext, members, payload)
async def _send_dm(self, conv_id: str, plaintext: bytes, members: list[dict],
payload: dict | None = None) -> tuple[bool, str | dict]:
"""Encrypt DM with per-device Double Ratchet."""
my_user_id = self.session["user_id"]
recipients = []
first_ratchet_header = None
for member in members:
uid = member.get("user_id")
if not uid or uid == my_user_id:
continue
# Get all device bundles for this user
try:
device_bundles = await self._get_device_bundles(uid)
self._logger.debug("Got %d device bundles for %s", len(device_bundles), uid)
except Exception as e:
self._logger.warning("Failed to get device bundles for %s: %s", uid, e)
device_bundles = []
if not device_bundles:
# Fallback: try single session (legacy peer)
ratchet = await self._get_or_create_session(uid)
result = ratchet.encrypt(plaintext)
x3dh_hdr = getattr(ratchet, "_x3dh_header", None)
if x3dh_hdr:
delattr(ratchet, "_x3dh_header")
entry = {
"user_id": uid,
"encrypted_content": encode_binary(result["ciphertext"]),
"nonce": encode_binary(result["nonce"]),
"ratchet_header": result["header"],
}
if x3dh_hdr:
entry["x3dh_header"] = x3dh_hdr
recipients.append(entry)
if first_ratchet_header is None:
first_ratchet_header = result["header"]
_save_session(self.email, uid, ratchet, self._local_key)
continue
for bundle in device_bundles:
dev_id = bundle.get("device_id")
ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id,
bundle=bundle)
result = ratchet.encrypt(plaintext)
x3dh_hdr = getattr(ratchet, "_x3dh_header", None)
if x3dh_hdr:
delattr(ratchet, "_x3dh_header")
entry = {
"user_id": uid,
"encrypted_content": encode_binary(result["ciphertext"]),
"nonce": encode_binary(result["nonce"]),
"ratchet_header": result["header"],
}
if dev_id:
entry["device_id"] = dev_id
if x3dh_hdr:
entry["x3dh_header"] = x3dh_hdr
recipients.append(entry)
if first_ratchet_header is None:
first_ratchet_header = result["header"]
_save_session(self.email, uid, ratchet, self._local_key,
peer_device_id=dev_id)
# Encrypt self-copy with static key derived from identity (not ratchet)
# Uses SELF_DEVICE_ID so all own devices can read it
self_key = derive_self_encryption_key(self.identity_private)
_, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key)
recipients.append({
"user_id": my_user_id,
"encrypted_content": encode_binary(self_ct + self_tag),
"nonce": encode_binary(self_nonce),
"ratchet_header": {"self": True},
})
if not recipients:
return False, "No recipients."
kwargs = {
"conversation_id": conv_id,
"ratchet_header": first_ratchet_header,
"recipients": recipients,
}
resp = await self.send_and_recv("send_message", **kwargs)
if resp["status"] == "ok":
msg_data = resp.get("data", {})
if payload is not None:
result = {
**payload,
"message_id": msg_data.get("message_id", ""),
"created_at": msg_data.get("created_at", ""),
"sender_id": self.session["user_id"],
"conversation_id": conv_id,
"read_by": [],
}
_save_message_to_cache(self.email, conv_id, result["message_id"], result, self._cache_key)
return True, result
return True, "Message sent."
return False, resp["data"]["message"]
async def _send_group_message(self, conv_id: str, plaintext: bytes,
members: list[dict],
payload: dict | None = None) -> tuple[bool, str | dict]:
"""Encrypt group message with Sender Keys."""
my_user_id = self.session["user_id"]
# Get or create sender key for this group
sk = self.sender_key_states.get(conv_id)
if not sk:
sk = _load_sender_key_state(self.email, conv_id, self._local_key)
if not sk:
sk = SenderKeyState()
self.sender_key_states[conv_id] = sk
_save_sender_key_state(self.email, conv_id, sk, self._local_key)
self.sender_key_states[conv_id] = sk
await self._catch_up_sender_key_distribution(conv_id, members, sk)
# Encrypt with sender key
result = sk.encrypt(plaintext)
_save_sender_key_state(self.email, conv_id, sk, self._local_key)
# Build per-recipient entries (same ciphertext for all except self)
recipients = []
for member in members:
uid = member.get("user_id")
if not uid or uid == my_user_id:
continue
recipients.append({
"user_id": uid,
"encrypted_content": encode_binary(result["ciphertext"]),
"nonce": encode_binary(result["nonce"]),
})
# Self-encrypted copy (so other devices + history fetch can decrypt)
self_key = derive_self_encryption_key(self.identity_private)
_, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key)
recipients.append({
"user_id": my_user_id,
"encrypted_content": encode_binary(self_ct + self_tag),
"nonce": encode_binary(self_nonce),
"ratchet_header": {"self": True},
})
ratchet_header = {"dh_pub": "00" * 32, "n": 0, "pn": 0} # Dummy for groups
kwargs = {
"conversation_id": conv_id,
"ratchet_header": ratchet_header,
"recipients": recipients,
"sender_chain_id": encode_binary(bytes.fromhex(result["chain_id"])),
"sender_chain_n": result["n"],
}
resp = await self.send_and_recv("send_message", **kwargs)
if resp["status"] == "ok":
msg_data = resp.get("data", {})
if payload is not None:
result_msg = {
**payload,
"message_id": msg_data.get("message_id", ""),
"created_at": msg_data.get("created_at", ""),
"sender_id": self.session["user_id"],
"conversation_id": conv_id,
"read_by": [],
}
_save_message_to_cache(self.email, conv_id, result_msg["message_id"], result_msg, self._cache_key)
return True, result_msg
return True, "Message sent."
return False, resp["data"]["message"]
async def _catch_up_sender_key_distribution(self, conv_id: str, members: list[dict],
sk: SenderKeyState):
"""Ensure all current members have our existing sender key."""
my_user_id = self.session["user_id"]
current_member_ids = {
uid for uid in (member.get("user_id") for member in members)
if uid and uid != my_user_id
}
if not current_member_ids:
return
distributed_to = _load_sender_key_recipients(self.email, conv_id, self._local_key)
missing_ids = sorted(current_member_ids - distributed_to)
if not missing_ids:
return
distributed_now = await self._distribute_sender_key(
conv_id,
[{"user_id": uid} for uid in missing_ids],
sk,
)
if distributed_now:
distributed_to.update(distributed_now)
_save_sender_key_recipients(self.email, conv_id, distributed_to, self._local_key)
async def _distribute_sender_key(self, conv_id: str, members: list[dict],
sk: SenderKeyState) -> set[str]:
"""Send own sender key to all group members via pairwise Double Ratchet (per-device)."""
my_user_id = self.session["user_id"]
distributed_to: set[str] = set()
exported_key = sk.export_key()
# Build a special "sender_key_distribution" payload
payload = {
"sender": self.username,
"text": "",
"reply_to": None,
"timestamp": datetime.now(timezone.utc).isoformat(),
"_sender_key": {
"conv_id": conv_id,
"key": encode_binary(exported_key),
"sender_device_id": self.device_id,
},
}
plaintext = pad_plaintext(json.dumps(payload, ensure_ascii=False).encode("utf-8"))
# Send as DM to each member's devices (per-device encryption)
for member in members:
uid = member.get("user_id")
if not uid or uid == my_user_id:
continue
try:
# Get all device bundles for this user
try:
device_bundles = await self._get_device_bundles(uid)
except Exception:
device_bundles = []
if not device_bundles:
# Fallback: legacy single-device
ratchet = await self._get_or_create_session(uid)
result = ratchet.encrypt(plaintext)
x3dh_header = getattr(ratchet, "_x3dh_header", None)
if x3dh_header:
delattr(ratchet, "_x3dh_header")
recipient_entry = {
"user_id": uid,
"encrypted_content": encode_binary(result["ciphertext"]),
"nonce": encode_binary(result["nonce"]),
"ratchet_header": result["header"],
}
if x3dh_header:
recipient_entry["x3dh_header"] = x3dh_header
kwargs = {
"conversation_id": conv_id,
"ratchet_header": result["header"],
"recipients": [recipient_entry],
}
await self.send_and_recv("send_message", **kwargs)
_save_session(self.email, uid, ratchet, self._local_key)
distributed_to.add(uid)
else:
# Per-device encryption
recipients = []
first_rh = None
for bundle in device_bundles:
dev_id = bundle.get("device_id")
ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id,
bundle=bundle)
result = ratchet.encrypt(plaintext)
x3dh_header = getattr(ratchet, "_x3dh_header", None)
if x3dh_header:
delattr(ratchet, "_x3dh_header")
entry = {
"user_id": uid,
"encrypted_content": encode_binary(result["ciphertext"]),
"nonce": encode_binary(result["nonce"]),
"ratchet_header": result["header"],
}
if dev_id:
entry["device_id"] = dev_id
if x3dh_header:
entry["x3dh_header"] = x3dh_header
recipients.append(entry)
if first_rh is None:
first_rh = result["header"]
_save_session(self.email, uid, ratchet, self._local_key,
peer_device_id=dev_id)
kwargs = {
"conversation_id": conv_id,
"ratchet_header": first_rh,
"recipients": recipients,
}
await self.send_and_recv("send_message", **kwargs)
distributed_to.add(uid)
except Exception as e:
self._logger.warning("Failed to distribute sender key to %s: %s", uid, e)
return distributed_to
async def redistribute_sender_key_to_member(self, conv_id: str, new_user_id: str):
"""Redistribute our existing sender key to a newly joined group member.
Called when we receive a member_added notification — the new member needs our
sender key so they can decrypt future messages we send in the group.
"""
if not self.session:
return
sk = self.sender_key_states.get(conv_id)
if sk is None:
sk = _load_sender_key_state(self.email, conv_id, self._local_key)
if sk is None:
# We haven't sent anything in this group yet — no key to redistribute
return
try:
distributed = await self._distribute_sender_key(conv_id, [{"user_id": new_user_id}], sk)
if new_user_id in distributed:
recipients = _load_sender_key_recipients(self.email, conv_id, self._local_key)
recipients.add(new_user_id)
_save_sender_key_recipients(self.email, conv_id, recipients, self._local_key)
self._logger.info("Redistributed sender key for conv=%s to new member %s",
conv_id[:8], new_user_id[:8])
except Exception as e:
self._logger.warning("Failed to redistribute sender key to %s: %s", new_user_id, e)
# ------------------------------------------------------------------
# Decrypt messages
# ------------------------------------------------------------------
def _decrypt_message(self, msg_data: dict) -> dict:
"""Decrypt a single message (DM or group)."""
# Check for self-encrypted marker FIRST — after re-encryption,
# group messages will have {"self": true} ratchet_header but still
# have sender_chain_id at message level.
rh = msg_data.get("ratchet_header", {})
if isinstance(rh, dict) and rh.get("self"):
return self._decrypt_dm(msg_data)
if msg_data.get("sender_chain_id"):
return self._decrypt_group(msg_data)
else:
return self._decrypt_dm(msg_data)
def _decrypt_dm(self, msg_data: dict) -> dict:
"""Decrypt DM using Double Ratchet with sender, or static key for self-copies."""
sender_id = msg_data.get("sender_id", "")
sender_device_id = msg_data.get("sender_device_id")
ratchet_header = msg_data.get("ratchet_header", {})
ct_b64 = msg_data.get("encrypted_content", "")
nonce_b64 = msg_data.get("nonce", "")
if not ct_b64 or not nonce_b64:
raise ValueError("Missing ciphertext or nonce")
ciphertext = decode_binary(ct_b64)
nonce = decode_binary(nonce_b64)
# Self-encrypted message (own sent message copy)
if isinstance(ratchet_header, dict) and ratchet_header.get("self"):
self_key = derive_self_encryption_key(self.identity_private)
ct = ciphertext[:-16]
tag = ciphertext[-16:]
plaintext = aes_decrypt(self_key, nonce, ct, tag)
else:
x3dh_header = msg_data.get("x3dh_header")
# Session key: "sender_id:sender_device_id" or just "sender_id" for legacy
session_key = f"{sender_id}:{sender_device_id}" if sender_device_id else sender_id
# Try to load existing session
ratchet = self.sessions.get(session_key)
if not ratchet:
ratchet = _load_session(self.email, sender_id, self._local_key,
peer_device_id=sender_device_id)
if ratchet:
self.sessions[session_key] = ratchet
if ratchet and not x3dh_header:
# Normal case: existing session, no X3DH header
plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce)
_save_session(self.email, sender_id, ratchet, self._local_key,
peer_device_id=sender_device_id)
elif x3dh_header:
if ratchet:
# Existing session + X3DH header: sender may have reset.
backup = ratchet.export_state()
try:
plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce)
_save_session(self.email, sender_id, ratchet, self._local_key,
peer_device_id=sender_device_id)
except Exception:
# Restore the known-good session before attempting a
# fresh X3DH; if the X3DH path fails too, this restored
# session stays installed (in memory and on disk).
restored = DoubleRatchet.import_state(backup)
self.sessions[session_key] = restored
_save_session(self.email, sender_id, restored, self._local_key,
peer_device_id=sender_device_id)
ratchet = self._process_x3dh_header(sender_id, x3dh_header,
sender_device_id=sender_device_id)
try:
plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce)
except Exception:
if self._prev_spk_private:
ratchet = self._process_x3dh_header(
sender_id, x3dh_header,
sender_device_id=sender_device_id,
spk_override=self._prev_spk_private)
plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce)
else:
raise
# First decrypt succeeded — only now adopt the new session
self.sessions[session_key] = ratchet
_save_session(self.email, sender_id, ratchet, self._local_key,
peer_device_id=sender_device_id)
else:
ratchet = self._process_x3dh_header(sender_id, x3dh_header,
sender_device_id=sender_device_id)
try:
plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce)
except Exception:
if self._prev_spk_private:
ratchet = self._process_x3dh_header(
sender_id, x3dh_header,
sender_device_id=sender_device_id,
spk_override=self._prev_spk_private)
plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce)
else:
raise
# First decrypt succeeded — install + persist the session
self.sessions[session_key] = ratchet
_save_session(self.email, sender_id, ratchet, self._local_key,
peer_device_id=sender_device_id)
else:
raise ValueError(f"No session for sender {sender_id}")
self._consume_pending_opk(ratchet)
plaintext = unpad_plaintext(plaintext)
payload = json.loads(plaintext)
# Handle sender key distribution messages
if "_sender_key" in payload:
sk_data = payload["_sender_key"]
sk_conv_id = sk_data["conv_id"]
sk_key = decode_binary(sk_data["key"])
sk_sender_device_id = sk_data.get("sender_device_id")
recv_sk = SenderKeyState.from_key(sk_key)
if sk_sender_device_id:
cache_key = f"{sk_conv_id}:{sender_id}:{sk_sender_device_id}"
else:
cache_key = f"{sk_conv_id}:{sender_id}"
self.recv_sender_keys[cache_key] = recv_sk
_save_recv_sender_key(self.email, sk_conv_id, sender_id, recv_sk, self._local_key,
sender_device_id=sk_sender_device_id)
# Return empty — this is a control message, not user-visible
return None
return payload
def _decrypt_group(self, msg_data: dict) -> dict:
"""Decrypt group message using sender's Sender Key."""
sender_id = msg_data.get("sender_id", "")
sender_device_id = msg_data.get("sender_device_id")
conv_id = msg_data.get("conversation_id", "")
chain_id_b64 = msg_data.get("sender_chain_id", "")
chain_n = msg_data.get("sender_chain_n", 0)
ct_b64 = msg_data.get("encrypted_content", "")
nonce_b64 = msg_data.get("nonce", "")
if not ct_b64 or not nonce_b64 or not chain_id_b64:
raise ValueError("Missing group message fields")
ciphertext = decode_binary(ct_b64)
nonce = decode_binary(nonce_b64)
chain_id = decode_binary(chain_id_b64)
my_user_id = self.session["user_id"]
# If we sent this message, use our own sender key
if sender_id == my_user_id:
sk = self.sender_key_states.get(conv_id)
if not sk:
sk = _load_sender_key_state(self.email, conv_id, self._local_key)
if sk:
self.sender_key_states[conv_id] = sk
if not sk:
raise ValueError("Own sender key not found")
# For our own messages, we can't decrypt from sender key (it's already advanced)
# Return a placeholder — the server echoed our ciphertext
raise ValueError("Cannot decrypt own group message from sender key")
# Use received sender key — try with sender_device_id first, fall back to without
sk = None
if sender_device_id:
cache_key = f"{conv_id}:{sender_id}:{sender_device_id}"
sk = self.recv_sender_keys.get(cache_key)
if not sk:
sk = _load_recv_sender_key(self.email, conv_id, sender_id, self._local_key,
sender_device_id=sender_device_id)
if sk:
self.recv_sender_keys[cache_key] = sk
if not sk:
# Fallback: try without device_id (legacy or same-device)
cache_key = f"{conv_id}:{sender_id}"
sk = self.recv_sender_keys.get(cache_key)
if not sk:
sk = _load_recv_sender_key(self.email, conv_id, sender_id, self._local_key)
if sk:
self.recv_sender_keys[cache_key] = sk
if not sk:
raise ValueError(f"No sender key for {sender_id} in conversation {conv_id}")
plaintext = unpad_plaintext(sk.decrypt(chain_id.hex(), chain_n, ciphertext, nonce))
_save_recv_sender_key(self.email, conv_id, sender_id, sk, self._local_key,
sender_device_id=sender_device_id)
return json.loads(plaintext)
# ------------------------------------------------------------------
# Get/decrypt messages (batch)
# ------------------------------------------------------------------
async def get_messages(self, conv_id: str, limit: int = 50, offset: int = 0) -> list[dict]:
cache = _load_message_cache(self.email, conv_id, self._cache_key)
my_user_id = self.session["user_id"] if self.session else ""
# Incremental sync: use stored server timestamp from last successful fetch.
after_ts = None
if cache and offset == 0:
after_ts = cache.get("__last_server_ts", {}).get("ts")
req_params = {"conversation_id": conv_id, "limit": limit, "offset": offset}
if after_ts:
req_params["after_ts"] = after_ts
resp = await self.send_and_recv("get_messages", **req_params)
if resp["status"] != "ok":
# Offline fallback: return from cache if available
if cache and offset == 0:
return self._build_from_cache(cache)
return []
raw_messages = resp["data"]["messages"]
raw_messages.reverse() # Server returns DESC, reverse to ASC
# Decrypt new messages from server
new_decrypted = self._decrypt_raw_messages(raw_messages, cache, conv_id, my_user_id)
# Advance the incremental-sync watermark only across the prefix of
# messages that are settled in the cache (decrypted, control, deleted,
# or failed too many times). Stopping at the first unsettled message
# means a transiently undecryptable message (e.g. sender key not yet
# received) is re-fetched and retried on the next sync instead of
# being skipped forever.
if raw_messages and offset == 0:
newest_ts = ""
for m in raw_messages:
entry = cache.get(m["message_id"])
if entry is None:
break
fails = entry.get("_decrypt_failed", 0)
if fails and fails < _MAX_DECRYPT_RETRIES:
break
newest_ts = m.get("created_at", "") or newest_ts
prev_ts = cache.get("__last_server_ts", {}).get("ts", "")
if newest_ts and newest_ts > prev_ts:
cache["__last_server_ts"] = {"ts": newest_ts}
_save_message_to_cache(self.email, conv_id, "__last_server_ts",
{"ts": newest_ts}, cache_key=self._cache_key)
# All non-critical ops fire-and-forget to avoid blocking message display
# Confirm delivery for messages from others
deliver_ids = [m["message_id"] for m in new_decrypted
if m.get("sender_id") and m["sender_id"] != my_user_id
and not m.get("deleted")]
if deliver_ids:
asyncio.ensure_future(self.confirm_delivery(conv_id, deliver_ids))
# Mark entire conversation as read (fire-and-forget)
asyncio.ensure_future(self.mark_conversation_read(conv_id))
# Flush self-encryption queue in background
if self._pending_self_encrypt:
asyncio.ensure_future(self._flush_self_encrypt())
if after_ts:
# Incremental: sync deletions in background, build from cache NOW
asyncio.ensure_future(self._sync_deletions(conv_id, after_ts))
return self._build_from_cache(cache)
return new_decrypted
async def _sync_deletions(self, conv_id: str, after_ts: str):
"""Sync message deletions from server (background, non-blocking)."""
try:
del_resp = await self.send_and_recv("get_deleted_since",
conversation_id=conv_id, since_ts=after_ts)
if del_resp.get("status") == "ok":
for del_id in del_resp.get("data", {}).get("deleted_ids", []):
_save_message_to_cache(self.email, conv_id, del_id, {"deleted": True},
cache_key=self._cache_key)
except Exception:
pass
def get_cached_messages(self, conv_id: str) -> list[dict]:
"""Return messages from local disk cache only (no server call). Instant."""
if not self.email:
return []
cache = _load_message_cache(self.email, conv_id, self._cache_key)
if not cache:
return []
return self._build_from_cache(cache)
def _build_from_cache(self, cache: dict) -> list[dict]:
"""Build sorted message list from local cache (all messages)."""
messages = []
for msg_id, p in cache.items():
if p.get("_control") or msg_id.startswith("__"):
continue
if p.get("_decrypt_failed"):
messages.append({
"message_id": msg_id,
"sender": "???",
"text": "[Decryption failed]",
"created_at": p.get("created_at", ""),
"read_by": [],
"delivered_to": [],
})
continue
entry = dict(p)
entry.setdefault("message_id", msg_id)
entry.setdefault("read_by", [])
entry.setdefault("delivered_to", [])
messages.append(entry)
messages.sort(key=lambda m: m.get("created_at", ""))
return messages
def _decrypt_raw_messages(self, raw_messages: list, cache: dict,
conv_id: str, my_user_id: str) -> list[dict]:
"""Decrypt server messages, update cache. Returns list of decrypted dicts."""
decrypted = []
for m in raw_messages:
msg_id = m["message_id"]
if m.get("deleted_at"):
decrypted.append({
"message_id": msg_id,
"sender": "",
"text": "",
"created_at": m["created_at"],
"read_by": [],
"sender_id": m.get("sender_id", ""),
"deleted": True,
})
cache[msg_id] = {"deleted": True, "created_at": m["created_at"]}
continue
# Check local cache first (ratchet keys are one-time use)
cached = cache.get(msg_id)
if cached and cached.get("_decrypt_failed"):
if cached["_decrypt_failed"] >= _MAX_DECRYPT_RETRIES:
decrypted.append({
"message_id": msg_id,
"sender": "???",
"text": "[Decryption failed]",
"created_at": m["created_at"],
"read_by": [],
"sender_id": m.get("sender_id", ""),
})
continue
cached = None # retry decryption below
if cached and not cached.get("_control"):
cached["read_by"] = m.get("read_by", [])
cached["delivered_to"] = m.get("delivered_to", [])
cached["created_at"] = m["created_at"]
if m.get("reactions"):
cached["reactions"] = m["reactions"]
if m.get("pinned_at"):
cached["pinned_at"] = m["pinned_at"]
cached["pinned_by"] = m.get("pinned_by", "")
else:
cached.pop("pinned_at", None)
cached.pop("pinned_by", None)
decrypted.append(cached)
continue
if cached and cached.get("_control"):
continue
try:
msg_data = {
"sender_id": m.get("sender_id", ""),
"sender_device_id": m.get("sender_device_id"),
"conversation_id": conv_id,
"ratchet_header": m.get("ratchet_header", {}),
"encrypted_content": m.get("encrypted_content", ""),
"nonce": m.get("nonce", ""),
"x3dh_header": m.get("x3dh_header"),
"sender_chain_id": m.get("sender_chain_id"),
"sender_chain_n": m.get("sender_chain_n"),
}
payload = self._decrypt_message(msg_data)
if payload is None:
_save_message_to_cache(self.email, conv_id, msg_id, {"_control": True},
cache_key=self._cache_key)
cache[msg_id] = {"_control": True}
continue
payload["message_id"] = msg_id
payload["created_at"] = m["created_at"]
payload["read_by"] = m.get("read_by", [])
payload["delivered_to"] = m.get("delivered_to", [])
payload["sender_id"] = m.get("sender_id", "")
if m.get("reactions"):
payload["reactions"] = m["reactions"]
if m.get("pinned_at"):
payload["pinned_at"] = m["pinned_at"]
payload["pinned_by"] = m.get("pinned_by", "")
decrypted.append(payload)
_save_message_to_cache(self.email, conv_id, msg_id, payload,
cache_key=self._cache_key)
cache[msg_id] = payload
if m.get("sender_id", "") != my_user_id:
self._pending_self_encrypt.append({
"message_id": msg_id,
"payload": {k: v for k, v in payload.items()
if k not in ("message_id", "created_at", "read_by",
"delivered_to", "sender_id", "deleted")},
})
except Exception as e:
# Record the failure (with retry count) so the sync watermark
# stops here and the message is retried on the next fetch.
fails = (cache.get(msg_id) or {}).get("_decrypt_failed", 0) + 1
fail_entry = {"_decrypt_failed": fails, "created_at": m["created_at"]}
cache[msg_id] = fail_entry
_save_message_to_cache(self.email, conv_id, msg_id, fail_entry,
cache_key=self._cache_key)
decrypted.append({
"message_id": msg_id,
"sender": "???",
"text": f"[Decryption failed: {e}]",
"created_at": m["created_at"],
"read_by": [],
"sender_id": m.get("sender_id", ""),
})
return decrypted
async def _flush_self_encrypt(self):
"""Upload self-encrypted copies of received messages for multi-device access."""
if not self._pending_self_encrypt or not self.identity_private:
return
self_key = derive_self_encryption_key(self.identity_private)
updates = []
for item in list(self._pending_self_encrypt):
try:
plaintext = json.dumps(item["payload"], ensure_ascii=False).encode("utf-8")
_, nonce, ct, tag = aes_encrypt(plaintext, key=self_key)
updates.append({
"message_id": item["message_id"],
"encrypted_content": encode_binary(ct + tag),
"nonce": encode_binary(nonce),
})
except Exception:
pass
self._pending_self_encrypt.clear()
if updates:
try:
for i in range(0, len(updates), 500):
batch = updates[i:i + 500]
await self.send_and_recv("reencrypt_messages", updates=batch)
except Exception as e:
self._logger.warning("Failed to self-encrypt received messages: %s", e)
async def mark_read(self, conv_id: str, message_ids: list[str]):
if not message_ids:
return
await self.send_and_recv("mark_read", conversation_id=conv_id, message_ids=message_ids)
async def mark_conversation_read(self, conv_id: str):
"""Mark ALL unread messages in a conversation as read (server-side bulk)."""
try:
await self.send_and_recv("mark_conversation_read", conversation_id=conv_id)
except Exception:
pass # non-critical — don't fail message loading
async def confirm_delivery(self, conv_id: str, message_ids: list[str]):
"""Confirm delivery of messages (fire-and-forget, non-critical)."""
if not message_ids:
return
try:
await self.send_and_recv("confirm_delivery",
conversation_id=conv_id, message_ids=message_ids)
except Exception:
pass # non-critical
def search_messages(self, conv_id: str, query: str) -> list[dict]:
"""Search cached messages in a conversation. Returns matching messages."""
cache = _load_message_cache(self.email, conv_id, self._cache_key)
query_lower = query.lower()
results = []
for msg_id, payload in cache.items():
if payload.get("deleted") or payload.get("_control") or payload.get("_sender_key"):
continue
text = payload.get("text", "")
if query_lower in text.lower():
entry = dict(payload)
entry["message_id"] = msg_id
results.append(entry)
results.sort(key=lambda m: m.get("created_at", ""))
return results
async def reset_session(self, peer_user_id: str, peer_device_id: str | None = None):
"""Delete local session and notify peer to do the same."""
if peer_device_id:
session_key = f"{peer_user_id}:{peer_device_id}"
else:
session_key = peer_user_id
self.sessions.pop(session_key, None)
_delete_session_file(self.email, peer_user_id, peer_device_id)
await self.send_and_recv("session_reset",
peer_user_id=peer_user_id,
peer_device_id=peer_device_id or "")
def handle_session_reset_notification(self, from_user_id: str, from_device_id: str | None = None):
"""Handle incoming session reset notification — delete the matching session."""
if from_device_id:
session_key = f"{from_user_id}:{from_device_id}"
else:
session_key = from_user_id
self.sessions.pop(session_key, None)
_delete_session_file(self.email, from_user_id, from_device_id)
# ------------------------------------------------------------------
# Local message cache updates
# ------------------------------------------------------------------
def load_message_cache(self, conv_id: str) -> dict:
"""Load cached messages for a conversation. Returns {msg_id: payload}."""
if not self.email:
return {}
return _load_message_cache(self.email, conv_id, self._cache_key)
def update_message_in_cache(self, conv_id: str, message_id: str, updates: dict):
"""Update fields of a cached message on disk (synchronous)."""
if not self.email:
return
cache = _load_message_cache(self.email, conv_id, self._cache_key)
if message_id not in cache or cache[message_id].get("_control"):
return
for key, value in updates.items():
if value is None:
cache[message_id].pop(key, None)
else:
cache[message_id][key] = value
d = get_key_dir(self.email) / "message_cache"
if self._cache_key:
_save_message_cache_full(d, conv_id, cache, self._cache_key)
# ------------------------------------------------------------------
# Reactions, Pins, Forwarding
# ------------------------------------------------------------------
async def react_message(self, message_id: str, reaction: str, action: str = "add") -> tuple[bool, str]:
"""Add or remove a reaction on a message."""
resp = await self.send_and_recv("react_message",
message_id=message_id, reaction=reaction, action=action)
if resp["status"] == "ok":
return True, "OK"
return False, resp.get("data", {}).get("message", "Failed")
async def pin_message(self, message_id: str, conversation_id: str, action: str = "pin") -> tuple[bool, str]:
"""Pin or unpin a message."""
resp = await self.send_and_recv("pin_message",
message_id=message_id, conversation_id=conversation_id, action=action)
if resp["status"] == "ok":
return True, "OK"
return False, resp.get("data", {}).get("message", "Failed")
async def get_pinned_messages(self, conversation_id: str) -> list[dict]:
"""Get list of pinned messages for a conversation."""
resp = await self.send_and_recv("get_pinned_messages", conversation_id=conversation_id)
if resp["status"] == "ok":
return resp["data"].get("messages", [])
return []
async def forward_message(self, target_conv_id: str, original_msg: dict,
target_members: list[dict]) -> tuple[bool, str | dict]:
"""Forward a message to another conversation."""
text = original_msg.get("text", "")
payload = {
"sender": self.username,
"text": text,
"forwarded_from": {
"sender": original_msg.get("sender", ""),
"conversation_id": original_msg.get("conversation_id", ""),
"message_id": original_msg.get("message_id", ""),
},
"timestamp": datetime.now(timezone.utc).isoformat(),
}
# Forward image/file metadata (the encrypted blob is already on the server)
if original_msg.get("image"):
payload["image"] = original_msg["image"]
if not text:
payload["text"] = ""
if original_msg.get("file"):
payload["file"] = original_msg["file"]
if not text:
payload["text"] = ""
plaintext = pad_plaintext(json.dumps(payload, ensure_ascii=False).encode("utf-8"))
if self._is_group(target_members):
return await self._send_group_message(target_conv_id, plaintext, target_members, payload)
else:
return await self._send_dm(target_conv_id, plaintext, target_members, payload)
# ------------------------------------------------------------------
# Decrypt notification
# ------------------------------------------------------------------
def decrypt_notification(self, notif_data: dict) -> dict | None:
"""Decrypt a new_message notification. Returns parsed payload or None.
Supports new multi-device format (device_entries array) and legacy flat format.
"""
try:
conv_id = notif_data.get("conversation_id", "")
msg_id = notif_data.get("message_id", "")
sender_id = notif_data.get("sender_id", "")
sender_device_id = notif_data.get("sender_device_id")
my_user_id = self.session["user_id"] if self.session else ""
# Extract per-device encrypted content from device_entries or flat fields
encrypted_content = ""
nonce = ""
ratchet_header = {}
x3dh_header = None
device_entries = notif_data.get("device_entries")
if device_entries:
# Multi-device format: pick entry matching our device_id or SELF_DEVICE_ID
chosen = None
self_entry = None
for entry in device_entries:
eid = entry.get("device_id", "")
if eid == self.device_id:
chosen = entry
break
if eid == "00000000-0000-0000-0000-000000000000":
self_entry = entry
# If sender is us, prefer self-encrypted entry
if sender_id == my_user_id:
chosen = self_entry or chosen
elif not chosen:
chosen = self_entry
if not chosen:
self._logger.warning("No matching device_entry for device %s", self.device_id)
return None
encrypted_content = chosen.get("encrypted_content", "")
nonce = chosen.get("nonce", "")
ratchet_header = chosen.get("ratchet_header") or notif_data.get("ratchet_header", {})
x3dh_header = chosen.get("x3dh_header") or notif_data.get("x3dh_header")
else:
# Legacy flat format
encrypted_content = notif_data.get("encrypted_content", "")
nonce = notif_data.get("nonce", "")
ratchet_header = notif_data.get("ratchet_header", {})
x3dh_header = notif_data.get("x3dh_header")
msg_data = {
"sender_id": sender_id,
"sender_device_id": sender_device_id,
"conversation_id": conv_id,
"ratchet_header": ratchet_header,
"encrypted_content": encrypted_content,
"nonce": nonce,
"x3dh_header": x3dh_header,
"sender_chain_id": notif_data.get("sender_chain_id"),
"sender_chain_n": notif_data.get("sender_chain_n"),
}
payload = self._decrypt_message(msg_data)
if payload is None:
# Cache control message so get_messages skips it
if msg_id and conv_id:
_save_message_to_cache(self.email, conv_id, msg_id, {"_control": True},
cache_key=self._cache_key)
return None
payload["conversation_id"] = conv_id
payload["message_id"] = msg_id
payload["sender_id"] = sender_id
# Use server-compatible timestamp (no timezone suffix) for cache consistency
_ts = payload.get("timestamp", "")
if _ts:
# Strip timezone suffix (+00:00 or Z) to match server DATETIME format
_ts = _ts.replace("+00:00", "").replace("Z", "")
# Strip microseconds if present
if "." in _ts:
_ts = _ts[:_ts.index(".")]
payload["created_at"] = _ts
payload["read_by"] = []
payload["delivered_to"] = []
# Cache so get_messages doesn't re-decrypt (ratchet keys are one-time)
if msg_id and conv_id:
_save_message_to_cache(self.email, conv_id, msg_id, payload,
cache_key=self._cache_key)
# Queue self-encryption for received messages (multi-device access)
if sender_id != my_user_id and msg_id:
self._pending_self_encrypt.append({
"message_id": msg_id,
"payload": {k: v for k, v in payload.items()
if k not in ("conversation_id", "message_id", "created_at",
"read_by", "delivered_to", "sender_id", "deleted")},
})
return payload
except IdentityKeyChanged:
raise # Must propagate to caller for key-change UI
except Exception as e:
self._logger.warning("Failed to decrypt notification: %s", e)
return None
# ------------------------------------------------------------------
# Delete message
# ------------------------------------------------------------------
async def delete_message(self, message_id: str) -> tuple[bool, str]:
resp = await self.send_and_recv("delete_message", message_id=message_id)
if resp["status"] == "ok":
return True, "Message deleted."
return False, resp["data"]["message"]
# ------------------------------------------------------------------
# Image sharing
# ------------------------------------------------------------------
async def send_image(self, conv_id: str, image_path: str, members: list[dict],
reply_to: str | None = None) -> tuple[bool, str]:
"""Encrypt and upload an image, then send as a message."""
await self.typing_stop(conv_id, force=True)
try:
from PIL import Image
import io
except ImportError:
return False, "Pillow is required for image sharing. Install with: pip install Pillow"
path = Path(image_path)
if not path.exists():
return False, "File not found."
try:
img = Image.open(path)
img.load()
except Exception as e:
return False, f"Cannot open image: {e}"
# Prepare image bytes — offload CPU-heavy work to thread
def _prepare_image(img, path):
"""PIL resize + thumbnail + AES encrypt (runs in thread)."""
original_format = img.format or "JPEG"
if original_format.upper() not in ("JPEG", "PNG", "WEBP", "GIF", "BMP"):
original_format = "JPEG"
image_bytes = path.read_bytes()
# AES-GCM overhead: 16 bytes tag. Check raw size as proxy.
if MAX_IMAGE_BYTES > 0 and len(image_bytes) + 16 > MAX_IMAGE_BYTES:
if img.mode not in ("RGB", "L"):
img = img.convert("RGB")
for quality in (92, 85, 75, 60):
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=quality)
image_bytes = buf.getvalue()
if len(image_bytes) + 16 <= MAX_IMAGE_BYTES:
break
else:
for max_dim in (3840, 2560, 1920, 1280):
if max(img.size) > max_dim:
img.thumbnail((max_dim, max_dim), Image.Resampling.LANCZOS)
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=75)
image_bytes = buf.getvalue()
if len(image_bytes) + 16 <= MAX_IMAGE_BYTES:
break
# Generate thumbnail
thumb = img.copy()
thumb.thumbnail((200, 200), Image.Resampling.LANCZOS)
if thumb.mode not in ("RGB", "L"):
thumb = thumb.convert("RGB")
thumb_buf = io.BytesIO()
thumb.save(thumb_buf, format="JPEG", quality=60)
thumbnail_b64 = encode_binary(thumb_buf.getvalue())
# Encrypt
img_aes_key, img_iv, img_ct, img_tag = aes_encrypt(image_bytes)
encrypted_image = img_ct + img_tag
return image_bytes, thumbnail_b64, img_aes_key, img_iv, encrypted_image
image_bytes, thumbnail_b64, img_aes_key, img_iv, encrypted_image = \
await asyncio.get_event_loop().run_in_executor(None, _prepare_image, img, path)
file_id = str(uuid.uuid4())
file_size = len(encrypted_image)
# Chunked upload
resp = await self.send_and_recv(
"upload_image_start",
conversation_id=conv_id,
file_id=file_id,
file_size=file_size,
)
if resp["status"] != "ok":
return False, resp["data"]["message"]
ok, err = await self._pipelined_upload(file_id, encrypted_image)
if not ok:
return False, err
resp = await self.send_and_recv("upload_image_end", file_id=file_id)
if resp["status"] != "ok":
return False, resp["data"]["message"]
# Cache decrypted image locally so sender never re-downloads
cache_path = self._media_cache_path(file_id)
if cache_path:
try:
cache_path.write_bytes(image_bytes)
os.chmod(cache_path, 0o600)
except OSError:
pass
# Build message payload with image info
image_info = {
"file_id": file_id,
"aes_key": encode_binary(img_aes_key),
"iv": encode_binary(img_iv),
"thumbnail": thumbnail_b64,
"filename": path.name,
"size": len(image_bytes),
}
payload = {
"sender": self.username,
"text": "",
"reply_to": reply_to,
"timestamp": datetime.now(timezone.utc).isoformat(),
"image": image_info,
}
plaintext = pad_plaintext(json.dumps(payload, ensure_ascii=False).encode("utf-8"))
my_user_id = self.session["user_id"]
if self._is_group(members):
# Group image: use sender key
sk = self.sender_key_states.get(conv_id)
if not sk:
sk = _load_sender_key_state(self.email, conv_id, self._local_key)
if not sk:
sk = SenderKeyState()
self.sender_key_states[conv_id] = sk
_save_sender_key_state(self.email, conv_id, sk, self._local_key)
self.sender_key_states[conv_id] = sk
await self._catch_up_sender_key_distribution(conv_id, members, sk)
result = sk.encrypt(plaintext)
_save_sender_key_state(self.email, conv_id, sk, self._local_key)
recipients = []
for member in members:
uid = member.get("user_id")
if not uid or uid == my_user_id:
continue
recipients.append({
"user_id": uid,
"encrypted_content": encode_binary(result["ciphertext"]),
"nonce": encode_binary(result["nonce"]),
})
# Self-encrypted copy for sender
self_key = derive_self_encryption_key(self.identity_private)
_, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key)
recipients.append({
"user_id": my_user_id,
"encrypted_content": encode_binary(self_ct + self_tag),
"nonce": encode_binary(self_nonce),
"ratchet_header": {"self": True},
})
resp = await self.send_and_recv(
"send_message",
conversation_id=conv_id,
ratchet_header={"dh_pub": "00" * 32, "n": 0, "pn": 0},
recipients=recipients,
sender_chain_id=encode_binary(bytes.fromhex(result["chain_id"])),
sender_chain_n=result["n"],
image_file_id=file_id,
)
else:
# DM image: per-device ratchet (same pattern as _send_dm)
recipients = []
first_rh = None
for member in members:
uid = member.get("user_id")
if not uid or uid == my_user_id:
continue
try:
device_bundles = await self._get_device_bundles(uid)
except Exception:
device_bundles = []
if not device_bundles:
# Fallback: legacy single-device
ratchet = await self._get_or_create_session(uid)
result = ratchet.encrypt(plaintext)
x3dh_h = getattr(ratchet, "_x3dh_header", None)
if x3dh_h:
delattr(ratchet, "_x3dh_header")
entry = {
"user_id": uid,
"encrypted_content": encode_binary(result["ciphertext"]),
"nonce": encode_binary(result["nonce"]),
"ratchet_header": result["header"],
}
if x3dh_h:
entry["x3dh_header"] = x3dh_h
recipients.append(entry)
if first_rh is None:
first_rh = result["header"]
_save_session(self.email, uid, ratchet, self._local_key)
else:
for bundle in device_bundles:
dev_id = bundle.get("device_id")
ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id,
bundle=bundle)
result = ratchet.encrypt(plaintext)
x3dh_h = getattr(ratchet, "_x3dh_header", None)
if x3dh_h:
delattr(ratchet, "_x3dh_header")
entry = {
"user_id": uid,
"encrypted_content": encode_binary(result["ciphertext"]),
"nonce": encode_binary(result["nonce"]),
"ratchet_header": result["header"],
}
if dev_id:
entry["device_id"] = dev_id
if x3dh_h:
entry["x3dh_header"] = x3dh_h
recipients.append(entry)
if first_rh is None:
first_rh = result["header"]
_save_session(self.email, uid, ratchet, self._local_key,
peer_device_id=dev_id)
# Encrypt self-copy with static key
self_key = derive_self_encryption_key(self.identity_private)
_, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key)
recipients.append({
"user_id": my_user_id,
"encrypted_content": encode_binary(self_ct + self_tag),
"nonce": encode_binary(self_nonce),
"ratchet_header": {"self": True},
})
resp = await self.send_and_recv(
"send_message",
conversation_id=conv_id,
ratchet_header=first_rh,
recipients=recipients,
image_file_id=file_id,
)
if resp["status"] == "ok":
msg_data = resp.get("data", {})
result_msg = {
**payload,
"message_id": msg_data.get("message_id", ""),
"created_at": msg_data.get("created_at", ""),
"sender_id": self.session["user_id"],
"conversation_id": conv_id,
"read_by": [],
}
_save_message_to_cache(self.email, conv_id, result_msg["message_id"], result_msg, self._cache_key)
return True, result_msg
return False, resp["data"]["message"]
async def _pipelined_upload(self, file_id: str, encrypted_data: bytes) -> tuple[bool, str]:
"""Upload encrypted data in pipelined chunks (no per-chunk round-trip wait)."""
file_size = len(encrypted_data)
chunk_futures = []
upload_offset = 0
while upload_offset < file_size:
chunk = encrypted_data[upload_offset:upload_offset + IMAGE_CHUNK_SIZE]
request_id = str(uuid.uuid4())
loop = asyncio.get_running_loop()
fut = loop.create_future()
self._pending[request_id] = fut
try:
await self.writer.send_request(
"upload_image_chunk",
request_id=request_id,
file_id=file_id,
data=encode_binary(chunk),
)
except Exception as e:
self._pending.pop(request_id, None)
return False, f"Upload failed: {e}"
chunk_futures.append((request_id, fut))
upload_offset += len(chunk)
for request_id, fut in chunk_futures:
try:
resp = await asyncio.wait_for(fut, timeout=30.0)
except (asyncio.TimeoutError, ConnectionError):
self._pending.pop(request_id, None)
return False, "Upload chunk timed out."
finally:
self._pending.pop(request_id, None)
if resp["status"] != "ok":
return False, resp["data"]["message"]
return True, ""
async def send_file(self, conv_id: str, file_path: str, members: list[dict],
reply_to: str | None = None) -> tuple[bool, str | dict]:
"""Encrypt and upload a file, then send as a message."""
await self.typing_stop(conv_id, force=True)
import mimetypes
path = Path(file_path)
if not path.exists():
return False, "File not found."
try:
file_bytes = path.read_bytes()
except Exception as e:
return False, f"Cannot read file: {e}"
mime_type = mimetypes.guess_type(path.name)[0] or "application/octet-stream"
# Encrypt file with AES-256-GCM (offload to thread)
file_aes_key, file_iv, file_ct, file_tag = await asyncio.get_event_loop().run_in_executor(
None, aes_encrypt, file_bytes)
encrypted_file = file_ct + file_tag
file_id = str(uuid.uuid4())
file_size = len(encrypted_file)
# Chunked upload (reuse image upload infrastructure with file_type="file")
resp = await self.send_and_recv(
"upload_image_start",
conversation_id=conv_id,
file_id=file_id,
file_size=file_size,
file_type="file",
)
if resp["status"] != "ok":
return False, resp["data"]["message"]
ok, err = await self._pipelined_upload(file_id, encrypted_file)
if not ok:
return False, err
resp = await self.send_and_recv("upload_image_end", file_id=file_id)
if resp["status"] != "ok":
return False, resp["data"]["message"]
# Cache decrypted file locally so sender never re-downloads
cache_path = self._media_cache_path(file_id)
if cache_path:
try:
cache_path.write_bytes(file_bytes)
os.chmod(cache_path, 0o600)
except OSError:
pass
# Build message payload with file info
file_info = {
"file_id": file_id,
"aes_key": encode_binary(file_aes_key),
"iv": encode_binary(file_iv),
"filename": path.name,
"size": len(file_bytes),
"mime_type": mime_type,
}
payload = {
"sender": self.username,
"text": "",
"reply_to": reply_to,
"timestamp": datetime.now(timezone.utc).isoformat(),
"file": file_info,
}
plaintext = pad_plaintext(json.dumps(payload, ensure_ascii=False).encode("utf-8"))
my_user_id = self.session["user_id"]
if self._is_group(members):
sk = self.sender_key_states.get(conv_id)
if not sk:
sk = _load_sender_key_state(self.email, conv_id, self._local_key)
if not sk:
sk = SenderKeyState()
self.sender_key_states[conv_id] = sk
_save_sender_key_state(self.email, conv_id, sk, self._local_key)
self.sender_key_states[conv_id] = sk
await self._catch_up_sender_key_distribution(conv_id, members, sk)
result = sk.encrypt(plaintext)
_save_sender_key_state(self.email, conv_id, sk, self._local_key)
recipients = []
for member in members:
uid = member.get("user_id")
if not uid or uid == my_user_id:
continue
recipients.append({
"user_id": uid,
"encrypted_content": encode_binary(result["ciphertext"]),
"nonce": encode_binary(result["nonce"]),
})
# Self-encrypted copy for sender
self_key = derive_self_encryption_key(self.identity_private)
_, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key)
recipients.append({
"user_id": my_user_id,
"encrypted_content": encode_binary(self_ct + self_tag),
"nonce": encode_binary(self_nonce),
"ratchet_header": {"self": True},
})
resp = await self.send_and_recv(
"send_message",
conversation_id=conv_id,
ratchet_header={"dh_pub": "00" * 32, "n": 0, "pn": 0},
recipients=recipients,
sender_chain_id=encode_binary(bytes.fromhex(result["chain_id"])),
sender_chain_n=result["n"],
image_file_id=file_id,
)
else:
# DM file: per-device ratchet (same pattern as _send_dm)
recipients = []
first_rh = None
for member in members:
uid = member.get("user_id")
if not uid or uid == my_user_id:
continue
try:
device_bundles = await self._get_device_bundles(uid)
except Exception:
device_bundles = []
if not device_bundles:
# Fallback: legacy single-device
ratchet = await self._get_or_create_session(uid)
result = ratchet.encrypt(plaintext)
x3dh_h = getattr(ratchet, "_x3dh_header", None)
if x3dh_h:
delattr(ratchet, "_x3dh_header")
entry = {
"user_id": uid,
"encrypted_content": encode_binary(result["ciphertext"]),
"nonce": encode_binary(result["nonce"]),
"ratchet_header": result["header"],
}
if x3dh_h:
entry["x3dh_header"] = x3dh_h
recipients.append(entry)
if first_rh is None:
first_rh = result["header"]
_save_session(self.email, uid, ratchet, self._local_key)
else:
for bundle in device_bundles:
dev_id = bundle.get("device_id")
ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id,
bundle=bundle)
result = ratchet.encrypt(plaintext)
x3dh_h = getattr(ratchet, "_x3dh_header", None)
if x3dh_h:
delattr(ratchet, "_x3dh_header")
entry = {
"user_id": uid,
"encrypted_content": encode_binary(result["ciphertext"]),
"nonce": encode_binary(result["nonce"]),
"ratchet_header": result["header"],
}
if dev_id:
entry["device_id"] = dev_id
if x3dh_h:
entry["x3dh_header"] = x3dh_h
recipients.append(entry)
if first_rh is None:
first_rh = result["header"]
_save_session(self.email, uid, ratchet, self._local_key,
peer_device_id=dev_id)
# Encrypt self-copy with static key
self_key = derive_self_encryption_key(self.identity_private)
_, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key)
recipients.append({
"user_id": my_user_id,
"encrypted_content": encode_binary(self_ct + self_tag),
"nonce": encode_binary(self_nonce),
"ratchet_header": {"self": True},
})
resp = await self.send_and_recv(
"send_message",
conversation_id=conv_id,
ratchet_header=first_rh,
recipients=recipients,
image_file_id=file_id,
)
if resp["status"] == "ok":
msg_data = resp.get("data", {})
result_msg = {
**payload,
"message_id": msg_data.get("message_id", ""),
"created_at": msg_data.get("created_at", ""),
"sender_id": self.session["user_id"],
"conversation_id": conv_id,
"read_by": [],
}
_save_message_to_cache(self.email, conv_id, result_msg["message_id"], result_msg, self._cache_key)
return True, result_msg
return False, resp["data"]["message"]
async def download_file(self, file_id: str, file_info: dict) -> bytes | None:
"""Download and decrypt a file. Returns decrypted file bytes or None."""
return await self._download_and_decrypt(file_id, file_info)
def _media_cache_path(self, file_id: str) -> Path | None:
"""Return path for cached decrypted media file, or None if no email."""
if not self.email:
return None
d = get_key_dir(self.email) / "media_cache"
d.mkdir(parents=True, exist_ok=True)
try:
os.chmod(d, 0o700)
except OSError:
pass
return d / f"{file_id}.bin"
async def _stream_download(self, file_id: str) -> bytes | None:
"""Download file via streaming (single request, server sends all chunks).
Falls back to legacy pipelined download if server doesn't support streaming.
"""
request_id = str(uuid.uuid4())
q: asyncio.Queue = asyncio.Queue()
self._pending[request_id] = q
try:
await self.writer.send_request(
"download_stream", request_id=request_id, file_id=file_id,
)
except Exception:
self._pending.pop(request_id, None)
return None
chunks: dict[int, bytes] = {}
try:
while True:
try:
resp = await asyncio.wait_for(q.get(), timeout=60.0)
except (asyncio.TimeoutError, ConnectionError):
return None
if resp.get("status") != "ok":
# Server may not support download_stream — fall back
return await self._legacy_download(file_id)
data = resp["data"]
chunk_data = decode_binary(data["data"])
chunks[data["offset"]] = chunk_data
if data.get("done"):
break
finally:
self._pending.pop(request_id, None)
# Reassemble in order
parts = []
for off in sorted(chunks.keys()):
parts.append(chunks[off])
return b"".join(parts)
async def _legacy_download(self, file_id: str) -> bytes | None:
"""Fallback: download file chunk by chunk (for older servers)."""
resp = await self.send_and_recv("download_image", file_id=file_id, offset=0)
if resp["status"] != "ok":
return None
data = resp["data"]
first_chunk = decode_binary(data["data"])
if data.get("done"):
return first_chunk
chunk_size = len(first_chunk)
chunks = {0: first_chunk}
# Pipeline remaining chunks
futures = []
offset = chunk_size
total_size = data.get("total_size", 0)
# Calculate how many chunks we need
while offset < total_size:
request_id = str(uuid.uuid4())
loop = asyncio.get_running_loop()
fut = loop.create_future()
self._pending[request_id] = fut
try:
await self.writer.send_request(
"download_image", request_id=request_id,
file_id=file_id, offset=offset,
)
except Exception:
self._pending.pop(request_id, None)
return None
futures.append((request_id, offset, fut))
offset += chunk_size
for request_id, off, fut in futures:
try:
resp = await asyncio.wait_for(fut, timeout=30.0)
except (asyncio.TimeoutError, ConnectionError):
self._pending.pop(request_id, None)
return None
finally:
self._pending.pop(request_id, None)
if resp["status"] != "ok":
return None
chunk_data = decode_binary(resp["data"]["data"])
chunks[off] = chunk_data
parts = []
for off in sorted(chunks.keys()):
parts.append(chunks[off])
return b"".join(parts)
async def _download_and_decrypt(self, file_id: str, info: dict) -> bytes | None:
"""Download, decrypt, and cache a media file. Used by both image and file download."""
# Check local cache first
cache_path = self._media_cache_path(file_id)
if cache_path and cache_path.exists():
try:
return cache_path.read_bytes()
except OSError:
pass
encrypted_data = await self._stream_download(file_id)
if not encrypted_data or len(encrypted_data) < 16:
return None
ciphertext = encrypted_data[:-16]
tag = encrypted_data[-16:]
try:
aes_key = decode_binary(info["aes_key"])
iv = decode_binary(info["iv"])
decrypted = aes_decrypt(aes_key, iv, ciphertext, tag)
except Exception:
return None
# Cache decrypted result to disk
if cache_path and decrypted:
try:
cache_path.write_bytes(decrypted)
os.chmod(cache_path, 0o600)
except OSError:
pass
return decrypted
async def download_image(self, file_id: str, image_info: dict) -> bytes | None:
"""Download and decrypt an image. Returns decrypted image bytes or None."""
return await self._download_and_decrypt(file_id, image_info)
# ------------------------------------------------------------------
# Re-encrypt history (for device pairing)
# ------------------------------------------------------------------
async def reencrypt_history(self):
"""Re-encrypt all cached messages with self-encryption key.
After device pairing, the new device shares the same identity key
but cannot decrypt old messages (Double Ratchet keys are one-time use).
This re-encrypts all cached messages so they can be read using the
self-encryption key derived from the shared identity key.
"""
if not self.identity_private or not self.session:
return
self_key = derive_self_encryption_key(self.identity_private)
# Phase 1: Fetch & decrypt all messages to populate cache
# (messages the old device never opened won't be in cache yet)
try:
convs = await self.list_conversations()
convs = list(convs)
random.shuffle(convs)
total_convs = len(convs)
for ci, conv in enumerate(convs):
cid = conv.get("id") or conv.get("conversation_id")
if not cid:
continue
if self._reencrypt_progress_cb:
self._reencrypt_progress_cb(
f"Fetching messages: {ci + 1}/{total_convs} conversations..."
)
offset = 0
while True:
msgs = await self.get_messages(cid, limit=200, offset=offset)
if not msgs or len(msgs) < 200:
break
offset += len(msgs)
await asyncio.sleep(random.uniform(*PAIRING_REENCRYPT_INTER_FETCH_DELAY_RANGE))
except Exception as e:
self._logger.warning("Failed to fetch messages for re-encryption: %s", e)
# Phase 2: Read cache and re-encrypt
cache_dir = get_key_dir(self.email) / "message_cache"
if not cache_dir.exists():
self._logger.info("No message cache to re-encrypt.")
return
all_updates = []
conv_ids = set()
for f in cache_dir.iterdir():
if f.suffix in (".json", ".bin"):
conv_ids.add(f.stem)
total_files = len(conv_ids)
conv_order = sorted(conv_ids)
random.shuffle(conv_order)
for i, conv_id in enumerate(conv_order):
cache = _load_message_cache(self.email, conv_id, self._cache_key)
if not cache:
continue
items = list(cache.items())
random.shuffle(items)
for msg_id, entry in items:
# Skip control messages (sender key distribution)
if entry.get("_control"):
continue
# Skip entries with no useful content
text = entry.get("text", "")
if not text and not entry.get("image") and not entry.get("file"):
continue
# Rebuild plaintext from cached payload
payload = {k: v for k, v in entry.items()
if k not in ("message_id", "created_at", "read_by", "sender_id", "deleted")}
plaintext = pad_plaintext(json.dumps(payload, ensure_ascii=False).encode("utf-8"))
# Re-encrypt with self-encryption key
_, nonce, ct, tag = aes_encrypt(plaintext, key=self_key)
all_updates.append({
"message_id": msg_id,
"encrypted_content": encode_binary(ct + tag),
"nonce": encode_binary(nonce),
})
if self._reencrypt_progress_cb:
self._reencrypt_progress_cb(f"Re-encrypting history: {i + 1}/{total_files} conversations...")
if not all_updates:
self._logger.info("No messages to re-encrypt.")
return
# Send in batches of 500
batch_size = PAIRING_REENCRYPT_BATCH_SIZE
random.shuffle(all_updates)
total = len(all_updates)
for start in range(0, total, batch_size):
batch = all_updates[start:start + batch_size]
resp = await self.send_and_recv("reencrypt_messages", updates=batch)
if resp["status"] != "ok":
self._logger.warning("Re-encrypt batch failed: %s", resp.get("data", {}).get("message", ""))
else:
self._logger.info("Re-encrypted %d/%d messages.", min(start + batch_size, total), total)
if start + batch_size < total:
await asyncio.sleep(random.uniform(*PAIRING_REENCRYPT_INTER_BATCH_DELAY_RANGE))
if self._reencrypt_progress_cb:
self._reencrypt_progress_cb(f"Re-encryption complete: {total} messages uploaded.")
# ------------------------------------------------------------------
# User Profiles
# ------------------------------------------------------------------
async def get_profile(self, user_id: str | None = None) -> dict | None:
"""Get user profile. If user_id is None, returns own profile."""
kwargs = {}
if user_id:
kwargs["user_id"] = user_id
resp = await self.send_and_recv("get_profile", **kwargs)
if resp["status"] == "ok":
return resp["data"]
return None
async def update_profile(self, **fields) -> tuple[bool, str]:
"""Update own profile (phone, location, *_visible)."""
resp = await self.send_and_recv("update_profile", **fields)
if resp["status"] == "ok":
return True, "OK"
return False, resp["data"]["message"]
async def update_avatar(self, image_data: bytes) -> tuple[bool, str]:
"""Upload avatar image."""
resp = await self.send_and_recv("update_avatar", data=encode_binary(image_data))
if resp["status"] == "ok":
return True, resp["data"].get("avatar_file", "")
return False, resp["data"]["message"]
async def get_avatar(self, user_id: str) -> bytes | None:
"""Download avatar for a user."""
resp = await self.send_and_recv("get_avatar", 10.0, user_id=user_id)
if resp["status"] == "ok":
return decode_binary(resp["data"]["data"])
return None
async def update_group_avatar(self, conv_id: str, image_data: bytes) -> tuple[bool, str]:
"""Upload avatar for a group conversation."""
resp = await self.send_and_recv("update_group_avatar",
conversation_id=conv_id, data=encode_binary(image_data))
if resp["status"] == "ok":
return True, resp["data"].get("avatar_file", "")
return False, resp["data"]["message"]
async def get_group_avatar(self, conv_id: str) -> bytes | None:
"""Download avatar for a group conversation."""
resp = await self.send_and_recv("get_group_avatar", 10.0, conversation_id=conv_id)
if resp["status"] == "ok":
return decode_binary(resp["data"]["data"])
return None
# ------------------------------------------------------------------
# Cleanup
# ------------------------------------------------------------------
async def close(self):
self.connected = False
for conv_id in list(self._typing_stop_tasks.keys()):
self._cancel_typing_timer(conv_id)
if self._listener_task:
self._listener_task.cancel()
if self.raw_writer:
self.raw_writer.close()
async def reconnect(self):
"""Close existing connection and re-establish: connect + re-login using in-memory keys."""
try:
await self.close()
except Exception:
pass
# Reset reader/writer but keep keys and sessions
self.reader = None
self.writer = None
self.raw_writer = None
self._listener_task = None
self._pending.clear()
self.login_rejected = False
# Drain queues
while not self._response_queue.empty():
try:
self._response_queue.get_nowait()
except Exception:
break
while not self._notification_queue.empty():
try:
self._notification_queue.get_nowait()
except Exception:
break
await self.connect()
self._listener_task = asyncio.create_task(self._background_listener())
if self.email and self.private_key:
# RSA challenge-response login (keys already in memory)
start = await self.send_and_recv("login_start", email=self.email)
if start["status"] == "ok":
challenge = decode_binary(start["data"]["challenge"])
signature = rsa_sign(self.private_key, challenge)
login_kwargs = {
"email": self.email,
"signature": encode_binary(signature),
"client_version": VERSION,
}
if self.device_id:
login_kwargs["device_id"] = self.device_id
finish = await self.send_and_recv("login_finish", **login_kwargs)
if finish["status"] == "ok":
self.session = finish["data"]
asyncio.create_task(self._ensure_prekeys())
else:
# Login rejected — keys were likely rotated on another device
self.session = None
self.connected = False
self.login_rejected = True