summaryrefslogtreecommitdiffstats
path: root/bitbake/lib/bb/asyncrpc/client.py
diff options
context:
space:
mode:
Diffstat (limited to 'bitbake/lib/bb/asyncrpc/client.py')
-rw-r--r--bitbake/lib/bb/asyncrpc/client.py271
1 files changed, 271 insertions, 0 deletions
diff --git a/bitbake/lib/bb/asyncrpc/client.py b/bitbake/lib/bb/asyncrpc/client.py
new file mode 100644
index 0000000000..17b72033b9
--- /dev/null
+++ b/bitbake/lib/bb/asyncrpc/client.py
@@ -0,0 +1,271 @@
1#
2# Copyright BitBake Contributors
3#
4# SPDX-License-Identifier: GPL-2.0-only
5#
6
7import abc
8import asyncio
9import json
10import os
11import socket
12import sys
13import re
14import contextlib
15from threading import Thread
16from .connection import StreamConnection, WebsocketConnection, DEFAULT_MAX_CHUNK
17from .exceptions import ConnectionClosedError, InvokeError
18
19UNIX_PREFIX = "unix://"
20WS_PREFIX = "ws://"
21WSS_PREFIX = "wss://"
22
23ADDR_TYPE_UNIX = 0
24ADDR_TYPE_TCP = 1
25ADDR_TYPE_WS = 2
26
27WEBSOCKETS_MIN_VERSION = (9, 1)
28# Need websockets 10 with python 3.10+
29if sys.version_info >= (3, 10, 0):
30 WEBSOCKETS_MIN_VERSION = (10, 0)
31
32
33def parse_address(addr):
34 if addr.startswith(UNIX_PREFIX):
35 return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX) :],))
36 elif addr.startswith(WS_PREFIX) or addr.startswith(WSS_PREFIX):
37 return (ADDR_TYPE_WS, (addr,))
38 else:
39 m = re.match(r"\[(?P<host>[^\]]*)\]:(?P<port>\d+)$", addr)
40 if m is not None:
41 host = m.group("host")
42 port = m.group("port")
43 else:
44 host, port = addr.split(":")
45
46 return (ADDR_TYPE_TCP, (host, int(port)))
47
48
49class AsyncClient(object):
50 def __init__(
51 self,
52 proto_name,
53 proto_version,
54 logger,
55 timeout=30,
56 server_headers=False,
57 headers={},
58 ):
59 self.socket = None
60 self.max_chunk = DEFAULT_MAX_CHUNK
61 self.proto_name = proto_name
62 self.proto_version = proto_version
63 self.logger = logger
64 self.timeout = timeout
65 self.needs_server_headers = server_headers
66 self.server_headers = {}
67 self.headers = headers
68
69 async def connect_tcp(self, address, port):
70 async def connect_sock():
71 reader, writer = await asyncio.open_connection(address, port)
72 return StreamConnection(reader, writer, self.timeout, self.max_chunk)
73
74 self._connect_sock = connect_sock
75
76 async def connect_unix(self, path):
77 async def connect_sock():
78 # AF_UNIX has path length issues so chdir here to workaround
79 cwd = os.getcwd()
80 try:
81 os.chdir(os.path.dirname(path))
82 # The socket must be opened synchronously so that CWD doesn't get
83 # changed out from underneath us so we pass as a sock into asyncio
84 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0)
85 sock.connect(os.path.basename(path))
86 finally:
87 os.chdir(cwd)
88 reader, writer = await asyncio.open_unix_connection(sock=sock)
89 return StreamConnection(reader, writer, self.timeout, self.max_chunk)
90
91 self._connect_sock = connect_sock
92
93 async def connect_websocket(self, uri):
94 import websockets
95
96 try:
97 version = tuple(
98 int(v)
99 for v in websockets.__version__.split(".")[
100 0 : len(WEBSOCKETS_MIN_VERSION)
101 ]
102 )
103 except ValueError:
104 raise ImportError(
105 f"Unable to parse websockets version '{websockets.__version__}'"
106 )
107
108 if version < WEBSOCKETS_MIN_VERSION:
109 min_ver_str = ".".join(str(v) for v in WEBSOCKETS_MIN_VERSION)
110 raise ImportError(
111 f"Websockets version {websockets.__version__} is less than minimum required version {min_ver_str}"
112 )
113
114 async def connect_sock():
115 try:
116 websocket = await websockets.connect(
117 uri,
118 ping_interval=None,
119 open_timeout=self.timeout,
120 )
121 except asyncio.exceptions.TimeoutError:
122 raise ConnectionError("Timeout while connecting to websocket")
123 except (OSError, websockets.InvalidHandshake, websockets.InvalidURI) as exc:
124 raise ConnectionError(f"Could not connect to websocket: {exc}") from exc
125 return WebsocketConnection(websocket, self.timeout)
126
127 self._connect_sock = connect_sock
128
129 async def setup_connection(self):
130 # Send headers
131 await self.socket.send("%s %s" % (self.proto_name, self.proto_version))
132 await self.socket.send(
133 "needs-headers: %s" % ("true" if self.needs_server_headers else "false")
134 )
135 for k, v in self.headers.items():
136 await self.socket.send("%s: %s" % (k, v))
137
138 # End of headers
139 await self.socket.send("")
140
141 self.server_headers = {}
142 if self.needs_server_headers:
143 while True:
144 line = await self.socket.recv()
145 if not line:
146 # End headers
147 break
148 tag, value = line.split(":", 1)
149 self.server_headers[tag.lower()] = value.strip()
150
151 async def get_header(self, tag, default):
152 await self.connect()
153 return self.server_headers.get(tag, default)
154
155 async def connect(self):
156 if self.socket is None:
157 self.socket = await self._connect_sock()
158 await self.setup_connection()
159
160 async def disconnect(self):
161 if self.socket is not None:
162 await self.socket.close()
163 self.socket = None
164
165 async def close(self):
166 await self.disconnect()
167
168 async def _send_wrapper(self, proc):
169 count = 0
170 while True:
171 try:
172 await self.connect()
173 return await proc()
174 except (
175 OSError,
176 ConnectionError,
177 ConnectionClosedError,
178 json.JSONDecodeError,
179 UnicodeDecodeError,
180 ) as e:
181 self.logger.warning("Error talking to server: %s" % e)
182 if count >= 3:
183 if not isinstance(e, ConnectionError):
184 raise ConnectionError(str(e))
185 raise e
186 await self.close()
187 count += 1
188
189 def check_invoke_error(self, msg):
190 if isinstance(msg, dict) and "invoke-error" in msg:
191 raise InvokeError(msg["invoke-error"]["message"])
192
193 async def invoke(self, msg):
194 async def proc():
195 await self.socket.send_message(msg)
196 return await self.socket.recv_message()
197
198 result = await self._send_wrapper(proc)
199 self.check_invoke_error(result)
200 return result
201
202 async def ping(self):
203 return await self.invoke({"ping": {}})
204
205 async def __aenter__(self):
206 return self
207
208 async def __aexit__(self, exc_type, exc_value, traceback):
209 await self.close()
210
211
212class Client(object):
213 def __init__(self):
214 self.client = self._get_async_client()
215 self.loop = asyncio.new_event_loop()
216
217 # Override any pre-existing loop.
218 # Without this, the PR server export selftest triggers a hang
219 # when running with Python 3.7. The drawback is that there is
220 # potential for issues if the PR and hash equiv (or some new)
221 # clients need to both be instantiated in the same process.
222 # This should be revisited if/when Python 3.9 becomes the
223 # minimum required version for BitBake, as it seems not
224 # required (but harmless) with it.
225 asyncio.set_event_loop(self.loop)
226
227 self._add_methods("connect_tcp", "ping")
228
229 @abc.abstractmethod
230 def _get_async_client(self):
231 pass
232
233 def _get_downcall_wrapper(self, downcall):
234 def wrapper(*args, **kwargs):
235 return self.loop.run_until_complete(downcall(*args, **kwargs))
236
237 return wrapper
238
239 def _add_methods(self, *methods):
240 for m in methods:
241 downcall = getattr(self.client, m)
242 setattr(self, m, self._get_downcall_wrapper(downcall))
243
244 def connect_unix(self, path):
245 self.loop.run_until_complete(self.client.connect_unix(path))
246 self.loop.run_until_complete(self.client.connect())
247
248 @property
249 def max_chunk(self):
250 return self.client.max_chunk
251
252 @max_chunk.setter
253 def max_chunk(self, value):
254 self.client.max_chunk = value
255
256 def disconnect(self):
257 self.loop.run_until_complete(self.client.close())
258
259 def close(self):
260 if self.loop:
261 self.loop.run_until_complete(self.client.close())
262 self.loop.run_until_complete(self.loop.shutdown_asyncgens())
263 self.loop.close()
264 self.loop = None
265
266 def __enter__(self):
267 return self
268
269 def __exit__(self, exc_type, exc_value, traceback):
270 self.close()
271 return False