diff mbox series

[bitbake-devel,v2] hashserv: client: Add batch stream API

Message ID 20240529193751.3372652-1-JPEWhacker@gmail.com
State New
Headers show
Series [bitbake-devel,v2] hashserv: client: Add batch stream API | expand

Commit Message

Joshua Watt May 29, 2024, 7:37 p.m. UTC
Changes the stream mode to do "batch" processing. This means that the
sending and reciving of messages is done simultaneously so that messages
can be sent as fast as possible without having to wait for each reply.
This allows multiple messages to be in flight at once, reducing the
effect of the round trip latency from the server.

Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
 bitbake/lib/hashserv/client.py | 92 ++++++++++++++++++++++++++++++----
 bitbake/lib/hashserv/tests.py  | 75 +++++++++++++++++++++++++++
 2 files changed, 158 insertions(+), 9 deletions(-)
diff mbox series

Patch

diff --git a/bitbake/lib/hashserv/client.py b/bitbake/lib/hashserv/client.py
index 0b254beddd7..9cc5e10708e 100644
--- a/bitbake/lib/hashserv/client.py
+++ b/bitbake/lib/hashserv/client.py
@@ -5,6 +5,7 @@ 
 
 import logging
 import socket
+import asyncio
 import bb.asyncrpc
 import json
 from . import create_async_client
@@ -13,6 +14,51 @@  from . import create_async_client
 logger = logging.getLogger("hashserv.client")
 
 
+class Batch(object):
+    def __init__(self, socket):
+        self.done = False
+        self.cond = asyncio.Condition()
+        self.pending = 0
+        self.socket = socket
+
+    async def recv(self):
+        result = []
+
+        while True:
+            async with self.cond:
+                await self.cond.wait_for(lambda: self.pending or self.done)
+
+                if self.pending == 0:
+                    if self.done:
+                        return result
+                    continue
+
+                self.pending -= 1
+
+            r = await self.socket.recv()
+            result.append(r)
+
+    async def send(self, msgs):
+        try:
+            for m in msgs:
+                await self.socket.send(m)
+
+                async with self.cond:
+                    self.pending += 1
+                    self.cond.notify()
+        finally:
+            async with self.cond:
+                self.done = True
+                self.cond.notify()
+
+    async def process(self, msgs):
+        result = await asyncio.gather(
+            self.recv(),
+            self.send(msgs),
+        )
+        return result[0]
+
+
 class AsyncClient(bb.asyncrpc.AsyncClient):
     MODE_NORMAL = 0
     MODE_GET_STREAM = 1
@@ -36,11 +82,25 @@  class AsyncClient(bb.asyncrpc.AsyncClient):
             if become:
                 await self.become_user(become)
 
-    async def send_stream(self, mode, msg):
+    async def send_stream_batch(self, mode, msgs):
+        """
+        Does a "batch" process of stream messages. This sends the query
+        messages as fast as possible, and simultaneously attempts to read the
+        messages back. This helps to mitigate the effects of latency to the
+        hash equivalence server be allowing multiple queries to be "in-flight"
+        at once
+
+        The implementation does more complicated tracking using a count of sent
+        messages so that `msgs` can be a generator function (i.e. its length is
+        unknown)
+
+        """
+
         async def proc():
             await self._set_mode(mode)
-            await self.socket.send(msg)
-            return await self.socket.recv()
+
+            b = Batch(self.socket)
+            return await b.process(msgs)
 
         return await self._send_wrapper(proc)
 
@@ -89,10 +149,18 @@  class AsyncClient(bb.asyncrpc.AsyncClient):
         self.mode = new_mode
 
     async def get_unihash(self, method, taskhash):
-        r = await self.send_stream(self.MODE_GET_STREAM, "%s %s" % (method, taskhash))
-        if not r:
-            return None
-        return r
+        r = await self.get_unihash_batch([(method, taskhash)])
+        return r[0]
+
+    async def get_unihash_batch(self, args):
+        result = await self.send_stream_batch(
+            self.MODE_GET_STREAM,
+            (f"{method} {taskhash}" for method, taskhash in args),
+        )
+
+        self.logger.warning(f"sent batch of {len(result)}")
+
+        return [r if r else None for r in result]
 
     async def report_unihash(self, taskhash, method, outhash, unihash, extra={}):
         m = extra.copy()
@@ -115,8 +183,12 @@  class AsyncClient(bb.asyncrpc.AsyncClient):
         )
 
     async def unihash_exists(self, unihash):
-        r = await self.send_stream(self.MODE_EXIST_STREAM, unihash)
-        return r == "true"
+        r = await self.unihash_exists_batch([unihash])
+        return r[0]
+
+    async def unihash_exists_batch(self, unihashes):
+        result = await self.send_stream_batch(self.MODE_EXIST_STREAM, unihashes)
+        return [r == "true" for r in result]
 
     async def get_outhash(self, method, outhash, taskhash, with_unihash=True):
         return await self.invoke(
@@ -237,10 +309,12 @@  class Client(bb.asyncrpc.Client):
             "connect_tcp",
             "connect_websocket",
             "get_unihash",
+            "get_unihash_batch",
             "report_unihash",
             "report_unihash_equiv",
             "get_taskhash",
             "unihash_exists",
+            "unihash_exists_batch",
             "get_outhash",
             "get_stats",
             "reset_stats",
diff --git a/bitbake/lib/hashserv/tests.py b/bitbake/lib/hashserv/tests.py
index 0809453cf87..5349cd58677 100644
--- a/bitbake/lib/hashserv/tests.py
+++ b/bitbake/lib/hashserv/tests.py
@@ -594,6 +594,43 @@  class HashEquivalenceCommonTests(object):
                 7: None,
             })
 
+    def test_get_unihash_batch(self):
+        TEST_INPUT = (
+            # taskhash                                   outhash                                                            unihash
+            ('8aa96fcffb5831b3c2c0cb75f0431e3f8b20554a', 'afe240a439959ce86f5e322f8c208e1fedefea9e813f2140c81af866cc9edf7e','218e57509998197d570e2c98512d0105985dffc9'),
+            # Duplicated taskhash with multiple output hashes and unihashes.
+            ('8aa96fcffb5831b3c2c0cb75f0431e3f8b20554a', '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d', 'ae9a7d252735f0dafcdb10e2e02561ca3a47314c'),
+            # Equivalent hash
+            ("044c2ec8aaf480685a00ff6ff49e6162e6ad34e1", '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d', "def64766090d28f627e816454ed46894bb3aab36"),
+            ("e3da00593d6a7fb435c7e2114976c59c5fd6d561", "1cf8713e645f491eb9c959d20b5cae1c47133a292626dda9b10709857cbe688a", "3b5d3d83f07f259e9086fcb422c855286e18a57d"),
+            ('35788efcb8dfb0a02659d81cf2bfd695fb30faf9', '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f', 'f46d3fbb439bd9b921095da657a4de906510d2cd'),
+            ('35788efcb8dfb0a02659d81cf2bfd695fb30fafa', '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f', 'f46d3fbb439bd9b921095da657a4de906510d2ce'),
+            ('9d81d76242cc7cfaf7bf74b94b9cd2e29324ed74', '8470d56547eea6236d7c81a644ce74670ca0bbda998e13c629ef6bb3f0d60b69', '05d2a63c81e32f0a36542ca677e8ad852365c538'),
+        )
+        EXTRA_QUERIES = (
+            "6b6be7a84ab179b4240c4302518dc3f6",
+        )
+
+        for taskhash, outhash, unihash in TEST_INPUT:
+            self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
+
+
+        result = self.client.get_unihash_batch(
+            [(self.METHOD, data[0]) for data in TEST_INPUT] +
+            [(self.METHOD, e) for e in EXTRA_QUERIES]
+        )
+
+        self.assertListEqual(result, [
+            "218e57509998197d570e2c98512d0105985dffc9",
+            "218e57509998197d570e2c98512d0105985dffc9",
+            "218e57509998197d570e2c98512d0105985dffc9",
+            "3b5d3d83f07f259e9086fcb422c855286e18a57d",
+            "f46d3fbb439bd9b921095da657a4de906510d2cd",
+            "f46d3fbb439bd9b921095da657a4de906510d2cd",
+            "05d2a63c81e32f0a36542ca677e8ad852365c538",
+            None,
+        ])
+
     def test_client_pool_unihash_exists(self):
         TEST_INPUT = (
             # taskhash                                   outhash                                                            unihash
@@ -636,6 +673,44 @@  class HashEquivalenceCommonTests(object):
             result = client_pool.unihashes_exist(query)
             self.assertDictEqual(result, expected)
 
+    def test_unihash_exists_batch(self):
+        TEST_INPUT = (
+            # taskhash                                   outhash                                                            unihash
+            ('8aa96fcffb5831b3c2c0cb75f0431e3f8b20554a', 'afe240a439959ce86f5e322f8c208e1fedefea9e813f2140c81af866cc9edf7e','218e57509998197d570e2c98512d0105985dffc9'),
+            # Duplicated taskhash with multiple output hashes and unihashes.
+            ('8aa96fcffb5831b3c2c0cb75f0431e3f8b20554a', '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d', 'ae9a7d252735f0dafcdb10e2e02561ca3a47314c'),
+            # Equivalent hash
+            ("044c2ec8aaf480685a00ff6ff49e6162e6ad34e1", '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d', "def64766090d28f627e816454ed46894bb3aab36"),
+            ("e3da00593d6a7fb435c7e2114976c59c5fd6d561", "1cf8713e645f491eb9c959d20b5cae1c47133a292626dda9b10709857cbe688a", "3b5d3d83f07f259e9086fcb422c855286e18a57d"),
+            ('35788efcb8dfb0a02659d81cf2bfd695fb30faf9', '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f', 'f46d3fbb439bd9b921095da657a4de906510d2cd'),
+            ('35788efcb8dfb0a02659d81cf2bfd695fb30fafa', '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f', 'f46d3fbb439bd9b921095da657a4de906510d2ce'),
+            ('9d81d76242cc7cfaf7bf74b94b9cd2e29324ed74', '8470d56547eea6236d7c81a644ce74670ca0bbda998e13c629ef6bb3f0d60b69', '05d2a63c81e32f0a36542ca677e8ad852365c538'),
+        )
+        EXTRA_QUERIES = (
+            "6b6be7a84ab179b4240c4302518dc3f6",
+        )
+
+        result_unihashes = set()
+
+
+        for taskhash, outhash, unihash in TEST_INPUT:
+            result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash)
+            result_unihashes.add(result["unihash"])
+
+        query = []
+        expected = []
+
+        for _, _, unihash in TEST_INPUT:
+            query.append(unihash)
+            expected.append(unihash in result_unihashes)
+
+
+        for unihash in EXTRA_QUERIES:
+            query.append(unihash)
+            expected.append(False)
+
+        result = self.client.unihash_exists_batch(query)
+        self.assertListEqual(result, expected)
 
     def test_auth_read_perms(self):
         admin_client = self.start_auth_server()