diff options
Diffstat (limited to 'bitbake/lib/hashserv/server.py')
-rw-r--r-- | bitbake/lib/hashserv/server.py | 105 |
1 files changed, 72 insertions, 33 deletions
diff --git a/bitbake/lib/hashserv/server.py b/bitbake/lib/hashserv/server.py index cc7e48233b..81050715ea 100644 --- a/bitbake/lib/hashserv/server.py +++ b/bitbake/lib/hashserv/server.py | |||
@@ -13,6 +13,7 @@ import os | |||
13 | import signal | 13 | import signal |
14 | import socket | 14 | import socket |
15 | import time | 15 | import time |
16 | from . import chunkify, DEFAULT_MAX_CHUNK | ||
16 | 17 | ||
17 | logger = logging.getLogger('hashserv.server') | 18 | logger = logging.getLogger('hashserv.server') |
18 | 19 | ||
@@ -107,12 +108,29 @@ class Stats(object): | |||
107 | return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')} | 108 | return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')} |
108 | 109 | ||
109 | 110 | ||
111 | class ClientError(Exception): | ||
112 | pass | ||
113 | |||
110 | class ServerClient(object): | 114 | class ServerClient(object): |
115 | FAST_QUERY = 'SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1' | ||
116 | ALL_QUERY = 'SELECT * FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1' | ||
117 | |||
111 | def __init__(self, reader, writer, db, request_stats): | 118 | def __init__(self, reader, writer, db, request_stats): |
112 | self.reader = reader | 119 | self.reader = reader |
113 | self.writer = writer | 120 | self.writer = writer |
114 | self.db = db | 121 | self.db = db |
115 | self.request_stats = request_stats | 122 | self.request_stats = request_stats |
123 | self.max_chunk = DEFAULT_MAX_CHUNK | ||
124 | |||
125 | self.handlers = { | ||
126 | 'get': self.handle_get, | ||
127 | 'report': self.handle_report, | ||
128 | 'report-equiv': self.handle_equivreport, | ||
129 | 'get-stream': self.handle_get_stream, | ||
130 | 'get-stats': self.handle_get_stats, | ||
131 | 'reset-stats': self.handle_reset_stats, | ||
132 | 'chunk-stream': self.handle_chunk, | ||
133 | } | ||
116 | 134 | ||
117 | async def process_requests(self): | 135 | async def process_requests(self): |
118 | try: | 136 | try: |
@@ -125,7 +143,11 @@ class ServerClient(object): | |||
125 | return | 143 | return |
126 | 144 | ||
127 | (proto_name, proto_version) = protocol.decode('utf-8').rstrip().split() | 145 | (proto_name, proto_version) = protocol.decode('utf-8').rstrip().split() |
128 | if proto_name != 'OEHASHEQUIV' or proto_version != '1.0': | 146 | if proto_name != 'OEHASHEQUIV': |
147 | return | ||
148 | |||
149 | proto_version = tuple(int(v) for v in proto_version.split('.')) | ||
150 | if proto_version < (1, 0) or proto_version > (1, 1): | ||
129 | return | 151 | return |
130 | 152 | ||
131 | # Read headers. Currently, no headers are implemented, so look for | 153 | # Read headers. Currently, no headers are implemented, so look for |
@@ -140,40 +162,34 @@ class ServerClient(object): | |||
140 | break | 162 | break |
141 | 163 | ||
142 | # Handle messages | 164 | # Handle messages |
143 | handlers = { | ||
144 | 'get': self.handle_get, | ||
145 | 'report': self.handle_report, | ||
146 | 'report-equiv': self.handle_equivreport, | ||
147 | 'get-stream': self.handle_get_stream, | ||
148 | 'get-stats': self.handle_get_stats, | ||
149 | 'reset-stats': self.handle_reset_stats, | ||
150 | } | ||
151 | |||
152 | while True: | 165 | while True: |
153 | d = await self.read_message() | 166 | d = await self.read_message() |
154 | if d is None: | 167 | if d is None: |
155 | break | 168 | break |
156 | 169 | await self.dispatch_message(d) | |
157 | for k in handlers.keys(): | ||
158 | if k in d: | ||
159 | logger.debug('Handling %s' % k) | ||
160 | if 'stream' in k: | ||
161 | await handlers[k](d[k]) | ||
162 | else: | ||
163 | with self.request_stats.start_sample() as self.request_sample, \ | ||
164 | self.request_sample.measure(): | ||
165 | await handlers[k](d[k]) | ||
166 | break | ||
167 | else: | ||
168 | logger.warning("Unrecognized command %r" % d) | ||
169 | break | ||
170 | |||
171 | await self.writer.drain() | 170 | await self.writer.drain() |
171 | except ClientError as e: | ||
172 | logger.error(str(e)) | ||
172 | finally: | 173 | finally: |
173 | self.writer.close() | 174 | self.writer.close() |
174 | 175 | ||
176 | async def dispatch_message(self, msg): | ||
177 | for k in self.handlers.keys(): | ||
178 | if k in msg: | ||
179 | logger.debug('Handling %s' % k) | ||
180 | if 'stream' in k: | ||
181 | await self.handlers[k](msg[k]) | ||
182 | else: | ||
183 | with self.request_stats.start_sample() as self.request_sample, \ | ||
184 | self.request_sample.measure(): | ||
185 | await self.handlers[k](msg[k]) | ||
186 | return | ||
187 | |||
188 | raise ClientError("Unrecognized command %r" % msg) | ||
189 | |||
175 | def write_message(self, msg): | 190 | def write_message(self, msg): |
176 | self.writer.write(('%s\n' % json.dumps(msg)).encode('utf-8')) | 191 | for c in chunkify(json.dumps(msg), self.max_chunk): |
192 | self.writer.write(c.encode('utf-8')) | ||
177 | 193 | ||
178 | async def read_message(self): | 194 | async def read_message(self): |
179 | l = await self.reader.readline() | 195 | l = await self.reader.readline() |
@@ -191,14 +207,38 @@ class ServerClient(object): | |||
191 | logger.error('Bad message from client: %r' % message) | 207 | logger.error('Bad message from client: %r' % message) |
192 | raise e | 208 | raise e |
193 | 209 | ||
210 | async def handle_chunk(self, request): | ||
211 | lines = [] | ||
212 | try: | ||
213 | while True: | ||
214 | l = await self.reader.readline() | ||
215 | l = l.rstrip(b"\n").decode("utf-8") | ||
216 | if not l: | ||
217 | break | ||
218 | lines.append(l) | ||
219 | |||
220 | msg = json.loads(''.join(lines)) | ||
221 | except (json.JSONDecodeError, UnicodeDecodeError) as e: | ||
222 | logger.error('Bad message from client: %r' % message) | ||
223 | raise e | ||
224 | |||
225 | if 'chunk-stream' in msg: | ||
226 | raise ClientError("Nested chunks are not allowed") | ||
227 | |||
228 | await self.dispatch_message(msg) | ||
229 | |||
194 | async def handle_get(self, request): | 230 | async def handle_get(self, request): |
195 | method = request['method'] | 231 | method = request['method'] |
196 | taskhash = request['taskhash'] | 232 | taskhash = request['taskhash'] |
197 | 233 | ||
198 | row = self.query_equivalent(method, taskhash) | 234 | if request.get('all', False): |
235 | row = self.query_equivalent(method, taskhash, self.ALL_QUERY) | ||
236 | else: | ||
237 | row = self.query_equivalent(method, taskhash, self.FAST_QUERY) | ||
238 | |||
199 | if row is not None: | 239 | if row is not None: |
200 | logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) | 240 | logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) |
201 | d = {k: row[k] for k in ('taskhash', 'method', 'unihash')} | 241 | d = {k: row[k] for k in row.keys()} |
202 | 242 | ||
203 | self.write_message(d) | 243 | self.write_message(d) |
204 | else: | 244 | else: |
@@ -228,7 +268,7 @@ class ServerClient(object): | |||
228 | 268 | ||
229 | (method, taskhash) = l.split() | 269 | (method, taskhash) = l.split() |
230 | #logger.debug('Looking up %s %s' % (method, taskhash)) | 270 | #logger.debug('Looking up %s %s' % (method, taskhash)) |
231 | row = self.query_equivalent(method, taskhash) | 271 | row = self.query_equivalent(method, taskhash, self.FAST_QUERY) |
232 | if row is not None: | 272 | if row is not None: |
233 | msg = ('%s\n' % row['unihash']).encode('utf-8') | 273 | msg = ('%s\n' % row['unihash']).encode('utf-8') |
234 | #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) | 274 | #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) |
@@ -328,7 +368,7 @@ class ServerClient(object): | |||
328 | # Fetch the unihash that will be reported for the taskhash. If the | 368 | # Fetch the unihash that will be reported for the taskhash. If the |
329 | # unihash matches, it means this row was inserted (or the mapping | 369 | # unihash matches, it means this row was inserted (or the mapping |
330 | # was already valid) | 370 | # was already valid) |
331 | row = self.query_equivalent(data['method'], data['taskhash']) | 371 | row = self.query_equivalent(data['method'], data['taskhash'], self.FAST_QUERY) |
332 | 372 | ||
333 | if row['unihash'] == data['unihash']: | 373 | if row['unihash'] == data['unihash']: |
334 | logger.info('Adding taskhash equivalence for %s with unihash %s', | 374 | logger.info('Adding taskhash equivalence for %s with unihash %s', |
@@ -354,12 +394,11 @@ class ServerClient(object): | |||
354 | self.request_stats.reset() | 394 | self.request_stats.reset() |
355 | self.write_message(d) | 395 | self.write_message(d) |
356 | 396 | ||
357 | def query_equivalent(self, method, taskhash): | 397 | def query_equivalent(self, method, taskhash, query): |
358 | # This is part of the inner loop and must be as fast as possible | 398 | # This is part of the inner loop and must be as fast as possible |
359 | try: | 399 | try: |
360 | cursor = self.db.cursor() | 400 | cursor = self.db.cursor() |
361 | cursor.execute('SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1', | 401 | cursor.execute(query, {'method': method, 'taskhash': taskhash}) |
362 | {'method': method, 'taskhash': taskhash}) | ||
363 | return cursor.fetchone() | 402 | return cursor.fetchone() |
364 | except: | 403 | except: |
365 | cursor.close() | 404 | cursor.close() |