diff --git a/codex-rs/core/src/chat_completions.rs b/codex-rs/core/src/chat_completions.rs index 52affa46dcf..4170c033fba 100644 --- a/codex-rs/core/src/chat_completions.rs +++ b/codex-rs/core/src/chat_completions.rs @@ -1,3 +1,4 @@ +use std::sync::Arc; use std::time::Duration; use bytes::Bytes; @@ -25,8 +26,8 @@ use crate::flags::OPENAI_REQUEST_MAX_RETRIES; use crate::flags::OPENAI_STREAM_IDLE_TIMEOUT_MS; use crate::models::ContentItem; use crate::models::ResponseItem; -use crate::util::backoff; use crate::util::UrlExt; +use crate::util::backoff; /// Implementation for the classic Chat Completions API. This is intentionally /// minimal: we only stream back plain assistant text. @@ -35,6 +36,7 @@ pub(crate) async fn stream_chat_completions( model: &str, client: &reqwest::Client, provider: &ModelProviderInfo, + token_aggregator: Arc>, ) -> Result { // Build messages array let mut messages = Vec::::new(); @@ -60,10 +62,15 @@ pub(crate) async fn stream_chat_completions( let payload = json!({ "model": model, "messages": messages, - "stream": true + "stream": true, + "stream_options": {"include_usage": true} }); - let url = provider.base_url.clone().append_path("/chat/completions")?.to_string(); + let url = provider + .base_url + .clone() + .append_path("/chat/completions")? + .to_string(); debug!("{} POST (chat)", &url); trace!("request payload: {}", payload); @@ -87,7 +94,11 @@ pub(crate) async fn stream_chat_completions( Ok(resp) if resp.status().is_success() => { let (tx_event, rx_event) = mpsc::channel::>(16); let stream = resp.bytes_stream().map_err(CodexErr::Reqwest); - tokio::spawn(process_chat_sse(stream, tx_event)); + tokio::spawn(process_chat_sse( + stream, + tx_event, + Arc::clone(&token_aggregator), + )); return Ok(ResponseStream { rx_event }); } Ok(res) => { @@ -126,14 +137,21 @@ pub(crate) async fn stream_chat_completions( /// Lightweight SSE processor for the Chat Completions streaming format. The /// output is mapped onto Codex's internal [`ResponseEvent`] so that the rest /// of the pipeline can stay agnostic of the underlying wire format. -async fn process_chat_sse(stream: S, tx_event: mpsc::Sender>) -where +async fn process_chat_sse( + stream: S, + tx_event: mpsc::Sender>, + token_aggregator: Arc>, +) where S: Stream> + Unpin, { let mut stream = stream.eventsource(); let idle_timeout = *OPENAI_STREAM_IDLE_TIMEOUT_MS; + // Track usage information to include in final completion event + let mut input_tokens = None; + let mut output_tokens = None; + loop { let sse = match timeout(idle_timeout, stream.next()).await { Ok(Some(Ok(ev))) => ev, @@ -146,6 +164,8 @@ where let _ = tx_event .send(Ok(ResponseEvent::Completed { response_id: String::new(), + input_tokens, + output_tokens, })) .await; return; @@ -163,6 +183,8 @@ where let _ = tx_event .send(Ok(ResponseEvent::Completed { response_id: String::new(), + input_tokens, + output_tokens, })) .await; return; @@ -174,6 +196,30 @@ where Err(_) => continue, }; + // Store usage statistics when received. + // Chat Completions API uses "prompt_tokens" and "completion_tokens" + if let Some(usage) = chunk.get("usage") { + let usage_input_tokens = usage + .get("prompt_tokens") + .and_then(|v| v.as_u64()) + .map(|v| v as u32) + .unwrap_or(0); + let usage_output_tokens = usage + .get("completion_tokens") + .and_then(|v| v.as_u64()) + .map(|v| v as u32) + .unwrap_or(0); + + // Add to session aggregator + token_aggregator + .lock() + .unwrap() + .add_token_usage(usage_input_tokens, usage_output_tokens); + + input_tokens = Some(usage_input_tokens); + output_tokens = Some(usage_output_tokens); + } + let content_opt = chunk .get("choices") .and_then(|c| c.get(0)) @@ -251,7 +297,7 @@ where // Swallow partial event; keep polling. continue; } - Poll::Ready(Some(Ok(ResponseEvent::Completed { response_id }))) => { + Poll::Ready(Some(Ok(ResponseEvent::Completed { response_id, .. }))) => { if !this.cumulative.is_empty() { let aggregated_item = crate::models::ResponseItem::Message { role: "assistant".to_string(), @@ -261,7 +307,11 @@ where }; // Buffer Completed so it is returned *after* the aggregated message. - this.pending_completed = Some(ResponseEvent::Completed { response_id }); + this.pending_completed = Some(ResponseEvent::Completed { + response_id, + input_tokens: None, + output_tokens: None, + }); return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone( aggregated_item, @@ -269,8 +319,12 @@ where } // Nothing aggregated – forward Completed directly. - return Poll::Ready(Some(Ok(ResponseEvent::Completed { response_id }))); - } // No other `Ok` variants exist at the moment, continue polling. + return Poll::Ready(Some(Ok(ResponseEvent::Completed { + response_id, + input_tokens: None, + output_tokens: None, + }))); + } } } } diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index fd5176dbb88..0f3e1414a64 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -1,6 +1,7 @@ use std::collections::BTreeMap; use std::io::BufRead; use std::path::Path; +use std::sync::Arc; use std::sync::LazyLock; use std::time::Duration; @@ -27,6 +28,7 @@ use crate::client_common::Reasoning; use crate::client_common::ResponseEvent; use crate::client_common::ResponseStream; use crate::client_common::Summary; +use crate::client_common::TokenAggregator; use crate::error::CodexErr; use crate::error::EnvVarError; use crate::error::Result; @@ -36,8 +38,8 @@ use crate::flags::OPENAI_STREAM_IDLE_TIMEOUT_MS; use crate::model_provider_info::ModelProviderInfo; use crate::model_provider_info::WireApi; use crate::models::ResponseItem; -use crate::util::backoff; use crate::util::UrlExt; +use crate::util::backoff; /// When serialized as JSON, this produces a valid "Tool" in the OpenAI /// Responses API. @@ -107,6 +109,7 @@ pub struct ModelClient { model: String, client: reqwest::Client, provider: ModelProviderInfo, + token_aggregator: Arc>, } impl ModelClient { @@ -115,9 +118,28 @@ impl ModelClient { model: model.to_string(), client: reqwest::Client::new(), provider, + token_aggregator: Arc::new(std::sync::Mutex::new(TokenAggregator::new())), } } + /// Returns the model name used for this client. Helper so callers inside + /// the business-logic layer (e.g. for pricing calculations) do not need + /// to reach into private fields. + pub fn model_name(&self) -> &str { + &self.model + } + + /// Expose the provider so higher-level modules (e.g. cost accounting) can + /// inspect metadata without breaking encapsulation. + pub fn provider(&self) -> &ModelProviderInfo { + &self.provider + } + + /// Get cumulative token usage for this session + pub fn get_session_token_usage(&self) -> (u32, u32) { + self.token_aggregator.lock().unwrap().get_token_totals() + } + /// Dispatches to either the Responses or Chat implementation depending on /// the provider config. Public callers always invoke `stream()` – the /// specialised helpers are private to avoid accidental misuse. @@ -126,9 +148,14 @@ impl ModelClient { WireApi::Responses => self.stream_responses(prompt).await, WireApi::Chat => { // Create the raw streaming connection first. - let response_stream = - stream_chat_completions(prompt, &self.model, &self.client, &self.provider) - .await?; + let response_stream = stream_chat_completions( + prompt, + &self.model, + &self.client, + &self.provider, + Arc::clone(&self.token_aggregator), + ) + .await?; // Wrap it with the aggregation adapter so callers see *only* // the final assistant message per turn (matching the @@ -199,7 +226,12 @@ impl ModelClient { stream: true, }; - let url = self.provider.base_url.clone().append_path("/responses")?.to_string(); + let url = self + .provider + .base_url + .clone() + .append_path("/responses")? + .to_string(); debug!("{} POST", url); trace!("request payload: {}", serde_json::to_string(&payload)?); @@ -237,7 +269,11 @@ impl ModelClient { // spawn task to process SSE let stream = resp.bytes_stream().map_err(CodexErr::Reqwest); - tokio::spawn(process_sse(stream, tx_event)); + tokio::spawn(process_sse( + stream, + tx_event, + Arc::clone(&self.token_aggregator), + )); return Ok(ResponseStream { rx_event }); } @@ -311,8 +347,11 @@ struct ResponseCompleted { id: String, } -async fn process_sse(stream: S, tx_event: mpsc::Sender>) -where +async fn process_sse( + stream: S, + tx_event: mpsc::Sender>, + token_aggregator: Arc>, +) where S: Stream> + Unpin, { let mut stream = stream.eventsource(); @@ -321,6 +360,9 @@ where let idle_timeout = *OPENAI_STREAM_IDLE_TIMEOUT_MS; // The response id returned from the "complete" message. let mut response_id = None; + // Token usage information + let mut input_tokens = None; + let mut output_tokens = None; loop { let sse = match timeout(idle_timeout, stream.next()).await { @@ -334,7 +376,11 @@ where Ok(None) => { match response_id { Some(response_id) => { - let event = ResponseEvent::Completed { response_id }; + let event = ResponseEvent::Completed { + response_id, + input_tokens, + output_tokens, + }; let _ = tx_event.send(Ok(event)).await; } None => { @@ -398,6 +444,28 @@ where // Final response completed – includes array of output items & id "response.completed" => { if let Some(resp_val) = event.response { + // Extract usage if present (Responses API uses input_tokens/output_tokens) + if let Some(usage) = resp_val.get("usage") { + let usage_input_tokens = usage + .get("input_tokens") + .and_then(|v| v.as_u64()) + .map(|v| v as u32) + .unwrap_or(0); + let usage_output_tokens = usage + .get("output_tokens") + .and_then(|v| v.as_u64()) + .map(|v| v as u32) + .unwrap_or(0); + + token_aggregator + .lock() + .unwrap() + .add_token_usage(usage_input_tokens, usage_output_tokens); + + input_tokens = Some(usage_input_tokens); + output_tokens = Some(usage_output_tokens); + } + match serde_json::from_value::(resp_val) { Ok(r) => { response_id = Some(r.id); @@ -429,6 +497,7 @@ async fn stream_from_fixture(path: impl AsRef) -> Result { let rdr = std::io::Cursor::new(content); let stream = ReaderStream::new(rdr).map_err(CodexErr::Io); - tokio::spawn(process_sse(stream, tx_event)); + let dummy_aggregator = Arc::new(std::sync::Mutex::new(TokenAggregator::new())); + tokio::spawn(process_sse(stream, tx_event, dummy_aggregator)); Ok(ResponseStream { rx_event }) } diff --git a/codex-rs/core/src/client_common.rs b/codex-rs/core/src/client_common.rs index 8eb8074b1ef..11484f22c97 100644 --- a/codex-rs/core/src/client_common.rs +++ b/codex-rs/core/src/client_common.rs @@ -47,7 +47,32 @@ impl Prompt { #[derive(Debug)] pub enum ResponseEvent { OutputItemDone(ResponseItem), - Completed { response_id: String }, + Completed { + response_id: String, + input_tokens: Option, + output_tokens: Option, + }, +} + +#[derive(Debug, Default)] +pub struct TokenAggregator { + total_input_tokens: u32, + total_output_tokens: u32, +} + +impl TokenAggregator { + pub fn new() -> Self { + Self::default() + } + + pub fn add_token_usage(&mut self, input_tokens: u32, output_tokens: u32) { + self.total_input_tokens += input_tokens; + self.total_output_tokens += output_tokens; + } + + pub fn get_token_totals(&self) -> (u32, u32) { + (self.total_input_tokens, self.total_output_tokens) + } } #[derive(Debug, Serialize)] diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 705b8260bb1..83828a19f49 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -81,6 +81,7 @@ use crate::rollout::RolloutRecorder; use crate::safety::SafetyCheck; use crate::safety::assess_command_safety; use crate::safety::assess_patch_safety; +use crate::usage::compute_openai_cost; use crate::user_notification::UserNotification; use crate::util::backoff; @@ -869,9 +870,38 @@ async fn run_task(sess: Arc, sub_id: String, input: Vec) { } } sess.remove_task(&sub_id); + + // Get aggregated usage from client and calculate cost for OpenAI models + let (total_input_tokens, total_output_tokens) = sess.client.get_session_token_usage(); + + let is_openai_provider = sess + .client + .provider() + .base_url + .as_str() + .contains("openai.com"); + + let token_usage_opt = if is_openai_provider { + let model = sess.client.model_name(); + let cost = compute_openai_cost(model, total_input_tokens, total_output_tokens); + Some(crate::protocol::TokenUsage { + input_tokens: total_input_tokens, + output_tokens: total_output_tokens, + total_cost: cost, + }) + } else { + Some(crate::protocol::TokenUsage { + input_tokens: total_input_tokens, + output_tokens: total_output_tokens, + total_cost: None, + }) + }; + let event = Event { id: sub_id, - msg: EventMsg::TaskComplete, + msg: EventMsg::TaskComplete { + token_usage: token_usage_opt, + }, }; sess.tx_event.send(event).await.ok(); } @@ -963,8 +993,7 @@ async fn try_run_turn( ) -> CodexResult> { let mut stream = sess.client.clone().stream(prompt).await?; - // Buffer all the incoming messages from the stream first, then execute them. - // If we execute a function call in the middle of handling the stream, it can time out. + // Buffer all incoming messages first as before. let mut input = Vec::new(); while let Some(event) = stream.next().await { input.push(event?); @@ -977,7 +1006,7 @@ async fn try_run_turn( let response = handle_response_item(sess, sub_id, item.clone()).await?; output.push(ProcessedResponseItem { item, response }); } - ResponseEvent::Completed { response_id } => { + ResponseEvent::Completed { response_id, .. } => { let mut state = sess.state.lock().unwrap(); state.previous_response_id = Some(response_id); break; diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index 00a65a67258..84d1f4827dc 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -26,6 +26,7 @@ pub mod mcp_server_config; mod mcp_tool_call; mod message_history; mod model_provider_info; +pub mod usage; pub use model_provider_info::ModelProviderInfo; pub use model_provider_info::WireApi; mod models; diff --git a/codex-rs/core/src/protocol.rs b/codex-rs/core/src/protocol.rs index 658b9a739b4..89aec6cd32a 100644 --- a/codex-rs/core/src/protocol.rs +++ b/codex-rs/core/src/protocol.rs @@ -320,8 +320,14 @@ pub enum EventMsg { /// Agent has started a task TaskStarted, - /// Agent has completed all actions - TaskComplete, + /// Agent has completed all actions. When using an OpenAI provider, the + /// server includes token usage metrics and total cost in USD for the entire + /// task. For non-OpenAI providers this field is `null`. + TaskComplete { + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(flatten)] + token_usage: Option, + }, /// Agent text output message AgentMessage(AgentMessageEvent), @@ -358,6 +364,15 @@ pub enum EventMsg { GetHistoryEntryResponse(GetHistoryEntryResponseEvent), } +/// Token usage information for API calls +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct TokenUsage { + pub input_tokens: u32, + pub output_tokens: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub total_cost: Option, +} + // Individual event payload types matching each `EventMsg` variant. #[derive(Debug, Clone, Deserialize, Serialize)] @@ -535,29 +550,3 @@ pub struct Chunk { pub inserted_lines: Vec, } -#[cfg(test)] -mod tests { - #![allow(clippy::unwrap_used)] - use super::*; - - /// Serialize Event to verify that its JSON representation has the expected - /// amount of nesting. - #[test] - fn serialize_event() { - let session_id: Uuid = uuid::uuid!("67e55044-10b1-426f-9247-bb680e5fe0c8"); - let event = Event { - id: "1234".to_string(), - msg: EventMsg::SessionConfigured(SessionConfiguredEvent { - session_id, - model: "o4-mini".to_string(), - history_log_id: 0, - history_entry_count: 0, - }), - }; - let serialized = serde_json::to_string(&event).unwrap(); - assert_eq!( - serialized, - r#"{"id":"1234","msg":{"type":"session_configured","session_id":"67e55044-10b1-426f-9247-bb680e5fe0c8","model":"o4-mini","history_log_id":0,"history_entry_count":0}}"# - ); - } -} diff --git a/codex-rs/core/src/usage.rs b/codex-rs/core/src/usage.rs new file mode 100644 index 00000000000..70c8ae5f83e --- /dev/null +++ b/codex-rs/core/src/usage.rs @@ -0,0 +1,89 @@ +/// Computes the total cost in USD for the given token usage using OpenAI pricing. +/// Returns `None` if the model is unknown or pricing is unavailable. +pub fn compute_openai_cost(model: &str, input_tokens: u32, output_tokens: u32) -> Option { + let (per_input_token_cost, per_output_token_cost) = get_openai_pricing(model)?; + // Rates are per-token. Multiply directly. + let cost = (input_tokens as f64) * per_input_token_cost + + (output_tokens as f64) * per_output_token_cost; + Some(cost) +} + +/// Returns the OpenAI per-token pricing (input, output) **in USD** for +/// a given model name. The list is not exhaustive – it only covers the most +/// common public models so we offer reasonable estimates without hard-coding +/// every single variant. Unknown models return `None` so callers can fall +/// back gracefully. +pub fn get_openai_pricing(model: &str) -> Option<(f64, f64)> { + // Exact mapping (per *token* rates, not per-1K) + // Order matters: more specific matches must come before general ones + let detailed: &[(&str, (f64, f64))] = &[ + // (model, (input, output)) + ("o3", (10.0 / 1_000_000.0, 40.0 / 1_000_000.0)), + ("o4-mini", (1.1 / 1_000_000.0, 4.4 / 1_000_000.0)), + ("gpt-4.1-nano", (0.1 / 1_000_000.0, 0.4 / 1_000_000.0)), + ("gpt-4.1-mini", (0.4 / 1_000_000.0, 1.6 / 1_000_000.0)), + ("gpt-4.1", (2.0 / 1_000_000.0, 8.0 / 1_000_000.0)), + ("gpt-4o-mini", (0.6 / 1_000_000.0, 2.4 / 1_000_000.0)), + ("gpt-4o", (5.0 / 1_000_000.0, 20.0 / 1_000_000.0)), + ("codex-mini-latest", (1.5 / 1_000_000.0, 6.0 / 1_000_000.0)), + ]; + + let key = model.to_ascii_lowercase(); + detailed + .iter() + // We use starts_with to match model variants (e.g., "gpt-4o-2024-11-20" matches "gpt-4o") + .find(|(m, _)| key.starts_with(*m)) + .map(|(_, r)| *r) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_compute_openai_cost() { + // Test cost computation for known model + let cost = compute_openai_cost("gpt-4o-mini", 1000, 500).unwrap(); + let expected = (1000.0 * 0.6 / 1_000_000.0) + (500.0 * 2.4 / 1_000_000.0); + assert_eq!(cost, expected); + + // Test cost computation for unknown model + assert_eq!(compute_openai_cost("unknown-model", 1000, 500), None); + + // Test zero tokens + let cost = compute_openai_cost("gpt-4o-mini", 0, 0).unwrap(); + assert_eq!(cost, 0.0); + } + + #[test] + fn test_openai_pricing() { + // Test exact matches + assert_eq!( + get_openai_pricing("gpt-4o-mini"), + Some((0.6 / 1_000_000.0, 2.4 / 1_000_000.0)) + ); + assert_eq!( + get_openai_pricing("gpt-4o"), + Some((5.0 / 1_000_000.0, 20.0 / 1_000_000.0)) + ); + assert_eq!( + get_openai_pricing("codex-mini-latest"), + Some((1.5 / 1_000_000.0, 6.0 / 1_000_000.0)) + ); + + // Test model variants (should match prefix) + assert_eq!( + get_openai_pricing("gpt-4o-2024-11-20"), + Some((5.0 / 1_000_000.0, 20.0 / 1_000_000.0)) + ); + + // Test unknown model + assert_eq!(get_openai_pricing("unknown-model"), None); + + // Test case insensitive + assert_eq!( + get_openai_pricing("GPT-4O-MINI"), + Some((0.6 / 1_000_000.0, 2.4 / 1_000_000.0)) + ); + } +} diff --git a/codex-rs/core/tests/live_agent.rs b/codex-rs/core/tests/live_agent.rs index bc5a1105958..c6d28d2c80a 100644 --- a/codex-rs/core/tests/live_agent.rs +++ b/codex-rs/core/tests/live_agent.rs @@ -98,7 +98,7 @@ async fn live_streaming_and_prev_id_reset() { match ev.msg { EventMsg::AgentMessage(_) => saw_message_before_complete = true, - EventMsg::TaskComplete => break, + EventMsg::TaskComplete { .. } => break, EventMsg::Error(ErrorEvent { message }) => { panic!("agent reported error in task1: {message}") } @@ -136,7 +136,7 @@ async fn live_streaming_and_prev_id_reset() { { got_expected = true; } - EventMsg::TaskComplete => break, + EventMsg::TaskComplete { .. } => break, EventMsg::Error(ErrorEvent { message }) => { panic!("agent reported error in task2: {message}") } @@ -204,7 +204,7 @@ async fn live_shell_function_call() { assert!(stdout.contains(MARKER)); saw_end_with_output = true; } - EventMsg::TaskComplete => break, + EventMsg::TaskComplete { .. } => break, EventMsg::Error(codex_core::protocol::ErrorEvent { message }) => { panic!("agent error during shell test: {message}") } diff --git a/codex-rs/core/tests/previous_response_id.rs b/codex-rs/core/tests/previous_response_id.rs index cefa8a33d13..df37d185ed6 100644 --- a/codex-rs/core/tests/previous_response_id.rs +++ b/codex-rs/core/tests/previous_response_id.rs @@ -132,7 +132,7 @@ async fn keeps_previous_response_id_between_tasks() { .await .unwrap() .unwrap(); - if matches!(ev.msg, EventMsg::TaskComplete) { + if matches!(ev.msg, EventMsg::TaskComplete { .. }) { break; } } @@ -154,7 +154,7 @@ async fn keeps_previous_response_id_between_tasks() { .unwrap() .unwrap(); match ev.msg { - EventMsg::TaskComplete => break, + EventMsg::TaskComplete { .. } => break, EventMsg::Error(ErrorEvent { message }) => { panic!("unexpected error: {message}") } diff --git a/codex-rs/core/tests/protocol_serialization.rs b/codex-rs/core/tests/protocol_serialization.rs new file mode 100644 index 00000000000..71061d3114a --- /dev/null +++ b/codex-rs/core/tests/protocol_serialization.rs @@ -0,0 +1,82 @@ +use codex_core::protocol::*; +use uuid::Uuid; + +/// Serialize Event to verify that its JSON representation has the expected +/// amount of nesting. +#[test] +fn serialize_event() { + let session_id: Uuid = uuid::uuid!("67e55044-10b1-426f-9247-bb680e5fe0c8"); + let event = Event { + id: "1234".to_string(), + msg: EventMsg::SessionConfigured(SessionConfiguredEvent { + session_id, + model: "o4-mini".to_string(), + history_log_id: 0, + history_entry_count: 0, + }), + }; + let serialized = serde_json::to_string(&event).unwrap(); + assert_eq!( + serialized, + r#"{"id":"1234","msg":{"type":"session_configured","session_id":"67e55044-10b1-426f-9247-bb680e5fe0c8","model":"o4-mini","history_log_id":0,"history_entry_count":0}}"# + ); +} + +/// Serialize TaskComplete event with token usage to verify JSON format +#[test] +fn serialize_task_complete_with_usage() { + let event = Event { + id: "5678".to_string(), + msg: EventMsg::TaskComplete { + token_usage: Some(TokenUsage { + input_tokens: 1000, + output_tokens: 500, + total_cost: Some(0.0125), + }), + }, + }; + let serialized = serde_json::to_string(&event).unwrap(); + println!("JSON with cost: {}", serialized); + assert_eq!( + serialized, + r#"{"id":"5678","msg":{"type":"task_complete","input_tokens":1000,"output_tokens":500,"total_cost":0.0125}}"# + ); +} + +/// Serialize TaskComplete event without cost to verify JSON format +#[test] +fn serialize_task_complete_no_cost() { + let event = Event { + id: "9999".to_string(), + msg: EventMsg::TaskComplete { + token_usage: Some(TokenUsage { + input_tokens: 1500, + output_tokens: 750, + total_cost: None, + }), + }, + }; + let serialized = serde_json::to_string(&event).unwrap(); + println!("JSON without cost: {}", serialized); + assert_eq!( + serialized, + r#"{"id":"9999","msg":{"type":"task_complete","input_tokens":1500,"output_tokens":750}}"# + ); +} + +/// Serialize TaskComplete event with no token usage +#[test] +fn serialize_task_complete_no_usage() { + let event = Event { + id: "0000".to_string(), + msg: EventMsg::TaskComplete { + token_usage: None, + }, + }; + let serialized = serde_json::to_string(&event).unwrap(); + println!("JSON no usage: {}", serialized); + assert_eq!( + serialized, + r#"{"id":"0000","msg":{"type":"task_complete"}}"# + ); +} \ No newline at end of file diff --git a/codex-rs/core/tests/stream_no_completed.rs b/codex-rs/core/tests/stream_no_completed.rs index 8753b0ed10b..b752fee5dac 100644 --- a/codex-rs/core/tests/stream_no_completed.rs +++ b/codex-rs/core/tests/stream_no_completed.rs @@ -118,7 +118,7 @@ async fn retries_on_early_close() { .await .unwrap() .unwrap(); - if matches!(ev.msg, codex_core::protocol::EventMsg::TaskComplete) { + if matches!(ev.msg, codex_core::protocol::EventMsg::TaskComplete { .. }) { break; } } diff --git a/codex-rs/exec/src/event_processor.rs b/codex-rs/exec/src/event_processor.rs index 4c8278cc597..674bfc94b67 100644 --- a/codex-rs/exec/src/event_processor.rs +++ b/codex-rs/exec/src/event_processor.rs @@ -117,8 +117,17 @@ impl EventProcessor { let msg = format!("Task started: {id}"); ts_println!("{}", msg.style(self.dimmed)); } - EventMsg::TaskComplete => { - let msg = format!("Task complete: {id}"); + EventMsg::TaskComplete { token_usage } => { + let mut msg = format!("Task complete: {id}"); + if let Some(usage) = token_usage { + msg.push_str(&format!( + " ({}input, {}output tokens)", + usage.input_tokens, usage.output_tokens + )); + if let Some(cost) = usage.total_cost { + msg.push_str(&format!(" [${:.4}]", cost)); + } + } ts_println!("{}", msg.style(self.bold)); } EventMsg::AgentMessage(AgentMessageEvent { message }) => { diff --git a/codex-rs/exec/src/lib.rs b/codex-rs/exec/src/lib.rs index 348bff08e61..f362ab67abe 100644 --- a/codex-rs/exec/src/lib.rs +++ b/codex-rs/exec/src/lib.rs @@ -137,7 +137,7 @@ pub async fn run_main(cli: Cli) -> anyhow::Result<()> { let initial_images_event_id = codex.submit(Op::UserInput { items }).await?; info!("Sent images with event ID: {initial_images_event_id}"); while let Ok(event) = codex.next_event().await { - if event.id == initial_images_event_id && matches!(event.msg, EventMsg::TaskComplete) { + if event.id == initial_images_event_id && matches!(event.msg, EventMsg::TaskComplete { .. }) { break; } } @@ -152,7 +152,7 @@ pub async fn run_main(cli: Cli) -> anyhow::Result<()> { let mut event_processor = EventProcessor::create_with_ansi(stdout_with_ansi); while let Some(event) = rx.recv().await { let last_event = - event.id == initial_prompt_task_id && matches!(event.msg, EventMsg::TaskComplete); + event.id == initial_prompt_task_id && matches!(event.msg, EventMsg::TaskComplete { .. }); event_processor.process_event(event); if last_event { break; diff --git a/codex-rs/mcp-server/src/codex_tool_runner.rs b/codex-rs/mcp-server/src/codex_tool_runner.rs index f6f6798cfea..a1cd8733e24 100644 --- a/codex-rs/mcp-server/src/codex_tool_runner.rs +++ b/codex-rs/mcp-server/src/codex_tool_runner.rs @@ -125,7 +125,7 @@ pub async fn run_codex_tool_session( .await; break; } - EventMsg::TaskComplete => { + EventMsg::TaskComplete { .. } => { let result = if let Some(msg) = last_agent_message { CallToolResult { content: vec![CallToolResultContent::TextContent(TextContent { diff --git a/codex-rs/tui/src/chatwidget.rs b/codex-rs/tui/src/chatwidget.rs index 8ceef95b4a6..fe10b29df76 100644 --- a/codex-rs/tui/src/chatwidget.rs +++ b/codex-rs/tui/src/chatwidget.rs @@ -246,7 +246,7 @@ impl ChatWidget<'_> { self.bottom_pane.set_task_running(true); self.request_redraw(); } - EventMsg::TaskComplete => { + EventMsg::TaskComplete { .. } => { self.bottom_pane.set_task_running(false); self.request_redraw(); }