diff --git a/src/databricks/sql/auth/thrift_http_client.py b/src/databricks/sql/auth/thrift_http_client.py index 7ed35e54a..a924ea63c 100644 --- a/src/databricks/sql/auth/thrift_http_client.py +++ b/src/databricks/sql/auth/thrift_http_client.py @@ -1,8 +1,11 @@ import logging from typing import Dict + import thrift +import urllib.parse, six, base64 + logger = logging.getLogger(__name__) @@ -33,3 +36,14 @@ def flush(self): self._headers = headers self.setCustomHeaders(self._headers) super().flush() + + @staticmethod + def basic_proxy_auth_header(proxy): + if proxy is None or not proxy.username: + return None + ap = "%s:%s" % ( + urllib.parse.unquote(proxy.username), + urllib.parse.unquote(proxy.password), + ) + cr = base64.b64encode(ap.encode()).strip() + return "Basic " + six.ensure_str(cr) diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 78181ef8e..1c2e589bf 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -143,6 +143,21 @@ def test_headers_are_set(self, t_http_client_class): ThriftBackend("foo", 123, "bar", [("header", "value")], auth_provider=AuthProvider()) t_http_client_class.return_value.setCustomHeaders.assert_called_with({"header": "value"}) + def test_proxy_headers_are_set(self): + + from databricks.sql.auth.thrift_http_client import THttpClient + from urllib.parse import urlparse + + fake_proxy_spec = "https://someuser:somepassword@8.8.8.8:12340" + parsed_proxy = urlparse(fake_proxy_spec) + + try: + result = THttpClient.basic_proxy_auth_header(parsed_proxy) + except TypeError as e: + assert False + + assert isinstance(result, type(str())) + @patch("databricks.sql.auth.thrift_http_client.THttpClient") @patch("databricks.sql.thrift_backend.create_default_context") def test_tls_cert_args_are_propagated(self, mock_create_default_context, t_http_client_class):