mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 06:33:41 +00:00
Compare commits
6 Commits
mdrxy/vers
...
vwp/tools-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e209ebffc7 | ||
|
|
d3c92ed203 | ||
|
|
e438969ab7 | ||
|
|
db0a9c14cf | ||
|
|
21a1ac36b5 | ||
|
|
57f4309fa8 |
@@ -1,8 +1,8 @@
|
|||||||
"""Interface for tools."""
|
"""Interface for tools."""
|
||||||
from inspect import signature
|
from inspect import Parameter, signature
|
||||||
from typing import Any, Awaitable, Callable, Optional, Union
|
from typing import Any, Awaitable, Callable, List, Optional, Union
|
||||||
|
|
||||||
from langchain.tools.base import BaseTool
|
from langchain.tools.base import ArgInfo, BaseTool
|
||||||
|
|
||||||
|
|
||||||
class Tool(BaseTool):
|
class Tool(BaseTool):
|
||||||
@@ -47,6 +47,21 @@ class InvalidTool(BaseTool):
|
|||||||
return f"{tool_name} is not a valid tool, try another one."
|
return f"{tool_name} is not a valid tool, try another one."
|
||||||
|
|
||||||
|
|
||||||
|
def _get_clean_type_name(annotation: Any) -> str:
|
||||||
|
if annotation == Parameter.empty:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if getattr(annotation, "__origin__", None) == Union:
|
||||||
|
types = ", ".join([_get_clean_type_name(arg) for arg in annotation.__args__])
|
||||||
|
return f"Union[{types}]"
|
||||||
|
|
||||||
|
if getattr(annotation, "__origin__", None) == Optional:
|
||||||
|
optional_type = _get_clean_type_name(annotation.__args__[0])
|
||||||
|
return f"Optional[{optional_type}]"
|
||||||
|
|
||||||
|
return annotation.__name__
|
||||||
|
|
||||||
|
|
||||||
def tool(*args: Union[str, Callable], return_direct: bool = False) -> Callable:
|
def tool(*args: Union[str, Callable], return_direct: bool = False) -> Callable:
|
||||||
"""Make tools out of functions, can be used with or without arguments.
|
"""Make tools out of functions, can be used with or without arguments.
|
||||||
|
|
||||||
@@ -73,9 +88,15 @@ def tool(*args: Union[str, Callable], return_direct: bool = False) -> Callable:
|
|||||||
assert func.__doc__, "Function must have a docstring"
|
assert func.__doc__, "Function must have a docstring"
|
||||||
# Description example:
|
# Description example:
|
||||||
# search_api(query: str) - Searches the API for the query.
|
# search_api(query: str) - Searches the API for the query.
|
||||||
description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}"
|
func_signature = signature(func)
|
||||||
|
description = f"{tool_name}{func_signature} - {func.__doc__.strip()}"
|
||||||
|
tool_args: List[ArgInfo] = []
|
||||||
|
for param_name, parameter in func_signature.parameters.items():
|
||||||
|
annotation = _get_clean_type_name(parameter.annotation)
|
||||||
|
tool_args.append(ArgInfo(name=param_name, description=annotation))
|
||||||
tool_ = Tool(
|
tool_ = Tool(
|
||||||
name=tool_name,
|
name=tool_name,
|
||||||
|
tool_args=tool_args,
|
||||||
func=func,
|
func=func,
|
||||||
description=description,
|
description=description,
|
||||||
return_direct=return_direct,
|
return_direct=return_direct,
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
"""Base implementation for tools or skills."""
|
"""Base implementation for tools or skills."""
|
||||||
|
|
||||||
from abc import abstractmethod
|
import inspect
|
||||||
from typing import Any, Optional
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra, Field, validator
|
from pydantic import BaseModel, Extra, Field, validator
|
||||||
|
|
||||||
@@ -9,11 +10,15 @@ from langchain.callbacks import get_callback_manager
|
|||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
|
|
||||||
|
|
||||||
class BaseTool(BaseModel):
|
class ArgInfo(BaseModel):
|
||||||
"""Class responsible for defining a tool or skill for an LLM."""
|
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTool(ABC, BaseModel):
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
tool_args: Optional[List[ArgInfo]] = None
|
||||||
return_direct: bool = False
|
return_direct: bool = False
|
||||||
verbose: bool = False
|
verbose: bool = False
|
||||||
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
|
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
|
||||||
@@ -24,6 +29,15 @@ class BaseTool(BaseModel):
|
|||||||
extra = Extra.forbid
|
extra = Extra.forbid
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def args(self) -> List[ArgInfo]:
|
||||||
|
if self.tool_args is None:
|
||||||
|
# Get the name expected in the run function
|
||||||
|
var_names = inspect.signature(self._run).parameters.keys()
|
||||||
|
return [ArgInfo(name=name, description="") for name in var_names]
|
||||||
|
else:
|
||||||
|
return self.tool_args
|
||||||
|
|
||||||
@validator("callback_manager", pre=True, always=True)
|
@validator("callback_manager", pre=True, always=True)
|
||||||
def set_callback_manager(
|
def set_callback_manager(
|
||||||
cls, callback_manager: Optional[BaseCallbackManager]
|
cls, callback_manager: Optional[BaseCallbackManager]
|
||||||
@@ -35,39 +49,42 @@ class BaseTool(BaseModel):
|
|||||||
return callback_manager or get_callback_manager()
|
return callback_manager or get_callback_manager()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _run(self, tool_input: str) -> str:
|
def _run(self, *args: Any, **kwargs: Any) -> str:
|
||||||
"""Use the tool."""
|
"""Use the tool."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def _arun(self, tool_input: str) -> str:
|
async def _arun(self, *args: Any, **kwargs: Any) -> str:
|
||||||
"""Use the tool asynchronously."""
|
"""Use the tool asynchronously."""
|
||||||
|
|
||||||
def __call__(self, tool_input: str) -> str:
|
|
||||||
"""Make tools callable with str input."""
|
|
||||||
return self.run(tool_input)
|
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
tool_input: str,
|
tool_input: Union[str, Dict],
|
||||||
verbose: Optional[bool] = None,
|
verbose: Optional[bool] = None,
|
||||||
start_color: Optional[str] = "green",
|
start_color: Optional[str] = "green",
|
||||||
color: Optional[str] = "green",
|
color: Optional[str] = "green",
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Run the tool."""
|
"""Run the tool."""
|
||||||
|
if isinstance(tool_input, str):
|
||||||
|
if len(self.args) > 1:
|
||||||
|
raise ValueError("Cannot pass in a string when > 1 argument expected")
|
||||||
|
key = self.args[0].name
|
||||||
|
run_input = {key: tool_input}
|
||||||
|
else:
|
||||||
|
run_input = tool_input
|
||||||
if not self.verbose and verbose is not None:
|
if not self.verbose and verbose is not None:
|
||||||
verbose_ = verbose
|
verbose_ = verbose
|
||||||
else:
|
else:
|
||||||
verbose_ = self.verbose
|
verbose_ = self.verbose
|
||||||
self.callback_manager.on_tool_start(
|
self.callback_manager.on_tool_start(
|
||||||
{"name": self.name, "description": self.description},
|
{"name": self.name, "description": self.description},
|
||||||
tool_input,
|
str(run_input),
|
||||||
verbose=verbose_,
|
verbose=verbose_,
|
||||||
color=start_color,
|
color=start_color,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
observation = self._run(tool_input)
|
observation = self._run(**run_input)
|
||||||
except (Exception, KeyboardInterrupt) as e:
|
except (Exception, KeyboardInterrupt) as e:
|
||||||
self.callback_manager.on_tool_error(e, verbose=verbose_)
|
self.callback_manager.on_tool_error(e, verbose=verbose_)
|
||||||
raise e
|
raise e
|
||||||
@@ -78,13 +95,20 @@ class BaseTool(BaseModel):
|
|||||||
|
|
||||||
async def arun(
|
async def arun(
|
||||||
self,
|
self,
|
||||||
tool_input: str,
|
tool_input: Union[str, Dict],
|
||||||
verbose: Optional[bool] = None,
|
verbose: Optional[bool] = None,
|
||||||
start_color: Optional[str] = "green",
|
start_color: Optional[str] = "green",
|
||||||
color: Optional[str] = "green",
|
color: Optional[str] = "green",
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Run the tool asynchronously."""
|
"""Run the tool asynchronously."""
|
||||||
|
if isinstance(tool_input, str):
|
||||||
|
if len(self.args) > 1:
|
||||||
|
raise ValueError("Cannot pass in a string when > 1 argument expected")
|
||||||
|
key = self.args[0].name
|
||||||
|
run_input = {key: tool_input}
|
||||||
|
else:
|
||||||
|
run_input = tool_input
|
||||||
if not self.verbose and verbose is not None:
|
if not self.verbose and verbose is not None:
|
||||||
verbose_ = verbose
|
verbose_ = verbose
|
||||||
else:
|
else:
|
||||||
@@ -92,7 +116,7 @@ class BaseTool(BaseModel):
|
|||||||
if self.callback_manager.is_async:
|
if self.callback_manager.is_async:
|
||||||
await self.callback_manager.on_tool_start(
|
await self.callback_manager.on_tool_start(
|
||||||
{"name": self.name, "description": self.description},
|
{"name": self.name, "description": self.description},
|
||||||
tool_input,
|
str(run_input),
|
||||||
verbose=verbose_,
|
verbose=verbose_,
|
||||||
color=start_color,
|
color=start_color,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -100,14 +124,14 @@ class BaseTool(BaseModel):
|
|||||||
else:
|
else:
|
||||||
self.callback_manager.on_tool_start(
|
self.callback_manager.on_tool_start(
|
||||||
{"name": self.name, "description": self.description},
|
{"name": self.name, "description": self.description},
|
||||||
tool_input,
|
str(run_input),
|
||||||
verbose=verbose_,
|
verbose=verbose_,
|
||||||
color=start_color,
|
color=start_color,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
# We then call the tool on the tool input to get an observation
|
# We then call the tool on the tool input to get an observation
|
||||||
observation = await self._arun(tool_input)
|
observation = await self._arun(**run_input)
|
||||||
except (Exception, KeyboardInterrupt) as e:
|
except (Exception, KeyboardInterrupt) as e:
|
||||||
if self.callback_manager.is_async:
|
if self.callback_manager.is_async:
|
||||||
await self.callback_manager.on_tool_error(e, verbose=verbose_)
|
await self.callback_manager.on_tool_error(e, verbose=verbose_)
|
||||||
@@ -123,3 +147,7 @@ class BaseTool(BaseModel):
|
|||||||
observation, verbose=verbose_, color=color, name=self.name, **kwargs
|
observation, verbose=verbose_, color=color, name=self.name, **kwargs
|
||||||
)
|
)
|
||||||
return observation
|
return observation
|
||||||
|
|
||||||
|
def __call__(self, tool_input: str) -> str:
|
||||||
|
"""Make tool callable."""
|
||||||
|
return self.run(tool_input)
|
||||||
|
|||||||
0
langchain/tools/file_management/__init__.py
Normal file
0
langchain/tools/file_management/__init__.py
Normal file
20
langchain/tools/file_management/read.py
Normal file
20
langchain/tools/file_management/read.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
from langchain.tools.base import ArgInfo, BaseTool
|
||||||
|
|
||||||
|
|
||||||
|
class ReadFileTool(BaseTool):
|
||||||
|
name: str = "read_file"
|
||||||
|
tool_args: List[ArgInfo] = [ArgInfo(name="file_path", description="name of file")]
|
||||||
|
description: str = "Read file from disk"
|
||||||
|
|
||||||
|
def _run(self, file_path: str) -> str:
|
||||||
|
try:
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
content = f.read()
|
||||||
|
return content
|
||||||
|
except Exception as e:
|
||||||
|
return "Error: " + str(e)
|
||||||
|
|
||||||
|
async def _arun(self, tool_input: str) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
27
langchain/tools/file_management/write.py
Normal file
27
langchain/tools/file_management/write.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
import os
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from langchain.tools.base import ArgInfo, BaseTool
|
||||||
|
|
||||||
|
|
||||||
|
class WriteFileTool(BaseTool):
|
||||||
|
name: str = "write_file"
|
||||||
|
tool_args: List[ArgInfo] = [
|
||||||
|
ArgInfo(name="file_path", description="name of file"),
|
||||||
|
ArgInfo(name="text", description="text to write to file"),
|
||||||
|
]
|
||||||
|
description: str = "Write file to disk"
|
||||||
|
|
||||||
|
def _run(self, file_path: str, text: str) -> str:
|
||||||
|
try:
|
||||||
|
directory = os.path.dirname(file_path)
|
||||||
|
if not os.path.exists(directory) and directory:
|
||||||
|
os.makedirs(directory)
|
||||||
|
with open(file_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(text)
|
||||||
|
return "File written to successfully."
|
||||||
|
except Exception as e:
|
||||||
|
return "Error: " + str(e)
|
||||||
|
|
||||||
|
async def _arun(self, tool_input: str) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
Reference in New Issue
Block a user