# # Copyright BitBake Contributors # # SPDX-License-Identifier: GPL-2.0-only # import asyncio import itertools import json from datetime import datetime from .exceptions import ClientError, ConnectionClosedError # The Python async server defaults to a 64K receive buffer, so we hardcode our # maximum chunk size. It would be better if the client and server reported to # each other what the maximum chunk sizes were, but that will slow down the # connection setup with a round trip delay so I'd rather not do that unless it # is necessary DEFAULT_MAX_CHUNK = 32 * 1024 def chunkify(msg, max_chunk): if len(msg) < max_chunk - 1: yield "".join((msg, "\n")) else: yield "".join((json.dumps({"chunk-stream": None}), "\n")) args = [iter(msg)] * (max_chunk - 1) for m in map("".join, itertools.zip_longest(*args, fillvalue="")): yield "".join(itertools.chain(m, "\n")) yield "\n" def json_serialize(obj): if isinstance(obj, datetime): return obj.isoformat() raise TypeError("Type %s not serializeable" % type(obj)) class StreamConnection(object): def __init__(self, reader, writer, timeout, max_chunk=DEFAULT_MAX_CHUNK): self.reader = reader self.writer = writer self.timeout = timeout self.max_chunk = max_chunk @property def address(self): return self.writer.get_extra_info("peername") async def send_message(self, msg): for c in chunkify(json.dumps(msg, default=json_serialize), self.max_chunk): self.writer.write(c.encode("utf-8")) await self.writer.drain() async def recv_message(self): l = await self.recv() m = json.loads(l) if not m: return m if "chunk-stream" in m: lines = [] while True: l = await self.recv() if not l: break lines.append(l) m = json.loads("".join(lines)) return m async def send(self, msg): self.writer.write(("%s\n" % msg).encode("utf-8")) await self.writer.drain() async def recv(self): if self.timeout < 0: line = await self.reader.readline() else: try: line = await asyncio.wait_for(self.reader.readline(), self.timeout) except asyncio.TimeoutError: raise ConnectionError("Timed out waiting for data") if not line: raise ConnectionClosedError("Connection closed") line = line.decode("utf-8") if not line.endswith("\n"): raise ConnectionError("Bad message %r" % (line)) return line.rstrip() async def close(self): self.reader = None if self.writer is not None: self.writer.close() self.writer = None class WebsocketConnection(object): def __init__(self, socket, timeout): self.socket = socket self.timeout = timeout @property def address(self): return ":".join(str(s) for s in self.socket.remote_address) async def send_message(self, msg): await self.send(json.dumps(msg, default=json_serialize)) async def recv_message(self): m = await self.recv() return json.loads(m) async def send(self, msg): import websockets.exceptions try: await self.socket.send(msg) except websockets.exceptions.ConnectionClosed: raise ConnectionClosedError("Connection closed") async def recv(self): import websockets.exceptions try: if self.timeout < 0: return await self.socket.recv() try: return await asyncio.wait_for(self.socket.recv(), self.timeout) except asyncio.TimeoutError: raise ConnectionError("Timed out waiting for data") except websockets.exceptions.ConnectionClosed: raise ConnectionClosedError("Connection closed") async def close(self): if self.socket is not None: await self.socket.close() self.socket = None