mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-14 17:07:25 +00:00
Add validation on agent instantiation for multi-input tools (#3681)
Tradeoffs here: - No lint-time checking for compatibility - Differs from JS package - The signature inference, etc. in the base tool isn't simple - The `args_schema` is optional Pros: - Forwards compatibility retained - Doesn't break backwards compatibility - User doesn't have to think about which class to subclass (single base tool or dynamic `Tool` interface regardless of input) - No need to change the load_tools, etc. interfaces Co-authored-by: Hasan Patel <mangafield@gmail.com>
This commit is contained in:
parent
212aadd4af
commit
4654c58f72
@ -454,7 +454,11 @@ class Agent(BaseSingleActionAgent):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||||
"""Validate that appropriate tools are passed in."""
|
"""Validate that appropriate tools are passed in."""
|
||||||
pass
|
for tool in tools:
|
||||||
|
if not tool.is_single_input:
|
||||||
|
raise ValueError(
|
||||||
|
f"{cls.__name__} does not support multi-input tool {tool.name}."
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -122,6 +122,7 @@ class ZeroShotAgent(Agent):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||||
|
super()._validate_tools(tools)
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
if tool.description is None:
|
if tool.description is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -37,6 +37,7 @@ class ReActDocstoreAgent(Agent):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||||
|
super()._validate_tools(tools)
|
||||||
if len(tools) != 2:
|
if len(tools) != 2:
|
||||||
raise ValueError(f"Exactly two tools must be specified, but got {tools}")
|
raise ValueError(f"Exactly two tools must be specified, but got {tools}")
|
||||||
tool_names = {tool.name for tool in tools}
|
tool_names = {tool.name for tool in tools}
|
||||||
@ -119,6 +120,7 @@ class ReActTextWorldAgent(ReActDocstoreAgent):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||||
|
super()._validate_tools(tools)
|
||||||
if len(tools) != 1:
|
if len(tools) != 1:
|
||||||
raise ValueError(f"Exactly one tool must be specified, but got {tools}")
|
raise ValueError(f"Exactly one tool must be specified, but got {tools}")
|
||||||
tool_names = {tool.name for tool in tools}
|
tool_names = {tool.name for tool in tools}
|
||||||
|
@ -36,6 +36,7 @@ class SelfAskWithSearchAgent(Agent):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||||
|
super()._validate_tools(tools)
|
||||||
if len(tools) != 1:
|
if len(tools) != 1:
|
||||||
raise ValueError(f"Exactly one tool must be specified, but got {tools}")
|
raise ValueError(f"Exactly one tool must be specified, but got {tools}")
|
||||||
tool_names = {tool.name for tool in tools}
|
tool_names = {tool.name for tool in tools}
|
||||||
|
@ -115,6 +115,11 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
|||||||
extra = Extra.forbid
|
extra = Extra.forbid
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_single_input(self) -> bool:
|
||||||
|
"""Whether the tool only accepts a single input."""
|
||||||
|
return len(self.args) == 1
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def args(self) -> dict:
|
def args(self) -> dict:
|
||||||
if self.args_schema is not None:
|
if self.args_schema is not None:
|
||||||
@ -148,11 +153,11 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
|||||||
return callback_manager or get_callback_manager()
|
return callback_manager or get_callback_manager()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _run(self, *args: Any, **kwargs: Any) -> str:
|
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
"""Use the tool."""
|
"""Use the tool."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def _arun(self, *args: Any, **kwargs: Any) -> str:
|
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
"""Use the tool asynchronously."""
|
"""Use the tool asynchronously."""
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
@ -183,7 +188,7 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
|||||||
self.callback_manager.on_tool_error(e, verbose=verbose_)
|
self.callback_manager.on_tool_error(e, verbose=verbose_)
|
||||||
raise e
|
raise e
|
||||||
self.callback_manager.on_tool_end(
|
self.callback_manager.on_tool_end(
|
||||||
observation, verbose=verbose_, color=color, name=self.name, **kwargs
|
str(observation), verbose=verbose_, color=color, name=self.name, **kwargs
|
||||||
)
|
)
|
||||||
return observation
|
return observation
|
||||||
|
|
||||||
@ -194,7 +199,7 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
|||||||
start_color: Optional[str] = "green",
|
start_color: Optional[str] = "green",
|
||||||
color: Optional[str] = "green",
|
color: Optional[str] = "green",
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> Any:
|
||||||
"""Run the tool asynchronously."""
|
"""Run the tool asynchronously."""
|
||||||
self._parse_input(tool_input)
|
self._parse_input(tool_input)
|
||||||
if not self.verbose and verbose is not None:
|
if not self.verbose and verbose is not None:
|
||||||
@ -229,7 +234,11 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
|||||||
raise e
|
raise e
|
||||||
if self.callback_manager.is_async:
|
if self.callback_manager.is_async:
|
||||||
await self.callback_manager.on_tool_end(
|
await self.callback_manager.on_tool_end(
|
||||||
observation, verbose=verbose_, color=color, name=self.name, **kwargs
|
str(observation),
|
||||||
|
verbose=verbose_,
|
||||||
|
color=color,
|
||||||
|
name=self.name,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.callback_manager.on_tool_end(
|
self.callback_manager.on_tool_end(
|
||||||
@ -237,6 +246,6 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
|||||||
)
|
)
|
||||||
return observation
|
return observation
|
||||||
|
|
||||||
def __call__(self, tool_input: str) -> str:
|
def __call__(self, tool_input: Union[str, dict]) -> Any:
|
||||||
"""Make tool callable."""
|
"""Make tool callable."""
|
||||||
return self.run(tool_input)
|
return self.run(tool_input)
|
||||||
|
@ -2,11 +2,19 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Optional, Type, Union
|
from typing import Optional, Type, Union
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pydantic
|
import pydantic
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from langchain.agents.agent import Agent
|
||||||
|
from langchain.agents.chat.base import ChatAgent
|
||||||
|
from langchain.agents.conversational.base import ConversationalAgent
|
||||||
|
from langchain.agents.conversational_chat.base import ConversationalChatAgent
|
||||||
|
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||||
|
from langchain.agents.react.base import ReActDocstoreAgent, ReActTextWorldAgent
|
||||||
|
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
|
||||||
from langchain.agents.tools import Tool, tool
|
from langchain.agents.tools import Tool, tool
|
||||||
from langchain.tools.base import BaseTool, SchemaAnnotationError
|
from langchain.tools.base import BaseTool, SchemaAnnotationError
|
||||||
|
|
||||||
@ -152,6 +160,7 @@ def test_decorated_function_schema_equivalent() -> None:
|
|||||||
return f"{arg1} {arg2} {arg3}"
|
return f"{arg1} {arg2} {arg3}"
|
||||||
|
|
||||||
assert isinstance(structured_tool_input, Tool)
|
assert isinstance(structured_tool_input, Tool)
|
||||||
|
assert structured_tool_input.args_schema is not None
|
||||||
assert (
|
assert (
|
||||||
structured_tool_input.args_schema.schema()["properties"]
|
structured_tool_input.args_schema.schema()["properties"]
|
||||||
== _MockSchema.schema()["properties"]
|
== _MockSchema.schema()["properties"]
|
||||||
@ -309,33 +318,38 @@ def test_tool_with_kwargs() -> None:
|
|||||||
|
|
||||||
@tool(return_direct=True)
|
@tool(return_direct=True)
|
||||||
def search_api(
|
def search_api(
|
||||||
arg_1: float,
|
arg_0: str,
|
||||||
|
arg_1: float = 4.3,
|
||||||
ping: str = "hi",
|
ping: str = "hi",
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Search the API for the query."""
|
"""Search the API for the query."""
|
||||||
return f"arg_1={arg_1}, ping={ping}"
|
return f"arg_0={arg_0}, arg_1={arg_1}, ping={ping}"
|
||||||
|
|
||||||
assert isinstance(search_api, Tool)
|
assert isinstance(search_api, Tool)
|
||||||
result = search_api.run(
|
result = search_api.run(
|
||||||
tool_input={
|
tool_input={
|
||||||
|
"arg_0": "foo",
|
||||||
"arg_1": 3.2,
|
"arg_1": 3.2,
|
||||||
"ping": "pong",
|
"ping": "pong",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
assert result == "arg_1=3.2, ping=pong"
|
assert result == "arg_0=foo, arg_1=3.2, ping=pong"
|
||||||
|
|
||||||
result = search_api.run(
|
result = search_api.run(
|
||||||
tool_input={
|
tool_input={
|
||||||
"arg_1": 3.2,
|
"arg_0": "foo",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
assert result == "arg_1=3.2, ping=hi"
|
assert result == "arg_0=foo, arg_1=4.3, ping=hi"
|
||||||
|
# For backwards compatibility, we still accept a single str arg
|
||||||
|
result = search_api.run("foobar")
|
||||||
|
assert result == "arg_0=foobar, arg_1=4.3, ping=hi"
|
||||||
|
|
||||||
|
|
||||||
def test_missing_docstring() -> None:
|
def test_missing_docstring() -> None:
|
||||||
"""Test error is raised when docstring is missing."""
|
"""Test error is raised when docstring is missing."""
|
||||||
# expect to throw a value error if theres no docstring
|
# expect to throw a value error if theres no docstring
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError, match="Function must have a docstring"):
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def search_api(query: str) -> str:
|
def search_api(query: str) -> str:
|
||||||
@ -348,11 +362,13 @@ def test_create_tool_positional_args() -> None:
|
|||||||
assert test_tool("foo") == "foo"
|
assert test_tool("foo") == "foo"
|
||||||
assert test_tool.name == "test_name"
|
assert test_tool.name == "test_name"
|
||||||
assert test_tool.description == "test_description"
|
assert test_tool.description == "test_description"
|
||||||
|
assert test_tool.is_single_input
|
||||||
|
|
||||||
|
|
||||||
def test_create_tool_keyword_args() -> None:
|
def test_create_tool_keyword_args() -> None:
|
||||||
"""Test that keyword arguments are allowed."""
|
"""Test that keyword arguments are allowed."""
|
||||||
test_tool = Tool(name="test_name", func=lambda x: x, description="test_description")
|
test_tool = Tool(name="test_name", func=lambda x: x, description="test_description")
|
||||||
|
assert test_tool.is_single_input
|
||||||
assert test_tool("foo") == "foo"
|
assert test_tool("foo") == "foo"
|
||||||
assert test_tool.name == "test_name"
|
assert test_tool.name == "test_name"
|
||||||
assert test_tool.description == "test_description"
|
assert test_tool.description == "test_description"
|
||||||
@ -371,8 +387,39 @@ async def test_create_async_tool() -> None:
|
|||||||
description="test_description",
|
description="test_description",
|
||||||
coroutine=_test_func,
|
coroutine=_test_func,
|
||||||
)
|
)
|
||||||
|
assert test_tool.is_single_input
|
||||||
assert test_tool("foo") == "foo"
|
assert test_tool("foo") == "foo"
|
||||||
assert test_tool.name == "test_name"
|
assert test_tool.name == "test_name"
|
||||||
assert test_tool.description == "test_description"
|
assert test_tool.description == "test_description"
|
||||||
assert test_tool.coroutine is not None
|
assert test_tool.coroutine is not None
|
||||||
assert await test_tool.arun("foo") == "foo"
|
assert await test_tool.arun("foo") == "foo"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"agent_cls",
|
||||||
|
[
|
||||||
|
ChatAgent,
|
||||||
|
ZeroShotAgent,
|
||||||
|
ConversationalChatAgent,
|
||||||
|
ConversationalAgent,
|
||||||
|
ReActDocstoreAgent,
|
||||||
|
ReActTextWorldAgent,
|
||||||
|
SelfAskWithSearchAgent,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_single_input_agent_raises_error_on_structured_tool(
|
||||||
|
agent_cls: Type[Agent],
|
||||||
|
) -> None:
|
||||||
|
"""Test that older agents raise errors on older tools."""
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def the_tool(foo: str, bar: str) -> str:
|
||||||
|
"""Return the concat of foo and bar."""
|
||||||
|
return foo + bar
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=f"{agent_cls.__name__} does not support" # type: ignore
|
||||||
|
f" multi-input tool {the_tool.name}.",
|
||||||
|
):
|
||||||
|
agent_cls.from_llm_and_tools(MagicMock(), [the_tool]) # type: ignore
|
||||||
|
Loading…
Reference in New Issue
Block a user