diff options
Diffstat (limited to 'bitbake/lib/bb/asyncrpc/client.py')
-rw-r--r-- | bitbake/lib/bb/asyncrpc/client.py | 271 |
1 files changed, 271 insertions, 0 deletions
diff --git a/bitbake/lib/bb/asyncrpc/client.py b/bitbake/lib/bb/asyncrpc/client.py new file mode 100644 index 0000000000..17b72033b9 --- /dev/null +++ b/bitbake/lib/bb/asyncrpc/client.py | |||
@@ -0,0 +1,271 @@ | |||
1 | # | ||
2 | # Copyright BitBake Contributors | ||
3 | # | ||
4 | # SPDX-License-Identifier: GPL-2.0-only | ||
5 | # | ||
6 | |||
7 | import abc | ||
8 | import asyncio | ||
9 | import json | ||
10 | import os | ||
11 | import socket | ||
12 | import sys | ||
13 | import re | ||
14 | import contextlib | ||
15 | from threading import Thread | ||
16 | from .connection import StreamConnection, WebsocketConnection, DEFAULT_MAX_CHUNK | ||
17 | from .exceptions import ConnectionClosedError, InvokeError | ||
18 | |||
19 | UNIX_PREFIX = "unix://" | ||
20 | WS_PREFIX = "ws://" | ||
21 | WSS_PREFIX = "wss://" | ||
22 | |||
23 | ADDR_TYPE_UNIX = 0 | ||
24 | ADDR_TYPE_TCP = 1 | ||
25 | ADDR_TYPE_WS = 2 | ||
26 | |||
27 | WEBSOCKETS_MIN_VERSION = (9, 1) | ||
28 | # Need websockets 10 with python 3.10+ | ||
29 | if sys.version_info >= (3, 10, 0): | ||
30 | WEBSOCKETS_MIN_VERSION = (10, 0) | ||
31 | |||
32 | |||
33 | def parse_address(addr): | ||
34 | if addr.startswith(UNIX_PREFIX): | ||
35 | return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX) :],)) | ||
36 | elif addr.startswith(WS_PREFIX) or addr.startswith(WSS_PREFIX): | ||
37 | return (ADDR_TYPE_WS, (addr,)) | ||
38 | else: | ||
39 | m = re.match(r"\[(?P<host>[^\]]*)\]:(?P<port>\d+)$", addr) | ||
40 | if m is not None: | ||
41 | host = m.group("host") | ||
42 | port = m.group("port") | ||
43 | else: | ||
44 | host, port = addr.split(":") | ||
45 | |||
46 | return (ADDR_TYPE_TCP, (host, int(port))) | ||
47 | |||
48 | |||
49 | class AsyncClient(object): | ||
50 | def __init__( | ||
51 | self, | ||
52 | proto_name, | ||
53 | proto_version, | ||
54 | logger, | ||
55 | timeout=30, | ||
56 | server_headers=False, | ||
57 | headers={}, | ||
58 | ): | ||
59 | self.socket = None | ||
60 | self.max_chunk = DEFAULT_MAX_CHUNK | ||
61 | self.proto_name = proto_name | ||
62 | self.proto_version = proto_version | ||
63 | self.logger = logger | ||
64 | self.timeout = timeout | ||
65 | self.needs_server_headers = server_headers | ||
66 | self.server_headers = {} | ||
67 | self.headers = headers | ||
68 | |||
69 | async def connect_tcp(self, address, port): | ||
70 | async def connect_sock(): | ||
71 | reader, writer = await asyncio.open_connection(address, port) | ||
72 | return StreamConnection(reader, writer, self.timeout, self.max_chunk) | ||
73 | |||
74 | self._connect_sock = connect_sock | ||
75 | |||
76 | async def connect_unix(self, path): | ||
77 | async def connect_sock(): | ||
78 | # AF_UNIX has path length issues so chdir here to workaround | ||
79 | cwd = os.getcwd() | ||
80 | try: | ||
81 | os.chdir(os.path.dirname(path)) | ||
82 | # The socket must be opened synchronously so that CWD doesn't get | ||
83 | # changed out from underneath us so we pass as a sock into asyncio | ||
84 | sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0) | ||
85 | sock.connect(os.path.basename(path)) | ||
86 | finally: | ||
87 | os.chdir(cwd) | ||
88 | reader, writer = await asyncio.open_unix_connection(sock=sock) | ||
89 | return StreamConnection(reader, writer, self.timeout, self.max_chunk) | ||
90 | |||
91 | self._connect_sock = connect_sock | ||
92 | |||
93 | async def connect_websocket(self, uri): | ||
94 | import websockets | ||
95 | |||
96 | try: | ||
97 | version = tuple( | ||
98 | int(v) | ||
99 | for v in websockets.__version__.split(".")[ | ||
100 | 0 : len(WEBSOCKETS_MIN_VERSION) | ||
101 | ] | ||
102 | ) | ||
103 | except ValueError: | ||
104 | raise ImportError( | ||
105 | f"Unable to parse websockets version '{websockets.__version__}'" | ||
106 | ) | ||
107 | |||
108 | if version < WEBSOCKETS_MIN_VERSION: | ||
109 | min_ver_str = ".".join(str(v) for v in WEBSOCKETS_MIN_VERSION) | ||
110 | raise ImportError( | ||
111 | f"Websockets version {websockets.__version__} is less than minimum required version {min_ver_str}" | ||
112 | ) | ||
113 | |||
114 | async def connect_sock(): | ||
115 | try: | ||
116 | websocket = await websockets.connect( | ||
117 | uri, | ||
118 | ping_interval=None, | ||
119 | open_timeout=self.timeout, | ||
120 | ) | ||
121 | except asyncio.exceptions.TimeoutError: | ||
122 | raise ConnectionError("Timeout while connecting to websocket") | ||
123 | except (OSError, websockets.InvalidHandshake, websockets.InvalidURI) as exc: | ||
124 | raise ConnectionError(f"Could not connect to websocket: {exc}") from exc | ||
125 | return WebsocketConnection(websocket, self.timeout) | ||
126 | |||
127 | self._connect_sock = connect_sock | ||
128 | |||
129 | async def setup_connection(self): | ||
130 | # Send headers | ||
131 | await self.socket.send("%s %s" % (self.proto_name, self.proto_version)) | ||
132 | await self.socket.send( | ||
133 | "needs-headers: %s" % ("true" if self.needs_server_headers else "false") | ||
134 | ) | ||
135 | for k, v in self.headers.items(): | ||
136 | await self.socket.send("%s: %s" % (k, v)) | ||
137 | |||
138 | # End of headers | ||
139 | await self.socket.send("") | ||
140 | |||
141 | self.server_headers = {} | ||
142 | if self.needs_server_headers: | ||
143 | while True: | ||
144 | line = await self.socket.recv() | ||
145 | if not line: | ||
146 | # End headers | ||
147 | break | ||
148 | tag, value = line.split(":", 1) | ||
149 | self.server_headers[tag.lower()] = value.strip() | ||
150 | |||
151 | async def get_header(self, tag, default): | ||
152 | await self.connect() | ||
153 | return self.server_headers.get(tag, default) | ||
154 | |||
155 | async def connect(self): | ||
156 | if self.socket is None: | ||
157 | self.socket = await self._connect_sock() | ||
158 | await self.setup_connection() | ||
159 | |||
160 | async def disconnect(self): | ||
161 | if self.socket is not None: | ||
162 | await self.socket.close() | ||
163 | self.socket = None | ||
164 | |||
165 | async def close(self): | ||
166 | await self.disconnect() | ||
167 | |||
168 | async def _send_wrapper(self, proc): | ||
169 | count = 0 | ||
170 | while True: | ||
171 | try: | ||
172 | await self.connect() | ||
173 | return await proc() | ||
174 | except ( | ||
175 | OSError, | ||
176 | ConnectionError, | ||
177 | ConnectionClosedError, | ||
178 | json.JSONDecodeError, | ||
179 | UnicodeDecodeError, | ||
180 | ) as e: | ||
181 | self.logger.warning("Error talking to server: %s" % e) | ||
182 | if count >= 3: | ||
183 | if not isinstance(e, ConnectionError): | ||
184 | raise ConnectionError(str(e)) | ||
185 | raise e | ||
186 | await self.close() | ||
187 | count += 1 | ||
188 | |||
189 | def check_invoke_error(self, msg): | ||
190 | if isinstance(msg, dict) and "invoke-error" in msg: | ||
191 | raise InvokeError(msg["invoke-error"]["message"]) | ||
192 | |||
193 | async def invoke(self, msg): | ||
194 | async def proc(): | ||
195 | await self.socket.send_message(msg) | ||
196 | return await self.socket.recv_message() | ||
197 | |||
198 | result = await self._send_wrapper(proc) | ||
199 | self.check_invoke_error(result) | ||
200 | return result | ||
201 | |||
202 | async def ping(self): | ||
203 | return await self.invoke({"ping": {}}) | ||
204 | |||
205 | async def __aenter__(self): | ||
206 | return self | ||
207 | |||
208 | async def __aexit__(self, exc_type, exc_value, traceback): | ||
209 | await self.close() | ||
210 | |||
211 | |||
212 | class Client(object): | ||
213 | def __init__(self): | ||
214 | self.client = self._get_async_client() | ||
215 | self.loop = asyncio.new_event_loop() | ||
216 | |||
217 | # Override any pre-existing loop. | ||
218 | # Without this, the PR server export selftest triggers a hang | ||
219 | # when running with Python 3.7. The drawback is that there is | ||
220 | # potential for issues if the PR and hash equiv (or some new) | ||
221 | # clients need to both be instantiated in the same process. | ||
222 | # This should be revisited if/when Python 3.9 becomes the | ||
223 | # minimum required version for BitBake, as it seems not | ||
224 | # required (but harmless) with it. | ||
225 | asyncio.set_event_loop(self.loop) | ||
226 | |||
227 | self._add_methods("connect_tcp", "ping") | ||
228 | |||
229 | @abc.abstractmethod | ||
230 | def _get_async_client(self): | ||
231 | pass | ||
232 | |||
233 | def _get_downcall_wrapper(self, downcall): | ||
234 | def wrapper(*args, **kwargs): | ||
235 | return self.loop.run_until_complete(downcall(*args, **kwargs)) | ||
236 | |||
237 | return wrapper | ||
238 | |||
239 | def _add_methods(self, *methods): | ||
240 | for m in methods: | ||
241 | downcall = getattr(self.client, m) | ||
242 | setattr(self, m, self._get_downcall_wrapper(downcall)) | ||
243 | |||
244 | def connect_unix(self, path): | ||
245 | self.loop.run_until_complete(self.client.connect_unix(path)) | ||
246 | self.loop.run_until_complete(self.client.connect()) | ||
247 | |||
248 | @property | ||
249 | def max_chunk(self): | ||
250 | return self.client.max_chunk | ||
251 | |||
252 | @max_chunk.setter | ||
253 | def max_chunk(self, value): | ||
254 | self.client.max_chunk = value | ||
255 | |||
256 | def disconnect(self): | ||
257 | self.loop.run_until_complete(self.client.close()) | ||
258 | |||
259 | def close(self): | ||
260 | if self.loop: | ||
261 | self.loop.run_until_complete(self.client.close()) | ||
262 | self.loop.run_until_complete(self.loop.shutdown_asyncgens()) | ||
263 | self.loop.close() | ||
264 | self.loop = None | ||
265 | |||
266 | def __enter__(self): | ||
267 | return self | ||
268 | |||
269 | def __exit__(self, exc_type, exc_value, traceback): | ||
270 | self.close() | ||
271 | return False | ||