diff options
| -rw-r--r-- | bitbake/lib/bb/asyncrpc/serv.py | 21 | ||||
| -rw-r--r-- | bitbake/lib/bb/cooker.py | 3 | ||||
| -rw-r--r-- | bitbake/lib/hashserv/tests.py | 54 |
3 files changed, 64 insertions, 14 deletions
diff --git a/bitbake/lib/bb/asyncrpc/serv.py b/bitbake/lib/bb/asyncrpc/serv.py index ef20cb71df..4084f300df 100644 --- a/bitbake/lib/bb/asyncrpc/serv.py +++ b/bitbake/lib/bb/asyncrpc/serv.py | |||
| @@ -9,6 +9,7 @@ import os | |||
| 9 | import signal | 9 | import signal |
| 10 | import socket | 10 | import socket |
| 11 | import sys | 11 | import sys |
| 12 | import multiprocessing | ||
| 12 | from . import chunkify, DEFAULT_MAX_CHUNK | 13 | from . import chunkify, DEFAULT_MAX_CHUNK |
| 13 | 14 | ||
| 14 | 15 | ||
| @@ -201,12 +202,14 @@ class AsyncServer(object): | |||
| 201 | pass | 202 | pass |
| 202 | 203 | ||
| 203 | def signal_handler(self): | 204 | def signal_handler(self): |
| 205 | self.logger.debug("Got exit signal") | ||
| 204 | self.loop.stop() | 206 | self.loop.stop() |
| 205 | 207 | ||
| 206 | def serve_forever(self): | 208 | def serve_forever(self): |
| 207 | asyncio.set_event_loop(self.loop) | 209 | asyncio.set_event_loop(self.loop) |
| 208 | try: | 210 | try: |
| 209 | self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler) | 211 | self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler) |
| 212 | signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM]) | ||
| 210 | 213 | ||
| 211 | self.run_loop_forever() | 214 | self.run_loop_forever() |
| 212 | self.server.close() | 215 | self.server.close() |
| @@ -221,3 +224,21 @@ class AsyncServer(object): | |||
| 221 | 224 | ||
| 222 | if self._cleanup_socket is not None: | 225 | if self._cleanup_socket is not None: |
| 223 | self._cleanup_socket() | 226 | self._cleanup_socket() |
| 227 | |||
| 228 | def serve_as_process(self, *, prefunc=None, args=()): | ||
| 229 | def run(): | ||
| 230 | if prefunc is not None: | ||
| 231 | prefunc(self, *args) | ||
| 232 | self.serve_forever() | ||
| 233 | |||
| 234 | # Temporarily block SIGTERM. The server process will inherit this | ||
| 235 | # block which will ensure it doesn't receive the SIGTERM until the | ||
| 236 | # handler is ready for it | ||
| 237 | mask = signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGTERM]) | ||
| 238 | try: | ||
| 239 | self.process = multiprocessing.Process(target=run) | ||
| 240 | self.process.start() | ||
| 241 | |||
| 242 | return self.process | ||
| 243 | finally: | ||
| 244 | signal.pthread_sigmask(signal.SIG_SETMASK, mask) | ||
diff --git a/bitbake/lib/bb/cooker.py b/bitbake/lib/bb/cooker.py index 39e10e6133..b2d69c28cf 100644 --- a/bitbake/lib/bb/cooker.py +++ b/bitbake/lib/bb/cooker.py | |||
| @@ -390,8 +390,7 @@ class BBCooker: | |||
| 390 | dbfile = (self.data.getVar("PERSISTENT_DIR") or self.data.getVar("CACHE")) + "/hashserv.db" | 390 | dbfile = (self.data.getVar("PERSISTENT_DIR") or self.data.getVar("CACHE")) + "/hashserv.db" |
| 391 | self.hashservaddr = "unix://%s/hashserve.sock" % self.data.getVar("TOPDIR") | 391 | self.hashservaddr = "unix://%s/hashserve.sock" % self.data.getVar("TOPDIR") |
| 392 | self.hashserv = hashserv.create_server(self.hashservaddr, dbfile, sync=False) | 392 | self.hashserv = hashserv.create_server(self.hashservaddr, dbfile, sync=False) |
| 393 | self.hashserv.process = multiprocessing.Process(target=self.hashserv.serve_forever) | 393 | self.hashserv.serve_as_process() |
| 394 | self.hashserv.process.start() | ||
| 395 | self.data.setVar("BB_HASHSERVE", self.hashservaddr) | 394 | self.data.setVar("BB_HASHSERVE", self.hashservaddr) |
| 396 | self.databuilder.origdata.setVar("BB_HASHSERVE", self.hashservaddr) | 395 | self.databuilder.origdata.setVar("BB_HASHSERVE", self.hashservaddr) |
| 397 | self.databuilder.data.setVar("BB_HASHSERVE", self.hashservaddr) | 396 | self.databuilder.data.setVar("BB_HASHSERVE", self.hashservaddr) |
diff --git a/bitbake/lib/hashserv/tests.py b/bitbake/lib/hashserv/tests.py index e2b762dbf0..e851535c59 100644 --- a/bitbake/lib/hashserv/tests.py +++ b/bitbake/lib/hashserv/tests.py | |||
| @@ -15,28 +15,32 @@ import tempfile | |||
| 15 | import threading | 15 | import threading |
| 16 | import unittest | 16 | import unittest |
| 17 | import socket | 17 | import socket |
| 18 | import time | ||
| 19 | import signal | ||
| 18 | 20 | ||
| 19 | def _run_server(server, idx): | 21 | def server_prefunc(server, idx): |
| 20 | # logging.basicConfig(level=logging.DEBUG, filename='bbhashserv.log', filemode='w', | 22 | logging.basicConfig(level=logging.DEBUG, filename='bbhashserv.log', filemode='w', |
| 21 | # format='%(levelname)s %(filename)s:%(lineno)d %(message)s') | 23 | format='%(levelname)s %(filename)s:%(lineno)d %(message)s') |
| 24 | server.logger.debug("Running server %d" % idx) | ||
| 22 | sys.stdout = open('bbhashserv-%d.log' % idx, 'w') | 25 | sys.stdout = open('bbhashserv-%d.log' % idx, 'w') |
| 23 | sys.stderr = sys.stdout | 26 | sys.stderr = sys.stdout |
| 24 | server.serve_forever() | ||
| 25 | |||
| 26 | 27 | ||
| 27 | class HashEquivalenceTestSetup(object): | 28 | class HashEquivalenceTestSetup(object): |
| 28 | METHOD = 'TestMethod' | 29 | METHOD = 'TestMethod' |
| 29 | 30 | ||
| 30 | server_index = 0 | 31 | server_index = 0 |
| 31 | 32 | ||
| 32 | def start_server(self, dbpath=None, upstream=None, read_only=False): | 33 | def start_server(self, dbpath=None, upstream=None, read_only=False, prefunc=server_prefunc): |
| 33 | self.server_index += 1 | 34 | self.server_index += 1 |
| 34 | if dbpath is None: | 35 | if dbpath is None: |
| 35 | dbpath = os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index) | 36 | dbpath = os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index) |
| 36 | 37 | ||
| 37 | def cleanup_thread(thread): | 38 | def cleanup_server(server): |
| 38 | thread.terminate() | 39 | if server.process.exitcode is not None: |
| 39 | thread.join() | 40 | return |
| 41 | |||
| 42 | server.process.terminate() | ||
| 43 | server.process.join() | ||
| 40 | 44 | ||
| 41 | server = create_server(self.get_server_addr(self.server_index), | 45 | server = create_server(self.get_server_addr(self.server_index), |
| 42 | dbpath, | 46 | dbpath, |
| @@ -44,9 +48,8 @@ class HashEquivalenceTestSetup(object): | |||
| 44 | read_only=read_only) | 48 | read_only=read_only) |
| 45 | server.dbpath = dbpath | 49 | server.dbpath = dbpath |
| 46 | 50 | ||
| 47 | server.thread = multiprocessing.Process(target=_run_server, args=(server, self.server_index)) | 51 | server.serve_as_process(prefunc=prefunc, args=(self.server_index,)) |
| 48 | server.thread.start() | 52 | self.addCleanup(cleanup_server, server) |
| 49 | self.addCleanup(cleanup_thread, server.thread) | ||
| 50 | 53 | ||
| 51 | def cleanup_client(client): | 54 | def cleanup_client(client): |
| 52 | client.close() | 55 | client.close() |
| @@ -283,6 +286,33 @@ class HashEquivalenceCommonTests(object): | |||
| 283 | self.assertClientGetHash(self.client, taskhash2, None) | 286 | self.assertClientGetHash(self.client, taskhash2, None) |
| 284 | 287 | ||
| 285 | 288 | ||
| 289 | def test_slow_server_start(self): | ||
| 290 | """ | ||
| 291 | Ensures that the server will exit correctly even if it gets a SIGTERM | ||
| 292 | before entering the main loop | ||
| 293 | """ | ||
| 294 | |||
| 295 | event = multiprocessing.Event() | ||
| 296 | |||
| 297 | def prefunc(server, idx): | ||
| 298 | nonlocal event | ||
| 299 | server_prefunc(server, idx) | ||
| 300 | event.wait() | ||
| 301 | |||
| 302 | def do_nothing(signum, frame): | ||
| 303 | pass | ||
| 304 | |||
| 305 | old_signal = signal.signal(signal.SIGTERM, do_nothing) | ||
| 306 | self.addCleanup(signal.signal, signal.SIGTERM, old_signal) | ||
| 307 | |||
| 308 | _, server = self.start_server(prefunc=prefunc) | ||
| 309 | server.process.terminate() | ||
| 310 | time.sleep(30) | ||
| 311 | event.set() | ||
| 312 | server.process.join(300) | ||
| 313 | self.assertIsNotNone(server.process.exitcode, "Server did not exit in a timely manner!") | ||
| 314 | |||
| 315 | |||
| 286 | class TestHashEquivalenceUnixServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase): | 316 | class TestHashEquivalenceUnixServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase): |
| 287 | def get_server_addr(self, server_idx): | 317 | def get_server_addr(self, server_idx): |
| 288 | return "unix://" + os.path.join(self.temp_dir.name, 'sock%d' % server_idx) | 318 | return "unix://" + os.path.join(self.temp_dir.name, 'sock%d' % server_idx) |
