core[patch]: Improve type checking for the tool decorator (#27460)

**Description:**

When annotating a function with the @tool decorator, the symbol should
have type BaseTool. The previous type annotations did not convey that to
type checkers. This patch creates 4 overloads for the tool function for
the 4 different use cases.

1. @tool decorator with no arguments
2. @tool decorator with only keyword arguments
3. @tool decorator with a name argument (and possibly keyword arguments)
4. Invoking tool as function with a name and runnable positional
arguments

The main function is updated to match the overloads. The changes are
100% backwards compatible (all existing calls should continue to work,
just with better type annotations).

**Twitter handle:** @nvachhar

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Neil Vachharajani 2024-10-29 06:59:56 -07:00 committed by GitHub
parent 94e5765416
commit eec35672a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,5 +1,5 @@
import inspect
from typing import Any, Callable, Literal, Optional, Union, get_type_hints
from typing import Any, Callable, Literal, Optional, Union, get_type_hints, overload
from pydantic import BaseModel, Field, create_model
@ -10,19 +10,79 @@ from langchain_core.tools.simple import Tool
from langchain_core.tools.structured import StructuredTool
@overload
def tool(
*args: Union[str, Callable, Runnable],
*,
return_direct: bool = False,
args_schema: Optional[type] = None,
infer_schema: bool = True,
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
) -> Callable:
) -> Callable[[Union[Callable, Runnable]], BaseTool]: ...
@overload
def tool(
name_or_callable: str,
runnable: Runnable,
*,
return_direct: bool = False,
args_schema: Optional[type] = None,
infer_schema: bool = True,
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
) -> BaseTool: ...
@overload
def tool(
name_or_callable: Callable,
*,
return_direct: bool = False,
args_schema: Optional[type] = None,
infer_schema: bool = True,
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
) -> BaseTool: ...
@overload
def tool(
name_or_callable: str,
*,
return_direct: bool = False,
args_schema: Optional[type] = None,
infer_schema: bool = True,
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
) -> Callable[[Union[Callable, Runnable]], BaseTool]: ...
def tool(
name_or_callable: Optional[Union[str, Callable]] = None,
runnable: Optional[Runnable] = None,
*args: Any,
return_direct: bool = False,
args_schema: Optional[type] = None,
infer_schema: bool = True,
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
) -> Union[
BaseTool,
Callable[[Union[Callable, Runnable]], BaseTool],
]:
"""Make tools out of functions, can be used with or without arguments.
Args:
*args: The arguments to the tool.
name_or_callable: Optional name of the tool or the callable to be
converted to a tool. Must be provided as a positional argument.
runnable: Optional runnable to convert to a tool. Must be provided as a
positional argument.
return_direct: Whether to return directly from the tool rather
than continuing the agent loop. Defaults to False.
args_schema: optional argument schema for user to specify.
@ -140,8 +200,19 @@ def tool(
return bar
"""
def _make_with_name(tool_name: str) -> Callable:
def _make_tool(dec_func: Union[Callable, Runnable]) -> BaseTool:
def _create_tool_factory(
tool_name: str,
) -> Callable[[Union[Callable, Runnable]], BaseTool]:
"""Create a decorator that takes a callable and returns a tool.
Args:
tool_name: The name that will be assigned to the tool.
Returns:
A function that takes a callable or Runnable and returns a tool.
"""
def _tool_factory(dec_func: Union[Callable, Runnable]) -> BaseTool:
if isinstance(dec_func, Runnable):
runnable = dec_func
@ -204,28 +275,63 @@ def tool(
response_format=response_format,
)
return _make_tool
return _tool_factory
if len(args) == 2 and isinstance(args[0], str) and isinstance(args[1], Runnable):
return _make_with_name(args[0])(args[1])
elif len(args) == 1 and isinstance(args[0], str):
# if the argument is a string, then we use the string as the tool name
# Example usage: @tool("search", return_direct=True)
return _make_with_name(args[0])
elif len(args) == 1 and callable(args[0]):
# if the argument is a function, then we use the function name as the tool name
# Example usage: @tool
return _make_with_name(args[0].__name__)(args[0])
elif len(args) == 0:
# if there are no arguments, then we use the function name as the tool name
# Example usage: @tool(return_direct=True)
def _partial(func: Callable[[str], str]) -> BaseTool:
return _make_with_name(func.__name__)(func)
if len(args) != 0:
# Triggered if a user attempts to use positional arguments that
# do not exist in the function signature
# e.g., @tool("name", runnable, "extra_arg")
# Here, "extra_arg" is not a valid argument
msg = "Too many arguments for tool decorator. A decorator "
raise ValueError(msg)
if runnable is not None:
# tool is used as a function
# tool_from_runnable = tool("name", runnable)
if not name_or_callable:
msg = "Runnable without name for tool constructor"
raise ValueError(msg)
if not isinstance(name_or_callable, str):
msg = "Name must be a string for tool constructor"
raise ValueError(msg)
return _create_tool_factory(name_or_callable)(runnable)
elif name_or_callable is not None:
if callable(name_or_callable) and hasattr(name_or_callable, "__name__"):
# Used as a decorator without parameters
# @tool
# def my_tool():
# pass
return _create_tool_factory(name_or_callable.__name__)(name_or_callable)
elif isinstance(name_or_callable, str):
# Used with a new name for the tool
# @tool("search")
# def my_tool():
# pass
#
# or
#
# @tool("search", parse_docstring=True)
# def my_tool():
# pass
return _create_tool_factory(name_or_callable)
else:
msg = (
f"The first argument must be a string or a callable with a __name__ "
f"for tool decorator. Got {type(name_or_callable)}"
)
raise ValueError(msg)
else:
# Tool is used as a decorator with parameters specified
# @tool(parse_docstring=True)
# def my_tool():
# pass
def _partial(func: Union[Callable, Runnable]) -> BaseTool:
"""Partial function that takes a callable and returns a tool."""
name_ = func.get_name() if isinstance(func, Runnable) else func.__name__
tool_factory = _create_tool_factory(name_)
return tool_factory(func)
return _partial
else:
msg = "Too many arguments for tool decorator"
raise ValueError(msg)
def _get_description_from_runnable(runnable: Runnable) -> str: