mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 14:18:52 +00:00
core, partners: add token usage attribute to AIMessage (#21944)
```python class UsageMetadata(TypedDict): """Usage metadata for a message, such as token counts. Attributes: input_tokens: (int) count of input (or prompt) tokens output_tokens: (int) count of output (or completion) tokens total_tokens: (int) total token count """ input_tokens: int output_tokens: int total_tokens: int ``` ```python class AIMessage(BaseMessage): ... usage_metadata: Optional[UsageMetadata] = None """If provided, token usage information associated with the message.""" ... ```
This commit is contained in:
parent
3d26807b92
commit
fbfed65fb1
@ -1,4 +1,6 @@
|
||||
from typing import Any, Dict, List, Literal, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain_core.messages.base import (
|
||||
BaseMessage,
|
||||
@ -19,6 +21,20 @@ from langchain_core.utils.json import (
|
||||
)
|
||||
|
||||
|
||||
class UsageMetadata(TypedDict):
|
||||
"""Usage metadata for a message, such as token counts.
|
||||
|
||||
Attributes:
|
||||
input_tokens: (int) count of input (or prompt) tokens
|
||||
output_tokens: (int) count of output (or completion) tokens
|
||||
total_tokens: (int) total token count
|
||||
"""
|
||||
|
||||
input_tokens: int
|
||||
output_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class AIMessage(BaseMessage):
|
||||
"""Message from an AI."""
|
||||
|
||||
@ -31,6 +47,11 @@ class AIMessage(BaseMessage):
|
||||
"""If provided, tool calls associated with the message."""
|
||||
invalid_tool_calls: List[InvalidToolCall] = []
|
||||
"""If provided, tool calls with parsing errors associated with the message."""
|
||||
usage_metadata: Optional[UsageMetadata] = None
|
||||
"""If provided, usage metadata for a message, such as token counts.
|
||||
|
||||
This is a standard representation of token usage that is consistent across models.
|
||||
"""
|
||||
|
||||
type: Literal["ai"] = "ai"
|
||||
|
||||
@ -198,12 +219,29 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
else:
|
||||
tool_call_chunks = []
|
||||
|
||||
# Token usage
|
||||
if self.usage_metadata or other.usage_metadata:
|
||||
left: UsageMetadata = self.usage_metadata or UsageMetadata(
|
||||
input_tokens=0, output_tokens=0, total_tokens=0
|
||||
)
|
||||
right: UsageMetadata = other.usage_metadata or UsageMetadata(
|
||||
input_tokens=0, output_tokens=0, total_tokens=0
|
||||
)
|
||||
usage_metadata: Optional[UsageMetadata] = {
|
||||
"input_tokens": left["input_tokens"] + right["input_tokens"],
|
||||
"output_tokens": left["output_tokens"] + right["output_tokens"],
|
||||
"total_tokens": left["total_tokens"] + right["total_tokens"],
|
||||
}
|
||||
else:
|
||||
usage_metadata = None
|
||||
|
||||
return self.__class__(
|
||||
example=self.example,
|
||||
content=content,
|
||||
additional_kwargs=additional_kwargs,
|
||||
tool_call_chunks=tool_call_chunks,
|
||||
response_metadata=response_metadata,
|
||||
usage_metadata=usage_metadata,
|
||||
id=self.id,
|
||||
)
|
||||
|
||||
|
@ -5286,6 +5286,9 @@
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
'usage_metadata': dict({
|
||||
'$ref': '#/definitions/UsageMetadata',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'content',
|
||||
@ -5707,6 +5710,29 @@
|
||||
'title': 'ToolMessage',
|
||||
'type': 'object',
|
||||
}),
|
||||
'UsageMetadata': dict({
|
||||
'properties': dict({
|
||||
'input_tokens': dict({
|
||||
'title': 'Input Tokens',
|
||||
'type': 'integer',
|
||||
}),
|
||||
'output_tokens': dict({
|
||||
'title': 'Output Tokens',
|
||||
'type': 'integer',
|
||||
}),
|
||||
'total_tokens': dict({
|
||||
'title': 'Total Tokens',
|
||||
'type': 'integer',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'input_tokens',
|
||||
'output_tokens',
|
||||
'total_tokens',
|
||||
]),
|
||||
'title': 'UsageMetadata',
|
||||
'type': 'object',
|
||||
}),
|
||||
}),
|
||||
'title': 'FakeListLLMInput',
|
||||
})
|
||||
@ -5821,6 +5847,9 @@
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
'usage_metadata': dict({
|
||||
'$ref': '#/definitions/UsageMetadata',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'content',
|
||||
@ -6242,6 +6271,29 @@
|
||||
'title': 'ToolMessage',
|
||||
'type': 'object',
|
||||
}),
|
||||
'UsageMetadata': dict({
|
||||
'properties': dict({
|
||||
'input_tokens': dict({
|
||||
'title': 'Input Tokens',
|
||||
'type': 'integer',
|
||||
}),
|
||||
'output_tokens': dict({
|
||||
'title': 'Output Tokens',
|
||||
'type': 'integer',
|
||||
}),
|
||||
'total_tokens': dict({
|
||||
'title': 'Total Tokens',
|
||||
'type': 'integer',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'input_tokens',
|
||||
'output_tokens',
|
||||
'total_tokens',
|
||||
]),
|
||||
'title': 'UsageMetadata',
|
||||
'type': 'object',
|
||||
}),
|
||||
}),
|
||||
'title': 'FakeListChatModelInput',
|
||||
})
|
||||
@ -6340,6 +6392,9 @@
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
'usage_metadata': dict({
|
||||
'$ref': '#/definitions/UsageMetadata',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'content',
|
||||
@ -6692,6 +6747,29 @@
|
||||
'title': 'ToolMessage',
|
||||
'type': 'object',
|
||||
}),
|
||||
'UsageMetadata': dict({
|
||||
'properties': dict({
|
||||
'input_tokens': dict({
|
||||
'title': 'Input Tokens',
|
||||
'type': 'integer',
|
||||
}),
|
||||
'output_tokens': dict({
|
||||
'title': 'Output Tokens',
|
||||
'type': 'integer',
|
||||
}),
|
||||
'total_tokens': dict({
|
||||
'title': 'Total Tokens',
|
||||
'type': 'integer',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'input_tokens',
|
||||
'output_tokens',
|
||||
'total_tokens',
|
||||
]),
|
||||
'title': 'UsageMetadata',
|
||||
'type': 'object',
|
||||
}),
|
||||
}),
|
||||
'title': 'FakeListChatModelOutput',
|
||||
})
|
||||
@ -6778,6 +6856,9 @@
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
'usage_metadata': dict({
|
||||
'$ref': '#/definitions/UsageMetadata',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'content',
|
||||
@ -7199,6 +7280,29 @@
|
||||
'title': 'ToolMessage',
|
||||
'type': 'object',
|
||||
}),
|
||||
'UsageMetadata': dict({
|
||||
'properties': dict({
|
||||
'input_tokens': dict({
|
||||
'title': 'Input Tokens',
|
||||
'type': 'integer',
|
||||
}),
|
||||
'output_tokens': dict({
|
||||
'title': 'Output Tokens',
|
||||
'type': 'integer',
|
||||
}),
|
||||
'total_tokens': dict({
|
||||
'title': 'Total Tokens',
|
||||
'type': 'integer',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'input_tokens',
|
||||
'output_tokens',
|
||||
'total_tokens',
|
||||
]),
|
||||
'title': 'UsageMetadata',
|
||||
'type': 'object',
|
||||
}),
|
||||
}),
|
||||
'title': 'ChatPromptTemplateOutput',
|
||||
})
|
||||
@ -7285,6 +7389,9 @@
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
'usage_metadata': dict({
|
||||
'$ref': '#/definitions/UsageMetadata',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'content',
|
||||
@ -7706,6 +7813,29 @@
|
||||
'title': 'ToolMessage',
|
||||
'type': 'object',
|
||||
}),
|
||||
'UsageMetadata': dict({
|
||||
'properties': dict({
|
||||
'input_tokens': dict({
|
||||
'title': 'Input Tokens',
|
||||
'type': 'integer',
|
||||
}),
|
||||
'output_tokens': dict({
|
||||
'title': 'Output Tokens',
|
||||
'type': 'integer',
|
||||
}),
|
||||
'total_tokens': dict({
|
||||
'title': 'Total Tokens',
|
||||
'type': 'integer',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'input_tokens',
|
||||
'output_tokens',
|
||||
'total_tokens',
|
||||
]),
|
||||
'title': 'UsageMetadata',
|
||||
'type': 'object',
|
||||
}),
|
||||
}),
|
||||
'title': 'PromptTemplateOutput',
|
||||
})
|
||||
@ -7784,6 +7914,9 @@
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
'usage_metadata': dict({
|
||||
'$ref': '#/definitions/UsageMetadata',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'content',
|
||||
@ -8216,6 +8349,29 @@
|
||||
'title': 'ToolMessage',
|
||||
'type': 'object',
|
||||
}),
|
||||
'UsageMetadata': dict({
|
||||
'properties': dict({
|
||||
'input_tokens': dict({
|
||||
'title': 'Input Tokens',
|
||||
'type': 'integer',
|
||||
}),
|
||||
'output_tokens': dict({
|
||||
'title': 'Output Tokens',
|
||||
'type': 'integer',
|
||||
}),
|
||||
'total_tokens': dict({
|
||||
'title': 'Total Tokens',
|
||||
'type': 'integer',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'input_tokens',
|
||||
'output_tokens',
|
||||
'total_tokens',
|
||||
]),
|
||||
'title': 'UsageMetadata',
|
||||
'type': 'object',
|
||||
}),
|
||||
}),
|
||||
'items': dict({
|
||||
'$ref': '#/definitions/PromptTemplateOutput',
|
||||
@ -8321,6 +8477,9 @@
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
'usage_metadata': dict({
|
||||
'$ref': '#/definitions/UsageMetadata',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'content',
|
||||
@ -8673,6 +8832,29 @@
|
||||
'title': 'ToolMessage',
|
||||
'type': 'object',
|
||||
}),
|
||||
'UsageMetadata': dict({
|
||||
'properties': dict({
|
||||
'input_tokens': dict({
|
||||
'title': 'Input Tokens',
|
||||
'type': 'integer',
|
||||
}),
|
||||
'output_tokens': dict({
|
||||
'title': 'Output Tokens',
|
||||
'type': 'integer',
|
||||
}),
|
||||
'total_tokens': dict({
|
||||
'title': 'Total Tokens',
|
||||
'type': 'integer',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'input_tokens',
|
||||
'output_tokens',
|
||||
'total_tokens',
|
||||
]),
|
||||
'title': 'UsageMetadata',
|
||||
'type': 'object',
|
||||
}),
|
||||
}),
|
||||
'title': 'CommaSeparatedListOutputParserInput',
|
||||
})
|
||||
|
@ -227,6 +227,29 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
|
||||
},
|
||||
"required": ["name", "args", "id", "error"],
|
||||
},
|
||||
"UsageMetadata": {
|
||||
"title": "UsageMetadata",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"input_tokens": {
|
||||
"title": "Input Tokens",
|
||||
"type": "integer",
|
||||
},
|
||||
"output_tokens": {
|
||||
"title": "Output Tokens",
|
||||
"type": "integer",
|
||||
},
|
||||
"total_tokens": {
|
||||
"title": "Total Tokens",
|
||||
"type": "integer",
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
"input_tokens",
|
||||
"output_tokens",
|
||||
"total_tokens",
|
||||
],
|
||||
},
|
||||
"AIMessage": {
|
||||
"title": "AIMessage",
|
||||
"description": "Message from an AI.",
|
||||
@ -280,6 +303,9 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
|
||||
"type": "array",
|
||||
"items": {"$ref": "#/definitions/InvalidToolCall"},
|
||||
},
|
||||
"usage_metadata": {
|
||||
"$ref": "#/definitions/UsageMetadata"
|
||||
},
|
||||
},
|
||||
"required": ["content"],
|
||||
},
|
||||
|
@ -383,6 +383,16 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
||||
},
|
||||
"required": ["name", "args", "id", "error"],
|
||||
},
|
||||
"UsageMetadata": {
|
||||
"title": "UsageMetadata",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"input_tokens": {"title": "Input Tokens", "type": "integer"},
|
||||
"output_tokens": {"title": "Output Tokens", "type": "integer"},
|
||||
"total_tokens": {"title": "Total Tokens", "type": "integer"},
|
||||
},
|
||||
"required": ["input_tokens", "output_tokens", "total_tokens"],
|
||||
},
|
||||
"AIMessage": {
|
||||
"title": "AIMessage",
|
||||
"description": "Message from an AI.",
|
||||
@ -433,6 +443,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
||||
"type": "array",
|
||||
"items": {"$ref": "#/definitions/InvalidToolCall"},
|
||||
},
|
||||
"usage_metadata": {"$ref": "#/definitions/UsageMetadata"},
|
||||
},
|
||||
"required": ["content"],
|
||||
},
|
||||
|
@ -120,6 +120,22 @@ def test_message_chunks() -> None:
|
||||
assert ai_msg_chunk + tool_calls_msg_chunk == tool_calls_msg_chunk
|
||||
assert tool_calls_msg_chunk + ai_msg_chunk == tool_calls_msg_chunk
|
||||
|
||||
# Test token usage
|
||||
left = AIMessageChunk(
|
||||
content="",
|
||||
usage_metadata={"input_tokens": 1, "output_tokens": 2, "total_tokens": 3},
|
||||
)
|
||||
right = AIMessageChunk(
|
||||
content="",
|
||||
usage_metadata={"input_tokens": 4, "output_tokens": 5, "total_tokens": 9},
|
||||
)
|
||||
assert left + right == AIMessageChunk(
|
||||
content="",
|
||||
usage_metadata={"input_tokens": 5, "output_tokens": 7, "total_tokens": 12},
|
||||
)
|
||||
assert AIMessageChunk(content="") + left == left
|
||||
assert right + AIMessageChunk(content="") == right
|
||||
|
||||
|
||||
def test_chat_message_chunks() -> None:
|
||||
assert ChatMessageChunk(role="User", content="I am", id="ai4") + ChatMessageChunk(
|
||||
|
@ -41,6 +41,17 @@ class TestAI21J2(ChatModelIntegrationTests):
|
||||
chat_model_params,
|
||||
)
|
||||
|
||||
@pytest.mark.xfail(reason="Not implemented.")
|
||||
def test_usage_metadata(
|
||||
self,
|
||||
chat_model_class: Type[BaseChatModel],
|
||||
chat_model_params: dict,
|
||||
) -> None:
|
||||
super().test_usage_metadata(
|
||||
chat_model_class,
|
||||
chat_model_params,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
@ -79,6 +90,17 @@ class TestAI21Jamba(ChatModelIntegrationTests):
|
||||
chat_model_params,
|
||||
)
|
||||
|
||||
@pytest.mark.xfail(reason="Not implemented.")
|
||||
def test_usage_metadata(
|
||||
self,
|
||||
chat_model_class: Type[BaseChatModel],
|
||||
chat_model_params: dict,
|
||||
) -> None:
|
||||
super().test_usage_metadata(
|
||||
chat_model_class,
|
||||
chat_model_params,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
|
@ -503,6 +503,12 @@ class ChatAnthropic(BaseChatModel):
|
||||
)
|
||||
else:
|
||||
msg = AIMessage(content=content)
|
||||
# Collect token usage
|
||||
msg.usage_metadata = {
|
||||
"input_tokens": data.usage.input_tokens,
|
||||
"output_tokens": data.usage.output_tokens,
|
||||
"total_tokens": data.usage.input_tokens + data.usage.output_tokens,
|
||||
}
|
||||
return ChatResult(
|
||||
generations=[ChatGeneration(message=msg)],
|
||||
llm_output=llm_output,
|
||||
|
@ -89,7 +89,16 @@ def test__format_output() -> None:
|
||||
)
|
||||
expected = ChatResult(
|
||||
generations=[
|
||||
ChatGeneration(message=AIMessage("bar")),
|
||||
ChatGeneration(
|
||||
message=AIMessage(
|
||||
"bar",
|
||||
usage_metadata={
|
||||
"input_tokens": 2,
|
||||
"output_tokens": 1,
|
||||
"total_tokens": 3,
|
||||
},
|
||||
)
|
||||
),
|
||||
],
|
||||
llm_output={
|
||||
"id": "foo",
|
||||
|
@ -21,6 +21,17 @@ class TestFireworksStandard(ChatModelIntegrationTests):
|
||||
"temperature": 0,
|
||||
}
|
||||
|
||||
@pytest.mark.xfail(reason="Not implemented.")
|
||||
def test_usage_metadata(
|
||||
self,
|
||||
chat_model_class: Type[BaseChatModel],
|
||||
chat_model_params: dict,
|
||||
) -> None:
|
||||
super().test_usage_metadata(
|
||||
chat_model_class,
|
||||
chat_model_params,
|
||||
)
|
||||
|
||||
@pytest.mark.xfail(reason="Not yet implemented.")
|
||||
def test_tool_message_histories_list_content(
|
||||
self,
|
||||
|
@ -14,6 +14,17 @@ class TestMistralStandard(ChatModelIntegrationTests):
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatGroq
|
||||
|
||||
@pytest.mark.xfail(reason="Not implemented.")
|
||||
def test_usage_metadata(
|
||||
self,
|
||||
chat_model_class: Type[BaseChatModel],
|
||||
chat_model_params: dict,
|
||||
) -> None:
|
||||
super().test_usage_metadata(
|
||||
chat_model_class,
|
||||
chat_model_params,
|
||||
)
|
||||
|
||||
@pytest.mark.xfail(reason="Not yet implemented.")
|
||||
def test_tool_message_histories_list_content(
|
||||
self,
|
||||
|
@ -20,3 +20,14 @@ class TestMistralStandard(ChatModelIntegrationTests):
|
||||
"model": "mistral-large-latest",
|
||||
"temperature": 0,
|
||||
}
|
||||
|
||||
@pytest.mark.xfail(reason="Not implemented.")
|
||||
def test_usage_metadata(
|
||||
self,
|
||||
chat_model_class: Type[BaseChatModel],
|
||||
chat_model_params: dict,
|
||||
) -> None:
|
||||
super().test_usage_metadata(
|
||||
chat_model_class,
|
||||
chat_model_params,
|
||||
)
|
||||
|
@ -548,8 +548,15 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
if response.get("error"):
|
||||
raise ValueError(response.get("error"))
|
||||
|
||||
token_usage = response.get("usage", {})
|
||||
for res in response["choices"]:
|
||||
message = _convert_dict_to_message(res["message"])
|
||||
if token_usage and isinstance(message, AIMessage):
|
||||
message.usage_metadata = {
|
||||
"input_tokens": token_usage.get("prompt_tokens", 0),
|
||||
"output_tokens": token_usage.get("completion_tokens", 0),
|
||||
"total_tokens": token_usage.get("total_tokens", 0),
|
||||
}
|
||||
generation_info = dict(finish_reason=res.get("finish_reason"))
|
||||
if "logprobs" in res:
|
||||
generation_info["logprobs"] = res["logprobs"]
|
||||
@ -558,7 +565,6 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
generation_info=generation_info,
|
||||
)
|
||||
generations.append(gen)
|
||||
token_usage = response.get("usage", {})
|
||||
llm_output = {
|
||||
"token_usage": token_usage,
|
||||
"model_name": self.model_name,
|
||||
|
@ -132,6 +132,18 @@ class ChatModelIntegrationTests(ABC):
|
||||
assert isinstance(result.content, str)
|
||||
assert len(result.content) > 0
|
||||
|
||||
def test_usage_metadata(
|
||||
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
|
||||
) -> None:
|
||||
model = chat_model_class(**chat_model_params)
|
||||
result = model.invoke("Hello")
|
||||
assert result is not None
|
||||
assert isinstance(result, AIMessage)
|
||||
assert result.usage_metadata is not None
|
||||
assert isinstance(result.usage_metadata["input_tokens"], int)
|
||||
assert isinstance(result.usage_metadata["output_tokens"], int)
|
||||
assert isinstance(result.usage_metadata["total_tokens"], int)
|
||||
|
||||
def test_tool_message_histories_string_content(
|
||||
self,
|
||||
chat_model_class: Type[BaseChatModel],
|
||||
|
Loading…
Reference in New Issue
Block a user