Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 86 additions & 19 deletions langfuse/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,20 @@ class OpenAiDefinition:
sync=False,
min_version="1.66.0",
),
OpenAiDefinition(
module="openai.resources.embeddings",
object="Embeddings",
method="create",
type="embedding",
sync=True,
),
OpenAiDefinition(
module="openai.resources.embeddings",
object="AsyncEmbeddings",
method="create",
type="embedding",
sync=False,
),
]


Expand Down Expand Up @@ -340,10 +354,13 @@ def _extract_chat_response(kwargs: Any) -> Any:


def _get_langfuse_data_from_kwargs(resource: OpenAiDefinition, kwargs: Any) -> Any:
name = kwargs.get("name", "OpenAI-generation")
default_name = (
"OpenAI-embedding" if resource.type == "embedding" else "OpenAI-generation"
)
name = kwargs.get("name", default_name)

if name is None:
name = "OpenAI-generation"
name = default_name

if name is not None and not isinstance(name, str):
raise TypeError("name must be a string")
Expand Down Expand Up @@ -395,6 +412,8 @@ def _get_langfuse_data_from_kwargs(resource: OpenAiDefinition, kwargs: Any) -> A
prompt = kwargs.get("input", None)
elif resource.type == "chat":
prompt = _extract_chat_prompt(kwargs)
elif resource.type == "embedding":
prompt = kwargs.get("input", None)

parsed_temperature = (
kwargs.get("temperature", 1)
Expand Down Expand Up @@ -440,23 +459,41 @@ def _get_langfuse_data_from_kwargs(resource: OpenAiDefinition, kwargs: Any) -> A

parsed_n = kwargs.get("n", 1) if not isinstance(kwargs.get("n", 1), NotGiven) else 1

modelParameters = {
"temperature": parsed_temperature,
"max_tokens": parsed_max_tokens, # casing?
"top_p": parsed_top_p,
"frequency_penalty": parsed_frequency_penalty,
"presence_penalty": parsed_presence_penalty,
}
if resource.type == "embedding":
parsed_dimensions = (
kwargs.get("dimensions", None)
if not isinstance(kwargs.get("dimensions", None), NotGiven)
else None
)
parsed_encoding_format = (
kwargs.get("encoding_format", "float")
if not isinstance(kwargs.get("encoding_format", "float"), NotGiven)
else "float"
)

if parsed_max_completion_tokens is not None:
modelParameters.pop("max_tokens", None)
modelParameters["max_completion_tokens"] = parsed_max_completion_tokens
modelParameters = {}
if parsed_dimensions is not None:
modelParameters["dimensions"] = parsed_dimensions
if parsed_encoding_format != "float":
modelParameters["encoding_format"] = parsed_encoding_format
else:
modelParameters = {
"temperature": parsed_temperature,
"max_tokens": parsed_max_tokens,
"top_p": parsed_top_p,
"frequency_penalty": parsed_frequency_penalty,
"presence_penalty": parsed_presence_penalty,
}

if parsed_n is not None and parsed_n > 1:
modelParameters["n"] = parsed_n
if parsed_max_completion_tokens is not None:
modelParameters.pop("max_tokens", None)
modelParameters["max_completion_tokens"] = parsed_max_completion_tokens

if parsed_seed is not None:
modelParameters["seed"] = parsed_seed
if parsed_n is not None and parsed_n > 1:
modelParameters["n"] = parsed_n

if parsed_seed is not None:
modelParameters["seed"] = parsed_seed

langfuse_prompt = kwargs.get("langfuse_prompt", None)

Expand Down Expand Up @@ -521,6 +558,14 @@ def _parse_usage(usage: Optional[Any] = None) -> Any:
k: v for k, v in tokens_details_dict.items() if v is not None
}

if (
len(usage_dict) == 2
and "prompt_tokens" in usage_dict
and "total_tokens" in usage_dict
):
# handle embedding usage
return {"input": usage_dict["prompt_tokens"]}

return usage_dict


Expand Down Expand Up @@ -646,7 +691,7 @@ def _extract_streamed_openai_response(resource: Any, chunks: Any) -> Any:
curr[-1]["arguments"] = ""

curr[-1]["arguments"] += getattr(
tool_call_chunk, "arguments", None
tool_call_chunk, "arguments", ""
)

if resource.type == "completion":
Expand Down Expand Up @@ -729,6 +774,20 @@ def _get_langfuse_data_from_default_response(
else choice.get("message", None)
)

elif resource.type == "embedding":
data = response.get("data", [])
if len(data) > 0:
first_embedding = data[0]
embedding_vector = (
first_embedding.embedding
if hasattr(first_embedding, "embedding")
else first_embedding.get("embedding", [])
)
completion = {
"dimensions": len(embedding_vector) if embedding_vector else 0,
"count": len(data),
}

usage = _parse_usage(response.get("usage", None))

return (model, completion, usage)
Expand Down Expand Up @@ -757,8 +816,12 @@ def _wrap(
langfuse_data = _get_langfuse_data_from_kwargs(open_ai_resource, langfuse_args)
langfuse_client = get_client(public_key=langfuse_args["langfuse_public_key"])

observation_type = (
"embedding" if open_ai_resource.type == "embedding" else "generation"
)

generation = langfuse_client.start_observation(
as_type="generation",
as_type=observation_type, # type: ignore
name=langfuse_data["name"],
input=langfuse_data.get("input", None),
metadata=langfuse_data.get("metadata", None),
Expand Down Expand Up @@ -824,8 +887,12 @@ async def _wrap_async(
langfuse_data = _get_langfuse_data_from_kwargs(open_ai_resource, langfuse_args)
langfuse_client = get_client(public_key=langfuse_args["langfuse_public_key"])

observation_type = (
"embedding" if open_ai_resource.type == "embedding" else "generation"
)

generation = langfuse_client.start_observation(
as_type="generation",
as_type=observation_type, # type: ignore
name=langfuse_data["name"],
input=langfuse_data.get("input", None),
metadata=langfuse_data.get("metadata", None),
Expand Down
90 changes: 90 additions & 0 deletions tests/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -1514,3 +1514,93 @@ def test_response_api_reasoning(openai):
assert generationData.usage.total is not None
assert generationData.output is not None
assert generationData.metadata is not None


def test_openai_embeddings(openai):
embedding_name = create_uuid()
openai.OpenAI().embeddings.create(
name=embedding_name,
model="text-embedding-ada-002",
input="The quick brown fox jumps over the lazy dog",
metadata={"test_key": "test_value"},
)

langfuse.flush()
sleep(1)

embedding = get_api().observations.get_many(name=embedding_name, type="EMBEDDING")

assert len(embedding.data) != 0
embedding_data = embedding.data[0]
assert embedding_data.name == embedding_name
assert embedding_data.metadata["test_key"] == "test_value"
assert embedding_data.input == "The quick brown fox jumps over the lazy dog"
assert embedding_data.type == "EMBEDDING"
assert "text-embedding-ada-002" in embedding_data.model
assert embedding_data.start_time is not None
assert embedding_data.end_time is not None
assert embedding_data.start_time < embedding_data.end_time
assert embedding_data.usage.input is not None
assert embedding_data.usage.total is not None
assert embedding_data.output is not None
assert "dimensions" in embedding_data.output
assert "count" in embedding_data.output
assert embedding_data.output["count"] == 1


def test_openai_embeddings_multiple_inputs(openai):
embedding_name = create_uuid()
inputs = ["The quick brown fox", "jumps over the lazy dog", "Hello world"]

openai.OpenAI().embeddings.create(
name=embedding_name,
model="text-embedding-ada-002",
input=inputs,
metadata={"batch_size": len(inputs)},
)

langfuse.flush()
sleep(1)

embedding = get_api().observations.get_many(name=embedding_name, type="EMBEDDING")

assert len(embedding.data) != 0
embedding_data = embedding.data[0]
assert embedding_data.name == embedding_name
assert embedding_data.input == inputs
assert embedding_data.type == "EMBEDDING"
assert "text-embedding-ada-002" in embedding_data.model
assert embedding_data.usage.input is not None
assert embedding_data.usage.total is not None
assert embedding_data.output["count"] == len(inputs)


@pytest.mark.asyncio
async def test_async_openai_embeddings(openai):
client = openai.AsyncOpenAI()
embedding_name = create_uuid()
print(embedding_name)

result = await client.embeddings.create(
name=embedding_name,
model="text-embedding-ada-002",
input="Async embedding test",
metadata={"async": True},
)

print("result:", result.usage)

langfuse.flush()
sleep(1)

embedding = get_api().observations.get_many(name=embedding_name, type="EMBEDDING")

assert len(embedding.data) != 0
embedding_data = embedding.data[0]
assert embedding_data.name == embedding_name
assert embedding_data.input == "Async embedding test"
assert embedding_data.type == "EMBEDDING"
assert "text-embedding-ada-002" in embedding_data.model
assert embedding_data.metadata["async"] is True
assert embedding_data.usage.input is not None
assert embedding_data.usage.total is not None