2610 lines
107 KiB
Python
2610 lines
107 KiB
Python
"""Shared network layer and ChatClient class for CLI and GUI clients.
|
|
|
|
Uses X3DH + Double Ratchet for message encryption, Sender Keys for groups.
|
|
RSA retained for login challenge-response only.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import ssl
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
load_dotenv()
|
|
|
|
from crypto_utils import (
|
|
# RSA (login only)
|
|
generate_rsa_keypair,
|
|
serialize_private_key,
|
|
serialize_public_key,
|
|
load_private_key,
|
|
load_public_key,
|
|
rsa_sign,
|
|
# Ed25519
|
|
generate_identity_keypair,
|
|
serialize_ed25519_private,
|
|
serialize_ed25519_private_raw,
|
|
serialize_ed25519_public,
|
|
load_ed25519_private,
|
|
load_ed25519_public,
|
|
ed25519_sign,
|
|
# X25519
|
|
generate_x25519_keypair,
|
|
serialize_x25519_private,
|
|
serialize_x25519_public,
|
|
load_x25519_private,
|
|
load_x25519_public,
|
|
# X3DH
|
|
generate_signed_prekey,
|
|
generate_one_time_prekeys,
|
|
x3dh_initiate,
|
|
x3dh_respond,
|
|
# Double Ratchet
|
|
DoubleRatchet,
|
|
# Sender Keys
|
|
SenderKeyState,
|
|
# AES
|
|
aes_encrypt,
|
|
aes_decrypt,
|
|
# Self-encryption
|
|
derive_self_encryption_key,
|
|
# Local storage encryption
|
|
derive_local_storage_key,
|
|
)
|
|
from protocol import (
|
|
VERSION,
|
|
ProtocolReader,
|
|
ProtocolWriter,
|
|
encode_binary,
|
|
decode_binary,
|
|
build_request,
|
|
MAX_MESSAGE_BYTES,
|
|
IMAGE_CHUNK_SIZE,
|
|
)
|
|
|
|
|
|
KEY_DIR = Path.home() / ".encrypted_chat"
|
|
OPK_REPLENISH_THRESHOLD = 20
|
|
OPK_BATCH_SIZE = 50
|
|
SPK_ROTATION_DAYS = 7
|
|
|
|
|
|
def _encrypt_local(data: bytes, key: bytes) -> bytes:
|
|
"""Encrypt data with AES-256-GCM for local storage. Format: nonce(12) + tag(16) + ciphertext."""
|
|
_, nonce, ct, tag = aes_encrypt(data, key=key)
|
|
return nonce + tag + ct
|
|
|
|
|
|
def _decrypt_local(raw: bytes, key: bytes) -> bytes:
|
|
"""Decrypt data encrypted by _encrypt_local."""
|
|
nonce, tag, ct = raw[:12], raw[12:28], raw[28:]
|
|
return aes_decrypt(key, nonce, ct, tag)
|
|
|
|
|
|
def get_key_dir(email: str) -> Path:
|
|
d = KEY_DIR / email
|
|
d.mkdir(parents=True, exist_ok=True)
|
|
os.chmod(d, 0o700)
|
|
return d
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# RSA key storage (login only — unchanged interface)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def save_keys(email: str, private_key, public_key, password: bytes | None = None):
|
|
d = get_key_dir(email)
|
|
(d / "private.pem").write_bytes(serialize_private_key(private_key, password=password))
|
|
(d / "public.pem").write_bytes(serialize_public_key(public_key))
|
|
os.chmod(d / "private.pem", 0o600)
|
|
|
|
|
|
def load_keys(email: str, password: bytes | None = None):
|
|
d = get_key_dir(email)
|
|
priv_path = d / "private.pem"
|
|
pub_path = d / "public.pem"
|
|
if not priv_path.exists():
|
|
return None, None, "No local keys found."
|
|
pem = priv_path.read_bytes()
|
|
try:
|
|
private_key = load_private_key(pem, password=password)
|
|
except Exception:
|
|
try:
|
|
private_key = load_private_key(pem, password=None)
|
|
if password:
|
|
save_keys(email, private_key, load_public_key(pub_path.read_bytes()), password=password)
|
|
except Exception:
|
|
return None, None, "Invalid or missing password."
|
|
public_key = load_public_key(pub_path.read_bytes())
|
|
return private_key, public_key, None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Identity + prekey storage
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _save_identity_keys(email: str, ed_priv, ed_pub, password: bytes | None = None):
|
|
d = get_key_dir(email)
|
|
if password:
|
|
(d / "identity_private.bin").write_bytes(serialize_ed25519_private(ed_priv, password=password))
|
|
else:
|
|
(d / "identity_private.bin").write_bytes(serialize_ed25519_private_raw(ed_priv))
|
|
(d / "identity_public.bin").write_bytes(serialize_ed25519_public(ed_pub))
|
|
os.chmod(d / "identity_private.bin", 0o600)
|
|
|
|
|
|
def _load_identity_keys(email: str, password: bytes | None = None):
|
|
d = get_key_dir(email)
|
|
priv_path = d / "identity_private.bin"
|
|
pub_path = d / "identity_public.bin"
|
|
if not priv_path.exists():
|
|
return None, None
|
|
priv = load_ed25519_private(priv_path.read_bytes(), password=password)
|
|
pub = load_ed25519_public(pub_path.read_bytes())
|
|
return priv, pub
|
|
|
|
|
|
def _save_spk(email: str, spk_priv, spk_id: str):
|
|
d = get_key_dir(email)
|
|
(d / "spk_private.bin").write_bytes(serialize_x25519_private(spk_priv))
|
|
(d / "spk_id.txt").write_text(spk_id)
|
|
os.chmod(d / "spk_private.bin", 0o600)
|
|
|
|
|
|
def _load_spk(email: str):
|
|
d = get_key_dir(email)
|
|
priv_path = d / "spk_private.bin"
|
|
id_path = d / "spk_id.txt"
|
|
if not priv_path.exists():
|
|
return None, None
|
|
priv = load_x25519_private(priv_path.read_bytes())
|
|
spk_id = id_path.read_text().strip() if id_path.exists() else ""
|
|
return priv, spk_id
|
|
|
|
|
|
def _save_prev_spk(email: str, spk_priv, spk_id: str):
|
|
"""Save previous SPK for grace period (in-flight X3DH may reference old SPK)."""
|
|
d = get_key_dir(email)
|
|
(d / "prev_spk_private.bin").write_bytes(serialize_x25519_private(spk_priv))
|
|
(d / "prev_spk_id.txt").write_text(spk_id)
|
|
os.chmod(d / "prev_spk_private.bin", 0o600)
|
|
|
|
|
|
def _load_prev_spk(email: str):
|
|
"""Load previous SPK (grace period). Returns (private_key, spk_id) or (None, None)."""
|
|
d = get_key_dir(email)
|
|
priv_path = d / "prev_spk_private.bin"
|
|
id_path = d / "prev_spk_id.txt"
|
|
if not priv_path.exists():
|
|
return None, None
|
|
priv = load_x25519_private(priv_path.read_bytes())
|
|
spk_id = id_path.read_text().strip() if id_path.exists() else ""
|
|
return priv, spk_id
|
|
|
|
|
|
def _save_opk_private(email: str, opk_id: str, opk_priv):
|
|
d = get_key_dir(email) / "opk_private"
|
|
d.mkdir(parents=True, exist_ok=True)
|
|
os.chmod(d, 0o700)
|
|
(d / f"{opk_id}.bin").write_bytes(serialize_x25519_private(opk_priv))
|
|
os.chmod(d / f"{opk_id}.bin", 0o600)
|
|
|
|
|
|
def _load_opk_private(email: str, opk_id: str):
|
|
d = get_key_dir(email) / "opk_private"
|
|
p = d / f"{opk_id}.bin"
|
|
if not p.exists():
|
|
return None
|
|
return load_x25519_private(p.read_bytes())
|
|
|
|
|
|
def _delete_opk_private(email: str, opk_id: str):
|
|
d = get_key_dir(email) / "opk_private"
|
|
p = d / f"{opk_id}.bin"
|
|
try:
|
|
p.unlink(missing_ok=True)
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def _save_device_id(email: str, device_id: str):
|
|
d = get_key_dir(email)
|
|
p = d / "device_id.txt"
|
|
p.write_text(device_id)
|
|
os.chmod(p, 0o600)
|
|
|
|
|
|
def _load_device_id(email: str) -> str | None:
|
|
d = get_key_dir(email)
|
|
p = d / "device_id.txt"
|
|
if not p.exists():
|
|
return None
|
|
return p.read_text().strip() or None
|
|
|
|
|
|
def _save_session(email: str, peer_user_id: str, ratchet: DoubleRatchet,
|
|
local_key: bytes | None = None, peer_device_id: str | None = None):
|
|
d = get_key_dir(email) / "sessions"
|
|
d.mkdir(parents=True, exist_ok=True)
|
|
os.chmod(d, 0o700)
|
|
if peer_device_id:
|
|
filename = f"{peer_user_id}_{peer_device_id}.bin"
|
|
else:
|
|
filename = f"{peer_user_id}.bin"
|
|
p = d / filename
|
|
data = ratchet.export_state()
|
|
if local_key:
|
|
data = _encrypt_local(data, local_key)
|
|
p.write_bytes(data)
|
|
os.chmod(p, 0o600)
|
|
|
|
|
|
def _load_session(email: str, peer_user_id: str,
|
|
local_key: bytes | None = None,
|
|
peer_device_id: str | None = None) -> DoubleRatchet | None:
|
|
d = get_key_dir(email) / "sessions"
|
|
if peer_device_id:
|
|
p = d / f"{peer_user_id}_{peer_device_id}.bin"
|
|
if not p.exists():
|
|
# Fallback: try old format (no device_id) and migrate
|
|
p_old = d / f"{peer_user_id}.bin"
|
|
if p_old.exists():
|
|
ratchet = _load_session_file(p_old, local_key)
|
|
if ratchet:
|
|
_save_session(email, peer_user_id, ratchet, local_key,
|
|
peer_device_id=peer_device_id)
|
|
try:
|
|
p_old.unlink()
|
|
except Exception:
|
|
pass
|
|
return ratchet
|
|
return None
|
|
else:
|
|
p = d / f"{peer_user_id}.bin"
|
|
if not p.exists():
|
|
return None
|
|
return _load_session_file(p, local_key)
|
|
|
|
|
|
def _load_session_file(p: Path, local_key: bytes | None = None) -> DoubleRatchet | None:
|
|
"""Load a session from a specific file path."""
|
|
if not p.exists():
|
|
return None
|
|
raw = p.read_bytes()
|
|
if local_key:
|
|
try:
|
|
data = _decrypt_local(raw, local_key)
|
|
except Exception:
|
|
# Migration: try loading as plaintext (old unencrypted format)
|
|
try:
|
|
ratchet = DoubleRatchet.import_state(raw)
|
|
return ratchet
|
|
except Exception:
|
|
return None
|
|
return DoubleRatchet.import_state(data)
|
|
return DoubleRatchet.import_state(raw)
|
|
|
|
|
|
def _delete_session_file(email: str, peer_user_id: str, peer_device_id: str | None = None):
|
|
"""Delete a session file from disk (for session reset)."""
|
|
d = get_key_dir(email) / "sessions"
|
|
if peer_device_id:
|
|
p = d / f"{peer_user_id}_{peer_device_id}.bin"
|
|
else:
|
|
p = d / f"{peer_user_id}.bin"
|
|
try:
|
|
p.unlink(missing_ok=True)
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def _save_sender_key_state(email: str, conv_id: str, state: SenderKeyState,
|
|
local_key: bytes | None = None):
|
|
d = get_key_dir(email) / "sender_keys"
|
|
d.mkdir(parents=True, exist_ok=True)
|
|
os.chmod(d, 0o700)
|
|
p = d / f"{conv_id}.bin"
|
|
data = state.export_state()
|
|
if local_key:
|
|
data = _encrypt_local(data, local_key)
|
|
p.write_bytes(data)
|
|
os.chmod(p, 0o600)
|
|
|
|
|
|
def _load_sender_key_state(email: str, conv_id: str,
|
|
local_key: bytes | None = None) -> SenderKeyState | None:
|
|
d = get_key_dir(email) / "sender_keys"
|
|
p = d / f"{conv_id}.bin"
|
|
if not p.exists():
|
|
return None
|
|
raw = p.read_bytes()
|
|
if local_key:
|
|
try:
|
|
data = _decrypt_local(raw, local_key)
|
|
except Exception:
|
|
try:
|
|
sk = SenderKeyState.import_state(raw)
|
|
_save_sender_key_state(email, conv_id, sk, local_key)
|
|
return sk
|
|
except Exception:
|
|
return None
|
|
return SenderKeyState.import_state(data)
|
|
return SenderKeyState.import_state(raw)
|
|
|
|
|
|
def _save_recv_sender_key(email: str, conv_id: str, sender_id: str, state: SenderKeyState,
|
|
local_key: bytes | None = None,
|
|
sender_device_id: str | None = None):
|
|
d = get_key_dir(email) / "sender_keys_recv"
|
|
d.mkdir(parents=True, exist_ok=True)
|
|
os.chmod(d, 0o700)
|
|
if sender_device_id:
|
|
filename = f"{conv_id}_{sender_id}_{sender_device_id}.bin"
|
|
else:
|
|
filename = f"{conv_id}_{sender_id}.bin"
|
|
p = d / filename
|
|
data = state.export_state()
|
|
if local_key:
|
|
data = _encrypt_local(data, local_key)
|
|
p.write_bytes(data)
|
|
os.chmod(p, 0o600)
|
|
|
|
|
|
def _load_recv_sender_key(email: str, conv_id: str, sender_id: str,
|
|
local_key: bytes | None = None,
|
|
sender_device_id: str | None = None) -> SenderKeyState | None:
|
|
d = get_key_dir(email) / "sender_keys_recv"
|
|
if sender_device_id:
|
|
p = d / f"{conv_id}_{sender_id}_{sender_device_id}.bin"
|
|
if not p.exists():
|
|
# Fallback: try old format and migrate
|
|
p_old = d / f"{conv_id}_{sender_id}.bin"
|
|
if p_old.exists():
|
|
sk = _load_recv_sender_key_file(p_old, local_key)
|
|
if sk:
|
|
_save_recv_sender_key(email, conv_id, sender_id, sk, local_key,
|
|
sender_device_id=sender_device_id)
|
|
try:
|
|
p_old.unlink()
|
|
except Exception:
|
|
pass
|
|
return sk
|
|
return None
|
|
else:
|
|
p = d / f"{conv_id}_{sender_id}.bin"
|
|
if not p.exists():
|
|
return None
|
|
return _load_recv_sender_key_file(p, local_key)
|
|
|
|
|
|
def _load_recv_sender_key_file(p: Path, local_key: bytes | None = None) -> SenderKeyState | None:
|
|
"""Load a recv sender key from a specific file path."""
|
|
if not p.exists():
|
|
return None
|
|
raw = p.read_bytes()
|
|
if local_key:
|
|
try:
|
|
data = _decrypt_local(raw, local_key)
|
|
except Exception:
|
|
try:
|
|
sk = SenderKeyState.import_state(raw)
|
|
return sk
|
|
except Exception:
|
|
return None
|
|
return SenderKeyState.import_state(data)
|
|
return SenderKeyState.import_state(raw)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Local decrypted message cache (Double Ratchet keys are one-time use)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _load_message_cache(email: str, conv_id: str, cache_key: bytes | None = None) -> dict:
|
|
d = get_key_dir(email) / "message_cache"
|
|
p_bin = d / f"{conv_id}.bin"
|
|
p_json = d / f"{conv_id}.json"
|
|
|
|
# Migration: if old plaintext .json exists but encrypted .bin doesn't
|
|
if p_json.exists() and not p_bin.exists():
|
|
try:
|
|
cache = json.loads(p_json.read_text("utf-8"))
|
|
if cache_key:
|
|
_save_message_cache_full(d, conv_id, cache, cache_key)
|
|
p_json.unlink(missing_ok=True)
|
|
return cache
|
|
except Exception:
|
|
return {}
|
|
|
|
if not p_bin.exists():
|
|
return {}
|
|
if not cache_key:
|
|
return {}
|
|
try:
|
|
raw = p_bin.read_bytes()
|
|
# Format: nonce (12) + tag (16) + ciphertext
|
|
nonce = raw[:12]
|
|
tag = raw[12:28]
|
|
ct = raw[28:]
|
|
plaintext = aes_decrypt(cache_key, nonce, ct, tag)
|
|
return json.loads(plaintext.decode("utf-8"))
|
|
except Exception:
|
|
return {}
|
|
|
|
|
|
def _save_message_cache_full(d: Path, conv_id: str, cache: dict, cache_key: bytes):
|
|
"""Write the full cache dict encrypted to disk."""
|
|
d.mkdir(parents=True, exist_ok=True)
|
|
os.chmod(d, 0o700)
|
|
p = d / f"{conv_id}.bin"
|
|
plaintext = json.dumps(cache, ensure_ascii=False).encode("utf-8")
|
|
_key, nonce, ct, tag = aes_encrypt(plaintext, key=cache_key)
|
|
p.write_bytes(nonce + tag + ct)
|
|
os.chmod(p, 0o600)
|
|
|
|
|
|
def _save_message_to_cache(email: str, conv_id: str, message_id: str, payload: dict,
|
|
cache_key: bytes | None = None):
|
|
d = get_key_dir(email) / "message_cache"
|
|
cache = _load_message_cache(email, conv_id, cache_key)
|
|
cache[message_id] = payload
|
|
if cache_key:
|
|
_save_message_cache_full(d, conv_id, cache, cache_key)
|
|
else:
|
|
# Fallback: plaintext (no identity key available yet)
|
|
d.mkdir(parents=True, exist_ok=True)
|
|
os.chmod(d, 0o700)
|
|
p = d / f"{conv_id}.json"
|
|
p.write_text(json.dumps(cache, ensure_ascii=False), "utf-8")
|
|
os.chmod(p, 0o600)
|
|
|
|
|
|
class ChatClient:
|
|
def __init__(self):
|
|
self.reader: ProtocolReader | None = None
|
|
self.writer: ProtocolWriter | None = None
|
|
self.raw_writer: asyncio.StreamWriter | None = None
|
|
self.session: dict | None = None
|
|
self.private_key = None # RSA private key (login only)
|
|
self.public_key = None # RSA public key (login only)
|
|
self.username: str = ""
|
|
self.email: str = ""
|
|
self._listener_task: asyncio.Task | None = None
|
|
self._response_queue: asyncio.Queue = asyncio.Queue()
|
|
self._notification_queue: asyncio.Queue = asyncio.Queue()
|
|
self._pending: dict[str, asyncio.Future] = {}
|
|
self._pairing_temp_private_key = None
|
|
self._reencrypt_progress_cb = None
|
|
self._logger = logging.getLogger("encrypted_chat.client")
|
|
|
|
# Signal Protocol keys
|
|
self.identity_private = None # Ed25519PrivateKey
|
|
self.identity_public = None # Ed25519PublicKey
|
|
self.spk_private = None # X25519PrivateKey (current signed prekey)
|
|
self.spk_id: str = ""
|
|
self._prev_spk_private = None # Previous SPK for grace period (M4)
|
|
self._prev_spk_id: str = ""
|
|
self.opk_privates: dict[str, object] = {} # id -> X25519PrivateKey
|
|
self.sessions: dict[str, DoubleRatchet] = {} # "user_id:device_id" -> ratchet
|
|
self.sender_key_states: dict[str, SenderKeyState] = {} # conv_id -> own sender key
|
|
self.recv_sender_keys: dict[str, SenderKeyState] = {} # "conv_id:sender_id:device_id" -> their key
|
|
# Cache: user_id -> {identity_key (Ed25519PublicKey), username, email}
|
|
self._user_cache: dict[str, dict] = {}
|
|
self.connected: bool = False
|
|
self.login_rejected: bool = False
|
|
self._cache_key: bytes | None = None # AES key for encrypting message cache on disk
|
|
self._local_key: bytes | None = None # AES key for encrypting session/sender key files
|
|
# Multi-device support
|
|
self.device_id: str | None = None # This device's UUID
|
|
self._device_bundle_cache: dict[str, tuple[float, list[dict]]] = {} # user_id -> (ts, bundles)
|
|
|
|
async def connect(self):
|
|
host = os.getenv("SERVER_HOST", "127.0.0.1")
|
|
port = int(os.getenv("SERVER_PORT", "9999"))
|
|
tls_enabled = os.getenv("TLS_ENABLED", "false").lower() in ("1", "true", "yes")
|
|
tls_required = os.getenv("TLS_REQUIRED", "false").lower() in ("1", "true", "yes")
|
|
ssl_context = None
|
|
if tls_required and not tls_enabled:
|
|
raise RuntimeError("TLS_REQUIRED is enabled but TLS is not enabled.")
|
|
if tls_enabled:
|
|
insecure = os.getenv("TLS_INSECURE", "false").lower() in ("1", "true", "yes")
|
|
is_dev = os.getenv("ENVIRONMENT", "").lower() in ("dev", "development")
|
|
if insecure and not is_dev:
|
|
raise RuntimeError("TLS_INSECURE is only allowed when ENVIRONMENT=dev")
|
|
ssl_context = ssl.create_default_context()
|
|
ca_file = os.getenv("TLS_CA_FILE", "").strip()
|
|
if ca_file:
|
|
ssl_context.load_verify_locations(cafile=ca_file)
|
|
elif insecure:
|
|
ssl_context.check_hostname = False
|
|
ssl_context.verify_mode = ssl.CERT_NONE
|
|
else:
|
|
self._logger.warning("TLS is disabled — traffic is unencrypted. Set TLS_ENABLED=true for production.")
|
|
r, w = await asyncio.open_connection(host, port, limit=MAX_MESSAGE_BYTES, ssl=ssl_context)
|
|
self.reader = ProtocolReader(r)
|
|
self.writer = ProtocolWriter(w)
|
|
self.raw_writer = w
|
|
self.connected = True
|
|
|
|
async def _background_listener(self):
|
|
"""Read messages from server, routing responses vs notifications."""
|
|
while True:
|
|
msg = await self.reader.read_message()
|
|
if msg is None:
|
|
self.connected = False
|
|
# Fail all pending futures so send_and_recv doesn't hang
|
|
pending = dict(self._pending)
|
|
self._pending.clear()
|
|
err = ConnectionError("Server connection lost")
|
|
for fut in pending.values():
|
|
if not fut.done():
|
|
fut.set_exception(err)
|
|
break
|
|
if msg.get("type") in ("new_message", "messages_read", "message_deleted",
|
|
"conversation_created", "member_added", "member_removed",
|
|
"user_online", "user_offline", "online_users",
|
|
"group_invitation", "conversation_renamed",
|
|
"session_reset"):
|
|
await self._notification_queue.put(msg)
|
|
else:
|
|
req_id = msg.get("request_id")
|
|
if req_id and req_id in self._pending:
|
|
fut = self._pending.pop(req_id)
|
|
if not fut.done():
|
|
fut.set_result(msg)
|
|
else:
|
|
await self._response_queue.put(msg)
|
|
|
|
async def send_and_recv(self, msg_type: str, timeout: float = 30.0, **kwargs) -> dict:
|
|
try:
|
|
request_id = str(uuid.uuid4())
|
|
loop = asyncio.get_running_loop()
|
|
fut = loop.create_future()
|
|
self._pending[request_id] = fut
|
|
await self.writer.send_request(msg_type, request_id=request_id, **kwargs)
|
|
except (ValueError, ConnectionError, OSError) as e:
|
|
self._pending.pop(request_id, None)
|
|
return {
|
|
"type": msg_type,
|
|
"status": "error",
|
|
"data": {"message": str(e) or "Connection lost."},
|
|
}
|
|
try:
|
|
return await asyncio.wait_for(fut, timeout=timeout)
|
|
except asyncio.TimeoutError:
|
|
self._logger.warning("send_and_recv timeout for '%s' after %.0fs", msg_type, timeout)
|
|
return {
|
|
"type": msg_type,
|
|
"status": "error",
|
|
"data": {"message": f"Request timed out ({msg_type})"},
|
|
}
|
|
except ConnectionError:
|
|
return {
|
|
"type": msg_type,
|
|
"status": "error",
|
|
"data": {"message": "Connection lost."},
|
|
}
|
|
finally:
|
|
self._pending.pop(request_id, None)
|
|
|
|
# ------------------------------------------------------------------
|
|
# User info / identity key cache
|
|
# ------------------------------------------------------------------
|
|
|
|
async def _get_user_info(self, user_id: str = "", email: str = "") -> dict | None:
|
|
"""Get user info from server, cache identity key."""
|
|
cached = self._user_cache.get(user_id)
|
|
if cached:
|
|
return cached
|
|
kwargs = {}
|
|
if user_id:
|
|
kwargs["user_id"] = user_id
|
|
elif email:
|
|
kwargs["email"] = email
|
|
else:
|
|
return None
|
|
resp = await self.send_and_recv("get_user_info", **kwargs)
|
|
if resp["status"] != "ok":
|
|
return None
|
|
data = resp["data"]
|
|
ik_bytes = decode_binary(data["identity_key"]) if data.get("identity_key") else None
|
|
info = {
|
|
"user_id": data["user_id"],
|
|
"username": data["username"],
|
|
"email": data["email"],
|
|
"identity_key": load_ed25519_public(ik_bytes) if ik_bytes else None,
|
|
"identity_key_bytes": ik_bytes,
|
|
}
|
|
self._user_cache[data["user_id"]] = info
|
|
return info
|
|
|
|
# ------------------------------------------------------------------
|
|
# Registration
|
|
# ------------------------------------------------------------------
|
|
|
|
async def register(self, username: str, password: str, email: str) -> tuple[bool, str]:
|
|
"""Register user. Generates RSA + Ed25519 + prekeys."""
|
|
self.username = username
|
|
self.email = email
|
|
pwd_bytes = bytearray(password.encode("utf-8")) if password else None
|
|
|
|
try:
|
|
# RSA keys for login
|
|
priv, pub, err = load_keys(email, password=bytes(pwd_bytes) if pwd_bytes else None)
|
|
if priv is None:
|
|
priv, pub = generate_rsa_keypair()
|
|
save_keys(email, priv, pub, password=bytes(pwd_bytes) if pwd_bytes else None)
|
|
self.private_key = priv
|
|
self.public_key = pub
|
|
|
|
# Ed25519 identity keys
|
|
ed_priv, ed_pub = _load_identity_keys(email, password=bytes(pwd_bytes) if pwd_bytes else None)
|
|
if ed_priv is None:
|
|
ed_priv, ed_pub = generate_identity_keypair()
|
|
_save_identity_keys(email, ed_priv, ed_pub, password=bytes(pwd_bytes) if pwd_bytes else None)
|
|
self.identity_private = ed_priv
|
|
self.identity_public = ed_pub
|
|
self._cache_key = derive_self_encryption_key(ed_priv)
|
|
self._local_key = derive_local_storage_key(ed_priv)
|
|
finally:
|
|
if pwd_bytes:
|
|
pwd_bytes[:] = b'\x00' * len(pwd_bytes)
|
|
|
|
pub_pem = serialize_public_key(pub).decode("utf-8")
|
|
ik_b64 = encode_binary(serialize_ed25519_public(ed_pub))
|
|
|
|
start = await self.send_and_recv(
|
|
"register",
|
|
username=username,
|
|
public_key=pub_pem,
|
|
email=email,
|
|
identity_key=ik_b64,
|
|
)
|
|
if start["status"] != "ok":
|
|
return False, start["data"]["message"]
|
|
code = start["data"].get("code")
|
|
if code:
|
|
return True, code
|
|
return True, start["data"].get("message", "Check your email for the code.")
|
|
|
|
async def confirm_registration(self, email: str, username: str, code: str) -> tuple[bool, str]:
|
|
confirm = await self.send_and_recv("register_confirm", email=email, code=code)
|
|
if confirm["status"] == "ok":
|
|
# Upload prekeys immediately after registration
|
|
await self._generate_and_upload_prekeys()
|
|
return True, f"Registered as '{username}' (ID: {confirm['data']['user_id']})"
|
|
return False, confirm["data"]["message"]
|
|
|
|
async def _generate_and_upload_prekeys(self, keep_spk: bool = False):
|
|
"""Generate SPK + OPKs and upload to server.
|
|
|
|
If keep_spk=True, re-sign the existing SPK instead of generating a new
|
|
one. This is used after device pairing so both devices share the same
|
|
SPK and either can respond to X3DH.
|
|
"""
|
|
if not self.identity_private:
|
|
return
|
|
|
|
if keep_spk and self.spk_private and self.spk_id:
|
|
# Re-sign existing SPK (both devices share the identity key)
|
|
spk_pub_bytes = serialize_x25519_public(self.spk_private.public_key())
|
|
spk_sig = ed25519_sign(self.identity_private, spk_pub_bytes)
|
|
spk_data = {
|
|
"id": self.spk_id,
|
|
"public_key": encode_binary(spk_pub_bytes),
|
|
"signature": encode_binary(spk_sig),
|
|
}
|
|
else:
|
|
# Save current SPK as previous for grace period (M4: in-flight X3DH)
|
|
if self.spk_private and self.spk_id:
|
|
self._prev_spk_private = self.spk_private
|
|
self._prev_spk_id = self.spk_id
|
|
_save_prev_spk(self.email, self.spk_private, self.spk_id)
|
|
# Generate a brand-new signed prekey
|
|
spk = generate_signed_prekey(self.identity_private)
|
|
self.spk_private = spk["private"]
|
|
self.spk_id = spk["id"]
|
|
_save_spk(self.email, spk["private"], spk["id"])
|
|
spk_data = {
|
|
"id": spk["id"],
|
|
"public_key": encode_binary(serialize_x25519_public(spk["public"])),
|
|
"signature": encode_binary(spk["signature"]),
|
|
}
|
|
|
|
# Generate one-time prekeys
|
|
opks = generate_one_time_prekeys(OPK_BATCH_SIZE)
|
|
for opk in opks:
|
|
self.opk_privates[opk["id"]] = opk["private"]
|
|
_save_opk_private(self.email, opk["id"], opk["private"])
|
|
|
|
# Upload to server
|
|
otp_data = [
|
|
{"id": opk["id"], "public_key": encode_binary(serialize_x25519_public(opk["public"]))}
|
|
for opk in opks
|
|
]
|
|
await self.send_and_recv(
|
|
"upload_prekeys",
|
|
signed_prekey=spk_data,
|
|
one_time_prekeys=otp_data,
|
|
)
|
|
|
|
async def _ensure_prekeys(self):
|
|
"""Check OPK count and SPK age, replenish/rotate if needed."""
|
|
resp = await self.send_and_recv("get_prekey_count")
|
|
if resp["status"] != "ok":
|
|
return
|
|
count = resp["data"].get("count", 0)
|
|
spk_created_at = resp["data"].get("spk_created_at", "")
|
|
|
|
need_new_spk = False
|
|
if spk_created_at:
|
|
try:
|
|
created = datetime.fromisoformat(spk_created_at)
|
|
if created.tzinfo is None:
|
|
created = created.replace(tzinfo=timezone.utc)
|
|
age_days = (datetime.now(timezone.utc) - created).days
|
|
if age_days >= SPK_ROTATION_DAYS:
|
|
need_new_spk = True
|
|
self._logger.info("SPK is %d days old, rotating...", age_days)
|
|
except Exception:
|
|
pass
|
|
|
|
if count < OPK_REPLENISH_THRESHOLD or need_new_spk:
|
|
if count >= OPK_REPLENISH_THRESHOLD:
|
|
self._logger.info("SPK rotation triggered (OPK count OK: %d)", count)
|
|
else:
|
|
self._logger.info("OPK count low (%d), replenishing...", count)
|
|
await self._generate_and_upload_prekeys()
|
|
|
|
# ------------------------------------------------------------------
|
|
# Login
|
|
# ------------------------------------------------------------------
|
|
|
|
async def login(self, email: str, password: str) -> tuple[bool, str]:
|
|
"""Login user. Returns (success, message)."""
|
|
self.email = email
|
|
pwd_bytes = bytearray(password.encode("utf-8")) if password else None
|
|
|
|
try:
|
|
# Load RSA keys
|
|
priv, pub, err = load_keys(email, password=bytes(pwd_bytes) if pwd_bytes else None)
|
|
if priv is None:
|
|
return False, err or "No local keys found. Register first."
|
|
self.private_key = priv
|
|
self.public_key = pub
|
|
|
|
# Load identity keys
|
|
ed_priv, ed_pub = _load_identity_keys(email, password=bytes(pwd_bytes) if pwd_bytes else None)
|
|
finally:
|
|
if pwd_bytes:
|
|
pwd_bytes[:] = b'\x00' * len(pwd_bytes)
|
|
|
|
if ed_priv is not None:
|
|
self.identity_private = ed_priv
|
|
self.identity_public = ed_pub
|
|
self._cache_key = derive_self_encryption_key(ed_priv)
|
|
self._local_key = derive_local_storage_key(ed_priv)
|
|
|
|
# Load SPK
|
|
spk_priv, spk_id = _load_spk(email)
|
|
if spk_priv:
|
|
self.spk_private = spk_priv
|
|
self.spk_id = spk_id
|
|
|
|
# Load previous SPK for grace period (M4)
|
|
prev_spk_priv, prev_spk_id = _load_prev_spk(email)
|
|
if prev_spk_priv:
|
|
self._prev_spk_private = prev_spk_priv
|
|
self._prev_spk_id = prev_spk_id
|
|
|
|
# Load device_id from disk
|
|
self.device_id = _load_device_id(email)
|
|
|
|
# RSA challenge-response login
|
|
start = await self.send_and_recv("login_start", email=email)
|
|
if start["status"] != "ok":
|
|
return False, start["data"]["message"]
|
|
|
|
challenge = decode_binary(start["data"]["challenge"])
|
|
signature = rsa_sign(self.private_key, challenge)
|
|
login_kwargs = {"email": email, "signature": encode_binary(signature),
|
|
"client_version": VERSION}
|
|
if self.device_id:
|
|
login_kwargs["device_id"] = self.device_id
|
|
finish = await self.send_and_recv("login_finish", **login_kwargs)
|
|
if finish["status"] == "ok":
|
|
self.session = finish["data"]
|
|
self.username = self.session.get("username", "")
|
|
# Store device_id from server
|
|
self.device_id = finish["data"].get("device_id")
|
|
if self.device_id:
|
|
_save_device_id(email, self.device_id)
|
|
# Replenish prekeys in background — after pairing, the new device
|
|
# has no local OPK private keys so we must generate fresh ones
|
|
# (server-side OPKs have no matching private keys on this device).
|
|
# Use keep_spk=True to preserve the shared SPK so both devices
|
|
# can respond to X3DH.
|
|
opk_dir = get_key_dir(self.email) / "opk_private"
|
|
has_local_opks = opk_dir.exists() and any(opk_dir.iterdir())
|
|
if has_local_opks:
|
|
asyncio.create_task(self._ensure_prekeys())
|
|
else:
|
|
self._logger.info("No local OPKs (likely new device). Generating fresh OPKs, keeping SPK.")
|
|
asyncio.create_task(self._generate_and_upload_prekeys(keep_spk=True))
|
|
return True, f"Logged in as '{self.username}' (ID: {self.session['user_id']})"
|
|
return False, finish["data"]["message"]
|
|
|
|
# ------------------------------------------------------------------
|
|
# Pairing (device pairing — transfers RSA + identity keys)
|
|
# ------------------------------------------------------------------
|
|
|
|
async def pairing_start(self, email: str) -> tuple[bool, str]:
|
|
"""Start device pairing. Returns (success, code/message)."""
|
|
temp_priv, temp_pub = generate_rsa_keypair(2048)
|
|
self._pairing_temp_private_key = temp_priv
|
|
temp_pub_pem = serialize_public_key(temp_pub).decode("utf-8")
|
|
resp = await self.send_and_recv("pairing_start", email=email, temp_public_key=temp_pub_pem)
|
|
if resp["status"] == "ok":
|
|
self._pairing_poll_token = resp["data"].get("poll_token", "")
|
|
return True, resp["data"]["code"]
|
|
return False, resp["data"]["message"]
|
|
|
|
async def pairing_wait(self, code: str, email: str, password: str, timeout: int = 300) -> tuple[bool, str]:
|
|
"""Wait for pairing payload and import keys. Returns (success, message)."""
|
|
if not self._pairing_temp_private_key:
|
|
return False, "Pairing not started."
|
|
from crypto_utils import aes_decrypt as _aes_decrypt
|
|
poll_token = getattr(self, "_pairing_poll_token", "")
|
|
deadline = asyncio.get_event_loop().time() + timeout
|
|
while asyncio.get_event_loop().time() < deadline:
|
|
resp = await self.send_and_recv("pairing_poll", code=code, poll_token=poll_token)
|
|
if resp["status"] != "ok":
|
|
return False, resp["data"]["message"]
|
|
if not resp["data"].get("ready"):
|
|
await asyncio.sleep(2.0)
|
|
continue
|
|
payload = resp["data"]["payload"]
|
|
try:
|
|
# Decrypt AES key with temp RSA key
|
|
from cryptography.hazmat.primitives.asymmetric import padding as rsa_padding
|
|
from cryptography.hazmat.primitives import hashes as rsa_hashes
|
|
enc_aes_key = decode_binary(payload["encrypted_key"])
|
|
aes_key = self._pairing_temp_private_key.decrypt(
|
|
enc_aes_key,
|
|
rsa_padding.OAEP(
|
|
mgf=rsa_padding.MGF1(algorithm=rsa_hashes.SHA256()),
|
|
algorithm=rsa_hashes.SHA256(),
|
|
label=None,
|
|
),
|
|
)
|
|
nonce = decode_binary(payload["iv"])
|
|
ct = decode_binary(payload["ciphertext"])
|
|
tag = decode_binary(payload["tag"])
|
|
keys_json = _aes_decrypt(aes_key, nonce, ct, tag)
|
|
keys_data = json.loads(keys_json)
|
|
|
|
pwd_bytes = bytearray(password.encode("utf-8")) if password else None
|
|
|
|
try:
|
|
# Import RSA key
|
|
rsa_priv = load_private_key(keys_data["rsa_private"].encode(), password=None)
|
|
rsa_pub = rsa_priv.public_key()
|
|
save_keys(email, rsa_priv, rsa_pub, password=bytes(pwd_bytes) if pwd_bytes else None)
|
|
|
|
# Import identity keys
|
|
ed_priv = load_ed25519_private(bytes.fromhex(keys_data["identity_private"]))
|
|
ed_pub = ed_priv.public_key()
|
|
_save_identity_keys(email, ed_priv, ed_pub, password=bytes(pwd_bytes) if pwd_bytes else None)
|
|
finally:
|
|
if pwd_bytes:
|
|
pwd_bytes[:] = b'\x00' * len(pwd_bytes)
|
|
|
|
self.email = email
|
|
self.private_key = rsa_priv
|
|
self.public_key = rsa_pub
|
|
self.identity_private = ed_priv
|
|
self.identity_public = ed_pub
|
|
self._cache_key = derive_self_encryption_key(ed_priv)
|
|
self._local_key = derive_local_storage_key(ed_priv)
|
|
self._pairing_temp_private_key = None
|
|
|
|
# Multi-device: new device generates own SPK + OPKs on first
|
|
# login. No session/sender key import needed — each device
|
|
# has independent Double Ratchet sessions.
|
|
|
|
return True, "Pairing complete."
|
|
except Exception as e:
|
|
return False, f"Failed to import keys: {e}"
|
|
return False, "Pairing timed out."
|
|
|
|
async def authorize_device(self, code: str) -> tuple[bool, str]:
|
|
"""Authorize a new device by sending all keys to it."""
|
|
if not self.private_key or not self.identity_private:
|
|
return False, "Not logged in."
|
|
claim = await self.send_and_recv("pairing_claim", code=code)
|
|
if claim["status"] != "ok":
|
|
return False, claim["data"]["message"]
|
|
|
|
temp_pub_pem = claim["data"]["temp_public_key"].encode("utf-8")
|
|
temp_pub = load_public_key(temp_pub_pem)
|
|
|
|
# Phase 1: Re-encrypt message history so new device can read old
|
|
# messages via self-encryption key. This also advances ratchet states
|
|
# for any previously-unfetched messages.
|
|
try:
|
|
await self.reencrypt_history()
|
|
except Exception as e:
|
|
self._logger.warning("Re-encryption failed: %s", e)
|
|
|
|
# Phase 2: Build keys payload — only RSA + identity key.
|
|
# Multi-device: new device generates own SPK + OPKs, creates independent
|
|
# sessions. No session/sender key transfer needed.
|
|
keys_data = {
|
|
"rsa_private": serialize_private_key(self.private_key, password=None).decode(),
|
|
"identity_private": serialize_ed25519_private_raw(self.identity_private).hex(),
|
|
}
|
|
|
|
# Phase 3: Encrypt and send keys to new device
|
|
from cryptography.hazmat.primitives.asymmetric import padding as rsa_padding
|
|
from cryptography.hazmat.primitives import hashes as rsa_hashes
|
|
plaintext = json.dumps(keys_data).encode()
|
|
aes_key, nonce, ct, tag = aes_encrypt(plaintext)
|
|
enc_aes_key = temp_pub.encrypt(
|
|
aes_key,
|
|
rsa_padding.OAEP(
|
|
mgf=rsa_padding.MGF1(algorithm=rsa_hashes.SHA256()),
|
|
algorithm=rsa_hashes.SHA256(),
|
|
label=None,
|
|
),
|
|
)
|
|
payload = {
|
|
"encrypted_key": encode_binary(enc_aes_key),
|
|
"iv": encode_binary(nonce),
|
|
"ciphertext": encode_binary(ct),
|
|
"tag": encode_binary(tag),
|
|
}
|
|
resp = await self.send_and_recv("pairing_send", code=code, payload=payload)
|
|
if resp["status"] == "ok":
|
|
return True, "Device authorized."
|
|
return False, resp["data"]["message"]
|
|
|
|
# ------------------------------------------------------------------
|
|
# Key rotation (RSA login key only)
|
|
# ------------------------------------------------------------------
|
|
|
|
async def rotate_keys(self, username: str, password: str) -> tuple[bool, str]:
|
|
"""Rotate RSA keypair to revoke other devices."""
|
|
if not self.session or self.session.get("username") != username:
|
|
return False, "Not logged in."
|
|
pwd_bytes = password.encode("utf-8") if password else None
|
|
priv, pub = generate_rsa_keypair()
|
|
save_keys(self.email, priv, pub, password=pwd_bytes)
|
|
self.private_key = priv
|
|
self.public_key = pub
|
|
pub_pem = serialize_public_key(pub).decode("utf-8")
|
|
resp = await self.send_and_recv("rotate_keys", public_key=pub_pem)
|
|
if resp["status"] == "ok":
|
|
return True, "RSA login keys rotated."
|
|
return False, resp["data"]["message"]
|
|
|
|
# ------------------------------------------------------------------
|
|
# Session management (X3DH + Double Ratchet)
|
|
# ------------------------------------------------------------------
|
|
|
|
async def _get_device_bundles(self, peer_user_id: str) -> list[dict]:
|
|
"""Get per-device key bundles for a peer. Caches for 5 minutes."""
|
|
import time
|
|
cached = self._device_bundle_cache.get(peer_user_id)
|
|
if cached:
|
|
ts, bundles = cached
|
|
if time.time() - ts < 300:
|
|
return bundles
|
|
|
|
resp = await self.send_and_recv("get_key_bundle", user_id=peer_user_id)
|
|
if resp["status"] != "ok":
|
|
raise RuntimeError(f"Cannot get key bundle for {peer_user_id}: {resp['data']['message']}")
|
|
|
|
data = resp["data"]
|
|
ik_b64 = data.get("identity_key", "")
|
|
|
|
device_bundles = data.get("device_bundles")
|
|
if device_bundles:
|
|
# Attach identity_key to each bundle
|
|
for b in device_bundles:
|
|
b["identity_key"] = ik_b64
|
|
else:
|
|
# Old server: wrap flat response as single-entry list
|
|
device_bundles = [{
|
|
"device_id": None,
|
|
"identity_key": ik_b64,
|
|
"signed_prekey_id": data.get("signed_prekey_id", ""),
|
|
"signed_prekey": data.get("signed_prekey", ""),
|
|
"spk_signature": data.get("spk_signature", ""),
|
|
"one_time_prekey_id": data.get("one_time_prekey_id"),
|
|
"one_time_prekey": data.get("one_time_prekey"),
|
|
}]
|
|
|
|
self._device_bundle_cache[peer_user_id] = (time.time(), device_bundles)
|
|
return device_bundles
|
|
|
|
async def _get_or_create_session(self, peer_user_id: str,
|
|
peer_device_id: str | None = None,
|
|
bundle: dict | None = None) -> DoubleRatchet:
|
|
"""Load existing session or create one via X3DH.
|
|
|
|
If peer_device_id is set, sessions are keyed by "user_id:device_id".
|
|
If bundle is provided, it's used instead of fetching from server.
|
|
"""
|
|
session_key = f"{peer_user_id}:{peer_device_id}" if peer_device_id else peer_user_id
|
|
|
|
# Check in-memory cache
|
|
if session_key in self.sessions:
|
|
return self.sessions[session_key]
|
|
|
|
# Check on disk
|
|
ratchet = _load_session(self.email, peer_user_id, self._local_key,
|
|
peer_device_id=peer_device_id)
|
|
if ratchet:
|
|
self.sessions[session_key] = ratchet
|
|
return ratchet
|
|
|
|
# Create new session via X3DH
|
|
if not bundle:
|
|
resp = await self.send_and_recv("get_key_bundle", user_id=peer_user_id)
|
|
if resp["status"] != "ok":
|
|
raise RuntimeError(f"Cannot get key bundle for {peer_user_id}: {resp['data']['message']}")
|
|
bundle = resp["data"]
|
|
|
|
ik_remote_bytes = decode_binary(bundle["identity_key"])
|
|
ik_remote = load_ed25519_public(ik_remote_bytes)
|
|
spk_remote = load_x25519_public(decode_binary(bundle["signed_prekey"]))
|
|
spk_sig = decode_binary(bundle["spk_signature"])
|
|
|
|
opk_remote = None
|
|
opk_id = bundle.get("one_time_prekey_id")
|
|
if bundle.get("one_time_prekey"):
|
|
opk_remote = load_x25519_public(decode_binary(bundle["one_time_prekey"]))
|
|
|
|
# Perform X3DH
|
|
shared_secret, ek_priv, ek_pub = x3dh_initiate(
|
|
self.identity_private,
|
|
ik_remote,
|
|
spk_remote,
|
|
spk_sig,
|
|
opk_remote,
|
|
)
|
|
|
|
# Initialize Double Ratchet as Alice
|
|
ratchet = DoubleRatchet.init_alice(shared_secret, spk_remote)
|
|
self.sessions[session_key] = ratchet
|
|
_save_session(self.email, peer_user_id, ratchet, self._local_key,
|
|
peer_device_id=peer_device_id)
|
|
|
|
# Build X3DH header for first message
|
|
x3dh_header = {
|
|
"ik": encode_binary(serialize_ed25519_public(self.identity_public)),
|
|
"ek": encode_binary(serialize_x25519_public(ek_pub)),
|
|
}
|
|
if opk_id:
|
|
x3dh_header["opk_id"] = opk_id
|
|
|
|
# Cache the x3dh header for the next send_message call
|
|
ratchet._x3dh_header = x3dh_header
|
|
|
|
# Cache remote user info
|
|
self._user_cache[peer_user_id] = {
|
|
"user_id": peer_user_id,
|
|
"identity_key": ik_remote,
|
|
"identity_key_bytes": ik_remote_bytes,
|
|
}
|
|
|
|
return ratchet
|
|
|
|
def _process_x3dh_header(self, sender_id: str, x3dh_header: dict,
|
|
sender_device_id: str | None = None,
|
|
spk_override=None) -> DoubleRatchet:
|
|
"""Process an incoming X3DH header to establish session as Bob.
|
|
|
|
Args:
|
|
spk_override: If provided, use this SPK private key instead of self.spk_private.
|
|
Used for grace period fallback (M4).
|
|
"""
|
|
ik_remote_bytes = decode_binary(x3dh_header["ik"])
|
|
ik_remote = load_ed25519_public(ik_remote_bytes)
|
|
ek_remote = load_x25519_public(decode_binary(x3dh_header["ek"]))
|
|
|
|
opk_id = x3dh_header.get("opk_id")
|
|
opk_priv = None
|
|
if opk_id:
|
|
opk_priv = _load_opk_private(self.email, opk_id)
|
|
if opk_priv:
|
|
_delete_opk_private(self.email, opk_id)
|
|
|
|
spk_priv = spk_override if spk_override else self.spk_private
|
|
|
|
shared_secret = x3dh_respond(
|
|
self.identity_private,
|
|
spk_priv,
|
|
ik_remote,
|
|
ek_remote,
|
|
opk_priv,
|
|
)
|
|
|
|
spk_pub = spk_priv.public_key() if hasattr(spk_priv, 'public_key') else None
|
|
ratchet = DoubleRatchet.init_bob(shared_secret, (spk_priv, spk_pub))
|
|
|
|
session_key = f"{sender_id}:{sender_device_id}" if sender_device_id else sender_id
|
|
self.sessions[session_key] = ratchet
|
|
_save_session(self.email, sender_id, ratchet, self._local_key,
|
|
peer_device_id=sender_device_id)
|
|
|
|
self._user_cache[sender_id] = {
|
|
"user_id": sender_id,
|
|
"identity_key": ik_remote,
|
|
"identity_key_bytes": ik_remote_bytes,
|
|
}
|
|
|
|
return ratchet
|
|
|
|
# ------------------------------------------------------------------
|
|
# Conversations
|
|
# ------------------------------------------------------------------
|
|
|
|
async def create_conversation(self, member_emails: list[str], name: str | None = None) -> tuple[str | None, str]:
|
|
kwargs = {"members": member_emails}
|
|
if name:
|
|
kwargs["name"] = name
|
|
resp = await self.send_and_recv("create_conversation", **kwargs)
|
|
if resp["status"] == "ok":
|
|
return resp["data"]["conversation_id"], "OK"
|
|
return None, resp["data"]["message"]
|
|
|
|
async def remove_member(self, conv_id: str, user_id: str) -> tuple[bool, str]:
|
|
resp = await self.send_and_recv("remove_member", conversation_id=conv_id, user_id=user_id)
|
|
if resp["status"] == "ok":
|
|
return True, "OK"
|
|
return False, resp["data"]["message"]
|
|
|
|
async def leave_group(self, conv_id: str) -> tuple[bool, str]:
|
|
"""Leave a group conversation."""
|
|
resp = await self.send_and_recv("leave_group", conversation_id=conv_id)
|
|
if resp["status"] == "ok":
|
|
# Clean up local sender key state for this group
|
|
self.sender_key_states.pop(conv_id, None)
|
|
# Remove received sender keys for this conversation
|
|
to_remove = [k for k in self.recv_sender_keys if k.startswith(f"{conv_id}:")]
|
|
for k in to_remove:
|
|
self.recv_sender_keys.pop(k, None)
|
|
return True, "OK"
|
|
return False, resp["data"]["message"]
|
|
|
|
async def rename_conversation(self, conv_id: str, name: str) -> tuple[bool, str]:
|
|
"""Rename a group conversation (creator only)."""
|
|
resp = await self.send_and_recv("rename_conversation", conversation_id=conv_id, name=name)
|
|
if resp["status"] == "ok":
|
|
return True, "OK"
|
|
return False, resp["data"]["message"]
|
|
|
|
async def delete_conversation(self, conv_id: str) -> tuple[bool, str]:
|
|
"""Delete a conversation (leave + server cleans up if empty)."""
|
|
resp = await self.send_and_recv("delete_conversation", conversation_id=conv_id)
|
|
if resp["status"] == "ok":
|
|
# Clean up local sender key state
|
|
self.sender_key_states.pop(conv_id, None)
|
|
to_remove = [k for k in self.recv_sender_keys if k.startswith(f"{conv_id}:")]
|
|
for k in to_remove:
|
|
self.recv_sender_keys.pop(k, None)
|
|
return True, "OK"
|
|
return False, resp["data"]["message"]
|
|
|
|
async def add_member(self, conv_id: str, email: str) -> tuple[bool, str]:
|
|
resp = await self.send_and_recv("add_member", conversation_id=conv_id, email=email)
|
|
if resp["status"] == "ok":
|
|
return True, "OK"
|
|
return False, resp["data"]["message"]
|
|
|
|
async def accept_invitation(self, conv_id: str) -> tuple[bool, str]:
|
|
"""Accept a group invitation."""
|
|
resp = await self.send_and_recv("accept_invitation", conversation_id=conv_id)
|
|
if resp["status"] == "ok":
|
|
return True, "OK"
|
|
return False, resp["data"]["message"]
|
|
|
|
async def decline_invitation(self, conv_id: str) -> tuple[bool, str]:
|
|
"""Decline a group invitation."""
|
|
resp = await self.send_and_recv("decline_invitation", conversation_id=conv_id)
|
|
if resp["status"] == "ok":
|
|
return True, "OK"
|
|
return False, resp["data"]["message"]
|
|
|
|
async def list_invitations(self) -> list[dict]:
|
|
"""List pending group invitations."""
|
|
resp = await self.send_and_recv("list_invitations")
|
|
if resp["status"] == "ok":
|
|
return resp["data"]["invitations"]
|
|
return []
|
|
|
|
async def list_conversations(self) -> list[dict]:
|
|
resp = await self.send_and_recv("list_conversations")
|
|
if resp["status"] == "ok":
|
|
return resp["data"]["conversations"]
|
|
return []
|
|
|
|
async def find_or_create_conversation(self, email: str) -> tuple[str | None, str]:
|
|
resp = await self.send_and_recv("find_conversation", email=email)
|
|
if resp["status"] != "ok":
|
|
return None, resp["data"]["message"]
|
|
conv_id = resp["data"]["conversation_id"]
|
|
if conv_id:
|
|
return conv_id, "OK"
|
|
return await self.create_conversation([email])
|
|
|
|
# ------------------------------------------------------------------
|
|
# Send message
|
|
# ------------------------------------------------------------------
|
|
|
|
def _is_group(self, members: list[dict]) -> bool:
|
|
return len(members) > 2
|
|
|
|
async def send_message(self, conv_id: str, text: str, members: list[dict],
|
|
reply_to: str | None = None) -> tuple[bool, str]:
|
|
"""Encrypt and send a message. DM: per-recipient Double Ratchet. Group: Sender Keys."""
|
|
my_user_id = self.session["user_id"]
|
|
|
|
# Build plaintext payload
|
|
payload = {
|
|
"sender": self.username,
|
|
"text": text,
|
|
"reply_to": reply_to,
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
}
|
|
plaintext = json.dumps(payload, ensure_ascii=False).encode("utf-8")
|
|
|
|
if self._is_group(members):
|
|
return await self._send_group_message(conv_id, plaintext, members)
|
|
else:
|
|
return await self._send_dm(conv_id, plaintext, members)
|
|
|
|
async def _send_dm(self, conv_id: str, plaintext: bytes, members: list[dict]) -> tuple[bool, str]:
|
|
"""Encrypt DM with per-device Double Ratchet."""
|
|
my_user_id = self.session["user_id"]
|
|
recipients = []
|
|
first_ratchet_header = None
|
|
|
|
for member in members:
|
|
uid = member.get("user_id")
|
|
if not uid or uid == my_user_id:
|
|
continue
|
|
|
|
# Get all device bundles for this user
|
|
try:
|
|
device_bundles = await self._get_device_bundles(uid)
|
|
self._logger.debug("Got %d device bundles for %s", len(device_bundles), uid)
|
|
except Exception as e:
|
|
self._logger.warning("Failed to get device bundles for %s: %s", uid, e)
|
|
device_bundles = []
|
|
|
|
if not device_bundles:
|
|
# Fallback: try single session (legacy peer)
|
|
ratchet = await self._get_or_create_session(uid)
|
|
result = ratchet.encrypt(plaintext)
|
|
x3dh_hdr = getattr(ratchet, "_x3dh_header", None)
|
|
if x3dh_hdr:
|
|
delattr(ratchet, "_x3dh_header")
|
|
entry = {
|
|
"user_id": uid,
|
|
"encrypted_content": encode_binary(result["ciphertext"]),
|
|
"nonce": encode_binary(result["nonce"]),
|
|
"ratchet_header": result["header"],
|
|
}
|
|
if x3dh_hdr:
|
|
entry["x3dh_header"] = x3dh_hdr
|
|
recipients.append(entry)
|
|
if first_ratchet_header is None:
|
|
first_ratchet_header = result["header"]
|
|
_save_session(self.email, uid, ratchet, self._local_key)
|
|
continue
|
|
|
|
for bundle in device_bundles:
|
|
dev_id = bundle.get("device_id")
|
|
ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id,
|
|
bundle=bundle)
|
|
result = ratchet.encrypt(plaintext)
|
|
x3dh_hdr = getattr(ratchet, "_x3dh_header", None)
|
|
if x3dh_hdr:
|
|
delattr(ratchet, "_x3dh_header")
|
|
|
|
entry = {
|
|
"user_id": uid,
|
|
"encrypted_content": encode_binary(result["ciphertext"]),
|
|
"nonce": encode_binary(result["nonce"]),
|
|
"ratchet_header": result["header"],
|
|
}
|
|
if dev_id:
|
|
entry["device_id"] = dev_id
|
|
if x3dh_hdr:
|
|
entry["x3dh_header"] = x3dh_hdr
|
|
recipients.append(entry)
|
|
|
|
if first_ratchet_header is None:
|
|
first_ratchet_header = result["header"]
|
|
|
|
_save_session(self.email, uid, ratchet, self._local_key,
|
|
peer_device_id=dev_id)
|
|
|
|
# Encrypt self-copy with static key derived from identity (not ratchet)
|
|
# Uses SELF_DEVICE_ID so all own devices can read it
|
|
self_key = derive_self_encryption_key(self.identity_private)
|
|
_, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key)
|
|
recipients.append({
|
|
"user_id": my_user_id,
|
|
"encrypted_content": encode_binary(self_ct + self_tag),
|
|
"nonce": encode_binary(self_nonce),
|
|
"ratchet_header": {"self": True},
|
|
})
|
|
|
|
if not recipients:
|
|
return False, "No recipients."
|
|
|
|
kwargs = {
|
|
"conversation_id": conv_id,
|
|
"ratchet_header": first_ratchet_header,
|
|
"recipients": recipients,
|
|
}
|
|
|
|
resp = await self.send_and_recv("send_message", **kwargs)
|
|
if resp["status"] == "ok":
|
|
return True, "Message sent."
|
|
return False, resp["data"]["message"]
|
|
|
|
async def _send_group_message(self, conv_id: str, plaintext: bytes,
|
|
members: list[dict]) -> tuple[bool, str]:
|
|
"""Encrypt group message with Sender Keys."""
|
|
my_user_id = self.session["user_id"]
|
|
|
|
# Get or create sender key for this group
|
|
sk = self.sender_key_states.get(conv_id)
|
|
if not sk:
|
|
sk = _load_sender_key_state(self.email, conv_id, self._local_key)
|
|
if not sk:
|
|
sk = SenderKeyState()
|
|
self.sender_key_states[conv_id] = sk
|
|
_save_sender_key_state(self.email, conv_id, sk, self._local_key)
|
|
# Distribute sender key to all members via pairwise ratchet
|
|
await self._distribute_sender_key(conv_id, members, sk)
|
|
|
|
self.sender_key_states[conv_id] = sk
|
|
|
|
# Encrypt with sender key
|
|
result = sk.encrypt(plaintext)
|
|
_save_sender_key_state(self.email, conv_id, sk, self._local_key)
|
|
|
|
# Build per-recipient entries (same ciphertext for all except self)
|
|
recipients = []
|
|
for member in members:
|
|
uid = member.get("user_id")
|
|
if not uid or uid == my_user_id:
|
|
continue
|
|
recipients.append({
|
|
"user_id": uid,
|
|
"encrypted_content": encode_binary(result["ciphertext"]),
|
|
"nonce": encode_binary(result["nonce"]),
|
|
})
|
|
|
|
# Self-encrypted copy (so other devices + history fetch can decrypt)
|
|
self_key = derive_self_encryption_key(self.identity_private)
|
|
_, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key)
|
|
recipients.append({
|
|
"user_id": my_user_id,
|
|
"encrypted_content": encode_binary(self_ct + self_tag),
|
|
"nonce": encode_binary(self_nonce),
|
|
"ratchet_header": {"self": True},
|
|
})
|
|
|
|
ratchet_header = {"dh_pub": "00" * 32, "n": 0, "pn": 0} # Dummy for groups
|
|
|
|
kwargs = {
|
|
"conversation_id": conv_id,
|
|
"ratchet_header": ratchet_header,
|
|
"recipients": recipients,
|
|
"sender_chain_id": encode_binary(bytes.fromhex(result["chain_id"])),
|
|
"sender_chain_n": result["n"],
|
|
}
|
|
|
|
resp = await self.send_and_recv("send_message", **kwargs)
|
|
if resp["status"] == "ok":
|
|
return True, "Message sent."
|
|
return False, resp["data"]["message"]
|
|
|
|
async def _distribute_sender_key(self, conv_id: str, members: list[dict],
|
|
sk: SenderKeyState):
|
|
"""Send own sender key to all group members via pairwise Double Ratchet (per-device)."""
|
|
my_user_id = self.session["user_id"]
|
|
exported_key = sk.export_key()
|
|
|
|
# Build a special "sender_key_distribution" payload
|
|
payload = {
|
|
"sender": self.username,
|
|
"text": "",
|
|
"reply_to": None,
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
"_sender_key": {
|
|
"conv_id": conv_id,
|
|
"key": encode_binary(exported_key),
|
|
"sender_device_id": self.device_id,
|
|
},
|
|
}
|
|
plaintext = json.dumps(payload, ensure_ascii=False).encode("utf-8")
|
|
|
|
# Send as DM to each member's devices (per-device encryption)
|
|
for member in members:
|
|
uid = member.get("user_id")
|
|
if not uid or uid == my_user_id:
|
|
continue
|
|
|
|
try:
|
|
# Get all device bundles for this user
|
|
try:
|
|
device_bundles = await self._get_device_bundles(uid)
|
|
except Exception:
|
|
device_bundles = []
|
|
|
|
if not device_bundles:
|
|
# Fallback: legacy single-device
|
|
ratchet = await self._get_or_create_session(uid)
|
|
result = ratchet.encrypt(plaintext)
|
|
x3dh_header = getattr(ratchet, "_x3dh_header", None)
|
|
if x3dh_header:
|
|
delattr(ratchet, "_x3dh_header")
|
|
|
|
recipient_entry = {
|
|
"user_id": uid,
|
|
"encrypted_content": encode_binary(result["ciphertext"]),
|
|
"nonce": encode_binary(result["nonce"]),
|
|
"ratchet_header": result["header"],
|
|
}
|
|
if x3dh_header:
|
|
recipient_entry["x3dh_header"] = x3dh_header
|
|
kwargs = {
|
|
"conversation_id": conv_id,
|
|
"ratchet_header": result["header"],
|
|
"recipients": [recipient_entry],
|
|
}
|
|
await self.send_and_recv("send_message", **kwargs)
|
|
_save_session(self.email, uid, ratchet, self._local_key)
|
|
else:
|
|
# Per-device encryption
|
|
recipients = []
|
|
first_rh = None
|
|
for bundle in device_bundles:
|
|
dev_id = bundle.get("device_id")
|
|
ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id,
|
|
bundle=bundle)
|
|
result = ratchet.encrypt(plaintext)
|
|
x3dh_header = getattr(ratchet, "_x3dh_header", None)
|
|
if x3dh_header:
|
|
delattr(ratchet, "_x3dh_header")
|
|
|
|
entry = {
|
|
"user_id": uid,
|
|
"encrypted_content": encode_binary(result["ciphertext"]),
|
|
"nonce": encode_binary(result["nonce"]),
|
|
"ratchet_header": result["header"],
|
|
}
|
|
if dev_id:
|
|
entry["device_id"] = dev_id
|
|
if x3dh_header:
|
|
entry["x3dh_header"] = x3dh_header
|
|
recipients.append(entry)
|
|
if first_rh is None:
|
|
first_rh = result["header"]
|
|
_save_session(self.email, uid, ratchet, self._local_key,
|
|
peer_device_id=dev_id)
|
|
|
|
kwargs = {
|
|
"conversation_id": conv_id,
|
|
"ratchet_header": first_rh,
|
|
"recipients": recipients,
|
|
}
|
|
await self.send_and_recv("send_message", **kwargs)
|
|
except Exception as e:
|
|
self._logger.warning("Failed to distribute sender key to %s: %s", uid, e)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Decrypt messages
|
|
# ------------------------------------------------------------------
|
|
|
|
def _decrypt_message(self, msg_data: dict) -> dict:
|
|
"""Decrypt a single message (DM or group)."""
|
|
# Check for self-encrypted marker FIRST — after re-encryption,
|
|
# group messages will have {"self": true} ratchet_header but still
|
|
# have sender_chain_id at message level.
|
|
rh = msg_data.get("ratchet_header", {})
|
|
if isinstance(rh, dict) and rh.get("self"):
|
|
return self._decrypt_dm(msg_data)
|
|
|
|
if msg_data.get("sender_chain_id"):
|
|
return self._decrypt_group(msg_data)
|
|
else:
|
|
return self._decrypt_dm(msg_data)
|
|
|
|
def _decrypt_dm(self, msg_data: dict) -> dict:
|
|
"""Decrypt DM using Double Ratchet with sender, or static key for self-copies."""
|
|
sender_id = msg_data.get("sender_id", "")
|
|
sender_device_id = msg_data.get("sender_device_id")
|
|
ratchet_header = msg_data.get("ratchet_header", {})
|
|
ct_b64 = msg_data.get("encrypted_content", "")
|
|
nonce_b64 = msg_data.get("nonce", "")
|
|
|
|
if not ct_b64 or not nonce_b64:
|
|
raise ValueError("Missing ciphertext or nonce")
|
|
|
|
ciphertext = decode_binary(ct_b64)
|
|
nonce = decode_binary(nonce_b64)
|
|
|
|
# Self-encrypted message (own sent message copy)
|
|
if isinstance(ratchet_header, dict) and ratchet_header.get("self"):
|
|
self_key = derive_self_encryption_key(self.identity_private)
|
|
ct = ciphertext[:-16]
|
|
tag = ciphertext[-16:]
|
|
plaintext = aes_decrypt(self_key, nonce, ct, tag)
|
|
else:
|
|
x3dh_header = msg_data.get("x3dh_header")
|
|
|
|
# Session key: "sender_id:sender_device_id" or just "sender_id" for legacy
|
|
session_key = f"{sender_id}:{sender_device_id}" if sender_device_id else sender_id
|
|
|
|
# Try to load existing session
|
|
ratchet = self.sessions.get(session_key)
|
|
if not ratchet:
|
|
ratchet = _load_session(self.email, sender_id, self._local_key,
|
|
peer_device_id=sender_device_id)
|
|
if ratchet:
|
|
self.sessions[session_key] = ratchet
|
|
|
|
if ratchet and not x3dh_header:
|
|
# Normal case: existing session, no X3DH header
|
|
plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce)
|
|
_save_session(self.email, sender_id, ratchet, self._local_key,
|
|
peer_device_id=sender_device_id)
|
|
elif x3dh_header:
|
|
if ratchet:
|
|
# Existing session + X3DH header: sender may have reset.
|
|
backup = ratchet.export_state()
|
|
try:
|
|
plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce)
|
|
_save_session(self.email, sender_id, ratchet, self._local_key,
|
|
peer_device_id=sender_device_id)
|
|
except Exception:
|
|
restored = DoubleRatchet.import_state(backup)
|
|
self.sessions[session_key] = restored
|
|
_save_session(self.email, sender_id, restored, self._local_key,
|
|
peer_device_id=sender_device_id)
|
|
ratchet = self._process_x3dh_header(sender_id, x3dh_header,
|
|
sender_device_id=sender_device_id)
|
|
try:
|
|
plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce)
|
|
except Exception:
|
|
if self._prev_spk_private:
|
|
ratchet = self._process_x3dh_header(
|
|
sender_id, x3dh_header,
|
|
sender_device_id=sender_device_id,
|
|
spk_override=self._prev_spk_private)
|
|
plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce)
|
|
else:
|
|
raise
|
|
_save_session(self.email, sender_id, ratchet, self._local_key,
|
|
peer_device_id=sender_device_id)
|
|
else:
|
|
ratchet = self._process_x3dh_header(sender_id, x3dh_header,
|
|
sender_device_id=sender_device_id)
|
|
try:
|
|
plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce)
|
|
except Exception:
|
|
if self._prev_spk_private:
|
|
ratchet = self._process_x3dh_header(
|
|
sender_id, x3dh_header,
|
|
sender_device_id=sender_device_id,
|
|
spk_override=self._prev_spk_private)
|
|
plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce)
|
|
else:
|
|
raise
|
|
_save_session(self.email, sender_id, ratchet, self._local_key,
|
|
peer_device_id=sender_device_id)
|
|
else:
|
|
raise ValueError(f"No session for sender {sender_id}")
|
|
|
|
payload = json.loads(plaintext)
|
|
|
|
# Handle sender key distribution messages
|
|
if "_sender_key" in payload:
|
|
sk_data = payload["_sender_key"]
|
|
sk_conv_id = sk_data["conv_id"]
|
|
sk_key = decode_binary(sk_data["key"])
|
|
sk_sender_device_id = sk_data.get("sender_device_id")
|
|
recv_sk = SenderKeyState.from_key(sk_key)
|
|
if sk_sender_device_id:
|
|
cache_key = f"{sk_conv_id}:{sender_id}:{sk_sender_device_id}"
|
|
else:
|
|
cache_key = f"{sk_conv_id}:{sender_id}"
|
|
self.recv_sender_keys[cache_key] = recv_sk
|
|
_save_recv_sender_key(self.email, sk_conv_id, sender_id, recv_sk, self._local_key,
|
|
sender_device_id=sk_sender_device_id)
|
|
# Return empty — this is a control message, not user-visible
|
|
return None
|
|
|
|
return payload
|
|
|
|
def _decrypt_group(self, msg_data: dict) -> dict:
|
|
"""Decrypt group message using sender's Sender Key."""
|
|
sender_id = msg_data.get("sender_id", "")
|
|
sender_device_id = msg_data.get("sender_device_id")
|
|
conv_id = msg_data.get("conversation_id", "")
|
|
chain_id_b64 = msg_data.get("sender_chain_id", "")
|
|
chain_n = msg_data.get("sender_chain_n", 0)
|
|
ct_b64 = msg_data.get("encrypted_content", "")
|
|
nonce_b64 = msg_data.get("nonce", "")
|
|
|
|
if not ct_b64 or not nonce_b64 or not chain_id_b64:
|
|
raise ValueError("Missing group message fields")
|
|
|
|
ciphertext = decode_binary(ct_b64)
|
|
nonce = decode_binary(nonce_b64)
|
|
chain_id = decode_binary(chain_id_b64)
|
|
|
|
my_user_id = self.session["user_id"]
|
|
|
|
# If we sent this message, use our own sender key
|
|
if sender_id == my_user_id:
|
|
sk = self.sender_key_states.get(conv_id)
|
|
if not sk:
|
|
sk = _load_sender_key_state(self.email, conv_id, self._local_key)
|
|
if sk:
|
|
self.sender_key_states[conv_id] = sk
|
|
if not sk:
|
|
raise ValueError("Own sender key not found")
|
|
# For our own messages, we can't decrypt from sender key (it's already advanced)
|
|
# Return a placeholder — the server echoed our ciphertext
|
|
raise ValueError("Cannot decrypt own group message from sender key")
|
|
|
|
# Use received sender key — try with sender_device_id first, fall back to without
|
|
sk = None
|
|
if sender_device_id:
|
|
cache_key = f"{conv_id}:{sender_id}:{sender_device_id}"
|
|
sk = self.recv_sender_keys.get(cache_key)
|
|
if not sk:
|
|
sk = _load_recv_sender_key(self.email, conv_id, sender_id, self._local_key,
|
|
sender_device_id=sender_device_id)
|
|
if sk:
|
|
self.recv_sender_keys[cache_key] = sk
|
|
|
|
if not sk:
|
|
# Fallback: try without device_id (legacy or same-device)
|
|
cache_key = f"{conv_id}:{sender_id}"
|
|
sk = self.recv_sender_keys.get(cache_key)
|
|
if not sk:
|
|
sk = _load_recv_sender_key(self.email, conv_id, sender_id, self._local_key)
|
|
if sk:
|
|
self.recv_sender_keys[cache_key] = sk
|
|
|
|
if not sk:
|
|
raise ValueError(f"No sender key for {sender_id} in conversation {conv_id}")
|
|
|
|
plaintext = sk.decrypt(chain_id.hex(), chain_n, ciphertext, nonce)
|
|
_save_recv_sender_key(self.email, conv_id, sender_id, sk, self._local_key,
|
|
sender_device_id=sender_device_id)
|
|
|
|
return json.loads(plaintext)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Get/decrypt messages (batch)
|
|
# ------------------------------------------------------------------
|
|
|
|
async def get_messages(self, conv_id: str, limit: int = 50, offset: int = 0) -> list[dict]:
|
|
resp = await self.send_and_recv("get_messages", conversation_id=conv_id, limit=limit, offset=offset)
|
|
if resp["status"] != "ok":
|
|
return []
|
|
|
|
cache = _load_message_cache(self.email, conv_id, self._cache_key)
|
|
decrypted = []
|
|
message_ids = []
|
|
raw_messages = resp["data"]["messages"]
|
|
raw_messages.reverse() # Server returns DESC, reverse to ASC
|
|
for m in raw_messages:
|
|
msg_id = m["message_id"]
|
|
message_ids.append(msg_id)
|
|
|
|
if m.get("deleted_at"):
|
|
decrypted.append({
|
|
"message_id": msg_id,
|
|
"sender": "",
|
|
"text": "",
|
|
"created_at": m["created_at"],
|
|
"read_by": [],
|
|
"sender_id": m.get("sender_id", ""),
|
|
"deleted": True,
|
|
})
|
|
continue
|
|
|
|
# Check local cache first (ratchet keys are one-time use)
|
|
cached = cache.get(msg_id)
|
|
if cached:
|
|
cached["read_by"] = m.get("read_by", [])
|
|
cached["created_at"] = m["created_at"]
|
|
if cached.get("_control"):
|
|
continue # Skip control messages
|
|
decrypted.append(cached)
|
|
continue
|
|
|
|
try:
|
|
msg_data = {
|
|
"sender_id": m.get("sender_id", ""),
|
|
"sender_device_id": m.get("sender_device_id"),
|
|
"conversation_id": conv_id,
|
|
"ratchet_header": m.get("ratchet_header", {}),
|
|
"encrypted_content": m.get("encrypted_content", ""),
|
|
"nonce": m.get("nonce", ""),
|
|
"x3dh_header": m.get("x3dh_header"),
|
|
"sender_chain_id": m.get("sender_chain_id"),
|
|
"sender_chain_n": m.get("sender_chain_n"),
|
|
}
|
|
payload = self._decrypt_message(msg_data)
|
|
if payload is None:
|
|
# Control message (sender key distribution) — cache and skip
|
|
_save_message_to_cache(self.email, conv_id, msg_id, {"_control": True},
|
|
cache_key=self._cache_key)
|
|
continue
|
|
payload["message_id"] = msg_id
|
|
payload["created_at"] = m["created_at"]
|
|
payload["read_by"] = m.get("read_by", [])
|
|
payload["sender_id"] = m.get("sender_id", "")
|
|
decrypted.append(payload)
|
|
# Cache the decrypted payload (without read_by which changes)
|
|
cache_entry = {k: v for k, v in payload.items() if k != "read_by"}
|
|
_save_message_to_cache(self.email, conv_id, msg_id, cache_entry,
|
|
cache_key=self._cache_key)
|
|
except Exception as e:
|
|
decrypted.append({
|
|
"message_id": msg_id,
|
|
"sender": "???",
|
|
"text": f"[Decryption failed: {e}]",
|
|
"created_at": m["created_at"],
|
|
"read_by": [],
|
|
})
|
|
|
|
if message_ids:
|
|
await self.mark_read(conv_id, message_ids)
|
|
|
|
return decrypted
|
|
|
|
async def mark_read(self, conv_id: str, message_ids: list[str]):
|
|
if not message_ids:
|
|
return
|
|
await self.send_and_recv("mark_read", conversation_id=conv_id, message_ids=message_ids)
|
|
|
|
def search_messages(self, conv_id: str, query: str) -> list[dict]:
|
|
"""Search cached messages in a conversation. Returns matching messages."""
|
|
cache = _load_message_cache(self.email, conv_id, self._cache_key)
|
|
query_lower = query.lower()
|
|
results = []
|
|
for msg_id, payload in cache.items():
|
|
if payload.get("deleted") or payload.get("_control") or payload.get("_sender_key"):
|
|
continue
|
|
text = payload.get("text", "")
|
|
if query_lower in text.lower():
|
|
entry = dict(payload)
|
|
entry["message_id"] = msg_id
|
|
results.append(entry)
|
|
results.sort(key=lambda m: m.get("created_at", ""))
|
|
return results
|
|
|
|
async def reset_session(self, peer_user_id: str, peer_device_id: str | None = None):
|
|
"""Delete local session and notify peer to do the same."""
|
|
if peer_device_id:
|
|
session_key = f"{peer_user_id}:{peer_device_id}"
|
|
else:
|
|
session_key = peer_user_id
|
|
self.sessions.pop(session_key, None)
|
|
_delete_session_file(self.email, peer_user_id, peer_device_id)
|
|
await self.send_and_recv("session_reset",
|
|
peer_user_id=peer_user_id,
|
|
peer_device_id=peer_device_id or "")
|
|
|
|
def handle_session_reset_notification(self, from_user_id: str, from_device_id: str | None = None):
|
|
"""Handle incoming session reset notification — delete the matching session."""
|
|
if from_device_id:
|
|
session_key = f"{from_user_id}:{from_device_id}"
|
|
else:
|
|
session_key = from_user_id
|
|
self.sessions.pop(session_key, None)
|
|
_delete_session_file(self.email, from_user_id, from_device_id)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Decrypt notification
|
|
# ------------------------------------------------------------------
|
|
|
|
def decrypt_notification(self, notif_data: dict) -> dict | None:
|
|
"""Decrypt a new_message notification. Returns parsed payload or None.
|
|
|
|
Supports new multi-device format (device_entries array) and legacy flat format.
|
|
"""
|
|
try:
|
|
conv_id = notif_data.get("conversation_id", "")
|
|
msg_id = notif_data.get("message_id", "")
|
|
sender_id = notif_data.get("sender_id", "")
|
|
sender_device_id = notif_data.get("sender_device_id")
|
|
my_user_id = self.session["user_id"] if self.session else ""
|
|
|
|
# Extract per-device encrypted content from device_entries or flat fields
|
|
encrypted_content = ""
|
|
nonce = ""
|
|
ratchet_header = {}
|
|
x3dh_header = None
|
|
|
|
device_entries = notif_data.get("device_entries")
|
|
if device_entries:
|
|
# Multi-device format: pick entry matching our device_id or SELF_DEVICE_ID
|
|
chosen = None
|
|
self_entry = None
|
|
for entry in device_entries:
|
|
eid = entry.get("device_id", "")
|
|
if eid == self.device_id:
|
|
chosen = entry
|
|
break
|
|
if eid == "00000000-0000-0000-0000-000000000000":
|
|
self_entry = entry
|
|
|
|
# If sender is us, prefer self-encrypted entry
|
|
if sender_id == my_user_id:
|
|
chosen = self_entry or chosen
|
|
elif not chosen:
|
|
chosen = self_entry
|
|
|
|
if not chosen:
|
|
self._logger.warning("No matching device_entry for device %s", self.device_id)
|
|
return None
|
|
|
|
encrypted_content = chosen.get("encrypted_content", "")
|
|
nonce = chosen.get("nonce", "")
|
|
ratchet_header = chosen.get("ratchet_header") or notif_data.get("ratchet_header", {})
|
|
x3dh_header = chosen.get("x3dh_header") or notif_data.get("x3dh_header")
|
|
else:
|
|
# Legacy flat format
|
|
encrypted_content = notif_data.get("encrypted_content", "")
|
|
nonce = notif_data.get("nonce", "")
|
|
ratchet_header = notif_data.get("ratchet_header", {})
|
|
x3dh_header = notif_data.get("x3dh_header")
|
|
|
|
msg_data = {
|
|
"sender_id": sender_id,
|
|
"sender_device_id": sender_device_id,
|
|
"conversation_id": conv_id,
|
|
"ratchet_header": ratchet_header,
|
|
"encrypted_content": encrypted_content,
|
|
"nonce": nonce,
|
|
"x3dh_header": x3dh_header,
|
|
"sender_chain_id": notif_data.get("sender_chain_id"),
|
|
"sender_chain_n": notif_data.get("sender_chain_n"),
|
|
}
|
|
payload = self._decrypt_message(msg_data)
|
|
if payload is None:
|
|
# Cache control message so get_messages skips it
|
|
if msg_id and conv_id:
|
|
_save_message_to_cache(self.email, conv_id, msg_id, {"_control": True},
|
|
cache_key=self._cache_key)
|
|
return None
|
|
payload["conversation_id"] = conv_id
|
|
payload["message_id"] = msg_id
|
|
payload["sender_id"] = sender_id
|
|
payload["created_at"] = payload.get("timestamp", "")
|
|
payload["read_by"] = []
|
|
# Cache so get_messages doesn't re-decrypt (ratchet keys are one-time)
|
|
if msg_id and conv_id:
|
|
cache_entry = {k: v for k, v in payload.items() if k != "read_by"}
|
|
_save_message_to_cache(self.email, conv_id, msg_id, cache_entry,
|
|
cache_key=self._cache_key)
|
|
return payload
|
|
except Exception as e:
|
|
self._logger.warning("Failed to decrypt notification: %s", e)
|
|
return None
|
|
|
|
# ------------------------------------------------------------------
|
|
# Delete message
|
|
# ------------------------------------------------------------------
|
|
|
|
async def delete_message(self, message_id: str) -> tuple[bool, str]:
|
|
resp = await self.send_and_recv("delete_message", message_id=message_id)
|
|
if resp["status"] == "ok":
|
|
return True, "Message deleted."
|
|
return False, resp["data"]["message"]
|
|
|
|
# ------------------------------------------------------------------
|
|
# Image sharing
|
|
# ------------------------------------------------------------------
|
|
|
|
async def send_image(self, conv_id: str, image_path: str, members: list[dict],
|
|
reply_to: str | None = None) -> tuple[bool, str]:
|
|
"""Encrypt and upload an image, then send as a message."""
|
|
try:
|
|
from PIL import Image
|
|
import io
|
|
except ImportError:
|
|
return False, "Pillow is required for image sharing. Install with: pip install Pillow"
|
|
|
|
path = Path(image_path)
|
|
if not path.exists():
|
|
return False, "File not found."
|
|
|
|
try:
|
|
img = Image.open(path)
|
|
img.load()
|
|
except Exception as e:
|
|
return False, f"Cannot open image: {e}"
|
|
|
|
if img.mode not in ("RGB", "L"):
|
|
img = img.convert("RGB")
|
|
|
|
max_dim = 1920
|
|
if max(img.size) > max_dim:
|
|
img.thumbnail((max_dim, max_dim), Image.Resampling.LANCZOS)
|
|
|
|
buf = io.BytesIO()
|
|
img.save(buf, format="JPEG", quality=85)
|
|
image_bytes = buf.getvalue()
|
|
|
|
thumb = img.copy()
|
|
thumb.thumbnail((200, 200), Image.Resampling.LANCZOS)
|
|
thumb_buf = io.BytesIO()
|
|
thumb.save(thumb_buf, format="JPEG", quality=60)
|
|
thumbnail_b64 = encode_binary(thumb_buf.getvalue())
|
|
|
|
# Encrypt image with AES-256-GCM
|
|
img_aes_key, img_iv, img_ct, img_tag = aes_encrypt(image_bytes)
|
|
encrypted_image = img_ct + img_tag
|
|
|
|
file_id = str(uuid.uuid4())
|
|
file_size = len(encrypted_image)
|
|
|
|
# Chunked upload
|
|
resp = await self.send_and_recv(
|
|
"upload_image_start",
|
|
conversation_id=conv_id,
|
|
file_id=file_id,
|
|
file_size=file_size,
|
|
)
|
|
if resp["status"] != "ok":
|
|
return False, resp["data"]["message"]
|
|
|
|
upload_offset = 0
|
|
while upload_offset < file_size:
|
|
chunk = encrypted_image[upload_offset:upload_offset + IMAGE_CHUNK_SIZE]
|
|
resp = await self.send_and_recv(
|
|
"upload_image_chunk",
|
|
file_id=file_id,
|
|
data=encode_binary(chunk),
|
|
)
|
|
if resp["status"] != "ok":
|
|
return False, resp["data"]["message"]
|
|
upload_offset += len(chunk)
|
|
|
|
resp = await self.send_and_recv("upload_image_end", file_id=file_id)
|
|
if resp["status"] != "ok":
|
|
return False, resp["data"]["message"]
|
|
|
|
# Build message payload with image info
|
|
image_info = {
|
|
"file_id": file_id,
|
|
"aes_key": encode_binary(img_aes_key),
|
|
"iv": encode_binary(img_iv),
|
|
"thumbnail": thumbnail_b64,
|
|
"filename": path.name,
|
|
"size": len(image_bytes),
|
|
}
|
|
|
|
payload = {
|
|
"sender": self.username,
|
|
"text": "",
|
|
"reply_to": reply_to,
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
"image": image_info,
|
|
}
|
|
plaintext = json.dumps(payload, ensure_ascii=False).encode("utf-8")
|
|
|
|
my_user_id = self.session["user_id"]
|
|
|
|
if self._is_group(members):
|
|
# Group image: use sender key
|
|
sk = self.sender_key_states.get(conv_id)
|
|
if not sk:
|
|
sk = _load_sender_key_state(self.email, conv_id, self._local_key)
|
|
if not sk:
|
|
sk = SenderKeyState()
|
|
self.sender_key_states[conv_id] = sk
|
|
_save_sender_key_state(self.email, conv_id, sk, self._local_key)
|
|
await self._distribute_sender_key(conv_id, members, sk)
|
|
|
|
result = sk.encrypt(plaintext)
|
|
_save_sender_key_state(self.email, conv_id, sk, self._local_key)
|
|
|
|
recipients = []
|
|
for member in members:
|
|
uid = member.get("user_id")
|
|
if not uid or uid == my_user_id:
|
|
continue
|
|
recipients.append({
|
|
"user_id": uid,
|
|
"encrypted_content": encode_binary(result["ciphertext"]),
|
|
"nonce": encode_binary(result["nonce"]),
|
|
})
|
|
|
|
# Self-encrypted copy for sender
|
|
self_key = derive_self_encryption_key(self.identity_private)
|
|
_, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key)
|
|
recipients.append({
|
|
"user_id": my_user_id,
|
|
"encrypted_content": encode_binary(self_ct + self_tag),
|
|
"nonce": encode_binary(self_nonce),
|
|
"ratchet_header": {"self": True},
|
|
})
|
|
|
|
resp = await self.send_and_recv(
|
|
"send_message",
|
|
conversation_id=conv_id,
|
|
ratchet_header={"dh_pub": "00" * 32, "n": 0, "pn": 0},
|
|
recipients=recipients,
|
|
sender_chain_id=encode_binary(bytes.fromhex(result["chain_id"])),
|
|
sender_chain_n=result["n"],
|
|
image_file_id=file_id,
|
|
)
|
|
else:
|
|
# DM image: per-device ratchet (same pattern as _send_dm)
|
|
recipients = []
|
|
first_rh = None
|
|
for member in members:
|
|
uid = member.get("user_id")
|
|
if not uid or uid == my_user_id:
|
|
continue
|
|
|
|
try:
|
|
device_bundles = await self._get_device_bundles(uid)
|
|
except Exception:
|
|
device_bundles = []
|
|
|
|
if not device_bundles:
|
|
# Fallback: legacy single-device
|
|
ratchet = await self._get_or_create_session(uid)
|
|
result = ratchet.encrypt(plaintext)
|
|
x3dh_h = getattr(ratchet, "_x3dh_header", None)
|
|
if x3dh_h:
|
|
delattr(ratchet, "_x3dh_header")
|
|
entry = {
|
|
"user_id": uid,
|
|
"encrypted_content": encode_binary(result["ciphertext"]),
|
|
"nonce": encode_binary(result["nonce"]),
|
|
"ratchet_header": result["header"],
|
|
}
|
|
if x3dh_h:
|
|
entry["x3dh_header"] = x3dh_h
|
|
recipients.append(entry)
|
|
if first_rh is None:
|
|
first_rh = result["header"]
|
|
_save_session(self.email, uid, ratchet, self._local_key)
|
|
else:
|
|
for bundle in device_bundles:
|
|
dev_id = bundle.get("device_id")
|
|
ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id,
|
|
bundle=bundle)
|
|
result = ratchet.encrypt(plaintext)
|
|
x3dh_h = getattr(ratchet, "_x3dh_header", None)
|
|
if x3dh_h:
|
|
delattr(ratchet, "_x3dh_header")
|
|
entry = {
|
|
"user_id": uid,
|
|
"encrypted_content": encode_binary(result["ciphertext"]),
|
|
"nonce": encode_binary(result["nonce"]),
|
|
"ratchet_header": result["header"],
|
|
}
|
|
if dev_id:
|
|
entry["device_id"] = dev_id
|
|
if x3dh_h:
|
|
entry["x3dh_header"] = x3dh_h
|
|
recipients.append(entry)
|
|
if first_rh is None:
|
|
first_rh = result["header"]
|
|
_save_session(self.email, uid, ratchet, self._local_key,
|
|
peer_device_id=dev_id)
|
|
|
|
# Encrypt self-copy with static key
|
|
self_key = derive_self_encryption_key(self.identity_private)
|
|
_, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key)
|
|
recipients.append({
|
|
"user_id": my_user_id,
|
|
"encrypted_content": encode_binary(self_ct + self_tag),
|
|
"nonce": encode_binary(self_nonce),
|
|
"ratchet_header": {"self": True},
|
|
})
|
|
|
|
resp = await self.send_and_recv(
|
|
"send_message",
|
|
conversation_id=conv_id,
|
|
ratchet_header=first_rh,
|
|
recipients=recipients,
|
|
image_file_id=file_id,
|
|
)
|
|
|
|
if resp["status"] == "ok":
|
|
return True, "Image sent."
|
|
return False, resp["data"]["message"]
|
|
|
|
async def send_file(self, conv_id: str, file_path: str, members: list[dict],
|
|
reply_to: str | None = None) -> tuple[bool, str]:
|
|
"""Encrypt and upload a file, then send as a message."""
|
|
import mimetypes
|
|
|
|
path = Path(file_path)
|
|
if not path.exists():
|
|
return False, "File not found."
|
|
|
|
try:
|
|
file_bytes = path.read_bytes()
|
|
except Exception as e:
|
|
return False, f"Cannot read file: {e}"
|
|
|
|
mime_type = mimetypes.guess_type(path.name)[0] or "application/octet-stream"
|
|
|
|
# Encrypt file with AES-256-GCM
|
|
file_aes_key, file_iv, file_ct, file_tag = aes_encrypt(file_bytes)
|
|
encrypted_file = file_ct + file_tag
|
|
|
|
file_id = str(uuid.uuid4())
|
|
file_size = len(encrypted_file)
|
|
|
|
# Chunked upload (reuse image upload infrastructure with file_type="file")
|
|
resp = await self.send_and_recv(
|
|
"upload_image_start",
|
|
conversation_id=conv_id,
|
|
file_id=file_id,
|
|
file_size=file_size,
|
|
file_type="file",
|
|
)
|
|
if resp["status"] != "ok":
|
|
return False, resp["data"]["message"]
|
|
|
|
upload_offset = 0
|
|
while upload_offset < file_size:
|
|
chunk = encrypted_file[upload_offset:upload_offset + IMAGE_CHUNK_SIZE]
|
|
resp = await self.send_and_recv(
|
|
"upload_image_chunk",
|
|
file_id=file_id,
|
|
data=encode_binary(chunk),
|
|
)
|
|
if resp["status"] != "ok":
|
|
return False, resp["data"]["message"]
|
|
upload_offset += len(chunk)
|
|
|
|
resp = await self.send_and_recv("upload_image_end", file_id=file_id)
|
|
if resp["status"] != "ok":
|
|
return False, resp["data"]["message"]
|
|
|
|
# Build message payload with file info
|
|
file_info = {
|
|
"file_id": file_id,
|
|
"aes_key": encode_binary(file_aes_key),
|
|
"iv": encode_binary(file_iv),
|
|
"filename": path.name,
|
|
"size": len(file_bytes),
|
|
"mime_type": mime_type,
|
|
}
|
|
|
|
payload = {
|
|
"sender": self.username,
|
|
"text": "",
|
|
"reply_to": reply_to,
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
"file": file_info,
|
|
}
|
|
plaintext = json.dumps(payload, ensure_ascii=False).encode("utf-8")
|
|
|
|
my_user_id = self.session["user_id"]
|
|
|
|
if self._is_group(members):
|
|
sk = self.sender_key_states.get(conv_id)
|
|
if not sk:
|
|
sk = _load_sender_key_state(self.email, conv_id, self._local_key)
|
|
if not sk:
|
|
sk = SenderKeyState()
|
|
self.sender_key_states[conv_id] = sk
|
|
_save_sender_key_state(self.email, conv_id, sk, self._local_key)
|
|
await self._distribute_sender_key(conv_id, members, sk)
|
|
|
|
result = sk.encrypt(plaintext)
|
|
_save_sender_key_state(self.email, conv_id, sk, self._local_key)
|
|
|
|
recipients = []
|
|
for member in members:
|
|
uid = member.get("user_id")
|
|
if not uid or uid == my_user_id:
|
|
continue
|
|
recipients.append({
|
|
"user_id": uid,
|
|
"encrypted_content": encode_binary(result["ciphertext"]),
|
|
"nonce": encode_binary(result["nonce"]),
|
|
})
|
|
|
|
# Self-encrypted copy for sender
|
|
self_key = derive_self_encryption_key(self.identity_private)
|
|
_, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key)
|
|
recipients.append({
|
|
"user_id": my_user_id,
|
|
"encrypted_content": encode_binary(self_ct + self_tag),
|
|
"nonce": encode_binary(self_nonce),
|
|
"ratchet_header": {"self": True},
|
|
})
|
|
|
|
resp = await self.send_and_recv(
|
|
"send_message",
|
|
conversation_id=conv_id,
|
|
ratchet_header={"dh_pub": "00" * 32, "n": 0, "pn": 0},
|
|
recipients=recipients,
|
|
sender_chain_id=encode_binary(bytes.fromhex(result["chain_id"])),
|
|
sender_chain_n=result["n"],
|
|
image_file_id=file_id,
|
|
)
|
|
else:
|
|
# DM file: per-device ratchet (same pattern as _send_dm)
|
|
recipients = []
|
|
first_rh = None
|
|
for member in members:
|
|
uid = member.get("user_id")
|
|
if not uid or uid == my_user_id:
|
|
continue
|
|
|
|
try:
|
|
device_bundles = await self._get_device_bundles(uid)
|
|
except Exception:
|
|
device_bundles = []
|
|
|
|
if not device_bundles:
|
|
# Fallback: legacy single-device
|
|
ratchet = await self._get_or_create_session(uid)
|
|
result = ratchet.encrypt(plaintext)
|
|
x3dh_h = getattr(ratchet, "_x3dh_header", None)
|
|
if x3dh_h:
|
|
delattr(ratchet, "_x3dh_header")
|
|
entry = {
|
|
"user_id": uid,
|
|
"encrypted_content": encode_binary(result["ciphertext"]),
|
|
"nonce": encode_binary(result["nonce"]),
|
|
"ratchet_header": result["header"],
|
|
}
|
|
if x3dh_h:
|
|
entry["x3dh_header"] = x3dh_h
|
|
recipients.append(entry)
|
|
if first_rh is None:
|
|
first_rh = result["header"]
|
|
_save_session(self.email, uid, ratchet, self._local_key)
|
|
else:
|
|
for bundle in device_bundles:
|
|
dev_id = bundle.get("device_id")
|
|
ratchet = await self._get_or_create_session(uid, peer_device_id=dev_id,
|
|
bundle=bundle)
|
|
result = ratchet.encrypt(plaintext)
|
|
x3dh_h = getattr(ratchet, "_x3dh_header", None)
|
|
if x3dh_h:
|
|
delattr(ratchet, "_x3dh_header")
|
|
entry = {
|
|
"user_id": uid,
|
|
"encrypted_content": encode_binary(result["ciphertext"]),
|
|
"nonce": encode_binary(result["nonce"]),
|
|
"ratchet_header": result["header"],
|
|
}
|
|
if dev_id:
|
|
entry["device_id"] = dev_id
|
|
if x3dh_h:
|
|
entry["x3dh_header"] = x3dh_h
|
|
recipients.append(entry)
|
|
if first_rh is None:
|
|
first_rh = result["header"]
|
|
_save_session(self.email, uid, ratchet, self._local_key,
|
|
peer_device_id=dev_id)
|
|
|
|
# Encrypt self-copy with static key
|
|
self_key = derive_self_encryption_key(self.identity_private)
|
|
_, self_nonce, self_ct, self_tag = aes_encrypt(plaintext, key=self_key)
|
|
recipients.append({
|
|
"user_id": my_user_id,
|
|
"encrypted_content": encode_binary(self_ct + self_tag),
|
|
"nonce": encode_binary(self_nonce),
|
|
"ratchet_header": {"self": True},
|
|
})
|
|
|
|
resp = await self.send_and_recv(
|
|
"send_message",
|
|
conversation_id=conv_id,
|
|
ratchet_header=first_rh,
|
|
recipients=recipients,
|
|
image_file_id=file_id,
|
|
)
|
|
|
|
if resp["status"] == "ok":
|
|
return True, "File sent."
|
|
return False, resp["data"]["message"]
|
|
|
|
async def download_file(self, file_id: str, file_info: dict) -> bytes | None:
|
|
"""Download and decrypt a file. Returns decrypted file bytes or None."""
|
|
chunks = []
|
|
offset = 0
|
|
while True:
|
|
resp = await self.send_and_recv(
|
|
"download_image",
|
|
file_id=file_id,
|
|
offset=offset,
|
|
)
|
|
if resp["status"] != "ok":
|
|
return None
|
|
data = resp["data"]
|
|
chunk = decode_binary(data["data"])
|
|
chunks.append(chunk)
|
|
offset += len(chunk)
|
|
if data.get("done"):
|
|
break
|
|
|
|
encrypted_data = b"".join(chunks)
|
|
if len(encrypted_data) < 16:
|
|
return None
|
|
ciphertext = encrypted_data[:-16]
|
|
tag = encrypted_data[-16:]
|
|
|
|
try:
|
|
file_aes_key = decode_binary(file_info["aes_key"])
|
|
iv = decode_binary(file_info["iv"])
|
|
return aes_decrypt(file_aes_key, iv, ciphertext, tag)
|
|
except Exception:
|
|
return None
|
|
|
|
async def download_image(self, file_id: str, image_info: dict) -> bytes | None:
|
|
"""Download and decrypt an image. Returns decrypted image bytes or None."""
|
|
chunks = []
|
|
offset = 0
|
|
while True:
|
|
resp = await self.send_and_recv(
|
|
"download_image",
|
|
file_id=file_id,
|
|
offset=offset,
|
|
)
|
|
if resp["status"] != "ok":
|
|
return None
|
|
data = resp["data"]
|
|
chunk = decode_binary(data["data"])
|
|
chunks.append(chunk)
|
|
offset += len(chunk)
|
|
if data.get("done"):
|
|
break
|
|
|
|
encrypted_data = b"".join(chunks)
|
|
if len(encrypted_data) < 16:
|
|
return None
|
|
ciphertext = encrypted_data[:-16]
|
|
tag = encrypted_data[-16:]
|
|
|
|
try:
|
|
img_aes_key = decode_binary(image_info["aes_key"])
|
|
iv = decode_binary(image_info["iv"])
|
|
return aes_decrypt(img_aes_key, iv, ciphertext, tag)
|
|
except Exception:
|
|
return None
|
|
|
|
# ------------------------------------------------------------------
|
|
# Re-encrypt history (for device pairing)
|
|
# ------------------------------------------------------------------
|
|
|
|
async def reencrypt_history(self):
|
|
"""Re-encrypt all cached messages with self-encryption key.
|
|
|
|
After device pairing, the new device shares the same identity key
|
|
but cannot decrypt old messages (Double Ratchet keys are one-time use).
|
|
This re-encrypts all cached messages so they can be read using the
|
|
self-encryption key derived from the shared identity key.
|
|
"""
|
|
if not self.identity_private or not self.session:
|
|
return
|
|
|
|
self_key = derive_self_encryption_key(self.identity_private)
|
|
|
|
# Phase 1: Fetch & decrypt all messages to populate cache
|
|
# (messages the old device never opened won't be in cache yet)
|
|
try:
|
|
convs = await self.list_conversations()
|
|
total_convs = len(convs)
|
|
for ci, conv in enumerate(convs):
|
|
cid = conv.get("id") or conv.get("conversation_id")
|
|
if not cid:
|
|
continue
|
|
if self._reencrypt_progress_cb:
|
|
self._reencrypt_progress_cb(
|
|
f"Fetching messages: {ci + 1}/{total_convs} conversations..."
|
|
)
|
|
offset = 0
|
|
while True:
|
|
msgs = await self.get_messages(cid, limit=200, offset=offset)
|
|
if not msgs or len(msgs) < 200:
|
|
break
|
|
offset += len(msgs)
|
|
except Exception as e:
|
|
self._logger.warning("Failed to fetch messages for re-encryption: %s", e)
|
|
|
|
# Phase 2: Read cache and re-encrypt
|
|
cache_dir = get_key_dir(self.email) / "message_cache"
|
|
if not cache_dir.exists():
|
|
self._logger.info("No message cache to re-encrypt.")
|
|
return
|
|
|
|
all_updates = []
|
|
conv_ids = set()
|
|
for f in cache_dir.iterdir():
|
|
if f.suffix in (".json", ".bin"):
|
|
conv_ids.add(f.stem)
|
|
|
|
total_files = len(conv_ids)
|
|
for i, conv_id in enumerate(sorted(conv_ids)):
|
|
cache = _load_message_cache(self.email, conv_id, self._cache_key)
|
|
if not cache:
|
|
continue
|
|
|
|
for msg_id, entry in cache.items():
|
|
# Skip control messages (sender key distribution)
|
|
if entry.get("_control"):
|
|
continue
|
|
# Skip entries with no useful content
|
|
text = entry.get("text", "")
|
|
if not text and not entry.get("image") and not entry.get("file"):
|
|
continue
|
|
|
|
# Rebuild plaintext from cached payload
|
|
payload = {k: v for k, v in entry.items()
|
|
if k not in ("message_id", "created_at", "read_by", "sender_id", "deleted")}
|
|
plaintext = json.dumps(payload, ensure_ascii=False).encode("utf-8")
|
|
|
|
# Re-encrypt with self-encryption key
|
|
_, nonce, ct, tag = aes_encrypt(plaintext, key=self_key)
|
|
all_updates.append({
|
|
"message_id": msg_id,
|
|
"encrypted_content": encode_binary(ct + tag),
|
|
"nonce": encode_binary(nonce),
|
|
})
|
|
|
|
if self._reencrypt_progress_cb:
|
|
self._reencrypt_progress_cb(f"Re-encrypting history: {i + 1}/{total_files} conversations...")
|
|
|
|
if not all_updates:
|
|
self._logger.info("No messages to re-encrypt.")
|
|
return
|
|
|
|
# Send in batches of 500
|
|
batch_size = 500
|
|
total = len(all_updates)
|
|
for start in range(0, total, batch_size):
|
|
batch = all_updates[start:start + batch_size]
|
|
resp = await self.send_and_recv("reencrypt_messages", updates=batch)
|
|
if resp["status"] != "ok":
|
|
self._logger.warning("Re-encrypt batch failed: %s", resp.get("data", {}).get("message", ""))
|
|
else:
|
|
self._logger.info("Re-encrypted %d/%d messages.", min(start + batch_size, total), total)
|
|
|
|
if self._reencrypt_progress_cb:
|
|
self._reencrypt_progress_cb(f"Re-encryption complete: {total} messages uploaded.")
|
|
|
|
# ------------------------------------------------------------------
|
|
# User Profiles
|
|
# ------------------------------------------------------------------
|
|
|
|
async def get_profile(self, user_id: str | None = None) -> dict | None:
|
|
"""Get user profile. If user_id is None, returns own profile."""
|
|
kwargs = {}
|
|
if user_id:
|
|
kwargs["user_id"] = user_id
|
|
resp = await self.send_and_recv("get_profile", **kwargs)
|
|
if resp["status"] == "ok":
|
|
return resp["data"]
|
|
return None
|
|
|
|
async def update_profile(self, **fields) -> tuple[bool, str]:
|
|
"""Update own profile (phone, location, *_visible)."""
|
|
resp = await self.send_and_recv("update_profile", **fields)
|
|
if resp["status"] == "ok":
|
|
return True, "OK"
|
|
return False, resp["data"]["message"]
|
|
|
|
async def update_avatar(self, image_data: bytes) -> tuple[bool, str]:
|
|
"""Upload avatar image."""
|
|
resp = await self.send_and_recv("update_avatar", data=encode_binary(image_data))
|
|
if resp["status"] == "ok":
|
|
return True, resp["data"].get("avatar_file", "")
|
|
return False, resp["data"]["message"]
|
|
|
|
async def get_avatar(self, user_id: str) -> bytes | None:
|
|
"""Download avatar for a user."""
|
|
resp = await self.send_and_recv("get_avatar", user_id=user_id)
|
|
if resp["status"] == "ok":
|
|
return decode_binary(resp["data"]["data"])
|
|
return None
|
|
|
|
async def update_group_avatar(self, conv_id: str, image_data: bytes) -> tuple[bool, str]:
|
|
"""Upload avatar for a group conversation."""
|
|
resp = await self.send_and_recv("update_group_avatar",
|
|
conversation_id=conv_id, data=encode_binary(image_data))
|
|
if resp["status"] == "ok":
|
|
return True, resp["data"].get("avatar_file", "")
|
|
return False, resp["data"]["message"]
|
|
|
|
async def get_group_avatar(self, conv_id: str) -> bytes | None:
|
|
"""Download avatar for a group conversation."""
|
|
resp = await self.send_and_recv("get_group_avatar", conversation_id=conv_id)
|
|
if resp["status"] == "ok":
|
|
return decode_binary(resp["data"]["data"])
|
|
return None
|
|
|
|
# ------------------------------------------------------------------
|
|
# Cleanup
|
|
# ------------------------------------------------------------------
|
|
|
|
async def close(self):
|
|
self.connected = False
|
|
if self._listener_task:
|
|
self._listener_task.cancel()
|
|
if self.raw_writer:
|
|
self.raw_writer.close()
|
|
|
|
async def reconnect(self):
|
|
"""Close existing connection and re-establish: connect + re-login using in-memory keys."""
|
|
try:
|
|
await self.close()
|
|
except Exception:
|
|
pass
|
|
# Reset reader/writer but keep keys and sessions
|
|
self.reader = None
|
|
self.writer = None
|
|
self.raw_writer = None
|
|
self._listener_task = None
|
|
self._pending.clear()
|
|
self.login_rejected = False
|
|
# Drain queues
|
|
while not self._response_queue.empty():
|
|
try:
|
|
self._response_queue.get_nowait()
|
|
except Exception:
|
|
break
|
|
while not self._notification_queue.empty():
|
|
try:
|
|
self._notification_queue.get_nowait()
|
|
except Exception:
|
|
break
|
|
await self.connect()
|
|
self._listener_task = asyncio.create_task(self._background_listener())
|
|
if self.email and self.private_key:
|
|
# RSA challenge-response login (keys already in memory)
|
|
start = await self.send_and_recv("login_start", email=self.email)
|
|
if start["status"] == "ok":
|
|
challenge = decode_binary(start["data"]["challenge"])
|
|
signature = rsa_sign(self.private_key, challenge)
|
|
login_kwargs = {
|
|
"email": self.email,
|
|
"signature": encode_binary(signature),
|
|
"client_version": VERSION,
|
|
}
|
|
if self.device_id:
|
|
login_kwargs["device_id"] = self.device_id
|
|
finish = await self.send_and_recv("login_finish", **login_kwargs)
|
|
if finish["status"] == "ok":
|
|
self.session = finish["data"]
|
|
asyncio.create_task(self._ensure_prekeys())
|
|
else:
|
|
# Login rejected — keys were likely rotated on another device
|
|
self.session = None
|
|
self.connected = False
|
|
self.login_rejected = True
|