summaryrefslogtreecommitdiffstats
path: root/bitbake/lib/bb/asyncrpc/serv.py
diff options
context:
space:
mode:
Diffstat (limited to 'bitbake/lib/bb/asyncrpc/serv.py')
-rw-r--r--bitbake/lib/bb/asyncrpc/serv.py391
1 files changed, 391 insertions, 0 deletions
diff --git a/bitbake/lib/bb/asyncrpc/serv.py b/bitbake/lib/bb/asyncrpc/serv.py
new file mode 100644
index 0000000000..a66117acad
--- /dev/null
+++ b/bitbake/lib/bb/asyncrpc/serv.py
@@ -0,0 +1,391 @@
1#
2# Copyright BitBake Contributors
3#
4# SPDX-License-Identifier: GPL-2.0-only
5#
6
7import abc
8import asyncio
9import json
10import os
11import signal
12import socket
13import sys
14import multiprocessing
15import logging
16from .connection import StreamConnection, WebsocketConnection
17from .exceptions import ClientError, ServerError, ConnectionClosedError, InvokeError
18
19
20class ClientLoggerAdapter(logging.LoggerAdapter):
21 def process(self, msg, kwargs):
22 return f"[Client {self.extra['address']}] {msg}", kwargs
23
24
25class AsyncServerConnection(object):
26 # If a handler returns this object (e.g. `return self.NO_RESPONSE`), no
27 # return message will be automatically be sent back to the client
28 NO_RESPONSE = object()
29
30 def __init__(self, socket, proto_name, logger):
31 self.socket = socket
32 self.proto_name = proto_name
33 self.handlers = {
34 "ping": self.handle_ping,
35 }
36 self.logger = ClientLoggerAdapter(
37 logger,
38 {
39 "address": socket.address,
40 },
41 )
42 self.client_headers = {}
43
44 async def close(self):
45 await self.socket.close()
46
47 async def handle_headers(self, headers):
48 return {}
49
50 async def process_requests(self):
51 try:
52 self.logger.info("Client %r connected" % (self.socket.address,))
53
54 # Read protocol and version
55 client_protocol = await self.socket.recv()
56 if not client_protocol:
57 return
58
59 (client_proto_name, client_proto_version) = client_protocol.split()
60 if client_proto_name != self.proto_name:
61 self.logger.debug("Rejecting invalid protocol %s" % (self.proto_name))
62 return
63
64 self.proto_version = tuple(int(v) for v in client_proto_version.split("."))
65 if not self.validate_proto_version():
66 self.logger.debug(
67 "Rejecting invalid protocol version %s" % (client_proto_version)
68 )
69 return
70
71 # Read headers
72 self.client_headers = {}
73 while True:
74 header = await self.socket.recv()
75 if not header:
76 # Empty line. End of headers
77 break
78 tag, value = header.split(":", 1)
79 self.client_headers[tag.lower()] = value.strip()
80
81 if self.client_headers.get("needs-headers", "false") == "true":
82 for k, v in (await self.handle_headers(self.client_headers)).items():
83 await self.socket.send("%s: %s" % (k, v))
84 await self.socket.send("")
85
86 # Handle messages
87 while True:
88 d = await self.socket.recv_message()
89 if d is None:
90 break
91 try:
92 response = await self.dispatch_message(d)
93 except InvokeError as e:
94 await self.socket.send_message(
95 {"invoke-error": {"message": str(e)}}
96 )
97 break
98
99 if response is not self.NO_RESPONSE:
100 await self.socket.send_message(response)
101
102 except ConnectionClosedError as e:
103 self.logger.info(str(e))
104 except (ClientError, ConnectionError) as e:
105 self.logger.error(str(e))
106 finally:
107 await self.close()
108
109 async def dispatch_message(self, msg):
110 for k in self.handlers.keys():
111 if k in msg:
112 self.logger.debug("Handling %s" % k)
113 return await self.handlers[k](msg[k])
114
115 raise ClientError("Unrecognized command %r" % msg)
116
117 async def handle_ping(self, request):
118 return {"alive": True}
119
120
121class StreamServer(object):
122 def __init__(self, handler, logger):
123 self.handler = handler
124 self.logger = logger
125 self.closed = False
126
127 async def handle_stream_client(self, reader, writer):
128 # writer.transport.set_write_buffer_limits(0)
129 socket = StreamConnection(reader, writer, -1)
130 if self.closed:
131 await socket.close()
132 return
133
134 await self.handler(socket)
135
136 async def stop(self):
137 self.closed = True
138
139
140class TCPStreamServer(StreamServer):
141 def __init__(self, host, port, handler, logger):
142 super().__init__(handler, logger)
143 self.host = host
144 self.port = port
145
146 def start(self, loop):
147 self.server = loop.run_until_complete(
148 asyncio.start_server(self.handle_stream_client, self.host, self.port)
149 )
150
151 for s in self.server.sockets:
152 self.logger.debug("Listening on %r" % (s.getsockname(),))
153 # Newer python does this automatically. Do it manually here for
154 # maximum compatibility
155 s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
156 s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
157
158 # Enable keep alives. This prevents broken client connections
159 # from persisting on the server for long periods of time.
160 s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
161 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
162 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
163 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
164
165 name = self.server.sockets[0].getsockname()
166 if self.server.sockets[0].family == socket.AF_INET6:
167 self.address = "[%s]:%d" % (name[0], name[1])
168 else:
169 self.address = "%s:%d" % (name[0], name[1])
170
171 return [self.server.wait_closed()]
172
173 async def stop(self):
174 await super().stop()
175 self.server.close()
176
177 def cleanup(self):
178 pass
179
180
181class UnixStreamServer(StreamServer):
182 def __init__(self, path, handler, logger):
183 super().__init__(handler, logger)
184 self.path = path
185
186 def start(self, loop):
187 cwd = os.getcwd()
188 try:
189 # Work around path length limits in AF_UNIX
190 os.chdir(os.path.dirname(self.path))
191 self.server = loop.run_until_complete(
192 asyncio.start_unix_server(
193 self.handle_stream_client, os.path.basename(self.path)
194 )
195 )
196 finally:
197 os.chdir(cwd)
198
199 self.logger.debug("Listening on %r" % self.path)
200 self.address = "unix://%s" % os.path.abspath(self.path)
201 return [self.server.wait_closed()]
202
203 async def stop(self):
204 await super().stop()
205 self.server.close()
206
207 def cleanup(self):
208 os.unlink(self.path)
209
210
211class WebsocketsServer(object):
212 def __init__(self, host, port, handler, logger):
213 self.host = host
214 self.port = port
215 self.handler = handler
216 self.logger = logger
217
218 def start(self, loop):
219 import websockets.server
220
221 self.server = loop.run_until_complete(
222 websockets.server.serve(
223 self.client_handler,
224 self.host,
225 self.port,
226 ping_interval=None,
227 )
228 )
229
230 for s in self.server.sockets:
231 self.logger.debug("Listening on %r" % (s.getsockname(),))
232
233 # Enable keep alives. This prevents broken client connections
234 # from persisting on the server for long periods of time.
235 s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
236 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
237 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
238 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
239
240 name = self.server.sockets[0].getsockname()
241 if self.server.sockets[0].family == socket.AF_INET6:
242 self.address = "ws://[%s]:%d" % (name[0], name[1])
243 else:
244 self.address = "ws://%s:%d" % (name[0], name[1])
245
246 return [self.server.wait_closed()]
247
248 async def stop(self):
249 self.server.close()
250
251 def cleanup(self):
252 pass
253
254 async def client_handler(self, websocket):
255 socket = WebsocketConnection(websocket, -1)
256 await self.handler(socket)
257
258
259class AsyncServer(object):
260 def __init__(self, logger):
261 self.logger = logger
262 self.loop = None
263 self.run_tasks = []
264
265 def start_tcp_server(self, host, port):
266 self.server = TCPStreamServer(host, port, self._client_handler, self.logger)
267
268 def start_unix_server(self, path):
269 self.server = UnixStreamServer(path, self._client_handler, self.logger)
270
271 def start_websocket_server(self, host, port):
272 self.server = WebsocketsServer(host, port, self._client_handler, self.logger)
273
274 async def _client_handler(self, socket):
275 address = socket.address
276 try:
277 client = self.accept_client(socket)
278 await client.process_requests()
279 except Exception as e:
280 import traceback
281
282 self.logger.error(
283 "Error from client %s: %s" % (address, str(e)), exc_info=True
284 )
285 traceback.print_exc()
286 finally:
287 self.logger.debug("Client %s disconnected", address)
288 await socket.close()
289
290 @abc.abstractmethod
291 def accept_client(self, socket):
292 pass
293
294 async def stop(self):
295 self.logger.debug("Stopping server")
296 await self.server.stop()
297
298 def start(self):
299 tasks = self.server.start(self.loop)
300 self.address = self.server.address
301 return tasks
302
303 def signal_handler(self):
304 self.logger.debug("Got exit signal")
305 self.loop.create_task(self.stop())
306
307 def _serve_forever(self, tasks):
308 try:
309 self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler)
310 self.loop.add_signal_handler(signal.SIGINT, self.signal_handler)
311 self.loop.add_signal_handler(signal.SIGQUIT, self.signal_handler)
312 signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM])
313
314 self.loop.run_until_complete(asyncio.gather(*tasks))
315
316 self.logger.debug("Server shutting down")
317 finally:
318 self.server.cleanup()
319
320 def serve_forever(self):
321 """
322 Serve requests in the current process
323 """
324 self._create_loop()
325 tasks = self.start()
326 self._serve_forever(tasks)
327 self.loop.close()
328
329 def _create_loop(self):
330 # Create loop and override any loop that may have existed in
331 # a parent process. It is possible that the usecases of
332 # serve_forever might be constrained enough to allow using
333 # get_event_loop here, but better safe than sorry for now.
334 self.loop = asyncio.new_event_loop()
335 asyncio.set_event_loop(self.loop)
336
337 def serve_as_process(self, *, prefunc=None, args=(), log_level=None):
338 """
339 Serve requests in a child process
340 """
341
342 def run(queue):
343 # Create loop and override any loop that may have existed
344 # in a parent process. Without doing this and instead
345 # using get_event_loop, at the very minimum the hashserv
346 # unit tests will hang when running the second test.
347 # This happens since get_event_loop in the spawned server
348 # process for the second testcase ends up with the loop
349 # from the hashserv client created in the unit test process
350 # when running the first testcase. The problem is somewhat
351 # more general, though, as any potential use of asyncio in
352 # Cooker could create a loop that needs to replaced in this
353 # new process.
354 self._create_loop()
355 try:
356 self.address = None
357 tasks = self.start()
358 finally:
359 # Always put the server address to wake up the parent task
360 queue.put(self.address)
361 queue.close()
362
363 if prefunc is not None:
364 prefunc(self, *args)
365
366 if log_level is not None:
367 self.logger.setLevel(log_level)
368
369 self._serve_forever(tasks)
370
371 if sys.version_info >= (3, 6):
372 self.loop.run_until_complete(self.loop.shutdown_asyncgens())
373 self.loop.close()
374
375 queue = multiprocessing.Queue()
376
377 # Temporarily block SIGTERM. The server process will inherit this
378 # block which will ensure it doesn't receive the SIGTERM until the
379 # handler is ready for it
380 mask = signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGTERM])
381 try:
382 self.process = multiprocessing.Process(target=run, args=(queue,))
383 self.process.start()
384
385 self.address = queue.get()
386 queue.close()
387 queue.join_thread()
388
389 return self.process
390 finally:
391 signal.pthread_sigmask(signal.SIG_SETMASK, mask)