Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 64 additions & 10 deletions codex-rs/core/src/chat_completions.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::sync::Arc;
use std::time::Duration;

use bytes::Bytes;
Expand Down Expand Up @@ -25,8 +26,8 @@ use crate::flags::OPENAI_REQUEST_MAX_RETRIES;
use crate::flags::OPENAI_STREAM_IDLE_TIMEOUT_MS;
use crate::models::ContentItem;
use crate::models::ResponseItem;
use crate::util::backoff;
use crate::util::UrlExt;
use crate::util::backoff;

/// Implementation for the classic Chat Completions API. This is intentionally
/// minimal: we only stream back plain assistant text.
Expand All @@ -35,6 +36,7 @@ pub(crate) async fn stream_chat_completions(
model: &str,
client: &reqwest::Client,
provider: &ModelProviderInfo,
token_aggregator: Arc<std::sync::Mutex<crate::client_common::TokenAggregator>>,
) -> Result<ResponseStream> {
// Build messages array
let mut messages = Vec::<serde_json::Value>::new();
Expand All @@ -60,10 +62,15 @@ pub(crate) async fn stream_chat_completions(
let payload = json!({
"model": model,
"messages": messages,
"stream": true
"stream": true,
"stream_options": {"include_usage": true}
});

let url = provider.base_url.clone().append_path("/chat/completions")?.to_string();
let url = provider
.base_url
.clone()
.append_path("/chat/completions")?
.to_string();

debug!("{} POST (chat)", &url);
trace!("request payload: {}", payload);
Expand All @@ -87,7 +94,11 @@ pub(crate) async fn stream_chat_completions(
Ok(resp) if resp.status().is_success() => {
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(16);
let stream = resp.bytes_stream().map_err(CodexErr::Reqwest);
tokio::spawn(process_chat_sse(stream, tx_event));
tokio::spawn(process_chat_sse(
stream,
tx_event,
Arc::clone(&token_aggregator),
));
return Ok(ResponseStream { rx_event });
}
Ok(res) => {
Expand Down Expand Up @@ -126,14 +137,21 @@ pub(crate) async fn stream_chat_completions(
/// Lightweight SSE processor for the Chat Completions streaming format. The
/// output is mapped onto Codex's internal [`ResponseEvent`] so that the rest
/// of the pipeline can stay agnostic of the underlying wire format.
async fn process_chat_sse<S>(stream: S, tx_event: mpsc::Sender<Result<ResponseEvent>>)
where
async fn process_chat_sse<S>(
stream: S,
tx_event: mpsc::Sender<Result<ResponseEvent>>,
token_aggregator: Arc<std::sync::Mutex<crate::client_common::TokenAggregator>>,
) where
S: Stream<Item = Result<Bytes>> + Unpin,
{
let mut stream = stream.eventsource();

let idle_timeout = *OPENAI_STREAM_IDLE_TIMEOUT_MS;

// Track usage information to include in final completion event
let mut input_tokens = None;
let mut output_tokens = None;

loop {
let sse = match timeout(idle_timeout, stream.next()).await {
Ok(Some(Ok(ev))) => ev,
Expand All @@ -146,6 +164,8 @@ where
let _ = tx_event
.send(Ok(ResponseEvent::Completed {
response_id: String::new(),
input_tokens,
output_tokens,
}))
.await;
return;
Expand All @@ -163,6 +183,8 @@ where
let _ = tx_event
.send(Ok(ResponseEvent::Completed {
response_id: String::new(),
input_tokens,
output_tokens,
}))
.await;
return;
Expand All @@ -174,6 +196,30 @@ where
Err(_) => continue,
};

// Store usage statistics when received.
// Chat Completions API uses "prompt_tokens" and "completion_tokens"
if let Some(usage) = chunk.get("usage") {
let usage_input_tokens = usage
.get("prompt_tokens")
.and_then(|v| v.as_u64())
.map(|v| v as u32)
.unwrap_or(0);
let usage_output_tokens = usage
.get("completion_tokens")
.and_then(|v| v.as_u64())
.map(|v| v as u32)
.unwrap_or(0);

// Add to session aggregator
token_aggregator
.lock()
.unwrap()
.add_token_usage(usage_input_tokens, usage_output_tokens);

input_tokens = Some(usage_input_tokens);
output_tokens = Some(usage_output_tokens);
}

let content_opt = chunk
.get("choices")
.and_then(|c| c.get(0))
Expand Down Expand Up @@ -251,7 +297,7 @@ where
// Swallow partial event; keep polling.
continue;
}
Poll::Ready(Some(Ok(ResponseEvent::Completed { response_id }))) => {
Poll::Ready(Some(Ok(ResponseEvent::Completed { response_id, .. }))) => {
if !this.cumulative.is_empty() {
let aggregated_item = crate::models::ResponseItem::Message {
role: "assistant".to_string(),
Expand All @@ -261,16 +307,24 @@ where
};

// Buffer Completed so it is returned *after* the aggregated message.
this.pending_completed = Some(ResponseEvent::Completed { response_id });
this.pending_completed = Some(ResponseEvent::Completed {
response_id,
input_tokens: None,
output_tokens: None,
});

return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(
aggregated_item,
))));
}

// Nothing aggregated – forward Completed directly.
return Poll::Ready(Some(Ok(ResponseEvent::Completed { response_id })));
} // No other `Ok` variants exist at the moment, continue polling.
return Poll::Ready(Some(Ok(ResponseEvent::Completed {
response_id,
input_tokens: None,
output_tokens: None,
})));
}
}
}
}
Expand Down
89 changes: 79 additions & 10 deletions codex-rs/core/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::collections::BTreeMap;
use std::io::BufRead;
use std::path::Path;
use std::sync::Arc;
use std::sync::LazyLock;
use std::time::Duration;

Expand All @@ -27,6 +28,7 @@ use crate::client_common::Reasoning;
use crate::client_common::ResponseEvent;
use crate::client_common::ResponseStream;
use crate::client_common::Summary;
use crate::client_common::TokenAggregator;
use crate::error::CodexErr;
use crate::error::EnvVarError;
use crate::error::Result;
Expand All @@ -36,8 +38,8 @@ use crate::flags::OPENAI_STREAM_IDLE_TIMEOUT_MS;
use crate::model_provider_info::ModelProviderInfo;
use crate::model_provider_info::WireApi;
use crate::models::ResponseItem;
use crate::util::backoff;
use crate::util::UrlExt;
use crate::util::backoff;

/// When serialized as JSON, this produces a valid "Tool" in the OpenAI
/// Responses API.
Expand Down Expand Up @@ -107,6 +109,7 @@ pub struct ModelClient {
model: String,
client: reqwest::Client,
provider: ModelProviderInfo,
token_aggregator: Arc<std::sync::Mutex<TokenAggregator>>,
}

impl ModelClient {
Expand All @@ -115,9 +118,28 @@ impl ModelClient {
model: model.to_string(),
client: reqwest::Client::new(),
provider,
token_aggregator: Arc::new(std::sync::Mutex::new(TokenAggregator::new())),
}
}

/// Returns the model name used for this client. Helper so callers inside
/// the business-logic layer (e.g. for pricing calculations) do not need
/// to reach into private fields.
pub fn model_name(&self) -> &str {
&self.model
}

/// Expose the provider so higher-level modules (e.g. cost accounting) can
/// inspect metadata without breaking encapsulation.
pub fn provider(&self) -> &ModelProviderInfo {
&self.provider
}

/// Get cumulative token usage for this session
pub fn get_session_token_usage(&self) -> (u32, u32) {
self.token_aggregator.lock().unwrap().get_token_totals()
}

/// Dispatches to either the Responses or Chat implementation depending on
/// the provider config. Public callers always invoke `stream()` – the
/// specialised helpers are private to avoid accidental misuse.
Expand All @@ -126,9 +148,14 @@ impl ModelClient {
WireApi::Responses => self.stream_responses(prompt).await,
WireApi::Chat => {
// Create the raw streaming connection first.
let response_stream =
stream_chat_completions(prompt, &self.model, &self.client, &self.provider)
.await?;
let response_stream = stream_chat_completions(
prompt,
&self.model,
&self.client,
&self.provider,
Arc::clone(&self.token_aggregator),
)
.await?;

// Wrap it with the aggregation adapter so callers see *only*
// the final assistant message per turn (matching the
Expand Down Expand Up @@ -199,7 +226,12 @@ impl ModelClient {
stream: true,
};

let url = self.provider.base_url.clone().append_path("/responses")?.to_string();
let url = self
.provider
.base_url
.clone()
.append_path("/responses")?
.to_string();

debug!("{} POST", url);
trace!("request payload: {}", serde_json::to_string(&payload)?);
Expand Down Expand Up @@ -237,7 +269,11 @@ impl ModelClient {

// spawn task to process SSE
let stream = resp.bytes_stream().map_err(CodexErr::Reqwest);
tokio::spawn(process_sse(stream, tx_event));
tokio::spawn(process_sse(
stream,
tx_event,
Arc::clone(&self.token_aggregator),
));

return Ok(ResponseStream { rx_event });
}
Expand Down Expand Up @@ -311,8 +347,11 @@ struct ResponseCompleted {
id: String,
}

async fn process_sse<S>(stream: S, tx_event: mpsc::Sender<Result<ResponseEvent>>)
where
async fn process_sse<S>(
stream: S,
tx_event: mpsc::Sender<Result<ResponseEvent>>,
token_aggregator: Arc<std::sync::Mutex<TokenAggregator>>,
) where
S: Stream<Item = Result<Bytes>> + Unpin,
{
let mut stream = stream.eventsource();
Expand All @@ -321,6 +360,9 @@ where
let idle_timeout = *OPENAI_STREAM_IDLE_TIMEOUT_MS;
// The response id returned from the "complete" message.
let mut response_id = None;
// Token usage information
let mut input_tokens = None;
let mut output_tokens = None;

loop {
let sse = match timeout(idle_timeout, stream.next()).await {
Expand All @@ -334,7 +376,11 @@ where
Ok(None) => {
match response_id {
Some(response_id) => {
let event = ResponseEvent::Completed { response_id };
let event = ResponseEvent::Completed {
response_id,
input_tokens,
output_tokens,
};
let _ = tx_event.send(Ok(event)).await;
}
None => {
Expand Down Expand Up @@ -398,6 +444,28 @@ where
// Final response completed – includes array of output items & id
"response.completed" => {
if let Some(resp_val) = event.response {
// Extract usage if present (Responses API uses input_tokens/output_tokens)
if let Some(usage) = resp_val.get("usage") {
let usage_input_tokens = usage
.get("input_tokens")
.and_then(|v| v.as_u64())
.map(|v| v as u32)
.unwrap_or(0);
let usage_output_tokens = usage
.get("output_tokens")
.and_then(|v| v.as_u64())
.map(|v| v as u32)
.unwrap_or(0);

token_aggregator
.lock()
.unwrap()
.add_token_usage(usage_input_tokens, usage_output_tokens);

input_tokens = Some(usage_input_tokens);
output_tokens = Some(usage_output_tokens);
}

match serde_json::from_value::<ResponseCompleted>(resp_val) {
Ok(r) => {
response_id = Some(r.id);
Expand Down Expand Up @@ -429,6 +497,7 @@ async fn stream_from_fixture(path: impl AsRef<Path>) -> Result<ResponseStream> {

let rdr = std::io::Cursor::new(content);
let stream = ReaderStream::new(rdr).map_err(CodexErr::Io);
tokio::spawn(process_sse(stream, tx_event));
let dummy_aggregator = Arc::new(std::sync::Mutex::new(TokenAggregator::new()));
tokio::spawn(process_sse(stream, tx_event, dummy_aggregator));
Ok(ResponseStream { rx_event })
}
27 changes: 26 additions & 1 deletion codex-rs/core/src/client_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,32 @@ impl Prompt {
#[derive(Debug)]
pub enum ResponseEvent {
OutputItemDone(ResponseItem),
Completed { response_id: String },
Completed {
response_id: String,
input_tokens: Option<u32>,
output_tokens: Option<u32>,
},
}

#[derive(Debug, Default)]
pub struct TokenAggregator {
total_input_tokens: u32,
total_output_tokens: u32,
}

impl TokenAggregator {
pub fn new() -> Self {
Self::default()
}

pub fn add_token_usage(&mut self, input_tokens: u32, output_tokens: u32) {
self.total_input_tokens += input_tokens;
self.total_output_tokens += output_tokens;
}

pub fn get_token_totals(&self) -> (u32, u32) {
(self.total_input_tokens, self.total_output_tokens)
}
}

#[derive(Debug, Serialize)]
Expand Down
Loading
Loading