diff options
Diffstat (limited to 'bitbake/lib/hashserv/client.py')
-rw-r--r-- | bitbake/lib/hashserv/client.py | 198 |
1 files changed, 105 insertions, 93 deletions
diff --git a/bitbake/lib/hashserv/client.py b/bitbake/lib/hashserv/client.py index 0b254beddd..a510f3284f 100644 --- a/bitbake/lib/hashserv/client.py +++ b/bitbake/lib/hashserv/client.py | |||
@@ -5,6 +5,7 @@ | |||
5 | 5 | ||
6 | import logging | 6 | import logging |
7 | import socket | 7 | import socket |
8 | import asyncio | ||
8 | import bb.asyncrpc | 9 | import bb.asyncrpc |
9 | import json | 10 | import json |
10 | from . import create_async_client | 11 | from . import create_async_client |
@@ -13,6 +14,66 @@ from . import create_async_client | |||
13 | logger = logging.getLogger("hashserv.client") | 14 | logger = logging.getLogger("hashserv.client") |
14 | 15 | ||
15 | 16 | ||
17 | class Batch(object): | ||
18 | def __init__(self): | ||
19 | self.done = False | ||
20 | self.cond = asyncio.Condition() | ||
21 | self.pending = [] | ||
22 | self.results = [] | ||
23 | self.sent_count = 0 | ||
24 | |||
25 | async def recv(self, socket): | ||
26 | while True: | ||
27 | async with self.cond: | ||
28 | await self.cond.wait_for(lambda: self.pending or self.done) | ||
29 | |||
30 | if not self.pending: | ||
31 | if self.done: | ||
32 | return | ||
33 | continue | ||
34 | |||
35 | r = await socket.recv() | ||
36 | self.results.append(r) | ||
37 | |||
38 | async with self.cond: | ||
39 | self.pending.pop(0) | ||
40 | |||
41 | async def send(self, socket, msgs): | ||
42 | try: | ||
43 | # In the event of a restart due to a reconnect, all in-flight | ||
44 | # messages need to be resent first to keep to result count in sync | ||
45 | for m in self.pending: | ||
46 | await socket.send(m) | ||
47 | |||
48 | for m in msgs: | ||
49 | # Add the message to the pending list before attempting to send | ||
50 | # it so that if the send fails it will be retried | ||
51 | async with self.cond: | ||
52 | self.pending.append(m) | ||
53 | self.cond.notify() | ||
54 | self.sent_count += 1 | ||
55 | |||
56 | await socket.send(m) | ||
57 | |||
58 | finally: | ||
59 | async with self.cond: | ||
60 | self.done = True | ||
61 | self.cond.notify() | ||
62 | |||
63 | async def process(self, socket, msgs): | ||
64 | await asyncio.gather( | ||
65 | self.recv(socket), | ||
66 | self.send(socket, msgs), | ||
67 | ) | ||
68 | |||
69 | if len(self.results) != self.sent_count: | ||
70 | raise ValueError( | ||
71 | f"Expected result count {len(self.results)}. Expected {self.sent_count}" | ||
72 | ) | ||
73 | |||
74 | return self.results | ||
75 | |||
76 | |||
16 | class AsyncClient(bb.asyncrpc.AsyncClient): | 77 | class AsyncClient(bb.asyncrpc.AsyncClient): |
17 | MODE_NORMAL = 0 | 78 | MODE_NORMAL = 0 |
18 | MODE_GET_STREAM = 1 | 79 | MODE_GET_STREAM = 1 |
@@ -36,32 +97,52 @@ class AsyncClient(bb.asyncrpc.AsyncClient): | |||
36 | if become: | 97 | if become: |
37 | await self.become_user(become) | 98 | await self.become_user(become) |
38 | 99 | ||
39 | async def send_stream(self, mode, msg): | 100 | async def send_stream_batch(self, mode, msgs): |
101 | """ | ||
102 | Does a "batch" process of stream messages. This sends the query | ||
103 | messages as fast as possible, and simultaneously attempts to read the | ||
104 | messages back. This helps to mitigate the effects of latency to the | ||
105 | hash equivalence server be allowing multiple queries to be "in-flight" | ||
106 | at once | ||
107 | |||
108 | The implementation does more complicated tracking using a count of sent | ||
109 | messages so that `msgs` can be a generator function (i.e. its length is | ||
110 | unknown) | ||
111 | |||
112 | """ | ||
113 | |||
114 | b = Batch() | ||
115 | |||
40 | async def proc(): | 116 | async def proc(): |
117 | nonlocal b | ||
118 | |||
41 | await self._set_mode(mode) | 119 | await self._set_mode(mode) |
42 | await self.socket.send(msg) | 120 | return await b.process(self.socket, msgs) |
43 | return await self.socket.recv() | ||
44 | 121 | ||
45 | return await self._send_wrapper(proc) | 122 | return await self._send_wrapper(proc) |
46 | 123 | ||
47 | async def invoke(self, *args, **kwargs): | 124 | async def invoke(self, *args, skip_mode=False, **kwargs): |
48 | # It's OK if connection errors cause a failure here, because the mode | 125 | # It's OK if connection errors cause a failure here, because the mode |
49 | # is also reset to normal on a new connection | 126 | # is also reset to normal on a new connection |
50 | await self._set_mode(self.MODE_NORMAL) | 127 | if not skip_mode: |
128 | await self._set_mode(self.MODE_NORMAL) | ||
51 | return await super().invoke(*args, **kwargs) | 129 | return await super().invoke(*args, **kwargs) |
52 | 130 | ||
53 | async def _set_mode(self, new_mode): | 131 | async def _set_mode(self, new_mode): |
54 | async def stream_to_normal(): | 132 | async def stream_to_normal(): |
133 | # Check if already in normal mode (e.g. due to a connection reset) | ||
134 | if self.mode == self.MODE_NORMAL: | ||
135 | return "ok" | ||
55 | await self.socket.send("END") | 136 | await self.socket.send("END") |
56 | return await self.socket.recv() | 137 | return await self.socket.recv() |
57 | 138 | ||
58 | async def normal_to_stream(command): | 139 | async def normal_to_stream(command): |
59 | r = await self.invoke({command: None}) | 140 | r = await self.invoke({command: None}, skip_mode=True) |
60 | if r != "ok": | 141 | if r != "ok": |
142 | self.check_invoke_error(r) | ||
61 | raise ConnectionError( | 143 | raise ConnectionError( |
62 | f"Unable to transition to stream mode: Bad response from server {r!r}" | 144 | f"Unable to transition to stream mode: Bad response from server {r!r}" |
63 | ) | 145 | ) |
64 | |||
65 | self.logger.debug("Mode is now %s", command) | 146 | self.logger.debug("Mode is now %s", command) |
66 | 147 | ||
67 | if new_mode == self.mode: | 148 | if new_mode == self.mode: |
@@ -89,10 +170,15 @@ class AsyncClient(bb.asyncrpc.AsyncClient): | |||
89 | self.mode = new_mode | 170 | self.mode = new_mode |
90 | 171 | ||
91 | async def get_unihash(self, method, taskhash): | 172 | async def get_unihash(self, method, taskhash): |
92 | r = await self.send_stream(self.MODE_GET_STREAM, "%s %s" % (method, taskhash)) | 173 | r = await self.get_unihash_batch([(method, taskhash)]) |
93 | if not r: | 174 | return r[0] |
94 | return None | 175 | |
95 | return r | 176 | async def get_unihash_batch(self, args): |
177 | result = await self.send_stream_batch( | ||
178 | self.MODE_GET_STREAM, | ||
179 | (f"{method} {taskhash}" for method, taskhash in args), | ||
180 | ) | ||
181 | return [r if r else None for r in result] | ||
96 | 182 | ||
97 | async def report_unihash(self, taskhash, method, outhash, unihash, extra={}): | 183 | async def report_unihash(self, taskhash, method, outhash, unihash, extra={}): |
98 | m = extra.copy() | 184 | m = extra.copy() |
@@ -115,8 +201,12 @@ class AsyncClient(bb.asyncrpc.AsyncClient): | |||
115 | ) | 201 | ) |
116 | 202 | ||
117 | async def unihash_exists(self, unihash): | 203 | async def unihash_exists(self, unihash): |
118 | r = await self.send_stream(self.MODE_EXIST_STREAM, unihash) | 204 | r = await self.unihash_exists_batch([unihash]) |
119 | return r == "true" | 205 | return r[0] |
206 | |||
207 | async def unihash_exists_batch(self, unihashes): | ||
208 | result = await self.send_stream_batch(self.MODE_EXIST_STREAM, unihashes) | ||
209 | return [r == "true" for r in result] | ||
120 | 210 | ||
121 | async def get_outhash(self, method, outhash, taskhash, with_unihash=True): | 211 | async def get_outhash(self, method, outhash, taskhash, with_unihash=True): |
122 | return await self.invoke( | 212 | return await self.invoke( |
@@ -237,10 +327,12 @@ class Client(bb.asyncrpc.Client): | |||
237 | "connect_tcp", | 327 | "connect_tcp", |
238 | "connect_websocket", | 328 | "connect_websocket", |
239 | "get_unihash", | 329 | "get_unihash", |
330 | "get_unihash_batch", | ||
240 | "report_unihash", | 331 | "report_unihash", |
241 | "report_unihash_equiv", | 332 | "report_unihash_equiv", |
242 | "get_taskhash", | 333 | "get_taskhash", |
243 | "unihash_exists", | 334 | "unihash_exists", |
335 | "unihash_exists_batch", | ||
244 | "get_outhash", | 336 | "get_outhash", |
245 | "get_stats", | 337 | "get_stats", |
246 | "reset_stats", | 338 | "reset_stats", |
@@ -264,83 +356,3 @@ class Client(bb.asyncrpc.Client): | |||
264 | 356 | ||
265 | def _get_async_client(self): | 357 | def _get_async_client(self): |
266 | return AsyncClient(self.username, self.password) | 358 | return AsyncClient(self.username, self.password) |
267 | |||
268 | |||
269 | class ClientPool(bb.asyncrpc.ClientPool): | ||
270 | def __init__( | ||
271 | self, | ||
272 | address, | ||
273 | max_clients, | ||
274 | *, | ||
275 | username=None, | ||
276 | password=None, | ||
277 | become=None, | ||
278 | ): | ||
279 | super().__init__(max_clients) | ||
280 | self.address = address | ||
281 | self.username = username | ||
282 | self.password = password | ||
283 | self.become = become | ||
284 | |||
285 | async def _new_client(self): | ||
286 | client = await create_async_client( | ||
287 | self.address, | ||
288 | username=self.username, | ||
289 | password=self.password, | ||
290 | ) | ||
291 | if self.become: | ||
292 | await client.become_user(self.become) | ||
293 | return client | ||
294 | |||
295 | def _run_key_tasks(self, queries, call): | ||
296 | results = {key: None for key in queries.keys()} | ||
297 | |||
298 | def make_task(key, args): | ||
299 | async def task(client): | ||
300 | nonlocal results | ||
301 | unihash = await call(client, args) | ||
302 | results[key] = unihash | ||
303 | |||
304 | return task | ||
305 | |||
306 | def gen_tasks(): | ||
307 | for key, args in queries.items(): | ||
308 | yield make_task(key, args) | ||
309 | |||
310 | self.run_tasks(gen_tasks()) | ||
311 | return results | ||
312 | |||
313 | def get_unihashes(self, queries): | ||
314 | """ | ||
315 | Query multiple unihashes in parallel. | ||
316 | |||
317 | The queries argument is a dictionary with arbitrary key. The values | ||
318 | must be a tuple of (method, taskhash). | ||
319 | |||
320 | Returns a dictionary with a corresponding key for each input key, and | ||
321 | the value is the queried unihash (which might be none if the query | ||
322 | failed) | ||
323 | """ | ||
324 | |||
325 | async def call(client, args): | ||
326 | method, taskhash = args | ||
327 | return await client.get_unihash(method, taskhash) | ||
328 | |||
329 | return self._run_key_tasks(queries, call) | ||
330 | |||
331 | def unihashes_exist(self, queries): | ||
332 | """ | ||
333 | Query multiple unihash existence checks in parallel. | ||
334 | |||
335 | The queries argument is a dictionary with arbitrary key. The values | ||
336 | must be a unihash. | ||
337 | |||
338 | Returns a dictionary with a corresponding key for each input key, and | ||
339 | the value is True or False if the unihash is known by the server (or | ||
340 | None if there was a failure) | ||
341 | """ | ||
342 | |||
343 | async def call(client, unihash): | ||
344 | return await client.unihash_exists(unihash) | ||
345 | |||
346 | return self._run_key_tasks(queries, call) | ||