summaryrefslogtreecommitdiffstats
path: root/bitbake/lib/bb/asyncrpc/serv.py
diff options
context:
space:
mode:
Diffstat (limited to 'bitbake/lib/bb/asyncrpc/serv.py')
-rw-r--r--bitbake/lib/bb/asyncrpc/serv.py118
1 files changed, 72 insertions, 46 deletions
diff --git a/bitbake/lib/bb/asyncrpc/serv.py b/bitbake/lib/bb/asyncrpc/serv.py
index 4084f300df..45628698b6 100644
--- a/bitbake/lib/bb/asyncrpc/serv.py
+++ b/bitbake/lib/bb/asyncrpc/serv.py
@@ -131,53 +131,58 @@ class AsyncServerConnection(object):
131 131
132 132
133class AsyncServer(object): 133class AsyncServer(object):
134 def __init__(self, logger, loop=None): 134 def __init__(self, logger):
135 if loop is None:
136 self.loop = asyncio.new_event_loop()
137 self.close_loop = True
138 else:
139 self.loop = loop
140 self.close_loop = False
141
142 self._cleanup_socket = None 135 self._cleanup_socket = None
143 self.logger = logger 136 self.logger = logger
137 self.start = None
138 self.address = None
139
140 @property
141 def loop(self):
142 return asyncio.get_event_loop()
144 143
145 def start_tcp_server(self, host, port): 144 def start_tcp_server(self, host, port):
146 self.server = self.loop.run_until_complete( 145 def start_tcp():
147 asyncio.start_server(self.handle_client, host, port, loop=self.loop) 146 self.server = self.loop.run_until_complete(
148 ) 147 asyncio.start_server(self.handle_client, host, port)
149 148 )
150 for s in self.server.sockets: 149
151 self.logger.debug('Listening on %r' % (s.getsockname(),)) 150 for s in self.server.sockets:
152 # Newer python does this automatically. Do it manually here for 151 self.logger.debug('Listening on %r' % (s.getsockname(),))
153 # maximum compatibility 152 # Newer python does this automatically. Do it manually here for
154 s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) 153 # maximum compatibility
155 s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1) 154 s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
156 155 s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
157 name = self.server.sockets[0].getsockname() 156
158 if self.server.sockets[0].family == socket.AF_INET6: 157 name = self.server.sockets[0].getsockname()
159 self.address = "[%s]:%d" % (name[0], name[1]) 158 if self.server.sockets[0].family == socket.AF_INET6:
160 else: 159 self.address = "[%s]:%d" % (name[0], name[1])
161 self.address = "%s:%d" % (name[0], name[1]) 160 else:
161 self.address = "%s:%d" % (name[0], name[1])
162
163 self.start = start_tcp
162 164
163 def start_unix_server(self, path): 165 def start_unix_server(self, path):
164 def cleanup(): 166 def cleanup():
165 os.unlink(path) 167 os.unlink(path)
166 168
167 cwd = os.getcwd() 169 def start_unix():
168 try: 170 cwd = os.getcwd()
169 # Work around path length limits in AF_UNIX 171 try:
170 os.chdir(os.path.dirname(path)) 172 # Work around path length limits in AF_UNIX
171 self.server = self.loop.run_until_complete( 173 os.chdir(os.path.dirname(path))
172 asyncio.start_unix_server(self.handle_client, os.path.basename(path), loop=self.loop) 174 self.server = self.loop.run_until_complete(
173 ) 175 asyncio.start_unix_server(self.handle_client, os.path.basename(path))
174 finally: 176 )
175 os.chdir(cwd) 177 finally:
178 os.chdir(cwd)
176 179
177 self.logger.debug('Listening on %r' % path) 180 self.logger.debug('Listening on %r' % path)
178 181
179 self._cleanup_socket = cleanup 182 self._cleanup_socket = cleanup
180 self.address = "unix://%s" % os.path.abspath(path) 183 self.address = "unix://%s" % os.path.abspath(path)
184
185 self.start = start_unix
181 186
182 @abc.abstractmethod 187 @abc.abstractmethod
183 def accept_client(self, reader, writer): 188 def accept_client(self, reader, writer):
@@ -205,8 +210,7 @@ class AsyncServer(object):
205 self.logger.debug("Got exit signal") 210 self.logger.debug("Got exit signal")
206 self.loop.stop() 211 self.loop.stop()
207 212
208 def serve_forever(self): 213 def _serve_forever(self):
209 asyncio.set_event_loop(self.loop)
210 try: 214 try:
211 self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler) 215 self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler)
212 signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM]) 216 signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM])
@@ -217,28 +221,50 @@ class AsyncServer(object):
217 self.loop.run_until_complete(self.server.wait_closed()) 221 self.loop.run_until_complete(self.server.wait_closed())
218 self.logger.debug('Server shutting down') 222 self.logger.debug('Server shutting down')
219 finally: 223 finally:
220 if self.close_loop:
221 if sys.version_info >= (3, 6):
222 self.loop.run_until_complete(self.loop.shutdown_asyncgens())
223 self.loop.close()
224
225 if self._cleanup_socket is not None: 224 if self._cleanup_socket is not None:
226 self._cleanup_socket() 225 self._cleanup_socket()
227 226
227 def serve_forever(self):
228 """
229 Serve requests in the current process
230 """
231 self.start()
232 self._serve_forever()
233
228 def serve_as_process(self, *, prefunc=None, args=()): 234 def serve_as_process(self, *, prefunc=None, args=()):
229 def run(): 235 """
236 Serve requests in a child process
237 """
238 def run(queue):
239 try:
240 self.start()
241 finally:
242 queue.put(self.address)
243 queue.close()
244
230 if prefunc is not None: 245 if prefunc is not None:
231 prefunc(self, *args) 246 prefunc(self, *args)
232 self.serve_forever() 247
248 self._serve_forever()
249
250 if sys.version_info >= (3, 6):
251 self.loop.run_until_complete(self.loop.shutdown_asyncgens())
252 self.loop.close()
253
254 queue = multiprocessing.Queue()
233 255
234 # Temporarily block SIGTERM. The server process will inherit this 256 # Temporarily block SIGTERM. The server process will inherit this
235 # block which will ensure it doesn't receive the SIGTERM until the 257 # block which will ensure it doesn't receive the SIGTERM until the
236 # handler is ready for it 258 # handler is ready for it
237 mask = signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGTERM]) 259 mask = signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGTERM])
238 try: 260 try:
239 self.process = multiprocessing.Process(target=run) 261 self.process = multiprocessing.Process(target=run, args=(queue,))
240 self.process.start() 262 self.process.start()
241 263
264 self.address = queue.get()
265 queue.close()
266 queue.join_thread()
267
242 return self.process 268 return self.process
243 finally: 269 finally:
244 signal.pthread_sigmask(signal.SIG_SETMASK, mask) 270 signal.pthread_sigmask(signal.SIG_SETMASK, mask)