summaryrefslogtreecommitdiffstats
path: root/bitbake/lib/bb/asyncrpc/serv.py
diff options
context:
space:
mode:
Diffstat (limited to 'bitbake/lib/bb/asyncrpc/serv.py')
-rw-r--r--bitbake/lib/bb/asyncrpc/serv.py410
1 files changed, 410 insertions, 0 deletions
diff --git a/bitbake/lib/bb/asyncrpc/serv.py b/bitbake/lib/bb/asyncrpc/serv.py
new file mode 100644
index 0000000000..667217c5c1
--- /dev/null
+++ b/bitbake/lib/bb/asyncrpc/serv.py
@@ -0,0 +1,410 @@
1#
2# Copyright BitBake Contributors
3#
4# SPDX-License-Identifier: GPL-2.0-only
5#
6
7import abc
8import asyncio
9import json
10import os
11import signal
12import socket
13import sys
14import multiprocessing
15import logging
16from .connection import StreamConnection, WebsocketConnection
17from .exceptions import ClientError, ServerError, ConnectionClosedError, InvokeError
18
19
20class ClientLoggerAdapter(logging.LoggerAdapter):
21 def process(self, msg, kwargs):
22 return f"[Client {self.extra['address']}] {msg}", kwargs
23
24
25class AsyncServerConnection(object):
26 # If a handler returns this object (e.g. `return self.NO_RESPONSE`), no
27 # return message will be automatically be sent back to the client
28 NO_RESPONSE = object()
29
30 def __init__(self, socket, proto_name, logger):
31 self.socket = socket
32 self.proto_name = proto_name
33 self.handlers = {
34 "ping": self.handle_ping,
35 }
36 self.logger = ClientLoggerAdapter(
37 logger,
38 {
39 "address": socket.address,
40 },
41 )
42 self.client_headers = {}
43
44 async def close(self):
45 await self.socket.close()
46
47 async def handle_headers(self, headers):
48 return {}
49
50 async def process_requests(self):
51 try:
52 self.logger.info("Client %r connected" % (self.socket.address,))
53
54 # Read protocol and version
55 client_protocol = await self.socket.recv()
56 if not client_protocol:
57 return
58
59 (client_proto_name, client_proto_version) = client_protocol.split()
60 if client_proto_name != self.proto_name:
61 self.logger.debug("Rejecting invalid protocol %s" % (self.proto_name))
62 return
63
64 self.proto_version = tuple(int(v) for v in client_proto_version.split("."))
65 if not self.validate_proto_version():
66 self.logger.debug(
67 "Rejecting invalid protocol version %s" % (client_proto_version)
68 )
69 return
70
71 # Read headers
72 self.client_headers = {}
73 while True:
74 header = await self.socket.recv()
75 if not header:
76 # Empty line. End of headers
77 break
78 tag, value = header.split(":", 1)
79 self.client_headers[tag.lower()] = value.strip()
80
81 if self.client_headers.get("needs-headers", "false") == "true":
82 for k, v in (await self.handle_headers(self.client_headers)).items():
83 await self.socket.send("%s: %s" % (k, v))
84 await self.socket.send("")
85
86 # Handle messages
87 while True:
88 d = await self.socket.recv_message()
89 if d is None:
90 break
91 try:
92 response = await self.dispatch_message(d)
93 except InvokeError as e:
94 await self.socket.send_message(
95 {"invoke-error": {"message": str(e)}}
96 )
97 break
98
99 if response is not self.NO_RESPONSE:
100 await self.socket.send_message(response)
101
102 except ConnectionClosedError as e:
103 self.logger.info(str(e))
104 except (ClientError, ConnectionError) as e:
105 self.logger.error(str(e))
106 finally:
107 await self.close()
108
109 async def dispatch_message(self, msg):
110 for k in self.handlers.keys():
111 if k in msg:
112 self.logger.debug("Handling %s" % k)
113 return await self.handlers[k](msg[k])
114
115 raise ClientError("Unrecognized command %r" % msg)
116
117 async def handle_ping(self, request):
118 return {"alive": True}
119
120
121class StreamServer(object):
122 def __init__(self, handler, logger):
123 self.handler = handler
124 self.logger = logger
125 self.closed = False
126
127 async def handle_stream_client(self, reader, writer):
128 # writer.transport.set_write_buffer_limits(0)
129 socket = StreamConnection(reader, writer, -1)
130 if self.closed:
131 await socket.close()
132 return
133
134 await self.handler(socket)
135
136 async def stop(self):
137 self.closed = True
138
139
140class TCPStreamServer(StreamServer):
141 def __init__(self, host, port, handler, logger, *, reuseport=False):
142 super().__init__(handler, logger)
143 self.host = host
144 self.port = port
145 self.reuseport = reuseport
146
147 def start(self, loop):
148 self.server = loop.run_until_complete(
149 asyncio.start_server(
150 self.handle_stream_client,
151 self.host,
152 self.port,
153 reuse_port=self.reuseport,
154 )
155 )
156
157 for s in self.server.sockets:
158 self.logger.debug("Listening on %r" % (s.getsockname(),))
159 # Newer python does this automatically. Do it manually here for
160 # maximum compatibility
161 s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
162 s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
163
164 # Enable keep alives. This prevents broken client connections
165 # from persisting on the server for long periods of time.
166 s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
167 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
168 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
169 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
170
171 name = self.server.sockets[0].getsockname()
172 if self.server.sockets[0].family == socket.AF_INET6:
173 self.address = "[%s]:%d" % (name[0], name[1])
174 else:
175 self.address = "%s:%d" % (name[0], name[1])
176
177 return [self.server.wait_closed()]
178
179 async def stop(self):
180 await super().stop()
181 self.server.close()
182
183 def cleanup(self):
184 pass
185
186
187class UnixStreamServer(StreamServer):
188 def __init__(self, path, handler, logger):
189 super().__init__(handler, logger)
190 self.path = path
191
192 def start(self, loop):
193 cwd = os.getcwd()
194 try:
195 # Work around path length limits in AF_UNIX
196 os.chdir(os.path.dirname(self.path))
197 self.server = loop.run_until_complete(
198 asyncio.start_unix_server(
199 self.handle_stream_client, os.path.basename(self.path)
200 )
201 )
202 finally:
203 os.chdir(cwd)
204
205 self.logger.debug("Listening on %r" % self.path)
206 self.address = "unix://%s" % os.path.abspath(self.path)
207 return [self.server.wait_closed()]
208
209 async def stop(self):
210 await super().stop()
211 self.server.close()
212
213 def cleanup(self):
214 os.unlink(self.path)
215
216
217class WebsocketsServer(object):
218 def __init__(self, host, port, handler, logger, *, reuseport=False):
219 self.host = host
220 self.port = port
221 self.handler = handler
222 self.logger = logger
223 self.reuseport = reuseport
224
225 def start(self, loop):
226 import websockets.server
227
228 self.server = loop.run_until_complete(
229 websockets.server.serve(
230 self.client_handler,
231 self.host,
232 self.port,
233 ping_interval=None,
234 reuse_port=self.reuseport,
235 )
236 )
237
238 for s in self.server.sockets:
239 self.logger.debug("Listening on %r" % (s.getsockname(),))
240
241 # Enable keep alives. This prevents broken client connections
242 # from persisting on the server for long periods of time.
243 s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
244 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
245 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
246 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
247
248 name = self.server.sockets[0].getsockname()
249 if self.server.sockets[0].family == socket.AF_INET6:
250 self.address = "ws://[%s]:%d" % (name[0], name[1])
251 else:
252 self.address = "ws://%s:%d" % (name[0], name[1])
253
254 return [self.server.wait_closed()]
255
256 async def stop(self):
257 self.server.close()
258
259 def cleanup(self):
260 pass
261
262 async def client_handler(self, websocket):
263 socket = WebsocketConnection(websocket, -1)
264 await self.handler(socket)
265
266
267class AsyncServer(object):
268 def __init__(self, logger):
269 self.logger = logger
270 self.loop = None
271 self.run_tasks = []
272
273 def start_tcp_server(self, host, port, *, reuseport=False):
274 self.server = TCPStreamServer(
275 host,
276 port,
277 self._client_handler,
278 self.logger,
279 reuseport=reuseport,
280 )
281
282 def start_unix_server(self, path):
283 self.server = UnixStreamServer(path, self._client_handler, self.logger)
284
285 def start_websocket_server(self, host, port, reuseport=False):
286 self.server = WebsocketsServer(
287 host,
288 port,
289 self._client_handler,
290 self.logger,
291 reuseport=reuseport,
292 )
293
294 async def _client_handler(self, socket):
295 address = socket.address
296 try:
297 client = self.accept_client(socket)
298 await client.process_requests()
299 except Exception as e:
300 import traceback
301
302 self.logger.error(
303 "Error from client %s: %s" % (address, str(e)), exc_info=True
304 )
305 traceback.print_exc()
306 finally:
307 self.logger.debug("Client %s disconnected", address)
308 await socket.close()
309
310 @abc.abstractmethod
311 def accept_client(self, socket):
312 pass
313
314 async def stop(self):
315 self.logger.debug("Stopping server")
316 await self.server.stop()
317
318 def start(self):
319 tasks = self.server.start(self.loop)
320 self.address = self.server.address
321 return tasks
322
323 def signal_handler(self):
324 self.logger.debug("Got exit signal")
325 self.loop.create_task(self.stop())
326
327 def _serve_forever(self, tasks):
328 try:
329 self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler)
330 self.loop.add_signal_handler(signal.SIGINT, self.signal_handler)
331 self.loop.add_signal_handler(signal.SIGQUIT, self.signal_handler)
332 signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM])
333
334 self.loop.run_until_complete(asyncio.gather(*tasks))
335
336 self.logger.debug("Server shutting down")
337 finally:
338 self.server.cleanup()
339
340 def serve_forever(self):
341 """
342 Serve requests in the current process
343 """
344 self._create_loop()
345 tasks = self.start()
346 self._serve_forever(tasks)
347 self.loop.close()
348
349 def _create_loop(self):
350 # Create loop and override any loop that may have existed in
351 # a parent process. It is possible that the usecases of
352 # serve_forever might be constrained enough to allow using
353 # get_event_loop here, but better safe than sorry for now.
354 self.loop = asyncio.new_event_loop()
355 asyncio.set_event_loop(self.loop)
356
357 def serve_as_process(self, *, prefunc=None, args=(), log_level=None):
358 """
359 Serve requests in a child process
360 """
361
362 def run(queue):
363 # Create loop and override any loop that may have existed
364 # in a parent process. Without doing this and instead
365 # using get_event_loop, at the very minimum the hashserv
366 # unit tests will hang when running the second test.
367 # This happens since get_event_loop in the spawned server
368 # process for the second testcase ends up with the loop
369 # from the hashserv client created in the unit test process
370 # when running the first testcase. The problem is somewhat
371 # more general, though, as any potential use of asyncio in
372 # Cooker could create a loop that needs to replaced in this
373 # new process.
374 self._create_loop()
375 try:
376 self.address = None
377 tasks = self.start()
378 finally:
379 # Always put the server address to wake up the parent task
380 queue.put(self.address)
381 queue.close()
382
383 if prefunc is not None:
384 prefunc(self, *args)
385
386 if log_level is not None:
387 self.logger.setLevel(log_level)
388
389 self._serve_forever(tasks)
390
391 self.loop.run_until_complete(self.loop.shutdown_asyncgens())
392 self.loop.close()
393
394 queue = multiprocessing.Queue()
395
396 # Temporarily block SIGTERM. The server process will inherit this
397 # block which will ensure it doesn't receive the SIGTERM until the
398 # handler is ready for it
399 mask = signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGTERM])
400 try:
401 self.process = multiprocessing.Process(target=run, args=(queue,))
402 self.process.start()
403
404 self.address = queue.get()
405 queue.close()
406 queue.join_thread()
407
408 return self.process
409 finally:
410 signal.pthread_sigmask(signal.SIG_SETMASK, mask)