@@ -141,15 +141,14 @@ class FlowControlMixin(protocols.Protocol):
141141 resume_reading() and connection_lost(). If the subclass overrides
142142 these it must call the super methods.
143143
144- StreamWriter.drain() must check for error conditions and then call
145- _make_drain_waiter(), which will return either () or a Future
146- depending on the paused state.
144+ StreamWriter.drain() must wait for _drain_helper() coroutine.
147145 """
148146
149147 def __init__ (self , loop = None ):
150148 self ._loop = loop # May be None; we may never need it.
151149 self ._paused = False
152150 self ._drain_waiter = None
151+ self ._connection_lost = False
153152
154153 def pause_writing (self ):
155154 assert not self ._paused
@@ -170,6 +169,7 @@ def resume_writing(self):
170169 waiter .set_result (None )
171170
172171 def connection_lost (self , exc ):
172+ self ._connection_lost = True
173173 # Wake up the writer if currently paused.
174174 if not self ._paused :
175175 return
@@ -184,14 +184,17 @@ def connection_lost(self, exc):
184184 else :
185185 waiter .set_exception (exc )
186186
187- def _make_drain_waiter (self ):
187+ @coroutine
188+ def _drain_helper (self ):
189+ if self ._connection_lost :
190+ raise ConnectionResetError ('Connection lost' )
188191 if not self ._paused :
189- return ()
192+ return
190193 waiter = self ._drain_waiter
191194 assert waiter is None or waiter .cancelled ()
192195 waiter = futures .Future (loop = self ._loop )
193196 self ._drain_waiter = waiter
194- return waiter
197+ yield from waiter
195198
196199
197200class StreamReaderProtocol (FlowControlMixin , protocols .Protocol ):
@@ -247,6 +250,8 @@ class StreamWriter:
247250 def __init__ (self , transport , protocol , reader , loop ):
248251 self ._transport = transport
249252 self ._protocol = protocol
253+ # drain() expects that the reader has a exception() method
254+ assert reader is None or isinstance (reader , StreamReader )
250255 self ._reader = reader
251256 self ._loop = loop
252257
@@ -278,26 +283,20 @@ def close(self):
278283 def get_extra_info (self , name , default = None ):
279284 return self ._transport .get_extra_info (name , default )
280285
286+ @coroutine
281287 def drain (self ):
282- """This method has an unusual return value .
288+ """Flush the write buffer .
283289
284290 The intended use is to write
285291
286292 w.write(data)
287293 yield from w.drain()
288-
289- When there's nothing to wait for, drain() returns (), and the
290- yield-from continues immediately. When the transport buffer
291- is full (the protocol is paused), drain() creates and returns
292- a Future and the yield-from will block until that Future is
293- completed, which will happen when the buffer is (partially)
294- drained and the protocol is resumed.
295294 """
296- if self ._reader is not None and self . _reader . _exception is not None :
297- raise self ._reader ._exception
298- if self . _transport . _conn_lost : # Uses private variable.
299- raise ConnectionResetError ( 'Connection lost' )
300- return self ._protocol ._make_drain_waiter ()
295+ if self ._reader is not None :
296+ exc = self ._reader .exception ()
297+ if exc is not None :
298+ raise exc
299+ yield from self ._protocol ._drain_helper ()
301300
302301
303302class StreamReader :
0 commit comments