diff mbox series

[scarthgap,2.8,3/4] tests/fetch: cover checkstatus redirect auth handling

Message ID 2b0f7fb5f54a415d851038ba7cb836b18289e000.1781271084.git.jeremy.rosen@smile.fr
State RFC, archived
Headers show
Series [scarthgap,2.8,1/4] fetch2/wget: handle HTTP 308 Permanent Redirect | expand

Commit Message

Jeremy Rosen June 12, 2026, 2:29 p.m. UTC
From: Anders Heimer <anders.heimer@est.tech>

Add local HTTP server tests for Wget.checkstatus() redirects. They check
that Authorization is kept for same-origin redirects and dropped when the
target has a different origin.

Signed-off-by: Anders Heimer <anders.heimer@est.tech>
Signed-off-by: Richard Purdie <richard.purdie@linuxfoundation.org>
(cherry picked from commit c687d42b81b17e7a2399099cab0f1a6aafcf6520)
Signed-off-by: Jeremy Rosen <jeremy.rosen@smile.fr>
---
 lib/bb/tests/fetch.py | 62 +++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 62 insertions(+)
diff mbox series

Patch

diff --git a/lib/bb/tests/fetch.py b/lib/bb/tests/fetch.py
index 2d95ef87d..a658b89a8 100644
--- a/lib/bb/tests/fetch.py
+++ b/lib/bb/tests/fetch.py
@@ -7,6 +7,7 @@ 
 #
 
 import contextlib
+import http.server
 import shutil
 import unittest
 import hashlib
@@ -16,6 +17,7 @@  import os
 import signal
 import subprocess
 import tarfile
+import threading
 from bb.fetch2 import URI
 from bb.fetch2 import FetchMethod
 import bb
@@ -1610,6 +1612,41 @@  class FetchCheckStatusTest(FetcherTest):
                       "https://github.com/kergoth/tslib/releases/download/1.1/tslib-1.1.tar.xz"
                       ]
 
+    def _start_checkstatus_server(self):
+        class CheckStatusHTTPRequestHandler(http.server.BaseHTTPRequestHandler):
+            def do_HEAD(self):
+                self.server.requests.append((self.path, dict(self.headers)))
+                if self.path == "/a" and self.server.redirect_url:
+                    self.send_response(302)
+                    self.send_header("Location", self.server.redirect_url)
+                    self.end_headers()
+                    return
+                self.send_response(200)
+                self.end_headers()
+
+            def log_message(self, format_str, *args):
+                pass
+
+        server = http.server.HTTPServer(("127.0.0.1", 0), CheckStatusHTTPRequestHandler)
+        server.redirect_url = None
+        server.requests = []
+        thread = threading.Thread(target=server.serve_forever, kwargs={"poll_interval": 0.05})
+        thread.daemon = True
+        thread.start()
+
+        def stop_server():
+            server.shutdown()
+            thread.join()
+            server.server_close()
+
+        self.addCleanup(stop_server)
+        return server
+
+    def _checkstatus(self, url):
+        fetch = bb.fetch2.Fetch([url], self.d)
+        ud = fetch.ud[url]
+        return ud.method.checkstatus(fetch, ud, self.d)
+
     @skipIfNoNetwork()
     def test_wget_checkstatus(self):
         fetch = bb.fetch2.Fetch(self.test_wget_uris, self.d)
@@ -1637,6 +1674,31 @@  class FetchCheckStatusTest(FetcherTest):
 
         connection_cache.close_connections()
 
+    def test_wget_checkstatus_same_origin_redirect_keeps_auth(self):
+        server = self._start_checkstatus_server()
+        server.redirect_url = "http://127.0.0.1:%s/b" % server.server_port
+
+        url = "http://127.0.0.1:%s/a;user=user;pswd=pass" % server.server_port
+        self.assertTrue(self._checkstatus(url))
+
+        self.assertEqual(len(server.requests), 2)
+        redirected_headers = {k.lower(): v for k, v in server.requests[1][1].items()}
+        self.assertIn("authorization", redirected_headers)
+
+    def test_wget_checkstatus_different_origin_redirect_drops_auth(self):
+        origin = self._start_checkstatus_server()
+        target = self._start_checkstatus_server()
+        # Same host but different port is a different origin.
+        origin.redirect_url = "http://127.0.0.1:%s/b" % target.server_port
+
+        url = "http://127.0.0.1:%s/a;user=user;pswd=pass" % origin.server_port
+        self.assertTrue(self._checkstatus(url))
+
+        self.assertEqual(len(origin.requests), 1)
+        self.assertEqual(len(target.requests), 1)
+        redirected_headers = {k.lower(): v for k, v in target.requests[0][1].items()}
+        self.assertNotIn("authorization", redirected_headers)
+
 
 class GitMakeShallowTest(FetcherTest):
     def setUp(self):