diff --git a/libs/community/langchain_community/chat_models/litellm.py b/libs/community/langchain_community/chat_models/litellm.py index 36335afbd9b..c9f588b3d49 100644 --- a/libs/community/langchain_community/chat_models/litellm.py +++ b/libs/community/langchain_community/chat_models/litellm.py @@ -452,6 +452,7 @@ class ChatLiteLLM(BaseChatModel): params = {**params, **kwargs, "stream": True} default_chunk_class = AIMessageChunk + added_model_name = False for chunk in self.completion_with_retry( messages=message_dicts, run_manager=run_manager, **params ): @@ -460,7 +461,15 @@ class ChatLiteLLM(BaseChatModel): if len(chunk["choices"]) == 0: continue delta = chunk["choices"][0]["delta"] + usage = chunk.get("usage", {}) chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) + if isinstance(chunk, AIMessageChunk): + if not added_model_name: + chunk.response_metadata = { + "model_name": self.model_name or self.model + } + added_model_name = True + chunk.usage_metadata = _create_usage_metadata(usage) default_chunk_class = chunk.__class__ cg_chunk = ChatGenerationChunk(message=chunk) if run_manager: @@ -478,6 +487,7 @@ class ChatLiteLLM(BaseChatModel): params = {**params, **kwargs, "stream": True} default_chunk_class = AIMessageChunk + added_model_name = False async for chunk in await acompletion_with_retry( self, messages=message_dicts, run_manager=run_manager, **params ): @@ -486,7 +496,15 @@ class ChatLiteLLM(BaseChatModel): if len(chunk["choices"]) == 0: continue delta = chunk["choices"][0]["delta"] + usage = chunk.get("usage", {}) chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) + if isinstance(chunk, AIMessageChunk): + if not added_model_name: + chunk.response_metadata = { + "model_name": self.model_name or self.model + } + added_model_name = True + chunk.usage_metadata = _create_usage_metadata(usage) default_chunk_class = chunk.__class__ cg_chunk = ChatGenerationChunk(message=chunk) if run_manager: diff --git a/libs/community/tests/integration_tests/chat_models/test_litellm_standard.py b/libs/community/tests/integration_tests/chat_models/test_litellm_standard.py index e7b753acf3f..d034ece43e4 100644 --- a/libs/community/tests/integration_tests/chat_models/test_litellm_standard.py +++ b/libs/community/tests/integration_tests/chat_models/test_litellm_standard.py @@ -2,7 +2,6 @@ from typing import Type -import pytest from langchain_core.language_models import BaseChatModel from langchain_tests.integration_tests import ChatModelIntegrationTests @@ -16,8 +15,8 @@ class TestLiteLLMStandard(ChatModelIntegrationTests): @property def chat_model_params(self) -> dict: - return {"model": "ollama/mistral"} - - @pytest.mark.xfail(reason="Not yet implemented.") - def test_usage_metadata_streaming(self, model: BaseChatModel) -> None: - super().test_usage_metadata_streaming(model) + return { + "model": "ollama/mistral", + # Needed to get the usage object when streaming. See https://docs.litellm.ai/docs/completion/usage#streaming-usage + "model_kwargs": {"stream_options": {"include_usage": True}}, + }