mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-13 16:36:06 +00:00
Add validation on agent instantiation for multi-input tools (#3681)
Tradeoffs here: - No lint-time checking for compatibility - Differs from JS package - The signature inference, etc. in the base tool isn't simple - The `args_schema` is optional Pros: - Forwards compatibility retained - Doesn't break backwards compatibility - User doesn't have to think about which class to subclass (single base tool or dynamic `Tool` interface regardless of input) - No need to change the load_tools, etc. interfaces Co-authored-by: Hasan Patel <mangafield@gmail.com>
This commit is contained in:
parent
212aadd4af
commit
4654c58f72
@ -454,7 +454,11 @@ class Agent(BaseSingleActionAgent):
|
||||
@classmethod
|
||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||
"""Validate that appropriate tools are passed in."""
|
||||
pass
|
||||
for tool in tools:
|
||||
if not tool.is_single_input:
|
||||
raise ValueError(
|
||||
f"{cls.__name__} does not support multi-input tool {tool.name}."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
|
@ -122,6 +122,7 @@ class ZeroShotAgent(Agent):
|
||||
|
||||
@classmethod
|
||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||
super()._validate_tools(tools)
|
||||
for tool in tools:
|
||||
if tool.description is None:
|
||||
raise ValueError(
|
||||
|
@ -37,6 +37,7 @@ class ReActDocstoreAgent(Agent):
|
||||
|
||||
@classmethod
|
||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||
super()._validate_tools(tools)
|
||||
if len(tools) != 2:
|
||||
raise ValueError(f"Exactly two tools must be specified, but got {tools}")
|
||||
tool_names = {tool.name for tool in tools}
|
||||
@ -119,6 +120,7 @@ class ReActTextWorldAgent(ReActDocstoreAgent):
|
||||
|
||||
@classmethod
|
||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||
super()._validate_tools(tools)
|
||||
if len(tools) != 1:
|
||||
raise ValueError(f"Exactly one tool must be specified, but got {tools}")
|
||||
tool_names = {tool.name for tool in tools}
|
||||
|
@ -36,6 +36,7 @@ class SelfAskWithSearchAgent(Agent):
|
||||
|
||||
@classmethod
|
||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||
super()._validate_tools(tools)
|
||||
if len(tools) != 1:
|
||||
raise ValueError(f"Exactly one tool must be specified, but got {tools}")
|
||||
tool_names = {tool.name for tool in tools}
|
||||
|
@ -115,6 +115,11 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def is_single_input(self) -> bool:
|
||||
"""Whether the tool only accepts a single input."""
|
||||
return len(self.args) == 1
|
||||
|
||||
@property
|
||||
def args(self) -> dict:
|
||||
if self.args_schema is not None:
|
||||
@ -148,11 +153,11 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
||||
return callback_manager or get_callback_manager()
|
||||
|
||||
@abstractmethod
|
||||
def _run(self, *args: Any, **kwargs: Any) -> str:
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Use the tool."""
|
||||
|
||||
@abstractmethod
|
||||
async def _arun(self, *args: Any, **kwargs: Any) -> str:
|
||||
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Use the tool asynchronously."""
|
||||
|
||||
def run(
|
||||
@ -183,7 +188,7 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
||||
self.callback_manager.on_tool_error(e, verbose=verbose_)
|
||||
raise e
|
||||
self.callback_manager.on_tool_end(
|
||||
observation, verbose=verbose_, color=color, name=self.name, **kwargs
|
||||
str(observation), verbose=verbose_, color=color, name=self.name, **kwargs
|
||||
)
|
||||
return observation
|
||||
|
||||
@ -194,7 +199,7 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
||||
start_color: Optional[str] = "green",
|
||||
color: Optional[str] = "green",
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
) -> Any:
|
||||
"""Run the tool asynchronously."""
|
||||
self._parse_input(tool_input)
|
||||
if not self.verbose and verbose is not None:
|
||||
@ -229,7 +234,11 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
||||
raise e
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_tool_end(
|
||||
observation, verbose=verbose_, color=color, name=self.name, **kwargs
|
||||
str(observation),
|
||||
verbose=verbose_,
|
||||
color=color,
|
||||
name=self.name,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_tool_end(
|
||||
@ -237,6 +246,6 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
||||
)
|
||||
return observation
|
||||
|
||||
def __call__(self, tool_input: str) -> str:
|
||||
def __call__(self, tool_input: Union[str, dict]) -> Any:
|
||||
"""Make tool callable."""
|
||||
return self.run(tool_input)
|
||||
|
@ -2,11 +2,19 @@
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Optional, Type, Union
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pydantic
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.agents.agent import Agent
|
||||
from langchain.agents.chat.base import ChatAgent
|
||||
from langchain.agents.conversational.base import ConversationalAgent
|
||||
from langchain.agents.conversational_chat.base import ConversationalChatAgent
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.agents.react.base import ReActDocstoreAgent, ReActTextWorldAgent
|
||||
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
|
||||
from langchain.agents.tools import Tool, tool
|
||||
from langchain.tools.base import BaseTool, SchemaAnnotationError
|
||||
|
||||
@ -152,6 +160,7 @@ def test_decorated_function_schema_equivalent() -> None:
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
assert isinstance(structured_tool_input, Tool)
|
||||
assert structured_tool_input.args_schema is not None
|
||||
assert (
|
||||
structured_tool_input.args_schema.schema()["properties"]
|
||||
== _MockSchema.schema()["properties"]
|
||||
@ -309,33 +318,38 @@ def test_tool_with_kwargs() -> None:
|
||||
|
||||
@tool(return_direct=True)
|
||||
def search_api(
|
||||
arg_1: float,
|
||||
arg_0: str,
|
||||
arg_1: float = 4.3,
|
||||
ping: str = "hi",
|
||||
) -> str:
|
||||
"""Search the API for the query."""
|
||||
return f"arg_1={arg_1}, ping={ping}"
|
||||
return f"arg_0={arg_0}, arg_1={arg_1}, ping={ping}"
|
||||
|
||||
assert isinstance(search_api, Tool)
|
||||
result = search_api.run(
|
||||
tool_input={
|
||||
"arg_0": "foo",
|
||||
"arg_1": 3.2,
|
||||
"ping": "pong",
|
||||
}
|
||||
)
|
||||
assert result == "arg_1=3.2, ping=pong"
|
||||
assert result == "arg_0=foo, arg_1=3.2, ping=pong"
|
||||
|
||||
result = search_api.run(
|
||||
tool_input={
|
||||
"arg_1": 3.2,
|
||||
"arg_0": "foo",
|
||||
}
|
||||
)
|
||||
assert result == "arg_1=3.2, ping=hi"
|
||||
assert result == "arg_0=foo, arg_1=4.3, ping=hi"
|
||||
# For backwards compatibility, we still accept a single str arg
|
||||
result = search_api.run("foobar")
|
||||
assert result == "arg_0=foobar, arg_1=4.3, ping=hi"
|
||||
|
||||
|
||||
def test_missing_docstring() -> None:
|
||||
"""Test error is raised when docstring is missing."""
|
||||
# expect to throw a value error if theres no docstring
|
||||
with pytest.raises(AssertionError):
|
||||
with pytest.raises(AssertionError, match="Function must have a docstring"):
|
||||
|
||||
@tool
|
||||
def search_api(query: str) -> str:
|
||||
@ -348,11 +362,13 @@ def test_create_tool_positional_args() -> None:
|
||||
assert test_tool("foo") == "foo"
|
||||
assert test_tool.name == "test_name"
|
||||
assert test_tool.description == "test_description"
|
||||
assert test_tool.is_single_input
|
||||
|
||||
|
||||
def test_create_tool_keyword_args() -> None:
|
||||
"""Test that keyword arguments are allowed."""
|
||||
test_tool = Tool(name="test_name", func=lambda x: x, description="test_description")
|
||||
assert test_tool.is_single_input
|
||||
assert test_tool("foo") == "foo"
|
||||
assert test_tool.name == "test_name"
|
||||
assert test_tool.description == "test_description"
|
||||
@ -371,8 +387,39 @@ async def test_create_async_tool() -> None:
|
||||
description="test_description",
|
||||
coroutine=_test_func,
|
||||
)
|
||||
assert test_tool.is_single_input
|
||||
assert test_tool("foo") == "foo"
|
||||
assert test_tool.name == "test_name"
|
||||
assert test_tool.description == "test_description"
|
||||
assert test_tool.coroutine is not None
|
||||
assert await test_tool.arun("foo") == "foo"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"agent_cls",
|
||||
[
|
||||
ChatAgent,
|
||||
ZeroShotAgent,
|
||||
ConversationalChatAgent,
|
||||
ConversationalAgent,
|
||||
ReActDocstoreAgent,
|
||||
ReActTextWorldAgent,
|
||||
SelfAskWithSearchAgent,
|
||||
],
|
||||
)
|
||||
def test_single_input_agent_raises_error_on_structured_tool(
|
||||
agent_cls: Type[Agent],
|
||||
) -> None:
|
||||
"""Test that older agents raise errors on older tools."""
|
||||
|
||||
@tool
|
||||
def the_tool(foo: str, bar: str) -> str:
|
||||
"""Return the concat of foo and bar."""
|
||||
return foo + bar
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=f"{agent_cls.__name__} does not support" # type: ignore
|
||||
f" multi-input tool {the_tool.name}.",
|
||||
):
|
||||
agent_cls.from_llm_and_tools(MagicMock(), [the_tool]) # type: ignore
|
||||
|
Loading…
Reference in New Issue
Block a user