mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-24 20:09:01 +00:00
Structured Tool Bugfixes (#3324)
- Proactively raise error if a tool subclasses BaseTool, defines its own schema, but fails to add the type-hints - fix the auto-inferred schema of the decorator to strip the unneeded virtual kwargs from the schema dict Helps avoid silent instances of #3297
This commit is contained in:
@@ -1,12 +1,14 @@
|
||||
"""Test tool utils."""
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Optional, Type, Union
|
||||
|
||||
import pydantic
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.agents.tools import Tool, tool
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.tools.base import BaseTool, SchemaAnnotationError
|
||||
|
||||
|
||||
def test_unnamed_decorator() -> None:
|
||||
@@ -51,10 +53,116 @@ def test_structured_args() -> None:
|
||||
assert structured_api.run(args) == expected_result
|
||||
|
||||
|
||||
def test_structured_args_decorator() -> None:
|
||||
"""Test functionality with structured arguments parsed as a decorator."""
|
||||
def test_unannotated_base_tool_raises_error() -> None:
|
||||
"""Test that a BaseTool without type hints raises an exception.""" ""
|
||||
with pytest.raises(SchemaAnnotationError):
|
||||
|
||||
class _UnAnnotatedTool(BaseTool):
|
||||
name = "structured_api"
|
||||
# This would silently be ignored without the custom metaclass
|
||||
args_schema = _MockSchema
|
||||
description = "A Structured Tool"
|
||||
|
||||
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
async def _arun(
|
||||
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def test_misannotated_base_tool_raises_error() -> None:
|
||||
"""Test that a BaseTool with the incorrrect typehint raises an exception.""" ""
|
||||
with pytest.raises(SchemaAnnotationError):
|
||||
|
||||
class _MisAnnotatedTool(BaseTool):
|
||||
name = "structured_api"
|
||||
# This would silently be ignored without the custom metaclass
|
||||
args_schema: BaseModel = _MockSchema # type: ignore
|
||||
description = "A Structured Tool"
|
||||
|
||||
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
async def _arun(
|
||||
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def test_forward_ref_annotated_base_tool_accepted() -> None:
|
||||
"""Test that a using forward ref annotation syntax is accepted.""" ""
|
||||
|
||||
class _ForwardRefAnnotatedTool(BaseTool):
|
||||
name = "structured_api"
|
||||
args_schema: "Type[BaseModel]" = _MockSchema
|
||||
description = "A Structured Tool"
|
||||
|
||||
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
async def _arun(
|
||||
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def test_subclass_annotated_base_tool_accepted() -> None:
|
||||
"""Test BaseTool child w/ custom schema isn't overwritten."""
|
||||
|
||||
class _ForwardRefAnnotatedTool(BaseTool):
|
||||
name = "structured_api"
|
||||
args_schema: Type[_MockSchema] = _MockSchema
|
||||
description = "A Structured Tool"
|
||||
|
||||
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
async def _arun(
|
||||
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
assert issubclass(_ForwardRefAnnotatedTool, BaseTool)
|
||||
tool = _ForwardRefAnnotatedTool()
|
||||
assert tool.args_schema == _MockSchema
|
||||
|
||||
|
||||
def test_decorator_with_specified_schema() -> None:
|
||||
"""Test that manually specified schemata are passed through to the tool."""
|
||||
|
||||
@tool(args_schema=_MockSchema)
|
||||
def tool_func(arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
"""Return the arguments directly."""
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
assert isinstance(tool_func, Tool)
|
||||
assert tool_func.args_schema == _MockSchema
|
||||
|
||||
|
||||
def test_decorated_function_schema_equivalent() -> None:
|
||||
"""Test that a BaseTool without a schema meets expectations."""
|
||||
|
||||
@tool
|
||||
def structured_tool_input(
|
||||
arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||
) -> str:
|
||||
"""Return the arguments directly."""
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
assert isinstance(structured_tool_input, Tool)
|
||||
assert (
|
||||
structured_tool_input.args_schema.schema()["properties"]
|
||||
== _MockSchema.schema()["properties"]
|
||||
== structured_tool_input.args
|
||||
)
|
||||
|
||||
|
||||
def test_structured_args_decorator_no_infer_schema() -> None:
|
||||
"""Test functionality with structured arguments parsed as a decorator."""
|
||||
|
||||
@tool(infer_schema=False)
|
||||
def structured_tool_input(
|
||||
arg1: int, arg2: Union[float, datetime], opt_arg: Optional[dict] = None
|
||||
) -> str:
|
||||
@@ -68,8 +176,83 @@ def test_structured_args_decorator() -> None:
|
||||
assert structured_tool_input.run(args) == expected_result
|
||||
|
||||
|
||||
def test_structured_single_str_decorator_no_infer_schema() -> None:
|
||||
"""Test functionality with structured arguments parsed as a decorator."""
|
||||
|
||||
@tool(infer_schema=False)
|
||||
def unstructured_tool_input(tool_input: str) -> str:
|
||||
"""Return the arguments directly."""
|
||||
return f"{tool_input}"
|
||||
|
||||
assert isinstance(unstructured_tool_input, Tool)
|
||||
assert unstructured_tool_input.args_schema is None
|
||||
|
||||
|
||||
def test_base_tool_inheritance_base_schema() -> None:
|
||||
"""Test schema is correctly inferred when inheriting from BaseTool."""
|
||||
|
||||
class _MockSimpleTool(BaseTool):
|
||||
name = "simple_tool"
|
||||
description = "A Simple Tool"
|
||||
|
||||
def _run(self, tool_input: str) -> str:
|
||||
return f"{tool_input}"
|
||||
|
||||
async def _arun(self, tool_input: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
simple_tool = _MockSimpleTool()
|
||||
assert simple_tool.args_schema is None
|
||||
expected_args = {"tool_input": {"title": "Tool Input", "type": "string"}}
|
||||
assert simple_tool.args == expected_args
|
||||
|
||||
|
||||
def test_tool_lambda_args_schema() -> None:
|
||||
"""Test args schema inference when the tool argument is a lambda function."""
|
||||
|
||||
tool = Tool(
|
||||
name="tool",
|
||||
description="A tool",
|
||||
func=lambda tool_input: tool_input,
|
||||
)
|
||||
assert tool.args_schema is None
|
||||
expected_args = {"tool_input": {"title": "Tool Input"}}
|
||||
assert tool.args == expected_args
|
||||
|
||||
|
||||
def test_tool_lambda_multi_args_schema() -> None:
|
||||
"""Test args schema inference when the tool argument is a lambda function."""
|
||||
tool = Tool(
|
||||
name="tool",
|
||||
description="A tool",
|
||||
func=lambda tool_input, other_arg: f"{tool_input}{other_arg}", # type: ignore
|
||||
)
|
||||
assert tool.args_schema is None
|
||||
expected_args = {
|
||||
"tool_input": {"title": "Tool Input"},
|
||||
"other_arg": {"title": "Other Arg"},
|
||||
}
|
||||
assert tool.args == expected_args
|
||||
|
||||
|
||||
def test_tool_partial_function_args_schema() -> None:
|
||||
"""Test args schema inference when the tool argument is a partial function."""
|
||||
|
||||
def func(tool_input: str, other_arg: str) -> str:
|
||||
return tool_input + other_arg
|
||||
|
||||
with pytest.raises(pydantic.error_wrappers.ValidationError):
|
||||
# We don't yet support args_schema inference for partial functions
|
||||
# so want to make sure we proactively raise an error
|
||||
Tool(
|
||||
name="tool",
|
||||
description="A tool",
|
||||
func=partial(func, other_arg="foo"),
|
||||
)
|
||||
|
||||
|
||||
def test_empty_args_decorator() -> None:
|
||||
"""Test functionality with no args parsed as a decorator."""
|
||||
"""Test inferred schema of decorated fn with no args."""
|
||||
|
||||
@tool
|
||||
def empty_tool_input() -> str:
|
||||
@@ -78,6 +261,7 @@ def test_empty_args_decorator() -> None:
|
||||
|
||||
assert isinstance(empty_tool_input, Tool)
|
||||
assert empty_tool_input.name == "empty_tool_input"
|
||||
assert empty_tool_input.args == {}
|
||||
assert empty_tool_input.run({}) == "the empty result"
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user