summaryrefslogtreecommitdiffstats
path: root/bitbake/lib/bb/asyncrpc
diff options
context:
space:
mode:
Diffstat (limited to 'bitbake/lib/bb/asyncrpc')
-rw-r--r--bitbake/lib/bb/asyncrpc/__init__.py16
-rw-r--r--bitbake/lib/bb/asyncrpc/client.py271
-rw-r--r--bitbake/lib/bb/asyncrpc/connection.py146
-rw-r--r--bitbake/lib/bb/asyncrpc/exceptions.py21
-rw-r--r--bitbake/lib/bb/asyncrpc/serv.py410
5 files changed, 864 insertions, 0 deletions
diff --git a/bitbake/lib/bb/asyncrpc/__init__.py b/bitbake/lib/bb/asyncrpc/__init__.py
new file mode 100644
index 0000000000..a4371643d7
--- /dev/null
+++ b/bitbake/lib/bb/asyncrpc/__init__.py
@@ -0,0 +1,16 @@
1#
2# Copyright BitBake Contributors
3#
4# SPDX-License-Identifier: GPL-2.0-only
5#
6
7
8from .client import AsyncClient, Client
9from .serv import AsyncServer, AsyncServerConnection
10from .connection import DEFAULT_MAX_CHUNK
11from .exceptions import (
12 ClientError,
13 ServerError,
14 ConnectionClosedError,
15 InvokeError,
16)
diff --git a/bitbake/lib/bb/asyncrpc/client.py b/bitbake/lib/bb/asyncrpc/client.py
new file mode 100644
index 0000000000..17b72033b9
--- /dev/null
+++ b/bitbake/lib/bb/asyncrpc/client.py
@@ -0,0 +1,271 @@
1#
2# Copyright BitBake Contributors
3#
4# SPDX-License-Identifier: GPL-2.0-only
5#
6
7import abc
8import asyncio
9import json
10import os
11import socket
12import sys
13import re
14import contextlib
15from threading import Thread
16from .connection import StreamConnection, WebsocketConnection, DEFAULT_MAX_CHUNK
17from .exceptions import ConnectionClosedError, InvokeError
18
19UNIX_PREFIX = "unix://"
20WS_PREFIX = "ws://"
21WSS_PREFIX = "wss://"
22
23ADDR_TYPE_UNIX = 0
24ADDR_TYPE_TCP = 1
25ADDR_TYPE_WS = 2
26
27WEBSOCKETS_MIN_VERSION = (9, 1)
28# Need websockets 10 with python 3.10+
29if sys.version_info >= (3, 10, 0):
30 WEBSOCKETS_MIN_VERSION = (10, 0)
31
32
33def parse_address(addr):
34 if addr.startswith(UNIX_PREFIX):
35 return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX) :],))
36 elif addr.startswith(WS_PREFIX) or addr.startswith(WSS_PREFIX):
37 return (ADDR_TYPE_WS, (addr,))
38 else:
39 m = re.match(r"\[(?P<host>[^\]]*)\]:(?P<port>\d+)$", addr)
40 if m is not None:
41 host = m.group("host")
42 port = m.group("port")
43 else:
44 host, port = addr.split(":")
45
46 return (ADDR_TYPE_TCP, (host, int(port)))
47
48
49class AsyncClient(object):
50 def __init__(
51 self,
52 proto_name,
53 proto_version,
54 logger,
55 timeout=30,
56 server_headers=False,
57 headers={},
58 ):
59 self.socket = None
60 self.max_chunk = DEFAULT_MAX_CHUNK
61 self.proto_name = proto_name
62 self.proto_version = proto_version
63 self.logger = logger
64 self.timeout = timeout
65 self.needs_server_headers = server_headers
66 self.server_headers = {}
67 self.headers = headers
68
69 async def connect_tcp(self, address, port):
70 async def connect_sock():
71 reader, writer = await asyncio.open_connection(address, port)
72 return StreamConnection(reader, writer, self.timeout, self.max_chunk)
73
74 self._connect_sock = connect_sock
75
76 async def connect_unix(self, path):
77 async def connect_sock():
78 # AF_UNIX has path length issues so chdir here to workaround
79 cwd = os.getcwd()
80 try:
81 os.chdir(os.path.dirname(path))
82 # The socket must be opened synchronously so that CWD doesn't get
83 # changed out from underneath us so we pass as a sock into asyncio
84 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0)
85 sock.connect(os.path.basename(path))
86 finally:
87 os.chdir(cwd)
88 reader, writer = await asyncio.open_unix_connection(sock=sock)
89 return StreamConnection(reader, writer, self.timeout, self.max_chunk)
90
91 self._connect_sock = connect_sock
92
93 async def connect_websocket(self, uri):
94 import websockets
95
96 try:
97 version = tuple(
98 int(v)
99 for v in websockets.__version__.split(".")[
100 0 : len(WEBSOCKETS_MIN_VERSION)
101 ]
102 )
103 except ValueError:
104 raise ImportError(
105 f"Unable to parse websockets version '{websockets.__version__}'"
106 )
107
108 if version < WEBSOCKETS_MIN_VERSION:
109 min_ver_str = ".".join(str(v) for v in WEBSOCKETS_MIN_VERSION)
110 raise ImportError(
111 f"Websockets version {websockets.__version__} is less than minimum required version {min_ver_str}"
112 )
113
114 async def connect_sock():
115 try:
116 websocket = await websockets.connect(
117 uri,
118 ping_interval=None,
119 open_timeout=self.timeout,
120 )
121 except asyncio.exceptions.TimeoutError:
122 raise ConnectionError("Timeout while connecting to websocket")
123 except (OSError, websockets.InvalidHandshake, websockets.InvalidURI) as exc:
124 raise ConnectionError(f"Could not connect to websocket: {exc}") from exc
125 return WebsocketConnection(websocket, self.timeout)
126
127 self._connect_sock = connect_sock
128
129 async def setup_connection(self):
130 # Send headers
131 await self.socket.send("%s %s" % (self.proto_name, self.proto_version))
132 await self.socket.send(
133 "needs-headers: %s" % ("true" if self.needs_server_headers else "false")
134 )
135 for k, v in self.headers.items():
136 await self.socket.send("%s: %s" % (k, v))
137
138 # End of headers
139 await self.socket.send("")
140
141 self.server_headers = {}
142 if self.needs_server_headers:
143 while True:
144 line = await self.socket.recv()
145 if not line:
146 # End headers
147 break
148 tag, value = line.split(":", 1)
149 self.server_headers[tag.lower()] = value.strip()
150
151 async def get_header(self, tag, default):
152 await self.connect()
153 return self.server_headers.get(tag, default)
154
155 async def connect(self):
156 if self.socket is None:
157 self.socket = await self._connect_sock()
158 await self.setup_connection()
159
160 async def disconnect(self):
161 if self.socket is not None:
162 await self.socket.close()
163 self.socket = None
164
165 async def close(self):
166 await self.disconnect()
167
168 async def _send_wrapper(self, proc):
169 count = 0
170 while True:
171 try:
172 await self.connect()
173 return await proc()
174 except (
175 OSError,
176 ConnectionError,
177 ConnectionClosedError,
178 json.JSONDecodeError,
179 UnicodeDecodeError,
180 ) as e:
181 self.logger.warning("Error talking to server: %s" % e)
182 if count >= 3:
183 if not isinstance(e, ConnectionError):
184 raise ConnectionError(str(e))
185 raise e
186 await self.close()
187 count += 1
188
189 def check_invoke_error(self, msg):
190 if isinstance(msg, dict) and "invoke-error" in msg:
191 raise InvokeError(msg["invoke-error"]["message"])
192
193 async def invoke(self, msg):
194 async def proc():
195 await self.socket.send_message(msg)
196 return await self.socket.recv_message()
197
198 result = await self._send_wrapper(proc)
199 self.check_invoke_error(result)
200 return result
201
202 async def ping(self):
203 return await self.invoke({"ping": {}})
204
205 async def __aenter__(self):
206 return self
207
208 async def __aexit__(self, exc_type, exc_value, traceback):
209 await self.close()
210
211
212class Client(object):
213 def __init__(self):
214 self.client = self._get_async_client()
215 self.loop = asyncio.new_event_loop()
216
217 # Override any pre-existing loop.
218 # Without this, the PR server export selftest triggers a hang
219 # when running with Python 3.7. The drawback is that there is
220 # potential for issues if the PR and hash equiv (or some new)
221 # clients need to both be instantiated in the same process.
222 # This should be revisited if/when Python 3.9 becomes the
223 # minimum required version for BitBake, as it seems not
224 # required (but harmless) with it.
225 asyncio.set_event_loop(self.loop)
226
227 self._add_methods("connect_tcp", "ping")
228
229 @abc.abstractmethod
230 def _get_async_client(self):
231 pass
232
233 def _get_downcall_wrapper(self, downcall):
234 def wrapper(*args, **kwargs):
235 return self.loop.run_until_complete(downcall(*args, **kwargs))
236
237 return wrapper
238
239 def _add_methods(self, *methods):
240 for m in methods:
241 downcall = getattr(self.client, m)
242 setattr(self, m, self._get_downcall_wrapper(downcall))
243
244 def connect_unix(self, path):
245 self.loop.run_until_complete(self.client.connect_unix(path))
246 self.loop.run_until_complete(self.client.connect())
247
248 @property
249 def max_chunk(self):
250 return self.client.max_chunk
251
252 @max_chunk.setter
253 def max_chunk(self, value):
254 self.client.max_chunk = value
255
256 def disconnect(self):
257 self.loop.run_until_complete(self.client.close())
258
259 def close(self):
260 if self.loop:
261 self.loop.run_until_complete(self.client.close())
262 self.loop.run_until_complete(self.loop.shutdown_asyncgens())
263 self.loop.close()
264 self.loop = None
265
266 def __enter__(self):
267 return self
268
269 def __exit__(self, exc_type, exc_value, traceback):
270 self.close()
271 return False
diff --git a/bitbake/lib/bb/asyncrpc/connection.py b/bitbake/lib/bb/asyncrpc/connection.py
new file mode 100644
index 0000000000..7f0cf6ba96
--- /dev/null
+++ b/bitbake/lib/bb/asyncrpc/connection.py
@@ -0,0 +1,146 @@
1#
2# Copyright BitBake Contributors
3#
4# SPDX-License-Identifier: GPL-2.0-only
5#
6
7import asyncio
8import itertools
9import json
10from datetime import datetime
11from .exceptions import ClientError, ConnectionClosedError
12
13
14# The Python async server defaults to a 64K receive buffer, so we hardcode our
15# maximum chunk size. It would be better if the client and server reported to
16# each other what the maximum chunk sizes were, but that will slow down the
17# connection setup with a round trip delay so I'd rather not do that unless it
18# is necessary
19DEFAULT_MAX_CHUNK = 32 * 1024
20
21
22def chunkify(msg, max_chunk):
23 if len(msg) < max_chunk - 1:
24 yield "".join((msg, "\n"))
25 else:
26 yield "".join((json.dumps({"chunk-stream": None}), "\n"))
27
28 args = [iter(msg)] * (max_chunk - 1)
29 for m in map("".join, itertools.zip_longest(*args, fillvalue="")):
30 yield "".join(itertools.chain(m, "\n"))
31 yield "\n"
32
33
34def json_serialize(obj):
35 if isinstance(obj, datetime):
36 return obj.isoformat()
37 raise TypeError("Type %s not serializeable" % type(obj))
38
39
40class StreamConnection(object):
41 def __init__(self, reader, writer, timeout, max_chunk=DEFAULT_MAX_CHUNK):
42 self.reader = reader
43 self.writer = writer
44 self.timeout = timeout
45 self.max_chunk = max_chunk
46
47 @property
48 def address(self):
49 return self.writer.get_extra_info("peername")
50
51 async def send_message(self, msg):
52 for c in chunkify(json.dumps(msg, default=json_serialize), self.max_chunk):
53 self.writer.write(c.encode("utf-8"))
54 await self.writer.drain()
55
56 async def recv_message(self):
57 l = await self.recv()
58
59 m = json.loads(l)
60 if not m:
61 return m
62
63 if "chunk-stream" in m:
64 lines = []
65 while True:
66 l = await self.recv()
67 if not l:
68 break
69 lines.append(l)
70
71 m = json.loads("".join(lines))
72
73 return m
74
75 async def send(self, msg):
76 self.writer.write(("%s\n" % msg).encode("utf-8"))
77 await self.writer.drain()
78
79 async def recv(self):
80 if self.timeout < 0:
81 line = await self.reader.readline()
82 else:
83 try:
84 line = await asyncio.wait_for(self.reader.readline(), self.timeout)
85 except asyncio.TimeoutError:
86 raise ConnectionError("Timed out waiting for data")
87
88 if not line:
89 raise ConnectionClosedError("Connection closed")
90
91 line = line.decode("utf-8")
92
93 if not line.endswith("\n"):
94 raise ConnectionError("Bad message %r" % (line))
95
96 return line.rstrip()
97
98 async def close(self):
99 self.reader = None
100 if self.writer is not None:
101 self.writer.close()
102 self.writer = None
103
104
105class WebsocketConnection(object):
106 def __init__(self, socket, timeout):
107 self.socket = socket
108 self.timeout = timeout
109
110 @property
111 def address(self):
112 return ":".join(str(s) for s in self.socket.remote_address)
113
114 async def send_message(self, msg):
115 await self.send(json.dumps(msg, default=json_serialize))
116
117 async def recv_message(self):
118 m = await self.recv()
119 return json.loads(m)
120
121 async def send(self, msg):
122 import websockets.exceptions
123
124 try:
125 await self.socket.send(msg)
126 except websockets.exceptions.ConnectionClosed:
127 raise ConnectionClosedError("Connection closed")
128
129 async def recv(self):
130 import websockets.exceptions
131
132 try:
133 if self.timeout < 0:
134 return await self.socket.recv()
135
136 try:
137 return await asyncio.wait_for(self.socket.recv(), self.timeout)
138 except asyncio.TimeoutError:
139 raise ConnectionError("Timed out waiting for data")
140 except websockets.exceptions.ConnectionClosed:
141 raise ConnectionClosedError("Connection closed")
142
143 async def close(self):
144 if self.socket is not None:
145 await self.socket.close()
146 self.socket = None
diff --git a/bitbake/lib/bb/asyncrpc/exceptions.py b/bitbake/lib/bb/asyncrpc/exceptions.py
new file mode 100644
index 0000000000..ae1043a38b
--- /dev/null
+++ b/bitbake/lib/bb/asyncrpc/exceptions.py
@@ -0,0 +1,21 @@
1#
2# Copyright BitBake Contributors
3#
4# SPDX-License-Identifier: GPL-2.0-only
5#
6
7
8class ClientError(Exception):
9 pass
10
11
12class InvokeError(Exception):
13 pass
14
15
16class ServerError(Exception):
17 pass
18
19
20class ConnectionClosedError(Exception):
21 pass
diff --git a/bitbake/lib/bb/asyncrpc/serv.py b/bitbake/lib/bb/asyncrpc/serv.py
new file mode 100644
index 0000000000..667217c5c1
--- /dev/null
+++ b/bitbake/lib/bb/asyncrpc/serv.py
@@ -0,0 +1,410 @@
1#
2# Copyright BitBake Contributors
3#
4# SPDX-License-Identifier: GPL-2.0-only
5#
6
7import abc
8import asyncio
9import json
10import os
11import signal
12import socket
13import sys
14import multiprocessing
15import logging
16from .connection import StreamConnection, WebsocketConnection
17from .exceptions import ClientError, ServerError, ConnectionClosedError, InvokeError
18
19
20class ClientLoggerAdapter(logging.LoggerAdapter):
21 def process(self, msg, kwargs):
22 return f"[Client {self.extra['address']}] {msg}", kwargs
23
24
25class AsyncServerConnection(object):
26 # If a handler returns this object (e.g. `return self.NO_RESPONSE`), no
27 # return message will be automatically be sent back to the client
28 NO_RESPONSE = object()
29
30 def __init__(self, socket, proto_name, logger):
31 self.socket = socket
32 self.proto_name = proto_name
33 self.handlers = {
34 "ping": self.handle_ping,
35 }
36 self.logger = ClientLoggerAdapter(
37 logger,
38 {
39 "address": socket.address,
40 },
41 )
42 self.client_headers = {}
43
44 async def close(self):
45 await self.socket.close()
46
47 async def handle_headers(self, headers):
48 return {}
49
50 async def process_requests(self):
51 try:
52 self.logger.info("Client %r connected" % (self.socket.address,))
53
54 # Read protocol and version
55 client_protocol = await self.socket.recv()
56 if not client_protocol:
57 return
58
59 (client_proto_name, client_proto_version) = client_protocol.split()
60 if client_proto_name != self.proto_name:
61 self.logger.debug("Rejecting invalid protocol %s" % (self.proto_name))
62 return
63
64 self.proto_version = tuple(int(v) for v in client_proto_version.split("."))
65 if not self.validate_proto_version():
66 self.logger.debug(
67 "Rejecting invalid protocol version %s" % (client_proto_version)
68 )
69 return
70
71 # Read headers
72 self.client_headers = {}
73 while True:
74 header = await self.socket.recv()
75 if not header:
76 # Empty line. End of headers
77 break
78 tag, value = header.split(":", 1)
79 self.client_headers[tag.lower()] = value.strip()
80
81 if self.client_headers.get("needs-headers", "false") == "true":
82 for k, v in (await self.handle_headers(self.client_headers)).items():
83 await self.socket.send("%s: %s" % (k, v))
84 await self.socket.send("")
85
86 # Handle messages
87 while True:
88 d = await self.socket.recv_message()
89 if d is None:
90 break
91 try:
92 response = await self.dispatch_message(d)
93 except InvokeError as e:
94 await self.socket.send_message(
95 {"invoke-error": {"message": str(e)}}
96 )
97 break
98
99 if response is not self.NO_RESPONSE:
100 await self.socket.send_message(response)
101
102 except ConnectionClosedError as e:
103 self.logger.info(str(e))
104 except (ClientError, ConnectionError) as e:
105 self.logger.error(str(e))
106 finally:
107 await self.close()
108
109 async def dispatch_message(self, msg):
110 for k in self.handlers.keys():
111 if k in msg:
112 self.logger.debug("Handling %s" % k)
113 return await self.handlers[k](msg[k])
114
115 raise ClientError("Unrecognized command %r" % msg)
116
117 async def handle_ping(self, request):
118 return {"alive": True}
119
120
121class StreamServer(object):
122 def __init__(self, handler, logger):
123 self.handler = handler
124 self.logger = logger
125 self.closed = False
126
127 async def handle_stream_client(self, reader, writer):
128 # writer.transport.set_write_buffer_limits(0)
129 socket = StreamConnection(reader, writer, -1)
130 if self.closed:
131 await socket.close()
132 return
133
134 await self.handler(socket)
135
136 async def stop(self):
137 self.closed = True
138
139
140class TCPStreamServer(StreamServer):
141 def __init__(self, host, port, handler, logger, *, reuseport=False):
142 super().__init__(handler, logger)
143 self.host = host
144 self.port = port
145 self.reuseport = reuseport
146
147 def start(self, loop):
148 self.server = loop.run_until_complete(
149 asyncio.start_server(
150 self.handle_stream_client,
151 self.host,
152 self.port,
153 reuse_port=self.reuseport,
154 )
155 )
156
157 for s in self.server.sockets:
158 self.logger.debug("Listening on %r" % (s.getsockname(),))
159 # Newer python does this automatically. Do it manually here for
160 # maximum compatibility
161 s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
162 s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
163
164 # Enable keep alives. This prevents broken client connections
165 # from persisting on the server for long periods of time.
166 s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
167 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
168 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
169 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
170
171 name = self.server.sockets[0].getsockname()
172 if self.server.sockets[0].family == socket.AF_INET6:
173 self.address = "[%s]:%d" % (name[0], name[1])
174 else:
175 self.address = "%s:%d" % (name[0], name[1])
176
177 return [self.server.wait_closed()]
178
179 async def stop(self):
180 await super().stop()
181 self.server.close()
182
183 def cleanup(self):
184 pass
185
186
187class UnixStreamServer(StreamServer):
188 def __init__(self, path, handler, logger):
189 super().__init__(handler, logger)
190 self.path = path
191
192 def start(self, loop):
193 cwd = os.getcwd()
194 try:
195 # Work around path length limits in AF_UNIX
196 os.chdir(os.path.dirname(self.path))
197 self.server = loop.run_until_complete(
198 asyncio.start_unix_server(
199 self.handle_stream_client, os.path.basename(self.path)
200 )
201 )
202 finally:
203 os.chdir(cwd)
204
205 self.logger.debug("Listening on %r" % self.path)
206 self.address = "unix://%s" % os.path.abspath(self.path)
207 return [self.server.wait_closed()]
208
209 async def stop(self):
210 await super().stop()
211 self.server.close()
212
213 def cleanup(self):
214 os.unlink(self.path)
215
216
217class WebsocketsServer(object):
218 def __init__(self, host, port, handler, logger, *, reuseport=False):
219 self.host = host
220 self.port = port
221 self.handler = handler
222 self.logger = logger
223 self.reuseport = reuseport
224
225 def start(self, loop):
226 import websockets.server
227
228 self.server = loop.run_until_complete(
229 websockets.server.serve(
230 self.client_handler,
231 self.host,
232 self.port,
233 ping_interval=None,
234 reuse_port=self.reuseport,
235 )
236 )
237
238 for s in self.server.sockets:
239 self.logger.debug("Listening on %r" % (s.getsockname(),))
240
241 # Enable keep alives. This prevents broken client connections
242 # from persisting on the server for long periods of time.
243 s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
244 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
245 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
246 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
247
248 name = self.server.sockets[0].getsockname()
249 if self.server.sockets[0].family == socket.AF_INET6:
250 self.address = "ws://[%s]:%d" % (name[0], name[1])
251 else:
252 self.address = "ws://%s:%d" % (name[0], name[1])
253
254 return [self.server.wait_closed()]
255
256 async def stop(self):
257 self.server.close()
258
259 def cleanup(self):
260 pass
261
262 async def client_handler(self, websocket):
263 socket = WebsocketConnection(websocket, -1)
264 await self.handler(socket)
265
266
267class AsyncServer(object):
268 def __init__(self, logger):
269 self.logger = logger
270 self.loop = None
271 self.run_tasks = []
272
273 def start_tcp_server(self, host, port, *, reuseport=False):
274 self.server = TCPStreamServer(
275 host,
276 port,
277 self._client_handler,
278 self.logger,
279 reuseport=reuseport,
280 )
281
282 def start_unix_server(self, path):
283 self.server = UnixStreamServer(path, self._client_handler, self.logger)
284
285 def start_websocket_server(self, host, port, reuseport=False):
286 self.server = WebsocketsServer(
287 host,
288 port,
289 self._client_handler,
290 self.logger,
291 reuseport=reuseport,
292 )
293
294 async def _client_handler(self, socket):
295 address = socket.address
296 try:
297 client = self.accept_client(socket)
298 await client.process_requests()
299 except Exception as e:
300 import traceback
301
302 self.logger.error(
303 "Error from client %s: %s" % (address, str(e)), exc_info=True
304 )
305 traceback.print_exc()
306 finally:
307 self.logger.debug("Client %s disconnected", address)
308 await socket.close()
309
310 @abc.abstractmethod
311 def accept_client(self, socket):
312 pass
313
314 async def stop(self):
315 self.logger.debug("Stopping server")
316 await self.server.stop()
317
318 def start(self):
319 tasks = self.server.start(self.loop)
320 self.address = self.server.address
321 return tasks
322
323 def signal_handler(self):
324 self.logger.debug("Got exit signal")
325 self.loop.create_task(self.stop())
326
327 def _serve_forever(self, tasks):
328 try:
329 self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler)
330 self.loop.add_signal_handler(signal.SIGINT, self.signal_handler)
331 self.loop.add_signal_handler(signal.SIGQUIT, self.signal_handler)
332 signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM])
333
334 self.loop.run_until_complete(asyncio.gather(*tasks))
335
336 self.logger.debug("Server shutting down")
337 finally:
338 self.server.cleanup()
339
340 def serve_forever(self):
341 """
342 Serve requests in the current process
343 """
344 self._create_loop()
345 tasks = self.start()
346 self._serve_forever(tasks)
347 self.loop.close()
348
349 def _create_loop(self):
350 # Create loop and override any loop that may have existed in
351 # a parent process. It is possible that the usecases of
352 # serve_forever might be constrained enough to allow using
353 # get_event_loop here, but better safe than sorry for now.
354 self.loop = asyncio.new_event_loop()
355 asyncio.set_event_loop(self.loop)
356
357 def serve_as_process(self, *, prefunc=None, args=(), log_level=None):
358 """
359 Serve requests in a child process
360 """
361
362 def run(queue):
363 # Create loop and override any loop that may have existed
364 # in a parent process. Without doing this and instead
365 # using get_event_loop, at the very minimum the hashserv
366 # unit tests will hang when running the second test.
367 # This happens since get_event_loop in the spawned server
368 # process for the second testcase ends up with the loop
369 # from the hashserv client created in the unit test process
370 # when running the first testcase. The problem is somewhat
371 # more general, though, as any potential use of asyncio in
372 # Cooker could create a loop that needs to replaced in this
373 # new process.
374 self._create_loop()
375 try:
376 self.address = None
377 tasks = self.start()
378 finally:
379 # Always put the server address to wake up the parent task
380 queue.put(self.address)
381 queue.close()
382
383 if prefunc is not None:
384 prefunc(self, *args)
385
386 if log_level is not None:
387 self.logger.setLevel(log_level)
388
389 self._serve_forever(tasks)
390
391 self.loop.run_until_complete(self.loop.shutdown_asyncgens())
392 self.loop.close()
393
394 queue = multiprocessing.Queue()
395
396 # Temporarily block SIGTERM. The server process will inherit this
397 # block which will ensure it doesn't receive the SIGTERM until the
398 # handler is ready for it
399 mask = signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGTERM])
400 try:
401 self.process = multiprocessing.Process(target=run, args=(queue,))
402 self.process.start()
403
404 self.address = queue.get()
405 queue.close()
406 queue.join_thread()
407
408 return self.process
409 finally:
410 signal.pthread_sigmask(signal.SIG_SETMASK, mask)