Compare commits

...

2 Commits

Author SHA1 Message Date
vowelparrot
30b42514aa Filter args when function is only *args and **kwargs 2023-04-27 17:36:07 -07:00
Zander Chase
027638af8a Add validation on agent instantiation for multi-input tools (#3681)
Tradeoffs here:
- No lint-time checking for compatibility
- Differs from JS package
- The signature inference, etc. in the base tool isn't simple
- The `args_schema` is optional

Pros:
- Forwards compatibility retained
- Doesn't break backwards compatibility
- User doesn't have to think about which class to subclass (single base
tool or dynamic `Tool` interface regardless of input)
-  No need to change the load_tools, etc. interfaces

Co-authored-by: Hasan Patel <mangafield@gmail.com>
2023-04-27 16:59:15 -07:00
3 changed files with 69 additions and 23 deletions

View File

@@ -1,7 +1,7 @@
"""Interface for tools."""
from functools import partial
from inspect import signature
from typing import Any, Awaitable, Callable, Optional, Type, Union
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type, Union
from pydantic import BaseModel, validate_arguments, validator
@@ -30,17 +30,39 @@ class Tool(BaseTool):
@property
def args(self) -> dict:
"""The tool's input arguments."""
if self.args_schema is not None:
return self.args_schema.schema()["properties"]
else:
inferred_model = validate_arguments(self.func).model # type: ignore
return get_filtered_args(inferred_model, self.func)
inferred_model = validate_arguments(self.func).model # type: ignore
filtered_args = get_filtered_args(
inferred_model, self.func, invalid_args={"args", "kwargs"}
)
if filtered_args:
return filtered_args
# For backwards compatability, if the function signature is ambiguous,
# assume it takes a single string input.
return {"tool_input": {"type": "string"}}
def _run(self, *args: Any, **kwargs: Any) -> str:
def _to_args_and_kwargs(self, tool_input: str | Dict) -> Tuple[Tuple, Dict]:
"""Convert tool input to pydantic model."""
args, kwargs = super()._to_args_and_kwargs(tool_input)
if self.is_single_input:
# For backwards compatability. If no schema is inferred,
# the tool must assume it should be run with a single input
all_args = list(args) + list(kwargs.values())
if len(all_args) != 1:
raise ValueError(
f"Too many arguments to single-input tool {self.name}."
f" Args: {all_args}"
)
return tuple(all_args), {}
return args, kwargs
def _run(self, *args: Any, **kwargs: Any) -> Any:
"""Use the tool."""
return self.func(*args, **kwargs)
async def _arun(self, *args: Any, **kwargs: Any) -> str:
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
"""Use the tool asynchronously."""
if self.coroutine:
return await self.coroutine(*args, **kwargs)
@@ -48,7 +70,7 @@ class Tool(BaseTool):
# TODO: this is for backwards compatibility, remove in future
def __init__(
self, name: str, func: Callable[[str], str], description: str, **kwargs: Any
self, name: str, func: Callable, description: str, **kwargs: Any
) -> None:
"""Initialize tool."""
super(Tool, self).__init__(

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from inspect import signature
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union
from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, Union
from pydantic import (
BaseModel,
@@ -19,15 +19,6 @@ from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager
def _to_args_and_kwargs(run_input: Union[str, Dict]) -> Tuple[Sequence, dict]:
# For backwards compatability, if run_input is a string,
# pass as a positional argument.
if isinstance(run_input, str):
return (run_input,), {}
else:
return [], run_input
class SchemaAnnotationError(TypeError):
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""
@@ -81,11 +72,16 @@ def _create_subset_model(
return create_model(name, **fields) # type: ignore
def get_filtered_args(inferred_model: Type[BaseModel], func: Callable) -> dict:
def get_filtered_args(
inferred_model: Type[BaseModel],
func: Callable,
invalid_args: Optional[Set[str]] = None,
) -> dict:
"""Get the arguments from a function's signature."""
schema = inferred_model.schema()["properties"]
valid_keys = signature(func).parameters
return {k: schema[k] for k in valid_keys}
invalid_args = invalid_args or set()
return {k: schema[k] for k in valid_keys if k not in invalid_args}
def create_schema_from_function(model_name: str, func: Callable) -> Type[BaseModel]:
@@ -160,6 +156,14 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
"""Use the tool asynchronously."""
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
# For backwards compatability, if run_input is a string,
# pass as a positional argument.
if isinstance(tool_input, str):
return (tool_input,), {}
else:
return (), tool_input
def run(
self,
tool_input: Union[str, Dict],
@@ -182,7 +186,7 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
**kwargs,
)
try:
tool_args, tool_kwargs = _to_args_and_kwargs(tool_input)
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
observation = self._run(*tool_args, **tool_kwargs)
except (Exception, KeyboardInterrupt) as e:
self.callback_manager.on_tool_error(e, verbose=verbose_)
@@ -224,8 +228,8 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
)
try:
# We then call the tool on the tool input to get an observation
args, kwargs = _to_args_and_kwargs(tool_input)
observation = await self._arun(*args, **kwargs)
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
observation = await self._arun(*tool_args, **tool_kwargs)
except (Exception, KeyboardInterrupt) as e:
if self.callback_manager.is_async:
await self.callback_manager.on_tool_error(e, verbose=verbose_)

View File

@@ -1,7 +1,7 @@
"""Test tool utils."""
from datetime import datetime
from functools import partial
from typing import Optional, Type, Union
from typing import Any, Optional, Type, Union
from unittest.mock import MagicMock
import pydantic
@@ -423,3 +423,23 @@ def test_single_input_agent_raises_error_on_structured_tool(
f" multi-input tool {the_tool.name}.",
):
agent_cls.from_llm_and_tools(MagicMock(), [the_tool]) # type: ignore
def test_tool_no_args_specified_assumes_str() -> None:
"""Older tools could assume *args and **kwargs were passed in."""
def ambiguous_function(*args: Any, **kwargs: Any) -> str:
"""An ambiguously defined function."""
return args[0]
some_tool = Tool(
name="chain_run",
description="Run the chain",
func=ambiguous_function,
)
expected_args = {"tool_input": {"type": "string"}}
assert some_tool.args == expected_args
assert some_tool.run("foobar") == "foobar"
assert some_tool.run({"tool_input": "foobar"}) == "foobar"
with pytest.raises(ValueError, match="Too many arguments to single-input tool"):
some_tool.run({"tool_input": "foobar", "other_input": "bar"})