summaryrefslogtreecommitdiffstats
path: root/bitbake/lib/bb
diff options
context:
space:
mode:
authorJoshua Watt <JPEWhacker@gmail.com>2023-11-03 08:26:19 -0600
committerRichard Purdie <richard.purdie@linuxfoundation.org>2023-11-09 17:33:02 +0000
commit8f8501ed403dec27acbe780b936bc087fc5006d0 (patch)
tree60e6415075c7c71eacec23ca7dda53e4a324b12e /bitbake/lib/bb
parentf97b686884166dd77d1818e70615027c6ba8c348 (diff)
downloadpoky-8f8501ed403dec27acbe780b936bc087fc5006d0.tar.gz
bitbake: asyncrpc: Abstract sockets
Rewrites the asyncrpc client and server code to make it possible to have other transport backends that are not stream based (e.g. websockets which are message based). The connection handling classes are now shared between both the client and server to make it easier to implement new transport mechanisms (Bitbake rev: 2aaeae53696e4c2f13a169830c3b7089cbad6eca) Signed-off-by: Joshua Watt <JPEWhacker@gmail.com> Signed-off-by: Richard Purdie <richard.purdie@linuxfoundation.org>
Diffstat (limited to 'bitbake/lib/bb')
-rw-r--r--bitbake/lib/bb/asyncrpc/__init__.py32
-rw-r--r--bitbake/lib/bb/asyncrpc/client.py78
-rw-r--r--bitbake/lib/bb/asyncrpc/connection.py95
-rw-r--r--bitbake/lib/bb/asyncrpc/exceptions.py17
-rw-r--r--bitbake/lib/bb/asyncrpc/serv.py304
5 files changed, 298 insertions, 228 deletions
diff --git a/bitbake/lib/bb/asyncrpc/__init__.py b/bitbake/lib/bb/asyncrpc/__init__.py
index 9a85e9965b..9f677eac4c 100644
--- a/bitbake/lib/bb/asyncrpc/__init__.py
+++ b/bitbake/lib/bb/asyncrpc/__init__.py
@@ -4,30 +4,12 @@
4# SPDX-License-Identifier: GPL-2.0-only 4# SPDX-License-Identifier: GPL-2.0-only
5# 5#
6 6
7import itertools
8import json
9
10# The Python async server defaults to a 64K receive buffer, so we hardcode our
11# maximum chunk size. It would be better if the client and server reported to
12# each other what the maximum chunk sizes were, but that will slow down the
13# connection setup with a round trip delay so I'd rather not do that unless it
14# is necessary
15DEFAULT_MAX_CHUNK = 32 * 1024
16
17
18def chunkify(msg, max_chunk):
19 if len(msg) < max_chunk - 1:
20 yield ''.join((msg, "\n"))
21 else:
22 yield ''.join((json.dumps({
23 'chunk-stream': None
24 }), "\n"))
25
26 args = [iter(msg)] * (max_chunk - 1)
27 for m in map(''.join, itertools.zip_longest(*args, fillvalue='')):
28 yield ''.join(itertools.chain(m, "\n"))
29 yield "\n"
30
31 7
32from .client import AsyncClient, Client 8from .client import AsyncClient, Client
33from .serv import AsyncServer, AsyncServerConnection, ClientError, ServerError 9from .serv import AsyncServer, AsyncServerConnection
10from .connection import DEFAULT_MAX_CHUNK
11from .exceptions import (
12 ClientError,
13 ServerError,
14 ConnectionClosedError,
15)
diff --git a/bitbake/lib/bb/asyncrpc/client.py b/bitbake/lib/bb/asyncrpc/client.py
index fa042bbe87..7f33099b63 100644
--- a/bitbake/lib/bb/asyncrpc/client.py
+++ b/bitbake/lib/bb/asyncrpc/client.py
@@ -10,13 +10,13 @@ import json
10import os 10import os
11import socket 11import socket
12import sys 12import sys
13from . import chunkify, DEFAULT_MAX_CHUNK 13from .connection import StreamConnection, DEFAULT_MAX_CHUNK
14from .exceptions import ConnectionClosedError
14 15
15 16
16class AsyncClient(object): 17class AsyncClient(object):
17 def __init__(self, proto_name, proto_version, logger, timeout=30): 18 def __init__(self, proto_name, proto_version, logger, timeout=30):
18 self.reader = None 19 self.socket = None
19 self.writer = None
20 self.max_chunk = DEFAULT_MAX_CHUNK 20 self.max_chunk = DEFAULT_MAX_CHUNK
21 self.proto_name = proto_name 21 self.proto_name = proto_name
22 self.proto_version = proto_version 22 self.proto_version = proto_version
@@ -25,7 +25,8 @@ class AsyncClient(object):
25 25
26 async def connect_tcp(self, address, port): 26 async def connect_tcp(self, address, port):
27 async def connect_sock(): 27 async def connect_sock():
28 return await asyncio.open_connection(address, port) 28 reader, writer = await asyncio.open_connection(address, port)
29 return StreamConnection(reader, writer, self.timeout, self.max_chunk)
29 30
30 self._connect_sock = connect_sock 31 self._connect_sock = connect_sock
31 32
@@ -40,27 +41,27 @@ class AsyncClient(object):
40 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0) 41 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0)
41 sock.connect(os.path.basename(path)) 42 sock.connect(os.path.basename(path))
42 finally: 43 finally:
43 os.chdir(cwd) 44 os.chdir(cwd)
44 return await asyncio.open_unix_connection(sock=sock) 45 reader, writer = await asyncio.open_unix_connection(sock=sock)
46 return StreamConnection(reader, writer, self.timeout, self.max_chunk)
45 47
46 self._connect_sock = connect_sock 48 self._connect_sock = connect_sock
47 49
48 async def setup_connection(self): 50 async def setup_connection(self):
49 s = '%s %s\n\n' % (self.proto_name, self.proto_version) 51 # Send headers
50 self.writer.write(s.encode("utf-8")) 52 await self.socket.send("%s %s" % (self.proto_name, self.proto_version))
51 await self.writer.drain() 53 # End of headers
54 await self.socket.send("")
52 55
53 async def connect(self): 56 async def connect(self):
54 if self.reader is None or self.writer is None: 57 if self.socket is None:
55 (self.reader, self.writer) = await self._connect_sock() 58 self.socket = await self._connect_sock()
56 await self.setup_connection() 59 await self.setup_connection()
57 60
58 async def close(self): 61 async def close(self):
59 self.reader = None 62 if self.socket is not None:
60 63 await self.socket.close()
61 if self.writer is not None: 64 self.socket = None
62 self.writer.close()
63 self.writer = None
64 65
65 async def _send_wrapper(self, proc): 66 async def _send_wrapper(self, proc):
66 count = 0 67 count = 0
@@ -71,6 +72,7 @@ class AsyncClient(object):
71 except ( 72 except (
72 OSError, 73 OSError,
73 ConnectionError, 74 ConnectionError,
75 ConnectionClosedError,
74 json.JSONDecodeError, 76 json.JSONDecodeError,
75 UnicodeDecodeError, 77 UnicodeDecodeError,
76 ) as e: 78 ) as e:
@@ -82,49 +84,15 @@ class AsyncClient(object):
82 await self.close() 84 await self.close()
83 count += 1 85 count += 1
84 86
85 async def send_message(self, msg): 87 async def invoke(self, msg):
86 async def get_line():
87 try:
88 line = await asyncio.wait_for(self.reader.readline(), self.timeout)
89 except asyncio.TimeoutError:
90 raise ConnectionError("Timed out waiting for server")
91
92 if not line:
93 raise ConnectionError("Connection closed")
94
95 line = line.decode("utf-8")
96
97 if not line.endswith("\n"):
98 raise ConnectionError("Bad message %r" % (line))
99
100 return line
101
102 async def proc(): 88 async def proc():
103 for c in chunkify(json.dumps(msg), self.max_chunk): 89 await self.socket.send_message(msg)
104 self.writer.write(c.encode("utf-8")) 90 return await self.socket.recv_message()
105 await self.writer.drain()
106
107 l = await get_line()
108
109 m = json.loads(l)
110 if m and "chunk-stream" in m:
111 lines = []
112 while True:
113 l = (await get_line()).rstrip("\n")
114 if not l:
115 break
116 lines.append(l)
117
118 m = json.loads("".join(lines))
119
120 return m
121 91
122 return await self._send_wrapper(proc) 92 return await self._send_wrapper(proc)
123 93
124 async def ping(self): 94 async def ping(self):
125 return await self.send_message( 95 return await self.invoke({"ping": {}})
126 {'ping': {}}
127 )
128 96
129 97
130class Client(object): 98class Client(object):
@@ -142,7 +110,7 @@ class Client(object):
142 # required (but harmless) with it. 110 # required (but harmless) with it.
143 asyncio.set_event_loop(self.loop) 111 asyncio.set_event_loop(self.loop)
144 112
145 self._add_methods('connect_tcp', 'ping') 113 self._add_methods("connect_tcp", "ping")
146 114
147 @abc.abstractmethod 115 @abc.abstractmethod
148 def _get_async_client(self): 116 def _get_async_client(self):
diff --git a/bitbake/lib/bb/asyncrpc/connection.py b/bitbake/lib/bb/asyncrpc/connection.py
new file mode 100644
index 0000000000..c4fd24754c
--- /dev/null
+++ b/bitbake/lib/bb/asyncrpc/connection.py
@@ -0,0 +1,95 @@
1#
2# Copyright BitBake Contributors
3#
4# SPDX-License-Identifier: GPL-2.0-only
5#
6
7import asyncio
8import itertools
9import json
10from .exceptions import ClientError, ConnectionClosedError
11
12
13# The Python async server defaults to a 64K receive buffer, so we hardcode our
14# maximum chunk size. It would be better if the client and server reported to
15# each other what the maximum chunk sizes were, but that will slow down the
16# connection setup with a round trip delay so I'd rather not do that unless it
17# is necessary
18DEFAULT_MAX_CHUNK = 32 * 1024
19
20
21def chunkify(msg, max_chunk):
22 if len(msg) < max_chunk - 1:
23 yield "".join((msg, "\n"))
24 else:
25 yield "".join((json.dumps({"chunk-stream": None}), "\n"))
26
27 args = [iter(msg)] * (max_chunk - 1)
28 for m in map("".join, itertools.zip_longest(*args, fillvalue="")):
29 yield "".join(itertools.chain(m, "\n"))
30 yield "\n"
31
32
33class StreamConnection(object):
34 def __init__(self, reader, writer, timeout, max_chunk=DEFAULT_MAX_CHUNK):
35 self.reader = reader
36 self.writer = writer
37 self.timeout = timeout
38 self.max_chunk = max_chunk
39
40 @property
41 def address(self):
42 return self.writer.get_extra_info("peername")
43
44 async def send_message(self, msg):
45 for c in chunkify(json.dumps(msg), self.max_chunk):
46 self.writer.write(c.encode("utf-8"))
47 await self.writer.drain()
48
49 async def recv_message(self):
50 l = await self.recv()
51
52 m = json.loads(l)
53 if not m:
54 return m
55
56 if "chunk-stream" in m:
57 lines = []
58 while True:
59 l = await self.recv()
60 if not l:
61 break
62 lines.append(l)
63
64 m = json.loads("".join(lines))
65
66 return m
67
68 async def send(self, msg):
69 self.writer.write(("%s\n" % msg).encode("utf-8"))
70 await self.writer.drain()
71
72 async def recv(self):
73 if self.timeout < 0:
74 line = await self.reader.readline()
75 else:
76 try:
77 line = await asyncio.wait_for(self.reader.readline(), self.timeout)
78 except asyncio.TimeoutError:
79 raise ConnectionError("Timed out waiting for data")
80
81 if not line:
82 raise ConnectionClosedError("Connection closed")
83
84 line = line.decode("utf-8")
85
86 if not line.endswith("\n"):
87 raise ConnectionError("Bad message %r" % (line))
88
89 return line.rstrip()
90
91 async def close(self):
92 self.reader = None
93 if self.writer is not None:
94 self.writer.close()
95 self.writer = None
diff --git a/bitbake/lib/bb/asyncrpc/exceptions.py b/bitbake/lib/bb/asyncrpc/exceptions.py
new file mode 100644
index 0000000000..a8942b4f0c
--- /dev/null
+++ b/bitbake/lib/bb/asyncrpc/exceptions.py
@@ -0,0 +1,17 @@
1#
2# Copyright BitBake Contributors
3#
4# SPDX-License-Identifier: GPL-2.0-only
5#
6
7
8class ClientError(Exception):
9 pass
10
11
12class ServerError(Exception):
13 pass
14
15
16class ConnectionClosedError(Exception):
17 pass
diff --git a/bitbake/lib/bb/asyncrpc/serv.py b/bitbake/lib/bb/asyncrpc/serv.py
index d2de4891b8..3e0d0632cb 100644
--- a/bitbake/lib/bb/asyncrpc/serv.py
+++ b/bitbake/lib/bb/asyncrpc/serv.py
@@ -12,241 +12,248 @@ import signal
12import socket 12import socket
13import sys 13import sys
14import multiprocessing 14import multiprocessing
15from . import chunkify, DEFAULT_MAX_CHUNK 15from .connection import StreamConnection
16 16from .exceptions import ClientError, ServerError, ConnectionClosedError
17
18class ClientError(Exception):
19 pass
20
21
22class ServerError(Exception):
23 pass
24 17
25 18
26class AsyncServerConnection(object): 19class AsyncServerConnection(object):
27 def __init__(self, reader, writer, proto_name, logger): 20 # If a handler returns this object (e.g. `return self.NO_RESPONSE`), no
28 self.reader = reader 21 # return message will be automatically be sent back to the client
29 self.writer = writer 22 NO_RESPONSE = object()
23
24 def __init__(self, socket, proto_name, logger):
25 self.socket = socket
30 self.proto_name = proto_name 26 self.proto_name = proto_name
31 self.max_chunk = DEFAULT_MAX_CHUNK
32 self.handlers = { 27 self.handlers = {
33 'chunk-stream': self.handle_chunk, 28 "ping": self.handle_ping,
34 'ping': self.handle_ping,
35 } 29 }
36 self.logger = logger 30 self.logger = logger
37 31
32 async def close(self):
33 await self.socket.close()
34
38 async def process_requests(self): 35 async def process_requests(self):
39 try: 36 try:
40 self.addr = self.writer.get_extra_info('peername') 37 self.logger.info("Client %r connected" % (self.socket.address,))
41 self.logger.debug('Client %r connected' % (self.addr,))
42 38
43 # Read protocol and version 39 # Read protocol and version
44 client_protocol = await self.reader.readline() 40 client_protocol = await self.socket.recv()
45 if not client_protocol: 41 if not client_protocol:
46 return 42 return
47 43
48 (client_proto_name, client_proto_version) = client_protocol.decode('utf-8').rstrip().split() 44 (client_proto_name, client_proto_version) = client_protocol.split()
49 if client_proto_name != self.proto_name: 45 if client_proto_name != self.proto_name:
50 self.logger.debug('Rejecting invalid protocol %s' % (self.proto_name)) 46 self.logger.debug("Rejecting invalid protocol %s" % (self.proto_name))
51 return 47 return
52 48
53 self.proto_version = tuple(int(v) for v in client_proto_version.split('.')) 49 self.proto_version = tuple(int(v) for v in client_proto_version.split("."))
54 if not self.validate_proto_version(): 50 if not self.validate_proto_version():
55 self.logger.debug('Rejecting invalid protocol version %s' % (client_proto_version)) 51 self.logger.debug(
52 "Rejecting invalid protocol version %s" % (client_proto_version)
53 )
56 return 54 return
57 55
58 # Read headers. Currently, no headers are implemented, so look for 56 # Read headers. Currently, no headers are implemented, so look for
59 # an empty line to signal the end of the headers 57 # an empty line to signal the end of the headers
60 while True: 58 while True:
61 line = await self.reader.readline() 59 header = await self.socket.recv()
62 if not line: 60 if not header:
63 return
64
65 line = line.decode('utf-8').rstrip()
66 if not line:
67 break 61 break
68 62
69 # Handle messages 63 # Handle messages
70 while True: 64 while True:
71 d = await self.read_message() 65 d = await self.socket.recv_message()
72 if d is None: 66 if d is None:
73 break 67 break
74 await self.dispatch_message(d) 68 response = await self.dispatch_message(d)
75 await self.writer.drain() 69 if response is not self.NO_RESPONSE:
76 except ClientError as e: 70 await self.socket.send_message(response)
71
72 except ConnectionClosedError as e:
73 self.logger.info(str(e))
74 except (ClientError, ConnectionError) as e:
77 self.logger.error(str(e)) 75 self.logger.error(str(e))
78 finally: 76 finally:
79 self.writer.close() 77 await self.close()
80 78
81 async def dispatch_message(self, msg): 79 async def dispatch_message(self, msg):
82 for k in self.handlers.keys(): 80 for k in self.handlers.keys():
83 if k in msg: 81 if k in msg:
84 self.logger.debug('Handling %s' % k) 82 self.logger.debug("Handling %s" % k)
85 await self.handlers[k](msg[k]) 83 return await self.handlers[k](msg[k])
86 return
87 84
88 raise ClientError("Unrecognized command %r" % msg) 85 raise ClientError("Unrecognized command %r" % msg)
89 86
90 def write_message(self, msg): 87 async def handle_ping(self, request):
91 for c in chunkify(json.dumps(msg), self.max_chunk): 88 return {"alive": True}
92 self.writer.write(c.encode('utf-8'))
93 89
94 async def read_message(self):
95 l = await self.reader.readline()
96 if not l:
97 return None
98 90
99 try: 91class StreamServer(object):
100 message = l.decode('utf-8') 92 def __init__(self, handler, logger):
93 self.handler = handler
94 self.logger = logger
95 self.closed = False
101 96
102 if not message.endswith('\n'): 97 async def handle_stream_client(self, reader, writer):
103 return None 98 # writer.transport.set_write_buffer_limits(0)
99 socket = StreamConnection(reader, writer, -1)
100 if self.closed:
101 await socket.close()
102 return
103
104 await self.handler(socket)
105
106 async def stop(self):
107 self.closed = True
108
109
110class TCPStreamServer(StreamServer):
111 def __init__(self, host, port, handler, logger):
112 super().__init__(handler, logger)
113 self.host = host
114 self.port = port
115
116 def start(self, loop):
117 self.server = loop.run_until_complete(
118 asyncio.start_server(self.handle_stream_client, self.host, self.port)
119 )
120
121 for s in self.server.sockets:
122 self.logger.debug("Listening on %r" % (s.getsockname(),))
123 # Newer python does this automatically. Do it manually here for
124 # maximum compatibility
125 s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
126 s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
127
128 # Enable keep alives. This prevents broken client connections
129 # from persisting on the server for long periods of time.
130 s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
131 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
132 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
133 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
134
135 name = self.server.sockets[0].getsockname()
136 if self.server.sockets[0].family == socket.AF_INET6:
137 self.address = "[%s]:%d" % (name[0], name[1])
138 else:
139 self.address = "%s:%d" % (name[0], name[1])
140
141 return [self.server.wait_closed()]
142
143 async def stop(self):
144 await super().stop()
145 self.server.close()
146
147 def cleanup(self):
148 pass
104 149
105 return json.loads(message)
106 except (json.JSONDecodeError, UnicodeDecodeError) as e:
107 self.logger.error('Bad message from client: %r' % message)
108 raise e
109 150
110 async def handle_chunk(self, request): 151class UnixStreamServer(StreamServer):
111 lines = [] 152 def __init__(self, path, handler, logger):
112 try: 153 super().__init__(handler, logger)
113 while True: 154 self.path = path
114 l = await self.reader.readline()
115 l = l.rstrip(b"\n").decode("utf-8")
116 if not l:
117 break
118 lines.append(l)
119 155
120 msg = json.loads(''.join(lines)) 156 def start(self, loop):
121 except (json.JSONDecodeError, UnicodeDecodeError) as e: 157 cwd = os.getcwd()
122 self.logger.error('Bad message from client: %r' % lines) 158 try:
123 raise e 159 # Work around path length limits in AF_UNIX
160 os.chdir(os.path.dirname(self.path))
161 self.server = loop.run_until_complete(
162 asyncio.start_unix_server(
163 self.handle_stream_client, os.path.basename(self.path)
164 )
165 )
166 finally:
167 os.chdir(cwd)
124 168
125 if 'chunk-stream' in msg: 169 self.logger.debug("Listening on %r" % self.path)
126 raise ClientError("Nested chunks are not allowed") 170 self.address = "unix://%s" % os.path.abspath(self.path)
171 return [self.server.wait_closed()]
127 172
128 await self.dispatch_message(msg) 173 async def stop(self):
174 await super().stop()
175 self.server.close()
129 176
130 async def handle_ping(self, request): 177 def cleanup(self):
131 response = {'alive': True} 178 os.unlink(self.path)
132 self.write_message(response)
133 179
134 180
135class AsyncServer(object): 181class AsyncServer(object):
136 def __init__(self, logger): 182 def __init__(self, logger):
137 self._cleanup_socket = None
138 self.logger = logger 183 self.logger = logger
139 self.start = None
140 self.address = None
141 self.loop = None 184 self.loop = None
185 self.run_tasks = []
142 186
143 def start_tcp_server(self, host, port): 187 def start_tcp_server(self, host, port):
144 def start_tcp(): 188 self.server = TCPStreamServer(host, port, self._client_handler, self.logger)
145 self.server = self.loop.run_until_complete(
146 asyncio.start_server(self.handle_client, host, port)
147 )
148
149 for s in self.server.sockets:
150 self.logger.debug('Listening on %r' % (s.getsockname(),))
151 # Newer python does this automatically. Do it manually here for
152 # maximum compatibility
153 s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
154 s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
155
156 # Enable keep alives. This prevents broken client connections
157 # from persisting on the server for long periods of time.
158 s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
159 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
160 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
161 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
162
163 name = self.server.sockets[0].getsockname()
164 if self.server.sockets[0].family == socket.AF_INET6:
165 self.address = "[%s]:%d" % (name[0], name[1])
166 else:
167 self.address = "%s:%d" % (name[0], name[1])
168
169 self.start = start_tcp
170 189
171 def start_unix_server(self, path): 190 def start_unix_server(self, path):
172 def cleanup(): 191 self.server = UnixStreamServer(path, self._client_handler, self.logger)
173 os.unlink(path)
174
175 def start_unix():
176 cwd = os.getcwd()
177 try:
178 # Work around path length limits in AF_UNIX
179 os.chdir(os.path.dirname(path))
180 self.server = self.loop.run_until_complete(
181 asyncio.start_unix_server(self.handle_client, os.path.basename(path))
182 )
183 finally:
184 os.chdir(cwd)
185
186 self.logger.debug('Listening on %r' % path)
187 192
188 self._cleanup_socket = cleanup 193 async def _client_handler(self, socket):
189 self.address = "unix://%s" % os.path.abspath(path)
190
191 self.start = start_unix
192
193 @abc.abstractmethod
194 def accept_client(self, reader, writer):
195 pass
196
197 async def handle_client(self, reader, writer):
198 # writer.transport.set_write_buffer_limits(0)
199 try: 194 try:
200 client = self.accept_client(reader, writer) 195 client = self.accept_client(socket)
201 await client.process_requests() 196 await client.process_requests()
202 except Exception as e: 197 except Exception as e:
203 import traceback 198 import traceback
204 self.logger.error('Error from client: %s' % str(e), exc_info=True) 199
200 self.logger.error("Error from client: %s" % str(e), exc_info=True)
205 traceback.print_exc() 201 traceback.print_exc()
206 writer.close() 202 await socket.close()
207 self.logger.debug('Client disconnected') 203 self.logger.debug("Client disconnected")
208 204
209 def run_loop_forever(self): 205 @abc.abstractmethod
210 try: 206 def accept_client(self, socket):
211 self.loop.run_forever() 207 pass
212 except KeyboardInterrupt: 208
213 pass 209 async def stop(self):
210 self.logger.debug("Stopping server")
211 await self.server.stop()
212
213 def start(self):
214 tasks = self.server.start(self.loop)
215 self.address = self.server.address
216 return tasks
214 217
215 def signal_handler(self): 218 def signal_handler(self):
216 self.logger.debug("Got exit signal") 219 self.logger.debug("Got exit signal")
217 self.loop.stop() 220 self.loop.create_task(self.stop())
218 221
219 def _serve_forever(self): 222 def _serve_forever(self, tasks):
220 try: 223 try:
221 self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler) 224 self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler)
225 self.loop.add_signal_handler(signal.SIGINT, self.signal_handler)
226 self.loop.add_signal_handler(signal.SIGQUIT, self.signal_handler)
222 signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM]) 227 signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM])
223 228
224 self.run_loop_forever() 229 self.loop.run_until_complete(asyncio.gather(*tasks))
225 self.server.close()
226 230
227 self.loop.run_until_complete(self.server.wait_closed()) 231 self.logger.debug("Server shutting down")
228 self.logger.debug('Server shutting down')
229 finally: 232 finally:
230 if self._cleanup_socket is not None: 233 self.server.cleanup()
231 self._cleanup_socket()
232 234
233 def serve_forever(self): 235 def serve_forever(self):
234 """ 236 """
235 Serve requests in the current process 237 Serve requests in the current process
236 """ 238 """
239 self._create_loop()
240 tasks = self.start()
241 self._serve_forever(tasks)
242 self.loop.close()
243
244 def _create_loop(self):
237 # Create loop and override any loop that may have existed in 245 # Create loop and override any loop that may have existed in
238 # a parent process. It is possible that the usecases of 246 # a parent process. It is possible that the usecases of
239 # serve_forever might be constrained enough to allow using 247 # serve_forever might be constrained enough to allow using
240 # get_event_loop here, but better safe than sorry for now. 248 # get_event_loop here, but better safe than sorry for now.
241 self.loop = asyncio.new_event_loop() 249 self.loop = asyncio.new_event_loop()
242 asyncio.set_event_loop(self.loop) 250 asyncio.set_event_loop(self.loop)
243 self.start()
244 self._serve_forever()
245 251
246 def serve_as_process(self, *, prefunc=None, args=()): 252 def serve_as_process(self, *, prefunc=None, args=()):
247 """ 253 """
248 Serve requests in a child process 254 Serve requests in a child process
249 """ 255 """
256
250 def run(queue): 257 def run(queue):
251 # Create loop and override any loop that may have existed 258 # Create loop and override any loop that may have existed
252 # in a parent process. Without doing this and instead 259 # in a parent process. Without doing this and instead
@@ -259,18 +266,19 @@ class AsyncServer(object):
259 # more general, though, as any potential use of asyncio in 266 # more general, though, as any potential use of asyncio in
260 # Cooker could create a loop that needs to replaced in this 267 # Cooker could create a loop that needs to replaced in this
261 # new process. 268 # new process.
262 self.loop = asyncio.new_event_loop() 269 self._create_loop()
263 asyncio.set_event_loop(self.loop)
264 try: 270 try:
265 self.start() 271 self.address = None
272 tasks = self.start()
266 finally: 273 finally:
274 # Always put the server address to wake up the parent task
267 queue.put(self.address) 275 queue.put(self.address)
268 queue.close() 276 queue.close()
269 277
270 if prefunc is not None: 278 if prefunc is not None:
271 prefunc(self, *args) 279 prefunc(self, *args)
272 280
273 self._serve_forever() 281 self._serve_forever(tasks)
274 282
275 if sys.version_info >= (3, 6): 283 if sys.version_info >= (3, 6):
276 self.loop.run_until_complete(self.loop.shutdown_asyncgens()) 284 self.loop.run_until_complete(self.loop.shutdown_asyncgens())