diff --git a/protocol.py b/protocol.py index 5c9f15a..55ee78f 100644 --- a/protocol.py +++ b/protocol.py @@ -86,9 +86,21 @@ class ProtocolReader: def __init__(self, reader: asyncio.StreamReader): self._reader = reader + self._leftover = b"" # bytes after the delimiter salvaged from a drain async def read_message(self) -> dict | None: """Read and parse one message. Returns None on EOF.""" + prefix = b"" + if self._leftover: + nl = self._leftover.find(b"\n") + if nl != -1: + line = self._leftover[:nl + 1] + self._leftover = self._leftover[nl + 1:] + if len(line) > MAX_MESSAGE_BYTES: + raise ValueError("Message exceeds maximum size") + return parse_message(line.strip()) + prefix = self._leftover + self._leftover = b"" try: line = await self._reader.readuntil(b"\n") except (asyncio.IncompleteReadError, ConnectionError): @@ -102,11 +114,19 @@ class ProtocolReader: chunk = await self._reader.read(max(remaining, 4096)) if not chunk: return None # EOF while draining - if b"\n" in chunk: - break # found delimiter, oversized message fully drained + nl = chunk.find(b"\n") + if nl != -1: + # Bytes after the delimiter belong to the NEXT message — + # discarding them would corrupt framing for the rest of + # the connection. + self._leftover = chunk[nl + 1:] + break raise ValueError("Message exceeds maximum size") if not line: return None + line = prefix + line + if len(line) > MAX_MESSAGE_BYTES: + raise ValueError("Message exceeds maximum size") return parse_message(line.strip())