Client core hardening: X3DH session adoption, sync watermark, PoW, key rotation

- Adopt a new X3DH session (install into self.sessions + persist) only
  after the first message decrypts successfully. Previously
  _process_x3dh_header saved the candidate session immediately, so a
  replayed/forged X3DH header permanently overwrote a working ratchet.
- Advance the incremental-sync watermark (__last_server_ts) only across
  the prefix of messages settled in the cache. An undecryptable message
  (e.g. sender key not yet received) is re-fetched and retried up to
  _MAX_DECRYPT_RETRIES=3 times instead of being silently lost forever.
  Watermark is no longer touched on offset>0 pages and never regresses.
- Fix NameError in the proof-of-work registration path (logger ->
  self._logger) and run _solve_pow in an executor so it does not block
  the event loop.
- Persist the rotated RSA login key only after the server confirmed
  rotate_keys; writing private.pem first bricked the account when the
  request failed.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
filip
2026-06-12 16:08:31 +02:00
parent f0666ea6ac
commit d499fd8436

View File

@@ -751,6 +751,11 @@ def _load_verified_contacts(email: str, local_key: bytes | None = None) -> dict:
return {} return {}
# How many sync cycles a message that fails to decrypt is retried before it
# is recorded as permanently failed and the sync watermark moves past it.
_MAX_DECRYPT_RETRIES = 3
def _solve_pow(challenge: str, difficulty: int) -> str: def _solve_pow(challenge: str, difficulty: int) -> str:
"""Solve a proof-of-work challenge by finding a nonce with enough leading zero bits.""" """Solve a proof-of-work challenge by finding a nonce with enough leading zero bits."""
target_bytes = difficulty // 8 target_bytes = difficulty // 8
@@ -1233,8 +1238,9 @@ class ChatClient:
challenge = start["data"]["challenge"] challenge = start["data"]["challenge"]
mac = start["data"]["mac"] mac = start["data"]["mac"]
difficulty = start["data"]["difficulty"] difficulty = start["data"]["difficulty"]
logger.info("Server requires proof-of-work (difficulty %d), solving...", difficulty) self._logger.info("Server requires proof-of-work (difficulty %d), solving...", difficulty)
nonce = _solve_pow(challenge, difficulty) nonce = await asyncio.get_running_loop().run_in_executor(
None, _solve_pow, challenge, difficulty)
extra_fields = {"pow_challenge": challenge, "pow_mac": mac, "pow_nonce": nonce} extra_fields = {"pow_challenge": challenge, "pow_mac": mac, "pow_nonce": nonce}
start = await self.send_and_recv( start = await self.send_and_recv(
"register", "register",
@@ -1700,12 +1706,14 @@ class ChatClient:
return False, "Not logged in." return False, "Not logged in."
pwd_bytes = password.encode("utf-8") if password else None pwd_bytes = password.encode("utf-8") if password else None
priv, pub = generate_rsa_keypair() priv, pub = generate_rsa_keypair()
save_keys(self.email, priv, pub, password=pwd_bytes)
self.private_key = priv
self.public_key = pub
pub_pem = serialize_public_key(pub).decode("utf-8") pub_pem = serialize_public_key(pub).decode("utf-8")
# Persist the new key only after the server accepted it — overwriting
# private.pem first would brick the account if rotation fails.
resp = await self.send_and_recv("rotate_keys", public_key=pub_pem) resp = await self.send_and_recv("rotate_keys", public_key=pub_pem)
if resp["status"] == "ok": if resp["status"] == "ok":
save_keys(self.email, priv, pub, password=pwd_bytes)
self.private_key = priv
self.public_key = pub
return True, "RSA login keys rotated." return True, "RSA login keys rotated."
return False, resp["data"]["message"] return False, resp["data"]["message"]
@@ -1872,10 +1880,10 @@ class ChatClient:
ratchet._pending_opk_delete = opk_id if opk_priv else None ratchet._pending_opk_delete = opk_id if opk_priv else None
session_key = f"{sender_id}:{sender_device_id}" if sender_device_id else sender_id # NOTE: the ratchet is intentionally NOT installed into self.sessions
self.sessions[session_key] = ratchet # nor saved to disk here. The caller does that only after the first
_save_session(self.email, sender_id, ratchet, self._local_key, # message decrypts successfully — otherwise a failed/forged X3DH
peer_device_id=sender_device_id) # header would overwrite a working session.
self._user_cache[sender_id] = { self._user_cache[sender_id] = {
"user_id": sender_id, "user_id": sender_id,
@@ -2463,6 +2471,9 @@ class ChatClient:
_save_session(self.email, sender_id, ratchet, self._local_key, _save_session(self.email, sender_id, ratchet, self._local_key,
peer_device_id=sender_device_id) peer_device_id=sender_device_id)
except Exception: except Exception:
# Restore the known-good session before attempting a
# fresh X3DH; if the X3DH path fails too, this restored
# session stays installed (in memory and on disk).
restored = DoubleRatchet.import_state(backup) restored = DoubleRatchet.import_state(backup)
self.sessions[session_key] = restored self.sessions[session_key] = restored
_save_session(self.email, sender_id, restored, self._local_key, _save_session(self.email, sender_id, restored, self._local_key,
@@ -2480,6 +2491,8 @@ class ChatClient:
plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce) plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce)
else: else:
raise raise
# First decrypt succeeded — only now adopt the new session
self.sessions[session_key] = ratchet
_save_session(self.email, sender_id, ratchet, self._local_key, _save_session(self.email, sender_id, ratchet, self._local_key,
peer_device_id=sender_device_id) peer_device_id=sender_device_id)
else: else:
@@ -2496,6 +2509,8 @@ class ChatClient:
plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce) plaintext = ratchet.decrypt(ratchet_header, ciphertext, nonce)
else: else:
raise raise
# First decrypt succeeded — install + persist the session
self.sessions[session_key] = ratchet
_save_session(self.email, sender_id, ratchet, self._local_key, _save_session(self.email, sender_id, ratchet, self._local_key,
peer_device_id=sender_device_id) peer_device_id=sender_device_id)
else: else:
@@ -2613,18 +2628,31 @@ class ChatClient:
raw_messages = resp["data"]["messages"] raw_messages = resp["data"]["messages"]
raw_messages.reverse() # Server returns DESC, reverse to ASC raw_messages.reverse() # Server returns DESC, reverse to ASC
# Save latest server timestamp for next incremental sync # Decrypt new messages from server
if raw_messages: new_decrypted = self._decrypt_raw_messages(raw_messages, cache, conv_id, my_user_id)
# raw_messages are now ASC; last one is newest
newest_ts = raw_messages[-1].get("created_at", "") # Advance the incremental-sync watermark only across the prefix of
if newest_ts: # messages that are settled in the cache (decrypted, control, deleted,
# or failed too many times). Stopping at the first unsettled message
# means a transiently undecryptable message (e.g. sender key not yet
# received) is re-fetched and retried on the next sync instead of
# being skipped forever.
if raw_messages and offset == 0:
newest_ts = ""
for m in raw_messages:
entry = cache.get(m["message_id"])
if entry is None:
break
fails = entry.get("_decrypt_failed", 0)
if fails and fails < _MAX_DECRYPT_RETRIES:
break
newest_ts = m.get("created_at", "") or newest_ts
prev_ts = cache.get("__last_server_ts", {}).get("ts", "")
if newest_ts and newest_ts > prev_ts:
cache["__last_server_ts"] = {"ts": newest_ts} cache["__last_server_ts"] = {"ts": newest_ts}
_save_message_to_cache(self.email, conv_id, "__last_server_ts", _save_message_to_cache(self.email, conv_id, "__last_server_ts",
{"ts": newest_ts}, cache_key=self._cache_key) {"ts": newest_ts}, cache_key=self._cache_key)
# Decrypt new messages from server
new_decrypted = self._decrypt_raw_messages(raw_messages, cache, conv_id, my_user_id)
# All non-critical ops fire-and-forget to avoid blocking message display # All non-critical ops fire-and-forget to avoid blocking message display
# Confirm delivery for messages from others # Confirm delivery for messages from others
deliver_ids = [m["message_id"] for m in new_decrypted deliver_ids = [m["message_id"] for m in new_decrypted
@@ -2674,6 +2702,16 @@ class ChatClient:
for msg_id, p in cache.items(): for msg_id, p in cache.items():
if p.get("_control") or msg_id.startswith("__"): if p.get("_control") or msg_id.startswith("__"):
continue continue
if p.get("_decrypt_failed"):
messages.append({
"message_id": msg_id,
"sender": "???",
"text": "[Decryption failed]",
"created_at": p.get("created_at", ""),
"read_by": [],
"delivered_to": [],
})
continue
entry = dict(p) entry = dict(p)
entry.setdefault("message_id", msg_id) entry.setdefault("message_id", msg_id)
entry.setdefault("read_by", []) entry.setdefault("read_by", [])
@@ -2704,6 +2742,18 @@ class ChatClient:
# Check local cache first (ratchet keys are one-time use) # Check local cache first (ratchet keys are one-time use)
cached = cache.get(msg_id) cached = cache.get(msg_id)
if cached and cached.get("_decrypt_failed"):
if cached["_decrypt_failed"] >= _MAX_DECRYPT_RETRIES:
decrypted.append({
"message_id": msg_id,
"sender": "???",
"text": "[Decryption failed]",
"created_at": m["created_at"],
"read_by": [],
"sender_id": m.get("sender_id", ""),
})
continue
cached = None # retry decryption below
if cached and not cached.get("_control"): if cached and not cached.get("_control"):
cached["read_by"] = m.get("read_by", []) cached["read_by"] = m.get("read_by", [])
cached["delivered_to"] = m.get("delivered_to", []) cached["delivered_to"] = m.get("delivered_to", [])
@@ -2761,12 +2811,20 @@ class ChatClient:
"delivered_to", "sender_id", "deleted")}, "delivered_to", "sender_id", "deleted")},
}) })
except Exception as e: except Exception as e:
# Record the failure (with retry count) so the sync watermark
# stops here and the message is retried on the next fetch.
fails = (cache.get(msg_id) or {}).get("_decrypt_failed", 0) + 1
fail_entry = {"_decrypt_failed": fails, "created_at": m["created_at"]}
cache[msg_id] = fail_entry
_save_message_to_cache(self.email, conv_id, msg_id, fail_entry,
cache_key=self._cache_key)
decrypted.append({ decrypted.append({
"message_id": msg_id, "message_id": msg_id,
"sender": "???", "sender": "???",
"text": f"[Decryption failed: {e}]", "text": f"[Decryption failed: {e}]",
"created_at": m["created_at"], "created_at": m["created_at"],
"read_by": [], "read_by": [],
"sender_id": m.get("sender_id", ""),
}) })
return decrypted return decrypted