core, partners: add token usage attribute to AIMessage (#21944)

```python
class UsageMetadata(TypedDict):
    """Usage metadata for a message, such as token counts.

    Attributes:
        input_tokens: (int) count of input (or prompt) tokens
        output_tokens: (int) count of output (or completion) tokens
        total_tokens: (int) total token count
    """

    input_tokens: int
    output_tokens: int
    total_tokens: int
```
```python
class AIMessage(BaseMessage):
    ...
    usage_metadata: Optional[UsageMetadata] = None
    """If provided, token usage information associated with the message."""
    ...
```
This commit is contained in:
ccurme 2024-05-23 14:21:58 -04:00 committed by GitHub
parent 3d26807b92
commit fbfed65fb1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 364 additions and 3 deletions

View File

@ -1,4 +1,6 @@
from typing import Any, Dict, List, Literal, Union
from typing import Any, Dict, List, Literal, Optional, Union
from typing_extensions import TypedDict
from langchain_core.messages.base import (
BaseMessage,
@ -19,6 +21,20 @@ from langchain_core.utils.json import (
)
class UsageMetadata(TypedDict):
"""Usage metadata for a message, such as token counts.
Attributes:
input_tokens: (int) count of input (or prompt) tokens
output_tokens: (int) count of output (or completion) tokens
total_tokens: (int) total token count
"""
input_tokens: int
output_tokens: int
total_tokens: int
class AIMessage(BaseMessage):
"""Message from an AI."""
@ -31,6 +47,11 @@ class AIMessage(BaseMessage):
"""If provided, tool calls associated with the message."""
invalid_tool_calls: List[InvalidToolCall] = []
"""If provided, tool calls with parsing errors associated with the message."""
usage_metadata: Optional[UsageMetadata] = None
"""If provided, usage metadata for a message, such as token counts.
This is a standard representation of token usage that is consistent across models.
"""
type: Literal["ai"] = "ai"
@ -198,12 +219,29 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
else:
tool_call_chunks = []
# Token usage
if self.usage_metadata or other.usage_metadata:
left: UsageMetadata = self.usage_metadata or UsageMetadata(
input_tokens=0, output_tokens=0, total_tokens=0
)
right: UsageMetadata = other.usage_metadata or UsageMetadata(
input_tokens=0, output_tokens=0, total_tokens=0
)
usage_metadata: Optional[UsageMetadata] = {
"input_tokens": left["input_tokens"] + right["input_tokens"],
"output_tokens": left["output_tokens"] + right["output_tokens"],
"total_tokens": left["total_tokens"] + right["total_tokens"],
}
else:
usage_metadata = None
return self.__class__(
example=self.example,
content=content,
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks,
response_metadata=response_metadata,
usage_metadata=usage_metadata,
id=self.id,
)

View File

@ -5286,6 +5286,9 @@
'title': 'Type',
'type': 'string',
}),
'usage_metadata': dict({
'$ref': '#/definitions/UsageMetadata',
}),
}),
'required': list([
'content',
@ -5707,6 +5710,29 @@
'title': 'ToolMessage',
'type': 'object',
}),
'UsageMetadata': dict({
'properties': dict({
'input_tokens': dict({
'title': 'Input Tokens',
'type': 'integer',
}),
'output_tokens': dict({
'title': 'Output Tokens',
'type': 'integer',
}),
'total_tokens': dict({
'title': 'Total Tokens',
'type': 'integer',
}),
}),
'required': list([
'input_tokens',
'output_tokens',
'total_tokens',
]),
'title': 'UsageMetadata',
'type': 'object',
}),
}),
'title': 'FakeListLLMInput',
})
@ -5821,6 +5847,9 @@
'title': 'Type',
'type': 'string',
}),
'usage_metadata': dict({
'$ref': '#/definitions/UsageMetadata',
}),
}),
'required': list([
'content',
@ -6242,6 +6271,29 @@
'title': 'ToolMessage',
'type': 'object',
}),
'UsageMetadata': dict({
'properties': dict({
'input_tokens': dict({
'title': 'Input Tokens',
'type': 'integer',
}),
'output_tokens': dict({
'title': 'Output Tokens',
'type': 'integer',
}),
'total_tokens': dict({
'title': 'Total Tokens',
'type': 'integer',
}),
}),
'required': list([
'input_tokens',
'output_tokens',
'total_tokens',
]),
'title': 'UsageMetadata',
'type': 'object',
}),
}),
'title': 'FakeListChatModelInput',
})
@ -6340,6 +6392,9 @@
'title': 'Type',
'type': 'string',
}),
'usage_metadata': dict({
'$ref': '#/definitions/UsageMetadata',
}),
}),
'required': list([
'content',
@ -6692,6 +6747,29 @@
'title': 'ToolMessage',
'type': 'object',
}),
'UsageMetadata': dict({
'properties': dict({
'input_tokens': dict({
'title': 'Input Tokens',
'type': 'integer',
}),
'output_tokens': dict({
'title': 'Output Tokens',
'type': 'integer',
}),
'total_tokens': dict({
'title': 'Total Tokens',
'type': 'integer',
}),
}),
'required': list([
'input_tokens',
'output_tokens',
'total_tokens',
]),
'title': 'UsageMetadata',
'type': 'object',
}),
}),
'title': 'FakeListChatModelOutput',
})
@ -6778,6 +6856,9 @@
'title': 'Type',
'type': 'string',
}),
'usage_metadata': dict({
'$ref': '#/definitions/UsageMetadata',
}),
}),
'required': list([
'content',
@ -7199,6 +7280,29 @@
'title': 'ToolMessage',
'type': 'object',
}),
'UsageMetadata': dict({
'properties': dict({
'input_tokens': dict({
'title': 'Input Tokens',
'type': 'integer',
}),
'output_tokens': dict({
'title': 'Output Tokens',
'type': 'integer',
}),
'total_tokens': dict({
'title': 'Total Tokens',
'type': 'integer',
}),
}),
'required': list([
'input_tokens',
'output_tokens',
'total_tokens',
]),
'title': 'UsageMetadata',
'type': 'object',
}),
}),
'title': 'ChatPromptTemplateOutput',
})
@ -7285,6 +7389,9 @@
'title': 'Type',
'type': 'string',
}),
'usage_metadata': dict({
'$ref': '#/definitions/UsageMetadata',
}),
}),
'required': list([
'content',
@ -7706,6 +7813,29 @@
'title': 'ToolMessage',
'type': 'object',
}),
'UsageMetadata': dict({
'properties': dict({
'input_tokens': dict({
'title': 'Input Tokens',
'type': 'integer',
}),
'output_tokens': dict({
'title': 'Output Tokens',
'type': 'integer',
}),
'total_tokens': dict({
'title': 'Total Tokens',
'type': 'integer',
}),
}),
'required': list([
'input_tokens',
'output_tokens',
'total_tokens',
]),
'title': 'UsageMetadata',
'type': 'object',
}),
}),
'title': 'PromptTemplateOutput',
})
@ -7784,6 +7914,9 @@
'title': 'Type',
'type': 'string',
}),
'usage_metadata': dict({
'$ref': '#/definitions/UsageMetadata',
}),
}),
'required': list([
'content',
@ -8216,6 +8349,29 @@
'title': 'ToolMessage',
'type': 'object',
}),
'UsageMetadata': dict({
'properties': dict({
'input_tokens': dict({
'title': 'Input Tokens',
'type': 'integer',
}),
'output_tokens': dict({
'title': 'Output Tokens',
'type': 'integer',
}),
'total_tokens': dict({
'title': 'Total Tokens',
'type': 'integer',
}),
}),
'required': list([
'input_tokens',
'output_tokens',
'total_tokens',
]),
'title': 'UsageMetadata',
'type': 'object',
}),
}),
'items': dict({
'$ref': '#/definitions/PromptTemplateOutput',
@ -8321,6 +8477,9 @@
'title': 'Type',
'type': 'string',
}),
'usage_metadata': dict({
'$ref': '#/definitions/UsageMetadata',
}),
}),
'required': list([
'content',
@ -8673,6 +8832,29 @@
'title': 'ToolMessage',
'type': 'object',
}),
'UsageMetadata': dict({
'properties': dict({
'input_tokens': dict({
'title': 'Input Tokens',
'type': 'integer',
}),
'output_tokens': dict({
'title': 'Output Tokens',
'type': 'integer',
}),
'total_tokens': dict({
'title': 'Total Tokens',
'type': 'integer',
}),
}),
'required': list([
'input_tokens',
'output_tokens',
'total_tokens',
]),
'title': 'UsageMetadata',
'type': 'object',
}),
}),
'title': 'CommaSeparatedListOutputParserInput',
})

View File

@ -227,6 +227,29 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
},
"required": ["name", "args", "id", "error"],
},
"UsageMetadata": {
"title": "UsageMetadata",
"type": "object",
"properties": {
"input_tokens": {
"title": "Input Tokens",
"type": "integer",
},
"output_tokens": {
"title": "Output Tokens",
"type": "integer",
},
"total_tokens": {
"title": "Total Tokens",
"type": "integer",
},
},
"required": [
"input_tokens",
"output_tokens",
"total_tokens",
],
},
"AIMessage": {
"title": "AIMessage",
"description": "Message from an AI.",
@ -280,6 +303,9 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
"type": "array",
"items": {"$ref": "#/definitions/InvalidToolCall"},
},
"usage_metadata": {
"$ref": "#/definitions/UsageMetadata"
},
},
"required": ["content"],
},

View File

@ -383,6 +383,16 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
},
"required": ["name", "args", "id", "error"],
},
"UsageMetadata": {
"title": "UsageMetadata",
"type": "object",
"properties": {
"input_tokens": {"title": "Input Tokens", "type": "integer"},
"output_tokens": {"title": "Output Tokens", "type": "integer"},
"total_tokens": {"title": "Total Tokens", "type": "integer"},
},
"required": ["input_tokens", "output_tokens", "total_tokens"],
},
"AIMessage": {
"title": "AIMessage",
"description": "Message from an AI.",
@ -433,6 +443,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
"type": "array",
"items": {"$ref": "#/definitions/InvalidToolCall"},
},
"usage_metadata": {"$ref": "#/definitions/UsageMetadata"},
},
"required": ["content"],
},

View File

@ -120,6 +120,22 @@ def test_message_chunks() -> None:
assert ai_msg_chunk + tool_calls_msg_chunk == tool_calls_msg_chunk
assert tool_calls_msg_chunk + ai_msg_chunk == tool_calls_msg_chunk
# Test token usage
left = AIMessageChunk(
content="",
usage_metadata={"input_tokens": 1, "output_tokens": 2, "total_tokens": 3},
)
right = AIMessageChunk(
content="",
usage_metadata={"input_tokens": 4, "output_tokens": 5, "total_tokens": 9},
)
assert left + right == AIMessageChunk(
content="",
usage_metadata={"input_tokens": 5, "output_tokens": 7, "total_tokens": 12},
)
assert AIMessageChunk(content="") + left == left
assert right + AIMessageChunk(content="") == right
def test_chat_message_chunks() -> None:
assert ChatMessageChunk(role="User", content="I am", id="ai4") + ChatMessageChunk(

View File

@ -41,6 +41,17 @@ class TestAI21J2(ChatModelIntegrationTests):
chat_model_params,
)
@pytest.mark.xfail(reason="Not implemented.")
def test_usage_metadata(
self,
chat_model_class: Type[BaseChatModel],
chat_model_params: dict,
) -> None:
super().test_usage_metadata(
chat_model_class,
chat_model_params,
)
@pytest.fixture
def chat_model_params(self) -> dict:
return {
@ -79,6 +90,17 @@ class TestAI21Jamba(ChatModelIntegrationTests):
chat_model_params,
)
@pytest.mark.xfail(reason="Not implemented.")
def test_usage_metadata(
self,
chat_model_class: Type[BaseChatModel],
chat_model_params: dict,
) -> None:
super().test_usage_metadata(
chat_model_class,
chat_model_params,
)
@pytest.fixture
def chat_model_params(self) -> dict:
return {

View File

@ -503,6 +503,12 @@ class ChatAnthropic(BaseChatModel):
)
else:
msg = AIMessage(content=content)
# Collect token usage
msg.usage_metadata = {
"input_tokens": data.usage.input_tokens,
"output_tokens": data.usage.output_tokens,
"total_tokens": data.usage.input_tokens + data.usage.output_tokens,
}
return ChatResult(
generations=[ChatGeneration(message=msg)],
llm_output=llm_output,

View File

@ -89,7 +89,16 @@ def test__format_output() -> None:
)
expected = ChatResult(
generations=[
ChatGeneration(message=AIMessage("bar")),
ChatGeneration(
message=AIMessage(
"bar",
usage_metadata={
"input_tokens": 2,
"output_tokens": 1,
"total_tokens": 3,
},
)
),
],
llm_output={
"id": "foo",

View File

@ -21,6 +21,17 @@ class TestFireworksStandard(ChatModelIntegrationTests):
"temperature": 0,
}
@pytest.mark.xfail(reason="Not implemented.")
def test_usage_metadata(
self,
chat_model_class: Type[BaseChatModel],
chat_model_params: dict,
) -> None:
super().test_usage_metadata(
chat_model_class,
chat_model_params,
)
@pytest.mark.xfail(reason="Not yet implemented.")
def test_tool_message_histories_list_content(
self,

View File

@ -14,6 +14,17 @@ class TestMistralStandard(ChatModelIntegrationTests):
def chat_model_class(self) -> Type[BaseChatModel]:
return ChatGroq
@pytest.mark.xfail(reason="Not implemented.")
def test_usage_metadata(
self,
chat_model_class: Type[BaseChatModel],
chat_model_params: dict,
) -> None:
super().test_usage_metadata(
chat_model_class,
chat_model_params,
)
@pytest.mark.xfail(reason="Not yet implemented.")
def test_tool_message_histories_list_content(
self,

View File

@ -20,3 +20,14 @@ class TestMistralStandard(ChatModelIntegrationTests):
"model": "mistral-large-latest",
"temperature": 0,
}
@pytest.mark.xfail(reason="Not implemented.")
def test_usage_metadata(
self,
chat_model_class: Type[BaseChatModel],
chat_model_params: dict,
) -> None:
super().test_usage_metadata(
chat_model_class,
chat_model_params,
)

View File

@ -548,8 +548,15 @@ class BaseChatOpenAI(BaseChatModel):
if response.get("error"):
raise ValueError(response.get("error"))
token_usage = response.get("usage", {})
for res in response["choices"]:
message = _convert_dict_to_message(res["message"])
if token_usage and isinstance(message, AIMessage):
message.usage_metadata = {
"input_tokens": token_usage.get("prompt_tokens", 0),
"output_tokens": token_usage.get("completion_tokens", 0),
"total_tokens": token_usage.get("total_tokens", 0),
}
generation_info = dict(finish_reason=res.get("finish_reason"))
if "logprobs" in res:
generation_info["logprobs"] = res["logprobs"]
@ -558,7 +565,6 @@ class BaseChatOpenAI(BaseChatModel):
generation_info=generation_info,
)
generations.append(gen)
token_usage = response.get("usage", {})
llm_output = {
"token_usage": token_usage,
"model_name": self.model_name,

View File

@ -132,6 +132,18 @@ class ChatModelIntegrationTests(ABC):
assert isinstance(result.content, str)
assert len(result.content) > 0
def test_usage_metadata(
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
) -> None:
model = chat_model_class(**chat_model_params)
result = model.invoke("Hello")
assert result is not None
assert isinstance(result, AIMessage)
assert result.usage_metadata is not None
assert isinstance(result.usage_metadata["input_tokens"], int)
assert isinstance(result.usage_metadata["output_tokens"], int)
assert isinstance(result.usage_metadata["total_tokens"], int)
def test_tool_message_histories_string_content(
self,
chat_model_class: Type[BaseChatModel],