Compare commits

...

6 Commits

Author SHA1 Message Date
vowelparrot
e209ebffc7 Add decorator 2023-04-16 20:51:42 -07:00
Harrison Chase
d3c92ed203 cr 2023-04-16 13:37:08 -07:00
Harrison Chase
e438969ab7 Merge branch 'master' into harrison/tools-refactor 2023-04-16 13:18:44 -07:00
Harrison Chase
db0a9c14cf cr 2023-04-16 09:10:56 -07:00
Harrison Chase
21a1ac36b5 cr 2023-04-16 09:06:00 -07:00
Harrison Chase
57f4309fa8 tools refactor 2023-04-15 18:11:02 -07:00
5 changed files with 118 additions and 22 deletions

View File

@@ -1,8 +1,8 @@
"""Interface for tools."""
from inspect import signature
from typing import Any, Awaitable, Callable, Optional, Union
from inspect import Parameter, signature
from typing import Any, Awaitable, Callable, List, Optional, Union
from langchain.tools.base import BaseTool
from langchain.tools.base import ArgInfo, BaseTool
class Tool(BaseTool):
@@ -47,6 +47,21 @@ class InvalidTool(BaseTool):
return f"{tool_name} is not a valid tool, try another one."
def _get_clean_type_name(annotation: Any) -> str:
if annotation == Parameter.empty:
return ""
if getattr(annotation, "__origin__", None) == Union:
types = ", ".join([_get_clean_type_name(arg) for arg in annotation.__args__])
return f"Union[{types}]"
if getattr(annotation, "__origin__", None) == Optional:
optional_type = _get_clean_type_name(annotation.__args__[0])
return f"Optional[{optional_type}]"
return annotation.__name__
def tool(*args: Union[str, Callable], return_direct: bool = False) -> Callable:
"""Make tools out of functions, can be used with or without arguments.
@@ -73,9 +88,15 @@ def tool(*args: Union[str, Callable], return_direct: bool = False) -> Callable:
assert func.__doc__, "Function must have a docstring"
# Description example:
# search_api(query: str) - Searches the API for the query.
description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}"
func_signature = signature(func)
description = f"{tool_name}{func_signature} - {func.__doc__.strip()}"
tool_args: List[ArgInfo] = []
for param_name, parameter in func_signature.parameters.items():
annotation = _get_clean_type_name(parameter.annotation)
tool_args.append(ArgInfo(name=param_name, description=annotation))
tool_ = Tool(
name=tool_name,
tool_args=tool_args,
func=func,
description=description,
return_direct=return_direct,

View File

@@ -1,7 +1,8 @@
"""Base implementation for tools or skills."""
from abc import abstractmethod
from typing import Any, Optional
import inspect
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Extra, Field, validator
@@ -9,11 +10,15 @@ from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager
class BaseTool(BaseModel):
"""Class responsible for defining a tool or skill for an LLM."""
class ArgInfo(BaseModel):
name: str
description: str
class BaseTool(ABC, BaseModel):
name: str
description: str
tool_args: Optional[List[ArgInfo]] = None
return_direct: bool = False
verbose: bool = False
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
@@ -24,6 +29,15 @@ class BaseTool(BaseModel):
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def args(self) -> List[ArgInfo]:
if self.tool_args is None:
# Get the name expected in the run function
var_names = inspect.signature(self._run).parameters.keys()
return [ArgInfo(name=name, description="") for name in var_names]
else:
return self.tool_args
@validator("callback_manager", pre=True, always=True)
def set_callback_manager(
cls, callback_manager: Optional[BaseCallbackManager]
@@ -35,39 +49,42 @@ class BaseTool(BaseModel):
return callback_manager or get_callback_manager()
@abstractmethod
def _run(self, tool_input: str) -> str:
def _run(self, *args: Any, **kwargs: Any) -> str:
"""Use the tool."""
@abstractmethod
async def _arun(self, tool_input: str) -> str:
async def _arun(self, *args: Any, **kwargs: Any) -> str:
"""Use the tool asynchronously."""
def __call__(self, tool_input: str) -> str:
"""Make tools callable with str input."""
return self.run(tool_input)
def run(
self,
tool_input: str,
tool_input: Union[str, Dict],
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
**kwargs: Any
) -> str:
"""Run the tool."""
if isinstance(tool_input, str):
if len(self.args) > 1:
raise ValueError("Cannot pass in a string when > 1 argument expected")
key = self.args[0].name
run_input = {key: tool_input}
else:
run_input = tool_input
if not self.verbose and verbose is not None:
verbose_ = verbose
else:
verbose_ = self.verbose
self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
tool_input,
str(run_input),
verbose=verbose_,
color=start_color,
**kwargs,
)
try:
observation = self._run(tool_input)
observation = self._run(**run_input)
except (Exception, KeyboardInterrupt) as e:
self.callback_manager.on_tool_error(e, verbose=verbose_)
raise e
@@ -78,13 +95,20 @@ class BaseTool(BaseModel):
async def arun(
self,
tool_input: str,
tool_input: Union[str, Dict],
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
**kwargs: Any
) -> str:
"""Run the tool asynchronously."""
if isinstance(tool_input, str):
if len(self.args) > 1:
raise ValueError("Cannot pass in a string when > 1 argument expected")
key = self.args[0].name
run_input = {key: tool_input}
else:
run_input = tool_input
if not self.verbose and verbose is not None:
verbose_ = verbose
else:
@@ -92,7 +116,7 @@ class BaseTool(BaseModel):
if self.callback_manager.is_async:
await self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
tool_input,
str(run_input),
verbose=verbose_,
color=start_color,
**kwargs,
@@ -100,14 +124,14 @@ class BaseTool(BaseModel):
else:
self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
tool_input,
str(run_input),
verbose=verbose_,
color=start_color,
**kwargs,
)
try:
# We then call the tool on the tool input to get an observation
observation = await self._arun(tool_input)
observation = await self._arun(**run_input)
except (Exception, KeyboardInterrupt) as e:
if self.callback_manager.is_async:
await self.callback_manager.on_tool_error(e, verbose=verbose_)
@@ -123,3 +147,7 @@ class BaseTool(BaseModel):
observation, verbose=verbose_, color=color, name=self.name, **kwargs
)
return observation
def __call__(self, tool_input: str) -> str:
"""Make tool callable."""
return self.run(tool_input)

View File

@@ -0,0 +1,20 @@
from typing import List
from langchain.tools.base import ArgInfo, BaseTool
class ReadFileTool(BaseTool):
name: str = "read_file"
tool_args: List[ArgInfo] = [ArgInfo(name="file_path", description="name of file")]
description: str = "Read file from disk"
def _run(self, file_path: str) -> str:
try:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
return content
except Exception as e:
return "Error: " + str(e)
async def _arun(self, tool_input: str) -> str:
raise NotImplementedError

View File

@@ -0,0 +1,27 @@
import os
from typing import List
from langchain.tools.base import ArgInfo, BaseTool
class WriteFileTool(BaseTool):
name: str = "write_file"
tool_args: List[ArgInfo] = [
ArgInfo(name="file_path", description="name of file"),
ArgInfo(name="text", description="text to write to file"),
]
description: str = "Write file to disk"
def _run(self, file_path: str, text: str) -> str:
try:
directory = os.path.dirname(file_path)
if not os.path.exists(directory) and directory:
os.makedirs(directory)
with open(file_path, "w", encoding="utf-8") as f:
f.write(text)
return "File written to successfully."
except Exception as e:
return "Error: " + str(e)
async def _arun(self, tool_input: str) -> str:
raise NotImplementedError