summaryrefslogtreecommitdiffstats
path: root/bitbake/lib/bb/asyncrpc/client.py
diff options
context:
space:
mode:
Diffstat (limited to 'bitbake/lib/bb/asyncrpc/client.py')
-rw-r--r--bitbake/lib/bb/asyncrpc/client.py337
1 files changed, 337 insertions, 0 deletions
diff --git a/bitbake/lib/bb/asyncrpc/client.py b/bitbake/lib/bb/asyncrpc/client.py
new file mode 100644
index 0000000000..b49de99313
--- /dev/null
+++ b/bitbake/lib/bb/asyncrpc/client.py
@@ -0,0 +1,337 @@
1#
2# Copyright BitBake Contributors
3#
4# SPDX-License-Identifier: GPL-2.0-only
5#
6
7import abc
8import asyncio
9import json
10import os
11import socket
12import sys
13import re
14import contextlib
15from threading import Thread
16from .connection import StreamConnection, WebsocketConnection, DEFAULT_MAX_CHUNK
17from .exceptions import ConnectionClosedError, InvokeError
18
19UNIX_PREFIX = "unix://"
20WS_PREFIX = "ws://"
21WSS_PREFIX = "wss://"
22
23ADDR_TYPE_UNIX = 0
24ADDR_TYPE_TCP = 1
25ADDR_TYPE_WS = 2
26
27WEBSOCKETS_MIN_VERSION = (9, 1)
28# Need websockets 10 with python 3.10+
29if sys.version_info >= (3, 10, 0):
30 WEBSOCKETS_MIN_VERSION = (10, 0)
31
32def parse_address(addr):
33 if addr.startswith(UNIX_PREFIX):
34 return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX) :],))
35 elif addr.startswith(WS_PREFIX) or addr.startswith(WSS_PREFIX):
36 return (ADDR_TYPE_WS, (addr,))
37 else:
38 m = re.match(r"\[(?P<host>[^\]]*)\]:(?P<port>\d+)$", addr)
39 if m is not None:
40 host = m.group("host")
41 port = m.group("port")
42 else:
43 host, port = addr.split(":")
44
45 return (ADDR_TYPE_TCP, (host, int(port)))
46
47
48class AsyncClient(object):
49 def __init__(
50 self,
51 proto_name,
52 proto_version,
53 logger,
54 timeout=30,
55 server_headers=False,
56 headers={},
57 ):
58 self.socket = None
59 self.max_chunk = DEFAULT_MAX_CHUNK
60 self.proto_name = proto_name
61 self.proto_version = proto_version
62 self.logger = logger
63 self.timeout = timeout
64 self.needs_server_headers = server_headers
65 self.server_headers = {}
66 self.headers = headers
67
68 async def connect_tcp(self, address, port):
69 async def connect_sock():
70 reader, writer = await asyncio.open_connection(address, port)
71 return StreamConnection(reader, writer, self.timeout, self.max_chunk)
72
73 self._connect_sock = connect_sock
74
75 async def connect_unix(self, path):
76 async def connect_sock():
77 # AF_UNIX has path length issues so chdir here to workaround
78 cwd = os.getcwd()
79 try:
80 os.chdir(os.path.dirname(path))
81 # The socket must be opened synchronously so that CWD doesn't get
82 # changed out from underneath us so we pass as a sock into asyncio
83 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0)
84 sock.connect(os.path.basename(path))
85 finally:
86 os.chdir(cwd)
87 reader, writer = await asyncio.open_unix_connection(sock=sock)
88 return StreamConnection(reader, writer, self.timeout, self.max_chunk)
89
90 self._connect_sock = connect_sock
91
92 async def connect_websocket(self, uri):
93 import websockets
94
95 try:
96 version = tuple(
97 int(v)
98 for v in websockets.__version__.split(".")[
99 0 : len(WEBSOCKETS_MIN_VERSION)
100 ]
101 )
102 except ValueError:
103 raise ImportError(
104 f"Unable to parse websockets version '{websockets.__version__}'"
105 )
106
107 if version < WEBSOCKETS_MIN_VERSION:
108 min_ver_str = ".".join(str(v) for v in WEBSOCKETS_MIN_VERSION)
109 raise ImportError(
110 f"Websockets version {websockets.__version__} is less than minimum required version {min_ver_str}"
111 )
112
113 async def connect_sock():
114 websocket = await websockets.connect(uri, ping_interval=None)
115 return WebsocketConnection(websocket, self.timeout)
116
117 self._connect_sock = connect_sock
118
119 async def setup_connection(self):
120 # Send headers
121 await self.socket.send("%s %s" % (self.proto_name, self.proto_version))
122 await self.socket.send(
123 "needs-headers: %s" % ("true" if self.needs_server_headers else "false")
124 )
125 for k, v in self.headers.items():
126 await self.socket.send("%s: %s" % (k, v))
127
128 # End of headers
129 await self.socket.send("")
130
131 self.server_headers = {}
132 if self.needs_server_headers:
133 while True:
134 line = await self.socket.recv()
135 if not line:
136 # End headers
137 break
138 tag, value = line.split(":", 1)
139 self.server_headers[tag.lower()] = value.strip()
140
141 async def get_header(self, tag, default):
142 await self.connect()
143 return self.server_headers.get(tag, default)
144
145 async def connect(self):
146 if self.socket is None:
147 self.socket = await self._connect_sock()
148 await self.setup_connection()
149
150 async def disconnect(self):
151 if self.socket is not None:
152 await self.socket.close()
153 self.socket = None
154
155 async def close(self):
156 await self.disconnect()
157
158 async def _send_wrapper(self, proc):
159 count = 0
160 while True:
161 try:
162 await self.connect()
163 return await proc()
164 except (
165 OSError,
166 ConnectionError,
167 ConnectionClosedError,
168 json.JSONDecodeError,
169 UnicodeDecodeError,
170 ) as e:
171 self.logger.warning("Error talking to server: %s" % e)
172 if count >= 3:
173 if not isinstance(e, ConnectionError):
174 raise ConnectionError(str(e))
175 raise e
176 await self.close()
177 count += 1
178
179 def check_invoke_error(self, msg):
180 if isinstance(msg, dict) and "invoke-error" in msg:
181 raise InvokeError(msg["invoke-error"]["message"])
182
183 async def invoke(self, msg):
184 async def proc():
185 await self.socket.send_message(msg)
186 return await self.socket.recv_message()
187
188 result = await self._send_wrapper(proc)
189 self.check_invoke_error(result)
190 return result
191
192 async def ping(self):
193 return await self.invoke({"ping": {}})
194
195 async def __aenter__(self):
196 return self
197
198 async def __aexit__(self, exc_type, exc_value, traceback):
199 await self.close()
200
201
202class Client(object):
203 def __init__(self):
204 self.client = self._get_async_client()
205 self.loop = asyncio.new_event_loop()
206
207 # Override any pre-existing loop.
208 # Without this, the PR server export selftest triggers a hang
209 # when running with Python 3.7. The drawback is that there is
210 # potential for issues if the PR and hash equiv (or some new)
211 # clients need to both be instantiated in the same process.
212 # This should be revisited if/when Python 3.9 becomes the
213 # minimum required version for BitBake, as it seems not
214 # required (but harmless) with it.
215 asyncio.set_event_loop(self.loop)
216
217 self._add_methods("connect_tcp", "ping")
218
219 @abc.abstractmethod
220 def _get_async_client(self):
221 pass
222
223 def _get_downcall_wrapper(self, downcall):
224 def wrapper(*args, **kwargs):
225 return self.loop.run_until_complete(downcall(*args, **kwargs))
226
227 return wrapper
228
229 def _add_methods(self, *methods):
230 for m in methods:
231 downcall = getattr(self.client, m)
232 setattr(self, m, self._get_downcall_wrapper(downcall))
233
234 def connect_unix(self, path):
235 self.loop.run_until_complete(self.client.connect_unix(path))
236 self.loop.run_until_complete(self.client.connect())
237
238 @property
239 def max_chunk(self):
240 return self.client.max_chunk
241
242 @max_chunk.setter
243 def max_chunk(self, value):
244 self.client.max_chunk = value
245
246 def disconnect(self):
247 self.loop.run_until_complete(self.client.close())
248
249 def close(self):
250 if self.loop:
251 self.loop.run_until_complete(self.client.close())
252 if sys.version_info >= (3, 6):
253 self.loop.run_until_complete(self.loop.shutdown_asyncgens())
254 self.loop.close()
255 self.loop = None
256
257 def __enter__(self):
258 return self
259
260 def __exit__(self, exc_type, exc_value, traceback):
261 self.close()
262 return False
263
264
265class ClientPool(object):
266 def __init__(self, max_clients):
267 self.avail_clients = []
268 self.num_clients = 0
269 self.max_clients = max_clients
270 self.loop = None
271 self.client_condition = None
272
273 @abc.abstractmethod
274 async def _new_client(self):
275 raise NotImplementedError("Must be implemented in derived class")
276
277 def close(self):
278 if self.client_condition:
279 self.client_condition = None
280
281 if self.loop:
282 self.loop.run_until_complete(self.__close_clients())
283 self.loop.run_until_complete(self.loop.shutdown_asyncgens())
284 self.loop.close()
285 self.loop = None
286
287 def run_tasks(self, tasks):
288 if not self.loop:
289 self.loop = asyncio.new_event_loop()
290
291 thread = Thread(target=self.__thread_main, args=(tasks,))
292 thread.start()
293 thread.join()
294
295 @contextlib.asynccontextmanager
296 async def get_client(self):
297 async with self.client_condition:
298 if self.avail_clients:
299 client = self.avail_clients.pop()
300 elif self.num_clients < self.max_clients:
301 self.num_clients += 1
302 client = await self._new_client()
303 else:
304 while not self.avail_clients:
305 await self.client_condition.wait()
306 client = self.avail_clients.pop()
307
308 try:
309 yield client
310 finally:
311 async with self.client_condition:
312 self.avail_clients.append(client)
313 self.client_condition.notify()
314
315 def __thread_main(self, tasks):
316 async def process_task(task):
317 async with self.get_client() as client:
318 await task(client)
319
320 asyncio.set_event_loop(self.loop)
321 if not self.client_condition:
322 self.client_condition = asyncio.Condition()
323 tasks = [process_task(t) for t in tasks]
324 self.loop.run_until_complete(asyncio.gather(*tasks))
325
326 async def __close_clients(self):
327 for c in self.avail_clients:
328 await c.close()
329 self.avail_clients = []
330 self.num_clients = 0
331
332 def __enter__(self):
333 return self
334
335 def __exit__(self, exc_type, exc_value, traceback):
336 self.close()
337 return False