Compare commits

...

3 Commits

Author SHA1 Message Date
vowelparrot
2cbc201c15 Tool Single-Input, Structured Tool Any-Input 2023-04-28 07:27:56 -07:00
vowelparrot
c91c71df7d Filter args when function is only *args and **kwargs 2023-04-28 06:24:54 -07:00
Zander Chase
7439002045 Add validation on agent instantiation for multi-input tools (#3681)
Tradeoffs here:
- No lint-time checking for compatibility
- Differs from JS package
- The signature inference, etc. in the base tool isn't simple
- The `args_schema` is optional

Pros:
- Forwards compatibility retained
- Doesn't break backwards compatibility
- User doesn't have to think about which class to subclass (single base
tool or dynamic `Tool` interface regardless of input)
-  No need to change the load_tools, etc. interfaces

Co-authored-by: Hasan Patel <mangafield@gmail.com>
2023-04-28 06:24:54 -07:00
4 changed files with 195 additions and 61 deletions

View File

@@ -1,3 +1,4 @@
import re
from typing import Any, List, Optional, Sequence, Tuple
from pydantic import Field
@@ -49,6 +50,11 @@ class ChatAgent(Agent):
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
return ChatOutputParser()
@classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
"""Validate that appropriate tools are passed in."""
pass
@property
def _stop(self) -> List[str]:
return ["Observation:"]
@@ -62,7 +68,13 @@ class ChatAgent(Agent):
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
) -> BasePromptTemplate:
tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
tool_strings_ = []
for tool in tools:
args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
tool_strings_.append(
f"> {tool.name}: {tool.description}\nArgs: {args_schema}"
)
tool_strings = "\n".join(tool_strings_)
tool_names = ", ".join([tool.name for tool in tools])
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])

View File

@@ -1,15 +1,10 @@
"""Interface for tools."""
from functools import partial
from inspect import signature
from typing import Any, Awaitable, Callable, Optional, Type, Union
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type, Union
from pydantic import BaseModel, validate_arguments, validator
from pydantic import BaseModel, validator
from langchain.tools.base import (
BaseTool,
create_schema_from_function,
get_filtered_args,
)
from langchain.tools.base import BaseTool, StructuredTool
class Tool(BaseTool):
@@ -30,17 +25,30 @@ class Tool(BaseTool):
@property
def args(self) -> dict:
"""The tool's input arguments."""
if self.args_schema is not None:
return self.args_schema.schema()["properties"]
else:
inferred_model = validate_arguments(self.func).model # type: ignore
return get_filtered_args(inferred_model, self.func)
# For backwards compatibility, if the function signature is ambiguous,
# assume it takes a single string input.
return {"tool_input": {"type": "string"}}
def _run(self, *args: Any, **kwargs: Any) -> str:
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
"""Convert tool input to pydantic model."""
args, kwargs = super()._to_args_and_kwargs(tool_input)
# For backwards compatibility. The tool must be run with a single input
all_args = list(args) + list(kwargs.values())
if len(all_args) != 1:
raise ValueError(
f"Too many arguments to single-input tool {self.name}."
f" Args: {all_args}"
)
return tuple(all_args), {}
def _run(self, *args: Any, **kwargs: Any) -> Any:
"""Use the tool."""
return self.func(*args, **kwargs)
async def _arun(self, *args: Any, **kwargs: Any) -> str:
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
"""Use the tool asynchronously."""
if self.coroutine:
return await self.coroutine(*args, **kwargs)
@@ -48,7 +56,7 @@ class Tool(BaseTool):
# TODO: this is for backwards compatibility, remove in future
def __init__(
self, name: str, func: Callable[[str], str], description: str, **kwargs: Any
self, name: str, func: Callable, description: str, **kwargs: Any
) -> None:
"""Initialize tool."""
super(Tool, self).__init__(
@@ -107,22 +115,24 @@ def tool(
"""
def _make_with_name(tool_name: str) -> Callable:
def _make_tool(func: Callable) -> Tool:
assert func.__doc__, "Function must have a docstring"
# Description example:
# search_api(query: str) - Searches the API for the query.
description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}"
_args_schema = args_schema
if _args_schema is None and infer_schema:
_args_schema = create_schema_from_function(f"{tool_name}Schema", func)
tool_ = Tool(
def _make_tool(func: Callable) -> BaseTool:
if infer_schema or args_schema is not None:
return StructuredTool.from_function(
func,
name=tool_name,
return_direct=return_direct,
args_schema=args_schema,
infer_schema=infer_schema,
)
# If someone doesn't want a schema applied, we must treat it as
# a simple string->string function
assert func.__doc__ is not None, "Function must have a docstring"
return Tool(
name=tool_name,
func=func,
args_schema=_args_schema,
description=description,
description=f"{tool_name} tool",
return_direct=return_direct,
)
return tool_
return _make_tool

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from inspect import signature
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type, Union
from pydantic import (
BaseModel,
@@ -19,15 +19,6 @@ from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager
def _to_args_and_kwargs(run_input: Union[str, Dict]) -> Tuple[Sequence, dict]:
# For backwards compatability, if run_input is a string,
# pass as a positional argument.
if isinstance(run_input, str):
return (run_input,), {}
else:
return [], run_input
class SchemaAnnotationError(TypeError):
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""
@@ -81,14 +72,20 @@ def _create_subset_model(
return create_model(name, **fields) # type: ignore
def get_filtered_args(inferred_model: Type[BaseModel], func: Callable) -> dict:
def get_filtered_args(
inferred_model: Type[BaseModel],
func: Callable,
) -> dict:
"""Get the arguments from a function's signature."""
schema = inferred_model.schema()["properties"]
valid_keys = signature(func).parameters
return {k: schema[k] for k in valid_keys}
def create_schema_from_function(model_name: str, func: Callable) -> Type[BaseModel]:
def create_schema_from_function(
model_name: str,
func: Callable,
) -> Type[BaseModel]:
"""Create a pydantic schema from a function's signature."""
inferred_model = validate_arguments(func).model # type: ignore
# Pydantic adds placeholder virtual fields we need to strip
@@ -102,12 +99,23 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
"""Interface LangChain tools must implement."""
name: str
"""The unique name of the tool that clearly communicates its purpose."""
description: str
"""Used to tell the model how/when/why to use the tool.
You can provide few-shot examples as a part of the description.
"""
args_schema: Optional[Type[BaseModel]] = None
"""Pydantic model class to validate and parse the tool's input arguments."""
return_direct: bool = False
"""Whether to return the tool's output directly. Setting this to True means
that after the tool is called, the AgentExecutor will stop looping.
"""
verbose: bool = False
"""Whether to log the tool's progress."""
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
"""Callback manager for this tool."""
class Config:
"""Configuration for this pydantic object."""
@@ -160,6 +168,14 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
"""Use the tool asynchronously."""
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
# For backwards compatibility, if run_input is a string,
# pass as a positional argument.
if isinstance(tool_input, str):
return (tool_input,), {}
else:
return (), tool_input
def run(
self,
tool_input: Union[str, Dict],
@@ -182,7 +198,7 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
**kwargs,
)
try:
tool_args, tool_kwargs = _to_args_and_kwargs(tool_input)
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
observation = self._run(*tool_args, **tool_kwargs)
except (Exception, KeyboardInterrupt) as e:
self.callback_manager.on_tool_error(e, verbose=verbose_)
@@ -224,8 +240,8 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
)
try:
# We then call the tool on the tool input to get an observation
args, kwargs = _to_args_and_kwargs(tool_input)
observation = await self._arun(*args, **kwargs)
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
observation = await self._arun(*tool_args, **tool_kwargs)
except (Exception, KeyboardInterrupt) as e:
if self.callback_manager.is_async:
await self.callback_manager.on_tool_error(e, verbose=verbose_)
@@ -249,3 +265,62 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
def __call__(self, tool_input: Union[str, dict]) -> Any:
"""Make tool callable."""
return self.run(tool_input)
class StructuredTool(BaseTool):
"""Tool that can operate on any number of inputs."""
description: str = ""
args_schema: Type[BaseModel] = Field(..., description="The tool schema.")
"""The input arguments' schema."""
func: Callable[..., str]
"""The function to run when the tool is called."""
coroutine: Optional[Callable[..., Awaitable[str]]] = None
"""The asynchronous version of the function."""
@property
def args(self) -> dict:
"""The tool's input arguments."""
return self.args_schema.schema()["properties"]
def _run(self, *args: Any, **kwargs: Any) -> Any:
"""Use the tool."""
return self.func(*args, **kwargs)
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
"""Use the tool asynchronously."""
if self.coroutine:
return await self.coroutine(*args, **kwargs)
raise NotImplementedError("Tool does not support async")
@classmethod
def from_function(
cls,
func: Callable,
name: Optional[str] = None,
description: Optional[str] = None,
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
infer_schema: bool = True,
**kwargs: Any,
) -> StructuredTool:
name = name or func.__name__
description = description or func.__doc__
assert (
description is not None
), "Function must have a docstring if description not provided."
# Description example:
# search_api(query: str) - Searches the API for the query.
description = f"{name}{signature(func)} - {description.strip()}"
_args_schema = args_schema
if _args_schema is None and infer_schema:
_args_schema = create_schema_from_function(f"{name}Schema", func)
return cls(
name=name,
func=func,
args_schema=_args_schema,
description=description,
return_direct=return_direct,
**kwargs,
)

View File

@@ -1,7 +1,7 @@
"""Test tool utils."""
from datetime import datetime
from functools import partial
from typing import Optional, Type, Union
from typing import Any, Optional, Type, Union
from unittest.mock import MagicMock
import pydantic
@@ -16,7 +16,7 @@ from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.react.base import ReActDocstoreAgent, ReActTextWorldAgent
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
from langchain.agents.tools import Tool, tool
from langchain.tools.base import BaseTool, SchemaAnnotationError
from langchain.tools.base import BaseTool, SchemaAnnotationError, StructuredTool
def test_unnamed_decorator() -> None:
@@ -27,7 +27,7 @@ def test_unnamed_decorator() -> None:
"""Search the API for the query."""
return "API result"
assert isinstance(search_api, Tool)
assert isinstance(search_api, BaseTool)
assert search_api.name == "search_api"
assert not search_api.return_direct
assert search_api("test") == "API result"
@@ -145,7 +145,7 @@ def test_decorator_with_specified_schema() -> None:
"""Return the arguments directly."""
return f"{arg1} {arg2} {arg3}"
assert isinstance(tool_func, Tool)
assert isinstance(tool_func, BaseTool)
assert tool_func.args_schema == _MockSchema
@@ -159,7 +159,7 @@ def test_decorated_function_schema_equivalent() -> None:
"""Return the arguments directly."""
return f"{arg1} {arg2} {arg3}"
assert isinstance(structured_tool_input, Tool)
assert isinstance(structured_tool_input, BaseTool)
assert structured_tool_input.args_schema is not None
assert (
structured_tool_input.args_schema.schema()["properties"]
@@ -171,14 +171,14 @@ def test_decorated_function_schema_equivalent() -> None:
def test_structured_args_decorator_no_infer_schema() -> None:
"""Test functionality with structured arguments parsed as a decorator."""
@tool(infer_schema=False)
@tool
def structured_tool_input(
arg1: int, arg2: Union[float, datetime], opt_arg: Optional[dict] = None
) -> str:
"""Return the arguments directly."""
return f"{arg1}, {arg2}, {opt_arg}"
assert isinstance(structured_tool_input, Tool)
assert isinstance(structured_tool_input, BaseTool)
assert structured_tool_input.name == "structured_tool_input"
args = {"arg1": 1, "arg2": 0.001, "opt_arg": {"foo": "bar"}}
expected_result = "1, 0.001, {'foo': 'bar'}"
@@ -193,8 +193,9 @@ def test_structured_single_str_decorator_no_infer_schema() -> None:
"""Return the arguments directly."""
return f"{tool_input}"
assert isinstance(unstructured_tool_input, Tool)
assert isinstance(unstructured_tool_input, BaseTool)
assert unstructured_tool_input.args_schema is None
assert unstructured_tool_input.run("foo") == "foo"
def test_base_tool_inheritance_base_schema() -> None:
@@ -225,18 +226,18 @@ def test_tool_lambda_args_schema() -> None:
func=lambda tool_input: tool_input,
)
assert tool.args_schema is None
expected_args = {"tool_input": {"title": "Tool Input"}}
expected_args = {"tool_input": {"type": "string"}}
assert tool.args == expected_args
def test_tool_lambda_multi_args_schema() -> None:
def test_structured_tool_lambda_multi_args_schema() -> None:
"""Test args schema inference when the tool argument is a lambda function."""
tool = Tool(
tool = StructuredTool.from_function(
name="tool",
description="A tool",
func=lambda tool_input, other_arg: f"{tool_input}{other_arg}", # type: ignore
)
assert tool.args_schema is None
assert tool.args_schema is not None
expected_args = {
"tool_input": {"title": "Tool Input"},
"other_arg": {"title": "Other Arg"},
@@ -268,7 +269,7 @@ def test_empty_args_decorator() -> None:
"""Return a constant."""
return "the empty result"
assert isinstance(empty_tool_input, Tool)
assert isinstance(empty_tool_input, BaseTool)
assert empty_tool_input.name == "empty_tool_input"
assert empty_tool_input.args == {}
assert empty_tool_input.run({}) == "the empty result"
@@ -282,7 +283,7 @@ def test_named_tool_decorator() -> None:
"""Search the API for the query."""
return "API result"
assert isinstance(search_api, Tool)
assert isinstance(search_api, BaseTool)
assert search_api.name == "search"
assert not search_api.return_direct
@@ -295,7 +296,7 @@ def test_named_tool_decorator_return_direct() -> None:
"""Search the API for the query."""
return "API result"
assert isinstance(search_api, Tool)
assert isinstance(search_api, BaseTool)
assert search_api.name == "search"
assert search_api.return_direct
@@ -308,7 +309,7 @@ def test_unnamed_tool_decorator_return_direct() -> None:
"""Search the API for the query."""
return "API result"
assert isinstance(search_api, Tool)
assert isinstance(search_api, BaseTool)
assert search_api.name == "search_api"
assert search_api.return_direct
@@ -325,7 +326,7 @@ def test_tool_with_kwargs() -> None:
"""Search the API for the query."""
return f"arg_0={arg_0}, arg_1={arg_1}, ping={ping}"
assert isinstance(search_api, Tool)
assert isinstance(search_api, BaseTool)
result = search_api.run(
tool_input={
"arg_0": "foo",
@@ -398,9 +399,8 @@ async def test_create_async_tool() -> None:
@pytest.mark.parametrize(
"agent_cls",
[
ChatAgent,
ZeroShotAgent,
ConversationalChatAgent,
ZeroShotAgent,
ConversationalAgent,
ReActDocstoreAgent,
ReActTextWorldAgent,
@@ -423,3 +423,40 @@ def test_single_input_agent_raises_error_on_structured_tool(
f" multi-input tool {the_tool.name}.",
):
agent_cls.from_llm_and_tools(MagicMock(), [the_tool]) # type: ignore
@pytest.mark.parametrize(
"agent_cls",
[
ChatAgent,
],
)
def test_multi_input_agents_permit_structured_tool(agent_cls: Type[Agent]) -> None:
"""Test that newer agents permit multi-input tools."""
@tool
def the_tool(foo: str, bar: str) -> str:
"""Return the concat of foo and bar."""
return foo + bar
agent_cls._validate_tools([the_tool]) # type: ignore
def test_tool_no_args_specified_assumes_str() -> None:
"""Older tools could assume *args and **kwargs were passed in."""
def ambiguous_function(*args: Any, **kwargs: Any) -> str:
"""An ambiguously defined function."""
return args[0]
some_tool = Tool(
name="chain_run",
description="Run the chain",
func=ambiguous_function,
)
expected_args = {"tool_input": {"type": "string"}}
assert some_tool.args == expected_args
assert some_tool.run("foobar") == "foobar"
assert some_tool.run({"tool_input": "foobar"}) == "foobar"
with pytest.raises(ValueError, match="Too many arguments to single-input tool"):
some_tool.run({"tool_input": "foobar", "other_input": "bar"})