mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 13:23:35 +00:00
core[patch]: allow passing description to @tool decorator (#29976)
This commit is contained in:
parent
21bfc95e14
commit
0fc50b82a0
@ -459,6 +459,12 @@ class ChildTool(BaseTool):
|
|||||||
@property
|
@property
|
||||||
def tool_call_schema(self) -> ArgsSchema:
|
def tool_call_schema(self) -> ArgsSchema:
|
||||||
if isinstance(self.args_schema, dict):
|
if isinstance(self.args_schema, dict):
|
||||||
|
if self.description:
|
||||||
|
return {
|
||||||
|
**self.args_schema,
|
||||||
|
"description": self.description,
|
||||||
|
}
|
||||||
|
|
||||||
return self.args_schema
|
return self.args_schema
|
||||||
|
|
||||||
full_schema = self.get_input_schema()
|
full_schema = self.get_input_schema()
|
||||||
|
@ -5,7 +5,7 @@ from pydantic import BaseModel, Field, create_model
|
|||||||
|
|
||||||
from langchain_core.callbacks import Callbacks
|
from langchain_core.callbacks import Callbacks
|
||||||
from langchain_core.runnables import Runnable
|
from langchain_core.runnables import Runnable
|
||||||
from langchain_core.tools.base import BaseTool
|
from langchain_core.tools.base import ArgsSchema, BaseTool
|
||||||
from langchain_core.tools.simple import Tool
|
from langchain_core.tools.simple import Tool
|
||||||
from langchain_core.tools.structured import StructuredTool
|
from langchain_core.tools.structured import StructuredTool
|
||||||
|
|
||||||
@ -13,8 +13,9 @@ from langchain_core.tools.structured import StructuredTool
|
|||||||
@overload
|
@overload
|
||||||
def tool(
|
def tool(
|
||||||
*,
|
*,
|
||||||
|
description: Optional[str] = None,
|
||||||
return_direct: bool = False,
|
return_direct: bool = False,
|
||||||
args_schema: Optional[type] = None,
|
args_schema: Optional[ArgsSchema] = None,
|
||||||
infer_schema: bool = True,
|
infer_schema: bool = True,
|
||||||
response_format: Literal["content", "content_and_artifact"] = "content",
|
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||||
parse_docstring: bool = False,
|
parse_docstring: bool = False,
|
||||||
@ -27,8 +28,9 @@ def tool(
|
|||||||
name_or_callable: str,
|
name_or_callable: str,
|
||||||
runnable: Runnable,
|
runnable: Runnable,
|
||||||
*,
|
*,
|
||||||
|
description: Optional[str] = None,
|
||||||
return_direct: bool = False,
|
return_direct: bool = False,
|
||||||
args_schema: Optional[type] = None,
|
args_schema: Optional[ArgsSchema] = None,
|
||||||
infer_schema: bool = True,
|
infer_schema: bool = True,
|
||||||
response_format: Literal["content", "content_and_artifact"] = "content",
|
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||||
parse_docstring: bool = False,
|
parse_docstring: bool = False,
|
||||||
@ -40,8 +42,9 @@ def tool(
|
|||||||
def tool(
|
def tool(
|
||||||
name_or_callable: Callable,
|
name_or_callable: Callable,
|
||||||
*,
|
*,
|
||||||
|
description: Optional[str] = None,
|
||||||
return_direct: bool = False,
|
return_direct: bool = False,
|
||||||
args_schema: Optional[type] = None,
|
args_schema: Optional[ArgsSchema] = None,
|
||||||
infer_schema: bool = True,
|
infer_schema: bool = True,
|
||||||
response_format: Literal["content", "content_and_artifact"] = "content",
|
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||||
parse_docstring: bool = False,
|
parse_docstring: bool = False,
|
||||||
@ -53,8 +56,9 @@ def tool(
|
|||||||
def tool(
|
def tool(
|
||||||
name_or_callable: str,
|
name_or_callable: str,
|
||||||
*,
|
*,
|
||||||
|
description: Optional[str] = None,
|
||||||
return_direct: bool = False,
|
return_direct: bool = False,
|
||||||
args_schema: Optional[type] = None,
|
args_schema: Optional[ArgsSchema] = None,
|
||||||
infer_schema: bool = True,
|
infer_schema: bool = True,
|
||||||
response_format: Literal["content", "content_and_artifact"] = "content",
|
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||||
parse_docstring: bool = False,
|
parse_docstring: bool = False,
|
||||||
@ -66,8 +70,9 @@ def tool(
|
|||||||
name_or_callable: Optional[Union[str, Callable]] = None,
|
name_or_callable: Optional[Union[str, Callable]] = None,
|
||||||
runnable: Optional[Runnable] = None,
|
runnable: Optional[Runnable] = None,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
|
description: Optional[str] = None,
|
||||||
return_direct: bool = False,
|
return_direct: bool = False,
|
||||||
args_schema: Optional[type] = None,
|
args_schema: Optional[ArgsSchema] = None,
|
||||||
infer_schema: bool = True,
|
infer_schema: bool = True,
|
||||||
response_format: Literal["content", "content_and_artifact"] = "content",
|
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||||
parse_docstring: bool = False,
|
parse_docstring: bool = False,
|
||||||
@ -83,6 +88,14 @@ def tool(
|
|||||||
converted to a tool. Must be provided as a positional argument.
|
converted to a tool. Must be provided as a positional argument.
|
||||||
runnable: Optional runnable to convert to a tool. Must be provided as a
|
runnable: Optional runnable to convert to a tool. Must be provided as a
|
||||||
positional argument.
|
positional argument.
|
||||||
|
description: Optional description for the tool.
|
||||||
|
Precedence for the tool description value is as follows:
|
||||||
|
- `description` argument
|
||||||
|
(used even if docstring and/or `args_schema` are provided)
|
||||||
|
- tool function docstring
|
||||||
|
(used even if `args_schema` is provided)
|
||||||
|
- `args_schema` description
|
||||||
|
(used only if `description` / docstring are not provided)
|
||||||
return_direct: Whether to return directly from the tool rather
|
return_direct: Whether to return directly from the tool rather
|
||||||
than continuing the agent loop. Defaults to False.
|
than continuing the agent loop. Defaults to False.
|
||||||
args_schema: optional argument schema for user to specify.
|
args_schema: optional argument schema for user to specify.
|
||||||
@ -213,6 +226,7 @@ def tool(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def _tool_factory(dec_func: Union[Callable, Runnable]) -> BaseTool:
|
def _tool_factory(dec_func: Union[Callable, Runnable]) -> BaseTool:
|
||||||
|
tool_description = description
|
||||||
if isinstance(dec_func, Runnable):
|
if isinstance(dec_func, Runnable):
|
||||||
runnable = dec_func
|
runnable = dec_func
|
||||||
|
|
||||||
@ -232,25 +246,23 @@ def tool(
|
|||||||
|
|
||||||
coroutine = ainvoke_wrapper
|
coroutine = ainvoke_wrapper
|
||||||
func = invoke_wrapper
|
func = invoke_wrapper
|
||||||
schema: Optional[type[BaseModel]] = runnable.input_schema
|
schema: Optional[ArgsSchema] = runnable.input_schema
|
||||||
description = repr(runnable)
|
tool_description = description or repr(runnable)
|
||||||
elif inspect.iscoroutinefunction(dec_func):
|
elif inspect.iscoroutinefunction(dec_func):
|
||||||
coroutine = dec_func
|
coroutine = dec_func
|
||||||
func = None
|
func = None
|
||||||
schema = args_schema
|
schema = args_schema
|
||||||
description = None
|
|
||||||
else:
|
else:
|
||||||
coroutine = None
|
coroutine = None
|
||||||
func = dec_func
|
func = dec_func
|
||||||
schema = args_schema
|
schema = args_schema
|
||||||
description = None
|
|
||||||
|
|
||||||
if infer_schema or args_schema is not None:
|
if infer_schema or args_schema is not None:
|
||||||
return StructuredTool.from_function(
|
return StructuredTool.from_function(
|
||||||
func,
|
func,
|
||||||
coroutine,
|
coroutine,
|
||||||
name=tool_name,
|
name=tool_name,
|
||||||
description=description,
|
description=tool_description,
|
||||||
return_direct=return_direct,
|
return_direct=return_direct,
|
||||||
args_schema=schema,
|
args_schema=schema,
|
||||||
infer_schema=infer_schema,
|
infer_schema=infer_schema,
|
||||||
|
@ -27,6 +27,7 @@ from langchain_core.tools.base import (
|
|||||||
_get_runnable_config_param,
|
_get_runnable_config_param,
|
||||||
create_schema_from_function,
|
create_schema_from_function,
|
||||||
)
|
)
|
||||||
|
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||||
|
|
||||||
|
|
||||||
class StructuredTool(BaseTool):
|
class StructuredTool(BaseTool):
|
||||||
@ -188,7 +189,16 @@ class StructuredTool(BaseTool):
|
|||||||
if description is None and not parse_docstring:
|
if description is None and not parse_docstring:
|
||||||
description_ = source_function.__doc__ or None
|
description_ = source_function.__doc__ or None
|
||||||
if description_ is None and args_schema:
|
if description_ is None and args_schema:
|
||||||
description_ = args_schema.__doc__ or None
|
if isinstance(args_schema, type) and is_basemodel_subclass(args_schema):
|
||||||
|
description_ = args_schema.__doc__ or None
|
||||||
|
elif isinstance(args_schema, dict):
|
||||||
|
description_ = args_schema.get("description")
|
||||||
|
else:
|
||||||
|
msg = (
|
||||||
|
"Invalid args_schema: expected BaseModel or dict, "
|
||||||
|
f"got {args_schema}"
|
||||||
|
)
|
||||||
|
raise TypeError(msg)
|
||||||
if description_ is None:
|
if description_ is None:
|
||||||
msg = "Function must have a docstring if description not provided."
|
msg = "Function must have a docstring if description not provided."
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
@ -52,6 +52,7 @@ from langchain_core.tools import (
|
|||||||
tool,
|
tool,
|
||||||
)
|
)
|
||||||
from langchain_core.tools.base import (
|
from langchain_core.tools.base import (
|
||||||
|
ArgsSchema,
|
||||||
InjectedToolArg,
|
InjectedToolArg,
|
||||||
InjectedToolCallId,
|
InjectedToolCallId,
|
||||||
SchemaAnnotationError,
|
SchemaAnnotationError,
|
||||||
@ -199,7 +200,7 @@ def test_decorator_with_specified_schema() -> None:
|
|||||||
assert isinstance(tool_func, BaseTool)
|
assert isinstance(tool_func, BaseTool)
|
||||||
assert tool_func.args_schema == _MockSchema
|
assert tool_func.args_schema == _MockSchema
|
||||||
|
|
||||||
@tool(args_schema=_MockSchemaV1)
|
@tool(args_schema=cast(ArgsSchema, _MockSchemaV1))
|
||||||
def tool_func_v1(arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
def tool_func_v1(arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||||
return f"{arg1} {arg2} {arg3}"
|
return f"{arg1} {arg2} {arg3}"
|
||||||
|
|
||||||
@ -2398,10 +2399,10 @@ def test_structured_tool_args_schema_dict() -> None:
|
|||||||
"required": ["a", "b"],
|
"required": ["a", "b"],
|
||||||
"title": "add",
|
"title": "add",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
"description": "add two numbers",
|
||||||
}
|
}
|
||||||
tool = StructuredTool(
|
tool = StructuredTool(
|
||||||
name="add",
|
name="add",
|
||||||
description="add two numbers",
|
|
||||||
args_schema=args_schema,
|
args_schema=args_schema,
|
||||||
func=lambda a, b: a + b,
|
func=lambda a, b: a + b,
|
||||||
)
|
)
|
||||||
@ -2433,6 +2434,7 @@ def test_simple_tool_args_schema_dict() -> None:
|
|||||||
"required": ["a"],
|
"required": ["a"],
|
||||||
"title": "square",
|
"title": "square",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
"description": "square a number",
|
||||||
}
|
}
|
||||||
tool = Tool(
|
tool = Tool(
|
||||||
name="square",
|
name="square",
|
||||||
@ -2468,3 +2470,93 @@ def test_empty_string_tool_call_id() -> None:
|
|||||||
assert foo.invoke({"type": "tool_call", "args": {"x": 0}, "id": ""}) == ToolMessage(
|
assert foo.invoke({"type": "tool_call", "args": {"x": 0}, "id": ""}) == ToolMessage(
|
||||||
content="hi", name="foo", tool_call_id=""
|
content="hi", name="foo", tool_call_id=""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_decorator_description() -> None:
|
||||||
|
# test basic tool
|
||||||
|
@tool
|
||||||
|
def foo(x: int) -> str:
|
||||||
|
"""Foo."""
|
||||||
|
return "hi"
|
||||||
|
|
||||||
|
assert foo.description == "Foo."
|
||||||
|
assert (
|
||||||
|
cast(BaseModel, foo.tool_call_schema).model_json_schema()["description"]
|
||||||
|
== "Foo."
|
||||||
|
)
|
||||||
|
|
||||||
|
# test basic tool with description
|
||||||
|
@tool(description="description")
|
||||||
|
def foo_description(x: int) -> str:
|
||||||
|
"""Foo."""
|
||||||
|
return "hi"
|
||||||
|
|
||||||
|
assert foo_description.description == "description"
|
||||||
|
assert (
|
||||||
|
cast(BaseModel, foo_description.tool_call_schema).model_json_schema()[
|
||||||
|
"description"
|
||||||
|
]
|
||||||
|
== "description"
|
||||||
|
)
|
||||||
|
|
||||||
|
# test tool with args schema
|
||||||
|
class ArgsSchema(BaseModel):
|
||||||
|
"""Bar."""
|
||||||
|
|
||||||
|
x: int
|
||||||
|
|
||||||
|
@tool(args_schema=ArgsSchema)
|
||||||
|
def foo_args_schema(x: int) -> str:
|
||||||
|
return "hi"
|
||||||
|
|
||||||
|
assert foo_args_schema.description == "Bar."
|
||||||
|
assert (
|
||||||
|
cast(BaseModel, foo_args_schema.tool_call_schema).model_json_schema()[
|
||||||
|
"description"
|
||||||
|
]
|
||||||
|
== "Bar."
|
||||||
|
)
|
||||||
|
|
||||||
|
@tool(description="description", args_schema=ArgsSchema)
|
||||||
|
def foo_args_schema_description(x: int) -> str:
|
||||||
|
return "hi"
|
||||||
|
|
||||||
|
assert foo_args_schema_description.description == "description"
|
||||||
|
assert (
|
||||||
|
cast(
|
||||||
|
BaseModel, foo_args_schema_description.tool_call_schema
|
||||||
|
).model_json_schema()["description"]
|
||||||
|
== "description"
|
||||||
|
)
|
||||||
|
|
||||||
|
args_json_schema = {
|
||||||
|
"description": "JSON Schema.",
|
||||||
|
"properties": {
|
||||||
|
"x": {"description": "my field", "title": "X", "type": "string"}
|
||||||
|
},
|
||||||
|
"required": ["x"],
|
||||||
|
"title": "my_tool",
|
||||||
|
"type": "object",
|
||||||
|
}
|
||||||
|
|
||||||
|
@tool(args_schema=args_json_schema)
|
||||||
|
def foo_args_jsons_schema(x: int) -> str:
|
||||||
|
return "hi"
|
||||||
|
|
||||||
|
@tool(description="description", args_schema=args_json_schema)
|
||||||
|
def foo_args_jsons_schema_with_description(x: int) -> str:
|
||||||
|
return "hi"
|
||||||
|
|
||||||
|
assert foo_args_jsons_schema.description == "JSON Schema."
|
||||||
|
assert (
|
||||||
|
cast(dict, foo_args_jsons_schema.tool_call_schema)["description"]
|
||||||
|
== "JSON Schema."
|
||||||
|
)
|
||||||
|
|
||||||
|
assert foo_args_jsons_schema_with_description.description == "description"
|
||||||
|
assert (
|
||||||
|
cast(dict, foo_args_jsons_schema_with_description.tool_call_schema)[
|
||||||
|
"description"
|
||||||
|
]
|
||||||
|
== "description"
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user