summaryrefslogtreecommitdiffstats
path: root/bitbake/lib/hashserv/server.py
diff options
context:
space:
mode:
Diffstat (limited to 'bitbake/lib/hashserv/server.py')
-rw-r--r--bitbake/lib/hashserv/server.py1117
1 files changed, 700 insertions, 417 deletions
diff --git a/bitbake/lib/hashserv/server.py b/bitbake/lib/hashserv/server.py
index a0dc0c170f..58f95c7bcd 100644
--- a/bitbake/lib/hashserv/server.py
+++ b/bitbake/lib/hashserv/server.py
@@ -3,20 +3,52 @@
3# SPDX-License-Identifier: GPL-2.0-only 3# SPDX-License-Identifier: GPL-2.0-only
4# 4#
5 5
6from contextlib import closing, contextmanager 6from datetime import datetime, timedelta
7from datetime import datetime
8import asyncio 7import asyncio
9import json
10import logging 8import logging
11import math 9import math
12import os
13import signal
14import socket
15import sys
16import time 10import time
17from . import chunkify, DEFAULT_MAX_CHUNK, create_async_client, TABLE_COLUMNS 11import os
12import base64
13import json
14import hashlib
15from . import create_async_client
16import bb.asyncrpc
17
18logger = logging.getLogger("hashserv.server")
19
20
21# This permission only exists to match nothing
22NONE_PERM = "@none"
23
24READ_PERM = "@read"
25REPORT_PERM = "@report"
26DB_ADMIN_PERM = "@db-admin"
27USER_ADMIN_PERM = "@user-admin"
28ALL_PERM = "@all"
29
30ALL_PERMISSIONS = {
31 READ_PERM,
32 REPORT_PERM,
33 DB_ADMIN_PERM,
34 USER_ADMIN_PERM,
35 ALL_PERM,
36}
37
38DEFAULT_ANON_PERMS = (
39 READ_PERM,
40 REPORT_PERM,
41 DB_ADMIN_PERM,
42)
43
44TOKEN_ALGORITHM = "sha256"
45
46# 48 bytes of random data will result in 64 characters when base64
47# encoded. This number also ensures that the base64 encoding won't have any
48# trailing '=' characters.
49TOKEN_SIZE = 48
18 50
19logger = logging.getLogger('hashserv.server') 51SALT_SIZE = 8
20 52
21 53
22class Measurement(object): 54class Measurement(object):
@@ -106,522 +138,773 @@ class Stats(object):
106 return math.sqrt(self.s / (self.num - 1)) 138 return math.sqrt(self.s / (self.num - 1))
107 139
108 def todict(self): 140 def todict(self):
109 return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')} 141 return {
110 142 k: getattr(self, k)
111 143 for k in ("num", "total_time", "max_time", "average", "stdev")
112class ClientError(Exception):
113 pass
114
115class ServerError(Exception):
116 pass
117
118def insert_task(cursor, data, ignore=False):
119 keys = sorted(data.keys())
120 query = '''INSERT%s INTO tasks_v2 (%s) VALUES (%s)''' % (
121 " OR IGNORE" if ignore else "",
122 ', '.join(keys),
123 ', '.join(':' + k for k in keys))
124 cursor.execute(query, data)
125
126async def copy_from_upstream(client, db, method, taskhash):
127 d = await client.get_taskhash(method, taskhash, True)
128 if d is not None:
129 # Filter out unknown columns
130 d = {k: v for k, v in d.items() if k in TABLE_COLUMNS}
131 keys = sorted(d.keys())
132
133 with closing(db.cursor()) as cursor:
134 insert_task(cursor, d)
135 db.commit()
136
137 return d
138
139async def copy_outhash_from_upstream(client, db, method, outhash, taskhash):
140 d = await client.get_outhash(method, outhash, taskhash)
141 if d is not None:
142 # Filter out unknown columns
143 d = {k: v for k, v in d.items() if k in TABLE_COLUMNS}
144 keys = sorted(d.keys())
145
146 with closing(db.cursor()) as cursor:
147 insert_task(cursor, d)
148 db.commit()
149
150 return d
151
152class ServerClient(object):
153 FAST_QUERY = 'SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1'
154 ALL_QUERY = 'SELECT * FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1'
155 OUTHASH_QUERY = '''
156 -- Find tasks with a matching outhash (that is, tasks that
157 -- are equivalent)
158 SELECT * FROM tasks_v2 WHERE method=:method AND outhash=:outhash
159
160 -- If there is an exact match on the taskhash, return it.
161 -- Otherwise return the oldest matching outhash of any
162 -- taskhash
163 ORDER BY CASE WHEN taskhash=:taskhash THEN 1 ELSE 2 END,
164 created ASC
165
166 -- Only return one row
167 LIMIT 1
168 '''
169
170 def __init__(self, reader, writer, db, request_stats, backfill_queue, upstream, read_only):
171 self.reader = reader
172 self.writer = writer
173 self.db = db
174 self.request_stats = request_stats
175 self.max_chunk = DEFAULT_MAX_CHUNK
176 self.backfill_queue = backfill_queue
177 self.upstream = upstream
178
179 self.handlers = {
180 'get': self.handle_get,
181 'get-outhash': self.handle_get_outhash,
182 'get-stream': self.handle_get_stream,
183 'get-stats': self.handle_get_stats,
184 'chunk-stream': self.handle_chunk,
185 } 144 }
186 145
187 if not read_only:
188 self.handlers.update({
189 'report': self.handle_report,
190 'report-equiv': self.handle_equivreport,
191 'reset-stats': self.handle_reset_stats,
192 'backfill-wait': self.handle_backfill_wait,
193 })
194 146
195 async def process_requests(self): 147token_refresh_semaphore = asyncio.Lock()
196 if self.upstream is not None:
197 self.upstream_client = await create_async_client(self.upstream)
198 else:
199 self.upstream_client = None
200 148
201 try:
202 149
150async def new_token():
151 # Prevent malicious users from using this API to deduce the entropy
152 # pool on the server and thus be able to guess a token. *All* token
153 # refresh requests lock the same global semaphore and then sleep for a
154 # short time. The effectively rate limits the total number of requests
155 # than can be made across all clients to 10/second, which should be enough
156 # since you have to be an authenticated users to make the request in the
157 # first place
158 async with token_refresh_semaphore:
159 await asyncio.sleep(0.1)
160 raw = os.getrandom(TOKEN_SIZE, os.GRND_NONBLOCK)
203 161
204 self.addr = self.writer.get_extra_info('peername') 162 return base64.b64encode(raw, b"._").decode("utf-8")
205 logger.debug('Client %r connected' % (self.addr,))
206 163
207 # Read protocol and version
208 protocol = await self.reader.readline()
209 if protocol is None:
210 return
211 164
212 (proto_name, proto_version) = protocol.decode('utf-8').rstrip().split() 165def new_salt():
213 if proto_name != 'OEHASHEQUIV': 166 return os.getrandom(SALT_SIZE, os.GRND_NONBLOCK).hex()
214 return
215 167
216 proto_version = tuple(int(v) for v in proto_version.split('.'))
217 if proto_version < (1, 0) or proto_version > (1, 1):
218 return
219 168
220 # Read headers. Currently, no headers are implemented, so look for 169def hash_token(algo, salt, token):
221 # an empty line to signal the end of the headers 170 h = hashlib.new(algo)
222 while True: 171 h.update(salt.encode("utf-8"))
223 line = await self.reader.readline() 172 h.update(token.encode("utf-8"))
224 if line is None: 173 return ":".join([algo, salt, h.hexdigest()])
225 return
226 174
227 line = line.decode('utf-8').rstrip()
228 if not line:
229 break
230 175
231 # Handle messages 176def permissions(*permissions, allow_anon=True, allow_self_service=False):
232 while True: 177 """
233 d = await self.read_message() 178 Function decorator that can be used to decorate an RPC function call and
234 if d is None: 179 check that the current users permissions match the require permissions.
235 break
236 await self.dispatch_message(d)
237 await self.writer.drain()
238 except ClientError as e:
239 logger.error(str(e))
240 finally:
241 if self.upstream_client is not None:
242 await self.upstream_client.close()
243 180
244 self.writer.close() 181 If allow_anon is True, the user will also be allowed to make the RPC call
182 if the anonymous user permissions match the permissions.
245 183
246 async def dispatch_message(self, msg): 184 If allow_self_service is True, and the "username" property in the request
247 for k in self.handlers.keys(): 185 is the currently logged in user, or not specified, the user will also be
248 if k in msg: 186 allowed to make the request. This allows users to access normal privileged
249 logger.debug('Handling %s' % k) 187 API, as long as they are only modifying their own user properties (e.g.
250 if 'stream' in k: 188 users can be allowed to reset their own token without @user-admin
251 await self.handlers[k](msg[k]) 189 permissions, but not the token for any other user.
190 """
191
192 def wrapper(func):
193 async def wrap(self, request):
194 if allow_self_service and self.user is not None:
195 username = request.get("username", self.user.username)
196 if username == self.user.username:
197 request["username"] = self.user.username
198 return await func(self, request)
199
200 if not self.user_has_permissions(*permissions, allow_anon=allow_anon):
201 if not self.user:
202 username = "Anonymous user"
203 user_perms = self.server.anon_perms
252 else: 204 else:
253 with self.request_stats.start_sample() as self.request_sample, \ 205 username = self.user.username
254 self.request_sample.measure(): 206 user_perms = self.user.permissions
255 await self.handlers[k](msg[k]) 207
256 return 208 self.logger.info(
209 "User %s with permissions %r denied from calling %s. Missing permissions(s) %r",
210 username,
211 ", ".join(user_perms),
212 func.__name__,
213 ", ".join(permissions),
214 )
215 raise bb.asyncrpc.InvokeError(
216 f"{username} is not allowed to access permissions(s) {', '.join(permissions)}"
217 )
218
219 return await func(self, request)
220
221 return wrap
222
223 return wrapper
224
225
226class ServerClient(bb.asyncrpc.AsyncServerConnection):
227 def __init__(self, socket, server):
228 super().__init__(socket, "OEHASHEQUIV", server.logger)
229 self.server = server
230 self.max_chunk = bb.asyncrpc.DEFAULT_MAX_CHUNK
231 self.user = None
232
233 self.handlers.update(
234 {
235 "get": self.handle_get,
236 "get-outhash": self.handle_get_outhash,
237 "get-stream": self.handle_get_stream,
238 "exists-stream": self.handle_exists_stream,
239 "get-stats": self.handle_get_stats,
240 "get-db-usage": self.handle_get_db_usage,
241 "get-db-query-columns": self.handle_get_db_query_columns,
242 # Not always read-only, but internally checks if the server is
243 # read-only
244 "report": self.handle_report,
245 "auth": self.handle_auth,
246 "get-user": self.handle_get_user,
247 "get-all-users": self.handle_get_all_users,
248 "become-user": self.handle_become_user,
249 }
250 )
257 251
258 raise ClientError("Unrecognized command %r" % msg) 252 if not self.server.read_only:
253 self.handlers.update(
254 {
255 "report-equiv": self.handle_equivreport,
256 "reset-stats": self.handle_reset_stats,
257 "backfill-wait": self.handle_backfill_wait,
258 "remove": self.handle_remove,
259 "gc-mark": self.handle_gc_mark,
260 "gc-mark-stream": self.handle_gc_mark_stream,
261 "gc-sweep": self.handle_gc_sweep,
262 "gc-status": self.handle_gc_status,
263 "clean-unused": self.handle_clean_unused,
264 "refresh-token": self.handle_refresh_token,
265 "set-user-perms": self.handle_set_perms,
266 "new-user": self.handle_new_user,
267 "delete-user": self.handle_delete_user,
268 }
269 )
259 270
260 def write_message(self, msg): 271 def raise_no_user_error(self, username):
261 for c in chunkify(json.dumps(msg), self.max_chunk): 272 raise bb.asyncrpc.InvokeError(f"No user named '{username}' exists")
262 self.writer.write(c.encode('utf-8'))
263 273
264 async def read_message(self): 274 def user_has_permissions(self, *permissions, allow_anon=True):
265 l = await self.reader.readline() 275 permissions = set(permissions)
266 if not l: 276 if allow_anon:
267 return None 277 if ALL_PERM in self.server.anon_perms:
278 return True
268 279
269 try: 280 if not permissions - self.server.anon_perms:
270 message = l.decode('utf-8') 281 return True
271 282
272 if not message.endswith('\n'): 283 if self.user is None:
273 return None 284 return False
274 285
275 return json.loads(message) 286 if ALL_PERM in self.user.permissions:
276 except (json.JSONDecodeError, UnicodeDecodeError) as e: 287 return True
277 logger.error('Bad message from client: %r' % message)
278 raise e
279 288
280 async def handle_chunk(self, request): 289 if not permissions - self.user.permissions:
281 lines = [] 290 return True
282 try:
283 while True:
284 l = await self.reader.readline()
285 l = l.rstrip(b"\n").decode("utf-8")
286 if not l:
287 break
288 lines.append(l)
289 291
290 msg = json.loads(''.join(lines)) 292 return False
291 except (json.JSONDecodeError, UnicodeDecodeError) as e:
292 logger.error('Bad message from client: %r' % message)
293 raise e
294 293
295 if 'chunk-stream' in msg: 294 def validate_proto_version(self):
296 raise ClientError("Nested chunks are not allowed") 295 return self.proto_version > (1, 0) and self.proto_version <= (1, 1)
297 296
298 await self.dispatch_message(msg) 297 async def process_requests(self):
298 async with self.server.db_engine.connect(self.logger) as db:
299 self.db = db
300 if self.server.upstream is not None:
301 self.upstream_client = await create_async_client(self.server.upstream)
302 else:
303 self.upstream_client = None
299 304
300 async def handle_get(self, request): 305 try:
301 method = request['method'] 306 await super().process_requests()
302 taskhash = request['taskhash'] 307 finally:
308 if self.upstream_client is not None:
309 await self.upstream_client.close()
303 310
304 if request.get('all', False): 311 async def dispatch_message(self, msg):
305 row = self.query_equivalent(method, taskhash, self.ALL_QUERY) 312 for k in self.handlers.keys():
306 else: 313 if k in msg:
307 row = self.query_equivalent(method, taskhash, self.FAST_QUERY) 314 self.logger.debug("Handling %s" % k)
315 if "stream" in k:
316 return await self.handlers[k](msg[k])
317 else:
318 with self.server.request_stats.start_sample() as self.request_sample, self.request_sample.measure():
319 return await self.handlers[k](msg[k])
308 320
309 if row is not None: 321 raise bb.asyncrpc.ClientError("Unrecognized command %r" % msg)
310 logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) 322
311 d = {k: row[k] for k in row.keys()} 323 @permissions(READ_PERM)
312 elif self.upstream_client is not None: 324 async def handle_get(self, request):
313 d = await copy_from_upstream(self.upstream_client, self.db, method, taskhash) 325 method = request["method"]
326 taskhash = request["taskhash"]
327 fetch_all = request.get("all", False)
328
329 return await self.get_unihash(method, taskhash, fetch_all)
330
331 async def get_unihash(self, method, taskhash, fetch_all=False):
332 d = None
333
334 if fetch_all:
335 row = await self.db.get_unihash_by_taskhash_full(method, taskhash)
336 if row is not None:
337 d = {k: row[k] for k in row.keys()}
338 elif self.upstream_client is not None:
339 d = await self.upstream_client.get_taskhash(method, taskhash, True)
340 await self.update_unified(d)
314 else: 341 else:
315 d = None 342 row = await self.db.get_equivalent(method, taskhash)
343
344 if row is not None:
345 d = {k: row[k] for k in row.keys()}
346 elif self.upstream_client is not None:
347 d = await self.upstream_client.get_taskhash(method, taskhash)
348 await self.db.insert_unihash(d["method"], d["taskhash"], d["unihash"])
316 349
317 self.write_message(d) 350 return d
318 351
352 @permissions(READ_PERM)
319 async def handle_get_outhash(self, request): 353 async def handle_get_outhash(self, request):
320 with closing(self.db.cursor()) as cursor: 354 method = request["method"]
321 cursor.execute(self.OUTHASH_QUERY, 355 outhash = request["outhash"]
322 {k: request[k] for k in ('method', 'outhash', 'taskhash')}) 356 taskhash = request["taskhash"]
357 with_unihash = request.get("with_unihash", True)
323 358
324 row = cursor.fetchone() 359 return await self.get_outhash(method, outhash, taskhash, with_unihash)
360
361 async def get_outhash(self, method, outhash, taskhash, with_unihash=True):
362 d = None
363 if with_unihash:
364 row = await self.db.get_unihash_by_outhash(method, outhash)
365 else:
366 row = await self.db.get_outhash(method, outhash)
325 367
326 if row is not None: 368 if row is not None:
327 logger.debug('Found equivalent outhash %s -> %s', (row['outhash'], row['unihash']))
328 d = {k: row[k] for k in row.keys()} 369 d = {k: row[k] for k in row.keys()}
329 else: 370 elif self.upstream_client is not None:
330 d = None 371 d = await self.upstream_client.get_outhash(method, outhash, taskhash)
372 await self.update_unified(d)
331 373
332 self.write_message(d) 374 return d
333 375
334 async def handle_get_stream(self, request): 376 async def update_unified(self, data):
335 self.write_message('ok') 377 if data is None:
378 return
379
380 await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"])
381 await self.db.insert_outhash(data)
382
383 async def _stream_handler(self, handler):
384 await self.socket.send_message("ok")
336 385
337 while True: 386 while True:
338 upstream = None 387 upstream = None
339 388
340 l = await self.reader.readline() 389 l = await self.socket.recv()
341 if not l: 390 if not l:
342 return 391 break
343 392
344 try: 393 try:
345 # This inner loop is very sensitive and must be as fast as 394 # This inner loop is very sensitive and must be as fast as
346 # possible (which is why the request sample is handled manually 395 # possible (which is why the request sample is handled manually
347 # instead of using 'with', and also why logging statements are 396 # instead of using 'with', and also why logging statements are
348 # commented out. 397 # commented out.
349 self.request_sample = self.request_stats.start_sample() 398 self.request_sample = self.server.request_stats.start_sample()
350 request_measure = self.request_sample.measure() 399 request_measure = self.request_sample.measure()
351 request_measure.start() 400 request_measure.start()
352 401
353 l = l.decode('utf-8').rstrip() 402 if l == "END":
354 if l == 'END': 403 break
355 self.writer.write('ok\n'.encode('utf-8'))
356 return
357
358 (method, taskhash) = l.split()
359 #logger.debug('Looking up %s %s' % (method, taskhash))
360 row = self.query_equivalent(method, taskhash, self.FAST_QUERY)
361 if row is not None:
362 msg = ('%s\n' % row['unihash']).encode('utf-8')
363 #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
364 elif self.upstream_client is not None:
365 upstream = await self.upstream_client.get_unihash(method, taskhash)
366 if upstream:
367 msg = ("%s\n" % upstream).encode("utf-8")
368 else:
369 msg = "\n".encode("utf-8")
370 else:
371 msg = '\n'.encode('utf-8')
372 404
373 self.writer.write(msg) 405 msg = await handler(l)
406 await self.socket.send(msg)
374 finally: 407 finally:
375 request_measure.end() 408 request_measure.end()
376 self.request_sample.end() 409 self.request_sample.end()
377 410
378 await self.writer.drain() 411 await self.socket.send("ok")
412 return self.NO_RESPONSE
379 413
380 # Post to the backfill queue after writing the result to minimize 414 @permissions(READ_PERM)
381 # the turn around time on a request 415 async def handle_get_stream(self, request):
382 if upstream is not None: 416 async def handler(l):
383 await self.backfill_queue.put((method, taskhash)) 417 (method, taskhash) = l.split()
418 # self.logger.debug('Looking up %s %s' % (method, taskhash))
419 row = await self.db.get_equivalent(method, taskhash)
384 420
385 async def handle_report(self, data): 421 if row is not None:
386 with closing(self.db.cursor()) as cursor: 422 # self.logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
387 cursor.execute(self.OUTHASH_QUERY, 423 return row["unihash"]
388 {k: data[k] for k in ('method', 'outhash', 'taskhash')})
389
390 row = cursor.fetchone()
391
392 if row is None and self.upstream_client:
393 # Try upstream
394 row = await copy_outhash_from_upstream(self.upstream_client,
395 self.db,
396 data['method'],
397 data['outhash'],
398 data['taskhash'])
399
400 # If no matching outhash was found, or one *was* found but it
401 # wasn't an exact match on the taskhash, a new entry for this
402 # taskhash should be added
403 if row is None or row['taskhash'] != data['taskhash']:
404 # If a row matching the outhash was found, the unihash for
405 # the new taskhash should be the same as that one.
406 # Otherwise the caller provided unihash is used.
407 unihash = data['unihash']
408 if row is not None:
409 unihash = row['unihash']
410
411 insert_data = {
412 'method': data['method'],
413 'outhash': data['outhash'],
414 'taskhash': data['taskhash'],
415 'unihash': unihash,
416 'created': datetime.now()
417 }
418 424
419 for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'): 425 if self.upstream_client is not None:
420 if k in data: 426 upstream = await self.upstream_client.get_unihash(method, taskhash)
421 insert_data[k] = data[k] 427 if upstream:
428 await self.server.backfill_queue.put((method, taskhash))
429 return upstream
422 430
423 insert_task(cursor, insert_data) 431 return ""
424 self.db.commit()
425 432
426 logger.info('Adding taskhash %s with unihash %s', 433 return await self._stream_handler(handler)
427 data['taskhash'], unihash)
428 434
429 d = { 435 @permissions(READ_PERM)
430 'taskhash': data['taskhash'], 436 async def handle_exists_stream(self, request):
431 'method': data['method'], 437 async def handler(l):
432 'unihash': unihash 438 if await self.db.unihash_exists(l):
433 } 439 return "true"
434 else:
435 d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
436 440
437 self.write_message(d) 441 if self.upstream_client is not None:
442 if await self.upstream_client.unihash_exists(l):
443 return "true"
438 444
439 async def handle_equivreport(self, data): 445 return "false"
440 with closing(self.db.cursor()) as cursor:
441 insert_data = {
442 'method': data['method'],
443 'outhash': "",
444 'taskhash': data['taskhash'],
445 'unihash': data['unihash'],
446 'created': datetime.now()
447 }
448 446
449 for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'): 447 return await self._stream_handler(handler)
450 if k in data:
451 insert_data[k] = data[k]
452 448
453 insert_task(cursor, insert_data, ignore=True) 449 async def report_readonly(self, data):
454 self.db.commit() 450 method = data["method"]
451 outhash = data["outhash"]
452 taskhash = data["taskhash"]
455 453
456 # Fetch the unihash that will be reported for the taskhash. If the 454 info = await self.get_outhash(method, outhash, taskhash)
457 # unihash matches, it means this row was inserted (or the mapping 455 if info:
458 # was already valid) 456 unihash = info["unihash"]
459 row = self.query_equivalent(data['method'], data['taskhash'], self.FAST_QUERY) 457 else:
458 unihash = data["unihash"]
460 459
461 if row['unihash'] == data['unihash']: 460 return {
462 logger.info('Adding taskhash equivalence for %s with unihash %s', 461 "taskhash": taskhash,
463 data['taskhash'], row['unihash']) 462 "method": method,
463 "unihash": unihash,
464 }
464 465
465 d = {k: row[k] for k in ('taskhash', 'method', 'unihash')} 466 # Since this can be called either read only or to report, the check to
467 # report is made inside the function
468 @permissions(READ_PERM)
469 async def handle_report(self, data):
470 if self.server.read_only or not self.user_has_permissions(REPORT_PERM):
471 return await self.report_readonly(data)
472
473 outhash_data = {
474 "method": data["method"],
475 "outhash": data["outhash"],
476 "taskhash": data["taskhash"],
477 "created": datetime.now(),
478 }
466 479
467 self.write_message(d) 480 for k in ("owner", "PN", "PV", "PR", "task", "outhash_siginfo"):
481 if k in data:
482 outhash_data[k] = data[k]
468 483
484 if self.user:
485 outhash_data["owner"] = self.user.username
469 486
470 async def handle_get_stats(self, request): 487 # Insert the new entry, unless it already exists
471 d = { 488 if await self.db.insert_outhash(outhash_data):
472 'requests': self.request_stats.todict(), 489 # If this row is new, check if it is equivalent to another
490 # output hash
491 row = await self.db.get_equivalent_for_outhash(
492 data["method"], data["outhash"], data["taskhash"]
493 )
494
495 if row is not None:
496 # A matching output hash was found. Set our taskhash to the
497 # same unihash since they are equivalent
498 unihash = row["unihash"]
499 else:
500 # No matching output hash was found. This is probably the
501 # first outhash to be added.
502 unihash = data["unihash"]
503
504 # Query upstream to see if it has a unihash we can use
505 if self.upstream_client is not None:
506 upstream_data = await self.upstream_client.get_outhash(
507 data["method"], data["outhash"], data["taskhash"]
508 )
509 if upstream_data is not None:
510 unihash = upstream_data["unihash"]
511
512 await self.db.insert_unihash(data["method"], data["taskhash"], unihash)
513
514 unihash_data = await self.get_unihash(data["method"], data["taskhash"])
515 if unihash_data is not None:
516 unihash = unihash_data["unihash"]
517 else:
518 unihash = data["unihash"]
519
520 return {
521 "taskhash": data["taskhash"],
522 "method": data["method"],
523 "unihash": unihash,
473 } 524 }
474 525
475 self.write_message(d) 526 @permissions(READ_PERM, REPORT_PERM)
527 async def handle_equivreport(self, data):
528 await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"])
529
530 # Fetch the unihash that will be reported for the taskhash. If the
531 # unihash matches, it means this row was inserted (or the mapping
532 # was already valid)
533 row = await self.db.get_equivalent(data["method"], data["taskhash"])
534
535 if row["unihash"] == data["unihash"]:
536 self.logger.info(
537 "Adding taskhash equivalence for %s with unihash %s",
538 data["taskhash"],
539 row["unihash"],
540 )
541
542 return {k: row[k] for k in ("taskhash", "method", "unihash")}
476 543
544 @permissions(READ_PERM)
545 async def handle_get_stats(self, request):
546 return {
547 "requests": self.server.request_stats.todict(),
548 }
549
550 @permissions(DB_ADMIN_PERM)
477 async def handle_reset_stats(self, request): 551 async def handle_reset_stats(self, request):
478 d = { 552 d = {
479 'requests': self.request_stats.todict(), 553 "requests": self.server.request_stats.todict(),
480 } 554 }
481 555
482 self.request_stats.reset() 556 self.server.request_stats.reset()
483 self.write_message(d) 557 return d
484 558
559 @permissions(READ_PERM)
485 async def handle_backfill_wait(self, request): 560 async def handle_backfill_wait(self, request):
486 d = { 561 d = {
487 'tasks': self.backfill_queue.qsize(), 562 "tasks": self.server.backfill_queue.qsize(),
488 } 563 }
489 await self.backfill_queue.join() 564 await self.server.backfill_queue.join()
490 self.write_message(d) 565 return d
491 566
492 def query_equivalent(self, method, taskhash, query): 567 @permissions(DB_ADMIN_PERM)
493 # This is part of the inner loop and must be as fast as possible 568 async def handle_remove(self, request):
494 try: 569 condition = request["where"]
495 cursor = self.db.cursor() 570 if not isinstance(condition, dict):
496 cursor.execute(query, {'method': method, 'taskhash': taskhash}) 571 raise TypeError("Bad condition type %s" % type(condition))
497 return cursor.fetchone()
498 except:
499 cursor.close()
500 572
573 return {"count": await self.db.remove(condition)}
501 574
502class Server(object): 575 @permissions(DB_ADMIN_PERM)
503 def __init__(self, db, loop=None, upstream=None, read_only=False): 576 async def handle_gc_mark(self, request):
504 if upstream and read_only: 577 condition = request["where"]
505 raise ServerError("Read-only hashserv cannot pull from an upstream server") 578 mark = request["mark"]
506 579
507 self.request_stats = Stats() 580 if not isinstance(condition, dict):
508 self.db = db 581 raise TypeError("Bad condition type %s" % type(condition))
509 582
510 if loop is None: 583 if not isinstance(mark, str):
511 self.loop = asyncio.new_event_loop() 584 raise TypeError("Bad mark type %s" % type(mark))
512 self.close_loop = True
513 else:
514 self.loop = loop
515 self.close_loop = False
516 585
517 self.upstream = upstream 586 return {"count": await self.db.gc_mark(mark, condition)}
518 self.read_only = read_only
519 587
520 self._cleanup_socket = None 588 @permissions(DB_ADMIN_PERM)
589 async def handle_gc_mark_stream(self, request):
590 async def handler(line):
591 try:
592 decoded_line = json.loads(line)
593 except json.JSONDecodeError as exc:
594 raise bb.asyncrpc.InvokeError(
595 "Could not decode JSONL input '%s'" % line
596 ) from exc
521 597
522 def start_tcp_server(self, host, port): 598 try:
523 self.server = self.loop.run_until_complete( 599 mark = decoded_line["mark"]
524 asyncio.start_server(self.handle_client, host, port, loop=self.loop) 600 condition = decoded_line["where"]
525 ) 601 if not isinstance(mark, str):
602 raise TypeError("Bad mark type %s" % type(mark))
526 603
527 for s in self.server.sockets: 604 if not isinstance(condition, dict):
528 logger.info('Listening on %r' % (s.getsockname(),)) 605 raise TypeError("Bad condition type %s" % type(condition))
529 # Newer python does this automatically. Do it manually here for 606 except KeyError as exc:
530 # maximum compatibility 607 raise bb.asyncrpc.InvokeError(
531 s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) 608 "Input line is missing key '%s' " % exc
532 s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1) 609 ) from exc
533 610
534 name = self.server.sockets[0].getsockname() 611 return json.dumps({"count": await self.db.gc_mark(mark, condition)})
535 if self.server.sockets[0].family == socket.AF_INET6:
536 self.address = "[%s]:%d" % (name[0], name[1])
537 else:
538 self.address = "%s:%d" % (name[0], name[1])
539 612
540 def start_unix_server(self, path): 613 return await self._stream_handler(handler)
541 def cleanup():
542 os.unlink(path)
543 614
544 cwd = os.getcwd() 615 @permissions(DB_ADMIN_PERM)
545 try: 616 async def handle_gc_sweep(self, request):
546 # Work around path length limits in AF_UNIX 617 mark = request["mark"]
547 os.chdir(os.path.dirname(path)) 618
548 self.server = self.loop.run_until_complete( 619 if not isinstance(mark, str):
549 asyncio.start_unix_server(self.handle_client, os.path.basename(path), loop=self.loop) 620 raise TypeError("Bad mark type %s" % type(mark))
621
622 current_mark = await self.db.get_current_gc_mark()
623
624 if not current_mark or mark != current_mark:
625 raise bb.asyncrpc.InvokeError(
626 f"'{mark}' is not the current mark. Refusing to sweep"
550 ) 627 )
551 finally:
552 os.chdir(cwd)
553 628
554 logger.info('Listening on %r' % path) 629 count = await self.db.gc_sweep()
630
631 return {"count": count}
555 632
556 self._cleanup_socket = cleanup 633 @permissions(DB_ADMIN_PERM)
557 self.address = "unix://%s" % os.path.abspath(path) 634 async def handle_gc_status(self, request):
635 (keep_rows, remove_rows, current_mark) = await self.db.gc_status()
636 return {
637 "keep": keep_rows,
638 "remove": remove_rows,
639 "mark": current_mark,
640 }
641
642 @permissions(DB_ADMIN_PERM)
643 async def handle_clean_unused(self, request):
644 max_age = request["max_age_seconds"]
645 oldest = datetime.now() - timedelta(seconds=-max_age)
646 return {"count": await self.db.clean_unused(oldest)}
647
648 @permissions(DB_ADMIN_PERM)
649 async def handle_get_db_usage(self, request):
650 return {"usage": await self.db.get_usage()}
651
652 @permissions(DB_ADMIN_PERM)
653 async def handle_get_db_query_columns(self, request):
654 return {"columns": await self.db.get_query_columns()}
655
656 # The authentication API is always allowed
657 async def handle_auth(self, request):
658 username = str(request["username"])
659 token = str(request["token"])
660
661 async def fail_auth():
662 nonlocal username
663 # Rate limit bad login attempts
664 await asyncio.sleep(1)
665 raise bb.asyncrpc.InvokeError(f"Unable to authenticate as {username}")
666
667 user, db_token = await self.db.lookup_user_token(username)
668
669 if not user or not db_token:
670 await fail_auth()
558 671
559 async def handle_client(self, reader, writer):
560 # writer.transport.set_write_buffer_limits(0)
561 try: 672 try:
562 client = ServerClient(reader, writer, self.db, self.request_stats, self.backfill_queue, self.upstream, self.read_only) 673 algo, salt, _ = db_token.split(":")
563 await client.process_requests() 674 except ValueError:
564 except Exception as e: 675 await fail_auth()
565 import traceback
566 logger.error('Error from client: %s' % str(e), exc_info=True)
567 traceback.print_exc()
568 writer.close()
569 logger.info('Client disconnected')
570
571 @contextmanager
572 def _backfill_worker(self):
573 async def backfill_worker_task():
574 client = await create_async_client(self.upstream)
575 try:
576 while True:
577 item = await self.backfill_queue.get()
578 if item is None:
579 self.backfill_queue.task_done()
580 break
581 method, taskhash = item
582 await copy_from_upstream(client, self.db, method, taskhash)
583 self.backfill_queue.task_done()
584 finally:
585 await client.close()
586 676
587 async def join_worker(worker): 677 if hash_token(algo, salt, token) != db_token:
588 await self.backfill_queue.put(None) 678 await fail_auth()
589 await worker
590 679
591 if self.upstream is not None: 680 self.user = user
592 worker = asyncio.ensure_future(backfill_worker_task()) 681
593 try: 682 self.logger.info("Authenticated as %s", username)
594 yield 683
595 finally: 684 return {
596 self.loop.run_until_complete(join_worker(worker)) 685 "result": True,
597 else: 686 "username": self.user.username,
598 yield 687 "permissions": sorted(list(self.user.permissions)),
688 }
599 689
600 def serve_forever(self): 690 @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False)
601 def signal_handler(): 691 async def handle_refresh_token(self, request):
602 self.loop.stop() 692 username = str(request["username"])
603 693
604 asyncio.set_event_loop(self.loop) 694 token = await new_token()
695
696 updated = await self.db.set_user_token(
697 username,
698 hash_token(TOKEN_ALGORITHM, new_salt(), token),
699 )
700 if not updated:
701 self.raise_no_user_error(username)
702
703 return {"username": username, "token": token}
704
705 def get_perm_arg(self, arg):
706 if not isinstance(arg, list):
707 raise bb.asyncrpc.InvokeError("Unexpected type for permissions")
708
709 arg = set(arg)
605 try: 710 try:
606 self.backfill_queue = asyncio.Queue() 711 arg.remove(NONE_PERM)
712 except KeyError:
713 pass
714
715 unknown_perms = arg - ALL_PERMISSIONS
716 if unknown_perms:
717 raise bb.asyncrpc.InvokeError(
718 "Unknown permissions %s" % ", ".join(sorted(list(unknown_perms)))
719 )
720
721 return sorted(list(arg))
722
723 def return_perms(self, permissions):
724 if ALL_PERM in permissions:
725 return sorted(list(ALL_PERMISSIONS))
726 return sorted(list(permissions))
607 727
608 self.loop.add_signal_handler(signal.SIGTERM, signal_handler) 728 @permissions(USER_ADMIN_PERM, allow_anon=False)
729 async def handle_set_perms(self, request):
730 username = str(request["username"])
731 permissions = self.get_perm_arg(request["permissions"])
609 732
610 with self._backfill_worker(): 733 if not await self.db.set_user_perms(username, permissions):
611 try: 734 self.raise_no_user_error(username)
612 self.loop.run_forever()
613 except KeyboardInterrupt:
614 pass
615 735
616 self.server.close() 736 return {
737 "username": username,
738 "permissions": self.return_perms(permissions),
739 }
740
741 @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False)
742 async def handle_get_user(self, request):
743 username = str(request["username"])
617 744
618 self.loop.run_until_complete(self.server.wait_closed()) 745 user = await self.db.lookup_user(username)
619 logger.info('Server shutting down') 746 if user is None:
620 finally: 747 return None
621 if self.close_loop: 748
622 if sys.version_info >= (3, 6): 749 return {
623 self.loop.run_until_complete(self.loop.shutdown_asyncgens()) 750 "username": user.username,
624 self.loop.close() 751 "permissions": self.return_perms(user.permissions),
752 }
753
754 @permissions(USER_ADMIN_PERM, allow_anon=False)
755 async def handle_get_all_users(self, request):
756 users = await self.db.get_all_users()
757 return {
758 "users": [
759 {
760 "username": u.username,
761 "permissions": self.return_perms(u.permissions),
762 }
763 for u in users
764 ]
765 }
625 766
626 if self._cleanup_socket is not None: 767 @permissions(USER_ADMIN_PERM, allow_anon=False)
627 self._cleanup_socket() 768 async def handle_new_user(self, request):
769 username = str(request["username"])
770 permissions = self.get_perm_arg(request["permissions"])
771
772 token = await new_token()
773
774 inserted = await self.db.new_user(
775 username,
776 permissions,
777 hash_token(TOKEN_ALGORITHM, new_salt(), token),
778 )
779 if not inserted:
780 raise bb.asyncrpc.InvokeError(f"Cannot create new user '{username}'")
781
782 return {
783 "username": username,
784 "permissions": self.return_perms(permissions),
785 "token": token,
786 }
787
788 @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False)
789 async def handle_delete_user(self, request):
790 username = str(request["username"])
791
792 if not await self.db.delete_user(username):
793 self.raise_no_user_error(username)
794
795 return {"username": username}
796
797 @permissions(USER_ADMIN_PERM, allow_anon=False)
798 async def handle_become_user(self, request):
799 username = str(request["username"])
800
801 user = await self.db.lookup_user(username)
802 if user is None:
803 raise bb.asyncrpc.InvokeError(f"User {username} doesn't exist")
804
805 self.user = user
806
807 self.logger.info("Became user %s", username)
808
809 return {
810 "username": self.user.username,
811 "permissions": self.return_perms(self.user.permissions),
812 }
813
814
815class Server(bb.asyncrpc.AsyncServer):
816 def __init__(
817 self,
818 db_engine,
819 upstream=None,
820 read_only=False,
821 anon_perms=DEFAULT_ANON_PERMS,
822 admin_username=None,
823 admin_password=None,
824 ):
825 if upstream and read_only:
826 raise bb.asyncrpc.ServerError(
827 "Read-only hashserv cannot pull from an upstream server"
828 )
829
830 disallowed_perms = set(anon_perms) - set(
831 [NONE_PERM, READ_PERM, REPORT_PERM, DB_ADMIN_PERM]
832 )
833
834 if disallowed_perms:
835 raise bb.asyncrpc.ServerError(
836 f"Permission(s) {' '.join(disallowed_perms)} are not allowed for anonymous users"
837 )
838
839 super().__init__(logger)
840
841 self.request_stats = Stats()
842 self.db_engine = db_engine
843 self.upstream = upstream
844 self.read_only = read_only
845 self.backfill_queue = None
846 self.anon_perms = set(anon_perms)
847 self.admin_username = admin_username
848 self.admin_password = admin_password
849
850 self.logger.info(
851 "Anonymous user permissions are: %s", ", ".join(self.anon_perms)
852 )
853
854 def accept_client(self, socket):
855 return ServerClient(socket, self)
856
857 async def create_admin_user(self):
858 admin_permissions = (ALL_PERM,)
859 async with self.db_engine.connect(self.logger) as db:
860 added = await db.new_user(
861 self.admin_username,
862 admin_permissions,
863 hash_token(TOKEN_ALGORITHM, new_salt(), self.admin_password),
864 )
865 if added:
866 self.logger.info("Created admin user '%s'", self.admin_username)
867 else:
868 await db.set_user_perms(
869 self.admin_username,
870 admin_permissions,
871 )
872 await db.set_user_token(
873 self.admin_username,
874 hash_token(TOKEN_ALGORITHM, new_salt(), self.admin_password),
875 )
876 self.logger.info("Admin user '%s' updated", self.admin_username)
877
878 async def backfill_worker_task(self):
879 async with await create_async_client(
880 self.upstream
881 ) as client, self.db_engine.connect(self.logger) as db:
882 while True:
883 item = await self.backfill_queue.get()
884 if item is None:
885 self.backfill_queue.task_done()
886 break
887
888 method, taskhash = item
889 d = await client.get_taskhash(method, taskhash)
890 if d is not None:
891 await db.insert_unihash(d["method"], d["taskhash"], d["unihash"])
892 self.backfill_queue.task_done()
893
894 def start(self):
895 tasks = super().start()
896 if self.upstream:
897 self.backfill_queue = asyncio.Queue()
898 tasks += [self.backfill_worker_task()]
899
900 self.loop.run_until_complete(self.db_engine.create())
901
902 if self.admin_username:
903 self.loop.run_until_complete(self.create_admin_user())
904
905 return tasks
906
907 async def stop(self):
908 if self.backfill_queue is not None:
909 await self.backfill_queue.put(None)
910 await super().stop()