@@ -6,7 +6,6 @@
import asyncio
from contextlib import closing
import re
-import sqlite3
import itertools
import json
from urllib.parse import urlparse
@@ -19,92 +18,34 @@ ADDR_TYPE_UNIX = 0
ADDR_TYPE_TCP = 1
ADDR_TYPE_WS = 2
-UNIHASH_TABLE_DEFINITION = (
- ("method", "TEXT NOT NULL", "UNIQUE"),
- ("taskhash", "TEXT NOT NULL", "UNIQUE"),
- ("unihash", "TEXT NOT NULL", ""),
-)
-
-UNIHASH_TABLE_COLUMNS = tuple(name for name, _, _ in UNIHASH_TABLE_DEFINITION)
-
-OUTHASH_TABLE_DEFINITION = (
- ("method", "TEXT NOT NULL", "UNIQUE"),
- ("taskhash", "TEXT NOT NULL", "UNIQUE"),
- ("outhash", "TEXT NOT NULL", "UNIQUE"),
- ("created", "DATETIME", ""),
-
- # Optional fields
- ("owner", "TEXT", ""),
- ("PN", "TEXT", ""),
- ("PV", "TEXT", ""),
- ("PR", "TEXT", ""),
- ("task", "TEXT", ""),
- ("outhash_siginfo", "TEXT", ""),
-)
-
-OUTHASH_TABLE_COLUMNS = tuple(name for name, _, _ in OUTHASH_TABLE_DEFINITION)
-
-def _make_table(cursor, name, definition):
- cursor.execute('''
- CREATE TABLE IF NOT EXISTS {name} (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- {fields}
- UNIQUE({unique})
- )
- '''.format(
- name=name,
- fields=" ".join("%s %s," % (name, typ) for name, typ, _ in definition),
- unique=", ".join(name for name, _, flags in definition if "UNIQUE" in flags)
- ))
-
-
-def setup_database(database, sync=True):
- db = sqlite3.connect(database)
- db.row_factory = sqlite3.Row
-
- with closing(db.cursor()) as cursor:
- _make_table(cursor, "unihashes_v2", UNIHASH_TABLE_DEFINITION)
- _make_table(cursor, "outhashes_v2", OUTHASH_TABLE_DEFINITION)
-
- cursor.execute('PRAGMA journal_mode = WAL')
- cursor.execute('PRAGMA synchronous = %s' % ('NORMAL' if sync else 'OFF'))
-
- # Drop old indexes
- cursor.execute('DROP INDEX IF EXISTS taskhash_lookup')
- cursor.execute('DROP INDEX IF EXISTS outhash_lookup')
- cursor.execute('DROP INDEX IF EXISTS taskhash_lookup_v2')
- cursor.execute('DROP INDEX IF EXISTS outhash_lookup_v2')
-
- # TODO: Upgrade from tasks_v2?
- cursor.execute('DROP TABLE IF EXISTS tasks_v2')
-
- # Create new indexes
- cursor.execute('CREATE INDEX IF NOT EXISTS taskhash_lookup_v3 ON unihashes_v2 (method, taskhash)')
- cursor.execute('CREATE INDEX IF NOT EXISTS outhash_lookup_v3 ON outhashes_v2 (method, outhash)')
-
- return db
-
def parse_address(addr):
if addr.startswith(UNIX_PREFIX):
- return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX):],))
+ return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX) :],))
elif addr.startswith(WS_PREFIX) or addr.startswith(WSS_PREFIX):
return (ADDR_TYPE_WS, (addr,))
else:
- m = re.match(r'\[(?P<host>[^\]]*)\]:(?P<port>\d+)$', addr)
+ m = re.match(r"\[(?P<host>[^\]]*)\]:(?P<port>\d+)$", addr)
if m is not None:
- host = m.group('host')
- port = m.group('port')
+ host = m.group("host")
+ port = m.group("port")
else:
- host, port = addr.split(':')
+ host, port = addr.split(":")
return (ADDR_TYPE_TCP, (host, int(port)))
def create_server(addr, dbname, *, sync=True, upstream=None, read_only=False):
+ def sqlite_engine():
+ from .sqlite import DatabaseEngine
+
+ return DatabaseEngine(dbname, sync)
+
from . import server
- db = setup_database(dbname, sync=sync)
- s = server.Server(db, upstream=upstream, read_only=read_only)
+
+ db_engine = sqlite_engine()
+
+ s = server.Server(db_engine, upstream=upstream, read_only=read_only)
(typ, a) = parse_address(addr)
if typ == ADDR_TYPE_UNIX:
@@ -120,6 +61,7 @@ def create_server(addr, dbname, *, sync=True, upstream=None, read_only=False):
def create_client(addr):
from . import client
+
c = client.Client()
(typ, a) = parse_address(addr)
@@ -132,8 +74,10 @@ def create_client(addr):
return c
+
async def create_async_client(addr):
from . import client
+
c = client.AsyncClient()
(typ, a) = parse_address(addr)
@@ -3,18 +3,16 @@
# SPDX-License-Identifier: GPL-2.0-only
#
-from contextlib import closing, contextmanager
from datetime import datetime, timedelta
-import enum
import asyncio
import logging
import math
import time
-from . import create_async_client, UNIHASH_TABLE_COLUMNS, OUTHASH_TABLE_COLUMNS
+from . import create_async_client
import bb.asyncrpc
-logger = logging.getLogger('hashserv.server')
+logger = logging.getLogger("hashserv.server")
class Measurement(object):
@@ -104,229 +102,136 @@ class Stats(object):
return math.sqrt(self.s / (self.num - 1))
def todict(self):
- return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')}
-
-
-@enum.unique
-class Resolve(enum.Enum):
- FAIL = enum.auto()
- IGNORE = enum.auto()
- REPLACE = enum.auto()
-
-
-def insert_table(cursor, table, data, on_conflict):
- resolve = {
- Resolve.FAIL: "",
- Resolve.IGNORE: " OR IGNORE",
- Resolve.REPLACE: " OR REPLACE",
- }[on_conflict]
-
- keys = sorted(data.keys())
- query = 'INSERT{resolve} INTO {table} ({fields}) VALUES({values})'.format(
- resolve=resolve,
- table=table,
- fields=", ".join(keys),
- values=", ".join(":" + k for k in keys),
- )
- prevrowid = cursor.lastrowid
- cursor.execute(query, data)
- logging.debug(
- "Inserting %r into %s, %s",
- data,
- table,
- on_conflict
- )
- return (cursor.lastrowid, cursor.lastrowid != prevrowid)
-
-def insert_unihash(cursor, data, on_conflict):
- return insert_table(cursor, "unihashes_v2", data, on_conflict)
-
-def insert_outhash(cursor, data, on_conflict):
- return insert_table(cursor, "outhashes_v2", data, on_conflict)
-
-async def copy_unihash_from_upstream(client, db, method, taskhash):
- d = await client.get_taskhash(method, taskhash)
- if d is not None:
- with closing(db.cursor()) as cursor:
- insert_unihash(
- cursor,
- {k: v for k, v in d.items() if k in UNIHASH_TABLE_COLUMNS},
- Resolve.IGNORE,
- )
- db.commit()
- return d
-
-
-class ServerCursor(object):
- def __init__(self, db, cursor, upstream):
- self.db = db
- self.cursor = cursor
- self.upstream = upstream
+ return {
+ k: getattr(self, k)
+ for k in ("num", "total_time", "max_time", "average", "stdev")
+ }
class ServerClient(bb.asyncrpc.AsyncServerConnection):
- def __init__(self, socket, db, request_stats, backfill_queue, upstream, read_only):
- super().__init__(socket, 'OEHASHEQUIV', logger)
- self.db = db
+ def __init__(
+ self,
+ socket,
+ db_engine,
+ request_stats,
+ backfill_queue,
+ upstream,
+ read_only,
+ ):
+ super().__init__(socket, "OEHASHEQUIV", logger)
+ self.db_engine = db_engine
self.request_stats = request_stats
self.max_chunk = bb.asyncrpc.DEFAULT_MAX_CHUNK
self.backfill_queue = backfill_queue
self.upstream = upstream
- self.handlers.update({
- 'get': self.handle_get,
- 'get-outhash': self.handle_get_outhash,
- 'get-stream': self.handle_get_stream,
- 'get-stats': self.handle_get_stats,
- })
+ self.handlers.update(
+ {
+ "get": self.handle_get,
+ "get-outhash": self.handle_get_outhash,
+ "get-stream": self.handle_get_stream,
+ "get-stats": self.handle_get_stats,
+ }
+ )
if not read_only:
- self.handlers.update({
- 'report': self.handle_report,
- 'report-equiv': self.handle_equivreport,
- 'reset-stats': self.handle_reset_stats,
- 'backfill-wait': self.handle_backfill_wait,
- 'remove': self.handle_remove,
- 'clean-unused': self.handle_clean_unused,
- })
+ self.handlers.update(
+ {
+ "report": self.handle_report,
+ "report-equiv": self.handle_equivreport,
+ "reset-stats": self.handle_reset_stats,
+ "backfill-wait": self.handle_backfill_wait,
+ "remove": self.handle_remove,
+ "clean-unused": self.handle_clean_unused,
+ }
+ )
def validate_proto_version(self):
- return (self.proto_version > (1, 0) and self.proto_version <= (1, 1))
+ return self.proto_version > (1, 0) and self.proto_version <= (1, 1)
async def process_requests(self):
- if self.upstream is not None:
- self.upstream_client = await create_async_client(self.upstream)
- else:
- self.upstream_client = None
-
- await super().process_requests()
+ async with self.db_engine.connect(self.logger) as db:
+ self.db = db
+ if self.upstream is not None:
+ self.upstream_client = await create_async_client(self.upstream)
+ else:
+ self.upstream_client = None
- if self.upstream_client is not None:
- await self.upstream_client.close()
+ try:
+ await super().process_requests()
+ finally:
+ if self.upstream_client is not None:
+ await self.upstream_client.close()
async def dispatch_message(self, msg):
for k in self.handlers.keys():
if k in msg:
- self.logger.debug('Handling %s' % k)
- if 'stream' in k:
+ self.logger.debug("Handling %s" % k)
+ if "stream" in k:
return await self.handlers[k](msg[k])
else:
- with self.request_stats.start_sample() as self.request_sample, \
- self.request_sample.measure():
+ with self.request_stats.start_sample() as self.request_sample, self.request_sample.measure():
return await self.handlers[k](msg[k])
raise bb.asyncrpc.ClientError("Unrecognized command %r" % msg)
async def handle_get(self, request):
- method = request['method']
- taskhash = request['taskhash']
- fetch_all = request.get('all', False)
+ method = request["method"]
+ taskhash = request["taskhash"]
+ fetch_all = request.get("all", False)
- with closing(self.db.cursor()) as cursor:
- return await self.get_unihash(cursor, method, taskhash, fetch_all)
+ return await self.get_unihash(method, taskhash, fetch_all)
- async def get_unihash(self, cursor, method, taskhash, fetch_all=False):
+ async def get_unihash(self, method, taskhash, fetch_all=False):
d = None
if fetch_all:
- cursor.execute(
- '''
- SELECT *, unihashes_v2.unihash AS unihash FROM outhashes_v2
- INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash
- WHERE outhashes_v2.method=:method AND outhashes_v2.taskhash=:taskhash
- ORDER BY outhashes_v2.created ASC
- LIMIT 1
- ''',
- {
- 'method': method,
- 'taskhash': taskhash,
- }
-
- )
- row = cursor.fetchone()
-
+ row = await self.db.get_unihash_by_taskhash_full(method, taskhash)
if row is not None:
d = {k: row[k] for k in row.keys()}
elif self.upstream_client is not None:
d = await self.upstream_client.get_taskhash(method, taskhash, True)
- self.update_unified(cursor, d)
- self.db.commit()
+ await self.update_unified(d)
else:
- row = self.query_equivalent(cursor, method, taskhash)
+ row = await self.db.get_equivalent(method, taskhash)
if row is not None:
d = {k: row[k] for k in row.keys()}
elif self.upstream_client is not None:
d = await self.upstream_client.get_taskhash(method, taskhash)
- d = {k: v for k, v in d.items() if k in UNIHASH_TABLE_COLUMNS}
- insert_unihash(cursor, d, Resolve.IGNORE)
- self.db.commit()
+ await self.db.insert_unihash(d["method"], d["taskhash"], d["unihash"])
return d
async def handle_get_outhash(self, request):
- method = request['method']
- outhash = request['outhash']
- taskhash = request['taskhash']
+ method = request["method"]
+ outhash = request["outhash"]
+ taskhash = request["taskhash"]
with_unihash = request.get("with_unihash", True)
- with closing(self.db.cursor()) as cursor:
- return await self.get_outhash(cursor, method, outhash, taskhash, with_unihash)
+ return await self.get_outhash(method, outhash, taskhash, with_unihash)
- async def get_outhash(self, cursor, method, outhash, taskhash, with_unihash=True):
+ async def get_outhash(self, method, outhash, taskhash, with_unihash=True):
d = None
if with_unihash:
- cursor.execute(
- '''
- SELECT *, unihashes_v2.unihash AS unihash FROM outhashes_v2
- INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash
- WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash
- ORDER BY outhashes_v2.created ASC
- LIMIT 1
- ''',
- {
- 'method': method,
- 'outhash': outhash,
- }
- )
+ row = await self.db.get_unihash_by_outhash(method, outhash)
else:
- cursor.execute(
- """
- SELECT * FROM outhashes_v2
- WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash
- ORDER BY outhashes_v2.created ASC
- LIMIT 1
- """,
- {
- 'method': method,
- 'outhash': outhash,
- }
- )
- row = cursor.fetchone()
+ row = await self.db.get_outhash(method, outhash)
if row is not None:
d = {k: row[k] for k in row.keys()}
elif self.upstream_client is not None:
d = await self.upstream_client.get_outhash(method, outhash, taskhash)
- self.update_unified(cursor, d)
- self.db.commit()
+ await self.update_unified(d)
return d
- def update_unified(self, cursor, data):
+ async def update_unified(self, data):
if data is None:
return
- insert_unihash(
- cursor,
- {k: v for k, v in data.items() if k in UNIHASH_TABLE_COLUMNS},
- Resolve.IGNORE
- )
- insert_outhash(
- cursor,
- {k: v for k, v in data.items() if k in OUTHASH_TABLE_COLUMNS},
- Resolve.IGNORE
- )
+ await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"])
+ await self.db.insert_outhash(data)
async def handle_get_stream(self, request):
await self.socket.send_message("ok")
@@ -347,20 +252,16 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
request_measure = self.request_sample.measure()
request_measure.start()
- if l == 'END':
+ if l == "END":
break
(method, taskhash) = l.split()
- #self.logger.debug('Looking up %s %s' % (method, taskhash))
- cursor = self.db.cursor()
- try:
- row = self.query_equivalent(cursor, method, taskhash)
- finally:
- cursor.close()
+ # self.logger.debug('Looking up %s %s' % (method, taskhash))
+ row = await self.db.get_equivalent(method, taskhash)
if row is not None:
- msg = row['unihash']
- #self.logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
+ msg = row["unihash"]
+ # self.logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
elif self.upstream_client is not None:
upstream = await self.upstream_client.get_unihash(method, taskhash)
if upstream:
@@ -384,118 +285,81 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
return self.NO_RESPONSE
async def handle_report(self, data):
- with closing(self.db.cursor()) as cursor:
- outhash_data = {
- 'method': data['method'],
- 'outhash': data['outhash'],
- 'taskhash': data['taskhash'],
- 'created': datetime.now()
- }
+ outhash_data = {
+ "method": data["method"],
+ "outhash": data["outhash"],
+ "taskhash": data["taskhash"],
+ "created": datetime.now(),
+ }
- for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'):
- if k in data:
- outhash_data[k] = data[k]
-
- # Insert the new entry, unless it already exists
- (rowid, inserted) = insert_outhash(cursor, outhash_data, Resolve.IGNORE)
-
- if inserted:
- # If this row is new, check if it is equivalent to another
- # output hash
- cursor.execute(
- '''
- SELECT outhashes_v2.taskhash AS taskhash, unihashes_v2.unihash AS unihash FROM outhashes_v2
- INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash
- -- Select any matching output hash except the one we just inserted
- WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash AND outhashes_v2.taskhash!=:taskhash
- -- Pick the oldest hash
- ORDER BY outhashes_v2.created ASC
- LIMIT 1
- ''',
- {
- 'method': data['method'],
- 'outhash': data['outhash'],
- 'taskhash': data['taskhash'],
- }
- )
- row = cursor.fetchone()
+ for k in ("owner", "PN", "PV", "PR", "task", "outhash_siginfo"):
+ if k in data:
+ outhash_data[k] = data[k]
- if row is not None:
- # A matching output hash was found. Set our taskhash to the
- # same unihash since they are equivalent
- unihash = row['unihash']
- resolve = Resolve.IGNORE
- else:
- # No matching output hash was found. This is probably the
- # first outhash to be added.
- unihash = data['unihash']
- resolve = Resolve.IGNORE
-
- # Query upstream to see if it has a unihash we can use
- if self.upstream_client is not None:
- upstream_data = await self.upstream_client.get_outhash(data['method'], data['outhash'], data['taskhash'])
- if upstream_data is not None:
- unihash = upstream_data['unihash']
-
-
- insert_unihash(
- cursor,
- {
- 'method': data['method'],
- 'taskhash': data['taskhash'],
- 'unihash': unihash,
- },
- resolve
- )
-
- unihash_data = await self.get_unihash(cursor, data['method'], data['taskhash'])
- if unihash_data is not None:
- unihash = unihash_data['unihash']
- else:
- unihash = data['unihash']
-
- self.db.commit()
+ # Insert the new entry, unless it already exists
+ if await self.db.insert_outhash(outhash_data):
+ # If this row is new, check if it is equivalent to another
+ # output hash
+ row = await self.db.get_equivalent_for_outhash(
+ data["method"], data["outhash"], data["taskhash"]
+ )
- d = {
- 'taskhash': data['taskhash'],
- 'method': data['method'],
- 'unihash': unihash,
- }
+ if row is not None:
+ # A matching output hash was found. Set our taskhash to the
+ # same unihash since they are equivalent
+ unihash = row["unihash"]
+ else:
+ # No matching output hash was found. This is probably the
+ # first outhash to be added.
+ unihash = data["unihash"]
+
+ # Query upstream to see if it has a unihash we can use
+ if self.upstream_client is not None:
+ upstream_data = await self.upstream_client.get_outhash(
+ data["method"], data["outhash"], data["taskhash"]
+ )
+ if upstream_data is not None:
+ unihash = upstream_data["unihash"]
+
+ await self.db.insert_unihash(data["method"], data["taskhash"], unihash)
+
+ unihash_data = await self.get_unihash(data["method"], data["taskhash"])
+ if unihash_data is not None:
+ unihash = unihash_data["unihash"]
+ else:
+ unihash = data["unihash"]
- return d
+ return {
+ "taskhash": data["taskhash"],
+ "method": data["method"],
+ "unihash": unihash,
+ }
async def handle_equivreport(self, data):
- with closing(self.db.cursor()) as cursor:
- insert_data = {
- 'method': data['method'],
- 'taskhash': data['taskhash'],
- 'unihash': data['unihash'],
- }
- insert_unihash(cursor, insert_data, Resolve.IGNORE)
- self.db.commit()
-
- # Fetch the unihash that will be reported for the taskhash. If the
- # unihash matches, it means this row was inserted (or the mapping
- # was already valid)
- row = self.query_equivalent(cursor, data['method'], data['taskhash'])
-
- if row['unihash'] == data['unihash']:
- self.logger.info('Adding taskhash equivalence for %s with unihash %s',
- data['taskhash'], row['unihash'])
-
- d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
-
- return d
+ await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"])
+
+ # Fetch the unihash that will be reported for the taskhash. If the
+ # unihash matches, it means this row was inserted (or the mapping
+ # was already valid)
+ row = await self.db.get_equivalent(data["method"], data["taskhash"])
+
+ if row["unihash"] == data["unihash"]:
+ self.logger.info(
+ "Adding taskhash equivalence for %s with unihash %s",
+ data["taskhash"],
+ row["unihash"],
+ )
+ return {k: row[k] for k in ("taskhash", "method", "unihash")}
async def handle_get_stats(self, request):
return {
- 'requests': self.request_stats.todict(),
+ "requests": self.request_stats.todict(),
}
async def handle_reset_stats(self, request):
d = {
- 'requests': self.request_stats.todict(),
+ "requests": self.request_stats.todict(),
}
self.request_stats.reset()
@@ -503,7 +367,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
async def handle_backfill_wait(self, request):
d = {
- 'tasks': self.backfill_queue.qsize(),
+ "tasks": self.backfill_queue.qsize(),
}
await self.backfill_queue.join()
return d
@@ -513,92 +377,63 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
if not isinstance(condition, dict):
raise TypeError("Bad condition type %s" % type(condition))
- def do_remove(columns, table_name, cursor):
- nonlocal condition
- where = {}
- for c in columns:
- if c in condition and condition[c] is not None:
- where[c] = condition[c]
-
- if where:
- query = ('DELETE FROM %s WHERE ' % table_name) + ' AND '.join("%s=:%s" % (k, k) for k in where.keys())
- cursor.execute(query, where)
- return cursor.rowcount
-
- return 0
-
- count = 0
- with closing(self.db.cursor()) as cursor:
- count += do_remove(OUTHASH_TABLE_COLUMNS, "outhashes_v2", cursor)
- count += do_remove(UNIHASH_TABLE_COLUMNS, "unihashes_v2", cursor)
- self.db.commit()
-
- return {"count": count}
+ return {"count": await self.db.remove(condition)}
async def handle_clean_unused(self, request):
max_age = request["max_age_seconds"]
- with closing(self.db.cursor()) as cursor:
- cursor.execute(
- """
- DELETE FROM outhashes_v2 WHERE created<:oldest AND NOT EXISTS (
- SELECT unihashes_v2.id FROM unihashes_v2 WHERE unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash LIMIT 1
- )
- """,
- {
- "oldest": datetime.now() - timedelta(seconds=-max_age)
- }
- )
- count = cursor.rowcount
-
- return {"count": count}
-
- def query_equivalent(self, cursor, method, taskhash):
- # This is part of the inner loop and must be as fast as possible
- cursor.execute(
- 'SELECT taskhash, method, unihash FROM unihashes_v2 WHERE method=:method AND taskhash=:taskhash',
- {
- 'method': method,
- 'taskhash': taskhash,
- }
- )
- return cursor.fetchone()
+ oldest = datetime.now() - timedelta(seconds=-max_age)
+ return {"count": await self.db.clean_unused(oldest)}
class Server(bb.asyncrpc.AsyncServer):
- def __init__(self, db, upstream=None, read_only=False):
+ def __init__(self, db_engine, upstream=None, read_only=False):
if upstream and read_only:
- raise bb.asyncrpc.ServerError("Read-only hashserv cannot pull from an upstream server")
+ raise bb.asyncrpc.ServerError(
+ "Read-only hashserv cannot pull from an upstream server"
+ )
super().__init__(logger)
self.request_stats = Stats()
- self.db = db
+ self.db_engine = db_engine
self.upstream = upstream
self.read_only = read_only
self.backfill_queue = None
def accept_client(self, socket):
- return ServerClient(socket, self.db, self.request_stats, self.backfill_queue, self.upstream, self.read_only)
+ return ServerClient(
+ socket,
+ self.db_engine,
+ self.request_stats,
+ self.backfill_queue,
+ self.upstream,
+ self.read_only,
+ )
async def backfill_worker_task(self):
- client = await create_async_client(self.upstream)
- try:
+ async with await create_async_client(
+ self.upstream
+ ) as client, self.db_engine.connect(logger) as db:
while True:
item = await self.backfill_queue.get()
if item is None:
self.backfill_queue.task_done()
break
+
method, taskhash = item
- await copy_unihash_from_upstream(client, self.db, method, taskhash)
+ d = await client.get_taskhash(method, taskhash)
+ if d is not None:
+ await db.insert_unihash(d["method"], d["taskhash"], d["unihash"])
self.backfill_queue.task_done()
- finally:
- await client.close()
def start(self):
tasks = super().start()
if self.upstream:
self.backfill_queue = asyncio.Queue()
tasks += [self.backfill_worker_task()]
+
+ self.loop.run_until_complete(self.db_engine.create())
+
return tasks
async def stop(self):
new file mode 100644
@@ -0,0 +1,259 @@
+#! /usr/bin/env python3
+#
+# Copyright (C) 2023 Garmin Ltd.
+#
+# SPDX-License-Identifier: GPL-2.0-only
+#
+import sqlite3
+import logging
+from contextlib import closing
+
+logger = logging.getLogger("hashserv.sqlite")
+
+UNIHASH_TABLE_DEFINITION = (
+ ("method", "TEXT NOT NULL", "UNIQUE"),
+ ("taskhash", "TEXT NOT NULL", "UNIQUE"),
+ ("unihash", "TEXT NOT NULL", ""),
+)
+
+UNIHASH_TABLE_COLUMNS = tuple(name for name, _, _ in UNIHASH_TABLE_DEFINITION)
+
+OUTHASH_TABLE_DEFINITION = (
+ ("method", "TEXT NOT NULL", "UNIQUE"),
+ ("taskhash", "TEXT NOT NULL", "UNIQUE"),
+ ("outhash", "TEXT NOT NULL", "UNIQUE"),
+ ("created", "DATETIME", ""),
+ # Optional fields
+ ("owner", "TEXT", ""),
+ ("PN", "TEXT", ""),
+ ("PV", "TEXT", ""),
+ ("PR", "TEXT", ""),
+ ("task", "TEXT", ""),
+ ("outhash_siginfo", "TEXT", ""),
+)
+
+OUTHASH_TABLE_COLUMNS = tuple(name for name, _, _ in OUTHASH_TABLE_DEFINITION)
+
+
+def _make_table(cursor, name, definition):
+ cursor.execute(
+ """
+ CREATE TABLE IF NOT EXISTS {name} (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ {fields}
+ UNIQUE({unique})
+ )
+ """.format(
+ name=name,
+ fields=" ".join("%s %s," % (name, typ) for name, typ, _ in definition),
+ unique=", ".join(
+ name for name, _, flags in definition if "UNIQUE" in flags
+ ),
+ )
+ )
+
+
+class DatabaseEngine(object):
+ def __init__(self, dbname, sync):
+ self.dbname = dbname
+ self.logger = logger
+ self.sync = sync
+
+ async def create(self):
+ db = sqlite3.connect(self.dbname)
+ db.row_factory = sqlite3.Row
+
+ with closing(db.cursor()) as cursor:
+ _make_table(cursor, "unihashes_v2", UNIHASH_TABLE_DEFINITION)
+ _make_table(cursor, "outhashes_v2", OUTHASH_TABLE_DEFINITION)
+
+ cursor.execute("PRAGMA journal_mode = WAL")
+ cursor.execute(
+ "PRAGMA synchronous = %s" % ("NORMAL" if self.sync else "OFF")
+ )
+
+ # Drop old indexes
+ cursor.execute("DROP INDEX IF EXISTS taskhash_lookup")
+ cursor.execute("DROP INDEX IF EXISTS outhash_lookup")
+ cursor.execute("DROP INDEX IF EXISTS taskhash_lookup_v2")
+ cursor.execute("DROP INDEX IF EXISTS outhash_lookup_v2")
+
+ # TODO: Upgrade from tasks_v2?
+ cursor.execute("DROP TABLE IF EXISTS tasks_v2")
+
+ # Create new indexes
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS taskhash_lookup_v3 ON unihashes_v2 (method, taskhash)"
+ )
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS outhash_lookup_v3 ON outhashes_v2 (method, outhash)"
+ )
+
+ def connect(self, logger):
+ return Database(logger, self.dbname)
+
+
+class Database(object):
+ def __init__(self, logger, dbname, sync=True):
+ self.dbname = dbname
+ self.logger = logger
+
+ self.db = sqlite3.connect(self.dbname)
+ self.db.row_factory = sqlite3.Row
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc_value, traceback):
+ await self.close()
+
+ async def close(self):
+ self.db.close()
+
+ async def get_unihash_by_taskhash_full(self, method, taskhash):
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute(
+ """
+ SELECT *, unihashes_v2.unihash AS unihash FROM outhashes_v2
+ INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash
+ WHERE outhashes_v2.method=:method AND outhashes_v2.taskhash=:taskhash
+ ORDER BY outhashes_v2.created ASC
+ LIMIT 1
+ """,
+ {
+ "method": method,
+ "taskhash": taskhash,
+ },
+ )
+ return cursor.fetchone()
+
+ async def get_unihash_by_outhash(self, method, outhash):
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute(
+ """
+ SELECT *, unihashes_v2.unihash AS unihash FROM outhashes_v2
+ INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash
+ WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash
+ ORDER BY outhashes_v2.created ASC
+ LIMIT 1
+ """,
+ {
+ "method": method,
+ "outhash": outhash,
+ },
+ )
+ return cursor.fetchone()
+
+ async def get_outhash(self, method, outhash):
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute(
+ """
+ SELECT * FROM outhashes_v2
+ WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash
+ ORDER BY outhashes_v2.created ASC
+ LIMIT 1
+ """,
+ {
+ "method": method,
+ "outhash": outhash,
+ },
+ )
+ return cursor.fetchone()
+
+ async def get_equivalent_for_outhash(self, method, outhash, taskhash):
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute(
+ """
+ SELECT outhashes_v2.taskhash AS taskhash, unihashes_v2.unihash AS unihash FROM outhashes_v2
+ INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash
+ -- Select any matching output hash except the one we just inserted
+ WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash AND outhashes_v2.taskhash!=:taskhash
+ -- Pick the oldest hash
+ ORDER BY outhashes_v2.created ASC
+ LIMIT 1
+ """,
+ {
+ "method": method,
+ "outhash": outhash,
+ "taskhash": taskhash,
+ },
+ )
+ return cursor.fetchone()
+
+ async def get_equivalent(self, method, taskhash):
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute(
+ "SELECT taskhash, method, unihash FROM unihashes_v2 WHERE method=:method AND taskhash=:taskhash",
+ {
+ "method": method,
+ "taskhash": taskhash,
+ },
+ )
+ return cursor.fetchone()
+
+ async def remove(self, condition):
+ def do_remove(columns, table_name, cursor):
+ where = {}
+ for c in columns:
+ if c in condition and condition[c] is not None:
+ where[c] = condition[c]
+
+ if where:
+ query = ("DELETE FROM %s WHERE " % table_name) + " AND ".join(
+ "%s=:%s" % (k, k) for k in where.keys()
+ )
+ cursor.execute(query, where)
+ return cursor.rowcount
+
+ return 0
+
+ count = 0
+ with closing(self.db.cursor()) as cursor:
+ count += do_remove(OUTHASH_TABLE_COLUMNS, "outhashes_v2", cursor)
+ count += do_remove(UNIHASH_TABLE_COLUMNS, "unihashes_v2", cursor)
+ self.db.commit()
+
+ return count
+
+ async def clean_unused(self, oldest):
+ with closing(self.db.cursor()) as cursor:
+ cursor.execute(
+ """
+ DELETE FROM outhashes_v2 WHERE created<:oldest AND NOT EXISTS (
+ SELECT unihashes_v2.id FROM unihashes_v2 WHERE unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash LIMIT 1
+ )
+ """,
+ {
+ "oldest": oldest,
+ },
+ )
+ return cursor.rowcount
+
+ async def insert_unihash(self, method, taskhash, unihash):
+ with closing(self.db.cursor()) as cursor:
+ prevrowid = cursor.lastrowid
+ cursor.execute(
+ """
+ INSERT OR IGNORE INTO unihashes_v2 (method, taskhash, unihash) VALUES(:method, :taskhash, :unihash)
+ """,
+ {
+ "method": method,
+ "taskhash": taskhash,
+ "unihash": unihash,
+ },
+ )
+ self.db.commit()
+ return cursor.lastrowid != prevrowid
+
+ async def insert_outhash(self, data):
+ data = {k: v for k, v in data.items() if k in OUTHASH_TABLE_COLUMNS}
+ keys = sorted(data.keys())
+ query = "INSERT OR IGNORE INTO outhashes_v2 ({fields}) VALUES({values})".format(
+ fields=", ".join(keys),
+ values=", ".join(":" + k for k in keys),
+ )
+ with closing(self.db.cursor()) as cursor:
+ prevrowid = cursor.lastrowid
+ cursor.execute(query, data)
+ self.db.commit()
+ return cursor.lastrowid != prevrowid
Abstracts the way the database backend is accessed by the hash equivalence server to make it possible to use other backends Signed-off-by: Joshua Watt <JPEWhacker@gmail.com> --- lib/hashserv/__init__.py | 90 ++----- lib/hashserv/server.py | 491 +++++++++++++-------------------------- lib/hashserv/sqlite.py | 259 +++++++++++++++++++++ 3 files changed, 439 insertions(+), 401 deletions(-) create mode 100644 lib/hashserv/sqlite.py