Thanks to visit codestin.com
Credit goes to github.com

Skip to content

bpo-36806: Forbid stream objects creation outside of asyncio #13101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 7, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 51 additions & 16 deletions Lib/asyncio/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import socket
import sys
import warnings
import weakref

if hasattr(socket, 'AF_UNIX'):
Expand Down Expand Up @@ -42,11 +43,14 @@ async def open_connection(host=None, port=None, *,
"""
if loop is None:
loop = events.get_event_loop()
reader = StreamReader(limit=limit, loop=loop)
protocol = StreamReaderProtocol(reader, loop=loop)
reader = StreamReader(limit=limit, loop=loop,
_asyncio_internal=True)
protocol = StreamReaderProtocol(reader, loop=loop,
_asyncio_internal=True)
transport, _ = await loop.create_connection(
lambda: protocol, host, port, **kwds)
writer = StreamWriter(transport, protocol, reader, loop)
writer = StreamWriter(transport, protocol, reader, loop,
_asyncio_internal=True)
return reader, writer


Expand Down Expand Up @@ -77,9 +81,11 @@ async def start_server(client_connected_cb, host=None, port=None, *,
loop = events.get_event_loop()

def factory():
reader = StreamReader(limit=limit, loop=loop)
reader = StreamReader(limit=limit, loop=loop,
_asyncio_internal=True)
protocol = StreamReaderProtocol(reader, client_connected_cb,
loop=loop)
loop=loop,
_asyncio_internal=True)
return protocol

return await loop.create_server(factory, host, port, **kwds)
Expand All @@ -93,11 +99,14 @@ async def open_unix_connection(path=None, *,
"""Similar to `open_connection` but works with UNIX Domain Sockets."""
if loop is None:
loop = events.get_event_loop()
reader = StreamReader(limit=limit, loop=loop)
protocol = StreamReaderProtocol(reader, loop=loop)
reader = StreamReader(limit=limit, loop=loop,
_asyncio_internal=True)
protocol = StreamReaderProtocol(reader, loop=loop,
_asyncio_internal=True)
transport, _ = await loop.create_unix_connection(
lambda: protocol, path, **kwds)
writer = StreamWriter(transport, protocol, reader, loop)
writer = StreamWriter(transport, protocol, reader, loop,
_asyncio_internal=True)
return reader, writer

async def start_unix_server(client_connected_cb, path=None, *,
Expand All @@ -107,9 +116,11 @@ async def start_unix_server(client_connected_cb, path=None, *,
loop = events.get_event_loop()

def factory():
reader = StreamReader(limit=limit, loop=loop)
reader = StreamReader(limit=limit, loop=loop,
_asyncio_internal=True)
protocol = StreamReaderProtocol(reader, client_connected_cb,
loop=loop)
loop=loop,
_asyncio_internal=True)
return protocol

return await loop.create_unix_server(factory, path, **kwds)
Expand All @@ -125,11 +136,20 @@ class FlowControlMixin(protocols.Protocol):
StreamWriter.drain() must wait for _drain_helper() coroutine.
"""

def __init__(self, loop=None):
def __init__(self, loop=None, *, _asyncio_internal=False):
if loop is None:
self._loop = events.get_event_loop()
else:
self._loop = loop
if not _asyncio_internal:
# NOTE:
# Avoid inheritance from FlowControlMixin
# Copy-paste the code to your project
# if you need flow control helpers
warnings.warn(f"{self.__class__} should be instaniated "
"by asyncio internals only, "
"please avoid its creation from user code",
DeprecationWarning)
self._paused = False
self._drain_waiter = None
self._connection_lost = False
Expand Down Expand Up @@ -191,8 +211,9 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):

_source_traceback = None

def __init__(self, stream_reader, client_connected_cb=None, loop=None):
super().__init__(loop=loop)
def __init__(self, stream_reader, client_connected_cb=None, loop=None,
*, _asyncio_internal=False):
super().__init__(loop=loop, _asyncio_internal=_asyncio_internal)
if stream_reader is not None:
self._stream_reader_wr = weakref.ref(stream_reader,
self._on_reader_gc)
Expand Down Expand Up @@ -253,7 +274,8 @@ def connection_made(self, transport):
if self._client_connected_cb is not None:
self._stream_writer = StreamWriter(transport, self,
reader,
self._loop)
self._loop,
_asyncio_internal=True)
res = self._client_connected_cb(reader,
self._stream_writer)
if coroutines.iscoroutine(res):
Expand Down Expand Up @@ -311,7 +333,13 @@ class StreamWriter:
directly.
"""

def __init__(self, transport, protocol, reader, loop):
def __init__(self, transport, protocol, reader, loop,
*, _asyncio_internal=False):
if not _asyncio_internal:
warnings.warn(f"{self.__class__} should be instaniated "
"by asyncio internals only, "
"please avoid its creation from user code",
DeprecationWarning)
self._transport = transport
self._protocol = protocol
# drain() expects that the reader has an exception() method
Expand Down Expand Up @@ -388,7 +416,14 @@ class StreamReader:

_source_traceback = None

def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
def __init__(self, limit=_DEFAULT_LIMIT, loop=None,
*, _asyncio_internal=False):
if not _asyncio_internal:
warnings.warn(f"{self.__class__} should be instaniated "
"by asyncio internals only, "
"please avoid its creation from user code",
DeprecationWarning)

# The line length limit is a security feature;
# it also doubles as half the buffer limit.

Expand Down
32 changes: 22 additions & 10 deletions Lib/asyncio/subprocess.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
__all__ = 'create_subprocess_exec', 'create_subprocess_shell'

import subprocess
import warnings

from . import events
from . import protocols
Expand All @@ -18,8 +19,8 @@ class SubprocessStreamProtocol(streams.FlowControlMixin,
protocols.SubprocessProtocol):
"""Like StreamReaderProtocol, but for a subprocess."""

def __init__(self, limit, loop):
super().__init__(loop=loop)
def __init__(self, limit, loop, *, _asyncio_internal=False):
super().__init__(loop=loop, _asyncio_internal=_asyncio_internal)
self._limit = limit
self.stdin = self.stdout = self.stderr = None
self._transport = None
Expand All @@ -42,14 +43,16 @@ def connection_made(self, transport):
stdout_transport = transport.get_pipe_transport(1)
if stdout_transport is not None:
self.stdout = streams.StreamReader(limit=self._limit,
loop=self._loop)
loop=self._loop,
_asyncio_internal=True)
self.stdout.set_transport(stdout_transport)
self._pipe_fds.append(1)

stderr_transport = transport.get_pipe_transport(2)
if stderr_transport is not None:
self.stderr = streams.StreamReader(limit=self._limit,
loop=self._loop)
loop=self._loop,
_asyncio_internal=True)
self.stderr.set_transport(stderr_transport)
self._pipe_fds.append(2)

Expand All @@ -58,7 +61,8 @@ def connection_made(self, transport):
self.stdin = streams.StreamWriter(stdin_transport,
protocol=self,
reader=None,
loop=self._loop)
loop=self._loop,
_asyncio_internal=True)

def pipe_data_received(self, fd, data):
if fd == 1:
Expand Down Expand Up @@ -104,7 +108,13 @@ def _maybe_close_transport(self):


class Process:
def __init__(self, transport, protocol, loop):
def __init__(self, transport, protocol, loop, *, _asyncio_internal=False):
if not _asyncio_internal:
warnings.warn(f"{self.__class__} should be instaniated "
"by asyncio internals only, "
"please avoid its creation from user code",
DeprecationWarning)

self._transport = transport
self._protocol = protocol
self._loop = loop
Expand Down Expand Up @@ -195,12 +205,13 @@ async def create_subprocess_shell(cmd, stdin=None, stdout=None, stderr=None,
if loop is None:
loop = events.get_event_loop()
protocol_factory = lambda: SubprocessStreamProtocol(limit=limit,
loop=loop)
loop=loop,
_asyncio_internal=True)
transport, protocol = await loop.subprocess_shell(
protocol_factory,
cmd, stdin=stdin, stdout=stdout,
stderr=stderr, **kwds)
return Process(transport, protocol, loop)
return Process(transport, protocol, loop, _asyncio_internal=True)


async def create_subprocess_exec(program, *args, stdin=None, stdout=None,
Expand All @@ -209,10 +220,11 @@ async def create_subprocess_exec(program, *args, stdin=None, stdout=None,
if loop is None:
loop = events.get_event_loop()
protocol_factory = lambda: SubprocessStreamProtocol(limit=limit,
loop=loop)
loop=loop,
_asyncio_internal=True)
transport, protocol = await loop.subprocess_exec(
protocol_factory,
program, *args,
stdin=stdin, stdout=stdout,
stderr=stderr, **kwds)
return Process(transport, protocol, loop)
return Process(transport, protocol, loop, _asyncio_internal=True)
Loading