diff --git a/libs/core/langchain_core/messages/ai.py b/libs/core/langchain_core/messages/ai.py index 4ab316b9d08..130b657d48d 100644 --- a/libs/core/langchain_core/messages/ai.py +++ b/libs/core/langchain_core/messages/ai.py @@ -1,4 +1,6 @@ -from typing import Any, Dict, List, Literal, Union +from typing import Any, Dict, List, Literal, Optional, Union + +from typing_extensions import TypedDict from langchain_core.messages.base import ( BaseMessage, @@ -19,6 +21,20 @@ from langchain_core.utils.json import ( ) +class UsageMetadata(TypedDict): + """Usage metadata for a message, such as token counts. + + Attributes: + input_tokens: (int) count of input (or prompt) tokens + output_tokens: (int) count of output (or completion) tokens + total_tokens: (int) total token count + """ + + input_tokens: int + output_tokens: int + total_tokens: int + + class AIMessage(BaseMessage): """Message from an AI.""" @@ -31,6 +47,11 @@ class AIMessage(BaseMessage): """If provided, tool calls associated with the message.""" invalid_tool_calls: List[InvalidToolCall] = [] """If provided, tool calls with parsing errors associated with the message.""" + usage_metadata: Optional[UsageMetadata] = None + """If provided, usage metadata for a message, such as token counts. + + This is a standard representation of token usage that is consistent across models. + """ type: Literal["ai"] = "ai" @@ -198,12 +219,29 @@ class AIMessageChunk(AIMessage, BaseMessageChunk): else: tool_call_chunks = [] + # Token usage + if self.usage_metadata or other.usage_metadata: + left: UsageMetadata = self.usage_metadata or UsageMetadata( + input_tokens=0, output_tokens=0, total_tokens=0 + ) + right: UsageMetadata = other.usage_metadata or UsageMetadata( + input_tokens=0, output_tokens=0, total_tokens=0 + ) + usage_metadata: Optional[UsageMetadata] = { + "input_tokens": left["input_tokens"] + right["input_tokens"], + "output_tokens": left["output_tokens"] + right["output_tokens"], + "total_tokens": left["total_tokens"] + right["total_tokens"], + } + else: + usage_metadata = None + return self.__class__( example=self.example, content=content, additional_kwargs=additional_kwargs, tool_call_chunks=tool_call_chunks, response_metadata=response_metadata, + usage_metadata=usage_metadata, id=self.id, ) diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr index a0a7ce07eff..129ae6e0bca 100644 --- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr +++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr @@ -5286,6 +5286,9 @@ 'title': 'Type', 'type': 'string', }), + 'usage_metadata': dict({ + '$ref': '#/definitions/UsageMetadata', + }), }), 'required': list([ 'content', @@ -5707,6 +5710,29 @@ 'title': 'ToolMessage', 'type': 'object', }), + 'UsageMetadata': dict({ + 'properties': dict({ + 'input_tokens': dict({ + 'title': 'Input Tokens', + 'type': 'integer', + }), + 'output_tokens': dict({ + 'title': 'Output Tokens', + 'type': 'integer', + }), + 'total_tokens': dict({ + 'title': 'Total Tokens', + 'type': 'integer', + }), + }), + 'required': list([ + 'input_tokens', + 'output_tokens', + 'total_tokens', + ]), + 'title': 'UsageMetadata', + 'type': 'object', + }), }), 'title': 'FakeListLLMInput', }) @@ -5821,6 +5847,9 @@ 'title': 'Type', 'type': 'string', }), + 'usage_metadata': dict({ + '$ref': '#/definitions/UsageMetadata', + }), }), 'required': list([ 'content', @@ -6242,6 +6271,29 @@ 'title': 'ToolMessage', 'type': 'object', }), + 'UsageMetadata': dict({ + 'properties': dict({ + 'input_tokens': dict({ + 'title': 'Input Tokens', + 'type': 'integer', + }), + 'output_tokens': dict({ + 'title': 'Output Tokens', + 'type': 'integer', + }), + 'total_tokens': dict({ + 'title': 'Total Tokens', + 'type': 'integer', + }), + }), + 'required': list([ + 'input_tokens', + 'output_tokens', + 'total_tokens', + ]), + 'title': 'UsageMetadata', + 'type': 'object', + }), }), 'title': 'FakeListChatModelInput', }) @@ -6340,6 +6392,9 @@ 'title': 'Type', 'type': 'string', }), + 'usage_metadata': dict({ + '$ref': '#/definitions/UsageMetadata', + }), }), 'required': list([ 'content', @@ -6692,6 +6747,29 @@ 'title': 'ToolMessage', 'type': 'object', }), + 'UsageMetadata': dict({ + 'properties': dict({ + 'input_tokens': dict({ + 'title': 'Input Tokens', + 'type': 'integer', + }), + 'output_tokens': dict({ + 'title': 'Output Tokens', + 'type': 'integer', + }), + 'total_tokens': dict({ + 'title': 'Total Tokens', + 'type': 'integer', + }), + }), + 'required': list([ + 'input_tokens', + 'output_tokens', + 'total_tokens', + ]), + 'title': 'UsageMetadata', + 'type': 'object', + }), }), 'title': 'FakeListChatModelOutput', }) @@ -6778,6 +6856,9 @@ 'title': 'Type', 'type': 'string', }), + 'usage_metadata': dict({ + '$ref': '#/definitions/UsageMetadata', + }), }), 'required': list([ 'content', @@ -7199,6 +7280,29 @@ 'title': 'ToolMessage', 'type': 'object', }), + 'UsageMetadata': dict({ + 'properties': dict({ + 'input_tokens': dict({ + 'title': 'Input Tokens', + 'type': 'integer', + }), + 'output_tokens': dict({ + 'title': 'Output Tokens', + 'type': 'integer', + }), + 'total_tokens': dict({ + 'title': 'Total Tokens', + 'type': 'integer', + }), + }), + 'required': list([ + 'input_tokens', + 'output_tokens', + 'total_tokens', + ]), + 'title': 'UsageMetadata', + 'type': 'object', + }), }), 'title': 'ChatPromptTemplateOutput', }) @@ -7285,6 +7389,9 @@ 'title': 'Type', 'type': 'string', }), + 'usage_metadata': dict({ + '$ref': '#/definitions/UsageMetadata', + }), }), 'required': list([ 'content', @@ -7706,6 +7813,29 @@ 'title': 'ToolMessage', 'type': 'object', }), + 'UsageMetadata': dict({ + 'properties': dict({ + 'input_tokens': dict({ + 'title': 'Input Tokens', + 'type': 'integer', + }), + 'output_tokens': dict({ + 'title': 'Output Tokens', + 'type': 'integer', + }), + 'total_tokens': dict({ + 'title': 'Total Tokens', + 'type': 'integer', + }), + }), + 'required': list([ + 'input_tokens', + 'output_tokens', + 'total_tokens', + ]), + 'title': 'UsageMetadata', + 'type': 'object', + }), }), 'title': 'PromptTemplateOutput', }) @@ -7784,6 +7914,9 @@ 'title': 'Type', 'type': 'string', }), + 'usage_metadata': dict({ + '$ref': '#/definitions/UsageMetadata', + }), }), 'required': list([ 'content', @@ -8216,6 +8349,29 @@ 'title': 'ToolMessage', 'type': 'object', }), + 'UsageMetadata': dict({ + 'properties': dict({ + 'input_tokens': dict({ + 'title': 'Input Tokens', + 'type': 'integer', + }), + 'output_tokens': dict({ + 'title': 'Output Tokens', + 'type': 'integer', + }), + 'total_tokens': dict({ + 'title': 'Total Tokens', + 'type': 'integer', + }), + }), + 'required': list([ + 'input_tokens', + 'output_tokens', + 'total_tokens', + ]), + 'title': 'UsageMetadata', + 'type': 'object', + }), }), 'items': dict({ '$ref': '#/definitions/PromptTemplateOutput', @@ -8321,6 +8477,9 @@ 'title': 'Type', 'type': 'string', }), + 'usage_metadata': dict({ + '$ref': '#/definitions/UsageMetadata', + }), }), 'required': list([ 'content', @@ -8673,6 +8832,29 @@ 'title': 'ToolMessage', 'type': 'object', }), + 'UsageMetadata': dict({ + 'properties': dict({ + 'input_tokens': dict({ + 'title': 'Input Tokens', + 'type': 'integer', + }), + 'output_tokens': dict({ + 'title': 'Output Tokens', + 'type': 'integer', + }), + 'total_tokens': dict({ + 'title': 'Total Tokens', + 'type': 'integer', + }), + }), + 'required': list([ + 'input_tokens', + 'output_tokens', + 'total_tokens', + ]), + 'title': 'UsageMetadata', + 'type': 'object', + }), }), 'title': 'CommaSeparatedListOutputParserInput', }) diff --git a/libs/core/tests/unit_tests/runnables/test_graph.py b/libs/core/tests/unit_tests/runnables/test_graph.py index 8a9aa12c004..bc4ef621fc5 100644 --- a/libs/core/tests/unit_tests/runnables/test_graph.py +++ b/libs/core/tests/unit_tests/runnables/test_graph.py @@ -227,6 +227,29 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None: }, "required": ["name", "args", "id", "error"], }, + "UsageMetadata": { + "title": "UsageMetadata", + "type": "object", + "properties": { + "input_tokens": { + "title": "Input Tokens", + "type": "integer", + }, + "output_tokens": { + "title": "Output Tokens", + "type": "integer", + }, + "total_tokens": { + "title": "Total Tokens", + "type": "integer", + }, + }, + "required": [ + "input_tokens", + "output_tokens", + "total_tokens", + ], + }, "AIMessage": { "title": "AIMessage", "description": "Message from an AI.", @@ -280,6 +303,9 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None: "type": "array", "items": {"$ref": "#/definitions/InvalidToolCall"}, }, + "usage_metadata": { + "$ref": "#/definitions/UsageMetadata" + }, }, "required": ["content"], }, diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 3fbd7c57c18..72a9494a807 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -383,6 +383,16 @@ def test_schemas(snapshot: SnapshotAssertion) -> None: }, "required": ["name", "args", "id", "error"], }, + "UsageMetadata": { + "title": "UsageMetadata", + "type": "object", + "properties": { + "input_tokens": {"title": "Input Tokens", "type": "integer"}, + "output_tokens": {"title": "Output Tokens", "type": "integer"}, + "total_tokens": {"title": "Total Tokens", "type": "integer"}, + }, + "required": ["input_tokens", "output_tokens", "total_tokens"], + }, "AIMessage": { "title": "AIMessage", "description": "Message from an AI.", @@ -433,6 +443,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None: "type": "array", "items": {"$ref": "#/definitions/InvalidToolCall"}, }, + "usage_metadata": {"$ref": "#/definitions/UsageMetadata"}, }, "required": ["content"], }, diff --git a/libs/core/tests/unit_tests/test_messages.py b/libs/core/tests/unit_tests/test_messages.py index aa86b2f49e7..21884cf1e83 100644 --- a/libs/core/tests/unit_tests/test_messages.py +++ b/libs/core/tests/unit_tests/test_messages.py @@ -120,6 +120,22 @@ def test_message_chunks() -> None: assert ai_msg_chunk + tool_calls_msg_chunk == tool_calls_msg_chunk assert tool_calls_msg_chunk + ai_msg_chunk == tool_calls_msg_chunk + # Test token usage + left = AIMessageChunk( + content="", + usage_metadata={"input_tokens": 1, "output_tokens": 2, "total_tokens": 3}, + ) + right = AIMessageChunk( + content="", + usage_metadata={"input_tokens": 4, "output_tokens": 5, "total_tokens": 9}, + ) + assert left + right == AIMessageChunk( + content="", + usage_metadata={"input_tokens": 5, "output_tokens": 7, "total_tokens": 12}, + ) + assert AIMessageChunk(content="") + left == left + assert right + AIMessageChunk(content="") == right + def test_chat_message_chunks() -> None: assert ChatMessageChunk(role="User", content="I am", id="ai4") + ChatMessageChunk( diff --git a/libs/partners/ai21/tests/integration_tests/test_standard.py b/libs/partners/ai21/tests/integration_tests/test_standard.py index e281ff2f06d..2d74ca59ec3 100644 --- a/libs/partners/ai21/tests/integration_tests/test_standard.py +++ b/libs/partners/ai21/tests/integration_tests/test_standard.py @@ -41,6 +41,17 @@ class TestAI21J2(ChatModelIntegrationTests): chat_model_params, ) + @pytest.mark.xfail(reason="Not implemented.") + def test_usage_metadata( + self, + chat_model_class: Type[BaseChatModel], + chat_model_params: dict, + ) -> None: + super().test_usage_metadata( + chat_model_class, + chat_model_params, + ) + @pytest.fixture def chat_model_params(self) -> dict: return { @@ -79,6 +90,17 @@ class TestAI21Jamba(ChatModelIntegrationTests): chat_model_params, ) + @pytest.mark.xfail(reason="Not implemented.") + def test_usage_metadata( + self, + chat_model_class: Type[BaseChatModel], + chat_model_params: dict, + ) -> None: + super().test_usage_metadata( + chat_model_class, + chat_model_params, + ) + @pytest.fixture def chat_model_params(self) -> dict: return { diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index 73b60dfa2ab..b9d410d8288 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -503,6 +503,12 @@ class ChatAnthropic(BaseChatModel): ) else: msg = AIMessage(content=content) + # Collect token usage + msg.usage_metadata = { + "input_tokens": data.usage.input_tokens, + "output_tokens": data.usage.output_tokens, + "total_tokens": data.usage.input_tokens + data.usage.output_tokens, + } return ChatResult( generations=[ChatGeneration(message=msg)], llm_output=llm_output, diff --git a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py index 5da3a8a5e66..d9e382b6cf2 100644 --- a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py +++ b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py @@ -89,7 +89,16 @@ def test__format_output() -> None: ) expected = ChatResult( generations=[ - ChatGeneration(message=AIMessage("bar")), + ChatGeneration( + message=AIMessage( + "bar", + usage_metadata={ + "input_tokens": 2, + "output_tokens": 1, + "total_tokens": 3, + }, + ) + ), ], llm_output={ "id": "foo", diff --git a/libs/partners/fireworks/tests/integration_tests/test_standard.py b/libs/partners/fireworks/tests/integration_tests/test_standard.py index bfeeca693d5..26ba020419c 100644 --- a/libs/partners/fireworks/tests/integration_tests/test_standard.py +++ b/libs/partners/fireworks/tests/integration_tests/test_standard.py @@ -21,6 +21,17 @@ class TestFireworksStandard(ChatModelIntegrationTests): "temperature": 0, } + @pytest.mark.xfail(reason="Not implemented.") + def test_usage_metadata( + self, + chat_model_class: Type[BaseChatModel], + chat_model_params: dict, + ) -> None: + super().test_usage_metadata( + chat_model_class, + chat_model_params, + ) + @pytest.mark.xfail(reason="Not yet implemented.") def test_tool_message_histories_list_content( self, diff --git a/libs/partners/groq/tests/integration_tests/test_standard.py b/libs/partners/groq/tests/integration_tests/test_standard.py index 4048f7e8f6a..8224adc3ec8 100644 --- a/libs/partners/groq/tests/integration_tests/test_standard.py +++ b/libs/partners/groq/tests/integration_tests/test_standard.py @@ -14,6 +14,17 @@ class TestMistralStandard(ChatModelIntegrationTests): def chat_model_class(self) -> Type[BaseChatModel]: return ChatGroq + @pytest.mark.xfail(reason="Not implemented.") + def test_usage_metadata( + self, + chat_model_class: Type[BaseChatModel], + chat_model_params: dict, + ) -> None: + super().test_usage_metadata( + chat_model_class, + chat_model_params, + ) + @pytest.mark.xfail(reason="Not yet implemented.") def test_tool_message_histories_list_content( self, diff --git a/libs/partners/mistralai/tests/integration_tests/test_standard.py b/libs/partners/mistralai/tests/integration_tests/test_standard.py index d9b8ff19692..7ea8f1bee8f 100644 --- a/libs/partners/mistralai/tests/integration_tests/test_standard.py +++ b/libs/partners/mistralai/tests/integration_tests/test_standard.py @@ -20,3 +20,14 @@ class TestMistralStandard(ChatModelIntegrationTests): "model": "mistral-large-latest", "temperature": 0, } + + @pytest.mark.xfail(reason="Not implemented.") + def test_usage_metadata( + self, + chat_model_class: Type[BaseChatModel], + chat_model_params: dict, + ) -> None: + super().test_usage_metadata( + chat_model_class, + chat_model_params, + ) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 1ede891316f..213e036c6dc 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -548,8 +548,15 @@ class BaseChatOpenAI(BaseChatModel): if response.get("error"): raise ValueError(response.get("error")) + token_usage = response.get("usage", {}) for res in response["choices"]: message = _convert_dict_to_message(res["message"]) + if token_usage and isinstance(message, AIMessage): + message.usage_metadata = { + "input_tokens": token_usage.get("prompt_tokens", 0), + "output_tokens": token_usage.get("completion_tokens", 0), + "total_tokens": token_usage.get("total_tokens", 0), + } generation_info = dict(finish_reason=res.get("finish_reason")) if "logprobs" in res: generation_info["logprobs"] = res["logprobs"] @@ -558,7 +565,6 @@ class BaseChatOpenAI(BaseChatModel): generation_info=generation_info, ) generations.append(gen) - token_usage = response.get("usage", {}) llm_output = { "token_usage": token_usage, "model_name": self.model_name, diff --git a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py index 5f11c6f1f94..5f669efda16 100644 --- a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py +++ b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py @@ -132,6 +132,18 @@ class ChatModelIntegrationTests(ABC): assert isinstance(result.content, str) assert len(result.content) > 0 + def test_usage_metadata( + self, chat_model_class: Type[BaseChatModel], chat_model_params: dict + ) -> None: + model = chat_model_class(**chat_model_params) + result = model.invoke("Hello") + assert result is not None + assert isinstance(result, AIMessage) + assert result.usage_metadata is not None + assert isinstance(result.usage_metadata["input_tokens"], int) + assert isinstance(result.usage_metadata["output_tokens"], int) + assert isinstance(result.usage_metadata["total_tokens"], int) + def test_tool_message_histories_string_content( self, chat_model_class: Type[BaseChatModel],