diff options
Diffstat (limited to 'bitbake/lib/bb/asyncrpc/client.py')
-rw-r--r-- | bitbake/lib/bb/asyncrpc/client.py | 137 |
1 files changed, 59 insertions, 78 deletions
diff --git a/bitbake/lib/bb/asyncrpc/client.py b/bitbake/lib/bb/asyncrpc/client.py index 29a5ab76aa..17b72033b9 100644 --- a/bitbake/lib/bb/asyncrpc/client.py +++ b/bitbake/lib/bb/asyncrpc/client.py | |||
@@ -10,11 +10,41 @@ import json | |||
10 | import os | 10 | import os |
11 | import socket | 11 | import socket |
12 | import sys | 12 | import sys |
13 | import re | ||
13 | import contextlib | 14 | import contextlib |
14 | from threading import Thread | 15 | from threading import Thread |
15 | from .connection import StreamConnection, WebsocketConnection, DEFAULT_MAX_CHUNK | 16 | from .connection import StreamConnection, WebsocketConnection, DEFAULT_MAX_CHUNK |
16 | from .exceptions import ConnectionClosedError, InvokeError | 17 | from .exceptions import ConnectionClosedError, InvokeError |
17 | 18 | ||
19 | UNIX_PREFIX = "unix://" | ||
20 | WS_PREFIX = "ws://" | ||
21 | WSS_PREFIX = "wss://" | ||
22 | |||
23 | ADDR_TYPE_UNIX = 0 | ||
24 | ADDR_TYPE_TCP = 1 | ||
25 | ADDR_TYPE_WS = 2 | ||
26 | |||
27 | WEBSOCKETS_MIN_VERSION = (9, 1) | ||
28 | # Need websockets 10 with python 3.10+ | ||
29 | if sys.version_info >= (3, 10, 0): | ||
30 | WEBSOCKETS_MIN_VERSION = (10, 0) | ||
31 | |||
32 | |||
33 | def parse_address(addr): | ||
34 | if addr.startswith(UNIX_PREFIX): | ||
35 | return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX) :],)) | ||
36 | elif addr.startswith(WS_PREFIX) or addr.startswith(WSS_PREFIX): | ||
37 | return (ADDR_TYPE_WS, (addr,)) | ||
38 | else: | ||
39 | m = re.match(r"\[(?P<host>[^\]]*)\]:(?P<port>\d+)$", addr) | ||
40 | if m is not None: | ||
41 | host = m.group("host") | ||
42 | port = m.group("port") | ||
43 | else: | ||
44 | host, port = addr.split(":") | ||
45 | |||
46 | return (ADDR_TYPE_TCP, (host, int(port))) | ||
47 | |||
18 | 48 | ||
19 | class AsyncClient(object): | 49 | class AsyncClient(object): |
20 | def __init__( | 50 | def __init__( |
@@ -63,8 +93,35 @@ class AsyncClient(object): | |||
63 | async def connect_websocket(self, uri): | 93 | async def connect_websocket(self, uri): |
64 | import websockets | 94 | import websockets |
65 | 95 | ||
96 | try: | ||
97 | version = tuple( | ||
98 | int(v) | ||
99 | for v in websockets.__version__.split(".")[ | ||
100 | 0 : len(WEBSOCKETS_MIN_VERSION) | ||
101 | ] | ||
102 | ) | ||
103 | except ValueError: | ||
104 | raise ImportError( | ||
105 | f"Unable to parse websockets version '{websockets.__version__}'" | ||
106 | ) | ||
107 | |||
108 | if version < WEBSOCKETS_MIN_VERSION: | ||
109 | min_ver_str = ".".join(str(v) for v in WEBSOCKETS_MIN_VERSION) | ||
110 | raise ImportError( | ||
111 | f"Websockets version {websockets.__version__} is less than minimum required version {min_ver_str}" | ||
112 | ) | ||
113 | |||
66 | async def connect_sock(): | 114 | async def connect_sock(): |
67 | websocket = await websockets.connect(uri, ping_interval=None) | 115 | try: |
116 | websocket = await websockets.connect( | ||
117 | uri, | ||
118 | ping_interval=None, | ||
119 | open_timeout=self.timeout, | ||
120 | ) | ||
121 | except asyncio.exceptions.TimeoutError: | ||
122 | raise ConnectionError("Timeout while connecting to websocket") | ||
123 | except (OSError, websockets.InvalidHandshake, websockets.InvalidURI) as exc: | ||
124 | raise ConnectionError(f"Could not connect to websocket: {exc}") from exc | ||
68 | return WebsocketConnection(websocket, self.timeout) | 125 | return WebsocketConnection(websocket, self.timeout) |
69 | 126 | ||
70 | self._connect_sock = connect_sock | 127 | self._connect_sock = connect_sock |
@@ -202,85 +259,9 @@ class Client(object): | |||
202 | def close(self): | 259 | def close(self): |
203 | if self.loop: | 260 | if self.loop: |
204 | self.loop.run_until_complete(self.client.close()) | 261 | self.loop.run_until_complete(self.client.close()) |
205 | if sys.version_info >= (3, 6): | ||
206 | self.loop.run_until_complete(self.loop.shutdown_asyncgens()) | ||
207 | self.loop.close() | ||
208 | self.loop = None | ||
209 | |||
210 | def __enter__(self): | ||
211 | return self | ||
212 | |||
213 | def __exit__(self, exc_type, exc_value, traceback): | ||
214 | self.close() | ||
215 | return False | ||
216 | |||
217 | |||
218 | class ClientPool(object): | ||
219 | def __init__(self, max_clients): | ||
220 | self.avail_clients = [] | ||
221 | self.num_clients = 0 | ||
222 | self.max_clients = max_clients | ||
223 | self.loop = None | ||
224 | self.client_condition = None | ||
225 | |||
226 | @abc.abstractmethod | ||
227 | async def _new_client(self): | ||
228 | raise NotImplementedError("Must be implemented in derived class") | ||
229 | |||
230 | def close(self): | ||
231 | if self.client_condition: | ||
232 | self.client_condition = None | ||
233 | |||
234 | if self.loop: | ||
235 | self.loop.run_until_complete(self.__close_clients()) | ||
236 | self.loop.run_until_complete(self.loop.shutdown_asyncgens()) | 262 | self.loop.run_until_complete(self.loop.shutdown_asyncgens()) |
237 | self.loop.close() | 263 | self.loop.close() |
238 | self.loop = None | 264 | self.loop = None |
239 | |||
240 | def run_tasks(self, tasks): | ||
241 | if not self.loop: | ||
242 | self.loop = asyncio.new_event_loop() | ||
243 | |||
244 | thread = Thread(target=self.__thread_main, args=(tasks,)) | ||
245 | thread.start() | ||
246 | thread.join() | ||
247 | |||
248 | @contextlib.asynccontextmanager | ||
249 | async def get_client(self): | ||
250 | async with self.client_condition: | ||
251 | if self.avail_clients: | ||
252 | client = self.avail_clients.pop() | ||
253 | elif self.num_clients < self.max_clients: | ||
254 | self.num_clients += 1 | ||
255 | client = await self._new_client() | ||
256 | else: | ||
257 | while not self.avail_clients: | ||
258 | await self.client_condition.wait() | ||
259 | client = self.avail_clients.pop() | ||
260 | |||
261 | try: | ||
262 | yield client | ||
263 | finally: | ||
264 | async with self.client_condition: | ||
265 | self.avail_clients.append(client) | ||
266 | self.client_condition.notify() | ||
267 | |||
268 | def __thread_main(self, tasks): | ||
269 | async def process_task(task): | ||
270 | async with self.get_client() as client: | ||
271 | await task(client) | ||
272 | |||
273 | asyncio.set_event_loop(self.loop) | ||
274 | if not self.client_condition: | ||
275 | self.client_condition = asyncio.Condition() | ||
276 | tasks = [process_task(t) for t in tasks] | ||
277 | self.loop.run_until_complete(asyncio.gather(*tasks)) | ||
278 | |||
279 | async def __close_clients(self): | ||
280 | for c in self.avail_clients: | ||
281 | await c.close() | ||
282 | self.avail_clients = [] | ||
283 | self.num_clients = 0 | ||
284 | 265 | ||
285 | def __enter__(self): | 266 | def __enter__(self): |
286 | return self | 267 | return self |