diff --git a/rethinkdb/gevent_net/net_gevent.py b/rethinkdb/gevent_net/net_gevent.py index a151ba5c..2969922e 100644 --- a/rethinkdb/gevent_net/net_gevent.py +++ b/rethinkdb/gevent_net/net_gevent.py @@ -26,6 +26,7 @@ from rethinkdb import net, ql2_pb2 from rethinkdb.errors import ReqlAuthError, ReqlCursorEmpty, ReqlDriverError, ReqlTimeoutError, RqlDriverError, \ RqlTimeoutError +from rethinkdb.helpers import get_hostname_for_ssl_match from rethinkdb.logger import default_logger __all__ = ['Connection'] @@ -103,7 +104,10 @@ def __init__(self, parent): self._socket.close() raise ReqlDriverError("SSL handshake failed (see server log for more information): %s" % str(exc)) try: - ssl.match_hostname(self._socket.getpeercert(), hostname=self.host) + ssl.match_hostname( + self._socket.getpeercert(), + hostname=get_hostname_for_ssl_match(self.host) + ) except ssl.CertificateError: self._socket.close() raise diff --git a/rethinkdb/helpers.py b/rethinkdb/helpers.py index 4a161286..46152e49 100644 --- a/rethinkdb/helpers.py +++ b/rethinkdb/helpers.py @@ -1,10 +1,22 @@ import six + def decode_utf8(string, encoding='utf-8'): if hasattr(string, 'decode'): return string.decode(encoding) return string + def chain_to_bytes(*strings): return b''.join([six.b(string) if isinstance(string, six.string_types) else string for string in strings]) + + +def get_hostname_for_ssl_match(hostname): + parts = hostname.split('.') + + if len(parts) < 3: + return hostname + + parts[0] = '*' + return '.'.join(parts) diff --git a/rethinkdb/net.py b/rethinkdb/net.py index 5a4c8ddc..155e038d 100644 --- a/rethinkdb/net.py +++ b/rethinkdb/net.py @@ -44,6 +44,7 @@ ReqlTimeoutError, ReqlUserError) from rethinkdb.handshake import HandshakeV1_0 +from rethinkdb.helpers import get_hostname_for_ssl_match from rethinkdb.logger import default_logger __all__ = ['Connection', 'Cursor', 'DEFAULT_PORT', 'DefaultConnection', 'make_connection'] @@ -352,7 +353,10 @@ def __init__(self, parent, timeout): "SSL handshake failed (see server log for more information): %s" % str(err)) try: - match_hostname(self._socket.getpeercert(), hostname=self.host) + ssl.match_hostname( + self._socket.getpeercert(), + hostname=get_hostname_for_ssl_match(self.host) + ) except CertificateError: self._socket.close() raise diff --git a/tests/test_helpers.py b/tests/test_helpers.py index ca868de6..68e5fefb 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,6 +1,6 @@ import pytest from mock import Mock -from rethinkdb.helpers import decode_utf8, chain_to_bytes +from rethinkdb.helpers import decode_utf8, chain_to_bytes, get_hostname_for_ssl_match @pytest.mark.unit class TestDecodeUTF8Helper(object): @@ -42,3 +42,34 @@ def test_mixed_chaining(self): result = chain_to_bytes('iron', ' ', b'man') assert result == expected_string + + +@pytest.mark.unit +class TestSSLMatchHostHostnameHelper(object): + def test_subdomain_replaced_to_star(self): + expected_string = '*.example.com' + + result = get_hostname_for_ssl_match('test.example.com') + + assert result == expected_string + + def test_subdomain_replaced_to_star_special_tld(self): + expected_string = '*.example.co.uk' + + result = get_hostname_for_ssl_match('test.example.co.uk') + + assert result == expected_string + + def test_no_subdomain_to_replace(self): + expected_string = 'example.com' + + result = get_hostname_for_ssl_match(expected_string) + + assert result == expected_string + + def test_no_tld(self): + expected_string = 'localhost' + + result = get_hostname_for_ssl_match(expected_string) + + assert result == expected_string