- chat_core: defer one-time-prekey deletion until the first message decrypts successfully; deleting it on load made the SPK grace-period retry derive a wrong shared secret and lose the message permanently - chat_core: fix get_deleted_since params (since -> since_ts) and response field (message_ids -> deleted_ids) so incremental deletion sync actually works - chat_core: route keys_updated pushes into the notification queue - server: notify contacts with keys_updated when a user uploads a new SPK or logs in with a new device, so clients invalidate cached key bundles instead of waiting for the TTL - server: rate-limit download_stream like other heavy handlers Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
3926 lines
163 KiB
Python
3926 lines
163 KiB
Python
"""Shared network layer and ChatClient class for CLI and GUI clients.
|
|
|
|
Uses X3DH + Double Ratchet for message encryption, Sender Keys for groups.
|
|
RSA retained for login challenge-response only.
|
|
"""
|
|
|
|
import asyncio
|
|
import collections
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import os
|
|
import random
|
|
import ssl
|
|
import time
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
load_dotenv()
|
|
|
|
from crypto_utils import (
|
|
# RSA (login only)
|
|
generate_rsa_keypair,
|
|
serialize_private_key,
|
|
serialize_public_key,
|
|
load_private_key,
|
|
load_public_key,
|
|
rsa_sign,
|
|
# Ed25519
|
|
generate_identity_keypair,
|
|
serialize_ed25519_private,
|
|
serialize_ed25519_private_raw,
|
|
serialize_ed25519_public,
|
|
load_ed25519_private,
|
|
load_ed25519_public,
|
|
ed25519_sign,
|
|
# X25519
|
|
generate_x25519_keypair,
|
|
serialize_x25519_private,
|
|
serialize_x25519_public,
|
|
load_x25519_private,
|
|
load_x25519_public,
|
|
x25519_dh,
|
|
derive_pairing_shared_key,
|
|
# X3DH
|
|
generate_signed_prekey,
|
|
generate_one_time_prekeys,
|
|
x3dh_initiate,
|
|
x3dh_respond,
|
|
# Double Ratchet
|
|
DoubleRatchet,
|
|
# Sender Keys
|
|
SenderKeyState,
|
|
# AES
|
|
aes_encrypt,
|
|
aes_decrypt,
|
|
# Self-encryption
|
|
derive_self_encryption_key,
|
|
# Local storage encryption
|
|
derive_local_storage_key,
|
|
# Contact verification
|
|
compute_fingerprint,
|
|
compute_pairing_fingerprint,
|
|
encode_pairing_qr,
|
|
format_fingerprint,
|
|
normalize_pairing_fingerprint,
|
|
compute_safety_number,
|
|
encode_verification_qr,
|
|
decode_verification_qr,
|
|
# Message padding
|
|
pad_plaintext,
|
|
unpad_plaintext,
|
|
)
|
|
from protocol import (
|
|
VERSION,
|
|
ProtocolReader,
|
|
ProtocolWriter,
|
|
encode_binary,
|
|
decode_binary,
|
|
build_request,
|
|
MAX_MESSAGE_BYTES,
|
|
MAX_IMAGE_BYTES,
|
|
IMAGE_CHUNK_SIZE,
|
|
)
|
|
|
|
|
|
KEY_DIR = Path.home() / ".encrypted_chat"
|
|
OPK_REPLENISH_THRESHOLD = 20
|
|
OPK_BATCH_SIZE = 50
|
|
SPK_ROTATION_DAYS = 7
|
|
PAIRING_REENCRYPT_INITIAL_DELAY_RANGE = (20.0, 75.0)
|
|
PAIRING_REENCRYPT_INTER_BATCH_DELAY_RANGE = (1.0, 3.0)
|
|
PAIRING_REENCRYPT_INTER_FETCH_DELAY_RANGE = (0.15, 0.5)
|
|
PAIRING_REENCRYPT_BATCH_SIZE = 500
|
|
|
|
|
|
def _encrypt_local(data: bytes, key: bytes) -> bytes:
|
|
"""Encrypt data with AES-256-GCM for local storage. Format: nonce(12) + tag(16) + ciphertext."""
|
|
_, nonce, ct, tag = aes_encrypt(data, key=key)
|
|
return nonce + tag + ct
|
|
|
|
|
|
def _decrypt_local(raw: bytes, key: bytes) -> bytes:
|
|
"""Decrypt data encrypted by _encrypt_local."""
|
|
nonce, tag, ct = raw[:12], raw[12:28], raw[28:]
|
|
return aes_decrypt(key, nonce, ct, tag)
|
|
|
|
|
|
def get_key_dir(email: str) -> Path:
|
|
d = KEY_DIR / email
|
|
d.mkdir(parents=True, exist_ok=True)
|
|
os.chmod(d, 0o700)
|
|
return d
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# RSA key storage (login only — unchanged interface)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def save_keys(email: str, private_key, public_key, password: bytes | None = None):
|
|
d = get_key_dir(email)
|
|
(d / "private.pem").write_bytes(serialize_private_key(private_key, password=password))
|
|
(d / "public.pem").write_bytes(serialize_public_key(public_key))
|
|
os.chmod(d / "private.pem", 0o600)
|
|
|
|
|
|
def load_keys(email: str, password: bytes | None = None):
|
|
d = get_key_dir(email)
|
|
priv_path = d / "private.pem"
|
|
pub_path = d / "public.pem"
|
|
if not priv_path.exists():
|
|
return None, None, "No local keys found."
|
|
pem = priv_path.read_bytes()
|
|
try:
|
|
private_key = load_private_key(pem, password=password)
|
|
except Exception:
|
|
try:
|
|
private_key = load_private_key(pem, password=None)
|
|
if password:
|
|
save_keys(email, private_key, load_public_key(pub_path.read_bytes()), password=password)
|
|
except Exception:
|
|
return None, None, "Invalid or missing password."
|
|
public_key = load_public_key(pub_path.read_bytes())
|
|
return private_key, public_key, None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Identity + prekey storage
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _save_identity_keys(email: str, ed_priv, ed_pub, password: bytes | None = None):
|
|
d = get_key_dir(email)
|
|
if password:
|
|
(d / "identity_private.bin").write_bytes(serialize_ed25519_private(ed_priv, password=password))
|
|
else:
|
|
(d / "identity_private.bin").write_bytes(serialize_ed25519_private_raw(ed_priv))
|
|
(d / "identity_public.bin").write_bytes(serialize_ed25519_public(ed_pub))
|
|
os.chmod(d / "identity_private.bin", 0o600)
|
|
|
|
|
|
def _load_identity_keys(email: str, password: bytes | None = None):
|
|
d = get_key_dir(email)
|
|
priv_path = d / "identity_private.bin"
|
|
pub_path = d / "identity_public.bin"
|
|
if not priv_path.exists():
|
|
return None, None
|
|
priv = load_ed25519_private(priv_path.read_bytes(), password=password)
|
|
pub = load_ed25519_public(pub_path.read_bytes())
|
|
return priv, pub
|
|
|
|
|
|
def _save_spk(email: str, spk_priv, spk_id: str, local_key: bytes | None = None):
|
|
d = get_key_dir(email)
|
|
raw = serialize_x25519_private(spk_priv)
|
|
data = _encrypt_local(raw, local_key) if local_key else raw
|
|
(d / "spk_private.bin").write_bytes(data)
|
|
(d / "spk_id.txt").write_text(spk_id)
|
|
os.chmod(d / "spk_private.bin", 0o600)
|
|
|
|
|
|
def _load_spk(email: str, local_key: bytes | None = None):
|
|
d = get_key_dir(email)
|
|
priv_path = d / "spk_private.bin"
|
|
id_path = d / "spk_id.txt"
|
|
if not priv_path.exists():
|
|
return None, None
|
|
raw = priv_path.read_bytes()
|
|
if local_key:
|
|
try:
|
|
raw = _decrypt_local(raw, local_key)
|
|
except Exception:
|
|
# Plaintext fallback (migration) — re-save encrypted
|
|
pass
|
|
priv = load_x25519_private(raw)
|
|
spk_id = id_path.read_text().strip() if id_path.exists() else ""
|
|
if local_key:
|
|
_save_spk(email, priv, spk_id, local_key)
|
|
return priv, spk_id
|
|
|
|
|
|
def _save_prev_spk(email: str, spk_priv, spk_id: str, local_key: bytes | None = None):
|
|
"""Save previous SPK for grace period (in-flight X3DH may reference old SPK)."""
|
|
d = get_key_dir(email)
|
|
raw = serialize_x25519_private(spk_priv)
|
|
data = _encrypt_local(raw, local_key) if local_key else raw
|
|
(d / "prev_spk_private.bin").write_bytes(data)
|
|
(d / "prev_spk_id.txt").write_text(spk_id)
|
|
os.chmod(d / "prev_spk_private.bin", 0o600)
|
|
|
|
|
|
def _load_prev_spk(email: str, local_key: bytes | None = None):
|
|
"""Load previous SPK (grace period). Returns (private_key, spk_id) or (None, None)."""
|
|
d = get_key_dir(email)
|
|
priv_path = d / "prev_spk_private.bin"
|
|
id_path = d / "prev_spk_id.txt"
|
|
if not priv_path.exists():
|
|
return None, None
|
|
raw = priv_path.read_bytes()
|
|
if local_key:
|
|
try:
|
|
raw = _decrypt_local(raw, local_key)
|
|
except Exception:
|
|
pass
|
|
priv = load_x25519_private(raw)
|
|
spk_id = id_path.read_text().strip() if id_path.exists() else ""
|
|
if local_key:
|
|
_save_prev_spk(email, priv, spk_id, local_key)
|
|
return priv, spk_id
|
|
|
|
|
|
def _save_opk_private(email: str, opk_id: str, opk_priv, local_key: bytes | None = None):
|
|
d = get_key_dir(email) / "opk_private"
|
|
d.mkdir(parents=True, exist_ok=True)
|
|
os.chmod(d, 0o700)
|
|
raw = serialize_x25519_private(opk_priv)
|
|
data = _encrypt_local(raw, local_key) if local_key else raw
|
|
(d / f"{opk_id}.bin").write_bytes(data)
|
|
os.chmod(d / f"{opk_id}.bin", 0o600)
|
|
|
|
|
|
def _load_opk_private(email: str, opk_id: str, local_key: bytes | None = None):
|
|
d = get_key_dir(email) / "opk_private"
|
|
p = d / f"{opk_id}.bin"
|
|
if not p.exists():
|
|
return None
|
|
raw = p.read_bytes()
|
|
if local_key:
|
|
try:
|
|
raw = _decrypt_local(raw, local_key)
|
|
except Exception:
|
|
pass
|
|
priv = load_x25519_private(raw)
|
|
# Migration: re-save encrypted if local_key provided
|
|
if local_key:
|
|
_save_opk_private(email, opk_id, priv, local_key)
|
|
return priv
|
|
|
|
|
|
def _secure_delete(p: Path):
|
|
"""Overwrite file with random data before deletion (anti-forensic wipe)."""
|
|
try:
|
|
if not p.exists():
|
|
return
|
|
size = p.stat().st_size
|
|
if size > 0:
|
|
with open(p, "r+b") as f:
|
|
f.write(os.urandom(size))
|
|
f.flush()
|
|
os.fsync(f.fileno())
|
|
p.unlink()
|
|
except Exception:
|
|
try:
|
|
p.unlink(missing_ok=True)
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def _delete_opk_private(email: str, opk_id: str):
|
|
d = get_key_dir(email) / "opk_private"
|
|
p = d / f"{opk_id}.bin"
|
|
_secure_delete(p)
|
|
|
|
|
|
def _save_device_id(email: str, device_id: str):
|
|
d = get_key_dir(email)
|
|
p = d / "device_id.txt"
|
|
p.write_text(device_id)
|
|
os.chmod(p, 0o600)
|
|
|
|
|
|
def _load_device_id(email: str) -> str | None:
|
|
d = get_key_dir(email)
|
|
p = d / "device_id.txt"
|
|
if not p.exists():
|
|
return None
|
|
return p.read_text().strip() or None
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
# Expiring LRU cache for user info (max_size + TTL eviction)
|
|
# ------------------------------------------------------------------
|
|
|
|
class _ExpiringLRUCache:
|
|
"""Dict-like cache with max size (LRU eviction) and per-entry TTL."""
|
|
|
|
def __init__(self, max_size: int = 10_000, ttl: float = 3600.0):
|
|
self._max_size = max_size
|
|
self._ttl = ttl
|
|
self._data: collections.OrderedDict = collections.OrderedDict() # key -> (value, ts)
|
|
|
|
def get(self, key, default=None):
|
|
entry = self._data.get(key)
|
|
if entry is None:
|
|
return default
|
|
value, ts = entry
|
|
if time.monotonic() - ts > self._ttl:
|
|
del self._data[key]
|
|
return default
|
|
# Move to end (most recently used)
|
|
self._data.move_to_end(key)
|
|
return value
|
|
|
|
def __setitem__(self, key, value):
|
|
if key in self._data:
|
|
self._data.move_to_end(key)
|
|
self._data[key] = (value, time.monotonic())
|
|
while len(self._data) > self._max_size:
|
|
self._data.popitem(last=False)
|
|
|
|
def __getitem__(self, key):
|
|
result = self.get(key)
|
|
if result is None and key not in self._data:
|
|
raise KeyError(key)
|
|
return result
|
|
|
|
def __contains__(self, key):
|
|
return self.get(key) is not None
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
# Identity key change exception (TOFU hard-fail)
|
|
# ------------------------------------------------------------------
|
|
|
|
class IdentityKeyChanged(Exception):
|
|
"""Raised when a peer's identity key has changed (TOFU violation).
|
|
|
|
Session creation is blocked until the user explicitly accepts the new key.
|
|
"""
|
|
def __init__(self, user_id: str, new_key_bytes: bytes, status: str):
|
|
self.user_id = user_id
|
|
self.new_key_bytes = new_key_bytes
|
|
self.status = status # "changed" or "changed_verified"
|
|
super().__init__(
|
|
f"Identity key changed for user {user_id} (status={status}). "
|
|
f"Accept the new key before communicating."
|
|
)
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
# Client-side brute-force lockout
|
|
# ------------------------------------------------------------------
|
|
|
|
_LOCKOUT_BASE_SECONDS = 2
|
|
_LOCKOUT_MAX_SECONDS = 300 # 5 min cap
|
|
|
|
|
|
def _get_lockout_path(email: str) -> Path:
|
|
return get_key_dir(email) / "login_lockout.json"
|
|
|
|
|
|
def _check_lockout(email: str) -> float:
|
|
"""Return seconds remaining until next attempt allowed. 0 = can try now."""
|
|
p = _get_lockout_path(email)
|
|
if not p.exists():
|
|
return 0.0
|
|
try:
|
|
data = json.loads(p.read_text())
|
|
locked_until = data.get("locked_until", 0.0)
|
|
remaining = locked_until - time.time()
|
|
return max(0.0, remaining)
|
|
except Exception:
|
|
return 0.0
|
|
|
|
|
|
def _record_failed_attempt(email: str):
|
|
"""Increment failed counter, update locked_until."""
|
|
p = _get_lockout_path(email)
|
|
failed = 0
|
|
try:
|
|
if p.exists():
|
|
data = json.loads(p.read_text())
|
|
failed = data.get("failed_attempts", 0)
|
|
except Exception:
|
|
pass
|
|
failed += 1
|
|
delay = min(_LOCKOUT_BASE_SECONDS ** failed, _LOCKOUT_MAX_SECONDS)
|
|
locked_until = time.time() + delay
|
|
p.write_text(json.dumps({"failed_attempts": failed, "locked_until": locked_until}))
|
|
os.chmod(p, 0o600)
|
|
|
|
|
|
def _clear_lockout(email: str):
|
|
"""Reset on successful login."""
|
|
p = _get_lockout_path(email)
|
|
if p.exists():
|
|
try:
|
|
p.unlink()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def _save_session(email: str, peer_user_id: str, ratchet: DoubleRatchet,
|
|
local_key: bytes | None = None, peer_device_id: str | None = None):
|
|
d = get_key_dir(email) / "sessions"
|
|
d.mkdir(parents=True, exist_ok=True)
|
|
os.chmod(d, 0o700)
|
|
if peer_device_id:
|
|
filename = f"{peer_user_id}_{peer_device_id}.bin"
|
|
else:
|
|
filename = f"{peer_user_id}.bin"
|
|
p = d / filename
|
|
data = ratchet.export_state()
|
|
if local_key:
|
|
data = _encrypt_local(data, local_key)
|
|
p.write_bytes(data)
|
|
os.chmod(p, 0o600)
|
|
|
|
|
|
def _load_session(email: str, peer_user_id: str,
|
|
local_key: bytes | None = None,
|
|
peer_device_id: str | None = None) -> DoubleRatchet | None:
|
|
d = get_key_dir(email) / "sessions"
|
|
if peer_device_id:
|
|
p = d / f"{peer_user_id}_{peer_device_id}.bin"
|
|
if not p.exists():
|
|
# Fallback: try old format (no device_id) and migrate
|
|
p_old = d / f"{peer_user_id}.bin"
|
|
if p_old.exists():
|
|
ratchet = _load_session_file(p_old, local_key)
|
|
if ratchet:
|
|
_save_session(email, peer_user_id, ratchet, local_key,
|
|
peer_device_id=peer_device_id)
|
|
_secure_delete(p_old)
|
|
return ratchet
|
|
return None
|
|
else:
|
|
p = d / f"{peer_user_id}.bin"
|
|
if not p.exists():
|
|
return None
|
|
return _load_session_file(p, local_key)
|
|
|
|
|
|
def _load_session_file(p: Path, local_key: bytes | None = None) -> DoubleRatchet | None:
|
|
"""Load a session from a specific file path."""
|
|
if not p.exists():
|
|
return None
|
|
raw = p.read_bytes()
|
|
if local_key:
|
|
try:
|
|
data = _decrypt_local(raw, local_key)
|
|
except Exception:
|
|
# Migration: try loading as plaintext (old unencrypted format)
|
|
try:
|
|
ratchet = DoubleRatchet.import_state(raw)
|
|
return ratchet
|
|
except Exception:
|
|
return None
|
|
return DoubleRatchet.import_state(data)
|
|
return DoubleRatchet.import_state(raw)
|
|
|
|
|
|
def _delete_session_file(email: str, peer_user_id: str, peer_device_id: str | None = None):
|
|
"""Securely delete a session file from disk (for session reset)."""
|
|
d = get_key_dir(email) / "sessions"
|
|
if peer_device_id:
|
|
p = d / f"{peer_user_id}_{peer_device_id}.bin"
|
|
else:
|
|
p = d / f"{peer_user_id}.bin"
|
|
_secure_delete(p)
|
|
|
|
|
|
def _save_sender_key_state(email: str, conv_id: str, state: SenderKeyState,
|
|
local_key: bytes | None = None):
|
|
d = get_key_dir(email) / "sender_keys"
|
|
d.mkdir(parents=True, exist_ok=True)
|
|
os.chmod(d, 0o700)
|
|
p = d / f"{conv_id}.bin"
|
|
data = state.export_state()
|
|
if local_key:
|
|
data = _encrypt_local(data, local_key)
|
|
p.write_bytes(data)
|
|
os.chmod(p, 0o600)
|
|
|
|
|
|
def _load_sender_key_state(email: str, conv_id: str,
|
|
local_key: bytes | None = None) -> SenderKeyState | None:
|
|
d = get_key_dir(email) / "sender_keys"
|
|
p = d / f"{conv_id}.bin"
|
|
if not p.exists():
|
|
return None
|
|
raw = p.read_bytes()
|
|
if local_key:
|
|
try:
|
|
data = _decrypt_local(raw, local_key)
|
|
except Exception:
|
|
try:
|
|
sk = SenderKeyState.import_state(raw)
|
|
_save_sender_key_state(email, conv_id, sk, local_key)
|
|
return sk
|
|
except Exception:
|
|
return None
|
|
return SenderKeyState.import_state(data)
|
|
return SenderKeyState.import_state(raw)
|
|
|
|
|
|
def _save_sender_key_recipients(email: str, conv_id: str, recipients: set[str],
|
|
local_key: bytes | None = None):
|
|
d = get_key_dir(email) / "sender_keys"
|
|
d.mkdir(parents=True, exist_ok=True)
|
|
os.chmod(d, 0o700)
|
|
p = d / f"{conv_id}.members.bin"
|
|
data = json.dumps(sorted(recipients), ensure_ascii=False).encode("utf-8")
|
|
if local_key:
|
|
data = _encrypt_local(data, local_key)
|
|
p.write_bytes(data)
|
|
os.chmod(p, 0o600)
|
|
|
|
|
|
def _load_sender_key_recipients(email: str, conv_id: str,
|
|
local_key: bytes | None = None) -> set[str]:
|
|
d = get_key_dir(email) / "sender_keys"
|
|
p = d / f"{conv_id}.members.bin"
|
|
if not p.exists():
|
|
return set()
|
|
raw = p.read_bytes()
|
|
if local_key:
|
|
try:
|
|
data = _decrypt_local(raw, local_key)
|
|
except Exception:
|
|
# Migration: previous plaintext storage, re-save encrypted on success.
|
|
try:
|
|
parsed = json.loads(raw.decode("utf-8"))
|
|
recipients = {str(uid) for uid in parsed if isinstance(uid, str)}
|
|
_save_sender_key_recipients(email, conv_id, recipients, local_key)
|
|
return recipients
|
|
except Exception:
|
|
return set()
|
|
else:
|
|
data = raw
|
|
try:
|
|
parsed = json.loads(data.decode("utf-8"))
|
|
if not isinstance(parsed, list):
|
|
return set()
|
|
return {str(uid) for uid in parsed if isinstance(uid, str)}
|
|
except Exception:
|
|
return set()
|
|
|
|
|
|
def _save_recv_sender_key(email: str, conv_id: str, sender_id: str, state: SenderKeyState,
|
|
local_key: bytes | None = None,
|
|
sender_device_id: str | None = None):
|
|
d = get_key_dir(email) / "sender_keys_recv"
|
|
d.mkdir(parents=True, exist_ok=True)
|
|
os.chmod(d, 0o700)
|
|
if sender_device_id:
|
|
filename = f"{conv_id}_{sender_id}_{sender_device_id}.bin"
|
|
else:
|
|
filename = f"{conv_id}_{sender_id}.bin"
|
|
p = d / filename
|
|
data = state.export_state()
|
|
if local_key:
|
|
data = _encrypt_local(data, local_key)
|
|
p.write_bytes(data)
|
|
os.chmod(p, 0o600)
|
|
|
|
|
|
def _load_recv_sender_key(email: str, conv_id: str, sender_id: str,
|
|
local_key: bytes | None = None,
|
|
sender_device_id: str | None = None) -> SenderKeyState | None:
|
|
d = get_key_dir(email) / "sender_keys_recv"
|
|
if sender_device_id:
|
|
p = d / f"{conv_id}_{sender_id}_{sender_device_id}.bin"
|
|
if not p.exists():
|
|
# Fallback: try old format and migrate
|
|
p_old = d / f"{conv_id}_{sender_id}.bin"
|
|
if p_old.exists():
|
|
sk = _load_recv_sender_key_file(p_old, local_key)
|
|
if sk:
|
|
_save_recv_sender_key(email, conv_id, sender_id, sk, local_key,
|
|
sender_device_id=sender_device_id)
|
|
_secure_delete(p_old)
|
|
return sk
|
|
return None
|
|
else:
|
|
p = d / f"{conv_id}_{sender_id}.bin"
|
|
if not p.exists():
|
|
return None
|
|
return _load_recv_sender_key_file(p, local_key)
|
|
|
|
|
|
def _load_recv_sender_key_file(p: Path, local_key: bytes | None = None) -> SenderKeyState | None:
|
|
"""Load a recv sender key from a specific file path."""
|
|
if not p.exists():
|
|
return None
|
|
raw = p.read_bytes()
|
|
if local_key:
|
|
try:
|
|
data = _decrypt_local(raw, local_key)
|
|
except Exception:
|
|
try:
|
|
sk = SenderKeyState.import_state(raw)
|
|
return sk
|
|
except Exception:
|
|
return None
|
|
return SenderKeyState.import_state(data)
|
|
return SenderKeyState.import_state(raw)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Local decrypted message cache (Double Ratchet keys are one-time use)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _load_message_cache(email: str, conv_id: str, cache_key: bytes | None = None) -> dict:
|
|
d = get_key_dir(email) / "message_cache"
|
|
p_bin = d / f"{conv_id}.bin"
|
|
p_json = d / f"{conv_id}.json"
|
|
|
|
# Migration: if old plaintext .json exists but encrypted .bin doesn't
|
|
if p_json.exists() and not p_bin.exists():
|
|
try:
|
|
cache = json.loads(p_json.read_text("utf-8"))
|
|
if cache_key:
|
|
_save_message_cache_full(d, conv_id, cache, cache_key)
|
|
_secure_delete(p_json)
|
|
return cache
|
|
except Exception:
|
|
return {}
|
|
|
|
if not p_bin.exists():
|
|
return {}
|
|
if not cache_key:
|
|
return {}
|
|
try:
|
|
raw = p_bin.read_bytes()
|
|
# Format: nonce (12) + tag (16) + ciphertext
|
|
nonce = raw[:12]
|
|
tag = raw[12:28]
|
|
ct = raw[28:]
|
|
plaintext = aes_decrypt(cache_key, nonce, ct, tag)
|
|
return json.loads(plaintext.decode("utf-8"))
|
|
except Exception:
|
|
return {}
|
|
|
|
|
|
def _save_message_cache_full(d: Path, conv_id: str, cache: dict, cache_key: bytes):
|
|
"""Write the full cache dict encrypted to disk."""
|
|
d.mkdir(parents=True, exist_ok=True)
|
|
os.chmod(d, 0o700)
|
|
p = d / f"{conv_id}.bin"
|
|
plaintext = json.dumps(cache, ensure_ascii=False).encode("utf-8")
|
|
_key, nonce, ct, tag = aes_encrypt(plaintext, key=cache_key)
|
|
p.write_bytes(nonce + tag + ct)
|
|
os.chmod(p, 0o600)
|
|
|
|
|
|
def _save_message_to_cache(email: str, conv_id: str, message_id: str, payload: dict,
|
|
cache_key: bytes | None = None):
|
|
d = get_key_dir(email) / "message_cache"
|
|
cache = _load_message_cache(email, conv_id, cache_key)
|
|
cache[message_id] = payload
|
|
if cache_key:
|
|
_save_message_cache_full(d, conv_id, cache, cache_key)
|
|
else:
|
|
# Fallback: plaintext (no identity key available yet)
|
|
d.mkdir(parents=True, exist_ok=True)
|
|
os.chmod(d, 0o700)
|
|
p = d / f"{conv_id}.json"
|
|
p.write_text(json.dumps(cache, ensure_ascii=False), "utf-8")
|
|
os.chmod(p, 0o600)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Verification storage (TOFU + explicit verification)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _save_known_identity_keys(email: str, keys: dict, local_key: bytes | None = None):
|
|
"""Save TOFU identity key registry (encrypted with local_key)."""
|
|
p = get_key_dir(email) / "known_identity_keys.bin"
|
|
data = json.dumps({"version": 1, "keys": keys}).encode("utf-8")
|
|
if local_key:
|
|
data = _encrypt_local(data, local_key)
|
|
p.write_bytes(data)
|
|
os.chmod(p, 0o600)
|
|
|
|
|
|
def _load_known_identity_keys(email: str, local_key: bytes | None = None) -> dict:
|
|
"""Load TOFU identity key registry. Returns empty dict on error.
|
|
|
|
No plaintext fallback — these files were never stored unencrypted
|
|
(feature introduced after local encryption was implemented).
|
|
Accepting plaintext would allow an attacker with disk access to
|
|
inject fake identity keys and bypass TOFU warnings.
|
|
"""
|
|
p = get_key_dir(email) / "known_identity_keys.bin"
|
|
if not p.exists():
|
|
return {}
|
|
raw = p.read_bytes()
|
|
try:
|
|
if local_key:
|
|
data = _decrypt_local(raw, local_key)
|
|
else:
|
|
data = raw
|
|
obj = json.loads(data)
|
|
return obj.get("keys", {})
|
|
except Exception:
|
|
return {}
|
|
|
|
|
|
def _save_verified_contacts(email: str, contacts: dict, local_key: bytes | None = None):
|
|
"""Save explicit verification state (encrypted with local_key)."""
|
|
p = get_key_dir(email) / "verified_contacts.bin"
|
|
data = json.dumps({"version": 1, "contacts": contacts}).encode("utf-8")
|
|
if local_key:
|
|
data = _encrypt_local(data, local_key)
|
|
p.write_bytes(data)
|
|
os.chmod(p, 0o600)
|
|
|
|
|
|
def _load_verified_contacts(email: str, local_key: bytes | None = None) -> dict:
|
|
"""Load explicit verification state. Returns empty dict on error.
|
|
|
|
No plaintext fallback — these files were never stored unencrypted.
|
|
Accepting plaintext would allow an attacker with disk access to
|
|
inject fake verification records (mark attacker as "verified").
|
|
"""
|
|
p = get_key_dir(email) / "verified_contacts.bin"
|
|
if not p.exists():
|
|
return {}
|
|
raw = p.read_bytes()
|
|
try:
|
|
if local_key:
|
|
data = _decrypt_local(raw, local_key)
|
|
else:
|
|
data = raw
|
|
obj = json.loads(data)
|
|
return obj.get("contacts", {})
|
|
except Exception:
|
|
return {}
|
|
|
|
|
|
def _solve_pow(challenge: str, difficulty: int) -> str:
|
|
"""Solve a proof-of-work challenge by finding a nonce with enough leading zero bits."""
|
|
target_bytes = difficulty // 8
|
|
target_bits = difficulty % 8
|
|
mask = (0xFF << (8 - target_bits)) & 0xFF if target_bits else 0
|
|
nonce = 0
|
|
while True:
|
|
digest = hashlib.sha256(f"{challenge}{nonce}".encode()).digest()
|
|
# Fast path: check full zero bytes first
|
|
ok = True
|
|
for i in range(target_bytes):
|
|
if digest[i] != 0:
|
|
ok = False
|
|
break
|
|
if ok and target_bits:
|
|
if digest[target_bytes] & mask:
|
|
ok = False
|
|
if ok:
|
|
return str(nonce)
|
|
nonce += 1
|
|
|
|
|
|
class ChatClient:
|
|
def __init__(self):
|
|
self.reader: ProtocolReader | None = None
|
|
self.writer: ProtocolWriter | None = None
|
|
self.raw_writer: asyncio.StreamWriter | None = None
|
|
self.session: dict | None = None
|
|
self.private_key = None # RSA private key (login only)
|
|
self.public_key = None # RSA public key (login only)
|
|
self.username: str = ""
|
|
self.email: str = ""
|
|
self._listener_task: asyncio.Task | None = None
|
|
self._response_queue: asyncio.Queue = asyncio.Queue()
|
|
self._notification_queue: asyncio.Queue = asyncio.Queue()
|
|
self._pending: dict[str, asyncio.Future] = {}
|
|
self._pairing_temp_private_key = None
|
|
self._pairing_fingerprint: str = ""
|
|
self._pairing_code: str = ""
|
|
self._reencrypt_progress_cb = None
|
|
self._logger = logging.getLogger("encrypted_chat.client")
|
|
|
|
# Signal Protocol keys
|
|
self.identity_private = None # Ed25519PrivateKey
|
|
self.identity_public = None # Ed25519PublicKey
|
|
self.spk_private = None # X25519PrivateKey (current signed prekey)
|
|
self.spk_id: str = ""
|
|
self._prev_spk_private = None # Previous SPK for grace period (M4)
|
|
self._prev_spk_id: str = ""
|
|
self.opk_privates: dict[str, object] = {} # id -> X25519PrivateKey
|
|
self.sessions: dict[str, DoubleRatchet] = {} # "user_id:device_id" -> ratchet
|
|
self.sender_key_states: dict[str, SenderKeyState] = {} # conv_id -> own sender key
|
|
self.recv_sender_keys: dict[str, SenderKeyState] = {} # "conv_id:sender_id:device_id" -> their key
|
|
# Cache: user_id -> {identity_key (Ed25519PublicKey), username, email}
|
|
# Bounded to 10K entries with 1-hour TTL to prevent unbounded growth (L7)
|
|
self._user_cache: _ExpiringLRUCache = _ExpiringLRUCache(max_size=10_000, ttl=3600.0)
|
|
self.connected: bool = False
|
|
self.login_rejected: bool = False
|
|
self._cache_key: bytes | None = None # AES key for encrypting message cache on disk
|
|
self._local_key: bytes | None = None # AES key for encrypting session/sender key files
|
|
# Multi-device support
|
|
self.device_id: str | None = None # This device's UUID
|
|
self._device_bundle_cache: dict[str, tuple[float, list[dict]]] = {} # user_id -> (ts, bundles)
|
|
# Queue of received messages to self-encrypt for multi-device access
|
|
self._pending_self_encrypt: list[dict] = []
|
|
self._typing_active: dict[str, bool] = {}
|
|
self._typing_last_sent: dict[str, float] = {}
|
|
self._typing_stop_tasks: dict[str, asyncio.Task] = {}
|
|
# Contact key verification (TOFU + explicit)
|
|
self._known_identity_keys: dict = {} # user_id -> {identity_key hex, first_seen, last_seen}
|
|
self._verified_contacts: dict = {} # user_id -> {identity_key hex, verified_at, method}
|
|
self._key_change_cb = None # callback(user_id, username, old_key_hex, was_verified)
|
|
|
|
async def connect(self):
|
|
host = os.getenv("SERVER_HOST", "127.0.0.1")
|
|
port = int(os.getenv("SERVER_PORT", "9999"))
|
|
tls_enabled = os.getenv("TLS_ENABLED", "false").lower() in ("1", "true", "yes")
|
|
tls_required = os.getenv("TLS_REQUIRED", "false").lower() in ("1", "true", "yes")
|
|
ssl_context = None
|
|
if tls_required and not tls_enabled:
|
|
raise RuntimeError("TLS_REQUIRED is enabled but TLS is not enabled.")
|
|
if tls_enabled:
|
|
insecure = os.getenv("TLS_INSECURE", "false").lower() in ("1", "true", "yes")
|
|
is_dev = os.getenv("ENVIRONMENT", "").lower() in ("dev", "development")
|
|
if insecure and not is_dev:
|
|
raise RuntimeError("TLS_INSECURE is only allowed when ENVIRONMENT=dev")
|
|
ssl_context = ssl.create_default_context()
|
|
ca_file = os.getenv("TLS_CA_FILE", "").strip()
|
|
if ca_file:
|
|
ssl_context.load_verify_locations(cafile=ca_file)
|
|
elif insecure:
|
|
ssl_context.check_hostname = False
|
|
ssl_context.verify_mode = ssl.CERT_NONE
|
|
else:
|
|
self._logger.warning("TLS is disabled — traffic is unencrypted. Set TLS_ENABLED=true for production.")
|
|
r, w = await asyncio.open_connection(host, port, limit=MAX_MESSAGE_BYTES, ssl=ssl_context)
|
|
# Enable TCP keepalive to detect dead connections through NAT/firewalls
|
|
sock = w.get_extra_info("socket")
|
|
if sock is not None:
|
|
import socket
|
|
try:
|
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
|
|
if hasattr(socket, "TCP_KEEPIDLE"):
|
|
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 25)
|
|
if hasattr(socket, "TCP_KEEPINTVL"):
|
|
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 10)
|
|
if hasattr(socket, "TCP_KEEPCNT"):
|
|
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 3)
|
|
except OSError:
|
|
pass
|
|
self.reader = ProtocolReader(r)
|
|
self.writer = ProtocolWriter(w)
|
|
self.raw_writer = w
|
|
self.connected = True
|
|
self._logger.info("Connected to %s:%s (tls=%s)", host, port, "on" if tls_enabled else "off")
|
|
|
|
def server_endpoint(self) -> str:
|
|
host = os.getenv("SERVER_HOST", "127.0.0.1")
|
|
port = os.getenv("SERVER_PORT", "9999")
|
|
return f"{host}:{port}"
|
|
|
|
def pairing_fingerprint(self) -> str:
|
|
return self._pairing_fingerprint
|
|
|
|
def pairing_qr_data(self) -> bytes | None:
|
|
if not self._pairing_code or not self._pairing_fingerprint:
|
|
return None
|
|
return encode_pairing_qr(self._pairing_code, self._pairing_fingerprint)
|
|
|
|
async def _background_listener(self):
|
|
"""Read messages from server, routing responses vs notifications."""
|
|
while True:
|
|
msg = await self.reader.read_message()
|
|
if msg is None:
|
|
self.connected = False
|
|
# Fail all pending futures so send_and_recv doesn't hang
|
|
pending = dict(self._pending)
|
|
self._pending.clear()
|
|
err = ConnectionError("Server connection lost")
|
|
for obj in pending.values():
|
|
if isinstance(obj, asyncio.Queue):
|
|
# Signal stream consumers that connection died
|
|
obj.put_nowait({"status": "error", "data": {"message": "Connection lost"}})
|
|
elif not obj.done():
|
|
obj.set_exception(err)
|
|
break
|
|
# Responses to our own requests (have request_id matching a pending future)
|
|
# must be routed to the pending future, even if the type matches a notification name.
|
|
req_id = msg.get("request_id")
|
|
if req_id and req_id in self._pending:
|
|
pending_obj = self._pending[req_id]
|
|
if isinstance(pending_obj, asyncio.Queue):
|
|
await pending_obj.put(msg)
|
|
else:
|
|
self._pending.pop(req_id)
|
|
if not pending_obj.done():
|
|
pending_obj.set_result(msg)
|
|
elif msg.get("type") in ("new_message", "messages_read", "message_deleted",
|
|
"conversation_created", "member_added", "member_removed",
|
|
"user_online", "user_offline", "online_users",
|
|
"group_invitation", "conversation_renamed",
|
|
"device_added",
|
|
"session_reset",
|
|
"message_reacted", "message_pinned", "message_unpinned",
|
|
"message_delivered", "username_changed",
|
|
"avatar_changed", "keys_updated",
|
|
"typing_start", "typing_stop"):
|
|
await self._notification_queue.put(msg)
|
|
else:
|
|
await self._response_queue.put(msg)
|
|
|
|
async def send_and_recv(self, msg_type: str, timeout: float = 30.0, **kwargs) -> dict:
|
|
try:
|
|
request_id = str(uuid.uuid4())
|
|
loop = asyncio.get_running_loop()
|
|
fut = loop.create_future()
|
|
self._pending[request_id] = fut
|
|
await self.writer.send_request(msg_type, request_id=request_id, **kwargs)
|
|
except (ValueError, ConnectionError, OSError) as e:
|
|
self._pending.pop(request_id, None)
|
|
return {
|
|
"type": msg_type,
|
|
"status": "error",
|
|
"data": {"message": str(e) or "Connection lost."},
|
|
}
|
|
try:
|
|
return await asyncio.wait_for(fut, timeout=timeout)
|
|
except asyncio.TimeoutError:
|
|
self._logger.warning("send_and_recv timeout for '%s' after %.0fs", msg_type, timeout)
|
|
return {
|
|
"type": msg_type,
|
|
"status": "error",
|
|
"data": {"message": f"Request timed out ({msg_type})"},
|
|
}
|
|
except ConnectionError:
|
|
return {
|
|
"type": msg_type,
|
|
"status": "error",
|
|
"data": {"message": "Connection lost."},
|
|
}
|
|
finally:
|
|
self._pending.pop(request_id, None)
|
|
|
|
# ------------------------------------------------------------------
|
|
# User info / identity key cache
|
|
# ------------------------------------------------------------------
|
|
|
|
async def _get_user_info(self, user_id: str = "", email: str = "") -> dict | None:
|
|
"""Get user info from server, cache identity key. Performs TOFU check."""
|
|
cached = self._user_cache.get(user_id)
|
|
if cached:
|
|
return cached
|
|
kwargs = {}
|
|
if user_id:
|
|
kwargs["user_id"] = user_id
|
|
elif email:
|
|
kwargs["email"] = email
|
|
else:
|
|
return None
|
|
resp = await self.send_and_recv("get_user_info", **kwargs)
|
|
if resp["status"] != "ok":
|
|
return None
|
|
data = resp["data"]
|
|
ik_bytes = decode_binary(data["identity_key"]) if data.get("identity_key") else None
|
|
info = {
|
|
"user_id": data["user_id"],
|
|
"username": data["username"],
|
|
"email": data["email"],
|
|
"identity_key": load_ed25519_public(ik_bytes) if ik_bytes else None,
|
|
"identity_key_bytes": ik_bytes,
|
|
}
|
|
# TOFU: check identity key against known keys
|
|
if ik_bytes:
|
|
status = self.check_identity_key(data["user_id"], ik_bytes)
|
|
info["identity_key_status"] = status
|
|
self._user_cache[data["user_id"]] = info
|
|
return info
|
|
|
|
# ------------------------------------------------------------------
|
|
# Contact Key Verification
|
|
# ------------------------------------------------------------------
|
|
|
|
def _load_verification_stores(self):
|
|
"""Load TOFU and verification stores from disk."""
|
|
if not self.email:
|
|
return
|
|
self._known_identity_keys = _load_known_identity_keys(self.email, self._local_key)
|
|
self._verified_contacts = _load_verified_contacts(self.email, self._local_key)
|
|
|
|
def check_identity_key(self, user_id: str, identity_key_bytes: bytes) -> str:
|
|
"""Check a user's identity key against TOFU registry.
|
|
|
|
Returns:
|
|
"new" — first contact, key recorded (TOFU)
|
|
"trusted" — key matches previously seen, not explicitly verified
|
|
"verified" — key matches and explicitly verified
|
|
"changed" — key differs from recorded (WARNING)
|
|
"changed_verified" — key changed AND was previously verified (CRITICAL)
|
|
"""
|
|
ik_hex = identity_key_bytes.hex()
|
|
now = datetime.now(timezone.utc).isoformat()
|
|
known = self._known_identity_keys.get(user_id)
|
|
|
|
if known is None:
|
|
# First time seeing this user — TOFU: trust on first use
|
|
self._known_identity_keys[user_id] = {
|
|
"identity_key": ik_hex,
|
|
"first_seen": now,
|
|
"last_seen": now,
|
|
}
|
|
if self.email:
|
|
_save_known_identity_keys(self.email, self._known_identity_keys, self._local_key)
|
|
return "new"
|
|
|
|
if known["identity_key"] == ik_hex:
|
|
# Key matches — update last_seen
|
|
known["last_seen"] = now
|
|
if self.email:
|
|
_save_known_identity_keys(self.email, self._known_identity_keys, self._local_key)
|
|
# Check if explicitly verified
|
|
verified = self._verified_contacts.get(user_id)
|
|
if verified and verified.get("identity_key") == ik_hex:
|
|
return "verified"
|
|
return "trusted"
|
|
|
|
# Key has CHANGED
|
|
was_verified = user_id in self._verified_contacts
|
|
old_key_hex = known["identity_key"]
|
|
|
|
# Invoke callback for GUI/CLI warning
|
|
if self._key_change_cb:
|
|
username = ""
|
|
cached = self._user_cache.get(user_id)
|
|
if cached:
|
|
username = cached.get("username", "")
|
|
try:
|
|
self._key_change_cb(user_id, username, old_key_hex, was_verified, identity_key_bytes)
|
|
except Exception:
|
|
pass
|
|
|
|
return "changed_verified" if was_verified else "changed"
|
|
|
|
def verify_contact(self, user_id: str, identity_key_bytes: bytes, method: str = "manual"):
|
|
"""Mark a contact's identity key as explicitly verified."""
|
|
ik_hex = identity_key_bytes.hex()
|
|
now = datetime.now(timezone.utc).isoformat()
|
|
self._verified_contacts[user_id] = {
|
|
"identity_key": ik_hex,
|
|
"verified_at": now,
|
|
"method": method,
|
|
}
|
|
# Also ensure TOFU registry is up to date
|
|
if user_id not in self._known_identity_keys:
|
|
self._known_identity_keys[user_id] = {
|
|
"identity_key": ik_hex,
|
|
"first_seen": now,
|
|
"last_seen": now,
|
|
}
|
|
else:
|
|
self._known_identity_keys[user_id]["last_seen"] = now
|
|
if self.email:
|
|
_save_verified_contacts(self.email, self._verified_contacts, self._local_key)
|
|
_save_known_identity_keys(self.email, self._known_identity_keys, self._local_key)
|
|
# Update user cache status
|
|
cached = self._user_cache.get(user_id)
|
|
if cached:
|
|
cached["identity_key_status"] = "verified"
|
|
|
|
def unverify_contact(self, user_id: str):
|
|
"""Remove explicit verification for a contact."""
|
|
self._verified_contacts.pop(user_id, None)
|
|
if self.email:
|
|
_save_verified_contacts(self.email, self._verified_contacts, self._local_key)
|
|
cached = self._user_cache.get(user_id)
|
|
if cached and cached.get("identity_key_status") == "verified":
|
|
cached["identity_key_status"] = "trusted"
|
|
|
|
def accept_key_change(self, user_id: str, new_ik_bytes: bytes):
|
|
"""Accept a changed identity key — update TOFU, remove old verification."""
|
|
ik_hex = new_ik_bytes.hex()
|
|
now = datetime.now(timezone.utc).isoformat()
|
|
self._known_identity_keys[user_id] = {
|
|
"identity_key": ik_hex,
|
|
"first_seen": now,
|
|
"last_seen": now,
|
|
}
|
|
# Remove old verification — user must re-verify
|
|
self._verified_contacts.pop(user_id, None)
|
|
if self.email:
|
|
_save_known_identity_keys(self.email, self._known_identity_keys, self._local_key)
|
|
_save_verified_contacts(self.email, self._verified_contacts, self._local_key)
|
|
# Update cache
|
|
cached = self._user_cache.get(user_id)
|
|
if cached:
|
|
cached["identity_key_status"] = "trusted"
|
|
|
|
def get_verification_status(self, user_id: str) -> str:
|
|
"""Get verification status for a user.
|
|
|
|
Returns: "verified", "trusted", or "unverified".
|
|
"""
|
|
verified = self._verified_contacts.get(user_id)
|
|
if verified:
|
|
# Check key still matches
|
|
known = self._known_identity_keys.get(user_id)
|
|
if known and known.get("identity_key") == verified.get("identity_key"):
|
|
return "verified"
|
|
if user_id in self._known_identity_keys:
|
|
return "trusted"
|
|
return "unverified"
|
|
|
|
def get_safety_number(self, peer_user_id: str) -> str | None:
|
|
"""Get formatted safety number for a peer (requires both identity keys)."""
|
|
if not self.identity_public or not self.session:
|
|
return None
|
|
my_uid = self.session.get("user_id", "")
|
|
my_ik_bytes = serialize_ed25519_public(self.identity_public)
|
|
cached = self._user_cache.get(peer_user_id)
|
|
if not cached or not cached.get("identity_key_bytes"):
|
|
return None
|
|
return compute_safety_number(my_uid, my_ik_bytes,
|
|
peer_user_id, cached["identity_key_bytes"])
|
|
|
|
def get_my_fingerprint(self) -> str | None:
|
|
"""Get formatted fingerprint for own identity key."""
|
|
if not self.identity_public or not self.session:
|
|
return None
|
|
my_uid = self.session.get("user_id", "")
|
|
my_ik_bytes = serialize_ed25519_public(self.identity_public)
|
|
fp = compute_fingerprint(my_uid, my_ik_bytes)
|
|
return format_fingerprint(fp)
|
|
|
|
def get_peer_fingerprint(self, peer_user_id: str) -> str | None:
|
|
"""Get formatted fingerprint for a peer's identity key."""
|
|
cached = self._user_cache.get(peer_user_id)
|
|
if not cached or not cached.get("identity_key_bytes"):
|
|
return None
|
|
fp = compute_fingerprint(peer_user_id, cached["identity_key_bytes"])
|
|
return format_fingerprint(fp)
|
|
|
|
def get_verification_qr_data(self) -> bytes | None:
|
|
"""Get QR code payload bytes for own identity (for peer to scan)."""
|
|
if not self.identity_public or not self.session:
|
|
return None
|
|
my_uid = self.session.get("user_id", "")
|
|
my_ik_bytes = serialize_ed25519_public(self.identity_public)
|
|
return encode_verification_qr(my_uid, my_ik_bytes)
|
|
|
|
def verify_qr_code(self, qr_data: bytes) -> tuple[bool, str, str]:
|
|
"""Verify a scanned QR code against known identity keys.
|
|
|
|
Returns (success, user_id, message).
|
|
"""
|
|
try:
|
|
user_id, ik_bytes = decode_verification_qr(qr_data)
|
|
except ValueError as e:
|
|
return False, "", f"Invalid QR code: {e}"
|
|
cached = self._user_cache.get(user_id)
|
|
if not cached:
|
|
return False, user_id, "Unknown user — not in your contacts."
|
|
if not cached.get("identity_key_bytes"):
|
|
return False, user_id, "No identity key on record for this user."
|
|
if cached["identity_key_bytes"] != ik_bytes:
|
|
return False, user_id, "Identity key MISMATCH — verification failed!"
|
|
# Keys match — mark as verified
|
|
self.verify_contact(user_id, ik_bytes, method="qr_code")
|
|
username = cached.get("username", user_id[:8])
|
|
return True, user_id, f"Verified {username} via QR code."
|
|
|
|
# ------------------------------------------------------------------
|
|
# Registration
|
|
# ------------------------------------------------------------------
|
|
|
|
async def register(self, username: str, password: str, email: str) -> tuple[bool, str]:
|
|
"""Register user. Generates RSA + Ed25519 in memory (saved to disk
|
|
only after server confirms registration via confirm_registration)."""
|
|
self.username = username
|
|
self.email = email
|
|
pwd_bytes = bytearray(password.encode("utf-8")) if password else None
|
|
|
|
try:
|
|
pwd = bytes(pwd_bytes) if pwd_bytes else None
|
|
# Try loading existing keys (previous successful registration)
|
|
priv, pub, err = load_keys(email, password=pwd)
|
|
if priv is None:
|
|
priv, pub = generate_rsa_keypair()
|
|
self.private_key = priv
|
|
self.public_key = pub
|
|
|
|
try:
|
|
ed_priv, ed_pub = _load_identity_keys(email, password=pwd)
|
|
except Exception:
|
|
ed_priv, ed_pub = None, None
|
|
if ed_priv is None:
|
|
ed_priv, ed_pub = generate_identity_keypair()
|
|
self.identity_private = ed_priv
|
|
self.identity_public = ed_pub
|
|
self._cache_key = derive_self_encryption_key(ed_priv)
|
|
self._local_key = derive_local_storage_key(ed_priv)
|
|
|
|
# Store password for saving keys after confirm
|
|
self._reg_password = pwd
|
|
finally:
|
|
if pwd_bytes:
|
|
pwd_bytes[:] = b'\x00' * len(pwd_bytes)
|
|
|
|
pub_pem = serialize_public_key(pub).decode("utf-8")
|
|
ik_b64 = encode_binary(serialize_ed25519_public(ed_pub))
|
|
|
|
extra_fields: dict = {}
|
|
start = await self.send_and_recv(
|
|
"register",
|
|
username=username,
|
|
public_key=pub_pem,
|
|
email=email,
|
|
identity_key=ik_b64,
|
|
)
|
|
# Handle PoW challenge (server under pressure)
|
|
if start.get("status") == "pow_required":
|
|
challenge = start["data"]["challenge"]
|
|
mac = start["data"]["mac"]
|
|
difficulty = start["data"]["difficulty"]
|
|
logger.info("Server requires proof-of-work (difficulty %d), solving...", difficulty)
|
|
nonce = _solve_pow(challenge, difficulty)
|
|
extra_fields = {"pow_challenge": challenge, "pow_mac": mac, "pow_nonce": nonce}
|
|
start = await self.send_and_recv(
|
|
"register",
|
|
username=username,
|
|
public_key=pub_pem,
|
|
email=email,
|
|
identity_key=ik_b64,
|
|
**extra_fields,
|
|
)
|
|
if start["status"] != "ok":
|
|
self._reg_password = None
|
|
return False, start["data"]["message"]
|
|
code = start["data"].get("code")
|
|
if code:
|
|
return True, code
|
|
return True, start["data"].get("message", "Check your email for the code.")
|
|
|
|
async def confirm_registration(self, email: str, username: str, code: str) -> tuple[bool, str]:
|
|
confirm = await self.send_and_recv("register_confirm", email=email, code=code)
|
|
if confirm["status"] == "ok":
|
|
# Registration confirmed — NOW save keys to disk
|
|
pwd = getattr(self, "_reg_password", None)
|
|
save_keys(email, self.private_key, self.public_key, password=pwd)
|
|
_save_identity_keys(email, self.identity_private, self.identity_public, password=pwd)
|
|
self._reg_password = None
|
|
self._load_verification_stores()
|
|
return True, f"Registered as '{username}' (ID: {confirm['data']['user_id']})"
|
|
return False, confirm["data"]["message"]
|
|
|
|
async def _generate_and_upload_prekeys(self, keep_spk: bool = False):
|
|
"""Generate SPK + OPKs and upload to server.
|
|
|
|
If keep_spk=True, re-sign the existing SPK instead of generating a new
|
|
one. This is used after device pairing so both devices share the same
|
|
SPK and either can respond to X3DH.
|
|
"""
|
|
if not self.identity_private:
|
|
return
|
|
|
|
if keep_spk and self.spk_private and self.spk_id:
|
|
# Re-sign existing SPK (both devices share the identity key)
|
|
spk_pub_bytes = serialize_x25519_public(self.spk_private.public_key())
|
|
spk_sig = ed25519_sign(self.identity_private, spk_pub_bytes)
|
|
spk_data = {
|
|
"id": self.spk_id,
|
|
"public_key": encode_binary(spk_pub_bytes),
|
|
"signature": encode_binary(spk_sig),
|
|
}
|
|
else:
|
|
# Save current SPK as previous for grace period (M4: in-flight X3DH)
|
|
if self.spk_private and self.spk_id:
|
|
self._prev_spk_private = self.spk_private
|
|
self._prev_spk_id = self.spk_id
|
|
_save_prev_spk(self.email, self.spk_private, self.spk_id, self._local_key)
|
|
# Generate a brand-new signed prekey
|
|
spk = generate_signed_prekey(self.identity_private)
|
|
self.spk_private = spk["private"]
|
|
self.spk_id = spk["id"]
|
|
_save_spk(self.email, spk["private"], spk["id"], self._local_key)
|
|
spk_data = {
|
|
"id": spk["id"],
|
|
"public_key": encode_binary(serialize_x25519_public(spk["public"])),
|
|
"signature": encode_binary(spk["signature"]),
|
|
}
|
|
|
|
# Generate one-time prekeys
|
|
opks = generate_one_time_prekeys(OPK_BATCH_SIZE)
|
|
for opk in opks:
|
|
self.opk_privates[opk["id"]] = opk["private"]
|
|
_save_opk_private(self.email, opk["id"], opk["private"], self._local_key)
|
|
|
|
# Upload to server
|
|
otp_data = [
|
|
{"id": opk["id"], "public_key": encode_binary(serialize_x25519_public(opk["public"]))}
|
|
for opk in opks
|
|
]
|
|
resp = await self.send_and_recv(
|
|
"upload_prekeys",
|
|
signed_prekey=spk_data,
|
|
one_time_prekeys=otp_data,
|
|
)
|
|
if resp.get("status") != "ok":
|
|
self._logger.warning("upload_prekeys failed: %s (will retry on login)",
|
|
resp.get("data", {}).get("message", "unknown"))
|
|
|
|
async def _ensure_prekeys(self):
|
|
"""Check OPK count and SPK age, replenish/rotate if needed.
|
|
|
|
Uses single-roundtrip `ensure_prekeys` handler when available,
|
|
falls back to legacy two-step flow (get_prekey_count + upload_prekeys).
|
|
"""
|
|
resp = await self.send_and_recv("get_prekey_count")
|
|
if resp["status"] != "ok":
|
|
return
|
|
count = resp["data"].get("count", 0)
|
|
spk_created_at = resp["data"].get("spk_created_at", "")
|
|
|
|
need_new_spk = False
|
|
if spk_created_at:
|
|
try:
|
|
created = datetime.fromisoformat(spk_created_at)
|
|
if created.tzinfo is None:
|
|
created = created.replace(tzinfo=timezone.utc)
|
|
age_days = (datetime.now(timezone.utc) - created).days
|
|
if age_days >= SPK_ROTATION_DAYS:
|
|
need_new_spk = True
|
|
self._logger.info("SPK is %d days old, rotating...", age_days)
|
|
except Exception:
|
|
need_new_spk = True
|
|
else:
|
|
# No SPK on server for this device — must upload one
|
|
need_new_spk = True
|
|
self._logger.info("No SPK on server for this device, uploading...")
|
|
|
|
if count < OPK_REPLENISH_THRESHOLD or need_new_spk:
|
|
if count >= OPK_REPLENISH_THRESHOLD:
|
|
self._logger.info("SPK rotation triggered (OPK count OK: %d)", count)
|
|
else:
|
|
self._logger.info("OPK count low (%d), replenishing...", count)
|
|
await self._generate_and_upload_prekeys_batch(need_new_spk)
|
|
|
|
async def _generate_and_upload_prekeys_batch(self, need_new_spk: bool = False):
|
|
"""Generate and upload prekeys in a single round-trip via ensure_prekeys."""
|
|
if not self.identity_private:
|
|
return
|
|
|
|
kwargs: dict = {}
|
|
|
|
# SPK
|
|
if need_new_spk:
|
|
if self.spk_private and self.spk_id:
|
|
self._prev_spk_private = self.spk_private
|
|
self._prev_spk_id = self.spk_id
|
|
_save_prev_spk(self.email, self.spk_private, self.spk_id, self._local_key)
|
|
spk = generate_signed_prekey(self.identity_private)
|
|
self.spk_private = spk["private"]
|
|
self.spk_id = spk["id"]
|
|
_save_spk(self.email, spk["private"], spk["id"], self._local_key)
|
|
kwargs["signed_prekey"] = {
|
|
"id": spk["id"],
|
|
"public_key": encode_binary(serialize_x25519_public(spk["public"])),
|
|
"signature": encode_binary(spk["signature"]),
|
|
}
|
|
|
|
# OPKs
|
|
opks = generate_one_time_prekeys(OPK_BATCH_SIZE)
|
|
for opk in opks:
|
|
self.opk_privates[opk["id"]] = opk["private"]
|
|
_save_opk_private(self.email, opk["id"], opk["private"], self._local_key)
|
|
kwargs["one_time_prekeys"] = [
|
|
{"id": opk["id"], "public_key": encode_binary(serialize_x25519_public(opk["public"]))}
|
|
for opk in opks
|
|
]
|
|
|
|
resp = await self.send_and_recv("ensure_prekeys", **kwargs)
|
|
if resp["status"] == "ok":
|
|
data = resp.get("data", {})
|
|
self._logger.info("ensure_prekeys: count=%d, spk_uploaded=%s, otps_uploaded=%d",
|
|
data.get("count", 0), data.get("uploaded_spk", False),
|
|
data.get("uploaded_otps", 0))
|
|
|
|
# ------------------------------------------------------------------
|
|
# Login
|
|
# ------------------------------------------------------------------
|
|
|
|
async def login(self, email: str, password: str) -> tuple[bool, str]:
|
|
"""Login user. Returns (success, message)."""
|
|
self.email = email
|
|
|
|
# Brute-force lockout check
|
|
remaining = _check_lockout(email)
|
|
if remaining > 0:
|
|
return False, f"Too many failed attempts. Try again in {remaining:.0f}s."
|
|
|
|
pwd_bytes = bytearray(password.encode("utf-8")) if password else None
|
|
|
|
try:
|
|
# Load RSA keys
|
|
priv, pub, err = load_keys(email, password=bytes(pwd_bytes) if pwd_bytes else None)
|
|
if priv is None:
|
|
if err and "password" in err.lower():
|
|
_record_failed_attempt(email)
|
|
return False, err or "No local keys found. Register first."
|
|
self.private_key = priv
|
|
self.public_key = pub
|
|
|
|
# Load identity keys
|
|
ed_priv, ed_pub = _load_identity_keys(email, password=bytes(pwd_bytes) if pwd_bytes else None)
|
|
finally:
|
|
if pwd_bytes:
|
|
pwd_bytes[:] = b'\x00' * len(pwd_bytes)
|
|
|
|
if ed_priv is not None:
|
|
self.identity_private = ed_priv
|
|
self.identity_public = ed_pub
|
|
self._cache_key = derive_self_encryption_key(ed_priv)
|
|
self._local_key = derive_local_storage_key(ed_priv)
|
|
self._load_verification_stores()
|
|
|
|
# Load SPK
|
|
spk_priv, spk_id = _load_spk(email, self._local_key)
|
|
if spk_priv:
|
|
self.spk_private = spk_priv
|
|
self.spk_id = spk_id
|
|
|
|
# Load previous SPK for grace period (M4)
|
|
prev_spk_priv, prev_spk_id = _load_prev_spk(email, self._local_key)
|
|
if prev_spk_priv:
|
|
self._prev_spk_private = prev_spk_priv
|
|
self._prev_spk_id = prev_spk_id
|
|
|
|
# Load device_id from disk
|
|
self.device_id = _load_device_id(email)
|
|
|
|
# RSA challenge-response login
|
|
start = await self.send_and_recv("login_start", email=email)
|
|
if start["status"] != "ok":
|
|
return False, start["data"]["message"]
|
|
|
|
challenge = decode_binary(start["data"]["challenge"])
|
|
signature = rsa_sign(self.private_key, challenge)
|
|
login_kwargs = {"email": email, "signature": encode_binary(signature),
|
|
"client_version": VERSION}
|
|
if self.device_id:
|
|
login_kwargs["device_id"] = self.device_id
|
|
finish = await self.send_and_recv("login_finish", **login_kwargs)
|
|
if finish["status"] == "ok":
|
|
self.session = finish["data"]
|
|
self.username = self.session.get("username", "")
|
|
# Store device_id from server
|
|
self.device_id = finish["data"].get("device_id")
|
|
if self.device_id:
|
|
_save_device_id(email, self.device_id)
|
|
# Replenish prekeys in background — after pairing, the new device
|
|
# has no local OPK private keys so we must generate fresh ones
|
|
# (server-side OPKs have no matching private keys on this device).
|
|
# Use keep_spk=True to preserve the shared SPK so both devices
|
|
# can respond to X3DH.
|
|
opk_dir = get_key_dir(self.email) / "opk_private"
|
|
has_local_opks = opk_dir.exists() and any(opk_dir.iterdir())
|
|
if has_local_opks:
|
|
asyncio.create_task(self._ensure_prekeys())
|
|
else:
|
|
self._logger.info("No local OPKs (likely new device). Generating fresh OPKs, keeping SPK.")
|
|
asyncio.create_task(self._generate_and_upload_prekeys(keep_spk=True))
|
|
_clear_lockout(email)
|
|
return True, f"Logged in as '{self.username}' (ID: {self.session['user_id']})"
|
|
return False, finish["data"]["message"]
|
|
|
|
# ------------------------------------------------------------------
|
|
# Pairing (device pairing — transfers RSA + identity keys)
|
|
# ------------------------------------------------------------------
|
|
|
|
async def pairing_start(self, email: str) -> tuple[bool, str]:
|
|
"""Start device pairing. Returns (success, code/message)."""
|
|
self._logger.info("pairing_start via %s for %s", self.server_endpoint(), email.strip().lower())
|
|
temp_priv, temp_pub = generate_x25519_keypair()
|
|
self._pairing_temp_private_key = temp_priv
|
|
temp_pub_raw = serialize_x25519_public(temp_pub)
|
|
self._pairing_fingerprint = compute_pairing_fingerprint(temp_pub_raw)
|
|
resp = await self.send_and_recv(
|
|
"pairing_start",
|
|
email=email,
|
|
temp_public_key=encode_binary(temp_pub_raw),
|
|
temp_key_type="x25519",
|
|
)
|
|
if resp["status"] == "ok":
|
|
self._pairing_code = resp["data"]["code"]
|
|
self._pairing_poll_token = resp["data"].get("poll_token", "")
|
|
return True, resp["data"]["code"]
|
|
self._pairing_fingerprint = ""
|
|
self._pairing_code = ""
|
|
return False, resp["data"]["message"]
|
|
|
|
async def pairing_wait(self, code: str, email: str, password: str, timeout: int = 300) -> tuple[bool, str]:
|
|
"""Wait for pairing payload and import keys. Returns (success, message)."""
|
|
if not self._pairing_temp_private_key:
|
|
return False, "Pairing not started."
|
|
from crypto_utils import aes_decrypt as _aes_decrypt
|
|
poll_token = getattr(self, "_pairing_poll_token", "")
|
|
deadline = asyncio.get_event_loop().time() + timeout
|
|
while asyncio.get_event_loop().time() < deadline:
|
|
resp = await self.send_and_recv("pairing_poll", code=code, poll_token=poll_token)
|
|
if resp["status"] != "ok":
|
|
self._pairing_fingerprint = ""
|
|
self._pairing_code = ""
|
|
return False, resp["data"]["message"]
|
|
if not resp["data"].get("ready"):
|
|
await asyncio.sleep(2.0)
|
|
continue
|
|
payload = resp["data"]["payload"]
|
|
try:
|
|
sender_pub_raw = decode_binary(payload["sender_public_key"])
|
|
sender_pub = load_x25519_public(sender_pub_raw)
|
|
my_pub_raw = serialize_x25519_public(self._pairing_temp_private_key.public_key())
|
|
shared_secret = x25519_dh(self._pairing_temp_private_key, sender_pub)
|
|
aes_key = derive_pairing_shared_key(shared_secret, my_pub_raw, sender_pub_raw)
|
|
nonce = decode_binary(payload["iv"])
|
|
ct = decode_binary(payload["ciphertext"])
|
|
tag = decode_binary(payload["tag"])
|
|
keys_json = _aes_decrypt(aes_key, nonce, ct, tag)
|
|
keys_data = json.loads(keys_json)
|
|
|
|
pwd_bytes = bytearray(password.encode("utf-8")) if password else None
|
|
|
|
try:
|
|
# Import RSA key
|
|
rsa_priv = load_private_key(keys_data["rsa_private"].encode(), password=None)
|
|
rsa_pub = rsa_priv.public_key()
|
|
save_keys(email, rsa_priv, rsa_pub, password=bytes(pwd_bytes) if pwd_bytes else None)
|
|
|
|
# Import identity keys
|
|
ed_priv = load_ed25519_private(bytes.fromhex(keys_data["identity_private"]))
|
|
ed_pub = ed_priv.public_key()
|
|
_save_identity_keys(email, ed_priv, ed_pub, password=bytes(pwd_bytes) if pwd_bytes else None)
|
|
finally:
|
|
if pwd_bytes:
|
|
pwd_bytes[:] = b'\x00' * len(pwd_bytes)
|
|
|
|
self.email = email
|
|
self.private_key = rsa_priv
|
|
self.public_key = rsa_pub
|
|
self.identity_private = ed_priv
|
|
self.identity_public = ed_pub
|
|
self._cache_key = derive_self_encryption_key(ed_priv)
|
|
self._local_key = derive_local_storage_key(ed_priv)
|
|
self._load_verification_stores()
|
|
self._pairing_temp_private_key = None
|
|
self._pairing_fingerprint = ""
|
|
self._pairing_code = ""
|
|
|
|
# Multi-device: new device generates own SPK + OPKs on first
|
|
# login. No session/sender key import needed — each device
|
|
# has independent Double Ratchet sessions.
|
|
|
|
return True, "Pairing complete."
|
|
except Exception as e:
|
|
self._pairing_fingerprint = ""
|
|
self._pairing_code = ""
|
|
return False, f"Failed to import keys: {e}"
|
|
self._pairing_fingerprint = ""
|
|
self._pairing_code = ""
|
|
return False, "Pairing timed out."
|
|
|
|
async def authorize_device(self, code: str, expected_fingerprint: str) -> tuple[bool, str]:
|
|
"""Authorize a new device by sending all keys to it."""
|
|
if not self.private_key or not self.identity_private:
|
|
return False, "Not logged in."
|
|
expected_digits = normalize_pairing_fingerprint(expected_fingerprint)
|
|
if len(expected_digits) != 30:
|
|
return False, "Pairing fingerprint must contain 30 digits."
|
|
claim = await self.send_and_recv("pairing_claim", code=code)
|
|
if claim["status"] != "ok":
|
|
return False, claim["data"]["message"]
|
|
|
|
if claim["data"].get("temp_key_type") != "x25519":
|
|
return False, "Unsupported pairing key type. Update both devices and try again."
|
|
|
|
temp_pub_raw = decode_binary(claim["data"]["temp_public_key"])
|
|
actual_fp = compute_pairing_fingerprint(temp_pub_raw)
|
|
if normalize_pairing_fingerprint(actual_fp) != expected_digits:
|
|
self._logger.warning("Pairing fingerprint mismatch for code %s", code[:8])
|
|
return False, (
|
|
"Pairing fingerprint mismatch. Verify the new device fingerprint and try again.\n"
|
|
f"Expected: {expected_fingerprint}\n"
|
|
f"Received: {actual_fp}"
|
|
)
|
|
temp_pub = load_x25519_public(temp_pub_raw)
|
|
|
|
# Build keys payload — only RSA + identity key.
|
|
# Multi-device: new device generates own SPK + OPKs, creates independent
|
|
# sessions. No session/sender key transfer needed.
|
|
keys_data = {
|
|
"rsa_private": serialize_private_key(self.private_key, password=None).decode(),
|
|
"identity_private": serialize_ed25519_private_raw(self.identity_private).hex(),
|
|
}
|
|
|
|
# Send keys to the new device first. Re-encrypting history can take a
|
|
# while on large accounts; doing it before pairing_send can make a valid
|
|
# code expire during authorization.
|
|
plaintext = json.dumps(keys_data).encode()
|
|
sender_priv, sender_pub = generate_x25519_keypair()
|
|
sender_pub_raw = serialize_x25519_public(sender_pub)
|
|
shared_secret = x25519_dh(sender_priv, temp_pub)
|
|
pairing_key = derive_pairing_shared_key(shared_secret, sender_pub_raw, temp_pub_raw)
|
|
_, nonce, ct, tag = aes_encrypt(plaintext, key=pairing_key)
|
|
payload = {
|
|
"sender_public_key": encode_binary(sender_pub_raw),
|
|
"iv": encode_binary(nonce),
|
|
"ciphertext": encode_binary(ct),
|
|
"tag": encode_binary(tag),
|
|
}
|
|
resp = await self.send_and_recv("pairing_send", code=code, payload=payload)
|
|
if resp["status"] == "ok":
|
|
async def _reencrypt_after_pairing():
|
|
try:
|
|
delay = random.uniform(*PAIRING_REENCRYPT_INITIAL_DELAY_RANGE)
|
|
self._logger.info("Delaying post-pairing history resync by %.1fs", delay)
|
|
await asyncio.sleep(delay)
|
|
await self.reencrypt_history()
|
|
except Exception as e:
|
|
self._logger.warning("Post-pairing re-encryption failed: %s", e)
|
|
|
|
asyncio.create_task(_reencrypt_after_pairing())
|
|
return True, "Device authorized."
|
|
return False, resp["data"]["message"]
|
|
|
|
# ------------------------------------------------------------------
|
|
# Password change (local key re-encryption only)
|
|
# ------------------------------------------------------------------
|
|
|
|
def change_password(self, old_password: str, new_password: str) -> tuple[bool, str]:
|
|
"""Change password for local key encryption (RSA + identity key).
|
|
|
|
Returns (success, message).
|
|
"""
|
|
if not self.email:
|
|
return False, "Not logged in."
|
|
|
|
old_pwd = bytearray(old_password.encode("utf-8"))
|
|
new_pwd = bytearray(new_password.encode("utf-8"))
|
|
try:
|
|
# 1. Verify old password by loading keys
|
|
priv, pub, err = load_keys(self.email, password=bytes(old_pwd))
|
|
if priv is None:
|
|
return False, "Wrong current password."
|
|
|
|
ed_priv, ed_pub = _load_identity_keys(self.email, password=bytes(old_pwd))
|
|
if ed_priv is None:
|
|
return False, "Failed to load identity key."
|
|
|
|
# 2. Re-save with new password
|
|
save_keys(self.email, priv, pub, password=bytes(new_pwd))
|
|
_save_identity_keys(self.email, ed_priv, ed_pub, password=bytes(new_pwd))
|
|
|
|
return True, "Password changed successfully."
|
|
finally:
|
|
old_pwd[:] = b'\x00' * len(old_pwd)
|
|
new_pwd[:] = b'\x00' * len(new_pwd)
|
|
|
|
async def change_username(self, new_username: str) -> tuple[bool, str]:
|
|
"""Change display name on server."""
|
|
if not self.session:
|
|
return False, "Not logged in."
|
|
new_username = new_username.strip()
|
|
if not new_username or len(new_username) > 100:
|
|
return False, "Username must be 1-100 characters."
|
|
resp = await self.send_and_recv("change_username", username=new_username)
|
|
if resp["status"] == "ok":
|
|
self.username = resp["data"]["username"]
|
|
if self.session:
|
|
self.session["username"] = self.username
|
|
return True, "Username changed."
|
|
return False, resp["data"].get("message", "Unknown error")
|
|
|
|
# ------------------------------------------------------------------
|
|
# Key rotation (RSA login key only)
|
|
# ------------------------------------------------------------------
|
|
|
|
async def rotate_keys(self, username: str, password: str) -> tuple[bool, str]:
|
|
"""Rotate RSA keypair to revoke other devices."""
|
|
if not self.session or self.session.get("username") != username:
|
|
return False, "Not logged in."
|
|
pwd_bytes = password.encode("utf-8") if password else None
|
|
priv, pub = generate_rsa_keypair()
|
|
save_keys(self.email, priv, pub, password=pwd_bytes)
|
|
self.private_key = priv
|
|
self.public_key = pub
|
|
pub_pem = serialize_public_key(pub).decode("utf-8")
|
|
resp = await self.send_and_recv("rotate_keys", public_key=pub_pem)
|
|
if resp["status"] == "ok":
|
|
return True, "RSA login keys rotated."
|
|
return False, resp["data"]["message"]
|
|
|
|
# ------------------------------------------------------------------
|
|
# Session management (X3DH + Double Ratchet)
|
|
# ------------------------------------------------------------------
|
|
|
|
async def _get_device_bundles(self, peer_user_id: str) -> list[dict]:
|
|
"""Get per-device key bundles for a peer. Caches for 5 minutes."""
|
|
import time
|
|
cached = self._device_bundle_cache.get(peer_user_id)
|
|
if cached:
|
|
ts, bundles = cached
|
|
if time.time() - ts < 300:
|
|
return bundles
|
|
|
|
resp = await self.send_and_recv("get_key_bundle", user_id=peer_user_id)
|
|
if resp["status"] != "ok":
|
|
raise RuntimeError(f"Cannot get key bundle for {peer_user_id}: {resp['data']['message']}")
|
|
|
|
data = resp["data"]
|
|
ik_b64 = data.get("identity_key", "")
|
|
|
|
device_bundles = data.get("device_bundles")
|
|
if device_bundles:
|
|
# Attach identity_key to each bundle
|
|
for b in device_bundles:
|
|
b["identity_key"] = ik_b64
|
|
else:
|
|
# Old server: wrap flat response as single-entry list
|
|
device_bundles = [{
|
|
"device_id": None,
|
|
"identity_key": ik_b64,
|
|
"signed_prekey_id": data.get("signed_prekey_id", ""),
|
|
"signed_prekey": data.get("signed_prekey", ""),
|
|
"spk_signature": data.get("spk_signature", ""),
|
|
"one_time_prekey_id": data.get("one_time_prekey_id"),
|
|
"one_time_prekey": data.get("one_time_prekey"),
|
|
}]
|
|
|
|
self._device_bundle_cache[peer_user_id] = (time.time(), device_bundles)
|
|
return device_bundles
|
|
|
|
async def _get_or_create_session(self, peer_user_id: str,
|
|
peer_device_id: str | None = None,
|
|
bundle: dict | None = None) -> DoubleRatchet:
|
|
"""Load existing session or create one via X3DH.
|
|
|
|
If peer_device_id is set, sessions are keyed by "user_id:device_id".
|
|
If bundle is provided, it's used instead of fetching from server.
|
|
"""
|
|
session_key = f"{peer_user_id}:{peer_device_id}" if peer_device_id else peer_user_id
|
|
|
|
# Check in-memory cache
|
|
if session_key in self.sessions:
|
|
return self.sessions[session_key]
|
|
|
|
# Check on disk
|
|
ratchet = _load_session(self.email, peer_user_id, self._local_key,
|
|
peer_device_id=peer_device_id)
|
|
if ratchet:
|
|
self.sessions[session_key] = ratchet
|
|
return ratchet
|
|
|
|
# Create new session via X3DH
|
|
if not bundle:
|
|
resp = await self.send_and_recv("get_key_bundle", user_id=peer_user_id)
|
|
if resp["status"] != "ok":
|
|
raise RuntimeError(f"Cannot get key bundle for {peer_user_id}: {resp['data']['message']}")
|
|
bundle = resp["data"]
|
|
|
|
ik_remote_bytes = decode_binary(bundle["identity_key"])
|
|
ik_remote = load_ed25519_public(ik_remote_bytes)
|
|
|
|
# TOFU: verify identity key before using it in X3DH
|
|
ik_status = self.check_identity_key(peer_user_id, ik_remote_bytes)
|
|
if ik_status in ("changed", "changed_verified"):
|
|
raise IdentityKeyChanged(peer_user_id, ik_remote_bytes, ik_status)
|
|
|
|
spk_remote = load_x25519_public(decode_binary(bundle["signed_prekey"]))
|
|
spk_sig = decode_binary(bundle["spk_signature"])
|
|
|
|
opk_remote = None
|
|
opk_id = bundle.get("one_time_prekey_id")
|
|
if bundle.get("one_time_prekey"):
|
|
opk_remote = load_x25519_public(decode_binary(bundle["one_time_prekey"]))
|
|
|
|
# Perform X3DH
|
|
shared_secret, ek_priv, ek_pub = x3dh_initiate(
|
|
self.identity_private,
|
|
ik_remote,
|
|
spk_remote,
|
|
spk_sig,
|
|
opk_remote,
|
|
)
|
|
|
|
# Initialize Double Ratchet as Alice
|
|
ratchet = DoubleRatchet.init_alice(shared_secret, spk_remote)
|
|
self.sessions[session_key] = ratchet
|
|
_save_session(self.email, peer_user_id, ratchet, self._local_key,
|
|
peer_device_id=peer_device_id)
|
|
|
|
# Build X3DH header for first message
|
|
x3dh_header = {
|
|
"ik": encode_binary(serialize_ed25519_public(self.identity_public)),
|
|
"ek": encode_binary(serialize_x25519_public(ek_pub)),
|
|
}
|
|
if opk_id:
|
|
x3dh_header["opk_id"] = opk_id
|
|
|
|
# Cache the x3dh header for the next send_message call
|
|
ratchet._x3dh_header = x3dh_header
|
|
|
|
# Cache remote user info
|
|
self._user_cache[peer_user_id] = {
|
|
"user_id": peer_user_id,
|
|
"identity_key": ik_remote,
|
|
"identity_key_bytes": ik_remote_bytes,
|
|
"identity_key_status": ik_status,
|
|
}
|
|
|
|
return ratchet
|
|
|
|
def _process_x3dh_header(self, sender_id: str, x3dh_header: dict,
|
|
sender_device_id: str | None = None,
|
|
spk_override=None) -> DoubleRatchet:
|
|
"""Process an incoming X3DH header to establish session as Bob.
|
|
|
|
Args:
|
|
spk_override: If provided, use this SPK private key instead of self.spk_private.
|
|
Used for grace period fallback (M4).
|
|
"""
|
|
ik_remote_bytes = decode_binary(x3dh_header["ik"])
|
|
ik_remote = load_ed25519_public(ik_remote_bytes)
|
|
|
|
# TOFU: verify identity key before using it in X3DH
|
|
ik_status = self.check_identity_key(sender_id, ik_remote_bytes)
|
|
if ik_status in ("changed", "changed_verified"):
|
|
raise IdentityKeyChanged(sender_id, ik_remote_bytes, ik_status)
|
|
|
|
ek_remote = load_x25519_public(decode_binary(x3dh_header["ek"]))
|
|
|
|
opk_id = x3dh_header.get("opk_id")
|
|
opk_priv = None
|
|
if opk_id:
|
|
opk_priv = _load_opk_private(self.email, opk_id, self._local_key)
|
|
# Deletion is deferred until the first message decrypts successfully
|
|
# (_consume_pending_opk). Deleting here would break the SPK
|
|
# grace-period retry: the second _process_x3dh_header call could no
|
|
# longer load the OPK and the message would be lost permanently.
|
|
|
|
spk_priv = spk_override if spk_override else self.spk_private
|
|
|
|
shared_secret = x3dh_respond(
|
|
self.identity_private,
|
|
spk_priv,
|
|
ik_remote,
|
|
ek_remote,
|
|
opk_priv,
|
|
)
|
|
|
|
spk_pub = spk_priv.public_key() if hasattr(spk_priv, 'public_key') else None
|
|
ratchet = DoubleRatchet.init_bob(shared_secret, (spk_priv, spk_pub))
|
|
|
|
ratchet._pending_opk_delete = opk_id if opk_priv else None
|
|
|
|
session_key = f"{sender_id}:{sender_device_id}" if sender_device_id else sender_id
|
|
self.sessions[session_key] = ratchet
|
|
_save_session(self.email, sender_id, ratchet, self._local_key,
|
|
peer_device_id=sender_device_id)
|
|
|
|
self._user_cache[sender_id] = {
|
|
"user_id": sender_id,
|
|
"identity_key": ik_remote,
|
|
"identity_key_bytes": ik_remote_bytes,
|
|
"identity_key_status": ik_status,
|
|
}
|
|
|
|
return ratchet
|
|
|
|
def _consume_pending_opk(self, ratchet) -> None:
|
|
"""Delete the one-time prekey consumed by an X3DH handshake.
|
|
|
|
Called only after the first message decrypted successfully, so a failed
|
|
attempt (e.g. wrong SPK during the grace period) can still retry with
|
|
the same OPK.
|
|
"""
|
|
opk_id = getattr(ratchet, "_pending_opk_delete", None)
|
|
if opk_id:
|
|
_delete_opk_private(self.email, opk_id)
|
|
ratchet._pending_opk_delete = None
|
|
|
|
# ------------------------------------------------------------------
|
|
# Conversations
|
|
# ------------------------------------------------------------------
|
|
|
|
async def create_conversation(self, member_emails: list[str], name: str | None = None) -> tuple[str | None, str]:
|
|
kwargs = {"members": member_emails}
|
|
if name:
|
|
kwargs["name"] = name
|
|
resp = await self.send_and_recv("create_conversation", **kwargs)
|
|
if resp["status"] == "ok":
|
|
return resp["data"]["conversation_id"], "OK"
|
|
return None, resp["data"]["message"]
|
|
|
|
async def remove_member(self, conv_id: str, user_id: str) -> tuple[bool, str]:
|
|
resp = await self.send_and_recv("remove_member", conversation_id=conv_id, user_id=user_id)
|
|
if resp["status"] == "ok":
|
|
return True, "OK"
|
|
return False, resp["data"]["message"]
|
|
|
|
async def leave_group(self, conv_id: str) -> tuple[bool, str]:
|
|
"""Leave a group conversation."""
|
|
resp = await self.send_and_recv("leave_group", conversation_id=conv_id)
|
|
if resp["status"] == "ok":
|
|
# Clean up local sender key state for this group
|
|
self.sender_key_states.pop(conv_id, None)
|
|
# Remove received sender keys for this conversation
|
|
to_remove = [k for k in self.recv_sender_keys if k.startswith(f"{conv_id}:")]
|
|
for k in to_remove:
|
|
self.recv_sender_keys.pop(k, None)
|
|
return True, "OK"
|
|
return False, resp["data"]["message"]
|
|
|
|
async def rename_conversation(self, conv_id: str, name: str) -> tuple[bool, str]:
|
|
"""Rename a group conversation (creator only)."""
|
|
resp = await self.send_and_recv("rename_conversation", conversation_id=conv_id, name=name)
|
|
if resp["status"] == "ok":
|
|
return True, "OK"
|
|
return False, resp["data"]["message"]
|
|
|
|
async def delete_conversation(self, conv_id: str) -> tuple[bool, str]:
|
|
"""Delete a conversation (leave + server cleans up if empty)."""
|
|
resp = await self.send_and_recv("delete_conversation", conversation_id=conv_id)
|
|
if resp["status"] == "ok":
|
|
# Clean up local sender key state
|
|
self.sender_key_states.pop(conv_id, None)
|
|
to_remove = [k for k in self.recv_sender_keys if k.startswith(f"{conv_id}:")]
|
|
for k in to_remove:
|
|
self.recv_sender_keys.pop(k, None)
|
|
return True, "OK"
|
|
return False, resp["data"]["message"]
|
|
|
|
async def add_member(self, conv_id: str, email: str) -> tuple[bool, str]:
|
|
resp = await self.send_and_recv("add_member", conversation_id=conv_id, email=email)
|
|
if resp["status"] == "ok":
|
|
return True, "OK"
|
|
return False, resp["data"]["message"]
|
|
|
|
async def accept_invitation(self, conv_id: str) -> tuple[bool, str]:
|
|
"""Accept a group invitation."""
|
|
resp = await self.send_and_recv("accept_invitation", conversation_id=conv_id)
|
|
if resp["status"] == "ok":
|
|
return True, "OK"
|
|
return False, resp["data"]["message"]
|
|
|
|
async def decline_invitation(self, conv_id: str) -> tuple[bool, str]:
|
|
"""Decline a group invitation."""
|
|
resp = await self.send_and_recv("decline_invitation", conversation_id=conv_id)
|
|
if resp["status"] == "ok":
|
|
return True, "OK"
|
|
return False, resp["data"]["message"]
|
|
|
|
async def list_invitations(self) -> list[dict]:
|
|
"""List pending group invitations."""
|
|
resp = await self.send_and_recv("list_invitations")
|
|
if resp["status"] == "ok":
|
|
return resp["data"]["invitations"]
|
|
return []
|
|
|
|
async def list_conversations(self) -> list[dict]:
|
|
resp = await self.send_and_recv("list_conversations")
|
|
if resp["status"] == "ok":
|
|
return resp["data"]["conversations"]
|
|
return []
|
|
|
|
async def find_or_create_conversation(self, email: str) -> tuple[str | None, str]:
|
|
resp = await self.send_and_recv("find_conversation", email=email)
|
|
if resp["status"] != "ok":
|
|
return None, resp["data"]["message"]
|
|
conv_id = resp["data"]["conversation_id"]
|
|
if conv_id:
|
|
return conv_id, "OK"
|
|
return await self.create_conversation([email])
|
|
|
|
# ------------------------------------------------------------------
|
|
# Send message
|
|
# ------------------------------------------------------------------
|
|
|
|
def _cancel_typing_timer(self, conv_id: str):
|
|
task = self._typing_stop_tasks.pop(conv_id, None)
|
|
if task and not task.done():
|
|
task.cancel()
|
|
|
|
async def _typing_stop_after_delay(self, conv_id: str, delay: float):
|
|
try:
|
|
await asyncio.sleep(delay)
|
|
await self.typing_stop(conv_id)
|
|
except asyncio.CancelledError:
|
|
return
|
|
|
|
async def typing_start(self, conv_id: str):
|
|
"""Debounced typing_start with 3s inactivity timeout."""
|
|
if not conv_id or not self.session:
|
|
return
|
|
|
|
now = time.monotonic()
|
|
was_active = self._typing_active.get(conv_id, False)
|
|
last_sent = self._typing_last_sent.get(conv_id, 0.0)
|
|
should_send = (not was_active) or (now - last_sent >= 1.0)
|
|
|
|
self._typing_active[conv_id] = True
|
|
self._cancel_typing_timer(conv_id)
|
|
self._typing_stop_tasks[conv_id] = asyncio.create_task(
|
|
self._typing_stop_after_delay(conv_id, 3.0)
|
|
)
|
|
|
|
if not should_send:
|
|
return
|
|
self._typing_last_sent[conv_id] = now
|
|
try:
|
|
await self.send_and_recv("typing_start", timeout=5.0, conversation_id=conv_id)
|
|
except Exception:
|
|
pass
|
|
|
|
async def typing_stop(self, conv_id: str, force: bool = False):
|
|
if not conv_id or not self.session:
|
|
return
|
|
self._cancel_typing_timer(conv_id)
|
|
|
|
was_active = self._typing_active.get(conv_id, False)
|
|
self._typing_active[conv_id] = False
|
|
if not was_active and not force:
|
|
return
|
|
try:
|
|
await self.send_and_recv("typing_stop", timeout=5.0, conversation_id=conv_id)
|
|
except Exception:
|
|
pass
|
|
|
|
def _is_group(self, members: list[dict]) -> bool:
|
|
return len(members) > 2
|
|
|
|
async def send_message(self, conv_id: str, text: str, members: list[dict],
|
|
reply_to: str | None = None) -> tuple[bool, str | dict]:
|
|
"""Encrypt and send a message. DM: per-recipient Double Ratchet. Group: Sender Keys.
|
|
|
|
Returns (True, msg_dict) on success or (False, error_string) on failure.
|
|
msg_dict contains the full decrypted payload ready for display.
|
|
"""
|
|
my_user_id = self.session["user_id"]
|
|
await self.typing_stop(conv_id, force=True)
|
|
|
|
# Build plaintext payload
|
|
payload = {
|
|
"sender": self.username,
|
|
"text": text,
|
|
"reply_to": reply_to,
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
}
|
|
plaintext = pad_plaintext(json.dumps(payload, ensure_ascii=False).encode("utf-8"))
|
|
|
|
if self._is_group(members):
|
|
return await self._send_group_message(conv_id, plaintext, members, payload)
|
|
else:
|
|
return await self._send_dm(conv_id, plaintext, members, payload)
|
|
|
|
async def _send_dm(self, conv_id: str, plaintext: bytes, members: list[dict],
|
|
payload: dict | None = None) -> tuple[bool, str | dict]:
|
|
"""Encrypt DM with per-device Double Ratchet."""
|
|
my_user_id = self.session["user_id"]
|
|
recipients = []
|
|
first_ratchet_header = None
|
|
|
|
for member in members:
|
|
uid = member.get("user_id")
|
|
if not uid or uid == my_user_id:
|
|
continue
|
|
|
|
# Get all device bundles for this user
|
|
try:
|
|
device_bundles = await self._get_device_bundles(uid)
|
|
self._logger.debug("Got %d device bundles for %s", len(device_bundles), uid)
|
|
except Exception as e:
|
|
self._logger.warning("Failed to get device bundles for %s: %s", uid, e)
|
|
device_bundles = []
|
|
|
|
if not device_bundles:
|
|
# Fallback: try single session (legacy peer)
|
|
ratchet = await self._get_or_create_session(uid)
|
|
result = ratchet.encrypt(plaintext)
|
|
x3dh_hdr = getattr(ratchet, "_x3dh_header", None)
|
|
if x3dh_hdr:
|
|
delattr(ratchet, "_x3dh_header")
|
|
entry = {
|
|
"user_id": uid,
|
|
"encrypted_content": encode_binary(result["ciphertext"]),
|
|
"nonce": encode_binary(result["nonce"]),
|
|
"ratchet_header": result["header"],
|
|
}
|
|
if x3dh_hdr:
|
|
entry["x3dh_header"] = x3dh_hdr
|
|
recipients.append(entry)
|
|
if first_ratchet_header is None:
|
|
first_ratchet_header = result["header"]
|
|
_save_session(self.email, uid, ratchet, self._local_key)
|
|
continue
|
|
|
|
for bundle in device_bundles:
|
|
dev_id = bundle.get("device_id")
|
|
ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id,
|
|
bundle=bundle)
|
|
result = ratchet.encrypt(plaintext)
|
|
x3dh_hdr = getattr(ratchet, "_x3dh_header", None)
|
|
if x3dh_hdr:
|
|
delattr(ratchet, "_x3dh_header")
|
|
|
|
entry = {
|
|
"user_id": uid,
|
|
"encrypted_content": encode_binary(result["ciphertext"]),
|
|
"nonce": encode_binary(result["nonce"]),
|
|
"ratchet_header": result["header"],
|
|
}
|
|
if dev_id:
|
|
entry["device_id"] = dev_id
|
|
if x3dh_hdr:
|
|
entry["x3dh_header"] = x3dh_hdr
|
|
recipients.append(entry)
|
|
|
|
if first_ratchet_header is None:
|
|
first_ratchet_header = result["header"]
|
|
|
|
_save_session(self.email, uid, ratchet, self._local_key,
|
|
peer_device_id=dev_id)
|
|
|
|
# Encrypt self-copy with static key derived from identity (not ratchet)
|
|
# Uses SELF_DEVICE_ID so all own devices can read it
|
|
self_key = derive_self_encryption_key(self.identity_private)
|
|
_, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key)
|
|
recipients.append({
|
|
"user_id": my_user_id,
|
|
"encrypted_content": encode_binary(self_ct + self_tag),
|
|
"nonce": encode_binary(self_nonce),
|
|
"ratchet_header": {"self": True},
|
|
})
|
|
|
|
if not recipients:
|
|
return False, "No recipients."
|
|
|
|
kwargs = {
|
|
"conversation_id": conv_id,
|
|
"ratchet_header": first_ratchet_header,
|
|
"recipients": recipients,
|
|
}
|
|
|
|
resp = await self.send_and_recv("send_message", **kwargs)
|
|
if resp["status"] == "ok":
|
|
msg_data = resp.get("data", {})
|
|
if payload is not None:
|
|
result = {
|
|
**payload,
|
|
"message_id": msg_data.get("message_id", ""),
|
|
"created_at": msg_data.get("created_at", ""),
|
|
"sender_id": self.session["user_id"],
|
|
"conversation_id": conv_id,
|
|
"read_by": [],
|
|
}
|
|
_save_message_to_cache(self.email, conv_id, result["message_id"], result, self._cache_key)
|
|
return True, result
|
|
return True, "Message sent."
|
|
return False, resp["data"]["message"]
|
|
|
|
async def _send_group_message(self, conv_id: str, plaintext: bytes,
|
|
members: list[dict],
|
|
payload: dict | None = None) -> tuple[bool, str | dict]:
|
|
"""Encrypt group message with Sender Keys."""
|
|
my_user_id = self.session["user_id"]
|
|
|
|
# Get or create sender key for this group
|
|
sk = self.sender_key_states.get(conv_id)
|
|
if not sk:
|
|
sk = _load_sender_key_state(self.email, conv_id, self._local_key)
|
|
if not sk:
|
|
sk = SenderKeyState()
|
|
self.sender_key_states[conv_id] = sk
|
|
_save_sender_key_state(self.email, conv_id, sk, self._local_key)
|
|
|
|
self.sender_key_states[conv_id] = sk
|
|
|
|
await self._catch_up_sender_key_distribution(conv_id, members, sk)
|
|
|
|
# Encrypt with sender key
|
|
result = sk.encrypt(plaintext)
|
|
_save_sender_key_state(self.email, conv_id, sk, self._local_key)
|
|
|
|
# Build per-recipient entries (same ciphertext for all except self)
|
|
recipients = []
|
|
for member in members:
|
|
uid = member.get("user_id")
|
|
if not uid or uid == my_user_id:
|
|
continue
|
|
recipients.append({
|
|
"user_id": uid,
|
|
"encrypted_content": encode_binary(result["ciphertext"]),
|
|
"nonce": encode_binary(result["nonce"]),
|
|
})
|
|
|
|
# Self-encrypted copy (so other devices + history fetch can decrypt)
|
|
self_key = derive_self_encryption_key(self.identity_private)
|
|
_, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key)
|
|
recipients.append({
|
|
"user_id": my_user_id,
|
|
"encrypted_content": encode_binary(self_ct + self_tag),
|
|
"nonce": encode_binary(self_nonce),
|
|
"ratchet_header": {"self": True},
|
|
})
|
|
|
|
ratchet_header = {"dh_pub": "00" * 32, "n": 0, "pn": 0} # Dummy for groups
|
|
|
|
kwargs = {
|
|
"conversation_id": conv_id,
|
|
"ratchet_header": ratchet_header,
|
|
"recipients": recipients,
|
|
"sender_chain_id": encode_binary(bytes.fromhex(result["chain_id"])),
|
|
"sender_chain_n": result["n"],
|
|
}
|
|
|
|
resp = await self.send_and_recv("send_message", **kwargs)
|
|
if resp["status"] == "ok":
|
|
msg_data = resp.get("data", {})
|
|
if payload is not None:
|
|
result_msg = {
|
|
**payload,
|
|
"message_id": msg_data.get("message_id", ""),
|
|
"created_at": msg_data.get("created_at", ""),
|
|
"sender_id": self.session["user_id"],
|
|
"conversation_id": conv_id,
|
|
"read_by": [],
|
|
}
|
|
_save_message_to_cache(self.email, conv_id, result_msg["message_id"], result_msg, self._cache_key)
|
|
return True, result_msg
|
|
return True, "Message sent."
|
|
return False, resp["data"]["message"]
|
|
|
|
async def _catch_up_sender_key_distribution(self, conv_id: str, members: list[dict],
|
|
sk: SenderKeyState):
|
|
"""Ensure all current members have our existing sender key."""
|
|
my_user_id = self.session["user_id"]
|
|
current_member_ids = {
|
|
uid for uid in (member.get("user_id") for member in members)
|
|
if uid and uid != my_user_id
|
|
}
|
|
if not current_member_ids:
|
|
return
|
|
|
|
distributed_to = _load_sender_key_recipients(self.email, conv_id, self._local_key)
|
|
missing_ids = sorted(current_member_ids - distributed_to)
|
|
if not missing_ids:
|
|
return
|
|
|
|
distributed_now = await self._distribute_sender_key(
|
|
conv_id,
|
|
[{"user_id": uid} for uid in missing_ids],
|
|
sk,
|
|
)
|
|
if distributed_now:
|
|
distributed_to.update(distributed_now)
|
|
_save_sender_key_recipients(self.email, conv_id, distributed_to, self._local_key)
|
|
|
|
async def _distribute_sender_key(self, conv_id: str, members: list[dict],
|
|
sk: SenderKeyState) -> set[str]:
|
|
"""Send own sender key to all group members via pairwise Double Ratchet (per-device)."""
|
|
my_user_id = self.session["user_id"]
|
|
distributed_to: set[str] = set()
|
|
exported_key = sk.export_key()
|
|
|
|
# Build a special "sender_key_distribution" payload
|
|
payload = {
|
|
"sender": self.username,
|
|
"text": "",
|
|
"reply_to": None,
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
"_sender_key": {
|
|
"conv_id": conv_id,
|
|
"key": encode_binary(exported_key),
|
|
"sender_device_id": self.device_id,
|
|
},
|
|
}
|
|
plaintext = pad_plaintext(json.dumps(payload, ensure_ascii=False).encode("utf-8"))
|
|
|
|
# Send as DM to each member's devices (per-device encryption)
|
|
for member in members:
|
|
uid = member.get("user_id")
|
|
if not uid or uid == my_user_id:
|
|
continue
|
|
|
|
try:
|
|
# Get all device bundles for this user
|
|
try:
|
|
device_bundles = await self._get_device_bundles(uid)
|
|
except Exception:
|
|
device_bundles = []
|
|
|
|
if not device_bundles:
|
|
# Fallback: legacy single-device
|
|
ratchet = await self._get_or_create_session(uid)
|
|
result = ratchet.encrypt(plaintext)
|
|
x3dh_header = getattr(ratchet, "_x3dh_header", None)
|
|
if x3dh_header:
|
|
delattr(ratchet, "_x3dh_header")
|
|
|
|
recipient_entry = {
|
|
"user_id": uid,
|
|
"encrypted_content": encode_binary(result["ciphertext"]),
|
|
"nonce": encode_binary(result["nonce"]),
|
|
"ratchet_header": result["header"],
|
|
}
|
|
if x3dh_header:
|
|
recipient_entry["x3dh_header"] = x3dh_header
|
|
kwargs = {
|
|
"conversation_id": conv_id,
|
|
"ratchet_header": result["header"],
|
|
"recipients": [recipient_entry],
|
|
}
|
|
await self.send_and_recv("send_message", **kwargs)
|
|
_save_session(self.email, uid, ratchet, self._local_key)
|
|
distributed_to.add(uid)
|
|
else:
|
|
# Per-device encryption
|
|
recipients = []
|
|
first_rh = None
|
|
for bundle in device_bundles:
|
|
dev_id = bundle.get("device_id")
|
|
ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id,
|
|
bundle=bundle)
|
|
result = ratchet.encrypt(plaintext)
|
|
x3dh_header = getattr(ratchet, "_x3dh_header", None)
|
|
if x3dh_header:
|
|
delattr(ratchet, "_x3dh_header")
|
|
|
|
entry = {
|
|
"user_id": uid,
|
|
"encrypted_content": encode_binary(result["ciphertext"]),
|
|
"nonce": encode_binary(result["nonce"]),
|
|
"ratchet_header": result["header"],
|
|
}
|
|
if dev_id:
|
|
entry["device_id"] = dev_id
|
|
if x3dh_header:
|
|
entry["x3dh_header"] = x3dh_header
|
|
recipients.append(entry)
|
|
if first_rh is None:
|
|
first_rh = result["header"]
|
|
_save_session(self.email, uid, ratchet, self._local_key,
|
|
peer_device_id=dev_id)
|
|
|
|
kwargs = {
|
|
"conversation_id": conv_id,
|
|
"ratchet_header": first_rh,
|
|
"recipients": recipients,
|
|
}
|
|
await self.send_and_recv("send_message", **kwargs)
|
|
distributed_to.add(uid)
|
|
except Exception as e:
|
|
self._logger.warning("Failed to distribute sender key to %s: %s", uid, e)
|
|
return distributed_to
|
|
|
|
async def redistribute_sender_key_to_member(self, conv_id: str, new_user_id: str):
|
|
"""Redistribute our existing sender key to a newly joined group member.
|
|
|
|
Called when we receive a member_added notification — the new member needs our
|
|
sender key so they can decrypt future messages we send in the group.
|
|
"""
|
|
if not self.session:
|
|
return
|
|
sk = self.sender_key_states.get(conv_id)
|
|
if sk is None:
|
|
sk = _load_sender_key_state(self.email, conv_id, self._local_key)
|
|
if sk is None:
|
|
# We haven't sent anything in this group yet — no key to redistribute
|
|
return
|
|
try:
|
|
distributed = await self._distribute_sender_key(conv_id, [{"user_id": new_user_id}], sk)
|
|
if new_user_id in distributed:
|
|
recipients = _load_sender_key_recipients(self.email, conv_id, self._local_key)
|
|
recipients.add(new_user_id)
|
|
_save_sender_key_recipients(self.email, conv_id, recipients, self._local_key)
|
|
self._logger.info("Redistributed sender key for conv=%s to new member %s",
|
|
conv_id[:8], new_user_id[:8])
|
|
except Exception as e:
|
|
self._logger.warning("Failed to redistribute sender key to %s: %s", new_user_id, e)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Decrypt messages
|
|
# ------------------------------------------------------------------
|
|
|
|
def _decrypt_message(self, msg_data: dict) -> dict:
|
|
"""Decrypt a single message (DM or group)."""
|
|
# Check for self-encrypted marker FIRST — after re-encryption,
|
|
# group messages will have {"self": true} ratchet_header but still
|
|
# have sender_chain_id at message level.
|
|
rh = msg_data.get("ratchet_header", {})
|
|
if isinstance(rh, dict) and rh.get("self"):
|
|
return self._decrypt_dm(msg_data)
|
|
|
|
if msg_data.get("sender_chain_id"):
|
|
return self._decrypt_group(msg_data)
|
|
else:
|
|
return self._decrypt_dm(msg_data)
|
|
|
|
def _decrypt_dm(self, msg_data: dict) -> dict:
|
|
"""Decrypt DM using Double Ratchet with sender, or static key for self-copies."""
|
|
sender_id = msg_data.get("sender_id", "")
|
|
sender_device_id = msg_data.get("sender_device_id")
|
|
ratchet_header = msg_data.get("ratchet_header", {})
|
|
ct_b64 = msg_data.get("encrypted_content", "")
|
|
nonce_b64 = msg_data.get("nonce", "")
|
|
|
|
if not ct_b64 or not nonce_b64:
|
|
raise ValueError("Missing ciphertext or nonce")
|
|
|
|
ciphertext = decode_binary(ct_b64)
|
|
nonce = decode_binary(nonce_b64)
|
|
|
|
# Self-encrypted message (own sent message copy)
|
|
if isinstance(ratchet_header, dict) and ratchet_header.get("self"):
|
|
self_key = derive_self_encryption_key(self.identity_private)
|
|
ct = ciphertext[:-16]
|
|
tag = ciphertext[-16:]
|
|
plaintext = aes_decrypt(self_key, nonce, ct, tag)
|
|
else:
|
|
x3dh_header = msg_data.get("x3dh_header")
|
|
|
|
# Session key: "sender_id:sender_device_id" or just "sender_id" for legacy
|
|
session_key = f"{sender_id}:{sender_device_id}" if sender_device_id else sender_id
|
|
|
|
# Try to load existing session
|
|
ratchet = self.sessions.get(session_key)
|
|
if not ratchet:
|
|
ratchet = _load_session(self.email, sender_id, self._local_key,
|
|
peer_device_id=sender_device_id)
|
|
if ratchet:
|
|
self.sessions[session_key] = ratchet
|
|
|
|
if ratchet and not x3dh_header:
|
|
# Normal case: existing session, no X3DH header
|
|
plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce)
|
|
_save_session(self.email, sender_id, ratchet, self._local_key,
|
|
peer_device_id=sender_device_id)
|
|
elif x3dh_header:
|
|
if ratchet:
|
|
# Existing session + X3DH header: sender may have reset.
|
|
backup = ratchet.export_state()
|
|
try:
|
|
plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce)
|
|
_save_session(self.email, sender_id, ratchet, self._local_key,
|
|
peer_device_id=sender_device_id)
|
|
except Exception:
|
|
restored = DoubleRatchet.import_state(backup)
|
|
self.sessions[session_key] = restored
|
|
_save_session(self.email, sender_id, restored, self._local_key,
|
|
peer_device_id=sender_device_id)
|
|
ratchet = self._process_x3dh_header(sender_id, x3dh_header,
|
|
sender_device_id=sender_device_id)
|
|
try:
|
|
plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce)
|
|
except Exception:
|
|
if self._prev_spk_private:
|
|
ratchet = self._process_x3dh_header(
|
|
sender_id, x3dh_header,
|
|
sender_device_id=sender_device_id,
|
|
spk_override=self._prev_spk_private)
|
|
plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce)
|
|
else:
|
|
raise
|
|
_save_session(self.email, sender_id, ratchet, self._local_key,
|
|
peer_device_id=sender_device_id)
|
|
else:
|
|
ratchet = self._process_x3dh_header(sender_id, x3dh_header,
|
|
sender_device_id=sender_device_id)
|
|
try:
|
|
plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce)
|
|
except Exception:
|
|
if self._prev_spk_private:
|
|
ratchet = self._process_x3dh_header(
|
|
sender_id, x3dh_header,
|
|
sender_device_id=sender_device_id,
|
|
spk_override=self._prev_spk_private)
|
|
plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce)
|
|
else:
|
|
raise
|
|
_save_session(self.email, sender_id, ratchet, self._local_key,
|
|
peer_device_id=sender_device_id)
|
|
else:
|
|
raise ValueError(f"No session for sender {sender_id}")
|
|
|
|
self._consume_pending_opk(ratchet)
|
|
|
|
plaintext = unpad_plaintext(plaintext)
|
|
payload = json.loads(plaintext)
|
|
|
|
# Handle sender key distribution messages
|
|
if "_sender_key" in payload:
|
|
sk_data = payload["_sender_key"]
|
|
sk_conv_id = sk_data["conv_id"]
|
|
sk_key = decode_binary(sk_data["key"])
|
|
sk_sender_device_id = sk_data.get("sender_device_id")
|
|
recv_sk = SenderKeyState.from_key(sk_key)
|
|
if sk_sender_device_id:
|
|
cache_key = f"{sk_conv_id}:{sender_id}:{sk_sender_device_id}"
|
|
else:
|
|
cache_key = f"{sk_conv_id}:{sender_id}"
|
|
self.recv_sender_keys[cache_key] = recv_sk
|
|
_save_recv_sender_key(self.email, sk_conv_id, sender_id, recv_sk, self._local_key,
|
|
sender_device_id=sk_sender_device_id)
|
|
# Return empty — this is a control message, not user-visible
|
|
return None
|
|
|
|
return payload
|
|
|
|
def _decrypt_group(self, msg_data: dict) -> dict:
|
|
"""Decrypt group message using sender's Sender Key."""
|
|
sender_id = msg_data.get("sender_id", "")
|
|
sender_device_id = msg_data.get("sender_device_id")
|
|
conv_id = msg_data.get("conversation_id", "")
|
|
chain_id_b64 = msg_data.get("sender_chain_id", "")
|
|
chain_n = msg_data.get("sender_chain_n", 0)
|
|
ct_b64 = msg_data.get("encrypted_content", "")
|
|
nonce_b64 = msg_data.get("nonce", "")
|
|
|
|
if not ct_b64 or not nonce_b64 or not chain_id_b64:
|
|
raise ValueError("Missing group message fields")
|
|
|
|
ciphertext = decode_binary(ct_b64)
|
|
nonce = decode_binary(nonce_b64)
|
|
chain_id = decode_binary(chain_id_b64)
|
|
|
|
my_user_id = self.session["user_id"]
|
|
|
|
# If we sent this message, use our own sender key
|
|
if sender_id == my_user_id:
|
|
sk = self.sender_key_states.get(conv_id)
|
|
if not sk:
|
|
sk = _load_sender_key_state(self.email, conv_id, self._local_key)
|
|
if sk:
|
|
self.sender_key_states[conv_id] = sk
|
|
if not sk:
|
|
raise ValueError("Own sender key not found")
|
|
# For our own messages, we can't decrypt from sender key (it's already advanced)
|
|
# Return a placeholder — the server echoed our ciphertext
|
|
raise ValueError("Cannot decrypt own group message from sender key")
|
|
|
|
# Use received sender key — try with sender_device_id first, fall back to without
|
|
sk = None
|
|
if sender_device_id:
|
|
cache_key = f"{conv_id}:{sender_id}:{sender_device_id}"
|
|
sk = self.recv_sender_keys.get(cache_key)
|
|
if not sk:
|
|
sk = _load_recv_sender_key(self.email, conv_id, sender_id, self._local_key,
|
|
sender_device_id=sender_device_id)
|
|
if sk:
|
|
self.recv_sender_keys[cache_key] = sk
|
|
|
|
if not sk:
|
|
# Fallback: try without device_id (legacy or same-device)
|
|
cache_key = f"{conv_id}:{sender_id}"
|
|
sk = self.recv_sender_keys.get(cache_key)
|
|
if not sk:
|
|
sk = _load_recv_sender_key(self.email, conv_id, sender_id, self._local_key)
|
|
if sk:
|
|
self.recv_sender_keys[cache_key] = sk
|
|
|
|
if not sk:
|
|
raise ValueError(f"No sender key for {sender_id} in conversation {conv_id}")
|
|
|
|
plaintext = unpad_plaintext(sk.decrypt(chain_id.hex(), chain_n, ciphertext, nonce))
|
|
_save_recv_sender_key(self.email, conv_id, sender_id, sk, self._local_key,
|
|
sender_device_id=sender_device_id)
|
|
|
|
return json.loads(plaintext)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Get/decrypt messages (batch)
|
|
# ------------------------------------------------------------------
|
|
|
|
async def get_messages(self, conv_id: str, limit: int = 50, offset: int = 0) -> list[dict]:
|
|
cache = _load_message_cache(self.email, conv_id, self._cache_key)
|
|
my_user_id = self.session["user_id"] if self.session else ""
|
|
|
|
# Incremental sync: use stored server timestamp from last successful fetch.
|
|
after_ts = None
|
|
if cache and offset == 0:
|
|
after_ts = cache.get("__last_server_ts", {}).get("ts")
|
|
|
|
req_params = {"conversation_id": conv_id, "limit": limit, "offset": offset}
|
|
if after_ts:
|
|
req_params["after_ts"] = after_ts
|
|
resp = await self.send_and_recv("get_messages", **req_params)
|
|
|
|
if resp["status"] != "ok":
|
|
# Offline fallback: return from cache if available
|
|
if cache and offset == 0:
|
|
return self._build_from_cache(cache)
|
|
return []
|
|
|
|
raw_messages = resp["data"]["messages"]
|
|
raw_messages.reverse() # Server returns DESC, reverse to ASC
|
|
|
|
# Save latest server timestamp for next incremental sync
|
|
if raw_messages:
|
|
# raw_messages are now ASC; last one is newest
|
|
newest_ts = raw_messages[-1].get("created_at", "")
|
|
if newest_ts:
|
|
cache["__last_server_ts"] = {"ts": newest_ts}
|
|
_save_message_to_cache(self.email, conv_id, "__last_server_ts",
|
|
{"ts": newest_ts}, cache_key=self._cache_key)
|
|
|
|
# Decrypt new messages from server
|
|
new_decrypted = self._decrypt_raw_messages(raw_messages, cache, conv_id, my_user_id)
|
|
|
|
# All non-critical ops fire-and-forget to avoid blocking message display
|
|
# Confirm delivery for messages from others
|
|
deliver_ids = [m["message_id"] for m in new_decrypted
|
|
if m.get("sender_id") and m["sender_id"] != my_user_id
|
|
and not m.get("deleted")]
|
|
if deliver_ids:
|
|
asyncio.ensure_future(self.confirm_delivery(conv_id, deliver_ids))
|
|
|
|
# Mark entire conversation as read (fire-and-forget)
|
|
asyncio.ensure_future(self.mark_conversation_read(conv_id))
|
|
|
|
# Flush self-encryption queue in background
|
|
if self._pending_self_encrypt:
|
|
asyncio.ensure_future(self._flush_self_encrypt())
|
|
|
|
if after_ts:
|
|
# Incremental: sync deletions in background, build from cache NOW
|
|
asyncio.ensure_future(self._sync_deletions(conv_id, after_ts))
|
|
return self._build_from_cache(cache)
|
|
|
|
return new_decrypted
|
|
|
|
async def _sync_deletions(self, conv_id: str, after_ts: str):
|
|
"""Sync message deletions from server (background, non-blocking)."""
|
|
try:
|
|
del_resp = await self.send_and_recv("get_deleted_since",
|
|
conversation_id=conv_id, since_ts=after_ts)
|
|
if del_resp.get("status") == "ok":
|
|
for del_id in del_resp.get("data", {}).get("deleted_ids", []):
|
|
_save_message_to_cache(self.email, conv_id, del_id, {"deleted": True},
|
|
cache_key=self._cache_key)
|
|
except Exception:
|
|
pass
|
|
|
|
def get_cached_messages(self, conv_id: str) -> list[dict]:
|
|
"""Return messages from local disk cache only (no server call). Instant."""
|
|
if not self.email:
|
|
return []
|
|
cache = _load_message_cache(self.email, conv_id, self._cache_key)
|
|
if not cache:
|
|
return []
|
|
return self._build_from_cache(cache)
|
|
|
|
def _build_from_cache(self, cache: dict) -> list[dict]:
|
|
"""Build sorted message list from local cache (all messages)."""
|
|
messages = []
|
|
for msg_id, p in cache.items():
|
|
if p.get("_control") or msg_id.startswith("__"):
|
|
continue
|
|
entry = dict(p)
|
|
entry.setdefault("message_id", msg_id)
|
|
entry.setdefault("read_by", [])
|
|
entry.setdefault("delivered_to", [])
|
|
messages.append(entry)
|
|
messages.sort(key=lambda m: m.get("created_at", ""))
|
|
return messages
|
|
|
|
def _decrypt_raw_messages(self, raw_messages: list, cache: dict,
|
|
conv_id: str, my_user_id: str) -> list[dict]:
|
|
"""Decrypt server messages, update cache. Returns list of decrypted dicts."""
|
|
decrypted = []
|
|
for m in raw_messages:
|
|
msg_id = m["message_id"]
|
|
|
|
if m.get("deleted_at"):
|
|
decrypted.append({
|
|
"message_id": msg_id,
|
|
"sender": "",
|
|
"text": "",
|
|
"created_at": m["created_at"],
|
|
"read_by": [],
|
|
"sender_id": m.get("sender_id", ""),
|
|
"deleted": True,
|
|
})
|
|
cache[msg_id] = {"deleted": True, "created_at": m["created_at"]}
|
|
continue
|
|
|
|
# Check local cache first (ratchet keys are one-time use)
|
|
cached = cache.get(msg_id)
|
|
if cached and not cached.get("_control"):
|
|
cached["read_by"] = m.get("read_by", [])
|
|
cached["delivered_to"] = m.get("delivered_to", [])
|
|
cached["created_at"] = m["created_at"]
|
|
if m.get("reactions"):
|
|
cached["reactions"] = m["reactions"]
|
|
if m.get("pinned_at"):
|
|
cached["pinned_at"] = m["pinned_at"]
|
|
cached["pinned_by"] = m.get("pinned_by", "")
|
|
else:
|
|
cached.pop("pinned_at", None)
|
|
cached.pop("pinned_by", None)
|
|
decrypted.append(cached)
|
|
continue
|
|
if cached and cached.get("_control"):
|
|
continue
|
|
|
|
try:
|
|
msg_data = {
|
|
"sender_id": m.get("sender_id", ""),
|
|
"sender_device_id": m.get("sender_device_id"),
|
|
"conversation_id": conv_id,
|
|
"ratchet_header": m.get("ratchet_header", {}),
|
|
"encrypted_content": m.get("encrypted_content", ""),
|
|
"nonce": m.get("nonce", ""),
|
|
"x3dh_header": m.get("x3dh_header"),
|
|
"sender_chain_id": m.get("sender_chain_id"),
|
|
"sender_chain_n": m.get("sender_chain_n"),
|
|
}
|
|
payload = self._decrypt_message(msg_data)
|
|
if payload is None:
|
|
_save_message_to_cache(self.email, conv_id, msg_id, {"_control": True},
|
|
cache_key=self._cache_key)
|
|
cache[msg_id] = {"_control": True}
|
|
continue
|
|
payload["message_id"] = msg_id
|
|
payload["created_at"] = m["created_at"]
|
|
payload["read_by"] = m.get("read_by", [])
|
|
payload["delivered_to"] = m.get("delivered_to", [])
|
|
payload["sender_id"] = m.get("sender_id", "")
|
|
if m.get("reactions"):
|
|
payload["reactions"] = m["reactions"]
|
|
if m.get("pinned_at"):
|
|
payload["pinned_at"] = m["pinned_at"]
|
|
payload["pinned_by"] = m.get("pinned_by", "")
|
|
decrypted.append(payload)
|
|
_save_message_to_cache(self.email, conv_id, msg_id, payload,
|
|
cache_key=self._cache_key)
|
|
cache[msg_id] = payload
|
|
if m.get("sender_id", "") != my_user_id:
|
|
self._pending_self_encrypt.append({
|
|
"message_id": msg_id,
|
|
"payload": {k: v for k, v in payload.items()
|
|
if k not in ("message_id", "created_at", "read_by",
|
|
"delivered_to", "sender_id", "deleted")},
|
|
})
|
|
except Exception as e:
|
|
decrypted.append({
|
|
"message_id": msg_id,
|
|
"sender": "???",
|
|
"text": f"[Decryption failed: {e}]",
|
|
"created_at": m["created_at"],
|
|
"read_by": [],
|
|
})
|
|
return decrypted
|
|
|
|
async def _flush_self_encrypt(self):
|
|
"""Upload self-encrypted copies of received messages for multi-device access."""
|
|
if not self._pending_self_encrypt or not self.identity_private:
|
|
return
|
|
self_key = derive_self_encryption_key(self.identity_private)
|
|
updates = []
|
|
for item in list(self._pending_self_encrypt):
|
|
try:
|
|
plaintext = json.dumps(item["payload"], ensure_ascii=False).encode("utf-8")
|
|
_, nonce, ct, tag = aes_encrypt(plaintext, key=self_key)
|
|
updates.append({
|
|
"message_id": item["message_id"],
|
|
"encrypted_content": encode_binary(ct + tag),
|
|
"nonce": encode_binary(nonce),
|
|
})
|
|
except Exception:
|
|
pass
|
|
self._pending_self_encrypt.clear()
|
|
if updates:
|
|
try:
|
|
for i in range(0, len(updates), 500):
|
|
batch = updates[i:i + 500]
|
|
await self.send_and_recv("reencrypt_messages", updates=batch)
|
|
except Exception as e:
|
|
self._logger.warning("Failed to self-encrypt received messages: %s", e)
|
|
|
|
async def mark_read(self, conv_id: str, message_ids: list[str]):
|
|
if not message_ids:
|
|
return
|
|
await self.send_and_recv("mark_read", conversation_id=conv_id, message_ids=message_ids)
|
|
|
|
async def mark_conversation_read(self, conv_id: str):
|
|
"""Mark ALL unread messages in a conversation as read (server-side bulk)."""
|
|
try:
|
|
await self.send_and_recv("mark_conversation_read", conversation_id=conv_id)
|
|
except Exception:
|
|
pass # non-critical — don't fail message loading
|
|
|
|
async def confirm_delivery(self, conv_id: str, message_ids: list[str]):
|
|
"""Confirm delivery of messages (fire-and-forget, non-critical)."""
|
|
if not message_ids:
|
|
return
|
|
try:
|
|
await self.send_and_recv("confirm_delivery",
|
|
conversation_id=conv_id, message_ids=message_ids)
|
|
except Exception:
|
|
pass # non-critical
|
|
|
|
def search_messages(self, conv_id: str, query: str) -> list[dict]:
|
|
"""Search cached messages in a conversation. Returns matching messages."""
|
|
cache = _load_message_cache(self.email, conv_id, self._cache_key)
|
|
query_lower = query.lower()
|
|
results = []
|
|
for msg_id, payload in cache.items():
|
|
if payload.get("deleted") or payload.get("_control") or payload.get("_sender_key"):
|
|
continue
|
|
text = payload.get("text", "")
|
|
if query_lower in text.lower():
|
|
entry = dict(payload)
|
|
entry["message_id"] = msg_id
|
|
results.append(entry)
|
|
results.sort(key=lambda m: m.get("created_at", ""))
|
|
return results
|
|
|
|
async def reset_session(self, peer_user_id: str, peer_device_id: str | None = None):
|
|
"""Delete local session and notify peer to do the same."""
|
|
if peer_device_id:
|
|
session_key = f"{peer_user_id}:{peer_device_id}"
|
|
else:
|
|
session_key = peer_user_id
|
|
self.sessions.pop(session_key, None)
|
|
_delete_session_file(self.email, peer_user_id, peer_device_id)
|
|
await self.send_and_recv("session_reset",
|
|
peer_user_id=peer_user_id,
|
|
peer_device_id=peer_device_id or "")
|
|
|
|
def handle_session_reset_notification(self, from_user_id: str, from_device_id: str | None = None):
|
|
"""Handle incoming session reset notification — delete the matching session."""
|
|
if from_device_id:
|
|
session_key = f"{from_user_id}:{from_device_id}"
|
|
else:
|
|
session_key = from_user_id
|
|
self.sessions.pop(session_key, None)
|
|
_delete_session_file(self.email, from_user_id, from_device_id)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Local message cache updates
|
|
# ------------------------------------------------------------------
|
|
|
|
def load_message_cache(self, conv_id: str) -> dict:
|
|
"""Load cached messages for a conversation. Returns {msg_id: payload}."""
|
|
if not self.email:
|
|
return {}
|
|
return _load_message_cache(self.email, conv_id, self._cache_key)
|
|
|
|
def update_message_in_cache(self, conv_id: str, message_id: str, updates: dict):
|
|
"""Update fields of a cached message on disk (synchronous)."""
|
|
if not self.email:
|
|
return
|
|
cache = _load_message_cache(self.email, conv_id, self._cache_key)
|
|
if message_id not in cache or cache[message_id].get("_control"):
|
|
return
|
|
for key, value in updates.items():
|
|
if value is None:
|
|
cache[message_id].pop(key, None)
|
|
else:
|
|
cache[message_id][key] = value
|
|
d = get_key_dir(self.email) / "message_cache"
|
|
if self._cache_key:
|
|
_save_message_cache_full(d, conv_id, cache, self._cache_key)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Reactions, Pins, Forwarding
|
|
# ------------------------------------------------------------------
|
|
|
|
async def react_message(self, message_id: str, reaction: str, action: str = "add") -> tuple[bool, str]:
|
|
"""Add or remove a reaction on a message."""
|
|
resp = await self.send_and_recv("react_message",
|
|
message_id=message_id, reaction=reaction, action=action)
|
|
if resp["status"] == "ok":
|
|
return True, "OK"
|
|
return False, resp.get("data", {}).get("message", "Failed")
|
|
|
|
async def pin_message(self, message_id: str, conversation_id: str, action: str = "pin") -> tuple[bool, str]:
|
|
"""Pin or unpin a message."""
|
|
resp = await self.send_and_recv("pin_message",
|
|
message_id=message_id, conversation_id=conversation_id, action=action)
|
|
if resp["status"] == "ok":
|
|
return True, "OK"
|
|
return False, resp.get("data", {}).get("message", "Failed")
|
|
|
|
async def get_pinned_messages(self, conversation_id: str) -> list[dict]:
|
|
"""Get list of pinned messages for a conversation."""
|
|
resp = await self.send_and_recv("get_pinned_messages", conversation_id=conversation_id)
|
|
if resp["status"] == "ok":
|
|
return resp["data"].get("messages", [])
|
|
return []
|
|
|
|
async def forward_message(self, target_conv_id: str, original_msg: dict,
|
|
target_members: list[dict]) -> tuple[bool, str | dict]:
|
|
"""Forward a message to another conversation."""
|
|
text = original_msg.get("text", "")
|
|
|
|
payload = {
|
|
"sender": self.username,
|
|
"text": text,
|
|
"forwarded_from": {
|
|
"sender": original_msg.get("sender", ""),
|
|
"conversation_id": original_msg.get("conversation_id", ""),
|
|
"message_id": original_msg.get("message_id", ""),
|
|
},
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
}
|
|
# Forward image/file metadata (the encrypted blob is already on the server)
|
|
if original_msg.get("image"):
|
|
payload["image"] = original_msg["image"]
|
|
if not text:
|
|
payload["text"] = ""
|
|
if original_msg.get("file"):
|
|
payload["file"] = original_msg["file"]
|
|
if not text:
|
|
payload["text"] = ""
|
|
plaintext = pad_plaintext(json.dumps(payload, ensure_ascii=False).encode("utf-8"))
|
|
|
|
if self._is_group(target_members):
|
|
return await self._send_group_message(target_conv_id, plaintext, target_members, payload)
|
|
else:
|
|
return await self._send_dm(target_conv_id, plaintext, target_members, payload)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Decrypt notification
|
|
# ------------------------------------------------------------------
|
|
|
|
def decrypt_notification(self, notif_data: dict) -> dict | None:
|
|
"""Decrypt a new_message notification. Returns parsed payload or None.
|
|
|
|
Supports new multi-device format (device_entries array) and legacy flat format.
|
|
"""
|
|
try:
|
|
conv_id = notif_data.get("conversation_id", "")
|
|
msg_id = notif_data.get("message_id", "")
|
|
sender_id = notif_data.get("sender_id", "")
|
|
sender_device_id = notif_data.get("sender_device_id")
|
|
my_user_id = self.session["user_id"] if self.session else ""
|
|
|
|
# Extract per-device encrypted content from device_entries or flat fields
|
|
encrypted_content = ""
|
|
nonce = ""
|
|
ratchet_header = {}
|
|
x3dh_header = None
|
|
|
|
device_entries = notif_data.get("device_entries")
|
|
if device_entries:
|
|
# Multi-device format: pick entry matching our device_id or SELF_DEVICE_ID
|
|
chosen = None
|
|
self_entry = None
|
|
for entry in device_entries:
|
|
eid = entry.get("device_id", "")
|
|
if eid == self.device_id:
|
|
chosen = entry
|
|
break
|
|
if eid == "00000000-0000-0000-0000-000000000000":
|
|
self_entry = entry
|
|
|
|
# If sender is us, prefer self-encrypted entry
|
|
if sender_id == my_user_id:
|
|
chosen = self_entry or chosen
|
|
elif not chosen:
|
|
chosen = self_entry
|
|
|
|
if not chosen:
|
|
self._logger.warning("No matching device_entry for device %s", self.device_id)
|
|
return None
|
|
|
|
encrypted_content = chosen.get("encrypted_content", "")
|
|
nonce = chosen.get("nonce", "")
|
|
ratchet_header = chosen.get("ratchet_header") or notif_data.get("ratchet_header", {})
|
|
x3dh_header = chosen.get("x3dh_header") or notif_data.get("x3dh_header")
|
|
else:
|
|
# Legacy flat format
|
|
encrypted_content = notif_data.get("encrypted_content", "")
|
|
nonce = notif_data.get("nonce", "")
|
|
ratchet_header = notif_data.get("ratchet_header", {})
|
|
x3dh_header = notif_data.get("x3dh_header")
|
|
|
|
msg_data = {
|
|
"sender_id": sender_id,
|
|
"sender_device_id": sender_device_id,
|
|
"conversation_id": conv_id,
|
|
"ratchet_header": ratchet_header,
|
|
"encrypted_content": encrypted_content,
|
|
"nonce": nonce,
|
|
"x3dh_header": x3dh_header,
|
|
"sender_chain_id": notif_data.get("sender_chain_id"),
|
|
"sender_chain_n": notif_data.get("sender_chain_n"),
|
|
}
|
|
payload = self._decrypt_message(msg_data)
|
|
if payload is None:
|
|
# Cache control message so get_messages skips it
|
|
if msg_id and conv_id:
|
|
_save_message_to_cache(self.email, conv_id, msg_id, {"_control": True},
|
|
cache_key=self._cache_key)
|
|
return None
|
|
payload["conversation_id"] = conv_id
|
|
payload["message_id"] = msg_id
|
|
payload["sender_id"] = sender_id
|
|
# Use server-compatible timestamp (no timezone suffix) for cache consistency
|
|
_ts = payload.get("timestamp", "")
|
|
if _ts:
|
|
# Strip timezone suffix (+00:00 or Z) to match server DATETIME format
|
|
_ts = _ts.replace("+00:00", "").replace("Z", "")
|
|
# Strip microseconds if present
|
|
if "." in _ts:
|
|
_ts = _ts[:_ts.index(".")]
|
|
payload["created_at"] = _ts
|
|
payload["read_by"] = []
|
|
payload["delivered_to"] = []
|
|
# Cache so get_messages doesn't re-decrypt (ratchet keys are one-time)
|
|
if msg_id and conv_id:
|
|
_save_message_to_cache(self.email, conv_id, msg_id, payload,
|
|
cache_key=self._cache_key)
|
|
# Queue self-encryption for received messages (multi-device access)
|
|
if sender_id != my_user_id and msg_id:
|
|
self._pending_self_encrypt.append({
|
|
"message_id": msg_id,
|
|
"payload": {k: v for k, v in payload.items()
|
|
if k not in ("conversation_id", "message_id", "created_at",
|
|
"read_by", "delivered_to", "sender_id", "deleted")},
|
|
})
|
|
return payload
|
|
except IdentityKeyChanged:
|
|
raise # Must propagate to caller for key-change UI
|
|
except Exception as e:
|
|
self._logger.warning("Failed to decrypt notification: %s", e)
|
|
return None
|
|
|
|
# ------------------------------------------------------------------
|
|
# Delete message
|
|
# ------------------------------------------------------------------
|
|
|
|
async def delete_message(self, message_id: str) -> tuple[bool, str]:
|
|
resp = await self.send_and_recv("delete_message", message_id=message_id)
|
|
if resp["status"] == "ok":
|
|
return True, "Message deleted."
|
|
return False, resp["data"]["message"]
|
|
|
|
# ------------------------------------------------------------------
|
|
# Image sharing
|
|
# ------------------------------------------------------------------
|
|
|
|
async def send_image(self, conv_id: str, image_path: str, members: list[dict],
|
|
reply_to: str | None = None) -> tuple[bool, str]:
|
|
"""Encrypt and upload an image, then send as a message."""
|
|
await self.typing_stop(conv_id, force=True)
|
|
try:
|
|
from PIL import Image
|
|
import io
|
|
except ImportError:
|
|
return False, "Pillow is required for image sharing. Install with: pip install Pillow"
|
|
|
|
path = Path(image_path)
|
|
if not path.exists():
|
|
return False, "File not found."
|
|
|
|
try:
|
|
img = Image.open(path)
|
|
img.load()
|
|
except Exception as e:
|
|
return False, f"Cannot open image: {e}"
|
|
|
|
# Prepare image bytes — offload CPU-heavy work to thread
|
|
def _prepare_image(img, path):
|
|
"""PIL resize + thumbnail + AES encrypt (runs in thread)."""
|
|
original_format = img.format or "JPEG"
|
|
if original_format.upper() not in ("JPEG", "PNG", "WEBP", "GIF", "BMP"):
|
|
original_format = "JPEG"
|
|
|
|
image_bytes = path.read_bytes()
|
|
|
|
# AES-GCM overhead: 16 bytes tag. Check raw size as proxy.
|
|
if MAX_IMAGE_BYTES > 0 and len(image_bytes) + 16 > MAX_IMAGE_BYTES:
|
|
if img.mode not in ("RGB", "L"):
|
|
img = img.convert("RGB")
|
|
for quality in (92, 85, 75, 60):
|
|
buf = io.BytesIO()
|
|
img.save(buf, format="JPEG", quality=quality)
|
|
image_bytes = buf.getvalue()
|
|
if len(image_bytes) + 16 <= MAX_IMAGE_BYTES:
|
|
break
|
|
else:
|
|
for max_dim in (3840, 2560, 1920, 1280):
|
|
if max(img.size) > max_dim:
|
|
img.thumbnail((max_dim, max_dim), Image.Resampling.LANCZOS)
|
|
buf = io.BytesIO()
|
|
img.save(buf, format="JPEG", quality=75)
|
|
image_bytes = buf.getvalue()
|
|
if len(image_bytes) + 16 <= MAX_IMAGE_BYTES:
|
|
break
|
|
|
|
# Generate thumbnail
|
|
thumb = img.copy()
|
|
thumb.thumbnail((200, 200), Image.Resampling.LANCZOS)
|
|
if thumb.mode not in ("RGB", "L"):
|
|
thumb = thumb.convert("RGB")
|
|
thumb_buf = io.BytesIO()
|
|
thumb.save(thumb_buf, format="JPEG", quality=60)
|
|
thumbnail_b64 = encode_binary(thumb_buf.getvalue())
|
|
|
|
# Encrypt
|
|
img_aes_key, img_iv, img_ct, img_tag = aes_encrypt(image_bytes)
|
|
encrypted_image = img_ct + img_tag
|
|
return image_bytes, thumbnail_b64, img_aes_key, img_iv, encrypted_image
|
|
|
|
image_bytes, thumbnail_b64, img_aes_key, img_iv, encrypted_image = \
|
|
await asyncio.get_event_loop().run_in_executor(None, _prepare_image, img, path)
|
|
|
|
file_id = str(uuid.uuid4())
|
|
file_size = len(encrypted_image)
|
|
|
|
# Chunked upload
|
|
resp = await self.send_and_recv(
|
|
"upload_image_start",
|
|
conversation_id=conv_id,
|
|
file_id=file_id,
|
|
file_size=file_size,
|
|
)
|
|
if resp["status"] != "ok":
|
|
return False, resp["data"]["message"]
|
|
|
|
ok, err = await self._pipelined_upload(file_id, encrypted_image)
|
|
if not ok:
|
|
return False, err
|
|
|
|
resp = await self.send_and_recv("upload_image_end", file_id=file_id)
|
|
if resp["status"] != "ok":
|
|
return False, resp["data"]["message"]
|
|
|
|
# Cache decrypted image locally so sender never re-downloads
|
|
cache_path = self._media_cache_path(file_id)
|
|
if cache_path:
|
|
try:
|
|
cache_path.write_bytes(image_bytes)
|
|
os.chmod(cache_path, 0o600)
|
|
except OSError:
|
|
pass
|
|
|
|
# Build message payload with image info
|
|
image_info = {
|
|
"file_id": file_id,
|
|
"aes_key": encode_binary(img_aes_key),
|
|
"iv": encode_binary(img_iv),
|
|
"thumbnail": thumbnail_b64,
|
|
"filename": path.name,
|
|
"size": len(image_bytes),
|
|
}
|
|
|
|
payload = {
|
|
"sender": self.username,
|
|
"text": "",
|
|
"reply_to": reply_to,
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
"image": image_info,
|
|
}
|
|
plaintext = pad_plaintext(json.dumps(payload, ensure_ascii=False).encode("utf-8"))
|
|
|
|
my_user_id = self.session["user_id"]
|
|
|
|
if self._is_group(members):
|
|
# Group image: use sender key
|
|
sk = self.sender_key_states.get(conv_id)
|
|
if not sk:
|
|
sk = _load_sender_key_state(self.email, conv_id, self._local_key)
|
|
if not sk:
|
|
sk = SenderKeyState()
|
|
self.sender_key_states[conv_id] = sk
|
|
_save_sender_key_state(self.email, conv_id, sk, self._local_key)
|
|
self.sender_key_states[conv_id] = sk
|
|
await self._catch_up_sender_key_distribution(conv_id, members, sk)
|
|
|
|
result = sk.encrypt(plaintext)
|
|
_save_sender_key_state(self.email, conv_id, sk, self._local_key)
|
|
|
|
recipients = []
|
|
for member in members:
|
|
uid = member.get("user_id")
|
|
if not uid or uid == my_user_id:
|
|
continue
|
|
recipients.append({
|
|
"user_id": uid,
|
|
"encrypted_content": encode_binary(result["ciphertext"]),
|
|
"nonce": encode_binary(result["nonce"]),
|
|
})
|
|
|
|
# Self-encrypted copy for sender
|
|
self_key = derive_self_encryption_key(self.identity_private)
|
|
_, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key)
|
|
recipients.append({
|
|
"user_id": my_user_id,
|
|
"encrypted_content": encode_binary(self_ct + self_tag),
|
|
"nonce": encode_binary(self_nonce),
|
|
"ratchet_header": {"self": True},
|
|
})
|
|
|
|
resp = await self.send_and_recv(
|
|
"send_message",
|
|
conversation_id=conv_id,
|
|
ratchet_header={"dh_pub": "00" * 32, "n": 0, "pn": 0},
|
|
recipients=recipients,
|
|
sender_chain_id=encode_binary(bytes.fromhex(result["chain_id"])),
|
|
sender_chain_n=result["n"],
|
|
image_file_id=file_id,
|
|
)
|
|
else:
|
|
# DM image: per-device ratchet (same pattern as _send_dm)
|
|
recipients = []
|
|
first_rh = None
|
|
for member in members:
|
|
uid = member.get("user_id")
|
|
if not uid or uid == my_user_id:
|
|
continue
|
|
|
|
try:
|
|
device_bundles = await self._get_device_bundles(uid)
|
|
except Exception:
|
|
device_bundles = []
|
|
|
|
if not device_bundles:
|
|
# Fallback: legacy single-device
|
|
ratchet = await self._get_or_create_session(uid)
|
|
result = ratchet.encrypt(plaintext)
|
|
x3dh_h = getattr(ratchet, "_x3dh_header", None)
|
|
if x3dh_h:
|
|
delattr(ratchet, "_x3dh_header")
|
|
entry = {
|
|
"user_id": uid,
|
|
"encrypted_content": encode_binary(result["ciphertext"]),
|
|
"nonce": encode_binary(result["nonce"]),
|
|
"ratchet_header": result["header"],
|
|
}
|
|
if x3dh_h:
|
|
entry["x3dh_header"] = x3dh_h
|
|
recipients.append(entry)
|
|
if first_rh is None:
|
|
first_rh = result["header"]
|
|
_save_session(self.email, uid, ratchet, self._local_key)
|
|
else:
|
|
for bundle in device_bundles:
|
|
dev_id = bundle.get("device_id")
|
|
ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id,
|
|
bundle=bundle)
|
|
result = ratchet.encrypt(plaintext)
|
|
x3dh_h = getattr(ratchet, "_x3dh_header", None)
|
|
if x3dh_h:
|
|
delattr(ratchet, "_x3dh_header")
|
|
entry = {
|
|
"user_id": uid,
|
|
"encrypted_content": encode_binary(result["ciphertext"]),
|
|
"nonce": encode_binary(result["nonce"]),
|
|
"ratchet_header": result["header"],
|
|
}
|
|
if dev_id:
|
|
entry["device_id"] = dev_id
|
|
if x3dh_h:
|
|
entry["x3dh_header"] = x3dh_h
|
|
recipients.append(entry)
|
|
if first_rh is None:
|
|
first_rh = result["header"]
|
|
_save_session(self.email, uid, ratchet, self._local_key,
|
|
peer_device_id=dev_id)
|
|
|
|
# Encrypt self-copy with static key
|
|
self_key = derive_self_encryption_key(self.identity_private)
|
|
_, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key)
|
|
recipients.append({
|
|
"user_id": my_user_id,
|
|
"encrypted_content": encode_binary(self_ct + self_tag),
|
|
"nonce": encode_binary(self_nonce),
|
|
"ratchet_header": {"self": True},
|
|
})
|
|
|
|
resp = await self.send_and_recv(
|
|
"send_message",
|
|
conversation_id=conv_id,
|
|
ratchet_header=first_rh,
|
|
recipients=recipients,
|
|
image_file_id=file_id,
|
|
)
|
|
|
|
if resp["status"] == "ok":
|
|
msg_data = resp.get("data", {})
|
|
result_msg = {
|
|
**payload,
|
|
"message_id": msg_data.get("message_id", ""),
|
|
"created_at": msg_data.get("created_at", ""),
|
|
"sender_id": self.session["user_id"],
|
|
"conversation_id": conv_id,
|
|
"read_by": [],
|
|
}
|
|
_save_message_to_cache(self.email, conv_id, result_msg["message_id"], result_msg, self._cache_key)
|
|
return True, result_msg
|
|
return False, resp["data"]["message"]
|
|
|
|
async def _pipelined_upload(self, file_id: str, encrypted_data: bytes) -> tuple[bool, str]:
|
|
"""Upload encrypted data in pipelined chunks (no per-chunk round-trip wait)."""
|
|
file_size = len(encrypted_data)
|
|
chunk_futures = []
|
|
upload_offset = 0
|
|
while upload_offset < file_size:
|
|
chunk = encrypted_data[upload_offset:upload_offset + IMAGE_CHUNK_SIZE]
|
|
request_id = str(uuid.uuid4())
|
|
loop = asyncio.get_running_loop()
|
|
fut = loop.create_future()
|
|
self._pending[request_id] = fut
|
|
try:
|
|
await self.writer.send_request(
|
|
"upload_image_chunk",
|
|
request_id=request_id,
|
|
file_id=file_id,
|
|
data=encode_binary(chunk),
|
|
)
|
|
except Exception as e:
|
|
self._pending.pop(request_id, None)
|
|
return False, f"Upload failed: {e}"
|
|
chunk_futures.append((request_id, fut))
|
|
upload_offset += len(chunk)
|
|
|
|
for request_id, fut in chunk_futures:
|
|
try:
|
|
resp = await asyncio.wait_for(fut, timeout=30.0)
|
|
except (asyncio.TimeoutError, ConnectionError):
|
|
self._pending.pop(request_id, None)
|
|
return False, "Upload chunk timed out."
|
|
finally:
|
|
self._pending.pop(request_id, None)
|
|
if resp["status"] != "ok":
|
|
return False, resp["data"]["message"]
|
|
return True, ""
|
|
|
|
async def send_file(self, conv_id: str, file_path: str, members: list[dict],
|
|
reply_to: str | None = None) -> tuple[bool, str | dict]:
|
|
"""Encrypt and upload a file, then send as a message."""
|
|
await self.typing_stop(conv_id, force=True)
|
|
import mimetypes
|
|
|
|
path = Path(file_path)
|
|
if not path.exists():
|
|
return False, "File not found."
|
|
|
|
try:
|
|
file_bytes = path.read_bytes()
|
|
except Exception as e:
|
|
return False, f"Cannot read file: {e}"
|
|
|
|
mime_type = mimetypes.guess_type(path.name)[0] or "application/octet-stream"
|
|
|
|
# Encrypt file with AES-256-GCM (offload to thread)
|
|
file_aes_key, file_iv, file_ct, file_tag = await asyncio.get_event_loop().run_in_executor(
|
|
None, aes_encrypt, file_bytes)
|
|
encrypted_file = file_ct + file_tag
|
|
|
|
file_id = str(uuid.uuid4())
|
|
file_size = len(encrypted_file)
|
|
|
|
# Chunked upload (reuse image upload infrastructure with file_type="file")
|
|
resp = await self.send_and_recv(
|
|
"upload_image_start",
|
|
conversation_id=conv_id,
|
|
file_id=file_id,
|
|
file_size=file_size,
|
|
file_type="file",
|
|
)
|
|
if resp["status"] != "ok":
|
|
return False, resp["data"]["message"]
|
|
|
|
ok, err = await self._pipelined_upload(file_id, encrypted_file)
|
|
if not ok:
|
|
return False, err
|
|
|
|
resp = await self.send_and_recv("upload_image_end", file_id=file_id)
|
|
if resp["status"] != "ok":
|
|
return False, resp["data"]["message"]
|
|
|
|
# Cache decrypted file locally so sender never re-downloads
|
|
cache_path = self._media_cache_path(file_id)
|
|
if cache_path:
|
|
try:
|
|
cache_path.write_bytes(file_bytes)
|
|
os.chmod(cache_path, 0o600)
|
|
except OSError:
|
|
pass
|
|
|
|
# Build message payload with file info
|
|
file_info = {
|
|
"file_id": file_id,
|
|
"aes_key": encode_binary(file_aes_key),
|
|
"iv": encode_binary(file_iv),
|
|
"filename": path.name,
|
|
"size": len(file_bytes),
|
|
"mime_type": mime_type,
|
|
}
|
|
|
|
payload = {
|
|
"sender": self.username,
|
|
"text": "",
|
|
"reply_to": reply_to,
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
"file": file_info,
|
|
}
|
|
plaintext = pad_plaintext(json.dumps(payload, ensure_ascii=False).encode("utf-8"))
|
|
|
|
my_user_id = self.session["user_id"]
|
|
|
|
if self._is_group(members):
|
|
sk = self.sender_key_states.get(conv_id)
|
|
if not sk:
|
|
sk = _load_sender_key_state(self.email, conv_id, self._local_key)
|
|
if not sk:
|
|
sk = SenderKeyState()
|
|
self.sender_key_states[conv_id] = sk
|
|
_save_sender_key_state(self.email, conv_id, sk, self._local_key)
|
|
self.sender_key_states[conv_id] = sk
|
|
await self._catch_up_sender_key_distribution(conv_id, members, sk)
|
|
|
|
result = sk.encrypt(plaintext)
|
|
_save_sender_key_state(self.email, conv_id, sk, self._local_key)
|
|
|
|
recipients = []
|
|
for member in members:
|
|
uid = member.get("user_id")
|
|
if not uid or uid == my_user_id:
|
|
continue
|
|
recipients.append({
|
|
"user_id": uid,
|
|
"encrypted_content": encode_binary(result["ciphertext"]),
|
|
"nonce": encode_binary(result["nonce"]),
|
|
})
|
|
|
|
# Self-encrypted copy for sender
|
|
self_key = derive_self_encryption_key(self.identity_private)
|
|
_, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key)
|
|
recipients.append({
|
|
"user_id": my_user_id,
|
|
"encrypted_content": encode_binary(self_ct + self_tag),
|
|
"nonce": encode_binary(self_nonce),
|
|
"ratchet_header": {"self": True},
|
|
})
|
|
|
|
resp = await self.send_and_recv(
|
|
"send_message",
|
|
conversation_id=conv_id,
|
|
ratchet_header={"dh_pub": "00" * 32, "n": 0, "pn": 0},
|
|
recipients=recipients,
|
|
sender_chain_id=encode_binary(bytes.fromhex(result["chain_id"])),
|
|
sender_chain_n=result["n"],
|
|
image_file_id=file_id,
|
|
)
|
|
else:
|
|
# DM file: per-device ratchet (same pattern as _send_dm)
|
|
recipients = []
|
|
first_rh = None
|
|
for member in members:
|
|
uid = member.get("user_id")
|
|
if not uid or uid == my_user_id:
|
|
continue
|
|
|
|
try:
|
|
device_bundles = await self._get_device_bundles(uid)
|
|
except Exception:
|
|
device_bundles = []
|
|
|
|
if not device_bundles:
|
|
# Fallback: legacy single-device
|
|
ratchet = await self._get_or_create_session(uid)
|
|
result = ratchet.encrypt(plaintext)
|
|
x3dh_h = getattr(ratchet, "_x3dh_header", None)
|
|
if x3dh_h:
|
|
delattr(ratchet, "_x3dh_header")
|
|
entry = {
|
|
"user_id": uid,
|
|
"encrypted_content": encode_binary(result["ciphertext"]),
|
|
"nonce": encode_binary(result["nonce"]),
|
|
"ratchet_header": result["header"],
|
|
}
|
|
if x3dh_h:
|
|
entry["x3dh_header"] = x3dh_h
|
|
recipients.append(entry)
|
|
if first_rh is None:
|
|
first_rh = result["header"]
|
|
_save_session(self.email, uid, ratchet, self._local_key)
|
|
else:
|
|
for bundle in device_bundles:
|
|
dev_id = bundle.get("device_id")
|
|
ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id,
|
|
bundle=bundle)
|
|
result = ratchet.encrypt(plaintext)
|
|
x3dh_h = getattr(ratchet, "_x3dh_header", None)
|
|
if x3dh_h:
|
|
delattr(ratchet, "_x3dh_header")
|
|
entry = {
|
|
"user_id": uid,
|
|
"encrypted_content": encode_binary(result["ciphertext"]),
|
|
"nonce": encode_binary(result["nonce"]),
|
|
"ratchet_header": result["header"],
|
|
}
|
|
if dev_id:
|
|
entry["device_id"] = dev_id
|
|
if x3dh_h:
|
|
entry["x3dh_header"] = x3dh_h
|
|
recipients.append(entry)
|
|
if first_rh is None:
|
|
first_rh = result["header"]
|
|
_save_session(self.email, uid, ratchet, self._local_key,
|
|
peer_device_id=dev_id)
|
|
|
|
# Encrypt self-copy with static key
|
|
self_key = derive_self_encryption_key(self.identity_private)
|
|
_, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key)
|
|
recipients.append({
|
|
"user_id": my_user_id,
|
|
"encrypted_content": encode_binary(self_ct + self_tag),
|
|
"nonce": encode_binary(self_nonce),
|
|
"ratchet_header": {"self": True},
|
|
})
|
|
|
|
resp = await self.send_and_recv(
|
|
"send_message",
|
|
conversation_id=conv_id,
|
|
ratchet_header=first_rh,
|
|
recipients=recipients,
|
|
image_file_id=file_id,
|
|
)
|
|
|
|
if resp["status"] == "ok":
|
|
msg_data = resp.get("data", {})
|
|
result_msg = {
|
|
**payload,
|
|
"message_id": msg_data.get("message_id", ""),
|
|
"created_at": msg_data.get("created_at", ""),
|
|
"sender_id": self.session["user_id"],
|
|
"conversation_id": conv_id,
|
|
"read_by": [],
|
|
}
|
|
_save_message_to_cache(self.email, conv_id, result_msg["message_id"], result_msg, self._cache_key)
|
|
return True, result_msg
|
|
return False, resp["data"]["message"]
|
|
|
|
async def download_file(self, file_id: str, file_info: dict) -> bytes | None:
|
|
"""Download and decrypt a file. Returns decrypted file bytes or None."""
|
|
return await self._download_and_decrypt(file_id, file_info)
|
|
|
|
def _media_cache_path(self, file_id: str) -> Path | None:
|
|
"""Return path for cached decrypted media file, or None if no email."""
|
|
if not self.email:
|
|
return None
|
|
d = get_key_dir(self.email) / "media_cache"
|
|
d.mkdir(parents=True, exist_ok=True)
|
|
try:
|
|
os.chmod(d, 0o700)
|
|
except OSError:
|
|
pass
|
|
return d / f"{file_id}.bin"
|
|
|
|
async def _stream_download(self, file_id: str) -> bytes | None:
|
|
"""Download file via streaming (single request, server sends all chunks).
|
|
|
|
Falls back to legacy pipelined download if server doesn't support streaming.
|
|
"""
|
|
request_id = str(uuid.uuid4())
|
|
q: asyncio.Queue = asyncio.Queue()
|
|
self._pending[request_id] = q
|
|
try:
|
|
await self.writer.send_request(
|
|
"download_stream", request_id=request_id, file_id=file_id,
|
|
)
|
|
except Exception:
|
|
self._pending.pop(request_id, None)
|
|
return None
|
|
|
|
chunks: dict[int, bytes] = {}
|
|
try:
|
|
while True:
|
|
try:
|
|
resp = await asyncio.wait_for(q.get(), timeout=60.0)
|
|
except (asyncio.TimeoutError, ConnectionError):
|
|
return None
|
|
if resp.get("status") != "ok":
|
|
# Server may not support download_stream — fall back
|
|
return await self._legacy_download(file_id)
|
|
data = resp["data"]
|
|
chunk_data = decode_binary(data["data"])
|
|
chunks[data["offset"]] = chunk_data
|
|
if data.get("done"):
|
|
break
|
|
finally:
|
|
self._pending.pop(request_id, None)
|
|
|
|
# Reassemble in order
|
|
parts = []
|
|
for off in sorted(chunks.keys()):
|
|
parts.append(chunks[off])
|
|
return b"".join(parts)
|
|
|
|
async def _legacy_download(self, file_id: str) -> bytes | None:
|
|
"""Fallback: download file chunk by chunk (for older servers)."""
|
|
resp = await self.send_and_recv("download_image", file_id=file_id, offset=0)
|
|
if resp["status"] != "ok":
|
|
return None
|
|
data = resp["data"]
|
|
first_chunk = decode_binary(data["data"])
|
|
if data.get("done"):
|
|
return first_chunk
|
|
|
|
chunk_size = len(first_chunk)
|
|
chunks = {0: first_chunk}
|
|
|
|
# Pipeline remaining chunks
|
|
futures = []
|
|
offset = chunk_size
|
|
total_size = data.get("total_size", 0)
|
|
# Calculate how many chunks we need
|
|
while offset < total_size:
|
|
request_id = str(uuid.uuid4())
|
|
loop = asyncio.get_running_loop()
|
|
fut = loop.create_future()
|
|
self._pending[request_id] = fut
|
|
try:
|
|
await self.writer.send_request(
|
|
"download_image", request_id=request_id,
|
|
file_id=file_id, offset=offset,
|
|
)
|
|
except Exception:
|
|
self._pending.pop(request_id, None)
|
|
return None
|
|
futures.append((request_id, offset, fut))
|
|
offset += chunk_size
|
|
|
|
for request_id, off, fut in futures:
|
|
try:
|
|
resp = await asyncio.wait_for(fut, timeout=30.0)
|
|
except (asyncio.TimeoutError, ConnectionError):
|
|
self._pending.pop(request_id, None)
|
|
return None
|
|
finally:
|
|
self._pending.pop(request_id, None)
|
|
if resp["status"] != "ok":
|
|
return None
|
|
chunk_data = decode_binary(resp["data"]["data"])
|
|
chunks[off] = chunk_data
|
|
|
|
parts = []
|
|
for off in sorted(chunks.keys()):
|
|
parts.append(chunks[off])
|
|
return b"".join(parts)
|
|
|
|
async def _download_and_decrypt(self, file_id: str, info: dict) -> bytes | None:
|
|
"""Download, decrypt, and cache a media file. Used by both image and file download."""
|
|
# Check local cache first
|
|
cache_path = self._media_cache_path(file_id)
|
|
if cache_path and cache_path.exists():
|
|
try:
|
|
return cache_path.read_bytes()
|
|
except OSError:
|
|
pass
|
|
|
|
encrypted_data = await self._stream_download(file_id)
|
|
if not encrypted_data or len(encrypted_data) < 16:
|
|
return None
|
|
|
|
ciphertext = encrypted_data[:-16]
|
|
tag = encrypted_data[-16:]
|
|
|
|
try:
|
|
aes_key = decode_binary(info["aes_key"])
|
|
iv = decode_binary(info["iv"])
|
|
decrypted = aes_decrypt(aes_key, iv, ciphertext, tag)
|
|
except Exception:
|
|
return None
|
|
|
|
# Cache decrypted result to disk
|
|
if cache_path and decrypted:
|
|
try:
|
|
cache_path.write_bytes(decrypted)
|
|
os.chmod(cache_path, 0o600)
|
|
except OSError:
|
|
pass
|
|
|
|
return decrypted
|
|
|
|
async def download_image(self, file_id: str, image_info: dict) -> bytes | None:
|
|
"""Download and decrypt an image. Returns decrypted image bytes or None."""
|
|
return await self._download_and_decrypt(file_id, image_info)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Re-encrypt history (for device pairing)
|
|
# ------------------------------------------------------------------
|
|
|
|
async def reencrypt_history(self):
|
|
"""Re-encrypt all cached messages with self-encryption key.
|
|
|
|
After device pairing, the new device shares the same identity key
|
|
but cannot decrypt old messages (Double Ratchet keys are one-time use).
|
|
This re-encrypts all cached messages so they can be read using the
|
|
self-encryption key derived from the shared identity key.
|
|
"""
|
|
if not self.identity_private or not self.session:
|
|
return
|
|
|
|
self_key = derive_self_encryption_key(self.identity_private)
|
|
|
|
# Phase 1: Fetch & decrypt all messages to populate cache
|
|
# (messages the old device never opened won't be in cache yet)
|
|
try:
|
|
convs = await self.list_conversations()
|
|
convs = list(convs)
|
|
random.shuffle(convs)
|
|
total_convs = len(convs)
|
|
for ci, conv in enumerate(convs):
|
|
cid = conv.get("id") or conv.get("conversation_id")
|
|
if not cid:
|
|
continue
|
|
if self._reencrypt_progress_cb:
|
|
self._reencrypt_progress_cb(
|
|
f"Fetching messages: {ci + 1}/{total_convs} conversations..."
|
|
)
|
|
offset = 0
|
|
while True:
|
|
msgs = await self.get_messages(cid, limit=200, offset=offset)
|
|
if not msgs or len(msgs) < 200:
|
|
break
|
|
offset += len(msgs)
|
|
await asyncio.sleep(random.uniform(*PAIRING_REENCRYPT_INTER_FETCH_DELAY_RANGE))
|
|
except Exception as e:
|
|
self._logger.warning("Failed to fetch messages for re-encryption: %s", e)
|
|
|
|
# Phase 2: Read cache and re-encrypt
|
|
cache_dir = get_key_dir(self.email) / "message_cache"
|
|
if not cache_dir.exists():
|
|
self._logger.info("No message cache to re-encrypt.")
|
|
return
|
|
|
|
all_updates = []
|
|
conv_ids = set()
|
|
for f in cache_dir.iterdir():
|
|
if f.suffix in (".json", ".bin"):
|
|
conv_ids.add(f.stem)
|
|
|
|
total_files = len(conv_ids)
|
|
conv_order = sorted(conv_ids)
|
|
random.shuffle(conv_order)
|
|
for i, conv_id in enumerate(conv_order):
|
|
cache = _load_message_cache(self.email, conv_id, self._cache_key)
|
|
if not cache:
|
|
continue
|
|
|
|
items = list(cache.items())
|
|
random.shuffle(items)
|
|
for msg_id, entry in items:
|
|
# Skip control messages (sender key distribution)
|
|
if entry.get("_control"):
|
|
continue
|
|
# Skip entries with no useful content
|
|
text = entry.get("text", "")
|
|
if not text and not entry.get("image") and not entry.get("file"):
|
|
continue
|
|
|
|
# Rebuild plaintext from cached payload
|
|
payload = {k: v for k, v in entry.items()
|
|
if k not in ("message_id", "created_at", "read_by", "sender_id", "deleted")}
|
|
plaintext = pad_plaintext(json.dumps(payload, ensure_ascii=False).encode("utf-8"))
|
|
|
|
# Re-encrypt with self-encryption key
|
|
_, nonce, ct, tag = aes_encrypt(plaintext, key=self_key)
|
|
all_updates.append({
|
|
"message_id": msg_id,
|
|
"encrypted_content": encode_binary(ct + tag),
|
|
"nonce": encode_binary(nonce),
|
|
})
|
|
|
|
if self._reencrypt_progress_cb:
|
|
self._reencrypt_progress_cb(f"Re-encrypting history: {i + 1}/{total_files} conversations...")
|
|
|
|
if not all_updates:
|
|
self._logger.info("No messages to re-encrypt.")
|
|
return
|
|
|
|
# Send in batches of 500
|
|
batch_size = PAIRING_REENCRYPT_BATCH_SIZE
|
|
random.shuffle(all_updates)
|
|
total = len(all_updates)
|
|
for start in range(0, total, batch_size):
|
|
batch = all_updates[start:start + batch_size]
|
|
resp = await self.send_and_recv("reencrypt_messages", updates=batch)
|
|
if resp["status"] != "ok":
|
|
self._logger.warning("Re-encrypt batch failed: %s", resp.get("data", {}).get("message", ""))
|
|
else:
|
|
self._logger.info("Re-encrypted %d/%d messages.", min(start + batch_size, total), total)
|
|
if start + batch_size < total:
|
|
await asyncio.sleep(random.uniform(*PAIRING_REENCRYPT_INTER_BATCH_DELAY_RANGE))
|
|
|
|
if self._reencrypt_progress_cb:
|
|
self._reencrypt_progress_cb(f"Re-encryption complete: {total} messages uploaded.")
|
|
|
|
# ------------------------------------------------------------------
|
|
# User Profiles
|
|
# ------------------------------------------------------------------
|
|
|
|
async def get_profile(self, user_id: str | None = None) -> dict | None:
|
|
"""Get user profile. If user_id is None, returns own profile."""
|
|
kwargs = {}
|
|
if user_id:
|
|
kwargs["user_id"] = user_id
|
|
resp = await self.send_and_recv("get_profile", **kwargs)
|
|
if resp["status"] == "ok":
|
|
return resp["data"]
|
|
return None
|
|
|
|
async def update_profile(self, **fields) -> tuple[bool, str]:
|
|
"""Update own profile (phone, location, *_visible)."""
|
|
resp = await self.send_and_recv("update_profile", **fields)
|
|
if resp["status"] == "ok":
|
|
return True, "OK"
|
|
return False, resp["data"]["message"]
|
|
|
|
async def update_avatar(self, image_data: bytes) -> tuple[bool, str]:
|
|
"""Upload avatar image."""
|
|
resp = await self.send_and_recv("update_avatar", data=encode_binary(image_data))
|
|
if resp["status"] == "ok":
|
|
return True, resp["data"].get("avatar_file", "")
|
|
return False, resp["data"]["message"]
|
|
|
|
async def get_avatar(self, user_id: str) -> bytes | None:
|
|
"""Download avatar for a user."""
|
|
resp = await self.send_and_recv("get_avatar", 10.0, user_id=user_id)
|
|
if resp["status"] == "ok":
|
|
return decode_binary(resp["data"]["data"])
|
|
return None
|
|
|
|
async def update_group_avatar(self, conv_id: str, image_data: bytes) -> tuple[bool, str]:
|
|
"""Upload avatar for a group conversation."""
|
|
resp = await self.send_and_recv("update_group_avatar",
|
|
conversation_id=conv_id, data=encode_binary(image_data))
|
|
if resp["status"] == "ok":
|
|
return True, resp["data"].get("avatar_file", "")
|
|
return False, resp["data"]["message"]
|
|
|
|
async def get_group_avatar(self, conv_id: str) -> bytes | None:
|
|
"""Download avatar for a group conversation."""
|
|
resp = await self.send_and_recv("get_group_avatar", 10.0, conversation_id=conv_id)
|
|
if resp["status"] == "ok":
|
|
return decode_binary(resp["data"]["data"])
|
|
return None
|
|
|
|
# ------------------------------------------------------------------
|
|
# Cleanup
|
|
# ------------------------------------------------------------------
|
|
|
|
async def close(self):
|
|
self.connected = False
|
|
for conv_id in list(self._typing_stop_tasks.keys()):
|
|
self._cancel_typing_timer(conv_id)
|
|
if self._listener_task:
|
|
self._listener_task.cancel()
|
|
if self.raw_writer:
|
|
self.raw_writer.close()
|
|
|
|
async def reconnect(self):
|
|
"""Close existing connection and re-establish: connect + re-login using in-memory keys."""
|
|
try:
|
|
await self.close()
|
|
except Exception:
|
|
pass
|
|
# Reset reader/writer but keep keys and sessions
|
|
self.reader = None
|
|
self.writer = None
|
|
self.raw_writer = None
|
|
self._listener_task = None
|
|
self._pending.clear()
|
|
self.login_rejected = False
|
|
# Drain queues
|
|
while not self._response_queue.empty():
|
|
try:
|
|
self._response_queue.get_nowait()
|
|
except Exception:
|
|
break
|
|
while not self._notification_queue.empty():
|
|
try:
|
|
self._notification_queue.get_nowait()
|
|
except Exception:
|
|
break
|
|
await self.connect()
|
|
self._listener_task = asyncio.create_task(self._background_listener())
|
|
if self.email and self.private_key:
|
|
# RSA challenge-response login (keys already in memory)
|
|
start = await self.send_and_recv("login_start", email=self.email)
|
|
if start["status"] == "ok":
|
|
challenge = decode_binary(start["data"]["challenge"])
|
|
signature = rsa_sign(self.private_key, challenge)
|
|
login_kwargs = {
|
|
"email": self.email,
|
|
"signature": encode_binary(signature),
|
|
"client_version": VERSION,
|
|
}
|
|
if self.device_id:
|
|
login_kwargs["device_id"] = self.device_id
|
|
finish = await self.send_and_recv("login_finish", **login_kwargs)
|
|
if finish["status"] == "ok":
|
|
self.session = finish["data"]
|
|
asyncio.create_task(self._ensure_prekeys())
|
|
else:
|
|
# Login rejected — keys were likely rotated on another device
|
|
self.session = None
|
|
self.connected = False
|
|
self.login_rejected = True
|