community[patch]: ChatPerplexity: track usage metadata (#30175)

This commit is contained in:
ccurme 2025-03-07 18:25:05 -05:00 committed by GitHub
parent 6c05d4b153
commit a7ab5e8372
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 38 additions and 6 deletions

View File

@ -38,6 +38,7 @@ from langchain_core.messages import (
SystemMessageChunk,
ToolMessageChunk,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
@ -59,6 +60,17 @@ def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and is_basemodel_subclass(obj)
def _create_usage_metadata(token_usage: dict) -> UsageMetadata:
input_tokens = token_usage.get("prompt_tokens", 0)
output_tokens = token_usage.get("completion_tokens", 0)
total_tokens = token_usage.get("total_tokens", input_tokens + output_tokens)
return UsageMetadata(
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=total_tokens,
)
class ChatPerplexity(BaseChatModel):
"""`Perplexity AI` Chat models API.
@ -238,9 +250,27 @@ class ChatPerplexity(BaseChatModel):
messages=message_dicts, stream=True, **params
)
first_chunk = True
prev_total_usage: Optional[UsageMetadata] = None
for chunk in stream_resp:
if not isinstance(chunk, dict):
chunk = chunk.dict()
# Collect standard usage metadata (transform from aggregate to delta)
if total_usage := chunk.get("usage"):
lc_total_usage = _create_usage_metadata(total_usage)
if prev_total_usage:
usage_metadata: Optional[UsageMetadata] = {
"input_tokens": lc_total_usage["input_tokens"]
- prev_total_usage["input_tokens"],
"output_tokens": lc_total_usage["output_tokens"]
- prev_total_usage["output_tokens"],
"total_tokens": lc_total_usage["total_tokens"]
- prev_total_usage["total_tokens"],
}
else:
usage_metadata = lc_total_usage
prev_total_usage = lc_total_usage
else:
usage_metadata = None
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
@ -249,6 +279,8 @@ class ChatPerplexity(BaseChatModel):
chunk = self._convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
if isinstance(chunk, AIMessageChunk) and usage_metadata:
chunk.usage_metadata = usage_metadata
if first_chunk:
chunk.additional_kwargs |= {"citations": citations}
first_chunk = False
@ -278,9 +310,15 @@ class ChatPerplexity(BaseChatModel):
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = self.client.chat.completions.create(messages=message_dicts, **params)
if usage := getattr(response, "usage", None):
usage_metadata = _create_usage_metadata(usage.model_dump())
else:
usage_metadata = None
message = AIMessage(
content=response.choices[0].message.content,
additional_kwargs={"citations": response.citations},
usage_metadata=usage_metadata,
)
return ChatResult(generations=[ChatGeneration(message=message)])

View File

@ -18,12 +18,6 @@ class TestPerplexityStandard(ChatModelIntegrationTests):
def chat_model_params(self) -> dict:
return {"model": "sonar"}
@property
def returns_usage_metadata(self) -> bool:
# TODO: add usage metadata and delete this property
# https://docs.perplexity.ai/api-reference/chat-completions#response-usage
return False
@pytest.mark.xfail(reason="TODO: handle in integration.")
def test_double_messages_conversation(self, model: BaseChatModel) -> None:
super().test_double_messages_conversation(model)