@@ -199,6 +199,9 @@ async def _drain_helper(self):
199199 self ._drain_waiter = waiter
200200 await waiter
201201
202+ def _get_close_waiter (self , stream ):
203+ raise NotImplementedError
204+
202205
203206class StreamReaderProtocol (FlowControlMixin , protocols .Protocol ):
204207 """Helper class to adapt between Protocol and StreamReader.
@@ -315,6 +318,9 @@ def eof_received(self):
315318 return False
316319 return True
317320
321+ def _get_close_waiter (self , stream ):
322+ return self ._closed
323+
318324 def __del__ (self ):
319325 # Prevent reports about unhandled exceptions.
320326 # Better than self._closed._log_traceback = False hack
@@ -376,7 +382,7 @@ def is_closing(self):
376382 return self ._transport .is_closing ()
377383
378384 async def wait_closed (self ):
379- await self ._protocol ._closed
385+ await self ._protocol ._get_close_waiter ( self )
380386
381387 def get_extra_info (self , name , default = None ):
382388 return self ._transport .get_extra_info (name , default )
@@ -394,13 +400,12 @@ async def drain(self):
394400 if exc is not None :
395401 raise exc
396402 if self ._transport .is_closing ():
397- # Yield to the event loop so connection_lost() may be
398- # called. Without this, _drain_helper() would return
399- # immediately, and code that calls
400- # write(...); await drain()
401- # in a loop would never call connection_lost(), so it
402- # would not see an error when the socket is closed.
403- await sleep (0 , loop = self ._loop )
403+ # Wait for protocol.connection_lost() call
404+ # Raise connection closing error if any,
405+ # ConnectionResetError otherwise
406+ fut = self ._protocol ._get_close_waiter (self )
407+ await fut
408+ raise ConnectionResetError ('Connection lost' )
404409 await self ._protocol ._drain_helper ()
405410
406411 async def aclose (self ):
0 commit comments