summaryrefslogtreecommitdiffstats
path: root/bitbake/lib/hashserv/client.py
diff options
context:
space:
mode:
Diffstat (limited to 'bitbake/lib/hashserv/client.py')
-rw-r--r--bitbake/lib/hashserv/client.py220
1 files changed, 127 insertions, 93 deletions
diff --git a/bitbake/lib/hashserv/client.py b/bitbake/lib/hashserv/client.py
index 0b254beddd..8cb18050a6 100644
--- a/bitbake/lib/hashserv/client.py
+++ b/bitbake/lib/hashserv/client.py
@@ -5,6 +5,7 @@
5 5
6import logging 6import logging
7import socket 7import socket
8import asyncio
8import bb.asyncrpc 9import bb.asyncrpc
9import json 10import json
10from . import create_async_client 11from . import create_async_client
@@ -13,10 +14,71 @@ from . import create_async_client
13logger = logging.getLogger("hashserv.client") 14logger = logging.getLogger("hashserv.client")
14 15
15 16
17class Batch(object):
18 def __init__(self):
19 self.done = False
20 self.cond = asyncio.Condition()
21 self.pending = []
22 self.results = []
23 self.sent_count = 0
24
25 async def recv(self, socket):
26 while True:
27 async with self.cond:
28 await self.cond.wait_for(lambda: self.pending or self.done)
29
30 if not self.pending:
31 if self.done:
32 return
33 continue
34
35 r = await socket.recv()
36 self.results.append(r)
37
38 async with self.cond:
39 self.pending.pop(0)
40
41 async def send(self, socket, msgs):
42 try:
43 # In the event of a restart due to a reconnect, all in-flight
44 # messages need to be resent first to keep to result count in sync
45 for m in self.pending:
46 await socket.send(m)
47
48 for m in msgs:
49 # Add the message to the pending list before attempting to send
50 # it so that if the send fails it will be retried
51 async with self.cond:
52 self.pending.append(m)
53 self.cond.notify()
54 self.sent_count += 1
55
56 await socket.send(m)
57
58 finally:
59 async with self.cond:
60 self.done = True
61 self.cond.notify()
62
63 async def process(self, socket, msgs):
64 await asyncio.gather(
65 self.recv(socket),
66 self.send(socket, msgs),
67 )
68
69 if len(self.results) != self.sent_count:
70 raise ValueError(
71 f"Expected result count {len(self.results)}. Expected {self.sent_count}"
72 )
73
74 return self.results
75
76
16class AsyncClient(bb.asyncrpc.AsyncClient): 77class AsyncClient(bb.asyncrpc.AsyncClient):
17 MODE_NORMAL = 0 78 MODE_NORMAL = 0
18 MODE_GET_STREAM = 1 79 MODE_GET_STREAM = 1
19 MODE_EXIST_STREAM = 2 80 MODE_EXIST_STREAM = 2
81 MODE_MARK_STREAM = 3
20 82
21 def __init__(self, username=None, password=None): 83 def __init__(self, username=None, password=None):
22 super().__init__("OEHASHEQUIV", "1.1", logger) 84 super().__init__("OEHASHEQUIV", "1.1", logger)
@@ -36,32 +98,52 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
36 if become: 98 if become:
37 await self.become_user(become) 99 await self.become_user(become)
38 100
39 async def send_stream(self, mode, msg): 101 async def send_stream_batch(self, mode, msgs):
102 """
103 Does a "batch" process of stream messages. This sends the query
104 messages as fast as possible, and simultaneously attempts to read the
105 messages back. This helps to mitigate the effects of latency to the
106 hash equivalence server be allowing multiple queries to be "in-flight"
107 at once
108
109 The implementation does more complicated tracking using a count of sent
110 messages so that `msgs` can be a generator function (i.e. its length is
111 unknown)
112
113 """
114
115 b = Batch()
116
40 async def proc(): 117 async def proc():
118 nonlocal b
119
41 await self._set_mode(mode) 120 await self._set_mode(mode)
42 await self.socket.send(msg) 121 return await b.process(self.socket, msgs)
43 return await self.socket.recv()
44 122
45 return await self._send_wrapper(proc) 123 return await self._send_wrapper(proc)
46 124
47 async def invoke(self, *args, **kwargs): 125 async def invoke(self, *args, skip_mode=False, **kwargs):
48 # It's OK if connection errors cause a failure here, because the mode 126 # It's OK if connection errors cause a failure here, because the mode
49 # is also reset to normal on a new connection 127 # is also reset to normal on a new connection
50 await self._set_mode(self.MODE_NORMAL) 128 if not skip_mode:
129 await self._set_mode(self.MODE_NORMAL)
51 return await super().invoke(*args, **kwargs) 130 return await super().invoke(*args, **kwargs)
52 131
53 async def _set_mode(self, new_mode): 132 async def _set_mode(self, new_mode):
54 async def stream_to_normal(): 133 async def stream_to_normal():
134 # Check if already in normal mode (e.g. due to a connection reset)
135 if self.mode == self.MODE_NORMAL:
136 return "ok"
55 await self.socket.send("END") 137 await self.socket.send("END")
56 return await self.socket.recv() 138 return await self.socket.recv()
57 139
58 async def normal_to_stream(command): 140 async def normal_to_stream(command):
59 r = await self.invoke({command: None}) 141 r = await self.invoke({command: None}, skip_mode=True)
60 if r != "ok": 142 if r != "ok":
143 self.check_invoke_error(r)
61 raise ConnectionError( 144 raise ConnectionError(
62 f"Unable to transition to stream mode: Bad response from server {r!r}" 145 f"Unable to transition to stream mode: Bad response from server {r!r}"
63 ) 146 )
64
65 self.logger.debug("Mode is now %s", command) 147 self.logger.debug("Mode is now %s", command)
66 148
67 if new_mode == self.mode: 149 if new_mode == self.mode:
@@ -83,16 +165,23 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
83 await normal_to_stream("get-stream") 165 await normal_to_stream("get-stream")
84 elif new_mode == self.MODE_EXIST_STREAM: 166 elif new_mode == self.MODE_EXIST_STREAM:
85 await normal_to_stream("exists-stream") 167 await normal_to_stream("exists-stream")
168 elif new_mode == self.MODE_MARK_STREAM:
169 await normal_to_stream("gc-mark-stream")
86 elif new_mode != self.MODE_NORMAL: 170 elif new_mode != self.MODE_NORMAL:
87 raise Exception("Undefined mode transition {self.mode!r} -> {new_mode!r}") 171 raise Exception("Undefined mode transition {self.mode!r} -> {new_mode!r}")
88 172
89 self.mode = new_mode 173 self.mode = new_mode
90 174
91 async def get_unihash(self, method, taskhash): 175 async def get_unihash(self, method, taskhash):
92 r = await self.send_stream(self.MODE_GET_STREAM, "%s %s" % (method, taskhash)) 176 r = await self.get_unihash_batch([(method, taskhash)])
93 if not r: 177 return r[0]
94 return None 178
95 return r 179 async def get_unihash_batch(self, args):
180 result = await self.send_stream_batch(
181 self.MODE_GET_STREAM,
182 (f"{method} {taskhash}" for method, taskhash in args),
183 )
184 return [r if r else None for r in result]
96 185
97 async def report_unihash(self, taskhash, method, outhash, unihash, extra={}): 186 async def report_unihash(self, taskhash, method, outhash, unihash, extra={}):
98 m = extra.copy() 187 m = extra.copy()
@@ -115,8 +204,12 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
115 ) 204 )
116 205
117 async def unihash_exists(self, unihash): 206 async def unihash_exists(self, unihash):
118 r = await self.send_stream(self.MODE_EXIST_STREAM, unihash) 207 r = await self.unihash_exists_batch([unihash])
119 return r == "true" 208 return r[0]
209
210 async def unihash_exists_batch(self, unihashes):
211 result = await self.send_stream_batch(self.MODE_EXIST_STREAM, unihashes)
212 return [r == "true" for r in result]
120 213
121 async def get_outhash(self, method, outhash, taskhash, with_unihash=True): 214 async def get_outhash(self, method, outhash, taskhash, with_unihash=True):
122 return await self.invoke( 215 return await self.invoke(
@@ -216,6 +309,24 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
216 """ 309 """
217 return await self.invoke({"gc-mark": {"mark": mark, "where": where}}) 310 return await self.invoke({"gc-mark": {"mark": mark, "where": where}})
218 311
312 async def gc_mark_stream(self, mark, rows):
313 """
314 Similar to `gc-mark`, but accepts a list of "where" key-value pair
315 conditions. It utilizes stream mode to mark hashes, which helps reduce
316 the impact of latency when communicating with the hash equivalence
317 server.
318 """
319 def row_to_dict(row):
320 pairs = row.split()
321 return dict(zip(pairs[::2], pairs[1::2]))
322
323 responses = await self.send_stream_batch(
324 self.MODE_MARK_STREAM,
325 (json.dumps({"mark": mark, "where": row_to_dict(row)}) for row in rows),
326 )
327
328 return {"count": sum(int(json.loads(r)["count"]) for r in responses)}
329
219 async def gc_sweep(self, mark): 330 async def gc_sweep(self, mark):
220 """ 331 """
221 Finishes garbage collection for "mark". All unihash entries that have 332 Finishes garbage collection for "mark". All unihash entries that have
@@ -237,10 +348,12 @@ class Client(bb.asyncrpc.Client):
237 "connect_tcp", 348 "connect_tcp",
238 "connect_websocket", 349 "connect_websocket",
239 "get_unihash", 350 "get_unihash",
351 "get_unihash_batch",
240 "report_unihash", 352 "report_unihash",
241 "report_unihash_equiv", 353 "report_unihash_equiv",
242 "get_taskhash", 354 "get_taskhash",
243 "unihash_exists", 355 "unihash_exists",
356 "unihash_exists_batch",
244 "get_outhash", 357 "get_outhash",
245 "get_stats", 358 "get_stats",
246 "reset_stats", 359 "reset_stats",
@@ -259,88 +372,9 @@ class Client(bb.asyncrpc.Client):
259 "get_db_query_columns", 372 "get_db_query_columns",
260 "gc_status", 373 "gc_status",
261 "gc_mark", 374 "gc_mark",
375 "gc_mark_stream",
262 "gc_sweep", 376 "gc_sweep",
263 ) 377 )
264 378
265 def _get_async_client(self): 379 def _get_async_client(self):
266 return AsyncClient(self.username, self.password) 380 return AsyncClient(self.username, self.password)
267
268
269class ClientPool(bb.asyncrpc.ClientPool):
270 def __init__(
271 self,
272 address,
273 max_clients,
274 *,
275 username=None,
276 password=None,
277 become=None,
278 ):
279 super().__init__(max_clients)
280 self.address = address
281 self.username = username
282 self.password = password
283 self.become = become
284
285 async def _new_client(self):
286 client = await create_async_client(
287 self.address,
288 username=self.username,
289 password=self.password,
290 )
291 if self.become:
292 await client.become_user(self.become)
293 return client
294
295 def _run_key_tasks(self, queries, call):
296 results = {key: None for key in queries.keys()}
297
298 def make_task(key, args):
299 async def task(client):
300 nonlocal results
301 unihash = await call(client, args)
302 results[key] = unihash
303
304 return task
305
306 def gen_tasks():
307 for key, args in queries.items():
308 yield make_task(key, args)
309
310 self.run_tasks(gen_tasks())
311 return results
312
313 def get_unihashes(self, queries):
314 """
315 Query multiple unihashes in parallel.
316
317 The queries argument is a dictionary with arbitrary key. The values
318 must be a tuple of (method, taskhash).
319
320 Returns a dictionary with a corresponding key for each input key, and
321 the value is the queried unihash (which might be none if the query
322 failed)
323 """
324
325 async def call(client, args):
326 method, taskhash = args
327 return await client.get_unihash(method, taskhash)
328
329 return self._run_key_tasks(queries, call)
330
331 def unihashes_exist(self, queries):
332 """
333 Query multiple unihash existence checks in parallel.
334
335 The queries argument is a dictionary with arbitrary key. The values
336 must be a unihash.
337
338 Returns a dictionary with a corresponding key for each input key, and
339 the value is True or False if the unihash is known by the server (or
340 None if there was a failure)
341 """
342
343 async def call(client, unihash):
344 return await client.unihash_exists(unihash)
345
346 return self._run_key_tasks(queries, call)