mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
multiple: enforce standards on tool_choice (#30372)
- Test if models support forcing tool calls via `tool_choice`. If they do, they should support - `"any"` to specify any tool - the tool name as a string to force calling a particular tool - Add `tool_choice` to signature of `BaseChatModel.bind_tools` in core - Deprecate `tool_choice_value` in standard tests in favor of a boolean `has_tool_choice` Will follow up with PRs in external repos (tested in AWS and Google already).
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
|
||||
from typing import Optional, Type
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
@@ -40,11 +40,6 @@ class TestGroqLlama(BaseTestGroq):
|
||||
"rate_limiter": rate_limiter,
|
||||
}
|
||||
|
||||
@property
|
||||
def tool_choice_value(self) -> Optional[str]:
|
||||
"""Value to use for tool choice when used in tests."""
|
||||
return "any"
|
||||
|
||||
@property
|
||||
def supports_json_mode(self) -> bool:
|
||||
return True
|
||||
|
||||
@@ -685,6 +685,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
|
||||
tool_choice: Optional[Union[dict, str, Literal["auto", "any"]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
"""Bind tool-like objects to this chat model.
|
||||
@@ -705,6 +706,22 @@ class ChatMistralAI(BaseChatModel):
|
||||
"""
|
||||
|
||||
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||
if tool_choice:
|
||||
tool_names = []
|
||||
for tool in formatted_tools:
|
||||
if "function" in tool and (name := tool["function"].get("name")):
|
||||
tool_names.append(name)
|
||||
elif name := tool.get("name"):
|
||||
tool_names.append(name)
|
||||
else:
|
||||
pass
|
||||
if tool_choice in tool_names:
|
||||
kwargs["tool_choice"] = {
|
||||
"type": "function",
|
||||
"function": {"name": tool_choice},
|
||||
}
|
||||
else:
|
||||
kwargs["tool_choice"] = tool_choice
|
||||
return super().bind(tools=formatted_tools, **kwargs)
|
||||
|
||||
def with_structured_output(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
|
||||
from typing import Optional, Type
|
||||
from typing import Type
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_tests.integration_tests import ( # type: ignore[import-not-found]
|
||||
@@ -22,8 +22,3 @@ class TestMistralStandard(ChatModelIntegrationTests):
|
||||
@property
|
||||
def supports_json_mode(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def tool_choice_value(self) -> Optional[str]:
|
||||
"""Value to use for tool choice when used in tests."""
|
||||
return "any"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
|
||||
from typing import Optional, Type
|
||||
from typing import Type
|
||||
|
||||
import pytest # type: ignore[import-not-found]
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
@@ -30,11 +30,6 @@ class TestXAIStandard(ChatModelIntegrationTests):
|
||||
"rate_limiter": rate_limiter,
|
||||
}
|
||||
|
||||
@property
|
||||
def tool_choice_value(self) -> Optional[str]:
|
||||
"""Value to use for tool choice when used in tests."""
|
||||
return "tool_name"
|
||||
|
||||
@pytest.mark.xfail(reason="Not yet supported.")
|
||||
def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
|
||||
super().test_usage_metadata_streaming(model)
|
||||
|
||||
Reference in New Issue
Block a user