summaryrefslogtreecommitdiffstats
path: root/bitbake/lib/hashserv
diff options
context:
space:
mode:
Diffstat (limited to 'bitbake/lib/hashserv')
-rw-r--r--bitbake/lib/hashserv/__init__.py175
-rw-r--r--bitbake/lib/hashserv/client.py453
-rw-r--r--bitbake/lib/hashserv/server.py1117
-rw-r--r--bitbake/lib/hashserv/sqlalchemy.py598
-rw-r--r--bitbake/lib/hashserv/sqlite.py579
-rw-r--r--bitbake/lib/hashserv/tests.py1309
6 files changed, 3521 insertions, 710 deletions
diff --git a/bitbake/lib/hashserv/__init__.py b/bitbake/lib/hashserv/__init__.py
index 5f2e101e52..ac891e0174 100644
--- a/bitbake/lib/hashserv/__init__.py
+++ b/bitbake/lib/hashserv/__init__.py
@@ -5,129 +5,104 @@
5 5
6import asyncio 6import asyncio
7from contextlib import closing 7from contextlib import closing
8import re
9import sqlite3
10import itertools 8import itertools
11import json 9import json
10from collections import namedtuple
11from urllib.parse import urlparse
12from bb.asyncrpc.client import parse_address, ADDR_TYPE_UNIX, ADDR_TYPE_WS
12 13
13UNIX_PREFIX = "unix://" 14User = namedtuple("User", ("username", "permissions"))
14
15ADDR_TYPE_UNIX = 0
16ADDR_TYPE_TCP = 1
17
18# The Python async server defaults to a 64K receive buffer, so we hardcode our
19# maximum chunk size. It would be better if the client and server reported to
20# each other what the maximum chunk sizes were, but that will slow down the
21# connection setup with a round trip delay so I'd rather not do that unless it
22# is necessary
23DEFAULT_MAX_CHUNK = 32 * 1024
24
25TABLE_DEFINITION = (
26 ("method", "TEXT NOT NULL"),
27 ("outhash", "TEXT NOT NULL"),
28 ("taskhash", "TEXT NOT NULL"),
29 ("unihash", "TEXT NOT NULL"),
30 ("created", "DATETIME"),
31
32 # Optional fields
33 ("owner", "TEXT"),
34 ("PN", "TEXT"),
35 ("PV", "TEXT"),
36 ("PR", "TEXT"),
37 ("task", "TEXT"),
38 ("outhash_siginfo", "TEXT"),
39)
40
41TABLE_COLUMNS = tuple(name for name, _ in TABLE_DEFINITION)
42
43def setup_database(database, sync=True):
44 db = sqlite3.connect(database)
45 db.row_factory = sqlite3.Row
46
47 with closing(db.cursor()) as cursor:
48 cursor.execute('''
49 CREATE TABLE IF NOT EXISTS tasks_v2 (
50 id INTEGER PRIMARY KEY AUTOINCREMENT,
51 %s
52 UNIQUE(method, outhash, taskhash)
53 )
54 ''' % " ".join("%s %s," % (name, typ) for name, typ in TABLE_DEFINITION))
55 cursor.execute('PRAGMA journal_mode = WAL')
56 cursor.execute('PRAGMA synchronous = %s' % ('NORMAL' if sync else 'OFF'))
57
58 # Drop old indexes
59 cursor.execute('DROP INDEX IF EXISTS taskhash_lookup')
60 cursor.execute('DROP INDEX IF EXISTS outhash_lookup')
61
62 # Create new indexes
63 cursor.execute('CREATE INDEX IF NOT EXISTS taskhash_lookup_v2 ON tasks_v2 (method, taskhash, created)')
64 cursor.execute('CREATE INDEX IF NOT EXISTS outhash_lookup_v2 ON tasks_v2 (method, outhash)')
65
66 return db
67
68
69def parse_address(addr):
70 if addr.startswith(UNIX_PREFIX):
71 return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX):],))
72 else:
73 m = re.match(r'\[(?P<host>[^\]]*)\]:(?P<port>\d+)$', addr)
74 if m is not None:
75 host = m.group('host')
76 port = m.group('port')
77 else:
78 host, port = addr.split(':')
79 15
80 return (ADDR_TYPE_TCP, (host, int(port)))
81 16
17def create_server(
18 addr,
19 dbname,
20 *,
21 sync=True,
22 upstream=None,
23 read_only=False,
24 db_username=None,
25 db_password=None,
26 anon_perms=None,
27 admin_username=None,
28 admin_password=None,
29 reuseport=False,
30):
31 def sqlite_engine():
32 from .sqlite import DatabaseEngine
82 33
83def chunkify(msg, max_chunk): 34 return DatabaseEngine(dbname, sync)
84 if len(msg) < max_chunk - 1:
85 yield ''.join((msg, "\n"))
86 else:
87 yield ''.join((json.dumps({
88 'chunk-stream': None
89 }), "\n"))
90 35
91 args = [iter(msg)] * (max_chunk - 1) 36 def sqlalchemy_engine():
92 for m in map(''.join, itertools.zip_longest(*args, fillvalue='')): 37 from .sqlalchemy import DatabaseEngine
93 yield ''.join(itertools.chain(m, "\n"))
94 yield "\n"
95 38
39 return DatabaseEngine(dbname, db_username, db_password)
96 40
97def create_server(addr, dbname, *, sync=True, upstream=None, read_only=False):
98 from . import server 41 from . import server
99 db = setup_database(dbname, sync=sync) 42
100 s = server.Server(db, upstream=upstream, read_only=read_only) 43 if "://" in dbname:
44 db_engine = sqlalchemy_engine()
45 else:
46 db_engine = sqlite_engine()
47
48 if anon_perms is None:
49 anon_perms = server.DEFAULT_ANON_PERMS
50
51 s = server.Server(
52 db_engine,
53 upstream=upstream,
54 read_only=read_only,
55 anon_perms=anon_perms,
56 admin_username=admin_username,
57 admin_password=admin_password,
58 )
101 59
102 (typ, a) = parse_address(addr) 60 (typ, a) = parse_address(addr)
103 if typ == ADDR_TYPE_UNIX: 61 if typ == ADDR_TYPE_UNIX:
104 s.start_unix_server(*a) 62 s.start_unix_server(*a)
63 elif typ == ADDR_TYPE_WS:
64 url = urlparse(a[0])
65 s.start_websocket_server(url.hostname, url.port, reuseport=reuseport)
105 else: 66 else:
106 s.start_tcp_server(*a) 67 s.start_tcp_server(*a, reuseport=reuseport)
107 68
108 return s 69 return s
109 70
110 71
111def create_client(addr): 72def create_client(addr, username=None, password=None):
112 from . import client 73 from . import client
113 c = client.Client()
114 74
115 (typ, a) = parse_address(addr) 75 c = client.Client(username, password)
116 if typ == ADDR_TYPE_UNIX: 76
117 c.connect_unix(*a) 77 try:
118 else: 78 (typ, a) = parse_address(addr)
119 c.connect_tcp(*a) 79 if typ == ADDR_TYPE_UNIX:
80 c.connect_unix(*a)
81 elif typ == ADDR_TYPE_WS:
82 c.connect_websocket(*a)
83 else:
84 c.connect_tcp(*a)
85 return c
86 except Exception as e:
87 c.close()
88 raise e
120 89
121 return c
122 90
123async def create_async_client(addr): 91async def create_async_client(addr, username=None, password=None):
124 from . import client 92 from . import client
125 c = client.AsyncClient()
126 93
127 (typ, a) = parse_address(addr) 94 c = client.AsyncClient(username, password)
128 if typ == ADDR_TYPE_UNIX: 95
129 await c.connect_unix(*a) 96 try:
130 else: 97 (typ, a) = parse_address(addr)
131 await c.connect_tcp(*a) 98 if typ == ADDR_TYPE_UNIX:
99 await c.connect_unix(*a)
100 elif typ == ADDR_TYPE_WS:
101 await c.connect_websocket(*a)
102 else:
103 await c.connect_tcp(*a)
132 104
133 return c 105 return c
106 except Exception as e:
107 await c.close()
108 raise e
diff --git a/bitbake/lib/hashserv/client.py b/bitbake/lib/hashserv/client.py
index e05c1eb568..8cb18050a6 100644
--- a/bitbake/lib/hashserv/client.py
+++ b/bitbake/lib/hashserv/client.py
@@ -3,231 +3,378 @@
3# SPDX-License-Identifier: GPL-2.0-only 3# SPDX-License-Identifier: GPL-2.0-only
4# 4#
5 5
6import asyncio
7import json
8import logging 6import logging
9import socket 7import socket
10import os 8import asyncio
11from . import chunkify, DEFAULT_MAX_CHUNK, create_async_client 9import bb.asyncrpc
10import json
11from . import create_async_client
12 12
13 13
14logger = logging.getLogger("hashserv.client") 14logger = logging.getLogger("hashserv.client")
15 15
16 16
17class HashConnectionError(Exception): 17class Batch(object):
18 pass 18 def __init__(self):
19 self.done = False
20 self.cond = asyncio.Condition()
21 self.pending = []
22 self.results = []
23 self.sent_count = 0
19 24
25 async def recv(self, socket):
26 while True:
27 async with self.cond:
28 await self.cond.wait_for(lambda: self.pending or self.done)
20 29
21class AsyncClient(object): 30 if not self.pending:
22 MODE_NORMAL = 0 31 if self.done:
23 MODE_GET_STREAM = 1 32 return
33 continue
24 34
25 def __init__(self): 35 r = await socket.recv()
26 self.reader = None 36 self.results.append(r)
27 self.writer = None
28 self.mode = self.MODE_NORMAL
29 self.max_chunk = DEFAULT_MAX_CHUNK
30 37
31 async def connect_tcp(self, address, port): 38 async with self.cond:
32 async def connect_sock(): 39 self.pending.pop(0)
33 return await asyncio.open_connection(address, port)
34 40
35 self._connect_sock = connect_sock 41 async def send(self, socket, msgs):
42 try:
43 # In the event of a restart due to a reconnect, all in-flight
44 # messages need to be resent first to keep to result count in sync
45 for m in self.pending:
46 await socket.send(m)
36 47
37 async def connect_unix(self, path): 48 for m in msgs:
38 async def connect_sock(): 49 # Add the message to the pending list before attempting to send
39 return await asyncio.open_unix_connection(path) 50 # it so that if the send fails it will be retried
51 async with self.cond:
52 self.pending.append(m)
53 self.cond.notify()
54 self.sent_count += 1
40 55
41 self._connect_sock = connect_sock 56 await socket.send(m)
42 57
43 async def connect(self): 58 finally:
44 if self.reader is None or self.writer is None: 59 async with self.cond:
45 (self.reader, self.writer) = await self._connect_sock() 60 self.done = True
61 self.cond.notify()
62
63 async def process(self, socket, msgs):
64 await asyncio.gather(
65 self.recv(socket),
66 self.send(socket, msgs),
67 )
46 68
47 self.writer.write("OEHASHEQUIV 1.1\n\n".encode("utf-8")) 69 if len(self.results) != self.sent_count:
48 await self.writer.drain() 70 raise ValueError(
71 f"Expected result count {len(self.results)}. Expected {self.sent_count}"
72 )
49 73
50 cur_mode = self.mode 74 return self.results
51 self.mode = self.MODE_NORMAL
52 await self._set_mode(cur_mode)
53 75
54 async def close(self):
55 self.reader = None
56 76
57 if self.writer is not None: 77class AsyncClient(bb.asyncrpc.AsyncClient):
58 self.writer.close() 78 MODE_NORMAL = 0
59 self.writer = None 79 MODE_GET_STREAM = 1
80 MODE_EXIST_STREAM = 2
81 MODE_MARK_STREAM = 3
60 82
61 async def _send_wrapper(self, proc): 83 def __init__(self, username=None, password=None):
62 count = 0 84 super().__init__("OEHASHEQUIV", "1.1", logger)
63 while True: 85 self.mode = self.MODE_NORMAL
64 try: 86 self.username = username
65 await self.connect() 87 self.password = password
66 return await proc() 88 self.saved_become_user = None
67 except (
68 OSError,
69 HashConnectionError,
70 json.JSONDecodeError,
71 UnicodeDecodeError,
72 ) as e:
73 logger.warning("Error talking to server: %s" % e)
74 if count >= 3:
75 if not isinstance(e, HashConnectionError):
76 raise HashConnectionError(str(e))
77 raise e
78 await self.close()
79 count += 1
80
81 async def send_message(self, msg):
82 async def get_line():
83 line = await self.reader.readline()
84 if not line:
85 raise HashConnectionError("Connection closed")
86
87 line = line.decode("utf-8")
88
89 if not line.endswith("\n"):
90 raise HashConnectionError("Bad message %r" % message)
91
92 return line
93 89
94 async def proc(): 90 async def setup_connection(self):
95 for c in chunkify(json.dumps(msg), self.max_chunk): 91 await super().setup_connection()
96 self.writer.write(c.encode("utf-8")) 92 self.mode = self.MODE_NORMAL
97 await self.writer.drain() 93 if self.username:
94 # Save off become user temporarily because auth() resets it
95 become = self.saved_become_user
96 await self.auth(self.username, self.password)
98 97
99 l = await get_line() 98 if become:
99 await self.become_user(become)
100 100
101 m = json.loads(l) 101 async def send_stream_batch(self, mode, msgs):
102 if m and "chunk-stream" in m: 102 """
103 lines = [] 103 Does a "batch" process of stream messages. This sends the query
104 while True: 104 messages as fast as possible, and simultaneously attempts to read the
105 l = (await get_line()).rstrip("\n") 105 messages back. This helps to mitigate the effects of latency to the
106 if not l: 106 hash equivalence server be allowing multiple queries to be "in-flight"
107 break 107 at once
108 lines.append(l)
109 108
110 m = json.loads("".join(lines)) 109 The implementation does more complicated tracking using a count of sent
110 messages so that `msgs` can be a generator function (i.e. its length is
111 unknown)
111 112
112 return m 113 """
113 114
114 return await self._send_wrapper(proc) 115 b = Batch()
115 116
116 async def send_stream(self, msg):
117 async def proc(): 117 async def proc():
118 self.writer.write(("%s\n" % msg).encode("utf-8")) 118 nonlocal b
119 await self.writer.drain() 119
120 l = await self.reader.readline() 120 await self._set_mode(mode)
121 if not l: 121 return await b.process(self.socket, msgs)
122 raise HashConnectionError("Connection closed")
123 return l.decode("utf-8").rstrip()
124 122
125 return await self._send_wrapper(proc) 123 return await self._send_wrapper(proc)
126 124
125 async def invoke(self, *args, skip_mode=False, **kwargs):
126 # It's OK if connection errors cause a failure here, because the mode
127 # is also reset to normal on a new connection
128 if not skip_mode:
129 await self._set_mode(self.MODE_NORMAL)
130 return await super().invoke(*args, **kwargs)
131
127 async def _set_mode(self, new_mode): 132 async def _set_mode(self, new_mode):
128 if new_mode == self.MODE_NORMAL and self.mode == self.MODE_GET_STREAM: 133 async def stream_to_normal():
129 r = await self.send_stream("END") 134 # Check if already in normal mode (e.g. due to a connection reset)
135 if self.mode == self.MODE_NORMAL:
136 return "ok"
137 await self.socket.send("END")
138 return await self.socket.recv()
139
140 async def normal_to_stream(command):
141 r = await self.invoke({command: None}, skip_mode=True)
130 if r != "ok": 142 if r != "ok":
131 raise HashConnectionError("Bad response from server %r" % r) 143 self.check_invoke_error(r)
132 elif new_mode == self.MODE_GET_STREAM and self.mode == self.MODE_NORMAL: 144 raise ConnectionError(
133 r = await self.send_message({"get-stream": None}) 145 f"Unable to transition to stream mode: Bad response from server {r!r}"
146 )
147 self.logger.debug("Mode is now %s", command)
148
149 if new_mode == self.mode:
150 return
151
152 self.logger.debug("Transitioning mode %s -> %s", self.mode, new_mode)
153
154 # Always transition to normal mode before switching to any other mode
155 if self.mode != self.MODE_NORMAL:
156 r = await self._send_wrapper(stream_to_normal)
134 if r != "ok": 157 if r != "ok":
135 raise HashConnectionError("Bad response from server %r" % r) 158 self.check_invoke_error(r)
136 elif new_mode != self.mode: 159 raise ConnectionError(
137 raise Exception( 160 f"Unable to transition to normal mode: Bad response from server {r!r}"
138 "Undefined mode transition %r -> %r" % (self.mode, new_mode) 161 )
139 ) 162 self.logger.debug("Mode is now normal")
163
164 if new_mode == self.MODE_GET_STREAM:
165 await normal_to_stream("get-stream")
166 elif new_mode == self.MODE_EXIST_STREAM:
167 await normal_to_stream("exists-stream")
168 elif new_mode == self.MODE_MARK_STREAM:
169 await normal_to_stream("gc-mark-stream")
170 elif new_mode != self.MODE_NORMAL:
171 raise Exception("Undefined mode transition {self.mode!r} -> {new_mode!r}")
140 172
141 self.mode = new_mode 173 self.mode = new_mode
142 174
143 async def get_unihash(self, method, taskhash): 175 async def get_unihash(self, method, taskhash):
144 await self._set_mode(self.MODE_GET_STREAM) 176 r = await self.get_unihash_batch([(method, taskhash)])
145 r = await self.send_stream("%s %s" % (method, taskhash)) 177 return r[0]
146 if not r: 178
147 return None 179 async def get_unihash_batch(self, args):
148 return r 180 result = await self.send_stream_batch(
181 self.MODE_GET_STREAM,
182 (f"{method} {taskhash}" for method, taskhash in args),
183 )
184 return [r if r else None for r in result]
149 185
150 async def report_unihash(self, taskhash, method, outhash, unihash, extra={}): 186 async def report_unihash(self, taskhash, method, outhash, unihash, extra={}):
151 await self._set_mode(self.MODE_NORMAL)
152 m = extra.copy() 187 m = extra.copy()
153 m["taskhash"] = taskhash 188 m["taskhash"] = taskhash
154 m["method"] = method 189 m["method"] = method
155 m["outhash"] = outhash 190 m["outhash"] = outhash
156 m["unihash"] = unihash 191 m["unihash"] = unihash
157 return await self.send_message({"report": m}) 192 return await self.invoke({"report": m})
158 193
159 async def report_unihash_equiv(self, taskhash, method, unihash, extra={}): 194 async def report_unihash_equiv(self, taskhash, method, unihash, extra={}):
160 await self._set_mode(self.MODE_NORMAL)
161 m = extra.copy() 195 m = extra.copy()
162 m["taskhash"] = taskhash 196 m["taskhash"] = taskhash
163 m["method"] = method 197 m["method"] = method
164 m["unihash"] = unihash 198 m["unihash"] = unihash
165 return await self.send_message({"report-equiv": m}) 199 return await self.invoke({"report-equiv": m})
166 200
167 async def get_taskhash(self, method, taskhash, all_properties=False): 201 async def get_taskhash(self, method, taskhash, all_properties=False):
168 await self._set_mode(self.MODE_NORMAL) 202 return await self.invoke(
169 return await self.send_message(
170 {"get": {"taskhash": taskhash, "method": method, "all": all_properties}} 203 {"get": {"taskhash": taskhash, "method": method, "all": all_properties}}
171 ) 204 )
172 205
173 async def get_outhash(self, method, outhash, taskhash): 206 async def unihash_exists(self, unihash):
174 await self._set_mode(self.MODE_NORMAL) 207 r = await self.unihash_exists_batch([unihash])
175 return await self.send_message( 208 return r[0]
176 {"get-outhash": {"outhash": outhash, "taskhash": taskhash, "method": method}} 209
210 async def unihash_exists_batch(self, unihashes):
211 result = await self.send_stream_batch(self.MODE_EXIST_STREAM, unihashes)
212 return [r == "true" for r in result]
213
214 async def get_outhash(self, method, outhash, taskhash, with_unihash=True):
215 return await self.invoke(
216 {
217 "get-outhash": {
218 "outhash": outhash,
219 "taskhash": taskhash,
220 "method": method,
221 "with_unihash": with_unihash,
222 }
223 }
177 ) 224 )
178 225
179 async def get_stats(self): 226 async def get_stats(self):
180 await self._set_mode(self.MODE_NORMAL) 227 return await self.invoke({"get-stats": None})
181 return await self.send_message({"get-stats": None})
182 228
183 async def reset_stats(self): 229 async def reset_stats(self):
184 await self._set_mode(self.MODE_NORMAL) 230 return await self.invoke({"reset-stats": None})
185 return await self.send_message({"reset-stats": None})
186 231
187 async def backfill_wait(self): 232 async def backfill_wait(self):
188 await self._set_mode(self.MODE_NORMAL) 233 return (await self.invoke({"backfill-wait": None}))["tasks"]
189 return (await self.send_message({"backfill-wait": None}))["tasks"] 234
235 async def remove(self, where):
236 return await self.invoke({"remove": {"where": where}})
237
238 async def clean_unused(self, max_age):
239 return await self.invoke({"clean-unused": {"max_age_seconds": max_age}})
240
241 async def auth(self, username, token):
242 result = await self.invoke({"auth": {"username": username, "token": token}})
243 self.username = username
244 self.password = token
245 self.saved_become_user = None
246 return result
247
248 async def refresh_token(self, username=None):
249 m = {}
250 if username:
251 m["username"] = username
252 result = await self.invoke({"refresh-token": m})
253 if (
254 self.username
255 and not self.saved_become_user
256 and result["username"] == self.username
257 ):
258 self.password = result["token"]
259 return result
190 260
261 async def set_user_perms(self, username, permissions):
262 return await self.invoke(
263 {"set-user-perms": {"username": username, "permissions": permissions}}
264 )
191 265
192class Client(object): 266 async def get_user(self, username=None):
193 def __init__(self): 267 m = {}
194 self.client = AsyncClient() 268 if username:
195 self.loop = asyncio.new_event_loop() 269 m["username"] = username
270 return await self.invoke({"get-user": m})
271
272 async def get_all_users(self):
273 return (await self.invoke({"get-all-users": {}}))["users"]
274
275 async def new_user(self, username, permissions):
276 return await self.invoke(
277 {"new-user": {"username": username, "permissions": permissions}}
278 )
279
280 async def delete_user(self, username):
281 return await self.invoke({"delete-user": {"username": username}})
282
283 async def become_user(self, username):
284 result = await self.invoke({"become-user": {"username": username}})
285 if username == self.username:
286 self.saved_become_user = None
287 else:
288 self.saved_become_user = username
289 return result
290
291 async def get_db_usage(self):
292 return (await self.invoke({"get-db-usage": {}}))["usage"]
293
294 async def get_db_query_columns(self):
295 return (await self.invoke({"get-db-query-columns": {}}))["columns"]
296
297 async def gc_status(self):
298 return await self.invoke({"gc-status": {}})
299
300 async def gc_mark(self, mark, where):
301 """
302 Starts a new garbage collection operation identified by "mark". If
303 garbage collection is already in progress with "mark", the collection
304 is continued.
305
306 All unihash entries that match the "where" clause are marked to be
307 kept. In addition, any new entries added to the database after this
308 command will be automatically marked with "mark"
309 """
310 return await self.invoke({"gc-mark": {"mark": mark, "where": where}})
311
312 async def gc_mark_stream(self, mark, rows):
313 """
314 Similar to `gc-mark`, but accepts a list of "where" key-value pair
315 conditions. It utilizes stream mode to mark hashes, which helps reduce
316 the impact of latency when communicating with the hash equivalence
317 server.
318 """
319 def row_to_dict(row):
320 pairs = row.split()
321 return dict(zip(pairs[::2], pairs[1::2]))
322
323 responses = await self.send_stream_batch(
324 self.MODE_MARK_STREAM,
325 (json.dumps({"mark": mark, "where": row_to_dict(row)}) for row in rows),
326 )
327
328 return {"count": sum(int(json.loads(r)["count"]) for r in responses)}
196 329
197 for call in ( 330 async def gc_sweep(self, mark):
331 """
332 Finishes garbage collection for "mark". All unihash entries that have
333 not been marked will be deleted.
334
335 It is recommended to clean unused outhash entries after running this to
336 cleanup any dangling outhashes
337 """
338 return await self.invoke({"gc-sweep": {"mark": mark}})
339
340
341class Client(bb.asyncrpc.Client):
342 def __init__(self, username=None, password=None):
343 self.username = username
344 self.password = password
345
346 super().__init__()
347 self._add_methods(
198 "connect_tcp", 348 "connect_tcp",
199 "close", 349 "connect_websocket",
200 "get_unihash", 350 "get_unihash",
351 "get_unihash_batch",
201 "report_unihash", 352 "report_unihash",
202 "report_unihash_equiv", 353 "report_unihash_equiv",
203 "get_taskhash", 354 "get_taskhash",
355 "unihash_exists",
356 "unihash_exists_batch",
357 "get_outhash",
204 "get_stats", 358 "get_stats",
205 "reset_stats", 359 "reset_stats",
206 "backfill_wait", 360 "backfill_wait",
207 ): 361 "remove",
208 downcall = getattr(self.client, call) 362 "clean_unused",
209 setattr(self, call, self._get_downcall_wrapper(downcall)) 363 "auth",
210 364 "refresh_token",
211 def _get_downcall_wrapper(self, downcall): 365 "set_user_perms",
212 def wrapper(*args, **kwargs): 366 "get_user",
213 return self.loop.run_until_complete(downcall(*args, **kwargs)) 367 "get_all_users",
214 368 "new_user",
215 return wrapper 369 "delete_user",
216 370 "become_user",
217 def connect_unix(self, path): 371 "get_db_usage",
218 # AF_UNIX has path length issues so chdir here to workaround 372 "get_db_query_columns",
219 cwd = os.getcwd() 373 "gc_status",
220 try: 374 "gc_mark",
221 os.chdir(os.path.dirname(path)) 375 "gc_mark_stream",
222 self.loop.run_until_complete(self.client.connect_unix(os.path.basename(path))) 376 "gc_sweep",
223 self.loop.run_until_complete(self.client.connect()) 377 )
224 finally:
225 os.chdir(cwd)
226
227 @property
228 def max_chunk(self):
229 return self.client.max_chunk
230 378
231 @max_chunk.setter 379 def _get_async_client(self):
232 def max_chunk(self, value): 380 return AsyncClient(self.username, self.password)
233 self.client.max_chunk = value
diff --git a/bitbake/lib/hashserv/server.py b/bitbake/lib/hashserv/server.py
index a0dc0c170f..58f95c7bcd 100644
--- a/bitbake/lib/hashserv/server.py
+++ b/bitbake/lib/hashserv/server.py
@@ -3,20 +3,52 @@
3# SPDX-License-Identifier: GPL-2.0-only 3# SPDX-License-Identifier: GPL-2.0-only
4# 4#
5 5
6from contextlib import closing, contextmanager 6from datetime import datetime, timedelta
7from datetime import datetime
8import asyncio 7import asyncio
9import json
10import logging 8import logging
11import math 9import math
12import os
13import signal
14import socket
15import sys
16import time 10import time
17from . import chunkify, DEFAULT_MAX_CHUNK, create_async_client, TABLE_COLUMNS 11import os
12import base64
13import json
14import hashlib
15from . import create_async_client
16import bb.asyncrpc
17
18logger = logging.getLogger("hashserv.server")
19
20
21# This permission only exists to match nothing
22NONE_PERM = "@none"
23
24READ_PERM = "@read"
25REPORT_PERM = "@report"
26DB_ADMIN_PERM = "@db-admin"
27USER_ADMIN_PERM = "@user-admin"
28ALL_PERM = "@all"
29
30ALL_PERMISSIONS = {
31 READ_PERM,
32 REPORT_PERM,
33 DB_ADMIN_PERM,
34 USER_ADMIN_PERM,
35 ALL_PERM,
36}
37
38DEFAULT_ANON_PERMS = (
39 READ_PERM,
40 REPORT_PERM,
41 DB_ADMIN_PERM,
42)
43
44TOKEN_ALGORITHM = "sha256"
45
46# 48 bytes of random data will result in 64 characters when base64
47# encoded. This number also ensures that the base64 encoding won't have any
48# trailing '=' characters.
49TOKEN_SIZE = 48
18 50
19logger = logging.getLogger('hashserv.server') 51SALT_SIZE = 8
20 52
21 53
22class Measurement(object): 54class Measurement(object):
@@ -106,522 +138,773 @@ class Stats(object):
106 return math.sqrt(self.s / (self.num - 1)) 138 return math.sqrt(self.s / (self.num - 1))
107 139
108 def todict(self): 140 def todict(self):
109 return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')} 141 return {
110 142 k: getattr(self, k)
111 143 for k in ("num", "total_time", "max_time", "average", "stdev")
112class ClientError(Exception):
113 pass
114
115class ServerError(Exception):
116 pass
117
118def insert_task(cursor, data, ignore=False):
119 keys = sorted(data.keys())
120 query = '''INSERT%s INTO tasks_v2 (%s) VALUES (%s)''' % (
121 " OR IGNORE" if ignore else "",
122 ', '.join(keys),
123 ', '.join(':' + k for k in keys))
124 cursor.execute(query, data)
125
126async def copy_from_upstream(client, db, method, taskhash):
127 d = await client.get_taskhash(method, taskhash, True)
128 if d is not None:
129 # Filter out unknown columns
130 d = {k: v for k, v in d.items() if k in TABLE_COLUMNS}
131 keys = sorted(d.keys())
132
133 with closing(db.cursor()) as cursor:
134 insert_task(cursor, d)
135 db.commit()
136
137 return d
138
139async def copy_outhash_from_upstream(client, db, method, outhash, taskhash):
140 d = await client.get_outhash(method, outhash, taskhash)
141 if d is not None:
142 # Filter out unknown columns
143 d = {k: v for k, v in d.items() if k in TABLE_COLUMNS}
144 keys = sorted(d.keys())
145
146 with closing(db.cursor()) as cursor:
147 insert_task(cursor, d)
148 db.commit()
149
150 return d
151
152class ServerClient(object):
153 FAST_QUERY = 'SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1'
154 ALL_QUERY = 'SELECT * FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1'
155 OUTHASH_QUERY = '''
156 -- Find tasks with a matching outhash (that is, tasks that
157 -- are equivalent)
158 SELECT * FROM tasks_v2 WHERE method=:method AND outhash=:outhash
159
160 -- If there is an exact match on the taskhash, return it.
161 -- Otherwise return the oldest matching outhash of any
162 -- taskhash
163 ORDER BY CASE WHEN taskhash=:taskhash THEN 1 ELSE 2 END,
164 created ASC
165
166 -- Only return one row
167 LIMIT 1
168 '''
169
170 def __init__(self, reader, writer, db, request_stats, backfill_queue, upstream, read_only):
171 self.reader = reader
172 self.writer = writer
173 self.db = db
174 self.request_stats = request_stats
175 self.max_chunk = DEFAULT_MAX_CHUNK
176 self.backfill_queue = backfill_queue
177 self.upstream = upstream
178
179 self.handlers = {
180 'get': self.handle_get,
181 'get-outhash': self.handle_get_outhash,
182 'get-stream': self.handle_get_stream,
183 'get-stats': self.handle_get_stats,
184 'chunk-stream': self.handle_chunk,
185 } 144 }
186 145
187 if not read_only:
188 self.handlers.update({
189 'report': self.handle_report,
190 'report-equiv': self.handle_equivreport,
191 'reset-stats': self.handle_reset_stats,
192 'backfill-wait': self.handle_backfill_wait,
193 })
194 146
195 async def process_requests(self): 147token_refresh_semaphore = asyncio.Lock()
196 if self.upstream is not None:
197 self.upstream_client = await create_async_client(self.upstream)
198 else:
199 self.upstream_client = None
200 148
201 try:
202 149
150async def new_token():
151 # Prevent malicious users from using this API to deduce the entropy
152 # pool on the server and thus be able to guess a token. *All* token
153 # refresh requests lock the same global semaphore and then sleep for a
154 # short time. The effectively rate limits the total number of requests
155 # than can be made across all clients to 10/second, which should be enough
156 # since you have to be an authenticated users to make the request in the
157 # first place
158 async with token_refresh_semaphore:
159 await asyncio.sleep(0.1)
160 raw = os.getrandom(TOKEN_SIZE, os.GRND_NONBLOCK)
203 161
204 self.addr = self.writer.get_extra_info('peername') 162 return base64.b64encode(raw, b"._").decode("utf-8")
205 logger.debug('Client %r connected' % (self.addr,))
206 163
207 # Read protocol and version
208 protocol = await self.reader.readline()
209 if protocol is None:
210 return
211 164
212 (proto_name, proto_version) = protocol.decode('utf-8').rstrip().split() 165def new_salt():
213 if proto_name != 'OEHASHEQUIV': 166 return os.getrandom(SALT_SIZE, os.GRND_NONBLOCK).hex()
214 return
215 167
216 proto_version = tuple(int(v) for v in proto_version.split('.'))
217 if proto_version < (1, 0) or proto_version > (1, 1):
218 return
219 168
220 # Read headers. Currently, no headers are implemented, so look for 169def hash_token(algo, salt, token):
221 # an empty line to signal the end of the headers 170 h = hashlib.new(algo)
222 while True: 171 h.update(salt.encode("utf-8"))
223 line = await self.reader.readline() 172 h.update(token.encode("utf-8"))
224 if line is None: 173 return ":".join([algo, salt, h.hexdigest()])
225 return
226 174
227 line = line.decode('utf-8').rstrip()
228 if not line:
229 break
230 175
231 # Handle messages 176def permissions(*permissions, allow_anon=True, allow_self_service=False):
232 while True: 177 """
233 d = await self.read_message() 178 Function decorator that can be used to decorate an RPC function call and
234 if d is None: 179 check that the current users permissions match the require permissions.
235 break
236 await self.dispatch_message(d)
237 await self.writer.drain()
238 except ClientError as e:
239 logger.error(str(e))
240 finally:
241 if self.upstream_client is not None:
242 await self.upstream_client.close()
243 180
244 self.writer.close() 181 If allow_anon is True, the user will also be allowed to make the RPC call
182 if the anonymous user permissions match the permissions.
245 183
246 async def dispatch_message(self, msg): 184 If allow_self_service is True, and the "username" property in the request
247 for k in self.handlers.keys(): 185 is the currently logged in user, or not specified, the user will also be
248 if k in msg: 186 allowed to make the request. This allows users to access normal privileged
249 logger.debug('Handling %s' % k) 187 API, as long as they are only modifying their own user properties (e.g.
250 if 'stream' in k: 188 users can be allowed to reset their own token without @user-admin
251 await self.handlers[k](msg[k]) 189 permissions, but not the token for any other user.
190 """
191
192 def wrapper(func):
193 async def wrap(self, request):
194 if allow_self_service and self.user is not None:
195 username = request.get("username", self.user.username)
196 if username == self.user.username:
197 request["username"] = self.user.username
198 return await func(self, request)
199
200 if not self.user_has_permissions(*permissions, allow_anon=allow_anon):
201 if not self.user:
202 username = "Anonymous user"
203 user_perms = self.server.anon_perms
252 else: 204 else:
253 with self.request_stats.start_sample() as self.request_sample, \ 205 username = self.user.username
254 self.request_sample.measure(): 206 user_perms = self.user.permissions
255 await self.handlers[k](msg[k]) 207
256 return 208 self.logger.info(
209 "User %s with permissions %r denied from calling %s. Missing permissions(s) %r",
210 username,
211 ", ".join(user_perms),
212 func.__name__,
213 ", ".join(permissions),
214 )
215 raise bb.asyncrpc.InvokeError(
216 f"{username} is not allowed to access permissions(s) {', '.join(permissions)}"
217 )
218
219 return await func(self, request)
220
221 return wrap
222
223 return wrapper
224
225
226class ServerClient(bb.asyncrpc.AsyncServerConnection):
227 def __init__(self, socket, server):
228 super().__init__(socket, "OEHASHEQUIV", server.logger)
229 self.server = server
230 self.max_chunk = bb.asyncrpc.DEFAULT_MAX_CHUNK
231 self.user = None
232
233 self.handlers.update(
234 {
235 "get": self.handle_get,
236 "get-outhash": self.handle_get_outhash,
237 "get-stream": self.handle_get_stream,
238 "exists-stream": self.handle_exists_stream,
239 "get-stats": self.handle_get_stats,
240 "get-db-usage": self.handle_get_db_usage,
241 "get-db-query-columns": self.handle_get_db_query_columns,
242 # Not always read-only, but internally checks if the server is
243 # read-only
244 "report": self.handle_report,
245 "auth": self.handle_auth,
246 "get-user": self.handle_get_user,
247 "get-all-users": self.handle_get_all_users,
248 "become-user": self.handle_become_user,
249 }
250 )
257 251
258 raise ClientError("Unrecognized command %r" % msg) 252 if not self.server.read_only:
253 self.handlers.update(
254 {
255 "report-equiv": self.handle_equivreport,
256 "reset-stats": self.handle_reset_stats,
257 "backfill-wait": self.handle_backfill_wait,
258 "remove": self.handle_remove,
259 "gc-mark": self.handle_gc_mark,
260 "gc-mark-stream": self.handle_gc_mark_stream,
261 "gc-sweep": self.handle_gc_sweep,
262 "gc-status": self.handle_gc_status,
263 "clean-unused": self.handle_clean_unused,
264 "refresh-token": self.handle_refresh_token,
265 "set-user-perms": self.handle_set_perms,
266 "new-user": self.handle_new_user,
267 "delete-user": self.handle_delete_user,
268 }
269 )
259 270
260 def write_message(self, msg): 271 def raise_no_user_error(self, username):
261 for c in chunkify(json.dumps(msg), self.max_chunk): 272 raise bb.asyncrpc.InvokeError(f"No user named '{username}' exists")
262 self.writer.write(c.encode('utf-8'))
263 273
264 async def read_message(self): 274 def user_has_permissions(self, *permissions, allow_anon=True):
265 l = await self.reader.readline() 275 permissions = set(permissions)
266 if not l: 276 if allow_anon:
267 return None 277 if ALL_PERM in self.server.anon_perms:
278 return True
268 279
269 try: 280 if not permissions - self.server.anon_perms:
270 message = l.decode('utf-8') 281 return True
271 282
272 if not message.endswith('\n'): 283 if self.user is None:
273 return None 284 return False
274 285
275 return json.loads(message) 286 if ALL_PERM in self.user.permissions:
276 except (json.JSONDecodeError, UnicodeDecodeError) as e: 287 return True
277 logger.error('Bad message from client: %r' % message)
278 raise e
279 288
280 async def handle_chunk(self, request): 289 if not permissions - self.user.permissions:
281 lines = [] 290 return True
282 try:
283 while True:
284 l = await self.reader.readline()
285 l = l.rstrip(b"\n").decode("utf-8")
286 if not l:
287 break
288 lines.append(l)
289 291
290 msg = json.loads(''.join(lines)) 292 return False
291 except (json.JSONDecodeError, UnicodeDecodeError) as e:
292 logger.error('Bad message from client: %r' % message)
293 raise e
294 293
295 if 'chunk-stream' in msg: 294 def validate_proto_version(self):
296 raise ClientError("Nested chunks are not allowed") 295 return self.proto_version > (1, 0) and self.proto_version <= (1, 1)
297 296
298 await self.dispatch_message(msg) 297 async def process_requests(self):
298 async with self.server.db_engine.connect(self.logger) as db:
299 self.db = db
300 if self.server.upstream is not None:
301 self.upstream_client = await create_async_client(self.server.upstream)
302 else:
303 self.upstream_client = None
299 304
300 async def handle_get(self, request): 305 try:
301 method = request['method'] 306 await super().process_requests()
302 taskhash = request['taskhash'] 307 finally:
308 if self.upstream_client is not None:
309 await self.upstream_client.close()
303 310
304 if request.get('all', False): 311 async def dispatch_message(self, msg):
305 row = self.query_equivalent(method, taskhash, self.ALL_QUERY) 312 for k in self.handlers.keys():
306 else: 313 if k in msg:
307 row = self.query_equivalent(method, taskhash, self.FAST_QUERY) 314 self.logger.debug("Handling %s" % k)
315 if "stream" in k:
316 return await self.handlers[k](msg[k])
317 else:
318 with self.server.request_stats.start_sample() as self.request_sample, self.request_sample.measure():
319 return await self.handlers[k](msg[k])
308 320
309 if row is not None: 321 raise bb.asyncrpc.ClientError("Unrecognized command %r" % msg)
310 logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) 322
311 d = {k: row[k] for k in row.keys()} 323 @permissions(READ_PERM)
312 elif self.upstream_client is not None: 324 async def handle_get(self, request):
313 d = await copy_from_upstream(self.upstream_client, self.db, method, taskhash) 325 method = request["method"]
326 taskhash = request["taskhash"]
327 fetch_all = request.get("all", False)
328
329 return await self.get_unihash(method, taskhash, fetch_all)
330
331 async def get_unihash(self, method, taskhash, fetch_all=False):
332 d = None
333
334 if fetch_all:
335 row = await self.db.get_unihash_by_taskhash_full(method, taskhash)
336 if row is not None:
337 d = {k: row[k] for k in row.keys()}
338 elif self.upstream_client is not None:
339 d = await self.upstream_client.get_taskhash(method, taskhash, True)
340 await self.update_unified(d)
314 else: 341 else:
315 d = None 342 row = await self.db.get_equivalent(method, taskhash)
343
344 if row is not None:
345 d = {k: row[k] for k in row.keys()}
346 elif self.upstream_client is not None:
347 d = await self.upstream_client.get_taskhash(method, taskhash)
348 await self.db.insert_unihash(d["method"], d["taskhash"], d["unihash"])
316 349
317 self.write_message(d) 350 return d
318 351
352 @permissions(READ_PERM)
319 async def handle_get_outhash(self, request): 353 async def handle_get_outhash(self, request):
320 with closing(self.db.cursor()) as cursor: 354 method = request["method"]
321 cursor.execute(self.OUTHASH_QUERY, 355 outhash = request["outhash"]
322 {k: request[k] for k in ('method', 'outhash', 'taskhash')}) 356 taskhash = request["taskhash"]
357 with_unihash = request.get("with_unihash", True)
323 358
324 row = cursor.fetchone() 359 return await self.get_outhash(method, outhash, taskhash, with_unihash)
360
361 async def get_outhash(self, method, outhash, taskhash, with_unihash=True):
362 d = None
363 if with_unihash:
364 row = await self.db.get_unihash_by_outhash(method, outhash)
365 else:
366 row = await self.db.get_outhash(method, outhash)
325 367
326 if row is not None: 368 if row is not None:
327 logger.debug('Found equivalent outhash %s -> %s', (row['outhash'], row['unihash']))
328 d = {k: row[k] for k in row.keys()} 369 d = {k: row[k] for k in row.keys()}
329 else: 370 elif self.upstream_client is not None:
330 d = None 371 d = await self.upstream_client.get_outhash(method, outhash, taskhash)
372 await self.update_unified(d)
331 373
332 self.write_message(d) 374 return d
333 375
334 async def handle_get_stream(self, request): 376 async def update_unified(self, data):
335 self.write_message('ok') 377 if data is None:
378 return
379
380 await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"])
381 await self.db.insert_outhash(data)
382
383 async def _stream_handler(self, handler):
384 await self.socket.send_message("ok")
336 385
337 while True: 386 while True:
338 upstream = None 387 upstream = None
339 388
340 l = await self.reader.readline() 389 l = await self.socket.recv()
341 if not l: 390 if not l:
342 return 391 break
343 392
344 try: 393 try:
345 # This inner loop is very sensitive and must be as fast as 394 # This inner loop is very sensitive and must be as fast as
346 # possible (which is why the request sample is handled manually 395 # possible (which is why the request sample is handled manually
347 # instead of using 'with', and also why logging statements are 396 # instead of using 'with', and also why logging statements are
348 # commented out. 397 # commented out.
349 self.request_sample = self.request_stats.start_sample() 398 self.request_sample = self.server.request_stats.start_sample()
350 request_measure = self.request_sample.measure() 399 request_measure = self.request_sample.measure()
351 request_measure.start() 400 request_measure.start()
352 401
353 l = l.decode('utf-8').rstrip() 402 if l == "END":
354 if l == 'END': 403 break
355 self.writer.write('ok\n'.encode('utf-8'))
356 return
357
358 (method, taskhash) = l.split()
359 #logger.debug('Looking up %s %s' % (method, taskhash))
360 row = self.query_equivalent(method, taskhash, self.FAST_QUERY)
361 if row is not None:
362 msg = ('%s\n' % row['unihash']).encode('utf-8')
363 #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
364 elif self.upstream_client is not None:
365 upstream = await self.upstream_client.get_unihash(method, taskhash)
366 if upstream:
367 msg = ("%s\n" % upstream).encode("utf-8")
368 else:
369 msg = "\n".encode("utf-8")
370 else:
371 msg = '\n'.encode('utf-8')
372 404
373 self.writer.write(msg) 405 msg = await handler(l)
406 await self.socket.send(msg)
374 finally: 407 finally:
375 request_measure.end() 408 request_measure.end()
376 self.request_sample.end() 409 self.request_sample.end()
377 410
378 await self.writer.drain() 411 await self.socket.send("ok")
412 return self.NO_RESPONSE
379 413
380 # Post to the backfill queue after writing the result to minimize 414 @permissions(READ_PERM)
381 # the turn around time on a request 415 async def handle_get_stream(self, request):
382 if upstream is not None: 416 async def handler(l):
383 await self.backfill_queue.put((method, taskhash)) 417 (method, taskhash) = l.split()
418 # self.logger.debug('Looking up %s %s' % (method, taskhash))
419 row = await self.db.get_equivalent(method, taskhash)
384 420
385 async def handle_report(self, data): 421 if row is not None:
386 with closing(self.db.cursor()) as cursor: 422 # self.logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
387 cursor.execute(self.OUTHASH_QUERY, 423 return row["unihash"]
388 {k: data[k] for k in ('method', 'outhash', 'taskhash')})
389
390 row = cursor.fetchone()
391
392 if row is None and self.upstream_client:
393 # Try upstream
394 row = await copy_outhash_from_upstream(self.upstream_client,
395 self.db,
396 data['method'],
397 data['outhash'],
398 data['taskhash'])
399
400 # If no matching outhash was found, or one *was* found but it
401 # wasn't an exact match on the taskhash, a new entry for this
402 # taskhash should be added
403 if row is None or row['taskhash'] != data['taskhash']:
404 # If a row matching the outhash was found, the unihash for
405 # the new taskhash should be the same as that one.
406 # Otherwise the caller provided unihash is used.
407 unihash = data['unihash']
408 if row is not None:
409 unihash = row['unihash']
410
411 insert_data = {
412 'method': data['method'],
413 'outhash': data['outhash'],
414 'taskhash': data['taskhash'],
415 'unihash': unihash,
416 'created': datetime.now()
417 }
418 424
419 for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'): 425 if self.upstream_client is not None:
420 if k in data: 426 upstream = await self.upstream_client.get_unihash(method, taskhash)
421 insert_data[k] = data[k] 427 if upstream:
428 await self.server.backfill_queue.put((method, taskhash))
429 return upstream
422 430
423 insert_task(cursor, insert_data) 431 return ""
424 self.db.commit()
425 432
426 logger.info('Adding taskhash %s with unihash %s', 433 return await self._stream_handler(handler)
427 data['taskhash'], unihash)
428 434
429 d = { 435 @permissions(READ_PERM)
430 'taskhash': data['taskhash'], 436 async def handle_exists_stream(self, request):
431 'method': data['method'], 437 async def handler(l):
432 'unihash': unihash 438 if await self.db.unihash_exists(l):
433 } 439 return "true"
434 else:
435 d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
436 440
437 self.write_message(d) 441 if self.upstream_client is not None:
442 if await self.upstream_client.unihash_exists(l):
443 return "true"
438 444
439 async def handle_equivreport(self, data): 445 return "false"
440 with closing(self.db.cursor()) as cursor:
441 insert_data = {
442 'method': data['method'],
443 'outhash': "",
444 'taskhash': data['taskhash'],
445 'unihash': data['unihash'],
446 'created': datetime.now()
447 }
448 446
449 for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'): 447 return await self._stream_handler(handler)
450 if k in data:
451 insert_data[k] = data[k]
452 448
453 insert_task(cursor, insert_data, ignore=True) 449 async def report_readonly(self, data):
454 self.db.commit() 450 method = data["method"]
451 outhash = data["outhash"]
452 taskhash = data["taskhash"]
455 453
456 # Fetch the unihash that will be reported for the taskhash. If the 454 info = await self.get_outhash(method, outhash, taskhash)
457 # unihash matches, it means this row was inserted (or the mapping 455 if info:
458 # was already valid) 456 unihash = info["unihash"]
459 row = self.query_equivalent(data['method'], data['taskhash'], self.FAST_QUERY) 457 else:
458 unihash = data["unihash"]
460 459
461 if row['unihash'] == data['unihash']: 460 return {
462 logger.info('Adding taskhash equivalence for %s with unihash %s', 461 "taskhash": taskhash,
463 data['taskhash'], row['unihash']) 462 "method": method,
463 "unihash": unihash,
464 }
464 465
465 d = {k: row[k] for k in ('taskhash', 'method', 'unihash')} 466 # Since this can be called either read only or to report, the check to
467 # report is made inside the function
468 @permissions(READ_PERM)
469 async def handle_report(self, data):
470 if self.server.read_only or not self.user_has_permissions(REPORT_PERM):
471 return await self.report_readonly(data)
472
473 outhash_data = {
474 "method": data["method"],
475 "outhash": data["outhash"],
476 "taskhash": data["taskhash"],
477 "created": datetime.now(),
478 }
466 479
467 self.write_message(d) 480 for k in ("owner", "PN", "PV", "PR", "task", "outhash_siginfo"):
481 if k in data:
482 outhash_data[k] = data[k]
468 483
484 if self.user:
485 outhash_data["owner"] = self.user.username
469 486
470 async def handle_get_stats(self, request): 487 # Insert the new entry, unless it already exists
471 d = { 488 if await self.db.insert_outhash(outhash_data):
472 'requests': self.request_stats.todict(), 489 # If this row is new, check if it is equivalent to another
490 # output hash
491 row = await self.db.get_equivalent_for_outhash(
492 data["method"], data["outhash"], data["taskhash"]
493 )
494
495 if row is not None:
496 # A matching output hash was found. Set our taskhash to the
497 # same unihash since they are equivalent
498 unihash = row["unihash"]
499 else:
500 # No matching output hash was found. This is probably the
501 # first outhash to be added.
502 unihash = data["unihash"]
503
504 # Query upstream to see if it has a unihash we can use
505 if self.upstream_client is not None:
506 upstream_data = await self.upstream_client.get_outhash(
507 data["method"], data["outhash"], data["taskhash"]
508 )
509 if upstream_data is not None:
510 unihash = upstream_data["unihash"]
511
512 await self.db.insert_unihash(data["method"], data["taskhash"], unihash)
513
514 unihash_data = await self.get_unihash(data["method"], data["taskhash"])
515 if unihash_data is not None:
516 unihash = unihash_data["unihash"]
517 else:
518 unihash = data["unihash"]
519
520 return {
521 "taskhash": data["taskhash"],
522 "method": data["method"],
523 "unihash": unihash,
473 } 524 }
474 525
475 self.write_message(d) 526 @permissions(READ_PERM, REPORT_PERM)
527 async def handle_equivreport(self, data):
528 await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"])
529
530 # Fetch the unihash that will be reported for the taskhash. If the
531 # unihash matches, it means this row was inserted (or the mapping
532 # was already valid)
533 row = await self.db.get_equivalent(data["method"], data["taskhash"])
534
535 if row["unihash"] == data["unihash"]:
536 self.logger.info(
537 "Adding taskhash equivalence for %s with unihash %s",
538 data["taskhash"],
539 row["unihash"],
540 )
541
542 return {k: row[k] for k in ("taskhash", "method", "unihash")}
476 543
544 @permissions(READ_PERM)
545 async def handle_get_stats(self, request):
546 return {
547 "requests": self.server.request_stats.todict(),
548 }
549
550 @permissions(DB_ADMIN_PERM)
477 async def handle_reset_stats(self, request): 551 async def handle_reset_stats(self, request):
478 d = { 552 d = {
479 'requests': self.request_stats.todict(), 553 "requests": self.server.request_stats.todict(),
480 } 554 }
481 555
482 self.request_stats.reset() 556 self.server.request_stats.reset()
483 self.write_message(d) 557 return d
484 558
559 @permissions(READ_PERM)
485 async def handle_backfill_wait(self, request): 560 async def handle_backfill_wait(self, request):
486 d = { 561 d = {
487 'tasks': self.backfill_queue.qsize(), 562 "tasks": self.server.backfill_queue.qsize(),
488 } 563 }
489 await self.backfill_queue.join() 564 await self.server.backfill_queue.join()
490 self.write_message(d) 565 return d
491 566
492 def query_equivalent(self, method, taskhash, query): 567 @permissions(DB_ADMIN_PERM)
493 # This is part of the inner loop and must be as fast as possible 568 async def handle_remove(self, request):
494 try: 569 condition = request["where"]
495 cursor = self.db.cursor() 570 if not isinstance(condition, dict):
496 cursor.execute(query, {'method': method, 'taskhash': taskhash}) 571 raise TypeError("Bad condition type %s" % type(condition))
497 return cursor.fetchone()
498 except:
499 cursor.close()
500 572
573 return {"count": await self.db.remove(condition)}
501 574
502class Server(object): 575 @permissions(DB_ADMIN_PERM)
503 def __init__(self, db, loop=None, upstream=None, read_only=False): 576 async def handle_gc_mark(self, request):
504 if upstream and read_only: 577 condition = request["where"]
505 raise ServerError("Read-only hashserv cannot pull from an upstream server") 578 mark = request["mark"]
506 579
507 self.request_stats = Stats() 580 if not isinstance(condition, dict):
508 self.db = db 581 raise TypeError("Bad condition type %s" % type(condition))
509 582
510 if loop is None: 583 if not isinstance(mark, str):
511 self.loop = asyncio.new_event_loop() 584 raise TypeError("Bad mark type %s" % type(mark))
512 self.close_loop = True
513 else:
514 self.loop = loop
515 self.close_loop = False
516 585
517 self.upstream = upstream 586 return {"count": await self.db.gc_mark(mark, condition)}
518 self.read_only = read_only
519 587
520 self._cleanup_socket = None 588 @permissions(DB_ADMIN_PERM)
589 async def handle_gc_mark_stream(self, request):
590 async def handler(line):
591 try:
592 decoded_line = json.loads(line)
593 except json.JSONDecodeError as exc:
594 raise bb.asyncrpc.InvokeError(
595 "Could not decode JSONL input '%s'" % line
596 ) from exc
521 597
522 def start_tcp_server(self, host, port): 598 try:
523 self.server = self.loop.run_until_complete( 599 mark = decoded_line["mark"]
524 asyncio.start_server(self.handle_client, host, port, loop=self.loop) 600 condition = decoded_line["where"]
525 ) 601 if not isinstance(mark, str):
602 raise TypeError("Bad mark type %s" % type(mark))
526 603
527 for s in self.server.sockets: 604 if not isinstance(condition, dict):
528 logger.info('Listening on %r' % (s.getsockname(),)) 605 raise TypeError("Bad condition type %s" % type(condition))
529 # Newer python does this automatically. Do it manually here for 606 except KeyError as exc:
530 # maximum compatibility 607 raise bb.asyncrpc.InvokeError(
531 s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) 608 "Input line is missing key '%s' " % exc
532 s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1) 609 ) from exc
533 610
534 name = self.server.sockets[0].getsockname() 611 return json.dumps({"count": await self.db.gc_mark(mark, condition)})
535 if self.server.sockets[0].family == socket.AF_INET6:
536 self.address = "[%s]:%d" % (name[0], name[1])
537 else:
538 self.address = "%s:%d" % (name[0], name[1])
539 612
540 def start_unix_server(self, path): 613 return await self._stream_handler(handler)
541 def cleanup():
542 os.unlink(path)
543 614
544 cwd = os.getcwd() 615 @permissions(DB_ADMIN_PERM)
545 try: 616 async def handle_gc_sweep(self, request):
546 # Work around path length limits in AF_UNIX 617 mark = request["mark"]
547 os.chdir(os.path.dirname(path)) 618
548 self.server = self.loop.run_until_complete( 619 if not isinstance(mark, str):
549 asyncio.start_unix_server(self.handle_client, os.path.basename(path), loop=self.loop) 620 raise TypeError("Bad mark type %s" % type(mark))
621
622 current_mark = await self.db.get_current_gc_mark()
623
624 if not current_mark or mark != current_mark:
625 raise bb.asyncrpc.InvokeError(
626 f"'{mark}' is not the current mark. Refusing to sweep"
550 ) 627 )
551 finally:
552 os.chdir(cwd)
553 628
554 logger.info('Listening on %r' % path) 629 count = await self.db.gc_sweep()
630
631 return {"count": count}
555 632
556 self._cleanup_socket = cleanup 633 @permissions(DB_ADMIN_PERM)
557 self.address = "unix://%s" % os.path.abspath(path) 634 async def handle_gc_status(self, request):
635 (keep_rows, remove_rows, current_mark) = await self.db.gc_status()
636 return {
637 "keep": keep_rows,
638 "remove": remove_rows,
639 "mark": current_mark,
640 }
641
642 @permissions(DB_ADMIN_PERM)
643 async def handle_clean_unused(self, request):
644 max_age = request["max_age_seconds"]
645 oldest = datetime.now() - timedelta(seconds=-max_age)
646 return {"count": await self.db.clean_unused(oldest)}
647
648 @permissions(DB_ADMIN_PERM)
649 async def handle_get_db_usage(self, request):
650 return {"usage": await self.db.get_usage()}
651
652 @permissions(DB_ADMIN_PERM)
653 async def handle_get_db_query_columns(self, request):
654 return {"columns": await self.db.get_query_columns()}
655
656 # The authentication API is always allowed
657 async def handle_auth(self, request):
658 username = str(request["username"])
659 token = str(request["token"])
660
661 async def fail_auth():
662 nonlocal username
663 # Rate limit bad login attempts
664 await asyncio.sleep(1)
665 raise bb.asyncrpc.InvokeError(f"Unable to authenticate as {username}")
666
667 user, db_token = await self.db.lookup_user_token(username)
668
669 if not user or not db_token:
670 await fail_auth()
558 671
559 async def handle_client(self, reader, writer):
560 # writer.transport.set_write_buffer_limits(0)
561 try: 672 try:
562 client = ServerClient(reader, writer, self.db, self.request_stats, self.backfill_queue, self.upstream, self.read_only) 673 algo, salt, _ = db_token.split(":")
563 await client.process_requests() 674 except ValueError:
564 except Exception as e: 675 await fail_auth()
565 import traceback
566 logger.error('Error from client: %s' % str(e), exc_info=True)
567 traceback.print_exc()
568 writer.close()
569 logger.info('Client disconnected')
570
571 @contextmanager
572 def _backfill_worker(self):
573 async def backfill_worker_task():
574 client = await create_async_client(self.upstream)
575 try:
576 while True:
577 item = await self.backfill_queue.get()
578 if item is None:
579 self.backfill_queue.task_done()
580 break
581 method, taskhash = item
582 await copy_from_upstream(client, self.db, method, taskhash)
583 self.backfill_queue.task_done()
584 finally:
585 await client.close()
586 676
587 async def join_worker(worker): 677 if hash_token(algo, salt, token) != db_token:
588 await self.backfill_queue.put(None) 678 await fail_auth()
589 await worker
590 679
591 if self.upstream is not None: 680 self.user = user
592 worker = asyncio.ensure_future(backfill_worker_task()) 681
593 try: 682 self.logger.info("Authenticated as %s", username)
594 yield 683
595 finally: 684 return {
596 self.loop.run_until_complete(join_worker(worker)) 685 "result": True,
597 else: 686 "username": self.user.username,
598 yield 687 "permissions": sorted(list(self.user.permissions)),
688 }
599 689
600 def serve_forever(self): 690 @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False)
601 def signal_handler(): 691 async def handle_refresh_token(self, request):
602 self.loop.stop() 692 username = str(request["username"])
603 693
604 asyncio.set_event_loop(self.loop) 694 token = await new_token()
695
696 updated = await self.db.set_user_token(
697 username,
698 hash_token(TOKEN_ALGORITHM, new_salt(), token),
699 )
700 if not updated:
701 self.raise_no_user_error(username)
702
703 return {"username": username, "token": token}
704
705 def get_perm_arg(self, arg):
706 if not isinstance(arg, list):
707 raise bb.asyncrpc.InvokeError("Unexpected type for permissions")
708
709 arg = set(arg)
605 try: 710 try:
606 self.backfill_queue = asyncio.Queue() 711 arg.remove(NONE_PERM)
712 except KeyError:
713 pass
714
715 unknown_perms = arg - ALL_PERMISSIONS
716 if unknown_perms:
717 raise bb.asyncrpc.InvokeError(
718 "Unknown permissions %s" % ", ".join(sorted(list(unknown_perms)))
719 )
720
721 return sorted(list(arg))
722
723 def return_perms(self, permissions):
724 if ALL_PERM in permissions:
725 return sorted(list(ALL_PERMISSIONS))
726 return sorted(list(permissions))
607 727
608 self.loop.add_signal_handler(signal.SIGTERM, signal_handler) 728 @permissions(USER_ADMIN_PERM, allow_anon=False)
729 async def handle_set_perms(self, request):
730 username = str(request["username"])
731 permissions = self.get_perm_arg(request["permissions"])
609 732
610 with self._backfill_worker(): 733 if not await self.db.set_user_perms(username, permissions):
611 try: 734 self.raise_no_user_error(username)
612 self.loop.run_forever()
613 except KeyboardInterrupt:
614 pass
615 735
616 self.server.close() 736 return {
737 "username": username,
738 "permissions": self.return_perms(permissions),
739 }
740
741 @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False)
742 async def handle_get_user(self, request):
743 username = str(request["username"])
617 744
618 self.loop.run_until_complete(self.server.wait_closed()) 745 user = await self.db.lookup_user(username)
619 logger.info('Server shutting down') 746 if user is None:
620 finally: 747 return None
621 if self.close_loop: 748
622 if sys.version_info >= (3, 6): 749 return {
623 self.loop.run_until_complete(self.loop.shutdown_asyncgens()) 750 "username": user.username,
624 self.loop.close() 751 "permissions": self.return_perms(user.permissions),
752 }
753
754 @permissions(USER_ADMIN_PERM, allow_anon=False)
755 async def handle_get_all_users(self, request):
756 users = await self.db.get_all_users()
757 return {
758 "users": [
759 {
760 "username": u.username,
761 "permissions": self.return_perms(u.permissions),
762 }
763 for u in users
764 ]
765 }
625 766
626 if self._cleanup_socket is not None: 767 @permissions(USER_ADMIN_PERM, allow_anon=False)
627 self._cleanup_socket() 768 async def handle_new_user(self, request):
769 username = str(request["username"])
770 permissions = self.get_perm_arg(request["permissions"])
771
772 token = await new_token()
773
774 inserted = await self.db.new_user(
775 username,
776 permissions,
777 hash_token(TOKEN_ALGORITHM, new_salt(), token),
778 )
779 if not inserted:
780 raise bb.asyncrpc.InvokeError(f"Cannot create new user '{username}'")
781
782 return {
783 "username": username,
784 "permissions": self.return_perms(permissions),
785 "token": token,
786 }
787
788 @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False)
789 async def handle_delete_user(self, request):
790 username = str(request["username"])
791
792 if not await self.db.delete_user(username):
793 self.raise_no_user_error(username)
794
795 return {"username": username}
796
797 @permissions(USER_ADMIN_PERM, allow_anon=False)
798 async def handle_become_user(self, request):
799 username = str(request["username"])
800
801 user = await self.db.lookup_user(username)
802 if user is None:
803 raise bb.asyncrpc.InvokeError(f"User {username} doesn't exist")
804
805 self.user = user
806
807 self.logger.info("Became user %s", username)
808
809 return {
810 "username": self.user.username,
811 "permissions": self.return_perms(self.user.permissions),
812 }
813
814
815class Server(bb.asyncrpc.AsyncServer):
816 def __init__(
817 self,
818 db_engine,
819 upstream=None,
820 read_only=False,
821 anon_perms=DEFAULT_ANON_PERMS,
822 admin_username=None,
823 admin_password=None,
824 ):
825 if upstream and read_only:
826 raise bb.asyncrpc.ServerError(
827 "Read-only hashserv cannot pull from an upstream server"
828 )
829
830 disallowed_perms = set(anon_perms) - set(
831 [NONE_PERM, READ_PERM, REPORT_PERM, DB_ADMIN_PERM]
832 )
833
834 if disallowed_perms:
835 raise bb.asyncrpc.ServerError(
836 f"Permission(s) {' '.join(disallowed_perms)} are not allowed for anonymous users"
837 )
838
839 super().__init__(logger)
840
841 self.request_stats = Stats()
842 self.db_engine = db_engine
843 self.upstream = upstream
844 self.read_only = read_only
845 self.backfill_queue = None
846 self.anon_perms = set(anon_perms)
847 self.admin_username = admin_username
848 self.admin_password = admin_password
849
850 self.logger.info(
851 "Anonymous user permissions are: %s", ", ".join(self.anon_perms)
852 )
853
854 def accept_client(self, socket):
855 return ServerClient(socket, self)
856
857 async def create_admin_user(self):
858 admin_permissions = (ALL_PERM,)
859 async with self.db_engine.connect(self.logger) as db:
860 added = await db.new_user(
861 self.admin_username,
862 admin_permissions,
863 hash_token(TOKEN_ALGORITHM, new_salt(), self.admin_password),
864 )
865 if added:
866 self.logger.info("Created admin user '%s'", self.admin_username)
867 else:
868 await db.set_user_perms(
869 self.admin_username,
870 admin_permissions,
871 )
872 await db.set_user_token(
873 self.admin_username,
874 hash_token(TOKEN_ALGORITHM, new_salt(), self.admin_password),
875 )
876 self.logger.info("Admin user '%s' updated", self.admin_username)
877
878 async def backfill_worker_task(self):
879 async with await create_async_client(
880 self.upstream
881 ) as client, self.db_engine.connect(self.logger) as db:
882 while True:
883 item = await self.backfill_queue.get()
884 if item is None:
885 self.backfill_queue.task_done()
886 break
887
888 method, taskhash = item
889 d = await client.get_taskhash(method, taskhash)
890 if d is not None:
891 await db.insert_unihash(d["method"], d["taskhash"], d["unihash"])
892 self.backfill_queue.task_done()
893
894 def start(self):
895 tasks = super().start()
896 if self.upstream:
897 self.backfill_queue = asyncio.Queue()
898 tasks += [self.backfill_worker_task()]
899
900 self.loop.run_until_complete(self.db_engine.create())
901
902 if self.admin_username:
903 self.loop.run_until_complete(self.create_admin_user())
904
905 return tasks
906
907 async def stop(self):
908 if self.backfill_queue is not None:
909 await self.backfill_queue.put(None)
910 await super().stop()
diff --git a/bitbake/lib/hashserv/sqlalchemy.py b/bitbake/lib/hashserv/sqlalchemy.py
new file mode 100644
index 0000000000..f7b0226a7a
--- /dev/null
+++ b/bitbake/lib/hashserv/sqlalchemy.py
@@ -0,0 +1,598 @@
1#! /usr/bin/env python3
2#
3# Copyright (C) 2023 Garmin Ltd.
4#
5# SPDX-License-Identifier: GPL-2.0-only
6#
7
8import logging
9from datetime import datetime
10from . import User
11
12from sqlalchemy.ext.asyncio import create_async_engine
13from sqlalchemy.pool import NullPool
14from sqlalchemy import (
15 MetaData,
16 Column,
17 Table,
18 Text,
19 Integer,
20 UniqueConstraint,
21 DateTime,
22 Index,
23 select,
24 insert,
25 exists,
26 literal,
27 and_,
28 delete,
29 update,
30 func,
31 inspect,
32)
33import sqlalchemy.engine
34from sqlalchemy.orm import declarative_base
35from sqlalchemy.exc import IntegrityError
36from sqlalchemy.dialects.postgresql import insert as postgres_insert
37
38Base = declarative_base()
39
40
41class UnihashesV3(Base):
42 __tablename__ = "unihashes_v3"
43 id = Column(Integer, primary_key=True, autoincrement=True)
44 method = Column(Text, nullable=False)
45 taskhash = Column(Text, nullable=False)
46 unihash = Column(Text, nullable=False)
47 gc_mark = Column(Text, nullable=False)
48
49 __table_args__ = (
50 UniqueConstraint("method", "taskhash"),
51 Index("taskhash_lookup_v4", "method", "taskhash"),
52 Index("unihash_lookup_v1", "unihash"),
53 )
54
55
56class OuthashesV2(Base):
57 __tablename__ = "outhashes_v2"
58 id = Column(Integer, primary_key=True, autoincrement=True)
59 method = Column(Text, nullable=False)
60 taskhash = Column(Text, nullable=False)
61 outhash = Column(Text, nullable=False)
62 created = Column(DateTime)
63 owner = Column(Text)
64 PN = Column(Text)
65 PV = Column(Text)
66 PR = Column(Text)
67 task = Column(Text)
68 outhash_siginfo = Column(Text)
69
70 __table_args__ = (
71 UniqueConstraint("method", "taskhash", "outhash"),
72 Index("outhash_lookup_v3", "method", "outhash"),
73 )
74
75
76class Users(Base):
77 __tablename__ = "users"
78 id = Column(Integer, primary_key=True, autoincrement=True)
79 username = Column(Text, nullable=False)
80 token = Column(Text, nullable=False)
81 permissions = Column(Text)
82
83 __table_args__ = (UniqueConstraint("username"),)
84
85
86class Config(Base):
87 __tablename__ = "config"
88 id = Column(Integer, primary_key=True, autoincrement=True)
89 name = Column(Text, nullable=False)
90 value = Column(Text)
91 __table_args__ = (
92 UniqueConstraint("name"),
93 Index("config_lookup", "name"),
94 )
95
96
97#
98# Old table versions
99#
100DeprecatedBase = declarative_base()
101
102
103class UnihashesV2(DeprecatedBase):
104 __tablename__ = "unihashes_v2"
105 id = Column(Integer, primary_key=True, autoincrement=True)
106 method = Column(Text, nullable=False)
107 taskhash = Column(Text, nullable=False)
108 unihash = Column(Text, nullable=False)
109
110 __table_args__ = (
111 UniqueConstraint("method", "taskhash"),
112 Index("taskhash_lookup_v3", "method", "taskhash"),
113 )
114
115
116class DatabaseEngine(object):
117 def __init__(self, url, username=None, password=None):
118 self.logger = logging.getLogger("hashserv.sqlalchemy")
119 self.url = sqlalchemy.engine.make_url(url)
120
121 if username is not None:
122 self.url = self.url.set(username=username)
123
124 if password is not None:
125 self.url = self.url.set(password=password)
126
127 async def create(self):
128 def check_table_exists(conn, name):
129 return inspect(conn).has_table(name)
130
131 self.logger.info("Using database %s", self.url)
132 if self.url.drivername == 'postgresql+psycopg':
133 # Psygopg 3 (psygopg) driver can handle async connection pooling
134 self.engine = create_async_engine(self.url, max_overflow=-1)
135 else:
136 self.engine = create_async_engine(self.url, poolclass=NullPool)
137
138 async with self.engine.begin() as conn:
139 # Create tables
140 self.logger.info("Creating tables...")
141 await conn.run_sync(Base.metadata.create_all)
142
143 if await conn.run_sync(check_table_exists, UnihashesV2.__tablename__):
144 self.logger.info("Upgrading Unihashes V2 -> V3...")
145 statement = insert(UnihashesV3).from_select(
146 ["id", "method", "unihash", "taskhash", "gc_mark"],
147 select(
148 UnihashesV2.id,
149 UnihashesV2.method,
150 UnihashesV2.unihash,
151 UnihashesV2.taskhash,
152 literal("").label("gc_mark"),
153 ),
154 )
155 self.logger.debug("%s", statement)
156 await conn.execute(statement)
157
158 await conn.run_sync(Base.metadata.drop_all, [UnihashesV2.__table__])
159 self.logger.info("Upgrade complete")
160
161 def connect(self, logger):
162 return Database(self.engine, logger)
163
164
165def map_row(row):
166 if row is None:
167 return None
168 return dict(**row._mapping)
169
170
171def map_user(row):
172 if row is None:
173 return None
174 return User(
175 username=row.username,
176 permissions=set(row.permissions.split()),
177 )
178
179
180def _make_condition_statement(table, condition):
181 where = {}
182 for c in table.__table__.columns:
183 if c.key in condition and condition[c.key] is not None:
184 where[c] = condition[c.key]
185
186 return [(k == v) for k, v in where.items()]
187
188
189class Database(object):
190 def __init__(self, engine, logger):
191 self.engine = engine
192 self.db = None
193 self.logger = logger
194
195 async def __aenter__(self):
196 self.db = await self.engine.connect()
197 return self
198
199 async def __aexit__(self, exc_type, exc_value, traceback):
200 await self.close()
201
202 async def close(self):
203 await self.db.close()
204 self.db = None
205
206 async def _execute(self, statement):
207 self.logger.debug("%s", statement)
208 return await self.db.execute(statement)
209
210 async def _set_config(self, name, value):
211 while True:
212 result = await self._execute(
213 update(Config).where(Config.name == name).values(value=value)
214 )
215
216 if result.rowcount == 0:
217 self.logger.debug("Config '%s' not found. Adding it", name)
218 try:
219 await self._execute(insert(Config).values(name=name, value=value))
220 except IntegrityError:
221 # Race. Try again
222 continue
223
224 break
225
226 def _get_config_subquery(self, name, default=None):
227 if default is not None:
228 return func.coalesce(
229 select(Config.value).where(Config.name == name).scalar_subquery(),
230 default,
231 )
232 return select(Config.value).where(Config.name == name).scalar_subquery()
233
234 async def _get_config(self, name):
235 result = await self._execute(select(Config.value).where(Config.name == name))
236 row = result.first()
237 if row is None:
238 return None
239 return row.value
240
241 async def get_unihash_by_taskhash_full(self, method, taskhash):
242 async with self.db.begin():
243 result = await self._execute(
244 select(
245 OuthashesV2,
246 UnihashesV3.unihash.label("unihash"),
247 )
248 .join(
249 UnihashesV3,
250 and_(
251 UnihashesV3.method == OuthashesV2.method,
252 UnihashesV3.taskhash == OuthashesV2.taskhash,
253 ),
254 )
255 .where(
256 OuthashesV2.method == method,
257 OuthashesV2.taskhash == taskhash,
258 )
259 .order_by(
260 OuthashesV2.created.asc(),
261 )
262 .limit(1)
263 )
264 return map_row(result.first())
265
266 async def get_unihash_by_outhash(self, method, outhash):
267 async with self.db.begin():
268 result = await self._execute(
269 select(OuthashesV2, UnihashesV3.unihash.label("unihash"))
270 .join(
271 UnihashesV3,
272 and_(
273 UnihashesV3.method == OuthashesV2.method,
274 UnihashesV3.taskhash == OuthashesV2.taskhash,
275 ),
276 )
277 .where(
278 OuthashesV2.method == method,
279 OuthashesV2.outhash == outhash,
280 )
281 .order_by(
282 OuthashesV2.created.asc(),
283 )
284 .limit(1)
285 )
286 return map_row(result.first())
287
288 async def unihash_exists(self, unihash):
289 async with self.db.begin():
290 result = await self._execute(
291 select(UnihashesV3).where(UnihashesV3.unihash == unihash).limit(1)
292 )
293
294 return result.first() is not None
295
296 async def get_outhash(self, method, outhash):
297 async with self.db.begin():
298 result = await self._execute(
299 select(OuthashesV2)
300 .where(
301 OuthashesV2.method == method,
302 OuthashesV2.outhash == outhash,
303 )
304 .order_by(
305 OuthashesV2.created.asc(),
306 )
307 .limit(1)
308 )
309 return map_row(result.first())
310
311 async def get_equivalent_for_outhash(self, method, outhash, taskhash):
312 async with self.db.begin():
313 result = await self._execute(
314 select(
315 OuthashesV2.taskhash.label("taskhash"),
316 UnihashesV3.unihash.label("unihash"),
317 )
318 .join(
319 UnihashesV3,
320 and_(
321 UnihashesV3.method == OuthashesV2.method,
322 UnihashesV3.taskhash == OuthashesV2.taskhash,
323 ),
324 )
325 .where(
326 OuthashesV2.method == method,
327 OuthashesV2.outhash == outhash,
328 OuthashesV2.taskhash != taskhash,
329 )
330 .order_by(
331 OuthashesV2.created.asc(),
332 )
333 .limit(1)
334 )
335 return map_row(result.first())
336
337 async def get_equivalent(self, method, taskhash):
338 async with self.db.begin():
339 result = await self._execute(
340 select(
341 UnihashesV3.unihash,
342 UnihashesV3.method,
343 UnihashesV3.taskhash,
344 ).where(
345 UnihashesV3.method == method,
346 UnihashesV3.taskhash == taskhash,
347 )
348 )
349 return map_row(result.first())
350
351 async def remove(self, condition):
352 async def do_remove(table):
353 where = _make_condition_statement(table, condition)
354 if where:
355 async with self.db.begin():
356 result = await self._execute(delete(table).where(*where))
357 return result.rowcount
358
359 return 0
360
361 count = 0
362 count += await do_remove(UnihashesV3)
363 count += await do_remove(OuthashesV2)
364
365 return count
366
367 async def get_current_gc_mark(self):
368 async with self.db.begin():
369 return await self._get_config("gc-mark")
370
371 async def gc_status(self):
372 async with self.db.begin():
373 gc_mark_subquery = self._get_config_subquery("gc-mark", "")
374
375 result = await self._execute(
376 select(func.count())
377 .select_from(UnihashesV3)
378 .where(UnihashesV3.gc_mark == gc_mark_subquery)
379 )
380 keep_rows = result.scalar()
381
382 result = await self._execute(
383 select(func.count())
384 .select_from(UnihashesV3)
385 .where(UnihashesV3.gc_mark != gc_mark_subquery)
386 )
387 remove_rows = result.scalar()
388
389 return (keep_rows, remove_rows, await self._get_config("gc-mark"))
390
391 async def gc_mark(self, mark, condition):
392 async with self.db.begin():
393 await self._set_config("gc-mark", mark)
394
395 where = _make_condition_statement(UnihashesV3, condition)
396 if not where:
397 return 0
398
399 result = await self._execute(
400 update(UnihashesV3)
401 .values(gc_mark=self._get_config_subquery("gc-mark", ""))
402 .where(*where)
403 )
404 return result.rowcount
405
406 async def gc_sweep(self):
407 async with self.db.begin():
408 result = await self._execute(
409 delete(UnihashesV3).where(
410 # A sneaky conditional that provides some errant use
411 # protection: If the config mark is NULL, this will not
412 # match any rows because No default is specified in the
413 # select statement
414 UnihashesV3.gc_mark
415 != self._get_config_subquery("gc-mark")
416 )
417 )
418 await self._set_config("gc-mark", None)
419
420 return result.rowcount
421
422 async def clean_unused(self, oldest):
423 async with self.db.begin():
424 result = await self._execute(
425 delete(OuthashesV2).where(
426 OuthashesV2.created < oldest,
427 ~(
428 select(UnihashesV3.id)
429 .where(
430 UnihashesV3.method == OuthashesV2.method,
431 UnihashesV3.taskhash == OuthashesV2.taskhash,
432 )
433 .limit(1)
434 .exists()
435 ),
436 )
437 )
438 return result.rowcount
439
440 async def insert_unihash(self, method, taskhash, unihash):
441 # Postgres specific ignore on insert duplicate
442 if self.engine.name == "postgresql":
443 statement = (
444 postgres_insert(UnihashesV3)
445 .values(
446 method=method,
447 taskhash=taskhash,
448 unihash=unihash,
449 gc_mark=self._get_config_subquery("gc-mark", ""),
450 )
451 .on_conflict_do_nothing(index_elements=("method", "taskhash"))
452 )
453 else:
454 statement = insert(UnihashesV3).values(
455 method=method,
456 taskhash=taskhash,
457 unihash=unihash,
458 gc_mark=self._get_config_subquery("gc-mark", ""),
459 )
460
461 try:
462 async with self.db.begin():
463 result = await self._execute(statement)
464 return result.rowcount != 0
465 except IntegrityError:
466 self.logger.debug(
467 "%s, %s, %s already in unihash database", method, taskhash, unihash
468 )
469 return False
470
471 async def insert_outhash(self, data):
472 outhash_columns = set(c.key for c in OuthashesV2.__table__.columns)
473
474 data = {k: v for k, v in data.items() if k in outhash_columns}
475
476 if "created" in data and not isinstance(data["created"], datetime):
477 data["created"] = datetime.fromisoformat(data["created"])
478
479 # Postgres specific ignore on insert duplicate
480 if self.engine.name == "postgresql":
481 statement = (
482 postgres_insert(OuthashesV2)
483 .values(**data)
484 .on_conflict_do_nothing(
485 index_elements=("method", "taskhash", "outhash")
486 )
487 )
488 else:
489 statement = insert(OuthashesV2).values(**data)
490
491 try:
492 async with self.db.begin():
493 result = await self._execute(statement)
494 return result.rowcount != 0
495 except IntegrityError:
496 self.logger.debug(
497 "%s, %s already in outhash database", data["method"], data["outhash"]
498 )
499 return False
500
501 async def _get_user(self, username):
502 async with self.db.begin():
503 result = await self._execute(
504 select(
505 Users.username,
506 Users.permissions,
507 Users.token,
508 ).where(
509 Users.username == username,
510 )
511 )
512 return result.first()
513
514 async def lookup_user_token(self, username):
515 row = await self._get_user(username)
516 if not row:
517 return None, None
518 return map_user(row), row.token
519
520 async def lookup_user(self, username):
521 return map_user(await self._get_user(username))
522
523 async def set_user_token(self, username, token):
524 async with self.db.begin():
525 result = await self._execute(
526 update(Users)
527 .where(
528 Users.username == username,
529 )
530 .values(
531 token=token,
532 )
533 )
534 return result.rowcount != 0
535
536 async def set_user_perms(self, username, permissions):
537 async with self.db.begin():
538 result = await self._execute(
539 update(Users)
540 .where(Users.username == username)
541 .values(permissions=" ".join(permissions))
542 )
543 return result.rowcount != 0
544
545 async def get_all_users(self):
546 async with self.db.begin():
547 result = await self._execute(
548 select(
549 Users.username,
550 Users.permissions,
551 )
552 )
553 return [map_user(row) for row in result]
554
555 async def new_user(self, username, permissions, token):
556 try:
557 async with self.db.begin():
558 await self._execute(
559 insert(Users).values(
560 username=username,
561 permissions=" ".join(permissions),
562 token=token,
563 )
564 )
565 return True
566 except IntegrityError as e:
567 self.logger.debug("Cannot create new user %s: %s", username, e)
568 return False
569
570 async def delete_user(self, username):
571 async with self.db.begin():
572 result = await self._execute(
573 delete(Users).where(Users.username == username)
574 )
575 return result.rowcount != 0
576
577 async def get_usage(self):
578 usage = {}
579 async with self.db.begin() as session:
580 for name, table in Base.metadata.tables.items():
581 result = await self._execute(
582 statement=select(func.count()).select_from(table)
583 )
584 usage[name] = {
585 "rows": result.scalar(),
586 }
587
588 return usage
589
590 async def get_query_columns(self):
591 columns = set()
592 for table in (UnihashesV3, OuthashesV2):
593 for c in table.__table__.columns:
594 if not isinstance(c.type, Text):
595 continue
596 columns.add(c.key)
597
598 return list(columns)
diff --git a/bitbake/lib/hashserv/sqlite.py b/bitbake/lib/hashserv/sqlite.py
new file mode 100644
index 0000000000..976504d7f4
--- /dev/null
+++ b/bitbake/lib/hashserv/sqlite.py
@@ -0,0 +1,579 @@
1#! /usr/bin/env python3
2#
3# Copyright (C) 2023 Garmin Ltd.
4#
5# SPDX-License-Identifier: GPL-2.0-only
6#
7from datetime import datetime, timezone
8import sqlite3
9import logging
10from contextlib import closing
11from . import User
12
13logger = logging.getLogger("hashserv.sqlite")
14
15UNIHASH_TABLE_DEFINITION = (
16 ("method", "TEXT NOT NULL", "UNIQUE"),
17 ("taskhash", "TEXT NOT NULL", "UNIQUE"),
18 ("unihash", "TEXT NOT NULL", ""),
19 ("gc_mark", "TEXT NOT NULL", ""),
20)
21
22UNIHASH_TABLE_COLUMNS = tuple(name for name, _, _ in UNIHASH_TABLE_DEFINITION)
23
24OUTHASH_TABLE_DEFINITION = (
25 ("method", "TEXT NOT NULL", "UNIQUE"),
26 ("taskhash", "TEXT NOT NULL", "UNIQUE"),
27 ("outhash", "TEXT NOT NULL", "UNIQUE"),
28 ("created", "DATETIME", ""),
29 # Optional fields
30 ("owner", "TEXT", ""),
31 ("PN", "TEXT", ""),
32 ("PV", "TEXT", ""),
33 ("PR", "TEXT", ""),
34 ("task", "TEXT", ""),
35 ("outhash_siginfo", "TEXT", ""),
36)
37
38OUTHASH_TABLE_COLUMNS = tuple(name for name, _, _ in OUTHASH_TABLE_DEFINITION)
39
40USERS_TABLE_DEFINITION = (
41 ("username", "TEXT NOT NULL", "UNIQUE"),
42 ("token", "TEXT NOT NULL", ""),
43 ("permissions", "TEXT NOT NULL", ""),
44)
45
46USERS_TABLE_COLUMNS = tuple(name for name, _, _ in USERS_TABLE_DEFINITION)
47
48
49CONFIG_TABLE_DEFINITION = (
50 ("name", "TEXT NOT NULL", "UNIQUE"),
51 ("value", "TEXT", ""),
52)
53
54CONFIG_TABLE_COLUMNS = tuple(name for name, _, _ in CONFIG_TABLE_DEFINITION)
55
56
57def adapt_datetime_iso(val):
58 """Adapt datetime.datetime to UTC ISO 8601 date."""
59 return val.astimezone(timezone.utc).isoformat()
60
61
62sqlite3.register_adapter(datetime, adapt_datetime_iso)
63
64
65def convert_datetime(val):
66 """Convert ISO 8601 datetime to datetime.datetime object."""
67 return datetime.fromisoformat(val.decode())
68
69
70sqlite3.register_converter("DATETIME", convert_datetime)
71
72
73def _make_table(cursor, name, definition):
74 cursor.execute(
75 """
76 CREATE TABLE IF NOT EXISTS {name} (
77 id INTEGER PRIMARY KEY AUTOINCREMENT,
78 {fields}
79 UNIQUE({unique})
80 )
81 """.format(
82 name=name,
83 fields=" ".join("%s %s," % (name, typ) for name, typ, _ in definition),
84 unique=", ".join(
85 name for name, _, flags in definition if "UNIQUE" in flags
86 ),
87 )
88 )
89
90
91def map_user(row):
92 if row is None:
93 return None
94 return User(
95 username=row["username"],
96 permissions=set(row["permissions"].split()),
97 )
98
99
100def _make_condition_statement(columns, condition):
101 where = {}
102 for c in columns:
103 if c in condition and condition[c] is not None:
104 where[c] = condition[c]
105
106 return where, " AND ".join("%s=:%s" % (k, k) for k in where.keys())
107
108
109def _get_sqlite_version(cursor):
110 cursor.execute("SELECT sqlite_version()")
111
112 version = []
113 for v in cursor.fetchone()[0].split("."):
114 try:
115 version.append(int(v))
116 except ValueError:
117 version.append(v)
118
119 return tuple(version)
120
121
122def _schema_table_name(version):
123 if version >= (3, 33):
124 return "sqlite_schema"
125
126 return "sqlite_master"
127
128
129class DatabaseEngine(object):
130 def __init__(self, dbname, sync):
131 self.dbname = dbname
132 self.logger = logger
133 self.sync = sync
134
135 async def create(self):
136 db = sqlite3.connect(self.dbname)
137 db.row_factory = sqlite3.Row
138
139 with closing(db.cursor()) as cursor:
140 _make_table(cursor, "unihashes_v3", UNIHASH_TABLE_DEFINITION)
141 _make_table(cursor, "outhashes_v2", OUTHASH_TABLE_DEFINITION)
142 _make_table(cursor, "users", USERS_TABLE_DEFINITION)
143 _make_table(cursor, "config", CONFIG_TABLE_DEFINITION)
144
145 cursor.execute("PRAGMA journal_mode = WAL")
146 cursor.execute(
147 "PRAGMA synchronous = %s" % ("NORMAL" if self.sync else "OFF")
148 )
149
150 # Drop old indexes
151 cursor.execute("DROP INDEX IF EXISTS taskhash_lookup")
152 cursor.execute("DROP INDEX IF EXISTS outhash_lookup")
153 cursor.execute("DROP INDEX IF EXISTS taskhash_lookup_v2")
154 cursor.execute("DROP INDEX IF EXISTS outhash_lookup_v2")
155 cursor.execute("DROP INDEX IF EXISTS taskhash_lookup_v3")
156
157 # TODO: Upgrade from tasks_v2?
158 cursor.execute("DROP TABLE IF EXISTS tasks_v2")
159
160 # Create new indexes
161 cursor.execute(
162 "CREATE INDEX IF NOT EXISTS taskhash_lookup_v4 ON unihashes_v3 (method, taskhash)"
163 )
164 cursor.execute(
165 "CREATE INDEX IF NOT EXISTS unihash_lookup_v1 ON unihashes_v3 (unihash)"
166 )
167 cursor.execute(
168 "CREATE INDEX IF NOT EXISTS outhash_lookup_v3 ON outhashes_v2 (method, outhash)"
169 )
170 cursor.execute("CREATE INDEX IF NOT EXISTS config_lookup ON config (name)")
171
172 sqlite_version = _get_sqlite_version(cursor)
173
174 cursor.execute(
175 f"""
176 SELECT name FROM {_schema_table_name(sqlite_version)} WHERE type = 'table' AND name = 'unihashes_v2'
177 """
178 )
179 if cursor.fetchone():
180 self.logger.info("Upgrading Unihashes V2 -> V3...")
181 cursor.execute(
182 """
183 INSERT INTO unihashes_v3 (id, method, unihash, taskhash, gc_mark)
184 SELECT id, method, unihash, taskhash, '' FROM unihashes_v2
185 """
186 )
187 cursor.execute("DROP TABLE unihashes_v2")
188 db.commit()
189 self.logger.info("Upgrade complete")
190
191 def connect(self, logger):
192 return Database(logger, self.dbname, self.sync)
193
194
195class Database(object):
196 def __init__(self, logger, dbname, sync):
197 self.dbname = dbname
198 self.logger = logger
199
200 self.db = sqlite3.connect(self.dbname)
201 self.db.row_factory = sqlite3.Row
202
203 with closing(self.db.cursor()) as cursor:
204 cursor.execute("PRAGMA journal_mode = WAL")
205 cursor.execute(
206 "PRAGMA synchronous = %s" % ("NORMAL" if sync else "OFF")
207 )
208
209 self.sqlite_version = _get_sqlite_version(cursor)
210
211 async def __aenter__(self):
212 return self
213
214 async def __aexit__(self, exc_type, exc_value, traceback):
215 await self.close()
216
217 async def _set_config(self, cursor, name, value):
218 cursor.execute(
219 """
220 INSERT OR REPLACE INTO config (id, name, value) VALUES
221 ((SELECT id FROM config WHERE name=:name), :name, :value)
222 """,
223 {
224 "name": name,
225 "value": value,
226 },
227 )
228
229 async def _get_config(self, cursor, name):
230 cursor.execute(
231 "SELECT value FROM config WHERE name=:name",
232 {
233 "name": name,
234 },
235 )
236 row = cursor.fetchone()
237 if row is None:
238 return None
239 return row["value"]
240
241 async def close(self):
242 self.db.close()
243
244 async def get_unihash_by_taskhash_full(self, method, taskhash):
245 with closing(self.db.cursor()) as cursor:
246 cursor.execute(
247 """
248 SELECT *, unihashes_v3.unihash AS unihash FROM outhashes_v2
249 INNER JOIN unihashes_v3 ON unihashes_v3.method=outhashes_v2.method AND unihashes_v3.taskhash=outhashes_v2.taskhash
250 WHERE outhashes_v2.method=:method AND outhashes_v2.taskhash=:taskhash
251 ORDER BY outhashes_v2.created ASC
252 LIMIT 1
253 """,
254 {
255 "method": method,
256 "taskhash": taskhash,
257 },
258 )
259 return cursor.fetchone()
260
261 async def get_unihash_by_outhash(self, method, outhash):
262 with closing(self.db.cursor()) as cursor:
263 cursor.execute(
264 """
265 SELECT *, unihashes_v3.unihash AS unihash FROM outhashes_v2
266 INNER JOIN unihashes_v3 ON unihashes_v3.method=outhashes_v2.method AND unihashes_v3.taskhash=outhashes_v2.taskhash
267 WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash
268 ORDER BY outhashes_v2.created ASC
269 LIMIT 1
270 """,
271 {
272 "method": method,
273 "outhash": outhash,
274 },
275 )
276 return cursor.fetchone()
277
278 async def unihash_exists(self, unihash):
279 with closing(self.db.cursor()) as cursor:
280 cursor.execute(
281 """
282 SELECT * FROM unihashes_v3 WHERE unihash=:unihash
283 LIMIT 1
284 """,
285 {
286 "unihash": unihash,
287 },
288 )
289 return cursor.fetchone() is not None
290
291 async def get_outhash(self, method, outhash):
292 with closing(self.db.cursor()) as cursor:
293 cursor.execute(
294 """
295 SELECT * FROM outhashes_v2
296 WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash
297 ORDER BY outhashes_v2.created ASC
298 LIMIT 1
299 """,
300 {
301 "method": method,
302 "outhash": outhash,
303 },
304 )
305 return cursor.fetchone()
306
307 async def get_equivalent_for_outhash(self, method, outhash, taskhash):
308 with closing(self.db.cursor()) as cursor:
309 cursor.execute(
310 """
311 SELECT outhashes_v2.taskhash AS taskhash, unihashes_v3.unihash AS unihash FROM outhashes_v2
312 INNER JOIN unihashes_v3 ON unihashes_v3.method=outhashes_v2.method AND unihashes_v3.taskhash=outhashes_v2.taskhash
313 -- Select any matching output hash except the one we just inserted
314 WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash AND outhashes_v2.taskhash!=:taskhash
315 -- Pick the oldest hash
316 ORDER BY outhashes_v2.created ASC
317 LIMIT 1
318 """,
319 {
320 "method": method,
321 "outhash": outhash,
322 "taskhash": taskhash,
323 },
324 )
325 return cursor.fetchone()
326
327 async def get_equivalent(self, method, taskhash):
328 with closing(self.db.cursor()) as cursor:
329 cursor.execute(
330 "SELECT taskhash, method, unihash FROM unihashes_v3 WHERE method=:method AND taskhash=:taskhash",
331 {
332 "method": method,
333 "taskhash": taskhash,
334 },
335 )
336 return cursor.fetchone()
337
338 async def remove(self, condition):
339 def do_remove(columns, table_name, cursor):
340 where, clause = _make_condition_statement(columns, condition)
341 if where:
342 query = f"DELETE FROM {table_name} WHERE {clause}"
343 cursor.execute(query, where)
344 return cursor.rowcount
345
346 return 0
347
348 count = 0
349 with closing(self.db.cursor()) as cursor:
350 count += do_remove(OUTHASH_TABLE_COLUMNS, "outhashes_v2", cursor)
351 count += do_remove(UNIHASH_TABLE_COLUMNS, "unihashes_v3", cursor)
352 self.db.commit()
353
354 return count
355
356 async def get_current_gc_mark(self):
357 with closing(self.db.cursor()) as cursor:
358 return await self._get_config(cursor, "gc-mark")
359
360 async def gc_status(self):
361 with closing(self.db.cursor()) as cursor:
362 cursor.execute(
363 """
364 SELECT COUNT() FROM unihashes_v3 WHERE
365 gc_mark=COALESCE((SELECT value FROM config WHERE name='gc-mark'), '')
366 """
367 )
368 keep_rows = cursor.fetchone()[0]
369
370 cursor.execute(
371 """
372 SELECT COUNT() FROM unihashes_v3 WHERE
373 gc_mark!=COALESCE((SELECT value FROM config WHERE name='gc-mark'), '')
374 """
375 )
376 remove_rows = cursor.fetchone()[0]
377
378 current_mark = await self._get_config(cursor, "gc-mark")
379
380 return (keep_rows, remove_rows, current_mark)
381
382 async def gc_mark(self, mark, condition):
383 with closing(self.db.cursor()) as cursor:
384 await self._set_config(cursor, "gc-mark", mark)
385
386 where, clause = _make_condition_statement(UNIHASH_TABLE_COLUMNS, condition)
387
388 new_rows = 0
389 if where:
390 cursor.execute(
391 f"""
392 UPDATE unihashes_v3 SET
393 gc_mark=COALESCE((SELECT value FROM config WHERE name='gc-mark'), '')
394 WHERE {clause}
395 """,
396 where,
397 )
398 new_rows = cursor.rowcount
399
400 self.db.commit()
401 return new_rows
402
403 async def gc_sweep(self):
404 with closing(self.db.cursor()) as cursor:
405 # NOTE: COALESCE is not used in this query so that if the current
406 # mark is NULL, nothing will happen
407 cursor.execute(
408 """
409 DELETE FROM unihashes_v3 WHERE
410 gc_mark!=(SELECT value FROM config WHERE name='gc-mark')
411 """
412 )
413 count = cursor.rowcount
414 await self._set_config(cursor, "gc-mark", None)
415
416 self.db.commit()
417 return count
418
419 async def clean_unused(self, oldest):
420 with closing(self.db.cursor()) as cursor:
421 cursor.execute(
422 """
423 DELETE FROM outhashes_v2 WHERE created<:oldest AND NOT EXISTS (
424 SELECT unihashes_v3.id FROM unihashes_v3 WHERE unihashes_v3.method=outhashes_v2.method AND unihashes_v3.taskhash=outhashes_v2.taskhash LIMIT 1
425 )
426 """,
427 {
428 "oldest": oldest,
429 },
430 )
431 self.db.commit()
432 return cursor.rowcount
433
434 async def insert_unihash(self, method, taskhash, unihash):
435 with closing(self.db.cursor()) as cursor:
436 prevrowid = cursor.lastrowid
437 cursor.execute(
438 """
439 INSERT OR IGNORE INTO unihashes_v3 (method, taskhash, unihash, gc_mark) VALUES
440 (
441 :method,
442 :taskhash,
443 :unihash,
444 COALESCE((SELECT value FROM config WHERE name='gc-mark'), '')
445 )
446 """,
447 {
448 "method": method,
449 "taskhash": taskhash,
450 "unihash": unihash,
451 },
452 )
453 self.db.commit()
454 return cursor.lastrowid != prevrowid
455
456 async def insert_outhash(self, data):
457 data = {k: v for k, v in data.items() if k in OUTHASH_TABLE_COLUMNS}
458 keys = sorted(data.keys())
459 query = "INSERT OR IGNORE INTO outhashes_v2 ({fields}) VALUES({values})".format(
460 fields=", ".join(keys),
461 values=", ".join(":" + k for k in keys),
462 )
463 with closing(self.db.cursor()) as cursor:
464 prevrowid = cursor.lastrowid
465 cursor.execute(query, data)
466 self.db.commit()
467 return cursor.lastrowid != prevrowid
468
469 def _get_user(self, username):
470 with closing(self.db.cursor()) as cursor:
471 cursor.execute(
472 """
473 SELECT username, permissions, token FROM users WHERE username=:username
474 """,
475 {
476 "username": username,
477 },
478 )
479 return cursor.fetchone()
480
481 async def lookup_user_token(self, username):
482 row = self._get_user(username)
483 if row is None:
484 return None, None
485 return map_user(row), row["token"]
486
487 async def lookup_user(self, username):
488 return map_user(self._get_user(username))
489
490 async def set_user_token(self, username, token):
491 with closing(self.db.cursor()) as cursor:
492 cursor.execute(
493 """
494 UPDATE users SET token=:token WHERE username=:username
495 """,
496 {
497 "username": username,
498 "token": token,
499 },
500 )
501 self.db.commit()
502 return cursor.rowcount != 0
503
504 async def set_user_perms(self, username, permissions):
505 with closing(self.db.cursor()) as cursor:
506 cursor.execute(
507 """
508 UPDATE users SET permissions=:permissions WHERE username=:username
509 """,
510 {
511 "username": username,
512 "permissions": " ".join(permissions),
513 },
514 )
515 self.db.commit()
516 return cursor.rowcount != 0
517
518 async def get_all_users(self):
519 with closing(self.db.cursor()) as cursor:
520 cursor.execute("SELECT username, permissions FROM users")
521 return [map_user(r) for r in cursor.fetchall()]
522
523 async def new_user(self, username, permissions, token):
524 with closing(self.db.cursor()) as cursor:
525 try:
526 cursor.execute(
527 """
528 INSERT INTO users (username, token, permissions) VALUES (:username, :token, :permissions)
529 """,
530 {
531 "username": username,
532 "token": token,
533 "permissions": " ".join(permissions),
534 },
535 )
536 self.db.commit()
537 return True
538 except sqlite3.IntegrityError:
539 return False
540
541 async def delete_user(self, username):
542 with closing(self.db.cursor()) as cursor:
543 cursor.execute(
544 """
545 DELETE FROM users WHERE username=:username
546 """,
547 {
548 "username": username,
549 },
550 )
551 self.db.commit()
552 return cursor.rowcount != 0
553
554 async def get_usage(self):
555 usage = {}
556 with closing(self.db.cursor()) as cursor:
557 cursor.execute(
558 f"""
559 SELECT name FROM {_schema_table_name(self.sqlite_version)} WHERE type = 'table' AND name NOT LIKE 'sqlite_%'
560 """
561 )
562 for row in cursor.fetchall():
563 cursor.execute(
564 """
565 SELECT COUNT() FROM %s
566 """
567 % row["name"],
568 )
569 usage[row["name"]] = {
570 "rows": cursor.fetchone()[0],
571 }
572 return usage
573
574 async def get_query_columns(self):
575 columns = set()
576 for name, typ, _ in UNIHASH_TABLE_DEFINITION + OUTHASH_TABLE_DEFINITION:
577 if typ.startswith("TEXT"):
578 columns.add(name)
579 return list(columns)
diff --git a/bitbake/lib/hashserv/tests.py b/bitbake/lib/hashserv/tests.py
index 1a696481e3..da3f8e0884 100644
--- a/bitbake/lib/hashserv/tests.py
+++ b/bitbake/lib/hashserv/tests.py
@@ -6,7 +6,8 @@
6# 6#
7 7
8from . import create_server, create_client 8from . import create_server, create_client
9from .client import HashConnectionError 9from .server import DEFAULT_ANON_PERMS, ALL_PERMISSIONS
10from bb.asyncrpc import InvokeError
10import hashlib 11import hashlib
11import logging 12import logging
12import multiprocessing 13import multiprocessing
@@ -16,72 +17,161 @@ import tempfile
16import threading 17import threading
17import unittest 18import unittest
18import socket 19import socket
19 20import time
20def _run_server(server, idx): 21import signal
21 # logging.basicConfig(level=logging.DEBUG, filename='bbhashserv.log', filemode='w', 22import subprocess
22 # format='%(levelname)s %(filename)s:%(lineno)d %(message)s') 23import json
23 sys.stdout = open('bbhashserv-%d.log' % idx, 'w') 24import re
25from pathlib import Path
26
27
28THIS_DIR = Path(__file__).parent
29BIN_DIR = THIS_DIR.parent.parent / "bin"
30
31def server_prefunc(server, idx):
32 logging.basicConfig(level=logging.DEBUG, filename='bbhashserv-%d.log' % idx, filemode='w',
33 format='%(levelname)s %(filename)s:%(lineno)d %(message)s')
34 server.logger.debug("Running server %d" % idx)
35 sys.stdout = open('bbhashserv-stdout-%d.log' % idx, 'w')
24 sys.stderr = sys.stdout 36 sys.stderr = sys.stdout
25 server.serve_forever()
26
27 37
28class HashEquivalenceTestSetup(object): 38class HashEquivalenceTestSetup(object):
29 METHOD = 'TestMethod' 39 METHOD = 'TestMethod'
30 40
31 server_index = 0 41 server_index = 0
42 client_index = 0
32 43
33 def start_server(self, dbpath=None, upstream=None, read_only=False): 44 def start_server(self, dbpath=None, upstream=None, read_only=False, prefunc=server_prefunc, anon_perms=DEFAULT_ANON_PERMS, admin_username=None, admin_password=None):
34 self.server_index += 1 45 self.server_index += 1
35 if dbpath is None: 46 if dbpath is None:
36 dbpath = os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index) 47 dbpath = self.make_dbpath()
48
49 def cleanup_server(server):
50 if server.process.exitcode is not None:
51 return
37 52
38 def cleanup_thread(thread): 53 server.process.terminate()
39 thread.terminate() 54 server.process.join()
40 thread.join()
41 55
42 server = create_server(self.get_server_addr(self.server_index), 56 server = create_server(self.get_server_addr(self.server_index),
43 dbpath, 57 dbpath,
44 upstream=upstream, 58 upstream=upstream,
45 read_only=read_only) 59 read_only=read_only,
60 anon_perms=anon_perms,
61 admin_username=admin_username,
62 admin_password=admin_password)
46 server.dbpath = dbpath 63 server.dbpath = dbpath
47 64
48 server.thread = multiprocessing.Process(target=_run_server, args=(server, self.server_index)) 65 server.serve_as_process(prefunc=prefunc, args=(self.server_index,))
49 server.thread.start() 66 self.addCleanup(cleanup_server, server)
50 self.addCleanup(cleanup_thread, server.thread) 67
68 return server
69
70 def make_dbpath(self):
71 return os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index)
51 72
73 def start_client(self, server_address, username=None, password=None):
52 def cleanup_client(client): 74 def cleanup_client(client):
53 client.close() 75 client.close()
54 76
55 client = create_client(server.address) 77 client = create_client(server_address, username=username, password=password)
56 self.addCleanup(cleanup_client, client) 78 self.addCleanup(cleanup_client, client)
57 79
58 return (client, server) 80 return client
59 81
60 def setUp(self): 82 def start_test_server(self):
61 if sys.version_info < (3, 5, 0): 83 self.server = self.start_server()
62 self.skipTest('Python 3.5 or later required') 84 return self.server.address
85
86 def start_auth_server(self):
87 auth_server = self.start_server(self.server.dbpath, anon_perms=[], admin_username="admin", admin_password="password")
88 self.auth_server_address = auth_server.address
89 self.admin_client = self.start_client(auth_server.address, username="admin", password="password")
90 return self.admin_client
91
92 def auth_client(self, user):
93 return self.start_client(self.auth_server_address, user["username"], user["token"])
63 94
95 def setUp(self):
64 self.temp_dir = tempfile.TemporaryDirectory(prefix='bb-hashserv') 96 self.temp_dir = tempfile.TemporaryDirectory(prefix='bb-hashserv')
65 self.addCleanup(self.temp_dir.cleanup) 97 self.addCleanup(self.temp_dir.cleanup)
66 98
67 (self.client, self.server) = self.start_server() 99 self.server_address = self.start_test_server()
100
101 self.client = self.start_client(self.server_address)
68 102
69 def assertClientGetHash(self, client, taskhash, unihash): 103 def assertClientGetHash(self, client, taskhash, unihash):
70 result = client.get_unihash(self.METHOD, taskhash) 104 result = client.get_unihash(self.METHOD, taskhash)
71 self.assertEqual(result, unihash) 105 self.assertEqual(result, unihash)
72 106
107 def assertUserPerms(self, user, permissions):
108 with self.auth_client(user) as client:
109 info = client.get_user()
110 self.assertEqual(info, {
111 "username": user["username"],
112 "permissions": permissions,
113 })
73 114
74class HashEquivalenceCommonTests(object): 115 def assertUserCanAuth(self, user):
75 def test_create_hash(self): 116 with self.start_client(self.auth_server_address) as client:
117 client.auth(user["username"], user["token"])
118
119 def assertUserCannotAuth(self, user):
120 with self.start_client(self.auth_server_address) as client, self.assertRaises(InvokeError):
121 client.auth(user["username"], user["token"])
122
123 def create_test_hash(self, client):
76 # Simple test that hashes can be created 124 # Simple test that hashes can be created
77 taskhash = '35788efcb8dfb0a02659d81cf2bfd695fb30faf9' 125 taskhash = '35788efcb8dfb0a02659d81cf2bfd695fb30faf9'
78 outhash = '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f' 126 outhash = '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f'
79 unihash = 'f46d3fbb439bd9b921095da657a4de906510d2cd' 127 unihash = 'f46d3fbb439bd9b921095da657a4de906510d2cd'
80 128
81 self.assertClientGetHash(self.client, taskhash, None) 129 self.assertClientGetHash(client, taskhash, None)
82 130
83 result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) 131 result = client.report_unihash(taskhash, self.METHOD, outhash, unihash)
84 self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') 132 self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
133 return taskhash, outhash, unihash
134
135 def run_hashclient(self, args, **kwargs):
136 try:
137 p = subprocess.run(
138 [BIN_DIR / "bitbake-hashclient"] + args,
139 stdout=subprocess.PIPE,
140 stderr=subprocess.STDOUT,
141 encoding="utf-8",
142 **kwargs
143 )
144 except subprocess.CalledProcessError as e:
145 print(e.output)
146 raise e
147
148 print(p.stdout)
149 return p
150
151
152class HashEquivalenceCommonTests(object):
153 def auth_perms(self, *permissions):
154 self.client_index += 1
155 user = self.create_user(f"user-{self.client_index}", permissions)
156 return self.auth_client(user)
157
158 def create_user(self, username, permissions, *, client=None):
159 def remove_user(username):
160 try:
161 self.admin_client.delete_user(username)
162 except bb.asyncrpc.InvokeError:
163 pass
164
165 if client is None:
166 client = self.admin_client
167
168 user = client.new_user(username, permissions)
169 self.addCleanup(remove_user, username)
170
171 return user
172
173 def test_create_hash(self):
174 return self.create_test_hash(self.client)
85 175
86 def test_create_equivalent(self): 176 def test_create_equivalent(self):
87 # Tests that a second reported task with the same outhash will be 177 # Tests that a second reported task with the same outhash will be
@@ -123,6 +213,57 @@ class HashEquivalenceCommonTests(object):
123 213
124 self.assertClientGetHash(self.client, taskhash, unihash) 214 self.assertClientGetHash(self.client, taskhash, unihash)
125 215
216 def test_remove_taskhash(self):
217 taskhash, outhash, unihash = self.create_test_hash(self.client)
218 result = self.client.remove({"taskhash": taskhash})
219 self.assertGreater(result["count"], 0)
220 self.assertClientGetHash(self.client, taskhash, None)
221
222 result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash)
223 self.assertIsNone(result_outhash)
224
225 def test_remove_unihash(self):
226 taskhash, outhash, unihash = self.create_test_hash(self.client)
227 result = self.client.remove({"unihash": unihash})
228 self.assertGreater(result["count"], 0)
229 self.assertClientGetHash(self.client, taskhash, None)
230
231 def test_remove_outhash(self):
232 taskhash, outhash, unihash = self.create_test_hash(self.client)
233 result = self.client.remove({"outhash": outhash})
234 self.assertGreater(result["count"], 0)
235
236 result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash)
237 self.assertIsNone(result_outhash)
238
239 def test_remove_method(self):
240 taskhash, outhash, unihash = self.create_test_hash(self.client)
241 result = self.client.remove({"method": self.METHOD})
242 self.assertGreater(result["count"], 0)
243 self.assertClientGetHash(self.client, taskhash, None)
244
245 result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash)
246 self.assertIsNone(result_outhash)
247
248 def test_clean_unused(self):
249 taskhash, outhash, unihash = self.create_test_hash(self.client)
250
251 # Clean the database, which should not remove anything because all hashes an in-use
252 result = self.client.clean_unused(0)
253 self.assertEqual(result["count"], 0)
254 self.assertClientGetHash(self.client, taskhash, unihash)
255
256 # Remove the unihash. The row in the outhash table should still be present
257 self.client.remove({"unihash": unihash})
258 result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash, False)
259 self.assertIsNotNone(result_outhash)
260
261 # Now clean with no minimum age which will remove the outhash
262 result = self.client.clean_unused(0)
263 self.assertEqual(result["count"], 1)
264 result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash, False)
265 self.assertIsNone(result_outhash)
266
126 def test_huge_message(self): 267 def test_huge_message(self):
127 # Simple test that hashes can be created 268 # Simple test that hashes can be created
128 taskhash = 'c665584ee6817aa99edfc77a44dd853828279370' 269 taskhash = 'c665584ee6817aa99edfc77a44dd853828279370'
@@ -138,16 +279,21 @@ class HashEquivalenceCommonTests(object):
138 }) 279 })
139 self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') 280 self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
140 281
141 result = self.client.get_taskhash(self.METHOD, taskhash, True) 282 result_unihash = self.client.get_taskhash(self.METHOD, taskhash, True)
142 self.assertEqual(result['taskhash'], taskhash) 283 self.assertEqual(result_unihash['taskhash'], taskhash)
143 self.assertEqual(result['unihash'], unihash) 284 self.assertEqual(result_unihash['unihash'], unihash)
144 self.assertEqual(result['method'], self.METHOD) 285 self.assertEqual(result_unihash['method'], self.METHOD)
145 self.assertEqual(result['outhash'], outhash) 286
146 self.assertEqual(result['outhash_siginfo'], siginfo) 287 result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash)
288 self.assertEqual(result_outhash['taskhash'], taskhash)
289 self.assertEqual(result_outhash['method'], self.METHOD)
290 self.assertEqual(result_outhash['unihash'], unihash)
291 self.assertEqual(result_outhash['outhash'], outhash)
292 self.assertEqual(result_outhash['outhash_siginfo'], siginfo)
147 293
148 def test_stress(self): 294 def test_stress(self):
149 def query_server(failures): 295 def query_server(failures):
150 client = Client(self.server.address) 296 client = Client(self.server_address)
151 try: 297 try:
152 for i in range(1000): 298 for i in range(1000):
153 taskhash = hashlib.sha256() 299 taskhash = hashlib.sha256()
@@ -186,8 +332,10 @@ class HashEquivalenceCommonTests(object):
186 # the side client. It also verifies that the results are pulled into 332 # the side client. It also verifies that the results are pulled into
187 # the downstream database by checking that the downstream and side servers 333 # the downstream database by checking that the downstream and side servers
188 # match after the downstream is done waiting for all backfill tasks 334 # match after the downstream is done waiting for all backfill tasks
189 (down_client, down_server) = self.start_server(upstream=self.server.address) 335 down_server = self.start_server(upstream=self.server_address)
190 (side_client, side_server) = self.start_server(dbpath=down_server.dbpath) 336 down_client = self.start_client(down_server.address)
337 side_server = self.start_server(dbpath=down_server.dbpath)
338 side_client = self.start_client(side_server.address)
191 339
192 def check_hash(taskhash, unihash, old_sidehash): 340 def check_hash(taskhash, unihash, old_sidehash):
193 nonlocal down_client 341 nonlocal down_client
@@ -258,15 +406,57 @@ class HashEquivalenceCommonTests(object):
258 result = down_client.report_unihash(taskhash6, self.METHOD, outhash5, unihash6) 406 result = down_client.report_unihash(taskhash6, self.METHOD, outhash5, unihash6)
259 self.assertEqual(result['unihash'], unihash5, 'Server failed to copy unihash from upstream') 407 self.assertEqual(result['unihash'], unihash5, 'Server failed to copy unihash from upstream')
260 408
409 # Tests read through from server with
410 taskhash7 = '9d81d76242cc7cfaf7bf74b94b9cd2e29324ed74'
411 outhash7 = '8470d56547eea6236d7c81a644ce74670ca0bbda998e13c629ef6bb3f0d60b69'
412 unihash7 = '05d2a63c81e32f0a36542ca677e8ad852365c538'
413 self.client.report_unihash(taskhash7, self.METHOD, outhash7, unihash7)
414
415 result = down_client.get_taskhash(self.METHOD, taskhash7, True)
416 self.assertEqual(result['unihash'], unihash7, 'Server failed to copy unihash from upstream')
417 self.assertEqual(result['outhash'], outhash7, 'Server failed to copy unihash from upstream')
418 self.assertEqual(result['taskhash'], taskhash7, 'Server failed to copy unihash from upstream')
419 self.assertEqual(result['method'], self.METHOD)
420
421 taskhash8 = '86978a4c8c71b9b487330b0152aade10c1ee58aa'
422 outhash8 = 'ca8c128e9d9e4a28ef24d0508aa20b5cf880604eacd8f65c0e366f7e0cc5fbcf'
423 unihash8 = 'd8bcf25369d40590ad7d08c84d538982f2023e01'
424 self.client.report_unihash(taskhash8, self.METHOD, outhash8, unihash8)
425
426 result = down_client.get_outhash(self.METHOD, outhash8, taskhash8)
427 self.assertEqual(result['unihash'], unihash8, 'Server failed to copy unihash from upstream')
428 self.assertEqual(result['outhash'], outhash8, 'Server failed to copy unihash from upstream')
429 self.assertEqual(result['taskhash'], taskhash8, 'Server failed to copy unihash from upstream')
430 self.assertEqual(result['method'], self.METHOD)
431
432 taskhash9 = 'ae6339531895ddf5b67e663e6a374ad8ec71d81c'
433 outhash9 = 'afc78172c81880ae10a1fec994b5b4ee33d196a001a1b66212a15ebe573e00b5'
434 unihash9 = '6662e699d6e3d894b24408ff9a4031ef9b038ee8'
435 self.client.report_unihash(taskhash9, self.METHOD, outhash9, unihash9)
436
437 result = down_client.get_taskhash(self.METHOD, taskhash9, False)
438 self.assertEqual(result['unihash'], unihash9, 'Server failed to copy unihash from upstream')
439 self.assertEqual(result['taskhash'], taskhash9, 'Server failed to copy unihash from upstream')
440 self.assertEqual(result['method'], self.METHOD)
441
442 def test_unihash_exsits(self):
443 taskhash, outhash, unihash = self.create_test_hash(self.client)
444 self.assertTrue(self.client.unihash_exists(unihash))
445 self.assertFalse(self.client.unihash_exists('6662e699d6e3d894b24408ff9a4031ef9b038ee8'))
446
261 def test_ro_server(self): 447 def test_ro_server(self):
262 (ro_client, ro_server) = self.start_server(dbpath=self.server.dbpath, read_only=True) 448 rw_server = self.start_server()
449 rw_client = self.start_client(rw_server.address)
450
451 ro_server = self.start_server(dbpath=rw_server.dbpath, read_only=True)
452 ro_client = self.start_client(ro_server.address)
263 453
264 # Report a hash via the read-write server 454 # Report a hash via the read-write server
265 taskhash = '35788efcb8dfb0a02659d81cf2bfd695fb30faf9' 455 taskhash = '35788efcb8dfb0a02659d81cf2bfd695fb30faf9'
266 outhash = '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f' 456 outhash = '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f'
267 unihash = 'f46d3fbb439bd9b921095da657a4de906510d2cd' 457 unihash = 'f46d3fbb439bd9b921095da657a4de906510d2cd'
268 458
269 result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) 459 result = rw_client.report_unihash(taskhash, self.METHOD, outhash, unihash)
270 self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') 460 self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
271 461
272 # Check the hash via the read-only server 462 # Check the hash via the read-only server
@@ -277,11 +467,976 @@ class HashEquivalenceCommonTests(object):
277 outhash2 = '3c979c3db45c569f51ab7626a4651074be3a9d11a84b1db076f5b14f7d39db44' 467 outhash2 = '3c979c3db45c569f51ab7626a4651074be3a9d11a84b1db076f5b14f7d39db44'
278 unihash2 = '90e9bc1d1f094c51824adca7f8ea79a048d68824' 468 unihash2 = '90e9bc1d1f094c51824adca7f8ea79a048d68824'
279 469
280 with self.assertRaises(HashConnectionError): 470 result = ro_client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2)
281 ro_client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2) 471 self.assertEqual(result['unihash'], unihash2)
282 472
283 # Ensure that the database was not modified 473 # Ensure that the database was not modified
474 self.assertClientGetHash(rw_client, taskhash2, None)
475
476
477 def test_slow_server_start(self):
478 # Ensures that the server will exit correctly even if it gets a SIGTERM
479 # before entering the main loop
480
481 event = multiprocessing.Event()
482
483 def prefunc(server, idx):
484 nonlocal event
485 server_prefunc(server, idx)
486 event.wait()
487
488 def do_nothing(signum, frame):
489 pass
490
491 old_signal = signal.signal(signal.SIGTERM, do_nothing)
492 self.addCleanup(signal.signal, signal.SIGTERM, old_signal)
493
494 server = self.start_server(prefunc=prefunc)
495 server.process.terminate()
496 time.sleep(30)
497 event.set()
498 server.process.join(300)
499 self.assertIsNotNone(server.process.exitcode, "Server did not exit in a timely manner!")
500
501 def test_diverging_report_race(self):
502 # Tests that a reported task will correctly pick up an updated unihash
503
504 # This is a baseline report added to the database to ensure that there
505 # is something to match against as equivalent
506 outhash1 = 'afd11c366050bcd75ad763e898e4430e2a60659b26f83fbb22201a60672019fa'
507 taskhash1 = '3bde230c743fc45ab61a065d7a1815fbfa01c4740e4c895af2eb8dc0f684a4ab'
508 unihash1 = '3bde230c743fc45ab61a065d7a1815fbfa01c4740e4c895af2eb8dc0f684a4ab'
509 result = self.client.report_unihash(taskhash1, self.METHOD, outhash1, unihash1)
510
511 # Add a report that is equivalent to Task 1. It should ignore the
512 # provided unihash and report the unihash from task 1
513 taskhash2 = '6259ae8263bd94d454c086f501c37e64c4e83cae806902ca95b4ab513546b273'
514 unihash2 = taskhash2
515 result = self.client.report_unihash(taskhash2, self.METHOD, outhash1, unihash2)
516 self.assertEqual(result['unihash'], unihash1)
517
518 # Add another report for Task 2, but with a different outhash (e.g. the
519 # task is non-deterministic). It should still be marked with the Task 1
520 # unihash because it has the Task 2 taskhash, which is equivalent to
521 # Task 1
522 outhash3 = 'd2187ee3a8966db10b34fe0e863482288d9a6185cb8ef58a6c1c6ace87a2f24c'
523 result = self.client.report_unihash(taskhash2, self.METHOD, outhash3, unihash2)
524 self.assertEqual(result['unihash'], unihash1)
525
526
527 def test_diverging_report_reverse_race(self):
528 # Same idea as the previous test, but Tasks 2 and 3 are reported in
529 # reverse order the opposite order
530
531 outhash1 = 'afd11c366050bcd75ad763e898e4430e2a60659b26f83fbb22201a60672019fa'
532 taskhash1 = '3bde230c743fc45ab61a065d7a1815fbfa01c4740e4c895af2eb8dc0f684a4ab'
533 unihash1 = '3bde230c743fc45ab61a065d7a1815fbfa01c4740e4c895af2eb8dc0f684a4ab'
534 result = self.client.report_unihash(taskhash1, self.METHOD, outhash1, unihash1)
535
536 taskhash2 = '6259ae8263bd94d454c086f501c37e64c4e83cae806902ca95b4ab513546b273'
537 unihash2 = taskhash2
538
539 # Report Task 3 first. Since there is nothing else in the database it
540 # will use the client provided unihash
541 outhash3 = 'd2187ee3a8966db10b34fe0e863482288d9a6185cb8ef58a6c1c6ace87a2f24c'
542 result = self.client.report_unihash(taskhash2, self.METHOD, outhash3, unihash2)
543 self.assertEqual(result['unihash'], unihash2)
544
545 # Report Task 2. This is equivalent to Task 1 but there is already a mapping for
546 # taskhash2 so it will report unihash2
547 result = self.client.report_unihash(taskhash2, self.METHOD, outhash1, unihash2)
548 self.assertEqual(result['unihash'], unihash2)
549
550 # The originally reported unihash for Task 3 should be unchanged even if it
551 # shares a taskhash with Task 2
552 self.assertClientGetHash(self.client, taskhash2, unihash2)
553
554 def test_get_unihash_batch(self):
555 TEST_INPUT = (
556 # taskhash outhash unihash
557 ('8aa96fcffb5831b3c2c0cb75f0431e3f8b20554a', 'afe240a439959ce86f5e322f8c208e1fedefea9e813f2140c81af866cc9edf7e','218e57509998197d570e2c98512d0105985dffc9'),
558 # Duplicated taskhash with multiple output hashes and unihashes.
559 ('8aa96fcffb5831b3c2c0cb75f0431e3f8b20554a', '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d', 'ae9a7d252735f0dafcdb10e2e02561ca3a47314c'),
560 # Equivalent hash
561 ("044c2ec8aaf480685a00ff6ff49e6162e6ad34e1", '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d', "def64766090d28f627e816454ed46894bb3aab36"),
562 ("e3da00593d6a7fb435c7e2114976c59c5fd6d561", "1cf8713e645f491eb9c959d20b5cae1c47133a292626dda9b10709857cbe688a", "3b5d3d83f07f259e9086fcb422c855286e18a57d"),
563 ('35788efcb8dfb0a02659d81cf2bfd695fb30faf9', '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f', 'f46d3fbb439bd9b921095da657a4de906510d2cd'),
564 ('35788efcb8dfb0a02659d81cf2bfd695fb30fafa', '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f', 'f46d3fbb439bd9b921095da657a4de906510d2ce'),
565 ('9d81d76242cc7cfaf7bf74b94b9cd2e29324ed74', '8470d56547eea6236d7c81a644ce74670ca0bbda998e13c629ef6bb3f0d60b69', '05d2a63c81e32f0a36542ca677e8ad852365c538'),
566 )
567 EXTRA_QUERIES = (
568 "6b6be7a84ab179b4240c4302518dc3f6",
569 )
570
571 for taskhash, outhash, unihash in TEST_INPUT:
572 self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
573
574
575 result = self.client.get_unihash_batch(
576 [(self.METHOD, data[0]) for data in TEST_INPUT] +
577 [(self.METHOD, e) for e in EXTRA_QUERIES]
578 )
579
580 self.assertListEqual(result, [
581 "218e57509998197d570e2c98512d0105985dffc9",
582 "218e57509998197d570e2c98512d0105985dffc9",
583 "218e57509998197d570e2c98512d0105985dffc9",
584 "3b5d3d83f07f259e9086fcb422c855286e18a57d",
585 "f46d3fbb439bd9b921095da657a4de906510d2cd",
586 "f46d3fbb439bd9b921095da657a4de906510d2cd",
587 "05d2a63c81e32f0a36542ca677e8ad852365c538",
588 None,
589 ])
590
591 def test_unihash_exists_batch(self):
592 TEST_INPUT = (
593 # taskhash outhash unihash
594 ('8aa96fcffb5831b3c2c0cb75f0431e3f8b20554a', 'afe240a439959ce86f5e322f8c208e1fedefea9e813f2140c81af866cc9edf7e','218e57509998197d570e2c98512d0105985dffc9'),
595 # Duplicated taskhash with multiple output hashes and unihashes.
596 ('8aa96fcffb5831b3c2c0cb75f0431e3f8b20554a', '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d', 'ae9a7d252735f0dafcdb10e2e02561ca3a47314c'),
597 # Equivalent hash
598 ("044c2ec8aaf480685a00ff6ff49e6162e6ad34e1", '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d', "def64766090d28f627e816454ed46894bb3aab36"),
599 ("e3da00593d6a7fb435c7e2114976c59c5fd6d561", "1cf8713e645f491eb9c959d20b5cae1c47133a292626dda9b10709857cbe688a", "3b5d3d83f07f259e9086fcb422c855286e18a57d"),
600 ('35788efcb8dfb0a02659d81cf2bfd695fb30faf9', '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f', 'f46d3fbb439bd9b921095da657a4de906510d2cd'),
601 ('35788efcb8dfb0a02659d81cf2bfd695fb30fafa', '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f', 'f46d3fbb439bd9b921095da657a4de906510d2ce'),
602 ('9d81d76242cc7cfaf7bf74b94b9cd2e29324ed74', '8470d56547eea6236d7c81a644ce74670ca0bbda998e13c629ef6bb3f0d60b69', '05d2a63c81e32f0a36542ca677e8ad852365c538'),
603 )
604 EXTRA_QUERIES = (
605 "6b6be7a84ab179b4240c4302518dc3f6",
606 )
607
608 result_unihashes = set()
609
610
611 for taskhash, outhash, unihash in TEST_INPUT:
612 result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
613 result_unihashes.add(result["unihash"])
614
615 query = []
616 expected = []
617
618 for _, _, unihash in TEST_INPUT:
619 query.append(unihash)
620 expected.append(unihash in result_unihashes)
621
622
623 for unihash in EXTRA_QUERIES:
624 query.append(unihash)
625 expected.append(False)
626
627 result = self.client.unihash_exists_batch(query)
628 self.assertListEqual(result, expected)
629
630 def test_auth_read_perms(self):
631 admin_client = self.start_auth_server()
632
633 # Create hashes with non-authenticated server
634 taskhash, outhash, unihash = self.create_test_hash(self.client)
635
636 # Validate hash can be retrieved using authenticated client
637 with self.auth_perms("@read") as client:
638 self.assertClientGetHash(client, taskhash, unihash)
639
640 with self.auth_perms() as client, self.assertRaises(InvokeError):
641 self.assertClientGetHash(client, taskhash, unihash)
642
643 def test_auth_report_perms(self):
644 admin_client = self.start_auth_server()
645
646 # Without read permission, the user is completely denied
647 with self.auth_perms() as client, self.assertRaises(InvokeError):
648 self.create_test_hash(client)
649
650 # Read permission allows the call to succeed, but it doesn't record
651 # anythin in the database
652 with self.auth_perms("@read") as client:
653 taskhash, outhash, unihash = self.create_test_hash(client)
654 self.assertClientGetHash(client, taskhash, None)
655
656 # Report permission alone is insufficient
657 with self.auth_perms("@report") as client, self.assertRaises(InvokeError):
658 self.create_test_hash(client)
659
660 # Read and report permission actually modify the database
661 with self.auth_perms("@read", "@report") as client:
662 taskhash, outhash, unihash = self.create_test_hash(client)
663 self.assertClientGetHash(client, taskhash, unihash)
664
665 def test_auth_no_token_refresh_from_anon_user(self):
666 self.start_auth_server()
667
668 with self.start_client(self.auth_server_address) as client, self.assertRaises(InvokeError):
669 client.refresh_token()
670
671 def test_auth_self_token_refresh(self):
672 admin_client = self.start_auth_server()
673
674 # Create a new user with no permissions
675 user = self.create_user("test-user", [])
676
677 with self.auth_client(user) as client:
678 new_user = client.refresh_token()
679
680 self.assertEqual(user["username"], new_user["username"])
681 self.assertNotEqual(user["token"], new_user["token"])
682 self.assertUserCanAuth(new_user)
683 self.assertUserCannotAuth(user)
684
685 # Explicitly specifying with your own username is fine also
686 with self.auth_client(new_user) as client:
687 new_user2 = client.refresh_token(user["username"])
688
689 self.assertEqual(user["username"], new_user2["username"])
690 self.assertNotEqual(user["token"], new_user2["token"])
691 self.assertUserCanAuth(new_user2)
692 self.assertUserCannotAuth(new_user)
693 self.assertUserCannotAuth(user)
694
695 def test_auth_token_refresh(self):
696 admin_client = self.start_auth_server()
697
698 user = self.create_user("test-user", [])
699
700 with self.auth_perms() as client, self.assertRaises(InvokeError):
701 client.refresh_token(user["username"])
702
703 with self.auth_perms("@user-admin") as client:
704 new_user = client.refresh_token(user["username"])
705
706 self.assertEqual(user["username"], new_user["username"])
707 self.assertNotEqual(user["token"], new_user["token"])
708 self.assertUserCanAuth(new_user)
709 self.assertUserCannotAuth(user)
710
711 def test_auth_self_get_user(self):
712 admin_client = self.start_auth_server()
713
714 user = self.create_user("test-user", [])
715 user_info = user.copy()
716 del user_info["token"]
717
718 with self.auth_client(user) as client:
719 info = client.get_user()
720 self.assertEqual(info, user_info)
721
722 # Explicitly asking for your own username is fine also
723 info = client.get_user(user["username"])
724 self.assertEqual(info, user_info)
725
726 def test_auth_get_user(self):
727 admin_client = self.start_auth_server()
728
729 user = self.create_user("test-user", [])
730 user_info = user.copy()
731 del user_info["token"]
732
733 with self.auth_perms() as client, self.assertRaises(InvokeError):
734 client.get_user(user["username"])
735
736 with self.auth_perms("@user-admin") as client:
737 info = client.get_user(user["username"])
738 self.assertEqual(info, user_info)
739
740 info = client.get_user("nonexist-user")
741 self.assertIsNone(info)
742
743 def test_auth_reconnect(self):
744 admin_client = self.start_auth_server()
745
746 user = self.create_user("test-user", [])
747 user_info = user.copy()
748 del user_info["token"]
749
750 with self.auth_client(user) as client:
751 info = client.get_user()
752 self.assertEqual(info, user_info)
753
754 client.disconnect()
755
756 info = client.get_user()
757 self.assertEqual(info, user_info)
758
759 def test_auth_delete_user(self):
760 admin_client = self.start_auth_server()
761
762 user = self.create_user("test-user", [])
763
764 # self service
765 with self.auth_client(user) as client:
766 client.delete_user(user["username"])
767
768 self.assertIsNone(admin_client.get_user(user["username"]))
769 user = self.create_user("test-user", [])
770
771 with self.auth_perms() as client, self.assertRaises(InvokeError):
772 client.delete_user(user["username"])
773
774 with self.auth_perms("@user-admin") as client:
775 client.delete_user(user["username"])
776
777 # User doesn't exist, so even though the permission is correct, it's an
778 # error
779 with self.auth_perms("@user-admin") as client, self.assertRaises(InvokeError):
780 client.delete_user(user["username"])
781
782 def test_auth_set_user_perms(self):
783 admin_client = self.start_auth_server()
784
785 user = self.create_user("test-user", [])
786
787 self.assertUserPerms(user, [])
788
789 # No self service to change permissions
790 with self.auth_client(user) as client, self.assertRaises(InvokeError):
791 client.set_user_perms(user["username"], ["@all"])
792 self.assertUserPerms(user, [])
793
794 with self.auth_perms() as client, self.assertRaises(InvokeError):
795 client.set_user_perms(user["username"], ["@all"])
796 self.assertUserPerms(user, [])
797
798 with self.auth_perms("@user-admin") as client:
799 client.set_user_perms(user["username"], ["@all"])
800 self.assertUserPerms(user, sorted(list(ALL_PERMISSIONS)))
801
802 # Bad permissions
803 with self.auth_perms("@user-admin") as client, self.assertRaises(InvokeError):
804 client.set_user_perms(user["username"], ["@this-is-not-a-permission"])
805 self.assertUserPerms(user, sorted(list(ALL_PERMISSIONS)))
806
807 def test_auth_get_all_users(self):
808 admin_client = self.start_auth_server()
809
810 user = self.create_user("test-user", [])
811
812 with self.auth_client(user) as client, self.assertRaises(InvokeError):
813 client.get_all_users()
814
815 # Give the test user the correct permission
816 admin_client.set_user_perms(user["username"], ["@user-admin"])
817
818 with self.auth_client(user) as client:
819 all_users = client.get_all_users()
820
821 # Convert to a dictionary for easier comparison
822 all_users = {u["username"]: u for u in all_users}
823
824 self.assertEqual(all_users,
825 {
826 "admin": {
827 "username": "admin",
828 "permissions": sorted(list(ALL_PERMISSIONS)),
829 },
830 "test-user": {
831 "username": "test-user",
832 "permissions": ["@user-admin"],
833 }
834 }
835 )
836
837 def test_auth_new_user(self):
838 self.start_auth_server()
839
840 permissions = ["@read", "@report", "@db-admin", "@user-admin"]
841 permissions.sort()
842
843 with self.auth_perms() as client, self.assertRaises(InvokeError):
844 self.create_user("test-user", permissions, client=client)
845
846 with self.auth_perms("@user-admin") as client:
847 user = self.create_user("test-user", permissions, client=client)
848 self.assertIn("token", user)
849 self.assertEqual(user["username"], "test-user")
850 self.assertEqual(user["permissions"], permissions)
851
852 def test_auth_become_user(self):
853 admin_client = self.start_auth_server()
854
855 user = self.create_user("test-user", ["@read", "@report"])
856 user_info = user.copy()
857 del user_info["token"]
858
859 with self.auth_perms() as client, self.assertRaises(InvokeError):
860 client.become_user(user["username"])
861
862 with self.auth_perms("@user-admin") as client:
863 become = client.become_user(user["username"])
864 self.assertEqual(become, user_info)
865
866 info = client.get_user()
867 self.assertEqual(info, user_info)
868
869 # Verify become user is preserved across disconnect
870 client.disconnect()
871
872 info = client.get_user()
873 self.assertEqual(info, user_info)
874
875 # test-user doesn't have become_user permissions, so this should
876 # not work
877 with self.assertRaises(InvokeError):
878 client.become_user(user["username"])
879
880 # No self-service of become
881 with self.auth_client(user) as client, self.assertRaises(InvokeError):
882 client.become_user(user["username"])
883
884 # Give test user permissions to become
885 admin_client.set_user_perms(user["username"], ["@user-admin"])
886
887 # It's possible to become yourself (effectively a noop)
888 with self.auth_perms("@user-admin") as client:
889 become = client.become_user(client.username)
890
891 def test_auth_gc(self):
892 admin_client = self.start_auth_server()
893
894 with self.auth_perms() as client, self.assertRaises(InvokeError):
895 client.gc_mark("ABC", {"unihash": "123"})
896
897 with self.auth_perms() as client, self.assertRaises(InvokeError):
898 client.gc_status()
899
900 with self.auth_perms() as client, self.assertRaises(InvokeError):
901 client.gc_sweep("ABC")
902
903 with self.auth_perms("@db-admin") as client:
904 client.gc_mark("ABC", {"unihash": "123"})
905
906 with self.auth_perms("@db-admin") as client:
907 client.gc_status()
908
909 with self.auth_perms("@db-admin") as client:
910 client.gc_sweep("ABC")
911
912 def test_get_db_usage(self):
913 usage = self.client.get_db_usage()
914
915 self.assertTrue(isinstance(usage, dict))
916 for name in usage.keys():
917 self.assertTrue(isinstance(usage[name], dict))
918 self.assertIn("rows", usage[name])
919 self.assertTrue(isinstance(usage[name]["rows"], int))
920
921 def test_get_db_query_columns(self):
922 columns = self.client.get_db_query_columns()
923
924 self.assertTrue(isinstance(columns, list))
925 self.assertTrue(len(columns) > 0)
926
927 for col in columns:
928 self.client.remove({col: ""})
929
930 def test_auth_is_owner(self):
931 admin_client = self.start_auth_server()
932
933 user = self.create_user("test-user", ["@read", "@report"])
934 with self.auth_client(user) as client:
935 taskhash, outhash, unihash = self.create_test_hash(client)
936 data = client.get_taskhash(self.METHOD, taskhash, True)
937 self.assertEqual(data["owner"], user["username"])
938
939 def test_gc(self):
940 taskhash = '53b8dce672cb6d0c73170be43f540460bfc347b4'
941 outhash = '5a9cb1649625f0bf41fc7791b635cd9c2d7118c7f021ba87dcd03f72b67ce7a8'
942 unihash = 'f37918cc02eb5a520b1aff86faacbc0a38124646'
943
944 result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
945 self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
946
947 taskhash2 = '3bf6f1e89d26205aec90da04854fbdbf73afe6b4'
948 outhash2 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4'
949 unihash2 = 'af36b199320e611fbb16f1f277d3ee1d619ca58b'
950
951 result = self.client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2)
952 self.assertClientGetHash(self.client, taskhash2, unihash2)
953
954 # Mark the first unihash to be kept
955 ret = self.client.gc_mark("ABC", {"unihash": unihash, "method": self.METHOD})
956 self.assertEqual(ret, {"count": 1})
957
958 ret = self.client.gc_status()
959 self.assertEqual(ret, {"mark": "ABC", "keep": 1, "remove": 1})
960
961 # Second hash is still there; mark doesn't delete hashes
962 self.assertClientGetHash(self.client, taskhash2, unihash2)
963
964 ret = self.client.gc_sweep("ABC")
965 self.assertEqual(ret, {"count": 1})
966
967 # Hash is gone. Taskhash is returned for second hash
968 self.assertClientGetHash(self.client, taskhash2, None)
969 # First hash is still present
970 self.assertClientGetHash(self.client, taskhash, unihash)
971
972 def test_gc_stream(self):
973 taskhash = '53b8dce672cb6d0c73170be43f540460bfc347b4'
974 outhash = '5a9cb1649625f0bf41fc7791b635cd9c2d7118c7f021ba87dcd03f72b67ce7a8'
975 unihash = 'f37918cc02eb5a520b1aff86faacbc0a38124646'
976
977 result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
978 self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
979
980 taskhash2 = '3bf6f1e89d26205aec90da04854fbdbf73afe6b4'
981 outhash2 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4'
982 unihash2 = 'af36b199320e611fbb16f1f277d3ee1d619ca58b'
983
984 result = self.client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2)
985 self.assertClientGetHash(self.client, taskhash2, unihash2)
986
987 taskhash3 = 'a1117c1f5a7c9ab2f5a39cc6fe5e6152169d09c0'
988 outhash3 = '7289c414905303700a1117c1f5a7c9ab2f5a39cc6fe5e6152169d09c04f9a53c'
989 unihash3 = '905303700a1117c1f5a7c9ab2f5a39cc6fe5e615'
990
991 result = self.client.report_unihash(taskhash3, self.METHOD, outhash3, unihash3)
992 self.assertClientGetHash(self.client, taskhash3, unihash3)
993
994 # Mark the first unihash to be kept
995 ret = self.client.gc_mark_stream("ABC", (f"unihash {h}" for h in [unihash, unihash2]))
996 self.assertEqual(ret, {"count": 2})
997
998 ret = self.client.gc_status()
999 self.assertEqual(ret, {"mark": "ABC", "keep": 2, "remove": 1})
1000
1001 # Third hash is still there; mark doesn't delete hashes
1002 self.assertClientGetHash(self.client, taskhash3, unihash3)
1003
1004 ret = self.client.gc_sweep("ABC")
1005 self.assertEqual(ret, {"count": 1})
1006
1007 # Hash is gone. Taskhash is returned for second hash
1008 self.assertClientGetHash(self.client, taskhash3, None)
1009 # First hash is still present
1010 self.assertClientGetHash(self.client, taskhash, unihash)
1011 # Second hash is still present
1012 self.assertClientGetHash(self.client, taskhash2, unihash2)
1013
1014 def test_gc_switch_mark(self):
1015 taskhash = '53b8dce672cb6d0c73170be43f540460bfc347b4'
1016 outhash = '5a9cb1649625f0bf41fc7791b635cd9c2d7118c7f021ba87dcd03f72b67ce7a8'
1017 unihash = 'f37918cc02eb5a520b1aff86faacbc0a38124646'
1018
1019 result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
1020 self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
1021
1022 taskhash2 = '3bf6f1e89d26205aec90da04854fbdbf73afe6b4'
1023 outhash2 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4'
1024 unihash2 = 'af36b199320e611fbb16f1f277d3ee1d619ca58b'
1025
1026 result = self.client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2)
1027 self.assertClientGetHash(self.client, taskhash2, unihash2)
1028
1029 # Mark the first unihash to be kept
1030 ret = self.client.gc_mark("ABC", {"unihash": unihash, "method": self.METHOD})
1031 self.assertEqual(ret, {"count": 1})
1032
1033 ret = self.client.gc_status()
1034 self.assertEqual(ret, {"mark": "ABC", "keep": 1, "remove": 1})
1035
1036 # Second hash is still there; mark doesn't delete hashes
1037 self.assertClientGetHash(self.client, taskhash2, unihash2)
1038
1039 # Switch to a different mark and mark the second hash. This will start
1040 # a new collection cycle
1041 ret = self.client.gc_mark("DEF", {"unihash": unihash2, "method": self.METHOD})
1042 self.assertEqual(ret, {"count": 1})
1043
1044 ret = self.client.gc_status()
1045 self.assertEqual(ret, {"mark": "DEF", "keep": 1, "remove": 1})
1046
1047 # Both hashes are still present
1048 self.assertClientGetHash(self.client, taskhash2, unihash2)
1049 self.assertClientGetHash(self.client, taskhash, unihash)
1050
1051 # Sweep with the new mark
1052 ret = self.client.gc_sweep("DEF")
1053 self.assertEqual(ret, {"count": 1})
1054
1055 # First hash is gone, second is kept
1056 self.assertClientGetHash(self.client, taskhash2, unihash2)
1057 self.assertClientGetHash(self.client, taskhash, None)
1058
1059 def test_gc_switch_sweep_mark(self):
1060 taskhash = '53b8dce672cb6d0c73170be43f540460bfc347b4'
1061 outhash = '5a9cb1649625f0bf41fc7791b635cd9c2d7118c7f021ba87dcd03f72b67ce7a8'
1062 unihash = 'f37918cc02eb5a520b1aff86faacbc0a38124646'
1063
1064 result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
1065 self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
1066
1067 taskhash2 = '3bf6f1e89d26205aec90da04854fbdbf73afe6b4'
1068 outhash2 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4'
1069 unihash2 = 'af36b199320e611fbb16f1f277d3ee1d619ca58b'
1070
1071 result = self.client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2)
1072 self.assertClientGetHash(self.client, taskhash2, unihash2)
1073
1074 # Mark the first unihash to be kept
1075 ret = self.client.gc_mark("ABC", {"unihash": unihash, "method": self.METHOD})
1076 self.assertEqual(ret, {"count": 1})
1077
1078 ret = self.client.gc_status()
1079 self.assertEqual(ret, {"mark": "ABC", "keep": 1, "remove": 1})
1080
1081 # Sweeping with a different mark raises an error
1082 with self.assertRaises(InvokeError):
1083 self.client.gc_sweep("DEF")
1084
1085 # Both hashes are present
1086 self.assertClientGetHash(self.client, taskhash2, unihash2)
1087 self.assertClientGetHash(self.client, taskhash, unihash)
1088
1089 def test_gc_new_hashes(self):
1090 taskhash = '53b8dce672cb6d0c73170be43f540460bfc347b4'
1091 outhash = '5a9cb1649625f0bf41fc7791b635cd9c2d7118c7f021ba87dcd03f72b67ce7a8'
1092 unihash = 'f37918cc02eb5a520b1aff86faacbc0a38124646'
1093
1094 result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
1095 self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
1096
1097 # Start a new garbage collection
1098 ret = self.client.gc_mark("ABC", {"unihash": unihash, "method": self.METHOD})
1099 self.assertEqual(ret, {"count": 1})
1100
1101 ret = self.client.gc_status()
1102 self.assertEqual(ret, {"mark": "ABC", "keep": 1, "remove": 0})
1103
1104 # Add second hash. It should inherit the mark from the current garbage
1105 # collection operation
1106
1107 taskhash2 = '3bf6f1e89d26205aec90da04854fbdbf73afe6b4'
1108 outhash2 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4'
1109 unihash2 = 'af36b199320e611fbb16f1f277d3ee1d619ca58b'
1110
1111 result = self.client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2)
1112 self.assertClientGetHash(self.client, taskhash2, unihash2)
1113
1114 # Sweep should remove nothing
1115 ret = self.client.gc_sweep("ABC")
1116 self.assertEqual(ret, {"count": 0})
1117
1118 # Both hashes are present
1119 self.assertClientGetHash(self.client, taskhash2, unihash2)
1120 self.assertClientGetHash(self.client, taskhash, unihash)
1121
1122
1123class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase):
1124 def get_server_addr(self, server_idx):
1125 return "unix://" + os.path.join(self.temp_dir.name, 'sock%d' % server_idx)
1126
1127 def test_get(self):
1128 taskhash, outhash, unihash = self.create_test_hash(self.client)
1129
1130 p = self.run_hashclient(["--address", self.server_address, "get", self.METHOD, taskhash])
1131 data = json.loads(p.stdout)
1132 self.assertEqual(data["unihash"], unihash)
1133 self.assertEqual(data["outhash"], outhash)
1134 self.assertEqual(data["taskhash"], taskhash)
1135 self.assertEqual(data["method"], self.METHOD)
1136
1137 def test_get_outhash(self):
1138 taskhash, outhash, unihash = self.create_test_hash(self.client)
1139
1140 p = self.run_hashclient(["--address", self.server_address, "get-outhash", self.METHOD, outhash, taskhash])
1141 data = json.loads(p.stdout)
1142 self.assertEqual(data["unihash"], unihash)
1143 self.assertEqual(data["outhash"], outhash)
1144 self.assertEqual(data["taskhash"], taskhash)
1145 self.assertEqual(data["method"], self.METHOD)
1146
1147 def test_stats(self):
1148 p = self.run_hashclient(["--address", self.server_address, "stats"], check=True)
1149 json.loads(p.stdout)
1150
1151 def test_stress(self):
1152 self.run_hashclient(["--address", self.server_address, "stress"], check=True)
1153
1154 def test_unihash_exsits(self):
1155 taskhash, outhash, unihash = self.create_test_hash(self.client)
1156
1157 p = self.run_hashclient([
1158 "--address", self.server_address,
1159 "unihash-exists", unihash,
1160 ], check=True)
1161 self.assertEqual(p.stdout.strip(), "true")
1162
1163 p = self.run_hashclient([
1164 "--address", self.server_address,
1165 "unihash-exists", '6662e699d6e3d894b24408ff9a4031ef9b038ee8',
1166 ], check=True)
1167 self.assertEqual(p.stdout.strip(), "false")
1168
1169 def test_unihash_exsits_quiet(self):
1170 taskhash, outhash, unihash = self.create_test_hash(self.client)
1171
1172 p = self.run_hashclient([
1173 "--address", self.server_address,
1174 "unihash-exists", unihash,
1175 "--quiet",
1176 ])
1177 self.assertEqual(p.returncode, 0)
1178 self.assertEqual(p.stdout.strip(), "")
1179
1180 p = self.run_hashclient([
1181 "--address", self.server_address,
1182 "unihash-exists", '6662e699d6e3d894b24408ff9a4031ef9b038ee8',
1183 "--quiet",
1184 ])
1185 self.assertEqual(p.returncode, 1)
1186 self.assertEqual(p.stdout.strip(), "")
1187
1188 def test_remove_taskhash(self):
1189 taskhash, outhash, unihash = self.create_test_hash(self.client)
1190 self.run_hashclient([
1191 "--address", self.server_address,
1192 "remove",
1193 "--where", "taskhash", taskhash,
1194 ], check=True)
1195 self.assertClientGetHash(self.client, taskhash, None)
1196
1197 result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash)
1198 self.assertIsNone(result_outhash)
1199
1200 def test_remove_unihash(self):
1201 taskhash, outhash, unihash = self.create_test_hash(self.client)
1202 self.run_hashclient([
1203 "--address", self.server_address,
1204 "remove",
1205 "--where", "unihash", unihash,
1206 ], check=True)
1207 self.assertClientGetHash(self.client, taskhash, None)
1208
1209 def test_remove_outhash(self):
1210 taskhash, outhash, unihash = self.create_test_hash(self.client)
1211 self.run_hashclient([
1212 "--address", self.server_address,
1213 "remove",
1214 "--where", "outhash", outhash,
1215 ], check=True)
1216
1217 result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash)
1218 self.assertIsNone(result_outhash)
1219
1220 def test_remove_method(self):
1221 taskhash, outhash, unihash = self.create_test_hash(self.client)
1222 self.run_hashclient([
1223 "--address", self.server_address,
1224 "remove",
1225 "--where", "method", self.METHOD,
1226 ], check=True)
1227 self.assertClientGetHash(self.client, taskhash, None)
1228
1229 result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash)
1230 self.assertIsNone(result_outhash)
1231
1232 def test_clean_unused(self):
1233 taskhash, outhash, unihash = self.create_test_hash(self.client)
1234
1235 # Clean the database, which should not remove anything because all hashes an in-use
1236 self.run_hashclient([
1237 "--address", self.server_address,
1238 "clean-unused", "0",
1239 ], check=True)
1240 self.assertClientGetHash(self.client, taskhash, unihash)
1241
1242 # Remove the unihash. The row in the outhash table should still be present
1243 self.run_hashclient([
1244 "--address", self.server_address,
1245 "remove",
1246 "--where", "unihash", unihash,
1247 ], check=True)
1248 result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash, False)
1249 self.assertIsNotNone(result_outhash)
1250
1251 # Now clean with no minimum age which will remove the outhash
1252 self.run_hashclient([
1253 "--address", self.server_address,
1254 "clean-unused", "0",
1255 ], check=True)
1256 result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash, False)
1257 self.assertIsNone(result_outhash)
1258
1259 def test_refresh_token(self):
1260 admin_client = self.start_auth_server()
1261
1262 user = admin_client.new_user("test-user", ["@read", "@report"])
1263
1264 p = self.run_hashclient([
1265 "--address", self.auth_server_address,
1266 "--login", user["username"],
1267 "--password", user["token"],
1268 "refresh-token"
1269 ], check=True)
1270
1271 new_token = None
1272 for l in p.stdout.splitlines():
1273 l = l.rstrip()
1274 m = re.match(r'Token: +(.*)$', l)
1275 if m is not None:
1276 new_token = m.group(1)
1277
1278 self.assertTrue(new_token)
1279
1280 print("New token is %r" % new_token)
1281
1282 self.run_hashclient([
1283 "--address", self.auth_server_address,
1284 "--login", user["username"],
1285 "--password", new_token,
1286 "get-user"
1287 ], check=True)
1288
1289 def test_set_user_perms(self):
1290 admin_client = self.start_auth_server()
1291
1292 user = admin_client.new_user("test-user", ["@read"])
1293
1294 self.run_hashclient([
1295 "--address", self.auth_server_address,
1296 "--login", admin_client.username,
1297 "--password", admin_client.password,
1298 "set-user-perms",
1299 "-u", user["username"],
1300 "@read", "@report",
1301 ], check=True)
1302
1303 new_user = admin_client.get_user(user["username"])
1304
1305 self.assertEqual(set(new_user["permissions"]), {"@read", "@report"})
1306
1307 def test_get_user(self):
1308 admin_client = self.start_auth_server()
1309
1310 user = admin_client.new_user("test-user", ["@read"])
1311
1312 p = self.run_hashclient([
1313 "--address", self.auth_server_address,
1314 "--login", admin_client.username,
1315 "--password", admin_client.password,
1316 "get-user",
1317 "-u", user["username"],
1318 ], check=True)
1319
1320 self.assertIn("Username:", p.stdout)
1321 self.assertIn("Permissions:", p.stdout)
1322
1323 p = self.run_hashclient([
1324 "--address", self.auth_server_address,
1325 "--login", user["username"],
1326 "--password", user["token"],
1327 "get-user",
1328 ], check=True)
1329
1330 self.assertIn("Username:", p.stdout)
1331 self.assertIn("Permissions:", p.stdout)
1332
1333 def test_get_all_users(self):
1334 admin_client = self.start_auth_server()
1335
1336 admin_client.new_user("test-user1", ["@read"])
1337 admin_client.new_user("test-user2", ["@read"])
1338
1339 p = self.run_hashclient([
1340 "--address", self.auth_server_address,
1341 "--login", admin_client.username,
1342 "--password", admin_client.password,
1343 "get-all-users",
1344 ], check=True)
1345
1346 self.assertIn("admin", p.stdout)
1347 self.assertIn("test-user1", p.stdout)
1348 self.assertIn("test-user2", p.stdout)
1349
1350 def test_new_user(self):
1351 admin_client = self.start_auth_server()
1352
1353 p = self.run_hashclient([
1354 "--address", self.auth_server_address,
1355 "--login", admin_client.username,
1356 "--password", admin_client.password,
1357 "new-user",
1358 "-u", "test-user",
1359 "@read", "@report",
1360 ], check=True)
1361
1362 new_token = None
1363 for l in p.stdout.splitlines():
1364 l = l.rstrip()
1365 m = re.match(r'Token: +(.*)$', l)
1366 if m is not None:
1367 new_token = m.group(1)
1368
1369 self.assertTrue(new_token)
1370
1371 user = {
1372 "username": "test-user",
1373 "token": new_token,
1374 }
1375
1376 self.assertUserPerms(user, ["@read", "@report"])
1377
1378 def test_delete_user(self):
1379 admin_client = self.start_auth_server()
1380
1381 user = admin_client.new_user("test-user", ["@read"])
1382
1383 p = self.run_hashclient([
1384 "--address", self.auth_server_address,
1385 "--login", admin_client.username,
1386 "--password", admin_client.password,
1387 "delete-user",
1388 "-u", user["username"],
1389 ], check=True)
1390
1391 self.assertIsNone(admin_client.get_user(user["username"]))
1392
1393 def test_get_db_usage(self):
1394 p = self.run_hashclient([
1395 "--address", self.server_address,
1396 "get-db-usage",
1397 ], check=True)
1398
1399 def test_get_db_query_columns(self):
1400 p = self.run_hashclient([
1401 "--address", self.server_address,
1402 "get-db-query-columns",
1403 ], check=True)
1404
1405 def test_gc(self):
1406 taskhash = '53b8dce672cb6d0c73170be43f540460bfc347b4'
1407 outhash = '5a9cb1649625f0bf41fc7791b635cd9c2d7118c7f021ba87dcd03f72b67ce7a8'
1408 unihash = 'f37918cc02eb5a520b1aff86faacbc0a38124646'
1409
1410 result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
1411 self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
1412
1413 taskhash2 = '3bf6f1e89d26205aec90da04854fbdbf73afe6b4'
1414 outhash2 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4'
1415 unihash2 = 'af36b199320e611fbb16f1f277d3ee1d619ca58b'
1416
1417 result = self.client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2)
1418 self.assertClientGetHash(self.client, taskhash2, unihash2)
1419
1420 # Mark the first unihash to be kept
1421 self.run_hashclient([
1422 "--address", self.server_address,
1423 "gc-mark", "ABC",
1424 "--where", "unihash", unihash,
1425 "--where", "method", self.METHOD
1426 ], check=True)
1427
1428 # Second hash is still there; mark doesn't delete hashes
1429 self.assertClientGetHash(self.client, taskhash2, unihash2)
1430
1431 self.run_hashclient([
1432 "--address", self.server_address,
1433 "gc-sweep", "ABC",
1434 ], check=True)
1435
1436 # Hash is gone. Taskhash is returned for second hash
284 self.assertClientGetHash(self.client, taskhash2, None) 1437 self.assertClientGetHash(self.client, taskhash2, None)
1438 # First hash is still present
1439 self.assertClientGetHash(self.client, taskhash, unihash)
285 1440
286 1441
287class TestHashEquivalenceUnixServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase): 1442class TestHashEquivalenceUnixServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
@@ -314,3 +1469,77 @@ class TestHashEquivalenceTCPServer(HashEquivalenceTestSetup, HashEquivalenceComm
314 # If IPv6 is enabled, it should be safe to use localhost directly, in general 1469 # If IPv6 is enabled, it should be safe to use localhost directly, in general
315 # case it is more reliable to resolve the IP address explicitly. 1470 # case it is more reliable to resolve the IP address explicitly.
316 return socket.gethostbyname("localhost") + ":0" 1471 return socket.gethostbyname("localhost") + ":0"
1472
1473
1474class TestHashEquivalenceWebsocketServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
1475 def setUp(self):
1476 try:
1477 import websockets
1478 except ImportError as e:
1479 self.skipTest(str(e))
1480
1481 super().setUp()
1482
1483 def get_server_addr(self, server_idx):
1484 # Some hosts cause asyncio module to misbehave, when IPv6 is not enabled.
1485 # If IPv6 is enabled, it should be safe to use localhost directly, in general
1486 # case it is more reliable to resolve the IP address explicitly.
1487 host = socket.gethostbyname("localhost")
1488 return "ws://%s:0" % host
1489
1490
1491class TestHashEquivalenceWebsocketsSQLAlchemyServer(TestHashEquivalenceWebsocketServer):
1492 def setUp(self):
1493 try:
1494 import sqlalchemy
1495 import aiosqlite
1496 except ImportError as e:
1497 self.skipTest(str(e))
1498
1499 super().setUp()
1500
1501 def make_dbpath(self):
1502 return "sqlite+aiosqlite:///%s" % os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index)
1503
1504
1505class TestHashEquivalenceExternalServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
1506 def get_env(self, name):
1507 v = os.environ.get(name)
1508 if not v:
1509 self.skipTest(f'{name} not defined to test an external server')
1510 return v
1511
1512 def start_test_server(self):
1513 return self.get_env('BB_TEST_HASHSERV')
1514
1515 def start_server(self, *args, **kwargs):
1516 self.skipTest('Cannot start local server when testing external servers')
1517
1518 def start_auth_server(self):
1519
1520 self.auth_server_address = self.server_address
1521 self.admin_client = self.start_client(
1522 self.server_address,
1523 username=self.get_env('BB_TEST_HASHSERV_USERNAME'),
1524 password=self.get_env('BB_TEST_HASHSERV_PASSWORD'),
1525 )
1526 return self.admin_client
1527
1528 def setUp(self):
1529 super().setUp()
1530 if "BB_TEST_HASHSERV_USERNAME" in os.environ:
1531 self.client = self.start_client(
1532 self.server_address,
1533 username=os.environ["BB_TEST_HASHSERV_USERNAME"],
1534 password=os.environ["BB_TEST_HASHSERV_PASSWORD"],
1535 )
1536 self.client.remove({"method": self.METHOD})
1537
1538 def tearDown(self):
1539 self.client.remove({"method": self.METHOD})
1540 super().tearDown()
1541
1542
1543 def test_auth_get_all_users(self):
1544 self.skipTest("Cannot test all users with external server")
1545