Invalid tool_choice being passed to ChatLiteLLM (#28198)

- **Description:** Invalid `tool_choice` is given to `ChatLiteLLM` to
`bind_tools` due to it's parent's class default value being pass through
`with_structured_output`.
- **Issue:** #28176
This commit is contained in:
Mohammad Mohtashim 2024-12-08 00:33:40 +05:00 committed by GitHub
parent dd0085a9ff
commit 524ee6d9ac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -11,6 +11,7 @@ from typing import (
Dict, Dict,
Iterator, Iterator,
List, List,
Literal,
Mapping, Mapping,
Optional, Optional,
Sequence, Sequence,
@ -212,6 +213,33 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
return message_dict return message_dict
_OPENAI_MODELS = [
"o1-mini",
"o1-preview",
"gpt-4o-mini",
"gpt-4o-mini-2024-07-18",
"gpt-4o",
"gpt-4o-2024-08-06",
"gpt-4o-2024-05-13",
"gpt-4-turbo",
"gpt-4-turbo-preview",
"gpt-4-0125-preview",
"gpt-4-1106-preview",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo",
"gpt-3.5-turbo-0301",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-16k-0613",
"gpt-4",
"gpt-4-0314",
"gpt-4-0613",
"gpt-4-32k",
"gpt-4-32k-0314",
"gpt-4-32k-0613",
]
class ChatLiteLLM(BaseChatModel): class ChatLiteLLM(BaseChatModel):
"""Chat model that uses the LiteLLM API.""" """Chat model that uses the LiteLLM API."""
@ -465,6 +493,9 @@ class ChatLiteLLM(BaseChatModel):
def bind_tools( def bind_tools(
self, self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
tool_choice: Optional[
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
] = None,
**kwargs: Any, **kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]: ) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model. """Bind tool-like objects to this chat model.
@ -476,17 +507,47 @@ class ChatLiteLLM(BaseChatModel):
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
models, callables, and BaseTools will be automatically converted to models, callables, and BaseTools will be automatically converted to
their schema dictionary representation. their schema dictionary representation.
tool_choice: Which tool to require the model to call. tool_choice: Which tool to require the model to call. Options are:
Must be the name of the single provided function or - str of the form ``"<<tool_name>>"``: calls <<tool_name>> tool.
"auto" to automatically determine which function to call - ``"auto"``:
(if any), or a dict of the form: automatically selects a tool (including no tool).
{"type": "function", "function": {"name": <<tool_name>>}}. - ``"none"``:
does not call a tool.
- ``"any"`` or ``"required"`` or ``True``:
forces least one tool to be called.
- dict of the form:
``{"type": "function", "function": {"name": <<tool_name>>}}``
- ``False`` or ``None``: no effect
**kwargs: Any additional parameters to pass to the **kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor. :class:`~langchain.runnable.Runnable` constructor.
""" """
formatted_tools = [convert_to_openai_tool(tool) for tool in tools] formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
return super().bind(tools=formatted_tools, **kwargs)
# In case of openai if tool_choice is `any` or if bool has been provided we
# change it to `required` as that is suppored by openai.
if (
(self.model is not None and "azure" in self.model)
or (self.model_name is not None and "azure" in self.model_name)
or (self.model is not None and self.model in _OPENAI_MODELS)
or (self.model_name is not None and self.model_name in _OPENAI_MODELS)
) and (tool_choice == "any" or isinstance(tool_choice, bool)):
tool_choice = "required"
# If tool_choice is bool apart from openai we make it `any`
elif isinstance(tool_choice, bool):
tool_choice = "any"
elif isinstance(tool_choice, dict):
tool_names = [
formatted_tool["function"]["name"] for formatted_tool in formatted_tools
]
if not any(
tool_name == tool_choice["function"]["name"] for tool_name in tool_names
):
raise ValueError(
f"Tool choice {tool_choice} was specified, but the only "
f"provided tools were {tool_names}."
)
return super().bind(tools=formatted_tools, tool_choice=tool_choice, **kwargs)
@property @property
def _identifying_params(self) -> Dict[str, Any]: def _identifying_params(self) -> Dict[str, Any]: