diff options
author | Joshua Watt <JPEWhacker@gmail.com> | 2020-06-25 09:21:07 -0500 |
---|---|---|
committer | Richard Purdie <richard.purdie@linuxfoundation.org> | 2020-07-02 16:11:40 +0100 |
commit | 6ebf01bfd43b6d95a70699b1e58a42fd7d1002a6 (patch) | |
tree | 3f18afa2f1918dde70ced1013e5d859cc4c573e7 | |
parent | b6e0f5889eb55d88276807407f75eaad9bf0a96a (diff) | |
download | poky-6ebf01bfd43b6d95a70699b1e58a42fd7d1002a6.tar.gz |
bitbake: hashserv: Chunkify large messages
The hash equivalence client and server can occasionally send messages
that are too large for the server to fit in the receive buffer (64 KB).
To prevent this, support is added to the protocol to "chunkify" the
stream and break it up into manageable pieces that the server can each
side can back together.
Ideally, this would be negotiated by the client and server, but it's
currently hard coded to 32 KB to prevent the round-trip delay.
(Bitbake rev: 1a7bddb5471a02a744e7a441a3b4a6da693348b0)
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
Signed-off-by: Richard Purdie <richard.purdie@linuxfoundation.org>
(cherry picked from commit e27a28c1e40e886ee68ba4b99b537ffc9c3577d4)
Signed-off-by: Steve Sakoman <steve@sakoman.com>
Signed-off-by: Richard Purdie <richard.purdie@linuxfoundation.org>
-rw-r--r-- | bitbake/lib/hashserv/__init__.py | 22 | ||||
-rw-r--r-- | bitbake/lib/hashserv/client.py | 43 | ||||
-rw-r--r-- | bitbake/lib/hashserv/server.py | 105 | ||||
-rw-r--r-- | bitbake/lib/hashserv/tests.py | 23 |
4 files changed, 152 insertions, 41 deletions
diff --git a/bitbake/lib/hashserv/__init__.py b/bitbake/lib/hashserv/__init__.py index c3318620f5..f95e8f43f1 100644 --- a/bitbake/lib/hashserv/__init__.py +++ b/bitbake/lib/hashserv/__init__.py | |||
@@ -6,12 +6,20 @@ | |||
6 | from contextlib import closing | 6 | from contextlib import closing |
7 | import re | 7 | import re |
8 | import sqlite3 | 8 | import sqlite3 |
9 | import itertools | ||
10 | import json | ||
9 | 11 | ||
10 | UNIX_PREFIX = "unix://" | 12 | UNIX_PREFIX = "unix://" |
11 | 13 | ||
12 | ADDR_TYPE_UNIX = 0 | 14 | ADDR_TYPE_UNIX = 0 |
13 | ADDR_TYPE_TCP = 1 | 15 | ADDR_TYPE_TCP = 1 |
14 | 16 | ||
17 | # The Python async server defaults to a 64K receive buffer, so we hardcode our | ||
18 | # maximum chunk size. It would be better if the client and server reported to | ||
19 | # each other what the maximum chunk sizes were, but that will slow down the | ||
20 | # connection setup with a round trip delay so I'd rather not do that unless it | ||
21 | # is necessary | ||
22 | DEFAULT_MAX_CHUNK = 32 * 1024 | ||
15 | 23 | ||
16 | def setup_database(database, sync=True): | 24 | def setup_database(database, sync=True): |
17 | db = sqlite3.connect(database) | 25 | db = sqlite3.connect(database) |
@@ -66,6 +74,20 @@ def parse_address(addr): | |||
66 | return (ADDR_TYPE_TCP, (host, int(port))) | 74 | return (ADDR_TYPE_TCP, (host, int(port))) |
67 | 75 | ||
68 | 76 | ||
77 | def chunkify(msg, max_chunk): | ||
78 | if len(msg) < max_chunk - 1: | ||
79 | yield ''.join((msg, "\n")) | ||
80 | else: | ||
81 | yield ''.join((json.dumps({ | ||
82 | 'chunk-stream': None | ||
83 | }), "\n")) | ||
84 | |||
85 | args = [iter(msg)] * (max_chunk - 1) | ||
86 | for m in map(''.join, itertools.zip_longest(*args, fillvalue='')): | ||
87 | yield ''.join(itertools.chain(m, "\n")) | ||
88 | yield "\n" | ||
89 | |||
90 | |||
69 | def create_server(addr, dbname, *, sync=True): | 91 | def create_server(addr, dbname, *, sync=True): |
70 | from . import server | 92 | from . import server |
71 | db = setup_database(dbname, sync=sync) | 93 | db = setup_database(dbname, sync=sync) |
diff --git a/bitbake/lib/hashserv/client.py b/bitbake/lib/hashserv/client.py index 46085d6418..a29af836d9 100644 --- a/bitbake/lib/hashserv/client.py +++ b/bitbake/lib/hashserv/client.py | |||
@@ -7,6 +7,7 @@ import json | |||
7 | import logging | 7 | import logging |
8 | import socket | 8 | import socket |
9 | import os | 9 | import os |
10 | from . import chunkify, DEFAULT_MAX_CHUNK | ||
10 | 11 | ||
11 | 12 | ||
12 | logger = logging.getLogger('hashserv.client') | 13 | logger = logging.getLogger('hashserv.client') |
@@ -25,6 +26,7 @@ class Client(object): | |||
25 | self.reader = None | 26 | self.reader = None |
26 | self.writer = None | 27 | self.writer = None |
27 | self.mode = self.MODE_NORMAL | 28 | self.mode = self.MODE_NORMAL |
29 | self.max_chunk = DEFAULT_MAX_CHUNK | ||
28 | 30 | ||
29 | def connect_tcp(self, address, port): | 31 | def connect_tcp(self, address, port): |
30 | def connect_sock(): | 32 | def connect_sock(): |
@@ -58,7 +60,7 @@ class Client(object): | |||
58 | self.reader = self._socket.makefile('r', encoding='utf-8') | 60 | self.reader = self._socket.makefile('r', encoding='utf-8') |
59 | self.writer = self._socket.makefile('w', encoding='utf-8') | 61 | self.writer = self._socket.makefile('w', encoding='utf-8') |
60 | 62 | ||
61 | self.writer.write('OEHASHEQUIV 1.0\n\n') | 63 | self.writer.write('OEHASHEQUIV 1.1\n\n') |
62 | self.writer.flush() | 64 | self.writer.flush() |
63 | 65 | ||
64 | # Restore mode if the socket is being re-created | 66 | # Restore mode if the socket is being re-created |
@@ -91,18 +93,35 @@ class Client(object): | |||
91 | count += 1 | 93 | count += 1 |
92 | 94 | ||
93 | def send_message(self, msg): | 95 | def send_message(self, msg): |
96 | def get_line(): | ||
97 | line = self.reader.readline() | ||
98 | if not line: | ||
99 | raise HashConnectionError('Connection closed') | ||
100 | |||
101 | if not line.endswith('\n'): | ||
102 | raise HashConnectionError('Bad message %r' % message) | ||
103 | |||
104 | return line | ||
105 | |||
94 | def proc(): | 106 | def proc(): |
95 | self.writer.write('%s\n' % json.dumps(msg)) | 107 | for c in chunkify(json.dumps(msg), self.max_chunk): |
108 | self.writer.write(c) | ||
96 | self.writer.flush() | 109 | self.writer.flush() |
97 | 110 | ||
98 | l = self.reader.readline() | 111 | l = get_line() |
99 | if not l: | ||
100 | raise HashConnectionError('Connection closed') | ||
101 | 112 | ||
102 | if not l.endswith('\n'): | 113 | m = json.loads(l) |
103 | raise HashConnectionError('Bad message %r' % message) | 114 | if 'chunk-stream' in m: |
115 | lines = [] | ||
116 | while True: | ||
117 | l = get_line().rstrip('\n') | ||
118 | if not l: | ||
119 | break | ||
120 | lines.append(l) | ||
104 | 121 | ||
105 | return json.loads(l) | 122 | m = json.loads(''.join(lines)) |
123 | |||
124 | return m | ||
106 | 125 | ||
107 | return self._send_wrapper(proc) | 126 | return self._send_wrapper(proc) |
108 | 127 | ||
@@ -155,6 +174,14 @@ class Client(object): | |||
155 | m['unihash'] = unihash | 174 | m['unihash'] = unihash |
156 | return self.send_message({'report-equiv': m}) | 175 | return self.send_message({'report-equiv': m}) |
157 | 176 | ||
177 | def get_taskhash(self, method, taskhash, all_properties=False): | ||
178 | self._set_mode(self.MODE_NORMAL) | ||
179 | return self.send_message({'get': { | ||
180 | 'taskhash': taskhash, | ||
181 | 'method': method, | ||
182 | 'all': all_properties | ||
183 | }}) | ||
184 | |||
158 | def get_stats(self): | 185 | def get_stats(self): |
159 | self._set_mode(self.MODE_NORMAL) | 186 | self._set_mode(self.MODE_NORMAL) |
160 | return self.send_message({'get-stats': None}) | 187 | return self.send_message({'get-stats': None}) |
diff --git a/bitbake/lib/hashserv/server.py b/bitbake/lib/hashserv/server.py index cc7e48233b..81050715ea 100644 --- a/bitbake/lib/hashserv/server.py +++ b/bitbake/lib/hashserv/server.py | |||
@@ -13,6 +13,7 @@ import os | |||
13 | import signal | 13 | import signal |
14 | import socket | 14 | import socket |
15 | import time | 15 | import time |
16 | from . import chunkify, DEFAULT_MAX_CHUNK | ||
16 | 17 | ||
17 | logger = logging.getLogger('hashserv.server') | 18 | logger = logging.getLogger('hashserv.server') |
18 | 19 | ||
@@ -107,12 +108,29 @@ class Stats(object): | |||
107 | return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')} | 108 | return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')} |
108 | 109 | ||
109 | 110 | ||
111 | class ClientError(Exception): | ||
112 | pass | ||
113 | |||
110 | class ServerClient(object): | 114 | class ServerClient(object): |
115 | FAST_QUERY = 'SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1' | ||
116 | ALL_QUERY = 'SELECT * FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1' | ||
117 | |||
111 | def __init__(self, reader, writer, db, request_stats): | 118 | def __init__(self, reader, writer, db, request_stats): |
112 | self.reader = reader | 119 | self.reader = reader |
113 | self.writer = writer | 120 | self.writer = writer |
114 | self.db = db | 121 | self.db = db |
115 | self.request_stats = request_stats | 122 | self.request_stats = request_stats |
123 | self.max_chunk = DEFAULT_MAX_CHUNK | ||
124 | |||
125 | self.handlers = { | ||
126 | 'get': self.handle_get, | ||
127 | 'report': self.handle_report, | ||
128 | 'report-equiv': self.handle_equivreport, | ||
129 | 'get-stream': self.handle_get_stream, | ||
130 | 'get-stats': self.handle_get_stats, | ||
131 | 'reset-stats': self.handle_reset_stats, | ||
132 | 'chunk-stream': self.handle_chunk, | ||
133 | } | ||
116 | 134 | ||
117 | async def process_requests(self): | 135 | async def process_requests(self): |
118 | try: | 136 | try: |
@@ -125,7 +143,11 @@ class ServerClient(object): | |||
125 | return | 143 | return |
126 | 144 | ||
127 | (proto_name, proto_version) = protocol.decode('utf-8').rstrip().split() | 145 | (proto_name, proto_version) = protocol.decode('utf-8').rstrip().split() |
128 | if proto_name != 'OEHASHEQUIV' or proto_version != '1.0': | 146 | if proto_name != 'OEHASHEQUIV': |
147 | return | ||
148 | |||
149 | proto_version = tuple(int(v) for v in proto_version.split('.')) | ||
150 | if proto_version < (1, 0) or proto_version > (1, 1): | ||
129 | return | 151 | return |
130 | 152 | ||
131 | # Read headers. Currently, no headers are implemented, so look for | 153 | # Read headers. Currently, no headers are implemented, so look for |
@@ -140,40 +162,34 @@ class ServerClient(object): | |||
140 | break | 162 | break |
141 | 163 | ||
142 | # Handle messages | 164 | # Handle messages |
143 | handlers = { | ||
144 | 'get': self.handle_get, | ||
145 | 'report': self.handle_report, | ||
146 | 'report-equiv': self.handle_equivreport, | ||
147 | 'get-stream': self.handle_get_stream, | ||
148 | 'get-stats': self.handle_get_stats, | ||
149 | 'reset-stats': self.handle_reset_stats, | ||
150 | } | ||
151 | |||
152 | while True: | 165 | while True: |
153 | d = await self.read_message() | 166 | d = await self.read_message() |
154 | if d is None: | 167 | if d is None: |
155 | break | 168 | break |
156 | 169 | await self.dispatch_message(d) | |
157 | for k in handlers.keys(): | ||
158 | if k in d: | ||
159 | logger.debug('Handling %s' % k) | ||
160 | if 'stream' in k: | ||
161 | await handlers[k](d[k]) | ||
162 | else: | ||
163 | with self.request_stats.start_sample() as self.request_sample, \ | ||
164 | self.request_sample.measure(): | ||
165 | await handlers[k](d[k]) | ||
166 | break | ||
167 | else: | ||
168 | logger.warning("Unrecognized command %r" % d) | ||
169 | break | ||
170 | |||
171 | await self.writer.drain() | 170 | await self.writer.drain() |
171 | except ClientError as e: | ||
172 | logger.error(str(e)) | ||
172 | finally: | 173 | finally: |
173 | self.writer.close() | 174 | self.writer.close() |
174 | 175 | ||
176 | async def dispatch_message(self, msg): | ||
177 | for k in self.handlers.keys(): | ||
178 | if k in msg: | ||
179 | logger.debug('Handling %s' % k) | ||
180 | if 'stream' in k: | ||
181 | await self.handlers[k](msg[k]) | ||
182 | else: | ||
183 | with self.request_stats.start_sample() as self.request_sample, \ | ||
184 | self.request_sample.measure(): | ||
185 | await self.handlers[k](msg[k]) | ||
186 | return | ||
187 | |||
188 | raise ClientError("Unrecognized command %r" % msg) | ||
189 | |||
175 | def write_message(self, msg): | 190 | def write_message(self, msg): |
176 | self.writer.write(('%s\n' % json.dumps(msg)).encode('utf-8')) | 191 | for c in chunkify(json.dumps(msg), self.max_chunk): |
192 | self.writer.write(c.encode('utf-8')) | ||
177 | 193 | ||
178 | async def read_message(self): | 194 | async def read_message(self): |
179 | l = await self.reader.readline() | 195 | l = await self.reader.readline() |
@@ -191,14 +207,38 @@ class ServerClient(object): | |||
191 | logger.error('Bad message from client: %r' % message) | 207 | logger.error('Bad message from client: %r' % message) |
192 | raise e | 208 | raise e |
193 | 209 | ||
210 | async def handle_chunk(self, request): | ||
211 | lines = [] | ||
212 | try: | ||
213 | while True: | ||
214 | l = await self.reader.readline() | ||
215 | l = l.rstrip(b"\n").decode("utf-8") | ||
216 | if not l: | ||
217 | break | ||
218 | lines.append(l) | ||
219 | |||
220 | msg = json.loads(''.join(lines)) | ||
221 | except (json.JSONDecodeError, UnicodeDecodeError) as e: | ||
222 | logger.error('Bad message from client: %r' % message) | ||
223 | raise e | ||
224 | |||
225 | if 'chunk-stream' in msg: | ||
226 | raise ClientError("Nested chunks are not allowed") | ||
227 | |||
228 | await self.dispatch_message(msg) | ||
229 | |||
194 | async def handle_get(self, request): | 230 | async def handle_get(self, request): |
195 | method = request['method'] | 231 | method = request['method'] |
196 | taskhash = request['taskhash'] | 232 | taskhash = request['taskhash'] |
197 | 233 | ||
198 | row = self.query_equivalent(method, taskhash) | 234 | if request.get('all', False): |
235 | row = self.query_equivalent(method, taskhash, self.ALL_QUERY) | ||
236 | else: | ||
237 | row = self.query_equivalent(method, taskhash, self.FAST_QUERY) | ||
238 | |||
199 | if row is not None: | 239 | if row is not None: |
200 | logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) | 240 | logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) |
201 | d = {k: row[k] for k in ('taskhash', 'method', 'unihash')} | 241 | d = {k: row[k] for k in row.keys()} |
202 | 242 | ||
203 | self.write_message(d) | 243 | self.write_message(d) |
204 | else: | 244 | else: |
@@ -228,7 +268,7 @@ class ServerClient(object): | |||
228 | 268 | ||
229 | (method, taskhash) = l.split() | 269 | (method, taskhash) = l.split() |
230 | #logger.debug('Looking up %s %s' % (method, taskhash)) | 270 | #logger.debug('Looking up %s %s' % (method, taskhash)) |
231 | row = self.query_equivalent(method, taskhash) | 271 | row = self.query_equivalent(method, taskhash, self.FAST_QUERY) |
232 | if row is not None: | 272 | if row is not None: |
233 | msg = ('%s\n' % row['unihash']).encode('utf-8') | 273 | msg = ('%s\n' % row['unihash']).encode('utf-8') |
234 | #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) | 274 | #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) |
@@ -328,7 +368,7 @@ class ServerClient(object): | |||
328 | # Fetch the unihash that will be reported for the taskhash. If the | 368 | # Fetch the unihash that will be reported for the taskhash. If the |
329 | # unihash matches, it means this row was inserted (or the mapping | 369 | # unihash matches, it means this row was inserted (or the mapping |
330 | # was already valid) | 370 | # was already valid) |
331 | row = self.query_equivalent(data['method'], data['taskhash']) | 371 | row = self.query_equivalent(data['method'], data['taskhash'], self.FAST_QUERY) |
332 | 372 | ||
333 | if row['unihash'] == data['unihash']: | 373 | if row['unihash'] == data['unihash']: |
334 | logger.info('Adding taskhash equivalence for %s with unihash %s', | 374 | logger.info('Adding taskhash equivalence for %s with unihash %s', |
@@ -354,12 +394,11 @@ class ServerClient(object): | |||
354 | self.request_stats.reset() | 394 | self.request_stats.reset() |
355 | self.write_message(d) | 395 | self.write_message(d) |
356 | 396 | ||
357 | def query_equivalent(self, method, taskhash): | 397 | def query_equivalent(self, method, taskhash, query): |
358 | # This is part of the inner loop and must be as fast as possible | 398 | # This is part of the inner loop and must be as fast as possible |
359 | try: | 399 | try: |
360 | cursor = self.db.cursor() | 400 | cursor = self.db.cursor() |
361 | cursor.execute('SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1', | 401 | cursor.execute(query, {'method': method, 'taskhash': taskhash}) |
362 | {'method': method, 'taskhash': taskhash}) | ||
363 | return cursor.fetchone() | 402 | return cursor.fetchone() |
364 | except: | 403 | except: |
365 | cursor.close() | 404 | cursor.close() |
diff --git a/bitbake/lib/hashserv/tests.py b/bitbake/lib/hashserv/tests.py index a5472a996d..6e86295079 100644 --- a/bitbake/lib/hashserv/tests.py +++ b/bitbake/lib/hashserv/tests.py | |||
@@ -99,6 +99,29 @@ class TestHashEquivalenceServer(object): | |||
99 | result = self.client.get_unihash(self.METHOD, taskhash) | 99 | result = self.client.get_unihash(self.METHOD, taskhash) |
100 | self.assertEqual(result, unihash) | 100 | self.assertEqual(result, unihash) |
101 | 101 | ||
102 | def test_huge_message(self): | ||
103 | # Simple test that hashes can be created | ||
104 | taskhash = 'c665584ee6817aa99edfc77a44dd853828279370' | ||
105 | outhash = '3c979c3db45c569f51ab7626a4651074be3a9d11a84b1db076f5b14f7d39db44' | ||
106 | unihash = '90e9bc1d1f094c51824adca7f8ea79a048d68824' | ||
107 | |||
108 | result = self.client.get_unihash(self.METHOD, taskhash) | ||
109 | self.assertIsNone(result, msg='Found unexpected task, %r' % result) | ||
110 | |||
111 | siginfo = "0" * (self.client.max_chunk * 4) | ||
112 | |||
113 | result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash, { | ||
114 | 'outhash_siginfo': siginfo | ||
115 | }) | ||
116 | self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') | ||
117 | |||
118 | result = self.client.get_taskhash(self.METHOD, taskhash, True) | ||
119 | self.assertEqual(result['taskhash'], taskhash) | ||
120 | self.assertEqual(result['unihash'], unihash) | ||
121 | self.assertEqual(result['method'], self.METHOD) | ||
122 | self.assertEqual(result['outhash'], outhash) | ||
123 | self.assertEqual(result['outhash_siginfo'], siginfo) | ||
124 | |||
102 | def test_stress(self): | 125 | def test_stress(self): |
103 | def query_server(failures): | 126 | def query_server(failures): |
104 | client = Client(self.server.address) | 127 | client = Client(self.server.address) |