diff --git a/openai/api_requestor.py b/openai/api_requestor.py index a34ee281ec..eff7dd8a0a 100644 --- a/openai/api_requestor.py +++ b/openai/api_requestor.py @@ -89,9 +89,9 @@ def _make_session() -> requests.Session: return s -def parse_stream_helper(line): +def parse_stream_helper(line: bytes): if line: - if line == b"data: [DONE]": + if line.strip() == b"data: [DONE]": # return here will cause GeneratorExit exception in urllib3 # and it will close http connection with TCP Reset return None @@ -111,7 +111,7 @@ def parse_stream(rbody): async def parse_stream_async(rbody: aiohttp.StreamReader): - async for line in rbody: + async for line, _ in rbody.iter_chunks(): _line = parse_stream_helper(line) if _line is not None: yield _line @@ -294,18 +294,31 @@ async def arequest( request_id: Optional[str] = None, request_timeout: Optional[Union[float, Tuple[float, float]]] = None, ) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]: - async with aiohttp_session() as session: - result = await self.arequest_raw( - method.lower(), - url, - session, - params=params, - supplied_headers=headers, - files=files, - request_id=request_id, - request_timeout=request_timeout, - ) - resp, got_stream = await self._interpret_async_response(result, stream) + ctx = aiohttp_session() + session = await ctx.__aenter__() + result = await self.arequest_raw( + method.lower(), + url, + session, + params=params, + supplied_headers=headers, + files=files, + request_id=request_id, + request_timeout=request_timeout, + ) + resp, got_stream = await self._interpret_async_response(result, stream) + if got_stream: + + async def wrap_resp(): + try: + async for r in resp: + yield r + finally: + await ctx.__aexit__(None, None, None) + + return wrap_resp(), got_stream, self.api_key + else: + await ctx.__aexit__(None, None, None) return resp, got_stream, self.api_key def handle_error_response(self, rbody, rcode, resp, rheaders, stream_error=False): @@ -507,7 +520,9 @@ def request_raw( except requests.exceptions.Timeout as e: raise error.Timeout("Request timed out: {}".format(e)) from e except requests.exceptions.RequestException as e: - raise error.APIConnectionError("Error communicating with OpenAI: {}".format(e)) from e + raise error.APIConnectionError( + "Error communicating with OpenAI: {}".format(e) + ) from e util.log_info( "OpenAI API response", path=abs_url, diff --git a/openai/api_resources/abstract/engine_api_resource.py b/openai/api_resources/abstract/engine_api_resource.py index d6fe0d39a9..1f172d8cbd 100644 --- a/openai/api_resources/abstract/engine_api_resource.py +++ b/openai/api_resources/abstract/engine_api_resource.py @@ -236,7 +236,7 @@ async def acreate( engine=engine, plain_old_data=cls.plain_old_data, ) - for line in response + async for line in response ) else: obj = util.convert_to_openai_object( diff --git a/openai/tests/asyncio/test_endpoints.py b/openai/tests/asyncio/test_endpoints.py index e5c6d012cd..a0e38a5a04 100644 --- a/openai/tests/asyncio/test_endpoints.py +++ b/openai/tests/asyncio/test_endpoints.py @@ -5,6 +5,7 @@ import openai from openai import error +from aiohttp import ClientSession pytestmark = [pytest.mark.asyncio] @@ -63,3 +64,26 @@ async def test_timeout_does_not_error(): model="ada", request_timeout=10, ) + + +async def test_completions_stream_finishes_global_session(): + async with ClientSession() as session: + openai.aiosession.set(session) + + # A query that should be fast + parts = [] + async for part in await openai.Completion.acreate( + prompt="test", model="ada", request_timeout=3, stream=True + ): + parts.append(part) + assert len(parts) > 1 + + +async def test_completions_stream_finishes_local_session(): + # A query that should be fast + parts = [] + async for part in await openai.Completion.acreate( + prompt="test", model="ada", request_timeout=3, stream=True + ): + parts.append(part) + assert len(parts) > 1