diff options
Diffstat (limited to 'bitbake/lib/bb/asyncrpc/serv.py')
-rw-r--r-- | bitbake/lib/bb/asyncrpc/serv.py | 391 |
1 files changed, 391 insertions, 0 deletions
diff --git a/bitbake/lib/bb/asyncrpc/serv.py b/bitbake/lib/bb/asyncrpc/serv.py new file mode 100644 index 0000000000..a66117acad --- /dev/null +++ b/bitbake/lib/bb/asyncrpc/serv.py | |||
@@ -0,0 +1,391 @@ | |||
1 | # | ||
2 | # Copyright BitBake Contributors | ||
3 | # | ||
4 | # SPDX-License-Identifier: GPL-2.0-only | ||
5 | # | ||
6 | |||
7 | import abc | ||
8 | import asyncio | ||
9 | import json | ||
10 | import os | ||
11 | import signal | ||
12 | import socket | ||
13 | import sys | ||
14 | import multiprocessing | ||
15 | import logging | ||
16 | from .connection import StreamConnection, WebsocketConnection | ||
17 | from .exceptions import ClientError, ServerError, ConnectionClosedError, InvokeError | ||
18 | |||
19 | |||
20 | class ClientLoggerAdapter(logging.LoggerAdapter): | ||
21 | def process(self, msg, kwargs): | ||
22 | return f"[Client {self.extra['address']}] {msg}", kwargs | ||
23 | |||
24 | |||
25 | class AsyncServerConnection(object): | ||
26 | # If a handler returns this object (e.g. `return self.NO_RESPONSE`), no | ||
27 | # return message will be automatically be sent back to the client | ||
28 | NO_RESPONSE = object() | ||
29 | |||
30 | def __init__(self, socket, proto_name, logger): | ||
31 | self.socket = socket | ||
32 | self.proto_name = proto_name | ||
33 | self.handlers = { | ||
34 | "ping": self.handle_ping, | ||
35 | } | ||
36 | self.logger = ClientLoggerAdapter( | ||
37 | logger, | ||
38 | { | ||
39 | "address": socket.address, | ||
40 | }, | ||
41 | ) | ||
42 | self.client_headers = {} | ||
43 | |||
44 | async def close(self): | ||
45 | await self.socket.close() | ||
46 | |||
47 | async def handle_headers(self, headers): | ||
48 | return {} | ||
49 | |||
50 | async def process_requests(self): | ||
51 | try: | ||
52 | self.logger.info("Client %r connected" % (self.socket.address,)) | ||
53 | |||
54 | # Read protocol and version | ||
55 | client_protocol = await self.socket.recv() | ||
56 | if not client_protocol: | ||
57 | return | ||
58 | |||
59 | (client_proto_name, client_proto_version) = client_protocol.split() | ||
60 | if client_proto_name != self.proto_name: | ||
61 | self.logger.debug("Rejecting invalid protocol %s" % (self.proto_name)) | ||
62 | return | ||
63 | |||
64 | self.proto_version = tuple(int(v) for v in client_proto_version.split(".")) | ||
65 | if not self.validate_proto_version(): | ||
66 | self.logger.debug( | ||
67 | "Rejecting invalid protocol version %s" % (client_proto_version) | ||
68 | ) | ||
69 | return | ||
70 | |||
71 | # Read headers | ||
72 | self.client_headers = {} | ||
73 | while True: | ||
74 | header = await self.socket.recv() | ||
75 | if not header: | ||
76 | # Empty line. End of headers | ||
77 | break | ||
78 | tag, value = header.split(":", 1) | ||
79 | self.client_headers[tag.lower()] = value.strip() | ||
80 | |||
81 | if self.client_headers.get("needs-headers", "false") == "true": | ||
82 | for k, v in (await self.handle_headers(self.client_headers)).items(): | ||
83 | await self.socket.send("%s: %s" % (k, v)) | ||
84 | await self.socket.send("") | ||
85 | |||
86 | # Handle messages | ||
87 | while True: | ||
88 | d = await self.socket.recv_message() | ||
89 | if d is None: | ||
90 | break | ||
91 | try: | ||
92 | response = await self.dispatch_message(d) | ||
93 | except InvokeError as e: | ||
94 | await self.socket.send_message( | ||
95 | {"invoke-error": {"message": str(e)}} | ||
96 | ) | ||
97 | break | ||
98 | |||
99 | if response is not self.NO_RESPONSE: | ||
100 | await self.socket.send_message(response) | ||
101 | |||
102 | except ConnectionClosedError as e: | ||
103 | self.logger.info(str(e)) | ||
104 | except (ClientError, ConnectionError) as e: | ||
105 | self.logger.error(str(e)) | ||
106 | finally: | ||
107 | await self.close() | ||
108 | |||
109 | async def dispatch_message(self, msg): | ||
110 | for k in self.handlers.keys(): | ||
111 | if k in msg: | ||
112 | self.logger.debug("Handling %s" % k) | ||
113 | return await self.handlers[k](msg[k]) | ||
114 | |||
115 | raise ClientError("Unrecognized command %r" % msg) | ||
116 | |||
117 | async def handle_ping(self, request): | ||
118 | return {"alive": True} | ||
119 | |||
120 | |||
121 | class StreamServer(object): | ||
122 | def __init__(self, handler, logger): | ||
123 | self.handler = handler | ||
124 | self.logger = logger | ||
125 | self.closed = False | ||
126 | |||
127 | async def handle_stream_client(self, reader, writer): | ||
128 | # writer.transport.set_write_buffer_limits(0) | ||
129 | socket = StreamConnection(reader, writer, -1) | ||
130 | if self.closed: | ||
131 | await socket.close() | ||
132 | return | ||
133 | |||
134 | await self.handler(socket) | ||
135 | |||
136 | async def stop(self): | ||
137 | self.closed = True | ||
138 | |||
139 | |||
140 | class TCPStreamServer(StreamServer): | ||
141 | def __init__(self, host, port, handler, logger): | ||
142 | super().__init__(handler, logger) | ||
143 | self.host = host | ||
144 | self.port = port | ||
145 | |||
146 | def start(self, loop): | ||
147 | self.server = loop.run_until_complete( | ||
148 | asyncio.start_server(self.handle_stream_client, self.host, self.port) | ||
149 | ) | ||
150 | |||
151 | for s in self.server.sockets: | ||
152 | self.logger.debug("Listening on %r" % (s.getsockname(),)) | ||
153 | # Newer python does this automatically. Do it manually here for | ||
154 | # maximum compatibility | ||
155 | s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) | ||
156 | s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1) | ||
157 | |||
158 | # Enable keep alives. This prevents broken client connections | ||
159 | # from persisting on the server for long periods of time. | ||
160 | s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) | ||
161 | s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30) | ||
162 | s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15) | ||
163 | s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4) | ||
164 | |||
165 | name = self.server.sockets[0].getsockname() | ||
166 | if self.server.sockets[0].family == socket.AF_INET6: | ||
167 | self.address = "[%s]:%d" % (name[0], name[1]) | ||
168 | else: | ||
169 | self.address = "%s:%d" % (name[0], name[1]) | ||
170 | |||
171 | return [self.server.wait_closed()] | ||
172 | |||
173 | async def stop(self): | ||
174 | await super().stop() | ||
175 | self.server.close() | ||
176 | |||
177 | def cleanup(self): | ||
178 | pass | ||
179 | |||
180 | |||
181 | class UnixStreamServer(StreamServer): | ||
182 | def __init__(self, path, handler, logger): | ||
183 | super().__init__(handler, logger) | ||
184 | self.path = path | ||
185 | |||
186 | def start(self, loop): | ||
187 | cwd = os.getcwd() | ||
188 | try: | ||
189 | # Work around path length limits in AF_UNIX | ||
190 | os.chdir(os.path.dirname(self.path)) | ||
191 | self.server = loop.run_until_complete( | ||
192 | asyncio.start_unix_server( | ||
193 | self.handle_stream_client, os.path.basename(self.path) | ||
194 | ) | ||
195 | ) | ||
196 | finally: | ||
197 | os.chdir(cwd) | ||
198 | |||
199 | self.logger.debug("Listening on %r" % self.path) | ||
200 | self.address = "unix://%s" % os.path.abspath(self.path) | ||
201 | return [self.server.wait_closed()] | ||
202 | |||
203 | async def stop(self): | ||
204 | await super().stop() | ||
205 | self.server.close() | ||
206 | |||
207 | def cleanup(self): | ||
208 | os.unlink(self.path) | ||
209 | |||
210 | |||
211 | class WebsocketsServer(object): | ||
212 | def __init__(self, host, port, handler, logger): | ||
213 | self.host = host | ||
214 | self.port = port | ||
215 | self.handler = handler | ||
216 | self.logger = logger | ||
217 | |||
218 | def start(self, loop): | ||
219 | import websockets.server | ||
220 | |||
221 | self.server = loop.run_until_complete( | ||
222 | websockets.server.serve( | ||
223 | self.client_handler, | ||
224 | self.host, | ||
225 | self.port, | ||
226 | ping_interval=None, | ||
227 | ) | ||
228 | ) | ||
229 | |||
230 | for s in self.server.sockets: | ||
231 | self.logger.debug("Listening on %r" % (s.getsockname(),)) | ||
232 | |||
233 | # Enable keep alives. This prevents broken client connections | ||
234 | # from persisting on the server for long periods of time. | ||
235 | s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) | ||
236 | s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30) | ||
237 | s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15) | ||
238 | s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4) | ||
239 | |||
240 | name = self.server.sockets[0].getsockname() | ||
241 | if self.server.sockets[0].family == socket.AF_INET6: | ||
242 | self.address = "ws://[%s]:%d" % (name[0], name[1]) | ||
243 | else: | ||
244 | self.address = "ws://%s:%d" % (name[0], name[1]) | ||
245 | |||
246 | return [self.server.wait_closed()] | ||
247 | |||
248 | async def stop(self): | ||
249 | self.server.close() | ||
250 | |||
251 | def cleanup(self): | ||
252 | pass | ||
253 | |||
254 | async def client_handler(self, websocket): | ||
255 | socket = WebsocketConnection(websocket, -1) | ||
256 | await self.handler(socket) | ||
257 | |||
258 | |||
259 | class AsyncServer(object): | ||
260 | def __init__(self, logger): | ||
261 | self.logger = logger | ||
262 | self.loop = None | ||
263 | self.run_tasks = [] | ||
264 | |||
265 | def start_tcp_server(self, host, port): | ||
266 | self.server = TCPStreamServer(host, port, self._client_handler, self.logger) | ||
267 | |||
268 | def start_unix_server(self, path): | ||
269 | self.server = UnixStreamServer(path, self._client_handler, self.logger) | ||
270 | |||
271 | def start_websocket_server(self, host, port): | ||
272 | self.server = WebsocketsServer(host, port, self._client_handler, self.logger) | ||
273 | |||
274 | async def _client_handler(self, socket): | ||
275 | address = socket.address | ||
276 | try: | ||
277 | client = self.accept_client(socket) | ||
278 | await client.process_requests() | ||
279 | except Exception as e: | ||
280 | import traceback | ||
281 | |||
282 | self.logger.error( | ||
283 | "Error from client %s: %s" % (address, str(e)), exc_info=True | ||
284 | ) | ||
285 | traceback.print_exc() | ||
286 | finally: | ||
287 | self.logger.debug("Client %s disconnected", address) | ||
288 | await socket.close() | ||
289 | |||
290 | @abc.abstractmethod | ||
291 | def accept_client(self, socket): | ||
292 | pass | ||
293 | |||
294 | async def stop(self): | ||
295 | self.logger.debug("Stopping server") | ||
296 | await self.server.stop() | ||
297 | |||
298 | def start(self): | ||
299 | tasks = self.server.start(self.loop) | ||
300 | self.address = self.server.address | ||
301 | return tasks | ||
302 | |||
303 | def signal_handler(self): | ||
304 | self.logger.debug("Got exit signal") | ||
305 | self.loop.create_task(self.stop()) | ||
306 | |||
307 | def _serve_forever(self, tasks): | ||
308 | try: | ||
309 | self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler) | ||
310 | self.loop.add_signal_handler(signal.SIGINT, self.signal_handler) | ||
311 | self.loop.add_signal_handler(signal.SIGQUIT, self.signal_handler) | ||
312 | signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM]) | ||
313 | |||
314 | self.loop.run_until_complete(asyncio.gather(*tasks)) | ||
315 | |||
316 | self.logger.debug("Server shutting down") | ||
317 | finally: | ||
318 | self.server.cleanup() | ||
319 | |||
320 | def serve_forever(self): | ||
321 | """ | ||
322 | Serve requests in the current process | ||
323 | """ | ||
324 | self._create_loop() | ||
325 | tasks = self.start() | ||
326 | self._serve_forever(tasks) | ||
327 | self.loop.close() | ||
328 | |||
329 | def _create_loop(self): | ||
330 | # Create loop and override any loop that may have existed in | ||
331 | # a parent process. It is possible that the usecases of | ||
332 | # serve_forever might be constrained enough to allow using | ||
333 | # get_event_loop here, but better safe than sorry for now. | ||
334 | self.loop = asyncio.new_event_loop() | ||
335 | asyncio.set_event_loop(self.loop) | ||
336 | |||
337 | def serve_as_process(self, *, prefunc=None, args=(), log_level=None): | ||
338 | """ | ||
339 | Serve requests in a child process | ||
340 | """ | ||
341 | |||
342 | def run(queue): | ||
343 | # Create loop and override any loop that may have existed | ||
344 | # in a parent process. Without doing this and instead | ||
345 | # using get_event_loop, at the very minimum the hashserv | ||
346 | # unit tests will hang when running the second test. | ||
347 | # This happens since get_event_loop in the spawned server | ||
348 | # process for the second testcase ends up with the loop | ||
349 | # from the hashserv client created in the unit test process | ||
350 | # when running the first testcase. The problem is somewhat | ||
351 | # more general, though, as any potential use of asyncio in | ||
352 | # Cooker could create a loop that needs to replaced in this | ||
353 | # new process. | ||
354 | self._create_loop() | ||
355 | try: | ||
356 | self.address = None | ||
357 | tasks = self.start() | ||
358 | finally: | ||
359 | # Always put the server address to wake up the parent task | ||
360 | queue.put(self.address) | ||
361 | queue.close() | ||
362 | |||
363 | if prefunc is not None: | ||
364 | prefunc(self, *args) | ||
365 | |||
366 | if log_level is not None: | ||
367 | self.logger.setLevel(log_level) | ||
368 | |||
369 | self._serve_forever(tasks) | ||
370 | |||
371 | if sys.version_info >= (3, 6): | ||
372 | self.loop.run_until_complete(self.loop.shutdown_asyncgens()) | ||
373 | self.loop.close() | ||
374 | |||
375 | queue = multiprocessing.Queue() | ||
376 | |||
377 | # Temporarily block SIGTERM. The server process will inherit this | ||
378 | # block which will ensure it doesn't receive the SIGTERM until the | ||
379 | # handler is ready for it | ||
380 | mask = signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGTERM]) | ||
381 | try: | ||
382 | self.process = multiprocessing.Process(target=run, args=(queue,)) | ||
383 | self.process.start() | ||
384 | |||
385 | self.address = queue.get() | ||
386 | queue.close() | ||
387 | queue.join_thread() | ||
388 | |||
389 | return self.process | ||
390 | finally: | ||
391 | signal.pthread_sigmask(signal.SIG_SETMASK, mask) | ||