diff options
Diffstat (limited to 'bitbake/lib/bb/asyncrpc')
-rw-r--r-- | bitbake/lib/bb/asyncrpc/__init__.py | 16 | ||||
-rw-r--r-- | bitbake/lib/bb/asyncrpc/client.py | 266 | ||||
-rw-r--r-- | bitbake/lib/bb/asyncrpc/connection.py | 146 | ||||
-rw-r--r-- | bitbake/lib/bb/asyncrpc/exceptions.py | 21 | ||||
-rw-r--r-- | bitbake/lib/bb/asyncrpc/serv.py | 410 |
5 files changed, 859 insertions, 0 deletions
diff --git a/bitbake/lib/bb/asyncrpc/__init__.py b/bitbake/lib/bb/asyncrpc/__init__.py new file mode 100644 index 0000000000..a4371643d7 --- /dev/null +++ b/bitbake/lib/bb/asyncrpc/__init__.py | |||
@@ -0,0 +1,16 @@ | |||
1 | # | ||
2 | # Copyright BitBake Contributors | ||
3 | # | ||
4 | # SPDX-License-Identifier: GPL-2.0-only | ||
5 | # | ||
6 | |||
7 | |||
8 | from .client import AsyncClient, Client | ||
9 | from .serv import AsyncServer, AsyncServerConnection | ||
10 | from .connection import DEFAULT_MAX_CHUNK | ||
11 | from .exceptions import ( | ||
12 | ClientError, | ||
13 | ServerError, | ||
14 | ConnectionClosedError, | ||
15 | InvokeError, | ||
16 | ) | ||
diff --git a/bitbake/lib/bb/asyncrpc/client.py b/bitbake/lib/bb/asyncrpc/client.py new file mode 100644 index 0000000000..9be49261c0 --- /dev/null +++ b/bitbake/lib/bb/asyncrpc/client.py | |||
@@ -0,0 +1,266 @@ | |||
1 | # | ||
2 | # Copyright BitBake Contributors | ||
3 | # | ||
4 | # SPDX-License-Identifier: GPL-2.0-only | ||
5 | # | ||
6 | |||
7 | import abc | ||
8 | import asyncio | ||
9 | import json | ||
10 | import os | ||
11 | import socket | ||
12 | import sys | ||
13 | import re | ||
14 | import contextlib | ||
15 | from threading import Thread | ||
16 | from .connection import StreamConnection, WebsocketConnection, DEFAULT_MAX_CHUNK | ||
17 | from .exceptions import ConnectionClosedError, InvokeError | ||
18 | |||
19 | UNIX_PREFIX = "unix://" | ||
20 | WS_PREFIX = "ws://" | ||
21 | WSS_PREFIX = "wss://" | ||
22 | |||
23 | ADDR_TYPE_UNIX = 0 | ||
24 | ADDR_TYPE_TCP = 1 | ||
25 | ADDR_TYPE_WS = 2 | ||
26 | |||
27 | WEBSOCKETS_MIN_VERSION = (9, 1) | ||
28 | # Need websockets 10 with python 3.10+ | ||
29 | if sys.version_info >= (3, 10, 0): | ||
30 | WEBSOCKETS_MIN_VERSION = (10, 0) | ||
31 | |||
32 | |||
33 | def parse_address(addr): | ||
34 | if addr.startswith(UNIX_PREFIX): | ||
35 | return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX) :],)) | ||
36 | elif addr.startswith(WS_PREFIX) or addr.startswith(WSS_PREFIX): | ||
37 | return (ADDR_TYPE_WS, (addr,)) | ||
38 | else: | ||
39 | m = re.match(r"\[(?P<host>[^\]]*)\]:(?P<port>\d+)$", addr) | ||
40 | if m is not None: | ||
41 | host = m.group("host") | ||
42 | port = m.group("port") | ||
43 | else: | ||
44 | host, port = addr.split(":") | ||
45 | |||
46 | return (ADDR_TYPE_TCP, (host, int(port))) | ||
47 | |||
48 | |||
49 | class AsyncClient(object): | ||
50 | def __init__( | ||
51 | self, | ||
52 | proto_name, | ||
53 | proto_version, | ||
54 | logger, | ||
55 | timeout=30, | ||
56 | server_headers=False, | ||
57 | headers={}, | ||
58 | ): | ||
59 | self.socket = None | ||
60 | self.max_chunk = DEFAULT_MAX_CHUNK | ||
61 | self.proto_name = proto_name | ||
62 | self.proto_version = proto_version | ||
63 | self.logger = logger | ||
64 | self.timeout = timeout | ||
65 | self.needs_server_headers = server_headers | ||
66 | self.server_headers = {} | ||
67 | self.headers = headers | ||
68 | |||
69 | async def connect_tcp(self, address, port): | ||
70 | async def connect_sock(): | ||
71 | reader, writer = await asyncio.open_connection(address, port) | ||
72 | return StreamConnection(reader, writer, self.timeout, self.max_chunk) | ||
73 | |||
74 | self._connect_sock = connect_sock | ||
75 | |||
76 | async def connect_unix(self, path): | ||
77 | async def connect_sock(): | ||
78 | # AF_UNIX has path length issues so chdir here to workaround | ||
79 | cwd = os.getcwd() | ||
80 | try: | ||
81 | os.chdir(os.path.dirname(path)) | ||
82 | # The socket must be opened synchronously so that CWD doesn't get | ||
83 | # changed out from underneath us so we pass as a sock into asyncio | ||
84 | sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0) | ||
85 | sock.connect(os.path.basename(path)) | ||
86 | finally: | ||
87 | os.chdir(cwd) | ||
88 | reader, writer = await asyncio.open_unix_connection(sock=sock) | ||
89 | return StreamConnection(reader, writer, self.timeout, self.max_chunk) | ||
90 | |||
91 | self._connect_sock = connect_sock | ||
92 | |||
93 | async def connect_websocket(self, uri): | ||
94 | import websockets | ||
95 | |||
96 | try: | ||
97 | version = tuple( | ||
98 | int(v) | ||
99 | for v in websockets.__version__.split(".")[ | ||
100 | 0 : len(WEBSOCKETS_MIN_VERSION) | ||
101 | ] | ||
102 | ) | ||
103 | except ValueError: | ||
104 | raise ImportError( | ||
105 | f"Unable to parse websockets version '{websockets.__version__}'" | ||
106 | ) | ||
107 | |||
108 | if version < WEBSOCKETS_MIN_VERSION: | ||
109 | min_ver_str = ".".join(str(v) for v in WEBSOCKETS_MIN_VERSION) | ||
110 | raise ImportError( | ||
111 | f"Websockets version {websockets.__version__} is less than minimum required version {min_ver_str}" | ||
112 | ) | ||
113 | |||
114 | async def connect_sock(): | ||
115 | websocket = await websockets.connect( | ||
116 | uri, | ||
117 | ping_interval=None, | ||
118 | open_timeout=self.timeout, | ||
119 | ) | ||
120 | return WebsocketConnection(websocket, self.timeout) | ||
121 | |||
122 | self._connect_sock = connect_sock | ||
123 | |||
124 | async def setup_connection(self): | ||
125 | # Send headers | ||
126 | await self.socket.send("%s %s" % (self.proto_name, self.proto_version)) | ||
127 | await self.socket.send( | ||
128 | "needs-headers: %s" % ("true" if self.needs_server_headers else "false") | ||
129 | ) | ||
130 | for k, v in self.headers.items(): | ||
131 | await self.socket.send("%s: %s" % (k, v)) | ||
132 | |||
133 | # End of headers | ||
134 | await self.socket.send("") | ||
135 | |||
136 | self.server_headers = {} | ||
137 | if self.needs_server_headers: | ||
138 | while True: | ||
139 | line = await self.socket.recv() | ||
140 | if not line: | ||
141 | # End headers | ||
142 | break | ||
143 | tag, value = line.split(":", 1) | ||
144 | self.server_headers[tag.lower()] = value.strip() | ||
145 | |||
146 | async def get_header(self, tag, default): | ||
147 | await self.connect() | ||
148 | return self.server_headers.get(tag, default) | ||
149 | |||
150 | async def connect(self): | ||
151 | if self.socket is None: | ||
152 | self.socket = await self._connect_sock() | ||
153 | await self.setup_connection() | ||
154 | |||
155 | async def disconnect(self): | ||
156 | if self.socket is not None: | ||
157 | await self.socket.close() | ||
158 | self.socket = None | ||
159 | |||
160 | async def close(self): | ||
161 | await self.disconnect() | ||
162 | |||
163 | async def _send_wrapper(self, proc): | ||
164 | count = 0 | ||
165 | while True: | ||
166 | try: | ||
167 | await self.connect() | ||
168 | return await proc() | ||
169 | except ( | ||
170 | OSError, | ||
171 | ConnectionError, | ||
172 | ConnectionClosedError, | ||
173 | json.JSONDecodeError, | ||
174 | UnicodeDecodeError, | ||
175 | ) as e: | ||
176 | self.logger.warning("Error talking to server: %s" % e) | ||
177 | if count >= 3: | ||
178 | if not isinstance(e, ConnectionError): | ||
179 | raise ConnectionError(str(e)) | ||
180 | raise e | ||
181 | await self.close() | ||
182 | count += 1 | ||
183 | |||
184 | def check_invoke_error(self, msg): | ||
185 | if isinstance(msg, dict) and "invoke-error" in msg: | ||
186 | raise InvokeError(msg["invoke-error"]["message"]) | ||
187 | |||
188 | async def invoke(self, msg): | ||
189 | async def proc(): | ||
190 | await self.socket.send_message(msg) | ||
191 | return await self.socket.recv_message() | ||
192 | |||
193 | result = await self._send_wrapper(proc) | ||
194 | self.check_invoke_error(result) | ||
195 | return result | ||
196 | |||
197 | async def ping(self): | ||
198 | return await self.invoke({"ping": {}}) | ||
199 | |||
200 | async def __aenter__(self): | ||
201 | return self | ||
202 | |||
203 | async def __aexit__(self, exc_type, exc_value, traceback): | ||
204 | await self.close() | ||
205 | |||
206 | |||
207 | class Client(object): | ||
208 | def __init__(self): | ||
209 | self.client = self._get_async_client() | ||
210 | self.loop = asyncio.new_event_loop() | ||
211 | |||
212 | # Override any pre-existing loop. | ||
213 | # Without this, the PR server export selftest triggers a hang | ||
214 | # when running with Python 3.7. The drawback is that there is | ||
215 | # potential for issues if the PR and hash equiv (or some new) | ||
216 | # clients need to both be instantiated in the same process. | ||
217 | # This should be revisited if/when Python 3.9 becomes the | ||
218 | # minimum required version for BitBake, as it seems not | ||
219 | # required (but harmless) with it. | ||
220 | asyncio.set_event_loop(self.loop) | ||
221 | |||
222 | self._add_methods("connect_tcp", "ping") | ||
223 | |||
224 | @abc.abstractmethod | ||
225 | def _get_async_client(self): | ||
226 | pass | ||
227 | |||
228 | def _get_downcall_wrapper(self, downcall): | ||
229 | def wrapper(*args, **kwargs): | ||
230 | return self.loop.run_until_complete(downcall(*args, **kwargs)) | ||
231 | |||
232 | return wrapper | ||
233 | |||
234 | def _add_methods(self, *methods): | ||
235 | for m in methods: | ||
236 | downcall = getattr(self.client, m) | ||
237 | setattr(self, m, self._get_downcall_wrapper(downcall)) | ||
238 | |||
239 | def connect_unix(self, path): | ||
240 | self.loop.run_until_complete(self.client.connect_unix(path)) | ||
241 | self.loop.run_until_complete(self.client.connect()) | ||
242 | |||
243 | @property | ||
244 | def max_chunk(self): | ||
245 | return self.client.max_chunk | ||
246 | |||
247 | @max_chunk.setter | ||
248 | def max_chunk(self, value): | ||
249 | self.client.max_chunk = value | ||
250 | |||
251 | def disconnect(self): | ||
252 | self.loop.run_until_complete(self.client.close()) | ||
253 | |||
254 | def close(self): | ||
255 | if self.loop: | ||
256 | self.loop.run_until_complete(self.client.close()) | ||
257 | self.loop.run_until_complete(self.loop.shutdown_asyncgens()) | ||
258 | self.loop.close() | ||
259 | self.loop = None | ||
260 | |||
261 | def __enter__(self): | ||
262 | return self | ||
263 | |||
264 | def __exit__(self, exc_type, exc_value, traceback): | ||
265 | self.close() | ||
266 | return False | ||
diff --git a/bitbake/lib/bb/asyncrpc/connection.py b/bitbake/lib/bb/asyncrpc/connection.py new file mode 100644 index 0000000000..7f0cf6ba96 --- /dev/null +++ b/bitbake/lib/bb/asyncrpc/connection.py | |||
@@ -0,0 +1,146 @@ | |||
1 | # | ||
2 | # Copyright BitBake Contributors | ||
3 | # | ||
4 | # SPDX-License-Identifier: GPL-2.0-only | ||
5 | # | ||
6 | |||
7 | import asyncio | ||
8 | import itertools | ||
9 | import json | ||
10 | from datetime import datetime | ||
11 | from .exceptions import ClientError, ConnectionClosedError | ||
12 | |||
13 | |||
14 | # The Python async server defaults to a 64K receive buffer, so we hardcode our | ||
15 | # maximum chunk size. It would be better if the client and server reported to | ||
16 | # each other what the maximum chunk sizes were, but that will slow down the | ||
17 | # connection setup with a round trip delay so I'd rather not do that unless it | ||
18 | # is necessary | ||
19 | DEFAULT_MAX_CHUNK = 32 * 1024 | ||
20 | |||
21 | |||
22 | def chunkify(msg, max_chunk): | ||
23 | if len(msg) < max_chunk - 1: | ||
24 | yield "".join((msg, "\n")) | ||
25 | else: | ||
26 | yield "".join((json.dumps({"chunk-stream": None}), "\n")) | ||
27 | |||
28 | args = [iter(msg)] * (max_chunk - 1) | ||
29 | for m in map("".join, itertools.zip_longest(*args, fillvalue="")): | ||
30 | yield "".join(itertools.chain(m, "\n")) | ||
31 | yield "\n" | ||
32 | |||
33 | |||
34 | def json_serialize(obj): | ||
35 | if isinstance(obj, datetime): | ||
36 | return obj.isoformat() | ||
37 | raise TypeError("Type %s not serializeable" % type(obj)) | ||
38 | |||
39 | |||
40 | class StreamConnection(object): | ||
41 | def __init__(self, reader, writer, timeout, max_chunk=DEFAULT_MAX_CHUNK): | ||
42 | self.reader = reader | ||
43 | self.writer = writer | ||
44 | self.timeout = timeout | ||
45 | self.max_chunk = max_chunk | ||
46 | |||
47 | @property | ||
48 | def address(self): | ||
49 | return self.writer.get_extra_info("peername") | ||
50 | |||
51 | async def send_message(self, msg): | ||
52 | for c in chunkify(json.dumps(msg, default=json_serialize), self.max_chunk): | ||
53 | self.writer.write(c.encode("utf-8")) | ||
54 | await self.writer.drain() | ||
55 | |||
56 | async def recv_message(self): | ||
57 | l = await self.recv() | ||
58 | |||
59 | m = json.loads(l) | ||
60 | if not m: | ||
61 | return m | ||
62 | |||
63 | if "chunk-stream" in m: | ||
64 | lines = [] | ||
65 | while True: | ||
66 | l = await self.recv() | ||
67 | if not l: | ||
68 | break | ||
69 | lines.append(l) | ||
70 | |||
71 | m = json.loads("".join(lines)) | ||
72 | |||
73 | return m | ||
74 | |||
75 | async def send(self, msg): | ||
76 | self.writer.write(("%s\n" % msg).encode("utf-8")) | ||
77 | await self.writer.drain() | ||
78 | |||
79 | async def recv(self): | ||
80 | if self.timeout < 0: | ||
81 | line = await self.reader.readline() | ||
82 | else: | ||
83 | try: | ||
84 | line = await asyncio.wait_for(self.reader.readline(), self.timeout) | ||
85 | except asyncio.TimeoutError: | ||
86 | raise ConnectionError("Timed out waiting for data") | ||
87 | |||
88 | if not line: | ||
89 | raise ConnectionClosedError("Connection closed") | ||
90 | |||
91 | line = line.decode("utf-8") | ||
92 | |||
93 | if not line.endswith("\n"): | ||
94 | raise ConnectionError("Bad message %r" % (line)) | ||
95 | |||
96 | return line.rstrip() | ||
97 | |||
98 | async def close(self): | ||
99 | self.reader = None | ||
100 | if self.writer is not None: | ||
101 | self.writer.close() | ||
102 | self.writer = None | ||
103 | |||
104 | |||
105 | class WebsocketConnection(object): | ||
106 | def __init__(self, socket, timeout): | ||
107 | self.socket = socket | ||
108 | self.timeout = timeout | ||
109 | |||
110 | @property | ||
111 | def address(self): | ||
112 | return ":".join(str(s) for s in self.socket.remote_address) | ||
113 | |||
114 | async def send_message(self, msg): | ||
115 | await self.send(json.dumps(msg, default=json_serialize)) | ||
116 | |||
117 | async def recv_message(self): | ||
118 | m = await self.recv() | ||
119 | return json.loads(m) | ||
120 | |||
121 | async def send(self, msg): | ||
122 | import websockets.exceptions | ||
123 | |||
124 | try: | ||
125 | await self.socket.send(msg) | ||
126 | except websockets.exceptions.ConnectionClosed: | ||
127 | raise ConnectionClosedError("Connection closed") | ||
128 | |||
129 | async def recv(self): | ||
130 | import websockets.exceptions | ||
131 | |||
132 | try: | ||
133 | if self.timeout < 0: | ||
134 | return await self.socket.recv() | ||
135 | |||
136 | try: | ||
137 | return await asyncio.wait_for(self.socket.recv(), self.timeout) | ||
138 | except asyncio.TimeoutError: | ||
139 | raise ConnectionError("Timed out waiting for data") | ||
140 | except websockets.exceptions.ConnectionClosed: | ||
141 | raise ConnectionClosedError("Connection closed") | ||
142 | |||
143 | async def close(self): | ||
144 | if self.socket is not None: | ||
145 | await self.socket.close() | ||
146 | self.socket = None | ||
diff --git a/bitbake/lib/bb/asyncrpc/exceptions.py b/bitbake/lib/bb/asyncrpc/exceptions.py new file mode 100644 index 0000000000..ae1043a38b --- /dev/null +++ b/bitbake/lib/bb/asyncrpc/exceptions.py | |||
@@ -0,0 +1,21 @@ | |||
1 | # | ||
2 | # Copyright BitBake Contributors | ||
3 | # | ||
4 | # SPDX-License-Identifier: GPL-2.0-only | ||
5 | # | ||
6 | |||
7 | |||
8 | class ClientError(Exception): | ||
9 | pass | ||
10 | |||
11 | |||
12 | class InvokeError(Exception): | ||
13 | pass | ||
14 | |||
15 | |||
16 | class ServerError(Exception): | ||
17 | pass | ||
18 | |||
19 | |||
20 | class ConnectionClosedError(Exception): | ||
21 | pass | ||
diff --git a/bitbake/lib/bb/asyncrpc/serv.py b/bitbake/lib/bb/asyncrpc/serv.py new file mode 100644 index 0000000000..667217c5c1 --- /dev/null +++ b/bitbake/lib/bb/asyncrpc/serv.py | |||
@@ -0,0 +1,410 @@ | |||
1 | # | ||
2 | # Copyright BitBake Contributors | ||
3 | # | ||
4 | # SPDX-License-Identifier: GPL-2.0-only | ||
5 | # | ||
6 | |||
7 | import abc | ||
8 | import asyncio | ||
9 | import json | ||
10 | import os | ||
11 | import signal | ||
12 | import socket | ||
13 | import sys | ||
14 | import multiprocessing | ||
15 | import logging | ||
16 | from .connection import StreamConnection, WebsocketConnection | ||
17 | from .exceptions import ClientError, ServerError, ConnectionClosedError, InvokeError | ||
18 | |||
19 | |||
20 | class ClientLoggerAdapter(logging.LoggerAdapter): | ||
21 | def process(self, msg, kwargs): | ||
22 | return f"[Client {self.extra['address']}] {msg}", kwargs | ||
23 | |||
24 | |||
25 | class AsyncServerConnection(object): | ||
26 | # If a handler returns this object (e.g. `return self.NO_RESPONSE`), no | ||
27 | # return message will be automatically be sent back to the client | ||
28 | NO_RESPONSE = object() | ||
29 | |||
30 | def __init__(self, socket, proto_name, logger): | ||
31 | self.socket = socket | ||
32 | self.proto_name = proto_name | ||
33 | self.handlers = { | ||
34 | "ping": self.handle_ping, | ||
35 | } | ||
36 | self.logger = ClientLoggerAdapter( | ||
37 | logger, | ||
38 | { | ||
39 | "address": socket.address, | ||
40 | }, | ||
41 | ) | ||
42 | self.client_headers = {} | ||
43 | |||
44 | async def close(self): | ||
45 | await self.socket.close() | ||
46 | |||
47 | async def handle_headers(self, headers): | ||
48 | return {} | ||
49 | |||
50 | async def process_requests(self): | ||
51 | try: | ||
52 | self.logger.info("Client %r connected" % (self.socket.address,)) | ||
53 | |||
54 | # Read protocol and version | ||
55 | client_protocol = await self.socket.recv() | ||
56 | if not client_protocol: | ||
57 | return | ||
58 | |||
59 | (client_proto_name, client_proto_version) = client_protocol.split() | ||
60 | if client_proto_name != self.proto_name: | ||
61 | self.logger.debug("Rejecting invalid protocol %s" % (self.proto_name)) | ||
62 | return | ||
63 | |||
64 | self.proto_version = tuple(int(v) for v in client_proto_version.split(".")) | ||
65 | if not self.validate_proto_version(): | ||
66 | self.logger.debug( | ||
67 | "Rejecting invalid protocol version %s" % (client_proto_version) | ||
68 | ) | ||
69 | return | ||
70 | |||
71 | # Read headers | ||
72 | self.client_headers = {} | ||
73 | while True: | ||
74 | header = await self.socket.recv() | ||
75 | if not header: | ||
76 | # Empty line. End of headers | ||
77 | break | ||
78 | tag, value = header.split(":", 1) | ||
79 | self.client_headers[tag.lower()] = value.strip() | ||
80 | |||
81 | if self.client_headers.get("needs-headers", "false") == "true": | ||
82 | for k, v in (await self.handle_headers(self.client_headers)).items(): | ||
83 | await self.socket.send("%s: %s" % (k, v)) | ||
84 | await self.socket.send("") | ||
85 | |||
86 | # Handle messages | ||
87 | while True: | ||
88 | d = await self.socket.recv_message() | ||
89 | if d is None: | ||
90 | break | ||
91 | try: | ||
92 | response = await self.dispatch_message(d) | ||
93 | except InvokeError as e: | ||
94 | await self.socket.send_message( | ||
95 | {"invoke-error": {"message": str(e)}} | ||
96 | ) | ||
97 | break | ||
98 | |||
99 | if response is not self.NO_RESPONSE: | ||
100 | await self.socket.send_message(response) | ||
101 | |||
102 | except ConnectionClosedError as e: | ||
103 | self.logger.info(str(e)) | ||
104 | except (ClientError, ConnectionError) as e: | ||
105 | self.logger.error(str(e)) | ||
106 | finally: | ||
107 | await self.close() | ||
108 | |||
109 | async def dispatch_message(self, msg): | ||
110 | for k in self.handlers.keys(): | ||
111 | if k in msg: | ||
112 | self.logger.debug("Handling %s" % k) | ||
113 | return await self.handlers[k](msg[k]) | ||
114 | |||
115 | raise ClientError("Unrecognized command %r" % msg) | ||
116 | |||
117 | async def handle_ping(self, request): | ||
118 | return {"alive": True} | ||
119 | |||
120 | |||
121 | class StreamServer(object): | ||
122 | def __init__(self, handler, logger): | ||
123 | self.handler = handler | ||
124 | self.logger = logger | ||
125 | self.closed = False | ||
126 | |||
127 | async def handle_stream_client(self, reader, writer): | ||
128 | # writer.transport.set_write_buffer_limits(0) | ||
129 | socket = StreamConnection(reader, writer, -1) | ||
130 | if self.closed: | ||
131 | await socket.close() | ||
132 | return | ||
133 | |||
134 | await self.handler(socket) | ||
135 | |||
136 | async def stop(self): | ||
137 | self.closed = True | ||
138 | |||
139 | |||
140 | class TCPStreamServer(StreamServer): | ||
141 | def __init__(self, host, port, handler, logger, *, reuseport=False): | ||
142 | super().__init__(handler, logger) | ||
143 | self.host = host | ||
144 | self.port = port | ||
145 | self.reuseport = reuseport | ||
146 | |||
147 | def start(self, loop): | ||
148 | self.server = loop.run_until_complete( | ||
149 | asyncio.start_server( | ||
150 | self.handle_stream_client, | ||
151 | self.host, | ||
152 | self.port, | ||
153 | reuse_port=self.reuseport, | ||
154 | ) | ||
155 | ) | ||
156 | |||
157 | for s in self.server.sockets: | ||
158 | self.logger.debug("Listening on %r" % (s.getsockname(),)) | ||
159 | # Newer python does this automatically. Do it manually here for | ||
160 | # maximum compatibility | ||
161 | s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) | ||
162 | s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1) | ||
163 | |||
164 | # Enable keep alives. This prevents broken client connections | ||
165 | # from persisting on the server for long periods of time. | ||
166 | s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) | ||
167 | s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30) | ||
168 | s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15) | ||
169 | s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4) | ||
170 | |||
171 | name = self.server.sockets[0].getsockname() | ||
172 | if self.server.sockets[0].family == socket.AF_INET6: | ||
173 | self.address = "[%s]:%d" % (name[0], name[1]) | ||
174 | else: | ||
175 | self.address = "%s:%d" % (name[0], name[1]) | ||
176 | |||
177 | return [self.server.wait_closed()] | ||
178 | |||
179 | async def stop(self): | ||
180 | await super().stop() | ||
181 | self.server.close() | ||
182 | |||
183 | def cleanup(self): | ||
184 | pass | ||
185 | |||
186 | |||
187 | class UnixStreamServer(StreamServer): | ||
188 | def __init__(self, path, handler, logger): | ||
189 | super().__init__(handler, logger) | ||
190 | self.path = path | ||
191 | |||
192 | def start(self, loop): | ||
193 | cwd = os.getcwd() | ||
194 | try: | ||
195 | # Work around path length limits in AF_UNIX | ||
196 | os.chdir(os.path.dirname(self.path)) | ||
197 | self.server = loop.run_until_complete( | ||
198 | asyncio.start_unix_server( | ||
199 | self.handle_stream_client, os.path.basename(self.path) | ||
200 | ) | ||
201 | ) | ||
202 | finally: | ||
203 | os.chdir(cwd) | ||
204 | |||
205 | self.logger.debug("Listening on %r" % self.path) | ||
206 | self.address = "unix://%s" % os.path.abspath(self.path) | ||
207 | return [self.server.wait_closed()] | ||
208 | |||
209 | async def stop(self): | ||
210 | await super().stop() | ||
211 | self.server.close() | ||
212 | |||
213 | def cleanup(self): | ||
214 | os.unlink(self.path) | ||
215 | |||
216 | |||
217 | class WebsocketsServer(object): | ||
218 | def __init__(self, host, port, handler, logger, *, reuseport=False): | ||
219 | self.host = host | ||
220 | self.port = port | ||
221 | self.handler = handler | ||
222 | self.logger = logger | ||
223 | self.reuseport = reuseport | ||
224 | |||
225 | def start(self, loop): | ||
226 | import websockets.server | ||
227 | |||
228 | self.server = loop.run_until_complete( | ||
229 | websockets.server.serve( | ||
230 | self.client_handler, | ||
231 | self.host, | ||
232 | self.port, | ||
233 | ping_interval=None, | ||
234 | reuse_port=self.reuseport, | ||
235 | ) | ||
236 | ) | ||
237 | |||
238 | for s in self.server.sockets: | ||
239 | self.logger.debug("Listening on %r" % (s.getsockname(),)) | ||
240 | |||
241 | # Enable keep alives. This prevents broken client connections | ||
242 | # from persisting on the server for long periods of time. | ||
243 | s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) | ||
244 | s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30) | ||
245 | s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15) | ||
246 | s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4) | ||
247 | |||
248 | name = self.server.sockets[0].getsockname() | ||
249 | if self.server.sockets[0].family == socket.AF_INET6: | ||
250 | self.address = "ws://[%s]:%d" % (name[0], name[1]) | ||
251 | else: | ||
252 | self.address = "ws://%s:%d" % (name[0], name[1]) | ||
253 | |||
254 | return [self.server.wait_closed()] | ||
255 | |||
256 | async def stop(self): | ||
257 | self.server.close() | ||
258 | |||
259 | def cleanup(self): | ||
260 | pass | ||
261 | |||
262 | async def client_handler(self, websocket): | ||
263 | socket = WebsocketConnection(websocket, -1) | ||
264 | await self.handler(socket) | ||
265 | |||
266 | |||
267 | class AsyncServer(object): | ||
268 | def __init__(self, logger): | ||
269 | self.logger = logger | ||
270 | self.loop = None | ||
271 | self.run_tasks = [] | ||
272 | |||
273 | def start_tcp_server(self, host, port, *, reuseport=False): | ||
274 | self.server = TCPStreamServer( | ||
275 | host, | ||
276 | port, | ||
277 | self._client_handler, | ||
278 | self.logger, | ||
279 | reuseport=reuseport, | ||
280 | ) | ||
281 | |||
282 | def start_unix_server(self, path): | ||
283 | self.server = UnixStreamServer(path, self._client_handler, self.logger) | ||
284 | |||
285 | def start_websocket_server(self, host, port, reuseport=False): | ||
286 | self.server = WebsocketsServer( | ||
287 | host, | ||
288 | port, | ||
289 | self._client_handler, | ||
290 | self.logger, | ||
291 | reuseport=reuseport, | ||
292 | ) | ||
293 | |||
294 | async def _client_handler(self, socket): | ||
295 | address = socket.address | ||
296 | try: | ||
297 | client = self.accept_client(socket) | ||
298 | await client.process_requests() | ||
299 | except Exception as e: | ||
300 | import traceback | ||
301 | |||
302 | self.logger.error( | ||
303 | "Error from client %s: %s" % (address, str(e)), exc_info=True | ||
304 | ) | ||
305 | traceback.print_exc() | ||
306 | finally: | ||
307 | self.logger.debug("Client %s disconnected", address) | ||
308 | await socket.close() | ||
309 | |||
310 | @abc.abstractmethod | ||
311 | def accept_client(self, socket): | ||
312 | pass | ||
313 | |||
314 | async def stop(self): | ||
315 | self.logger.debug("Stopping server") | ||
316 | await self.server.stop() | ||
317 | |||
318 | def start(self): | ||
319 | tasks = self.server.start(self.loop) | ||
320 | self.address = self.server.address | ||
321 | return tasks | ||
322 | |||
323 | def signal_handler(self): | ||
324 | self.logger.debug("Got exit signal") | ||
325 | self.loop.create_task(self.stop()) | ||
326 | |||
327 | def _serve_forever(self, tasks): | ||
328 | try: | ||
329 | self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler) | ||
330 | self.loop.add_signal_handler(signal.SIGINT, self.signal_handler) | ||
331 | self.loop.add_signal_handler(signal.SIGQUIT, self.signal_handler) | ||
332 | signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM]) | ||
333 | |||
334 | self.loop.run_until_complete(asyncio.gather(*tasks)) | ||
335 | |||
336 | self.logger.debug("Server shutting down") | ||
337 | finally: | ||
338 | self.server.cleanup() | ||
339 | |||
340 | def serve_forever(self): | ||
341 | """ | ||
342 | Serve requests in the current process | ||
343 | """ | ||
344 | self._create_loop() | ||
345 | tasks = self.start() | ||
346 | self._serve_forever(tasks) | ||
347 | self.loop.close() | ||
348 | |||
349 | def _create_loop(self): | ||
350 | # Create loop and override any loop that may have existed in | ||
351 | # a parent process. It is possible that the usecases of | ||
352 | # serve_forever might be constrained enough to allow using | ||
353 | # get_event_loop here, but better safe than sorry for now. | ||
354 | self.loop = asyncio.new_event_loop() | ||
355 | asyncio.set_event_loop(self.loop) | ||
356 | |||
357 | def serve_as_process(self, *, prefunc=None, args=(), log_level=None): | ||
358 | """ | ||
359 | Serve requests in a child process | ||
360 | """ | ||
361 | |||
362 | def run(queue): | ||
363 | # Create loop and override any loop that may have existed | ||
364 | # in a parent process. Without doing this and instead | ||
365 | # using get_event_loop, at the very minimum the hashserv | ||
366 | # unit tests will hang when running the second test. | ||
367 | # This happens since get_event_loop in the spawned server | ||
368 | # process for the second testcase ends up with the loop | ||
369 | # from the hashserv client created in the unit test process | ||
370 | # when running the first testcase. The problem is somewhat | ||
371 | # more general, though, as any potential use of asyncio in | ||
372 | # Cooker could create a loop that needs to replaced in this | ||
373 | # new process. | ||
374 | self._create_loop() | ||
375 | try: | ||
376 | self.address = None | ||
377 | tasks = self.start() | ||
378 | finally: | ||
379 | # Always put the server address to wake up the parent task | ||
380 | queue.put(self.address) | ||
381 | queue.close() | ||
382 | |||
383 | if prefunc is not None: | ||
384 | prefunc(self, *args) | ||
385 | |||
386 | if log_level is not None: | ||
387 | self.logger.setLevel(log_level) | ||
388 | |||
389 | self._serve_forever(tasks) | ||
390 | |||
391 | self.loop.run_until_complete(self.loop.shutdown_asyncgens()) | ||
392 | self.loop.close() | ||
393 | |||
394 | queue = multiprocessing.Queue() | ||
395 | |||
396 | # Temporarily block SIGTERM. The server process will inherit this | ||
397 | # block which will ensure it doesn't receive the SIGTERM until the | ||
398 | # handler is ready for it | ||
399 | mask = signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGTERM]) | ||
400 | try: | ||
401 | self.process = multiprocessing.Process(target=run, args=(queue,)) | ||
402 | self.process.start() | ||
403 | |||
404 | self.address = queue.get() | ||
405 | queue.close() | ||
406 | queue.join_thread() | ||
407 | |||
408 | return self.process | ||
409 | finally: | ||
410 | signal.pthread_sigmask(signal.SIG_SETMASK, mask) | ||