summaryrefslogtreecommitdiffstats
path: root/bitbake/lib
diff options
context:
space:
mode:
Diffstat (limited to 'bitbake/lib')
-rw-r--r--bitbake/lib/bb/cooker.py17
-rw-r--r--bitbake/lib/bb/runqueue.py4
-rw-r--r--bitbake/lib/bb/siggen.py74
-rw-r--r--bitbake/lib/bb/tests/runqueue.py19
-rw-r--r--bitbake/lib/hashserv/__init__.py261
-rw-r--r--bitbake/lib/hashserv/client.py156
-rw-r--r--bitbake/lib/hashserv/server.py414
-rw-r--r--bitbake/lib/hashserv/tests.py159
8 files changed, 767 insertions, 337 deletions
diff --git a/bitbake/lib/bb/cooker.py b/bitbake/lib/bb/cooker.py
index e46868ddd0..0c540028ae 100644
--- a/bitbake/lib/bb/cooker.py
+++ b/bitbake/lib/bb/cooker.py
@@ -194,7 +194,7 @@ class BBCooker:
194 194
195 self.ui_cmdline = None 195 self.ui_cmdline = None
196 self.hashserv = None 196 self.hashserv = None
197 self.hashservport = None 197 self.hashservaddr = None
198 198
199 self.initConfigurationData() 199 self.initConfigurationData()
200 200
@@ -392,19 +392,20 @@ class BBCooker:
392 except prserv.serv.PRServiceConfigError as e: 392 except prserv.serv.PRServiceConfigError as e:
393 bb.fatal("Unable to start PR Server, exitting") 393 bb.fatal("Unable to start PR Server, exitting")
394 394
395 if self.data.getVar("BB_HASHSERVE") == "localhost:0": 395 if self.data.getVar("BB_HASHSERVE") == "auto":
396 # Create a new hash server bound to a unix domain socket
396 if not self.hashserv: 397 if not self.hashserv:
397 dbfile = (self.data.getVar("PERSISTENT_DIR") or self.data.getVar("CACHE")) + "/hashserv.db" 398 dbfile = (self.data.getVar("PERSISTENT_DIR") or self.data.getVar("CACHE")) + "/hashserv.db"
398 self.hashserv = hashserv.create_server(('localhost', 0), dbfile, '') 399 self.hashservaddr = "unix://%s/hashserve.sock" % self.data.getVar("TOPDIR")
399 self.hashservport = "localhost:" + str(self.hashserv.server_port) 400 self.hashserv = hashserv.create_server(self.hashservaddr, dbfile, sync=False)
400 self.hashserv.process = multiprocessing.Process(target=self.hashserv.serve_forever) 401 self.hashserv.process = multiprocessing.Process(target=self.hashserv.serve_forever)
401 self.hashserv.process.daemon = True 402 self.hashserv.process.daemon = True
402 self.hashserv.process.start() 403 self.hashserv.process.start()
403 self.data.setVar("BB_HASHSERVE", self.hashservport) 404 self.data.setVar("BB_HASHSERVE", self.hashservaddr)
404 self.databuilder.origdata.setVar("BB_HASHSERVE", self.hashservport) 405 self.databuilder.origdata.setVar("BB_HASHSERVE", self.hashservaddr)
405 self.databuilder.data.setVar("BB_HASHSERVE", self.hashservport) 406 self.databuilder.data.setVar("BB_HASHSERVE", self.hashservaddr)
406 for mc in self.databuilder.mcdata: 407 for mc in self.databuilder.mcdata:
407 self.databuilder.mcdata[mc].setVar("BB_HASHSERVE", self.hashservport) 408 self.databuilder.mcdata[mc].setVar("BB_HASHSERVE", self.hashservaddr)
408 409
409 bb.parse.init_parser(self.data) 410 bb.parse.init_parser(self.data)
410 411
diff --git a/bitbake/lib/bb/runqueue.py b/bitbake/lib/bb/runqueue.py
index 45bfec8c37..314a30908b 100644
--- a/bitbake/lib/bb/runqueue.py
+++ b/bitbake/lib/bb/runqueue.py
@@ -1260,7 +1260,7 @@ class RunQueue:
1260 "buildname" : self.cfgData.getVar("BUILDNAME"), 1260 "buildname" : self.cfgData.getVar("BUILDNAME"),
1261 "date" : self.cfgData.getVar("DATE"), 1261 "date" : self.cfgData.getVar("DATE"),
1262 "time" : self.cfgData.getVar("TIME"), 1262 "time" : self.cfgData.getVar("TIME"),
1263 "hashservport" : self.cooker.hashservport, 1263 "hashservaddr" : self.cooker.hashservaddr,
1264 } 1264 }
1265 1265
1266 worker.stdin.write(b"<cookerconfig>" + pickle.dumps(self.cooker.configuration) + b"</cookerconfig>") 1266 worker.stdin.write(b"<cookerconfig>" + pickle.dumps(self.cooker.configuration) + b"</cookerconfig>")
@@ -2174,7 +2174,7 @@ class RunQueueExecute:
2174 ret.add(dep) 2174 ret.add(dep)
2175 return ret 2175 return ret
2176 2176
2177 # We filter out multiconfig dependencies from taskdepdata we pass to the tasks 2177 # We filter out multiconfig dependencies from taskdepdata we pass to the tasks
2178 # as most code can't handle them 2178 # as most code can't handle them
2179 def build_taskdepdata(self, task): 2179 def build_taskdepdata(self, task):
2180 taskdepdata = {} 2180 taskdepdata = {}
diff --git a/bitbake/lib/bb/siggen.py b/bitbake/lib/bb/siggen.py
index 8b593a348b..e047c217e5 100644
--- a/bitbake/lib/bb/siggen.py
+++ b/bitbake/lib/bb/siggen.py
@@ -13,6 +13,7 @@ import difflib
13import simplediff 13import simplediff
14from bb.checksum import FileChecksumCache 14from bb.checksum import FileChecksumCache
15from bb import runqueue 15from bb import runqueue
16import hashserv
16 17
17logger = logging.getLogger('BitBake.SigGen') 18logger = logging.getLogger('BitBake.SigGen')
18 19
@@ -375,6 +376,11 @@ class SignatureGeneratorUniHashMixIn(object):
375 self.server, self.method = data[:2] 376 self.server, self.method = data[:2]
376 super().set_taskdata(data[2:]) 377 super().set_taskdata(data[2:])
377 378
379 def client(self):
380 if getattr(self, '_client', None) is None:
381 self._client = hashserv.create_client(self.server)
382 return self._client
383
378 def __get_task_unihash_key(self, tid): 384 def __get_task_unihash_key(self, tid):
379 # TODO: The key only *needs* to be the taskhash, the tid is just 385 # TODO: The key only *needs* to be the taskhash, the tid is just
380 # convenient 386 # convenient
@@ -395,9 +401,6 @@ class SignatureGeneratorUniHashMixIn(object):
395 self.unitaskhashes[self.__get_task_unihash_key(tid)] = unihash 401 self.unitaskhashes[self.__get_task_unihash_key(tid)] = unihash
396 402
397 def get_unihash(self, tid): 403 def get_unihash(self, tid):
398 import urllib
399 import json
400
401 taskhash = self.taskhash[tid] 404 taskhash = self.taskhash[tid]
402 405
403 # If its not a setscene task we can return 406 # If its not a setscene task we can return
@@ -428,36 +431,22 @@ class SignatureGeneratorUniHashMixIn(object):
428 unihash = taskhash 431 unihash = taskhash
429 432
430 try: 433 try:
431 url = '%s/v1/equivalent?%s' % (self.server, 434 data = self.client().get_unihash(self.method, self.taskhash[tid])
432 urllib.parse.urlencode({'method': self.method, 'taskhash': self.taskhash[tid]})) 435 if data:
433 436 unihash = data
434 request = urllib.request.Request(url)
435 response = urllib.request.urlopen(request)
436 data = response.read().decode('utf-8')
437
438 json_data = json.loads(data)
439
440 if json_data:
441 unihash = json_data['unihash']
442 # A unique hash equal to the taskhash is not very interesting, 437 # A unique hash equal to the taskhash is not very interesting,
443 # so it is reported it at debug level 2. If they differ, that 438 # so it is reported it at debug level 2. If they differ, that
444 # is much more interesting, so it is reported at debug level 1 439 # is much more interesting, so it is reported at debug level 1
445 bb.debug((1, 2)[unihash == taskhash], 'Found unihash %s in place of %s for %s from %s' % (unihash, taskhash, tid, self.server)) 440 bb.debug((1, 2)[unihash == taskhash], 'Found unihash %s in place of %s for %s from %s' % (unihash, taskhash, tid, self.server))
446 else: 441 else:
447 bb.debug(2, 'No reported unihash for %s:%s from %s' % (tid, taskhash, self.server)) 442 bb.debug(2, 'No reported unihash for %s:%s from %s' % (tid, taskhash, self.server))
448 except urllib.error.URLError as e: 443 except hashserv.HashConnectionError as e:
449 bb.warn('Failure contacting Hash Equivalence Server %s: %s' % (self.server, str(e))) 444 bb.warn('Error contacting Hash Equivalence Server %s: %s' % (self.server, str(e)))
450 except (KeyError, json.JSONDecodeError) as e:
451 bb.warn('Poorly formatted response from %s: %s' % (self.server, str(e)))
452 445
453 self.unitaskhashes[key] = unihash 446 self.unitaskhashes[key] = unihash
454 return unihash 447 return unihash
455 448
456 def report_unihash(self, path, task, d): 449 def report_unihash(self, path, task, d):
457 import urllib
458 import json
459 import tempfile
460 import base64
461 import importlib 450 import importlib
462 451
463 taskhash = d.getVar('BB_TASKHASH') 452 taskhash = d.getVar('BB_TASKHASH')
@@ -492,42 +481,31 @@ class SignatureGeneratorUniHashMixIn(object):
492 outhash = bb.utils.better_eval(self.method + '(path, sigfile, task, d)', locs) 481 outhash = bb.utils.better_eval(self.method + '(path, sigfile, task, d)', locs)
493 482
494 try: 483 try:
495 url = '%s/v1/equivalent' % self.server 484 extra_data = {}
496 task_data = { 485
497 'taskhash': taskhash, 486 owner = d.getVar('SSTATE_HASHEQUIV_OWNER')
498 'method': self.method, 487 if owner:
499 'outhash': outhash, 488 extra_data['owner'] = owner
500 'unihash': unihash,
501 'owner': d.getVar('SSTATE_HASHEQUIV_OWNER')
502 }
503 489
504 if report_taskdata: 490 if report_taskdata:
505 sigfile.seek(0) 491 sigfile.seek(0)
506 492
507 task_data['PN'] = d.getVar('PN') 493 extra_data['PN'] = d.getVar('PN')
508 task_data['PV'] = d.getVar('PV') 494 extra_data['PV'] = d.getVar('PV')
509 task_data['PR'] = d.getVar('PR') 495 extra_data['PR'] = d.getVar('PR')
510 task_data['task'] = task 496 extra_data['task'] = task
511 task_data['outhash_siginfo'] = sigfile.read().decode('utf-8') 497 extra_data['outhash_siginfo'] = sigfile.read().decode('utf-8')
512
513 headers = {'content-type': 'application/json'}
514
515 request = urllib.request.Request(url, json.dumps(task_data).encode('utf-8'), headers)
516 response = urllib.request.urlopen(request)
517 data = response.read().decode('utf-8')
518 498
519 json_data = json.loads(data) 499 data = self.client().report_unihash(taskhash, self.method, outhash, unihash, extra_data)
520 new_unihash = json_data['unihash'] 500 new_unihash = data['unihash']
521 501
522 if new_unihash != unihash: 502 if new_unihash != unihash:
523 bb.debug(1, 'Task %s unihash changed %s -> %s by server %s' % (taskhash, unihash, new_unihash, self.server)) 503 bb.debug(1, 'Task %s unihash changed %s -> %s by server %s' % (taskhash, unihash, new_unihash, self.server))
524 bb.event.fire(bb.runqueue.taskUniHashUpdate(fn + ':do_' + task, new_unihash), d) 504 bb.event.fire(bb.runqueue.taskUniHashUpdate(fn + ':do_' + task, new_unihash), d)
525 else: 505 else:
526 bb.debug(1, 'Reported task %s as unihash %s to %s' % (taskhash, unihash, self.server)) 506 bb.debug(1, 'Reported task %s as unihash %s to %s' % (taskhash, unihash, self.server))
527 except urllib.error.URLError as e: 507 except hashserv.HashConnectionError as e:
528 bb.warn('Failure contacting Hash Equivalence Server %s: %s' % (self.server, str(e))) 508 bb.warn('Error contacting Hash Equivalence Server %s: %s' % (self.server, str(e)))
529 except (KeyError, json.JSONDecodeError) as e:
530 bb.warn('Poorly formatted response from %s: %s' % (self.server, str(e)))
531 finally: 509 finally:
532 if sigfile: 510 if sigfile:
533 sigfile.close() 511 sigfile.close()
@@ -548,7 +526,7 @@ class SignatureGeneratorTestEquivHash(SignatureGeneratorUniHashMixIn, SignatureG
548 name = "TestEquivHash" 526 name = "TestEquivHash"
549 def init_rundepcheck(self, data): 527 def init_rundepcheck(self, data):
550 super().init_rundepcheck(data) 528 super().init_rundepcheck(data)
551 self.server = "http://" + data.getVar('BB_HASHSERVE') 529 self.server = data.getVar('BB_HASHSERVE')
552 self.method = "sstate_output_hash" 530 self.method = "sstate_output_hash"
553 531
554 532
diff --git a/bitbake/lib/bb/tests/runqueue.py b/bitbake/lib/bb/tests/runqueue.py
index c7f5e55726..cb4d526f13 100644
--- a/bitbake/lib/bb/tests/runqueue.py
+++ b/bitbake/lib/bb/tests/runqueue.py
@@ -11,6 +11,7 @@ import bb
11import os 11import os
12import tempfile 12import tempfile
13import subprocess 13import subprocess
14import sys
14 15
15# 16#
16# TODO: 17# TODO:
@@ -232,10 +233,11 @@ class RunQueueTests(unittest.TestCase):
232 self.assertEqual(set(tasks), set(expected)) 233 self.assertEqual(set(tasks), set(expected))
233 234
234 235
236 @unittest.skipIf(sys.version_info < (3, 5, 0), 'Python 3.5 or later required')
235 def test_hashserv_single(self): 237 def test_hashserv_single(self):
236 with tempfile.TemporaryDirectory(prefix="runqueuetest") as tempdir: 238 with tempfile.TemporaryDirectory(prefix="runqueuetest") as tempdir:
237 extraenv = { 239 extraenv = {
238 "BB_HASHSERVE" : "localhost:0", 240 "BB_HASHSERVE" : "auto",
239 "BB_SIGNATURE_HANDLER" : "TestEquivHash" 241 "BB_SIGNATURE_HANDLER" : "TestEquivHash"
240 } 242 }
241 cmd = ["bitbake", "a1", "b1"] 243 cmd = ["bitbake", "a1", "b1"]
@@ -255,10 +257,11 @@ class RunQueueTests(unittest.TestCase):
255 'a1:package_write_ipk_setscene', 'a1:package_qa_setscene'] 257 'a1:package_write_ipk_setscene', 'a1:package_qa_setscene']
256 self.assertEqual(set(tasks), set(expected)) 258 self.assertEqual(set(tasks), set(expected))
257 259
260 @unittest.skipIf(sys.version_info < (3, 5, 0), 'Python 3.5 or later required')
258 def test_hashserv_double(self): 261 def test_hashserv_double(self):
259 with tempfile.TemporaryDirectory(prefix="runqueuetest") as tempdir: 262 with tempfile.TemporaryDirectory(prefix="runqueuetest") as tempdir:
260 extraenv = { 263 extraenv = {
261 "BB_HASHSERVE" : "localhost:0", 264 "BB_HASHSERVE" : "auto",
262 "BB_SIGNATURE_HANDLER" : "TestEquivHash" 265 "BB_SIGNATURE_HANDLER" : "TestEquivHash"
263 } 266 }
264 cmd = ["bitbake", "a1", "b1", "e1"] 267 cmd = ["bitbake", "a1", "b1", "e1"]
@@ -278,11 +281,12 @@ class RunQueueTests(unittest.TestCase):
278 self.assertEqual(set(tasks), set(expected)) 281 self.assertEqual(set(tasks), set(expected))
279 282
280 283
284 @unittest.skipIf(sys.version_info < (3, 5, 0), 'Python 3.5 or later required')
281 def test_hashserv_multiple_setscene(self): 285 def test_hashserv_multiple_setscene(self):
282 # Runs e1:do_package_setscene twice 286 # Runs e1:do_package_setscene twice
283 with tempfile.TemporaryDirectory(prefix="runqueuetest") as tempdir: 287 with tempfile.TemporaryDirectory(prefix="runqueuetest") as tempdir:
284 extraenv = { 288 extraenv = {
285 "BB_HASHSERVE" : "localhost:0", 289 "BB_HASHSERVE" : "auto",
286 "BB_SIGNATURE_HANDLER" : "TestEquivHash" 290 "BB_SIGNATURE_HANDLER" : "TestEquivHash"
287 } 291 }
288 cmd = ["bitbake", "a1", "b1", "e1"] 292 cmd = ["bitbake", "a1", "b1", "e1"]
@@ -308,11 +312,12 @@ class RunQueueTests(unittest.TestCase):
308 else: 312 else:
309 self.assertEqual(tasks.count(i), 1, "%s not in task list once" % i) 313 self.assertEqual(tasks.count(i), 1, "%s not in task list once" % i)
310 314
315 @unittest.skipIf(sys.version_info < (3, 5, 0), 'Python 3.5 or later required')
311 def test_hashserv_partial_match(self): 316 def test_hashserv_partial_match(self):
312 # e1:do_package matches initial built but not second hash value 317 # e1:do_package matches initial built but not second hash value
313 with tempfile.TemporaryDirectory(prefix="runqueuetest") as tempdir: 318 with tempfile.TemporaryDirectory(prefix="runqueuetest") as tempdir:
314 extraenv = { 319 extraenv = {
315 "BB_HASHSERVE" : "localhost:0", 320 "BB_HASHSERVE" : "auto",
316 "BB_SIGNATURE_HANDLER" : "TestEquivHash" 321 "BB_SIGNATURE_HANDLER" : "TestEquivHash"
317 } 322 }
318 cmd = ["bitbake", "a1", "b1"] 323 cmd = ["bitbake", "a1", "b1"]
@@ -336,11 +341,12 @@ class RunQueueTests(unittest.TestCase):
336 expected.remove('e1:package') 341 expected.remove('e1:package')
337 self.assertEqual(set(tasks), set(expected)) 342 self.assertEqual(set(tasks), set(expected))
338 343
344 @unittest.skipIf(sys.version_info < (3, 5, 0), 'Python 3.5 or later required')
339 def test_hashserv_partial_match2(self): 345 def test_hashserv_partial_match2(self):
340 # e1:do_package + e1:do_populate_sysroot matches initial built but not second hash value 346 # e1:do_package + e1:do_populate_sysroot matches initial built but not second hash value
341 with tempfile.TemporaryDirectory(prefix="runqueuetest") as tempdir: 347 with tempfile.TemporaryDirectory(prefix="runqueuetest") as tempdir:
342 extraenv = { 348 extraenv = {
343 "BB_HASHSERVE" : "localhost:0", 349 "BB_HASHSERVE" : "auto",
344 "BB_SIGNATURE_HANDLER" : "TestEquivHash" 350 "BB_SIGNATURE_HANDLER" : "TestEquivHash"
345 } 351 }
346 cmd = ["bitbake", "a1", "b1"] 352 cmd = ["bitbake", "a1", "b1"]
@@ -363,13 +369,14 @@ class RunQueueTests(unittest.TestCase):
363 'e1:package_setscene', 'e1:populate_sysroot_setscene', 'e1:build', 'e1:package_qa', 'e1:package_write_rpm', 'e1:package_write_ipk', 'e1:packagedata'] 369 'e1:package_setscene', 'e1:populate_sysroot_setscene', 'e1:build', 'e1:package_qa', 'e1:package_write_rpm', 'e1:package_write_ipk', 'e1:packagedata']
364 self.assertEqual(set(tasks), set(expected)) 370 self.assertEqual(set(tasks), set(expected))
365 371
372 @unittest.skipIf(sys.version_info < (3, 5, 0), 'Python 3.5 or later required')
366 def test_hashserv_partial_match3(self): 373 def test_hashserv_partial_match3(self):
367 # e1:do_package is valid for a1 but not after b1 374 # e1:do_package is valid for a1 but not after b1
368 # In former buggy code, this triggered e1:do_fetch, then e1:do_populate_sysroot to run 375 # In former buggy code, this triggered e1:do_fetch, then e1:do_populate_sysroot to run
369 # with none of the intermediate tasks which is a serious bug 376 # with none of the intermediate tasks which is a serious bug
370 with tempfile.TemporaryDirectory(prefix="runqueuetest") as tempdir: 377 with tempfile.TemporaryDirectory(prefix="runqueuetest") as tempdir:
371 extraenv = { 378 extraenv = {
372 "BB_HASHSERVE" : "localhost:0", 379 "BB_HASHSERVE" : "auto",
373 "BB_SIGNATURE_HANDLER" : "TestEquivHash" 380 "BB_SIGNATURE_HANDLER" : "TestEquivHash"
374 } 381 }
375 cmd = ["bitbake", "a1", "b1"] 382 cmd = ["bitbake", "a1", "b1"]
diff --git a/bitbake/lib/hashserv/__init__.py b/bitbake/lib/hashserv/__init__.py
index eb03c32213..c3318620f5 100644
--- a/bitbake/lib/hashserv/__init__.py
+++ b/bitbake/lib/hashserv/__init__.py
@@ -3,203 +3,21 @@
3# SPDX-License-Identifier: GPL-2.0-only 3# SPDX-License-Identifier: GPL-2.0-only
4# 4#
5 5
6from http.server import BaseHTTPRequestHandler, HTTPServer 6from contextlib import closing
7import contextlib 7import re
8import urllib.parse
9import sqlite3 8import sqlite3
10import json
11import traceback
12import logging
13import socketserver
14import queue
15import threading
16import signal
17import socket
18import struct
19from datetime import datetime
20
21logger = logging.getLogger('hashserv')
22
23class HashEquivalenceServer(BaseHTTPRequestHandler):
24 def log_message(self, f, *args):
25 logger.debug(f, *args)
26
27 def opendb(self):
28 self.db = sqlite3.connect(self.dbname)
29 self.db.row_factory = sqlite3.Row
30 self.db.execute("PRAGMA synchronous = OFF;")
31 self.db.execute("PRAGMA journal_mode = MEMORY;")
32
33 def do_GET(self):
34 try:
35 if not self.db:
36 self.opendb()
37
38 p = urllib.parse.urlparse(self.path)
39
40 if p.path != self.prefix + '/v1/equivalent':
41 self.send_error(404)
42 return
43
44 query = urllib.parse.parse_qs(p.query, strict_parsing=True)
45 method = query['method'][0]
46 taskhash = query['taskhash'][0]
47
48 d = None
49 with contextlib.closing(self.db.cursor()) as cursor:
50 cursor.execute('SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1',
51 {'method': method, 'taskhash': taskhash})
52
53 row = cursor.fetchone()
54
55 if row is not None:
56 logger.debug('Found equivalent task %s', row['taskhash'])
57 d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
58
59 self.send_response(200)
60 self.send_header('Content-Type', 'application/json; charset=utf-8')
61 self.end_headers()
62 self.wfile.write(json.dumps(d).encode('utf-8'))
63 except:
64 logger.exception('Error in GET')
65 self.send_error(400, explain=traceback.format_exc())
66 return
67
68 def do_POST(self):
69 try:
70 if not self.db:
71 self.opendb()
72
73 p = urllib.parse.urlparse(self.path)
74
75 if p.path != self.prefix + '/v1/equivalent':
76 self.send_error(404)
77 return
78
79 length = int(self.headers['content-length'])
80 data = json.loads(self.rfile.read(length).decode('utf-8'))
81
82 with contextlib.closing(self.db.cursor()) as cursor:
83 cursor.execute('''
84 -- Find tasks with a matching outhash (that is, tasks that
85 -- are equivalent)
86 SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND outhash=:outhash
87
88 -- If there is an exact match on the taskhash, return it.
89 -- Otherwise return the oldest matching outhash of any
90 -- taskhash
91 ORDER BY CASE WHEN taskhash=:taskhash THEN 1 ELSE 2 END,
92 created ASC
93
94 -- Only return one row
95 LIMIT 1
96 ''', {k: data[k] for k in ('method', 'outhash', 'taskhash')})
97
98 row = cursor.fetchone()
99
100 # If no matching outhash was found, or one *was* found but it
101 # wasn't an exact match on the taskhash, a new entry for this
102 # taskhash should be added
103 if row is None or row['taskhash'] != data['taskhash']:
104 # If a row matching the outhash was found, the unihash for
105 # the new taskhash should be the same as that one.
106 # Otherwise the caller provided unihash is used.
107 unihash = data['unihash']
108 if row is not None:
109 unihash = row['unihash']
110
111 insert_data = {
112 'method': data['method'],
113 'outhash': data['outhash'],
114 'taskhash': data['taskhash'],
115 'unihash': unihash,
116 'created': datetime.now()
117 }
118
119 for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'):
120 if k in data:
121 insert_data[k] = data[k]
122
123 cursor.execute('''INSERT INTO tasks_v2 (%s) VALUES (%s)''' % (
124 ', '.join(sorted(insert_data.keys())),
125 ', '.join(':' + k for k in sorted(insert_data.keys()))),
126 insert_data)
127
128 logger.info('Adding taskhash %s with unihash %s', data['taskhash'], unihash)
129
130 self.db.commit()
131 d = {'taskhash': data['taskhash'], 'method': data['method'], 'unihash': unihash}
132 else:
133 d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
134
135 self.send_response(200)
136 self.send_header('Content-Type', 'application/json; charset=utf-8')
137 self.end_headers()
138 self.wfile.write(json.dumps(d).encode('utf-8'))
139 except:
140 logger.exception('Error in POST')
141 self.send_error(400, explain=traceback.format_exc())
142 return
143
144class ThreadedHTTPServer(HTTPServer):
145 quit = False
146
147 def serve_forever(self):
148 self.requestqueue = queue.Queue()
149 self.handlerthread = threading.Thread(target=self.process_request_thread)
150 self.handlerthread.daemon = False
151
152 self.handlerthread.start()
153
154 signal.signal(signal.SIGTERM, self.sigterm_exception)
155 super().serve_forever()
156 os._exit(0)
157
158 def sigterm_exception(self, signum, stackframe):
159 self.server_close()
160 os._exit(0)
161
162 def server_bind(self):
163 HTTPServer.server_bind(self)
164 self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, struct.pack('ii', 1, 0))
165
166 def process_request_thread(self):
167 while not self.quit:
168 try:
169 (request, client_address) = self.requestqueue.get(True)
170 except queue.Empty:
171 continue
172 if request is None:
173 continue
174 try:
175 self.finish_request(request, client_address)
176 except Exception:
177 self.handle_error(request, client_address)
178 finally:
179 self.shutdown_request(request)
180 os._exit(0)
181
182 def process_request(self, request, client_address):
183 self.requestqueue.put((request, client_address))
184
185 def server_close(self):
186 super().server_close()
187 self.quit = True
188 self.requestqueue.put((None, None))
189 self.handlerthread.join()
190
191def create_server(addr, dbname, prefix=''):
192 class Handler(HashEquivalenceServer):
193 pass
194
195 db = sqlite3.connect(dbname)
196 db.row_factory = sqlite3.Row
197 9
198 Handler.prefix = prefix 10UNIX_PREFIX = "unix://"
199 Handler.db = None 11
200 Handler.dbname = dbname 12ADDR_TYPE_UNIX = 0
13ADDR_TYPE_TCP = 1
14
15
16def setup_database(database, sync=True):
17 db = sqlite3.connect(database)
18 db.row_factory = sqlite3.Row
201 19
202 with contextlib.closing(db.cursor()) as cursor: 20 with closing(db.cursor()) as cursor:
203 cursor.execute(''' 21 cursor.execute('''
204 CREATE TABLE IF NOT EXISTS tasks_v2 ( 22 CREATE TABLE IF NOT EXISTS tasks_v2 (
205 id INTEGER PRIMARY KEY AUTOINCREMENT, 23 id INTEGER PRIMARY KEY AUTOINCREMENT,
@@ -220,11 +38,56 @@ def create_server(addr, dbname, prefix=''):
220 UNIQUE(method, outhash, taskhash) 38 UNIQUE(method, outhash, taskhash)
221 ) 39 )
222 ''') 40 ''')
223 cursor.execute('CREATE INDEX IF NOT EXISTS taskhash_lookup ON tasks_v2 (method, taskhash)') 41 cursor.execute('PRAGMA journal_mode = WAL')
224 cursor.execute('CREATE INDEX IF NOT EXISTS outhash_lookup ON tasks_v2 (method, outhash)') 42 cursor.execute('PRAGMA synchronous = %s' % ('NORMAL' if sync else 'OFF'))
43
44 # Drop old indexes
45 cursor.execute('DROP INDEX IF EXISTS taskhash_lookup')
46 cursor.execute('DROP INDEX IF EXISTS outhash_lookup')
47
48 # Create new indexes
49 cursor.execute('CREATE INDEX IF NOT EXISTS taskhash_lookup_v2 ON tasks_v2 (method, taskhash, created)')
50 cursor.execute('CREATE INDEX IF NOT EXISTS outhash_lookup_v2 ON tasks_v2 (method, outhash)')
51
52 return db
53
54
55def parse_address(addr):
56 if addr.startswith(UNIX_PREFIX):
57 return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX):],))
58 else:
59 m = re.match(r'\[(?P<host>[^\]]*)\]:(?P<port>\d+)$', addr)
60 if m is not None:
61 host = m.group('host')
62 port = m.group('port')
63 else:
64 host, port = addr.split(':')
65
66 return (ADDR_TYPE_TCP, (host, int(port)))
67
68
69def create_server(addr, dbname, *, sync=True):
70 from . import server
71 db = setup_database(dbname, sync=sync)
72 s = server.Server(db)
73
74 (typ, a) = parse_address(addr)
75 if typ == ADDR_TYPE_UNIX:
76 s.start_unix_server(*a)
77 else:
78 s.start_tcp_server(*a)
79
80 return s
81
225 82
226 ret = ThreadedHTTPServer(addr, Handler) 83def create_client(addr):
84 from . import client
85 c = client.Client()
227 86
228 logger.info('Starting server on %s\n', ret.server_port) 87 (typ, a) = parse_address(addr)
88 if typ == ADDR_TYPE_UNIX:
89 c.connect_unix(*a)
90 else:
91 c.connect_tcp(*a)
229 92
230 return ret 93 return c
diff --git a/bitbake/lib/hashserv/client.py b/bitbake/lib/hashserv/client.py
new file mode 100644
index 0000000000..2559bbb3fb
--- /dev/null
+++ b/bitbake/lib/hashserv/client.py
@@ -0,0 +1,156 @@
1# Copyright (C) 2019 Garmin Ltd.
2#
3# SPDX-License-Identifier: GPL-2.0-only
4#
5
6from contextlib import closing
7import json
8import logging
9import socket
10
11
12logger = logging.getLogger('hashserv.client')
13
14
15class HashConnectionError(Exception):
16 pass
17
18
19class Client(object):
20 MODE_NORMAL = 0
21 MODE_GET_STREAM = 1
22
23 def __init__(self):
24 self._socket = None
25 self.reader = None
26 self.writer = None
27 self.mode = self.MODE_NORMAL
28
29 def connect_tcp(self, address, port):
30 def connect_sock():
31 s = socket.create_connection((address, port))
32
33 s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
34 s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
35 s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
36 return s
37
38 self._connect_sock = connect_sock
39
40 def connect_unix(self, path):
41 def connect_sock():
42 s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
43 # AF_UNIX has path length issues so chdir here to workaround
44 cwd = os.getcwd()
45 try:
46 os.chdir(os.path.dirname(path))
47 s.connect(os.path.basename(path))
48 finally:
49 os.chdir(cwd)
50 return s
51
52 self._connect_sock = connect_sock
53
54 def connect(self):
55 if self._socket is None:
56 self._socket = self._connect_sock()
57
58 self.reader = self._socket.makefile('r', encoding='utf-8')
59 self.writer = self._socket.makefile('w', encoding='utf-8')
60
61 self.writer.write('OEHASHEQUIV 1.0\n\n')
62 self.writer.flush()
63
64 # Restore mode if the socket is being re-created
65 cur_mode = self.mode
66 self.mode = self.MODE_NORMAL
67 self._set_mode(cur_mode)
68
69 return self._socket
70
71 def close(self):
72 if self._socket is not None:
73 self._socket.close()
74 self._socket = None
75 self.reader = None
76 self.writer = None
77
78 def _send_wrapper(self, proc):
79 count = 0
80 while True:
81 try:
82 self.connect()
83 return proc()
84 except (OSError, HashConnectionError, json.JSONDecodeError, UnicodeDecodeError) as e:
85 logger.warning('Error talking to server: %s' % e)
86 if count >= 3:
87 if not isinstance(e, HashConnectionError):
88 raise HashConnectionError(str(e))
89 raise e
90 self.close()
91 count += 1
92
93 def send_message(self, msg):
94 def proc():
95 self.writer.write('%s\n' % json.dumps(msg))
96 self.writer.flush()
97
98 l = self.reader.readline()
99 if not l:
100 raise HashConnectionError('Connection closed')
101
102 if not l.endswith('\n'):
103 raise HashConnectionError('Bad message %r' % message)
104
105 return json.loads(l)
106
107 return self._send_wrapper(proc)
108
109 def send_stream(self, msg):
110 def proc():
111 self.writer.write("%s\n" % msg)
112 self.writer.flush()
113 l = self.reader.readline()
114 if not l:
115 raise HashConnectionError('Connection closed')
116 return l.rstrip()
117
118 return self._send_wrapper(proc)
119
120 def _set_mode(self, new_mode):
121 if new_mode == self.MODE_NORMAL and self.mode == self.MODE_GET_STREAM:
122 r = self.send_stream('END')
123 if r != 'ok':
124 raise HashConnectionError('Bad response from server %r' % r)
125 elif new_mode == self.MODE_GET_STREAM and self.mode == self.MODE_NORMAL:
126 r = self.send_message({'get-stream': None})
127 if r != 'ok':
128 raise HashConnectionError('Bad response from server %r' % r)
129 elif new_mode != self.mode:
130 raise Exception('Undefined mode transition %r -> %r' % (self.mode, new_mode))
131
132 self.mode = new_mode
133
134 def get_unihash(self, method, taskhash):
135 self._set_mode(self.MODE_GET_STREAM)
136 r = self.send_stream('%s %s' % (method, taskhash))
137 if not r:
138 return None
139 return r
140
141 def report_unihash(self, taskhash, method, outhash, unihash, extra={}):
142 self._set_mode(self.MODE_NORMAL)
143 m = extra.copy()
144 m['taskhash'] = taskhash
145 m['method'] = method
146 m['outhash'] = outhash
147 m['unihash'] = unihash
148 return self.send_message({'report': m})
149
150 def get_stats(self):
151 self._set_mode(self.MODE_NORMAL)
152 return self.send_message({'get-stats': None})
153
154 def reset_stats(self):
155 self._set_mode(self.MODE_NORMAL)
156 return self.send_message({'reset-stats': None})
diff --git a/bitbake/lib/hashserv/server.py b/bitbake/lib/hashserv/server.py
new file mode 100644
index 0000000000..0aff77688e
--- /dev/null
+++ b/bitbake/lib/hashserv/server.py
@@ -0,0 +1,414 @@
1# Copyright (C) 2019 Garmin Ltd.
2#
3# SPDX-License-Identifier: GPL-2.0-only
4#
5
6from contextlib import closing
7from datetime import datetime
8import asyncio
9import json
10import logging
11import math
12import os
13import signal
14import socket
15import time
16
17logger = logging.getLogger('hashserv.server')
18
19
20class Measurement(object):
21 def __init__(self, sample):
22 self.sample = sample
23
24 def start(self):
25 self.start_time = time.perf_counter()
26
27 def end(self):
28 self.sample.add(time.perf_counter() - self.start_time)
29
30 def __enter__(self):
31 self.start()
32 return self
33
34 def __exit__(self, *args, **kwargs):
35 self.end()
36
37
38class Sample(object):
39 def __init__(self, stats):
40 self.stats = stats
41 self.num_samples = 0
42 self.elapsed = 0
43
44 def measure(self):
45 return Measurement(self)
46
47 def __enter__(self):
48 return self
49
50 def __exit__(self, *args, **kwargs):
51 self.end()
52
53 def add(self, elapsed):
54 self.num_samples += 1
55 self.elapsed += elapsed
56
57 def end(self):
58 if self.num_samples:
59 self.stats.add(self.elapsed)
60 self.num_samples = 0
61 self.elapsed = 0
62
63
64class Stats(object):
65 def __init__(self):
66 self.reset()
67
68 def reset(self):
69 self.num = 0
70 self.total_time = 0
71 self.max_time = 0
72 self.m = 0
73 self.s = 0
74 self.current_elapsed = None
75
76 def add(self, elapsed):
77 self.num += 1
78 if self.num == 1:
79 self.m = elapsed
80 self.s = 0
81 else:
82 last_m = self.m
83 self.m = last_m + (elapsed - last_m) / self.num
84 self.s = self.s + (elapsed - last_m) * (elapsed - self.m)
85
86 self.total_time += elapsed
87
88 if self.max_time < elapsed:
89 self.max_time = elapsed
90
91 def start_sample(self):
92 return Sample(self)
93
94 @property
95 def average(self):
96 if self.num == 0:
97 return 0
98 return self.total_time / self.num
99
100 @property
101 def stdev(self):
102 if self.num <= 1:
103 return 0
104 return math.sqrt(self.s / (self.num - 1))
105
106 def todict(self):
107 return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')}
108
109
110class ServerClient(object):
111 def __init__(self, reader, writer, db, request_stats):
112 self.reader = reader
113 self.writer = writer
114 self.db = db
115 self.request_stats = request_stats
116
117 async def process_requests(self):
118 try:
119 self.addr = self.writer.get_extra_info('peername')
120 logger.debug('Client %r connected' % (self.addr,))
121
122 # Read protocol and version
123 protocol = await self.reader.readline()
124 if protocol is None:
125 return
126
127 (proto_name, proto_version) = protocol.decode('utf-8').rstrip().split()
128 if proto_name != 'OEHASHEQUIV' or proto_version != '1.0':
129 return
130
131 # Read headers. Currently, no headers are implemented, so look for
132 # an empty line to signal the end of the headers
133 while True:
134 line = await self.reader.readline()
135 if line is None:
136 return
137
138 line = line.decode('utf-8').rstrip()
139 if not line:
140 break
141
142 # Handle messages
143 handlers = {
144 'get': self.handle_get,
145 'report': self.handle_report,
146 'get-stream': self.handle_get_stream,
147 'get-stats': self.handle_get_stats,
148 'reset-stats': self.handle_reset_stats,
149 }
150
151 while True:
152 d = await self.read_message()
153 if d is None:
154 break
155
156 for k in handlers.keys():
157 if k in d:
158 logger.debug('Handling %s' % k)
159 if 'stream' in k:
160 await handlers[k](d[k])
161 else:
162 with self.request_stats.start_sample() as self.request_sample, \
163 self.request_sample.measure():
164 await handlers[k](d[k])
165 break
166 else:
167 logger.warning("Unrecognized command %r" % d)
168 break
169
170 await self.writer.drain()
171 finally:
172 self.writer.close()
173
174 def write_message(self, msg):
175 self.writer.write(('%s\n' % json.dumps(msg)).encode('utf-8'))
176
177 async def read_message(self):
178 l = await self.reader.readline()
179 if not l:
180 return None
181
182 try:
183 message = l.decode('utf-8')
184
185 if not message.endswith('\n'):
186 return None
187
188 return json.loads(message)
189 except (json.JSONDecodeError, UnicodeDecodeError) as e:
190 logger.error('Bad message from client: %r' % message)
191 raise e
192
193 async def handle_get(self, request):
194 method = request['method']
195 taskhash = request['taskhash']
196
197 row = self.query_equivalent(method, taskhash)
198 if row is not None:
199 logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
200 d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
201
202 self.write_message(d)
203 else:
204 self.write_message(None)
205
206 async def handle_get_stream(self, request):
207 self.write_message('ok')
208
209 while True:
210 l = await self.reader.readline()
211 if not l:
212 return
213
214 try:
215 # This inner loop is very sensitive and must be as fast as
216 # possible (which is why the request sample is handled manually
217 # instead of using 'with', and also why logging statements are
218 # commented out.
219 self.request_sample = self.request_stats.start_sample()
220 request_measure = self.request_sample.measure()
221 request_measure.start()
222
223 l = l.decode('utf-8').rstrip()
224 if l == 'END':
225 self.writer.write('ok\n'.encode('utf-8'))
226 return
227
228 (method, taskhash) = l.split()
229 #logger.debug('Looking up %s %s' % (method, taskhash))
230 row = self.query_equivalent(method, taskhash)
231 if row is not None:
232 msg = ('%s\n' % row['unihash']).encode('utf-8')
233 #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
234 else:
235 msg = '\n'.encode('utf-8')
236
237 self.writer.write(msg)
238 finally:
239 request_measure.end()
240 self.request_sample.end()
241
242 await self.writer.drain()
243
244 async def handle_report(self, data):
245 with closing(self.db.cursor()) as cursor:
246 cursor.execute('''
247 -- Find tasks with a matching outhash (that is, tasks that
248 -- are equivalent)
249 SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND outhash=:outhash
250
251 -- If there is an exact match on the taskhash, return it.
252 -- Otherwise return the oldest matching outhash of any
253 -- taskhash
254 ORDER BY CASE WHEN taskhash=:taskhash THEN 1 ELSE 2 END,
255 created ASC
256
257 -- Only return one row
258 LIMIT 1
259 ''', {k: data[k] for k in ('method', 'outhash', 'taskhash')})
260
261 row = cursor.fetchone()
262
263 # If no matching outhash was found, or one *was* found but it
264 # wasn't an exact match on the taskhash, a new entry for this
265 # taskhash should be added
266 if row is None or row['taskhash'] != data['taskhash']:
267 # If a row matching the outhash was found, the unihash for
268 # the new taskhash should be the same as that one.
269 # Otherwise the caller provided unihash is used.
270 unihash = data['unihash']
271 if row is not None:
272 unihash = row['unihash']
273
274 insert_data = {
275 'method': data['method'],
276 'outhash': data['outhash'],
277 'taskhash': data['taskhash'],
278 'unihash': unihash,
279 'created': datetime.now()
280 }
281
282 for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'):
283 if k in data:
284 insert_data[k] = data[k]
285
286 cursor.execute('''INSERT INTO tasks_v2 (%s) VALUES (%s)''' % (
287 ', '.join(sorted(insert_data.keys())),
288 ', '.join(':' + k for k in sorted(insert_data.keys()))),
289 insert_data)
290
291 self.db.commit()
292
293 logger.info('Adding taskhash %s with unihash %s',
294 data['taskhash'], unihash)
295
296 d = {
297 'taskhash': data['taskhash'],
298 'method': data['method'],
299 'unihash': unihash
300 }
301 else:
302 d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
303
304 self.write_message(d)
305
306 async def handle_get_stats(self, request):
307 d = {
308 'requests': self.request_stats.todict(),
309 }
310
311 self.write_message(d)
312
313 async def handle_reset_stats(self, request):
314 d = {
315 'requests': self.request_stats.todict(),
316 }
317
318 self.request_stats.reset()
319 self.write_message(d)
320
321 def query_equivalent(self, method, taskhash):
322 # This is part of the inner loop and must be as fast as possible
323 try:
324 cursor = self.db.cursor()
325 cursor.execute('SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1',
326 {'method': method, 'taskhash': taskhash})
327 return cursor.fetchone()
328 except:
329 cursor.close()
330
331
332class Server(object):
333 def __init__(self, db, loop=None):
334 self.request_stats = Stats()
335 self.db = db
336
337 if loop is None:
338 self.loop = asyncio.new_event_loop()
339 self.close_loop = True
340 else:
341 self.loop = loop
342 self.close_loop = False
343
344 self._cleanup_socket = None
345
346 def start_tcp_server(self, host, port):
347 self.server = self.loop.run_until_complete(
348 asyncio.start_server(self.handle_client, host, port, loop=self.loop)
349 )
350
351 for s in self.server.sockets:
352 logger.info('Listening on %r' % (s.getsockname(),))
353 # Newer python does this automatically. Do it manually here for
354 # maximum compatibility
355 s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
356 s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
357
358 name = self.server.sockets[0].getsockname()
359 if self.server.sockets[0].family == socket.AF_INET6:
360 self.address = "[%s]:%d" % (name[0], name[1])
361 else:
362 self.address = "%s:%d" % (name[0], name[1])
363
364 def start_unix_server(self, path):
365 def cleanup():
366 os.unlink(path)
367
368 cwd = os.getcwd()
369 try:
370 # Work around path length limits in AF_UNIX
371 os.chdir(os.path.dirname(path))
372 self.server = self.loop.run_until_complete(
373 asyncio.start_unix_server(self.handle_client, os.path.basename(path), loop=self.loop)
374 )
375 finally:
376 os.chdir(cwd)
377
378 logger.info('Listening on %r' % path)
379
380 self._cleanup_socket = cleanup
381 self.address = "unix://%s" % os.path.abspath(path)
382
383 async def handle_client(self, reader, writer):
384 # writer.transport.set_write_buffer_limits(0)
385 try:
386 client = ServerClient(reader, writer, self.db, self.request_stats)
387 await client.process_requests()
388 except Exception as e:
389 import traceback
390 logger.error('Error from client: %s' % str(e), exc_info=True)
391 traceback.print_exc()
392 writer.close()
393 logger.info('Client disconnected')
394
395 def serve_forever(self):
396 def signal_handler():
397 self.loop.stop()
398
399 self.loop.add_signal_handler(signal.SIGTERM, signal_handler)
400
401 try:
402 self.loop.run_forever()
403 except KeyboardInterrupt:
404 pass
405
406 self.server.close()
407 self.loop.run_until_complete(self.server.wait_closed())
408 logger.info('Server shutting down')
409
410 if self.close_loop:
411 self.loop.close()
412
413 if self._cleanup_socket is not None:
414 self._cleanup_socket()
diff --git a/bitbake/lib/hashserv/tests.py b/bitbake/lib/hashserv/tests.py
index 6845b53884..6584ff57b4 100644
--- a/bitbake/lib/hashserv/tests.py
+++ b/bitbake/lib/hashserv/tests.py
@@ -1,29 +1,40 @@
1#! /usr/bin/env python3 1#! /usr/bin/env python3
2# 2#
3# Copyright (C) 2018 Garmin Ltd. 3# Copyright (C) 2018-2019 Garmin Ltd.
4# 4#
5# SPDX-License-Identifier: GPL-2.0-only 5# SPDX-License-Identifier: GPL-2.0-only
6# 6#
7 7
8import unittest 8from . import create_server, create_client
9import multiprocessing
10import sqlite3
11import hashlib 9import hashlib
12import urllib.request 10import logging
13import json 11import multiprocessing
12import sys
14import tempfile 13import tempfile
15from . import create_server 14import threading
15import unittest
16
17
18class TestHashEquivalenceServer(object):
19 METHOD = 'TestMethod'
20
21 def _run_server(self):
22 # logging.basicConfig(level=logging.DEBUG, filename='bbhashserv.log', filemode='w',
23 # format='%(levelname)s %(filename)s:%(lineno)d %(message)s')
24 self.server.serve_forever()
16 25
17class TestHashEquivalenceServer(unittest.TestCase):
18 def setUp(self): 26 def setUp(self):
19 # Start a hash equivalence server in the background bound to 27 if sys.version_info < (3, 5, 0):
20 # an ephemeral port 28 self.skipTest('Python 3.5 or later required')
21 self.dbfile = tempfile.NamedTemporaryFile(prefix="bb-hashserv-db-") 29
22 self.server = create_server(('localhost', 0), self.dbfile.name) 30 self.temp_dir = tempfile.TemporaryDirectory(prefix='bb-hashserv')
23 self.server_addr = 'http://localhost:%d' % self.server.socket.getsockname()[1] 31 self.dbfile = os.path.join(self.temp_dir.name, 'db.sqlite')
24 self.server_thread = multiprocessing.Process(target=self.server.serve_forever) 32
33 self.server = create_server(self.get_server_addr(), self.dbfile)
34 self.server_thread = multiprocessing.Process(target=self._run_server)
25 self.server_thread.daemon = True 35 self.server_thread.daemon = True
26 self.server_thread.start() 36 self.server_thread.start()
37 self.client = create_client(self.server.address)
27 38
28 def tearDown(self): 39 def tearDown(self):
29 # Shutdown server 40 # Shutdown server
@@ -31,19 +42,8 @@ class TestHashEquivalenceServer(unittest.TestCase):
31 if s is not None: 42 if s is not None:
32 self.server_thread.terminate() 43 self.server_thread.terminate()
33 self.server_thread.join() 44 self.server_thread.join()
34 45 self.client.close()
35 def send_get(self, path): 46 self.temp_dir.cleanup()
36 url = '%s/%s' % (self.server_addr, path)
37 request = urllib.request.Request(url)
38 response = urllib.request.urlopen(request)
39 return json.loads(response.read().decode('utf-8'))
40
41 def send_post(self, path, data):
42 headers = {'content-type': 'application/json'}
43 url = '%s/%s' % (self.server_addr, path)
44 request = urllib.request.Request(url, json.dumps(data).encode('utf-8'), headers)
45 response = urllib.request.urlopen(request)
46 return json.loads(response.read().decode('utf-8'))
47 47
48 def test_create_hash(self): 48 def test_create_hash(self):
49 # Simple test that hashes can be created 49 # Simple test that hashes can be created
@@ -51,16 +51,11 @@ class TestHashEquivalenceServer(unittest.TestCase):
51 outhash = '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f' 51 outhash = '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f'
52 unihash = 'f46d3fbb439bd9b921095da657a4de906510d2cd' 52 unihash = 'f46d3fbb439bd9b921095da657a4de906510d2cd'
53 53
54 d = self.send_get('v1/equivalent?method=TestMethod&taskhash=%s' % taskhash) 54 result = self.client.get_unihash(self.METHOD, taskhash)
55 self.assertIsNone(d, msg='Found unexpected task, %r' % d) 55 self.assertIsNone(result, msg='Found unexpected task, %r' % result)
56 56
57 d = self.send_post('v1/equivalent', { 57 result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
58 'taskhash': taskhash, 58 self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
59 'method': 'TestMethod',
60 'outhash': outhash,
61 'unihash': unihash,
62 })
63 self.assertEqual(d['unihash'], unihash, 'Server returned bad unihash')
64 59
65 def test_create_equivalent(self): 60 def test_create_equivalent(self):
66 # Tests that a second reported task with the same outhash will be 61 # Tests that a second reported task with the same outhash will be
@@ -68,25 +63,16 @@ class TestHashEquivalenceServer(unittest.TestCase):
68 taskhash = '53b8dce672cb6d0c73170be43f540460bfc347b4' 63 taskhash = '53b8dce672cb6d0c73170be43f540460bfc347b4'
69 outhash = '5a9cb1649625f0bf41fc7791b635cd9c2d7118c7f021ba87dcd03f72b67ce7a8' 64 outhash = '5a9cb1649625f0bf41fc7791b635cd9c2d7118c7f021ba87dcd03f72b67ce7a8'
70 unihash = 'f37918cc02eb5a520b1aff86faacbc0a38124646' 65 unihash = 'f37918cc02eb5a520b1aff86faacbc0a38124646'
71 d = self.send_post('v1/equivalent', { 66
72 'taskhash': taskhash, 67 result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
73 'method': 'TestMethod', 68 self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
74 'outhash': outhash,
75 'unihash': unihash,
76 })
77 self.assertEqual(d['unihash'], unihash, 'Server returned bad unihash')
78 69
79 # Report a different task with the same outhash. The returned unihash 70 # Report a different task with the same outhash. The returned unihash
80 # should match the first task 71 # should match the first task
81 taskhash2 = '3bf6f1e89d26205aec90da04854fbdbf73afe6b4' 72 taskhash2 = '3bf6f1e89d26205aec90da04854fbdbf73afe6b4'
82 unihash2 = 'af36b199320e611fbb16f1f277d3ee1d619ca58b' 73 unihash2 = 'af36b199320e611fbb16f1f277d3ee1d619ca58b'
83 d = self.send_post('v1/equivalent', { 74 result = self.client.report_unihash(taskhash2, self.METHOD, outhash, unihash2)
84 'taskhash': taskhash2, 75 self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
85 'method': 'TestMethod',
86 'outhash': outhash,
87 'unihash': unihash2,
88 })
89 self.assertEqual(d['unihash'], unihash, 'Server returned bad unihash')
90 76
91 def test_duplicate_taskhash(self): 77 def test_duplicate_taskhash(self):
92 # Tests that duplicate reports of the same taskhash with different 78 # Tests that duplicate reports of the same taskhash with different
@@ -95,38 +81,63 @@ class TestHashEquivalenceServer(unittest.TestCase):
95 taskhash = '8aa96fcffb5831b3c2c0cb75f0431e3f8b20554a' 81 taskhash = '8aa96fcffb5831b3c2c0cb75f0431e3f8b20554a'
96 outhash = 'afe240a439959ce86f5e322f8c208e1fedefea9e813f2140c81af866cc9edf7e' 82 outhash = 'afe240a439959ce86f5e322f8c208e1fedefea9e813f2140c81af866cc9edf7e'
97 unihash = '218e57509998197d570e2c98512d0105985dffc9' 83 unihash = '218e57509998197d570e2c98512d0105985dffc9'
98 d = self.send_post('v1/equivalent', { 84 self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
99 'taskhash': taskhash,
100 'method': 'TestMethod',
101 'outhash': outhash,
102 'unihash': unihash,
103 })
104 85
105 d = self.send_get('v1/equivalent?method=TestMethod&taskhash=%s' % taskhash) 86 result = self.client.get_unihash(self.METHOD, taskhash)
106 self.assertEqual(d['unihash'], unihash) 87 self.assertEqual(result, unihash)
107 88
108 outhash2 = '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d' 89 outhash2 = '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d'
109 unihash2 = 'ae9a7d252735f0dafcdb10e2e02561ca3a47314c' 90 unihash2 = 'ae9a7d252735f0dafcdb10e2e02561ca3a47314c'
110 d = self.send_post('v1/equivalent', { 91 self.client.report_unihash(taskhash, self.METHOD, outhash2, unihash2)
111 'taskhash': taskhash,
112 'method': 'TestMethod',
113 'outhash': outhash2,
114 'unihash': unihash2
115 })
116 92
117 d = self.send_get('v1/equivalent?method=TestMethod&taskhash=%s' % taskhash) 93 result = self.client.get_unihash(self.METHOD, taskhash)
118 self.assertEqual(d['unihash'], unihash) 94 self.assertEqual(result, unihash)
119 95
120 outhash3 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4' 96 outhash3 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4'
121 unihash3 = '9217a7d6398518e5dc002ed58f2cbbbc78696603' 97 unihash3 = '9217a7d6398518e5dc002ed58f2cbbbc78696603'
122 d = self.send_post('v1/equivalent', { 98 self.client.report_unihash(taskhash, self.METHOD, outhash3, unihash3)
123 'taskhash': taskhash, 99
124 'method': 'TestMethod', 100 result = self.client.get_unihash(self.METHOD, taskhash)
125 'outhash': outhash3, 101 self.assertEqual(result, unihash)
126 'unihash': unihash3 102
127 }) 103 def test_stress(self):
104 def query_server(failures):
105 client = Client(self.server.address)
106 try:
107 for i in range(1000):
108 taskhash = hashlib.sha256()
109 taskhash.update(str(i).encode('utf-8'))
110 taskhash = taskhash.hexdigest()
111 result = client.get_unihash(self.METHOD, taskhash)
112 if result != taskhash:
113 failures.append("taskhash mismatch: %s != %s" % (result, taskhash))
114 finally:
115 client.close()
116
117 # Report hashes
118 for i in range(1000):
119 taskhash = hashlib.sha256()
120 taskhash.update(str(i).encode('utf-8'))
121 taskhash = taskhash.hexdigest()
122 self.client.report_unihash(taskhash, self.METHOD, taskhash, taskhash)
123
124 failures = []
125 threads = [threading.Thread(target=query_server, args=(failures,)) for t in range(100)]
126
127 for t in threads:
128 t.start()
129
130 for t in threads:
131 t.join()
132
133 self.assertFalse(failures)
134
128 135
129 d = self.send_get('v1/equivalent?method=TestMethod&taskhash=%s' % taskhash) 136class TestHashEquivalenceUnixServer(TestHashEquivalenceServer, unittest.TestCase):
130 self.assertEqual(d['unihash'], unihash) 137 def get_server_addr(self):
138 return "unix://" + os.path.join(self.temp_dir.name, 'sock')
131 139
132 140
141class TestHashEquivalenceTCPServer(TestHashEquivalenceServer, unittest.TestCase):
142 def get_server_addr(self):
143 return "localhost:0"