core[patch]: allow passing description to @tool decorator (#29976)

This commit is contained in:
Vadym Barda 2025-02-25 17:45:36 -05:00 committed by GitHub
parent 21bfc95e14
commit 0fc50b82a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 134 additions and 14 deletions

View File

@ -459,6 +459,12 @@ class ChildTool(BaseTool):
@property @property
def tool_call_schema(self) -> ArgsSchema: def tool_call_schema(self) -> ArgsSchema:
if isinstance(self.args_schema, dict): if isinstance(self.args_schema, dict):
if self.description:
return {
**self.args_schema,
"description": self.description,
}
return self.args_schema return self.args_schema
full_schema = self.get_input_schema() full_schema = self.get_input_schema()

View File

@ -5,7 +5,7 @@ from pydantic import BaseModel, Field, create_model
from langchain_core.callbacks import Callbacks from langchain_core.callbacks import Callbacks
from langchain_core.runnables import Runnable from langchain_core.runnables import Runnable
from langchain_core.tools.base import BaseTool from langchain_core.tools.base import ArgsSchema, BaseTool
from langchain_core.tools.simple import Tool from langchain_core.tools.simple import Tool
from langchain_core.tools.structured import StructuredTool from langchain_core.tools.structured import StructuredTool
@ -13,8 +13,9 @@ from langchain_core.tools.structured import StructuredTool
@overload @overload
def tool( def tool(
*, *,
description: Optional[str] = None,
return_direct: bool = False, return_direct: bool = False,
args_schema: Optional[type] = None, args_schema: Optional[ArgsSchema] = None,
infer_schema: bool = True, infer_schema: bool = True,
response_format: Literal["content", "content_and_artifact"] = "content", response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False, parse_docstring: bool = False,
@ -27,8 +28,9 @@ def tool(
name_or_callable: str, name_or_callable: str,
runnable: Runnable, runnable: Runnable,
*, *,
description: Optional[str] = None,
return_direct: bool = False, return_direct: bool = False,
args_schema: Optional[type] = None, args_schema: Optional[ArgsSchema] = None,
infer_schema: bool = True, infer_schema: bool = True,
response_format: Literal["content", "content_and_artifact"] = "content", response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False, parse_docstring: bool = False,
@ -40,8 +42,9 @@ def tool(
def tool( def tool(
name_or_callable: Callable, name_or_callable: Callable,
*, *,
description: Optional[str] = None,
return_direct: bool = False, return_direct: bool = False,
args_schema: Optional[type] = None, args_schema: Optional[ArgsSchema] = None,
infer_schema: bool = True, infer_schema: bool = True,
response_format: Literal["content", "content_and_artifact"] = "content", response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False, parse_docstring: bool = False,
@ -53,8 +56,9 @@ def tool(
def tool( def tool(
name_or_callable: str, name_or_callable: str,
*, *,
description: Optional[str] = None,
return_direct: bool = False, return_direct: bool = False,
args_schema: Optional[type] = None, args_schema: Optional[ArgsSchema] = None,
infer_schema: bool = True, infer_schema: bool = True,
response_format: Literal["content", "content_and_artifact"] = "content", response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False, parse_docstring: bool = False,
@ -66,8 +70,9 @@ def tool(
name_or_callable: Optional[Union[str, Callable]] = None, name_or_callable: Optional[Union[str, Callable]] = None,
runnable: Optional[Runnable] = None, runnable: Optional[Runnable] = None,
*args: Any, *args: Any,
description: Optional[str] = None,
return_direct: bool = False, return_direct: bool = False,
args_schema: Optional[type] = None, args_schema: Optional[ArgsSchema] = None,
infer_schema: bool = True, infer_schema: bool = True,
response_format: Literal["content", "content_and_artifact"] = "content", response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False, parse_docstring: bool = False,
@ -83,6 +88,14 @@ def tool(
converted to a tool. Must be provided as a positional argument. converted to a tool. Must be provided as a positional argument.
runnable: Optional runnable to convert to a tool. Must be provided as a runnable: Optional runnable to convert to a tool. Must be provided as a
positional argument. positional argument.
description: Optional description for the tool.
Precedence for the tool description value is as follows:
- `description` argument
(used even if docstring and/or `args_schema` are provided)
- tool function docstring
(used even if `args_schema` is provided)
- `args_schema` description
(used only if `description` / docstring are not provided)
return_direct: Whether to return directly from the tool rather return_direct: Whether to return directly from the tool rather
than continuing the agent loop. Defaults to False. than continuing the agent loop. Defaults to False.
args_schema: optional argument schema for user to specify. args_schema: optional argument schema for user to specify.
@ -213,6 +226,7 @@ def tool(
""" """
def _tool_factory(dec_func: Union[Callable, Runnable]) -> BaseTool: def _tool_factory(dec_func: Union[Callable, Runnable]) -> BaseTool:
tool_description = description
if isinstance(dec_func, Runnable): if isinstance(dec_func, Runnable):
runnable = dec_func runnable = dec_func
@ -232,25 +246,23 @@ def tool(
coroutine = ainvoke_wrapper coroutine = ainvoke_wrapper
func = invoke_wrapper func = invoke_wrapper
schema: Optional[type[BaseModel]] = runnable.input_schema schema: Optional[ArgsSchema] = runnable.input_schema
description = repr(runnable) tool_description = description or repr(runnable)
elif inspect.iscoroutinefunction(dec_func): elif inspect.iscoroutinefunction(dec_func):
coroutine = dec_func coroutine = dec_func
func = None func = None
schema = args_schema schema = args_schema
description = None
else: else:
coroutine = None coroutine = None
func = dec_func func = dec_func
schema = args_schema schema = args_schema
description = None
if infer_schema or args_schema is not None: if infer_schema or args_schema is not None:
return StructuredTool.from_function( return StructuredTool.from_function(
func, func,
coroutine, coroutine,
name=tool_name, name=tool_name,
description=description, description=tool_description,
return_direct=return_direct, return_direct=return_direct,
args_schema=schema, args_schema=schema,
infer_schema=infer_schema, infer_schema=infer_schema,

View File

@ -27,6 +27,7 @@ from langchain_core.tools.base import (
_get_runnable_config_param, _get_runnable_config_param,
create_schema_from_function, create_schema_from_function,
) )
from langchain_core.utils.pydantic import is_basemodel_subclass
class StructuredTool(BaseTool): class StructuredTool(BaseTool):
@ -188,7 +189,16 @@ class StructuredTool(BaseTool):
if description is None and not parse_docstring: if description is None and not parse_docstring:
description_ = source_function.__doc__ or None description_ = source_function.__doc__ or None
if description_ is None and args_schema: if description_ is None and args_schema:
description_ = args_schema.__doc__ or None if isinstance(args_schema, type) and is_basemodel_subclass(args_schema):
description_ = args_schema.__doc__ or None
elif isinstance(args_schema, dict):
description_ = args_schema.get("description")
else:
msg = (
"Invalid args_schema: expected BaseModel or dict, "
f"got {args_schema}"
)
raise TypeError(msg)
if description_ is None: if description_ is None:
msg = "Function must have a docstring if description not provided." msg = "Function must have a docstring if description not provided."
raise ValueError(msg) raise ValueError(msg)

View File

@ -52,6 +52,7 @@ from langchain_core.tools import (
tool, tool,
) )
from langchain_core.tools.base import ( from langchain_core.tools.base import (
ArgsSchema,
InjectedToolArg, InjectedToolArg,
InjectedToolCallId, InjectedToolCallId,
SchemaAnnotationError, SchemaAnnotationError,
@ -199,7 +200,7 @@ def test_decorator_with_specified_schema() -> None:
assert isinstance(tool_func, BaseTool) assert isinstance(tool_func, BaseTool)
assert tool_func.args_schema == _MockSchema assert tool_func.args_schema == _MockSchema
@tool(args_schema=_MockSchemaV1) @tool(args_schema=cast(ArgsSchema, _MockSchemaV1))
def tool_func_v1(arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str: def tool_func_v1(arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
return f"{arg1} {arg2} {arg3}" return f"{arg1} {arg2} {arg3}"
@ -2398,10 +2399,10 @@ def test_structured_tool_args_schema_dict() -> None:
"required": ["a", "b"], "required": ["a", "b"],
"title": "add", "title": "add",
"type": "object", "type": "object",
"description": "add two numbers",
} }
tool = StructuredTool( tool = StructuredTool(
name="add", name="add",
description="add two numbers",
args_schema=args_schema, args_schema=args_schema,
func=lambda a, b: a + b, func=lambda a, b: a + b,
) )
@ -2433,6 +2434,7 @@ def test_simple_tool_args_schema_dict() -> None:
"required": ["a"], "required": ["a"],
"title": "square", "title": "square",
"type": "object", "type": "object",
"description": "square a number",
} }
tool = Tool( tool = Tool(
name="square", name="square",
@ -2468,3 +2470,93 @@ def test_empty_string_tool_call_id() -> None:
assert foo.invoke({"type": "tool_call", "args": {"x": 0}, "id": ""}) == ToolMessage( assert foo.invoke({"type": "tool_call", "args": {"x": 0}, "id": ""}) == ToolMessage(
content="hi", name="foo", tool_call_id="" content="hi", name="foo", tool_call_id=""
) )
def test_tool_decorator_description() -> None:
# test basic tool
@tool
def foo(x: int) -> str:
"""Foo."""
return "hi"
assert foo.description == "Foo."
assert (
cast(BaseModel, foo.tool_call_schema).model_json_schema()["description"]
== "Foo."
)
# test basic tool with description
@tool(description="description")
def foo_description(x: int) -> str:
"""Foo."""
return "hi"
assert foo_description.description == "description"
assert (
cast(BaseModel, foo_description.tool_call_schema).model_json_schema()[
"description"
]
== "description"
)
# test tool with args schema
class ArgsSchema(BaseModel):
"""Bar."""
x: int
@tool(args_schema=ArgsSchema)
def foo_args_schema(x: int) -> str:
return "hi"
assert foo_args_schema.description == "Bar."
assert (
cast(BaseModel, foo_args_schema.tool_call_schema).model_json_schema()[
"description"
]
== "Bar."
)
@tool(description="description", args_schema=ArgsSchema)
def foo_args_schema_description(x: int) -> str:
return "hi"
assert foo_args_schema_description.description == "description"
assert (
cast(
BaseModel, foo_args_schema_description.tool_call_schema
).model_json_schema()["description"]
== "description"
)
args_json_schema = {
"description": "JSON Schema.",
"properties": {
"x": {"description": "my field", "title": "X", "type": "string"}
},
"required": ["x"],
"title": "my_tool",
"type": "object",
}
@tool(args_schema=args_json_schema)
def foo_args_jsons_schema(x: int) -> str:
return "hi"
@tool(description="description", args_schema=args_json_schema)
def foo_args_jsons_schema_with_description(x: int) -> str:
return "hi"
assert foo_args_jsons_schema.description == "JSON Schema."
assert (
cast(dict, foo_args_jsons_schema.tool_call_schema)["description"]
== "JSON Schema."
)
assert foo_args_jsons_schema_with_description.description == "description"
assert (
cast(dict, foo_args_jsons_schema_with_description.tool_call_schema)[
"description"
]
== "description"
)