summaryrefslogtreecommitdiffstats
path: root/bitbake/lib/bb/asyncrpc/serv.py
diff options
context:
space:
mode:
authorPaul Barker <pbarker@konsulko.com>2021-04-26 09:16:29 +0100
committerRichard Purdie <richard.purdie@linuxfoundation.org>2021-04-27 15:12:57 +0100
commit244b044fd6d94c000fc9cb8d1b7a9dddd08017ad (patch)
treea2123dd620bfce57d0429ff677aa3abea0fe226f /bitbake/lib/bb/asyncrpc/serv.py
parent10236718236e6a12e2e6528abcd920276d181545 (diff)
downloadpoky-244b044fd6d94c000fc9cb8d1b7a9dddd08017ad.tar.gz
bitbake: asyncrpc: Common implementation of RPC using json & asyncio
The hashserv module implements a flexible RPC mechanism based on sending json formatted messages over unix or tcp sockets and uses Python's asyncio features to build an efficient message loop on both the client and server side. Much of this implementation is not specific to the hash equivalency service and can be extracted into a new module for easy re-use elsewhere in bitbake. (Bitbake rev: 4105ffd967fa86154ad67366aaf0f898abf78d14) Signed-off-by: Paul Barker <pbarker@konsulko.com> Signed-off-by: Richard Purdie <richard.purdie@linuxfoundation.org>
Diffstat (limited to 'bitbake/lib/bb/asyncrpc/serv.py')
-rw-r--r--bitbake/lib/bb/asyncrpc/serv.py218
1 files changed, 218 insertions, 0 deletions
diff --git a/bitbake/lib/bb/asyncrpc/serv.py b/bitbake/lib/bb/asyncrpc/serv.py
new file mode 100644
index 0000000000..cb3384639d
--- /dev/null
+++ b/bitbake/lib/bb/asyncrpc/serv.py
@@ -0,0 +1,218 @@
1#
2# SPDX-License-Identifier: GPL-2.0-only
3#
4
5import abc
6import asyncio
7import json
8import os
9import signal
10import socket
11import sys
12from . import chunkify, DEFAULT_MAX_CHUNK
13
14
15class ClientError(Exception):
16 pass
17
18
19class ServerError(Exception):
20 pass
21
22
23class AsyncServerConnection(object):
24 def __init__(self, reader, writer, proto_name, logger):
25 self.reader = reader
26 self.writer = writer
27 self.proto_name = proto_name
28 self.max_chunk = DEFAULT_MAX_CHUNK
29 self.handlers = {
30 'chunk-stream': self.handle_chunk,
31 }
32 self.logger = logger
33
34 async def process_requests(self):
35 try:
36 self.addr = self.writer.get_extra_info('peername')
37 self.logger.debug('Client %r connected' % (self.addr,))
38
39 # Read protocol and version
40 client_protocol = await self.reader.readline()
41 if client_protocol is None:
42 return
43
44 (client_proto_name, client_proto_version) = client_protocol.decode('utf-8').rstrip().split()
45 if client_proto_name != self.proto_name:
46 self.logger.debug('Rejecting invalid protocol %s' % (self.proto_name))
47 return
48
49 self.proto_version = tuple(int(v) for v in client_proto_version.split('.'))
50 if not self.validate_proto_version():
51 self.logger.debug('Rejecting invalid protocol version %s' % (client_proto_version))
52 return
53
54 # Read headers. Currently, no headers are implemented, so look for
55 # an empty line to signal the end of the headers
56 while True:
57 line = await self.reader.readline()
58 if line is None:
59 return
60
61 line = line.decode('utf-8').rstrip()
62 if not line:
63 break
64
65 # Handle messages
66 while True:
67 d = await self.read_message()
68 if d is None:
69 break
70 await self.dispatch_message(d)
71 await self.writer.drain()
72 except ClientError as e:
73 self.logger.error(str(e))
74 finally:
75 self.writer.close()
76
77 async def dispatch_message(self, msg):
78 for k in self.handlers.keys():
79 if k in msg:
80 self.logger.debug('Handling %s' % k)
81 await self.handlers[k](msg[k])
82 return
83
84 raise ClientError("Unrecognized command %r" % msg)
85
86 def write_message(self, msg):
87 for c in chunkify(json.dumps(msg), self.max_chunk):
88 self.writer.write(c.encode('utf-8'))
89
90 async def read_message(self):
91 l = await self.reader.readline()
92 if not l:
93 return None
94
95 try:
96 message = l.decode('utf-8')
97
98 if not message.endswith('\n'):
99 return None
100
101 return json.loads(message)
102 except (json.JSONDecodeError, UnicodeDecodeError) as e:
103 self.logger.error('Bad message from client: %r' % message)
104 raise e
105
106 async def handle_chunk(self, request):
107 lines = []
108 try:
109 while True:
110 l = await self.reader.readline()
111 l = l.rstrip(b"\n").decode("utf-8")
112 if not l:
113 break
114 lines.append(l)
115
116 msg = json.loads(''.join(lines))
117 except (json.JSONDecodeError, UnicodeDecodeError) as e:
118 self.logger.error('Bad message from client: %r' % lines)
119 raise e
120
121 if 'chunk-stream' in msg:
122 raise ClientError("Nested chunks are not allowed")
123
124 await self.dispatch_message(msg)
125
126
127class AsyncServer(object):
128 def __init__(self, logger, loop=None):
129 if loop is None:
130 self.loop = asyncio.new_event_loop()
131 self.close_loop = True
132 else:
133 self.loop = loop
134 self.close_loop = False
135
136 self._cleanup_socket = None
137 self.logger = logger
138
139 def start_tcp_server(self, host, port):
140 self.server = self.loop.run_until_complete(
141 asyncio.start_server(self.handle_client, host, port, loop=self.loop)
142 )
143
144 for s in self.server.sockets:
145 self.logger.info('Listening on %r' % (s.getsockname(),))
146 # Newer python does this automatically. Do it manually here for
147 # maximum compatibility
148 s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
149 s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
150
151 name = self.server.sockets[0].getsockname()
152 if self.server.sockets[0].family == socket.AF_INET6:
153 self.address = "[%s]:%d" % (name[0], name[1])
154 else:
155 self.address = "%s:%d" % (name[0], name[1])
156
157 def start_unix_server(self, path):
158 def cleanup():
159 os.unlink(path)
160
161 cwd = os.getcwd()
162 try:
163 # Work around path length limits in AF_UNIX
164 os.chdir(os.path.dirname(path))
165 self.server = self.loop.run_until_complete(
166 asyncio.start_unix_server(self.handle_client, os.path.basename(path), loop=self.loop)
167 )
168 finally:
169 os.chdir(cwd)
170
171 self.logger.info('Listening on %r' % path)
172
173 self._cleanup_socket = cleanup
174 self.address = "unix://%s" % os.path.abspath(path)
175
176 @abc.abstractmethod
177 def accept_client(self, reader, writer):
178 pass
179
180 async def handle_client(self, reader, writer):
181 # writer.transport.set_write_buffer_limits(0)
182 try:
183 client = self.accept_client(reader, writer)
184 await client.process_requests()
185 except Exception as e:
186 import traceback
187 self.logger.error('Error from client: %s' % str(e), exc_info=True)
188 traceback.print_exc()
189 writer.close()
190 self.logger.info('Client disconnected')
191
192 def run_loop_forever(self):
193 try:
194 self.loop.run_forever()
195 except KeyboardInterrupt:
196 pass
197
198 def signal_handler(self):
199 self.loop.stop()
200
201 def serve_forever(self):
202 asyncio.set_event_loop(self.loop)
203 try:
204 self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler)
205
206 self.run_loop_forever()
207 self.server.close()
208
209 self.loop.run_until_complete(self.server.wait_closed())
210 self.logger.info('Server shutting down')
211 finally:
212 if self.close_loop:
213 if sys.version_info >= (3, 6):
214 self.loop.run_until_complete(self.loop.shutdown_asyncgens())
215 self.loop.close()
216
217 if self._cleanup_socket is not None:
218 self._cleanup_socket()