@@ -7,6 +7,7 @@
#
import contextlib
+import http.server
import shutil
import unittest
import unittest.mock
@@ -18,6 +19,7 @@ import os
import signal
import subprocess
import tarfile
+import threading
from bb.fetch2 import URI
import bb
import bb.utils
@@ -1643,6 +1645,41 @@ class FetchCheckStatusTest(FetcherTest):
"ftp://sourceware.org/pub/libffi/libffi-1.20.tar.gz",
]
+ def _start_checkstatus_server(self):
+ class CheckStatusHTTPRequestHandler(http.server.BaseHTTPRequestHandler):
+ def do_HEAD(self):
+ self.server.requests.append((self.path, dict(self.headers)))
+ if self.path == "/a" and self.server.redirect_url:
+ self.send_response(302)
+ self.send_header("Location", self.server.redirect_url)
+ self.end_headers()
+ return
+ self.send_response(200)
+ self.end_headers()
+
+ def log_message(self, format_str, *args):
+ pass
+
+ server = http.server.HTTPServer(("127.0.0.1", 0), CheckStatusHTTPRequestHandler)
+ server.redirect_url = None
+ server.requests = []
+ thread = threading.Thread(target=server.serve_forever, kwargs={"poll_interval": 0.05})
+ thread.daemon = True
+ thread.start()
+
+ def stop_server():
+ server.shutdown()
+ thread.join()
+ server.server_close()
+
+ self.addCleanup(stop_server)
+ return server
+
+ def _checkstatus(self, url):
+ fetch = bb.fetch2.Fetch([url], self.d)
+ ud = fetch.ud[url]
+ return ud.method.checkstatus(fetch, ud, self.d)
+
@skipIfNoNetwork()
def test_wget_checkstatus(self):
fetch = bb.fetch2.Fetch(self.test_wget_uris, self.d)
@@ -1670,6 +1707,31 @@ class FetchCheckStatusTest(FetcherTest):
connection_cache.close_connections()
+ def test_wget_checkstatus_same_origin_redirect_keeps_auth(self):
+ server = self._start_checkstatus_server()
+ server.redirect_url = "http://127.0.0.1:%s/b" % server.server_port
+
+ url = "http://127.0.0.1:%s/a;user=user;pswd=pass" % server.server_port
+ self.assertTrue(self._checkstatus(url))
+
+ self.assertEqual(len(server.requests), 2)
+ redirected_headers = {k.lower(): v for k, v in server.requests[1][1].items()}
+ self.assertIn("authorization", redirected_headers)
+
+ def test_wget_checkstatus_different_origin_redirect_drops_auth(self):
+ origin = self._start_checkstatus_server()
+ target = self._start_checkstatus_server()
+ # Same host but different port is a different origin.
+ origin.redirect_url = "http://127.0.0.1:%s/b" % target.server_port
+
+ url = "http://127.0.0.1:%s/a;user=user;pswd=pass" % origin.server_port
+ self.assertTrue(self._checkstatus(url))
+
+ self.assertEqual(len(origin.requests), 1)
+ self.assertEqual(len(target.requests), 1)
+ redirected_headers = {k.lower(): v for k, v in target.requests[0][1].items()}
+ self.assertNotIn("authorization", redirected_headers)
+
class GitMakeShallowTest(FetcherTest):
def setUp(self):