diff options
Diffstat (limited to 'bitbake/lib/hashserv')
-rw-r--r-- | bitbake/lib/hashserv/__init__.py | 4 | ||||
-rw-r--r-- | bitbake/lib/hashserv/server.py | 25 | ||||
-rw-r--r-- | bitbake/lib/hashserv/tests.py | 33 |
3 files changed, 51 insertions, 11 deletions
diff --git a/bitbake/lib/hashserv/__init__.py b/bitbake/lib/hashserv/__init__.py index 55f48410d3..5f2e101e52 100644 --- a/bitbake/lib/hashserv/__init__.py +++ b/bitbake/lib/hashserv/__init__.py | |||
@@ -94,10 +94,10 @@ def chunkify(msg, max_chunk): | |||
94 | yield "\n" | 94 | yield "\n" |
95 | 95 | ||
96 | 96 | ||
97 | def create_server(addr, dbname, *, sync=True, upstream=None): | 97 | def create_server(addr, dbname, *, sync=True, upstream=None, read_only=False): |
98 | from . import server | 98 | from . import server |
99 | db = setup_database(dbname, sync=sync) | 99 | db = setup_database(dbname, sync=sync) |
100 | s = server.Server(db, upstream=upstream) | 100 | s = server.Server(db, upstream=upstream, read_only=read_only) |
101 | 101 | ||
102 | (typ, a) = parse_address(addr) | 102 | (typ, a) = parse_address(addr) |
103 | if typ == ADDR_TYPE_UNIX: | 103 | if typ == ADDR_TYPE_UNIX: |
diff --git a/bitbake/lib/hashserv/server.py b/bitbake/lib/hashserv/server.py index 3ff4c51ccb..2770c23607 100644 --- a/bitbake/lib/hashserv/server.py +++ b/bitbake/lib/hashserv/server.py | |||
@@ -112,6 +112,9 @@ class Stats(object): | |||
112 | class ClientError(Exception): | 112 | class ClientError(Exception): |
113 | pass | 113 | pass |
114 | 114 | ||
115 | class ServerError(Exception): | ||
116 | pass | ||
117 | |||
115 | def insert_task(cursor, data, ignore=False): | 118 | def insert_task(cursor, data, ignore=False): |
116 | keys = sorted(data.keys()) | 119 | keys = sorted(data.keys()) |
117 | query = '''INSERT%s INTO tasks_v2 (%s) VALUES (%s)''' % ( | 120 | query = '''INSERT%s INTO tasks_v2 (%s) VALUES (%s)''' % ( |
@@ -138,7 +141,7 @@ class ServerClient(object): | |||
138 | FAST_QUERY = 'SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1' | 141 | FAST_QUERY = 'SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1' |
139 | ALL_QUERY = 'SELECT * FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1' | 142 | ALL_QUERY = 'SELECT * FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1' |
140 | 143 | ||
141 | def __init__(self, reader, writer, db, request_stats, backfill_queue, upstream): | 144 | def __init__(self, reader, writer, db, request_stats, backfill_queue, upstream, read_only): |
142 | self.reader = reader | 145 | self.reader = reader |
143 | self.writer = writer | 146 | self.writer = writer |
144 | self.db = db | 147 | self.db = db |
@@ -149,15 +152,19 @@ class ServerClient(object): | |||
149 | 152 | ||
150 | self.handlers = { | 153 | self.handlers = { |
151 | 'get': self.handle_get, | 154 | 'get': self.handle_get, |
152 | 'report': self.handle_report, | ||
153 | 'report-equiv': self.handle_equivreport, | ||
154 | 'get-stream': self.handle_get_stream, | 155 | 'get-stream': self.handle_get_stream, |
155 | 'get-stats': self.handle_get_stats, | 156 | 'get-stats': self.handle_get_stats, |
156 | 'reset-stats': self.handle_reset_stats, | ||
157 | 'chunk-stream': self.handle_chunk, | 157 | 'chunk-stream': self.handle_chunk, |
158 | 'backfill-wait': self.handle_backfill_wait, | ||
159 | } | 158 | } |
160 | 159 | ||
160 | if not read_only: | ||
161 | self.handlers.update({ | ||
162 | 'report': self.handle_report, | ||
163 | 'report-equiv': self.handle_equivreport, | ||
164 | 'reset-stats': self.handle_reset_stats, | ||
165 | 'backfill-wait': self.handle_backfill_wait, | ||
166 | }) | ||
167 | |||
161 | async def process_requests(self): | 168 | async def process_requests(self): |
162 | if self.upstream is not None: | 169 | if self.upstream is not None: |
163 | self.upstream_client = await create_async_client(self.upstream) | 170 | self.upstream_client = await create_async_client(self.upstream) |
@@ -455,7 +462,10 @@ class ServerClient(object): | |||
455 | 462 | ||
456 | 463 | ||
457 | class Server(object): | 464 | class Server(object): |
458 | def __init__(self, db, loop=None, upstream=None): | 465 | def __init__(self, db, loop=None, upstream=None, read_only=False): |
466 | if upstream and read_only: | ||
467 | raise ServerError("Read-only hashserv cannot pull from an upstream server") | ||
468 | |||
459 | self.request_stats = Stats() | 469 | self.request_stats = Stats() |
460 | self.db = db | 470 | self.db = db |
461 | 471 | ||
@@ -467,6 +477,7 @@ class Server(object): | |||
467 | self.close_loop = False | 477 | self.close_loop = False |
468 | 478 | ||
469 | self.upstream = upstream | 479 | self.upstream = upstream |
480 | self.read_only = read_only | ||
470 | 481 | ||
471 | self._cleanup_socket = None | 482 | self._cleanup_socket = None |
472 | 483 | ||
@@ -510,7 +521,7 @@ class Server(object): | |||
510 | async def handle_client(self, reader, writer): | 521 | async def handle_client(self, reader, writer): |
511 | # writer.transport.set_write_buffer_limits(0) | 522 | # writer.transport.set_write_buffer_limits(0) |
512 | try: | 523 | try: |
513 | client = ServerClient(reader, writer, self.db, self.request_stats, self.backfill_queue, self.upstream) | 524 | client = ServerClient(reader, writer, self.db, self.request_stats, self.backfill_queue, self.upstream, self.read_only) |
514 | await client.process_requests() | 525 | await client.process_requests() |
515 | except Exception as e: | 526 | except Exception as e: |
516 | import traceback | 527 | import traceback |
diff --git a/bitbake/lib/hashserv/tests.py b/bitbake/lib/hashserv/tests.py index 77a19b8077..6f04e30d61 100644 --- a/bitbake/lib/hashserv/tests.py +++ b/bitbake/lib/hashserv/tests.py | |||
@@ -6,6 +6,7 @@ | |||
6 | # | 6 | # |
7 | 7 | ||
8 | from . import create_server, create_client | 8 | from . import create_server, create_client |
9 | from .client import HashConnectionError | ||
9 | import hashlib | 10 | import hashlib |
10 | import logging | 11 | import logging |
11 | import multiprocessing | 12 | import multiprocessing |
@@ -29,7 +30,7 @@ class HashEquivalenceTestSetup(object): | |||
29 | 30 | ||
30 | server_index = 0 | 31 | server_index = 0 |
31 | 32 | ||
32 | def start_server(self, dbpath=None, upstream=None): | 33 | def start_server(self, dbpath=None, upstream=None, read_only=False): |
33 | self.server_index += 1 | 34 | self.server_index += 1 |
34 | if dbpath is None: | 35 | if dbpath is None: |
35 | dbpath = os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index) | 36 | dbpath = os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index) |
@@ -38,7 +39,10 @@ class HashEquivalenceTestSetup(object): | |||
38 | thread.terminate() | 39 | thread.terminate() |
39 | thread.join() | 40 | thread.join() |
40 | 41 | ||
41 | server = create_server(self.get_server_addr(self.server_index), dbpath, upstream=upstream) | 42 | server = create_server(self.get_server_addr(self.server_index), |
43 | dbpath, | ||
44 | upstream=upstream, | ||
45 | read_only=read_only) | ||
42 | server.dbpath = dbpath | 46 | server.dbpath = dbpath |
43 | 47 | ||
44 | server.thread = multiprocessing.Process(target=_run_server, args=(server, self.server_index)) | 48 | server.thread = multiprocessing.Process(target=_run_server, args=(server, self.server_index)) |
@@ -242,6 +246,31 @@ class HashEquivalenceCommonTests(object): | |||
242 | self.assertClientGetHash(side_client, taskhash4, unihash4) | 246 | self.assertClientGetHash(side_client, taskhash4, unihash4) |
243 | self.assertClientGetHash(self.client, taskhash4, None) | 247 | self.assertClientGetHash(self.client, taskhash4, None) |
244 | 248 | ||
249 | def test_ro_server(self): | ||
250 | (ro_client, ro_server) = self.start_server(dbpath=self.server.dbpath, read_only=True) | ||
251 | |||
252 | # Report a hash via the read-write server | ||
253 | taskhash = '35788efcb8dfb0a02659d81cf2bfd695fb30faf9' | ||
254 | outhash = '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f' | ||
255 | unihash = 'f46d3fbb439bd9b921095da657a4de906510d2cd' | ||
256 | |||
257 | result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) | ||
258 | self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') | ||
259 | |||
260 | # Check the hash via the read-only server | ||
261 | self.assertClientGetHash(ro_client, taskhash, unihash) | ||
262 | |||
263 | # Ensure that reporting via the read-only server fails | ||
264 | taskhash2 = 'c665584ee6817aa99edfc77a44dd853828279370' | ||
265 | outhash2 = '3c979c3db45c569f51ab7626a4651074be3a9d11a84b1db076f5b14f7d39db44' | ||
266 | unihash2 = '90e9bc1d1f094c51824adca7f8ea79a048d68824' | ||
267 | |||
268 | with self.assertRaises(HashConnectionError): | ||
269 | ro_client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2) | ||
270 | |||
271 | # Ensure that the database was not modified | ||
272 | self.assertClientGetHash(self.client, taskhash2, None) | ||
273 | |||
245 | 274 | ||
246 | class TestHashEquivalenceUnixServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase): | 275 | class TestHashEquivalenceUnixServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase): |
247 | def get_server_addr(self, server_idx): | 276 | def get_server_addr(self, server_idx): |