diff options
Diffstat (limited to 'bitbake/lib/hashserv/sqlalchemy.py')
-rw-r--r-- | bitbake/lib/hashserv/sqlalchemy.py | 111 |
1 files changed, 105 insertions, 6 deletions
diff --git a/bitbake/lib/hashserv/sqlalchemy.py b/bitbake/lib/hashserv/sqlalchemy.py index 3216621f9d..bfd8a8446e 100644 --- a/bitbake/lib/hashserv/sqlalchemy.py +++ b/bitbake/lib/hashserv/sqlalchemy.py | |||
@@ -7,6 +7,7 @@ | |||
7 | 7 | ||
8 | import logging | 8 | import logging |
9 | from datetime import datetime | 9 | from datetime import datetime |
10 | from . import User | ||
10 | 11 | ||
11 | from sqlalchemy.ext.asyncio import create_async_engine | 12 | from sqlalchemy.ext.asyncio import create_async_engine |
12 | from sqlalchemy.pool import NullPool | 13 | from sqlalchemy.pool import NullPool |
@@ -25,13 +26,12 @@ from sqlalchemy import ( | |||
25 | literal, | 26 | literal, |
26 | and_, | 27 | and_, |
27 | delete, | 28 | delete, |
29 | update, | ||
28 | ) | 30 | ) |
29 | import sqlalchemy.engine | 31 | import sqlalchemy.engine |
30 | from sqlalchemy.orm import declarative_base | 32 | from sqlalchemy.orm import declarative_base |
31 | from sqlalchemy.exc import IntegrityError | 33 | from sqlalchemy.exc import IntegrityError |
32 | 34 | ||
33 | logger = logging.getLogger("hashserv.sqlalchemy") | ||
34 | |||
35 | Base = declarative_base() | 35 | Base = declarative_base() |
36 | 36 | ||
37 | 37 | ||
@@ -68,9 +68,19 @@ class OuthashesV2(Base): | |||
68 | ) | 68 | ) |
69 | 69 | ||
70 | 70 | ||
71 | class Users(Base): | ||
72 | __tablename__ = "users" | ||
73 | id = Column(Integer, primary_key=True, autoincrement=True) | ||
74 | username = Column(Text, nullable=False) | ||
75 | token = Column(Text, nullable=False) | ||
76 | permissions = Column(Text) | ||
77 | |||
78 | __table_args__ = (UniqueConstraint("username"),) | ||
79 | |||
80 | |||
71 | class DatabaseEngine(object): | 81 | class DatabaseEngine(object): |
72 | def __init__(self, url, username=None, password=None): | 82 | def __init__(self, url, username=None, password=None): |
73 | self.logger = logger | 83 | self.logger = logging.getLogger("hashserv.sqlalchemy") |
74 | self.url = sqlalchemy.engine.make_url(url) | 84 | self.url = sqlalchemy.engine.make_url(url) |
75 | 85 | ||
76 | if username is not None: | 86 | if username is not None: |
@@ -85,7 +95,7 @@ class DatabaseEngine(object): | |||
85 | 95 | ||
86 | async with self.engine.begin() as conn: | 96 | async with self.engine.begin() as conn: |
87 | # Create tables | 97 | # Create tables |
88 | logger.info("Creating tables...") | 98 | self.logger.info("Creating tables...") |
89 | await conn.run_sync(Base.metadata.create_all) | 99 | await conn.run_sync(Base.metadata.create_all) |
90 | 100 | ||
91 | def connect(self, logger): | 101 | def connect(self, logger): |
@@ -98,6 +108,15 @@ def map_row(row): | |||
98 | return dict(**row._mapping) | 108 | return dict(**row._mapping) |
99 | 109 | ||
100 | 110 | ||
111 | def map_user(row): | ||
112 | if row is None: | ||
113 | return None | ||
114 | return User( | ||
115 | username=row.username, | ||
116 | permissions=set(row.permissions.split()), | ||
117 | ) | ||
118 | |||
119 | |||
101 | class Database(object): | 120 | class Database(object): |
102 | def __init__(self, engine, logger): | 121 | def __init__(self, engine, logger): |
103 | self.engine = engine | 122 | self.engine = engine |
@@ -278,7 +297,7 @@ class Database(object): | |||
278 | await self.db.execute(statement) | 297 | await self.db.execute(statement) |
279 | return True | 298 | return True |
280 | except IntegrityError: | 299 | except IntegrityError: |
281 | logger.debug( | 300 | self.logger.debug( |
282 | "%s, %s, %s already in unihash database", method, taskhash, unihash | 301 | "%s, %s, %s already in unihash database", method, taskhash, unihash |
283 | ) | 302 | ) |
284 | return False | 303 | return False |
@@ -298,7 +317,87 @@ class Database(object): | |||
298 | await self.db.execute(statement) | 317 | await self.db.execute(statement) |
299 | return True | 318 | return True |
300 | except IntegrityError: | 319 | except IntegrityError: |
301 | logger.debug( | 320 | self.logger.debug( |
302 | "%s, %s already in outhash database", data["method"], data["outhash"] | 321 | "%s, %s already in outhash database", data["method"], data["outhash"] |
303 | ) | 322 | ) |
304 | return False | 323 | return False |
324 | |||
325 | async def _get_user(self, username): | ||
326 | statement = select( | ||
327 | Users.username, | ||
328 | Users.permissions, | ||
329 | Users.token, | ||
330 | ).where( | ||
331 | Users.username == username, | ||
332 | ) | ||
333 | self.logger.debug("%s", statement) | ||
334 | async with self.db.begin(): | ||
335 | result = await self.db.execute(statement) | ||
336 | return result.first() | ||
337 | |||
338 | async def lookup_user_token(self, username): | ||
339 | row = await self._get_user(username) | ||
340 | if not row: | ||
341 | return None, None | ||
342 | return map_user(row), row.token | ||
343 | |||
344 | async def lookup_user(self, username): | ||
345 | return map_user(await self._get_user(username)) | ||
346 | |||
347 | async def set_user_token(self, username, token): | ||
348 | statement = ( | ||
349 | update(Users) | ||
350 | .where( | ||
351 | Users.username == username, | ||
352 | ) | ||
353 | .values( | ||
354 | token=token, | ||
355 | ) | ||
356 | ) | ||
357 | self.logger.debug("%s", statement) | ||
358 | async with self.db.begin(): | ||
359 | result = await self.db.execute(statement) | ||
360 | return result.rowcount != 0 | ||
361 | |||
362 | async def set_user_perms(self, username, permissions): | ||
363 | statement = ( | ||
364 | update(Users) | ||
365 | .where(Users.username == username) | ||
366 | .values(permissions=" ".join(permissions)) | ||
367 | ) | ||
368 | self.logger.debug("%s", statement) | ||
369 | async with self.db.begin(): | ||
370 | result = await self.db.execute(statement) | ||
371 | return result.rowcount != 0 | ||
372 | |||
373 | async def get_all_users(self): | ||
374 | statement = select( | ||
375 | Users.username, | ||
376 | Users.permissions, | ||
377 | ) | ||
378 | self.logger.debug("%s", statement) | ||
379 | async with self.db.begin(): | ||
380 | result = await self.db.execute(statement) | ||
381 | return [map_user(row) for row in result] | ||
382 | |||
383 | async def new_user(self, username, permissions, token): | ||
384 | statement = insert(Users).values( | ||
385 | username=username, | ||
386 | permissions=" ".join(permissions), | ||
387 | token=token, | ||
388 | ) | ||
389 | self.logger.debug("%s", statement) | ||
390 | try: | ||
391 | async with self.db.begin(): | ||
392 | await self.db.execute(statement) | ||
393 | return True | ||
394 | except IntegrityError as e: | ||
395 | self.logger.debug("Cannot create new user %s: %s", username, e) | ||
396 | return False | ||
397 | |||
398 | async def delete_user(self, username): | ||
399 | statement = delete(Users).where(Users.username == username) | ||
400 | self.logger.debug("%s", statement) | ||
401 | async with self.db.begin(): | ||
402 | result = await self.db.execute(statement) | ||
403 | return result.rowcount != 0 | ||