diff options
Diffstat (limited to 'bitbake/lib/bb/asyncrpc/client.py')
-rw-r--r-- | bitbake/lib/bb/asyncrpc/client.py | 90 |
1 files changed, 12 insertions, 78 deletions
diff --git a/bitbake/lib/bb/asyncrpc/client.py b/bitbake/lib/bb/asyncrpc/client.py index b49de99313..17b72033b9 100644 --- a/bitbake/lib/bb/asyncrpc/client.py +++ b/bitbake/lib/bb/asyncrpc/client.py | |||
@@ -29,6 +29,7 @@ WEBSOCKETS_MIN_VERSION = (9, 1) | |||
29 | if sys.version_info >= (3, 10, 0): | 29 | if sys.version_info >= (3, 10, 0): |
30 | WEBSOCKETS_MIN_VERSION = (10, 0) | 30 | WEBSOCKETS_MIN_VERSION = (10, 0) |
31 | 31 | ||
32 | |||
32 | def parse_address(addr): | 33 | def parse_address(addr): |
33 | if addr.startswith(UNIX_PREFIX): | 34 | if addr.startswith(UNIX_PREFIX): |
34 | return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX) :],)) | 35 | return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX) :],)) |
@@ -111,7 +112,16 @@ class AsyncClient(object): | |||
111 | ) | 112 | ) |
112 | 113 | ||
113 | async def connect_sock(): | 114 | async def connect_sock(): |
114 | websocket = await websockets.connect(uri, ping_interval=None) | 115 | try: |
116 | websocket = await websockets.connect( | ||
117 | uri, | ||
118 | ping_interval=None, | ||
119 | open_timeout=self.timeout, | ||
120 | ) | ||
121 | except asyncio.exceptions.TimeoutError: | ||
122 | raise ConnectionError("Timeout while connecting to websocket") | ||
123 | except (OSError, websockets.InvalidHandshake, websockets.InvalidURI) as exc: | ||
124 | raise ConnectionError(f"Could not connect to websocket: {exc}") from exc | ||
115 | return WebsocketConnection(websocket, self.timeout) | 125 | return WebsocketConnection(websocket, self.timeout) |
116 | 126 | ||
117 | self._connect_sock = connect_sock | 127 | self._connect_sock = connect_sock |
@@ -249,85 +259,9 @@ class Client(object): | |||
249 | def close(self): | 259 | def close(self): |
250 | if self.loop: | 260 | if self.loop: |
251 | self.loop.run_until_complete(self.client.close()) | 261 | self.loop.run_until_complete(self.client.close()) |
252 | if sys.version_info >= (3, 6): | ||
253 | self.loop.run_until_complete(self.loop.shutdown_asyncgens()) | ||
254 | self.loop.close() | ||
255 | self.loop = None | ||
256 | |||
257 | def __enter__(self): | ||
258 | return self | ||
259 | |||
260 | def __exit__(self, exc_type, exc_value, traceback): | ||
261 | self.close() | ||
262 | return False | ||
263 | |||
264 | |||
265 | class ClientPool(object): | ||
266 | def __init__(self, max_clients): | ||
267 | self.avail_clients = [] | ||
268 | self.num_clients = 0 | ||
269 | self.max_clients = max_clients | ||
270 | self.loop = None | ||
271 | self.client_condition = None | ||
272 | |||
273 | @abc.abstractmethod | ||
274 | async def _new_client(self): | ||
275 | raise NotImplementedError("Must be implemented in derived class") | ||
276 | |||
277 | def close(self): | ||
278 | if self.client_condition: | ||
279 | self.client_condition = None | ||
280 | |||
281 | if self.loop: | ||
282 | self.loop.run_until_complete(self.__close_clients()) | ||
283 | self.loop.run_until_complete(self.loop.shutdown_asyncgens()) | 262 | self.loop.run_until_complete(self.loop.shutdown_asyncgens()) |
284 | self.loop.close() | 263 | self.loop.close() |
285 | self.loop = None | 264 | self.loop = None |
286 | |||
287 | def run_tasks(self, tasks): | ||
288 | if not self.loop: | ||
289 | self.loop = asyncio.new_event_loop() | ||
290 | |||
291 | thread = Thread(target=self.__thread_main, args=(tasks,)) | ||
292 | thread.start() | ||
293 | thread.join() | ||
294 | |||
295 | @contextlib.asynccontextmanager | ||
296 | async def get_client(self): | ||
297 | async with self.client_condition: | ||
298 | if self.avail_clients: | ||
299 | client = self.avail_clients.pop() | ||
300 | elif self.num_clients < self.max_clients: | ||
301 | self.num_clients += 1 | ||
302 | client = await self._new_client() | ||
303 | else: | ||
304 | while not self.avail_clients: | ||
305 | await self.client_condition.wait() | ||
306 | client = self.avail_clients.pop() | ||
307 | |||
308 | try: | ||
309 | yield client | ||
310 | finally: | ||
311 | async with self.client_condition: | ||
312 | self.avail_clients.append(client) | ||
313 | self.client_condition.notify() | ||
314 | |||
315 | def __thread_main(self, tasks): | ||
316 | async def process_task(task): | ||
317 | async with self.get_client() as client: | ||
318 | await task(client) | ||
319 | |||
320 | asyncio.set_event_loop(self.loop) | ||
321 | if not self.client_condition: | ||
322 | self.client_condition = asyncio.Condition() | ||
323 | tasks = [process_task(t) for t in tasks] | ||
324 | self.loop.run_until_complete(asyncio.gather(*tasks)) | ||
325 | |||
326 | async def __close_clients(self): | ||
327 | for c in self.avail_clients: | ||
328 | await c.close() | ||
329 | self.avail_clients = [] | ||
330 | self.num_clients = 0 | ||
331 | 265 | ||
332 | def __enter__(self): | 266 | def __enter__(self): |
333 | return self | 267 | return self |