diff --git a/libs/core/langchain_core/tools/convert.py b/libs/core/langchain_core/tools/convert.py index e85435a86df..bb8b85f5558 100644 --- a/libs/core/langchain_core/tools/convert.py +++ b/libs/core/langchain_core/tools/convert.py @@ -1,5 +1,5 @@ import inspect -from typing import Any, Callable, Literal, Optional, Union, get_type_hints +from typing import Any, Callable, Literal, Optional, Union, get_type_hints, overload from pydantic import BaseModel, Field, create_model @@ -10,19 +10,79 @@ from langchain_core.tools.simple import Tool from langchain_core.tools.structured import StructuredTool +@overload def tool( - *args: Union[str, Callable, Runnable], + *, return_direct: bool = False, args_schema: Optional[type] = None, infer_schema: bool = True, response_format: Literal["content", "content_and_artifact"] = "content", parse_docstring: bool = False, error_on_invalid_docstring: bool = True, -) -> Callable: +) -> Callable[[Union[Callable, Runnable]], BaseTool]: ... + + +@overload +def tool( + name_or_callable: str, + runnable: Runnable, + *, + return_direct: bool = False, + args_schema: Optional[type] = None, + infer_schema: bool = True, + response_format: Literal["content", "content_and_artifact"] = "content", + parse_docstring: bool = False, + error_on_invalid_docstring: bool = True, +) -> BaseTool: ... + + +@overload +def tool( + name_or_callable: Callable, + *, + return_direct: bool = False, + args_schema: Optional[type] = None, + infer_schema: bool = True, + response_format: Literal["content", "content_and_artifact"] = "content", + parse_docstring: bool = False, + error_on_invalid_docstring: bool = True, +) -> BaseTool: ... + + +@overload +def tool( + name_or_callable: str, + *, + return_direct: bool = False, + args_schema: Optional[type] = None, + infer_schema: bool = True, + response_format: Literal["content", "content_and_artifact"] = "content", + parse_docstring: bool = False, + error_on_invalid_docstring: bool = True, +) -> Callable[[Union[Callable, Runnable]], BaseTool]: ... + + +def tool( + name_or_callable: Optional[Union[str, Callable]] = None, + runnable: Optional[Runnable] = None, + *args: Any, + return_direct: bool = False, + args_schema: Optional[type] = None, + infer_schema: bool = True, + response_format: Literal["content", "content_and_artifact"] = "content", + parse_docstring: bool = False, + error_on_invalid_docstring: bool = True, +) -> Union[ + BaseTool, + Callable[[Union[Callable, Runnable]], BaseTool], +]: """Make tools out of functions, can be used with or without arguments. Args: - *args: The arguments to the tool. + name_or_callable: Optional name of the tool or the callable to be + converted to a tool. Must be provided as a positional argument. + runnable: Optional runnable to convert to a tool. Must be provided as a + positional argument. return_direct: Whether to return directly from the tool rather than continuing the agent loop. Defaults to False. args_schema: optional argument schema for user to specify. @@ -140,8 +200,19 @@ def tool( return bar """ - def _make_with_name(tool_name: str) -> Callable: - def _make_tool(dec_func: Union[Callable, Runnable]) -> BaseTool: + def _create_tool_factory( + tool_name: str, + ) -> Callable[[Union[Callable, Runnable]], BaseTool]: + """Create a decorator that takes a callable and returns a tool. + + Args: + tool_name: The name that will be assigned to the tool. + + Returns: + A function that takes a callable or Runnable and returns a tool. + """ + + def _tool_factory(dec_func: Union[Callable, Runnable]) -> BaseTool: if isinstance(dec_func, Runnable): runnable = dec_func @@ -204,28 +275,63 @@ def tool( response_format=response_format, ) - return _make_tool + return _tool_factory - if len(args) == 2 and isinstance(args[0], str) and isinstance(args[1], Runnable): - return _make_with_name(args[0])(args[1]) - elif len(args) == 1 and isinstance(args[0], str): - # if the argument is a string, then we use the string as the tool name - # Example usage: @tool("search", return_direct=True) - return _make_with_name(args[0]) - elif len(args) == 1 and callable(args[0]): - # if the argument is a function, then we use the function name as the tool name - # Example usage: @tool - return _make_with_name(args[0].__name__)(args[0]) - elif len(args) == 0: - # if there are no arguments, then we use the function name as the tool name - # Example usage: @tool(return_direct=True) - def _partial(func: Callable[[str], str]) -> BaseTool: - return _make_with_name(func.__name__)(func) + if len(args) != 0: + # Triggered if a user attempts to use positional arguments that + # do not exist in the function signature + # e.g., @tool("name", runnable, "extra_arg") + # Here, "extra_arg" is not a valid argument + msg = "Too many arguments for tool decorator. A decorator " + raise ValueError(msg) + + if runnable is not None: + # tool is used as a function + # tool_from_runnable = tool("name", runnable) + if not name_or_callable: + msg = "Runnable without name for tool constructor" + raise ValueError(msg) + if not isinstance(name_or_callable, str): + msg = "Name must be a string for tool constructor" + raise ValueError(msg) + return _create_tool_factory(name_or_callable)(runnable) + elif name_or_callable is not None: + if callable(name_or_callable) and hasattr(name_or_callable, "__name__"): + # Used as a decorator without parameters + # @tool + # def my_tool(): + # pass + return _create_tool_factory(name_or_callable.__name__)(name_or_callable) + elif isinstance(name_or_callable, str): + # Used with a new name for the tool + # @tool("search") + # def my_tool(): + # pass + # + # or + # + # @tool("search", parse_docstring=True) + # def my_tool(): + # pass + return _create_tool_factory(name_or_callable) + else: + msg = ( + f"The first argument must be a string or a callable with a __name__ " + f"for tool decorator. Got {type(name_or_callable)}" + ) + raise ValueError(msg) + else: + # Tool is used as a decorator with parameters specified + # @tool(parse_docstring=True) + # def my_tool(): + # pass + def _partial(func: Union[Callable, Runnable]) -> BaseTool: + """Partial function that takes a callable and returns a tool.""" + name_ = func.get_name() if isinstance(func, Runnable) else func.__name__ + tool_factory = _create_tool_factory(name_) + return tool_factory(func) return _partial - else: - msg = "Too many arguments for tool decorator" - raise ValueError(msg) def _get_description_from_runnable(runnable: Runnable) -> str: