@@ -43,6 +43,10 @@ def check_siggen_version(siggen):
if siggen.find_siginfo_version < siggen.find_siginfo_minversion:
bb.fatal("Siggen from metadata (OE-Core?) is too old, please update it (%s vs %s)" % (siggen.find_siginfo_version, siggen.find_siginfo_minversion))
+def check_hashserv_unihash(unihash):
+ if not hashserv.is_valid_unihash(unihash):
+ bb.fatal("Hash Equivalence Server returned invalid unihash")
+
class SetEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, set) or isinstance(obj, frozenset):
@@ -729,6 +733,7 @@ class SignatureGeneratorUniHashMixIn(object):
if unihashes and unihashes[idx]:
unihash = unihashes[idx]
+ check_hashserv_unihash(unihash)
# A unique hash equal to the taskhash is not very interesting,
# so it is reported it at debug level 2. If they differ, that
# is much more interesting, so it is reported at debug level 1
@@ -747,7 +752,7 @@ class SignatureGeneratorUniHashMixIn(object):
import importlib
taskhash = d.getVar('BB_TASKHASH')
- unihash = d.getVar('BB_UNIHASH')
+ unihash = d.getVar('BB_UNIHASH', expand=False)
report_taskdata = d.getVar('SSTATE_HASHEQUIV_REPORT_TASKDATA') == '1'
tempdir = d.getVar('T')
mcfn = d.getVar('BB_FILENAME')
@@ -809,6 +814,7 @@ class SignatureGeneratorUniHashMixIn(object):
data = client.report_unihash(taskhash, method, outhash, unihash, extra_data)
new_unihash = data['unihash']
+ check_hashserv_unihash(new_unihash)
if new_unihash != unihash:
hashequiv_logger.debug('Task %s unihash changed %s -> %s by server %s' % (taskhash, unihash, new_unihash, self.server))
@@ -848,6 +854,7 @@ class SignatureGeneratorUniHashMixIn(object):
return False
finalunihash = data['unihash']
+ check_hashserv_unihash(finalunihash)
if finalunihash == current_unihash:
hashequiv_logger.verbose('Task %s unihash %s unchanged by server' % (tid, finalunihash))
@@ -9,7 +9,9 @@
import unittest
import logging
import bb
+import bb.data
import time
+from contextlib import contextmanager
logger = logging.getLogger('BitBake.TestSiggen')
@@ -26,3 +28,48 @@ class SiggenTest(unittest.TestCase):
for t in tests:
self.assertEqual(bb.siggen.build_pnid(*t), tests[t])
+ def test_get_unihashes_rejects_invalid_hashserv_unihash(self):
+ class TestClient:
+ def get_unihash_batch(self, query):
+ list(query)
+ return ["${@os.system('true')}"]
+
+ class TestSiggen(bb.siggen.SignatureGeneratorUniHashMixIn):
+ def __init__(self):
+ self.server = "test-server"
+ self.method = "test-method"
+ self.extramethod = {}
+ self.taskhash = {"test.bb:do_compile": "a" * 64}
+ self.unihash = {}
+ self.unitaskhashes = {}
+ self.tidtopn = {}
+ self.setscenetasks = set()
+
+ @contextmanager
+ def client(self):
+ yield TestClient()
+
+ siggen = TestSiggen()
+
+ with self.assertRaises(bb.BBHandledException):
+ siggen.get_unihashes(["test.bb:do_compile"])
+
+ self.assertEqual(siggen.unihash, {})
+ self.assertEqual(siggen.unitaskhashes, {})
+
+ def test_report_unihash_reads_bb_unihash_without_expansion(self):
+ class TestSiggen(bb.siggen.SignatureGeneratorUniHashMixIn):
+ def __init__(self):
+ self.setscenetasks = set()
+ self.taskhash = {"test.bb:do_compile": "b" * 64}
+
+ d = bb.data.init()
+ d.setVar("BB_TASKHASH", "a" * 64)
+ d.setVar("BB_UNIHASH", "${@d.setVar('EXPANDED_UNIHASH', '1') or 'bad'}")
+ d.setVar("SSTATE_HASHEQUIV_REPORT_TASKDATA", "0")
+ d.setVar("T", "/tmp")
+ d.setVar("BB_FILENAME", "test.bb")
+
+ TestSiggen().report_unihash(".", "compile", d)
+
+ self.assertIsNone(d.getVar("EXPANDED_UNIHASH"))
@@ -7,12 +7,19 @@ import asyncio
from contextlib import closing
import itertools
import json
+import re
from collections import namedtuple
from urllib.parse import urlparse
from bb.asyncrpc.client import parse_address, ADDR_TYPE_UNIX, ADDR_TYPE_WS
User = namedtuple("User", ("username", "permissions"))
+UNIHASH_REGEX = re.compile(r"^[0-9a-f]{64}$")
+
+
+def is_valid_unihash(value):
+ return isinstance(value, str) and UNIHASH_REGEX.fullmatch(value) is not None
+
def create_server(
addr,
@@ -13,6 +13,7 @@ import base64
import json
import hashlib
from . import create_async_client
+from . import is_valid_unihash
import bb.asyncrpc
logger = logging.getLogger("hashserv.server")
@@ -173,6 +174,11 @@ def hash_token(algo, salt, token):
return ":".join([algo, salt, h.hexdigest()])
+def validate_unihash(value):
+ if not is_valid_unihash(value):
+ raise bb.asyncrpc.InvokeError("Invalid unihash")
+
+
def permissions(*permissions, allow_anon=True, allow_self_service=False):
"""
Function decorator that can be used to decorate an RPC function call and
@@ -345,7 +351,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
d = {k: row[k] for k in row.keys()}
elif self.upstream_client is not None:
d = await self.upstream_client.get_taskhash(method, taskhash)
- await self.db.insert_unihash(d["method"], d["taskhash"], d["unihash"])
+ await self.insert_unihash(d["method"], d["taskhash"], d["unihash"])
return d
@@ -377,9 +383,13 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
if data is None:
return
- await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"])
+ await self.insert_unihash(data["method"], data["taskhash"], data["unihash"])
await self.db.insert_outhash(data)
+ async def insert_unihash(self, method, taskhash, unihash):
+ validate_unihash(unihash)
+ return await self.db.insert_unihash(method, taskhash, unihash)
+
async def _stream_handler(self, handler):
await self.socket.send_message("ok")
@@ -467,6 +477,8 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
# report is made inside the function
@permissions(READ_PERM)
async def handle_report(self, data):
+ validate_unihash(data.get("unihash"))
+
if self.server.read_only or not self.user_has_permissions(REPORT_PERM):
return await self.report_readonly(data)
@@ -509,7 +521,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
if upstream_data is not None:
unihash = upstream_data["unihash"]
- await self.db.insert_unihash(data["method"], data["taskhash"], unihash)
+ await self.insert_unihash(data["method"], data["taskhash"], unihash)
unihash_data = await self.get_unihash(data["method"], data["taskhash"])
if unihash_data is not None:
@@ -525,7 +537,9 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
@permissions(READ_PERM, REPORT_PERM)
async def handle_equivreport(self, data):
- await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"])
+ validate_unihash(data.get("unihash"))
+
+ await self.insert_unihash(data["method"], data["taskhash"], data["unihash"])
# Fetch the unihash that will be reported for the taskhash. If the
# unihash matches, it means this row was inserted (or the mapping
@@ -888,7 +902,10 @@ class Server(bb.asyncrpc.AsyncServer):
method, taskhash = item
d = await client.get_taskhash(method, taskhash)
if d is not None:
- await db.insert_unihash(d["method"], d["taskhash"], d["unihash"])
+ if is_valid_unihash(d.get("unihash")):
+ await db.insert_unihash(d["method"], d["taskhash"], d["unihash"])
+ else:
+ self.logger.warning("Upstream server returned invalid unihash")
self.backfill_queue.task_done()
def start(self):
@@ -291,6 +291,36 @@ class HashEquivalenceCommonTests(object):
self.assertEqual(result_outhash['outhash'], outhash)
self.assertEqual(result_outhash['outhash_siginfo'], siginfo)
+ def test_report_rejects_invalid_unihash(self):
+ taskhash = '68a9206490b2321bb033fb3eab013a4ec62c41f9'
+ outhash = 'bf5f2efaf1ca351f3b4c3d079363540ab48f7c58db3d23cfbb069cf4ff1ea8f7'
+ invalid_unihashes = (
+ "${@os.system('true')}",
+ 'a' * 63,
+ 'a' * 65,
+ 'A' * 64,
+ None,
+ )
+
+ for unihash in invalid_unihashes:
+ with self.subTest(unihash=unihash):
+ with self.start_client(self.server_address) as client:
+ with self.assertRaises(InvokeError) as context:
+ client.report_unihash(taskhash, self.METHOD, outhash, unihash)
+
+ self.assertEqual(str(context.exception), "Invalid unihash")
+
+ self.assertClientGetHash(self.client, taskhash, None)
+
+ def test_equivreport_rejects_invalid_unihash(self):
+ taskhash = 'ae6339531895ddf5b67e663e6a374ad8ec71d81c'
+
+ with self.assertRaises(InvokeError) as context:
+ self.client.report_unihash_equiv(taskhash, self.METHOD, "${@os.system('true')}")
+
+ self.assertEqual(str(context.exception), "Invalid unihash")
+ self.assertClientGetHash(self.start_client(self.server_address), taskhash, None)
+
def test_stress(self):
def query_server(failures):
client = Client(self.server_address)