multiple: enforce standards on tool_choice (#30372)

- Test if models support forcing tool calls via `tool_choice`. If they
do, they should support
  - `"any"` to specify any tool
  - the tool name as a string to force calling a particular tool
- Add `tool_choice` to signature of `BaseChatModel.bind_tools` in core
- Deprecate `tool_choice_value` in standard tests in favor of a boolean
`has_tool_choice`

Will follow up with PRs in external repos (tested in AWS and Google
already).
This commit is contained in:
ccurme 2025-03-20 13:48:59 -04:00 committed by GitHub
parent b86cd8270c
commit de3960d285
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 190 additions and 56 deletions

View File

@ -194,9 +194,14 @@ class ChatCloudflareWorkersAI(BaseChatModel):
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema.""" """Model wrapper that returns outputs formatted to match the given schema."""
_ = kwargs.pop("strict", None)
if kwargs: if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}") raise ValueError(f"Received unsupported arguments {kwargs}")
is_pydantic_schema = _is_pydantic_class(schema) is_pydantic_schema = _is_pydantic_class(schema)
if method == "json_schema":
# Some applications require that incompatible parameters (e.g., unsupported
# methods) be handled.
method = "function_calling"
if method == "function_calling": if method == "function_calling":
if schema is None: if schema is None:
raise ValueError( raise ValueError(

View File

@ -383,6 +383,8 @@ class ChatPerplexity(BaseChatModel):
- "parsing_error": Optional[BaseException] - "parsing_error": Optional[BaseException]
""" # noqa: E501 """ # noqa: E501
if method in ("function_calling", "json_mode"):
method = "json_schema"
if method == "json_schema": if method == "json_schema":
if schema is None: if schema is None:
raise ValueError( raise ValueError(

View File

@ -6,6 +6,7 @@ from typing import (
Dict, Dict,
Iterator, Iterator,
List, List,
Literal,
Mapping, Mapping,
Optional, Optional,
Sequence, Sequence,
@ -382,7 +383,7 @@ class ChatReka(BaseChatModel):
self, self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
*, *,
tool_choice: str = "auto", tool_choice: Optional[Union[str, Literal["any"]]] = "auto",
strict: Optional[bool] = None, strict: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]: ) -> Runnable[LanguageModelInput, BaseMessage]:
@ -421,6 +422,10 @@ class ChatReka(BaseChatModel):
] ]
# Ensure tool_choice is one of the allowed options # Ensure tool_choice is one of the allowed options
if tool_choice is None:
tool_choice = "auto"
if tool_choice == "any":
tool_choice = "tool"
if tool_choice not in ("auto", "none", "tool"): if tool_choice not in ("auto", "none", "tool"):
raise ValueError( raise ValueError(
f"Invalid tool_choice '{tool_choice}' provided. " f"Invalid tool_choice '{tool_choice}' provided. "

View File

@ -1167,6 +1167,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
tools: Sequence[ tools: Sequence[
Union[typing.Dict[str, Any], type, Callable, BaseTool] # noqa: UP006 Union[typing.Dict[str, Any], type, Callable, BaseTool] # noqa: UP006
], ],
*,
tool_choice: Optional[Union[str, Literal["any"]]] = None,
**kwargs: Any, **kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]: ) -> Runnable[LanguageModelInput, BaseMessage]:
raise NotImplementedError raise NotImplementedError
@ -1280,6 +1282,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
Added support for TypedDict class. Added support for TypedDict class.
""" # noqa: E501 """ # noqa: E501
_ = kwargs.pop("method", None)
_ = kwargs.pop("strict", None)
if kwargs: if kwargs:
msg = f"Received unsupported arguments {kwargs}" msg = f"Received unsupported arguments {kwargs}"
raise ValueError(msg) raise ValueError(msg)

View File

@ -1,6 +1,6 @@
"""Standard LangChain interface tests""" """Standard LangChain interface tests"""
from typing import Optional, Type from typing import Type
import pytest import pytest
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
@ -40,11 +40,6 @@ class TestGroqLlama(BaseTestGroq):
"rate_limiter": rate_limiter, "rate_limiter": rate_limiter,
} }
@property
def tool_choice_value(self) -> Optional[str]:
"""Value to use for tool choice when used in tests."""
return "any"
@property @property
def supports_json_mode(self) -> bool: def supports_json_mode(self) -> bool:
return True return True

View File

@ -685,6 +685,7 @@ class ChatMistralAI(BaseChatModel):
def bind_tools( def bind_tools(
self, self,
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]], tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
tool_choice: Optional[Union[dict, str, Literal["auto", "any"]]] = None,
**kwargs: Any, **kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]: ) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model. """Bind tool-like objects to this chat model.
@ -705,6 +706,22 @@ class ChatMistralAI(BaseChatModel):
""" """
formatted_tools = [convert_to_openai_tool(tool) for tool in tools] formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
if tool_choice:
tool_names = []
for tool in formatted_tools:
if "function" in tool and (name := tool["function"].get("name")):
tool_names.append(name)
elif name := tool.get("name"):
tool_names.append(name)
else:
pass
if tool_choice in tool_names:
kwargs["tool_choice"] = {
"type": "function",
"function": {"name": tool_choice},
}
else:
kwargs["tool_choice"] = tool_choice
return super().bind(tools=formatted_tools, **kwargs) return super().bind(tools=formatted_tools, **kwargs)
def with_structured_output( def with_structured_output(

View File

@ -1,6 +1,6 @@
"""Standard LangChain interface tests""" """Standard LangChain interface tests"""
from typing import Optional, Type from typing import Type
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_tests.integration_tests import ( # type: ignore[import-not-found] from langchain_tests.integration_tests import ( # type: ignore[import-not-found]
@ -22,8 +22,3 @@ class TestMistralStandard(ChatModelIntegrationTests):
@property @property
def supports_json_mode(self) -> bool: def supports_json_mode(self) -> bool:
return True return True
@property
def tool_choice_value(self) -> Optional[str]:
"""Value to use for tool choice when used in tests."""
return "any"

View File

@ -1,6 +1,6 @@
"""Standard LangChain interface tests""" """Standard LangChain interface tests"""
from typing import Optional, Type from typing import Type
import pytest # type: ignore[import-not-found] import pytest # type: ignore[import-not-found]
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
@ -30,11 +30,6 @@ class TestXAIStandard(ChatModelIntegrationTests):
"rate_limiter": rate_limiter, "rate_limiter": rate_limiter,
} }
@property
def tool_choice_value(self) -> Optional[str]:
"""Value to use for tool choice when used in tests."""
return "tool_name"
@pytest.mark.xfail(reason="Not yet supported.") @pytest.mark.xfail(reason="Not yet supported.")
def test_usage_metadata_streaming(self, model: BaseChatModel) -> None: def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
super().test_usage_metadata_streaming(model) super().test_usage_metadata_streaming(model)

View File

@ -1,10 +1,12 @@
import base64 import base64
import inspect
import json import json
from typing import Any, List, Literal, Optional, cast from typing import Any, List, Literal, Optional, cast
from unittest.mock import MagicMock from unittest.mock import MagicMock
import httpx import httpx
import pytest import pytest
from langchain_core._api import warn_deprecated
from langchain_core.callbacks import BaseCallbackHandler from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.language_models import BaseChatModel, GenericFakeChatModel from langchain_core.language_models import BaseChatModel, GenericFakeChatModel
from langchain_core.messages import ( from langchain_core.messages import (
@ -204,12 +206,12 @@ class ChatModelIntegrationTests(ChatModelTests):
Value to use for tool choice when used in tests. Value to use for tool choice when used in tests.
Some tests for tool calling features attempt to force tool calling via a .. warning:: Deprecated since version 0.3.15:
`tool_choice` parameter. A common value for this parameter is "any". Defaults This property will be removed in version 0.3.20. If a model supports
to `None`. ``tool_choice``, it should accept ``tool_choice="any"`` and
``tool_choice=<string name of tool>``. If a model does not
Note: if the value is set to "tool_name", the name of the tool used in each support forcing tool calling, override the ``has_tool_choice`` property to
test will be set as the value for `tool_choice`. return ``False``.
Example: Example:
@ -219,6 +221,26 @@ class ChatModelIntegrationTests(ChatModelTests):
def tool_choice_value(self) -> Optional[str]: def tool_choice_value(self) -> Optional[str]:
return "any" return "any"
.. dropdown:: has_tool_choice
Boolean property indicating whether the chat model supports forcing tool
calling via a ``tool_choice`` parameter.
By default, this is determined by whether the parameter is included in the
signature for the corresponding ``bind_tools`` method.
If ``True``, the minimum requirement for this feature is that
``tool_choice="any"`` will force a tool call, and ``tool_choice=<tool name>``
will force a call to a specific tool.
Example override:
.. code-block:: python
@property
def has_tool_choice(self) -> bool:
return False
.. dropdown:: has_structured_output .. dropdown:: has_structured_output
Boolean property indicating whether the chat model supports structured Boolean property indicating whether the chat model supports structured
@ -989,16 +1011,33 @@ class ChatModelIntegrationTests(ChatModelTests):
def test_tool_calling(self, model: BaseChatModel) -> None: def test_tool_calling(self, model: BaseChatModel) -> None:
super().test_tool_calling(model) super().test_tool_calling(model)
Otherwise, ensure that the ``tool_choice_value`` property is correctly Otherwise, in the case that only one tool is bound, ensure that
specified on the test class. ``tool_choice`` supports the string ``"any"`` to force calling that tool.
""" """
if not self.has_tool_calling: if not self.has_tool_calling:
pytest.skip("Test requires tool calling.") pytest.skip("Test requires tool calling.")
if self.tool_choice_value == "tool_name": if not self.has_tool_choice:
tool_choice: Optional[str] = "magic_function" tool_choice_value = None
else: else:
tool_choice = self.tool_choice_value tool_choice_value = "any"
model_with_tools = model.bind_tools([magic_function], tool_choice=tool_choice) # Emit warning if tool_choice_value property is overridden
if inspect.getattr_static(
self, "tool_choice_value"
) is not inspect.getattr_static(ChatModelIntegrationTests, "tool_choice_value"):
warn_deprecated(
"0.3.15",
message=(
"`tool_choice_value` will be removed in version 0.3.20. If a "
"model supports `tool_choice`, it should accept `tool_choice='any' "
"and `tool_choice=<string name of tool>`. If the model does not "
"support `tool_choice`, override the `supports_tool_choice` "
"property to return `False`."
),
removal="0.3.20",
)
model_with_tools = model.bind_tools(
[magic_function], tool_choice=tool_choice_value
)
# Test invoke # Test invoke
query = "What is the value of magic_function(3)? Use the tool." query = "What is the value of magic_function(3)? Use the tool."
@ -1012,6 +1051,57 @@ class ChatModelIntegrationTests(ChatModelTests):
assert isinstance(full, AIMessage) assert isinstance(full, AIMessage)
_validate_tool_call_message(full) _validate_tool_call_message(full)
def test_tool_choice(self, model: BaseChatModel) -> None:
"""Test that the model can force tool calling via the ``tool_choice``
parameter. This test is skipped if the ``has_tool_choice`` property on the
test class is set to False.
This test is optional and should be skipped if the model does not support
tool calling (see Configuration below).
.. dropdown:: Configuration
To disable tool calling tests, set ``has_tool_choice`` to False in your
test class:
.. code-block:: python
class TestMyChatModelIntegration(ChatModelIntegrationTests):
@property
def has_tool_choice(self) -> bool:
return False
.. dropdown:: Troubleshooting
If this test fails, check whether the ``test_tool_calling`` test is passing.
If it is not, refer to the troubleshooting steps in that test first.
If ``test_tool_calling`` is passing, check that the underlying model
supports forced tool calling. If it does, ``bind_tools`` should accept a
``tool_choice`` parameter that can be used to force a tool call.
It should accept (1) the string ``"any"`` to force calling the bound tool,
and (2) the string name of the tool to force calling that tool.
"""
if not self.has_tool_choice or not self.has_tool_calling:
pytest.skip("Test requires tool choice.")
@tool
def get_weather(location: str) -> str:
"""Get weather at a location."""
return "It's sunny."
for tool_choice in ["any", "magic_function"]:
model_with_tools = model.bind_tools(
[magic_function, get_weather], tool_choice=tool_choice
)
result = model_with_tools.invoke("Hello!")
assert isinstance(result, AIMessage)
assert result.tool_calls
if tool_choice == "magic_function":
assert result.tool_calls[0]["name"] == "magic_function"
async def test_tool_calling_async(self, model: BaseChatModel) -> None: async def test_tool_calling_async(self, model: BaseChatModel) -> None:
"""Test that the model generates tool calls. This test is skipped if the """Test that the model generates tool calls. This test is skipped if the
``has_tool_calling`` property on the test class is set to False. ``has_tool_calling`` property on the test class is set to False.
@ -1048,16 +1138,18 @@ class ChatModelIntegrationTests(ChatModelTests):
async def test_tool_calling_async(self, model: BaseChatModel) -> None: async def test_tool_calling_async(self, model: BaseChatModel) -> None:
await super().test_tool_calling_async(model) await super().test_tool_calling_async(model)
Otherwise, ensure that the ``tool_choice_value`` property is correctly Otherwise, in the case that only one tool is bound, ensure that
specified on the test class. ``tool_choice`` supports the string ``"any"`` to force calling that tool.
""" """
if not self.has_tool_calling: if not self.has_tool_calling:
pytest.skip("Test requires tool calling.") pytest.skip("Test requires tool calling.")
if self.tool_choice_value == "tool_name": if not self.has_tool_choice:
tool_choice: Optional[str] = "magic_function" tool_choice_value = None
else: else:
tool_choice = self.tool_choice_value tool_choice_value = "any"
model_with_tools = model.bind_tools([magic_function], tool_choice=tool_choice) model_with_tools = model.bind_tools(
[magic_function], tool_choice=tool_choice_value
)
# Test ainvoke # Test ainvoke
query = "What is the value of magic_function(3)? Use the tool." query = "What is the value of magic_function(3)? Use the tool."
@ -1109,18 +1201,17 @@ class ChatModelIntegrationTests(ChatModelTests):
def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None: def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None:
super().test_tool_calling_with_no_arguments(model) super().test_tool_calling_with_no_arguments(model)
Otherwise, ensure that the ``tool_choice_value`` property is correctly Otherwise, in the case that only one tool is bound, ensure that
specified on the test class. ``tool_choice`` supports the string ``"any"`` to force calling that tool.
""" # noqa: E501 """ # noqa: E501
if not self.has_tool_calling: if not self.has_tool_calling:
pytest.skip("Test requires tool calling.") pytest.skip("Test requires tool calling.")
if not self.has_tool_choice:
if self.tool_choice_value == "tool_name": tool_choice_value = None
tool_choice: Optional[str] = "magic_function_no_args"
else: else:
tool_choice = self.tool_choice_value tool_choice_value = "any"
model_with_tools = model.bind_tools( model_with_tools = model.bind_tools(
[magic_function_no_args], tool_choice=tool_choice [magic_function_no_args], tool_choice=tool_choice_value
) )
query = "What is the value of magic_function_no_args()? Use the tool." query = "What is the value of magic_function_no_args()? Use the tool."
result = model_with_tools.invoke(query) result = model_with_tools.invoke(query)
@ -1184,10 +1275,10 @@ class ChatModelIntegrationTests(ChatModelTests):
name="greeting_generator", name="greeting_generator",
description="Generate a greeting in a particular style of speaking.", description="Generate a greeting in a particular style of speaking.",
) )
if self.tool_choice_value == "tool_name": if not self.has_tool_choice:
tool_choice: Optional[str] = "greeting_generator" tool_choice: Optional[str] = "any"
else: else:
tool_choice = self.tool_choice_value tool_choice = None
model_with_tools = model.bind_tools([tool_], tool_choice=tool_choice) model_with_tools = model.bind_tools([tool_], tool_choice=tool_choice)
query = "Using the tool, generate a Pirate greeting." query = "Using the tool, generate a Pirate greeting."
result = model_with_tools.invoke(query) result = model_with_tools.invoke(query)
@ -2086,9 +2177,6 @@ class ChatModelIntegrationTests(ChatModelTests):
If this test fails, check that the ``status`` field on ``ToolMessage`` If this test fails, check that the ``status`` field on ``ToolMessage``
objects is either ignored or passed to the model appropriately. objects is either ignored or passed to the model appropriately.
Otherwise, ensure that the ``tool_choice_value`` property is correctly
specified on the test class.
""" """
if not self.has_tool_calling: if not self.has_tool_calling:
pytest.skip("Test requires tool calling.") pytest.skip("Test requires tool calling.")

View File

@ -2,6 +2,7 @@
:autodoc-options: autoproperty :autodoc-options: autoproperty
""" """
import inspect
import os import os
from abc import abstractmethod from abc import abstractmethod
from typing import Any, Dict, List, Literal, Optional, Tuple, Type from typing import Any, Dict, List, Literal, Optional, Tuple, Type
@ -127,6 +128,14 @@ class ChatModelTests(BaseStandardTests):
"""(None or str) to use for tool choice when used in tests.""" """(None or str) to use for tool choice when used in tests."""
return None return None
@property
def has_tool_choice(self) -> bool:
"""(bool) whether the model supports tool calling."""
bind_tools_params = inspect.signature(
self.chat_model_class.bind_tools
).parameters
return "tool_choice" in bind_tools_params
@property @property
def has_structured_output(self) -> bool: def has_structured_output(self) -> bool:
"""(bool) whether the chat model supports structured output.""" """(bool) whether the chat model supports structured output."""
@ -273,12 +282,11 @@ class ChatModelUnitTests(ChatModelTests):
Value to use for tool choice when used in tests. Value to use for tool choice when used in tests.
Some tests for tool calling features attempt to force tool calling via a .. warning:: Deprecated since version 0.3.15:
`tool_choice` parameter. A common value for this parameter is "any". Defaults This property will be removed in version 0.3.20. If a model does not
to `None`. support forcing tool calling, override the ``has_tool_choice`` property to
return ``False``. Otherwise, models should accept values of ``"any"`` or
Note: if the value is set to "tool_name", the name of the tool used in each the name of a tool in ``tool_choice``.
test will be set as the value for `tool_choice`.
Example: Example:
@ -288,6 +296,26 @@ class ChatModelUnitTests(ChatModelTests):
def tool_choice_value(self) -> Optional[str]: def tool_choice_value(self) -> Optional[str]:
return "any" return "any"
.. dropdown:: has_tool_choice
Boolean property indicating whether the chat model supports forcing tool
calling via a ``tool_choice`` parameter.
By default, this is determined by whether the parameter is included in the
signature for the corresponding ``bind_tools`` method.
If ``True``, the minimum requirement for this feature is that
``tool_choice="any"`` will force a tool call, and ``tool_choice=<tool name>``
will force a call to a specific tool.
Example override:
.. code-block:: python
@property
def has_tool_choice(self) -> bool:
return False
.. dropdown:: has_structured_output .. dropdown:: has_structured_output
Boolean property indicating whether the chat model supports structured Boolean property indicating whether the chat model supports structured