diff --git a/chat_core.py b/chat_core.py index 59d7260..94bab3c 100644 --- a/chat_core.py +++ b/chat_core.py @@ -916,7 +916,7 @@ class ChatClient: "session_reset", "message_reacted", "message_pinned", "message_unpinned", "message_delivered", "username_changed", - "avatar_changed", + "avatar_changed", "keys_updated", "typing_start", "typing_stop"): await self._notification_queue.put(msg) else: @@ -1852,8 +1852,10 @@ class ChatClient: opk_priv = None if opk_id: opk_priv = _load_opk_private(self.email, opk_id, self._local_key) - if opk_priv: - _delete_opk_private(self.email, opk_id) + # Deletion is deferred until the first message decrypts successfully + # (_consume_pending_opk). Deleting here would break the SPK + # grace-period retry: the second _process_x3dh_header call could no + # longer load the OPK and the message would be lost permanently. spk_priv = spk_override if spk_override else self.spk_private @@ -1868,6 +1870,8 @@ class ChatClient: spk_pub = spk_priv.public_key() if hasattr(spk_priv, 'public_key') else None ratchet = DoubleRatchet.init_bob(shared_secret, (spk_priv, spk_pub)) + ratchet._pending_opk_delete = opk_id if opk_priv else None + session_key = f"{sender_id}:{sender_device_id}" if sender_device_id else sender_id self.sessions[session_key] = ratchet _save_session(self.email, sender_id, ratchet, self._local_key, @@ -1882,6 +1886,18 @@ class ChatClient: return ratchet + def _consume_pending_opk(self, ratchet) -> None: + """Delete the one-time prekey consumed by an X3DH handshake. + + Called only after the first message decrypted successfully, so a failed + attempt (e.g. wrong SPK during the grace period) can still retry with + the same OPK. + """ + opk_id = getattr(ratchet, "_pending_opk_delete", None) + if opk_id: + _delete_opk_private(self.email, opk_id) + ratchet._pending_opk_delete = None + # ------------------------------------------------------------------ # Conversations # ------------------------------------------------------------------ @@ -2485,6 +2501,8 @@ class ChatClient: else: raise ValueError(f"No session for sender {sender_id}") + self._consume_pending_opk(ratchet) + plaintext = unpad_plaintext(plaintext) payload = json.loads(plaintext) @@ -2633,9 +2651,9 @@ class ChatClient: """Sync message deletions from server (background, non-blocking).""" try: del_resp = await self.send_and_recv("get_deleted_since", - conversation_id=conv_id, since=after_ts) + conversation_id=conv_id, since_ts=after_ts) if del_resp.get("status") == "ok": - for del_id in del_resp.get("data", {}).get("message_ids", []): + for del_id in del_resp.get("data", {}).get("deleted_ids", []): _save_message_to_cache(self.email, conv_id, del_id, {"deleted": True}, cache_key=self._cache_key) except Exception: diff --git a/server.py b/server.py index 3caed14..c1a58f8 100644 --- a/server.py +++ b/server.py @@ -842,6 +842,10 @@ async def handle_login_finish(msg: dict, writer: ProtocolWriter, state: dict) -> "ip": addr, "added_at": datetime.now(timezone.utc).isoformat(), }, exclude_writer=writer) + # Contacts must refetch key bundles to include the new device + if contacts: + await _notify_users(contacts, "keys_updated", {"user_id": user_id}, + exclude_writer=writer) return {"user_id": user_id, "username": user["username"], "email": user["email"], "device_id": device_id} @@ -1060,6 +1064,13 @@ async def handle_ensure_prekeys(msg: dict, session: dict, writer: ProtocolWriter {"count": count, "spk_created_at": spk_created_at, "uploaded_spk": uploaded_spk, "uploaded_otps": uploaded_otps}) + # SPK change invalidates key bundles cached by other clients + if uploaded_spk: + contacts = await adb.get_user_contacts(user_id) + if contacts: + await _notify_users(contacts, "keys_updated", {"user_id": user_id}, + exclude_writer=writer) + async def handle_rotate_keys(msg: dict, session: dict, writer: ProtocolWriter): if await _is_rate_limited(f"rotate_keys|{session['user_id']}", 3): @@ -2478,6 +2489,10 @@ async def handle_download_stream(msg: dict, session: dict, writer: ProtocolWrite so the client can reassemble chunks in order even if delivery is reordered. """ file_id = msg.get("file_id", "") + if await _is_rate_limited(f"download_stream|{session['user_id']}", 30): + await send_resp(msg, writer, "download_stream", "error", + {"message": "Too many requests. Try later."}) + return result = await _validate_download(msg, session, writer, "download_stream") if not result: return