diff options
-rw-r--r-- | bitbake/lib/hashserv/__init__.py | 22 | ||||
-rw-r--r-- | bitbake/lib/hashserv/client.py | 43 | ||||
-rw-r--r-- | bitbake/lib/hashserv/server.py | 105 | ||||
-rw-r--r-- | bitbake/lib/hashserv/tests.py | 23 |
4 files changed, 152 insertions, 41 deletions
diff --git a/bitbake/lib/hashserv/__init__.py b/bitbake/lib/hashserv/__init__.py index c3318620f5..f95e8f43f1 100644 --- a/bitbake/lib/hashserv/__init__.py +++ b/bitbake/lib/hashserv/__init__.py | |||
@@ -6,12 +6,20 @@ | |||
6 | from contextlib import closing | 6 | from contextlib import closing |
7 | import re | 7 | import re |
8 | import sqlite3 | 8 | import sqlite3 |
9 | import itertools | ||
10 | import json | ||
9 | 11 | ||
10 | UNIX_PREFIX = "unix://" | 12 | UNIX_PREFIX = "unix://" |
11 | 13 | ||
12 | ADDR_TYPE_UNIX = 0 | 14 | ADDR_TYPE_UNIX = 0 |
13 | ADDR_TYPE_TCP = 1 | 15 | ADDR_TYPE_TCP = 1 |
14 | 16 | ||
17 | # The Python async server defaults to a 64K receive buffer, so we hardcode our | ||
18 | # maximum chunk size. It would be better if the client and server reported to | ||
19 | # each other what the maximum chunk sizes were, but that will slow down the | ||
20 | # connection setup with a round trip delay so I'd rather not do that unless it | ||
21 | # is necessary | ||
22 | DEFAULT_MAX_CHUNK = 32 * 1024 | ||
15 | 23 | ||
16 | def setup_database(database, sync=True): | 24 | def setup_database(database, sync=True): |
17 | db = sqlite3.connect(database) | 25 | db = sqlite3.connect(database) |
@@ -66,6 +74,20 @@ def parse_address(addr): | |||
66 | return (ADDR_TYPE_TCP, (host, int(port))) | 74 | return (ADDR_TYPE_TCP, (host, int(port))) |
67 | 75 | ||
68 | 76 | ||
77 | def chunkify(msg, max_chunk): | ||
78 | if len(msg) < max_chunk - 1: | ||
79 | yield ''.join((msg, "\n")) | ||
80 | else: | ||
81 | yield ''.join((json.dumps({ | ||
82 | 'chunk-stream': None | ||
83 | }), "\n")) | ||
84 | |||
85 | args = [iter(msg)] * (max_chunk - 1) | ||
86 | for m in map(''.join, itertools.zip_longest(*args, fillvalue='')): | ||
87 | yield ''.join(itertools.chain(m, "\n")) | ||
88 | yield "\n" | ||
89 | |||
90 | |||
69 | def create_server(addr, dbname, *, sync=True): | 91 | def create_server(addr, dbname, *, sync=True): |
70 | from . import server | 92 | from . import server |
71 | db = setup_database(dbname, sync=sync) | 93 | db = setup_database(dbname, sync=sync) |
diff --git a/bitbake/lib/hashserv/client.py b/bitbake/lib/hashserv/client.py index 46085d6418..a29af836d9 100644 --- a/bitbake/lib/hashserv/client.py +++ b/bitbake/lib/hashserv/client.py | |||
@@ -7,6 +7,7 @@ import json | |||
7 | import logging | 7 | import logging |
8 | import socket | 8 | import socket |
9 | import os | 9 | import os |
10 | from . import chunkify, DEFAULT_MAX_CHUNK | ||
10 | 11 | ||
11 | 12 | ||
12 | logger = logging.getLogger('hashserv.client') | 13 | logger = logging.getLogger('hashserv.client') |
@@ -25,6 +26,7 @@ class Client(object): | |||
25 | self.reader = None | 26 | self.reader = None |
26 | self.writer = None | 27 | self.writer = None |
27 | self.mode = self.MODE_NORMAL | 28 | self.mode = self.MODE_NORMAL |
29 | self.max_chunk = DEFAULT_MAX_CHUNK | ||
28 | 30 | ||
29 | def connect_tcp(self, address, port): | 31 | def connect_tcp(self, address, port): |
30 | def connect_sock(): | 32 | def connect_sock(): |
@@ -58,7 +60,7 @@ class Client(object): | |||
58 | self.reader = self._socket.makefile('r', encoding='utf-8') | 60 | self.reader = self._socket.makefile('r', encoding='utf-8') |
59 | self.writer = self._socket.makefile('w', encoding='utf-8') | 61 | self.writer = self._socket.makefile('w', encoding='utf-8') |
60 | 62 | ||
61 | self.writer.write('OEHASHEQUIV 1.0\n\n') | 63 | self.writer.write('OEHASHEQUIV 1.1\n\n') |
62 | self.writer.flush() | 64 | self.writer.flush() |
63 | 65 | ||
64 | # Restore mode if the socket is being re-created | 66 | # Restore mode if the socket is being re-created |
@@ -91,18 +93,35 @@ class Client(object): | |||
91 | count += 1 | 93 | count += 1 |
92 | 94 | ||
93 | def send_message(self, msg): | 95 | def send_message(self, msg): |
96 | def get_line(): | ||
97 | line = self.reader.readline() | ||
98 | if not line: | ||
99 | raise HashConnectionError('Connection closed') | ||
100 | |||
101 | if not line.endswith('\n'): | ||
102 | raise HashConnectionError('Bad message %r' % message) | ||
103 | |||
104 | return line | ||
105 | |||
94 | def proc(): | 106 | def proc(): |
95 | self.writer.write('%s\n' % json.dumps(msg)) | 107 | for c in chunkify(json.dumps(msg), self.max_chunk): |
108 | self.writer.write(c) | ||
96 | self.writer.flush() | 109 | self.writer.flush() |
97 | 110 | ||
98 | l = self.reader.readline() | 111 | l = get_line() |
99 | if not l: | ||
100 | raise HashConnectionError('Connection closed') | ||
101 | 112 | ||
102 | if not l.endswith('\n'): | 113 | m = json.loads(l) |
103 | raise HashConnectionError('Bad message %r' % message) | 114 | if 'chunk-stream' in m: |
115 | lines = [] | ||
116 | while True: | ||
117 | l = get_line().rstrip('\n') | ||
118 | if not l: | ||
119 | break | ||
120 | lines.append(l) | ||
104 | 121 | ||
105 | return json.loads(l) | 122 | m = json.loads(''.join(lines)) |
123 | |||
124 | return m | ||
106 | 125 | ||
107 | return self._send_wrapper(proc) | 126 | return self._send_wrapper(proc) |
108 | 127 | ||
@@ -155,6 +174,14 @@ class Client(object): | |||
155 | m['unihash'] = unihash | 174 | m['unihash'] = unihash |
156 | return self.send_message({'report-equiv': m}) | 175 | return self.send_message({'report-equiv': m}) |
157 | 176 | ||
177 | def get_taskhash(self, method, taskhash, all_properties=False): | ||
178 | self._set_mode(self.MODE_NORMAL) | ||
179 | return self.send_message({'get': { | ||
180 | 'taskhash': taskhash, | ||
181 | 'method': method, | ||
182 | 'all': all_properties | ||
183 | }}) | ||
184 | |||
158 | def get_stats(self): | 185 | def get_stats(self): |
159 | self._set_mode(self.MODE_NORMAL) | 186 | self._set_mode(self.MODE_NORMAL) |
160 | return self.send_message({'get-stats': None}) | 187 | return self.send_message({'get-stats': None}) |
diff --git a/bitbake/lib/hashserv/server.py b/bitbake/lib/hashserv/server.py index cc7e48233b..81050715ea 100644 --- a/bitbake/lib/hashserv/server.py +++ b/bitbake/lib/hashserv/server.py | |||
@@ -13,6 +13,7 @@ import os | |||
13 | import signal | 13 | import signal |
14 | import socket | 14 | import socket |
15 | import time | 15 | import time |
16 | from . import chunkify, DEFAULT_MAX_CHUNK | ||
16 | 17 | ||
17 | logger = logging.getLogger('hashserv.server') | 18 | logger = logging.getLogger('hashserv.server') |
18 | 19 | ||
@@ -107,12 +108,29 @@ class Stats(object): | |||
107 | return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')} | 108 | return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')} |
108 | 109 | ||
109 | 110 | ||
111 | class ClientError(Exception): | ||
112 | pass | ||
113 | |||
110 | class ServerClient(object): | 114 | class ServerClient(object): |
115 | FAST_QUERY = 'SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1' | ||
116 | ALL_QUERY = 'SELECT * FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1' | ||
117 | |||
111 | def __init__(self, reader, writer, db, request_stats): | 118 | def __init__(self, reader, writer, db, request_stats): |
112 | self.reader = reader | 119 | self.reader = reader |
113 | self.writer = writer | 120 | self.writer = writer |
114 | self.db = db | 121 | self.db = db |
115 | self.request_stats = request_stats | 122 | self.request_stats = request_stats |
123 | self.max_chunk = DEFAULT_MAX_CHUNK | ||
124 | |||
125 | self.handlers = { | ||
126 | 'get': self.handle_get, | ||
127 | 'report': self.handle_report, | ||
128 | 'report-equiv': self.handle_equivreport, | ||
129 | 'get-stream': self.handle_get_stream, | ||
130 | 'get-stats': self.handle_get_stats, | ||
131 | 'reset-stats': self.handle_reset_stats, | ||
132 | 'chunk-stream': self.handle_chunk, | ||
133 | } | ||
116 | 134 | ||
117 | async def process_requests(self): | 135 | async def process_requests(self): |
118 | try: | 136 | try: |
@@ -125,7 +143,11 @@ class ServerClient(object): | |||
125 | return | 143 | return |
126 | 144 | ||
127 | (proto_name, proto_version) = protocol.decode('utf-8').rstrip().split() | 145 | (proto_name, proto_version) = protocol.decode('utf-8').rstrip().split() |
128 | if proto_name != 'OEHASHEQUIV' or proto_version != '1.0': | 146 | if proto_name != 'OEHASHEQUIV': |
147 | return | ||
148 | |||
149 | proto_version = tuple(int(v) for v in proto_version.split('.')) | ||
150 | if proto_version < (1, 0) or proto_version > (1, 1): | ||
129 | return | 151 | return |
130 | 152 | ||
131 | # Read headers. Currently, no headers are implemented, so look for | 153 | # Read headers. Currently, no headers are implemented, so look for |
@@ -140,40 +162,34 @@ class ServerClient(object): | |||
140 | break | 162 | break |
141 | 163 | ||
142 | # Handle messages | 164 | # Handle messages |
143 | handlers = { | ||
144 | 'get': self.handle_get, | ||
145 | 'report': self.handle_report, | ||
146 | 'report-equiv': self.handle_equivreport, | ||
147 | 'get-stream': self.handle_get_stream, | ||
148 | 'get-stats': self.handle_get_stats, | ||
149 | 'reset-stats': self.handle_reset_stats, | ||
150 | } | ||
151 | |||
152 | while True: | 165 | while True: |
153 | d = await self.read_message() | 166 | d = await self.read_message() |
154 | if d is None: | 167 | if d is None: |
155 | break | 168 | break |
156 | 169 | await self.dispatch_message(d) | |
157 | for k in handlers.keys(): | ||
158 | if k in d: | ||
159 | logger.debug('Handling %s' % k) | ||
160 | if 'stream' in k: | ||
161 | await handlers[k](d[k]) | ||
162 | else: | ||
163 | with self.request_stats.start_sample() as self.request_sample, \ | ||
164 | self.request_sample.measure(): | ||
165 | await handlers[k](d[k]) | ||
166 | break | ||
167 | else: | ||
168 | logger.warning("Unrecognized command %r" % d) | ||
169 | break | ||
170 | |||
171 | await self.writer.drain() | 170 | await self.writer.drain() |
171 | except ClientError as e: | ||
172 | logger.error(str(e)) | ||
172 | finally: | 173 | finally: |
173 | self.writer.close() | 174 | self.writer.close() |
174 | 175 | ||
176 | async def dispatch_message(self, msg): | ||
177 | for k in self.handlers.keys(): | ||
178 | if k in msg: | ||
179 | logger.debug('Handling %s' % k) | ||
180 | if 'stream' in k: | ||
181 | await self.handlers[k](msg[k]) | ||
182 | else: | ||
183 | with self.request_stats.start_sample() as self.request_sample, \ | ||
184 | self.request_sample.measure(): | ||
185 | await self.handlers[k](msg[k]) | ||
186 | return | ||
187 | |||
188 | raise ClientError("Unrecognized command %r" % msg) | ||
189 | |||
175 | def write_message(self, msg): | 190 | def write_message(self, msg): |
176 | self.writer.write(('%s\n' % json.dumps(msg)).encode('utf-8')) | 191 | for c in chunkify(json.dumps(msg), self.max_chunk): |
192 | self.writer.write(c.encode('utf-8')) | ||
177 | 193 | ||
178 | async def read_message(self): | 194 | async def read_message(self): |
179 | l = await self.reader.readline() | 195 | l = await self.reader.readline() |
@@ -191,14 +207,38 @@ class ServerClient(object): | |||
191 | logger.error('Bad message from client: %r' % message) | 207 | logger.error('Bad message from client: %r' % message) |
192 | raise e | 208 | raise e |
193 | 209 | ||
210 | async def handle_chunk(self, request): | ||
211 | lines = [] | ||
212 | try: | ||
213 | while True: | ||
214 | l = await self.reader.readline() | ||
215 | l = l.rstrip(b"\n").decode("utf-8") | ||
216 | if not l: | ||
217 | break | ||
218 | lines.append(l) | ||
219 | |||
220 | msg = json.loads(''.join(lines)) | ||
221 | except (json.JSONDecodeError, UnicodeDecodeError) as e: | ||
222 | logger.error('Bad message from client: %r' % message) | ||
223 | raise e | ||
224 | |||
225 | if 'chunk-stream' in msg: | ||
226 | raise ClientError("Nested chunks are not allowed") | ||
227 | |||
228 | await self.dispatch_message(msg) | ||
229 | |||
194 | async def handle_get(self, request): | 230 | async def handle_get(self, request): |
195 | method = request['method'] | 231 | method = request['method'] |
196 | taskhash = request['taskhash'] | 232 | taskhash = request['taskhash'] |
197 | 233 | ||
198 | row = self.query_equivalent(method, taskhash) | 234 | if request.get('all', False): |
235 | row = self.query_equivalent(method, taskhash, self.ALL_QUERY) | ||
236 | else: | ||
237 | row = self.query_equivalent(method, taskhash, self.FAST_QUERY) | ||
238 | |||
199 | if row is not None: | 239 | if row is not None: |
200 | logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) | 240 | logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) |
201 | d = {k: row[k] for k in ('taskhash', 'method', 'unihash')} | 241 | d = {k: row[k] for k in row.keys()} |
202 | 242 | ||
203 | self.write_message(d) | 243 | self.write_message(d) |
204 | else: | 244 | else: |
@@ -228,7 +268,7 @@ class ServerClient(object): | |||
228 | 268 | ||
229 | (method, taskhash) = l.split() | 269 | (method, taskhash) = l.split() |
230 | #logger.debug('Looking up %s %s' % (method, taskhash)) | 270 | #logger.debug('Looking up %s %s' % (method, taskhash)) |
231 | row = self.query_equivalent(method, taskhash) | 271 | row = self.query_equivalent(method, taskhash, self.FAST_QUERY) |
232 | if row is not None: | 272 | if row is not None: |
233 | msg = ('%s\n' % row['unihash']).encode('utf-8') | 273 | msg = ('%s\n' % row['unihash']).encode('utf-8') |
234 | #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) | 274 | #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) |
@@ -328,7 +368,7 @@ class ServerClient(object): | |||
328 | # Fetch the unihash that will be reported for the taskhash. If the | 368 | # Fetch the unihash that will be reported for the taskhash. If the |
329 | # unihash matches, it means this row was inserted (or the mapping | 369 | # unihash matches, it means this row was inserted (or the mapping |
330 | # was already valid) | 370 | # was already valid) |
331 | row = self.query_equivalent(data['method'], data['taskhash']) | 371 | row = self.query_equivalent(data['method'], data['taskhash'], self.FAST_QUERY) |
332 | 372 | ||
333 | if row['unihash'] == data['unihash']: | 373 | if row['unihash'] == data['unihash']: |
334 | logger.info('Adding taskhash equivalence for %s with unihash %s', | 374 | logger.info('Adding taskhash equivalence for %s with unihash %s', |
@@ -354,12 +394,11 @@ class ServerClient(object): | |||
354 | self.request_stats.reset() | 394 | self.request_stats.reset() |
355 | self.write_message(d) | 395 | self.write_message(d) |
356 | 396 | ||
357 | def query_equivalent(self, method, taskhash): | 397 | def query_equivalent(self, method, taskhash, query): |
358 | # This is part of the inner loop and must be as fast as possible | 398 | # This is part of the inner loop and must be as fast as possible |
359 | try: | 399 | try: |
360 | cursor = self.db.cursor() | 400 | cursor = self.db.cursor() |
361 | cursor.execute('SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1', | 401 | cursor.execute(query, {'method': method, 'taskhash': taskhash}) |
362 | {'method': method, 'taskhash': taskhash}) | ||
363 | return cursor.fetchone() | 402 | return cursor.fetchone() |
364 | except: | 403 | except: |
365 | cursor.close() | 404 | cursor.close() |
diff --git a/bitbake/lib/hashserv/tests.py b/bitbake/lib/hashserv/tests.py index a5472a996d..6e86295079 100644 --- a/bitbake/lib/hashserv/tests.py +++ b/bitbake/lib/hashserv/tests.py | |||
@@ -99,6 +99,29 @@ class TestHashEquivalenceServer(object): | |||
99 | result = self.client.get_unihash(self.METHOD, taskhash) | 99 | result = self.client.get_unihash(self.METHOD, taskhash) |
100 | self.assertEqual(result, unihash) | 100 | self.assertEqual(result, unihash) |
101 | 101 | ||
102 | def test_huge_message(self): | ||
103 | # Simple test that hashes can be created | ||
104 | taskhash = 'c665584ee6817aa99edfc77a44dd853828279370' | ||
105 | outhash = '3c979c3db45c569f51ab7626a4651074be3a9d11a84b1db076f5b14f7d39db44' | ||
106 | unihash = '90e9bc1d1f094c51824adca7f8ea79a048d68824' | ||
107 | |||
108 | result = self.client.get_unihash(self.METHOD, taskhash) | ||
109 | self.assertIsNone(result, msg='Found unexpected task, %r' % result) | ||
110 | |||
111 | siginfo = "0" * (self.client.max_chunk * 4) | ||
112 | |||
113 | result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash, { | ||
114 | 'outhash_siginfo': siginfo | ||
115 | }) | ||
116 | self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') | ||
117 | |||
118 | result = self.client.get_taskhash(self.METHOD, taskhash, True) | ||
119 | self.assertEqual(result['taskhash'], taskhash) | ||
120 | self.assertEqual(result['unihash'], unihash) | ||
121 | self.assertEqual(result['method'], self.METHOD) | ||
122 | self.assertEqual(result['outhash'], outhash) | ||
123 | self.assertEqual(result['outhash_siginfo'], siginfo) | ||
124 | |||
102 | def test_stress(self): | 125 | def test_stress(self): |
103 | def query_server(failures): | 126 | def query_server(failures): |
104 | client = Client(self.server.address) | 127 | client = Client(self.server.address) |