From c2dc88997477942417b9194d3837d580390cc4cc Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Sat, 7 Jan 2023 10:50:40 +0000 Subject: [PATCH 1/6] Added a failing test case for async completion stream --- openai/tests/asyncio/test_endpoints.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/openai/tests/asyncio/test_endpoints.py b/openai/tests/asyncio/test_endpoints.py index e5c6d012cd..e44aa5144f 100644 --- a/openai/tests/asyncio/test_endpoints.py +++ b/openai/tests/asyncio/test_endpoints.py @@ -63,3 +63,15 @@ async def test_timeout_does_not_error(): model="ada", request_timeout=10, ) + + +async def test_completions_stream_finishes(): + # A query that should be fast + parts = [] + async for part in await openai.Completion.acreate( + prompt="test", model="ada", request_timeout=10, stream=True + ): + parts.append(part) + assert ( + len(parts) == 2 + ) # note this assertion is incorrect, but we don't make it here currently From 4f4f3cf350122220268a18a4450a65f50b9812f2 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Sat, 7 Jan 2023 10:51:08 +0000 Subject: [PATCH 2/6] Consume async generator with async for --- openai/api_resources/abstract/engine_api_resource.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openai/api_resources/abstract/engine_api_resource.py b/openai/api_resources/abstract/engine_api_resource.py index d6fe0d39a9..1f172d8cbd 100644 --- a/openai/api_resources/abstract/engine_api_resource.py +++ b/openai/api_resources/abstract/engine_api_resource.py @@ -236,7 +236,7 @@ async def acreate( engine=engine, plain_old_data=cls.plain_old_data, ) - for line in response + async for line in response ) else: obj = util.convert_to_openai_object( From 2e2e20e66892154465c233894214bd09ecb9d1e9 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Sat, 7 Jan 2023 10:58:52 +0000 Subject: [PATCH 3/6] Consume the stream in chunks as sent by API, to avoid "empty" parts The api will send chunks like ``` b'data: {"id": "cmpl-6W18L0k1kFoHUoSsJOwcPq7DKBaGX", "object": "text_completion", "created": 1673088873, "choices": [{"text": "_", "index": 0, "logprobs": null, "finish_reason": null}], "model": "ada"}\n\n' ``` The default iterator will break on each `\n` character, whereas iter_chunks will just output parts as they arrive --- openai/api_requestor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/openai/api_requestor.py b/openai/api_requestor.py index a34ee281ec..ae06306938 100644 --- a/openai/api_requestor.py +++ b/openai/api_requestor.py @@ -89,9 +89,9 @@ def _make_session() -> requests.Session: return s -def parse_stream_helper(line): +def parse_stream_helper(line: bytes): if line: - if line == b"data: [DONE]": + if line.strip() == b"data: [DONE]": # return here will cause GeneratorExit exception in urllib3 # and it will close http connection with TCP Reset return None @@ -111,7 +111,7 @@ def parse_stream(rbody): async def parse_stream_async(rbody: aiohttp.StreamReader): - async for line in rbody: + async for line, _ in rbody.iter_chunks(): _line = parse_stream_helper(line) if _line is not None: yield _line From 9817bbdaf4e16828af40906b3972fab920ca60a6 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Sat, 7 Jan 2023 12:36:20 +0000 Subject: [PATCH 4/6] Add another test using global aiosession --- openai/tests/asyncio/test_endpoints.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/openai/tests/asyncio/test_endpoints.py b/openai/tests/asyncio/test_endpoints.py index e44aa5144f..a0e38a5a04 100644 --- a/openai/tests/asyncio/test_endpoints.py +++ b/openai/tests/asyncio/test_endpoints.py @@ -5,6 +5,7 @@ import openai from openai import error +from aiohttp import ClientSession pytestmark = [pytest.mark.asyncio] @@ -65,13 +66,24 @@ async def test_timeout_does_not_error(): ) -async def test_completions_stream_finishes(): +async def test_completions_stream_finishes_global_session(): + async with ClientSession() as session: + openai.aiosession.set(session) + + # A query that should be fast + parts = [] + async for part in await openai.Completion.acreate( + prompt="test", model="ada", request_timeout=3, stream=True + ): + parts.append(part) + assert len(parts) > 1 + + +async def test_completions_stream_finishes_local_session(): # A query that should be fast parts = [] async for part in await openai.Completion.acreate( - prompt="test", model="ada", request_timeout=10, stream=True + prompt="test", model="ada", request_timeout=3, stream=True ): parts.append(part) - assert ( - len(parts) == 2 - ) # note this assertion is incorrect, but we don't make it here currently + assert len(parts) > 1 From 21ed0ad1b312b645496c1d3234c285cae832cfad Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Sat, 7 Jan 2023 12:38:18 +0000 Subject: [PATCH 5/6] Manually consume aiohttp_session asyncontextmanager to ensure that session is only closed once the response stream is finished Previously we'd exit the with statement before the response stream is consumed by the caller, therefore, unless we're using a global ClientSession, the session is closed (and thus the request) before it should be. --- openai/api_requestor.py | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/openai/api_requestor.py b/openai/api_requestor.py index ae06306938..effbf92dc8 100644 --- a/openai/api_requestor.py +++ b/openai/api_requestor.py @@ -294,18 +294,29 @@ async def arequest( request_id: Optional[str] = None, request_timeout: Optional[Union[float, Tuple[float, float]]] = None, ) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]: - async with aiohttp_session() as session: - result = await self.arequest_raw( - method.lower(), - url, - session, - params=params, - supplied_headers=headers, - files=files, - request_id=request_id, - request_timeout=request_timeout, - ) - resp, got_stream = await self._interpret_async_response(result, stream) + ctx = aiohttp_session() + session = await ctx.__aenter__() + result = await self.arequest_raw( + method.lower(), + url, + session, + params=params, + supplied_headers=headers, + files=files, + request_id=request_id, + request_timeout=request_timeout, + ) + resp, got_stream = await self._interpret_async_response(result, stream) + if got_stream: + + async def wrap_resp(): + async for r in resp: + yield r + await ctx.__aexit__(None, None, None) + + return wrap_resp(), got_stream, self.api_key + else: + await ctx.__aexit__(None, None, None) return resp, got_stream, self.api_key def handle_error_response(self, rbody, rcode, resp, rheaders, stream_error=False): From 70883524c0b61f14ab9568f5ceb204f8fd3a6e7b Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 12 Jan 2023 07:37:47 +0000 Subject: [PATCH 6/6] Ensure we close the session even if the caller raises an exception while consuming the stream --- openai/api_requestor.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/openai/api_requestor.py b/openai/api_requestor.py index effbf92dc8..eff7dd8a0a 100644 --- a/openai/api_requestor.py +++ b/openai/api_requestor.py @@ -310,9 +310,11 @@ async def arequest( if got_stream: async def wrap_resp(): - async for r in resp: - yield r - await ctx.__aexit__(None, None, None) + try: + async for r in resp: + yield r + finally: + await ctx.__aexit__(None, None, None) return wrap_resp(), got_stream, self.api_key else: @@ -518,7 +520,9 @@ def request_raw( except requests.exceptions.Timeout as e: raise error.Timeout("Request timed out: {}".format(e)) from e except requests.exceptions.RequestException as e: - raise error.APIConnectionError("Error communicating with OpenAI: {}".format(e)) from e + raise error.APIConnectionError( + "Error communicating with OpenAI: {}".format(e) + ) from e util.log_info( "OpenAI API response", path=abs_url,