mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-19 17:36:00 +00:00
Add validation on agent instantiation for multi-input tools (#3681)
Tradeoffs here: - No lint-time checking for compatibility - Differs from JS package - The signature inference, etc. in the base tool isn't simple - The `args_schema` is optional Pros: - Forwards compatibility retained - Doesn't break backwards compatibility - User doesn't have to think about which class to subclass (single base tool or dynamic `Tool` interface regardless of input) - No need to change the load_tools, etc. interfaces Co-authored-by: Hasan Patel <mangafield@gmail.com>
This commit is contained in:
@@ -115,6 +115,11 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def is_single_input(self) -> bool:
|
||||
"""Whether the tool only accepts a single input."""
|
||||
return len(self.args) == 1
|
||||
|
||||
@property
|
||||
def args(self) -> dict:
|
||||
if self.args_schema is not None:
|
||||
@@ -148,11 +153,11 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
||||
return callback_manager or get_callback_manager()
|
||||
|
||||
@abstractmethod
|
||||
def _run(self, *args: Any, **kwargs: Any) -> str:
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Use the tool."""
|
||||
|
||||
@abstractmethod
|
||||
async def _arun(self, *args: Any, **kwargs: Any) -> str:
|
||||
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Use the tool asynchronously."""
|
||||
|
||||
def run(
|
||||
@@ -183,7 +188,7 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
||||
self.callback_manager.on_tool_error(e, verbose=verbose_)
|
||||
raise e
|
||||
self.callback_manager.on_tool_end(
|
||||
observation, verbose=verbose_, color=color, name=self.name, **kwargs
|
||||
str(observation), verbose=verbose_, color=color, name=self.name, **kwargs
|
||||
)
|
||||
return observation
|
||||
|
||||
@@ -194,7 +199,7 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
||||
start_color: Optional[str] = "green",
|
||||
color: Optional[str] = "green",
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
) -> Any:
|
||||
"""Run the tool asynchronously."""
|
||||
self._parse_input(tool_input)
|
||||
if not self.verbose and verbose is not None:
|
||||
@@ -229,7 +234,11 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
||||
raise e
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_tool_end(
|
||||
observation, verbose=verbose_, color=color, name=self.name, **kwargs
|
||||
str(observation),
|
||||
verbose=verbose_,
|
||||
color=color,
|
||||
name=self.name,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_tool_end(
|
||||
@@ -237,6 +246,6 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
||||
)
|
||||
return observation
|
||||
|
||||
def __call__(self, tool_input: str) -> str:
|
||||
def __call__(self, tool_input: Union[str, dict]) -> Any:
|
||||
"""Make tool callable."""
|
||||
return self.run(tool_input)
|
||||
|
Reference in New Issue
Block a user