diff --git a/libs/langchain/langchain/tools/base.py b/libs/langchain/langchain/tools/base.py index 7c24c01b3ae..5d59f7959bf 100644 --- a/libs/langchain/langchain/tools/base.py +++ b/libs/langchain/langchain/tools/base.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio +import inspect import warnings from abc import abstractmethod from functools import partial @@ -437,7 +438,7 @@ class Tool(BaseTool): """Tool that takes in function or coroutine directly.""" description: str = "" - func: Callable[..., str] + func: Optional[Callable[..., str]] """The function to run when the tool is called.""" coroutine: Optional[Callable[..., Awaitable[str]]] = None """The asynchronous version of the function.""" @@ -488,16 +489,18 @@ class Tool(BaseTool): **kwargs: Any, ) -> Any: """Use the tool.""" - new_argument_supported = signature(self.func).parameters.get("callbacks") - return ( - self.func( - *args, - callbacks=run_manager.get_child() if run_manager else None, - **kwargs, + if self.func: + new_argument_supported = signature(self.func).parameters.get("callbacks") + return ( + self.func( + *args, + callbacks=run_manager.get_child() if run_manager else None, + **kwargs, + ) + if new_argument_supported + else self.func(*args, **kwargs) ) - if new_argument_supported - else self.func(*args, **kwargs) - ) + raise NotImplementedError("Tool does not support sync") async def _arun( self, @@ -523,7 +526,7 @@ class Tool(BaseTool): # TODO: this is for backwards compatibility, remove in future def __init__( - self, name: str, func: Callable, description: str, **kwargs: Any + self, name: str, func: Optional[Callable], description: str, **kwargs: Any ) -> None: """Initialize tool.""" super(Tool, self).__init__( @@ -533,17 +536,23 @@ class Tool(BaseTool): @classmethod def from_function( cls, - func: Callable, + func: Optional[Callable], name: str, # We keep these required to support backwards compatibility description: str, return_direct: bool = False, args_schema: Optional[Type[BaseModel]] = None, + coroutine: Optional[ + Callable[..., Awaitable[Any]] + ] = None, # This is last for compatibility, but should be after func **kwargs: Any, ) -> Tool: """Initialize tool from a function.""" + if func is None and coroutine is None: + raise ValueError("Function and/or coroutine must be provided") return cls( name=name, func=func, + coroutine=coroutine, description=description, return_direct=return_direct, args_schema=args_schema, @@ -557,7 +566,7 @@ class StructuredTool(BaseTool): description: str = "" args_schema: Type[BaseModel] = Field(..., description="The tool schema.") """The input arguments' schema.""" - func: Callable[..., Any] + func: Optional[Callable[..., Any]] """The function to run when the tool is called.""" coroutine: Optional[Callable[..., Awaitable[Any]]] = None """The asynchronous version of the function.""" @@ -592,16 +601,18 @@ class StructuredTool(BaseTool): **kwargs: Any, ) -> Any: """Use the tool.""" - new_argument_supported = signature(self.func).parameters.get("callbacks") - return ( - self.func( - *args, - callbacks=run_manager.get_child() if run_manager else None, - **kwargs, + if self.func: + new_argument_supported = signature(self.func).parameters.get("callbacks") + return ( + self.func( + *args, + callbacks=run_manager.get_child() if run_manager else None, + **kwargs, + ) + if new_argument_supported + else self.func(*args, **kwargs) ) - if new_argument_supported - else self.func(*args, **kwargs) - ) + raise NotImplementedError("Tool does not support sync") async def _arun( self, @@ -628,7 +639,8 @@ class StructuredTool(BaseTool): @classmethod def from_function( cls, - func: Callable, + func: Optional[Callable] = None, + coroutine: Optional[Callable[..., Awaitable[Any]]] = None, name: Optional[str] = None, description: Optional[str] = None, return_direct: bool = False, @@ -642,6 +654,7 @@ class StructuredTool(BaseTool): Args: func: The function from which to create a tool + coroutine: The async function from which to create a tool name: The name of the tool. Defaults to the function name description: The description of the tool. Defaults to the function docstring return_direct: Whether to return the result directly or as a callback @@ -662,21 +675,31 @@ class StructuredTool(BaseTool): tool = StructuredTool.from_function(add) tool.run(1, 2) # 3 """ - name = name or func.__name__ - description = description or func.__doc__ - assert ( - description is not None - ), "Function must have a docstring if description not provided." + + if func is not None: + source_function = func + elif coroutine is not None: + source_function = coroutine + else: + raise ValueError("Function and/or coroutine must be provided") + name = name or source_function.__name__ + description = description or source_function.__doc__ + if description is None: + raise ValueError( + "Function must have a docstring if description not provided." + ) # Description example: # search_api(query: str) - Searches the API for the query. - description = f"{name}{signature(func)} - {description.strip()}" + sig = signature(source_function) + description = f"{name}{sig} - {description.strip()}" _args_schema = args_schema if _args_schema is None and infer_schema: - _args_schema = create_schema_from_function(f"{name}Schema", func) + _args_schema = create_schema_from_function(f"{name}Schema", source_function) return cls( name=name, func=func, + coroutine=coroutine, args_schema=_args_schema, description=description, return_direct=return_direct, @@ -720,10 +743,18 @@ def tool( """ def _make_with_name(tool_name: str) -> Callable: - def _make_tool(func: Callable) -> BaseTool: + def _make_tool(dec_func: Callable) -> BaseTool: + if inspect.iscoroutinefunction(dec_func): + coroutine = dec_func + func = None + else: + coroutine = None + func = dec_func + if infer_schema or args_schema is not None: return StructuredTool.from_function( func, + coroutine, name=tool_name, return_direct=return_direct, args_schema=args_schema, @@ -731,12 +762,17 @@ def tool( ) # If someone doesn't want a schema applied, we must treat it as # a simple string->string function - assert func.__doc__ is not None, "Function must have a docstring" + if func.__doc__ is None: + raise ValueError( + "Function must have a docstring if " + "description not provided and infer_schema is False." + ) return Tool( name=tool_name, func=func, description=f"{tool_name} tool", return_direct=return_direct, + coroutine=coroutine, ) return _make_tool diff --git a/libs/langchain/tests/unit_tests/tools/test_base.py b/libs/langchain/tests/unit_tests/tools/test_base.py index a66953a0669..9c5fdf39e10 100644 --- a/libs/langchain/tests/unit_tests/tools/test_base.py +++ b/libs/langchain/tests/unit_tests/tools/test_base.py @@ -546,7 +546,7 @@ def test_tool_with_kwargs() -> None: def test_missing_docstring() -> None: """Test error is raised when docstring is missing.""" # expect to throw a value error if there's no docstring - with pytest.raises(AssertionError, match="Function must have a docstring"): + with pytest.raises(ValueError, match="Function must have a docstring"): @tool def search_api(query: str) -> str: