diff options
Diffstat (limited to 'bitbake/lib/bb/asyncrpc/client.py')
-rw-r--r-- | bitbake/lib/bb/asyncrpc/client.py | 337 |
1 files changed, 337 insertions, 0 deletions
diff --git a/bitbake/lib/bb/asyncrpc/client.py b/bitbake/lib/bb/asyncrpc/client.py new file mode 100644 index 0000000000..b49de99313 --- /dev/null +++ b/bitbake/lib/bb/asyncrpc/client.py | |||
@@ -0,0 +1,337 @@ | |||
1 | # | ||
2 | # Copyright BitBake Contributors | ||
3 | # | ||
4 | # SPDX-License-Identifier: GPL-2.0-only | ||
5 | # | ||
6 | |||
7 | import abc | ||
8 | import asyncio | ||
9 | import json | ||
10 | import os | ||
11 | import socket | ||
12 | import sys | ||
13 | import re | ||
14 | import contextlib | ||
15 | from threading import Thread | ||
16 | from .connection import StreamConnection, WebsocketConnection, DEFAULT_MAX_CHUNK | ||
17 | from .exceptions import ConnectionClosedError, InvokeError | ||
18 | |||
19 | UNIX_PREFIX = "unix://" | ||
20 | WS_PREFIX = "ws://" | ||
21 | WSS_PREFIX = "wss://" | ||
22 | |||
23 | ADDR_TYPE_UNIX = 0 | ||
24 | ADDR_TYPE_TCP = 1 | ||
25 | ADDR_TYPE_WS = 2 | ||
26 | |||
27 | WEBSOCKETS_MIN_VERSION = (9, 1) | ||
28 | # Need websockets 10 with python 3.10+ | ||
29 | if sys.version_info >= (3, 10, 0): | ||
30 | WEBSOCKETS_MIN_VERSION = (10, 0) | ||
31 | |||
32 | def parse_address(addr): | ||
33 | if addr.startswith(UNIX_PREFIX): | ||
34 | return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX) :],)) | ||
35 | elif addr.startswith(WS_PREFIX) or addr.startswith(WSS_PREFIX): | ||
36 | return (ADDR_TYPE_WS, (addr,)) | ||
37 | else: | ||
38 | m = re.match(r"\[(?P<host>[^\]]*)\]:(?P<port>\d+)$", addr) | ||
39 | if m is not None: | ||
40 | host = m.group("host") | ||
41 | port = m.group("port") | ||
42 | else: | ||
43 | host, port = addr.split(":") | ||
44 | |||
45 | return (ADDR_TYPE_TCP, (host, int(port))) | ||
46 | |||
47 | |||
48 | class AsyncClient(object): | ||
49 | def __init__( | ||
50 | self, | ||
51 | proto_name, | ||
52 | proto_version, | ||
53 | logger, | ||
54 | timeout=30, | ||
55 | server_headers=False, | ||
56 | headers={}, | ||
57 | ): | ||
58 | self.socket = None | ||
59 | self.max_chunk = DEFAULT_MAX_CHUNK | ||
60 | self.proto_name = proto_name | ||
61 | self.proto_version = proto_version | ||
62 | self.logger = logger | ||
63 | self.timeout = timeout | ||
64 | self.needs_server_headers = server_headers | ||
65 | self.server_headers = {} | ||
66 | self.headers = headers | ||
67 | |||
68 | async def connect_tcp(self, address, port): | ||
69 | async def connect_sock(): | ||
70 | reader, writer = await asyncio.open_connection(address, port) | ||
71 | return StreamConnection(reader, writer, self.timeout, self.max_chunk) | ||
72 | |||
73 | self._connect_sock = connect_sock | ||
74 | |||
75 | async def connect_unix(self, path): | ||
76 | async def connect_sock(): | ||
77 | # AF_UNIX has path length issues so chdir here to workaround | ||
78 | cwd = os.getcwd() | ||
79 | try: | ||
80 | os.chdir(os.path.dirname(path)) | ||
81 | # The socket must be opened synchronously so that CWD doesn't get | ||
82 | # changed out from underneath us so we pass as a sock into asyncio | ||
83 | sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0) | ||
84 | sock.connect(os.path.basename(path)) | ||
85 | finally: | ||
86 | os.chdir(cwd) | ||
87 | reader, writer = await asyncio.open_unix_connection(sock=sock) | ||
88 | return StreamConnection(reader, writer, self.timeout, self.max_chunk) | ||
89 | |||
90 | self._connect_sock = connect_sock | ||
91 | |||
92 | async def connect_websocket(self, uri): | ||
93 | import websockets | ||
94 | |||
95 | try: | ||
96 | version = tuple( | ||
97 | int(v) | ||
98 | for v in websockets.__version__.split(".")[ | ||
99 | 0 : len(WEBSOCKETS_MIN_VERSION) | ||
100 | ] | ||
101 | ) | ||
102 | except ValueError: | ||
103 | raise ImportError( | ||
104 | f"Unable to parse websockets version '{websockets.__version__}'" | ||
105 | ) | ||
106 | |||
107 | if version < WEBSOCKETS_MIN_VERSION: | ||
108 | min_ver_str = ".".join(str(v) for v in WEBSOCKETS_MIN_VERSION) | ||
109 | raise ImportError( | ||
110 | f"Websockets version {websockets.__version__} is less than minimum required version {min_ver_str}" | ||
111 | ) | ||
112 | |||
113 | async def connect_sock(): | ||
114 | websocket = await websockets.connect(uri, ping_interval=None) | ||
115 | return WebsocketConnection(websocket, self.timeout) | ||
116 | |||
117 | self._connect_sock = connect_sock | ||
118 | |||
119 | async def setup_connection(self): | ||
120 | # Send headers | ||
121 | await self.socket.send("%s %s" % (self.proto_name, self.proto_version)) | ||
122 | await self.socket.send( | ||
123 | "needs-headers: %s" % ("true" if self.needs_server_headers else "false") | ||
124 | ) | ||
125 | for k, v in self.headers.items(): | ||
126 | await self.socket.send("%s: %s" % (k, v)) | ||
127 | |||
128 | # End of headers | ||
129 | await self.socket.send("") | ||
130 | |||
131 | self.server_headers = {} | ||
132 | if self.needs_server_headers: | ||
133 | while True: | ||
134 | line = await self.socket.recv() | ||
135 | if not line: | ||
136 | # End headers | ||
137 | break | ||
138 | tag, value = line.split(":", 1) | ||
139 | self.server_headers[tag.lower()] = value.strip() | ||
140 | |||
141 | async def get_header(self, tag, default): | ||
142 | await self.connect() | ||
143 | return self.server_headers.get(tag, default) | ||
144 | |||
145 | async def connect(self): | ||
146 | if self.socket is None: | ||
147 | self.socket = await self._connect_sock() | ||
148 | await self.setup_connection() | ||
149 | |||
150 | async def disconnect(self): | ||
151 | if self.socket is not None: | ||
152 | await self.socket.close() | ||
153 | self.socket = None | ||
154 | |||
155 | async def close(self): | ||
156 | await self.disconnect() | ||
157 | |||
158 | async def _send_wrapper(self, proc): | ||
159 | count = 0 | ||
160 | while True: | ||
161 | try: | ||
162 | await self.connect() | ||
163 | return await proc() | ||
164 | except ( | ||
165 | OSError, | ||
166 | ConnectionError, | ||
167 | ConnectionClosedError, | ||
168 | json.JSONDecodeError, | ||
169 | UnicodeDecodeError, | ||
170 | ) as e: | ||
171 | self.logger.warning("Error talking to server: %s" % e) | ||
172 | if count >= 3: | ||
173 | if not isinstance(e, ConnectionError): | ||
174 | raise ConnectionError(str(e)) | ||
175 | raise e | ||
176 | await self.close() | ||
177 | count += 1 | ||
178 | |||
179 | def check_invoke_error(self, msg): | ||
180 | if isinstance(msg, dict) and "invoke-error" in msg: | ||
181 | raise InvokeError(msg["invoke-error"]["message"]) | ||
182 | |||
183 | async def invoke(self, msg): | ||
184 | async def proc(): | ||
185 | await self.socket.send_message(msg) | ||
186 | return await self.socket.recv_message() | ||
187 | |||
188 | result = await self._send_wrapper(proc) | ||
189 | self.check_invoke_error(result) | ||
190 | return result | ||
191 | |||
192 | async def ping(self): | ||
193 | return await self.invoke({"ping": {}}) | ||
194 | |||
195 | async def __aenter__(self): | ||
196 | return self | ||
197 | |||
198 | async def __aexit__(self, exc_type, exc_value, traceback): | ||
199 | await self.close() | ||
200 | |||
201 | |||
202 | class Client(object): | ||
203 | def __init__(self): | ||
204 | self.client = self._get_async_client() | ||
205 | self.loop = asyncio.new_event_loop() | ||
206 | |||
207 | # Override any pre-existing loop. | ||
208 | # Without this, the PR server export selftest triggers a hang | ||
209 | # when running with Python 3.7. The drawback is that there is | ||
210 | # potential for issues if the PR and hash equiv (or some new) | ||
211 | # clients need to both be instantiated in the same process. | ||
212 | # This should be revisited if/when Python 3.9 becomes the | ||
213 | # minimum required version for BitBake, as it seems not | ||
214 | # required (but harmless) with it. | ||
215 | asyncio.set_event_loop(self.loop) | ||
216 | |||
217 | self._add_methods("connect_tcp", "ping") | ||
218 | |||
219 | @abc.abstractmethod | ||
220 | def _get_async_client(self): | ||
221 | pass | ||
222 | |||
223 | def _get_downcall_wrapper(self, downcall): | ||
224 | def wrapper(*args, **kwargs): | ||
225 | return self.loop.run_until_complete(downcall(*args, **kwargs)) | ||
226 | |||
227 | return wrapper | ||
228 | |||
229 | def _add_methods(self, *methods): | ||
230 | for m in methods: | ||
231 | downcall = getattr(self.client, m) | ||
232 | setattr(self, m, self._get_downcall_wrapper(downcall)) | ||
233 | |||
234 | def connect_unix(self, path): | ||
235 | self.loop.run_until_complete(self.client.connect_unix(path)) | ||
236 | self.loop.run_until_complete(self.client.connect()) | ||
237 | |||
238 | @property | ||
239 | def max_chunk(self): | ||
240 | return self.client.max_chunk | ||
241 | |||
242 | @max_chunk.setter | ||
243 | def max_chunk(self, value): | ||
244 | self.client.max_chunk = value | ||
245 | |||
246 | def disconnect(self): | ||
247 | self.loop.run_until_complete(self.client.close()) | ||
248 | |||
249 | def close(self): | ||
250 | if self.loop: | ||
251 | self.loop.run_until_complete(self.client.close()) | ||
252 | if sys.version_info >= (3, 6): | ||
253 | self.loop.run_until_complete(self.loop.shutdown_asyncgens()) | ||
254 | self.loop.close() | ||
255 | self.loop = None | ||
256 | |||
257 | def __enter__(self): | ||
258 | return self | ||
259 | |||
260 | def __exit__(self, exc_type, exc_value, traceback): | ||
261 | self.close() | ||
262 | return False | ||
263 | |||
264 | |||
265 | class ClientPool(object): | ||
266 | def __init__(self, max_clients): | ||
267 | self.avail_clients = [] | ||
268 | self.num_clients = 0 | ||
269 | self.max_clients = max_clients | ||
270 | self.loop = None | ||
271 | self.client_condition = None | ||
272 | |||
273 | @abc.abstractmethod | ||
274 | async def _new_client(self): | ||
275 | raise NotImplementedError("Must be implemented in derived class") | ||
276 | |||
277 | def close(self): | ||
278 | if self.client_condition: | ||
279 | self.client_condition = None | ||
280 | |||
281 | if self.loop: | ||
282 | self.loop.run_until_complete(self.__close_clients()) | ||
283 | self.loop.run_until_complete(self.loop.shutdown_asyncgens()) | ||
284 | self.loop.close() | ||
285 | self.loop = None | ||
286 | |||
287 | def run_tasks(self, tasks): | ||
288 | if not self.loop: | ||
289 | self.loop = asyncio.new_event_loop() | ||
290 | |||
291 | thread = Thread(target=self.__thread_main, args=(tasks,)) | ||
292 | thread.start() | ||
293 | thread.join() | ||
294 | |||
295 | @contextlib.asynccontextmanager | ||
296 | async def get_client(self): | ||
297 | async with self.client_condition: | ||
298 | if self.avail_clients: | ||
299 | client = self.avail_clients.pop() | ||
300 | elif self.num_clients < self.max_clients: | ||
301 | self.num_clients += 1 | ||
302 | client = await self._new_client() | ||
303 | else: | ||
304 | while not self.avail_clients: | ||
305 | await self.client_condition.wait() | ||
306 | client = self.avail_clients.pop() | ||
307 | |||
308 | try: | ||
309 | yield client | ||
310 | finally: | ||
311 | async with self.client_condition: | ||
312 | self.avail_clients.append(client) | ||
313 | self.client_condition.notify() | ||
314 | |||
315 | def __thread_main(self, tasks): | ||
316 | async def process_task(task): | ||
317 | async with self.get_client() as client: | ||
318 | await task(client) | ||
319 | |||
320 | asyncio.set_event_loop(self.loop) | ||
321 | if not self.client_condition: | ||
322 | self.client_condition = asyncio.Condition() | ||
323 | tasks = [process_task(t) for t in tasks] | ||
324 | self.loop.run_until_complete(asyncio.gather(*tasks)) | ||
325 | |||
326 | async def __close_clients(self): | ||
327 | for c in self.avail_clients: | ||
328 | await c.close() | ||
329 | self.avail_clients = [] | ||
330 | self.num_clients = 0 | ||
331 | |||
332 | def __enter__(self): | ||
333 | return self | ||
334 | |||
335 | def __exit__(self, exc_type, exc_value, traceback): | ||
336 | self.close() | ||
337 | return False | ||