diff --git a/libs/community/extended_testing_deps.txt b/libs/community/extended_testing_deps.txt index 8e22d52cf61..79ce46657ac 100644 --- a/libs/community/extended_testing_deps.txt +++ b/libs/community/extended_testing_deps.txt @@ -40,6 +40,7 @@ jinja2>=3,<4 jq>=1.4.1,<2 jsonschema>1 keybert>=0.8.5 +litellm>=1.30,<=1.39.5 lxml>=4.9.3,<6.0 markdownify>=0.11.6,<0.12 motor>=3.3.1,<4 diff --git a/libs/community/langchain_community/chat_models/litellm.py b/libs/community/langchain_community/chat_models/litellm.py index 3ce4836c8a2..f0cf69a7191 100644 --- a/libs/community/langchain_community/chat_models/litellm.py +++ b/libs/community/langchain_community/chat_models/litellm.py @@ -2,6 +2,7 @@ from __future__ import annotations +import json import logging from typing import ( Any, @@ -42,6 +43,9 @@ from langchain_core.messages import ( HumanMessageChunk, SystemMessage, SystemMessageChunk, + ToolCall, + ToolCallChunk, + ToolMessage, ) from langchain_core.outputs import ( ChatGeneration, @@ -132,10 +136,30 @@ def _convert_delta_to_message_chunk( else: additional_kwargs = {} + tool_call_chunks = [] + if raw_tool_calls := _dict.get("tool_calls"): + additional_kwargs["tool_calls"] = raw_tool_calls + try: + tool_call_chunks = [ + ToolCallChunk( + name=rtc["function"].get("name"), + args=rtc["function"].get("arguments"), + id=rtc.get("id"), + index=rtc["index"], + ) + for rtc in raw_tool_calls + ] + except KeyError: + pass + if role == "user" or default_class == HumanMessageChunk: return HumanMessageChunk(content=content) elif role == "assistant" or default_class == AIMessageChunk: - return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) + return AIMessageChunk( + content=content, + additional_kwargs=additional_kwargs, + tool_call_chunks=tool_call_chunks, + ) elif role == "system" or default_class == SystemMessageChunk: return SystemMessageChunk(content=content) elif role == "function" or default_class == FunctionMessageChunk: @@ -146,23 +170,41 @@ def _convert_delta_to_message_chunk( return default_class(content=content) # type: ignore[call-arg] +def _lc_tool_call_to_openai_tool_call(tool_call: ToolCall) -> dict: + return { + "type": "function", + "id": tool_call["id"], + "function": { + "name": tool_call["name"], + "arguments": json.dumps(tool_call["args"]), + }, + } + + def _convert_message_to_dict(message: BaseMessage) -> dict: + message_dict: Dict[str, Any] = {"content": message.content} if isinstance(message, ChatMessage): - message_dict = {"role": message.role, "content": message.content} + message_dict["role"] = message.role elif isinstance(message, HumanMessage): - message_dict = {"role": "user", "content": message.content} + message_dict["role"] = "user" elif isinstance(message, AIMessage): - message_dict = {"role": "assistant", "content": message.content} + message_dict["role"] = "assistant" if "function_call" in message.additional_kwargs: message_dict["function_call"] = message.additional_kwargs["function_call"] + if message.tool_calls: + message_dict["tool_calls"] = [ + _lc_tool_call_to_openai_tool_call(tc) for tc in message.tool_calls + ] + elif "tool_calls" in message.additional_kwargs: + message_dict["tool_calls"] = message.additional_kwargs["tool_calls"] elif isinstance(message, SystemMessage): - message_dict = {"role": "system", "content": message.content} + message_dict["role"] = "system" elif isinstance(message, FunctionMessage): - message_dict = { - "role": "function", - "content": message.content, - "name": message.name, - } + message_dict["role"] = "function" + message_dict["name"] = message.name + elif isinstance(message, ToolMessage): + message_dict["role"] = "tool" + message_dict["tool_call_id"] = message.tool_call_id else: raise ValueError(f"Got unknown type {message}") if "name" in message.additional_kwargs: @@ -360,6 +402,8 @@ class ChatLiteLLM(BaseChatModel): for chunk in self.completion_with_retry( messages=message_dicts, run_manager=run_manager, **params ): + if not isinstance(chunk, dict): + chunk = chunk.model_dump() if len(chunk["choices"]) == 0: continue delta = chunk["choices"][0]["delta"] @@ -384,6 +428,8 @@ class ChatLiteLLM(BaseChatModel): async for chunk in await acompletion_with_retry( self, messages=message_dicts, run_manager=run_manager, **params ): + if not isinstance(chunk, dict): + chunk = chunk.model_dump() if len(chunk["choices"]) == 0: continue delta = chunk["choices"][0]["delta"] diff --git a/libs/community/tests/integration_tests/chat_models/test_litellm_standard.py b/libs/community/tests/integration_tests/chat_models/test_litellm_standard.py new file mode 100644 index 00000000000..972617617bd --- /dev/null +++ b/libs/community/tests/integration_tests/chat_models/test_litellm_standard.py @@ -0,0 +1,23 @@ +"""Standard LangChain interface tests""" + +from typing import Type + +import pytest +from langchain_core.language_models import BaseChatModel +from langchain_standard_tests.integration_tests import ChatModelIntegrationTests + +from langchain_community.chat_models.litellm import ChatLiteLLM + + +class TestLiteLLMStandard(ChatModelIntegrationTests): + @property + def chat_model_class(self) -> Type[BaseChatModel]: + return ChatLiteLLM + + @property + def chat_model_params(self) -> dict: + return {"model": "ollama/mistral"} + + @pytest.mark.xfail(reason="Not yet implemented.") + def test_usage_metadata(self, model: BaseChatModel) -> None: + super().test_usage_metadata(model) diff --git a/libs/community/tests/unit_tests/chat_models/test_litellm.py b/libs/community/tests/unit_tests/chat_models/test_litellm.py new file mode 100644 index 00000000000..583c53d64bd --- /dev/null +++ b/libs/community/tests/unit_tests/chat_models/test_litellm.py @@ -0,0 +1,24 @@ +"""Standard LangChain interface tests""" + +from typing import Type + +import pytest +from langchain_core.language_models import BaseChatModel +from langchain_standard_tests.unit_tests import ChatModelUnitTests + +from langchain_community.chat_models.litellm import ChatLiteLLM + + +@pytest.mark.requires("litellm") +class TestLiteLLMStandard(ChatModelUnitTests): + @property + def chat_model_class(self) -> Type[BaseChatModel]: + return ChatLiteLLM + + @property + def chat_model_params(self) -> dict: + return {"api_key": "test_api_key"} + + @pytest.mark.xfail(reason="Not yet implemented.") + def test_standard_params(self, model: BaseChatModel) -> None: + super().test_standard_params(model)