Files
Kecalek_python/zaloha/chat_core.py
2026-03-11 16:54:14 +01:00

2610 lines
107 KiB
Python

"""Shared network layer and ChatClient class for CLI and GUI clients.
Uses X3DH + Double Ratchet for message encryption, Sender Keys for groups.
RSA retained for login challenge-response only.
"""
import asyncio
import json
import logging
import os
import ssl
import uuid
from datetime import datetime, timezone
from pathlib import Path
from dotenv import load_dotenv
load_dotenv()
from crypto_utils import (
# RSA (login only)
generate_rsa_keypair,
serialize_private_key,
serialize_public_key,
load_private_key,
load_public_key,
rsa_sign,
# Ed25519
generate_identity_keypair,
serialize_ed25519_private,
serialize_ed25519_private_raw,
serialize_ed25519_public,
load_ed25519_private,
load_ed25519_public,
ed25519_sign,
# X25519
generate_x25519_keypair,
serialize_x25519_private,
serialize_x25519_public,
load_x25519_private,
load_x25519_public,
# X3DH
generate_signed_prekey,
generate_one_time_prekeys,
x3dh_initiate,
x3dh_respond,
# Double Ratchet
DoubleRatchet,
# Sender Keys
SenderKeyState,
# AES
aes_encrypt,
aes_decrypt,
# Self-encryption
derive_self_encryption_key,
# Local storage encryption
derive_local_storage_key,
)
from protocol import (
VERSION,
ProtocolReader,
ProtocolWriter,
encode_binary,
decode_binary,
build_request,
MAX_MESSAGE_BYTES,
IMAGE_CHUNK_SIZE,
)
KEY_DIR = Path.home() / ".encrypted_chat"
OPK_REPLENISH_THRESHOLD = 20
OPK_BATCH_SIZE = 50
SPK_ROTATION_DAYS = 7
def _encrypt_local(data: bytes, key: bytes) -> bytes:
"""Encrypt data with AES-256-GCM for local storage. Format: nonce(12) + tag(16) + ciphertext."""
_, nonce, ct, tag = aes_encrypt(data, key=key)
return nonce + tag + ct
def _decrypt_local(raw: bytes, key: bytes) -> bytes:
"""Decrypt data encrypted by _encrypt_local."""
nonce, tag, ct = raw[:12], raw[12:28], raw[28:]
return aes_decrypt(key, nonce, ct, tag)
def get_key_dir(email: str) -> Path:
d = KEY_DIR / email
d.mkdir(parents=True, exist_ok=True)
os.chmod(d, 0o700)
return d
# ---------------------------------------------------------------------------
# RSA key storage (login only — unchanged interface)
# ---------------------------------------------------------------------------
def save_keys(email: str, private_key, public_key, password: bytes | None = None):
d = get_key_dir(email)
(d / "private.pem").write_bytes(serialize_private_key(private_key, password=password))
(d / "public.pem").write_bytes(serialize_public_key(public_key))
os.chmod(d / "private.pem", 0o600)
def load_keys(email: str, password: bytes | None = None):
d = get_key_dir(email)
priv_path = d / "private.pem"
pub_path = d / "public.pem"
if not priv_path.exists():
return None, None, "No local keys found."
pem = priv_path.read_bytes()
try:
private_key = load_private_key(pem, password=password)
except Exception:
try:
private_key = load_private_key(pem, password=None)
if password:
save_keys(email, private_key, load_public_key(pub_path.read_bytes()), password=password)
except Exception:
return None, None, "Invalid or missing password."
public_key = load_public_key(pub_path.read_bytes())
return private_key, public_key, None
# ---------------------------------------------------------------------------
# Identity + prekey storage
# ---------------------------------------------------------------------------
def _save_identity_keys(email: str, ed_priv, ed_pub, password: bytes | None = None):
d = get_key_dir(email)
if password:
(d / "identity_private.bin").write_bytes(serialize_ed25519_private(ed_priv, password=password))
else:
(d / "identity_private.bin").write_bytes(serialize_ed25519_private_raw(ed_priv))
(d / "identity_public.bin").write_bytes(serialize_ed25519_public(ed_pub))
os.chmod(d / "identity_private.bin", 0o600)
def _load_identity_keys(email: str, password: bytes | None = None):
d = get_key_dir(email)
priv_path = d / "identity_private.bin"
pub_path = d / "identity_public.bin"
if not priv_path.exists():
return None, None
priv = load_ed25519_private(priv_path.read_bytes(), password=password)
pub = load_ed25519_public(pub_path.read_bytes())
return priv, pub
def _save_spk(email: str, spk_priv, spk_id: str):
d = get_key_dir(email)
(d / "spk_private.bin").write_bytes(serialize_x25519_private(spk_priv))
(d / "spk_id.txt").write_text(spk_id)
os.chmod(d / "spk_private.bin", 0o600)
def _load_spk(email: str):
d = get_key_dir(email)
priv_path = d / "spk_private.bin"
id_path = d / "spk_id.txt"
if not priv_path.exists():
return None, None
priv = load_x25519_private(priv_path.read_bytes())
spk_id = id_path.read_text().strip() if id_path.exists() else ""
return priv, spk_id
def _save_prev_spk(email: str, spk_priv, spk_id: str):
"""Save previous SPK for grace period (in-flight X3DH may reference old SPK)."""
d = get_key_dir(email)
(d / "prev_spk_private.bin").write_bytes(serialize_x25519_private(spk_priv))
(d / "prev_spk_id.txt").write_text(spk_id)
os.chmod(d / "prev_spk_private.bin", 0o600)
def _load_prev_spk(email: str):
"""Load previous SPK (grace period). Returns (private_key, spk_id) or (None, None)."""
d = get_key_dir(email)
priv_path = d / "prev_spk_private.bin"
id_path = d / "prev_spk_id.txt"
if not priv_path.exists():
return None, None
priv = load_x25519_private(priv_path.read_bytes())
spk_id = id_path.read_text().strip() if id_path.exists() else ""
return priv, spk_id
def _save_opk_private(email: str, opk_id: str, opk_priv):
d = get_key_dir(email) / "opk_private"
d.mkdir(parents=True, exist_ok=True)
os.chmod(d, 0o700)
(d / f"{opk_id}.bin").write_bytes(serialize_x25519_private(opk_priv))
os.chmod(d / f"{opk_id}.bin", 0o600)
def _load_opk_private(email: str, opk_id: str):
d = get_key_dir(email) / "opk_private"
p = d / f"{opk_id}.bin"
if not p.exists():
return None
return load_x25519_private(p.read_bytes())
def _delete_opk_private(email: str, opk_id: str):
d = get_key_dir(email) / "opk_private"
p = d / f"{opk_id}.bin"
try:
p.unlink(missing_ok=True)
except Exception:
pass
def _save_device_id(email: str, device_id: str):
d = get_key_dir(email)
p = d / "device_id.txt"
p.write_text(device_id)
os.chmod(p, 0o600)
def _load_device_id(email: str) -> str | None:
d = get_key_dir(email)
p = d / "device_id.txt"
if not p.exists():
return None
return p.read_text().strip() or None
def _save_session(email: str, peer_user_id: str, ratchet: DoubleRatchet,
local_key: bytes | None = None, peer_device_id: str | None = None):
d = get_key_dir(email) / "sessions"
d.mkdir(parents=True, exist_ok=True)
os.chmod(d, 0o700)
if peer_device_id:
filename = f"{peer_user_id}_{peer_device_id}.bin"
else:
filename = f"{peer_user_id}.bin"
p = d / filename
data = ratchet.export_state()
if local_key:
data = _encrypt_local(data, local_key)
p.write_bytes(data)
os.chmod(p, 0o600)
def _load_session(email: str, peer_user_id: str,
local_key: bytes | None = None,
peer_device_id: str | None = None) -> DoubleRatchet | None:
d = get_key_dir(email) / "sessions"
if peer_device_id:
p = d / f"{peer_user_id}_{peer_device_id}.bin"
if not p.exists():
# Fallback: try old format (no device_id) and migrate
p_old = d / f"{peer_user_id}.bin"
if p_old.exists():
ratchet = _load_session_file(p_old, local_key)
if ratchet:
_save_session(email, peer_user_id, ratchet, local_key,
peer_device_id=peer_device_id)
try:
p_old.unlink()
except Exception:
pass
return ratchet
return None
else:
p = d / f"{peer_user_id}.bin"
if not p.exists():
return None
return _load_session_file(p, local_key)
def _load_session_file(p: Path, local_key: bytes | None = None) -> DoubleRatchet | None:
"""Load a session from a specific file path."""
if not p.exists():
return None
raw = p.read_bytes()
if local_key:
try:
data = _decrypt_local(raw, local_key)
except Exception:
# Migration: try loading as plaintext (old unencrypted format)
try:
ratchet = DoubleRatchet.import_state(raw)
return ratchet
except Exception:
return None
return DoubleRatchet.import_state(data)
return DoubleRatchet.import_state(raw)
def _delete_session_file(email: str, peer_user_id: str, peer_device_id: str | None = None):
"""Delete a session file from disk (for session reset)."""
d = get_key_dir(email) / "sessions"
if peer_device_id:
p = d / f"{peer_user_id}_{peer_device_id}.bin"
else:
p = d / f"{peer_user_id}.bin"
try:
p.unlink(missing_ok=True)
except Exception:
pass
def _save_sender_key_state(email: str, conv_id: str, state: SenderKeyState,
local_key: bytes | None = None):
d = get_key_dir(email) / "sender_keys"
d.mkdir(parents=True, exist_ok=True)
os.chmod(d, 0o700)
p = d / f"{conv_id}.bin"
data = state.export_state()
if local_key:
data = _encrypt_local(data, local_key)
p.write_bytes(data)
os.chmod(p, 0o600)
def _load_sender_key_state(email: str, conv_id: str,
local_key: bytes | None = None) -> SenderKeyState | None:
d = get_key_dir(email) / "sender_keys"
p = d / f"{conv_id}.bin"
if not p.exists():
return None
raw = p.read_bytes()
if local_key:
try:
data = _decrypt_local(raw, local_key)
except Exception:
try:
sk = SenderKeyState.import_state(raw)
_save_sender_key_state(email, conv_id, sk, local_key)
return sk
except Exception:
return None
return SenderKeyState.import_state(data)
return SenderKeyState.import_state(raw)
def _save_recv_sender_key(email: str, conv_id: str, sender_id: str, state: SenderKeyState,
local_key: bytes | None = None,
sender_device_id: str | None = None):
d = get_key_dir(email) / "sender_keys_recv"
d.mkdir(parents=True, exist_ok=True)
os.chmod(d, 0o700)
if sender_device_id:
filename = f"{conv_id}_{sender_id}_{sender_device_id}.bin"
else:
filename = f"{conv_id}_{sender_id}.bin"
p = d / filename
data = state.export_state()
if local_key:
data = _encrypt_local(data, local_key)
p.write_bytes(data)
os.chmod(p, 0o600)
def _load_recv_sender_key(email: str, conv_id: str, sender_id: str,
local_key: bytes | None = None,
sender_device_id: str | None = None) -> SenderKeyState | None:
d = get_key_dir(email) / "sender_keys_recv"
if sender_device_id:
p = d / f"{conv_id}_{sender_id}_{sender_device_id}.bin"
if not p.exists():
# Fallback: try old format and migrate
p_old = d / f"{conv_id}_{sender_id}.bin"
if p_old.exists():
sk = _load_recv_sender_key_file(p_old, local_key)
if sk:
_save_recv_sender_key(email, conv_id, sender_id, sk, local_key,
sender_device_id=sender_device_id)
try:
p_old.unlink()
except Exception:
pass
return sk
return None
else:
p = d / f"{conv_id}_{sender_id}.bin"
if not p.exists():
return None
return _load_recv_sender_key_file(p, local_key)
def _load_recv_sender_key_file(p: Path, local_key: bytes | None = None) -> SenderKeyState | None:
"""Load a recv sender key from a specific file path."""
if not p.exists():
return None
raw = p.read_bytes()
if local_key:
try:
data = _decrypt_local(raw, local_key)
except Exception:
try:
sk = SenderKeyState.import_state(raw)
return sk
except Exception:
return None
return SenderKeyState.import_state(data)
return SenderKeyState.import_state(raw)
# ---------------------------------------------------------------------------
# Local decrypted message cache (Double Ratchet keys are one-time use)
# ---------------------------------------------------------------------------
def _load_message_cache(email: str, conv_id: str, cache_key: bytes | None = None) -> dict:
d = get_key_dir(email) / "message_cache"
p_bin = d / f"{conv_id}.bin"
p_json = d / f"{conv_id}.json"
# Migration: if old plaintext .json exists but encrypted .bin doesn't
if p_json.exists() and not p_bin.exists():
try:
cache = json.loads(p_json.read_text("utf-8"))
if cache_key:
_save_message_cache_full(d, conv_id, cache, cache_key)
p_json.unlink(missing_ok=True)
return cache
except Exception:
return {}
if not p_bin.exists():
return {}
if not cache_key:
return {}
try:
raw = p_bin.read_bytes()
# Format: nonce (12) + tag (16) + ciphertext
nonce = raw[:12]
tag = raw[12:28]
ct = raw[28:]
plaintext = aes_decrypt(cache_key, nonce, ct, tag)
return json.loads(plaintext.decode("utf-8"))
except Exception:
return {}
def _save_message_cache_full(d: Path, conv_id: str, cache: dict, cache_key: bytes):
"""Write the full cache dict encrypted to disk."""
d.mkdir(parents=True, exist_ok=True)
os.chmod(d, 0o700)
p = d / f"{conv_id}.bin"
plaintext = json.dumps(cache, ensure_ascii=False).encode("utf-8")
_key, nonce, ct, tag = aes_encrypt(plaintext, key=cache_key)
p.write_bytes(nonce + tag + ct)
os.chmod(p, 0o600)
def _save_message_to_cache(email: str, conv_id: str, message_id: str, payload: dict,
cache_key: bytes | None = None):
d = get_key_dir(email) / "message_cache"
cache = _load_message_cache(email, conv_id, cache_key)
cache[message_id] = payload
if cache_key:
_save_message_cache_full(d, conv_id, cache, cache_key)
else:
# Fallback: plaintext (no identity key available yet)
d.mkdir(parents=True, exist_ok=True)
os.chmod(d, 0o700)
p = d / f"{conv_id}.json"
p.write_text(json.dumps(cache, ensure_ascii=False), "utf-8")
os.chmod(p, 0o600)
class ChatClient:
def __init__(self):
self.reader: ProtocolReader | None = None
self.writer: ProtocolWriter | None = None
self.raw_writer: asyncio.StreamWriter | None = None
self.session: dict | None = None
self.private_key = None # RSA private key (login only)
self.public_key = None # RSA public key (login only)
self.username: str = ""
self.email: str = ""
self._listener_task: asyncio.Task | None = None
self._response_queue: asyncio.Queue = asyncio.Queue()
self._notification_queue: asyncio.Queue = asyncio.Queue()
self._pending: dict[str, asyncio.Future] = {}
self._pairing_temp_private_key = None
self._reencrypt_progress_cb = None
self._logger = logging.getLogger("encrypted_chat.client")
# Signal Protocol keys
self.identity_private = None # Ed25519PrivateKey
self.identity_public = None # Ed25519PublicKey
self.spk_private = None # X25519PrivateKey (current signed prekey)
self.spk_id: str = ""
self._prev_spk_private = None # Previous SPK for grace period (M4)
self._prev_spk_id: str = ""
self.opk_privates: dict[str, object] = {} # id -> X25519PrivateKey
self.sessions: dict[str, DoubleRatchet] = {} # "user_id:device_id" -> ratchet
self.sender_key_states: dict[str, SenderKeyState] = {} # conv_id -> own sender key
self.recv_sender_keys: dict[str, SenderKeyState] = {} # "conv_id:sender_id:device_id" -> their key
# Cache: user_id -> {identity_key (Ed25519PublicKey), username, email}
self._user_cache: dict[str, dict] = {}
self.connected: bool = False
self.login_rejected: bool = False
self._cache_key: bytes | None = None # AES key for encrypting message cache on disk
self._local_key: bytes | None = None # AES key for encrypting session/sender key files
# Multi-device support
self.device_id: str | None = None # This device's UUID
self._device_bundle_cache: dict[str, tuple[float, list[dict]]] = {} # user_id -> (ts, bundles)
async def connect(self):
host = os.getenv("SERVER_HOST", "127.0.0.1")
port = int(os.getenv("SERVER_PORT", "9999"))
tls_enabled = os.getenv("TLS_ENABLED", "false").lower() in ("1", "true", "yes")
tls_required = os.getenv("TLS_REQUIRED", "false").lower() in ("1", "true", "yes")
ssl_context = None
if tls_required and not tls_enabled:
raise RuntimeError("TLS_REQUIRED is enabled but TLS is not enabled.")
if tls_enabled:
insecure = os.getenv("TLS_INSECURE", "false").lower() in ("1", "true", "yes")
is_dev = os.getenv("ENVIRONMENT", "").lower() in ("dev", "development")
if insecure and not is_dev:
raise RuntimeError("TLS_INSECURE is only allowed when ENVIRONMENT=dev")
ssl_context = ssl.create_default_context()
ca_file = os.getenv("TLS_CA_FILE", "").strip()
if ca_file:
ssl_context.load_verify_locations(cafile=ca_file)
elif insecure:
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
else:
self._logger.warning("TLS is disabled — traffic is unencrypted. Set TLS_ENABLED=true for production.")
r, w = await asyncio.open_connection(host, port, limit=MAX_MESSAGE_BYTES, ssl=ssl_context)
self.reader = ProtocolReader(r)
self.writer = ProtocolWriter(w)
self.raw_writer = w
self.connected = True
async def _background_listener(self):
"""Read messages from server, routing responses vs notifications."""
while True:
msg = await self.reader.read_message()
if msg is None:
self.connected = False
# Fail all pending futures so send_and_recv doesn't hang
pending = dict(self._pending)
self._pending.clear()
err = ConnectionError("Server connection lost")
for fut in pending.values():
if not fut.done():
fut.set_exception(err)
break
if msg.get("type") in ("new_message", "messages_read", "message_deleted",
"conversation_created", "member_added", "member_removed",
"user_online", "user_offline", "online_users",
"group_invitation", "conversation_renamed",
"session_reset"):
await self._notification_queue.put(msg)
else:
req_id = msg.get("request_id")
if req_id and req_id in self._pending:
fut = self._pending.pop(req_id)
if not fut.done():
fut.set_result(msg)
else:
await self._response_queue.put(msg)
async def send_and_recv(self, msg_type: str, timeout: float = 30.0, **kwargs) -> dict:
try:
request_id = str(uuid.uuid4())
loop = asyncio.get_running_loop()
fut = loop.create_future()
self._pending[request_id] = fut
await self.writer.send_request(msg_type, request_id=request_id, **kwargs)
except (ValueError, ConnectionError, OSError) as e:
self._pending.pop(request_id, None)
return {
"type": msg_type,
"status": "error",
"data": {"message": str(e) or "Connection lost."},
}
try:
return await asyncio.wait_for(fut, timeout=timeout)
except asyncio.TimeoutError:
self._logger.warning("send_and_recv timeout for '%s' after %.0fs", msg_type, timeout)
return {
"type": msg_type,
"status": "error",
"data": {"message": f"Request timed out ({msg_type})"},
}
except ConnectionError:
return {
"type": msg_type,
"status": "error",
"data": {"message": "Connection lost."},
}
finally:
self._pending.pop(request_id, None)
# ------------------------------------------------------------------
# User info / identity key cache
# ------------------------------------------------------------------
async def _get_user_info(self, user_id: str = "", email: str = "") -> dict | None:
"""Get user info from server, cache identity key."""
cached = self._user_cache.get(user_id)
if cached:
return cached
kwargs = {}
if user_id:
kwargs["user_id"] = user_id
elif email:
kwargs["email"] = email
else:
return None
resp = await self.send_and_recv("get_user_info", **kwargs)
if resp["status"] != "ok":
return None
data = resp["data"]
ik_bytes = decode_binary(data["identity_key"]) if data.get("identity_key") else None
info = {
"user_id": data["user_id"],
"username": data["username"],
"email": data["email"],
"identity_key": load_ed25519_public(ik_bytes) if ik_bytes else None,
"identity_key_bytes": ik_bytes,
}
self._user_cache[data["user_id"]] = info
return info
# ------------------------------------------------------------------
# Registration
# ------------------------------------------------------------------
async def register(self, username: str, password: str, email: str) -> tuple[bool, str]:
"""Register user. Generates RSA + Ed25519 + prekeys."""
self.username = username
self.email = email
pwd_bytes = bytearray(password.encode("utf-8")) if password else None
try:
# RSA keys for login
priv, pub, err = load_keys(email, password=bytes(pwd_bytes) if pwd_bytes else None)
if priv is None:
priv, pub = generate_rsa_keypair()
save_keys(email, priv, pub, password=bytes(pwd_bytes) if pwd_bytes else None)
self.private_key = priv
self.public_key = pub
# Ed25519 identity keys
ed_priv, ed_pub = _load_identity_keys(email, password=bytes(pwd_bytes) if pwd_bytes else None)
if ed_priv is None:
ed_priv, ed_pub = generate_identity_keypair()
_save_identity_keys(email, ed_priv, ed_pub, password=bytes(pwd_bytes) if pwd_bytes else None)
self.identity_private = ed_priv
self.identity_public = ed_pub
self._cache_key = derive_self_encryption_key(ed_priv)
self._local_key = derive_local_storage_key(ed_priv)
finally:
if pwd_bytes:
pwd_bytes[:] = b'\x00' * len(pwd_bytes)
pub_pem = serialize_public_key(pub).decode("utf-8")
ik_b64 = encode_binary(serialize_ed25519_public(ed_pub))
start = await self.send_and_recv(
"register",
username=username,
public_key=pub_pem,
email=email,
identity_key=ik_b64,
)
if start["status"] != "ok":
return False, start["data"]["message"]
code = start["data"].get("code")
if code:
return True, code
return True, start["data"].get("message", "Check your email for the code.")
async def confirm_registration(self, email: str, username: str, code: str) -> tuple[bool, str]:
confirm = await self.send_and_recv("register_confirm", email=email, code=code)
if confirm["status"] == "ok":
# Upload prekeys immediately after registration
await self._generate_and_upload_prekeys()
return True, f"Registered as '{username}' (ID: {confirm['data']['user_id']})"
return False, confirm["data"]["message"]
async def _generate_and_upload_prekeys(self, keep_spk: bool = False):
"""Generate SPK + OPKs and upload to server.
If keep_spk=True, re-sign the existing SPK instead of generating a new
one. This is used after device pairing so both devices share the same
SPK and either can respond to X3DH.
"""
if not self.identity_private:
return
if keep_spk and self.spk_private and self.spk_id:
# Re-sign existing SPK (both devices share the identity key)
spk_pub_bytes = serialize_x25519_public(self.spk_private.public_key())
spk_sig = ed25519_sign(self.identity_private, spk_pub_bytes)
spk_data = {
"id": self.spk_id,
"public_key": encode_binary(spk_pub_bytes),
"signature": encode_binary(spk_sig),
}
else:
# Save current SPK as previous for grace period (M4: in-flight X3DH)
if self.spk_private and self.spk_id:
self._prev_spk_private = self.spk_private
self._prev_spk_id = self.spk_id
_save_prev_spk(self.email, self.spk_private, self.spk_id)
# Generate a brand-new signed prekey
spk = generate_signed_prekey(self.identity_private)
self.spk_private = spk["private"]
self.spk_id = spk["id"]
_save_spk(self.email, spk["private"], spk["id"])
spk_data = {
"id": spk["id"],
"public_key": encode_binary(serialize_x25519_public(spk["public"])),
"signature": encode_binary(spk["signature"]),
}
# Generate one-time prekeys
opks = generate_one_time_prekeys(OPK_BATCH_SIZE)
for opk in opks:
self.opk_privates[opk["id"]] = opk["private"]
_save_opk_private(self.email, opk["id"], opk["private"])
# Upload to server
otp_data = [
{"id": opk["id"], "public_key": encode_binary(serialize_x25519_public(opk["public"]))}
for opk in opks
]
await self.send_and_recv(
"upload_prekeys",
signed_prekey=spk_data,
one_time_prekeys=otp_data,
)
async def _ensure_prekeys(self):
"""Check OPK count and SPK age, replenish/rotate if needed."""
resp = await self.send_and_recv("get_prekey_count")
if resp["status"] != "ok":
return
count = resp["data"].get("count", 0)
spk_created_at = resp["data"].get("spk_created_at", "")
need_new_spk = False
if spk_created_at:
try:
created = datetime.fromisoformat(spk_created_at)
if created.tzinfo is None:
created = created.replace(tzinfo=timezone.utc)
age_days = (datetime.now(timezone.utc) - created).days
if age_days >= SPK_ROTATION_DAYS:
need_new_spk = True
self._logger.info("SPK is %d days old, rotating...", age_days)
except Exception:
pass
if count < OPK_REPLENISH_THRESHOLD or need_new_spk:
if count >= OPK_REPLENISH_THRESHOLD:
self._logger.info("SPK rotation triggered (OPK count OK: %d)", count)
else:
self._logger.info("OPK count low (%d), replenishing...", count)
await self._generate_and_upload_prekeys()
# ------------------------------------------------------------------
# Login
# ------------------------------------------------------------------
async def login(self, email: str, password: str) -> tuple[bool, str]:
"""Login user. Returns (success, message)."""
self.email = email
pwd_bytes = bytearray(password.encode("utf-8")) if password else None
try:
# Load RSA keys
priv, pub, err = load_keys(email, password=bytes(pwd_bytes) if pwd_bytes else None)
if priv is None:
return False, err or "No local keys found. Register first."
self.private_key = priv
self.public_key = pub
# Load identity keys
ed_priv, ed_pub = _load_identity_keys(email, password=bytes(pwd_bytes) if pwd_bytes else None)
finally:
if pwd_bytes:
pwd_bytes[:] = b'\x00' * len(pwd_bytes)
if ed_priv is not None:
self.identity_private = ed_priv
self.identity_public = ed_pub
self._cache_key = derive_self_encryption_key(ed_priv)
self._local_key = derive_local_storage_key(ed_priv)
# Load SPK
spk_priv, spk_id = _load_spk(email)
if spk_priv:
self.spk_private = spk_priv
self.spk_id = spk_id
# Load previous SPK for grace period (M4)
prev_spk_priv, prev_spk_id = _load_prev_spk(email)
if prev_spk_priv:
self._prev_spk_private = prev_spk_priv
self._prev_spk_id = prev_spk_id
# Load device_id from disk
self.device_id = _load_device_id(email)
# RSA challenge-response login
start = await self.send_and_recv("login_start", email=email)
if start["status"] != "ok":
return False, start["data"]["message"]
challenge = decode_binary(start["data"]["challenge"])
signature = rsa_sign(self.private_key, challenge)
login_kwargs = {"email": email, "signature": encode_binary(signature),
"client_version": VERSION}
if self.device_id:
login_kwargs["device_id"] = self.device_id
finish = await self.send_and_recv("login_finish", **login_kwargs)
if finish["status"] == "ok":
self.session = finish["data"]
self.username = self.session.get("username", "")
# Store device_id from server
self.device_id = finish["data"].get("device_id")
if self.device_id:
_save_device_id(email, self.device_id)
# Replenish prekeys in background — after pairing, the new device
# has no local OPK private keys so we must generate fresh ones
# (server-side OPKs have no matching private keys on this device).
# Use keep_spk=True to preserve the shared SPK so both devices
# can respond to X3DH.
opk_dir = get_key_dir(self.email) / "opk_private"
has_local_opks = opk_dir.exists() and any(opk_dir.iterdir())
if has_local_opks:
asyncio.create_task(self._ensure_prekeys())
else:
self._logger.info("No local OPKs (likely new device). Generating fresh OPKs, keeping SPK.")
asyncio.create_task(self._generate_and_upload_prekeys(keep_spk=True))
return True, f"Logged in as '{self.username}' (ID: {self.session['user_id']})"
return False, finish["data"]["message"]
# ------------------------------------------------------------------
# Pairing (device pairing — transfers RSA + identity keys)
# ------------------------------------------------------------------
async def pairing_start(self, email: str) -> tuple[bool, str]:
"""Start device pairing. Returns (success, code/message)."""
temp_priv, temp_pub = generate_rsa_keypair(2048)
self._pairing_temp_private_key = temp_priv
temp_pub_pem = serialize_public_key(temp_pub).decode("utf-8")
resp = await self.send_and_recv("pairing_start", email=email, temp_public_key=temp_pub_pem)
if resp["status"] == "ok":
self._pairing_poll_token = resp["data"].get("poll_token", "")
return True, resp["data"]["code"]
return False, resp["data"]["message"]
async def pairing_wait(self, code: str, email: str, password: str, timeout: int = 300) -> tuple[bool, str]:
"""Wait for pairing payload and import keys. Returns (success, message)."""
if not self._pairing_temp_private_key:
return False, "Pairing not started."
from crypto_utils import aes_decrypt as _aes_decrypt
poll_token = getattr(self, "_pairing_poll_token", "")
deadline = asyncio.get_event_loop().time() + timeout
while asyncio.get_event_loop().time() < deadline:
resp = await self.send_and_recv("pairing_poll", code=code, poll_token=poll_token)
if resp["status"] != "ok":
return False, resp["data"]["message"]
if not resp["data"].get("ready"):
await asyncio.sleep(2.0)
continue
payload = resp["data"]["payload"]
try:
# Decrypt AES key with temp RSA key
from cryptography.hazmat.primitives.asymmetric import padding as rsa_padding
from cryptography.hazmat.primitives import hashes as rsa_hashes
enc_aes_key = decode_binary(payload["encrypted_key"])
aes_key = self._pairing_temp_private_key.decrypt(
enc_aes_key,
rsa_padding.OAEP(
mgf=rsa_padding.MGF1(algorithm=rsa_hashes.SHA256()),
algorithm=rsa_hashes.SHA256(),
label=None,
),
)
nonce = decode_binary(payload["iv"])
ct = decode_binary(payload["ciphertext"])
tag = decode_binary(payload["tag"])
keys_json = _aes_decrypt(aes_key, nonce, ct, tag)
keys_data = json.loads(keys_json)
pwd_bytes = bytearray(password.encode("utf-8")) if password else None
try:
# Import RSA key
rsa_priv = load_private_key(keys_data["rsa_private"].encode(), password=None)
rsa_pub = rsa_priv.public_key()
save_keys(email, rsa_priv, rsa_pub, password=bytes(pwd_bytes) if pwd_bytes else None)
# Import identity keys
ed_priv = load_ed25519_private(bytes.fromhex(keys_data["identity_private"]))
ed_pub = ed_priv.public_key()
_save_identity_keys(email, ed_priv, ed_pub, password=bytes(pwd_bytes) if pwd_bytes else None)
finally:
if pwd_bytes:
pwd_bytes[:] = b'\x00' * len(pwd_bytes)
self.email = email
self.private_key = rsa_priv
self.public_key = rsa_pub
self.identity_private = ed_priv
self.identity_public = ed_pub
self._cache_key = derive_self_encryption_key(ed_priv)
self._local_key = derive_local_storage_key(ed_priv)
self._pairing_temp_private_key = None
# Multi-device: new device generates own SPK + OPKs on first
# login. No session/sender key import needed — each device
# has independent Double Ratchet sessions.
return True, "Pairing complete."
except Exception as e:
return False, f"Failed to import keys: {e}"
return False, "Pairing timed out."
async def authorize_device(self, code: str) -> tuple[bool, str]:
"""Authorize a new device by sending all keys to it."""
if not self.private_key or not self.identity_private:
return False, "Not logged in."
claim = await self.send_and_recv("pairing_claim", code=code)
if claim["status"] != "ok":
return False, claim["data"]["message"]
temp_pub_pem = claim["data"]["temp_public_key"].encode("utf-8")
temp_pub = load_public_key(temp_pub_pem)
# Phase 1: Re-encrypt message history so new device can read old
# messages via self-encryption key. This also advances ratchet states
# for any previously-unfetched messages.
try:
await self.reencrypt_history()
except Exception as e:
self._logger.warning("Re-encryption failed: %s", e)
# Phase 2: Build keys payload — only RSA + identity key.
# Multi-device: new device generates own SPK + OPKs, creates independent
# sessions. No session/sender key transfer needed.
keys_data = {
"rsa_private": serialize_private_key(self.private_key, password=None).decode(),
"identity_private": serialize_ed25519_private_raw(self.identity_private).hex(),
}
# Phase 3: Encrypt and send keys to new device
from cryptography.hazmat.primitives.asymmetric import padding as rsa_padding
from cryptography.hazmat.primitives import hashes as rsa_hashes
plaintext = json.dumps(keys_data).encode()
aes_key, nonce, ct, tag = aes_encrypt(plaintext)
enc_aes_key = temp_pub.encrypt(
aes_key,
rsa_padding.OAEP(
mgf=rsa_padding.MGF1(algorithm=rsa_hashes.SHA256()),
algorithm=rsa_hashes.SHA256(),
label=None,
),
)
payload = {
"encrypted_key": encode_binary(enc_aes_key),
"iv": encode_binary(nonce),
"ciphertext": encode_binary(ct),
"tag": encode_binary(tag),
}
resp = await self.send_and_recv("pairing_send", code=code, payload=payload)
if resp["status"] == "ok":
return True, "Device authorized."
return False, resp["data"]["message"]
# ------------------------------------------------------------------
# Key rotation (RSA login key only)
# ------------------------------------------------------------------
async def rotate_keys(self, username: str, password: str) -> tuple[bool, str]:
"""Rotate RSA keypair to revoke other devices."""
if not self.session or self.session.get("username") != username:
return False, "Not logged in."
pwd_bytes = password.encode("utf-8") if password else None
priv, pub = generate_rsa_keypair()
save_keys(self.email, priv, pub, password=pwd_bytes)
self.private_key = priv
self.public_key = pub
pub_pem = serialize_public_key(pub).decode("utf-8")
resp = await self.send_and_recv("rotate_keys", public_key=pub_pem)
if resp["status"] == "ok":
return True, "RSA login keys rotated."
return False, resp["data"]["message"]
# ------------------------------------------------------------------
# Session management (X3DH + Double Ratchet)
# ------------------------------------------------------------------
async def _get_device_bundles(self, peer_user_id: str) -> list[dict]:
"""Get per-device key bundles for a peer. Caches for 5 minutes."""
import time
cached = self._device_bundle_cache.get(peer_user_id)
if cached:
ts, bundles = cached
if time.time() - ts < 300:
return bundles
resp = await self.send_and_recv("get_key_bundle", user_id=peer_user_id)
if resp["status"] != "ok":
raise RuntimeError(f"Cannot get key bundle for {peer_user_id}: {resp['data']['message']}")
data = resp["data"]
ik_b64 = data.get("identity_key", "")
device_bundles = data.get("device_bundles")
if device_bundles:
# Attach identity_key to each bundle
for b in device_bundles:
b["identity_key"] = ik_b64
else:
# Old server: wrap flat response as single-entry list
device_bundles = [{
"device_id": None,
"identity_key": ik_b64,
"signed_prekey_id": data.get("signed_prekey_id", ""),
"signed_prekey": data.get("signed_prekey", ""),
"spk_signature": data.get("spk_signature", ""),
"one_time_prekey_id": data.get("one_time_prekey_id"),
"one_time_prekey": data.get("one_time_prekey"),
}]
self._device_bundle_cache[peer_user_id] = (time.time(), device_bundles)
return device_bundles
async def _get_or_create_session(self, peer_user_id: str,
peer_device_id: str | None = None,
bundle: dict | None = None) -> DoubleRatchet:
"""Load existing session or create one via X3DH.
If peer_device_id is set, sessions are keyed by "user_id:device_id".
If bundle is provided, it's used instead of fetching from server.
"""
session_key = f"{peer_user_id}:{peer_device_id}" if peer_device_id else peer_user_id
# Check in-memory cache
if session_key in self.sessions:
return self.sessions[session_key]
# Check on disk
ratchet = _load_session(self.email, peer_user_id, self._local_key,
peer_device_id=peer_device_id)
if ratchet:
self.sessions[session_key] = ratchet
return ratchet
# Create new session via X3DH
if not bundle:
resp = await self.send_and_recv("get_key_bundle", user_id=peer_user_id)
if resp["status"] != "ok":
raise RuntimeError(f"Cannot get key bundle for {peer_user_id}: {resp['data']['message']}")
bundle = resp["data"]
ik_remote_bytes = decode_binary(bundle["identity_key"])
ik_remote = load_ed25519_public(ik_remote_bytes)
spk_remote = load_x25519_public(decode_binary(bundle["signed_prekey"]))
spk_sig = decode_binary(bundle["spk_signature"])
opk_remote = None
opk_id = bundle.get("one_time_prekey_id")
if bundle.get("one_time_prekey"):
opk_remote = load_x25519_public(decode_binary(bundle["one_time_prekey"]))
# Perform X3DH
shared_secret, ek_priv, ek_pub = x3dh_initiate(
self.identity_private,
ik_remote,
spk_remote,
spk_sig,
opk_remote,
)
# Initialize Double Ratchet as Alice
ratchet = DoubleRatchet.init_alice(shared_secret, spk_remote)
self.sessions[session_key] = ratchet
_save_session(self.email, peer_user_id, ratchet, self._local_key,
peer_device_id=peer_device_id)
# Build X3DH header for first message
x3dh_header = {
"ik": encode_binary(serialize_ed25519_public(self.identity_public)),
"ek": encode_binary(serialize_x25519_public(ek_pub)),
}
if opk_id:
x3dh_header["opk_id"] = opk_id
# Cache the x3dh header for the next send_message call
ratchet._x3dh_header = x3dh_header
# Cache remote user info
self._user_cache[peer_user_id] = {
"user_id": peer_user_id,
"identity_key": ik_remote,
"identity_key_bytes": ik_remote_bytes,
}
return ratchet
def _process_x3dh_header(self, sender_id: str, x3dh_header: dict,
sender_device_id: str | None = None,
spk_override=None) -> DoubleRatchet:
"""Process an incoming X3DH header to establish session as Bob.
Args:
spk_override: If provided, use this SPK private key instead of self.spk_private.
Used for grace period fallback (M4).
"""
ik_remote_bytes = decode_binary(x3dh_header["ik"])
ik_remote = load_ed25519_public(ik_remote_bytes)
ek_remote = load_x25519_public(decode_binary(x3dh_header["ek"]))
opk_id = x3dh_header.get("opk_id")
opk_priv = None
if opk_id:
opk_priv = _load_opk_private(self.email, opk_id)
if opk_priv:
_delete_opk_private(self.email, opk_id)
spk_priv = spk_override if spk_override else self.spk_private
shared_secret = x3dh_respond(
self.identity_private,
spk_priv,
ik_remote,
ek_remote,
opk_priv,
)
spk_pub = spk_priv.public_key() if hasattr(spk_priv, 'public_key') else None
ratchet = DoubleRatchet.init_bob(shared_secret, (spk_priv, spk_pub))
session_key = f"{sender_id}:{sender_device_id}" if sender_device_id else sender_id
self.sessions[session_key] = ratchet
_save_session(self.email, sender_id, ratchet, self._local_key,
peer_device_id=sender_device_id)
self._user_cache[sender_id] = {
"user_id": sender_id,
"identity_key": ik_remote,
"identity_key_bytes": ik_remote_bytes,
}
return ratchet
# ------------------------------------------------------------------
# Conversations
# ------------------------------------------------------------------
async def create_conversation(self, member_emails: list[str], name: str | None = None) -> tuple[str | None, str]:
kwargs = {"members": member_emails}
if name:
kwargs["name"] = name
resp = await self.send_and_recv("create_conversation", **kwargs)
if resp["status"] == "ok":
return resp["data"]["conversation_id"], "OK"
return None, resp["data"]["message"]
async def remove_member(self, conv_id: str, user_id: str) -> tuple[bool, str]:
resp = await self.send_and_recv("remove_member", conversation_id=conv_id, user_id=user_id)
if resp["status"] == "ok":
return True, "OK"
return False, resp["data"]["message"]
async def leave_group(self, conv_id: str) -> tuple[bool, str]:
"""Leave a group conversation."""
resp = await self.send_and_recv("leave_group", conversation_id=conv_id)
if resp["status"] == "ok":
# Clean up local sender key state for this group
self.sender_key_states.pop(conv_id, None)
# Remove received sender keys for this conversation
to_remove = [k for k in self.recv_sender_keys if k.startswith(f"{conv_id}:")]
for k in to_remove:
self.recv_sender_keys.pop(k, None)
return True, "OK"
return False, resp["data"]["message"]
async def rename_conversation(self, conv_id: str, name: str) -> tuple[bool, str]:
"""Rename a group conversation (creator only)."""
resp = await self.send_and_recv("rename_conversation", conversation_id=conv_id, name=name)
if resp["status"] == "ok":
return True, "OK"
return False, resp["data"]["message"]
async def delete_conversation(self, conv_id: str) -> tuple[bool, str]:
"""Delete a conversation (leave + server cleans up if empty)."""
resp = await self.send_and_recv("delete_conversation", conversation_id=conv_id)
if resp["status"] == "ok":
# Clean up local sender key state
self.sender_key_states.pop(conv_id, None)
to_remove = [k for k in self.recv_sender_keys if k.startswith(f"{conv_id}:")]
for k in to_remove:
self.recv_sender_keys.pop(k, None)
return True, "OK"
return False, resp["data"]["message"]
async def add_member(self, conv_id: str, email: str) -> tuple[bool, str]:
resp = await self.send_and_recv("add_member", conversation_id=conv_id, email=email)
if resp["status"] == "ok":
return True, "OK"
return False, resp["data"]["message"]
async def accept_invitation(self, conv_id: str) -> tuple[bool, str]:
"""Accept a group invitation."""
resp = await self.send_and_recv("accept_invitation", conversation_id=conv_id)
if resp["status"] == "ok":
return True, "OK"
return False, resp["data"]["message"]
async def decline_invitation(self, conv_id: str) -> tuple[bool, str]:
"""Decline a group invitation."""
resp = await self.send_and_recv("decline_invitation", conversation_id=conv_id)
if resp["status"] == "ok":
return True, "OK"
return False, resp["data"]["message"]
async def list_invitations(self) -> list[dict]:
"""List pending group invitations."""
resp = await self.send_and_recv("list_invitations")
if resp["status"] == "ok":
return resp["data"]["invitations"]
return []
async def list_conversations(self) -> list[dict]:
resp = await self.send_and_recv("list_conversations")
if resp["status"] == "ok":
return resp["data"]["conversations"]
return []
async def find_or_create_conversation(self, email: str) -> tuple[str | None, str]:
resp = await self.send_and_recv("find_conversation", email=email)
if resp["status"] != "ok":
return None, resp["data"]["message"]
conv_id = resp["data"]["conversation_id"]
if conv_id:
return conv_id, "OK"
return await self.create_conversation([email])
# ------------------------------------------------------------------
# Send message
# ------------------------------------------------------------------
def _is_group(self, members: list[dict]) -> bool:
return len(members) > 2
async def send_message(self, conv_id: str, text: str, members: list[dict],
reply_to: str | None = None) -> tuple[bool, str]:
"""Encrypt and send a message. DM: per-recipient Double Ratchet. Group: Sender Keys."""
my_user_id = self.session["user_id"]
# Build plaintext payload
payload = {
"sender": self.username,
"text": text,
"reply_to": reply_to,
"timestamp": datetime.now(timezone.utc).isoformat(),
}
plaintext = json.dumps(payload, ensure_ascii=False).encode("utf-8")
if self._is_group(members):
return await self._send_group_message(conv_id, plaintext, members)
else:
return await self._send_dm(conv_id, plaintext, members)
async def _send_dm(self, conv_id: str, plaintext: bytes, members: list[dict]) -> tuple[bool, str]:
"""Encrypt DM with per-device Double Ratchet."""
my_user_id = self.session["user_id"]
recipients = []
first_ratchet_header = None
for member in members:
uid = member.get("user_id")
if not uid or uid == my_user_id:
continue
# Get all device bundles for this user
try:
device_bundles = await self._get_device_bundles(uid)
self._logger.debug("Got %d device bundles for %s", len(device_bundles), uid)
except Exception as e:
self._logger.warning("Failed to get device bundles for %s: %s", uid, e)
device_bundles = []
if not device_bundles:
# Fallback: try single session (legacy peer)
ratchet = await self._get_or_create_session(uid)
result = ratchet.encrypt(plaintext)
x3dh_hdr = getattr(ratchet, "_x3dh_header", None)
if x3dh_hdr:
delattr(ratchet, "_x3dh_header")
entry = {
"user_id": uid,
"encrypted_content": encode_binary(result["ciphertext"]),
"nonce": encode_binary(result["nonce"]),
"ratchet_header": result["header"],
}
if x3dh_hdr:
entry["x3dh_header"] = x3dh_hdr
recipients.append(entry)
if first_ratchet_header is None:
first_ratchet_header = result["header"]
_save_session(self.email, uid, ratchet, self._local_key)
continue
for bundle in device_bundles:
dev_id = bundle.get("device_id")
ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id,
bundle=bundle)
result = ratchet.encrypt(plaintext)
x3dh_hdr = getattr(ratchet, "_x3dh_header", None)
if x3dh_hdr:
delattr(ratchet, "_x3dh_header")
entry = {
"user_id": uid,
"encrypted_content": encode_binary(result["ciphertext"]),
"nonce": encode_binary(result["nonce"]),
"ratchet_header": result["header"],
}
if dev_id:
entry["device_id"] = dev_id
if x3dh_hdr:
entry["x3dh_header"] = x3dh_hdr
recipients.append(entry)
if first_ratchet_header is None:
first_ratchet_header = result["header"]
_save_session(self.email, uid, ratchet, self._local_key,
peer_device_id=dev_id)
# Encrypt self-copy with static key derived from identity (not ratchet)
# Uses SELF_DEVICE_ID so all own devices can read it
self_key = derive_self_encryption_key(self.identity_private)
_, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key)
recipients.append({
"user_id": my_user_id,
"encrypted_content": encode_binary(self_ct + self_tag),
"nonce": encode_binary(self_nonce),
"ratchet_header": {"self": True},
})
if not recipients:
return False, "No recipients."
kwargs = {
"conversation_id": conv_id,
"ratchet_header": first_ratchet_header,
"recipients": recipients,
}
resp = await self.send_and_recv("send_message", **kwargs)
if resp["status"] == "ok":
return True, "Message sent."
return False, resp["data"]["message"]
async def _send_group_message(self, conv_id: str, plaintext: bytes,
members: list[dict]) -> tuple[bool, str]:
"""Encrypt group message with Sender Keys."""
my_user_id = self.session["user_id"]
# Get or create sender key for this group
sk = self.sender_key_states.get(conv_id)
if not sk:
sk = _load_sender_key_state(self.email, conv_id, self._local_key)
if not sk:
sk = SenderKeyState()
self.sender_key_states[conv_id] = sk
_save_sender_key_state(self.email, conv_id, sk, self._local_key)
# Distribute sender key to all members via pairwise ratchet
await self._distribute_sender_key(conv_id, members, sk)
self.sender_key_states[conv_id] = sk
# Encrypt with sender key
result = sk.encrypt(plaintext)
_save_sender_key_state(self.email, conv_id, sk, self._local_key)
# Build per-recipient entries (same ciphertext for all except self)
recipients = []
for member in members:
uid = member.get("user_id")
if not uid or uid == my_user_id:
continue
recipients.append({
"user_id": uid,
"encrypted_content": encode_binary(result["ciphertext"]),
"nonce": encode_binary(result["nonce"]),
})
# Self-encrypted copy (so other devices + history fetch can decrypt)
self_key = derive_self_encryption_key(self.identity_private)
_, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key)
recipients.append({
"user_id": my_user_id,
"encrypted_content": encode_binary(self_ct + self_tag),
"nonce": encode_binary(self_nonce),
"ratchet_header": {"self": True},
})
ratchet_header = {"dh_pub": "00" * 32, "n": 0, "pn": 0} # Dummy for groups
kwargs = {
"conversation_id": conv_id,
"ratchet_header": ratchet_header,
"recipients": recipients,
"sender_chain_id": encode_binary(bytes.fromhex(result["chain_id"])),
"sender_chain_n": result["n"],
}
resp = await self.send_and_recv("send_message", **kwargs)
if resp["status"] == "ok":
return True, "Message sent."
return False, resp["data"]["message"]
async def _distribute_sender_key(self, conv_id: str, members: list[dict],
sk: SenderKeyState):
"""Send own sender key to all group members via pairwise Double Ratchet (per-device)."""
my_user_id = self.session["user_id"]
exported_key = sk.export_key()
# Build a special "sender_key_distribution" payload
payload = {
"sender": self.username,
"text": "",
"reply_to": None,
"timestamp": datetime.now(timezone.utc).isoformat(),
"_sender_key": {
"conv_id": conv_id,
"key": encode_binary(exported_key),
"sender_device_id": self.device_id,
},
}
plaintext = json.dumps(payload, ensure_ascii=False).encode("utf-8")
# Send as DM to each member's devices (per-device encryption)
for member in members:
uid = member.get("user_id")
if not uid or uid == my_user_id:
continue
try:
# Get all device bundles for this user
try:
device_bundles = await self._get_device_bundles(uid)
except Exception:
device_bundles = []
if not device_bundles:
# Fallback: legacy single-device
ratchet = await self._get_or_create_session(uid)
result = ratchet.encrypt(plaintext)
x3dh_header = getattr(ratchet, "_x3dh_header", None)
if x3dh_header:
delattr(ratchet, "_x3dh_header")
recipient_entry = {
"user_id": uid,
"encrypted_content": encode_binary(result["ciphertext"]),
"nonce": encode_binary(result["nonce"]),
"ratchet_header": result["header"],
}
if x3dh_header:
recipient_entry["x3dh_header"] = x3dh_header
kwargs = {
"conversation_id": conv_id,
"ratchet_header": result["header"],
"recipients": [recipient_entry],
}
await self.send_and_recv("send_message", **kwargs)
_save_session(self.email, uid, ratchet, self._local_key)
else:
# Per-device encryption
recipients = []
first_rh = None
for bundle in device_bundles:
dev_id = bundle.get("device_id")
ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id,
bundle=bundle)
result = ratchet.encrypt(plaintext)
x3dh_header = getattr(ratchet, "_x3dh_header", None)
if x3dh_header:
delattr(ratchet, "_x3dh_header")
entry = {
"user_id": uid,
"encrypted_content": encode_binary(result["ciphertext"]),
"nonce": encode_binary(result["nonce"]),
"ratchet_header": result["header"],
}
if dev_id:
entry["device_id"] = dev_id
if x3dh_header:
entry["x3dh_header"] = x3dh_header
recipients.append(entry)
if first_rh is None:
first_rh = result["header"]
_save_session(self.email, uid, ratchet, self._local_key,
peer_device_id=dev_id)
kwargs = {
"conversation_id": conv_id,
"ratchet_header": first_rh,
"recipients": recipients,
}
await self.send_and_recv("send_message", **kwargs)
except Exception as e:
self._logger.warning("Failed to distribute sender key to %s: %s", uid, e)
# ------------------------------------------------------------------
# Decrypt messages
# ------------------------------------------------------------------
def _decrypt_message(self, msg_data: dict) -> dict:
"""Decrypt a single message (DM or group)."""
# Check for self-encrypted marker FIRST — after re-encryption,
# group messages will have {"self": true} ratchet_header but still
# have sender_chain_id at message level.
rh = msg_data.get("ratchet_header", {})
if isinstance(rh, dict) and rh.get("self"):
return self._decrypt_dm(msg_data)
if msg_data.get("sender_chain_id"):
return self._decrypt_group(msg_data)
else:
return self._decrypt_dm(msg_data)
def _decrypt_dm(self, msg_data: dict) -> dict:
"""Decrypt DM using Double Ratchet with sender, or static key for self-copies."""
sender_id = msg_data.get("sender_id", "")
sender_device_id = msg_data.get("sender_device_id")
ratchet_header = msg_data.get("ratchet_header", {})
ct_b64 = msg_data.get("encrypted_content", "")
nonce_b64 = msg_data.get("nonce", "")
if not ct_b64 or not nonce_b64:
raise ValueError("Missing ciphertext or nonce")
ciphertext = decode_binary(ct_b64)
nonce = decode_binary(nonce_b64)
# Self-encrypted message (own sent message copy)
if isinstance(ratchet_header, dict) and ratchet_header.get("self"):
self_key = derive_self_encryption_key(self.identity_private)
ct = ciphertext[:-16]
tag = ciphertext[-16:]
plaintext = aes_decrypt(self_key, nonce, ct, tag)
else:
x3dh_header = msg_data.get("x3dh_header")
# Session key: "sender_id:sender_device_id" or just "sender_id" for legacy
session_key = f"{sender_id}:{sender_device_id}" if sender_device_id else sender_id
# Try to load existing session
ratchet = self.sessions.get(session_key)
if not ratchet:
ratchet = _load_session(self.email, sender_id, self._local_key,
peer_device_id=sender_device_id)
if ratchet:
self.sessions[session_key] = ratchet
if ratchet and not x3dh_header:
# Normal case: existing session, no X3DH header
plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce)
_save_session(self.email, sender_id, ratchet, self._local_key,
peer_device_id=sender_device_id)
elif x3dh_header:
if ratchet:
# Existing session + X3DH header: sender may have reset.
backup = ratchet.export_state()
try:
plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce)
_save_session(self.email, sender_id, ratchet, self._local_key,
peer_device_id=sender_device_id)
except Exception:
restored = DoubleRatchet.import_state(backup)
self.sessions[session_key] = restored
_save_session(self.email, sender_id, restored, self._local_key,
peer_device_id=sender_device_id)
ratchet = self._process_x3dh_header(sender_id, x3dh_header,
sender_device_id=sender_device_id)
try:
plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce)
except Exception:
if self._prev_spk_private:
ratchet = self._process_x3dh_header(
sender_id, x3dh_header,
sender_device_id=sender_device_id,
spk_override=self._prev_spk_private)
plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce)
else:
raise
_save_session(self.email, sender_id, ratchet, self._local_key,
peer_device_id=sender_device_id)
else:
ratchet = self._process_x3dh_header(sender_id, x3dh_header,
sender_device_id=sender_device_id)
try:
plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce)
except Exception:
if self._prev_spk_private:
ratchet = self._process_x3dh_header(
sender_id, x3dh_header,
sender_device_id=sender_device_id,
spk_override=self._prev_spk_private)
plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce)
else:
raise
_save_session(self.email, sender_id, ratchet, self._local_key,
peer_device_id=sender_device_id)
else:
raise ValueError(f"No session for sender {sender_id}")
payload = json.loads(plaintext)
# Handle sender key distribution messages
if "_sender_key" in payload:
sk_data = payload["_sender_key"]
sk_conv_id = sk_data["conv_id"]
sk_key = decode_binary(sk_data["key"])
sk_sender_device_id = sk_data.get("sender_device_id")
recv_sk = SenderKeyState.from_key(sk_key)
if sk_sender_device_id:
cache_key = f"{sk_conv_id}:{sender_id}:{sk_sender_device_id}"
else:
cache_key = f"{sk_conv_id}:{sender_id}"
self.recv_sender_keys[cache_key] = recv_sk
_save_recv_sender_key(self.email, sk_conv_id, sender_id, recv_sk, self._local_key,
sender_device_id=sk_sender_device_id)
# Return empty — this is a control message, not user-visible
return None
return payload
def _decrypt_group(self, msg_data: dict) -> dict:
"""Decrypt group message using sender's Sender Key."""
sender_id = msg_data.get("sender_id", "")
sender_device_id = msg_data.get("sender_device_id")
conv_id = msg_data.get("conversation_id", "")
chain_id_b64 = msg_data.get("sender_chain_id", "")
chain_n = msg_data.get("sender_chain_n", 0)
ct_b64 = msg_data.get("encrypted_content", "")
nonce_b64 = msg_data.get("nonce", "")
if not ct_b64 or not nonce_b64 or not chain_id_b64:
raise ValueError("Missing group message fields")
ciphertext = decode_binary(ct_b64)
nonce = decode_binary(nonce_b64)
chain_id = decode_binary(chain_id_b64)
my_user_id = self.session["user_id"]
# If we sent this message, use our own sender key
if sender_id == my_user_id:
sk = self.sender_key_states.get(conv_id)
if not sk:
sk = _load_sender_key_state(self.email, conv_id, self._local_key)
if sk:
self.sender_key_states[conv_id] = sk
if not sk:
raise ValueError("Own sender key not found")
# For our own messages, we can't decrypt from sender key (it's already advanced)
# Return a placeholder — the server echoed our ciphertext
raise ValueError("Cannot decrypt own group message from sender key")
# Use received sender key — try with sender_device_id first, fall back to without
sk = None
if sender_device_id:
cache_key = f"{conv_id}:{sender_id}:{sender_device_id}"
sk = self.recv_sender_keys.get(cache_key)
if not sk:
sk = _load_recv_sender_key(self.email, conv_id, sender_id, self._local_key,
sender_device_id=sender_device_id)
if sk:
self.recv_sender_keys[cache_key] = sk
if not sk:
# Fallback: try without device_id (legacy or same-device)
cache_key = f"{conv_id}:{sender_id}"
sk = self.recv_sender_keys.get(cache_key)
if not sk:
sk = _load_recv_sender_key(self.email, conv_id, sender_id, self._local_key)
if sk:
self.recv_sender_keys[cache_key] = sk
if not sk:
raise ValueError(f"No sender key for {sender_id} in conversation {conv_id}")
plaintext = sk.decrypt(chain_id.hex(), chain_n, ciphertext, nonce)
_save_recv_sender_key(self.email, conv_id, sender_id, sk, self._local_key,
sender_device_id=sender_device_id)
return json.loads(plaintext)
# ------------------------------------------------------------------
# Get/decrypt messages (batch)
# ------------------------------------------------------------------
async def get_messages(self, conv_id: str, limit: int = 50, offset: int = 0) -> list[dict]:
resp = await self.send_and_recv("get_messages", conversation_id=conv_id, limit=limit, offset=offset)
if resp["status"] != "ok":
return []
cache = _load_message_cache(self.email, conv_id, self._cache_key)
decrypted = []
message_ids = []
raw_messages = resp["data"]["messages"]
raw_messages.reverse() # Server returns DESC, reverse to ASC
for m in raw_messages:
msg_id = m["message_id"]
message_ids.append(msg_id)
if m.get("deleted_at"):
decrypted.append({
"message_id": msg_id,
"sender": "",
"text": "",
"created_at": m["created_at"],
"read_by": [],
"sender_id": m.get("sender_id", ""),
"deleted": True,
})
continue
# Check local cache first (ratchet keys are one-time use)
cached = cache.get(msg_id)
if cached:
cached["read_by"] = m.get("read_by", [])
cached["created_at"] = m["created_at"]
if cached.get("_control"):
continue # Skip control messages
decrypted.append(cached)
continue
try:
msg_data = {
"sender_id": m.get("sender_id", ""),
"sender_device_id": m.get("sender_device_id"),
"conversation_id": conv_id,
"ratchet_header": m.get("ratchet_header", {}),
"encrypted_content": m.get("encrypted_content", ""),
"nonce": m.get("nonce", ""),
"x3dh_header": m.get("x3dh_header"),
"sender_chain_id": m.get("sender_chain_id"),
"sender_chain_n": m.get("sender_chain_n"),
}
payload = self._decrypt_message(msg_data)
if payload is None:
# Control message (sender key distribution) — cache and skip
_save_message_to_cache(self.email, conv_id, msg_id, {"_control": True},
cache_key=self._cache_key)
continue
payload["message_id"] = msg_id
payload["created_at"] = m["created_at"]
payload["read_by"] = m.get("read_by", [])
payload["sender_id"] = m.get("sender_id", "")
decrypted.append(payload)
# Cache the decrypted payload (without read_by which changes)
cache_entry = {k: v for k, v in payload.items() if k != "read_by"}
_save_message_to_cache(self.email, conv_id, msg_id, cache_entry,
cache_key=self._cache_key)
except Exception as e:
decrypted.append({
"message_id": msg_id,
"sender": "???",
"text": f"[Decryption failed: {e}]",
"created_at": m["created_at"],
"read_by": [],
})
if message_ids:
await self.mark_read(conv_id, message_ids)
return decrypted
async def mark_read(self, conv_id: str, message_ids: list[str]):
if not message_ids:
return
await self.send_and_recv("mark_read", conversation_id=conv_id, message_ids=message_ids)
def search_messages(self, conv_id: str, query: str) -> list[dict]:
"""Search cached messages in a conversation. Returns matching messages."""
cache = _load_message_cache(self.email, conv_id, self._cache_key)
query_lower = query.lower()
results = []
for msg_id, payload in cache.items():
if payload.get("deleted") or payload.get("_control") or payload.get("_sender_key"):
continue
text = payload.get("text", "")
if query_lower in text.lower():
entry = dict(payload)
entry["message_id"] = msg_id
results.append(entry)
results.sort(key=lambda m: m.get("created_at", ""))
return results
async def reset_session(self, peer_user_id: str, peer_device_id: str | None = None):
"""Delete local session and notify peer to do the same."""
if peer_device_id:
session_key = f"{peer_user_id}:{peer_device_id}"
else:
session_key = peer_user_id
self.sessions.pop(session_key, None)
_delete_session_file(self.email, peer_user_id, peer_device_id)
await self.send_and_recv("session_reset",
peer_user_id=peer_user_id,
peer_device_id=peer_device_id or "")
def handle_session_reset_notification(self, from_user_id: str, from_device_id: str | None = None):
"""Handle incoming session reset notification — delete the matching session."""
if from_device_id:
session_key = f"{from_user_id}:{from_device_id}"
else:
session_key = from_user_id
self.sessions.pop(session_key, None)
_delete_session_file(self.email, from_user_id, from_device_id)
# ------------------------------------------------------------------
# Decrypt notification
# ------------------------------------------------------------------
def decrypt_notification(self, notif_data: dict) -> dict | None:
"""Decrypt a new_message notification. Returns parsed payload or None.
Supports new multi-device format (device_entries array) and legacy flat format.
"""
try:
conv_id = notif_data.get("conversation_id", "")
msg_id = notif_data.get("message_id", "")
sender_id = notif_data.get("sender_id", "")
sender_device_id = notif_data.get("sender_device_id")
my_user_id = self.session["user_id"] if self.session else ""
# Extract per-device encrypted content from device_entries or flat fields
encrypted_content = ""
nonce = ""
ratchet_header = {}
x3dh_header = None
device_entries = notif_data.get("device_entries")
if device_entries:
# Multi-device format: pick entry matching our device_id or SELF_DEVICE_ID
chosen = None
self_entry = None
for entry in device_entries:
eid = entry.get("device_id", "")
if eid == self.device_id:
chosen = entry
break
if eid == "00000000-0000-0000-0000-000000000000":
self_entry = entry
# If sender is us, prefer self-encrypted entry
if sender_id == my_user_id:
chosen = self_entry or chosen
elif not chosen:
chosen = self_entry
if not chosen:
self._logger.warning("No matching device_entry for device %s", self.device_id)
return None
encrypted_content = chosen.get("encrypted_content", "")
nonce = chosen.get("nonce", "")
ratchet_header = chosen.get("ratchet_header") or notif_data.get("ratchet_header", {})
x3dh_header = chosen.get("x3dh_header") or notif_data.get("x3dh_header")
else:
# Legacy flat format
encrypted_content = notif_data.get("encrypted_content", "")
nonce = notif_data.get("nonce", "")
ratchet_header = notif_data.get("ratchet_header", {})
x3dh_header = notif_data.get("x3dh_header")
msg_data = {
"sender_id": sender_id,
"sender_device_id": sender_device_id,
"conversation_id": conv_id,
"ratchet_header": ratchet_header,
"encrypted_content": encrypted_content,
"nonce": nonce,
"x3dh_header": x3dh_header,
"sender_chain_id": notif_data.get("sender_chain_id"),
"sender_chain_n": notif_data.get("sender_chain_n"),
}
payload = self._decrypt_message(msg_data)
if payload is None:
# Cache control message so get_messages skips it
if msg_id and conv_id:
_save_message_to_cache(self.email, conv_id, msg_id, {"_control": True},
cache_key=self._cache_key)
return None
payload["conversation_id"] = conv_id
payload["message_id"] = msg_id
payload["sender_id"] = sender_id
payload["created_at"] = payload.get("timestamp", "")
payload["read_by"] = []
# Cache so get_messages doesn't re-decrypt (ratchet keys are one-time)
if msg_id and conv_id:
cache_entry = {k: v for k, v in payload.items() if k != "read_by"}
_save_message_to_cache(self.email, conv_id, msg_id, cache_entry,
cache_key=self._cache_key)
return payload
except Exception as e:
self._logger.warning("Failed to decrypt notification: %s", e)
return None
# ------------------------------------------------------------------
# Delete message
# ------------------------------------------------------------------
async def delete_message(self, message_id: str) -> tuple[bool, str]:
resp = await self.send_and_recv("delete_message", message_id=message_id)
if resp["status"] == "ok":
return True, "Message deleted."
return False, resp["data"]["message"]
# ------------------------------------------------------------------
# Image sharing
# ------------------------------------------------------------------
async def send_image(self, conv_id: str, image_path: str, members: list[dict],
reply_to: str | None = None) -> tuple[bool, str]:
"""Encrypt and upload an image, then send as a message."""
try:
from PIL import Image
import io
except ImportError:
return False, "Pillow is required for image sharing. Install with: pip install Pillow"
path = Path(image_path)
if not path.exists():
return False, "File not found."
try:
img = Image.open(path)
img.load()
except Exception as e:
return False, f"Cannot open image: {e}"
if img.mode not in ("RGB", "L"):
img = img.convert("RGB")
max_dim = 1920
if max(img.size) > max_dim:
img.thumbnail((max_dim, max_dim), Image.Resampling.LANCZOS)
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=85)
image_bytes = buf.getvalue()
thumb = img.copy()
thumb.thumbnail((200, 200), Image.Resampling.LANCZOS)
thumb_buf = io.BytesIO()
thumb.save(thumb_buf, format="JPEG", quality=60)
thumbnail_b64 = encode_binary(thumb_buf.getvalue())
# Encrypt image with AES-256-GCM
img_aes_key, img_iv, img_ct, img_tag = aes_encrypt(image_bytes)
encrypted_image = img_ct + img_tag
file_id = str(uuid.uuid4())
file_size = len(encrypted_image)
# Chunked upload
resp = await self.send_and_recv(
"upload_image_start",
conversation_id=conv_id,
file_id=file_id,
file_size=file_size,
)
if resp["status"] != "ok":
return False, resp["data"]["message"]
upload_offset = 0
while upload_offset < file_size:
chunk = encrypted_image[upload_offset:upload_offset + IMAGE_CHUNK_SIZE]
resp = await self.send_and_recv(
"upload_image_chunk",
file_id=file_id,
data=encode_binary(chunk),
)
if resp["status"] != "ok":
return False, resp["data"]["message"]
upload_offset += len(chunk)
resp = await self.send_and_recv("upload_image_end", file_id=file_id)
if resp["status"] != "ok":
return False, resp["data"]["message"]
# Build message payload with image info
image_info = {
"file_id": file_id,
"aes_key": encode_binary(img_aes_key),
"iv": encode_binary(img_iv),
"thumbnail": thumbnail_b64,
"filename": path.name,
"size": len(image_bytes),
}
payload = {
"sender": self.username,
"text": "",
"reply_to": reply_to,
"timestamp": datetime.now(timezone.utc).isoformat(),
"image": image_info,
}
plaintext = json.dumps(payload, ensure_ascii=False).encode("utf-8")
my_user_id = self.session["user_id"]
if self._is_group(members):
# Group image: use sender key
sk = self.sender_key_states.get(conv_id)
if not sk:
sk = _load_sender_key_state(self.email, conv_id, self._local_key)
if not sk:
sk = SenderKeyState()
self.sender_key_states[conv_id] = sk
_save_sender_key_state(self.email, conv_id, sk, self._local_key)
await self._distribute_sender_key(conv_id, members, sk)
result = sk.encrypt(plaintext)
_save_sender_key_state(self.email, conv_id, sk, self._local_key)
recipients = []
for member in members:
uid = member.get("user_id")
if not uid or uid == my_user_id:
continue
recipients.append({
"user_id": uid,
"encrypted_content": encode_binary(result["ciphertext"]),
"nonce": encode_binary(result["nonce"]),
})
# Self-encrypted copy for sender
self_key = derive_self_encryption_key(self.identity_private)
_, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key)
recipients.append({
"user_id": my_user_id,
"encrypted_content": encode_binary(self_ct + self_tag),
"nonce": encode_binary(self_nonce),
"ratchet_header": {"self": True},
})
resp = await self.send_and_recv(
"send_message",
conversation_id=conv_id,
ratchet_header={"dh_pub": "00" * 32, "n": 0, "pn": 0},
recipients=recipients,
sender_chain_id=encode_binary(bytes.fromhex(result["chain_id"])),
sender_chain_n=result["n"],
image_file_id=file_id,
)
else:
# DM image: per-device ratchet (same pattern as _send_dm)
recipients = []
first_rh = None
for member in members:
uid = member.get("user_id")
if not uid or uid == my_user_id:
continue
try:
device_bundles = await self._get_device_bundles(uid)
except Exception:
device_bundles = []
if not device_bundles:
# Fallback: legacy single-device
ratchet = await self._get_or_create_session(uid)
result = ratchet.encrypt(plaintext)
x3dh_h = getattr(ratchet, "_x3dh_header", None)
if x3dh_h:
delattr(ratchet, "_x3dh_header")
entry = {
"user_id": uid,
"encrypted_content": encode_binary(result["ciphertext"]),
"nonce": encode_binary(result["nonce"]),
"ratchet_header": result["header"],
}
if x3dh_h:
entry["x3dh_header"] = x3dh_h
recipients.append(entry)
if first_rh is None:
first_rh = result["header"]
_save_session(self.email, uid, ratchet, self._local_key)
else:
for bundle in device_bundles:
dev_id = bundle.get("device_id")
ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id,
bundle=bundle)
result = ratchet.encrypt(plaintext)
x3dh_h = getattr(ratchet, "_x3dh_header", None)
if x3dh_h:
delattr(ratchet, "_x3dh_header")
entry = {
"user_id": uid,
"encrypted_content": encode_binary(result["ciphertext"]),
"nonce": encode_binary(result["nonce"]),
"ratchet_header": result["header"],
}
if dev_id:
entry["device_id"] = dev_id
if x3dh_h:
entry["x3dh_header"] = x3dh_h
recipients.append(entry)
if first_rh is None:
first_rh = result["header"]
_save_session(self.email, uid, ratchet, self._local_key,
peer_device_id=dev_id)
# Encrypt self-copy with static key
self_key = derive_self_encryption_key(self.identity_private)
_, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key)
recipients.append({
"user_id": my_user_id,
"encrypted_content": encode_binary(self_ct + self_tag),
"nonce": encode_binary(self_nonce),
"ratchet_header": {"self": True},
})
resp = await self.send_and_recv(
"send_message",
conversation_id=conv_id,
ratchet_header=first_rh,
recipients=recipients,
image_file_id=file_id,
)
if resp["status"] == "ok":
return True, "Image sent."
return False, resp["data"]["message"]
async def send_file(self, conv_id: str, file_path: str, members: list[dict],
reply_to: str | None = None) -> tuple[bool, str]:
"""Encrypt and upload a file, then send as a message."""
import mimetypes
path = Path(file_path)
if not path.exists():
return False, "File not found."
try:
file_bytes = path.read_bytes()
except Exception as e:
return False, f"Cannot read file: {e}"
mime_type = mimetypes.guess_type(path.name)[0] or "application/octet-stream"
# Encrypt file with AES-256-GCM
file_aes_key, file_iv, file_ct, file_tag = aes_encrypt(file_bytes)
encrypted_file = file_ct + file_tag
file_id = str(uuid.uuid4())
file_size = len(encrypted_file)
# Chunked upload (reuse image upload infrastructure with file_type="file")
resp = await self.send_and_recv(
"upload_image_start",
conversation_id=conv_id,
file_id=file_id,
file_size=file_size,
file_type="file",
)
if resp["status"] != "ok":
return False, resp["data"]["message"]
upload_offset = 0
while upload_offset < file_size:
chunk = encrypted_file[upload_offset:upload_offset + IMAGE_CHUNK_SIZE]
resp = await self.send_and_recv(
"upload_image_chunk",
file_id=file_id,
data=encode_binary(chunk),
)
if resp["status"] != "ok":
return False, resp["data"]["message"]
upload_offset += len(chunk)
resp = await self.send_and_recv("upload_image_end", file_id=file_id)
if resp["status"] != "ok":
return False, resp["data"]["message"]
# Build message payload with file info
file_info = {
"file_id": file_id,
"aes_key": encode_binary(file_aes_key),
"iv": encode_binary(file_iv),
"filename": path.name,
"size": len(file_bytes),
"mime_type": mime_type,
}
payload = {
"sender": self.username,
"text": "",
"reply_to": reply_to,
"timestamp": datetime.now(timezone.utc).isoformat(),
"file": file_info,
}
plaintext = json.dumps(payload, ensure_ascii=False).encode("utf-8")
my_user_id = self.session["user_id"]
if self._is_group(members):
sk = self.sender_key_states.get(conv_id)
if not sk:
sk = _load_sender_key_state(self.email, conv_id, self._local_key)
if not sk:
sk = SenderKeyState()
self.sender_key_states[conv_id] = sk
_save_sender_key_state(self.email, conv_id, sk, self._local_key)
await self._distribute_sender_key(conv_id, members, sk)
result = sk.encrypt(plaintext)
_save_sender_key_state(self.email, conv_id, sk, self._local_key)
recipients = []
for member in members:
uid = member.get("user_id")
if not uid or uid == my_user_id:
continue
recipients.append({
"user_id": uid,
"encrypted_content": encode_binary(result["ciphertext"]),
"nonce": encode_binary(result["nonce"]),
})
# Self-encrypted copy for sender
self_key = derive_self_encryption_key(self.identity_private)
_, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key)
recipients.append({
"user_id": my_user_id,
"encrypted_content": encode_binary(self_ct + self_tag),
"nonce": encode_binary(self_nonce),
"ratchet_header": {"self": True},
})
resp = await self.send_and_recv(
"send_message",
conversation_id=conv_id,
ratchet_header={"dh_pub": "00" * 32, "n": 0, "pn": 0},
recipients=recipients,
sender_chain_id=encode_binary(bytes.fromhex(result["chain_id"])),
sender_chain_n=result["n"],
image_file_id=file_id,
)
else:
# DM file: per-device ratchet (same pattern as _send_dm)
recipients = []
first_rh = None
for member in members:
uid = member.get("user_id")
if not uid or uid == my_user_id:
continue
try:
device_bundles = await self._get_device_bundles(uid)
except Exception:
device_bundles = []
if not device_bundles:
# Fallback: legacy single-device
ratchet = await self._get_or_create_session(uid)
result = ratchet.encrypt(plaintext)
x3dh_h = getattr(ratchet, "_x3dh_header", None)
if x3dh_h:
delattr(ratchet, "_x3dh_header")
entry = {
"user_id": uid,
"encrypted_content": encode_binary(result["ciphertext"]),
"nonce": encode_binary(result["nonce"]),
"ratchet_header": result["header"],
}
if x3dh_h:
entry["x3dh_header"] = x3dh_h
recipients.append(entry)
if first_rh is None:
first_rh = result["header"]
_save_session(self.email, uid, ratchet, self._local_key)
else:
for bundle in device_bundles:
dev_id = bundle.get("device_id")
ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id,
bundle=bundle)
result = ratchet.encrypt(plaintext)
x3dh_h = getattr(ratchet, "_x3dh_header", None)
if x3dh_h:
delattr(ratchet, "_x3dh_header")
entry = {
"user_id": uid,
"encrypted_content": encode_binary(result["ciphertext"]),
"nonce": encode_binary(result["nonce"]),
"ratchet_header": result["header"],
}
if dev_id:
entry["device_id"] = dev_id
if x3dh_h:
entry["x3dh_header"] = x3dh_h
recipients.append(entry)
if first_rh is None:
first_rh = result["header"]
_save_session(self.email, uid, ratchet, self._local_key,
peer_device_id=dev_id)
# Encrypt self-copy with static key
self_key = derive_self_encryption_key(self.identity_private)
_, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key)
recipients.append({
"user_id": my_user_id,
"encrypted_content": encode_binary(self_ct + self_tag),
"nonce": encode_binary(self_nonce),
"ratchet_header": {"self": True},
})
resp = await self.send_and_recv(
"send_message",
conversation_id=conv_id,
ratchet_header=first_rh,
recipients=recipients,
image_file_id=file_id,
)
if resp["status"] == "ok":
return True, "File sent."
return False, resp["data"]["message"]
async def download_file(self, file_id: str, file_info: dict) -> bytes | None:
"""Download and decrypt a file. Returns decrypted file bytes or None."""
chunks = []
offset = 0
while True:
resp = await self.send_and_recv(
"download_image",
file_id=file_id,
offset=offset,
)
if resp["status"] != "ok":
return None
data = resp["data"]
chunk = decode_binary(data["data"])
chunks.append(chunk)
offset += len(chunk)
if data.get("done"):
break
encrypted_data = b"".join(chunks)
if len(encrypted_data) < 16:
return None
ciphertext = encrypted_data[:-16]
tag = encrypted_data[-16:]
try:
file_aes_key = decode_binary(file_info["aes_key"])
iv = decode_binary(file_info["iv"])
return aes_decrypt(file_aes_key, iv, ciphertext, tag)
except Exception:
return None
async def download_image(self, file_id: str, image_info: dict) -> bytes | None:
"""Download and decrypt an image. Returns decrypted image bytes or None."""
chunks = []
offset = 0
while True:
resp = await self.send_and_recv(
"download_image",
file_id=file_id,
offset=offset,
)
if resp["status"] != "ok":
return None
data = resp["data"]
chunk = decode_binary(data["data"])
chunks.append(chunk)
offset += len(chunk)
if data.get("done"):
break
encrypted_data = b"".join(chunks)
if len(encrypted_data) < 16:
return None
ciphertext = encrypted_data[:-16]
tag = encrypted_data[-16:]
try:
img_aes_key = decode_binary(image_info["aes_key"])
iv = decode_binary(image_info["iv"])
return aes_decrypt(img_aes_key, iv, ciphertext, tag)
except Exception:
return None
# ------------------------------------------------------------------
# Re-encrypt history (for device pairing)
# ------------------------------------------------------------------
async def reencrypt_history(self):
"""Re-encrypt all cached messages with self-encryption key.
After device pairing, the new device shares the same identity key
but cannot decrypt old messages (Double Ratchet keys are one-time use).
This re-encrypts all cached messages so they can be read using the
self-encryption key derived from the shared identity key.
"""
if not self.identity_private or not self.session:
return
self_key = derive_self_encryption_key(self.identity_private)
# Phase 1: Fetch & decrypt all messages to populate cache
# (messages the old device never opened won't be in cache yet)
try:
convs = await self.list_conversations()
total_convs = len(convs)
for ci, conv in enumerate(convs):
cid = conv.get("id") or conv.get("conversation_id")
if not cid:
continue
if self._reencrypt_progress_cb:
self._reencrypt_progress_cb(
f"Fetching messages: {ci + 1}/{total_convs} conversations..."
)
offset = 0
while True:
msgs = await self.get_messages(cid, limit=200, offset=offset)
if not msgs or len(msgs) < 200:
break
offset += len(msgs)
except Exception as e:
self._logger.warning("Failed to fetch messages for re-encryption: %s", e)
# Phase 2: Read cache and re-encrypt
cache_dir = get_key_dir(self.email) / "message_cache"
if not cache_dir.exists():
self._logger.info("No message cache to re-encrypt.")
return
all_updates = []
conv_ids = set()
for f in cache_dir.iterdir():
if f.suffix in (".json", ".bin"):
conv_ids.add(f.stem)
total_files = len(conv_ids)
for i, conv_id in enumerate(sorted(conv_ids)):
cache = _load_message_cache(self.email, conv_id, self._cache_key)
if not cache:
continue
for msg_id, entry in cache.items():
# Skip control messages (sender key distribution)
if entry.get("_control"):
continue
# Skip entries with no useful content
text = entry.get("text", "")
if not text and not entry.get("image") and not entry.get("file"):
continue
# Rebuild plaintext from cached payload
payload = {k: v for k, v in entry.items()
if k not in ("message_id", "created_at", "read_by", "sender_id", "deleted")}
plaintext = json.dumps(payload, ensure_ascii=False).encode("utf-8")
# Re-encrypt with self-encryption key
_, nonce, ct, tag = aes_encrypt(plaintext, key=self_key)
all_updates.append({
"message_id": msg_id,
"encrypted_content": encode_binary(ct + tag),
"nonce": encode_binary(nonce),
})
if self._reencrypt_progress_cb:
self._reencrypt_progress_cb(f"Re-encrypting history: {i + 1}/{total_files} conversations...")
if not all_updates:
self._logger.info("No messages to re-encrypt.")
return
# Send in batches of 500
batch_size = 500
total = len(all_updates)
for start in range(0, total, batch_size):
batch = all_updates[start:start + batch_size]
resp = await self.send_and_recv("reencrypt_messages", updates=batch)
if resp["status"] != "ok":
self._logger.warning("Re-encrypt batch failed: %s", resp.get("data", {}).get("message", ""))
else:
self._logger.info("Re-encrypted %d/%d messages.", min(start + batch_size, total), total)
if self._reencrypt_progress_cb:
self._reencrypt_progress_cb(f"Re-encryption complete: {total} messages uploaded.")
# ------------------------------------------------------------------
# User Profiles
# ------------------------------------------------------------------
async def get_profile(self, user_id: str | None = None) -> dict | None:
"""Get user profile. If user_id is None, returns own profile."""
kwargs = {}
if user_id:
kwargs["user_id"] = user_id
resp = await self.send_and_recv("get_profile", **kwargs)
if resp["status"] == "ok":
return resp["data"]
return None
async def update_profile(self, **fields) -> tuple[bool, str]:
"""Update own profile (phone, location, *_visible)."""
resp = await self.send_and_recv("update_profile", **fields)
if resp["status"] == "ok":
return True, "OK"
return False, resp["data"]["message"]
async def update_avatar(self, image_data: bytes) -> tuple[bool, str]:
"""Upload avatar image."""
resp = await self.send_and_recv("update_avatar", data=encode_binary(image_data))
if resp["status"] == "ok":
return True, resp["data"].get("avatar_file", "")
return False, resp["data"]["message"]
async def get_avatar(self, user_id: str) -> bytes | None:
"""Download avatar for a user."""
resp = await self.send_and_recv("get_avatar", user_id=user_id)
if resp["status"] == "ok":
return decode_binary(resp["data"]["data"])
return None
async def update_group_avatar(self, conv_id: str, image_data: bytes) -> tuple[bool, str]:
"""Upload avatar for a group conversation."""
resp = await self.send_and_recv("update_group_avatar",
conversation_id=conv_id, data=encode_binary(image_data))
if resp["status"] == "ok":
return True, resp["data"].get("avatar_file", "")
return False, resp["data"]["message"]
async def get_group_avatar(self, conv_id: str) -> bytes | None:
"""Download avatar for a group conversation."""
resp = await self.send_and_recv("get_group_avatar", conversation_id=conv_id)
if resp["status"] == "ok":
return decode_binary(resp["data"]["data"])
return None
# ------------------------------------------------------------------
# Cleanup
# ------------------------------------------------------------------
async def close(self):
self.connected = False
if self._listener_task:
self._listener_task.cancel()
if self.raw_writer:
self.raw_writer.close()
async def reconnect(self):
"""Close existing connection and re-establish: connect + re-login using in-memory keys."""
try:
await self.close()
except Exception:
pass
# Reset reader/writer but keep keys and sessions
self.reader = None
self.writer = None
self.raw_writer = None
self._listener_task = None
self._pending.clear()
self.login_rejected = False
# Drain queues
while not self._response_queue.empty():
try:
self._response_queue.get_nowait()
except Exception:
break
while not self._notification_queue.empty():
try:
self._notification_queue.get_nowait()
except Exception:
break
await self.connect()
self._listener_task = asyncio.create_task(self._background_listener())
if self.email and self.private_key:
# RSA challenge-response login (keys already in memory)
start = await self.send_and_recv("login_start", email=self.email)
if start["status"] == "ok":
challenge = decode_binary(start["data"]["challenge"])
signature = rsa_sign(self.private_key, challenge)
login_kwargs = {
"email": self.email,
"signature": encode_binary(signature),
"client_version": VERSION,
}
if self.device_id:
login_kwargs["device_id"] = self.device_id
finish = await self.send_and_recv("login_finish", **login_kwargs)
if finish["status"] == "ok":
self.session = finish["data"]
asyncio.create_task(self._ensure_prekeys())
else:
# Login rejected — keys were likely rotated on another device
self.session = None
self.connected = False
self.login_rejected = True