summaryrefslogtreecommitdiffstats
path: root/bitbake
diff options
context:
space:
mode:
Diffstat (limited to 'bitbake')
-rw-r--r--bitbake/lib/bb/asyncrpc/__init__.py31
-rw-r--r--bitbake/lib/bb/asyncrpc/client.py145
-rw-r--r--bitbake/lib/bb/asyncrpc/serv.py218
3 files changed, 394 insertions, 0 deletions
diff --git a/bitbake/lib/bb/asyncrpc/__init__.py b/bitbake/lib/bb/asyncrpc/__init__.py
new file mode 100644
index 0000000000..b2bec31ab2
--- /dev/null
+++ b/bitbake/lib/bb/asyncrpc/__init__.py
@@ -0,0 +1,31 @@
1#
2# SPDX-License-Identifier: GPL-2.0-only
3#
4
5import itertools
6import json
7
8# The Python async server defaults to a 64K receive buffer, so we hardcode our
9# maximum chunk size. It would be better if the client and server reported to
10# each other what the maximum chunk sizes were, but that will slow down the
11# connection setup with a round trip delay so I'd rather not do that unless it
12# is necessary
13DEFAULT_MAX_CHUNK = 32 * 1024
14
15
16def chunkify(msg, max_chunk):
17 if len(msg) < max_chunk - 1:
18 yield ''.join((msg, "\n"))
19 else:
20 yield ''.join((json.dumps({
21 'chunk-stream': None
22 }), "\n"))
23
24 args = [iter(msg)] * (max_chunk - 1)
25 for m in map(''.join, itertools.zip_longest(*args, fillvalue='')):
26 yield ''.join(itertools.chain(m, "\n"))
27 yield "\n"
28
29
30from .client import AsyncClient, Client
31from .serv import AsyncServer, AsyncServerConnection
diff --git a/bitbake/lib/bb/asyncrpc/client.py b/bitbake/lib/bb/asyncrpc/client.py
new file mode 100644
index 0000000000..4cdad9ac3c
--- /dev/null
+++ b/bitbake/lib/bb/asyncrpc/client.py
@@ -0,0 +1,145 @@
1#
2# SPDX-License-Identifier: GPL-2.0-only
3#
4
5import abc
6import asyncio
7import json
8import os
9import socket
10from . import chunkify, DEFAULT_MAX_CHUNK
11
12
13class AsyncClient(object):
14 def __init__(self, proto_name, proto_version, logger):
15 self.reader = None
16 self.writer = None
17 self.max_chunk = DEFAULT_MAX_CHUNK
18 self.proto_name = proto_name
19 self.proto_version = proto_version
20 self.logger = logger
21
22 async def connect_tcp(self, address, port):
23 async def connect_sock():
24 return await asyncio.open_connection(address, port)
25
26 self._connect_sock = connect_sock
27
28 async def connect_unix(self, path):
29 async def connect_sock():
30 return await asyncio.open_unix_connection(path)
31
32 self._connect_sock = connect_sock
33
34 async def setup_connection(self):
35 s = '%s %s\n\n' % (self.proto_name, self.proto_version)
36 self.writer.write(s.encode("utf-8"))
37 await self.writer.drain()
38
39 async def connect(self):
40 if self.reader is None or self.writer is None:
41 (self.reader, self.writer) = await self._connect_sock()
42 await self.setup_connection()
43
44 async def close(self):
45 self.reader = None
46
47 if self.writer is not None:
48 self.writer.close()
49 self.writer = None
50
51 async def _send_wrapper(self, proc):
52 count = 0
53 while True:
54 try:
55 await self.connect()
56 return await proc()
57 except (
58 OSError,
59 ConnectionError,
60 json.JSONDecodeError,
61 UnicodeDecodeError,
62 ) as e:
63 self.logger.warning("Error talking to server: %s" % e)
64 if count >= 3:
65 if not isinstance(e, ConnectionError):
66 raise ConnectionError(str(e))
67 raise e
68 await self.close()
69 count += 1
70
71 async def send_message(self, msg):
72 async def get_line():
73 line = await self.reader.readline()
74 if not line:
75 raise ConnectionError("Connection closed")
76
77 line = line.decode("utf-8")
78
79 if not line.endswith("\n"):
80 raise ConnectionError("Bad message %r" % msg)
81
82 return line
83
84 async def proc():
85 for c in chunkify(json.dumps(msg), self.max_chunk):
86 self.writer.write(c.encode("utf-8"))
87 await self.writer.drain()
88
89 l = await get_line()
90
91 m = json.loads(l)
92 if m and "chunk-stream" in m:
93 lines = []
94 while True:
95 l = (await get_line()).rstrip("\n")
96 if not l:
97 break
98 lines.append(l)
99
100 m = json.loads("".join(lines))
101
102 return m
103
104 return await self._send_wrapper(proc)
105
106
107class Client(object):
108 def __init__(self):
109 self.client = self._get_async_client()
110 self.loop = asyncio.new_event_loop()
111
112 self._add_methods('connect_tcp', 'close')
113
114 @abc.abstractmethod
115 def _get_async_client(self):
116 pass
117
118 def _get_downcall_wrapper(self, downcall):
119 def wrapper(*args, **kwargs):
120 return self.loop.run_until_complete(downcall(*args, **kwargs))
121
122 return wrapper
123
124 def _add_methods(self, *methods):
125 for m in methods:
126 downcall = getattr(self.client, m)
127 setattr(self, m, self._get_downcall_wrapper(downcall))
128
129 def connect_unix(self, path):
130 # AF_UNIX has path length issues so chdir here to workaround
131 cwd = os.getcwd()
132 try:
133 os.chdir(os.path.dirname(path))
134 self.loop.run_until_complete(self.client.connect_unix(os.path.basename(path)))
135 self.loop.run_until_complete(self.client.connect())
136 finally:
137 os.chdir(cwd)
138
139 @property
140 def max_chunk(self):
141 return self.client.max_chunk
142
143 @max_chunk.setter
144 def max_chunk(self, value):
145 self.client.max_chunk = value
diff --git a/bitbake/lib/bb/asyncrpc/serv.py b/bitbake/lib/bb/asyncrpc/serv.py
new file mode 100644
index 0000000000..cb3384639d
--- /dev/null
+++ b/bitbake/lib/bb/asyncrpc/serv.py
@@ -0,0 +1,218 @@
1#
2# SPDX-License-Identifier: GPL-2.0-only
3#
4
5import abc
6import asyncio
7import json
8import os
9import signal
10import socket
11import sys
12from . import chunkify, DEFAULT_MAX_CHUNK
13
14
15class ClientError(Exception):
16 pass
17
18
19class ServerError(Exception):
20 pass
21
22
23class AsyncServerConnection(object):
24 def __init__(self, reader, writer, proto_name, logger):
25 self.reader = reader
26 self.writer = writer
27 self.proto_name = proto_name
28 self.max_chunk = DEFAULT_MAX_CHUNK
29 self.handlers = {
30 'chunk-stream': self.handle_chunk,
31 }
32 self.logger = logger
33
34 async def process_requests(self):
35 try:
36 self.addr = self.writer.get_extra_info('peername')
37 self.logger.debug('Client %r connected' % (self.addr,))
38
39 # Read protocol and version
40 client_protocol = await self.reader.readline()
41 if client_protocol is None:
42 return
43
44 (client_proto_name, client_proto_version) = client_protocol.decode('utf-8').rstrip().split()
45 if client_proto_name != self.proto_name:
46 self.logger.debug('Rejecting invalid protocol %s' % (self.proto_name))
47 return
48
49 self.proto_version = tuple(int(v) for v in client_proto_version.split('.'))
50 if not self.validate_proto_version():
51 self.logger.debug('Rejecting invalid protocol version %s' % (client_proto_version))
52 return
53
54 # Read headers. Currently, no headers are implemented, so look for
55 # an empty line to signal the end of the headers
56 while True:
57 line = await self.reader.readline()
58 if line is None:
59 return
60
61 line = line.decode('utf-8').rstrip()
62 if not line:
63 break
64
65 # Handle messages
66 while True:
67 d = await self.read_message()
68 if d is None:
69 break
70 await self.dispatch_message(d)
71 await self.writer.drain()
72 except ClientError as e:
73 self.logger.error(str(e))
74 finally:
75 self.writer.close()
76
77 async def dispatch_message(self, msg):
78 for k in self.handlers.keys():
79 if k in msg:
80 self.logger.debug('Handling %s' % k)
81 await self.handlers[k](msg[k])
82 return
83
84 raise ClientError("Unrecognized command %r" % msg)
85
86 def write_message(self, msg):
87 for c in chunkify(json.dumps(msg), self.max_chunk):
88 self.writer.write(c.encode('utf-8'))
89
90 async def read_message(self):
91 l = await self.reader.readline()
92 if not l:
93 return None
94
95 try:
96 message = l.decode('utf-8')
97
98 if not message.endswith('\n'):
99 return None
100
101 return json.loads(message)
102 except (json.JSONDecodeError, UnicodeDecodeError) as e:
103 self.logger.error('Bad message from client: %r' % message)
104 raise e
105
106 async def handle_chunk(self, request):
107 lines = []
108 try:
109 while True:
110 l = await self.reader.readline()
111 l = l.rstrip(b"\n").decode("utf-8")
112 if not l:
113 break
114 lines.append(l)
115
116 msg = json.loads(''.join(lines))
117 except (json.JSONDecodeError, UnicodeDecodeError) as e:
118 self.logger.error('Bad message from client: %r' % lines)
119 raise e
120
121 if 'chunk-stream' in msg:
122 raise ClientError("Nested chunks are not allowed")
123
124 await self.dispatch_message(msg)
125
126
127class AsyncServer(object):
128 def __init__(self, logger, loop=None):
129 if loop is None:
130 self.loop = asyncio.new_event_loop()
131 self.close_loop = True
132 else:
133 self.loop = loop
134 self.close_loop = False
135
136 self._cleanup_socket = None
137 self.logger = logger
138
139 def start_tcp_server(self, host, port):
140 self.server = self.loop.run_until_complete(
141 asyncio.start_server(self.handle_client, host, port, loop=self.loop)
142 )
143
144 for s in self.server.sockets:
145 self.logger.info('Listening on %r' % (s.getsockname(),))
146 # Newer python does this automatically. Do it manually here for
147 # maximum compatibility
148 s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
149 s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
150
151 name = self.server.sockets[0].getsockname()
152 if self.server.sockets[0].family == socket.AF_INET6:
153 self.address = "[%s]:%d" % (name[0], name[1])
154 else:
155 self.address = "%s:%d" % (name[0], name[1])
156
157 def start_unix_server(self, path):
158 def cleanup():
159 os.unlink(path)
160
161 cwd = os.getcwd()
162 try:
163 # Work around path length limits in AF_UNIX
164 os.chdir(os.path.dirname(path))
165 self.server = self.loop.run_until_complete(
166 asyncio.start_unix_server(self.handle_client, os.path.basename(path), loop=self.loop)
167 )
168 finally:
169 os.chdir(cwd)
170
171 self.logger.info('Listening on %r' % path)
172
173 self._cleanup_socket = cleanup
174 self.address = "unix://%s" % os.path.abspath(path)
175
176 @abc.abstractmethod
177 def accept_client(self, reader, writer):
178 pass
179
180 async def handle_client(self, reader, writer):
181 # writer.transport.set_write_buffer_limits(0)
182 try:
183 client = self.accept_client(reader, writer)
184 await client.process_requests()
185 except Exception as e:
186 import traceback
187 self.logger.error('Error from client: %s' % str(e), exc_info=True)
188 traceback.print_exc()
189 writer.close()
190 self.logger.info('Client disconnected')
191
192 def run_loop_forever(self):
193 try:
194 self.loop.run_forever()
195 except KeyboardInterrupt:
196 pass
197
198 def signal_handler(self):
199 self.loop.stop()
200
201 def serve_forever(self):
202 asyncio.set_event_loop(self.loop)
203 try:
204 self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler)
205
206 self.run_loop_forever()
207 self.server.close()
208
209 self.loop.run_until_complete(self.server.wait_closed())
210 self.logger.info('Server shutting down')
211 finally:
212 if self.close_loop:
213 if sys.version_info >= (3, 6):
214 self.loop.run_until_complete(self.loop.shutdown_asyncgens())
215 self.loop.close()
216
217 if self._cleanup_socket is not None:
218 self._cleanup_socket()