From aba171f1a7b0f36836bd56b62d6149a429071dab Mon Sep 17 00:00:00 2001 From: Viraj Mehta Date: Mon, 8 Dec 2025 18:13:52 -0500 Subject: [PATCH 01/11] added get inference to the tensorzero client --- ui/app/utils/tensorzero/tensorzero.ts | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/ui/app/utils/tensorzero/tensorzero.ts b/ui/app/utils/tensorzero/tensorzero.ts index 761a33b823..bc76ddb269 100644 --- a/ui/app/utils/tensorzero/tensorzero.ts +++ b/ui/app/utils/tensorzero/tensorzero.ts @@ -19,6 +19,7 @@ import type { DeleteDatapointsResponse, GetDatapointsRequest, GetDatapointsResponse, + GetInferencesRequest, GetInferencesResponse, ListDatapointsRequest, ListDatasetsResponse, @@ -633,6 +634,27 @@ export class TensorZeroClient { return (await response.json()) as GetInferencesResponse; } + /** + * Retrieves specific inferences by their IDs. + * Uses the public v1 API endpoint. + * @param request - The get inferences request containing IDs and optional filters + * @returns A promise that resolves with the inferences response + * @throws Error if the request fails + */ + async getInferences( + request: GetInferencesRequest, + ): Promise { + const response = await this.fetch("/v1/inferences/get_inferences", { + method: "POST", + body: JSON.stringify(request), + }); + if (!response.ok) { + const message = await this.getErrorText(response); + this.handleHttpError({ message, response }); + } + return (await response.json()) as GetInferencesResponse; + } + /** * Fetches the gateway configuration for the UI. * @returns A promise that resolves with the UiConfig object From 03c9f59b0111ec790fee241c78109cfa3cd39e3f Mon Sep 17 00:00:00 2001 From: Viraj Mehta Date: Mon, 8 Dec 2025 19:26:23 -0500 Subject: [PATCH 02/11] add extra body to stored inference types --- .gitignore | 1 + clients/python/tensorzero/generated_types.py | 5 + .../tensorzero-node/lib/bindings/ExtraBody.ts | 116 ++++++++++++++++++ .../lib/bindings/StoredChatInference.ts | 2 + .../lib/bindings/StoredJsonInference.ts | 2 + .../bindings/UnfilteredInferenceExtraBody.ts | 8 ++ .../tensorzero-node/lib/bindings/index.ts | 2 + .../src/db/clickhouse/inference_queries.rs | 4 + .../src/db/clickhouse/query_builder/mod.rs | 26 ++++ tensorzero-core/src/db/inferences.rs | 9 +- .../datasets/v1/create_from_inferences.rs | 1 + .../stored_inferences/v1/get_inferences.rs | 1 + .../src/inference/types/extra_body.rs | 2 +- tensorzero-core/src/stored_inference.rs | 12 ++ .../stored_inferences/get_inferences.rs | 63 ++++++++++ .../tests/e2e/render_inferences.rs | 8 ++ 16 files changed, 260 insertions(+), 2 deletions(-) create mode 100644 internal/tensorzero-node/lib/bindings/ExtraBody.ts create mode 100644 internal/tensorzero-node/lib/bindings/UnfilteredInferenceExtraBody.ts diff --git a/.gitignore b/.gitignore index 1831eba607..bd523ba340 100644 --- a/.gitignore +++ b/.gitignore @@ -57,3 +57,4 @@ tensorzero-optimizers/bindings clients/rust/bindings/ internal/tensorzero-node/bindings/ *.junit.xml +tensorzero-core/schemas diff --git a/clients/python/tensorzero/generated_types.py b/clients/python/tensorzero/generated_types.py index 8f9db0e225..5d5d32f8c0 100644 --- a/clients/python/tensorzero/generated_types.py +++ b/clients/python/tensorzero/generated_types.py @@ -840,6 +840,9 @@ class ToolResult: result: str +UnfilteredInferenceExtraBody = list[ExtraBody] + + @dataclass(kw_only=True) class Unknown: """ @@ -1555,6 +1558,7 @@ class StoredJsonInference: timestamp: str variant_name: str dispreferred_outputs: list[JsonInferenceOutput] | None = field(default_factory=lambda: []) + extra_body: UnfilteredInferenceExtraBody | None = field(default_factory=lambda: []) tags: dict[str, str] | None = field(default_factory=lambda: {}) @@ -1899,6 +1903,7 @@ class StoredChatInference: If not provided, all static tools are allowed. """ dispreferred_outputs: list[list[ContentBlockChatOutput]] | None = field(default_factory=lambda: []) + extra_body: UnfilteredInferenceExtraBody | None = field(default_factory=lambda: []) parallel_tool_calls: bool | None = None """ Whether to use parallel tool calls in the inference. Optional. diff --git a/internal/tensorzero-node/lib/bindings/ExtraBody.ts b/internal/tensorzero-node/lib/bindings/ExtraBody.ts new file mode 100644 index 0000000000..7e2886d9b3 --- /dev/null +++ b/internal/tensorzero-node/lib/bindings/ExtraBody.ts @@ -0,0 +1,116 @@ +// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually. +import type { JsonValue } from "./serde_json/JsonValue"; + +export type ExtraBody = + | { + /** + * A fully-qualified model provider name in your configuration (e.g. `tensorzero::model_name::my_model::provider_name::my_provider`) + */ + model_provider_name: string; + /** + * A JSON Pointer to the field to update (e.g. `/enable_agi`) + */ + pointer: string; + /** + * The value to set the field to + */ + value: JsonValue; + } + | { + /** + * A fully-qualified model provider name in your configuration (e.g. `tensorzero::model_name::my_model::provider_name::my_provider`) + */ + model_provider_name: string; + /** + * A JSON Pointer to the field to update (e.g. `/enable_agi`) + */ + pointer: string; + /** + * Set to true to remove the field from the model provider request's body + */ + delete: null; + } + | { + /** + * A variant name in your configuration (e.g. `my_variant`) + */ + variant_name: string; + /** + * A JSON Pointer to the field to update (e.g. `/enable_agi`) + */ + pointer: string; + /** + * The value to set the field to + */ + value: JsonValue; + } + | { + /** + * A variant name in your configuration (e.g. `my_variant`) + */ + variant_name: string; + /** + * A JSON Pointer to the field to update (e.g. `/enable_agi`) + */ + pointer: string; + /** + * Set to true to remove the field from the model provider request's body + */ + delete: null; + } + | { + /** + * A model name in your configuration (e.g. `my_gpt_5`) or a short-hand model name (e.g. `openai::gpt-5`) + */ + model_name: string; + /** + * A provider name for the model you specified (e.g. `my_openai`) + */ + provider_name?: string; + /** + * A JSON Pointer to the field to update (e.g. `/enable_agi`) + */ + pointer: string; + /** + * The value to set the field to + */ + value: JsonValue; + } + | { + /** + * A model name in your configuration (e.g. `my_gpt_5`) or a short-hand model name (e.g. `openai::gpt-5`) + */ + model_name: string; + /** + * A provider name for the model you specified (e.g. `my_openai`) + */ + provider_name?: string; + /** + * A JSON Pointer to the field to update (e.g. `/enable_agi`) + */ + pointer: string; + /** + * Set to true to remove the field from the model provider request's body + */ + delete: null; + } + | { + /** + * A JSON Pointer to the field to update (e.g. `/enable_agi`) + */ + pointer: string; + /** + * The value to set the field to + */ + value: JsonValue; + } + | { + /** + * A JSON Pointer to the field to update (e.g. `/enable_agi`) + */ + pointer: string; + /** + * Set to true to remove the field from the model provider request's body + */ + delete: null; + }; diff --git a/internal/tensorzero-node/lib/bindings/StoredChatInference.ts b/internal/tensorzero-node/lib/bindings/StoredChatInference.ts index c4a71a0f16..283a677f5f 100644 --- a/internal/tensorzero-node/lib/bindings/StoredChatInference.ts +++ b/internal/tensorzero-node/lib/bindings/StoredChatInference.ts @@ -4,6 +4,7 @@ import type { ProviderTool } from "./ProviderTool"; import type { StoredInput } from "./StoredInput"; import type { Tool } from "./Tool"; import type { ToolChoice } from "./ToolChoice"; +import type { UnfilteredInferenceExtraBody } from "./UnfilteredInferenceExtraBody"; /** * Wire variant of StoredChatInference for API responses with Python/TypeScript bindings @@ -18,6 +19,7 @@ export type StoredChatInference = { episode_id: string; inference_id: string; tags: { [key in string]?: string }; + extra_body: UnfilteredInferenceExtraBody; /** * A subset of static tools configured for the function that the inference is allowed to use. Optional. * If not provided, all static tools are allowed. diff --git a/internal/tensorzero-node/lib/bindings/StoredJsonInference.ts b/internal/tensorzero-node/lib/bindings/StoredJsonInference.ts index 6938f069f4..2b3fe66973 100644 --- a/internal/tensorzero-node/lib/bindings/StoredJsonInference.ts +++ b/internal/tensorzero-node/lib/bindings/StoredJsonInference.ts @@ -1,6 +1,7 @@ // This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually. import type { JsonInferenceOutput } from "./JsonInferenceOutput"; import type { StoredInput } from "./StoredInput"; +import type { UnfilteredInferenceExtraBody } from "./UnfilteredInferenceExtraBody"; import type { JsonValue } from "./serde_json/JsonValue"; export type StoredJsonInference = { @@ -14,4 +15,5 @@ export type StoredJsonInference = { inference_id: string; output_schema: JsonValue; tags: { [key in string]?: string }; + extra_body: UnfilteredInferenceExtraBody; }; diff --git a/internal/tensorzero-node/lib/bindings/UnfilteredInferenceExtraBody.ts b/internal/tensorzero-node/lib/bindings/UnfilteredInferenceExtraBody.ts new file mode 100644 index 0000000000..39e11819e2 --- /dev/null +++ b/internal/tensorzero-node/lib/bindings/UnfilteredInferenceExtraBody.ts @@ -0,0 +1,8 @@ +// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually. +import type { ExtraBody } from "./ExtraBody"; + +/** + * The 'InferenceExtraBody' options provided directly in an inference request. + * These have not yet been filtered by variant name + */ +export type UnfilteredInferenceExtraBody = { extra_body: Array }; diff --git a/internal/tensorzero-node/lib/bindings/index.ts b/internal/tensorzero-node/lib/bindings/index.ts index 76f790ee05..1c7bcb5762 100644 --- a/internal/tensorzero-node/lib/bindings/index.ts +++ b/internal/tensorzero-node/lib/bindings/index.ts @@ -89,6 +89,7 @@ export * from "./EvaluatorConfig"; export * from "./ExactMatchConfig"; export * from "./ExperimentationConfig"; export * from "./ExportConfig"; +export * from "./ExtraBody"; export * from "./ExtraHeaderKind"; export * from "./FeedbackBounds"; export * from "./FeedbackBoundsByType"; @@ -315,6 +316,7 @@ export * from "./ToolResult"; export * from "./TrackAndStopConfig"; export * from "./TrackAndStopState"; export * from "./UiConfig"; +export * from "./UnfilteredInferenceExtraBody"; export * from "./UniformConfig"; export * from "./UninitializedBestOfNEvaluatorConfig"; export * from "./UninitializedBestOfNSamplingConfig"; diff --git a/tensorzero-core/src/db/clickhouse/inference_queries.rs b/tensorzero-core/src/db/clickhouse/inference_queries.rs index 8ccfcfbe74..e144dcddef 100644 --- a/tensorzero-core/src/db/clickhouse/inference_queries.rs +++ b/tensorzero-core/src/db/clickhouse/inference_queries.rs @@ -435,6 +435,7 @@ fn generate_single_table_query_for_type( } select_clauses.push("i.variant_name as variant_name".to_string()); + select_clauses.push("i.extra_body as extra_body".to_string()); let mut where_clauses: Vec = Vec::new(); @@ -684,6 +685,7 @@ mod tests { i.tool_choice as tool_choice, i.parallel_tool_calls as parallel_tool_calls, i.variant_name as variant_name, + i.extra_body as extra_body, i.output as output FROM ChatInference AS i @@ -708,6 +710,7 @@ mod tests { NULL as tool_choice, NULL as parallel_tool_calls, i.variant_name as variant_name, + i.extra_body as extra_body, i.output as output FROM JsonInference AS i @@ -754,6 +757,7 @@ mod tests { NULL as tool_choice, NULL as parallel_tool_calls, i.variant_name as variant_name, + i.extra_body as extra_body, i.output as output FROM JsonInference AS i diff --git a/tensorzero-core/src/db/clickhouse/query_builder/mod.rs b/tensorzero-core/src/db/clickhouse/query_builder/mod.rs index 0205d50980..d79cda595d 100644 --- a/tensorzero-core/src/db/clickhouse/query_builder/mod.rs +++ b/tensorzero-core/src/db/clickhouse/query_builder/mod.rs @@ -532,6 +532,7 @@ SELECT NULL as tool_choice, NULL as parallel_tool_calls, i.variant_name as variant_name, + i.extra_body as extra_body, i.output as output FROM JsonInference AS i @@ -580,6 +581,7 @@ SELECT i.tool_choice as tool_choice, i.parallel_tool_calls as parallel_tool_calls, i.variant_name as variant_name, + i.extra_body as extra_body, i.output as output FROM ChatInference AS i @@ -634,6 +636,7 @@ SELECT NULL as tool_choice, NULL as parallel_tool_calls, i.variant_name as variant_name, + i.extra_body as extra_body, i.output as output FROM JsonInference AS i @@ -735,6 +738,7 @@ SELECT NULL as tool_choice, NULL as parallel_tool_calls, i.variant_name as variant_name, + i.extra_body as extra_body, demo_f.value AS output, [i.output] as dispreferred_outputs FROM @@ -790,6 +794,7 @@ SELECT NULL as tool_choice, NULL as parallel_tool_calls, i.variant_name as variant_name, + i.extra_body as extra_body, i.output as output FROM JsonInference AS i @@ -859,6 +864,7 @@ SELECT NULL as tool_choice, NULL as parallel_tool_calls, i.variant_name as variant_name, + i.extra_body as extra_body, i.output as output FROM JsonInference AS i @@ -944,6 +950,7 @@ SELECT NULL as tool_choice, NULL as parallel_tool_calls, i.variant_name as variant_name, + i.extra_body as extra_body, i.output as output FROM JsonInference AS i @@ -1048,6 +1055,7 @@ SELECT NULL as tool_choice, NULL as parallel_tool_calls, i.variant_name as variant_name, + i.extra_body as extra_body, i.output as output FROM JsonInference AS i @@ -1191,6 +1199,7 @@ SELECT NULL as tool_choice, NULL as parallel_tool_calls, i.variant_name as variant_name, + i.extra_body as extra_body, i.output as output FROM JsonInference AS i @@ -1284,6 +1293,7 @@ SELECT NULL as tool_choice, NULL as parallel_tool_calls, i.variant_name as variant_name, + i.extra_body as extra_body, i.output as output FROM JsonInference AS i @@ -1380,6 +1390,7 @@ SELECT NULL as tool_choice, NULL as parallel_tool_calls, i.variant_name as variant_name, + i.extra_body as extra_body, i.output as output FROM JsonInference AS i @@ -1428,6 +1439,7 @@ SELECT NULL as tool_choice, NULL as parallel_tool_calls, i.variant_name as variant_name, + i.extra_body as extra_body, i.output as output FROM JsonInference AS i @@ -1482,6 +1494,7 @@ SELECT NULL as tool_choice, NULL as parallel_tool_calls, i.variant_name as variant_name, + i.extra_body as extra_body, i.output as output FROM JsonInference AS i @@ -1552,6 +1565,7 @@ SELECT NULL as tool_choice, NULL as parallel_tool_calls, i.variant_name as variant_name, + i.extra_body as extra_body, i.output as output FROM JsonInference AS i @@ -1623,6 +1637,7 @@ SELECT NULL as tool_choice, NULL as parallel_tool_calls, i.variant_name as variant_name, + i.extra_body as extra_body, i.output as output FROM JsonInference AS i @@ -1685,6 +1700,7 @@ SELECT i.tool_choice as tool_choice, i.parallel_tool_calls as parallel_tool_calls, i.variant_name as variant_name, + i.extra_body as extra_body, i.output as output FROM ChatInference AS i @@ -1756,6 +1772,7 @@ SELECT NULL as tool_choice, NULL as parallel_tool_calls, i.variant_name as variant_name, + i.extra_body as extra_body, i.output as output FROM JsonInference AS i @@ -1835,6 +1852,7 @@ SELECT NULL as tool_choice, NULL as parallel_tool_calls, i.variant_name as variant_name, + i.extra_body as extra_body, i.output as output FROM JsonInference AS i @@ -1925,6 +1943,7 @@ SELECT NULL as tool_choice, NULL as parallel_tool_calls, i.variant_name as variant_name, + i.extra_body as extra_body, demo_f.value AS output, [i.output] as dispreferred_outputs FROM @@ -2023,6 +2042,7 @@ SELECT NULL as tool_choice, NULL as parallel_tool_calls, i.variant_name as variant_name, + i.extra_body as extra_body, i.output as output FROM JsonInference AS i @@ -2091,6 +2111,7 @@ SELECT i.tool_choice as tool_choice, i.parallel_tool_calls as parallel_tool_calls, i.variant_name as variant_name, + i.extra_body as extra_body, i.output as output FROM ChatInference AS i @@ -2164,6 +2185,7 @@ SELECT NULL as tool_choice, NULL as parallel_tool_calls, i.variant_name as variant_name, + i.extra_body as extra_body, i.output as output FROM JsonInference AS i @@ -2572,6 +2594,7 @@ SELECT NULL as tool_choice, NULL as parallel_tool_calls, i.variant_name as variant_name, + i.extra_body as extra_body, i.output as output FROM JsonInference AS i @@ -2629,6 +2652,7 @@ SELECT NULL as tool_choice, NULL as parallel_tool_calls, i.variant_name as variant_name, + i.extra_body as extra_body, i.output as output FROM JsonInference AS i @@ -2704,6 +2728,7 @@ SELECT NULL as tool_choice, NULL as parallel_tool_calls, i.variant_name as variant_name, + i.extra_body as extra_body, i.output as output FROM JsonInference AS i @@ -2773,6 +2798,7 @@ SELECT i.tool_choice as tool_choice, i.parallel_tool_calls as parallel_tool_calls, i.variant_name as variant_name, + i.extra_body as extra_body, i.output as output, countSubstringsCaseInsensitiveUTF8(i.input, {p1:String}) as input_term_frequency, countSubstringsCaseInsensitiveUTF8(i.output, {p1:String}) as output_term_frequency, diff --git a/tensorzero-core/src/db/inferences.rs b/tensorzero-core/src/db/inferences.rs index 9419288984..ac4b9c516b 100644 --- a/tensorzero-core/src/db/inferences.rs +++ b/tensorzero-core/src/db/inferences.rs @@ -13,8 +13,9 @@ use mockall::automock; use crate::config::Config; use crate::db::clickhouse::query_builder::{InferenceFilter, OrderBy}; use crate::error::{Error, ErrorDetails}; +use crate::inference::types::extra_body::UnfilteredInferenceExtraBody; use crate::inference::types::{ContentBlockChatOutput, JsonInferenceOutput, StoredInput}; -use crate::serde_util::deserialize_json_string; +use crate::serde_util::{deserialize_defaulted_json_string, deserialize_json_string}; use crate::stored_inference::{ StoredChatInferenceDatabase, StoredInferenceDatabase, StoredJsonInference, }; @@ -38,6 +39,8 @@ pub(super) struct ClickHouseStoredChatInferenceWithDispreferredOutputs { #[serde(flatten, deserialize_with = "deserialize_tool_info")] pub tool_params: ToolCallConfigDatabaseInsert, pub tags: HashMap, + #[serde(default, deserialize_with = "deserialize_defaulted_json_string")] + pub extra_body: UnfilteredInferenceExtraBody, } impl TryFrom for StoredChatInferenceDatabase { @@ -69,6 +72,7 @@ impl TryFrom for StoredCha tool_params: value.tool_params, tags: value.tags, timestamp: value.timestamp, + extra_body: value.extra_body, }) } } @@ -89,6 +93,8 @@ pub(super) struct ClickHouseStoredJsonInferenceWithDispreferredOutputs { #[serde(deserialize_with = "deserialize_json_string")] pub output_schema: Value, pub tags: HashMap, + #[serde(default, deserialize_with = "deserialize_defaulted_json_string")] + pub extra_body: UnfilteredInferenceExtraBody, } impl TryFrom for StoredJsonInference { @@ -119,6 +125,7 @@ impl TryFrom for StoredJso output_schema: value.output_schema, tags: value.tags, timestamp: value.timestamp, + extra_body: value.extra_body, }) } } diff --git a/tensorzero-core/src/endpoints/datasets/v1/create_from_inferences.rs b/tensorzero-core/src/endpoints/datasets/v1/create_from_inferences.rs index 2b3fa9737f..d2401faa6e 100644 --- a/tensorzero-core/src/endpoints/datasets/v1/create_from_inferences.rs +++ b/tensorzero-core/src/endpoints/datasets/v1/create_from_inferences.rs @@ -206,6 +206,7 @@ mod tests { inference_id: id, tool_params: ToolCallConfigDatabaseInsert::default(), tags: HashMap::new(), + extra_body: Default::default(), }) } diff --git a/tensorzero-core/src/endpoints/stored_inferences/v1/get_inferences.rs b/tensorzero-core/src/endpoints/stored_inferences/v1/get_inferences.rs index bbbcc4442d..e5ea06dfbc 100644 --- a/tensorzero-core/src/endpoints/stored_inferences/v1/get_inferences.rs +++ b/tensorzero-core/src/endpoints/stored_inferences/v1/get_inferences.rs @@ -144,6 +144,7 @@ mod tests { inference_id: id, tool_params: ToolCallConfigDatabaseInsert::default(), tags: HashMap::new(), + extra_body: Default::default(), }) } diff --git a/tensorzero-core/src/inference/types/extra_body.rs b/tensorzero-core/src/inference/types/extra_body.rs index 66547b8401..39f501c6d2 100644 --- a/tensorzero-core/src/inference/types/extra_body.rs +++ b/tensorzero-core/src/inference/types/extra_body.rs @@ -33,7 +33,7 @@ pub enum ExtraBodyReplacementKind { /// The 'InferenceExtraBody' options provided directly in an inference request. /// These have not yet been filtered by variant name -#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ts_rs::TS)] +#[derive(Clone, Debug, Default, Deserialize, JsonSchema, PartialEq, Serialize, ts_rs::TS)] #[serde(transparent)] pub struct UnfilteredInferenceExtraBody { extra_body: Vec, diff --git a/tensorzero-core/src/stored_inference.rs b/tensorzero-core/src/stored_inference.rs index e760e3612d..d9812b8d97 100644 --- a/tensorzero-core/src/stored_inference.rs +++ b/tensorzero-core/src/stored_inference.rs @@ -10,6 +10,7 @@ use crate::endpoints::datasets::v1::types::{ }; use crate::error::{Error, ErrorDetails}; use crate::function::FunctionConfig; +use crate::inference::types::extra_body::UnfilteredInferenceExtraBody; #[cfg(feature = "pyo3")] use crate::inference::types::pyo3_helpers::{ content_block_chat_output_to_python, serialize_to_dict, uuid_to_python, @@ -19,6 +20,7 @@ use crate::inference::types::{ ContentBlockChatOutput, JsonInferenceOutput, ModelInput, RequestMessage, ResolvedInput, ResolvedRequestMessage, Text, }; +use crate::serde_util::deserialize_defaulted_json_string; use crate::tool::{ DynamicToolParams, StaticToolConfig, ToolCallConfigDatabaseInsert, deserialize_tool_info, }; @@ -214,6 +216,7 @@ impl StoredChatInference { inference_id: self.inference_id, tool_params, tags: self.tags, + extra_body: self.extra_body, }) } } @@ -261,6 +264,8 @@ pub struct StoredChatInference { pub tool_params: DynamicToolParams, #[serde(default)] pub tags: HashMap, + #[serde(default, deserialize_with = "deserialize_defaulted_json_string")] + pub extra_body: UnfilteredInferenceExtraBody, } impl std::fmt::Display for StoredChatInference { @@ -284,6 +289,7 @@ impl StoredChatInferenceDatabase { inference_id: self.inference_id, tool_params: self.tool_params.into(), tags: self.tags, + extra_body: self.extra_body, } } } @@ -304,6 +310,8 @@ pub struct StoredChatInferenceDatabase { pub tool_params: ToolCallConfigDatabaseInsert, #[serde(default)] pub tags: HashMap, + #[serde(default, deserialize_with = "deserialize_defaulted_json_string")] + pub extra_body: UnfilteredInferenceExtraBody, } impl std::fmt::Display for StoredChatInferenceDatabase { @@ -329,6 +337,8 @@ pub struct StoredJsonInference { pub output_schema: Value, #[serde(default)] pub tags: HashMap, + #[serde(default, deserialize_with = "deserialize_defaulted_json_string")] + pub extra_body: UnfilteredInferenceExtraBody, } impl std::fmt::Display for StoredJsonInference { @@ -826,6 +836,7 @@ mod tests { tags.insert("key2".to_string(), "value2".to_string()); tags }, + extra_body: UnfilteredInferenceExtraBody::default(), } } @@ -862,6 +873,7 @@ mod tests { tags.insert("json_key".to_string(), "json_value".to_string()); tags }, + extra_body: UnfilteredInferenceExtraBody::default(), } } diff --git a/tensorzero-core/tests/e2e/endpoints/stored_inferences/get_inferences.rs b/tensorzero-core/tests/e2e/endpoints/stored_inferences/get_inferences.rs index 99e4f5d8a3..5cf661b85e 100644 --- a/tensorzero-core/tests/e2e/endpoints/stored_inferences/get_inferences.rs +++ b/tensorzero-core/tests/e2e/endpoints/stored_inferences/get_inferences.rs @@ -1167,3 +1167,66 @@ pub async fn test_list_inferences_cursor_with_metric_ordering_fails() { "Request with cursor and metric ordering should fail" ); } + +// Tests for extra_body field + +#[tokio::test(flavor = "multi_thread")] +pub async fn test_get_by_ids_with_extra_body() { + let http_client = Client::new(); + + // Create an inference with a nontrivial extra_body + let extra_body_value = json!([ + {"pointer": "/test_field", "value": "test_value"}, + {"pointer": "/nested/field", "value": {"key": "nested_value"}} + ]); + + let inference_payload = json!({ + "function_name": "basic_test", + "variant_name": "test", + "input": { + "system": {"assistant_name": "TestBot"}, + "messages": [{"role": "user", "content": "Hello"}] + }, + "stream": false, + "extra_body": extra_body_value + }); + + // Make the inference request + let inference_response = http_client + .post(get_gateway_endpoint("/inference")) + .json(&inference_payload) + .send() + .await + .unwrap(); + + assert!( + inference_response.status().is_success(), + "Inference request failed: status={:?}", + inference_response.status() + ); + + let inference_json: Value = inference_response.json().await.unwrap(); + let inference_id = Uuid::parse_str(inference_json["inference_id"].as_str().unwrap()).unwrap(); + + // Query the inference back + let res = get_inferences_by_ids(vec![inference_id], InferenceOutputSource::Inference) + .await + .unwrap(); + + assert_eq!(res.len(), 1); + + // Assert the extra_body is correctly returned + let extra_body = &res[0]["extra_body"]; + assert!(extra_body.is_array(), "extra_body should be an array"); + + let extra_body_array = extra_body.as_array().unwrap(); + assert_eq!(extra_body_array.len(), 2); + + // Check the first extra_body entry + assert_eq!(extra_body_array[0]["pointer"], "/test_field"); + assert_eq!(extra_body_array[0]["value"], "test_value"); + + // Check the second extra_body entry (nested value) + assert_eq!(extra_body_array[1]["pointer"], "/nested/field"); + assert_eq!(extra_body_array[1]["value"]["key"], "nested_value"); +} diff --git a/tensorzero-core/tests/e2e/render_inferences.rs b/tensorzero-core/tests/e2e/render_inferences.rs index 8ce24710f8..23e958b98a 100644 --- a/tensorzero-core/tests/e2e/render_inferences.rs +++ b/tensorzero-core/tests/e2e/render_inferences.rs @@ -62,6 +62,7 @@ pub async fn test_render_samples_no_function() { timestamp: Utc::now(), dispreferred_outputs: vec![], tags: HashMap::from([("test_key".to_string(), "test_value".to_string())]), + extra_body: Default::default(), })]; let rendered_inferences = client @@ -98,6 +99,7 @@ pub async fn test_render_samples_no_variant() { timestamp: Utc::now(), dispreferred_outputs: vec![], tags: HashMap::new(), + extra_body: Default::default(), })]; let error = client @@ -147,6 +149,7 @@ pub async fn test_render_samples_missing_variable() { timestamp: Utc::now(), dispreferred_outputs: vec![], tags: HashMap::new(), + extra_body: Default::default(), })]; let rendered_inferences = client @@ -188,6 +191,7 @@ pub async fn test_render_samples_normal() { timestamp: Utc::now(), dispreferred_outputs: vec![], tags: HashMap::new(), + extra_body: Default::default(), }), StoredInferenceDatabase::Json(StoredJsonInference { function_name: "json_success".to_string(), @@ -221,6 +225,7 @@ pub async fn test_render_samples_normal() { raw: Some("{}".to_string()), // This should not be validated }], tags: HashMap::new(), + extra_body: Default::default(), }), StoredInferenceDatabase::Chat(StoredChatInferenceDatabase { function_name: "weather_helper".to_string(), @@ -268,6 +273,7 @@ pub async fn test_render_samples_normal() { text: "Hello, world!".to_string(), })]], tags: HashMap::new(), + extra_body: Default::default(), }), StoredInferenceDatabase::Chat(StoredChatInferenceDatabase { function_name: "basic_test".to_string(), @@ -313,6 +319,7 @@ pub async fn test_render_samples_normal() { timestamp: Utc::now(), dispreferred_outputs: vec![], tags: HashMap::new(), + extra_body: Default::default(), }), ]; @@ -506,6 +513,7 @@ pub async fn test_render_samples_template_no_schema() { tool_params: ToolCallConfigDatabaseInsert::default(), dispreferred_outputs: vec![], tags: HashMap::new(), + extra_body: Default::default(), })]; let rendered_inferences = client From 87a54142a2f1d00ace3d7c9677b735491594ed0a Mon Sep 17 00:00:00 2001 From: Viraj Mehta Date: Mon, 8 Dec 2025 21:04:39 -0500 Subject: [PATCH 03/11] wip --- .../inference/InferenceDetailContent.tsx | 36 ++++-- .../inference/InferencePreviewSheet.tsx | 6 +- .../inference/VariantResponseModal.tsx | 5 +- .../api/inference/$inference_id/route.ts | 29 +++-- .../routes/api/tensorzero/inference.utils.tsx | 110 ++++++------------ .../$inference_id/InferenceBasicInfo.tsx | 7 +- .../inferences/$inference_id/route.tsx | 27 +++-- ui/app/utils/resolve.server.ts | 98 ++++++++++++++++ 8 files changed, 205 insertions(+), 113 deletions(-) diff --git a/ui/app/components/inference/InferenceDetailContent.tsx b/ui/app/components/inference/InferenceDetailContent.tsx index c2af9b1da2..a91926f8b0 100644 --- a/ui/app/components/inference/InferenceDetailContent.tsx +++ b/ui/app/components/inference/InferenceDetailContent.tsx @@ -3,12 +3,15 @@ import { type ParsedInferenceRow, type ParsedModelInferenceRow, } from "~/utils/clickhouse/inference"; -import type { FeedbackRow, FeedbackBounds } from "~/types/tensorzero"; +import type { + FeedbackRow, + FeedbackBounds, + StoredInference, +} from "~/types/tensorzero"; import { useEffect, useState } from "react"; import type { ReactNode } from "react"; import { useConfig, useFunctionConfig } from "~/context/config"; import BasicInfo from "~/routes/observability/inferences/$inference_id/InferenceBasicInfo"; -import Input from "~/components/inference/Input"; import { ChatOutputElement } from "~/components/input_output/ChatOutputElement"; import { JsonOutputElement } from "~/components/input_output/JsonOutputElement"; import FeedbackTable from "~/components/feedback/FeedbackTable"; @@ -38,9 +41,13 @@ import { logger } from "~/utils/logger"; import { useFetcherWithReset } from "~/hooks/use-fetcher-with-reset"; import { DEFAULT_FUNCTION } from "~/utils/constants"; import { VariantResponseModal } from "~/components/inference/VariantResponseModal"; +import { InputElement } from "../input_output/InputElement"; +import type { Input } from "~/types/tensorzero"; export interface InferenceDetailData { - inference: ParsedInferenceRow; + inference: StoredInference; + // TODO: remove + input: Input; model_inferences: ParsedModelInferenceRow[]; feedback: FeedbackRow[]; feedback_bounds: FeedbackBounds; @@ -85,6 +92,7 @@ export function InferenceDetailContent({ }: InferenceDetailContentProps) { const { inference, + input, model_inferences, feedback, feedback_bounds, @@ -273,7 +281,7 @@ export function InferenceDetailContent({ isDefaultFunction={isDefault} /> - - {inference.function_type === "json" ? ( + {inference.type === "json" ? ( - {inference.function_type === "chat" && ( + {inference.type === "chat" && ( {inference.tool_params && ( @@ -384,7 +391,14 @@ export function InferenceDetailContent({ - + entry[1] !== undefined, + ), + )} + isEditing={false} + /> diff --git a/ui/app/components/inference/InferencePreviewSheet.tsx b/ui/app/components/inference/InferencePreviewSheet.tsx index 22795239bc..edea67fc1a 100644 --- a/ui/app/components/inference/InferencePreviewSheet.tsx +++ b/ui/app/components/inference/InferencePreviewSheet.tsx @@ -41,7 +41,7 @@ export function InferencePreviewSheet({ // Extract stable values from fetcher for dependency arrays const fetcherState = fetcher.state; - const fetcherDataInferenceId = fetcher.data?.inference.id; + const fetcherDataInferenceId = fetcher.data?.inference.inference_id; // Fetch data when sheet opens with an inference ID (only if we don't have data) // Also refetch when inference ID changes to avoid showing stale data @@ -101,10 +101,10 @@ export function InferencePreviewSheet({ <> Inference{" "} - {inferenceData.inference.id} + {inferenceData.inference.inference_id} ) : ( diff --git a/ui/app/components/inference/VariantResponseModal.tsx b/ui/app/components/inference/VariantResponseModal.tsx index 1102f971ed..829259ecdb 100644 --- a/ui/app/components/inference/VariantResponseModal.tsx +++ b/ui/app/components/inference/VariantResponseModal.tsx @@ -16,6 +16,7 @@ import type { InferenceResponse } from "~/utils/tensorzero"; import type { ContentBlockChatOutput, JsonInferenceOutput, + StoredInference, } from "~/types/tensorzero"; import { Card, CardContent } from "~/components/ui/card"; import type { VariantResponseInfo } from "~/routes/api/tensorzero/inference.utils"; @@ -112,7 +113,7 @@ interface VariantResponseModalProps { isLoading: boolean; onClose: () => void; // Use a union type to accept either inference or datapoint - item: ParsedInferenceRow | Datapoint; + item: StoredInference | Datapoint; // Make inferenceUsage optional since datasets don't have it by default inferenceUsage?: InferenceUsage; selectedVariant: string; @@ -161,7 +162,7 @@ export function VariantResponseModal({ // Get original variant name if available (only for inferences) const originalVariant = source === "inference" - ? (item as ParsedInferenceRow).variant_name + ? (item as StoredInference).variant_name : undefined; const refreshButton = onRefresh && ( diff --git a/ui/app/routes/api/inference/$inference_id/route.ts b/ui/app/routes/api/inference/$inference_id/route.ts index a18f9976c7..a92d7fa2b2 100644 --- a/ui/app/routes/api/inference/$inference_id/route.ts +++ b/ui/app/routes/api/inference/$inference_id/route.ts @@ -12,6 +12,8 @@ import { getUsedVariants } from "~/utils/clickhouse/function"; import { DEFAULT_FUNCTION } from "~/utils/constants"; import { logger } from "~/utils/logger"; import type { InferenceDetailData } from "~/components/inference/InferenceDetailContent"; +import { getTensorZeroClient } from "~/utils/get-tensorzero-client.server"; +import { loadFileDataForStoredInput } from "~/utils/resolve.server"; export async function loader({ request, @@ -27,8 +29,12 @@ export async function loader({ try { const dbClient = await getNativeDatabaseClient(); + const tensorZeroClient = getTensorZeroClient(); - const inferencePromise = queryInferenceById(inference_id); + const inferencesPromise = tensorZeroClient.getInferences({ + ids: [inference_id], + output_source: "inference", + }); const modelInferencesPromise = queryModelInferencesByInferenceId(inference_id); const demonstrationFeedbackPromise = @@ -47,7 +53,7 @@ export async function loader({ limit: 10, }); - let inference, + let inferences, model_inferences, demonstration_feedback, feedback_bounds, @@ -57,9 +63,9 @@ export async function loader({ if (newFeedbackId) { // When there's new feedback, wait for polling to complete before querying // feedbackBounds and latestFeedbackByMetric to ensure ClickHouse materialized views are updated - [inference, model_inferences, demonstration_feedback, feedback] = + [inferences, model_inferences, demonstration_feedback, feedback] = await Promise.all([ - inferencePromise, + inferencesPromise, modelInferencesPromise, demonstrationFeedbackPromise, feedbackDataPromise, @@ -73,14 +79,14 @@ export async function loader({ } else { // Normal case: execute all queries in parallel [ - inference, + inferences, model_inferences, demonstration_feedback, feedback_bounds, feedback, latestFeedbackByMetric, ] = await Promise.all([ - inferencePromise, + inferencesPromise, modelInferencesPromise, demonstrationFeedbackPromise, dbClient.queryFeedbackBoundsByTargetId({ target_id: inference_id }), @@ -89,9 +95,13 @@ export async function loader({ ]); } - if (!inference) { - throw data(`Inference ${inference_id} not found`, { status: 404 }); - } + if (inferences.inferences.length !== 1) { + throw data(`No inference found for id ${inference_id}.`, { + status: 404, + }); + } + const inference = inferences.inferences[0]; + const resolvedInput = await loadFileDataForStoredInput(inference.input); // Get used variants for default function const usedVariants = @@ -101,6 +111,7 @@ export async function loader({ const inferenceData: InferenceDetailData = { inference, + input: resolvedInput, model_inferences, feedback, feedback_bounds, diff --git a/ui/app/routes/api/tensorzero/inference.utils.tsx b/ui/app/routes/api/tensorzero/inference.utils.tsx index 464d1088ae..120fb34994 100644 --- a/ui/app/routes/api/tensorzero/inference.utils.tsx +++ b/ui/app/routes/api/tensorzero/inference.utils.tsx @@ -15,6 +15,8 @@ import type { ToolChoice, Tool, ResolvedTomlPathData, + StoredInference, + StoredInput, } from "~/types/tensorzero"; import type { InputMessageContent as TensorZeroContent, @@ -47,6 +49,7 @@ import type { ZodInputMessageContent, } from "~/utils/clickhouse/common"; import { v7 } from "uuid"; +import { loadFileDataForInput, loadFileDataForStoredInput } from "~/utils/resolve.server"; interface InferenceActionError { message: string; @@ -278,13 +281,13 @@ function inputMessageContentToZodInputMessageContent( interface InferenceActionArgs { source: "inference"; - resource: ParsedInferenceRow; + resource: StoredInference; variant: string; } interface InferenceDefaultFunctionActionArgs { source: "inference"; - resource: ParsedInferenceRow; + resource: StoredInference; variant?: undefined; model_name: string; } @@ -295,27 +298,10 @@ interface T0DatapointActionArgs { variant: string; } -interface ClickHouseDatapointActionArgs { - source: "clickhouse_datapoint"; - input: ZodDisplayInput; - functionName: string; - allowed_tools?: string[]; - additional_tools?: Array | null; - tool_choice?: ToolChoice | null; - parallel_tool_calls?: boolean | null; - output_schema?: JsonValue; - variant?: string; - cache_options: CacheParamsOptions; - editedVariantInfo?: VariantInfo; - functionConfig: FunctionConfig; - toolsConfig: { [key in string]?: StaticToolConfig }; -} - type ActionArgs = | InferenceActionArgs | InferenceDefaultFunctionActionArgs - | T0DatapointActionArgs - | ClickHouseDatapointActionArgs; + | T0DatapointActionArgs; function isDefaultFunctionArgs( args: ActionArgs, @@ -326,9 +312,9 @@ function isDefaultFunctionArgs( ); } -export function prepareInferenceActionRequest( +export async function prepareInferenceActionRequest( args: ActionArgs, -): ClientInferenceParams { +): Promise { // Create base ClientInferenceParams with default values const baseParams: ClientInferenceParams = { function_name: null, @@ -372,25 +358,6 @@ export function prepareInferenceActionRequest( args.model_name, ); return { ...baseParams, ...defaultRequest }; - } else if (args.source === "clickhouse_datapoint") { - // Extract tool parameters from the ClickHouse datapoint args - const dynamicVariantInfo = args.editedVariantInfo - ? variantInfoToUninitializedVariantInfo(args.editedVariantInfo) - : null; - - return { - ...baseParams, - function_name: args.functionName, - input: resolvedInputToInput(args.input), - variant_name: args.variant || null, - output_schema: args.output_schema || null, - tool_choice: args.tool_choice || undefined, - parallel_tool_calls: args.parallel_tool_calls || undefined, - additional_tools: args.additional_tools || undefined, - allowed_tools: args.allowed_tools || undefined, - cache_options: args.cache_options, - internal_dynamic_variant_config: dynamicVariantInfo, - }; } else if (args.source === "t0_datapoint") { // Handle datapoints from tensorzero-node (with StoredInput) return { @@ -404,11 +371,11 @@ export function prepareInferenceActionRequest( if ( args.source === "inference" && args.resource.extra_body && - args.resource.extra_body.length > 0 + args.resource.extra_body.extra_body.length > 0 ) { throw new Error("Extra body is not supported for inference in UI."); } - const input = resolvedInputToInput(args.resource.input); + const input = await loadFileDataForStoredInput(args.resource.input); // TODO: this is unsupported in Node bindings for now // const extra_body = // args.source === "inference" ? args.resource.extra_body : undefined; @@ -422,24 +389,25 @@ export function prepareInferenceActionRequest( } } -function prepareDefaultFunctionRequest( - inference: ParsedInferenceRow, +async function prepareDefaultFunctionRequest( + inference: StoredInference, selectedVariant: string, -): Partial { - const input = resolvedInputToInput(inference.input); - if (inference.function_type === "chat") { - const tool_choice = inference.tool_params?.tool_choice; - const parallel_tool_calls = inference.tool_params?.parallel_tool_calls; - const tools_available = inference.tool_params?.tools_available; +): Promise> { + const input =await loadFileDataForStoredInput(inference.input); + if (inference.type === "chat") { + const tool_choice = inference.tool_choice; + const parallel_tool_calls = inference.parallel_tool_calls; + const allowed_tools = inference.allowed_tools; return { model_name: selectedVariant, input, tool_choice: tool_choice, - parallel_tool_calls: parallel_tool_calls || undefined, + parallel_tool_calls: parallel_tool_calls, + allowed_tools, // We need to add all tools as additional for the default function - additional_tools: tools_available, + additional_tools: inference.additional_tools, }; - } else if (inference.function_type === "json") { + } else if (inference.type === "json") { // This should never happen, just in case and for type safety const output_schema = inference.output_schema; return { @@ -468,14 +436,14 @@ export type VariantResponseInfo = usage?: InferenceUsage; }; -export function resolvedInputToInput(input: ZodDisplayInput): Input { +function resolvedInputToInput(input: StoredInput): Input { return { - system: input.system || null, - messages: input.messages.map(resolvedInputMessageToInputMessage), + system: input.system, + messages: input.messages.map(), }; } -export function resolvedInputToTensorZeroInput( +function resolvedInputToTensorZeroInput( input: ZodDisplayInput, ): TensorZeroInput { return { @@ -535,7 +503,7 @@ function resolvedFileContentToTensorZeroFile( } function resolvedInputMessageToInputMessage( - message: ZodDisplayInputMessage, + message: InputMessage, ): InputMessage { return { role: message.role, @@ -546,7 +514,7 @@ function resolvedInputMessageToInputMessage( } function resolvedInputMessageContentToInputMessageContent( - content: ZodDisplayInputMessageContent, + content: InputMessageContent, ): InputMessageContent { switch (content.type) { case "template": @@ -556,30 +524,20 @@ function resolvedInputMessageContentToInputMessageContent( type: "text", text: content.text, }; - case "missing_function_text": - return { - type: "text", - text: content.value, - }; case "raw_text": return { type: "raw_text", value: content.value, }; case "tool_call": { - let parsedArguments; - try { - parsedArguments = JSON.parse(content.arguments); - } catch { - parsedArguments = content.arguments; - } + // TODO: handle both types of tool here return { type: "tool_call", id: content.id, name: content.name, - arguments: parsedArguments, - raw_arguments: content.arguments, - raw_name: content.name, + arguments: content.arguments, + raw_arguments: JSON.stringify(content.arguments), + raw_name: content.raw_name, }; } case "tool_result": @@ -604,9 +562,7 @@ function resolvedInputMessageContentToInputMessageContent( provider_name: content.provider_name, }; case "file": - return resolvedFileContentToClientFile(content); - case "file_error": - throw new Error("Can't convert image error to client content"); + return loadFileDataForInput(content); } } diff --git a/ui/app/routes/observability/inferences/$inference_id/InferenceBasicInfo.tsx b/ui/app/routes/observability/inferences/$inference_id/InferenceBasicInfo.tsx index f64935c551..feb41b3f30 100644 --- a/ui/app/routes/observability/inferences/$inference_id/InferenceBasicInfo.tsx +++ b/ui/app/routes/observability/inferences/$inference_id/InferenceBasicInfo.tsx @@ -21,6 +21,7 @@ import { import { toFunctionUrl, toVariantUrl, toEpisodeUrl } from "~/utils/urls"; import { formatDateWithSeconds, getTimestampTooltipData } from "~/utils/date"; import { getFunctionTypeIcon } from "~/utils/icon"; +import type { StoredInference } from "~/types/tensorzero"; // Create timestamp tooltip component const createTimestampTooltip = (timestamp: string | number | Date) => { @@ -37,7 +38,7 @@ const createTimestampTooltip = (timestamp: string | number | Date) => { }; interface BasicInfoProps { - inference: ParsedInferenceRow; + inference: StoredInference; inferenceUsage?: InferenceUsage; modelInferences?: ParsedModelInferenceRow[]; } @@ -58,7 +59,7 @@ export default function BasicInfo({ const timestampTooltip = createTimestampTooltip(inference.timestamp); // Get function icon and background - const functionIconConfig = getFunctionTypeIcon(inference.function_type); + const functionIconConfig = getFunctionTypeIcon(inference.type); // Determine cache status from model inferences const hasCachedInferences = modelInferences.some((mi) => mi.cached); @@ -79,7 +80,7 @@ export default function BasicInfo({ icon={functionIconConfig.icon} iconBg={functionIconConfig.iconBg} label={inference.function_name} - secondaryLabel={`· ${inference.function_type}`} + secondaryLabel={`· ${inference.type}`} link={toFunctionUrl(inference.function_name)} font="mono" /> diff --git a/ui/app/routes/observability/inferences/$inference_id/route.tsx b/ui/app/routes/observability/inferences/$inference_id/route.tsx index 038d4b7d66..459433ad41 100644 --- a/ui/app/routes/observability/inferences/$inference_id/route.tsx +++ b/ui/app/routes/observability/inferences/$inference_id/route.tsx @@ -27,6 +27,8 @@ import { InferenceDetailContent, type InferenceDetailData, } from "~/components/inference/InferenceDetailContent"; +import { getTensorZeroClient } from "~/utils/get-tensorzero-client.server"; +import { loadFileDataForStoredInput } from "~/utils/resolve.server"; export const handle: RouteHandle = { crumb: (match) => [{ label: match.params.inference_id!, isIdentifier: true }], @@ -47,8 +49,12 @@ export async function loader({ request, params }: Route.LoaderArgs) { // --- Define all promises, conditionally choosing the feedback promise --- const dbClient = await getNativeDatabaseClient(); + const tensorZeroClient = getTensorZeroClient(); - const inferencePromise = queryInferenceById(inference_id); + const inferencesPromise = tensorZeroClient.getInferences({ + ids: [inference_id], + output_source: "inference", + }); const modelInferencesPromise = queryModelInferencesByInferenceId(inference_id); const demonstrationFeedbackPromise = @@ -72,7 +78,7 @@ export async function loader({ request, params }: Route.LoaderArgs) { // --- Execute promises concurrently (with special handling for new feedback) --- - let inference, + let inferences, model_inferences, demonstration_feedback, feedback_bounds, @@ -82,9 +88,9 @@ export async function loader({ request, params }: Route.LoaderArgs) { if (newFeedbackId) { // When there's new feedback, wait for polling to complete before querying // feedbackBounds and latestFeedbackByMetric to ensure ClickHouse materialized views are updated - [inference, model_inferences, demonstration_feedback, feedback] = + [inferences, model_inferences, demonstration_feedback, feedback] = await Promise.all([ - inferencePromise, + inferencesPromise, modelInferencesPromise, demonstrationFeedbackPromise, feedbackDataPromise, @@ -98,14 +104,14 @@ export async function loader({ request, params }: Route.LoaderArgs) { } else { // Normal case: execute all queries in parallel [ - inference, + inferences, model_inferences, demonstration_feedback, feedback_bounds, feedback, latestFeedbackByMetric, ] = await Promise.all([ - inferencePromise, + inferencesPromise, modelInferencesPromise, demonstrationFeedbackPromise, dbClient.queryFeedbackBoundsByTargetId({ target_id: inference_id }), @@ -116,19 +122,22 @@ export async function loader({ request, params }: Route.LoaderArgs) { // --- Process results --- - if (!inference) { + if (inferences.inferences.length !== 1) { throw data(`No inference found for id ${inference_id}.`, { status: 404, }); } + const inference = inferences.inferences[0]; const usedVariants = inference.function_name === DEFAULT_FUNCTION ? await getUsedVariants(inference.function_name) : []; + const resolvedInput = await loadFileDataForStoredInput(inference.input); return { inference, + resolvedInput, model_inferences, usedVariants, feedback, @@ -142,6 +151,7 @@ export async function loader({ request, params }: Route.LoaderArgs) { export default function InferencePage({ loaderData }: Route.ComponentProps) { const { inference, + resolvedInput, model_inferences, usedVariants, feedback, @@ -198,6 +208,7 @@ export default function InferencePage({ loaderData }: Route.ComponentProps) { // Build the data object for InferenceDetailContent const inferenceData: InferenceDetailData = { inference, + input: resolvedInput, model_inferences, feedback, feedback_bounds, @@ -238,7 +249,7 @@ export default function InferencePage({ loaderData }: Route.ComponentProps) { /> } renderHeader={({ basicInfo, actionBar }) => ( - + {basicInfo} {actionBar} diff --git a/ui/app/utils/resolve.server.ts b/ui/app/utils/resolve.server.ts index 83484ea715..d58b124372 100644 --- a/ui/app/utils/resolve.server.ts +++ b/ui/app/utils/resolve.server.ts @@ -18,6 +18,9 @@ import type { JsonValue, Input, InputMessageContent, + StoredInput, + StoredInputMessageContent, + StoredFile, } from "~/types/tensorzero"; import { getTensorZeroClient } from "./tensorzero.server"; @@ -393,3 +396,98 @@ async function loadInputFileData(file: File): Promise { } } } + +/** + * Loads the content of files for an `StoredInput`. + * Converts `ObjectStoragePointer` to `File` with `file_type: "object_storage"`. + * + * TODO (#4674 #4675): This will be handled in the gateway. + */ +export async function loadFileDataForStoredInput( + input: StoredInput, +): Promise { + const resolvedMessages = await Promise.all( + input.messages.map(async (message) => { + const resolvedContent = await Promise.all( + message.content.map(async (content) => { + return loadFileDataForStoredInputContent(content); + }), + ); + return { + role: message.role, + content: resolvedContent, + }; + }), + ); + + return { + system: input.system, + messages: resolvedMessages, + }; +} + +/** + * Resolves a StoredInputMessageContent to InputMessageContent. + * For files: converts StoredFile to File with file_type: "object_storage" by fetching data. + */ +async function loadFileDataForStoredInputContent( + content: StoredInputMessageContent, +): Promise { + switch (content.type) { + case "tool_call": + case "tool_result": + case "raw_text": + case "thought": + case "unknown": + case "template": + case "text": + return content; + case "file": { + const loadedFile = await loadStoredInputFileData(content); + return { + type: "file", + ...loadedFile, + }; + } + } +} + +/** + * Loads the data of a `file`, converting `ObjectStoragePointer` to `ObjectStorage` or `ObjectStorageError`. + * @param file - The file to load. + * @returns Loaded file. + */ +async function loadStoredInputFileData(file: StoredFile): Promise { + try { + const fileContent: ZodFileContent = { + type: "file", + file: { + url: file.source_url, + mime_type: file.mime_type, + }, + storage_path: file.storage_path, + }; + const resolvedFile = await resolveFile(fileContent); + const loadedFile: File = { + file_type: "object_storage", + data: resolvedFile.data, + mime_type: resolvedFile.mime_type, + storage_path: file.storage_path, + source_url: file.source_url, + detail: file.detail, + filename: file.filename, + }; + return loadedFile; + } catch (error) { + const loadFileError: File = { + file_type: "object_storage_error", + source_url: file.source_url, + mime_type: file.mime_type, + storage_path: file.storage_path, + detail: file.detail, + filename: file.filename, + error: error instanceof Error ? error.message : String(error), + }; + return loadFileError; + } +} From 66bb23d615dd13b876a4b606452aedd8505f3d84 Mon Sep 17 00:00:00 2001 From: Viraj Mehta Date: Mon, 8 Dec 2025 21:22:23 -0500 Subject: [PATCH 04/11] wip --- .../inference/InferenceDetailContent.tsx | 3 ++- .../api/inference/$inference_id/route.ts | 1 - .../routes/api/tensorzero/inference.utils.tsx | 18 +++++++++++------ ui/app/routes/playground/utils.ts | 20 ++----------------- 4 files changed, 16 insertions(+), 26 deletions(-) diff --git a/ui/app/components/inference/InferenceDetailContent.tsx b/ui/app/components/inference/InferenceDetailContent.tsx index a91926f8b0..6dbfcf74a9 100644 --- a/ui/app/components/inference/InferenceDetailContent.tsx +++ b/ui/app/components/inference/InferenceDetailContent.tsx @@ -1,6 +1,5 @@ import { isJsonOutput, - type ParsedInferenceRow, type ParsedModelInferenceRow, } from "~/utils/clickhouse/inference"; import type { @@ -193,6 +192,7 @@ export function InferenceDetailContent({ const onVariantSelect = (variant: string) => { processRequest(variant, { resource: inference, + input, source: variantSource, variant, }); @@ -201,6 +201,7 @@ export function InferenceDetailContent({ const onModelSelect = (model: string) => { processRequest(model, { resource: inference, + input, source: variantSource, model_name: model, }); diff --git a/ui/app/routes/api/inference/$inference_id/route.ts b/ui/app/routes/api/inference/$inference_id/route.ts index a92d7fa2b2..58a048eac4 100644 --- a/ui/app/routes/api/inference/$inference_id/route.ts +++ b/ui/app/routes/api/inference/$inference_id/route.ts @@ -1,6 +1,5 @@ import { data, type LoaderFunctionArgs } from "react-router"; import { - queryInferenceById, queryModelInferencesByInferenceId, } from "~/utils/clickhouse/inference.server"; import { diff --git a/ui/app/routes/api/tensorzero/inference.utils.tsx b/ui/app/routes/api/tensorzero/inference.utils.tsx index 120fb34994..13ffaf84af 100644 --- a/ui/app/routes/api/tensorzero/inference.utils.tsx +++ b/ui/app/routes/api/tensorzero/inference.utils.tsx @@ -282,12 +282,14 @@ function inputMessageContentToZodInputMessageContent( interface InferenceActionArgs { source: "inference"; resource: StoredInference; + input: Input; variant: string; } interface InferenceDefaultFunctionActionArgs { source: "inference"; resource: StoredInference; + input: Input; variant?: undefined; model_name: string; } @@ -295,7 +297,8 @@ interface InferenceDefaultFunctionActionArgs { interface T0DatapointActionArgs { source: "t0_datapoint"; resource: ChatInferenceDatapoint | JsonInferenceDatapoint; - variant: string; + variant?: string; + editedVariantInfo?: VariantInfo; } type ActionArgs = @@ -312,9 +315,9 @@ function isDefaultFunctionArgs( ); } -export async function prepareInferenceActionRequest( +export function prepareInferenceActionRequest( args: ActionArgs, -): Promise { +): ClientInferenceParams { // Create base ClientInferenceParams with default values const baseParams: ClientInferenceParams = { function_name: null, @@ -360,11 +363,15 @@ export async function prepareInferenceActionRequest( return { ...baseParams, ...defaultRequest }; } else if (args.source === "t0_datapoint") { // Handle datapoints from tensorzero-node (with StoredInput) + const dynamicVariantInfo = args.editedVariantInfo + ? variantInfoToUninitializedVariantInfo(args.editedVariantInfo) + : null; return { ...baseParams, function_name: args.resource.function_name, input: args.resource.input, - variant_name: args.variant, + variant_name: args.variant || null, + internal_dynamic_variant_config: dynamicVariantInfo, }; } else { // For other sources, the input is already a DisplayInput @@ -375,7 +382,6 @@ export async function prepareInferenceActionRequest( ) { throw new Error("Extra body is not supported for inference in UI."); } - const input = await loadFileDataForStoredInput(args.resource.input); // TODO: this is unsupported in Node bindings for now // const extra_body = // args.source === "inference" ? args.resource.extra_body : undefined; @@ -383,7 +389,7 @@ export async function prepareInferenceActionRequest( return { ...baseParams, function_name: args.resource.function_name, - input, + input: args.input, variant_name: args.variant, }; } diff --git a/ui/app/routes/playground/utils.ts b/ui/app/routes/playground/utils.ts index 183ee92f5e..be509e93d4 100644 --- a/ui/app/routes/playground/utils.ts +++ b/ui/app/routes/playground/utils.ts @@ -156,26 +156,10 @@ export function preparePlaygroundInferenceRequest( } = args; const variantInferenceInfo = getVariantInferenceInfo(variant); const request = prepareInferenceActionRequest({ - source: "clickhouse_datapoint", - input, - functionName, + source: "t0_datapoint", + resource: datapoint, variant: variantInferenceInfo.variant, - allowed_tools: - datapoint?.type === "chat" ? datapoint.allowed_tools : undefined, - additional_tools: - datapoint?.type === "chat" ? datapoint.additional_tools : null, - tool_choice: datapoint?.type === "chat" ? datapoint.tool_choice : null, - parallel_tool_calls: - datapoint?.type === "chat" ? datapoint.parallel_tool_calls : null, - output_schema: datapoint?.type === "json" ? datapoint.output_schema : null, - // The default is write_only but we do off in the playground - cache_options: { - max_age_s: null, - enabled: "off", - }, editedVariantInfo: variantInferenceInfo.editedVariantInfo, - functionConfig, - toolsConfig, }); const extraOptions = getExtraInferenceOptions(); return { From 57a5138a8c10b9b38b74bc0bbcf077218f2d7b76 Mon Sep 17 00:00:00 2001 From: Viraj Mehta Date: Mon, 8 Dec 2025 22:08:03 -0500 Subject: [PATCH 05/11] added missing fields to get inference endpoint --- .../src/db/clickhouse/inference_queries.rs | 12 ++ .../src/db/clickhouse/query_builder/mod.rs | 78 +++++++++ tensorzero-core/src/db/inferences.rs | 15 ++ .../datasets/v1/create_from_inferences.rs | 3 + tensorzero-core/src/endpoints/inference.rs | 5 +- .../stored_inferences/v1/get_inferences.rs | 3 + .../service_tier.rs | 3 +- tensorzero-core/src/stored_inference.rs | 31 +++- tensorzero-core/src/variant/mod.rs | 3 +- .../stored_inferences/get_inferences.rs | 151 +++++++++++++++++- .../tests/e2e/render_inferences.rs | 24 +++ 11 files changed, 315 insertions(+), 13 deletions(-) diff --git a/tensorzero-core/src/db/clickhouse/inference_queries.rs b/tensorzero-core/src/db/clickhouse/inference_queries.rs index e144dcddef..046209ec74 100644 --- a/tensorzero-core/src/db/clickhouse/inference_queries.rs +++ b/tensorzero-core/src/db/clickhouse/inference_queries.rs @@ -436,6 +436,9 @@ fn generate_single_table_query_for_type( select_clauses.push("i.variant_name as variant_name".to_string()); select_clauses.push("i.extra_body as extra_body".to_string()); + select_clauses.push("i.inference_params as inference_params".to_string()); + select_clauses.push("i.processing_time_ms as processing_time_ms".to_string()); + select_clauses.push("i.ttft_ms as ttft_ms".to_string()); let mut where_clauses: Vec = Vec::new(); @@ -686,6 +689,9 @@ mod tests { i.parallel_tool_calls as parallel_tool_calls, i.variant_name as variant_name, i.extra_body as extra_body, + i.inference_params as inference_params, + i.processing_time_ms as processing_time_ms, + i.ttft_ms as ttft_ms, i.output as output FROM ChatInference AS i @@ -711,6 +717,9 @@ mod tests { NULL as parallel_tool_calls, i.variant_name as variant_name, i.extra_body as extra_body, + i.inference_params as inference_params, + i.processing_time_ms as processing_time_ms, + i.ttft_ms as ttft_ms, i.output as output FROM JsonInference AS i @@ -758,6 +767,9 @@ mod tests { NULL as parallel_tool_calls, i.variant_name as variant_name, i.extra_body as extra_body, + i.inference_params as inference_params, + i.processing_time_ms as processing_time_ms, + i.ttft_ms as ttft_ms, i.output as output FROM JsonInference AS i diff --git a/tensorzero-core/src/db/clickhouse/query_builder/mod.rs b/tensorzero-core/src/db/clickhouse/query_builder/mod.rs index d79cda595d..b60b24b52e 100644 --- a/tensorzero-core/src/db/clickhouse/query_builder/mod.rs +++ b/tensorzero-core/src/db/clickhouse/query_builder/mod.rs @@ -533,6 +533,9 @@ SELECT NULL as parallel_tool_calls, i.variant_name as variant_name, i.extra_body as extra_body, + i.inference_params as inference_params, + i.processing_time_ms as processing_time_ms, + i.ttft_ms as ttft_ms, i.output as output FROM JsonInference AS i @@ -582,6 +585,9 @@ SELECT i.parallel_tool_calls as parallel_tool_calls, i.variant_name as variant_name, i.extra_body as extra_body, + i.inference_params as inference_params, + i.processing_time_ms as processing_time_ms, + i.ttft_ms as ttft_ms, i.output as output FROM ChatInference AS i @@ -637,6 +643,9 @@ SELECT NULL as parallel_tool_calls, i.variant_name as variant_name, i.extra_body as extra_body, + i.inference_params as inference_params, + i.processing_time_ms as processing_time_ms, + i.ttft_ms as ttft_ms, i.output as output FROM JsonInference AS i @@ -739,6 +748,9 @@ SELECT NULL as parallel_tool_calls, i.variant_name as variant_name, i.extra_body as extra_body, + i.inference_params as inference_params, + i.processing_time_ms as processing_time_ms, + i.ttft_ms as ttft_ms, demo_f.value AS output, [i.output] as dispreferred_outputs FROM @@ -795,6 +807,9 @@ SELECT NULL as parallel_tool_calls, i.variant_name as variant_name, i.extra_body as extra_body, + i.inference_params as inference_params, + i.processing_time_ms as processing_time_ms, + i.ttft_ms as ttft_ms, i.output as output FROM JsonInference AS i @@ -865,6 +880,9 @@ SELECT NULL as parallel_tool_calls, i.variant_name as variant_name, i.extra_body as extra_body, + i.inference_params as inference_params, + i.processing_time_ms as processing_time_ms, + i.ttft_ms as ttft_ms, i.output as output FROM JsonInference AS i @@ -951,6 +969,9 @@ SELECT NULL as parallel_tool_calls, i.variant_name as variant_name, i.extra_body as extra_body, + i.inference_params as inference_params, + i.processing_time_ms as processing_time_ms, + i.ttft_ms as ttft_ms, i.output as output FROM JsonInference AS i @@ -1056,6 +1077,9 @@ SELECT NULL as parallel_tool_calls, i.variant_name as variant_name, i.extra_body as extra_body, + i.inference_params as inference_params, + i.processing_time_ms as processing_time_ms, + i.ttft_ms as ttft_ms, i.output as output FROM JsonInference AS i @@ -1200,6 +1224,9 @@ SELECT NULL as parallel_tool_calls, i.variant_name as variant_name, i.extra_body as extra_body, + i.inference_params as inference_params, + i.processing_time_ms as processing_time_ms, + i.ttft_ms as ttft_ms, i.output as output FROM JsonInference AS i @@ -1294,6 +1321,9 @@ SELECT NULL as parallel_tool_calls, i.variant_name as variant_name, i.extra_body as extra_body, + i.inference_params as inference_params, + i.processing_time_ms as processing_time_ms, + i.ttft_ms as ttft_ms, i.output as output FROM JsonInference AS i @@ -1391,6 +1421,9 @@ SELECT NULL as parallel_tool_calls, i.variant_name as variant_name, i.extra_body as extra_body, + i.inference_params as inference_params, + i.processing_time_ms as processing_time_ms, + i.ttft_ms as ttft_ms, i.output as output FROM JsonInference AS i @@ -1440,6 +1473,9 @@ SELECT NULL as parallel_tool_calls, i.variant_name as variant_name, i.extra_body as extra_body, + i.inference_params as inference_params, + i.processing_time_ms as processing_time_ms, + i.ttft_ms as ttft_ms, i.output as output FROM JsonInference AS i @@ -1495,6 +1531,9 @@ SELECT NULL as parallel_tool_calls, i.variant_name as variant_name, i.extra_body as extra_body, + i.inference_params as inference_params, + i.processing_time_ms as processing_time_ms, + i.ttft_ms as ttft_ms, i.output as output FROM JsonInference AS i @@ -1566,6 +1605,9 @@ SELECT NULL as parallel_tool_calls, i.variant_name as variant_name, i.extra_body as extra_body, + i.inference_params as inference_params, + i.processing_time_ms as processing_time_ms, + i.ttft_ms as ttft_ms, i.output as output FROM JsonInference AS i @@ -1638,6 +1680,9 @@ SELECT NULL as parallel_tool_calls, i.variant_name as variant_name, i.extra_body as extra_body, + i.inference_params as inference_params, + i.processing_time_ms as processing_time_ms, + i.ttft_ms as ttft_ms, i.output as output FROM JsonInference AS i @@ -1701,6 +1746,9 @@ SELECT i.parallel_tool_calls as parallel_tool_calls, i.variant_name as variant_name, i.extra_body as extra_body, + i.inference_params as inference_params, + i.processing_time_ms as processing_time_ms, + i.ttft_ms as ttft_ms, i.output as output FROM ChatInference AS i @@ -1773,6 +1821,9 @@ SELECT NULL as parallel_tool_calls, i.variant_name as variant_name, i.extra_body as extra_body, + i.inference_params as inference_params, + i.processing_time_ms as processing_time_ms, + i.ttft_ms as ttft_ms, i.output as output FROM JsonInference AS i @@ -1853,6 +1904,9 @@ SELECT NULL as parallel_tool_calls, i.variant_name as variant_name, i.extra_body as extra_body, + i.inference_params as inference_params, + i.processing_time_ms as processing_time_ms, + i.ttft_ms as ttft_ms, i.output as output FROM JsonInference AS i @@ -1944,6 +1998,9 @@ SELECT NULL as parallel_tool_calls, i.variant_name as variant_name, i.extra_body as extra_body, + i.inference_params as inference_params, + i.processing_time_ms as processing_time_ms, + i.ttft_ms as ttft_ms, demo_f.value AS output, [i.output] as dispreferred_outputs FROM @@ -2043,6 +2100,9 @@ SELECT NULL as parallel_tool_calls, i.variant_name as variant_name, i.extra_body as extra_body, + i.inference_params as inference_params, + i.processing_time_ms as processing_time_ms, + i.ttft_ms as ttft_ms, i.output as output FROM JsonInference AS i @@ -2112,6 +2172,9 @@ SELECT i.parallel_tool_calls as parallel_tool_calls, i.variant_name as variant_name, i.extra_body as extra_body, + i.inference_params as inference_params, + i.processing_time_ms as processing_time_ms, + i.ttft_ms as ttft_ms, i.output as output FROM ChatInference AS i @@ -2186,6 +2249,9 @@ SELECT NULL as parallel_tool_calls, i.variant_name as variant_name, i.extra_body as extra_body, + i.inference_params as inference_params, + i.processing_time_ms as processing_time_ms, + i.ttft_ms as ttft_ms, i.output as output FROM JsonInference AS i @@ -2595,6 +2661,9 @@ SELECT NULL as parallel_tool_calls, i.variant_name as variant_name, i.extra_body as extra_body, + i.inference_params as inference_params, + i.processing_time_ms as processing_time_ms, + i.ttft_ms as ttft_ms, i.output as output FROM JsonInference AS i @@ -2653,6 +2722,9 @@ SELECT NULL as parallel_tool_calls, i.variant_name as variant_name, i.extra_body as extra_body, + i.inference_params as inference_params, + i.processing_time_ms as processing_time_ms, + i.ttft_ms as ttft_ms, i.output as output FROM JsonInference AS i @@ -2729,6 +2801,9 @@ SELECT NULL as parallel_tool_calls, i.variant_name as variant_name, i.extra_body as extra_body, + i.inference_params as inference_params, + i.processing_time_ms as processing_time_ms, + i.ttft_ms as ttft_ms, i.output as output FROM JsonInference AS i @@ -2799,6 +2874,9 @@ SELECT i.parallel_tool_calls as parallel_tool_calls, i.variant_name as variant_name, i.extra_body as extra_body, + i.inference_params as inference_params, + i.processing_time_ms as processing_time_ms, + i.ttft_ms as ttft_ms, i.output as output, countSubstringsCaseInsensitiveUTF8(i.input, {p1:String}) as input_term_frequency, countSubstringsCaseInsensitiveUTF8(i.output, {p1:String}) as output_term_frequency, diff --git a/tensorzero-core/src/db/inferences.rs b/tensorzero-core/src/db/inferences.rs index ac4b9c516b..e918aa4abe 100644 --- a/tensorzero-core/src/db/inferences.rs +++ b/tensorzero-core/src/db/inferences.rs @@ -12,6 +12,7 @@ use mockall::automock; use crate::config::Config; use crate::db::clickhouse::query_builder::{InferenceFilter, OrderBy}; +use crate::endpoints::inference::InferenceParams; use crate::error::{Error, ErrorDetails}; use crate::inference::types::extra_body::UnfilteredInferenceExtraBody; use crate::inference::types::{ContentBlockChatOutput, JsonInferenceOutput, StoredInput}; @@ -41,6 +42,10 @@ pub(super) struct ClickHouseStoredChatInferenceWithDispreferredOutputs { pub tags: HashMap, #[serde(default, deserialize_with = "deserialize_defaulted_json_string")] pub extra_body: UnfilteredInferenceExtraBody, + #[serde(default, deserialize_with = "deserialize_defaulted_json_string")] + pub inference_params: InferenceParams, + pub processing_time_ms: Option, + pub ttft_ms: Option, } impl TryFrom for StoredChatInferenceDatabase { @@ -73,6 +78,9 @@ impl TryFrom for StoredCha tags: value.tags, timestamp: value.timestamp, extra_body: value.extra_body, + inference_params: value.inference_params, + processing_time_ms: value.processing_time_ms, + ttft_ms: value.ttft_ms, }) } } @@ -95,6 +103,10 @@ pub(super) struct ClickHouseStoredJsonInferenceWithDispreferredOutputs { pub tags: HashMap, #[serde(default, deserialize_with = "deserialize_defaulted_json_string")] pub extra_body: UnfilteredInferenceExtraBody, + #[serde(default, deserialize_with = "deserialize_defaulted_json_string")] + pub inference_params: InferenceParams, + pub processing_time_ms: Option, + pub ttft_ms: Option, } impl TryFrom for StoredJsonInference { @@ -126,6 +138,9 @@ impl TryFrom for StoredJso tags: value.tags, timestamp: value.timestamp, extra_body: value.extra_body, + inference_params: value.inference_params, + processing_time_ms: value.processing_time_ms, + ttft_ms: value.ttft_ms, }) } } diff --git a/tensorzero-core/src/endpoints/datasets/v1/create_from_inferences.rs b/tensorzero-core/src/endpoints/datasets/v1/create_from_inferences.rs index d2401faa6e..b5aa707d9c 100644 --- a/tensorzero-core/src/endpoints/datasets/v1/create_from_inferences.rs +++ b/tensorzero-core/src/endpoints/datasets/v1/create_from_inferences.rs @@ -207,6 +207,9 @@ mod tests { tool_params: ToolCallConfigDatabaseInsert::default(), tags: HashMap::new(), extra_body: Default::default(), + inference_params: Default::default(), + processing_time_ms: None, + ttft_ms: None, }) } diff --git a/tensorzero-core/src/endpoints/inference.rs b/tensorzero-core/src/endpoints/inference.rs index a31cd48606..8ac2f78914 100644 --- a/tensorzero-core/src/endpoints/inference.rs +++ b/tensorzero-core/src/endpoints/inference.rs @@ -8,6 +8,7 @@ use futures::stream::Stream; use futures_core::FusedStream; use indexmap::IndexMap; use metrics::counter; +use schemars::JsonSchema; use secrecy::SecretString; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -1365,14 +1366,14 @@ pub struct InferenceModels { /// InferenceParams is the top-level struct for inference parameters. /// We backfill these from the configs given in the variants used and ultimately write them to the database. -#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ts_rs::TS)] +#[derive(Clone, Debug, Default, Deserialize, JsonSchema, PartialEq, Serialize, ts_rs::TS)] #[ts(export)] #[serde(deny_unknown_fields)] pub struct InferenceParams { pub chat_completion: ChatCompletionInferenceParams, } -#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ts_rs::TS)] +#[derive(Clone, Debug, Default, Deserialize, JsonSchema, PartialEq, Serialize, ts_rs::TS)] #[ts(export)] #[serde(deny_unknown_fields)] pub struct ChatCompletionInferenceParams { diff --git a/tensorzero-core/src/endpoints/stored_inferences/v1/get_inferences.rs b/tensorzero-core/src/endpoints/stored_inferences/v1/get_inferences.rs index e5ea06dfbc..a29bc929bc 100644 --- a/tensorzero-core/src/endpoints/stored_inferences/v1/get_inferences.rs +++ b/tensorzero-core/src/endpoints/stored_inferences/v1/get_inferences.rs @@ -145,6 +145,9 @@ mod tests { tool_params: ToolCallConfigDatabaseInsert::default(), tags: HashMap::new(), extra_body: Default::default(), + inference_params: Default::default(), + processing_time_ms: None, + ttft_ms: None, }) } diff --git a/tensorzero-core/src/inference/types/chat_completion_inference_params/service_tier.rs b/tensorzero-core/src/inference/types/chat_completion_inference_params/service_tier.rs index a2444612e4..2db9bcf841 100644 --- a/tensorzero-core/src/inference/types/chat_completion_inference_params/service_tier.rs +++ b/tensorzero-core/src/inference/types/chat_completion_inference_params/service_tier.rs @@ -1,10 +1,11 @@ +use schemars::JsonSchema; use serde::{Deserialize, Serialize}; /// Service tier for inference requests. /// /// Controls the priority and latency characteristics of the request. /// Different providers map these values differently to their own service tiers. -#[derive(Clone, Debug, Default, PartialEq, Eq, Deserialize, Serialize, ts_rs::TS)] +#[derive(Clone, Debug, Default, PartialEq, Eq, Deserialize, Serialize, JsonSchema, ts_rs::TS)] #[ts(export)] #[serde(rename_all = "lowercase")] pub enum ServiceTier { diff --git a/tensorzero-core/src/stored_inference.rs b/tensorzero-core/src/stored_inference.rs index d9812b8d97..e014bd6788 100644 --- a/tensorzero-core/src/stored_inference.rs +++ b/tensorzero-core/src/stored_inference.rs @@ -1,5 +1,6 @@ use std::{collections::HashMap, sync::Arc}; +use crate::client::InferenceParams; use crate::config::Config; use crate::db::datasets::{ ChatInferenceDatapointInsert, DatapointInsert, JsonInferenceDatapointInsert, @@ -20,7 +21,7 @@ use crate::inference::types::{ ContentBlockChatOutput, JsonInferenceOutput, ModelInput, RequestMessage, ResolvedInput, ResolvedRequestMessage, Text, }; -use crate::serde_util::deserialize_defaulted_json_string; +use crate::serde_util::{deserialize_defaulted_json_string, deserialize_json_string}; use crate::tool::{ DynamicToolParams, StaticToolConfig, ToolCallConfigDatabaseInsert, deserialize_tool_info, }; @@ -217,6 +218,9 @@ impl StoredChatInference { tool_params, tags: self.tags, extra_body: self.extra_body, + inference_params: self.inference_params, + processing_time_ms: self.processing_time_ms, + ttft_ms: self.ttft_ms, }) } } @@ -264,8 +268,11 @@ pub struct StoredChatInference { pub tool_params: DynamicToolParams, #[serde(default)] pub tags: HashMap, - #[serde(default, deserialize_with = "deserialize_defaulted_json_string")] + #[serde(default)] pub extra_body: UnfilteredInferenceExtraBody, + pub inference_params: InferenceParams, + pub processing_time_ms: Option, + pub ttft_ms: Option, } impl std::fmt::Display for StoredChatInference { @@ -290,6 +297,9 @@ impl StoredChatInferenceDatabase { tool_params: self.tool_params.into(), tags: self.tags, extra_body: self.extra_body, + inference_params: self.inference_params, + processing_time_ms: self.processing_time_ms, + ttft_ms: self.ttft_ms, } } } @@ -312,6 +322,10 @@ pub struct StoredChatInferenceDatabase { pub tags: HashMap, #[serde(default, deserialize_with = "deserialize_defaulted_json_string")] pub extra_body: UnfilteredInferenceExtraBody, + #[serde(default, deserialize_with = "deserialize_json_string")] + pub inference_params: InferenceParams, + pub processing_time_ms: Option, + pub ttft_ms: Option, } impl std::fmt::Display for StoredChatInferenceDatabase { @@ -337,8 +351,12 @@ pub struct StoredJsonInference { pub output_schema: Value, #[serde(default)] pub tags: HashMap, - #[serde(default, deserialize_with = "deserialize_defaulted_json_string")] + #[serde(default)] pub extra_body: UnfilteredInferenceExtraBody, + #[serde(default)] + pub inference_params: InferenceParams, + pub processing_time_ms: Option, + pub ttft_ms: Option, } impl std::fmt::Display for StoredJsonInference { @@ -759,6 +777,7 @@ mod tests { use crate::config::{Config, SchemaData}; use crate::db::datasets::DatapointInsert; use crate::endpoints::datasets::v1::types::CreateDatapointsFromInferenceOutputSource; + use crate::endpoints::inference::InferenceParams; use crate::experimentation::ExperimentationConfig; use crate::function::{FunctionConfig, FunctionConfigChat, FunctionConfigJson}; use crate::inference::types::System; @@ -837,6 +856,9 @@ mod tests { tags }, extra_body: UnfilteredInferenceExtraBody::default(), + inference_params: InferenceParams::default(), + processing_time_ms: None, + ttft_ms: None, } } @@ -874,6 +896,9 @@ mod tests { tags }, extra_body: UnfilteredInferenceExtraBody::default(), + inference_params: InferenceParams::default(), + processing_time_ms: None, + ttft_ms: None, } } diff --git a/tensorzero-core/src/variant/mod.rs b/tensorzero-core/src/variant/mod.rs index 94f1d200de..6f15ce7d3d 100644 --- a/tensorzero-core/src/variant/mod.rs +++ b/tensorzero-core/src/variant/mod.rs @@ -4,6 +4,7 @@ use itertools::izip; use pyo3::exceptions::PyValueError; #[cfg(feature = "pyo3")] use pyo3::prelude::*; +use schemars::JsonSchema; use serde::Deserialize; use serde::Serialize; use std::borrow::Cow; @@ -111,7 +112,7 @@ pub struct ChainOfThoughtConfigPyClass { /// Variants represent JSON mode in a slightly more abstract sense than ModelInferenceRequests, as /// we support coercing tool calls into JSON mode. /// This is represented as a tool config in the -#[derive(Clone, Copy, Debug, Deserialize, PartialEq, Serialize)] +#[derive(Clone, Copy, Debug, Deserialize, JsonSchema, PartialEq, Serialize)] #[serde(rename_all = "snake_case")] #[derive(ts_rs::TS)] #[ts(export)] diff --git a/tensorzero-core/tests/e2e/endpoints/stored_inferences/get_inferences.rs b/tensorzero-core/tests/e2e/endpoints/stored_inferences/get_inferences.rs index 5cf661b85e..0d93b38d2f 100644 --- a/tensorzero-core/tests/e2e/endpoints/stored_inferences/get_inferences.rs +++ b/tensorzero-core/tests/e2e/endpoints/stored_inferences/get_inferences.rs @@ -1168,13 +1168,11 @@ pub async fn test_list_inferences_cursor_with_metric_ordering_fails() { ); } -// Tests for extra_body field - #[tokio::test(flavor = "multi_thread")] -pub async fn test_get_by_ids_with_extra_body() { +pub async fn test_get_by_ids_with_extra_body_and_inference_params() { let http_client = Client::new(); - // Create an inference with a nontrivial extra_body + // Create an inference with a nontrivial extra_body and inference params let extra_body_value = json!([ {"pointer": "/test_field", "value": "test_value"}, {"pointer": "/nested/field", "value": {"key": "nested_value"}} @@ -1188,7 +1186,14 @@ pub async fn test_get_by_ids_with_extra_body() { "messages": [{"role": "user", "content": "Hello"}] }, "stream": false, - "extra_body": extra_body_value + "extra_body": extra_body_value, + "params": { + "chat_completion": { + "temperature": 0.7, + "max_tokens": 100, + "seed": 42 + } + } }); // Make the inference request @@ -1215,8 +1220,10 @@ pub async fn test_get_by_ids_with_extra_body() { assert_eq!(res.len(), 1); + let inference = &res[0]; + // Assert the extra_body is correctly returned - let extra_body = &res[0]["extra_body"]; + let extra_body = &inference["extra_body"]; assert!(extra_body.is_array(), "extra_body should be an array"); let extra_body_array = extra_body.as_array().unwrap(); @@ -1229,4 +1236,136 @@ pub async fn test_get_by_ids_with_extra_body() { // Check the second extra_body entry (nested value) assert_eq!(extra_body_array[1]["pointer"], "/nested/field"); assert_eq!(extra_body_array[1]["value"]["key"], "nested_value"); + + // Assert the inference_params are correctly returned + let inference_params = &inference["inference_params"]; + assert!( + inference_params.is_object(), + "inference_params should be an object" + ); + + let chat_completion_params = &inference_params["chat_completion"]; + assert!( + chat_completion_params.is_object(), + "chat_completion should be an object" + ); + assert_eq!(chat_completion_params["temperature"], 0.7); + assert_eq!(chat_completion_params["max_tokens"], 100); + assert_eq!(chat_completion_params["seed"], 42); + + // Assert processing_time_ms is present (should be non-null for a completed inference) + let processing_time_ms = &inference["processing_time_ms"]; + assert!( + processing_time_ms.is_u64(), + "processing_time_ms should be a positive integer, but got: {processing_time_ms}" + ); + + // ttft_ms can be null for non-streaming requests, but we still check it's present + assert!( + inference.get("ttft_ms").is_some(), + "ttft_ms field should be present" + ); +} + +#[tokio::test(flavor = "multi_thread")] +pub async fn test_get_by_ids_json_function_with_inference_params() { + let http_client = Client::new(); + + // Create a JSON inference with extra_body and inference params + let extra_body_value = json!([ + {"pointer": "/json_test_field", "value": "json_test_value"}, + {"pointer": "/json_nested/field", "value": {"key": "json_nested_value"}} + ]); + + let inference_payload = json!({ + "function_name": "json_success", + "variant_name": "test", + "input": { + "system": {"assistant_name": "TestBot"}, + "messages": [{"role": "user", "content": [{"type": "template", "name": "user", "arguments": {"country": "France"}}]}] + }, + "stream": false, + "extra_body": extra_body_value, + "params": { + "chat_completion": { + "temperature": 0.5, + "max_tokens": 200, + "top_p": 0.9 + } + } + }); + + // Make the inference request + let inference_response = http_client + .post(get_gateway_endpoint("/inference")) + .json(&inference_payload) + .send() + .await + .unwrap(); + + assert!( + inference_response.status().is_success(), + "Inference request failed: status={:?}, body={:?}", + inference_response.status(), + inference_response.text().await + ); + + let inference_json: Value = inference_response.json().await.unwrap(); + let inference_id = Uuid::parse_str(inference_json["inference_id"].as_str().unwrap()).unwrap(); + + // Query the inference back + let res = get_inferences_by_ids(vec![inference_id], InferenceOutputSource::Inference) + .await + .unwrap(); + + assert_eq!(res.len(), 1); + + let inference = &res[0]; + + // Assert type is JSON + assert_eq!(inference["type"], "json"); + + // Assert the extra_body is correctly returned + let extra_body = &inference["extra_body"]; + assert!(extra_body.is_array(), "extra_body should be an array"); + + let extra_body_array = extra_body.as_array().unwrap(); + assert_eq!(extra_body_array.len(), 2); + + // Check the first extra_body entry + assert_eq!(extra_body_array[0]["pointer"], "/json_test_field"); + assert_eq!(extra_body_array[0]["value"], "json_test_value"); + + // Check the second extra_body entry (nested value) + assert_eq!(extra_body_array[1]["pointer"], "/json_nested/field"); + assert_eq!(extra_body_array[1]["value"]["key"], "json_nested_value"); + + // Assert the inference_params are correctly returned + let inference_params = &inference["inference_params"]; + assert!( + inference_params.is_object(), + "inference_params should be an object" + ); + + let chat_completion_params = &inference_params["chat_completion"]; + assert!( + chat_completion_params.is_object(), + "chat_completion should be an object" + ); + assert_eq!(chat_completion_params["temperature"], 0.5); + assert_eq!(chat_completion_params["max_tokens"], 200); + assert_eq!(chat_completion_params["top_p"], 0.9); + + // Assert processing_time_ms is present + let processing_time_ms = &inference["processing_time_ms"]; + assert!( + processing_time_ms.is_u64(), + "processing_time_ms should be a positive integer, but got: {processing_time_ms}" + ); + + // ttft_ms can be null for non-streaming requests + assert!( + inference.get("ttft_ms").is_some(), + "ttft_ms field should be present" + ); } diff --git a/tensorzero-core/tests/e2e/render_inferences.rs b/tensorzero-core/tests/e2e/render_inferences.rs index 23e958b98a..1e2dbf845e 100644 --- a/tensorzero-core/tests/e2e/render_inferences.rs +++ b/tensorzero-core/tests/e2e/render_inferences.rs @@ -63,6 +63,9 @@ pub async fn test_render_samples_no_function() { dispreferred_outputs: vec![], tags: HashMap::from([("test_key".to_string(), "test_value".to_string())]), extra_body: Default::default(), + inference_params: Default::default(), + processing_time_ms: None, + ttft_ms: None, })]; let rendered_inferences = client @@ -100,6 +103,9 @@ pub async fn test_render_samples_no_variant() { dispreferred_outputs: vec![], tags: HashMap::new(), extra_body: Default::default(), + inference_params: Default::default(), + processing_time_ms: None, + ttft_ms: None, })]; let error = client @@ -150,6 +156,9 @@ pub async fn test_render_samples_missing_variable() { dispreferred_outputs: vec![], tags: HashMap::new(), extra_body: Default::default(), + inference_params: Default::default(), + processing_time_ms: None, + ttft_ms: None, })]; let rendered_inferences = client @@ -192,6 +201,9 @@ pub async fn test_render_samples_normal() { dispreferred_outputs: vec![], tags: HashMap::new(), extra_body: Default::default(), + inference_params: Default::default(), + processing_time_ms: None, + ttft_ms: None, }), StoredInferenceDatabase::Json(StoredJsonInference { function_name: "json_success".to_string(), @@ -226,6 +238,9 @@ pub async fn test_render_samples_normal() { }], tags: HashMap::new(), extra_body: Default::default(), + inference_params: Default::default(), + processing_time_ms: None, + ttft_ms: None, }), StoredInferenceDatabase::Chat(StoredChatInferenceDatabase { function_name: "weather_helper".to_string(), @@ -274,6 +289,9 @@ pub async fn test_render_samples_normal() { })]], tags: HashMap::new(), extra_body: Default::default(), + inference_params: Default::default(), + processing_time_ms: None, + ttft_ms: None, }), StoredInferenceDatabase::Chat(StoredChatInferenceDatabase { function_name: "basic_test".to_string(), @@ -320,6 +338,9 @@ pub async fn test_render_samples_normal() { dispreferred_outputs: vec![], tags: HashMap::new(), extra_body: Default::default(), + inference_params: Default::default(), + processing_time_ms: None, + ttft_ms: None, }), ]; @@ -514,6 +535,9 @@ pub async fn test_render_samples_template_no_schema() { dispreferred_outputs: vec![], tags: HashMap::new(), extra_body: Default::default(), + inference_params: Default::default(), + processing_time_ms: None, + ttft_ms: None, })]; let rendered_inferences = client From 7d745f7863191a5ee12b27e6a2dbd9ae15043555 Mon Sep 17 00:00:00 2001 From: Viraj Mehta Date: Mon, 8 Dec 2025 22:11:15 -0500 Subject: [PATCH 06/11] fixed types --- clients/python/tensorzero/generated_types.py | 38 +++++++++++++++++++ .../lib/bindings/StoredChatInference.ts | 4 ++ .../lib/bindings/StoredJsonInference.ts | 4 ++ .../inference/InferenceDetailContent.tsx | 4 +- .../inference/VariantResponseModal.tsx | 4 +- .../api/inference/$inference_id/route.ts | 20 +++++----- .../routes/api/tensorzero/inference.utils.tsx | 7 +++- 7 files changed, 62 insertions(+), 19 deletions(-) diff --git a/clients/python/tensorzero/generated_types.py b/clients/python/tensorzero/generated_types.py index 5d5d32f8c0..adce3a2efb 100644 --- a/clients/python/tensorzero/generated_types.py +++ b/clients/python/tensorzero/generated_types.py @@ -570,6 +570,9 @@ class JsonInferenceOutput: """ +JsonMode = Literal["off", "on", "strict", "tool"] + + @dataclass(kw_only=True) class OpenAICustomToolFormatText: type: Literal["text"] = "text" @@ -600,6 +603,9 @@ class RawText: Role = Literal["user", "assistant"] +ServiceTier = Literal["auto", "default", "priority", "flex"] + + @dataclass(kw_only=True) class StorageKindS3Compatible: """ @@ -934,6 +940,22 @@ class Base64File: source_url: str | None = None +@dataclass(kw_only=True) +class ChatCompletionInferenceParams: + frequency_penalty: float | None = None + json_mode: JsonMode | None = None + max_tokens: int | None = None + presence_penalty: float | None = None + reasoning_effort: str | None = None + seed: int | None = None + service_tier: ServiceTier | None = None + stop_sequences: list[str] | None = None + temperature: float | None = None + thinking_budget_tokens: int | None = None + top_p: float | None = None + verbosity: str | None = None + + @dataclass(kw_only=True) class ContentBlockChatOutputToolCall(InferenceResponseToolCall): """ @@ -1064,6 +1086,16 @@ class InferenceFilterTime(TimeFilter): type: Literal["time"] = "time" +@dataclass(kw_only=True) +class InferenceParams: + """ + InferenceParams is the top-level struct for inference parameters. + We backfill these from the configs given in the variants used and ultimately write them to the database. + """ + + chat_completion: ChatCompletionInferenceParams + + @dataclass(kw_only=True) class InputMessageContentToolCall: type: Literal["tool_call"] = "tool_call" @@ -1559,7 +1591,10 @@ class StoredJsonInference: variant_name: str dispreferred_outputs: list[JsonInferenceOutput] | None = field(default_factory=lambda: []) extra_body: UnfilteredInferenceExtraBody | None = field(default_factory=lambda: []) + inference_params: InferenceParams | None = field(default_factory=lambda: {"chat_completion": {}}) + processing_time_ms: int | None = None tags: dict[str, str] | None = field(default_factory=lambda: {}) + ttft_ms: int | None = None @dataclass(kw_only=True) @@ -1888,6 +1923,7 @@ class StoredChatInference: episode_id: str function_name: str inference_id: str + inference_params: InferenceParams input: StoredInput output: list[ContentBlockChatOutput] timestamp: str @@ -1909,6 +1945,7 @@ class StoredChatInference: Whether to use parallel tool calls in the inference. Optional. If provided during inference, it will override the function-configured parallel tool calls. """ + processing_time_ms: int | None = None provider_tools: list[ProviderTool] | None = field(default_factory=lambda: []) """ Provider-specific tool configurations @@ -1919,6 +1956,7 @@ class StoredChatInference: User-specified tool choice strategy. If provided during inference, it will override the function-configured tool choice. Optional. """ + ttft_ms: int | None = None @dataclass(kw_only=True) diff --git a/internal/tensorzero-node/lib/bindings/StoredChatInference.ts b/internal/tensorzero-node/lib/bindings/StoredChatInference.ts index 283a677f5f..7a5b83dd7e 100644 --- a/internal/tensorzero-node/lib/bindings/StoredChatInference.ts +++ b/internal/tensorzero-node/lib/bindings/StoredChatInference.ts @@ -1,5 +1,6 @@ // This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually. import type { ContentBlockChatOutput } from "./ContentBlockChatOutput"; +import type { InferenceParams } from "./InferenceParams"; import type { ProviderTool } from "./ProviderTool"; import type { StoredInput } from "./StoredInput"; import type { Tool } from "./Tool"; @@ -20,6 +21,9 @@ export type StoredChatInference = { inference_id: string; tags: { [key in string]?: string }; extra_body: UnfilteredInferenceExtraBody; + inference_params: InferenceParams; + processing_time_ms: bigint | null; + ttft_ms: bigint | null; /** * A subset of static tools configured for the function that the inference is allowed to use. Optional. * If not provided, all static tools are allowed. diff --git a/internal/tensorzero-node/lib/bindings/StoredJsonInference.ts b/internal/tensorzero-node/lib/bindings/StoredJsonInference.ts index 2b3fe66973..72dbfb9b0c 100644 --- a/internal/tensorzero-node/lib/bindings/StoredJsonInference.ts +++ b/internal/tensorzero-node/lib/bindings/StoredJsonInference.ts @@ -1,4 +1,5 @@ // This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually. +import type { InferenceParams } from "./InferenceParams"; import type { JsonInferenceOutput } from "./JsonInferenceOutput"; import type { StoredInput } from "./StoredInput"; import type { UnfilteredInferenceExtraBody } from "./UnfilteredInferenceExtraBody"; @@ -16,4 +17,7 @@ export type StoredJsonInference = { output_schema: JsonValue; tags: { [key in string]?: string }; extra_body: UnfilteredInferenceExtraBody; + inference_params: InferenceParams; + processing_time_ms: bigint | null; + ttft_ms: bigint | null; }; diff --git a/ui/app/components/inference/InferenceDetailContent.tsx b/ui/app/components/inference/InferenceDetailContent.tsx index 6dbfcf74a9..4b4336721c 100644 --- a/ui/app/components/inference/InferenceDetailContent.tsx +++ b/ui/app/components/inference/InferenceDetailContent.tsx @@ -334,9 +334,7 @@ export function InferenceDetailContent({ - + diff --git a/ui/app/components/inference/VariantResponseModal.tsx b/ui/app/components/inference/VariantResponseModal.tsx index 829259ecdb..0b63598785 100644 --- a/ui/app/components/inference/VariantResponseModal.tsx +++ b/ui/app/components/inference/VariantResponseModal.tsx @@ -161,9 +161,7 @@ export function VariantResponseModal({ // Get original variant name if available (only for inferences) const originalVariant = - source === "inference" - ? (item as StoredInference).variant_name - : undefined; + source === "inference" ? (item as StoredInference).variant_name : undefined; const refreshButton = onRefresh && (