diff --git a/libs/community/langchain_community/chat_models/litellm.py b/libs/community/langchain_community/chat_models/litellm.py index 18ad918914c..e57a5b07252 100644 --- a/libs/community/langchain_community/chat_models/litellm.py +++ b/libs/community/langchain_community/chat_models/litellm.py @@ -12,6 +12,7 @@ from typing import ( List, Mapping, Optional, + Sequence, Tuple, Type, Union, @@ -21,6 +22,7 @@ from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) +from langchain_core.language_models import LanguageModelInput from langchain_core.language_models.chat_models import ( BaseChatModel, agenerate_from_stream, @@ -46,8 +48,11 @@ from langchain_core.outputs import ( ChatGenerationChunk, ChatResult, ) -from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.pydantic_v1 import BaseModel, Field, root_validator +from langchain_core.runnables import Runnable +from langchain_core.tools import BaseTool from langchain_core.utils import get_from_dict_or_env +from langchain_core.utils.function_calling import convert_to_openai_tool logger = logging.getLogger(__name__) @@ -411,6 +416,32 @@ class ChatLiteLLM(BaseChatModel): ) return self._create_chat_result(response) + def bind_tools( + self, + tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + """Bind tool-like objects to this chat model. + + LiteLLM expects tools argument in OpenAI format. + + Args: + tools: A list of tool definitions to bind to this chat model. + Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic + models, callables, and BaseTools will be automatically converted to + their schema dictionary representation. + tool_choice: Which tool to require the model to call. + Must be the name of the single provided function or + "auto" to automatically determine which function to call + (if any), or a dict of the form: + {"type": "function", "function": {"name": <>}}. + **kwargs: Any additional parameters to pass to the + :class:`~langchain.runnable.Runnable` constructor. + """ + + formatted_tools = [convert_to_openai_tool(tool) for tool in tools] + return super().bind(tools=formatted_tools, **kwargs) + @property def _identifying_params(self) -> Dict[str, Any]: """Get the identifying parameters."""