Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 88a5bf0

Browse files
committed
asyncio: Add support for UNIX Domain Sockets.
1 parent c36e504 commit 88a5bf0

10 files changed

Lines changed: 738 additions & 193 deletions

File tree

Lib/asyncio/base_events.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,13 @@ def create_connection(self, protocol_factory, host=None, port=None, *,
407407

408408
sock.setblocking(False)
409409

410+
transport, protocol = yield from self._create_connection_transport(
411+
sock, protocol_factory, ssl, server_hostname)
412+
return transport, protocol
413+
414+
@tasks.coroutine
415+
def _create_connection_transport(self, sock, protocol_factory, ssl,
416+
server_hostname):
410417
protocol = protocol_factory()
411418
waiter = futures.Future(loop=self)
412419
if ssl:

Lib/asyncio/events.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,32 @@ def create_server(self, protocol_factory, host=None, port=None, *,
220220
"""
221221
raise NotImplementedError
222222

223+
def create_unix_connection(self, protocol_factory, path, *,
224+
ssl=None, sock=None,
225+
server_hostname=None):
226+
raise NotImplementedError
227+
228+
def create_unix_server(self, protocol_factory, path, *,
229+
sock=None, backlog=100, ssl=None):
230+
"""A coroutine which creates a UNIX Domain Socket server.
231+
232+
The return valud is a Server object, which can be used to stop
233+
the service.
234+
235+
path is a str, representing a file systsem path to bind the
236+
server socket to.
237+
238+
sock can optionally be specified in order to use a preexisting
239+
socket object.
240+
241+
backlog is the maximum number of queued connections passed to
242+
listen() (defaults to 100).
243+
244+
ssl can be set to an SSLContext to enable SSL over the
245+
accepted connections.
246+
"""
247+
raise NotImplementedError
248+
223249
def create_datagram_endpoint(self, protocol_factory,
224250
local_addr=None, remote_addr=None, *,
225251
family=0, proto=0, flags=0):

Lib/asyncio/streams.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
"""Stream-related things."""
22

33
__all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol',
4-
'open_connection', 'start_server', 'IncompleteReadError',
4+
'open_connection', 'start_server',
5+
'open_unix_connection', 'start_unix_server',
6+
'IncompleteReadError',
57
]
68

9+
import socket
10+
711
from . import events
812
from . import futures
913
from . import protocols
@@ -93,6 +97,39 @@ def factory():
9397
return (yield from loop.create_server(factory, host, port, **kwds))
9498

9599

100+
if hasattr(socket, 'AF_UNIX'):
101+
# UNIX Domain Sockets are supported on this platform
102+
103+
@tasks.coroutine
104+
def open_unix_connection(path=None, *,
105+
loop=None, limit=_DEFAULT_LIMIT, **kwds):
106+
"""Similar to `open_connection` but works with UNIX Domain Sockets."""
107+
if loop is None:
108+
loop = events.get_event_loop()
109+
reader = StreamReader(limit=limit, loop=loop)
110+
protocol = StreamReaderProtocol(reader, loop=loop)
111+
transport, _ = yield from loop.create_unix_connection(
112+
lambda: protocol, path, **kwds)
113+
writer = StreamWriter(transport, protocol, reader, loop)
114+
return reader, writer
115+
116+
117+
@tasks.coroutine
118+
def start_unix_server(client_connected_cb, path=None, *,
119+
loop=None, limit=_DEFAULT_LIMIT, **kwds):
120+
"""Similar to `start_server` but works with UNIX Domain Sockets."""
121+
if loop is None:
122+
loop = events.get_event_loop()
123+
124+
def factory():
125+
reader = StreamReader(limit=limit, loop=loop)
126+
protocol = StreamReaderProtocol(reader, client_connected_cb,
127+
loop=loop)
128+
return protocol
129+
130+
return (yield from loop.create_unix_server(factory, path, **kwds))
131+
132+
96133
class FlowControlMixin(protocols.Protocol):
97134
"""Reusable flow control logic for StreamWriter.drain().
98135

Lib/asyncio/test_utils.py

Lines changed: 119 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,18 @@
44
import contextlib
55
import io
66
import os
7+
import socket
8+
import socketserver
79
import sys
10+
import tempfile
811
import threading
912
import time
1013
import unittest
1114
import unittest.mock
15+
16+
from http.server import HTTPServer
1217
from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer
18+
1319
try:
1420
import ssl
1521
except ImportError: # pragma: no cover
@@ -70,42 +76,51 @@ def run_once(loop):
7076
loop.run_forever()
7177

7278

73-
@contextlib.contextmanager
74-
def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
79+
class SilentWSGIRequestHandler(WSGIRequestHandler):
7580

76-
class SilentWSGIRequestHandler(WSGIRequestHandler):
77-
def get_stderr(self):
78-
return io.StringIO()
81+
def get_stderr(self):
82+
return io.StringIO()
7983

80-
def log_message(self, format, *args):
81-
pass
84+
def log_message(self, format, *args):
85+
pass
8286

83-
class SilentWSGIServer(WSGIServer):
84-
def handle_error(self, request, client_address):
87+
88+
class SilentWSGIServer(WSGIServer):
89+
90+
def handle_error(self, request, client_address):
91+
pass
92+
93+
94+
class SSLWSGIServerMixin:
95+
96+
def finish_request(self, request, client_address):
97+
# The relative location of our test directory (which
98+
# contains the ssl key and certificate files) differs
99+
# between the stdlib and stand-alone asyncio.
100+
# Prefer our own if we can find it.
101+
here = os.path.join(os.path.dirname(__file__), '..', 'tests')
102+
if not os.path.isdir(here):
103+
here = os.path.join(os.path.dirname(os.__file__),
104+
'test', 'test_asyncio')
105+
keyfile = os.path.join(here, 'ssl_key.pem')
106+
certfile = os.path.join(here, 'ssl_cert.pem')
107+
ssock = ssl.wrap_socket(request,
108+
keyfile=keyfile,
109+
certfile=certfile,
110+
server_side=True)
111+
try:
112+
self.RequestHandlerClass(ssock, client_address, self)
113+
ssock.close()
114+
except OSError:
115+
# maybe socket has been closed by peer
85116
pass
86117

87-
class SSLWSGIServer(SilentWSGIServer):
88-
def finish_request(self, request, client_address):
89-
# The relative location of our test directory (which
90-
# contains the ssl key and certificate files) differs
91-
# between the stdlib and stand-alone asyncio.
92-
# Prefer our own if we can find it.
93-
here = os.path.join(os.path.dirname(__file__), '..', 'tests')
94-
if not os.path.isdir(here):
95-
here = os.path.join(os.path.dirname(os.__file__),
96-
'test', 'test_asyncio')
97-
keyfile = os.path.join(here, 'ssl_key.pem')
98-
certfile = os.path.join(here, 'ssl_cert.pem')
99-
ssock = ssl.wrap_socket(request,
100-
keyfile=keyfile,
101-
certfile=certfile,
102-
server_side=True)
103-
try:
104-
self.RequestHandlerClass(ssock, client_address, self)
105-
ssock.close()
106-
except OSError:
107-
# maybe socket has been closed by peer
108-
pass
118+
119+
class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
120+
pass
121+
122+
123+
def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
109124

110125
def app(environ, start_response):
111126
status = '200 OK'
@@ -115,9 +130,9 @@ def app(environ, start_response):
115130

116131
# Run the test WSGI server in a separate thread in order not to
117132
# interfere with event handling in the main thread
118-
server_class = SSLWSGIServer if use_ssl else SilentWSGIServer
119-
httpd = make_server(host, port, app,
120-
server_class, SilentWSGIRequestHandler)
133+
server_class = server_ssl_cls if use_ssl else server_cls
134+
httpd = server_class(address, SilentWSGIRequestHandler)
135+
httpd.set_app(app)
121136
httpd.address = httpd.server_address
122137
server_thread = threading.Thread(target=httpd.serve_forever)
123138
server_thread.start()
@@ -129,6 +144,75 @@ def app(environ, start_response):
129144
server_thread.join()
130145

131146

147+
if hasattr(socket, 'AF_UNIX'):
148+
149+
class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
150+
151+
def server_bind(self):
152+
socketserver.UnixStreamServer.server_bind(self)
153+
self.server_name = '127.0.0.1'
154+
self.server_port = 80
155+
156+
157+
class UnixWSGIServer(UnixHTTPServer, WSGIServer):
158+
159+
def server_bind(self):
160+
UnixHTTPServer.server_bind(self)
161+
self.setup_environ()
162+
163+
def get_request(self):
164+
request, client_addr = super().get_request()
165+
# Code in the stdlib expects that get_request
166+
# will return a socket and a tuple (host, port).
167+
# However, this isn't true for UNIX sockets,
168+
# as the second return value will be a path;
169+
# hence we return some fake data sufficient
170+
# to get the tests going
171+
return request, ('127.0.0.1', '')
172+
173+
174+
class SilentUnixWSGIServer(UnixWSGIServer):
175+
176+
def handle_error(self, request, client_address):
177+
pass
178+
179+
180+
class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
181+
pass
182+
183+
184+
def gen_unix_socket_path():
185+
with tempfile.NamedTemporaryFile() as file:
186+
return file.name
187+
188+
189+
@contextlib.contextmanager
190+
def unix_socket_path():
191+
path = gen_unix_socket_path()
192+
try:
193+
yield path
194+
finally:
195+
try:
196+
os.unlink(path)
197+
except OSError:
198+
pass
199+
200+
201+
@contextlib.contextmanager
202+
def run_test_unix_server(*, use_ssl=False):
203+
with unix_socket_path() as path:
204+
yield from _run_test_server(address=path, use_ssl=use_ssl,
205+
server_cls=SilentUnixWSGIServer,
206+
server_ssl_cls=UnixSSLWSGIServer)
207+
208+
209+
@contextlib.contextmanager
210+
def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
211+
yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
212+
server_cls=SilentWSGIServer,
213+
server_ssl_cls=SSLWSGIServer)
214+
215+
132216
def make_test_protocol(base):
133217
dct = {}
134218
for name in dir(base):
@@ -275,5 +359,6 @@ def _process_events(self, event_list):
275359
def _write_to_self(self):
276360
pass
277361

362+
278363
def MockCallback(**kwargs):
279364
return unittest.mock.Mock(spec=['__call__'], **kwargs)

Lib/asyncio/unix_events.py

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import threading
1212

1313

14+
from . import base_events
1415
from . import base_subprocess
1516
from . import constants
1617
from . import events
@@ -31,9 +32,9 @@
3132

3233

3334
class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
34-
"""Unix event loop
35+
"""Unix event loop.
3536
36-
Adds signal handling to SelectorEventLoop
37+
Adds signal handling and UNIX Domain Socket support to SelectorEventLoop.
3738
"""
3839

3940
def __init__(self, selector=None):
@@ -164,6 +165,76 @@ def _make_subprocess_transport(self, protocol, args, shell,
164165
def _child_watcher_callback(self, pid, returncode, transp):
165166
self.call_soon_threadsafe(transp._process_exited, returncode)
166167

168+
@tasks.coroutine
169+
def create_unix_connection(self, protocol_factory, path, *,
170+
ssl=None, sock=None,
171+
server_hostname=None):
172+
assert server_hostname is None or isinstance(server_hostname, str)
173+
if ssl:
174+
if server_hostname is None:
175+
raise ValueError(
176+
'you have to pass server_hostname when using ssl')
177+
else:
178+
if server_hostname is not None:
179+
raise ValueError('server_hostname is only meaningful with ssl')
180+
181+
if path is not None:
182+
if sock is not None:
183+
raise ValueError(
184+
'path and sock can not be specified at the same time')
185+
186+
try:
187+
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0)
188+
sock.setblocking(False)
189+
yield from self.sock_connect(sock, path)
190+
except OSError:
191+
if sock is not None:
192+
sock.close()
193+
raise
194+
195+
else:
196+
if sock is None:
197+
raise ValueError('no path and sock were specified')
198+
sock.setblocking(False)
199+
200+
transport, protocol = yield from self._create_connection_transport(
201+
sock, protocol_factory, ssl, server_hostname)
202+
return transport, protocol
203+
204+
@tasks.coroutine
205+
def create_unix_server(self, protocol_factory, path=None, *,
206+
sock=None, backlog=100, ssl=None):
207+
if isinstance(ssl, bool):
208+
raise TypeError('ssl argument must be an SSLContext or None')
209+
210+
if path is not None:
211+
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
212+
213+
try:
214+
sock.bind(path)
215+
except OSError as exc:
216+
if exc.errno == errno.EADDRINUSE:
217+
# Let's improve the error message by adding
218+
# with what exact address it occurs.
219+
msg = 'Address {!r} is already in use'.format(path)
220+
raise OSError(errno.EADDRINUSE, msg) from None
221+
else:
222+
raise
223+
else:
224+
if sock is None:
225+
raise ValueError(
226+
'path was not specified, and no sock specified')
227+
228+
if sock.family != socket.AF_UNIX:
229+
raise ValueError(
230+
'A UNIX Domain Socket was expected, got {!r}'.format(sock))
231+
232+
server = base_events.Server(self, [sock])
233+
sock.listen(backlog)
234+
sock.setblocking(False)
235+
self._start_serving(protocol_factory, sock, ssl, server)
236+
return server
237+
167238

168239
def _set_nonblocking(fd):
169240
flags = fcntl.fcntl(fd, fcntl.F_GETFL)

0 commit comments

Comments
 (0)