multiple: enforce standards on tool_choice (#30372)

- Test if models support forcing tool calls via `tool_choice`. If they
do, they should support
  - `"any"` to specify any tool
  - the tool name as a string to force calling a particular tool
- Add `tool_choice` to signature of `BaseChatModel.bind_tools` in core
- Deprecate `tool_choice_value` in standard tests in favor of a boolean
`has_tool_choice`

Will follow up with PRs in external repos (tested in AWS and Google
already).
This commit is contained in:
ccurme 2025-03-20 13:48:59 -04:00 committed by GitHub
parent b86cd8270c
commit de3960d285
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 190 additions and 56 deletions

View File

@ -194,9 +194,14 @@ class ChatCloudflareWorkersAI(BaseChatModel):
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema."""
_ = kwargs.pop("strict", None)
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
is_pydantic_schema = _is_pydantic_class(schema)
if method == "json_schema":
# Some applications require that incompatible parameters (e.g., unsupported
# methods) be handled.
method = "function_calling"
if method == "function_calling":
if schema is None:
raise ValueError(

View File

@ -383,6 +383,8 @@ class ChatPerplexity(BaseChatModel):
- "parsing_error": Optional[BaseException]
""" # noqa: E501
if method in ("function_calling", "json_mode"):
method = "json_schema"
if method == "json_schema":
if schema is None:
raise ValueError(

View File

@ -6,6 +6,7 @@ from typing import (
Dict,
Iterator,
List,
Literal,
Mapping,
Optional,
Sequence,
@ -382,7 +383,7 @@ class ChatReka(BaseChatModel):
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
*,
tool_choice: str = "auto",
tool_choice: Optional[Union[str, Literal["any"]]] = "auto",
strict: Optional[bool] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
@ -421,6 +422,10 @@ class ChatReka(BaseChatModel):
]
# Ensure tool_choice is one of the allowed options
if tool_choice is None:
tool_choice = "auto"
if tool_choice == "any":
tool_choice = "tool"
if tool_choice not in ("auto", "none", "tool"):
raise ValueError(
f"Invalid tool_choice '{tool_choice}' provided. "

View File

@ -1167,6 +1167,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
tools: Sequence[
Union[typing.Dict[str, Any], type, Callable, BaseTool] # noqa: UP006
],
*,
tool_choice: Optional[Union[str, Literal["any"]]] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
raise NotImplementedError
@ -1280,6 +1282,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
Added support for TypedDict class.
""" # noqa: E501
_ = kwargs.pop("method", None)
_ = kwargs.pop("strict", None)
if kwargs:
msg = f"Received unsupported arguments {kwargs}"
raise ValueError(msg)

View File

@ -1,6 +1,6 @@
"""Standard LangChain interface tests"""
from typing import Optional, Type
from typing import Type
import pytest
from langchain_core.language_models import BaseChatModel
@ -40,11 +40,6 @@ class TestGroqLlama(BaseTestGroq):
"rate_limiter": rate_limiter,
}
@property
def tool_choice_value(self) -> Optional[str]:
"""Value to use for tool choice when used in tests."""
return "any"
@property
def supports_json_mode(self) -> bool:
return True

View File

@ -685,6 +685,7 @@ class ChatMistralAI(BaseChatModel):
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
tool_choice: Optional[Union[dict, str, Literal["auto", "any"]]] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model.
@ -705,6 +706,22 @@ class ChatMistralAI(BaseChatModel):
"""
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
if tool_choice:
tool_names = []
for tool in formatted_tools:
if "function" in tool and (name := tool["function"].get("name")):
tool_names.append(name)
elif name := tool.get("name"):
tool_names.append(name)
else:
pass
if tool_choice in tool_names:
kwargs["tool_choice"] = {
"type": "function",
"function": {"name": tool_choice},
}
else:
kwargs["tool_choice"] = tool_choice
return super().bind(tools=formatted_tools, **kwargs)
def with_structured_output(

View File

@ -1,6 +1,6 @@
"""Standard LangChain interface tests"""
from typing import Optional, Type
from typing import Type
from langchain_core.language_models import BaseChatModel
from langchain_tests.integration_tests import ( # type: ignore[import-not-found]
@ -22,8 +22,3 @@ class TestMistralStandard(ChatModelIntegrationTests):
@property
def supports_json_mode(self) -> bool:
return True
@property
def tool_choice_value(self) -> Optional[str]:
"""Value to use for tool choice when used in tests."""
return "any"

View File

@ -1,6 +1,6 @@
"""Standard LangChain interface tests"""
from typing import Optional, Type
from typing import Type
import pytest # type: ignore[import-not-found]
from langchain_core.language_models import BaseChatModel
@ -30,11 +30,6 @@ class TestXAIStandard(ChatModelIntegrationTests):
"rate_limiter": rate_limiter,
}
@property
def tool_choice_value(self) -> Optional[str]:
"""Value to use for tool choice when used in tests."""
return "tool_name"
@pytest.mark.xfail(reason="Not yet supported.")
def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
super().test_usage_metadata_streaming(model)

View File

@ -1,10 +1,12 @@
import base64
import inspect
import json
from typing import Any, List, Literal, Optional, cast
from unittest.mock import MagicMock
import httpx
import pytest
from langchain_core._api import warn_deprecated
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.language_models import BaseChatModel, GenericFakeChatModel
from langchain_core.messages import (
@ -204,12 +206,12 @@ class ChatModelIntegrationTests(ChatModelTests):
Value to use for tool choice when used in tests.
Some tests for tool calling features attempt to force tool calling via a
`tool_choice` parameter. A common value for this parameter is "any". Defaults
to `None`.
Note: if the value is set to "tool_name", the name of the tool used in each
test will be set as the value for `tool_choice`.
.. warning:: Deprecated since version 0.3.15:
This property will be removed in version 0.3.20. If a model supports
``tool_choice``, it should accept ``tool_choice="any"`` and
``tool_choice=<string name of tool>``. If a model does not
support forcing tool calling, override the ``has_tool_choice`` property to
return ``False``.
Example:
@ -219,6 +221,26 @@ class ChatModelIntegrationTests(ChatModelTests):
def tool_choice_value(self) -> Optional[str]:
return "any"
.. dropdown:: has_tool_choice
Boolean property indicating whether the chat model supports forcing tool
calling via a ``tool_choice`` parameter.
By default, this is determined by whether the parameter is included in the
signature for the corresponding ``bind_tools`` method.
If ``True``, the minimum requirement for this feature is that
``tool_choice="any"`` will force a tool call, and ``tool_choice=<tool name>``
will force a call to a specific tool.
Example override:
.. code-block:: python
@property
def has_tool_choice(self) -> bool:
return False
.. dropdown:: has_structured_output
Boolean property indicating whether the chat model supports structured
@ -989,16 +1011,33 @@ class ChatModelIntegrationTests(ChatModelTests):
def test_tool_calling(self, model: BaseChatModel) -> None:
super().test_tool_calling(model)
Otherwise, ensure that the ``tool_choice_value`` property is correctly
specified on the test class.
Otherwise, in the case that only one tool is bound, ensure that
``tool_choice`` supports the string ``"any"`` to force calling that tool.
"""
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")
if self.tool_choice_value == "tool_name":
tool_choice: Optional[str] = "magic_function"
if not self.has_tool_choice:
tool_choice_value = None
else:
tool_choice = self.tool_choice_value
model_with_tools = model.bind_tools([magic_function], tool_choice=tool_choice)
tool_choice_value = "any"
# Emit warning if tool_choice_value property is overridden
if inspect.getattr_static(
self, "tool_choice_value"
) is not inspect.getattr_static(ChatModelIntegrationTests, "tool_choice_value"):
warn_deprecated(
"0.3.15",
message=(
"`tool_choice_value` will be removed in version 0.3.20. If a "
"model supports `tool_choice`, it should accept `tool_choice='any' "
"and `tool_choice=<string name of tool>`. If the model does not "
"support `tool_choice`, override the `supports_tool_choice` "
"property to return `False`."
),
removal="0.3.20",
)
model_with_tools = model.bind_tools(
[magic_function], tool_choice=tool_choice_value
)
# Test invoke
query = "What is the value of magic_function(3)? Use the tool."
@ -1012,6 +1051,57 @@ class ChatModelIntegrationTests(ChatModelTests):
assert isinstance(full, AIMessage)
_validate_tool_call_message(full)
def test_tool_choice(self, model: BaseChatModel) -> None:
"""Test that the model can force tool calling via the ``tool_choice``
parameter. This test is skipped if the ``has_tool_choice`` property on the
test class is set to False.
This test is optional and should be skipped if the model does not support
tool calling (see Configuration below).
.. dropdown:: Configuration
To disable tool calling tests, set ``has_tool_choice`` to False in your
test class:
.. code-block:: python
class TestMyChatModelIntegration(ChatModelIntegrationTests):
@property
def has_tool_choice(self) -> bool:
return False
.. dropdown:: Troubleshooting
If this test fails, check whether the ``test_tool_calling`` test is passing.
If it is not, refer to the troubleshooting steps in that test first.
If ``test_tool_calling`` is passing, check that the underlying model
supports forced tool calling. If it does, ``bind_tools`` should accept a
``tool_choice`` parameter that can be used to force a tool call.
It should accept (1) the string ``"any"`` to force calling the bound tool,
and (2) the string name of the tool to force calling that tool.
"""
if not self.has_tool_choice or not self.has_tool_calling:
pytest.skip("Test requires tool choice.")
@tool
def get_weather(location: str) -> str:
"""Get weather at a location."""
return "It's sunny."
for tool_choice in ["any", "magic_function"]:
model_with_tools = model.bind_tools(
[magic_function, get_weather], tool_choice=tool_choice
)
result = model_with_tools.invoke("Hello!")
assert isinstance(result, AIMessage)
assert result.tool_calls
if tool_choice == "magic_function":
assert result.tool_calls[0]["name"] == "magic_function"
async def test_tool_calling_async(self, model: BaseChatModel) -> None:
"""Test that the model generates tool calls. This test is skipped if the
``has_tool_calling`` property on the test class is set to False.
@ -1048,16 +1138,18 @@ class ChatModelIntegrationTests(ChatModelTests):
async def test_tool_calling_async(self, model: BaseChatModel) -> None:
await super().test_tool_calling_async(model)
Otherwise, ensure that the ``tool_choice_value`` property is correctly
specified on the test class.
Otherwise, in the case that only one tool is bound, ensure that
``tool_choice`` supports the string ``"any"`` to force calling that tool.
"""
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")
if self.tool_choice_value == "tool_name":
tool_choice: Optional[str] = "magic_function"
if not self.has_tool_choice:
tool_choice_value = None
else:
tool_choice = self.tool_choice_value
model_with_tools = model.bind_tools([magic_function], tool_choice=tool_choice)
tool_choice_value = "any"
model_with_tools = model.bind_tools(
[magic_function], tool_choice=tool_choice_value
)
# Test ainvoke
query = "What is the value of magic_function(3)? Use the tool."
@ -1109,18 +1201,17 @@ class ChatModelIntegrationTests(ChatModelTests):
def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None:
super().test_tool_calling_with_no_arguments(model)
Otherwise, ensure that the ``tool_choice_value`` property is correctly
specified on the test class.
Otherwise, in the case that only one tool is bound, ensure that
``tool_choice`` supports the string ``"any"`` to force calling that tool.
""" # noqa: E501
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")
if self.tool_choice_value == "tool_name":
tool_choice: Optional[str] = "magic_function_no_args"
if not self.has_tool_choice:
tool_choice_value = None
else:
tool_choice = self.tool_choice_value
tool_choice_value = "any"
model_with_tools = model.bind_tools(
[magic_function_no_args], tool_choice=tool_choice
[magic_function_no_args], tool_choice=tool_choice_value
)
query = "What is the value of magic_function_no_args()? Use the tool."
result = model_with_tools.invoke(query)
@ -1184,10 +1275,10 @@ class ChatModelIntegrationTests(ChatModelTests):
name="greeting_generator",
description="Generate a greeting in a particular style of speaking.",
)
if self.tool_choice_value == "tool_name":
tool_choice: Optional[str] = "greeting_generator"
if not self.has_tool_choice:
tool_choice: Optional[str] = "any"
else:
tool_choice = self.tool_choice_value
tool_choice = None
model_with_tools = model.bind_tools([tool_], tool_choice=tool_choice)
query = "Using the tool, generate a Pirate greeting."
result = model_with_tools.invoke(query)
@ -2086,9 +2177,6 @@ class ChatModelIntegrationTests(ChatModelTests):
If this test fails, check that the ``status`` field on ``ToolMessage``
objects is either ignored or passed to the model appropriately.
Otherwise, ensure that the ``tool_choice_value`` property is correctly
specified on the test class.
"""
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")

View File

@ -2,6 +2,7 @@
:autodoc-options: autoproperty
"""
import inspect
import os
from abc import abstractmethod
from typing import Any, Dict, List, Literal, Optional, Tuple, Type
@ -127,6 +128,14 @@ class ChatModelTests(BaseStandardTests):
"""(None or str) to use for tool choice when used in tests."""
return None
@property
def has_tool_choice(self) -> bool:
"""(bool) whether the model supports tool calling."""
bind_tools_params = inspect.signature(
self.chat_model_class.bind_tools
).parameters
return "tool_choice" in bind_tools_params
@property
def has_structured_output(self) -> bool:
"""(bool) whether the chat model supports structured output."""
@ -273,12 +282,11 @@ class ChatModelUnitTests(ChatModelTests):
Value to use for tool choice when used in tests.
Some tests for tool calling features attempt to force tool calling via a
`tool_choice` parameter. A common value for this parameter is "any". Defaults
to `None`.
Note: if the value is set to "tool_name", the name of the tool used in each
test will be set as the value for `tool_choice`.
.. warning:: Deprecated since version 0.3.15:
This property will be removed in version 0.3.20. If a model does not
support forcing tool calling, override the ``has_tool_choice`` property to
return ``False``. Otherwise, models should accept values of ``"any"`` or
the name of a tool in ``tool_choice``.
Example:
@ -288,6 +296,26 @@ class ChatModelUnitTests(ChatModelTests):
def tool_choice_value(self) -> Optional[str]:
return "any"
.. dropdown:: has_tool_choice
Boolean property indicating whether the chat model supports forcing tool
calling via a ``tool_choice`` parameter.
By default, this is determined by whether the parameter is included in the
signature for the corresponding ``bind_tools`` method.
If ``True``, the minimum requirement for this feature is that
``tool_choice="any"`` will force a tool call, and ``tool_choice=<tool name>``
will force a call to a specific tool.
Example override:
.. code-block:: python
@property
def has_tool_choice(self) -> bool:
return False
.. dropdown:: has_structured_output
Boolean property indicating whether the chat model supports structured