summaryrefslogtreecommitdiffstats
path: root/bitbake/lib/bb/asyncrpc
diff options
context:
space:
mode:
Diffstat (limited to 'bitbake/lib/bb/asyncrpc')
-rw-r--r--bitbake/lib/bb/asyncrpc/__init__.py16
-rw-r--r--bitbake/lib/bb/asyncrpc/client.py266
-rw-r--r--bitbake/lib/bb/asyncrpc/connection.py146
-rw-r--r--bitbake/lib/bb/asyncrpc/exceptions.py21
-rw-r--r--bitbake/lib/bb/asyncrpc/serv.py410
5 files changed, 859 insertions, 0 deletions
diff --git a/bitbake/lib/bb/asyncrpc/__init__.py b/bitbake/lib/bb/asyncrpc/__init__.py
new file mode 100644
index 0000000000..a4371643d7
--- /dev/null
+++ b/bitbake/lib/bb/asyncrpc/__init__.py
@@ -0,0 +1,16 @@
1#
2# Copyright BitBake Contributors
3#
4# SPDX-License-Identifier: GPL-2.0-only
5#
6
7
8from .client import AsyncClient, Client
9from .serv import AsyncServer, AsyncServerConnection
10from .connection import DEFAULT_MAX_CHUNK
11from .exceptions import (
12 ClientError,
13 ServerError,
14 ConnectionClosedError,
15 InvokeError,
16)
diff --git a/bitbake/lib/bb/asyncrpc/client.py b/bitbake/lib/bb/asyncrpc/client.py
new file mode 100644
index 0000000000..9be49261c0
--- /dev/null
+++ b/bitbake/lib/bb/asyncrpc/client.py
@@ -0,0 +1,266 @@
1#
2# Copyright BitBake Contributors
3#
4# SPDX-License-Identifier: GPL-2.0-only
5#
6
7import abc
8import asyncio
9import json
10import os
11import socket
12import sys
13import re
14import contextlib
15from threading import Thread
16from .connection import StreamConnection, WebsocketConnection, DEFAULT_MAX_CHUNK
17from .exceptions import ConnectionClosedError, InvokeError
18
19UNIX_PREFIX = "unix://"
20WS_PREFIX = "ws://"
21WSS_PREFIX = "wss://"
22
23ADDR_TYPE_UNIX = 0
24ADDR_TYPE_TCP = 1
25ADDR_TYPE_WS = 2
26
27WEBSOCKETS_MIN_VERSION = (9, 1)
28# Need websockets 10 with python 3.10+
29if sys.version_info >= (3, 10, 0):
30 WEBSOCKETS_MIN_VERSION = (10, 0)
31
32
33def parse_address(addr):
34 if addr.startswith(UNIX_PREFIX):
35 return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX) :],))
36 elif addr.startswith(WS_PREFIX) or addr.startswith(WSS_PREFIX):
37 return (ADDR_TYPE_WS, (addr,))
38 else:
39 m = re.match(r"\[(?P<host>[^\]]*)\]:(?P<port>\d+)$", addr)
40 if m is not None:
41 host = m.group("host")
42 port = m.group("port")
43 else:
44 host, port = addr.split(":")
45
46 return (ADDR_TYPE_TCP, (host, int(port)))
47
48
49class AsyncClient(object):
50 def __init__(
51 self,
52 proto_name,
53 proto_version,
54 logger,
55 timeout=30,
56 server_headers=False,
57 headers={},
58 ):
59 self.socket = None
60 self.max_chunk = DEFAULT_MAX_CHUNK
61 self.proto_name = proto_name
62 self.proto_version = proto_version
63 self.logger = logger
64 self.timeout = timeout
65 self.needs_server_headers = server_headers
66 self.server_headers = {}
67 self.headers = headers
68
69 async def connect_tcp(self, address, port):
70 async def connect_sock():
71 reader, writer = await asyncio.open_connection(address, port)
72 return StreamConnection(reader, writer, self.timeout, self.max_chunk)
73
74 self._connect_sock = connect_sock
75
76 async def connect_unix(self, path):
77 async def connect_sock():
78 # AF_UNIX has path length issues so chdir here to workaround
79 cwd = os.getcwd()
80 try:
81 os.chdir(os.path.dirname(path))
82 # The socket must be opened synchronously so that CWD doesn't get
83 # changed out from underneath us so we pass as a sock into asyncio
84 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0)
85 sock.connect(os.path.basename(path))
86 finally:
87 os.chdir(cwd)
88 reader, writer = await asyncio.open_unix_connection(sock=sock)
89 return StreamConnection(reader, writer, self.timeout, self.max_chunk)
90
91 self._connect_sock = connect_sock
92
93 async def connect_websocket(self, uri):
94 import websockets
95
96 try:
97 version = tuple(
98 int(v)
99 for v in websockets.__version__.split(".")[
100 0 : len(WEBSOCKETS_MIN_VERSION)
101 ]
102 )
103 except ValueError:
104 raise ImportError(
105 f"Unable to parse websockets version '{websockets.__version__}'"
106 )
107
108 if version < WEBSOCKETS_MIN_VERSION:
109 min_ver_str = ".".join(str(v) for v in WEBSOCKETS_MIN_VERSION)
110 raise ImportError(
111 f"Websockets version {websockets.__version__} is less than minimum required version {min_ver_str}"
112 )
113
114 async def connect_sock():
115 websocket = await websockets.connect(
116 uri,
117 ping_interval=None,
118 open_timeout=self.timeout,
119 )
120 return WebsocketConnection(websocket, self.timeout)
121
122 self._connect_sock = connect_sock
123
124 async def setup_connection(self):
125 # Send headers
126 await self.socket.send("%s %s" % (self.proto_name, self.proto_version))
127 await self.socket.send(
128 "needs-headers: %s" % ("true" if self.needs_server_headers else "false")
129 )
130 for k, v in self.headers.items():
131 await self.socket.send("%s: %s" % (k, v))
132
133 # End of headers
134 await self.socket.send("")
135
136 self.server_headers = {}
137 if self.needs_server_headers:
138 while True:
139 line = await self.socket.recv()
140 if not line:
141 # End headers
142 break
143 tag, value = line.split(":", 1)
144 self.server_headers[tag.lower()] = value.strip()
145
146 async def get_header(self, tag, default):
147 await self.connect()
148 return self.server_headers.get(tag, default)
149
150 async def connect(self):
151 if self.socket is None:
152 self.socket = await self._connect_sock()
153 await self.setup_connection()
154
155 async def disconnect(self):
156 if self.socket is not None:
157 await self.socket.close()
158 self.socket = None
159
160 async def close(self):
161 await self.disconnect()
162
163 async def _send_wrapper(self, proc):
164 count = 0
165 while True:
166 try:
167 await self.connect()
168 return await proc()
169 except (
170 OSError,
171 ConnectionError,
172 ConnectionClosedError,
173 json.JSONDecodeError,
174 UnicodeDecodeError,
175 ) as e:
176 self.logger.warning("Error talking to server: %s" % e)
177 if count >= 3:
178 if not isinstance(e, ConnectionError):
179 raise ConnectionError(str(e))
180 raise e
181 await self.close()
182 count += 1
183
184 def check_invoke_error(self, msg):
185 if isinstance(msg, dict) and "invoke-error" in msg:
186 raise InvokeError(msg["invoke-error"]["message"])
187
188 async def invoke(self, msg):
189 async def proc():
190 await self.socket.send_message(msg)
191 return await self.socket.recv_message()
192
193 result = await self._send_wrapper(proc)
194 self.check_invoke_error(result)
195 return result
196
197 async def ping(self):
198 return await self.invoke({"ping": {}})
199
200 async def __aenter__(self):
201 return self
202
203 async def __aexit__(self, exc_type, exc_value, traceback):
204 await self.close()
205
206
207class Client(object):
208 def __init__(self):
209 self.client = self._get_async_client()
210 self.loop = asyncio.new_event_loop()
211
212 # Override any pre-existing loop.
213 # Without this, the PR server export selftest triggers a hang
214 # when running with Python 3.7. The drawback is that there is
215 # potential for issues if the PR and hash equiv (or some new)
216 # clients need to both be instantiated in the same process.
217 # This should be revisited if/when Python 3.9 becomes the
218 # minimum required version for BitBake, as it seems not
219 # required (but harmless) with it.
220 asyncio.set_event_loop(self.loop)
221
222 self._add_methods("connect_tcp", "ping")
223
224 @abc.abstractmethod
225 def _get_async_client(self):
226 pass
227
228 def _get_downcall_wrapper(self, downcall):
229 def wrapper(*args, **kwargs):
230 return self.loop.run_until_complete(downcall(*args, **kwargs))
231
232 return wrapper
233
234 def _add_methods(self, *methods):
235 for m in methods:
236 downcall = getattr(self.client, m)
237 setattr(self, m, self._get_downcall_wrapper(downcall))
238
239 def connect_unix(self, path):
240 self.loop.run_until_complete(self.client.connect_unix(path))
241 self.loop.run_until_complete(self.client.connect())
242
243 @property
244 def max_chunk(self):
245 return self.client.max_chunk
246
247 @max_chunk.setter
248 def max_chunk(self, value):
249 self.client.max_chunk = value
250
251 def disconnect(self):
252 self.loop.run_until_complete(self.client.close())
253
254 def close(self):
255 if self.loop:
256 self.loop.run_until_complete(self.client.close())
257 self.loop.run_until_complete(self.loop.shutdown_asyncgens())
258 self.loop.close()
259 self.loop = None
260
261 def __enter__(self):
262 return self
263
264 def __exit__(self, exc_type, exc_value, traceback):
265 self.close()
266 return False
diff --git a/bitbake/lib/bb/asyncrpc/connection.py b/bitbake/lib/bb/asyncrpc/connection.py
new file mode 100644
index 0000000000..7f0cf6ba96
--- /dev/null
+++ b/bitbake/lib/bb/asyncrpc/connection.py
@@ -0,0 +1,146 @@
1#
2# Copyright BitBake Contributors
3#
4# SPDX-License-Identifier: GPL-2.0-only
5#
6
7import asyncio
8import itertools
9import json
10from datetime import datetime
11from .exceptions import ClientError, ConnectionClosedError
12
13
14# The Python async server defaults to a 64K receive buffer, so we hardcode our
15# maximum chunk size. It would be better if the client and server reported to
16# each other what the maximum chunk sizes were, but that will slow down the
17# connection setup with a round trip delay so I'd rather not do that unless it
18# is necessary
19DEFAULT_MAX_CHUNK = 32 * 1024
20
21
22def chunkify(msg, max_chunk):
23 if len(msg) < max_chunk - 1:
24 yield "".join((msg, "\n"))
25 else:
26 yield "".join((json.dumps({"chunk-stream": None}), "\n"))
27
28 args = [iter(msg)] * (max_chunk - 1)
29 for m in map("".join, itertools.zip_longest(*args, fillvalue="")):
30 yield "".join(itertools.chain(m, "\n"))
31 yield "\n"
32
33
34def json_serialize(obj):
35 if isinstance(obj, datetime):
36 return obj.isoformat()
37 raise TypeError("Type %s not serializeable" % type(obj))
38
39
40class StreamConnection(object):
41 def __init__(self, reader, writer, timeout, max_chunk=DEFAULT_MAX_CHUNK):
42 self.reader = reader
43 self.writer = writer
44 self.timeout = timeout
45 self.max_chunk = max_chunk
46
47 @property
48 def address(self):
49 return self.writer.get_extra_info("peername")
50
51 async def send_message(self, msg):
52 for c in chunkify(json.dumps(msg, default=json_serialize), self.max_chunk):
53 self.writer.write(c.encode("utf-8"))
54 await self.writer.drain()
55
56 async def recv_message(self):
57 l = await self.recv()
58
59 m = json.loads(l)
60 if not m:
61 return m
62
63 if "chunk-stream" in m:
64 lines = []
65 while True:
66 l = await self.recv()
67 if not l:
68 break
69 lines.append(l)
70
71 m = json.loads("".join(lines))
72
73 return m
74
75 async def send(self, msg):
76 self.writer.write(("%s\n" % msg).encode("utf-8"))
77 await self.writer.drain()
78
79 async def recv(self):
80 if self.timeout < 0:
81 line = await self.reader.readline()
82 else:
83 try:
84 line = await asyncio.wait_for(self.reader.readline(), self.timeout)
85 except asyncio.TimeoutError:
86 raise ConnectionError("Timed out waiting for data")
87
88 if not line:
89 raise ConnectionClosedError("Connection closed")
90
91 line = line.decode("utf-8")
92
93 if not line.endswith("\n"):
94 raise ConnectionError("Bad message %r" % (line))
95
96 return line.rstrip()
97
98 async def close(self):
99 self.reader = None
100 if self.writer is not None:
101 self.writer.close()
102 self.writer = None
103
104
105class WebsocketConnection(object):
106 def __init__(self, socket, timeout):
107 self.socket = socket
108 self.timeout = timeout
109
110 @property
111 def address(self):
112 return ":".join(str(s) for s in self.socket.remote_address)
113
114 async def send_message(self, msg):
115 await self.send(json.dumps(msg, default=json_serialize))
116
117 async def recv_message(self):
118 m = await self.recv()
119 return json.loads(m)
120
121 async def send(self, msg):
122 import websockets.exceptions
123
124 try:
125 await self.socket.send(msg)
126 except websockets.exceptions.ConnectionClosed:
127 raise ConnectionClosedError("Connection closed")
128
129 async def recv(self):
130 import websockets.exceptions
131
132 try:
133 if self.timeout < 0:
134 return await self.socket.recv()
135
136 try:
137 return await asyncio.wait_for(self.socket.recv(), self.timeout)
138 except asyncio.TimeoutError:
139 raise ConnectionError("Timed out waiting for data")
140 except websockets.exceptions.ConnectionClosed:
141 raise ConnectionClosedError("Connection closed")
142
143 async def close(self):
144 if self.socket is not None:
145 await self.socket.close()
146 self.socket = None
diff --git a/bitbake/lib/bb/asyncrpc/exceptions.py b/bitbake/lib/bb/asyncrpc/exceptions.py
new file mode 100644
index 0000000000..ae1043a38b
--- /dev/null
+++ b/bitbake/lib/bb/asyncrpc/exceptions.py
@@ -0,0 +1,21 @@
1#
2# Copyright BitBake Contributors
3#
4# SPDX-License-Identifier: GPL-2.0-only
5#
6
7
8class ClientError(Exception):
9 pass
10
11
12class InvokeError(Exception):
13 pass
14
15
16class ServerError(Exception):
17 pass
18
19
20class ConnectionClosedError(Exception):
21 pass
diff --git a/bitbake/lib/bb/asyncrpc/serv.py b/bitbake/lib/bb/asyncrpc/serv.py
new file mode 100644
index 0000000000..667217c5c1
--- /dev/null
+++ b/bitbake/lib/bb/asyncrpc/serv.py
@@ -0,0 +1,410 @@
1#
2# Copyright BitBake Contributors
3#
4# SPDX-License-Identifier: GPL-2.0-only
5#
6
7import abc
8import asyncio
9import json
10import os
11import signal
12import socket
13import sys
14import multiprocessing
15import logging
16from .connection import StreamConnection, WebsocketConnection
17from .exceptions import ClientError, ServerError, ConnectionClosedError, InvokeError
18
19
20class ClientLoggerAdapter(logging.LoggerAdapter):
21 def process(self, msg, kwargs):
22 return f"[Client {self.extra['address']}] {msg}", kwargs
23
24
25class AsyncServerConnection(object):
26 # If a handler returns this object (e.g. `return self.NO_RESPONSE`), no
27 # return message will be automatically be sent back to the client
28 NO_RESPONSE = object()
29
30 def __init__(self, socket, proto_name, logger):
31 self.socket = socket
32 self.proto_name = proto_name
33 self.handlers = {
34 "ping": self.handle_ping,
35 }
36 self.logger = ClientLoggerAdapter(
37 logger,
38 {
39 "address": socket.address,
40 },
41 )
42 self.client_headers = {}
43
44 async def close(self):
45 await self.socket.close()
46
47 async def handle_headers(self, headers):
48 return {}
49
50 async def process_requests(self):
51 try:
52 self.logger.info("Client %r connected" % (self.socket.address,))
53
54 # Read protocol and version
55 client_protocol = await self.socket.recv()
56 if not client_protocol:
57 return
58
59 (client_proto_name, client_proto_version) = client_protocol.split()
60 if client_proto_name != self.proto_name:
61 self.logger.debug("Rejecting invalid protocol %s" % (self.proto_name))
62 return
63
64 self.proto_version = tuple(int(v) for v in client_proto_version.split("."))
65 if not self.validate_proto_version():
66 self.logger.debug(
67 "Rejecting invalid protocol version %s" % (client_proto_version)
68 )
69 return
70
71 # Read headers
72 self.client_headers = {}
73 while True:
74 header = await self.socket.recv()
75 if not header:
76 # Empty line. End of headers
77 break
78 tag, value = header.split(":", 1)
79 self.client_headers[tag.lower()] = value.strip()
80
81 if self.client_headers.get("needs-headers", "false") == "true":
82 for k, v in (await self.handle_headers(self.client_headers)).items():
83 await self.socket.send("%s: %s" % (k, v))
84 await self.socket.send("")
85
86 # Handle messages
87 while True:
88 d = await self.socket.recv_message()
89 if d is None:
90 break
91 try:
92 response = await self.dispatch_message(d)
93 except InvokeError as e:
94 await self.socket.send_message(
95 {"invoke-error": {"message": str(e)}}
96 )
97 break
98
99 if response is not self.NO_RESPONSE:
100 await self.socket.send_message(response)
101
102 except ConnectionClosedError as e:
103 self.logger.info(str(e))
104 except (ClientError, ConnectionError) as e:
105 self.logger.error(str(e))
106 finally:
107 await self.close()
108
109 async def dispatch_message(self, msg):
110 for k in self.handlers.keys():
111 if k in msg:
112 self.logger.debug("Handling %s" % k)
113 return await self.handlers[k](msg[k])
114
115 raise ClientError("Unrecognized command %r" % msg)
116
117 async def handle_ping(self, request):
118 return {"alive": True}
119
120
121class StreamServer(object):
122 def __init__(self, handler, logger):
123 self.handler = handler
124 self.logger = logger
125 self.closed = False
126
127 async def handle_stream_client(self, reader, writer):
128 # writer.transport.set_write_buffer_limits(0)
129 socket = StreamConnection(reader, writer, -1)
130 if self.closed:
131 await socket.close()
132 return
133
134 await self.handler(socket)
135
136 async def stop(self):
137 self.closed = True
138
139
140class TCPStreamServer(StreamServer):
141 def __init__(self, host, port, handler, logger, *, reuseport=False):
142 super().__init__(handler, logger)
143 self.host = host
144 self.port = port
145 self.reuseport = reuseport
146
147 def start(self, loop):
148 self.server = loop.run_until_complete(
149 asyncio.start_server(
150 self.handle_stream_client,
151 self.host,
152 self.port,
153 reuse_port=self.reuseport,
154 )
155 )
156
157 for s in self.server.sockets:
158 self.logger.debug("Listening on %r" % (s.getsockname(),))
159 # Newer python does this automatically. Do it manually here for
160 # maximum compatibility
161 s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
162 s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
163
164 # Enable keep alives. This prevents broken client connections
165 # from persisting on the server for long periods of time.
166 s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
167 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
168 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
169 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
170
171 name = self.server.sockets[0].getsockname()
172 if self.server.sockets[0].family == socket.AF_INET6:
173 self.address = "[%s]:%d" % (name[0], name[1])
174 else:
175 self.address = "%s:%d" % (name[0], name[1])
176
177 return [self.server.wait_closed()]
178
179 async def stop(self):
180 await super().stop()
181 self.server.close()
182
183 def cleanup(self):
184 pass
185
186
187class UnixStreamServer(StreamServer):
188 def __init__(self, path, handler, logger):
189 super().__init__(handler, logger)
190 self.path = path
191
192 def start(self, loop):
193 cwd = os.getcwd()
194 try:
195 # Work around path length limits in AF_UNIX
196 os.chdir(os.path.dirname(self.path))
197 self.server = loop.run_until_complete(
198 asyncio.start_unix_server(
199 self.handle_stream_client, os.path.basename(self.path)
200 )
201 )
202 finally:
203 os.chdir(cwd)
204
205 self.logger.debug("Listening on %r" % self.path)
206 self.address = "unix://%s" % os.path.abspath(self.path)
207 return [self.server.wait_closed()]
208
209 async def stop(self):
210 await super().stop()
211 self.server.close()
212
213 def cleanup(self):
214 os.unlink(self.path)
215
216
217class WebsocketsServer(object):
218 def __init__(self, host, port, handler, logger, *, reuseport=False):
219 self.host = host
220 self.port = port
221 self.handler = handler
222 self.logger = logger
223 self.reuseport = reuseport
224
225 def start(self, loop):
226 import websockets.server
227
228 self.server = loop.run_until_complete(
229 websockets.server.serve(
230 self.client_handler,
231 self.host,
232 self.port,
233 ping_interval=None,
234 reuse_port=self.reuseport,
235 )
236 )
237
238 for s in self.server.sockets:
239 self.logger.debug("Listening on %r" % (s.getsockname(),))
240
241 # Enable keep alives. This prevents broken client connections
242 # from persisting on the server for long periods of time.
243 s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
244 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
245 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
246 s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
247
248 name = self.server.sockets[0].getsockname()
249 if self.server.sockets[0].family == socket.AF_INET6:
250 self.address = "ws://[%s]:%d" % (name[0], name[1])
251 else:
252 self.address = "ws://%s:%d" % (name[0], name[1])
253
254 return [self.server.wait_closed()]
255
256 async def stop(self):
257 self.server.close()
258
259 def cleanup(self):
260 pass
261
262 async def client_handler(self, websocket):
263 socket = WebsocketConnection(websocket, -1)
264 await self.handler(socket)
265
266
267class AsyncServer(object):
268 def __init__(self, logger):
269 self.logger = logger
270 self.loop = None
271 self.run_tasks = []
272
273 def start_tcp_server(self, host, port, *, reuseport=False):
274 self.server = TCPStreamServer(
275 host,
276 port,
277 self._client_handler,
278 self.logger,
279 reuseport=reuseport,
280 )
281
282 def start_unix_server(self, path):
283 self.server = UnixStreamServer(path, self._client_handler, self.logger)
284
285 def start_websocket_server(self, host, port, reuseport=False):
286 self.server = WebsocketsServer(
287 host,
288 port,
289 self._client_handler,
290 self.logger,
291 reuseport=reuseport,
292 )
293
294 async def _client_handler(self, socket):
295 address = socket.address
296 try:
297 client = self.accept_client(socket)
298 await client.process_requests()
299 except Exception as e:
300 import traceback
301
302 self.logger.error(
303 "Error from client %s: %s" % (address, str(e)), exc_info=True
304 )
305 traceback.print_exc()
306 finally:
307 self.logger.debug("Client %s disconnected", address)
308 await socket.close()
309
310 @abc.abstractmethod
311 def accept_client(self, socket):
312 pass
313
314 async def stop(self):
315 self.logger.debug("Stopping server")
316 await self.server.stop()
317
318 def start(self):
319 tasks = self.server.start(self.loop)
320 self.address = self.server.address
321 return tasks
322
323 def signal_handler(self):
324 self.logger.debug("Got exit signal")
325 self.loop.create_task(self.stop())
326
327 def _serve_forever(self, tasks):
328 try:
329 self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler)
330 self.loop.add_signal_handler(signal.SIGINT, self.signal_handler)
331 self.loop.add_signal_handler(signal.SIGQUIT, self.signal_handler)
332 signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM])
333
334 self.loop.run_until_complete(asyncio.gather(*tasks))
335
336 self.logger.debug("Server shutting down")
337 finally:
338 self.server.cleanup()
339
340 def serve_forever(self):
341 """
342 Serve requests in the current process
343 """
344 self._create_loop()
345 tasks = self.start()
346 self._serve_forever(tasks)
347 self.loop.close()
348
349 def _create_loop(self):
350 # Create loop and override any loop that may have existed in
351 # a parent process. It is possible that the usecases of
352 # serve_forever might be constrained enough to allow using
353 # get_event_loop here, but better safe than sorry for now.
354 self.loop = asyncio.new_event_loop()
355 asyncio.set_event_loop(self.loop)
356
357 def serve_as_process(self, *, prefunc=None, args=(), log_level=None):
358 """
359 Serve requests in a child process
360 """
361
362 def run(queue):
363 # Create loop and override any loop that may have existed
364 # in a parent process. Without doing this and instead
365 # using get_event_loop, at the very minimum the hashserv
366 # unit tests will hang when running the second test.
367 # This happens since get_event_loop in the spawned server
368 # process for the second testcase ends up with the loop
369 # from the hashserv client created in the unit test process
370 # when running the first testcase. The problem is somewhat
371 # more general, though, as any potential use of asyncio in
372 # Cooker could create a loop that needs to replaced in this
373 # new process.
374 self._create_loop()
375 try:
376 self.address = None
377 tasks = self.start()
378 finally:
379 # Always put the server address to wake up the parent task
380 queue.put(self.address)
381 queue.close()
382
383 if prefunc is not None:
384 prefunc(self, *args)
385
386 if log_level is not None:
387 self.logger.setLevel(log_level)
388
389 self._serve_forever(tasks)
390
391 self.loop.run_until_complete(self.loop.shutdown_asyncgens())
392 self.loop.close()
393
394 queue = multiprocessing.Queue()
395
396 # Temporarily block SIGTERM. The server process will inherit this
397 # block which will ensure it doesn't receive the SIGTERM until the
398 # handler is ready for it
399 mask = signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGTERM])
400 try:
401 self.process = multiprocessing.Process(target=run, args=(queue,))
402 self.process.start()
403
404 self.address = queue.get()
405 queue.close()
406 queue.join_thread()
407
408 return self.process
409 finally:
410 signal.pthread_sigmask(signal.SIG_SETMASK, mask)