diff options
author | Joshua Watt <JPEWhacker@gmail.com> | 2020-11-10 08:59:56 -0600 |
---|---|---|
committer | Richard Purdie <richard.purdie@linuxfoundation.org> | 2020-11-24 15:26:12 +0000 |
commit | 96b548a79d87120655da3ac5501b8ad4726cf1a4 (patch) | |
tree | 06938cfe533173ad02664dc2f187867a165c5871 | |
parent | 859f43e176dcaaa652e24a2289abd75e18c077cf (diff) | |
download | poky-96b548a79d87120655da3ac5501b8ad4726cf1a4.tar.gz |
bitbake: bitbake: hashserve: Add support for readonly upstream
Adds support for an upstream server to be specified. The upstream server
will be queried for equivalent hashes whenever a miss is found in the
local server. If the server returns a match, it is merged into the
local database. In order to keep the get stream queries as fast as
possible since they are the critical path when bitbake is preparing the
run queue, missing tasks provided by the server are not immediately
pulled from the upstream server, but instead are put into a queue to be
backfilled by a worker task later.
(Bitbake rev: e6d6c0b39393e9bdf378c1eba141f815e26b724b)
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
Signed-off-by: Richard Purdie <richard.purdie@linuxfoundation.org>
-rw-r--r-- | bitbake/lib/hashserv/__init__.py | 39 | ||||
-rw-r--r-- | bitbake/lib/hashserv/client.py | 19 | ||||
-rw-r--r-- | bitbake/lib/hashserv/server.py | 149 | ||||
-rw-r--r-- | bitbake/lib/hashserv/tests.py | 147 |
4 files changed, 268 insertions, 86 deletions
diff --git a/bitbake/lib/hashserv/__init__.py b/bitbake/lib/hashserv/__init__.py index 622ca17a91..55f48410d3 100644 --- a/bitbake/lib/hashserv/__init__.py +++ b/bitbake/lib/hashserv/__init__.py | |||
@@ -22,6 +22,24 @@ ADDR_TYPE_TCP = 1 | |||
22 | # is necessary | 22 | # is necessary |
23 | DEFAULT_MAX_CHUNK = 32 * 1024 | 23 | DEFAULT_MAX_CHUNK = 32 * 1024 |
24 | 24 | ||
25 | TABLE_DEFINITION = ( | ||
26 | ("method", "TEXT NOT NULL"), | ||
27 | ("outhash", "TEXT NOT NULL"), | ||
28 | ("taskhash", "TEXT NOT NULL"), | ||
29 | ("unihash", "TEXT NOT NULL"), | ||
30 | ("created", "DATETIME"), | ||
31 | |||
32 | # Optional fields | ||
33 | ("owner", "TEXT"), | ||
34 | ("PN", "TEXT"), | ||
35 | ("PV", "TEXT"), | ||
36 | ("PR", "TEXT"), | ||
37 | ("task", "TEXT"), | ||
38 | ("outhash_siginfo", "TEXT"), | ||
39 | ) | ||
40 | |||
41 | TABLE_COLUMNS = tuple(name for name, _ in TABLE_DEFINITION) | ||
42 | |||
25 | def setup_database(database, sync=True): | 43 | def setup_database(database, sync=True): |
26 | db = sqlite3.connect(database) | 44 | db = sqlite3.connect(database) |
27 | db.row_factory = sqlite3.Row | 45 | db.row_factory = sqlite3.Row |
@@ -30,23 +48,10 @@ def setup_database(database, sync=True): | |||
30 | cursor.execute(''' | 48 | cursor.execute(''' |
31 | CREATE TABLE IF NOT EXISTS tasks_v2 ( | 49 | CREATE TABLE IF NOT EXISTS tasks_v2 ( |
32 | id INTEGER PRIMARY KEY AUTOINCREMENT, | 50 | id INTEGER PRIMARY KEY AUTOINCREMENT, |
33 | method TEXT NOT NULL, | 51 | %s |
34 | outhash TEXT NOT NULL, | ||
35 | taskhash TEXT NOT NULL, | ||
36 | unihash TEXT NOT NULL, | ||
37 | created DATETIME, | ||
38 | |||
39 | -- Optional fields | ||
40 | owner TEXT, | ||
41 | PN TEXT, | ||
42 | PV TEXT, | ||
43 | PR TEXT, | ||
44 | task TEXT, | ||
45 | outhash_siginfo TEXT, | ||
46 | |||
47 | UNIQUE(method, outhash, taskhash) | 52 | UNIQUE(method, outhash, taskhash) |
48 | ) | 53 | ) |
49 | ''') | 54 | ''' % " ".join("%s %s," % (name, typ) for name, typ in TABLE_DEFINITION)) |
50 | cursor.execute('PRAGMA journal_mode = WAL') | 55 | cursor.execute('PRAGMA journal_mode = WAL') |
51 | cursor.execute('PRAGMA synchronous = %s' % ('NORMAL' if sync else 'OFF')) | 56 | cursor.execute('PRAGMA synchronous = %s' % ('NORMAL' if sync else 'OFF')) |
52 | 57 | ||
@@ -89,10 +94,10 @@ def chunkify(msg, max_chunk): | |||
89 | yield "\n" | 94 | yield "\n" |
90 | 95 | ||
91 | 96 | ||
92 | def create_server(addr, dbname, *, sync=True): | 97 | def create_server(addr, dbname, *, sync=True, upstream=None): |
93 | from . import server | 98 | from . import server |
94 | db = setup_database(dbname, sync=sync) | 99 | db = setup_database(dbname, sync=sync) |
95 | s = server.Server(db) | 100 | s = server.Server(db, upstream=upstream) |
96 | 101 | ||
97 | (typ, a) = parse_address(addr) | 102 | (typ, a) = parse_address(addr) |
98 | if typ == ADDR_TYPE_UNIX: | 103 | if typ == ADDR_TYPE_UNIX: |
diff --git a/bitbake/lib/hashserv/client.py b/bitbake/lib/hashserv/client.py index d0b3cf3863..ae5875d1b3 100644 --- a/bitbake/lib/hashserv/client.py +++ b/bitbake/lib/hashserv/client.py | |||
@@ -178,18 +178,16 @@ class AsyncClient(object): | |||
178 | await self._set_mode(self.MODE_NORMAL) | 178 | await self._set_mode(self.MODE_NORMAL) |
179 | return await self.send_message({"reset-stats": None}) | 179 | return await self.send_message({"reset-stats": None}) |
180 | 180 | ||
181 | async def backfill_wait(self): | ||
182 | await self._set_mode(self.MODE_NORMAL) | ||
183 | return (await self.send_message({"backfill-wait": None}))["tasks"] | ||
184 | |||
181 | 185 | ||
182 | class Client(object): | 186 | class Client(object): |
183 | def __init__(self): | 187 | def __init__(self): |
184 | self.client = AsyncClient() | 188 | self.client = AsyncClient() |
185 | self.loop = asyncio.new_event_loop() | 189 | self.loop = asyncio.new_event_loop() |
186 | 190 | ||
187 | def get_wrapper(self, downcall): | ||
188 | def wrapper(*args, **kwargs): | ||
189 | return self.loop.run_until_complete(downcall(*args, **kwargs)) | ||
190 | |||
191 | return wrapper | ||
192 | |||
193 | for call in ( | 191 | for call in ( |
194 | "connect_tcp", | 192 | "connect_tcp", |
195 | "connect_unix", | 193 | "connect_unix", |
@@ -200,9 +198,16 @@ class Client(object): | |||
200 | "get_taskhash", | 198 | "get_taskhash", |
201 | "get_stats", | 199 | "get_stats", |
202 | "reset_stats", | 200 | "reset_stats", |
201 | "backfill_wait", | ||
203 | ): | 202 | ): |
204 | downcall = getattr(self.client, call) | 203 | downcall = getattr(self.client, call) |
205 | setattr(self, call, get_wrapper(self, downcall)) | 204 | setattr(self, call, self._get_downcall_wrapper(downcall)) |
205 | |||
206 | def _get_downcall_wrapper(self, downcall): | ||
207 | def wrapper(*args, **kwargs): | ||
208 | return self.loop.run_until_complete(downcall(*args, **kwargs)) | ||
209 | |||
210 | return wrapper | ||
206 | 211 | ||
207 | @property | 212 | @property |
208 | def max_chunk(self): | 213 | def max_chunk(self): |
diff --git a/bitbake/lib/hashserv/server.py b/bitbake/lib/hashserv/server.py index 81050715ea..3ff4c51ccb 100644 --- a/bitbake/lib/hashserv/server.py +++ b/bitbake/lib/hashserv/server.py | |||
@@ -3,7 +3,7 @@ | |||
3 | # SPDX-License-Identifier: GPL-2.0-only | 3 | # SPDX-License-Identifier: GPL-2.0-only |
4 | # | 4 | # |
5 | 5 | ||
6 | from contextlib import closing | 6 | from contextlib import closing, contextmanager |
7 | from datetime import datetime | 7 | from datetime import datetime |
8 | import asyncio | 8 | import asyncio |
9 | import json | 9 | import json |
@@ -12,8 +12,9 @@ import math | |||
12 | import os | 12 | import os |
13 | import signal | 13 | import signal |
14 | import socket | 14 | import socket |
15 | import sys | ||
15 | import time | 16 | import time |
16 | from . import chunkify, DEFAULT_MAX_CHUNK | 17 | from . import chunkify, DEFAULT_MAX_CHUNK, create_async_client, TABLE_COLUMNS |
17 | 18 | ||
18 | logger = logging.getLogger('hashserv.server') | 19 | logger = logging.getLogger('hashserv.server') |
19 | 20 | ||
@@ -111,16 +112,40 @@ class Stats(object): | |||
111 | class ClientError(Exception): | 112 | class ClientError(Exception): |
112 | pass | 113 | pass |
113 | 114 | ||
115 | def insert_task(cursor, data, ignore=False): | ||
116 | keys = sorted(data.keys()) | ||
117 | query = '''INSERT%s INTO tasks_v2 (%s) VALUES (%s)''' % ( | ||
118 | " OR IGNORE" if ignore else "", | ||
119 | ', '.join(keys), | ||
120 | ', '.join(':' + k for k in keys)) | ||
121 | cursor.execute(query, data) | ||
122 | |||
123 | async def copy_from_upstream(client, db, method, taskhash): | ||
124 | d = await client.get_taskhash(method, taskhash, True) | ||
125 | if d is not None: | ||
126 | # Filter out unknown columns | ||
127 | d = {k: v for k, v in d.items() if k in TABLE_COLUMNS} | ||
128 | keys = sorted(d.keys()) | ||
129 | |||
130 | |||
131 | with closing(db.cursor()) as cursor: | ||
132 | insert_task(cursor, d) | ||
133 | db.commit() | ||
134 | |||
135 | return d | ||
136 | |||
114 | class ServerClient(object): | 137 | class ServerClient(object): |
115 | FAST_QUERY = 'SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1' | 138 | FAST_QUERY = 'SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1' |
116 | ALL_QUERY = 'SELECT * FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1' | 139 | ALL_QUERY = 'SELECT * FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1' |
117 | 140 | ||
118 | def __init__(self, reader, writer, db, request_stats): | 141 | def __init__(self, reader, writer, db, request_stats, backfill_queue, upstream): |
119 | self.reader = reader | 142 | self.reader = reader |
120 | self.writer = writer | 143 | self.writer = writer |
121 | self.db = db | 144 | self.db = db |
122 | self.request_stats = request_stats | 145 | self.request_stats = request_stats |
123 | self.max_chunk = DEFAULT_MAX_CHUNK | 146 | self.max_chunk = DEFAULT_MAX_CHUNK |
147 | self.backfill_queue = backfill_queue | ||
148 | self.upstream = upstream | ||
124 | 149 | ||
125 | self.handlers = { | 150 | self.handlers = { |
126 | 'get': self.handle_get, | 151 | 'get': self.handle_get, |
@@ -130,10 +155,18 @@ class ServerClient(object): | |||
130 | 'get-stats': self.handle_get_stats, | 155 | 'get-stats': self.handle_get_stats, |
131 | 'reset-stats': self.handle_reset_stats, | 156 | 'reset-stats': self.handle_reset_stats, |
132 | 'chunk-stream': self.handle_chunk, | 157 | 'chunk-stream': self.handle_chunk, |
158 | 'backfill-wait': self.handle_backfill_wait, | ||
133 | } | 159 | } |
134 | 160 | ||
135 | async def process_requests(self): | 161 | async def process_requests(self): |
162 | if self.upstream is not None: | ||
163 | self.upstream_client = await create_async_client(self.upstream) | ||
164 | else: | ||
165 | self.upstream_client = None | ||
166 | |||
136 | try: | 167 | try: |
168 | |||
169 | |||
137 | self.addr = self.writer.get_extra_info('peername') | 170 | self.addr = self.writer.get_extra_info('peername') |
138 | logger.debug('Client %r connected' % (self.addr,)) | 171 | logger.debug('Client %r connected' % (self.addr,)) |
139 | 172 | ||
@@ -171,6 +204,9 @@ class ServerClient(object): | |||
171 | except ClientError as e: | 204 | except ClientError as e: |
172 | logger.error(str(e)) | 205 | logger.error(str(e)) |
173 | finally: | 206 | finally: |
207 | if self.upstream_client is not None: | ||
208 | await self.upstream_client.close() | ||
209 | |||
174 | self.writer.close() | 210 | self.writer.close() |
175 | 211 | ||
176 | async def dispatch_message(self, msg): | 212 | async def dispatch_message(self, msg): |
@@ -239,15 +275,19 @@ class ServerClient(object): | |||
239 | if row is not None: | 275 | if row is not None: |
240 | logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) | 276 | logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) |
241 | d = {k: row[k] for k in row.keys()} | 277 | d = {k: row[k] for k in row.keys()} |
242 | 278 | elif self.upstream_client is not None: | |
243 | self.write_message(d) | 279 | d = await copy_from_upstream(self.upstream_client, self.db, method, taskhash) |
244 | else: | 280 | else: |
245 | self.write_message(None) | 281 | d = None |
282 | |||
283 | self.write_message(d) | ||
246 | 284 | ||
247 | async def handle_get_stream(self, request): | 285 | async def handle_get_stream(self, request): |
248 | self.write_message('ok') | 286 | self.write_message('ok') |
249 | 287 | ||
250 | while True: | 288 | while True: |
289 | upstream = None | ||
290 | |||
251 | l = await self.reader.readline() | 291 | l = await self.reader.readline() |
252 | if not l: | 292 | if not l: |
253 | return | 293 | return |
@@ -272,6 +312,12 @@ class ServerClient(object): | |||
272 | if row is not None: | 312 | if row is not None: |
273 | msg = ('%s\n' % row['unihash']).encode('utf-8') | 313 | msg = ('%s\n' % row['unihash']).encode('utf-8') |
274 | #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) | 314 | #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) |
315 | elif self.upstream_client is not None: | ||
316 | upstream = await self.upstream_client.get_unihash(method, taskhash) | ||
317 | if upstream: | ||
318 | msg = ("%s\n" % upstream).encode("utf-8") | ||
319 | else: | ||
320 | msg = "\n".encode("utf-8") | ||
275 | else: | 321 | else: |
276 | msg = '\n'.encode('utf-8') | 322 | msg = '\n'.encode('utf-8') |
277 | 323 | ||
@@ -282,6 +328,11 @@ class ServerClient(object): | |||
282 | 328 | ||
283 | await self.writer.drain() | 329 | await self.writer.drain() |
284 | 330 | ||
331 | # Post to the backfill queue after writing the result to minimize | ||
332 | # the turn around time on a request | ||
333 | if upstream is not None: | ||
334 | await self.backfill_queue.put((method, taskhash)) | ||
335 | |||
285 | async def handle_report(self, data): | 336 | async def handle_report(self, data): |
286 | with closing(self.db.cursor()) as cursor: | 337 | with closing(self.db.cursor()) as cursor: |
287 | cursor.execute(''' | 338 | cursor.execute(''' |
@@ -324,11 +375,7 @@ class ServerClient(object): | |||
324 | if k in data: | 375 | if k in data: |
325 | insert_data[k] = data[k] | 376 | insert_data[k] = data[k] |
326 | 377 | ||
327 | cursor.execute('''INSERT INTO tasks_v2 (%s) VALUES (%s)''' % ( | 378 | insert_task(cursor, insert_data) |
328 | ', '.join(sorted(insert_data.keys())), | ||
329 | ', '.join(':' + k for k in sorted(insert_data.keys()))), | ||
330 | insert_data) | ||
331 | |||
332 | self.db.commit() | 379 | self.db.commit() |
333 | 380 | ||
334 | logger.info('Adding taskhash %s with unihash %s', | 381 | logger.info('Adding taskhash %s with unihash %s', |
@@ -358,11 +405,7 @@ class ServerClient(object): | |||
358 | if k in data: | 405 | if k in data: |
359 | insert_data[k] = data[k] | 406 | insert_data[k] = data[k] |
360 | 407 | ||
361 | cursor.execute('''INSERT OR IGNORE INTO tasks_v2 (%s) VALUES (%s)''' % ( | 408 | insert_task(cursor, insert_data, ignore=True) |
362 | ', '.join(sorted(insert_data.keys())), | ||
363 | ', '.join(':' + k for k in sorted(insert_data.keys()))), | ||
364 | insert_data) | ||
365 | |||
366 | self.db.commit() | 409 | self.db.commit() |
367 | 410 | ||
368 | # Fetch the unihash that will be reported for the taskhash. If the | 411 | # Fetch the unihash that will be reported for the taskhash. If the |
@@ -394,6 +437,13 @@ class ServerClient(object): | |||
394 | self.request_stats.reset() | 437 | self.request_stats.reset() |
395 | self.write_message(d) | 438 | self.write_message(d) |
396 | 439 | ||
440 | async def handle_backfill_wait(self, request): | ||
441 | d = { | ||
442 | 'tasks': self.backfill_queue.qsize(), | ||
443 | } | ||
444 | await self.backfill_queue.join() | ||
445 | self.write_message(d) | ||
446 | |||
397 | def query_equivalent(self, method, taskhash, query): | 447 | def query_equivalent(self, method, taskhash, query): |
398 | # This is part of the inner loop and must be as fast as possible | 448 | # This is part of the inner loop and must be as fast as possible |
399 | try: | 449 | try: |
@@ -405,7 +455,7 @@ class ServerClient(object): | |||
405 | 455 | ||
406 | 456 | ||
407 | class Server(object): | 457 | class Server(object): |
408 | def __init__(self, db, loop=None): | 458 | def __init__(self, db, loop=None, upstream=None): |
409 | self.request_stats = Stats() | 459 | self.request_stats = Stats() |
410 | self.db = db | 460 | self.db = db |
411 | 461 | ||
@@ -416,6 +466,8 @@ class Server(object): | |||
416 | self.loop = loop | 466 | self.loop = loop |
417 | self.close_loop = False | 467 | self.close_loop = False |
418 | 468 | ||
469 | self.upstream = upstream | ||
470 | |||
419 | self._cleanup_socket = None | 471 | self._cleanup_socket = None |
420 | 472 | ||
421 | def start_tcp_server(self, host, port): | 473 | def start_tcp_server(self, host, port): |
@@ -458,7 +510,7 @@ class Server(object): | |||
458 | async def handle_client(self, reader, writer): | 510 | async def handle_client(self, reader, writer): |
459 | # writer.transport.set_write_buffer_limits(0) | 511 | # writer.transport.set_write_buffer_limits(0) |
460 | try: | 512 | try: |
461 | client = ServerClient(reader, writer, self.db, self.request_stats) | 513 | client = ServerClient(reader, writer, self.db, self.request_stats, self.backfill_queue, self.upstream) |
462 | await client.process_requests() | 514 | await client.process_requests() |
463 | except Exception as e: | 515 | except Exception as e: |
464 | import traceback | 516 | import traceback |
@@ -467,23 +519,60 @@ class Server(object): | |||
467 | writer.close() | 519 | writer.close() |
468 | logger.info('Client disconnected') | 520 | logger.info('Client disconnected') |
469 | 521 | ||
522 | @contextmanager | ||
523 | def _backfill_worker(self): | ||
524 | async def backfill_worker_task(): | ||
525 | client = await create_async_client(self.upstream) | ||
526 | try: | ||
527 | while True: | ||
528 | item = await self.backfill_queue.get() | ||
529 | if item is None: | ||
530 | self.backfill_queue.task_done() | ||
531 | break | ||
532 | method, taskhash = item | ||
533 | await copy_from_upstream(client, self.db, method, taskhash) | ||
534 | self.backfill_queue.task_done() | ||
535 | finally: | ||
536 | await client.close() | ||
537 | |||
538 | async def join_worker(worker): | ||
539 | await self.backfill_queue.put(None) | ||
540 | await worker | ||
541 | |||
542 | if self.upstream is not None: | ||
543 | worker = asyncio.ensure_future(backfill_worker_task()) | ||
544 | try: | ||
545 | yield | ||
546 | finally: | ||
547 | self.loop.run_until_complete(join_worker(worker)) | ||
548 | else: | ||
549 | yield | ||
550 | |||
470 | def serve_forever(self): | 551 | def serve_forever(self): |
471 | def signal_handler(): | 552 | def signal_handler(): |
472 | self.loop.stop() | 553 | self.loop.stop() |
473 | 554 | ||
474 | self.loop.add_signal_handler(signal.SIGTERM, signal_handler) | 555 | asyncio.set_event_loop(self.loop) |
475 | |||
476 | try: | 556 | try: |
477 | self.loop.run_forever() | 557 | self.backfill_queue = asyncio.Queue() |
478 | except KeyboardInterrupt: | 558 | |
479 | pass | 559 | self.loop.add_signal_handler(signal.SIGTERM, signal_handler) |
480 | 560 | ||
481 | self.server.close() | 561 | with self._backfill_worker(): |
482 | self.loop.run_until_complete(self.server.wait_closed()) | 562 | try: |
483 | logger.info('Server shutting down') | 563 | self.loop.run_forever() |
564 | except KeyboardInterrupt: | ||
565 | pass | ||
484 | 566 | ||
485 | if self.close_loop: | 567 | self.server.close() |
486 | self.loop.close() | 568 | |
569 | self.loop.run_until_complete(self.server.wait_closed()) | ||
570 | logger.info('Server shutting down') | ||
571 | finally: | ||
572 | if self.close_loop: | ||
573 | if sys.version_info >= (3, 6): | ||
574 | self.loop.run_until_complete(self.loop.shutdown_asyncgens()) | ||
575 | self.loop.close() | ||
487 | 576 | ||
488 | if self._cleanup_socket is not None: | 577 | if self._cleanup_socket is not None: |
489 | self._cleanup_socket() | 578 | self._cleanup_socket() |
diff --git a/bitbake/lib/hashserv/tests.py b/bitbake/lib/hashserv/tests.py index 4566f24738..3dd9a31bee 100644 --- a/bitbake/lib/hashserv/tests.py +++ b/bitbake/lib/hashserv/tests.py | |||
@@ -16,35 +16,54 @@ import threading | |||
16 | import unittest | 16 | import unittest |
17 | import socket | 17 | import socket |
18 | 18 | ||
19 | def _run_server(server, idx): | ||
20 | # logging.basicConfig(level=logging.DEBUG, filename='bbhashserv.log', filemode='w', | ||
21 | # format='%(levelname)s %(filename)s:%(lineno)d %(message)s') | ||
22 | sys.stdout = open('bbhashserv-%d.log' % idx, 'w') | ||
23 | sys.stderr = sys.stdout | ||
24 | server.serve_forever() | ||
19 | 25 | ||
20 | class TestHashEquivalenceServer(object): | 26 | class TestHashEquivalenceServer(object): |
21 | METHOD = 'TestMethod' | 27 | METHOD = 'TestMethod' |
22 | 28 | ||
23 | def _run_server(self): | 29 | server_index = 0 |
24 | # logging.basicConfig(level=logging.DEBUG, filename='bbhashserv.log', filemode='w', | 30 | |
25 | # format='%(levelname)s %(filename)s:%(lineno)d %(message)s') | 31 | def start_server(self, dbpath=None, upstream=None): |
26 | self.server.serve_forever() | 32 | self.server_index += 1 |
33 | if dbpath is None: | ||
34 | dbpath = os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index) | ||
35 | |||
36 | def cleanup_thread(thread): | ||
37 | thread.terminate() | ||
38 | thread.join() | ||
39 | |||
40 | server = create_server(self.get_server_addr(self.server_index), dbpath, upstream=upstream) | ||
41 | server.dbpath = dbpath | ||
42 | |||
43 | server.thread = multiprocessing.Process(target=_run_server, args=(server, self.server_index)) | ||
44 | server.thread.start() | ||
45 | self.addCleanup(cleanup_thread, server.thread) | ||
46 | |||
47 | def cleanup_client(client): | ||
48 | client.close() | ||
49 | |||
50 | client = create_client(server.address) | ||
51 | self.addCleanup(cleanup_client, client) | ||
52 | |||
53 | return (client, server) | ||
27 | 54 | ||
28 | def setUp(self): | 55 | def setUp(self): |
29 | if sys.version_info < (3, 5, 0): | 56 | if sys.version_info < (3, 5, 0): |
30 | self.skipTest('Python 3.5 or later required') | 57 | self.skipTest('Python 3.5 or later required') |
31 | 58 | ||
32 | self.temp_dir = tempfile.TemporaryDirectory(prefix='bb-hashserv') | 59 | self.temp_dir = tempfile.TemporaryDirectory(prefix='bb-hashserv') |
33 | self.dbfile = os.path.join(self.temp_dir.name, 'db.sqlite') | 60 | self.addCleanup(self.temp_dir.cleanup) |
34 | 61 | ||
35 | self.server = create_server(self.get_server_addr(), self.dbfile) | 62 | (self.client, self.server) = self.start_server() |
36 | self.server_thread = multiprocessing.Process(target=self._run_server) | 63 | |
37 | self.server_thread.start() | 64 | def assertClientGetHash(self, client, taskhash, unihash): |
38 | self.client = create_client(self.server.address) | 65 | result = client.get_unihash(self.METHOD, taskhash) |
39 | 66 | self.assertEqual(result, unihash) | |
40 | def tearDown(self): | ||
41 | # Shutdown server | ||
42 | s = getattr(self, 'server', None) | ||
43 | if s is not None: | ||
44 | self.server_thread.terminate() | ||
45 | self.server_thread.join() | ||
46 | self.client.close() | ||
47 | self.temp_dir.cleanup() | ||
48 | 67 | ||
49 | def test_create_hash(self): | 68 | def test_create_hash(self): |
50 | # Simple test that hashes can be created | 69 | # Simple test that hashes can be created |
@@ -52,8 +71,7 @@ class TestHashEquivalenceServer(object): | |||
52 | outhash = '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f' | 71 | outhash = '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f' |
53 | unihash = 'f46d3fbb439bd9b921095da657a4de906510d2cd' | 72 | unihash = 'f46d3fbb439bd9b921095da657a4de906510d2cd' |
54 | 73 | ||
55 | result = self.client.get_unihash(self.METHOD, taskhash) | 74 | self.assertClientGetHash(self.client, taskhash, None) |
56 | self.assertIsNone(result, msg='Found unexpected task, %r' % result) | ||
57 | 75 | ||
58 | result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) | 76 | result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) |
59 | self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') | 77 | self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') |
@@ -84,22 +102,19 @@ class TestHashEquivalenceServer(object): | |||
84 | unihash = '218e57509998197d570e2c98512d0105985dffc9' | 102 | unihash = '218e57509998197d570e2c98512d0105985dffc9' |
85 | self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) | 103 | self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) |
86 | 104 | ||
87 | result = self.client.get_unihash(self.METHOD, taskhash) | 105 | self.assertClientGetHash(self.client, taskhash, unihash) |
88 | self.assertEqual(result, unihash) | ||
89 | 106 | ||
90 | outhash2 = '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d' | 107 | outhash2 = '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d' |
91 | unihash2 = 'ae9a7d252735f0dafcdb10e2e02561ca3a47314c' | 108 | unihash2 = 'ae9a7d252735f0dafcdb10e2e02561ca3a47314c' |
92 | self.client.report_unihash(taskhash, self.METHOD, outhash2, unihash2) | 109 | self.client.report_unihash(taskhash, self.METHOD, outhash2, unihash2) |
93 | 110 | ||
94 | result = self.client.get_unihash(self.METHOD, taskhash) | 111 | self.assertClientGetHash(self.client, taskhash, unihash) |
95 | self.assertEqual(result, unihash) | ||
96 | 112 | ||
97 | outhash3 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4' | 113 | outhash3 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4' |
98 | unihash3 = '9217a7d6398518e5dc002ed58f2cbbbc78696603' | 114 | unihash3 = '9217a7d6398518e5dc002ed58f2cbbbc78696603' |
99 | self.client.report_unihash(taskhash, self.METHOD, outhash3, unihash3) | 115 | self.client.report_unihash(taskhash, self.METHOD, outhash3, unihash3) |
100 | 116 | ||
101 | result = self.client.get_unihash(self.METHOD, taskhash) | 117 | self.assertClientGetHash(self.client, taskhash, unihash) |
102 | self.assertEqual(result, unihash) | ||
103 | 118 | ||
104 | def test_huge_message(self): | 119 | def test_huge_message(self): |
105 | # Simple test that hashes can be created | 120 | # Simple test that hashes can be created |
@@ -107,8 +122,7 @@ class TestHashEquivalenceServer(object): | |||
107 | outhash = '3c979c3db45c569f51ab7626a4651074be3a9d11a84b1db076f5b14f7d39db44' | 122 | outhash = '3c979c3db45c569f51ab7626a4651074be3a9d11a84b1db076f5b14f7d39db44' |
108 | unihash = '90e9bc1d1f094c51824adca7f8ea79a048d68824' | 123 | unihash = '90e9bc1d1f094c51824adca7f8ea79a048d68824' |
109 | 124 | ||
110 | result = self.client.get_unihash(self.METHOD, taskhash) | 125 | self.assertClientGetHash(self.client, taskhash, None) |
111 | self.assertIsNone(result, msg='Found unexpected task, %r' % result) | ||
112 | 126 | ||
113 | siginfo = "0" * (self.client.max_chunk * 4) | 127 | siginfo = "0" * (self.client.max_chunk * 4) |
114 | 128 | ||
@@ -156,14 +170,83 @@ class TestHashEquivalenceServer(object): | |||
156 | 170 | ||
157 | self.assertFalse(failures) | 171 | self.assertFalse(failures) |
158 | 172 | ||
173 | def test_upstream_server(self): | ||
174 | # Tests upstream server support. This is done by creating two servers | ||
175 | # that share a database file. The downstream server has it upstream | ||
176 | # set to the test server, whereas the side server doesn't. This allows | ||
177 | # verification that the hash requests are being proxied to the upstream | ||
178 | # server by verifying that they appear on the downstream client, but not | ||
179 | # the side client. It also verifies that the results are pulled into | ||
180 | # the downstream database by checking that the downstream and side servers | ||
181 | # match after the downstream is done waiting for all backfill tasks | ||
182 | (down_client, down_server) = self.start_server(upstream=self.server.address) | ||
183 | (side_client, side_server) = self.start_server(dbpath=down_server.dbpath) | ||
184 | |||
185 | def check_hash(taskhash, unihash, old_sidehash): | ||
186 | nonlocal down_client | ||
187 | nonlocal side_client | ||
188 | |||
189 | # check upstream server | ||
190 | self.assertClientGetHash(self.client, taskhash, unihash) | ||
191 | |||
192 | # Hash should *not* be present on the side server | ||
193 | self.assertClientGetHash(side_client, taskhash, old_sidehash) | ||
194 | |||
195 | # Hash should be present on the downstream server, since it | ||
196 | # will defer to the upstream server. This will trigger | ||
197 | # the backfill in the downstream server | ||
198 | self.assertClientGetHash(down_client, taskhash, unihash) | ||
199 | |||
200 | # After waiting for the downstream client to finish backfilling the | ||
201 | # task from the upstream server, it should appear in the side server | ||
202 | # since the database is populated | ||
203 | down_client.backfill_wait() | ||
204 | self.assertClientGetHash(side_client, taskhash, unihash) | ||
205 | |||
206 | # Basic report | ||
207 | taskhash = '8aa96fcffb5831b3c2c0cb75f0431e3f8b20554a' | ||
208 | outhash = 'afe240a439959ce86f5e322f8c208e1fedefea9e813f2140c81af866cc9edf7e' | ||
209 | unihash = '218e57509998197d570e2c98512d0105985dffc9' | ||
210 | self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) | ||
211 | |||
212 | check_hash(taskhash, unihash, None) | ||
213 | |||
214 | # Duplicated taskhash with multiple output hashes and unihashes. | ||
215 | # All servers should agree with the originally reported hash | ||
216 | outhash2 = '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d' | ||
217 | unihash2 = 'ae9a7d252735f0dafcdb10e2e02561ca3a47314c' | ||
218 | self.client.report_unihash(taskhash, self.METHOD, outhash2, unihash2) | ||
219 | |||
220 | check_hash(taskhash, unihash, unihash) | ||
221 | |||
222 | # Report an equivalent task. The sideload will originally report | ||
223 | # no unihash until backfilled | ||
224 | taskhash3 = "044c2ec8aaf480685a00ff6ff49e6162e6ad34e1" | ||
225 | unihash3 = "def64766090d28f627e816454ed46894bb3aab36" | ||
226 | self.client.report_unihash(taskhash3, self.METHOD, outhash, unihash3) | ||
227 | |||
228 | check_hash(taskhash3, unihash, None) | ||
229 | |||
230 | # Test that reporting a unihash in the downstream client isn't | ||
231 | # propagating to the upstream server | ||
232 | taskhash4 = "e3da00593d6a7fb435c7e2114976c59c5fd6d561" | ||
233 | outhash4 = "1cf8713e645f491eb9c959d20b5cae1c47133a292626dda9b10709857cbe688a" | ||
234 | unihash4 = "3b5d3d83f07f259e9086fcb422c855286e18a57d" | ||
235 | down_client.report_unihash(taskhash4, self.METHOD, outhash4, unihash4) | ||
236 | down_client.backfill_wait() | ||
237 | |||
238 | self.assertClientGetHash(down_client, taskhash4, unihash4) | ||
239 | self.assertClientGetHash(side_client, taskhash4, unihash4) | ||
240 | self.assertClientGetHash(self.client, taskhash4, None) | ||
241 | |||
159 | 242 | ||
160 | class TestHashEquivalenceUnixServer(TestHashEquivalenceServer, unittest.TestCase): | 243 | class TestHashEquivalenceUnixServer(TestHashEquivalenceServer, unittest.TestCase): |
161 | def get_server_addr(self): | 244 | def get_server_addr(self, server_idx): |
162 | return "unix://" + os.path.join(self.temp_dir.name, 'sock') | 245 | return "unix://" + os.path.join(self.temp_dir.name, 'sock%d' % server_idx) |
163 | 246 | ||
164 | 247 | ||
165 | class TestHashEquivalenceTCPServer(TestHashEquivalenceServer, unittest.TestCase): | 248 | class TestHashEquivalenceTCPServer(TestHashEquivalenceServer, unittest.TestCase): |
166 | def get_server_addr(self): | 249 | def get_server_addr(self, server_idx): |
167 | # Some hosts cause asyncio module to misbehave, when IPv6 is not enabled. | 250 | # Some hosts cause asyncio module to misbehave, when IPv6 is not enabled. |
168 | # If IPv6 is enabled, it should be safe to use localhost directly, in general | 251 | # If IPv6 is enabled, it should be safe to use localhost directly, in general |
169 | # case it is more reliable to resolve the IP address explicitly. | 252 | # case it is more reliable to resolve the IP address explicitly. |