diff --git a/gateway/src/routes/external.rs b/gateway/src/routes/external.rs index 2524fb95ff..1f0cff0355 100644 --- a/gateway/src/routes/external.rs +++ b/gateway/src/routes/external.rs @@ -9,6 +9,7 @@ use axum::{ routing::{delete, get, patch, post}, }; use metrics_exporter_prometheus::PrometheusHandle; +use tensorzero_core::endpoints::anthropic_compatible::build_anthropic_compatible_routes; use tensorzero_core::endpoints::openai_compatible::build_openai_compatible_routes; use tensorzero_core::observability::OtelEnabledRoutes; use tensorzero_core::{endpoints, utils::gateway::AppStateData}; @@ -34,6 +35,7 @@ pub fn build_otel_enabled_routes() -> (OtelEnabledRoutes, Router) ("/feedback", post(endpoints::feedback::feedback_handler)), ]; routes.extend(build_openai_compatible_routes().routes); + routes.extend(build_anthropic_compatible_routes().routes); let mut router = Router::new(); let mut route_names = Vec::with_capacity(routes.len()); for (path, handler) in routes { diff --git a/tensorzero-core/src/endpoints/anthropic_compatible/messages.rs b/tensorzero-core/src/endpoints/anthropic_compatible/messages.rs new file mode 100644 index 0000000000..5a3651b091 --- /dev/null +++ b/tensorzero-core/src/endpoints/anthropic_compatible/messages.rs @@ -0,0 +1,100 @@ +//! Messages endpoint handler for Anthropic-compatible API. +//! +//! This module implements the HTTP handler for the `/anthropic/v1/messages` endpoint, +//! which provides Anthropic Messages API compatibility. It handles request validation, +//! parameter parsing, inference execution, and response formatting for both streaming +//! and non-streaming requests. + +use axum::Json; +use axum::body::Body; +use axum::extract::State; +use axum::response::sse::Sse; +use axum::response::{IntoResponse, Response}; +use axum::{Extension, debug_handler}; + +use crate::endpoints::anthropic_compatible::types::messages::AnthropicMessageResponse; +use crate::endpoints::anthropic_compatible::types::messages::AnthropicMessagesParams; +use crate::endpoints::anthropic_compatible::types::streaming::prepare_serialized_anthropic_events; +use crate::endpoints::inference::{InferenceOutput, Params, inference}; +use crate::error::{Error, ErrorDetails}; +use crate::utils::gateway::{AppState, AppStateData, StructuredJson}; +use tensorzero_auth::middleware::RequestApiKeyExtension; + +/// A handler for the Anthropic-compatible messages endpoint +#[debug_handler(state = AppStateData)] +pub async fn messages_handler( + State(AppStateData { + config, + http_client, + clickhouse_connection_info, + postgres_connection_info, + deferred_tasks, + rate_limiting_manager, + .. + }): AppState, + api_key_ext: Option>, + StructuredJson(anthropic_params): StructuredJson, +) -> Result, Error> { + // Validate that max_tokens is set (it's required in Anthropic's API) + if anthropic_params.max_tokens == 0 { + return Err(Error::new( + ErrorDetails::InvalidAnthropicCompatibleRequest { + message: "`max_tokens` is required and must be greater than 0".to_string(), + }, + )); + } + + let include_raw_usage = anthropic_params.tensorzero_include_raw_usage; + let include_raw_response = anthropic_params.tensorzero_include_raw_response; + + let params = Params::try_from_anthropic(anthropic_params)?; + + // The prefix for the response's `model` field depends on the inference target + let response_model_prefix = match (¶ms.function_name, ¶ms.model_name) { + (Some(function_name), None) => Ok::(format!( + "tensorzero::function_name::{function_name}::variant_name::", + )), + (None, Some(_model_name)) => Ok("tensorzero::model_name::".to_string()), + (Some(_), Some(_)) => Err(ErrorDetails::InvalidInferenceTarget { + message: "Only one of `function_name` or `model_name` can be provided".to_string(), + } + .into()), + (None, None) => Err(ErrorDetails::InvalidInferenceTarget { + message: "Either `function_name` or `model_name` must be provided".to_string(), + } + .into()), + }?; + + let response = Box::pin(inference( + config, + &http_client, + clickhouse_connection_info, + postgres_connection_info, + deferred_tasks, + rate_limiting_manager, + params, + api_key_ext, + )) + .await? + .output; + + match response { + InferenceOutput::NonStreaming(response) => { + let anthropic_response = + AnthropicMessageResponse::from((response, response_model_prefix)); + Ok(Json(anthropic_response).into_response()) + } + InferenceOutput::Streaming(stream) => { + let anthropic_stream = prepare_serialized_anthropic_events( + stream, + response_model_prefix, + true, // include_usage + include_raw_usage, + include_raw_response, + ); + Ok(Sse::new(anthropic_stream) + .keep_alive(axum::response::sse::KeepAlive::new()) + .into_response()) + } + } +} diff --git a/tensorzero-core/src/endpoints/anthropic_compatible/mod.rs b/tensorzero-core/src/endpoints/anthropic_compatible/mod.rs new file mode 100644 index 0000000000..f49f27f11b --- /dev/null +++ b/tensorzero-core/src/endpoints/anthropic_compatible/mod.rs @@ -0,0 +1,47 @@ +//! Anthropic-compatible API endpoints. +//! +//! This module provides compatibility with Anthropic's Messages API format. +//! It handles routing, request/response conversion, and provides the main entry +//! points for Anthropic-compatible requests. + +pub mod messages; +pub mod models; +pub mod types; + +use messages::messages_handler; +use models::models_handler; + +use axum::Router; +use axum::routing::{get, post}; + +use crate::endpoints::RouteHandlers; +use crate::utils::gateway::AppStateData; + +/// Constructs (but does not register) all of our Anthropic-compatible endpoints. +/// The `RouterExt::register_anthropic_compatible_routes` is a convenience method +/// to register all of the routes on a router. +/// +/// Alternatively, the returned `RouteHandlers` can be inspected (e.g. to allow middleware to see the route paths) +/// and then manually registered on a router. +pub fn build_anthropic_compatible_routes() -> RouteHandlers { + RouteHandlers { + routes: vec![ + ("/anthropic/v1/messages", post(messages_handler)), + ("/anthropic/v1/models", get(models_handler)), + ], + } +} + +pub trait RouterExt { + /// Applies our Anthropic-compatible endpoints to the router. + fn register_anthropic_compatible_routes(self) -> Self; +} + +impl RouterExt for Router { + fn register_anthropic_compatible_routes(mut self) -> Self { + for (path, handler) in build_anthropic_compatible_routes().routes { + self = self.route(path, handler); + } + self + } +} diff --git a/tensorzero-core/src/endpoints/anthropic_compatible/models.rs b/tensorzero-core/src/endpoints/anthropic_compatible/models.rs new file mode 100644 index 0000000000..ea64c934a7 --- /dev/null +++ b/tensorzero-core/src/endpoints/anthropic_compatible/models.rs @@ -0,0 +1,268 @@ +//! Models API endpoint for Anthropic-compatible API. +//! +//! This module provides the `/anthropic/v1/models` endpoint that returns a list of +//! configured models available in the TensorZero gateway. It follows the Anthropic +//! API specification for model listing. +//! +//! # Example +//! +//! ```rust,no_run +//! use reqwest::Client; +//! +//! # async fn example() -> Result<(), Box> { +//! let response = Client::new() +//! .get("http://localhost:3000/anthropic/v1/models") +//! .send() +//! .await?; +//! +//! let models = response.json::().await?; +//! println!("Available models: {}", models); +//! # Ok(()) +//! # } +//! ``` +//! +//! # Model IDs +//! +//! Models are returned with prefixes indicating their type: +//! - `tensorzero::function_name::{name}`: TensorZero functions +//! - `tensorzero::model_name::{name}`: Direct provider models +//! +//! # Response Format +//! +//! ```json +//! { +//! "data": [ +//! { +//! "id": "tensorzero::function_name::my_function", +//! "name": "my_function", +//! "type": "model" +//! } +//! ], +//! "object": "list" +//! } +//! ``` + +use axum::extract::State; +use serde::{Deserialize, Serialize}; +use tracing::instrument; + +use crate::error::Error; +use crate::utils::gateway::{AppState, AppStateData}; + +/// Prefix for function_name models +const FUNCTION_NAME_PREFIX: &str = "tensorzero::function_name::"; + +/// Prefix for model_name models +const MODEL_NAME_PREFIX: &str = "tensorzero::model_name::"; + +/// Individual model information for the Models API response +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +pub struct AnthropicModel { + /// Unique identifier for the model (function name or model name) + pub id: String, + /// Display name for the model + pub name: String, + /// Type of resource (always "model") + #[serde(rename = "type")] + pub resource_type: String, +} + +/// Response for GET /anthropic/v1/models +/// +/// Returns a list of available models that can be used with the Messages API. +/// Models are returned in the format expected by the Anthropic API. +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub struct AnthropicModelsResponse { + /// List of available models (both functions and providers) + pub data: Vec, + /// Object type (always "list" for consistency with Anthropic API) + pub object: String, +} + +/// Handler for `GET /anthropic/v1/models` +/// +/// Returns all configured models (functions and providers) that can be used +/// with the Anthropic-compatible Messages API. +/// +/// # Behavior +/// +/// This endpoint: +/// 1. Iterates through all configured functions in the gateway +/// 2. Includes only functions that have at least one variant configured +/// 3. Iterates through all configured provider models +/// 4. Sorts all models alphabetically by ID for consistent ordering +/// 5. Returns models in Anthropic's format with appropriate prefixes +/// +/// # Model ID Format +/// +/// - Functions: `tensorzero::function_name::{function_name}` +/// - Providers: `tensorzero::model_name::{model_name}` +/// +/// # Response Format +/// +/// Returns models in Anthropic's format: +/// ```json +/// { +/// "data": [ +/// { +/// "id": "tensorzero::function_name::my_function", +/// "name": "my_function", +/// "type": "model" +/// }, +/// { +/// "id": "tensorzero::model_name::openai::gpt-4o-mini", +/// "name": "openai::gpt-4o-mini", +/// "type": "model" +/// } +/// ], +/// "object": "list" +/// } +/// ``` +/// +/// # Errors +/// +/// This function will return an error if: +/// - The gateway state is unavailable (internal server error) +/// - Configuration cannot be accessed (internal server error) +/// +/// # Example Usage +/// +/// ```bash +/// curl http://localhost:3000/anthropic/v1/models +/// ``` +#[axum::debug_handler(state = AppStateData)] +#[instrument(name = "anthropic_compatible.models", skip_all)] +#[allow(clippy::unused_async)] +pub async fn models_handler( + State(app_state): AppState, +) -> Result, Error> { + let config = &app_state.config; + + let mut models = Vec::new(); + + // Add function_name models (TensorZero functions) + for (function_name, function_config) in &config.functions { + // Only include functions that have at least one variant configured + if !function_config.variants().is_empty() { + models.push(AnthropicModel { + id: format!("{FUNCTION_NAME_PREFIX}{function_name}"), + name: function_name.clone(), + resource_type: "model".to_string(), + }); + } + } + + // Add model_name models (direct provider models) + for model_name in config.models.table.keys() { + models.push(AnthropicModel { + id: format!("{MODEL_NAME_PREFIX}{model_name}"), + name: model_name.to_string(), + resource_type: "model".to_string(), + }); + } + + // Sort models by ID for consistent, predictable ordering + models.sort_by(|a, b| a.id.cmp(&b.id)); + + Ok(axum::Json(AnthropicModelsResponse { + data: models, + object: "list".to_string(), + })) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_anthropic_model_serialization() { + let model = AnthropicModel { + id: "claude-3-5-sonnet-20241022".to_string(), + name: "Claude 3.5 Sonnet".to_string(), + resource_type: "model".to_string(), + }; + + let json = serde_json::to_string(&model).unwrap(); + + // Verify the JSON structure + assert!(json.contains("\"id\"")); + assert!(json.contains("\"name\"")); + assert!(json.contains("\"type\"")); + assert!(json.contains("\"claude-3-5-sonnet-20241022\"")); + + // Verify we can deserialize it back + let parsed: AnthropicModel = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed.id, "claude-3-5-sonnet-20241022"); + assert_eq!(parsed.name, "Claude 3.5 Sonnet"); + assert_eq!(parsed.resource_type, "model"); + } + + #[test] + fn test_anthropic_models_response_serialization() { + let response = AnthropicModelsResponse { + data: vec![ + AnthropicModel { + id: "tensorzero::function_name::test".to_string(), + name: "test".to_string(), + resource_type: "model".to_string(), + }, + AnthropicModel { + id: "tensorzero::model_name::openai::gpt-4o-mini".to_string(), + name: "openai::gpt-4o-mini".to_string(), + resource_type: "model".to_string(), + }, + ], + object: "list".to_string(), + }; + + let json = serde_json::to_string(&response).unwrap(); + let parsed: Value = serde_json::from_str(&json).unwrap(); + + // Verify top-level structure + assert_eq!(parsed["object"], "list"); + assert!(parsed["data"].is_array()); + assert_eq!(parsed["data"].as_array().unwrap().len(), 2); + + // Verify first model + assert_eq!(parsed["data"][0]["id"], "tensorzero::function_name::test"); + assert_eq!(parsed["data"][0]["name"], "test"); + assert_eq!(parsed["data"][0]["type"], "model"); + + // Verify second model + assert_eq!( + parsed["data"][1]["id"], + "tensorzero::model_name::openai::gpt-4o-mini" + ); + assert_eq!(parsed["data"][1]["name"], "openai::gpt-4o-mini"); + assert_eq!(parsed["data"][1]["type"], "model"); + } + + #[test] + fn test_anthropic_model_required_fields() { + // Test that required fields are present and correctly typed + let model = AnthropicModel { + id: "test-model".to_string(), + name: "Test Model".to_string(), + resource_type: "model".to_string(), + }; + + // Verify all fields are non-empty + assert!(!model.id.is_empty()); + assert!(!model.name.is_empty()); + assert_eq!(model.resource_type, "model"); + } + + #[test] + fn test_anthropic_models_response_object_type() { + let response = AnthropicModelsResponse { + data: vec![], + object: "list".to_string(), + }; + + assert_eq!(response.object, "list"); + assert!(response.data.is_empty()); + } + + use serde_json::Value; +} diff --git a/tensorzero-core/src/endpoints/anthropic_compatible/types/messages.rs b/tensorzero-core/src/endpoints/anthropic_compatible/types/messages.rs new file mode 100644 index 0000000000..a269419dbc --- /dev/null +++ b/tensorzero-core/src/endpoints/anthropic_compatible/types/messages.rs @@ -0,0 +1,1179 @@ +//! Message types and conversion logic for Anthropic-compatible API. +//! +//! This module provides type definitions and conversion functions for translating +//! between Anthropic's Messages API format and TensorZero's internal representations. +//! +//! # Key Types +//! +//! - [`AnthropicMessagesParams`]: Request parameters matching Anthropic's API +//! - [`AnthropicMessageResponse`]: Response format matching Anthropic's API +//! - [`AnthropicContentBlock`]: Content blocks for requests (Text, ToolUse, ToolResult) +//! +//! # Conversions +//! +//! - [`Params::try_from_anthropic()`]: Converts Anthropic params to TensorZero Params +//! - [`AnthropicMessageResponse::from()`]: Converts TensorZero response to Anthropic format +//! - [`finish_reason_to_anthropic()`]: Maps TensorZero finish reasons to Anthropic's format +//! +//! # Example +//! +//! ```rust +//! use tensorzero_core::endpoints::anthropic_compatible::types::messages::AnthropicMessagesParams; +//! use serde_json::json; +//! +//! // Basic request +//! let params = AnthropicMessagesParams { +//! model: "tensorzero::function_name::my_function".to_string(), +//! max_tokens: 100, +//! messages: vec![ +//! AnthropicMessage::User(AnthropicUserMessage { +//! content: serde_json::Value::String("Hello, world!".to_string()), +//! }), +//! ], +//! ..Default::default() +//! }; +//! +//! // With system prompt +//! let params_with_system = AnthropicMessagesParams { +//! model: "tensorzero::function_name::my_function".to_string(), +//! max_tokens: 100, +//! system: Some(json!("You are a helpful assistant")), +//! messages: vec![...], +//! ..Default::default() +//! }; +//! +//! // With tools +//! let params_with_tools = AnthropicMessagesParams { +//! model: "tensorzero::function_name::my_function".to_string(), +//! max_tokens: 1000, +//! messages: vec![...], +//! tools: Some(vec![AnthropicTool { +//! name: "get_weather".to_string(), +//! description: "Get current weather".to_string(), +//! input_schema: AnthropicInputSchema { +//! schema_type: "object".to_string(), +//! properties: Some(HashMap::from_iter([ +//! ("location".to_string(), json!({ +//! "type": "string", +//! "description": "City name" +//! })), +//! ])), +//! required: Some(vec!["location".to_string()]), +//! additional_properties: Some(false), +//! }, +//! }]), +//! tool_choice: Some(AnthropicToolChoice::Auto), +//! ..Default::default() +//! }; +//! ``` + +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::collections::HashMap; +use tensorzero_derive::TensorZeroDeserialize; +use uuid::Uuid; + +use crate::cache::CacheParamsOptions; +use crate::config::UninitializedVariantInfo; +use crate::endpoints::anthropic_compatible::types::tool::{ + AnthropicTool, AnthropicToolChoice, AnthropicToolChoiceParams, +}; +use crate::endpoints::anthropic_compatible::types::usage::AnthropicUsage; +use crate::endpoints::inference::{ + ChatCompletionInferenceParams, InferenceCredentials, InferenceParams, InferenceResponse, Params, +}; +use crate::error::{Error, ErrorDetails}; +use crate::inference::types::extra_body::UnfilteredInferenceExtraBody; +use crate::inference::types::extra_headers::UnfilteredInferenceExtraHeaders; +use crate::inference::types::usage::{RawResponseEntry, RawUsageEntry}; +use crate::inference::types::{ + ContentBlockChatOutput, FinishReason, Input, InputMessage, InputMessageContent, RawText, Role, + System, Template, Text, +}; +use crate::tool::{DynamicToolParams, ProviderTool, ToolResult}; + +// ============================================================================ +// Message Types +// ============================================================================ + +/// Anthropic message (user or assistant - system is a separate field) +#[derive(Clone, Debug, Deserialize, PartialEq)] +#[serde(tag = "role")] +#[serde(rename_all = "lowercase")] +pub enum AnthropicMessage { + User(AnthropicUserMessage), + Assistant(AnthropicAssistantMessage), +} + +#[derive(Clone, Debug, Deserialize, PartialEq)] +pub struct AnthropicUserMessage { + pub content: AnthropicMessageContent, +} + +#[derive(Clone, Debug, Deserialize, PartialEq)] +pub struct AnthropicAssistantMessage { + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, +} + +/// Content can be a string or an array of content blocks +pub type AnthropicMessageContent = Value; + +/// Content block types for requests +#[derive(Clone, Debug, PartialEq, TensorZeroDeserialize)] +#[serde(tag = "type")] +#[serde(deny_unknown_fields, rename_all = "snake_case")] +pub enum AnthropicContentBlock { + Text { + text: String, + }, + ToolUse { + id: String, + name: String, + input: serde_json::Value, + }, + ToolResult { + tool_use_id: String, + content: String, + #[serde(default)] + is_error: bool, + }, + #[serde(rename = "tensorzero::raw_text")] + RawText(RawText), + #[serde(rename = "tensorzero::template")] + Template(Template), +} + +// ============================================================================ +// Request Parameter Types +// ============================================================================ + +#[derive(Clone, Debug, Default, Deserialize)] +pub struct AnthropicMessagesParams { + pub model: String, + pub messages: Vec, + #[serde(default)] + pub system: Option, + pub max_tokens: u32, + #[serde(default)] + pub stop_sequences: Option>, + #[serde(default)] + pub temperature: Option, + #[serde(default)] + pub top_p: Option, + #[serde(default)] + pub top_k: Option, + #[serde(default)] + pub stream: Option, + #[serde(default)] + pub tools: Option>, + #[serde(default)] + pub tool_choice: Option, + // TensorZero-specific parameters + #[serde(rename = "tensorzero::variant_name")] + pub tensorzero_variant_name: Option, + #[serde(rename = "tensorzero::dryrun")] + pub tensorzero_dryrun: Option, + #[serde(rename = "tensorzero::episode_id")] + pub tensorzero_episode_id: Option, + #[serde(rename = "tensorzero::cache_options")] + pub tensorzero_cache_options: Option, + #[serde(default, rename = "tensorzero::extra_body")] + pub tensorzero_extra_body: UnfilteredInferenceExtraBody, + #[serde(default, rename = "tensorzero::extra_headers")] + pub tensorzero_extra_headers: UnfilteredInferenceExtraHeaders, + #[serde(default, rename = "tensorzero::tags")] + pub tensorzero_tags: HashMap, + #[serde(default, rename = "tensorzero::deny_unknown_fields")] + pub tensorzero_deny_unknown_fields: bool, + #[serde(default, rename = "tensorzero::credentials")] + pub tensorzero_credentials: InferenceCredentials, + #[serde(rename = "tensorzero::internal_dynamic_variant_config")] + pub tensorzero_internal_dynamic_variant_config: Option, + #[serde(default, rename = "tensorzero::provider_tools")] + pub tensorzero_provider_tools: Vec, + #[serde(default, rename = "tensorzero::params")] + pub tensorzero_params: Option, + #[serde(default, rename = "tensorzero::include_raw_usage")] + pub tensorzero_include_raw_usage: bool, + #[serde(default, rename = "tensorzero::include_raw_response")] + pub tensorzero_include_raw_response: bool, +} + +// ============================================================================ +// Response Types +// ============================================================================ + +#[derive(Clone, Debug, PartialEq, Serialize)] +pub struct AnthropicMessageResponse { + pub id: String, + #[serde(rename = "type")] + pub message_type: String, + pub role: String, + pub content: Vec, + pub model: String, + pub stop_reason: AnthropicStopReason, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop_sequence: Option, + pub usage: AnthropicUsage, + #[serde(skip_serializing_if = "Option::is_none")] + pub tensorzero_raw_usage: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tensorzero_raw_response: Option>, +} + +#[derive(Clone, Debug, PartialEq, Serialize)] +#[serde(tag = "type")] +#[serde(rename_all = "snake_case")] +pub enum AnthropicOutputContentBlock { + Text { + text: String, + }, + ToolUse { + id: String, + name: String, + input: Value, + }, +} + +#[derive(Clone, Debug, PartialEq, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum AnthropicStopReason { + EndTurn, + MaxTokens, + StopSequence, + ToolUse, +} + +impl From for AnthropicStopReason { + fn from(finish_reason: FinishReason) -> Self { + match finish_reason { + FinishReason::Stop => AnthropicStopReason::EndTurn, + FinishReason::StopSequence => AnthropicStopReason::StopSequence, + FinishReason::Length => AnthropicStopReason::MaxTokens, + FinishReason::ToolCall => AnthropicStopReason::ToolUse, + FinishReason::ContentFilter => AnthropicStopReason::EndTurn, // Coerce to end_turn + FinishReason::Unknown => AnthropicStopReason::EndTurn, + } + } +} + +// ============================================================================ +// Conversion Implementations +// ============================================================================ + +const TENSORZERO_FUNCTION_NAME_PREFIX: &str = "tensorzero::function_name::"; +const TENSORZERO_MODEL_NAME_PREFIX: &str = "tensorzero::model_name::"; +const ANTHROPIC_MESSAGE_TYPE: &str = "message"; +const ANTHROPIC_ROLE_ASSISTANT: &str = "assistant"; + +impl Params { + pub fn try_from_anthropic(anthropic_params: AnthropicMessagesParams) -> Result { + let (function_name, model_name) = if let Some(function_name) = anthropic_params + .model + .strip_prefix(TENSORZERO_FUNCTION_NAME_PREFIX) + { + (Some(function_name.to_string()), None) + } else if let Some(model_name) = anthropic_params + .model + .strip_prefix(TENSORZERO_MODEL_NAME_PREFIX) + { + (None, Some(model_name.to_string())) + } else { + return Err(Error::new(ErrorDetails::InvalidAnthropicCompatibleRequest { + message: "`model` field must start with `tensorzero::function_name::` or `tensorzero::model_name::`. For example, `tensorzero::function_name::my_function` for a function `my_function` defined in your config, or `tensorzero::model_name::my_model` for a model `my_model` defined in your config.".to_string(), + })); + }; + + if let Some(function_name) = &function_name + && function_name.is_empty() + { + return Err(ErrorDetails::InvalidAnthropicCompatibleRequest { + message: "function_name (passed in model field after \"tensorzero::function_name::\") cannot be empty".to_string(), + }.into()); + } + + if let Some(model_name) = &model_name + && model_name.is_empty() + { + return Err(ErrorDetails::InvalidAnthropicCompatibleRequest { + message: "model_name (passed in model field after \"tensorzero::model_name::\") cannot be empty".to_string(), + }.into()); + } + + let input = + anthropic_messages_to_input(anthropic_params.system, anthropic_params.messages)?; + + let mut inference_params = anthropic_params.tensorzero_params.unwrap_or_default(); + + inference_params.chat_completion = ChatCompletionInferenceParams { + frequency_penalty: inference_params.chat_completion.frequency_penalty, + json_mode: inference_params.chat_completion.json_mode, + max_tokens: inference_params + .chat_completion + .max_tokens + .or(Some(anthropic_params.max_tokens)), + presence_penalty: inference_params.chat_completion.presence_penalty, + reasoning_effort: inference_params.chat_completion.reasoning_effort, + service_tier: inference_params.chat_completion.service_tier, + seed: inference_params.chat_completion.seed, + stop_sequences: inference_params + .chat_completion + .stop_sequences + .or(anthropic_params.stop_sequences), + temperature: inference_params + .chat_completion + .temperature + .or(anthropic_params.temperature), + thinking_budget_tokens: inference_params.chat_completion.thinking_budget_tokens, + top_p: inference_params + .chat_completion + .top_p + .or(anthropic_params.top_p), + verbosity: inference_params.chat_completion.verbosity.clone(), + }; + + let AnthropicToolChoiceParams { + allowed_tools, + tool_choice, + } = anthropic_params + .tool_choice + .map(|tc| tc.into_tool_params()) + .unwrap_or_default(); + + let dynamic_tool_params = DynamicToolParams { + allowed_tools, + additional_tools: anthropic_params + .tools + .map(|tools| tools.into_iter().map(|t| t.into()).collect()), + tool_choice, + parallel_tool_calls: Some(true), // Anthropic supports parallel tool calls + provider_tools: anthropic_params.tensorzero_provider_tools, + }; + + Ok(Params { + function_name, + model_name, + episode_id: anthropic_params.tensorzero_episode_id, + input, + stream: anthropic_params.stream, + params: inference_params, + variant_name: anthropic_params.tensorzero_variant_name, + dryrun: anthropic_params.tensorzero_dryrun, + dynamic_tool_params, + output_schema: None, // Anthropic doesn't have a direct response_format equivalent + credentials: anthropic_params.tensorzero_credentials, + cache_options: anthropic_params + .tensorzero_cache_options + .unwrap_or_default(), + internal: false, + tags: anthropic_params.tensorzero_tags, + include_original_response: false, // Deprecated + include_raw_response: anthropic_params.tensorzero_include_raw_response, + include_raw_usage: anthropic_params.tensorzero_include_raw_usage, + extra_body: anthropic_params.tensorzero_extra_body, + extra_headers: anthropic_params.tensorzero_extra_headers, + internal_dynamic_variant_config: anthropic_params + .tensorzero_internal_dynamic_variant_config, + }) + } +} + +/// Convert Anthropic messages to TensorZero Input format. +/// +/// This function handles the translation from Anthropic's message format to TensorZero's internal format, +/// including merging consecutive tool result messages (for parallel tool calls). +/// +/// # Arguments +/// * `system` - Optional system prompt (string or array of blocks) +/// * `messages` - Vector of Anthropic messages (User and Assistant) +/// +/// # Returns +/// A TensorZero `Input` containing the converted messages and system prompt +/// +/// # Tool Result Merging +/// When multiple consecutive user messages contain tool results, they are merged into a single +/// user message containing all tool results. This handles Anthropic's format where parallel +/// tool calls are represented as consecutive tool result messages. +/// +/// # Errors +/// Returns an error if: +/// - The system prompt is invalid (not a string or array) +/// - Any message content is invalid +pub fn anthropic_messages_to_input( + system: Option, + messages: Vec, +) -> Result { + // Convert system prompt (can be string or array of blocks) + let system_message = if let Some(system) = system { + Some(convert_system_prompt(system)?) + } else { + None + }; + + let mut converted_messages = Vec::new(); + + for message in messages { + match message { + AnthropicMessage::User(msg) => { + let content = convert_anthropic_message_content(msg.content)?; + converted_messages.push(InputMessage { + role: Role::User, + content, + }); + } + AnthropicMessage::Assistant(msg) => { + let mut message_content = Vec::new(); + if let Some(content) = msg.content { + message_content.extend(convert_anthropic_message_content(content)?); + } + converted_messages.push(InputMessage { + role: Role::Assistant, + content: message_content, + }); + } + } + } + + // Merge consecutive tool results (similar to OpenAI endpoint) + // This ensures that parallel tool call results can be passed through properly + let mut final_messages = Vec::new(); + let mut i = 0; + while i < converted_messages.len() { + let message = &converted_messages[i]; + if message.role == Role::User { + // Check if this is a tool result message + let has_tool_result = message + .content + .iter() + .any(|c| matches!(c, InputMessageContent::ToolResult(_))); + + if has_tool_result { + // Collect all consecutive tool results + let mut tool_results = Vec::new(); + while i < converted_messages.len() + && converted_messages[i].role == Role::User + && converted_messages[i] + .content + .iter() + .any(|c| matches!(c, InputMessageContent::ToolResult(_))) + { + for content in &converted_messages[i].content { + if let InputMessageContent::ToolResult(_) = content { + tool_results.push(content.clone()); + } + } + i += 1; + } + // Add any text content from the next user message + if i < converted_messages.len() && converted_messages[i].role == Role::User { + for content in &converted_messages[i].content { + if let InputMessageContent::Text(_) = content { + tool_results.push(content.clone()); + break; + } + } + } + final_messages.push(InputMessage { + role: Role::User, + content: tool_results, + }); + continue; + } + } + final_messages.push(message.clone()); + i += 1; + } + + Ok(Input { + system: system_message, + messages: final_messages, + }) +} + +/// Convert system prompt from Anthropic format to TensorZero format. +/// +/// # Arguments +/// * `system` - Either a string or an array of content blocks (e.g., for cache_control) +/// +/// # Returns +/// A `System` instance containing the converted system prompt +/// +/// # Errors +/// Returns an error if the system value is not a string or array +fn convert_system_prompt(system: Value) -> Result { + match system { + Value::String(s) => Ok(System::Text(s)), + Value::Array(blocks) => { + // System as array of blocks (e.g., for cache_control) + // Concatenate text blocks with newline separators + let text_blocks: Vec<&str> = blocks + .iter() + .filter_map(|block| { + block + .as_object() + .and_then(|obj| obj.get("type")) + .and_then(|t| t.as_str()) + .filter(|&t| t == "text") + .and_then(|_| { + block + .as_object() + .and_then(|obj| obj.get("text")) + .and_then(|text| text.as_str()) + }) + }) + .collect(); + + Ok(System::Text(text_blocks.join("\n"))) + } + _ => Err(ErrorDetails::InvalidAnthropicCompatibleRequest { + message: "system must be a string or array of blocks".to_string(), + } + .into()), + } +} + +/// Convert Anthropic message content to TensorZero format. +/// +/// Handles both simple string content and structured content blocks. +/// +/// # Arguments +/// * `content` - JSON value containing either: +/// - A string (simple text message) +/// - An array of content blocks (Text, ToolUse, ToolResult, etc.) +/// +/// # Returns +/// Vector of TensorZero input message contents +/// +/// # Errors +/// Returns an error if: +/// - The content is not a string or array +/// - Any content block is malformed or invalid +/// +/// # Supported Conversions +/// - `string` → `InputMessageContent::Text` +/// - `{"type": "text", "text": "..."}` → `InputMessageContent::Text` +/// - `{"type": "tool_use", ...}` → `InputMessageContent::ToolCall` +/// - `{"type": "tool_result", ...}` → `InputMessageContent::ToolResult` +/// - `{"type": "raw_text", ...}` → `InputMessageContent::RawText` +/// - `{"type": "template", ...}` → `InputMessageContent::Template` +fn convert_anthropic_message_content(content: Value) -> Result, Error> { + match content { + Value::String(s) => Ok(vec![InputMessageContent::Text(Text { text: s })]), + Value::Array(blocks) => { + blocks + .into_iter() + .map(|block| { + let content_block = + serde_json::from_value::(block.clone()); + + match content_block { + Ok(AnthropicContentBlock::Text { text }) => { + Ok(InputMessageContent::Text(Text { text })) + } + Ok(AnthropicContentBlock::ToolUse { id, name, input }) => { + Ok(InputMessageContent::ToolCall( + crate::tool::ToolCallWrapper::InferenceResponseToolCall( + crate::tool::InferenceResponseToolCall { + id, + raw_name: name.clone(), + raw_arguments: serde_json::to_string(&input) + .unwrap_or_default(), + name: None, + arguments: None, + }, + ), + )) + } + Ok(AnthropicContentBlock::ToolResult { + tool_use_id, + content, + .. + }) => Ok(InputMessageContent::ToolResult(ToolResult { + id: tool_use_id, + name: String::new(), // Will be filled in from tool_use_id mapping + result: content, + })), + Ok(AnthropicContentBlock::RawText(raw_text)) => { + Ok(InputMessageContent::RawText(raw_text)) + } + Ok(AnthropicContentBlock::Template(t)) => { + Ok(InputMessageContent::Template(t)) + } + Err(e) => Err(ErrorDetails::InvalidAnthropicCompatibleRequest { + message: format!("Invalid content block: {e}"), + } + .into()), + } + }) + .collect() + } + _ => Err(ErrorDetails::InvalidAnthropicCompatibleRequest { + message: "message content must be a string or array of content blocks".to_string(), + } + .into()), + } +} + +impl From<(InferenceResponse, String)> for AnthropicMessageResponse { + fn from((inference_response, response_model_prefix): (InferenceResponse, String)) -> Self { + match inference_response { + InferenceResponse::Chat(response) => { + let content = process_chat_content(response.content); + AnthropicMessageResponse { + id: response.inference_id.to_string(), + message_type: ANTHROPIC_MESSAGE_TYPE.to_string(), + role: ANTHROPIC_ROLE_ASSISTANT.to_string(), + content, + model: format!("{response_model_prefix}{}", response.variant_name), + stop_reason: response.finish_reason.unwrap_or(FinishReason::Stop).into(), + stop_sequence: None, + usage: AnthropicUsage { + input_tokens: response.usage.input_tokens.unwrap_or(0), + output_tokens: response.usage.output_tokens.unwrap_or(0), + }, + tensorzero_raw_usage: response.raw_usage, + tensorzero_raw_response: response.raw_response, + } + } + InferenceResponse::Json(response) => AnthropicMessageResponse { + id: response.inference_id.to_string(), + message_type: "message".to_string(), + role: "assistant".to_string(), + content: vec![AnthropicOutputContentBlock::Text { + text: response.output.raw.unwrap_or_default(), + }], + model: format!("{response_model_prefix}{}", response.variant_name), + stop_reason: response.finish_reason.unwrap_or(FinishReason::Stop).into(), + stop_sequence: None, + usage: AnthropicUsage { + input_tokens: response.usage.input_tokens.unwrap_or(0), + output_tokens: response.usage.output_tokens.unwrap_or(0), + }, + tensorzero_raw_usage: response.raw_usage, + tensorzero_raw_response: response.raw_response, + }, + } + } +} + +/// Process chat content blocks and convert to Anthropic-compatible format. +/// +/// Filters out unsupported block types (Thought, Unknown) with warnings, +/// and converts Text and ToolCall blocks to Anthropic format. +/// +/// # Arguments +/// * `content` - Vector of content blocks from TensorZero chat response +/// +/// # Returns +/// Vector of Anthropic-compatible content blocks (Text and ToolUse only) +/// +/// # Filtering +/// - `Text` blocks → `AnthropicOutputContentBlock::Text` +/// - `ToolCall` blocks → `AnthropicOutputContentBlock::ToolUse` +/// - `Thought` blocks → Logged and filtered out (not supported by Anthropic) +/// - `Unknown` blocks → Logged and filtered out +fn process_chat_content(content: Vec) -> Vec { + content + .into_iter() + .filter_map(|block| match block { + ContentBlockChatOutput::Text(text) => { + Some(AnthropicOutputContentBlock::Text { text: text.text }) + } + ContentBlockChatOutput::ToolCall(tool_call) => { + Some(AnthropicOutputContentBlock::ToolUse { + id: tool_call.id, + name: tool_call.raw_name, + input: serde_json::from_str(&tool_call.raw_arguments).unwrap_or_default(), + }) + } + ContentBlockChatOutput::Thought(_) => { + tracing::warn!( + "Ignoring 'thought' content block when constructing Anthropic-compatible response" + ); + None + } + ContentBlockChatOutput::Unknown(_) => { + tracing::warn!( + "Ignoring 'unknown' content block when constructing Anthropic-compatible response" + ); + None + } + }) + .collect() +} + +// ============================================================================ +// Helper Functions +// ============================================================================ + +/// Convert TensorZero FinishReason to Anthropic-compatible stop reason string +/// This is used by both streaming and non-streaming responses +pub fn finish_reason_to_anthropic(finish_reason: FinishReason) -> String { + match finish_reason { + FinishReason::Stop => "end_turn".to_string(), + FinishReason::StopSequence => "stop_sequence".to_string(), + FinishReason::Length => "max_tokens".to_string(), + FinishReason::ToolCall => "tool_use".to_string(), + FinishReason::ContentFilter => "end_turn".to_string(), + FinishReason::Unknown => "end_turn".to_string(), + } +} + +// ============================================================================ +// Tests +// ============================================================================ + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_anthropic_messages_to_input_basic() { + let messages = vec![AnthropicMessage::User(AnthropicUserMessage { + content: Value::String("Hello, world!".to_string()), + })]; + + let input = anthropic_messages_to_input(None, messages).unwrap(); + assert_eq!(input.messages.len(), 1); + assert_eq!(input.messages[0].role, Role::User); + assert_eq!( + input.messages[0].content[0], + InputMessageContent::Text(Text { + text: "Hello, world!".to_string(), + }) + ); + } + + #[test] + fn test_anthropic_messages_to_input_with_system() { + let system = json!("You are a helpful assistant"); + let messages = vec![AnthropicMessage::User(AnthropicUserMessage { + content: Value::String("Hello".to_string()), + })]; + + let input = anthropic_messages_to_input(Some(system), messages).unwrap(); + assert_eq!( + input.system, + Some(System::Text("You are a helpful assistant".to_string())) + ); + } + + #[test] + fn test_anthropic_messages_to_input_with_system_array() { + let system = json!([ + {"type": "text", "text": "You are helpful"} + ]); + let messages = vec![AnthropicMessage::User(AnthropicUserMessage { + content: Value::String("Hello".to_string()), + })]; + + let input = anthropic_messages_to_input(Some(system), messages).unwrap(); + assert_eq!( + input.system, + Some(System::Text("You are helpful".to_string())) + ); + } + + #[test] + fn test_anthropic_messages_to_input_tool_use() { + let messages = vec![ + AnthropicMessage::User(AnthropicUserMessage { + content: Value::String("What's the weather?".to_string()), + }), + AnthropicMessage::Assistant(AnthropicAssistantMessage { + content: Some(Value::Array(vec![json!({ + "type": "tool_use", + "id": "toolu_0123", + "name": "get_weather", + "input": {"location": "SF"} + })])), + }), + ]; + + let input = anthropic_messages_to_input(None, messages).unwrap(); + assert_eq!(input.messages.len(), 2); + assert_eq!(input.messages[0].role, Role::User); + assert_eq!(input.messages[1].role, Role::Assistant); + assert!(matches!( + input.messages[1].content[0], + InputMessageContent::ToolCall(_) + )); + } + + #[test] + fn test_anthropic_messages_to_input_with_tool_result() { + let messages = vec![ + AnthropicMessage::User(AnthropicUserMessage { + content: Value::String("What's the weather?".to_string()), + }), + AnthropicMessage::Assistant(AnthropicAssistantMessage { + content: Some(Value::Array(vec![json!({ + "type": "tool_use", + "id": "toolu_0123", + "name": "get_weather", + "input": {"location": "SF"} + })])), + }), + AnthropicMessage::User(AnthropicUserMessage { + content: Value::Array(vec![json!({ + "type": "tool_result", + "tool_use_id": "toolu_0123", + "content": "68 degrees" + })]), + }), + ]; + + let input = anthropic_messages_to_input(None, messages).unwrap(); + // Should have 3 messages: user question, assistant tool use, user tool result + assert_eq!(input.messages.len(), 3); + assert_eq!(input.messages[0].role, Role::User); + assert_eq!(input.messages[1].role, Role::Assistant); + assert_eq!(input.messages[2].role, Role::User); + // Last message should contain the tool result + assert!( + input.messages[2] + .content + .iter() + .any(|c| matches!(c, InputMessageContent::ToolResult(_))) + ); + } + + #[test] + fn test_params_try_from_anthropic_basic() { + let params = Params::try_from_anthropic(AnthropicMessagesParams { + model: "tensorzero::function_name::test_function".to_string(), + max_tokens: 100, + messages: vec![AnthropicMessage::User(AnthropicUserMessage { + content: Value::String("test".to_string()), + })], + ..Default::default() + }) + .unwrap(); + + assert_eq!(params.function_name, Some("test_function".to_string())); + assert_eq!(params.params.chat_completion.max_tokens, Some(100)); + } + + #[test] + fn test_params_try_from_anthropic_invalid_prefix() { + let result = Params::try_from_anthropic(AnthropicMessagesParams { + model: "gpt-4".to_string(), + max_tokens: 100, + messages: vec![], + ..Default::default() + }); + + assert!(result.is_err()); + } + + #[test] + fn test_params_try_from_anthropic_empty_function_name() { + let result = Params::try_from_anthropic(AnthropicMessagesParams { + model: "tensorzero::function_name::".to_string(), + max_tokens: 100, + messages: vec![], + ..Default::default() + }); + + assert!(result.is_err()); + } + + #[test] + fn test_params_try_from_anthropic_with_temperature() { + let params = Params::try_from_anthropic(AnthropicMessagesParams { + model: "tensorzero::model_name::test_model".to_string(), + max_tokens: 100, + temperature: Some(0.7), + top_p: Some(0.9), + messages: vec![AnthropicMessage::User(AnthropicUserMessage { + content: Value::String("test".to_string()), + })], + ..Default::default() + }) + .unwrap(); + + assert_eq!(params.model_name, Some("test_model".to_string())); + assert_eq!(params.params.chat_completion.temperature, Some(0.7)); + assert_eq!(params.params.chat_completion.top_p, Some(0.9)); + } + + #[test] + fn test_params_try_from_anthropic_with_stop_sequences() { + let params = Params::try_from_anthropic(AnthropicMessagesParams { + model: "tensorzero::model_name::test_model".to_string(), + max_tokens: 100, + stop_sequences: Some(vec!["STOP".to_string(), "END".to_string()]), + messages: vec![AnthropicMessage::User(AnthropicUserMessage { + content: Value::String("test".to_string()), + })], + ..Default::default() + }) + .unwrap(); + + assert_eq!( + params.params.chat_completion.stop_sequences, + Some(vec!["STOP".to_string(), "END".to_string()]) + ); + } + + #[test] + fn test_params_try_from_anthropic_tool_choice_auto() { + let params = Params::try_from_anthropic(AnthropicMessagesParams { + model: "tensorzero::function_name::test_function".to_string(), + max_tokens: 100, + tool_choice: Some(AnthropicToolChoice::Auto), + messages: vec![], + ..Default::default() + }) + .unwrap(); + + assert_eq!( + params.dynamic_tool_params.tool_choice, + Some(tensorzero_types::ToolChoice::Auto) + ); + } + + #[test] + fn test_params_try_from_anthropic_tool_choice_specific() { + let params = Params::try_from_anthropic(AnthropicMessagesParams { + model: "tensorzero::function_name::test_function".to_string(), + max_tokens: 100, + tool_choice: Some(AnthropicToolChoice::Tool { + name: "my_tool".to_string(), + }), + messages: vec![], + ..Default::default() + }) + .unwrap(); + + assert_eq!( + params.dynamic_tool_params.tool_choice, + Some(tensorzero_types::ToolChoice::Specific( + "my_tool".to_string() + )) + ); + assert_eq!( + params.dynamic_tool_params.allowed_tools, + Some(vec!["my_tool".to_string()]) + ); + } + + #[test] + fn test_response_conversion_chat() { + use crate::endpoints::inference::ChatInferenceResponse; + use crate::inference::types::Usage; + use uuid::Uuid; + + let inference_id = Uuid::now_v7(); + let episode_id = Uuid::now_v7(); + let response = InferenceResponse::Chat(ChatInferenceResponse { + inference_id, + episode_id, + variant_name: "test_variant".to_string(), + content: vec![], + usage: Usage { + input_tokens: Some(10), + output_tokens: Some(20), + }, + raw_usage: None, + original_response: None, + raw_response: None, + finish_reason: Some(crate::inference::types::FinishReason::Stop), + }); + + let anthropic_response = + AnthropicMessageResponse::from((response, "test_prefix::".to_string())); + + assert_eq!(anthropic_response.message_type, "message"); + assert_eq!(anthropic_response.role, "assistant"); + assert_eq!(anthropic_response.model, "test_prefix::test_variant"); + assert_eq!(anthropic_response.stop_reason, AnthropicStopReason::EndTurn); + assert_eq!(anthropic_response.usage.input_tokens, 10); + assert_eq!(anthropic_response.usage.output_tokens, 20); + } + + #[test] + fn test_response_conversion_with_text() { + use crate::endpoints::inference::ChatInferenceResponse; + use crate::inference::types::{Text, Usage}; + use uuid::Uuid; + + let inference_id = Uuid::now_v7(); + let episode_id = Uuid::now_v7(); + let response = InferenceResponse::Chat(ChatInferenceResponse { + inference_id, + episode_id, + variant_name: "test_variant".to_string(), + content: vec![crate::inference::types::ContentBlockChatOutput::Text( + Text { + text: "Hello, world!".to_string(), + }, + )], + usage: Usage { + input_tokens: Some(10), + output_tokens: Some(20), + }, + raw_usage: None, + original_response: None, + raw_response: None, + finish_reason: None, + }); + + let anthropic_response = AnthropicMessageResponse::from((response, "prefix::".to_string())); + + assert_eq!(anthropic_response.content.len(), 1); + assert_eq!( + anthropic_response.content[0], + AnthropicOutputContentBlock::Text { + text: "Hello, world!".to_string() + } + ); + } + + #[test] + fn test_response_conversion_with_tool() { + use crate::endpoints::inference::ChatInferenceResponse; + use crate::inference::types::{ContentBlockChatOutput, Usage}; + use crate::tool::InferenceResponseToolCall; + use uuid::Uuid; + + let inference_id = Uuid::now_v7(); + let episode_id = Uuid::now_v7(); + let response = InferenceResponse::Chat(ChatInferenceResponse { + inference_id, + episode_id, + variant_name: "test_variant".to_string(), + content: vec![ContentBlockChatOutput::ToolCall( + InferenceResponseToolCall { + id: "tool_123".to_string(), + raw_name: "my_tool".to_string(), + raw_arguments: "{\"arg\": \"value\"}".to_string(), + name: None, + arguments: None, + }, + )], + usage: Usage { + input_tokens: Some(10), + output_tokens: Some(20), + }, + raw_usage: None, + original_response: None, + raw_response: None, + finish_reason: Some(crate::inference::types::FinishReason::ToolCall), + }); + + let anthropic_response = AnthropicMessageResponse::from((response, "prefix::".to_string())); + + assert_eq!(anthropic_response.content.len(), 1); + assert_eq!( + anthropic_response.content[0], + AnthropicOutputContentBlock::ToolUse { + id: "tool_123".to_string(), + name: "my_tool".to_string(), + input: serde_json::json!({"arg": "value"}) + } + ); + assert_eq!(anthropic_response.stop_reason, AnthropicStopReason::ToolUse); + } + + #[test] + fn test_stop_reason_conversion() { + assert_eq!( + AnthropicStopReason::from(crate::inference::types::FinishReason::Stop), + AnthropicStopReason::EndTurn + ); + assert_eq!( + AnthropicStopReason::from(crate::inference::types::FinishReason::Length), + AnthropicStopReason::MaxTokens + ); + assert_eq!( + AnthropicStopReason::from(crate::inference::types::FinishReason::ToolCall), + AnthropicStopReason::ToolUse + ); + assert_eq!( + AnthropicStopReason::from(crate::inference::types::FinishReason::StopSequence), + AnthropicStopReason::StopSequence + ); + assert_eq!( + AnthropicStopReason::from(crate::inference::types::FinishReason::ContentFilter), + AnthropicStopReason::EndTurn + ); + } + + #[test] + fn test_content_block_with_text() { + let block = AnthropicContentBlock::Text { + text: "Hello".to_string(), + }; + assert_eq!( + block, + AnthropicContentBlock::Text { + text: "Hello".to_string() + } + ); + } + + #[test] + fn test_content_block_with_tool_use() { + let block = AnthropicContentBlock::ToolUse { + id: "123".to_string(), + name: "my_tool".to_string(), + input: json!({"arg": "value"}), + }; + assert!(matches!(block, AnthropicContentBlock::ToolUse { .. })); + if let AnthropicContentBlock::ToolUse { id, name, .. } = block { + assert_eq!(id, "123"); + assert_eq!(name, "my_tool"); + } + } + + #[test] + fn test_content_block_with_tool_result() { + let block = AnthropicContentBlock::ToolResult { + tool_use_id: "123".to_string(), + content: "result".to_string(), + is_error: false, + }; + assert!(matches!(block, AnthropicContentBlock::ToolResult { .. })); + if let AnthropicContentBlock::ToolResult { + tool_use_id, + content, + is_error, + } = block + { + assert_eq!(tool_use_id, "123"); + assert_eq!(content, "result"); + assert!(!is_error); + } + } + + #[test] + fn test_content_block_serialization() { + let json = json!({ + "type": "text", + "text": "Hello" + }); + let block: AnthropicContentBlock = serde_json::from_value(json.clone()).unwrap(); + assert!(matches!(block, AnthropicContentBlock::Text { .. })); + + let json = json!({ + "type": "tool_use", + "id": "123", + "name": "my_tool", + "input": {"arg": "value"} + }); + let block: AnthropicContentBlock = serde_json::from_value(json.clone()).unwrap(); + assert!(matches!(block, AnthropicContentBlock::ToolUse { .. })); + + let json = json!({ + "type": "tool_result", + "tool_use_id": "123", + "content": "result", + "is_error": false + }); + let block: AnthropicContentBlock = serde_json::from_value(json.clone()).unwrap(); + assert!(matches!(block, AnthropicContentBlock::ToolResult { .. })); + } +} diff --git a/tensorzero-core/src/endpoints/anthropic_compatible/types/mod.rs b/tensorzero-core/src/endpoints/anthropic_compatible/types/mod.rs new file mode 100644 index 0000000000..495e3ec8ee --- /dev/null +++ b/tensorzero-core/src/endpoints/anthropic_compatible/types/mod.rs @@ -0,0 +1,11 @@ +//! Type definitions for Anthropic-compatible API. + +pub mod messages; +pub mod streaming; +pub mod tool; +pub mod usage; + +pub use messages::*; +pub use streaming::*; +pub use tool::*; +pub use usage::*; diff --git a/tensorzero-core/src/endpoints/anthropic_compatible/types/streaming.rs b/tensorzero-core/src/endpoints/anthropic_compatible/types/streaming.rs new file mode 100644 index 0000000000..f0f8da9a1a --- /dev/null +++ b/tensorzero-core/src/endpoints/anthropic_compatible/types/streaming.rs @@ -0,0 +1,505 @@ +//! Streaming response types and logic for Anthropic-compatible API. +//! +//! This module provides types and functions for streaming message responses in +//! Server-Sent Events (SSE) format, compatible with Anthropic's Messages API. +//! +//! # Event Types +//! +//! Anthropic uses multiple event types for streaming (unlike OpenAI's single `chunk` event): +//! - `message_start`: Initial metadata (id, type, role, model) +//! - `content_block_start`: Beginning of a content block (text or tool_use) +//! - `content_block_delta`: Incremental content (text or partial JSON) +//! - `content_block_stop`: End of a content block +//! - `message_delta`: Final metadata (stop_reason, usage) +//! - `message_stop`: Stream complete +//! +//! # Example +//! +//! ```rust +//! use tensorzero_core::endpoints::anthropic_compatible::types::streaming::prepare_serialized_anthropic_events; +//! +//! let stream = prepare_serialized_anthropic_events( +//! inference_stream, +//! "tensorzero::function_name::".to_string(), +//! true, // include_usage +//! false, // include_raw_usage +//! false, // include_raw_response +//! ); +//! ``` + +use axum::response::sse::Event; +use futures::Stream; +use serde::Serialize; +use std::collections::HashMap; +use tokio_stream::StreamExt; + +use crate::error::{Error, ErrorDetails}; +use crate::inference::types::ContentBlockChunk; + +use crate::endpoints::anthropic_compatible::types::messages::AnthropicOutputContentBlock; +use crate::endpoints::anthropic_compatible::types::messages::finish_reason_to_anthropic; +use crate::endpoints::anthropic_compatible::types::usage::AnthropicStreamingUsage; +use crate::endpoints::inference::{InferenceResponseChunk, InferenceStream}; + +#[derive(Clone, Debug, PartialEq, Serialize)] +#[serde(untagged)] +pub enum AnthropicStreamingEventData { + MessageStart { + message: AnthropicMessageStart, + }, + ContentBlockStart { + content_block: AnthropicContentBlockStart, + index: u32, + }, + ContentBlockDelta { + delta: AnthropicDelta, + index: u32, + }, + ContentBlockStop { + index: u32, + }, + MessageDelta { + delta: AnthropicMessageDelta, + usage: Option, + }, + MessageStop, +} + +#[derive(Clone, Debug, PartialEq, Serialize)] +pub struct AnthropicMessageStart { + pub id: String, + #[serde(rename = "type")] + pub message_type: String, + pub role: String, + pub content: Vec, + pub model: String, + pub stop_reason: Option, +} + +#[derive(Clone, Debug, PartialEq, Serialize)] +#[serde(tag = "type")] +#[serde(rename_all = "snake_case")] +pub enum AnthropicContentBlockStart { + Text { + index: u32, + }, + ToolUse { + id: String, + name: String, + index: u32, + }, +} + +#[derive(Clone, Debug, PartialEq, Serialize)] +pub struct AnthropicDelta { + #[serde(rename = "type")] + pub delta_type: String, + pub text: Option, + pub partial_json: Option, +} + +#[derive(Clone, Debug, PartialEq, Serialize)] +pub struct AnthropicMessageDelta { + pub stop_reason: Option, + pub stop_sequence: Option, +} + +/// Converts a TensorZero inference chunk to Anthropic-compatible streaming events. +/// +/// # Arguments +/// * `chunk` - The inference response chunk to convert +/// * `response_model_prefix` - Prefix to prepend to the model name in responses +/// * `is_first_chunk` - Whether this is the first chunk (triggers message_start event) +/// * `include_usage` - Whether to include usage information in message_delta +/// * `_include_raw_usage` - Unused (reserved for future use) +/// * `_include_raw_response` - Unused (reserved for future use) +/// +/// # Returns +/// A vector of streaming event data that can be serialized to SSE format +/// +/// # Event Flow +/// 1. First chunk: `message_start` event +/// 2. Each new content block: `content_block_start` event +/// 3. Content deltas: `content_block_delta` events +/// 4. Each block end: `content_block_stop` event +/// 5. Final chunk: `message_delta` event followed by `message_stop` +pub fn convert_inference_response_chunk_to_anthropic( + chunk: InferenceResponseChunk, + response_model_prefix: &str, + is_first_chunk: bool, + include_usage: bool, + _include_raw_usage: bool, + _include_raw_response: bool, +) -> Vec { + let mut events = Vec::new(); + + match chunk { + InferenceResponseChunk::Chat(c) => { + // Generate message_start event for first chunk + if is_first_chunk { + events.push(AnthropicStreamingEventData::MessageStart { + message: AnthropicMessageStart { + id: c.inference_id.to_string(), + message_type: "message".to_string(), + role: "assistant".to_string(), + content: Vec::new(), // Will be populated as blocks arrive + model: format!("{response_model_prefix}{}", c.variant_name), + stop_reason: None, + }, + }); + } + + // Process content blocks + let (text_deltas, tool_calls, has_new_block) = process_chat_content_chunk(c.content); + + // Generate content_block_start for new blocks + if has_new_block && !text_deltas.is_empty() { + events.push(AnthropicStreamingEventData::ContentBlockStart { + content_block: AnthropicContentBlockStart::Text { index: 0 }, + index: 0, + }); + } + + // Generate content_block_delta events + for (index, text_delta) in text_deltas.iter().enumerate() { + if !text_delta.is_empty() { + events.push(AnthropicStreamingEventData::ContentBlockDelta { + delta: AnthropicDelta { + delta_type: "text_delta".to_string(), + text: Some(text_delta.clone()), + partial_json: None, + }, + index: index as u32, + }); + } + } + + // Handle tool calls + for (index, tool_call) in tool_calls.iter().enumerate() { + if tool_call.is_new { + events.push(AnthropicStreamingEventData::ContentBlockStart { + content_block: AnthropicContentBlockStart::ToolUse { + id: tool_call.id.clone(), + name: tool_call.name.clone(), + index: index as u32 + 1, + }, + index: index as u32 + 1, + }); + } + + let arguments_delta = if tool_call.arguments_delta.is_empty() { + None + } else { + Some(tool_call.arguments_delta.clone()) + }; + + if arguments_delta.is_some() { + events.push(AnthropicStreamingEventData::ContentBlockDelta { + delta: AnthropicDelta { + delta_type: "input_json_delta".to_string(), + text: None, + partial_json: arguments_delta, + }, + index: index as u32 + 1, + }); + } + } + + // Generate message_delta for final chunk (when finish_reason is present) + if c.finish_reason.is_some() { + let stop_reason = c.finish_reason.map(finish_reason_to_anthropic); + + let usage = if include_usage { + c.usage.map(|u| AnthropicStreamingUsage { + input_tokens: Some(u.input_tokens.unwrap_or(0)), + output_tokens: Some(u.output_tokens.unwrap_or(0)), + }) + } else { + None + }; + + events.push(AnthropicStreamingEventData::MessageDelta { + delta: AnthropicMessageDelta { + stop_reason, + stop_sequence: None, + }, + usage, + }); + + events.push(AnthropicStreamingEventData::MessageStop); + } + } + InferenceResponseChunk::Json(c) => { + // JSON mode - similar to chat but simpler + if is_first_chunk { + events.push(AnthropicStreamingEventData::MessageStart { + message: AnthropicMessageStart { + id: c.inference_id.to_string(), + message_type: "message".to_string(), + role: "assistant".to_string(), + content: vec![], + model: format!("{response_model_prefix}{}", c.variant_name), + stop_reason: None, + }, + }); + + events.push(AnthropicStreamingEventData::ContentBlockStart { + content_block: AnthropicContentBlockStart::Text { index: 0 }, + index: 0, + }); + } + + // Add text delta + if !c.raw.is_empty() { + events.push(AnthropicStreamingEventData::ContentBlockDelta { + delta: AnthropicDelta { + delta_type: "text_delta".to_string(), + text: Some(c.raw), + partial_json: None, + }, + index: 0, + }); + } + + if c.finish_reason.is_some() { + let stop_reason = c.finish_reason.map(finish_reason_to_anthropic); + + let usage = if include_usage { + c.usage.map(|u| AnthropicStreamingUsage { + input_tokens: Some(u.input_tokens.unwrap_or(0)), + output_tokens: Some(u.output_tokens.unwrap_or(0)), + }) + } else { + None + }; + + events.push(AnthropicStreamingEventData::MessageDelta { + delta: AnthropicMessageDelta { + stop_reason, + stop_sequence: None, + }, + usage, + }); + + events.push(AnthropicStreamingEventData::MessageStop); + } + } + } + + events +} + +struct ToolCallDelta { + id: String, + name: String, + arguments_delta: String, + is_new: bool, +} + +fn process_chat_content_chunk( + content: Vec, +) -> (Vec, Vec, bool) { + let mut text_deltas = Vec::new(); + let mut tool_calls = HashMap::new(); + + for block in content { + match block { + ContentBlockChunk::Text(text) => { + text_deltas.push(text.text); + } + ContentBlockChunk::ToolCall(tool_call) => { + let entry = + tool_calls + .entry(tool_call.id.clone()) + .or_insert_with(|| ToolCallDelta { + id: tool_call.id, + name: tool_call.raw_name.unwrap_or_default(), + arguments_delta: String::new(), + is_new: true, + }); + + if !tool_call.raw_arguments.is_empty() { + entry.arguments_delta.push_str(&tool_call.raw_arguments); + entry.is_new = false; + } + } + ContentBlockChunk::Thought(_) => { + tracing::warn!( + "Ignoring 'thought' content block chunk when constructing Anthropic-compatible response" + ); + } + ContentBlockChunk::Unknown(_) => { + tracing::warn!( + "Ignoring 'unknown' content block chunk when constructing Anthropic-compatible response" + ); + } + } + } + + let has_new_block = !text_deltas.is_empty(); + let tool_calls_vec = tool_calls.into_values().collect(); + + (text_deltas, tool_calls_vec, has_new_block) +} + +/// Prepares an Event for SSE on the way out of the gateway. +/// Converts each InferenceResponseChunk to Anthropic-compatible format and streams it. +pub fn prepare_serialized_anthropic_events( + mut stream: InferenceStream, + response_model_prefix: String, + include_usage: bool, + include_raw_usage: bool, + include_raw_response: bool, +) -> impl Stream> { + async_stream::stream! { + let mut is_first_chunk = true; + + while let Some(chunk) = stream.next().await { + let Ok(chunk) = chunk else { + continue; + }; + + let anthropic_events = convert_inference_response_chunk_to_anthropic( + chunk, + &response_model_prefix, + is_first_chunk, + include_usage, + include_raw_usage, + include_raw_response, + ); + + is_first_chunk = false; + + for event_data in anthropic_events { + let event_type = get_event_type(&event_data); + yield Event::default() + .event(event_type) + .json_data(&event_data) + .map_err(|e| { + Error::new(ErrorDetails::Inference { + message: format!("Failed to convert chunk to Event: {e}"), + }) + }); + } + } + } +} + +fn get_event_type(event: &AnthropicStreamingEventData) -> &str { + match event { + AnthropicStreamingEventData::MessageStart { .. } => "message_start", + AnthropicStreamingEventData::ContentBlockStart { .. } => "content_block_start", + AnthropicStreamingEventData::ContentBlockDelta { .. } => "content_block_delta", + AnthropicStreamingEventData::ContentBlockStop { .. } => "content_block_stop", + AnthropicStreamingEventData::MessageDelta { .. } => "message_delta", + AnthropicStreamingEventData::MessageStop => "message_stop", + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::endpoints::inference::ChatInferenceResponseChunk; + use crate::inference::types::usage::Usage; + use crate::inference::types::{FinishReason, TextChunk}; + use uuid::Uuid; + + #[test] + fn test_convert_chat_chunk_first() { + let inference_id = Uuid::now_v7(); + let episode_id = Uuid::now_v7(); + let chunk = InferenceResponseChunk::Chat(ChatInferenceResponseChunk { + inference_id, + episode_id, + variant_name: "test_variant".to_string(), + content: vec![ContentBlockChunk::Text(TextChunk { + id: "1".to_string(), + text: "Hello".to_string(), + })], + usage: None, + raw_usage: None, + finish_reason: None, + original_chunk: None, + raw_chunk: None, + raw_response: None, + }); + + let events = convert_inference_response_chunk_to_anthropic( + chunk, + "test_prefix::", + true, // is_first_chunk + true, // include_usage + false, // include_raw_usage + false, // include_raw_response + ); + + assert!(!events.is_empty()); + assert!(matches!( + events[0], + AnthropicStreamingEventData::MessageStart { .. } + )); + } + + #[test] + fn test_convert_chat_chunk_final() { + let inference_id = Uuid::now_v7(); + let episode_id = Uuid::now_v7(); + let chunk = InferenceResponseChunk::Chat(ChatInferenceResponseChunk { + inference_id, + episode_id, + variant_name: "test_variant".to_string(), + content: vec![], + usage: Some(Usage { + input_tokens: Some(10), + output_tokens: Some(20), + }), + raw_usage: None, + finish_reason: Some(FinishReason::Stop), + original_chunk: None, + raw_chunk: None, + raw_response: None, + }); + + let events = convert_inference_response_chunk_to_anthropic( + chunk, + "test_prefix::", + false, // is_first_chunk + true, // include_usage + false, // include_raw_usage + false, // include_raw_response + ); + + assert!(!events.is_empty()); + let has_delta = events + .iter() + .any(|e| matches!(e, AnthropicStreamingEventData::MessageDelta { .. })); + assert!(has_delta); + + let has_stop = events + .iter() + .any(|e| matches!(e, AnthropicStreamingEventData::MessageStop)); + assert!(has_stop); + } + + #[test] + fn test_event_type_mapping() { + assert_eq!( + get_event_type(&AnthropicStreamingEventData::MessageStart { + message: AnthropicMessageStart { + id: "test".to_string(), + message_type: "message".to_string(), + role: "assistant".to_string(), + content: vec![], + model: "test".to_string(), + stop_reason: None, + } + }), + "message_start" + ); + + assert_eq!( + get_event_type(&AnthropicStreamingEventData::MessageStop), + "message_stop" + ); + } +} diff --git a/tensorzero-core/src/endpoints/anthropic_compatible/types/tool.rs b/tensorzero-core/src/endpoints/anthropic_compatible/types/tool.rs new file mode 100644 index 0000000000..bf03997d03 --- /dev/null +++ b/tensorzero-core/src/endpoints/anthropic_compatible/types/tool.rs @@ -0,0 +1,216 @@ +//! Tool types for Anthropic-compatible API. + +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +use crate::tool::{ProviderTool, Tool}; +use tensorzero_types::ToolChoice; + +/// Tool definition for Anthropic-compatible requests +#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] +#[cfg_attr(feature = "ts-bindings", derive(ts_rs::TS))] +#[cfg_attr(feature = "ts-bindings", ts(export))] +pub struct AnthropicTool { + pub name: String, + pub description: String, + pub input_schema: AnthropicInputSchema, +} + +#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] +#[serde(rename_all = "lowercase")] +pub enum AnthropicToolChoice { + Auto, + Any, + Tool { name: String }, +} + +#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] +#[cfg_attr(feature = "ts-bindings", derive(ts_rs::TS))] +pub struct AnthropicInputSchema { + #[serde(rename = "type")] + pub schema_type: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub properties: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub required: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub additional_properties: Option, +} + +/// Parameters extracted from tool choice +#[derive(Clone, Debug, Default, PartialEq)] +pub struct AnthropicToolChoiceParams { + pub allowed_tools: Option>, + pub tool_choice: Option, +} + +impl AnthropicToolChoice { + pub fn into_tool_params(self) -> AnthropicToolChoiceParams { + match self { + AnthropicToolChoice::Auto => AnthropicToolChoiceParams { + allowed_tools: None, + tool_choice: Some(ToolChoice::Auto), + }, + AnthropicToolChoice::Any => AnthropicToolChoiceParams { + allowed_tools: None, + tool_choice: Some(ToolChoice::Required), + }, + AnthropicToolChoice::Tool { name } => AnthropicToolChoiceParams { + allowed_tools: Some(vec![name.clone()]), + tool_choice: Some(ToolChoice::Specific(name)), + }, + } + } +} + +impl From for Tool { + fn from(tool: AnthropicTool) -> Self { + Tool::Function(crate::tool::FunctionTool { + name: tool.name, + description: tool.description, + parameters: serde_json::json!({ + "type": tool.input_schema.schema_type, + "properties": tool.input_schema.properties.unwrap_or_default(), + "required": tool.input_schema.required.unwrap_or_default(), + "additionalProperties": tool.input_schema.additional_properties, + }), + strict: false, + }) + } +} + +impl From for ProviderTool { + fn from(tool: AnthropicTool) -> Self { + ProviderTool { + scope: Default::default(), + tool: serde_json::to_value(tool).unwrap_or_default(), + } + } +} + +// ============================================================================ +// Tests +// ============================================================================ + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_tool_choice_auto_conversion() { + let tool_choice = AnthropicToolChoice::Auto; + let params = tool_choice.into_tool_params(); + + assert!(params.allowed_tools.is_none()); + assert_eq!(params.tool_choice, Some(tensorzero_types::ToolChoice::Auto)); + } + + #[test] + fn test_tool_choice_any_conversion() { + let tool_choice = AnthropicToolChoice::Any; + let params = tool_choice.into_tool_params(); + + assert!(params.allowed_tools.is_none()); + assert_eq!( + params.tool_choice, + Some(tensorzero_types::ToolChoice::Required) + ); + } + + #[test] + fn test_tool_choice_specific_conversion() { + let tool_choice = AnthropicToolChoice::Tool { + name: "my_tool".to_string(), + }; + let params = tool_choice.into_tool_params(); + + assert_eq!(params.allowed_tools, Some(vec!["my_tool".to_string()])); + assert_eq!( + params.tool_choice, + Some(tensorzero_types::ToolChoice::Specific( + "my_tool".to_string() + )) + ); + } + + #[test] + fn test_anthropic_tool_serialization() { + let tool = AnthropicTool { + name: "test_tool".to_string(), + description: "A test tool".to_string(), + input_schema: AnthropicInputSchema { + schema_type: "object".to_string(), + properties: Some(HashMap::from_iter([( + "param1".to_string(), + json!({"type": "string"}), + )])), + required: Some(vec!["param1".to_string()]), + additional_properties: Some(false), + }, + }; + + let json = serde_json::to_value(&tool).unwrap(); + assert_eq!(json["name"], "test_tool"); + assert_eq!(json["description"], "A test tool"); + assert_eq!(json["input_schema"]["type"], "object"); + } + + #[test] + fn test_anthropic_tool_deserialization() { + let json = json!({ + "name": "test_tool", + "description": "A test tool", + "input_schema": { + "type": "object", + "properties": { + "param1": {"type": "string"} + }, + "required": ["param1"] + } + }); + + let tool: AnthropicTool = serde_json::from_value(json).unwrap(); + assert_eq!(tool.name, "test_tool"); + assert_eq!(tool.description, "A test tool"); + assert_eq!(tool.input_schema.schema_type, "object"); + assert_eq!(tool.input_schema.required, Some(vec!["param1".to_string()])); + } + + #[test] + fn test_anthropic_tool_to_tool_conversion() { + let anthropic_tool = AnthropicTool { + name: "test_tool".to_string(), + description: "Test".to_string(), + input_schema: AnthropicInputSchema { + schema_type: "object".to_string(), + properties: None, + required: None, + additional_properties: None, + }, + }; + + let tool: Tool = anthropic_tool.into(); + assert!(matches!(tool, Tool::Function(_))); + } + + #[test] + fn test_anthropic_tool_to_provider_tool_conversion() { + let anthropic_tool = AnthropicTool { + name: "test_tool".to_string(), + description: "Test".to_string(), + input_schema: AnthropicInputSchema { + schema_type: "object".to_string(), + properties: None, + required: None, + additional_properties: None, + }, + }; + + let provider_tool: ProviderTool = anthropic_tool.into(); + assert_eq!( + provider_tool.scope, + crate::tool::ProviderToolScope::Unscoped + ); + } +} diff --git a/tensorzero-core/src/endpoints/anthropic_compatible/types/usage.rs b/tensorzero-core/src/endpoints/anthropic_compatible/types/usage.rs new file mode 100644 index 0000000000..c50925a1fd --- /dev/null +++ b/tensorzero-core/src/endpoints/anthropic_compatible/types/usage.rs @@ -0,0 +1,29 @@ +//! Usage types for Anthropic-compatible API. + +use serde::Serialize; + +/// Usage information for Anthropic-compatible responses +#[derive(Clone, Debug, PartialEq, Serialize)] +pub struct AnthropicUsage { + pub input_tokens: u32, + pub output_tokens: u32, +} + +/// Usage information for Anthropic-compatible streaming responses +/// Some fields may be omitted in intermediate chunks +#[derive(Clone, Debug, PartialEq, Serialize)] +pub struct AnthropicStreamingUsage { + #[serde(skip_serializing_if = "Option::is_none")] + pub input_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub output_tokens: Option, +} + +impl From for AnthropicStreamingUsage { + fn from(usage: AnthropicUsage) -> Self { + Self { + input_tokens: Some(usage.input_tokens), + output_tokens: Some(usage.output_tokens), + } + } +} diff --git a/tensorzero-core/src/endpoints/mod.rs b/tensorzero-core/src/endpoints/mod.rs index 5e14533040..c445c862e6 100644 --- a/tensorzero-core/src/endpoints/mod.rs +++ b/tensorzero-core/src/endpoints/mod.rs @@ -5,6 +5,7 @@ use crate::utils::gateway::AppStateData; use axum::routing::MethodRouter; use std::convert::Infallible; +pub mod anthropic_compatible; pub mod batch_inference; pub mod datasets; pub mod embeddings; diff --git a/tensorzero-core/src/error/mod.rs b/tensorzero-core/src/error/mod.rs index 7aa2b5a677..caa506a365 100644 --- a/tensorzero-core/src/error/mod.rs +++ b/tensorzero-core/src/error/mod.rs @@ -425,6 +425,9 @@ pub enum ErrorDetails { InvalidOpenAICompatibleRequest { message: String, }, + InvalidAnthropicCompatibleRequest { + message: String, + }, InvalidProviderConfig { message: String, }, @@ -698,6 +701,7 @@ impl ErrorDetails { ErrorDetails::InvalidModel { .. } => tracing::Level::ERROR, ErrorDetails::InvalidModelProvider { .. } => tracing::Level::ERROR, ErrorDetails::InvalidOpenAICompatibleRequest { .. } => tracing::Level::ERROR, + ErrorDetails::InvalidAnthropicCompatibleRequest { .. } => tracing::Level::ERROR, ErrorDetails::InvalidProviderConfig { .. } => tracing::Level::ERROR, ErrorDetails::InvalidRequest { .. } => tracing::Level::WARN, ErrorDetails::InvalidTemplatePath => tracing::Level::ERROR, @@ -856,6 +860,7 @@ impl ErrorDetails { ErrorDetails::InvalidModel { .. } => StatusCode::INTERNAL_SERVER_ERROR, ErrorDetails::InvalidModelProvider { .. } => StatusCode::INTERNAL_SERVER_ERROR, ErrorDetails::InvalidOpenAICompatibleRequest { .. } => StatusCode::BAD_REQUEST, + ErrorDetails::InvalidAnthropicCompatibleRequest { .. } => StatusCode::BAD_REQUEST, ErrorDetails::InvalidProviderConfig { .. } => StatusCode::INTERNAL_SERVER_ERROR, ErrorDetails::InvalidRequest { .. } => StatusCode::BAD_REQUEST, ErrorDetails::InvalidRenderedStoredInference { .. } => StatusCode::BAD_REQUEST, @@ -1385,6 +1390,12 @@ impl std::fmt::Display for ErrorDetails { f, "Invalid request to OpenAI-compatible endpoint: {message}" ), + ErrorDetails::InvalidAnthropicCompatibleRequest { message } => { + write!( + f, + "Invalid request to Anthropic-compatible endpoint: {message}" + ) + } ErrorDetails::InvalidProviderConfig { message } => write!(f, "{message}"), ErrorDetails::InvalidRequest { message } => write!(f, "{message}"), ErrorDetails::InvalidRenderedStoredInference { message } => { diff --git a/tensorzero-core/src/utils/gateway.rs b/tensorzero-core/src/utils/gateway.rs index c08f9ae873..9bc763222e 100644 --- a/tensorzero-core/src/utils/gateway.rs +++ b/tensorzero-core/src/utils/gateway.rs @@ -23,6 +23,7 @@ use crate::db::feedback::FeedbackQueries; use crate::db::postgres::PostgresConnectionInfo; use crate::db::valkey::ValkeyConnectionInfo; use crate::endpoints; +use crate::endpoints::anthropic_compatible::RouterExt as AnthropicRouterExt; use crate::endpoints::openai_compatible::RouterExt; use crate::error::{Error, ErrorDetails}; use crate::howdy::setup_howdy; @@ -689,6 +690,7 @@ pub async fn start_openai_compatible_gateway( let router = Router::new() .register_openai_compatible_routes() + .register_anthropic_compatible_routes() .fallback(endpoints::fallback::handle_404) .layer(DefaultBodyLimit::max(100 * 1024 * 1024)) // increase the default body limit from 2MB to 100MB .layer(axum::middleware::from_fn_with_state( diff --git a/tensorzero-core/tests/e2e/anthropic_compatible.rs b/tensorzero-core/tests/e2e/anthropic_compatible.rs new file mode 100644 index 0000000000..2f1e7e7594 --- /dev/null +++ b/tensorzero-core/tests/e2e/anthropic_compatible.rs @@ -0,0 +1,1081 @@ +#![expect(clippy::print_stdout)] + +use std::collections::HashSet; + +use axum::extract::State; +use http_body_util::BodyExt; +use reqwest::{Client, StatusCode}; +use serde_json::{Value, json}; +use uuid::Uuid; + +use crate::common::get_gateway_endpoint; + +use tensorzero_core::endpoints::anthropic_compatible::messages::messages_handler; +use tensorzero_core::{ + db::clickhouse::test_helpers::{ + get_clickhouse, select_chat_inference_clickhouse, select_model_inference_clickhouse, + }, + utils::gateway::StructuredJson, +}; + +#[tokio::test(flavor = "multi_thread")] +async fn test_anthropic_compatible_basic_request() { + let client = tensorzero::test_helpers::make_embedded_gateway().await; + let state = client.get_app_state_data().unwrap().clone(); + let episode_id = Uuid::now_v7(); + + let response = messages_handler( + State(state), + None, + StructuredJson( + serde_json::from_value(serde_json::json!({ + "model": "tensorzero::function_name::basic_test_no_system_schema", + "max_tokens": 100, + "messages": [ + { + "role": "user", + "content": "What is the capital of Japan?" + } + ], + "tensorzero::tags": { + "foo": "bar" + }, + "tensorzero::episode_id": episode_id.to_string(), + })) + .unwrap(), + ), + ) + .await + .unwrap(); + + // Check Response is OK + assert_eq!(response.status(), StatusCode::OK); + let response_json = response.into_body().collect().await.unwrap().to_bytes(); + let response_json: Value = serde_json::from_slice(&response_json).unwrap(); + println!("response: {response_json:?}"); + + // Check basic response structure + assert_eq!( + response_json.get("type").unwrap().as_str().unwrap(), + "message" + ); + assert_eq!( + response_json.get("role").unwrap().as_str().unwrap(), + "assistant" + ); + + // Check content array + let content = response_json.get("content").unwrap().as_array().unwrap(); + assert_eq!(content.len(), 1); + let first_block = content.first().unwrap(); + assert_eq!(first_block.get("type").unwrap().as_str().unwrap(), "text"); + let text = first_block.get("text").unwrap().as_str().unwrap(); + assert_eq!( + text, + "Megumin gleefully chanted her spell, unleashing a thunderous explosion that lit up the sky and left a massive crater in its wake." + ); + + // Check model prefix + let response_model = response_json.get("model").unwrap().as_str().unwrap(); + assert_eq!( + response_model, + "tensorzero::function_name::basic_test_no_system_schema::variant_name::test" + ); + + // Check stop_reason + let stop_reason = response_json.get("stop_reason").unwrap().as_str().unwrap(); + assert_eq!(stop_reason, "end_turn"); + + // Get inference_id + let inference_id: Uuid = response_json + .get("id") + .unwrap() + .as_str() + .unwrap() + .parse() + .unwrap(); + + // Sleep for 1 second to allow time for data to be inserted into ClickHouse + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + + // Check ClickHouse + let clickhouse = get_clickhouse().await; + + // Check Inference table + let result = select_chat_inference_clickhouse(&clickhouse, inference_id) + .await + .unwrap(); + let id = result.get("id").unwrap().as_str().unwrap(); + let id_uuid = Uuid::parse_str(id).unwrap(); + assert_eq!(id_uuid, inference_id); + let function_name = result.get("function_name").unwrap().as_str().unwrap(); + assert_eq!(function_name, "basic_test_no_system_schema"); + + // Check tags + let tags = result.get("tags").unwrap().as_object().unwrap(); + assert_eq!(tags.get("foo").unwrap().as_str().unwrap(), "bar"); + assert_eq!(tags.len(), 1); + + // Check variant name + let variant_name = result.get("variant_name").unwrap().as_str().unwrap(); + assert_eq!(variant_name, "test"); + + // Check the ModelInference Table + let result = select_model_inference_clickhouse(&clickhouse, inference_id) + .await + .unwrap(); + println!("ModelInference result: {result:?}"); + let inference_id_result = result.get("inference_id").unwrap().as_str().unwrap(); + let inference_id_result = Uuid::parse_str(inference_id_result).unwrap(); + assert_eq!(inference_id_result, inference_id); + let model_name = result.get("model_name").unwrap().as_str().unwrap(); + assert_eq!(model_name, "test"); + let finish_reason = result.get("finish_reason").unwrap().as_str().unwrap(); + assert_eq!(finish_reason, "stop"); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_anthropic_compatible_system_prompt_string() { + let client = tensorzero::test_helpers::make_embedded_gateway().await; + let state = client.get_app_state_data().unwrap().clone(); + + let response = messages_handler( + State(state), + None, + StructuredJson( + serde_json::from_value(serde_json::json!({ + "model": "tensorzero::function_name::basic_test_no_system_schema", + "max_tokens": 100, + "system": "You are a helpful assistant named TensorBot.", + "messages": [ + { + "role": "user", + "content": "Hello" + } + ] + })) + .unwrap(), + ), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + let response_json = response.into_body().collect().await.unwrap().to_bytes(); + let response_json: Value = serde_json::from_slice(&response_json).unwrap(); + println!("response: {response_json:?}"); + + // Just check that we got a valid response + assert_eq!(response_json.get("role").unwrap().as_str().unwrap(), "assistant"); + let content = response_json.get("content").unwrap().as_array().unwrap(); + assert!(!content.is_empty()); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_anthropic_compatible_system_prompt_array() { + let client = tensorzero::test_helpers::make_embedded_gateway().await; + let state = client.get_app_state_data().unwrap().clone(); + + let response = messages_handler( + State(state), + None, + StructuredJson( + serde_json::from_value(serde_json::json!({ + "model": "tensorzero::function_name::basic_test_no_system_schema", + "max_tokens": 100, + "system": [ + { + "type": "text", + "text": "You are a helpful assistant." + } + ], + "messages": [ + { + "role": "user", + "content": "Hello" + } + ] + })) + .unwrap(), + ), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + let response_json = response.into_body().collect().await.unwrap().to_bytes(); + let response_json: Value = serde_json::from_slice(&response_json).unwrap(); + + // Just check that we got a valid response + assert_eq!(response_json.get("role").unwrap().as_str().unwrap(), "assistant"); + let content = response_json.get("content").unwrap().as_array().unwrap(); + assert!(!content.is_empty()); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_anthropic_compatible_missing_max_tokens() { + let client = tensorzero::test_helpers::make_embedded_gateway().await; + let state = client.get_app_state_data().unwrap().clone(); + + let error = messages_handler( + State(state), + None, + StructuredJson( + serde_json::from_value(serde_json::json!({ + "model": "tensorzero::function_name::basic_test_no_system_schema", + "messages": [ + { + "role": "user", + "content": "Hello" + } + ] + })) + .unwrap(), + ), + ) + .await + .unwrap_err(); + + assert_eq!(error.status(), StatusCode::BAD_REQUEST); + assert!(error + .to_string() + .contains("max_tokens")); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_anthropic_compatible_max_tokens_zero() { + let client = tensorzero::test_helpers::make_embedded_gateway().await; + let state = client.get_app_state_data().unwrap().clone(); + + let error = messages_handler( + State(state), + None, + StructuredJson( + serde_json::from_value(serde_json::json!({ + "model": "tensorzero::function_name::basic_test_no_system_schema", + "max_tokens": 0, + "messages": [ + { + "role": "user", + "content": "Hello" + } + ] + })) + .unwrap(), + ), + ) + .await + .unwrap_err(); + + assert_eq!(error.status(), StatusCode::BAD_REQUEST); + assert!(error + .to_string() + .contains("max_tokens")); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_anthropic_compatible_invalid_model_prefix() { + let client = tensorzero::test_helpers::make_embedded_gateway().await; + let state = client.get_app_state_data().unwrap().clone(); + + let error = messages_handler( + State(state), + None, + StructuredJson( + serde_json::from_value(serde_json::json!({ + "model": "invalid::model::name", + "max_tokens": 100, + "messages": [ + { + "role": "user", + "content": "Hello" + } + ] + })) + .unwrap(), + ), + ) + .await + .unwrap_err(); + + // Should get an error about invalid model prefix + assert!(error.to_string().contains("model") || error.to_string().contains("function")); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_anthropic_compatible_model_name_target() { + let client = tensorzero::test_helpers::make_embedded_gateway().await; + let state = client.get_app_state_data().unwrap().clone(); + + let response = messages_handler( + State(state), + None, + StructuredJson( + serde_json::from_value(serde_json::json!({ + "model": "tensorzero::model_name::openai::gpt-4o-mini", + "max_tokens": 100, + "messages": [ + { + "role": "user", + "content": "Say 'test passed'" + } + ] + })) + .unwrap(), + ), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + let response_json = response.into_body().collect().await.unwrap().to_bytes(); + let response_json: Value = serde_json::from_slice(&response_json).unwrap(); + + // Check model prefix + let response_model = response_json.get("model").unwrap().as_str().unwrap(); + assert_eq!(response_model, "tensorzero::model_name::openai::gpt-4o-mini"); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_anthropic_compatible_streaming() { + use futures::StreamExt; + use reqwest_eventsource::{Event, RequestBuilderExt}; + + let client = Client::new(); + let episode_id = Uuid::now_v7(); + + let body = json!({ + "model": "tensorzero::function_name::basic_test_no_system_schema", + "max_tokens": 100, + "messages": [ + { + "role": "user", + "content": "What's the reason for why we use AC not DC?" + } + ], + "stream": true, + "tensorzero::episode_id": episode_id.to_string(), + }); + + let mut response = client + .post(get_gateway_endpoint("/anthropic/v1/messages")) + .header("Content-Type", "application/json") + .json(&body) + .eventsource() + .unwrap(); + + let mut events = vec![]; + let mut found_message_start = false; + let mut found_message_stop = false; + let mut found_content_block_start = false; + let mut found_content_block_delta = false; + let mut found_message_delta = false; + + while let Some(event) = response.next().await { + let event = event.unwrap(); + match event { + Event::Open => continue, + Event::Message(message) => { + if message.data == "[DONE]" { + break; + } + events.push((message.event, message.data)); + } + } + } + + // Check we got the expected event types + for (event_type, data) in &events { + let parsed: Value = serde_json::from_str(data).unwrap(); + println!("Event type: {event_type}, Data: {parsed}"); + + match event_type.as_str() { + "message_start" => { + found_message_start = true; + assert!(parsed.get("message").is_some()); + assert!(parsed["message"].get("id").is_some()); + assert_eq!(parsed["message"]["type"].as_str().unwrap(), "message"); + assert_eq!(parsed["message"]["role"].as_str().unwrap(), "assistant"); + } + "content_block_start" => { + found_content_block_start = true; + assert!(parsed.get("content_block").is_some()); + assert!(parsed.get("index").is_some()); + } + "content_block_delta" => { + found_content_block_delta = true; + assert!(parsed.get("delta").is_some()); + assert!(parsed.get("index").is_some()); + } + "content_block_stop" => { + assert!(parsed.get("index").is_some()); + } + "message_delta" => { + found_message_delta = true; + assert!(parsed.get("delta").is_some()); + assert!(parsed["delta"].get("stop_reason").is_some()); + } + "message_stop" => { + found_message_stop = true; + } + _ => {} + } + } + + assert!(found_message_start, "Should have message_start event"); + assert!(found_content_block_start, "Should have content_block_start event"); + assert!(found_content_block_delta, "Should have content_block_delta event"); + assert!(found_message_delta, "Should have message_delta event"); + assert!(found_message_stop, "Should have message_stop event"); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_anthropic_compatible_tool_use() { + let client = tensorzero::test_helpers::make_embedded_gateway().await; + let state = client.get_app_state_data().unwrap().clone(); + let episode_id = Uuid::now_v7(); + + let response = messages_handler( + State(state), + None, + StructuredJson( + serde_json::from_value(serde_json::json!({ + "model": "tensorzero::function_name::weather_helper", + "max_tokens": 1000, + "messages": [ + { + "role": "user", + "content": "What's the weather in Tokyo?" + } + ], + "tools": [ + { + "name": "get_temperature", + "description": "Get the current temperature in a given location", + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + } + }, + "required": ["location"] + } + } + ], + "tensorzero::episode_id": episode_id.to_string(), + })) + .unwrap(), + ), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + let response_json = response.into_body().collect().await.unwrap().to_bytes(); + let response_json: Value = serde_json::from_slice(&response_json).unwrap(); + println!("response: {response_json:?}"); + + // Check that we got a tool_use block + let content = response_json.get("content").unwrap().as_array().unwrap(); + let tool_use_block = content + .iter() + .find(|block| block.get("type").unwrap().as_str().unwrap() == "tool_use"); + assert!(tool_use_block.is_some()); + + let tool_use = tool_use_block.unwrap(); + assert_eq!(tool_use.get("name").unwrap().as_str().unwrap(), "get_temperature"); + assert!(tool_use.get("id").is_some()); + assert!(tool_use.get("input").is_some()); + + // Check stop_reason is tool_use + assert_eq!( + response_json.get("stop_reason").unwrap().as_str().unwrap(), + "tool_use" + ); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_anthropic_compatible_streaming_tool_use() { + use futures::StreamExt; + use reqwest_eventsource::{Event, RequestBuilderExt}; + + let client = Client::new(); + let episode_id = Uuid::now_v7(); + + let body = json!({ + "model": "tensorzero::function_name::weather_helper", + "max_tokens": 1000, + "messages": [ + { + "role": "user", + "content": "What's the weather in Tokyo?" + } + ], + "tools": [ + { + "name": "get_temperature", + "description": "Get the current temperature in a given location", + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + } + }, + "required": ["location"] + } + } + ], + "stream": true, + "tensorzero::episode_id": episode_id.to_string(), + }); + + let mut response = client + .post(get_gateway_endpoint("/anthropic/v1/messages")) + .header("Content-Type", "application/json") + .json(&body) + .eventsource() + .unwrap(); + + let mut found_tool_use_start = false; + let mut found_tool_use_delta = false; + + while let Some(event) = response.next().await { + let event = event.unwrap(); + match event { + Event::Open => continue, + Event::Message(message) => { + if message.data == "[DONE]" { + break; + } + let parsed: Value = serde_json::from_str(&message.data).unwrap(); + println!("Event: {} - Data: {}", message.event, parsed); + + if message.event == "content_block_start" { + if let Some(content_block) = parsed.get("content_block") { + if content_block.get("type").unwrap().as_str().unwrap() == "tool_use" { + found_tool_use_start = true; + } + } + } + if message.event == "content_block_delta" { + if let Some(delta) = parsed.get("delta") { + if delta.get("type").unwrap().as_str().unwrap() == "input_json_delta" { + found_tool_use_delta = true; + } + } + } + } + } + } + + assert!(found_tool_use_start, "Should have tool_use content_block_start event"); + assert!(found_tool_use_delta, "Should have input_json_delta event"); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_anthropic_compatible_tool_result() { + let client = tensorzero::test_helpers::make_embedded_gateway().await; + let state = client.get_app_state_data().unwrap().clone(); + let episode_id = Uuid::now_v7(); + + // First, get a tool use + let response1 = messages_handler( + State(state.clone()), + None, + StructuredJson( + serde_json::from_value(serde_json::json!({ + "model": "tensorzero::function_name::weather_helper", + "max_tokens": 1000, + "messages": [ + { + "role": "user", + "content": "What's the weather in Tokyo?" + } + ], + "tools": [ + { + "name": "get_temperature", + "description": "Get the current temperature in a given location", + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + } + }, + "required": ["location"] + } + } + ], + "tensorzero::episode_id": episode_id.to_string(), + })) + .unwrap(), + ), + ) + .await + .unwrap(); + + let response1_json = response1.into_body().collect().await.unwrap().to_bytes(); + let response1_json: Value = serde_json::from_slice(&response1_json).unwrap(); + + // Extract the tool_use_id + let content = response1_json.get("content").unwrap().as_array().unwrap(); + let tool_use_block = content + .iter() + .find(|block| block.get("type").unwrap().as_str().unwrap() == "tool_use") + .unwrap(); + let tool_use_id = tool_use_block.get("id").unwrap().as_str().unwrap(); + + // Now send the tool result + let response2 = messages_handler( + State(state), + None, + StructuredJson( + serde_json::from_value(serde_json::json!({ + "model": "tensorzero::function_name::weather_helper", + "max_tokens": 1000, + "messages": [ + { + "role": "user", + "content": "What's the weather in Tokyo?" + }, + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": tool_use_id, + "name": "get_temperature", + "input": {"location": "Tokyo"} + } + ] + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": tool_use_id, + "content": "The temperature in Tokyo is 22°C" + } + ] + } + ], + "tools": [ + { + "name": "get_temperature", + "description": "Get the current temperature in a given location", + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + } + }, + "required": ["location"] + } + } + ], + "tensorzero::episode_id": episode_id.to_string(), + })) + .unwrap(), + ), + ) + .await + .unwrap(); + + assert_eq!(response2.status(), StatusCode::OK); + let response2_json = response2.into_body().collect().await.unwrap().to_bytes(); + let response2_json: Value = serde_json::from_slice(&response2_json).unwrap(); + println!("response2: {response2_json:?}"); + + // Check that we got a text response (not a tool_use) + let content2 = response2_json.get("content").unwrap().as_array().unwrap(); + assert!(content2 + .iter() + .any(|block| block.get("type").unwrap().as_str().unwrap() == "text")); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_anthropic_compatible_tool_choice_auto() { + let client = tensorzero::test_helpers::make_embedded_gateway().await; + let state = client.get_app_state_data().unwrap().clone(); + + let response = messages_handler( + State(state), + None, + StructuredJson( + serde_json::from_value(serde_json::json!({ + "model": "tensorzero::function_name::weather_helper", + "max_tokens": 1000, + "messages": [ + { + "role": "user", + "content": "What's the weather in Tokyo?" + } + ], + "tools": [ + { + "name": "get_temperature", + "description": "Get the current temperature in a given location", + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + } + }, + "required": ["location"] + } + } + ], + "tool_choice": "auto" + })) + .unwrap(), + ), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_anthropic_compatible_tool_choice_any() { + let client = tensorzero::test_helpers::make_embedded_gateway().await; + let state = client.get_app_state_data().unwrap().clone(); + + let response = messages_handler( + State(state), + None, + StructuredJson( + serde_json::from_value(serde_json::json!({ + "model": "tensorzero::function_name::weather_helper", + "max_tokens": 1000, + "messages": [ + { + "role": "user", + "content": "Hello" + } + ], + "tools": [ + { + "name": "get_temperature", + "description": "Get the current temperature in a given location", + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + } + }, + "required": ["location"] + } + } + ], + "tool_choice": "any" + })) + .unwrap(), + ), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + // With "any", it should force a tool call even if not needed + let response_json = response.into_body().collect().await.unwrap().to_bytes(); + let response_json: Value = serde_json::from_slice(&response_json).unwrap(); + + let content = response_json.get("content").unwrap().as_array().unwrap(); + let tool_use_block = content + .iter() + .find(|block| block.get("type").unwrap().as_str().unwrap() == "tool_use"); + assert!(tool_use_block.is_some()); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_anthropic_compatible_tool_choice_specific() { + let client = tensorzero::test_helpers::make_embedded_gateway().await; + let state = client.get_app_state_data().unwrap().clone(); + + let response = messages_handler( + State(state), + None, + StructuredJson( + serde_json::from_value(serde_json::json!({ + "model": "tensorzero::function_name::weather_helper", + "max_tokens": 1000, + "messages": [ + { + "role": "user", + "content": "Hello" + } + ], + "tools": [ + { + "name": "get_temperature", + "description": "Get the current temperature in a given location", + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + } + }, + "required": ["location"] + } + } + ], + "tool_choice": { + "type": "tool", + "name": "get_temperature" + } + })) + .unwrap(), + ), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + // Should use the specific tool + let response_json = response.into_body().collect().await.unwrap().to_bytes(); + let response_json: Value = serde_json::from_slice(&response_json).unwrap(); + + let content = response_json.get("content").unwrap().as_array().unwrap(); + let tool_use_block = content + .iter() + .find(|block| block.get("type").unwrap().as_str().unwrap() == "tool_use"); + assert!(tool_use_block.is_some()); + + let tool_use = tool_use_block.unwrap(); + assert_eq!( + tool_use.get("name").unwrap().as_str().unwrap(), + "get_temperature" + ); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_anthropic_compatible_temperature() { + let client = tensorzero::test_helpers::make_embedded_gateway().await; + let state = client.get_app_state_data().unwrap().clone(); + + let response = messages_handler( + State(state), + None, + StructuredJson( + serde_json::from_value(serde_json::json!({ + "model": "tensorzero::function_name::basic_test_no_system_schema", + "max_tokens": 100, + "messages": [ + { + "role": "user", + "content": "Hello" + } + ], + "temperature": 0.5 + })) + .unwrap(), + ), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_anthropic_compatible_top_p() { + let client = tensorzero::test_helpers::make_embedded_gateway().await; + let state = client.get_app_state_data().unwrap().clone(); + + let response = messages_handler( + State(state), + None, + StructuredJson( + serde_json::from_value(serde_json::json!({ + "model": "tensorzero::function_name::basic_test_no_system_schema", + "max_tokens": 100, + "messages": [ + { + "role": "user", + "content": "Hello" + } + ], + "top_p": 0.9 + })) + .unwrap(), + ), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_anthropic_compatible_top_k() { + let client = tensorzero::test_helpers::make_embedded_gateway().await; + let state = client.get_app_state_data().unwrap().clone(); + + let response = messages_handler( + State(state), + None, + StructuredJson( + serde_json::from_value(serde_json::json!({ + "model": "tensorzero::function_name::basic_test_no_system_schema", + "max_tokens": 100, + "messages": [ + { + "role": "user", + "content": "Hello" + } + ], + "top_k": 40 + })) + .unwrap(), + ), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_anthropic_compatible_stop_sequences() { + let client = tensorzero::test_helpers::make_embedded_gateway().await; + let state = client.get_app_state_data().unwrap().clone(); + + let response = messages_handler( + State(state), + None, + StructuredJson( + serde_json::from_value(serde_json::json!({ + "model": "tensorzero::function_name::basic_test_no_system_schema", + "max_tokens": 1000, + "messages": [ + { + "role": "user", + "content": "Count from 1 to 10" + } + ], + "stop_sequences": ["5"] + })) + .unwrap(), + ), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + let response_json = response.into_body().collect().await.unwrap().to_bytes(); + let response_json: Value = serde_json::from_slice(&response_json).unwrap(); + + // Check stop_reason + let stop_reason = response_json.get("stop_reason").unwrap().as_str().unwrap(); + assert_eq!(stop_reason, "stop_sequence"); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_anthropic_compatible_usage() { + let client = tensorzero::test_helpers::make_embedded_gateway().await; + let state = client.get_app_state_data().unwrap().clone(); + + let response = messages_handler( + State(state), + None, + StructuredJson( + serde_json::from_value(serde_json::json!({ + "model": "tensorzero::function_name::basic_test_no_system_schema", + "max_tokens": 100, + "messages": [ + { + "role": "user", + "content": "Hello" + } + ] + })) + .unwrap(), + ), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + let response_json = response.into_body().collect().await.unwrap().to_bytes(); + let response_json: Value = serde_json::from_slice(&response_json).unwrap(); + + // Check usage is present + let usage = response_json.get("usage").unwrap(); + assert!(usage.get("input_tokens").is_some()); + assert!(usage.get("output_tokens").is_some()); + + let input_tokens = usage.get("input_tokens").unwrap().as_u64().unwrap(); + assert!(input_tokens > 0); + + let output_tokens = usage.get("output_tokens").unwrap().as_u64().unwrap(); + assert!(output_tokens > 0); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_anthropic_compatible_models() { + use reqwest::Client; + + let client = tensorzero::test_helpers::make_embedded_gateway().await; + + let response = Client::new() + .get(get_gateway_endpoint("/anthropic/v1/models")) + .send() + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + let response_json: Value = response.json().await.unwrap(); + + // Check response structure + assert_eq!(response_json.get("object").unwrap().as_str().unwrap(), "list"); + + let data = response_json.get("data").unwrap().as_array().unwrap(); + assert!(!data.is_empty(), "Should return at least one model"); + + // Check that each model has the required fields + for model in data { + assert!(model.get("id").is_some(), "Model should have 'id' field"); + assert!(model.get("name").is_some(), "Model should have 'name' field"); + assert_eq!( + model.get("type").unwrap().as_str().unwrap(), + "model", + "Model type should be 'model'" + ); + } + + // Check that function_name models are included + let has_function_model = data + .iter() + .any(|m| { + m.get("id") + .and_then(|id| id.as_str()) + .map(|id| id.starts_with("tensorzero::function_name::")) + .unwrap_or(false) + }); + assert!( + has_function_model, + "Should include at least one function_name model" + ); +} +