mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-02 01:23:07 +00:00
Wfh/async tool (#9878)
Co-authored-by: Daniel Brenot <dbrenot@pelmorex.com> Co-authored-by: Daniel <daniel.alexander.brenot@gmail.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
7bba1d911b
commit
d799963870
@ -2,6 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from functools import partial
|
||||
@ -437,7 +438,7 @@ class Tool(BaseTool):
|
||||
"""Tool that takes in function or coroutine directly."""
|
||||
|
||||
description: str = ""
|
||||
func: Callable[..., str]
|
||||
func: Optional[Callable[..., str]]
|
||||
"""The function to run when the tool is called."""
|
||||
coroutine: Optional[Callable[..., Awaitable[str]]] = None
|
||||
"""The asynchronous version of the function."""
|
||||
@ -488,16 +489,18 @@ class Tool(BaseTool):
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use the tool."""
|
||||
new_argument_supported = signature(self.func).parameters.get("callbacks")
|
||||
return (
|
||||
self.func(
|
||||
*args,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**kwargs,
|
||||
if self.func:
|
||||
new_argument_supported = signature(self.func).parameters.get("callbacks")
|
||||
return (
|
||||
self.func(
|
||||
*args,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**kwargs,
|
||||
)
|
||||
if new_argument_supported
|
||||
else self.func(*args, **kwargs)
|
||||
)
|
||||
if new_argument_supported
|
||||
else self.func(*args, **kwargs)
|
||||
)
|
||||
raise NotImplementedError("Tool does not support sync")
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
@ -523,7 +526,7 @@ class Tool(BaseTool):
|
||||
|
||||
# TODO: this is for backwards compatibility, remove in future
|
||||
def __init__(
|
||||
self, name: str, func: Callable, description: str, **kwargs: Any
|
||||
self, name: str, func: Optional[Callable], description: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Initialize tool."""
|
||||
super(Tool, self).__init__(
|
||||
@ -533,17 +536,23 @@ class Tool(BaseTool):
|
||||
@classmethod
|
||||
def from_function(
|
||||
cls,
|
||||
func: Callable,
|
||||
func: Optional[Callable],
|
||||
name: str, # We keep these required to support backwards compatibility
|
||||
description: str,
|
||||
return_direct: bool = False,
|
||||
args_schema: Optional[Type[BaseModel]] = None,
|
||||
coroutine: Optional[
|
||||
Callable[..., Awaitable[Any]]
|
||||
] = None, # This is last for compatibility, but should be after func
|
||||
**kwargs: Any,
|
||||
) -> Tool:
|
||||
"""Initialize tool from a function."""
|
||||
if func is None and coroutine is None:
|
||||
raise ValueError("Function and/or coroutine must be provided")
|
||||
return cls(
|
||||
name=name,
|
||||
func=func,
|
||||
coroutine=coroutine,
|
||||
description=description,
|
||||
return_direct=return_direct,
|
||||
args_schema=args_schema,
|
||||
@ -557,7 +566,7 @@ class StructuredTool(BaseTool):
|
||||
description: str = ""
|
||||
args_schema: Type[BaseModel] = Field(..., description="The tool schema.")
|
||||
"""The input arguments' schema."""
|
||||
func: Callable[..., Any]
|
||||
func: Optional[Callable[..., Any]]
|
||||
"""The function to run when the tool is called."""
|
||||
coroutine: Optional[Callable[..., Awaitable[Any]]] = None
|
||||
"""The asynchronous version of the function."""
|
||||
@ -592,16 +601,18 @@ class StructuredTool(BaseTool):
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use the tool."""
|
||||
new_argument_supported = signature(self.func).parameters.get("callbacks")
|
||||
return (
|
||||
self.func(
|
||||
*args,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**kwargs,
|
||||
if self.func:
|
||||
new_argument_supported = signature(self.func).parameters.get("callbacks")
|
||||
return (
|
||||
self.func(
|
||||
*args,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**kwargs,
|
||||
)
|
||||
if new_argument_supported
|
||||
else self.func(*args, **kwargs)
|
||||
)
|
||||
if new_argument_supported
|
||||
else self.func(*args, **kwargs)
|
||||
)
|
||||
raise NotImplementedError("Tool does not support sync")
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
@ -628,7 +639,8 @@ class StructuredTool(BaseTool):
|
||||
@classmethod
|
||||
def from_function(
|
||||
cls,
|
||||
func: Callable,
|
||||
func: Optional[Callable] = None,
|
||||
coroutine: Optional[Callable[..., Awaitable[Any]]] = None,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
return_direct: bool = False,
|
||||
@ -642,6 +654,7 @@ class StructuredTool(BaseTool):
|
||||
|
||||
Args:
|
||||
func: The function from which to create a tool
|
||||
coroutine: The async function from which to create a tool
|
||||
name: The name of the tool. Defaults to the function name
|
||||
description: The description of the tool. Defaults to the function docstring
|
||||
return_direct: Whether to return the result directly or as a callback
|
||||
@ -662,21 +675,31 @@ class StructuredTool(BaseTool):
|
||||
tool = StructuredTool.from_function(add)
|
||||
tool.run(1, 2) # 3
|
||||
"""
|
||||
name = name or func.__name__
|
||||
description = description or func.__doc__
|
||||
assert (
|
||||
description is not None
|
||||
), "Function must have a docstring if description not provided."
|
||||
|
||||
if func is not None:
|
||||
source_function = func
|
||||
elif coroutine is not None:
|
||||
source_function = coroutine
|
||||
else:
|
||||
raise ValueError("Function and/or coroutine must be provided")
|
||||
name = name or source_function.__name__
|
||||
description = description or source_function.__doc__
|
||||
if description is None:
|
||||
raise ValueError(
|
||||
"Function must have a docstring if description not provided."
|
||||
)
|
||||
|
||||
# Description example:
|
||||
# search_api(query: str) - Searches the API for the query.
|
||||
description = f"{name}{signature(func)} - {description.strip()}"
|
||||
sig = signature(source_function)
|
||||
description = f"{name}{sig} - {description.strip()}"
|
||||
_args_schema = args_schema
|
||||
if _args_schema is None and infer_schema:
|
||||
_args_schema = create_schema_from_function(f"{name}Schema", func)
|
||||
_args_schema = create_schema_from_function(f"{name}Schema", source_function)
|
||||
return cls(
|
||||
name=name,
|
||||
func=func,
|
||||
coroutine=coroutine,
|
||||
args_schema=_args_schema,
|
||||
description=description,
|
||||
return_direct=return_direct,
|
||||
@ -720,10 +743,18 @@ def tool(
|
||||
"""
|
||||
|
||||
def _make_with_name(tool_name: str) -> Callable:
|
||||
def _make_tool(func: Callable) -> BaseTool:
|
||||
def _make_tool(dec_func: Callable) -> BaseTool:
|
||||
if inspect.iscoroutinefunction(dec_func):
|
||||
coroutine = dec_func
|
||||
func = None
|
||||
else:
|
||||
coroutine = None
|
||||
func = dec_func
|
||||
|
||||
if infer_schema or args_schema is not None:
|
||||
return StructuredTool.from_function(
|
||||
func,
|
||||
coroutine,
|
||||
name=tool_name,
|
||||
return_direct=return_direct,
|
||||
args_schema=args_schema,
|
||||
@ -731,12 +762,17 @@ def tool(
|
||||
)
|
||||
# If someone doesn't want a schema applied, we must treat it as
|
||||
# a simple string->string function
|
||||
assert func.__doc__ is not None, "Function must have a docstring"
|
||||
if func.__doc__ is None:
|
||||
raise ValueError(
|
||||
"Function must have a docstring if "
|
||||
"description not provided and infer_schema is False."
|
||||
)
|
||||
return Tool(
|
||||
name=tool_name,
|
||||
func=func,
|
||||
description=f"{tool_name} tool",
|
||||
return_direct=return_direct,
|
||||
coroutine=coroutine,
|
||||
)
|
||||
|
||||
return _make_tool
|
||||
|
@ -546,7 +546,7 @@ def test_tool_with_kwargs() -> None:
|
||||
def test_missing_docstring() -> None:
|
||||
"""Test error is raised when docstring is missing."""
|
||||
# expect to throw a value error if there's no docstring
|
||||
with pytest.raises(AssertionError, match="Function must have a docstring"):
|
||||
with pytest.raises(ValueError, match="Function must have a docstring"):
|
||||
|
||||
@tool
|
||||
def search_api(query: str) -> str:
|
||||
|
Loading…
Reference in New Issue
Block a user