diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py index 59cb6b1babec54..17fc096ca22bbc 100644 --- a/Lib/asyncio/selector_events.py +++ b/Lib/asyncio/selector_events.py @@ -744,9 +744,6 @@ def _call_connection_lost(self, exc): server._detach() self._server = None - def get_write_buffer_size(self): - return len(self._buffer) - def _add_reader(self, fd, callback, *args): if self._closing: return @@ -895,6 +892,9 @@ def _read_ready__on_eof(self): else: self.close() + def get_write_buffer_size(self): + return len(self._buffer) + def write(self, data): if not isinstance(data, (bytes, bytearray, memoryview)): raise TypeError(f'data argument must be a bytes-like object, ' @@ -992,14 +992,32 @@ def _reset_empty_waiter(self): self._empty_waiter = None -class _SelectorDatagramTransport(_SelectorTransport): +class _SelectorDatagramTransport(transports._FlowControlMixin, + transports.DatagramTransport): _buffer_factory = collections.deque + max_size = 256 * 1024 def __init__(self, loop, sock, protocol, address=None, - waiter=None, extra=None): - super().__init__(loop, sock, protocol, extra) + waiter=None, extra=None, server=None): + super().__init__(loop=loop, extra=extra) self._address = address + self._loop = loop + self._protocol = protocol + self._sock = sock + self._sock_fd = sock.fileno() + self._buffer = self._buffer_factory() + self._server = server + self._conn_lost = 0 + self._closing = False + self._protocol_connected = False + + if 'peername' not in self._extra: + try: + self._extra['peername'] = sock.getpeername() + except socket.error: + self._extra['peername'] = None + self._loop.call_soon(self._protocol.connection_made, self) # only start reading when connection_made() has been called self._loop.call_soon(self._add_reader, @@ -1097,3 +1115,70 @@ def _sendto_ready(self): self._loop._remove_writer(self._sock_fd) if self._closing: self._call_connection_lost(None) + + def _add_reader(self, fd, callback, *args): + self._check_closed() + handle = events.Handle(callback, args, self, None) + try: + key = self._selector.get_key(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_READ, + (handle, None)) + else: + mask, (reader, writer) = key.events, key.data + self._selector.modify(fd, mask | selectors.EVENT_READ, + (handle, writer)) + if reader is not None: + reader.cancel() + return handle + + def _fatal_error(self, exc, message='Fatal error on transport'): + # Should be called from exception handler only. + if isinstance(exc, OSError): + if self._loop.get_debug(): + logger.debug("%r: %s", self, message, exc_info=True) + else: + self._loop.call_exception_handler({ + 'message': message, + 'exception': exc, + 'transport': self, + 'protocol': self._protocol, + }) + self._force_close(exc) + + def _force_close(self, exc): + if self._conn_lost: + return + if self._buffer: + self._buffer.clear() + self._loop._remove_writer(self._sock_fd) + if not self._closing: + self._closing = True + self._loop._remove_reader(self._sock_fd) + self._conn_lost += 1 + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + if self._protocol_connected: + self._protocol.connection_lost(exc) + finally: + self._sock.close() + self._sock = None + self._protocol = None + self._loop = None + server = self._server + if server is not None: + server._detach() + self._server = None + + def close(self): + if self._closing: + return + self._closing = True + self._loop._remove_reader(self._sock_fd) + if not self._buffer: + self._conn_lost += 1 + self._loop._remove_writer(self._sock_fd) + self._loop.call_soon(self._call_connection_lost, None) + diff --git a/Lib/test/test_asyncio/test_selector_events.py b/Lib/test/test_asyncio/test_selector_events.py index b684fab2771f20..d9949c2411e1dc 100644 --- a/Lib/test/test_asyncio/test_selector_events.py +++ b/Lib/test/test_asyncio/test_selector_events.py @@ -1089,6 +1089,10 @@ def datagram_transport(self, address=None): self.addCleanup(close_transport, transport) return transport + def test_isinstance_datagram_transport(self): + transport = self.datagram_transport() + self.assertIsInstance(transport, asyncio.DatagramTransport) + def test_read_ready(self): transport = self.datagram_transport() diff --git a/Misc/NEWS.d/next/Library/2021-12-29-19-31-55.bpo-46194.9PbQHt.rst b/Misc/NEWS.d/next/Library/2021-12-29-19-31-55.bpo-46194.9PbQHt.rst new file mode 100644 index 00000000000000..b4c8e10278781b --- /dev/null +++ b/Misc/NEWS.d/next/Library/2021-12-29-19-31-55.bpo-46194.9PbQHt.rst @@ -0,0 +1 @@ +Internal implementation of selector-based datagram transport inherits :class:`asyncio.DatagramTransport`.