mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-07 03:56:39 +00:00
core[patch]: Improve type checking for the tool decorator (#27460)
**Description:** When annotating a function with the @tool decorator, the symbol should have type BaseTool. The previous type annotations did not convey that to type checkers. This patch creates 4 overloads for the tool function for the 4 different use cases. 1. @tool decorator with no arguments 2. @tool decorator with only keyword arguments 3. @tool decorator with a name argument (and possibly keyword arguments) 4. Invoking tool as function with a name and runnable positional arguments The main function is updated to match the overloads. The changes are 100% backwards compatible (all existing calls should continue to work, just with better type annotations). **Twitter handle:** @nvachhar --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
94e5765416
commit
eec35672a4
@ -1,5 +1,5 @@
|
|||||||
import inspect
|
import inspect
|
||||||
from typing import Any, Callable, Literal, Optional, Union, get_type_hints
|
from typing import Any, Callable, Literal, Optional, Union, get_type_hints, overload
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, create_model
|
from pydantic import BaseModel, Field, create_model
|
||||||
|
|
||||||
@ -10,19 +10,79 @@ from langchain_core.tools.simple import Tool
|
|||||||
from langchain_core.tools.structured import StructuredTool
|
from langchain_core.tools.structured import StructuredTool
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
def tool(
|
def tool(
|
||||||
*args: Union[str, Callable, Runnable],
|
*,
|
||||||
return_direct: bool = False,
|
return_direct: bool = False,
|
||||||
args_schema: Optional[type] = None,
|
args_schema: Optional[type] = None,
|
||||||
infer_schema: bool = True,
|
infer_schema: bool = True,
|
||||||
response_format: Literal["content", "content_and_artifact"] = "content",
|
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||||
parse_docstring: bool = False,
|
parse_docstring: bool = False,
|
||||||
error_on_invalid_docstring: bool = True,
|
error_on_invalid_docstring: bool = True,
|
||||||
) -> Callable:
|
) -> Callable[[Union[Callable, Runnable]], BaseTool]: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def tool(
|
||||||
|
name_or_callable: str,
|
||||||
|
runnable: Runnable,
|
||||||
|
*,
|
||||||
|
return_direct: bool = False,
|
||||||
|
args_schema: Optional[type] = None,
|
||||||
|
infer_schema: bool = True,
|
||||||
|
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||||
|
parse_docstring: bool = False,
|
||||||
|
error_on_invalid_docstring: bool = True,
|
||||||
|
) -> BaseTool: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def tool(
|
||||||
|
name_or_callable: Callable,
|
||||||
|
*,
|
||||||
|
return_direct: bool = False,
|
||||||
|
args_schema: Optional[type] = None,
|
||||||
|
infer_schema: bool = True,
|
||||||
|
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||||
|
parse_docstring: bool = False,
|
||||||
|
error_on_invalid_docstring: bool = True,
|
||||||
|
) -> BaseTool: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def tool(
|
||||||
|
name_or_callable: str,
|
||||||
|
*,
|
||||||
|
return_direct: bool = False,
|
||||||
|
args_schema: Optional[type] = None,
|
||||||
|
infer_schema: bool = True,
|
||||||
|
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||||
|
parse_docstring: bool = False,
|
||||||
|
error_on_invalid_docstring: bool = True,
|
||||||
|
) -> Callable[[Union[Callable, Runnable]], BaseTool]: ...
|
||||||
|
|
||||||
|
|
||||||
|
def tool(
|
||||||
|
name_or_callable: Optional[Union[str, Callable]] = None,
|
||||||
|
runnable: Optional[Runnable] = None,
|
||||||
|
*args: Any,
|
||||||
|
return_direct: bool = False,
|
||||||
|
args_schema: Optional[type] = None,
|
||||||
|
infer_schema: bool = True,
|
||||||
|
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||||
|
parse_docstring: bool = False,
|
||||||
|
error_on_invalid_docstring: bool = True,
|
||||||
|
) -> Union[
|
||||||
|
BaseTool,
|
||||||
|
Callable[[Union[Callable, Runnable]], BaseTool],
|
||||||
|
]:
|
||||||
"""Make tools out of functions, can be used with or without arguments.
|
"""Make tools out of functions, can be used with or without arguments.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
*args: The arguments to the tool.
|
name_or_callable: Optional name of the tool or the callable to be
|
||||||
|
converted to a tool. Must be provided as a positional argument.
|
||||||
|
runnable: Optional runnable to convert to a tool. Must be provided as a
|
||||||
|
positional argument.
|
||||||
return_direct: Whether to return directly from the tool rather
|
return_direct: Whether to return directly from the tool rather
|
||||||
than continuing the agent loop. Defaults to False.
|
than continuing the agent loop. Defaults to False.
|
||||||
args_schema: optional argument schema for user to specify.
|
args_schema: optional argument schema for user to specify.
|
||||||
@ -140,8 +200,19 @@ def tool(
|
|||||||
return bar
|
return bar
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _make_with_name(tool_name: str) -> Callable:
|
def _create_tool_factory(
|
||||||
def _make_tool(dec_func: Union[Callable, Runnable]) -> BaseTool:
|
tool_name: str,
|
||||||
|
) -> Callable[[Union[Callable, Runnable]], BaseTool]:
|
||||||
|
"""Create a decorator that takes a callable and returns a tool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: The name that will be assigned to the tool.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A function that takes a callable or Runnable and returns a tool.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _tool_factory(dec_func: Union[Callable, Runnable]) -> BaseTool:
|
||||||
if isinstance(dec_func, Runnable):
|
if isinstance(dec_func, Runnable):
|
||||||
runnable = dec_func
|
runnable = dec_func
|
||||||
|
|
||||||
@ -204,28 +275,63 @@ def tool(
|
|||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
return _make_tool
|
return _tool_factory
|
||||||
|
|
||||||
if len(args) == 2 and isinstance(args[0], str) and isinstance(args[1], Runnable):
|
if len(args) != 0:
|
||||||
return _make_with_name(args[0])(args[1])
|
# Triggered if a user attempts to use positional arguments that
|
||||||
elif len(args) == 1 and isinstance(args[0], str):
|
# do not exist in the function signature
|
||||||
# if the argument is a string, then we use the string as the tool name
|
# e.g., @tool("name", runnable, "extra_arg")
|
||||||
# Example usage: @tool("search", return_direct=True)
|
# Here, "extra_arg" is not a valid argument
|
||||||
return _make_with_name(args[0])
|
msg = "Too many arguments for tool decorator. A decorator "
|
||||||
elif len(args) == 1 and callable(args[0]):
|
raise ValueError(msg)
|
||||||
# if the argument is a function, then we use the function name as the tool name
|
|
||||||
# Example usage: @tool
|
if runnable is not None:
|
||||||
return _make_with_name(args[0].__name__)(args[0])
|
# tool is used as a function
|
||||||
elif len(args) == 0:
|
# tool_from_runnable = tool("name", runnable)
|
||||||
# if there are no arguments, then we use the function name as the tool name
|
if not name_or_callable:
|
||||||
# Example usage: @tool(return_direct=True)
|
msg = "Runnable without name for tool constructor"
|
||||||
def _partial(func: Callable[[str], str]) -> BaseTool:
|
raise ValueError(msg)
|
||||||
return _make_with_name(func.__name__)(func)
|
if not isinstance(name_or_callable, str):
|
||||||
|
msg = "Name must be a string for tool constructor"
|
||||||
|
raise ValueError(msg)
|
||||||
|
return _create_tool_factory(name_or_callable)(runnable)
|
||||||
|
elif name_or_callable is not None:
|
||||||
|
if callable(name_or_callable) and hasattr(name_or_callable, "__name__"):
|
||||||
|
# Used as a decorator without parameters
|
||||||
|
# @tool
|
||||||
|
# def my_tool():
|
||||||
|
# pass
|
||||||
|
return _create_tool_factory(name_or_callable.__name__)(name_or_callable)
|
||||||
|
elif isinstance(name_or_callable, str):
|
||||||
|
# Used with a new name for the tool
|
||||||
|
# @tool("search")
|
||||||
|
# def my_tool():
|
||||||
|
# pass
|
||||||
|
#
|
||||||
|
# or
|
||||||
|
#
|
||||||
|
# @tool("search", parse_docstring=True)
|
||||||
|
# def my_tool():
|
||||||
|
# pass
|
||||||
|
return _create_tool_factory(name_or_callable)
|
||||||
|
else:
|
||||||
|
msg = (
|
||||||
|
f"The first argument must be a string or a callable with a __name__ "
|
||||||
|
f"for tool decorator. Got {type(name_or_callable)}"
|
||||||
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
else:
|
||||||
|
# Tool is used as a decorator with parameters specified
|
||||||
|
# @tool(parse_docstring=True)
|
||||||
|
# def my_tool():
|
||||||
|
# pass
|
||||||
|
def _partial(func: Union[Callable, Runnable]) -> BaseTool:
|
||||||
|
"""Partial function that takes a callable and returns a tool."""
|
||||||
|
name_ = func.get_name() if isinstance(func, Runnable) else func.__name__
|
||||||
|
tool_factory = _create_tool_factory(name_)
|
||||||
|
return tool_factory(func)
|
||||||
|
|
||||||
return _partial
|
return _partial
|
||||||
else:
|
|
||||||
msg = "Too many arguments for tool decorator"
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_description_from_runnable(runnable: Runnable) -> str:
|
def _get_description_from_runnable(runnable: Runnable) -> str:
|
||||||
|
Loading…
Reference in New Issue
Block a user