Files
Kecalek_python/chat_core.py
2026-03-11 16:54:14 +01:00

3482 lines
144 KiB
Python

"""Shared network layer and ChatClient class for CLI and GUI clients.
Uses X3DH + Double Ratchet for message encryption, Sender Keys for groups.
RSA retained for login challenge-response only.
"""
import asyncio
import hashlib
import json
import logging
import os
import ssl
import time
import uuid
from datetime import datetime, timezone
from pathlib import Path
from dotenv import load_dotenv
load_dotenv()
from crypto_utils import (
# RSA (login only)
generate_rsa_keypair,
serialize_private_key,
serialize_public_key,
load_private_key,
load_public_key,
rsa_sign,
# Ed25519
generate_identity_keypair,
serialize_ed25519_private,
serialize_ed25519_private_raw,
serialize_ed25519_public,
load_ed25519_private,
load_ed25519_public,
ed25519_sign,
# X25519
generate_x25519_keypair,
serialize_x25519_private,
serialize_x25519_public,
load_x25519_private,
load_x25519_public,
# X3DH
generate_signed_prekey,
generate_one_time_prekeys,
x3dh_initiate,
x3dh_respond,
# Double Ratchet
DoubleRatchet,
# Sender Keys
SenderKeyState,
# AES
aes_encrypt,
aes_decrypt,
# Self-encryption
derive_self_encryption_key,
# Local storage encryption
derive_local_storage_key,
# Contact verification
compute_fingerprint,
format_fingerprint,
compute_safety_number,
encode_verification_qr,
decode_verification_qr,
# Message padding
pad_plaintext,
unpad_plaintext,
)
from protocol import (
VERSION,
ProtocolReader,
ProtocolWriter,
encode_binary,
decode_binary,
build_request,
MAX_MESSAGE_BYTES,
IMAGE_CHUNK_SIZE,
)
KEY_DIR = Path.home() / ".encrypted_chat"
OPK_REPLENISH_THRESHOLD = 20
OPK_BATCH_SIZE = 50
SPK_ROTATION_DAYS = 7
def _encrypt_local(data: bytes, key: bytes) -> bytes:
"""Encrypt data with AES-256-GCM for local storage. Format: nonce(12) + tag(16) + ciphertext."""
_, nonce, ct, tag = aes_encrypt(data, key=key)
return nonce + tag + ct
def _decrypt_local(raw: bytes, key: bytes) -> bytes:
"""Decrypt data encrypted by _encrypt_local."""
nonce, tag, ct = raw[:12], raw[12:28], raw[28:]
return aes_decrypt(key, nonce, ct, tag)
def get_key_dir(email: str) -> Path:
d = KEY_DIR / email
d.mkdir(parents=True, exist_ok=True)
os.chmod(d, 0o700)
return d
# ---------------------------------------------------------------------------
# RSA key storage (login only — unchanged interface)
# ---------------------------------------------------------------------------
def save_keys(email: str, private_key, public_key, password: bytes | None = None):
d = get_key_dir(email)
(d / "private.pem").write_bytes(serialize_private_key(private_key, password=password))
(d / "public.pem").write_bytes(serialize_public_key(public_key))
os.chmod(d / "private.pem", 0o600)
def load_keys(email: str, password: bytes | None = None):
d = get_key_dir(email)
priv_path = d / "private.pem"
pub_path = d / "public.pem"
if not priv_path.exists():
return None, None, "No local keys found."
pem = priv_path.read_bytes()
try:
private_key = load_private_key(pem, password=password)
except Exception:
try:
private_key = load_private_key(pem, password=None)
if password:
save_keys(email, private_key, load_public_key(pub_path.read_bytes()), password=password)
except Exception:
return None, None, "Invalid or missing password."
public_key = load_public_key(pub_path.read_bytes())
return private_key, public_key, None
# ---------------------------------------------------------------------------
# Identity + prekey storage
# ---------------------------------------------------------------------------
def _save_identity_keys(email: str, ed_priv, ed_pub, password: bytes | None = None):
d = get_key_dir(email)
if password:
(d / "identity_private.bin").write_bytes(serialize_ed25519_private(ed_priv, password=password))
else:
(d / "identity_private.bin").write_bytes(serialize_ed25519_private_raw(ed_priv))
(d / "identity_public.bin").write_bytes(serialize_ed25519_public(ed_pub))
os.chmod(d / "identity_private.bin", 0o600)
def _load_identity_keys(email: str, password: bytes | None = None):
d = get_key_dir(email)
priv_path = d / "identity_private.bin"
pub_path = d / "identity_public.bin"
if not priv_path.exists():
return None, None
priv = load_ed25519_private(priv_path.read_bytes(), password=password)
pub = load_ed25519_public(pub_path.read_bytes())
return priv, pub
def _save_spk(email: str, spk_priv, spk_id: str, local_key: bytes | None = None):
d = get_key_dir(email)
raw = serialize_x25519_private(spk_priv)
data = _encrypt_local(raw, local_key) if local_key else raw
(d / "spk_private.bin").write_bytes(data)
(d / "spk_id.txt").write_text(spk_id)
os.chmod(d / "spk_private.bin", 0o600)
def _load_spk(email: str, local_key: bytes | None = None):
d = get_key_dir(email)
priv_path = d / "spk_private.bin"
id_path = d / "spk_id.txt"
if not priv_path.exists():
return None, None
raw = priv_path.read_bytes()
if local_key:
try:
raw = _decrypt_local(raw, local_key)
except Exception:
# Plaintext fallback (migration) — re-save encrypted
pass
priv = load_x25519_private(raw)
spk_id = id_path.read_text().strip() if id_path.exists() else ""
if local_key:
_save_spk(email, priv, spk_id, local_key)
return priv, spk_id
def _save_prev_spk(email: str, spk_priv, spk_id: str, local_key: bytes | None = None):
"""Save previous SPK for grace period (in-flight X3DH may reference old SPK)."""
d = get_key_dir(email)
raw = serialize_x25519_private(spk_priv)
data = _encrypt_local(raw, local_key) if local_key else raw
(d / "prev_spk_private.bin").write_bytes(data)
(d / "prev_spk_id.txt").write_text(spk_id)
os.chmod(d / "prev_spk_private.bin", 0o600)
def _load_prev_spk(email: str, local_key: bytes | None = None):
"""Load previous SPK (grace period). Returns (private_key, spk_id) or (None, None)."""
d = get_key_dir(email)
priv_path = d / "prev_spk_private.bin"
id_path = d / "prev_spk_id.txt"
if not priv_path.exists():
return None, None
raw = priv_path.read_bytes()
if local_key:
try:
raw = _decrypt_local(raw, local_key)
except Exception:
pass
priv = load_x25519_private(raw)
spk_id = id_path.read_text().strip() if id_path.exists() else ""
if local_key:
_save_prev_spk(email, priv, spk_id, local_key)
return priv, spk_id
def _save_opk_private(email: str, opk_id: str, opk_priv, local_key: bytes | None = None):
d = get_key_dir(email) / "opk_private"
d.mkdir(parents=True, exist_ok=True)
os.chmod(d, 0o700)
raw = serialize_x25519_private(opk_priv)
data = _encrypt_local(raw, local_key) if local_key else raw
(d / f"{opk_id}.bin").write_bytes(data)
os.chmod(d / f"{opk_id}.bin", 0o600)
def _load_opk_private(email: str, opk_id: str, local_key: bytes | None = None):
d = get_key_dir(email) / "opk_private"
p = d / f"{opk_id}.bin"
if not p.exists():
return None
raw = p.read_bytes()
if local_key:
try:
raw = _decrypt_local(raw, local_key)
except Exception:
pass
priv = load_x25519_private(raw)
# Migration: re-save encrypted if local_key provided
if local_key:
_save_opk_private(email, opk_id, priv, local_key)
return priv
def _secure_delete(p: Path):
"""Overwrite file with random data before deletion (anti-forensic wipe)."""
try:
if not p.exists():
return
size = p.stat().st_size
if size > 0:
with open(p, "r+b") as f:
f.write(os.urandom(size))
f.flush()
os.fsync(f.fileno())
p.unlink()
except Exception:
try:
p.unlink(missing_ok=True)
except Exception:
pass
def _delete_opk_private(email: str, opk_id: str):
d = get_key_dir(email) / "opk_private"
p = d / f"{opk_id}.bin"
_secure_delete(p)
def _save_device_id(email: str, device_id: str):
d = get_key_dir(email)
p = d / "device_id.txt"
p.write_text(device_id)
os.chmod(p, 0o600)
def _load_device_id(email: str) -> str | None:
d = get_key_dir(email)
p = d / "device_id.txt"
if not p.exists():
return None
return p.read_text().strip() or None
# ------------------------------------------------------------------
# Identity key change exception (TOFU hard-fail)
# ------------------------------------------------------------------
class IdentityKeyChanged(Exception):
"""Raised when a peer's identity key has changed (TOFU violation).
Session creation is blocked until the user explicitly accepts the new key.
"""
def __init__(self, user_id: str, new_key_bytes: bytes, status: str):
self.user_id = user_id
self.new_key_bytes = new_key_bytes
self.status = status # "changed" or "changed_verified"
super().__init__(
f"Identity key changed for user {user_id} (status={status}). "
f"Accept the new key before communicating."
)
# ------------------------------------------------------------------
# Client-side brute-force lockout
# ------------------------------------------------------------------
_LOCKOUT_BASE_SECONDS = 2
_LOCKOUT_MAX_SECONDS = 300 # 5 min cap
def _get_lockout_path(email: str) -> Path:
return get_key_dir(email) / "login_lockout.json"
def _check_lockout(email: str) -> float:
"""Return seconds remaining until next attempt allowed. 0 = can try now."""
p = _get_lockout_path(email)
if not p.exists():
return 0.0
try:
data = json.loads(p.read_text())
locked_until = data.get("locked_until", 0.0)
remaining = locked_until - time.time()
return max(0.0, remaining)
except Exception:
return 0.0
def _record_failed_attempt(email: str):
"""Increment failed counter, update locked_until."""
p = _get_lockout_path(email)
failed = 0
try:
if p.exists():
data = json.loads(p.read_text())
failed = data.get("failed_attempts", 0)
except Exception:
pass
failed += 1
delay = min(_LOCKOUT_BASE_SECONDS ** failed, _LOCKOUT_MAX_SECONDS)
locked_until = time.time() + delay
p.write_text(json.dumps({"failed_attempts": failed, "locked_until": locked_until}))
os.chmod(p, 0o600)
def _clear_lockout(email: str):
"""Reset on successful login."""
p = _get_lockout_path(email)
if p.exists():
try:
p.unlink()
except Exception:
pass
def _save_session(email: str, peer_user_id: str, ratchet: DoubleRatchet,
local_key: bytes | None = None, peer_device_id: str | None = None):
d = get_key_dir(email) / "sessions"
d.mkdir(parents=True, exist_ok=True)
os.chmod(d, 0o700)
if peer_device_id:
filename = f"{peer_user_id}_{peer_device_id}.bin"
else:
filename = f"{peer_user_id}.bin"
p = d / filename
data = ratchet.export_state()
if local_key:
data = _encrypt_local(data, local_key)
p.write_bytes(data)
os.chmod(p, 0o600)
def _load_session(email: str, peer_user_id: str,
local_key: bytes | None = None,
peer_device_id: str | None = None) -> DoubleRatchet | None:
d = get_key_dir(email) / "sessions"
if peer_device_id:
p = d / f"{peer_user_id}_{peer_device_id}.bin"
if not p.exists():
# Fallback: try old format (no device_id) and migrate
p_old = d / f"{peer_user_id}.bin"
if p_old.exists():
ratchet = _load_session_file(p_old, local_key)
if ratchet:
_save_session(email, peer_user_id, ratchet, local_key,
peer_device_id=peer_device_id)
_secure_delete(p_old)
return ratchet
return None
else:
p = d / f"{peer_user_id}.bin"
if not p.exists():
return None
return _load_session_file(p, local_key)
def _load_session_file(p: Path, local_key: bytes | None = None) -> DoubleRatchet | None:
"""Load a session from a specific file path."""
if not p.exists():
return None
raw = p.read_bytes()
if local_key:
try:
data = _decrypt_local(raw, local_key)
except Exception:
# Migration: try loading as plaintext (old unencrypted format)
try:
ratchet = DoubleRatchet.import_state(raw)
return ratchet
except Exception:
return None
return DoubleRatchet.import_state(data)
return DoubleRatchet.import_state(raw)
def _delete_session_file(email: str, peer_user_id: str, peer_device_id: str | None = None):
"""Securely delete a session file from disk (for session reset)."""
d = get_key_dir(email) / "sessions"
if peer_device_id:
p = d / f"{peer_user_id}_{peer_device_id}.bin"
else:
p = d / f"{peer_user_id}.bin"
_secure_delete(p)
def _save_sender_key_state(email: str, conv_id: str, state: SenderKeyState,
local_key: bytes | None = None):
d = get_key_dir(email) / "sender_keys"
d.mkdir(parents=True, exist_ok=True)
os.chmod(d, 0o700)
p = d / f"{conv_id}.bin"
data = state.export_state()
if local_key:
data = _encrypt_local(data, local_key)
p.write_bytes(data)
os.chmod(p, 0o600)
def _load_sender_key_state(email: str, conv_id: str,
local_key: bytes | None = None) -> SenderKeyState | None:
d = get_key_dir(email) / "sender_keys"
p = d / f"{conv_id}.bin"
if not p.exists():
return None
raw = p.read_bytes()
if local_key:
try:
data = _decrypt_local(raw, local_key)
except Exception:
try:
sk = SenderKeyState.import_state(raw)
_save_sender_key_state(email, conv_id, sk, local_key)
return sk
except Exception:
return None
return SenderKeyState.import_state(data)
return SenderKeyState.import_state(raw)
def _save_recv_sender_key(email: str, conv_id: str, sender_id: str, state: SenderKeyState,
local_key: bytes | None = None,
sender_device_id: str | None = None):
d = get_key_dir(email) / "sender_keys_recv"
d.mkdir(parents=True, exist_ok=True)
os.chmod(d, 0o700)
if sender_device_id:
filename = f"{conv_id}_{sender_id}_{sender_device_id}.bin"
else:
filename = f"{conv_id}_{sender_id}.bin"
p = d / filename
data = state.export_state()
if local_key:
data = _encrypt_local(data, local_key)
p.write_bytes(data)
os.chmod(p, 0o600)
def _load_recv_sender_key(email: str, conv_id: str, sender_id: str,
local_key: bytes | None = None,
sender_device_id: str | None = None) -> SenderKeyState | None:
d = get_key_dir(email) / "sender_keys_recv"
if sender_device_id:
p = d / f"{conv_id}_{sender_id}_{sender_device_id}.bin"
if not p.exists():
# Fallback: try old format and migrate
p_old = d / f"{conv_id}_{sender_id}.bin"
if p_old.exists():
sk = _load_recv_sender_key_file(p_old, local_key)
if sk:
_save_recv_sender_key(email, conv_id, sender_id, sk, local_key,
sender_device_id=sender_device_id)
_secure_delete(p_old)
return sk
return None
else:
p = d / f"{conv_id}_{sender_id}.bin"
if not p.exists():
return None
return _load_recv_sender_key_file(p, local_key)
def _load_recv_sender_key_file(p: Path, local_key: bytes | None = None) -> SenderKeyState | None:
"""Load a recv sender key from a specific file path."""
if not p.exists():
return None
raw = p.read_bytes()
if local_key:
try:
data = _decrypt_local(raw, local_key)
except Exception:
try:
sk = SenderKeyState.import_state(raw)
return sk
except Exception:
return None
return SenderKeyState.import_state(data)
return SenderKeyState.import_state(raw)
# ---------------------------------------------------------------------------
# Local decrypted message cache (Double Ratchet keys are one-time use)
# ---------------------------------------------------------------------------
def _load_message_cache(email: str, conv_id: str, cache_key: bytes | None = None) -> dict:
d = get_key_dir(email) / "message_cache"
p_bin = d / f"{conv_id}.bin"
p_json = d / f"{conv_id}.json"
# Migration: if old plaintext .json exists but encrypted .bin doesn't
if p_json.exists() and not p_bin.exists():
try:
cache = json.loads(p_json.read_text("utf-8"))
if cache_key:
_save_message_cache_full(d, conv_id, cache, cache_key)
_secure_delete(p_json)
return cache
except Exception:
return {}
if not p_bin.exists():
return {}
if not cache_key:
return {}
try:
raw = p_bin.read_bytes()
# Format: nonce (12) + tag (16) + ciphertext
nonce = raw[:12]
tag = raw[12:28]
ct = raw[28:]
plaintext = aes_decrypt(cache_key, nonce, ct, tag)
return json.loads(plaintext.decode("utf-8"))
except Exception:
return {}
def _save_message_cache_full(d: Path, conv_id: str, cache: dict, cache_key: bytes):
"""Write the full cache dict encrypted to disk."""
d.mkdir(parents=True, exist_ok=True)
os.chmod(d, 0o700)
p = d / f"{conv_id}.bin"
plaintext = json.dumps(cache, ensure_ascii=False).encode("utf-8")
_key, nonce, ct, tag = aes_encrypt(plaintext, key=cache_key)
p.write_bytes(nonce + tag + ct)
os.chmod(p, 0o600)
def _save_message_to_cache(email: str, conv_id: str, message_id: str, payload: dict,
cache_key: bytes | None = None):
d = get_key_dir(email) / "message_cache"
cache = _load_message_cache(email, conv_id, cache_key)
cache[message_id] = payload
if cache_key:
_save_message_cache_full(d, conv_id, cache, cache_key)
else:
# Fallback: plaintext (no identity key available yet)
d.mkdir(parents=True, exist_ok=True)
os.chmod(d, 0o700)
p = d / f"{conv_id}.json"
p.write_text(json.dumps(cache, ensure_ascii=False), "utf-8")
os.chmod(p, 0o600)
# ---------------------------------------------------------------------------
# Verification storage (TOFU + explicit verification)
# ---------------------------------------------------------------------------
def _save_known_identity_keys(email: str, keys: dict, local_key: bytes | None = None):
"""Save TOFU identity key registry (encrypted with local_key)."""
p = get_key_dir(email) / "known_identity_keys.bin"
data = json.dumps({"version": 1, "keys": keys}).encode("utf-8")
if local_key:
data = _encrypt_local(data, local_key)
p.write_bytes(data)
os.chmod(p, 0o600)
def _load_known_identity_keys(email: str, local_key: bytes | None = None) -> dict:
"""Load TOFU identity key registry. Returns empty dict on error.
No plaintext fallback — these files were never stored unencrypted
(feature introduced after local encryption was implemented).
Accepting plaintext would allow an attacker with disk access to
inject fake identity keys and bypass TOFU warnings.
"""
p = get_key_dir(email) / "known_identity_keys.bin"
if not p.exists():
return {}
raw = p.read_bytes()
try:
if local_key:
data = _decrypt_local(raw, local_key)
else:
data = raw
obj = json.loads(data)
return obj.get("keys", {})
except Exception:
return {}
def _save_verified_contacts(email: str, contacts: dict, local_key: bytes | None = None):
"""Save explicit verification state (encrypted with local_key)."""
p = get_key_dir(email) / "verified_contacts.bin"
data = json.dumps({"version": 1, "contacts": contacts}).encode("utf-8")
if local_key:
data = _encrypt_local(data, local_key)
p.write_bytes(data)
os.chmod(p, 0o600)
def _load_verified_contacts(email: str, local_key: bytes | None = None) -> dict:
"""Load explicit verification state. Returns empty dict on error.
No plaintext fallback — these files were never stored unencrypted.
Accepting plaintext would allow an attacker with disk access to
inject fake verification records (mark attacker as "verified").
"""
p = get_key_dir(email) / "verified_contacts.bin"
if not p.exists():
return {}
raw = p.read_bytes()
try:
if local_key:
data = _decrypt_local(raw, local_key)
else:
data = raw
obj = json.loads(data)
return obj.get("contacts", {})
except Exception:
return {}
def _solve_pow(challenge: str, difficulty: int) -> str:
"""Solve a proof-of-work challenge by finding a nonce with enough leading zero bits."""
target_bytes = difficulty // 8
target_bits = difficulty % 8
mask = (0xFF << (8 - target_bits)) & 0xFF if target_bits else 0
nonce = 0
while True:
digest = hashlib.sha256(f"{challenge}{nonce}".encode()).digest()
# Fast path: check full zero bytes first
ok = True
for i in range(target_bytes):
if digest[i] != 0:
ok = False
break
if ok and target_bits:
if digest[target_bytes] & mask:
ok = False
if ok:
return str(nonce)
nonce += 1
class ChatClient:
def __init__(self):
self.reader: ProtocolReader | None = None
self.writer: ProtocolWriter | None = None
self.raw_writer: asyncio.StreamWriter | None = None
self.session: dict | None = None
self.private_key = None # RSA private key (login only)
self.public_key = None # RSA public key (login only)
self.username: str = ""
self.email: str = ""
self._listener_task: asyncio.Task | None = None
self._response_queue: asyncio.Queue = asyncio.Queue()
self._notification_queue: asyncio.Queue = asyncio.Queue()
self._pending: dict[str, asyncio.Future] = {}
self._pairing_temp_private_key = None
self._reencrypt_progress_cb = None
self._logger = logging.getLogger("encrypted_chat.client")
# Signal Protocol keys
self.identity_private = None # Ed25519PrivateKey
self.identity_public = None # Ed25519PublicKey
self.spk_private = None # X25519PrivateKey (current signed prekey)
self.spk_id: str = ""
self._prev_spk_private = None # Previous SPK for grace period (M4)
self._prev_spk_id: str = ""
self.opk_privates: dict[str, object] = {} # id -> X25519PrivateKey
self.sessions: dict[str, DoubleRatchet] = {} # "user_id:device_id" -> ratchet
self.sender_key_states: dict[str, SenderKeyState] = {} # conv_id -> own sender key
self.recv_sender_keys: dict[str, SenderKeyState] = {} # "conv_id:sender_id:device_id" -> their key
# Cache: user_id -> {identity_key (Ed25519PublicKey), username, email}
self._user_cache: dict[str, dict] = {}
self.connected: bool = False
self.login_rejected: bool = False
self._cache_key: bytes | None = None # AES key for encrypting message cache on disk
self._local_key: bytes | None = None # AES key for encrypting session/sender key files
# Multi-device support
self.device_id: str | None = None # This device's UUID
self._device_bundle_cache: dict[str, tuple[float, list[dict]]] = {} # user_id -> (ts, bundles)
# Queue of received messages to self-encrypt for multi-device access
self._pending_self_encrypt: list[dict] = []
# Contact key verification (TOFU + explicit)
self._known_identity_keys: dict = {} # user_id -> {identity_key hex, first_seen, last_seen}
self._verified_contacts: dict = {} # user_id -> {identity_key hex, verified_at, method}
self._key_change_cb = None # callback(user_id, username, old_key_hex, was_verified)
async def connect(self):
host = os.getenv("SERVER_HOST", "127.0.0.1")
port = int(os.getenv("SERVER_PORT", "9999"))
tls_enabled = os.getenv("TLS_ENABLED", "false").lower() in ("1", "true", "yes")
tls_required = os.getenv("TLS_REQUIRED", "false").lower() in ("1", "true", "yes")
ssl_context = None
if tls_required and not tls_enabled:
raise RuntimeError("TLS_REQUIRED is enabled but TLS is not enabled.")
if tls_enabled:
insecure = os.getenv("TLS_INSECURE", "false").lower() in ("1", "true", "yes")
is_dev = os.getenv("ENVIRONMENT", "").lower() in ("dev", "development")
if insecure and not is_dev:
raise RuntimeError("TLS_INSECURE is only allowed when ENVIRONMENT=dev")
ssl_context = ssl.create_default_context()
ca_file = os.getenv("TLS_CA_FILE", "").strip()
if ca_file:
ssl_context.load_verify_locations(cafile=ca_file)
elif insecure:
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
else:
self._logger.warning("TLS is disabled — traffic is unencrypted. Set TLS_ENABLED=true for production.")
r, w = await asyncio.open_connection(host, port, limit=MAX_MESSAGE_BYTES, ssl=ssl_context)
self.reader = ProtocolReader(r)
self.writer = ProtocolWriter(w)
self.raw_writer = w
self.connected = True
async def _background_listener(self):
"""Read messages from server, routing responses vs notifications."""
while True:
msg = await self.reader.read_message()
if msg is None:
self.connected = False
# Fail all pending futures so send_and_recv doesn't hang
pending = dict(self._pending)
self._pending.clear()
err = ConnectionError("Server connection lost")
for fut in pending.values():
if not fut.done():
fut.set_exception(err)
break
if msg.get("type") in ("new_message", "messages_read", "message_deleted",
"conversation_created", "member_added", "member_removed",
"user_online", "user_offline", "online_users",
"group_invitation", "conversation_renamed",
"session_reset",
"message_reacted", "message_pinned", "message_unpinned",
"message_delivered", "username_changed"):
await self._notification_queue.put(msg)
else:
req_id = msg.get("request_id")
if req_id and req_id in self._pending:
fut = self._pending.pop(req_id)
if not fut.done():
fut.set_result(msg)
else:
await self._response_queue.put(msg)
async def send_and_recv(self, msg_type: str, timeout: float = 30.0, **kwargs) -> dict:
try:
request_id = str(uuid.uuid4())
loop = asyncio.get_running_loop()
fut = loop.create_future()
self._pending[request_id] = fut
await self.writer.send_request(msg_type, request_id=request_id, **kwargs)
except (ValueError, ConnectionError, OSError) as e:
self._pending.pop(request_id, None)
return {
"type": msg_type,
"status": "error",
"data": {"message": str(e) or "Connection lost."},
}
try:
return await asyncio.wait_for(fut, timeout=timeout)
except asyncio.TimeoutError:
self._logger.warning("send_and_recv timeout for '%s' after %.0fs", msg_type, timeout)
return {
"type": msg_type,
"status": "error",
"data": {"message": f"Request timed out ({msg_type})"},
}
except ConnectionError:
return {
"type": msg_type,
"status": "error",
"data": {"message": "Connection lost."},
}
finally:
self._pending.pop(request_id, None)
# ------------------------------------------------------------------
# User info / identity key cache
# ------------------------------------------------------------------
async def _get_user_info(self, user_id: str = "", email: str = "") -> dict | None:
"""Get user info from server, cache identity key. Performs TOFU check."""
cached = self._user_cache.get(user_id)
if cached:
return cached
kwargs = {}
if user_id:
kwargs["user_id"] = user_id
elif email:
kwargs["email"] = email
else:
return None
resp = await self.send_and_recv("get_user_info", **kwargs)
if resp["status"] != "ok":
return None
data = resp["data"]
ik_bytes = decode_binary(data["identity_key"]) if data.get("identity_key") else None
info = {
"user_id": data["user_id"],
"username": data["username"],
"email": data["email"],
"identity_key": load_ed25519_public(ik_bytes) if ik_bytes else None,
"identity_key_bytes": ik_bytes,
}
# TOFU: check identity key against known keys
if ik_bytes:
status = self.check_identity_key(data["user_id"], ik_bytes)
info["identity_key_status"] = status
self._user_cache[data["user_id"]] = info
return info
# ------------------------------------------------------------------
# Contact Key Verification
# ------------------------------------------------------------------
def _load_verification_stores(self):
"""Load TOFU and verification stores from disk."""
if not self.email:
return
self._known_identity_keys = _load_known_identity_keys(self.email, self._local_key)
self._verified_contacts = _load_verified_contacts(self.email, self._local_key)
def check_identity_key(self, user_id: str, identity_key_bytes: bytes) -> str:
"""Check a user's identity key against TOFU registry.
Returns:
"new" — first contact, key recorded (TOFU)
"trusted" — key matches previously seen, not explicitly verified
"verified" — key matches and explicitly verified
"changed" — key differs from recorded (WARNING)
"changed_verified" — key changed AND was previously verified (CRITICAL)
"""
ik_hex = identity_key_bytes.hex()
now = datetime.now(timezone.utc).isoformat()
known = self._known_identity_keys.get(user_id)
if known is None:
# First time seeing this user — TOFU: trust on first use
self._known_identity_keys[user_id] = {
"identity_key": ik_hex,
"first_seen": now,
"last_seen": now,
}
if self.email:
_save_known_identity_keys(self.email, self._known_identity_keys, self._local_key)
return "new"
if known["identity_key"] == ik_hex:
# Key matches — update last_seen
known["last_seen"] = now
if self.email:
_save_known_identity_keys(self.email, self._known_identity_keys, self._local_key)
# Check if explicitly verified
verified = self._verified_contacts.get(user_id)
if verified and verified.get("identity_key") == ik_hex:
return "verified"
return "trusted"
# Key has CHANGED
was_verified = user_id in self._verified_contacts
old_key_hex = known["identity_key"]
# Invoke callback for GUI/CLI warning
if self._key_change_cb:
username = ""
cached = self._user_cache.get(user_id)
if cached:
username = cached.get("username", "")
try:
self._key_change_cb(user_id, username, old_key_hex, was_verified, identity_key_bytes)
except Exception:
pass
return "changed_verified" if was_verified else "changed"
def verify_contact(self, user_id: str, identity_key_bytes: bytes, method: str = "manual"):
"""Mark a contact's identity key as explicitly verified."""
ik_hex = identity_key_bytes.hex()
now = datetime.now(timezone.utc).isoformat()
self._verified_contacts[user_id] = {
"identity_key": ik_hex,
"verified_at": now,
"method": method,
}
# Also ensure TOFU registry is up to date
if user_id not in self._known_identity_keys:
self._known_identity_keys[user_id] = {
"identity_key": ik_hex,
"first_seen": now,
"last_seen": now,
}
else:
self._known_identity_keys[user_id]["last_seen"] = now
if self.email:
_save_verified_contacts(self.email, self._verified_contacts, self._local_key)
_save_known_identity_keys(self.email, self._known_identity_keys, self._local_key)
# Update user cache status
cached = self._user_cache.get(user_id)
if cached:
cached["identity_key_status"] = "verified"
def unverify_contact(self, user_id: str):
"""Remove explicit verification for a contact."""
self._verified_contacts.pop(user_id, None)
if self.email:
_save_verified_contacts(self.email, self._verified_contacts, self._local_key)
cached = self._user_cache.get(user_id)
if cached and cached.get("identity_key_status") == "verified":
cached["identity_key_status"] = "trusted"
def accept_key_change(self, user_id: str, new_ik_bytes: bytes):
"""Accept a changed identity key — update TOFU, remove old verification."""
ik_hex = new_ik_bytes.hex()
now = datetime.now(timezone.utc).isoformat()
self._known_identity_keys[user_id] = {
"identity_key": ik_hex,
"first_seen": now,
"last_seen": now,
}
# Remove old verification — user must re-verify
self._verified_contacts.pop(user_id, None)
if self.email:
_save_known_identity_keys(self.email, self._known_identity_keys, self._local_key)
_save_verified_contacts(self.email, self._verified_contacts, self._local_key)
# Update cache
cached = self._user_cache.get(user_id)
if cached:
cached["identity_key_status"] = "trusted"
def get_verification_status(self, user_id: str) -> str:
"""Get verification status for a user.
Returns: "verified", "trusted", or "unverified".
"""
verified = self._verified_contacts.get(user_id)
if verified:
# Check key still matches
known = self._known_identity_keys.get(user_id)
if known and known.get("identity_key") == verified.get("identity_key"):
return "verified"
if user_id in self._known_identity_keys:
return "trusted"
return "unverified"
def get_safety_number(self, peer_user_id: str) -> str | None:
"""Get formatted safety number for a peer (requires both identity keys)."""
if not self.identity_public or not self.session:
return None
my_uid = self.session.get("user_id", "")
my_ik_bytes = serialize_ed25519_public(self.identity_public)
cached = self._user_cache.get(peer_user_id)
if not cached or not cached.get("identity_key_bytes"):
return None
return compute_safety_number(my_uid, my_ik_bytes,
peer_user_id, cached["identity_key_bytes"])
def get_my_fingerprint(self) -> str | None:
"""Get formatted fingerprint for own identity key."""
if not self.identity_public or not self.session:
return None
my_uid = self.session.get("user_id", "")
my_ik_bytes = serialize_ed25519_public(self.identity_public)
fp = compute_fingerprint(my_uid, my_ik_bytes)
return format_fingerprint(fp)
def get_peer_fingerprint(self, peer_user_id: str) -> str | None:
"""Get formatted fingerprint for a peer's identity key."""
cached = self._user_cache.get(peer_user_id)
if not cached or not cached.get("identity_key_bytes"):
return None
fp = compute_fingerprint(peer_user_id, cached["identity_key_bytes"])
return format_fingerprint(fp)
def get_verification_qr_data(self) -> bytes | None:
"""Get QR code payload bytes for own identity (for peer to scan)."""
if not self.identity_public or not self.session:
return None
my_uid = self.session.get("user_id", "")
my_ik_bytes = serialize_ed25519_public(self.identity_public)
return encode_verification_qr(my_uid, my_ik_bytes)
def verify_qr_code(self, qr_data: bytes) -> tuple[bool, str, str]:
"""Verify a scanned QR code against known identity keys.
Returns (success, user_id, message).
"""
try:
user_id, ik_bytes = decode_verification_qr(qr_data)
except ValueError as e:
return False, "", f"Invalid QR code: {e}"
cached = self._user_cache.get(user_id)
if not cached:
return False, user_id, "Unknown user — not in your contacts."
if not cached.get("identity_key_bytes"):
return False, user_id, "No identity key on record for this user."
if cached["identity_key_bytes"] != ik_bytes:
return False, user_id, "Identity key MISMATCH — verification failed!"
# Keys match — mark as verified
self.verify_contact(user_id, ik_bytes, method="qr_code")
username = cached.get("username", user_id[:8])
return True, user_id, f"Verified {username} via QR code."
# ------------------------------------------------------------------
# Registration
# ------------------------------------------------------------------
async def register(self, username: str, password: str, email: str) -> tuple[bool, str]:
"""Register user. Generates RSA + Ed25519 + prekeys."""
self.username = username
self.email = email
pwd_bytes = bytearray(password.encode("utf-8")) if password else None
try:
# RSA keys for login
priv, pub, err = load_keys(email, password=bytes(pwd_bytes) if pwd_bytes else None)
if priv is None:
priv, pub = generate_rsa_keypair()
save_keys(email, priv, pub, password=bytes(pwd_bytes) if pwd_bytes else None)
self.private_key = priv
self.public_key = pub
# Ed25519 identity keys
ed_priv, ed_pub = _load_identity_keys(email, password=bytes(pwd_bytes) if pwd_bytes else None)
if ed_priv is None:
ed_priv, ed_pub = generate_identity_keypair()
_save_identity_keys(email, ed_priv, ed_pub, password=bytes(pwd_bytes) if pwd_bytes else None)
self.identity_private = ed_priv
self.identity_public = ed_pub
self._cache_key = derive_self_encryption_key(ed_priv)
self._local_key = derive_local_storage_key(ed_priv)
self._load_verification_stores()
finally:
if pwd_bytes:
pwd_bytes[:] = b'\x00' * len(pwd_bytes)
pub_pem = serialize_public_key(pub).decode("utf-8")
ik_b64 = encode_binary(serialize_ed25519_public(ed_pub))
extra_fields: dict = {}
start = await self.send_and_recv(
"register",
username=username,
public_key=pub_pem,
email=email,
identity_key=ik_b64,
)
# Handle PoW challenge (server under pressure)
if start.get("status") == "pow_required":
challenge = start["data"]["challenge"]
mac = start["data"]["mac"]
difficulty = start["data"]["difficulty"]
logger.info("Server requires proof-of-work (difficulty %d), solving...", difficulty)
nonce = _solve_pow(challenge, difficulty)
extra_fields = {"pow_challenge": challenge, "pow_mac": mac, "pow_nonce": nonce}
start = await self.send_and_recv(
"register",
username=username,
public_key=pub_pem,
email=email,
identity_key=ik_b64,
**extra_fields,
)
if start["status"] != "ok":
return False, start["data"]["message"]
code = start["data"].get("code")
if code:
return True, code
return True, start["data"].get("message", "Check your email for the code.")
async def confirm_registration(self, email: str, username: str, code: str) -> tuple[bool, str]:
confirm = await self.send_and_recv("register_confirm", email=email, code=code)
if confirm["status"] == "ok":
# Upload prekeys immediately after registration
await self._generate_and_upload_prekeys()
return True, f"Registered as '{username}' (ID: {confirm['data']['user_id']})"
return False, confirm["data"]["message"]
async def _generate_and_upload_prekeys(self, keep_spk: bool = False):
"""Generate SPK + OPKs and upload to server.
If keep_spk=True, re-sign the existing SPK instead of generating a new
one. This is used after device pairing so both devices share the same
SPK and either can respond to X3DH.
"""
if not self.identity_private:
return
if keep_spk and self.spk_private and self.spk_id:
# Re-sign existing SPK (both devices share the identity key)
spk_pub_bytes = serialize_x25519_public(self.spk_private.public_key())
spk_sig = ed25519_sign(self.identity_private, spk_pub_bytes)
spk_data = {
"id": self.spk_id,
"public_key": encode_binary(spk_pub_bytes),
"signature": encode_binary(spk_sig),
}
else:
# Save current SPK as previous for grace period (M4: in-flight X3DH)
if self.spk_private and self.spk_id:
self._prev_spk_private = self.spk_private
self._prev_spk_id = self.spk_id
_save_prev_spk(self.email, self.spk_private, self.spk_id, self._local_key)
# Generate a brand-new signed prekey
spk = generate_signed_prekey(self.identity_private)
self.spk_private = spk["private"]
self.spk_id = spk["id"]
_save_spk(self.email, spk["private"], spk["id"], self._local_key)
spk_data = {
"id": spk["id"],
"public_key": encode_binary(serialize_x25519_public(spk["public"])),
"signature": encode_binary(spk["signature"]),
}
# Generate one-time prekeys
opks = generate_one_time_prekeys(OPK_BATCH_SIZE)
for opk in opks:
self.opk_privates[opk["id"]] = opk["private"]
_save_opk_private(self.email, opk["id"], opk["private"], self._local_key)
# Upload to server
otp_data = [
{"id": opk["id"], "public_key": encode_binary(serialize_x25519_public(opk["public"]))}
for opk in opks
]
resp = await self.send_and_recv(
"upload_prekeys",
signed_prekey=spk_data,
one_time_prekeys=otp_data,
)
if resp.get("status") != "ok":
self._logger.warning("upload_prekeys failed: %s (will retry on login)",
resp.get("data", {}).get("message", "unknown"))
async def _ensure_prekeys(self):
"""Check OPK count and SPK age, replenish/rotate if needed.
Uses single-roundtrip `ensure_prekeys` handler when available,
falls back to legacy two-step flow (get_prekey_count + upload_prekeys).
"""
resp = await self.send_and_recv("get_prekey_count")
if resp["status"] != "ok":
return
count = resp["data"].get("count", 0)
spk_created_at = resp["data"].get("spk_created_at", "")
need_new_spk = False
if spk_created_at:
try:
created = datetime.fromisoformat(spk_created_at)
if created.tzinfo is None:
created = created.replace(tzinfo=timezone.utc)
age_days = (datetime.now(timezone.utc) - created).days
if age_days >= SPK_ROTATION_DAYS:
need_new_spk = True
self._logger.info("SPK is %d days old, rotating...", age_days)
except Exception:
need_new_spk = True
else:
# No SPK on server for this device — must upload one
need_new_spk = True
self._logger.info("No SPK on server for this device, uploading...")
if count < OPK_REPLENISH_THRESHOLD or need_new_spk:
if count >= OPK_REPLENISH_THRESHOLD:
self._logger.info("SPK rotation triggered (OPK count OK: %d)", count)
else:
self._logger.info("OPK count low (%d), replenishing...", count)
await self._generate_and_upload_prekeys_batch(need_new_spk)
async def _generate_and_upload_prekeys_batch(self, need_new_spk: bool = False):
"""Generate and upload prekeys in a single round-trip via ensure_prekeys."""
if not self.identity_private:
return
kwargs: dict = {}
# SPK
if need_new_spk:
if self.spk_private and self.spk_id:
self._prev_spk_private = self.spk_private
self._prev_spk_id = self.spk_id
_save_prev_spk(self.email, self.spk_private, self.spk_id, self._local_key)
spk = generate_signed_prekey(self.identity_private)
self.spk_private = spk["private"]
self.spk_id = spk["id"]
_save_spk(self.email, spk["private"], spk["id"], self._local_key)
kwargs["signed_prekey"] = {
"id": spk["id"],
"public_key": encode_binary(serialize_x25519_public(spk["public"])),
"signature": encode_binary(spk["signature"]),
}
# OPKs
opks = generate_one_time_prekeys(OPK_BATCH_SIZE)
for opk in opks:
self.opk_privates[opk["id"]] = opk["private"]
_save_opk_private(self.email, opk["id"], opk["private"], self._local_key)
kwargs["one_time_prekeys"] = [
{"id": opk["id"], "public_key": encode_binary(serialize_x25519_public(opk["public"]))}
for opk in opks
]
resp = await self.send_and_recv("ensure_prekeys", **kwargs)
if resp["status"] == "ok":
data = resp.get("data", {})
self._logger.info("ensure_prekeys: count=%d, spk_uploaded=%s, otps_uploaded=%d",
data.get("count", 0), data.get("uploaded_spk", False),
data.get("uploaded_otps", 0))
# ------------------------------------------------------------------
# Login
# ------------------------------------------------------------------
async def login(self, email: str, password: str) -> tuple[bool, str]:
"""Login user. Returns (success, message)."""
self.email = email
# Brute-force lockout check
remaining = _check_lockout(email)
if remaining > 0:
return False, f"Too many failed attempts. Try again in {remaining:.0f}s."
pwd_bytes = bytearray(password.encode("utf-8")) if password else None
try:
# Load RSA keys
priv, pub, err = load_keys(email, password=bytes(pwd_bytes) if pwd_bytes else None)
if priv is None:
if err and "password" in err.lower():
_record_failed_attempt(email)
return False, err or "No local keys found. Register first."
self.private_key = priv
self.public_key = pub
# Load identity keys
ed_priv, ed_pub = _load_identity_keys(email, password=bytes(pwd_bytes) if pwd_bytes else None)
finally:
if pwd_bytes:
pwd_bytes[:] = b'\x00' * len(pwd_bytes)
if ed_priv is not None:
self.identity_private = ed_priv
self.identity_public = ed_pub
self._cache_key = derive_self_encryption_key(ed_priv)
self._local_key = derive_local_storage_key(ed_priv)
self._load_verification_stores()
# Load SPK
spk_priv, spk_id = _load_spk(email, self._local_key)
if spk_priv:
self.spk_private = spk_priv
self.spk_id = spk_id
# Load previous SPK for grace period (M4)
prev_spk_priv, prev_spk_id = _load_prev_spk(email, self._local_key)
if prev_spk_priv:
self._prev_spk_private = prev_spk_priv
self._prev_spk_id = prev_spk_id
# Load device_id from disk
self.device_id = _load_device_id(email)
# RSA challenge-response login
start = await self.send_and_recv("login_start", email=email)
if start["status"] != "ok":
return False, start["data"]["message"]
challenge = decode_binary(start["data"]["challenge"])
signature = rsa_sign(self.private_key, challenge)
login_kwargs = {"email": email, "signature": encode_binary(signature),
"client_version": VERSION}
if self.device_id:
login_kwargs["device_id"] = self.device_id
finish = await self.send_and_recv("login_finish", **login_kwargs)
if finish["status"] == "ok":
self.session = finish["data"]
self.username = self.session.get("username", "")
# Store device_id from server
self.device_id = finish["data"].get("device_id")
if self.device_id:
_save_device_id(email, self.device_id)
# Replenish prekeys in background — after pairing, the new device
# has no local OPK private keys so we must generate fresh ones
# (server-side OPKs have no matching private keys on this device).
# Use keep_spk=True to preserve the shared SPK so both devices
# can respond to X3DH.
opk_dir = get_key_dir(self.email) / "opk_private"
has_local_opks = opk_dir.exists() and any(opk_dir.iterdir())
if has_local_opks:
asyncio.create_task(self._ensure_prekeys())
else:
self._logger.info("No local OPKs (likely new device). Generating fresh OPKs, keeping SPK.")
asyncio.create_task(self._generate_and_upload_prekeys(keep_spk=True))
_clear_lockout(email)
return True, f"Logged in as '{self.username}' (ID: {self.session['user_id']})"
return False, finish["data"]["message"]
# ------------------------------------------------------------------
# Pairing (device pairing — transfers RSA + identity keys)
# ------------------------------------------------------------------
async def pairing_start(self, email: str) -> tuple[bool, str]:
"""Start device pairing. Returns (success, code/message)."""
temp_priv, temp_pub = generate_rsa_keypair(2048)
self._pairing_temp_private_key = temp_priv
temp_pub_pem = serialize_public_key(temp_pub).decode("utf-8")
resp = await self.send_and_recv("pairing_start", email=email, temp_public_key=temp_pub_pem)
if resp["status"] == "ok":
self._pairing_poll_token = resp["data"].get("poll_token", "")
return True, resp["data"]["code"]
return False, resp["data"]["message"]
async def pairing_wait(self, code: str, email: str, password: str, timeout: int = 300) -> tuple[bool, str]:
"""Wait for pairing payload and import keys. Returns (success, message)."""
if not self._pairing_temp_private_key:
return False, "Pairing not started."
from crypto_utils import aes_decrypt as _aes_decrypt
poll_token = getattr(self, "_pairing_poll_token", "")
deadline = asyncio.get_event_loop().time() + timeout
while asyncio.get_event_loop().time() < deadline:
resp = await self.send_and_recv("pairing_poll", code=code, poll_token=poll_token)
if resp["status"] != "ok":
return False, resp["data"]["message"]
if not resp["data"].get("ready"):
await asyncio.sleep(2.0)
continue
payload = resp["data"]["payload"]
try:
# Decrypt AES key with temp RSA key
from cryptography.hazmat.primitives.asymmetric import padding as rsa_padding
from cryptography.hazmat.primitives import hashes as rsa_hashes
enc_aes_key = decode_binary(payload["encrypted_key"])
aes_key = self._pairing_temp_private_key.decrypt(
enc_aes_key,
rsa_padding.OAEP(
mgf=rsa_padding.MGF1(algorithm=rsa_hashes.SHA256()),
algorithm=rsa_hashes.SHA256(),
label=None,
),
)
nonce = decode_binary(payload["iv"])
ct = decode_binary(payload["ciphertext"])
tag = decode_binary(payload["tag"])
keys_json = _aes_decrypt(aes_key, nonce, ct, tag)
keys_data = json.loads(keys_json)
pwd_bytes = bytearray(password.encode("utf-8")) if password else None
try:
# Import RSA key
rsa_priv = load_private_key(keys_data["rsa_private"].encode(), password=None)
rsa_pub = rsa_priv.public_key()
save_keys(email, rsa_priv, rsa_pub, password=bytes(pwd_bytes) if pwd_bytes else None)
# Import identity keys
ed_priv = load_ed25519_private(bytes.fromhex(keys_data["identity_private"]))
ed_pub = ed_priv.public_key()
_save_identity_keys(email, ed_priv, ed_pub, password=bytes(pwd_bytes) if pwd_bytes else None)
finally:
if pwd_bytes:
pwd_bytes[:] = b'\x00' * len(pwd_bytes)
self.email = email
self.private_key = rsa_priv
self.public_key = rsa_pub
self.identity_private = ed_priv
self.identity_public = ed_pub
self._cache_key = derive_self_encryption_key(ed_priv)
self._local_key = derive_local_storage_key(ed_priv)
self._load_verification_stores()
self._pairing_temp_private_key = None
# Multi-device: new device generates own SPK + OPKs on first
# login. No session/sender key import needed — each device
# has independent Double Ratchet sessions.
return True, "Pairing complete."
except Exception as e:
return False, f"Failed to import keys: {e}"
return False, "Pairing timed out."
async def authorize_device(self, code: str) -> tuple[bool, str]:
"""Authorize a new device by sending all keys to it."""
if not self.private_key or not self.identity_private:
return False, "Not logged in."
claim = await self.send_and_recv("pairing_claim", code=code)
if claim["status"] != "ok":
return False, claim["data"]["message"]
temp_pub_pem = claim["data"]["temp_public_key"].encode("utf-8")
temp_pub = load_public_key(temp_pub_pem)
# Phase 1: Re-encrypt message history so new device can read old
# messages via self-encryption key. This also advances ratchet states
# for any previously-unfetched messages.
try:
await self.reencrypt_history()
except Exception as e:
self._logger.warning("Re-encryption failed: %s", e)
# Phase 2: Build keys payload — only RSA + identity key.
# Multi-device: new device generates own SPK + OPKs, creates independent
# sessions. No session/sender key transfer needed.
keys_data = {
"rsa_private": serialize_private_key(self.private_key, password=None).decode(),
"identity_private": serialize_ed25519_private_raw(self.identity_private).hex(),
}
# Phase 3: Encrypt and send keys to new device
from cryptography.hazmat.primitives.asymmetric import padding as rsa_padding
from cryptography.hazmat.primitives import hashes as rsa_hashes
plaintext = json.dumps(keys_data).encode()
aes_key, nonce, ct, tag = aes_encrypt(plaintext)
enc_aes_key = temp_pub.encrypt(
aes_key,
rsa_padding.OAEP(
mgf=rsa_padding.MGF1(algorithm=rsa_hashes.SHA256()),
algorithm=rsa_hashes.SHA256(),
label=None,
),
)
payload = {
"encrypted_key": encode_binary(enc_aes_key),
"iv": encode_binary(nonce),
"ciphertext": encode_binary(ct),
"tag": encode_binary(tag),
}
resp = await self.send_and_recv("pairing_send", code=code, payload=payload)
if resp["status"] == "ok":
return True, "Device authorized."
return False, resp["data"]["message"]
# ------------------------------------------------------------------
# Password change (local key re-encryption only)
# ------------------------------------------------------------------
def change_password(self, old_password: str, new_password: str) -> tuple[bool, str]:
"""Change password for local key encryption (RSA + identity key).
Returns (success, message).
"""
if not self.email:
return False, "Not logged in."
old_pwd = bytearray(old_password.encode("utf-8"))
new_pwd = bytearray(new_password.encode("utf-8"))
try:
# 1. Verify old password by loading keys
priv, pub, err = load_keys(self.email, password=bytes(old_pwd))
if priv is None:
return False, "Wrong current password."
ed_priv, ed_pub = _load_identity_keys(self.email, password=bytes(old_pwd))
if ed_priv is None:
return False, "Failed to load identity key."
# 2. Re-save with new password
save_keys(self.email, priv, pub, password=bytes(new_pwd))
_save_identity_keys(self.email, ed_priv, ed_pub, password=bytes(new_pwd))
return True, "Password changed successfully."
finally:
old_pwd[:] = b'\x00' * len(old_pwd)
new_pwd[:] = b'\x00' * len(new_pwd)
async def change_username(self, new_username: str) -> tuple[bool, str]:
"""Change display name on server."""
if not self.session:
return False, "Not logged in."
new_username = new_username.strip()
if not new_username or len(new_username) > 100:
return False, "Username must be 1-100 characters."
resp = await self.send_and_recv("change_username", username=new_username)
if resp["status"] == "ok":
self.username = resp["data"]["username"]
if self.session:
self.session["username"] = self.username
return True, "Username changed."
return False, resp["data"].get("message", "Unknown error")
# ------------------------------------------------------------------
# Key rotation (RSA login key only)
# ------------------------------------------------------------------
async def rotate_keys(self, username: str, password: str) -> tuple[bool, str]:
"""Rotate RSA keypair to revoke other devices."""
if not self.session or self.session.get("username") != username:
return False, "Not logged in."
pwd_bytes = password.encode("utf-8") if password else None
priv, pub = generate_rsa_keypair()
save_keys(self.email, priv, pub, password=pwd_bytes)
self.private_key = priv
self.public_key = pub
pub_pem = serialize_public_key(pub).decode("utf-8")
resp = await self.send_and_recv("rotate_keys", public_key=pub_pem)
if resp["status"] == "ok":
return True, "RSA login keys rotated."
return False, resp["data"]["message"]
# ------------------------------------------------------------------
# Session management (X3DH + Double Ratchet)
# ------------------------------------------------------------------
async def _get_device_bundles(self, peer_user_id: str) -> list[dict]:
"""Get per-device key bundles for a peer. Caches for 5 minutes."""
import time
cached = self._device_bundle_cache.get(peer_user_id)
if cached:
ts, bundles = cached
if time.time() - ts < 300:
return bundles
resp = await self.send_and_recv("get_key_bundle", user_id=peer_user_id)
if resp["status"] != "ok":
raise RuntimeError(f"Cannot get key bundle for {peer_user_id}: {resp['data']['message']}")
data = resp["data"]
ik_b64 = data.get("identity_key", "")
device_bundles = data.get("device_bundles")
if device_bundles:
# Attach identity_key to each bundle
for b in device_bundles:
b["identity_key"] = ik_b64
else:
# Old server: wrap flat response as single-entry list
device_bundles = [{
"device_id": None,
"identity_key": ik_b64,
"signed_prekey_id": data.get("signed_prekey_id", ""),
"signed_prekey": data.get("signed_prekey", ""),
"spk_signature": data.get("spk_signature", ""),
"one_time_prekey_id": data.get("one_time_prekey_id"),
"one_time_prekey": data.get("one_time_prekey"),
}]
self._device_bundle_cache[peer_user_id] = (time.time(), device_bundles)
return device_bundles
async def _get_or_create_session(self, peer_user_id: str,
peer_device_id: str | None = None,
bundle: dict | None = None) -> DoubleRatchet:
"""Load existing session or create one via X3DH.
If peer_device_id is set, sessions are keyed by "user_id:device_id".
If bundle is provided, it's used instead of fetching from server.
"""
session_key = f"{peer_user_id}:{peer_device_id}" if peer_device_id else peer_user_id
# Check in-memory cache
if session_key in self.sessions:
return self.sessions[session_key]
# Check on disk
ratchet = _load_session(self.email, peer_user_id, self._local_key,
peer_device_id=peer_device_id)
if ratchet:
self.sessions[session_key] = ratchet
return ratchet
# Create new session via X3DH
if not bundle:
resp = await self.send_and_recv("get_key_bundle", user_id=peer_user_id)
if resp["status"] != "ok":
raise RuntimeError(f"Cannot get key bundle for {peer_user_id}: {resp['data']['message']}")
bundle = resp["data"]
ik_remote_bytes = decode_binary(bundle["identity_key"])
ik_remote = load_ed25519_public(ik_remote_bytes)
# TOFU: verify identity key before using it in X3DH
ik_status = self.check_identity_key(peer_user_id, ik_remote_bytes)
if ik_status in ("changed", "changed_verified"):
raise IdentityKeyChanged(peer_user_id, ik_remote_bytes, ik_status)
spk_remote = load_x25519_public(decode_binary(bundle["signed_prekey"]))
spk_sig = decode_binary(bundle["spk_signature"])
opk_remote = None
opk_id = bundle.get("one_time_prekey_id")
if bundle.get("one_time_prekey"):
opk_remote = load_x25519_public(decode_binary(bundle["one_time_prekey"]))
# Perform X3DH
shared_secret, ek_priv, ek_pub = x3dh_initiate(
self.identity_private,
ik_remote,
spk_remote,
spk_sig,
opk_remote,
)
# Initialize Double Ratchet as Alice
ratchet = DoubleRatchet.init_alice(shared_secret, spk_remote)
self.sessions[session_key] = ratchet
_save_session(self.email, peer_user_id, ratchet, self._local_key,
peer_device_id=peer_device_id)
# Build X3DH header for first message
x3dh_header = {
"ik": encode_binary(serialize_ed25519_public(self.identity_public)),
"ek": encode_binary(serialize_x25519_public(ek_pub)),
}
if opk_id:
x3dh_header["opk_id"] = opk_id
# Cache the x3dh header for the next send_message call
ratchet._x3dh_header = x3dh_header
# Cache remote user info
self._user_cache[peer_user_id] = {
"user_id": peer_user_id,
"identity_key": ik_remote,
"identity_key_bytes": ik_remote_bytes,
"identity_key_status": ik_status,
}
return ratchet
def _process_x3dh_header(self, sender_id: str, x3dh_header: dict,
sender_device_id: str | None = None,
spk_override=None) -> DoubleRatchet:
"""Process an incoming X3DH header to establish session as Bob.
Args:
spk_override: If provided, use this SPK private key instead of self.spk_private.
Used for grace period fallback (M4).
"""
ik_remote_bytes = decode_binary(x3dh_header["ik"])
ik_remote = load_ed25519_public(ik_remote_bytes)
# TOFU: verify identity key before using it in X3DH
ik_status = self.check_identity_key(sender_id, ik_remote_bytes)
if ik_status in ("changed", "changed_verified"):
raise IdentityKeyChanged(sender_id, ik_remote_bytes, ik_status)
ek_remote = load_x25519_public(decode_binary(x3dh_header["ek"]))
opk_id = x3dh_header.get("opk_id")
opk_priv = None
if opk_id:
opk_priv = _load_opk_private(self.email, opk_id, self._local_key)
if opk_priv:
_delete_opk_private(self.email, opk_id)
spk_priv = spk_override if spk_override else self.spk_private
shared_secret = x3dh_respond(
self.identity_private,
spk_priv,
ik_remote,
ek_remote,
opk_priv,
)
spk_pub = spk_priv.public_key() if hasattr(spk_priv, 'public_key') else None
ratchet = DoubleRatchet.init_bob(shared_secret, (spk_priv, spk_pub))
session_key = f"{sender_id}:{sender_device_id}" if sender_device_id else sender_id
self.sessions[session_key] = ratchet
_save_session(self.email, sender_id, ratchet, self._local_key,
peer_device_id=sender_device_id)
self._user_cache[sender_id] = {
"user_id": sender_id,
"identity_key": ik_remote,
"identity_key_bytes": ik_remote_bytes,
"identity_key_status": ik_status,
}
return ratchet
# ------------------------------------------------------------------
# Conversations
# ------------------------------------------------------------------
async def create_conversation(self, member_emails: list[str], name: str | None = None) -> tuple[str | None, str]:
kwargs = {"members": member_emails}
if name:
kwargs["name"] = name
resp = await self.send_and_recv("create_conversation", **kwargs)
if resp["status"] == "ok":
return resp["data"]["conversation_id"], "OK"
return None, resp["data"]["message"]
async def remove_member(self, conv_id: str, user_id: str) -> tuple[bool, str]:
resp = await self.send_and_recv("remove_member", conversation_id=conv_id, user_id=user_id)
if resp["status"] == "ok":
return True, "OK"
return False, resp["data"]["message"]
async def leave_group(self, conv_id: str) -> tuple[bool, str]:
"""Leave a group conversation."""
resp = await self.send_and_recv("leave_group", conversation_id=conv_id)
if resp["status"] == "ok":
# Clean up local sender key state for this group
self.sender_key_states.pop(conv_id, None)
# Remove received sender keys for this conversation
to_remove = [k for k in self.recv_sender_keys if k.startswith(f"{conv_id}:")]
for k in to_remove:
self.recv_sender_keys.pop(k, None)
return True, "OK"
return False, resp["data"]["message"]
async def rename_conversation(self, conv_id: str, name: str) -> tuple[bool, str]:
"""Rename a group conversation (creator only)."""
resp = await self.send_and_recv("rename_conversation", conversation_id=conv_id, name=name)
if resp["status"] == "ok":
return True, "OK"
return False, resp["data"]["message"]
async def delete_conversation(self, conv_id: str) -> tuple[bool, str]:
"""Delete a conversation (leave + server cleans up if empty)."""
resp = await self.send_and_recv("delete_conversation", conversation_id=conv_id)
if resp["status"] == "ok":
# Clean up local sender key state
self.sender_key_states.pop(conv_id, None)
to_remove = [k for k in self.recv_sender_keys if k.startswith(f"{conv_id}:")]
for k in to_remove:
self.recv_sender_keys.pop(k, None)
return True, "OK"
return False, resp["data"]["message"]
async def add_member(self, conv_id: str, email: str) -> tuple[bool, str]:
resp = await self.send_and_recv("add_member", conversation_id=conv_id, email=email)
if resp["status"] == "ok":
return True, "OK"
return False, resp["data"]["message"]
async def accept_invitation(self, conv_id: str) -> tuple[bool, str]:
"""Accept a group invitation."""
resp = await self.send_and_recv("accept_invitation", conversation_id=conv_id)
if resp["status"] == "ok":
return True, "OK"
return False, resp["data"]["message"]
async def decline_invitation(self, conv_id: str) -> tuple[bool, str]:
"""Decline a group invitation."""
resp = await self.send_and_recv("decline_invitation", conversation_id=conv_id)
if resp["status"] == "ok":
return True, "OK"
return False, resp["data"]["message"]
async def list_invitations(self) -> list[dict]:
"""List pending group invitations."""
resp = await self.send_and_recv("list_invitations")
if resp["status"] == "ok":
return resp["data"]["invitations"]
return []
async def list_conversations(self) -> list[dict]:
resp = await self.send_and_recv("list_conversations")
if resp["status"] == "ok":
return resp["data"]["conversations"]
return []
async def find_or_create_conversation(self, email: str) -> tuple[str | None, str]:
resp = await self.send_and_recv("find_conversation", email=email)
if resp["status"] != "ok":
return None, resp["data"]["message"]
conv_id = resp["data"]["conversation_id"]
if conv_id:
return conv_id, "OK"
return await self.create_conversation([email])
# ------------------------------------------------------------------
# Send message
# ------------------------------------------------------------------
def _is_group(self, members: list[dict]) -> bool:
return len(members) > 2
async def send_message(self, conv_id: str, text: str, members: list[dict],
reply_to: str | None = None) -> tuple[bool, str | dict]:
"""Encrypt and send a message. DM: per-recipient Double Ratchet. Group: Sender Keys.
Returns (True, msg_dict) on success or (False, error_string) on failure.
msg_dict contains the full decrypted payload ready for display.
"""
my_user_id = self.session["user_id"]
# Build plaintext payload
payload = {
"sender": self.username,
"text": text,
"reply_to": reply_to,
"timestamp": datetime.now(timezone.utc).isoformat(),
}
plaintext = pad_plaintext(json.dumps(payload, ensure_ascii=False).encode("utf-8"))
if self._is_group(members):
return await self._send_group_message(conv_id, plaintext, members, payload)
else:
return await self._send_dm(conv_id, plaintext, members, payload)
async def _send_dm(self, conv_id: str, plaintext: bytes, members: list[dict],
payload: dict | None = None) -> tuple[bool, str | dict]:
"""Encrypt DM with per-device Double Ratchet."""
my_user_id = self.session["user_id"]
recipients = []
first_ratchet_header = None
for member in members:
uid = member.get("user_id")
if not uid or uid == my_user_id:
continue
# Get all device bundles for this user
try:
device_bundles = await self._get_device_bundles(uid)
self._logger.debug("Got %d device bundles for %s", len(device_bundles), uid)
except Exception as e:
self._logger.warning("Failed to get device bundles for %s: %s", uid, e)
device_bundles = []
if not device_bundles:
# Fallback: try single session (legacy peer)
ratchet = await self._get_or_create_session(uid)
result = ratchet.encrypt(plaintext)
x3dh_hdr = getattr(ratchet, "_x3dh_header", None)
if x3dh_hdr:
delattr(ratchet, "_x3dh_header")
entry = {
"user_id": uid,
"encrypted_content": encode_binary(result["ciphertext"]),
"nonce": encode_binary(result["nonce"]),
"ratchet_header": result["header"],
}
if x3dh_hdr:
entry["x3dh_header"] = x3dh_hdr
recipients.append(entry)
if first_ratchet_header is None:
first_ratchet_header = result["header"]
_save_session(self.email, uid, ratchet, self._local_key)
continue
for bundle in device_bundles:
dev_id = bundle.get("device_id")
ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id,
bundle=bundle)
result = ratchet.encrypt(plaintext)
x3dh_hdr = getattr(ratchet, "_x3dh_header", None)
if x3dh_hdr:
delattr(ratchet, "_x3dh_header")
entry = {
"user_id": uid,
"encrypted_content": encode_binary(result["ciphertext"]),
"nonce": encode_binary(result["nonce"]),
"ratchet_header": result["header"],
}
if dev_id:
entry["device_id"] = dev_id
if x3dh_hdr:
entry["x3dh_header"] = x3dh_hdr
recipients.append(entry)
if first_ratchet_header is None:
first_ratchet_header = result["header"]
_save_session(self.email, uid, ratchet, self._local_key,
peer_device_id=dev_id)
# Encrypt self-copy with static key derived from identity (not ratchet)
# Uses SELF_DEVICE_ID so all own devices can read it
self_key = derive_self_encryption_key(self.identity_private)
_, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key)
recipients.append({
"user_id": my_user_id,
"encrypted_content": encode_binary(self_ct + self_tag),
"nonce": encode_binary(self_nonce),
"ratchet_header": {"self": True},
})
if not recipients:
return False, "No recipients."
kwargs = {
"conversation_id": conv_id,
"ratchet_header": first_ratchet_header,
"recipients": recipients,
}
resp = await self.send_and_recv("send_message", **kwargs)
if resp["status"] == "ok":
msg_data = resp.get("data", {})
if payload is not None:
result = {
**payload,
"message_id": msg_data.get("message_id", ""),
"created_at": msg_data.get("created_at", ""),
"sender_id": self.session["user_id"],
"conversation_id": conv_id,
"read_by": [],
}
_save_message_to_cache(self.email, conv_id, result["message_id"], result, self._cache_key)
return True, result
return True, "Message sent."
return False, resp["data"]["message"]
async def _send_group_message(self, conv_id: str, plaintext: bytes,
members: list[dict],
payload: dict | None = None) -> tuple[bool, str | dict]:
"""Encrypt group message with Sender Keys."""
my_user_id = self.session["user_id"]
# Get or create sender key for this group
sk = self.sender_key_states.get(conv_id)
if not sk:
sk = _load_sender_key_state(self.email, conv_id, self._local_key)
if not sk:
sk = SenderKeyState()
self.sender_key_states[conv_id] = sk
_save_sender_key_state(self.email, conv_id, sk, self._local_key)
# Distribute sender key to all members via pairwise ratchet
await self._distribute_sender_key(conv_id, members, sk)
self.sender_key_states[conv_id] = sk
# Encrypt with sender key
result = sk.encrypt(plaintext)
_save_sender_key_state(self.email, conv_id, sk, self._local_key)
# Build per-recipient entries (same ciphertext for all except self)
recipients = []
for member in members:
uid = member.get("user_id")
if not uid or uid == my_user_id:
continue
recipients.append({
"user_id": uid,
"encrypted_content": encode_binary(result["ciphertext"]),
"nonce": encode_binary(result["nonce"]),
})
# Self-encrypted copy (so other devices + history fetch can decrypt)
self_key = derive_self_encryption_key(self.identity_private)
_, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key)
recipients.append({
"user_id": my_user_id,
"encrypted_content": encode_binary(self_ct + self_tag),
"nonce": encode_binary(self_nonce),
"ratchet_header": {"self": True},
})
ratchet_header = {"dh_pub": "00" * 32, "n": 0, "pn": 0} # Dummy for groups
kwargs = {
"conversation_id": conv_id,
"ratchet_header": ratchet_header,
"recipients": recipients,
"sender_chain_id": encode_binary(bytes.fromhex(result["chain_id"])),
"sender_chain_n": result["n"],
}
resp = await self.send_and_recv("send_message", **kwargs)
if resp["status"] == "ok":
msg_data = resp.get("data", {})
if payload is not None:
result_msg = {
**payload,
"message_id": msg_data.get("message_id", ""),
"created_at": msg_data.get("created_at", ""),
"sender_id": self.session["user_id"],
"conversation_id": conv_id,
"read_by": [],
}
_save_message_to_cache(self.email, conv_id, result_msg["message_id"], result_msg, self._cache_key)
return True, result_msg
return True, "Message sent."
return False, resp["data"]["message"]
async def _distribute_sender_key(self, conv_id: str, members: list[dict],
sk: SenderKeyState):
"""Send own sender key to all group members via pairwise Double Ratchet (per-device)."""
my_user_id = self.session["user_id"]
exported_key = sk.export_key()
# Build a special "sender_key_distribution" payload
payload = {
"sender": self.username,
"text": "",
"reply_to": None,
"timestamp": datetime.now(timezone.utc).isoformat(),
"_sender_key": {
"conv_id": conv_id,
"key": encode_binary(exported_key),
"sender_device_id": self.device_id,
},
}
plaintext = pad_plaintext(json.dumps(payload, ensure_ascii=False).encode("utf-8"))
# Send as DM to each member's devices (per-device encryption)
for member in members:
uid = member.get("user_id")
if not uid or uid == my_user_id:
continue
try:
# Get all device bundles for this user
try:
device_bundles = await self._get_device_bundles(uid)
except Exception:
device_bundles = []
if not device_bundles:
# Fallback: legacy single-device
ratchet = await self._get_or_create_session(uid)
result = ratchet.encrypt(plaintext)
x3dh_header = getattr(ratchet, "_x3dh_header", None)
if x3dh_header:
delattr(ratchet, "_x3dh_header")
recipient_entry = {
"user_id": uid,
"encrypted_content": encode_binary(result["ciphertext"]),
"nonce": encode_binary(result["nonce"]),
"ratchet_header": result["header"],
}
if x3dh_header:
recipient_entry["x3dh_header"] = x3dh_header
kwargs = {
"conversation_id": conv_id,
"ratchet_header": result["header"],
"recipients": [recipient_entry],
}
await self.send_and_recv("send_message", **kwargs)
_save_session(self.email, uid, ratchet, self._local_key)
else:
# Per-device encryption
recipients = []
first_rh = None
for bundle in device_bundles:
dev_id = bundle.get("device_id")
ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id,
bundle=bundle)
result = ratchet.encrypt(plaintext)
x3dh_header = getattr(ratchet, "_x3dh_header", None)
if x3dh_header:
delattr(ratchet, "_x3dh_header")
entry = {
"user_id": uid,
"encrypted_content": encode_binary(result["ciphertext"]),
"nonce": encode_binary(result["nonce"]),
"ratchet_header": result["header"],
}
if dev_id:
entry["device_id"] = dev_id
if x3dh_header:
entry["x3dh_header"] = x3dh_header
recipients.append(entry)
if first_rh is None:
first_rh = result["header"]
_save_session(self.email, uid, ratchet, self._local_key,
peer_device_id=dev_id)
kwargs = {
"conversation_id": conv_id,
"ratchet_header": first_rh,
"recipients": recipients,
}
await self.send_and_recv("send_message", **kwargs)
except Exception as e:
self._logger.warning("Failed to distribute sender key to %s: %s", uid, e)
# ------------------------------------------------------------------
# Decrypt messages
# ------------------------------------------------------------------
def _decrypt_message(self, msg_data: dict) -> dict:
"""Decrypt a single message (DM or group)."""
# Check for self-encrypted marker FIRST — after re-encryption,
# group messages will have {"self": true} ratchet_header but still
# have sender_chain_id at message level.
rh = msg_data.get("ratchet_header", {})
if isinstance(rh, dict) and rh.get("self"):
return self._decrypt_dm(msg_data)
if msg_data.get("sender_chain_id"):
return self._decrypt_group(msg_data)
else:
return self._decrypt_dm(msg_data)
def _decrypt_dm(self, msg_data: dict) -> dict:
"""Decrypt DM using Double Ratchet with sender, or static key for self-copies."""
sender_id = msg_data.get("sender_id", "")
sender_device_id = msg_data.get("sender_device_id")
ratchet_header = msg_data.get("ratchet_header", {})
ct_b64 = msg_data.get("encrypted_content", "")
nonce_b64 = msg_data.get("nonce", "")
if not ct_b64 or not nonce_b64:
raise ValueError("Missing ciphertext or nonce")
ciphertext = decode_binary(ct_b64)
nonce = decode_binary(nonce_b64)
# Self-encrypted message (own sent message copy)
if isinstance(ratchet_header, dict) and ratchet_header.get("self"):
self_key = derive_self_encryption_key(self.identity_private)
ct = ciphertext[:-16]
tag = ciphertext[-16:]
plaintext = aes_decrypt(self_key, nonce, ct, tag)
else:
x3dh_header = msg_data.get("x3dh_header")
# Session key: "sender_id:sender_device_id" or just "sender_id" for legacy
session_key = f"{sender_id}:{sender_device_id}" if sender_device_id else sender_id
# Try to load existing session
ratchet = self.sessions.get(session_key)
if not ratchet:
ratchet = _load_session(self.email, sender_id, self._local_key,
peer_device_id=sender_device_id)
if ratchet:
self.sessions[session_key] = ratchet
if ratchet and not x3dh_header:
# Normal case: existing session, no X3DH header
plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce)
_save_session(self.email, sender_id, ratchet, self._local_key,
peer_device_id=sender_device_id)
elif x3dh_header:
if ratchet:
# Existing session + X3DH header: sender may have reset.
backup = ratchet.export_state()
try:
plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce)
_save_session(self.email, sender_id, ratchet, self._local_key,
peer_device_id=sender_device_id)
except Exception:
restored = DoubleRatchet.import_state(backup)
self.sessions[session_key] = restored
_save_session(self.email, sender_id, restored, self._local_key,
peer_device_id=sender_device_id)
ratchet = self._process_x3dh_header(sender_id, x3dh_header,
sender_device_id=sender_device_id)
try:
plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce)
except Exception:
if self._prev_spk_private:
ratchet = self._process_x3dh_header(
sender_id, x3dh_header,
sender_device_id=sender_device_id,
spk_override=self._prev_spk_private)
plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce)
else:
raise
_save_session(self.email, sender_id, ratchet, self._local_key,
peer_device_id=sender_device_id)
else:
ratchet = self._process_x3dh_header(sender_id, x3dh_header,
sender_device_id=sender_device_id)
try:
plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce)
except Exception:
if self._prev_spk_private:
ratchet = self._process_x3dh_header(
sender_id, x3dh_header,
sender_device_id=sender_device_id,
spk_override=self._prev_spk_private)
plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce)
else:
raise
_save_session(self.email, sender_id, ratchet, self._local_key,
peer_device_id=sender_device_id)
else:
raise ValueError(f"No session for sender {sender_id}")
plaintext = unpad_plaintext(plaintext)
payload = json.loads(plaintext)
# Handle sender key distribution messages
if "_sender_key" in payload:
sk_data = payload["_sender_key"]
sk_conv_id = sk_data["conv_id"]
sk_key = decode_binary(sk_data["key"])
sk_sender_device_id = sk_data.get("sender_device_id")
recv_sk = SenderKeyState.from_key(sk_key)
if sk_sender_device_id:
cache_key = f"{sk_conv_id}:{sender_id}:{sk_sender_device_id}"
else:
cache_key = f"{sk_conv_id}:{sender_id}"
self.recv_sender_keys[cache_key] = recv_sk
_save_recv_sender_key(self.email, sk_conv_id, sender_id, recv_sk, self._local_key,
sender_device_id=sk_sender_device_id)
# Return empty — this is a control message, not user-visible
return None
return payload
def _decrypt_group(self, msg_data: dict) -> dict:
"""Decrypt group message using sender's Sender Key."""
sender_id = msg_data.get("sender_id", "")
sender_device_id = msg_data.get("sender_device_id")
conv_id = msg_data.get("conversation_id", "")
chain_id_b64 = msg_data.get("sender_chain_id", "")
chain_n = msg_data.get("sender_chain_n", 0)
ct_b64 = msg_data.get("encrypted_content", "")
nonce_b64 = msg_data.get("nonce", "")
if not ct_b64 or not nonce_b64 or not chain_id_b64:
raise ValueError("Missing group message fields")
ciphertext = decode_binary(ct_b64)
nonce = decode_binary(nonce_b64)
chain_id = decode_binary(chain_id_b64)
my_user_id = self.session["user_id"]
# If we sent this message, use our own sender key
if sender_id == my_user_id:
sk = self.sender_key_states.get(conv_id)
if not sk:
sk = _load_sender_key_state(self.email, conv_id, self._local_key)
if sk:
self.sender_key_states[conv_id] = sk
if not sk:
raise ValueError("Own sender key not found")
# For our own messages, we can't decrypt from sender key (it's already advanced)
# Return a placeholder — the server echoed our ciphertext
raise ValueError("Cannot decrypt own group message from sender key")
# Use received sender key — try with sender_device_id first, fall back to without
sk = None
if sender_device_id:
cache_key = f"{conv_id}:{sender_id}:{sender_device_id}"
sk = self.recv_sender_keys.get(cache_key)
if not sk:
sk = _load_recv_sender_key(self.email, conv_id, sender_id, self._local_key,
sender_device_id=sender_device_id)
if sk:
self.recv_sender_keys[cache_key] = sk
if not sk:
# Fallback: try without device_id (legacy or same-device)
cache_key = f"{conv_id}:{sender_id}"
sk = self.recv_sender_keys.get(cache_key)
if not sk:
sk = _load_recv_sender_key(self.email, conv_id, sender_id, self._local_key)
if sk:
self.recv_sender_keys[cache_key] = sk
if not sk:
raise ValueError(f"No sender key for {sender_id} in conversation {conv_id}")
plaintext = unpad_plaintext(sk.decrypt(chain_id.hex(), chain_n, ciphertext, nonce))
_save_recv_sender_key(self.email, conv_id, sender_id, sk, self._local_key,
sender_device_id=sender_device_id)
return json.loads(plaintext)
# ------------------------------------------------------------------
# Get/decrypt messages (batch)
# ------------------------------------------------------------------
async def get_messages(self, conv_id: str, limit: int = 50, offset: int = 0) -> list[dict]:
cache = _load_message_cache(self.email, conv_id, self._cache_key)
my_user_id = self.session["user_id"] if self.session else ""
# Incremental sync: use stored server timestamp from last successful fetch.
after_ts = None
if cache and offset == 0:
after_ts = cache.get("__last_server_ts", {}).get("ts")
req_params = {"conversation_id": conv_id, "limit": limit, "offset": offset}
if after_ts:
req_params["after_ts"] = after_ts
resp = await self.send_and_recv("get_messages", **req_params)
if resp["status"] != "ok":
# Offline fallback: return from cache if available
if cache and offset == 0:
return self._build_from_cache(cache)
return []
raw_messages = resp["data"]["messages"]
raw_messages.reverse() # Server returns DESC, reverse to ASC
# Save latest server timestamp for next incremental sync
if raw_messages:
# raw_messages are now ASC; last one is newest
newest_ts = raw_messages[-1].get("created_at", "")
if newest_ts:
cache["__last_server_ts"] = {"ts": newest_ts}
_save_message_to_cache(self.email, conv_id, "__last_server_ts",
{"ts": newest_ts}, cache_key=self._cache_key)
# Decrypt new messages from server
new_decrypted = self._decrypt_raw_messages(raw_messages, cache, conv_id, my_user_id)
# Confirm delivery for messages from others (fire-and-forget)
deliver_ids = [m["message_id"] for m in new_decrypted
if m.get("sender_id") and m["sender_id"] != my_user_id
and not m.get("deleted")]
if deliver_ids:
asyncio.ensure_future(self.confirm_delivery(conv_id, deliver_ids))
# Mark entire conversation as read (bulk — server handles filtering)
await self.mark_conversation_read(conv_id)
# Flush self-encryption queue in background
if self._pending_self_encrypt:
asyncio.ensure_future(self._flush_self_encrypt())
if after_ts:
# Incremental: sync deletions, then build from cache
try:
del_resp = await self.send_and_recv("get_deleted_since",
conversation_id=conv_id, since=after_ts)
if del_resp.get("status") == "ok":
for del_id in del_resp.get("data", {}).get("message_ids", []):
cache.pop(del_id, None)
_save_message_to_cache(self.email, conv_id, del_id, {"deleted": True},
cache_key=self._cache_key)
except Exception:
pass
return self._build_from_cache(cache)
return new_decrypted
def _build_from_cache(self, cache: dict) -> list[dict]:
"""Build sorted message list from local cache (all messages)."""
messages = []
for msg_id, p in cache.items():
if p.get("_control") or msg_id.startswith("__"):
continue
entry = dict(p)
entry.setdefault("message_id", msg_id)
entry.setdefault("read_by", [])
entry.setdefault("delivered_to", [])
messages.append(entry)
messages.sort(key=lambda m: m.get("created_at", ""))
return messages
def _decrypt_raw_messages(self, raw_messages: list, cache: dict,
conv_id: str, my_user_id: str) -> list[dict]:
"""Decrypt server messages, update cache. Returns list of decrypted dicts."""
decrypted = []
for m in raw_messages:
msg_id = m["message_id"]
if m.get("deleted_at"):
decrypted.append({
"message_id": msg_id,
"sender": "",
"text": "",
"created_at": m["created_at"],
"read_by": [],
"sender_id": m.get("sender_id", ""),
"deleted": True,
})
cache[msg_id] = {"deleted": True, "created_at": m["created_at"]}
continue
# Check local cache first (ratchet keys are one-time use)
cached = cache.get(msg_id)
if cached and not cached.get("_control"):
cached["read_by"] = m.get("read_by", [])
cached["delivered_to"] = m.get("delivered_to", [])
cached["created_at"] = m["created_at"]
if m.get("reactions"):
cached["reactions"] = m["reactions"]
if m.get("pinned_at"):
cached["pinned_at"] = m["pinned_at"]
cached["pinned_by"] = m.get("pinned_by", "")
else:
cached.pop("pinned_at", None)
cached.pop("pinned_by", None)
decrypted.append(cached)
continue
if cached and cached.get("_control"):
continue
try:
msg_data = {
"sender_id": m.get("sender_id", ""),
"sender_device_id": m.get("sender_device_id"),
"conversation_id": conv_id,
"ratchet_header": m.get("ratchet_header", {}),
"encrypted_content": m.get("encrypted_content", ""),
"nonce": m.get("nonce", ""),
"x3dh_header": m.get("x3dh_header"),
"sender_chain_id": m.get("sender_chain_id"),
"sender_chain_n": m.get("sender_chain_n"),
}
payload = self._decrypt_message(msg_data)
if payload is None:
_save_message_to_cache(self.email, conv_id, msg_id, {"_control": True},
cache_key=self._cache_key)
cache[msg_id] = {"_control": True}
continue
payload["message_id"] = msg_id
payload["created_at"] = m["created_at"]
payload["read_by"] = m.get("read_by", [])
payload["delivered_to"] = m.get("delivered_to", [])
payload["sender_id"] = m.get("sender_id", "")
if m.get("reactions"):
payload["reactions"] = m["reactions"]
if m.get("pinned_at"):
payload["pinned_at"] = m["pinned_at"]
payload["pinned_by"] = m.get("pinned_by", "")
decrypted.append(payload)
_save_message_to_cache(self.email, conv_id, msg_id, payload,
cache_key=self._cache_key)
cache[msg_id] = payload
if m.get("sender_id", "") != my_user_id:
self._pending_self_encrypt.append({
"message_id": msg_id,
"payload": {k: v for k, v in payload.items()
if k not in ("message_id", "created_at", "read_by",
"delivered_to", "sender_id", "deleted")},
})
except Exception as e:
decrypted.append({
"message_id": msg_id,
"sender": "???",
"text": f"[Decryption failed: {e}]",
"created_at": m["created_at"],
"read_by": [],
})
return decrypted
async def _flush_self_encrypt(self):
"""Upload self-encrypted copies of received messages for multi-device access."""
if not self._pending_self_encrypt or not self.identity_private:
return
self_key = derive_self_encryption_key(self.identity_private)
updates = []
for item in list(self._pending_self_encrypt):
try:
plaintext = json.dumps(item["payload"], ensure_ascii=False).encode("utf-8")
_, nonce, ct, tag = aes_encrypt(plaintext, key=self_key)
updates.append({
"message_id": item["message_id"],
"encrypted_content": encode_binary(ct + tag),
"nonce": encode_binary(nonce),
})
except Exception:
pass
self._pending_self_encrypt.clear()
if updates:
try:
for i in range(0, len(updates), 500):
batch = updates[i:i + 500]
await self.send_and_recv("reencrypt_messages", updates=batch)
except Exception as e:
self._logger.warning("Failed to self-encrypt received messages: %s", e)
async def mark_read(self, conv_id: str, message_ids: list[str]):
if not message_ids:
return
await self.send_and_recv("mark_read", conversation_id=conv_id, message_ids=message_ids)
async def mark_conversation_read(self, conv_id: str):
"""Mark ALL unread messages in a conversation as read (server-side bulk)."""
try:
await self.send_and_recv("mark_conversation_read", conversation_id=conv_id)
except Exception:
pass # non-critical — don't fail message loading
async def confirm_delivery(self, conv_id: str, message_ids: list[str]):
"""Confirm delivery of messages (fire-and-forget, non-critical)."""
if not message_ids:
return
try:
await self.send_and_recv("confirm_delivery",
conversation_id=conv_id, message_ids=message_ids)
except Exception:
pass # non-critical
def search_messages(self, conv_id: str, query: str) -> list[dict]:
"""Search cached messages in a conversation. Returns matching messages."""
cache = _load_message_cache(self.email, conv_id, self._cache_key)
query_lower = query.lower()
results = []
for msg_id, payload in cache.items():
if payload.get("deleted") or payload.get("_control") or payload.get("_sender_key"):
continue
text = payload.get("text", "")
if query_lower in text.lower():
entry = dict(payload)
entry["message_id"] = msg_id
results.append(entry)
results.sort(key=lambda m: m.get("created_at", ""))
return results
async def reset_session(self, peer_user_id: str, peer_device_id: str | None = None):
"""Delete local session and notify peer to do the same."""
if peer_device_id:
session_key = f"{peer_user_id}:{peer_device_id}"
else:
session_key = peer_user_id
self.sessions.pop(session_key, None)
_delete_session_file(self.email, peer_user_id, peer_device_id)
await self.send_and_recv("session_reset",
peer_user_id=peer_user_id,
peer_device_id=peer_device_id or "")
def handle_session_reset_notification(self, from_user_id: str, from_device_id: str | None = None):
"""Handle incoming session reset notification — delete the matching session."""
if from_device_id:
session_key = f"{from_user_id}:{from_device_id}"
else:
session_key = from_user_id
self.sessions.pop(session_key, None)
_delete_session_file(self.email, from_user_id, from_device_id)
# ------------------------------------------------------------------
# Local message cache updates
# ------------------------------------------------------------------
def load_message_cache(self, conv_id: str) -> dict:
"""Load cached messages for a conversation. Returns {msg_id: payload}."""
if not self.email:
return {}
return _load_message_cache(self.email, conv_id, self._cache_key)
def update_message_in_cache(self, conv_id: str, message_id: str, updates: dict):
"""Update fields of a cached message on disk (synchronous)."""
if not self.email:
return
cache = _load_message_cache(self.email, conv_id, self._cache_key)
if message_id not in cache or cache[message_id].get("_control"):
return
for key, value in updates.items():
if value is None:
cache[message_id].pop(key, None)
else:
cache[message_id][key] = value
d = get_key_dir(self.email) / "message_cache"
if self._cache_key:
_save_message_cache_full(d, conv_id, cache, self._cache_key)
# ------------------------------------------------------------------
# Reactions, Pins, Forwarding
# ------------------------------------------------------------------
async def react_message(self, message_id: str, reaction: str, action: str = "add") -> tuple[bool, str]:
"""Add or remove a reaction on a message."""
resp = await self.send_and_recv("react_message",
message_id=message_id, reaction=reaction, action=action)
if resp["status"] == "ok":
return True, "OK"
return False, resp.get("data", {}).get("message", "Failed")
async def pin_message(self, message_id: str, conversation_id: str, action: str = "pin") -> tuple[bool, str]:
"""Pin or unpin a message."""
resp = await self.send_and_recv("pin_message",
message_id=message_id, conversation_id=conversation_id, action=action)
if resp["status"] == "ok":
return True, "OK"
return False, resp.get("data", {}).get("message", "Failed")
async def get_pinned_messages(self, conversation_id: str) -> list[dict]:
"""Get list of pinned messages for a conversation."""
resp = await self.send_and_recv("get_pinned_messages", conversation_id=conversation_id)
if resp["status"] == "ok":
return resp["data"].get("messages", [])
return []
async def forward_message(self, target_conv_id: str, original_msg: dict,
target_members: list[dict]) -> tuple[bool, str | dict]:
"""Forward a message to another conversation."""
text = original_msg.get("text", "")
payload = {
"sender": self.username,
"text": text,
"forwarded_from": {
"sender": original_msg.get("sender", ""),
"conversation_id": original_msg.get("conversation_id", ""),
"message_id": original_msg.get("message_id", ""),
},
"timestamp": datetime.now(timezone.utc).isoformat(),
}
# Forward image/file metadata (the encrypted blob is already on the server)
if original_msg.get("image"):
payload["image"] = original_msg["image"]
if not text:
payload["text"] = ""
if original_msg.get("file"):
payload["file"] = original_msg["file"]
if not text:
payload["text"] = ""
plaintext = pad_plaintext(json.dumps(payload, ensure_ascii=False).encode("utf-8"))
if self._is_group(target_members):
return await self._send_group_message(target_conv_id, plaintext, target_members, payload)
else:
return await self._send_dm(target_conv_id, plaintext, target_members, payload)
# ------------------------------------------------------------------
# Decrypt notification
# ------------------------------------------------------------------
def decrypt_notification(self, notif_data: dict) -> dict | None:
"""Decrypt a new_message notification. Returns parsed payload or None.
Supports new multi-device format (device_entries array) and legacy flat format.
"""
try:
conv_id = notif_data.get("conversation_id", "")
msg_id = notif_data.get("message_id", "")
sender_id = notif_data.get("sender_id", "")
sender_device_id = notif_data.get("sender_device_id")
my_user_id = self.session["user_id"] if self.session else ""
# Extract per-device encrypted content from device_entries or flat fields
encrypted_content = ""
nonce = ""
ratchet_header = {}
x3dh_header = None
device_entries = notif_data.get("device_entries")
if device_entries:
# Multi-device format: pick entry matching our device_id or SELF_DEVICE_ID
chosen = None
self_entry = None
for entry in device_entries:
eid = entry.get("device_id", "")
if eid == self.device_id:
chosen = entry
break
if eid == "00000000-0000-0000-0000-000000000000":
self_entry = entry
# If sender is us, prefer self-encrypted entry
if sender_id == my_user_id:
chosen = self_entry or chosen
elif not chosen:
chosen = self_entry
if not chosen:
self._logger.warning("No matching device_entry for device %s", self.device_id)
return None
encrypted_content = chosen.get("encrypted_content", "")
nonce = chosen.get("nonce", "")
ratchet_header = chosen.get("ratchet_header") or notif_data.get("ratchet_header", {})
x3dh_header = chosen.get("x3dh_header") or notif_data.get("x3dh_header")
else:
# Legacy flat format
encrypted_content = notif_data.get("encrypted_content", "")
nonce = notif_data.get("nonce", "")
ratchet_header = notif_data.get("ratchet_header", {})
x3dh_header = notif_data.get("x3dh_header")
msg_data = {
"sender_id": sender_id,
"sender_device_id": sender_device_id,
"conversation_id": conv_id,
"ratchet_header": ratchet_header,
"encrypted_content": encrypted_content,
"nonce": nonce,
"x3dh_header": x3dh_header,
"sender_chain_id": notif_data.get("sender_chain_id"),
"sender_chain_n": notif_data.get("sender_chain_n"),
}
payload = self._decrypt_message(msg_data)
if payload is None:
# Cache control message so get_messages skips it
if msg_id and conv_id:
_save_message_to_cache(self.email, conv_id, msg_id, {"_control": True},
cache_key=self._cache_key)
return None
payload["conversation_id"] = conv_id
payload["message_id"] = msg_id
payload["sender_id"] = sender_id
# Use server-compatible timestamp (no timezone suffix) for cache consistency
_ts = payload.get("timestamp", "")
if _ts:
# Strip timezone suffix (+00:00 or Z) to match server DATETIME format
_ts = _ts.replace("+00:00", "").replace("Z", "")
# Strip microseconds if present
if "." in _ts:
_ts = _ts[:_ts.index(".")]
payload["created_at"] = _ts
payload["read_by"] = []
payload["delivered_to"] = []
# Cache so get_messages doesn't re-decrypt (ratchet keys are one-time)
if msg_id and conv_id:
_save_message_to_cache(self.email, conv_id, msg_id, payload,
cache_key=self._cache_key)
# Queue self-encryption for received messages (multi-device access)
if sender_id != my_user_id and msg_id:
self._pending_self_encrypt.append({
"message_id": msg_id,
"payload": {k: v for k, v in payload.items()
if k not in ("conversation_id", "message_id", "created_at",
"read_by", "delivered_to", "sender_id", "deleted")},
})
return payload
except IdentityKeyChanged:
raise # Must propagate to caller for key-change UI
except Exception as e:
self._logger.warning("Failed to decrypt notification: %s", e)
return None
# ------------------------------------------------------------------
# Delete message
# ------------------------------------------------------------------
async def delete_message(self, message_id: str) -> tuple[bool, str]:
resp = await self.send_and_recv("delete_message", message_id=message_id)
if resp["status"] == "ok":
return True, "Message deleted."
return False, resp["data"]["message"]
# ------------------------------------------------------------------
# Image sharing
# ------------------------------------------------------------------
async def send_image(self, conv_id: str, image_path: str, members: list[dict],
reply_to: str | None = None) -> tuple[bool, str]:
"""Encrypt and upload an image, then send as a message."""
try:
from PIL import Image
import io
except ImportError:
return False, "Pillow is required for image sharing. Install with: pip install Pillow"
path = Path(image_path)
if not path.exists():
return False, "File not found."
try:
img = Image.open(path)
img.load()
except Exception as e:
return False, f"Cannot open image: {e}"
# Try sending in original format/quality first
original_format = img.format or "JPEG"
if original_format.upper() not in ("JPEG", "PNG", "WEBP", "GIF", "BMP"):
original_format = "JPEG"
# Read raw file bytes for original quality
image_bytes = path.read_bytes()
# If encrypted size exceeds limit, progressively downscale
if MAX_IMAGE_BYTES > 0:
img_aes_key_test, _, ct_test, tag_test = aes_encrypt(image_bytes)
if len(ct_test) + len(tag_test) > MAX_IMAGE_BYTES:
# Convert to RGB for JPEG compression
if img.mode not in ("RGB", "L"):
img = img.convert("RGB")
# Try JPEG at high quality first, then reduce quality/dimensions
for quality in (92, 85, 75, 60):
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=quality)
image_bytes = buf.getvalue()
_, _, ct_test, tag_test = aes_encrypt(image_bytes)
if len(ct_test) + len(tag_test) <= MAX_IMAGE_BYTES:
break
else:
# Still too large — downscale dimensions
for max_dim in (3840, 2560, 1920, 1280):
if max(img.size) > max_dim:
img.thumbnail((max_dim, max_dim), Image.Resampling.LANCZOS)
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=75)
image_bytes = buf.getvalue()
_, _, ct_test, tag_test = aes_encrypt(image_bytes)
if len(ct_test) + len(tag_test) <= MAX_IMAGE_BYTES:
break
# Generate thumbnail
thumb = img.copy()
thumb.thumbnail((200, 200), Image.Resampling.LANCZOS)
if thumb.mode not in ("RGB", "L"):
thumb = thumb.convert("RGB")
thumb_buf = io.BytesIO()
thumb.save(thumb_buf, format="JPEG", quality=60)
thumbnail_b64 = encode_binary(thumb_buf.getvalue())
# Encrypt image with AES-256-GCM
img_aes_key, img_iv, img_ct, img_tag = aes_encrypt(image_bytes)
encrypted_image = img_ct + img_tag
file_id = str(uuid.uuid4())
file_size = len(encrypted_image)
# Chunked upload
resp = await self.send_and_recv(
"upload_image_start",
conversation_id=conv_id,
file_id=file_id,
file_size=file_size,
)
if resp["status"] != "ok":
return False, resp["data"]["message"]
upload_offset = 0
while upload_offset < file_size:
chunk = encrypted_image[upload_offset:upload_offset + IMAGE_CHUNK_SIZE]
resp = await self.send_and_recv(
"upload_image_chunk",
file_id=file_id,
data=encode_binary(chunk),
)
if resp["status"] != "ok":
return False, resp["data"]["message"]
upload_offset += len(chunk)
resp = await self.send_and_recv("upload_image_end", file_id=file_id)
if resp["status"] != "ok":
return False, resp["data"]["message"]
# Build message payload with image info
image_info = {
"file_id": file_id,
"aes_key": encode_binary(img_aes_key),
"iv": encode_binary(img_iv),
"thumbnail": thumbnail_b64,
"filename": path.name,
"size": len(image_bytes),
}
payload = {
"sender": self.username,
"text": "",
"reply_to": reply_to,
"timestamp": datetime.now(timezone.utc).isoformat(),
"image": image_info,
}
plaintext = pad_plaintext(json.dumps(payload, ensure_ascii=False).encode("utf-8"))
my_user_id = self.session["user_id"]
if self._is_group(members):
# Group image: use sender key
sk = self.sender_key_states.get(conv_id)
if not sk:
sk = _load_sender_key_state(self.email, conv_id, self._local_key)
if not sk:
sk = SenderKeyState()
self.sender_key_states[conv_id] = sk
_save_sender_key_state(self.email, conv_id, sk, self._local_key)
await self._distribute_sender_key(conv_id, members, sk)
result = sk.encrypt(plaintext)
_save_sender_key_state(self.email, conv_id, sk, self._local_key)
recipients = []
for member in members:
uid = member.get("user_id")
if not uid or uid == my_user_id:
continue
recipients.append({
"user_id": uid,
"encrypted_content": encode_binary(result["ciphertext"]),
"nonce": encode_binary(result["nonce"]),
})
# Self-encrypted copy for sender
self_key = derive_self_encryption_key(self.identity_private)
_, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key)
recipients.append({
"user_id": my_user_id,
"encrypted_content": encode_binary(self_ct + self_tag),
"nonce": encode_binary(self_nonce),
"ratchet_header": {"self": True},
})
resp = await self.send_and_recv(
"send_message",
conversation_id=conv_id,
ratchet_header={"dh_pub": "00" * 32, "n": 0, "pn": 0},
recipients=recipients,
sender_chain_id=encode_binary(bytes.fromhex(result["chain_id"])),
sender_chain_n=result["n"],
image_file_id=file_id,
)
else:
# DM image: per-device ratchet (same pattern as _send_dm)
recipients = []
first_rh = None
for member in members:
uid = member.get("user_id")
if not uid or uid == my_user_id:
continue
try:
device_bundles = await self._get_device_bundles(uid)
except Exception:
device_bundles = []
if not device_bundles:
# Fallback: legacy single-device
ratchet = await self._get_or_create_session(uid)
result = ratchet.encrypt(plaintext)
x3dh_h = getattr(ratchet, "_x3dh_header", None)
if x3dh_h:
delattr(ratchet, "_x3dh_header")
entry = {
"user_id": uid,
"encrypted_content": encode_binary(result["ciphertext"]),
"nonce": encode_binary(result["nonce"]),
"ratchet_header": result["header"],
}
if x3dh_h:
entry["x3dh_header"] = x3dh_h
recipients.append(entry)
if first_rh is None:
first_rh = result["header"]
_save_session(self.email, uid, ratchet, self._local_key)
else:
for bundle in device_bundles:
dev_id = bundle.get("device_id")
ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id,
bundle=bundle)
result = ratchet.encrypt(plaintext)
x3dh_h = getattr(ratchet, "_x3dh_header", None)
if x3dh_h:
delattr(ratchet, "_x3dh_header")
entry = {
"user_id": uid,
"encrypted_content": encode_binary(result["ciphertext"]),
"nonce": encode_binary(result["nonce"]),
"ratchet_header": result["header"],
}
if dev_id:
entry["device_id"] = dev_id
if x3dh_h:
entry["x3dh_header"] = x3dh_h
recipients.append(entry)
if first_rh is None:
first_rh = result["header"]
_save_session(self.email, uid, ratchet, self._local_key,
peer_device_id=dev_id)
# Encrypt self-copy with static key
self_key = derive_self_encryption_key(self.identity_private)
_, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key)
recipients.append({
"user_id": my_user_id,
"encrypted_content": encode_binary(self_ct + self_tag),
"nonce": encode_binary(self_nonce),
"ratchet_header": {"self": True},
})
resp = await self.send_and_recv(
"send_message",
conversation_id=conv_id,
ratchet_header=first_rh,
recipients=recipients,
image_file_id=file_id,
)
if resp["status"] == "ok":
msg_data = resp.get("data", {})
result_msg = {
**payload,
"message_id": msg_data.get("message_id", ""),
"created_at": msg_data.get("created_at", ""),
"sender_id": self.session["user_id"],
"conversation_id": conv_id,
"read_by": [],
}
_save_message_to_cache(self.email, conv_id, result_msg["message_id"], result_msg, self._cache_key)
return True, result_msg
return False, resp["data"]["message"]
async def send_file(self, conv_id: str, file_path: str, members: list[dict],
reply_to: str | None = None) -> tuple[bool, str | dict]:
"""Encrypt and upload a file, then send as a message."""
import mimetypes
path = Path(file_path)
if not path.exists():
return False, "File not found."
try:
file_bytes = path.read_bytes()
except Exception as e:
return False, f"Cannot read file: {e}"
mime_type = mimetypes.guess_type(path.name)[0] or "application/octet-stream"
# Encrypt file with AES-256-GCM
file_aes_key, file_iv, file_ct, file_tag = aes_encrypt(file_bytes)
encrypted_file = file_ct + file_tag
file_id = str(uuid.uuid4())
file_size = len(encrypted_file)
# Chunked upload (reuse image upload infrastructure with file_type="file")
resp = await self.send_and_recv(
"upload_image_start",
conversation_id=conv_id,
file_id=file_id,
file_size=file_size,
file_type="file",
)
if resp["status"] != "ok":
return False, resp["data"]["message"]
upload_offset = 0
while upload_offset < file_size:
chunk = encrypted_file[upload_offset:upload_offset + IMAGE_CHUNK_SIZE]
resp = await self.send_and_recv(
"upload_image_chunk",
file_id=file_id,
data=encode_binary(chunk),
)
if resp["status"] != "ok":
return False, resp["data"]["message"]
upload_offset += len(chunk)
resp = await self.send_and_recv("upload_image_end", file_id=file_id)
if resp["status"] != "ok":
return False, resp["data"]["message"]
# Build message payload with file info
file_info = {
"file_id": file_id,
"aes_key": encode_binary(file_aes_key),
"iv": encode_binary(file_iv),
"filename": path.name,
"size": len(file_bytes),
"mime_type": mime_type,
}
payload = {
"sender": self.username,
"text": "",
"reply_to": reply_to,
"timestamp": datetime.now(timezone.utc).isoformat(),
"file": file_info,
}
plaintext = pad_plaintext(json.dumps(payload, ensure_ascii=False).encode("utf-8"))
my_user_id = self.session["user_id"]
if self._is_group(members):
sk = self.sender_key_states.get(conv_id)
if not sk:
sk = _load_sender_key_state(self.email, conv_id, self._local_key)
if not sk:
sk = SenderKeyState()
self.sender_key_states[conv_id] = sk
_save_sender_key_state(self.email, conv_id, sk, self._local_key)
await self._distribute_sender_key(conv_id, members, sk)
result = sk.encrypt(plaintext)
_save_sender_key_state(self.email, conv_id, sk, self._local_key)
recipients = []
for member in members:
uid = member.get("user_id")
if not uid or uid == my_user_id:
continue
recipients.append({
"user_id": uid,
"encrypted_content": encode_binary(result["ciphertext"]),
"nonce": encode_binary(result["nonce"]),
})
# Self-encrypted copy for sender
self_key = derive_self_encryption_key(self.identity_private)
_, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key)
recipients.append({
"user_id": my_user_id,
"encrypted_content": encode_binary(self_ct + self_tag),
"nonce": encode_binary(self_nonce),
"ratchet_header": {"self": True},
})
resp = await self.send_and_recv(
"send_message",
conversation_id=conv_id,
ratchet_header={"dh_pub": "00" * 32, "n": 0, "pn": 0},
recipients=recipients,
sender_chain_id=encode_binary(bytes.fromhex(result["chain_id"])),
sender_chain_n=result["n"],
image_file_id=file_id,
)
else:
# DM file: per-device ratchet (same pattern as _send_dm)
recipients = []
first_rh = None
for member in members:
uid = member.get("user_id")
if not uid or uid == my_user_id:
continue
try:
device_bundles = await self._get_device_bundles(uid)
except Exception:
device_bundles = []
if not device_bundles:
# Fallback: legacy single-device
ratchet = await self._get_or_create_session(uid)
result = ratchet.encrypt(plaintext)
x3dh_h = getattr(ratchet, "_x3dh_header", None)
if x3dh_h:
delattr(ratchet, "_x3dh_header")
entry = {
"user_id": uid,
"encrypted_content": encode_binary(result["ciphertext"]),
"nonce": encode_binary(result["nonce"]),
"ratchet_header": result["header"],
}
if x3dh_h:
entry["x3dh_header"] = x3dh_h
recipients.append(entry)
if first_rh is None:
first_rh = result["header"]
_save_session(self.email, uid, ratchet, self._local_key)
else:
for bundle in device_bundles:
dev_id = bundle.get("device_id")
ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id,
bundle=bundle)
result = ratchet.encrypt(plaintext)
x3dh_h = getattr(ratchet, "_x3dh_header", None)
if x3dh_h:
delattr(ratchet, "_x3dh_header")
entry = {
"user_id": uid,
"encrypted_content": encode_binary(result["ciphertext"]),
"nonce": encode_binary(result["nonce"]),
"ratchet_header": result["header"],
}
if dev_id:
entry["device_id"] = dev_id
if x3dh_h:
entry["x3dh_header"] = x3dh_h
recipients.append(entry)
if first_rh is None:
first_rh = result["header"]
_save_session(self.email, uid, ratchet, self._local_key,
peer_device_id=dev_id)
# Encrypt self-copy with static key
self_key = derive_self_encryption_key(self.identity_private)
_, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key)
recipients.append({
"user_id": my_user_id,
"encrypted_content": encode_binary(self_ct + self_tag),
"nonce": encode_binary(self_nonce),
"ratchet_header": {"self": True},
})
resp = await self.send_and_recv(
"send_message",
conversation_id=conv_id,
ratchet_header=first_rh,
recipients=recipients,
image_file_id=file_id,
)
if resp["status"] == "ok":
msg_data = resp.get("data", {})
result_msg = {
**payload,
"message_id": msg_data.get("message_id", ""),
"created_at": msg_data.get("created_at", ""),
"sender_id": self.session["user_id"],
"conversation_id": conv_id,
"read_by": [],
}
_save_message_to_cache(self.email, conv_id, result_msg["message_id"], result_msg, self._cache_key)
return True, result_msg
return False, resp["data"]["message"]
async def download_file(self, file_id: str, file_info: dict) -> bytes | None:
"""Download and decrypt a file. Returns decrypted file bytes or None."""
chunks = []
offset = 0
while True:
resp = await self.send_and_recv(
"download_image",
file_id=file_id,
offset=offset,
)
if resp["status"] != "ok":
return None
data = resp["data"]
chunk = decode_binary(data["data"])
chunks.append(chunk)
offset += len(chunk)
if data.get("done"):
break
encrypted_data = b"".join(chunks)
if len(encrypted_data) < 16:
return None
ciphertext = encrypted_data[:-16]
tag = encrypted_data[-16:]
try:
file_aes_key = decode_binary(file_info["aes_key"])
iv = decode_binary(file_info["iv"])
return aes_decrypt(file_aes_key, iv, ciphertext, tag)
except Exception:
return None
async def download_image(self, file_id: str, image_info: dict) -> bytes | None:
"""Download and decrypt an image. Returns decrypted image bytes or None."""
chunks = []
offset = 0
while True:
resp = await self.send_and_recv(
"download_image",
file_id=file_id,
offset=offset,
)
if resp["status"] != "ok":
return None
data = resp["data"]
chunk = decode_binary(data["data"])
chunks.append(chunk)
offset += len(chunk)
if data.get("done"):
break
encrypted_data = b"".join(chunks)
if len(encrypted_data) < 16:
return None
ciphertext = encrypted_data[:-16]
tag = encrypted_data[-16:]
try:
img_aes_key = decode_binary(image_info["aes_key"])
iv = decode_binary(image_info["iv"])
return aes_decrypt(img_aes_key, iv, ciphertext, tag)
except Exception:
return None
# ------------------------------------------------------------------
# Re-encrypt history (for device pairing)
# ------------------------------------------------------------------
async def reencrypt_history(self):
"""Re-encrypt all cached messages with self-encryption key.
After device pairing, the new device shares the same identity key
but cannot decrypt old messages (Double Ratchet keys are one-time use).
This re-encrypts all cached messages so they can be read using the
self-encryption key derived from the shared identity key.
"""
if not self.identity_private or not self.session:
return
self_key = derive_self_encryption_key(self.identity_private)
# Phase 1: Fetch & decrypt all messages to populate cache
# (messages the old device never opened won't be in cache yet)
try:
convs = await self.list_conversations()
total_convs = len(convs)
for ci, conv in enumerate(convs):
cid = conv.get("id") or conv.get("conversation_id")
if not cid:
continue
if self._reencrypt_progress_cb:
self._reencrypt_progress_cb(
f"Fetching messages: {ci + 1}/{total_convs} conversations..."
)
offset = 0
while True:
msgs = await self.get_messages(cid, limit=200, offset=offset)
if not msgs or len(msgs) < 200:
break
offset += len(msgs)
except Exception as e:
self._logger.warning("Failed to fetch messages for re-encryption: %s", e)
# Phase 2: Read cache and re-encrypt
cache_dir = get_key_dir(self.email) / "message_cache"
if not cache_dir.exists():
self._logger.info("No message cache to re-encrypt.")
return
all_updates = []
conv_ids = set()
for f in cache_dir.iterdir():
if f.suffix in (".json", ".bin"):
conv_ids.add(f.stem)
total_files = len(conv_ids)
for i, conv_id in enumerate(sorted(conv_ids)):
cache = _load_message_cache(self.email, conv_id, self._cache_key)
if not cache:
continue
for msg_id, entry in cache.items():
# Skip control messages (sender key distribution)
if entry.get("_control"):
continue
# Skip entries with no useful content
text = entry.get("text", "")
if not text and not entry.get("image") and not entry.get("file"):
continue
# Rebuild plaintext from cached payload
payload = {k: v for k, v in entry.items()
if k not in ("message_id", "created_at", "read_by", "sender_id", "deleted")}
plaintext = pad_plaintext(json.dumps(payload, ensure_ascii=False).encode("utf-8"))
# Re-encrypt with self-encryption key
_, nonce, ct, tag = aes_encrypt(plaintext, key=self_key)
all_updates.append({
"message_id": msg_id,
"encrypted_content": encode_binary(ct + tag),
"nonce": encode_binary(nonce),
})
if self._reencrypt_progress_cb:
self._reencrypt_progress_cb(f"Re-encrypting history: {i + 1}/{total_files} conversations...")
if not all_updates:
self._logger.info("No messages to re-encrypt.")
return
# Send in batches of 500
batch_size = 500
total = len(all_updates)
for start in range(0, total, batch_size):
batch = all_updates[start:start + batch_size]
resp = await self.send_and_recv("reencrypt_messages", updates=batch)
if resp["status"] != "ok":
self._logger.warning("Re-encrypt batch failed: %s", resp.get("data", {}).get("message", ""))
else:
self._logger.info("Re-encrypted %d/%d messages.", min(start + batch_size, total), total)
if self._reencrypt_progress_cb:
self._reencrypt_progress_cb(f"Re-encryption complete: {total} messages uploaded.")
# ------------------------------------------------------------------
# User Profiles
# ------------------------------------------------------------------
async def get_profile(self, user_id: str | None = None) -> dict | None:
"""Get user profile. If user_id is None, returns own profile."""
kwargs = {}
if user_id:
kwargs["user_id"] = user_id
resp = await self.send_and_recv("get_profile", **kwargs)
if resp["status"] == "ok":
return resp["data"]
return None
async def update_profile(self, **fields) -> tuple[bool, str]:
"""Update own profile (phone, location, *_visible)."""
resp = await self.send_and_recv("update_profile", **fields)
if resp["status"] == "ok":
return True, "OK"
return False, resp["data"]["message"]
async def update_avatar(self, image_data: bytes) -> tuple[bool, str]:
"""Upload avatar image."""
resp = await self.send_and_recv("update_avatar", data=encode_binary(image_data))
if resp["status"] == "ok":
return True, resp["data"].get("avatar_file", "")
return False, resp["data"]["message"]
async def get_avatar(self, user_id: str) -> bytes | None:
"""Download avatar for a user."""
resp = await self.send_and_recv("get_avatar", user_id=user_id)
if resp["status"] == "ok":
return decode_binary(resp["data"]["data"])
return None
async def update_group_avatar(self, conv_id: str, image_data: bytes) -> tuple[bool, str]:
"""Upload avatar for a group conversation."""
resp = await self.send_and_recv("update_group_avatar",
conversation_id=conv_id, data=encode_binary(image_data))
if resp["status"] == "ok":
return True, resp["data"].get("avatar_file", "")
return False, resp["data"]["message"]
async def get_group_avatar(self, conv_id: str) -> bytes | None:
"""Download avatar for a group conversation."""
resp = await self.send_and_recv("get_group_avatar", conversation_id=conv_id)
if resp["status"] == "ok":
return decode_binary(resp["data"]["data"])
return None
# ------------------------------------------------------------------
# Cleanup
# ------------------------------------------------------------------
async def close(self):
self.connected = False
if self._listener_task:
self._listener_task.cancel()
if self.raw_writer:
self.raw_writer.close()
async def reconnect(self):
"""Close existing connection and re-establish: connect + re-login using in-memory keys."""
try:
await self.close()
except Exception:
pass
# Reset reader/writer but keep keys and sessions
self.reader = None
self.writer = None
self.raw_writer = None
self._listener_task = None
self._pending.clear()
self.login_rejected = False
# Drain queues
while not self._response_queue.empty():
try:
self._response_queue.get_nowait()
except Exception:
break
while not self._notification_queue.empty():
try:
self._notification_queue.get_nowait()
except Exception:
break
await self.connect()
self._listener_task = asyncio.create_task(self._background_listener())
if self.email and self.private_key:
# RSA challenge-response login (keys already in memory)
start = await self.send_and_recv("login_start", email=self.email)
if start["status"] == "ok":
challenge = decode_binary(start["data"]["challenge"])
signature = rsa_sign(self.private_key, challenge)
login_kwargs = {
"email": self.email,
"signature": encode_binary(signature),
"client_version": VERSION,
}
if self.device_id:
login_kwargs["device_id"] = self.device_id
finish = await self.send_and_recv("login_finish", **login_kwargs)
if finish["status"] == "ok":
self.session = finish["data"]
asyncio.create_task(self._ensure_prekeys())
else:
# Login rejected — keys were likely rotated on another device
self.session = None
self.connected = False
self.login_rejected = True