summaryrefslogtreecommitdiffstats
path: root/bitbake
diff options
context:
space:
mode:
authorJoshua Watt <JPEWhacker@gmail.com>2020-11-10 08:59:56 -0600
committerRichard Purdie <richard.purdie@linuxfoundation.org>2020-11-24 15:26:12 +0000
commit96b548a79d87120655da3ac5501b8ad4726cf1a4 (patch)
tree06938cfe533173ad02664dc2f187867a165c5871 /bitbake
parent859f43e176dcaaa652e24a2289abd75e18c077cf (diff)
downloadpoky-96b548a79d87120655da3ac5501b8ad4726cf1a4.tar.gz
bitbake: bitbake: hashserve: Add support for readonly upstream
Adds support for an upstream server to be specified. The upstream server will be queried for equivalent hashes whenever a miss is found in the local server. If the server returns a match, it is merged into the local database. In order to keep the get stream queries as fast as possible since they are the critical path when bitbake is preparing the run queue, missing tasks provided by the server are not immediately pulled from the upstream server, but instead are put into a queue to be backfilled by a worker task later. (Bitbake rev: e6d6c0b39393e9bdf378c1eba141f815e26b724b) Signed-off-by: Joshua Watt <JPEWhacker@gmail.com> Signed-off-by: Richard Purdie <richard.purdie@linuxfoundation.org>
Diffstat (limited to 'bitbake')
-rw-r--r--bitbake/lib/hashserv/__init__.py39
-rw-r--r--bitbake/lib/hashserv/client.py19
-rw-r--r--bitbake/lib/hashserv/server.py149
-rw-r--r--bitbake/lib/hashserv/tests.py147
4 files changed, 268 insertions, 86 deletions
diff --git a/bitbake/lib/hashserv/__init__.py b/bitbake/lib/hashserv/__init__.py
index 622ca17a91..55f48410d3 100644
--- a/bitbake/lib/hashserv/__init__.py
+++ b/bitbake/lib/hashserv/__init__.py
@@ -22,6 +22,24 @@ ADDR_TYPE_TCP = 1
22# is necessary 22# is necessary
23DEFAULT_MAX_CHUNK = 32 * 1024 23DEFAULT_MAX_CHUNK = 32 * 1024
24 24
25TABLE_DEFINITION = (
26 ("method", "TEXT NOT NULL"),
27 ("outhash", "TEXT NOT NULL"),
28 ("taskhash", "TEXT NOT NULL"),
29 ("unihash", "TEXT NOT NULL"),
30 ("created", "DATETIME"),
31
32 # Optional fields
33 ("owner", "TEXT"),
34 ("PN", "TEXT"),
35 ("PV", "TEXT"),
36 ("PR", "TEXT"),
37 ("task", "TEXT"),
38 ("outhash_siginfo", "TEXT"),
39)
40
41TABLE_COLUMNS = tuple(name for name, _ in TABLE_DEFINITION)
42
25def setup_database(database, sync=True): 43def setup_database(database, sync=True):
26 db = sqlite3.connect(database) 44 db = sqlite3.connect(database)
27 db.row_factory = sqlite3.Row 45 db.row_factory = sqlite3.Row
@@ -30,23 +48,10 @@ def setup_database(database, sync=True):
30 cursor.execute(''' 48 cursor.execute('''
31 CREATE TABLE IF NOT EXISTS tasks_v2 ( 49 CREATE TABLE IF NOT EXISTS tasks_v2 (
32 id INTEGER PRIMARY KEY AUTOINCREMENT, 50 id INTEGER PRIMARY KEY AUTOINCREMENT,
33 method TEXT NOT NULL, 51 %s
34 outhash TEXT NOT NULL,
35 taskhash TEXT NOT NULL,
36 unihash TEXT NOT NULL,
37 created DATETIME,
38
39 -- Optional fields
40 owner TEXT,
41 PN TEXT,
42 PV TEXT,
43 PR TEXT,
44 task TEXT,
45 outhash_siginfo TEXT,
46
47 UNIQUE(method, outhash, taskhash) 52 UNIQUE(method, outhash, taskhash)
48 ) 53 )
49 ''') 54 ''' % " ".join("%s %s," % (name, typ) for name, typ in TABLE_DEFINITION))
50 cursor.execute('PRAGMA journal_mode = WAL') 55 cursor.execute('PRAGMA journal_mode = WAL')
51 cursor.execute('PRAGMA synchronous = %s' % ('NORMAL' if sync else 'OFF')) 56 cursor.execute('PRAGMA synchronous = %s' % ('NORMAL' if sync else 'OFF'))
52 57
@@ -89,10 +94,10 @@ def chunkify(msg, max_chunk):
89 yield "\n" 94 yield "\n"
90 95
91 96
92def create_server(addr, dbname, *, sync=True): 97def create_server(addr, dbname, *, sync=True, upstream=None):
93 from . import server 98 from . import server
94 db = setup_database(dbname, sync=sync) 99 db = setup_database(dbname, sync=sync)
95 s = server.Server(db) 100 s = server.Server(db, upstream=upstream)
96 101
97 (typ, a) = parse_address(addr) 102 (typ, a) = parse_address(addr)
98 if typ == ADDR_TYPE_UNIX: 103 if typ == ADDR_TYPE_UNIX:
diff --git a/bitbake/lib/hashserv/client.py b/bitbake/lib/hashserv/client.py
index d0b3cf3863..ae5875d1b3 100644
--- a/bitbake/lib/hashserv/client.py
+++ b/bitbake/lib/hashserv/client.py
@@ -178,18 +178,16 @@ class AsyncClient(object):
178 await self._set_mode(self.MODE_NORMAL) 178 await self._set_mode(self.MODE_NORMAL)
179 return await self.send_message({"reset-stats": None}) 179 return await self.send_message({"reset-stats": None})
180 180
181 async def backfill_wait(self):
182 await self._set_mode(self.MODE_NORMAL)
183 return (await self.send_message({"backfill-wait": None}))["tasks"]
184
181 185
182class Client(object): 186class Client(object):
183 def __init__(self): 187 def __init__(self):
184 self.client = AsyncClient() 188 self.client = AsyncClient()
185 self.loop = asyncio.new_event_loop() 189 self.loop = asyncio.new_event_loop()
186 190
187 def get_wrapper(self, downcall):
188 def wrapper(*args, **kwargs):
189 return self.loop.run_until_complete(downcall(*args, **kwargs))
190
191 return wrapper
192
193 for call in ( 191 for call in (
194 "connect_tcp", 192 "connect_tcp",
195 "connect_unix", 193 "connect_unix",
@@ -200,9 +198,16 @@ class Client(object):
200 "get_taskhash", 198 "get_taskhash",
201 "get_stats", 199 "get_stats",
202 "reset_stats", 200 "reset_stats",
201 "backfill_wait",
203 ): 202 ):
204 downcall = getattr(self.client, call) 203 downcall = getattr(self.client, call)
205 setattr(self, call, get_wrapper(self, downcall)) 204 setattr(self, call, self._get_downcall_wrapper(downcall))
205
206 def _get_downcall_wrapper(self, downcall):
207 def wrapper(*args, **kwargs):
208 return self.loop.run_until_complete(downcall(*args, **kwargs))
209
210 return wrapper
206 211
207 @property 212 @property
208 def max_chunk(self): 213 def max_chunk(self):
diff --git a/bitbake/lib/hashserv/server.py b/bitbake/lib/hashserv/server.py
index 81050715ea..3ff4c51ccb 100644
--- a/bitbake/lib/hashserv/server.py
+++ b/bitbake/lib/hashserv/server.py
@@ -3,7 +3,7 @@
3# SPDX-License-Identifier: GPL-2.0-only 3# SPDX-License-Identifier: GPL-2.0-only
4# 4#
5 5
6from contextlib import closing 6from contextlib import closing, contextmanager
7from datetime import datetime 7from datetime import datetime
8import asyncio 8import asyncio
9import json 9import json
@@ -12,8 +12,9 @@ import math
12import os 12import os
13import signal 13import signal
14import socket 14import socket
15import sys
15import time 16import time
16from . import chunkify, DEFAULT_MAX_CHUNK 17from . import chunkify, DEFAULT_MAX_CHUNK, create_async_client, TABLE_COLUMNS
17 18
18logger = logging.getLogger('hashserv.server') 19logger = logging.getLogger('hashserv.server')
19 20
@@ -111,16 +112,40 @@ class Stats(object):
111class ClientError(Exception): 112class ClientError(Exception):
112 pass 113 pass
113 114
115def insert_task(cursor, data, ignore=False):
116 keys = sorted(data.keys())
117 query = '''INSERT%s INTO tasks_v2 (%s) VALUES (%s)''' % (
118 " OR IGNORE" if ignore else "",
119 ', '.join(keys),
120 ', '.join(':' + k for k in keys))
121 cursor.execute(query, data)
122
123async def copy_from_upstream(client, db, method, taskhash):
124 d = await client.get_taskhash(method, taskhash, True)
125 if d is not None:
126 # Filter out unknown columns
127 d = {k: v for k, v in d.items() if k in TABLE_COLUMNS}
128 keys = sorted(d.keys())
129
130
131 with closing(db.cursor()) as cursor:
132 insert_task(cursor, d)
133 db.commit()
134
135 return d
136
114class ServerClient(object): 137class ServerClient(object):
115 FAST_QUERY = 'SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1' 138 FAST_QUERY = 'SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1'
116 ALL_QUERY = 'SELECT * FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1' 139 ALL_QUERY = 'SELECT * FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1'
117 140
118 def __init__(self, reader, writer, db, request_stats): 141 def __init__(self, reader, writer, db, request_stats, backfill_queue, upstream):
119 self.reader = reader 142 self.reader = reader
120 self.writer = writer 143 self.writer = writer
121 self.db = db 144 self.db = db
122 self.request_stats = request_stats 145 self.request_stats = request_stats
123 self.max_chunk = DEFAULT_MAX_CHUNK 146 self.max_chunk = DEFAULT_MAX_CHUNK
147 self.backfill_queue = backfill_queue
148 self.upstream = upstream
124 149
125 self.handlers = { 150 self.handlers = {
126 'get': self.handle_get, 151 'get': self.handle_get,
@@ -130,10 +155,18 @@ class ServerClient(object):
130 'get-stats': self.handle_get_stats, 155 'get-stats': self.handle_get_stats,
131 'reset-stats': self.handle_reset_stats, 156 'reset-stats': self.handle_reset_stats,
132 'chunk-stream': self.handle_chunk, 157 'chunk-stream': self.handle_chunk,
158 'backfill-wait': self.handle_backfill_wait,
133 } 159 }
134 160
135 async def process_requests(self): 161 async def process_requests(self):
162 if self.upstream is not None:
163 self.upstream_client = await create_async_client(self.upstream)
164 else:
165 self.upstream_client = None
166
136 try: 167 try:
168
169
137 self.addr = self.writer.get_extra_info('peername') 170 self.addr = self.writer.get_extra_info('peername')
138 logger.debug('Client %r connected' % (self.addr,)) 171 logger.debug('Client %r connected' % (self.addr,))
139 172
@@ -171,6 +204,9 @@ class ServerClient(object):
171 except ClientError as e: 204 except ClientError as e:
172 logger.error(str(e)) 205 logger.error(str(e))
173 finally: 206 finally:
207 if self.upstream_client is not None:
208 await self.upstream_client.close()
209
174 self.writer.close() 210 self.writer.close()
175 211
176 async def dispatch_message(self, msg): 212 async def dispatch_message(self, msg):
@@ -239,15 +275,19 @@ class ServerClient(object):
239 if row is not None: 275 if row is not None:
240 logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) 276 logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
241 d = {k: row[k] for k in row.keys()} 277 d = {k: row[k] for k in row.keys()}
242 278 elif self.upstream_client is not None:
243 self.write_message(d) 279 d = await copy_from_upstream(self.upstream_client, self.db, method, taskhash)
244 else: 280 else:
245 self.write_message(None) 281 d = None
282
283 self.write_message(d)
246 284
247 async def handle_get_stream(self, request): 285 async def handle_get_stream(self, request):
248 self.write_message('ok') 286 self.write_message('ok')
249 287
250 while True: 288 while True:
289 upstream = None
290
251 l = await self.reader.readline() 291 l = await self.reader.readline()
252 if not l: 292 if not l:
253 return 293 return
@@ -272,6 +312,12 @@ class ServerClient(object):
272 if row is not None: 312 if row is not None:
273 msg = ('%s\n' % row['unihash']).encode('utf-8') 313 msg = ('%s\n' % row['unihash']).encode('utf-8')
274 #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) 314 #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
315 elif self.upstream_client is not None:
316 upstream = await self.upstream_client.get_unihash(method, taskhash)
317 if upstream:
318 msg = ("%s\n" % upstream).encode("utf-8")
319 else:
320 msg = "\n".encode("utf-8")
275 else: 321 else:
276 msg = '\n'.encode('utf-8') 322 msg = '\n'.encode('utf-8')
277 323
@@ -282,6 +328,11 @@ class ServerClient(object):
282 328
283 await self.writer.drain() 329 await self.writer.drain()
284 330
331 # Post to the backfill queue after writing the result to minimize
332 # the turn around time on a request
333 if upstream is not None:
334 await self.backfill_queue.put((method, taskhash))
335
285 async def handle_report(self, data): 336 async def handle_report(self, data):
286 with closing(self.db.cursor()) as cursor: 337 with closing(self.db.cursor()) as cursor:
287 cursor.execute(''' 338 cursor.execute('''
@@ -324,11 +375,7 @@ class ServerClient(object):
324 if k in data: 375 if k in data:
325 insert_data[k] = data[k] 376 insert_data[k] = data[k]
326 377
327 cursor.execute('''INSERT INTO tasks_v2 (%s) VALUES (%s)''' % ( 378 insert_task(cursor, insert_data)
328 ', '.join(sorted(insert_data.keys())),
329 ', '.join(':' + k for k in sorted(insert_data.keys()))),
330 insert_data)
331
332 self.db.commit() 379 self.db.commit()
333 380
334 logger.info('Adding taskhash %s with unihash %s', 381 logger.info('Adding taskhash %s with unihash %s',
@@ -358,11 +405,7 @@ class ServerClient(object):
358 if k in data: 405 if k in data:
359 insert_data[k] = data[k] 406 insert_data[k] = data[k]
360 407
361 cursor.execute('''INSERT OR IGNORE INTO tasks_v2 (%s) VALUES (%s)''' % ( 408 insert_task(cursor, insert_data, ignore=True)
362 ', '.join(sorted(insert_data.keys())),
363 ', '.join(':' + k for k in sorted(insert_data.keys()))),
364 insert_data)
365
366 self.db.commit() 409 self.db.commit()
367 410
368 # Fetch the unihash that will be reported for the taskhash. If the 411 # Fetch the unihash that will be reported for the taskhash. If the
@@ -394,6 +437,13 @@ class ServerClient(object):
394 self.request_stats.reset() 437 self.request_stats.reset()
395 self.write_message(d) 438 self.write_message(d)
396 439
440 async def handle_backfill_wait(self, request):
441 d = {
442 'tasks': self.backfill_queue.qsize(),
443 }
444 await self.backfill_queue.join()
445 self.write_message(d)
446
397 def query_equivalent(self, method, taskhash, query): 447 def query_equivalent(self, method, taskhash, query):
398 # This is part of the inner loop and must be as fast as possible 448 # This is part of the inner loop and must be as fast as possible
399 try: 449 try:
@@ -405,7 +455,7 @@ class ServerClient(object):
405 455
406 456
407class Server(object): 457class Server(object):
408 def __init__(self, db, loop=None): 458 def __init__(self, db, loop=None, upstream=None):
409 self.request_stats = Stats() 459 self.request_stats = Stats()
410 self.db = db 460 self.db = db
411 461
@@ -416,6 +466,8 @@ class Server(object):
416 self.loop = loop 466 self.loop = loop
417 self.close_loop = False 467 self.close_loop = False
418 468
469 self.upstream = upstream
470
419 self._cleanup_socket = None 471 self._cleanup_socket = None
420 472
421 def start_tcp_server(self, host, port): 473 def start_tcp_server(self, host, port):
@@ -458,7 +510,7 @@ class Server(object):
458 async def handle_client(self, reader, writer): 510 async def handle_client(self, reader, writer):
459 # writer.transport.set_write_buffer_limits(0) 511 # writer.transport.set_write_buffer_limits(0)
460 try: 512 try:
461 client = ServerClient(reader, writer, self.db, self.request_stats) 513 client = ServerClient(reader, writer, self.db, self.request_stats, self.backfill_queue, self.upstream)
462 await client.process_requests() 514 await client.process_requests()
463 except Exception as e: 515 except Exception as e:
464 import traceback 516 import traceback
@@ -467,23 +519,60 @@ class Server(object):
467 writer.close() 519 writer.close()
468 logger.info('Client disconnected') 520 logger.info('Client disconnected')
469 521
522 @contextmanager
523 def _backfill_worker(self):
524 async def backfill_worker_task():
525 client = await create_async_client(self.upstream)
526 try:
527 while True:
528 item = await self.backfill_queue.get()
529 if item is None:
530 self.backfill_queue.task_done()
531 break
532 method, taskhash = item
533 await copy_from_upstream(client, self.db, method, taskhash)
534 self.backfill_queue.task_done()
535 finally:
536 await client.close()
537
538 async def join_worker(worker):
539 await self.backfill_queue.put(None)
540 await worker
541
542 if self.upstream is not None:
543 worker = asyncio.ensure_future(backfill_worker_task())
544 try:
545 yield
546 finally:
547 self.loop.run_until_complete(join_worker(worker))
548 else:
549 yield
550
470 def serve_forever(self): 551 def serve_forever(self):
471 def signal_handler(): 552 def signal_handler():
472 self.loop.stop() 553 self.loop.stop()
473 554
474 self.loop.add_signal_handler(signal.SIGTERM, signal_handler) 555 asyncio.set_event_loop(self.loop)
475
476 try: 556 try:
477 self.loop.run_forever() 557 self.backfill_queue = asyncio.Queue()
478 except KeyboardInterrupt: 558
479 pass 559 self.loop.add_signal_handler(signal.SIGTERM, signal_handler)
480 560
481 self.server.close() 561 with self._backfill_worker():
482 self.loop.run_until_complete(self.server.wait_closed()) 562 try:
483 logger.info('Server shutting down') 563 self.loop.run_forever()
564 except KeyboardInterrupt:
565 pass
484 566
485 if self.close_loop: 567 self.server.close()
486 self.loop.close() 568
569 self.loop.run_until_complete(self.server.wait_closed())
570 logger.info('Server shutting down')
571 finally:
572 if self.close_loop:
573 if sys.version_info >= (3, 6):
574 self.loop.run_until_complete(self.loop.shutdown_asyncgens())
575 self.loop.close()
487 576
488 if self._cleanup_socket is not None: 577 if self._cleanup_socket is not None:
489 self._cleanup_socket() 578 self._cleanup_socket()
diff --git a/bitbake/lib/hashserv/tests.py b/bitbake/lib/hashserv/tests.py
index 4566f24738..3dd9a31bee 100644
--- a/bitbake/lib/hashserv/tests.py
+++ b/bitbake/lib/hashserv/tests.py
@@ -16,35 +16,54 @@ import threading
16import unittest 16import unittest
17import socket 17import socket
18 18
19def _run_server(server, idx):
20 # logging.basicConfig(level=logging.DEBUG, filename='bbhashserv.log', filemode='w',
21 # format='%(levelname)s %(filename)s:%(lineno)d %(message)s')
22 sys.stdout = open('bbhashserv-%d.log' % idx, 'w')
23 sys.stderr = sys.stdout
24 server.serve_forever()
19 25
20class TestHashEquivalenceServer(object): 26class TestHashEquivalenceServer(object):
21 METHOD = 'TestMethod' 27 METHOD = 'TestMethod'
22 28
23 def _run_server(self): 29 server_index = 0
24 # logging.basicConfig(level=logging.DEBUG, filename='bbhashserv.log', filemode='w', 30
25 # format='%(levelname)s %(filename)s:%(lineno)d %(message)s') 31 def start_server(self, dbpath=None, upstream=None):
26 self.server.serve_forever() 32 self.server_index += 1
33 if dbpath is None:
34 dbpath = os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index)
35
36 def cleanup_thread(thread):
37 thread.terminate()
38 thread.join()
39
40 server = create_server(self.get_server_addr(self.server_index), dbpath, upstream=upstream)
41 server.dbpath = dbpath
42
43 server.thread = multiprocessing.Process(target=_run_server, args=(server, self.server_index))
44 server.thread.start()
45 self.addCleanup(cleanup_thread, server.thread)
46
47 def cleanup_client(client):
48 client.close()
49
50 client = create_client(server.address)
51 self.addCleanup(cleanup_client, client)
52
53 return (client, server)
27 54
28 def setUp(self): 55 def setUp(self):
29 if sys.version_info < (3, 5, 0): 56 if sys.version_info < (3, 5, 0):
30 self.skipTest('Python 3.5 or later required') 57 self.skipTest('Python 3.5 or later required')
31 58
32 self.temp_dir = tempfile.TemporaryDirectory(prefix='bb-hashserv') 59 self.temp_dir = tempfile.TemporaryDirectory(prefix='bb-hashserv')
33 self.dbfile = os.path.join(self.temp_dir.name, 'db.sqlite') 60 self.addCleanup(self.temp_dir.cleanup)
34 61
35 self.server = create_server(self.get_server_addr(), self.dbfile) 62 (self.client, self.server) = self.start_server()
36 self.server_thread = multiprocessing.Process(target=self._run_server) 63
37 self.server_thread.start() 64 def assertClientGetHash(self, client, taskhash, unihash):
38 self.client = create_client(self.server.address) 65 result = client.get_unihash(self.METHOD, taskhash)
39 66 self.assertEqual(result, unihash)
40 def tearDown(self):
41 # Shutdown server
42 s = getattr(self, 'server', None)
43 if s is not None:
44 self.server_thread.terminate()
45 self.server_thread.join()
46 self.client.close()
47 self.temp_dir.cleanup()
48 67
49 def test_create_hash(self): 68 def test_create_hash(self):
50 # Simple test that hashes can be created 69 # Simple test that hashes can be created
@@ -52,8 +71,7 @@ class TestHashEquivalenceServer(object):
52 outhash = '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f' 71 outhash = '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f'
53 unihash = 'f46d3fbb439bd9b921095da657a4de906510d2cd' 72 unihash = 'f46d3fbb439bd9b921095da657a4de906510d2cd'
54 73
55 result = self.client.get_unihash(self.METHOD, taskhash) 74 self.assertClientGetHash(self.client, taskhash, None)
56 self.assertIsNone(result, msg='Found unexpected task, %r' % result)
57 75
58 result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) 76 result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
59 self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') 77 self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
@@ -84,22 +102,19 @@ class TestHashEquivalenceServer(object):
84 unihash = '218e57509998197d570e2c98512d0105985dffc9' 102 unihash = '218e57509998197d570e2c98512d0105985dffc9'
85 self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) 103 self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
86 104
87 result = self.client.get_unihash(self.METHOD, taskhash) 105 self.assertClientGetHash(self.client, taskhash, unihash)
88 self.assertEqual(result, unihash)
89 106
90 outhash2 = '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d' 107 outhash2 = '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d'
91 unihash2 = 'ae9a7d252735f0dafcdb10e2e02561ca3a47314c' 108 unihash2 = 'ae9a7d252735f0dafcdb10e2e02561ca3a47314c'
92 self.client.report_unihash(taskhash, self.METHOD, outhash2, unihash2) 109 self.client.report_unihash(taskhash, self.METHOD, outhash2, unihash2)
93 110
94 result = self.client.get_unihash(self.METHOD, taskhash) 111 self.assertClientGetHash(self.client, taskhash, unihash)
95 self.assertEqual(result, unihash)
96 112
97 outhash3 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4' 113 outhash3 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4'
98 unihash3 = '9217a7d6398518e5dc002ed58f2cbbbc78696603' 114 unihash3 = '9217a7d6398518e5dc002ed58f2cbbbc78696603'
99 self.client.report_unihash(taskhash, self.METHOD, outhash3, unihash3) 115 self.client.report_unihash(taskhash, self.METHOD, outhash3, unihash3)
100 116
101 result = self.client.get_unihash(self.METHOD, taskhash) 117 self.assertClientGetHash(self.client, taskhash, unihash)
102 self.assertEqual(result, unihash)
103 118
104 def test_huge_message(self): 119 def test_huge_message(self):
105 # Simple test that hashes can be created 120 # Simple test that hashes can be created
@@ -107,8 +122,7 @@ class TestHashEquivalenceServer(object):
107 outhash = '3c979c3db45c569f51ab7626a4651074be3a9d11a84b1db076f5b14f7d39db44' 122 outhash = '3c979c3db45c569f51ab7626a4651074be3a9d11a84b1db076f5b14f7d39db44'
108 unihash = '90e9bc1d1f094c51824adca7f8ea79a048d68824' 123 unihash = '90e9bc1d1f094c51824adca7f8ea79a048d68824'
109 124
110 result = self.client.get_unihash(self.METHOD, taskhash) 125 self.assertClientGetHash(self.client, taskhash, None)
111 self.assertIsNone(result, msg='Found unexpected task, %r' % result)
112 126
113 siginfo = "0" * (self.client.max_chunk * 4) 127 siginfo = "0" * (self.client.max_chunk * 4)
114 128
@@ -156,14 +170,83 @@ class TestHashEquivalenceServer(object):
156 170
157 self.assertFalse(failures) 171 self.assertFalse(failures)
158 172
173 def test_upstream_server(self):
174 # Tests upstream server support. This is done by creating two servers
175 # that share a database file. The downstream server has it upstream
176 # set to the test server, whereas the side server doesn't. This allows
177 # verification that the hash requests are being proxied to the upstream
178 # server by verifying that they appear on the downstream client, but not
179 # the side client. It also verifies that the results are pulled into
180 # the downstream database by checking that the downstream and side servers
181 # match after the downstream is done waiting for all backfill tasks
182 (down_client, down_server) = self.start_server(upstream=self.server.address)
183 (side_client, side_server) = self.start_server(dbpath=down_server.dbpath)
184
185 def check_hash(taskhash, unihash, old_sidehash):
186 nonlocal down_client
187 nonlocal side_client
188
189 # check upstream server
190 self.assertClientGetHash(self.client, taskhash, unihash)
191
192 # Hash should *not* be present on the side server
193 self.assertClientGetHash(side_client, taskhash, old_sidehash)
194
195 # Hash should be present on the downstream server, since it
196 # will defer to the upstream server. This will trigger
197 # the backfill in the downstream server
198 self.assertClientGetHash(down_client, taskhash, unihash)
199
200 # After waiting for the downstream client to finish backfilling the
201 # task from the upstream server, it should appear in the side server
202 # since the database is populated
203 down_client.backfill_wait()
204 self.assertClientGetHash(side_client, taskhash, unihash)
205
206 # Basic report
207 taskhash = '8aa96fcffb5831b3c2c0cb75f0431e3f8b20554a'
208 outhash = 'afe240a439959ce86f5e322f8c208e1fedefea9e813f2140c81af866cc9edf7e'
209 unihash = '218e57509998197d570e2c98512d0105985dffc9'
210 self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
211
212 check_hash(taskhash, unihash, None)
213
214 # Duplicated taskhash with multiple output hashes and unihashes.
215 # All servers should agree with the originally reported hash
216 outhash2 = '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d'
217 unihash2 = 'ae9a7d252735f0dafcdb10e2e02561ca3a47314c'
218 self.client.report_unihash(taskhash, self.METHOD, outhash2, unihash2)
219
220 check_hash(taskhash, unihash, unihash)
221
222 # Report an equivalent task. The sideload will originally report
223 # no unihash until backfilled
224 taskhash3 = "044c2ec8aaf480685a00ff6ff49e6162e6ad34e1"
225 unihash3 = "def64766090d28f627e816454ed46894bb3aab36"
226 self.client.report_unihash(taskhash3, self.METHOD, outhash, unihash3)
227
228 check_hash(taskhash3, unihash, None)
229
230 # Test that reporting a unihash in the downstream client isn't
231 # propagating to the upstream server
232 taskhash4 = "e3da00593d6a7fb435c7e2114976c59c5fd6d561"
233 outhash4 = "1cf8713e645f491eb9c959d20b5cae1c47133a292626dda9b10709857cbe688a"
234 unihash4 = "3b5d3d83f07f259e9086fcb422c855286e18a57d"
235 down_client.report_unihash(taskhash4, self.METHOD, outhash4, unihash4)
236 down_client.backfill_wait()
237
238 self.assertClientGetHash(down_client, taskhash4, unihash4)
239 self.assertClientGetHash(side_client, taskhash4, unihash4)
240 self.assertClientGetHash(self.client, taskhash4, None)
241
159 242
160class TestHashEquivalenceUnixServer(TestHashEquivalenceServer, unittest.TestCase): 243class TestHashEquivalenceUnixServer(TestHashEquivalenceServer, unittest.TestCase):
161 def get_server_addr(self): 244 def get_server_addr(self, server_idx):
162 return "unix://" + os.path.join(self.temp_dir.name, 'sock') 245 return "unix://" + os.path.join(self.temp_dir.name, 'sock%d' % server_idx)
163 246
164 247
165class TestHashEquivalenceTCPServer(TestHashEquivalenceServer, unittest.TestCase): 248class TestHashEquivalenceTCPServer(TestHashEquivalenceServer, unittest.TestCase):
166 def get_server_addr(self): 249 def get_server_addr(self, server_idx):
167 # Some hosts cause asyncio module to misbehave, when IPv6 is not enabled. 250 # Some hosts cause asyncio module to misbehave, when IPv6 is not enabled.
168 # If IPv6 is enabled, it should be safe to use localhost directly, in general 251 # If IPv6 is enabled, it should be safe to use localhost directly, in general
169 # case it is more reliable to resolve the IP address explicitly. 252 # case it is more reliable to resolve the IP address explicitly.