mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 05:13:46 +00:00
core[patch]: allow passing description to @tool decorator (#29976)
This commit is contained in:
parent
21bfc95e14
commit
0fc50b82a0
@ -459,6 +459,12 @@ class ChildTool(BaseTool):
|
||||
@property
|
||||
def tool_call_schema(self) -> ArgsSchema:
|
||||
if isinstance(self.args_schema, dict):
|
||||
if self.description:
|
||||
return {
|
||||
**self.args_schema,
|
||||
"description": self.description,
|
||||
}
|
||||
|
||||
return self.args_schema
|
||||
|
||||
full_schema = self.get_input_schema()
|
||||
|
@ -5,7 +5,7 @@ from pydantic import BaseModel, Field, create_model
|
||||
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools.base import BaseTool
|
||||
from langchain_core.tools.base import ArgsSchema, BaseTool
|
||||
from langchain_core.tools.simple import Tool
|
||||
from langchain_core.tools.structured import StructuredTool
|
||||
|
||||
@ -13,8 +13,9 @@ from langchain_core.tools.structured import StructuredTool
|
||||
@overload
|
||||
def tool(
|
||||
*,
|
||||
description: Optional[str] = None,
|
||||
return_direct: bool = False,
|
||||
args_schema: Optional[type] = None,
|
||||
args_schema: Optional[ArgsSchema] = None,
|
||||
infer_schema: bool = True,
|
||||
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||
parse_docstring: bool = False,
|
||||
@ -27,8 +28,9 @@ def tool(
|
||||
name_or_callable: str,
|
||||
runnable: Runnable,
|
||||
*,
|
||||
description: Optional[str] = None,
|
||||
return_direct: bool = False,
|
||||
args_schema: Optional[type] = None,
|
||||
args_schema: Optional[ArgsSchema] = None,
|
||||
infer_schema: bool = True,
|
||||
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||
parse_docstring: bool = False,
|
||||
@ -40,8 +42,9 @@ def tool(
|
||||
def tool(
|
||||
name_or_callable: Callable,
|
||||
*,
|
||||
description: Optional[str] = None,
|
||||
return_direct: bool = False,
|
||||
args_schema: Optional[type] = None,
|
||||
args_schema: Optional[ArgsSchema] = None,
|
||||
infer_schema: bool = True,
|
||||
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||
parse_docstring: bool = False,
|
||||
@ -53,8 +56,9 @@ def tool(
|
||||
def tool(
|
||||
name_or_callable: str,
|
||||
*,
|
||||
description: Optional[str] = None,
|
||||
return_direct: bool = False,
|
||||
args_schema: Optional[type] = None,
|
||||
args_schema: Optional[ArgsSchema] = None,
|
||||
infer_schema: bool = True,
|
||||
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||
parse_docstring: bool = False,
|
||||
@ -66,8 +70,9 @@ def tool(
|
||||
name_or_callable: Optional[Union[str, Callable]] = None,
|
||||
runnable: Optional[Runnable] = None,
|
||||
*args: Any,
|
||||
description: Optional[str] = None,
|
||||
return_direct: bool = False,
|
||||
args_schema: Optional[type] = None,
|
||||
args_schema: Optional[ArgsSchema] = None,
|
||||
infer_schema: bool = True,
|
||||
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||
parse_docstring: bool = False,
|
||||
@ -83,6 +88,14 @@ def tool(
|
||||
converted to a tool. Must be provided as a positional argument.
|
||||
runnable: Optional runnable to convert to a tool. Must be provided as a
|
||||
positional argument.
|
||||
description: Optional description for the tool.
|
||||
Precedence for the tool description value is as follows:
|
||||
- `description` argument
|
||||
(used even if docstring and/or `args_schema` are provided)
|
||||
- tool function docstring
|
||||
(used even if `args_schema` is provided)
|
||||
- `args_schema` description
|
||||
(used only if `description` / docstring are not provided)
|
||||
return_direct: Whether to return directly from the tool rather
|
||||
than continuing the agent loop. Defaults to False.
|
||||
args_schema: optional argument schema for user to specify.
|
||||
@ -213,6 +226,7 @@ def tool(
|
||||
"""
|
||||
|
||||
def _tool_factory(dec_func: Union[Callable, Runnable]) -> BaseTool:
|
||||
tool_description = description
|
||||
if isinstance(dec_func, Runnable):
|
||||
runnable = dec_func
|
||||
|
||||
@ -232,25 +246,23 @@ def tool(
|
||||
|
||||
coroutine = ainvoke_wrapper
|
||||
func = invoke_wrapper
|
||||
schema: Optional[type[BaseModel]] = runnable.input_schema
|
||||
description = repr(runnable)
|
||||
schema: Optional[ArgsSchema] = runnable.input_schema
|
||||
tool_description = description or repr(runnable)
|
||||
elif inspect.iscoroutinefunction(dec_func):
|
||||
coroutine = dec_func
|
||||
func = None
|
||||
schema = args_schema
|
||||
description = None
|
||||
else:
|
||||
coroutine = None
|
||||
func = dec_func
|
||||
schema = args_schema
|
||||
description = None
|
||||
|
||||
if infer_schema or args_schema is not None:
|
||||
return StructuredTool.from_function(
|
||||
func,
|
||||
coroutine,
|
||||
name=tool_name,
|
||||
description=description,
|
||||
description=tool_description,
|
||||
return_direct=return_direct,
|
||||
args_schema=schema,
|
||||
infer_schema=infer_schema,
|
||||
|
@ -27,6 +27,7 @@ from langchain_core.tools.base import (
|
||||
_get_runnable_config_param,
|
||||
create_schema_from_function,
|
||||
)
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
|
||||
|
||||
class StructuredTool(BaseTool):
|
||||
@ -188,7 +189,16 @@ class StructuredTool(BaseTool):
|
||||
if description is None and not parse_docstring:
|
||||
description_ = source_function.__doc__ or None
|
||||
if description_ is None and args_schema:
|
||||
description_ = args_schema.__doc__ or None
|
||||
if isinstance(args_schema, type) and is_basemodel_subclass(args_schema):
|
||||
description_ = args_schema.__doc__ or None
|
||||
elif isinstance(args_schema, dict):
|
||||
description_ = args_schema.get("description")
|
||||
else:
|
||||
msg = (
|
||||
"Invalid args_schema: expected BaseModel or dict, "
|
||||
f"got {args_schema}"
|
||||
)
|
||||
raise TypeError(msg)
|
||||
if description_ is None:
|
||||
msg = "Function must have a docstring if description not provided."
|
||||
raise ValueError(msg)
|
||||
|
@ -52,6 +52,7 @@ from langchain_core.tools import (
|
||||
tool,
|
||||
)
|
||||
from langchain_core.tools.base import (
|
||||
ArgsSchema,
|
||||
InjectedToolArg,
|
||||
InjectedToolCallId,
|
||||
SchemaAnnotationError,
|
||||
@ -199,7 +200,7 @@ def test_decorator_with_specified_schema() -> None:
|
||||
assert isinstance(tool_func, BaseTool)
|
||||
assert tool_func.args_schema == _MockSchema
|
||||
|
||||
@tool(args_schema=_MockSchemaV1)
|
||||
@tool(args_schema=cast(ArgsSchema, _MockSchemaV1))
|
||||
def tool_func_v1(arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
@ -2398,10 +2399,10 @@ def test_structured_tool_args_schema_dict() -> None:
|
||||
"required": ["a", "b"],
|
||||
"title": "add",
|
||||
"type": "object",
|
||||
"description": "add two numbers",
|
||||
}
|
||||
tool = StructuredTool(
|
||||
name="add",
|
||||
description="add two numbers",
|
||||
args_schema=args_schema,
|
||||
func=lambda a, b: a + b,
|
||||
)
|
||||
@ -2433,6 +2434,7 @@ def test_simple_tool_args_schema_dict() -> None:
|
||||
"required": ["a"],
|
||||
"title": "square",
|
||||
"type": "object",
|
||||
"description": "square a number",
|
||||
}
|
||||
tool = Tool(
|
||||
name="square",
|
||||
@ -2468,3 +2470,93 @@ def test_empty_string_tool_call_id() -> None:
|
||||
assert foo.invoke({"type": "tool_call", "args": {"x": 0}, "id": ""}) == ToolMessage(
|
||||
content="hi", name="foo", tool_call_id=""
|
||||
)
|
||||
|
||||
|
||||
def test_tool_decorator_description() -> None:
|
||||
# test basic tool
|
||||
@tool
|
||||
def foo(x: int) -> str:
|
||||
"""Foo."""
|
||||
return "hi"
|
||||
|
||||
assert foo.description == "Foo."
|
||||
assert (
|
||||
cast(BaseModel, foo.tool_call_schema).model_json_schema()["description"]
|
||||
== "Foo."
|
||||
)
|
||||
|
||||
# test basic tool with description
|
||||
@tool(description="description")
|
||||
def foo_description(x: int) -> str:
|
||||
"""Foo."""
|
||||
return "hi"
|
||||
|
||||
assert foo_description.description == "description"
|
||||
assert (
|
||||
cast(BaseModel, foo_description.tool_call_schema).model_json_schema()[
|
||||
"description"
|
||||
]
|
||||
== "description"
|
||||
)
|
||||
|
||||
# test tool with args schema
|
||||
class ArgsSchema(BaseModel):
|
||||
"""Bar."""
|
||||
|
||||
x: int
|
||||
|
||||
@tool(args_schema=ArgsSchema)
|
||||
def foo_args_schema(x: int) -> str:
|
||||
return "hi"
|
||||
|
||||
assert foo_args_schema.description == "Bar."
|
||||
assert (
|
||||
cast(BaseModel, foo_args_schema.tool_call_schema).model_json_schema()[
|
||||
"description"
|
||||
]
|
||||
== "Bar."
|
||||
)
|
||||
|
||||
@tool(description="description", args_schema=ArgsSchema)
|
||||
def foo_args_schema_description(x: int) -> str:
|
||||
return "hi"
|
||||
|
||||
assert foo_args_schema_description.description == "description"
|
||||
assert (
|
||||
cast(
|
||||
BaseModel, foo_args_schema_description.tool_call_schema
|
||||
).model_json_schema()["description"]
|
||||
== "description"
|
||||
)
|
||||
|
||||
args_json_schema = {
|
||||
"description": "JSON Schema.",
|
||||
"properties": {
|
||||
"x": {"description": "my field", "title": "X", "type": "string"}
|
||||
},
|
||||
"required": ["x"],
|
||||
"title": "my_tool",
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
@tool(args_schema=args_json_schema)
|
||||
def foo_args_jsons_schema(x: int) -> str:
|
||||
return "hi"
|
||||
|
||||
@tool(description="description", args_schema=args_json_schema)
|
||||
def foo_args_jsons_schema_with_description(x: int) -> str:
|
||||
return "hi"
|
||||
|
||||
assert foo_args_jsons_schema.description == "JSON Schema."
|
||||
assert (
|
||||
cast(dict, foo_args_jsons_schema.tool_call_schema)["description"]
|
||||
== "JSON Schema."
|
||||
)
|
||||
|
||||
assert foo_args_jsons_schema_with_description.description == "description"
|
||||
assert (
|
||||
cast(dict, foo_args_jsons_schema_with_description.tool_call_schema)[
|
||||
"description"
|
||||
]
|
||||
== "description"
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user