From dd7f1f6aead38fa91d9b719279587bb9a3b5b323 Mon Sep 17 00:00:00 2001 From: Jesse Whitehouse Date: Wed, 25 Oct 2023 12:06:05 -0400 Subject: [PATCH 1/3] =?UTF-8?q?Clone=20thrift=5Fhttp=5Fclient.py=20?= =?UTF-8?q?=E2=86=92=20async=5Fthrift=5Fhttp=5Fclient.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Jesse Whitehouse --- .../sql/auth/async_thrift_http_client.py | 214 ++++++++++++++++++ 1 file changed, 214 insertions(+) create mode 100644 src/databricks/sql/auth/async_thrift_http_client.py diff --git a/src/databricks/sql/auth/async_thrift_http_client.py b/src/databricks/sql/auth/async_thrift_http_client.py new file mode 100644 index 000000000..11589258f --- /dev/null +++ b/src/databricks/sql/auth/async_thrift_http_client.py @@ -0,0 +1,214 @@ +import base64 +import logging +import urllib.parse +from typing import Dict, Union + +import six +import thrift + +logger = logging.getLogger(__name__) + +import ssl +import warnings +from http.client import HTTPResponse +from io import BytesIO + +from urllib3 import HTTPConnectionPool, HTTPSConnectionPool, ProxyManager + +from databricks.sql.auth.retry import CommandType, DatabricksRetryPolicy + + +class THttpClient(thrift.transport.THttpClient.THttpClient): + def __init__( + self, + auth_provider, + uri_or_host, + port=None, + path=None, + cafile=None, + cert_file=None, + key_file=None, + ssl_context=None, + max_connections: int = 1, + retry_policy: Union[DatabricksRetryPolicy, int] = 0, + ): + if port is not None: + warnings.warn( + "Please use the THttpClient('http{s}://host:port/path') constructor", + DeprecationWarning, + stacklevel=2, + ) + self.host = uri_or_host + self.port = port + assert path + self.path = path + self.scheme = "http" + else: + parsed = urllib.parse.urlsplit(uri_or_host) + self.scheme = parsed.scheme + assert self.scheme in ("http", "https") + if self.scheme == "https": + self.certfile = cert_file + self.keyfile = key_file + self.context = ( + ssl.create_default_context(cafile=cafile) + if (cafile and not ssl_context) + else ssl_context + ) + self.port = parsed.port + self.host = parsed.hostname + self.path = parsed.path + if parsed.query: + self.path += "?%s" % parsed.query + try: + proxy = urllib.request.getproxies()[self.scheme] + except KeyError: + proxy = None + else: + if urllib.request.proxy_bypass(self.host): + proxy = None + if proxy: + parsed = urllib.parse.urlparse(proxy) + + # realhost and realport are the host and port of the actual request + self.realhost = self.host + self.realport = self.port + + # this is passed to ProxyManager + self.proxy_uri: str = proxy + self.host = parsed.hostname + self.port = parsed.port + self.proxy_auth = self.basic_proxy_auth_header(parsed) + else: + self.realhost = self.realport = self.proxy_auth = None + + self.max_connections = max_connections + + # If retry_policy == 0 then urllib3 will not retry automatically + # this falls back to the pre-v3 behaviour where thrift_backend.py handles retry logic + self.retry_policy = retry_policy + + self.__wbuf = BytesIO() + self.__resp: Union[None, HTTPResponse] = None + self.__timeout = None + self.__custom_headers = None + + self.__auth_provider = auth_provider + + def setCustomHeaders(self, headers: Dict[str, str]): + self._headers = headers + super().setCustomHeaders(headers) + + def startRetryTimer(self): + """Notify DatabricksRetryPolicy of the request start time + + This is used to enforce the retry_stop_after_attempts_duration + """ + self.retry_policy and self.retry_policy.start_retry_timer() + + def open(self): + + # self.__pool replaces the self.__http used by the original THttpClient + if self.scheme == "http": + pool_class = HTTPConnectionPool + elif self.scheme == "https": + pool_class = HTTPSConnectionPool + + _pool_kwargs = {"maxsize": self.max_connections} + + if self.using_proxy(): + proxy_manager = ProxyManager( + self.proxy_uri, + num_pools=1, + headers={"Proxy-Authorization": self.proxy_auth}, + ) + self.__pool = proxy_manager.connection_from_host( + host=self.realhost, + port=self.realport, + scheme=self.scheme, + pool_kwargs=_pool_kwargs, + ) + else: + self.__pool = pool_class(self.host, self.port, **_pool_kwargs) + + def close(self): + self.__resp and self.__resp.drain_conn() + self.__resp and self.__resp.release_conn() + self.__resp = None + + def read(self, sz): + return self.__resp.read(sz) + + def isOpen(self): + return self.__resp is not None + + def flush(self): + + # Pull data out of buffer that will be sent in this request + data = self.__wbuf.getvalue() + self.__wbuf = BytesIO() + + # Header handling + + headers = dict(self._headers) + self.__auth_provider.add_headers(headers) + self._headers = headers + self.setCustomHeaders(self._headers) + + # Note: we don't set User-Agent explicitly in this class because PySQL + # should always provide one. Unlike the original THttpClient class, our version + # doesn't define a default User-Agent and so should raise an exception if one + # isn't provided. + assert self.__custom_headers and "User-Agent" in self.__custom_headers + + headers = { + "Content-Type": "application/x-thrift", + "Content-Length": str(len(data)), + } + + if self.using_proxy() and self.scheme == "http" and self.proxy_auth is not None: + headers["Proxy-Authorization" : self.proxy_auth] + + if self.__custom_headers: + custom_headers = {key: val for key, val in self.__custom_headers.items()} + headers.update(**custom_headers) + + # HTTP request + self.__resp = self.__pool.request( + "POST", + url=self.path, + body=data, + headers=headers, + preload_content=False, + timeout=self.__timeout, + retries=self.retry_policy, + ) + + # Get reply to flush the request + self.code = self.__resp.status + self.message = self.__resp.reason + self.headers = self.__resp.headers + + # Saves the cookie sent by the server response + if "Set-Cookie" in self.headers: + self.setCustomHeaders(dict("Cookie", self.headers["Set-Cookie"])) + + @staticmethod + def basic_proxy_auth_header(proxy): + if proxy is None or not proxy.username: + return None + ap = "%s:%s" % ( + urllib.parse.unquote(proxy.username), + urllib.parse.unquote(proxy.password), + ) + cr = base64.b64encode(ap.encode()).strip() + return "Basic " + six.ensure_str(cr) + + def set_retry_command_type(self, value: CommandType): + """Pass the provided CommandType to the retry policy""" + if isinstance(self.retry_policy, DatabricksRetryPolicy): + self.retry_policy.command_type = value + else: + logger.warning( + "DatabricksRetryPolicy is currently bypassed. The CommandType cannot be set." + ) From 59248c0f759125dd614e925a18daf9178bcc7a43 Mon Sep 17 00:00:00 2001 From: Jesse Whitehouse Date: Wed, 25 Oct 2023 12:06:44 -0400 Subject: [PATCH 2/3] import httpx Signed-off-by: Jesse Whitehouse --- src/databricks/sql/auth/async_thrift_http_client.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/databricks/sql/auth/async_thrift_http_client.py b/src/databricks/sql/auth/async_thrift_http_client.py index 11589258f..4af84d0c7 100644 --- a/src/databricks/sql/auth/async_thrift_http_client.py +++ b/src/databricks/sql/auth/async_thrift_http_client.py @@ -13,6 +13,7 @@ from http.client import HTTPResponse from io import BytesIO +import httpx from urllib3 import HTTPConnectionPool, HTTPSConnectionPool, ProxyManager from databricks.sql.auth.retry import CommandType, DatabricksRetryPolicy From 5e632d600080d934f368c994fc343ce6a9e0a4c6 Mon Sep 17 00:00:00 2001 From: Jesse Whitehouse Date: Wed, 25 Oct 2023 16:22:45 -0400 Subject: [PATCH 3/3] Replace urrlib3 with httpx in a synchronous fashion Signed-off-by: Jesse Whitehouse --- .../sql/auth/async_thrift_http_client.py | 64 +++++++------------ src/databricks/sql/thrift_backend.py | 3 +- 2 files changed, 24 insertions(+), 43 deletions(-) diff --git a/src/databricks/sql/auth/async_thrift_http_client.py b/src/databricks/sql/auth/async_thrift_http_client.py index 4af84d0c7..940417d7d 100644 --- a/src/databricks/sql/auth/async_thrift_http_client.py +++ b/src/databricks/sql/auth/async_thrift_http_client.py @@ -13,8 +13,7 @@ from http.client import HTTPResponse from io import BytesIO -import httpx -from urllib3 import HTTPConnectionPool, HTTPSConnectionPool, ProxyManager +import httpx from databricks.sql.auth.retry import CommandType, DatabricksRetryPolicy @@ -90,7 +89,8 @@ def __init__( self.retry_policy = retry_policy self.__wbuf = BytesIO() - self.__resp: Union[None, HTTPResponse] = None + self.__rbuf = None + self.__resp: Union[None, httpx.Response] = None self.__timeout = None self.__custom_headers = None @@ -108,43 +108,28 @@ def startRetryTimer(self): self.retry_policy and self.retry_policy.start_retry_timer() def open(self): + # pretty sure we need this to not default to 1 once I enable async because each request uses one connection + limits = httpx.Limits(max_connections=self.max_connections) - # self.__pool replaces the self.__http used by the original THttpClient - if self.scheme == "http": - pool_class = HTTPConnectionPool - elif self.scheme == "https": - pool_class = HTTPSConnectionPool - - _pool_kwargs = {"maxsize": self.max_connections} - - if self.using_proxy(): - proxy_manager = ProxyManager( - self.proxy_uri, - num_pools=1, - headers={"Proxy-Authorization": self.proxy_auth}, - ) - self.__pool = proxy_manager.connection_from_host( - host=self.realhost, - port=self.realport, - scheme=self.scheme, - pool_kwargs=_pool_kwargs, - ) - else: - self.__pool = pool_class(self.host, self.port, **_pool_kwargs) + # httpx automatically handles http(s) + # TODO: implement proxy handling (deferred as we don't have any e2e tests for it) + # TODO: rename self._pool once POC works + self.__pool = httpx.Client(limits=limits) def close(self): - self.__resp and self.__resp.drain_conn() - self.__resp and self.__resp.release_conn() + self.__resp and self.__resp.close() + # must clear out buffer because thrift leaves stray bytes behind + # equivalent to urllib3's .drain_conn() method + self.__rbuf = None self.__resp = None def read(self, sz): - return self.__resp.read(sz) + return self.__rbuf.read(sz) def isOpen(self): return self.__resp is not None def flush(self): - # Pull data out of buffer that will be sent in this request data = self.__wbuf.getvalue() self.__wbuf = BytesIO() @@ -167,33 +152,28 @@ def flush(self): "Content-Length": str(len(data)), } - if self.using_proxy() and self.scheme == "http" and self.proxy_auth is not None: - headers["Proxy-Authorization" : self.proxy_auth] - if self.__custom_headers: custom_headers = {key: val for key, val in self.__custom_headers.items()} headers.update(**custom_headers) + target_url_parts = (self.scheme, self.host, self.path, "", "") + target_url = urllib.parse.urlunsplit(target_url_parts) # HTTP request self.__resp = self.__pool.request( "POST", - url=self.path, - body=data, + url=target_url, + content=data, headers=headers, - preload_content=False, timeout=self.__timeout, - retries=self.retry_policy, ) + self.__rbuf = BytesIO(self.__resp.read()) + # Get reply to flush the request - self.code = self.__resp.status - self.message = self.__resp.reason + self.code = self.__resp.status_code + self.message = self.__resp.reason_phrase self.headers = self.__resp.headers - # Saves the cookie sent by the server response - if "Set-Cookie" in self.headers: - self.setCustomHeaders(dict("Cookie", self.headers["Set-Cookie"])) - @staticmethod def basic_proxy_auth_header(proxy): if proxy is None or not proxy.username: diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index 6d9707d2f..1ea580c37 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -17,6 +17,7 @@ import urllib3.exceptions import databricks.sql.auth.thrift_http_client +import databricks.sql.auth.async_thrift_http_client from databricks.sql.auth.thrift_http_client import CommandType from databricks.sql.auth.authenticators import AuthProvider from databricks.sql.thrift_api.TCLIService import TCLIService, ttypes @@ -212,7 +213,7 @@ def __init__( additional_transport_args["retry_policy"] = self.retry_policy - self._transport = databricks.sql.auth.thrift_http_client.THttpClient( + self._transport = databricks.sql.auth.async_thrift_http_client.THttpClient( auth_provider=self._auth_provider, uri_or_host=uri, ssl_context=ssl_context,