Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d9b43f5

Browse files
authored
Merge branch 'openai:main' into client
2 parents 917b97e + b82a3f7 commit d9b43f5

File tree

9 files changed

+208
-15
lines changed

9 files changed

+208
-15
lines changed

openai/api_requestor.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
import asyncio
22
import json
3+
import time
34
import platform
45
import sys
56
import threading
7+
import time
68
import warnings
79
from contextlib import asynccontextmanager
810
from json import JSONDecodeError
911
from typing import (
1012
AsyncGenerator,
1113
AsyncIterator,
14+
Callable,
1215
Dict,
1316
Iterator,
1417
Optional,
@@ -32,6 +35,7 @@
3235
from openai.util import ApiType
3336

3437
TIMEOUT_SECS = 600
38+
MAX_SESSION_LIFETIME_SECS = 180
3539
MAX_CONNECTION_RETRIES = 2
3640

3741
# Has one attribute per thread, 'session'.
@@ -149,6 +153,70 @@ def format_app_info(cls, info):
149153
str += " (%s)" % (info["url"],)
150154
return str
151155

156+
def _check_polling_response(self, response: OpenAIResponse, predicate: Callable[[OpenAIResponse], bool]):
157+
if not predicate(response):
158+
return
159+
error_data = response.data['error']
160+
message = error_data.get('message', 'Operation failed')
161+
code = error_data.get('code')
162+
raise error.OpenAIError(message=message, code=code)
163+
164+
def _poll(
165+
self,
166+
method,
167+
url,
168+
until,
169+
failed,
170+
params = None,
171+
headers = None,
172+
interval = None,
173+
delay = None
174+
) -> Tuple[Iterator[OpenAIResponse], bool, str]:
175+
if delay:
176+
time.sleep(delay)
177+
178+
response, b, api_key = self.request(method, url, params, headers)
179+
self._check_polling_response(response, failed)
180+
start_time = time.time()
181+
while not until(response):
182+
if time.time() - start_time > TIMEOUT_SECS:
183+
raise error.Timeout("Operation polling timed out.")
184+
185+
time.sleep(interval or response.retry_after or 10)
186+
response, b, api_key = self.request(method, url, params, headers)
187+
self._check_polling_response(response, failed)
188+
189+
response.data = response.data['result']
190+
return response, b, api_key
191+
192+
async def _apoll(
193+
self,
194+
method,
195+
url,
196+
until,
197+
failed,
198+
params = None,
199+
headers = None,
200+
interval = None,
201+
delay = None
202+
) -> Tuple[Iterator[OpenAIResponse], bool, str]:
203+
if delay:
204+
await asyncio.sleep(delay)
205+
206+
response, b, api_key = await self.arequest(method, url, params, headers)
207+
self._check_polling_response(response, failed)
208+
start_time = time.time()
209+
while not until(response):
210+
if time.time() - start_time > TIMEOUT_SECS:
211+
raise error.Timeout("Operation polling timed out.")
212+
213+
await asyncio.sleep(interval or response.retry_after or 10)
214+
response, b, api_key = await self.arequest(method, url, params, headers)
215+
self._check_polling_response(response, failed)
216+
217+
response.data = response.data['result']
218+
return response, b, api_key
219+
152220
@overload
153221
def request(
154222
self,
@@ -516,6 +584,14 @@ def request_raw(
516584

517585
if not hasattr(_thread_context, "session"):
518586
_thread_context.session = _make_session()
587+
_thread_context.session_create_time = time.time()
588+
elif (
589+
time.time() - getattr(_thread_context, "session_create_time", 0)
590+
>= MAX_SESSION_LIFETIME_SECS
591+
):
592+
_thread_context.session.close()
593+
_thread_context.session = _make_session()
594+
_thread_context.session_create_time = time.time()
519595
try:
520596
result = _thread_context.session.request(
521597
method,
@@ -644,6 +720,8 @@ async def _interpret_async_response(
644720
else:
645721
try:
646722
await result.read()
723+
except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as e:
724+
raise error.Timeout("Request timed out") from e
647725
except aiohttp.ClientError as e:
648726
util.log_warn(e, body=result.content)
649727
return (

openai/api_resources/audio.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ def transcribe(
5959
api_key=api_key,
6060
api_base=api_base,
6161
api_type=api_type,
62+
api_version=api_version,
63+
organization=organization,
6264
**params,
6365
)
6466
url = cls._get_url("transcriptions")
@@ -86,6 +88,8 @@ def translate(
8688
api_key=api_key,
8789
api_base=api_base,
8890
api_type=api_type,
91+
api_version=api_version,
92+
organization=organization,
8993
**params,
9094
)
9195
url = cls._get_url("translations")
@@ -114,6 +118,8 @@ def transcribe_raw(
114118
api_key=api_key,
115119
api_base=api_base,
116120
api_type=api_type,
121+
api_version=api_version,
122+
organization=organization,
117123
**params,
118124
)
119125
url = cls._get_url("transcriptions")
@@ -142,6 +148,8 @@ def translate_raw(
142148
api_key=api_key,
143149
api_base=api_base,
144150
api_type=api_type,
151+
api_version=api_version,
152+
organization=organization,
145153
**params,
146154
)
147155
url = cls._get_url("translations")
@@ -169,6 +177,8 @@ async def atranscribe(
169177
api_key=api_key,
170178
api_base=api_base,
171179
api_type=api_type,
180+
api_version=api_version,
181+
organization=organization,
172182
**params,
173183
)
174184
url = cls._get_url("transcriptions")
@@ -198,6 +208,8 @@ async def atranslate(
198208
api_key=api_key,
199209
api_base=api_base,
200210
api_type=api_type,
211+
api_version=api_version,
212+
organization=organization,
201213
**params,
202214
)
203215
url = cls._get_url("translations")
@@ -228,6 +240,8 @@ async def atranscribe_raw(
228240
api_key=api_key,
229241
api_base=api_base,
230242
api_type=api_type,
243+
api_version=api_version,
244+
organization=organization,
231245
**params,
232246
)
233247
url = cls._get_url("transcriptions")
@@ -258,6 +272,8 @@ async def atranslate_raw(
258272
api_key=api_key,
259273
api_base=api_base,
260274
api_type=api_type,
275+
api_version=api_version,
276+
organization=organization,
261277
**params,
262278
)
263279
url = cls._get_url("translations")

openai/api_resources/chat_completion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def create(cls, *args, **kwargs):
1414
"""
1515
Creates a new chat completion for the provided messages and parameters.
1616
17-
See https://platform.openai.com/docs/api-reference/chat-completions/create
17+
See https://platform.openai.com/docs/api-reference/chat/create
1818
for a list of valid parameters.
1919
"""
2020
start = time.time()
@@ -34,7 +34,7 @@ async def acreate(cls, *args, **kwargs):
3434
"""
3535
Creates a new chat completion for the provided messages and parameters.
3636
37-
See https://platform.openai.com/docs/api-reference/chat-completions/create
37+
See https://platform.openai.com/docs/api-reference/chat/create
3838
for a list of valid parameters.
3939
"""
4040
start = time.time()

openai/api_resources/image.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,19 @@
22
from typing import Any, List
33

44
import openai
5-
from openai import api_requestor, util
5+
from openai import api_requestor, error, util
66
from openai.api_resources.abstract import APIResource
77

88

99
class Image(APIResource):
1010
OBJECT_NAME = "images"
1111

1212
@classmethod
13-
def _get_url(cls, action):
14-
return cls.class_url() + f"/{action}"
13+
def _get_url(cls, action, azure_action, api_type, api_version):
14+
if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD) and azure_action is not None:
15+
return f"/{cls.azure_api_prefix}{cls.class_url()}/{action}:{azure_action}?api-version={api_version}"
16+
else:
17+
return f"{cls.class_url()}/{action}"
1518

1619
@classmethod
1720
def create(
@@ -31,12 +34,20 @@ def create(
3134
organization=organization,
3235
)
3336

34-
_, api_version = cls._get_api_type_and_version(api_type, api_version)
37+
api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
3538

3639
response, _, api_key = requestor.request(
37-
"post", cls._get_url("generations"), params
40+
"post", cls._get_url("generations", azure_action="submit", api_type=api_type, api_version=api_version), params
3841
)
3942

43+
if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
44+
requestor.api_base = "" # operation_location is a full url
45+
response, _, api_key = requestor._poll(
46+
"get", response.operation_location,
47+
until=lambda response: response.data['status'] in [ 'succeeded' ],
48+
failed=lambda response: response.data['status'] in [ 'failed' ]
49+
)
50+
4051
return util.convert_to_openai_object(
4152
response, api_key, api_version, organization
4253
)
@@ -60,12 +71,20 @@ async def acreate(
6071
organization=organization,
6172
)
6273

63-
_, api_version = cls._get_api_type_and_version(api_type, api_version)
74+
api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
6475

6576
response, _, api_key = await requestor.arequest(
66-
"post", cls._get_url("generations"), params
77+
"post", cls._get_url("generations", azure_action="submit", api_type=api_type, api_version=api_version), params
6778
)
6879

80+
if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
81+
requestor.api_base = "" # operation_location is a full url
82+
response, _, api_key = await requestor._apoll(
83+
"get", response.operation_location,
84+
until=lambda response: response.data['status'] in [ 'succeeded' ],
85+
failed=lambda response: response.data['status'] in [ 'failed' ]
86+
)
87+
6988
return util.convert_to_openai_object(
7089
response, api_key, api_version, organization
7190
)
@@ -88,9 +107,9 @@ def _prepare_create_variation(
88107
api_version=api_version,
89108
organization=organization,
90109
)
91-
_, api_version = cls._get_api_type_and_version(api_type, api_version)
110+
api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
92111

93-
url = cls._get_url("variations")
112+
url = cls._get_url("variations", azure_action=None, api_type=api_type, api_version=api_version)
94113

95114
files: List[Any] = []
96115
for key, value in params.items():
@@ -109,6 +128,9 @@ def create_variation(
109128
organization=None,
110129
**params,
111130
):
131+
if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
132+
raise error.InvalidAPIType("Variations are not supported by the Azure OpenAI API yet.")
133+
112134
requestor, url, files = cls._prepare_create_variation(
113135
image,
114136
api_key,
@@ -136,6 +158,9 @@ async def acreate_variation(
136158
organization=None,
137159
**params,
138160
):
161+
if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
162+
raise error.InvalidAPIType("Variations are not supported by the Azure OpenAI API yet.")
163+
139164
requestor, url, files = cls._prepare_create_variation(
140165
image,
141166
api_key,
@@ -171,9 +196,9 @@ def _prepare_create_edit(
171196
api_version=api_version,
172197
organization=organization,
173198
)
174-
_, api_version = cls._get_api_type_and_version(api_type, api_version)
199+
api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
175200

176-
url = cls._get_url("edits")
201+
url = cls._get_url("edits", azure_action=None, api_type=api_type, api_version=api_version)
177202

178203
files: List[Any] = []
179204
for key, value in params.items():
@@ -195,6 +220,9 @@ def create_edit(
195220
organization=None,
196221
**params,
197222
):
223+
if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
224+
raise error.InvalidAPIType("Edits are not supported by the Azure OpenAI API yet.")
225+
198226
requestor, url, files = cls._prepare_create_edit(
199227
image,
200228
mask,
@@ -224,6 +252,9 @@ async def acreate_edit(
224252
organization=None,
225253
**params,
226254
):
255+
if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
256+
raise error.InvalidAPIType("Edits are not supported by the Azure OpenAI API yet.")
257+
227258
requestor, url, files = cls._prepare_create_edit(
228259
image,
229260
mask,

openai/openai_object.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def __repr__(self):
278278

279279
def __str__(self):
280280
obj = self.to_dict_recursive()
281-
return json.dumps(obj, sort_keys=True, indent=2)
281+
return json.dumps(obj, indent=2)
282282

283283
def to_dict(self):
284284
return dict(self)

openai/openai_response.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,17 @@ def __init__(self, data, headers):
1010
def request_id(self) -> Optional[str]:
1111
return self._headers.get("request-id")
1212

13+
@property
14+
def retry_after(self) -> Optional[int]:
15+
try:
16+
return int(self._headers.get("retry-after"))
17+
except TypeError:
18+
return None
19+
20+
@property
21+
def operation_location(self) -> Optional[str]:
22+
return self._headers.get("operation-location")
23+
1324
@property
1425
def organization(self) -> Optional[str]:
1526
return self._headers.get("OpenAI-Organization")

0 commit comments

Comments
 (0)