summaryrefslogtreecommitdiffstats
path: root/bitbake/lib/hashserv
diff options
context:
space:
mode:
authorJoshua Watt <jpewhacker@gmail.com>2019-09-17 08:37:11 -0500
committerRichard Purdie <richard.purdie@linuxfoundation.org>2019-09-18 17:52:01 +0100
commit20f032338ff3b4b25f2cbb7f975b5fd1c105004d (patch)
tree84c1a4693fdbaac0823e99d70ed7c814b890244c /bitbake/lib/hashserv
parent34923e4f772fc57c29421741d2f622eb4009961c (diff)
downloadpoky-20f032338ff3b4b25f2cbb7f975b5fd1c105004d.tar.gz
bitbake: bitbake: Rework hash equivalence
Reworks the hash equivalence server to address performance issues that were encountered with the REST mechanism used previously, particularly during the heavy request load encountered during signature generation. Notable changes are: 1) The server protocol is no longer HTTP based. Instead, it uses a simpler JSON over a streaming protocol link. This protocol has much lower overhead than HTTP since it eliminates the HTTP headers. 2) The hash equivalence server can either bind to a TCP port, or a Unix domain socket. Unix domain sockets are more efficient for local communication, and so are preferred if the user enables hash equivalence only for the local build. The arguments to the 'bitbake-hashserve' command have been updated accordingly. 3) The value to which BB_HASHSERVE should be set to enable a local hash equivalence server is changed to "auto" instead of "localhost:0". The latter didn't make sense when the local server was using a Unix domain socket. 4) Clients are expected to keep a persistent connection to the server instead of creating a new connection each time a request is made for optimal performance. 5) Most of the client logic has been moved to the hashserve module in bitbake. This makes it easier to share the client code. 6) A new bitbake command has been added called 'bitbake-hashclient'. This command can be used to query a hash equivalence server, including fetching the statistics and running a performance stress test. 7) The table indexes in the SQLite database have been updated to optimize hash lookups. This change is backward compatible, as the database will delete the old indexes first if they exist. 8) The server has been reworked to use python async to maximize performance with persistently connected clients. This requires Python 3.5 or later. (Bitbake rev: 2124eec3a5830afe8e07ffb6f2a0df6a417ac973) Signed-off-by: Joshua Watt <JPEWhacker@gmail.com> Signed-off-by: Richard Purdie <richard.purdie@linuxfoundation.org>
Diffstat (limited to 'bitbake/lib/hashserv')
-rw-r--r--bitbake/lib/hashserv/__init__.py261
-rw-r--r--bitbake/lib/hashserv/client.py156
-rw-r--r--bitbake/lib/hashserv/server.py414
-rw-r--r--bitbake/lib/hashserv/tests.py159
4 files changed, 717 insertions, 273 deletions
diff --git a/bitbake/lib/hashserv/__init__.py b/bitbake/lib/hashserv/__init__.py
index eb03c32213..c3318620f5 100644
--- a/bitbake/lib/hashserv/__init__.py
+++ b/bitbake/lib/hashserv/__init__.py
@@ -3,203 +3,21 @@
3# SPDX-License-Identifier: GPL-2.0-only 3# SPDX-License-Identifier: GPL-2.0-only
4# 4#
5 5
6from http.server import BaseHTTPRequestHandler, HTTPServer 6from contextlib import closing
7import contextlib 7import re
8import urllib.parse
9import sqlite3 8import sqlite3
10import json
11import traceback
12import logging
13import socketserver
14import queue
15import threading
16import signal
17import socket
18import struct
19from datetime import datetime
20
21logger = logging.getLogger('hashserv')
22
23class HashEquivalenceServer(BaseHTTPRequestHandler):
24 def log_message(self, f, *args):
25 logger.debug(f, *args)
26
27 def opendb(self):
28 self.db = sqlite3.connect(self.dbname)
29 self.db.row_factory = sqlite3.Row
30 self.db.execute("PRAGMA synchronous = OFF;")
31 self.db.execute("PRAGMA journal_mode = MEMORY;")
32
33 def do_GET(self):
34 try:
35 if not self.db:
36 self.opendb()
37
38 p = urllib.parse.urlparse(self.path)
39
40 if p.path != self.prefix + '/v1/equivalent':
41 self.send_error(404)
42 return
43
44 query = urllib.parse.parse_qs(p.query, strict_parsing=True)
45 method = query['method'][0]
46 taskhash = query['taskhash'][0]
47
48 d = None
49 with contextlib.closing(self.db.cursor()) as cursor:
50 cursor.execute('SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1',
51 {'method': method, 'taskhash': taskhash})
52
53 row = cursor.fetchone()
54
55 if row is not None:
56 logger.debug('Found equivalent task %s', row['taskhash'])
57 d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
58
59 self.send_response(200)
60 self.send_header('Content-Type', 'application/json; charset=utf-8')
61 self.end_headers()
62 self.wfile.write(json.dumps(d).encode('utf-8'))
63 except:
64 logger.exception('Error in GET')
65 self.send_error(400, explain=traceback.format_exc())
66 return
67
68 def do_POST(self):
69 try:
70 if not self.db:
71 self.opendb()
72
73 p = urllib.parse.urlparse(self.path)
74
75 if p.path != self.prefix + '/v1/equivalent':
76 self.send_error(404)
77 return
78
79 length = int(self.headers['content-length'])
80 data = json.loads(self.rfile.read(length).decode('utf-8'))
81
82 with contextlib.closing(self.db.cursor()) as cursor:
83 cursor.execute('''
84 -- Find tasks with a matching outhash (that is, tasks that
85 -- are equivalent)
86 SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND outhash=:outhash
87
88 -- If there is an exact match on the taskhash, return it.
89 -- Otherwise return the oldest matching outhash of any
90 -- taskhash
91 ORDER BY CASE WHEN taskhash=:taskhash THEN 1 ELSE 2 END,
92 created ASC
93
94 -- Only return one row
95 LIMIT 1
96 ''', {k: data[k] for k in ('method', 'outhash', 'taskhash')})
97
98 row = cursor.fetchone()
99
100 # If no matching outhash was found, or one *was* found but it
101 # wasn't an exact match on the taskhash, a new entry for this
102 # taskhash should be added
103 if row is None or row['taskhash'] != data['taskhash']:
104 # If a row matching the outhash was found, the unihash for
105 # the new taskhash should be the same as that one.
106 # Otherwise the caller provided unihash is used.
107 unihash = data['unihash']
108 if row is not None:
109 unihash = row['unihash']
110
111 insert_data = {
112 'method': data['method'],
113 'outhash': data['outhash'],
114 'taskhash': data['taskhash'],
115 'unihash': unihash,
116 'created': datetime.now()
117 }
118
119 for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'):
120 if k in data:
121 insert_data[k] = data[k]
122
123 cursor.execute('''INSERT INTO tasks_v2 (%s) VALUES (%s)''' % (
124 ', '.join(sorted(insert_data.keys())),
125 ', '.join(':' + k for k in sorted(insert_data.keys()))),
126 insert_data)
127
128 logger.info('Adding taskhash %s with unihash %s', data['taskhash'], unihash)
129
130 self.db.commit()
131 d = {'taskhash': data['taskhash'], 'method': data['method'], 'unihash': unihash}
132 else:
133 d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
134
135 self.send_response(200)
136 self.send_header('Content-Type', 'application/json; charset=utf-8')
137 self.end_headers()
138 self.wfile.write(json.dumps(d).encode('utf-8'))
139 except:
140 logger.exception('Error in POST')
141 self.send_error(400, explain=traceback.format_exc())
142 return
143
144class ThreadedHTTPServer(HTTPServer):
145 quit = False
146
147 def serve_forever(self):
148 self.requestqueue = queue.Queue()
149 self.handlerthread = threading.Thread(target=self.process_request_thread)
150 self.handlerthread.daemon = False
151
152 self.handlerthread.start()
153
154 signal.signal(signal.SIGTERM, self.sigterm_exception)
155 super().serve_forever()
156 os._exit(0)
157
158 def sigterm_exception(self, signum, stackframe):
159 self.server_close()
160 os._exit(0)
161
162 def server_bind(self):
163 HTTPServer.server_bind(self)
164 self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, struct.pack('ii', 1, 0))
165
166 def process_request_thread(self):
167 while not self.quit:
168 try:
169 (request, client_address) = self.requestqueue.get(True)
170 except queue.Empty:
171 continue
172 if request is None:
173 continue
174 try:
175 self.finish_request(request, client_address)
176 except Exception:
177 self.handle_error(request, client_address)
178 finally:
179 self.shutdown_request(request)
180 os._exit(0)
181
182 def process_request(self, request, client_address):
183 self.requestqueue.put((request, client_address))
184
185 def server_close(self):
186 super().server_close()
187 self.quit = True
188 self.requestqueue.put((None, None))
189 self.handlerthread.join()
190
191def create_server(addr, dbname, prefix=''):
192 class Handler(HashEquivalenceServer):
193 pass
194
195 db = sqlite3.connect(dbname)
196 db.row_factory = sqlite3.Row
197 9
198 Handler.prefix = prefix 10UNIX_PREFIX = "unix://"
199 Handler.db = None 11
200 Handler.dbname = dbname 12ADDR_TYPE_UNIX = 0
13ADDR_TYPE_TCP = 1
14
15
16def setup_database(database, sync=True):
17 db = sqlite3.connect(database)
18 db.row_factory = sqlite3.Row
201 19
202 with contextlib.closing(db.cursor()) as cursor: 20 with closing(db.cursor()) as cursor:
203 cursor.execute(''' 21 cursor.execute('''
204 CREATE TABLE IF NOT EXISTS tasks_v2 ( 22 CREATE TABLE IF NOT EXISTS tasks_v2 (
205 id INTEGER PRIMARY KEY AUTOINCREMENT, 23 id INTEGER PRIMARY KEY AUTOINCREMENT,
@@ -220,11 +38,56 @@ def create_server(addr, dbname, prefix=''):
220 UNIQUE(method, outhash, taskhash) 38 UNIQUE(method, outhash, taskhash)
221 ) 39 )
222 ''') 40 ''')
223 cursor.execute('CREATE INDEX IF NOT EXISTS taskhash_lookup ON tasks_v2 (method, taskhash)') 41 cursor.execute('PRAGMA journal_mode = WAL')
224 cursor.execute('CREATE INDEX IF NOT EXISTS outhash_lookup ON tasks_v2 (method, outhash)') 42 cursor.execute('PRAGMA synchronous = %s' % ('NORMAL' if sync else 'OFF'))
43
44 # Drop old indexes
45 cursor.execute('DROP INDEX IF EXISTS taskhash_lookup')
46 cursor.execute('DROP INDEX IF EXISTS outhash_lookup')
47
48 # Create new indexes
49 cursor.execute('CREATE INDEX IF NOT EXISTS taskhash_lookup_v2 ON tasks_v2 (method, taskhash, created)')
50 cursor.execute('CREATE INDEX IF NOT EXISTS outhash_lookup_v2 ON tasks_v2 (method, outhash)')
51
52 return db
53
54
55def parse_address(addr):
56 if addr.startswith(UNIX_PREFIX):
57 return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX):],))
58 else:
59 m = re.match(r'\[(?P<host>[^\]]*)\]:(?P<port>\d+)$', addr)
60 if m is not None:
61 host = m.group('host')
62 port = m.group('port')
63 else:
64 host, port = addr.split(':')
65
66 return (ADDR_TYPE_TCP, (host, int(port)))
67
68
69def create_server(addr, dbname, *, sync=True):
70 from . import server
71 db = setup_database(dbname, sync=sync)
72 s = server.Server(db)
73
74 (typ, a) = parse_address(addr)
75 if typ == ADDR_TYPE_UNIX:
76 s.start_unix_server(*a)
77 else:
78 s.start_tcp_server(*a)
79
80 return s
81
225 82
226 ret = ThreadedHTTPServer(addr, Handler) 83def create_client(addr):
84 from . import client
85 c = client.Client()
227 86
228 logger.info('Starting server on %s\n', ret.server_port) 87 (typ, a) = parse_address(addr)
88 if typ == ADDR_TYPE_UNIX:
89 c.connect_unix(*a)
90 else:
91 c.connect_tcp(*a)
229 92
230 return ret 93 return c
diff --git a/bitbake/lib/hashserv/client.py b/bitbake/lib/hashserv/client.py
new file mode 100644
index 0000000000..2559bbb3fb
--- /dev/null
+++ b/bitbake/lib/hashserv/client.py
@@ -0,0 +1,156 @@
1# Copyright (C) 2019 Garmin Ltd.
2#
3# SPDX-License-Identifier: GPL-2.0-only
4#
5
6from contextlib import closing
7import json
8import logging
9import socket
10
11
12logger = logging.getLogger('hashserv.client')
13
14
15class HashConnectionError(Exception):
16 pass
17
18
19class Client(object):
20 MODE_NORMAL = 0
21 MODE_GET_STREAM = 1
22
23 def __init__(self):
24 self._socket = None
25 self.reader = None
26 self.writer = None
27 self.mode = self.MODE_NORMAL
28
29 def connect_tcp(self, address, port):
30 def connect_sock():
31 s = socket.create_connection((address, port))
32
33 s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
34 s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
35 s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
36 return s
37
38 self._connect_sock = connect_sock
39
40 def connect_unix(self, path):
41 def connect_sock():
42 s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
43 # AF_UNIX has path length issues so chdir here to workaround
44 cwd = os.getcwd()
45 try:
46 os.chdir(os.path.dirname(path))
47 s.connect(os.path.basename(path))
48 finally:
49 os.chdir(cwd)
50 return s
51
52 self._connect_sock = connect_sock
53
54 def connect(self):
55 if self._socket is None:
56 self._socket = self._connect_sock()
57
58 self.reader = self._socket.makefile('r', encoding='utf-8')
59 self.writer = self._socket.makefile('w', encoding='utf-8')
60
61 self.writer.write('OEHASHEQUIV 1.0\n\n')
62 self.writer.flush()
63
64 # Restore mode if the socket is being re-created
65 cur_mode = self.mode
66 self.mode = self.MODE_NORMAL
67 self._set_mode(cur_mode)
68
69 return self._socket
70
71 def close(self):
72 if self._socket is not None:
73 self._socket.close()
74 self._socket = None
75 self.reader = None
76 self.writer = None
77
78 def _send_wrapper(self, proc):
79 count = 0
80 while True:
81 try:
82 self.connect()
83 return proc()
84 except (OSError, HashConnectionError, json.JSONDecodeError, UnicodeDecodeError) as e:
85 logger.warning('Error talking to server: %s' % e)
86 if count >= 3:
87 if not isinstance(e, HashConnectionError):
88 raise HashConnectionError(str(e))
89 raise e
90 self.close()
91 count += 1
92
93 def send_message(self, msg):
94 def proc():
95 self.writer.write('%s\n' % json.dumps(msg))
96 self.writer.flush()
97
98 l = self.reader.readline()
99 if not l:
100 raise HashConnectionError('Connection closed')
101
102 if not l.endswith('\n'):
103 raise HashConnectionError('Bad message %r' % message)
104
105 return json.loads(l)
106
107 return self._send_wrapper(proc)
108
109 def send_stream(self, msg):
110 def proc():
111 self.writer.write("%s\n" % msg)
112 self.writer.flush()
113 l = self.reader.readline()
114 if not l:
115 raise HashConnectionError('Connection closed')
116 return l.rstrip()
117
118 return self._send_wrapper(proc)
119
120 def _set_mode(self, new_mode):
121 if new_mode == self.MODE_NORMAL and self.mode == self.MODE_GET_STREAM:
122 r = self.send_stream('END')
123 if r != 'ok':
124 raise HashConnectionError('Bad response from server %r' % r)
125 elif new_mode == self.MODE_GET_STREAM and self.mode == self.MODE_NORMAL:
126 r = self.send_message({'get-stream': None})
127 if r != 'ok':
128 raise HashConnectionError('Bad response from server %r' % r)
129 elif new_mode != self.mode:
130 raise Exception('Undefined mode transition %r -> %r' % (self.mode, new_mode))
131
132 self.mode = new_mode
133
134 def get_unihash(self, method, taskhash):
135 self._set_mode(self.MODE_GET_STREAM)
136 r = self.send_stream('%s %s' % (method, taskhash))
137 if not r:
138 return None
139 return r
140
141 def report_unihash(self, taskhash, method, outhash, unihash, extra={}):
142 self._set_mode(self.MODE_NORMAL)
143 m = extra.copy()
144 m['taskhash'] = taskhash
145 m['method'] = method
146 m['outhash'] = outhash
147 m['unihash'] = unihash
148 return self.send_message({'report': m})
149
150 def get_stats(self):
151 self._set_mode(self.MODE_NORMAL)
152 return self.send_message({'get-stats': None})
153
154 def reset_stats(self):
155 self._set_mode(self.MODE_NORMAL)
156 return self.send_message({'reset-stats': None})
diff --git a/bitbake/lib/hashserv/server.py b/bitbake/lib/hashserv/server.py
new file mode 100644
index 0000000000..0aff77688e
--- /dev/null
+++ b/bitbake/lib/hashserv/server.py
@@ -0,0 +1,414 @@
1# Copyright (C) 2019 Garmin Ltd.
2#
3# SPDX-License-Identifier: GPL-2.0-only
4#
5
6from contextlib import closing
7from datetime import datetime
8import asyncio
9import json
10import logging
11import math
12import os
13import signal
14import socket
15import time
16
17logger = logging.getLogger('hashserv.server')
18
19
20class Measurement(object):
21 def __init__(self, sample):
22 self.sample = sample
23
24 def start(self):
25 self.start_time = time.perf_counter()
26
27 def end(self):
28 self.sample.add(time.perf_counter() - self.start_time)
29
30 def __enter__(self):
31 self.start()
32 return self
33
34 def __exit__(self, *args, **kwargs):
35 self.end()
36
37
38class Sample(object):
39 def __init__(self, stats):
40 self.stats = stats
41 self.num_samples = 0
42 self.elapsed = 0
43
44 def measure(self):
45 return Measurement(self)
46
47 def __enter__(self):
48 return self
49
50 def __exit__(self, *args, **kwargs):
51 self.end()
52
53 def add(self, elapsed):
54 self.num_samples += 1
55 self.elapsed += elapsed
56
57 def end(self):
58 if self.num_samples:
59 self.stats.add(self.elapsed)
60 self.num_samples = 0
61 self.elapsed = 0
62
63
64class Stats(object):
65 def __init__(self):
66 self.reset()
67
68 def reset(self):
69 self.num = 0
70 self.total_time = 0
71 self.max_time = 0
72 self.m = 0
73 self.s = 0
74 self.current_elapsed = None
75
76 def add(self, elapsed):
77 self.num += 1
78 if self.num == 1:
79 self.m = elapsed
80 self.s = 0
81 else:
82 last_m = self.m
83 self.m = last_m + (elapsed - last_m) / self.num
84 self.s = self.s + (elapsed - last_m) * (elapsed - self.m)
85
86 self.total_time += elapsed
87
88 if self.max_time < elapsed:
89 self.max_time = elapsed
90
91 def start_sample(self):
92 return Sample(self)
93
94 @property
95 def average(self):
96 if self.num == 0:
97 return 0
98 return self.total_time / self.num
99
100 @property
101 def stdev(self):
102 if self.num <= 1:
103 return 0
104 return math.sqrt(self.s / (self.num - 1))
105
106 def todict(self):
107 return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')}
108
109
110class ServerClient(object):
111 def __init__(self, reader, writer, db, request_stats):
112 self.reader = reader
113 self.writer = writer
114 self.db = db
115 self.request_stats = request_stats
116
117 async def process_requests(self):
118 try:
119 self.addr = self.writer.get_extra_info('peername')
120 logger.debug('Client %r connected' % (self.addr,))
121
122 # Read protocol and version
123 protocol = await self.reader.readline()
124 if protocol is None:
125 return
126
127 (proto_name, proto_version) = protocol.decode('utf-8').rstrip().split()
128 if proto_name != 'OEHASHEQUIV' or proto_version != '1.0':
129 return
130
131 # Read headers. Currently, no headers are implemented, so look for
132 # an empty line to signal the end of the headers
133 while True:
134 line = await self.reader.readline()
135 if line is None:
136 return
137
138 line = line.decode('utf-8').rstrip()
139 if not line:
140 break
141
142 # Handle messages
143 handlers = {
144 'get': self.handle_get,
145 'report': self.handle_report,
146 'get-stream': self.handle_get_stream,
147 'get-stats': self.handle_get_stats,
148 'reset-stats': self.handle_reset_stats,
149 }
150
151 while True:
152 d = await self.read_message()
153 if d is None:
154 break
155
156 for k in handlers.keys():
157 if k in d:
158 logger.debug('Handling %s' % k)
159 if 'stream' in k:
160 await handlers[k](d[k])
161 else:
162 with self.request_stats.start_sample() as self.request_sample, \
163 self.request_sample.measure():
164 await handlers[k](d[k])
165 break
166 else:
167 logger.warning("Unrecognized command %r" % d)
168 break
169
170 await self.writer.drain()
171 finally:
172 self.writer.close()
173
174 def write_message(self, msg):
175 self.writer.write(('%s\n' % json.dumps(msg)).encode('utf-8'))
176
177 async def read_message(self):
178 l = await self.reader.readline()
179 if not l:
180 return None
181
182 try:
183 message = l.decode('utf-8')
184
185 if not message.endswith('\n'):
186 return None
187
188 return json.loads(message)
189 except (json.JSONDecodeError, UnicodeDecodeError) as e:
190 logger.error('Bad message from client: %r' % message)
191 raise e
192
193 async def handle_get(self, request):
194 method = request['method']
195 taskhash = request['taskhash']
196
197 row = self.query_equivalent(method, taskhash)
198 if row is not None:
199 logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
200 d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
201
202 self.write_message(d)
203 else:
204 self.write_message(None)
205
206 async def handle_get_stream(self, request):
207 self.write_message('ok')
208
209 while True:
210 l = await self.reader.readline()
211 if not l:
212 return
213
214 try:
215 # This inner loop is very sensitive and must be as fast as
216 # possible (which is why the request sample is handled manually
217 # instead of using 'with', and also why logging statements are
218 # commented out.
219 self.request_sample = self.request_stats.start_sample()
220 request_measure = self.request_sample.measure()
221 request_measure.start()
222
223 l = l.decode('utf-8').rstrip()
224 if l == 'END':
225 self.writer.write('ok\n'.encode('utf-8'))
226 return
227
228 (method, taskhash) = l.split()
229 #logger.debug('Looking up %s %s' % (method, taskhash))
230 row = self.query_equivalent(method, taskhash)
231 if row is not None:
232 msg = ('%s\n' % row['unihash']).encode('utf-8')
233 #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
234 else:
235 msg = '\n'.encode('utf-8')
236
237 self.writer.write(msg)
238 finally:
239 request_measure.end()
240 self.request_sample.end()
241
242 await self.writer.drain()
243
244 async def handle_report(self, data):
245 with closing(self.db.cursor()) as cursor:
246 cursor.execute('''
247 -- Find tasks with a matching outhash (that is, tasks that
248 -- are equivalent)
249 SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND outhash=:outhash
250
251 -- If there is an exact match on the taskhash, return it.
252 -- Otherwise return the oldest matching outhash of any
253 -- taskhash
254 ORDER BY CASE WHEN taskhash=:taskhash THEN 1 ELSE 2 END,
255 created ASC
256
257 -- Only return one row
258 LIMIT 1
259 ''', {k: data[k] for k in ('method', 'outhash', 'taskhash')})
260
261 row = cursor.fetchone()
262
263 # If no matching outhash was found, or one *was* found but it
264 # wasn't an exact match on the taskhash, a new entry for this
265 # taskhash should be added
266 if row is None or row['taskhash'] != data['taskhash']:
267 # If a row matching the outhash was found, the unihash for
268 # the new taskhash should be the same as that one.
269 # Otherwise the caller provided unihash is used.
270 unihash = data['unihash']
271 if row is not None:
272 unihash = row['unihash']
273
274 insert_data = {
275 'method': data['method'],
276 'outhash': data['outhash'],
277 'taskhash': data['taskhash'],
278 'unihash': unihash,
279 'created': datetime.now()
280 }
281
282 for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'):
283 if k in data:
284 insert_data[k] = data[k]
285
286 cursor.execute('''INSERT INTO tasks_v2 (%s) VALUES (%s)''' % (
287 ', '.join(sorted(insert_data.keys())),
288 ', '.join(':' + k for k in sorted(insert_data.keys()))),
289 insert_data)
290
291 self.db.commit()
292
293 logger.info('Adding taskhash %s with unihash %s',
294 data['taskhash'], unihash)
295
296 d = {
297 'taskhash': data['taskhash'],
298 'method': data['method'],
299 'unihash': unihash
300 }
301 else:
302 d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
303
304 self.write_message(d)
305
306 async def handle_get_stats(self, request):
307 d = {
308 'requests': self.request_stats.todict(),
309 }
310
311 self.write_message(d)
312
313 async def handle_reset_stats(self, request):
314 d = {
315 'requests': self.request_stats.todict(),
316 }
317
318 self.request_stats.reset()
319 self.write_message(d)
320
321 def query_equivalent(self, method, taskhash):
322 # This is part of the inner loop and must be as fast as possible
323 try:
324 cursor = self.db.cursor()
325 cursor.execute('SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1',
326 {'method': method, 'taskhash': taskhash})
327 return cursor.fetchone()
328 except:
329 cursor.close()
330
331
332class Server(object):
333 def __init__(self, db, loop=None):
334 self.request_stats = Stats()
335 self.db = db
336
337 if loop is None:
338 self.loop = asyncio.new_event_loop()
339 self.close_loop = True
340 else:
341 self.loop = loop
342 self.close_loop = False
343
344 self._cleanup_socket = None
345
346 def start_tcp_server(self, host, port):
347 self.server = self.loop.run_until_complete(
348 asyncio.start_server(self.handle_client, host, port, loop=self.loop)
349 )
350
351 for s in self.server.sockets:
352 logger.info('Listening on %r' % (s.getsockname(),))
353 # Newer python does this automatically. Do it manually here for
354 # maximum compatibility
355 s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
356 s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
357
358 name = self.server.sockets[0].getsockname()
359 if self.server.sockets[0].family == socket.AF_INET6:
360 self.address = "[%s]:%d" % (name[0], name[1])
361 else:
362 self.address = "%s:%d" % (name[0], name[1])
363
364 def start_unix_server(self, path):
365 def cleanup():
366 os.unlink(path)
367
368 cwd = os.getcwd()
369 try:
370 # Work around path length limits in AF_UNIX
371 os.chdir(os.path.dirname(path))
372 self.server = self.loop.run_until_complete(
373 asyncio.start_unix_server(self.handle_client, os.path.basename(path), loop=self.loop)
374 )
375 finally:
376 os.chdir(cwd)
377
378 logger.info('Listening on %r' % path)
379
380 self._cleanup_socket = cleanup
381 self.address = "unix://%s" % os.path.abspath(path)
382
383 async def handle_client(self, reader, writer):
384 # writer.transport.set_write_buffer_limits(0)
385 try:
386 client = ServerClient(reader, writer, self.db, self.request_stats)
387 await client.process_requests()
388 except Exception as e:
389 import traceback
390 logger.error('Error from client: %s' % str(e), exc_info=True)
391 traceback.print_exc()
392 writer.close()
393 logger.info('Client disconnected')
394
395 def serve_forever(self):
396 def signal_handler():
397 self.loop.stop()
398
399 self.loop.add_signal_handler(signal.SIGTERM, signal_handler)
400
401 try:
402 self.loop.run_forever()
403 except KeyboardInterrupt:
404 pass
405
406 self.server.close()
407 self.loop.run_until_complete(self.server.wait_closed())
408 logger.info('Server shutting down')
409
410 if self.close_loop:
411 self.loop.close()
412
413 if self._cleanup_socket is not None:
414 self._cleanup_socket()
diff --git a/bitbake/lib/hashserv/tests.py b/bitbake/lib/hashserv/tests.py
index 6845b53884..6584ff57b4 100644
--- a/bitbake/lib/hashserv/tests.py
+++ b/bitbake/lib/hashserv/tests.py
@@ -1,29 +1,40 @@
1#! /usr/bin/env python3 1#! /usr/bin/env python3
2# 2#
3# Copyright (C) 2018 Garmin Ltd. 3# Copyright (C) 2018-2019 Garmin Ltd.
4# 4#
5# SPDX-License-Identifier: GPL-2.0-only 5# SPDX-License-Identifier: GPL-2.0-only
6# 6#
7 7
8import unittest 8from . import create_server, create_client
9import multiprocessing
10import sqlite3
11import hashlib 9import hashlib
12import urllib.request 10import logging
13import json 11import multiprocessing
12import sys
14import tempfile 13import tempfile
15from . import create_server 14import threading
15import unittest
16
17
18class TestHashEquivalenceServer(object):
19 METHOD = 'TestMethod'
20
21 def _run_server(self):
22 # logging.basicConfig(level=logging.DEBUG, filename='bbhashserv.log', filemode='w',
23 # format='%(levelname)s %(filename)s:%(lineno)d %(message)s')
24 self.server.serve_forever()
16 25
17class TestHashEquivalenceServer(unittest.TestCase):
18 def setUp(self): 26 def setUp(self):
19 # Start a hash equivalence server in the background bound to 27 if sys.version_info < (3, 5, 0):
20 # an ephemeral port 28 self.skipTest('Python 3.5 or later required')
21 self.dbfile = tempfile.NamedTemporaryFile(prefix="bb-hashserv-db-") 29
22 self.server = create_server(('localhost', 0), self.dbfile.name) 30 self.temp_dir = tempfile.TemporaryDirectory(prefix='bb-hashserv')
23 self.server_addr = 'http://localhost:%d' % self.server.socket.getsockname()[1] 31 self.dbfile = os.path.join(self.temp_dir.name, 'db.sqlite')
24 self.server_thread = multiprocessing.Process(target=self.server.serve_forever) 32
33 self.server = create_server(self.get_server_addr(), self.dbfile)
34 self.server_thread = multiprocessing.Process(target=self._run_server)
25 self.server_thread.daemon = True 35 self.server_thread.daemon = True
26 self.server_thread.start() 36 self.server_thread.start()
37 self.client = create_client(self.server.address)
27 38
28 def tearDown(self): 39 def tearDown(self):
29 # Shutdown server 40 # Shutdown server
@@ -31,19 +42,8 @@ class TestHashEquivalenceServer(unittest.TestCase):
31 if s is not None: 42 if s is not None:
32 self.server_thread.terminate() 43 self.server_thread.terminate()
33 self.server_thread.join() 44 self.server_thread.join()
34 45 self.client.close()
35 def send_get(self, path): 46 self.temp_dir.cleanup()
36 url = '%s/%s' % (self.server_addr, path)
37 request = urllib.request.Request(url)
38 response = urllib.request.urlopen(request)
39 return json.loads(response.read().decode('utf-8'))
40
41 def send_post(self, path, data):
42 headers = {'content-type': 'application/json'}
43 url = '%s/%s' % (self.server_addr, path)
44 request = urllib.request.Request(url, json.dumps(data).encode('utf-8'), headers)
45 response = urllib.request.urlopen(request)
46 return json.loads(response.read().decode('utf-8'))
47 47
48 def test_create_hash(self): 48 def test_create_hash(self):
49 # Simple test that hashes can be created 49 # Simple test that hashes can be created
@@ -51,16 +51,11 @@ class TestHashEquivalenceServer(unittest.TestCase):
51 outhash = '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f' 51 outhash = '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f'
52 unihash = 'f46d3fbb439bd9b921095da657a4de906510d2cd' 52 unihash = 'f46d3fbb439bd9b921095da657a4de906510d2cd'
53 53
54 d = self.send_get('v1/equivalent?method=TestMethod&taskhash=%s' % taskhash) 54 result = self.client.get_unihash(self.METHOD, taskhash)
55 self.assertIsNone(d, msg='Found unexpected task, %r' % d) 55 self.assertIsNone(result, msg='Found unexpected task, %r' % result)
56 56
57 d = self.send_post('v1/equivalent', { 57 result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
58 'taskhash': taskhash, 58 self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
59 'method': 'TestMethod',
60 'outhash': outhash,
61 'unihash': unihash,
62 })
63 self.assertEqual(d['unihash'], unihash, 'Server returned bad unihash')
64 59
65 def test_create_equivalent(self): 60 def test_create_equivalent(self):
66 # Tests that a second reported task with the same outhash will be 61 # Tests that a second reported task with the same outhash will be
@@ -68,25 +63,16 @@ class TestHashEquivalenceServer(unittest.TestCase):
68 taskhash = '53b8dce672cb6d0c73170be43f540460bfc347b4' 63 taskhash = '53b8dce672cb6d0c73170be43f540460bfc347b4'
69 outhash = '5a9cb1649625f0bf41fc7791b635cd9c2d7118c7f021ba87dcd03f72b67ce7a8' 64 outhash = '5a9cb1649625f0bf41fc7791b635cd9c2d7118c7f021ba87dcd03f72b67ce7a8'
70 unihash = 'f37918cc02eb5a520b1aff86faacbc0a38124646' 65 unihash = 'f37918cc02eb5a520b1aff86faacbc0a38124646'
71 d = self.send_post('v1/equivalent', { 66
72 'taskhash': taskhash, 67 result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
73 'method': 'TestMethod', 68 self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
74 'outhash': outhash,
75 'unihash': unihash,
76 })
77 self.assertEqual(d['unihash'], unihash, 'Server returned bad unihash')
78 69
79 # Report a different task with the same outhash. The returned unihash 70 # Report a different task with the same outhash. The returned unihash
80 # should match the first task 71 # should match the first task
81 taskhash2 = '3bf6f1e89d26205aec90da04854fbdbf73afe6b4' 72 taskhash2 = '3bf6f1e89d26205aec90da04854fbdbf73afe6b4'
82 unihash2 = 'af36b199320e611fbb16f1f277d3ee1d619ca58b' 73 unihash2 = 'af36b199320e611fbb16f1f277d3ee1d619ca58b'
83 d = self.send_post('v1/equivalent', { 74 result = self.client.report_unihash(taskhash2, self.METHOD, outhash, unihash2)
84 'taskhash': taskhash2, 75 self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
85 'method': 'TestMethod',
86 'outhash': outhash,
87 'unihash': unihash2,
88 })
89 self.assertEqual(d['unihash'], unihash, 'Server returned bad unihash')
90 76
91 def test_duplicate_taskhash(self): 77 def test_duplicate_taskhash(self):
92 # Tests that duplicate reports of the same taskhash with different 78 # Tests that duplicate reports of the same taskhash with different
@@ -95,38 +81,63 @@ class TestHashEquivalenceServer(unittest.TestCase):
95 taskhash = '8aa96fcffb5831b3c2c0cb75f0431e3f8b20554a' 81 taskhash = '8aa96fcffb5831b3c2c0cb75f0431e3f8b20554a'
96 outhash = 'afe240a439959ce86f5e322f8c208e1fedefea9e813f2140c81af866cc9edf7e' 82 outhash = 'afe240a439959ce86f5e322f8c208e1fedefea9e813f2140c81af866cc9edf7e'
97 unihash = '218e57509998197d570e2c98512d0105985dffc9' 83 unihash = '218e57509998197d570e2c98512d0105985dffc9'
98 d = self.send_post('v1/equivalent', { 84 self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
99 'taskhash': taskhash,
100 'method': 'TestMethod',
101 'outhash': outhash,
102 'unihash': unihash,
103 })
104 85
105 d = self.send_get('v1/equivalent?method=TestMethod&taskhash=%s' % taskhash) 86 result = self.client.get_unihash(self.METHOD, taskhash)
106 self.assertEqual(d['unihash'], unihash) 87 self.assertEqual(result, unihash)
107 88
108 outhash2 = '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d' 89 outhash2 = '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d'
109 unihash2 = 'ae9a7d252735f0dafcdb10e2e02561ca3a47314c' 90 unihash2 = 'ae9a7d252735f0dafcdb10e2e02561ca3a47314c'
110 d = self.send_post('v1/equivalent', { 91 self.client.report_unihash(taskhash, self.METHOD, outhash2, unihash2)
111 'taskhash': taskhash,
112 'method': 'TestMethod',
113 'outhash': outhash2,
114 'unihash': unihash2
115 })
116 92
117 d = self.send_get('v1/equivalent?method=TestMethod&taskhash=%s' % taskhash) 93 result = self.client.get_unihash(self.METHOD, taskhash)
118 self.assertEqual(d['unihash'], unihash) 94 self.assertEqual(result, unihash)
119 95
120 outhash3 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4' 96 outhash3 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4'
121 unihash3 = '9217a7d6398518e5dc002ed58f2cbbbc78696603' 97 unihash3 = '9217a7d6398518e5dc002ed58f2cbbbc78696603'
122 d = self.send_post('v1/equivalent', { 98 self.client.report_unihash(taskhash, self.METHOD, outhash3, unihash3)
123 'taskhash': taskhash, 99
124 'method': 'TestMethod', 100 result = self.client.get_unihash(self.METHOD, taskhash)
125 'outhash': outhash3, 101 self.assertEqual(result, unihash)
126 'unihash': unihash3 102
127 }) 103 def test_stress(self):
104 def query_server(failures):
105 client = Client(self.server.address)
106 try:
107 for i in range(1000):
108 taskhash = hashlib.sha256()
109 taskhash.update(str(i).encode('utf-8'))
110 taskhash = taskhash.hexdigest()
111 result = client.get_unihash(self.METHOD, taskhash)
112 if result != taskhash:
113 failures.append("taskhash mismatch: %s != %s" % (result, taskhash))
114 finally:
115 client.close()
116
117 # Report hashes
118 for i in range(1000):
119 taskhash = hashlib.sha256()
120 taskhash.update(str(i).encode('utf-8'))
121 taskhash = taskhash.hexdigest()
122 self.client.report_unihash(taskhash, self.METHOD, taskhash, taskhash)
123
124 failures = []
125 threads = [threading.Thread(target=query_server, args=(failures,)) for t in range(100)]
126
127 for t in threads:
128 t.start()
129
130 for t in threads:
131 t.join()
132
133 self.assertFalse(failures)
134
128 135
129 d = self.send_get('v1/equivalent?method=TestMethod&taskhash=%s' % taskhash) 136class TestHashEquivalenceUnixServer(TestHashEquivalenceServer, unittest.TestCase):
130 self.assertEqual(d['unihash'], unihash) 137 def get_server_addr(self):
138 return "unix://" + os.path.join(self.temp_dir.name, 'sock')
131 139
132 140
141class TestHashEquivalenceTCPServer(TestHashEquivalenceServer, unittest.TestCase):
142 def get_server_addr(self):
143 return "localhost:0"