diff options
author | Joshua Watt <JPEWhacker@gmail.com> | 2023-11-03 08:26:19 -0600 |
---|---|---|
committer | Richard Purdie <richard.purdie@linuxfoundation.org> | 2023-11-09 17:33:02 +0000 |
commit | 8f8501ed403dec27acbe780b936bc087fc5006d0 (patch) | |
tree | 60e6415075c7c71eacec23ca7dda53e4a324b12e /bitbake/lib/bb | |
parent | f97b686884166dd77d1818e70615027c6ba8c348 (diff) | |
download | poky-8f8501ed403dec27acbe780b936bc087fc5006d0.tar.gz |
bitbake: asyncrpc: Abstract sockets
Rewrites the asyncrpc client and server code to make it possible to have
other transport backends that are not stream based (e.g. websockets
which are message based). The connection handling classes are now shared
between both the client and server to make it easier to implement new
transport mechanisms
(Bitbake rev: 2aaeae53696e4c2f13a169830c3b7089cbad6eca)
Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
Signed-off-by: Richard Purdie <richard.purdie@linuxfoundation.org>
Diffstat (limited to 'bitbake/lib/bb')
-rw-r--r-- | bitbake/lib/bb/asyncrpc/__init__.py | 32 | ||||
-rw-r--r-- | bitbake/lib/bb/asyncrpc/client.py | 78 | ||||
-rw-r--r-- | bitbake/lib/bb/asyncrpc/connection.py | 95 | ||||
-rw-r--r-- | bitbake/lib/bb/asyncrpc/exceptions.py | 17 | ||||
-rw-r--r-- | bitbake/lib/bb/asyncrpc/serv.py | 304 |
5 files changed, 298 insertions, 228 deletions
diff --git a/bitbake/lib/bb/asyncrpc/__init__.py b/bitbake/lib/bb/asyncrpc/__init__.py index 9a85e9965b..9f677eac4c 100644 --- a/bitbake/lib/bb/asyncrpc/__init__.py +++ b/bitbake/lib/bb/asyncrpc/__init__.py | |||
@@ -4,30 +4,12 @@ | |||
4 | # SPDX-License-Identifier: GPL-2.0-only | 4 | # SPDX-License-Identifier: GPL-2.0-only |
5 | # | 5 | # |
6 | 6 | ||
7 | import itertools | ||
8 | import json | ||
9 | |||
10 | # The Python async server defaults to a 64K receive buffer, so we hardcode our | ||
11 | # maximum chunk size. It would be better if the client and server reported to | ||
12 | # each other what the maximum chunk sizes were, but that will slow down the | ||
13 | # connection setup with a round trip delay so I'd rather not do that unless it | ||
14 | # is necessary | ||
15 | DEFAULT_MAX_CHUNK = 32 * 1024 | ||
16 | |||
17 | |||
18 | def chunkify(msg, max_chunk): | ||
19 | if len(msg) < max_chunk - 1: | ||
20 | yield ''.join((msg, "\n")) | ||
21 | else: | ||
22 | yield ''.join((json.dumps({ | ||
23 | 'chunk-stream': None | ||
24 | }), "\n")) | ||
25 | |||
26 | args = [iter(msg)] * (max_chunk - 1) | ||
27 | for m in map(''.join, itertools.zip_longest(*args, fillvalue='')): | ||
28 | yield ''.join(itertools.chain(m, "\n")) | ||
29 | yield "\n" | ||
30 | |||
31 | 7 | ||
32 | from .client import AsyncClient, Client | 8 | from .client import AsyncClient, Client |
33 | from .serv import AsyncServer, AsyncServerConnection, ClientError, ServerError | 9 | from .serv import AsyncServer, AsyncServerConnection |
10 | from .connection import DEFAULT_MAX_CHUNK | ||
11 | from .exceptions import ( | ||
12 | ClientError, | ||
13 | ServerError, | ||
14 | ConnectionClosedError, | ||
15 | ) | ||
diff --git a/bitbake/lib/bb/asyncrpc/client.py b/bitbake/lib/bb/asyncrpc/client.py index fa042bbe87..7f33099b63 100644 --- a/bitbake/lib/bb/asyncrpc/client.py +++ b/bitbake/lib/bb/asyncrpc/client.py | |||
@@ -10,13 +10,13 @@ import json | |||
10 | import os | 10 | import os |
11 | import socket | 11 | import socket |
12 | import sys | 12 | import sys |
13 | from . import chunkify, DEFAULT_MAX_CHUNK | 13 | from .connection import StreamConnection, DEFAULT_MAX_CHUNK |
14 | from .exceptions import ConnectionClosedError | ||
14 | 15 | ||
15 | 16 | ||
16 | class AsyncClient(object): | 17 | class AsyncClient(object): |
17 | def __init__(self, proto_name, proto_version, logger, timeout=30): | 18 | def __init__(self, proto_name, proto_version, logger, timeout=30): |
18 | self.reader = None | 19 | self.socket = None |
19 | self.writer = None | ||
20 | self.max_chunk = DEFAULT_MAX_CHUNK | 20 | self.max_chunk = DEFAULT_MAX_CHUNK |
21 | self.proto_name = proto_name | 21 | self.proto_name = proto_name |
22 | self.proto_version = proto_version | 22 | self.proto_version = proto_version |
@@ -25,7 +25,8 @@ class AsyncClient(object): | |||
25 | 25 | ||
26 | async def connect_tcp(self, address, port): | 26 | async def connect_tcp(self, address, port): |
27 | async def connect_sock(): | 27 | async def connect_sock(): |
28 | return await asyncio.open_connection(address, port) | 28 | reader, writer = await asyncio.open_connection(address, port) |
29 | return StreamConnection(reader, writer, self.timeout, self.max_chunk) | ||
29 | 30 | ||
30 | self._connect_sock = connect_sock | 31 | self._connect_sock = connect_sock |
31 | 32 | ||
@@ -40,27 +41,27 @@ class AsyncClient(object): | |||
40 | sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0) | 41 | sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0) |
41 | sock.connect(os.path.basename(path)) | 42 | sock.connect(os.path.basename(path)) |
42 | finally: | 43 | finally: |
43 | os.chdir(cwd) | 44 | os.chdir(cwd) |
44 | return await asyncio.open_unix_connection(sock=sock) | 45 | reader, writer = await asyncio.open_unix_connection(sock=sock) |
46 | return StreamConnection(reader, writer, self.timeout, self.max_chunk) | ||
45 | 47 | ||
46 | self._connect_sock = connect_sock | 48 | self._connect_sock = connect_sock |
47 | 49 | ||
48 | async def setup_connection(self): | 50 | async def setup_connection(self): |
49 | s = '%s %s\n\n' % (self.proto_name, self.proto_version) | 51 | # Send headers |
50 | self.writer.write(s.encode("utf-8")) | 52 | await self.socket.send("%s %s" % (self.proto_name, self.proto_version)) |
51 | await self.writer.drain() | 53 | # End of headers |
54 | await self.socket.send("") | ||
52 | 55 | ||
53 | async def connect(self): | 56 | async def connect(self): |
54 | if self.reader is None or self.writer is None: | 57 | if self.socket is None: |
55 | (self.reader, self.writer) = await self._connect_sock() | 58 | self.socket = await self._connect_sock() |
56 | await self.setup_connection() | 59 | await self.setup_connection() |
57 | 60 | ||
58 | async def close(self): | 61 | async def close(self): |
59 | self.reader = None | 62 | if self.socket is not None: |
60 | 63 | await self.socket.close() | |
61 | if self.writer is not None: | 64 | self.socket = None |
62 | self.writer.close() | ||
63 | self.writer = None | ||
64 | 65 | ||
65 | async def _send_wrapper(self, proc): | 66 | async def _send_wrapper(self, proc): |
66 | count = 0 | 67 | count = 0 |
@@ -71,6 +72,7 @@ class AsyncClient(object): | |||
71 | except ( | 72 | except ( |
72 | OSError, | 73 | OSError, |
73 | ConnectionError, | 74 | ConnectionError, |
75 | ConnectionClosedError, | ||
74 | json.JSONDecodeError, | 76 | json.JSONDecodeError, |
75 | UnicodeDecodeError, | 77 | UnicodeDecodeError, |
76 | ) as e: | 78 | ) as e: |
@@ -82,49 +84,15 @@ class AsyncClient(object): | |||
82 | await self.close() | 84 | await self.close() |
83 | count += 1 | 85 | count += 1 |
84 | 86 | ||
85 | async def send_message(self, msg): | 87 | async def invoke(self, msg): |
86 | async def get_line(): | ||
87 | try: | ||
88 | line = await asyncio.wait_for(self.reader.readline(), self.timeout) | ||
89 | except asyncio.TimeoutError: | ||
90 | raise ConnectionError("Timed out waiting for server") | ||
91 | |||
92 | if not line: | ||
93 | raise ConnectionError("Connection closed") | ||
94 | |||
95 | line = line.decode("utf-8") | ||
96 | |||
97 | if not line.endswith("\n"): | ||
98 | raise ConnectionError("Bad message %r" % (line)) | ||
99 | |||
100 | return line | ||
101 | |||
102 | async def proc(): | 88 | async def proc(): |
103 | for c in chunkify(json.dumps(msg), self.max_chunk): | 89 | await self.socket.send_message(msg) |
104 | self.writer.write(c.encode("utf-8")) | 90 | return await self.socket.recv_message() |
105 | await self.writer.drain() | ||
106 | |||
107 | l = await get_line() | ||
108 | |||
109 | m = json.loads(l) | ||
110 | if m and "chunk-stream" in m: | ||
111 | lines = [] | ||
112 | while True: | ||
113 | l = (await get_line()).rstrip("\n") | ||
114 | if not l: | ||
115 | break | ||
116 | lines.append(l) | ||
117 | |||
118 | m = json.loads("".join(lines)) | ||
119 | |||
120 | return m | ||
121 | 91 | ||
122 | return await self._send_wrapper(proc) | 92 | return await self._send_wrapper(proc) |
123 | 93 | ||
124 | async def ping(self): | 94 | async def ping(self): |
125 | return await self.send_message( | 95 | return await self.invoke({"ping": {}}) |
126 | {'ping': {}} | ||
127 | ) | ||
128 | 96 | ||
129 | 97 | ||
130 | class Client(object): | 98 | class Client(object): |
@@ -142,7 +110,7 @@ class Client(object): | |||
142 | # required (but harmless) with it. | 110 | # required (but harmless) with it. |
143 | asyncio.set_event_loop(self.loop) | 111 | asyncio.set_event_loop(self.loop) |
144 | 112 | ||
145 | self._add_methods('connect_tcp', 'ping') | 113 | self._add_methods("connect_tcp", "ping") |
146 | 114 | ||
147 | @abc.abstractmethod | 115 | @abc.abstractmethod |
148 | def _get_async_client(self): | 116 | def _get_async_client(self): |
diff --git a/bitbake/lib/bb/asyncrpc/connection.py b/bitbake/lib/bb/asyncrpc/connection.py new file mode 100644 index 0000000000..c4fd24754c --- /dev/null +++ b/bitbake/lib/bb/asyncrpc/connection.py | |||
@@ -0,0 +1,95 @@ | |||
1 | # | ||
2 | # Copyright BitBake Contributors | ||
3 | # | ||
4 | # SPDX-License-Identifier: GPL-2.0-only | ||
5 | # | ||
6 | |||
7 | import asyncio | ||
8 | import itertools | ||
9 | import json | ||
10 | from .exceptions import ClientError, ConnectionClosedError | ||
11 | |||
12 | |||
13 | # The Python async server defaults to a 64K receive buffer, so we hardcode our | ||
14 | # maximum chunk size. It would be better if the client and server reported to | ||
15 | # each other what the maximum chunk sizes were, but that will slow down the | ||
16 | # connection setup with a round trip delay so I'd rather not do that unless it | ||
17 | # is necessary | ||
18 | DEFAULT_MAX_CHUNK = 32 * 1024 | ||
19 | |||
20 | |||
21 | def chunkify(msg, max_chunk): | ||
22 | if len(msg) < max_chunk - 1: | ||
23 | yield "".join((msg, "\n")) | ||
24 | else: | ||
25 | yield "".join((json.dumps({"chunk-stream": None}), "\n")) | ||
26 | |||
27 | args = [iter(msg)] * (max_chunk - 1) | ||
28 | for m in map("".join, itertools.zip_longest(*args, fillvalue="")): | ||
29 | yield "".join(itertools.chain(m, "\n")) | ||
30 | yield "\n" | ||
31 | |||
32 | |||
33 | class StreamConnection(object): | ||
34 | def __init__(self, reader, writer, timeout, max_chunk=DEFAULT_MAX_CHUNK): | ||
35 | self.reader = reader | ||
36 | self.writer = writer | ||
37 | self.timeout = timeout | ||
38 | self.max_chunk = max_chunk | ||
39 | |||
40 | @property | ||
41 | def address(self): | ||
42 | return self.writer.get_extra_info("peername") | ||
43 | |||
44 | async def send_message(self, msg): | ||
45 | for c in chunkify(json.dumps(msg), self.max_chunk): | ||
46 | self.writer.write(c.encode("utf-8")) | ||
47 | await self.writer.drain() | ||
48 | |||
49 | async def recv_message(self): | ||
50 | l = await self.recv() | ||
51 | |||
52 | m = json.loads(l) | ||
53 | if not m: | ||
54 | return m | ||
55 | |||
56 | if "chunk-stream" in m: | ||
57 | lines = [] | ||
58 | while True: | ||
59 | l = await self.recv() | ||
60 | if not l: | ||
61 | break | ||
62 | lines.append(l) | ||
63 | |||
64 | m = json.loads("".join(lines)) | ||
65 | |||
66 | return m | ||
67 | |||
68 | async def send(self, msg): | ||
69 | self.writer.write(("%s\n" % msg).encode("utf-8")) | ||
70 | await self.writer.drain() | ||
71 | |||
72 | async def recv(self): | ||
73 | if self.timeout < 0: | ||
74 | line = await self.reader.readline() | ||
75 | else: | ||
76 | try: | ||
77 | line = await asyncio.wait_for(self.reader.readline(), self.timeout) | ||
78 | except asyncio.TimeoutError: | ||
79 | raise ConnectionError("Timed out waiting for data") | ||
80 | |||
81 | if not line: | ||
82 | raise ConnectionClosedError("Connection closed") | ||
83 | |||
84 | line = line.decode("utf-8") | ||
85 | |||
86 | if not line.endswith("\n"): | ||
87 | raise ConnectionError("Bad message %r" % (line)) | ||
88 | |||
89 | return line.rstrip() | ||
90 | |||
91 | async def close(self): | ||
92 | self.reader = None | ||
93 | if self.writer is not None: | ||
94 | self.writer.close() | ||
95 | self.writer = None | ||
diff --git a/bitbake/lib/bb/asyncrpc/exceptions.py b/bitbake/lib/bb/asyncrpc/exceptions.py new file mode 100644 index 0000000000..a8942b4f0c --- /dev/null +++ b/bitbake/lib/bb/asyncrpc/exceptions.py | |||
@@ -0,0 +1,17 @@ | |||
1 | # | ||
2 | # Copyright BitBake Contributors | ||
3 | # | ||
4 | # SPDX-License-Identifier: GPL-2.0-only | ||
5 | # | ||
6 | |||
7 | |||
8 | class ClientError(Exception): | ||
9 | pass | ||
10 | |||
11 | |||
12 | class ServerError(Exception): | ||
13 | pass | ||
14 | |||
15 | |||
16 | class ConnectionClosedError(Exception): | ||
17 | pass | ||
diff --git a/bitbake/lib/bb/asyncrpc/serv.py b/bitbake/lib/bb/asyncrpc/serv.py index d2de4891b8..3e0d0632cb 100644 --- a/bitbake/lib/bb/asyncrpc/serv.py +++ b/bitbake/lib/bb/asyncrpc/serv.py | |||
@@ -12,241 +12,248 @@ import signal | |||
12 | import socket | 12 | import socket |
13 | import sys | 13 | import sys |
14 | import multiprocessing | 14 | import multiprocessing |
15 | from . import chunkify, DEFAULT_MAX_CHUNK | 15 | from .connection import StreamConnection |
16 | 16 | from .exceptions import ClientError, ServerError, ConnectionClosedError | |
17 | |||
18 | class ClientError(Exception): | ||
19 | pass | ||
20 | |||
21 | |||
22 | class ServerError(Exception): | ||
23 | pass | ||
24 | 17 | ||
25 | 18 | ||
26 | class AsyncServerConnection(object): | 19 | class AsyncServerConnection(object): |
27 | def __init__(self, reader, writer, proto_name, logger): | 20 | # If a handler returns this object (e.g. `return self.NO_RESPONSE`), no |
28 | self.reader = reader | 21 | # return message will be automatically be sent back to the client |
29 | self.writer = writer | 22 | NO_RESPONSE = object() |
23 | |||
24 | def __init__(self, socket, proto_name, logger): | ||
25 | self.socket = socket | ||
30 | self.proto_name = proto_name | 26 | self.proto_name = proto_name |
31 | self.max_chunk = DEFAULT_MAX_CHUNK | ||
32 | self.handlers = { | 27 | self.handlers = { |
33 | 'chunk-stream': self.handle_chunk, | 28 | "ping": self.handle_ping, |
34 | 'ping': self.handle_ping, | ||
35 | } | 29 | } |
36 | self.logger = logger | 30 | self.logger = logger |
37 | 31 | ||
32 | async def close(self): | ||
33 | await self.socket.close() | ||
34 | |||
38 | async def process_requests(self): | 35 | async def process_requests(self): |
39 | try: | 36 | try: |
40 | self.addr = self.writer.get_extra_info('peername') | 37 | self.logger.info("Client %r connected" % (self.socket.address,)) |
41 | self.logger.debug('Client %r connected' % (self.addr,)) | ||
42 | 38 | ||
43 | # Read protocol and version | 39 | # Read protocol and version |
44 | client_protocol = await self.reader.readline() | 40 | client_protocol = await self.socket.recv() |
45 | if not client_protocol: | 41 | if not client_protocol: |
46 | return | 42 | return |
47 | 43 | ||
48 | (client_proto_name, client_proto_version) = client_protocol.decode('utf-8').rstrip().split() | 44 | (client_proto_name, client_proto_version) = client_protocol.split() |
49 | if client_proto_name != self.proto_name: | 45 | if client_proto_name != self.proto_name: |
50 | self.logger.debug('Rejecting invalid protocol %s' % (self.proto_name)) | 46 | self.logger.debug("Rejecting invalid protocol %s" % (self.proto_name)) |
51 | return | 47 | return |
52 | 48 | ||
53 | self.proto_version = tuple(int(v) for v in client_proto_version.split('.')) | 49 | self.proto_version = tuple(int(v) for v in client_proto_version.split(".")) |
54 | if not self.validate_proto_version(): | 50 | if not self.validate_proto_version(): |
55 | self.logger.debug('Rejecting invalid protocol version %s' % (client_proto_version)) | 51 | self.logger.debug( |
52 | "Rejecting invalid protocol version %s" % (client_proto_version) | ||
53 | ) | ||
56 | return | 54 | return |
57 | 55 | ||
58 | # Read headers. Currently, no headers are implemented, so look for | 56 | # Read headers. Currently, no headers are implemented, so look for |
59 | # an empty line to signal the end of the headers | 57 | # an empty line to signal the end of the headers |
60 | while True: | 58 | while True: |
61 | line = await self.reader.readline() | 59 | header = await self.socket.recv() |
62 | if not line: | 60 | if not header: |
63 | return | ||
64 | |||
65 | line = line.decode('utf-8').rstrip() | ||
66 | if not line: | ||
67 | break | 61 | break |
68 | 62 | ||
69 | # Handle messages | 63 | # Handle messages |
70 | while True: | 64 | while True: |
71 | d = await self.read_message() | 65 | d = await self.socket.recv_message() |
72 | if d is None: | 66 | if d is None: |
73 | break | 67 | break |
74 | await self.dispatch_message(d) | 68 | response = await self.dispatch_message(d) |
75 | await self.writer.drain() | 69 | if response is not self.NO_RESPONSE: |
76 | except ClientError as e: | 70 | await self.socket.send_message(response) |
71 | |||
72 | except ConnectionClosedError as e: | ||
73 | self.logger.info(str(e)) | ||
74 | except (ClientError, ConnectionError) as e: | ||
77 | self.logger.error(str(e)) | 75 | self.logger.error(str(e)) |
78 | finally: | 76 | finally: |
79 | self.writer.close() | 77 | await self.close() |
80 | 78 | ||
81 | async def dispatch_message(self, msg): | 79 | async def dispatch_message(self, msg): |
82 | for k in self.handlers.keys(): | 80 | for k in self.handlers.keys(): |
83 | if k in msg: | 81 | if k in msg: |
84 | self.logger.debug('Handling %s' % k) | 82 | self.logger.debug("Handling %s" % k) |
85 | await self.handlers[k](msg[k]) | 83 | return await self.handlers[k](msg[k]) |
86 | return | ||
87 | 84 | ||
88 | raise ClientError("Unrecognized command %r" % msg) | 85 | raise ClientError("Unrecognized command %r" % msg) |
89 | 86 | ||
90 | def write_message(self, msg): | 87 | async def handle_ping(self, request): |
91 | for c in chunkify(json.dumps(msg), self.max_chunk): | 88 | return {"alive": True} |
92 | self.writer.write(c.encode('utf-8')) | ||
93 | 89 | ||
94 | async def read_message(self): | ||
95 | l = await self.reader.readline() | ||
96 | if not l: | ||
97 | return None | ||
98 | 90 | ||
99 | try: | 91 | class StreamServer(object): |
100 | message = l.decode('utf-8') | 92 | def __init__(self, handler, logger): |
93 | self.handler = handler | ||
94 | self.logger = logger | ||
95 | self.closed = False | ||
101 | 96 | ||
102 | if not message.endswith('\n'): | 97 | async def handle_stream_client(self, reader, writer): |
103 | return None | 98 | # writer.transport.set_write_buffer_limits(0) |
99 | socket = StreamConnection(reader, writer, -1) | ||
100 | if self.closed: | ||
101 | await socket.close() | ||
102 | return | ||
103 | |||
104 | await self.handler(socket) | ||
105 | |||
106 | async def stop(self): | ||
107 | self.closed = True | ||
108 | |||
109 | |||
110 | class TCPStreamServer(StreamServer): | ||
111 | def __init__(self, host, port, handler, logger): | ||
112 | super().__init__(handler, logger) | ||
113 | self.host = host | ||
114 | self.port = port | ||
115 | |||
116 | def start(self, loop): | ||
117 | self.server = loop.run_until_complete( | ||
118 | asyncio.start_server(self.handle_stream_client, self.host, self.port) | ||
119 | ) | ||
120 | |||
121 | for s in self.server.sockets: | ||
122 | self.logger.debug("Listening on %r" % (s.getsockname(),)) | ||
123 | # Newer python does this automatically. Do it manually here for | ||
124 | # maximum compatibility | ||
125 | s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) | ||
126 | s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1) | ||
127 | |||
128 | # Enable keep alives. This prevents broken client connections | ||
129 | # from persisting on the server for long periods of time. | ||
130 | s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) | ||
131 | s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30) | ||
132 | s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15) | ||
133 | s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4) | ||
134 | |||
135 | name = self.server.sockets[0].getsockname() | ||
136 | if self.server.sockets[0].family == socket.AF_INET6: | ||
137 | self.address = "[%s]:%d" % (name[0], name[1]) | ||
138 | else: | ||
139 | self.address = "%s:%d" % (name[0], name[1]) | ||
140 | |||
141 | return [self.server.wait_closed()] | ||
142 | |||
143 | async def stop(self): | ||
144 | await super().stop() | ||
145 | self.server.close() | ||
146 | |||
147 | def cleanup(self): | ||
148 | pass | ||
104 | 149 | ||
105 | return json.loads(message) | ||
106 | except (json.JSONDecodeError, UnicodeDecodeError) as e: | ||
107 | self.logger.error('Bad message from client: %r' % message) | ||
108 | raise e | ||
109 | 150 | ||
110 | async def handle_chunk(self, request): | 151 | class UnixStreamServer(StreamServer): |
111 | lines = [] | 152 | def __init__(self, path, handler, logger): |
112 | try: | 153 | super().__init__(handler, logger) |
113 | while True: | 154 | self.path = path |
114 | l = await self.reader.readline() | ||
115 | l = l.rstrip(b"\n").decode("utf-8") | ||
116 | if not l: | ||
117 | break | ||
118 | lines.append(l) | ||
119 | 155 | ||
120 | msg = json.loads(''.join(lines)) | 156 | def start(self, loop): |
121 | except (json.JSONDecodeError, UnicodeDecodeError) as e: | 157 | cwd = os.getcwd() |
122 | self.logger.error('Bad message from client: %r' % lines) | 158 | try: |
123 | raise e | 159 | # Work around path length limits in AF_UNIX |
160 | os.chdir(os.path.dirname(self.path)) | ||
161 | self.server = loop.run_until_complete( | ||
162 | asyncio.start_unix_server( | ||
163 | self.handle_stream_client, os.path.basename(self.path) | ||
164 | ) | ||
165 | ) | ||
166 | finally: | ||
167 | os.chdir(cwd) | ||
124 | 168 | ||
125 | if 'chunk-stream' in msg: | 169 | self.logger.debug("Listening on %r" % self.path) |
126 | raise ClientError("Nested chunks are not allowed") | 170 | self.address = "unix://%s" % os.path.abspath(self.path) |
171 | return [self.server.wait_closed()] | ||
127 | 172 | ||
128 | await self.dispatch_message(msg) | 173 | async def stop(self): |
174 | await super().stop() | ||
175 | self.server.close() | ||
129 | 176 | ||
130 | async def handle_ping(self, request): | 177 | def cleanup(self): |
131 | response = {'alive': True} | 178 | os.unlink(self.path) |
132 | self.write_message(response) | ||
133 | 179 | ||
134 | 180 | ||
135 | class AsyncServer(object): | 181 | class AsyncServer(object): |
136 | def __init__(self, logger): | 182 | def __init__(self, logger): |
137 | self._cleanup_socket = None | ||
138 | self.logger = logger | 183 | self.logger = logger |
139 | self.start = None | ||
140 | self.address = None | ||
141 | self.loop = None | 184 | self.loop = None |
185 | self.run_tasks = [] | ||
142 | 186 | ||
143 | def start_tcp_server(self, host, port): | 187 | def start_tcp_server(self, host, port): |
144 | def start_tcp(): | 188 | self.server = TCPStreamServer(host, port, self._client_handler, self.logger) |
145 | self.server = self.loop.run_until_complete( | ||
146 | asyncio.start_server(self.handle_client, host, port) | ||
147 | ) | ||
148 | |||
149 | for s in self.server.sockets: | ||
150 | self.logger.debug('Listening on %r' % (s.getsockname(),)) | ||
151 | # Newer python does this automatically. Do it manually here for | ||
152 | # maximum compatibility | ||
153 | s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) | ||
154 | s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1) | ||
155 | |||
156 | # Enable keep alives. This prevents broken client connections | ||
157 | # from persisting on the server for long periods of time. | ||
158 | s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) | ||
159 | s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30) | ||
160 | s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15) | ||
161 | s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4) | ||
162 | |||
163 | name = self.server.sockets[0].getsockname() | ||
164 | if self.server.sockets[0].family == socket.AF_INET6: | ||
165 | self.address = "[%s]:%d" % (name[0], name[1]) | ||
166 | else: | ||
167 | self.address = "%s:%d" % (name[0], name[1]) | ||
168 | |||
169 | self.start = start_tcp | ||
170 | 189 | ||
171 | def start_unix_server(self, path): | 190 | def start_unix_server(self, path): |
172 | def cleanup(): | 191 | self.server = UnixStreamServer(path, self._client_handler, self.logger) |
173 | os.unlink(path) | ||
174 | |||
175 | def start_unix(): | ||
176 | cwd = os.getcwd() | ||
177 | try: | ||
178 | # Work around path length limits in AF_UNIX | ||
179 | os.chdir(os.path.dirname(path)) | ||
180 | self.server = self.loop.run_until_complete( | ||
181 | asyncio.start_unix_server(self.handle_client, os.path.basename(path)) | ||
182 | ) | ||
183 | finally: | ||
184 | os.chdir(cwd) | ||
185 | |||
186 | self.logger.debug('Listening on %r' % path) | ||
187 | 192 | ||
188 | self._cleanup_socket = cleanup | 193 | async def _client_handler(self, socket): |
189 | self.address = "unix://%s" % os.path.abspath(path) | ||
190 | |||
191 | self.start = start_unix | ||
192 | |||
193 | @abc.abstractmethod | ||
194 | def accept_client(self, reader, writer): | ||
195 | pass | ||
196 | |||
197 | async def handle_client(self, reader, writer): | ||
198 | # writer.transport.set_write_buffer_limits(0) | ||
199 | try: | 194 | try: |
200 | client = self.accept_client(reader, writer) | 195 | client = self.accept_client(socket) |
201 | await client.process_requests() | 196 | await client.process_requests() |
202 | except Exception as e: | 197 | except Exception as e: |
203 | import traceback | 198 | import traceback |
204 | self.logger.error('Error from client: %s' % str(e), exc_info=True) | 199 | |
200 | self.logger.error("Error from client: %s" % str(e), exc_info=True) | ||
205 | traceback.print_exc() | 201 | traceback.print_exc() |
206 | writer.close() | 202 | await socket.close() |
207 | self.logger.debug('Client disconnected') | 203 | self.logger.debug("Client disconnected") |
208 | 204 | ||
209 | def run_loop_forever(self): | 205 | @abc.abstractmethod |
210 | try: | 206 | def accept_client(self, socket): |
211 | self.loop.run_forever() | 207 | pass |
212 | except KeyboardInterrupt: | 208 | |
213 | pass | 209 | async def stop(self): |
210 | self.logger.debug("Stopping server") | ||
211 | await self.server.stop() | ||
212 | |||
213 | def start(self): | ||
214 | tasks = self.server.start(self.loop) | ||
215 | self.address = self.server.address | ||
216 | return tasks | ||
214 | 217 | ||
215 | def signal_handler(self): | 218 | def signal_handler(self): |
216 | self.logger.debug("Got exit signal") | 219 | self.logger.debug("Got exit signal") |
217 | self.loop.stop() | 220 | self.loop.create_task(self.stop()) |
218 | 221 | ||
219 | def _serve_forever(self): | 222 | def _serve_forever(self, tasks): |
220 | try: | 223 | try: |
221 | self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler) | 224 | self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler) |
225 | self.loop.add_signal_handler(signal.SIGINT, self.signal_handler) | ||
226 | self.loop.add_signal_handler(signal.SIGQUIT, self.signal_handler) | ||
222 | signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM]) | 227 | signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM]) |
223 | 228 | ||
224 | self.run_loop_forever() | 229 | self.loop.run_until_complete(asyncio.gather(*tasks)) |
225 | self.server.close() | ||
226 | 230 | ||
227 | self.loop.run_until_complete(self.server.wait_closed()) | 231 | self.logger.debug("Server shutting down") |
228 | self.logger.debug('Server shutting down') | ||
229 | finally: | 232 | finally: |
230 | if self._cleanup_socket is not None: | 233 | self.server.cleanup() |
231 | self._cleanup_socket() | ||
232 | 234 | ||
233 | def serve_forever(self): | 235 | def serve_forever(self): |
234 | """ | 236 | """ |
235 | Serve requests in the current process | 237 | Serve requests in the current process |
236 | """ | 238 | """ |
239 | self._create_loop() | ||
240 | tasks = self.start() | ||
241 | self._serve_forever(tasks) | ||
242 | self.loop.close() | ||
243 | |||
244 | def _create_loop(self): | ||
237 | # Create loop and override any loop that may have existed in | 245 | # Create loop and override any loop that may have existed in |
238 | # a parent process. It is possible that the usecases of | 246 | # a parent process. It is possible that the usecases of |
239 | # serve_forever might be constrained enough to allow using | 247 | # serve_forever might be constrained enough to allow using |
240 | # get_event_loop here, but better safe than sorry for now. | 248 | # get_event_loop here, but better safe than sorry for now. |
241 | self.loop = asyncio.new_event_loop() | 249 | self.loop = asyncio.new_event_loop() |
242 | asyncio.set_event_loop(self.loop) | 250 | asyncio.set_event_loop(self.loop) |
243 | self.start() | ||
244 | self._serve_forever() | ||
245 | 251 | ||
246 | def serve_as_process(self, *, prefunc=None, args=()): | 252 | def serve_as_process(self, *, prefunc=None, args=()): |
247 | """ | 253 | """ |
248 | Serve requests in a child process | 254 | Serve requests in a child process |
249 | """ | 255 | """ |
256 | |||
250 | def run(queue): | 257 | def run(queue): |
251 | # Create loop and override any loop that may have existed | 258 | # Create loop and override any loop that may have existed |
252 | # in a parent process. Without doing this and instead | 259 | # in a parent process. Without doing this and instead |
@@ -259,18 +266,19 @@ class AsyncServer(object): | |||
259 | # more general, though, as any potential use of asyncio in | 266 | # more general, though, as any potential use of asyncio in |
260 | # Cooker could create a loop that needs to replaced in this | 267 | # Cooker could create a loop that needs to replaced in this |
261 | # new process. | 268 | # new process. |
262 | self.loop = asyncio.new_event_loop() | 269 | self._create_loop() |
263 | asyncio.set_event_loop(self.loop) | ||
264 | try: | 270 | try: |
265 | self.start() | 271 | self.address = None |
272 | tasks = self.start() | ||
266 | finally: | 273 | finally: |
274 | # Always put the server address to wake up the parent task | ||
267 | queue.put(self.address) | 275 | queue.put(self.address) |
268 | queue.close() | 276 | queue.close() |
269 | 277 | ||
270 | if prefunc is not None: | 278 | if prefunc is not None: |
271 | prefunc(self, *args) | 279 | prefunc(self, *args) |
272 | 280 | ||
273 | self._serve_forever() | 281 | self._serve_forever(tasks) |
274 | 282 | ||
275 | if sys.version_info >= (3, 6): | 283 | if sys.version_info >= (3, 6): |
276 | self.loop.run_until_complete(self.loop.shutdown_asyncgens()) | 284 | self.loop.run_until_complete(self.loop.shutdown_asyncgens()) |