mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-31 00:29:57 +00:00
core[patch]: Improve type checking for the tool decorator (#27460)
**Description:** When annotating a function with the @tool decorator, the symbol should have type BaseTool. The previous type annotations did not convey that to type checkers. This patch creates 4 overloads for the tool function for the 4 different use cases. 1. @tool decorator with no arguments 2. @tool decorator with only keyword arguments 3. @tool decorator with a name argument (and possibly keyword arguments) 4. Invoking tool as function with a name and runnable positional arguments The main function is updated to match the overloads. The changes are 100% backwards compatible (all existing calls should continue to work, just with better type annotations). **Twitter handle:** @nvachhar --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
94e5765416
commit
eec35672a4
@ -1,5 +1,5 @@
|
||||
import inspect
|
||||
from typing import Any, Callable, Literal, Optional, Union, get_type_hints
|
||||
from typing import Any, Callable, Literal, Optional, Union, get_type_hints, overload
|
||||
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
@ -10,19 +10,79 @@ from langchain_core.tools.simple import Tool
|
||||
from langchain_core.tools.structured import StructuredTool
|
||||
|
||||
|
||||
@overload
|
||||
def tool(
|
||||
*args: Union[str, Callable, Runnable],
|
||||
*,
|
||||
return_direct: bool = False,
|
||||
args_schema: Optional[type] = None,
|
||||
infer_schema: bool = True,
|
||||
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||
parse_docstring: bool = False,
|
||||
error_on_invalid_docstring: bool = True,
|
||||
) -> Callable:
|
||||
) -> Callable[[Union[Callable, Runnable]], BaseTool]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def tool(
|
||||
name_or_callable: str,
|
||||
runnable: Runnable,
|
||||
*,
|
||||
return_direct: bool = False,
|
||||
args_schema: Optional[type] = None,
|
||||
infer_schema: bool = True,
|
||||
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||
parse_docstring: bool = False,
|
||||
error_on_invalid_docstring: bool = True,
|
||||
) -> BaseTool: ...
|
||||
|
||||
|
||||
@overload
|
||||
def tool(
|
||||
name_or_callable: Callable,
|
||||
*,
|
||||
return_direct: bool = False,
|
||||
args_schema: Optional[type] = None,
|
||||
infer_schema: bool = True,
|
||||
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||
parse_docstring: bool = False,
|
||||
error_on_invalid_docstring: bool = True,
|
||||
) -> BaseTool: ...
|
||||
|
||||
|
||||
@overload
|
||||
def tool(
|
||||
name_or_callable: str,
|
||||
*,
|
||||
return_direct: bool = False,
|
||||
args_schema: Optional[type] = None,
|
||||
infer_schema: bool = True,
|
||||
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||
parse_docstring: bool = False,
|
||||
error_on_invalid_docstring: bool = True,
|
||||
) -> Callable[[Union[Callable, Runnable]], BaseTool]: ...
|
||||
|
||||
|
||||
def tool(
|
||||
name_or_callable: Optional[Union[str, Callable]] = None,
|
||||
runnable: Optional[Runnable] = None,
|
||||
*args: Any,
|
||||
return_direct: bool = False,
|
||||
args_schema: Optional[type] = None,
|
||||
infer_schema: bool = True,
|
||||
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||
parse_docstring: bool = False,
|
||||
error_on_invalid_docstring: bool = True,
|
||||
) -> Union[
|
||||
BaseTool,
|
||||
Callable[[Union[Callable, Runnable]], BaseTool],
|
||||
]:
|
||||
"""Make tools out of functions, can be used with or without arguments.
|
||||
|
||||
Args:
|
||||
*args: The arguments to the tool.
|
||||
name_or_callable: Optional name of the tool or the callable to be
|
||||
converted to a tool. Must be provided as a positional argument.
|
||||
runnable: Optional runnable to convert to a tool. Must be provided as a
|
||||
positional argument.
|
||||
return_direct: Whether to return directly from the tool rather
|
||||
than continuing the agent loop. Defaults to False.
|
||||
args_schema: optional argument schema for user to specify.
|
||||
@ -140,8 +200,19 @@ def tool(
|
||||
return bar
|
||||
"""
|
||||
|
||||
def _make_with_name(tool_name: str) -> Callable:
|
||||
def _make_tool(dec_func: Union[Callable, Runnable]) -> BaseTool:
|
||||
def _create_tool_factory(
|
||||
tool_name: str,
|
||||
) -> Callable[[Union[Callable, Runnable]], BaseTool]:
|
||||
"""Create a decorator that takes a callable and returns a tool.
|
||||
|
||||
Args:
|
||||
tool_name: The name that will be assigned to the tool.
|
||||
|
||||
Returns:
|
||||
A function that takes a callable or Runnable and returns a tool.
|
||||
"""
|
||||
|
||||
def _tool_factory(dec_func: Union[Callable, Runnable]) -> BaseTool:
|
||||
if isinstance(dec_func, Runnable):
|
||||
runnable = dec_func
|
||||
|
||||
@ -204,28 +275,63 @@ def tool(
|
||||
response_format=response_format,
|
||||
)
|
||||
|
||||
return _make_tool
|
||||
return _tool_factory
|
||||
|
||||
if len(args) == 2 and isinstance(args[0], str) and isinstance(args[1], Runnable):
|
||||
return _make_with_name(args[0])(args[1])
|
||||
elif len(args) == 1 and isinstance(args[0], str):
|
||||
# if the argument is a string, then we use the string as the tool name
|
||||
# Example usage: @tool("search", return_direct=True)
|
||||
return _make_with_name(args[0])
|
||||
elif len(args) == 1 and callable(args[0]):
|
||||
# if the argument is a function, then we use the function name as the tool name
|
||||
# Example usage: @tool
|
||||
return _make_with_name(args[0].__name__)(args[0])
|
||||
elif len(args) == 0:
|
||||
# if there are no arguments, then we use the function name as the tool name
|
||||
# Example usage: @tool(return_direct=True)
|
||||
def _partial(func: Callable[[str], str]) -> BaseTool:
|
||||
return _make_with_name(func.__name__)(func)
|
||||
if len(args) != 0:
|
||||
# Triggered if a user attempts to use positional arguments that
|
||||
# do not exist in the function signature
|
||||
# e.g., @tool("name", runnable, "extra_arg")
|
||||
# Here, "extra_arg" is not a valid argument
|
||||
msg = "Too many arguments for tool decorator. A decorator "
|
||||
raise ValueError(msg)
|
||||
|
||||
if runnable is not None:
|
||||
# tool is used as a function
|
||||
# tool_from_runnable = tool("name", runnable)
|
||||
if not name_or_callable:
|
||||
msg = "Runnable without name for tool constructor"
|
||||
raise ValueError(msg)
|
||||
if not isinstance(name_or_callable, str):
|
||||
msg = "Name must be a string for tool constructor"
|
||||
raise ValueError(msg)
|
||||
return _create_tool_factory(name_or_callable)(runnable)
|
||||
elif name_or_callable is not None:
|
||||
if callable(name_or_callable) and hasattr(name_or_callable, "__name__"):
|
||||
# Used as a decorator without parameters
|
||||
# @tool
|
||||
# def my_tool():
|
||||
# pass
|
||||
return _create_tool_factory(name_or_callable.__name__)(name_or_callable)
|
||||
elif isinstance(name_or_callable, str):
|
||||
# Used with a new name for the tool
|
||||
# @tool("search")
|
||||
# def my_tool():
|
||||
# pass
|
||||
#
|
||||
# or
|
||||
#
|
||||
# @tool("search", parse_docstring=True)
|
||||
# def my_tool():
|
||||
# pass
|
||||
return _create_tool_factory(name_or_callable)
|
||||
else:
|
||||
msg = (
|
||||
f"The first argument must be a string or a callable with a __name__ "
|
||||
f"for tool decorator. Got {type(name_or_callable)}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
# Tool is used as a decorator with parameters specified
|
||||
# @tool(parse_docstring=True)
|
||||
# def my_tool():
|
||||
# pass
|
||||
def _partial(func: Union[Callable, Runnable]) -> BaseTool:
|
||||
"""Partial function that takes a callable and returns a tool."""
|
||||
name_ = func.get_name() if isinstance(func, Runnable) else func.__name__
|
||||
tool_factory = _create_tool_factory(name_)
|
||||
return tool_factory(func)
|
||||
|
||||
return _partial
|
||||
else:
|
||||
msg = "Too many arguments for tool decorator"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def _get_description_from_runnable(runnable: Runnable) -> str:
|
||||
|
Loading…
Reference in New Issue
Block a user