diff options
Diffstat (limited to 'bitbake/lib/hashserv/__init__.py')
-rw-r--r-- | bitbake/lib/hashserv/__init__.py | 261 |
1 files changed, 62 insertions, 199 deletions
diff --git a/bitbake/lib/hashserv/__init__.py b/bitbake/lib/hashserv/__init__.py index eb03c32213..c3318620f5 100644 --- a/bitbake/lib/hashserv/__init__.py +++ b/bitbake/lib/hashserv/__init__.py | |||
@@ -3,203 +3,21 @@ | |||
3 | # SPDX-License-Identifier: GPL-2.0-only | 3 | # SPDX-License-Identifier: GPL-2.0-only |
4 | # | 4 | # |
5 | 5 | ||
6 | from http.server import BaseHTTPRequestHandler, HTTPServer | 6 | from contextlib import closing |
7 | import contextlib | 7 | import re |
8 | import urllib.parse | ||
9 | import sqlite3 | 8 | import sqlite3 |
10 | import json | ||
11 | import traceback | ||
12 | import logging | ||
13 | import socketserver | ||
14 | import queue | ||
15 | import threading | ||
16 | import signal | ||
17 | import socket | ||
18 | import struct | ||
19 | from datetime import datetime | ||
20 | |||
21 | logger = logging.getLogger('hashserv') | ||
22 | |||
23 | class HashEquivalenceServer(BaseHTTPRequestHandler): | ||
24 | def log_message(self, f, *args): | ||
25 | logger.debug(f, *args) | ||
26 | |||
27 | def opendb(self): | ||
28 | self.db = sqlite3.connect(self.dbname) | ||
29 | self.db.row_factory = sqlite3.Row | ||
30 | self.db.execute("PRAGMA synchronous = OFF;") | ||
31 | self.db.execute("PRAGMA journal_mode = MEMORY;") | ||
32 | |||
33 | def do_GET(self): | ||
34 | try: | ||
35 | if not self.db: | ||
36 | self.opendb() | ||
37 | |||
38 | p = urllib.parse.urlparse(self.path) | ||
39 | |||
40 | if p.path != self.prefix + '/v1/equivalent': | ||
41 | self.send_error(404) | ||
42 | return | ||
43 | |||
44 | query = urllib.parse.parse_qs(p.query, strict_parsing=True) | ||
45 | method = query['method'][0] | ||
46 | taskhash = query['taskhash'][0] | ||
47 | |||
48 | d = None | ||
49 | with contextlib.closing(self.db.cursor()) as cursor: | ||
50 | cursor.execute('SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1', | ||
51 | {'method': method, 'taskhash': taskhash}) | ||
52 | |||
53 | row = cursor.fetchone() | ||
54 | |||
55 | if row is not None: | ||
56 | logger.debug('Found equivalent task %s', row['taskhash']) | ||
57 | d = {k: row[k] for k in ('taskhash', 'method', 'unihash')} | ||
58 | |||
59 | self.send_response(200) | ||
60 | self.send_header('Content-Type', 'application/json; charset=utf-8') | ||
61 | self.end_headers() | ||
62 | self.wfile.write(json.dumps(d).encode('utf-8')) | ||
63 | except: | ||
64 | logger.exception('Error in GET') | ||
65 | self.send_error(400, explain=traceback.format_exc()) | ||
66 | return | ||
67 | |||
68 | def do_POST(self): | ||
69 | try: | ||
70 | if not self.db: | ||
71 | self.opendb() | ||
72 | |||
73 | p = urllib.parse.urlparse(self.path) | ||
74 | |||
75 | if p.path != self.prefix + '/v1/equivalent': | ||
76 | self.send_error(404) | ||
77 | return | ||
78 | |||
79 | length = int(self.headers['content-length']) | ||
80 | data = json.loads(self.rfile.read(length).decode('utf-8')) | ||
81 | |||
82 | with contextlib.closing(self.db.cursor()) as cursor: | ||
83 | cursor.execute(''' | ||
84 | -- Find tasks with a matching outhash (that is, tasks that | ||
85 | -- are equivalent) | ||
86 | SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND outhash=:outhash | ||
87 | |||
88 | -- If there is an exact match on the taskhash, return it. | ||
89 | -- Otherwise return the oldest matching outhash of any | ||
90 | -- taskhash | ||
91 | ORDER BY CASE WHEN taskhash=:taskhash THEN 1 ELSE 2 END, | ||
92 | created ASC | ||
93 | |||
94 | -- Only return one row | ||
95 | LIMIT 1 | ||
96 | ''', {k: data[k] for k in ('method', 'outhash', 'taskhash')}) | ||
97 | |||
98 | row = cursor.fetchone() | ||
99 | |||
100 | # If no matching outhash was found, or one *was* found but it | ||
101 | # wasn't an exact match on the taskhash, a new entry for this | ||
102 | # taskhash should be added | ||
103 | if row is None or row['taskhash'] != data['taskhash']: | ||
104 | # If a row matching the outhash was found, the unihash for | ||
105 | # the new taskhash should be the same as that one. | ||
106 | # Otherwise the caller provided unihash is used. | ||
107 | unihash = data['unihash'] | ||
108 | if row is not None: | ||
109 | unihash = row['unihash'] | ||
110 | |||
111 | insert_data = { | ||
112 | 'method': data['method'], | ||
113 | 'outhash': data['outhash'], | ||
114 | 'taskhash': data['taskhash'], | ||
115 | 'unihash': unihash, | ||
116 | 'created': datetime.now() | ||
117 | } | ||
118 | |||
119 | for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'): | ||
120 | if k in data: | ||
121 | insert_data[k] = data[k] | ||
122 | |||
123 | cursor.execute('''INSERT INTO tasks_v2 (%s) VALUES (%s)''' % ( | ||
124 | ', '.join(sorted(insert_data.keys())), | ||
125 | ', '.join(':' + k for k in sorted(insert_data.keys()))), | ||
126 | insert_data) | ||
127 | |||
128 | logger.info('Adding taskhash %s with unihash %s', data['taskhash'], unihash) | ||
129 | |||
130 | self.db.commit() | ||
131 | d = {'taskhash': data['taskhash'], 'method': data['method'], 'unihash': unihash} | ||
132 | else: | ||
133 | d = {k: row[k] for k in ('taskhash', 'method', 'unihash')} | ||
134 | |||
135 | self.send_response(200) | ||
136 | self.send_header('Content-Type', 'application/json; charset=utf-8') | ||
137 | self.end_headers() | ||
138 | self.wfile.write(json.dumps(d).encode('utf-8')) | ||
139 | except: | ||
140 | logger.exception('Error in POST') | ||
141 | self.send_error(400, explain=traceback.format_exc()) | ||
142 | return | ||
143 | |||
144 | class ThreadedHTTPServer(HTTPServer): | ||
145 | quit = False | ||
146 | |||
147 | def serve_forever(self): | ||
148 | self.requestqueue = queue.Queue() | ||
149 | self.handlerthread = threading.Thread(target=self.process_request_thread) | ||
150 | self.handlerthread.daemon = False | ||
151 | |||
152 | self.handlerthread.start() | ||
153 | |||
154 | signal.signal(signal.SIGTERM, self.sigterm_exception) | ||
155 | super().serve_forever() | ||
156 | os._exit(0) | ||
157 | |||
158 | def sigterm_exception(self, signum, stackframe): | ||
159 | self.server_close() | ||
160 | os._exit(0) | ||
161 | |||
162 | def server_bind(self): | ||
163 | HTTPServer.server_bind(self) | ||
164 | self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, struct.pack('ii', 1, 0)) | ||
165 | |||
166 | def process_request_thread(self): | ||
167 | while not self.quit: | ||
168 | try: | ||
169 | (request, client_address) = self.requestqueue.get(True) | ||
170 | except queue.Empty: | ||
171 | continue | ||
172 | if request is None: | ||
173 | continue | ||
174 | try: | ||
175 | self.finish_request(request, client_address) | ||
176 | except Exception: | ||
177 | self.handle_error(request, client_address) | ||
178 | finally: | ||
179 | self.shutdown_request(request) | ||
180 | os._exit(0) | ||
181 | |||
182 | def process_request(self, request, client_address): | ||
183 | self.requestqueue.put((request, client_address)) | ||
184 | |||
185 | def server_close(self): | ||
186 | super().server_close() | ||
187 | self.quit = True | ||
188 | self.requestqueue.put((None, None)) | ||
189 | self.handlerthread.join() | ||
190 | |||
191 | def create_server(addr, dbname, prefix=''): | ||
192 | class Handler(HashEquivalenceServer): | ||
193 | pass | ||
194 | |||
195 | db = sqlite3.connect(dbname) | ||
196 | db.row_factory = sqlite3.Row | ||
197 | 9 | ||
198 | Handler.prefix = prefix | 10 | UNIX_PREFIX = "unix://" |
199 | Handler.db = None | 11 | |
200 | Handler.dbname = dbname | 12 | ADDR_TYPE_UNIX = 0 |
13 | ADDR_TYPE_TCP = 1 | ||
14 | |||
15 | |||
16 | def setup_database(database, sync=True): | ||
17 | db = sqlite3.connect(database) | ||
18 | db.row_factory = sqlite3.Row | ||
201 | 19 | ||
202 | with contextlib.closing(db.cursor()) as cursor: | 20 | with closing(db.cursor()) as cursor: |
203 | cursor.execute(''' | 21 | cursor.execute(''' |
204 | CREATE TABLE IF NOT EXISTS tasks_v2 ( | 22 | CREATE TABLE IF NOT EXISTS tasks_v2 ( |
205 | id INTEGER PRIMARY KEY AUTOINCREMENT, | 23 | id INTEGER PRIMARY KEY AUTOINCREMENT, |
@@ -220,11 +38,56 @@ def create_server(addr, dbname, prefix=''): | |||
220 | UNIQUE(method, outhash, taskhash) | 38 | UNIQUE(method, outhash, taskhash) |
221 | ) | 39 | ) |
222 | ''') | 40 | ''') |
223 | cursor.execute('CREATE INDEX IF NOT EXISTS taskhash_lookup ON tasks_v2 (method, taskhash)') | 41 | cursor.execute('PRAGMA journal_mode = WAL') |
224 | cursor.execute('CREATE INDEX IF NOT EXISTS outhash_lookup ON tasks_v2 (method, outhash)') | 42 | cursor.execute('PRAGMA synchronous = %s' % ('NORMAL' if sync else 'OFF')) |
43 | |||
44 | # Drop old indexes | ||
45 | cursor.execute('DROP INDEX IF EXISTS taskhash_lookup') | ||
46 | cursor.execute('DROP INDEX IF EXISTS outhash_lookup') | ||
47 | |||
48 | # Create new indexes | ||
49 | cursor.execute('CREATE INDEX IF NOT EXISTS taskhash_lookup_v2 ON tasks_v2 (method, taskhash, created)') | ||
50 | cursor.execute('CREATE INDEX IF NOT EXISTS outhash_lookup_v2 ON tasks_v2 (method, outhash)') | ||
51 | |||
52 | return db | ||
53 | |||
54 | |||
55 | def parse_address(addr): | ||
56 | if addr.startswith(UNIX_PREFIX): | ||
57 | return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX):],)) | ||
58 | else: | ||
59 | m = re.match(r'\[(?P<host>[^\]]*)\]:(?P<port>\d+)$', addr) | ||
60 | if m is not None: | ||
61 | host = m.group('host') | ||
62 | port = m.group('port') | ||
63 | else: | ||
64 | host, port = addr.split(':') | ||
65 | |||
66 | return (ADDR_TYPE_TCP, (host, int(port))) | ||
67 | |||
68 | |||
69 | def create_server(addr, dbname, *, sync=True): | ||
70 | from . import server | ||
71 | db = setup_database(dbname, sync=sync) | ||
72 | s = server.Server(db) | ||
73 | |||
74 | (typ, a) = parse_address(addr) | ||
75 | if typ == ADDR_TYPE_UNIX: | ||
76 | s.start_unix_server(*a) | ||
77 | else: | ||
78 | s.start_tcp_server(*a) | ||
79 | |||
80 | return s | ||
81 | |||
225 | 82 | ||
226 | ret = ThreadedHTTPServer(addr, Handler) | 83 | def create_client(addr): |
84 | from . import client | ||
85 | c = client.Client() | ||
227 | 86 | ||
228 | logger.info('Starting server on %s\n', ret.server_port) | 87 | (typ, a) = parse_address(addr) |
88 | if typ == ADDR_TYPE_UNIX: | ||
89 | c.connect_unix(*a) | ||
90 | else: | ||
91 | c.connect_tcp(*a) | ||
229 | 92 | ||
230 | return ret | 93 | return c |