diff options
Diffstat (limited to 'bitbake/lib/hashserv')
-rw-r--r-- | bitbake/lib/hashserv/__init__.py | 175 | ||||
-rw-r--r-- | bitbake/lib/hashserv/client.py | 431 | ||||
-rw-r--r-- | bitbake/lib/hashserv/server.py | 1088 | ||||
-rw-r--r-- | bitbake/lib/hashserv/sqlalchemy.py | 598 | ||||
-rw-r--r-- | bitbake/lib/hashserv/sqlite.py | 562 | ||||
-rw-r--r-- | bitbake/lib/hashserv/tests.py | 1267 |
6 files changed, 3411 insertions, 710 deletions
diff --git a/bitbake/lib/hashserv/__init__.py b/bitbake/lib/hashserv/__init__.py index 5f2e101e52..ac891e0174 100644 --- a/bitbake/lib/hashserv/__init__.py +++ b/bitbake/lib/hashserv/__init__.py | |||
@@ -5,129 +5,104 @@ | |||
5 | 5 | ||
6 | import asyncio | 6 | import asyncio |
7 | from contextlib import closing | 7 | from contextlib import closing |
8 | import re | ||
9 | import sqlite3 | ||
10 | import itertools | 8 | import itertools |
11 | import json | 9 | import json |
10 | from collections import namedtuple | ||
11 | from urllib.parse import urlparse | ||
12 | from bb.asyncrpc.client import parse_address, ADDR_TYPE_UNIX, ADDR_TYPE_WS | ||
12 | 13 | ||
13 | UNIX_PREFIX = "unix://" | 14 | User = namedtuple("User", ("username", "permissions")) |
14 | |||
15 | ADDR_TYPE_UNIX = 0 | ||
16 | ADDR_TYPE_TCP = 1 | ||
17 | |||
18 | # The Python async server defaults to a 64K receive buffer, so we hardcode our | ||
19 | # maximum chunk size. It would be better if the client and server reported to | ||
20 | # each other what the maximum chunk sizes were, but that will slow down the | ||
21 | # connection setup with a round trip delay so I'd rather not do that unless it | ||
22 | # is necessary | ||
23 | DEFAULT_MAX_CHUNK = 32 * 1024 | ||
24 | |||
25 | TABLE_DEFINITION = ( | ||
26 | ("method", "TEXT NOT NULL"), | ||
27 | ("outhash", "TEXT NOT NULL"), | ||
28 | ("taskhash", "TEXT NOT NULL"), | ||
29 | ("unihash", "TEXT NOT NULL"), | ||
30 | ("created", "DATETIME"), | ||
31 | |||
32 | # Optional fields | ||
33 | ("owner", "TEXT"), | ||
34 | ("PN", "TEXT"), | ||
35 | ("PV", "TEXT"), | ||
36 | ("PR", "TEXT"), | ||
37 | ("task", "TEXT"), | ||
38 | ("outhash_siginfo", "TEXT"), | ||
39 | ) | ||
40 | |||
41 | TABLE_COLUMNS = tuple(name for name, _ in TABLE_DEFINITION) | ||
42 | |||
43 | def setup_database(database, sync=True): | ||
44 | db = sqlite3.connect(database) | ||
45 | db.row_factory = sqlite3.Row | ||
46 | |||
47 | with closing(db.cursor()) as cursor: | ||
48 | cursor.execute(''' | ||
49 | CREATE TABLE IF NOT EXISTS tasks_v2 ( | ||
50 | id INTEGER PRIMARY KEY AUTOINCREMENT, | ||
51 | %s | ||
52 | UNIQUE(method, outhash, taskhash) | ||
53 | ) | ||
54 | ''' % " ".join("%s %s," % (name, typ) for name, typ in TABLE_DEFINITION)) | ||
55 | cursor.execute('PRAGMA journal_mode = WAL') | ||
56 | cursor.execute('PRAGMA synchronous = %s' % ('NORMAL' if sync else 'OFF')) | ||
57 | |||
58 | # Drop old indexes | ||
59 | cursor.execute('DROP INDEX IF EXISTS taskhash_lookup') | ||
60 | cursor.execute('DROP INDEX IF EXISTS outhash_lookup') | ||
61 | |||
62 | # Create new indexes | ||
63 | cursor.execute('CREATE INDEX IF NOT EXISTS taskhash_lookup_v2 ON tasks_v2 (method, taskhash, created)') | ||
64 | cursor.execute('CREATE INDEX IF NOT EXISTS outhash_lookup_v2 ON tasks_v2 (method, outhash)') | ||
65 | |||
66 | return db | ||
67 | |||
68 | |||
69 | def parse_address(addr): | ||
70 | if addr.startswith(UNIX_PREFIX): | ||
71 | return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX):],)) | ||
72 | else: | ||
73 | m = re.match(r'\[(?P<host>[^\]]*)\]:(?P<port>\d+)$', addr) | ||
74 | if m is not None: | ||
75 | host = m.group('host') | ||
76 | port = m.group('port') | ||
77 | else: | ||
78 | host, port = addr.split(':') | ||
79 | 15 | ||
80 | return (ADDR_TYPE_TCP, (host, int(port))) | ||
81 | 16 | ||
17 | def create_server( | ||
18 | addr, | ||
19 | dbname, | ||
20 | *, | ||
21 | sync=True, | ||
22 | upstream=None, | ||
23 | read_only=False, | ||
24 | db_username=None, | ||
25 | db_password=None, | ||
26 | anon_perms=None, | ||
27 | admin_username=None, | ||
28 | admin_password=None, | ||
29 | reuseport=False, | ||
30 | ): | ||
31 | def sqlite_engine(): | ||
32 | from .sqlite import DatabaseEngine | ||
82 | 33 | ||
83 | def chunkify(msg, max_chunk): | 34 | return DatabaseEngine(dbname, sync) |
84 | if len(msg) < max_chunk - 1: | ||
85 | yield ''.join((msg, "\n")) | ||
86 | else: | ||
87 | yield ''.join((json.dumps({ | ||
88 | 'chunk-stream': None | ||
89 | }), "\n")) | ||
90 | 35 | ||
91 | args = [iter(msg)] * (max_chunk - 1) | 36 | def sqlalchemy_engine(): |
92 | for m in map(''.join, itertools.zip_longest(*args, fillvalue='')): | 37 | from .sqlalchemy import DatabaseEngine |
93 | yield ''.join(itertools.chain(m, "\n")) | ||
94 | yield "\n" | ||
95 | 38 | ||
39 | return DatabaseEngine(dbname, db_username, db_password) | ||
96 | 40 | ||
97 | def create_server(addr, dbname, *, sync=True, upstream=None, read_only=False): | ||
98 | from . import server | 41 | from . import server |
99 | db = setup_database(dbname, sync=sync) | 42 | |
100 | s = server.Server(db, upstream=upstream, read_only=read_only) | 43 | if "://" in dbname: |
44 | db_engine = sqlalchemy_engine() | ||
45 | else: | ||
46 | db_engine = sqlite_engine() | ||
47 | |||
48 | if anon_perms is None: | ||
49 | anon_perms = server.DEFAULT_ANON_PERMS | ||
50 | |||
51 | s = server.Server( | ||
52 | db_engine, | ||
53 | upstream=upstream, | ||
54 | read_only=read_only, | ||
55 | anon_perms=anon_perms, | ||
56 | admin_username=admin_username, | ||
57 | admin_password=admin_password, | ||
58 | ) | ||
101 | 59 | ||
102 | (typ, a) = parse_address(addr) | 60 | (typ, a) = parse_address(addr) |
103 | if typ == ADDR_TYPE_UNIX: | 61 | if typ == ADDR_TYPE_UNIX: |
104 | s.start_unix_server(*a) | 62 | s.start_unix_server(*a) |
63 | elif typ == ADDR_TYPE_WS: | ||
64 | url = urlparse(a[0]) | ||
65 | s.start_websocket_server(url.hostname, url.port, reuseport=reuseport) | ||
105 | else: | 66 | else: |
106 | s.start_tcp_server(*a) | 67 | s.start_tcp_server(*a, reuseport=reuseport) |
107 | 68 | ||
108 | return s | 69 | return s |
109 | 70 | ||
110 | 71 | ||
111 | def create_client(addr): | 72 | def create_client(addr, username=None, password=None): |
112 | from . import client | 73 | from . import client |
113 | c = client.Client() | ||
114 | 74 | ||
115 | (typ, a) = parse_address(addr) | 75 | c = client.Client(username, password) |
116 | if typ == ADDR_TYPE_UNIX: | 76 | |
117 | c.connect_unix(*a) | 77 | try: |
118 | else: | 78 | (typ, a) = parse_address(addr) |
119 | c.connect_tcp(*a) | 79 | if typ == ADDR_TYPE_UNIX: |
80 | c.connect_unix(*a) | ||
81 | elif typ == ADDR_TYPE_WS: | ||
82 | c.connect_websocket(*a) | ||
83 | else: | ||
84 | c.connect_tcp(*a) | ||
85 | return c | ||
86 | except Exception as e: | ||
87 | c.close() | ||
88 | raise e | ||
120 | 89 | ||
121 | return c | ||
122 | 90 | ||
123 | async def create_async_client(addr): | 91 | async def create_async_client(addr, username=None, password=None): |
124 | from . import client | 92 | from . import client |
125 | c = client.AsyncClient() | ||
126 | 93 | ||
127 | (typ, a) = parse_address(addr) | 94 | c = client.AsyncClient(username, password) |
128 | if typ == ADDR_TYPE_UNIX: | 95 | |
129 | await c.connect_unix(*a) | 96 | try: |
130 | else: | 97 | (typ, a) = parse_address(addr) |
131 | await c.connect_tcp(*a) | 98 | if typ == ADDR_TYPE_UNIX: |
99 | await c.connect_unix(*a) | ||
100 | elif typ == ADDR_TYPE_WS: | ||
101 | await c.connect_websocket(*a) | ||
102 | else: | ||
103 | await c.connect_tcp(*a) | ||
132 | 104 | ||
133 | return c | 105 | return c |
106 | except Exception as e: | ||
107 | await c.close() | ||
108 | raise e | ||
diff --git a/bitbake/lib/hashserv/client.py b/bitbake/lib/hashserv/client.py index e05c1eb568..a510f3284f 100644 --- a/bitbake/lib/hashserv/client.py +++ b/bitbake/lib/hashserv/client.py | |||
@@ -3,231 +3,356 @@ | |||
3 | # SPDX-License-Identifier: GPL-2.0-only | 3 | # SPDX-License-Identifier: GPL-2.0-only |
4 | # | 4 | # |
5 | 5 | ||
6 | import asyncio | ||
7 | import json | ||
8 | import logging | 6 | import logging |
9 | import socket | 7 | import socket |
10 | import os | 8 | import asyncio |
11 | from . import chunkify, DEFAULT_MAX_CHUNK, create_async_client | 9 | import bb.asyncrpc |
10 | import json | ||
11 | from . import create_async_client | ||
12 | 12 | ||
13 | 13 | ||
14 | logger = logging.getLogger("hashserv.client") | 14 | logger = logging.getLogger("hashserv.client") |
15 | 15 | ||
16 | 16 | ||
17 | class HashConnectionError(Exception): | 17 | class Batch(object): |
18 | pass | 18 | def __init__(self): |
19 | self.done = False | ||
20 | self.cond = asyncio.Condition() | ||
21 | self.pending = [] | ||
22 | self.results = [] | ||
23 | self.sent_count = 0 | ||
19 | 24 | ||
25 | async def recv(self, socket): | ||
26 | while True: | ||
27 | async with self.cond: | ||
28 | await self.cond.wait_for(lambda: self.pending or self.done) | ||
20 | 29 | ||
21 | class AsyncClient(object): | 30 | if not self.pending: |
22 | MODE_NORMAL = 0 | 31 | if self.done: |
23 | MODE_GET_STREAM = 1 | 32 | return |
33 | continue | ||
24 | 34 | ||
25 | def __init__(self): | 35 | r = await socket.recv() |
26 | self.reader = None | 36 | self.results.append(r) |
27 | self.writer = None | ||
28 | self.mode = self.MODE_NORMAL | ||
29 | self.max_chunk = DEFAULT_MAX_CHUNK | ||
30 | 37 | ||
31 | async def connect_tcp(self, address, port): | 38 | async with self.cond: |
32 | async def connect_sock(): | 39 | self.pending.pop(0) |
33 | return await asyncio.open_connection(address, port) | ||
34 | 40 | ||
35 | self._connect_sock = connect_sock | 41 | async def send(self, socket, msgs): |
42 | try: | ||
43 | # In the event of a restart due to a reconnect, all in-flight | ||
44 | # messages need to be resent first to keep to result count in sync | ||
45 | for m in self.pending: | ||
46 | await socket.send(m) | ||
36 | 47 | ||
37 | async def connect_unix(self, path): | 48 | for m in msgs: |
38 | async def connect_sock(): | 49 | # Add the message to the pending list before attempting to send |
39 | return await asyncio.open_unix_connection(path) | 50 | # it so that if the send fails it will be retried |
51 | async with self.cond: | ||
52 | self.pending.append(m) | ||
53 | self.cond.notify() | ||
54 | self.sent_count += 1 | ||
40 | 55 | ||
41 | self._connect_sock = connect_sock | 56 | await socket.send(m) |
42 | 57 | ||
43 | async def connect(self): | 58 | finally: |
44 | if self.reader is None or self.writer is None: | 59 | async with self.cond: |
45 | (self.reader, self.writer) = await self._connect_sock() | 60 | self.done = True |
61 | self.cond.notify() | ||
62 | |||
63 | async def process(self, socket, msgs): | ||
64 | await asyncio.gather( | ||
65 | self.recv(socket), | ||
66 | self.send(socket, msgs), | ||
67 | ) | ||
46 | 68 | ||
47 | self.writer.write("OEHASHEQUIV 1.1\n\n".encode("utf-8")) | 69 | if len(self.results) != self.sent_count: |
48 | await self.writer.drain() | 70 | raise ValueError( |
71 | f"Expected result count {len(self.results)}. Expected {self.sent_count}" | ||
72 | ) | ||
49 | 73 | ||
50 | cur_mode = self.mode | 74 | return self.results |
51 | self.mode = self.MODE_NORMAL | ||
52 | await self._set_mode(cur_mode) | ||
53 | 75 | ||
54 | async def close(self): | ||
55 | self.reader = None | ||
56 | 76 | ||
57 | if self.writer is not None: | 77 | class AsyncClient(bb.asyncrpc.AsyncClient): |
58 | self.writer.close() | 78 | MODE_NORMAL = 0 |
59 | self.writer = None | 79 | MODE_GET_STREAM = 1 |
80 | MODE_EXIST_STREAM = 2 | ||
60 | 81 | ||
61 | async def _send_wrapper(self, proc): | 82 | def __init__(self, username=None, password=None): |
62 | count = 0 | 83 | super().__init__("OEHASHEQUIV", "1.1", logger) |
63 | while True: | 84 | self.mode = self.MODE_NORMAL |
64 | try: | 85 | self.username = username |
65 | await self.connect() | 86 | self.password = password |
66 | return await proc() | 87 | self.saved_become_user = None |
67 | except ( | ||
68 | OSError, | ||
69 | HashConnectionError, | ||
70 | json.JSONDecodeError, | ||
71 | UnicodeDecodeError, | ||
72 | ) as e: | ||
73 | logger.warning("Error talking to server: %s" % e) | ||
74 | if count >= 3: | ||
75 | if not isinstance(e, HashConnectionError): | ||
76 | raise HashConnectionError(str(e)) | ||
77 | raise e | ||
78 | await self.close() | ||
79 | count += 1 | ||
80 | |||
81 | async def send_message(self, msg): | ||
82 | async def get_line(): | ||
83 | line = await self.reader.readline() | ||
84 | if not line: | ||
85 | raise HashConnectionError("Connection closed") | ||
86 | |||
87 | line = line.decode("utf-8") | ||
88 | |||
89 | if not line.endswith("\n"): | ||
90 | raise HashConnectionError("Bad message %r" % message) | ||
91 | |||
92 | return line | ||
93 | 88 | ||
94 | async def proc(): | 89 | async def setup_connection(self): |
95 | for c in chunkify(json.dumps(msg), self.max_chunk): | 90 | await super().setup_connection() |
96 | self.writer.write(c.encode("utf-8")) | 91 | self.mode = self.MODE_NORMAL |
97 | await self.writer.drain() | 92 | if self.username: |
93 | # Save off become user temporarily because auth() resets it | ||
94 | become = self.saved_become_user | ||
95 | await self.auth(self.username, self.password) | ||
98 | 96 | ||
99 | l = await get_line() | 97 | if become: |
98 | await self.become_user(become) | ||
100 | 99 | ||
101 | m = json.loads(l) | 100 | async def send_stream_batch(self, mode, msgs): |
102 | if m and "chunk-stream" in m: | 101 | """ |
103 | lines = [] | 102 | Does a "batch" process of stream messages. This sends the query |
104 | while True: | 103 | messages as fast as possible, and simultaneously attempts to read the |
105 | l = (await get_line()).rstrip("\n") | 104 | messages back. This helps to mitigate the effects of latency to the |
106 | if not l: | 105 | hash equivalence server be allowing multiple queries to be "in-flight" |
107 | break | 106 | at once |
108 | lines.append(l) | ||
109 | 107 | ||
110 | m = json.loads("".join(lines)) | 108 | The implementation does more complicated tracking using a count of sent |
109 | messages so that `msgs` can be a generator function (i.e. its length is | ||
110 | unknown) | ||
111 | 111 | ||
112 | return m | 112 | """ |
113 | 113 | ||
114 | return await self._send_wrapper(proc) | 114 | b = Batch() |
115 | 115 | ||
116 | async def send_stream(self, msg): | ||
117 | async def proc(): | 116 | async def proc(): |
118 | self.writer.write(("%s\n" % msg).encode("utf-8")) | 117 | nonlocal b |
119 | await self.writer.drain() | 118 | |
120 | l = await self.reader.readline() | 119 | await self._set_mode(mode) |
121 | if not l: | 120 | return await b.process(self.socket, msgs) |
122 | raise HashConnectionError("Connection closed") | ||
123 | return l.decode("utf-8").rstrip() | ||
124 | 121 | ||
125 | return await self._send_wrapper(proc) | 122 | return await self._send_wrapper(proc) |
126 | 123 | ||
124 | async def invoke(self, *args, skip_mode=False, **kwargs): | ||
125 | # It's OK if connection errors cause a failure here, because the mode | ||
126 | # is also reset to normal on a new connection | ||
127 | if not skip_mode: | ||
128 | await self._set_mode(self.MODE_NORMAL) | ||
129 | return await super().invoke(*args, **kwargs) | ||
130 | |||
127 | async def _set_mode(self, new_mode): | 131 | async def _set_mode(self, new_mode): |
128 | if new_mode == self.MODE_NORMAL and self.mode == self.MODE_GET_STREAM: | 132 | async def stream_to_normal(): |
129 | r = await self.send_stream("END") | 133 | # Check if already in normal mode (e.g. due to a connection reset) |
134 | if self.mode == self.MODE_NORMAL: | ||
135 | return "ok" | ||
136 | await self.socket.send("END") | ||
137 | return await self.socket.recv() | ||
138 | |||
139 | async def normal_to_stream(command): | ||
140 | r = await self.invoke({command: None}, skip_mode=True) | ||
130 | if r != "ok": | 141 | if r != "ok": |
131 | raise HashConnectionError("Bad response from server %r" % r) | 142 | self.check_invoke_error(r) |
132 | elif new_mode == self.MODE_GET_STREAM and self.mode == self.MODE_NORMAL: | 143 | raise ConnectionError( |
133 | r = await self.send_message({"get-stream": None}) | 144 | f"Unable to transition to stream mode: Bad response from server {r!r}" |
145 | ) | ||
146 | self.logger.debug("Mode is now %s", command) | ||
147 | |||
148 | if new_mode == self.mode: | ||
149 | return | ||
150 | |||
151 | self.logger.debug("Transitioning mode %s -> %s", self.mode, new_mode) | ||
152 | |||
153 | # Always transition to normal mode before switching to any other mode | ||
154 | if self.mode != self.MODE_NORMAL: | ||
155 | r = await self._send_wrapper(stream_to_normal) | ||
134 | if r != "ok": | 156 | if r != "ok": |
135 | raise HashConnectionError("Bad response from server %r" % r) | 157 | self.check_invoke_error(r) |
136 | elif new_mode != self.mode: | 158 | raise ConnectionError( |
137 | raise Exception( | 159 | f"Unable to transition to normal mode: Bad response from server {r!r}" |
138 | "Undefined mode transition %r -> %r" % (self.mode, new_mode) | 160 | ) |
139 | ) | 161 | self.logger.debug("Mode is now normal") |
162 | |||
163 | if new_mode == self.MODE_GET_STREAM: | ||
164 | await normal_to_stream("get-stream") | ||
165 | elif new_mode == self.MODE_EXIST_STREAM: | ||
166 | await normal_to_stream("exists-stream") | ||
167 | elif new_mode != self.MODE_NORMAL: | ||
168 | raise Exception("Undefined mode transition {self.mode!r} -> {new_mode!r}") | ||
140 | 169 | ||
141 | self.mode = new_mode | 170 | self.mode = new_mode |
142 | 171 | ||
143 | async def get_unihash(self, method, taskhash): | 172 | async def get_unihash(self, method, taskhash): |
144 | await self._set_mode(self.MODE_GET_STREAM) | 173 | r = await self.get_unihash_batch([(method, taskhash)]) |
145 | r = await self.send_stream("%s %s" % (method, taskhash)) | 174 | return r[0] |
146 | if not r: | 175 | |
147 | return None | 176 | async def get_unihash_batch(self, args): |
148 | return r | 177 | result = await self.send_stream_batch( |
178 | self.MODE_GET_STREAM, | ||
179 | (f"{method} {taskhash}" for method, taskhash in args), | ||
180 | ) | ||
181 | return [r if r else None for r in result] | ||
149 | 182 | ||
150 | async def report_unihash(self, taskhash, method, outhash, unihash, extra={}): | 183 | async def report_unihash(self, taskhash, method, outhash, unihash, extra={}): |
151 | await self._set_mode(self.MODE_NORMAL) | ||
152 | m = extra.copy() | 184 | m = extra.copy() |
153 | m["taskhash"] = taskhash | 185 | m["taskhash"] = taskhash |
154 | m["method"] = method | 186 | m["method"] = method |
155 | m["outhash"] = outhash | 187 | m["outhash"] = outhash |
156 | m["unihash"] = unihash | 188 | m["unihash"] = unihash |
157 | return await self.send_message({"report": m}) | 189 | return await self.invoke({"report": m}) |
158 | 190 | ||
159 | async def report_unihash_equiv(self, taskhash, method, unihash, extra={}): | 191 | async def report_unihash_equiv(self, taskhash, method, unihash, extra={}): |
160 | await self._set_mode(self.MODE_NORMAL) | ||
161 | m = extra.copy() | 192 | m = extra.copy() |
162 | m["taskhash"] = taskhash | 193 | m["taskhash"] = taskhash |
163 | m["method"] = method | 194 | m["method"] = method |
164 | m["unihash"] = unihash | 195 | m["unihash"] = unihash |
165 | return await self.send_message({"report-equiv": m}) | 196 | return await self.invoke({"report-equiv": m}) |
166 | 197 | ||
167 | async def get_taskhash(self, method, taskhash, all_properties=False): | 198 | async def get_taskhash(self, method, taskhash, all_properties=False): |
168 | await self._set_mode(self.MODE_NORMAL) | 199 | return await self.invoke( |
169 | return await self.send_message( | ||
170 | {"get": {"taskhash": taskhash, "method": method, "all": all_properties}} | 200 | {"get": {"taskhash": taskhash, "method": method, "all": all_properties}} |
171 | ) | 201 | ) |
172 | 202 | ||
173 | async def get_outhash(self, method, outhash, taskhash): | 203 | async def unihash_exists(self, unihash): |
174 | await self._set_mode(self.MODE_NORMAL) | 204 | r = await self.unihash_exists_batch([unihash]) |
175 | return await self.send_message( | 205 | return r[0] |
176 | {"get-outhash": {"outhash": outhash, "taskhash": taskhash, "method": method}} | 206 | |
207 | async def unihash_exists_batch(self, unihashes): | ||
208 | result = await self.send_stream_batch(self.MODE_EXIST_STREAM, unihashes) | ||
209 | return [r == "true" for r in result] | ||
210 | |||
211 | async def get_outhash(self, method, outhash, taskhash, with_unihash=True): | ||
212 | return await self.invoke( | ||
213 | { | ||
214 | "get-outhash": { | ||
215 | "outhash": outhash, | ||
216 | "taskhash": taskhash, | ||
217 | "method": method, | ||
218 | "with_unihash": with_unihash, | ||
219 | } | ||
220 | } | ||
177 | ) | 221 | ) |
178 | 222 | ||
179 | async def get_stats(self): | 223 | async def get_stats(self): |
180 | await self._set_mode(self.MODE_NORMAL) | 224 | return await self.invoke({"get-stats": None}) |
181 | return await self.send_message({"get-stats": None}) | ||
182 | 225 | ||
183 | async def reset_stats(self): | 226 | async def reset_stats(self): |
184 | await self._set_mode(self.MODE_NORMAL) | 227 | return await self.invoke({"reset-stats": None}) |
185 | return await self.send_message({"reset-stats": None}) | ||
186 | 228 | ||
187 | async def backfill_wait(self): | 229 | async def backfill_wait(self): |
188 | await self._set_mode(self.MODE_NORMAL) | 230 | return (await self.invoke({"backfill-wait": None}))["tasks"] |
189 | return (await self.send_message({"backfill-wait": None}))["tasks"] | 231 | |
232 | async def remove(self, where): | ||
233 | return await self.invoke({"remove": {"where": where}}) | ||
234 | |||
235 | async def clean_unused(self, max_age): | ||
236 | return await self.invoke({"clean-unused": {"max_age_seconds": max_age}}) | ||
237 | |||
238 | async def auth(self, username, token): | ||
239 | result = await self.invoke({"auth": {"username": username, "token": token}}) | ||
240 | self.username = username | ||
241 | self.password = token | ||
242 | self.saved_become_user = None | ||
243 | return result | ||
244 | |||
245 | async def refresh_token(self, username=None): | ||
246 | m = {} | ||
247 | if username: | ||
248 | m["username"] = username | ||
249 | result = await self.invoke({"refresh-token": m}) | ||
250 | if ( | ||
251 | self.username | ||
252 | and not self.saved_become_user | ||
253 | and result["username"] == self.username | ||
254 | ): | ||
255 | self.password = result["token"] | ||
256 | return result | ||
190 | 257 | ||
258 | async def set_user_perms(self, username, permissions): | ||
259 | return await self.invoke( | ||
260 | {"set-user-perms": {"username": username, "permissions": permissions}} | ||
261 | ) | ||
191 | 262 | ||
192 | class Client(object): | 263 | async def get_user(self, username=None): |
193 | def __init__(self): | 264 | m = {} |
194 | self.client = AsyncClient() | 265 | if username: |
195 | self.loop = asyncio.new_event_loop() | 266 | m["username"] = username |
267 | return await self.invoke({"get-user": m}) | ||
268 | |||
269 | async def get_all_users(self): | ||
270 | return (await self.invoke({"get-all-users": {}}))["users"] | ||
271 | |||
272 | async def new_user(self, username, permissions): | ||
273 | return await self.invoke( | ||
274 | {"new-user": {"username": username, "permissions": permissions}} | ||
275 | ) | ||
276 | |||
277 | async def delete_user(self, username): | ||
278 | return await self.invoke({"delete-user": {"username": username}}) | ||
279 | |||
280 | async def become_user(self, username): | ||
281 | result = await self.invoke({"become-user": {"username": username}}) | ||
282 | if username == self.username: | ||
283 | self.saved_become_user = None | ||
284 | else: | ||
285 | self.saved_become_user = username | ||
286 | return result | ||
287 | |||
288 | async def get_db_usage(self): | ||
289 | return (await self.invoke({"get-db-usage": {}}))["usage"] | ||
290 | |||
291 | async def get_db_query_columns(self): | ||
292 | return (await self.invoke({"get-db-query-columns": {}}))["columns"] | ||
293 | |||
294 | async def gc_status(self): | ||
295 | return await self.invoke({"gc-status": {}}) | ||
296 | |||
297 | async def gc_mark(self, mark, where): | ||
298 | """ | ||
299 | Starts a new garbage collection operation identified by "mark". If | ||
300 | garbage collection is already in progress with "mark", the collection | ||
301 | is continued. | ||
196 | 302 | ||
197 | for call in ( | 303 | All unihash entries that match the "where" clause are marked to be |
304 | kept. In addition, any new entries added to the database after this | ||
305 | command will be automatically marked with "mark" | ||
306 | """ | ||
307 | return await self.invoke({"gc-mark": {"mark": mark, "where": where}}) | ||
308 | |||
309 | async def gc_sweep(self, mark): | ||
310 | """ | ||
311 | Finishes garbage collection for "mark". All unihash entries that have | ||
312 | not been marked will be deleted. | ||
313 | |||
314 | It is recommended to clean unused outhash entries after running this to | ||
315 | cleanup any dangling outhashes | ||
316 | """ | ||
317 | return await self.invoke({"gc-sweep": {"mark": mark}}) | ||
318 | |||
319 | |||
320 | class Client(bb.asyncrpc.Client): | ||
321 | def __init__(self, username=None, password=None): | ||
322 | self.username = username | ||
323 | self.password = password | ||
324 | |||
325 | super().__init__() | ||
326 | self._add_methods( | ||
198 | "connect_tcp", | 327 | "connect_tcp", |
199 | "close", | 328 | "connect_websocket", |
200 | "get_unihash", | 329 | "get_unihash", |
330 | "get_unihash_batch", | ||
201 | "report_unihash", | 331 | "report_unihash", |
202 | "report_unihash_equiv", | 332 | "report_unihash_equiv", |
203 | "get_taskhash", | 333 | "get_taskhash", |
334 | "unihash_exists", | ||
335 | "unihash_exists_batch", | ||
336 | "get_outhash", | ||
204 | "get_stats", | 337 | "get_stats", |
205 | "reset_stats", | 338 | "reset_stats", |
206 | "backfill_wait", | 339 | "backfill_wait", |
207 | ): | 340 | "remove", |
208 | downcall = getattr(self.client, call) | 341 | "clean_unused", |
209 | setattr(self, call, self._get_downcall_wrapper(downcall)) | 342 | "auth", |
210 | 343 | "refresh_token", | |
211 | def _get_downcall_wrapper(self, downcall): | 344 | "set_user_perms", |
212 | def wrapper(*args, **kwargs): | 345 | "get_user", |
213 | return self.loop.run_until_complete(downcall(*args, **kwargs)) | 346 | "get_all_users", |
214 | 347 | "new_user", | |
215 | return wrapper | 348 | "delete_user", |
216 | 349 | "become_user", | |
217 | def connect_unix(self, path): | 350 | "get_db_usage", |
218 | # AF_UNIX has path length issues so chdir here to workaround | 351 | "get_db_query_columns", |
219 | cwd = os.getcwd() | 352 | "gc_status", |
220 | try: | 353 | "gc_mark", |
221 | os.chdir(os.path.dirname(path)) | 354 | "gc_sweep", |
222 | self.loop.run_until_complete(self.client.connect_unix(os.path.basename(path))) | 355 | ) |
223 | self.loop.run_until_complete(self.client.connect()) | ||
224 | finally: | ||
225 | os.chdir(cwd) | ||
226 | |||
227 | @property | ||
228 | def max_chunk(self): | ||
229 | return self.client.max_chunk | ||
230 | 356 | ||
231 | @max_chunk.setter | 357 | def _get_async_client(self): |
232 | def max_chunk(self, value): | 358 | return AsyncClient(self.username, self.password) |
233 | self.client.max_chunk = value | ||
diff --git a/bitbake/lib/hashserv/server.py b/bitbake/lib/hashserv/server.py index a0dc0c170f..68f64f983b 100644 --- a/bitbake/lib/hashserv/server.py +++ b/bitbake/lib/hashserv/server.py | |||
@@ -3,20 +3,51 @@ | |||
3 | # SPDX-License-Identifier: GPL-2.0-only | 3 | # SPDX-License-Identifier: GPL-2.0-only |
4 | # | 4 | # |
5 | 5 | ||
6 | from contextlib import closing, contextmanager | 6 | from datetime import datetime, timedelta |
7 | from datetime import datetime | ||
8 | import asyncio | 7 | import asyncio |
9 | import json | ||
10 | import logging | 8 | import logging |
11 | import math | 9 | import math |
12 | import os | ||
13 | import signal | ||
14 | import socket | ||
15 | import sys | ||
16 | import time | 10 | import time |
17 | from . import chunkify, DEFAULT_MAX_CHUNK, create_async_client, TABLE_COLUMNS | 11 | import os |
12 | import base64 | ||
13 | import hashlib | ||
14 | from . import create_async_client | ||
15 | import bb.asyncrpc | ||
16 | |||
17 | logger = logging.getLogger("hashserv.server") | ||
18 | |||
19 | |||
20 | # This permission only exists to match nothing | ||
21 | NONE_PERM = "@none" | ||
22 | |||
23 | READ_PERM = "@read" | ||
24 | REPORT_PERM = "@report" | ||
25 | DB_ADMIN_PERM = "@db-admin" | ||
26 | USER_ADMIN_PERM = "@user-admin" | ||
27 | ALL_PERM = "@all" | ||
18 | 28 | ||
19 | logger = logging.getLogger('hashserv.server') | 29 | ALL_PERMISSIONS = { |
30 | READ_PERM, | ||
31 | REPORT_PERM, | ||
32 | DB_ADMIN_PERM, | ||
33 | USER_ADMIN_PERM, | ||
34 | ALL_PERM, | ||
35 | } | ||
36 | |||
37 | DEFAULT_ANON_PERMS = ( | ||
38 | READ_PERM, | ||
39 | REPORT_PERM, | ||
40 | DB_ADMIN_PERM, | ||
41 | ) | ||
42 | |||
43 | TOKEN_ALGORITHM = "sha256" | ||
44 | |||
45 | # 48 bytes of random data will result in 64 characters when base64 | ||
46 | # encoded. This number also ensures that the base64 encoding won't have any | ||
47 | # trailing '=' characters. | ||
48 | TOKEN_SIZE = 48 | ||
49 | |||
50 | SALT_SIZE = 8 | ||
20 | 51 | ||
21 | 52 | ||
22 | class Measurement(object): | 53 | class Measurement(object): |
@@ -106,522 +137,745 @@ class Stats(object): | |||
106 | return math.sqrt(self.s / (self.num - 1)) | 137 | return math.sqrt(self.s / (self.num - 1)) |
107 | 138 | ||
108 | def todict(self): | 139 | def todict(self): |
109 | return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')} | 140 | return { |
110 | 141 | k: getattr(self, k) | |
111 | 142 | for k in ("num", "total_time", "max_time", "average", "stdev") | |
112 | class ClientError(Exception): | ||
113 | pass | ||
114 | |||
115 | class ServerError(Exception): | ||
116 | pass | ||
117 | |||
118 | def insert_task(cursor, data, ignore=False): | ||
119 | keys = sorted(data.keys()) | ||
120 | query = '''INSERT%s INTO tasks_v2 (%s) VALUES (%s)''' % ( | ||
121 | " OR IGNORE" if ignore else "", | ||
122 | ', '.join(keys), | ||
123 | ', '.join(':' + k for k in keys)) | ||
124 | cursor.execute(query, data) | ||
125 | |||
126 | async def copy_from_upstream(client, db, method, taskhash): | ||
127 | d = await client.get_taskhash(method, taskhash, True) | ||
128 | if d is not None: | ||
129 | # Filter out unknown columns | ||
130 | d = {k: v for k, v in d.items() if k in TABLE_COLUMNS} | ||
131 | keys = sorted(d.keys()) | ||
132 | |||
133 | with closing(db.cursor()) as cursor: | ||
134 | insert_task(cursor, d) | ||
135 | db.commit() | ||
136 | |||
137 | return d | ||
138 | |||
139 | async def copy_outhash_from_upstream(client, db, method, outhash, taskhash): | ||
140 | d = await client.get_outhash(method, outhash, taskhash) | ||
141 | if d is not None: | ||
142 | # Filter out unknown columns | ||
143 | d = {k: v for k, v in d.items() if k in TABLE_COLUMNS} | ||
144 | keys = sorted(d.keys()) | ||
145 | |||
146 | with closing(db.cursor()) as cursor: | ||
147 | insert_task(cursor, d) | ||
148 | db.commit() | ||
149 | |||
150 | return d | ||
151 | |||
152 | class ServerClient(object): | ||
153 | FAST_QUERY = 'SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1' | ||
154 | ALL_QUERY = 'SELECT * FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1' | ||
155 | OUTHASH_QUERY = ''' | ||
156 | -- Find tasks with a matching outhash (that is, tasks that | ||
157 | -- are equivalent) | ||
158 | SELECT * FROM tasks_v2 WHERE method=:method AND outhash=:outhash | ||
159 | |||
160 | -- If there is an exact match on the taskhash, return it. | ||
161 | -- Otherwise return the oldest matching outhash of any | ||
162 | -- taskhash | ||
163 | ORDER BY CASE WHEN taskhash=:taskhash THEN 1 ELSE 2 END, | ||
164 | created ASC | ||
165 | |||
166 | -- Only return one row | ||
167 | LIMIT 1 | ||
168 | ''' | ||
169 | |||
170 | def __init__(self, reader, writer, db, request_stats, backfill_queue, upstream, read_only): | ||
171 | self.reader = reader | ||
172 | self.writer = writer | ||
173 | self.db = db | ||
174 | self.request_stats = request_stats | ||
175 | self.max_chunk = DEFAULT_MAX_CHUNK | ||
176 | self.backfill_queue = backfill_queue | ||
177 | self.upstream = upstream | ||
178 | |||
179 | self.handlers = { | ||
180 | 'get': self.handle_get, | ||
181 | 'get-outhash': self.handle_get_outhash, | ||
182 | 'get-stream': self.handle_get_stream, | ||
183 | 'get-stats': self.handle_get_stats, | ||
184 | 'chunk-stream': self.handle_chunk, | ||
185 | } | 143 | } |
186 | 144 | ||
187 | if not read_only: | ||
188 | self.handlers.update({ | ||
189 | 'report': self.handle_report, | ||
190 | 'report-equiv': self.handle_equivreport, | ||
191 | 'reset-stats': self.handle_reset_stats, | ||
192 | 'backfill-wait': self.handle_backfill_wait, | ||
193 | }) | ||
194 | 145 | ||
195 | async def process_requests(self): | 146 | token_refresh_semaphore = asyncio.Lock() |
196 | if self.upstream is not None: | ||
197 | self.upstream_client = await create_async_client(self.upstream) | ||
198 | else: | ||
199 | self.upstream_client = None | ||
200 | 147 | ||
201 | try: | ||
202 | 148 | ||
149 | async def new_token(): | ||
150 | # Prevent malicious users from using this API to deduce the entropy | ||
151 | # pool on the server and thus be able to guess a token. *All* token | ||
152 | # refresh requests lock the same global semaphore and then sleep for a | ||
153 | # short time. The effectively rate limits the total number of requests | ||
154 | # than can be made across all clients to 10/second, which should be enough | ||
155 | # since you have to be an authenticated users to make the request in the | ||
156 | # first place | ||
157 | async with token_refresh_semaphore: | ||
158 | await asyncio.sleep(0.1) | ||
159 | raw = os.getrandom(TOKEN_SIZE, os.GRND_NONBLOCK) | ||
203 | 160 | ||
204 | self.addr = self.writer.get_extra_info('peername') | 161 | return base64.b64encode(raw, b"._").decode("utf-8") |
205 | logger.debug('Client %r connected' % (self.addr,)) | ||
206 | 162 | ||
207 | # Read protocol and version | ||
208 | protocol = await self.reader.readline() | ||
209 | if protocol is None: | ||
210 | return | ||
211 | 163 | ||
212 | (proto_name, proto_version) = protocol.decode('utf-8').rstrip().split() | 164 | def new_salt(): |
213 | if proto_name != 'OEHASHEQUIV': | 165 | return os.getrandom(SALT_SIZE, os.GRND_NONBLOCK).hex() |
214 | return | ||
215 | 166 | ||
216 | proto_version = tuple(int(v) for v in proto_version.split('.')) | ||
217 | if proto_version < (1, 0) or proto_version > (1, 1): | ||
218 | return | ||
219 | 167 | ||
220 | # Read headers. Currently, no headers are implemented, so look for | 168 | def hash_token(algo, salt, token): |
221 | # an empty line to signal the end of the headers | 169 | h = hashlib.new(algo) |
222 | while True: | 170 | h.update(salt.encode("utf-8")) |
223 | line = await self.reader.readline() | 171 | h.update(token.encode("utf-8")) |
224 | if line is None: | 172 | return ":".join([algo, salt, h.hexdigest()]) |
225 | return | ||
226 | 173 | ||
227 | line = line.decode('utf-8').rstrip() | ||
228 | if not line: | ||
229 | break | ||
230 | 174 | ||
231 | # Handle messages | 175 | def permissions(*permissions, allow_anon=True, allow_self_service=False): |
232 | while True: | 176 | """ |
233 | d = await self.read_message() | 177 | Function decorator that can be used to decorate an RPC function call and |
234 | if d is None: | 178 | check that the current users permissions match the require permissions. |
235 | break | ||
236 | await self.dispatch_message(d) | ||
237 | await self.writer.drain() | ||
238 | except ClientError as e: | ||
239 | logger.error(str(e)) | ||
240 | finally: | ||
241 | if self.upstream_client is not None: | ||
242 | await self.upstream_client.close() | ||
243 | 179 | ||
244 | self.writer.close() | 180 | If allow_anon is True, the user will also be allowed to make the RPC call |
181 | if the anonymous user permissions match the permissions. | ||
245 | 182 | ||
246 | async def dispatch_message(self, msg): | 183 | If allow_self_service is True, and the "username" property in the request |
247 | for k in self.handlers.keys(): | 184 | is the currently logged in user, or not specified, the user will also be |
248 | if k in msg: | 185 | allowed to make the request. This allows users to access normal privileged |
249 | logger.debug('Handling %s' % k) | 186 | API, as long as they are only modifying their own user properties (e.g. |
250 | if 'stream' in k: | 187 | users can be allowed to reset their own token without @user-admin |
251 | await self.handlers[k](msg[k]) | 188 | permissions, but not the token for any other user. |
189 | """ | ||
190 | |||
191 | def wrapper(func): | ||
192 | async def wrap(self, request): | ||
193 | if allow_self_service and self.user is not None: | ||
194 | username = request.get("username", self.user.username) | ||
195 | if username == self.user.username: | ||
196 | request["username"] = self.user.username | ||
197 | return await func(self, request) | ||
198 | |||
199 | if not self.user_has_permissions(*permissions, allow_anon=allow_anon): | ||
200 | if not self.user: | ||
201 | username = "Anonymous user" | ||
202 | user_perms = self.server.anon_perms | ||
252 | else: | 203 | else: |
253 | with self.request_stats.start_sample() as self.request_sample, \ | 204 | username = self.user.username |
254 | self.request_sample.measure(): | 205 | user_perms = self.user.permissions |
255 | await self.handlers[k](msg[k]) | 206 | |
256 | return | 207 | self.logger.info( |
208 | "User %s with permissions %r denied from calling %s. Missing permissions(s) %r", | ||
209 | username, | ||
210 | ", ".join(user_perms), | ||
211 | func.__name__, | ||
212 | ", ".join(permissions), | ||
213 | ) | ||
214 | raise bb.asyncrpc.InvokeError( | ||
215 | f"{username} is not allowed to access permissions(s) {', '.join(permissions)}" | ||
216 | ) | ||
217 | |||
218 | return await func(self, request) | ||
219 | |||
220 | return wrap | ||
221 | |||
222 | return wrapper | ||
223 | |||
224 | |||
225 | class ServerClient(bb.asyncrpc.AsyncServerConnection): | ||
226 | def __init__(self, socket, server): | ||
227 | super().__init__(socket, "OEHASHEQUIV", server.logger) | ||
228 | self.server = server | ||
229 | self.max_chunk = bb.asyncrpc.DEFAULT_MAX_CHUNK | ||
230 | self.user = None | ||
231 | |||
232 | self.handlers.update( | ||
233 | { | ||
234 | "get": self.handle_get, | ||
235 | "get-outhash": self.handle_get_outhash, | ||
236 | "get-stream": self.handle_get_stream, | ||
237 | "exists-stream": self.handle_exists_stream, | ||
238 | "get-stats": self.handle_get_stats, | ||
239 | "get-db-usage": self.handle_get_db_usage, | ||
240 | "get-db-query-columns": self.handle_get_db_query_columns, | ||
241 | # Not always read-only, but internally checks if the server is | ||
242 | # read-only | ||
243 | "report": self.handle_report, | ||
244 | "auth": self.handle_auth, | ||
245 | "get-user": self.handle_get_user, | ||
246 | "get-all-users": self.handle_get_all_users, | ||
247 | "become-user": self.handle_become_user, | ||
248 | } | ||
249 | ) | ||
257 | 250 | ||
258 | raise ClientError("Unrecognized command %r" % msg) | 251 | if not self.server.read_only: |
252 | self.handlers.update( | ||
253 | { | ||
254 | "report-equiv": self.handle_equivreport, | ||
255 | "reset-stats": self.handle_reset_stats, | ||
256 | "backfill-wait": self.handle_backfill_wait, | ||
257 | "remove": self.handle_remove, | ||
258 | "gc-mark": self.handle_gc_mark, | ||
259 | "gc-sweep": self.handle_gc_sweep, | ||
260 | "gc-status": self.handle_gc_status, | ||
261 | "clean-unused": self.handle_clean_unused, | ||
262 | "refresh-token": self.handle_refresh_token, | ||
263 | "set-user-perms": self.handle_set_perms, | ||
264 | "new-user": self.handle_new_user, | ||
265 | "delete-user": self.handle_delete_user, | ||
266 | } | ||
267 | ) | ||
259 | 268 | ||
260 | def write_message(self, msg): | 269 | def raise_no_user_error(self, username): |
261 | for c in chunkify(json.dumps(msg), self.max_chunk): | 270 | raise bb.asyncrpc.InvokeError(f"No user named '{username}' exists") |
262 | self.writer.write(c.encode('utf-8')) | ||
263 | 271 | ||
264 | async def read_message(self): | 272 | def user_has_permissions(self, *permissions, allow_anon=True): |
265 | l = await self.reader.readline() | 273 | permissions = set(permissions) |
266 | if not l: | 274 | if allow_anon: |
267 | return None | 275 | if ALL_PERM in self.server.anon_perms: |
276 | return True | ||
268 | 277 | ||
269 | try: | 278 | if not permissions - self.server.anon_perms: |
270 | message = l.decode('utf-8') | 279 | return True |
271 | 280 | ||
272 | if not message.endswith('\n'): | 281 | if self.user is None: |
273 | return None | 282 | return False |
274 | 283 | ||
275 | return json.loads(message) | 284 | if ALL_PERM in self.user.permissions: |
276 | except (json.JSONDecodeError, UnicodeDecodeError) as e: | 285 | return True |
277 | logger.error('Bad message from client: %r' % message) | ||
278 | raise e | ||
279 | 286 | ||
280 | async def handle_chunk(self, request): | 287 | if not permissions - self.user.permissions: |
281 | lines = [] | 288 | return True |
282 | try: | ||
283 | while True: | ||
284 | l = await self.reader.readline() | ||
285 | l = l.rstrip(b"\n").decode("utf-8") | ||
286 | if not l: | ||
287 | break | ||
288 | lines.append(l) | ||
289 | 289 | ||
290 | msg = json.loads(''.join(lines)) | 290 | return False |
291 | except (json.JSONDecodeError, UnicodeDecodeError) as e: | ||
292 | logger.error('Bad message from client: %r' % message) | ||
293 | raise e | ||
294 | 291 | ||
295 | if 'chunk-stream' in msg: | 292 | def validate_proto_version(self): |
296 | raise ClientError("Nested chunks are not allowed") | 293 | return self.proto_version > (1, 0) and self.proto_version <= (1, 1) |
297 | 294 | ||
298 | await self.dispatch_message(msg) | 295 | async def process_requests(self): |
296 | async with self.server.db_engine.connect(self.logger) as db: | ||
297 | self.db = db | ||
298 | if self.server.upstream is not None: | ||
299 | self.upstream_client = await create_async_client(self.server.upstream) | ||
300 | else: | ||
301 | self.upstream_client = None | ||
299 | 302 | ||
300 | async def handle_get(self, request): | 303 | try: |
301 | method = request['method'] | 304 | await super().process_requests() |
302 | taskhash = request['taskhash'] | 305 | finally: |
306 | if self.upstream_client is not None: | ||
307 | await self.upstream_client.close() | ||
303 | 308 | ||
304 | if request.get('all', False): | 309 | async def dispatch_message(self, msg): |
305 | row = self.query_equivalent(method, taskhash, self.ALL_QUERY) | 310 | for k in self.handlers.keys(): |
306 | else: | 311 | if k in msg: |
307 | row = self.query_equivalent(method, taskhash, self.FAST_QUERY) | 312 | self.logger.debug("Handling %s" % k) |
313 | if "stream" in k: | ||
314 | return await self.handlers[k](msg[k]) | ||
315 | else: | ||
316 | with self.server.request_stats.start_sample() as self.request_sample, self.request_sample.measure(): | ||
317 | return await self.handlers[k](msg[k]) | ||
308 | 318 | ||
309 | if row is not None: | 319 | raise bb.asyncrpc.ClientError("Unrecognized command %r" % msg) |
310 | logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) | 320 | |
311 | d = {k: row[k] for k in row.keys()} | 321 | @permissions(READ_PERM) |
312 | elif self.upstream_client is not None: | 322 | async def handle_get(self, request): |
313 | d = await copy_from_upstream(self.upstream_client, self.db, method, taskhash) | 323 | method = request["method"] |
324 | taskhash = request["taskhash"] | ||
325 | fetch_all = request.get("all", False) | ||
326 | |||
327 | return await self.get_unihash(method, taskhash, fetch_all) | ||
328 | |||
329 | async def get_unihash(self, method, taskhash, fetch_all=False): | ||
330 | d = None | ||
331 | |||
332 | if fetch_all: | ||
333 | row = await self.db.get_unihash_by_taskhash_full(method, taskhash) | ||
334 | if row is not None: | ||
335 | d = {k: row[k] for k in row.keys()} | ||
336 | elif self.upstream_client is not None: | ||
337 | d = await self.upstream_client.get_taskhash(method, taskhash, True) | ||
338 | await self.update_unified(d) | ||
314 | else: | 339 | else: |
315 | d = None | 340 | row = await self.db.get_equivalent(method, taskhash) |
341 | |||
342 | if row is not None: | ||
343 | d = {k: row[k] for k in row.keys()} | ||
344 | elif self.upstream_client is not None: | ||
345 | d = await self.upstream_client.get_taskhash(method, taskhash) | ||
346 | await self.db.insert_unihash(d["method"], d["taskhash"], d["unihash"]) | ||
316 | 347 | ||
317 | self.write_message(d) | 348 | return d |
318 | 349 | ||
350 | @permissions(READ_PERM) | ||
319 | async def handle_get_outhash(self, request): | 351 | async def handle_get_outhash(self, request): |
320 | with closing(self.db.cursor()) as cursor: | 352 | method = request["method"] |
321 | cursor.execute(self.OUTHASH_QUERY, | 353 | outhash = request["outhash"] |
322 | {k: request[k] for k in ('method', 'outhash', 'taskhash')}) | 354 | taskhash = request["taskhash"] |
355 | with_unihash = request.get("with_unihash", True) | ||
323 | 356 | ||
324 | row = cursor.fetchone() | 357 | return await self.get_outhash(method, outhash, taskhash, with_unihash) |
358 | |||
359 | async def get_outhash(self, method, outhash, taskhash, with_unihash=True): | ||
360 | d = None | ||
361 | if with_unihash: | ||
362 | row = await self.db.get_unihash_by_outhash(method, outhash) | ||
363 | else: | ||
364 | row = await self.db.get_outhash(method, outhash) | ||
325 | 365 | ||
326 | if row is not None: | 366 | if row is not None: |
327 | logger.debug('Found equivalent outhash %s -> %s', (row['outhash'], row['unihash'])) | ||
328 | d = {k: row[k] for k in row.keys()} | 367 | d = {k: row[k] for k in row.keys()} |
329 | else: | 368 | elif self.upstream_client is not None: |
330 | d = None | 369 | d = await self.upstream_client.get_outhash(method, outhash, taskhash) |
370 | await self.update_unified(d) | ||
331 | 371 | ||
332 | self.write_message(d) | 372 | return d |
333 | 373 | ||
334 | async def handle_get_stream(self, request): | 374 | async def update_unified(self, data): |
335 | self.write_message('ok') | 375 | if data is None: |
376 | return | ||
377 | |||
378 | await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"]) | ||
379 | await self.db.insert_outhash(data) | ||
380 | |||
381 | async def _stream_handler(self, handler): | ||
382 | await self.socket.send_message("ok") | ||
336 | 383 | ||
337 | while True: | 384 | while True: |
338 | upstream = None | 385 | upstream = None |
339 | 386 | ||
340 | l = await self.reader.readline() | 387 | l = await self.socket.recv() |
341 | if not l: | 388 | if not l: |
342 | return | 389 | break |
343 | 390 | ||
344 | try: | 391 | try: |
345 | # This inner loop is very sensitive and must be as fast as | 392 | # This inner loop is very sensitive and must be as fast as |
346 | # possible (which is why the request sample is handled manually | 393 | # possible (which is why the request sample is handled manually |
347 | # instead of using 'with', and also why logging statements are | 394 | # instead of using 'with', and also why logging statements are |
348 | # commented out. | 395 | # commented out. |
349 | self.request_sample = self.request_stats.start_sample() | 396 | self.request_sample = self.server.request_stats.start_sample() |
350 | request_measure = self.request_sample.measure() | 397 | request_measure = self.request_sample.measure() |
351 | request_measure.start() | 398 | request_measure.start() |
352 | 399 | ||
353 | l = l.decode('utf-8').rstrip() | 400 | if l == "END": |
354 | if l == 'END': | 401 | break |
355 | self.writer.write('ok\n'.encode('utf-8')) | ||
356 | return | ||
357 | |||
358 | (method, taskhash) = l.split() | ||
359 | #logger.debug('Looking up %s %s' % (method, taskhash)) | ||
360 | row = self.query_equivalent(method, taskhash, self.FAST_QUERY) | ||
361 | if row is not None: | ||
362 | msg = ('%s\n' % row['unihash']).encode('utf-8') | ||
363 | #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) | ||
364 | elif self.upstream_client is not None: | ||
365 | upstream = await self.upstream_client.get_unihash(method, taskhash) | ||
366 | if upstream: | ||
367 | msg = ("%s\n" % upstream).encode("utf-8") | ||
368 | else: | ||
369 | msg = "\n".encode("utf-8") | ||
370 | else: | ||
371 | msg = '\n'.encode('utf-8') | ||
372 | 402 | ||
373 | self.writer.write(msg) | 403 | msg = await handler(l) |
404 | await self.socket.send(msg) | ||
374 | finally: | 405 | finally: |
375 | request_measure.end() | 406 | request_measure.end() |
376 | self.request_sample.end() | 407 | self.request_sample.end() |
377 | 408 | ||
378 | await self.writer.drain() | 409 | await self.socket.send("ok") |
410 | return self.NO_RESPONSE | ||
379 | 411 | ||
380 | # Post to the backfill queue after writing the result to minimize | 412 | @permissions(READ_PERM) |
381 | # the turn around time on a request | 413 | async def handle_get_stream(self, request): |
382 | if upstream is not None: | 414 | async def handler(l): |
383 | await self.backfill_queue.put((method, taskhash)) | 415 | (method, taskhash) = l.split() |
416 | # self.logger.debug('Looking up %s %s' % (method, taskhash)) | ||
417 | row = await self.db.get_equivalent(method, taskhash) | ||
384 | 418 | ||
385 | async def handle_report(self, data): | 419 | if row is not None: |
386 | with closing(self.db.cursor()) as cursor: | 420 | # self.logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) |
387 | cursor.execute(self.OUTHASH_QUERY, | 421 | return row["unihash"] |
388 | {k: data[k] for k in ('method', 'outhash', 'taskhash')}) | ||
389 | |||
390 | row = cursor.fetchone() | ||
391 | |||
392 | if row is None and self.upstream_client: | ||
393 | # Try upstream | ||
394 | row = await copy_outhash_from_upstream(self.upstream_client, | ||
395 | self.db, | ||
396 | data['method'], | ||
397 | data['outhash'], | ||
398 | data['taskhash']) | ||
399 | |||
400 | # If no matching outhash was found, or one *was* found but it | ||
401 | # wasn't an exact match on the taskhash, a new entry for this | ||
402 | # taskhash should be added | ||
403 | if row is None or row['taskhash'] != data['taskhash']: | ||
404 | # If a row matching the outhash was found, the unihash for | ||
405 | # the new taskhash should be the same as that one. | ||
406 | # Otherwise the caller provided unihash is used. | ||
407 | unihash = data['unihash'] | ||
408 | if row is not None: | ||
409 | unihash = row['unihash'] | ||
410 | |||
411 | insert_data = { | ||
412 | 'method': data['method'], | ||
413 | 'outhash': data['outhash'], | ||
414 | 'taskhash': data['taskhash'], | ||
415 | 'unihash': unihash, | ||
416 | 'created': datetime.now() | ||
417 | } | ||
418 | 422 | ||
419 | for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'): | 423 | if self.upstream_client is not None: |
420 | if k in data: | 424 | upstream = await self.upstream_client.get_unihash(method, taskhash) |
421 | insert_data[k] = data[k] | 425 | if upstream: |
426 | await self.server.backfill_queue.put((method, taskhash)) | ||
427 | return upstream | ||
422 | 428 | ||
423 | insert_task(cursor, insert_data) | 429 | return "" |
424 | self.db.commit() | ||
425 | 430 | ||
426 | logger.info('Adding taskhash %s with unihash %s', | 431 | return await self._stream_handler(handler) |
427 | data['taskhash'], unihash) | ||
428 | 432 | ||
429 | d = { | 433 | @permissions(READ_PERM) |
430 | 'taskhash': data['taskhash'], | 434 | async def handle_exists_stream(self, request): |
431 | 'method': data['method'], | 435 | async def handler(l): |
432 | 'unihash': unihash | 436 | if await self.db.unihash_exists(l): |
433 | } | 437 | return "true" |
434 | else: | ||
435 | d = {k: row[k] for k in ('taskhash', 'method', 'unihash')} | ||
436 | 438 | ||
437 | self.write_message(d) | 439 | if self.upstream_client is not None: |
440 | if await self.upstream_client.unihash_exists(l): | ||
441 | return "true" | ||
438 | 442 | ||
439 | async def handle_equivreport(self, data): | 443 | return "false" |
440 | with closing(self.db.cursor()) as cursor: | ||
441 | insert_data = { | ||
442 | 'method': data['method'], | ||
443 | 'outhash': "", | ||
444 | 'taskhash': data['taskhash'], | ||
445 | 'unihash': data['unihash'], | ||
446 | 'created': datetime.now() | ||
447 | } | ||
448 | 444 | ||
449 | for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'): | 445 | return await self._stream_handler(handler) |
450 | if k in data: | ||
451 | insert_data[k] = data[k] | ||
452 | 446 | ||
453 | insert_task(cursor, insert_data, ignore=True) | 447 | async def report_readonly(self, data): |
454 | self.db.commit() | 448 | method = data["method"] |
449 | outhash = data["outhash"] | ||
450 | taskhash = data["taskhash"] | ||
455 | 451 | ||
456 | # Fetch the unihash that will be reported for the taskhash. If the | 452 | info = await self.get_outhash(method, outhash, taskhash) |
457 | # unihash matches, it means this row was inserted (or the mapping | 453 | if info: |
458 | # was already valid) | 454 | unihash = info["unihash"] |
459 | row = self.query_equivalent(data['method'], data['taskhash'], self.FAST_QUERY) | 455 | else: |
456 | unihash = data["unihash"] | ||
460 | 457 | ||
461 | if row['unihash'] == data['unihash']: | 458 | return { |
462 | logger.info('Adding taskhash equivalence for %s with unihash %s', | 459 | "taskhash": taskhash, |
463 | data['taskhash'], row['unihash']) | 460 | "method": method, |
461 | "unihash": unihash, | ||
462 | } | ||
464 | 463 | ||
465 | d = {k: row[k] for k in ('taskhash', 'method', 'unihash')} | 464 | # Since this can be called either read only or to report, the check to |
465 | # report is made inside the function | ||
466 | @permissions(READ_PERM) | ||
467 | async def handle_report(self, data): | ||
468 | if self.server.read_only or not self.user_has_permissions(REPORT_PERM): | ||
469 | return await self.report_readonly(data) | ||
470 | |||
471 | outhash_data = { | ||
472 | "method": data["method"], | ||
473 | "outhash": data["outhash"], | ||
474 | "taskhash": data["taskhash"], | ||
475 | "created": datetime.now(), | ||
476 | } | ||
466 | 477 | ||
467 | self.write_message(d) | 478 | for k in ("owner", "PN", "PV", "PR", "task", "outhash_siginfo"): |
479 | if k in data: | ||
480 | outhash_data[k] = data[k] | ||
468 | 481 | ||
482 | if self.user: | ||
483 | outhash_data["owner"] = self.user.username | ||
469 | 484 | ||
470 | async def handle_get_stats(self, request): | 485 | # Insert the new entry, unless it already exists |
471 | d = { | 486 | if await self.db.insert_outhash(outhash_data): |
472 | 'requests': self.request_stats.todict(), | 487 | # If this row is new, check if it is equivalent to another |
488 | # output hash | ||
489 | row = await self.db.get_equivalent_for_outhash( | ||
490 | data["method"], data["outhash"], data["taskhash"] | ||
491 | ) | ||
492 | |||
493 | if row is not None: | ||
494 | # A matching output hash was found. Set our taskhash to the | ||
495 | # same unihash since they are equivalent | ||
496 | unihash = row["unihash"] | ||
497 | else: | ||
498 | # No matching output hash was found. This is probably the | ||
499 | # first outhash to be added. | ||
500 | unihash = data["unihash"] | ||
501 | |||
502 | # Query upstream to see if it has a unihash we can use | ||
503 | if self.upstream_client is not None: | ||
504 | upstream_data = await self.upstream_client.get_outhash( | ||
505 | data["method"], data["outhash"], data["taskhash"] | ||
506 | ) | ||
507 | if upstream_data is not None: | ||
508 | unihash = upstream_data["unihash"] | ||
509 | |||
510 | await self.db.insert_unihash(data["method"], data["taskhash"], unihash) | ||
511 | |||
512 | unihash_data = await self.get_unihash(data["method"], data["taskhash"]) | ||
513 | if unihash_data is not None: | ||
514 | unihash = unihash_data["unihash"] | ||
515 | else: | ||
516 | unihash = data["unihash"] | ||
517 | |||
518 | return { | ||
519 | "taskhash": data["taskhash"], | ||
520 | "method": data["method"], | ||
521 | "unihash": unihash, | ||
473 | } | 522 | } |
474 | 523 | ||
475 | self.write_message(d) | 524 | @permissions(READ_PERM, REPORT_PERM) |
525 | async def handle_equivreport(self, data): | ||
526 | await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"]) | ||
527 | |||
528 | # Fetch the unihash that will be reported for the taskhash. If the | ||
529 | # unihash matches, it means this row was inserted (or the mapping | ||
530 | # was already valid) | ||
531 | row = await self.db.get_equivalent(data["method"], data["taskhash"]) | ||
532 | |||
533 | if row["unihash"] == data["unihash"]: | ||
534 | self.logger.info( | ||
535 | "Adding taskhash equivalence for %s with unihash %s", | ||
536 | data["taskhash"], | ||
537 | row["unihash"], | ||
538 | ) | ||
539 | |||
540 | return {k: row[k] for k in ("taskhash", "method", "unihash")} | ||
476 | 541 | ||
542 | @permissions(READ_PERM) | ||
543 | async def handle_get_stats(self, request): | ||
544 | return { | ||
545 | "requests": self.server.request_stats.todict(), | ||
546 | } | ||
547 | |||
548 | @permissions(DB_ADMIN_PERM) | ||
477 | async def handle_reset_stats(self, request): | 549 | async def handle_reset_stats(self, request): |
478 | d = { | 550 | d = { |
479 | 'requests': self.request_stats.todict(), | 551 | "requests": self.server.request_stats.todict(), |
480 | } | 552 | } |
481 | 553 | ||
482 | self.request_stats.reset() | 554 | self.server.request_stats.reset() |
483 | self.write_message(d) | 555 | return d |
484 | 556 | ||
557 | @permissions(READ_PERM) | ||
485 | async def handle_backfill_wait(self, request): | 558 | async def handle_backfill_wait(self, request): |
486 | d = { | 559 | d = { |
487 | 'tasks': self.backfill_queue.qsize(), | 560 | "tasks": self.server.backfill_queue.qsize(), |
488 | } | 561 | } |
489 | await self.backfill_queue.join() | 562 | await self.server.backfill_queue.join() |
490 | self.write_message(d) | 563 | return d |
564 | |||
565 | @permissions(DB_ADMIN_PERM) | ||
566 | async def handle_remove(self, request): | ||
567 | condition = request["where"] | ||
568 | if not isinstance(condition, dict): | ||
569 | raise TypeError("Bad condition type %s" % type(condition)) | ||
570 | |||
571 | return {"count": await self.db.remove(condition)} | ||
572 | |||
573 | @permissions(DB_ADMIN_PERM) | ||
574 | async def handle_gc_mark(self, request): | ||
575 | condition = request["where"] | ||
576 | mark = request["mark"] | ||
577 | |||
578 | if not isinstance(condition, dict): | ||
579 | raise TypeError("Bad condition type %s" % type(condition)) | ||
580 | |||
581 | if not isinstance(mark, str): | ||
582 | raise TypeError("Bad mark type %s" % type(mark)) | ||
583 | |||
584 | return {"count": await self.db.gc_mark(mark, condition)} | ||
585 | |||
586 | @permissions(DB_ADMIN_PERM) | ||
587 | async def handle_gc_sweep(self, request): | ||
588 | mark = request["mark"] | ||
589 | |||
590 | if not isinstance(mark, str): | ||
591 | raise TypeError("Bad mark type %s" % type(mark)) | ||
592 | |||
593 | current_mark = await self.db.get_current_gc_mark() | ||
594 | |||
595 | if not current_mark or mark != current_mark: | ||
596 | raise bb.asyncrpc.InvokeError( | ||
597 | f"'{mark}' is not the current mark. Refusing to sweep" | ||
598 | ) | ||
599 | |||
600 | count = await self.db.gc_sweep() | ||
601 | |||
602 | return {"count": count} | ||
603 | |||
604 | @permissions(DB_ADMIN_PERM) | ||
605 | async def handle_gc_status(self, request): | ||
606 | (keep_rows, remove_rows, current_mark) = await self.db.gc_status() | ||
607 | return { | ||
608 | "keep": keep_rows, | ||
609 | "remove": remove_rows, | ||
610 | "mark": current_mark, | ||
611 | } | ||
612 | |||
613 | @permissions(DB_ADMIN_PERM) | ||
614 | async def handle_clean_unused(self, request): | ||
615 | max_age = request["max_age_seconds"] | ||
616 | oldest = datetime.now() - timedelta(seconds=-max_age) | ||
617 | return {"count": await self.db.clean_unused(oldest)} | ||
618 | |||
619 | @permissions(DB_ADMIN_PERM) | ||
620 | async def handle_get_db_usage(self, request): | ||
621 | return {"usage": await self.db.get_usage()} | ||
622 | |||
623 | @permissions(DB_ADMIN_PERM) | ||
624 | async def handle_get_db_query_columns(self, request): | ||
625 | return {"columns": await self.db.get_query_columns()} | ||
626 | |||
627 | # The authentication API is always allowed | ||
628 | async def handle_auth(self, request): | ||
629 | username = str(request["username"]) | ||
630 | token = str(request["token"]) | ||
631 | |||
632 | async def fail_auth(): | ||
633 | nonlocal username | ||
634 | # Rate limit bad login attempts | ||
635 | await asyncio.sleep(1) | ||
636 | raise bb.asyncrpc.InvokeError(f"Unable to authenticate as {username}") | ||
637 | |||
638 | user, db_token = await self.db.lookup_user_token(username) | ||
639 | |||
640 | if not user or not db_token: | ||
641 | await fail_auth() | ||
491 | 642 | ||
492 | def query_equivalent(self, method, taskhash, query): | ||
493 | # This is part of the inner loop and must be as fast as possible | ||
494 | try: | 643 | try: |
495 | cursor = self.db.cursor() | 644 | algo, salt, _ = db_token.split(":") |
496 | cursor.execute(query, {'method': method, 'taskhash': taskhash}) | 645 | except ValueError: |
497 | return cursor.fetchone() | 646 | await fail_auth() |
498 | except: | ||
499 | cursor.close() | ||
500 | 647 | ||
648 | if hash_token(algo, salt, token) != db_token: | ||
649 | await fail_auth() | ||
501 | 650 | ||
502 | class Server(object): | 651 | self.user = user |
503 | def __init__(self, db, loop=None, upstream=None, read_only=False): | ||
504 | if upstream and read_only: | ||
505 | raise ServerError("Read-only hashserv cannot pull from an upstream server") | ||
506 | 652 | ||
507 | self.request_stats = Stats() | 653 | self.logger.info("Authenticated as %s", username) |
508 | self.db = db | ||
509 | 654 | ||
510 | if loop is None: | 655 | return { |
511 | self.loop = asyncio.new_event_loop() | 656 | "result": True, |
512 | self.close_loop = True | 657 | "username": self.user.username, |
513 | else: | 658 | "permissions": sorted(list(self.user.permissions)), |
514 | self.loop = loop | 659 | } |
515 | self.close_loop = False | ||
516 | 660 | ||
517 | self.upstream = upstream | 661 | @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False) |
518 | self.read_only = read_only | 662 | async def handle_refresh_token(self, request): |
663 | username = str(request["username"]) | ||
519 | 664 | ||
520 | self._cleanup_socket = None | 665 | token = await new_token() |
521 | 666 | ||
522 | def start_tcp_server(self, host, port): | 667 | updated = await self.db.set_user_token( |
523 | self.server = self.loop.run_until_complete( | 668 | username, |
524 | asyncio.start_server(self.handle_client, host, port, loop=self.loop) | 669 | hash_token(TOKEN_ALGORITHM, new_salt(), token), |
525 | ) | 670 | ) |
671 | if not updated: | ||
672 | self.raise_no_user_error(username) | ||
526 | 673 | ||
527 | for s in self.server.sockets: | 674 | return {"username": username, "token": token} |
528 | logger.info('Listening on %r' % (s.getsockname(),)) | ||
529 | # Newer python does this automatically. Do it manually here for | ||
530 | # maximum compatibility | ||
531 | s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) | ||
532 | s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1) | ||
533 | |||
534 | name = self.server.sockets[0].getsockname() | ||
535 | if self.server.sockets[0].family == socket.AF_INET6: | ||
536 | self.address = "[%s]:%d" % (name[0], name[1]) | ||
537 | else: | ||
538 | self.address = "%s:%d" % (name[0], name[1]) | ||
539 | 675 | ||
540 | def start_unix_server(self, path): | 676 | def get_perm_arg(self, arg): |
541 | def cleanup(): | 677 | if not isinstance(arg, list): |
542 | os.unlink(path) | 678 | raise bb.asyncrpc.InvokeError("Unexpected type for permissions") |
543 | 679 | ||
544 | cwd = os.getcwd() | 680 | arg = set(arg) |
545 | try: | 681 | try: |
546 | # Work around path length limits in AF_UNIX | 682 | arg.remove(NONE_PERM) |
547 | os.chdir(os.path.dirname(path)) | 683 | except KeyError: |
548 | self.server = self.loop.run_until_complete( | 684 | pass |
549 | asyncio.start_unix_server(self.handle_client, os.path.basename(path), loop=self.loop) | 685 | |
686 | unknown_perms = arg - ALL_PERMISSIONS | ||
687 | if unknown_perms: | ||
688 | raise bb.asyncrpc.InvokeError( | ||
689 | "Unknown permissions %s" % ", ".join(sorted(list(unknown_perms))) | ||
550 | ) | 690 | ) |
551 | finally: | ||
552 | os.chdir(cwd) | ||
553 | 691 | ||
554 | logger.info('Listening on %r' % path) | 692 | return sorted(list(arg)) |
555 | 693 | ||
556 | self._cleanup_socket = cleanup | 694 | def return_perms(self, permissions): |
557 | self.address = "unix://%s" % os.path.abspath(path) | 695 | if ALL_PERM in permissions: |
696 | return sorted(list(ALL_PERMISSIONS)) | ||
697 | return sorted(list(permissions)) | ||
558 | 698 | ||
559 | async def handle_client(self, reader, writer): | 699 | @permissions(USER_ADMIN_PERM, allow_anon=False) |
560 | # writer.transport.set_write_buffer_limits(0) | 700 | async def handle_set_perms(self, request): |
561 | try: | 701 | username = str(request["username"]) |
562 | client = ServerClient(reader, writer, self.db, self.request_stats, self.backfill_queue, self.upstream, self.read_only) | 702 | permissions = self.get_perm_arg(request["permissions"]) |
563 | await client.process_requests() | ||
564 | except Exception as e: | ||
565 | import traceback | ||
566 | logger.error('Error from client: %s' % str(e), exc_info=True) | ||
567 | traceback.print_exc() | ||
568 | writer.close() | ||
569 | logger.info('Client disconnected') | ||
570 | |||
571 | @contextmanager | ||
572 | def _backfill_worker(self): | ||
573 | async def backfill_worker_task(): | ||
574 | client = await create_async_client(self.upstream) | ||
575 | try: | ||
576 | while True: | ||
577 | item = await self.backfill_queue.get() | ||
578 | if item is None: | ||
579 | self.backfill_queue.task_done() | ||
580 | break | ||
581 | method, taskhash = item | ||
582 | await copy_from_upstream(client, self.db, method, taskhash) | ||
583 | self.backfill_queue.task_done() | ||
584 | finally: | ||
585 | await client.close() | ||
586 | 703 | ||
587 | async def join_worker(worker): | 704 | if not await self.db.set_user_perms(username, permissions): |
588 | await self.backfill_queue.put(None) | 705 | self.raise_no_user_error(username) |
589 | await worker | ||
590 | 706 | ||
591 | if self.upstream is not None: | 707 | return { |
592 | worker = asyncio.ensure_future(backfill_worker_task()) | 708 | "username": username, |
593 | try: | 709 | "permissions": self.return_perms(permissions), |
594 | yield | 710 | } |
595 | finally: | ||
596 | self.loop.run_until_complete(join_worker(worker)) | ||
597 | else: | ||
598 | yield | ||
599 | 711 | ||
600 | def serve_forever(self): | 712 | @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False) |
601 | def signal_handler(): | 713 | async def handle_get_user(self, request): |
602 | self.loop.stop() | 714 | username = str(request["username"]) |
603 | 715 | ||
604 | asyncio.set_event_loop(self.loop) | 716 | user = await self.db.lookup_user(username) |
605 | try: | 717 | if user is None: |
606 | self.backfill_queue = asyncio.Queue() | 718 | return None |
719 | |||
720 | return { | ||
721 | "username": user.username, | ||
722 | "permissions": self.return_perms(user.permissions), | ||
723 | } | ||
724 | |||
725 | @permissions(USER_ADMIN_PERM, allow_anon=False) | ||
726 | async def handle_get_all_users(self, request): | ||
727 | users = await self.db.get_all_users() | ||
728 | return { | ||
729 | "users": [ | ||
730 | { | ||
731 | "username": u.username, | ||
732 | "permissions": self.return_perms(u.permissions), | ||
733 | } | ||
734 | for u in users | ||
735 | ] | ||
736 | } | ||
737 | |||
738 | @permissions(USER_ADMIN_PERM, allow_anon=False) | ||
739 | async def handle_new_user(self, request): | ||
740 | username = str(request["username"]) | ||
741 | permissions = self.get_perm_arg(request["permissions"]) | ||
742 | |||
743 | token = await new_token() | ||
744 | |||
745 | inserted = await self.db.new_user( | ||
746 | username, | ||
747 | permissions, | ||
748 | hash_token(TOKEN_ALGORITHM, new_salt(), token), | ||
749 | ) | ||
750 | if not inserted: | ||
751 | raise bb.asyncrpc.InvokeError(f"Cannot create new user '{username}'") | ||
752 | |||
753 | return { | ||
754 | "username": username, | ||
755 | "permissions": self.return_perms(permissions), | ||
756 | "token": token, | ||
757 | } | ||
758 | |||
759 | @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False) | ||
760 | async def handle_delete_user(self, request): | ||
761 | username = str(request["username"]) | ||
762 | |||
763 | if not await self.db.delete_user(username): | ||
764 | self.raise_no_user_error(username) | ||
765 | |||
766 | return {"username": username} | ||
607 | 767 | ||
608 | self.loop.add_signal_handler(signal.SIGTERM, signal_handler) | 768 | @permissions(USER_ADMIN_PERM, allow_anon=False) |
769 | async def handle_become_user(self, request): | ||
770 | username = str(request["username"]) | ||
609 | 771 | ||
610 | with self._backfill_worker(): | 772 | user = await self.db.lookup_user(username) |
611 | try: | 773 | if user is None: |
612 | self.loop.run_forever() | 774 | raise bb.asyncrpc.InvokeError(f"User {username} doesn't exist") |
613 | except KeyboardInterrupt: | ||
614 | pass | ||
615 | 775 | ||
616 | self.server.close() | 776 | self.user = user |
777 | |||
778 | self.logger.info("Became user %s", username) | ||
779 | |||
780 | return { | ||
781 | "username": self.user.username, | ||
782 | "permissions": self.return_perms(self.user.permissions), | ||
783 | } | ||
784 | |||
785 | |||
786 | class Server(bb.asyncrpc.AsyncServer): | ||
787 | def __init__( | ||
788 | self, | ||
789 | db_engine, | ||
790 | upstream=None, | ||
791 | read_only=False, | ||
792 | anon_perms=DEFAULT_ANON_PERMS, | ||
793 | admin_username=None, | ||
794 | admin_password=None, | ||
795 | ): | ||
796 | if upstream and read_only: | ||
797 | raise bb.asyncrpc.ServerError( | ||
798 | "Read-only hashserv cannot pull from an upstream server" | ||
799 | ) | ||
800 | |||
801 | disallowed_perms = set(anon_perms) - set( | ||
802 | [NONE_PERM, READ_PERM, REPORT_PERM, DB_ADMIN_PERM] | ||
803 | ) | ||
804 | |||
805 | if disallowed_perms: | ||
806 | raise bb.asyncrpc.ServerError( | ||
807 | f"Permission(s) {' '.join(disallowed_perms)} are not allowed for anonymous users" | ||
808 | ) | ||
617 | 809 | ||
618 | self.loop.run_until_complete(self.server.wait_closed()) | 810 | super().__init__(logger) |
619 | logger.info('Server shutting down') | ||
620 | finally: | ||
621 | if self.close_loop: | ||
622 | if sys.version_info >= (3, 6): | ||
623 | self.loop.run_until_complete(self.loop.shutdown_asyncgens()) | ||
624 | self.loop.close() | ||
625 | 811 | ||
626 | if self._cleanup_socket is not None: | 812 | self.request_stats = Stats() |
627 | self._cleanup_socket() | 813 | self.db_engine = db_engine |
814 | self.upstream = upstream | ||
815 | self.read_only = read_only | ||
816 | self.backfill_queue = None | ||
817 | self.anon_perms = set(anon_perms) | ||
818 | self.admin_username = admin_username | ||
819 | self.admin_password = admin_password | ||
820 | |||
821 | self.logger.info( | ||
822 | "Anonymous user permissions are: %s", ", ".join(self.anon_perms) | ||
823 | ) | ||
824 | |||
825 | def accept_client(self, socket): | ||
826 | return ServerClient(socket, self) | ||
827 | |||
828 | async def create_admin_user(self): | ||
829 | admin_permissions = (ALL_PERM,) | ||
830 | async with self.db_engine.connect(self.logger) as db: | ||
831 | added = await db.new_user( | ||
832 | self.admin_username, | ||
833 | admin_permissions, | ||
834 | hash_token(TOKEN_ALGORITHM, new_salt(), self.admin_password), | ||
835 | ) | ||
836 | if added: | ||
837 | self.logger.info("Created admin user '%s'", self.admin_username) | ||
838 | else: | ||
839 | await db.set_user_perms( | ||
840 | self.admin_username, | ||
841 | admin_permissions, | ||
842 | ) | ||
843 | await db.set_user_token( | ||
844 | self.admin_username, | ||
845 | hash_token(TOKEN_ALGORITHM, new_salt(), self.admin_password), | ||
846 | ) | ||
847 | self.logger.info("Admin user '%s' updated", self.admin_username) | ||
848 | |||
849 | async def backfill_worker_task(self): | ||
850 | async with await create_async_client( | ||
851 | self.upstream | ||
852 | ) as client, self.db_engine.connect(self.logger) as db: | ||
853 | while True: | ||
854 | item = await self.backfill_queue.get() | ||
855 | if item is None: | ||
856 | self.backfill_queue.task_done() | ||
857 | break | ||
858 | |||
859 | method, taskhash = item | ||
860 | d = await client.get_taskhash(method, taskhash) | ||
861 | if d is not None: | ||
862 | await db.insert_unihash(d["method"], d["taskhash"], d["unihash"]) | ||
863 | self.backfill_queue.task_done() | ||
864 | |||
865 | def start(self): | ||
866 | tasks = super().start() | ||
867 | if self.upstream: | ||
868 | self.backfill_queue = asyncio.Queue() | ||
869 | tasks += [self.backfill_worker_task()] | ||
870 | |||
871 | self.loop.run_until_complete(self.db_engine.create()) | ||
872 | |||
873 | if self.admin_username: | ||
874 | self.loop.run_until_complete(self.create_admin_user()) | ||
875 | |||
876 | return tasks | ||
877 | |||
878 | async def stop(self): | ||
879 | if self.backfill_queue is not None: | ||
880 | await self.backfill_queue.put(None) | ||
881 | await super().stop() | ||
diff --git a/bitbake/lib/hashserv/sqlalchemy.py b/bitbake/lib/hashserv/sqlalchemy.py new file mode 100644 index 0000000000..f7b0226a7a --- /dev/null +++ b/bitbake/lib/hashserv/sqlalchemy.py | |||
@@ -0,0 +1,598 @@ | |||
1 | #! /usr/bin/env python3 | ||
2 | # | ||
3 | # Copyright (C) 2023 Garmin Ltd. | ||
4 | # | ||
5 | # SPDX-License-Identifier: GPL-2.0-only | ||
6 | # | ||
7 | |||
8 | import logging | ||
9 | from datetime import datetime | ||
10 | from . import User | ||
11 | |||
12 | from sqlalchemy.ext.asyncio import create_async_engine | ||
13 | from sqlalchemy.pool import NullPool | ||
14 | from sqlalchemy import ( | ||
15 | MetaData, | ||
16 | Column, | ||
17 | Table, | ||
18 | Text, | ||
19 | Integer, | ||
20 | UniqueConstraint, | ||
21 | DateTime, | ||
22 | Index, | ||
23 | select, | ||
24 | insert, | ||
25 | exists, | ||
26 | literal, | ||
27 | and_, | ||
28 | delete, | ||
29 | update, | ||
30 | func, | ||
31 | inspect, | ||
32 | ) | ||
33 | import sqlalchemy.engine | ||
34 | from sqlalchemy.orm import declarative_base | ||
35 | from sqlalchemy.exc import IntegrityError | ||
36 | from sqlalchemy.dialects.postgresql import insert as postgres_insert | ||
37 | |||
38 | Base = declarative_base() | ||
39 | |||
40 | |||
41 | class UnihashesV3(Base): | ||
42 | __tablename__ = "unihashes_v3" | ||
43 | id = Column(Integer, primary_key=True, autoincrement=True) | ||
44 | method = Column(Text, nullable=False) | ||
45 | taskhash = Column(Text, nullable=False) | ||
46 | unihash = Column(Text, nullable=False) | ||
47 | gc_mark = Column(Text, nullable=False) | ||
48 | |||
49 | __table_args__ = ( | ||
50 | UniqueConstraint("method", "taskhash"), | ||
51 | Index("taskhash_lookup_v4", "method", "taskhash"), | ||
52 | Index("unihash_lookup_v1", "unihash"), | ||
53 | ) | ||
54 | |||
55 | |||
56 | class OuthashesV2(Base): | ||
57 | __tablename__ = "outhashes_v2" | ||
58 | id = Column(Integer, primary_key=True, autoincrement=True) | ||
59 | method = Column(Text, nullable=False) | ||
60 | taskhash = Column(Text, nullable=False) | ||
61 | outhash = Column(Text, nullable=False) | ||
62 | created = Column(DateTime) | ||
63 | owner = Column(Text) | ||
64 | PN = Column(Text) | ||
65 | PV = Column(Text) | ||
66 | PR = Column(Text) | ||
67 | task = Column(Text) | ||
68 | outhash_siginfo = Column(Text) | ||
69 | |||
70 | __table_args__ = ( | ||
71 | UniqueConstraint("method", "taskhash", "outhash"), | ||
72 | Index("outhash_lookup_v3", "method", "outhash"), | ||
73 | ) | ||
74 | |||
75 | |||
76 | class Users(Base): | ||
77 | __tablename__ = "users" | ||
78 | id = Column(Integer, primary_key=True, autoincrement=True) | ||
79 | username = Column(Text, nullable=False) | ||
80 | token = Column(Text, nullable=False) | ||
81 | permissions = Column(Text) | ||
82 | |||
83 | __table_args__ = (UniqueConstraint("username"),) | ||
84 | |||
85 | |||
86 | class Config(Base): | ||
87 | __tablename__ = "config" | ||
88 | id = Column(Integer, primary_key=True, autoincrement=True) | ||
89 | name = Column(Text, nullable=False) | ||
90 | value = Column(Text) | ||
91 | __table_args__ = ( | ||
92 | UniqueConstraint("name"), | ||
93 | Index("config_lookup", "name"), | ||
94 | ) | ||
95 | |||
96 | |||
97 | # | ||
98 | # Old table versions | ||
99 | # | ||
100 | DeprecatedBase = declarative_base() | ||
101 | |||
102 | |||
103 | class UnihashesV2(DeprecatedBase): | ||
104 | __tablename__ = "unihashes_v2" | ||
105 | id = Column(Integer, primary_key=True, autoincrement=True) | ||
106 | method = Column(Text, nullable=False) | ||
107 | taskhash = Column(Text, nullable=False) | ||
108 | unihash = Column(Text, nullable=False) | ||
109 | |||
110 | __table_args__ = ( | ||
111 | UniqueConstraint("method", "taskhash"), | ||
112 | Index("taskhash_lookup_v3", "method", "taskhash"), | ||
113 | ) | ||
114 | |||
115 | |||
116 | class DatabaseEngine(object): | ||
117 | def __init__(self, url, username=None, password=None): | ||
118 | self.logger = logging.getLogger("hashserv.sqlalchemy") | ||
119 | self.url = sqlalchemy.engine.make_url(url) | ||
120 | |||
121 | if username is not None: | ||
122 | self.url = self.url.set(username=username) | ||
123 | |||
124 | if password is not None: | ||
125 | self.url = self.url.set(password=password) | ||
126 | |||
127 | async def create(self): | ||
128 | def check_table_exists(conn, name): | ||
129 | return inspect(conn).has_table(name) | ||
130 | |||
131 | self.logger.info("Using database %s", self.url) | ||
132 | if self.url.drivername == 'postgresql+psycopg': | ||
133 | # Psygopg 3 (psygopg) driver can handle async connection pooling | ||
134 | self.engine = create_async_engine(self.url, max_overflow=-1) | ||
135 | else: | ||
136 | self.engine = create_async_engine(self.url, poolclass=NullPool) | ||
137 | |||
138 | async with self.engine.begin() as conn: | ||
139 | # Create tables | ||
140 | self.logger.info("Creating tables...") | ||
141 | await conn.run_sync(Base.metadata.create_all) | ||
142 | |||
143 | if await conn.run_sync(check_table_exists, UnihashesV2.__tablename__): | ||
144 | self.logger.info("Upgrading Unihashes V2 -> V3...") | ||
145 | statement = insert(UnihashesV3).from_select( | ||
146 | ["id", "method", "unihash", "taskhash", "gc_mark"], | ||
147 | select( | ||
148 | UnihashesV2.id, | ||
149 | UnihashesV2.method, | ||
150 | UnihashesV2.unihash, | ||
151 | UnihashesV2.taskhash, | ||
152 | literal("").label("gc_mark"), | ||
153 | ), | ||
154 | ) | ||
155 | self.logger.debug("%s", statement) | ||
156 | await conn.execute(statement) | ||
157 | |||
158 | await conn.run_sync(Base.metadata.drop_all, [UnihashesV2.__table__]) | ||
159 | self.logger.info("Upgrade complete") | ||
160 | |||
161 | def connect(self, logger): | ||
162 | return Database(self.engine, logger) | ||
163 | |||
164 | |||
165 | def map_row(row): | ||
166 | if row is None: | ||
167 | return None | ||
168 | return dict(**row._mapping) | ||
169 | |||
170 | |||
171 | def map_user(row): | ||
172 | if row is None: | ||
173 | return None | ||
174 | return User( | ||
175 | username=row.username, | ||
176 | permissions=set(row.permissions.split()), | ||
177 | ) | ||
178 | |||
179 | |||
180 | def _make_condition_statement(table, condition): | ||
181 | where = {} | ||
182 | for c in table.__table__.columns: | ||
183 | if c.key in condition and condition[c.key] is not None: | ||
184 | where[c] = condition[c.key] | ||
185 | |||
186 | return [(k == v) for k, v in where.items()] | ||
187 | |||
188 | |||
189 | class Database(object): | ||
190 | def __init__(self, engine, logger): | ||
191 | self.engine = engine | ||
192 | self.db = None | ||
193 | self.logger = logger | ||
194 | |||
195 | async def __aenter__(self): | ||
196 | self.db = await self.engine.connect() | ||
197 | return self | ||
198 | |||
199 | async def __aexit__(self, exc_type, exc_value, traceback): | ||
200 | await self.close() | ||
201 | |||
202 | async def close(self): | ||
203 | await self.db.close() | ||
204 | self.db = None | ||
205 | |||
206 | async def _execute(self, statement): | ||
207 | self.logger.debug("%s", statement) | ||
208 | return await self.db.execute(statement) | ||
209 | |||
210 | async def _set_config(self, name, value): | ||
211 | while True: | ||
212 | result = await self._execute( | ||
213 | update(Config).where(Config.name == name).values(value=value) | ||
214 | ) | ||
215 | |||
216 | if result.rowcount == 0: | ||
217 | self.logger.debug("Config '%s' not found. Adding it", name) | ||
218 | try: | ||
219 | await self._execute(insert(Config).values(name=name, value=value)) | ||
220 | except IntegrityError: | ||
221 | # Race. Try again | ||
222 | continue | ||
223 | |||
224 | break | ||
225 | |||
226 | def _get_config_subquery(self, name, default=None): | ||
227 | if default is not None: | ||
228 | return func.coalesce( | ||
229 | select(Config.value).where(Config.name == name).scalar_subquery(), | ||
230 | default, | ||
231 | ) | ||
232 | return select(Config.value).where(Config.name == name).scalar_subquery() | ||
233 | |||
234 | async def _get_config(self, name): | ||
235 | result = await self._execute(select(Config.value).where(Config.name == name)) | ||
236 | row = result.first() | ||
237 | if row is None: | ||
238 | return None | ||
239 | return row.value | ||
240 | |||
241 | async def get_unihash_by_taskhash_full(self, method, taskhash): | ||
242 | async with self.db.begin(): | ||
243 | result = await self._execute( | ||
244 | select( | ||
245 | OuthashesV2, | ||
246 | UnihashesV3.unihash.label("unihash"), | ||
247 | ) | ||
248 | .join( | ||
249 | UnihashesV3, | ||
250 | and_( | ||
251 | UnihashesV3.method == OuthashesV2.method, | ||
252 | UnihashesV3.taskhash == OuthashesV2.taskhash, | ||
253 | ), | ||
254 | ) | ||
255 | .where( | ||
256 | OuthashesV2.method == method, | ||
257 | OuthashesV2.taskhash == taskhash, | ||
258 | ) | ||
259 | .order_by( | ||
260 | OuthashesV2.created.asc(), | ||
261 | ) | ||
262 | .limit(1) | ||
263 | ) | ||
264 | return map_row(result.first()) | ||
265 | |||
266 | async def get_unihash_by_outhash(self, method, outhash): | ||
267 | async with self.db.begin(): | ||
268 | result = await self._execute( | ||
269 | select(OuthashesV2, UnihashesV3.unihash.label("unihash")) | ||
270 | .join( | ||
271 | UnihashesV3, | ||
272 | and_( | ||
273 | UnihashesV3.method == OuthashesV2.method, | ||
274 | UnihashesV3.taskhash == OuthashesV2.taskhash, | ||
275 | ), | ||
276 | ) | ||
277 | .where( | ||
278 | OuthashesV2.method == method, | ||
279 | OuthashesV2.outhash == outhash, | ||
280 | ) | ||
281 | .order_by( | ||
282 | OuthashesV2.created.asc(), | ||
283 | ) | ||
284 | .limit(1) | ||
285 | ) | ||
286 | return map_row(result.first()) | ||
287 | |||
288 | async def unihash_exists(self, unihash): | ||
289 | async with self.db.begin(): | ||
290 | result = await self._execute( | ||
291 | select(UnihashesV3).where(UnihashesV3.unihash == unihash).limit(1) | ||
292 | ) | ||
293 | |||
294 | return result.first() is not None | ||
295 | |||
296 | async def get_outhash(self, method, outhash): | ||
297 | async with self.db.begin(): | ||
298 | result = await self._execute( | ||
299 | select(OuthashesV2) | ||
300 | .where( | ||
301 | OuthashesV2.method == method, | ||
302 | OuthashesV2.outhash == outhash, | ||
303 | ) | ||
304 | .order_by( | ||
305 | OuthashesV2.created.asc(), | ||
306 | ) | ||
307 | .limit(1) | ||
308 | ) | ||
309 | return map_row(result.first()) | ||
310 | |||
311 | async def get_equivalent_for_outhash(self, method, outhash, taskhash): | ||
312 | async with self.db.begin(): | ||
313 | result = await self._execute( | ||
314 | select( | ||
315 | OuthashesV2.taskhash.label("taskhash"), | ||
316 | UnihashesV3.unihash.label("unihash"), | ||
317 | ) | ||
318 | .join( | ||
319 | UnihashesV3, | ||
320 | and_( | ||
321 | UnihashesV3.method == OuthashesV2.method, | ||
322 | UnihashesV3.taskhash == OuthashesV2.taskhash, | ||
323 | ), | ||
324 | ) | ||
325 | .where( | ||
326 | OuthashesV2.method == method, | ||
327 | OuthashesV2.outhash == outhash, | ||
328 | OuthashesV2.taskhash != taskhash, | ||
329 | ) | ||
330 | .order_by( | ||
331 | OuthashesV2.created.asc(), | ||
332 | ) | ||
333 | .limit(1) | ||
334 | ) | ||
335 | return map_row(result.first()) | ||
336 | |||
337 | async def get_equivalent(self, method, taskhash): | ||
338 | async with self.db.begin(): | ||
339 | result = await self._execute( | ||
340 | select( | ||
341 | UnihashesV3.unihash, | ||
342 | UnihashesV3.method, | ||
343 | UnihashesV3.taskhash, | ||
344 | ).where( | ||
345 | UnihashesV3.method == method, | ||
346 | UnihashesV3.taskhash == taskhash, | ||
347 | ) | ||
348 | ) | ||
349 | return map_row(result.first()) | ||
350 | |||
351 | async def remove(self, condition): | ||
352 | async def do_remove(table): | ||
353 | where = _make_condition_statement(table, condition) | ||
354 | if where: | ||
355 | async with self.db.begin(): | ||
356 | result = await self._execute(delete(table).where(*where)) | ||
357 | return result.rowcount | ||
358 | |||
359 | return 0 | ||
360 | |||
361 | count = 0 | ||
362 | count += await do_remove(UnihashesV3) | ||
363 | count += await do_remove(OuthashesV2) | ||
364 | |||
365 | return count | ||
366 | |||
367 | async def get_current_gc_mark(self): | ||
368 | async with self.db.begin(): | ||
369 | return await self._get_config("gc-mark") | ||
370 | |||
371 | async def gc_status(self): | ||
372 | async with self.db.begin(): | ||
373 | gc_mark_subquery = self._get_config_subquery("gc-mark", "") | ||
374 | |||
375 | result = await self._execute( | ||
376 | select(func.count()) | ||
377 | .select_from(UnihashesV3) | ||
378 | .where(UnihashesV3.gc_mark == gc_mark_subquery) | ||
379 | ) | ||
380 | keep_rows = result.scalar() | ||
381 | |||
382 | result = await self._execute( | ||
383 | select(func.count()) | ||
384 | .select_from(UnihashesV3) | ||
385 | .where(UnihashesV3.gc_mark != gc_mark_subquery) | ||
386 | ) | ||
387 | remove_rows = result.scalar() | ||
388 | |||
389 | return (keep_rows, remove_rows, await self._get_config("gc-mark")) | ||
390 | |||
391 | async def gc_mark(self, mark, condition): | ||
392 | async with self.db.begin(): | ||
393 | await self._set_config("gc-mark", mark) | ||
394 | |||
395 | where = _make_condition_statement(UnihashesV3, condition) | ||
396 | if not where: | ||
397 | return 0 | ||
398 | |||
399 | result = await self._execute( | ||
400 | update(UnihashesV3) | ||
401 | .values(gc_mark=self._get_config_subquery("gc-mark", "")) | ||
402 | .where(*where) | ||
403 | ) | ||
404 | return result.rowcount | ||
405 | |||
406 | async def gc_sweep(self): | ||
407 | async with self.db.begin(): | ||
408 | result = await self._execute( | ||
409 | delete(UnihashesV3).where( | ||
410 | # A sneaky conditional that provides some errant use | ||
411 | # protection: If the config mark is NULL, this will not | ||
412 | # match any rows because No default is specified in the | ||
413 | # select statement | ||
414 | UnihashesV3.gc_mark | ||
415 | != self._get_config_subquery("gc-mark") | ||
416 | ) | ||
417 | ) | ||
418 | await self._set_config("gc-mark", None) | ||
419 | |||
420 | return result.rowcount | ||
421 | |||
422 | async def clean_unused(self, oldest): | ||
423 | async with self.db.begin(): | ||
424 | result = await self._execute( | ||
425 | delete(OuthashesV2).where( | ||
426 | OuthashesV2.created < oldest, | ||
427 | ~( | ||
428 | select(UnihashesV3.id) | ||
429 | .where( | ||
430 | UnihashesV3.method == OuthashesV2.method, | ||
431 | UnihashesV3.taskhash == OuthashesV2.taskhash, | ||
432 | ) | ||
433 | .limit(1) | ||
434 | .exists() | ||
435 | ), | ||
436 | ) | ||
437 | ) | ||
438 | return result.rowcount | ||
439 | |||
440 | async def insert_unihash(self, method, taskhash, unihash): | ||
441 | # Postgres specific ignore on insert duplicate | ||
442 | if self.engine.name == "postgresql": | ||
443 | statement = ( | ||
444 | postgres_insert(UnihashesV3) | ||
445 | .values( | ||
446 | method=method, | ||
447 | taskhash=taskhash, | ||
448 | unihash=unihash, | ||
449 | gc_mark=self._get_config_subquery("gc-mark", ""), | ||
450 | ) | ||
451 | .on_conflict_do_nothing(index_elements=("method", "taskhash")) | ||
452 | ) | ||
453 | else: | ||
454 | statement = insert(UnihashesV3).values( | ||
455 | method=method, | ||
456 | taskhash=taskhash, | ||
457 | unihash=unihash, | ||
458 | gc_mark=self._get_config_subquery("gc-mark", ""), | ||
459 | ) | ||
460 | |||
461 | try: | ||
462 | async with self.db.begin(): | ||
463 | result = await self._execute(statement) | ||
464 | return result.rowcount != 0 | ||
465 | except IntegrityError: | ||
466 | self.logger.debug( | ||
467 | "%s, %s, %s already in unihash database", method, taskhash, unihash | ||
468 | ) | ||
469 | return False | ||
470 | |||
471 | async def insert_outhash(self, data): | ||
472 | outhash_columns = set(c.key for c in OuthashesV2.__table__.columns) | ||
473 | |||
474 | data = {k: v for k, v in data.items() if k in outhash_columns} | ||
475 | |||
476 | if "created" in data and not isinstance(data["created"], datetime): | ||
477 | data["created"] = datetime.fromisoformat(data["created"]) | ||
478 | |||
479 | # Postgres specific ignore on insert duplicate | ||
480 | if self.engine.name == "postgresql": | ||
481 | statement = ( | ||
482 | postgres_insert(OuthashesV2) | ||
483 | .values(**data) | ||
484 | .on_conflict_do_nothing( | ||
485 | index_elements=("method", "taskhash", "outhash") | ||
486 | ) | ||
487 | ) | ||
488 | else: | ||
489 | statement = insert(OuthashesV2).values(**data) | ||
490 | |||
491 | try: | ||
492 | async with self.db.begin(): | ||
493 | result = await self._execute(statement) | ||
494 | return result.rowcount != 0 | ||
495 | except IntegrityError: | ||
496 | self.logger.debug( | ||
497 | "%s, %s already in outhash database", data["method"], data["outhash"] | ||
498 | ) | ||
499 | return False | ||
500 | |||
501 | async def _get_user(self, username): | ||
502 | async with self.db.begin(): | ||
503 | result = await self._execute( | ||
504 | select( | ||
505 | Users.username, | ||
506 | Users.permissions, | ||
507 | Users.token, | ||
508 | ).where( | ||
509 | Users.username == username, | ||
510 | ) | ||
511 | ) | ||
512 | return result.first() | ||
513 | |||
514 | async def lookup_user_token(self, username): | ||
515 | row = await self._get_user(username) | ||
516 | if not row: | ||
517 | return None, None | ||
518 | return map_user(row), row.token | ||
519 | |||
520 | async def lookup_user(self, username): | ||
521 | return map_user(await self._get_user(username)) | ||
522 | |||
523 | async def set_user_token(self, username, token): | ||
524 | async with self.db.begin(): | ||
525 | result = await self._execute( | ||
526 | update(Users) | ||
527 | .where( | ||
528 | Users.username == username, | ||
529 | ) | ||
530 | .values( | ||
531 | token=token, | ||
532 | ) | ||
533 | ) | ||
534 | return result.rowcount != 0 | ||
535 | |||
536 | async def set_user_perms(self, username, permissions): | ||
537 | async with self.db.begin(): | ||
538 | result = await self._execute( | ||
539 | update(Users) | ||
540 | .where(Users.username == username) | ||
541 | .values(permissions=" ".join(permissions)) | ||
542 | ) | ||
543 | return result.rowcount != 0 | ||
544 | |||
545 | async def get_all_users(self): | ||
546 | async with self.db.begin(): | ||
547 | result = await self._execute( | ||
548 | select( | ||
549 | Users.username, | ||
550 | Users.permissions, | ||
551 | ) | ||
552 | ) | ||
553 | return [map_user(row) for row in result] | ||
554 | |||
555 | async def new_user(self, username, permissions, token): | ||
556 | try: | ||
557 | async with self.db.begin(): | ||
558 | await self._execute( | ||
559 | insert(Users).values( | ||
560 | username=username, | ||
561 | permissions=" ".join(permissions), | ||
562 | token=token, | ||
563 | ) | ||
564 | ) | ||
565 | return True | ||
566 | except IntegrityError as e: | ||
567 | self.logger.debug("Cannot create new user %s: %s", username, e) | ||
568 | return False | ||
569 | |||
570 | async def delete_user(self, username): | ||
571 | async with self.db.begin(): | ||
572 | result = await self._execute( | ||
573 | delete(Users).where(Users.username == username) | ||
574 | ) | ||
575 | return result.rowcount != 0 | ||
576 | |||
577 | async def get_usage(self): | ||
578 | usage = {} | ||
579 | async with self.db.begin() as session: | ||
580 | for name, table in Base.metadata.tables.items(): | ||
581 | result = await self._execute( | ||
582 | statement=select(func.count()).select_from(table) | ||
583 | ) | ||
584 | usage[name] = { | ||
585 | "rows": result.scalar(), | ||
586 | } | ||
587 | |||
588 | return usage | ||
589 | |||
590 | async def get_query_columns(self): | ||
591 | columns = set() | ||
592 | for table in (UnihashesV3, OuthashesV2): | ||
593 | for c in table.__table__.columns: | ||
594 | if not isinstance(c.type, Text): | ||
595 | continue | ||
596 | columns.add(c.key) | ||
597 | |||
598 | return list(columns) | ||
diff --git a/bitbake/lib/hashserv/sqlite.py b/bitbake/lib/hashserv/sqlite.py new file mode 100644 index 0000000000..da2e844a03 --- /dev/null +++ b/bitbake/lib/hashserv/sqlite.py | |||
@@ -0,0 +1,562 @@ | |||
1 | #! /usr/bin/env python3 | ||
2 | # | ||
3 | # Copyright (C) 2023 Garmin Ltd. | ||
4 | # | ||
5 | # SPDX-License-Identifier: GPL-2.0-only | ||
6 | # | ||
7 | import sqlite3 | ||
8 | import logging | ||
9 | from contextlib import closing | ||
10 | from . import User | ||
11 | |||
12 | logger = logging.getLogger("hashserv.sqlite") | ||
13 | |||
14 | UNIHASH_TABLE_DEFINITION = ( | ||
15 | ("method", "TEXT NOT NULL", "UNIQUE"), | ||
16 | ("taskhash", "TEXT NOT NULL", "UNIQUE"), | ||
17 | ("unihash", "TEXT NOT NULL", ""), | ||
18 | ("gc_mark", "TEXT NOT NULL", ""), | ||
19 | ) | ||
20 | |||
21 | UNIHASH_TABLE_COLUMNS = tuple(name for name, _, _ in UNIHASH_TABLE_DEFINITION) | ||
22 | |||
23 | OUTHASH_TABLE_DEFINITION = ( | ||
24 | ("method", "TEXT NOT NULL", "UNIQUE"), | ||
25 | ("taskhash", "TEXT NOT NULL", "UNIQUE"), | ||
26 | ("outhash", "TEXT NOT NULL", "UNIQUE"), | ||
27 | ("created", "DATETIME", ""), | ||
28 | # Optional fields | ||
29 | ("owner", "TEXT", ""), | ||
30 | ("PN", "TEXT", ""), | ||
31 | ("PV", "TEXT", ""), | ||
32 | ("PR", "TEXT", ""), | ||
33 | ("task", "TEXT", ""), | ||
34 | ("outhash_siginfo", "TEXT", ""), | ||
35 | ) | ||
36 | |||
37 | OUTHASH_TABLE_COLUMNS = tuple(name for name, _, _ in OUTHASH_TABLE_DEFINITION) | ||
38 | |||
39 | USERS_TABLE_DEFINITION = ( | ||
40 | ("username", "TEXT NOT NULL", "UNIQUE"), | ||
41 | ("token", "TEXT NOT NULL", ""), | ||
42 | ("permissions", "TEXT NOT NULL", ""), | ||
43 | ) | ||
44 | |||
45 | USERS_TABLE_COLUMNS = tuple(name for name, _, _ in USERS_TABLE_DEFINITION) | ||
46 | |||
47 | |||
48 | CONFIG_TABLE_DEFINITION = ( | ||
49 | ("name", "TEXT NOT NULL", "UNIQUE"), | ||
50 | ("value", "TEXT", ""), | ||
51 | ) | ||
52 | |||
53 | CONFIG_TABLE_COLUMNS = tuple(name for name, _, _ in CONFIG_TABLE_DEFINITION) | ||
54 | |||
55 | |||
56 | def _make_table(cursor, name, definition): | ||
57 | cursor.execute( | ||
58 | """ | ||
59 | CREATE TABLE IF NOT EXISTS {name} ( | ||
60 | id INTEGER PRIMARY KEY AUTOINCREMENT, | ||
61 | {fields} | ||
62 | UNIQUE({unique}) | ||
63 | ) | ||
64 | """.format( | ||
65 | name=name, | ||
66 | fields=" ".join("%s %s," % (name, typ) for name, typ, _ in definition), | ||
67 | unique=", ".join( | ||
68 | name for name, _, flags in definition if "UNIQUE" in flags | ||
69 | ), | ||
70 | ) | ||
71 | ) | ||
72 | |||
73 | |||
74 | def map_user(row): | ||
75 | if row is None: | ||
76 | return None | ||
77 | return User( | ||
78 | username=row["username"], | ||
79 | permissions=set(row["permissions"].split()), | ||
80 | ) | ||
81 | |||
82 | |||
83 | def _make_condition_statement(columns, condition): | ||
84 | where = {} | ||
85 | for c in columns: | ||
86 | if c in condition and condition[c] is not None: | ||
87 | where[c] = condition[c] | ||
88 | |||
89 | return where, " AND ".join("%s=:%s" % (k, k) for k in where.keys()) | ||
90 | |||
91 | |||
92 | def _get_sqlite_version(cursor): | ||
93 | cursor.execute("SELECT sqlite_version()") | ||
94 | |||
95 | version = [] | ||
96 | for v in cursor.fetchone()[0].split("."): | ||
97 | try: | ||
98 | version.append(int(v)) | ||
99 | except ValueError: | ||
100 | version.append(v) | ||
101 | |||
102 | return tuple(version) | ||
103 | |||
104 | |||
105 | def _schema_table_name(version): | ||
106 | if version >= (3, 33): | ||
107 | return "sqlite_schema" | ||
108 | |||
109 | return "sqlite_master" | ||
110 | |||
111 | |||
112 | class DatabaseEngine(object): | ||
113 | def __init__(self, dbname, sync): | ||
114 | self.dbname = dbname | ||
115 | self.logger = logger | ||
116 | self.sync = sync | ||
117 | |||
118 | async def create(self): | ||
119 | db = sqlite3.connect(self.dbname) | ||
120 | db.row_factory = sqlite3.Row | ||
121 | |||
122 | with closing(db.cursor()) as cursor: | ||
123 | _make_table(cursor, "unihashes_v3", UNIHASH_TABLE_DEFINITION) | ||
124 | _make_table(cursor, "outhashes_v2", OUTHASH_TABLE_DEFINITION) | ||
125 | _make_table(cursor, "users", USERS_TABLE_DEFINITION) | ||
126 | _make_table(cursor, "config", CONFIG_TABLE_DEFINITION) | ||
127 | |||
128 | cursor.execute("PRAGMA journal_mode = WAL") | ||
129 | cursor.execute( | ||
130 | "PRAGMA synchronous = %s" % ("NORMAL" if self.sync else "OFF") | ||
131 | ) | ||
132 | |||
133 | # Drop old indexes | ||
134 | cursor.execute("DROP INDEX IF EXISTS taskhash_lookup") | ||
135 | cursor.execute("DROP INDEX IF EXISTS outhash_lookup") | ||
136 | cursor.execute("DROP INDEX IF EXISTS taskhash_lookup_v2") | ||
137 | cursor.execute("DROP INDEX IF EXISTS outhash_lookup_v2") | ||
138 | cursor.execute("DROP INDEX IF EXISTS taskhash_lookup_v3") | ||
139 | |||
140 | # TODO: Upgrade from tasks_v2? | ||
141 | cursor.execute("DROP TABLE IF EXISTS tasks_v2") | ||
142 | |||
143 | # Create new indexes | ||
144 | cursor.execute( | ||
145 | "CREATE INDEX IF NOT EXISTS taskhash_lookup_v4 ON unihashes_v3 (method, taskhash)" | ||
146 | ) | ||
147 | cursor.execute( | ||
148 | "CREATE INDEX IF NOT EXISTS unihash_lookup_v1 ON unihashes_v3 (unihash)" | ||
149 | ) | ||
150 | cursor.execute( | ||
151 | "CREATE INDEX IF NOT EXISTS outhash_lookup_v3 ON outhashes_v2 (method, outhash)" | ||
152 | ) | ||
153 | cursor.execute("CREATE INDEX IF NOT EXISTS config_lookup ON config (name)") | ||
154 | |||
155 | sqlite_version = _get_sqlite_version(cursor) | ||
156 | |||
157 | cursor.execute( | ||
158 | f""" | ||
159 | SELECT name FROM {_schema_table_name(sqlite_version)} WHERE type = 'table' AND name = 'unihashes_v2' | ||
160 | """ | ||
161 | ) | ||
162 | if cursor.fetchone(): | ||
163 | self.logger.info("Upgrading Unihashes V2 -> V3...") | ||
164 | cursor.execute( | ||
165 | """ | ||
166 | INSERT INTO unihashes_v3 (id, method, unihash, taskhash, gc_mark) | ||
167 | SELECT id, method, unihash, taskhash, '' FROM unihashes_v2 | ||
168 | """ | ||
169 | ) | ||
170 | cursor.execute("DROP TABLE unihashes_v2") | ||
171 | db.commit() | ||
172 | self.logger.info("Upgrade complete") | ||
173 | |||
174 | def connect(self, logger): | ||
175 | return Database(logger, self.dbname, self.sync) | ||
176 | |||
177 | |||
178 | class Database(object): | ||
179 | def __init__(self, logger, dbname, sync): | ||
180 | self.dbname = dbname | ||
181 | self.logger = logger | ||
182 | |||
183 | self.db = sqlite3.connect(self.dbname) | ||
184 | self.db.row_factory = sqlite3.Row | ||
185 | |||
186 | with closing(self.db.cursor()) as cursor: | ||
187 | cursor.execute("PRAGMA journal_mode = WAL") | ||
188 | cursor.execute( | ||
189 | "PRAGMA synchronous = %s" % ("NORMAL" if sync else "OFF") | ||
190 | ) | ||
191 | |||
192 | self.sqlite_version = _get_sqlite_version(cursor) | ||
193 | |||
194 | async def __aenter__(self): | ||
195 | return self | ||
196 | |||
197 | async def __aexit__(self, exc_type, exc_value, traceback): | ||
198 | await self.close() | ||
199 | |||
200 | async def _set_config(self, cursor, name, value): | ||
201 | cursor.execute( | ||
202 | """ | ||
203 | INSERT OR REPLACE INTO config (id, name, value) VALUES | ||
204 | ((SELECT id FROM config WHERE name=:name), :name, :value) | ||
205 | """, | ||
206 | { | ||
207 | "name": name, | ||
208 | "value": value, | ||
209 | }, | ||
210 | ) | ||
211 | |||
212 | async def _get_config(self, cursor, name): | ||
213 | cursor.execute( | ||
214 | "SELECT value FROM config WHERE name=:name", | ||
215 | { | ||
216 | "name": name, | ||
217 | }, | ||
218 | ) | ||
219 | row = cursor.fetchone() | ||
220 | if row is None: | ||
221 | return None | ||
222 | return row["value"] | ||
223 | |||
224 | async def close(self): | ||
225 | self.db.close() | ||
226 | |||
227 | async def get_unihash_by_taskhash_full(self, method, taskhash): | ||
228 | with closing(self.db.cursor()) as cursor: | ||
229 | cursor.execute( | ||
230 | """ | ||
231 | SELECT *, unihashes_v3.unihash AS unihash FROM outhashes_v2 | ||
232 | INNER JOIN unihashes_v3 ON unihashes_v3.method=outhashes_v2.method AND unihashes_v3.taskhash=outhashes_v2.taskhash | ||
233 | WHERE outhashes_v2.method=:method AND outhashes_v2.taskhash=:taskhash | ||
234 | ORDER BY outhashes_v2.created ASC | ||
235 | LIMIT 1 | ||
236 | """, | ||
237 | { | ||
238 | "method": method, | ||
239 | "taskhash": taskhash, | ||
240 | }, | ||
241 | ) | ||
242 | return cursor.fetchone() | ||
243 | |||
244 | async def get_unihash_by_outhash(self, method, outhash): | ||
245 | with closing(self.db.cursor()) as cursor: | ||
246 | cursor.execute( | ||
247 | """ | ||
248 | SELECT *, unihashes_v3.unihash AS unihash FROM outhashes_v2 | ||
249 | INNER JOIN unihashes_v3 ON unihashes_v3.method=outhashes_v2.method AND unihashes_v3.taskhash=outhashes_v2.taskhash | ||
250 | WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash | ||
251 | ORDER BY outhashes_v2.created ASC | ||
252 | LIMIT 1 | ||
253 | """, | ||
254 | { | ||
255 | "method": method, | ||
256 | "outhash": outhash, | ||
257 | }, | ||
258 | ) | ||
259 | return cursor.fetchone() | ||
260 | |||
261 | async def unihash_exists(self, unihash): | ||
262 | with closing(self.db.cursor()) as cursor: | ||
263 | cursor.execute( | ||
264 | """ | ||
265 | SELECT * FROM unihashes_v3 WHERE unihash=:unihash | ||
266 | LIMIT 1 | ||
267 | """, | ||
268 | { | ||
269 | "unihash": unihash, | ||
270 | }, | ||
271 | ) | ||
272 | return cursor.fetchone() is not None | ||
273 | |||
274 | async def get_outhash(self, method, outhash): | ||
275 | with closing(self.db.cursor()) as cursor: | ||
276 | cursor.execute( | ||
277 | """ | ||
278 | SELECT * FROM outhashes_v2 | ||
279 | WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash | ||
280 | ORDER BY outhashes_v2.created ASC | ||
281 | LIMIT 1 | ||
282 | """, | ||
283 | { | ||
284 | "method": method, | ||
285 | "outhash": outhash, | ||
286 | }, | ||
287 | ) | ||
288 | return cursor.fetchone() | ||
289 | |||
290 | async def get_equivalent_for_outhash(self, method, outhash, taskhash): | ||
291 | with closing(self.db.cursor()) as cursor: | ||
292 | cursor.execute( | ||
293 | """ | ||
294 | SELECT outhashes_v2.taskhash AS taskhash, unihashes_v3.unihash AS unihash FROM outhashes_v2 | ||
295 | INNER JOIN unihashes_v3 ON unihashes_v3.method=outhashes_v2.method AND unihashes_v3.taskhash=outhashes_v2.taskhash | ||
296 | -- Select any matching output hash except the one we just inserted | ||
297 | WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash AND outhashes_v2.taskhash!=:taskhash | ||
298 | -- Pick the oldest hash | ||
299 | ORDER BY outhashes_v2.created ASC | ||
300 | LIMIT 1 | ||
301 | """, | ||
302 | { | ||
303 | "method": method, | ||
304 | "outhash": outhash, | ||
305 | "taskhash": taskhash, | ||
306 | }, | ||
307 | ) | ||
308 | return cursor.fetchone() | ||
309 | |||
310 | async def get_equivalent(self, method, taskhash): | ||
311 | with closing(self.db.cursor()) as cursor: | ||
312 | cursor.execute( | ||
313 | "SELECT taskhash, method, unihash FROM unihashes_v3 WHERE method=:method AND taskhash=:taskhash", | ||
314 | { | ||
315 | "method": method, | ||
316 | "taskhash": taskhash, | ||
317 | }, | ||
318 | ) | ||
319 | return cursor.fetchone() | ||
320 | |||
321 | async def remove(self, condition): | ||
322 | def do_remove(columns, table_name, cursor): | ||
323 | where, clause = _make_condition_statement(columns, condition) | ||
324 | if where: | ||
325 | query = f"DELETE FROM {table_name} WHERE {clause}" | ||
326 | cursor.execute(query, where) | ||
327 | return cursor.rowcount | ||
328 | |||
329 | return 0 | ||
330 | |||
331 | count = 0 | ||
332 | with closing(self.db.cursor()) as cursor: | ||
333 | count += do_remove(OUTHASH_TABLE_COLUMNS, "outhashes_v2", cursor) | ||
334 | count += do_remove(UNIHASH_TABLE_COLUMNS, "unihashes_v3", cursor) | ||
335 | self.db.commit() | ||
336 | |||
337 | return count | ||
338 | |||
339 | async def get_current_gc_mark(self): | ||
340 | with closing(self.db.cursor()) as cursor: | ||
341 | return await self._get_config(cursor, "gc-mark") | ||
342 | |||
343 | async def gc_status(self): | ||
344 | with closing(self.db.cursor()) as cursor: | ||
345 | cursor.execute( | ||
346 | """ | ||
347 | SELECT COUNT() FROM unihashes_v3 WHERE | ||
348 | gc_mark=COALESCE((SELECT value FROM config WHERE name='gc-mark'), '') | ||
349 | """ | ||
350 | ) | ||
351 | keep_rows = cursor.fetchone()[0] | ||
352 | |||
353 | cursor.execute( | ||
354 | """ | ||
355 | SELECT COUNT() FROM unihashes_v3 WHERE | ||
356 | gc_mark!=COALESCE((SELECT value FROM config WHERE name='gc-mark'), '') | ||
357 | """ | ||
358 | ) | ||
359 | remove_rows = cursor.fetchone()[0] | ||
360 | |||
361 | current_mark = await self._get_config(cursor, "gc-mark") | ||
362 | |||
363 | return (keep_rows, remove_rows, current_mark) | ||
364 | |||
365 | async def gc_mark(self, mark, condition): | ||
366 | with closing(self.db.cursor()) as cursor: | ||
367 | await self._set_config(cursor, "gc-mark", mark) | ||
368 | |||
369 | where, clause = _make_condition_statement(UNIHASH_TABLE_COLUMNS, condition) | ||
370 | |||
371 | new_rows = 0 | ||
372 | if where: | ||
373 | cursor.execute( | ||
374 | f""" | ||
375 | UPDATE unihashes_v3 SET | ||
376 | gc_mark=COALESCE((SELECT value FROM config WHERE name='gc-mark'), '') | ||
377 | WHERE {clause} | ||
378 | """, | ||
379 | where, | ||
380 | ) | ||
381 | new_rows = cursor.rowcount | ||
382 | |||
383 | self.db.commit() | ||
384 | return new_rows | ||
385 | |||
386 | async def gc_sweep(self): | ||
387 | with closing(self.db.cursor()) as cursor: | ||
388 | # NOTE: COALESCE is not used in this query so that if the current | ||
389 | # mark is NULL, nothing will happen | ||
390 | cursor.execute( | ||
391 | """ | ||
392 | DELETE FROM unihashes_v3 WHERE | ||
393 | gc_mark!=(SELECT value FROM config WHERE name='gc-mark') | ||
394 | """ | ||
395 | ) | ||
396 | count = cursor.rowcount | ||
397 | await self._set_config(cursor, "gc-mark", None) | ||
398 | |||
399 | self.db.commit() | ||
400 | return count | ||
401 | |||
402 | async def clean_unused(self, oldest): | ||
403 | with closing(self.db.cursor()) as cursor: | ||
404 | cursor.execute( | ||
405 | """ | ||
406 | DELETE FROM outhashes_v2 WHERE created<:oldest AND NOT EXISTS ( | ||
407 | SELECT unihashes_v3.id FROM unihashes_v3 WHERE unihashes_v3.method=outhashes_v2.method AND unihashes_v3.taskhash=outhashes_v2.taskhash LIMIT 1 | ||
408 | ) | ||
409 | """, | ||
410 | { | ||
411 | "oldest": oldest, | ||
412 | }, | ||
413 | ) | ||
414 | self.db.commit() | ||
415 | return cursor.rowcount | ||
416 | |||
417 | async def insert_unihash(self, method, taskhash, unihash): | ||
418 | with closing(self.db.cursor()) as cursor: | ||
419 | prevrowid = cursor.lastrowid | ||
420 | cursor.execute( | ||
421 | """ | ||
422 | INSERT OR IGNORE INTO unihashes_v3 (method, taskhash, unihash, gc_mark) VALUES | ||
423 | ( | ||
424 | :method, | ||
425 | :taskhash, | ||
426 | :unihash, | ||
427 | COALESCE((SELECT value FROM config WHERE name='gc-mark'), '') | ||
428 | ) | ||
429 | """, | ||
430 | { | ||
431 | "method": method, | ||
432 | "taskhash": taskhash, | ||
433 | "unihash": unihash, | ||
434 | }, | ||
435 | ) | ||
436 | self.db.commit() | ||
437 | return cursor.lastrowid != prevrowid | ||
438 | |||
439 | async def insert_outhash(self, data): | ||
440 | data = {k: v for k, v in data.items() if k in OUTHASH_TABLE_COLUMNS} | ||
441 | keys = sorted(data.keys()) | ||
442 | query = "INSERT OR IGNORE INTO outhashes_v2 ({fields}) VALUES({values})".format( | ||
443 | fields=", ".join(keys), | ||
444 | values=", ".join(":" + k for k in keys), | ||
445 | ) | ||
446 | with closing(self.db.cursor()) as cursor: | ||
447 | prevrowid = cursor.lastrowid | ||
448 | cursor.execute(query, data) | ||
449 | self.db.commit() | ||
450 | return cursor.lastrowid != prevrowid | ||
451 | |||
452 | def _get_user(self, username): | ||
453 | with closing(self.db.cursor()) as cursor: | ||
454 | cursor.execute( | ||
455 | """ | ||
456 | SELECT username, permissions, token FROM users WHERE username=:username | ||
457 | """, | ||
458 | { | ||
459 | "username": username, | ||
460 | }, | ||
461 | ) | ||
462 | return cursor.fetchone() | ||
463 | |||
464 | async def lookup_user_token(self, username): | ||
465 | row = self._get_user(username) | ||
466 | if row is None: | ||
467 | return None, None | ||
468 | return map_user(row), row["token"] | ||
469 | |||
470 | async def lookup_user(self, username): | ||
471 | return map_user(self._get_user(username)) | ||
472 | |||
473 | async def set_user_token(self, username, token): | ||
474 | with closing(self.db.cursor()) as cursor: | ||
475 | cursor.execute( | ||
476 | """ | ||
477 | UPDATE users SET token=:token WHERE username=:username | ||
478 | """, | ||
479 | { | ||
480 | "username": username, | ||
481 | "token": token, | ||
482 | }, | ||
483 | ) | ||
484 | self.db.commit() | ||
485 | return cursor.rowcount != 0 | ||
486 | |||
487 | async def set_user_perms(self, username, permissions): | ||
488 | with closing(self.db.cursor()) as cursor: | ||
489 | cursor.execute( | ||
490 | """ | ||
491 | UPDATE users SET permissions=:permissions WHERE username=:username | ||
492 | """, | ||
493 | { | ||
494 | "username": username, | ||
495 | "permissions": " ".join(permissions), | ||
496 | }, | ||
497 | ) | ||
498 | self.db.commit() | ||
499 | return cursor.rowcount != 0 | ||
500 | |||
501 | async def get_all_users(self): | ||
502 | with closing(self.db.cursor()) as cursor: | ||
503 | cursor.execute("SELECT username, permissions FROM users") | ||
504 | return [map_user(r) for r in cursor.fetchall()] | ||
505 | |||
506 | async def new_user(self, username, permissions, token): | ||
507 | with closing(self.db.cursor()) as cursor: | ||
508 | try: | ||
509 | cursor.execute( | ||
510 | """ | ||
511 | INSERT INTO users (username, token, permissions) VALUES (:username, :token, :permissions) | ||
512 | """, | ||
513 | { | ||
514 | "username": username, | ||
515 | "token": token, | ||
516 | "permissions": " ".join(permissions), | ||
517 | }, | ||
518 | ) | ||
519 | self.db.commit() | ||
520 | return True | ||
521 | except sqlite3.IntegrityError: | ||
522 | return False | ||
523 | |||
524 | async def delete_user(self, username): | ||
525 | with closing(self.db.cursor()) as cursor: | ||
526 | cursor.execute( | ||
527 | """ | ||
528 | DELETE FROM users WHERE username=:username | ||
529 | """, | ||
530 | { | ||
531 | "username": username, | ||
532 | }, | ||
533 | ) | ||
534 | self.db.commit() | ||
535 | return cursor.rowcount != 0 | ||
536 | |||
537 | async def get_usage(self): | ||
538 | usage = {} | ||
539 | with closing(self.db.cursor()) as cursor: | ||
540 | cursor.execute( | ||
541 | f""" | ||
542 | SELECT name FROM {_schema_table_name(self.sqlite_version)} WHERE type = 'table' AND name NOT LIKE 'sqlite_%' | ||
543 | """ | ||
544 | ) | ||
545 | for row in cursor.fetchall(): | ||
546 | cursor.execute( | ||
547 | """ | ||
548 | SELECT COUNT() FROM %s | ||
549 | """ | ||
550 | % row["name"], | ||
551 | ) | ||
552 | usage[row["name"]] = { | ||
553 | "rows": cursor.fetchone()[0], | ||
554 | } | ||
555 | return usage | ||
556 | |||
557 | async def get_query_columns(self): | ||
558 | columns = set() | ||
559 | for name, typ, _ in UNIHASH_TABLE_DEFINITION + OUTHASH_TABLE_DEFINITION: | ||
560 | if typ.startswith("TEXT"): | ||
561 | columns.add(name) | ||
562 | return list(columns) | ||
diff --git a/bitbake/lib/hashserv/tests.py b/bitbake/lib/hashserv/tests.py index 1a696481e3..13ccb20ebf 100644 --- a/bitbake/lib/hashserv/tests.py +++ b/bitbake/lib/hashserv/tests.py | |||
@@ -6,7 +6,8 @@ | |||
6 | # | 6 | # |
7 | 7 | ||
8 | from . import create_server, create_client | 8 | from . import create_server, create_client |
9 | from .client import HashConnectionError | 9 | from .server import DEFAULT_ANON_PERMS, ALL_PERMISSIONS |
10 | from bb.asyncrpc import InvokeError | ||
10 | import hashlib | 11 | import hashlib |
11 | import logging | 12 | import logging |
12 | import multiprocessing | 13 | import multiprocessing |
@@ -16,72 +17,161 @@ import tempfile | |||
16 | import threading | 17 | import threading |
17 | import unittest | 18 | import unittest |
18 | import socket | 19 | import socket |
19 | 20 | import time | |
20 | def _run_server(server, idx): | 21 | import signal |
21 | # logging.basicConfig(level=logging.DEBUG, filename='bbhashserv.log', filemode='w', | 22 | import subprocess |
22 | # format='%(levelname)s %(filename)s:%(lineno)d %(message)s') | 23 | import json |
23 | sys.stdout = open('bbhashserv-%d.log' % idx, 'w') | 24 | import re |
25 | from pathlib import Path | ||
26 | |||
27 | |||
28 | THIS_DIR = Path(__file__).parent | ||
29 | BIN_DIR = THIS_DIR.parent.parent / "bin" | ||
30 | |||
31 | def server_prefunc(server, idx): | ||
32 | logging.basicConfig(level=logging.DEBUG, filename='bbhashserv-%d.log' % idx, filemode='w', | ||
33 | format='%(levelname)s %(filename)s:%(lineno)d %(message)s') | ||
34 | server.logger.debug("Running server %d" % idx) | ||
35 | sys.stdout = open('bbhashserv-stdout-%d.log' % idx, 'w') | ||
24 | sys.stderr = sys.stdout | 36 | sys.stderr = sys.stdout |
25 | server.serve_forever() | ||
26 | |||
27 | 37 | ||
28 | class HashEquivalenceTestSetup(object): | 38 | class HashEquivalenceTestSetup(object): |
29 | METHOD = 'TestMethod' | 39 | METHOD = 'TestMethod' |
30 | 40 | ||
31 | server_index = 0 | 41 | server_index = 0 |
42 | client_index = 0 | ||
32 | 43 | ||
33 | def start_server(self, dbpath=None, upstream=None, read_only=False): | 44 | def start_server(self, dbpath=None, upstream=None, read_only=False, prefunc=server_prefunc, anon_perms=DEFAULT_ANON_PERMS, admin_username=None, admin_password=None): |
34 | self.server_index += 1 | 45 | self.server_index += 1 |
35 | if dbpath is None: | 46 | if dbpath is None: |
36 | dbpath = os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index) | 47 | dbpath = self.make_dbpath() |
48 | |||
49 | def cleanup_server(server): | ||
50 | if server.process.exitcode is not None: | ||
51 | return | ||
37 | 52 | ||
38 | def cleanup_thread(thread): | 53 | server.process.terminate() |
39 | thread.terminate() | 54 | server.process.join() |
40 | thread.join() | ||
41 | 55 | ||
42 | server = create_server(self.get_server_addr(self.server_index), | 56 | server = create_server(self.get_server_addr(self.server_index), |
43 | dbpath, | 57 | dbpath, |
44 | upstream=upstream, | 58 | upstream=upstream, |
45 | read_only=read_only) | 59 | read_only=read_only, |
60 | anon_perms=anon_perms, | ||
61 | admin_username=admin_username, | ||
62 | admin_password=admin_password) | ||
46 | server.dbpath = dbpath | 63 | server.dbpath = dbpath |
47 | 64 | ||
48 | server.thread = multiprocessing.Process(target=_run_server, args=(server, self.server_index)) | 65 | server.serve_as_process(prefunc=prefunc, args=(self.server_index,)) |
49 | server.thread.start() | 66 | self.addCleanup(cleanup_server, server) |
50 | self.addCleanup(cleanup_thread, server.thread) | 67 | |
68 | return server | ||
69 | |||
70 | def make_dbpath(self): | ||
71 | return os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index) | ||
51 | 72 | ||
73 | def start_client(self, server_address, username=None, password=None): | ||
52 | def cleanup_client(client): | 74 | def cleanup_client(client): |
53 | client.close() | 75 | client.close() |
54 | 76 | ||
55 | client = create_client(server.address) | 77 | client = create_client(server_address, username=username, password=password) |
56 | self.addCleanup(cleanup_client, client) | 78 | self.addCleanup(cleanup_client, client) |
57 | 79 | ||
58 | return (client, server) | 80 | return client |
59 | 81 | ||
60 | def setUp(self): | 82 | def start_test_server(self): |
61 | if sys.version_info < (3, 5, 0): | 83 | self.server = self.start_server() |
62 | self.skipTest('Python 3.5 or later required') | 84 | return self.server.address |
85 | |||
86 | def start_auth_server(self): | ||
87 | auth_server = self.start_server(self.server.dbpath, anon_perms=[], admin_username="admin", admin_password="password") | ||
88 | self.auth_server_address = auth_server.address | ||
89 | self.admin_client = self.start_client(auth_server.address, username="admin", password="password") | ||
90 | return self.admin_client | ||
63 | 91 | ||
92 | def auth_client(self, user): | ||
93 | return self.start_client(self.auth_server_address, user["username"], user["token"]) | ||
94 | |||
95 | def setUp(self): | ||
64 | self.temp_dir = tempfile.TemporaryDirectory(prefix='bb-hashserv') | 96 | self.temp_dir = tempfile.TemporaryDirectory(prefix='bb-hashserv') |
65 | self.addCleanup(self.temp_dir.cleanup) | 97 | self.addCleanup(self.temp_dir.cleanup) |
66 | 98 | ||
67 | (self.client, self.server) = self.start_server() | 99 | self.server_address = self.start_test_server() |
100 | |||
101 | self.client = self.start_client(self.server_address) | ||
68 | 102 | ||
69 | def assertClientGetHash(self, client, taskhash, unihash): | 103 | def assertClientGetHash(self, client, taskhash, unihash): |
70 | result = client.get_unihash(self.METHOD, taskhash) | 104 | result = client.get_unihash(self.METHOD, taskhash) |
71 | self.assertEqual(result, unihash) | 105 | self.assertEqual(result, unihash) |
72 | 106 | ||
107 | def assertUserPerms(self, user, permissions): | ||
108 | with self.auth_client(user) as client: | ||
109 | info = client.get_user() | ||
110 | self.assertEqual(info, { | ||
111 | "username": user["username"], | ||
112 | "permissions": permissions, | ||
113 | }) | ||
73 | 114 | ||
74 | class HashEquivalenceCommonTests(object): | 115 | def assertUserCanAuth(self, user): |
75 | def test_create_hash(self): | 116 | with self.start_client(self.auth_server_address) as client: |
117 | client.auth(user["username"], user["token"]) | ||
118 | |||
119 | def assertUserCannotAuth(self, user): | ||
120 | with self.start_client(self.auth_server_address) as client, self.assertRaises(InvokeError): | ||
121 | client.auth(user["username"], user["token"]) | ||
122 | |||
123 | def create_test_hash(self, client): | ||
76 | # Simple test that hashes can be created | 124 | # Simple test that hashes can be created |
77 | taskhash = '35788efcb8dfb0a02659d81cf2bfd695fb30faf9' | 125 | taskhash = '35788efcb8dfb0a02659d81cf2bfd695fb30faf9' |
78 | outhash = '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f' | 126 | outhash = '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f' |
79 | unihash = 'f46d3fbb439bd9b921095da657a4de906510d2cd' | 127 | unihash = 'f46d3fbb439bd9b921095da657a4de906510d2cd' |
80 | 128 | ||
81 | self.assertClientGetHash(self.client, taskhash, None) | 129 | self.assertClientGetHash(client, taskhash, None) |
82 | 130 | ||
83 | result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) | 131 | result = client.report_unihash(taskhash, self.METHOD, outhash, unihash) |
84 | self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') | 132 | self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') |
133 | return taskhash, outhash, unihash | ||
134 | |||
135 | def run_hashclient(self, args, **kwargs): | ||
136 | try: | ||
137 | p = subprocess.run( | ||
138 | [BIN_DIR / "bitbake-hashclient"] + args, | ||
139 | stdout=subprocess.PIPE, | ||
140 | stderr=subprocess.STDOUT, | ||
141 | encoding="utf-8", | ||
142 | **kwargs | ||
143 | ) | ||
144 | except subprocess.CalledProcessError as e: | ||
145 | print(e.output) | ||
146 | raise e | ||
147 | |||
148 | print(p.stdout) | ||
149 | return p | ||
150 | |||
151 | |||
152 | class HashEquivalenceCommonTests(object): | ||
153 | def auth_perms(self, *permissions): | ||
154 | self.client_index += 1 | ||
155 | user = self.create_user(f"user-{self.client_index}", permissions) | ||
156 | return self.auth_client(user) | ||
157 | |||
158 | def create_user(self, username, permissions, *, client=None): | ||
159 | def remove_user(username): | ||
160 | try: | ||
161 | self.admin_client.delete_user(username) | ||
162 | except bb.asyncrpc.InvokeError: | ||
163 | pass | ||
164 | |||
165 | if client is None: | ||
166 | client = self.admin_client | ||
167 | |||
168 | user = client.new_user(username, permissions) | ||
169 | self.addCleanup(remove_user, username) | ||
170 | |||
171 | return user | ||
172 | |||
173 | def test_create_hash(self): | ||
174 | return self.create_test_hash(self.client) | ||
85 | 175 | ||
86 | def test_create_equivalent(self): | 176 | def test_create_equivalent(self): |
87 | # Tests that a second reported task with the same outhash will be | 177 | # Tests that a second reported task with the same outhash will be |
@@ -123,6 +213,57 @@ class HashEquivalenceCommonTests(object): | |||
123 | 213 | ||
124 | self.assertClientGetHash(self.client, taskhash, unihash) | 214 | self.assertClientGetHash(self.client, taskhash, unihash) |
125 | 215 | ||
216 | def test_remove_taskhash(self): | ||
217 | taskhash, outhash, unihash = self.create_test_hash(self.client) | ||
218 | result = self.client.remove({"taskhash": taskhash}) | ||
219 | self.assertGreater(result["count"], 0) | ||
220 | self.assertClientGetHash(self.client, taskhash, None) | ||
221 | |||
222 | result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash) | ||
223 | self.assertIsNone(result_outhash) | ||
224 | |||
225 | def test_remove_unihash(self): | ||
226 | taskhash, outhash, unihash = self.create_test_hash(self.client) | ||
227 | result = self.client.remove({"unihash": unihash}) | ||
228 | self.assertGreater(result["count"], 0) | ||
229 | self.assertClientGetHash(self.client, taskhash, None) | ||
230 | |||
231 | def test_remove_outhash(self): | ||
232 | taskhash, outhash, unihash = self.create_test_hash(self.client) | ||
233 | result = self.client.remove({"outhash": outhash}) | ||
234 | self.assertGreater(result["count"], 0) | ||
235 | |||
236 | result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash) | ||
237 | self.assertIsNone(result_outhash) | ||
238 | |||
239 | def test_remove_method(self): | ||
240 | taskhash, outhash, unihash = self.create_test_hash(self.client) | ||
241 | result = self.client.remove({"method": self.METHOD}) | ||
242 | self.assertGreater(result["count"], 0) | ||
243 | self.assertClientGetHash(self.client, taskhash, None) | ||
244 | |||
245 | result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash) | ||
246 | self.assertIsNone(result_outhash) | ||
247 | |||
248 | def test_clean_unused(self): | ||
249 | taskhash, outhash, unihash = self.create_test_hash(self.client) | ||
250 | |||
251 | # Clean the database, which should not remove anything because all hashes an in-use | ||
252 | result = self.client.clean_unused(0) | ||
253 | self.assertEqual(result["count"], 0) | ||
254 | self.assertClientGetHash(self.client, taskhash, unihash) | ||
255 | |||
256 | # Remove the unihash. The row in the outhash table should still be present | ||
257 | self.client.remove({"unihash": unihash}) | ||
258 | result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash, False) | ||
259 | self.assertIsNotNone(result_outhash) | ||
260 | |||
261 | # Now clean with no minimum age which will remove the outhash | ||
262 | result = self.client.clean_unused(0) | ||
263 | self.assertEqual(result["count"], 1) | ||
264 | result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash, False) | ||
265 | self.assertIsNone(result_outhash) | ||
266 | |||
126 | def test_huge_message(self): | 267 | def test_huge_message(self): |
127 | # Simple test that hashes can be created | 268 | # Simple test that hashes can be created |
128 | taskhash = 'c665584ee6817aa99edfc77a44dd853828279370' | 269 | taskhash = 'c665584ee6817aa99edfc77a44dd853828279370' |
@@ -138,16 +279,21 @@ class HashEquivalenceCommonTests(object): | |||
138 | }) | 279 | }) |
139 | self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') | 280 | self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') |
140 | 281 | ||
141 | result = self.client.get_taskhash(self.METHOD, taskhash, True) | 282 | result_unihash = self.client.get_taskhash(self.METHOD, taskhash, True) |
142 | self.assertEqual(result['taskhash'], taskhash) | 283 | self.assertEqual(result_unihash['taskhash'], taskhash) |
143 | self.assertEqual(result['unihash'], unihash) | 284 | self.assertEqual(result_unihash['unihash'], unihash) |
144 | self.assertEqual(result['method'], self.METHOD) | 285 | self.assertEqual(result_unihash['method'], self.METHOD) |
145 | self.assertEqual(result['outhash'], outhash) | 286 | |
146 | self.assertEqual(result['outhash_siginfo'], siginfo) | 287 | result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash) |
288 | self.assertEqual(result_outhash['taskhash'], taskhash) | ||
289 | self.assertEqual(result_outhash['method'], self.METHOD) | ||
290 | self.assertEqual(result_outhash['unihash'], unihash) | ||
291 | self.assertEqual(result_outhash['outhash'], outhash) | ||
292 | self.assertEqual(result_outhash['outhash_siginfo'], siginfo) | ||
147 | 293 | ||
148 | def test_stress(self): | 294 | def test_stress(self): |
149 | def query_server(failures): | 295 | def query_server(failures): |
150 | client = Client(self.server.address) | 296 | client = Client(self.server_address) |
151 | try: | 297 | try: |
152 | for i in range(1000): | 298 | for i in range(1000): |
153 | taskhash = hashlib.sha256() | 299 | taskhash = hashlib.sha256() |
@@ -186,8 +332,10 @@ class HashEquivalenceCommonTests(object): | |||
186 | # the side client. It also verifies that the results are pulled into | 332 | # the side client. It also verifies that the results are pulled into |
187 | # the downstream database by checking that the downstream and side servers | 333 | # the downstream database by checking that the downstream and side servers |
188 | # match after the downstream is done waiting for all backfill tasks | 334 | # match after the downstream is done waiting for all backfill tasks |
189 | (down_client, down_server) = self.start_server(upstream=self.server.address) | 335 | down_server = self.start_server(upstream=self.server_address) |
190 | (side_client, side_server) = self.start_server(dbpath=down_server.dbpath) | 336 | down_client = self.start_client(down_server.address) |
337 | side_server = self.start_server(dbpath=down_server.dbpath) | ||
338 | side_client = self.start_client(side_server.address) | ||
191 | 339 | ||
192 | def check_hash(taskhash, unihash, old_sidehash): | 340 | def check_hash(taskhash, unihash, old_sidehash): |
193 | nonlocal down_client | 341 | nonlocal down_client |
@@ -258,15 +406,57 @@ class HashEquivalenceCommonTests(object): | |||
258 | result = down_client.report_unihash(taskhash6, self.METHOD, outhash5, unihash6) | 406 | result = down_client.report_unihash(taskhash6, self.METHOD, outhash5, unihash6) |
259 | self.assertEqual(result['unihash'], unihash5, 'Server failed to copy unihash from upstream') | 407 | self.assertEqual(result['unihash'], unihash5, 'Server failed to copy unihash from upstream') |
260 | 408 | ||
409 | # Tests read through from server with | ||
410 | taskhash7 = '9d81d76242cc7cfaf7bf74b94b9cd2e29324ed74' | ||
411 | outhash7 = '8470d56547eea6236d7c81a644ce74670ca0bbda998e13c629ef6bb3f0d60b69' | ||
412 | unihash7 = '05d2a63c81e32f0a36542ca677e8ad852365c538' | ||
413 | self.client.report_unihash(taskhash7, self.METHOD, outhash7, unihash7) | ||
414 | |||
415 | result = down_client.get_taskhash(self.METHOD, taskhash7, True) | ||
416 | self.assertEqual(result['unihash'], unihash7, 'Server failed to copy unihash from upstream') | ||
417 | self.assertEqual(result['outhash'], outhash7, 'Server failed to copy unihash from upstream') | ||
418 | self.assertEqual(result['taskhash'], taskhash7, 'Server failed to copy unihash from upstream') | ||
419 | self.assertEqual(result['method'], self.METHOD) | ||
420 | |||
421 | taskhash8 = '86978a4c8c71b9b487330b0152aade10c1ee58aa' | ||
422 | outhash8 = 'ca8c128e9d9e4a28ef24d0508aa20b5cf880604eacd8f65c0e366f7e0cc5fbcf' | ||
423 | unihash8 = 'd8bcf25369d40590ad7d08c84d538982f2023e01' | ||
424 | self.client.report_unihash(taskhash8, self.METHOD, outhash8, unihash8) | ||
425 | |||
426 | result = down_client.get_outhash(self.METHOD, outhash8, taskhash8) | ||
427 | self.assertEqual(result['unihash'], unihash8, 'Server failed to copy unihash from upstream') | ||
428 | self.assertEqual(result['outhash'], outhash8, 'Server failed to copy unihash from upstream') | ||
429 | self.assertEqual(result['taskhash'], taskhash8, 'Server failed to copy unihash from upstream') | ||
430 | self.assertEqual(result['method'], self.METHOD) | ||
431 | |||
432 | taskhash9 = 'ae6339531895ddf5b67e663e6a374ad8ec71d81c' | ||
433 | outhash9 = 'afc78172c81880ae10a1fec994b5b4ee33d196a001a1b66212a15ebe573e00b5' | ||
434 | unihash9 = '6662e699d6e3d894b24408ff9a4031ef9b038ee8' | ||
435 | self.client.report_unihash(taskhash9, self.METHOD, outhash9, unihash9) | ||
436 | |||
437 | result = down_client.get_taskhash(self.METHOD, taskhash9, False) | ||
438 | self.assertEqual(result['unihash'], unihash9, 'Server failed to copy unihash from upstream') | ||
439 | self.assertEqual(result['taskhash'], taskhash9, 'Server failed to copy unihash from upstream') | ||
440 | self.assertEqual(result['method'], self.METHOD) | ||
441 | |||
442 | def test_unihash_exsits(self): | ||
443 | taskhash, outhash, unihash = self.create_test_hash(self.client) | ||
444 | self.assertTrue(self.client.unihash_exists(unihash)) | ||
445 | self.assertFalse(self.client.unihash_exists('6662e699d6e3d894b24408ff9a4031ef9b038ee8')) | ||
446 | |||
261 | def test_ro_server(self): | 447 | def test_ro_server(self): |
262 | (ro_client, ro_server) = self.start_server(dbpath=self.server.dbpath, read_only=True) | 448 | rw_server = self.start_server() |
449 | rw_client = self.start_client(rw_server.address) | ||
450 | |||
451 | ro_server = self.start_server(dbpath=rw_server.dbpath, read_only=True) | ||
452 | ro_client = self.start_client(ro_server.address) | ||
263 | 453 | ||
264 | # Report a hash via the read-write server | 454 | # Report a hash via the read-write server |
265 | taskhash = '35788efcb8dfb0a02659d81cf2bfd695fb30faf9' | 455 | taskhash = '35788efcb8dfb0a02659d81cf2bfd695fb30faf9' |
266 | outhash = '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f' | 456 | outhash = '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f' |
267 | unihash = 'f46d3fbb439bd9b921095da657a4de906510d2cd' | 457 | unihash = 'f46d3fbb439bd9b921095da657a4de906510d2cd' |
268 | 458 | ||
269 | result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) | 459 | result = rw_client.report_unihash(taskhash, self.METHOD, outhash, unihash) |
270 | self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') | 460 | self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') |
271 | 461 | ||
272 | # Check the hash via the read-only server | 462 | # Check the hash via the read-only server |
@@ -277,11 +467,934 @@ class HashEquivalenceCommonTests(object): | |||
277 | outhash2 = '3c979c3db45c569f51ab7626a4651074be3a9d11a84b1db076f5b14f7d39db44' | 467 | outhash2 = '3c979c3db45c569f51ab7626a4651074be3a9d11a84b1db076f5b14f7d39db44' |
278 | unihash2 = '90e9bc1d1f094c51824adca7f8ea79a048d68824' | 468 | unihash2 = '90e9bc1d1f094c51824adca7f8ea79a048d68824' |
279 | 469 | ||
280 | with self.assertRaises(HashConnectionError): | 470 | result = ro_client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2) |
281 | ro_client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2) | 471 | self.assertEqual(result['unihash'], unihash2) |
282 | 472 | ||
283 | # Ensure that the database was not modified | 473 | # Ensure that the database was not modified |
474 | self.assertClientGetHash(rw_client, taskhash2, None) | ||
475 | |||
476 | |||
477 | def test_slow_server_start(self): | ||
478 | # Ensures that the server will exit correctly even if it gets a SIGTERM | ||
479 | # before entering the main loop | ||
480 | |||
481 | event = multiprocessing.Event() | ||
482 | |||
483 | def prefunc(server, idx): | ||
484 | nonlocal event | ||
485 | server_prefunc(server, idx) | ||
486 | event.wait() | ||
487 | |||
488 | def do_nothing(signum, frame): | ||
489 | pass | ||
490 | |||
491 | old_signal = signal.signal(signal.SIGTERM, do_nothing) | ||
492 | self.addCleanup(signal.signal, signal.SIGTERM, old_signal) | ||
493 | |||
494 | server = self.start_server(prefunc=prefunc) | ||
495 | server.process.terminate() | ||
496 | time.sleep(30) | ||
497 | event.set() | ||
498 | server.process.join(300) | ||
499 | self.assertIsNotNone(server.process.exitcode, "Server did not exit in a timely manner!") | ||
500 | |||
501 | def test_diverging_report_race(self): | ||
502 | # Tests that a reported task will correctly pick up an updated unihash | ||
503 | |||
504 | # This is a baseline report added to the database to ensure that there | ||
505 | # is something to match against as equivalent | ||
506 | outhash1 = 'afd11c366050bcd75ad763e898e4430e2a60659b26f83fbb22201a60672019fa' | ||
507 | taskhash1 = '3bde230c743fc45ab61a065d7a1815fbfa01c4740e4c895af2eb8dc0f684a4ab' | ||
508 | unihash1 = '3bde230c743fc45ab61a065d7a1815fbfa01c4740e4c895af2eb8dc0f684a4ab' | ||
509 | result = self.client.report_unihash(taskhash1, self.METHOD, outhash1, unihash1) | ||
510 | |||
511 | # Add a report that is equivalent to Task 1. It should ignore the | ||
512 | # provided unihash and report the unihash from task 1 | ||
513 | taskhash2 = '6259ae8263bd94d454c086f501c37e64c4e83cae806902ca95b4ab513546b273' | ||
514 | unihash2 = taskhash2 | ||
515 | result = self.client.report_unihash(taskhash2, self.METHOD, outhash1, unihash2) | ||
516 | self.assertEqual(result['unihash'], unihash1) | ||
517 | |||
518 | # Add another report for Task 2, but with a different outhash (e.g. the | ||
519 | # task is non-deterministic). It should still be marked with the Task 1 | ||
520 | # unihash because it has the Task 2 taskhash, which is equivalent to | ||
521 | # Task 1 | ||
522 | outhash3 = 'd2187ee3a8966db10b34fe0e863482288d9a6185cb8ef58a6c1c6ace87a2f24c' | ||
523 | result = self.client.report_unihash(taskhash2, self.METHOD, outhash3, unihash2) | ||
524 | self.assertEqual(result['unihash'], unihash1) | ||
525 | |||
526 | |||
527 | def test_diverging_report_reverse_race(self): | ||
528 | # Same idea as the previous test, but Tasks 2 and 3 are reported in | ||
529 | # reverse order the opposite order | ||
530 | |||
531 | outhash1 = 'afd11c366050bcd75ad763e898e4430e2a60659b26f83fbb22201a60672019fa' | ||
532 | taskhash1 = '3bde230c743fc45ab61a065d7a1815fbfa01c4740e4c895af2eb8dc0f684a4ab' | ||
533 | unihash1 = '3bde230c743fc45ab61a065d7a1815fbfa01c4740e4c895af2eb8dc0f684a4ab' | ||
534 | result = self.client.report_unihash(taskhash1, self.METHOD, outhash1, unihash1) | ||
535 | |||
536 | taskhash2 = '6259ae8263bd94d454c086f501c37e64c4e83cae806902ca95b4ab513546b273' | ||
537 | unihash2 = taskhash2 | ||
538 | |||
539 | # Report Task 3 first. Since there is nothing else in the database it | ||
540 | # will use the client provided unihash | ||
541 | outhash3 = 'd2187ee3a8966db10b34fe0e863482288d9a6185cb8ef58a6c1c6ace87a2f24c' | ||
542 | result = self.client.report_unihash(taskhash2, self.METHOD, outhash3, unihash2) | ||
543 | self.assertEqual(result['unihash'], unihash2) | ||
544 | |||
545 | # Report Task 2. This is equivalent to Task 1 but there is already a mapping for | ||
546 | # taskhash2 so it will report unihash2 | ||
547 | result = self.client.report_unihash(taskhash2, self.METHOD, outhash1, unihash2) | ||
548 | self.assertEqual(result['unihash'], unihash2) | ||
549 | |||
550 | # The originally reported unihash for Task 3 should be unchanged even if it | ||
551 | # shares a taskhash with Task 2 | ||
552 | self.assertClientGetHash(self.client, taskhash2, unihash2) | ||
553 | |||
554 | def test_get_unihash_batch(self): | ||
555 | TEST_INPUT = ( | ||
556 | # taskhash outhash unihash | ||
557 | ('8aa96fcffb5831b3c2c0cb75f0431e3f8b20554a', 'afe240a439959ce86f5e322f8c208e1fedefea9e813f2140c81af866cc9edf7e','218e57509998197d570e2c98512d0105985dffc9'), | ||
558 | # Duplicated taskhash with multiple output hashes and unihashes. | ||
559 | ('8aa96fcffb5831b3c2c0cb75f0431e3f8b20554a', '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d', 'ae9a7d252735f0dafcdb10e2e02561ca3a47314c'), | ||
560 | # Equivalent hash | ||
561 | ("044c2ec8aaf480685a00ff6ff49e6162e6ad34e1", '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d', "def64766090d28f627e816454ed46894bb3aab36"), | ||
562 | ("e3da00593d6a7fb435c7e2114976c59c5fd6d561", "1cf8713e645f491eb9c959d20b5cae1c47133a292626dda9b10709857cbe688a", "3b5d3d83f07f259e9086fcb422c855286e18a57d"), | ||
563 | ('35788efcb8dfb0a02659d81cf2bfd695fb30faf9', '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f', 'f46d3fbb439bd9b921095da657a4de906510d2cd'), | ||
564 | ('35788efcb8dfb0a02659d81cf2bfd695fb30fafa', '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f', 'f46d3fbb439bd9b921095da657a4de906510d2ce'), | ||
565 | ('9d81d76242cc7cfaf7bf74b94b9cd2e29324ed74', '8470d56547eea6236d7c81a644ce74670ca0bbda998e13c629ef6bb3f0d60b69', '05d2a63c81e32f0a36542ca677e8ad852365c538'), | ||
566 | ) | ||
567 | EXTRA_QUERIES = ( | ||
568 | "6b6be7a84ab179b4240c4302518dc3f6", | ||
569 | ) | ||
570 | |||
571 | for taskhash, outhash, unihash in TEST_INPUT: | ||
572 | self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) | ||
573 | |||
574 | |||
575 | result = self.client.get_unihash_batch( | ||
576 | [(self.METHOD, data[0]) for data in TEST_INPUT] + | ||
577 | [(self.METHOD, e) for e in EXTRA_QUERIES] | ||
578 | ) | ||
579 | |||
580 | self.assertListEqual(result, [ | ||
581 | "218e57509998197d570e2c98512d0105985dffc9", | ||
582 | "218e57509998197d570e2c98512d0105985dffc9", | ||
583 | "218e57509998197d570e2c98512d0105985dffc9", | ||
584 | "3b5d3d83f07f259e9086fcb422c855286e18a57d", | ||
585 | "f46d3fbb439bd9b921095da657a4de906510d2cd", | ||
586 | "f46d3fbb439bd9b921095da657a4de906510d2cd", | ||
587 | "05d2a63c81e32f0a36542ca677e8ad852365c538", | ||
588 | None, | ||
589 | ]) | ||
590 | |||
591 | def test_unihash_exists_batch(self): | ||
592 | TEST_INPUT = ( | ||
593 | # taskhash outhash unihash | ||
594 | ('8aa96fcffb5831b3c2c0cb75f0431e3f8b20554a', 'afe240a439959ce86f5e322f8c208e1fedefea9e813f2140c81af866cc9edf7e','218e57509998197d570e2c98512d0105985dffc9'), | ||
595 | # Duplicated taskhash with multiple output hashes and unihashes. | ||
596 | ('8aa96fcffb5831b3c2c0cb75f0431e3f8b20554a', '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d', 'ae9a7d252735f0dafcdb10e2e02561ca3a47314c'), | ||
597 | # Equivalent hash | ||
598 | ("044c2ec8aaf480685a00ff6ff49e6162e6ad34e1", '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d', "def64766090d28f627e816454ed46894bb3aab36"), | ||
599 | ("e3da00593d6a7fb435c7e2114976c59c5fd6d561", "1cf8713e645f491eb9c959d20b5cae1c47133a292626dda9b10709857cbe688a", "3b5d3d83f07f259e9086fcb422c855286e18a57d"), | ||
600 | ('35788efcb8dfb0a02659d81cf2bfd695fb30faf9', '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f', 'f46d3fbb439bd9b921095da657a4de906510d2cd'), | ||
601 | ('35788efcb8dfb0a02659d81cf2bfd695fb30fafa', '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f', 'f46d3fbb439bd9b921095da657a4de906510d2ce'), | ||
602 | ('9d81d76242cc7cfaf7bf74b94b9cd2e29324ed74', '8470d56547eea6236d7c81a644ce74670ca0bbda998e13c629ef6bb3f0d60b69', '05d2a63c81e32f0a36542ca677e8ad852365c538'), | ||
603 | ) | ||
604 | EXTRA_QUERIES = ( | ||
605 | "6b6be7a84ab179b4240c4302518dc3f6", | ||
606 | ) | ||
607 | |||
608 | result_unihashes = set() | ||
609 | |||
610 | |||
611 | for taskhash, outhash, unihash in TEST_INPUT: | ||
612 | result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) | ||
613 | result_unihashes.add(result["unihash"]) | ||
614 | |||
615 | query = [] | ||
616 | expected = [] | ||
617 | |||
618 | for _, _, unihash in TEST_INPUT: | ||
619 | query.append(unihash) | ||
620 | expected.append(unihash in result_unihashes) | ||
621 | |||
622 | |||
623 | for unihash in EXTRA_QUERIES: | ||
624 | query.append(unihash) | ||
625 | expected.append(False) | ||
626 | |||
627 | result = self.client.unihash_exists_batch(query) | ||
628 | self.assertListEqual(result, expected) | ||
629 | |||
630 | def test_auth_read_perms(self): | ||
631 | admin_client = self.start_auth_server() | ||
632 | |||
633 | # Create hashes with non-authenticated server | ||
634 | taskhash, outhash, unihash = self.create_test_hash(self.client) | ||
635 | |||
636 | # Validate hash can be retrieved using authenticated client | ||
637 | with self.auth_perms("@read") as client: | ||
638 | self.assertClientGetHash(client, taskhash, unihash) | ||
639 | |||
640 | with self.auth_perms() as client, self.assertRaises(InvokeError): | ||
641 | self.assertClientGetHash(client, taskhash, unihash) | ||
642 | |||
643 | def test_auth_report_perms(self): | ||
644 | admin_client = self.start_auth_server() | ||
645 | |||
646 | # Without read permission, the user is completely denied | ||
647 | with self.auth_perms() as client, self.assertRaises(InvokeError): | ||
648 | self.create_test_hash(client) | ||
649 | |||
650 | # Read permission allows the call to succeed, but it doesn't record | ||
651 | # anythin in the database | ||
652 | with self.auth_perms("@read") as client: | ||
653 | taskhash, outhash, unihash = self.create_test_hash(client) | ||
654 | self.assertClientGetHash(client, taskhash, None) | ||
655 | |||
656 | # Report permission alone is insufficient | ||
657 | with self.auth_perms("@report") as client, self.assertRaises(InvokeError): | ||
658 | self.create_test_hash(client) | ||
659 | |||
660 | # Read and report permission actually modify the database | ||
661 | with self.auth_perms("@read", "@report") as client: | ||
662 | taskhash, outhash, unihash = self.create_test_hash(client) | ||
663 | self.assertClientGetHash(client, taskhash, unihash) | ||
664 | |||
665 | def test_auth_no_token_refresh_from_anon_user(self): | ||
666 | self.start_auth_server() | ||
667 | |||
668 | with self.start_client(self.auth_server_address) as client, self.assertRaises(InvokeError): | ||
669 | client.refresh_token() | ||
670 | |||
671 | def test_auth_self_token_refresh(self): | ||
672 | admin_client = self.start_auth_server() | ||
673 | |||
674 | # Create a new user with no permissions | ||
675 | user = self.create_user("test-user", []) | ||
676 | |||
677 | with self.auth_client(user) as client: | ||
678 | new_user = client.refresh_token() | ||
679 | |||
680 | self.assertEqual(user["username"], new_user["username"]) | ||
681 | self.assertNotEqual(user["token"], new_user["token"]) | ||
682 | self.assertUserCanAuth(new_user) | ||
683 | self.assertUserCannotAuth(user) | ||
684 | |||
685 | # Explicitly specifying with your own username is fine also | ||
686 | with self.auth_client(new_user) as client: | ||
687 | new_user2 = client.refresh_token(user["username"]) | ||
688 | |||
689 | self.assertEqual(user["username"], new_user2["username"]) | ||
690 | self.assertNotEqual(user["token"], new_user2["token"]) | ||
691 | self.assertUserCanAuth(new_user2) | ||
692 | self.assertUserCannotAuth(new_user) | ||
693 | self.assertUserCannotAuth(user) | ||
694 | |||
695 | def test_auth_token_refresh(self): | ||
696 | admin_client = self.start_auth_server() | ||
697 | |||
698 | user = self.create_user("test-user", []) | ||
699 | |||
700 | with self.auth_perms() as client, self.assertRaises(InvokeError): | ||
701 | client.refresh_token(user["username"]) | ||
702 | |||
703 | with self.auth_perms("@user-admin") as client: | ||
704 | new_user = client.refresh_token(user["username"]) | ||
705 | |||
706 | self.assertEqual(user["username"], new_user["username"]) | ||
707 | self.assertNotEqual(user["token"], new_user["token"]) | ||
708 | self.assertUserCanAuth(new_user) | ||
709 | self.assertUserCannotAuth(user) | ||
710 | |||
711 | def test_auth_self_get_user(self): | ||
712 | admin_client = self.start_auth_server() | ||
713 | |||
714 | user = self.create_user("test-user", []) | ||
715 | user_info = user.copy() | ||
716 | del user_info["token"] | ||
717 | |||
718 | with self.auth_client(user) as client: | ||
719 | info = client.get_user() | ||
720 | self.assertEqual(info, user_info) | ||
721 | |||
722 | # Explicitly asking for your own username is fine also | ||
723 | info = client.get_user(user["username"]) | ||
724 | self.assertEqual(info, user_info) | ||
725 | |||
726 | def test_auth_get_user(self): | ||
727 | admin_client = self.start_auth_server() | ||
728 | |||
729 | user = self.create_user("test-user", []) | ||
730 | user_info = user.copy() | ||
731 | del user_info["token"] | ||
732 | |||
733 | with self.auth_perms() as client, self.assertRaises(InvokeError): | ||
734 | client.get_user(user["username"]) | ||
735 | |||
736 | with self.auth_perms("@user-admin") as client: | ||
737 | info = client.get_user(user["username"]) | ||
738 | self.assertEqual(info, user_info) | ||
739 | |||
740 | info = client.get_user("nonexist-user") | ||
741 | self.assertIsNone(info) | ||
742 | |||
743 | def test_auth_reconnect(self): | ||
744 | admin_client = self.start_auth_server() | ||
745 | |||
746 | user = self.create_user("test-user", []) | ||
747 | user_info = user.copy() | ||
748 | del user_info["token"] | ||
749 | |||
750 | with self.auth_client(user) as client: | ||
751 | info = client.get_user() | ||
752 | self.assertEqual(info, user_info) | ||
753 | |||
754 | client.disconnect() | ||
755 | |||
756 | info = client.get_user() | ||
757 | self.assertEqual(info, user_info) | ||
758 | |||
759 | def test_auth_delete_user(self): | ||
760 | admin_client = self.start_auth_server() | ||
761 | |||
762 | user = self.create_user("test-user", []) | ||
763 | |||
764 | # self service | ||
765 | with self.auth_client(user) as client: | ||
766 | client.delete_user(user["username"]) | ||
767 | |||
768 | self.assertIsNone(admin_client.get_user(user["username"])) | ||
769 | user = self.create_user("test-user", []) | ||
770 | |||
771 | with self.auth_perms() as client, self.assertRaises(InvokeError): | ||
772 | client.delete_user(user["username"]) | ||
773 | |||
774 | with self.auth_perms("@user-admin") as client: | ||
775 | client.delete_user(user["username"]) | ||
776 | |||
777 | # User doesn't exist, so even though the permission is correct, it's an | ||
778 | # error | ||
779 | with self.auth_perms("@user-admin") as client, self.assertRaises(InvokeError): | ||
780 | client.delete_user(user["username"]) | ||
781 | |||
782 | def test_auth_set_user_perms(self): | ||
783 | admin_client = self.start_auth_server() | ||
784 | |||
785 | user = self.create_user("test-user", []) | ||
786 | |||
787 | self.assertUserPerms(user, []) | ||
788 | |||
789 | # No self service to change permissions | ||
790 | with self.auth_client(user) as client, self.assertRaises(InvokeError): | ||
791 | client.set_user_perms(user["username"], ["@all"]) | ||
792 | self.assertUserPerms(user, []) | ||
793 | |||
794 | with self.auth_perms() as client, self.assertRaises(InvokeError): | ||
795 | client.set_user_perms(user["username"], ["@all"]) | ||
796 | self.assertUserPerms(user, []) | ||
797 | |||
798 | with self.auth_perms("@user-admin") as client: | ||
799 | client.set_user_perms(user["username"], ["@all"]) | ||
800 | self.assertUserPerms(user, sorted(list(ALL_PERMISSIONS))) | ||
801 | |||
802 | # Bad permissions | ||
803 | with self.auth_perms("@user-admin") as client, self.assertRaises(InvokeError): | ||
804 | client.set_user_perms(user["username"], ["@this-is-not-a-permission"]) | ||
805 | self.assertUserPerms(user, sorted(list(ALL_PERMISSIONS))) | ||
806 | |||
807 | def test_auth_get_all_users(self): | ||
808 | admin_client = self.start_auth_server() | ||
809 | |||
810 | user = self.create_user("test-user", []) | ||
811 | |||
812 | with self.auth_client(user) as client, self.assertRaises(InvokeError): | ||
813 | client.get_all_users() | ||
814 | |||
815 | # Give the test user the correct permission | ||
816 | admin_client.set_user_perms(user["username"], ["@user-admin"]) | ||
817 | |||
818 | with self.auth_client(user) as client: | ||
819 | all_users = client.get_all_users() | ||
820 | |||
821 | # Convert to a dictionary for easier comparison | ||
822 | all_users = {u["username"]: u for u in all_users} | ||
823 | |||
824 | self.assertEqual(all_users, | ||
825 | { | ||
826 | "admin": { | ||
827 | "username": "admin", | ||
828 | "permissions": sorted(list(ALL_PERMISSIONS)), | ||
829 | }, | ||
830 | "test-user": { | ||
831 | "username": "test-user", | ||
832 | "permissions": ["@user-admin"], | ||
833 | } | ||
834 | } | ||
835 | ) | ||
836 | |||
837 | def test_auth_new_user(self): | ||
838 | self.start_auth_server() | ||
839 | |||
840 | permissions = ["@read", "@report", "@db-admin", "@user-admin"] | ||
841 | permissions.sort() | ||
842 | |||
843 | with self.auth_perms() as client, self.assertRaises(InvokeError): | ||
844 | self.create_user("test-user", permissions, client=client) | ||
845 | |||
846 | with self.auth_perms("@user-admin") as client: | ||
847 | user = self.create_user("test-user", permissions, client=client) | ||
848 | self.assertIn("token", user) | ||
849 | self.assertEqual(user["username"], "test-user") | ||
850 | self.assertEqual(user["permissions"], permissions) | ||
851 | |||
852 | def test_auth_become_user(self): | ||
853 | admin_client = self.start_auth_server() | ||
854 | |||
855 | user = self.create_user("test-user", ["@read", "@report"]) | ||
856 | user_info = user.copy() | ||
857 | del user_info["token"] | ||
858 | |||
859 | with self.auth_perms() as client, self.assertRaises(InvokeError): | ||
860 | client.become_user(user["username"]) | ||
861 | |||
862 | with self.auth_perms("@user-admin") as client: | ||
863 | become = client.become_user(user["username"]) | ||
864 | self.assertEqual(become, user_info) | ||
865 | |||
866 | info = client.get_user() | ||
867 | self.assertEqual(info, user_info) | ||
868 | |||
869 | # Verify become user is preserved across disconnect | ||
870 | client.disconnect() | ||
871 | |||
872 | info = client.get_user() | ||
873 | self.assertEqual(info, user_info) | ||
874 | |||
875 | # test-user doesn't have become_user permissions, so this should | ||
876 | # not work | ||
877 | with self.assertRaises(InvokeError): | ||
878 | client.become_user(user["username"]) | ||
879 | |||
880 | # No self-service of become | ||
881 | with self.auth_client(user) as client, self.assertRaises(InvokeError): | ||
882 | client.become_user(user["username"]) | ||
883 | |||
884 | # Give test user permissions to become | ||
885 | admin_client.set_user_perms(user["username"], ["@user-admin"]) | ||
886 | |||
887 | # It's possible to become yourself (effectively a noop) | ||
888 | with self.auth_perms("@user-admin") as client: | ||
889 | become = client.become_user(client.username) | ||
890 | |||
891 | def test_auth_gc(self): | ||
892 | admin_client = self.start_auth_server() | ||
893 | |||
894 | with self.auth_perms() as client, self.assertRaises(InvokeError): | ||
895 | client.gc_mark("ABC", {"unihash": "123"}) | ||
896 | |||
897 | with self.auth_perms() as client, self.assertRaises(InvokeError): | ||
898 | client.gc_status() | ||
899 | |||
900 | with self.auth_perms() as client, self.assertRaises(InvokeError): | ||
901 | client.gc_sweep("ABC") | ||
902 | |||
903 | with self.auth_perms("@db-admin") as client: | ||
904 | client.gc_mark("ABC", {"unihash": "123"}) | ||
905 | |||
906 | with self.auth_perms("@db-admin") as client: | ||
907 | client.gc_status() | ||
908 | |||
909 | with self.auth_perms("@db-admin") as client: | ||
910 | client.gc_sweep("ABC") | ||
911 | |||
912 | def test_get_db_usage(self): | ||
913 | usage = self.client.get_db_usage() | ||
914 | |||
915 | self.assertTrue(isinstance(usage, dict)) | ||
916 | for name in usage.keys(): | ||
917 | self.assertTrue(isinstance(usage[name], dict)) | ||
918 | self.assertIn("rows", usage[name]) | ||
919 | self.assertTrue(isinstance(usage[name]["rows"], int)) | ||
920 | |||
921 | def test_get_db_query_columns(self): | ||
922 | columns = self.client.get_db_query_columns() | ||
923 | |||
924 | self.assertTrue(isinstance(columns, list)) | ||
925 | self.assertTrue(len(columns) > 0) | ||
926 | |||
927 | for col in columns: | ||
928 | self.client.remove({col: ""}) | ||
929 | |||
930 | def test_auth_is_owner(self): | ||
931 | admin_client = self.start_auth_server() | ||
932 | |||
933 | user = self.create_user("test-user", ["@read", "@report"]) | ||
934 | with self.auth_client(user) as client: | ||
935 | taskhash, outhash, unihash = self.create_test_hash(client) | ||
936 | data = client.get_taskhash(self.METHOD, taskhash, True) | ||
937 | self.assertEqual(data["owner"], user["username"]) | ||
938 | |||
939 | def test_gc(self): | ||
940 | taskhash = '53b8dce672cb6d0c73170be43f540460bfc347b4' | ||
941 | outhash = '5a9cb1649625f0bf41fc7791b635cd9c2d7118c7f021ba87dcd03f72b67ce7a8' | ||
942 | unihash = 'f37918cc02eb5a520b1aff86faacbc0a38124646' | ||
943 | |||
944 | result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) | ||
945 | self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') | ||
946 | |||
947 | taskhash2 = '3bf6f1e89d26205aec90da04854fbdbf73afe6b4' | ||
948 | outhash2 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4' | ||
949 | unihash2 = 'af36b199320e611fbb16f1f277d3ee1d619ca58b' | ||
950 | |||
951 | result = self.client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2) | ||
952 | self.assertClientGetHash(self.client, taskhash2, unihash2) | ||
953 | |||
954 | # Mark the first unihash to be kept | ||
955 | ret = self.client.gc_mark("ABC", {"unihash": unihash, "method": self.METHOD}) | ||
956 | self.assertEqual(ret, {"count": 1}) | ||
957 | |||
958 | ret = self.client.gc_status() | ||
959 | self.assertEqual(ret, {"mark": "ABC", "keep": 1, "remove": 1}) | ||
960 | |||
961 | # Second hash is still there; mark doesn't delete hashes | ||
962 | self.assertClientGetHash(self.client, taskhash2, unihash2) | ||
963 | |||
964 | ret = self.client.gc_sweep("ABC") | ||
965 | self.assertEqual(ret, {"count": 1}) | ||
966 | |||
967 | # Hash is gone. Taskhash is returned for second hash | ||
284 | self.assertClientGetHash(self.client, taskhash2, None) | 968 | self.assertClientGetHash(self.client, taskhash2, None) |
969 | # First hash is still present | ||
970 | self.assertClientGetHash(self.client, taskhash, unihash) | ||
971 | |||
972 | def test_gc_switch_mark(self): | ||
973 | taskhash = '53b8dce672cb6d0c73170be43f540460bfc347b4' | ||
974 | outhash = '5a9cb1649625f0bf41fc7791b635cd9c2d7118c7f021ba87dcd03f72b67ce7a8' | ||
975 | unihash = 'f37918cc02eb5a520b1aff86faacbc0a38124646' | ||
976 | |||
977 | result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) | ||
978 | self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') | ||
979 | |||
980 | taskhash2 = '3bf6f1e89d26205aec90da04854fbdbf73afe6b4' | ||
981 | outhash2 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4' | ||
982 | unihash2 = 'af36b199320e611fbb16f1f277d3ee1d619ca58b' | ||
983 | |||
984 | result = self.client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2) | ||
985 | self.assertClientGetHash(self.client, taskhash2, unihash2) | ||
986 | |||
987 | # Mark the first unihash to be kept | ||
988 | ret = self.client.gc_mark("ABC", {"unihash": unihash, "method": self.METHOD}) | ||
989 | self.assertEqual(ret, {"count": 1}) | ||
990 | |||
991 | ret = self.client.gc_status() | ||
992 | self.assertEqual(ret, {"mark": "ABC", "keep": 1, "remove": 1}) | ||
993 | |||
994 | # Second hash is still there; mark doesn't delete hashes | ||
995 | self.assertClientGetHash(self.client, taskhash2, unihash2) | ||
996 | |||
997 | # Switch to a different mark and mark the second hash. This will start | ||
998 | # a new collection cycle | ||
999 | ret = self.client.gc_mark("DEF", {"unihash": unihash2, "method": self.METHOD}) | ||
1000 | self.assertEqual(ret, {"count": 1}) | ||
1001 | |||
1002 | ret = self.client.gc_status() | ||
1003 | self.assertEqual(ret, {"mark": "DEF", "keep": 1, "remove": 1}) | ||
1004 | |||
1005 | # Both hashes are still present | ||
1006 | self.assertClientGetHash(self.client, taskhash2, unihash2) | ||
1007 | self.assertClientGetHash(self.client, taskhash, unihash) | ||
1008 | |||
1009 | # Sweep with the new mark | ||
1010 | ret = self.client.gc_sweep("DEF") | ||
1011 | self.assertEqual(ret, {"count": 1}) | ||
1012 | |||
1013 | # First hash is gone, second is kept | ||
1014 | self.assertClientGetHash(self.client, taskhash2, unihash2) | ||
1015 | self.assertClientGetHash(self.client, taskhash, None) | ||
1016 | |||
1017 | def test_gc_switch_sweep_mark(self): | ||
1018 | taskhash = '53b8dce672cb6d0c73170be43f540460bfc347b4' | ||
1019 | outhash = '5a9cb1649625f0bf41fc7791b635cd9c2d7118c7f021ba87dcd03f72b67ce7a8' | ||
1020 | unihash = 'f37918cc02eb5a520b1aff86faacbc0a38124646' | ||
1021 | |||
1022 | result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) | ||
1023 | self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') | ||
1024 | |||
1025 | taskhash2 = '3bf6f1e89d26205aec90da04854fbdbf73afe6b4' | ||
1026 | outhash2 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4' | ||
1027 | unihash2 = 'af36b199320e611fbb16f1f277d3ee1d619ca58b' | ||
1028 | |||
1029 | result = self.client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2) | ||
1030 | self.assertClientGetHash(self.client, taskhash2, unihash2) | ||
1031 | |||
1032 | # Mark the first unihash to be kept | ||
1033 | ret = self.client.gc_mark("ABC", {"unihash": unihash, "method": self.METHOD}) | ||
1034 | self.assertEqual(ret, {"count": 1}) | ||
1035 | |||
1036 | ret = self.client.gc_status() | ||
1037 | self.assertEqual(ret, {"mark": "ABC", "keep": 1, "remove": 1}) | ||
1038 | |||
1039 | # Sweeping with a different mark raises an error | ||
1040 | with self.assertRaises(InvokeError): | ||
1041 | self.client.gc_sweep("DEF") | ||
1042 | |||
1043 | # Both hashes are present | ||
1044 | self.assertClientGetHash(self.client, taskhash2, unihash2) | ||
1045 | self.assertClientGetHash(self.client, taskhash, unihash) | ||
1046 | |||
1047 | def test_gc_new_hashes(self): | ||
1048 | taskhash = '53b8dce672cb6d0c73170be43f540460bfc347b4' | ||
1049 | outhash = '5a9cb1649625f0bf41fc7791b635cd9c2d7118c7f021ba87dcd03f72b67ce7a8' | ||
1050 | unihash = 'f37918cc02eb5a520b1aff86faacbc0a38124646' | ||
1051 | |||
1052 | result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) | ||
1053 | self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') | ||
1054 | |||
1055 | # Start a new garbage collection | ||
1056 | ret = self.client.gc_mark("ABC", {"unihash": unihash, "method": self.METHOD}) | ||
1057 | self.assertEqual(ret, {"count": 1}) | ||
1058 | |||
1059 | ret = self.client.gc_status() | ||
1060 | self.assertEqual(ret, {"mark": "ABC", "keep": 1, "remove": 0}) | ||
1061 | |||
1062 | # Add second hash. It should inherit the mark from the current garbage | ||
1063 | # collection operation | ||
1064 | |||
1065 | taskhash2 = '3bf6f1e89d26205aec90da04854fbdbf73afe6b4' | ||
1066 | outhash2 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4' | ||
1067 | unihash2 = 'af36b199320e611fbb16f1f277d3ee1d619ca58b' | ||
1068 | |||
1069 | result = self.client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2) | ||
1070 | self.assertClientGetHash(self.client, taskhash2, unihash2) | ||
1071 | |||
1072 | # Sweep should remove nothing | ||
1073 | ret = self.client.gc_sweep("ABC") | ||
1074 | self.assertEqual(ret, {"count": 0}) | ||
1075 | |||
1076 | # Both hashes are present | ||
1077 | self.assertClientGetHash(self.client, taskhash2, unihash2) | ||
1078 | self.assertClientGetHash(self.client, taskhash, unihash) | ||
1079 | |||
1080 | |||
1081 | class TestHashEquivalenceClient(HashEquivalenceTestSetup, unittest.TestCase): | ||
1082 | def get_server_addr(self, server_idx): | ||
1083 | return "unix://" + os.path.join(self.temp_dir.name, 'sock%d' % server_idx) | ||
1084 | |||
1085 | def test_get(self): | ||
1086 | taskhash, outhash, unihash = self.create_test_hash(self.client) | ||
1087 | |||
1088 | p = self.run_hashclient(["--address", self.server_address, "get", self.METHOD, taskhash]) | ||
1089 | data = json.loads(p.stdout) | ||
1090 | self.assertEqual(data["unihash"], unihash) | ||
1091 | self.assertEqual(data["outhash"], outhash) | ||
1092 | self.assertEqual(data["taskhash"], taskhash) | ||
1093 | self.assertEqual(data["method"], self.METHOD) | ||
1094 | |||
1095 | def test_get_outhash(self): | ||
1096 | taskhash, outhash, unihash = self.create_test_hash(self.client) | ||
1097 | |||
1098 | p = self.run_hashclient(["--address", self.server_address, "get-outhash", self.METHOD, outhash, taskhash]) | ||
1099 | data = json.loads(p.stdout) | ||
1100 | self.assertEqual(data["unihash"], unihash) | ||
1101 | self.assertEqual(data["outhash"], outhash) | ||
1102 | self.assertEqual(data["taskhash"], taskhash) | ||
1103 | self.assertEqual(data["method"], self.METHOD) | ||
1104 | |||
1105 | def test_stats(self): | ||
1106 | p = self.run_hashclient(["--address", self.server_address, "stats"], check=True) | ||
1107 | json.loads(p.stdout) | ||
1108 | |||
1109 | def test_stress(self): | ||
1110 | self.run_hashclient(["--address", self.server_address, "stress"], check=True) | ||
1111 | |||
1112 | def test_unihash_exsits(self): | ||
1113 | taskhash, outhash, unihash = self.create_test_hash(self.client) | ||
1114 | |||
1115 | p = self.run_hashclient([ | ||
1116 | "--address", self.server_address, | ||
1117 | "unihash-exists", unihash, | ||
1118 | ], check=True) | ||
1119 | self.assertEqual(p.stdout.strip(), "true") | ||
1120 | |||
1121 | p = self.run_hashclient([ | ||
1122 | "--address", self.server_address, | ||
1123 | "unihash-exists", '6662e699d6e3d894b24408ff9a4031ef9b038ee8', | ||
1124 | ], check=True) | ||
1125 | self.assertEqual(p.stdout.strip(), "false") | ||
1126 | |||
1127 | def test_unihash_exsits_quiet(self): | ||
1128 | taskhash, outhash, unihash = self.create_test_hash(self.client) | ||
1129 | |||
1130 | p = self.run_hashclient([ | ||
1131 | "--address", self.server_address, | ||
1132 | "unihash-exists", unihash, | ||
1133 | "--quiet", | ||
1134 | ]) | ||
1135 | self.assertEqual(p.returncode, 0) | ||
1136 | self.assertEqual(p.stdout.strip(), "") | ||
1137 | |||
1138 | p = self.run_hashclient([ | ||
1139 | "--address", self.server_address, | ||
1140 | "unihash-exists", '6662e699d6e3d894b24408ff9a4031ef9b038ee8', | ||
1141 | "--quiet", | ||
1142 | ]) | ||
1143 | self.assertEqual(p.returncode, 1) | ||
1144 | self.assertEqual(p.stdout.strip(), "") | ||
1145 | |||
1146 | def test_remove_taskhash(self): | ||
1147 | taskhash, outhash, unihash = self.create_test_hash(self.client) | ||
1148 | self.run_hashclient([ | ||
1149 | "--address", self.server_address, | ||
1150 | "remove", | ||
1151 | "--where", "taskhash", taskhash, | ||
1152 | ], check=True) | ||
1153 | self.assertClientGetHash(self.client, taskhash, None) | ||
1154 | |||
1155 | result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash) | ||
1156 | self.assertIsNone(result_outhash) | ||
1157 | |||
1158 | def test_remove_unihash(self): | ||
1159 | taskhash, outhash, unihash = self.create_test_hash(self.client) | ||
1160 | self.run_hashclient([ | ||
1161 | "--address", self.server_address, | ||
1162 | "remove", | ||
1163 | "--where", "unihash", unihash, | ||
1164 | ], check=True) | ||
1165 | self.assertClientGetHash(self.client, taskhash, None) | ||
1166 | |||
1167 | def test_remove_outhash(self): | ||
1168 | taskhash, outhash, unihash = self.create_test_hash(self.client) | ||
1169 | self.run_hashclient([ | ||
1170 | "--address", self.server_address, | ||
1171 | "remove", | ||
1172 | "--where", "outhash", outhash, | ||
1173 | ], check=True) | ||
1174 | |||
1175 | result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash) | ||
1176 | self.assertIsNone(result_outhash) | ||
1177 | |||
1178 | def test_remove_method(self): | ||
1179 | taskhash, outhash, unihash = self.create_test_hash(self.client) | ||
1180 | self.run_hashclient([ | ||
1181 | "--address", self.server_address, | ||
1182 | "remove", | ||
1183 | "--where", "method", self.METHOD, | ||
1184 | ], check=True) | ||
1185 | self.assertClientGetHash(self.client, taskhash, None) | ||
1186 | |||
1187 | result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash) | ||
1188 | self.assertIsNone(result_outhash) | ||
1189 | |||
1190 | def test_clean_unused(self): | ||
1191 | taskhash, outhash, unihash = self.create_test_hash(self.client) | ||
1192 | |||
1193 | # Clean the database, which should not remove anything because all hashes an in-use | ||
1194 | self.run_hashclient([ | ||
1195 | "--address", self.server_address, | ||
1196 | "clean-unused", "0", | ||
1197 | ], check=True) | ||
1198 | self.assertClientGetHash(self.client, taskhash, unihash) | ||
1199 | |||
1200 | # Remove the unihash. The row in the outhash table should still be present | ||
1201 | self.run_hashclient([ | ||
1202 | "--address", self.server_address, | ||
1203 | "remove", | ||
1204 | "--where", "unihash", unihash, | ||
1205 | ], check=True) | ||
1206 | result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash, False) | ||
1207 | self.assertIsNotNone(result_outhash) | ||
1208 | |||
1209 | # Now clean with no minimum age which will remove the outhash | ||
1210 | self.run_hashclient([ | ||
1211 | "--address", self.server_address, | ||
1212 | "clean-unused", "0", | ||
1213 | ], check=True) | ||
1214 | result_outhash = self.client.get_outhash(self.METHOD, outhash, taskhash, False) | ||
1215 | self.assertIsNone(result_outhash) | ||
1216 | |||
1217 | def test_refresh_token(self): | ||
1218 | admin_client = self.start_auth_server() | ||
1219 | |||
1220 | user = admin_client.new_user("test-user", ["@read", "@report"]) | ||
1221 | |||
1222 | p = self.run_hashclient([ | ||
1223 | "--address", self.auth_server_address, | ||
1224 | "--login", user["username"], | ||
1225 | "--password", user["token"], | ||
1226 | "refresh-token" | ||
1227 | ], check=True) | ||
1228 | |||
1229 | new_token = None | ||
1230 | for l in p.stdout.splitlines(): | ||
1231 | l = l.rstrip() | ||
1232 | m = re.match(r'Token: +(.*)$', l) | ||
1233 | if m is not None: | ||
1234 | new_token = m.group(1) | ||
1235 | |||
1236 | self.assertTrue(new_token) | ||
1237 | |||
1238 | print("New token is %r" % new_token) | ||
1239 | |||
1240 | self.run_hashclient([ | ||
1241 | "--address", self.auth_server_address, | ||
1242 | "--login", user["username"], | ||
1243 | "--password", new_token, | ||
1244 | "get-user" | ||
1245 | ], check=True) | ||
1246 | |||
1247 | def test_set_user_perms(self): | ||
1248 | admin_client = self.start_auth_server() | ||
1249 | |||
1250 | user = admin_client.new_user("test-user", ["@read"]) | ||
1251 | |||
1252 | self.run_hashclient([ | ||
1253 | "--address", self.auth_server_address, | ||
1254 | "--login", admin_client.username, | ||
1255 | "--password", admin_client.password, | ||
1256 | "set-user-perms", | ||
1257 | "-u", user["username"], | ||
1258 | "@read", "@report", | ||
1259 | ], check=True) | ||
1260 | |||
1261 | new_user = admin_client.get_user(user["username"]) | ||
1262 | |||
1263 | self.assertEqual(set(new_user["permissions"]), {"@read", "@report"}) | ||
1264 | |||
1265 | def test_get_user(self): | ||
1266 | admin_client = self.start_auth_server() | ||
1267 | |||
1268 | user = admin_client.new_user("test-user", ["@read"]) | ||
1269 | |||
1270 | p = self.run_hashclient([ | ||
1271 | "--address", self.auth_server_address, | ||
1272 | "--login", admin_client.username, | ||
1273 | "--password", admin_client.password, | ||
1274 | "get-user", | ||
1275 | "-u", user["username"], | ||
1276 | ], check=True) | ||
1277 | |||
1278 | self.assertIn("Username:", p.stdout) | ||
1279 | self.assertIn("Permissions:", p.stdout) | ||
1280 | |||
1281 | p = self.run_hashclient([ | ||
1282 | "--address", self.auth_server_address, | ||
1283 | "--login", user["username"], | ||
1284 | "--password", user["token"], | ||
1285 | "get-user", | ||
1286 | ], check=True) | ||
1287 | |||
1288 | self.assertIn("Username:", p.stdout) | ||
1289 | self.assertIn("Permissions:", p.stdout) | ||
1290 | |||
1291 | def test_get_all_users(self): | ||
1292 | admin_client = self.start_auth_server() | ||
1293 | |||
1294 | admin_client.new_user("test-user1", ["@read"]) | ||
1295 | admin_client.new_user("test-user2", ["@read"]) | ||
1296 | |||
1297 | p = self.run_hashclient([ | ||
1298 | "--address", self.auth_server_address, | ||
1299 | "--login", admin_client.username, | ||
1300 | "--password", admin_client.password, | ||
1301 | "get-all-users", | ||
1302 | ], check=True) | ||
1303 | |||
1304 | self.assertIn("admin", p.stdout) | ||
1305 | self.assertIn("test-user1", p.stdout) | ||
1306 | self.assertIn("test-user2", p.stdout) | ||
1307 | |||
1308 | def test_new_user(self): | ||
1309 | admin_client = self.start_auth_server() | ||
1310 | |||
1311 | p = self.run_hashclient([ | ||
1312 | "--address", self.auth_server_address, | ||
1313 | "--login", admin_client.username, | ||
1314 | "--password", admin_client.password, | ||
1315 | "new-user", | ||
1316 | "-u", "test-user", | ||
1317 | "@read", "@report", | ||
1318 | ], check=True) | ||
1319 | |||
1320 | new_token = None | ||
1321 | for l in p.stdout.splitlines(): | ||
1322 | l = l.rstrip() | ||
1323 | m = re.match(r'Token: +(.*)$', l) | ||
1324 | if m is not None: | ||
1325 | new_token = m.group(1) | ||
1326 | |||
1327 | self.assertTrue(new_token) | ||
1328 | |||
1329 | user = { | ||
1330 | "username": "test-user", | ||
1331 | "token": new_token, | ||
1332 | } | ||
1333 | |||
1334 | self.assertUserPerms(user, ["@read", "@report"]) | ||
1335 | |||
1336 | def test_delete_user(self): | ||
1337 | admin_client = self.start_auth_server() | ||
1338 | |||
1339 | user = admin_client.new_user("test-user", ["@read"]) | ||
1340 | |||
1341 | p = self.run_hashclient([ | ||
1342 | "--address", self.auth_server_address, | ||
1343 | "--login", admin_client.username, | ||
1344 | "--password", admin_client.password, | ||
1345 | "delete-user", | ||
1346 | "-u", user["username"], | ||
1347 | ], check=True) | ||
1348 | |||
1349 | self.assertIsNone(admin_client.get_user(user["username"])) | ||
1350 | |||
1351 | def test_get_db_usage(self): | ||
1352 | p = self.run_hashclient([ | ||
1353 | "--address", self.server_address, | ||
1354 | "get-db-usage", | ||
1355 | ], check=True) | ||
1356 | |||
1357 | def test_get_db_query_columns(self): | ||
1358 | p = self.run_hashclient([ | ||
1359 | "--address", self.server_address, | ||
1360 | "get-db-query-columns", | ||
1361 | ], check=True) | ||
1362 | |||
1363 | def test_gc(self): | ||
1364 | taskhash = '53b8dce672cb6d0c73170be43f540460bfc347b4' | ||
1365 | outhash = '5a9cb1649625f0bf41fc7791b635cd9c2d7118c7f021ba87dcd03f72b67ce7a8' | ||
1366 | unihash = 'f37918cc02eb5a520b1aff86faacbc0a38124646' | ||
1367 | |||
1368 | result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) | ||
1369 | self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') | ||
1370 | |||
1371 | taskhash2 = '3bf6f1e89d26205aec90da04854fbdbf73afe6b4' | ||
1372 | outhash2 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4' | ||
1373 | unihash2 = 'af36b199320e611fbb16f1f277d3ee1d619ca58b' | ||
1374 | |||
1375 | result = self.client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2) | ||
1376 | self.assertClientGetHash(self.client, taskhash2, unihash2) | ||
1377 | |||
1378 | # Mark the first unihash to be kept | ||
1379 | self.run_hashclient([ | ||
1380 | "--address", self.server_address, | ||
1381 | "gc-mark", "ABC", | ||
1382 | "--where", "unihash", unihash, | ||
1383 | "--where", "method", self.METHOD | ||
1384 | ], check=True) | ||
1385 | |||
1386 | # Second hash is still there; mark doesn't delete hashes | ||
1387 | self.assertClientGetHash(self.client, taskhash2, unihash2) | ||
1388 | |||
1389 | self.run_hashclient([ | ||
1390 | "--address", self.server_address, | ||
1391 | "gc-sweep", "ABC", | ||
1392 | ], check=True) | ||
1393 | |||
1394 | # Hash is gone. Taskhash is returned for second hash | ||
1395 | self.assertClientGetHash(self.client, taskhash2, None) | ||
1396 | # First hash is still present | ||
1397 | self.assertClientGetHash(self.client, taskhash, unihash) | ||
285 | 1398 | ||
286 | 1399 | ||
287 | class TestHashEquivalenceUnixServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase): | 1400 | class TestHashEquivalenceUnixServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase): |
@@ -314,3 +1427,77 @@ class TestHashEquivalenceTCPServer(HashEquivalenceTestSetup, HashEquivalenceComm | |||
314 | # If IPv6 is enabled, it should be safe to use localhost directly, in general | 1427 | # If IPv6 is enabled, it should be safe to use localhost directly, in general |
315 | # case it is more reliable to resolve the IP address explicitly. | 1428 | # case it is more reliable to resolve the IP address explicitly. |
316 | return socket.gethostbyname("localhost") + ":0" | 1429 | return socket.gethostbyname("localhost") + ":0" |
1430 | |||
1431 | |||
1432 | class TestHashEquivalenceWebsocketServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase): | ||
1433 | def setUp(self): | ||
1434 | try: | ||
1435 | import websockets | ||
1436 | except ImportError as e: | ||
1437 | self.skipTest(str(e)) | ||
1438 | |||
1439 | super().setUp() | ||
1440 | |||
1441 | def get_server_addr(self, server_idx): | ||
1442 | # Some hosts cause asyncio module to misbehave, when IPv6 is not enabled. | ||
1443 | # If IPv6 is enabled, it should be safe to use localhost directly, in general | ||
1444 | # case it is more reliable to resolve the IP address explicitly. | ||
1445 | host = socket.gethostbyname("localhost") | ||
1446 | return "ws://%s:0" % host | ||
1447 | |||
1448 | |||
1449 | class TestHashEquivalenceWebsocketsSQLAlchemyServer(TestHashEquivalenceWebsocketServer): | ||
1450 | def setUp(self): | ||
1451 | try: | ||
1452 | import sqlalchemy | ||
1453 | import aiosqlite | ||
1454 | except ImportError as e: | ||
1455 | self.skipTest(str(e)) | ||
1456 | |||
1457 | super().setUp() | ||
1458 | |||
1459 | def make_dbpath(self): | ||
1460 | return "sqlite+aiosqlite:///%s" % os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index) | ||
1461 | |||
1462 | |||
1463 | class TestHashEquivalenceExternalServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase): | ||
1464 | def get_env(self, name): | ||
1465 | v = os.environ.get(name) | ||
1466 | if not v: | ||
1467 | self.skipTest(f'{name} not defined to test an external server') | ||
1468 | return v | ||
1469 | |||
1470 | def start_test_server(self): | ||
1471 | return self.get_env('BB_TEST_HASHSERV') | ||
1472 | |||
1473 | def start_server(self, *args, **kwargs): | ||
1474 | self.skipTest('Cannot start local server when testing external servers') | ||
1475 | |||
1476 | def start_auth_server(self): | ||
1477 | |||
1478 | self.auth_server_address = self.server_address | ||
1479 | self.admin_client = self.start_client( | ||
1480 | self.server_address, | ||
1481 | username=self.get_env('BB_TEST_HASHSERV_USERNAME'), | ||
1482 | password=self.get_env('BB_TEST_HASHSERV_PASSWORD'), | ||
1483 | ) | ||
1484 | return self.admin_client | ||
1485 | |||
1486 | def setUp(self): | ||
1487 | super().setUp() | ||
1488 | if "BB_TEST_HASHSERV_USERNAME" in os.environ: | ||
1489 | self.client = self.start_client( | ||
1490 | self.server_address, | ||
1491 | username=os.environ["BB_TEST_HASHSERV_USERNAME"], | ||
1492 | password=os.environ["BB_TEST_HASHSERV_PASSWORD"], | ||
1493 | ) | ||
1494 | self.client.remove({"method": self.METHOD}) | ||
1495 | |||
1496 | def tearDown(self): | ||
1497 | self.client.remove({"method": self.METHOD}) | ||
1498 | super().tearDown() | ||
1499 | |||
1500 | |||
1501 | def test_auth_get_all_users(self): | ||
1502 | self.skipTest("Cannot test all users with external server") | ||
1503 | |||