Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,4 @@ tensorzero-optimizers/bindings
clients/rust/bindings/
internal/tensorzero-node/bindings/
*.junit.xml
tensorzero-core/schemas
2 changes: 2 additions & 0 deletions clients/python/tensorzero/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
AlwaysExtraBodyDelete,
AlwaysExtraHeader,
AlwaysExtraHeaderDelete,
ChatCompletionInferenceParams,
ContentBlockChatOutput,
ContentBlockChatOutputText,
ContentBlockChatOutputToolCall,
Expand All @@ -41,6 +42,7 @@
InferenceFilterOr,
InferenceFilterTag,
InferenceFilterTime,
InferenceParams,
Input,
InputMessage,
InputMessageContentTemplate,
Expand Down
43 changes: 43 additions & 0 deletions clients/python/tensorzero/generated_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,9 @@ class JsonInferenceOutput:
"""


JsonMode = Literal["off", "on", "strict", "tool"]


@dataclass(kw_only=True)
class OpenAICustomToolFormatText:
type: Literal["text"] = "text"
Expand Down Expand Up @@ -600,6 +603,9 @@ class RawText:
Role = Literal["user", "assistant"]


ServiceTier = Literal["auto", "default", "priority", "flex"]


@dataclass(kw_only=True)
class StorageKindS3Compatible:
"""
Expand Down Expand Up @@ -840,6 +846,9 @@ class ToolResult:
result: str


UnfilteredInferenceExtraBody = list[ExtraBody]


@dataclass(kw_only=True)
class Unknown:
"""
Expand Down Expand Up @@ -931,6 +940,22 @@ class Base64File:
source_url: str | None = None


@dataclass(kw_only=True)
class ChatCompletionInferenceParams:
frequency_penalty: float | None = None
json_mode: JsonMode | None = None
max_tokens: int | None = None
presence_penalty: float | None = None
reasoning_effort: str | None = None
seed: int | None = None
service_tier: ServiceTier | None = None
stop_sequences: list[str] | None = None
temperature: float | None = None
thinking_budget_tokens: int | None = None
top_p: float | None = None
verbosity: str | None = None


@dataclass(kw_only=True)
class ContentBlockChatOutputToolCall(InferenceResponseToolCall):
"""
Expand Down Expand Up @@ -1061,6 +1086,16 @@ class InferenceFilterTime(TimeFilter):
type: Literal["time"] = "time"


@dataclass(kw_only=True)
class InferenceParams:
"""
InferenceParams is the top-level struct for inference parameters.
We backfill these from the configs given in the variants used and ultimately write them to the database.
"""

chat_completion: ChatCompletionInferenceParams


@dataclass(kw_only=True)
class InputMessageContentToolCall:
type: Literal["tool_call"] = "tool_call"
Expand Down Expand Up @@ -1555,7 +1590,11 @@ class StoredJsonInference:
timestamp: str
variant_name: str
dispreferred_outputs: list[JsonInferenceOutput] | None = field(default_factory=lambda: [])
extra_body: UnfilteredInferenceExtraBody | None = field(default_factory=lambda: [])
inference_params: InferenceParams | None = None
processing_time_ms: int | None = None
tags: dict[str, str] | None = field(default_factory=lambda: {})
ttft_ms: int | None = None


@dataclass(kw_only=True)
Expand Down Expand Up @@ -1884,6 +1923,7 @@ class StoredChatInference:
episode_id: str
function_name: str
inference_id: str
inference_params: InferenceParams
input: StoredInput
output: list[ContentBlockChatOutput]
timestamp: str
Expand All @@ -1899,11 +1939,13 @@ class StoredChatInference:
If not provided, all static tools are allowed.
"""
dispreferred_outputs: list[list[ContentBlockChatOutput]] | None = field(default_factory=lambda: [])
extra_body: UnfilteredInferenceExtraBody | None = field(default_factory=lambda: [])
parallel_tool_calls: bool | None = None
"""
Whether to use parallel tool calls in the inference. Optional.
If provided during inference, it will override the function-configured parallel tool calls.
"""
processing_time_ms: int | None = None
provider_tools: list[ProviderTool] | None = field(default_factory=lambda: [])
"""
Provider-specific tool configurations
Expand All @@ -1914,6 +1956,7 @@ class StoredChatInference:
User-specified tool choice strategy. If provided during inference, it will override the function-configured tool choice.
Optional.
"""
ttft_ms: int | None = None


@dataclass(kw_only=True)
Expand Down
6 changes: 6 additions & 0 deletions clients/python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from openai import AsyncOpenAI
from pytest import FixtureRequest
from tensorzero import (
InferenceParams,
ChatCompletionInferenceParams,
AsyncTensorZeroGateway,
ChatDatapointInsert,
ContentBlockChatOutputText,
Expand Down Expand Up @@ -124,6 +126,7 @@ def mixed_rendered_samples(
episode_id=str(uuid7()),
inference_id=str(uuid7()),
timestamp=datetime.now(timezone.utc).isoformat(),
inference_params=InferenceParams(chat_completion=ChatCompletionInferenceParams()),
additional_tools=[
FunctionTool(
name="test",
Expand Down Expand Up @@ -157,6 +160,7 @@ def mixed_rendered_samples(
episode_id=str(uuid7()),
inference_id=str(uuid7()),
timestamp=datetime.now(timezone.utc).isoformat(),
inference_params=InferenceParams(chat_completion=ChatCompletionInferenceParams()),
output_schema={
"type": "object",
"properties": {"answer": {"type": "string"}},
Expand Down Expand Up @@ -192,6 +196,7 @@ def chat_function_rendered_samples(
episode_id=str(uuid7()),
inference_id=str(uuid7()),
timestamp=datetime.now(timezone.utc).isoformat(),
inference_params=InferenceParams(chat_completion=ChatCompletionInferenceParams()),
tool_choice="none",
parallel_tool_calls=False,
dispreferred_outputs=[],
Expand Down Expand Up @@ -228,6 +233,7 @@ def json_function_rendered_samples(
episode_id=str(uuid7()),
inference_id=str(uuid7()),
timestamp=datetime.now(timezone.utc).isoformat(),
inference_params=InferenceParams(chat_completion=ChatCompletionInferenceParams()),
output_schema={
"type": "object",
"properties": {"answer": {"type": "string"}},
Expand Down
14 changes: 14 additions & 0 deletions clients/python/tests/test_rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import pytest
from tensorzero import (
ChatCompletionInferenceParams,
InferenceParams,
AsyncTensorZeroGateway,
ContentBlockChatOutputText,
FileBase64,
Expand Down Expand Up @@ -101,6 +103,7 @@ def test_sync_render_samples_success(embedded_sync_client: TensorZeroGateway):
dispreferred_outputs=[[ContentBlockChatOutputText(text="goodbye")]],
tags={},
timestamp=datetime.now(timezone.utc).isoformat(),
inference_params=InferenceParams(chat_completion=ChatCompletionInferenceParams()),
),
StoredInferenceJson(
function_name="json_success",
Expand All @@ -124,6 +127,7 @@ def test_sync_render_samples_success(embedded_sync_client: TensorZeroGateway):
dispreferred_outputs=[JsonInferenceOutput(parsed={"answer": "Kyoto"}, raw='{"answer": "Kyoto"}')],
tags={},
timestamp=datetime.now(timezone.utc).isoformat(),
inference_params=InferenceParams(chat_completion=ChatCompletionInferenceParams()),
),
],
variants={"basic_test": "test", "json_success": "test"},
Expand Down Expand Up @@ -276,6 +280,7 @@ def test_sync_render_samples_nonexistent_function(
dispreferred_outputs=[],
tags={},
timestamp=datetime.now(timezone.utc).isoformat(),
inference_params=InferenceParams(chat_completion=ChatCompletionInferenceParams()),
)
],
variants={},
Expand Down Expand Up @@ -310,6 +315,7 @@ def test_sync_render_samples_unspecified_function(
dispreferred_outputs=[],
tags={},
timestamp=datetime.now(timezone.utc).isoformat(),
inference_params=InferenceParams(chat_completion=ChatCompletionInferenceParams()),
)
],
variants={},
Expand Down Expand Up @@ -342,6 +348,7 @@ def test_sync_render_samples_no_variant(embedded_sync_client: TensorZeroGateway)
dispreferred_outputs=[],
tags={},
timestamp=datetime.now(timezone.utc).isoformat(),
inference_params=InferenceParams(chat_completion=ChatCompletionInferenceParams()),
)
],
variants={"basic_test": "non_existent_variant"},
Expand Down Expand Up @@ -376,6 +383,7 @@ def test_sync_render_samples_missing_variable(
dispreferred_outputs=[],
tags={},
timestamp=datetime.now(timezone.utc).isoformat(),
inference_params=InferenceParams(chat_completion=ChatCompletionInferenceParams()),
)
],
variants={"basic_test": "test"},
Expand Down Expand Up @@ -458,6 +466,7 @@ async def test_async_render_samples_success(
dispreferred_outputs=[],
tags={},
timestamp=datetime.now(timezone.utc).isoformat(),
inference_params=InferenceParams(chat_completion=ChatCompletionInferenceParams()),
),
StoredInferenceJson(
function_name="json_success",
Expand All @@ -484,6 +493,7 @@ async def test_async_render_samples_success(
dispreferred_outputs=[],
tags={},
timestamp=datetime.now(timezone.utc).isoformat(),
inference_params=InferenceParams(chat_completion=ChatCompletionInferenceParams()),
),
],
variants={"basic_test": "test", "json_success": "test"},
Expand Down Expand Up @@ -634,6 +644,7 @@ async def test_async_render_samples_nonexistent_function(
dispreferred_outputs=[],
tags={},
timestamp=datetime.now(timezone.utc).isoformat(),
inference_params=InferenceParams(chat_completion=ChatCompletionInferenceParams()),
)
],
variants={},
Expand Down Expand Up @@ -669,6 +680,7 @@ async def test_async_render_samples_unspecified_function(
dispreferred_outputs=[],
tags={},
timestamp=datetime.now(timezone.utc).isoformat(),
inference_params=InferenceParams(chat_completion=ChatCompletionInferenceParams()),
)
],
variants={},
Expand Down Expand Up @@ -704,6 +716,7 @@ async def test_async_render_samples_no_variant(
dispreferred_outputs=[],
tags={},
timestamp=datetime.now(timezone.utc).isoformat(),
inference_params=InferenceParams(chat_completion=ChatCompletionInferenceParams()),
)
],
variants={"basic_test": "non_existent_variant"},
Expand Down Expand Up @@ -739,6 +752,7 @@ async def test_async_render_samples_missing_variable(
dispreferred_outputs=[],
tags={},
timestamp=datetime.now(timezone.utc).isoformat(),
inference_params=InferenceParams(chat_completion=ChatCompletionInferenceParams()),
)
],
variants={"basic_test": "test"},
Expand Down
3 changes: 2 additions & 1 deletion clients/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ pub use tensorzero_core::endpoints::datasets::{
pub use tensorzero_core::endpoints::feedback::FeedbackResponse;
pub use tensorzero_core::endpoints::feedback::Params as FeedbackParams;
pub use tensorzero_core::endpoints::inference::{
InferenceOutput, InferenceParams, InferenceResponse, InferenceResponseChunk, InferenceStream,
ChatCompletionInferenceParams, InferenceOutput, InferenceParams, InferenceResponse,
InferenceResponseChunk, InferenceStream,
};
pub use tensorzero_core::endpoints::object_storage::ObjectResponse;
pub use tensorzero_core::endpoints::stored_inferences::v1::types::{
Expand Down
Loading
Loading