summaryrefslogtreecommitdiffstats
path: root/bitbake/lib/bb/asyncrpc
diff options
context:
space:
mode:
Diffstat (limited to 'bitbake/lib/bb/asyncrpc')
-rw-r--r--bitbake/lib/bb/asyncrpc/__init__.py2
-rw-r--r--bitbake/lib/bb/asyncrpc/client.py90
-rw-r--r--bitbake/lib/bb/asyncrpc/serv.py37
3 files changed, 41 insertions, 88 deletions
diff --git a/bitbake/lib/bb/asyncrpc/__init__.py b/bitbake/lib/bb/asyncrpc/__init__.py
index 639e1607f8..a4371643d7 100644
--- a/bitbake/lib/bb/asyncrpc/__init__.py
+++ b/bitbake/lib/bb/asyncrpc/__init__.py
@@ -5,7 +5,7 @@
5# 5#
6 6
7 7
8from .client import AsyncClient, Client, ClientPool 8from .client import AsyncClient, Client
9from .serv import AsyncServer, AsyncServerConnection 9from .serv import AsyncServer, AsyncServerConnection
10from .connection import DEFAULT_MAX_CHUNK 10from .connection import DEFAULT_MAX_CHUNK
11from .exceptions import ( 11from .exceptions import (
diff --git a/bitbake/lib/bb/asyncrpc/client.py b/bitbake/lib/bb/asyncrpc/client.py
index b49de99313..17b72033b9 100644
--- a/bitbake/lib/bb/asyncrpc/client.py
+++ b/bitbake/lib/bb/asyncrpc/client.py
@@ -29,6 +29,7 @@ WEBSOCKETS_MIN_VERSION = (9, 1)
29if sys.version_info >= (3, 10, 0): 29if sys.version_info >= (3, 10, 0):
30 WEBSOCKETS_MIN_VERSION = (10, 0) 30 WEBSOCKETS_MIN_VERSION = (10, 0)
31 31
32
32def parse_address(addr): 33def parse_address(addr):
33 if addr.startswith(UNIX_PREFIX): 34 if addr.startswith(UNIX_PREFIX):
34 return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX) :],)) 35 return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX) :],))
@@ -111,7 +112,16 @@ class AsyncClient(object):
111 ) 112 )
112 113
113 async def connect_sock(): 114 async def connect_sock():
114 websocket = await websockets.connect(uri, ping_interval=None) 115 try:
116 websocket = await websockets.connect(
117 uri,
118 ping_interval=None,
119 open_timeout=self.timeout,
120 )
121 except asyncio.exceptions.TimeoutError:
122 raise ConnectionError("Timeout while connecting to websocket")
123 except (OSError, websockets.InvalidHandshake, websockets.InvalidURI) as exc:
124 raise ConnectionError(f"Could not connect to websocket: {exc}") from exc
115 return WebsocketConnection(websocket, self.timeout) 125 return WebsocketConnection(websocket, self.timeout)
116 126
117 self._connect_sock = connect_sock 127 self._connect_sock = connect_sock
@@ -249,85 +259,9 @@ class Client(object):
249 def close(self): 259 def close(self):
250 if self.loop: 260 if self.loop:
251 self.loop.run_until_complete(self.client.close()) 261 self.loop.run_until_complete(self.client.close())
252 if sys.version_info >= (3, 6):
253 self.loop.run_until_complete(self.loop.shutdown_asyncgens())
254 self.loop.close()
255 self.loop = None
256
257 def __enter__(self):
258 return self
259
260 def __exit__(self, exc_type, exc_value, traceback):
261 self.close()
262 return False
263
264
265class ClientPool(object):
266 def __init__(self, max_clients):
267 self.avail_clients = []
268 self.num_clients = 0
269 self.max_clients = max_clients
270 self.loop = None
271 self.client_condition = None
272
273 @abc.abstractmethod
274 async def _new_client(self):
275 raise NotImplementedError("Must be implemented in derived class")
276
277 def close(self):
278 if self.client_condition:
279 self.client_condition = None
280
281 if self.loop:
282 self.loop.run_until_complete(self.__close_clients())
283 self.loop.run_until_complete(self.loop.shutdown_asyncgens()) 262 self.loop.run_until_complete(self.loop.shutdown_asyncgens())
284 self.loop.close() 263 self.loop.close()
285 self.loop = None 264 self.loop = None
286
287 def run_tasks(self, tasks):
288 if not self.loop:
289 self.loop = asyncio.new_event_loop()
290
291 thread = Thread(target=self.__thread_main, args=(tasks,))
292 thread.start()
293 thread.join()
294
295 @contextlib.asynccontextmanager
296 async def get_client(self):
297 async with self.client_condition:
298 if self.avail_clients:
299 client = self.avail_clients.pop()
300 elif self.num_clients < self.max_clients:
301 self.num_clients += 1
302 client = await self._new_client()
303 else:
304 while not self.avail_clients:
305 await self.client_condition.wait()
306 client = self.avail_clients.pop()
307
308 try:
309 yield client
310 finally:
311 async with self.client_condition:
312 self.avail_clients.append(client)
313 self.client_condition.notify()
314
315 def __thread_main(self, tasks):
316 async def process_task(task):
317 async with self.get_client() as client:
318 await task(client)
319
320 asyncio.set_event_loop(self.loop)
321 if not self.client_condition:
322 self.client_condition = asyncio.Condition()
323 tasks = [process_task(t) for t in tasks]
324 self.loop.run_until_complete(asyncio.gather(*tasks))
325
326 async def __close_clients(self):
327 for c in self.avail_clients:
328 await c.close()
329 self.avail_clients = []
330 self.num_clients = 0
331 265
332 def __enter__(self): 266 def __enter__(self):
333 return self 267 return self
diff --git a/bitbake/lib/bb/asyncrpc/serv.py b/bitbake/lib/bb/asyncrpc/serv.py
index a66117acad..667217c5c1 100644
--- a/bitbake/lib/bb/asyncrpc/serv.py
+++ b/bitbake/lib/bb/asyncrpc/serv.py
@@ -138,14 +138,20 @@ class StreamServer(object):
138 138
139 139
140class TCPStreamServer(StreamServer): 140class TCPStreamServer(StreamServer):
141 def __init__(self, host, port, handler, logger): 141 def __init__(self, host, port, handler, logger, *, reuseport=False):
142 super().__init__(handler, logger) 142 super().__init__(handler, logger)
143 self.host = host 143 self.host = host
144 self.port = port 144 self.port = port
145 self.reuseport = reuseport
145 146
146 def start(self, loop): 147 def start(self, loop):
147 self.server = loop.run_until_complete( 148 self.server = loop.run_until_complete(
148 asyncio.start_server(self.handle_stream_client, self.host, self.port) 149 asyncio.start_server(
150 self.handle_stream_client,
151 self.host,
152 self.port,
153 reuse_port=self.reuseport,
154 )
149 ) 155 )
150 156
151 for s in self.server.sockets: 157 for s in self.server.sockets:
@@ -209,11 +215,12 @@ class UnixStreamServer(StreamServer):
209 215
210 216
211class WebsocketsServer(object): 217class WebsocketsServer(object):
212 def __init__(self, host, port, handler, logger): 218 def __init__(self, host, port, handler, logger, *, reuseport=False):
213 self.host = host 219 self.host = host
214 self.port = port 220 self.port = port
215 self.handler = handler 221 self.handler = handler
216 self.logger = logger 222 self.logger = logger
223 self.reuseport = reuseport
217 224
218 def start(self, loop): 225 def start(self, loop):
219 import websockets.server 226 import websockets.server
@@ -224,6 +231,7 @@ class WebsocketsServer(object):
224 self.host, 231 self.host,
225 self.port, 232 self.port,
226 ping_interval=None, 233 ping_interval=None,
234 reuse_port=self.reuseport,
227 ) 235 )
228 ) 236 )
229 237
@@ -262,14 +270,26 @@ class AsyncServer(object):
262 self.loop = None 270 self.loop = None
263 self.run_tasks = [] 271 self.run_tasks = []
264 272
265 def start_tcp_server(self, host, port): 273 def start_tcp_server(self, host, port, *, reuseport=False):
266 self.server = TCPStreamServer(host, port, self._client_handler, self.logger) 274 self.server = TCPStreamServer(
275 host,
276 port,
277 self._client_handler,
278 self.logger,
279 reuseport=reuseport,
280 )
267 281
268 def start_unix_server(self, path): 282 def start_unix_server(self, path):
269 self.server = UnixStreamServer(path, self._client_handler, self.logger) 283 self.server = UnixStreamServer(path, self._client_handler, self.logger)
270 284
271 def start_websocket_server(self, host, port): 285 def start_websocket_server(self, host, port, reuseport=False):
272 self.server = WebsocketsServer(host, port, self._client_handler, self.logger) 286 self.server = WebsocketsServer(
287 host,
288 port,
289 self._client_handler,
290 self.logger,
291 reuseport=reuseport,
292 )
273 293
274 async def _client_handler(self, socket): 294 async def _client_handler(self, socket):
275 address = socket.address 295 address = socket.address
@@ -368,8 +388,7 @@ class AsyncServer(object):
368 388
369 self._serve_forever(tasks) 389 self._serve_forever(tasks)
370 390
371 if sys.version_info >= (3, 6): 391 self.loop.run_until_complete(self.loop.shutdown_asyncgens())
372 self.loop.run_until_complete(self.loop.shutdown_asyncgens())
373 self.loop.close() 392 self.loop.close()
374 393
375 queue = multiprocessing.Queue() 394 queue = multiprocessing.Queue()