ProtocolReader discarded the entire chunk containing the newline delimiter while draining an oversized message, including bytes after the newline that belong to the next pipelined message. This corrupted framing for the rest of the connection (affects server and client). Salvaged bytes are now kept in a _leftover buffer that read_message() consumes before touching the stream. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
167 lines
5.9 KiB
Python
167 lines
5.9 KiB
Python
"""Newline-delimited JSON protocol with base64 encoding for binary data."""
|
|
|
|
import asyncio
|
|
import base64
|
|
import binascii
|
|
import json
|
|
import os
|
|
|
|
|
|
def encode_binary(data: bytes) -> str:
|
|
"""Encode bytes to base64 string."""
|
|
return base64.b64encode(data).decode("ascii")
|
|
|
|
|
|
def decode_binary(data: str) -> bytes:
|
|
"""Decode base64 string to bytes."""
|
|
try:
|
|
return base64.b64decode(data, validate=True)
|
|
except (TypeError, binascii.Error) as e:
|
|
raise ValueError(f"Invalid base64: {e}")
|
|
|
|
|
|
VERSION = "0.8.6"
|
|
MIN_CLIENT_VERSION = "0.8.6" # server rejects clients older than this
|
|
|
|
|
|
def version_gte(version: str, minimum: str) -> bool:
|
|
"""Return True if version >= minimum (compares numeric tuples, e.g. '0.8.1' >= '0.8').
|
|
|
|
Returns False for malformed version strings (instead of silently treating them as 0).
|
|
"""
|
|
def _parse(v: str) -> tuple[int, ...] | None:
|
|
if not isinstance(v, str) or not v:
|
|
return None
|
|
parts = v.split(".")
|
|
try:
|
|
return tuple(int(x) for x in parts)
|
|
except (ValueError, AttributeError):
|
|
return None
|
|
parsed_ver = _parse(version)
|
|
parsed_min = _parse(minimum)
|
|
if parsed_ver is None or parsed_min is None:
|
|
return False
|
|
return parsed_ver >= parsed_min
|
|
|
|
|
|
MAX_MESSAGE_BYTES = int(os.getenv("MAX_MESSAGE_BYTES", str(1024 * 1024))) # 1 MiB default (was 64K, raised for 256K media chunks)
|
|
MAX_IMAGE_BYTES = int(os.getenv("MAX_IMAGE_BYTES", str(5 * 1024 * 1024))) # 5 MiB default, 0 = no limit
|
|
MAX_FILE_BYTES = int(os.getenv("MAX_FILE_BYTES", str(50 * 1024 * 1024))) # 50 MiB default
|
|
IMAGE_CHUNK_SIZE = 262144 # 256 KiB raw chunk size for image upload/download
|
|
|
|
|
|
def build_request(msg_type: str, request_id: str | None = None, **kwargs) -> bytes:
|
|
"""Build a protocol message (newline-terminated JSON)."""
|
|
msg = {"type": msg_type, **kwargs}
|
|
if request_id:
|
|
msg["request_id"] = request_id
|
|
return json.dumps(msg, ensure_ascii=False).encode("utf-8") + b"\n"
|
|
|
|
|
|
def build_response(
|
|
msg_type: str,
|
|
status: str,
|
|
data: dict | None = None,
|
|
request_id: str | None = None,
|
|
) -> bytes:
|
|
"""Build a server response."""
|
|
msg = {"type": msg_type, "status": status}
|
|
if data is not None:
|
|
msg["data"] = data
|
|
if request_id:
|
|
msg["request_id"] = request_id
|
|
return json.dumps(msg, ensure_ascii=False).encode("utf-8") + b"\n"
|
|
|
|
|
|
def parse_message(line: bytes) -> dict:
|
|
"""Parse a single protocol message from bytes."""
|
|
try:
|
|
return json.loads(line.decode("utf-8"))
|
|
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
|
raise ValueError(f"Invalid message: {e}")
|
|
|
|
|
|
class ProtocolReader:
|
|
"""Read newline-delimited JSON messages from an asyncio StreamReader."""
|
|
|
|
def __init__(self, reader: asyncio.StreamReader):
|
|
self._reader = reader
|
|
self._leftover = b"" # bytes after the delimiter salvaged from a drain
|
|
|
|
async def read_message(self) -> dict | None:
|
|
"""Read and parse one message. Returns None on EOF."""
|
|
prefix = b""
|
|
if self._leftover:
|
|
nl = self._leftover.find(b"\n")
|
|
if nl != -1:
|
|
line = self._leftover[:nl + 1]
|
|
self._leftover = self._leftover[nl + 1:]
|
|
if len(line) > MAX_MESSAGE_BYTES:
|
|
raise ValueError("Message exceeds maximum size")
|
|
return parse_message(line.strip())
|
|
prefix = self._leftover
|
|
self._leftover = b""
|
|
try:
|
|
line = await self._reader.readuntil(b"\n")
|
|
except (asyncio.IncompleteReadError, ConnectionError):
|
|
return None
|
|
except asyncio.LimitOverrunError as e:
|
|
# Message exceeded StreamReader limit — drain oversized data
|
|
# using public read() API (consumed=e.consumed bytes before limit).
|
|
# Read in chunks until newline found or EOF, then signal error.
|
|
remaining = e.consumed
|
|
while True:
|
|
chunk = await self._reader.read(max(remaining, 4096))
|
|
if not chunk:
|
|
return None # EOF while draining
|
|
nl = chunk.find(b"\n")
|
|
if nl != -1:
|
|
# Bytes after the delimiter belong to the NEXT message —
|
|
# discarding them would corrupt framing for the rest of
|
|
# the connection.
|
|
self._leftover = chunk[nl + 1:]
|
|
break
|
|
raise ValueError("Message exceeds maximum size")
|
|
if not line:
|
|
return None
|
|
line = prefix + line
|
|
if len(line) > MAX_MESSAGE_BYTES:
|
|
raise ValueError("Message exceeds maximum size")
|
|
return parse_message(line.strip())
|
|
|
|
|
|
class ProtocolWriter:
|
|
"""Write newline-delimited JSON messages to an asyncio StreamWriter."""
|
|
|
|
def __init__(self, writer: asyncio.StreamWriter):
|
|
self._writer = writer
|
|
|
|
async def send_request(self, msg_type: str, request_id: str | None = None, **kwargs):
|
|
"""Send a request message."""
|
|
payload = build_request(msg_type, request_id=request_id, **kwargs)
|
|
if len(payload) > MAX_MESSAGE_BYTES:
|
|
raise ValueError(f"Message exceeds limit ({len(payload)} > {MAX_MESSAGE_BYTES})")
|
|
self._writer.write(payload)
|
|
await self._writer.drain()
|
|
|
|
async def send_response(
|
|
self,
|
|
msg_type: str,
|
|
status: str,
|
|
data: dict | None = None,
|
|
request_id: str | None = None,
|
|
):
|
|
"""Send a response message."""
|
|
payload = build_response(msg_type, status, data, request_id=request_id)
|
|
if len(payload) > MAX_MESSAGE_BYTES:
|
|
raise ValueError(f"Message exceeds limit ({len(payload)} > {MAX_MESSAGE_BYTES})")
|
|
self._writer.write(payload)
|
|
await self._writer.drain()
|
|
|
|
def is_closing(self) -> bool:
|
|
"""Check if the underlying transport is closing or closed."""
|
|
return self._writer.is_closing()
|
|
|
|
def close(self):
|
|
self._writer.close()
|