summaryrefslogtreecommitdiffstats
path: root/bitbake
diff options
context:
space:
mode:
authorJoshua Watt <JPEWhacker@gmail.com>2020-06-25 09:21:07 -0500
committerRichard Purdie <richard.purdie@linuxfoundation.org>2020-07-02 16:11:40 +0100
commit6ebf01bfd43b6d95a70699b1e58a42fd7d1002a6 (patch)
tree3f18afa2f1918dde70ced1013e5d859cc4c573e7 /bitbake
parentb6e0f5889eb55d88276807407f75eaad9bf0a96a (diff)
downloadpoky-6ebf01bfd43b6d95a70699b1e58a42fd7d1002a6.tar.gz
bitbake: hashserv: Chunkify large messages
The hash equivalence client and server can occasionally send messages that are too large for the server to fit in the receive buffer (64 KB). To prevent this, support is added to the protocol to "chunkify" the stream and break it up into manageable pieces that the server can each side can back together. Ideally, this would be negotiated by the client and server, but it's currently hard coded to 32 KB to prevent the round-trip delay. (Bitbake rev: 1a7bddb5471a02a744e7a441a3b4a6da693348b0) Signed-off-by: Joshua Watt <JPEWhacker@gmail.com> Signed-off-by: Richard Purdie <richard.purdie@linuxfoundation.org> (cherry picked from commit e27a28c1e40e886ee68ba4b99b537ffc9c3577d4) Signed-off-by: Steve Sakoman <steve@sakoman.com> Signed-off-by: Richard Purdie <richard.purdie@linuxfoundation.org>
Diffstat (limited to 'bitbake')
-rw-r--r--bitbake/lib/hashserv/__init__.py22
-rw-r--r--bitbake/lib/hashserv/client.py43
-rw-r--r--bitbake/lib/hashserv/server.py105
-rw-r--r--bitbake/lib/hashserv/tests.py23
4 files changed, 152 insertions, 41 deletions
diff --git a/bitbake/lib/hashserv/__init__.py b/bitbake/lib/hashserv/__init__.py
index c3318620f5..f95e8f43f1 100644
--- a/bitbake/lib/hashserv/__init__.py
+++ b/bitbake/lib/hashserv/__init__.py
@@ -6,12 +6,20 @@
6from contextlib import closing 6from contextlib import closing
7import re 7import re
8import sqlite3 8import sqlite3
9import itertools
10import json
9 11
10UNIX_PREFIX = "unix://" 12UNIX_PREFIX = "unix://"
11 13
12ADDR_TYPE_UNIX = 0 14ADDR_TYPE_UNIX = 0
13ADDR_TYPE_TCP = 1 15ADDR_TYPE_TCP = 1
14 16
17# The Python async server defaults to a 64K receive buffer, so we hardcode our
18# maximum chunk size. It would be better if the client and server reported to
19# each other what the maximum chunk sizes were, but that will slow down the
20# connection setup with a round trip delay so I'd rather not do that unless it
21# is necessary
22DEFAULT_MAX_CHUNK = 32 * 1024
15 23
16def setup_database(database, sync=True): 24def setup_database(database, sync=True):
17 db = sqlite3.connect(database) 25 db = sqlite3.connect(database)
@@ -66,6 +74,20 @@ def parse_address(addr):
66 return (ADDR_TYPE_TCP, (host, int(port))) 74 return (ADDR_TYPE_TCP, (host, int(port)))
67 75
68 76
77def chunkify(msg, max_chunk):
78 if len(msg) < max_chunk - 1:
79 yield ''.join((msg, "\n"))
80 else:
81 yield ''.join((json.dumps({
82 'chunk-stream': None
83 }), "\n"))
84
85 args = [iter(msg)] * (max_chunk - 1)
86 for m in map(''.join, itertools.zip_longest(*args, fillvalue='')):
87 yield ''.join(itertools.chain(m, "\n"))
88 yield "\n"
89
90
69def create_server(addr, dbname, *, sync=True): 91def create_server(addr, dbname, *, sync=True):
70 from . import server 92 from . import server
71 db = setup_database(dbname, sync=sync) 93 db = setup_database(dbname, sync=sync)
diff --git a/bitbake/lib/hashserv/client.py b/bitbake/lib/hashserv/client.py
index 46085d6418..a29af836d9 100644
--- a/bitbake/lib/hashserv/client.py
+++ b/bitbake/lib/hashserv/client.py
@@ -7,6 +7,7 @@ import json
7import logging 7import logging
8import socket 8import socket
9import os 9import os
10from . import chunkify, DEFAULT_MAX_CHUNK
10 11
11 12
12logger = logging.getLogger('hashserv.client') 13logger = logging.getLogger('hashserv.client')
@@ -25,6 +26,7 @@ class Client(object):
25 self.reader = None 26 self.reader = None
26 self.writer = None 27 self.writer = None
27 self.mode = self.MODE_NORMAL 28 self.mode = self.MODE_NORMAL
29 self.max_chunk = DEFAULT_MAX_CHUNK
28 30
29 def connect_tcp(self, address, port): 31 def connect_tcp(self, address, port):
30 def connect_sock(): 32 def connect_sock():
@@ -58,7 +60,7 @@ class Client(object):
58 self.reader = self._socket.makefile('r', encoding='utf-8') 60 self.reader = self._socket.makefile('r', encoding='utf-8')
59 self.writer = self._socket.makefile('w', encoding='utf-8') 61 self.writer = self._socket.makefile('w', encoding='utf-8')
60 62
61 self.writer.write('OEHASHEQUIV 1.0\n\n') 63 self.writer.write('OEHASHEQUIV 1.1\n\n')
62 self.writer.flush() 64 self.writer.flush()
63 65
64 # Restore mode if the socket is being re-created 66 # Restore mode if the socket is being re-created
@@ -91,18 +93,35 @@ class Client(object):
91 count += 1 93 count += 1
92 94
93 def send_message(self, msg): 95 def send_message(self, msg):
96 def get_line():
97 line = self.reader.readline()
98 if not line:
99 raise HashConnectionError('Connection closed')
100
101 if not line.endswith('\n'):
102 raise HashConnectionError('Bad message %r' % message)
103
104 return line
105
94 def proc(): 106 def proc():
95 self.writer.write('%s\n' % json.dumps(msg)) 107 for c in chunkify(json.dumps(msg), self.max_chunk):
108 self.writer.write(c)
96 self.writer.flush() 109 self.writer.flush()
97 110
98 l = self.reader.readline() 111 l = get_line()
99 if not l:
100 raise HashConnectionError('Connection closed')
101 112
102 if not l.endswith('\n'): 113 m = json.loads(l)
103 raise HashConnectionError('Bad message %r' % message) 114 if 'chunk-stream' in m:
115 lines = []
116 while True:
117 l = get_line().rstrip('\n')
118 if not l:
119 break
120 lines.append(l)
104 121
105 return json.loads(l) 122 m = json.loads(''.join(lines))
123
124 return m
106 125
107 return self._send_wrapper(proc) 126 return self._send_wrapper(proc)
108 127
@@ -155,6 +174,14 @@ class Client(object):
155 m['unihash'] = unihash 174 m['unihash'] = unihash
156 return self.send_message({'report-equiv': m}) 175 return self.send_message({'report-equiv': m})
157 176
177 def get_taskhash(self, method, taskhash, all_properties=False):
178 self._set_mode(self.MODE_NORMAL)
179 return self.send_message({'get': {
180 'taskhash': taskhash,
181 'method': method,
182 'all': all_properties
183 }})
184
158 def get_stats(self): 185 def get_stats(self):
159 self._set_mode(self.MODE_NORMAL) 186 self._set_mode(self.MODE_NORMAL)
160 return self.send_message({'get-stats': None}) 187 return self.send_message({'get-stats': None})
diff --git a/bitbake/lib/hashserv/server.py b/bitbake/lib/hashserv/server.py
index cc7e48233b..81050715ea 100644
--- a/bitbake/lib/hashserv/server.py
+++ b/bitbake/lib/hashserv/server.py
@@ -13,6 +13,7 @@ import os
13import signal 13import signal
14import socket 14import socket
15import time 15import time
16from . import chunkify, DEFAULT_MAX_CHUNK
16 17
17logger = logging.getLogger('hashserv.server') 18logger = logging.getLogger('hashserv.server')
18 19
@@ -107,12 +108,29 @@ class Stats(object):
107 return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')} 108 return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')}
108 109
109 110
111class ClientError(Exception):
112 pass
113
110class ServerClient(object): 114class ServerClient(object):
115 FAST_QUERY = 'SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1'
116 ALL_QUERY = 'SELECT * FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1'
117
111 def __init__(self, reader, writer, db, request_stats): 118 def __init__(self, reader, writer, db, request_stats):
112 self.reader = reader 119 self.reader = reader
113 self.writer = writer 120 self.writer = writer
114 self.db = db 121 self.db = db
115 self.request_stats = request_stats 122 self.request_stats = request_stats
123 self.max_chunk = DEFAULT_MAX_CHUNK
124
125 self.handlers = {
126 'get': self.handle_get,
127 'report': self.handle_report,
128 'report-equiv': self.handle_equivreport,
129 'get-stream': self.handle_get_stream,
130 'get-stats': self.handle_get_stats,
131 'reset-stats': self.handle_reset_stats,
132 'chunk-stream': self.handle_chunk,
133 }
116 134
117 async def process_requests(self): 135 async def process_requests(self):
118 try: 136 try:
@@ -125,7 +143,11 @@ class ServerClient(object):
125 return 143 return
126 144
127 (proto_name, proto_version) = protocol.decode('utf-8').rstrip().split() 145 (proto_name, proto_version) = protocol.decode('utf-8').rstrip().split()
128 if proto_name != 'OEHASHEQUIV' or proto_version != '1.0': 146 if proto_name != 'OEHASHEQUIV':
147 return
148
149 proto_version = tuple(int(v) for v in proto_version.split('.'))
150 if proto_version < (1, 0) or proto_version > (1, 1):
129 return 151 return
130 152
131 # Read headers. Currently, no headers are implemented, so look for 153 # Read headers. Currently, no headers are implemented, so look for
@@ -140,40 +162,34 @@ class ServerClient(object):
140 break 162 break
141 163
142 # Handle messages 164 # Handle messages
143 handlers = {
144 'get': self.handle_get,
145 'report': self.handle_report,
146 'report-equiv': self.handle_equivreport,
147 'get-stream': self.handle_get_stream,
148 'get-stats': self.handle_get_stats,
149 'reset-stats': self.handle_reset_stats,
150 }
151
152 while True: 165 while True:
153 d = await self.read_message() 166 d = await self.read_message()
154 if d is None: 167 if d is None:
155 break 168 break
156 169 await self.dispatch_message(d)
157 for k in handlers.keys():
158 if k in d:
159 logger.debug('Handling %s' % k)
160 if 'stream' in k:
161 await handlers[k](d[k])
162 else:
163 with self.request_stats.start_sample() as self.request_sample, \
164 self.request_sample.measure():
165 await handlers[k](d[k])
166 break
167 else:
168 logger.warning("Unrecognized command %r" % d)
169 break
170
171 await self.writer.drain() 170 await self.writer.drain()
171 except ClientError as e:
172 logger.error(str(e))
172 finally: 173 finally:
173 self.writer.close() 174 self.writer.close()
174 175
176 async def dispatch_message(self, msg):
177 for k in self.handlers.keys():
178 if k in msg:
179 logger.debug('Handling %s' % k)
180 if 'stream' in k:
181 await self.handlers[k](msg[k])
182 else:
183 with self.request_stats.start_sample() as self.request_sample, \
184 self.request_sample.measure():
185 await self.handlers[k](msg[k])
186 return
187
188 raise ClientError("Unrecognized command %r" % msg)
189
175 def write_message(self, msg): 190 def write_message(self, msg):
176 self.writer.write(('%s\n' % json.dumps(msg)).encode('utf-8')) 191 for c in chunkify(json.dumps(msg), self.max_chunk):
192 self.writer.write(c.encode('utf-8'))
177 193
178 async def read_message(self): 194 async def read_message(self):
179 l = await self.reader.readline() 195 l = await self.reader.readline()
@@ -191,14 +207,38 @@ class ServerClient(object):
191 logger.error('Bad message from client: %r' % message) 207 logger.error('Bad message from client: %r' % message)
192 raise e 208 raise e
193 209
210 async def handle_chunk(self, request):
211 lines = []
212 try:
213 while True:
214 l = await self.reader.readline()
215 l = l.rstrip(b"\n").decode("utf-8")
216 if not l:
217 break
218 lines.append(l)
219
220 msg = json.loads(''.join(lines))
221 except (json.JSONDecodeError, UnicodeDecodeError) as e:
222 logger.error('Bad message from client: %r' % message)
223 raise e
224
225 if 'chunk-stream' in msg:
226 raise ClientError("Nested chunks are not allowed")
227
228 await self.dispatch_message(msg)
229
194 async def handle_get(self, request): 230 async def handle_get(self, request):
195 method = request['method'] 231 method = request['method']
196 taskhash = request['taskhash'] 232 taskhash = request['taskhash']
197 233
198 row = self.query_equivalent(method, taskhash) 234 if request.get('all', False):
235 row = self.query_equivalent(method, taskhash, self.ALL_QUERY)
236 else:
237 row = self.query_equivalent(method, taskhash, self.FAST_QUERY)
238
199 if row is not None: 239 if row is not None:
200 logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) 240 logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
201 d = {k: row[k] for k in ('taskhash', 'method', 'unihash')} 241 d = {k: row[k] for k in row.keys()}
202 242
203 self.write_message(d) 243 self.write_message(d)
204 else: 244 else:
@@ -228,7 +268,7 @@ class ServerClient(object):
228 268
229 (method, taskhash) = l.split() 269 (method, taskhash) = l.split()
230 #logger.debug('Looking up %s %s' % (method, taskhash)) 270 #logger.debug('Looking up %s %s' % (method, taskhash))
231 row = self.query_equivalent(method, taskhash) 271 row = self.query_equivalent(method, taskhash, self.FAST_QUERY)
232 if row is not None: 272 if row is not None:
233 msg = ('%s\n' % row['unihash']).encode('utf-8') 273 msg = ('%s\n' % row['unihash']).encode('utf-8')
234 #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) 274 #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
@@ -328,7 +368,7 @@ class ServerClient(object):
328 # Fetch the unihash that will be reported for the taskhash. If the 368 # Fetch the unihash that will be reported for the taskhash. If the
329 # unihash matches, it means this row was inserted (or the mapping 369 # unihash matches, it means this row was inserted (or the mapping
330 # was already valid) 370 # was already valid)
331 row = self.query_equivalent(data['method'], data['taskhash']) 371 row = self.query_equivalent(data['method'], data['taskhash'], self.FAST_QUERY)
332 372
333 if row['unihash'] == data['unihash']: 373 if row['unihash'] == data['unihash']:
334 logger.info('Adding taskhash equivalence for %s with unihash %s', 374 logger.info('Adding taskhash equivalence for %s with unihash %s',
@@ -354,12 +394,11 @@ class ServerClient(object):
354 self.request_stats.reset() 394 self.request_stats.reset()
355 self.write_message(d) 395 self.write_message(d)
356 396
357 def query_equivalent(self, method, taskhash): 397 def query_equivalent(self, method, taskhash, query):
358 # This is part of the inner loop and must be as fast as possible 398 # This is part of the inner loop and must be as fast as possible
359 try: 399 try:
360 cursor = self.db.cursor() 400 cursor = self.db.cursor()
361 cursor.execute('SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1', 401 cursor.execute(query, {'method': method, 'taskhash': taskhash})
362 {'method': method, 'taskhash': taskhash})
363 return cursor.fetchone() 402 return cursor.fetchone()
364 except: 403 except:
365 cursor.close() 404 cursor.close()
diff --git a/bitbake/lib/hashserv/tests.py b/bitbake/lib/hashserv/tests.py
index a5472a996d..6e86295079 100644
--- a/bitbake/lib/hashserv/tests.py
+++ b/bitbake/lib/hashserv/tests.py
@@ -99,6 +99,29 @@ class TestHashEquivalenceServer(object):
99 result = self.client.get_unihash(self.METHOD, taskhash) 99 result = self.client.get_unihash(self.METHOD, taskhash)
100 self.assertEqual(result, unihash) 100 self.assertEqual(result, unihash)
101 101
102 def test_huge_message(self):
103 # Simple test that hashes can be created
104 taskhash = 'c665584ee6817aa99edfc77a44dd853828279370'
105 outhash = '3c979c3db45c569f51ab7626a4651074be3a9d11a84b1db076f5b14f7d39db44'
106 unihash = '90e9bc1d1f094c51824adca7f8ea79a048d68824'
107
108 result = self.client.get_unihash(self.METHOD, taskhash)
109 self.assertIsNone(result, msg='Found unexpected task, %r' % result)
110
111 siginfo = "0" * (self.client.max_chunk * 4)
112
113 result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash, {
114 'outhash_siginfo': siginfo
115 })
116 self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
117
118 result = self.client.get_taskhash(self.METHOD, taskhash, True)
119 self.assertEqual(result['taskhash'], taskhash)
120 self.assertEqual(result['unihash'], unihash)
121 self.assertEqual(result['method'], self.METHOD)
122 self.assertEqual(result['outhash'], outhash)
123 self.assertEqual(result['outhash_siginfo'], siginfo)
124
102 def test_stress(self): 125 def test_stress(self):
103 def query_server(failures): 126 def query_server(failures):
104 client = Client(self.server.address) 127 client = Client(self.server.address)