From a7ab5e83726fe6ede7466f94520804044415d3b7 Mon Sep 17 00:00:00 2001 From: ccurme Date: Fri, 7 Mar 2025 18:25:05 -0500 Subject: [PATCH] community[patch]: ChatPerplexity: track usage metadata (#30175) --- .../chat_models/perplexity.py | 38 +++++++++++++++++++ .../chat_models/test_perplexity.py | 6 --- 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/libs/community/langchain_community/chat_models/perplexity.py b/libs/community/langchain_community/chat_models/perplexity.py index 0e5ea613447..0266244223d 100644 --- a/libs/community/langchain_community/chat_models/perplexity.py +++ b/libs/community/langchain_community/chat_models/perplexity.py @@ -38,6 +38,7 @@ from langchain_core.messages import ( SystemMessageChunk, ToolMessageChunk, ) +from langchain_core.messages.ai import UsageMetadata from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough @@ -59,6 +60,17 @@ def _is_pydantic_class(obj: Any) -> bool: return isinstance(obj, type) and is_basemodel_subclass(obj) +def _create_usage_metadata(token_usage: dict) -> UsageMetadata: + input_tokens = token_usage.get("prompt_tokens", 0) + output_tokens = token_usage.get("completion_tokens", 0) + total_tokens = token_usage.get("total_tokens", input_tokens + output_tokens) + return UsageMetadata( + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=total_tokens, + ) + + class ChatPerplexity(BaseChatModel): """`Perplexity AI` Chat models API. @@ -238,9 +250,27 @@ class ChatPerplexity(BaseChatModel): messages=message_dicts, stream=True, **params ) first_chunk = True + prev_total_usage: Optional[UsageMetadata] = None for chunk in stream_resp: if not isinstance(chunk, dict): chunk = chunk.dict() + # Collect standard usage metadata (transform from aggregate to delta) + if total_usage := chunk.get("usage"): + lc_total_usage = _create_usage_metadata(total_usage) + if prev_total_usage: + usage_metadata: Optional[UsageMetadata] = { + "input_tokens": lc_total_usage["input_tokens"] + - prev_total_usage["input_tokens"], + "output_tokens": lc_total_usage["output_tokens"] + - prev_total_usage["output_tokens"], + "total_tokens": lc_total_usage["total_tokens"] + - prev_total_usage["total_tokens"], + } + else: + usage_metadata = lc_total_usage + prev_total_usage = lc_total_usage + else: + usage_metadata = None if len(chunk["choices"]) == 0: continue choice = chunk["choices"][0] @@ -249,6 +279,8 @@ class ChatPerplexity(BaseChatModel): chunk = self._convert_delta_to_message_chunk( choice["delta"], default_chunk_class ) + if isinstance(chunk, AIMessageChunk) and usage_metadata: + chunk.usage_metadata = usage_metadata if first_chunk: chunk.additional_kwargs |= {"citations": citations} first_chunk = False @@ -278,9 +310,15 @@ class ChatPerplexity(BaseChatModel): message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs} response = self.client.chat.completions.create(messages=message_dicts, **params) + if usage := getattr(response, "usage", None): + usage_metadata = _create_usage_metadata(usage.model_dump()) + else: + usage_metadata = None + message = AIMessage( content=response.choices[0].message.content, additional_kwargs={"citations": response.citations}, + usage_metadata=usage_metadata, ) return ChatResult(generations=[ChatGeneration(message=message)]) diff --git a/libs/community/tests/integration_tests/chat_models/test_perplexity.py b/libs/community/tests/integration_tests/chat_models/test_perplexity.py index 63cdd4a59d3..5288fccc9b7 100644 --- a/libs/community/tests/integration_tests/chat_models/test_perplexity.py +++ b/libs/community/tests/integration_tests/chat_models/test_perplexity.py @@ -18,12 +18,6 @@ class TestPerplexityStandard(ChatModelIntegrationTests): def chat_model_params(self) -> dict: return {"model": "sonar"} - @property - def returns_usage_metadata(self) -> bool: - # TODO: add usage metadata and delete this property - # https://docs.perplexity.ai/api-reference/chat-completions#response-usage - return False - @pytest.mark.xfail(reason="TODO: handle in integration.") def test_double_messages_conversation(self, model: BaseChatModel) -> None: super().test_double_messages_conversation(model)