diff options
Diffstat (limited to 'bitbake/lib/hashserv')
-rw-r--r-- | bitbake/lib/hashserv/__init__.py | 39 | ||||
-rw-r--r-- | bitbake/lib/hashserv/client.py | 19 | ||||
-rw-r--r-- | bitbake/lib/hashserv/server.py | 149 | ||||
-rw-r--r-- | bitbake/lib/hashserv/tests.py | 147 |
4 files changed, 268 insertions, 86 deletions
diff --git a/bitbake/lib/hashserv/__init__.py b/bitbake/lib/hashserv/__init__.py index 622ca17a91..55f48410d3 100644 --- a/bitbake/lib/hashserv/__init__.py +++ b/bitbake/lib/hashserv/__init__.py | |||
@@ -22,6 +22,24 @@ ADDR_TYPE_TCP = 1 | |||
22 | # is necessary | 22 | # is necessary |
23 | DEFAULT_MAX_CHUNK = 32 * 1024 | 23 | DEFAULT_MAX_CHUNK = 32 * 1024 |
24 | 24 | ||
25 | TABLE_DEFINITION = ( | ||
26 | ("method", "TEXT NOT NULL"), | ||
27 | ("outhash", "TEXT NOT NULL"), | ||
28 | ("taskhash", "TEXT NOT NULL"), | ||
29 | ("unihash", "TEXT NOT NULL"), | ||
30 | ("created", "DATETIME"), | ||
31 | |||
32 | # Optional fields | ||
33 | ("owner", "TEXT"), | ||
34 | ("PN", "TEXT"), | ||
35 | ("PV", "TEXT"), | ||
36 | ("PR", "TEXT"), | ||
37 | ("task", "TEXT"), | ||
38 | ("outhash_siginfo", "TEXT"), | ||
39 | ) | ||
40 | |||
41 | TABLE_COLUMNS = tuple(name for name, _ in TABLE_DEFINITION) | ||
42 | |||
25 | def setup_database(database, sync=True): | 43 | def setup_database(database, sync=True): |
26 | db = sqlite3.connect(database) | 44 | db = sqlite3.connect(database) |
27 | db.row_factory = sqlite3.Row | 45 | db.row_factory = sqlite3.Row |
@@ -30,23 +48,10 @@ def setup_database(database, sync=True): | |||
30 | cursor.execute(''' | 48 | cursor.execute(''' |
31 | CREATE TABLE IF NOT EXISTS tasks_v2 ( | 49 | CREATE TABLE IF NOT EXISTS tasks_v2 ( |
32 | id INTEGER PRIMARY KEY AUTOINCREMENT, | 50 | id INTEGER PRIMARY KEY AUTOINCREMENT, |
33 | method TEXT NOT NULL, | 51 | %s |
34 | outhash TEXT NOT NULL, | ||
35 | taskhash TEXT NOT NULL, | ||
36 | unihash TEXT NOT NULL, | ||
37 | created DATETIME, | ||
38 | |||
39 | -- Optional fields | ||
40 | owner TEXT, | ||
41 | PN TEXT, | ||
42 | PV TEXT, | ||
43 | PR TEXT, | ||
44 | task TEXT, | ||
45 | outhash_siginfo TEXT, | ||
46 | |||
47 | UNIQUE(method, outhash, taskhash) | 52 | UNIQUE(method, outhash, taskhash) |
48 | ) | 53 | ) |
49 | ''') | 54 | ''' % " ".join("%s %s," % (name, typ) for name, typ in TABLE_DEFINITION)) |
50 | cursor.execute('PRAGMA journal_mode = WAL') | 55 | cursor.execute('PRAGMA journal_mode = WAL') |
51 | cursor.execute('PRAGMA synchronous = %s' % ('NORMAL' if sync else 'OFF')) | 56 | cursor.execute('PRAGMA synchronous = %s' % ('NORMAL' if sync else 'OFF')) |
52 | 57 | ||
@@ -89,10 +94,10 @@ def chunkify(msg, max_chunk): | |||
89 | yield "\n" | 94 | yield "\n" |
90 | 95 | ||
91 | 96 | ||
92 | def create_server(addr, dbname, *, sync=True): | 97 | def create_server(addr, dbname, *, sync=True, upstream=None): |
93 | from . import server | 98 | from . import server |
94 | db = setup_database(dbname, sync=sync) | 99 | db = setup_database(dbname, sync=sync) |
95 | s = server.Server(db) | 100 | s = server.Server(db, upstream=upstream) |
96 | 101 | ||
97 | (typ, a) = parse_address(addr) | 102 | (typ, a) = parse_address(addr) |
98 | if typ == ADDR_TYPE_UNIX: | 103 | if typ == ADDR_TYPE_UNIX: |
diff --git a/bitbake/lib/hashserv/client.py b/bitbake/lib/hashserv/client.py index d0b3cf3863..ae5875d1b3 100644 --- a/bitbake/lib/hashserv/client.py +++ b/bitbake/lib/hashserv/client.py | |||
@@ -178,18 +178,16 @@ class AsyncClient(object): | |||
178 | await self._set_mode(self.MODE_NORMAL) | 178 | await self._set_mode(self.MODE_NORMAL) |
179 | return await self.send_message({"reset-stats": None}) | 179 | return await self.send_message({"reset-stats": None}) |
180 | 180 | ||
181 | async def backfill_wait(self): | ||
182 | await self._set_mode(self.MODE_NORMAL) | ||
183 | return (await self.send_message({"backfill-wait": None}))["tasks"] | ||
184 | |||
181 | 185 | ||
182 | class Client(object): | 186 | class Client(object): |
183 | def __init__(self): | 187 | def __init__(self): |
184 | self.client = AsyncClient() | 188 | self.client = AsyncClient() |
185 | self.loop = asyncio.new_event_loop() | 189 | self.loop = asyncio.new_event_loop() |
186 | 190 | ||
187 | def get_wrapper(self, downcall): | ||
188 | def wrapper(*args, **kwargs): | ||
189 | return self.loop.run_until_complete(downcall(*args, **kwargs)) | ||
190 | |||
191 | return wrapper | ||
192 | |||
193 | for call in ( | 191 | for call in ( |
194 | "connect_tcp", | 192 | "connect_tcp", |
195 | "connect_unix", | 193 | "connect_unix", |
@@ -200,9 +198,16 @@ class Client(object): | |||
200 | "get_taskhash", | 198 | "get_taskhash", |
201 | "get_stats", | 199 | "get_stats", |
202 | "reset_stats", | 200 | "reset_stats", |
201 | "backfill_wait", | ||
203 | ): | 202 | ): |
204 | downcall = getattr(self.client, call) | 203 | downcall = getattr(self.client, call) |
205 | setattr(self, call, get_wrapper(self, downcall)) | 204 | setattr(self, call, self._get_downcall_wrapper(downcall)) |
205 | |||
206 | def _get_downcall_wrapper(self, downcall): | ||
207 | def wrapper(*args, **kwargs): | ||
208 | return self.loop.run_until_complete(downcall(*args, **kwargs)) | ||
209 | |||
210 | return wrapper | ||
206 | 211 | ||
207 | @property | 212 | @property |
208 | def max_chunk(self): | 213 | def max_chunk(self): |
diff --git a/bitbake/lib/hashserv/server.py b/bitbake/lib/hashserv/server.py index 81050715ea..3ff4c51ccb 100644 --- a/bitbake/lib/hashserv/server.py +++ b/bitbake/lib/hashserv/server.py | |||
@@ -3,7 +3,7 @@ | |||
3 | # SPDX-License-Identifier: GPL-2.0-only | 3 | # SPDX-License-Identifier: GPL-2.0-only |
4 | # | 4 | # |
5 | 5 | ||
6 | from contextlib import closing | 6 | from contextlib import closing, contextmanager |
7 | from datetime import datetime | 7 | from datetime import datetime |
8 | import asyncio | 8 | import asyncio |
9 | import json | 9 | import json |
@@ -12,8 +12,9 @@ import math | |||
12 | import os | 12 | import os |
13 | import signal | 13 | import signal |
14 | import socket | 14 | import socket |
15 | import sys | ||
15 | import time | 16 | import time |
16 | from . import chunkify, DEFAULT_MAX_CHUNK | 17 | from . import chunkify, DEFAULT_MAX_CHUNK, create_async_client, TABLE_COLUMNS |
17 | 18 | ||
18 | logger = logging.getLogger('hashserv.server') | 19 | logger = logging.getLogger('hashserv.server') |
19 | 20 | ||
@@ -111,16 +112,40 @@ class Stats(object): | |||
111 | class ClientError(Exception): | 112 | class ClientError(Exception): |
112 | pass | 113 | pass |
113 | 114 | ||
115 | def insert_task(cursor, data, ignore=False): | ||
116 | keys = sorted(data.keys()) | ||
117 | query = '''INSERT%s INTO tasks_v2 (%s) VALUES (%s)''' % ( | ||
118 | " OR IGNORE" if ignore else "", | ||
119 | ', '.join(keys), | ||
120 | ', '.join(':' + k for k in keys)) | ||
121 | cursor.execute(query, data) | ||
122 | |||
123 | async def copy_from_upstream(client, db, method, taskhash): | ||
124 | d = await client.get_taskhash(method, taskhash, True) | ||
125 | if d is not None: | ||
126 | # Filter out unknown columns | ||
127 | d = {k: v for k, v in d.items() if k in TABLE_COLUMNS} | ||
128 | keys = sorted(d.keys()) | ||
129 | |||
130 | |||
131 | with closing(db.cursor()) as cursor: | ||
132 | insert_task(cursor, d) | ||
133 | db.commit() | ||
134 | |||
135 | return d | ||
136 | |||
114 | class ServerClient(object): | 137 | class ServerClient(object): |
115 | FAST_QUERY = 'SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1' | 138 | FAST_QUERY = 'SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1' |
116 | ALL_QUERY = 'SELECT * FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1' | 139 | ALL_QUERY = 'SELECT * FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1' |
117 | 140 | ||
118 | def __init__(self, reader, writer, db, request_stats): | 141 | def __init__(self, reader, writer, db, request_stats, backfill_queue, upstream): |
119 | self.reader = reader | 142 | self.reader = reader |
120 | self.writer = writer | 143 | self.writer = writer |
121 | self.db = db | 144 | self.db = db |
122 | self.request_stats = request_stats | 145 | self.request_stats = request_stats |
123 | self.max_chunk = DEFAULT_MAX_CHUNK | 146 | self.max_chunk = DEFAULT_MAX_CHUNK |
147 | self.backfill_queue = backfill_queue | ||
148 | self.upstream = upstream | ||
124 | 149 | ||
125 | self.handlers = { | 150 | self.handlers = { |
126 | 'get': self.handle_get, | 151 | 'get': self.handle_get, |
@@ -130,10 +155,18 @@ class ServerClient(object): | |||
130 | 'get-stats': self.handle_get_stats, | 155 | 'get-stats': self.handle_get_stats, |
131 | 'reset-stats': self.handle_reset_stats, | 156 | 'reset-stats': self.handle_reset_stats, |
132 | 'chunk-stream': self.handle_chunk, | 157 | 'chunk-stream': self.handle_chunk, |
158 | 'backfill-wait': self.handle_backfill_wait, | ||
133 | } | 159 | } |
134 | 160 | ||
135 | async def process_requests(self): | 161 | async def process_requests(self): |
162 | if self.upstream is not None: | ||
163 | self.upstream_client = await create_async_client(self.upstream) | ||
164 | else: | ||
165 | self.upstream_client = None | ||
166 | |||
136 | try: | 167 | try: |
168 | |||
169 | |||
137 | self.addr = self.writer.get_extra_info('peername') | 170 | self.addr = self.writer.get_extra_info('peername') |
138 | logger.debug('Client %r connected' % (self.addr,)) | 171 | logger.debug('Client %r connected' % (self.addr,)) |
139 | 172 | ||
@@ -171,6 +204,9 @@ class ServerClient(object): | |||
171 | except ClientError as e: | 204 | except ClientError as e: |
172 | logger.error(str(e)) | 205 | logger.error(str(e)) |
173 | finally: | 206 | finally: |
207 | if self.upstream_client is not None: | ||
208 | await self.upstream_client.close() | ||
209 | |||
174 | self.writer.close() | 210 | self.writer.close() |
175 | 211 | ||
176 | async def dispatch_message(self, msg): | 212 | async def dispatch_message(self, msg): |
@@ -239,15 +275,19 @@ class ServerClient(object): | |||
239 | if row is not None: | 275 | if row is not None: |
240 | logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) | 276 | logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) |
241 | d = {k: row[k] for k in row.keys()} | 277 | d = {k: row[k] for k in row.keys()} |
242 | 278 | elif self.upstream_client is not None: | |
243 | self.write_message(d) | 279 | d = await copy_from_upstream(self.upstream_client, self.db, method, taskhash) |
244 | else: | 280 | else: |
245 | self.write_message(None) | 281 | d = None |
282 | |||
283 | self.write_message(d) | ||
246 | 284 | ||
247 | async def handle_get_stream(self, request): | 285 | async def handle_get_stream(self, request): |
248 | self.write_message('ok') | 286 | self.write_message('ok') |
249 | 287 | ||
250 | while True: | 288 | while True: |
289 | upstream = None | ||
290 | |||
251 | l = await self.reader.readline() | 291 | l = await self.reader.readline() |
252 | if not l: | 292 | if not l: |
253 | return | 293 | return |
@@ -272,6 +312,12 @@ class ServerClient(object): | |||
272 | if row is not None: | 312 | if row is not None: |
273 | msg = ('%s\n' % row['unihash']).encode('utf-8') | 313 | msg = ('%s\n' % row['unihash']).encode('utf-8') |
274 | #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) | 314 | #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) |
315 | elif self.upstream_client is not None: | ||
316 | upstream = await self.upstream_client.get_unihash(method, taskhash) | ||
317 | if upstream: | ||
318 | msg = ("%s\n" % upstream).encode("utf-8") | ||
319 | else: | ||
320 | msg = "\n".encode("utf-8") | ||
275 | else: | 321 | else: |
276 | msg = '\n'.encode('utf-8') | 322 | msg = '\n'.encode('utf-8') |
277 | 323 | ||
@@ -282,6 +328,11 @@ class ServerClient(object): | |||
282 | 328 | ||
283 | await self.writer.drain() | 329 | await self.writer.drain() |
284 | 330 | ||
331 | # Post to the backfill queue after writing the result to minimize | ||
332 | # the turn around time on a request | ||
333 | if upstream is not None: | ||
334 | await self.backfill_queue.put((method, taskhash)) | ||
335 | |||
285 | async def handle_report(self, data): | 336 | async def handle_report(self, data): |
286 | with closing(self.db.cursor()) as cursor: | 337 | with closing(self.db.cursor()) as cursor: |
287 | cursor.execute(''' | 338 | cursor.execute(''' |
@@ -324,11 +375,7 @@ class ServerClient(object): | |||
324 | if k in data: | 375 | if k in data: |
325 | insert_data[k] = data[k] | 376 | insert_data[k] = data[k] |
326 | 377 | ||
327 | cursor.execute('''INSERT INTO tasks_v2 (%s) VALUES (%s)''' % ( | 378 | insert_task(cursor, insert_data) |
328 | ', '.join(sorted(insert_data.keys())), | ||
329 | ', '.join(':' + k for k in sorted(insert_data.keys()))), | ||
330 | insert_data) | ||
331 | |||
332 | self.db.commit() | 379 | self.db.commit() |
333 | 380 | ||
334 | logger.info('Adding taskhash %s with unihash %s', | 381 | logger.info('Adding taskhash %s with unihash %s', |
@@ -358,11 +405,7 @@ class ServerClient(object): | |||
358 | if k in data: | 405 | if k in data: |
359 | insert_data[k] = data[k] | 406 | insert_data[k] = data[k] |
360 | 407 | ||
361 | cursor.execute('''INSERT OR IGNORE INTO tasks_v2 (%s) VALUES (%s)''' % ( | 408 | insert_task(cursor, insert_data, ignore=True) |
362 | ', '.join(sorted(insert_data.keys())), | ||
363 | ', '.join(':' + k for k in sorted(insert_data.keys()))), | ||
364 | insert_data) | ||
365 | |||
366 | self.db.commit() | 409 | self.db.commit() |
367 | 410 | ||
368 | # Fetch the unihash that will be reported for the taskhash. If the | 411 | # Fetch the unihash that will be reported for the taskhash. If the |
@@ -394,6 +437,13 @@ class ServerClient(object): | |||
394 | self.request_stats.reset() | 437 | self.request_stats.reset() |
395 | self.write_message(d) | 438 | self.write_message(d) |
396 | 439 | ||
440 | async def handle_backfill_wait(self, request): | ||
441 | d = { | ||
442 | 'tasks': self.backfill_queue.qsize(), | ||
443 | } | ||
444 | await self.backfill_queue.join() | ||
445 | self.write_message(d) | ||
446 | |||
397 | def query_equivalent(self, method, taskhash, query): | 447 | def query_equivalent(self, method, taskhash, query): |
398 | # This is part of the inner loop and must be as fast as possible | 448 | # This is part of the inner loop and must be as fast as possible |
399 | try: | 449 | try: |
@@ -405,7 +455,7 @@ class ServerClient(object): | |||
405 | 455 | ||
406 | 456 | ||
407 | class Server(object): | 457 | class Server(object): |
408 | def __init__(self, db, loop=None): | 458 | def __init__(self, db, loop=None, upstream=None): |
409 | self.request_stats = Stats() | 459 | self.request_stats = Stats() |
410 | self.db = db | 460 | self.db = db |
411 | 461 | ||
@@ -416,6 +466,8 @@ class Server(object): | |||
416 | self.loop = loop | 466 | self.loop = loop |
417 | self.close_loop = False | 467 | self.close_loop = False |
418 | 468 | ||
469 | self.upstream = upstream | ||
470 | |||
419 | self._cleanup_socket = None | 471 | self._cleanup_socket = None |
420 | 472 | ||
421 | def start_tcp_server(self, host, port): | 473 | def start_tcp_server(self, host, port): |
@@ -458,7 +510,7 @@ class Server(object): | |||
458 | async def handle_client(self, reader, writer): | 510 | async def handle_client(self, reader, writer): |
459 | # writer.transport.set_write_buffer_limits(0) | 511 | # writer.transport.set_write_buffer_limits(0) |
460 | try: | 512 | try: |
461 | client = ServerClient(reader, writer, self.db, self.request_stats) | 513 | client = ServerClient(reader, writer, self.db, self.request_stats, self.backfill_queue, self.upstream) |
462 | await client.process_requests() | 514 | await client.process_requests() |
463 | except Exception as e: | 515 | except Exception as e: |
464 | import traceback | 516 | import traceback |
@@ -467,23 +519,60 @@ class Server(object): | |||
467 | writer.close() | 519 | writer.close() |
468 | logger.info('Client disconnected') | 520 | logger.info('Client disconnected') |
469 | 521 | ||
522 | @contextmanager | ||
523 | def _backfill_worker(self): | ||
524 | async def backfill_worker_task(): | ||
525 | client = await create_async_client(self.upstream) | ||
526 | try: | ||
527 | while True: | ||
528 | item = await self.backfill_queue.get() | ||
529 | if item is None: | ||
530 | self.backfill_queue.task_done() | ||
531 | break | ||
532 | method, taskhash = item | ||
533 | await copy_from_upstream(client, self.db, method, taskhash) | ||
534 | self.backfill_queue.task_done() | ||
535 | finally: | ||
536 | await client.close() | ||
537 | |||
538 | async def join_worker(worker): | ||
539 | await self.backfill_queue.put(None) | ||
540 | await worker | ||
541 | |||
542 | if self.upstream is not None: | ||
543 | worker = asyncio.ensure_future(backfill_worker_task()) | ||
544 | try: | ||
545 | yield | ||
546 | finally: | ||
547 | self.loop.run_until_complete(join_worker(worker)) | ||
548 | else: | ||
549 | yield | ||
550 | |||
470 | def serve_forever(self): | 551 | def serve_forever(self): |
471 | def signal_handler(): | 552 | def signal_handler(): |
472 | self.loop.stop() | 553 | self.loop.stop() |
473 | 554 | ||
474 | self.loop.add_signal_handler(signal.SIGTERM, signal_handler) | 555 | asyncio.set_event_loop(self.loop) |
475 | |||
476 | try: | 556 | try: |
477 | self.loop.run_forever() | 557 | self.backfill_queue = asyncio.Queue() |
478 | except KeyboardInterrupt: | 558 | |
479 | pass | 559 | self.loop.add_signal_handler(signal.SIGTERM, signal_handler) |
480 | 560 | ||
481 | self.server.close() | 561 | with self._backfill_worker(): |
482 | self.loop.run_until_complete(self.server.wait_closed()) | 562 | try: |
483 | logger.info('Server shutting down') | 563 | self.loop.run_forever() |
564 | except KeyboardInterrupt: | ||
565 | pass | ||
484 | 566 | ||
485 | if self.close_loop: | 567 | self.server.close() |
486 | self.loop.close() | 568 | |
569 | self.loop.run_until_complete(self.server.wait_closed()) | ||
570 | logger.info('Server shutting down') | ||
571 | finally: | ||
572 | if self.close_loop: | ||
573 | if sys.version_info >= (3, 6): | ||
574 | self.loop.run_until_complete(self.loop.shutdown_asyncgens()) | ||
575 | self.loop.close() | ||
487 | 576 | ||
488 | if self._cleanup_socket is not None: | 577 | if self._cleanup_socket is not None: |
489 | self._cleanup_socket() | 578 | self._cleanup_socket() |
diff --git a/bitbake/lib/hashserv/tests.py b/bitbake/lib/hashserv/tests.py index 4566f24738..3dd9a31bee 100644 --- a/bitbake/lib/hashserv/tests.py +++ b/bitbake/lib/hashserv/tests.py | |||
@@ -16,35 +16,54 @@ import threading | |||
16 | import unittest | 16 | import unittest |
17 | import socket | 17 | import socket |
18 | 18 | ||
19 | def _run_server(server, idx): | ||
20 | # logging.basicConfig(level=logging.DEBUG, filename='bbhashserv.log', filemode='w', | ||
21 | # format='%(levelname)s %(filename)s:%(lineno)d %(message)s') | ||
22 | sys.stdout = open('bbhashserv-%d.log' % idx, 'w') | ||
23 | sys.stderr = sys.stdout | ||
24 | server.serve_forever() | ||
19 | 25 | ||
20 | class TestHashEquivalenceServer(object): | 26 | class TestHashEquivalenceServer(object): |
21 | METHOD = 'TestMethod' | 27 | METHOD = 'TestMethod' |
22 | 28 | ||
23 | def _run_server(self): | 29 | server_index = 0 |
24 | # logging.basicConfig(level=logging.DEBUG, filename='bbhashserv.log', filemode='w', | 30 | |
25 | # format='%(levelname)s %(filename)s:%(lineno)d %(message)s') | 31 | def start_server(self, dbpath=None, upstream=None): |
26 | self.server.serve_forever() | 32 | self.server_index += 1 |
33 | if dbpath is None: | ||
34 | dbpath = os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index) | ||
35 | |||
36 | def cleanup_thread(thread): | ||
37 | thread.terminate() | ||
38 | thread.join() | ||
39 | |||
40 | server = create_server(self.get_server_addr(self.server_index), dbpath, upstream=upstream) | ||
41 | server.dbpath = dbpath | ||
42 | |||
43 | server.thread = multiprocessing.Process(target=_run_server, args=(server, self.server_index)) | ||
44 | server.thread.start() | ||
45 | self.addCleanup(cleanup_thread, server.thread) | ||
46 | |||
47 | def cleanup_client(client): | ||
48 | client.close() | ||
49 | |||
50 | client = create_client(server.address) | ||
51 | self.addCleanup(cleanup_client, client) | ||
52 | |||
53 | return (client, server) | ||
27 | 54 | ||
28 | def setUp(self): | 55 | def setUp(self): |
29 | if sys.version_info < (3, 5, 0): | 56 | if sys.version_info < (3, 5, 0): |
30 | self.skipTest('Python 3.5 or later required') | 57 | self.skipTest('Python 3.5 or later required') |
31 | 58 | ||
32 | self.temp_dir = tempfile.TemporaryDirectory(prefix='bb-hashserv') | 59 | self.temp_dir = tempfile.TemporaryDirectory(prefix='bb-hashserv') |
33 | self.dbfile = os.path.join(self.temp_dir.name, 'db.sqlite') | 60 | self.addCleanup(self.temp_dir.cleanup) |
34 | 61 | ||
35 | self.server = create_server(self.get_server_addr(), self.dbfile) | 62 | (self.client, self.server) = self.start_server() |
36 | self.server_thread = multiprocessing.Process(target=self._run_server) | 63 | |
37 | self.server_thread.start() | 64 | def assertClientGetHash(self, client, taskhash, unihash): |
38 | self.client = create_client(self.server.address) | 65 | result = client.get_unihash(self.METHOD, taskhash) |
39 | 66 | self.assertEqual(result, unihash) | |
40 | def tearDown(self): | ||
41 | # Shutdown server | ||
42 | s = getattr(self, 'server', None) | ||
43 | if s is not None: | ||
44 | self.server_thread.terminate() | ||
45 | self.server_thread.join() | ||
46 | self.client.close() | ||
47 | self.temp_dir.cleanup() | ||
48 | 67 | ||
49 | def test_create_hash(self): | 68 | def test_create_hash(self): |
50 | # Simple test that hashes can be created | 69 | # Simple test that hashes can be created |
@@ -52,8 +71,7 @@ class TestHashEquivalenceServer(object): | |||
52 | outhash = '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f' | 71 | outhash = '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f' |
53 | unihash = 'f46d3fbb439bd9b921095da657a4de906510d2cd' | 72 | unihash = 'f46d3fbb439bd9b921095da657a4de906510d2cd' |
54 | 73 | ||
55 | result = self.client.get_unihash(self.METHOD, taskhash) | 74 | self.assertClientGetHash(self.client, taskhash, None) |
56 | self.assertIsNone(result, msg='Found unexpected task, %r' % result) | ||
57 | 75 | ||
58 | result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) | 76 | result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) |
59 | self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') | 77 | self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') |
@@ -84,22 +102,19 @@ class TestHashEquivalenceServer(object): | |||
84 | unihash = '218e57509998197d570e2c98512d0105985dffc9' | 102 | unihash = '218e57509998197d570e2c98512d0105985dffc9' |
85 | self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) | 103 | self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) |
86 | 104 | ||
87 | result = self.client.get_unihash(self.METHOD, taskhash) | 105 | self.assertClientGetHash(self.client, taskhash, unihash) |
88 | self.assertEqual(result, unihash) | ||
89 | 106 | ||
90 | outhash2 = '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d' | 107 | outhash2 = '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d' |
91 | unihash2 = 'ae9a7d252735f0dafcdb10e2e02561ca3a47314c' | 108 | unihash2 = 'ae9a7d252735f0dafcdb10e2e02561ca3a47314c' |
92 | self.client.report_unihash(taskhash, self.METHOD, outhash2, unihash2) | 109 | self.client.report_unihash(taskhash, self.METHOD, outhash2, unihash2) |
93 | 110 | ||
94 | result = self.client.get_unihash(self.METHOD, taskhash) | 111 | self.assertClientGetHash(self.client, taskhash, unihash) |
95 | self.assertEqual(result, unihash) | ||
96 | 112 | ||
97 | outhash3 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4' | 113 | outhash3 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4' |
98 | unihash3 = '9217a7d6398518e5dc002ed58f2cbbbc78696603' | 114 | unihash3 = '9217a7d6398518e5dc002ed58f2cbbbc78696603' |
99 | self.client.report_unihash(taskhash, self.METHOD, outhash3, unihash3) | 115 | self.client.report_unihash(taskhash, self.METHOD, outhash3, unihash3) |
100 | 116 | ||
101 | result = self.client.get_unihash(self.METHOD, taskhash) | 117 | self.assertClientGetHash(self.client, taskhash, unihash) |
102 | self.assertEqual(result, unihash) | ||
103 | 118 | ||
104 | def test_huge_message(self): | 119 | def test_huge_message(self): |
105 | # Simple test that hashes can be created | 120 | # Simple test that hashes can be created |
@@ -107,8 +122,7 @@ class TestHashEquivalenceServer(object): | |||
107 | outhash = '3c979c3db45c569f51ab7626a4651074be3a9d11a84b1db076f5b14f7d39db44' | 122 | outhash = '3c979c3db45c569f51ab7626a4651074be3a9d11a84b1db076f5b14f7d39db44' |
108 | unihash = '90e9bc1d1f094c51824adca7f8ea79a048d68824' | 123 | unihash = '90e9bc1d1f094c51824adca7f8ea79a048d68824' |
109 | 124 | ||
110 | result = self.client.get_unihash(self.METHOD, taskhash) | 125 | self.assertClientGetHash(self.client, taskhash, None) |
111 | self.assertIsNone(result, msg='Found unexpected task, %r' % result) | ||
112 | 126 | ||
113 | siginfo = "0" * (self.client.max_chunk * 4) | 127 | siginfo = "0" * (self.client.max_chunk * 4) |
114 | 128 | ||
@@ -156,14 +170,83 @@ class TestHashEquivalenceServer(object): | |||
156 | 170 | ||
157 | self.assertFalse(failures) | 171 | self.assertFalse(failures) |
158 | 172 | ||
173 | def test_upstream_server(self): | ||
174 | # Tests upstream server support. This is done by creating two servers | ||
175 | # that share a database file. The downstream server has it upstream | ||
176 | # set to the test server, whereas the side server doesn't. This allows | ||
177 | # verification that the hash requests are being proxied to the upstream | ||
178 | # server by verifying that they appear on the downstream client, but not | ||
179 | # the side client. It also verifies that the results are pulled into | ||
180 | # the downstream database by checking that the downstream and side servers | ||
181 | # match after the downstream is done waiting for all backfill tasks | ||
182 | (down_client, down_server) = self.start_server(upstream=self.server.address) | ||
183 | (side_client, side_server) = self.start_server(dbpath=down_server.dbpath) | ||
184 | |||
185 | def check_hash(taskhash, unihash, old_sidehash): | ||
186 | nonlocal down_client | ||
187 | nonlocal side_client | ||
188 | |||
189 | # check upstream server | ||
190 | self.assertClientGetHash(self.client, taskhash, unihash) | ||
191 | |||
192 | # Hash should *not* be present on the side server | ||
193 | self.assertClientGetHash(side_client, taskhash, old_sidehash) | ||
194 | |||
195 | # Hash should be present on the downstream server, since it | ||
196 | # will defer to the upstream server. This will trigger | ||
197 | # the backfill in the downstream server | ||
198 | self.assertClientGetHash(down_client, taskhash, unihash) | ||
199 | |||
200 | # After waiting for the downstream client to finish backfilling the | ||
201 | # task from the upstream server, it should appear in the side server | ||
202 | # since the database is populated | ||
203 | down_client.backfill_wait() | ||
204 | self.assertClientGetHash(side_client, taskhash, unihash) | ||
205 | |||
206 | # Basic report | ||
207 | taskhash = '8aa96fcffb5831b3c2c0cb75f0431e3f8b20554a' | ||
208 | outhash = 'afe240a439959ce86f5e322f8c208e1fedefea9e813f2140c81af866cc9edf7e' | ||
209 | unihash = '218e57509998197d570e2c98512d0105985dffc9' | ||
210 | self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) | ||
211 | |||
212 | check_hash(taskhash, unihash, None) | ||
213 | |||
214 | # Duplicated taskhash with multiple output hashes and unihashes. | ||
215 | # All servers should agree with the originally reported hash | ||
216 | outhash2 = '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d' | ||
217 | unihash2 = 'ae9a7d252735f0dafcdb10e2e02561ca3a47314c' | ||
218 | self.client.report_unihash(taskhash, self.METHOD, outhash2, unihash2) | ||
219 | |||
220 | check_hash(taskhash, unihash, unihash) | ||
221 | |||
222 | # Report an equivalent task. The sideload will originally report | ||
223 | # no unihash until backfilled | ||
224 | taskhash3 = "044c2ec8aaf480685a00ff6ff49e6162e6ad34e1" | ||
225 | unihash3 = "def64766090d28f627e816454ed46894bb3aab36" | ||
226 | self.client.report_unihash(taskhash3, self.METHOD, outhash, unihash3) | ||
227 | |||
228 | check_hash(taskhash3, unihash, None) | ||
229 | |||
230 | # Test that reporting a unihash in the downstream client isn't | ||
231 | # propagating to the upstream server | ||
232 | taskhash4 = "e3da00593d6a7fb435c7e2114976c59c5fd6d561" | ||
233 | outhash4 = "1cf8713e645f491eb9c959d20b5cae1c47133a292626dda9b10709857cbe688a" | ||
234 | unihash4 = "3b5d3d83f07f259e9086fcb422c855286e18a57d" | ||
235 | down_client.report_unihash(taskhash4, self.METHOD, outhash4, unihash4) | ||
236 | down_client.backfill_wait() | ||
237 | |||
238 | self.assertClientGetHash(down_client, taskhash4, unihash4) | ||
239 | self.assertClientGetHash(side_client, taskhash4, unihash4) | ||
240 | self.assertClientGetHash(self.client, taskhash4, None) | ||
241 | |||
159 | 242 | ||
160 | class TestHashEquivalenceUnixServer(TestHashEquivalenceServer, unittest.TestCase): | 243 | class TestHashEquivalenceUnixServer(TestHashEquivalenceServer, unittest.TestCase): |
161 | def get_server_addr(self): | 244 | def get_server_addr(self, server_idx): |
162 | return "unix://" + os.path.join(self.temp_dir.name, 'sock') | 245 | return "unix://" + os.path.join(self.temp_dir.name, 'sock%d' % server_idx) |
163 | 246 | ||
164 | 247 | ||
165 | class TestHashEquivalenceTCPServer(TestHashEquivalenceServer, unittest.TestCase): | 248 | class TestHashEquivalenceTCPServer(TestHashEquivalenceServer, unittest.TestCase): |
166 | def get_server_addr(self): | 249 | def get_server_addr(self, server_idx): |
167 | # Some hosts cause asyncio module to misbehave, when IPv6 is not enabled. | 250 | # Some hosts cause asyncio module to misbehave, when IPv6 is not enabled. |
168 | # If IPv6 is enabled, it should be safe to use localhost directly, in general | 251 | # If IPv6 is enabled, it should be safe to use localhost directly, in general |
169 | # case it is more reliable to resolve the IP address explicitly. | 252 | # case it is more reliable to resolve the IP address explicitly. |