Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit caf374e

Browse files
amitksingh1490forge-code-agentautofix-ci[bot]
authored
fix: correct token double-counting for Anthropic and Bedrock providers (tailcallhq#2861)
Co-authored-by: ForgeCode <[email protected]> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
1 parent 09fbef3 commit caf374e

6 files changed

Lines changed: 210 additions & 19 deletions

File tree

crates/forge_app/src/dto/anthropic/response.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -657,12 +657,14 @@ mod tests {
657657
assert_eq!(delta_domain.completion_tokens, TokenCount::Actual(75));
658658
assert_eq!(delta_domain.cached_tokens, TokenCount::Actual(0));
659659

660-
// Accumulate usage (simulating how we'd combine them in practice)
661-
let accumulated = initial_domain.accumulate(&delta_domain);
662-
assert_eq!(accumulated.prompt_tokens, TokenCount::Actual(150));
663-
assert_eq!(accumulated.completion_tokens, TokenCount::Actual(75));
664-
assert_eq!(accumulated.cached_tokens, TokenCount::Actual(50));
665-
assert_eq!(accumulated.total_tokens, TokenCount::Actual(225));
660+
// Merge usage (simulating how we'd combine them in practice)
661+
// Using merge (max) instead of accumulate (sum) since Anthropic
662+
// usage values are cumulative, not incremental deltas.
663+
let merged = initial_domain.merge(&delta_domain);
664+
assert_eq!(merged.prompt_tokens, TokenCount::Actual(150));
665+
assert_eq!(merged.completion_tokens, TokenCount::Actual(75));
666+
assert_eq!(merged.cached_tokens, TokenCount::Actual(50));
667+
assert_eq!(merged.total_tokens, TokenCount::Actual(150)); // max(150, 75)
666668
}
667669

668670
#[test]

crates/forge_domain/src/context.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -774,6 +774,21 @@ impl Default for TokenCount {
774774
}
775775
}
776776

777+
impl TokenCount {
778+
/// Returns the larger of two TokenCount values by their inner count.
779+
/// If both are `Actual`, the result is `Actual`. If either is `Approx`,
780+
/// the result is `Approx`.
781+
pub fn max(self, other: TokenCount) -> TokenCount {
782+
use TokenCount::*;
783+
match (self, other) {
784+
(Actual(a), Actual(b)) => Actual(a.max(b)),
785+
(Actual(a), Approx(b)) => Approx(a.max(b)),
786+
(Approx(a), Actual(b)) => Approx(a.max(b)),
787+
(Approx(a), Approx(b)) => Approx(a.max(b)),
788+
}
789+
}
790+
}
791+
777792
impl Deref for TokenCount {
778793
type Target = usize;
779794

crates/forge_domain/src/message.rs

Lines changed: 96 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@ pub struct Usage {
3131
}
3232

3333
impl Usage {
34-
/// Accumulates usage from another Usage instance
35-
/// Cost is summed, tokens are added using TokenCount's Add implementation
34+
/// Accumulates usage from another Usage instance by summing all fields.
35+
///
36+
/// Use this for aggregating usage across **independent** requests (e.g.,
37+
/// session-level totals where each message has its own final usage).
3638
pub fn accumulate(mut self, other: &Usage) -> Self {
3739
self.prompt_tokens = self.prompt_tokens + other.prompt_tokens;
3840
self.completion_tokens = self.completion_tokens + other.completion_tokens;
@@ -46,6 +48,34 @@ impl Usage {
4648
};
4749
self
4850
}
51+
52+
/// Merges usage from another Usage instance using a "last non-zero wins"
53+
/// strategy.
54+
///
55+
/// Use this when combining **partial** usage events within a single
56+
/// streaming response where values are **cumulative** (not incremental):
57+
/// - `message_start`: `input_tokens=1000, output_tokens=1`
58+
/// - `message_delta`: `input_tokens=0, output_tokens=75` (cumulative
59+
/// total)
60+
///
61+
/// For each field, the larger of the two values is kept. This prevents
62+
/// double-counting when providers report cumulative token counts across
63+
/// multiple events.
64+
///
65+
/// Cost is summed since cost events are always additive.
66+
pub fn merge(mut self, other: &Usage) -> Self {
67+
self.prompt_tokens = self.prompt_tokens.max(other.prompt_tokens);
68+
self.completion_tokens = self.completion_tokens.max(other.completion_tokens);
69+
self.total_tokens = self.total_tokens.max(other.total_tokens);
70+
self.cached_tokens = self.cached_tokens.max(other.cached_tokens);
71+
self.cost = match (self.cost, other.cost) {
72+
(Some(a), Some(b)) => Some(a + b),
73+
(Some(a), None) => Some(a),
74+
(None, Some(b)) => Some(b),
75+
(None, None) => None,
76+
};
77+
self
78+
}
4979
}
5080

5181
/// Represents a message that was received from the LLM provider
@@ -374,4 +404,68 @@ mod tests {
374404
FinishReason::Stop
375405
);
376406
}
407+
408+
#[test]
409+
fn test_usage_merge_anthropic_cumulative() {
410+
// Fixture: Simulates Anthropic's message_start + message_delta pattern
411+
// where output_tokens in message_delta is CUMULATIVE (total), not a delta.
412+
let fixture_message_start = Usage {
413+
prompt_tokens: TokenCount::Actual(1000),
414+
completion_tokens: TokenCount::Actual(1), // Initial output token
415+
total_tokens: TokenCount::Actual(1001),
416+
cached_tokens: TokenCount::Actual(300),
417+
cost: None,
418+
};
419+
420+
let fixture_message_delta = Usage {
421+
prompt_tokens: TokenCount::Actual(0),
422+
completion_tokens: TokenCount::Actual(75), // Cumulative total, NOT delta
423+
total_tokens: TokenCount::Actual(75),
424+
cached_tokens: TokenCount::Actual(0),
425+
cost: None,
426+
};
427+
428+
let actual = fixture_message_start.merge(&fixture_message_delta);
429+
430+
let expected = Usage {
431+
prompt_tokens: TokenCount::Actual(1000), // max(1000, 0)
432+
completion_tokens: TokenCount::Actual(75), // max(1, 75) = 75, NOT 1+75=76
433+
total_tokens: TokenCount::Actual(1001), // max(1001, 75)
434+
cached_tokens: TokenCount::Actual(300), // max(300, 0)
435+
cost: None,
436+
};
437+
438+
assert_eq!(actual, expected);
439+
}
440+
441+
#[test]
442+
fn test_usage_merge_preserves_costs() {
443+
let fixture_usage_1 = Usage {
444+
prompt_tokens: TokenCount::Actual(100),
445+
completion_tokens: TokenCount::Actual(0),
446+
total_tokens: TokenCount::Actual(100),
447+
cached_tokens: TokenCount::Actual(0),
448+
cost: Some(0.01),
449+
};
450+
451+
let fixture_usage_2 = Usage {
452+
prompt_tokens: TokenCount::Actual(0),
453+
completion_tokens: TokenCount::Actual(50),
454+
total_tokens: TokenCount::Actual(50),
455+
cached_tokens: TokenCount::Actual(0),
456+
cost: Some(0.02),
457+
};
458+
459+
let actual = fixture_usage_1.merge(&fixture_usage_2);
460+
461+
let expected = Usage {
462+
prompt_tokens: TokenCount::Actual(100),
463+
completion_tokens: TokenCount::Actual(50),
464+
total_tokens: TokenCount::Actual(100),
465+
cached_tokens: TokenCount::Actual(0),
466+
cost: Some(0.03), // Costs are summed, not maxed
467+
};
468+
469+
assert_eq!(actual, expected);
470+
}
377471
}

crates/forge_domain/src/result_stream_ext.rs

Lines changed: 85 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,15 +69,17 @@ impl ResultStreamExt<anyhow::Error> for crate::BoxStream<ChatCompletionMessage,
6969
anyhow::Ok(message?).with_context(|| "Failed to process message stream")?;
7070
// Process usage information
7171
// - For Anthropic-style streaming: input tokens in MessageStart, output tokens
72-
// in MessageDelta
72+
// in MessageDelta (values are CUMULATIVE, not incremental)
73+
// ref: https://platform.claude.com/docs/en/build-with-claude/streaming#event-types
7374
// - For OpenAI-style streaming: all tokens in the final chunk
7475
// - For GLM-style: may send complete usage in every chunk (need to replace, not
7576
// accumulate)
77+
// - For Google-style: cumulative usage in every chunk
7678
// - Cost-only events: have 0 tokens but a cost value
7779
if let Some(current_usage) = message.usage.as_ref() {
7880
// If current usage has both prompt and completion tokens, it's a "complete"
79-
// usage In this case, replace instead of accumulate (handles
80-
// GLM-style streaming)
81+
// usage. In this case, replace instead of merge (handles GLM-style streaming
82+
// where every chunk has full usage).
8183
let is_complete_usage =
8284
*current_usage.prompt_tokens > 0 && *current_usage.completion_tokens > 0;
8385

@@ -95,10 +97,19 @@ impl ResultStreamExt<anyhow::Error> for crate::BoxStream<ChatCompletionMessage,
9597
}
9698
} else if is_cost_only {
9799
// Accumulate only the cost to the existing usage
98-
usage.cost = current_usage.cost;
100+
usage.cost = match (usage.cost, current_usage.cost) {
101+
(Some(a), Some(b)) => Some(a + b),
102+
(Some(a), None) => Some(a),
103+
(None, Some(b)) => Some(b),
104+
(None, None) => None,
105+
};
99106
} else {
100-
// Accumulate partial usage (for Anthropic-style streaming)
101-
usage = usage.accumulate(current_usage);
107+
// Merge partial usage using "max" strategy. This correctly handles
108+
// providers like Anthropic where usage values are CUMULATIVE across
109+
// events (message_start has input tokens, message_delta has the
110+
// total output tokens). Using max instead of sum prevents
111+
// double-counting when message_start includes output_tokens=1.
112+
usage = usage.merge(current_usage);
102113
}
103114
}
104115

@@ -485,8 +496,73 @@ mod tests {
485496
}
486497

487498
#[tokio::test]
488-
async fn test_into_full_anthropic_streaming_usage_accumulation() {
499+
async fn test_into_full_anthropic_streaming_usage_merge() {
500+
// Fixture: Simulate Anthropic streaming pattern where message_start has
501+
// output_tokens=1 (the common case) and message_delta has the cumulative total.
502+
// This tests that merge (max) is used instead of accumulate (sum) to prevent
503+
// double-counting.
504+
let messages = vec![
505+
// MessageStart with input token usage AND output_tokens=1
506+
Ok(ChatCompletionMessage::default().usage(Usage {
507+
prompt_tokens: TokenCount::Actual(1000),
508+
completion_tokens: TokenCount::Actual(1),
509+
total_tokens: TokenCount::Actual(1001),
510+
cached_tokens: TokenCount::Actual(300),
511+
cost: None,
512+
})),
513+
// Content deltas
514+
Ok(ChatCompletionMessage::default().content(Content::part("Hello "))),
515+
Ok(ChatCompletionMessage::default().content(Content::part("world!"))),
516+
// MessageDelta with cumulative output token usage
517+
Ok(ChatCompletionMessage::default()
518+
.usage(Usage {
519+
prompt_tokens: TokenCount::Actual(0),
520+
completion_tokens: TokenCount::Actual(50),
521+
total_tokens: TokenCount::Actual(50),
522+
cached_tokens: TokenCount::Actual(0),
523+
cost: None,
524+
})
525+
.finish_reason(FinishReason::Stop)),
526+
];
527+
528+
let result_stream: BoxStream<ChatCompletionMessage, anyhow::Error> =
529+
Box::pin(tokio_stream::iter(messages));
530+
531+
// Actual: Convert stream to full message
532+
let actual = result_stream.into_full(false).await.unwrap();
533+
534+
// Expected: Usage should use max (merge) not sum (accumulate).
535+
// message_start has completion_tokens=1 and prompt_tokens=1000, so
536+
// is_complete_usage=true -> replace: usage = {1000, 1, 1001, 300}
537+
// message_delta has prompt=0, completion=50 -> is_complete_usage=false ->
538+
// merge: prompt = max(1000, 0) = 1000
539+
// completion = max(1, 50) = 50 (NOT 1+50=51)
540+
// total = max(1001, 50) = 1001
541+
// cached = max(300, 0) = 300
542+
let expected = ChatCompletionMessageFull {
543+
content: "Hello world!".to_string(),
544+
tool_calls: vec![],
545+
thought_signature: None,
546+
usage: Usage {
547+
prompt_tokens: TokenCount::Actual(1000),
548+
completion_tokens: TokenCount::Actual(50), // max(1, 50) = 50, NOT 1+50=51
549+
total_tokens: TokenCount::Actual(1001),
550+
cached_tokens: TokenCount::Actual(300),
551+
cost: None,
552+
},
553+
reasoning: None,
554+
reasoning_details: None,
555+
finish_reason: Some(FinishReason::Stop),
556+
phase: None,
557+
};
558+
559+
assert_eq!(actual, expected);
560+
}
561+
562+
#[tokio::test]
563+
async fn test_into_full_anthropic_streaming_usage_merge_zero_output() {
489564
// Fixture: Simulate Anthropic/Vertex AI Anthropic streaming pattern
565+
// where message_start has output_tokens=0 (Vertex AI pattern).
490566
// MessageStart event has input tokens, MessageDelta has output tokens
491567
let messages = vec![
492568
// MessageStart with input token usage
@@ -518,15 +594,15 @@ mod tests {
518594
// Actual: Convert stream to full message
519595
let actual = result_stream.into_full(false).await.unwrap();
520596

521-
// Expected: Usage should be accumulated from both MessageStart and MessageDelta
597+
// Expected: Usage should be merged from both MessageStart and MessageDelta
522598
let expected = ChatCompletionMessageFull {
523599
content: "Hello world!".to_string(),
524600
tool_calls: vec![],
525601
thought_signature: None,
526602
usage: Usage {
527603
prompt_tokens: TokenCount::Actual(1000), // From MessageStart
528604
completion_tokens: TokenCount::Actual(50), // From MessageDelta
529-
total_tokens: TokenCount::Actual(1050), // Sum of both
605+
total_tokens: TokenCount::Actual(1000), // max(1000, 50) = 1000
530606
cached_tokens: TokenCount::Actual(300), // From MessageStart
531607
cost: None,
532608
},

crates/forge_repo/src/provider/bedrock.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ impl IntoDomain for aws_sdk_bedrockruntime::types::ConverseStreamOutput {
407407
.saturating_add(u.cache_write_input_tokens.unwrap_or(0));
408408

409409
forge_domain::Usage {
410-
prompt_tokens: forge_domain::TokenCount::Actual(u.total_tokens as usize),
410+
prompt_tokens: forge_domain::TokenCount::Actual(u.input_tokens as usize),
411411
completion_tokens: forge_domain::TokenCount::Actual(
412412
u.output_tokens as usize,
413413
),
@@ -1418,7 +1418,7 @@ mod tests {
14181418
let actual = fixture.into_domain();
14191419
let expected =
14201420
ChatCompletionMessage::assistant(Content::part("")).usage(forge_domain::Usage {
1421-
prompt_tokens: TokenCount::Actual(1000),
1421+
prompt_tokens: TokenCount::Actual(800),
14221422
completion_tokens: TokenCount::Actual(200),
14231423
total_tokens: TokenCount::Actual(1000),
14241424
cached_tokens: TokenCount::Actual(80), // 50 + 30

crates/forge_repo/src/provider/openai_responses/response.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ pub(super) enum StreamItem {
8383
Message(Box<ChatCompletionMessage>),
8484
}
8585

86+
/// Converts OpenAI Responses API usage into the domain Usage type.
87+
/// Usage is sent once in the `response.completed` event (not split across
88+
/// events).
89+
/// ref: https://developers.openai.com/api/reference/resources/responses#(resource)%20responses%20%3E%20(model)%20response_usage%20%3E%20(schema)
8690
impl IntoDomain for oai::ResponseUsage {
8791
type Domain = Usage;
8892

0 commit comments

Comments
 (0)