DB-GPT/dbgpt/agent/resource/tool/tests/test_base_tool.py
明天 b124ecc10b
feat: (0.6)New UI (#1855)
Co-authored-by: 夏姜 <wenfengjiang.jwf@digital-engine.com>
Co-authored-by: aries_ckt <916701291@qq.com>
Co-authored-by: wb-lh513319 <wb-lh513319@alibaba-inc.com>
Co-authored-by: csunny <cfqsunny@163.com>
2024-08-21 17:37:45 +08:00

203 lines
6.2 KiB
Python

import asyncio
import json
from typing import Dict, List, Optional
import pytest
from typing_extensions import Annotated, Doc
from dbgpt._private.pydantic import BaseModel, Field
from ..base import BaseTool, FunctionTool, ToolParameter, tool
class TestBaseTool(BaseTool):
@property
def name(self):
return "test_tool"
@property
def description(self):
return "This is a test tool."
@property
def args(self):
return {}
def execute(self, *args, **kwargs):
return "executed"
async def async_execute(self, *args, **kwargs):
return "async executed"
def test_base_tool():
tool = TestBaseTool()
assert tool.name == "test_tool"
assert tool.description == "This is a test tool."
assert tool.execute() == "executed"
assert asyncio.run(tool.async_execute()) == "async executed"
def test_function_tool_sync() -> None:
def two_sum(a: int, b: int) -> int:
"""Add two numbers."""
return a + b
ft = FunctionTool(name="sample", func=two_sum)
assert ft.execute(1, 2) == 3
with pytest.raises(ValueError):
asyncio.run(ft.async_execute(1, 2))
@pytest.mark.asyncio
async def test_function_tool_async() -> None:
async def sample_async_func(a: int, b: int) -> int:
"""Add two numbers asynchronously."""
return a + b
ft = FunctionTool(name="sample_async", func=sample_async_func)
with pytest.raises(ValueError):
ft.execute(1, 2)
assert await ft.async_execute(1, 2) == 3
@pytest.mark.asyncio
async def test_function_tool_sync_with_args() -> None:
def two_sum(a: int, b: int) -> int:
"""Add two numbers."""
return a + b
ft = FunctionTool(
name="sample",
func=two_sum,
args={
"a": {"type": "integer", "name": "a", "description": "The first number."},
"b": {"type": "integer", "name": "b", "description": "The second number."},
},
)
ft1 = FunctionTool(
name="sample",
func=two_sum,
args={
"a": ToolParameter(
type="integer", name="a", description="The first number."
),
"b": ToolParameter(
type="integer", name="b", description="The second number."
),
},
)
assert ft.description == "Add two numbers."
assert ft.args.keys() == {"a", "b"}
assert ft.args["a"].type == "integer"
assert ft.args["a"].name == "a"
assert ft.args["a"].description == "The first number."
assert ft.args["a"].title == "A"
dict_params = [
{
"name": "a",
"type": "integer",
"description": "The first number.",
"required": True,
},
{
"name": "b",
"type": "integer",
"description": "The second number.",
"required": True,
},
]
json_params = json.dumps(dict_params, ensure_ascii=False)
expected_prompt = (
f"sample: Call this tool to interact with the sample API. What is the "
f"sample API useful for? Add two numbers. Parameters: {json_params}"
)
pmt, info = await ft.get_prompt()
pmt1, info1 = await ft1.get_prompt()
assert pmt == expected_prompt
assert pmt1 == expected_prompt
assert ft.execute(1, 2) == 3
with pytest.raises(ValueError):
await ft.async_execute(1, 2)
def test_function_tool_sync_with_complex_types() -> None:
@tool
def complex_func(
a: int,
b: Annotated[int, Doc("The second number.")],
c: Annotated[str, Doc("The third string.")],
d: List[int],
e: Annotated[Dict[str, int], Doc("A dictionary of integers.")],
f: Optional[float] = None,
g: str | None = None,
) -> int:
"""A complex function."""
return (
a + b + len(c) + sum(d) + sum(e.values()) + (f or 0) + (len(g) if g else 0)
)
ft: FunctionTool = complex_func._tool
assert ft.description == "A complex function."
assert ft.args.keys() == {"a", "b", "c", "d", "e", "f", "g"}
assert ft.args["a"].type == "integer"
assert ft.args["a"].description == "A"
assert ft.args["b"].type == "integer"
assert ft.args["b"].description == "The second number."
assert ft.args["c"].type == "string"
assert ft.args["c"].description == "The third string."
assert ft.args["d"].type == "array"
assert ft.args["d"].description == "D"
assert ft.args["e"].type == "object"
assert ft.args["e"].description == "A dictionary of integers."
assert ft.args["f"].type == "float"
assert ft.args["f"].description == "F"
assert ft.args["g"].type == "string"
assert ft.args["g"].description == "G"
def test_function_tool_sync_with_args_schema() -> None:
class ArgsSchema(BaseModel):
a: int = Field(description="The first number.")
b: int = Field(description="The second number.")
c: Optional[str] = Field(None, description="The third string.")
d: List[int] = Field(description="Numbers.")
@tool(args_schema=ArgsSchema)
def complex_func(a: int, b: int, c: Optional[str] = None) -> int:
"""A complex function."""
return a + b + len(c) if c else 0
ft: FunctionTool = complex_func._tool
assert ft.description == "A complex function."
assert ft.args.keys() == {"a", "b", "c", "d"}
assert ft.args["a"].type == "integer"
assert ft.args["a"].description == "The first number."
assert ft.args["b"].type == "integer"
assert ft.args["b"].description == "The second number."
assert ft.args["c"].type == "string"
assert ft.args["c"].description == "The third string."
assert ft.args["d"].type == "array"
assert ft.args["d"].description == "Numbers."
def test_tool_decorator() -> None:
@tool(description="Add two numbers")
def add(a: int, b: int) -> int:
"""Add two numbers."""
return a + b
assert add(1, 2) == 3
assert add._tool.name == "add"
assert add._tool.description == "Add two numbers"
@pytest.mark.asyncio
async def test_tool_decorator_async() -> None:
@tool
async def async_add(a: int, b: int) -> int:
"""Asynchronously add two numbers."""
return a + b
assert await async_add(1, 2) == 3