@@ -747,22 +747,38 @@ def __init__(self, host, port=None, timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
747747 self ._tunnel_port = None
748748 self ._tunnel_headers = {}
749749
750- self ._set_hostport (host , port )
750+ (self .host , self .port ) = self ._get_hostport (host , port )
751+
752+ # This is stored as an instance variable to allow unit
753+ # tests to replace it with a suitable mockup
754+ self ._create_connection = socket .create_connection
751755
752756 def set_tunnel (self , host , port = None , headers = None ):
753- """ Sets up the host and the port for the HTTP CONNECT Tunnelling.
757+ """Set up host and port for HTTP CONNECT tunnelling.
758+
759+ In a connection that uses HTTP CONNECT tunneling, the host passed to the
760+ constructor is used as a proxy server that relays all communication to
761+ the endpoint passed to `set_tunnel`. This done by sending an HTTP
762+ CONNECT request to the proxy server when the connection is established.
754763
755- The headers argument should be a mapping of extra HTTP headers
756- to send with the CONNECT request.
764+ This method must be called before the HTML connection has been
765+ established.
766+
767+ The headers argument should be a mapping of extra HTTP headers to send
768+ with the CONNECT request.
757769 """
770+
771+ if self .sock :
772+ raise RuntimeError ("Can't set up tunnel for established connection" )
773+
758774 self ._tunnel_host = host
759775 self ._tunnel_port = port
760776 if headers :
761777 self ._tunnel_headers = headers
762778 else :
763779 self ._tunnel_headers .clear ()
764780
765- def _set_hostport (self , host , port ):
781+ def _get_hostport (self , host , port ):
766782 if port is None :
767783 i = host .rfind (':' )
768784 j = host .rfind (']' ) # ipv6 addresses have [...]
@@ -779,15 +795,16 @@ def _set_hostport(self, host, port):
779795 port = self .default_port
780796 if host and host [0 ] == '[' and host [- 1 ] == ']' :
781797 host = host [1 :- 1 ]
782- self . host = host
783- self . port = port
798+
799+ return ( host , port )
784800
785801 def set_debuglevel (self , level ):
786802 self .debuglevel = level
787803
788804 def _tunnel (self ):
789- self ._set_hostport (self ._tunnel_host , self ._tunnel_port )
790- connect_str = "CONNECT %s:%d HTTP/1.0\r \n " % (self .host , self .port )
805+ (host , port ) = self ._get_hostport (self ._tunnel_host ,
806+ self ._tunnel_port )
807+ connect_str = "CONNECT %s:%d HTTP/1.0\r \n " % (host , port )
791808 connect_bytes = connect_str .encode ("ascii" )
792809 self .send (connect_bytes )
793810 for header , value in self ._tunnel_headers .items ():
@@ -815,8 +832,9 @@ def _tunnel(self):
815832
816833 def connect (self ):
817834 """Connect to the host and port specified in __init__."""
818- self .sock = socket .create_connection ((self .host ,self .port ),
819- self .timeout , self .source_address )
835+ self .sock = self ._create_connection ((self .host ,self .port ),
836+ self .timeout , self .source_address )
837+
820838 if self ._tunnel_host :
821839 self ._tunnel ()
822840
@@ -985,22 +1003,29 @@ def putrequest(self, method, url, skip_host=0, skip_accept_encoding=0):
9851003 netloc_enc = netloc .encode ("idna" )
9861004 self .putheader ('Host' , netloc_enc )
9871005 else :
1006+ if self ._tunnel_host :
1007+ host = self ._tunnel_host
1008+ port = self ._tunnel_port
1009+ else :
1010+ host = self .host
1011+ port = self .port
1012+
9881013 try :
989- host_enc = self . host .encode ("ascii" )
1014+ host_enc = host .encode ("ascii" )
9901015 except UnicodeEncodeError :
991- host_enc = self . host .encode ("idna" )
1016+ host_enc = host .encode ("idna" )
9921017
9931018 # As per RFC 273, IPv6 address should be wrapped with []
9941019 # when used as Host header
9951020
996- if self . host .find (':' ) >= 0 :
1021+ if host .find (':' ) >= 0 :
9971022 host_enc = b'[' + host_enc + b']'
9981023
999- if self . port == self .default_port :
1024+ if port == self .default_port :
10001025 self .putheader ('Host' , host_enc )
10011026 else :
10021027 host_enc = host_enc .decode ("ascii" )
1003- self .putheader ('Host' , "%s:%s" % (host_enc , self . port ))
1028+ self .putheader ('Host' , "%s:%s" % (host_enc , port ))
10041029
10051030 # note: we are assuming that clients will not attempt to set these
10061031 # headers since *this* library must deal with the
@@ -1193,19 +1218,19 @@ def __init__(self, host, port=None, key_file=None, cert_file=None,
11931218 def connect (self ):
11941219 "Connect to a host on a given (SSL) port."
11951220
1196- sock = socket .create_connection ((self .host , self .port ),
1197- self .timeout , self .source_address )
1221+ super ().connect ()
11981222
11991223 if self ._tunnel_host :
1200- self .sock = sock
1201- self ._tunnel ()
1224+ server_hostname = self ._tunnel_host
1225+ else :
1226+ server_hostname = self .host
1227+ sni_hostname = server_hostname if ssl .HAS_SNI else None
12021228
1203- server_hostname = self .host if ssl .HAS_SNI else None
1204- self .sock = self ._context .wrap_socket (sock ,
1205- server_hostname = server_hostname )
1229+ self .sock = self ._context .wrap_socket (self .sock ,
1230+ server_hostname = sni_hostname )
12061231 if not self ._context .check_hostname and self ._check_hostname :
12071232 try :
1208- ssl .match_hostname (self .sock .getpeercert (), self . host )
1233+ ssl .match_hostname (self .sock .getpeercert (), server_hostname )
12091234 except Exception :
12101235 self .sock .shutdown (socket .SHUT_RDWR )
12111236 self .sock .close ()
0 commit comments