From 4948d0f5e277ac970fea380a736641aecb271c88 Mon Sep 17 00:00:00 2001 From: Aaron Hill Date: Thu, 29 Jan 2026 15:20:04 -0500 Subject: [PATCH 1/5] Switch to sse-stream crate --- Cargo.lock | 41 ++++++++++++++--------- Cargo.toml | 4 +-- gateway/Cargo.toml | 1 - internal/autopilot-client/Cargo.toml | 3 +- internal/autopilot-client/src/client.rs | 8 ++--- internal/reqwest-sse-stream/Cargo.toml | 17 ++++++++++ internal/reqwest-sse-stream/src/lib.rs | 39 ++++++++++++++++++++++ provider-proxy/Cargo.toml | 3 +- provider-proxy/tests/e2e/tests.rs | 43 +++++++++++-------------- tensorzero-core/Cargo.toml | 3 +- tensorzero-core/src/http.rs | 17 ++++++---- 11 files changed, 122 insertions(+), 57 deletions(-) create mode 100644 internal/reqwest-sse-stream/Cargo.toml create mode 100644 internal/reqwest-sse-stream/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 53e8054cef..2a1171d870 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -341,13 +341,14 @@ dependencies = [ "futures", "moka", "reqwest 0.12.28", - "reqwest-eventsource", + "reqwest-sse-stream", "schemars 1.2.0", "secrecy", "serde", "serde_json", "serde_path_to_error", "sqlx", + "sse-stream", "tensorzero-derive", "tensorzero-types", "thiserror 2.0.18", @@ -2443,7 +2444,6 @@ dependencies = [ "metrics-exporter-prometheus", "mimalloc", "reqwest 0.12.28", - "reqwest-eventsource", "secrecy", "serde", "serde_json", @@ -4517,11 +4517,12 @@ dependencies = [ "rand 0.9.2", "rcgen", "reqwest 0.12.28", - "reqwest-eventsource", + "reqwest-sse-stream", "rustls", "serde", "serde_json", "sha2", + "sse-stream", "tempfile", "tokio", "tokio-rustls", @@ -5239,19 +5240,15 @@ dependencies = [ ] [[package]] -name = "reqwest-eventsource" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde" +name = "reqwest-sse-stream" +version = "2026.1.6" dependencies = [ - "eventsource-stream", - "futures-core", - "futures-timer", - "mime 0.3.17", - "nom 7.1.3", - "pin-project-lite", + "futures", + "http 1.4.0", + "http-body 1.0.1", "reqwest 0.12.28", - "thiserror 1.0.69", + "sse-stream", + "thiserror 2.0.18", ] [[package]] @@ -6120,6 +6117,19 @@ dependencies = [ "uuid", ] +[[package]] +name = "sse-stream" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb4dc4d33c68ec1f27d386b5610a351922656e1fdf5c05bbaad930cd1519479a" +dependencies = [ + "bytes", + "futures-util", + "http-body 1.0.1", + "http-body-util", + "pin-project-lite", +] + [[package]] name = "stable_deref_trait" version = "1.2.1" @@ -6426,7 +6436,7 @@ dependencies = [ "redis-test", "regex", "reqwest 0.12.28", - "reqwest-eventsource", + "reqwest-sse-stream", "schemars 1.2.0", "secrecy", "serde", @@ -6434,6 +6444,7 @@ dependencies = [ "serde_path_to_error", "sha2", "sqlx", + "sse-stream", "strum", "tempfile", "tensorzero", diff --git a/Cargo.toml b/Cargo.toml index 7ab870ce3b..fa70bcd245 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ members = [ "internal/durable-tools", "internal/tensorzero-types", "internal/tensorzero-types-providers", - "internal/autopilot-worker", + "internal/autopilot-worker", "internal/reqwest-sse-stream", ] resolver = "2" @@ -44,7 +44,6 @@ uuid = { version = "1.20.0", features = ["serde", "v7"] } serde_json = { version = "1.0.143", features = ["preserve_order"] } secrecy = { version = "0.10.2", features = ["serde"] } toml = { version = "0.9.11", features = ["preserve_order"] } -reqwest-eventsource = "0.6.0" async-stream = "0.3.5" async-trait = "0.1.89" http = "1.4.0" @@ -60,6 +59,7 @@ clap = { version = "4.5.55", features = ["derive"] } futures = "0.3.30" thiserror = "2.0.18" lazy_static = { version = "1.5.0" } +sse-stream = "0.2.1" minijinja = { version = "2.15.1", features = [ "loader", "debug", diff --git a/gateway/Cargo.toml b/gateway/Cargo.toml index 85ccf40f8c..223fee4f86 100644 --- a/gateway/Cargo.toml +++ b/gateway/Cargo.toml @@ -44,7 +44,6 @@ reqwest.workspace = true tempfile = "3.21.0" tensorzero = { path = "../clients/rust", features = ["e2e_tests"] } serde_json = { workspace = true } -reqwest-eventsource = { workspace = true } futures = { workspace = true } secrecy = { workspace = true } uuid = { workspace = true } diff --git a/internal/autopilot-client/Cargo.toml b/internal/autopilot-client/Cargo.toml index 8471ea7af6..371e878941 100644 --- a/internal/autopilot-client/Cargo.toml +++ b/internal/autopilot-client/Cargo.toml @@ -11,7 +11,8 @@ durable-tools-spawn.workspace = true futures.workspace = true moka.workspace = true reqwest = { workspace = true, features = ["stream"] } -reqwest-eventsource.workspace = true +sse-stream = { workspace = true } +reqwest-sse-stream = { path = "../reqwest-sse-stream" } secrecy.workspace = true serde.workspace = true serde_json.workspace = true diff --git a/internal/autopilot-client/src/client.rs b/internal/autopilot-client/src/client.rs index f1200882f6..d7a0446d6e 100644 --- a/internal/autopilot-client/src/client.rs +++ b/internal/autopilot-client/src/client.rs @@ -9,9 +9,9 @@ use durable_tools_spawn::{SpawnClient, SpawnOptions}; use futures::stream::{Stream, StreamExt}; use moka::sync::Cache; use reqwest::header::{AUTHORIZATION, HeaderMap, HeaderValue}; -use reqwest_eventsource::{Event as SseEvent, EventSource}; use secrecy::{ExposeSecret, SecretString}; use sqlx::PgPool; +use sse_stream::{Sse, SseStream}; use url::Url; use uuid::Uuid; @@ -674,13 +674,13 @@ impl AutopilotClient { let request = self.sse_http_client.get(url).headers(self.auth_headers()); - let mut event_source = - EventSource::new(request).map_err(|e| AutopilotError::Sse(e.to_string()))?; + let event_source = + SseStream::from_byte_stream(request.send().await?.error_for_status()?.bytes_stream()); // Wait for connection to be established or fail. // The first event should be Open on success, or an error on failure. match event_source.next().await { - Some(Ok(SseEvent::Open)) => { + Some(Ok(Sse::Open)) => { // Connection established successfully } Some(Err(e)) => { diff --git a/internal/reqwest-sse-stream/Cargo.toml b/internal/reqwest-sse-stream/Cargo.toml new file mode 100644 index 0000000000..5f68d40e53 --- /dev/null +++ b/internal/reqwest-sse-stream/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "reqwest-sse-stream" +version.workspace = true +rust-version.workspace = true +license.workspace = true +edition.workspace = true + +[dependencies] +sse-stream = { workspace = true } +reqwest = { workspace = true, features = ["stream"] } +futures = { workspace = true } +http = { workspace = true } +thiserror = { workspace = true } +http-body = "1.0.1" + +[lints] +workspace = true diff --git a/internal/reqwest-sse-stream/src/lib.rs b/internal/reqwest-sse-stream/src/lib.rs new file mode 100644 index 0000000000..252f193b7a --- /dev/null +++ b/internal/reqwest-sse-stream/src/lib.rs @@ -0,0 +1,39 @@ +use http_body::Body; +use sse_stream::SseStream; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum ReqwestSseStreamError { + #[error("Reqwest error: {0}")] + ReqwestError(reqwest::Error), + #[error("Expected content-type 'text/event-stream', got {0:?}")] + InvalidContentType(Option), +} + +pub async fn into_sse_stream( + builder: reqwest::RequestBuilder, +) -> Result>, ReqwestSseStreamError> { + let response = builder + .header(reqwest::header::ACCEPT, "text/event-stream") + .send() + .await + .map_err(ReqwestSseStreamError::ReqwestError)? + .error_for_status() + .map_err(ReqwestSseStreamError::ReqwestError)?; + + if let Some(content_type) = response.headers().get(reqwest::header::CONTENT_TYPE) { + if content_type + .to_str() + .ok() + .is_none_or(|s| !s.eq_ignore_ascii_case("text/event-stream")) + { + return Err(ReqwestSseStreamError::InvalidContentType(Some( + content_type.clone(), + ))); + } + } else { + return Err(ReqwestSseStreamError::InvalidContentType(None)); + } + + Ok(SseStream::from_byte_stream(response.bytes_stream())) +} diff --git a/provider-proxy/Cargo.toml b/provider-proxy/Cargo.toml index c14920cd1b..9f843a22a0 100644 --- a/provider-proxy/Cargo.toml +++ b/provider-proxy/Cargo.toml @@ -31,6 +31,8 @@ tokio.workspace = true tokio-rustls = "0.26.1" tracing = "0.1.43" tracing-subscriber.workspace = true +reqwest-sse-stream = { path = "../internal/reqwest-sse-stream" } +sse-stream = { workspace = true } tempfile = "3.21.0" @@ -41,5 +43,4 @@ workspace = true axum.workspace = true rand = "0.9.1" async-stream.workspace = true -reqwest-eventsource.workspace = true futures-util = "0.3.31" diff --git a/provider-proxy/tests/e2e/tests.rs b/provider-proxy/tests/e2e/tests.rs index 0814d3cdc2..dec257bbe9 100644 --- a/provider-proxy/tests/e2e/tests.rs +++ b/provider-proxy/tests/e2e/tests.rs @@ -15,7 +15,6 @@ use axum::{ use futures_util::StreamExt; use provider_proxy::{Args, CacheMode, run_server}; use rand::Rng; -use reqwest_eventsource::RequestBuilderExt; use serde_json::Value; use tokio::{sync::oneshot, task::JoinHandle}; @@ -432,14 +431,14 @@ async fn test_dropped_stream_body() { .build() .unwrap(); - let mut good_stream = good_client - .post(format!("http://{target_server_addr}/slow")) - .eventsource() - .unwrap(); + let mut good_stream = reqwest_sse_stream::into_sse_stream( + good_client.post(format!("http://{target_server_addr}/slow")), + ) + .await + .unwrap(); // Read the entire stream, so that we're sure that provider-proxy will write the file to disk while let Some(event) = good_stream.next().await { match event { - Err(reqwest_eventsource::Error::StreamEnded) => break, Err(e) => panic!("Unexpected error: {e:?}"), Ok(_) => continue, } @@ -474,23 +473,18 @@ async fn test_dropped_stream_body() { .build() .unwrap(); - let mut first_stream = client - .post(format!("http://{target_server_addr}/slow")) - .eventsource() - .unwrap(); + let mut first_stream = reqwest_sse_stream::into_sse_stream( + client.post(format!("http://{target_server_addr}/slow")), + ) + .await + .unwrap(); let first_event = first_stream.next().await.unwrap().unwrap(); - assert_eq!(first_event, reqwest_eventsource::Event::Open); - - let second_event = first_stream.next().await.unwrap().unwrap(); - let reqwest_eventsource::Event::Message(second_event) = second_event else { - panic!("Unexpected event: {second_event:?}"); - }; - assert_eq!(second_event.data, "Hello"); + assert_eq!(first_event.data, Some("Hello".to_string())); // We should get a timeout let err = first_stream.next().await.unwrap().unwrap_err(); assert!( - matches!(&err, reqwest_eventsource::Error::Transport(e) if e.is_timeout()), + matches!(&err, sse_stream::Error::Body(e) if format!("{e:?}").contains("TimedOut")), "Unexpected error: {err:?}" ); @@ -540,16 +534,15 @@ async fn test_stream_body() { .build() .unwrap(); - let mut second_stream = client - .post(format!("http://{target_server_addr}/slow")) - .eventsource() - .unwrap(); + let mut second_stream = reqwest_sse_stream::into_sse_stream( + client.post(format!("http://{target_server_addr}/slow")), + ) + .await + .unwrap(); while let Some(event) = second_stream.next().await { let event = event.unwrap(); - if let reqwest_eventsource::Event::Message(event) = event - && event.data == "[DONE]" - { + if event.data.as_deref() == Some("done") { break; } } diff --git a/tensorzero-core/Cargo.toml b/tensorzero-core/Cargo.toml index 1508471ff3..060c380c7d 100644 --- a/tensorzero-core/Cargo.toml +++ b/tensorzero-core/Cargo.toml @@ -36,6 +36,7 @@ aws-types = "1.3.6" aws-credential-types = { version = "1.2.2", features = [ "hardcoded-credentials", ] } +sse-stream = { workspace = true } aws-sigv4 = "1.3" aws-smithy-eventstream = "0.60" aws-smithy-runtime-api = "1.11.0" @@ -60,11 +61,11 @@ num-bigint = "0.4" object_store = { workspace = true } rand = { workspace = true } reqwest = { workspace = true } -reqwest-eventsource = { workspace = true } secrecy = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } http = { workspace = true } +reqwest-sse-stream = { path = "../internal/reqwest-sse-stream" } serde_path_to_error = { workspace = true } sha2 = "0.10.9" strum = { version = "0.27.1", features = ["derive"] } diff --git a/tensorzero-core/src/http.rs b/tensorzero-core/src/http.rs index 95ab6494d1..0da9ef62fb 100644 --- a/tensorzero-core/src/http.rs +++ b/tensorzero-core/src/http.rs @@ -13,17 +13,14 @@ use std::{ use tracing::Span; use tracing_futures::Instrument; -use eventsource_stream::Eventsource; use futures::{Stream, StreamExt}; use http::{HeaderMap, HeaderName, HeaderValue}; use pin_project::pin_project; use reqwest::header::{ACCEPT, CONTENT_TYPE}; use reqwest::{Body, Response, StatusCode}; use reqwest::{Client, IntoUrl, NoProxy, Proxy, RequestBuilder}; -use reqwest_eventsource::{ - CannotCloneRequestError, Error as ReqwestEventSourceError, Event, RequestBuilderExt, -}; use serde::{Serialize, de::DeserializeOwned}; +use sse_stream::SseStream; use crate::endpoints::status::TENSORZERO_VERSION; use crate::error::IMPOSSIBLE_ERROR_MESSAGE; @@ -543,11 +540,17 @@ impl<'a> TensorzeroRequestBuilder<'a> { self } - pub fn eventsource(mut self) -> Result { + pub async fn eventsource(mut self) -> Result { self = self.with_otlp_headers(); - let event_source = self.builder.eventsource()?; + let response = self + .builder + .send() + .instrument(tensorzero_h2_workaround_span()) + .await? + .error_for_status()?; + Ok(TensorZeroEventSource { - stream: Box::pin(event_source.map(|r| r.map_err(Box::new))), + stream: SseStream::new(response.body()).pin(), ticket: self.ticket.into_owned(), span: tensorzero_h2_workaround_span(), tensorzero_external_span: tracing::debug_span!( From 6652b891bd24939582acf627ba02e727976e18be Mon Sep 17 00:00:00 2001 From: Aaron Hill Date: Thu, 29 Jan 2026 15:27:04 -0500 Subject: [PATCH 2/5] More work --- internal/autopilot-client/src/client.rs | 64 ++++++++++--------------- 1 file changed, 26 insertions(+), 38 deletions(-) diff --git a/internal/autopilot-client/src/client.rs b/internal/autopilot-client/src/client.rs index d7a0446d6e..6eea9fd2c1 100644 --- a/internal/autopilot-client/src/client.rs +++ b/internal/autopilot-client/src/client.rs @@ -672,32 +672,12 @@ impl AutopilotClient { .append_pair("last_event_id", &last_event_id.to_string()); } - let request = self.sse_http_client.get(url).headers(self.auth_headers()); - - let event_source = - SseStream::from_byte_stream(request.send().await?.error_for_status()?.bytes_stream()); - // Wait for connection to be established or fail. - // The first event should be Open on success, or an error on failure. - match event_source.next().await { - Some(Ok(Sse::Open)) => { - // Connection established successfully - } - Some(Err(e)) => { - // Convert SSE error to appropriate AutopilotError - return Err(Self::convert_sse_error(e)); - } - Some(Ok(SseEvent::Message(_))) => { - return Err(AutopilotError::Sse( - "Received message before connection was established".to_string(), - )); - } - None => { - return Err(AutopilotError::Sse( - "Connection closed unexpectedly".to_string(), - )); - } - } + let event_source = reqwest_sse_stream::into_sse_stream( + self.sse_http_client.get(url).headers(self.auth_headers()), + ) + .await + .map_err(|e| Self::convert_sse_error(e))?; // Connection is good, return the stream let cache = self.tool_call_cache.clone(); @@ -710,10 +690,16 @@ impl AutopilotClient { let spawn_client = spawn_client.clone(); async move { match result { - Ok(SseEvent::Open) => None, - Ok(SseEvent::Message(message)) => { - if message.event == "event" { - match serde_json::from_str::(&message.data) { + Ok(sse) => { + if sse.event.as_deref() == Some("event") { + let data = sse.data.as_ref().ok_or_else(|| { + AutopilotError::Sse(format!("Missing data for event: {sse:?}")) + }); + let data = match data { + Ok(data) => data, + Err(e) => return Some(Err(AutopilotError::Sse(e.to_string()))), + }; + match serde_json::from_str::(data) { Ok(update) => { // Cache tool calls for later lookup if let EventPayload::ToolCall(tool_call) = &update.event.payload @@ -758,16 +744,18 @@ impl AutopilotClient { /// Converts an SSE error to the appropriate AutopilotError. /// HTTP errors are converted to AutopilotError::Http for consistency. - fn convert_sse_error(e: reqwest_eventsource::Error) -> AutopilotError { - use reqwest_eventsource::Error as SseError; + fn convert_sse_error(e: reqwest_sse_stream::ReqwestSseStreamError) -> AutopilotError { match e { - SseError::InvalidStatusCode(status, _response) => AutopilotError::Http { - status_code: status.as_u16(), - message: status - .canonical_reason() - .unwrap_or("Unknown error") - .to_string(), - }, + reqwest_sse_stream::ReqwestSseStreamError::ReqwestError(e) if e.is_status() => { + AutopilotError::Http { + status_code: e.status().map(|s| s.as_u16()).unwrap_or(0), + message: e + .status() + .and_then(|s| s.canonical_reason()) + .unwrap_or("Unknown error") + .to_string(), + } + } other => AutopilotError::Sse(other.to_string()), } } From 4537f4ca0d3ebc7198dac79c58233fe8fc0e1313 Mon Sep 17 00:00:00 2001 From: Aaron Hill Date: Thu, 29 Jan 2026 16:45:34 -0500 Subject: [PATCH 3/5] Work on adjusting usage --- gateway/tests/logging.rs | 45 +- gateway/tests/prometheus.rs | 13 +- gateway/tests/relay/raw_response.rs | 160 +++--- gateway/tests/relay/raw_usage.rs | 160 +++--- internal/autopilot-client/src/client.rs | 1 - tensorzero-core/src/client/mod.rs | 32 +- tensorzero-core/src/http.rs | 64 ++- tensorzero-core/src/inference/mod.rs | 7 +- tensorzero-core/src/providers/anthropic.rs | 13 +- .../src/providers/aws_sagemaker.rs | 36 +- tensorzero-core/src/providers/deepseek.rs | 13 +- .../src/providers/fireworks/mod.rs | 15 +- .../src/providers/gcp_vertex_anthropic.rs | 11 +- .../src/providers/gcp_vertex_gemini/mod.rs | 13 +- .../src/providers/google_ai_studio_gemini.rs | 13 +- tensorzero-core/src/providers/groq.rs | 13 +- tensorzero-core/src/providers/helpers.rs | 20 +- tensorzero-core/src/providers/mistral.rs | 13 +- tensorzero-core/src/providers/openai/mod.rs | 19 +- .../src/providers/openai/responses.rs | 11 +- tensorzero-core/src/providers/openrouter.rs | 13 +- tensorzero-core/src/providers/sglang.rs | 15 +- tensorzero-core/src/providers/tgi.rs | 13 +- tensorzero-core/src/providers/together.rs | 15 +- tensorzero-core/tests/e2e/best_of_n.rs | 13 +- tensorzero-core/tests/e2e/cache.rs | 61 +-- tensorzero-core/tests/e2e/dicl.rs | 30 +- .../e2e/endpoints/internal/evaluations.rs | 241 ++++---- tensorzero-core/tests/e2e/inference/mod.rs | 132 +++-- tensorzero-core/tests/e2e/mixture_of_n.rs | 54 +- .../tests/e2e/openai_compatible.rs | 72 ++- .../tests/e2e/providers/anthropic.rs | 122 ++--- tensorzero-core/tests/e2e/providers/common.rs | 518 +++++++++--------- .../providers/commonv2/cache_input_tokens.rs | 20 +- .../tests/e2e/providers/commonv2/raw_usage.rs | 36 +- .../tests/e2e/providers/commonv2/usage.rs | 36 +- .../tests/e2e/providers/reasoning.rs | 60 +- .../tests/e2e/raw_response/cache.rs | 18 +- tensorzero-core/tests/e2e/raw_response/mod.rs | 162 +++--- .../e2e/raw_response/openai_compatible.rs | 48 +- tensorzero-core/tests/e2e/raw_usage/cache.rs | 18 +- tensorzero-core/tests/e2e/raw_usage/mod.rs | 162 +++--- .../tests/e2e/raw_usage/openai_compatible.rs | 71 ++- tensorzero-core/tests/e2e/retries.rs | 28 +- tensorzero-core/tests/e2e/streaming_errors.rs | 66 +-- tensorzero-core/tests/e2e/timeouts.rs | 58 +- 46 files changed, 1379 insertions(+), 1375 deletions(-) diff --git a/gateway/tests/logging.rs b/gateway/tests/logging.rs index 6ccfe2901f..4a6200b186 100644 --- a/gateway/tests/logging.rs +++ b/gateway/tests/logging.rs @@ -5,7 +5,7 @@ mod common; use common::start_gateway_on_random_port; use futures::StreamExt; use http::StatusCode; -use reqwest_eventsource::{Event, RequestBuilderExt}; +use reqwest_sse_stream::into_sse_stream; use std::time::Duration; use tokio::time::error::Elapsed; @@ -65,33 +65,34 @@ async fn test_log_early_drop_streaming(model_name: &str, expect_finish: bool) { let client = reqwest::Client::new(); - let mut stream = client - .post(format!("http://{}/inference", child_data.addr)) - .json(&serde_json::json!({ - "model_name": model_name, - "input": { - "messages": [ - { - "role": "user", - "content": "Hello, world!" - } - ] - }, - "stream": true, - })) - .eventsource() - .unwrap(); + let mut stream = into_sse_stream( + client + .post(format!("http://{}/inference", child_data.addr)) + .json(&serde_json::json!({ + "model_name": model_name, + "input": { + "messages": [ + { + "role": "user", + "content": "Hello, world!" + } + ] + }, + "stream": true, + })), + ) + .await + .unwrap(); println!("Started stream"); // Cancel the request early, and verify that the gateway logs a warning. let _elapsed = tokio::time::timeout(Duration::from_millis(500), async move { while let Some(event) = stream.next().await { - let event = event.unwrap(); - println!("Event: {event:?}"); - if let Event::Message(event) = event - && event.data == "[DONE]" - { + let sse = event.unwrap(); + println!("Event: {sse:?}"); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { break; } } diff --git a/gateway/tests/prometheus.rs b/gateway/tests/prometheus.rs index 77f7a9d538..11cdfdc5a4 100644 --- a/gateway/tests/prometheus.rs +++ b/gateway/tests/prometheus.rs @@ -1,11 +1,11 @@ #![expect(clippy::print_stdout, clippy::unwrap_used)] use std::time::{Duration, Instant}; +use futures::StreamExt; use reqwest::Client; -use reqwest_eventsource::{Event, RequestBuilderExt}; +use reqwest_sse_stream::into_sse_stream; use tensorzero::test_helpers::get_metrics; use tokio::task::JoinSet; -use tokio_stream::StreamExt; use crate::common::start_gateway_on_random_port; @@ -69,12 +69,11 @@ async fn test_prometheus_metrics_inference_helper(stream: bool) { .json(&inference_payload); if stream { - let mut event_source = builder.eventsource().unwrap(); + let mut event_source = into_sse_stream(builder).await.unwrap(); while let Some(event) = event_source.next().await { - let event = event.unwrap(); - if let Event::Message(event) = event - && event.data == "[DONE]" - { + let sse = event.unwrap(); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { break; } } diff --git a/gateway/tests/relay/raw_response.rs b/gateway/tests/relay/raw_response.rs index 31f46d1b73..438be5932f 100644 --- a/gateway/tests/relay/raw_response.rs +++ b/gateway/tests/relay/raw_response.rs @@ -6,7 +6,7 @@ use crate::common::relay::start_relay_test_environment; use futures::StreamExt; use reqwest::Client; -use reqwest_eventsource::{Event, RequestBuilderExt}; +use reqwest_sse_stream::into_sse_stream; use serde_json::{Value, json}; use uuid::Uuid; @@ -130,44 +130,44 @@ async fn test_relay_raw_response_streaming() { let env = start_relay_test_environment(downstream_config, relay_config).await; let client = Client::new(); - let mut stream = client - .post(format!("http://{}/inference", env.relay.addr)) - .json(&json!({ - "model_name": "openai::gpt-5-nano", - "episode_id": Uuid::now_v7(), - "input": { - "messages": [ - { - "role": "user", - "content": "Hello" + let mut stream = into_sse_stream( + client + .post(format!("http://{}/inference", env.relay.addr)) + .json(&json!({ + "model_name": "openai::gpt-5-nano", + "episode_id": Uuid::now_v7(), + "input": { + "messages": [ + { + "role": "user", + "content": "Hello" + } + ] + }, + "stream": true, + "include_raw_response": true, + "params": { + "chat_completion": { + "reasoning_effort": "minimal" } - ] - }, - "stream": true, - "include_raw_response": true, - "params": { - "chat_completion": { - "reasoning_effort": "minimal" } - } - })) - .eventsource() - .unwrap(); + })), + ) + .await + .unwrap(); let mut found_raw_chunk = false; let mut content_chunks_count: usize = 0; let mut chunks_with_raw_chunk: usize = 0; while let Some(event) = stream.next().await { - let event = event.unwrap(); - let Event::Message(message) = event else { - continue; - }; - if message.data == "[DONE]" { + let sse = event.unwrap(); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { break; } - let chunk: Value = serde_json::from_str(&message.data).unwrap(); + let chunk: Value = serde_json::from_str(&data).unwrap(); // Count content chunks (chunks with content delta) if chunk.get("content").is_some() { @@ -254,40 +254,40 @@ async fn test_relay_raw_response_not_requested_streaming() { let env = start_relay_test_environment(downstream_config, relay_config).await; let client = Client::new(); - let mut stream = client - .post(format!("http://{}/inference", env.relay.addr)) - .json(&json!({ - "model_name": "openai::gpt-5-nano", - "episode_id": Uuid::now_v7(), - "input": { - "messages": [ - { - "role": "user", - "content": "Hello" + let mut stream = into_sse_stream( + client + .post(format!("http://{}/inference", env.relay.addr)) + .json(&json!({ + "model_name": "openai::gpt-5-nano", + "episode_id": Uuid::now_v7(), + "input": { + "messages": [ + { + "role": "user", + "content": "Hello" + } + ] + }, + "stream": true, + "include_raw_response": false, + "params": { + "chat_completion": { + "reasoning_effort": "minimal" } - ] - }, - "stream": true, - "include_raw_response": false, - "params": { - "chat_completion": { - "reasoning_effort": "minimal" } - } - })) - .eventsource() - .unwrap(); + })), + ) + .await + .unwrap(); while let Some(event) = stream.next().await { - let event = event.unwrap(); - let Event::Message(message) = event else { - continue; - }; - if message.data == "[DONE]" { + let sse = event.unwrap(); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { break; } - let chunk: Value = serde_json::from_str(&message.data).unwrap(); + let chunk: Value = serde_json::from_str(&data).unwrap(); // raw_response and raw_chunk should NOT be present when not requested assert!( @@ -440,39 +440,39 @@ reasoning_effort = "minimal" let env = start_relay_test_environment(downstream_config, relay_config).await; let client = Client::new(); - let mut stream = client - .post(format!("http://{}/inference", env.relay.addr)) - .json(&json!({ - "function_name": "best_of_n_test", - "variant_name": "best_of_n", - "episode_id": Uuid::now_v7(), - "input": { - "messages": [ - { - "role": "user", - "content": "Hello" - } - ] - }, - "stream": true, - "include_raw_response": true - })) - .eventsource() - .unwrap(); + let mut stream = into_sse_stream( + client + .post(format!("http://{}/inference", env.relay.addr)) + .json(&json!({ + "function_name": "best_of_n_test", + "variant_name": "best_of_n", + "episode_id": Uuid::now_v7(), + "input": { + "messages": [ + { + "role": "user", + "content": "Hello" + } + ] + }, + "stream": true, + "include_raw_response": true + })), + ) + .await + .unwrap(); let mut raw_response_entries: Vec = Vec::new(); let mut found_raw_chunk = false; while let Some(event) = stream.next().await { - let event = event.unwrap(); - let Event::Message(message) = event else { - continue; - }; - if message.data == "[DONE]" { + let sse = event.unwrap(); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { break; } - let chunk: Value = serde_json::from_str(&message.data).unwrap(); + let chunk: Value = serde_json::from_str(&data).unwrap(); // Check if this chunk has raw_response (previous inferences for best-of-n) if let Some(raw_response) = chunk.get("raw_response") diff --git a/gateway/tests/relay/raw_usage.rs b/gateway/tests/relay/raw_usage.rs index 5d4fb04ec3..5e56efd1f4 100644 --- a/gateway/tests/relay/raw_usage.rs +++ b/gateway/tests/relay/raw_usage.rs @@ -6,7 +6,7 @@ use crate::common::relay::start_relay_test_environment; use futures::StreamExt; use reqwest::Client; -use reqwest_eventsource::{Event, RequestBuilderExt}; +use reqwest_sse_stream::into_sse_stream; use serde_json::{Value, json}; use uuid::Uuid; @@ -133,42 +133,42 @@ async fn test_relay_raw_usage_streaming() { let env = start_relay_test_environment(downstream_config, relay_config).await; let client = Client::new(); - let mut stream = client - .post(format!("http://{}/inference", env.relay.addr)) - .json(&json!({ - "model_name": "openai::gpt-5-nano", - "episode_id": Uuid::now_v7(), - "input": { - "messages": [ - { - "role": "user", - "content": "Hello" + let mut stream = into_sse_stream( + client + .post(format!("http://{}/inference", env.relay.addr)) + .json(&json!({ + "model_name": "openai::gpt-5-nano", + "episode_id": Uuid::now_v7(), + "input": { + "messages": [ + { + "role": "user", + "content": "Hello" + } + ] + }, + "stream": true, + "include_raw_usage": true, + "params": { + "chat_completion": { + "reasoning_effort": "minimal" } - ] - }, - "stream": true, - "include_raw_usage": true, - "params": { - "chat_completion": { - "reasoning_effort": "minimal" } - } - })) - .eventsource() - .unwrap(); + })), + ) + .await + .unwrap(); let mut found_raw_usage = false; while let Some(event) = stream.next().await { - let event = event.unwrap(); - let Event::Message(message) = event else { - continue; - }; - if message.data == "[DONE]" { + let sse = event.unwrap(); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { break; } - let chunk: Value = serde_json::from_str(&message.data).unwrap(); + let chunk: Value = serde_json::from_str(&data).unwrap(); // Check if this chunk has raw_usage (sibling to usage at chunk level) if let Some(raw_usage) = chunk.get("raw_usage") { @@ -275,40 +275,40 @@ async fn test_relay_raw_usage_not_requested_streaming() { let env = start_relay_test_environment(downstream_config, relay_config).await; let client = Client::new(); - let mut stream = client - .post(format!("http://{}/inference", env.relay.addr)) - .json(&json!({ - "model_name": "openai::gpt-5-nano", - "episode_id": Uuid::now_v7(), - "input": { - "messages": [ - { - "role": "user", - "content": "Hello" + let mut stream = into_sse_stream( + client + .post(format!("http://{}/inference", env.relay.addr)) + .json(&json!({ + "model_name": "openai::gpt-5-nano", + "episode_id": Uuid::now_v7(), + "input": { + "messages": [ + { + "role": "user", + "content": "Hello" + } + ] + }, + "stream": true, + "include_raw_usage": false, + "params": { + "chat_completion": { + "reasoning_effort": "minimal" } - ] - }, - "stream": true, - "include_raw_usage": false, - "params": { - "chat_completion": { - "reasoning_effort": "minimal" } - } - })) - .eventsource() - .unwrap(); + })), + ) + .await + .unwrap(); while let Some(event) = stream.next().await { - let event = event.unwrap(); - let Event::Message(message) = event else { - continue; - }; - if message.data == "[DONE]" { + let sse = event.unwrap(); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { break; } - let chunk: Value = serde_json::from_str(&message.data).unwrap(); + let chunk: Value = serde_json::from_str(&data).unwrap(); // raw_usage should NOT be present at chunk level when not requested assert!( @@ -470,38 +470,38 @@ reasoning_effort = "minimal" let env = start_relay_test_environment(downstream_config, relay_config).await; let client = Client::new(); - let mut stream = client - .post(format!("http://{}/inference", env.relay.addr)) - .json(&json!({ - "function_name": "best_of_n_test", - "variant_name": "best_of_n", - "episode_id": Uuid::now_v7(), - "input": { - "messages": [ - { - "role": "user", - "content": "Hello" - } - ] - }, - "stream": true, - "include_raw_usage": true - })) - .eventsource() - .unwrap(); + let mut stream = into_sse_stream( + client + .post(format!("http://{}/inference", env.relay.addr)) + .json(&json!({ + "function_name": "best_of_n_test", + "variant_name": "best_of_n", + "episode_id": Uuid::now_v7(), + "input": { + "messages": [ + { + "role": "user", + "content": "Hello" + } + ] + }, + "stream": true, + "include_raw_usage": true + })), + ) + .await + .unwrap(); let mut raw_usage_entries: Vec = Vec::new(); while let Some(event) = stream.next().await { - let event = event.unwrap(); - let Event::Message(message) = event else { - continue; - }; - if message.data == "[DONE]" { + let sse = event.unwrap(); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { break; } - let chunk: Value = serde_json::from_str(&message.data).unwrap(); + let chunk: Value = serde_json::from_str(&data).unwrap(); // Check if this chunk has raw_usage (sibling to usage at chunk level) if let Some(raw_usage) = chunk.get("raw_usage") diff --git a/internal/autopilot-client/src/client.rs b/internal/autopilot-client/src/client.rs index 6eea9fd2c1..9ef73a80b4 100644 --- a/internal/autopilot-client/src/client.rs +++ b/internal/autopilot-client/src/client.rs @@ -11,7 +11,6 @@ use moka::sync::Cache; use reqwest::header::{AUTHORIZATION, HeaderMap, HeaderValue}; use secrecy::{ExposeSecret, SecretString}; use sqlx::PgPool; -use sse_stream::{Sse, SseStream}; use url::Url; use uuid::Uuid; diff --git a/tensorzero-core/src/client/mod.rs b/tensorzero-core/src/client/mod.rs index a28c52cbfd..3ace7f02d2 100644 --- a/tensorzero-core/src/client/mod.rs +++ b/tensorzero-core/src/client/mod.rs @@ -10,6 +10,7 @@ use crate::endpoints::openai_compatible::types::embeddings::OpenAIEmbeddingRespo use crate::feature_flags; use crate::http::TensorzeroResponseWrapper; use crate::http::{DEFAULT_HTTP_CLIENT_TIMEOUT, TensorzeroHttpClient, TensorzeroRequestBuilder}; +use crate::http::{Event, ReqwestEventSourceError}; use crate::inference::types::stored_input::StoragePathResolver; use crate::observability::{ TENSORZERO_OTLP_ATTRIBUTE_PREFIX, TENSORZERO_OTLP_HEADERS_PREFIX, @@ -25,7 +26,6 @@ use crate::{ utils::gateway::{GatewayHandle, setup_clickhouse, setup_postgres, setup_valkey}, }; use reqwest::header::HeaderMap; -use reqwest_eventsource::Event; use secrecy::{ExposeSecret, SecretString}; use std::fmt::Debug; use tokio::time::error::Elapsed; @@ -206,15 +206,16 @@ impl HTTPGateway { &self, builder: TensorzeroRequestBuilder<'_>, ) -> Result { - let event_source = - self.customize_builder(builder) - .eventsource() - .map_err(|e| TensorZeroError::Other { - source: Error::new(ErrorDetails::JsonRequest { - message: format!("Error constructing event stream: {e:?}"), - }) - .into(), - })?; + let event_source = self + .customize_builder(builder) + .eventsource() + .await + .map_err(|e| TensorZeroError::Other { + source: Error::new(ErrorDetails::JsonRequest { + message: format!("Error constructing event stream: {e:?}"), + }) + .into(), + })?; let mut event_source = event_source.peekable(); let first = event_source.peek().await; @@ -229,7 +230,7 @@ impl HTTPGateway { let inner_err = Error::new(ErrorDetails::StreamError { source: Box::new(Error::new(ErrorDetails::Serialization { message: err_str })), }); - if let reqwest_eventsource::Error::InvalidStatusCode(code, resp) = *e { + if let ReqwestEventSourceError::InvalidStatusCode(code, resp) = *e { return Err(TensorZeroError::Http { status_code: code.as_u16(), text: resp.text().await.ok(), @@ -248,7 +249,7 @@ impl HTTPGateway { while let Some(ev) = event_source.next().await { match ev { Err(e) => { - if matches!(*e, reqwest_eventsource::Error::StreamEnded) { + if matches!(*e, ReqwestEventSourceError::StreamEnded) { break; } yield Err(Error::new(ErrorDetails::StreamError { @@ -263,10 +264,13 @@ impl HTTPGateway { Ok(e) => match e { Event::Open => continue, Event::Message(message) => { - if message.data == "[DONE]" { + let Some(data) = message.data else { + continue; + }; + if data == "[DONE]" { break; } - let json: serde_json::Value = serde_json::from_str(&message.data).map_err(|e| { + let json: serde_json::Value = serde_json::from_str(&data).map_err(|e| { Error::new(ErrorDetails::Serialization { message: format!("Error deserializing inference response chunk: {}", DisplayOrDebug { val: e, diff --git a/tensorzero-core/src/http.rs b/tensorzero-core/src/http.rs index 0da9ef62fb..bd283be7f8 100644 --- a/tensorzero-core/src/http.rs +++ b/tensorzero-core/src/http.rs @@ -20,9 +20,45 @@ use reqwest::header::{ACCEPT, CONTENT_TYPE}; use reqwest::{Body, Response, StatusCode}; use reqwest::{Client, IntoUrl, NoProxy, Proxy, RequestBuilder}; use serde::{Serialize, de::DeserializeOwned}; +pub use sse_stream::Sse; use sse_stream::SseStream; use crate::endpoints::status::TENSORZERO_VERSION; + +/// An SSE event, compatible with the API of `reqwest_eventsource::Event`. +/// This allows us to use the same code paths that were written for `reqwest_eventsource` +/// while using the `sse-stream` crate instead. +#[derive(Debug, Clone)] +pub enum Event { + /// The event source has been opened. + Open, + /// A message was received. + Message(Sse), +} + +/// An error type for SSE event streams, compatible with `reqwest_eventsource::Error`. +/// This provides the same error variants that the rest of the codebase expects. +#[derive(Debug, thiserror::Error)] +pub enum ReqwestEventSourceError { + #[error("Invalid status code: {0}")] + InvalidStatusCode(StatusCode, Response), + #[error("Invalid content type: {0:?}")] + InvalidContentType(HeaderValue, Response), + #[error("Transport error: {0}")] + Transport(#[source] reqwest::Error), + #[error("Stream ended")] + StreamEnded, + #[error("UTF-8 error: {0}")] + Utf8(#[source] std::str::Utf8Error), + #[error("Parser error: {0}")] + Parser(#[source] sse_stream::Error), +} + +impl From for ReqwestEventSourceError { + fn from(e: sse_stream::Error) -> Self { + ReqwestEventSourceError::Parser(e) + } +} use crate::error::IMPOSSIBLE_ERROR_MESSAGE; use crate::observability::overhead_timing::TENSORZERO_EXTERNAL_SPAN_ATTRIBUTE_NAME; use crate::{ @@ -311,7 +347,7 @@ pub struct TensorZeroEventSource { } impl Stream for TensorZeroEventSource { - type Item = Result>; + type Item = Result>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { let this = self.project(); @@ -549,9 +585,18 @@ impl<'a> TensorzeroRequestBuilder<'a> { .await? .error_for_status()?; + let ticket = self.ticket.into_owned(); + // Wrap SSE items in Event::Message and emit an initial Event::Open + let stream = SseStream::from_byte_stream(response.bytes_stream()).map(|event| { + event + .map(Event::Message) + .map_err(|e| Box::new(ReqwestEventSourceError::from(e))) + }); + let stream = futures::stream::once(async { Ok(Event::Open) }).chain(stream); + Ok(TensorZeroEventSource { - stream: SseStream::new(response.body()).pin(), - ticket: self.ticket.into_owned(), + stream: Box::pin(stream), + ticket, span: tensorzero_h2_workaround_span(), tensorzero_external_span: tracing::debug_span!( "eventsource", @@ -578,7 +623,7 @@ impl<'a> TensorzeroRequestBuilder<'a> { let headers = response.headers().clone(); let response = validate_event_stream_response(response).map_err(|e| (e, Some(headers.clone())))?; - let stream = response.bytes_stream().eventsource().map(|event| { + let stream = SseStream::from_byte_stream(response.bytes_stream()).map(|event| { event .map(Event::Message) .map_err(|e| Box::new(ReqwestEventSourceError::from(e))) @@ -794,7 +839,9 @@ mod tests { use reqwest::Proxy; use tokio::task::{JoinHandle, JoinSet}; - use crate::http::{CONCURRENCY_LIMIT, LimitedClient, TensorZeroEventSource}; + use crate::http::{ + CONCURRENCY_LIMIT, LimitedClient, ReqwestEventSourceError, TensorZeroEventSource, + }; async fn start_target_server() -> (SocketAddr, JoinHandle>) { let app = Router::new() @@ -827,7 +874,7 @@ mod tests { match event { Ok(_) => {} Err(e) => { - if matches!(*e, reqwest_eventsource::Error::StreamEnded) { + if matches!(*e, ReqwestEventSourceError::StreamEnded) { break; } panic!("Error in streaming response: {e:?}"); @@ -883,6 +930,7 @@ mod tests { let mut event_source = client .get(format!("http://{addr}/hello-stream")) .eventsource() + .await .unwrap(); process_stream(&mut event_source).await; drop(event_source); @@ -922,6 +970,7 @@ mod tests { let mut stream = client .get(format!("http://{addr}/hello-stream")) .eventsource() + .await .unwrap(); process_stream(&mut stream).await; }); @@ -974,6 +1023,7 @@ mod tests { let mut stream = client .get(format!("http://{addr}/hello-stream")) .eventsource() + .await .unwrap(); process_stream(&mut stream).await; drop(stream); @@ -1021,6 +1071,7 @@ mod tests { let mut stream = client .get(format!("http://{addr}/hello-stream")) .eventsource() + .await .unwrap(); process_stream(&mut stream).await; }); @@ -1038,6 +1089,7 @@ mod tests { let mut stream = client .get(format!("http://{addr}/hello-stream")) .eventsource() + .await .unwrap(); process_stream(&mut stream).await; diff --git a/tensorzero-core/src/inference/mod.rs b/tensorzero-core/src/inference/mod.rs index 2ea63217cf..f99814347b 100644 --- a/tensorzero-core/src/inference/mod.rs +++ b/tensorzero-core/src/inference/mod.rs @@ -4,6 +4,7 @@ use crate::cache::ModelProviderRequest; use crate::endpoints::inference::InferenceCredentials; use crate::error::Error; use crate::http::TensorzeroHttpClient; +use crate::http::{Event, ReqwestEventSourceError}; use crate::inference::types::Latency; use crate::inference::types::ModelInferenceRequest; use crate::inference::types::PeekableProviderInferenceResponseStream; @@ -16,19 +17,19 @@ use crate::model::ModelProvider; use async_trait::async_trait; use futures::Future; use futures::Stream; -use reqwest_eventsource::Event; use std::borrow::Cow; use std::fmt::Debug; use std::pin::Pin; use tokio::time::Instant; use uuid::Uuid; -/// A helper type for preserving custom errors when working with `reqwest_eventsource` +/// A helper type for preserving custom errors when working with SSE event streams. /// This is currently used by `stream_openai` to allow using it with a provider /// that needs to do additional validation when streaming (e.g. Sagemaker) +#[derive(Debug)] pub enum TensorZeroEventError { TensorZero(Error), - EventSource(Box), + EventSource(Box), } pub trait InferenceProvider { diff --git a/tensorzero-core/src/providers/anthropic.rs b/tensorzero-core/src/providers/anthropic.rs index eaabae604f..45075837cb 100644 --- a/tensorzero-core/src/providers/anthropic.rs +++ b/tensorzero-core/src/providers/anthropic.rs @@ -1,9 +1,9 @@ +use crate::http::Event; use futures::StreamExt; use futures::future::try_join_all; use lazy_static::lazy_static; use mime::MediaType; use reqwest::StatusCode; -use reqwest_eventsource::Event; use secrecy::{ExposeSecret, SecretString}; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -507,15 +507,18 @@ fn stream_anthropic( Ok(event) => match event { Event::Open => continue, Event::Message(message) => { + let Some(message_data) = message.data else { + continue; + }; let data: Result = - serde_json::from_str(&message.data).map_err(|e| Error::new(ErrorDetails::InferenceServer { + serde_json::from_str(&message_data).map_err(|e| Error::new(ErrorDetails::InferenceServer { message: format!( "Error parsing message: {}, Data: {}", - e, message.data + e, message_data ), provider_type: PROVIDER_TYPE.to_string(), raw_request: Some(raw_request.to_string()), - raw_response: Some(message.data.clone()), + raw_response: Some(message_data.clone()), })); // Anthropic streaming API docs specify that this is the last message if let Ok(AnthropicStreamMessage::MessageStop) = data { @@ -524,7 +527,7 @@ fn stream_anthropic( let response = data.and_then(|data| { anthropic_to_tensorzero_stream_message( - message.data, + message_data, data, start_time.elapsed(), &mut tool_state, diff --git a/tensorzero-core/src/providers/aws_sagemaker.rs b/tensorzero-core/src/providers/aws_sagemaker.rs index c3e3de8723..25dcd924ea 100644 --- a/tensorzero-core/src/providers/aws_sagemaker.rs +++ b/tensorzero-core/src/providers/aws_sagemaker.rs @@ -1,11 +1,12 @@ //! AWS SageMaker model provider using direct HTTP calls. +use crate::http::{Event, ReqwestEventSourceError}; use aws_smithy_eventstream::frame::{DecodedFrame, MessageFrameDecoder}; use aws_types::region::Region; use bytes::BytesMut; -use eventsource_stream::Eventsource; use futures::StreamExt; use serde::Serialize; +use sse_stream::SseStream; use std::time::Instant; use super::aws_common::{ @@ -24,7 +25,6 @@ use crate::inference::types::{ }; use crate::inference::{InferenceProvider, TensorZeroEventError, WrappedProvider}; use crate::model::ModelProvider; -use eventsource_stream::EventStreamError; #[expect(unused)] const PROVIDER_NAME: &str = "AWS Sagemaker"; @@ -343,20 +343,26 @@ impl InferenceProvider for AWSSagemakerProvider { } }; - // Second, convert the byte stream to SSE events using eventsource_stream + // Second, convert the byte stream to SSE events using sse_stream // The payload bytes contain SSE text from the hosted model (OpenAI/TGI) - let event_stream = futures::stream::iter([Ok(reqwest_eventsource::Event::Open)]).chain( - sagemaker_byte_stream.eventsource().map(|r| match r { - Ok(msg) => Ok(reqwest_eventsource::Event::Message(msg)), - Err(e) => match e { - EventStreamError::Utf8(err) => Err(TensorZeroEventError::EventSource( - Box::new(reqwest_eventsource::Error::Utf8(err)), - )), - EventStreamError::Parser(err) => Err(TensorZeroEventError::EventSource( - Box::new(reqwest_eventsource::Error::Parser(err)), - )), - EventStreamError::Transport(err) => Err(err), - }, + // First, split the stream to handle errors separately from the Ok bytes + let bytes_only_stream = sagemaker_byte_stream.filter_map(|r| async move { + match r { + Ok(bytes) => Some(Ok::<_, std::io::Error>(bytes::Bytes::from(bytes))), + Err(e) => { + // Log or handle TensorZeroEventError here if needed + // For now we skip them as they're fatal errors that terminate the stream + tracing::error!("SageMaker stream error: {e:?}"); + None + } + } + }); + let event_stream = futures::stream::iter([Ok(Event::Open)]).chain( + SseStream::from_byte_stream(bytes_only_stream).map(|r| match r { + Ok(sse) => Ok(Event::Message(sse)), + Err(e) => Err(TensorZeroEventError::EventSource(Box::new( + ReqwestEventSourceError::from(e), + ))), }), ); diff --git a/tensorzero-core/src/providers/deepseek.rs b/tensorzero-core/src/providers/deepseek.rs index 5e75ca228a..cb10ec1a3c 100644 --- a/tensorzero-core/src/providers/deepseek.rs +++ b/tensorzero-core/src/providers/deepseek.rs @@ -1,6 +1,6 @@ +use crate::http::Event; use futures::StreamExt; use lazy_static::lazy_static; -use reqwest_eventsource::Event; use secrecy::{ExposeSecret, SecretString}; use serde::{Deserialize, Serialize}; use std::borrow::Cow; @@ -495,23 +495,26 @@ fn stream_deepseek( Ok(event) => match event { Event::Open => continue, Event::Message(message) => { - if message.data == "[DONE]" { + let Some(message_data) = message.data else { + continue; + }; + if message_data == "[DONE]" { break; } let data: Result = - serde_json::from_str(&message.data).map_err(|e| Error::new(ErrorDetails::InferenceServer { + serde_json::from_str(&message_data).map_err(|e| Error::new(ErrorDetails::InferenceServer { message: format!( "Error parsing chunk. Error: {e}", ), raw_request: Some(raw_request.clone()), - raw_response: Some(message.data.clone()), + raw_response: Some(message_data.clone()), provider_type: PROVIDER_TYPE.to_string(), })); let latency = start_time.elapsed(); let stream_message = data.and_then(|d| { deepseek_to_tensorzero_chunk( - message.data, + message_data, d, latency, &mut tool_call_ids, diff --git a/tensorzero-core/src/providers/fireworks/mod.rs b/tensorzero-core/src/providers/fireworks/mod.rs index 2c73c829ad..96e5bb6bf2 100644 --- a/tensorzero-core/src/providers/fireworks/mod.rs +++ b/tensorzero-core/src/providers/fireworks/mod.rs @@ -1,6 +1,7 @@ use std::borrow::Cow; use crate::http::TensorzeroHttpClient; +use crate::http::{Event, ReqwestEventSourceError}; use crate::inference::types::chat_completion_inference_params::{ ChatCompletionInferenceParamsV2, warn_inference_parameter_not_supported, }; @@ -11,7 +12,6 @@ use crate::{ }; use futures::StreamExt; use lazy_static::lazy_static; -use reqwest_eventsource::Event; use secrecy::{ExposeSecret, SecretString}; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -691,7 +691,7 @@ fn stream_fireworks( Err(e) => { let message = e.to_string(); let mut raw_response = None; - if let reqwest_eventsource::Error::InvalidStatusCode(_, resp) = *e { + if let ReqwestEventSourceError::InvalidStatusCode(_, resp) = *e { raw_response = resp.text().await.ok(); } yield Err(ErrorDetails::InferenceServer { @@ -704,21 +704,24 @@ fn stream_fireworks( Ok(event) => match event { Event::Open => continue, Event::Message(message) => { - if message.data == "[DONE]" { + let Some(message_data) = message.data else { + continue; + }; + if message_data == "[DONE]" { break; } let data: Result = - serde_json::from_str(&message.data).map_err(|e| Error::new(ErrorDetails::InferenceServer { + serde_json::from_str(&message_data).map_err(|e| Error::new(ErrorDetails::InferenceServer { message: format!("Error parsing chunk. Error: {e}"), raw_request: None, - raw_response: Some(message.data.clone()), + raw_response: Some(message_data.clone()), provider_type: PROVIDER_TYPE.to_string(), })); let latency = start_time.elapsed(); let stream_message = data.and_then(|d| { fireworks_to_tensorzero_chunk( - message.data, + message_data, d, latency, &mut tool_call_ids, diff --git a/tensorzero-core/src/providers/gcp_vertex_anthropic.rs b/tensorzero-core/src/providers/gcp_vertex_anthropic.rs index ad1ac4571f..08e499dd32 100644 --- a/tensorzero-core/src/providers/gcp_vertex_anthropic.rs +++ b/tensorzero-core/src/providers/gcp_vertex_anthropic.rs @@ -1,9 +1,9 @@ use std::borrow::Cow; use std::fmt::Display; +use crate::http::Event; use futures::StreamExt; use futures::future::try_join_all; -use reqwest_eventsource::Event; use serde::{Deserialize, Serialize}; use std::fmt::Debug; use tensorzero_derive::TensorZeroDeserialize; @@ -412,11 +412,14 @@ fn stream_anthropic( Ok(event) => match event { Event::Open => continue, Event::Message(message) => { + let Some(message_data) = message.data else { + continue; + }; let data: Result = - serde_json::from_str(&message.data).map_err(|e| Error::new(ErrorDetails::InferenceServer { + serde_json::from_str(&message_data).map_err(|e| Error::new(ErrorDetails::InferenceServer { message: format!( "Error parsing message: {}, Data: {}", - e, message.data + e, message_data ), provider_type: PROVIDER_TYPE.to_string(), raw_request: Some(raw_request.clone()), @@ -429,7 +432,7 @@ fn stream_anthropic( let response = data.and_then(|data| { anthropic_to_tensorzero_stream_message( - message.data, + message_data, data, start_time.elapsed(), &mut tool_state, diff --git a/tensorzero-core/src/providers/gcp_vertex_gemini/mod.rs b/tensorzero-core/src/providers/gcp_vertex_gemini/mod.rs index 592d10c8ba..766e282775 100644 --- a/tensorzero-core/src/providers/gcp_vertex_gemini/mod.rs +++ b/tensorzero-core/src/providers/gcp_vertex_gemini/mod.rs @@ -5,6 +5,7 @@ use std::sync::{Arc, OnceLock}; use std::time::Duration; use crate::error::DelayedError; +use crate::http::{Event, ReqwestEventSourceError}; use axum::http; use futures::StreamExt; use futures::future::try_join_all; @@ -15,7 +16,6 @@ use jsonwebtoken::{Algorithm, EncodingKey, Header, encode}; use object_store::gcp::{GcpCredential, GoogleCloudStorageBuilder}; use object_store::{ObjectStore, ObjectStoreExt, StaticCredentialProvider}; use reqwest::StatusCode; -use reqwest_eventsource::Event; use secrecy::{ExposeSecret, SecretString}; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -1587,7 +1587,7 @@ fn stream_gcp_vertex_gemini( while let Some(ev) = event_source.next().await { match ev { Err(e) => { - if matches!(*e, reqwest_eventsource::Error::StreamEnded) { + if matches!(*e, ReqwestEventSourceError::StreamEnded) { break; } yield Err(convert_stream_error(raw_request.clone(), PROVIDER_TYPE.to_string(), *e, None).await); @@ -1595,12 +1595,15 @@ fn stream_gcp_vertex_gemini( Ok(event) => match event { Event::Open => continue, Event::Message(message) => { - let data: Result = serde_json::from_str(&message.data).map_err(|e| { + let Some(message_data) = message.data else { + continue; + }; + let data: Result = serde_json::from_str(&message_data).map_err(|e| { Error::new(ErrorDetails::InferenceServer { message: format!("Error parsing streaming JSON response: {}", DisplayOrDebugGateway::new(e)), provider_type: PROVIDER_TYPE.to_string(), raw_request: Some(raw_request.clone()), - raw_response: Some(message.data.clone()), + raw_response: Some(message_data.clone()), }) }); let data = match data { @@ -1611,7 +1614,7 @@ fn stream_gcp_vertex_gemini( } }; yield convert_stream_response_with_metadata_to_chunk( - message.data, + message_data, data, start_time.elapsed(), &mut last_tool_name, diff --git a/tensorzero-core/src/providers/google_ai_studio_gemini.rs b/tensorzero-core/src/providers/google_ai_studio_gemini.rs index 22f2c2d93c..3264d99852 100644 --- a/tensorzero-core/src/providers/google_ai_studio_gemini.rs +++ b/tensorzero-core/src/providers/google_ai_studio_gemini.rs @@ -1,9 +1,9 @@ use std::borrow::Cow; use std::time::Duration; +use crate::http::{Event, ReqwestEventSourceError}; use futures::{StreamExt, future::try_join_all}; use reqwest::StatusCode; -use reqwest_eventsource::Event; use secrecy::{ExposeSecret, SecretString}; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -369,7 +369,7 @@ fn stream_google_ai_studio_gemini( while let Some(ev) = event_source.next().await { match ev { Err(e) => { - if matches!(*e, reqwest_eventsource::Error::StreamEnded) { + if matches!(*e, ReqwestEventSourceError::StreamEnded) { break; } yield Err(convert_stream_error(raw_request.clone(), PROVIDER_TYPE.to_string(), *e, None).await); @@ -377,12 +377,15 @@ fn stream_google_ai_studio_gemini( Ok(event) => match event { Event::Open => continue, Event::Message(message) => { - let data: Result = serde_json::from_str(&message.data).map_err(|e| { + let Some(message_data) = message.data else { + continue; + }; + let data: Result = serde_json::from_str(&message_data).map_err(|e| { Error::new(ErrorDetails::InferenceServer { message: format!("Error parsing streaming JSON response: {}", DisplayOrDebugGateway::new(e)), provider_type: PROVIDER_TYPE.to_string(), raw_request: Some(raw_request.clone()), - raw_response: Some(message.data.clone()), + raw_response: Some(message_data.clone()), }) }); let data = match data { @@ -394,7 +397,7 @@ fn stream_google_ai_studio_gemini( }; yield convert_stream_response_with_metadata_to_chunk( ConvertStreamResponseArgs { - raw_response: message.data, + raw_response: message_data, response: data, latency: start_time.elapsed(), last_tool_name: &mut last_tool_name, diff --git a/tensorzero-core/src/providers/groq.rs b/tensorzero-core/src/providers/groq.rs index 5e9271966e..7b6575218d 100644 --- a/tensorzero-core/src/providers/groq.rs +++ b/tensorzero-core/src/providers/groq.rs @@ -1,7 +1,7 @@ +use crate::http::Event; use futures::future::try_join_all; use futures::{Stream, StreamExt, TryStreamExt}; use reqwest::StatusCode; -use reqwest_eventsource::Event; use secrecy::{ExposeSecret, SecretString}; use serde::de::IntoDeserializer; use serde::{Deserialize, Deserializer, Serialize}; @@ -353,23 +353,26 @@ pub fn stream_groq( Ok(event) => match event { Event::Open => continue, Event::Message(message) => { - if message.data == "[DONE]" { + let Some(message_data) = message.data else { + continue; + }; + if message_data == "[DONE]" { break; } let data: Result = - serde_json::from_str(&message.data).map_err(|e| Error::new(ErrorDetails::InferenceServer { + serde_json::from_str(&message_data).map_err(|e| Error::new(ErrorDetails::InferenceServer { message: format!( "Error parsing chunk. Error: {e}", ), raw_request: Some(raw_request.clone()), - raw_response: Some(message.data.clone()), + raw_response: Some(message_data.clone()), provider_type: provider_type.clone(), })); let latency = start_time.elapsed(); let stream_message = data.and_then(|d| { groq_to_tensorzero_chunk( - message.data.clone(), + message_data.clone(), d, latency, &mut tool_call_ids, diff --git a/tensorzero-core/src/providers/helpers.rs b/tensorzero-core/src/providers/helpers.rs index 0c6eb67bff..1c50582c93 100644 --- a/tensorzero-core/src/providers/helpers.rs +++ b/tensorzero-core/src/providers/helpers.rs @@ -8,7 +8,10 @@ use uuid::Uuid; use crate::{ error::{DisplayOrDebugGateway, Error, ErrorDetails, IMPOSSIBLE_ERROR_MESSAGE}, - http::{TensorZeroEventSource, TensorzeroRequestBuilder, TensorzeroResponseWrapper}, + http::{ + ReqwestEventSourceError, TensorZeroEventSource, TensorzeroRequestBuilder, + TensorzeroResponseWrapper, + }, inference::types::{ ProviderInferenceResponseChunk, batch::{ProviderBatchInferenceOutput, ProviderBatchInferenceResponse}, @@ -29,7 +32,7 @@ pub struct JsonlBatchFileInfo { pub async fn convert_stream_error( raw_request: String, provider_type: String, - e: reqwest_eventsource::Error, + e: ReqwestEventSourceError, request_id: Option<&str>, ) -> Error { let base_message = e.to_string(); @@ -39,8 +42,8 @@ pub async fn convert_stream_error( // to avoid holding open a broken stream (which will delay gateway shutdown when we // wait on the parent `Span` to finish) match e { - reqwest_eventsource::Error::InvalidStatusCode(_, resp) - | reqwest_eventsource::Error::InvalidContentType(_, resp) => { + ReqwestEventSourceError::InvalidStatusCode(_, resp) + | ReqwestEventSourceError::InvalidContentType(_, resp) => { let raw_response = resp.text().await.ok(); let message = match (&raw_response, request_id) { (Some(body), Some(id)) => format!("{base_message}: {body} [request_id: {id}]"), @@ -56,7 +59,7 @@ pub async fn convert_stream_error( } .into() } - reqwest_eventsource::Error::Transport(inner) => { + ReqwestEventSourceError::Transport(inner) => { // Timeouts at the reqwest level are from `gateway.global_outbound_http_timeout_ms`. // Variant/model/provider-level timeouts are handled via `tokio::time::timeout` // and produce distinct error types (VariantTimeout, ModelTimeout, ModelProviderTimeout). @@ -345,7 +348,7 @@ pub async fn inject_extra_request_data_and_send_eventsource_with_headers( Err((e, headers)) => { // Extract status code first (by borrowing), then consume Response to read body let (message, raw_response) = match e { - reqwest_eventsource::Error::InvalidStatusCode(status, resp) => { + ReqwestEventSourceError::InvalidStatusCode(status, resp) => { let body = resp.text().await.ok(); let message = match &body { Some(b) => { @@ -355,7 +358,7 @@ pub async fn inject_extra_request_data_and_send_eventsource_with_headers( }; (message, body) } - reqwest_eventsource::Error::InvalidContentType(content_type, resp) => { + ReqwestEventSourceError::InvalidContentType(content_type, resp) => { let body = resp.text().await.ok(); let message = match &body { Some(b) => format!( @@ -373,7 +376,8 @@ pub async fn inject_extra_request_data_and_send_eventsource_with_headers( // Timeouts at the reqwest level are from `gateway.global_outbound_http_timeout_ms`. // Variant/model/provider-level timeouts are handled via `tokio::time::timeout` // and produce distinct error types (VariantTimeout, ModelTimeout, ModelProviderTimeout). - let is_timeout = matches!(&other, reqwest_eventsource::Error::Transport(e) if e.is_timeout()); + let is_timeout = + matches!(&other, ReqwestEventSourceError::Transport(e) if e.is_timeout()); let message = if is_timeout { format!( "Request timed out due to `gateway.global_outbound_http_timeout_ms`. Consider increasing this value in your configuration if you expect inferences to take longer to complete. ({})", diff --git a/tensorzero-core/src/providers/mistral.rs b/tensorzero-core/src/providers/mistral.rs index edabe54cb3..8fd0d11510 100644 --- a/tensorzero-core/src/providers/mistral.rs +++ b/tensorzero-core/src/providers/mistral.rs @@ -1,5 +1,6 @@ use std::{borrow::Cow, time::Duration}; +use crate::http::Event; use crate::{ http::{TensorZeroEventSource, TensorzeroHttpClient}, providers::openai::OpenAIMessagesConfig, @@ -7,7 +8,6 @@ use crate::{ use futures::{StreamExt, future::try_join_all}; use lazy_static::lazy_static; use reqwest::StatusCode; -use reqwest_eventsource::Event; use secrecy::{ExposeSecret, SecretString}; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -372,14 +372,17 @@ pub fn stream_mistral( Ok(event) => match event { Event::Open => continue, Event::Message(message) => { - if message.data == "[DONE]" { + let Some(message_data) = message.data else { + continue; + }; + if message_data == "[DONE]" { break; } let data: Result = - serde_json::from_str(&message.data).map_err(|e| ErrorDetails::InferenceServer { + serde_json::from_str(&message_data).map_err(|e| ErrorDetails::InferenceServer { message: format!( "Error parsing chunk. Error: {}, Data: {}", - e, message.data + e, message_data ), provider_type: PROVIDER_TYPE.to_string(), raw_request: Some(raw_request.clone()), @@ -388,7 +391,7 @@ pub fn stream_mistral( let latency = start_time.elapsed(); let stream_message = data.and_then(|d| { mistral_to_tensorzero_chunk( - message.data, + message_data, d, latency, &mut last_tool_name, diff --git a/tensorzero-core/src/providers/openai/mod.rs b/tensorzero-core/src/providers/openai/mod.rs index 31a28c5f06..1f2240136a 100644 --- a/tensorzero-core/src/providers/openai/mod.rs +++ b/tensorzero-core/src/providers/openai/mod.rs @@ -1,10 +1,10 @@ +use crate::http::{Event, ReqwestEventSourceError}; use async_trait::async_trait; use futures::future::try_join_all; use futures::{Stream, StreamExt, TryStreamExt}; use lazy_static::lazy_static; use reqwest::StatusCode; use reqwest::multipart::{Form, Part}; -use reqwest_eventsource::Event; use responses::stream_openai_responses; use secrecy::{ExposeSecret, SecretString}; use serde::de::IntoDeserializer; @@ -1062,11 +1062,14 @@ pub fn stream_openai( Ok(event) => match event { Event::Open => continue, Event::Message(message) => { - if message.data == "[DONE]" { + let Some(message_data) = message.data else { + continue; + }; + if message_data == "[DONE]" { break; } let data: Result = - serde_json::from_str(&message.data).map_err(|e| { + serde_json::from_str(&message_data).map_err(|e| { let error_message = match &request_id { Some(id) => format!("Error parsing chunk. Error: {e} [request_id: {id}]"), None => format!("Error parsing chunk. Error: {e}"), @@ -1074,7 +1077,7 @@ pub fn stream_openai( Error::new(ErrorDetails::InferenceServer { message: error_message, raw_request: Some(raw_request.clone()), - raw_response: Some(message.data.clone()), + raw_response: Some(message_data.clone()), provider_type: provider_type.clone(), }) }); @@ -1082,7 +1085,7 @@ pub fn stream_openai( let latency = start_time.elapsed(); let stream_message = data.and_then(|d| { openai_to_tensorzero_chunk( - message.data, + message_data, d, latency, &mut tool_call_ids, @@ -1237,11 +1240,11 @@ fn extract_request_id(headers: &reqwest::header::HeaderMap) -> Option { } pub(super) fn request_id_from_event_source_error( - error: &reqwest_eventsource::Error, + error: &ReqwestEventSourceError, ) -> Option { match error { - reqwest_eventsource::Error::InvalidStatusCode(_, resp) - | reqwest_eventsource::Error::InvalidContentType(_, resp) => { + ReqwestEventSourceError::InvalidStatusCode(_, resp) + | ReqwestEventSourceError::InvalidContentType(_, resp) => { extract_request_id(resp.headers()) } _ => None, diff --git a/tensorzero-core/src/providers/openai/responses.rs b/tensorzero-core/src/providers/openai/responses.rs index 8778ccaa3c..d77391c843 100644 --- a/tensorzero-core/src/providers/openai/responses.rs +++ b/tensorzero-core/src/providers/openai/responses.rs @@ -12,11 +12,11 @@ use crate::{ }; const PROVIDER_NAME: &str = "OpenAI Responses"; +use crate::http::Event; use crate::providers::helpers::convert_stream_error; use crate::{error::IMPOSSIBLE_ERROR_MESSAGE, inference::TensorZeroEventError}; use futures::StreamExt; use futures::{Stream, future::try_join_all}; -use reqwest_eventsource::Event; use serde::{Deserialize, Serialize}; use serde_json::{Value, json}; use tokio::time::Instant; @@ -1197,10 +1197,13 @@ pub fn stream_openai_responses( Ok(event) => match event { Event::Open => continue, Event::Message(message) => { + let Some(message_data) = message.data else { + continue; + }; // OpenAI Responses API does not send [DONE] marker // Instead, we check for terminal events: completed, failed, or incomplete let data: Result = - serde_json::from_str(&message.data); + serde_json::from_str(&message_data); // If we can't parse the event at all, log and skip it let event = match data { @@ -1209,7 +1212,7 @@ pub fn stream_openai_responses( tracing::warn!( "Failed to parse OpenAI Responses stream event, skipping. Error: {}, Data: {}", e, - message.data + message_data ); continue; } @@ -1225,7 +1228,7 @@ pub fn stream_openai_responses( let latency = start_time.elapsed(); let stream_message = openai_responses_to_tensorzero_chunk( - message.data, + message_data, event, latency, &mut current_tool_id, diff --git a/tensorzero-core/src/providers/openrouter.rs b/tensorzero-core/src/providers/openrouter.rs index 257a1dcdcc..f5b9a86ac4 100644 --- a/tensorzero-core/src/providers/openrouter.rs +++ b/tensorzero-core/src/providers/openrouter.rs @@ -1,8 +1,8 @@ +use crate::http::Event; use futures::future::try_join_all; use futures::{Stream, StreamExt, TryStreamExt}; use lazy_static::lazy_static; use reqwest::StatusCode; -use reqwest_eventsource::Event; use secrecy::{ExposeSecret, SecretString}; use serde::de::IntoDeserializer; use serde::{Deserialize, Deserializer, Serialize}; @@ -491,23 +491,26 @@ pub fn stream_openrouter( Ok(event) => match event { Event::Open => continue, Event::Message(message) => { - if message.data == "[DONE]" { + let Some(message_data) = message.data else { + continue; + }; + if message_data == "[DONE]" { break; } let data: Result = - serde_json::from_str(&message.data).map_err(|e| Error::new(ErrorDetails::InferenceServer { + serde_json::from_str(&message_data).map_err(|e| Error::new(ErrorDetails::InferenceServer { message: format!( "Error parsing chunk. Error: {e}", ), raw_request: Some(raw_request.clone()), - raw_response: Some(message.data.clone()), + raw_response: Some(message_data.clone()), provider_type: provider_type.clone(), })); let latency = start_time.elapsed(); let stream_message = data.and_then(|d| { openrouter_to_tensorzero_chunk( - message.data, + message_data, d, latency, &mut tool_call_ids, diff --git a/tensorzero-core/src/providers/sglang.rs b/tensorzero-core/src/providers/sglang.rs index 3a98ce12c8..5871dd304c 100644 --- a/tensorzero-core/src/providers/sglang.rs +++ b/tensorzero-core/src/providers/sglang.rs @@ -1,9 +1,9 @@ use std::borrow::Cow; use std::time::Duration; +use crate::http::{Event, ReqwestEventSourceError}; use crate::http::{TensorZeroEventSource, TensorzeroHttpClient}; use futures::StreamExt; -use reqwest_eventsource::Event; use secrecy::{ExposeSecret, SecretString}; use serde::{Deserialize, Serialize}; use serde_json::{Value, json}; @@ -391,7 +391,7 @@ fn stream_sglang( Err(e) => { let message = e.to_string(); let mut raw_response = None; - if let reqwest_eventsource::Error::InvalidStatusCode(_, resp) = *e { + if let ReqwestEventSourceError::InvalidStatusCode(_, resp) = *e { raw_response = resp.text().await.ok(); } yield Err(ErrorDetails::InferenceServer { @@ -404,21 +404,24 @@ fn stream_sglang( Ok(event) => match event { Event::Open => continue, Event::Message(message) => { - if message.data == "[DONE]" { + let Some(message_data) = message.data else { + continue; + }; + if message_data == "[DONE]" { break; } let data: Result = - serde_json::from_str(&message.data).map_err(|e| Error::new(ErrorDetails::InferenceServer { + serde_json::from_str(&message_data).map_err(|e| Error::new(ErrorDetails::InferenceServer { message: format!("Error parsing chunk. Error: {e}"), raw_request: None, - raw_response: Some(message.data.clone()), + raw_response: Some(message_data.clone()), provider_type: PROVIDER_TYPE.to_string(), })); let latency = start_time.elapsed(); let stream_message = data.and_then(|d| { sglang_to_tensorzero_chunk( - message.data, + message_data, d, latency, &mut tool_call_ids, diff --git a/tensorzero-core/src/providers/tgi.rs b/tensorzero-core/src/providers/tgi.rs index e8ffc33fca..2e29a451b4 100644 --- a/tensorzero-core/src/providers/tgi.rs +++ b/tensorzero-core/src/providers/tgi.rs @@ -1,3 +1,4 @@ +use crate::http::Event; use crate::http::TensorzeroHttpClient; use async_trait::async_trait; /// TGI integration for TensorZero @@ -14,7 +15,6 @@ use async_trait::async_trait; /// Our implementation currently allows you to use a tool in TGI (nonstreaming), but YMMV. use futures::{Stream, StreamExt, TryStreamExt}; use reqwest::StatusCode; -use reqwest_eventsource::Event; use secrecy::{ExposeSecret, SecretString}; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -408,22 +408,25 @@ fn stream_tgi( Ok(event) => match event { Event::Open => continue, Event::Message(message) => { - if message.data == "[DONE]" { + let Some(message_data) = message.data else { + continue; + }; + if message_data == "[DONE]" { break; } let data: Result = - serde_json::from_str(&message.data).map_err(|e| Error::new(ErrorDetails::InferenceServer { + serde_json::from_str(&message_data).map_err(|e| Error::new(ErrorDetails::InferenceServer { message: format!( "Error parsing chunk. Error: {e}", ), raw_request: Some(raw_request.clone()), - raw_response: Some(message.data.clone()), + raw_response: Some(message_data.clone()), provider_type: PROVIDER_TYPE.to_string(), })); let latency = start_time.elapsed(); let stream_message = data.and_then(|d| { - tgi_to_tensorzero_chunk(message.data, d, latency, model_inference_id) + tgi_to_tensorzero_chunk(message_data, d, latency, model_inference_id) }); yield stream_message; } diff --git a/tensorzero-core/src/providers/together.rs b/tensorzero-core/src/providers/together.rs index edd5c2beb4..f9ece84cac 100644 --- a/tensorzero-core/src/providers/together.rs +++ b/tensorzero-core/src/providers/together.rs @@ -1,5 +1,6 @@ use std::{borrow::Cow, time::Duration}; +use crate::http::{Event, ReqwestEventSourceError}; use crate::inference::types::RequestMessage; use crate::inference::types::chat_completion_inference_params::{ ChatCompletionInferenceParamsV2, warn_inference_parameter_not_supported, @@ -7,7 +8,6 @@ use crate::inference::types::chat_completion_inference_params::{ use crate::providers::openai::OpenAIMessagesConfig; use futures::{StreamExt, future::try_join_all}; use lazy_static::lazy_static; -use reqwest_eventsource::Event; use secrecy::{ExposeSecret, SecretString}; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -709,7 +709,7 @@ fn stream_together( Err(e) => { let message = e.to_string(); let mut raw_response = None; - if let reqwest_eventsource::Error::InvalidStatusCode(_, resp) = *e { + if let ReqwestEventSourceError::InvalidStatusCode(_, resp) = *e { raw_response = resp.text().await.ok(); } yield Err(ErrorDetails::InferenceServer { @@ -722,21 +722,24 @@ fn stream_together( Ok(event) => match event { Event::Open => continue, Event::Message(message) => { - if message.data == "[DONE]" { + let Some(message_data) = message.data else { + continue; + }; + if message_data == "[DONE]" { break; } let data: Result = - serde_json::from_str(&message.data).map_err(|e| Error::new(ErrorDetails::InferenceServer { + serde_json::from_str(&message_data).map_err(|e| Error::new(ErrorDetails::InferenceServer { message: format!("Error parsing chunk. Error: {e}"), raw_request: None, - raw_response: Some(message.data.clone()), + raw_response: Some(message_data.clone()), provider_type: PROVIDER_TYPE.to_string(), })); let latency = start_time.elapsed(); let stream_message = data.and_then(|d| { together_to_tensorzero_chunk( - message.data, + message_data, d, latency, &mut tool_call_ids, diff --git a/tensorzero-core/tests/e2e/best_of_n.rs b/tensorzero-core/tests/e2e/best_of_n.rs index f3a66ec0d3..4c130b0d8b 100644 --- a/tensorzero-core/tests/e2e/best_of_n.rs +++ b/tensorzero-core/tests/e2e/best_of_n.rs @@ -1,6 +1,6 @@ use futures::StreamExt; use reqwest::{Client, StatusCode}; -use reqwest_eventsource::{Event, RequestBuilderExt}; +use reqwest_sse_stream::into_sse_stream; use serde_json::{Value, json}; use tensorzero_core::{ inference::types::{Role, StoredContentBlock, StoredRequestMessage, Text, Unknown}, @@ -60,19 +60,18 @@ async fn e2e_test_best_of_n_dummy_candidates_dummy_judge_inner( .json(&payload); let inference_id = if stream { - let mut chunks = builder.eventsource().unwrap(); + let mut chunks = into_sse_stream(builder).await.unwrap(); let mut first_inference_id = None; while let Some(chunk) = chunks.next().await { println!("chunk: {chunk:?}"); - let chunk = chunk.unwrap(); - let Event::Message(chunk) = chunk else { + let sse = chunk.unwrap(); + let Some(data) = sse.data else { continue; }; - if chunk.data == "[DONE]" { + if data == "[DONE]" { break; } - let chunk_json = chunk.data; - let chunk_json: Value = serde_json::from_str(&chunk_json).unwrap(); + let chunk_json: Value = serde_json::from_str(&data).unwrap(); let inference_id = chunk_json.get("inference_id").unwrap().as_str().unwrap(); let inference_id = Uuid::parse_str(inference_id).unwrap(); if first_inference_id.is_none() { diff --git a/tensorzero-core/tests/e2e/cache.rs b/tensorzero-core/tests/e2e/cache.rs index 284dc8fd52..ac0ea06650 100644 --- a/tensorzero-core/tests/e2e/cache.rs +++ b/tensorzero-core/tests/e2e/cache.rs @@ -3,8 +3,7 @@ use futures::StreamExt; use rand::Rng; use reqwest::Client; -use reqwest_eventsource::Event; -use reqwest_eventsource::RequestBuilderExt; +use reqwest_sse_stream::into_sse_stream; use serde_json::Value; use serde_json::json; use std::time::Duration; @@ -525,33 +524,31 @@ pub async fn check_test_streaming_cache_with_err( "cache_options": {"enabled": "on", "lookback_s": 10} }); - let mut event_source = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .unwrap(); + let mut event_source = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .unwrap(); let mut chunks = vec![]; let mut found_done_chunk = false; while let Some(event) = event_source.next().await { - let event = event.unwrap(); - match event { - Event::Open => continue, - Event::Message(message) => { - if message.data == "[DONE]" { - found_done_chunk = true; - break; - } - if serde_json::from_str::(&message.data) - .unwrap() - .get("error") - .is_some() - { - continue; - } - chunks.push(message.data); - } + let sse = event.unwrap(); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { + found_done_chunk = true; + break; + } + if serde_json::from_str::(&data) + .unwrap() + .get("error") + .is_some() + { + continue; } + chunks.push(data); } assert!(found_done_chunk); @@ -913,10 +910,8 @@ async fn test_streaming_cache_usage_only_in_final_chunk_openai() { let url = format!("{base_url}/openai/v1/chat/completions"); - let mut chunks = Client::new() - .post(&url) - .json(&payload) - .eventsource() + let mut chunks = into_sse_stream(Client::new().post(&url).json(&payload)) + .await .unwrap(); let mut chunks_with_usage = 0; @@ -925,17 +920,15 @@ async fn test_streaming_cache_usage_only_in_final_chunk_openai() { let mut total_completion_tokens = 0u64; while let Some(chunk) = chunks.next().await { - let chunk = chunk.unwrap(); - let Event::Message(chunk) = chunk else { - continue; - }; - if chunk.data == "[DONE]" { + let sse = chunk.unwrap(); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { break; } total_chunks += 1; - let chunk_json: Value = serde_json::from_str(&chunk.data).unwrap(); + let chunk_json: Value = serde_json::from_str(&data).unwrap(); if let Some(usage) = chunk_json.get("usage") && !usage.is_null() { diff --git a/tensorzero-core/tests/e2e/dicl.rs b/tensorzero-core/tests/e2e/dicl.rs index e2715d4311..bbe3312940 100644 --- a/tensorzero-core/tests/e2e/dicl.rs +++ b/tensorzero-core/tests/e2e/dicl.rs @@ -1,7 +1,7 @@ use crate::common::get_gateway_endpoint; use futures::StreamExt; use reqwest::{Client, StatusCode}; -use reqwest_eventsource::{Event, RequestBuilderExt}; +use reqwest_sse_stream::into_sse_stream; use serde_json::{Value, json}; use std::collections::HashMap; use std::sync::Arc; @@ -737,25 +737,23 @@ pub async fn test_dicl_inference_request_simple() { ] } }); - let mut event_source = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .unwrap(); + let mut event_source = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .unwrap(); let mut chunks = vec![]; let mut found_done_chunk = false; while let Some(event) = event_source.next().await { - let event = event.unwrap(); - match event { - Event::Open => continue, - Event::Message(message) => { - if message.data == "[DONE]" { - found_done_chunk = true; - break; - } - chunks.push(message.data); - } + let sse = event.unwrap(); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { + found_done_chunk = true; + break; } + chunks.push(data); } assert!(found_done_chunk); let mut inference_id: Option = None; diff --git a/tensorzero-core/tests/e2e/endpoints/internal/evaluations.rs b/tensorzero-core/tests/e2e/endpoints/internal/evaluations.rs index 66157ece23..073e8c39b5 100644 --- a/tensorzero-core/tests/e2e/endpoints/internal/evaluations.rs +++ b/tensorzero-core/tests/e2e/endpoints/internal/evaluations.rs @@ -4,7 +4,7 @@ use std::time::Duration; use futures::StreamExt; use reqwest::{Client, StatusCode}; -use reqwest_eventsource::{Event, RequestBuilderExt}; +use reqwest_sse_stream::into_sse_stream; use serde_json::{Value, json}; use tensorzero_core::db::clickhouse::test_helpers::get_clickhouse; use tensorzero_core::db::evaluation_queries::EvaluationResultRow; @@ -779,11 +779,13 @@ async fn test_run_evaluation_streaming_success() { }); // Make the SSE request - let mut event_stream = http_client - .post(get_gateway_endpoint("/internal/evaluations/run")) - .json(&payload) - .eventsource() - .unwrap(); + let mut event_stream = into_sse_stream( + http_client + .post(get_gateway_endpoint("/internal/evaluations/run")) + .json(&payload), + ) + .await + .unwrap(); let mut events: Vec = Vec::new(); let mut start_received = false; @@ -793,64 +795,62 @@ async fn test_run_evaluation_streaming_success() { // Collect events from the stream while let Some(event_result) = event_stream.next().await { - match event_result { - Ok(Event::Open) => continue, - Ok(Event::Message(message)) => { - if message.data == "[DONE]" { - break; - } - - let event: Value = serde_json::from_str(&message.data).unwrap(); - let event_type = event.get("type").and_then(|t| t.as_str()); - - match event_type { - Some("start") => { - start_received = true; - assert!( - event.get("evaluation_run_id").is_some(), - "Start event should have evaluation_run_id" - ); - assert!( - event.get("num_datapoints").is_some(), - "Start event should have num_datapoints" - ); - } - Some("success") => { - success_count += 1; - assert!( - event.get("datapoint").is_some(), - "Success event should have datapoint" - ); - assert!( - event.get("response").is_some(), - "Success event should have response" - ); - assert!( - event.get("evaluations").is_some(), - "Success event should have evaluations" - ); - } - Some("error") => { - error_count += 1; - } - Some("complete") => { - complete_received = true; - assert!( - event.get("evaluation_run_id").is_some(), - "Complete event should have evaluation_run_id" - ); - } - Some("fatal_error") => { - panic!("Received fatal_error event: {:?}", event.get("message")); - } - _ => {} - } + let sse = match event_result { + Ok(sse) => sse, + Err(_) => break, + }; + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { + break; + } - events.push(event); + let event: Value = serde_json::from_str(&data).unwrap(); + let event_type = event.get("type").and_then(|t| t.as_str()); + + match event_type { + Some("start") => { + start_received = true; + assert!( + event.get("evaluation_run_id").is_some(), + "Start event should have evaluation_run_id" + ); + assert!( + event.get("num_datapoints").is_some(), + "Start event should have num_datapoints" + ); } - Err(reqwest_eventsource::Error::StreamEnded) => break, - Err(e) => panic!("SSE stream error: {e:?}"), + Some("success") => { + success_count += 1; + assert!( + event.get("datapoint").is_some(), + "Success event should have datapoint" + ); + assert!( + event.get("response").is_some(), + "Success event should have response" + ); + assert!( + event.get("evaluations").is_some(), + "Success event should have evaluations" + ); + } + Some("error") => { + error_count += 1; + } + Some("complete") => { + complete_received = true; + assert!( + event.get("evaluation_run_id").is_some(), + "Complete event should have evaluation_run_id" + ); + } + Some("fatal_error") => { + panic!("Received fatal_error event: {:?}", event.get("message")); + } + _ => {} } + + events.push(event); } assert!(start_received, "Should receive start event"); @@ -923,47 +923,47 @@ async fn test_run_evaluation_streaming_nonexistent_dataset() { // The request should start streaming. We should get a start event with 0 datapoints, // then a complete event. - let mut event_stream = http_client - .post(get_gateway_endpoint("/internal/evaluations/run")) - .json(&payload) - .eventsource() - .unwrap(); + let mut event_stream = into_sse_stream( + http_client + .post(get_gateway_endpoint("/internal/evaluations/run")) + .json(&payload), + ) + .await + .unwrap(); let mut found_error_or_empty = false; while let Some(event_result) = event_stream.next().await { - match event_result { - Ok(Event::Open) => continue, - Ok(Event::Message(message)) => { - if message.data == "[DONE]" { - break; - } - let event: Value = serde_json::from_str(&message.data).unwrap(); - let event_type = event.get("type").and_then(|t| t.as_str()); - - match event_type { - Some("start") => { - let num_datapoints = event.get("num_datapoints").and_then(|n| n.as_u64()); - if num_datapoints == Some(0) { - found_error_or_empty = true; - } - } - Some("fatal_error") => { - found_error_or_empty = true; - break; - } - Some("complete") => { - found_error_or_empty = true; - break; - } - _ => {} + let sse = match event_result { + Ok(sse) => sse, + Err(_) => { + found_error_or_empty = true; + break; + } + }; + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { + break; + } + let event: Value = serde_json::from_str(&data).unwrap(); + let event_type = event.get("type").and_then(|t| t.as_str()); + + match event_type { + Some("start") => { + let num_datapoints = event.get("num_datapoints").and_then(|n| n.as_u64()); + if num_datapoints == Some(0) { + found_error_or_empty = true; } } - Err(reqwest_eventsource::Error::StreamEnded) => break, - Err(_) => { + Some("fatal_error") => { found_error_or_empty = true; break; } + Some("complete") => { + found_error_or_empty = true; + break; + } + _ => {} } } @@ -1012,41 +1012,40 @@ async fn test_run_evaluation_streaming_with_specific_datapoint_ids() { "inference_cache": "off", }); - let mut event_stream = http_client - .post(get_gateway_endpoint("/internal/evaluations/run")) - .json(&payload) - .eventsource() - .unwrap(); + let mut event_stream = into_sse_stream( + http_client + .post(get_gateway_endpoint("/internal/evaluations/run")) + .json(&payload), + ) + .await + .unwrap(); let mut num_datapoints_reported = None; let mut start_received = false; while let Some(event_result) = event_stream.next().await { - match event_result { - Ok(Event::Open) => continue, - Ok(Event::Message(message)) => { - if message.data == "[DONE]" { - break; - } + let sse = match event_result { + Ok(sse) => sse, + Err(e) => panic!("SSE stream error: {e:?}"), + }; + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { + break; + } - let event: Value = serde_json::from_str(&message.data).unwrap(); - let event_type = event.get("type").and_then(|t| t.as_str()); - - match event_type { - Some("start") => { - start_received = true; - num_datapoints_reported = - event.get("num_datapoints").and_then(|n| n.as_u64()); - } - Some("complete") => break, - Some("fatal_error") => { - panic!("Received fatal_error event: {:?}", event.get("message")); - } - _ => {} - } + let event: Value = serde_json::from_str(&data).unwrap(); + let event_type = event.get("type").and_then(|t| t.as_str()); + + match event_type { + Some("start") => { + start_received = true; + num_datapoints_reported = event.get("num_datapoints").and_then(|n| n.as_u64()); } - Err(reqwest_eventsource::Error::StreamEnded) => break, - Err(e) => panic!("SSE stream error: {e:?}"), + Some("complete") => break, + Some("fatal_error") => { + panic!("Received fatal_error event: {:?}", event.get("message")); + } + _ => {} } } diff --git a/tensorzero-core/tests/e2e/inference/mod.rs b/tensorzero-core/tests/e2e/inference/mod.rs index b79aef7a1b..01e4473120 100644 --- a/tensorzero-core/tests/e2e/inference/mod.rs +++ b/tensorzero-core/tests/e2e/inference/mod.rs @@ -11,7 +11,7 @@ use base64::prelude::{BASE64_STANDARD, Engine as Base64Engine}; use futures::StreamExt; use opentelemetry_sdk::trace::SpanData; use reqwest::{Client, StatusCode}; -use reqwest_eventsource::{Event, RequestBuilderExt}; +use reqwest_sse_stream::into_sse_stream; use serde_json::{Value, json}; use std::collections::HashMap; use std::{collections::HashSet, sync::Arc}; @@ -326,24 +326,22 @@ async fn test_dummy_only_inference_chat_strip_unknown_block_stream() { "stream": true, }); - let mut event_source = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .unwrap(); + let mut event_source = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .unwrap(); let mut chunks = vec![]; while let Some(event) = event_source.next().await { - let event = event.unwrap(); - match event { - Event::Open => continue, - Event::Message(message) => { - if message.data == "[DONE]" { - break; - } - chunks.push(message.data); - } + let sse = event.unwrap(); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { + break; } + chunks.push(data); } let chunk_json: serde_json::Value = serde_json::from_str(chunks.last().unwrap()).unwrap(); @@ -1561,23 +1559,21 @@ async fn e2e_test_streaming() { }} }); - let mut event_source = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .unwrap(); + let mut event_source = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .unwrap(); let mut chunks = vec![]; while let Some(event) = event_source.next().await { - let event = event.unwrap(); - match event { - Event::Open => continue, - Event::Message(message) => { - if message.data == "[DONE]" { - break; - } - chunks.push(message.data); - } + let sse = event.unwrap(); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { + break; } + chunks.push(data); } let mut inference_id = None; for (i, chunk) in chunks.iter().enumerate() { @@ -1738,23 +1734,21 @@ async fn e2e_test_streaming_dryrun() { "dryrun": true, }); - let mut event_source = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .unwrap(); + let mut event_source = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .unwrap(); let mut chunks = vec![]; while let Some(event) = event_source.next().await { - let event = event.unwrap(); - match event { - Event::Open => continue, - Event::Message(message) => { - if message.data == "[DONE]" { - break; - } - chunks.push(message.data); - } + let sse = event.unwrap(); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { + break; } + chunks.push(data); } let mut inference_id = None; for (i, chunk) in chunks.iter().enumerate() { @@ -2241,23 +2235,21 @@ async fn e2e_test_tool_call_streaming() { ]}, "stream": true, }); - let mut event_source = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .unwrap(); + let mut event_source = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .unwrap(); let mut chunks = vec![]; while let Some(event) = event_source.next().await { - let event = event.unwrap(); - match event { - Event::Open => continue, - Event::Message(message) => { - if message.data == "[DONE]" { - break; - } - chunks.push(message.data); - } + let sse = event.unwrap(); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { + break; } + chunks.push(data); } let mut inference_id = None; let mut id: Option = None; @@ -2468,23 +2460,21 @@ async fn e2e_test_tool_call_streaming_split_tool_name() { "stream": true, "variant_name": "split_tool_name", }); - let mut event_source = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .unwrap(); + let mut event_source = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .unwrap(); let mut chunks = vec![]; while let Some(event) = event_source.next().await { - let event = event.unwrap(); - match event { - Event::Open => continue, - Event::Message(message) => { - if message.data == "[DONE]" { - break; - } - chunks.push(message.data); - } + let sse = event.unwrap(); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { + break; } + chunks.push(data); } let mut inference_id = None; let mut id: Option = None; diff --git a/tensorzero-core/tests/e2e/mixture_of_n.rs b/tensorzero-core/tests/e2e/mixture_of_n.rs index e9e1e16aef..b9aa800da8 100644 --- a/tensorzero-core/tests/e2e/mixture_of_n.rs +++ b/tensorzero-core/tests/e2e/mixture_of_n.rs @@ -1,6 +1,6 @@ use futures::StreamExt; use reqwest::{Client, StatusCode}; -use reqwest_eventsource::{Event, RequestBuilderExt}; +use reqwest_sse_stream::into_sse_stream; use serde_json::{Value, json}; use tensorzero_core::inference::types::{ Role, StoredContentBlock, StoredRequestMessage, Text, Unknown, Usage, @@ -58,20 +58,17 @@ async fn e2e_test_mixture_of_n_dummy_candidates_dummy_judge_inner( .post(get_gateway_endpoint("/inference")) .json(&payload); let (inference_id, output_usage) = if stream { - let mut chunks = builder.eventsource().unwrap(); + let mut chunks = into_sse_stream(builder).await.unwrap(); let mut first_inference_id = None; let mut last_chunk = None; while let Some(chunk) = chunks.next().await { println!("chunk: {chunk:?}"); - let chunk = chunk.unwrap(); - let Event::Message(chunk) = chunk else { - continue; - }; - if chunk.data == "[DONE]" { + let sse = chunk.unwrap(); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { break; } - let chunk_json = chunk.data; - let chunk_json: Value = serde_json::from_str(&chunk_json).unwrap(); + let chunk_json: Value = serde_json::from_str(&data).unwrap(); let inference_id = chunk_json.get("inference_id").unwrap().as_str().unwrap(); let inference_id = Uuid::parse_str(inference_id).unwrap(); if first_inference_id.is_none() { @@ -288,19 +285,16 @@ async fn e2e_test_mixture_of_n_dummy_candidates_real_judge_inner(stream: bool) { .json(&payload); let (content, inference_id) = if stream { - let mut chunks = builder.eventsource().unwrap(); + let mut chunks = into_sse_stream(builder).await.unwrap(); let mut first_inference_id = None; while let Some(chunk) = chunks.next().await { println!("chunk: {chunk:?}"); - let chunk = chunk.unwrap(); - let Event::Message(chunk) = chunk else { - continue; - }; - if chunk.data == "[DONE]" { + let sse = chunk.unwrap(); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { break; } - let chunk_json = chunk.data; - let chunk_json: Value = serde_json::from_str(&chunk_json).unwrap(); + let chunk_json: Value = serde_json::from_str(&data).unwrap(); let inference_id = chunk_json.get("inference_id").unwrap().as_str().unwrap(); let inference_id = Uuid::parse_str(inference_id).unwrap(); if first_inference_id.is_none() { @@ -940,20 +934,17 @@ async fn e2e_test_mixture_of_n_bad_fuser_streaming() { .post(get_gateway_endpoint("/inference")) .json(&payload); - let mut chunks = builder.eventsource().unwrap(); + let mut chunks = into_sse_stream(builder).await.unwrap(); let mut first_inference_id = None; let mut chunk_data = vec![]; while let Some(chunk) = chunks.next().await { println!("chunk: {chunk:?}"); - let chunk = chunk.unwrap(); - let Event::Message(chunk) = chunk else { - continue; - }; - if chunk.data == "[DONE]" { + let sse = chunk.unwrap(); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { break; } - let chunk_json = chunk.data; - let chunk_json: Value = serde_json::from_str(&chunk_json).unwrap(); + let chunk_json: Value = serde_json::from_str(&data).unwrap(); let inference_id = chunk_json.get("inference_id").unwrap().as_str().unwrap(); let inference_id = Uuid::parse_str(inference_id).unwrap(); if first_inference_id.is_none() { @@ -1118,20 +1109,17 @@ async fn e2e_test_mixture_of_n_single_candidate_inner( .post(get_gateway_endpoint("/inference")) .json(&payload); let inference_id = if stream { - let mut chunks = builder.eventsource().unwrap(); + let mut chunks = into_sse_stream(builder).await.unwrap(); let mut first_inference_id = None; let mut chunk_data = vec![]; while let Some(chunk) = chunks.next().await { println!("chunk: {chunk:?}"); - let chunk = chunk.unwrap(); - let Event::Message(chunk) = chunk else { - continue; - }; - if chunk.data == "[DONE]" { + let sse = chunk.unwrap(); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { break; } - let chunk_json = chunk.data; - let chunk_json: Value = serde_json::from_str(&chunk_json).unwrap(); + let chunk_json: Value = serde_json::from_str(&data).unwrap(); let inference_id = chunk_json.get("inference_id").unwrap().as_str().unwrap(); let inference_id = Uuid::parse_str(inference_id).unwrap(); if first_inference_id.is_none() { diff --git a/tensorzero-core/tests/e2e/openai_compatible.rs b/tensorzero-core/tests/e2e/openai_compatible.rs index 76097eea66..a6fab60ba1 100644 --- a/tensorzero-core/tests/e2e/openai_compatible.rs +++ b/tensorzero-core/tests/e2e/openai_compatible.rs @@ -1,12 +1,14 @@ #![expect(clippy::print_stdout)] use std::collections::HashSet; -use tensorzero::ClientExt; use axum::extract::State; +use futures::StreamExt; use http_body_util::BodyExt; use reqwest::{Client, StatusCode}; +use reqwest_sse_stream::into_sse_stream; use serde_json::{Value, json}; +use tensorzero::ClientExt; use uuid::Uuid; use crate::common::get_gateway_endpoint; @@ -766,7 +768,7 @@ async fn test_openai_compatible_route_with_json_schema() { #[tokio::test] async fn test_openai_compatible_streaming_tool_call() { use futures::StreamExt; - use reqwest_eventsource::{Event, RequestBuilderExt}; + use reqwest_sse_stream::into_sse_stream; let client = Client::new(); let episode_id = Uuid::now_v7(); @@ -809,27 +811,25 @@ async fn test_openai_compatible_streaming_tool_call() { "tensorzero::episode_id": episode_id.to_string(), }); - let mut response = client - .post(get_gateway_endpoint("/openai/v1/chat/completions")) - .header("Content-Type", "application/json") - .json(&body) - .eventsource() - .unwrap(); + let mut response = into_sse_stream( + client + .post(get_gateway_endpoint("/openai/v1/chat/completions")) + .header("Content-Type", "application/json") + .json(&body), + ) + .await + .unwrap(); - let mut chunks = vec![]; + let mut chunks: Vec = vec![]; let mut found_done_chunk = false; while let Some(event) = response.next().await { - let event = event.unwrap(); - match event { - Event::Open => continue, - Event::Message(message) => { - if message.data == "[DONE]" { - found_done_chunk = true; - break; - } - chunks.push(message.data); - } + let sse = event.unwrap(); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { + found_done_chunk = true; + break; } + chunks.push(data); } assert!(found_done_chunk); let first_chunk = chunks.first().unwrap(); @@ -933,8 +933,6 @@ async fn test_openai_compatible_deny_unknown_fields() { #[tokio::test] async fn test_openai_compatible_streaming() { - use futures::StreamExt; - use reqwest_eventsource::{Event, RequestBuilderExt}; let client = Client::new(); let episode_id = Uuid::now_v7(); @@ -950,27 +948,25 @@ async fn test_openai_compatible_streaming() { "tensorzero::episode_id": episode_id.to_string(), }); - let mut response = client - .post(get_gateway_endpoint("/openai/v1/chat/completions")) - .header("Content-Type", "application/json") - .json(&body) - .eventsource() - .unwrap(); + let mut response = into_sse_stream( + client + .post(get_gateway_endpoint("/openai/v1/chat/completions")) + .header("Content-Type", "application/json") + .json(&body), + ) + .await + .unwrap(); - let mut chunks = vec![]; + let mut chunks: Vec = vec![]; let mut found_done_chunk = false; while let Some(event) = response.next().await { - let event = event.unwrap(); - match event { - Event::Open => continue, - Event::Message(message) => { - if message.data == "[DONE]" { - found_done_chunk = true; - break; - } - chunks.push(message.data); - } + let sse = event.unwrap(); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { + found_done_chunk = true; + break; } + chunks.push(data); } assert!(found_done_chunk); let first_chunk = chunks.first().unwrap(); diff --git a/tensorzero-core/tests/e2e/providers/anthropic.rs b/tensorzero-core/tests/e2e/providers/anthropic.rs index f4e112932c..e1b518fe65 100644 --- a/tensorzero-core/tests/e2e/providers/anthropic.rs +++ b/tensorzero-core/tests/e2e/providers/anthropic.rs @@ -4,7 +4,7 @@ use std::collections::HashMap; use futures::StreamExt; use indexmap::IndexMap; use reqwest::{Client, StatusCode}; -use reqwest_eventsource::{Event, RequestBuilderExt}; +use reqwest_sse_stream::into_sse_stream; use serde_json::{Value, json}; use tensorzero::{ ClientInferenceParams, File, InferenceOutput, InferenceResponse, Input, InputMessage, @@ -255,23 +255,19 @@ async fn test_empty_chunks_success() { let client = Client::new(); - let mut event_source = client - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .unwrap(); + let mut event_source = into_sse_stream( + client.post(get_gateway_endpoint("/inference")).json(&payload), + ) + .await + .unwrap(); let mut chunks = vec![]; while let Some(event) = event_source.next().await { - let event = event.unwrap(); - match event { - Event::Open => continue, - Event::Message(message) => { - if message.data == "[DONE]" { - break; - } - chunks.push(message.data); - } + let sse = event.unwrap(); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { + break; } + chunks.push(data); } println!("Chunks: {chunks:?}"); @@ -724,28 +720,24 @@ async fn test_beta_structured_outputs_json_helper(stream: bool) { }); let inference_id = if stream { - let mut event_source = client - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .unwrap(); + let mut event_source = into_sse_stream( + client.post(get_gateway_endpoint("/inference")).json(&payload), + ) + .await + .unwrap(); let mut first_inference_id = None; while let Some(event) = event_source.next().await { - let event = event.unwrap(); - match event { - Event::Open => continue, - Event::Message(message) => { - if message.data == "[DONE]" { - break; - } - let chunk_json: Value = serde_json::from_str(&message.data).unwrap(); - if let Some(inference_id) = chunk_json - .get("inference_id") - .and_then(|id| id.as_str().map(|id| Uuid::parse_str(id).unwrap())) - { - first_inference_id = Some(inference_id); - } - } + let sse = event.unwrap(); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { + break; + } + let chunk_json: Value = serde_json::from_str(&data).unwrap(); + if let Some(inference_id) = chunk_json + .get("inference_id") + .and_then(|id| id.as_str().map(|id| Uuid::parse_str(id).unwrap())) + { + first_inference_id = Some(inference_id); } } first_inference_id.unwrap() @@ -825,28 +817,24 @@ async fn test_beta_structured_outputs_strict_tool_helper(stream: bool) { }); let inference_id = if stream { - let mut event_source = client - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .unwrap(); + let mut event_source = into_sse_stream( + client.post(get_gateway_endpoint("/inference")).json(&payload), + ) + .await + .unwrap(); let mut first_inference_id = None; while let Some(event) = event_source.next().await { - let event = event.unwrap(); - match event { - Event::Open => continue, - Event::Message(message) => { - if message.data == "[DONE]" { - break; - } - let chunk_json: Value = serde_json::from_str(&message.data).unwrap(); - if let Some(inference_id) = chunk_json - .get("inference_id") - .and_then(|id| id.as_str().map(|id| Uuid::parse_str(id).unwrap())) - { - first_inference_id = Some(inference_id); - } - } + let sse = event.unwrap(); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { + break; + } + let chunk_json: Value = serde_json::from_str(&data).unwrap(); + if let Some(inference_id) = chunk_json + .get("inference_id") + .and_then(|id| id.as_str().map(|id| Uuid::parse_str(id).unwrap())) + { + first_inference_id = Some(inference_id); } } first_inference_id.unwrap() @@ -951,23 +939,19 @@ pub async fn test_streaming_thinking_helper(model_name: &str, provider_type: &st let client = Client::new(); - let mut event_source = client - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .unwrap(); + let mut event_source = into_sse_stream( + client.post(get_gateway_endpoint("/inference")).json(&payload), + ) + .await + .unwrap(); let mut chunks = vec![]; while let Some(event) = event_source.next().await { - let event = event.unwrap(); - match event { - Event::Open => continue, - Event::Message(message) => { - if message.data == "[DONE]" { - break; - } - chunks.push(message.data); - } + let sse = event.unwrap(); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { + break; } + chunks.push(data); } let mut inference_id = None; let mut content_blocks: IndexMap<(String, String), String> = IndexMap::new(); diff --git a/tensorzero-core/tests/e2e/providers/common.rs b/tensorzero-core/tests/e2e/providers/common.rs index 06c98ac639..e3520d3c4c 100644 --- a/tensorzero-core/tests/e2e/providers/common.rs +++ b/tensorzero-core/tests/e2e/providers/common.rs @@ -27,7 +27,7 @@ use object_store::path::Path; use rand::Rng; use reqwest::{Client, StatusCode}; -use reqwest_eventsource::{Event, RequestBuilderExt}; +use reqwest_sse_stream::into_sse_stream; use serde_json::{Value, json}; use std::future::IntoFuture; use tensorzero::{ @@ -1997,26 +1997,26 @@ pub async fn test_extra_body_with_provider_and_stream(provider: &E2ETestProvider }); let inference_id = if stream { - let mut event_source = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .unwrap(); + let mut event_source = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .unwrap(); let mut chunks = vec![]; let mut found_done_chunk = false; while let Some(event) = event_source.next().await { - let event = event.unwrap(); - match event { - Event::Open => continue, - Event::Message(message) => { - if message.data == "[DONE]" { - found_done_chunk = true; - break; - } - chunks.push(message.data); - } + let sse = event.unwrap(); + let Some(data) = sse.data else { + continue; + }; + if data == "[DONE]" { + found_done_chunk = true; + break; } + chunks.push(data); } assert!(found_done_chunk); @@ -2201,26 +2201,26 @@ pub async fn test_inference_extra_body_with_provider_and_stream( }); let inference_id = if stream { - let mut event_source = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .unwrap(); + let mut event_source = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .unwrap(); let mut chunks = vec![]; let mut found_done_chunk = false; while let Some(event) = event_source.next().await { - let event = event.unwrap(); - match event { - Event::Open => continue, - Event::Message(message) => { - if message.data == "[DONE]" { - found_done_chunk = true; - break; - } - chunks.push(message.data); - } + let sse = event.unwrap(); + let Some(data) = sse.data else { + continue; + }; + if data == "[DONE]" { + found_done_chunk = true; + break; } + chunks.push(data); } assert!(found_done_chunk); @@ -3729,33 +3729,27 @@ pub async fn test_streaming_invalid_request_with_provider(provider: E2ETestProvi "extra_headers": extra_headers.extra_headers, }); - let mut event_source = Client::new() + // This test expects an error response, so we make the request manually + // instead of using into_sse_stream which calls error_for_status() + let response = Client::new() .post(get_gateway_endpoint("/inference")) .json(&payload) - .eventsource() + .send() + .await .unwrap(); - while let Some(event) = event_source.next().await { - if let Ok(reqwest_eventsource::Event::Open) = event { - continue; - } - let err = event.unwrap_err(); - let reqwest_eventsource::Error::InvalidStatusCode(code, resp) = err else { - panic!("Unexpected error: {err:?}") - }; - assert_eq!(code, StatusCode::INTERNAL_SERVER_ERROR); - let resp: Value = resp.json().await.unwrap(); - let err_msg = resp.get("error").unwrap().as_str().unwrap(); - println!("Error message: {err_msg}"); - assert!( - err_msg.contains("top_p") - || err_msg.contains("topP") - || err_msg.contains("temperature") - || err_msg.contains("presence_penalty") - || err_msg.contains("frequency_penalty"), - "Unexpected error message: {resp}" - ); - } + assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); + let resp: Value = response.json().await.unwrap(); + let err_msg = resp.get("error").unwrap().as_str().unwrap(); + println!("Error message: {err_msg}"); + assert!( + err_msg.contains("top_p") + || err_msg.contains("topP") + || err_msg.contains("temperature") + || err_msg.contains("presence_penalty") + || err_msg.contains("frequency_penalty"), + "Unexpected error message: {resp}" + ); } pub async fn test_simple_streaming_inference_request_with_provider(provider: E2ETestProvider) { @@ -3854,26 +3848,26 @@ pub async fn test_simple_streaming_inference_request_with_provider_cache( "extra_headers": extra_headers.extra_headers, }); - let mut event_source = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .unwrap(); + let mut event_source = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .unwrap(); let mut chunks = vec![]; let mut found_done_chunk = false; while let Some(event) = event_source.next().await { - let event = event.unwrap(); - match event { - Event::Open => continue, - Event::Message(message) => { - if message.data == "[DONE]" { - found_done_chunk = true; - break; - } - chunks.push(message.data); - } + let sse = event.unwrap(); + let Some(data) = sse.data else { + continue; + }; + if data == "[DONE]" { + found_done_chunk = true; + break; } + chunks.push(data); } assert!(found_done_chunk); @@ -4432,26 +4426,26 @@ pub async fn test_inference_params_dynamic_credentials_streaming_inference_reque "extra_headers": extra_headers.extra_headers, }); - let mut event_source = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .unwrap(); + let mut event_source = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .unwrap(); let mut chunks = vec![]; let mut found_done_chunk = false; while let Some(event) = event_source.next().await { - let event = event.unwrap(); - match event { - Event::Open => continue, - Event::Message(message) => { - if message.data == "[DONE]" { - found_done_chunk = true; - break; - } - chunks.push(message.data); - } + let sse = event.unwrap(); + let Some(data) = sse.data else { + continue; + }; + if data == "[DONE]" { + found_done_chunk = true; + break; } + chunks.push(data); } assert!(found_done_chunk); @@ -5076,26 +5070,26 @@ pub async fn test_tool_use_tool_choice_auto_used_streaming_inference_request_wit "extra_headers": extra_headers }); - let mut event_source = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .unwrap(); + let mut event_source = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .unwrap(); let mut chunks = vec![]; let mut found_done_chunk = false; while let Some(event) = event_source.next().await { - let event = event.unwrap(); - match event { - Event::Open => continue, - Event::Message(message) => { - if message.data == "[DONE]" { - found_done_chunk = true; - break; - } - chunks.push(message.data); - } + let sse = event.unwrap(); + let Some(data) = sse.data else { + continue; + }; + if data == "[DONE]" { + found_done_chunk = true; + break; } + chunks.push(data); } assert!(found_done_chunk); @@ -5693,26 +5687,26 @@ pub async fn test_tool_use_tool_choice_auto_unused_streaming_inference_request_w "extra_headers": extra_headers.extra_headers, }); - let mut event_source = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .unwrap(); + let mut event_source = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .unwrap(); let mut chunks = vec![]; let mut found_done_chunk = false; while let Some(event) = event_source.next().await { - let event = event.unwrap(); - match event { - Event::Open => continue, - Event::Message(message) => { - if message.data == "[DONE]" { - found_done_chunk = true; - break; - } - chunks.push(message.data); - } + let sse = event.unwrap(); + let Some(data) = sse.data else { + continue; + }; + if data == "[DONE]" { + found_done_chunk = true; + break; } + chunks.push(data); } assert!(found_done_chunk); @@ -6296,26 +6290,26 @@ pub async fn test_tool_use_tool_choice_required_streaming_inference_request_with "extra_headers": extra_headers.extra_headers, }); - let mut event_source = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .unwrap(); + let mut event_source = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .unwrap(); let mut chunks = vec![]; let mut found_done_chunk = false; while let Some(event) = event_source.next().await { - let event = event.unwrap(); - match event { - Event::Open => continue, - Event::Message(message) => { - if message.data == "[DONE]" { - found_done_chunk = true; - break; - } - chunks.push(message.data); - } + let sse = event.unwrap(); + let Some(data) = sse.data else { + continue; + }; + if data == "[DONE]" { + found_done_chunk = true; + break; } + chunks.push(data); } assert!(found_done_chunk); @@ -6909,26 +6903,26 @@ pub async fn test_tool_use_tool_choice_none_streaming_inference_request_with_pro "extra_headers": extra_headers.extra_headers, }); - let mut event_source = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .unwrap(); + let mut event_source = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .unwrap(); let mut chunks = vec![]; let mut found_done_chunk = false; while let Some(event) = event_source.next().await { - let event = event.unwrap(); - match event { - Event::Open => continue, - Event::Message(message) => { - if message.data == "[DONE]" { - found_done_chunk = true; - break; - } - chunks.push(message.data); - } + let sse = event.unwrap(); + let Some(data) = sse.data else { + continue; + }; + if data == "[DONE]" { + found_done_chunk = true; + break; } + chunks.push(data); } assert!(found_done_chunk); @@ -7573,26 +7567,26 @@ pub async fn test_tool_use_tool_choice_specific_streaming_inference_request_with "extra_headers": extra_headers.extra_headers, }); - let mut event_source = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .unwrap(); + let mut event_source = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .unwrap(); let mut chunks = vec![]; let mut found_done_chunk = false; while let Some(event) = event_source.next().await { - let event = event.unwrap(); - match event { - Event::Open => continue, - Event::Message(message) => { - if message.data == "[DONE]" { - found_done_chunk = true; - break; - } - chunks.push(message.data); - } + let sse = event.unwrap(); + let Some(data) = sse.data else { + continue; + }; + if data == "[DONE]" { + found_done_chunk = true; + break; } + chunks.push(data); } assert!(found_done_chunk); @@ -8236,26 +8230,26 @@ pub async fn test_tool_use_allowed_tools_streaming_inference_request_with_provid "extra_headers": extra_headers.extra_headers, }); - let mut event_source = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .unwrap(); + let mut event_source = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .unwrap(); let mut chunks = vec![]; let mut found_done_chunk = false; while let Some(event) = event_source.next().await { - let event = event.unwrap(); - match event { - Event::Open => continue, - Event::Message(message) => { - if message.data == "[DONE]" { - found_done_chunk = true; - break; - } - chunks.push(message.data); - } + let sse = event.unwrap(); + let Some(data) = sse.data else { + continue; + }; + if data == "[DONE]" { + found_done_chunk = true; + break; } + chunks.push(data); } assert!(found_done_chunk); @@ -8920,26 +8914,26 @@ pub async fn test_tool_multi_turn_streaming_inference_request_with_provider( "stream": true, }); - let mut event_source = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .unwrap(); + let mut event_source = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .unwrap(); let mut chunks = vec![]; let mut found_done_chunk = false; while let Some(event) = event_source.next().await { - let event = event.unwrap(); - match event { - Event::Open => continue, - Event::Message(message) => { - if message.data == "[DONE]" { - found_done_chunk = true; - break; - } - chunks.push(message.data); - } + let sse = event.unwrap(); + let Some(data) = sse.data else { + continue; + }; + if data == "[DONE]" { + found_done_chunk = true; + break; } + chunks.push(data); } assert!(found_done_chunk); @@ -10398,26 +10392,26 @@ pub async fn test_parallel_tool_use_streaming_inference_request_with_provider( "extra_headers": if provider.is_modal_provider() { get_modal_extra_headers() } else { UnfilteredInferenceExtraHeaders::default() }, }); - let mut event_source = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .unwrap(); + let mut event_source = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .unwrap(); let mut chunks = vec![]; let mut found_done_chunk = false; while let Some(event) = event_source.next().await { - let event = event.unwrap(); - match event { - Event::Open => continue, - Event::Message(message) => { - if message.data == "[DONE]" { - found_done_chunk = true; - break; - } - chunks.push(message.data); - } + let sse = event.unwrap(); + let Some(data) = sse.data else { + continue; + }; + if data == "[DONE]" { + found_done_chunk = true; + break; } + chunks.push(data); } assert!(found_done_chunk); @@ -11346,26 +11340,26 @@ pub async fn test_json_mode_streaming_inference_request_with_provider(provider: "extra_headers": extra_headers.extra_headers, }); - let mut event_source = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .unwrap(); + let mut event_source = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .unwrap(); let mut chunks = vec![]; let mut found_done_chunk = false; while let Some(event) = event_source.next().await { - let event = event.unwrap(); - match event { - Event::Open => continue, - Event::Message(message) => { - if message.data == "[DONE]" { - found_done_chunk = true; - break; - } - chunks.push(message.data); - } + let sse = event.unwrap(); + let Some(data) = sse.data else { + continue; + }; + if data == "[DONE]" { + found_done_chunk = true; + break; } + chunks.push(data); } assert!(found_done_chunk); @@ -12266,26 +12260,26 @@ pub async fn test_multi_turn_parallel_tool_use_streaming_inference_request_with_ payload["stream"] = json!(true); // Make the second inference request - let mut event_source = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .unwrap(); + let mut event_source = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .unwrap(); let mut chunks = vec![]; let mut found_done_chunk = false; while let Some(event) = event_source.next().await { - let event = event.unwrap(); - match event { - Event::Open => continue, - Event::Message(message) => { - if message.data == "[DONE]" { - found_done_chunk = true; - break; - } - chunks.push(message.data); - } + let sse = event.unwrap(); + let Some(data) = sse.data else { + continue; + }; + if data == "[DONE]" { + found_done_chunk = true; + break; } + chunks.push(data); } assert!(found_done_chunk); @@ -12911,23 +12905,23 @@ pub async fn test_reasoning_multi_turn_thought_streaming_with_provider(provider: "stream": true, }); - let mut event_source = client - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .unwrap(); + let mut event_source = into_sse_stream( + client + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .unwrap(); let mut chunks = vec![]; while let Some(event) = event_source.next().await { - let event = event.unwrap(); - match event { - Event::Open => continue, - Event::Message(message) => { - if message.data == "[DONE]" { - break; - } - chunks.push(message.data); - } + let sse = event.unwrap(); + let Some(data) = sse.data else { + continue; + }; + if data == "[DONE]" { + break; } + chunks.push(data); } let mut inference_id = None; @@ -13023,23 +13017,23 @@ pub async fn test_reasoning_multi_turn_thought_streaming_with_provider(provider: }); println!("Payload (iteration {iteration}): {payload}"); - let mut event_source = client - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .unwrap(); + let mut event_source = into_sse_stream( + client + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .unwrap(); let mut chunks = vec![]; while let Some(event) = event_source.next().await { - let event = event.unwrap(); - match event { - Event::Open => continue, - Event::Message(message) => { - if message.data == "[DONE]" { - break; - } - chunks.push(message.data); - } + let sse = event.unwrap(); + let Some(data) = sse.data else { + continue; + }; + if data == "[DONE]" { + break; } + chunks.push(data); } // Validate we got chunks and extract inference_id diff --git a/tensorzero-core/tests/e2e/providers/commonv2/cache_input_tokens.rs b/tensorzero-core/tests/e2e/providers/commonv2/cache_input_tokens.rs index 1d22868973..b50aad2a12 100644 --- a/tensorzero-core/tests/e2e/providers/commonv2/cache_input_tokens.rs +++ b/tensorzero-core/tests/e2e/providers/commonv2/cache_input_tokens.rs @@ -16,7 +16,7 @@ use crate::providers::common::E2ETestProvider; use crate::providers::helpers::get_modal_extra_headers; use futures::StreamExt; use reqwest::{Client, StatusCode}; -use reqwest_eventsource::{Event, RequestBuilderExt}; +use reqwest_sse_stream::into_sse_stream; use serde_json::{Value, json}; use tensorzero_core::inference::types::extra_headers::UnfilteredInferenceExtraHeaders; use uuid::Uuid; @@ -261,10 +261,8 @@ async fn get_streaming_input_tokens( payload: &Value, variant_name: &str, ) -> Option { - let mut chunks = client - .post(get_gateway_endpoint("/inference")) - .json(payload) - .eventsource() + let mut chunks = into_sse_stream(client.post(get_gateway_endpoint("/inference")).json(payload)) + .await .unwrap_or_else(|e| { panic!( "Failed to create eventsource for streaming request for provider {variant_name}: {e}", @@ -275,20 +273,18 @@ async fn get_streaming_input_tokens( let mut all_chunks: Vec = Vec::new(); while let Some(chunk) = chunks.next().await { - let chunk = chunk.unwrap_or_else(|e| { + let sse = chunk.unwrap_or_else(|e| { panic!("Failed to receive chunk from stream for provider {variant_name}: {e}",) }); - let Event::Message(chunk) = chunk else { - continue; - }; - if chunk.data == "[DONE]" { + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { break; } - let chunk_json: Value = serde_json::from_str(&chunk.data).unwrap_or_else(|e| { + let chunk_json: Value = serde_json::from_str(&data).unwrap_or_else(|e| { panic!( "Failed to parse chunk as JSON for provider {variant_name}: {e}. Data: {}", - chunk.data + data ) }); diff --git a/tensorzero-core/tests/e2e/providers/commonv2/raw_usage.rs b/tensorzero-core/tests/e2e/providers/commonv2/raw_usage.rs index df380e0e84..304c4eb20f 100644 --- a/tensorzero-core/tests/e2e/providers/commonv2/raw_usage.rs +++ b/tensorzero-core/tests/e2e/providers/commonv2/raw_usage.rs @@ -8,7 +8,7 @@ use crate::providers::common::E2ETestProvider; use crate::providers::helpers::get_modal_extra_headers; use futures::StreamExt; use reqwest::{Client, StatusCode}; -use reqwest_eventsource::{Event, RequestBuilderExt}; +use reqwest_sse_stream::into_sse_stream; use serde_json::{Value, json}; use tensorzero_core::inference::types::extra_headers::UnfilteredInferenceExtraHeaders; use uuid::Uuid; @@ -158,39 +158,39 @@ pub async fn test_raw_usage_inference_with_provider_streaming(provider: E2ETestP "extra_headers": extra_headers.extra_headers, }); - let mut chunks = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .unwrap_or_else(|e| { - panic!( - "Failed to create eventsource for streaming request for provider {}: {e}", - provider.variant_name - ) - }); + let mut chunks = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .unwrap_or_else(|e| { + panic!( + "Failed to create eventsource for streaming request for provider {}: {e}", + provider.variant_name + ) + }); let mut found_raw_usage = false; let mut last_chunk_with_usage: Option = None; let mut all_chunks: Vec = Vec::new(); while let Some(chunk) = chunks.next().await { - let chunk = chunk.unwrap_or_else(|e| { + let sse = chunk.unwrap_or_else(|e| { panic!( "Failed to receive chunk from stream for provider {}: {e}", provider.variant_name ) }); - let Event::Message(chunk) = chunk else { - continue; - }; - if chunk.data == "[DONE]" { + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { break; } - let chunk_json: Value = serde_json::from_str(&chunk.data).unwrap_or_else(|e| { + let chunk_json: Value = serde_json::from_str(&data).unwrap_or_else(|e| { panic!( "Failed to parse chunk as JSON for provider {}: {e}. Data: {}", - provider.variant_name, chunk.data + provider.variant_name, data ) }); diff --git a/tensorzero-core/tests/e2e/providers/commonv2/usage.rs b/tensorzero-core/tests/e2e/providers/commonv2/usage.rs index 06183745d4..3696cdfbfe 100644 --- a/tensorzero-core/tests/e2e/providers/commonv2/usage.rs +++ b/tensorzero-core/tests/e2e/providers/commonv2/usage.rs @@ -5,7 +5,7 @@ use crate::providers::common::E2ETestProvider; use crate::providers::helpers::get_modal_extra_headers; use futures::StreamExt; use reqwest::{Client, StatusCode}; -use reqwest_eventsource::{Event, RequestBuilderExt}; +use reqwest_sse_stream::into_sse_stream; use serde_json::{Value, json}; use tensorzero_core::inference::types::extra_headers::UnfilteredInferenceExtraHeaders; use uuid::Uuid; @@ -139,38 +139,38 @@ pub async fn test_reasoning_output_tokens_streaming_with_provider(provider: E2ET "extra_headers": extra_headers.extra_headers, }); - let mut chunks = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .unwrap_or_else(|e| { - panic!( - "Failed to create eventsource for streaming request for provider {}: {e}", - provider.variant_name - ) - }); + let mut chunks = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .unwrap_or_else(|e| { + panic!( + "Failed to create eventsource for streaming request for provider {}: {e}", + provider.variant_name + ) + }); let mut output_tokens: Option = None; let mut all_chunks: Vec = Vec::new(); while let Some(chunk) = chunks.next().await { - let chunk = chunk.unwrap_or_else(|e| { + let sse = chunk.unwrap_or_else(|e| { panic!( "Failed to receive chunk from stream for provider {}: {e}", provider.variant_name ) }); - let Event::Message(chunk) = chunk else { - continue; - }; - if chunk.data == "[DONE]" { + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { break; } - let chunk_json: Value = serde_json::from_str(&chunk.data).unwrap_or_else(|e| { + let chunk_json: Value = serde_json::from_str(&data).unwrap_or_else(|e| { panic!( "Failed to parse chunk as JSON for provider {}: {e}. Data: {}", - provider.variant_name, chunk.data + provider.variant_name, data ) }); diff --git a/tensorzero-core/tests/e2e/providers/reasoning.rs b/tensorzero-core/tests/e2e/providers/reasoning.rs index 25983ba1d2..baa54888c4 100644 --- a/tensorzero-core/tests/e2e/providers/reasoning.rs +++ b/tensorzero-core/tests/e2e/providers/reasoning.rs @@ -5,8 +5,7 @@ use crate::providers::helpers::get_modal_extra_headers; use futures::StreamExt; use reqwest::Client; use reqwest::StatusCode; -use reqwest_eventsource::Event; -use reqwest_eventsource::RequestBuilderExt; +use reqwest_sse_stream::into_sse_stream; use serde_json::Value; use serde_json::json; use tensorzero::Role; @@ -279,7 +278,6 @@ pub async fn test_reasoning_inference_request_simple_nonstreaming_with_provider( pub async fn test_reasoning_inference_request_simple_streaming_with_provider( provider: E2ETestProvider, ) { - use reqwest_eventsource::{Event, RequestBuilderExt}; use serde_json::Value; use crate::common::get_gateway_endpoint; @@ -310,26 +308,24 @@ pub async fn test_reasoning_inference_request_simple_streaming_with_provider( "tags": {"key": tag_value}, }); - let mut event_source = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .unwrap(); + let mut event_source = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .unwrap(); let mut chunks = vec![]; let mut found_done_chunk = false; while let Some(event) = event_source.next().await { - let event = event.unwrap(); - match event { - Event::Open => continue, - Event::Message(message) => { - if message.data == "[DONE]" { - found_done_chunk = true; - break; - } - chunks.push(message.data); - } + let sse = event.unwrap(); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { + found_done_chunk = true; + break; } + chunks.push(data); } assert!(found_done_chunk); @@ -846,26 +842,24 @@ pub async fn test_reasoning_inference_request_json_mode_streaming_with_provider( "extra_headers": extra_headers, }); - let mut event_source = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .unwrap(); + let mut event_source = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .unwrap(); let mut chunks = vec![]; let mut found_done_chunk = false; while let Some(event) = event_source.next().await { - let event = event.unwrap(); - match event { - Event::Open => continue, - Event::Message(message) => { - if message.data == "[DONE]" { - found_done_chunk = true; - break; - } - chunks.push(message.data); - } + let sse = event.unwrap(); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { + found_done_chunk = true; + break; } + chunks.push(data); } assert!(found_done_chunk); diff --git a/tensorzero-core/tests/e2e/raw_response/cache.rs b/tensorzero-core/tests/e2e/raw_response/cache.rs index 7d562cd1bb..27b972fb4e 100644 --- a/tensorzero-core/tests/e2e/raw_response/cache.rs +++ b/tensorzero-core/tests/e2e/raw_response/cache.rs @@ -7,7 +7,7 @@ use futures::StreamExt; use reqwest::{Client, StatusCode}; -use reqwest_eventsource::{Event, RequestBuilderExt}; +use reqwest_sse_stream::into_sse_stream; use serde_json::{Map, Value, json}; use tensorzero::test_helpers::{ make_embedded_gateway_e2e_with_unique_db, start_http_gateway_with_unique_db, @@ -417,25 +417,21 @@ async fn make_openai_request_to_gateway( let url = format!("{base_url}/openai/v1/chat/completions"); if stream { - let mut chunks = Client::new() - .post(&url) - .json(&payload) - .eventsource() + let mut chunks = into_sse_stream(Client::new().post(&url).json(&payload)) + .await .unwrap(); // Collect raw_response entries from all chunks let mut raw_response_entries: Vec = Vec::new(); while let Some(chunk) = chunks.next().await { - let chunk = chunk.unwrap(); - let Event::Message(chunk) = chunk else { - continue; - }; - if chunk.data == "[DONE]" { + let sse = chunk.unwrap(); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { break; } - let chunk_json: Value = serde_json::from_str(&chunk.data).unwrap(); + let chunk_json: Value = serde_json::from_str(&data).unwrap(); // Check for tensorzero_raw_response at chunk level if let Some(rr) = chunk_json.get("tensorzero_raw_response") && let Some(arr) = rr.as_array() diff --git a/tensorzero-core/tests/e2e/raw_response/mod.rs b/tensorzero-core/tests/e2e/raw_response/mod.rs index b3b8885d3c..974afd6beb 100644 --- a/tensorzero-core/tests/e2e/raw_response/mod.rs +++ b/tensorzero-core/tests/e2e/raw_response/mod.rs @@ -8,7 +8,7 @@ mod openai_compatible; use futures::StreamExt; use reqwest::{Client, StatusCode}; -use reqwest_eventsource::{Event, RequestBuilderExt}; +use reqwest_sse_stream::into_sse_stream; use serde_json::{Value, json}; use uuid::Uuid; @@ -142,11 +142,13 @@ async fn e2e_test_raw_response_chat_completions_streaming() { "include_raw_response": true }); - let mut chunks = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .expect("Failed to create eventsource for streaming request"); + let mut chunks = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .expect("Failed to create eventsource for streaming request"); let mut found_raw_chunk = false; let mut content_chunks_count: usize = 0; @@ -154,16 +156,13 @@ async fn e2e_test_raw_response_chat_completions_streaming() { let mut all_chunks: Vec = Vec::new(); while let Some(chunk) = chunks.next().await { - let chunk = chunk.expect("Failed to receive chunk from stream"); - let Event::Message(chunk) = chunk else { - continue; - }; - if chunk.data == "[DONE]" { + let sse = chunk.expect("Failed to receive chunk from stream"); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { break; } - let chunk_json: Value = - serde_json::from_str(&chunk.data).expect("Failed to parse chunk as JSON"); + let chunk_json: Value = serde_json::from_str(&data).expect("Failed to parse chunk as JSON"); all_chunks.push(chunk_json.clone()); @@ -287,26 +286,25 @@ async fn e2e_test_raw_response_responses_api_streaming() { "include_raw_response": true }); - let mut chunks = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .expect("Failed to create eventsource for streaming request"); + let mut chunks = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .expect("Failed to create eventsource for streaming request"); let mut found_raw_chunk = false; let mut all_chunks: Vec = Vec::new(); while let Some(chunk) = chunks.next().await { - let chunk = chunk.expect("Failed to receive chunk from stream"); - let Event::Message(chunk) = chunk else { - continue; - }; - if chunk.data == "[DONE]" { + let sse = chunk.expect("Failed to receive chunk from stream"); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { break; } - let chunk_json: Value = - serde_json::from_str(&chunk.data).expect("Failed to parse chunk as JSON"); + let chunk_json: Value = serde_json::from_str(&data).expect("Failed to parse chunk as JSON"); all_chunks.push(chunk_json.clone()); @@ -392,22 +390,22 @@ async fn e2e_test_raw_response_not_requested_streaming() { // include_raw_response is NOT set }); - let mut chunks = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .unwrap(); + let mut chunks = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .unwrap(); while let Some(chunk) = chunks.next().await { - let chunk = chunk.unwrap(); - let Event::Message(chunk) = chunk else { - continue; - }; - if chunk.data == "[DONE]" { + let sse = chunk.unwrap(); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { break; } - let chunk_json: Value = serde_json::from_str(&chunk.data).unwrap(); + let chunk_json: Value = serde_json::from_str(&data).unwrap(); // raw_chunk should NOT be present at chunk level when not requested assert!( @@ -518,27 +516,26 @@ async fn e2e_test_raw_response_best_of_n_streaming() { "include_raw_response": true }); - let mut chunks = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .expect("Failed to create eventsource for streaming request"); + let mut chunks = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .expect("Failed to create eventsource for streaming request"); let mut found_raw_response = false; let mut found_raw_chunk = false; let mut raw_response_count = 0; while let Some(chunk) = chunks.next().await { - let chunk = chunk.expect("Failed to receive chunk from stream"); - let Event::Message(chunk) = chunk else { - continue; - }; - if chunk.data == "[DONE]" { + let sse = chunk.expect("Failed to receive chunk from stream"); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { break; } - let chunk_json: Value = - serde_json::from_str(&chunk.data).expect("Failed to parse chunk as JSON"); + let chunk_json: Value = serde_json::from_str(&data).expect("Failed to parse chunk as JSON"); // Check for raw_chunk (current streaming inference) if chunk_json.get("raw_chunk").is_some() { @@ -665,27 +662,26 @@ async fn e2e_test_raw_response_mixture_of_n_streaming() { "include_raw_response": true }); - let mut chunks = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .expect("Failed to create eventsource for streaming request"); + let mut chunks = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .expect("Failed to create eventsource for streaming request"); let mut found_raw_response = false; let mut found_raw_chunk = false; let mut raw_response_count = 0; while let Some(chunk) = chunks.next().await { - let chunk = chunk.expect("Failed to receive chunk from stream"); - let Event::Message(chunk) = chunk else { - continue; - }; - if chunk.data == "[DONE]" { + let sse = chunk.expect("Failed to receive chunk from stream"); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { break; } - let chunk_json: Value = - serde_json::from_str(&chunk.data).expect("Failed to parse chunk as JSON"); + let chunk_json: Value = serde_json::from_str(&data).expect("Failed to parse chunk as JSON"); // Check for raw_chunk (current streaming inference) if chunk_json.get("raw_chunk").is_some() { @@ -833,27 +829,26 @@ async fn e2e_test_raw_response_dicl_streaming() { "include_raw_response": true }); - let mut chunks = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .expect("Failed to create eventsource for streaming request"); + let mut chunks = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .expect("Failed to create eventsource for streaming request"); let mut found_raw_response = false; let mut found_raw_chunk = false; let mut api_types: Vec = Vec::new(); while let Some(chunk) = chunks.next().await { - let chunk = chunk.expect("Failed to receive chunk from stream"); - let Event::Message(chunk) = chunk else { - continue; - }; - if chunk.data == "[DONE]" { + let sse = chunk.expect("Failed to receive chunk from stream"); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { break; } - let chunk_json: Value = - serde_json::from_str(&chunk.data).expect("Failed to parse chunk as JSON"); + let chunk_json: Value = serde_json::from_str(&data).expect("Failed to parse chunk as JSON"); // Check for raw_chunk if chunk_json.get("raw_chunk").is_some() { @@ -970,25 +965,24 @@ async fn e2e_test_raw_response_json_function_streaming() { "include_raw_response": true }); - let mut chunks = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .expect("Failed to create eventsource for streaming request"); + let mut chunks = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .expect("Failed to create eventsource for streaming request"); let mut found_raw_chunk = false; while let Some(chunk) = chunks.next().await { - let chunk = chunk.expect("Failed to receive chunk from stream"); - let Event::Message(chunk) = chunk else { - continue; - }; - if chunk.data == "[DONE]" { + let sse = chunk.expect("Failed to receive chunk from stream"); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { break; } - let chunk_json: Value = - serde_json::from_str(&chunk.data).expect("Failed to parse chunk as JSON"); + let chunk_json: Value = serde_json::from_str(&data).expect("Failed to parse chunk as JSON"); // Check for raw_chunk if chunk_json.get("raw_chunk").is_some() { diff --git a/tensorzero-core/tests/e2e/raw_response/openai_compatible.rs b/tensorzero-core/tests/e2e/raw_response/openai_compatible.rs index 4e9335225c..9c4d0f567b 100644 --- a/tensorzero-core/tests/e2e/raw_response/openai_compatible.rs +++ b/tensorzero-core/tests/e2e/raw_response/openai_compatible.rs @@ -5,7 +5,7 @@ use futures::StreamExt; use reqwest::{Client, StatusCode}; -use reqwest_eventsource::{Event, RequestBuilderExt}; +use reqwest_sse_stream::into_sse_stream; use serde_json::{Value, json}; use uuid::Uuid; @@ -177,11 +177,13 @@ async fn test_openai_compatible_raw_response_streaming() { "tensorzero::include_raw_response": true }); - let mut chunks = client - .post(get_gateway_endpoint("/openai/v1/chat/completions")) - .json(&payload) - .eventsource() - .expect("Failed to create eventsource for streaming request"); + let mut chunks = into_sse_stream( + client + .post(get_gateway_endpoint("/openai/v1/chat/completions")) + .json(&payload), + ) + .await + .expect("Failed to create eventsource for streaming request"); let mut found_raw_chunk = false; let mut content_chunks_count: usize = 0; @@ -189,16 +191,13 @@ async fn test_openai_compatible_raw_response_streaming() { let mut all_chunks: Vec = Vec::new(); while let Some(chunk) = chunks.next().await { - let chunk = chunk.expect("Failed to receive chunk from stream"); - let Event::Message(chunk) = chunk else { - continue; - }; - if chunk.data == "[DONE]" { + let sse = chunk.expect("Failed to receive chunk from stream"); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { break; } - let chunk_json: Value = - serde_json::from_str(&chunk.data).expect("Failed to parse chunk as JSON"); + let chunk_json: Value = serde_json::from_str(&data).expect("Failed to parse chunk as JSON"); all_chunks.push(chunk_json.clone()); @@ -260,23 +259,22 @@ async fn test_openai_compatible_raw_response_streaming_not_requested() { // Note: tensorzero::include_raw_response is NOT set (defaults to false) }); - let mut chunks = client - .post(get_gateway_endpoint("/openai/v1/chat/completions")) - .json(&payload) - .eventsource() - .expect("Failed to create eventsource for streaming request"); + let mut chunks = into_sse_stream( + client + .post(get_gateway_endpoint("/openai/v1/chat/completions")) + .json(&payload), + ) + .await + .expect("Failed to create eventsource for streaming request"); while let Some(chunk) = chunks.next().await { - let chunk = chunk.expect("Failed to receive chunk from stream"); - let Event::Message(chunk) = chunk else { - continue; - }; - if chunk.data == "[DONE]" { + let sse = chunk.expect("Failed to receive chunk from stream"); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { break; } - let chunk_json: Value = - serde_json::from_str(&chunk.data).expect("Failed to parse chunk as JSON"); + let chunk_json: Value = serde_json::from_str(&data).expect("Failed to parse chunk as JSON"); // tensorzero_raw_chunk should NOT be present assert!( diff --git a/tensorzero-core/tests/e2e/raw_usage/cache.rs b/tensorzero-core/tests/e2e/raw_usage/cache.rs index f10f538c8e..207edde442 100644 --- a/tensorzero-core/tests/e2e/raw_usage/cache.rs +++ b/tensorzero-core/tests/e2e/raw_usage/cache.rs @@ -7,7 +7,7 @@ use futures::StreamExt; use reqwest::{Client, StatusCode}; -use reqwest_eventsource::{Event, RequestBuilderExt}; +use reqwest_sse_stream::into_sse_stream; use serde_json::{Map, Value, json}; use tensorzero::test_helpers::{ make_embedded_gateway_e2e_with_unique_db, start_http_gateway_with_unique_db, @@ -412,25 +412,21 @@ async fn make_openai_request_to_gateway( let url = format!("{base_url}/openai/v1/chat/completions"); if stream { - let mut chunks = Client::new() - .post(&url) - .json(&payload) - .eventsource() + let mut chunks = into_sse_stream(Client::new().post(&url).json(&payload)) + .await .unwrap(); // Collect raw_usage entries from all chunks, similar to native API let mut raw_usage_entries: Vec = Vec::new(); while let Some(chunk) = chunks.next().await { - let chunk = chunk.unwrap(); - let Event::Message(chunk) = chunk else { - continue; - }; - if chunk.data == "[DONE]" { + let sse = chunk.unwrap(); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { break; } - let chunk_json: Value = serde_json::from_str(&chunk.data).unwrap(); + let chunk_json: Value = serde_json::from_str(&data).unwrap(); // Check for tensorzero_raw_usage at chunk level (sibling to usage) if let Some(ru) = chunk_json.get("tensorzero_raw_usage") && let Some(arr) = ru.as_array() diff --git a/tensorzero-core/tests/e2e/raw_usage/mod.rs b/tensorzero-core/tests/e2e/raw_usage/mod.rs index 3fd8b37d3c..84692874c5 100644 --- a/tensorzero-core/tests/e2e/raw_usage/mod.rs +++ b/tensorzero-core/tests/e2e/raw_usage/mod.rs @@ -7,7 +7,7 @@ mod openai_compatible; use futures::StreamExt; use reqwest::{Client, StatusCode}; -use reqwest_eventsource::{Event, RequestBuilderExt}; +use reqwest_sse_stream::into_sse_stream; use serde_json::{Value, json}; use uuid::Uuid; @@ -226,27 +226,26 @@ async fn e2e_test_raw_usage_chat_completions_streaming() { "include_raw_usage": true }); - let mut chunks = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .expect("Failed to create eventsource for streaming request"); + let mut chunks = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .expect("Failed to create eventsource for streaming request"); let mut found_raw_usage = false; let mut last_chunk_with_usage: Option = None; let mut all_chunks: Vec = Vec::new(); while let Some(chunk) = chunks.next().await { - let chunk = chunk.expect("Failed to receive chunk from stream"); - let Event::Message(chunk) = chunk else { - continue; - }; - if chunk.data == "[DONE]" { + let sse = chunk.expect("Failed to receive chunk from stream"); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { break; } - let chunk_json: Value = - serde_json::from_str(&chunk.data).expect("Failed to parse chunk as JSON"); + let chunk_json: Value = serde_json::from_str(&data).expect("Failed to parse chunk as JSON"); all_chunks.push(chunk_json.clone()); @@ -373,27 +372,26 @@ async fn e2e_test_raw_usage_responses_api_streaming() { "include_raw_usage": true }); - let mut chunks = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .expect("Failed to create eventsource for streaming request"); + let mut chunks = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .expect("Failed to create eventsource for streaming request"); let mut found_raw_usage = false; let mut last_chunk_with_usage: Option = None; let mut all_chunks: Vec = Vec::new(); while let Some(chunk) = chunks.next().await { - let chunk = chunk.expect("Failed to receive chunk from stream"); - let Event::Message(chunk) = chunk else { - continue; - }; - if chunk.data == "[DONE]" { + let sse = chunk.expect("Failed to receive chunk from stream"); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { break; } - let chunk_json: Value = - serde_json::from_str(&chunk.data).expect("Failed to parse chunk as JSON"); + let chunk_json: Value = serde_json::from_str(&data).expect("Failed to parse chunk as JSON"); all_chunks.push(chunk_json.clone()); @@ -493,22 +491,22 @@ async fn e2e_test_raw_usage_not_requested_streaming() { // include_raw_usage is NOT set }); - let mut chunks = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .unwrap(); + let mut chunks = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .unwrap(); while let Some(chunk) = chunks.next().await { - let chunk = chunk.unwrap(); - let Event::Message(chunk) = chunk else { - continue; - }; - if chunk.data == "[DONE]" { + let sse = chunk.unwrap(); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { break; } - let chunk_json: Value = serde_json::from_str(&chunk.data).unwrap(); + let chunk_json: Value = serde_json::from_str(&data).unwrap(); // raw_usage should NOT be present at chunk level when not requested assert!( @@ -615,26 +613,25 @@ async fn e2e_test_raw_usage_best_of_n_streaming() { "include_raw_usage": true }); - let mut chunks = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .expect("Failed to create eventsource for streaming request"); + let mut chunks = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .expect("Failed to create eventsource for streaming request"); let mut found_raw_usage = false; let mut raw_usage_count = 0; while let Some(chunk) = chunks.next().await { - let chunk = chunk.expect("Failed to receive chunk from stream"); - let Event::Message(chunk) = chunk else { - continue; - }; - if chunk.data == "[DONE]" { + let sse = chunk.expect("Failed to receive chunk from stream"); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { break; } - let chunk_json: Value = - serde_json::from_str(&chunk.data).expect("Failed to parse chunk as JSON"); + let chunk_json: Value = serde_json::from_str(&data).expect("Failed to parse chunk as JSON"); // Check if this chunk has raw_usage (sibling to usage at chunk level) if let Some(raw_usage) = chunk_json.get("raw_usage") { @@ -768,27 +765,26 @@ async fn e2e_test_raw_usage_mixture_of_n_streaming() { "include_raw_usage": true }); - let mut chunks = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .expect("Failed to create eventsource for streaming request"); + let mut chunks = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .expect("Failed to create eventsource for streaming request"); let mut found_raw_usage = false; let mut raw_usage_count = 0; let mut all_chunks: Vec = Vec::new(); while let Some(chunk) = chunks.next().await { - let chunk = chunk.expect("Failed to receive chunk from stream"); - let Event::Message(chunk) = chunk else { - continue; - }; - if chunk.data == "[DONE]" { + let sse = chunk.expect("Failed to receive chunk from stream"); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { break; } - let chunk_json: Value = - serde_json::from_str(&chunk.data).expect("Failed to parse chunk as JSON"); + let chunk_json: Value = serde_json::from_str(&data).expect("Failed to parse chunk as JSON"); all_chunks.push(chunk_json.clone()); @@ -937,26 +933,25 @@ async fn e2e_test_raw_usage_dicl_streaming() { "include_raw_usage": true }); - let mut chunks = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .expect("Failed to create eventsource for streaming request"); + let mut chunks = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .expect("Failed to create eventsource for streaming request"); let mut found_raw_usage = false; let mut api_types: Vec = Vec::new(); while let Some(chunk) = chunks.next().await { - let chunk = chunk.expect("Failed to receive chunk from stream"); - let Event::Message(chunk) = chunk else { - continue; - }; - if chunk.data == "[DONE]" { + let sse = chunk.expect("Failed to receive chunk from stream"); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { break; } - let chunk_json: Value = - serde_json::from_str(&chunk.data).expect("Failed to parse chunk as JSON"); + let chunk_json: Value = serde_json::from_str(&data).expect("Failed to parse chunk as JSON"); // Check if this chunk has raw_usage (sibling to usage at chunk level) if let Some(raw_usage) = chunk_json.get("raw_usage") { @@ -1119,26 +1114,25 @@ async fn e2e_test_raw_usage_json_function_streaming() { "include_raw_usage": true }); - let mut chunks = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .expect("Failed to create eventsource for streaming request"); + let mut chunks = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .expect("Failed to create eventsource for streaming request"); let mut found_raw_usage = false; let mut last_chunk_with_usage: Option = None; while let Some(chunk) = chunks.next().await { - let chunk = chunk.expect("Failed to receive chunk from stream"); - let Event::Message(chunk) = chunk else { - continue; - }; - if chunk.data == "[DONE]" { + let sse = chunk.expect("Failed to receive chunk from stream"); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { break; } - let chunk_json: Value = - serde_json::from_str(&chunk.data).expect("Failed to parse chunk as JSON"); + let chunk_json: Value = serde_json::from_str(&data).expect("Failed to parse chunk as JSON"); // Check if this chunk has raw_usage (sibling to usage at chunk level) if chunk_json.get("raw_usage").is_some() { diff --git a/tensorzero-core/tests/e2e/raw_usage/openai_compatible.rs b/tensorzero-core/tests/e2e/raw_usage/openai_compatible.rs index 087e64bc29..c84542170a 100644 --- a/tensorzero-core/tests/e2e/raw_usage/openai_compatible.rs +++ b/tensorzero-core/tests/e2e/raw_usage/openai_compatible.rs @@ -5,7 +5,7 @@ use futures::StreamExt; use reqwest::{Client, StatusCode}; -use reqwest_eventsource::{Event, RequestBuilderExt}; +use reqwest_sse_stream::into_sse_stream; use serde_json::{Value, json}; use uuid::Uuid; @@ -254,27 +254,26 @@ async fn test_openai_compatible_raw_usage_streaming() { "tensorzero::include_raw_usage": true }); - let mut response = client - .post(get_gateway_endpoint("/openai/v1/chat/completions")) - .json(&payload) - .eventsource() - .unwrap(); + let mut response = into_sse_stream( + client + .post(get_gateway_endpoint("/openai/v1/chat/completions")) + .json(&payload), + ) + .await + .unwrap(); let mut all_chunks: Vec = Vec::new(); // Track which chunk indices have raw_usage let mut chunks_with_raw_usage: Vec = Vec::new(); while let Some(event) = response.next().await { - let event = event.expect("Failed to receive event"); - let Event::Message(message) = event else { - continue; - }; - if message.data == "[DONE]" { + let sse = event.expect("Failed to receive event"); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { break; } - let chunk_json: Value = - serde_json::from_str(&message.data).expect("Failed to parse chunk as JSON"); + let chunk_json: Value = serde_json::from_str(&data).expect("Failed to parse chunk as JSON"); let chunk_index = all_chunks.len(); all_chunks.push(chunk_json.clone()); @@ -364,27 +363,26 @@ async fn test_openai_compatible_streaming_usage_only_in_last_chunk() { "tensorzero::episode_id": episode_id.to_string() }); - let mut response = client - .post(get_gateway_endpoint("/openai/v1/chat/completions")) - .json(&payload) - .eventsource() - .unwrap(); + let mut response = into_sse_stream( + client + .post(get_gateway_endpoint("/openai/v1/chat/completions")) + .json(&payload), + ) + .await + .unwrap(); let mut all_chunks: Vec = Vec::new(); // Track which chunk indices have usage populated (not null) let mut chunks_with_usage: Vec = Vec::new(); while let Some(event) = response.next().await { - let event = event.expect("Failed to receive event"); - let Event::Message(message) = event else { - continue; - }; - if message.data == "[DONE]" { + let sse = event.expect("Failed to receive event"); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { break; } - let chunk_json: Value = - serde_json::from_str(&message.data).expect("Failed to parse chunk as JSON"); + let chunk_json: Value = serde_json::from_str(&data).expect("Failed to parse chunk as JSON"); let chunk_index = all_chunks.len(); all_chunks.push(chunk_json.clone()); @@ -460,26 +458,25 @@ async fn test_openai_compatible_streaming_no_usage_when_disabled() { "tensorzero::episode_id": episode_id.to_string() }); - let mut response = client - .post(get_gateway_endpoint("/openai/v1/chat/completions")) - .json(&payload) - .eventsource() - .unwrap(); + let mut response = into_sse_stream( + client + .post(get_gateway_endpoint("/openai/v1/chat/completions")) + .json(&payload), + ) + .await + .unwrap(); let mut all_chunks: Vec = Vec::new(); let mut chunks_with_usage: Vec = Vec::new(); while let Some(event) = response.next().await { - let event = event.expect("Failed to receive event"); - let Event::Message(message) = event else { - continue; - }; - if message.data == "[DONE]" { + let sse = event.expect("Failed to receive event"); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { break; } - let chunk_json: Value = - serde_json::from_str(&message.data).expect("Failed to parse chunk as JSON"); + let chunk_json: Value = serde_json::from_str(&data).expect("Failed to parse chunk as JSON"); let chunk_index = all_chunks.len(); all_chunks.push(chunk_json.clone()); diff --git a/tensorzero-core/tests/e2e/retries.rs b/tensorzero-core/tests/e2e/retries.rs index c7fda17939..52723046d4 100644 --- a/tensorzero-core/tests/e2e/retries.rs +++ b/tensorzero-core/tests/e2e/retries.rs @@ -1,6 +1,6 @@ use futures::StreamExt; use reqwest::{Client, StatusCode}; -use reqwest_eventsource::{Event, RequestBuilderExt}; +use reqwest_sse_stream::into_sse_stream; use serde_json::{Value, json}; use tensorzero_core::{ inference::types::{Role, StoredContentBlock, StoredRequestMessage, Text}, @@ -189,23 +189,21 @@ async fn e2e_test_streaming_flaky() { }} }); - let mut event_source = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .unwrap(); + let mut event_source = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .unwrap(); let mut chunks = vec![]; while let Some(event) = event_source.next().await { - let event = event.unwrap(); - match event { - Event::Open => continue, - Event::Message(message) => { - if message.data == "[DONE]" { - break; - } - chunks.push(message.data); - } + let sse = event.unwrap(); + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { + break; } + chunks.push(data); } let mut inference_id = None; for (i, chunk) in chunks.iter().enumerate() { diff --git a/tensorzero-core/tests/e2e/streaming_errors.rs b/tensorzero-core/tests/e2e/streaming_errors.rs index 081eca3774..0d21ce50f5 100644 --- a/tensorzero-core/tests/e2e/streaming_errors.rs +++ b/tensorzero-core/tests/e2e/streaming_errors.rs @@ -1,4 +1,5 @@ use futures::StreamExt; +use reqwest_sse_stream::into_sse_stream; use serde_json::json; use tensorzero::{ Client, ClientInferenceParams, InferenceOutput, InferenceResponseChunk, Input, InputMessage, @@ -7,7 +8,6 @@ use tensorzero::{ use tensorzero_core::inference::types::{Arguments, System, Text}; use crate::common::get_gateway_endpoint; -use reqwest_eventsource::{Event, RequestBuilderExt}; #[tokio::test] async fn test_client_stream_with_error_http_gateway() { @@ -79,47 +79,37 @@ async fn test_stream_with_error() { "stream": true, }); - let mut event_stream = reqwest::Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .unwrap(); + let mut event_stream = into_sse_stream( + reqwest::Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .unwrap(); let mut good_chunks = 0; // Check we receive all client chunks correctly - loop { - match event_stream.next().await { - Some(Ok(e)) => match e { - Event::Open => continue, - Event::Message(message) => { - if message.data == "[DONE]" { - break; - } - let obj: serde_json::Value = serde_json::from_str(&message.data).unwrap(); - if let Some(error) = obj.get("error") { - let error_str: &str = error.as_str().unwrap(); - assert!( - error_str.contains("Dummy error in stream"), - "Unexpected error: {error_str}" - ); - assert_eq!(good_chunks, 3); - } else { - let _chunk: InferenceResponseChunk = - serde_json::from_str(&message.data).unwrap(); - } - good_chunks += 1; - } - }, - Some(Err(e)) => { - if matches!(e, reqwest_eventsource::Error::StreamEnded) { - break; - } - panic!("Unexpected error: {e:?}"); - } - None => { - panic!("Stream ended unexpectedly"); - } + while let Some(event) = event_stream.next().await { + let sse = match event { + Ok(sse) => sse, + Err(_) => break, + }; + let Some(data) = sse.data else { continue }; + if data == "[DONE]" { + break; + } + let obj: serde_json::Value = serde_json::from_str(&data).unwrap(); + if let Some(error) = obj.get("error") { + let error_str: &str = error.as_str().unwrap(); + assert!( + error_str.contains("Dummy error in stream"), + "Unexpected error: {error_str}" + ); + assert_eq!(good_chunks, 3); + } else { + let _chunk: InferenceResponseChunk = serde_json::from_str(&data).unwrap(); } + good_chunks += 1; } assert_eq!(good_chunks, 17); } diff --git a/tensorzero-core/tests/e2e/timeouts.rs b/tensorzero-core/tests/e2e/timeouts.rs index 471cc16c0d..e74a241042 100644 --- a/tensorzero-core/tests/e2e/timeouts.rs +++ b/tensorzero-core/tests/e2e/timeouts.rs @@ -1,8 +1,9 @@ use std::time::{Duration, Instant}; +use futures::StreamExt; use http::StatusCode; use reqwest::Client; -use reqwest_eventsource::{Event, RequestBuilderExt}; +use reqwest_sse_stream::into_sse_stream; use serde_json::{Value, json}; use tensorzero::{ClientInferenceParams, Input, InputMessage, InputMessageContent, Role}; use tensorzero_core::{ @@ -12,7 +13,6 @@ use tensorzero_core::{ }, inference::types::Text, }; -use tokio_stream::StreamExt; use uuid::Uuid; use crate::common::get_gateway_endpoint; @@ -167,24 +167,25 @@ async fn test_json_inference_ttft_ms() { } async fn test_inference_ttft_ms(payload: Value, json: bool) { - let mut response = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .unwrap(); + let mut response = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .unwrap(); let mut inference_id = None; while let Some(event) = response.next().await { - let chunk = event.unwrap(); - println!("chunk: {chunk:?}"); - if let Event::Message(event) = chunk { - if event.data == "[DONE]" { - break; - } - let event = serde_json::from_str::(&event.data).unwrap(); - inference_id = Some(event["inference_id"].as_str().unwrap().parse().unwrap()); + let sse = event.unwrap(); + let Some(data) = sse.data else { continue }; + println!("chunk: {data:?}"); + if data == "[DONE]" { + break; } + let event = serde_json::from_str::(&data).unwrap(); + inference_id = Some(event["inference_id"].as_str().unwrap().parse().unwrap()); } // Sleep for 200ms to allow time for data to be inserted into ClickHouse (trailing writes from API) @@ -515,24 +516,25 @@ async fn best_of_n_judge_timeout(payload: Value) { async fn slow_second_chunk_streaming(payload: Value) { let start = Instant::now(); - let mut response = Client::new() - .post(get_gateway_endpoint("/inference")) - .json(&payload) - .eventsource() - .unwrap(); + let mut response = into_sse_stream( + Client::new() + .post(get_gateway_endpoint("/inference")) + .json(&payload), + ) + .await + .unwrap(); let mut inference_id = None; while let Some(event) = response.next().await { - let chunk = event.unwrap(); - println!("chunk: {chunk:?}"); - if let Event::Message(event) = chunk { - if event.data == "[DONE]" { - break; - } - let event = serde_json::from_str::(&event.data).unwrap(); - inference_id = Some(event["inference_id"].as_str().unwrap().parse().unwrap()); + let sse = event.unwrap(); + let Some(data) = sse.data else { continue }; + println!("chunk: {data:?}"); + if data == "[DONE]" { + break; } + let event = serde_json::from_str::(&data).unwrap(); + inference_id = Some(event["inference_id"].as_str().unwrap().parse().unwrap()); } // The overall stream duration should be at least 2 seconds, because we used the 'slow_second_chunk' model From 42aa83136bb42164295d8a26a410df55d82fc45b Mon Sep 17 00:00:00 2001 From: Aaron Hill Date: Thu, 29 Jan 2026 16:51:53 -0500 Subject: [PATCH 4/5] More work --- Cargo.lock | 1 + gateway/Cargo.toml | 1 + internal/autopilot-client/src/client.rs | 2 +- tensorzero-core/src/providers/anthropic.rs | 3 +-- .../src/providers/gcp_vertex_anthropic.rs | 3 +-- tensorzero-core/src/providers/mistral.rs | 3 +-- tensorzero-core/tests/e2e/openai_compatible.rs | 1 - .../tests/e2e/providers/anthropic.rs | 16 ++++++++++++---- .../providers/commonv2/cache_input_tokens.rs | 18 +++++++++++------- 9 files changed, 29 insertions(+), 19 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2a1171d870..0bb141f5ba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2444,6 +2444,7 @@ dependencies = [ "metrics-exporter-prometheus", "mimalloc", "reqwest 0.12.28", + "reqwest-sse-stream", "secrecy", "serde", "serde_json", diff --git a/gateway/Cargo.toml b/gateway/Cargo.toml index 223fee4f86..1f638af068 100644 --- a/gateway/Cargo.toml +++ b/gateway/Cargo.toml @@ -41,6 +41,7 @@ ts-bindings = ["dep:ts-rs", "tensorzero-core/ts-bindings", "tensorzero-optimizer [dev-dependencies] reqwest.workspace = true +reqwest-sse-stream = { path = "../internal/reqwest-sse-stream" } tempfile = "3.21.0" tensorzero = { path = "../clients/rust", features = ["e2e_tests"] } serde_json = { workspace = true } diff --git a/internal/autopilot-client/src/client.rs b/internal/autopilot-client/src/client.rs index 9ef73a80b4..8c59b81668 100644 --- a/internal/autopilot-client/src/client.rs +++ b/internal/autopilot-client/src/client.rs @@ -676,7 +676,7 @@ impl AutopilotClient { self.sse_http_client.get(url).headers(self.auth_headers()), ) .await - .map_err(|e| Self::convert_sse_error(e))?; + .map_err(Self::convert_sse_error)?; // Connection is good, return the stream let cache = self.tool_call_cache.clone(); diff --git a/tensorzero-core/src/providers/anthropic.rs b/tensorzero-core/src/providers/anthropic.rs index 45075837cb..96e0f08cbd 100644 --- a/tensorzero-core/src/providers/anthropic.rs +++ b/tensorzero-core/src/providers/anthropic.rs @@ -513,8 +513,7 @@ fn stream_anthropic( let data: Result = serde_json::from_str(&message_data).map_err(|e| Error::new(ErrorDetails::InferenceServer { message: format!( - "Error parsing message: {}, Data: {}", - e, message_data + "Error parsing message: {e}, Data: {message_data}" ), provider_type: PROVIDER_TYPE.to_string(), raw_request: Some(raw_request.to_string()), diff --git a/tensorzero-core/src/providers/gcp_vertex_anthropic.rs b/tensorzero-core/src/providers/gcp_vertex_anthropic.rs index 08e499dd32..d5172a2a4d 100644 --- a/tensorzero-core/src/providers/gcp_vertex_anthropic.rs +++ b/tensorzero-core/src/providers/gcp_vertex_anthropic.rs @@ -418,8 +418,7 @@ fn stream_anthropic( let data: Result = serde_json::from_str(&message_data).map_err(|e| Error::new(ErrorDetails::InferenceServer { message: format!( - "Error parsing message: {}, Data: {}", - e, message_data + "Error parsing message: {e}, Data: {message_data}" ), provider_type: PROVIDER_TYPE.to_string(), raw_request: Some(raw_request.clone()), diff --git a/tensorzero-core/src/providers/mistral.rs b/tensorzero-core/src/providers/mistral.rs index 8fd0d11510..3fe793c302 100644 --- a/tensorzero-core/src/providers/mistral.rs +++ b/tensorzero-core/src/providers/mistral.rs @@ -381,8 +381,7 @@ pub fn stream_mistral( let data: Result = serde_json::from_str(&message_data).map_err(|e| ErrorDetails::InferenceServer { message: format!( - "Error parsing chunk. Error: {}, Data: {}", - e, message_data + "Error parsing chunk. Error: {e}, Data: {message_data}" ), provider_type: PROVIDER_TYPE.to_string(), raw_request: Some(raw_request.clone()), diff --git a/tensorzero-core/tests/e2e/openai_compatible.rs b/tensorzero-core/tests/e2e/openai_compatible.rs index a6fab60ba1..a9850edc2f 100644 --- a/tensorzero-core/tests/e2e/openai_compatible.rs +++ b/tensorzero-core/tests/e2e/openai_compatible.rs @@ -933,7 +933,6 @@ async fn test_openai_compatible_deny_unknown_fields() { #[tokio::test] async fn test_openai_compatible_streaming() { - let client = Client::new(); let episode_id = Uuid::now_v7(); let body = json!({ diff --git a/tensorzero-core/tests/e2e/providers/anthropic.rs b/tensorzero-core/tests/e2e/providers/anthropic.rs index e1b518fe65..094a1aa107 100644 --- a/tensorzero-core/tests/e2e/providers/anthropic.rs +++ b/tensorzero-core/tests/e2e/providers/anthropic.rs @@ -256,7 +256,9 @@ async fn test_empty_chunks_success() { let client = Client::new(); let mut event_source = into_sse_stream( - client.post(get_gateway_endpoint("/inference")).json(&payload), + client + .post(get_gateway_endpoint("/inference")) + .json(&payload), ) .await .unwrap(); @@ -721,7 +723,9 @@ async fn test_beta_structured_outputs_json_helper(stream: bool) { let inference_id = if stream { let mut event_source = into_sse_stream( - client.post(get_gateway_endpoint("/inference")).json(&payload), + client + .post(get_gateway_endpoint("/inference")) + .json(&payload), ) .await .unwrap(); @@ -818,7 +822,9 @@ async fn test_beta_structured_outputs_strict_tool_helper(stream: bool) { let inference_id = if stream { let mut event_source = into_sse_stream( - client.post(get_gateway_endpoint("/inference")).json(&payload), + client + .post(get_gateway_endpoint("/inference")) + .json(&payload), ) .await .unwrap(); @@ -940,7 +946,9 @@ pub async fn test_streaming_thinking_helper(model_name: &str, provider_type: &st let client = Client::new(); let mut event_source = into_sse_stream( - client.post(get_gateway_endpoint("/inference")).json(&payload), + client + .post(get_gateway_endpoint("/inference")) + .json(&payload), ) .await .unwrap(); diff --git a/tensorzero-core/tests/e2e/providers/commonv2/cache_input_tokens.rs b/tensorzero-core/tests/e2e/providers/commonv2/cache_input_tokens.rs index b50aad2a12..ec45b7a6db 100644 --- a/tensorzero-core/tests/e2e/providers/commonv2/cache_input_tokens.rs +++ b/tensorzero-core/tests/e2e/providers/commonv2/cache_input_tokens.rs @@ -261,13 +261,17 @@ async fn get_streaming_input_tokens( payload: &Value, variant_name: &str, ) -> Option { - let mut chunks = into_sse_stream(client.post(get_gateway_endpoint("/inference")).json(payload)) - .await - .unwrap_or_else(|e| { - panic!( - "Failed to create eventsource for streaming request for provider {variant_name}: {e}", - ) - }); + let mut chunks = into_sse_stream( + client + .post(get_gateway_endpoint("/inference")) + .json(payload), + ) + .await + .unwrap_or_else(|e| { + panic!( + "Failed to create eventsource for streaming request for provider {variant_name}: {e}", + ) + }); let mut input_tokens: Option = None; let mut all_chunks: Vec = Vec::new(); From 01668dd27aa3882b1db1efcde981934fd1f5a50d Mon Sep 17 00:00:00 2001 From: Aaron Hill Date: Thu, 29 Jan 2026 16:56:39 -0500 Subject: [PATCH 5/5] More work --- .../tests/e2e/providers/commonv2/cache_input_tokens.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tensorzero-core/tests/e2e/providers/commonv2/cache_input_tokens.rs b/tensorzero-core/tests/e2e/providers/commonv2/cache_input_tokens.rs index ec45b7a6db..f291534c31 100644 --- a/tensorzero-core/tests/e2e/providers/commonv2/cache_input_tokens.rs +++ b/tensorzero-core/tests/e2e/providers/commonv2/cache_input_tokens.rs @@ -286,10 +286,7 @@ async fn get_streaming_input_tokens( } let chunk_json: Value = serde_json::from_str(&data).unwrap_or_else(|e| { - panic!( - "Failed to parse chunk as JSON for provider {variant_name}: {e}. Data: {}", - data - ) + panic!("Failed to parse chunk as JSON for provider {variant_name}: {e}. Data: {data}") }); all_chunks.push(chunk_json.clone());