mirror of
https://github.com/hwchase17/langchain.git
synced 2026-03-19 03:26:46 +00:00
209 lines
6.4 KiB
Python
209 lines
6.4 KiB
Python
import logging
|
|
from functools import partial
|
|
from typing import Any, Optional, Type
|
|
|
|
import pydantic
|
|
import pytest
|
|
from pydantic import BaseModel
|
|
|
|
from langchain.tools.structured import (
|
|
BaseStructuredTool,
|
|
StructuredTool,
|
|
structured_tool,
|
|
)
|
|
|
|
|
|
class _MockSchema(BaseModel):
|
|
arg1: int
|
|
arg2: bool
|
|
arg3: Optional[dict] = None
|
|
|
|
|
|
class _MockStructuredTool(BaseStructuredTool):
|
|
name = "structured_api"
|
|
args_schema: Type[BaseModel] = _MockSchema
|
|
description = "A Structured Tool"
|
|
|
|
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
|
return f"{arg1} {arg2} {arg3}"
|
|
|
|
async def _arun(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
|
raise NotImplementedError
|
|
|
|
|
|
def test_structured_args() -> None:
|
|
"""Test functionality with structured arguments."""
|
|
structured_api = _MockStructuredTool()
|
|
assert isinstance(structured_api, BaseStructuredTool)
|
|
assert structured_api.name == "structured_api"
|
|
expected_result = "1 True {'foo': 'bar'}"
|
|
args = {"arg1": 1, "arg2": True, "arg3": {"foo": "bar"}}
|
|
assert structured_api.run(args) == expected_result
|
|
|
|
|
|
def test_subclass_annotated_base_tool_accepted() -> None:
|
|
"""Test BaseTool child w/ custom schema isn't overwritten."""
|
|
|
|
class _ForwardRefAnnotatedTool(BaseStructuredTool):
|
|
name = "structured_api"
|
|
args_schema: Type[_MockSchema] = _MockSchema
|
|
description = "A Structured Tool"
|
|
|
|
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
|
return f"{arg1} {arg2} {arg3}"
|
|
|
|
async def _arun(
|
|
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
|
) -> str:
|
|
raise NotImplementedError
|
|
|
|
assert issubclass(_ForwardRefAnnotatedTool, BaseStructuredTool)
|
|
tool = _ForwardRefAnnotatedTool()
|
|
assert tool.args_schema == _MockSchema
|
|
|
|
|
|
def test_decorator_with_specified_schema() -> None:
|
|
"""Test that manually specified schemata are passed through to the tool."""
|
|
|
|
@structured_tool(args_schema=_MockSchema)
|
|
def tool_func(arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
|
"""Return the arguments directly."""
|
|
return f"{arg1} {arg2} {arg3}"
|
|
|
|
assert isinstance(tool_func, StructuredTool)
|
|
assert tool_func.args_schema == _MockSchema
|
|
|
|
|
|
def test_decorated_function_schema_equivalent() -> None:
|
|
"""Test that a BaseTool without a schema meets expectations."""
|
|
|
|
@structured_tool
|
|
def structured_tool_input(
|
|
arg1: int, arg2: bool, arg3: Optional[dict] = None
|
|
) -> str:
|
|
"""Return the arguments directly."""
|
|
return f"{arg1} {arg2} {arg3}"
|
|
|
|
assert isinstance(structured_tool_input, StructuredTool)
|
|
assert (
|
|
structured_tool_input.args_schema.schema()["properties"]
|
|
== _MockSchema.schema()["properties"]
|
|
)
|
|
|
|
|
|
def test_tool_lambda_multi_args_schema() -> None:
|
|
"""Test args schema inference when the tool argument is a lambda function."""
|
|
tool = StructuredTool.from_function(
|
|
func=lambda tool_input, other_arg: f"{tool_input}{other_arg}", # type: ignore
|
|
name="tool",
|
|
description="A tool",
|
|
)
|
|
assert set(tool.args_schema.schema()["properties"]) == {"tool_input", "other_arg"}
|
|
expected_args = {
|
|
"tool_input": {"title": "Tool Input"},
|
|
"other_arg": {"title": "Other Arg"},
|
|
}
|
|
assert tool.args_schema.schema()["properties"] == expected_args
|
|
|
|
|
|
def test_tool_partial_function_args_schema() -> None:
|
|
"""Test args schema inference when the tool argument is a partial function."""
|
|
|
|
def func(tool_input: str, other_arg: str) -> str:
|
|
return tool_input + other_arg
|
|
|
|
with pytest.raises(pydantic.error_wrappers.ValidationError):
|
|
# We don't yet support args_schema inference for partial functions
|
|
# so want to make sure we proactively raise an error
|
|
StructuredTool(
|
|
name="tool",
|
|
description="A tool",
|
|
func=partial(func, other_arg="foo"),
|
|
)
|
|
|
|
|
|
def test_tool_with_kwargs() -> None:
|
|
"""Test functionality when only return direct is provided."""
|
|
|
|
@structured_tool(return_direct=True)
|
|
def search_api(
|
|
arg_1: float,
|
|
ping: str = "hi",
|
|
) -> str:
|
|
"""Search the API for the query."""
|
|
return f"arg_1={arg_1}, ping={ping}"
|
|
|
|
assert isinstance(search_api, StructuredTool)
|
|
result = search_api.run(
|
|
tool_input={
|
|
"arg_1": 3.2,
|
|
"ping": "pong",
|
|
}
|
|
)
|
|
assert result == "arg_1=3.2, ping=pong"
|
|
|
|
result = search_api.run(
|
|
tool_input={
|
|
"arg_1": 3.2,
|
|
}
|
|
)
|
|
assert result == "arg_1=3.2, ping=hi"
|
|
|
|
|
|
def test_empty_args_decorator() -> None:
|
|
"""Test inferred schema of decorated fn with no args."""
|
|
|
|
@structured_tool
|
|
def empty_tool_input() -> str:
|
|
"""Return a constant."""
|
|
return "the empty result"
|
|
|
|
assert isinstance(empty_tool_input, StructuredTool)
|
|
assert empty_tool_input.name == "empty_tool_input"
|
|
assert empty_tool_input.args_schema.schema()["properties"] == {}
|
|
assert empty_tool_input.run({}) == "the empty result"
|
|
|
|
|
|
def test_nested_pydantic_args() -> None:
|
|
"""Test inferred schema when args are nested pydantic models."""
|
|
# This is a pattern that is common with FastAPI methods.
|
|
# If we only parse a dict input but pass the dict
|
|
# to the function, we are limited only to primitive types
|
|
# in general.
|
|
|
|
class SomeNestedInput(BaseModel):
|
|
arg2: str
|
|
|
|
class SomeInput(BaseModel):
|
|
arg1: int
|
|
arg2: SomeNestedInput
|
|
|
|
@structured_tool
|
|
def nested_tool(some_input: SomeInput) -> dict:
|
|
"""Return a constant."""
|
|
return some_input.dict()
|
|
|
|
assert isinstance(nested_tool, StructuredTool)
|
|
assert nested_tool.name == "nested_tool"
|
|
input_ = {"some_input": {"arg1": 1, "arg2": {"arg2": "foo"}}}
|
|
assert nested_tool.run(input_) == input_["some_input"]
|
|
|
|
|
|
def test_warning_on_args_kwargs(caplog: pytest.LogCaptureFixture) -> None:
|
|
"""Test inferred schema when args are nested pydantic models."""
|
|
|
|
with caplog.at_level(logging.WARNING):
|
|
|
|
@structured_tool
|
|
def anything_goes(*foo: Any, **bar: Any) -> str:
|
|
"""Return a constant."""
|
|
return str(foo) + "|" + str(bar)
|
|
|
|
# Check if the expected warning message was logged
|
|
assert any(
|
|
"anything_goes uses *args" in record.message for record in caplog.records
|
|
)
|
|
assert any(
|
|
"anything_goes uses **kwargs" in record.message for record in caplog.records
|
|
)
|