mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 15:43:54 +00:00
multiple: enforce standards on tool_choice (#30372)
- Test if models support forcing tool calls via `tool_choice`. If they do, they should support - `"any"` to specify any tool - the tool name as a string to force calling a particular tool - Add `tool_choice` to signature of `BaseChatModel.bind_tools` in core - Deprecate `tool_choice_value` in standard tests in favor of a boolean `has_tool_choice` Will follow up with PRs in external repos (tested in AWS and Google already).
This commit is contained in:
parent
b86cd8270c
commit
de3960d285
@ -194,9 +194,14 @@ class ChatCloudflareWorkersAI(BaseChatModel):
|
||||
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
|
||||
"""Model wrapper that returns outputs formatted to match the given schema."""
|
||||
|
||||
_ = kwargs.pop("strict", None)
|
||||
if kwargs:
|
||||
raise ValueError(f"Received unsupported arguments {kwargs}")
|
||||
is_pydantic_schema = _is_pydantic_class(schema)
|
||||
if method == "json_schema":
|
||||
# Some applications require that incompatible parameters (e.g., unsupported
|
||||
# methods) be handled.
|
||||
method = "function_calling"
|
||||
if method == "function_calling":
|
||||
if schema is None:
|
||||
raise ValueError(
|
||||
|
@ -383,6 +383,8 @@ class ChatPerplexity(BaseChatModel):
|
||||
- "parsing_error": Optional[BaseException]
|
||||
|
||||
""" # noqa: E501
|
||||
if method in ("function_calling", "json_mode"):
|
||||
method = "json_schema"
|
||||
if method == "json_schema":
|
||||
if schema is None:
|
||||
raise ValueError(
|
||||
|
@ -6,6 +6,7 @@ from typing import (
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
@ -382,7 +383,7 @@ class ChatReka(BaseChatModel):
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
*,
|
||||
tool_choice: str = "auto",
|
||||
tool_choice: Optional[Union[str, Literal["any"]]] = "auto",
|
||||
strict: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
@ -421,6 +422,10 @@ class ChatReka(BaseChatModel):
|
||||
]
|
||||
|
||||
# Ensure tool_choice is one of the allowed options
|
||||
if tool_choice is None:
|
||||
tool_choice = "auto"
|
||||
if tool_choice == "any":
|
||||
tool_choice = "tool"
|
||||
if tool_choice not in ("auto", "none", "tool"):
|
||||
raise ValueError(
|
||||
f"Invalid tool_choice '{tool_choice}' provided. "
|
||||
|
@ -1167,6 +1167,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
tools: Sequence[
|
||||
Union[typing.Dict[str, Any], type, Callable, BaseTool] # noqa: UP006
|
||||
],
|
||||
*,
|
||||
tool_choice: Optional[Union[str, Literal["any"]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
raise NotImplementedError
|
||||
@ -1280,6 +1282,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
Added support for TypedDict class.
|
||||
""" # noqa: E501
|
||||
_ = kwargs.pop("method", None)
|
||||
_ = kwargs.pop("strict", None)
|
||||
if kwargs:
|
||||
msg = f"Received unsupported arguments {kwargs}"
|
||||
raise ValueError(msg)
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
|
||||
from typing import Optional, Type
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
@ -40,11 +40,6 @@ class TestGroqLlama(BaseTestGroq):
|
||||
"rate_limiter": rate_limiter,
|
||||
}
|
||||
|
||||
@property
|
||||
def tool_choice_value(self) -> Optional[str]:
|
||||
"""Value to use for tool choice when used in tests."""
|
||||
return "any"
|
||||
|
||||
@property
|
||||
def supports_json_mode(self) -> bool:
|
||||
return True
|
||||
|
@ -685,6 +685,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
|
||||
tool_choice: Optional[Union[dict, str, Literal["auto", "any"]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
"""Bind tool-like objects to this chat model.
|
||||
@ -705,6 +706,22 @@ class ChatMistralAI(BaseChatModel):
|
||||
"""
|
||||
|
||||
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||
if tool_choice:
|
||||
tool_names = []
|
||||
for tool in formatted_tools:
|
||||
if "function" in tool and (name := tool["function"].get("name")):
|
||||
tool_names.append(name)
|
||||
elif name := tool.get("name"):
|
||||
tool_names.append(name)
|
||||
else:
|
||||
pass
|
||||
if tool_choice in tool_names:
|
||||
kwargs["tool_choice"] = {
|
||||
"type": "function",
|
||||
"function": {"name": tool_choice},
|
||||
}
|
||||
else:
|
||||
kwargs["tool_choice"] = tool_choice
|
||||
return super().bind(tools=formatted_tools, **kwargs)
|
||||
|
||||
def with_structured_output(
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
|
||||
from typing import Optional, Type
|
||||
from typing import Type
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_tests.integration_tests import ( # type: ignore[import-not-found]
|
||||
@ -22,8 +22,3 @@ class TestMistralStandard(ChatModelIntegrationTests):
|
||||
@property
|
||||
def supports_json_mode(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def tool_choice_value(self) -> Optional[str]:
|
||||
"""Value to use for tool choice when used in tests."""
|
||||
return "any"
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
|
||||
from typing import Optional, Type
|
||||
from typing import Type
|
||||
|
||||
import pytest # type: ignore[import-not-found]
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
@ -30,11 +30,6 @@ class TestXAIStandard(ChatModelIntegrationTests):
|
||||
"rate_limiter": rate_limiter,
|
||||
}
|
||||
|
||||
@property
|
||||
def tool_choice_value(self) -> Optional[str]:
|
||||
"""Value to use for tool choice when used in tests."""
|
||||
return "tool_name"
|
||||
|
||||
@pytest.mark.xfail(reason="Not yet supported.")
|
||||
def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
|
||||
super().test_usage_metadata_streaming(model)
|
||||
|
@ -1,10 +1,12 @@
|
||||
import base64
|
||||
import inspect
|
||||
import json
|
||||
from typing import Any, List, Literal, Optional, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from langchain_core._api import warn_deprecated
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.language_models import BaseChatModel, GenericFakeChatModel
|
||||
from langchain_core.messages import (
|
||||
@ -204,12 +206,12 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
|
||||
Value to use for tool choice when used in tests.
|
||||
|
||||
Some tests for tool calling features attempt to force tool calling via a
|
||||
`tool_choice` parameter. A common value for this parameter is "any". Defaults
|
||||
to `None`.
|
||||
|
||||
Note: if the value is set to "tool_name", the name of the tool used in each
|
||||
test will be set as the value for `tool_choice`.
|
||||
.. warning:: Deprecated since version 0.3.15:
|
||||
This property will be removed in version 0.3.20. If a model supports
|
||||
``tool_choice``, it should accept ``tool_choice="any"`` and
|
||||
``tool_choice=<string name of tool>``. If a model does not
|
||||
support forcing tool calling, override the ``has_tool_choice`` property to
|
||||
return ``False``.
|
||||
|
||||
Example:
|
||||
|
||||
@ -219,6 +221,26 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
def tool_choice_value(self) -> Optional[str]:
|
||||
return "any"
|
||||
|
||||
.. dropdown:: has_tool_choice
|
||||
|
||||
Boolean property indicating whether the chat model supports forcing tool
|
||||
calling via a ``tool_choice`` parameter.
|
||||
|
||||
By default, this is determined by whether the parameter is included in the
|
||||
signature for the corresponding ``bind_tools`` method.
|
||||
|
||||
If ``True``, the minimum requirement for this feature is that
|
||||
``tool_choice="any"`` will force a tool call, and ``tool_choice=<tool name>``
|
||||
will force a call to a specific tool.
|
||||
|
||||
Example override:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@property
|
||||
def has_tool_choice(self) -> bool:
|
||||
return False
|
||||
|
||||
.. dropdown:: has_structured_output
|
||||
|
||||
Boolean property indicating whether the chat model supports structured
|
||||
@ -989,16 +1011,33 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
def test_tool_calling(self, model: BaseChatModel) -> None:
|
||||
super().test_tool_calling(model)
|
||||
|
||||
Otherwise, ensure that the ``tool_choice_value`` property is correctly
|
||||
specified on the test class.
|
||||
Otherwise, in the case that only one tool is bound, ensure that
|
||||
``tool_choice`` supports the string ``"any"`` to force calling that tool.
|
||||
"""
|
||||
if not self.has_tool_calling:
|
||||
pytest.skip("Test requires tool calling.")
|
||||
if self.tool_choice_value == "tool_name":
|
||||
tool_choice: Optional[str] = "magic_function"
|
||||
if not self.has_tool_choice:
|
||||
tool_choice_value = None
|
||||
else:
|
||||
tool_choice = self.tool_choice_value
|
||||
model_with_tools = model.bind_tools([magic_function], tool_choice=tool_choice)
|
||||
tool_choice_value = "any"
|
||||
# Emit warning if tool_choice_value property is overridden
|
||||
if inspect.getattr_static(
|
||||
self, "tool_choice_value"
|
||||
) is not inspect.getattr_static(ChatModelIntegrationTests, "tool_choice_value"):
|
||||
warn_deprecated(
|
||||
"0.3.15",
|
||||
message=(
|
||||
"`tool_choice_value` will be removed in version 0.3.20. If a "
|
||||
"model supports `tool_choice`, it should accept `tool_choice='any' "
|
||||
"and `tool_choice=<string name of tool>`. If the model does not "
|
||||
"support `tool_choice`, override the `supports_tool_choice` "
|
||||
"property to return `False`."
|
||||
),
|
||||
removal="0.3.20",
|
||||
)
|
||||
model_with_tools = model.bind_tools(
|
||||
[magic_function], tool_choice=tool_choice_value
|
||||
)
|
||||
|
||||
# Test invoke
|
||||
query = "What is the value of magic_function(3)? Use the tool."
|
||||
@ -1012,6 +1051,57 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
assert isinstance(full, AIMessage)
|
||||
_validate_tool_call_message(full)
|
||||
|
||||
def test_tool_choice(self, model: BaseChatModel) -> None:
|
||||
"""Test that the model can force tool calling via the ``tool_choice``
|
||||
parameter. This test is skipped if the ``has_tool_choice`` property on the
|
||||
test class is set to False.
|
||||
|
||||
This test is optional and should be skipped if the model does not support
|
||||
tool calling (see Configuration below).
|
||||
|
||||
.. dropdown:: Configuration
|
||||
|
||||
To disable tool calling tests, set ``has_tool_choice`` to False in your
|
||||
test class:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class TestMyChatModelIntegration(ChatModelIntegrationTests):
|
||||
@property
|
||||
def has_tool_choice(self) -> bool:
|
||||
return False
|
||||
|
||||
.. dropdown:: Troubleshooting
|
||||
|
||||
If this test fails, check whether the ``test_tool_calling`` test is passing.
|
||||
If it is not, refer to the troubleshooting steps in that test first.
|
||||
|
||||
If ``test_tool_calling`` is passing, check that the underlying model
|
||||
supports forced tool calling. If it does, ``bind_tools`` should accept a
|
||||
``tool_choice`` parameter that can be used to force a tool call.
|
||||
|
||||
It should accept (1) the string ``"any"`` to force calling the bound tool,
|
||||
and (2) the string name of the tool to force calling that tool.
|
||||
|
||||
"""
|
||||
if not self.has_tool_choice or not self.has_tool_calling:
|
||||
pytest.skip("Test requires tool choice.")
|
||||
|
||||
@tool
|
||||
def get_weather(location: str) -> str:
|
||||
"""Get weather at a location."""
|
||||
return "It's sunny."
|
||||
|
||||
for tool_choice in ["any", "magic_function"]:
|
||||
model_with_tools = model.bind_tools(
|
||||
[magic_function, get_weather], tool_choice=tool_choice
|
||||
)
|
||||
result = model_with_tools.invoke("Hello!")
|
||||
assert isinstance(result, AIMessage)
|
||||
assert result.tool_calls
|
||||
if tool_choice == "magic_function":
|
||||
assert result.tool_calls[0]["name"] == "magic_function"
|
||||
|
||||
async def test_tool_calling_async(self, model: BaseChatModel) -> None:
|
||||
"""Test that the model generates tool calls. This test is skipped if the
|
||||
``has_tool_calling`` property on the test class is set to False.
|
||||
@ -1048,16 +1138,18 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
async def test_tool_calling_async(self, model: BaseChatModel) -> None:
|
||||
await super().test_tool_calling_async(model)
|
||||
|
||||
Otherwise, ensure that the ``tool_choice_value`` property is correctly
|
||||
specified on the test class.
|
||||
Otherwise, in the case that only one tool is bound, ensure that
|
||||
``tool_choice`` supports the string ``"any"`` to force calling that tool.
|
||||
"""
|
||||
if not self.has_tool_calling:
|
||||
pytest.skip("Test requires tool calling.")
|
||||
if self.tool_choice_value == "tool_name":
|
||||
tool_choice: Optional[str] = "magic_function"
|
||||
if not self.has_tool_choice:
|
||||
tool_choice_value = None
|
||||
else:
|
||||
tool_choice = self.tool_choice_value
|
||||
model_with_tools = model.bind_tools([magic_function], tool_choice=tool_choice)
|
||||
tool_choice_value = "any"
|
||||
model_with_tools = model.bind_tools(
|
||||
[magic_function], tool_choice=tool_choice_value
|
||||
)
|
||||
|
||||
# Test ainvoke
|
||||
query = "What is the value of magic_function(3)? Use the tool."
|
||||
@ -1109,18 +1201,17 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None:
|
||||
super().test_tool_calling_with_no_arguments(model)
|
||||
|
||||
Otherwise, ensure that the ``tool_choice_value`` property is correctly
|
||||
specified on the test class.
|
||||
Otherwise, in the case that only one tool is bound, ensure that
|
||||
``tool_choice`` supports the string ``"any"`` to force calling that tool.
|
||||
""" # noqa: E501
|
||||
if not self.has_tool_calling:
|
||||
pytest.skip("Test requires tool calling.")
|
||||
|
||||
if self.tool_choice_value == "tool_name":
|
||||
tool_choice: Optional[str] = "magic_function_no_args"
|
||||
if not self.has_tool_choice:
|
||||
tool_choice_value = None
|
||||
else:
|
||||
tool_choice = self.tool_choice_value
|
||||
tool_choice_value = "any"
|
||||
model_with_tools = model.bind_tools(
|
||||
[magic_function_no_args], tool_choice=tool_choice
|
||||
[magic_function_no_args], tool_choice=tool_choice_value
|
||||
)
|
||||
query = "What is the value of magic_function_no_args()? Use the tool."
|
||||
result = model_with_tools.invoke(query)
|
||||
@ -1184,10 +1275,10 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
name="greeting_generator",
|
||||
description="Generate a greeting in a particular style of speaking.",
|
||||
)
|
||||
if self.tool_choice_value == "tool_name":
|
||||
tool_choice: Optional[str] = "greeting_generator"
|
||||
if not self.has_tool_choice:
|
||||
tool_choice: Optional[str] = "any"
|
||||
else:
|
||||
tool_choice = self.tool_choice_value
|
||||
tool_choice = None
|
||||
model_with_tools = model.bind_tools([tool_], tool_choice=tool_choice)
|
||||
query = "Using the tool, generate a Pirate greeting."
|
||||
result = model_with_tools.invoke(query)
|
||||
@ -2086,9 +2177,6 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
|
||||
If this test fails, check that the ``status`` field on ``ToolMessage``
|
||||
objects is either ignored or passed to the model appropriately.
|
||||
|
||||
Otherwise, ensure that the ``tool_choice_value`` property is correctly
|
||||
specified on the test class.
|
||||
"""
|
||||
if not self.has_tool_calling:
|
||||
pytest.skip("Test requires tool calling.")
|
||||
|
@ -2,6 +2,7 @@
|
||||
:autodoc-options: autoproperty
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Type
|
||||
@ -127,6 +128,14 @@ class ChatModelTests(BaseStandardTests):
|
||||
"""(None or str) to use for tool choice when used in tests."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def has_tool_choice(self) -> bool:
|
||||
"""(bool) whether the model supports tool calling."""
|
||||
bind_tools_params = inspect.signature(
|
||||
self.chat_model_class.bind_tools
|
||||
).parameters
|
||||
return "tool_choice" in bind_tools_params
|
||||
|
||||
@property
|
||||
def has_structured_output(self) -> bool:
|
||||
"""(bool) whether the chat model supports structured output."""
|
||||
@ -273,12 +282,11 @@ class ChatModelUnitTests(ChatModelTests):
|
||||
|
||||
Value to use for tool choice when used in tests.
|
||||
|
||||
Some tests for tool calling features attempt to force tool calling via a
|
||||
`tool_choice` parameter. A common value for this parameter is "any". Defaults
|
||||
to `None`.
|
||||
|
||||
Note: if the value is set to "tool_name", the name of the tool used in each
|
||||
test will be set as the value for `tool_choice`.
|
||||
.. warning:: Deprecated since version 0.3.15:
|
||||
This property will be removed in version 0.3.20. If a model does not
|
||||
support forcing tool calling, override the ``has_tool_choice`` property to
|
||||
return ``False``. Otherwise, models should accept values of ``"any"`` or
|
||||
the name of a tool in ``tool_choice``.
|
||||
|
||||
Example:
|
||||
|
||||
@ -288,6 +296,26 @@ class ChatModelUnitTests(ChatModelTests):
|
||||
def tool_choice_value(self) -> Optional[str]:
|
||||
return "any"
|
||||
|
||||
.. dropdown:: has_tool_choice
|
||||
|
||||
Boolean property indicating whether the chat model supports forcing tool
|
||||
calling via a ``tool_choice`` parameter.
|
||||
|
||||
By default, this is determined by whether the parameter is included in the
|
||||
signature for the corresponding ``bind_tools`` method.
|
||||
|
||||
If ``True``, the minimum requirement for this feature is that
|
||||
``tool_choice="any"`` will force a tool call, and ``tool_choice=<tool name>``
|
||||
will force a call to a specific tool.
|
||||
|
||||
Example override:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@property
|
||||
def has_tool_choice(self) -> bool:
|
||||
return False
|
||||
|
||||
.. dropdown:: has_structured_output
|
||||
|
||||
Boolean property indicating whether the chat model supports structured
|
||||
|
Loading…
Reference in New Issue
Block a user