summaryrefslogtreecommitdiffstats
path: root/bitbake/lib/bb/asyncrpc
diff options
context:
space:
mode:
Diffstat (limited to 'bitbake/lib/bb/asyncrpc')
-rw-r--r--bitbake/lib/bb/asyncrpc/__init__.py16
-rw-r--r--bitbake/lib/bb/asyncrpc/client.py271
-rw-r--r--bitbake/lib/bb/asyncrpc/connection.py146
-rw-r--r--bitbake/lib/bb/asyncrpc/exceptions.py21
-rw-r--r--bitbake/lib/bb/asyncrpc/serv.py413
5 files changed, 867 insertions, 0 deletions
diff --git a/bitbake/lib/bb/asyncrpc/__init__.py b/bitbake/lib/bb/asyncrpc/__init__.py
new file mode 100644
index 0000000000..a4371643d7
--- /dev/null
+++ b/bitbake/lib/bb/asyncrpc/__init__.py
@@ -0,0 +1,16 @@
1#
2# Copyright BitBake Contributors
3#
4# SPDX-License-Identifier: GPL-2.0-only
5#
6
7
8from .client import AsyncClient, Client
9from .serv import AsyncServer, AsyncServerConnection
10from .connection import DEFAULT_MAX_CHUNK
11from .exceptions import (
12 ClientError,
13 ServerError,
14 ConnectionClosedError,
15 InvokeError,
16)
diff --git a/bitbake/lib/bb/asyncrpc/client.py b/bitbake/lib/bb/asyncrpc/client.py
new file mode 100644
index 0000000000..17b72033b9
--- /dev/null
+++ b/bitbake/lib/bb/asyncrpc/client.py
@@ -0,0 +1,271 @@
1#
2# Copyright BitBake Contributors
3#
4# SPDX-License-Identifier: GPL-2.0-only
5#
6
7import abc
8import asyncio
9import json
10import os
11import socket
12import sys
13import re
14import contextlib
15from threading import Thread
16from .connection import StreamConnection, WebsocketConnection, DEFAULT_MAX_CHUNK
17from .exceptions import ConnectionClosedError, InvokeError
18
19UNIX_PREFIX = "unix://"
20WS_PREFIX = "ws://"
21WSS_PREFIX = "wss://"
22
23ADDR_TYPE_UNIX = 0
24ADDR_TYPE_TCP = 1
25ADDR_TYPE_WS = 2
26
27WEBSOCKETS_MIN_VERSION = (9, 1)
28# Need websockets 10 with python 3.10+
29if sys.version_info >= (3, 10, 0):
30 WEBSOCKETS_MIN_VERSION = (10, 0)
31
32
33def parse_address(addr):
34 if addr.startswith(UNIX_PREFIX):
35 return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX) :],))
36 elif addr.startswith(WS_PREFIX) or addr.startswith(WSS_PREFIX):
37 return (ADDR_TYPE_WS, (addr,))
38 else:
39 m = re.match(r"\[(?P<host>[^\]]*)\]:(?P<port>\d+)$", addr)
40 if m is not None:
41 host = m.group("host")
42 port = m.group("port")
43 else:
44 host, port = addr.split(":")
45
46 return (ADDR_TYPE_TCP, (host, int(port)))
47
48
49class AsyncClient(object):
50 def __init__(
51 self,
52 proto_name,
53 proto_version,
54 logger,
55 timeout=30,
56 server_headers=False,
57 headers={},
58 ):
59 self.socket = None
60 self.max_chunk = DEFAULT_MAX_CHUNK
61 self.proto_name = proto_name
62 self.proto_version = proto_version
63 self.logger = logger
64 self.timeout = timeout
65 self.needs_server_headers = server_headers
66 self.server_headers = {}
67 self.headers = headers
68
69 async def connect_tcp(self, address, port):
70 async def connect_sock():
71 reader, writer = await asyncio.open_connection(address, port)
72 return StreamConnection(reader, writer, self.timeout, self.max_chunk)
73
74 self._connect_sock = connect_sock
75
76 async def connect_unix(self, path):
77 async def connect_sock():
78 # AF_UNIX has path length issues so chdir here to workaround
79 cwd = os.getcwd()
80 try:
81 os.chdir(os.path.dirname(path))
82 # The socket must be opened synchronously so that CWD doesn't get
83 # changed out from underneath us so we pass as a sock into asyncio
84 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0)
85 sock.connect(os.path.basename(path))
86 finally:
87 os.chdir(cwd)
88 reader, writer = await asyncio.open_unix_connection(sock=sock)
89 return StreamConnection(reader, writer, self.timeout, self.max_chunk)
90
91 self._connect_sock = connect_sock
92
93 async def connect_websocket(self, uri):
94 import websockets
95
96 try:
97 version = tuple(
98 int(v)
99 for v in websockets.__version__.split(".")[
100 0 : len(WEBSOCKETS_MIN_VERSION)
101 ]
102 )
103 except ValueError:
104 raise ImportError(
105 f"Unable to parse websockets version '{websockets.__version__}'"
106 )
107
108 if version < WEBSOCKETS_MIN_VERSION:
109 min_ver_str = ".".join(str(v) for v in WEBSOCKETS_MIN_VERSION)
110 raise ImportError(
111 f"Websockets version {websockets.__version__} is less than minimum required version {min_ver_str}"
112 )
113
114 async def connect_sock():
115 try:
116 websocket = await websockets.connect(
117 uri,
118 ping_interval=None,
119 open_timeout=self.timeout,
120 )
121 except asyncio.exceptions.TimeoutError:
122 raise ConnectionError("Timeout while connecting to websocket")
123 except (OSError, websockets.InvalidHandshake, websockets.InvalidURI) as exc:
124 raise ConnectionError(f"Could not connect to websocket: {exc}") from exc
125 return WebsocketConnection(websocket, self.timeout)
126
127 self._connect_sock = connect_sock
128
129 async def setup_connection(self):
130 # Send headers
131 await self.socket.send("%s %s" % (self.proto_name, self.proto_version))
132 await self.socket.send(
133 "needs-headers: %s" % ("true" if self.needs_server_headers else "false")
134 )
135 for k, v in self.headers.items():
136 await self.socket.send("%s: %s" % (k, v))
137
138 # End of headers
139 await self.socket.send("")
140
141 self.server_headers = {}
142 if self.needs_server_headers:
143 while True:
144 line = await self.socket.recv()
145 if not line:
146 # End headers
147 break
148 tag, value = line.split(":", 1)
149 self.server_headers[tag.lower()] = value.strip()
150
151 async def get_header(self, tag, default):
152 await self.connect()
153 return self.server_headers.get(tag, default)
154
155 async def connect(self):
156 if self.socket is None:
157 self.socket = await self._connect_sock()
158 await self.setup_connection()
159
160 async def disconnect(self):
161 if self.socket is not None:
162 await self.socket.close()
163 self.socket = None
164
165 async def close(self):
166 await self.disconnect()
167
168 async def _send_wrapper(self, proc):
169 count = 0
170 while True:
171 try:
172 await self.connect()
173 return await proc()
174 except (
175 OSError,
176 ConnectionError,
177 ConnectionClosedError,
178 json.JSONDecodeError,
179 UnicodeDecodeError,
180 ) as e:
181 self.logger.warning("Error talking to server: %s" % e)
182 if count >= 3:
183 if not isinstance(e, ConnectionError):
184 raise ConnectionError(str(e))
185 raise e
186 await self.close()
187 count += 1
188
189 def check_invoke_error(self, msg):
190 if isinstance(msg, dict) and "invoke-error" in msg:
191 raise InvokeError(msg["invoke-error"]["message"])
192
193 async def invoke(self, msg):
194 async def proc():
195 await self.socket.send_message(msg)
196 return await self.socket.recv_message()
197
198 result = await self._send_wrapper(proc)
199 self.check_invoke_error(result)
200 return result
201
202 async def ping(self):
203 return await self.invoke({"ping": {}})
204
205 async def __aenter__(self):
206 return self
207
208 async def __aexit__(self, exc_type, exc_value, traceback):
209 await self.close()
210
211
212class Client(object):
213 def __init__(self):
214 self.client = self._get_async_client()
215 self.loop = asyncio.new_event_loop()
216
217 # Override any pre-existing loop.
218 # Without this, the PR server export selftest triggers a hang
219 # when running with Python 3.7. The drawback is that there is
220 # potential for issues if the PR and hash equiv (or some new)
221 # clients need to both be instantiated in the same process.
222 # This should be revisited if/when Python 3.9 becomes the
223 # minimum required version for BitBake, as it seems not
224 # required (but harmless) with it.
225 asyncio.set_event_loop(self.loop)
226
227 self._add_methods("connect_tcp", "ping")
228
229 @abc.abstractmethod
230 def _get_async_client(self):
231 pass
232
233 def _get_downcall_wrapper(self, downcall):
234 def wrapper(*args, **kwargs):
235 return self.loop.run_until_complete(downcall(*args, **kwargs))
236
237 return wrapper
238
239 def _add_methods(self, *methods):
240 for m in methods:
241 downcall = getattr(self.client, m)
242 setattr(self, m, self._get_downcall_wrapper(downcall))
243
244 def connect_unix(self, path):
245 self.loop.run_until_complete(self.client.connect_unix(path))
246 self.loop.run_until_complete(self.client.connect())
247
248 @property
249 def max_chunk(self):
250 return self.client.max_chunk
251
252 @max_chunk.setter
253 def max_chunk(self, value):
254 self.client.max_chunk = value
255
256 def disconnect(self):
257 self.loop.run_until_complete(self.client.close())
258
259 def close(self):
260 if self.loop:
261 self.loop.run_until_complete(self.client.close())
262 self.loop.run_until_complete(self.loop.shutdown_asyncgens())
263 self.loop.close()
264 self.loop = None
265
266 def __enter__(self):
267 return self
268
269 def __exit__(self, exc_type, exc_value, traceback):
270 self.close()
271 return False
diff --git a/bitbake/lib/bb/asyncrpc/connection.py b/bitbake/lib/bb/asyncrpc/connection.py
new file mode 100644
index 0000000000..7f0cf6ba96
--- /dev/null
+++ b/bitbake/lib/bb/asyncrpc/connection.py
@@ -0,0 +1,146 @@
1#
2# Copyright BitBake Contributors
3#
4# SPDX-License-Identifier: GPL-2.0-only
5#
6
7import asyncio
8import itertools
9import json
10from datetime import datetime
11from .exceptions import ClientError, ConnectionClosedError
12
13
14# The Python async server defaults to a 64K receive buffer, so we hardcode our
15# maximum chunk size. It would be better if the client and server reported to
16# each other what the maximum chunk sizes were, but that will slow down the
17# connection setup with a round trip delay so I'd rather not do that unless it
18# is necessary
19DEFAULT_MAX_CHUNK = 32 * 1024
20
21
22def chunkify(msg, max_chunk):
23 if len(msg) < max_chunk - 1:
24 yield "".join((msg, "\n"))
25 else:
26 yield "".join((json.dumps({"chunk-stream": None}), "\n"))
27
28 args = [iter(msg)] * (max_chunk - 1)
29 for m in map("".join, itertools.zip_longest(*args, fillvalue="")):
30 yield "".join(itertools.chain(m, "\n"))
31 yield "\n"
32
33
34def json_serialize(obj):
35 if isinstance(obj, datetime):
36 return obj.isoformat()
37 raise TypeError("Type %s not serializeable" % type(obj))
38
39
40class StreamConnection(object):
41 def __init__(self, reader, writer, timeout, max_chunk=DEFAULT_MAX_CHUNK):
42 self.reader = reader
43 self.writer = writer
44 self.timeout = timeout
45 self.max_chunk = max_chunk
46
47 @property
48 def address(self):
49 return self.writer.get_extra_info("peername")
50
51 async def send_message(self, msg):
52 for c in chunkify(json.dumps(msg, default=json_serialize), self.max_chunk):
53 self.writer.write(c.encode("utf-8"))
54 await self.writer.drain()
55
56 async def recv_message(self):
57 l = await self.recv()
58
59 m = json.loads(l)
60 if not m:
61 return m
62
63 if "chunk-stream" in m:
64 lines = []
65 while True:
66 l = await self.recv()
67 if not l:
68 break
69 lines.append(l)
70
71 m = json.loads("".join(lines))
72
73 return m
74
75 async def send(self, msg):
76 self.writer.write(("%s\n" % msg).encode("utf-8"))
77 await self.writer.drain()
78
79 async def recv(self):
80 if self.timeout < 0:
81 line = await self.reader.readline()
82 else:
83 try:
84 line = await asyncio.wait_for(self.reader.readline(), self.timeout)
85 except asyncio.TimeoutError:
86 raise ConnectionError("Timed out waiting for data")
87
88 if not line:
89 raise ConnectionClosedError("Connection closed")
90
91 line = line.decode("utf-8")
92
93 if not line.endswith("\n"):
94 raise ConnectionError("Bad message %r" % (line))
95
96 return line.rstrip()
97
98 async def close(self):
99 self.reader = None
100 if self.writer is not None:
101 self.writer.close()
102 self.writer = None
103
104
105class WebsocketConnection(object):
106 def __init__(self, socket, timeout):
107 self.socket = socket
108 self.timeout = timeout
109
110 @property
111 def address(self):
112 return ":".join(str(s) for s in self.socket.remote_address)
113
114 async def send_message(self, msg):
115 await self.send(json.dumps(msg, default=json_serialize))
116
117 async def recv_message(self):
118 m = await self.recv()
119 return json.loads(m)
120
121 async def send(self, msg):
122 import websockets.exceptions
123
124 try:
125 await self.socket.send(msg)
126 except websockets.exceptions.ConnectionClosed:
127 raise ConnectionClosedError("Connection closed")
128
129 async def recv(self):
130 import websockets.exceptions
131
132 try:
133 if self.timeout < 0:
134 return await self.socket.recv()
135
136 try:
137 return await asyncio.wait_for(self.socket.recv(), self.timeout)
138 except asyncio.TimeoutError:
139 raise ConnectionError("Timed out waiting for data")
140 except websockets.exceptions.ConnectionClosed:
141 raise ConnectionClosedError("Connection closed")
142
143 async def close(self):
144 if self.socket is not None:
145 await self.socket.close()
146 self.socket = None
diff --git a/bitbake/lib/bb/asyncrpc/exceptions.py b/bitbake/lib/bb/asyncrpc/exceptions.py
new file mode 100644
index 0000000000..ae1043a38b
--- /dev/null
+++ b/bitbake/lib/bb/asyncrpc/exceptions.py
@@ -0,0 +1,21 @@
1#
2# Copyright BitBake Contributors
3#
4# SPDX-License-Identifier: GPL-2.0-only
5#
6
7
8class ClientError(Exception):
9 pass
10
11
12class InvokeError(Exception):
13 pass
14
15
16class ServerError(Exception):
17 pass
18
19
20class ConnectionClosedError(Exception):
21 pass
diff --git a/bitbake/lib/bb/asyncrpc/serv.py b/bitbake/lib/bb/asyncrpc/serv.py
new file mode 100644
index 0000000000..bd1aded8db
--- /dev/null
+++ b/bitbake/lib/bb/asyncrpc/serv.py
@@ -0,0 +1,413 @@
1#
2# Copyright BitBake Contributors
3#
4# SPDX-License-Identifier: GPL-2.0-only
5#
6
7import abc
8import asyncio
9import json
10import os
11import signal
12import socket
13import sys
14from bb import multiprocessing
15import logging
16from .connection import StreamConnection, WebsocketConnection
17from .exceptions import ClientError, ServerError, ConnectionClosedError, InvokeError
18
19
20class ClientLoggerAdapter(logging.LoggerAdapter):
21 def process(self, msg, kwargs):
22 return f"[Client {self.extra['address']}] {msg}", kwargs
23
24
25class AsyncServerConnection(object):
26 # If a handler returns this object (e.g. `return self.NO_RESPONSE`), no
27 # return message will be automatically be sent back to the client
28 NO_RESPONSE = object()
29
30 def __init__(self, socket, proto_name, logger):
31 self.socket = socket
32 self.proto_name = proto_name
33 self.handlers = {
34 "ping": self.handle_ping,
35 }
36 self.logger = ClientLoggerAdapter(
37 logger,
38 {
39 "address": socket.address,
40 },
41 )
42 self.client_headers = {}
43
44 async def close(self):
45 await self.socket.close()
46
47 async def handle_headers(self, headers):
48 return {}
49
50 async def process_requests(self):
51 try:
52 self.logger.info("Client %r connected" % (self.socket.address,))
53
54 # Read protocol and version
55 client_protocol = await self.socket.recv()
56 if not client_protocol:
57 return
58
59 (client_proto_name, client_proto_version) = client_protocol.split()
60 if client_proto_name != self.proto_name:
61 self.logger.debug("Rejecting invalid protocol %s" % (self.proto_name))
62 return
63
64 self.proto_version = tuple(int(v) for v in client_proto_version.split("."))
65 if not self.validate_proto_version():
66 self.logger.debug(
67 "Rejecting invalid protocol version %s" % (client_proto_version)
68 )
69 return
70
71 # Read headers
72 self.client_headers = {}
73 while True:
74 header = await self.socket.recv()
75 if not header:
76 # Empty line. End of headers
77 break
78 tag, value = header.split(":", 1)
79 self.client_headers[tag.lower()] = value.strip()
80
81 if self.client_headers.get("needs-headers", "false") == "true":
82 for k, v in (await self.handle_headers(self.client_headers)).items():
83 await self.socket.send("%s: %s" % (k, v))
84 await self.socket.send("")
85
86 # Handle messages
87 while True:
88 d = await self.socket.recv_message()
89 if d is None:
90 break
91 try:
92 response = await self.dispatch_message(d)
93 except InvokeError as e:
94 await self.socket.send_message(
95 {"invoke-error": {"message": str(e)}}
96 )
97 break
98
99 if response is not self.NO_RESPONSE:
100 await self.socket.send_message(response)
101
102 except ConnectionClosedError as e:
103 self.logger.info(str(e))
104 except (ClientError, ConnectionError) as e:
105 self.logger.error(str(e))
106 finally:
107 await self.close()
108
109 async def dispatch_message(self, msg):
110 for k in self.handlers.keys():
111 if k in msg:
112 self.logger.debug("Handling %s" % k)
113 return await self.handlers[k](msg[k])
114
115 raise ClientError("Unrecognized command %r" % msg)
116
117 async def handle_ping(self, request):
118 return {"alive": True}
119
120
121class StreamServer(object):
122 def __init__(self, handler, logger):
123 self.handler = handler
124 self.logger = logger
125 self.closed = False
126
127 async def handle_stream_client(self, reader, writer):
128 # writer.transport.set_write_buffer_limits(0)
129 socket = StreamConnection(reader, writer, -1)
130 if self.closed:
131 await socket.close()
132 return
133
134 await self.handler(socket)
135
136 async def stop(self):
137 self.closed = True
138
139
140class TCPStreamServer(StreamServer):
141 def __init__(self, host, port, handler, logger, *, reuseport=False):
142 super().__init__(handler, logger)
143 self.host = host
144 self.port = port
145 self.reuseport = reuseport
146
147 def start(self, loop):
148 self.server = loop.run_until_complete(
149 asyncio.start_server(
150 self.handle_stream_client,
151 self.host,
152 self.port,
153 reuse_port=self.reuseport,
154 )
155 )
156
157 for s in self.server.sockets:
158 self.logger.debug("Listening on %r" % (s.getsockname(),))
159 # Newer python does this automatically. Do it manually here for
160 # maximum compatibility
161 s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
162 s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
163
164 # Enable keep alives. This prevents broken client connections
165 # from persisting on the server for long periods of time.
166 s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
167 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
168 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
169 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
170
171 name = self.server.sockets[0].getsockname()
172 if self.server.sockets[0].family == socket.AF_INET6:
173 self.address = "[%s]:%d" % (name[0], name[1])
174 else:
175 self.address = "%s:%d" % (name[0], name[1])
176
177 return [self.server.wait_closed()]
178
179 async def stop(self):
180 await super().stop()
181 self.server.close()
182
183 def cleanup(self):
184 pass
185
186
187class UnixStreamServer(StreamServer):
188 def __init__(self, path, handler, logger):
189 super().__init__(handler, logger)
190 self.path = path
191
192 def start(self, loop):
193 cwd = os.getcwd()
194 try:
195 # Work around path length limits in AF_UNIX
196 os.chdir(os.path.dirname(self.path))
197 self.server = loop.run_until_complete(
198 asyncio.start_unix_server(
199 self.handle_stream_client, os.path.basename(self.path)
200 )
201 )
202 finally:
203 os.chdir(cwd)
204
205 self.logger.debug("Listening on %r" % self.path)
206 self.address = "unix://%s" % os.path.abspath(self.path)
207 return [self.server.wait_closed()]
208
209 async def stop(self):
210 await super().stop()
211 self.server.close()
212
213 def cleanup(self):
214 try:
215 os.unlink(self.path)
216 except FileNotFoundError:
217 pass
218
219
220class WebsocketsServer(object):
221 def __init__(self, host, port, handler, logger, *, reuseport=False):
222 self.host = host
223 self.port = port
224 self.handler = handler
225 self.logger = logger
226 self.reuseport = reuseport
227
228 def start(self, loop):
229 import websockets.server
230
231 self.server = loop.run_until_complete(
232 websockets.server.serve(
233 self.client_handler,
234 self.host,
235 self.port,
236 ping_interval=None,
237 reuse_port=self.reuseport,
238 )
239 )
240
241 for s in self.server.sockets:
242 self.logger.debug("Listening on %r" % (s.getsockname(),))
243
244 # Enable keep alives. This prevents broken client connections
245 # from persisting on the server for long periods of time.
246 s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
247 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
248 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
249 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
250
251 name = self.server.sockets[0].getsockname()
252 if self.server.sockets[0].family == socket.AF_INET6:
253 self.address = "ws://[%s]:%d" % (name[0], name[1])
254 else:
255 self.address = "ws://%s:%d" % (name[0], name[1])
256
257 return [self.server.wait_closed()]
258
259 async def stop(self):
260 self.server.close()
261
262 def cleanup(self):
263 pass
264
265 async def client_handler(self, websocket):
266 socket = WebsocketConnection(websocket, -1)
267 await self.handler(socket)
268
269
270class AsyncServer(object):
271 def __init__(self, logger):
272 self.logger = logger
273 self.loop = None
274 self.run_tasks = []
275
276 def start_tcp_server(self, host, port, *, reuseport=False):
277 self.server = TCPStreamServer(
278 host,
279 port,
280 self._client_handler,
281 self.logger,
282 reuseport=reuseport,
283 )
284
285 def start_unix_server(self, path):
286 self.server = UnixStreamServer(path, self._client_handler, self.logger)
287
288 def start_websocket_server(self, host, port, reuseport=False):
289 self.server = WebsocketsServer(
290 host,
291 port,
292 self._client_handler,
293 self.logger,
294 reuseport=reuseport,
295 )
296
297 async def _client_handler(self, socket):
298 address = socket.address
299 try:
300 client = self.accept_client(socket)
301 await client.process_requests()
302 except Exception as e:
303 import traceback
304
305 self.logger.error(
306 "Error from client %s: %s" % (address, str(e)), exc_info=True
307 )
308 traceback.print_exc()
309 finally:
310 self.logger.debug("Client %s disconnected", address)
311 await socket.close()
312
313 @abc.abstractmethod
314 def accept_client(self, socket):
315 pass
316
317 async def stop(self):
318 self.logger.debug("Stopping server")
319 await self.server.stop()
320
321 def start(self):
322 tasks = self.server.start(self.loop)
323 self.address = self.server.address
324 return tasks
325
326 def signal_handler(self):
327 self.logger.debug("Got exit signal")
328 self.loop.create_task(self.stop())
329
330 def _serve_forever(self, tasks):
331 try:
332 self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler)
333 self.loop.add_signal_handler(signal.SIGINT, self.signal_handler)
334 self.loop.add_signal_handler(signal.SIGQUIT, self.signal_handler)
335 signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM])
336
337 self.loop.run_until_complete(asyncio.gather(*tasks))
338
339 self.logger.debug("Server shutting down")
340 finally:
341 self.server.cleanup()
342
343 def serve_forever(self):
344 """
345 Serve requests in the current process
346 """
347 self._create_loop()
348 tasks = self.start()
349 self._serve_forever(tasks)
350 self.loop.close()
351
352 def _create_loop(self):
353 # Create loop and override any loop that may have existed in
354 # a parent process. It is possible that the usecases of
355 # serve_forever might be constrained enough to allow using
356 # get_event_loop here, but better safe than sorry for now.
357 self.loop = asyncio.new_event_loop()
358 asyncio.set_event_loop(self.loop)
359
360 def serve_as_process(self, *, prefunc=None, args=(), log_level=None):
361 """
362 Serve requests in a child process
363 """
364
365 def run(queue):
366 # Create loop and override any loop that may have existed
367 # in a parent process. Without doing this and instead
368 # using get_event_loop, at the very minimum the hashserv
369 # unit tests will hang when running the second test.
370 # This happens since get_event_loop in the spawned server
371 # process for the second testcase ends up with the loop
372 # from the hashserv client created in the unit test process
373 # when running the first testcase. The problem is somewhat
374 # more general, though, as any potential use of asyncio in
375 # Cooker could create a loop that needs to replaced in this
376 # new process.
377 self._create_loop()
378 try:
379 self.address = None
380 tasks = self.start()
381 finally:
382 # Always put the server address to wake up the parent task
383 queue.put(self.address)
384 queue.close()
385
386 if prefunc is not None:
387 prefunc(self, *args)
388
389 if log_level is not None:
390 self.logger.setLevel(log_level)
391
392 self._serve_forever(tasks)
393
394 self.loop.run_until_complete(self.loop.shutdown_asyncgens())
395 self.loop.close()
396
397 queue = multiprocessing.Queue()
398
399 # Temporarily block SIGTERM. The server process will inherit this
400 # block which will ensure it doesn't receive the SIGTERM until the
401 # handler is ready for it
402 mask = signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGTERM])
403 try:
404 self.process = multiprocessing.Process(target=run, args=(queue,))
405 self.process.start()
406
407 self.address = queue.get()
408 queue.close()
409 queue.join_thread()
410
411 return self.process
412 finally:
413 signal.pthread_sigmask(signal.SIG_SETMASK, mask)