Thanks to visit codestin.com
Credit goes to github.com

Skip to content

async: replace urllib3 with httpx in THttpClient #261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 195 additions & 0 deletions src/databricks/sql/auth/async_thrift_http_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
import base64
import logging
import urllib.parse
from typing import Dict, Union

import six
import thrift

logger = logging.getLogger(__name__)

import ssl
import warnings
from http.client import HTTPResponse
from io import BytesIO

import httpx

from databricks.sql.auth.retry import CommandType, DatabricksRetryPolicy


class THttpClient(thrift.transport.THttpClient.THttpClient):
def __init__(
self,
auth_provider,
uri_or_host,
port=None,
path=None,
cafile=None,
cert_file=None,
key_file=None,
ssl_context=None,
max_connections: int = 1,
retry_policy: Union[DatabricksRetryPolicy, int] = 0,
):
if port is not None:
warnings.warn(
"Please use the THttpClient('http{s}://host:port/path') constructor",
DeprecationWarning,
stacklevel=2,
)
self.host = uri_or_host
self.port = port
assert path
self.path = path
self.scheme = "http"
else:
parsed = urllib.parse.urlsplit(uri_or_host)
self.scheme = parsed.scheme
assert self.scheme in ("http", "https")
if self.scheme == "https":
self.certfile = cert_file
self.keyfile = key_file
self.context = (
ssl.create_default_context(cafile=cafile)
if (cafile and not ssl_context)
else ssl_context
)
self.port = parsed.port
self.host = parsed.hostname
self.path = parsed.path
if parsed.query:
self.path += "?%s" % parsed.query
try:
proxy = urllib.request.getproxies()[self.scheme]
except KeyError:
proxy = None
else:
if urllib.request.proxy_bypass(self.host):
proxy = None
if proxy:
parsed = urllib.parse.urlparse(proxy)

# realhost and realport are the host and port of the actual request
self.realhost = self.host
self.realport = self.port

# this is passed to ProxyManager
self.proxy_uri: str = proxy
self.host = parsed.hostname
self.port = parsed.port
self.proxy_auth = self.basic_proxy_auth_header(parsed)
else:
self.realhost = self.realport = self.proxy_auth = None

self.max_connections = max_connections

# If retry_policy == 0 then urllib3 will not retry automatically
# this falls back to the pre-v3 behaviour where thrift_backend.py handles retry logic
self.retry_policy = retry_policy

self.__wbuf = BytesIO()
self.__rbuf = None
self.__resp: Union[None, httpx.Response] = None
self.__timeout = None
self.__custom_headers = None

self.__auth_provider = auth_provider

def setCustomHeaders(self, headers: Dict[str, str]):
self._headers = headers
super().setCustomHeaders(headers)

def startRetryTimer(self):
"""Notify DatabricksRetryPolicy of the request start time

This is used to enforce the retry_stop_after_attempts_duration
"""
self.retry_policy and self.retry_policy.start_retry_timer()

def open(self):
# pretty sure we need this to not default to 1 once I enable async because each request uses one connection
limits = httpx.Limits(max_connections=self.max_connections)

# httpx automatically handles http(s)
# TODO: implement proxy handling (deferred as we don't have any e2e tests for it)
# TODO: rename self._pool once POC works
self.__pool = httpx.Client(limits=limits)

def close(self):
self.__resp and self.__resp.close()
# must clear out buffer because thrift leaves stray bytes behind
# equivalent to urllib3's .drain_conn() method
self.__rbuf = None
self.__resp = None

def read(self, sz):
return self.__rbuf.read(sz)

def isOpen(self):
return self.__resp is not None

def flush(self):
# Pull data out of buffer that will be sent in this request
data = self.__wbuf.getvalue()
self.__wbuf = BytesIO()

# Header handling

headers = dict(self._headers)
self.__auth_provider.add_headers(headers)
self._headers = headers
self.setCustomHeaders(self._headers)

# Note: we don't set User-Agent explicitly in this class because PySQL
# should always provide one. Unlike the original THttpClient class, our version
# doesn't define a default User-Agent and so should raise an exception if one
# isn't provided.
assert self.__custom_headers and "User-Agent" in self.__custom_headers

headers = {
"Content-Type": "application/x-thrift",
"Content-Length": str(len(data)),
}

if self.__custom_headers:
custom_headers = {key: val for key, val in self.__custom_headers.items()}
headers.update(**custom_headers)

target_url_parts = (self.scheme, self.host, self.path, "", "")
target_url = urllib.parse.urlunsplit(target_url_parts)
# HTTP request
self.__resp = self.__pool.request(
"POST",
url=target_url,
content=data,
headers=headers,
timeout=self.__timeout,
)

self.__rbuf = BytesIO(self.__resp.read())

# Get reply to flush the request
self.code = self.__resp.status_code
self.message = self.__resp.reason_phrase
self.headers = self.__resp.headers

@staticmethod
def basic_proxy_auth_header(proxy):
if proxy is None or not proxy.username:
return None
ap = "%s:%s" % (
urllib.parse.unquote(proxy.username),
urllib.parse.unquote(proxy.password),
)
cr = base64.b64encode(ap.encode()).strip()
return "Basic " + six.ensure_str(cr)

def set_retry_command_type(self, value: CommandType):
"""Pass the provided CommandType to the retry policy"""
if isinstance(self.retry_policy, DatabricksRetryPolicy):
self.retry_policy.command_type = value
else:
logger.warning(
"DatabricksRetryPolicy is currently bypassed. The CommandType cannot be set."
)
3 changes: 2 additions & 1 deletion src/databricks/sql/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import urllib3.exceptions

import databricks.sql.auth.thrift_http_client
import databricks.sql.auth.async_thrift_http_client
from databricks.sql.auth.thrift_http_client import CommandType
from databricks.sql.auth.authenticators import AuthProvider
from databricks.sql.thrift_api.TCLIService import TCLIService, ttypes
Expand Down Expand Up @@ -212,7 +213,7 @@ def __init__(

additional_transport_args["retry_policy"] = self.retry_policy

self._transport = databricks.sql.auth.thrift_http_client.THttpClient(
self._transport = databricks.sql.auth.async_thrift_http_client.THttpClient(
auth_provider=self._auth_provider,
uri_or_host=uri,
ssl_context=ssl_context,
Expand Down