Have ChatLlamaCpp handle "auto" and "any" for tool_choice

This commit is contained in:
Clint Adams
2025-04-13 17:42:42 -04:00
parent f005988e31
commit 58fd6266ca

View File

@@ -7,6 +7,7 @@ from typing import (
Dict, Dict,
Iterator, Iterator,
List, List,
Literal,
Mapping, Mapping,
Optional, Optional,
Sequence, Sequence,
@@ -342,15 +343,10 @@ class ChatLlamaCpp(BaseChatModel):
self, self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
*, *,
tool_choice: Optional[Union[dict, bool, str]] = None, tool_choice: Optional[Union[dict, bool, str, Literal["auto", "any"]]] = None,
**kwargs: Any, **kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]: ) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model """Bind tool-like objects to this chat model """
tool_choice: does not currently support "any", "auto" choices like OpenAI
tool-calling API. should be a dict of the form to force this tool
{"type": "function", "function": {"name": <<tool_name>>}}.
"""
formatted_tools = [convert_to_openai_tool(tool) for tool in tools] formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
tool_names = [ft["function"]["name"] for ft in formatted_tools] tool_names = [ft["function"]["name"] for ft in formatted_tools]
if tool_choice: if tool_choice:
@@ -363,14 +359,24 @@ class ChatLlamaCpp(BaseChatModel):
f"provided tools were {tool_names}." f"provided tools were {tool_names}."
) )
elif isinstance(tool_choice, str): elif isinstance(tool_choice, str):
chosen = [ if tool_choice == 'any':
f for f in formatted_tools if f["function"]["name"] == tool_choice if len(formatted_tools) == 1:
] tool_choice = formatted_tools[0]
if not chosen: else:
raise ValueError( raise ValueError(
f"Tool choice {tool_choice=} was specified, but the only " "tool_choice `'any'` only supported if one tool is provided."
f"provided tools were {tool_names}." )
) elif tool_choice == 'auto':
tool_choice = None
else:
chosen = [
f for f in formatted_tools if f["function"]["name"] == tool_choice
]
if not chosen:
raise ValueError(
f"Tool choice {tool_choice=} was specified, but the only "
f"provided tools were {tool_names}."
)
elif isinstance(tool_choice, bool): elif isinstance(tool_choice, bool):
if len(formatted_tools) > 1: if len(formatted_tools) > 1:
raise ValueError( raise ValueError(
@@ -386,7 +392,6 @@ class ChatLlamaCpp(BaseChatModel):
) )
kwargs["tool_choice"] = tool_choice kwargs["tool_choice"] = tool_choice
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
return super().bind(tools=formatted_tools, **kwargs) return super().bind(tools=formatted_tools, **kwargs)
def with_structured_output( def with_structured_output(