Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 8b7d862

Browse files
authored
[5/n] OpenAI realtime transport impl (openai#1072)
Uses the openai realtime impl over websockets. Unlike the TS version, only supports websockets - no in browser stuff. --- [//]: # (BEGIN SAPLING FOOTER) * openai#1074 * openai#1073 * __->__ openai#1072 * openai#1071
1 parent 9d32577 commit 8b7d862

File tree

2 files changed

+354
-0
lines changed

2 files changed

+354
-0
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ Repository = "https://github.com/openai/openai-agents-python"
3737
voice = ["numpy>=2.2.0, <3; python_version>='3.10'", "websockets>=15.0, <16"]
3838
viz = ["graphviz>=0.17"]
3939
litellm = ["litellm>=1.67.4.post1, <2"]
40+
realtime = ["websockets>=15.0, <16"]
4041

4142
[dependency-groups]
4243
dev = [
Lines changed: 353 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,353 @@
1+
import asyncio
2+
import base64
3+
import json
4+
import os
5+
from datetime import datetime
6+
from typing import Any
7+
8+
import websockets
9+
from openai.types.beta.realtime.realtime_server_event import (
10+
RealtimeServerEvent as OpenAIRealtimeServerEvent,
11+
)
12+
from pydantic import TypeAdapter
13+
from websockets.asyncio.client import ClientConnection
14+
15+
from ..exceptions import UserError
16+
from ..logger import logger
17+
from .config import RealtimeClientMessage, RealtimeUserInput, get_api_key
18+
from .items import RealtimeMessageItem, RealtimeToolCallItem
19+
from .transport import (
20+
RealtimeSessionTransport,
21+
RealtimeTransportConnectionOptions,
22+
RealtimeTransportListener,
23+
)
24+
from .transport_events import (
25+
RealtimeTransportAudioDoneEvent,
26+
RealtimeTransportAudioEvent,
27+
RealtimeTransportAudioInterruptedEvent,
28+
RealtimeTransportErrorEvent,
29+
RealtimeTransportEvent,
30+
RealtimeTransportInputAudioTranscriptionCompletedEvent,
31+
RealtimeTransportItemDeletedEvent,
32+
RealtimeTransportItemUpdatedEvent,
33+
RealtimeTransportToolCallEvent,
34+
RealtimeTransportTranscriptDelta,
35+
RealtimeTransportTurnEndedEvent,
36+
RealtimeTransportTurnStartedEvent,
37+
)
38+
39+
40+
class OpenAIRealtimeWebSocketTransport(RealtimeSessionTransport):
41+
"""A transport layer for realtime sessions that uses OpenAI's WebSocket API."""
42+
43+
def __init__(self) -> None:
44+
self.model = "gpt-4o-realtime-preview" # Default model
45+
self._websocket: ClientConnection | None = None
46+
self._websocket_task: asyncio.Task[None] | None = None
47+
self._listeners: list[RealtimeTransportListener] = []
48+
self._current_item_id: str | None = None
49+
self._audio_start_time: datetime | None = None
50+
self._audio_length_ms: float = 0.0
51+
self._ongoing_response: bool = False
52+
self._current_audio_content_index: int | None = None
53+
54+
async def connect(self, options: RealtimeTransportConnectionOptions) -> None:
55+
"""Establish a connection to the model and keep it alive."""
56+
assert self._websocket is None, "Already connected"
57+
assert self._websocket_task is None, "Already connected"
58+
59+
self.model = options.get("model", self.model)
60+
api_key = await get_api_key(options.get("api_key", os.getenv("OPENAI_API_KEY")))
61+
62+
if not api_key:
63+
raise UserError("API key is required but was not provided.")
64+
65+
url = options.get("url", f"wss://api.openai.com/v1/realtime?model={self.model}")
66+
67+
headers = {
68+
"Authorization": f"Bearer {api_key}",
69+
"OpenAI-Beta": "realtime=v1",
70+
}
71+
self._websocket = await websockets.connect(url, additional_headers=headers)
72+
self._websocket_task = asyncio.create_task(self._listen_for_messages())
73+
74+
def add_listener(self, listener: RealtimeTransportListener) -> None:
75+
"""Add a listener to the transport."""
76+
self._listeners.append(listener)
77+
78+
async def remove_listener(self, listener: RealtimeTransportListener) -> None:
79+
"""Remove a listener from the transport."""
80+
self._listeners.remove(listener)
81+
82+
async def _emit_event(self, event: RealtimeTransportEvent) -> None:
83+
"""Emit an event to the listeners."""
84+
for listener in self._listeners:
85+
await listener.on_event(event)
86+
87+
async def _listen_for_messages(self):
88+
assert self._websocket is not None, "Not connected"
89+
90+
try:
91+
async for message in self._websocket:
92+
parsed = json.loads(message)
93+
await self._handle_ws_event(parsed)
94+
95+
except websockets.exceptions.ConnectionClosed:
96+
# TODO connection closed handling (event, cleanup)
97+
logger.warning("WebSocket connection closed")
98+
except Exception as e:
99+
logger.error(f"WebSocket error: {e}")
100+
101+
async def send_event(self, event: RealtimeClientMessage) -> None:
102+
"""Send an event to the model."""
103+
assert self._websocket is not None, "Not connected"
104+
converted_event = {
105+
"type": event["type"],
106+
}
107+
108+
converted_event.update(event.get("other_data", {}))
109+
110+
await self._websocket.send(json.dumps(converted_event))
111+
112+
async def send_message(
113+
self, message: RealtimeUserInput, other_event_data: dict[str, Any] | None = None
114+
) -> None:
115+
"""Send a message to the model."""
116+
message = (
117+
message
118+
if isinstance(message, dict)
119+
else {
120+
"type": "message",
121+
"role": "user",
122+
"content": [{"type": "input_text", "text": message}],
123+
}
124+
)
125+
other_data = {
126+
"item": message,
127+
}
128+
if other_event_data:
129+
other_data.update(other_event_data)
130+
131+
await self.send_event({"type": "conversation.item.create", "other_data": other_data})
132+
133+
await self.send_event({"type": "response.create"})
134+
135+
async def send_audio(self, audio: bytes, *, commit: bool = False) -> None:
136+
"""Send a raw audio chunk to the model.
137+
138+
Args:
139+
audio: The audio data to send.
140+
commit: Whether to commit the audio buffer to the model. If the model does not do turn
141+
detection, this can be used to indicate the turn is completed.
142+
"""
143+
assert self._websocket is not None, "Not connected"
144+
base64_audio = base64.b64encode(audio).decode("utf-8")
145+
await self.send_event(
146+
{
147+
"type": "input_audio_buffer.append",
148+
"other_data": {
149+
"audio": base64_audio,
150+
},
151+
}
152+
)
153+
if commit:
154+
await self.send_event({"type": "input_audio_buffer.commit"})
155+
156+
async def send_tool_output(
157+
self, tool_call: RealtimeTransportToolCallEvent, output: str, start_response: bool
158+
) -> None:
159+
"""Send tool output to the model."""
160+
await self.send_event(
161+
{
162+
"type": "conversation.item.create",
163+
"other_data": {
164+
"item": {
165+
"type": "function_call_output",
166+
"output": output,
167+
"call_id": tool_call.id,
168+
},
169+
},
170+
}
171+
)
172+
173+
tool_item = RealtimeToolCallItem(
174+
item_id=tool_call.id or "",
175+
previous_item_id=tool_call.previous_item_id,
176+
type="function_call",
177+
status="completed",
178+
arguments=tool_call.arguments,
179+
name=tool_call.name,
180+
output=output,
181+
)
182+
await self._emit_event(RealtimeTransportItemUpdatedEvent(item=tool_item))
183+
184+
if start_response:
185+
await self.send_event({"type": "response.create"})
186+
187+
async def interrupt(self) -> None:
188+
"""Interrupt the model."""
189+
if not self._current_item_id or not self._audio_start_time:
190+
return
191+
192+
await self._cancel_response()
193+
194+
elapsed_time_ms = (datetime.now() - self._audio_start_time).total_seconds() * 1000
195+
if elapsed_time_ms > 0 and elapsed_time_ms < self._audio_length_ms:
196+
await self._emit_event(RealtimeTransportAudioInterruptedEvent())
197+
await self.send_event(
198+
{
199+
"type": "conversation.item.truncate",
200+
"other_data": {
201+
"item_id": self._current_item_id,
202+
"content_index": self._current_audio_content_index,
203+
"audio_end_ms": elapsed_time_ms,
204+
},
205+
}
206+
)
207+
208+
self._current_item_id = None
209+
self._audio_start_time = None
210+
self._audio_length_ms = 0.0
211+
self._current_audio_content_index = None
212+
213+
async def close(self) -> None:
214+
"""Close the session."""
215+
if self._websocket:
216+
await self._websocket.close()
217+
self._websocket = None
218+
if self._websocket_task:
219+
self._websocket_task.cancel()
220+
self._websocket_task = None
221+
222+
async def _cancel_response(self) -> None:
223+
if self._ongoing_response:
224+
await self.send_event({"type": "response.cancel"})
225+
self._ongoing_response = False
226+
227+
async def _handle_ws_event(self, event: dict[str, Any]):
228+
try:
229+
parsed: OpenAIRealtimeServerEvent = TypeAdapter(
230+
OpenAIRealtimeServerEvent
231+
).validate_python(event)
232+
except Exception as e:
233+
logger.error(f"Invalid event: {event} - {e}")
234+
await self._emit_event(RealtimeTransportErrorEvent(error=f"Invalid event: {event}"))
235+
return
236+
237+
if parsed.type == "response.audio.delta":
238+
self._current_audio_content_index = parsed.content_index
239+
self._current_item_id = parsed.item_id
240+
if self._audio_start_time is None:
241+
self._audio_start_time = datetime.now()
242+
self._audio_length_ms = 0.0
243+
244+
audio_bytes = base64.b64decode(parsed.delta)
245+
# Calculate audio length in ms using 24KHz pcm16le
246+
self._audio_length_ms += len(audio_bytes) / 24 / 2
247+
await self._emit_event(
248+
RealtimeTransportAudioEvent(data=audio_bytes, response_id=parsed.response_id)
249+
)
250+
elif parsed.type == "response.audio.done":
251+
await self._emit_event(RealtimeTransportAudioDoneEvent())
252+
elif parsed.type == "input_audio_buffer.speech_started":
253+
await self.interrupt()
254+
elif parsed.type == "response.created":
255+
self._ongoing_response = True
256+
await self._emit_event(RealtimeTransportTurnStartedEvent())
257+
elif parsed.type == "response.done":
258+
self._ongoing_response = False
259+
await self._emit_event(RealtimeTransportTurnEndedEvent())
260+
elif parsed.type == "session.created":
261+
# TODO (rm) tracing stuff here
262+
pass
263+
elif parsed.type == "error":
264+
await self._emit_event(RealtimeTransportErrorEvent(error=parsed.error))
265+
elif parsed.type == "conversation.item.deleted":
266+
await self._emit_event(RealtimeTransportItemDeletedEvent(item_id=parsed.item_id))
267+
elif (
268+
parsed.type == "conversation.item.created"
269+
or parsed.type == "conversation.item.retrieved"
270+
):
271+
item = parsed.item
272+
previous_item_id = (
273+
parsed.previous_item_id if parsed.type == "conversation.item.created" else None
274+
)
275+
message_item: RealtimeMessageItem = TypeAdapter(RealtimeMessageItem).validate_python(
276+
{
277+
"item_id": item.id or "",
278+
"previous_item_id": previous_item_id,
279+
"type": item.type,
280+
"role": item.role,
281+
"content": item.content,
282+
"status": "in_progress",
283+
}
284+
)
285+
await self._emit_event(RealtimeTransportItemUpdatedEvent(item=message_item))
286+
elif (
287+
parsed.type == "conversation.item.input_audio_transcription.completed"
288+
or parsed.type == "conversation.item.truncated"
289+
):
290+
await self.send_event(
291+
{
292+
"type": "conversation.item.retrieve",
293+
"other_data": {
294+
"item_id": self._current_item_id,
295+
},
296+
}
297+
)
298+
if parsed.type == "conversation.item.input_audio_transcription.completed":
299+
await self._emit_event(
300+
RealtimeTransportInputAudioTranscriptionCompletedEvent(
301+
item_id=parsed.item_id, transcript=parsed.transcript
302+
)
303+
)
304+
elif parsed.type == "response.audio_transcript.delta":
305+
await self._emit_event(
306+
RealtimeTransportTranscriptDelta(
307+
item_id=parsed.item_id, delta=parsed.delta, response_id=parsed.response_id
308+
)
309+
)
310+
elif (
311+
parsed.type == "conversation.item.input_audio_transcription.delta"
312+
or parsed.type == "response.text.delta"
313+
or parsed.type == "response.function_call_arguments.delta"
314+
):
315+
# No support for partials yet
316+
pass
317+
elif (
318+
parsed.type == "response.output_item.added"
319+
or parsed.type == "response.output_item.done"
320+
):
321+
item = parsed.item
322+
if item.type == "function_call" and item.status == "completed":
323+
tool_call = RealtimeToolCallItem(
324+
item_id=item.id or "",
325+
previous_item_id=None,
326+
type="function_call",
327+
# We use the same item for tool call and output, so it will be completed by the
328+
# output being added
329+
status="in_progress",
330+
arguments=item.arguments or "",
331+
name=item.name or "",
332+
output=None,
333+
)
334+
await self._emit_event(RealtimeTransportItemUpdatedEvent(item=tool_call))
335+
await self._emit_event(
336+
RealtimeTransportToolCallEvent(
337+
call_id=item.id or "",
338+
name=item.name or "",
339+
arguments=item.arguments or "",
340+
id=item.id or "",
341+
)
342+
)
343+
elif item.type == "message":
344+
message_item = TypeAdapter(RealtimeMessageItem).validate_python(
345+
{
346+
"item_id": item.id or "",
347+
"type": item.type,
348+
"role": item.role,
349+
"content": item.content,
350+
"status": "in_progress",
351+
}
352+
)
353+
await self._emit_event(RealtimeTransportItemUpdatedEvent(item=message_item))

0 commit comments

Comments
 (0)