core[patch]: Improve type checking for the tool decorator (#27460)

**Description:**

When annotating a function with the @tool decorator, the symbol should
have type BaseTool. The previous type annotations did not convey that to
type checkers. This patch creates 4 overloads for the tool function for
the 4 different use cases.

1. @tool decorator with no arguments
2. @tool decorator with only keyword arguments
3. @tool decorator with a name argument (and possibly keyword arguments)
4. Invoking tool as function with a name and runnable positional
arguments

The main function is updated to match the overloads. The changes are
100% backwards compatible (all existing calls should continue to work,
just with better type annotations).

**Twitter handle:** @nvachhar

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Neil Vachharajani 2024-10-29 06:59:56 -07:00 committed by GitHub
parent 94e5765416
commit eec35672a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,5 +1,5 @@
import inspect import inspect
from typing import Any, Callable, Literal, Optional, Union, get_type_hints from typing import Any, Callable, Literal, Optional, Union, get_type_hints, overload
from pydantic import BaseModel, Field, create_model from pydantic import BaseModel, Field, create_model
@ -10,19 +10,79 @@ from langchain_core.tools.simple import Tool
from langchain_core.tools.structured import StructuredTool from langchain_core.tools.structured import StructuredTool
@overload
def tool( def tool(
*args: Union[str, Callable, Runnable], *,
return_direct: bool = False, return_direct: bool = False,
args_schema: Optional[type] = None, args_schema: Optional[type] = None,
infer_schema: bool = True, infer_schema: bool = True,
response_format: Literal["content", "content_and_artifact"] = "content", response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False, parse_docstring: bool = False,
error_on_invalid_docstring: bool = True, error_on_invalid_docstring: bool = True,
) -> Callable: ) -> Callable[[Union[Callable, Runnable]], BaseTool]: ...
@overload
def tool(
name_or_callable: str,
runnable: Runnable,
*,
return_direct: bool = False,
args_schema: Optional[type] = None,
infer_schema: bool = True,
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
) -> BaseTool: ...
@overload
def tool(
name_or_callable: Callable,
*,
return_direct: bool = False,
args_schema: Optional[type] = None,
infer_schema: bool = True,
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
) -> BaseTool: ...
@overload
def tool(
name_or_callable: str,
*,
return_direct: bool = False,
args_schema: Optional[type] = None,
infer_schema: bool = True,
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
) -> Callable[[Union[Callable, Runnable]], BaseTool]: ...
def tool(
name_or_callable: Optional[Union[str, Callable]] = None,
runnable: Optional[Runnable] = None,
*args: Any,
return_direct: bool = False,
args_schema: Optional[type] = None,
infer_schema: bool = True,
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
) -> Union[
BaseTool,
Callable[[Union[Callable, Runnable]], BaseTool],
]:
"""Make tools out of functions, can be used with or without arguments. """Make tools out of functions, can be used with or without arguments.
Args: Args:
*args: The arguments to the tool. name_or_callable: Optional name of the tool or the callable to be
converted to a tool. Must be provided as a positional argument.
runnable: Optional runnable to convert to a tool. Must be provided as a
positional argument.
return_direct: Whether to return directly from the tool rather return_direct: Whether to return directly from the tool rather
than continuing the agent loop. Defaults to False. than continuing the agent loop. Defaults to False.
args_schema: optional argument schema for user to specify. args_schema: optional argument schema for user to specify.
@ -140,8 +200,19 @@ def tool(
return bar return bar
""" """
def _make_with_name(tool_name: str) -> Callable: def _create_tool_factory(
def _make_tool(dec_func: Union[Callable, Runnable]) -> BaseTool: tool_name: str,
) -> Callable[[Union[Callable, Runnable]], BaseTool]:
"""Create a decorator that takes a callable and returns a tool.
Args:
tool_name: The name that will be assigned to the tool.
Returns:
A function that takes a callable or Runnable and returns a tool.
"""
def _tool_factory(dec_func: Union[Callable, Runnable]) -> BaseTool:
if isinstance(dec_func, Runnable): if isinstance(dec_func, Runnable):
runnable = dec_func runnable = dec_func
@ -204,28 +275,63 @@ def tool(
response_format=response_format, response_format=response_format,
) )
return _make_tool return _tool_factory
if len(args) == 2 and isinstance(args[0], str) and isinstance(args[1], Runnable): if len(args) != 0:
return _make_with_name(args[0])(args[1]) # Triggered if a user attempts to use positional arguments that
elif len(args) == 1 and isinstance(args[0], str): # do not exist in the function signature
# if the argument is a string, then we use the string as the tool name # e.g., @tool("name", runnable, "extra_arg")
# Example usage: @tool("search", return_direct=True) # Here, "extra_arg" is not a valid argument
return _make_with_name(args[0]) msg = "Too many arguments for tool decorator. A decorator "
elif len(args) == 1 and callable(args[0]): raise ValueError(msg)
# if the argument is a function, then we use the function name as the tool name
# Example usage: @tool if runnable is not None:
return _make_with_name(args[0].__name__)(args[0]) # tool is used as a function
elif len(args) == 0: # tool_from_runnable = tool("name", runnable)
# if there are no arguments, then we use the function name as the tool name if not name_or_callable:
# Example usage: @tool(return_direct=True) msg = "Runnable without name for tool constructor"
def _partial(func: Callable[[str], str]) -> BaseTool: raise ValueError(msg)
return _make_with_name(func.__name__)(func) if not isinstance(name_or_callable, str):
msg = "Name must be a string for tool constructor"
raise ValueError(msg)
return _create_tool_factory(name_or_callable)(runnable)
elif name_or_callable is not None:
if callable(name_or_callable) and hasattr(name_or_callable, "__name__"):
# Used as a decorator without parameters
# @tool
# def my_tool():
# pass
return _create_tool_factory(name_or_callable.__name__)(name_or_callable)
elif isinstance(name_or_callable, str):
# Used with a new name for the tool
# @tool("search")
# def my_tool():
# pass
#
# or
#
# @tool("search", parse_docstring=True)
# def my_tool():
# pass
return _create_tool_factory(name_or_callable)
else:
msg = (
f"The first argument must be a string or a callable with a __name__ "
f"for tool decorator. Got {type(name_or_callable)}"
)
raise ValueError(msg)
else:
# Tool is used as a decorator with parameters specified
# @tool(parse_docstring=True)
# def my_tool():
# pass
def _partial(func: Union[Callable, Runnable]) -> BaseTool:
"""Partial function that takes a callable and returns a tool."""
name_ = func.get_name() if isinstance(func, Runnable) else func.__name__
tool_factory = _create_tool_factory(name_)
return tool_factory(func)
return _partial return _partial
else:
msg = "Too many arguments for tool decorator"
raise ValueError(msg)
def _get_description_from_runnable(runnable: Runnable) -> str: def _get_description_from_runnable(runnable: Runnable) -> str: