diff options
Diffstat (limited to 'bitbake/lib/bb/asyncrpc/client.py')
-rw-r--r-- | bitbake/lib/bb/asyncrpc/client.py | 109 |
1 files changed, 31 insertions, 78 deletions
diff --git a/bitbake/lib/bb/asyncrpc/client.py b/bitbake/lib/bb/asyncrpc/client.py index a350b4fb12..9be49261c0 100644 --- a/bitbake/lib/bb/asyncrpc/client.py +++ b/bitbake/lib/bb/asyncrpc/client.py | |||
@@ -24,6 +24,12 @@ ADDR_TYPE_UNIX = 0 | |||
24 | ADDR_TYPE_TCP = 1 | 24 | ADDR_TYPE_TCP = 1 |
25 | ADDR_TYPE_WS = 2 | 25 | ADDR_TYPE_WS = 2 |
26 | 26 | ||
27 | WEBSOCKETS_MIN_VERSION = (9, 1) | ||
28 | # Need websockets 10 with python 3.10+ | ||
29 | if sys.version_info >= (3, 10, 0): | ||
30 | WEBSOCKETS_MIN_VERSION = (10, 0) | ||
31 | |||
32 | |||
27 | def parse_address(addr): | 33 | def parse_address(addr): |
28 | if addr.startswith(UNIX_PREFIX): | 34 | if addr.startswith(UNIX_PREFIX): |
29 | return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX) :],)) | 35 | return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX) :],)) |
@@ -39,6 +45,7 @@ def parse_address(addr): | |||
39 | 45 | ||
40 | return (ADDR_TYPE_TCP, (host, int(port))) | 46 | return (ADDR_TYPE_TCP, (host, int(port))) |
41 | 47 | ||
48 | |||
42 | class AsyncClient(object): | 49 | class AsyncClient(object): |
43 | def __init__( | 50 | def __init__( |
44 | self, | 51 | self, |
@@ -86,8 +93,30 @@ class AsyncClient(object): | |||
86 | async def connect_websocket(self, uri): | 93 | async def connect_websocket(self, uri): |
87 | import websockets | 94 | import websockets |
88 | 95 | ||
96 | try: | ||
97 | version = tuple( | ||
98 | int(v) | ||
99 | for v in websockets.__version__.split(".")[ | ||
100 | 0 : len(WEBSOCKETS_MIN_VERSION) | ||
101 | ] | ||
102 | ) | ||
103 | except ValueError: | ||
104 | raise ImportError( | ||
105 | f"Unable to parse websockets version '{websockets.__version__}'" | ||
106 | ) | ||
107 | |||
108 | if version < WEBSOCKETS_MIN_VERSION: | ||
109 | min_ver_str = ".".join(str(v) for v in WEBSOCKETS_MIN_VERSION) | ||
110 | raise ImportError( | ||
111 | f"Websockets version {websockets.__version__} is less than minimum required version {min_ver_str}" | ||
112 | ) | ||
113 | |||
89 | async def connect_sock(): | 114 | async def connect_sock(): |
90 | websocket = await websockets.connect(uri, ping_interval=None) | 115 | websocket = await websockets.connect( |
116 | uri, | ||
117 | ping_interval=None, | ||
118 | open_timeout=self.timeout, | ||
119 | ) | ||
91 | return WebsocketConnection(websocket, self.timeout) | 120 | return WebsocketConnection(websocket, self.timeout) |
92 | 121 | ||
93 | self._connect_sock = connect_sock | 122 | self._connect_sock = connect_sock |
@@ -225,85 +254,9 @@ class Client(object): | |||
225 | def close(self): | 254 | def close(self): |
226 | if self.loop: | 255 | if self.loop: |
227 | self.loop.run_until_complete(self.client.close()) | 256 | self.loop.run_until_complete(self.client.close()) |
228 | if sys.version_info >= (3, 6): | ||
229 | self.loop.run_until_complete(self.loop.shutdown_asyncgens()) | ||
230 | self.loop.close() | ||
231 | self.loop = None | ||
232 | |||
233 | def __enter__(self): | ||
234 | return self | ||
235 | |||
236 | def __exit__(self, exc_type, exc_value, traceback): | ||
237 | self.close() | ||
238 | return False | ||
239 | |||
240 | |||
241 | class ClientPool(object): | ||
242 | def __init__(self, max_clients): | ||
243 | self.avail_clients = [] | ||
244 | self.num_clients = 0 | ||
245 | self.max_clients = max_clients | ||
246 | self.loop = None | ||
247 | self.client_condition = None | ||
248 | |||
249 | @abc.abstractmethod | ||
250 | async def _new_client(self): | ||
251 | raise NotImplementedError("Must be implemented in derived class") | ||
252 | |||
253 | def close(self): | ||
254 | if self.client_condition: | ||
255 | self.client_condition = None | ||
256 | |||
257 | if self.loop: | ||
258 | self.loop.run_until_complete(self.__close_clients()) | ||
259 | self.loop.run_until_complete(self.loop.shutdown_asyncgens()) | 257 | self.loop.run_until_complete(self.loop.shutdown_asyncgens()) |
260 | self.loop.close() | 258 | self.loop.close() |
261 | self.loop = None | 259 | self.loop = None |
262 | |||
263 | def run_tasks(self, tasks): | ||
264 | if not self.loop: | ||
265 | self.loop = asyncio.new_event_loop() | ||
266 | |||
267 | thread = Thread(target=self.__thread_main, args=(tasks,)) | ||
268 | thread.start() | ||
269 | thread.join() | ||
270 | |||
271 | @contextlib.asynccontextmanager | ||
272 | async def get_client(self): | ||
273 | async with self.client_condition: | ||
274 | if self.avail_clients: | ||
275 | client = self.avail_clients.pop() | ||
276 | elif self.num_clients < self.max_clients: | ||
277 | self.num_clients += 1 | ||
278 | client = await self._new_client() | ||
279 | else: | ||
280 | while not self.avail_clients: | ||
281 | await self.client_condition.wait() | ||
282 | client = self.avail_clients.pop() | ||
283 | |||
284 | try: | ||
285 | yield client | ||
286 | finally: | ||
287 | async with self.client_condition: | ||
288 | self.avail_clients.append(client) | ||
289 | self.client_condition.notify() | ||
290 | |||
291 | def __thread_main(self, tasks): | ||
292 | async def process_task(task): | ||
293 | async with self.get_client() as client: | ||
294 | await task(client) | ||
295 | |||
296 | asyncio.set_event_loop(self.loop) | ||
297 | if not self.client_condition: | ||
298 | self.client_condition = asyncio.Condition() | ||
299 | tasks = [process_task(t) for t in tasks] | ||
300 | self.loop.run_until_complete(asyncio.gather(*tasks)) | ||
301 | |||
302 | async def __close_clients(self): | ||
303 | for c in self.avail_clients: | ||
304 | await c.close() | ||
305 | self.avail_clients = [] | ||
306 | self.num_clients = 0 | ||
307 | 260 | ||
308 | def __enter__(self): | 261 | def __enter__(self): |
309 | return self | 262 | return self |