summaryrefslogtreecommitdiffstats
path: root/bitbake/lib/hashserv/server.py
diff options
context:
space:
mode:
Diffstat (limited to 'bitbake/lib/hashserv/server.py')
-rw-r--r--bitbake/lib/hashserv/server.py1088
1 files changed, 671 insertions, 417 deletions
diff --git a/bitbake/lib/hashserv/server.py b/bitbake/lib/hashserv/server.py
index a0dc0c170f..68f64f983b 100644
--- a/bitbake/lib/hashserv/server.py
+++ b/bitbake/lib/hashserv/server.py
@@ -3,20 +3,51 @@
3# SPDX-License-Identifier: GPL-2.0-only 3# SPDX-License-Identifier: GPL-2.0-only
4# 4#
5 5
6from contextlib import closing, contextmanager 6from datetime import datetime, timedelta
7from datetime import datetime
8import asyncio 7import asyncio
9import json
10import logging 8import logging
11import math 9import math
12import os
13import signal
14import socket
15import sys
16import time 10import time
17from . import chunkify, DEFAULT_MAX_CHUNK, create_async_client, TABLE_COLUMNS 11import os
12import base64
13import hashlib
14from . import create_async_client
15import bb.asyncrpc
16
17logger = logging.getLogger("hashserv.server")
18
19
20# This permission only exists to match nothing
21NONE_PERM = "@none"
22
23READ_PERM = "@read"
24REPORT_PERM = "@report"
25DB_ADMIN_PERM = "@db-admin"
26USER_ADMIN_PERM = "@user-admin"
27ALL_PERM = "@all"
18 28
19logger = logging.getLogger('hashserv.server') 29ALL_PERMISSIONS = {
30 READ_PERM,
31 REPORT_PERM,
32 DB_ADMIN_PERM,
33 USER_ADMIN_PERM,
34 ALL_PERM,
35}
36
37DEFAULT_ANON_PERMS = (
38 READ_PERM,
39 REPORT_PERM,
40 DB_ADMIN_PERM,
41)
42
43TOKEN_ALGORITHM = "sha256"
44
45# 48 bytes of random data will result in 64 characters when base64
46# encoded. This number also ensures that the base64 encoding won't have any
47# trailing '=' characters.
48TOKEN_SIZE = 48
49
50SALT_SIZE = 8
20 51
21 52
22class Measurement(object): 53class Measurement(object):
@@ -106,522 +137,745 @@ class Stats(object):
106 return math.sqrt(self.s / (self.num - 1)) 137 return math.sqrt(self.s / (self.num - 1))
107 138
108 def todict(self): 139 def todict(self):
109 return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')} 140 return {
110 141 k: getattr(self, k)
111 142 for k in ("num", "total_time", "max_time", "average", "stdev")
112class ClientError(Exception):
113 pass
114
115class ServerError(Exception):
116 pass
117
118def insert_task(cursor, data, ignore=False):
119 keys = sorted(data.keys())
120 query = '''INSERT%s INTO tasks_v2 (%s) VALUES (%s)''' % (
121 " OR IGNORE" if ignore else "",
122 ', '.join(keys),
123 ', '.join(':' + k for k in keys))
124 cursor.execute(query, data)
125
126async def copy_from_upstream(client, db, method, taskhash):
127 d = await client.get_taskhash(method, taskhash, True)
128 if d is not None:
129 # Filter out unknown columns
130 d = {k: v for k, v in d.items() if k in TABLE_COLUMNS}
131 keys = sorted(d.keys())
132
133 with closing(db.cursor()) as cursor:
134 insert_task(cursor, d)
135 db.commit()
136
137 return d
138
139async def copy_outhash_from_upstream(client, db, method, outhash, taskhash):
140 d = await client.get_outhash(method, outhash, taskhash)
141 if d is not None:
142 # Filter out unknown columns
143 d = {k: v for k, v in d.items() if k in TABLE_COLUMNS}
144 keys = sorted(d.keys())
145
146 with closing(db.cursor()) as cursor:
147 insert_task(cursor, d)
148 db.commit()
149
150 return d
151
152class ServerClient(object):
153 FAST_QUERY = 'SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1'
154 ALL_QUERY = 'SELECT * FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1'
155 OUTHASH_QUERY = '''
156 -- Find tasks with a matching outhash (that is, tasks that
157 -- are equivalent)
158 SELECT * FROM tasks_v2 WHERE method=:method AND outhash=:outhash
159
160 -- If there is an exact match on the taskhash, return it.
161 -- Otherwise return the oldest matching outhash of any
162 -- taskhash
163 ORDER BY CASE WHEN taskhash=:taskhash THEN 1 ELSE 2 END,
164 created ASC
165
166 -- Only return one row
167 LIMIT 1
168 '''
169
170 def __init__(self, reader, writer, db, request_stats, backfill_queue, upstream, read_only):
171 self.reader = reader
172 self.writer = writer
173 self.db = db
174 self.request_stats = request_stats
175 self.max_chunk = DEFAULT_MAX_CHUNK
176 self.backfill_queue = backfill_queue
177 self.upstream = upstream
178
179 self.handlers = {
180 'get': self.handle_get,
181 'get-outhash': self.handle_get_outhash,
182 'get-stream': self.handle_get_stream,
183 'get-stats': self.handle_get_stats,
184 'chunk-stream': self.handle_chunk,
185 } 143 }
186 144
187 if not read_only:
188 self.handlers.update({
189 'report': self.handle_report,
190 'report-equiv': self.handle_equivreport,
191 'reset-stats': self.handle_reset_stats,
192 'backfill-wait': self.handle_backfill_wait,
193 })
194 145
195 async def process_requests(self): 146token_refresh_semaphore = asyncio.Lock()
196 if self.upstream is not None:
197 self.upstream_client = await create_async_client(self.upstream)
198 else:
199 self.upstream_client = None
200 147
201 try:
202 148
149async def new_token():
150 # Prevent malicious users from using this API to deduce the entropy
151 # pool on the server and thus be able to guess a token. *All* token
152 # refresh requests lock the same global semaphore and then sleep for a
153 # short time. The effectively rate limits the total number of requests
154 # than can be made across all clients to 10/second, which should be enough
155 # since you have to be an authenticated users to make the request in the
156 # first place
157 async with token_refresh_semaphore:
158 await asyncio.sleep(0.1)
159 raw = os.getrandom(TOKEN_SIZE, os.GRND_NONBLOCK)
203 160
204 self.addr = self.writer.get_extra_info('peername') 161 return base64.b64encode(raw, b"._").decode("utf-8")
205 logger.debug('Client %r connected' % (self.addr,))
206 162
207 # Read protocol and version
208 protocol = await self.reader.readline()
209 if protocol is None:
210 return
211 163
212 (proto_name, proto_version) = protocol.decode('utf-8').rstrip().split() 164def new_salt():
213 if proto_name != 'OEHASHEQUIV': 165 return os.getrandom(SALT_SIZE, os.GRND_NONBLOCK).hex()
214 return
215 166
216 proto_version = tuple(int(v) for v in proto_version.split('.'))
217 if proto_version < (1, 0) or proto_version > (1, 1):
218 return
219 167
220 # Read headers. Currently, no headers are implemented, so look for 168def hash_token(algo, salt, token):
221 # an empty line to signal the end of the headers 169 h = hashlib.new(algo)
222 while True: 170 h.update(salt.encode("utf-8"))
223 line = await self.reader.readline() 171 h.update(token.encode("utf-8"))
224 if line is None: 172 return ":".join([algo, salt, h.hexdigest()])
225 return
226 173
227 line = line.decode('utf-8').rstrip()
228 if not line:
229 break
230 174
231 # Handle messages 175def permissions(*permissions, allow_anon=True, allow_self_service=False):
232 while True: 176 """
233 d = await self.read_message() 177 Function decorator that can be used to decorate an RPC function call and
234 if d is None: 178 check that the current users permissions match the require permissions.
235 break
236 await self.dispatch_message(d)
237 await self.writer.drain()
238 except ClientError as e:
239 logger.error(str(e))
240 finally:
241 if self.upstream_client is not None:
242 await self.upstream_client.close()
243 179
244 self.writer.close() 180 If allow_anon is True, the user will also be allowed to make the RPC call
181 if the anonymous user permissions match the permissions.
245 182
246 async def dispatch_message(self, msg): 183 If allow_self_service is True, and the "username" property in the request
247 for k in self.handlers.keys(): 184 is the currently logged in user, or not specified, the user will also be
248 if k in msg: 185 allowed to make the request. This allows users to access normal privileged
249 logger.debug('Handling %s' % k) 186 API, as long as they are only modifying their own user properties (e.g.
250 if 'stream' in k: 187 users can be allowed to reset their own token without @user-admin
251 await self.handlers[k](msg[k]) 188 permissions, but not the token for any other user.
189 """
190
191 def wrapper(func):
192 async def wrap(self, request):
193 if allow_self_service and self.user is not None:
194 username = request.get("username", self.user.username)
195 if username == self.user.username:
196 request["username"] = self.user.username
197 return await func(self, request)
198
199 if not self.user_has_permissions(*permissions, allow_anon=allow_anon):
200 if not self.user:
201 username = "Anonymous user"
202 user_perms = self.server.anon_perms
252 else: 203 else:
253 with self.request_stats.start_sample() as self.request_sample, \ 204 username = self.user.username
254 self.request_sample.measure(): 205 user_perms = self.user.permissions
255 await self.handlers[k](msg[k]) 206
256 return 207 self.logger.info(
208 "User %s with permissions %r denied from calling %s. Missing permissions(s) %r",
209 username,
210 ", ".join(user_perms),
211 func.__name__,
212 ", ".join(permissions),
213 )
214 raise bb.asyncrpc.InvokeError(
215 f"{username} is not allowed to access permissions(s) {', '.join(permissions)}"
216 )
217
218 return await func(self, request)
219
220 return wrap
221
222 return wrapper
223
224
225class ServerClient(bb.asyncrpc.AsyncServerConnection):
226 def __init__(self, socket, server):
227 super().__init__(socket, "OEHASHEQUIV", server.logger)
228 self.server = server
229 self.max_chunk = bb.asyncrpc.DEFAULT_MAX_CHUNK
230 self.user = None
231
232 self.handlers.update(
233 {
234 "get": self.handle_get,
235 "get-outhash": self.handle_get_outhash,
236 "get-stream": self.handle_get_stream,
237 "exists-stream": self.handle_exists_stream,
238 "get-stats": self.handle_get_stats,
239 "get-db-usage": self.handle_get_db_usage,
240 "get-db-query-columns": self.handle_get_db_query_columns,
241 # Not always read-only, but internally checks if the server is
242 # read-only
243 "report": self.handle_report,
244 "auth": self.handle_auth,
245 "get-user": self.handle_get_user,
246 "get-all-users": self.handle_get_all_users,
247 "become-user": self.handle_become_user,
248 }
249 )
257 250
258 raise ClientError("Unrecognized command %r" % msg) 251 if not self.server.read_only:
252 self.handlers.update(
253 {
254 "report-equiv": self.handle_equivreport,
255 "reset-stats": self.handle_reset_stats,
256 "backfill-wait": self.handle_backfill_wait,
257 "remove": self.handle_remove,
258 "gc-mark": self.handle_gc_mark,
259 "gc-sweep": self.handle_gc_sweep,
260 "gc-status": self.handle_gc_status,
261 "clean-unused": self.handle_clean_unused,
262 "refresh-token": self.handle_refresh_token,
263 "set-user-perms": self.handle_set_perms,
264 "new-user": self.handle_new_user,
265 "delete-user": self.handle_delete_user,
266 }
267 )
259 268
260 def write_message(self, msg): 269 def raise_no_user_error(self, username):
261 for c in chunkify(json.dumps(msg), self.max_chunk): 270 raise bb.asyncrpc.InvokeError(f"No user named '{username}' exists")
262 self.writer.write(c.encode('utf-8'))
263 271
264 async def read_message(self): 272 def user_has_permissions(self, *permissions, allow_anon=True):
265 l = await self.reader.readline() 273 permissions = set(permissions)
266 if not l: 274 if allow_anon:
267 return None 275 if ALL_PERM in self.server.anon_perms:
276 return True
268 277
269 try: 278 if not permissions - self.server.anon_perms:
270 message = l.decode('utf-8') 279 return True
271 280
272 if not message.endswith('\n'): 281 if self.user is None:
273 return None 282 return False
274 283
275 return json.loads(message) 284 if ALL_PERM in self.user.permissions:
276 except (json.JSONDecodeError, UnicodeDecodeError) as e: 285 return True
277 logger.error('Bad message from client: %r' % message)
278 raise e
279 286
280 async def handle_chunk(self, request): 287 if not permissions - self.user.permissions:
281 lines = [] 288 return True
282 try:
283 while True:
284 l = await self.reader.readline()
285 l = l.rstrip(b"\n").decode("utf-8")
286 if not l:
287 break
288 lines.append(l)
289 289
290 msg = json.loads(''.join(lines)) 290 return False
291 except (json.JSONDecodeError, UnicodeDecodeError) as e:
292 logger.error('Bad message from client: %r' % message)
293 raise e
294 291
295 if 'chunk-stream' in msg: 292 def validate_proto_version(self):
296 raise ClientError("Nested chunks are not allowed") 293 return self.proto_version > (1, 0) and self.proto_version <= (1, 1)
297 294
298 await self.dispatch_message(msg) 295 async def process_requests(self):
296 async with self.server.db_engine.connect(self.logger) as db:
297 self.db = db
298 if self.server.upstream is not None:
299 self.upstream_client = await create_async_client(self.server.upstream)
300 else:
301 self.upstream_client = None
299 302
300 async def handle_get(self, request): 303 try:
301 method = request['method'] 304 await super().process_requests()
302 taskhash = request['taskhash'] 305 finally:
306 if self.upstream_client is not None:
307 await self.upstream_client.close()
303 308
304 if request.get('all', False): 309 async def dispatch_message(self, msg):
305 row = self.query_equivalent(method, taskhash, self.ALL_QUERY) 310 for k in self.handlers.keys():
306 else: 311 if k in msg:
307 row = self.query_equivalent(method, taskhash, self.FAST_QUERY) 312 self.logger.debug("Handling %s" % k)
313 if "stream" in k:
314 return await self.handlers[k](msg[k])
315 else:
316 with self.server.request_stats.start_sample() as self.request_sample, self.request_sample.measure():
317 return await self.handlers[k](msg[k])
308 318
309 if row is not None: 319 raise bb.asyncrpc.ClientError("Unrecognized command %r" % msg)
310 logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) 320
311 d = {k: row[k] for k in row.keys()} 321 @permissions(READ_PERM)
312 elif self.upstream_client is not None: 322 async def handle_get(self, request):
313 d = await copy_from_upstream(self.upstream_client, self.db, method, taskhash) 323 method = request["method"]
324 taskhash = request["taskhash"]
325 fetch_all = request.get("all", False)
326
327 return await self.get_unihash(method, taskhash, fetch_all)
328
329 async def get_unihash(self, method, taskhash, fetch_all=False):
330 d = None
331
332 if fetch_all:
333 row = await self.db.get_unihash_by_taskhash_full(method, taskhash)
334 if row is not None:
335 d = {k: row[k] for k in row.keys()}
336 elif self.upstream_client is not None:
337 d = await self.upstream_client.get_taskhash(method, taskhash, True)
338 await self.update_unified(d)
314 else: 339 else:
315 d = None 340 row = await self.db.get_equivalent(method, taskhash)
341
342 if row is not None:
343 d = {k: row[k] for k in row.keys()}
344 elif self.upstream_client is not None:
345 d = await self.upstream_client.get_taskhash(method, taskhash)
346 await self.db.insert_unihash(d["method"], d["taskhash"], d["unihash"])
316 347
317 self.write_message(d) 348 return d
318 349
350 @permissions(READ_PERM)
319 async def handle_get_outhash(self, request): 351 async def handle_get_outhash(self, request):
320 with closing(self.db.cursor()) as cursor: 352 method = request["method"]
321 cursor.execute(self.OUTHASH_QUERY, 353 outhash = request["outhash"]
322 {k: request[k] for k in ('method', 'outhash', 'taskhash')}) 354 taskhash = request["taskhash"]
355 with_unihash = request.get("with_unihash", True)
323 356
324 row = cursor.fetchone() 357 return await self.get_outhash(method, outhash, taskhash, with_unihash)
358
359 async def get_outhash(self, method, outhash, taskhash, with_unihash=True):
360 d = None
361 if with_unihash:
362 row = await self.db.get_unihash_by_outhash(method, outhash)
363 else:
364 row = await self.db.get_outhash(method, outhash)
325 365
326 if row is not None: 366 if row is not None:
327 logger.debug('Found equivalent outhash %s -> %s', (row['outhash'], row['unihash']))
328 d = {k: row[k] for k in row.keys()} 367 d = {k: row[k] for k in row.keys()}
329 else: 368 elif self.upstream_client is not None:
330 d = None 369 d = await self.upstream_client.get_outhash(method, outhash, taskhash)
370 await self.update_unified(d)
331 371
332 self.write_message(d) 372 return d
333 373
334 async def handle_get_stream(self, request): 374 async def update_unified(self, data):
335 self.write_message('ok') 375 if data is None:
376 return
377
378 await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"])
379 await self.db.insert_outhash(data)
380
381 async def _stream_handler(self, handler):
382 await self.socket.send_message("ok")
336 383
337 while True: 384 while True:
338 upstream = None 385 upstream = None
339 386
340 l = await self.reader.readline() 387 l = await self.socket.recv()
341 if not l: 388 if not l:
342 return 389 break
343 390
344 try: 391 try:
345 # This inner loop is very sensitive and must be as fast as 392 # This inner loop is very sensitive and must be as fast as
346 # possible (which is why the request sample is handled manually 393 # possible (which is why the request sample is handled manually
347 # instead of using 'with', and also why logging statements are 394 # instead of using 'with', and also why logging statements are
348 # commented out. 395 # commented out.
349 self.request_sample = self.request_stats.start_sample() 396 self.request_sample = self.server.request_stats.start_sample()
350 request_measure = self.request_sample.measure() 397 request_measure = self.request_sample.measure()
351 request_measure.start() 398 request_measure.start()
352 399
353 l = l.decode('utf-8').rstrip() 400 if l == "END":
354 if l == 'END': 401 break
355 self.writer.write('ok\n'.encode('utf-8'))
356 return
357
358 (method, taskhash) = l.split()
359 #logger.debug('Looking up %s %s' % (method, taskhash))
360 row = self.query_equivalent(method, taskhash, self.FAST_QUERY)
361 if row is not None:
362 msg = ('%s\n' % row['unihash']).encode('utf-8')
363 #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
364 elif self.upstream_client is not None:
365 upstream = await self.upstream_client.get_unihash(method, taskhash)
366 if upstream:
367 msg = ("%s\n" % upstream).encode("utf-8")
368 else:
369 msg = "\n".encode("utf-8")
370 else:
371 msg = '\n'.encode('utf-8')
372 402
373 self.writer.write(msg) 403 msg = await handler(l)
404 await self.socket.send(msg)
374 finally: 405 finally:
375 request_measure.end() 406 request_measure.end()
376 self.request_sample.end() 407 self.request_sample.end()
377 408
378 await self.writer.drain() 409 await self.socket.send("ok")
410 return self.NO_RESPONSE
379 411
380 # Post to the backfill queue after writing the result to minimize 412 @permissions(READ_PERM)
381 # the turn around time on a request 413 async def handle_get_stream(self, request):
382 if upstream is not None: 414 async def handler(l):
383 await self.backfill_queue.put((method, taskhash)) 415 (method, taskhash) = l.split()
416 # self.logger.debug('Looking up %s %s' % (method, taskhash))
417 row = await self.db.get_equivalent(method, taskhash)
384 418
385 async def handle_report(self, data): 419 if row is not None:
386 with closing(self.db.cursor()) as cursor: 420 # self.logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
387 cursor.execute(self.OUTHASH_QUERY, 421 return row["unihash"]
388 {k: data[k] for k in ('method', 'outhash', 'taskhash')})
389
390 row = cursor.fetchone()
391
392 if row is None and self.upstream_client:
393 # Try upstream
394 row = await copy_outhash_from_upstream(self.upstream_client,
395 self.db,
396 data['method'],
397 data['outhash'],
398 data['taskhash'])
399
400 # If no matching outhash was found, or one *was* found but it
401 # wasn't an exact match on the taskhash, a new entry for this
402 # taskhash should be added
403 if row is None or row['taskhash'] != data['taskhash']:
404 # If a row matching the outhash was found, the unihash for
405 # the new taskhash should be the same as that one.
406 # Otherwise the caller provided unihash is used.
407 unihash = data['unihash']
408 if row is not None:
409 unihash = row['unihash']
410
411 insert_data = {
412 'method': data['method'],
413 'outhash': data['outhash'],
414 'taskhash': data['taskhash'],
415 'unihash': unihash,
416 'created': datetime.now()
417 }
418 422
419 for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'): 423 if self.upstream_client is not None:
420 if k in data: 424 upstream = await self.upstream_client.get_unihash(method, taskhash)
421 insert_data[k] = data[k] 425 if upstream:
426 await self.server.backfill_queue.put((method, taskhash))
427 return upstream
422 428
423 insert_task(cursor, insert_data) 429 return ""
424 self.db.commit()
425 430
426 logger.info('Adding taskhash %s with unihash %s', 431 return await self._stream_handler(handler)
427 data['taskhash'], unihash)
428 432
429 d = { 433 @permissions(READ_PERM)
430 'taskhash': data['taskhash'], 434 async def handle_exists_stream(self, request):
431 'method': data['method'], 435 async def handler(l):
432 'unihash': unihash 436 if await self.db.unihash_exists(l):
433 } 437 return "true"
434 else:
435 d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
436 438
437 self.write_message(d) 439 if self.upstream_client is not None:
440 if await self.upstream_client.unihash_exists(l):
441 return "true"
438 442
439 async def handle_equivreport(self, data): 443 return "false"
440 with closing(self.db.cursor()) as cursor:
441 insert_data = {
442 'method': data['method'],
443 'outhash': "",
444 'taskhash': data['taskhash'],
445 'unihash': data['unihash'],
446 'created': datetime.now()
447 }
448 444
449 for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'): 445 return await self._stream_handler(handler)
450 if k in data:
451 insert_data[k] = data[k]
452 446
453 insert_task(cursor, insert_data, ignore=True) 447 async def report_readonly(self, data):
454 self.db.commit() 448 method = data["method"]
449 outhash = data["outhash"]
450 taskhash = data["taskhash"]
455 451
456 # Fetch the unihash that will be reported for the taskhash. If the 452 info = await self.get_outhash(method, outhash, taskhash)
457 # unihash matches, it means this row was inserted (or the mapping 453 if info:
458 # was already valid) 454 unihash = info["unihash"]
459 row = self.query_equivalent(data['method'], data['taskhash'], self.FAST_QUERY) 455 else:
456 unihash = data["unihash"]
460 457
461 if row['unihash'] == data['unihash']: 458 return {
462 logger.info('Adding taskhash equivalence for %s with unihash %s', 459 "taskhash": taskhash,
463 data['taskhash'], row['unihash']) 460 "method": method,
461 "unihash": unihash,
462 }
464 463
465 d = {k: row[k] for k in ('taskhash', 'method', 'unihash')} 464 # Since this can be called either read only or to report, the check to
465 # report is made inside the function
466 @permissions(READ_PERM)
467 async def handle_report(self, data):
468 if self.server.read_only or not self.user_has_permissions(REPORT_PERM):
469 return await self.report_readonly(data)
470
471 outhash_data = {
472 "method": data["method"],
473 "outhash": data["outhash"],
474 "taskhash": data["taskhash"],
475 "created": datetime.now(),
476 }
466 477
467 self.write_message(d) 478 for k in ("owner", "PN", "PV", "PR", "task", "outhash_siginfo"):
479 if k in data:
480 outhash_data[k] = data[k]
468 481
482 if self.user:
483 outhash_data["owner"] = self.user.username
469 484
470 async def handle_get_stats(self, request): 485 # Insert the new entry, unless it already exists
471 d = { 486 if await self.db.insert_outhash(outhash_data):
472 'requests': self.request_stats.todict(), 487 # If this row is new, check if it is equivalent to another
488 # output hash
489 row = await self.db.get_equivalent_for_outhash(
490 data["method"], data["outhash"], data["taskhash"]
491 )
492
493 if row is not None:
494 # A matching output hash was found. Set our taskhash to the
495 # same unihash since they are equivalent
496 unihash = row["unihash"]
497 else:
498 # No matching output hash was found. This is probably the
499 # first outhash to be added.
500 unihash = data["unihash"]
501
502 # Query upstream to see if it has a unihash we can use
503 if self.upstream_client is not None:
504 upstream_data = await self.upstream_client.get_outhash(
505 data["method"], data["outhash"], data["taskhash"]
506 )
507 if upstream_data is not None:
508 unihash = upstream_data["unihash"]
509
510 await self.db.insert_unihash(data["method"], data["taskhash"], unihash)
511
512 unihash_data = await self.get_unihash(data["method"], data["taskhash"])
513 if unihash_data is not None:
514 unihash = unihash_data["unihash"]
515 else:
516 unihash = data["unihash"]
517
518 return {
519 "taskhash": data["taskhash"],
520 "method": data["method"],
521 "unihash": unihash,
473 } 522 }
474 523
475 self.write_message(d) 524 @permissions(READ_PERM, REPORT_PERM)
525 async def handle_equivreport(self, data):
526 await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"])
527
528 # Fetch the unihash that will be reported for the taskhash. If the
529 # unihash matches, it means this row was inserted (or the mapping
530 # was already valid)
531 row = await self.db.get_equivalent(data["method"], data["taskhash"])
532
533 if row["unihash"] == data["unihash"]:
534 self.logger.info(
535 "Adding taskhash equivalence for %s with unihash %s",
536 data["taskhash"],
537 row["unihash"],
538 )
539
540 return {k: row[k] for k in ("taskhash", "method", "unihash")}
476 541
542 @permissions(READ_PERM)
543 async def handle_get_stats(self, request):
544 return {
545 "requests": self.server.request_stats.todict(),
546 }
547
548 @permissions(DB_ADMIN_PERM)
477 async def handle_reset_stats(self, request): 549 async def handle_reset_stats(self, request):
478 d = { 550 d = {
479 'requests': self.request_stats.todict(), 551 "requests": self.server.request_stats.todict(),
480 } 552 }
481 553
482 self.request_stats.reset() 554 self.server.request_stats.reset()
483 self.write_message(d) 555 return d
484 556
557 @permissions(READ_PERM)
485 async def handle_backfill_wait(self, request): 558 async def handle_backfill_wait(self, request):
486 d = { 559 d = {
487 'tasks': self.backfill_queue.qsize(), 560 "tasks": self.server.backfill_queue.qsize(),
488 } 561 }
489 await self.backfill_queue.join() 562 await self.server.backfill_queue.join()
490 self.write_message(d) 563 return d
564
565 @permissions(DB_ADMIN_PERM)
566 async def handle_remove(self, request):
567 condition = request["where"]
568 if not isinstance(condition, dict):
569 raise TypeError("Bad condition type %s" % type(condition))
570
571 return {"count": await self.db.remove(condition)}
572
573 @permissions(DB_ADMIN_PERM)
574 async def handle_gc_mark(self, request):
575 condition = request["where"]
576 mark = request["mark"]
577
578 if not isinstance(condition, dict):
579 raise TypeError("Bad condition type %s" % type(condition))
580
581 if not isinstance(mark, str):
582 raise TypeError("Bad mark type %s" % type(mark))
583
584 return {"count": await self.db.gc_mark(mark, condition)}
585
586 @permissions(DB_ADMIN_PERM)
587 async def handle_gc_sweep(self, request):
588 mark = request["mark"]
589
590 if not isinstance(mark, str):
591 raise TypeError("Bad mark type %s" % type(mark))
592
593 current_mark = await self.db.get_current_gc_mark()
594
595 if not current_mark or mark != current_mark:
596 raise bb.asyncrpc.InvokeError(
597 f"'{mark}' is not the current mark. Refusing to sweep"
598 )
599
600 count = await self.db.gc_sweep()
601
602 return {"count": count}
603
604 @permissions(DB_ADMIN_PERM)
605 async def handle_gc_status(self, request):
606 (keep_rows, remove_rows, current_mark) = await self.db.gc_status()
607 return {
608 "keep": keep_rows,
609 "remove": remove_rows,
610 "mark": current_mark,
611 }
612
613 @permissions(DB_ADMIN_PERM)
614 async def handle_clean_unused(self, request):
615 max_age = request["max_age_seconds"]
616 oldest = datetime.now() - timedelta(seconds=-max_age)
617 return {"count": await self.db.clean_unused(oldest)}
618
619 @permissions(DB_ADMIN_PERM)
620 async def handle_get_db_usage(self, request):
621 return {"usage": await self.db.get_usage()}
622
623 @permissions(DB_ADMIN_PERM)
624 async def handle_get_db_query_columns(self, request):
625 return {"columns": await self.db.get_query_columns()}
626
627 # The authentication API is always allowed
628 async def handle_auth(self, request):
629 username = str(request["username"])
630 token = str(request["token"])
631
632 async def fail_auth():
633 nonlocal username
634 # Rate limit bad login attempts
635 await asyncio.sleep(1)
636 raise bb.asyncrpc.InvokeError(f"Unable to authenticate as {username}")
637
638 user, db_token = await self.db.lookup_user_token(username)
639
640 if not user or not db_token:
641 await fail_auth()
491 642
492 def query_equivalent(self, method, taskhash, query):
493 # This is part of the inner loop and must be as fast as possible
494 try: 643 try:
495 cursor = self.db.cursor() 644 algo, salt, _ = db_token.split(":")
496 cursor.execute(query, {'method': method, 'taskhash': taskhash}) 645 except ValueError:
497 return cursor.fetchone() 646 await fail_auth()
498 except:
499 cursor.close()
500 647
648 if hash_token(algo, salt, token) != db_token:
649 await fail_auth()
501 650
502class Server(object): 651 self.user = user
503 def __init__(self, db, loop=None, upstream=None, read_only=False):
504 if upstream and read_only:
505 raise ServerError("Read-only hashserv cannot pull from an upstream server")
506 652
507 self.request_stats = Stats() 653 self.logger.info("Authenticated as %s", username)
508 self.db = db
509 654
510 if loop is None: 655 return {
511 self.loop = asyncio.new_event_loop() 656 "result": True,
512 self.close_loop = True 657 "username": self.user.username,
513 else: 658 "permissions": sorted(list(self.user.permissions)),
514 self.loop = loop 659 }
515 self.close_loop = False
516 660
517 self.upstream = upstream 661 @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False)
518 self.read_only = read_only 662 async def handle_refresh_token(self, request):
663 username = str(request["username"])
519 664
520 self._cleanup_socket = None 665 token = await new_token()
521 666
522 def start_tcp_server(self, host, port): 667 updated = await self.db.set_user_token(
523 self.server = self.loop.run_until_complete( 668 username,
524 asyncio.start_server(self.handle_client, host, port, loop=self.loop) 669 hash_token(TOKEN_ALGORITHM, new_salt(), token),
525 ) 670 )
671 if not updated:
672 self.raise_no_user_error(username)
526 673
527 for s in self.server.sockets: 674 return {"username": username, "token": token}
528 logger.info('Listening on %r' % (s.getsockname(),))
529 # Newer python does this automatically. Do it manually here for
530 # maximum compatibility
531 s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
532 s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
533
534 name = self.server.sockets[0].getsockname()
535 if self.server.sockets[0].family == socket.AF_INET6:
536 self.address = "[%s]:%d" % (name[0], name[1])
537 else:
538 self.address = "%s:%d" % (name[0], name[1])
539 675
540 def start_unix_server(self, path): 676 def get_perm_arg(self, arg):
541 def cleanup(): 677 if not isinstance(arg, list):
542 os.unlink(path) 678 raise bb.asyncrpc.InvokeError("Unexpected type for permissions")
543 679
544 cwd = os.getcwd() 680 arg = set(arg)
545 try: 681 try:
546 # Work around path length limits in AF_UNIX 682 arg.remove(NONE_PERM)
547 os.chdir(os.path.dirname(path)) 683 except KeyError:
548 self.server = self.loop.run_until_complete( 684 pass
549 asyncio.start_unix_server(self.handle_client, os.path.basename(path), loop=self.loop) 685
686 unknown_perms = arg - ALL_PERMISSIONS
687 if unknown_perms:
688 raise bb.asyncrpc.InvokeError(
689 "Unknown permissions %s" % ", ".join(sorted(list(unknown_perms)))
550 ) 690 )
551 finally:
552 os.chdir(cwd)
553 691
554 logger.info('Listening on %r' % path) 692 return sorted(list(arg))
555 693
556 self._cleanup_socket = cleanup 694 def return_perms(self, permissions):
557 self.address = "unix://%s" % os.path.abspath(path) 695 if ALL_PERM in permissions:
696 return sorted(list(ALL_PERMISSIONS))
697 return sorted(list(permissions))
558 698
559 async def handle_client(self, reader, writer): 699 @permissions(USER_ADMIN_PERM, allow_anon=False)
560 # writer.transport.set_write_buffer_limits(0) 700 async def handle_set_perms(self, request):
561 try: 701 username = str(request["username"])
562 client = ServerClient(reader, writer, self.db, self.request_stats, self.backfill_queue, self.upstream, self.read_only) 702 permissions = self.get_perm_arg(request["permissions"])
563 await client.process_requests()
564 except Exception as e:
565 import traceback
566 logger.error('Error from client: %s' % str(e), exc_info=True)
567 traceback.print_exc()
568 writer.close()
569 logger.info('Client disconnected')
570
571 @contextmanager
572 def _backfill_worker(self):
573 async def backfill_worker_task():
574 client = await create_async_client(self.upstream)
575 try:
576 while True:
577 item = await self.backfill_queue.get()
578 if item is None:
579 self.backfill_queue.task_done()
580 break
581 method, taskhash = item
582 await copy_from_upstream(client, self.db, method, taskhash)
583 self.backfill_queue.task_done()
584 finally:
585 await client.close()
586 703
587 async def join_worker(worker): 704 if not await self.db.set_user_perms(username, permissions):
588 await self.backfill_queue.put(None) 705 self.raise_no_user_error(username)
589 await worker
590 706
591 if self.upstream is not None: 707 return {
592 worker = asyncio.ensure_future(backfill_worker_task()) 708 "username": username,
593 try: 709 "permissions": self.return_perms(permissions),
594 yield 710 }
595 finally:
596 self.loop.run_until_complete(join_worker(worker))
597 else:
598 yield
599 711
600 def serve_forever(self): 712 @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False)
601 def signal_handler(): 713 async def handle_get_user(self, request):
602 self.loop.stop() 714 username = str(request["username"])
603 715
604 asyncio.set_event_loop(self.loop) 716 user = await self.db.lookup_user(username)
605 try: 717 if user is None:
606 self.backfill_queue = asyncio.Queue() 718 return None
719
720 return {
721 "username": user.username,
722 "permissions": self.return_perms(user.permissions),
723 }
724
725 @permissions(USER_ADMIN_PERM, allow_anon=False)
726 async def handle_get_all_users(self, request):
727 users = await self.db.get_all_users()
728 return {
729 "users": [
730 {
731 "username": u.username,
732 "permissions": self.return_perms(u.permissions),
733 }
734 for u in users
735 ]
736 }
737
738 @permissions(USER_ADMIN_PERM, allow_anon=False)
739 async def handle_new_user(self, request):
740 username = str(request["username"])
741 permissions = self.get_perm_arg(request["permissions"])
742
743 token = await new_token()
744
745 inserted = await self.db.new_user(
746 username,
747 permissions,
748 hash_token(TOKEN_ALGORITHM, new_salt(), token),
749 )
750 if not inserted:
751 raise bb.asyncrpc.InvokeError(f"Cannot create new user '{username}'")
752
753 return {
754 "username": username,
755 "permissions": self.return_perms(permissions),
756 "token": token,
757 }
758
759 @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False)
760 async def handle_delete_user(self, request):
761 username = str(request["username"])
762
763 if not await self.db.delete_user(username):
764 self.raise_no_user_error(username)
765
766 return {"username": username}
607 767
608 self.loop.add_signal_handler(signal.SIGTERM, signal_handler) 768 @permissions(USER_ADMIN_PERM, allow_anon=False)
769 async def handle_become_user(self, request):
770 username = str(request["username"])
609 771
610 with self._backfill_worker(): 772 user = await self.db.lookup_user(username)
611 try: 773 if user is None:
612 self.loop.run_forever() 774 raise bb.asyncrpc.InvokeError(f"User {username} doesn't exist")
613 except KeyboardInterrupt:
614 pass
615 775
616 self.server.close() 776 self.user = user
777
778 self.logger.info("Became user %s", username)
779
780 return {
781 "username": self.user.username,
782 "permissions": self.return_perms(self.user.permissions),
783 }
784
785
786class Server(bb.asyncrpc.AsyncServer):
787 def __init__(
788 self,
789 db_engine,
790 upstream=None,
791 read_only=False,
792 anon_perms=DEFAULT_ANON_PERMS,
793 admin_username=None,
794 admin_password=None,
795 ):
796 if upstream and read_only:
797 raise bb.asyncrpc.ServerError(
798 "Read-only hashserv cannot pull from an upstream server"
799 )
800
801 disallowed_perms = set(anon_perms) - set(
802 [NONE_PERM, READ_PERM, REPORT_PERM, DB_ADMIN_PERM]
803 )
804
805 if disallowed_perms:
806 raise bb.asyncrpc.ServerError(
807 f"Permission(s) {' '.join(disallowed_perms)} are not allowed for anonymous users"
808 )
617 809
618 self.loop.run_until_complete(self.server.wait_closed()) 810 super().__init__(logger)
619 logger.info('Server shutting down')
620 finally:
621 if self.close_loop:
622 if sys.version_info >= (3, 6):
623 self.loop.run_until_complete(self.loop.shutdown_asyncgens())
624 self.loop.close()
625 811
626 if self._cleanup_socket is not None: 812 self.request_stats = Stats()
627 self._cleanup_socket() 813 self.db_engine = db_engine
814 self.upstream = upstream
815 self.read_only = read_only
816 self.backfill_queue = None
817 self.anon_perms = set(anon_perms)
818 self.admin_username = admin_username
819 self.admin_password = admin_password
820
821 self.logger.info(
822 "Anonymous user permissions are: %s", ", ".join(self.anon_perms)
823 )
824
825 def accept_client(self, socket):
826 return ServerClient(socket, self)
827
828 async def create_admin_user(self):
829 admin_permissions = (ALL_PERM,)
830 async with self.db_engine.connect(self.logger) as db:
831 added = await db.new_user(
832 self.admin_username,
833 admin_permissions,
834 hash_token(TOKEN_ALGORITHM, new_salt(), self.admin_password),
835 )
836 if added:
837 self.logger.info("Created admin user '%s'", self.admin_username)
838 else:
839 await db.set_user_perms(
840 self.admin_username,
841 admin_permissions,
842 )
843 await db.set_user_token(
844 self.admin_username,
845 hash_token(TOKEN_ALGORITHM, new_salt(), self.admin_password),
846 )
847 self.logger.info("Admin user '%s' updated", self.admin_username)
848
849 async def backfill_worker_task(self):
850 async with await create_async_client(
851 self.upstream
852 ) as client, self.db_engine.connect(self.logger) as db:
853 while True:
854 item = await self.backfill_queue.get()
855 if item is None:
856 self.backfill_queue.task_done()
857 break
858
859 method, taskhash = item
860 d = await client.get_taskhash(method, taskhash)
861 if d is not None:
862 await db.insert_unihash(d["method"], d["taskhash"], d["unihash"])
863 self.backfill_queue.task_done()
864
865 def start(self):
866 tasks = super().start()
867 if self.upstream:
868 self.backfill_queue = asyncio.Queue()
869 tasks += [self.backfill_worker_task()]
870
871 self.loop.run_until_complete(self.db_engine.create())
872
873 if self.admin_username:
874 self.loop.run_until_complete(self.create_admin_user())
875
876 return tasks
877
878 async def stop(self):
879 if self.backfill_queue is not None:
880 await self.backfill_queue.put(None)
881 await super().stop()