Update Tool Input (#3103)

- Remove dynamic model creation in the `args()` property. _Only infer
for the decorator (and add an argument to NOT infer if someone wishes to
only pass as a string)_
- Update the validation example to make it less likely to be
misinterpreted as a "safe" way to run a repl


There is one example of "Multi-argument tools" in the custom_tools.ipynb
from yesterday, but we could add more. The output parsing for the base
MRKL agent hasn't been adapted to handle structured args at this point
in time

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
Zander Chase
2023-04-18 18:18:33 -07:00
committed by GitHub
parent 19116010ee
commit 90ef705ced
6 changed files with 216 additions and 213 deletions

View File

@@ -1,75 +1,21 @@
"""Base implementation for tools or skills."""
import inspect
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union
from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union
from pydantic import BaseModel, Extra, Field, create_model, validator
from pydantic import BaseModel, Extra, Field, validator
from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager
def create_args_schema_model_from_signature(run_func: Callable) -> Type[BaseModel]:
"""Create a pydantic model type from a function's signature."""
signature_ = inspect.signature(run_func)
field_definitions: Dict[str, Any] = {}
for name, param in signature_.parameters.items():
if name == "self":
continue
default_value = (
param.default if param.default != inspect.Parameter.empty else None
)
annotation = (
param.annotation if param.annotation != inspect.Parameter.empty else Any
)
# Handle functions with *args in the signature
if param.kind == inspect.Parameter.VAR_POSITIONAL:
field_definitions[name] = (
Any,
Field(default=None, extra={"is_var_positional": True}),
)
# handle functions with **kwargs in the signature
elif param.kind == inspect.Parameter.VAR_KEYWORD:
field_definitions[name] = (
Any,
Field(default=None, extra={"is_var_keyword": True}),
)
# Handle all other named parameters
else:
is_keyword_only = param.kind == inspect.Parameter.KEYWORD_ONLY
field_definitions[name] = (
annotation,
Field(
default=default_value, extra={"is_keyword_only": is_keyword_only}
),
)
return create_model("ArgsModel", **field_definitions) # type: ignore
def _to_args_and_kwargs(model: BaseModel) -> Tuple[Sequence, dict]:
args = []
kwargs = {}
for name, field in model.__fields__.items():
value = getattr(model, name)
# Handle *args in the function signature
if field.field_info.extra.get("extra", {}).get("is_var_positional"):
if isinstance(value, str):
# Base case for backwards compatability
args.append(value)
elif value is not None:
args.extend(value)
# Handle **kwargs in the function signature
elif field.field_info.extra.get("extra", {}).get("is_var_keyword"):
if value is not None:
kwargs.update(value)
elif field.field_info.extra.get("extra", {}).get("is_keyword_only"):
kwargs[name] = value
else:
args.append(value)
return tuple(args), kwargs
def _to_args_and_kwargs(run_input: Union[str, Dict]) -> Tuple[Sequence, dict]:
# For backwards compatability, if run_input is a string,
# pass as a positional argument.
if isinstance(run_input, str):
return (run_input,), {}
else:
return [], run_input
class BaseTool(ABC, BaseModel):
@@ -90,31 +36,28 @@ class BaseTool(ABC, BaseModel):
arbitrary_types_allowed = True
@property
def args(self) -> Type[BaseModel]:
def args(self) -> Union[Type[BaseModel], Type[str]]:
"""Generate an input pydantic model."""
if self.args_schema is not None:
return self.args_schema
return create_args_schema_model_from_signature(self._run)
return str if self.args_schema is None else self.args_schema
def _parse_input(
self,
tool_input: Union[str, Dict],
) -> BaseModel:
) -> None:
"""Convert tool input to pydantic model."""
pydantic_input_type = self.args
input_args = self.args
if isinstance(tool_input, str):
# For backwards compatibility, a tool that only takes
# a single string input will be converted to a dict.
# to be validated.
field_name = next(iter(pydantic_input_type.__fields__))
tool_input = {field_name: tool_input}
if pydantic_input_type is not None:
return pydantic_input_type.parse_obj(tool_input)
if issubclass(input_args, BaseModel):
key_ = next(iter(input_args.__fields__.keys()))
input_args.validate({key_: tool_input})
else:
raise ValueError(
f"args_schema required for tool {self.name} in order to"
f" accept input of type {type(tool_input)}"
)
if issubclass(input_args, BaseModel):
input_args.validate(tool_input)
else:
raise ValueError(
f"args_schema required for tool {self.name} in order to"
f" accept input of type {type(tool_input)}"
)
@validator("callback_manager", pre=True, always=True)
def set_callback_manager(
@@ -143,20 +86,20 @@ class BaseTool(ABC, BaseModel):
**kwargs: Any,
) -> str:
"""Run the tool."""
run_input = self._parse_input(tool_input)
self._parse_input(tool_input)
if not self.verbose and verbose is not None:
verbose_ = verbose
else:
verbose_ = self.verbose
self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
str(run_input),
tool_input if isinstance(tool_input, str) else str(tool_input),
verbose=verbose_,
color=start_color,
**kwargs,
)
try:
args, kwargs = _to_args_and_kwargs(run_input)
args, kwargs = _to_args_and_kwargs(tool_input)
observation = self._run(*args, **kwargs)
except (Exception, KeyboardInterrupt) as e:
self.callback_manager.on_tool_error(e, verbose=verbose_)
@@ -175,7 +118,7 @@ class BaseTool(ABC, BaseModel):
**kwargs: Any,
) -> str:
"""Run the tool asynchronously."""
run_input = self._parse_input(tool_input)
self._parse_input(tool_input)
if not self.verbose and verbose is not None:
verbose_ = verbose
else:
@@ -183,7 +126,7 @@ class BaseTool(ABC, BaseModel):
if self.callback_manager.is_async:
await self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
str(run_input.dict()),
tool_input if isinstance(tool_input, str) else str(tool_input),
verbose=verbose_,
color=start_color,
**kwargs,
@@ -191,14 +134,14 @@ class BaseTool(ABC, BaseModel):
else:
self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
str(run_input.dict()),
tool_input if isinstance(tool_input, str) else str(tool_input),
verbose=verbose_,
color=start_color,
**kwargs,
)
try:
# We then call the tool on the tool input to get an observation
args, kwargs = _to_args_and_kwargs(run_input)
args, kwargs = _to_args_and_kwargs(tool_input)
observation = await self._arun(*args, **kwargs)
except (Exception, KeyboardInterrupt) as e:
if self.callback_manager.is_async:

View File

@@ -29,6 +29,6 @@ class WriteFileTool(BaseTool):
except Exception as e:
return "Error: " + str(e)
async def _arun(self, tool_input: str) -> str:
async def _arun(self, file_path: str, text: str) -> str:
# TODO: Add aiofiles method
raise NotImplementedError