diff options
author | Joshua Watt <JPEWhacker@gmail.com> | 2023-11-03 08:26:19 -0600 |
---|---|---|
committer | Richard Purdie <richard.purdie@linuxfoundation.org> | 2023-11-09 17:33:02 +0000 |
commit | 8f8501ed403dec27acbe780b936bc087fc5006d0 (patch) | |
tree | 60e6415075c7c71eacec23ca7dda53e4a324b12e /bitbake/lib/bb/asyncrpc/serv.py | |
parent | f97b686884166dd77d1818e70615027c6ba8c348 (diff) | |
download | poky-8f8501ed403dec27acbe780b936bc087fc5006d0.tar.gz |
bitbake: asyncrpc: Abstract sockets
Rewrites the asyncrpc client and server code to make it possible to have
other transport backends that are not stream based (e.g. websockets
which are message based). The connection handling classes are now shared
between both the client and server to make it easier to implement new
transport mechanisms
(Bitbake rev: 2aaeae53696e4c2f13a169830c3b7089cbad6eca)
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
Signed-off-by: Richard Purdie <richard.purdie@linuxfoundation.org>
Diffstat (limited to 'bitbake/lib/bb/asyncrpc/serv.py')
-rw-r--r-- | bitbake/lib/bb/asyncrpc/serv.py | 304 |
1 files changed, 156 insertions, 148 deletions
diff --git a/bitbake/lib/bb/asyncrpc/serv.py b/bitbake/lib/bb/asyncrpc/serv.py index d2de4891b8..3e0d0632cb 100644 --- a/bitbake/lib/bb/asyncrpc/serv.py +++ b/bitbake/lib/bb/asyncrpc/serv.py | |||
@@ -12,241 +12,248 @@ import signal | |||
12 | import socket | 12 | import socket |
13 | import sys | 13 | import sys |
14 | import multiprocessing | 14 | import multiprocessing |
15 | from . import chunkify, DEFAULT_MAX_CHUNK | 15 | from .connection import StreamConnection |
16 | 16 | from .exceptions import ClientError, ServerError, ConnectionClosedError | |
17 | |||
18 | class ClientError(Exception): | ||
19 | pass | ||
20 | |||
21 | |||
22 | class ServerError(Exception): | ||
23 | pass | ||
24 | 17 | ||
25 | 18 | ||
26 | class AsyncServerConnection(object): | 19 | class AsyncServerConnection(object): |
27 | def __init__(self, reader, writer, proto_name, logger): | 20 | # If a handler returns this object (e.g. `return self.NO_RESPONSE`), no |
28 | self.reader = reader | 21 | # return message will be automatically be sent back to the client |
29 | self.writer = writer | 22 | NO_RESPONSE = object() |
23 | |||
24 | def __init__(self, socket, proto_name, logger): | ||
25 | self.socket = socket | ||
30 | self.proto_name = proto_name | 26 | self.proto_name = proto_name |
31 | self.max_chunk = DEFAULT_MAX_CHUNK | ||
32 | self.handlers = { | 27 | self.handlers = { |
33 | 'chunk-stream': self.handle_chunk, | 28 | "ping": self.handle_ping, |
34 | 'ping': self.handle_ping, | ||
35 | } | 29 | } |
36 | self.logger = logger | 30 | self.logger = logger |
37 | 31 | ||
32 | async def close(self): | ||
33 | await self.socket.close() | ||
34 | |||
38 | async def process_requests(self): | 35 | async def process_requests(self): |
39 | try: | 36 | try: |
40 | self.addr = self.writer.get_extra_info('peername') | 37 | self.logger.info("Client %r connected" % (self.socket.address,)) |
41 | self.logger.debug('Client %r connected' % (self.addr,)) | ||
42 | 38 | ||
43 | # Read protocol and version | 39 | # Read protocol and version |
44 | client_protocol = await self.reader.readline() | 40 | client_protocol = await self.socket.recv() |
45 | if not client_protocol: | 41 | if not client_protocol: |
46 | return | 42 | return |
47 | 43 | ||
48 | (client_proto_name, client_proto_version) = client_protocol.decode('utf-8').rstrip().split() | 44 | (client_proto_name, client_proto_version) = client_protocol.split() |
49 | if client_proto_name != self.proto_name: | 45 | if client_proto_name != self.proto_name: |
50 | self.logger.debug('Rejecting invalid protocol %s' % (self.proto_name)) | 46 | self.logger.debug("Rejecting invalid protocol %s" % (self.proto_name)) |
51 | return | 47 | return |
52 | 48 | ||
53 | self.proto_version = tuple(int(v) for v in client_proto_version.split('.')) | 49 | self.proto_version = tuple(int(v) for v in client_proto_version.split(".")) |
54 | if not self.validate_proto_version(): | 50 | if not self.validate_proto_version(): |
55 | self.logger.debug('Rejecting invalid protocol version %s' % (client_proto_version)) | 51 | self.logger.debug( |
52 | "Rejecting invalid protocol version %s" % (client_proto_version) | ||
53 | ) | ||
56 | return | 54 | return |
57 | 55 | ||
58 | # Read headers. Currently, no headers are implemented, so look for | 56 | # Read headers. Currently, no headers are implemented, so look for |
59 | # an empty line to signal the end of the headers | 57 | # an empty line to signal the end of the headers |
60 | while True: | 58 | while True: |
61 | line = await self.reader.readline() | 59 | header = await self.socket.recv() |
62 | if not line: | 60 | if not header: |
63 | return | ||
64 | |||
65 | line = line.decode('utf-8').rstrip() | ||
66 | if not line: | ||
67 | break | 61 | break |
68 | 62 | ||
69 | # Handle messages | 63 | # Handle messages |
70 | while True: | 64 | while True: |
71 | d = await self.read_message() | 65 | d = await self.socket.recv_message() |
72 | if d is None: | 66 | if d is None: |
73 | break | 67 | break |
74 | await self.dispatch_message(d) | 68 | response = await self.dispatch_message(d) |
75 | await self.writer.drain() | 69 | if response is not self.NO_RESPONSE: |
76 | except ClientError as e: | 70 | await self.socket.send_message(response) |
71 | |||
72 | except ConnectionClosedError as e: | ||
73 | self.logger.info(str(e)) | ||
74 | except (ClientError, ConnectionError) as e: | ||
77 | self.logger.error(str(e)) | 75 | self.logger.error(str(e)) |
78 | finally: | 76 | finally: |
79 | self.writer.close() | 77 | await self.close() |
80 | 78 | ||
81 | async def dispatch_message(self, msg): | 79 | async def dispatch_message(self, msg): |
82 | for k in self.handlers.keys(): | 80 | for k in self.handlers.keys(): |
83 | if k in msg: | 81 | if k in msg: |
84 | self.logger.debug('Handling %s' % k) | 82 | self.logger.debug("Handling %s" % k) |
85 | await self.handlers[k](msg[k]) | 83 | return await self.handlers[k](msg[k]) |
86 | return | ||
87 | 84 | ||
88 | raise ClientError("Unrecognized command %r" % msg) | 85 | raise ClientError("Unrecognized command %r" % msg) |
89 | 86 | ||
90 | def write_message(self, msg): | 87 | async def handle_ping(self, request): |
91 | for c in chunkify(json.dumps(msg), self.max_chunk): | 88 | return {"alive": True} |
92 | self.writer.write(c.encode('utf-8')) | ||
93 | 89 | ||
94 | async def read_message(self): | ||
95 | l = await self.reader.readline() | ||
96 | if not l: | ||
97 | return None | ||
98 | 90 | ||
99 | try: | 91 | class StreamServer(object): |
100 | message = l.decode('utf-8') | 92 | def __init__(self, handler, logger): |
93 | self.handler = handler | ||
94 | self.logger = logger | ||
95 | self.closed = False | ||
101 | 96 | ||
102 | if not message.endswith('\n'): | 97 | async def handle_stream_client(self, reader, writer): |
103 | return None | 98 | # writer.transport.set_write_buffer_limits(0) |
99 | socket = StreamConnection(reader, writer, -1) | ||
100 | if self.closed: | ||
101 | await socket.close() | ||
102 | return | ||
103 | |||
104 | await self.handler(socket) | ||
105 | |||
106 | async def stop(self): | ||
107 | self.closed = True | ||
108 | |||
109 | |||
110 | class TCPStreamServer(StreamServer): | ||
111 | def __init__(self, host, port, handler, logger): | ||
112 | super().__init__(handler, logger) | ||
113 | self.host = host | ||
114 | self.port = port | ||
115 | |||
116 | def start(self, loop): | ||
117 | self.server = loop.run_until_complete( | ||
118 | asyncio.start_server(self.handle_stream_client, self.host, self.port) | ||
119 | ) | ||
120 | |||
121 | for s in self.server.sockets: | ||
122 | self.logger.debug("Listening on %r" % (s.getsockname(),)) | ||
123 | # Newer python does this automatically. Do it manually here for | ||
124 | # maximum compatibility | ||
125 | s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) | ||
126 | s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1) | ||
127 | |||
128 | # Enable keep alives. This prevents broken client connections | ||
129 | # from persisting on the server for long periods of time. | ||
130 | s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) | ||
131 | s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30) | ||
132 | s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15) | ||
133 | s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4) | ||
134 | |||
135 | name = self.server.sockets[0].getsockname() | ||
136 | if self.server.sockets[0].family == socket.AF_INET6: | ||
137 | self.address = "[%s]:%d" % (name[0], name[1]) | ||
138 | else: | ||
139 | self.address = "%s:%d" % (name[0], name[1]) | ||
140 | |||
141 | return [self.server.wait_closed()] | ||
142 | |||
143 | async def stop(self): | ||
144 | await super().stop() | ||
145 | self.server.close() | ||
146 | |||
147 | def cleanup(self): | ||
148 | pass | ||
104 | 149 | ||
105 | return json.loads(message) | ||
106 | except (json.JSONDecodeError, UnicodeDecodeError) as e: | ||
107 | self.logger.error('Bad message from client: %r' % message) | ||
108 | raise e | ||
109 | 150 | ||
110 | async def handle_chunk(self, request): | 151 | class UnixStreamServer(StreamServer): |
111 | lines = [] | 152 | def __init__(self, path, handler, logger): |
112 | try: | 153 | super().__init__(handler, logger) |
113 | while True: | 154 | self.path = path |
114 | l = await self.reader.readline() | ||
115 | l = l.rstrip(b"\n").decode("utf-8") | ||
116 | if not l: | ||
117 | break | ||
118 | lines.append(l) | ||
119 | 155 | ||
120 | msg = json.loads(''.join(lines)) | 156 | def start(self, loop): |
121 | except (json.JSONDecodeError, UnicodeDecodeError) as e: | 157 | cwd = os.getcwd() |
122 | self.logger.error('Bad message from client: %r' % lines) | 158 | try: |
123 | raise e | 159 | # Work around path length limits in AF_UNIX |
160 | os.chdir(os.path.dirname(self.path)) | ||
161 | self.server = loop.run_until_complete( | ||
162 | asyncio.start_unix_server( | ||
163 | self.handle_stream_client, os.path.basename(self.path) | ||
164 | ) | ||
165 | ) | ||
166 | finally: | ||
167 | os.chdir(cwd) | ||
124 | 168 | ||
125 | if 'chunk-stream' in msg: | 169 | self.logger.debug("Listening on %r" % self.path) |
126 | raise ClientError("Nested chunks are not allowed") | 170 | self.address = "unix://%s" % os.path.abspath(self.path) |
171 | return [self.server.wait_closed()] | ||
127 | 172 | ||
128 | await self.dispatch_message(msg) | 173 | async def stop(self): |
174 | await super().stop() | ||
175 | self.server.close() | ||
129 | 176 | ||
130 | async def handle_ping(self, request): | 177 | def cleanup(self): |
131 | response = {'alive': True} | 178 | os.unlink(self.path) |
132 | self.write_message(response) | ||
133 | 179 | ||
134 | 180 | ||
135 | class AsyncServer(object): | 181 | class AsyncServer(object): |
136 | def __init__(self, logger): | 182 | def __init__(self, logger): |
137 | self._cleanup_socket = None | ||
138 | self.logger = logger | 183 | self.logger = logger |
139 | self.start = None | ||
140 | self.address = None | ||
141 | self.loop = None | 184 | self.loop = None |
185 | self.run_tasks = [] | ||
142 | 186 | ||
143 | def start_tcp_server(self, host, port): | 187 | def start_tcp_server(self, host, port): |
144 | def start_tcp(): | 188 | self.server = TCPStreamServer(host, port, self._client_handler, self.logger) |
145 | self.server = self.loop.run_until_complete( | ||
146 | asyncio.start_server(self.handle_client, host, port) | ||
147 | ) | ||
148 | |||
149 | for s in self.server.sockets: | ||
150 | self.logger.debug('Listening on %r' % (s.getsockname(),)) | ||
151 | # Newer python does this automatically. Do it manually here for | ||
152 | # maximum compatibility | ||
153 | s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) | ||
154 | s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1) | ||
155 | |||
156 | # Enable keep alives. This prevents broken client connections | ||
157 | # from persisting on the server for long periods of time. | ||
158 | s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) | ||
159 | s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30) | ||
160 | s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15) | ||
161 | s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4) | ||
162 | |||
163 | name = self.server.sockets[0].getsockname() | ||
164 | if self.server.sockets[0].family == socket.AF_INET6: | ||
165 | self.address = "[%s]:%d" % (name[0], name[1]) | ||
166 | else: | ||
167 | self.address = "%s:%d" % (name[0], name[1]) | ||
168 | |||
169 | self.start = start_tcp | ||
170 | 189 | ||
171 | def start_unix_server(self, path): | 190 | def start_unix_server(self, path): |
172 | def cleanup(): | 191 | self.server = UnixStreamServer(path, self._client_handler, self.logger) |
173 | os.unlink(path) | ||
174 | |||
175 | def start_unix(): | ||
176 | cwd = os.getcwd() | ||
177 | try: | ||
178 | # Work around path length limits in AF_UNIX | ||
179 | os.chdir(os.path.dirname(path)) | ||
180 | self.server = self.loop.run_until_complete( | ||
181 | asyncio.start_unix_server(self.handle_client, os.path.basename(path)) | ||
182 | ) | ||
183 | finally: | ||
184 | os.chdir(cwd) | ||
185 | |||
186 | self.logger.debug('Listening on %r' % path) | ||
187 | 192 | ||
188 | self._cleanup_socket = cleanup | 193 | async def _client_handler(self, socket): |
189 | self.address = "unix://%s" % os.path.abspath(path) | ||
190 | |||
191 | self.start = start_unix | ||
192 | |||
193 | @abc.abstractmethod | ||
194 | def accept_client(self, reader, writer): | ||
195 | pass | ||
196 | |||
197 | async def handle_client(self, reader, writer): | ||
198 | # writer.transport.set_write_buffer_limits(0) | ||
199 | try: | 194 | try: |
200 | client = self.accept_client(reader, writer) | 195 | client = self.accept_client(socket) |
201 | await client.process_requests() | 196 | await client.process_requests() |
202 | except Exception as e: | 197 | except Exception as e: |
203 | import traceback | 198 | import traceback |
204 | self.logger.error('Error from client: %s' % str(e), exc_info=True) | 199 | |
200 | self.logger.error("Error from client: %s" % str(e), exc_info=True) | ||
205 | traceback.print_exc() | 201 | traceback.print_exc() |
206 | writer.close() | 202 | await socket.close() |
207 | self.logger.debug('Client disconnected') | 203 | self.logger.debug("Client disconnected") |
208 | 204 | ||
209 | def run_loop_forever(self): | 205 | @abc.abstractmethod |
210 | try: | 206 | def accept_client(self, socket): |
211 | self.loop.run_forever() | 207 | pass |
212 | except KeyboardInterrupt: | 208 | |
213 | pass | 209 | async def stop(self): |
210 | self.logger.debug("Stopping server") | ||
211 | await self.server.stop() | ||
212 | |||
213 | def start(self): | ||
214 | tasks = self.server.start(self.loop) | ||
215 | self.address = self.server.address | ||
216 | return tasks | ||
214 | 217 | ||
215 | def signal_handler(self): | 218 | def signal_handler(self): |
216 | self.logger.debug("Got exit signal") | 219 | self.logger.debug("Got exit signal") |
217 | self.loop.stop() | 220 | self.loop.create_task(self.stop()) |
218 | 221 | ||
219 | def _serve_forever(self): | 222 | def _serve_forever(self, tasks): |
220 | try: | 223 | try: |
221 | self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler) | 224 | self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler) |
225 | self.loop.add_signal_handler(signal.SIGINT, self.signal_handler) | ||
226 | self.loop.add_signal_handler(signal.SIGQUIT, self.signal_handler) | ||
222 | signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM]) | 227 | signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM]) |
223 | 228 | ||
224 | self.run_loop_forever() | 229 | self.loop.run_until_complete(asyncio.gather(*tasks)) |
225 | self.server.close() | ||
226 | 230 | ||
227 | self.loop.run_until_complete(self.server.wait_closed()) | 231 | self.logger.debug("Server shutting down") |
228 | self.logger.debug('Server shutting down') | ||
229 | finally: | 232 | finally: |
230 | if self._cleanup_socket is not None: | 233 | self.server.cleanup() |
231 | self._cleanup_socket() | ||
232 | 234 | ||
233 | def serve_forever(self): | 235 | def serve_forever(self): |
234 | """ | 236 | """ |
235 | Serve requests in the current process | 237 | Serve requests in the current process |
236 | """ | 238 | """ |
239 | self._create_loop() | ||
240 | tasks = self.start() | ||
241 | self._serve_forever(tasks) | ||
242 | self.loop.close() | ||
243 | |||
244 | def _create_loop(self): | ||
237 | # Create loop and override any loop that may have existed in | 245 | # Create loop and override any loop that may have existed in |
238 | # a parent process. It is possible that the usecases of | 246 | # a parent process. It is possible that the usecases of |
239 | # serve_forever might be constrained enough to allow using | 247 | # serve_forever might be constrained enough to allow using |
240 | # get_event_loop here, but better safe than sorry for now. | 248 | # get_event_loop here, but better safe than sorry for now. |
241 | self.loop = asyncio.new_event_loop() | 249 | self.loop = asyncio.new_event_loop() |
242 | asyncio.set_event_loop(self.loop) | 250 | asyncio.set_event_loop(self.loop) |
243 | self.start() | ||
244 | self._serve_forever() | ||
245 | 251 | ||
246 | def serve_as_process(self, *, prefunc=None, args=()): | 252 | def serve_as_process(self, *, prefunc=None, args=()): |
247 | """ | 253 | """ |
248 | Serve requests in a child process | 254 | Serve requests in a child process |
249 | """ | 255 | """ |
256 | |||
250 | def run(queue): | 257 | def run(queue): |
251 | # Create loop and override any loop that may have existed | 258 | # Create loop and override any loop that may have existed |
252 | # in a parent process. Without doing this and instead | 259 | # in a parent process. Without doing this and instead |
@@ -259,18 +266,19 @@ class AsyncServer(object): | |||
259 | # more general, though, as any potential use of asyncio in | 266 | # more general, though, as any potential use of asyncio in |
260 | # Cooker could create a loop that needs to replaced in this | 267 | # Cooker could create a loop that needs to replaced in this |
261 | # new process. | 268 | # new process. |
262 | self.loop = asyncio.new_event_loop() | 269 | self._create_loop() |
263 | asyncio.set_event_loop(self.loop) | ||
264 | try: | 270 | try: |
265 | self.start() | 271 | self.address = None |
272 | tasks = self.start() | ||
266 | finally: | 273 | finally: |
274 | # Always put the server address to wake up the parent task | ||
267 | queue.put(self.address) | 275 | queue.put(self.address) |
268 | queue.close() | 276 | queue.close() |
269 | 277 | ||
270 | if prefunc is not None: | 278 | if prefunc is not None: |
271 | prefunc(self, *args) | 279 | prefunc(self, *args) |
272 | 280 | ||
273 | self._serve_forever() | 281 | self._serve_forever(tasks) |
274 | 282 | ||
275 | if sys.version_info >= (3, 6): | 283 | if sys.version_info >= (3, 6): |
276 | self.loop.run_until_complete(self.loop.shutdown_asyncgens()) | 284 | self.loop.run_until_complete(self.loop.shutdown_asyncgens()) |