@@ -212,6 +212,26 @@ def main():
print("New hashes marked: %d" % result["count"])
return 0
+ def handle_gc_mark_stream(args, client):
+ stdin = (l.strip() for l in sys.stdin)
+ marked_hashes = 0
+
+ try:
+ result = client.gc_mark_stream(args.mark, stdin)
+ marked_hashes = result["count"]
+ except ConnectionError:
+ logger.warning(
+ "Server doesn't seem to support `gc-mark-stream`. Sending "
+ "hashes sequentially using `gc-mark` API."
+ )
+ for line in stdin:
+ db_column, hash = line.split()
+ result = client.gc_mark(args.mark, {db_column: hash})
+ marked_hashes += result["count"]
+
+ print("New hashes marked: %d" % marked_hashes)
+ return 0
+
def handle_gc_sweep(args, client):
result = client.gc_sweep(args.mark)
print("Removed %d rows" % result["count"])
@@ -313,6 +333,16 @@ def main():
help="Keep entries in table where KEY == VALUE")
gc_mark_parser.set_defaults(func=handle_gc_mark)
+ gc_mark_parser_stream = subparsers.add_parser(
+ 'gc-mark-stream',
+ help=(
+ "Mark multiple hashes to be retained for garbage collection. Input should be provided via stdin, "
+ "with each line formatted as key-value pairs separated by spaces, for example 'column1 foo column2 bar'."
+ )
+ )
+ gc_mark_parser_stream.add_argument("mark", help="Mark for this garbage collection operation")
+ gc_mark_parser_stream.set_defaults(func=handle_gc_mark_stream)
+
gc_sweep_parser = subparsers.add_parser('gc-sweep', help="Perform garbage collection and delete any entries that are not marked")
gc_sweep_parser.add_argument("mark", help="Mark for this garbage collection operation")
gc_sweep_parser.set_defaults(func=handle_gc_sweep)
@@ -78,6 +78,7 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
MODE_NORMAL = 0
MODE_GET_STREAM = 1
MODE_EXIST_STREAM = 2
+ MODE_MARK_STREAM = 3
def __init__(self, username=None, password=None):
super().__init__("OEHASHEQUIV", "1.1", logger)
@@ -160,6 +161,8 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
await normal_to_stream("get-stream")
elif new_mode == self.MODE_EXIST_STREAM:
await normal_to_stream("exists-stream")
+ elif new_mode == self.MODE_MARK_STREAM:
+ await normal_to_stream("gc-mark-stream")
elif new_mode != self.MODE_NORMAL:
raise Exception("Undefined mode transition {self.mode!r} -> {new_mode!r}")
@@ -302,6 +305,24 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
"""
return await self.invoke({"gc-mark": {"mark": mark, "where": where}})
+ async def gc_mark_stream(self, mark, rows):
+ """
+ Similar to `gc-mark`, but accepts a list of "where" key-value pair
+ conditions. It utilizes stream mode to mark hashes, which helps reduce
+ the impact of latency when communicating with the hash equivalence
+ server.
+ """
+ def row_to_dict(row):
+ pairs = row.split()
+ return dict(zip(pairs[::2], pairs[1::2]))
+
+ responses = await self.send_stream_batch(
+ self.MODE_MARK_STREAM,
+ (json.dumps({"mark": mark, "where": row_to_dict(row)}) for row in rows),
+ )
+
+ return {"count": sum(int(json.loads(r)["count"]) for r in responses)}
+
async def gc_sweep(self, mark):
"""
Finishes garbage collection for "mark". All unihash entries that have
@@ -347,6 +368,7 @@ class Client(bb.asyncrpc.Client):
"get_db_query_columns",
"gc_status",
"gc_mark",
+ "gc_mark_stream",
"gc_sweep",
)
@@ -10,6 +10,7 @@ import math
import time
import os
import base64
+import json
import hashlib
from . import create_async_client
import bb.asyncrpc
@@ -256,6 +257,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
"backfill-wait": self.handle_backfill_wait,
"remove": self.handle_remove,
"gc-mark": self.handle_gc_mark,
+ "gc-mark-stream": self.handle_gc_mark_stream,
"gc-sweep": self.handle_gc_sweep,
"gc-status": self.handle_gc_status,
"clean-unused": self.handle_clean_unused,
@@ -583,6 +585,33 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
return {"count": await self.db.gc_mark(mark, condition)}
+ @permissions(DB_ADMIN_PERM)
+ async def handle_gc_mark_stream(self, request):
+ async def handler(line):
+ try:
+ decoded_line = json.loads(line)
+ except json.JSONDecodeError as exc:
+ raise bb.asyncrpc.InvokeError(
+ "Could not decode JSONL input '%s'" % line
+ ) from exc
+
+ try:
+ mark = decoded_line["mark"]
+ condition = decoded_line["where"]
+ if not isinstance(mark, str):
+ raise TypeError("Bad mark type %s" % type(mark))
+
+ if not isinstance(condition, dict):
+ raise TypeError("Bad condition type %s" % type(condition))
+ except KeyError as exc:
+ raise bb.asyncrpc.InvokeError(
+ "Input line is missing key '%s' " % exc
+ ) from exc
+
+ return json.dumps({"count": await self.db.gc_mark(mark, condition)})
+
+ return await self._stream_handler(handler)
+
@permissions(DB_ADMIN_PERM)
async def handle_gc_sweep(self, request):
mark = request["mark"]
@@ -1054,6 +1054,48 @@ class HashEquivalenceCommonTests(object):
# First hash is still present
self.assertClientGetHash(self.client, taskhash, unihash)
+ def test_gc_stream(self):
+ taskhash = '53b8dce672cb6d0c73170be43f540460bfc347b4'
+ outhash = '5a9cb1649625f0bf41fc7791b635cd9c2d7118c7f021ba87dcd03f72b67ce7a8'
+ unihash = 'f37918cc02eb5a520b1aff86faacbc0a38124646'
+
+ result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
+ self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash')
+
+ taskhash2 = '3bf6f1e89d26205aec90da04854fbdbf73afe6b4'
+ outhash2 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4'
+ unihash2 = 'af36b199320e611fbb16f1f277d3ee1d619ca58b'
+
+ result = self.client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2)
+ self.assertClientGetHash(self.client, taskhash2, unihash2)
+
+ taskhash3 = 'a1117c1f5a7c9ab2f5a39cc6fe5e6152169d09c0'
+ outhash3 = '7289c414905303700a1117c1f5a7c9ab2f5a39cc6fe5e6152169d09c04f9a53c'
+ unihash3 = '905303700a1117c1f5a7c9ab2f5a39cc6fe5e615'
+
+ result = self.client.report_unihash(taskhash3, self.METHOD, outhash3, unihash3)
+ self.assertClientGetHash(self.client, taskhash3, unihash3)
+
+ # Mark the first unihash to be kept
+ ret = self.client.gc_mark_stream("ABC", (f"unihash {h}" for h in [unihash, unihash2]))
+ self.assertEqual(ret, {"count": 2})
+
+ ret = self.client.gc_status()
+ self.assertEqual(ret, {"mark": "ABC", "keep": 2, "remove": 1})
+
+ # Third hash is still there; mark doesn't delete hashes
+ self.assertClientGetHash(self.client, taskhash3, unihash3)
+
+ ret = self.client.gc_sweep("ABC")
+ self.assertEqual(ret, {"count": 1})
+
+ # Hash is gone. Taskhash is returned for second hash
+ self.assertClientGetHash(self.client, taskhash3, None)
+ # First hash is still present
+ self.assertClientGetHash(self.client, taskhash, unihash)
+ # Second hash is still present
+ self.assertClientGetHash(self.client, taskhash2, unihash2)
+
def test_gc_switch_mark(self):
taskhash = '53b8dce672cb6d0c73170be43f540460bfc347b4'
outhash = '5a9cb1649625f0bf41fc7791b635cd9c2d7118c7f021ba87dcd03f72b67ce7a8'