mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 14:18:52 +00:00
core, partners: add token usage attribute to AIMessage (#21944)
```python class UsageMetadata(TypedDict): """Usage metadata for a message, such as token counts. Attributes: input_tokens: (int) count of input (or prompt) tokens output_tokens: (int) count of output (or completion) tokens total_tokens: (int) total token count """ input_tokens: int output_tokens: int total_tokens: int ``` ```python class AIMessage(BaseMessage): ... usage_metadata: Optional[UsageMetadata] = None """If provided, token usage information associated with the message.""" ... ```
This commit is contained in:
parent
3d26807b92
commit
fbfed65fb1
@ -1,4 +1,6 @@
|
|||||||
from typing import Any, Dict, List, Literal, Union
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from langchain_core.messages.base import (
|
from langchain_core.messages.base import (
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
@ -19,6 +21,20 @@ from langchain_core.utils.json import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UsageMetadata(TypedDict):
|
||||||
|
"""Usage metadata for a message, such as token counts.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
input_tokens: (int) count of input (or prompt) tokens
|
||||||
|
output_tokens: (int) count of output (or completion) tokens
|
||||||
|
total_tokens: (int) total token count
|
||||||
|
"""
|
||||||
|
|
||||||
|
input_tokens: int
|
||||||
|
output_tokens: int
|
||||||
|
total_tokens: int
|
||||||
|
|
||||||
|
|
||||||
class AIMessage(BaseMessage):
|
class AIMessage(BaseMessage):
|
||||||
"""Message from an AI."""
|
"""Message from an AI."""
|
||||||
|
|
||||||
@ -31,6 +47,11 @@ class AIMessage(BaseMessage):
|
|||||||
"""If provided, tool calls associated with the message."""
|
"""If provided, tool calls associated with the message."""
|
||||||
invalid_tool_calls: List[InvalidToolCall] = []
|
invalid_tool_calls: List[InvalidToolCall] = []
|
||||||
"""If provided, tool calls with parsing errors associated with the message."""
|
"""If provided, tool calls with parsing errors associated with the message."""
|
||||||
|
usage_metadata: Optional[UsageMetadata] = None
|
||||||
|
"""If provided, usage metadata for a message, such as token counts.
|
||||||
|
|
||||||
|
This is a standard representation of token usage that is consistent across models.
|
||||||
|
"""
|
||||||
|
|
||||||
type: Literal["ai"] = "ai"
|
type: Literal["ai"] = "ai"
|
||||||
|
|
||||||
@ -198,12 +219,29 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
|
|||||||
else:
|
else:
|
||||||
tool_call_chunks = []
|
tool_call_chunks = []
|
||||||
|
|
||||||
|
# Token usage
|
||||||
|
if self.usage_metadata or other.usage_metadata:
|
||||||
|
left: UsageMetadata = self.usage_metadata or UsageMetadata(
|
||||||
|
input_tokens=0, output_tokens=0, total_tokens=0
|
||||||
|
)
|
||||||
|
right: UsageMetadata = other.usage_metadata or UsageMetadata(
|
||||||
|
input_tokens=0, output_tokens=0, total_tokens=0
|
||||||
|
)
|
||||||
|
usage_metadata: Optional[UsageMetadata] = {
|
||||||
|
"input_tokens": left["input_tokens"] + right["input_tokens"],
|
||||||
|
"output_tokens": left["output_tokens"] + right["output_tokens"],
|
||||||
|
"total_tokens": left["total_tokens"] + right["total_tokens"],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
usage_metadata = None
|
||||||
|
|
||||||
return self.__class__(
|
return self.__class__(
|
||||||
example=self.example,
|
example=self.example,
|
||||||
content=content,
|
content=content,
|
||||||
additional_kwargs=additional_kwargs,
|
additional_kwargs=additional_kwargs,
|
||||||
tool_call_chunks=tool_call_chunks,
|
tool_call_chunks=tool_call_chunks,
|
||||||
response_metadata=response_metadata,
|
response_metadata=response_metadata,
|
||||||
|
usage_metadata=usage_metadata,
|
||||||
id=self.id,
|
id=self.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -5286,6 +5286,9 @@
|
|||||||
'title': 'Type',
|
'title': 'Type',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
}),
|
}),
|
||||||
|
'usage_metadata': dict({
|
||||||
|
'$ref': '#/definitions/UsageMetadata',
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
'required': list([
|
'required': list([
|
||||||
'content',
|
'content',
|
||||||
@ -5707,6 +5710,29 @@
|
|||||||
'title': 'ToolMessage',
|
'title': 'ToolMessage',
|
||||||
'type': 'object',
|
'type': 'object',
|
||||||
}),
|
}),
|
||||||
|
'UsageMetadata': dict({
|
||||||
|
'properties': dict({
|
||||||
|
'input_tokens': dict({
|
||||||
|
'title': 'Input Tokens',
|
||||||
|
'type': 'integer',
|
||||||
|
}),
|
||||||
|
'output_tokens': dict({
|
||||||
|
'title': 'Output Tokens',
|
||||||
|
'type': 'integer',
|
||||||
|
}),
|
||||||
|
'total_tokens': dict({
|
||||||
|
'title': 'Total Tokens',
|
||||||
|
'type': 'integer',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'required': list([
|
||||||
|
'input_tokens',
|
||||||
|
'output_tokens',
|
||||||
|
'total_tokens',
|
||||||
|
]),
|
||||||
|
'title': 'UsageMetadata',
|
||||||
|
'type': 'object',
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
'title': 'FakeListLLMInput',
|
'title': 'FakeListLLMInput',
|
||||||
})
|
})
|
||||||
@ -5821,6 +5847,9 @@
|
|||||||
'title': 'Type',
|
'title': 'Type',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
}),
|
}),
|
||||||
|
'usage_metadata': dict({
|
||||||
|
'$ref': '#/definitions/UsageMetadata',
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
'required': list([
|
'required': list([
|
||||||
'content',
|
'content',
|
||||||
@ -6242,6 +6271,29 @@
|
|||||||
'title': 'ToolMessage',
|
'title': 'ToolMessage',
|
||||||
'type': 'object',
|
'type': 'object',
|
||||||
}),
|
}),
|
||||||
|
'UsageMetadata': dict({
|
||||||
|
'properties': dict({
|
||||||
|
'input_tokens': dict({
|
||||||
|
'title': 'Input Tokens',
|
||||||
|
'type': 'integer',
|
||||||
|
}),
|
||||||
|
'output_tokens': dict({
|
||||||
|
'title': 'Output Tokens',
|
||||||
|
'type': 'integer',
|
||||||
|
}),
|
||||||
|
'total_tokens': dict({
|
||||||
|
'title': 'Total Tokens',
|
||||||
|
'type': 'integer',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'required': list([
|
||||||
|
'input_tokens',
|
||||||
|
'output_tokens',
|
||||||
|
'total_tokens',
|
||||||
|
]),
|
||||||
|
'title': 'UsageMetadata',
|
||||||
|
'type': 'object',
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
'title': 'FakeListChatModelInput',
|
'title': 'FakeListChatModelInput',
|
||||||
})
|
})
|
||||||
@ -6340,6 +6392,9 @@
|
|||||||
'title': 'Type',
|
'title': 'Type',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
}),
|
}),
|
||||||
|
'usage_metadata': dict({
|
||||||
|
'$ref': '#/definitions/UsageMetadata',
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
'required': list([
|
'required': list([
|
||||||
'content',
|
'content',
|
||||||
@ -6692,6 +6747,29 @@
|
|||||||
'title': 'ToolMessage',
|
'title': 'ToolMessage',
|
||||||
'type': 'object',
|
'type': 'object',
|
||||||
}),
|
}),
|
||||||
|
'UsageMetadata': dict({
|
||||||
|
'properties': dict({
|
||||||
|
'input_tokens': dict({
|
||||||
|
'title': 'Input Tokens',
|
||||||
|
'type': 'integer',
|
||||||
|
}),
|
||||||
|
'output_tokens': dict({
|
||||||
|
'title': 'Output Tokens',
|
||||||
|
'type': 'integer',
|
||||||
|
}),
|
||||||
|
'total_tokens': dict({
|
||||||
|
'title': 'Total Tokens',
|
||||||
|
'type': 'integer',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'required': list([
|
||||||
|
'input_tokens',
|
||||||
|
'output_tokens',
|
||||||
|
'total_tokens',
|
||||||
|
]),
|
||||||
|
'title': 'UsageMetadata',
|
||||||
|
'type': 'object',
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
'title': 'FakeListChatModelOutput',
|
'title': 'FakeListChatModelOutput',
|
||||||
})
|
})
|
||||||
@ -6778,6 +6856,9 @@
|
|||||||
'title': 'Type',
|
'title': 'Type',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
}),
|
}),
|
||||||
|
'usage_metadata': dict({
|
||||||
|
'$ref': '#/definitions/UsageMetadata',
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
'required': list([
|
'required': list([
|
||||||
'content',
|
'content',
|
||||||
@ -7199,6 +7280,29 @@
|
|||||||
'title': 'ToolMessage',
|
'title': 'ToolMessage',
|
||||||
'type': 'object',
|
'type': 'object',
|
||||||
}),
|
}),
|
||||||
|
'UsageMetadata': dict({
|
||||||
|
'properties': dict({
|
||||||
|
'input_tokens': dict({
|
||||||
|
'title': 'Input Tokens',
|
||||||
|
'type': 'integer',
|
||||||
|
}),
|
||||||
|
'output_tokens': dict({
|
||||||
|
'title': 'Output Tokens',
|
||||||
|
'type': 'integer',
|
||||||
|
}),
|
||||||
|
'total_tokens': dict({
|
||||||
|
'title': 'Total Tokens',
|
||||||
|
'type': 'integer',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'required': list([
|
||||||
|
'input_tokens',
|
||||||
|
'output_tokens',
|
||||||
|
'total_tokens',
|
||||||
|
]),
|
||||||
|
'title': 'UsageMetadata',
|
||||||
|
'type': 'object',
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
'title': 'ChatPromptTemplateOutput',
|
'title': 'ChatPromptTemplateOutput',
|
||||||
})
|
})
|
||||||
@ -7285,6 +7389,9 @@
|
|||||||
'title': 'Type',
|
'title': 'Type',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
}),
|
}),
|
||||||
|
'usage_metadata': dict({
|
||||||
|
'$ref': '#/definitions/UsageMetadata',
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
'required': list([
|
'required': list([
|
||||||
'content',
|
'content',
|
||||||
@ -7706,6 +7813,29 @@
|
|||||||
'title': 'ToolMessage',
|
'title': 'ToolMessage',
|
||||||
'type': 'object',
|
'type': 'object',
|
||||||
}),
|
}),
|
||||||
|
'UsageMetadata': dict({
|
||||||
|
'properties': dict({
|
||||||
|
'input_tokens': dict({
|
||||||
|
'title': 'Input Tokens',
|
||||||
|
'type': 'integer',
|
||||||
|
}),
|
||||||
|
'output_tokens': dict({
|
||||||
|
'title': 'Output Tokens',
|
||||||
|
'type': 'integer',
|
||||||
|
}),
|
||||||
|
'total_tokens': dict({
|
||||||
|
'title': 'Total Tokens',
|
||||||
|
'type': 'integer',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'required': list([
|
||||||
|
'input_tokens',
|
||||||
|
'output_tokens',
|
||||||
|
'total_tokens',
|
||||||
|
]),
|
||||||
|
'title': 'UsageMetadata',
|
||||||
|
'type': 'object',
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
'title': 'PromptTemplateOutput',
|
'title': 'PromptTemplateOutput',
|
||||||
})
|
})
|
||||||
@ -7784,6 +7914,9 @@
|
|||||||
'title': 'Type',
|
'title': 'Type',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
}),
|
}),
|
||||||
|
'usage_metadata': dict({
|
||||||
|
'$ref': '#/definitions/UsageMetadata',
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
'required': list([
|
'required': list([
|
||||||
'content',
|
'content',
|
||||||
@ -8216,6 +8349,29 @@
|
|||||||
'title': 'ToolMessage',
|
'title': 'ToolMessage',
|
||||||
'type': 'object',
|
'type': 'object',
|
||||||
}),
|
}),
|
||||||
|
'UsageMetadata': dict({
|
||||||
|
'properties': dict({
|
||||||
|
'input_tokens': dict({
|
||||||
|
'title': 'Input Tokens',
|
||||||
|
'type': 'integer',
|
||||||
|
}),
|
||||||
|
'output_tokens': dict({
|
||||||
|
'title': 'Output Tokens',
|
||||||
|
'type': 'integer',
|
||||||
|
}),
|
||||||
|
'total_tokens': dict({
|
||||||
|
'title': 'Total Tokens',
|
||||||
|
'type': 'integer',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'required': list([
|
||||||
|
'input_tokens',
|
||||||
|
'output_tokens',
|
||||||
|
'total_tokens',
|
||||||
|
]),
|
||||||
|
'title': 'UsageMetadata',
|
||||||
|
'type': 'object',
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
'items': dict({
|
'items': dict({
|
||||||
'$ref': '#/definitions/PromptTemplateOutput',
|
'$ref': '#/definitions/PromptTemplateOutput',
|
||||||
@ -8321,6 +8477,9 @@
|
|||||||
'title': 'Type',
|
'title': 'Type',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
}),
|
}),
|
||||||
|
'usage_metadata': dict({
|
||||||
|
'$ref': '#/definitions/UsageMetadata',
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
'required': list([
|
'required': list([
|
||||||
'content',
|
'content',
|
||||||
@ -8673,6 +8832,29 @@
|
|||||||
'title': 'ToolMessage',
|
'title': 'ToolMessage',
|
||||||
'type': 'object',
|
'type': 'object',
|
||||||
}),
|
}),
|
||||||
|
'UsageMetadata': dict({
|
||||||
|
'properties': dict({
|
||||||
|
'input_tokens': dict({
|
||||||
|
'title': 'Input Tokens',
|
||||||
|
'type': 'integer',
|
||||||
|
}),
|
||||||
|
'output_tokens': dict({
|
||||||
|
'title': 'Output Tokens',
|
||||||
|
'type': 'integer',
|
||||||
|
}),
|
||||||
|
'total_tokens': dict({
|
||||||
|
'title': 'Total Tokens',
|
||||||
|
'type': 'integer',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'required': list([
|
||||||
|
'input_tokens',
|
||||||
|
'output_tokens',
|
||||||
|
'total_tokens',
|
||||||
|
]),
|
||||||
|
'title': 'UsageMetadata',
|
||||||
|
'type': 'object',
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
'title': 'CommaSeparatedListOutputParserInput',
|
'title': 'CommaSeparatedListOutputParserInput',
|
||||||
})
|
})
|
||||||
|
@ -227,6 +227,29 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
|
|||||||
},
|
},
|
||||||
"required": ["name", "args", "id", "error"],
|
"required": ["name", "args", "id", "error"],
|
||||||
},
|
},
|
||||||
|
"UsageMetadata": {
|
||||||
|
"title": "UsageMetadata",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"input_tokens": {
|
||||||
|
"title": "Input Tokens",
|
||||||
|
"type": "integer",
|
||||||
|
},
|
||||||
|
"output_tokens": {
|
||||||
|
"title": "Output Tokens",
|
||||||
|
"type": "integer",
|
||||||
|
},
|
||||||
|
"total_tokens": {
|
||||||
|
"title": "Total Tokens",
|
||||||
|
"type": "integer",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": [
|
||||||
|
"input_tokens",
|
||||||
|
"output_tokens",
|
||||||
|
"total_tokens",
|
||||||
|
],
|
||||||
|
},
|
||||||
"AIMessage": {
|
"AIMessage": {
|
||||||
"title": "AIMessage",
|
"title": "AIMessage",
|
||||||
"description": "Message from an AI.",
|
"description": "Message from an AI.",
|
||||||
@ -280,6 +303,9 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
|
|||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {"$ref": "#/definitions/InvalidToolCall"},
|
"items": {"$ref": "#/definitions/InvalidToolCall"},
|
||||||
},
|
},
|
||||||
|
"usage_metadata": {
|
||||||
|
"$ref": "#/definitions/UsageMetadata"
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"required": ["content"],
|
"required": ["content"],
|
||||||
},
|
},
|
||||||
|
@ -383,6 +383,16 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
|||||||
},
|
},
|
||||||
"required": ["name", "args", "id", "error"],
|
"required": ["name", "args", "id", "error"],
|
||||||
},
|
},
|
||||||
|
"UsageMetadata": {
|
||||||
|
"title": "UsageMetadata",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"input_tokens": {"title": "Input Tokens", "type": "integer"},
|
||||||
|
"output_tokens": {"title": "Output Tokens", "type": "integer"},
|
||||||
|
"total_tokens": {"title": "Total Tokens", "type": "integer"},
|
||||||
|
},
|
||||||
|
"required": ["input_tokens", "output_tokens", "total_tokens"],
|
||||||
|
},
|
||||||
"AIMessage": {
|
"AIMessage": {
|
||||||
"title": "AIMessage",
|
"title": "AIMessage",
|
||||||
"description": "Message from an AI.",
|
"description": "Message from an AI.",
|
||||||
@ -433,6 +443,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
|||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {"$ref": "#/definitions/InvalidToolCall"},
|
"items": {"$ref": "#/definitions/InvalidToolCall"},
|
||||||
},
|
},
|
||||||
|
"usage_metadata": {"$ref": "#/definitions/UsageMetadata"},
|
||||||
},
|
},
|
||||||
"required": ["content"],
|
"required": ["content"],
|
||||||
},
|
},
|
||||||
|
@ -120,6 +120,22 @@ def test_message_chunks() -> None:
|
|||||||
assert ai_msg_chunk + tool_calls_msg_chunk == tool_calls_msg_chunk
|
assert ai_msg_chunk + tool_calls_msg_chunk == tool_calls_msg_chunk
|
||||||
assert tool_calls_msg_chunk + ai_msg_chunk == tool_calls_msg_chunk
|
assert tool_calls_msg_chunk + ai_msg_chunk == tool_calls_msg_chunk
|
||||||
|
|
||||||
|
# Test token usage
|
||||||
|
left = AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
usage_metadata={"input_tokens": 1, "output_tokens": 2, "total_tokens": 3},
|
||||||
|
)
|
||||||
|
right = AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
usage_metadata={"input_tokens": 4, "output_tokens": 5, "total_tokens": 9},
|
||||||
|
)
|
||||||
|
assert left + right == AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
usage_metadata={"input_tokens": 5, "output_tokens": 7, "total_tokens": 12},
|
||||||
|
)
|
||||||
|
assert AIMessageChunk(content="") + left == left
|
||||||
|
assert right + AIMessageChunk(content="") == right
|
||||||
|
|
||||||
|
|
||||||
def test_chat_message_chunks() -> None:
|
def test_chat_message_chunks() -> None:
|
||||||
assert ChatMessageChunk(role="User", content="I am", id="ai4") + ChatMessageChunk(
|
assert ChatMessageChunk(role="User", content="I am", id="ai4") + ChatMessageChunk(
|
||||||
|
@ -41,6 +41,17 @@ class TestAI21J2(ChatModelIntegrationTests):
|
|||||||
chat_model_params,
|
chat_model_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@pytest.mark.xfail(reason="Not implemented.")
|
||||||
|
def test_usage_metadata(
|
||||||
|
self,
|
||||||
|
chat_model_class: Type[BaseChatModel],
|
||||||
|
chat_model_params: dict,
|
||||||
|
) -> None:
|
||||||
|
super().test_usage_metadata(
|
||||||
|
chat_model_class,
|
||||||
|
chat_model_params,
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def chat_model_params(self) -> dict:
|
def chat_model_params(self) -> dict:
|
||||||
return {
|
return {
|
||||||
@ -79,6 +90,17 @@ class TestAI21Jamba(ChatModelIntegrationTests):
|
|||||||
chat_model_params,
|
chat_model_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@pytest.mark.xfail(reason="Not implemented.")
|
||||||
|
def test_usage_metadata(
|
||||||
|
self,
|
||||||
|
chat_model_class: Type[BaseChatModel],
|
||||||
|
chat_model_params: dict,
|
||||||
|
) -> None:
|
||||||
|
super().test_usage_metadata(
|
||||||
|
chat_model_class,
|
||||||
|
chat_model_params,
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def chat_model_params(self) -> dict:
|
def chat_model_params(self) -> dict:
|
||||||
return {
|
return {
|
||||||
|
@ -503,6 +503,12 @@ class ChatAnthropic(BaseChatModel):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
msg = AIMessage(content=content)
|
msg = AIMessage(content=content)
|
||||||
|
# Collect token usage
|
||||||
|
msg.usage_metadata = {
|
||||||
|
"input_tokens": data.usage.input_tokens,
|
||||||
|
"output_tokens": data.usage.output_tokens,
|
||||||
|
"total_tokens": data.usage.input_tokens + data.usage.output_tokens,
|
||||||
|
}
|
||||||
return ChatResult(
|
return ChatResult(
|
||||||
generations=[ChatGeneration(message=msg)],
|
generations=[ChatGeneration(message=msg)],
|
||||||
llm_output=llm_output,
|
llm_output=llm_output,
|
||||||
|
@ -89,7 +89,16 @@ def test__format_output() -> None:
|
|||||||
)
|
)
|
||||||
expected = ChatResult(
|
expected = ChatResult(
|
||||||
generations=[
|
generations=[
|
||||||
ChatGeneration(message=AIMessage("bar")),
|
ChatGeneration(
|
||||||
|
message=AIMessage(
|
||||||
|
"bar",
|
||||||
|
usage_metadata={
|
||||||
|
"input_tokens": 2,
|
||||||
|
"output_tokens": 1,
|
||||||
|
"total_tokens": 3,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
),
|
||||||
],
|
],
|
||||||
llm_output={
|
llm_output={
|
||||||
"id": "foo",
|
"id": "foo",
|
||||||
|
@ -21,6 +21,17 @@ class TestFireworksStandard(ChatModelIntegrationTests):
|
|||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@pytest.mark.xfail(reason="Not implemented.")
|
||||||
|
def test_usage_metadata(
|
||||||
|
self,
|
||||||
|
chat_model_class: Type[BaseChatModel],
|
||||||
|
chat_model_params: dict,
|
||||||
|
) -> None:
|
||||||
|
super().test_usage_metadata(
|
||||||
|
chat_model_class,
|
||||||
|
chat_model_params,
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.mark.xfail(reason="Not yet implemented.")
|
@pytest.mark.xfail(reason="Not yet implemented.")
|
||||||
def test_tool_message_histories_list_content(
|
def test_tool_message_histories_list_content(
|
||||||
self,
|
self,
|
||||||
|
@ -14,6 +14,17 @@ class TestMistralStandard(ChatModelIntegrationTests):
|
|||||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||||
return ChatGroq
|
return ChatGroq
|
||||||
|
|
||||||
|
@pytest.mark.xfail(reason="Not implemented.")
|
||||||
|
def test_usage_metadata(
|
||||||
|
self,
|
||||||
|
chat_model_class: Type[BaseChatModel],
|
||||||
|
chat_model_params: dict,
|
||||||
|
) -> None:
|
||||||
|
super().test_usage_metadata(
|
||||||
|
chat_model_class,
|
||||||
|
chat_model_params,
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.mark.xfail(reason="Not yet implemented.")
|
@pytest.mark.xfail(reason="Not yet implemented.")
|
||||||
def test_tool_message_histories_list_content(
|
def test_tool_message_histories_list_content(
|
||||||
self,
|
self,
|
||||||
|
@ -20,3 +20,14 @@ class TestMistralStandard(ChatModelIntegrationTests):
|
|||||||
"model": "mistral-large-latest",
|
"model": "mistral-large-latest",
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@pytest.mark.xfail(reason="Not implemented.")
|
||||||
|
def test_usage_metadata(
|
||||||
|
self,
|
||||||
|
chat_model_class: Type[BaseChatModel],
|
||||||
|
chat_model_params: dict,
|
||||||
|
) -> None:
|
||||||
|
super().test_usage_metadata(
|
||||||
|
chat_model_class,
|
||||||
|
chat_model_params,
|
||||||
|
)
|
||||||
|
@ -548,8 +548,15 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
if response.get("error"):
|
if response.get("error"):
|
||||||
raise ValueError(response.get("error"))
|
raise ValueError(response.get("error"))
|
||||||
|
|
||||||
|
token_usage = response.get("usage", {})
|
||||||
for res in response["choices"]:
|
for res in response["choices"]:
|
||||||
message = _convert_dict_to_message(res["message"])
|
message = _convert_dict_to_message(res["message"])
|
||||||
|
if token_usage and isinstance(message, AIMessage):
|
||||||
|
message.usage_metadata = {
|
||||||
|
"input_tokens": token_usage.get("prompt_tokens", 0),
|
||||||
|
"output_tokens": token_usage.get("completion_tokens", 0),
|
||||||
|
"total_tokens": token_usage.get("total_tokens", 0),
|
||||||
|
}
|
||||||
generation_info = dict(finish_reason=res.get("finish_reason"))
|
generation_info = dict(finish_reason=res.get("finish_reason"))
|
||||||
if "logprobs" in res:
|
if "logprobs" in res:
|
||||||
generation_info["logprobs"] = res["logprobs"]
|
generation_info["logprobs"] = res["logprobs"]
|
||||||
@ -558,7 +565,6 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
generation_info=generation_info,
|
generation_info=generation_info,
|
||||||
)
|
)
|
||||||
generations.append(gen)
|
generations.append(gen)
|
||||||
token_usage = response.get("usage", {})
|
|
||||||
llm_output = {
|
llm_output = {
|
||||||
"token_usage": token_usage,
|
"token_usage": token_usage,
|
||||||
"model_name": self.model_name,
|
"model_name": self.model_name,
|
||||||
|
@ -132,6 +132,18 @@ class ChatModelIntegrationTests(ABC):
|
|||||||
assert isinstance(result.content, str)
|
assert isinstance(result.content, str)
|
||||||
assert len(result.content) > 0
|
assert len(result.content) > 0
|
||||||
|
|
||||||
|
def test_usage_metadata(
|
||||||
|
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
|
||||||
|
) -> None:
|
||||||
|
model = chat_model_class(**chat_model_params)
|
||||||
|
result = model.invoke("Hello")
|
||||||
|
assert result is not None
|
||||||
|
assert isinstance(result, AIMessage)
|
||||||
|
assert result.usage_metadata is not None
|
||||||
|
assert isinstance(result.usage_metadata["input_tokens"], int)
|
||||||
|
assert isinstance(result.usage_metadata["output_tokens"], int)
|
||||||
|
assert isinstance(result.usage_metadata["total_tokens"], int)
|
||||||
|
|
||||||
def test_tool_message_histories_string_content(
|
def test_tool_message_histories_string_content(
|
||||||
self,
|
self,
|
||||||
chat_model_class: Type[BaseChatModel],
|
chat_model_class: Type[BaseChatModel],
|
||||||
|
Loading…
Reference in New Issue
Block a user