Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Several fixes to make Completion.acreate(stream=True) work #172

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 31 additions & 16 deletions openai/api_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ def _make_session() -> requests.Session:
return s


def parse_stream_helper(line):
def parse_stream_helper(line: bytes):
if line:
if line == b"data: [DONE]":
if line.strip() == b"data: [DONE]":
# return here will cause GeneratorExit exception in urllib3
# and it will close http connection with TCP Reset
return None
Expand All @@ -111,7 +111,7 @@ def parse_stream(rbody):


async def parse_stream_async(rbody: aiohttp.StreamReader):
async for line in rbody:
async for line, _ in rbody.iter_chunks():
_line = parse_stream_helper(line)
if _line is not None:
yield _line
Expand Down Expand Up @@ -294,18 +294,31 @@ async def arequest(
request_id: Optional[str] = None,
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]:
async with aiohttp_session() as session:
result = await self.arequest_raw(
method.lower(),
url,
session,
params=params,
supplied_headers=headers,
files=files,
request_id=request_id,
request_timeout=request_timeout,
)
resp, got_stream = await self._interpret_async_response(result, stream)
ctx = aiohttp_session()
session = await ctx.__aenter__()
result = await self.arequest_raw(
method.lower(),
url,
session,
params=params,
supplied_headers=headers,
files=files,
request_id=request_id,
request_timeout=request_timeout,
)
resp, got_stream = await self._interpret_async_response(result, stream)
if got_stream:

async def wrap_resp():
try:
async for r in resp:
yield r
finally:
await ctx.__aexit__(None, None, None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might still not be called if the caller never actually iterates through the response and just drops it right?

It's probably fine for now to fix this bug but I imagine we'll want to scope the session to the requestor itself in the future so that we can always ensure that it's closed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's an issue inherent to exposing an async iterator as an api here? If the caller doesn't consume it then all sorts of bad things may happen... The only solution I can think of is to add some cleanup with a timeout, but sounds a bit invasive? Another option would be to make this a bit fake, and fully consume the iterator ourselves and buffer it all in memory, but that partly defeats the purpose of asking for a stream

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep that's fair.


return wrap_resp(), got_stream, self.api_key
else:
await ctx.__aexit__(None, None, None)
return resp, got_stream, self.api_key

def handle_error_response(self, rbody, rcode, resp, rheaders, stream_error=False):
Expand Down Expand Up @@ -507,7 +520,9 @@ def request_raw(
except requests.exceptions.Timeout as e:
raise error.Timeout("Request timed out: {}".format(e)) from e
except requests.exceptions.RequestException as e:
raise error.APIConnectionError("Error communicating with OpenAI: {}".format(e)) from e
raise error.APIConnectionError(
"Error communicating with OpenAI: {}".format(e)
) from e
util.log_info(
"OpenAI API response",
path=abs_url,
Expand Down
2 changes: 1 addition & 1 deletion openai/api_resources/abstract/engine_api_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ async def acreate(
engine=engine,
plain_old_data=cls.plain_old_data,
)
for line in response
async for line in response
)
else:
obj = util.convert_to_openai_object(
Expand Down
24 changes: 24 additions & 0 deletions openai/tests/asyncio/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import openai
from openai import error
from aiohttp import ClientSession


pytestmark = [pytest.mark.asyncio]
Expand Down Expand Up @@ -63,3 +64,26 @@ async def test_timeout_does_not_error():
model="ada",
request_timeout=10,
)


async def test_completions_stream_finishes_global_session():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this test.

async with ClientSession() as session:
openai.aiosession.set(session)

# A query that should be fast
parts = []
async for part in await openai.Completion.acreate(
prompt="test", model="ada", request_timeout=3, stream=True
):
parts.append(part)
assert len(parts) > 1


async def test_completions_stream_finishes_local_session():
# A query that should be fast
parts = []
async for part in await openai.Completion.acreate(
prompt="test", model="ada", request_timeout=3, stream=True
):
parts.append(part)
assert len(parts) > 1