Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 063fc77

Browse files
gustavocidornelaswhoseoyster
authored andcommitted
feat: parse usage after response streaming
1 parent 9c1e276 commit 063fc77

File tree

1 file changed

+41
-3
lines changed

1 file changed

+41
-3
lines changed

‎src/openlayer/lib/integrations/langchain_callback.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,16 @@ def _handle_llm_end(
492492
return
493493

494494
output = self._extract_output(response)
495-
token_info = self._extract_token_info(response)
495+
496+
# Only extract token info if it hasn't been set during streaming
497+
step = self.steps[run_id]
498+
token_info = {}
499+
if not (
500+
hasattr(step, "prompt_tokens")
501+
and step.prompt_tokens is not None
502+
and step.prompt_tokens > 0
503+
):
504+
token_info = self._extract_token_info(response)
496505

497506
self._end_step(
498507
run_id=run_id,
@@ -763,6 +772,35 @@ def _handle_retriever_error(
763772
"""Common logic for retriever error."""
764773
self._end_step(run_id=run_id, parent_run_id=parent_run_id, error=str(error))
765774

775+
def _handle_llm_new_token(self, token: str, **kwargs: Any) -> Any:
776+
"""Common logic for LLM new token."""
777+
# Safely check for chunk and usage_metadata
778+
chunk = kwargs.get("chunk")
779+
if (
780+
chunk
781+
and hasattr(chunk, "message")
782+
and hasattr(chunk.message, "usage_metadata")
783+
):
784+
usage = chunk.message.usage_metadata
785+
786+
# Only proceed if usage is not None
787+
if usage:
788+
# Extract run_id from kwargs (should be provided by LangChain)
789+
run_id = kwargs.get("run_id")
790+
if run_id and run_id in self.steps:
791+
# Convert usage to the expected format like _extract_token_info does
792+
token_info = {
793+
"prompt_tokens": usage.get("input_tokens", 0),
794+
"completion_tokens": usage.get("output_tokens", 0),
795+
"tokens": usage.get("total_tokens", 0),
796+
}
797+
798+
# Update the step with token usage information
799+
step = self.steps[run_id]
800+
if isinstance(step, steps.ChatCompletionStep):
801+
step.log(**token_info)
802+
return
803+
766804

767805
class OpenlayerHandler(OpenlayerHandlerMixin, BaseCallbackHandlerClass): # type: ignore[misc]
768806
"""LangChain callback handler that logs to Openlayer."""
@@ -848,7 +886,7 @@ def on_llm_error(
848886

849887
def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:
850888
"""Run on new LLM token. Only available when streaming is enabled."""
851-
pass
889+
return self._handle_llm_new_token(token, **kwargs)
852890

853891
def on_chain_start(
854892
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
@@ -1137,7 +1175,7 @@ async def on_llm_error(
11371175
return self._handle_llm_error(error, **kwargs)
11381176

11391177
async def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:
1140-
pass
1178+
return self._handle_llm_new_token(token, **kwargs)
11411179

11421180
async def on_chain_start(
11431181
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any

0 commit comments

Comments
 (0)