mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 06:39:52 +00:00
community[patch]: ChatPerplexity: track usage metadata (#30175)
This commit is contained in:
parent
6c05d4b153
commit
a7ab5e8372
@ -38,6 +38,7 @@ from langchain_core.messages import (
|
|||||||
SystemMessageChunk,
|
SystemMessageChunk,
|
||||||
ToolMessageChunk,
|
ToolMessageChunk,
|
||||||
)
|
)
|
||||||
|
from langchain_core.messages.ai import UsageMetadata
|
||||||
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
|
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||||
@ -59,6 +60,17 @@ def _is_pydantic_class(obj: Any) -> bool:
|
|||||||
return isinstance(obj, type) and is_basemodel_subclass(obj)
|
return isinstance(obj, type) and is_basemodel_subclass(obj)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_usage_metadata(token_usage: dict) -> UsageMetadata:
|
||||||
|
input_tokens = token_usage.get("prompt_tokens", 0)
|
||||||
|
output_tokens = token_usage.get("completion_tokens", 0)
|
||||||
|
total_tokens = token_usage.get("total_tokens", input_tokens + output_tokens)
|
||||||
|
return UsageMetadata(
|
||||||
|
input_tokens=input_tokens,
|
||||||
|
output_tokens=output_tokens,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ChatPerplexity(BaseChatModel):
|
class ChatPerplexity(BaseChatModel):
|
||||||
"""`Perplexity AI` Chat models API.
|
"""`Perplexity AI` Chat models API.
|
||||||
|
|
||||||
@ -238,9 +250,27 @@ class ChatPerplexity(BaseChatModel):
|
|||||||
messages=message_dicts, stream=True, **params
|
messages=message_dicts, stream=True, **params
|
||||||
)
|
)
|
||||||
first_chunk = True
|
first_chunk = True
|
||||||
|
prev_total_usage: Optional[UsageMetadata] = None
|
||||||
for chunk in stream_resp:
|
for chunk in stream_resp:
|
||||||
if not isinstance(chunk, dict):
|
if not isinstance(chunk, dict):
|
||||||
chunk = chunk.dict()
|
chunk = chunk.dict()
|
||||||
|
# Collect standard usage metadata (transform from aggregate to delta)
|
||||||
|
if total_usage := chunk.get("usage"):
|
||||||
|
lc_total_usage = _create_usage_metadata(total_usage)
|
||||||
|
if prev_total_usage:
|
||||||
|
usage_metadata: Optional[UsageMetadata] = {
|
||||||
|
"input_tokens": lc_total_usage["input_tokens"]
|
||||||
|
- prev_total_usage["input_tokens"],
|
||||||
|
"output_tokens": lc_total_usage["output_tokens"]
|
||||||
|
- prev_total_usage["output_tokens"],
|
||||||
|
"total_tokens": lc_total_usage["total_tokens"]
|
||||||
|
- prev_total_usage["total_tokens"],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
usage_metadata = lc_total_usage
|
||||||
|
prev_total_usage = lc_total_usage
|
||||||
|
else:
|
||||||
|
usage_metadata = None
|
||||||
if len(chunk["choices"]) == 0:
|
if len(chunk["choices"]) == 0:
|
||||||
continue
|
continue
|
||||||
choice = chunk["choices"][0]
|
choice = chunk["choices"][0]
|
||||||
@ -249,6 +279,8 @@ class ChatPerplexity(BaseChatModel):
|
|||||||
chunk = self._convert_delta_to_message_chunk(
|
chunk = self._convert_delta_to_message_chunk(
|
||||||
choice["delta"], default_chunk_class
|
choice["delta"], default_chunk_class
|
||||||
)
|
)
|
||||||
|
if isinstance(chunk, AIMessageChunk) and usage_metadata:
|
||||||
|
chunk.usage_metadata = usage_metadata
|
||||||
if first_chunk:
|
if first_chunk:
|
||||||
chunk.additional_kwargs |= {"citations": citations}
|
chunk.additional_kwargs |= {"citations": citations}
|
||||||
first_chunk = False
|
first_chunk = False
|
||||||
@ -278,9 +310,15 @@ class ChatPerplexity(BaseChatModel):
|
|||||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
params = {**params, **kwargs}
|
params = {**params, **kwargs}
|
||||||
response = self.client.chat.completions.create(messages=message_dicts, **params)
|
response = self.client.chat.completions.create(messages=message_dicts, **params)
|
||||||
|
if usage := getattr(response, "usage", None):
|
||||||
|
usage_metadata = _create_usage_metadata(usage.model_dump())
|
||||||
|
else:
|
||||||
|
usage_metadata = None
|
||||||
|
|
||||||
message = AIMessage(
|
message = AIMessage(
|
||||||
content=response.choices[0].message.content,
|
content=response.choices[0].message.content,
|
||||||
additional_kwargs={"citations": response.citations},
|
additional_kwargs={"citations": response.citations},
|
||||||
|
usage_metadata=usage_metadata,
|
||||||
)
|
)
|
||||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||||
|
|
||||||
|
@ -18,12 +18,6 @@ class TestPerplexityStandard(ChatModelIntegrationTests):
|
|||||||
def chat_model_params(self) -> dict:
|
def chat_model_params(self) -> dict:
|
||||||
return {"model": "sonar"}
|
return {"model": "sonar"}
|
||||||
|
|
||||||
@property
|
|
||||||
def returns_usage_metadata(self) -> bool:
|
|
||||||
# TODO: add usage metadata and delete this property
|
|
||||||
# https://docs.perplexity.ai/api-reference/chat-completions#response-usage
|
|
||||||
return False
|
|
||||||
|
|
||||||
@pytest.mark.xfail(reason="TODO: handle in integration.")
|
@pytest.mark.xfail(reason="TODO: handle in integration.")
|
||||||
def test_double_messages_conversation(self, model: BaseChatModel) -> None:
|
def test_double_messages_conversation(self, model: BaseChatModel) -> None:
|
||||||
super().test_double_messages_conversation(model)
|
super().test_double_messages_conversation(model)
|
||||||
|
Loading…
Reference in New Issue
Block a user