From 36051c0a17281cc20ce0925a0805deb8f52fe048 Mon Sep 17 00:00:00 2001 From: Arnaud Stiegler Date: Thu, 22 May 2025 12:29:14 -0400 Subject: [PATCH 1/9] adding cost --- codex-rs/core/src/chat_completions.rs | 30 +++- codex-rs/core/src/client.rs | 36 +++- codex-rs/core/src/client_common.rs | 4 + codex-rs/core/src/codex.rs | 173 +++++++++++++++++-- codex-rs/core/src/protocol.rs | 13 +- codex-rs/core/tests/live_agent.rs | 6 +- codex-rs/core/tests/previous_response_id.rs | 4 +- codex-rs/core/tests/stream_no_completed.rs | 2 +- codex-rs/exec/src/event_processor.rs | 7 +- codex-rs/exec/src/lib.rs | 4 +- codex-rs/mcp-server/src/codex_tool_runner.rs | 2 +- codex-rs/tui/src/chatwidget.rs | 2 +- 12 files changed, 254 insertions(+), 29 deletions(-) diff --git a/codex-rs/core/src/chat_completions.rs b/codex-rs/core/src/chat_completions.rs index 52affa46dcf..f41018b6cc9 100644 --- a/codex-rs/core/src/chat_completions.rs +++ b/codex-rs/core/src/chat_completions.rs @@ -60,7 +60,8 @@ pub(crate) async fn stream_chat_completions( let payload = json!({ "model": model, "messages": messages, - "stream": true + "stream": true, + "stream_options": {"include_usage": true} }); let url = provider.base_url.clone().append_path("/chat/completions")?.to_string(); @@ -174,6 +175,27 @@ where Err(_) => continue, }; + // Forward usage statistics when requested. + if let Some(usage) = chunk.get("usage") { + let prompt_tokens = usage + .get("prompt_tokens") + .or_else(|| usage.get("input_tokens")) + .and_then(|v| v.as_u64()) + .unwrap_or(0) as u32; + let completion_tokens = usage + .get("completion_tokens") + .or_else(|| usage.get("output_tokens")) + .and_then(|v| v.as_u64()) + .unwrap_or(0) as u32; + + let _ = tx_event + .send(Ok(ResponseEvent::Usage { + prompt_tokens, + completion_tokens, + })) + .await; + } + let content_opt = chunk .get("choices") .and_then(|c| c.get(0)) @@ -270,7 +292,11 @@ where // Nothing aggregated – forward Completed directly. return Poll::Ready(Some(Ok(ResponseEvent::Completed { response_id }))); - } // No other `Ok` variants exist at the moment, continue polling. + } + Poll::Ready(Some(Ok(ev))) => { + // Forward any other event types (e.g., Usage). + return Poll::Ready(Some(Ok(ev))); + } } } } diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index fd5176dbb88..ef5e5809927 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -118,6 +118,19 @@ impl ModelClient { } } + /// Returns the model name used for this client. Helper so callers inside + /// the business-logic layer (e.g. for pricing calculations) do not need + /// to reach into private fields. + pub fn model_name(&self) -> &str { + &self.model + } + + /// Expose the provider so higher-level modules (e.g. cost accounting) can + /// inspect metadata without breaking encapsulation. + pub fn provider(&self) -> &ModelProviderInfo { + &self.provider + } + /// Dispatches to either the Responses or Chat implementation depending on /// the provider config. Public callers always invoke `stream()` – the /// specialised helpers are private to avoid accidental misuse. @@ -398,6 +411,27 @@ where // Final response completed – includes array of output items & id "response.completed" => { if let Some(resp_val) = event.response { + // Extract usage if present + if let Some(usage) = resp_val.get("usage") { + let prompt_tokens = usage + .get("prompt_tokens") + .or_else(|| usage.get("input_tokens")) + .and_then(|v| v.as_u64()) + .unwrap_or(0) as u32; + let completion_tokens = usage + .get("completion_tokens") + .or_else(|| usage.get("output_tokens")) + .and_then(|v| v.as_u64()) + .unwrap_or(0) as u32; + + let _ = tx_event + .send(Ok(ResponseEvent::Usage { + prompt_tokens, + completion_tokens, + })) + .await; + } + match serde_json::from_value::(resp_val) { Ok(r) => { response_id = Some(r.id); @@ -407,7 +441,7 @@ where continue; } }; - }; + } } other => debug!(other, "sse event"), } diff --git a/codex-rs/core/src/client_common.rs b/codex-rs/core/src/client_common.rs index 8eb8074b1ef..dfe2c182f5a 100644 --- a/codex-rs/core/src/client_common.rs +++ b/codex-rs/core/src/client_common.rs @@ -48,6 +48,10 @@ impl Prompt { pub enum ResponseEvent { OutputItemDone(ResponseItem), Completed { response_id: String }, + Usage { + prompt_tokens: u32, + completion_tokens: u32, + }, } #[derive(Debug, Serialize)] diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 705b8260bb1..5bd5843f699 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -193,6 +193,89 @@ impl Session { } } +// ----------------------------------------------------------------------------- +// Helper functions (private to this module) +// ----------------------------------------------------------------------------- + +/// Very rough approximation for the token count of an arbitrary string. We use +/// a simple heuristic of 4 characters per token, which is commonly accepted as +/// “good enough” for estimating costs without a tokenizer. The result is +/// *never* used for billing – only for displaying approximate usage stats to +/// the user. +fn approx_token_count(s: &str) -> usize { + // Avoid division by zero for empty strings. + if s.is_empty() { + 0 + } else { + (s.len() + 3) / 4 // round up + } +} + +/// Counts the number of tokens contained in a collection of [`ResponseItem`]s +/// by summing up the textual content of all `InputText` and `OutputText` +/// elements. +fn count_tokens_in_items(items: &[ResponseItem]) -> usize { + items + .iter() + .map(|item| match item { + ResponseItem::Message { content, .. } => content + .iter() + .filter_map(|c| match c { + ContentItem::InputText { text } | ContentItem::OutputText { text } => { + Some(approx_token_count(text)) + } + _ => None, + }) + .sum::(), + _ => 0, + }) + .sum() +} + +/// Returns the OpenAI per-1K-token pricing (prompt, completion) **in USD** for +/// a given model name. The list is *not* exhaustive – it only covers the most +/// common public models so we offer reasonable estimates without hard-coding +/// every single variant. Unknown models default to `None` so callers can fall +/// back gracefully. +fn get_openai_pricing(model: &str) -> Option<(f64, f64)> { + // Exact mapping (per *token* rates, not per-1K) + let detailed: &[(&str, (f64, f64))] = &[ // (model, (input, output)) + ("o3", (10.0 / 1_000_000.0, 40.0 / 1_000_000.0)), + ("o4-mini", (1.1 / 1_000_000.0, 4.4 / 1_000_000.0)), + ("gpt-4.1-nano", (0.1 / 1_000_000.0, 0.4 / 1_000_000.0)), + ("gpt-4.1-mini", (0.4 / 1_000_000.0, 1.6 / 1_000_000.0)), + ("gpt-4.1", (2.0 / 1_000_000.0, 8.0 / 1_000_000.0)), + ("gpt-4o-mini", (0.6 / 1_000_000.0, 2.4 / 1_000_000.0)), + ("gpt-4o", (5.0 / 1_000_000.0, 20.0 / 1_000_000.0)), + ]; + + let key = model.to_ascii_lowercase(); + if let Some((in_rate, out_rate)) = detailed + .iter() + .find(|(m, _)| key.starts_with(*m)) + .map(|(_, r)| *r) + { + return Some((in_rate, out_rate)); + } + + // Fallback coarse buckets (per-1K rates → convert to per-token) + let per_1k_to_per_token = |x: f64| x / 1000.0; + if key.contains("gpt-4o") { + return Some((per_1k_to_per_token(0.005), per_1k_to_per_token(0.015))); + } + if key.contains("gpt-4-turbo") { + return Some((per_1k_to_per_token(0.01), per_1k_to_per_token(0.03))); + } + if key.contains("gpt-4") { + return Some((per_1k_to_per_token(0.03), per_1k_to_per_token(0.06))); + } + if key.contains("gpt-3.5-turbo") { + return Some((per_1k_to_per_token(0.0005), per_1k_to_per_token(0.0015))); + } + None +} + + /// Mutable state of the agent #[derive(Default)] struct State { @@ -765,6 +848,14 @@ async fn run_task(sess: Arc, sub_id: String, input: Vec) { return; } + // Track overall token usage for this task so we can expose usage/cost + // statistics in the final `TaskComplete` event. These are *approximate* + // counts based on a naive 4-character-per-token heuristic – sufficient + // for ballpark cost estimation without pulling in a heavyweight tokenizer + // dependency. + let mut total_prompt_tokens: usize = 0; + let mut total_completion_tokens: usize = 0; + let mut pending_response_input: Vec = vec![ResponseInputItem::from(input)]; loop { let mut net_new_turn_input = pending_response_input @@ -807,6 +898,10 @@ async fn run_task(sess: Arc, sub_id: String, input: Vec) { net_new_turn_input }; + // Token accounting for providers that do not return detailed usage + // stats. We add *approximate* counts here and later replace them with + // exact ones when `usage_in` becomes available. + let turn_input_messages: Vec = turn_input .iter() .filter_map(|item| match item { @@ -820,8 +915,15 @@ async fn run_task(sess: Arc, sub_id: String, input: Vec) { }) }) .collect(); + + // Pre-compute approximate token count before `turn_input` is moved. + let turn_input_approx_tokens = count_tokens_in_items(&turn_input); match run_turn(&sess, sub_id.clone(), turn_input).await { - Ok(turn_output) => { + Ok((turn_output, usage_in, usage_out)) => { + // Accumulate exact token usage when available. + total_prompt_tokens += usage_in as usize; + total_completion_tokens += usage_out as usize; + let (items, responses): (Vec<_>, Vec<_>) = turn_output .into_iter() .map(|p| (p.item, p.response)) @@ -832,6 +934,15 @@ async fn run_task(sess: Arc, sub_id: String, input: Vec) { .collect::>(); let last_assistant_message = get_last_assistant_message_from_turn(&items); + // When usage info not provided (usage_out==0) fall back to approximate. + if usage_out == 0 { + total_completion_tokens += count_tokens_in_items(&items); + } + + if usage_in == 0 { + total_prompt_tokens += turn_input_approx_tokens; + } + // Only attempt to take the lock if there is something to record. if !items.is_empty() { // First persist model-generated output to the rollout file – this only borrows. @@ -869,9 +980,36 @@ async fn run_task(sess: Arc, sub_id: String, input: Vec) { } } sess.remove_task(&sub_id); + + // Calculate total cost for OpenAI models if possible. + let is_openai_provider = sess + .client + .provider() + .base_url + .contains("openai.com"); + + let (total_cost_opt, prompt_tokens_opt, completion_tokens_opt) = if is_openai_provider { + let model = sess.client.model_name(); + let (prompt_rate, completion_rate) = get_openai_pricing(model).unwrap_or((0.0, 0.0)); + // Rates are per-token. Multiply directly. + let cost = (total_prompt_tokens as f64) * prompt_rate + + (total_completion_tokens as f64) * completion_rate; + ( + Some(cost), + Some(total_prompt_tokens as u32), + Some(total_completion_tokens as u32), + ) + } else { + (None, None, None) + }; + let event = Event { id: sub_id, - msg: EventMsg::TaskComplete, + msg: EventMsg::TaskComplete { + total_cost: total_cost_opt, + prompt_tokens: prompt_tokens_opt, + completion_tokens: completion_tokens_opt, + }, }; sess.tx_event.send(event).await.ok(); } @@ -880,7 +1018,7 @@ async fn run_turn( sess: &Session, sub_id: String, input: Vec, -) -> CodexResult> { +) -> CodexResult<(Vec, u32, u32)> { // Decide whether to use server-side storage (previous_response_id) or disable it let (prev_id, store, is_first_turn) = { let state = sess.state.lock().unwrap(); @@ -914,7 +1052,7 @@ async fn run_turn( let mut retries = 0; loop { match try_run_turn(sess, &sub_id, &prompt).await { - Ok(output) => return Ok(output), + Ok(res) => return Ok(res), Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted), Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)), Err(e) => { @@ -960,18 +1098,20 @@ async fn try_run_turn( sess: &Session, sub_id: &str, prompt: &Prompt, -) -> CodexResult> { +) -> CodexResult<(Vec, u32, u32)> { let mut stream = sess.client.clone().stream(prompt).await?; - // Buffer all the incoming messages from the stream first, then execute them. - // If we execute a function call in the middle of handling the stream, it can time out. - let mut input = Vec::new(); + // Buffer all incoming messages first as before. + let mut input_events = Vec::new(); while let Some(event) = stream.next().await { - input.push(event?); + input_events.push(event?); } let mut output = Vec::new(); - for event in input { + let mut prompt_tokens: u32 = 0; + let mut completion_tokens: u32 = 0; + + for event in input_events { match event { ResponseEvent::OutputItemDone(item) => { let response = handle_response_item(sess, sub_id, item.clone()).await?; @@ -980,11 +1120,20 @@ async fn try_run_turn( ResponseEvent::Completed { response_id } => { let mut state = sess.state.lock().unwrap(); state.previous_response_id = Some(response_id); - break; + // Do not break – there might be Usage event afterwards, but + // in practice Completed comes last; we keep scanning. + } + ResponseEvent::Usage { + prompt_tokens: p, + completion_tokens: c, + } => { + prompt_tokens += p; + completion_tokens += c; } } } - Ok(output) + + Ok((output, prompt_tokens, completion_tokens)) } async fn handle_response_item( diff --git a/codex-rs/core/src/protocol.rs b/codex-rs/core/src/protocol.rs index 658b9a739b4..1cdb13f0a3e 100644 --- a/codex-rs/core/src/protocol.rs +++ b/codex-rs/core/src/protocol.rs @@ -320,8 +320,17 @@ pub enum EventMsg { /// Agent has started a task TaskStarted, - /// Agent has completed all actions - TaskComplete, + /// Agent has completed all actions. When using an OpenAI provider, the + /// server includes the total cost in USD as well as token usage metrics + /// for the entire task. For non-OpenAI providers these fields are `null`. + TaskComplete { + #[serde(skip_serializing_if = "Option::is_none")] + total_cost: Option, + #[serde(skip_serializing_if = "Option::is_none")] + prompt_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + completion_tokens: Option, + }, /// Agent text output message AgentMessage(AgentMessageEvent), diff --git a/codex-rs/core/tests/live_agent.rs b/codex-rs/core/tests/live_agent.rs index bc5a1105958..c6d28d2c80a 100644 --- a/codex-rs/core/tests/live_agent.rs +++ b/codex-rs/core/tests/live_agent.rs @@ -98,7 +98,7 @@ async fn live_streaming_and_prev_id_reset() { match ev.msg { EventMsg::AgentMessage(_) => saw_message_before_complete = true, - EventMsg::TaskComplete => break, + EventMsg::TaskComplete { .. } => break, EventMsg::Error(ErrorEvent { message }) => { panic!("agent reported error in task1: {message}") } @@ -136,7 +136,7 @@ async fn live_streaming_and_prev_id_reset() { { got_expected = true; } - EventMsg::TaskComplete => break, + EventMsg::TaskComplete { .. } => break, EventMsg::Error(ErrorEvent { message }) => { panic!("agent reported error in task2: {message}") } @@ -204,7 +204,7 @@ async fn live_shell_function_call() { assert!(stdout.contains(MARKER)); saw_end_with_output = true; } - EventMsg::TaskComplete => break, + EventMsg::TaskComplete { .. } => break, EventMsg::Error(codex_core::protocol::ErrorEvent { message }) => { panic!("agent error during shell test: {message}") } diff --git a/codex-rs/core/tests/previous_response_id.rs b/codex-rs/core/tests/previous_response_id.rs index cefa8a33d13..df37d185ed6 100644 --- a/codex-rs/core/tests/previous_response_id.rs +++ b/codex-rs/core/tests/previous_response_id.rs @@ -132,7 +132,7 @@ async fn keeps_previous_response_id_between_tasks() { .await .unwrap() .unwrap(); - if matches!(ev.msg, EventMsg::TaskComplete) { + if matches!(ev.msg, EventMsg::TaskComplete { .. }) { break; } } @@ -154,7 +154,7 @@ async fn keeps_previous_response_id_between_tasks() { .unwrap() .unwrap(); match ev.msg { - EventMsg::TaskComplete => break, + EventMsg::TaskComplete { .. } => break, EventMsg::Error(ErrorEvent { message }) => { panic!("unexpected error: {message}") } diff --git a/codex-rs/core/tests/stream_no_completed.rs b/codex-rs/core/tests/stream_no_completed.rs index 8753b0ed10b..b752fee5dac 100644 --- a/codex-rs/core/tests/stream_no_completed.rs +++ b/codex-rs/core/tests/stream_no_completed.rs @@ -118,7 +118,7 @@ async fn retries_on_early_close() { .await .unwrap() .unwrap(); - if matches!(ev.msg, codex_core::protocol::EventMsg::TaskComplete) { + if matches!(ev.msg, codex_core::protocol::EventMsg::TaskComplete { .. }) { break; } } diff --git a/codex-rs/exec/src/event_processor.rs b/codex-rs/exec/src/event_processor.rs index 4c8278cc597..de74d193aa9 100644 --- a/codex-rs/exec/src/event_processor.rs +++ b/codex-rs/exec/src/event_processor.rs @@ -117,8 +117,11 @@ impl EventProcessor { let msg = format!("Task started: {id}"); ts_println!("{}", msg.style(self.dimmed)); } - EventMsg::TaskComplete => { - let msg = format!("Task complete: {id}"); + EventMsg::TaskComplete { total_cost, prompt_tokens, completion_tokens } => { + let mut msg = format!("Task complete: {id}"); + if let (Some(cost), Some(inp), Some(out)) = (total_cost, prompt_tokens, completion_tokens) { + msg.push_str(&format!(" – cost ${:.4} ({} in / {} out tokens)", cost, inp, out)); + } ts_println!("{}", msg.style(self.bold)); } EventMsg::AgentMessage(AgentMessageEvent { message }) => { diff --git a/codex-rs/exec/src/lib.rs b/codex-rs/exec/src/lib.rs index 348bff08e61..f362ab67abe 100644 --- a/codex-rs/exec/src/lib.rs +++ b/codex-rs/exec/src/lib.rs @@ -137,7 +137,7 @@ pub async fn run_main(cli: Cli) -> anyhow::Result<()> { let initial_images_event_id = codex.submit(Op::UserInput { items }).await?; info!("Sent images with event ID: {initial_images_event_id}"); while let Ok(event) = codex.next_event().await { - if event.id == initial_images_event_id && matches!(event.msg, EventMsg::TaskComplete) { + if event.id == initial_images_event_id && matches!(event.msg, EventMsg::TaskComplete { .. }) { break; } } @@ -152,7 +152,7 @@ pub async fn run_main(cli: Cli) -> anyhow::Result<()> { let mut event_processor = EventProcessor::create_with_ansi(stdout_with_ansi); while let Some(event) = rx.recv().await { let last_event = - event.id == initial_prompt_task_id && matches!(event.msg, EventMsg::TaskComplete); + event.id == initial_prompt_task_id && matches!(event.msg, EventMsg::TaskComplete { .. }); event_processor.process_event(event); if last_event { break; diff --git a/codex-rs/mcp-server/src/codex_tool_runner.rs b/codex-rs/mcp-server/src/codex_tool_runner.rs index f6f6798cfea..a1cd8733e24 100644 --- a/codex-rs/mcp-server/src/codex_tool_runner.rs +++ b/codex-rs/mcp-server/src/codex_tool_runner.rs @@ -125,7 +125,7 @@ pub async fn run_codex_tool_session( .await; break; } - EventMsg::TaskComplete => { + EventMsg::TaskComplete { .. } => { let result = if let Some(msg) = last_agent_message { CallToolResult { content: vec![CallToolResultContent::TextContent(TextContent { diff --git a/codex-rs/tui/src/chatwidget.rs b/codex-rs/tui/src/chatwidget.rs index 8ceef95b4a6..fe10b29df76 100644 --- a/codex-rs/tui/src/chatwidget.rs +++ b/codex-rs/tui/src/chatwidget.rs @@ -246,7 +246,7 @@ impl ChatWidget<'_> { self.bottom_pane.set_task_running(true); self.request_redraw(); } - EventMsg::TaskComplete => { + EventMsg::TaskComplete { .. } => { self.bottom_pane.set_task_running(false); self.request_redraw(); } From ac0e158e8d73b4fb5f7fa0e6f7c44ec463ab14b4 Mon Sep 17 00:00:00 2001 From: Arnaud Stiegler Date: Thu, 22 May 2025 10:00:15 -0400 Subject: [PATCH 2/9] add pricing for codex-mini-latest --- codex-rs/core/src/codex.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 5bd5843f699..8bb7969baab 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -247,6 +247,7 @@ fn get_openai_pricing(model: &str) -> Option<(f64, f64)> { ("gpt-4.1", (2.0 / 1_000_000.0, 8.0 / 1_000_000.0)), ("gpt-4o-mini", (0.6 / 1_000_000.0, 2.4 / 1_000_000.0)), ("gpt-4o", (5.0 / 1_000_000.0, 20.0 / 1_000_000.0)), + ("codex-mini-latest", (1.5 / 1_000_000.0, 6.0 / 1_000_000.0)), ]; let key = model.to_ascii_lowercase(); From f03f48ede288fcd31e1a200ea33a3b683591eccd Mon Sep 17 00:00:00 2001 From: Arnaud Stiegler Date: Thu, 22 May 2025 14:38:55 -0400 Subject: [PATCH 3/9] addressing comments --- codex-rs/core/src/chat_completions.rs | 47 +++++---- codex-rs/core/src/client.rs | 24 ++--- codex-rs/core/src/client_common.rs | 9 +- codex-rs/core/src/codex.rs | 141 +++++--------------------- codex-rs/core/src/lib.rs | 1 + codex-rs/core/src/protocol.rs | 4 +- codex-rs/core/src/usage.rs | 26 +++++ codex-rs/exec/src/event_processor.rs | 4 +- 8 files changed, 98 insertions(+), 158 deletions(-) create mode 100644 codex-rs/core/src/usage.rs diff --git a/codex-rs/core/src/chat_completions.rs b/codex-rs/core/src/chat_completions.rs index f41018b6cc9..5a93597ad20 100644 --- a/codex-rs/core/src/chat_completions.rs +++ b/codex-rs/core/src/chat_completions.rs @@ -134,6 +134,10 @@ where let mut stream = stream.eventsource(); let idle_timeout = *OPENAI_STREAM_IDLE_TIMEOUT_MS; + + // Track usage information to include in final completion event + let mut input_tokens = None; + let mut output_tokens = None; loop { let sse = match timeout(idle_timeout, stream.next()).await { @@ -147,6 +151,8 @@ where let _ = tx_event .send(Ok(ResponseEvent::Completed { response_id: String::new(), + input_tokens, + output_tokens, })) .await; return; @@ -164,6 +170,8 @@ where let _ = tx_event .send(Ok(ResponseEvent::Completed { response_id: String::new(), + input_tokens, + output_tokens, })) .await; return; @@ -175,25 +183,18 @@ where Err(_) => continue, }; - // Forward usage statistics when requested. + // Store usage statistics when received. + // For the completion API, keys are "prompt_tokens" and "completion_tokens" + // which differs from the keys in the responses API. if let Some(usage) = chunk.get("usage") { - let prompt_tokens = usage + input_tokens = usage .get("prompt_tokens") - .or_else(|| usage.get("input_tokens")) .and_then(|v| v.as_u64()) - .unwrap_or(0) as u32; - let completion_tokens = usage + .map(|v| v as u32); + output_tokens = usage .get("completion_tokens") - .or_else(|| usage.get("output_tokens")) .and_then(|v| v.as_u64()) - .unwrap_or(0) as u32; - - let _ = tx_event - .send(Ok(ResponseEvent::Usage { - prompt_tokens, - completion_tokens, - })) - .await; + .map(|v| v as u32); } let content_opt = chunk @@ -273,7 +274,7 @@ where // Swallow partial event; keep polling. continue; } - Poll::Ready(Some(Ok(ResponseEvent::Completed { response_id }))) => { + Poll::Ready(Some(Ok(ResponseEvent::Completed { response_id, .. }))) => { if !this.cumulative.is_empty() { let aggregated_item = crate::models::ResponseItem::Message { role: "assistant".to_string(), @@ -283,7 +284,11 @@ where }; // Buffer Completed so it is returned *after* the aggregated message. - this.pending_completed = Some(ResponseEvent::Completed { response_id }); + this.pending_completed = Some(ResponseEvent::Completed { + response_id, + input_tokens: None, + output_tokens: None, + }); return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone( aggregated_item, @@ -291,11 +296,11 @@ where } // Nothing aggregated – forward Completed directly. - return Poll::Ready(Some(Ok(ResponseEvent::Completed { response_id }))); - } - Poll::Ready(Some(Ok(ev))) => { - // Forward any other event types (e.g., Usage). - return Poll::Ready(Some(Ok(ev))); + return Poll::Ready(Some(Ok(ResponseEvent::Completed { + response_id, + input_tokens: None, + output_tokens: None, + }))); } } } diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index ef5e5809927..dc53f227da9 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -334,6 +334,9 @@ where let idle_timeout = *OPENAI_STREAM_IDLE_TIMEOUT_MS; // The response id returned from the "complete" message. let mut response_id = None; + // Token usage information + let mut input_tokens = None; + let mut output_tokens = None; loop { let sse = match timeout(idle_timeout, stream.next()).await { @@ -347,7 +350,11 @@ where Ok(None) => { match response_id { Some(response_id) => { - let event = ResponseEvent::Completed { response_id }; + let event = ResponseEvent::Completed { + response_id, + input_tokens, + output_tokens, + }; let _ = tx_event.send(Ok(event)).await; } None => { @@ -413,23 +420,16 @@ where if let Some(resp_val) = event.response { // Extract usage if present if let Some(usage) = resp_val.get("usage") { - let prompt_tokens = usage + input_tokens = usage .get("prompt_tokens") .or_else(|| usage.get("input_tokens")) .and_then(|v| v.as_u64()) - .unwrap_or(0) as u32; - let completion_tokens = usage + .map(|v| v as u32); + output_tokens = usage .get("completion_tokens") .or_else(|| usage.get("output_tokens")) .and_then(|v| v.as_u64()) - .unwrap_or(0) as u32; - - let _ = tx_event - .send(Ok(ResponseEvent::Usage { - prompt_tokens, - completion_tokens, - })) - .await; + .map(|v| v as u32); } match serde_json::from_value::(resp_val) { diff --git a/codex-rs/core/src/client_common.rs b/codex-rs/core/src/client_common.rs index dfe2c182f5a..47a193ec1f0 100644 --- a/codex-rs/core/src/client_common.rs +++ b/codex-rs/core/src/client_common.rs @@ -47,13 +47,14 @@ impl Prompt { #[derive(Debug)] pub enum ResponseEvent { OutputItemDone(ResponseItem), - Completed { response_id: String }, - Usage { - prompt_tokens: u32, - completion_tokens: u32, + Completed { + response_id: String, + input_tokens: Option, + output_tokens: Option, }, } + #[derive(Debug, Serialize)] pub(crate) struct Reasoning { pub(crate) effort: &'static str, diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 8bb7969baab..8457042a244 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -38,6 +38,7 @@ use crate::client_common::Prompt; use crate::client_common::ResponseEvent; use crate::config::Config; use crate::conversation_history::ConversationHistory; +use crate::usage::get_openai_pricing; use crate::error::CodexErr; use crate::error::Result as CodexResult; use crate::error::SandboxErr; @@ -193,89 +194,6 @@ impl Session { } } -// ----------------------------------------------------------------------------- -// Helper functions (private to this module) -// ----------------------------------------------------------------------------- - -/// Very rough approximation for the token count of an arbitrary string. We use -/// a simple heuristic of 4 characters per token, which is commonly accepted as -/// “good enough” for estimating costs without a tokenizer. The result is -/// *never* used for billing – only for displaying approximate usage stats to -/// the user. -fn approx_token_count(s: &str) -> usize { - // Avoid division by zero for empty strings. - if s.is_empty() { - 0 - } else { - (s.len() + 3) / 4 // round up - } -} - -/// Counts the number of tokens contained in a collection of [`ResponseItem`]s -/// by summing up the textual content of all `InputText` and `OutputText` -/// elements. -fn count_tokens_in_items(items: &[ResponseItem]) -> usize { - items - .iter() - .map(|item| match item { - ResponseItem::Message { content, .. } => content - .iter() - .filter_map(|c| match c { - ContentItem::InputText { text } | ContentItem::OutputText { text } => { - Some(approx_token_count(text)) - } - _ => None, - }) - .sum::(), - _ => 0, - }) - .sum() -} - -/// Returns the OpenAI per-1K-token pricing (prompt, completion) **in USD** for -/// a given model name. The list is *not* exhaustive – it only covers the most -/// common public models so we offer reasonable estimates without hard-coding -/// every single variant. Unknown models default to `None` so callers can fall -/// back gracefully. -fn get_openai_pricing(model: &str) -> Option<(f64, f64)> { - // Exact mapping (per *token* rates, not per-1K) - let detailed: &[(&str, (f64, f64))] = &[ // (model, (input, output)) - ("o3", (10.0 / 1_000_000.0, 40.0 / 1_000_000.0)), - ("o4-mini", (1.1 / 1_000_000.0, 4.4 / 1_000_000.0)), - ("gpt-4.1-nano", (0.1 / 1_000_000.0, 0.4 / 1_000_000.0)), - ("gpt-4.1-mini", (0.4 / 1_000_000.0, 1.6 / 1_000_000.0)), - ("gpt-4.1", (2.0 / 1_000_000.0, 8.0 / 1_000_000.0)), - ("gpt-4o-mini", (0.6 / 1_000_000.0, 2.4 / 1_000_000.0)), - ("gpt-4o", (5.0 / 1_000_000.0, 20.0 / 1_000_000.0)), - ("codex-mini-latest", (1.5 / 1_000_000.0, 6.0 / 1_000_000.0)), - ]; - - let key = model.to_ascii_lowercase(); - if let Some((in_rate, out_rate)) = detailed - .iter() - .find(|(m, _)| key.starts_with(*m)) - .map(|(_, r)| *r) - { - return Some((in_rate, out_rate)); - } - - // Fallback coarse buckets (per-1K rates → convert to per-token) - let per_1k_to_per_token = |x: f64| x / 1000.0; - if key.contains("gpt-4o") { - return Some((per_1k_to_per_token(0.005), per_1k_to_per_token(0.015))); - } - if key.contains("gpt-4-turbo") { - return Some((per_1k_to_per_token(0.01), per_1k_to_per_token(0.03))); - } - if key.contains("gpt-4") { - return Some((per_1k_to_per_token(0.03), per_1k_to_per_token(0.06))); - } - if key.contains("gpt-3.5-turbo") { - return Some((per_1k_to_per_token(0.0005), per_1k_to_per_token(0.0015))); - } - None -} - /// Mutable state of the agent #[derive(Default)] @@ -854,8 +772,8 @@ async fn run_task(sess: Arc, sub_id: String, input: Vec) { // counts based on a naive 4-character-per-token heuristic – sufficient // for ballpark cost estimation without pulling in a heavyweight tokenizer // dependency. - let mut total_prompt_tokens: usize = 0; - let mut total_completion_tokens: usize = 0; + let mut total_input_tokens: usize = 0; + let mut total_output_tokens: usize = 0; let mut pending_response_input: Vec = vec![ResponseInputItem::from(input)]; loop { @@ -917,13 +835,11 @@ async fn run_task(sess: Arc, sub_id: String, input: Vec) { }) .collect(); - // Pre-compute approximate token count before `turn_input` is moved. - let turn_input_approx_tokens = count_tokens_in_items(&turn_input); match run_turn(&sess, sub_id.clone(), turn_input).await { Ok((turn_output, usage_in, usage_out)) => { - // Accumulate exact token usage when available. - total_prompt_tokens += usage_in as usize; - total_completion_tokens += usage_out as usize; + // Accumulate exact token usage from API + total_input_tokens += usage_in as usize; + total_output_tokens += usage_out as usize; let (items, responses): (Vec<_>, Vec<_>) = turn_output .into_iter() @@ -935,15 +851,6 @@ async fn run_task(sess: Arc, sub_id: String, input: Vec) { .collect::>(); let last_assistant_message = get_last_assistant_message_from_turn(&items); - // When usage info not provided (usage_out==0) fall back to approximate. - if usage_out == 0 { - total_completion_tokens += count_tokens_in_items(&items); - } - - if usage_in == 0 { - total_prompt_tokens += turn_input_approx_tokens; - } - // Only attempt to take the lock if there is something to record. if !items.is_empty() { // First persist model-generated output to the rollout file – this only borrows. @@ -987,18 +894,19 @@ async fn run_task(sess: Arc, sub_id: String, input: Vec) { .client .provider() .base_url + .as_str() .contains("openai.com"); - let (total_cost_opt, prompt_tokens_opt, completion_tokens_opt) = if is_openai_provider { + let (total_cost_opt, input_tokens_opt, output_tokens_opt) = if is_openai_provider { let model = sess.client.model_name(); - let (prompt_rate, completion_rate) = get_openai_pricing(model).unwrap_or((0.0, 0.0)); + let (per_input_token_cost, per_output_token_cost) = get_openai_pricing(model).unwrap_or((0.0, 0.0)); // Rates are per-token. Multiply directly. - let cost = (total_prompt_tokens as f64) * prompt_rate - + (total_completion_tokens as f64) * completion_rate; + let cost = (total_input_tokens as f64) * per_input_token_cost + + (total_output_tokens as f64) * per_output_token_cost; ( Some(cost), - Some(total_prompt_tokens as u32), - Some(total_completion_tokens as u32), + Some(total_input_tokens as u32), + Some(total_output_tokens as u32), ) } else { (None, None, None) @@ -1008,8 +916,8 @@ async fn run_task(sess: Arc, sub_id: String, input: Vec) { id: sub_id, msg: EventMsg::TaskComplete { total_cost: total_cost_opt, - prompt_tokens: prompt_tokens_opt, - completion_tokens: completion_tokens_opt, + input_tokens: input_tokens_opt, + output_tokens: output_tokens_opt, }, }; sess.tx_event.send(event).await.ok(); @@ -1118,18 +1026,17 @@ async fn try_run_turn( let response = handle_response_item(sess, sub_id, item.clone()).await?; output.push(ProcessedResponseItem { item, response }); } - ResponseEvent::Completed { response_id } => { + ResponseEvent::Completed { response_id, input_tokens, output_tokens } => { let mut state = sess.state.lock().unwrap(); state.previous_response_id = Some(response_id); - // Do not break – there might be Usage event afterwards, but - // in practice Completed comes last; we keep scanning. - } - ResponseEvent::Usage { - prompt_tokens: p, - completion_tokens: c, - } => { - prompt_tokens += p; - completion_tokens += c; + + // Add token usage if available + if let Some(p) = input_tokens { + prompt_tokens += p; + } + if let Some(c) = output_tokens { + completion_tokens += c; + } } } } diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index 00a65a67258..af89b009cbb 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -14,6 +14,7 @@ pub mod codex_wrapper; pub mod config; pub mod config_profile; mod conversation_history; +pub mod usage; pub mod error; pub mod exec; pub mod exec_linux; diff --git a/codex-rs/core/src/protocol.rs b/codex-rs/core/src/protocol.rs index 1cdb13f0a3e..d4f3b781cf3 100644 --- a/codex-rs/core/src/protocol.rs +++ b/codex-rs/core/src/protocol.rs @@ -327,9 +327,9 @@ pub enum EventMsg { #[serde(skip_serializing_if = "Option::is_none")] total_cost: Option, #[serde(skip_serializing_if = "Option::is_none")] - prompt_tokens: Option, + input_tokens: Option, #[serde(skip_serializing_if = "Option::is_none")] - completion_tokens: Option, + output_tokens: Option, }, /// Agent text output message diff --git a/codex-rs/core/src/usage.rs b/codex-rs/core/src/usage.rs new file mode 100644 index 00000000000..5601058adce --- /dev/null +++ b/codex-rs/core/src/usage.rs @@ -0,0 +1,26 @@ +/// Returns the OpenAI per-token pricing (input, output) **in USD** for +/// a given model name. The list is not exhaustive – it only covers the most +/// common public models so we offer reasonable estimates without hard-coding +/// every single variant. Unknown models return `None` so callers can fall +/// back gracefully. +pub fn get_openai_pricing(model: &str) -> Option<(f64, f64)> { + // Exact mapping (per *token* rates, not per-1K) + // Order matters: more specific matches must come before general ones + let detailed: &[(&str, (f64, f64))] = &[ // (model, (input, output)) + ("o3", (10.0 / 1_000_000.0, 40.0 / 1_000_000.0)), + ("o4-mini", (1.1 / 1_000_000.0, 4.4 / 1_000_000.0)), + ("gpt-4.1-nano", (0.1 / 1_000_000.0, 0.4 / 1_000_000.0)), + ("gpt-4.1-mini", (0.4 / 1_000_000.0, 1.6 / 1_000_000.0)), + ("gpt-4.1", (2.0 / 1_000_000.0, 8.0 / 1_000_000.0)), + ("gpt-4o-mini", (0.6 / 1_000_000.0, 2.4 / 1_000_000.0)), + ("gpt-4o", (5.0 / 1_000_000.0, 20.0 / 1_000_000.0)), + ("codex-mini-latest", (1.5 / 1_000_000.0, 6.0 / 1_000_000.0)), + ]; + + let key = model.to_ascii_lowercase(); + detailed + .iter() + // We use starts_with to match model variants (e.g., "gpt-4o-2024-11-20" matches "gpt-4o") + .find(|(m, _)| key.starts_with(*m)) + .map(|(_, r)| *r) +} \ No newline at end of file diff --git a/codex-rs/exec/src/event_processor.rs b/codex-rs/exec/src/event_processor.rs index de74d193aa9..e9a70af534a 100644 --- a/codex-rs/exec/src/event_processor.rs +++ b/codex-rs/exec/src/event_processor.rs @@ -117,9 +117,9 @@ impl EventProcessor { let msg = format!("Task started: {id}"); ts_println!("{}", msg.style(self.dimmed)); } - EventMsg::TaskComplete { total_cost, prompt_tokens, completion_tokens } => { + EventMsg::TaskComplete { total_cost, input_tokens, output_tokens } => { let mut msg = format!("Task complete: {id}"); - if let (Some(cost), Some(inp), Some(out)) = (total_cost, prompt_tokens, completion_tokens) { + if let (Some(cost), Some(inp), Some(out)) = (total_cost, input_tokens, output_tokens) { msg.push_str(&format!(" – cost ${:.4} ({} in / {} out tokens)", cost, inp, out)); } ts_println!("{}", msg.style(self.bold)); From 9cbebeec274ed285b955e66e3ea10e6fc9a7e256 Mon Sep 17 00:00:00 2001 From: Arnaud Stiegler Date: Thu, 22 May 2025 16:21:18 -0400 Subject: [PATCH 4/9] fix --- codex-rs/core/src/client.rs | 6 ++---- codex-rs/core/src/codex.rs | 16 ++++++++-------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index dc53f227da9..a628a16b4eb 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -421,13 +421,11 @@ where // Extract usage if present if let Some(usage) = resp_val.get("usage") { input_tokens = usage - .get("prompt_tokens") - .or_else(|| usage.get("input_tokens")) + .get("input_tokens") .and_then(|v| v.as_u64()) .map(|v| v as u32); output_tokens = usage - .get("completion_tokens") - .or_else(|| usage.get("output_tokens")) + .get("output_tokens") .and_then(|v| v.as_u64()) .map(|v| v as u32); } diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 8457042a244..bbeae712424 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -1017,8 +1017,8 @@ async fn try_run_turn( } let mut output = Vec::new(); - let mut prompt_tokens: u32 = 0; - let mut completion_tokens: u32 = 0; + let mut input_tokens: u32 = 0; + let mut output_tokens: u32 = 0; for event in input_events { match event { @@ -1026,22 +1026,22 @@ async fn try_run_turn( let response = handle_response_item(sess, sub_id, item.clone()).await?; output.push(ProcessedResponseItem { item, response }); } - ResponseEvent::Completed { response_id, input_tokens, output_tokens } => { + ResponseEvent::Completed { response_id, input_tokens: resp_input_tokens, output_tokens: resp_output_tokens } => { let mut state = sess.state.lock().unwrap(); state.previous_response_id = Some(response_id); // Add token usage if available - if let Some(p) = input_tokens { - prompt_tokens += p; + if let Some(p) = resp_input_tokens { + input_tokens += p; } - if let Some(c) = output_tokens { - completion_tokens += c; + if let Some(c) = resp_output_tokens { + output_tokens += c; } } } } - Ok((output, prompt_tokens, completion_tokens)) + Ok((output, input_tokens, output_tokens)) } async fn handle_response_item( From 0bd9d0969111e1c6cd37689e33b8c576269d27e5 Mon Sep 17 00:00:00 2001 From: Arnaud Stiegler Date: Thu, 22 May 2025 16:52:38 -0400 Subject: [PATCH 5/9] move aggregation to client.rs --- codex-rs/core/src/chat_completions.rs | 29 ++++++++++++----- codex-rs/core/src/client.rs | 45 +++++++++++++++++++++------ codex-rs/core/src/client_common.rs | 25 +++++++++++++++ codex-rs/core/src/codex.rs | 22 +++++-------- codex-rs/core/src/usage.rs | 22 +++++++++++++ 5 files changed, 112 insertions(+), 31 deletions(-) diff --git a/codex-rs/core/src/chat_completions.rs b/codex-rs/core/src/chat_completions.rs index 5a93597ad20..79ef80ce8af 100644 --- a/codex-rs/core/src/chat_completions.rs +++ b/codex-rs/core/src/chat_completions.rs @@ -1,3 +1,4 @@ +use std::sync::Arc; use std::time::Duration; use bytes::Bytes; @@ -35,6 +36,7 @@ pub(crate) async fn stream_chat_completions( model: &str, client: &reqwest::Client, provider: &ModelProviderInfo, + token_aggregator: Arc>, ) -> Result { // Build messages array let mut messages = Vec::::new(); @@ -88,7 +90,7 @@ pub(crate) async fn stream_chat_completions( Ok(resp) if resp.status().is_success() => { let (tx_event, rx_event) = mpsc::channel::>(16); let stream = resp.bytes_stream().map_err(CodexErr::Reqwest); - tokio::spawn(process_chat_sse(stream, tx_event)); + tokio::spawn(process_chat_sse(stream, tx_event, Arc::clone(&token_aggregator))); return Ok(ResponseStream { rx_event }); } Ok(res) => { @@ -127,7 +129,11 @@ pub(crate) async fn stream_chat_completions( /// Lightweight SSE processor for the Chat Completions streaming format. The /// output is mapped onto Codex's internal [`ResponseEvent`] so that the rest /// of the pipeline can stay agnostic of the underlying wire format. -async fn process_chat_sse(stream: S, tx_event: mpsc::Sender>) +async fn process_chat_sse( + stream: S, + tx_event: mpsc::Sender>, + token_aggregator: Arc>, +) where S: Stream> + Unpin, { @@ -184,17 +190,24 @@ where }; // Store usage statistics when received. - // For the completion API, keys are "prompt_tokens" and "completion_tokens" - // which differs from the keys in the responses API. + // Chat Completions API uses "prompt_tokens" and "completion_tokens" if let Some(usage) = chunk.get("usage") { - input_tokens = usage + let usage_input_tokens = usage .get("prompt_tokens") .and_then(|v| v.as_u64()) - .map(|v| v as u32); - output_tokens = usage + .map(|v| v as u32) + .unwrap_or(0); + let usage_output_tokens = usage .get("completion_tokens") .and_then(|v| v.as_u64()) - .map(|v| v as u32); + .map(|v| v as u32) + .unwrap_or(0); + + // Add to session aggregator + token_aggregator.lock().unwrap().add_usage(usage_input_tokens, usage_output_tokens); + + input_tokens = Some(usage_input_tokens); + output_tokens = Some(usage_output_tokens); } let content_opt = chunk diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index a628a16b4eb..7b71c0a61a4 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -1,6 +1,7 @@ use std::collections::BTreeMap; use std::io::BufRead; use std::path::Path; +use std::sync::Arc; use std::sync::LazyLock; use std::time::Duration; @@ -27,6 +28,7 @@ use crate::client_common::Reasoning; use crate::client_common::ResponseEvent; use crate::client_common::ResponseStream; use crate::client_common::Summary; +use crate::client_common::TokenAggregator; use crate::error::CodexErr; use crate::error::EnvVarError; use crate::error::Result; @@ -107,6 +109,7 @@ pub struct ModelClient { model: String, client: reqwest::Client, provider: ModelProviderInfo, + token_aggregator: Arc>, } impl ModelClient { @@ -115,6 +118,7 @@ impl ModelClient { model: model.to_string(), client: reqwest::Client::new(), provider, + token_aggregator: Arc::new(std::sync::Mutex::new(TokenAggregator::new())), } } @@ -131,6 +135,16 @@ impl ModelClient { &self.provider } + /// Get cumulative token usage for this session + pub fn get_session_usage(&self) -> (u32, u32) { + self.token_aggregator.lock().unwrap().get_totals() + } + + /// Reset token counters (e.g., for new session) + pub fn reset_usage(&self) { + self.token_aggregator.lock().unwrap().reset() + } + /// Dispatches to either the Responses or Chat implementation depending on /// the provider config. Public callers always invoke `stream()` – the /// specialised helpers are private to avoid accidental misuse. @@ -140,7 +154,7 @@ impl ModelClient { WireApi::Chat => { // Create the raw streaming connection first. let response_stream = - stream_chat_completions(prompt, &self.model, &self.client, &self.provider) + stream_chat_completions(prompt, &self.model, &self.client, &self.provider, Arc::clone(&self.token_aggregator)) .await?; // Wrap it with the aggregation adapter so callers see *only* @@ -250,7 +264,7 @@ impl ModelClient { // spawn task to process SSE let stream = resp.bytes_stream().map_err(CodexErr::Reqwest); - tokio::spawn(process_sse(stream, tx_event)); + tokio::spawn(process_sse(stream, tx_event, Arc::clone(&self.token_aggregator))); return Ok(ResponseStream { rx_event }); } @@ -324,7 +338,11 @@ struct ResponseCompleted { id: String, } -async fn process_sse(stream: S, tx_event: mpsc::Sender>) +async fn process_sse( + stream: S, + tx_event: mpsc::Sender>, + token_aggregator: Arc>, +) where S: Stream> + Unpin, { @@ -418,16 +436,24 @@ where // Final response completed – includes array of output items & id "response.completed" => { if let Some(resp_val) = event.response { - // Extract usage if present + // Extract usage if present (Responses API uses input_tokens/output_tokens) if let Some(usage) = resp_val.get("usage") { - input_tokens = usage + let usage_input_tokens = usage .get("input_tokens") .and_then(|v| v.as_u64()) - .map(|v| v as u32); - output_tokens = usage + .map(|v| v as u32) + .unwrap_or(0); + let usage_output_tokens = usage .get("output_tokens") .and_then(|v| v.as_u64()) - .map(|v| v as u32); + .map(|v| v as u32) + .unwrap_or(0); + + // Add to session aggregator + token_aggregator.lock().unwrap().add_usage(usage_input_tokens, usage_output_tokens); + + input_tokens = Some(usage_input_tokens); + output_tokens = Some(usage_output_tokens); } match serde_json::from_value::(resp_val) { @@ -461,6 +487,7 @@ async fn stream_from_fixture(path: impl AsRef) -> Result { let rdr = std::io::Cursor::new(content); let stream = ReaderStream::new(rdr).map_err(CodexErr::Io); - tokio::spawn(process_sse(stream, tx_event)); + let dummy_aggregator = Arc::new(std::sync::Mutex::new(TokenAggregator::new())); + tokio::spawn(process_sse(stream, tx_event, dummy_aggregator)); Ok(ResponseStream { rx_event }) } diff --git a/codex-rs/core/src/client_common.rs b/codex-rs/core/src/client_common.rs index 47a193ec1f0..acddaa27772 100644 --- a/codex-rs/core/src/client_common.rs +++ b/codex-rs/core/src/client_common.rs @@ -54,6 +54,31 @@ pub enum ResponseEvent { }, } +#[derive(Debug, Default)] +pub struct TokenAggregator { + total_input_tokens: u32, + total_output_tokens: u32, +} + +impl TokenAggregator { + pub fn new() -> Self { + Self::default() + } + + pub fn add_usage(&mut self, input_tokens: u32, output_tokens: u32) { + self.total_input_tokens += input_tokens; + self.total_output_tokens += output_tokens; + } + + pub fn get_totals(&self) -> (u32, u32) { + (self.total_input_tokens, self.total_output_tokens) + } + + pub fn reset(&mut self) { + self.total_input_tokens = 0; + self.total_output_tokens = 0; + } +} #[derive(Debug, Serialize)] pub(crate) struct Reasoning { diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index bbeae712424..97e5fafff43 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -767,13 +767,7 @@ async fn run_task(sess: Arc, sub_id: String, input: Vec) { return; } - // Track overall token usage for this task so we can expose usage/cost - // statistics in the final `TaskComplete` event. These are *approximate* - // counts based on a naive 4-character-per-token heuristic – sufficient - // for ballpark cost estimation without pulling in a heavyweight tokenizer - // dependency. - let mut total_input_tokens: usize = 0; - let mut total_output_tokens: usize = 0; + // Token usage is now tracked automatically at the client level let mut pending_response_input: Vec = vec![ResponseInputItem::from(input)]; loop { @@ -836,10 +830,8 @@ async fn run_task(sess: Arc, sub_id: String, input: Vec) { .collect(); match run_turn(&sess, sub_id.clone(), turn_input).await { - Ok((turn_output, usage_in, usage_out)) => { - // Accumulate exact token usage from API - total_input_tokens += usage_in as usize; - total_output_tokens += usage_out as usize; + Ok((turn_output, _usage_in, _usage_out)) => { + // Token usage is now accumulated automatically in the client let (items, responses): (Vec<_>, Vec<_>) = turn_output .into_iter() @@ -889,7 +881,9 @@ async fn run_task(sess: Arc, sub_id: String, input: Vec) { } sess.remove_task(&sub_id); - // Calculate total cost for OpenAI models if possible. + // Get aggregated usage from client and calculate cost for OpenAI models + let (total_input_tokens, total_output_tokens) = sess.client.get_session_usage(); + let is_openai_provider = sess .client .provider() @@ -905,8 +899,8 @@ async fn run_task(sess: Arc, sub_id: String, input: Vec) { + (total_output_tokens as f64) * per_output_token_cost; ( Some(cost), - Some(total_input_tokens as u32), - Some(total_output_tokens as u32), + Some(total_input_tokens), + Some(total_output_tokens), ) } else { (None, None, None) diff --git a/codex-rs/core/src/usage.rs b/codex-rs/core/src/usage.rs index 5601058adce..f63b3adaa66 100644 --- a/codex-rs/core/src/usage.rs +++ b/codex-rs/core/src/usage.rs @@ -23,4 +23,26 @@ pub fn get_openai_pricing(model: &str) -> Option<(f64, f64)> { // We use starts_with to match model variants (e.g., "gpt-4o-2024-11-20" matches "gpt-4o") .find(|(m, _)| key.starts_with(*m)) .map(|(_, r)| *r) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_openai_pricing() { + // Test exact matches + assert_eq!(get_openai_pricing("gpt-4o-mini"), Some((0.6 / 1_000_000.0, 2.4 / 1_000_000.0))); + assert_eq!(get_openai_pricing("gpt-4o"), Some((5.0 / 1_000_000.0, 20.0 / 1_000_000.0))); + assert_eq!(get_openai_pricing("codex-mini-latest"), Some((1.5 / 1_000_000.0, 6.0 / 1_000_000.0))); + + // Test model variants (should match prefix) + assert_eq!(get_openai_pricing("gpt-4o-2024-11-20"), Some((5.0 / 1_000_000.0, 20.0 / 1_000_000.0))); + + // Test unknown model + assert_eq!(get_openai_pricing("unknown-model"), None); + + // Test case insensitive + assert_eq!(get_openai_pricing("GPT-4O-MINI"), Some((0.6 / 1_000_000.0, 2.4 / 1_000_000.0))); + } } \ No newline at end of file From f8bde98b68857ddeda6be407cb348cbc1400f61e Mon Sep 17 00:00:00 2001 From: Arnaud Stiegler Date: Thu, 22 May 2025 17:04:21 -0400 Subject: [PATCH 6/9] cleanup --- codex-rs/core/src/codex.rs | 34 ++++++---------------------------- 1 file changed, 6 insertions(+), 28 deletions(-) diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 97e5fafff43..7eeb41da328 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -194,7 +194,6 @@ impl Session { } } - /// Mutable state of the agent #[derive(Default)] struct State { @@ -767,8 +766,6 @@ async fn run_task(sess: Arc, sub_id: String, input: Vec) { return; } - // Token usage is now tracked automatically at the client level - let mut pending_response_input: Vec = vec![ResponseInputItem::from(input)]; loop { let mut net_new_turn_input = pending_response_input @@ -811,10 +808,6 @@ async fn run_task(sess: Arc, sub_id: String, input: Vec) { net_new_turn_input }; - // Token accounting for providers that do not return detailed usage - // stats. We add *approximate* counts here and later replace them with - // exact ones when `usage_in` becomes available. - let turn_input_messages: Vec = turn_input .iter() .filter_map(|item| match item { @@ -828,11 +821,8 @@ async fn run_task(sess: Arc, sub_id: String, input: Vec) { }) }) .collect(); - match run_turn(&sess, sub_id.clone(), turn_input).await { - Ok((turn_output, _usage_in, _usage_out)) => { - // Token usage is now accumulated automatically in the client - + Ok(turn_output) => { let (items, responses): (Vec<_>, Vec<_>) = turn_output .into_iter() .map(|p| (p.item, p.response)) @@ -921,7 +911,7 @@ async fn run_turn( sess: &Session, sub_id: String, input: Vec, -) -> CodexResult<(Vec, u32, u32)> { +) -> CodexResult> { // Decide whether to use server-side storage (previous_response_id) or disable it let (prev_id, store, is_first_turn) = { let state = sess.state.lock().unwrap(); @@ -955,7 +945,7 @@ async fn run_turn( let mut retries = 0; loop { match try_run_turn(sess, &sub_id, &prompt).await { - Ok(res) => return Ok(res), + Ok(output) => return Ok(output), Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted), Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)), Err(e) => { @@ -1001,7 +991,7 @@ async fn try_run_turn( sess: &Session, sub_id: &str, prompt: &Prompt, -) -> CodexResult<(Vec, u32, u32)> { +) -> CodexResult> { let mut stream = sess.client.clone().stream(prompt).await?; // Buffer all incoming messages first as before. @@ -1011,31 +1001,19 @@ async fn try_run_turn( } let mut output = Vec::new(); - let mut input_tokens: u32 = 0; - let mut output_tokens: u32 = 0; - for event in input_events { match event { ResponseEvent::OutputItemDone(item) => { let response = handle_response_item(sess, sub_id, item.clone()).await?; output.push(ProcessedResponseItem { item, response }); } - ResponseEvent::Completed { response_id, input_tokens: resp_input_tokens, output_tokens: resp_output_tokens } => { + ResponseEvent::Completed { response_id, input_tokens: _resp_input_tokens, output_tokens: _resp_output_tokens } => { let mut state = sess.state.lock().unwrap(); state.previous_response_id = Some(response_id); - - // Add token usage if available - if let Some(p) = resp_input_tokens { - input_tokens += p; - } - if let Some(c) = resp_output_tokens { - output_tokens += c; - } } } } - - Ok((output, input_tokens, output_tokens)) + Ok(output) } async fn handle_response_item( From 5838ea5feaa1318470d9eef0c6d94afad8d84135 Mon Sep 17 00:00:00 2001 From: Arnaud Stiegler Date: Fri, 23 May 2025 09:58:59 -0400 Subject: [PATCH 7/9] fmt --- codex-rs/core/src/chat_completions.rs | 30 ++++++++---- codex-rs/core/src/client.rs | 58 ++++++++++++---------- codex-rs/core/src/client_common.rs | 13 ++--- codex-rs/core/src/codex.rs | 18 ++++--- codex-rs/core/src/lib.rs | 2 +- codex-rs/core/src/usage.rs | 70 ++++++++++++++++----------- codex-rs/exec/src/event_processor.rs | 7 +-- codex-rs/exec/src/lib.rs | 4 +- 8 files changed, 115 insertions(+), 87 deletions(-) diff --git a/codex-rs/core/src/chat_completions.rs b/codex-rs/core/src/chat_completions.rs index 79ef80ce8af..4170c033fba 100644 --- a/codex-rs/core/src/chat_completions.rs +++ b/codex-rs/core/src/chat_completions.rs @@ -26,8 +26,8 @@ use crate::flags::OPENAI_REQUEST_MAX_RETRIES; use crate::flags::OPENAI_STREAM_IDLE_TIMEOUT_MS; use crate::models::ContentItem; use crate::models::ResponseItem; -use crate::util::backoff; use crate::util::UrlExt; +use crate::util::backoff; /// Implementation for the classic Chat Completions API. This is intentionally /// minimal: we only stream back plain assistant text. @@ -66,7 +66,11 @@ pub(crate) async fn stream_chat_completions( "stream_options": {"include_usage": true} }); - let url = provider.base_url.clone().append_path("/chat/completions")?.to_string(); + let url = provider + .base_url + .clone() + .append_path("/chat/completions")? + .to_string(); debug!("{} POST (chat)", &url); trace!("request payload: {}", payload); @@ -90,7 +94,11 @@ pub(crate) async fn stream_chat_completions( Ok(resp) if resp.status().is_success() => { let (tx_event, rx_event) = mpsc::channel::>(16); let stream = resp.bytes_stream().map_err(CodexErr::Reqwest); - tokio::spawn(process_chat_sse(stream, tx_event, Arc::clone(&token_aggregator))); + tokio::spawn(process_chat_sse( + stream, + tx_event, + Arc::clone(&token_aggregator), + )); return Ok(ResponseStream { rx_event }); } Ok(res) => { @@ -130,17 +138,16 @@ pub(crate) async fn stream_chat_completions( /// output is mapped onto Codex's internal [`ResponseEvent`] so that the rest /// of the pipeline can stay agnostic of the underlying wire format. async fn process_chat_sse( - stream: S, + stream: S, tx_event: mpsc::Sender>, token_aggregator: Arc>, -) -where +) where S: Stream> + Unpin, { let mut stream = stream.eventsource(); let idle_timeout = *OPENAI_STREAM_IDLE_TIMEOUT_MS; - + // Track usage information to include in final completion event let mut input_tokens = None; let mut output_tokens = None; @@ -202,10 +209,13 @@ where .and_then(|v| v.as_u64()) .map(|v| v as u32) .unwrap_or(0); - + // Add to session aggregator - token_aggregator.lock().unwrap().add_usage(usage_input_tokens, usage_output_tokens); - + token_aggregator + .lock() + .unwrap() + .add_token_usage(usage_input_tokens, usage_output_tokens); + input_tokens = Some(usage_input_tokens); output_tokens = Some(usage_output_tokens); } diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 7b71c0a61a4..6e5b9c70533 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -38,8 +38,8 @@ use crate::flags::OPENAI_STREAM_IDLE_TIMEOUT_MS; use crate::model_provider_info::ModelProviderInfo; use crate::model_provider_info::WireApi; use crate::models::ResponseItem; -use crate::util::backoff; use crate::util::UrlExt; +use crate::util::backoff; /// When serialized as JSON, this produces a valid "Tool" in the OpenAI /// Responses API. @@ -136,13 +136,8 @@ impl ModelClient { } /// Get cumulative token usage for this session - pub fn get_session_usage(&self) -> (u32, u32) { - self.token_aggregator.lock().unwrap().get_totals() - } - - /// Reset token counters (e.g., for new session) - pub fn reset_usage(&self) { - self.token_aggregator.lock().unwrap().reset() + pub fn get_session_token_usage(&self) -> (u32, u32) { + self.token_aggregator.lock().unwrap().get_token_totals() } /// Dispatches to either the Responses or Chat implementation depending on @@ -153,9 +148,14 @@ impl ModelClient { WireApi::Responses => self.stream_responses(prompt).await, WireApi::Chat => { // Create the raw streaming connection first. - let response_stream = - stream_chat_completions(prompt, &self.model, &self.client, &self.provider, Arc::clone(&self.token_aggregator)) - .await?; + let response_stream = stream_chat_completions( + prompt, + &self.model, + &self.client, + &self.provider, + Arc::clone(&self.token_aggregator), + ) + .await?; // Wrap it with the aggregation adapter so callers see *only* // the final assistant message per turn (matching the @@ -226,7 +226,12 @@ impl ModelClient { stream: true, }; - let url = self.provider.base_url.clone().append_path("/responses")?.to_string(); + let url = self + .provider + .base_url + .clone() + .append_path("/responses")? + .to_string(); debug!("{} POST", url); trace!("request payload: {}", serde_json::to_string(&payload)?); @@ -264,7 +269,11 @@ impl ModelClient { // spawn task to process SSE let stream = resp.bytes_stream().map_err(CodexErr::Reqwest); - tokio::spawn(process_sse(stream, tx_event, Arc::clone(&self.token_aggregator))); + tokio::spawn(process_sse( + stream, + tx_event, + Arc::clone(&self.token_aggregator), + )); return Ok(ResponseStream { rx_event }); } @@ -339,11 +348,10 @@ struct ResponseCompleted { } async fn process_sse( - stream: S, + stream: S, tx_event: mpsc::Sender>, token_aggregator: Arc>, -) -where +) where S: Stream> + Unpin, { let mut stream = stream.eventsource(); @@ -441,17 +449,17 @@ where let usage_input_tokens = usage .get("input_tokens") .and_then(|v| v.as_u64()) - .map(|v| v as u32) - .unwrap_or(0); + .unwrap_or(0) as u32; let usage_output_tokens = usage .get("output_tokens") .and_then(|v| v.as_u64()) - .map(|v| v as u32) - .unwrap_or(0); - - // Add to session aggregator - token_aggregator.lock().unwrap().add_usage(usage_input_tokens, usage_output_tokens); - + .unwrap_or(0) as u32; + + token_aggregator + .lock() + .unwrap() + .add_token_usage(usage_input_tokens, usage_output_tokens); + input_tokens = Some(usage_input_tokens); output_tokens = Some(usage_output_tokens); } @@ -465,7 +473,7 @@ where continue; } }; - } + }; } other => debug!(other, "sse event"), } diff --git a/codex-rs/core/src/client_common.rs b/codex-rs/core/src/client_common.rs index acddaa27772..11484f22c97 100644 --- a/codex-rs/core/src/client_common.rs +++ b/codex-rs/core/src/client_common.rs @@ -64,20 +64,15 @@ impl TokenAggregator { pub fn new() -> Self { Self::default() } - - pub fn add_usage(&mut self, input_tokens: u32, output_tokens: u32) { + + pub fn add_token_usage(&mut self, input_tokens: u32, output_tokens: u32) { self.total_input_tokens += input_tokens; self.total_output_tokens += output_tokens; } - - pub fn get_totals(&self) -> (u32, u32) { + + pub fn get_token_totals(&self) -> (u32, u32) { (self.total_input_tokens, self.total_output_tokens) } - - pub fn reset(&mut self) { - self.total_input_tokens = 0; - self.total_output_tokens = 0; - } } #[derive(Debug, Serialize)] diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 7eeb41da328..bcc52fadbe0 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -38,7 +38,6 @@ use crate::client_common::Prompt; use crate::client_common::ResponseEvent; use crate::config::Config; use crate::conversation_history::ConversationHistory; -use crate::usage::get_openai_pricing; use crate::error::CodexErr; use crate::error::Result as CodexResult; use crate::error::SandboxErr; @@ -82,6 +81,7 @@ use crate::rollout::RolloutRecorder; use crate::safety::SafetyCheck; use crate::safety::assess_command_safety; use crate::safety::assess_patch_safety; +use crate::usage::get_openai_pricing; use crate::user_notification::UserNotification; use crate::util::backoff; @@ -872,8 +872,8 @@ async fn run_task(sess: Arc, sub_id: String, input: Vec) { sess.remove_task(&sub_id); // Get aggregated usage from client and calculate cost for OpenAI models - let (total_input_tokens, total_output_tokens) = sess.client.get_session_usage(); - + let (total_input_tokens, total_output_tokens) = sess.client.get_session_token_usage(); + let is_openai_provider = sess .client .provider() @@ -883,7 +883,8 @@ async fn run_task(sess: Arc, sub_id: String, input: Vec) { let (total_cost_opt, input_tokens_opt, output_tokens_opt) = if is_openai_provider { let model = sess.client.model_name(); - let (per_input_token_cost, per_output_token_cost) = get_openai_pricing(model).unwrap_or((0.0, 0.0)); + let (per_input_token_cost, per_output_token_cost) = + get_openai_pricing(model).unwrap_or((0.0, 0.0)); // Rates are per-token. Multiply directly. let cost = (total_input_tokens as f64) * per_input_token_cost + (total_output_tokens as f64) * per_output_token_cost; @@ -995,21 +996,22 @@ async fn try_run_turn( let mut stream = sess.client.clone().stream(prompt).await?; // Buffer all incoming messages first as before. - let mut input_events = Vec::new(); + let mut input = Vec::new(); while let Some(event) = stream.next().await { - input_events.push(event?); + input.push(event?); } let mut output = Vec::new(); - for event in input_events { + for event in input { match event { ResponseEvent::OutputItemDone(item) => { let response = handle_response_item(sess, sub_id, item.clone()).await?; output.push(ProcessedResponseItem { item, response }); } - ResponseEvent::Completed { response_id, input_tokens: _resp_input_tokens, output_tokens: _resp_output_tokens } => { + ResponseEvent::Completed { response_id, .. } => { let mut state = sess.state.lock().unwrap(); state.previous_response_id = Some(response_id); + break; } } } diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index af89b009cbb..84d1f4827dc 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -14,7 +14,6 @@ pub mod codex_wrapper; pub mod config; pub mod config_profile; mod conversation_history; -pub mod usage; pub mod error; pub mod exec; pub mod exec_linux; @@ -27,6 +26,7 @@ pub mod mcp_server_config; mod mcp_tool_call; mod message_history; mod model_provider_info; +pub mod usage; pub use model_provider_info::ModelProviderInfo; pub use model_provider_info::WireApi; mod models; diff --git a/codex-rs/core/src/usage.rs b/codex-rs/core/src/usage.rs index f63b3adaa66..d3a7b5b5cb3 100644 --- a/codex-rs/core/src/usage.rs +++ b/codex-rs/core/src/usage.rs @@ -4,25 +4,26 @@ /// every single variant. Unknown models return `None` so callers can fall /// back gracefully. pub fn get_openai_pricing(model: &str) -> Option<(f64, f64)> { - // Exact mapping (per *token* rates, not per-1K) - // Order matters: more specific matches must come before general ones - let detailed: &[(&str, (f64, f64))] = &[ // (model, (input, output)) - ("o3", (10.0 / 1_000_000.0, 40.0 / 1_000_000.0)), - ("o4-mini", (1.1 / 1_000_000.0, 4.4 / 1_000_000.0)), - ("gpt-4.1-nano", (0.1 / 1_000_000.0, 0.4 / 1_000_000.0)), - ("gpt-4.1-mini", (0.4 / 1_000_000.0, 1.6 / 1_000_000.0)), - ("gpt-4.1", (2.0 / 1_000_000.0, 8.0 / 1_000_000.0)), - ("gpt-4o-mini", (0.6 / 1_000_000.0, 2.4 / 1_000_000.0)), - ("gpt-4o", (5.0 / 1_000_000.0, 20.0 / 1_000_000.0)), - ("codex-mini-latest", (1.5 / 1_000_000.0, 6.0 / 1_000_000.0)), - ]; + // Exact mapping (per *token* rates, not per-1K) + // Order matters: more specific matches must come before general ones + let detailed: &[(&str, (f64, f64))] = &[ + // (model, (input, output)) + ("o3", (10.0 / 1_000_000.0, 40.0 / 1_000_000.0)), + ("o4-mini", (1.1 / 1_000_000.0, 4.4 / 1_000_000.0)), + ("gpt-4.1-nano", (0.1 / 1_000_000.0, 0.4 / 1_000_000.0)), + ("gpt-4.1-mini", (0.4 / 1_000_000.0, 1.6 / 1_000_000.0)), + ("gpt-4.1", (2.0 / 1_000_000.0, 8.0 / 1_000_000.0)), + ("gpt-4o-mini", (0.6 / 1_000_000.0, 2.4 / 1_000_000.0)), + ("gpt-4o", (5.0 / 1_000_000.0, 20.0 / 1_000_000.0)), + ("codex-mini-latest", (1.5 / 1_000_000.0, 6.0 / 1_000_000.0)), + ]; - let key = model.to_ascii_lowercase(); - detailed - .iter() - // We use starts_with to match model variants (e.g., "gpt-4o-2024-11-20" matches "gpt-4o") - .find(|(m, _)| key.starts_with(*m)) - .map(|(_, r)| *r) + let key = model.to_ascii_lowercase(); + detailed + .iter() + // We use starts_with to match model variants (e.g., "gpt-4o-2024-11-20" matches "gpt-4o") + .find(|(m, _)| key.starts_with(*m)) + .map(|(_, r)| *r) } #[cfg(test)] @@ -32,17 +33,32 @@ mod tests { #[test] fn test_openai_pricing() { // Test exact matches - assert_eq!(get_openai_pricing("gpt-4o-mini"), Some((0.6 / 1_000_000.0, 2.4 / 1_000_000.0))); - assert_eq!(get_openai_pricing("gpt-4o"), Some((5.0 / 1_000_000.0, 20.0 / 1_000_000.0))); - assert_eq!(get_openai_pricing("codex-mini-latest"), Some((1.5 / 1_000_000.0, 6.0 / 1_000_000.0))); - + assert_eq!( + get_openai_pricing("gpt-4o-mini"), + Some((0.6 / 1_000_000.0, 2.4 / 1_000_000.0)) + ); + assert_eq!( + get_openai_pricing("gpt-4o"), + Some((5.0 / 1_000_000.0, 20.0 / 1_000_000.0)) + ); + assert_eq!( + get_openai_pricing("codex-mini-latest"), + Some((1.5 / 1_000_000.0, 6.0 / 1_000_000.0)) + ); + // Test model variants (should match prefix) - assert_eq!(get_openai_pricing("gpt-4o-2024-11-20"), Some((5.0 / 1_000_000.0, 20.0 / 1_000_000.0))); - + assert_eq!( + get_openai_pricing("gpt-4o-2024-11-20"), + Some((5.0 / 1_000_000.0, 20.0 / 1_000_000.0)) + ); + // Test unknown model assert_eq!(get_openai_pricing("unknown-model"), None); - + // Test case insensitive - assert_eq!(get_openai_pricing("GPT-4O-MINI"), Some((0.6 / 1_000_000.0, 2.4 / 1_000_000.0))); + assert_eq!( + get_openai_pricing("GPT-4O-MINI"), + Some((0.6 / 1_000_000.0, 2.4 / 1_000_000.0)) + ); } -} \ No newline at end of file +} diff --git a/codex-rs/exec/src/event_processor.rs b/codex-rs/exec/src/event_processor.rs index e9a70af534a..4c8278cc597 100644 --- a/codex-rs/exec/src/event_processor.rs +++ b/codex-rs/exec/src/event_processor.rs @@ -117,11 +117,8 @@ impl EventProcessor { let msg = format!("Task started: {id}"); ts_println!("{}", msg.style(self.dimmed)); } - EventMsg::TaskComplete { total_cost, input_tokens, output_tokens } => { - let mut msg = format!("Task complete: {id}"); - if let (Some(cost), Some(inp), Some(out)) = (total_cost, input_tokens, output_tokens) { - msg.push_str(&format!(" – cost ${:.4} ({} in / {} out tokens)", cost, inp, out)); - } + EventMsg::TaskComplete => { + let msg = format!("Task complete: {id}"); ts_println!("{}", msg.style(self.bold)); } EventMsg::AgentMessage(AgentMessageEvent { message }) => { diff --git a/codex-rs/exec/src/lib.rs b/codex-rs/exec/src/lib.rs index f362ab67abe..348bff08e61 100644 --- a/codex-rs/exec/src/lib.rs +++ b/codex-rs/exec/src/lib.rs @@ -137,7 +137,7 @@ pub async fn run_main(cli: Cli) -> anyhow::Result<()> { let initial_images_event_id = codex.submit(Op::UserInput { items }).await?; info!("Sent images with event ID: {initial_images_event_id}"); while let Ok(event) = codex.next_event().await { - if event.id == initial_images_event_id && matches!(event.msg, EventMsg::TaskComplete { .. }) { + if event.id == initial_images_event_id && matches!(event.msg, EventMsg::TaskComplete) { break; } } @@ -152,7 +152,7 @@ pub async fn run_main(cli: Cli) -> anyhow::Result<()> { let mut event_processor = EventProcessor::create_with_ansi(stdout_with_ansi); while let Some(event) = rx.recv().await { let last_event = - event.id == initial_prompt_task_id && matches!(event.msg, EventMsg::TaskComplete { .. }); + event.id == initial_prompt_task_id && matches!(event.msg, EventMsg::TaskComplete); event_processor.process_event(event); if last_event { break; From 678c909c67053b29193daa1bdb44112890b81d4e Mon Sep 17 00:00:00 2001 From: Arnaud Stiegler Date: Fri, 23 May 2025 10:35:59 -0400 Subject: [PATCH 8/9] fixes --- codex-rs/core/src/codex.rs | 30 +++++------ codex-rs/core/src/exec.rs | 1 - codex-rs/core/src/protocol.rs | 79 +++++++++++++++++++++++++--- codex-rs/core/src/usage.rs | 25 +++++++++ codex-rs/core/src/util.rs | 22 ++++---- codex-rs/core/tests/test_url_ext.rs | 12 +++-- codex-rs/exec/src/event_processor.rs | 13 ++++- codex-rs/exec/src/lib.rs | 4 +- 8 files changed, 144 insertions(+), 42 deletions(-) diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index bcc52fadbe0..83828a19f49 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -81,7 +81,7 @@ use crate::rollout::RolloutRecorder; use crate::safety::SafetyCheck; use crate::safety::assess_command_safety; use crate::safety::assess_patch_safety; -use crate::usage::get_openai_pricing; +use crate::usage::compute_openai_cost; use crate::user_notification::UserNotification; use crate::util::backoff; @@ -881,28 +881,26 @@ async fn run_task(sess: Arc, sub_id: String, input: Vec) { .as_str() .contains("openai.com"); - let (total_cost_opt, input_tokens_opt, output_tokens_opt) = if is_openai_provider { + let token_usage_opt = if is_openai_provider { let model = sess.client.model_name(); - let (per_input_token_cost, per_output_token_cost) = - get_openai_pricing(model).unwrap_or((0.0, 0.0)); - // Rates are per-token. Multiply directly. - let cost = (total_input_tokens as f64) * per_input_token_cost - + (total_output_tokens as f64) * per_output_token_cost; - ( - Some(cost), - Some(total_input_tokens), - Some(total_output_tokens), - ) + let cost = compute_openai_cost(model, total_input_tokens, total_output_tokens); + Some(crate::protocol::TokenUsage { + input_tokens: total_input_tokens, + output_tokens: total_output_tokens, + total_cost: cost, + }) } else { - (None, None, None) + Some(crate::protocol::TokenUsage { + input_tokens: total_input_tokens, + output_tokens: total_output_tokens, + total_cost: None, + }) }; let event = Event { id: sub_id, msg: EventMsg::TaskComplete { - total_cost: total_cost_opt, - input_tokens: input_tokens_opt, - output_tokens: output_tokens_opt, + token_usage: token_usage_opt, }, }; sess.tx_event.send(event).await.ok(); diff --git a/codex-rs/core/src/exec.rs b/codex-rs/core/src/exec.rs index 8e9420a029b..d9aeea5be3e 100644 --- a/codex-rs/core/src/exec.rs +++ b/codex-rs/core/src/exec.rs @@ -25,7 +25,6 @@ use crate::protocol::SandboxPolicy; use once_cell::sync::Lazy; - /// Each value is parsed as an unsigned integer. If parsing fails, or the /// environment variable is unset, we fall back to the hard-coded default. pub(crate) static MAX_STREAM_OUTPUT: Lazy = Lazy::new(|| { diff --git a/codex-rs/core/src/protocol.rs b/codex-rs/core/src/protocol.rs index d4f3b781cf3..894e654e0ef 100644 --- a/codex-rs/core/src/protocol.rs +++ b/codex-rs/core/src/protocol.rs @@ -321,15 +321,12 @@ pub enum EventMsg { TaskStarted, /// Agent has completed all actions. When using an OpenAI provider, the - /// server includes the total cost in USD as well as token usage metrics - /// for the entire task. For non-OpenAI providers these fields are `null`. + /// server includes token usage metrics and total cost in USD for the entire + /// task. For non-OpenAI providers this field is `null`. TaskComplete { #[serde(skip_serializing_if = "Option::is_none")] - total_cost: Option, - #[serde(skip_serializing_if = "Option::is_none")] - input_tokens: Option, - #[serde(skip_serializing_if = "Option::is_none")] - output_tokens: Option, + #[serde(flatten)] + token_usage: Option, }, /// Agent text output message @@ -367,6 +364,15 @@ pub enum EventMsg { GetHistoryEntryResponse(GetHistoryEntryResponseEvent), } +/// Token usage information for API calls +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct TokenUsage { + pub input_tokens: u32, + pub output_tokens: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub total_cost: Option, +} + // Individual event payload types matching each `EventMsg` variant. #[derive(Debug, Clone, Deserialize, Serialize)] @@ -569,4 +575,63 @@ mod tests { r#"{"id":"1234","msg":{"type":"session_configured","session_id":"67e55044-10b1-426f-9247-bb680e5fe0c8","model":"o4-mini","history_log_id":0,"history_entry_count":0}}"# ); } + + /// Serialize TaskComplete event with token usage to verify JSON format + #[test] + fn serialize_task_complete_with_usage() { + let event = Event { + id: "5678".to_string(), + msg: EventMsg::TaskComplete { + token_usage: Some(TokenUsage { + input_tokens: 1000, + output_tokens: 500, + total_cost: Some(0.0125), + }), + }, + }; + let serialized = serde_json::to_string(&event).unwrap(); + println!("JSON with cost: {}", serialized); + assert_eq!( + serialized, + r#"{"id":"5678","msg":{"type":"task_complete","input_tokens":1000,"output_tokens":500,"total_cost":0.0125}}"# + ); + } + + /// Serialize TaskComplete event without cost to verify JSON format + #[test] + fn serialize_task_complete_no_cost() { + let event = Event { + id: "9999".to_string(), + msg: EventMsg::TaskComplete { + token_usage: Some(TokenUsage { + input_tokens: 1500, + output_tokens: 750, + total_cost: None, + }), + }, + }; + let serialized = serde_json::to_string(&event).unwrap(); + println!("JSON without cost: {}", serialized); + assert_eq!( + serialized, + r#"{"id":"9999","msg":{"type":"task_complete","input_tokens":1500,"output_tokens":750}}"# + ); + } + + /// Serialize TaskComplete event with no token usage + #[test] + fn serialize_task_complete_no_usage() { + let event = Event { + id: "0000".to_string(), + msg: EventMsg::TaskComplete { + token_usage: None, + }, + }; + let serialized = serde_json::to_string(&event).unwrap(); + println!("JSON no usage: {}", serialized); + assert_eq!( + serialized, + r#"{"id":"0000","msg":{"type":"task_complete"}}"# + ); + } } diff --git a/codex-rs/core/src/usage.rs b/codex-rs/core/src/usage.rs index d3a7b5b5cb3..70c8ae5f83e 100644 --- a/codex-rs/core/src/usage.rs +++ b/codex-rs/core/src/usage.rs @@ -1,3 +1,13 @@ +/// Computes the total cost in USD for the given token usage using OpenAI pricing. +/// Returns `None` if the model is unknown or pricing is unavailable. +pub fn compute_openai_cost(model: &str, input_tokens: u32, output_tokens: u32) -> Option { + let (per_input_token_cost, per_output_token_cost) = get_openai_pricing(model)?; + // Rates are per-token. Multiply directly. + let cost = (input_tokens as f64) * per_input_token_cost + + (output_tokens as f64) * per_output_token_cost; + Some(cost) +} + /// Returns the OpenAI per-token pricing (input, output) **in USD** for /// a given model name. The list is not exhaustive – it only covers the most /// common public models so we offer reasonable estimates without hard-coding @@ -30,6 +40,21 @@ pub fn get_openai_pricing(model: &str) -> Option<(f64, f64)> { mod tests { use super::*; + #[test] + fn test_compute_openai_cost() { + // Test cost computation for known model + let cost = compute_openai_cost("gpt-4o-mini", 1000, 500).unwrap(); + let expected = (1000.0 * 0.6 / 1_000_000.0) + (500.0 * 2.4 / 1_000_000.0); + assert_eq!(cost, expected); + + // Test cost computation for unknown model + assert_eq!(compute_openai_cost("unknown-model", 1000, 500), None); + + // Test zero tokens + let cost = compute_openai_cost("gpt-4o-mini", 0, 0).unwrap(); + assert_eq!(cost, 0.0); + } + #[test] fn test_openai_pricing() { // Test exact matches diff --git a/codex-rs/core/src/util.rs b/codex-rs/core/src/util.rs index 7a86e4519bc..2c252e5f78f 100644 --- a/codex-rs/core/src/util.rs +++ b/codex-rs/core/src/util.rs @@ -71,7 +71,7 @@ pub trait UrlExt { /// Append a path to the URL, without modifying the original URL components. /// It allows us to configure query parameters and carry them over when we use /// different Wire API endpoints. - /// + /// /// This is necessary as some APIs (e.g. Azure OpenAI) requires query parameters /// to select different versions. fn append_path(self, path: &str) -> Result; @@ -80,17 +80,16 @@ pub trait UrlExt { impl UrlExt for Url { fn append_path(self, path: &str) -> Result { let mut url = self.clone(); - + // Validate path doesn't contain invalid characters if path.contains(|c: char| c.is_whitespace() || c == '?' || c == '#') { - return Err(anyhow::anyhow!("Invalid path: contains whitespace or special characters")); + return Err(anyhow::anyhow!( + "Invalid path: contains whitespace or special characters" + )); } // Split the path into segments, filtering out empty ones - let segments: Vec<&str> = path - .split('/') - .filter(|s| !s.is_empty()) - .collect(); + let segments: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect(); if segments.is_empty() { return Ok(url); @@ -98,12 +97,13 @@ impl UrlExt for Url { // Get path segments and add new segments { - let mut path_segments = url.path_segments_mut() + let mut path_segments = url + .path_segments_mut() .map_err(|_| anyhow::anyhow!("Failed to get path segments"))?; - + // Remove trailing empty segment if it exists path_segments.pop_if_empty(); - + // Add each non-empty segment for segment in segments { path_segments.push(segment); @@ -112,4 +112,4 @@ impl UrlExt for Url { Ok(url) } -} \ No newline at end of file +} diff --git a/codex-rs/core/tests/test_url_ext.rs b/codex-rs/core/tests/test_url_ext.rs index a78c5d9f624..0de86cb021b 100644 --- a/codex-rs/core/tests/test_url_ext.rs +++ b/codex-rs/core/tests/test_url_ext.rs @@ -1,5 +1,5 @@ -use url::Url; use codex_core::util::UrlExt; +use url::Url; #[cfg(test)] mod tests { @@ -16,7 +16,10 @@ mod tests { fn test_append_path_with_query_params() { let base_url = Url::parse("https://api.example.com/v1?version=2023").unwrap(); let result = base_url.append_path("/models").unwrap(); - assert_eq!(result.as_str(), "https://api.example.com/v1/models?version=2023"); + assert_eq!( + result.as_str(), + "https://api.example.com/v1/models?version=2023" + ); } #[test] @@ -44,7 +47,10 @@ mod tests { fn test_append_path_with_complex_query() { let base_url = Url::parse("https://api.example.com/v1?version=2023&api-key=123").unwrap(); let result = base_url.append_path("/models/gpt-4").unwrap(); - assert_eq!(result.as_str(), "https://api.example.com/v1/models/gpt-4?version=2023&api-key=123"); + assert_eq!( + result.as_str(), + "https://api.example.com/v1/models/gpt-4?version=2023&api-key=123" + ); } #[test] diff --git a/codex-rs/exec/src/event_processor.rs b/codex-rs/exec/src/event_processor.rs index 4c8278cc597..674bfc94b67 100644 --- a/codex-rs/exec/src/event_processor.rs +++ b/codex-rs/exec/src/event_processor.rs @@ -117,8 +117,17 @@ impl EventProcessor { let msg = format!("Task started: {id}"); ts_println!("{}", msg.style(self.dimmed)); } - EventMsg::TaskComplete => { - let msg = format!("Task complete: {id}"); + EventMsg::TaskComplete { token_usage } => { + let mut msg = format!("Task complete: {id}"); + if let Some(usage) = token_usage { + msg.push_str(&format!( + " ({}input, {}output tokens)", + usage.input_tokens, usage.output_tokens + )); + if let Some(cost) = usage.total_cost { + msg.push_str(&format!(" [${:.4}]", cost)); + } + } ts_println!("{}", msg.style(self.bold)); } EventMsg::AgentMessage(AgentMessageEvent { message }) => { diff --git a/codex-rs/exec/src/lib.rs b/codex-rs/exec/src/lib.rs index 348bff08e61..f362ab67abe 100644 --- a/codex-rs/exec/src/lib.rs +++ b/codex-rs/exec/src/lib.rs @@ -137,7 +137,7 @@ pub async fn run_main(cli: Cli) -> anyhow::Result<()> { let initial_images_event_id = codex.submit(Op::UserInput { items }).await?; info!("Sent images with event ID: {initial_images_event_id}"); while let Ok(event) = codex.next_event().await { - if event.id == initial_images_event_id && matches!(event.msg, EventMsg::TaskComplete) { + if event.id == initial_images_event_id && matches!(event.msg, EventMsg::TaskComplete { .. }) { break; } } @@ -152,7 +152,7 @@ pub async fn run_main(cli: Cli) -> anyhow::Result<()> { let mut event_processor = EventProcessor::create_with_ansi(stdout_with_ansi); while let Some(event) = rx.recv().await { let last_event = - event.id == initial_prompt_task_id && matches!(event.msg, EventMsg::TaskComplete); + event.id == initial_prompt_task_id && matches!(event.msg, EventMsg::TaskComplete { .. }); event_processor.process_event(event); if last_event { break; From f25713cc5da760dac4849145a6c4679962cec37f Mon Sep 17 00:00:00 2001 From: Arnaud Stiegler Date: Fri, 23 May 2025 10:43:52 -0400 Subject: [PATCH 9/9] fixes --- codex-rs/core/src/client.rs | 6 +- codex-rs/core/src/exec.rs | 1 + codex-rs/core/src/protocol.rs | 85 ------------------- codex-rs/core/src/util.rs | 22 ++--- codex-rs/core/tests/protocol_serialization.rs | 82 ++++++++++++++++++ codex-rs/core/tests/test_url_ext.rs | 12 +-- 6 files changed, 101 insertions(+), 107 deletions(-) create mode 100644 codex-rs/core/tests/protocol_serialization.rs diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 6e5b9c70533..0f3e1414a64 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -449,11 +449,13 @@ async fn process_sse( let usage_input_tokens = usage .get("input_tokens") .and_then(|v| v.as_u64()) - .unwrap_or(0) as u32; + .map(|v| v as u32) + .unwrap_or(0); let usage_output_tokens = usage .get("output_tokens") .and_then(|v| v.as_u64()) - .unwrap_or(0) as u32; + .map(|v| v as u32) + .unwrap_or(0); token_aggregator .lock() diff --git a/codex-rs/core/src/exec.rs b/codex-rs/core/src/exec.rs index d9aeea5be3e..8e9420a029b 100644 --- a/codex-rs/core/src/exec.rs +++ b/codex-rs/core/src/exec.rs @@ -25,6 +25,7 @@ use crate::protocol::SandboxPolicy; use once_cell::sync::Lazy; + /// Each value is parsed as an unsigned integer. If parsing fails, or the /// environment variable is unset, we fall back to the hard-coded default. pub(crate) static MAX_STREAM_OUTPUT: Lazy = Lazy::new(|| { diff --git a/codex-rs/core/src/protocol.rs b/codex-rs/core/src/protocol.rs index 894e654e0ef..89aec6cd32a 100644 --- a/codex-rs/core/src/protocol.rs +++ b/codex-rs/core/src/protocol.rs @@ -550,88 +550,3 @@ pub struct Chunk { pub inserted_lines: Vec, } -#[cfg(test)] -mod tests { - #![allow(clippy::unwrap_used)] - use super::*; - - /// Serialize Event to verify that its JSON representation has the expected - /// amount of nesting. - #[test] - fn serialize_event() { - let session_id: Uuid = uuid::uuid!("67e55044-10b1-426f-9247-bb680e5fe0c8"); - let event = Event { - id: "1234".to_string(), - msg: EventMsg::SessionConfigured(SessionConfiguredEvent { - session_id, - model: "o4-mini".to_string(), - history_log_id: 0, - history_entry_count: 0, - }), - }; - let serialized = serde_json::to_string(&event).unwrap(); - assert_eq!( - serialized, - r#"{"id":"1234","msg":{"type":"session_configured","session_id":"67e55044-10b1-426f-9247-bb680e5fe0c8","model":"o4-mini","history_log_id":0,"history_entry_count":0}}"# - ); - } - - /// Serialize TaskComplete event with token usage to verify JSON format - #[test] - fn serialize_task_complete_with_usage() { - let event = Event { - id: "5678".to_string(), - msg: EventMsg::TaskComplete { - token_usage: Some(TokenUsage { - input_tokens: 1000, - output_tokens: 500, - total_cost: Some(0.0125), - }), - }, - }; - let serialized = serde_json::to_string(&event).unwrap(); - println!("JSON with cost: {}", serialized); - assert_eq!( - serialized, - r#"{"id":"5678","msg":{"type":"task_complete","input_tokens":1000,"output_tokens":500,"total_cost":0.0125}}"# - ); - } - - /// Serialize TaskComplete event without cost to verify JSON format - #[test] - fn serialize_task_complete_no_cost() { - let event = Event { - id: "9999".to_string(), - msg: EventMsg::TaskComplete { - token_usage: Some(TokenUsage { - input_tokens: 1500, - output_tokens: 750, - total_cost: None, - }), - }, - }; - let serialized = serde_json::to_string(&event).unwrap(); - println!("JSON without cost: {}", serialized); - assert_eq!( - serialized, - r#"{"id":"9999","msg":{"type":"task_complete","input_tokens":1500,"output_tokens":750}}"# - ); - } - - /// Serialize TaskComplete event with no token usage - #[test] - fn serialize_task_complete_no_usage() { - let event = Event { - id: "0000".to_string(), - msg: EventMsg::TaskComplete { - token_usage: None, - }, - }; - let serialized = serde_json::to_string(&event).unwrap(); - println!("JSON no usage: {}", serialized); - assert_eq!( - serialized, - r#"{"id":"0000","msg":{"type":"task_complete"}}"# - ); - } -} diff --git a/codex-rs/core/src/util.rs b/codex-rs/core/src/util.rs index 2c252e5f78f..7a86e4519bc 100644 --- a/codex-rs/core/src/util.rs +++ b/codex-rs/core/src/util.rs @@ -71,7 +71,7 @@ pub trait UrlExt { /// Append a path to the URL, without modifying the original URL components. /// It allows us to configure query parameters and carry them over when we use /// different Wire API endpoints. - /// + /// /// This is necessary as some APIs (e.g. Azure OpenAI) requires query parameters /// to select different versions. fn append_path(self, path: &str) -> Result; @@ -80,16 +80,17 @@ pub trait UrlExt { impl UrlExt for Url { fn append_path(self, path: &str) -> Result { let mut url = self.clone(); - + // Validate path doesn't contain invalid characters if path.contains(|c: char| c.is_whitespace() || c == '?' || c == '#') { - return Err(anyhow::anyhow!( - "Invalid path: contains whitespace or special characters" - )); + return Err(anyhow::anyhow!("Invalid path: contains whitespace or special characters")); } // Split the path into segments, filtering out empty ones - let segments: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect(); + let segments: Vec<&str> = path + .split('/') + .filter(|s| !s.is_empty()) + .collect(); if segments.is_empty() { return Ok(url); @@ -97,13 +98,12 @@ impl UrlExt for Url { // Get path segments and add new segments { - let mut path_segments = url - .path_segments_mut() + let mut path_segments = url.path_segments_mut() .map_err(|_| anyhow::anyhow!("Failed to get path segments"))?; - + // Remove trailing empty segment if it exists path_segments.pop_if_empty(); - + // Add each non-empty segment for segment in segments { path_segments.push(segment); @@ -112,4 +112,4 @@ impl UrlExt for Url { Ok(url) } -} +} \ No newline at end of file diff --git a/codex-rs/core/tests/protocol_serialization.rs b/codex-rs/core/tests/protocol_serialization.rs new file mode 100644 index 00000000000..71061d3114a --- /dev/null +++ b/codex-rs/core/tests/protocol_serialization.rs @@ -0,0 +1,82 @@ +use codex_core::protocol::*; +use uuid::Uuid; + +/// Serialize Event to verify that its JSON representation has the expected +/// amount of nesting. +#[test] +fn serialize_event() { + let session_id: Uuid = uuid::uuid!("67e55044-10b1-426f-9247-bb680e5fe0c8"); + let event = Event { + id: "1234".to_string(), + msg: EventMsg::SessionConfigured(SessionConfiguredEvent { + session_id, + model: "o4-mini".to_string(), + history_log_id: 0, + history_entry_count: 0, + }), + }; + let serialized = serde_json::to_string(&event).unwrap(); + assert_eq!( + serialized, + r#"{"id":"1234","msg":{"type":"session_configured","session_id":"67e55044-10b1-426f-9247-bb680e5fe0c8","model":"o4-mini","history_log_id":0,"history_entry_count":0}}"# + ); +} + +/// Serialize TaskComplete event with token usage to verify JSON format +#[test] +fn serialize_task_complete_with_usage() { + let event = Event { + id: "5678".to_string(), + msg: EventMsg::TaskComplete { + token_usage: Some(TokenUsage { + input_tokens: 1000, + output_tokens: 500, + total_cost: Some(0.0125), + }), + }, + }; + let serialized = serde_json::to_string(&event).unwrap(); + println!("JSON with cost: {}", serialized); + assert_eq!( + serialized, + r#"{"id":"5678","msg":{"type":"task_complete","input_tokens":1000,"output_tokens":500,"total_cost":0.0125}}"# + ); +} + +/// Serialize TaskComplete event without cost to verify JSON format +#[test] +fn serialize_task_complete_no_cost() { + let event = Event { + id: "9999".to_string(), + msg: EventMsg::TaskComplete { + token_usage: Some(TokenUsage { + input_tokens: 1500, + output_tokens: 750, + total_cost: None, + }), + }, + }; + let serialized = serde_json::to_string(&event).unwrap(); + println!("JSON without cost: {}", serialized); + assert_eq!( + serialized, + r#"{"id":"9999","msg":{"type":"task_complete","input_tokens":1500,"output_tokens":750}}"# + ); +} + +/// Serialize TaskComplete event with no token usage +#[test] +fn serialize_task_complete_no_usage() { + let event = Event { + id: "0000".to_string(), + msg: EventMsg::TaskComplete { + token_usage: None, + }, + }; + let serialized = serde_json::to_string(&event).unwrap(); + println!("JSON no usage: {}", serialized); + assert_eq!( + serialized, + r#"{"id":"0000","msg":{"type":"task_complete"}}"# + ); +} \ No newline at end of file diff --git a/codex-rs/core/tests/test_url_ext.rs b/codex-rs/core/tests/test_url_ext.rs index 0de86cb021b..a78c5d9f624 100644 --- a/codex-rs/core/tests/test_url_ext.rs +++ b/codex-rs/core/tests/test_url_ext.rs @@ -1,5 +1,5 @@ -use codex_core::util::UrlExt; use url::Url; +use codex_core::util::UrlExt; #[cfg(test)] mod tests { @@ -16,10 +16,7 @@ mod tests { fn test_append_path_with_query_params() { let base_url = Url::parse("https://api.example.com/v1?version=2023").unwrap(); let result = base_url.append_path("/models").unwrap(); - assert_eq!( - result.as_str(), - "https://api.example.com/v1/models?version=2023" - ); + assert_eq!(result.as_str(), "https://api.example.com/v1/models?version=2023"); } #[test] @@ -47,10 +44,7 @@ mod tests { fn test_append_path_with_complex_query() { let base_url = Url::parse("https://api.example.com/v1?version=2023&api-key=123").unwrap(); let result = base_url.append_path("/models/gpt-4").unwrap(); - assert_eq!( - result.as_str(), - "https://api.example.com/v1/models/gpt-4?version=2023&api-key=123" - ); + assert_eq!(result.as_str(), "https://api.example.com/v1/models/gpt-4?version=2023&api-key=123"); } #[test]