|
29 | 29 | import warnings |
30 | 30 | import weakref |
31 | 31 |
|
| 32 | +try: |
| 33 | + import ssl |
| 34 | +except ImportError: # pragma: no cover |
| 35 | + ssl = None |
| 36 | + |
32 | 37 | from . import coroutines |
33 | 38 | from . import events |
34 | 39 | from . import futures |
| 40 | +from . import sslproto |
35 | 41 | from . import tasks |
36 | 42 | from .log import logger |
37 | 43 |
|
@@ -279,7 +285,8 @@ def _make_ssl_transport( |
279 | 285 | self, rawsock, protocol, sslcontext, waiter=None, |
280 | 286 | *, server_side=False, server_hostname=None, |
281 | 287 | extra=None, server=None, |
282 | | - ssl_handshake_timeout=None): |
| 288 | + ssl_handshake_timeout=None, |
| 289 | + call_connection_made=True): |
283 | 290 | """Create SSL transport.""" |
284 | 291 | raise NotImplementedError |
285 | 292 |
|
@@ -795,6 +802,42 @@ async def _create_connection_transport( |
795 | 802 |
|
796 | 803 | return transport, protocol |
797 | 804 |
|
| 805 | + async def start_tls(self, transport, protocol, sslcontext, *, |
| 806 | + server_side=False, |
| 807 | + server_hostname=None, |
| 808 | + ssl_handshake_timeout=None): |
| 809 | + """Upgrade transport to TLS. |
| 810 | +
|
| 811 | + Return a new transport that *protocol* should start using |
| 812 | + immediately. |
| 813 | + """ |
| 814 | + if ssl is None: |
| 815 | + raise RuntimeError('Python ssl module is not available') |
| 816 | + |
| 817 | + if not isinstance(sslcontext, ssl.SSLContext): |
| 818 | + raise TypeError( |
| 819 | + f'sslcontext is expected to be an instance of ssl.SSLContext, ' |
| 820 | + f'got {sslcontext!r}') |
| 821 | + |
| 822 | + if not getattr(transport, '_start_tls_compatible', False): |
| 823 | + raise TypeError( |
| 824 | + f'transport {self!r} is not supported by start_tls()') |
| 825 | + |
| 826 | + waiter = self.create_future() |
| 827 | + ssl_protocol = sslproto.SSLProtocol( |
| 828 | + self, protocol, sslcontext, waiter, |
| 829 | + server_side, server_hostname, |
| 830 | + ssl_handshake_timeout=ssl_handshake_timeout, |
| 831 | + call_connection_made=False) |
| 832 | + |
| 833 | + transport.set_protocol(ssl_protocol) |
| 834 | + self.call_soon(ssl_protocol.connection_made, transport) |
| 835 | + if not transport.is_reading(): |
| 836 | + self.call_soon(transport.resume_reading) |
| 837 | + |
| 838 | + await waiter |
| 839 | + return ssl_protocol._app_transport |
| 840 | + |
798 | 841 | async def create_datagram_endpoint(self, protocol_factory, |
799 | 842 | local_addr=None, remote_addr=None, *, |
800 | 843 | family=0, proto=0, flags=0, |
|
0 commit comments