summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJoshua Watt <jpewhacker@gmail.com>2021-07-22 11:19:37 -0500
committerRichard Purdie <richard.purdie@linuxfoundation.org>2021-07-29 23:21:24 +0100
commita83ea01b99d799025caede57512648d3258579d0 (patch)
tree72b5d1858931bc1e5f7244c9dcff36493650e8a7
parent445c5b9324a6243a0a3941c5f00e369d7749efde (diff)
downloadpoky-a83ea01b99d799025caede57512648d3258579d0.tar.gz
bitbake: bitbake: asyncrpc: Catch early SIGTERM
If the SIGTERM signal is sent to an asyncrpc server before it has installed the SIGTERM handler in the main loop, it may miss the signal which will can cause the calling process to wait forever on the join(). To resolve this, the calling process should mask of SIGTERM before forking the server process and the server should unmask the signal only after the handler is installed. To simplify the usage of the server, an new helper function called serve_as_process() is added to do this automatically and correctly. Thanks: Scott Murray <scott.murray@konsulko.com> for helping debug (Bitbake rev: ef2865efa98ad20823267364f2159d8d8c931400) Signed-off-by: Joshua Watt <JPEWhacker@gmail.com> Signed-off-by: Richard Purdie <richard.purdie@linuxfoundation.org>
-rw-r--r--bitbake/lib/bb/asyncrpc/serv.py21
-rw-r--r--bitbake/lib/bb/cooker.py3
-rw-r--r--bitbake/lib/hashserv/tests.py54
3 files changed, 64 insertions, 14 deletions
diff --git a/bitbake/lib/bb/asyncrpc/serv.py b/bitbake/lib/bb/asyncrpc/serv.py
index ef20cb71df..4084f300df 100644
--- a/bitbake/lib/bb/asyncrpc/serv.py
+++ b/bitbake/lib/bb/asyncrpc/serv.py
@@ -9,6 +9,7 @@ import os
9import signal 9import signal
10import socket 10import socket
11import sys 11import sys
12import multiprocessing
12from . import chunkify, DEFAULT_MAX_CHUNK 13from . import chunkify, DEFAULT_MAX_CHUNK
13 14
14 15
@@ -201,12 +202,14 @@ class AsyncServer(object):
201 pass 202 pass
202 203
203 def signal_handler(self): 204 def signal_handler(self):
205 self.logger.debug("Got exit signal")
204 self.loop.stop() 206 self.loop.stop()
205 207
206 def serve_forever(self): 208 def serve_forever(self):
207 asyncio.set_event_loop(self.loop) 209 asyncio.set_event_loop(self.loop)
208 try: 210 try:
209 self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler) 211 self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler)
212 signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM])
210 213
211 self.run_loop_forever() 214 self.run_loop_forever()
212 self.server.close() 215 self.server.close()
@@ -221,3 +224,21 @@ class AsyncServer(object):
221 224
222 if self._cleanup_socket is not None: 225 if self._cleanup_socket is not None:
223 self._cleanup_socket() 226 self._cleanup_socket()
227
228 def serve_as_process(self, *, prefunc=None, args=()):
229 def run():
230 if prefunc is not None:
231 prefunc(self, *args)
232 self.serve_forever()
233
234 # Temporarily block SIGTERM. The server process will inherit this
235 # block which will ensure it doesn't receive the SIGTERM until the
236 # handler is ready for it
237 mask = signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGTERM])
238 try:
239 self.process = multiprocessing.Process(target=run)
240 self.process.start()
241
242 return self.process
243 finally:
244 signal.pthread_sigmask(signal.SIG_SETMASK, mask)
diff --git a/bitbake/lib/bb/cooker.py b/bitbake/lib/bb/cooker.py
index 39e10e6133..b2d69c28cf 100644
--- a/bitbake/lib/bb/cooker.py
+++ b/bitbake/lib/bb/cooker.py
@@ -390,8 +390,7 @@ class BBCooker:
390 dbfile = (self.data.getVar("PERSISTENT_DIR") or self.data.getVar("CACHE")) + "/hashserv.db" 390 dbfile = (self.data.getVar("PERSISTENT_DIR") or self.data.getVar("CACHE")) + "/hashserv.db"
391 self.hashservaddr = "unix://%s/hashserve.sock" % self.data.getVar("TOPDIR") 391 self.hashservaddr = "unix://%s/hashserve.sock" % self.data.getVar("TOPDIR")
392 self.hashserv = hashserv.create_server(self.hashservaddr, dbfile, sync=False) 392 self.hashserv = hashserv.create_server(self.hashservaddr, dbfile, sync=False)
393 self.hashserv.process = multiprocessing.Process(target=self.hashserv.serve_forever) 393 self.hashserv.serve_as_process()
394 self.hashserv.process.start()
395 self.data.setVar("BB_HASHSERVE", self.hashservaddr) 394 self.data.setVar("BB_HASHSERVE", self.hashservaddr)
396 self.databuilder.origdata.setVar("BB_HASHSERVE", self.hashservaddr) 395 self.databuilder.origdata.setVar("BB_HASHSERVE", self.hashservaddr)
397 self.databuilder.data.setVar("BB_HASHSERVE", self.hashservaddr) 396 self.databuilder.data.setVar("BB_HASHSERVE", self.hashservaddr)
diff --git a/bitbake/lib/hashserv/tests.py b/bitbake/lib/hashserv/tests.py
index e2b762dbf0..e851535c59 100644
--- a/bitbake/lib/hashserv/tests.py
+++ b/bitbake/lib/hashserv/tests.py
@@ -15,28 +15,32 @@ import tempfile
15import threading 15import threading
16import unittest 16import unittest
17import socket 17import socket
18import time
19import signal
18 20
19def _run_server(server, idx): 21def server_prefunc(server, idx):
20 # logging.basicConfig(level=logging.DEBUG, filename='bbhashserv.log', filemode='w', 22 logging.basicConfig(level=logging.DEBUG, filename='bbhashserv.log', filemode='w',
21 # format='%(levelname)s %(filename)s:%(lineno)d %(message)s') 23 format='%(levelname)s %(filename)s:%(lineno)d %(message)s')
24 server.logger.debug("Running server %d" % idx)
22 sys.stdout = open('bbhashserv-%d.log' % idx, 'w') 25 sys.stdout = open('bbhashserv-%d.log' % idx, 'w')
23 sys.stderr = sys.stdout 26 sys.stderr = sys.stdout
24 server.serve_forever()
25
26 27
27class HashEquivalenceTestSetup(object): 28class HashEquivalenceTestSetup(object):
28 METHOD = 'TestMethod' 29 METHOD = 'TestMethod'
29 30
30 server_index = 0 31 server_index = 0
31 32
32 def start_server(self, dbpath=None, upstream=None, read_only=False): 33 def start_server(self, dbpath=None, upstream=None, read_only=False, prefunc=server_prefunc):
33 self.server_index += 1 34 self.server_index += 1
34 if dbpath is None: 35 if dbpath is None:
35 dbpath = os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index) 36 dbpath = os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index)
36 37
37 def cleanup_thread(thread): 38 def cleanup_server(server):
38 thread.terminate() 39 if server.process.exitcode is not None:
39 thread.join() 40 return
41
42 server.process.terminate()
43 server.process.join()
40 44
41 server = create_server(self.get_server_addr(self.server_index), 45 server = create_server(self.get_server_addr(self.server_index),
42 dbpath, 46 dbpath,
@@ -44,9 +48,8 @@ class HashEquivalenceTestSetup(object):
44 read_only=read_only) 48 read_only=read_only)
45 server.dbpath = dbpath 49 server.dbpath = dbpath
46 50
47 server.thread = multiprocessing.Process(target=_run_server, args=(server, self.server_index)) 51 server.serve_as_process(prefunc=prefunc, args=(self.server_index,))
48 server.thread.start() 52 self.addCleanup(cleanup_server, server)
49 self.addCleanup(cleanup_thread, server.thread)
50 53
51 def cleanup_client(client): 54 def cleanup_client(client):
52 client.close() 55 client.close()
@@ -283,6 +286,33 @@ class HashEquivalenceCommonTests(object):
283 self.assertClientGetHash(self.client, taskhash2, None) 286 self.assertClientGetHash(self.client, taskhash2, None)
284 287
285 288
289 def test_slow_server_start(self):
290 """
291 Ensures that the server will exit correctly even if it gets a SIGTERM
292 before entering the main loop
293 """
294
295 event = multiprocessing.Event()
296
297 def prefunc(server, idx):
298 nonlocal event
299 server_prefunc(server, idx)
300 event.wait()
301
302 def do_nothing(signum, frame):
303 pass
304
305 old_signal = signal.signal(signal.SIGTERM, do_nothing)
306 self.addCleanup(signal.signal, signal.SIGTERM, old_signal)
307
308 _, server = self.start_server(prefunc=prefunc)
309 server.process.terminate()
310 time.sleep(30)
311 event.set()
312 server.process.join(300)
313 self.assertIsNotNone(server.process.exitcode, "Server did not exit in a timely manner!")
314
315
286class TestHashEquivalenceUnixServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase): 316class TestHashEquivalenceUnixServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
287 def get_server_addr(self, server_idx): 317 def get_server_addr(self, server_idx):
288 return "unix://" + os.path.join(self.temp_dir.name, 'sock%d' % server_idx) 318 return "unix://" + os.path.join(self.temp_dir.name, 'sock%d' % server_idx)