Files
Kecalek_python/protocol.py
filip f0666ea6ac Preserve next-message bytes when draining oversized protocol message
ProtocolReader discarded the entire chunk containing the newline
delimiter while draining an oversized message, including bytes after
the newline that belong to the next pipelined message. This corrupted
framing for the rest of the connection (affects server and client).
Salvaged bytes are now kept in a _leftover buffer that read_message()
consumes before touching the stream.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-12 16:08:12 +02:00

167 lines
5.9 KiB
Python

"""Newline-delimited JSON protocol with base64 encoding for binary data."""
import asyncio
import base64
import binascii
import json
import os
def encode_binary(data: bytes) -> str:
"""Encode bytes to base64 string."""
return base64.b64encode(data).decode("ascii")
def decode_binary(data: str) -> bytes:
"""Decode base64 string to bytes."""
try:
return base64.b64decode(data, validate=True)
except (TypeError, binascii.Error) as e:
raise ValueError(f"Invalid base64: {e}")
VERSION = "0.8.6"
MIN_CLIENT_VERSION = "0.8.6" # server rejects clients older than this
def version_gte(version: str, minimum: str) -> bool:
"""Return True if version >= minimum (compares numeric tuples, e.g. '0.8.1' >= '0.8').
Returns False for malformed version strings (instead of silently treating them as 0).
"""
def _parse(v: str) -> tuple[int, ...] | None:
if not isinstance(v, str) or not v:
return None
parts = v.split(".")
try:
return tuple(int(x) for x in parts)
except (ValueError, AttributeError):
return None
parsed_ver = _parse(version)
parsed_min = _parse(minimum)
if parsed_ver is None or parsed_min is None:
return False
return parsed_ver >= parsed_min
MAX_MESSAGE_BYTES = int(os.getenv("MAX_MESSAGE_BYTES", str(1024 * 1024))) # 1 MiB default (was 64K, raised for 256K media chunks)
MAX_IMAGE_BYTES = int(os.getenv("MAX_IMAGE_BYTES", str(5 * 1024 * 1024))) # 5 MiB default, 0 = no limit
MAX_FILE_BYTES = int(os.getenv("MAX_FILE_BYTES", str(50 * 1024 * 1024))) # 50 MiB default
IMAGE_CHUNK_SIZE = 262144 # 256 KiB raw chunk size for image upload/download
def build_request(msg_type: str, request_id: str | None = None, **kwargs) -> bytes:
"""Build a protocol message (newline-terminated JSON)."""
msg = {"type": msg_type, **kwargs}
if request_id:
msg["request_id"] = request_id
return json.dumps(msg, ensure_ascii=False).encode("utf-8") + b"\n"
def build_response(
msg_type: str,
status: str,
data: dict | None = None,
request_id: str | None = None,
) -> bytes:
"""Build a server response."""
msg = {"type": msg_type, "status": status}
if data is not None:
msg["data"] = data
if request_id:
msg["request_id"] = request_id
return json.dumps(msg, ensure_ascii=False).encode("utf-8") + b"\n"
def parse_message(line: bytes) -> dict:
"""Parse a single protocol message from bytes."""
try:
return json.loads(line.decode("utf-8"))
except (json.JSONDecodeError, UnicodeDecodeError) as e:
raise ValueError(f"Invalid message: {e}")
class ProtocolReader:
"""Read newline-delimited JSON messages from an asyncio StreamReader."""
def __init__(self, reader: asyncio.StreamReader):
self._reader = reader
self._leftover = b"" # bytes after the delimiter salvaged from a drain
async def read_message(self) -> dict | None:
"""Read and parse one message. Returns None on EOF."""
prefix = b""
if self._leftover:
nl = self._leftover.find(b"\n")
if nl != -1:
line = self._leftover[:nl + 1]
self._leftover = self._leftover[nl + 1:]
if len(line) > MAX_MESSAGE_BYTES:
raise ValueError("Message exceeds maximum size")
return parse_message(line.strip())
prefix = self._leftover
self._leftover = b""
try:
line = await self._reader.readuntil(b"\n")
except (asyncio.IncompleteReadError, ConnectionError):
return None
except asyncio.LimitOverrunError as e:
# Message exceeded StreamReader limit — drain oversized data
# using public read() API (consumed=e.consumed bytes before limit).
# Read in chunks until newline found or EOF, then signal error.
remaining = e.consumed
while True:
chunk = await self._reader.read(max(remaining, 4096))
if not chunk:
return None # EOF while draining
nl = chunk.find(b"\n")
if nl != -1:
# Bytes after the delimiter belong to the NEXT message —
# discarding them would corrupt framing for the rest of
# the connection.
self._leftover = chunk[nl + 1:]
break
raise ValueError("Message exceeds maximum size")
if not line:
return None
line = prefix + line
if len(line) > MAX_MESSAGE_BYTES:
raise ValueError("Message exceeds maximum size")
return parse_message(line.strip())
class ProtocolWriter:
"""Write newline-delimited JSON messages to an asyncio StreamWriter."""
def __init__(self, writer: asyncio.StreamWriter):
self._writer = writer
async def send_request(self, msg_type: str, request_id: str | None = None, **kwargs):
"""Send a request message."""
payload = build_request(msg_type, request_id=request_id, **kwargs)
if len(payload) > MAX_MESSAGE_BYTES:
raise ValueError(f"Message exceeds limit ({len(payload)} > {MAX_MESSAGE_BYTES})")
self._writer.write(payload)
await self._writer.drain()
async def send_response(
self,
msg_type: str,
status: str,
data: dict | None = None,
request_id: str | None = None,
):
"""Send a response message."""
payload = build_response(msg_type, status, data, request_id=request_id)
if len(payload) > MAX_MESSAGE_BYTES:
raise ValueError(f"Message exceeds limit ({len(payload)} > {MAX_MESSAGE_BYTES})")
self._writer.write(payload)
await self._writer.drain()
def is_closing(self) -> bool:
"""Check if the underlying transport is closing or closed."""
return self._writer.is_closing()
def close(self):
self._writer.close()