Compare commits

...

6 Commits

Author SHA1 Message Date
vowelparrot
e209ebffc7 Add decorator 2023-04-16 20:51:42 -07:00
Harrison Chase
d3c92ed203 cr 2023-04-16 13:37:08 -07:00
Harrison Chase
e438969ab7 Merge branch 'master' into harrison/tools-refactor 2023-04-16 13:18:44 -07:00
Harrison Chase
db0a9c14cf cr 2023-04-16 09:10:56 -07:00
Harrison Chase
21a1ac36b5 cr 2023-04-16 09:06:00 -07:00
Harrison Chase
57f4309fa8 tools refactor 2023-04-15 18:11:02 -07:00
5 changed files with 118 additions and 22 deletions

View File

@@ -1,8 +1,8 @@
"""Interface for tools.""" """Interface for tools."""
from inspect import signature from inspect import Parameter, signature
from typing import Any, Awaitable, Callable, Optional, Union from typing import Any, Awaitable, Callable, List, Optional, Union
from langchain.tools.base import BaseTool from langchain.tools.base import ArgInfo, BaseTool
class Tool(BaseTool): class Tool(BaseTool):
@@ -47,6 +47,21 @@ class InvalidTool(BaseTool):
return f"{tool_name} is not a valid tool, try another one." return f"{tool_name} is not a valid tool, try another one."
def _get_clean_type_name(annotation: Any) -> str:
if annotation == Parameter.empty:
return ""
if getattr(annotation, "__origin__", None) == Union:
types = ", ".join([_get_clean_type_name(arg) for arg in annotation.__args__])
return f"Union[{types}]"
if getattr(annotation, "__origin__", None) == Optional:
optional_type = _get_clean_type_name(annotation.__args__[0])
return f"Optional[{optional_type}]"
return annotation.__name__
def tool(*args: Union[str, Callable], return_direct: bool = False) -> Callable: def tool(*args: Union[str, Callable], return_direct: bool = False) -> Callable:
"""Make tools out of functions, can be used with or without arguments. """Make tools out of functions, can be used with or without arguments.
@@ -73,9 +88,15 @@ def tool(*args: Union[str, Callable], return_direct: bool = False) -> Callable:
assert func.__doc__, "Function must have a docstring" assert func.__doc__, "Function must have a docstring"
# Description example: # Description example:
# search_api(query: str) - Searches the API for the query. # search_api(query: str) - Searches the API for the query.
description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}" func_signature = signature(func)
description = f"{tool_name}{func_signature} - {func.__doc__.strip()}"
tool_args: List[ArgInfo] = []
for param_name, parameter in func_signature.parameters.items():
annotation = _get_clean_type_name(parameter.annotation)
tool_args.append(ArgInfo(name=param_name, description=annotation))
tool_ = Tool( tool_ = Tool(
name=tool_name, name=tool_name,
tool_args=tool_args,
func=func, func=func,
description=description, description=description,
return_direct=return_direct, return_direct=return_direct,

View File

@@ -1,7 +1,8 @@
"""Base implementation for tools or skills.""" """Base implementation for tools or skills."""
from abc import abstractmethod import inspect
from typing import Any, Optional from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Extra, Field, validator from pydantic import BaseModel, Extra, Field, validator
@@ -9,11 +10,15 @@ from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
class BaseTool(BaseModel): class ArgInfo(BaseModel):
"""Class responsible for defining a tool or skill for an LLM."""
name: str name: str
description: str description: str
class BaseTool(ABC, BaseModel):
name: str
description: str
tool_args: Optional[List[ArgInfo]] = None
return_direct: bool = False return_direct: bool = False
verbose: bool = False verbose: bool = False
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager) callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
@@ -24,6 +29,15 @@ class BaseTool(BaseModel):
extra = Extra.forbid extra = Extra.forbid
arbitrary_types_allowed = True arbitrary_types_allowed = True
@property
def args(self) -> List[ArgInfo]:
if self.tool_args is None:
# Get the name expected in the run function
var_names = inspect.signature(self._run).parameters.keys()
return [ArgInfo(name=name, description="") for name in var_names]
else:
return self.tool_args
@validator("callback_manager", pre=True, always=True) @validator("callback_manager", pre=True, always=True)
def set_callback_manager( def set_callback_manager(
cls, callback_manager: Optional[BaseCallbackManager] cls, callback_manager: Optional[BaseCallbackManager]
@@ -35,39 +49,42 @@ class BaseTool(BaseModel):
return callback_manager or get_callback_manager() return callback_manager or get_callback_manager()
@abstractmethod @abstractmethod
def _run(self, tool_input: str) -> str: def _run(self, *args: Any, **kwargs: Any) -> str:
"""Use the tool.""" """Use the tool."""
@abstractmethod @abstractmethod
async def _arun(self, tool_input: str) -> str: async def _arun(self, *args: Any, **kwargs: Any) -> str:
"""Use the tool asynchronously.""" """Use the tool asynchronously."""
def __call__(self, tool_input: str) -> str:
"""Make tools callable with str input."""
return self.run(tool_input)
def run( def run(
self, self,
tool_input: str, tool_input: Union[str, Dict],
verbose: Optional[bool] = None, verbose: Optional[bool] = None,
start_color: Optional[str] = "green", start_color: Optional[str] = "green",
color: Optional[str] = "green", color: Optional[str] = "green",
**kwargs: Any **kwargs: Any
) -> str: ) -> str:
"""Run the tool.""" """Run the tool."""
if isinstance(tool_input, str):
if len(self.args) > 1:
raise ValueError("Cannot pass in a string when > 1 argument expected")
key = self.args[0].name
run_input = {key: tool_input}
else:
run_input = tool_input
if not self.verbose and verbose is not None: if not self.verbose and verbose is not None:
verbose_ = verbose verbose_ = verbose
else: else:
verbose_ = self.verbose verbose_ = self.verbose
self.callback_manager.on_tool_start( self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description}, {"name": self.name, "description": self.description},
tool_input, str(run_input),
verbose=verbose_, verbose=verbose_,
color=start_color, color=start_color,
**kwargs, **kwargs,
) )
try: try:
observation = self._run(tool_input) observation = self._run(**run_input)
except (Exception, KeyboardInterrupt) as e: except (Exception, KeyboardInterrupt) as e:
self.callback_manager.on_tool_error(e, verbose=verbose_) self.callback_manager.on_tool_error(e, verbose=verbose_)
raise e raise e
@@ -78,13 +95,20 @@ class BaseTool(BaseModel):
async def arun( async def arun(
self, self,
tool_input: str, tool_input: Union[str, Dict],
verbose: Optional[bool] = None, verbose: Optional[bool] = None,
start_color: Optional[str] = "green", start_color: Optional[str] = "green",
color: Optional[str] = "green", color: Optional[str] = "green",
**kwargs: Any **kwargs: Any
) -> str: ) -> str:
"""Run the tool asynchronously.""" """Run the tool asynchronously."""
if isinstance(tool_input, str):
if len(self.args) > 1:
raise ValueError("Cannot pass in a string when > 1 argument expected")
key = self.args[0].name
run_input = {key: tool_input}
else:
run_input = tool_input
if not self.verbose and verbose is not None: if not self.verbose and verbose is not None:
verbose_ = verbose verbose_ = verbose
else: else:
@@ -92,7 +116,7 @@ class BaseTool(BaseModel):
if self.callback_manager.is_async: if self.callback_manager.is_async:
await self.callback_manager.on_tool_start( await self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description}, {"name": self.name, "description": self.description},
tool_input, str(run_input),
verbose=verbose_, verbose=verbose_,
color=start_color, color=start_color,
**kwargs, **kwargs,
@@ -100,14 +124,14 @@ class BaseTool(BaseModel):
else: else:
self.callback_manager.on_tool_start( self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description}, {"name": self.name, "description": self.description},
tool_input, str(run_input),
verbose=verbose_, verbose=verbose_,
color=start_color, color=start_color,
**kwargs, **kwargs,
) )
try: try:
# We then call the tool on the tool input to get an observation # We then call the tool on the tool input to get an observation
observation = await self._arun(tool_input) observation = await self._arun(**run_input)
except (Exception, KeyboardInterrupt) as e: except (Exception, KeyboardInterrupt) as e:
if self.callback_manager.is_async: if self.callback_manager.is_async:
await self.callback_manager.on_tool_error(e, verbose=verbose_) await self.callback_manager.on_tool_error(e, verbose=verbose_)
@@ -123,3 +147,7 @@ class BaseTool(BaseModel):
observation, verbose=verbose_, color=color, name=self.name, **kwargs observation, verbose=verbose_, color=color, name=self.name, **kwargs
) )
return observation return observation
def __call__(self, tool_input: str) -> str:
"""Make tool callable."""
return self.run(tool_input)

View File

@@ -0,0 +1,20 @@
from typing import List
from langchain.tools.base import ArgInfo, BaseTool
class ReadFileTool(BaseTool):
name: str = "read_file"
tool_args: List[ArgInfo] = [ArgInfo(name="file_path", description="name of file")]
description: str = "Read file from disk"
def _run(self, file_path: str) -> str:
try:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
return content
except Exception as e:
return "Error: " + str(e)
async def _arun(self, tool_input: str) -> str:
raise NotImplementedError

View File

@@ -0,0 +1,27 @@
import os
from typing import List
from langchain.tools.base import ArgInfo, BaseTool
class WriteFileTool(BaseTool):
name: str = "write_file"
tool_args: List[ArgInfo] = [
ArgInfo(name="file_path", description="name of file"),
ArgInfo(name="text", description="text to write to file"),
]
description: str = "Write file to disk"
def _run(self, file_path: str, text: str) -> str:
try:
directory = os.path.dirname(file_path)
if not os.path.exists(directory) and directory:
os.makedirs(directory)
with open(file_path, "w", encoding="utf-8") as f:
f.write(text)
return "File written to successfully."
except Exception as e:
return "Error: " + str(e)
async def _arun(self, tool_input: str) -> str:
raise NotImplementedError