diff --git a/libs/community/langchain_community/chat_models/llamacpp.py b/libs/community/langchain_community/chat_models/llamacpp.py index ea9b5975d11..f403566e8ea 100644 --- a/libs/community/langchain_community/chat_models/llamacpp.py +++ b/libs/community/langchain_community/chat_models/llamacpp.py @@ -7,6 +7,7 @@ from typing import ( Dict, Iterator, List, + Literal, Mapping, Optional, Sequence, @@ -342,15 +343,10 @@ class ChatLlamaCpp(BaseChatModel): self, tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], *, - tool_choice: Optional[Union[dict, bool, str]] = None, + tool_choice: Optional[Union[dict, bool, str, Literal["auto", "any"]]] = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: - """Bind tool-like objects to this chat model - - tool_choice: does not currently support "any", "auto" choices like OpenAI - tool-calling API. should be a dict of the form to force this tool - {"type": "function", "function": {"name": <>}}. - """ + """Bind tool-like objects to this chat model """ formatted_tools = [convert_to_openai_tool(tool) for tool in tools] tool_names = [ft["function"]["name"] for ft in formatted_tools] if tool_choice: @@ -363,14 +359,24 @@ class ChatLlamaCpp(BaseChatModel): f"provided tools were {tool_names}." ) elif isinstance(tool_choice, str): - chosen = [ - f for f in formatted_tools if f["function"]["name"] == tool_choice - ] - if not chosen: - raise ValueError( - f"Tool choice {tool_choice=} was specified, but the only " - f"provided tools were {tool_names}." - ) + if tool_choice == 'any': + if len(formatted_tools) == 1: + tool_choice = formatted_tools[0] + else: + raise ValueError( + "tool_choice `'any'` only supported if one tool is provided." + ) + elif tool_choice == 'auto': + tool_choice = None + else: + chosen = [ + f for f in formatted_tools if f["function"]["name"] == tool_choice + ] + if not chosen: + raise ValueError( + f"Tool choice {tool_choice=} was specified, but the only " + f"provided tools were {tool_names}." + ) elif isinstance(tool_choice, bool): if len(formatted_tools) > 1: raise ValueError( @@ -386,7 +392,6 @@ class ChatLlamaCpp(BaseChatModel): ) kwargs["tool_choice"] = tool_choice - formatted_tools = [convert_to_openai_tool(tool) for tool in tools] return super().bind(tools=formatted_tools, **kwargs) def with_structured_output(