tools refactor (#2961)

Co-authored-by: vowelparrot <130414180+vowelparrot@users.noreply.github.com>
This commit is contained in:
Harrison Chase
2023-04-17 21:35:29 -07:00
committed by GitHub
parent 7a8c935b90
commit db968284f8
13 changed files with 833 additions and 143 deletions

View File

@@ -1,19 +1,84 @@
"""Base implementation for tools or skills."""
from abc import abstractmethod
from typing import Any, Optional
import inspect
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union
from pydantic import BaseModel, Extra, Field, validator
from pydantic import BaseModel, Extra, Field, create_model, validator
from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager
class BaseTool(BaseModel):
"""Class responsible for defining a tool or skill for an LLM."""
def create_args_schema_model_from_signature(run_func: Callable) -> Type[BaseModel]:
"""Create a pydantic model type from a function's signature."""
signature_ = inspect.signature(run_func)
field_definitions: Dict[str, Any] = {}
for name, param in signature_.parameters.items():
if name == "self":
continue
default_value = (
param.default if param.default != inspect.Parameter.empty else None
)
annotation = (
param.annotation if param.annotation != inspect.Parameter.empty else Any
)
# Handle functions with *args in the signature
if param.kind == inspect.Parameter.VAR_POSITIONAL:
field_definitions[name] = (
Any,
Field(default=None, extra={"is_var_positional": True}),
)
# handle functions with **kwargs in the signature
elif param.kind == inspect.Parameter.VAR_KEYWORD:
field_definitions[name] = (
Any,
Field(default=None, extra={"is_var_keyword": True}),
)
# Handle all other named parameters
else:
is_keyword_only = param.kind == inspect.Parameter.KEYWORD_ONLY
field_definitions[name] = (
annotation,
Field(
default=default_value, extra={"is_keyword_only": is_keyword_only}
),
)
return create_model("ArgsModel", **field_definitions) # type: ignore
def _to_args_and_kwargs(model: BaseModel) -> Tuple[Sequence, dict]:
args = []
kwargs = {}
for name, field in model.__fields__.items():
value = getattr(model, name)
# Handle *args in the function signature
if field.field_info.extra.get("extra", {}).get("is_var_positional"):
if isinstance(value, str):
# Base case for backwards compatability
args.append(value)
elif value is not None:
args.extend(value)
# Handle **kwargs in the function signature
elif field.field_info.extra.get("extra", {}).get("is_var_keyword"):
if value is not None:
kwargs.update(value)
elif field.field_info.extra.get("extra", {}).get("is_keyword_only"):
kwargs[name] = value
else:
args.append(value)
return tuple(args), kwargs
class BaseTool(ABC, BaseModel):
"""Interface LangChain tools must implement."""
name: str
description: str
args_schema: Optional[Type[BaseModel]] = None
"""Pydantic model class to validate and parse the tool's input arguments."""
return_direct: bool = False
verbose: bool = False
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
@@ -24,6 +89,33 @@ class BaseTool(BaseModel):
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def args(self) -> Type[BaseModel]:
"""Generate an input pydantic model."""
if self.args_schema is not None:
return self.args_schema
return create_args_schema_model_from_signature(self._run)
def _parse_input(
self,
tool_input: Union[str, Dict],
) -> BaseModel:
"""Convert tool input to pydantic model."""
pydantic_input_type = self.args
if isinstance(tool_input, str):
# For backwards compatibility, a tool that only takes
# a single string input will be converted to a dict.
# to be validated.
field_name = next(iter(pydantic_input_type.__fields__))
tool_input = {field_name: tool_input}
if pydantic_input_type is not None:
return pydantic_input_type.parse_obj(tool_input)
else:
raise ValueError(
f"args_schema required for tool {self.name} in order to"
f" accept input of type {type(tool_input)}"
)
@validator("callback_manager", pre=True, always=True)
def set_callback_manager(
cls, callback_manager: Optional[BaseCallbackManager]
@@ -35,39 +127,37 @@ class BaseTool(BaseModel):
return callback_manager or get_callback_manager()
@abstractmethod
def _run(self, tool_input: str) -> str:
def _run(self, *args: Any, **kwargs: Any) -> str:
"""Use the tool."""
@abstractmethod
async def _arun(self, tool_input: str) -> str:
async def _arun(self, *args: Any, **kwargs: Any) -> str:
"""Use the tool asynchronously."""
def __call__(self, tool_input: str) -> str:
"""Make tools callable with str input."""
return self.run(tool_input)
def run(
self,
tool_input: str,
tool_input: Union[str, Dict],
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
**kwargs: Any
**kwargs: Any,
) -> str:
"""Run the tool."""
run_input = self._parse_input(tool_input)
if not self.verbose and verbose is not None:
verbose_ = verbose
else:
verbose_ = self.verbose
self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
tool_input,
str(run_input),
verbose=verbose_,
color=start_color,
**kwargs,
)
try:
observation = self._run(tool_input)
args, kwargs = _to_args_and_kwargs(run_input)
observation = self._run(*args, **kwargs)
except (Exception, KeyboardInterrupt) as e:
self.callback_manager.on_tool_error(e, verbose=verbose_)
raise e
@@ -78,13 +168,14 @@ class BaseTool(BaseModel):
async def arun(
self,
tool_input: str,
tool_input: Union[str, Dict],
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
**kwargs: Any
**kwargs: Any,
) -> str:
"""Run the tool asynchronously."""
run_input = self._parse_input(tool_input)
if not self.verbose and verbose is not None:
verbose_ = verbose
else:
@@ -92,7 +183,7 @@ class BaseTool(BaseModel):
if self.callback_manager.is_async:
await self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
tool_input,
str(run_input.dict()),
verbose=verbose_,
color=start_color,
**kwargs,
@@ -100,14 +191,15 @@ class BaseTool(BaseModel):
else:
self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
tool_input,
str(run_input.dict()),
verbose=verbose_,
color=start_color,
**kwargs,
)
try:
# We then call the tool on the tool input to get an observation
observation = await self._arun(tool_input)
args, kwargs = _to_args_and_kwargs(run_input)
observation = await self._arun(*args, **kwargs)
except (Exception, KeyboardInterrupt) as e:
if self.callback_manager.is_async:
await self.callback_manager.on_tool_error(e, verbose=verbose_)
@@ -123,3 +215,7 @@ class BaseTool(BaseModel):
observation, verbose=verbose_, color=color, name=self.name, **kwargs
)
return observation
def __call__(self, tool_input: str) -> str:
"""Make tool callable."""
return self.run(tool_input)

View File

@@ -0,0 +1,29 @@
from typing import Type
from pydantic import BaseModel, Field
from langchain.tools.base import BaseTool
class ReadFileInput(BaseModel):
"""Input for ReadFileTool."""
file_path: str = Field(..., description="name of file")
class ReadFileTool(BaseTool):
name: str = "read_file"
tool_args: Type[BaseModel] = ReadFileInput
description: str = "Read file from disk"
def _run(self, file_path: str) -> str:
try:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
return content
except Exception as e:
return "Error: " + str(e)
async def _arun(self, tool_input: str) -> str:
# TODO: Add aiofiles method
raise NotImplementedError

View File

@@ -0,0 +1,34 @@
import os
from typing import Type
from pydantic import BaseModel, Field
from langchain.tools.base import BaseTool
class WriteFileInput(BaseModel):
"""Input for WriteFileTool."""
file_path: str = Field(..., description="name of file")
text: str = Field(..., description="text to write to file")
class WriteFileTool(BaseTool):
name: str = "write_file"
tool_args: Type[BaseModel] = WriteFileInput
description: str = "Write file to disk"
def _run(self, file_path: str, text: str) -> str:
try:
directory = os.path.dirname(file_path)
if not os.path.exists(directory) and directory:
os.makedirs(directory)
with open(file_path, "w", encoding="utf-8") as f:
f.write(text)
return "File written to successfully."
except Exception as e:
return "Error: " + str(e)
async def _arun(self, tool_input: str) -> str:
# TODO: Add aiofiles method
raise NotImplementedError