summaryrefslogtreecommitdiffstats
path: root/bitbake
diff options
context:
space:
mode:
Diffstat (limited to 'bitbake')
-rw-r--r--bitbake/lib/hashserv/__init__.py22
-rw-r--r--bitbake/lib/hashserv/client.py43
-rw-r--r--bitbake/lib/hashserv/server.py105
-rw-r--r--bitbake/lib/hashserv/tests.py23
4 files changed, 152 insertions, 41 deletions
diff --git a/bitbake/lib/hashserv/__init__.py b/bitbake/lib/hashserv/__init__.py
index c3318620f5..f95e8f43f1 100644
--- a/bitbake/lib/hashserv/__init__.py
+++ b/bitbake/lib/hashserv/__init__.py
@@ -6,12 +6,20 @@
6from contextlib import closing 6from contextlib import closing
7import re 7import re
8import sqlite3 8import sqlite3
9import itertools
10import json
9 11
10UNIX_PREFIX = "unix://" 12UNIX_PREFIX = "unix://"
11 13
12ADDR_TYPE_UNIX = 0 14ADDR_TYPE_UNIX = 0
13ADDR_TYPE_TCP = 1 15ADDR_TYPE_TCP = 1
14 16
17# The Python async server defaults to a 64K receive buffer, so we hardcode our
18# maximum chunk size. It would be better if the client and server reported to
19# each other what the maximum chunk sizes were, but that will slow down the
20# connection setup with a round trip delay so I'd rather not do that unless it
21# is necessary
22DEFAULT_MAX_CHUNK = 32 * 1024
15 23
16def setup_database(database, sync=True): 24def setup_database(database, sync=True):
17 db = sqlite3.connect(database) 25 db = sqlite3.connect(database)
@@ -66,6 +74,20 @@ def parse_address(addr):
66 return (ADDR_TYPE_TCP, (host, int(port))) 74 return (ADDR_TYPE_TCP, (host, int(port)))
67 75
68 76
77def chunkify(msg, max_chunk):
78 if len(msg) < max_chunk - 1:
79 yield ''.join((msg, "\n"))
80 else:
81 yield ''.join((json.dumps({
82 'chunk-stream': None
83 }), "\n"))
84
85 args = [iter(msg)] * (max_chunk - 1)
86 for m in map(''.join, itertools.zip_longest(*args, fillvalue='')):
87 yield ''.join(itertools.chain(m, "\n"))
88 yield "\n"
89
90
69def create_server(addr, dbname, *, sync=True): 91def create_server(addr, dbname, *, sync=True):
70 from . import server 92 from . import server
71 db = setup_database(dbname, sync=sync) 93 db = setup_database(dbname, sync=sync)
diff --git a/bitbake/lib/hashserv/client.py b/bitbake/lib/hashserv/client.py
index 46085d6418..a29af836d9 100644
--- a/bitbake/lib/hashserv/client.py
+++ b/bitbake/lib/hashserv/client.py
@@ -7,6 +7,7 @@ import json
7import logging 7import logging
8import socket 8import socket
9import os 9import os
10from . import chunkify, DEFAULT_MAX_CHUNK
10 11
11 12
12logger = logging.getLogger('hashserv.client') 13logger = logging.getLogger('hashserv.client')
@@ -25,6 +26,7 @@ class Client(object):
25 self.reader = None 26 self.reader = None
26 self.writer = None 27 self.writer = None
27 self.mode = self.MODE_NORMAL 28 self.mode = self.MODE_NORMAL
29 self.max_chunk = DEFAULT_MAX_CHUNK
28 30
29 def connect_tcp(self, address, port): 31 def connect_tcp(self, address, port):
30 def connect_sock(): 32 def connect_sock():
@@ -58,7 +60,7 @@ class Client(object):
58 self.reader = self._socket.makefile('r', encoding='utf-8') 60 self.reader = self._socket.makefile('r', encoding='utf-8')
59 self.writer = self._socket.makefile('w', encoding='utf-8') 61 self.writer = self._socket.makefile('w', encoding='utf-8')
60 62
61 self.writer.write('OEHASHEQUIV 1.0\n\n') 63 self.writer.write('OEHASHEQUIV 1.1\n\n')
62 self.writer.flush() 64 self.writer.flush()
63 65
64 # Restore mode if the socket is being re-created 66 # Restore mode if the socket is being re-created
@@ -91,18 +93,35 @@ class Client(object):
91 count += 1 93 count += 1
92 94
93 def send_message(self, msg): 95 def send_message(self, msg):
96 def get_line():
97 line = self.reader.readline()
98 if not line:
99 raise HashConnectionError('Connection closed')
100
101 if not line.endswith('\n'):
102 raise HashConnectionError('Bad message %r' % message)
103
104 return line
105
94 def proc(): 106 def proc():
95 self.writer.write('%s\n' % json.dumps(msg)) 107 for c in chunkify(json.dumps(msg), self.max_chunk):
108 self.writer.write(c)
96 self.writer.flush() 109 self.writer.flush()
97 110
98 l = self.reader.readline() 111 l = get_line()
99 if not l:
100 raise HashConnectionError('Connection closed')
101 112
102 if not l.endswith('\n'): 113 m = json.loads(l)
103 raise HashConnectionError('Bad message %r' % message) 114 if 'chunk-stream' in m:
115 lines = []
116 while True:
117 l = get_line().rstrip('\n')
118 if not l:
119 break
120 lines.append(l)
104 121
105 return json.loads(l) 122 m = json.loads(''.join(lines))
123
124 return m
106 125
107 return self._send_wrapper(proc) 126 return self._send_wrapper(proc)
108 127
@@ -155,6 +174,14 @@ class Client(object):
155 m['unihash'] = unihash 174 m['unihash'] = unihash
156 return self.send_message({'report-equiv': m}) 175 return self.send_message({'report-equiv': m})
157 176
177 def get_taskhash(self, method, taskhash, all_properties=False):
178 self._set_mode(self.MODE_NORMAL)
179 return self.send_message({'get': {
180 'taskhash': taskhash,
181 'method': method,
182 'all': all_properties
183 }})
184
158 def get_stats(self): 185 def get_stats(self):
159 self._set_mode(self.MODE_NORMAL) 186 self._set_mode(self.MODE_NORMAL)
160 return self.send_message({'get-stats': None}) 187 return self.send_message({'get-stats': None})
diff --git a/bitbake/lib/hashserv/server.py b/bitbake/lib/hashserv/server.py
index cc7e48233b..81050715ea 100644
--- a/bitbake/lib/hashserv/server.py
+++ b/bitbake/lib/hashserv/server.py
@@ -13,6 +13,7 @@ import os
13import signal 13import signal
14import socket 14import socket
15import time 15import time
16from . import chunkify, DEFAULT_MAX_CHUNK
16 17
17logger = logging.getLogger('hashserv.server') 18logger = logging.getLogger('hashserv.server')
18 19
@@ -107,12 +108,29 @@ class Stats(object):
107 return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')} 108 return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')}
108 109
109 110
111class ClientError(Exception):
112 pass
113
110class ServerClient(object): 114class ServerClient(object):
115 FAST_QUERY = 'SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1'
116 ALL_QUERY = 'SELECT * FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1'
117
111 def __init__(self, reader, writer, db, request_stats): 118 def __init__(self, reader, writer, db, request_stats):
112 self.reader = reader 119 self.reader = reader
113 self.writer = writer 120 self.writer = writer
114 self.db = db 121 self.db = db
115 self.request_stats = request_stats 122 self.request_stats = request_stats
123 self.max_chunk = DEFAULT_MAX_CHUNK
124
125 self.handlers = {
126 'get': self.handle_get,
127 'report': self.handle_report,
128 'report-equiv': self.handle_equivreport,
129 'get-stream': self.handle_get_stream,
130 'get-stats': self.handle_get_stats,
131 'reset-stats': self.handle_reset_stats,
132 'chunk-stream': self.handle_chunk,
133 }
116 134
117 async def process_requests(self): 135 async def process_requests(self):
118 try: 136 try:
@@ -125,7 +143,11 @@ class ServerClient(object):
125 return 143 return
126 144
127 (proto_name, proto_version) = protocol.decode('utf-8').rstrip().split() 145 (proto_name, proto_version) = protocol.decode('utf-8').rstrip().split()
128 if proto_name != 'OEHASHEQUIV' or proto_version != '1.0': 146 if proto_name != 'OEHASHEQUIV':
147 return
148
149 proto_version = tuple(int(v) for v in proto_version.split('.'))
150 if proto_version < (1, 0) or proto_version > (1, 1):
129 return 151 return
130 152
131 # Read headers. Currently, no headers are implemented, so look for 153 # Read headers. Currently, no headers are implemented, so look for
@@ -140,40 +162,34 @@ class ServerClient(object):
140 break 162 break
141 163
142 # Handle messages 164 # Handle messages
143 handlers = {
144 'get': self.handle_get,
145 'report': self.handle_report,
146 'report-equiv': self.handle_equivreport,
147 'get-stream': self.handle_get_stream,
148 'get-stats': self.handle_get_stats,
149 'reset-stats': self.handle_reset_stats,
150 }
151
152 while True: 165 while True:
153 d = await self.read_message() 166 d = await self.read_message()
154 if d is None: 167 if d is None:
155 break 168 break
156 169 await self.dispatch_message(d)
157 for k in handlers.keys():
158 if k in d:
159 logger.debug('Handling %s' % k)
160 if 'stream' in k:
161 await handlers[k](d[k])
162 else:
163 with self.request_stats.start_sample() as self.request_sample, \
164 self.request_sample.measure():
165 await handlers[k](d[k])
166 break
167 else:
168 logger.warning("Unrecognized command %r" % d)
169 break
170
171 await self.writer.drain() 170 await self.writer.drain()
171 except ClientError as e:
172 logger.error(str(e))
172 finally: 173 finally:
173 self.writer.close() 174 self.writer.close()
174 175
176 async def dispatch_message(self, msg):
177 for k in self.handlers.keys():
178 if k in msg:
179 logger.debug('Handling %s' % k)
180 if 'stream' in k:
181 await self.handlers[k](msg[k])
182 else:
183 with self.request_stats.start_sample() as self.request_sample, \
184 self.request_sample.measure():
185 await self.handlers[k](msg[k])
186 return
187
188 raise ClientError("Unrecognized command %r" % msg)
189
175 def write_message(self, msg): 190 def write_message(self, msg):
176 self.writer.write(('%s\n' % json.dumps(msg)).encode('utf-8')) 191 for c in chunkify(json.dumps(msg), self.max_chunk):
192 self.writer.write(c.encode('utf-8'))
177 193
178 async def read_message(self): 194 async def read_message(self):
179 l = await self.reader.readline() 195 l = await self.reader.readline()
@@ -191,14 +207,38 @@ class ServerClient(object):
191 logger.error('Bad message from client: %r' % message) 207 logger.error('Bad message from client: %r' % message)
192 raise e 208 raise e
193 209
210 async def handle_chunk(self, request):
211 lines = []
212 try:
213 while True:
214 l = await self.reader.readline()
215 l = l.rstrip(b"\n").decode("utf-8")
216 if not l:
217 break
218 lines.append(l)
219
220 msg = json.loads(''.join(lines))
221 except (json.JSONDecodeError, UnicodeDecodeError) as e:
222 logger.error('Bad message from client: %r' % message)
223 raise e
224
225 if 'chunk-stream' in msg:
226 raise ClientError("Nested chunks are not allowed")
227
228 await self.dispatch_message(msg)
229
194 async def handle_get(self, request): 230 async def handle_get(self, request):
195 method = request['method'] 231 method = request['method']
196 taskhash = request['taskhash'] 232 taskhash = request['taskhash']
197 233
198 row = self.query_equivalent(method, taskhash) 234 if request.get('all', False):
235 row = self.query_equivalent(method, taskhash, self.ALL_QUERY)
236 else:
237 row = self.query_equivalent(method, taskhash, self.FAST_QUERY)
238
199 if row is not None: 239 if row is not None:
200 logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) 240 logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
201 d = {k: row[k] for k in ('taskhash', 'method', 'unihash')} 241 d = {k: row[k] for k in row.keys()}
202 242
203 self.write_message(d) 243 self.write_message(d)
204 else: 244 else:
@@ -228,7 +268,7 @@ class ServerClient(object):
228 268
229 (method, taskhash) = l.split() 269 (method, taskhash) = l.split()
230 #logger.debug('Looking up %s %s' % (method, taskhash)) 270 #logger.debug('Looking up %s %s' % (method, taskhash))
231 row = self.query_equivalent(method, taskhash) 271 row = self.query_equivalent(method, taskhash, self.FAST_QUERY)
232 if row is not None: 272 if row is not None:
233 msg = ('%s\n' % row['unihash']).encode('utf-8') 273 msg = ('%s\n' % row['unihash']).encode('utf-8')
234 #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) 274 #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
@@ -328,7 +368,7 @@ class ServerClient(object):
328 # Fetch the unihash that will be reported for the taskhash. If the 368 # Fetch the unihash that will be reported for the taskhash. If the
329 # unihash matches, it means this row was inserted (or the mapping 369 # unihash matches, it means this row was inserted (or the mapping
330 # was already valid) 370 # was already valid)
331 row = self.query_equivalent(data['method'], data['taskhash']) 371 row = self.query_equivalent(data['method'], data['taskhash'], self.FAST_QUERY)
332 372
333 if row['unihash'] == data['unihash']: 373 if row['unihash'] == data['unihash']:
334 logger.info('Adding taskhash equivalence for %s with unihash %s', 374 logger.info('Adding taskhash equivalence for %s with unihash %s',
@@ -354,12 +394,11 @@ class ServerClient(object):
354 self.request_stats.reset() 394 self.request_stats.reset()
355 self.write_message(d) 395 self.write_message(d)
356 396
357 def query_equivalent(self, method, taskhash): 397 def query_equivalent(self, method, taskhash, query):
358 # This is part of the inner loop and must be as fast as possible 398 # This is part of the inner loop and must be as fast as possible
359 try: 399 try:
360 cursor = self.db.cursor() 400 cursor = self.db.cursor()
361 cursor.execute('SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1', 401 cursor.execute(query, {'method': method, 'taskhash': taskhash})
362 {'method': method, 'taskhash': taskhash})
363 return cursor.fetchone() 402 return cursor.fetchone()
364 except: 403 except:
365 cursor.close() 404 cursor.close()
diff --git a/bitbake/lib/hashserv/tests.py b/bitbake/lib/hashserv/tests.py
index a5472a996d..6e86295079 100644
--- a/bitbake/lib/hashserv/tests.py
+++ b/bitbake/lib/hashserv/tests.py
@@ -99,6 +99,29 @@ class TestHashEquivalenceServer(object):
99 result = self.client.get_unihash(self.METHOD, taskhash) 99 result = self.client.get_unihash(self.METHOD, taskhash)
100 self.assertEqual(result, unihash) 100 self.assertEqual(result, unihash)
101 101
102 def test_huge_message(self):
103 # Simple test that hashes can be created
104 taskhash = 'c665584ee6817aa99edfc77a44dd853828279370'
105 outhash = '3c979c3db45c569f51ab7626a4651074be3a9d11a84b1db076f5b14f7d39db44'
106 unihash = '90e9bc1d1f094c51824adca7f8ea79a048d68824'
107
108 result = self.client.get_unihash(self.METHOD, taskhash)
109 self.assertIsNone(result, msg='Found unexpected task, %r' % result)
110
111 siginfo = "0" * (self.client.max_chunk * 4)
112
113 result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash, {
114 'outhash_siginfo': siginfo
115 })
116 self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
117
118 result = self.client.get_taskhash(self.METHOD, taskhash, True)
119 self.assertEqual(result['taskhash'], taskhash)
120 self.assertEqual(result['unihash'], unihash)
121 self.assertEqual(result['method'], self.METHOD)
122 self.assertEqual(result['outhash'], outhash)
123 self.assertEqual(result['outhash_siginfo'], siginfo)
124
102 def test_stress(self): 125 def test_stress(self):
103 def query_server(failures): 126 def query_server(failures):
104 client = Client(self.server.address) 127 client = Client(self.server.address)