summaryrefslogtreecommitdiffstats
path: root/bitbake
diff options
context:
space:
mode:
authorPaul Barker <pbarker@konsulko.com>2021-04-26 09:16:30 +0100
committerRichard Purdie <richard.purdie@linuxfoundation.org>2021-04-27 15:12:57 +0100
commit421e86e7edadb8c88baf4df68b9fc15671e425de (patch)
tree4535af55064f7de8f41929085718e88bcd1fc15e /bitbake
parent244b044fd6d94c000fc9cb8d1b7a9dddd08017ad (diff)
downloadpoky-421e86e7edadb8c88baf4df68b9fc15671e425de.tar.gz
bitbake: hashserv: Refactor to use asyncrpc
The asyncrpc module can now be used to provide the json & asyncio based RPC system used by hashserv. (Bitbake rev: 5afb9586b0a4a23a05efb0e8ff4a97262631ae4a) Signed-off-by: Paul Barker <pbarker@konsulko.com> Signed-off-by: Richard Purdie <richard.purdie@linuxfoundation.org>
Diffstat (limited to 'bitbake')
-rw-r--r--bitbake/lib/hashserv/client.py137
-rw-r--r--bitbake/lib/hashserv/server.py210
2 files changed, 41 insertions, 306 deletions
diff --git a/bitbake/lib/hashserv/client.py b/bitbake/lib/hashserv/client.py
index f370cba63f..5311709677 100644
--- a/bitbake/lib/hashserv/client.py
+++ b/bitbake/lib/hashserv/client.py
@@ -8,106 +8,26 @@ import json
8import logging 8import logging
9import socket 9import socket
10import os 10import os
11from . import chunkify, DEFAULT_MAX_CHUNK, create_async_client 11import bb.asyncrpc
12from . import create_async_client
12 13
13 14
14logger = logging.getLogger("hashserv.client") 15logger = logging.getLogger("hashserv.client")
15 16
16 17
17class AsyncClient(object): 18class AsyncClient(bb.asyncrpc.AsyncClient):
18 MODE_NORMAL = 0 19 MODE_NORMAL = 0
19 MODE_GET_STREAM = 1 20 MODE_GET_STREAM = 1
20 21
21 def __init__(self): 22 def __init__(self):
22 self.reader = None 23 super().__init__('OEHASHEQUIV', '1.1', logger)
23 self.writer = None
24 self.mode = self.MODE_NORMAL 24 self.mode = self.MODE_NORMAL
25 self.max_chunk = DEFAULT_MAX_CHUNK
26 25
27 async def connect_tcp(self, address, port): 26 async def setup_connection(self):
28 async def connect_sock(): 27 await super().setup_connection()
29 return await asyncio.open_connection(address, port) 28 cur_mode = self.mode
30 29 self.mode = self.MODE_NORMAL
31 self._connect_sock = connect_sock 30 await self._set_mode(cur_mode)
32
33 async def connect_unix(self, path):
34 async def connect_sock():
35 return await asyncio.open_unix_connection(path)
36
37 self._connect_sock = connect_sock
38
39 async def connect(self):
40 if self.reader is None or self.writer is None:
41 (self.reader, self.writer) = await self._connect_sock()
42
43 self.writer.write("OEHASHEQUIV 1.1\n\n".encode("utf-8"))
44 await self.writer.drain()
45
46 cur_mode = self.mode
47 self.mode = self.MODE_NORMAL
48 await self._set_mode(cur_mode)
49
50 async def close(self):
51 self.reader = None
52
53 if self.writer is not None:
54 self.writer.close()
55 self.writer = None
56
57 async def _send_wrapper(self, proc):
58 count = 0
59 while True:
60 try:
61 await self.connect()
62 return await proc()
63 except (
64 OSError,
65 ConnectionError,
66 json.JSONDecodeError,
67 UnicodeDecodeError,
68 ) as e:
69 logger.warning("Error talking to server: %s" % e)
70 if count >= 3:
71 if not isinstance(e, ConnectionError):
72 raise ConnectionError(str(e))
73 raise e
74 await self.close()
75 count += 1
76
77 async def send_message(self, msg):
78 async def get_line():
79 line = await self.reader.readline()
80 if not line:
81 raise ConnectionError("Connection closed")
82
83 line = line.decode("utf-8")
84
85 if not line.endswith("\n"):
86 raise ConnectionError("Bad message %r" % message)
87
88 return line
89
90 async def proc():
91 for c in chunkify(json.dumps(msg), self.max_chunk):
92 self.writer.write(c.encode("utf-8"))
93 await self.writer.drain()
94
95 l = await get_line()
96
97 m = json.loads(l)
98 if m and "chunk-stream" in m:
99 lines = []
100 while True:
101 l = (await get_line()).rstrip("\n")
102 if not l:
103 break
104 lines.append(l)
105
106 m = json.loads("".join(lines))
107
108 return m
109
110 return await self._send_wrapper(proc)
111 31
112 async def send_stream(self, msg): 32 async def send_stream(self, msg):
113 async def proc(): 33 async def proc():
@@ -185,12 +105,10 @@ class AsyncClient(object):
185 return (await self.send_message({"backfill-wait": None}))["tasks"] 105 return (await self.send_message({"backfill-wait": None}))["tasks"]
186 106
187 107
188class Client(object): 108class Client(bb.asyncrpc.Client):
189 def __init__(self): 109 def __init__(self):
190 self.client = AsyncClient() 110 super().__init__()
191 self.loop = asyncio.new_event_loop() 111 self._add_methods(
192
193 for call in (
194 "connect_tcp", 112 "connect_tcp",
195 "close", 113 "close",
196 "get_unihash", 114 "get_unihash",
@@ -200,30 +118,7 @@ class Client(object):
200 "get_stats", 118 "get_stats",
201 "reset_stats", 119 "reset_stats",
202 "backfill_wait", 120 "backfill_wait",
203 ): 121 )
204 downcall = getattr(self.client, call) 122
205 setattr(self, call, self._get_downcall_wrapper(downcall)) 123 def _get_async_client(self):
206 124 return AsyncClient()
207 def _get_downcall_wrapper(self, downcall):
208 def wrapper(*args, **kwargs):
209 return self.loop.run_until_complete(downcall(*args, **kwargs))
210
211 return wrapper
212
213 def connect_unix(self, path):
214 # AF_UNIX has path length issues so chdir here to workaround
215 cwd = os.getcwd()
216 try:
217 os.chdir(os.path.dirname(path))
218 self.loop.run_until_complete(self.client.connect_unix(os.path.basename(path)))
219 self.loop.run_until_complete(self.client.connect())
220 finally:
221 os.chdir(cwd)
222
223 @property
224 def max_chunk(self):
225 return self.client.max_chunk
226
227 @max_chunk.setter
228 def max_chunk(self, value):
229 self.client.max_chunk = value
diff --git a/bitbake/lib/hashserv/server.py b/bitbake/lib/hashserv/server.py
index a0dc0c170f..c941c0e9dd 100644
--- a/bitbake/lib/hashserv/server.py
+++ b/bitbake/lib/hashserv/server.py
@@ -14,7 +14,9 @@ import signal
14import socket 14import socket
15import sys 15import sys
16import time 16import time
17from . import chunkify, DEFAULT_MAX_CHUNK, create_async_client, TABLE_COLUMNS 17from . import create_async_client, TABLE_COLUMNS
18import bb.asyncrpc
19
18 20
19logger = logging.getLogger('hashserv.server') 21logger = logging.getLogger('hashserv.server')
20 22
@@ -109,12 +111,6 @@ class Stats(object):
109 return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')} 111 return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')}
110 112
111 113
112class ClientError(Exception):
113 pass
114
115class ServerError(Exception):
116 pass
117
118def insert_task(cursor, data, ignore=False): 114def insert_task(cursor, data, ignore=False):
119 keys = sorted(data.keys()) 115 keys = sorted(data.keys())
120 query = '''INSERT%s INTO tasks_v2 (%s) VALUES (%s)''' % ( 116 query = '''INSERT%s INTO tasks_v2 (%s) VALUES (%s)''' % (
@@ -149,7 +145,7 @@ async def copy_outhash_from_upstream(client, db, method, outhash, taskhash):
149 145
150 return d 146 return d
151 147
152class ServerClient(object): 148class ServerClient(bb.asyncrpc.AsyncServerConnection):
153 FAST_QUERY = 'SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1' 149 FAST_QUERY = 'SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1'
154 ALL_QUERY = 'SELECT * FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1' 150 ALL_QUERY = 'SELECT * FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1'
155 OUTHASH_QUERY = ''' 151 OUTHASH_QUERY = '''
@@ -168,21 +164,19 @@ class ServerClient(object):
168 ''' 164 '''
169 165
170 def __init__(self, reader, writer, db, request_stats, backfill_queue, upstream, read_only): 166 def __init__(self, reader, writer, db, request_stats, backfill_queue, upstream, read_only):
171 self.reader = reader 167 super().__init__(reader, writer, 'OEHASHEQUIV', logger)
172 self.writer = writer
173 self.db = db 168 self.db = db
174 self.request_stats = request_stats 169 self.request_stats = request_stats
175 self.max_chunk = DEFAULT_MAX_CHUNK 170 self.max_chunk = bb.asyncrpc.DEFAULT_MAX_CHUNK
176 self.backfill_queue = backfill_queue 171 self.backfill_queue = backfill_queue
177 self.upstream = upstream 172 self.upstream = upstream
178 173
179 self.handlers = { 174 self.handlers.update({
180 'get': self.handle_get, 175 'get': self.handle_get,
181 'get-outhash': self.handle_get_outhash, 176 'get-outhash': self.handle_get_outhash,
182 'get-stream': self.handle_get_stream, 177 'get-stream': self.handle_get_stream,
183 'get-stats': self.handle_get_stats, 178 'get-stats': self.handle_get_stats,
184 'chunk-stream': self.handle_chunk, 179 })
185 }
186 180
187 if not read_only: 181 if not read_only:
188 self.handlers.update({ 182 self.handlers.update({
@@ -192,56 +186,19 @@ class ServerClient(object):
192 'backfill-wait': self.handle_backfill_wait, 186 'backfill-wait': self.handle_backfill_wait,
193 }) 187 })
194 188
189 def validate_proto_version(self):
190 return (self.proto_version > (1, 0) and self.proto_version <= (1, 1))
191
195 async def process_requests(self): 192 async def process_requests(self):
196 if self.upstream is not None: 193 if self.upstream is not None:
197 self.upstream_client = await create_async_client(self.upstream) 194 self.upstream_client = await create_async_client(self.upstream)
198 else: 195 else:
199 self.upstream_client = None 196 self.upstream_client = None
200 197
201 try: 198 await super().process_requests()
202
203
204 self.addr = self.writer.get_extra_info('peername')
205 logger.debug('Client %r connected' % (self.addr,))
206
207 # Read protocol and version
208 protocol = await self.reader.readline()
209 if protocol is None:
210 return
211
212 (proto_name, proto_version) = protocol.decode('utf-8').rstrip().split()
213 if proto_name != 'OEHASHEQUIV':
214 return
215
216 proto_version = tuple(int(v) for v in proto_version.split('.'))
217 if proto_version < (1, 0) or proto_version > (1, 1):
218 return
219
220 # Read headers. Currently, no headers are implemented, so look for
221 # an empty line to signal the end of the headers
222 while True:
223 line = await self.reader.readline()
224 if line is None:
225 return
226 199
227 line = line.decode('utf-8').rstrip() 200 if self.upstream_client is not None:
228 if not line: 201 await self.upstream_client.close()
229 break
230
231 # Handle messages
232 while True:
233 d = await self.read_message()
234 if d is None:
235 break
236 await self.dispatch_message(d)
237 await self.writer.drain()
238 except ClientError as e:
239 logger.error(str(e))
240 finally:
241 if self.upstream_client is not None:
242 await self.upstream_client.close()
243
244 self.writer.close()
245 202
246 async def dispatch_message(self, msg): 203 async def dispatch_message(self, msg):
247 for k in self.handlers.keys(): 204 for k in self.handlers.keys():
@@ -255,47 +212,7 @@ class ServerClient(object):
255 await self.handlers[k](msg[k]) 212 await self.handlers[k](msg[k])
256 return 213 return
257 214
258 raise ClientError("Unrecognized command %r" % msg) 215 raise bb.asyncrpc.ClientError("Unrecognized command %r" % msg)
259
260 def write_message(self, msg):
261 for c in chunkify(json.dumps(msg), self.max_chunk):
262 self.writer.write(c.encode('utf-8'))
263
264 async def read_message(self):
265 l = await self.reader.readline()
266 if not l:
267 return None
268
269 try:
270 message = l.decode('utf-8')
271
272 if not message.endswith('\n'):
273 return None
274
275 return json.loads(message)
276 except (json.JSONDecodeError, UnicodeDecodeError) as e:
277 logger.error('Bad message from client: %r' % message)
278 raise e
279
280 async def handle_chunk(self, request):
281 lines = []
282 try:
283 while True:
284 l = await self.reader.readline()
285 l = l.rstrip(b"\n").decode("utf-8")
286 if not l:
287 break
288 lines.append(l)
289
290 msg = json.loads(''.join(lines))
291 except (json.JSONDecodeError, UnicodeDecodeError) as e:
292 logger.error('Bad message from client: %r' % message)
293 raise e
294
295 if 'chunk-stream' in msg:
296 raise ClientError("Nested chunks are not allowed")
297
298 await self.dispatch_message(msg)
299 216
300 async def handle_get(self, request): 217 async def handle_get(self, request):
301 method = request['method'] 218 method = request['method']
@@ -499,74 +416,20 @@ class ServerClient(object):
499 cursor.close() 416 cursor.close()
500 417
501 418
502class Server(object): 419class Server(bb.asyncrpc.AsyncServer):
503 def __init__(self, db, loop=None, upstream=None, read_only=False): 420 def __init__(self, db, loop=None, upstream=None, read_only=False):
504 if upstream and read_only: 421 if upstream and read_only:
505 raise ServerError("Read-only hashserv cannot pull from an upstream server") 422 raise bb.asyncrpc.ServerError("Read-only hashserv cannot pull from an upstream server")
423
424 super().__init__(logger, loop)
506 425
507 self.request_stats = Stats() 426 self.request_stats = Stats()
508 self.db = db 427 self.db = db
509
510 if loop is None:
511 self.loop = asyncio.new_event_loop()
512 self.close_loop = True
513 else:
514 self.loop = loop
515 self.close_loop = False
516
517 self.upstream = upstream 428 self.upstream = upstream
518 self.read_only = read_only 429 self.read_only = read_only
519 430
520 self._cleanup_socket = None 431 def accept_client(self, reader, writer):
521 432 return ServerClient(reader, writer, self.db, self.request_stats, self.backfill_queue, self.upstream, self.read_only)
522 def start_tcp_server(self, host, port):
523 self.server = self.loop.run_until_complete(
524 asyncio.start_server(self.handle_client, host, port, loop=self.loop)
525 )
526
527 for s in self.server.sockets:
528 logger.info('Listening on %r' % (s.getsockname(),))
529 # Newer python does this automatically. Do it manually here for
530 # maximum compatibility
531 s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
532 s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
533
534 name = self.server.sockets[0].getsockname()
535 if self.server.sockets[0].family == socket.AF_INET6:
536 self.address = "[%s]:%d" % (name[0], name[1])
537 else:
538 self.address = "%s:%d" % (name[0], name[1])
539
540 def start_unix_server(self, path):
541 def cleanup():
542 os.unlink(path)
543
544 cwd = os.getcwd()
545 try:
546 # Work around path length limits in AF_UNIX
547 os.chdir(os.path.dirname(path))
548 self.server = self.loop.run_until_complete(
549 asyncio.start_unix_server(self.handle_client, os.path.basename(path), loop=self.loop)
550 )
551 finally:
552 os.chdir(cwd)
553
554 logger.info('Listening on %r' % path)
555
556 self._cleanup_socket = cleanup
557 self.address = "unix://%s" % os.path.abspath(path)
558
559 async def handle_client(self, reader, writer):
560 # writer.transport.set_write_buffer_limits(0)
561 try:
562 client = ServerClient(reader, writer, self.db, self.request_stats, self.backfill_queue, self.upstream, self.read_only)
563 await client.process_requests()
564 except Exception as e:
565 import traceback
566 logger.error('Error from client: %s' % str(e), exc_info=True)
567 traceback.print_exc()
568 writer.close()
569 logger.info('Client disconnected')
570 433
571 @contextmanager 434 @contextmanager
572 def _backfill_worker(self): 435 def _backfill_worker(self):
@@ -597,31 +460,8 @@ class Server(object):
597 else: 460 else:
598 yield 461 yield
599 462
600 def serve_forever(self): 463 def run_loop_forever(self):
601 def signal_handler(): 464 self.backfill_queue = asyncio.Queue()
602 self.loop.stop()
603
604 asyncio.set_event_loop(self.loop)
605 try:
606 self.backfill_queue = asyncio.Queue()
607
608 self.loop.add_signal_handler(signal.SIGTERM, signal_handler)
609
610 with self._backfill_worker():
611 try:
612 self.loop.run_forever()
613 except KeyboardInterrupt:
614 pass
615
616 self.server.close()
617
618 self.loop.run_until_complete(self.server.wait_closed())
619 logger.info('Server shutting down')
620 finally:
621 if self.close_loop:
622 if sys.version_info >= (3, 6):
623 self.loop.run_until_complete(self.loop.shutdown_asyncgens())
624 self.loop.close()
625 465
626 if self._cleanup_socket is not None: 466 with self._backfill_worker():
627 self._cleanup_socket() 467 super().run_loop_forever()