diff options
Diffstat (limited to 'bitbake/lib/hashserv/server.py')
-rw-r--r-- | bitbake/lib/hashserv/server.py | 1088 |
1 files changed, 671 insertions, 417 deletions
diff --git a/bitbake/lib/hashserv/server.py b/bitbake/lib/hashserv/server.py index a0dc0c170f..68f64f983b 100644 --- a/bitbake/lib/hashserv/server.py +++ b/bitbake/lib/hashserv/server.py | |||
@@ -3,20 +3,51 @@ | |||
3 | # SPDX-License-Identifier: GPL-2.0-only | 3 | # SPDX-License-Identifier: GPL-2.0-only |
4 | # | 4 | # |
5 | 5 | ||
6 | from contextlib import closing, contextmanager | 6 | from datetime import datetime, timedelta |
7 | from datetime import datetime | ||
8 | import asyncio | 7 | import asyncio |
9 | import json | ||
10 | import logging | 8 | import logging |
11 | import math | 9 | import math |
12 | import os | ||
13 | import signal | ||
14 | import socket | ||
15 | import sys | ||
16 | import time | 10 | import time |
17 | from . import chunkify, DEFAULT_MAX_CHUNK, create_async_client, TABLE_COLUMNS | 11 | import os |
12 | import base64 | ||
13 | import hashlib | ||
14 | from . import create_async_client | ||
15 | import bb.asyncrpc | ||
16 | |||
17 | logger = logging.getLogger("hashserv.server") | ||
18 | |||
19 | |||
20 | # This permission only exists to match nothing | ||
21 | NONE_PERM = "@none" | ||
22 | |||
23 | READ_PERM = "@read" | ||
24 | REPORT_PERM = "@report" | ||
25 | DB_ADMIN_PERM = "@db-admin" | ||
26 | USER_ADMIN_PERM = "@user-admin" | ||
27 | ALL_PERM = "@all" | ||
18 | 28 | ||
19 | logger = logging.getLogger('hashserv.server') | 29 | ALL_PERMISSIONS = { |
30 | READ_PERM, | ||
31 | REPORT_PERM, | ||
32 | DB_ADMIN_PERM, | ||
33 | USER_ADMIN_PERM, | ||
34 | ALL_PERM, | ||
35 | } | ||
36 | |||
37 | DEFAULT_ANON_PERMS = ( | ||
38 | READ_PERM, | ||
39 | REPORT_PERM, | ||
40 | DB_ADMIN_PERM, | ||
41 | ) | ||
42 | |||
43 | TOKEN_ALGORITHM = "sha256" | ||
44 | |||
45 | # 48 bytes of random data will result in 64 characters when base64 | ||
46 | # encoded. This number also ensures that the base64 encoding won't have any | ||
47 | # trailing '=' characters. | ||
48 | TOKEN_SIZE = 48 | ||
49 | |||
50 | SALT_SIZE = 8 | ||
20 | 51 | ||
21 | 52 | ||
22 | class Measurement(object): | 53 | class Measurement(object): |
@@ -106,522 +137,745 @@ class Stats(object): | |||
106 | return math.sqrt(self.s / (self.num - 1)) | 137 | return math.sqrt(self.s / (self.num - 1)) |
107 | 138 | ||
108 | def todict(self): | 139 | def todict(self): |
109 | return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')} | 140 | return { |
110 | 141 | k: getattr(self, k) | |
111 | 142 | for k in ("num", "total_time", "max_time", "average", "stdev") | |
112 | class ClientError(Exception): | ||
113 | pass | ||
114 | |||
115 | class ServerError(Exception): | ||
116 | pass | ||
117 | |||
118 | def insert_task(cursor, data, ignore=False): | ||
119 | keys = sorted(data.keys()) | ||
120 | query = '''INSERT%s INTO tasks_v2 (%s) VALUES (%s)''' % ( | ||
121 | " OR IGNORE" if ignore else "", | ||
122 | ', '.join(keys), | ||
123 | ', '.join(':' + k for k in keys)) | ||
124 | cursor.execute(query, data) | ||
125 | |||
126 | async def copy_from_upstream(client, db, method, taskhash): | ||
127 | d = await client.get_taskhash(method, taskhash, True) | ||
128 | if d is not None: | ||
129 | # Filter out unknown columns | ||
130 | d = {k: v for k, v in d.items() if k in TABLE_COLUMNS} | ||
131 | keys = sorted(d.keys()) | ||
132 | |||
133 | with closing(db.cursor()) as cursor: | ||
134 | insert_task(cursor, d) | ||
135 | db.commit() | ||
136 | |||
137 | return d | ||
138 | |||
139 | async def copy_outhash_from_upstream(client, db, method, outhash, taskhash): | ||
140 | d = await client.get_outhash(method, outhash, taskhash) | ||
141 | if d is not None: | ||
142 | # Filter out unknown columns | ||
143 | d = {k: v for k, v in d.items() if k in TABLE_COLUMNS} | ||
144 | keys = sorted(d.keys()) | ||
145 | |||
146 | with closing(db.cursor()) as cursor: | ||
147 | insert_task(cursor, d) | ||
148 | db.commit() | ||
149 | |||
150 | return d | ||
151 | |||
152 | class ServerClient(object): | ||
153 | FAST_QUERY = 'SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1' | ||
154 | ALL_QUERY = 'SELECT * FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1' | ||
155 | OUTHASH_QUERY = ''' | ||
156 | -- Find tasks with a matching outhash (that is, tasks that | ||
157 | -- are equivalent) | ||
158 | SELECT * FROM tasks_v2 WHERE method=:method AND outhash=:outhash | ||
159 | |||
160 | -- If there is an exact match on the taskhash, return it. | ||
161 | -- Otherwise return the oldest matching outhash of any | ||
162 | -- taskhash | ||
163 | ORDER BY CASE WHEN taskhash=:taskhash THEN 1 ELSE 2 END, | ||
164 | created ASC | ||
165 | |||
166 | -- Only return one row | ||
167 | LIMIT 1 | ||
168 | ''' | ||
169 | |||
170 | def __init__(self, reader, writer, db, request_stats, backfill_queue, upstream, read_only): | ||
171 | self.reader = reader | ||
172 | self.writer = writer | ||
173 | self.db = db | ||
174 | self.request_stats = request_stats | ||
175 | self.max_chunk = DEFAULT_MAX_CHUNK | ||
176 | self.backfill_queue = backfill_queue | ||
177 | self.upstream = upstream | ||
178 | |||
179 | self.handlers = { | ||
180 | 'get': self.handle_get, | ||
181 | 'get-outhash': self.handle_get_outhash, | ||
182 | 'get-stream': self.handle_get_stream, | ||
183 | 'get-stats': self.handle_get_stats, | ||
184 | 'chunk-stream': self.handle_chunk, | ||
185 | } | 143 | } |
186 | 144 | ||
187 | if not read_only: | ||
188 | self.handlers.update({ | ||
189 | 'report': self.handle_report, | ||
190 | 'report-equiv': self.handle_equivreport, | ||
191 | 'reset-stats': self.handle_reset_stats, | ||
192 | 'backfill-wait': self.handle_backfill_wait, | ||
193 | }) | ||
194 | 145 | ||
195 | async def process_requests(self): | 146 | token_refresh_semaphore = asyncio.Lock() |
196 | if self.upstream is not None: | ||
197 | self.upstream_client = await create_async_client(self.upstream) | ||
198 | else: | ||
199 | self.upstream_client = None | ||
200 | 147 | ||
201 | try: | ||
202 | 148 | ||
149 | async def new_token(): | ||
150 | # Prevent malicious users from using this API to deduce the entropy | ||
151 | # pool on the server and thus be able to guess a token. *All* token | ||
152 | # refresh requests lock the same global semaphore and then sleep for a | ||
153 | # short time. The effectively rate limits the total number of requests | ||
154 | # than can be made across all clients to 10/second, which should be enough | ||
155 | # since you have to be an authenticated users to make the request in the | ||
156 | # first place | ||
157 | async with token_refresh_semaphore: | ||
158 | await asyncio.sleep(0.1) | ||
159 | raw = os.getrandom(TOKEN_SIZE, os.GRND_NONBLOCK) | ||
203 | 160 | ||
204 | self.addr = self.writer.get_extra_info('peername') | 161 | return base64.b64encode(raw, b"._").decode("utf-8") |
205 | logger.debug('Client %r connected' % (self.addr,)) | ||
206 | 162 | ||
207 | # Read protocol and version | ||
208 | protocol = await self.reader.readline() | ||
209 | if protocol is None: | ||
210 | return | ||
211 | 163 | ||
212 | (proto_name, proto_version) = protocol.decode('utf-8').rstrip().split() | 164 | def new_salt(): |
213 | if proto_name != 'OEHASHEQUIV': | 165 | return os.getrandom(SALT_SIZE, os.GRND_NONBLOCK).hex() |
214 | return | ||
215 | 166 | ||
216 | proto_version = tuple(int(v) for v in proto_version.split('.')) | ||
217 | if proto_version < (1, 0) or proto_version > (1, 1): | ||
218 | return | ||
219 | 167 | ||
220 | # Read headers. Currently, no headers are implemented, so look for | 168 | def hash_token(algo, salt, token): |
221 | # an empty line to signal the end of the headers | 169 | h = hashlib.new(algo) |
222 | while True: | 170 | h.update(salt.encode("utf-8")) |
223 | line = await self.reader.readline() | 171 | h.update(token.encode("utf-8")) |
224 | if line is None: | 172 | return ":".join([algo, salt, h.hexdigest()]) |
225 | return | ||
226 | 173 | ||
227 | line = line.decode('utf-8').rstrip() | ||
228 | if not line: | ||
229 | break | ||
230 | 174 | ||
231 | # Handle messages | 175 | def permissions(*permissions, allow_anon=True, allow_self_service=False): |
232 | while True: | 176 | """ |
233 | d = await self.read_message() | 177 | Function decorator that can be used to decorate an RPC function call and |
234 | if d is None: | 178 | check that the current users permissions match the require permissions. |
235 | break | ||
236 | await self.dispatch_message(d) | ||
237 | await self.writer.drain() | ||
238 | except ClientError as e: | ||
239 | logger.error(str(e)) | ||
240 | finally: | ||
241 | if self.upstream_client is not None: | ||
242 | await self.upstream_client.close() | ||
243 | 179 | ||
244 | self.writer.close() | 180 | If allow_anon is True, the user will also be allowed to make the RPC call |
181 | if the anonymous user permissions match the permissions. | ||
245 | 182 | ||
246 | async def dispatch_message(self, msg): | 183 | If allow_self_service is True, and the "username" property in the request |
247 | for k in self.handlers.keys(): | 184 | is the currently logged in user, or not specified, the user will also be |
248 | if k in msg: | 185 | allowed to make the request. This allows users to access normal privileged |
249 | logger.debug('Handling %s' % k) | 186 | API, as long as they are only modifying their own user properties (e.g. |
250 | if 'stream' in k: | 187 | users can be allowed to reset their own token without @user-admin |
251 | await self.handlers[k](msg[k]) | 188 | permissions, but not the token for any other user. |
189 | """ | ||
190 | |||
191 | def wrapper(func): | ||
192 | async def wrap(self, request): | ||
193 | if allow_self_service and self.user is not None: | ||
194 | username = request.get("username", self.user.username) | ||
195 | if username == self.user.username: | ||
196 | request["username"] = self.user.username | ||
197 | return await func(self, request) | ||
198 | |||
199 | if not self.user_has_permissions(*permissions, allow_anon=allow_anon): | ||
200 | if not self.user: | ||
201 | username = "Anonymous user" | ||
202 | user_perms = self.server.anon_perms | ||
252 | else: | 203 | else: |
253 | with self.request_stats.start_sample() as self.request_sample, \ | 204 | username = self.user.username |
254 | self.request_sample.measure(): | 205 | user_perms = self.user.permissions |
255 | await self.handlers[k](msg[k]) | 206 | |
256 | return | 207 | self.logger.info( |
208 | "User %s with permissions %r denied from calling %s. Missing permissions(s) %r", | ||
209 | username, | ||
210 | ", ".join(user_perms), | ||
211 | func.__name__, | ||
212 | ", ".join(permissions), | ||
213 | ) | ||
214 | raise bb.asyncrpc.InvokeError( | ||
215 | f"{username} is not allowed to access permissions(s) {', '.join(permissions)}" | ||
216 | ) | ||
217 | |||
218 | return await func(self, request) | ||
219 | |||
220 | return wrap | ||
221 | |||
222 | return wrapper | ||
223 | |||
224 | |||
225 | class ServerClient(bb.asyncrpc.AsyncServerConnection): | ||
226 | def __init__(self, socket, server): | ||
227 | super().__init__(socket, "OEHASHEQUIV", server.logger) | ||
228 | self.server = server | ||
229 | self.max_chunk = bb.asyncrpc.DEFAULT_MAX_CHUNK | ||
230 | self.user = None | ||
231 | |||
232 | self.handlers.update( | ||
233 | { | ||
234 | "get": self.handle_get, | ||
235 | "get-outhash": self.handle_get_outhash, | ||
236 | "get-stream": self.handle_get_stream, | ||
237 | "exists-stream": self.handle_exists_stream, | ||
238 | "get-stats": self.handle_get_stats, | ||
239 | "get-db-usage": self.handle_get_db_usage, | ||
240 | "get-db-query-columns": self.handle_get_db_query_columns, | ||
241 | # Not always read-only, but internally checks if the server is | ||
242 | # read-only | ||
243 | "report": self.handle_report, | ||
244 | "auth": self.handle_auth, | ||
245 | "get-user": self.handle_get_user, | ||
246 | "get-all-users": self.handle_get_all_users, | ||
247 | "become-user": self.handle_become_user, | ||
248 | } | ||
249 | ) | ||
257 | 250 | ||
258 | raise ClientError("Unrecognized command %r" % msg) | 251 | if not self.server.read_only: |
252 | self.handlers.update( | ||
253 | { | ||
254 | "report-equiv": self.handle_equivreport, | ||
255 | "reset-stats": self.handle_reset_stats, | ||
256 | "backfill-wait": self.handle_backfill_wait, | ||
257 | "remove": self.handle_remove, | ||
258 | "gc-mark": self.handle_gc_mark, | ||
259 | "gc-sweep": self.handle_gc_sweep, | ||
260 | "gc-status": self.handle_gc_status, | ||
261 | "clean-unused": self.handle_clean_unused, | ||
262 | "refresh-token": self.handle_refresh_token, | ||
263 | "set-user-perms": self.handle_set_perms, | ||
264 | "new-user": self.handle_new_user, | ||
265 | "delete-user": self.handle_delete_user, | ||
266 | } | ||
267 | ) | ||
259 | 268 | ||
260 | def write_message(self, msg): | 269 | def raise_no_user_error(self, username): |
261 | for c in chunkify(json.dumps(msg), self.max_chunk): | 270 | raise bb.asyncrpc.InvokeError(f"No user named '{username}' exists") |
262 | self.writer.write(c.encode('utf-8')) | ||
263 | 271 | ||
264 | async def read_message(self): | 272 | def user_has_permissions(self, *permissions, allow_anon=True): |
265 | l = await self.reader.readline() | 273 | permissions = set(permissions) |
266 | if not l: | 274 | if allow_anon: |
267 | return None | 275 | if ALL_PERM in self.server.anon_perms: |
276 | return True | ||
268 | 277 | ||
269 | try: | 278 | if not permissions - self.server.anon_perms: |
270 | message = l.decode('utf-8') | 279 | return True |
271 | 280 | ||
272 | if not message.endswith('\n'): | 281 | if self.user is None: |
273 | return None | 282 | return False |
274 | 283 | ||
275 | return json.loads(message) | 284 | if ALL_PERM in self.user.permissions: |
276 | except (json.JSONDecodeError, UnicodeDecodeError) as e: | 285 | return True |
277 | logger.error('Bad message from client: %r' % message) | ||
278 | raise e | ||
279 | 286 | ||
280 | async def handle_chunk(self, request): | 287 | if not permissions - self.user.permissions: |
281 | lines = [] | 288 | return True |
282 | try: | ||
283 | while True: | ||
284 | l = await self.reader.readline() | ||
285 | l = l.rstrip(b"\n").decode("utf-8") | ||
286 | if not l: | ||
287 | break | ||
288 | lines.append(l) | ||
289 | 289 | ||
290 | msg = json.loads(''.join(lines)) | 290 | return False |
291 | except (json.JSONDecodeError, UnicodeDecodeError) as e: | ||
292 | logger.error('Bad message from client: %r' % message) | ||
293 | raise e | ||
294 | 291 | ||
295 | if 'chunk-stream' in msg: | 292 | def validate_proto_version(self): |
296 | raise ClientError("Nested chunks are not allowed") | 293 | return self.proto_version > (1, 0) and self.proto_version <= (1, 1) |
297 | 294 | ||
298 | await self.dispatch_message(msg) | 295 | async def process_requests(self): |
296 | async with self.server.db_engine.connect(self.logger) as db: | ||
297 | self.db = db | ||
298 | if self.server.upstream is not None: | ||
299 | self.upstream_client = await create_async_client(self.server.upstream) | ||
300 | else: | ||
301 | self.upstream_client = None | ||
299 | 302 | ||
300 | async def handle_get(self, request): | 303 | try: |
301 | method = request['method'] | 304 | await super().process_requests() |
302 | taskhash = request['taskhash'] | 305 | finally: |
306 | if self.upstream_client is not None: | ||
307 | await self.upstream_client.close() | ||
303 | 308 | ||
304 | if request.get('all', False): | 309 | async def dispatch_message(self, msg): |
305 | row = self.query_equivalent(method, taskhash, self.ALL_QUERY) | 310 | for k in self.handlers.keys(): |
306 | else: | 311 | if k in msg: |
307 | row = self.query_equivalent(method, taskhash, self.FAST_QUERY) | 312 | self.logger.debug("Handling %s" % k) |
313 | if "stream" in k: | ||
314 | return await self.handlers[k](msg[k]) | ||
315 | else: | ||
316 | with self.server.request_stats.start_sample() as self.request_sample, self.request_sample.measure(): | ||
317 | return await self.handlers[k](msg[k]) | ||
308 | 318 | ||
309 | if row is not None: | 319 | raise bb.asyncrpc.ClientError("Unrecognized command %r" % msg) |
310 | logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) | 320 | |
311 | d = {k: row[k] for k in row.keys()} | 321 | @permissions(READ_PERM) |
312 | elif self.upstream_client is not None: | 322 | async def handle_get(self, request): |
313 | d = await copy_from_upstream(self.upstream_client, self.db, method, taskhash) | 323 | method = request["method"] |
324 | taskhash = request["taskhash"] | ||
325 | fetch_all = request.get("all", False) | ||
326 | |||
327 | return await self.get_unihash(method, taskhash, fetch_all) | ||
328 | |||
329 | async def get_unihash(self, method, taskhash, fetch_all=False): | ||
330 | d = None | ||
331 | |||
332 | if fetch_all: | ||
333 | row = await self.db.get_unihash_by_taskhash_full(method, taskhash) | ||
334 | if row is not None: | ||
335 | d = {k: row[k] for k in row.keys()} | ||
336 | elif self.upstream_client is not None: | ||
337 | d = await self.upstream_client.get_taskhash(method, taskhash, True) | ||
338 | await self.update_unified(d) | ||
314 | else: | 339 | else: |
315 | d = None | 340 | row = await self.db.get_equivalent(method, taskhash) |
341 | |||
342 | if row is not None: | ||
343 | d = {k: row[k] for k in row.keys()} | ||
344 | elif self.upstream_client is not None: | ||
345 | d = await self.upstream_client.get_taskhash(method, taskhash) | ||
346 | await self.db.insert_unihash(d["method"], d["taskhash"], d["unihash"]) | ||
316 | 347 | ||
317 | self.write_message(d) | 348 | return d |
318 | 349 | ||
350 | @permissions(READ_PERM) | ||
319 | async def handle_get_outhash(self, request): | 351 | async def handle_get_outhash(self, request): |
320 | with closing(self.db.cursor()) as cursor: | 352 | method = request["method"] |
321 | cursor.execute(self.OUTHASH_QUERY, | 353 | outhash = request["outhash"] |
322 | {k: request[k] for k in ('method', 'outhash', 'taskhash')}) | 354 | taskhash = request["taskhash"] |
355 | with_unihash = request.get("with_unihash", True) | ||
323 | 356 | ||
324 | row = cursor.fetchone() | 357 | return await self.get_outhash(method, outhash, taskhash, with_unihash) |
358 | |||
359 | async def get_outhash(self, method, outhash, taskhash, with_unihash=True): | ||
360 | d = None | ||
361 | if with_unihash: | ||
362 | row = await self.db.get_unihash_by_outhash(method, outhash) | ||
363 | else: | ||
364 | row = await self.db.get_outhash(method, outhash) | ||
325 | 365 | ||
326 | if row is not None: | 366 | if row is not None: |
327 | logger.debug('Found equivalent outhash %s -> %s', (row['outhash'], row['unihash'])) | ||
328 | d = {k: row[k] for k in row.keys()} | 367 | d = {k: row[k] for k in row.keys()} |
329 | else: | 368 | elif self.upstream_client is not None: |
330 | d = None | 369 | d = await self.upstream_client.get_outhash(method, outhash, taskhash) |
370 | await self.update_unified(d) | ||
331 | 371 | ||
332 | self.write_message(d) | 372 | return d |
333 | 373 | ||
334 | async def handle_get_stream(self, request): | 374 | async def update_unified(self, data): |
335 | self.write_message('ok') | 375 | if data is None: |
376 | return | ||
377 | |||
378 | await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"]) | ||
379 | await self.db.insert_outhash(data) | ||
380 | |||
381 | async def _stream_handler(self, handler): | ||
382 | await self.socket.send_message("ok") | ||
336 | 383 | ||
337 | while True: | 384 | while True: |
338 | upstream = None | 385 | upstream = None |
339 | 386 | ||
340 | l = await self.reader.readline() | 387 | l = await self.socket.recv() |
341 | if not l: | 388 | if not l: |
342 | return | 389 | break |
343 | 390 | ||
344 | try: | 391 | try: |
345 | # This inner loop is very sensitive and must be as fast as | 392 | # This inner loop is very sensitive and must be as fast as |
346 | # possible (which is why the request sample is handled manually | 393 | # possible (which is why the request sample is handled manually |
347 | # instead of using 'with', and also why logging statements are | 394 | # instead of using 'with', and also why logging statements are |
348 | # commented out. | 395 | # commented out. |
349 | self.request_sample = self.request_stats.start_sample() | 396 | self.request_sample = self.server.request_stats.start_sample() |
350 | request_measure = self.request_sample.measure() | 397 | request_measure = self.request_sample.measure() |
351 | request_measure.start() | 398 | request_measure.start() |
352 | 399 | ||
353 | l = l.decode('utf-8').rstrip() | 400 | if l == "END": |
354 | if l == 'END': | 401 | break |
355 | self.writer.write('ok\n'.encode('utf-8')) | ||
356 | return | ||
357 | |||
358 | (method, taskhash) = l.split() | ||
359 | #logger.debug('Looking up %s %s' % (method, taskhash)) | ||
360 | row = self.query_equivalent(method, taskhash, self.FAST_QUERY) | ||
361 | if row is not None: | ||
362 | msg = ('%s\n' % row['unihash']).encode('utf-8') | ||
363 | #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) | ||
364 | elif self.upstream_client is not None: | ||
365 | upstream = await self.upstream_client.get_unihash(method, taskhash) | ||
366 | if upstream: | ||
367 | msg = ("%s\n" % upstream).encode("utf-8") | ||
368 | else: | ||
369 | msg = "\n".encode("utf-8") | ||
370 | else: | ||
371 | msg = '\n'.encode('utf-8') | ||
372 | 402 | ||
373 | self.writer.write(msg) | 403 | msg = await handler(l) |
404 | await self.socket.send(msg) | ||
374 | finally: | 405 | finally: |
375 | request_measure.end() | 406 | request_measure.end() |
376 | self.request_sample.end() | 407 | self.request_sample.end() |
377 | 408 | ||
378 | await self.writer.drain() | 409 | await self.socket.send("ok") |
410 | return self.NO_RESPONSE | ||
379 | 411 | ||
380 | # Post to the backfill queue after writing the result to minimize | 412 | @permissions(READ_PERM) |
381 | # the turn around time on a request | 413 | async def handle_get_stream(self, request): |
382 | if upstream is not None: | 414 | async def handler(l): |
383 | await self.backfill_queue.put((method, taskhash)) | 415 | (method, taskhash) = l.split() |
416 | # self.logger.debug('Looking up %s %s' % (method, taskhash)) | ||
417 | row = await self.db.get_equivalent(method, taskhash) | ||
384 | 418 | ||
385 | async def handle_report(self, data): | 419 | if row is not None: |
386 | with closing(self.db.cursor()) as cursor: | 420 | # self.logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) |
387 | cursor.execute(self.OUTHASH_QUERY, | 421 | return row["unihash"] |
388 | {k: data[k] for k in ('method', 'outhash', 'taskhash')}) | ||
389 | |||
390 | row = cursor.fetchone() | ||
391 | |||
392 | if row is None and self.upstream_client: | ||
393 | # Try upstream | ||
394 | row = await copy_outhash_from_upstream(self.upstream_client, | ||
395 | self.db, | ||
396 | data['method'], | ||
397 | data['outhash'], | ||
398 | data['taskhash']) | ||
399 | |||
400 | # If no matching outhash was found, or one *was* found but it | ||
401 | # wasn't an exact match on the taskhash, a new entry for this | ||
402 | # taskhash should be added | ||
403 | if row is None or row['taskhash'] != data['taskhash']: | ||
404 | # If a row matching the outhash was found, the unihash for | ||
405 | # the new taskhash should be the same as that one. | ||
406 | # Otherwise the caller provided unihash is used. | ||
407 | unihash = data['unihash'] | ||
408 | if row is not None: | ||
409 | unihash = row['unihash'] | ||
410 | |||
411 | insert_data = { | ||
412 | 'method': data['method'], | ||
413 | 'outhash': data['outhash'], | ||
414 | 'taskhash': data['taskhash'], | ||
415 | 'unihash': unihash, | ||
416 | 'created': datetime.now() | ||
417 | } | ||
418 | 422 | ||
419 | for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'): | 423 | if self.upstream_client is not None: |
420 | if k in data: | 424 | upstream = await self.upstream_client.get_unihash(method, taskhash) |
421 | insert_data[k] = data[k] | 425 | if upstream: |
426 | await self.server.backfill_queue.put((method, taskhash)) | ||
427 | return upstream | ||
422 | 428 | ||
423 | insert_task(cursor, insert_data) | 429 | return "" |
424 | self.db.commit() | ||
425 | 430 | ||
426 | logger.info('Adding taskhash %s with unihash %s', | 431 | return await self._stream_handler(handler) |
427 | data['taskhash'], unihash) | ||
428 | 432 | ||
429 | d = { | 433 | @permissions(READ_PERM) |
430 | 'taskhash': data['taskhash'], | 434 | async def handle_exists_stream(self, request): |
431 | 'method': data['method'], | 435 | async def handler(l): |
432 | 'unihash': unihash | 436 | if await self.db.unihash_exists(l): |
433 | } | 437 | return "true" |
434 | else: | ||
435 | d = {k: row[k] for k in ('taskhash', 'method', 'unihash')} | ||
436 | 438 | ||
437 | self.write_message(d) | 439 | if self.upstream_client is not None: |
440 | if await self.upstream_client.unihash_exists(l): | ||
441 | return "true" | ||
438 | 442 | ||
439 | async def handle_equivreport(self, data): | 443 | return "false" |
440 | with closing(self.db.cursor()) as cursor: | ||
441 | insert_data = { | ||
442 | 'method': data['method'], | ||
443 | 'outhash': "", | ||
444 | 'taskhash': data['taskhash'], | ||
445 | 'unihash': data['unihash'], | ||
446 | 'created': datetime.now() | ||
447 | } | ||
448 | 444 | ||
449 | for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'): | 445 | return await self._stream_handler(handler) |
450 | if k in data: | ||
451 | insert_data[k] = data[k] | ||
452 | 446 | ||
453 | insert_task(cursor, insert_data, ignore=True) | 447 | async def report_readonly(self, data): |
454 | self.db.commit() | 448 | method = data["method"] |
449 | outhash = data["outhash"] | ||
450 | taskhash = data["taskhash"] | ||
455 | 451 | ||
456 | # Fetch the unihash that will be reported for the taskhash. If the | 452 | info = await self.get_outhash(method, outhash, taskhash) |
457 | # unihash matches, it means this row was inserted (or the mapping | 453 | if info: |
458 | # was already valid) | 454 | unihash = info["unihash"] |
459 | row = self.query_equivalent(data['method'], data['taskhash'], self.FAST_QUERY) | 455 | else: |
456 | unihash = data["unihash"] | ||
460 | 457 | ||
461 | if row['unihash'] == data['unihash']: | 458 | return { |
462 | logger.info('Adding taskhash equivalence for %s with unihash %s', | 459 | "taskhash": taskhash, |
463 | data['taskhash'], row['unihash']) | 460 | "method": method, |
461 | "unihash": unihash, | ||
462 | } | ||
464 | 463 | ||
465 | d = {k: row[k] for k in ('taskhash', 'method', 'unihash')} | 464 | # Since this can be called either read only or to report, the check to |
465 | # report is made inside the function | ||
466 | @permissions(READ_PERM) | ||
467 | async def handle_report(self, data): | ||
468 | if self.server.read_only or not self.user_has_permissions(REPORT_PERM): | ||
469 | return await self.report_readonly(data) | ||
470 | |||
471 | outhash_data = { | ||
472 | "method": data["method"], | ||
473 | "outhash": data["outhash"], | ||
474 | "taskhash": data["taskhash"], | ||
475 | "created": datetime.now(), | ||
476 | } | ||
466 | 477 | ||
467 | self.write_message(d) | 478 | for k in ("owner", "PN", "PV", "PR", "task", "outhash_siginfo"): |
479 | if k in data: | ||
480 | outhash_data[k] = data[k] | ||
468 | 481 | ||
482 | if self.user: | ||
483 | outhash_data["owner"] = self.user.username | ||
469 | 484 | ||
470 | async def handle_get_stats(self, request): | 485 | # Insert the new entry, unless it already exists |
471 | d = { | 486 | if await self.db.insert_outhash(outhash_data): |
472 | 'requests': self.request_stats.todict(), | 487 | # If this row is new, check if it is equivalent to another |
488 | # output hash | ||
489 | row = await self.db.get_equivalent_for_outhash( | ||
490 | data["method"], data["outhash"], data["taskhash"] | ||
491 | ) | ||
492 | |||
493 | if row is not None: | ||
494 | # A matching output hash was found. Set our taskhash to the | ||
495 | # same unihash since they are equivalent | ||
496 | unihash = row["unihash"] | ||
497 | else: | ||
498 | # No matching output hash was found. This is probably the | ||
499 | # first outhash to be added. | ||
500 | unihash = data["unihash"] | ||
501 | |||
502 | # Query upstream to see if it has a unihash we can use | ||
503 | if self.upstream_client is not None: | ||
504 | upstream_data = await self.upstream_client.get_outhash( | ||
505 | data["method"], data["outhash"], data["taskhash"] | ||
506 | ) | ||
507 | if upstream_data is not None: | ||
508 | unihash = upstream_data["unihash"] | ||
509 | |||
510 | await self.db.insert_unihash(data["method"], data["taskhash"], unihash) | ||
511 | |||
512 | unihash_data = await self.get_unihash(data["method"], data["taskhash"]) | ||
513 | if unihash_data is not None: | ||
514 | unihash = unihash_data["unihash"] | ||
515 | else: | ||
516 | unihash = data["unihash"] | ||
517 | |||
518 | return { | ||
519 | "taskhash": data["taskhash"], | ||
520 | "method": data["method"], | ||
521 | "unihash": unihash, | ||
473 | } | 522 | } |
474 | 523 | ||
475 | self.write_message(d) | 524 | @permissions(READ_PERM, REPORT_PERM) |
525 | async def handle_equivreport(self, data): | ||
526 | await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"]) | ||
527 | |||
528 | # Fetch the unihash that will be reported for the taskhash. If the | ||
529 | # unihash matches, it means this row was inserted (or the mapping | ||
530 | # was already valid) | ||
531 | row = await self.db.get_equivalent(data["method"], data["taskhash"]) | ||
532 | |||
533 | if row["unihash"] == data["unihash"]: | ||
534 | self.logger.info( | ||
535 | "Adding taskhash equivalence for %s with unihash %s", | ||
536 | data["taskhash"], | ||
537 | row["unihash"], | ||
538 | ) | ||
539 | |||
540 | return {k: row[k] for k in ("taskhash", "method", "unihash")} | ||
476 | 541 | ||
542 | @permissions(READ_PERM) | ||
543 | async def handle_get_stats(self, request): | ||
544 | return { | ||
545 | "requests": self.server.request_stats.todict(), | ||
546 | } | ||
547 | |||
548 | @permissions(DB_ADMIN_PERM) | ||
477 | async def handle_reset_stats(self, request): | 549 | async def handle_reset_stats(self, request): |
478 | d = { | 550 | d = { |
479 | 'requests': self.request_stats.todict(), | 551 | "requests": self.server.request_stats.todict(), |
480 | } | 552 | } |
481 | 553 | ||
482 | self.request_stats.reset() | 554 | self.server.request_stats.reset() |
483 | self.write_message(d) | 555 | return d |
484 | 556 | ||
557 | @permissions(READ_PERM) | ||
485 | async def handle_backfill_wait(self, request): | 558 | async def handle_backfill_wait(self, request): |
486 | d = { | 559 | d = { |
487 | 'tasks': self.backfill_queue.qsize(), | 560 | "tasks": self.server.backfill_queue.qsize(), |
488 | } | 561 | } |
489 | await self.backfill_queue.join() | 562 | await self.server.backfill_queue.join() |
490 | self.write_message(d) | 563 | return d |
564 | |||
565 | @permissions(DB_ADMIN_PERM) | ||
566 | async def handle_remove(self, request): | ||
567 | condition = request["where"] | ||
568 | if not isinstance(condition, dict): | ||
569 | raise TypeError("Bad condition type %s" % type(condition)) | ||
570 | |||
571 | return {"count": await self.db.remove(condition)} | ||
572 | |||
573 | @permissions(DB_ADMIN_PERM) | ||
574 | async def handle_gc_mark(self, request): | ||
575 | condition = request["where"] | ||
576 | mark = request["mark"] | ||
577 | |||
578 | if not isinstance(condition, dict): | ||
579 | raise TypeError("Bad condition type %s" % type(condition)) | ||
580 | |||
581 | if not isinstance(mark, str): | ||
582 | raise TypeError("Bad mark type %s" % type(mark)) | ||
583 | |||
584 | return {"count": await self.db.gc_mark(mark, condition)} | ||
585 | |||
586 | @permissions(DB_ADMIN_PERM) | ||
587 | async def handle_gc_sweep(self, request): | ||
588 | mark = request["mark"] | ||
589 | |||
590 | if not isinstance(mark, str): | ||
591 | raise TypeError("Bad mark type %s" % type(mark)) | ||
592 | |||
593 | current_mark = await self.db.get_current_gc_mark() | ||
594 | |||
595 | if not current_mark or mark != current_mark: | ||
596 | raise bb.asyncrpc.InvokeError( | ||
597 | f"'{mark}' is not the current mark. Refusing to sweep" | ||
598 | ) | ||
599 | |||
600 | count = await self.db.gc_sweep() | ||
601 | |||
602 | return {"count": count} | ||
603 | |||
604 | @permissions(DB_ADMIN_PERM) | ||
605 | async def handle_gc_status(self, request): | ||
606 | (keep_rows, remove_rows, current_mark) = await self.db.gc_status() | ||
607 | return { | ||
608 | "keep": keep_rows, | ||
609 | "remove": remove_rows, | ||
610 | "mark": current_mark, | ||
611 | } | ||
612 | |||
613 | @permissions(DB_ADMIN_PERM) | ||
614 | async def handle_clean_unused(self, request): | ||
615 | max_age = request["max_age_seconds"] | ||
616 | oldest = datetime.now() - timedelta(seconds=-max_age) | ||
617 | return {"count": await self.db.clean_unused(oldest)} | ||
618 | |||
619 | @permissions(DB_ADMIN_PERM) | ||
620 | async def handle_get_db_usage(self, request): | ||
621 | return {"usage": await self.db.get_usage()} | ||
622 | |||
623 | @permissions(DB_ADMIN_PERM) | ||
624 | async def handle_get_db_query_columns(self, request): | ||
625 | return {"columns": await self.db.get_query_columns()} | ||
626 | |||
627 | # The authentication API is always allowed | ||
628 | async def handle_auth(self, request): | ||
629 | username = str(request["username"]) | ||
630 | token = str(request["token"]) | ||
631 | |||
632 | async def fail_auth(): | ||
633 | nonlocal username | ||
634 | # Rate limit bad login attempts | ||
635 | await asyncio.sleep(1) | ||
636 | raise bb.asyncrpc.InvokeError(f"Unable to authenticate as {username}") | ||
637 | |||
638 | user, db_token = await self.db.lookup_user_token(username) | ||
639 | |||
640 | if not user or not db_token: | ||
641 | await fail_auth() | ||
491 | 642 | ||
492 | def query_equivalent(self, method, taskhash, query): | ||
493 | # This is part of the inner loop and must be as fast as possible | ||
494 | try: | 643 | try: |
495 | cursor = self.db.cursor() | 644 | algo, salt, _ = db_token.split(":") |
496 | cursor.execute(query, {'method': method, 'taskhash': taskhash}) | 645 | except ValueError: |
497 | return cursor.fetchone() | 646 | await fail_auth() |
498 | except: | ||
499 | cursor.close() | ||
500 | 647 | ||
648 | if hash_token(algo, salt, token) != db_token: | ||
649 | await fail_auth() | ||
501 | 650 | ||
502 | class Server(object): | 651 | self.user = user |
503 | def __init__(self, db, loop=None, upstream=None, read_only=False): | ||
504 | if upstream and read_only: | ||
505 | raise ServerError("Read-only hashserv cannot pull from an upstream server") | ||
506 | 652 | ||
507 | self.request_stats = Stats() | 653 | self.logger.info("Authenticated as %s", username) |
508 | self.db = db | ||
509 | 654 | ||
510 | if loop is None: | 655 | return { |
511 | self.loop = asyncio.new_event_loop() | 656 | "result": True, |
512 | self.close_loop = True | 657 | "username": self.user.username, |
513 | else: | 658 | "permissions": sorted(list(self.user.permissions)), |
514 | self.loop = loop | 659 | } |
515 | self.close_loop = False | ||
516 | 660 | ||
517 | self.upstream = upstream | 661 | @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False) |
518 | self.read_only = read_only | 662 | async def handle_refresh_token(self, request): |
663 | username = str(request["username"]) | ||
519 | 664 | ||
520 | self._cleanup_socket = None | 665 | token = await new_token() |
521 | 666 | ||
522 | def start_tcp_server(self, host, port): | 667 | updated = await self.db.set_user_token( |
523 | self.server = self.loop.run_until_complete( | 668 | username, |
524 | asyncio.start_server(self.handle_client, host, port, loop=self.loop) | 669 | hash_token(TOKEN_ALGORITHM, new_salt(), token), |
525 | ) | 670 | ) |
671 | if not updated: | ||
672 | self.raise_no_user_error(username) | ||
526 | 673 | ||
527 | for s in self.server.sockets: | 674 | return {"username": username, "token": token} |
528 | logger.info('Listening on %r' % (s.getsockname(),)) | ||
529 | # Newer python does this automatically. Do it manually here for | ||
530 | # maximum compatibility | ||
531 | s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) | ||
532 | s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1) | ||
533 | |||
534 | name = self.server.sockets[0].getsockname() | ||
535 | if self.server.sockets[0].family == socket.AF_INET6: | ||
536 | self.address = "[%s]:%d" % (name[0], name[1]) | ||
537 | else: | ||
538 | self.address = "%s:%d" % (name[0], name[1]) | ||
539 | 675 | ||
540 | def start_unix_server(self, path): | 676 | def get_perm_arg(self, arg): |
541 | def cleanup(): | 677 | if not isinstance(arg, list): |
542 | os.unlink(path) | 678 | raise bb.asyncrpc.InvokeError("Unexpected type for permissions") |
543 | 679 | ||
544 | cwd = os.getcwd() | 680 | arg = set(arg) |
545 | try: | 681 | try: |
546 | # Work around path length limits in AF_UNIX | 682 | arg.remove(NONE_PERM) |
547 | os.chdir(os.path.dirname(path)) | 683 | except KeyError: |
548 | self.server = self.loop.run_until_complete( | 684 | pass |
549 | asyncio.start_unix_server(self.handle_client, os.path.basename(path), loop=self.loop) | 685 | |
686 | unknown_perms = arg - ALL_PERMISSIONS | ||
687 | if unknown_perms: | ||
688 | raise bb.asyncrpc.InvokeError( | ||
689 | "Unknown permissions %s" % ", ".join(sorted(list(unknown_perms))) | ||
550 | ) | 690 | ) |
551 | finally: | ||
552 | os.chdir(cwd) | ||
553 | 691 | ||
554 | logger.info('Listening on %r' % path) | 692 | return sorted(list(arg)) |
555 | 693 | ||
556 | self._cleanup_socket = cleanup | 694 | def return_perms(self, permissions): |
557 | self.address = "unix://%s" % os.path.abspath(path) | 695 | if ALL_PERM in permissions: |
696 | return sorted(list(ALL_PERMISSIONS)) | ||
697 | return sorted(list(permissions)) | ||
558 | 698 | ||
559 | async def handle_client(self, reader, writer): | 699 | @permissions(USER_ADMIN_PERM, allow_anon=False) |
560 | # writer.transport.set_write_buffer_limits(0) | 700 | async def handle_set_perms(self, request): |
561 | try: | 701 | username = str(request["username"]) |
562 | client = ServerClient(reader, writer, self.db, self.request_stats, self.backfill_queue, self.upstream, self.read_only) | 702 | permissions = self.get_perm_arg(request["permissions"]) |
563 | await client.process_requests() | ||
564 | except Exception as e: | ||
565 | import traceback | ||
566 | logger.error('Error from client: %s' % str(e), exc_info=True) | ||
567 | traceback.print_exc() | ||
568 | writer.close() | ||
569 | logger.info('Client disconnected') | ||
570 | |||
571 | @contextmanager | ||
572 | def _backfill_worker(self): | ||
573 | async def backfill_worker_task(): | ||
574 | client = await create_async_client(self.upstream) | ||
575 | try: | ||
576 | while True: | ||
577 | item = await self.backfill_queue.get() | ||
578 | if item is None: | ||
579 | self.backfill_queue.task_done() | ||
580 | break | ||
581 | method, taskhash = item | ||
582 | await copy_from_upstream(client, self.db, method, taskhash) | ||
583 | self.backfill_queue.task_done() | ||
584 | finally: | ||
585 | await client.close() | ||
586 | 703 | ||
587 | async def join_worker(worker): | 704 | if not await self.db.set_user_perms(username, permissions): |
588 | await self.backfill_queue.put(None) | 705 | self.raise_no_user_error(username) |
589 | await worker | ||
590 | 706 | ||
591 | if self.upstream is not None: | 707 | return { |
592 | worker = asyncio.ensure_future(backfill_worker_task()) | 708 | "username": username, |
593 | try: | 709 | "permissions": self.return_perms(permissions), |
594 | yield | 710 | } |
595 | finally: | ||
596 | self.loop.run_until_complete(join_worker(worker)) | ||
597 | else: | ||
598 | yield | ||
599 | 711 | ||
600 | def serve_forever(self): | 712 | @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False) |
601 | def signal_handler(): | 713 | async def handle_get_user(self, request): |
602 | self.loop.stop() | 714 | username = str(request["username"]) |
603 | 715 | ||
604 | asyncio.set_event_loop(self.loop) | 716 | user = await self.db.lookup_user(username) |
605 | try: | 717 | if user is None: |
606 | self.backfill_queue = asyncio.Queue() | 718 | return None |
719 | |||
720 | return { | ||
721 | "username": user.username, | ||
722 | "permissions": self.return_perms(user.permissions), | ||
723 | } | ||
724 | |||
725 | @permissions(USER_ADMIN_PERM, allow_anon=False) | ||
726 | async def handle_get_all_users(self, request): | ||
727 | users = await self.db.get_all_users() | ||
728 | return { | ||
729 | "users": [ | ||
730 | { | ||
731 | "username": u.username, | ||
732 | "permissions": self.return_perms(u.permissions), | ||
733 | } | ||
734 | for u in users | ||
735 | ] | ||
736 | } | ||
737 | |||
738 | @permissions(USER_ADMIN_PERM, allow_anon=False) | ||
739 | async def handle_new_user(self, request): | ||
740 | username = str(request["username"]) | ||
741 | permissions = self.get_perm_arg(request["permissions"]) | ||
742 | |||
743 | token = await new_token() | ||
744 | |||
745 | inserted = await self.db.new_user( | ||
746 | username, | ||
747 | permissions, | ||
748 | hash_token(TOKEN_ALGORITHM, new_salt(), token), | ||
749 | ) | ||
750 | if not inserted: | ||
751 | raise bb.asyncrpc.InvokeError(f"Cannot create new user '{username}'") | ||
752 | |||
753 | return { | ||
754 | "username": username, | ||
755 | "permissions": self.return_perms(permissions), | ||
756 | "token": token, | ||
757 | } | ||
758 | |||
759 | @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False) | ||
760 | async def handle_delete_user(self, request): | ||
761 | username = str(request["username"]) | ||
762 | |||
763 | if not await self.db.delete_user(username): | ||
764 | self.raise_no_user_error(username) | ||
765 | |||
766 | return {"username": username} | ||
607 | 767 | ||
608 | self.loop.add_signal_handler(signal.SIGTERM, signal_handler) | 768 | @permissions(USER_ADMIN_PERM, allow_anon=False) |
769 | async def handle_become_user(self, request): | ||
770 | username = str(request["username"]) | ||
609 | 771 | ||
610 | with self._backfill_worker(): | 772 | user = await self.db.lookup_user(username) |
611 | try: | 773 | if user is None: |
612 | self.loop.run_forever() | 774 | raise bb.asyncrpc.InvokeError(f"User {username} doesn't exist") |
613 | except KeyboardInterrupt: | ||
614 | pass | ||
615 | 775 | ||
616 | self.server.close() | 776 | self.user = user |
777 | |||
778 | self.logger.info("Became user %s", username) | ||
779 | |||
780 | return { | ||
781 | "username": self.user.username, | ||
782 | "permissions": self.return_perms(self.user.permissions), | ||
783 | } | ||
784 | |||
785 | |||
786 | class Server(bb.asyncrpc.AsyncServer): | ||
787 | def __init__( | ||
788 | self, | ||
789 | db_engine, | ||
790 | upstream=None, | ||
791 | read_only=False, | ||
792 | anon_perms=DEFAULT_ANON_PERMS, | ||
793 | admin_username=None, | ||
794 | admin_password=None, | ||
795 | ): | ||
796 | if upstream and read_only: | ||
797 | raise bb.asyncrpc.ServerError( | ||
798 | "Read-only hashserv cannot pull from an upstream server" | ||
799 | ) | ||
800 | |||
801 | disallowed_perms = set(anon_perms) - set( | ||
802 | [NONE_PERM, READ_PERM, REPORT_PERM, DB_ADMIN_PERM] | ||
803 | ) | ||
804 | |||
805 | if disallowed_perms: | ||
806 | raise bb.asyncrpc.ServerError( | ||
807 | f"Permission(s) {' '.join(disallowed_perms)} are not allowed for anonymous users" | ||
808 | ) | ||
617 | 809 | ||
618 | self.loop.run_until_complete(self.server.wait_closed()) | 810 | super().__init__(logger) |
619 | logger.info('Server shutting down') | ||
620 | finally: | ||
621 | if self.close_loop: | ||
622 | if sys.version_info >= (3, 6): | ||
623 | self.loop.run_until_complete(self.loop.shutdown_asyncgens()) | ||
624 | self.loop.close() | ||
625 | 811 | ||
626 | if self._cleanup_socket is not None: | 812 | self.request_stats = Stats() |
627 | self._cleanup_socket() | 813 | self.db_engine = db_engine |
814 | self.upstream = upstream | ||
815 | self.read_only = read_only | ||
816 | self.backfill_queue = None | ||
817 | self.anon_perms = set(anon_perms) | ||
818 | self.admin_username = admin_username | ||
819 | self.admin_password = admin_password | ||
820 | |||
821 | self.logger.info( | ||
822 | "Anonymous user permissions are: %s", ", ".join(self.anon_perms) | ||
823 | ) | ||
824 | |||
825 | def accept_client(self, socket): | ||
826 | return ServerClient(socket, self) | ||
827 | |||
828 | async def create_admin_user(self): | ||
829 | admin_permissions = (ALL_PERM,) | ||
830 | async with self.db_engine.connect(self.logger) as db: | ||
831 | added = await db.new_user( | ||
832 | self.admin_username, | ||
833 | admin_permissions, | ||
834 | hash_token(TOKEN_ALGORITHM, new_salt(), self.admin_password), | ||
835 | ) | ||
836 | if added: | ||
837 | self.logger.info("Created admin user '%s'", self.admin_username) | ||
838 | else: | ||
839 | await db.set_user_perms( | ||
840 | self.admin_username, | ||
841 | admin_permissions, | ||
842 | ) | ||
843 | await db.set_user_token( | ||
844 | self.admin_username, | ||
845 | hash_token(TOKEN_ALGORITHM, new_salt(), self.admin_password), | ||
846 | ) | ||
847 | self.logger.info("Admin user '%s' updated", self.admin_username) | ||
848 | |||
849 | async def backfill_worker_task(self): | ||
850 | async with await create_async_client( | ||
851 | self.upstream | ||
852 | ) as client, self.db_engine.connect(self.logger) as db: | ||
853 | while True: | ||
854 | item = await self.backfill_queue.get() | ||
855 | if item is None: | ||
856 | self.backfill_queue.task_done() | ||
857 | break | ||
858 | |||
859 | method, taskhash = item | ||
860 | d = await client.get_taskhash(method, taskhash) | ||
861 | if d is not None: | ||
862 | await db.insert_unihash(d["method"], d["taskhash"], d["unihash"]) | ||
863 | self.backfill_queue.task_done() | ||
864 | |||
865 | def start(self): | ||
866 | tasks = super().start() | ||
867 | if self.upstream: | ||
868 | self.backfill_queue = asyncio.Queue() | ||
869 | tasks += [self.backfill_worker_task()] | ||
870 | |||
871 | self.loop.run_until_complete(self.db_engine.create()) | ||
872 | |||
873 | if self.admin_username: | ||
874 | self.loop.run_until_complete(self.create_admin_user()) | ||
875 | |||
876 | return tasks | ||
877 | |||
878 | async def stop(self): | ||
879 | if self.backfill_queue is not None: | ||
880 | await self.backfill_queue.put(None) | ||
881 | await super().stop() | ||