summaryrefslogtreecommitdiffstats
path: root/bitbake/lib/hashserv
diff options
context:
space:
mode:
Diffstat (limited to 'bitbake/lib/hashserv')
-rw-r--r--bitbake/lib/hashserv/__init__.py175
-rw-r--r--bitbake/lib/hashserv/client.py431
-rw-r--r--bitbake/lib/hashserv/server.py1088
-rw-r--r--bitbake/lib/hashserv/sqlalchemy.py598
-rw-r--r--bitbake/lib/hashserv/sqlite.py562
-rw-r--r--bitbake/lib/hashserv/tests.py1267
6 files changed, 3411 insertions, 710 deletions
diff --git a/bitbake/lib/hashserv/__init__.py b/bitbake/lib/hashserv/__init__.py
index 5f2e101e52..ac891e0174 100644
--- a/bitbake/lib/hashserv/__init__.py
+++ b/bitbake/lib/hashserv/__init__.py
@@ -5,129 +5,104 @@
5 5
6import asyncio 6import asyncio
7from contextlib import closing 7from contextlib import closing
8import re
9import sqlite3
10import itertools 8import itertools
11import json 9import json
10from collections import namedtuple
11from urllib.parse import urlparse
12from bb.asyncrpc.client import parse_address, ADDR_TYPE_UNIX, ADDR_TYPE_WS
12 13
13UNIX_PREFIX = "unix://" 14User = namedtuple("User", ("username", "permissions"))
14
15ADDR_TYPE_UNIX = 0
16ADDR_TYPE_TCP = 1
17
18# The Python async server defaults to a 64K receive buffer, so we hardcode our
19# maximum chunk size. It would be better if the client and server reported to
20# each other what the maximum chunk sizes were, but that will slow down the
21# connection setup with a round trip delay so I'd rather not do that unless it
22# is necessary
23DEFAULT_MAX_CHUNK = 32 * 1024
24
25TABLE_DEFINITION = (
26 ("method", "TEXT NOT NULL"),
27 ("outhash", "TEXT NOT NULL"),
28 ("taskhash", "TEXT NOT NULL"),
29 ("unihash", "TEXT NOT NULL"),
30 ("created", "DATETIME"),
31
32 # Optional fields
33 ("owner", "TEXT"),
34 ("PN", "TEXT"),
35 ("PV", "TEXT"),
36 ("PR", "TEXT"),
37 ("task", "TEXT"),
38 ("outhash_siginfo", "TEXT"),
39)
40
41TABLE_COLUMNS = tuple(name for name, _ in TABLE_DEFINITION)
42
43def setup_database(database, sync=True):
44 db = sqlite3.connect(database)
45 db.row_factory = sqlite3.Row
46
47 with closing(db.cursor()) as cursor:
48 cursor.execute('''
49 CREATE TABLE IF NOT EXISTS tasks_v2 (
50 id INTEGER PRIMARY KEY AUTOINCREMENT,
51 %s
52 UNIQUE(method, outhash, taskhash)
53 )
54 ''' % " ".join("%s %s," % (name, typ) for name, typ in TABLE_DEFINITION))
55 cursor.execute('PRAGMA journal_mode = WAL')
56 cursor.execute('PRAGMA synchronous = %s' % ('NORMAL' if sync else 'OFF'))
57
58 # Drop old indexes
59 cursor.execute('DROP INDEX IF EXISTS taskhash_lookup')
60 cursor.execute('DROP INDEX IF EXISTS outhash_lookup')
61
62 # Create new indexes
63 cursor.execute('CREATE INDEX IF NOT EXISTS taskhash_lookup_v2 ON tasks_v2 (method, taskhash, created)')
64 cursor.execute('CREATE INDEX IF NOT EXISTS outhash_lookup_v2 ON tasks_v2 (method, outhash)')
65
66 return db
67
68
69def parse_address(addr):
70 if addr.startswith(UNIX_PREFIX):
71 return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX):],))
72 else:
73 m = re.match(r'\[(?P<host>[^\]]*)\]:(?P<port>\d+)$', addr)
74 if m is not None:
75 host = m.group('host')
76 port = m.group('port')
77 else:
78 host, port = addr.split(':')
79 15
80 return (ADDR_TYPE_TCP, (host, int(port)))
81 16
17def create_server(
18 addr,
19 dbname,
20 *,
21 sync=True,
22 upstream=None,
23 read_only=False,
24 db_username=None,
25 db_password=None,
26 anon_perms=None,
27 admin_username=None,
28 admin_password=None,
29 reuseport=False,
30):
31 def sqlite_engine():
32 from .sqlite import DatabaseEngine
82 33
83def chunkify(msg, max_chunk): 34 return DatabaseEngine(dbname, sync)
84 if len(msg) < max_chunk - 1:
85 yield ''.join((msg, "\n"))
86 else:
87 yield ''.join((json.dumps({
88 'chunk-stream': None
89 }), "\n"))
90 35
91 args = [iter(msg)] * (max_chunk - 1) 36 def sqlalchemy_engine():
92 for m in map(''.join, itertools.zip_longest(*args, fillvalue='')): 37 from .sqlalchemy import DatabaseEngine
93 yield ''.join(itertools.chain(m, "\n"))
94 yield "\n"
95 38
39 return DatabaseEngine(dbname, db_username, db_password)
96 40
97def create_server(addr, dbname, *, sync=True, upstream=None, read_only=False):
98 from . import server 41 from . import server
99 db = setup_database(dbname, sync=sync) 42
100 s = server.Server(db, upstream=upstream, read_only=read_only) 43 if "://" in dbname:
44 db_engine = sqlalchemy_engine()
45 else:
46 db_engine = sqlite_engine()
47
48 if anon_perms is None:
49 anon_perms = server.DEFAULT_ANON_PERMS
50
51 s = server.Server(
52 db_engine,
53 upstream=upstream,
54 read_only=read_only,
55 anon_perms=anon_perms,
56 admin_username=admin_username,
57 admin_password=admin_password,
58 )
101 59
102 (typ, a) = parse_address(addr) 60 (typ, a) = parse_address(addr)
103 if typ == ADDR_TYPE_UNIX: 61 if typ == ADDR_TYPE_UNIX:
104 s.start_unix_server(*a) 62 s.start_unix_server(*a)
63 elif typ == ADDR_TYPE_WS:
64 url = urlparse(a[0])
65 s.start_websocket_server(url.hostname, url.port, reuseport=reuseport)
105 else: 66 else:
106 s.start_tcp_server(*a) 67 s.start_tcp_server(*a, reuseport=reuseport)
107 68
108 return s 69 return s
109 70
110 71
111def create_client(addr): 72def create_client(addr, username=None, password=None):
112 from . import client 73 from . import client
113 c = client.Client()
114 74
115 (typ, a) = parse_address(addr) 75 c = client.Client(username, password)
116 if typ == ADDR_TYPE_UNIX: 76
117 c.connect_unix(*a) 77 try:
118 else: 78 (typ, a) = parse_address(addr)
119 c.connect_tcp(*a) 79 if typ == ADDR_TYPE_UNIX:
80 c.connect_unix(*a)
81 elif typ == ADDR_TYPE_WS:
82 c.connect_websocket(*a)
83 else:
84 c.connect_tcp(*a)
85 return c
86 except Exception as e:
87 c.close()
88 raise e
120 89
121 return c
122 90
123async def create_async_client(addr): 91async def create_async_client(addr, username=None, password=None):
124 from . import client 92 from . import client
125 c = client.AsyncClient()
126 93
127 (typ, a) = parse_address(addr) 94 c = client.AsyncClient(username, password)
128 if typ == ADDR_TYPE_UNIX: 95
129 await c.connect_unix(*a) 96 try:
130 else: 97 (typ, a) = parse_address(addr)
131 await c.connect_tcp(*a) 98 if typ == ADDR_TYPE_UNIX:
99 await c.connect_unix(*a)
100 elif typ == ADDR_TYPE_WS:
101 await c.connect_websocket(*a)
102 else:
103 await c.connect_tcp(*a)
132 104
133 return c 105 return c
106 except Exception as e:
107 await c.close()
108 raise e
diff --git a/bitbake/lib/hashserv/client.py b/bitbake/lib/hashserv/client.py
index e05c1eb568..a510f3284f 100644
--- a/bitbake/lib/hashserv/client.py
+++ b/bitbake/lib/hashserv/client.py
@@ -3,231 +3,356 @@
3# SPDX-License-Identifier: GPL-2.0-only 3# SPDX-License-Identifier: GPL-2.0-only
4# 4#
5 5
6import asyncio
7import json
8import logging 6import logging
9import socket 7import socket
10import os 8import asyncio
11from . import chunkify, DEFAULT_MAX_CHUNK, create_async_client 9import bb.asyncrpc
10import json
11from . import create_async_client
12 12
13 13
14logger = logging.getLogger("hashserv.client") 14logger = logging.getLogger("hashserv.client")
15 15
16 16
17class HashConnectionError(Exception): 17class Batch(object):
18 pass 18 def __init__(self):
19 self.done = False
20 self.cond = asyncio.Condition()
21 self.pending = []
22 self.results = []
23 self.sent_count = 0
19 24
25 async def recv(self, socket):
26 while True:
27 async with self.cond:
28 await self.cond.wait_for(lambda: self.pending or self.done)
20 29
21class AsyncClient(object): 30 if not self.pending:
22 MODE_NORMAL = 0 31 if self.done:
23 MODE_GET_STREAM = 1 32 return
33 continue
24 34
25 def __init__(self): 35 r = await socket.recv()
26 self.reader = None 36 self.results.append(r)
27 self.writer = None
28 self.mode = self.MODE_NORMAL
29 self.max_chunk = DEFAULT_MAX_CHUNK
30 37
31 async def connect_tcp(self, address, port): 38 async with self.cond:
32 async def connect_sock(): 39 self.pending.pop(0)
33 return await asyncio.open_connection(address, port)
34 40
35 self._connect_sock = connect_sock 41 async def send(self, socket, msgs):
42 try:
43 # In the event of a restart due to a reconnect, all in-flight
44 # messages need to be resent first to keep to result count in sync
45 for m in self.pending:
46 await socket.send(m)
36 47
37 async def connect_unix(self, path): 48 for m in msgs:
38 async def connect_sock(): 49 # Add the message to the pending list before attempting to send
39 return await asyncio.open_unix_connection(path) 50 # it so that if the send fails it will be retried
51 async with self.cond:
52 self.pending.append(m)
53 self.cond.notify()
54 self.sent_count += 1
40 55
41 self._connect_sock = connect_sock 56 await socket.send(m)
42 57
43 async def connect(self): 58 finally:
44 if self.reader is None or self.writer is None: 59 async with self.cond:
45 (self.reader, self.writer) = await self._connect_sock() 60 self.done = True
61 self.cond.notify()
62
63 async def process(self, socket, msgs):
64 await asyncio.gather(
65 self.recv(socket),
66 self.send(socket, msgs),
67 )
46 68
47 self.writer.write("OEHASHEQUIV 1.1\n\n".encode("utf-8")) 69 if len(self.results) != self.sent_count:
48 await self.writer.drain() 70 raise ValueError(
71 f"Expected result count {len(self.results)}. Expected {self.sent_count}"
72 )
49 73
50 cur_mode = self.mode 74 return self.results
51 self.mode = self.MODE_NORMAL
52 await self._set_mode(cur_mode)
53 75
54 async def close(self):
55 self.reader = None
56 76
57 if self.writer is not None: 77class AsyncClient(bb.asyncrpc.AsyncClient):
58 self.writer.close() 78 MODE_NORMAL = 0
59 self.writer = None 79 MODE_GET_STREAM = 1
80 MODE_EXIST_STREAM = 2
60 81
61 async def _send_wrapper(self, proc): 82 def __init__(self, username=None, password=None):
62 count = 0 83 super().__init__("OEHASHEQUIV", "1.1", logger)
63 while True: 84 self.mode = self.MODE_NORMAL
64 try: 85 self.username = username
65 await self.connect() 86 self.password = password
66 return await proc() 87 self.saved_become_user = None
67 except (
68 OSError,
69 HashConnectionError,
70 json.JSONDecodeError,
71 UnicodeDecodeError,
72 ) as e:
73 logger.warning("Error talking to server: %s" % e)
74 if count >= 3:
75 if not isinstance(e, HashConnectionError):
76 raise HashConnectionError(str(e))
77 raise e
78 await self.close()
79 count += 1
80
81 async def send_message(self, msg):
82 async def get_line():
83 line = await self.reader.readline()
84 if not line:
85 raise HashConnectionError("Connection closed")
86
87 line = line.decode("utf-8")
88
89 if not line.endswith("\n"):
90 raise HashConnectionError("Bad message %r" % message)
91
92 return line
93 88
94 async def proc(): 89 async def setup_connection(self):
95 for c in chunkify(json.dumps(msg), self.max_chunk): 90 await super().setup_connection()
96 self.writer.write(c.encode("utf-8")) 91 self.mode = self.MODE_NORMAL
97 await self.writer.drain() 92 if self.username:
93 # Save off become user temporarily because auth() resets it
94 become = self.saved_become_user
95 await self.auth(self.username, self.password)
98 96
99 l = await get_line() 97 if become:
98 await self.become_user(become)
100 99
101 m = json.loads(l) 100 async def send_stream_batch(self, mode, msgs):
102 if m and "chunk-stream" in m: 101 """
103 lines = [] 102 Does a "batch" process of stream messages. This sends the query
104 while True: 103 messages as fast as possible, and simultaneously attempts to read the
105 l = (await get_line()).rstrip("\n") 104 messages back. This helps to mitigate the effects of latency to the
106 if not l: 105 hash equivalence server be allowing multiple queries to be "in-flight"
107 break 106 at once
108 lines.append(l)
109 107
110 m = json.loads("".join(lines)) 108 The implementation does more complicated tracking using a count of sent
109 messages so that `msgs` can be a generator function (i.e. its length is
110 unknown)
111 111
112 return m 112 """
113 113
114 return await self._send_wrapper(proc) 114 b = Batch()
115 115
116 async def send_stream(self, msg):
117 async def proc(): 116 async def proc():
118 self.writer.write(("%s\n" % msg).encode("utf-8")) 117 nonlocal b
119 await self.writer.drain() 118
120 l = await self.reader.readline() 119 await self._set_mode(mode)
121 if not l: 120 return await b.process(self.socket, msgs)
122 raise HashConnectionError("Connection closed")
123 return l.decode("utf-8").rstrip()
124 121
125 return await self._send_wrapper(proc) 122 return await self._send_wrapper(proc)
126 123
124 async def invoke(self, *args, skip_mode=False, **kwargs):
125 # It's OK if connection errors cause a failure here, because the mode
126 # is also reset to normal on a new connection
127 if not skip_mode:
128 await self._set_mode(self.MODE_NORMAL)
129 return await super().invoke(*args, **kwargs)
130
127 async def _set_mode(self, new_mode): 131 async def _set_mode(self, new_mode):
128 if new_mode == self.MODE_NORMAL and self.mode == self.MODE_GET_STREAM: 132 async def stream_to_normal():
129 r = await self.send_stream("END") 133 # Check if already in normal mode (e.g. due to a connection reset)
134 if self.mode == self.MODE_NORMAL:
135 return "ok"
136 await self.socket.send("END")
137 return await self.socket.recv()
138
139 async def normal_to_stream(command):
140 r = await self.invoke({command: None}, skip_mode=True)
130 if r != "ok": 141 if r != "ok":
131 raise HashConnectionError("Bad response from server %r" % r) 142 self.check_invoke_error(r)
132 elif new_mode == self.MODE_GET_STREAM and self.mode == self.MODE_NORMAL: 143 raise ConnectionError(
133 r = await self.send_message({"get-stream": None}) 144 f"Unable to transition to stream mode: Bad response from server {r!r}"
145 )
146 self.logger.debug("Mode is now %s", command)
147
148 if new_mode == self.mode:
149 return
150
151 self.logger.debug("Transitioning mode %s -> %s", self.mode, new_mode)
152
153 # Always transition to normal mode before switching to any other mode
154 if self.mode != self.MODE_NORMAL:
155 r = await self._send_wrapper(stream_to_normal)
134 if r != "ok": 156 if r != "ok":
135 raise HashConnectionError("Bad response from server %r" % r) 157 self.check_invoke_error(r)
136 elif new_mode != self.mode: 158 raise ConnectionError(
137 raise Exception( 159 f"Unable to transition to normal mode: Bad response from server {r!r}"
138 "Undefined mode transition %r -> %r" % (self.mode, new_mode) 160 )
139 ) 161 self.logger.debug("Mode is now normal")
162
163 if new_mode == self.MODE_GET_STREAM:
164 await normal_to_stream("get-stream")
165 elif new_mode == self.MODE_EXIST_STREAM:
166 await normal_to_stream("exists-stream")
167 elif new_mode != self.MODE_NORMAL:
168 raise Exception("Undefined mode transition {self.mode!r} -> {new_mode!r}")
140 169
141 self.mode = new_mode 170 self.mode = new_mode
142 171
143 async def get_unihash(self, method, taskhash): 172 async def get_unihash(self, method, taskhash):
144 await self._set_mode(self.MODE_GET_STREAM) 173 r = await self.get_unihash_batch([(method, taskhash)])
145 r = await self.send_stream("%s %s" % (method, taskhash)) 174 return r[0]
146 if not r: 175
147 return None 176 async def get_unihash_batch(self, args):
148 return r 177 result = await self.send_stream_batch(
178 self.MODE_GET_STREAM,
179 (f"{method} {taskhash}" for method, taskhash in args),
180 )
181 return [r if r else None for r in result]
149 182
150 async def report_unihash(self, taskhash, method, outhash, unihash, extra={}): 183 async def report_unihash(self, taskhash, method, outhash, unihash, extra={}):
151 await self._set_mode(self.MODE_NORMAL)
152 m = extra.copy() 184 m = extra.copy()
153 m["taskhash"] = taskhash 185 m["taskhash"] = taskhash
154 m["method"] = method 186 m["method"] = method
155 m["outhash"] = outhash 187 m["outhash"] = outhash
156 m["unihash"] = unihash 188 m["unihash"] = unihash
157 return await self.send_message({"report": m}) 189 return await self.invoke({"report": m})
158 190
159 async def report_unihash_equiv(self, taskhash, method, unihash, extra={}): 191 async def report_unihash_equiv(self, taskhash, method, unihash, extra={}):
160 await self._set_mode(self.MODE_NORMAL)
161 m = extra.copy() 192 m = extra.copy()
162 m["taskhash"] = taskhash 193 m["taskhash"] = taskhash
163 m["method"] = method 194 m["method"] = method
164 m["unihash"] = unihash 195 m["unihash"] = unihash
165 return await self.send_message({"report-equiv": m}) 196 return await self.invoke({"report-equiv": m})
166 197
167 async def get_taskhash(self, method, taskhash, all_properties=False): 198 async def get_taskhash(self, method, taskhash, all_properties=False):
168 await self._set_mode(self.MODE_NORMAL) 199 return await self.invoke(
169 return await self.send_message(
170 {"get": {"taskhash": taskhash, "method": method, "all": all_properties}} 200 {"get": {"taskhash": taskhash, "method": method, "all": all_properties}}
171 ) 201 )
172 202
173 async def get_outhash(self, method, outhash, taskhash): 203 async def unihash_exists(self, unihash):
174 await self._set_mode(self.MODE_NORMAL) 204 r = await self.unihash_exists_batch([unihash])
175 return await self.send_message( 205 return r[0]
176 {"get-outhash": {"outhash": outhash, "taskhash": taskhash, "method": method}} 206
207 async def unihash_exists_batch(self, unihashes):
208 result = await self.send_stream_batch(self.MODE_EXIST_STREAM, unihashes)
209 return [r == "true" for r in result]
210
211 async def get_outhash(self, method, outhash, taskhash, with_unihash=True):
212 return await self.invoke(
213 {
214 "get-outhash": {
215 "outhash": outhash,
216 "taskhash": taskhash,
217 "method": method,
218 "with_unihash": with_unihash,
219 }
220 }
177 ) 221 )
178 222
179 async def get_stats(self): 223 async def get_stats(self):
180 await self._set_mode(self.MODE_NORMAL) 224 return await self.invoke({"get-stats": None})
181 return await self.send_message({"get-stats": None})
182 225
183 async def reset_stats(self): 226 async def reset_stats(self):
184 await self._set_mode(self.MODE_NORMAL) 227 return await self.invoke({"reset-stats": None})
185 return await self.send_message({"reset-stats": None})
186 228
187 async def backfill_wait(self): 229 async def backfill_wait(self):
188 await self._set_mode(self.MODE_NORMAL) 230 return (await self.invoke({"backfill-wait": None}))["tasks"]
189 return (await self.send_message({"backfill-wait": None}))["tasks"] 231
232 async def remove(self, where):
233 return await self.invoke({"remove": {"where": where}})
234
235 async def clean_unused(self, max_age):
236 return await self.invoke({"clean-unused": {"max_age_seconds": max_age}})
237
238 async def auth(self, username, token):
239 result = await self.invoke({"auth": {"username": username, "token": token}})
240 self.username = username
241 self.password = token
242 self.saved_become_user = None
243 return result
244
245 async def refresh_token(self, username=None):
246 m = {}
247 if username:
248 m["username"] = username
249 result = await self.invoke({"refresh-token": m})
250 if (
251 self.username
252 and not self.saved_become_user
253 and result["username"] == self.username
254 ):
255 self.password = result["token"]
256 return result
190 257
258 async def set_user_perms(self, username, permissions):
259 return await self.invoke(
260 {"set-user-perms": {"username": username, "permissions": permissions}}
261 )
191 262
192class Client(object): 263 async def get_user(self, username=None):
193 def __init__(self): 264 m = {}
194 self.client = AsyncClient() 265 if username:
195 self.loop = asyncio.new_event_loop() 266 m["username"] = username
267 return await self.invoke({"get-user": m})
268
269 async def get_all_users(self):
270 return (await self.invoke({"get-all-users": {}}))["users"]
271
272 async def new_user(self, username, permissions):
273 return await self.invoke(
274 {"new-user": {"username": username, "permissions": permissions}}
275 )
276
277 async def delete_user(self, username):
278 return await self.invoke({"delete-user": {"username": username}})
279
280 async def become_user(self, username):
281 result = await self.invoke({"become-user": {"username": username}})
282 if username == self.username:
283 self.saved_become_user = None
284 else:
285 self.saved_become_user = username
286 return result
287
288 async def get_db_usage(self):
289 return (await self.invoke({"get-db-usage": {}}))["usage"]
290
291 async def get_db_query_columns(self):
292 return (await self.invoke({"get-db-query-columns": {}}))["columns"]
293
294 async def gc_status(self):
295 return await self.invoke({"gc-status": {}})
296
297 async def gc_mark(self, mark, where):
298 """
299 Starts a new garbage collection operation identified by "mark". If
300 garbage collection is already in progress with "mark", the collection
301 is continued.
196 302
197 for call in ( 303 All unihash entries that match the "where" clause are marked to be
304 kept. In addition, any new entries added to the database after this
305 command will be automatically marked with "mark"
306 """
307 return await self.invoke({"gc-mark": {"mark": mark, "where": where}})
308
309 async def gc_sweep(self, mark):
310 """
311 Finishes garbage collection for "mark". All unihash entries that have
312 not been marked will be deleted.
313
314 It is recommended to clean unused outhash entries after running this to
315 cleanup any dangling outhashes
316 """
317 return await self.invoke({"gc-sweep": {"mark": mark}})
318
319
320class Client(bb.asyncrpc.Client):
321 def __init__(self, username=None, password=None):
322 self.username = username
323 self.password = password
324
325 super().__init__()
326 self._add_methods(
198 "connect_tcp", 327 "connect_tcp",
199 "close", 328 "connect_websocket",
200 "get_unihash", 329 "get_unihash",
330 "get_unihash_batch",
201 "report_unihash", 331 "report_unihash",
202 "report_unihash_equiv", 332 "report_unihash_equiv",
203 "get_taskhash", 333 "get_taskhash",
334 "unihash_exists",
335 "unihash_exists_batch",
336 "get_outhash",
204 "get_stats", 337 "get_stats",
205 "reset_stats", 338 "reset_stats",
206 "backfill_wait", 339 "backfill_wait",
207 ): 340 "remove",
208 downcall = getattr(self.client, call) 341 "clean_unused",
209 setattr(self, call, self._get_downcall_wrapper(downcall)) 342 "auth",
210 343 "refresh_token",
211 def _get_downcall_wrapper(self, downcall): 344 "set_user_perms",
212 def wrapper(*args, **kwargs): 345 "get_user",
213 return self.loop.run_until_complete(downcall(*args, **kwargs)) 346 "get_all_users",
214 347 "new_user",
215 return wrapper 348 "delete_user",
216 349 "become_user",
217 def connect_unix(self, path): 350 "get_db_usage",
218 # AF_UNIX has path length issues so chdir here to workaround 351 "get_db_query_columns",
219 cwd = os.getcwd() 352 "gc_status",
220 try: 353 "gc_mark",
221 os.chdir(os.path.dirname(path)) 354 "gc_sweep",
222 self.loop.run_until_complete(self.client.connect_unix(os.path.basename(path))) 355 )
223 self.loop.run_until_complete(self.client.connect())
224 finally:
225 os.chdir(cwd)
226
227 @property
228 def max_chunk(self):
229 return self.client.max_chunk
230 356
231 @max_chunk.setter 357 def _get_async_client(self):
232 def max_chunk(self, value): 358 return AsyncClient(self.username, self.password)
233 self.client.max_chunk = value
diff --git a/bitbake/lib/hashserv/server.py b/bitbake/lib/hashserv/server.py
index a0dc0c170f..68f64f983b 100644
--- a/bitbake/lib/hashserv/server.py
+++ b/bitbake/lib/hashserv/server.py
@@ -3,20 +3,51 @@
3# SPDX-License-Identifier: GPL-2.0-only 3# SPDX-License-Identifier: GPL-2.0-only
4# 4#
5 5
6from contextlib import closing, contextmanager 6from datetime import datetime, timedelta
7from datetime import datetime
8import asyncio 7import asyncio
9import json
10import logging 8import logging
11import math 9import math
12import os
13import signal
14import socket
15import sys
16import time 10import time
17from . import chunkify, DEFAULT_MAX_CHUNK, create_async_client, TABLE_COLUMNS 11import os
12import base64
13import hashlib
14from . import create_async_client
15import bb.asyncrpc
16
17logger = logging.getLogger("hashserv.server")
18
19
20# This permission only exists to match nothing
21NONE_PERM = "@none"
22
23READ_PERM = "@read"
24REPORT_PERM = "@report"
25DB_ADMIN_PERM = "@db-admin"
26USER_ADMIN_PERM = "@user-admin"
27ALL_PERM = "@all"
18 28
19logger = logging.getLogger('hashserv.server') 29ALL_PERMISSIONS = {
30 READ_PERM,
31 REPORT_PERM,
32 DB_ADMIN_PERM,
33 USER_ADMIN_PERM,
34 ALL_PERM,
35}
36
37DEFAULT_ANON_PERMS = (
38 READ_PERM,
39 REPORT_PERM,
40 DB_ADMIN_PERM,
41)
42
43TOKEN_ALGORITHM = "sha256"
44
45# 48 bytes of random data will result in 64 characters when base64
46# encoded. This number also ensures that the base64 encoding won't have any
47# trailing '=' characters.
48TOKEN_SIZE = 48
49
50SALT_SIZE = 8
20 51
21 52
22class Measurement(object): 53class Measurement(object):
@@ -106,522 +137,745 @@ class Stats(object):
106 return math.sqrt(self.s / (self.num - 1)) 137 return math.sqrt(self.s / (self.num - 1))
107 138
108 def todict(self): 139 def todict(self):
109 return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')} 140 return {
110 141 k: getattr(self, k)
111 142 for k in ("num", "total_time", "max_time", "average", "stdev")
112class ClientError(Exception):
113 pass
114
115class ServerError(Exception):
116 pass
117
118def insert_task(cursor, data, ignore=False):
119 keys = sorted(data.keys())
120 query = '''INSERT%s INTO tasks_v2 (%s) VALUES (%s)''' % (
121 " OR IGNORE" if ignore else "",
122 ', '.join(keys),
123 ', '.join(':' + k for k in keys))
124 cursor.execute(query, data)
125
126async def copy_from_upstream(client, db, method, taskhash):
127 d = await client.get_taskhash(method, taskhash, True)
128 if d is not None:
129 # Filter out unknown columns
130 d = {k: v for k, v in d.items() if k in TABLE_COLUMNS}
131 keys = sorted(d.keys())
132
133 with closing(db.cursor()) as cursor:
134 insert_task(cursor, d)
135 db.commit()
136
137 return d
138
139async def copy_outhash_from_upstream(client, db, method, outhash, taskhash):
140 d = await client.get_outhash(method, outhash, taskhash)
141 if d is not None:
142 # Filter out unknown columns
143 d = {k: v for k, v in d.items() if k in TABLE_COLUMNS}
144 keys = sorted(d.keys())
145
146 with closing(db.cursor()) as cursor:
147 insert_task(cursor, d)
148 db.commit()
149
150 return d
151
152class ServerClient(object):
153 FAST_QUERY = 'SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1'
154 ALL_QUERY = 'SELECT * FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1'
155 OUTHASH_QUERY = '''
156 -- Find tasks with a matching outhash (that is, tasks that
157 -- are equivalent)
158 SELECT * FROM tasks_v2 WHERE method=:method AND outhash=:outhash
159
160 -- If there is an exact match on the taskhash, return it.
161 -- Otherwise return the oldest matching outhash of any
162 -- taskhash
163 ORDER BY CASE WHEN taskhash=:taskhash THEN 1 ELSE 2 END,
164 created ASC
165
166 -- Only return one row
167 LIMIT 1
168 '''
169
170 def __init__(self, reader, writer, db, request_stats, backfill_queue, upstream, read_only):
171 self.reader = reader
172 self.writer = writer
173 self.db = db
174 self.request_stats = request_stats
175 self.max_chunk = DEFAULT_MAX_CHUNK
176 self.backfill_queue = backfill_queue
177 self.upstream = upstream
178
179 self.handlers = {
180 'get': self.handle_get,
181 'get-outhash': self.handle_get_outhash,
182 'get-stream': self.handle_get_stream,
183 'get-stats': self.handle_get_stats,
184 'chunk-stream': self.handle_chunk,
185 } 143 }
186 144
187 if not read_only:
188 self.handlers.update({
189 'report': self.handle_report,
190 'report-equiv': self.handle_equivreport,
191 'reset-stats': self.handle_reset_stats,
192 'backfill-wait': self.handle_backfill_wait,
193 })
194 145
195 async def process_requests(self): 146token_refresh_semaphore = asyncio.Lock()
196 if self.upstream is not None:
197 self.upstream_client = await create_async_client(self.upstream)
198 else:
199 self.upstream_client = None
200 147
201 try:
202 148
149async def new_token():
150 # Prevent malicious users from using this API to deduce the entropy
151 # pool on the server and thus be able to guess a token. *All* token
152 # refresh requests lock the same global semaphore and then sleep for a
153 # short time. The effectively rate limits the total number of requests
154 # than can be made across all clients to 10/second, which should be enough
155 # since you have to be an authenticated users to make the request in the
156 # first place
157 async with token_refresh_semaphore:
158 await asyncio.sleep(0.1)
159 raw = os.getrandom(TOKEN_SIZE, os.GRND_NONBLOCK)
203 160
204 self.addr = self.writer.get_extra_info('peername') 161 return base64.b64encode(raw, b"._").decode("utf-8")
205 logger.debug('Client %r connected' % (self.addr,))
206 162
207 # Read protocol and version
208 protocol = await self.reader.readline()
209 if protocol is None:
210 return
211 163
212 (proto_name, proto_version) = protocol.decode('utf-8').rstrip().split() 164def new_salt():
213 if proto_name != 'OEHASHEQUIV': 165 return os.getrandom(SALT_SIZE, os.GRND_NONBLOCK).hex()
214 return
215 166
216 proto_version = tuple(int(v) for v in proto_version.split('.'))
217 if proto_version < (1, 0) or proto_version > (1, 1):
218 return
219 167
220 # Read headers. Currently, no headers are implemented, so look for 168def hash_token(algo, salt, token):
221 # an empty line to signal the end of the headers 169 h = hashlib.new(algo)
222 while True: 170 h.update(salt.encode("utf-8"))
223 line = await self.reader.readline() 171 h.update(token.encode("utf-8"))
224 if line is None: 172 return ":".join([algo, salt, h.hexdigest()])
225 return
226 173
227 line = line.decode('utf-8').rstrip()
228 if not line:
229 break
230 174
231 # Handle messages 175def permissions(*permissions, allow_anon=True, allow_self_service=False):
232 while True: 176 """
233 d = await self.read_message() 177 Function decorator that can be used to decorate an RPC function call and
234 if d is None: 178 check that the current users permissions match the require permissions.
235 break
236 await self.dispatch_message(d)
237 await self.writer.drain()
238 except ClientError as e:
239 logger.error(str(e))
240 finally:
241 if self.upstream_client is not None:
242 await self.upstream_client.close()
243 179
244 self.writer.close() 180 If allow_anon is True, the user will also be allowed to make the RPC call
181 if the anonymous user permissions match the permissions.
245 182
246 async def dispatch_message(self, msg): 183 If allow_self_service is True, and the "username" property in the request
247 for k in self.handlers.keys(): 184 is the currently logged in user, or not specified, the user will also be
248 if k in msg: 185 allowed to make the request. This allows users to access normal privileged
249 logger.debug('Handling %s' % k) 186 API, as long as they are only modifying their own user properties (e.g.
250 if 'stream' in k: 187 users can be allowed to reset their own token without @user-admin
251 await self.handlers[k](msg[k]) 188 permissions, but not the token for any other user.
189 """
190
191 def wrapper(func):
192 async def wrap(self, request):
193 if allow_self_service and self.user is not None:
194 username = request.get("username", self.user.username)
195 if username == self.user.username:
196 request["username"] = self.user.username
197 return await func(self, request)
198
199 if not self.user_has_permissions(*permissions, allow_anon=allow_anon):
200 if not self.user:
201 username = "Anonymous user"
202 user_perms = self.server.anon_perms
252 else: 203 else:
253 with self.request_stats.start_sample() as self.request_sample, \ 204 username = self.user.username
254 self.request_sample.measure(): 205 user_perms = self.user.permissions
255 await self.handlers[k](msg[k]) 206
256 return 207 self.logger.info(
208 "User %s with permissions %r denied from calling %s. Missing permissions(s) %r",
209 username,
210 ", ".join(user_perms),
211 func.__name__,
212 ", ".join(permissions),
213 )
214 raise bb.asyncrpc.InvokeError(
215 f"{username} is not allowed to access permissions(s) {', '.join(permissions)}"
216 )
217
218 return await func(self, request)
219
220 return wrap
221
222 return wrapper
223
224
225class ServerClient(bb.asyncrpc.AsyncServerConnection):
226 def __init__(self, socket, server):
227 super().__init__(socket, "OEHASHEQUIV", server.logger)
228 self.server = server
229 self.max_chunk = bb.asyncrpc.DEFAULT_MAX_CHUNK
230 self.user = None
231
232 self.handlers.update(
233 {
234 "get": self.handle_get,
235 "get-outhash": self.handle_get_outhash,
236 "get-stream": self.handle_get_stream,
237 "exists-stream": self.handle_exists_stream,
238 "get-stats": self.handle_get_stats,
239 "get-db-usage": self.handle_get_db_usage,
240 "get-db-query-columns": self.handle_get_db_query_columns,
241 # Not always read-only, but internally checks if the server is
242 # read-only
243 "report": self.handle_report,
244 "auth": self.handle_auth,
245 "get-user": self.handle_get_user,
246 "get-all-users": self.handle_get_all_users,
247 "become-user": self.handle_become_user,
248 }
249 )
257 250
258 raise ClientError("Unrecognized command %r" % msg) 251 if not self.server.read_only:
252 self.handlers.update(
253 {
254 "report-equiv": self.handle_equivreport,
255 "reset-stats": self.handle_reset_stats,
256 "backfill-wait": self.handle_backfill_wait,
257 "remove": self.handle_remove,
258 "gc-mark": self.handle_gc_mark,
259 "gc-sweep": self.handle_gc_sweep,
260 "gc-status": self.handle_gc_status,
261 "clean-unused": self.handle_clean_unused,
262 "refresh-token": self.handle_refresh_token,
263 "set-user-perms": self.handle_set_perms,
264 "new-user": self.handle_new_user,
265 "delete-user": self.handle_delete_user,
266 }
267 )
259 268
260 def write_message(self, msg): 269 def raise_no_user_error(self, username):
261 for c in chunkify(json.dumps(msg), self.max_chunk): 270 raise bb.asyncrpc.InvokeError(f"No user named '{username}' exists")
262 self.writer.write(c.encode('utf-8'))
263 271
264 async def read_message(self): 272 def user_has_permissions(self, *permissions, allow_anon=True):
265 l = await self.reader.readline() 273 permissions = set(permissions)
266 if not l: 274 if allow_anon:
267 return None 275 if ALL_PERM in self.server.anon_perms:
276 return True
268 277
269 try: 278 if not permissions - self.server.anon_perms:
270 message = l.decode('utf-8') 279 return True
271 280
272 if not message.endswith('\n'): 281 if self.user is None:
273 return None 282 return False
274 283
275 return json.loads(message) 284 if ALL_PERM in self.user.permissions:
276 except (json.JSONDecodeError, UnicodeDecodeError) as e: 285 return True
277 logger.error('Bad message from client: %r' % message)
278 raise e
279 286
280 async def handle_chunk(self, request): 287 if not permissions - self.user.permissions:
281 lines = [] 288 return True
282 try:
283 while True:
284 l = await self.reader.readline()
285 l = l.rstrip(b"\n").decode("utf-8")
286 if not l:
287 break
288 lines.append(l)
289 289
290 msg = json.loads(''.join(lines)) 290 return False
291 except (json.JSONDecodeError, UnicodeDecodeError) as e:
292 logger.error('Bad message from client: %r' % message)
293 raise e
294 291
295 if 'chunk-stream' in msg: 292 def validate_proto_version(self):
296 raise ClientError("Nested chunks are not allowed") 293 return self.proto_version > (1, 0) and self.proto_version <= (1, 1)
297 294
298 await self.dispatch_message(msg) 295 async def process_requests(self):
296 async with self.server.db_engine.connect(self.logger) as db:
297 self.db = db
298 if self.server.upstream is not None:
299 self.upstream_client = await create_async_client(self.server.upstream)
300 else:
301 self.upstream_client = None
299 302
300 async def handle_get(self, request): 303 try:
301 method = request['method'] 304 await super().process_requests()
302 taskhash = request['taskhash'] 305 finally:
306 if self.upstream_client is not None:
307 await self.upstream_client.close()
303 308
304 if request.get('all', False): 309 async def dispatch_message(self, msg):
305 row = self.query_equivalent(method, taskhash, self.ALL_QUERY) 310 for k in self.handlers.keys():
306 else: 311 if k in msg:
307 row = self.query_equivalent(method, taskhash, self.FAST_QUERY) 312 self.logger.debug("Handling %s" % k)
313 if "stream" in k:
314 return await self.handlers[k](msg[k])
315 else:
316 with self.server.request_stats.start_sample() as self.request_sample, self.request_sample.measure():
317 return await self.handlers[k](msg[k])
308 318
309 if row is not None: 319 raise bb.asyncrpc.ClientError("Unrecognized command %r" % msg)
310 logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) 320
311 d = {k: row[k] for k in row.keys()} 321 @permissions(READ_PERM)
312 elif self.upstream_client is not None: 322 async def handle_get(self, request):
313 d = await copy_from_upstream(self.upstream_client, self.db, method, taskhash) 323 method = request["method"]
324 taskhash = request["taskhash"]
325 fetch_all = request.get("all", False)
326
327 return await self.get_unihash(method, taskhash, fetch_all)
328
329 async def get_unihash(self, method, taskhash, fetch_all=False):
330 d = None
331
332 if fetch_all:
333 row = await self.db.get_unihash_by_taskhash_full(method, taskhash)
334 if row is not None:
335 d = {k: row[k] for k in row.keys()}
336 elif self.upstream_client is not None:
337 d = await self.upstream_client.get_taskhash(method, taskhash, True)
338 await self.update_unified(d)
314 else: 339 else:
315 d = None 340 row = await self.db.get_equivalent(method, taskhash)
341
342 if row is not None:
343 d = {k: row[k] for k in row.keys()}
344 elif self.upstream_client is not None:
345 d = await self.upstream_client.get_taskhash(method, taskhash)
346 await self.db.insert_unihash(d["method"], d["taskhash"], d["unihash"])
316 347
317 self.write_message(d) 348 return d
318 349
350 @permissions(READ_PERM)
319 async def handle_get_outhash(self, request): 351 async def handle_get_outhash(self, request):
320 with closing(self.db.cursor()) as cursor: 352 method = request["method"]
321 cursor.execute(self.OUTHASH_QUERY, 353 outhash = request["outhash"]
322 {k: request[k] for k in ('method', 'outhash', 'taskhash')}) 354 taskhash = request["taskhash"]
355 with_unihash = request.get("with_unihash", True)
323 356
324 row = cursor.fetchone() 357 return await self.get_outhash(method, outhash, taskhash, with_unihash)
358
359 async def get_outhash(self, method, outhash, taskhash, with_unihash=True):
360 d = None
361 if with_unihash:
362 row = await self.db.get_unihash_by_outhash(method, outhash)
363 else:
364 row = await self.db.get_outhash(method, outhash)
325 365
326 if row is not None: 366 if row is not None:
327 logger.debug('Found equivalent outhash %s -> %s', (row['outhash'], row['unihash']))
328 d = {k: row[k] for k in row.keys()} 367 d = {k: row[k] for k in row.keys()}
329 else: 368 elif self.upstream_client is not None:
330 d = None 369 d = await self.upstream_client.get_outhash(method, outhash, taskhash)
370 await self.update_unified(d)
331 371
332 self.write_message(d) 372 return d
333 373
334 async def handle_get_stream(self, request): 374 async def update_unified(self, data):
335 self.write_message('ok') 375 if data is None:
376 return
377
378 await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"])
379 await self.db.insert_outhash(data)
380
381 async def _stream_handler(self, handler):
382 await self.socket.send_message("ok")
336 383
337 while True: 384 while True:
338 upstream = None 385 upstream = None
339 386
340 l = await self.reader.readline() 387 l = await self.socket.recv()
341 if not l: 388 if not l:
342 return 389 break
343 390
344 try: 391 try:
345 # This inner loop is very sensitive and must be as fast as 392 # This inner loop is very sensitive and must be as fast as
346 # possible (which is why the request sample is handled manually 393 # possible (which is why the request sample is handled manually
347 # instead of using 'with', and also why logging statements are 394 # instead of using 'with', and also why logging statements are
348 # commented out. 395 # commented out.
349 self.request_sample = self.request_stats.start_sample() 396 self.request_sample = self.server.request_stats.start_sample()
350 request_measure = self.request_sample.measure() 397 request_measure = self.request_sample.measure()
351 request_measure.start() 398 request_measure.start()
352 399
353 l = l.decode('utf-8').rstrip() 400 if l == "END":
354 if l == 'END': 401 break
355 self.writer.write('ok\n'.encode('utf-8'))
356 return
357
358 (method, taskhash) = l.split()
359 #logger.debug('Looking up %s %s' % (method, taskhash))
360 row = self.query_equivalent(method, taskhash, self.FAST_QUERY)
361 if row is not None:
362 msg = ('%s\n' % row['unihash']).encode('utf-8')
363 #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
364 elif self.upstream_client is not None:
365 upstream = await self.upstream_client.get_unihash(method, taskhash)
366 if upstream:
367 msg = ("%s\n" % upstream).encode("utf-8")
368 else:
369 msg = "\n".encode("utf-8")
370 else:
371 msg = '\n'.encode('utf-8')
372 402
373 self.writer.write(msg) 403 msg = await handler(l)
404 await self.socket.send(msg)
374 finally: 405 finally:
375 request_measure.end() 406 request_measure.end()
376 self.request_sample.end() 407 self.request_sample.end()
377 408
378 await self.writer.drain() 409 await self.socket.send("ok")
410 return self.NO_RESPONSE
379 411
380 # Post to the backfill queue after writing the result to minimize 412 @permissions(READ_PERM)
381 # the turn around time on a request 413 async def handle_get_stream(self, request):
382 if upstream is not None: 414 async def handler(l):
383 await self.backfill_queue.put((method, taskhash)) 415 (method, taskhash) = l.split()
416 # self.logger.debug('Looking up %s %s' % (method, taskhash))
417 row = await self.db.get_equivalent(method, taskhash)
384 418
385 async def handle_report(self, data): 419 if row is not None:
386 with closing(self.db.cursor()) as cursor: 420 # self.logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
387 cursor.execute(self.OUTHASH_QUERY, 421 return row["unihash"]
388 {k: data[k] for k in ('method', 'outhash', 'taskhash')})
389
390 row = cursor.fetchone()
391
392 if row is None and self.upstream_client:
393 # Try upstream
394 row = await copy_outhash_from_upstream(self.upstream_client,
395 self.db,
396 data['method'],
397 data['outhash'],
398 data['taskhash'])
399
400 # If no matching outhash was found, or one *was* found but it
401 # wasn't an exact match on the taskhash, a new entry for this
402 # taskhash should be added
403 if row is None or row['taskhash'] != data['taskhash']:
404 # If a row matching the outhash was found, the unihash for
405 # the new taskhash should be the same as that one.
406 # Otherwise the caller provided unihash is used.
407 unihash = data['unihash']
408 if row is not None:
409 unihash = row['unihash']
410
411 insert_data = {
412 'method': data['method'],
413 'outhash': data['outhash'],
414 'taskhash': data['taskhash'],
415 'unihash': unihash,
416 'created': datetime.now()
417 }
418 422
419 for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'): 423 if self.upstream_client is not None:
420 if k in data: 424 upstream = await self.upstream_client.get_unihash(method, taskhash)
421 insert_data[k] = data[k] 425 if upstream:
426 await self.server.backfill_queue.put((method, taskhash))
427 return upstream
422 428
423 insert_task(cursor, insert_data) 429 return ""
424 self.db.commit()
425 430
426 logger.info('Adding taskhash %s with unihash %s', 431 return await self._stream_handler(handler)
427 data['taskhash'], unihash)
428 432
429 d = { 433 @permissions(READ_PERM)
430 'taskhash': data['taskhash'], 434 async def handle_exists_stream(self, request):
431 'method': data['method'], 435 async def handler(l):
432 'unihash': unihash 436 if await self.db.unihash_exists(l):
433 } 437 return "true"
434 else:
435 d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
436 438
437 self.write_message(d) 439 if self.upstream_client is not None:
440 if await self.upstream_client.unihash_exists(l):
441 return "true"
438 442
439 async def handle_equivreport(self, data): 443 return "false"
440 with closing(self.db.cursor()) as cursor:
441 insert_data = {
442 'method': data['method'],
443 'outhash': "",
444 'taskhash': data['taskhash'],
445 'unihash': data['unihash'],
446 'created': datetime.now()
447 }
448 444
449 for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'): 445 return await self._stream_handler(handler)
450 if k in data:
451 insert_data[k] = data[k]
452 446
453 insert_task(cursor, insert_data, ignore=True) 447 async def report_readonly(self, data):
454 self.db.commit() 448 method = data["method"]
449 outhash = data["outhash"]
450 taskhash = data["taskhash"]
455 451
456 # Fetch the unihash that will be reported for the taskhash. If the 452 info = await self.get_outhash(method, outhash, taskhash)
457 # unihash matches, it means this row was inserted (or the mapping 453 if info:
458 # was already valid) 454 unihash = info["unihash"]
459 row = self.query_equivalent(data['method'], data['taskhash'], self.FAST_QUERY) 455 else:
456 unihash = data["unihash"]
460 457
461 if row['unihash'] == data['unihash']: 458 return {
462 logger.info('Adding taskhash equivalence for %s with unihash %s', 459 "taskhash": taskhash,
463 data['taskhash'], row['unihash']) 460 "method": method,
461 "unihash": unihash,
462 }
464 463
465 d = {k: row[k] for k in ('taskhash', 'method', 'unihash')} 464 # Since this can be called either read only or to report, the check to
465 # report is made inside the function
466 @permissions(READ_PERM)
467 async def handle_report(self, data):
468 if self.server.read_only or not self.user_has_permissions(REPORT_PERM):
469 return await self.report_readonly(data)
470
471 outhash_data = {
472 "method": data["method"],
473 "outhash": data["outhash"],
474 "taskhash": data["taskhash"],
475 "created": datetime.now(),
476 }
466 477
467 self.write_message(d) 478 for k in ("owner", "PN", "PV", "PR", "task", "outhash_siginfo"):
479 if k in data:
480 outhash_data[k] = data[k]
468 481
482 if self.user:
483 outhash_data["owner"] = self.user.username
469 484
470 async def handle_get_stats(self, request): 485 # Insert the new entry, unless it already exists
471 d = { 486 if await self.db.insert_outhash(outhash_data):
472 'requests': self.request_stats.todict(), 487 # If this row is new, check if it is equivalent to another
488 # output hash
489 row = await self.db.get_equivalent_for_outhash(
490 data["method"], data["outhash"], data["taskhash"]
491 )
492
493 if row is not None:
494 # A matching output hash was found. Set our taskhash to the
495 # same unihash since they are equivalent
496 unihash = row["unihash"]
497 else:
498 # No matching output hash was found. This is probably the
499 # first outhash to be added.
500 unihash = data["unihash"]
501
502 # Query upstream to see if it has a unihash we can use
503 if self.upstream_client is not None:
504 upstream_data = await self.upstream_client.get_outhash(
505 data["method"], data["outhash"], data["taskhash"]
506 )
507 if upstream_data is not None:
508 unihash = upstream_data["unihash"]
509
510 await self.db.insert_unihash(data["method"], data["taskhash"], unihash)
511
512 unihash_data = await self.get_unihash(data["method"], data["taskhash"])
513 if unihash_data is not None:
514 unihash = unihash_data["unihash"]
515 else:
516 unihash = data["unihash"]
517
518 return {
519 "taskhash": data["taskhash"],
520 "method": data["method"],
521 "unihash": unihash,
473 } 522 }
474 523
475 self.write_message(d) 524 @permissions(READ_PERM, REPORT_PERM)
525 async def handle_equivreport(self, data):
526 await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"])
527
528 # Fetch the unihash that will be reported for the taskhash. If the
529 # unihash matches, it means this row was inserted (or the mapping
530 # was already valid)
531 row = await self.db.get_equivalent(data["method"], data["taskhash"])
532
533 if row["unihash"] == data["unihash"]:
534 self.logger.info(
535 "Adding taskhash equivalence for %s with unihash %s",
536 data["taskhash"],
537 row["unihash"],
538 )
539
540 return {k: row[k] for k in ("taskhash", "method", "unihash")}
476 541
542 @permissions(READ_PERM)
543 async def handle_get_stats(self, request):
544 return {
545 "requests": self.server.request_stats.todict(),
546 }
547
548 @permissions(DB_ADMIN_PERM)
477 async def handle_reset_stats(self, request): 549 async def handle_reset_stats(self, request):
478 d = { 550 d = {
479 'requests': self.request_stats.todict(), 551 "requests": self.server.request_stats.todict(),
480 } 552 }
481 553
482 self.request_stats.reset() 554 self.server.request_stats.reset()
483 self.write_message(d) 555 return d
484 556
557 @permissions(READ_PERM)
485 async def handle_backfill_wait(self, request): 558 async def handle_backfill_wait(self, request):
486 d = { 559 d = {
487 'tasks': self.backfill_queue.qsize(), 560 "tasks": self.server.backfill_queue.qsize(),
488 } 561 }
489 await self.backfill_queue.join() 562 await self.server.backfill_queue.join()
490 self.write_message(d) 563 return d
564
565 @permissions(DB_ADMIN_PERM)
566 async def handle_remove(self, request):
567 condition = request["where"]
568 if not isinstance(condition, dict):
569 raise TypeError("Bad condition type %s" % type(condition))
570
571 return {"count": await self.db.remove(condition)}
572
573 @permissions(DB_ADMIN_PERM)
574 async def handle_gc_mark(self, request):
575 condition = request["where"]
576 mark = request["mark"]
577
578 if not isinstance(condition, dict):
579 raise TypeError("Bad condition type %s" % type(condition))
580
581 if not isinstance(mark, str):
582 raise TypeError("Bad mark type %s" % type(mark))
583
584 return {"count": await self.db.gc_mark(mark, condition)}
585
586 @permissions(DB_ADMIN_PERM)
587 async def handle_gc_sweep(self, request):
588 mark = request["mark"]
589
590 if not isinstance(mark, str):
591 raise TypeError("Bad mark type %s" % type(mark))
592
593 current_mark = await self.db.get_current_gc_mark()
594
595 if not current_mark or mark != current_mark:
596 raise bb.asyncrpc.InvokeError(
597 f"'{mark}' is not the current mark. Refusing to sweep"
598 )
599
600 count = await self.db.gc_sweep()
601
602 return {"count": count}
603
604 @permissions(DB_ADMIN_PERM)
605 async def handle_gc_status(self, request):
606 (keep_rows, remove_rows, current_mark) = await self.db.gc_status()
607 return {
608 "keep": keep_rows,
609 "remove": remove_rows,
610 "mark": current_mark,
611 }
612
613 @permissions(DB_ADMIN_PERM)
614 async def handle_clean_unused(self, request):
615 max_age = request["max_age_seconds"]
616 oldest = datetime.now() - timedelta(seconds=-max_age)
617 return {"count": await self.db.clean_unused(oldest)}
618
619 @permissions(DB_ADMIN_PERM)
620 async def handle_get_db_usage(self, request):
621 return {"usage": await self.db.get_usage()}
622
623 @permissions(DB_ADMIN_PERM)
624 async def handle_get_db_query_columns(self, request):
625 return {"columns": await self.db.get_query_columns()}
626
627 # The authentication API is always allowed
628 async def handle_auth(self, request):
629 username = str(request["username"])
630 token = str(request["token"])
631
632 async def fail_auth():
633 nonlocal username
634 # Rate limit bad login attempts
635 await asyncio.sleep(1)
636 raise bb.asyncrpc.InvokeError(f"Unable to authenticate as {username}")
637
638 user, db_token = await self.db.lookup_user_token(username)
639
640 if not user or not db_token:
641 await fail_auth()
491 642
492 def query_equivalent(self, method, taskhash, query):
493 # This is part of the inner loop and must be as fast as possible
494 try: 643 try:
495 cursor = self.db.cursor() 644 algo, salt, _ = db_token.split(":")
496 cursor.execute(query, {'method': method, 'taskhash': taskhash}) 645 except ValueError:
497 return cursor.fetchone() 646 await fail_auth()
498 except:
499 cursor.close()
500 647
648 if hash_token(algo, salt, token) != db_token:
649 await fail_auth()
501 650
502class Server(object): 651 self.user = user
503 def __init__(self, db, loop=None, upstream=None, read_only=False):
504 if upstream and read_only:
505 raise ServerError("Read-only hashserv cannot pull from an upstream server")
506 652
507 self.request_stats = Stats() 653 self.logger.info("Authenticated as %s", username)
508 self.db = db
509 654
510 if loop is None: 655 return {
511 self.loop = asyncio.new_event_loop() 656 "result": True,
512 self.close_loop = True 657 "username": self.user.username,
513 else: 658 "permissions": sorted(list(self.user.permissions)),
514 self.loop = loop 659 }
515 self.close_loop = False
516 660
517 self.upstream = upstream 661 @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False)
518 self.read_only = read_only 662 async def handle_refresh_token(self, request):
663 username = str(request["username"])
519 664
520 self._cleanup_socket = None 665 token = await new_token()
521 666
522 def start_tcp_server(self, host, port): 667 updated = await self.db.set_user_token(
523 self.server = self.loop.run_until_complete( 668 username,
524 asyncio.start_server(self.handle_client, host, port, loop=self.loop) 669 hash_token(TOKEN_ALGORITHM, new_salt(), token),
525 ) 670 )
671 if not updated:
672 self.raise_no_user_error(username)
526 673
527 for s in self.server.sockets: 674 return {"username": username, "token": token}
528 logger.info('Listening on %r' % (s.getsockname(),))
529 # Newer python does this automatically. Do it manually here for
530 # maximum compatibility
531 s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
532 s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
533
534 name = self.server.sockets[0].getsockname()
535 if self.server.sockets[0].family == socket.AF_INET6:
536 self.address = "[%s]:%d" % (name[0], name[1])
537 else:
538 self.address = "%s:%d" % (name[0], name[1])
539 675
540 def start_unix_server(self, path): 676 def get_perm_arg(self, arg):
541 def cleanup(): 677 if not isinstance(arg, list):
542 os.unlink(path) 678 raise bb.asyncrpc.InvokeError("Unexpected type for permissions")
543 679
544 cwd = os.getcwd() 680 arg = set(arg)
545 try: 681 try:
546 # Work around path length limits in AF_UNIX 682 arg.remove(NONE_PERM)
547 os.chdir(os.path.dirname(path)) 683 except KeyError:
548 self.server = self.loop.run_until_complete( 684 pass
549 asyncio.start_unix_server(self.handle_client, os.path.basename(path), loop=self.loop) 685
686 unknown_perms = arg - ALL_PERMISSIONS
687 if unknown_perms:
688 raise bb.asyncrpc.InvokeError(
689 "Unknown permissions %s" % ", ".join(sorted(list(unknown_perms)))
550 ) 690 )
551 finally:
552 os.chdir(cwd)
553 691
554 logger.info('Listening on %r' % path) 692 return sorted(list(arg))
555 693
556 self._cleanup_socket = cleanup 694 def return_perms(self, permissions):
557 self.address = "unix://%s" % os.path.abspath(path) 695 if ALL_PERM in permissions:
696 return sorted(list(ALL_PERMISSIONS))
697 return sorted(list(permissions))
558 698
559 async def handle_client(self, reader, writer): 699 @permissions(USER_ADMIN_PERM, allow_anon=False)
560 # writer.transport.set_write_buffer_limits(0) 700 async def handle_set_perms(self, request):
561 try: 701 username = str(request["username"])
562 client = ServerClient(reader, writer, self.db, self.request_stats, self.backfill_queue, self.upstream, self.read_only) 702 permissions = self.get_perm_arg(request["permissions"])
563 await client.process_requests()
564 except Exception as e:
565 import traceback
566 logger.error('Error from client: %s' % str(e), exc_info=True)
567 traceback.print_exc()
568 writer.close()
569 logger.info('Client disconnected')
570
571 @contextmanager
572 def _backfill_worker(self):
573 async def backfill_worker_task():
574 client = await create_async_client(self.upstream)
575 try:
576 while True:
577 item = await self.backfill_queue.get()
578 if item is None:
579 self.backfill_queue.task_done()
580 break
581 method, taskhash = item
582 await copy_from_upstream(client, self.db, method, taskhash)
583 self.backfill_queue.task_done()
584 finally:
585 await client.close()
586 703
587 async def join_worker(worker): 704 if not await self.db.set_user_perms(username, permissions):
588 await self.backfill_queue.put(None) 705 self.raise_no_user_error(username)
589 await worker
590 706
591 if self.upstream is not None: 707 return {
592 worker = asyncio.ensure_future(backfill_worker_task()) 708 "username": username,
593 try: 709 "permissions": self.return_perms(permissions),
594 yield 710 }
595 finally:
596 self.loop.run_until_complete(join_worker(worker))
597 else:
598 yield
599 711
600 def serve_forever(self): 712 @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False)
601 def signal_handler(): 713 async def handle_get_user(self, request):
602 self.loop.stop() 714 username = str(request["username"])
603 715
604 asyncio.set_event_loop(self.loop) 716 user = await self.db.lookup_user(username)
605 try: 717 if user is None:
606 self.backfill_queue = asyncio.Queue() 718 return None
719
720 return {
721 "username": user.username,
722 "permissions": self.return_perms(user.permissions),
723 }
724
725 @permissions(USER_ADMIN_PERM, allow_anon=False)
726 async def handle_get_all_users(self, request):
727 users = await self.db.get_all_users()
728 return {
729 "users": [
730 {
731 "username": u.username,
732 "permissions": self.return_perms(u.permissions),
733 }
734 for u in users
735 ]
736 }
737
738 @permissions(USER_ADMIN_PERM, allow_anon=False)
739 async def handle_new_user(self, request):
740 username = str(request["username"])
741 permissions = self.get_perm_arg(request["permissions"])
742
743 token = await new_token()
744
745 inserted = await self.db.new_user(
746 username,
747 permissions,
748 hash_token(TOKEN_ALGORITHM, new_salt(), token),
749 )
750 if not inserted:
751 raise bb.asyncrpc.InvokeError(f"Cannot create new user '{username}'")
752
753 return {
754 "username": username,
755 "permissions": self.return_perms(permissions),
756 "token": token,
757 }
758
759 @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False)
760 async def handle_delete_user(self, request):
761 username = str(request["username"])
762
763 if not await self.db.delete_user(username):
764 self.raise_no_user_error(username)
765
766 return {"username": username}
607 767
608 self.loop.add_signal_handler(signal.SIGTERM, signal_handler) 768 @permissions(USER_ADMIN_PERM, allow_anon=False)
769 async def handle_become_user(self, request):
770 username = str(request["username"])
609 771
610 with self._backfill_worker(): 772 user = await self.db.lookup_user(username)
611 try: 773 if user is None:
612 self.loop.run_forever() 774 raise bb.asyncrpc.InvokeError(f"User {username} doesn't exist")
613 except KeyboardInterrupt:
614 pass
615 775
616 self.server.close() 776 self.user = user
777
778 self.logger.info("Became user %s", username)
779
780 return {
781 "username": self.user.username,
782 "permissions": self.return_perms(self.user.permissions),
783 }
784
785
786class Server(bb.asyncrpc.AsyncServer):
787 def __init__(
788 self,
789 db_engine,
790 upstream=None,
791 read_only=False,
792 anon_perms=DEFAULT_ANON_PERMS,
793 admin_username=None,
794 admin_password=None,
795 ):
796 if upstream and read_only:
797 raise bb.asyncrpc.ServerError(
798 "Read-only hashserv cannot pull from an upstream server"
799 )
800
801 disallowed_perms = set(anon_perms) - set(
802 [NONE_PERM, READ_PERM, REPORT_PERM, DB_ADMIN_PERM]
803 )
804
805 if disallowed_perms:
806 raise bb.asyncrpc.ServerError(
807 f"Permission(s) {' '.join(disallowed_perms)} are not allowed for anonymous users"
808 )
617 809
618 self.loop.run_until_complete(self.server.wait_closed()) 810 super().__init__(logger)
619 logger.info('Server shutting down')
620 finally:
621 if self.close_loop:
622 if sys.version_info >= (3, 6):
623 self.loop.run_until_complete(self.loop.shutdown_asyncgens())
624 self.loop.close()
625 811
626 if self._cleanup_socket is not None: 812 self.request_stats = Stats()
627 self._cleanup_socket() 813 self.db_engine = db_engine
814 self.upstream = upstream
815 self.read_only = read_only
816 self.backfill_queue = None
817 self.anon_perms = set(anon_perms)
818 self.admin_username = admin_username
819 self.admin_password = admin_password
820
821 self.logger.info(
822 "Anonymous user permissions are: %s", ", ".join(self.anon_perms)
823 )
824
825 def accept_client(self, socket):
826 return ServerClient(socket, self)
827
828 async def create_admin_user(self):
829 admin_permissions = (ALL_PERM,)
830 async with self.db_engine.connect(self.logger) as db:
831 added = await db.new_user(
832 self.admin_username,
833 admin_permissions,
834 hash_token(TOKEN_ALGORITHM, new_salt(), self.admin_password),
835 )
836 if added:
837 self.logger.info("Created admin user '%s'", self.admin_username)
838 else:
839 await db.set_user_perms(
840 self.admin_username,
841 admin_permissions,
842 )
843 await db.set_user_token(
844 self.admin_username,
845 hash_token(TOKEN_ALGORITHM, new_salt(), self.admin_password),
846 )
847 self.logger.info("Admin user '%s' updated", self.admin_username)
848
849 async def backfill_worker_task(self):
850 async with await create_async_client(
851 self.upstream
852 ) as client, self.db_engine.connect(self.logger) as db:
853 while True:
854 item = await self.backfill_queue.get()
855 if item is None:
856 self.backfill_queue.task_done()
857 break
858
859 method, taskhash = item
860 d = await client.get_taskhash(method, taskhash)
861 if d is not None:
862 await db.insert_unihash(d["method"], d["taskhash"], d["unihash"])
863 self.backfill_queue.task_done()
864
865 def start(self):
866 tasks = super().start()
867 if self.upstream:
868 self.backfill_queue = asyncio.Queue()
869 tasks += [self.backfill_worker_task()]
870
871 self.loop.run_until_complete(self.db_engine.create())
872
873 if self.admin_username:
874 self.loop.run_until_complete(self.create_admin_user())
875
876 return tasks
877
878 async def stop(self):
879 if self.backfill_queue is not None:
880 await self.backfill_queue.put(None)
881 await super().stop()
diff --git a/bitbake/lib/hashserv/sqlalchemy.py b/bitbake/lib/hashserv/sqlalchemy.py
new file mode 100644
index 0000000000..f7b0226a7a
--- /dev/null
+++ b/bitbake/lib/hashserv/sqlalchemy.py
@@ -0,0 +1,598 @@
1#! /usr/bin/env python3
2#
3# Copyright (C) 2023 Garmin Ltd.
4#
5# SPDX-License-Identifier: GPL-2.0-only
6#
7
8import logging
9from datetime import datetime
10from . import User
11
12from sqlalchemy.ext.asyncio import create_async_engine
13from sqlalchemy.pool import NullPool
14from sqlalchemy import (
15 MetaData,
16 Column,
17 Table,
18 Text,
19 Integer,
20 UniqueConstraint,
21 DateTime,
22 Index,
23 select,
24 insert,
25 exists,
26 literal,
27 and_,
28 delete,
29 update,
30 func,
31 inspect,
32)
33import sqlalchemy.engine
34from sqlalchemy.orm import declarative_base
35from sqlalchemy.exc import IntegrityError
36from sqlalchemy.dialects.postgresql import insert as postgres_insert
37
38Base = declarative_base()
39
40
41class UnihashesV3(Base):
42 __tablename__ = "unihashes_v3"
43 id = Column(Integer, primary_key=True, autoincrement=True)
44 method = Column(Text, nullable=False)
45 taskhash = Column(Text, nullable=False)
46 unihash = Column(Text, nullable=False)
47 gc_mark = Column(Text, nullable=False)
48
49 __table_args__ = (
50 UniqueConstraint("method", "taskhash"),
51 Index("taskhash_lookup_v4", "method", "taskhash"),
52 Index("unihash_lookup_v1", "unihash"),
53 )
54
55
56class OuthashesV2(Base):
57 __tablename__ = "outhashes_v2"
58 id = Column(Integer, primary_key=True, autoincrement=True)
59 method = Column(Text, nullable=False)
60 taskhash = Column(Text, nullable=False)
61 outhash = Column(Text, nullable=False)
62 created = Column(DateTime)
63 owner = Column(Text)
64 PN = Column(Text)
65 PV = Column(Text)
66 PR = Column(Text)
67 task = Column(Text)
68 outhash_siginfo = Column(Text)
69
70 __table_args__ = (
71 UniqueConstraint("method", "taskhash", "outhash"),
72 Index("outhash_lookup_v3", "method", "outhash"),
73 )
74
75
76class Users(Base):
77 __tablename__ = "users"
78 id = Column(Integer, primary_key=True, autoincrement=True)
79 username = Column(Text, nullable=False)
80 token = Column(Text, nullable=False)
81 permissions = Column(Text)
82
83 __table_args__ = (UniqueConstraint("username"),)
84
85
86class Config(Base):
87 __tablename__ = "config"
88 id = Column(Integer, primary_key=True, autoincrement=True)
89 name = Column(Text, nullable=False)
90 value = Column(Text)
91 __table_args__ = (
92 UniqueConstraint("name"),
93 Index("config_lookup", "name"),
94 )
95
96
97#
98# Old table versions
99#
100DeprecatedBase = declarative_base()
101
102
103class UnihashesV2(DeprecatedBase):
104 __tablename__ = "unihashes_v2"
105 id = Column(Integer, primary_key=True, autoincrement=True)
106 method = Column(Text, nullable=False)
107 taskhash = Column(Text, nullable=False)
108 unihash = Column(Text, nullable=False)
109
110 __table_args__ = (
111 UniqueConstraint("method", "taskhash"),
112 Index("taskhash_lookup_v3", "method", "taskhash"),
113 )
114
115
116class DatabaseEngine(object):
117 def __init__(self, url, username=None, password=None):
118 self.logger = logging.getLogger("hashserv.sqlalchemy")
119 self.url = sqlalchemy.engine.make_url(url)
120
121 if username is not None:
122 self.url = self.url.set(username=username)
123
124 if password is not None:
125 self.url = self.url.set(password=password)
126
127 async def create(self):
128 def check_table_exists(conn, name):
129 return inspect(conn).has_table(name)
130
131 self.logger.info("Using database %s", self.url)
132 if self.url.drivername == 'postgresql+psycopg':
133 # Psygopg 3 (psygopg) driver can handle async connection pooling
134 self.engine = create_async_engine(self.url, max_overflow=-1)
135 else:
136 self.engine = create_async_engine(self.url, poolclass=NullPool)
137
138 async with self.engine.begin() as conn:
139 # Create tables
140 self.logger.info("Creating tables...")
141 await conn.run_sync(Base.metadata.create_all)
142
143 if await conn.run_sync(check_table_exists, UnihashesV2.__tablename__):
144 self.logger.info("Upgrading Unihashes V2 -> V3...")
145 statement = insert(UnihashesV3).from_select(
146 ["id", "method", "unihash", "taskhash", "gc_mark"],
147 select(
148 UnihashesV2.id,
149 UnihashesV2.method,
150 UnihashesV2.unihash,
151 UnihashesV2.taskhash,
152 literal("").label("gc_mark"),
153 ),
154 )
155 self.logger.debug("%s", statement)
156 await conn.execute(statement)
157
158 await conn.run_sync(Base.metadata.drop_all, [UnihashesV2.__table__])
159 self.logger.info("Upgrade complete")
160
161 def connect(self, logger):
162 return Database(self.engine, logger)
163
164
165def map_row(row):
166 if row is None:
167 return None
168 return dict(**row._mapping)
169
170
171def map_user(row):
172 if row is None:
173 return None
174 return User(
175 username=row.username,
176 permissions=set(row.permissions.split()),
177 )
178
179
180def _make_condition_statement(table, condition):
181 where = {}
182 for c in table.__table__.columns:
183 if c.key in condition and condition[c.key] is not None:
184 where[c] = condition[c.key]
185
186 return [(k == v) for k, v in where.items()]
187
188
189class Database(object):
190 def __init__(self, engine, logger):
191 self.engine = engine
192 self.db = None
193 self.logger = logger
194
195 async def __aenter__(self):
196 self.db = await self.engine.connect()
197 return self
198
199 async def __aexit__(self, exc_type, exc_value, traceback):
200 await self.close()
201
202 async def close(self):
203 await self.db.close()
204 self.db = None
205
206 async def _execute(self, statement):
207 self.logger.debug("%s", statement)
208 return await self.db.execute(statement)
209
210 async def _set_config(self, name, value):
211 while True:
212 result = await self._execute(
213 update(Config).where(Config.name == name).values(value=value)
214 )
215
216 if result.rowcount == 0:
217 self.logger.debug("Config '%s' not found. Adding it", name)
218 try:
219 await self._execute(insert(Config).values(name=name, value=value))
220 except IntegrityError:
221 # Race. Try again
222 continue
223
224 break
225
226 def _get_config_subquery(self, name, default=None):
227 if default is not None:
228 return func.coalesce(
229 select(Config.value).where(Config.name == name).scalar_subquery(),
230 default,
231 )
232 return select(Config.value).where(Config.name == name).scalar_subquery()
233
234 async def _get_config(self, name):
235 result = await self._execute(select(Config.value).where(Config.name == name))
236 row = result.first()
237 if row is None:
238 return None
239 return row.value
240
241 async def get_unihash_by_taskhash_full(self, method, taskhash):
242 async with self.db.begin():
243 result = await self._execute(
244 select(
245 OuthashesV2,
246 UnihashesV3.unihash.label("unihash"),
247 )
248 .join(
249 UnihashesV3,
250 and_(
251 UnihashesV3.method == OuthashesV2.method,
252 UnihashesV3.taskhash == OuthashesV2.taskhash,
253 ),
254 )
255 .where(
256 OuthashesV2.method == method,
257 OuthashesV2.taskhash == taskhash,
258 )
259 .order_by(
260 OuthashesV2.created.asc(),
261 )
262 .limit(1)
263 )
264 return map_row(result.first())
265
266 async def get_unihash_by_outhash(self, method, outhash):
267 async with self.db.begin():
268 result = await self._execute(
269 select(OuthashesV2, UnihashesV3.unihash.label("unihash"))
270 .join(
271 UnihashesV3,
272 and_(
273 UnihashesV3.method == OuthashesV2.method,
274 UnihashesV3.taskhash == OuthashesV2.taskhash,
275 ),
276 )
277 .where(
278 OuthashesV2.method == method,
279 OuthashesV2.outhash == outhash,
280 )
281 .order_by(
282 OuthashesV2.created.asc(),
283 )
284 .limit(1)
285 )
286 return map_row(result.first())
287
288 async def unihash_exists(self, unihash):
289 async with self.db.begin():
290 result = await self._execute(
291 select(UnihashesV3).where(UnihashesV3.unihash == unihash).limit(1)
292 )
293
294 return result.first() is not None
295
296 async def get_outhash(self, method, outhash):
297 async with self.db.begin():
298 result = await self._execute(
299 select(OuthashesV2)
300 .where(
301 OuthashesV2.method == method,
302 OuthashesV2.outhash == outhash,
303 )
304 .order_by(
305 OuthashesV2.created.asc(),
306 )
307 .limit(1)
308 )
309 return map_row(result.first())
310
311 async def get_equivalent_for_outhash(self, method, outhash, taskhash):
312 async with self.db.begin():
313 result = await self._execute(
314 select(
315 OuthashesV2.taskhash.label("taskhash"),
316 UnihashesV3.unihash.label("unihash"),
317 )
318 .join(
319 UnihashesV3,
320 and_(
321 UnihashesV3.method == OuthashesV2.method,
322 UnihashesV3.taskhash == OuthashesV2.taskhash,
323 ),
324 )
325 .where(
326 OuthashesV2.method == method,
327 OuthashesV2.outhash == outhash,
328 OuthashesV2.taskhash != taskhash,
329 )
330 .order_by(
331 OuthashesV2.created.asc(),
332 )
333 .limit(1)
334 )
335 return map_row(result.first())
336
337 async def get_equivalent(self, method, taskhash):
338 async with self.db.begin():
339 result = await self._execute(
340 select(
341 UnihashesV3.unihash,
342 UnihashesV3.method,
343 UnihashesV3.taskhash,
344 ).where(
345 UnihashesV3.method == method,
346 UnihashesV3.taskhash == taskhash,
347 )
348 )
349 return map_row(result.first())
350
351 async def remove(self, condition):
352 async def do_remove(table):
353 where = _make_condition_statement(table, condition)
354 if where:
355 async with self.db.begin():
356 result = await self._execute(delete(table).where(*where))
357 return result.rowcount
358
359 return 0
360
361 count = 0
362 count += await do_remove(UnihashesV3)
363 count += await do_remove(OuthashesV2)
364
365 return count
366
367 async def get_current_gc_mark(self):
368 async with self.db.begin():
369 return await self._get_config("gc-mark")
370
371 async def gc_status(self):
372 async with self.db.begin():
373 gc_mark_subquery = self._get_config_subquery("gc-mark", "")
374
375 result = await self._execute(
376 select(func.count())
377 .select_from(UnihashesV3)
378 .where(UnihashesV3.gc_mark == gc_mark_subquery)
379 )
380 keep_rows = result.scalar()
381
382 result = await self._execute(
383 select(func.count())
384 .select_from(UnihashesV3)
385 .where(UnihashesV3.gc_mark != gc_mark_subquery)
386 )
387 remove_rows = result.scalar()
388
389 return (keep_rows, remove_rows, await self._get_config("gc-mark"))
390
391 async def gc_mark(self, mark, condition):
392 async with self.db.begin():
393 await self._set_config("gc-mark", mark)
394
395 where = _make_condition_statement(UnihashesV3, condition)
396 if not where:
397 return 0
398
399 result = await self._execute(
400 update(UnihashesV3)
401 .values(gc_mark=self._get_config_subquery("gc-mark", ""))
402 .where(*where)
403 )
404 return result.rowcount
405
406 async def gc_sweep(self):
407 async with self.db.begin():
408 result = await self._execute(
409 delete(UnihashesV3).where(
410 # A sneaky conditional that provides some errant use
411 # protection: If the config mark is NULL, this will not
412 # match any rows because No default is specified in the
413 # select statement
414 UnihashesV3.gc_mark
415 != self._get_config_subquery("gc-mark")
416 )
417 )
418 await self._set_config("gc-mark", None)
419
420 return result.rowcount
421
422 async def clean_unused(self, oldest):
423 async with self.db.begin():
424 result = await self._execute(
425 delete(OuthashesV2).where(
426 OuthashesV2.created < oldest,
427 ~(
428 select(UnihashesV3.id)
429 .where(
430 UnihashesV3.method == OuthashesV2.method,
431 UnihashesV3.taskhash == OuthashesV2.taskhash,
432 )
433 .limit(1)
434 .exists()
435 ),
436 )
437 )
438 return result.rowcount
439
440 async def insert_unihash(self, method, taskhash, unihash):
441 # Postgres specific ignore on insert duplicate
442 if self.engine.name == "postgresql":
443 statement = (
444 postgres_insert(UnihashesV3)
445 .values(
446 method=method,
447 taskhash=taskhash,
448 unihash=unihash,
449 gc_mark=self._get_config_subquery("gc-mark", ""),
450 )
451 .on_conflict_do_nothing(index_elements=("method", "taskhash"))
452 )
453 else:
454 statement = insert(UnihashesV3).values(
455 method=method,
456 taskhash=taskhash,
457 unihash=unihash,
458 gc_mark=self._get_config_subquery("gc-mark", ""),
459 )
460
461 try:
462 async with self.db.begin():
463 result = await self._execute(statement)
464 return result.rowcount != 0
465 except IntegrityError:
466 self.logger.debug(
467 "%s, %s, %s already in unihash database", method, taskhash, unihash
468 )
469 return False
470
471 async def insert_outhash(self, data):
472 outhash_columns = set(c.key for c in OuthashesV2.__table__.columns)
473
474 data = {k: v for k, v in data.items() if k in outhash_columns}
475
476 if "created" in data and not isinstance(data["created"], datetime):
477 data["created"] = datetime.fromisoformat(data["created"])
478
479 # Postgres specific ignore on insert duplicate
480 if self.engine.name == "postgresql":
481 statement = (
482 postgres_insert(OuthashesV2)
483 .values(**data)
484 .on_conflict_do_nothing(
485 index_elements=("method", "taskhash", "outhash")
486 )
487 )
488 else:
489 statement = insert(OuthashesV2).values(**data)
490
491 try:
492 async with self.db.begin():
493 result = await self._execute(statement)
494 return result.rowcount != 0
495 except IntegrityError:
496 self.logger.debug(
497 "%s, %s already in outhash database", data["method"], data["outhash"]
498 )
499 return False
500
501 async def _get_user(self, username):
502 async with self.db.begin():
503 result = await self._execute(
504 select(
505 Users.username,
506 Users.permissions,
507 Users.token,
508 ).where(
509 Users.username == username,
510 )
511 )
512 return result.first()
513
514 async def lookup_user_token(self, username):
515 row = await self._get_user(username)
516 if not row:
517 return None, None
518 return map_user(row), row.token
519
520 async def lookup_user(self, username):
521 return map_user(await self._get_user(username))
522
523 async def set_user_token(self, username, token):
524 async with self.db.begin():
525 result = await self._execute(
526 update(Users)
527 .where(
528 Users.username == username,
529 )
530 .values(
531 token=token,
532 )
533 )
534 return result.rowcount != 0
535
536 async def set_user_perms(self, username, permissions):
537 async with self.db.begin():
538 result = await self._execute(
539 update(Users)
540 .where(Users.username == username)
541 .values(permissions=" ".join(permissions))
542 )
543 return result.rowcount != 0
544
545 async def get_all_users(self):
546 async with self.db.begin():
547 result = await self._execute(
548 select(
549 Users.username,
550 Users.permissions,
551 )
552 )
553 return [map_user(row) for row in result]
554
555 async def new_user(self, username, permissions, token):
556 try:
557 async with self.db.begin():
558 await self._execute(
559 insert(Users).values(
560 username=username,
561 permissions=" ".join(permissions),
562 token=token,
563 )
564 )
565 return True
566 except IntegrityError as e:
567 self.logger.debug("Cannot create new user %s: %s", username, e)
568 return False
569
570 async def delete_user(self, username):
571 async with self.db.begin():
572 result = await self._execute(
573 delete(Users).where(Users.username == username)
574 )
575 return result.rowcount != 0
576
577 async def get_usage(self):
578 usage = {}
579 async with self.db.begin() as session:
580 for name, table in Base.metadata.tables.items():
581 result = await self._execute(
582 statement=select(func.count()).select_from(table)
583 )
584 usage[name] = {
585 "rows": result.scalar(),
586 }
587
588 return usage
589
590 async def get_query_columns(self):
591 columns = set()
592 for table in (UnihashesV3, OuthashesV2):
593 for c in table.__table__.columns:
594 if not isinstance(c.type, Text):
595 continue
596 columns.add(c.key)
597
598 return list(columns)
diff --git a/bitbake/lib/hashserv/sqlite.py b/bitbake/lib/hashserv/sqlite.py
new file mode 100644
index 0000000000..da2e844a03
--- /dev/null
+++ b/bitbake/lib/hashserv/sqlite.py
@@ -0,0 +1,562 @@
1#! /usr/bin/env python3
2#
3# Copyright (C) 2023 Garmin Ltd.
4#
5# SPDX-License-Identifier: GPL-2.0-only
6#
7import sqlite3
8import logging
9from contextlib import closing
10from . import User
11
12logger = logging.getLogger("hashserv.sqlite")
13
14UNIHASH_TABLE_DEFINITION = (
15 ("method", "TEXT NOT NULL", "UNIQUE"),
16 ("taskhash", "TEXT NOT NULL", "UNIQUE"),
17 ("unihash", "TEXT NOT NULL", ""),
18 ("gc_mark", "TEXT NOT NULL", ""),
19)
20
21UNIHASH_TABLE_COLUMNS = tuple(name for name, _, _ in UNIHASH_TABLE_DEFINITION)
22
23OUTHASH_TABLE_DEFINITION = (
24 ("method", "TEXT NOT NULL", "UNIQUE"),
25 ("taskhash", "TEXT NOT NULL", "UNIQUE"),
26 ("outhash", "TEXT NOT NULL", "UNIQUE"),
27 ("created", "DATETIME", ""),
28 # Optional fields
29 ("owner", "TEXT", ""),
30 ("PN", "TEXT", ""),
31 ("PV", "TEXT", ""),
32 ("PR", "TEXT", ""),
33 ("task", "TEXT", ""),
34 ("outhash_siginfo", "TEXT", ""),
35)
36
37OUTHASH_TABLE_COLUMNS = tuple(name for name, _, _ in OUTHASH_TABLE_DEFINITION)
38
39USERS_TABLE_DEFINITION = (
40 ("username", "TEXT NOT NULL", "UNIQUE"),
41 ("token", "TEXT NOT NULL", ""),
42 ("permissions", "TEXT NOT NULL", ""),
43)
44
45USERS_TABLE_COLUMNS = tuple(name for name, _, _ in USERS_TABLE_DEFINITION)
46
47
48CONFIG_TABLE_DEFINITION = (
49 ("name", "TEXT NOT NULL", "UNIQUE"),
50 ("value", "TEXT", ""),
51)
52
53CONFIG_TABLE_COLUMNS = tuple(name for name, _, _ in CONFIG_TABLE_DEFINITION)
54
55
56def _make_table(cursor, name, definition):
57 cursor.execute(
58 """
59 CREATE TABLE IF NOT EXISTS {name} (
60 id INTEGER PRIMARY KEY AUTOINCREMENT,
61 {fields}
62 UNIQUE({unique})
63 )
64 """.format(
65 name=name,
66 fields=" ".join("%s %s," % (name, typ) for name, typ, _ in definition),
67 unique=", ".join(
68 name for name, _, flags in definition if "UNIQUE" in flags
69 ),
70 )
71 )
72
73
74def map_user(row):
75 if row is None:
76 return None
77 return User(
78 username=row["username"],
79 permissions=set(row["permissions"].split()),
80 )
81
82
83def _make_condition_statement(columns, condition):
84 where = {}
85 for c in columns:
86 if c in condition and condition[c] is not None:
87 where[c] = condition[c]
88
89 return where, " AND ".join("%s=:%s" % (k, k) for k in where.keys())
90
91
92def _get_sqlite_version(cursor):
93 cursor.execute("SELECT sqlite_version()")
94
95 version = []
96 for v in cursor.fetchone()[0].split("."):
97 try:
98 version.append(int(v))
99 except ValueError:
100 version.append(v)
101
102 return tuple(version)
103
104
105def _schema_table_name(version):
106 if version >= (3, 33):
107 return "sqlite_schema"
108
109 return "sqlite_master"
110
111
112class DatabaseEngine(object):
113 def __init__(self, dbname, sync):
114 self.dbname = dbname
115 self.logger = logger
116 self.sync = sync
117
118 async def create(self):
119 db = sqlite3.connect(self.dbname)
120 db.row_factory = sqlite3.Row
121
122 with closing(db.cursor()) as cursor:
123 _make_table(cursor, "unihashes_v3", UNIHASH_TABLE_DEFINITION)
124 _make_table(cursor, "outhashes_v2", OUTHASH_TABLE_DEFINITION)
125 _make_table(cursor, "users", USERS_TABLE_DEFINITION)
126 _make_table(cursor, "config", CONFIG_TABLE_DEFINITION)
127
128 cursor.execute("PRAGMA journal_mode = WAL")
129 cursor.execute(
130 "PRAGMA synchronous = %s" % ("NORMAL" if self.sync else "OFF")
131 )
132
133 # Drop old indexes
134 cursor.execute("DROP INDEX IF EXISTS taskhash_lookup")
135 cursor.execute("DROP INDEX IF EXISTS outhash_lookup")
136 cursor.execute("DROP INDEX IF EXISTS taskhash_lookup_v2")
137 cursor.execute("DROP INDEX IF EXISTS outhash_lookup_v2")
138 cursor.execute("DROP INDEX IF EXISTS taskhash_lookup_v3")
139
140 # TODO: Upgrade from tasks_v2?
141 cursor.execute("DROP TABLE IF EXISTS tasks_v2")
142
143 # Create new indexes
144 cursor.execute(
145 "CREATE INDEX IF NOT EXISTS taskhash_lookup_v4 ON unihashes_v3 (method, taskhash)"
146 )
147 cursor.execute(
148 "CREATE INDEX IF NOT EXISTS unihash_lookup_v1 ON unihashes_v3 (unihash)"
149 )
150 cursor.execute(
151 "CREATE INDEX IF NOT EXISTS outhash_lookup_v3 ON outhashes_v2 (method, outhash)"
152 )
153 cursor.execute("CREATE INDEX IF NOT EXISTS config_lookup ON config (name)")
154
155 sqlite_version = _get_sqlite_version(cursor)
156
157 cursor.execute(
158 f"""
159 SELECT name FROM {_schema_table_name(sqlite_version)} WHERE type = 'table' AND name = 'unihashes_v2'
160 """
161 )
162 if cursor.fetchone():
163 self.logger.info("Upgrading Unihashes V2 -> V3...")
164 cursor.execute(
165 """
166 INSERT INTO unihashes_v3 (id, method, unihash, taskhash, gc_mark)
167 SELECT id, method, unihash, taskhash, '' FROM unihashes_v2
168 """
169 )
170 cursor.execute("DROP TABLE unihashes_v2")
171 db.commit()
172 self.logger.info("Upgrade complete")
173
174 def connect(self, logger):
175 return Database(logger, self.dbname, self.sync)
176
177
178class Database(object):
179 def __init__(self, logger, dbname, sync):
180 self.dbname = dbname
181 self.logger = logger
182
183 self.db = sqlite3.connect(self.dbname)
184 self.db.row_factory = sqlite3.Row
185
186 with closing(self.db.cursor()) as cursor:
187 cursor.execute("PRAGMA journal_mode = WAL")
188 cursor.execute(
189 "PRAGMA synchronous = %s" % ("NORMAL" if sync else "OFF")
190 )
191
192 self.sqlite_version = _get_sqlite_version(cursor)
193
194 async def __aenter__(self):
195 return self
196
197 async def __aexit__(self, exc_type, exc_value, traceback):
198 await self.close()
199
200 async def _set_config(self, cursor, name, value):
201 cursor.execute(
202 """
203 INSERT OR REPLACE INTO config (id, name, value) VALUES
204 ((SELECT id FROM config WHERE name=:name), :name, :value)
205 """,
206 {
207 "name": name,
208 "value": value,
209 },
210 )
211
212 async def _get_config(self, cursor, name):
213 cursor.execute(
214 "SELECT value FROM config WHERE name=:name",
215 {
216 "name": name,
217 },
218 )
219 row = cursor.fetchone()
220 if row is None:
221 return None
222 return row["value"]
223
224 async def close(self):
225 self.db.close()
226
227 async def get_unihash_by_taskhash_full(self, method, taskhash):
228 with closing(self.db.cursor()) as cursor:
229 cursor.execute(
230 """
231 SELECT *, unihashes_v3.unihash AS unihash FROM outhashes_v2
232 INNER JOIN unihashes_v3 ON unihashes_v3.method=outhashes_v2.method AND unihashes_v3.taskhash=outhashes_v2.taskhash
233 WHERE outhashes_v2.method=:method AND outhashes_v2.taskhash=:taskhash
234 ORDER BY outhashes_v2.created ASC
235 LIMIT 1
236 """,
237 {
238 "method": method,
239 "taskhash": taskhash,
240 },
241 )
242 return cursor.fetchone()
243
244 async def get_unihash_by_outhash(self, method, outhash):
245 with closing(self.db.cursor()) as cursor:
246 cursor.execute(
247 """
248 SELECT *, unihashes_v3.unihash AS unihash FROM outhashes_v2
249 INNER JOIN unihashes_v3 ON unihashes_v3.method=outhashes_v2.method AND unihashes_v3.taskhash=outhashes_v2.taskhash
250 WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash
251 ORDER BY outhashes_v2.created ASC
252 LIMIT 1
253 """,
254 {
255 "method": method,
256 "outhash": outhash,
257 },
258 )
259 return cursor.fetchone()
260
261 async def unihash_exists(self, unihash):
262 with closing(self.db.cursor()) as cursor:
263 cursor.execute(
264 """
265 SELECT * FROM unihashes_v3 WHERE unihash=:unihash
266 LIMIT 1
267 """,
268 {
269 "unihash": unihash,
270 },
271 )
272 return cursor.fetchone() is not None
273
274 async def get_outhash(self, method, outhash):
275 with closing(self.db.cursor()) as cursor:
276 cursor.execute(
277 """
278 SELECT * FROM outhashes_v2
279 WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash
280 ORDER BY outhashes_v2.created ASC
281 LIMIT 1
282 """,
283 {
284 "method": method,
285 "outhash": outhash,
286 },
287 )
288 return cursor.fetchone()
289
290 async def get_equivalent_for_outhash(self, method, outhash, taskhash):
291 with closing(self.db.cursor()) as cursor:
292 cursor.execute(
293 """
294 SELECT outhashes_v2.taskhash AS taskhash, unihashes_v3.unihash AS unihash FROM outhashes_v2
295 INNER JOIN unihashes_v3 ON unihashes_v3.method=outhashes_v2.method AND unihashes_v3.taskhash=outhashes_v2.taskhash
296 -- Select any matching output hash except the one we just inserted
297 WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash AND outhashes_v2.taskhash!=:taskhash
298 -- Pick the oldest hash
299 ORDER BY outhashes_v2.created ASC
300 LIMIT 1
301 """,
302 {
303 "method": method,
304 "outhash": outhash,
305 "taskhash": taskhash,
306 },
307 )
308 return cursor.fetchone()
309
310 async def get_equivalent(self, method, taskhash):
311 with closing(self.db.cursor()) as cursor:
312 cursor.execute(
313 "SELECT taskhash, method, unihash FROM unihashes_v3 WHERE method=:method AND taskhash=:taskhash",
314 {
315 "method": method,
316 "taskhash": taskhash,
317 },
318 )
319 return cursor.fetchone()
320
321 async def remove(self, condition):
322 def do_remove(columns, table_name, cursor):
323 where, clause = _make_condition_statement(columns, condition)
324 if where:
325 query = f"DELETE FROM {table_name} WHERE {clause}"
326 cursor.execute(query, where)
327 return cursor.rowcount
328
329 return 0
330
331 count = 0
332 with closing(self.db.cursor()) as cursor:
333 count += do_remove(OUTHASH_TABLE_COLUMNS, "outhashes_v2", cursor)
334 count += do_remove(UNIHASH_TABLE_COLUMNS, "unihashes_v3", cursor)
335 self.db.commit()
336
337 return count
338
339 async def get_current_gc_mark(self):
340 with closing(self.db.cursor()) as cursor:
341 return await self._get_config(cursor, "gc-mark")
342
343 async def gc_status(self):
344 with closing(self.db.cursor()) as cursor:
345 cursor.execute(
346 """
347 SELECT COUNT() FROM unihashes_v3 WHERE
348 gc_mark=COALESCE((SELECT value FROM config WHERE name='gc-mark'), '')
349 """
350 )
351 keep_rows = cursor.fetchone()[0]
352
353 cursor.execute(
354 """
355 SELECT COUNT() FROM unihashes_v3 WHERE
356 gc_mark!=COALESCE((SELECT value FROM config WHERE name='gc-mark'), '')
357 """
358 )
359 remove_rows = cursor.fetchone()[0]
360
361 current_mark = await self._get_config(cursor, "gc-mark")
362
363 return (keep_rows, remove_rows, current_mark)
364
365 async def gc_mark(self, mark, condition):
366 with closing(self.db.cursor()) as cursor:
367 await self._set_config(cursor, "gc-mark", mark)
368
369 where, clause = _make_condition_statement(UNIHASH_TABLE_COLUMNS, condition)
370
371 new_rows = 0
372 if where:
373 cursor.execute(
374 f"""
375 UPDATE unihashes_v3 SET
376 gc_mark=COALESCE((SELECT value FROM config WHERE name='gc-mark'), '')
377 WHERE {clause}
378 """,
379 where,
380 )
381 new_rows = cursor.rowcount
382
383 self.db.commit()
384 return new_rows
385
386 async def gc_sweep(self):
387 with closing(self.db.cursor()) as cursor:
388 # NOTE: COALESCE is not used in this query so that if the current
389 # mark is NULL, nothing will happen
390 cursor.execute(
391 """
392 DELETE FROM unihashes_v3 WHERE
393 gc_mark!=(SELECT value FROM config WHERE name='gc-mark')
394 """
395 )
396 count = cursor.rowcount
397 await self._set_config(cursor, "gc-mark", None)
398
399 self.db.commit()
400 return count
401
402 async def clean_unused(self, oldest):
403 with closing(self.db.cursor()) as cursor:
404 cursor.execute(
405 """
406 DELETE FROM outhashes_v2 WHERE created<:oldest AND NOT EXISTS (
407 SELECT unihashes_v3.id FROM unihashes_v3 WHERE unihashes_v3.method=outhashes_v2.method AND unihashes_v3.taskhash=outhashes_v2.taskhash LIMIT 1
408 )
409 """,
410 {
411 "oldest": oldest,
412 },
413 )
414 self.db.commit()
415 return cursor.rowcount
416
417 async def insert_unihash(self, method, taskhash, unihash):
418 with closing(self.db.cursor()) as cursor:
419 prevrowid = cursor.lastrowid
420 cursor.execute(
421 """
422 INSERT OR IGNORE INTO unihashes_v3 (method, taskhash, unihash, gc_mark) VALUES
423 (
424 :method,
425 :taskhash,
426 :unihash,
427 COALESCE((SELECT value FROM config WHERE name='gc-mark'), '')
428 )
429 """,
430 {
431 "method": method,
432 "taskhash": taskhash,
433 "unihash": unihash,
434 },
435 )
436 self.db.commit()
437 return cursor.lastrowid != prevrowid
438
439 async def insert_outhash(self, data):
440 data = {k: v for k, v in data.items() if k in OUTHASH_TABLE_COLUMNS}
441 keys = sorted(data.keys())
442 query = "INSERT OR IGNORE INTO outhashes_v2 ({fields}) VALUES({values})".format(
443 fields=", ".join(keys),
444 values=", ".join(":" + k for k in keys),
445 )
446 with closing(self.db.cursor()) as cursor:
447 prevrowid = cursor.lastrowid
448 cursor.execute(query, data)
449 self.db.commit()
450 return cursor.lastrowid != prevrowid
451
452 def _get_user(self, username):
453 with closing(self.db.cursor()) as cursor:
454 cursor.execute(
455 """
456 SELECT username, permissions, token FROM users WHERE username=:username
457 """,
458 {
459 "username": username,
460 },
461 )
462 return cursor.fetchone()
463
464 async def lookup_user_token(self, username):
465 row = self._get_user(username)
466 if row is None:
467 return None, None
468 return map_user(row), row["token"]
469
470 async def lookup_user(self, username):
471 return map_user(self._get_user(username))
472
473 async def set_user_token(self, username, token):
474 with closing(self.db.cursor()) as cursor:
475 cursor.execute(
476 """
477 UPDATE users SET token=:token WHERE username=:username
478 """,
479 {
480 "username": username,
481 "token": token,
482 },
483 )
484 self.db.commit()
485 return cursor.rowcount != 0
486
487 async def set_user_perms(self, username, permissions):
488 with closing(self.db.cursor()) as cursor:
489 cursor.execute(
490 """
491 UPDATE users SET permissions=:permissions WHERE username=:username
492 """,
493 {
494 "username": username,
495 "permissions": " ".join(permissions),
496 },
497 )
498 self.db.commit()
499 return cursor.rowcount != 0
500
501 async def get_all_users(self):
502 with closing(self.db.cursor()) as cursor:
503 cursor.execute("SELECT username, permissions FROM users")
504 return [map_user(r) for r in cursor.fetchall()]
505
506 async def new_user(self, username, permissions, token):
507 with closing(self.db.cursor()) as cursor:
508 try:
509 cursor.execute(
510 """
511 INSERT INTO users (username, token, permissions) VALUES (:username, :token, :permissions)
512 """,
513 {
514 "username": username,
515 "token": token,
516 "permissions": " ".join(permissions),
517 },
518 )
519 self.db.commit()
520 return True
521 except sqlite3.IntegrityError:
522 return False
523
524 async def delete_user(self, username):
525 with closing(self.db.cursor()) as cursor:
526 cursor.execute(
527 """
528 DELETE FROM users WHERE username=:username
529 """,
530 {
531 "username": username,
532 },
533 )
534 self.db.commit()
535 return cursor.rowcount != 0
536
537 async def get_usage(self):
538 usage = {}
539 with closing(self.db.cursor()) as cursor:
540 cursor.execute(
541 f"""
542 SELECT name FROM {_schema_table_name(self.sqlite_version)} WHERE type = 'table' AND name NOT LIKE 'sqlite_%'
543 """
544 )
545 for row in cursor.fetchall():
546 cursor.execute(
547 """
548 SELECT COUNT() FROM %s
549 """
550 % row["name"],
551 )
552 usage[row["name"]] = {
553 "rows": cursor.fetchone()[0],
554 }
555 return usage
556
557 async def get_query_columns(self):
558 columns = set()
559 for name, typ, _ in UNIHASH_TABLE_DEFINITION + OUTHASH_TABLE_DEFINITION:
560 if typ.startswith("TEXT"):
561 columns.add(name)
562 return list(columns)
diff --git a/bitbake/lib/hashserv/tests.py b/bitbake/lib/hashserv/tests.py
index 1a696481e3..13ccb20ebf 100644
--- a/bitbake/lib/hashserv/tests.py
+++ b/bitbake/lib/hashserv/tests.py
@@ -6,7 +6,8 @@
6# 6#
7 7
8from . import create_server, create_client 8from . import create_server, create_client
9from .client import HashConnectionError 9from .server import DEFAULT_ANON_PERMS, ALL_PERMISSIONS
10from bb.asyncrpc import InvokeError
10import hashlib 11import hashlib
11import logging 12import logging
12import multiprocessing 13import multiprocessing
@@ -16,72 +17,161 @@ import tempfile
16import threading 17import threading
17import unittest 18import unittest
18import socket 19import socket
19 20import time
20def _run_server(server, idx): 21import signal
21 # logging.basicConfig(level=logging.DEBUG, filename='bbhashserv.log', filemode='w', 22import subprocess
22 # format='%(levelname)s %(filename)s:%(lineno)d %(message)s') 23import json
23 sys.stdout = open('bbhashserv-%d.log' % idx, 'w') 24import re
25from pathlib import Path
26
27
28THIS_DIR = Path(__file__).parent
29BIN_DIR = THIS_DIR.parent.parent / "bin"
30
31def server_prefunc(server, idx):
32 logging.basicConfig(level=logging.DEBUG, filename='bbhashserv-%d.log' % idx, filemode='w',
33 format='%(levelname)s %(filename)s:%(lineno)d %(message)s')
34 server.logger.debug("Running server %d" % idx)
35 sys.stdout = open('bbhashserv-stdout-%d.log' % idx, 'w')
24 sys.stderr = sys.stdout 36 sys.stderr = sys.stdout
25 server.serve_forever()
26
27 37
28class HashEquivalenceTestSetup(object): 38class HashEquivalenceTestSetup(object):
29 METHOD = 'TestMethod' 39 METHOD = 'TestMethod'
30 40
31 server_index = 0 41 server_index = 0
42 client_index = 0
32 43
33 def start_server(self, dbpath=None, upstream=None, read_only=False): 44 def start_server(self, dbpath=None, upstream=None, read_only=False, prefunc=server_prefunc, anon_perms=DEFAULT_ANON_PERMS, admin_username=None, admin_password=None):
34 self.server_index += 1 45 self.server_index += 1
35 if dbpath is None: 46 if dbpath is None:
36 dbpath = os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index) 47 dbpath = self.make_dbpath()
48
49 def cleanup_server(server):
50 if server.process.exitcode is not None:
51 return
37 52
38 def cleanup_thread(thread): 53 server.process.terminate()
39 thread.terminate() 54 server.process.join()
40 thread.join()
41 55
42 server = create_server(self.get_server_addr(self.server_index), 56 server = create_server(self.get_server_addr(self.server_index),
43 dbpath, 57 dbpath,
44 upstream=upstream, 58 upstream=upstream,
45 read_only=read_only) 59 read_only=read_only,
60 anon_perms=anon_perms,
61 admin_username=admin_username,
62 admin_password=admin_password)
46 server.dbpath = dbpath 63 server.dbpath = dbpath
47 64
48 server.thread = multiprocessing.Process(target=_run_server, args=(server, self.server_index)) 65 server.serve_as_process(prefunc=prefunc, args=(self.server_index,))
49 server.thread.start() 66 self.addCleanup(cleanup_server, server)
50 self.addCleanup(cleanup_thread, server.thread) 67
68 return server
69
70 def make_dbpath(self):
71 return os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index)
51 72
73 def start_client(self, server_address, username=None, password=None):
52 def cleanup_client(client): 74 def cleanup_client(client):
53 client.close() 75 client.close()
54 76
55 client = create_client(server.address) 77 client = create_client(server_address, username=username, password=password)
56 self.addCleanup(cleanup_client, client) 78 self.addCleanup(cleanup_client, client)
57 79
58 return (client, server) 80 return client
59 81
60 def setUp(self): 82 def start_test_server(self):
61 if sys.version_info < (3, 5, 0): 83 self.server = self.start_server()
62 self.skipTest('Python 3.5 or later required') 84 return self.server.address
85
86 def start_auth_server(self):
87 auth_server = self.start_server(self.server.dbpath, anon_perms=[], admin_username="admin", admin_password="password")
88 self.auth_server_address = auth_server.address
89 self.admin_client = self.start_client(auth_server.address, username="admin", password="password")
90 return self.admin_client
63 91
92 def auth_client(self, user):
93 return self.start_client(self.auth_server_address, user["username"], user["token"])
94
95 def setUp(self):
64 self.temp_dir = tempfile.TemporaryDirectory(prefix='bb-hashserv') 96 self.temp_dir = tempfile.TemporaryDirectory(prefix='bb-hashserv')
65 self.addCleanup(self.temp_dir.cleanup) 97 self.addCleanup(self.temp_dir.cleanup)
66 98
67 (self.client, self.server) = self.start_server() 99 self.server_address = self.start_test_server()
100
101 self.client = self.start_client(self.server_address)
68 102
69 def assertClientGetHash(self, client, taskhash, unihash): 103 def assertClientGetHash(self, client, taskhash, unihash):
70 result = client.get_unihash(self.METHOD, taskhash) 104 result = client.get_unihash(self.METHOD, taskhash)
71 self.assertEqual(result, unihash) 105 self.assertEqual(result, unihash)
72 106
107 def assertUserPerms(self, user, permissions):
108 with self.auth_client(user) as client:
109 info = client.get_user()
110 self.assertEqual(info, {
111 "username": user["username"],
112 "permissions": permissions,
113 })
73 114
74class HashEquivalenceCommonTests(object): 115 def assertUserCanAuth(self, user):
75 def test_create_hash(self): 116 with self.start_client(self.auth_server_address) as client:
117 client.auth(user["username"], user["token"])
118
119 def assertUserCannotAuth(self, user):
120 with self.start_client(self.auth_server_address) as client, self.assertRaises(InvokeError):
121 client.auth(user["username"], user["token"])
122
123 def create_test_hash(self, client):
76 # Simple test that hashes can be created 124 # Simple test that hashes can be created
77 taskhash = '35788efcb8dfb0a02659d81cf2bfd695fb30faf9' 125 taskhash = '35788efcb8dfb0a02659d81cf2bfd695fb30faf9'
78 outhash = '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f' 126 outhash = '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f'
79 unihash = 'f46d3fbb439bd9b921095da657a4de906510d2cd' 127 unihash = 'f46d3fbb439bd9b921095da657a4de906510d2cd'
80 128
81 self.assertClientGetHash(self.client, taskhash, None) 129 self.assertClientGetHash(client, taskhash, None)
82 130
83 result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) 131 result = client.report_unihash(taskhash, self.METHOD, outhash, unihash)
84 self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') 132 self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
133 return taskhash, outhash, unihash
134
135 def run_hashclient(self, args, **kwargs):
136 try:
137 p = subprocess.run(
138 [BIN_DIR / "bitbake-hashclient"] + args,
139 stdout=subprocess.PIPE,
140 stderr=subprocess.STDOUT,
141 encoding="utf-8",
142 **kwargs
143 )
144 except subprocess.CalledProcessError as e:
145 print(e.output)
146 raise e
147
148 print(p.stdout)
149 return p
150
151
152class HashEquivalenceCommonTests(object):
153 def auth_perms(self, *permissions):
154 self.client_index += 1
155 user = self.create_user(f"user-{self.client_index}", permissions)
156 return self.auth_client(user)
157
158 def create_user(self, username, permissions, *, client=None):
159 def remove_user(username):
160 try:
161 self.admin_client.delete_user(username)
162 except bb.asyncrpc.InvokeError:
163 pass
164
165 if client is None:
166 client = self.admin_client
167
168 user = client.new_user(username, permissions)
169 self.addCleanup(remove_user, username)
170
171 return user
172
173 def test_create_hash(self):
174 return self.create_test_hash(self.client)
85 175
86 def test_create_equivalent(self): 176 def test_create_equivalent(self):
87 # Tests that a second reported task with the same outhash will be 177 # Tests that a second reported task with the same outhash will be
@@ -123,6 +213,57 @@ class HashEquivalenceCommonTests(object):
123 213
124 self.assertClientGetHash(self.client, taskhash, unihash) 214 self.assertClientGetHash(self.client, taskhash, unihash)
125 215
216 def test_remove_taskhash(self):
217 taskhash, outhash, unihash = self.create_test_hash(self.client)
218 result = self.client.remove({"taskhash": taskhash})
219 self.assertGreater(result["count"], 0)
220 self.assertClientGetHash(self.client, taskhash, None)
221
222 result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash)
223 self.assertIsNone(result_outhash)
224
225 def test_remove_unihash(self):
226 taskhash, outhash, unihash = self.create_test_hash(self.client)
227 result = self.client.remove({"unihash": unihash})
228 self.assertGreater(result["count"], 0)
229 self.assertClientGetHash(self.client, taskhash, None)
230
231 def test_remove_outhash(self):
232 taskhash, outhash, unihash = self.create_test_hash(self.client)
233 result = self.client.remove({"outhash": outhash})
234 self.assertGreater(result["count"], 0)
235
236 result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash)
237 self.assertIsNone(result_outhash)
238
239 def test_remove_method(self):
240 taskhash, outhash, unihash = self.create_test_hash(self.client)
241 result = self.client.remove({"method": self.METHOD})
242 self.assertGreater(result["count"], 0)
243 self.assertClientGetHash(self.client, taskhash, None)
244
245 result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash)
246 self.assertIsNone(result_outhash)
247
248 def test_clean_unused(self):
249 taskhash, outhash, unihash = self.create_test_hash(self.client)
250
251 # Clean the database, which should not remove anything because all hashes an in-use
252 result = self.client.clean_unused(0)
253 self.assertEqual(result["count"], 0)
254 self.assertClientGetHash(self.client, taskhash, unihash)
255
256 # Remove the unihash. The row in the outhash table should still be present
257 self.client.remove({"unihash": unihash})
258 result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash, False)
259 self.assertIsNotNone(result_outhash)
260
261 # Now clean with no minimum age which will remove the outhash
262 result = self.client.clean_unused(0)
263 self.assertEqual(result["count"], 1)
264 result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash, False)
265 self.assertIsNone(result_outhash)
266
126 def test_huge_message(self): 267 def test_huge_message(self):
127 # Simple test that hashes can be created 268 # Simple test that hashes can be created
128 taskhash = 'c665584ee6817aa99edfc77a44dd853828279370' 269 taskhash = 'c665584ee6817aa99edfc77a44dd853828279370'
@@ -138,16 +279,21 @@ class HashEquivalenceCommonTests(object):
138 }) 279 })
139 self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') 280 self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
140 281
141 result = self.client.get_taskhash(self.METHOD, taskhash, True) 282 result_unihash = self.client.get_taskhash(self.METHOD, taskhash, True)
142 self.assertEqual(result['taskhash'], taskhash) 283 self.assertEqual(result_unihash['taskhash'], taskhash)
143 self.assertEqual(result['unihash'], unihash) 284 self.assertEqual(result_unihash['unihash'], unihash)
144 self.assertEqual(result['method'], self.METHOD) 285 self.assertEqual(result_unihash['method'], self.METHOD)
145 self.assertEqual(result['outhash'], outhash) 286
146 self.assertEqual(result['outhash_siginfo'], siginfo) 287 result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash)
288 self.assertEqual(result_outhash['taskhash'], taskhash)
289 self.assertEqual(result_outhash['method'], self.METHOD)
290 self.assertEqual(result_outhash['unihash'], unihash)
291 self.assertEqual(result_outhash['outhash'], outhash)
292 self.assertEqual(result_outhash['outhash_siginfo'], siginfo)
147 293
148 def test_stress(self): 294 def test_stress(self):
149 def query_server(failures): 295 def query_server(failures):
150 client = Client(self.server.address) 296 client = Client(self.server_address)
151 try: 297 try:
152 for i in range(1000): 298 for i in range(1000):
153 taskhash = hashlib.sha256() 299 taskhash = hashlib.sha256()
@@ -186,8 +332,10 @@ class HashEquivalenceCommonTests(object):
186 # the side client. It also verifies that the results are pulled into 332 # the side client. It also verifies that the results are pulled into
187 # the downstream database by checking that the downstream and side servers 333 # the downstream database by checking that the downstream and side servers
188 # match after the downstream is done waiting for all backfill tasks 334 # match after the downstream is done waiting for all backfill tasks
189 (down_client, down_server) = self.start_server(upstream=self.server.address) 335 down_server = self.start_server(upstream=self.server_address)
190 (side_client, side_server) = self.start_server(dbpath=down_server.dbpath) 336 down_client = self.start_client(down_server.address)
337 side_server = self.start_server(dbpath=down_server.dbpath)
338 side_client = self.start_client(side_server.address)
191 339
192 def check_hash(taskhash, unihash, old_sidehash): 340 def check_hash(taskhash, unihash, old_sidehash):
193 nonlocal down_client 341 nonlocal down_client
@@ -258,15 +406,57 @@ class HashEquivalenceCommonTests(object):
258 result = down_client.report_unihash(taskhash6, self.METHOD, outhash5, unihash6) 406 result = down_client.report_unihash(taskhash6, self.METHOD, outhash5, unihash6)
259 self.assertEqual(result['unihash'], unihash5, 'Server failed to copy unihash from upstream') 407 self.assertEqual(result['unihash'], unihash5, 'Server failed to copy unihash from upstream')
260 408
409 # Tests read through from server with
410 taskhash7 = '9d81d76242cc7cfaf7bf74b94b9cd2e29324ed74'
411 outhash7 = '8470d56547eea6236d7c81a644ce74670ca0bbda998e13c629ef6bb3f0d60b69'
412 unihash7 = '05d2a63c81e32f0a36542ca677e8ad852365c538'
413 self.client.report_unihash(taskhash7, self.METHOD, outhash7, unihash7)
414
415 result = down_client.get_taskhash(self.METHOD, taskhash7, True)
416 self.assertEqual(result['unihash'], unihash7, 'Server failed to copy unihash from upstream')
417 self.assertEqual(result['outhash'], outhash7, 'Server failed to copy unihash from upstream')
418 self.assertEqual(result['taskhash'], taskhash7, 'Server failed to copy unihash from upstream')
419 self.assertEqual(result['method'], self.METHOD)
420
421 taskhash8 = '86978a4c8c71b9b487330b0152aade10c1ee58aa'
422 outhash8 = 'ca8c128e9d9e4a28ef24d0508aa20b5cf880604eacd8f65c0e366f7e0cc5fbcf'
423 unihash8 = 'd8bcf25369d40590ad7d08c84d538982f2023e01'
424 self.client.report_unihash(taskhash8, self.METHOD, outhash8, unihash8)
425
426 result = down_client.get_outhash(self.METHOD, outhash8, taskhash8)
427 self.assertEqual(result['unihash'], unihash8, 'Server failed to copy unihash from upstream')
428 self.assertEqual(result['outhash'], outhash8, 'Server failed to copy unihash from upstream')
429 self.assertEqual(result['taskhash'], taskhash8, 'Server failed to copy unihash from upstream')
430 self.assertEqual(result['method'], self.METHOD)
431
432 taskhash9 = 'ae6339531895ddf5b67e663e6a374ad8ec71d81c'
433 outhash9 = 'afc78172c81880ae10a1fec994b5b4ee33d196a001a1b66212a15ebe573e00b5'
434 unihash9 = '6662e699d6e3d894b24408ff9a4031ef9b038ee8'
435 self.client.report_unihash(taskhash9, self.METHOD, outhash9, unihash9)
436
437 result = down_client.get_taskhash(self.METHOD, taskhash9, False)
438 self.assertEqual(result['unihash'], unihash9, 'Server failed to copy unihash from upstream')
439 self.assertEqual(result['taskhash'], taskhash9, 'Server failed to copy unihash from upstream')
440 self.assertEqual(result['method'], self.METHOD)
441
442 def test_unihash_exsits(self):
443 taskhash, outhash, unihash = self.create_test_hash(self.client)
444 self.assertTrue(self.client.unihash_exists(unihash))
445 self.assertFalse(self.client.unihash_exists('6662e699d6e3d894b24408ff9a4031ef9b038ee8'))
446
261 def test_ro_server(self): 447 def test_ro_server(self):
262 (ro_client, ro_server) = self.start_server(dbpath=self.server.dbpath, read_only=True) 448 rw_server = self.start_server()
449 rw_client = self.start_client(rw_server.address)
450
451 ro_server = self.start_server(dbpath=rw_server.dbpath, read_only=True)
452 ro_client = self.start_client(ro_server.address)
263 453
264 # Report a hash via the read-write server 454 # Report a hash via the read-write server
265 taskhash = '35788efcb8dfb0a02659d81cf2bfd695fb30faf9' 455 taskhash = '35788efcb8dfb0a02659d81cf2bfd695fb30faf9'
266 outhash = '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f' 456 outhash = '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f'
267 unihash = 'f46d3fbb439bd9b921095da657a4de906510d2cd' 457 unihash = 'f46d3fbb439bd9b921095da657a4de906510d2cd'
268 458
269 result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) 459 result = rw_client.report_unihash(taskhash, self.METHOD, outhash, unihash)
270 self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') 460 self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
271 461
272 # Check the hash via the read-only server 462 # Check the hash via the read-only server
@@ -277,11 +467,934 @@ class HashEquivalenceCommonTests(object):
277 outhash2 = '3c979c3db45c569f51ab7626a4651074be3a9d11a84b1db076f5b14f7d39db44' 467 outhash2 = '3c979c3db45c569f51ab7626a4651074be3a9d11a84b1db076f5b14f7d39db44'
278 unihash2 = '90e9bc1d1f094c51824adca7f8ea79a048d68824' 468 unihash2 = '90e9bc1d1f094c51824adca7f8ea79a048d68824'
279 469
280 with self.assertRaises(HashConnectionError): 470 result = ro_client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2)
281 ro_client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2) 471 self.assertEqual(result['unihash'], unihash2)
282 472
283 # Ensure that the database was not modified 473 # Ensure that the database was not modified
474 self.assertClientGetHash(rw_client, taskhash2, None)
475
476
477 def test_slow_server_start(self):
478 # Ensures that the server will exit correctly even if it gets a SIGTERM
479 # before entering the main loop
480
481 event = multiprocessing.Event()
482
483 def prefunc(server, idx):
484 nonlocal event
485 server_prefunc(server, idx)
486 event.wait()
487
488 def do_nothing(signum, frame):
489 pass
490
491 old_signal = signal.signal(signal.SIGTERM, do_nothing)
492 self.addCleanup(signal.signal, signal.SIGTERM, old_signal)
493
494 server = self.start_server(prefunc=prefunc)
495 server.process.terminate()
496 time.sleep(30)
497 event.set()
498 server.process.join(300)
499 self.assertIsNotNone(server.process.exitcode, "Server did not exit in a timely manner!")
500
501 def test_diverging_report_race(self):
502 # Tests that a reported task will correctly pick up an updated unihash
503
504 # This is a baseline report added to the database to ensure that there
505 # is something to match against as equivalent
506 outhash1 = 'afd11c366050bcd75ad763e898e4430e2a60659b26f83fbb22201a60672019fa'
507 taskhash1 = '3bde230c743fc45ab61a065d7a1815fbfa01c4740e4c895af2eb8dc0f684a4ab'
508 unihash1 = '3bde230c743fc45ab61a065d7a1815fbfa01c4740e4c895af2eb8dc0f684a4ab'
509 result = self.client.report_unihash(taskhash1, self.METHOD, outhash1, unihash1)
510
511 # Add a report that is equivalent to Task 1. It should ignore the
512 # provided unihash and report the unihash from task 1
513 taskhash2 = '6259ae8263bd94d454c086f501c37e64c4e83cae806902ca95b4ab513546b273'
514 unihash2 = taskhash2
515 result = self.client.report_unihash(taskhash2, self.METHOD, outhash1, unihash2)
516 self.assertEqual(result['unihash'], unihash1)
517
518 # Add another report for Task 2, but with a different outhash (e.g. the
519 # task is non-deterministic). It should still be marked with the Task 1
520 # unihash because it has the Task 2 taskhash, which is equivalent to
521 # Task 1
522 outhash3 = 'd2187ee3a8966db10b34fe0e863482288d9a6185cb8ef58a6c1c6ace87a2f24c'
523 result = self.client.report_unihash(taskhash2, self.METHOD, outhash3, unihash2)
524 self.assertEqual(result['unihash'], unihash1)
525
526
527 def test_diverging_report_reverse_race(self):
528 # Same idea as the previous test, but Tasks 2 and 3 are reported in
529 # reverse order the opposite order
530
531 outhash1 = 'afd11c366050bcd75ad763e898e4430e2a60659b26f83fbb22201a60672019fa'
532 taskhash1 = '3bde230c743fc45ab61a065d7a1815fbfa01c4740e4c895af2eb8dc0f684a4ab'
533 unihash1 = '3bde230c743fc45ab61a065d7a1815fbfa01c4740e4c895af2eb8dc0f684a4ab'
534 result = self.client.report_unihash(taskhash1, self.METHOD, outhash1, unihash1)
535
536 taskhash2 = '6259ae8263bd94d454c086f501c37e64c4e83cae806902ca95b4ab513546b273'
537 unihash2 = taskhash2
538
539 # Report Task 3 first. Since there is nothing else in the database it
540 # will use the client provided unihash
541 outhash3 = 'd2187ee3a8966db10b34fe0e863482288d9a6185cb8ef58a6c1c6ace87a2f24c'
542 result = self.client.report_unihash(taskhash2, self.METHOD, outhash3, unihash2)
543 self.assertEqual(result['unihash'], unihash2)
544
545 # Report Task 2. This is equivalent to Task 1 but there is already a mapping for
546 # taskhash2 so it will report unihash2
547 result = self.client.report_unihash(taskhash2, self.METHOD, outhash1, unihash2)
548 self.assertEqual(result['unihash'], unihash2)
549
550 # The originally reported unihash for Task 3 should be unchanged even if it
551 # shares a taskhash with Task 2
552 self.assertClientGetHash(self.client, taskhash2, unihash2)
553
554 def test_get_unihash_batch(self):
555 TEST_INPUT = (
556 # taskhash outhash unihash
557 ('8aa96fcffb5831b3c2c0cb75f0431e3f8b20554a', 'afe240a439959ce86f5e322f8c208e1fedefea9e813f2140c81af866cc9edf7e','218e57509998197d570e2c98512d0105985dffc9'),
558 # Duplicated taskhash with multiple output hashes and unihashes.
559 ('8aa96fcffb5831b3c2c0cb75f0431e3f8b20554a', '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d', 'ae9a7d252735f0dafcdb10e2e02561ca3a47314c'),
560 # Equivalent hash
561 ("044c2ec8aaf480685a00ff6ff49e6162e6ad34e1", '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d', "def64766090d28f627e816454ed46894bb3aab36"),
562 ("e3da00593d6a7fb435c7e2114976c59c5fd6d561", "1cf8713e645f491eb9c959d20b5cae1c47133a292626dda9b10709857cbe688a", "3b5d3d83f07f259e9086fcb422c855286e18a57d"),
563 ('35788efcb8dfb0a02659d81cf2bfd695fb30faf9', '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f', 'f46d3fbb439bd9b921095da657a4de906510d2cd'),
564 ('35788efcb8dfb0a02659d81cf2bfd695fb30fafa', '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f', 'f46d3fbb439bd9b921095da657a4de906510d2ce'),
565 ('9d81d76242cc7cfaf7bf74b94b9cd2e29324ed74', '8470d56547eea6236d7c81a644ce74670ca0bbda998e13c629ef6bb3f0d60b69', '05d2a63c81e32f0a36542ca677e8ad852365c538'),
566 )
567 EXTRA_QUERIES = (
568 "6b6be7a84ab179b4240c4302518dc3f6",
569 )
570
571 for taskhash, outhash, unihash in TEST_INPUT:
572 self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
573
574
575 result = self.client.get_unihash_batch(
576 [(self.METHOD, data[0]) for data in TEST_INPUT] +
577 [(self.METHOD, e) for e in EXTRA_QUERIES]
578 )
579
580 self.assertListEqual(result, [
581 "218e57509998197d570e2c98512d0105985dffc9",
582 "218e57509998197d570e2c98512d0105985dffc9",
583 "218e57509998197d570e2c98512d0105985dffc9",
584 "3b5d3d83f07f259e9086fcb422c855286e18a57d",
585 "f46d3fbb439bd9b921095da657a4de906510d2cd",
586 "f46d3fbb439bd9b921095da657a4de906510d2cd",
587 "05d2a63c81e32f0a36542ca677e8ad852365c538",
588 None,
589 ])
590
591 def test_unihash_exists_batch(self):
592 TEST_INPUT = (
593 # taskhash outhash unihash
594 ('8aa96fcffb5831b3c2c0cb75f0431e3f8b20554a', 'afe240a439959ce86f5e322f8c208e1fedefea9e813f2140c81af866cc9edf7e','218e57509998197d570e2c98512d0105985dffc9'),
595 # Duplicated taskhash with multiple output hashes and unihashes.
596 ('8aa96fcffb5831b3c2c0cb75f0431e3f8b20554a', '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d', 'ae9a7d252735f0dafcdb10e2e02561ca3a47314c'),
597 # Equivalent hash
598 ("044c2ec8aaf480685a00ff6ff49e6162e6ad34e1", '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d', "def64766090d28f627e816454ed46894bb3aab36"),
599 ("e3da00593d6a7fb435c7e2114976c59c5fd6d561", "1cf8713e645f491eb9c959d20b5cae1c47133a292626dda9b10709857cbe688a", "3b5d3d83f07f259e9086fcb422c855286e18a57d"),
600 ('35788efcb8dfb0a02659d81cf2bfd695fb30faf9', '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f', 'f46d3fbb439bd9b921095da657a4de906510d2cd'),
601 ('35788efcb8dfb0a02659d81cf2bfd695fb30fafa', '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f', 'f46d3fbb439bd9b921095da657a4de906510d2ce'),
602 ('9d81d76242cc7cfaf7bf74b94b9cd2e29324ed74', '8470d56547eea6236d7c81a644ce74670ca0bbda998e13c629ef6bb3f0d60b69', '05d2a63c81e32f0a36542ca677e8ad852365c538'),
603 )
604 EXTRA_QUERIES = (
605 "6b6be7a84ab179b4240c4302518dc3f6",
606 )
607
608 result_unihashes = set()
609
610
611 for taskhash, outhash, unihash in TEST_INPUT:
612 result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
613 result_unihashes.add(result["unihash"])
614
615 query = []
616 expected = []
617
618 for _, _, unihash in TEST_INPUT:
619 query.append(unihash)
620 expected.append(unihash in result_unihashes)
621
622
623 for unihash in EXTRA_QUERIES:
624 query.append(unihash)
625 expected.append(False)
626
627 result = self.client.unihash_exists_batch(query)
628 self.assertListEqual(result, expected)
629
630 def test_auth_read_perms(self):
631 admin_client = self.start_auth_server()
632
633 # Create hashes with non-authenticated server
634 taskhash, outhash, unihash = self.create_test_hash(self.client)
635
636 # Validate hash can be retrieved using authenticated client
637 with self.auth_perms("@read") as client:
638 self.assertClientGetHash(client, taskhash, unihash)
639
640 with self.auth_perms() as client, self.assertRaises(InvokeError):
641 self.assertClientGetHash(client, taskhash, unihash)
642
643 def test_auth_report_perms(self):
644 admin_client = self.start_auth_server()
645
646 # Without read permission, the user is completely denied
647 with self.auth_perms() as client, self.assertRaises(InvokeError):
648 self.create_test_hash(client)
649
650 # Read permission allows the call to succeed, but it doesn't record
651 # anythin in the database
652 with self.auth_perms("@read") as client:
653 taskhash, outhash, unihash = self.create_test_hash(client)
654 self.assertClientGetHash(client, taskhash, None)
655
656 # Report permission alone is insufficient
657 with self.auth_perms("@report") as client, self.assertRaises(InvokeError):
658 self.create_test_hash(client)
659
660 # Read and report permission actually modify the database
661 with self.auth_perms("@read", "@report") as client:
662 taskhash, outhash, unihash = self.create_test_hash(client)
663 self.assertClientGetHash(client, taskhash, unihash)
664
665 def test_auth_no_token_refresh_from_anon_user(self):
666 self.start_auth_server()
667
668 with self.start_client(self.auth_server_address) as client, self.assertRaises(InvokeError):
669 client.refresh_token()
670
671 def test_auth_self_token_refresh(self):
672 admin_client = self.start_auth_server()
673
674 # Create a new user with no permissions
675 user = self.create_user("test-user", [])
676
677 with self.auth_client(user) as client:
678 new_user = client.refresh_token()
679
680 self.assertEqual(user["username"], new_user["username"])
681 self.assertNotEqual(user["token"], new_user["token"])
682 self.assertUserCanAuth(new_user)
683 self.assertUserCannotAuth(user)
684
685 # Explicitly specifying with your own username is fine also
686 with self.auth_client(new_user) as client:
687 new_user2 = client.refresh_token(user["username"])
688
689 self.assertEqual(user["username"], new_user2["username"])
690 self.assertNotEqual(user["token"], new_user2["token"])
691 self.assertUserCanAuth(new_user2)
692 self.assertUserCannotAuth(new_user)
693 self.assertUserCannotAuth(user)
694
695 def test_auth_token_refresh(self):
696 admin_client = self.start_auth_server()
697
698 user = self.create_user("test-user", [])
699
700 with self.auth_perms() as client, self.assertRaises(InvokeError):
701 client.refresh_token(user["username"])
702
703 with self.auth_perms("@user-admin") as client:
704 new_user = client.refresh_token(user["username"])
705
706 self.assertEqual(user["username"], new_user["username"])
707 self.assertNotEqual(user["token"], new_user["token"])
708 self.assertUserCanAuth(new_user)
709 self.assertUserCannotAuth(user)
710
711 def test_auth_self_get_user(self):
712 admin_client = self.start_auth_server()
713
714 user = self.create_user("test-user", [])
715 user_info = user.copy()
716 del user_info["token"]
717
718 with self.auth_client(user) as client:
719 info = client.get_user()
720 self.assertEqual(info, user_info)
721
722 # Explicitly asking for your own username is fine also
723 info = client.get_user(user["username"])
724 self.assertEqual(info, user_info)
725
726 def test_auth_get_user(self):
727 admin_client = self.start_auth_server()
728
729 user = self.create_user("test-user", [])
730 user_info = user.copy()
731 del user_info["token"]
732
733 with self.auth_perms() as client, self.assertRaises(InvokeError):
734 client.get_user(user["username"])
735
736 with self.auth_perms("@user-admin") as client:
737 info = client.get_user(user["username"])
738 self.assertEqual(info, user_info)
739
740 info = client.get_user("nonexist-user")
741 self.assertIsNone(info)
742
743 def test_auth_reconnect(self):
744 admin_client = self.start_auth_server()
745
746 user = self.create_user("test-user", [])
747 user_info = user.copy()
748 del user_info["token"]
749
750 with self.auth_client(user) as client:
751 info = client.get_user()
752 self.assertEqual(info, user_info)
753
754 client.disconnect()
755
756 info = client.get_user()
757 self.assertEqual(info, user_info)
758
759 def test_auth_delete_user(self):
760 admin_client = self.start_auth_server()
761
762 user = self.create_user("test-user", [])
763
764 # self service
765 with self.auth_client(user) as client:
766 client.delete_user(user["username"])
767
768 self.assertIsNone(admin_client.get_user(user["username"]))
769 user = self.create_user("test-user", [])
770
771 with self.auth_perms() as client, self.assertRaises(InvokeError):
772 client.delete_user(user["username"])
773
774 with self.auth_perms("@user-admin") as client:
775 client.delete_user(user["username"])
776
777 # User doesn't exist, so even though the permission is correct, it's an
778 # error
779 with self.auth_perms("@user-admin") as client, self.assertRaises(InvokeError):
780 client.delete_user(user["username"])
781
782 def test_auth_set_user_perms(self):
783 admin_client = self.start_auth_server()
784
785 user = self.create_user("test-user", [])
786
787 self.assertUserPerms(user, [])
788
789 # No self service to change permissions
790 with self.auth_client(user) as client, self.assertRaises(InvokeError):
791 client.set_user_perms(user["username"], ["@all"])
792 self.assertUserPerms(user, [])
793
794 with self.auth_perms() as client, self.assertRaises(InvokeError):
795 client.set_user_perms(user["username"], ["@all"])
796 self.assertUserPerms(user, [])
797
798 with self.auth_perms("@user-admin") as client:
799 client.set_user_perms(user["username"], ["@all"])
800 self.assertUserPerms(user, sorted(list(ALL_PERMISSIONS)))
801
802 # Bad permissions
803 with self.auth_perms("@user-admin") as client, self.assertRaises(InvokeError):
804 client.set_user_perms(user["username"], ["@this-is-not-a-permission"])
805 self.assertUserPerms(user, sorted(list(ALL_PERMISSIONS)))
806
807 def test_auth_get_all_users(self):
808 admin_client = self.start_auth_server()
809
810 user = self.create_user("test-user", [])
811
812 with self.auth_client(user) as client, self.assertRaises(InvokeError):
813 client.get_all_users()
814
815 # Give the test user the correct permission
816 admin_client.set_user_perms(user["username"], ["@user-admin"])
817
818 with self.auth_client(user) as client:
819 all_users = client.get_all_users()
820
821 # Convert to a dictionary for easier comparison
822 all_users = {u["username"]: u for u in all_users}
823
824 self.assertEqual(all_users,
825 {
826 "admin": {
827 "username": "admin",
828 "permissions": sorted(list(ALL_PERMISSIONS)),
829 },
830 "test-user": {
831 "username": "test-user",
832 "permissions": ["@user-admin"],
833 }
834 }
835 )
836
837 def test_auth_new_user(self):
838 self.start_auth_server()
839
840 permissions = ["@read", "@report", "@db-admin", "@user-admin"]
841 permissions.sort()
842
843 with self.auth_perms() as client, self.assertRaises(InvokeError):
844 self.create_user("test-user", permissions, client=client)
845
846 with self.auth_perms("@user-admin") as client:
847 user = self.create_user("test-user", permissions, client=client)
848 self.assertIn("token", user)
849 self.assertEqual(user["username"], "test-user")
850 self.assertEqual(user["permissions"], permissions)
851
852 def test_auth_become_user(self):
853 admin_client = self.start_auth_server()
854
855 user = self.create_user("test-user", ["@read", "@report"])
856 user_info = user.copy()
857 del user_info["token"]
858
859 with self.auth_perms() as client, self.assertRaises(InvokeError):
860 client.become_user(user["username"])
861
862 with self.auth_perms("@user-admin") as client:
863 become = client.become_user(user["username"])
864 self.assertEqual(become, user_info)
865
866 info = client.get_user()
867 self.assertEqual(info, user_info)
868
869 # Verify become user is preserved across disconnect
870 client.disconnect()
871
872 info = client.get_user()
873 self.assertEqual(info, user_info)
874
875 # test-user doesn't have become_user permissions, so this should
876 # not work
877 with self.assertRaises(InvokeError):
878 client.become_user(user["username"])
879
880 # No self-service of become
881 with self.auth_client(user) as client, self.assertRaises(InvokeError):
882 client.become_user(user["username"])
883
884 # Give test user permissions to become
885 admin_client.set_user_perms(user["username"], ["@user-admin"])
886
887 # It's possible to become yourself (effectively a noop)
888 with self.auth_perms("@user-admin") as client:
889 become = client.become_user(client.username)
890
891 def test_auth_gc(self):
892 admin_client = self.start_auth_server()
893
894 with self.auth_perms() as client, self.assertRaises(InvokeError):
895 client.gc_mark("ABC", {"unihash": "123"})
896
897 with self.auth_perms() as client, self.assertRaises(InvokeError):
898 client.gc_status()
899
900 with self.auth_perms() as client, self.assertRaises(InvokeError):
901 client.gc_sweep("ABC")
902
903 with self.auth_perms("@db-admin") as client:
904 client.gc_mark("ABC", {"unihash": "123"})
905
906 with self.auth_perms("@db-admin") as client:
907 client.gc_status()
908
909 with self.auth_perms("@db-admin") as client:
910 client.gc_sweep("ABC")
911
912 def test_get_db_usage(self):
913 usage = self.client.get_db_usage()
914
915 self.assertTrue(isinstance(usage, dict))
916 for name in usage.keys():
917 self.assertTrue(isinstance(usage[name], dict))
918 self.assertIn("rows", usage[name])
919 self.assertTrue(isinstance(usage[name]["rows"], int))
920
921 def test_get_db_query_columns(self):
922 columns = self.client.get_db_query_columns()
923
924 self.assertTrue(isinstance(columns, list))
925 self.assertTrue(len(columns) > 0)
926
927 for col in columns:
928 self.client.remove({col: ""})
929
930 def test_auth_is_owner(self):
931 admin_client = self.start_auth_server()
932
933 user = self.create_user("test-user", ["@read", "@report"])
934 with self.auth_client(user) as client:
935 taskhash, outhash, unihash = self.create_test_hash(client)
936 data = client.get_taskhash(self.METHOD, taskhash, True)
937 self.assertEqual(data["owner"], user["username"])
938
939 def test_gc(self):
940 taskhash = '53b8dce672cb6d0c73170be43f540460bfc347b4'
941 outhash = '5a9cb1649625f0bf41fc7791b635cd9c2d7118c7f021ba87dcd03f72b67ce7a8'
942 unihash = 'f37918cc02eb5a520b1aff86faacbc0a38124646'
943
944 result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
945 self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
946
947 taskhash2 = '3bf6f1e89d26205aec90da04854fbdbf73afe6b4'
948 outhash2 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4'
949 unihash2 = 'af36b199320e611fbb16f1f277d3ee1d619ca58b'
950
951 result = self.client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2)
952 self.assertClientGetHash(self.client, taskhash2, unihash2)
953
954 # Mark the first unihash to be kept
955 ret = self.client.gc_mark("ABC", {"unihash": unihash, "method": self.METHOD})
956 self.assertEqual(ret, {"count": 1})
957
958 ret = self.client.gc_status()
959 self.assertEqual(ret, {"mark": "ABC", "keep": 1, "remove": 1})
960
961 # Second hash is still there; mark doesn't delete hashes
962 self.assertClientGetHash(self.client, taskhash2, unihash2)
963
964 ret = self.client.gc_sweep("ABC")
965 self.assertEqual(ret, {"count": 1})
966
967 # Hash is gone. Taskhash is returned for second hash
284 self.assertClientGetHash(self.client, taskhash2, None) 968 self.assertClientGetHash(self.client, taskhash2, None)
969 # First hash is still present
970 self.assertClientGetHash(self.client, taskhash, unihash)
971
972 def test_gc_switch_mark(self):
973 taskhash = '53b8dce672cb6d0c73170be43f540460bfc347b4'
974 outhash = '5a9cb1649625f0bf41fc7791b635cd9c2d7118c7f021ba87dcd03f72b67ce7a8'
975 unihash = 'f37918cc02eb5a520b1aff86faacbc0a38124646'
976
977 result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
978 self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
979
980 taskhash2 = '3bf6f1e89d26205aec90da04854fbdbf73afe6b4'
981 outhash2 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4'
982 unihash2 = 'af36b199320e611fbb16f1f277d3ee1d619ca58b'
983
984 result = self.client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2)
985 self.assertClientGetHash(self.client, taskhash2, unihash2)
986
987 # Mark the first unihash to be kept
988 ret = self.client.gc_mark("ABC", {"unihash": unihash, "method": self.METHOD})
989 self.assertEqual(ret, {"count": 1})
990
991 ret = self.client.gc_status()
992 self.assertEqual(ret, {"mark": "ABC", "keep": 1, "remove": 1})
993
994 # Second hash is still there; mark doesn't delete hashes
995 self.assertClientGetHash(self.client, taskhash2, unihash2)
996
997 # Switch to a different mark and mark the second hash. This will start
998 # a new collection cycle
999 ret = self.client.gc_mark("DEF", {"unihash": unihash2, "method": self.METHOD})
1000 self.assertEqual(ret, {"count": 1})
1001
1002 ret = self.client.gc_status()
1003 self.assertEqual(ret, {"mark": "DEF", "keep": 1, "remove": 1})
1004
1005 # Both hashes are still present
1006 self.assertClientGetHash(self.client, taskhash2, unihash2)
1007 self.assertClientGetHash(self.client, taskhash, unihash)
1008
1009 # Sweep with the new mark
1010 ret = self.client.gc_sweep("DEF")
1011 self.assertEqual(ret, {"count": 1})
1012
1013 # First hash is gone, second is kept
1014 self.assertClientGetHash(self.client, taskhash2, unihash2)
1015 self.assertClientGetHash(self.client, taskhash, None)
1016
1017 def test_gc_switch_sweep_mark(self):
1018 taskhash = '53b8dce672cb6d0c73170be43f540460bfc347b4'
1019 outhash = '5a9cb1649625f0bf41fc7791b635cd9c2d7118c7f021ba87dcd03f72b67ce7a8'
1020 unihash = 'f37918cc02eb5a520b1aff86faacbc0a38124646'
1021
1022 result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
1023 self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
1024
1025 taskhash2 = '3bf6f1e89d26205aec90da04854fbdbf73afe6b4'
1026 outhash2 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4'
1027 unihash2 = 'af36b199320e611fbb16f1f277d3ee1d619ca58b'
1028
1029 result = self.client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2)
1030 self.assertClientGetHash(self.client, taskhash2, unihash2)
1031
1032 # Mark the first unihash to be kept
1033 ret = self.client.gc_mark("ABC", {"unihash": unihash, "method": self.METHOD})
1034 self.assertEqual(ret, {"count": 1})
1035
1036 ret = self.client.gc_status()
1037 self.assertEqual(ret, {"mark": "ABC", "keep": 1, "remove": 1})
1038
1039 # Sweeping with a different mark raises an error
1040 with self.assertRaises(InvokeError):
1041 self.client.gc_sweep("DEF")
1042
1043 # Both hashes are present
1044 self.assertClientGetHash(self.client, taskhash2, unihash2)
1045 self.assertClientGetHash(self.client, taskhash, unihash)
1046
1047 def test_gc_new_hashes(self):
1048 taskhash = '53b8dce672cb6d0c73170be43f540460bfc347b4'
1049 outhash = '5a9cb1649625f0bf41fc7791b635cd9c2d7118c7f021ba87dcd03f72b67ce7a8'
1050 unihash = 'f37918cc02eb5a520b1aff86faacbc0a38124646'
1051
1052 result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
1053 self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
1054
1055 # Start a new garbage collection
1056 ret = self.client.gc_mark("ABC", {"unihash": unihash, "method": self.METHOD})
1057 self.assertEqual(ret, {"count": 1})
1058
1059 ret = self.client.gc_status()
1060 self.assertEqual(ret, {"mark": "ABC", "keep": 1, "remove": 0})
1061
1062 # Add second hash. It should inherit the mark from the current garbage
1063 # collection operation
1064
1065 taskhash2 = '3bf6f1e89d26205aec90da04854fbdbf73afe6b4'
1066 outhash2 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4'
1067 unihash2 = 'af36b199320e611fbb16f1f277d3ee1d619ca58b'
1068
1069 result = self.client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2)
1070 self.assertClientGetHash(self.client, taskhash2, unihash2)
1071
1072 # Sweep should remove nothing
1073 ret = self.client.gc_sweep("ABC")
1074 self.assertEqual(ret, {"count": 0})
1075
1076 # Both hashes are present
1077 self.assertClientGetHash(self.client, taskhash2, unihash2)
1078 self.assertClientGetHash(self.client, taskhash, unihash)
1079
1080
1081class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase):
1082 def get_server_addr(self, server_idx):
1083 return "unix://" + os.path.join(self.temp_dir.name, 'sock%d' % server_idx)
1084
1085 def test_get(self):
1086 taskhash, outhash, unihash = self.create_test_hash(self.client)
1087
1088 p = self.run_hashclient(["--address", self.server_address, "get", self.METHOD, taskhash])
1089 data = json.loads(p.stdout)
1090 self.assertEqual(data["unihash"], unihash)
1091 self.assertEqual(data["outhash"], outhash)
1092 self.assertEqual(data["taskhash"], taskhash)
1093 self.assertEqual(data["method"], self.METHOD)
1094
1095 def test_get_outhash(self):
1096 taskhash, outhash, unihash = self.create_test_hash(self.client)
1097
1098 p = self.run_hashclient(["--address", self.server_address, "get-outhash", self.METHOD, outhash, taskhash])
1099 data = json.loads(p.stdout)
1100 self.assertEqual(data["unihash"], unihash)
1101 self.assertEqual(data["outhash"], outhash)
1102 self.assertEqual(data["taskhash"], taskhash)
1103 self.assertEqual(data["method"], self.METHOD)
1104
1105 def test_stats(self):
1106 p = self.run_hashclient(["--address", self.server_address, "stats"], check=True)
1107 json.loads(p.stdout)
1108
1109 def test_stress(self):
1110 self.run_hashclient(["--address", self.server_address, "stress"], check=True)
1111
1112 def test_unihash_exsits(self):
1113 taskhash, outhash, unihash = self.create_test_hash(self.client)
1114
1115 p = self.run_hashclient([
1116 "--address", self.server_address,
1117 "unihash-exists", unihash,
1118 ], check=True)
1119 self.assertEqual(p.stdout.strip(), "true")
1120
1121 p = self.run_hashclient([
1122 "--address", self.server_address,
1123 "unihash-exists", '6662e699d6e3d894b24408ff9a4031ef9b038ee8',
1124 ], check=True)
1125 self.assertEqual(p.stdout.strip(), "false")
1126
1127 def test_unihash_exsits_quiet(self):
1128 taskhash, outhash, unihash = self.create_test_hash(self.client)
1129
1130 p = self.run_hashclient([
1131 "--address", self.server_address,
1132 "unihash-exists", unihash,
1133 "--quiet",
1134 ])
1135 self.assertEqual(p.returncode, 0)
1136 self.assertEqual(p.stdout.strip(), "")
1137
1138 p = self.run_hashclient([
1139 "--address", self.server_address,
1140 "unihash-exists", '6662e699d6e3d894b24408ff9a4031ef9b038ee8',
1141 "--quiet",
1142 ])
1143 self.assertEqual(p.returncode, 1)
1144 self.assertEqual(p.stdout.strip(), "")
1145
1146 def test_remove_taskhash(self):
1147 taskhash, outhash, unihash = self.create_test_hash(self.client)
1148 self.run_hashclient([
1149 "--address", self.server_address,
1150 "remove",
1151 "--where", "taskhash", taskhash,
1152 ], check=True)
1153 self.assertClientGetHash(self.client, taskhash, None)
1154
1155 result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash)
1156 self.assertIsNone(result_outhash)
1157
1158 def test_remove_unihash(self):
1159 taskhash, outhash, unihash = self.create_test_hash(self.client)
1160 self.run_hashclient([
1161 "--address", self.server_address,
1162 "remove",
1163 "--where", "unihash", unihash,
1164 ], check=True)
1165 self.assertClientGetHash(self.client, taskhash, None)
1166
1167 def test_remove_outhash(self):
1168 taskhash, outhash, unihash = self.create_test_hash(self.client)
1169 self.run_hashclient([
1170 "--address", self.server_address,
1171 "remove",
1172 "--where", "outhash", outhash,
1173 ], check=True)
1174
1175 result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash)
1176 self.assertIsNone(result_outhash)
1177
1178 def test_remove_method(self):
1179 taskhash, outhash, unihash = self.create_test_hash(self.client)
1180 self.run_hashclient([
1181 "--address", self.server_address,
1182 "remove",
1183 "--where", "method", self.METHOD,
1184 ], check=True)
1185 self.assertClientGetHash(self.client, taskhash, None)
1186
1187 result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash)
1188 self.assertIsNone(result_outhash)
1189
1190 def test_clean_unused(self):
1191 taskhash, outhash, unihash = self.create_test_hash(self.client)
1192
1193 # Clean the database, which should not remove anything because all hashes an in-use
1194 self.run_hashclient([
1195 "--address", self.server_address,
1196 "clean-unused", "0",
1197 ], check=True)
1198 self.assertClientGetHash(self.client, taskhash, unihash)
1199
1200 # Remove the unihash. The row in the outhash table should still be present
1201 self.run_hashclient([
1202 "--address", self.server_address,
1203 "remove",
1204 "--where", "unihash", unihash,
1205 ], check=True)
1206 result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash, False)
1207 self.assertIsNotNone(result_outhash)
1208
1209 # Now clean with no minimum age which will remove the outhash
1210 self.run_hashclient([
1211 "--address", self.server_address,
1212 "clean-unused", "0",
1213 ], check=True)
1214 result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash, False)
1215 self.assertIsNone(result_outhash)
1216
1217 def test_refresh_token(self):
1218 admin_client = self.start_auth_server()
1219
1220 user = admin_client.new_user("test-user", ["@read", "@report"])
1221
1222 p = self.run_hashclient([
1223 "--address", self.auth_server_address,
1224 "--login", user["username"],
1225 "--password", user["token"],
1226 "refresh-token"
1227 ], check=True)
1228
1229 new_token = None
1230 for l in p.stdout.splitlines():
1231 l = l.rstrip()
1232 m = re.match(r'Token: +(.*)$', l)
1233 if m is not None:
1234 new_token = m.group(1)
1235
1236 self.assertTrue(new_token)
1237
1238 print("New token is %r" % new_token)
1239
1240 self.run_hashclient([
1241 "--address", self.auth_server_address,
1242 "--login", user["username"],
1243 "--password", new_token,
1244 "get-user"
1245 ], check=True)
1246
1247 def test_set_user_perms(self):
1248 admin_client = self.start_auth_server()
1249
1250 user = admin_client.new_user("test-user", ["@read"])
1251
1252 self.run_hashclient([
1253 "--address", self.auth_server_address,
1254 "--login", admin_client.username,
1255 "--password", admin_client.password,
1256 "set-user-perms",
1257 "-u", user["username"],
1258 "@read", "@report",
1259 ], check=True)
1260
1261 new_user = admin_client.get_user(user["username"])
1262
1263 self.assertEqual(set(new_user["permissions"]), {"@read", "@report"})
1264
1265 def test_get_user(self):
1266 admin_client = self.start_auth_server()
1267
1268 user = admin_client.new_user("test-user", ["@read"])
1269
1270 p = self.run_hashclient([
1271 "--address", self.auth_server_address,
1272 "--login", admin_client.username,
1273 "--password", admin_client.password,
1274 "get-user",
1275 "-u", user["username"],
1276 ], check=True)
1277
1278 self.assertIn("Username:", p.stdout)
1279 self.assertIn("Permissions:", p.stdout)
1280
1281 p = self.run_hashclient([
1282 "--address", self.auth_server_address,
1283 "--login", user["username"],
1284 "--password", user["token"],
1285 "get-user",
1286 ], check=True)
1287
1288 self.assertIn("Username:", p.stdout)
1289 self.assertIn("Permissions:", p.stdout)
1290
1291 def test_get_all_users(self):
1292 admin_client = self.start_auth_server()
1293
1294 admin_client.new_user("test-user1", ["@read"])
1295 admin_client.new_user("test-user2", ["@read"])
1296
1297 p = self.run_hashclient([
1298 "--address", self.auth_server_address,
1299 "--login", admin_client.username,
1300 "--password", admin_client.password,
1301 "get-all-users",
1302 ], check=True)
1303
1304 self.assertIn("admin", p.stdout)
1305 self.assertIn("test-user1", p.stdout)
1306 self.assertIn("test-user2", p.stdout)
1307
1308 def test_new_user(self):
1309 admin_client = self.start_auth_server()
1310
1311 p = self.run_hashclient([
1312 "--address", self.auth_server_address,
1313 "--login", admin_client.username,
1314 "--password", admin_client.password,
1315 "new-user",
1316 "-u", "test-user",
1317 "@read", "@report",
1318 ], check=True)
1319
1320 new_token = None
1321 for l in p.stdout.splitlines():
1322 l = l.rstrip()
1323 m = re.match(r'Token: +(.*)$', l)
1324 if m is not None:
1325 new_token = m.group(1)
1326
1327 self.assertTrue(new_token)
1328
1329 user = {
1330 "username": "test-user",
1331 "token": new_token,
1332 }
1333
1334 self.assertUserPerms(user, ["@read", "@report"])
1335
1336 def test_delete_user(self):
1337 admin_client = self.start_auth_server()
1338
1339 user = admin_client.new_user("test-user", ["@read"])
1340
1341 p = self.run_hashclient([
1342 "--address", self.auth_server_address,
1343 "--login", admin_client.username,
1344 "--password", admin_client.password,
1345 "delete-user",
1346 "-u", user["username"],
1347 ], check=True)
1348
1349 self.assertIsNone(admin_client.get_user(user["username"]))
1350
1351 def test_get_db_usage(self):
1352 p = self.run_hashclient([
1353 "--address", self.server_address,
1354 "get-db-usage",
1355 ], check=True)
1356
1357 def test_get_db_query_columns(self):
1358 p = self.run_hashclient([
1359 "--address", self.server_address,
1360 "get-db-query-columns",
1361 ], check=True)
1362
1363 def test_gc(self):
1364 taskhash = '53b8dce672cb6d0c73170be43f540460bfc347b4'
1365 outhash = '5a9cb1649625f0bf41fc7791b635cd9c2d7118c7f021ba87dcd03f72b67ce7a8'
1366 unihash = 'f37918cc02eb5a520b1aff86faacbc0a38124646'
1367
1368 result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
1369 self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
1370
1371 taskhash2 = '3bf6f1e89d26205aec90da04854fbdbf73afe6b4'
1372 outhash2 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4'
1373 unihash2 = 'af36b199320e611fbb16f1f277d3ee1d619ca58b'
1374
1375 result = self.client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2)
1376 self.assertClientGetHash(self.client, taskhash2, unihash2)
1377
1378 # Mark the first unihash to be kept
1379 self.run_hashclient([
1380 "--address", self.server_address,
1381 "gc-mark", "ABC",
1382 "--where", "unihash", unihash,
1383 "--where", "method", self.METHOD
1384 ], check=True)
1385
1386 # Second hash is still there; mark doesn't delete hashes
1387 self.assertClientGetHash(self.client, taskhash2, unihash2)
1388
1389 self.run_hashclient([
1390 "--address", self.server_address,
1391 "gc-sweep", "ABC",
1392 ], check=True)
1393
1394 # Hash is gone. Taskhash is returned for second hash
1395 self.assertClientGetHash(self.client, taskhash2, None)
1396 # First hash is still present
1397 self.assertClientGetHash(self.client, taskhash, unihash)
285 1398
286 1399
287class TestHashEquivalenceUnixServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase): 1400class TestHashEquivalenceUnixServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
@@ -314,3 +1427,77 @@ class TestHashEquivalenceTCPServer(HashEquivalenceTestSetup, HashEquivalenceComm
314 # If IPv6 is enabled, it should be safe to use localhost directly, in general 1427 # If IPv6 is enabled, it should be safe to use localhost directly, in general
315 # case it is more reliable to resolve the IP address explicitly. 1428 # case it is more reliable to resolve the IP address explicitly.
316 return socket.gethostbyname("localhost") + ":0" 1429 return socket.gethostbyname("localhost") + ":0"
1430
1431
1432class TestHashEquivalenceWebsocketServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
1433 def setUp(self):
1434 try:
1435 import websockets
1436 except ImportError as e:
1437 self.skipTest(str(e))
1438
1439 super().setUp()
1440
1441 def get_server_addr(self, server_idx):
1442 # Some hosts cause asyncio module to misbehave, when IPv6 is not enabled.
1443 # If IPv6 is enabled, it should be safe to use localhost directly, in general
1444 # case it is more reliable to resolve the IP address explicitly.
1445 host = socket.gethostbyname("localhost")
1446 return "ws://%s:0" % host
1447
1448
1449class TestHashEquivalenceWebsocketsSQLAlchemyServer(TestHashEquivalenceWebsocketServer):
1450 def setUp(self):
1451 try:
1452 import sqlalchemy
1453 import aiosqlite
1454 except ImportError as e:
1455 self.skipTest(str(e))
1456
1457 super().setUp()
1458
1459 def make_dbpath(self):
1460 return "sqlite+aiosqlite:///%s" % os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index)
1461
1462
1463class TestHashEquivalenceExternalServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
1464 def get_env(self, name):
1465 v = os.environ.get(name)
1466 if not v:
1467 self.skipTest(f'{name} not defined to test an external server')
1468 return v
1469
1470 def start_test_server(self):
1471 return self.get_env('BB_TEST_HASHSERV')
1472
1473 def start_server(self, *args, **kwargs):
1474 self.skipTest('Cannot start local server when testing external servers')
1475
1476 def start_auth_server(self):
1477
1478 self.auth_server_address = self.server_address
1479 self.admin_client = self.start_client(
1480 self.server_address,
1481 username=self.get_env('BB_TEST_HASHSERV_USERNAME'),
1482 password=self.get_env('BB_TEST_HASHSERV_PASSWORD'),
1483 )
1484 return self.admin_client
1485
1486 def setUp(self):
1487 super().setUp()
1488 if "BB_TEST_HASHSERV_USERNAME" in os.environ:
1489 self.client = self.start_client(
1490 self.server_address,
1491 username=os.environ["BB_TEST_HASHSERV_USERNAME"],
1492 password=os.environ["BB_TEST_HASHSERV_PASSWORD"],
1493 )
1494 self.client.remove({"method": self.METHOD})
1495
1496 def tearDown(self):
1497 self.client.remove({"method": self.METHOD})
1498 super().tearDown()
1499
1500
1501 def test_auth_get_all_users(self):
1502 self.skipTest("Cannot test all users with external server")
1503