Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions haystack/components/generators/chat/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import os
from typing import Any, Optional, Union

from openai.lib._pydantic import to_strict_json_schema
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a different way to import this function that doesn't go through a private file? I'm a little worried the import path is subject to break/change

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, for now seems like this is the only way to import this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm I think we can do this a different way. We should be able to directly use the one from pydantic which looks like this parameters_schema = model.model_json_schema(). This is from the _create_tool_parameters_schema function in ComponentTool

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OpenAI expect a stricter JSON Schema than Pydantic’s default. For example, objects must set additionalProperties and Optional keys are handled differently. As a result, model_json_schema() often isn’t accepted as-is.
Its also discussed here where another solution is offered but I prefer using the openai method over some unpopular library.

Copy link
Contributor Author

@Amnah199 Amnah199 Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevertheless, I spotted a bug in to_dict where schema wasn't stored properly. Fixing this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh okay thanks for the info

from openai.lib.azure import AsyncAzureADTokenProvider, AsyncAzureOpenAI, AzureADTokenProvider, AzureOpenAI
from pydantic import BaseModel

from haystack import component, default_from_dict, default_to_dict
from haystack.components.generators.chat import OpenAIChatGenerator
Expand Down Expand Up @@ -123,6 +125,16 @@ def __init__( # pylint: disable=too-many-positional-arguments
Higher values make the model less likely to repeat the token.
- `logit_bias`: Adds a logit bias to specific tokens. The keys of the dictionary are tokens, and the
values are the bias to add to that token.
- `response_format`: A JSON schema or a Pydantic model that enforces the structure of the model's response.
If provided, the output will always be validated against this
format (unless the model returns a tool call).
For details, see the [OpenAI Structured Outputs documentation](https://platform.openai.com/docs/guides/structured-outputs).
Note:
- This parameter accepts Pydantic models and JSON schemas for latest models starting from GPT-4o.
Older models only support basic version of structured outputs through `{"type": "json_object"}`.
For detailed information on JSON mode, see the [OpenAI Structured Outputs documentation](https://platform.openai.com/docs/guides/structured-outputs#json-mode).
- For structured outputs with streaming,
the `response_format` must be a JSON schema and not a Pydantic model.
:param default_headers: Default headers to use for the AzureOpenAI client.
:param tools:
A list of tools or a Toolset for which the model can prepare calls. This parameter can accept either a
Expand Down Expand Up @@ -201,6 +213,11 @@ def to_dict(self) -> dict[str, Any]:
azure_ad_token_provider_name = None
if self.azure_ad_token_provider:
azure_ad_token_provider_name = serialize_callable(self.azure_ad_token_provider)
# If the response format is a Pydantic model, its converted to openai's json schema format
# If its already a json schema, it's left as is
response_format = self.generation_kwargs.get("response_format")
if response_format and issubclass(response_format, BaseModel):
self.generation_kwargs["response_format"] = to_strict_json_schema(response_format)
return default_to_dict(
self,
azure_endpoint=self.azure_endpoint,
Expand Down
68 changes: 60 additions & 8 deletions haystack/components/generators/chat/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,18 @@
from typing import Any, Optional, Union

from openai import AsyncOpenAI, AsyncStream, OpenAI, Stream
from openai.lib._pydantic import to_strict_json_schema
from openai.types.chat import (
ChatCompletion,
ChatCompletionChunk,
ChatCompletionMessage,
ChatCompletionMessageCustomToolCall,
ParsedChatCompletion,
ParsedChatCompletionMessage,
)
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
from pydantic import BaseModel

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message
Expand Down Expand Up @@ -138,6 +142,16 @@ def __init__( # pylint: disable=too-many-positional-arguments
Bigger values mean the model will be less likely to repeat the same token in the text.
- `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens, and the
values are the bias to add to that token.
- `response_format`: A JSON schema or a Pydantic model that enforces the structure of the model's response.
If provided, the output will always be validated against this
format (unless the model returns a tool call).
For details, see the [OpenAI Structured Outputs documentation](https://platform.openai.com/docs/guides/structured-outputs).
Note:
- This parameter accepts Pydantic models and JSON schemas for latest models starting from GPT-4o.
Older models only support basic version of structured outputs through `{"type": "json_object"}`.
For detailed information on JSON mode, see the [OpenAI Structured Outputs documentation](https://platform.openai.com/docs/guides/structured-outputs#json-mode).
- For structured outputs with streaming,
the `response_format` must be a JSON schema and not a Pydantic model.
:param timeout:
Timeout for OpenAI client calls. If not set, it defaults to either the
`OPENAI_TIMEOUT` environment variable, or 30 seconds.
Expand All @@ -153,6 +167,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
:param http_client_kwargs:
A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`.
For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).

"""
self.api_key = api_key
self.model = model
Expand Down Expand Up @@ -200,6 +215,13 @@ def to_dict(self) -> dict[str, Any]:
The serialized component as a dictionary.
"""
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
response_format = self.generation_kwargs.get("response_format")

# If the response format is a Pydantic model, its converted to openai's json schema format
# If its already a json schema, it's left as is
if response_format and issubclass(response_format, BaseModel):
self.generation_kwargs["response_format"] = to_strict_json_schema(response_format)

return default_to_dict(
self,
model=self.model,
Expand Down Expand Up @@ -272,6 +294,7 @@ def run(
streaming_callback = select_streaming_callback(
init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False
)
chat_completion: Union[Stream[ChatCompletionChunk], ChatCompletion, ParsedChatCompletion]

api_args = self._prepare_api_call(
messages=messages,
Expand All @@ -280,9 +303,11 @@ def run(
tools=tools,
tools_strict=tools_strict,
)
chat_completion: Union[Stream[ChatCompletionChunk], ChatCompletion] = self.client.chat.completions.create(
**api_args
)
openai_endpoint = api_args.pop("openai_endpoint")
if openai_endpoint == "create":
chat_completion = self.client.chat.completions.create(**api_args)
elif openai_endpoint == "parse":
chat_completion = self.client.chat.completions.parse(**api_args)

if streaming_callback is not None:
completions = self._handle_stream_response(
Expand Down Expand Up @@ -346,6 +371,7 @@ async def run_async(
streaming_callback = select_streaming_callback(
init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=True
)
chat_completion: Union[AsyncStream[ChatCompletionChunk], ChatCompletion, ParsedChatCompletion]

if len(messages) == 0:
return {"replies": []}
Expand All @@ -358,9 +384,11 @@ async def run_async(
tools_strict=tools_strict,
)

chat_completion: Union[
AsyncStream[ChatCompletionChunk], ChatCompletion
] = await self.async_client.chat.completions.create(**api_args)
openai_endpoint = api_args.pop("openai_endpoint")
if openai_endpoint == "create":
chat_completion = await self.async_client.chat.completions.create(**api_args)
elif openai_endpoint == "parse":
chat_completion = await self.async_client.chat.completions.parse(**api_args)

if streaming_callback is not None:
completions = await self._handle_async_stream_response(
Expand Down Expand Up @@ -393,6 +421,7 @@ def _prepare_api_call( # noqa: PLR0913
) -> dict[str, Any]:
# update generation kwargs by merging with the generation kwargs passed to the run method
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
response_format = generation_kwargs.get("response_format") if generation_kwargs else None

# adapt ChatMessage(s) to the format expected by the OpenAI API
openai_formatted_messages = [message.to_openai_dict_format() for message in messages]
Expand All @@ -416,14 +445,35 @@ def _prepare_api_call( # noqa: PLR0913

is_streaming = streaming_callback is not None
num_responses = generation_kwargs.pop("n", 1)

if response_format and not is_streaming:
# for structured outputs without streaming, we use openai's parse endpoint
# Note: `stream` cannot be passed to chat.completions.parse
# we pass a key `openai_endpoint` as a hint to the run method to use the parse endpoint
# this key will be removed before the API call is made
return {
"model": self.model,
"messages": openai_formatted_messages,
"n": num_responses,
"response_format": response_format,
"openai_endpoint": "parse",
**openai_tools,
**generation_kwargs,
}

if is_streaming and num_responses > 1:
raise ValueError("Cannot stream multiple responses, please set n=1.")

# for structured outputs with streaming, we use openai's create endpoint
# we pass a key `openai_endpoint` as a hint to the run method to use the create endpoint
# this key will be removed before the API call is made
return {
"model": self.model,
"messages": openai_formatted_messages,
"stream": streaming_callback is not None,
"n": num_responses,
"response_format": response_format,
"openai_endpoint": "create",
**openai_tools,
**generation_kwargs,
}
Expand Down Expand Up @@ -471,15 +521,17 @@ def _check_finish_reason(meta: dict[str, Any]) -> None:
)


def _convert_chat_completion_to_chat_message(completion: ChatCompletion, choice: Choice) -> ChatMessage:
def _convert_chat_completion_to_chat_message(
completion: Union[ChatCompletion, ParsedChatCompletion], choice: Choice
) -> ChatMessage:
"""
Converts the non-streaming response from the OpenAI API to a ChatMessage.

:param completion: The completion returned by the OpenAI API.
:param choice: The choice returned by the OpenAI API.
:return: The ChatMessage.
"""
message: ChatCompletionMessage = choice.message
message: Union[ChatCompletionMessage, ParsedChatCompletionMessage] = choice.message
text = message.content
tool_calls = []
if message.tool_calls:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
---
features:
- |
`OpenAIChatGenerator` and `AzureOpenAIChatGenerator` now support structured outputs using `response_format`
parameter that can be passed in `generation_kwargs`.
The `response_format` parameter can be a Pydantic model or a JSON schema for non-streaming responses. For streaming responses, the `response_format` must be a JSON schema.
Example usage of the `response_format` parameter:
```python
from pydantic import BaseModel
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.dataclasses import ChatMessage

class NobelPrizeInfo(BaseModel):
recipient_name: str
award_year: int
category: str
achievement_description: str
nationality: str

client = OpenAIChatGenerator(
model="gpt-4o-2024-08-06",
generation_kwargs={"response_format": NobelPrizeInfo}
)

response = client.run(messages=[
ChatMessage.from_user("Give me information about the 20th Nobel Peace Prize winner.")
])
print(response)

```
62 changes: 59 additions & 3 deletions test/components/generators/chat/test_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
#
# SPDX-License-Identifier: Apache-2.0

import json
import os
from typing import Any, Optional

import pytest
from openai import OpenAIError
from pydantic import BaseModel

from haystack import Pipeline, component
from haystack.components.generators.chat import AzureOpenAIChatGenerator
Expand All @@ -18,6 +20,17 @@
from haystack.utils.azure import default_azure_ad_token_provider


class CalendarEvent(BaseModel):
event_name: str
event_date: str
event_location: str


@pytest.fixture
def calendar_event_model():
return CalendarEvent


def get_weather(city: str) -> dict[str, Any]:
weather_info = {
"Berlin": {"weather": "mostly sunny", "temperature": 7, "unit": "celsius"},
Expand Down Expand Up @@ -141,7 +154,7 @@ def test_to_dict_default(self, monkeypatch):
},
}

def test_to_dict_with_parameters(self, monkeypatch):
def test_to_dict_with_parameters(self, monkeypatch, calendar_event_model):
monkeypatch.setenv("ENV_VAR", "test-api-key")
component = AzureOpenAIChatGenerator(
api_key=Secret.from_env_var("ENV_VAR", strict=False),
Expand All @@ -150,7 +163,11 @@ def test_to_dict_with_parameters(self, monkeypatch):
streaming_callback=print_streaming_chunk,
timeout=2.5,
max_retries=10,
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
generation_kwargs={
"max_tokens": 10,
"some_test_param": "test-params",
"response_format": calendar_event_model,
},
azure_ad_token_provider=default_azure_ad_token_provider,
http_client_kwargs={"proxy": "http://localhost:8080"},
)
Expand All @@ -167,7 +184,21 @@ def test_to_dict_with_parameters(self, monkeypatch):
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"timeout": 2.5,
"max_retries": 10,
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
"generation_kwargs": {
"max_tokens": 10,
"some_test_param": "test-params",
"response_format": {
"properties": {
"event_name": {"title": "Event Name", "type": "string"},
"event_date": {"title": "Event Date", "type": "string"},
"event_location": {"title": "Event Location", "type": "string"},
},
"required": ["event_name", "event_date", "event_location"],
"title": "CalendarEvent",
"type": "object",
"additionalProperties": False,
},
},
"tools": None,
"tools_strict": False,
"default_headers": {},
Expand Down Expand Up @@ -331,6 +362,31 @@ def test_live_run_with_tools(self, tools):
assert tool_call.arguments == {"city": "Paris"}
assert message.meta["finish_reason"] == "tool_calls"

@pytest.mark.skipif(
not os.environ.get("AZURE_OPENAI_API_KEY", None),
reason="Export an env var called AZURE_OPENAI_API_KEY containing the Azure OpenAI API key to run this test.",
)
@pytest.mark.integration
def test_live_run_with_response_format(self):
class CalendarEvent(BaseModel):
event_name: str
event_date: str
event_location: str

chat_messages = [ChatMessage.from_user("Give me information about the 20th Nobel Peace Prize.")]
component = AzureOpenAIChatGenerator(
api_version="2024-08-01-preview", generation_kwargs={"response_format": CalendarEvent}
)
results = component.run(chat_messages)
assert len(results["replies"]) == 1
message: ChatMessage = results["replies"][0]
msg = json.loads(message.text)
assert "20th Nobel Peace Prize" in msg["event_name"]
assert isinstance(msg["event_date"], str)
assert isinstance(msg["event_location"], str)

assert message.meta["finish_reason"] == "stop"

def test_to_dict_with_toolset(self, tools, monkeypatch):
"""Test that the AzureOpenAIChatGenerator can be serialized to a dictionary with a Toolset."""
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
Expand Down
Loading
Loading