From 884125e129c88b04f4cfa4d85e9ba1fc2fba949e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Rodr=C3=ADguez?= Date: Wed, 2 Apr 2025 19:45:15 -0400 Subject: [PATCH] community: support usage_metadata for litellm (#30625) Support "usage_metadata" for LiteLLM. If no one reviews your PR within a few days, please @-mention one of baskaryan, eyurtsev, ccurme, vbarda, hwchase17. --- .../langchain_community/chat_models/litellm.py | 18 +++++++++++++++++- .../chat_models/test_litellm_standard.py | 4 ++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/libs/community/langchain_community/chat_models/litellm.py b/libs/community/langchain_community/chat_models/litellm.py index 3b90102bd90..36335afbd9b 100644 --- a/libs/community/langchain_community/chat_models/litellm.py +++ b/libs/community/langchain_community/chat_models/litellm.py @@ -48,6 +48,7 @@ from langchain_core.messages import ( ToolCallChunk, ToolMessage, ) +from langchain_core.messages.ai import UsageMetadata from langchain_core.outputs import ( ChatGeneration, ChatGenerationChunk, @@ -410,14 +411,19 @@ class ChatLiteLLM(BaseChatModel): def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult: generations = [] + token_usage = response.get("usage", {}) for res in response["choices"]: message = _convert_dict_to_message(res["message"]) + if isinstance(message, AIMessage): + message.response_metadata = { + "model_name": self.model_name or self.model + } + message.usage_metadata = _create_usage_metadata(token_usage) gen = ChatGeneration( message=message, generation_info=dict(finish_reason=res.get("finish_reason")), ) generations.append(gen) - token_usage = response.get("usage", {}) set_model_value = self.model if self.model_name is not None: set_model_value = self.model_name @@ -585,3 +591,13 @@ class ChatLiteLLM(BaseChatModel): @property def _llm_type(self) -> str: return "litellm-chat" + + +def _create_usage_metadata(token_usage: Mapping[str, Any]) -> UsageMetadata: + input_tokens = token_usage.get("prompt_tokens", 0) + output_tokens = token_usage.get("completion_tokens", 0) + return UsageMetadata( + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=input_tokens + output_tokens, + ) diff --git a/libs/community/tests/integration_tests/chat_models/test_litellm_standard.py b/libs/community/tests/integration_tests/chat_models/test_litellm_standard.py index 5e87e3ac8a0..e7b753acf3f 100644 --- a/libs/community/tests/integration_tests/chat_models/test_litellm_standard.py +++ b/libs/community/tests/integration_tests/chat_models/test_litellm_standard.py @@ -19,5 +19,5 @@ class TestLiteLLMStandard(ChatModelIntegrationTests): return {"model": "ollama/mistral"} @pytest.mark.xfail(reason="Not yet implemented.") - def test_usage_metadata(self, model: BaseChatModel) -> None: - super().test_usage_metadata(model) + def test_usage_metadata_streaming(self, model: BaseChatModel) -> None: + super().test_usage_metadata_streaming(model)