summaryrefslogtreecommitdiffstats
path: root/bitbake/lib/bb/asyncrpc/serv.py
diff options
context:
space:
mode:
Diffstat (limited to 'bitbake/lib/bb/asyncrpc/serv.py')
-rw-r--r--bitbake/lib/bb/asyncrpc/serv.py304
1 files changed, 156 insertions, 148 deletions
diff --git a/bitbake/lib/bb/asyncrpc/serv.py b/bitbake/lib/bb/asyncrpc/serv.py
index d2de4891b8..3e0d0632cb 100644
--- a/bitbake/lib/bb/asyncrpc/serv.py
+++ b/bitbake/lib/bb/asyncrpc/serv.py
@@ -12,241 +12,248 @@ import signal
12import socket 12import socket
13import sys 13import sys
14import multiprocessing 14import multiprocessing
15from . import chunkify, DEFAULT_MAX_CHUNK 15from .connection import StreamConnection
16 16from .exceptions import ClientError, ServerError, ConnectionClosedError
17
18class ClientError(Exception):
19 pass
20
21
22class ServerError(Exception):
23 pass
24 17
25 18
26class AsyncServerConnection(object): 19class AsyncServerConnection(object):
27 def __init__(self, reader, writer, proto_name, logger): 20 # If a handler returns this object (e.g. `return self.NO_RESPONSE`), no
28 self.reader = reader 21 # return message will be automatically be sent back to the client
29 self.writer = writer 22 NO_RESPONSE = object()
23
24 def __init__(self, socket, proto_name, logger):
25 self.socket = socket
30 self.proto_name = proto_name 26 self.proto_name = proto_name
31 self.max_chunk = DEFAULT_MAX_CHUNK
32 self.handlers = { 27 self.handlers = {
33 'chunk-stream': self.handle_chunk, 28 "ping": self.handle_ping,
34 'ping': self.handle_ping,
35 } 29 }
36 self.logger = logger 30 self.logger = logger
37 31
32 async def close(self):
33 await self.socket.close()
34
38 async def process_requests(self): 35 async def process_requests(self):
39 try: 36 try:
40 self.addr = self.writer.get_extra_info('peername') 37 self.logger.info("Client %r connected" % (self.socket.address,))
41 self.logger.debug('Client %r connected' % (self.addr,))
42 38
43 # Read protocol and version 39 # Read protocol and version
44 client_protocol = await self.reader.readline() 40 client_protocol = await self.socket.recv()
45 if not client_protocol: 41 if not client_protocol:
46 return 42 return
47 43
48 (client_proto_name, client_proto_version) = client_protocol.decode('utf-8').rstrip().split() 44 (client_proto_name, client_proto_version) = client_protocol.split()
49 if client_proto_name != self.proto_name: 45 if client_proto_name != self.proto_name:
50 self.logger.debug('Rejecting invalid protocol %s' % (self.proto_name)) 46 self.logger.debug("Rejecting invalid protocol %s" % (self.proto_name))
51 return 47 return
52 48
53 self.proto_version = tuple(int(v) for v in client_proto_version.split('.')) 49 self.proto_version = tuple(int(v) for v in client_proto_version.split("."))
54 if not self.validate_proto_version(): 50 if not self.validate_proto_version():
55 self.logger.debug('Rejecting invalid protocol version %s' % (client_proto_version)) 51 self.logger.debug(
52 "Rejecting invalid protocol version %s" % (client_proto_version)
53 )
56 return 54 return
57 55
58 # Read headers. Currently, no headers are implemented, so look for 56 # Read headers. Currently, no headers are implemented, so look for
59 # an empty line to signal the end of the headers 57 # an empty line to signal the end of the headers
60 while True: 58 while True:
61 line = await self.reader.readline() 59 header = await self.socket.recv()
62 if not line: 60 if not header:
63 return
64
65 line = line.decode('utf-8').rstrip()
66 if not line:
67 break 61 break
68 62
69 # Handle messages 63 # Handle messages
70 while True: 64 while True:
71 d = await self.read_message() 65 d = await self.socket.recv_message()
72 if d is None: 66 if d is None:
73 break 67 break
74 await self.dispatch_message(d) 68 response = await self.dispatch_message(d)
75 await self.writer.drain() 69 if response is not self.NO_RESPONSE:
76 except ClientError as e: 70 await self.socket.send_message(response)
71
72 except ConnectionClosedError as e:
73 self.logger.info(str(e))
74 except (ClientError, ConnectionError) as e:
77 self.logger.error(str(e)) 75 self.logger.error(str(e))
78 finally: 76 finally:
79 self.writer.close() 77 await self.close()
80 78
81 async def dispatch_message(self, msg): 79 async def dispatch_message(self, msg):
82 for k in self.handlers.keys(): 80 for k in self.handlers.keys():
83 if k in msg: 81 if k in msg:
84 self.logger.debug('Handling %s' % k) 82 self.logger.debug("Handling %s" % k)
85 await self.handlers[k](msg[k]) 83 return await self.handlers[k](msg[k])
86 return
87 84
88 raise ClientError("Unrecognized command %r" % msg) 85 raise ClientError("Unrecognized command %r" % msg)
89 86
90 def write_message(self, msg): 87 async def handle_ping(self, request):
91 for c in chunkify(json.dumps(msg), self.max_chunk): 88 return {"alive": True}
92 self.writer.write(c.encode('utf-8'))
93 89
94 async def read_message(self):
95 l = await self.reader.readline()
96 if not l:
97 return None
98 90
99 try: 91class StreamServer(object):
100 message = l.decode('utf-8') 92 def __init__(self, handler, logger):
93 self.handler = handler
94 self.logger = logger
95 self.closed = False
101 96
102 if not message.endswith('\n'): 97 async def handle_stream_client(self, reader, writer):
103 return None 98 # writer.transport.set_write_buffer_limits(0)
99 socket = StreamConnection(reader, writer, -1)
100 if self.closed:
101 await socket.close()
102 return
103
104 await self.handler(socket)
105
106 async def stop(self):
107 self.closed = True
108
109
110class TCPStreamServer(StreamServer):
111 def __init__(self, host, port, handler, logger):
112 super().__init__(handler, logger)
113 self.host = host
114 self.port = port
115
116 def start(self, loop):
117 self.server = loop.run_until_complete(
118 asyncio.start_server(self.handle_stream_client, self.host, self.port)
119 )
120
121 for s in self.server.sockets:
122 self.logger.debug("Listening on %r" % (s.getsockname(),))
123 # Newer python does this automatically. Do it manually here for
124 # maximum compatibility
125 s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
126 s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
127
128 # Enable keep alives. This prevents broken client connections
129 # from persisting on the server for long periods of time.
130 s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
131 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
132 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
133 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
134
135 name = self.server.sockets[0].getsockname()
136 if self.server.sockets[0].family == socket.AF_INET6:
137 self.address = "[%s]:%d" % (name[0], name[1])
138 else:
139 self.address = "%s:%d" % (name[0], name[1])
140
141 return [self.server.wait_closed()]
142
143 async def stop(self):
144 await super().stop()
145 self.server.close()
146
147 def cleanup(self):
148 pass
104 149
105 return json.loads(message)
106 except (json.JSONDecodeError, UnicodeDecodeError) as e:
107 self.logger.error('Bad message from client: %r' % message)
108 raise e
109 150
110 async def handle_chunk(self, request): 151class UnixStreamServer(StreamServer):
111 lines = [] 152 def __init__(self, path, handler, logger):
112 try: 153 super().__init__(handler, logger)
113 while True: 154 self.path = path
114 l = await self.reader.readline()
115 l = l.rstrip(b"\n").decode("utf-8")
116 if not l:
117 break
118 lines.append(l)
119 155
120 msg = json.loads(''.join(lines)) 156 def start(self, loop):
121 except (json.JSONDecodeError, UnicodeDecodeError) as e: 157 cwd = os.getcwd()
122 self.logger.error('Bad message from client: %r' % lines) 158 try:
123 raise e 159 # Work around path length limits in AF_UNIX
160 os.chdir(os.path.dirname(self.path))
161 self.server = loop.run_until_complete(
162 asyncio.start_unix_server(
163 self.handle_stream_client, os.path.basename(self.path)
164 )
165 )
166 finally:
167 os.chdir(cwd)
124 168
125 if 'chunk-stream' in msg: 169 self.logger.debug("Listening on %r" % self.path)
126 raise ClientError("Nested chunks are not allowed") 170 self.address = "unix://%s" % os.path.abspath(self.path)
171 return [self.server.wait_closed()]
127 172
128 await self.dispatch_message(msg) 173 async def stop(self):
174 await super().stop()
175 self.server.close()
129 176
130 async def handle_ping(self, request): 177 def cleanup(self):
131 response = {'alive': True} 178 os.unlink(self.path)
132 self.write_message(response)
133 179
134 180
135class AsyncServer(object): 181class AsyncServer(object):
136 def __init__(self, logger): 182 def __init__(self, logger):
137 self._cleanup_socket = None
138 self.logger = logger 183 self.logger = logger
139 self.start = None
140 self.address = None
141 self.loop = None 184 self.loop = None
185 self.run_tasks = []
142 186
143 def start_tcp_server(self, host, port): 187 def start_tcp_server(self, host, port):
144 def start_tcp(): 188 self.server = TCPStreamServer(host, port, self._client_handler, self.logger)
145 self.server = self.loop.run_until_complete(
146 asyncio.start_server(self.handle_client, host, port)
147 )
148
149 for s in self.server.sockets:
150 self.logger.debug('Listening on %r' % (s.getsockname(),))
151 # Newer python does this automatically. Do it manually here for
152 # maximum compatibility
153 s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
154 s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
155
156 # Enable keep alives. This prevents broken client connections
157 # from persisting on the server for long periods of time.
158 s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
159 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
160 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
161 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
162
163 name = self.server.sockets[0].getsockname()
164 if self.server.sockets[0].family == socket.AF_INET6:
165 self.address = "[%s]:%d" % (name[0], name[1])
166 else:
167 self.address = "%s:%d" % (name[0], name[1])
168
169 self.start = start_tcp
170 189
171 def start_unix_server(self, path): 190 def start_unix_server(self, path):
172 def cleanup(): 191 self.server = UnixStreamServer(path, self._client_handler, self.logger)
173 os.unlink(path)
174
175 def start_unix():
176 cwd = os.getcwd()
177 try:
178 # Work around path length limits in AF_UNIX
179 os.chdir(os.path.dirname(path))
180 self.server = self.loop.run_until_complete(
181 asyncio.start_unix_server(self.handle_client, os.path.basename(path))
182 )
183 finally:
184 os.chdir(cwd)
185
186 self.logger.debug('Listening on %r' % path)
187 192
188 self._cleanup_socket = cleanup 193 async def _client_handler(self, socket):
189 self.address = "unix://%s" % os.path.abspath(path)
190
191 self.start = start_unix
192
193 @abc.abstractmethod
194 def accept_client(self, reader, writer):
195 pass
196
197 async def handle_client(self, reader, writer):
198 # writer.transport.set_write_buffer_limits(0)
199 try: 194 try:
200 client = self.accept_client(reader, writer) 195 client = self.accept_client(socket)
201 await client.process_requests() 196 await client.process_requests()
202 except Exception as e: 197 except Exception as e:
203 import traceback 198 import traceback
204 self.logger.error('Error from client: %s' % str(e), exc_info=True) 199
200 self.logger.error("Error from client: %s" % str(e), exc_info=True)
205 traceback.print_exc() 201 traceback.print_exc()
206 writer.close() 202 await socket.close()
207 self.logger.debug('Client disconnected') 203 self.logger.debug("Client disconnected")
208 204
209 def run_loop_forever(self): 205 @abc.abstractmethod
210 try: 206 def accept_client(self, socket):
211 self.loop.run_forever() 207 pass
212 except KeyboardInterrupt: 208
213 pass 209 async def stop(self):
210 self.logger.debug("Stopping server")
211 await self.server.stop()
212
213 def start(self):
214 tasks = self.server.start(self.loop)
215 self.address = self.server.address
216 return tasks
214 217
215 def signal_handler(self): 218 def signal_handler(self):
216 self.logger.debug("Got exit signal") 219 self.logger.debug("Got exit signal")
217 self.loop.stop() 220 self.loop.create_task(self.stop())
218 221
219 def _serve_forever(self): 222 def _serve_forever(self, tasks):
220 try: 223 try:
221 self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler) 224 self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler)
225 self.loop.add_signal_handler(signal.SIGINT, self.signal_handler)
226 self.loop.add_signal_handler(signal.SIGQUIT, self.signal_handler)
222 signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM]) 227 signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM])
223 228
224 self.run_loop_forever() 229 self.loop.run_until_complete(asyncio.gather(*tasks))
225 self.server.close()
226 230
227 self.loop.run_until_complete(self.server.wait_closed()) 231 self.logger.debug("Server shutting down")
228 self.logger.debug('Server shutting down')
229 finally: 232 finally:
230 if self._cleanup_socket is not None: 233 self.server.cleanup()
231 self._cleanup_socket()
232 234
233 def serve_forever(self): 235 def serve_forever(self):
234 """ 236 """
235 Serve requests in the current process 237 Serve requests in the current process
236 """ 238 """
239 self._create_loop()
240 tasks = self.start()
241 self._serve_forever(tasks)
242 self.loop.close()
243
244 def _create_loop(self):
237 # Create loop and override any loop that may have existed in 245 # Create loop and override any loop that may have existed in
238 # a parent process. It is possible that the usecases of 246 # a parent process. It is possible that the usecases of
239 # serve_forever might be constrained enough to allow using 247 # serve_forever might be constrained enough to allow using
240 # get_event_loop here, but better safe than sorry for now. 248 # get_event_loop here, but better safe than sorry for now.
241 self.loop = asyncio.new_event_loop() 249 self.loop = asyncio.new_event_loop()
242 asyncio.set_event_loop(self.loop) 250 asyncio.set_event_loop(self.loop)
243 self.start()
244 self._serve_forever()
245 251
246 def serve_as_process(self, *, prefunc=None, args=()): 252 def serve_as_process(self, *, prefunc=None, args=()):
247 """ 253 """
248 Serve requests in a child process 254 Serve requests in a child process
249 """ 255 """
256
250 def run(queue): 257 def run(queue):
251 # Create loop and override any loop that may have existed 258 # Create loop and override any loop that may have existed
252 # in a parent process. Without doing this and instead 259 # in a parent process. Without doing this and instead
@@ -259,18 +266,19 @@ class AsyncServer(object):
259 # more general, though, as any potential use of asyncio in 266 # more general, though, as any potential use of asyncio in
260 # Cooker could create a loop that needs to replaced in this 267 # Cooker could create a loop that needs to replaced in this
261 # new process. 268 # new process.
262 self.loop = asyncio.new_event_loop() 269 self._create_loop()
263 asyncio.set_event_loop(self.loop)
264 try: 270 try:
265 self.start() 271 self.address = None
272 tasks = self.start()
266 finally: 273 finally:
274 # Always put the server address to wake up the parent task
267 queue.put(self.address) 275 queue.put(self.address)
268 queue.close() 276 queue.close()
269 277
270 if prefunc is not None: 278 if prefunc is not None:
271 prefunc(self, *args) 279 prefunc(self, *args)
272 280
273 self._serve_forever() 281 self._serve_forever(tasks)
274 282
275 if sys.version_info >= (3, 6): 283 if sys.version_info >= (3, 6):
276 self.loop.run_until_complete(self.loop.shutdown_asyncgens()) 284 self.loop.run_until_complete(self.loop.shutdown_asyncgens())