summaryrefslogtreecommitdiffstats
path: root/bitbake/lib/bb/asyncrpc/client.py
diff options
context:
space:
mode:
Diffstat (limited to 'bitbake/lib/bb/asyncrpc/client.py')
-rw-r--r--bitbake/lib/bb/asyncrpc/client.py137
1 files changed, 59 insertions, 78 deletions
diff --git a/bitbake/lib/bb/asyncrpc/client.py b/bitbake/lib/bb/asyncrpc/client.py
index 29a5ab76aa..17b72033b9 100644
--- a/bitbake/lib/bb/asyncrpc/client.py
+++ b/bitbake/lib/bb/asyncrpc/client.py
@@ -10,11 +10,41 @@ import json
10import os 10import os
11import socket 11import socket
12import sys 12import sys
13import re
13import contextlib 14import contextlib
14from threading import Thread 15from threading import Thread
15from .connection import StreamConnection, WebsocketConnection, DEFAULT_MAX_CHUNK 16from .connection import StreamConnection, WebsocketConnection, DEFAULT_MAX_CHUNK
16from .exceptions import ConnectionClosedError, InvokeError 17from .exceptions import ConnectionClosedError, InvokeError
17 18
19UNIX_PREFIX = "unix://"
20WS_PREFIX = "ws://"
21WSS_PREFIX = "wss://"
22
23ADDR_TYPE_UNIX = 0
24ADDR_TYPE_TCP = 1
25ADDR_TYPE_WS = 2
26
27WEBSOCKETS_MIN_VERSION = (9, 1)
28# Need websockets 10 with python 3.10+
29if sys.version_info >= (3, 10, 0):
30 WEBSOCKETS_MIN_VERSION = (10, 0)
31
32
33def parse_address(addr):
34 if addr.startswith(UNIX_PREFIX):
35 return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX) :],))
36 elif addr.startswith(WS_PREFIX) or addr.startswith(WSS_PREFIX):
37 return (ADDR_TYPE_WS, (addr,))
38 else:
39 m = re.match(r"\[(?P<host>[^\]]*)\]:(?P<port>\d+)$", addr)
40 if m is not None:
41 host = m.group("host")
42 port = m.group("port")
43 else:
44 host, port = addr.split(":")
45
46 return (ADDR_TYPE_TCP, (host, int(port)))
47
18 48
19class AsyncClient(object): 49class AsyncClient(object):
20 def __init__( 50 def __init__(
@@ -63,8 +93,35 @@ class AsyncClient(object):
63 async def connect_websocket(self, uri): 93 async def connect_websocket(self, uri):
64 import websockets 94 import websockets
65 95
96 try:
97 version = tuple(
98 int(v)
99 for v in websockets.__version__.split(".")[
100 0 : len(WEBSOCKETS_MIN_VERSION)
101 ]
102 )
103 except ValueError:
104 raise ImportError(
105 f"Unable to parse websockets version '{websockets.__version__}'"
106 )
107
108 if version < WEBSOCKETS_MIN_VERSION:
109 min_ver_str = ".".join(str(v) for v in WEBSOCKETS_MIN_VERSION)
110 raise ImportError(
111 f"Websockets version {websockets.__version__} is less than minimum required version {min_ver_str}"
112 )
113
66 async def connect_sock(): 114 async def connect_sock():
67 websocket = await websockets.connect(uri, ping_interval=None) 115 try:
116 websocket = await websockets.connect(
117 uri,
118 ping_interval=None,
119 open_timeout=self.timeout,
120 )
121 except asyncio.exceptions.TimeoutError:
122 raise ConnectionError("Timeout while connecting to websocket")
123 except (OSError, websockets.InvalidHandshake, websockets.InvalidURI) as exc:
124 raise ConnectionError(f"Could not connect to websocket: {exc}") from exc
68 return WebsocketConnection(websocket, self.timeout) 125 return WebsocketConnection(websocket, self.timeout)
69 126
70 self._connect_sock = connect_sock 127 self._connect_sock = connect_sock
@@ -202,85 +259,9 @@ class Client(object):
202 def close(self): 259 def close(self):
203 if self.loop: 260 if self.loop:
204 self.loop.run_until_complete(self.client.close()) 261 self.loop.run_until_complete(self.client.close())
205 if sys.version_info >= (3, 6):
206 self.loop.run_until_complete(self.loop.shutdown_asyncgens())
207 self.loop.close()
208 self.loop = None
209
210 def __enter__(self):
211 return self
212
213 def __exit__(self, exc_type, exc_value, traceback):
214 self.close()
215 return False
216
217
218class ClientPool(object):
219 def __init__(self, max_clients):
220 self.avail_clients = []
221 self.num_clients = 0
222 self.max_clients = max_clients
223 self.loop = None
224 self.client_condition = None
225
226 @abc.abstractmethod
227 async def _new_client(self):
228 raise NotImplementedError("Must be implemented in derived class")
229
230 def close(self):
231 if self.client_condition:
232 self.client_condition = None
233
234 if self.loop:
235 self.loop.run_until_complete(self.__close_clients())
236 self.loop.run_until_complete(self.loop.shutdown_asyncgens()) 262 self.loop.run_until_complete(self.loop.shutdown_asyncgens())
237 self.loop.close() 263 self.loop.close()
238 self.loop = None 264 self.loop = None
239
240 def run_tasks(self, tasks):
241 if not self.loop:
242 self.loop = asyncio.new_event_loop()
243
244 thread = Thread(target=self.__thread_main, args=(tasks,))
245 thread.start()
246 thread.join()
247
248 @contextlib.asynccontextmanager
249 async def get_client(self):
250 async with self.client_condition:
251 if self.avail_clients:
252 client = self.avail_clients.pop()
253 elif self.num_clients < self.max_clients:
254 self.num_clients += 1
255 client = await self._new_client()
256 else:
257 while not self.avail_clients:
258 await self.client_condition.wait()
259 client = self.avail_clients.pop()
260
261 try:
262 yield client
263 finally:
264 async with self.client_condition:
265 self.avail_clients.append(client)
266 self.client_condition.notify()
267
268 def __thread_main(self, tasks):
269 async def process_task(task):
270 async with self.get_client() as client:
271 await task(client)
272
273 asyncio.set_event_loop(self.loop)
274 if not self.client_condition:
275 self.client_condition = asyncio.Condition()
276 tasks = [process_task(t) for t in tasks]
277 self.loop.run_until_complete(asyncio.gather(*tasks))
278
279 async def __close_clients(self):
280 for c in self.avail_clients:
281 await c.close()
282 self.avail_clients = []
283 self.num_clients = 0
284 265
285 def __enter__(self): 266 def __enter__(self):
286 return self 267 return self