mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-07 05:52:15 +00:00
openai[patch]: add usage metadata details (#27080)
This commit is contained in:
@@ -63,7 +63,11 @@ from langchain_core.messages import (
|
||||
ToolMessage,
|
||||
ToolMessageChunk,
|
||||
)
|
||||
from langchain_core.messages.ai import UsageMetadata
|
||||
from langchain_core.messages.ai import (
|
||||
InputTokenDetails,
|
||||
OutputTokenDetails,
|
||||
UsageMetadata,
|
||||
)
|
||||
from langchain_core.messages.tool import tool_call_chunk
|
||||
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
|
||||
from langchain_core.output_parsers.openai_tools import (
|
||||
@@ -286,16 +290,10 @@ def _convert_chunk_to_generation_chunk(
|
||||
) -> Optional[ChatGenerationChunk]:
|
||||
token_usage = chunk.get("usage")
|
||||
choices = chunk.get("choices", [])
|
||||
usage_metadata: Optional[UsageMetadata] = (
|
||||
UsageMetadata(
|
||||
input_tokens=token_usage.get("prompt_tokens", 0),
|
||||
output_tokens=token_usage.get("completion_tokens", 0),
|
||||
total_tokens=token_usage.get("total_tokens", 0),
|
||||
)
|
||||
if token_usage
|
||||
else None
|
||||
)
|
||||
|
||||
usage_metadata: Optional[UsageMetadata] = (
|
||||
_create_usage_metadata(token_usage) if token_usage else None
|
||||
)
|
||||
if len(choices) == 0:
|
||||
# logprobs is implicitly None
|
||||
generation_chunk = ChatGenerationChunk(
|
||||
@@ -721,15 +719,11 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
if response_dict.get("error"):
|
||||
raise ValueError(response_dict.get("error"))
|
||||
|
||||
token_usage = response_dict.get("usage", {})
|
||||
token_usage = response_dict.get("usage")
|
||||
for res in response_dict["choices"]:
|
||||
message = _convert_dict_to_message(res["message"])
|
||||
if token_usage and isinstance(message, AIMessage):
|
||||
message.usage_metadata = {
|
||||
"input_tokens": token_usage.get("prompt_tokens", 0),
|
||||
"output_tokens": token_usage.get("completion_tokens", 0),
|
||||
"total_tokens": token_usage.get("total_tokens", 0),
|
||||
}
|
||||
message.usage_metadata = _create_usage_metadata(token_usage)
|
||||
generation_info = generation_info or {}
|
||||
generation_info["finish_reason"] = (
|
||||
res.get("finish_reason")
|
||||
@@ -2160,3 +2154,34 @@ class OpenAIRefusalError(Exception):
|
||||
|
||||
.. versionadded:: 0.1.21
|
||||
"""
|
||||
|
||||
|
||||
def _create_usage_metadata(oai_token_usage: dict) -> UsageMetadata:
|
||||
input_tokens = oai_token_usage.get("prompt_tokens", 0)
|
||||
output_tokens = oai_token_usage.get("completion_tokens", 0)
|
||||
total_tokens = oai_token_usage.get("total_tokens", input_tokens + output_tokens)
|
||||
input_token_details: dict = {
|
||||
"audio": oai_token_usage.get("prompt_tokens_details", {}).get("audio_tokens"),
|
||||
"cache_read": oai_token_usage.get("prompt_tokens_details", {}).get(
|
||||
"cached_tokens"
|
||||
),
|
||||
}
|
||||
output_token_details: dict = {
|
||||
"audio": oai_token_usage.get("completion_tokens_details", {}).get(
|
||||
"audio_tokens"
|
||||
),
|
||||
"reasoning": oai_token_usage.get("completion_tokens_details", {}).get(
|
||||
"reasoning_tokens"
|
||||
),
|
||||
}
|
||||
return UsageMetadata(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
total_tokens=total_tokens,
|
||||
input_token_details=InputTokenDetails(
|
||||
**{k: v for k, v in input_token_details.items() if v is not None}
|
||||
),
|
||||
output_token_details=OutputTokenDetails(
|
||||
**{k: v for k, v in output_token_details.items() if v is not None}
|
||||
),
|
||||
)
|
||||
|
Reference in New Issue
Block a user