mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-05 02:51:07 +00:00
refactor(agent): Refactor resource of agents (#1518)
This commit is contained in:
200
dbgpt/agent/resource/tool/tests/test_base_tool.py
Normal file
200
dbgpt/agent/resource/tool/tests/test_base_tool.py
Normal file
@@ -0,0 +1,200 @@
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
from typing_extensions import Annotated, Doc
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
|
||||
from ..base import BaseTool, FunctionTool, ToolParameter, tool
|
||||
|
||||
|
||||
class TestBaseTool(BaseTool):
|
||||
@property
|
||||
def name(self):
|
||||
return "test_tool"
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "This is a test tool."
|
||||
|
||||
@property
|
||||
def args(self):
|
||||
return {}
|
||||
|
||||
def execute(self, *args, **kwargs):
|
||||
return "executed"
|
||||
|
||||
async def async_execute(self, *args, **kwargs):
|
||||
return "async executed"
|
||||
|
||||
|
||||
def test_base_tool():
|
||||
tool = TestBaseTool()
|
||||
assert tool.name == "test_tool"
|
||||
assert tool.description == "This is a test tool."
|
||||
assert tool.execute() == "executed"
|
||||
assert asyncio.run(tool.async_execute()) == "async executed"
|
||||
|
||||
|
||||
def test_function_tool_sync() -> None:
|
||||
def two_sum(a: int, b: int) -> int:
|
||||
"""Add two numbers."""
|
||||
return a + b
|
||||
|
||||
ft = FunctionTool(name="sample", func=two_sum)
|
||||
assert ft.execute(1, 2) == 3
|
||||
with pytest.raises(ValueError):
|
||||
asyncio.run(ft.async_execute(1, 2))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_function_tool_async() -> None:
|
||||
async def sample_async_func(a: int, b: int) -> int:
|
||||
"""Add two numbers asynchronously."""
|
||||
return a + b
|
||||
|
||||
ft = FunctionTool(name="sample_async", func=sample_async_func)
|
||||
with pytest.raises(ValueError):
|
||||
ft.execute(1, 2)
|
||||
assert await ft.async_execute(1, 2) == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_function_tool_sync_with_args() -> None:
|
||||
def two_sum(a: int, b: int) -> int:
|
||||
"""Add two numbers."""
|
||||
return a + b
|
||||
|
||||
ft = FunctionTool(
|
||||
name="sample",
|
||||
func=two_sum,
|
||||
args={
|
||||
"a": {"type": "integer", "name": "a", "description": "The first number."},
|
||||
"b": {"type": "integer", "name": "b", "description": "The second number."},
|
||||
},
|
||||
)
|
||||
ft1 = FunctionTool(
|
||||
name="sample",
|
||||
func=two_sum,
|
||||
args={
|
||||
"a": ToolParameter(
|
||||
type="integer", name="a", description="The first number."
|
||||
),
|
||||
"b": ToolParameter(
|
||||
type="integer", name="b", description="The second number."
|
||||
),
|
||||
},
|
||||
)
|
||||
assert ft.description == "Add two numbers."
|
||||
assert ft.args.keys() == {"a", "b"}
|
||||
assert ft.args["a"].type == "integer"
|
||||
assert ft.args["a"].name == "a"
|
||||
assert ft.args["a"].description == "The first number."
|
||||
assert ft.args["a"].title == "A"
|
||||
dict_params = [
|
||||
{
|
||||
"name": "a",
|
||||
"type": "integer",
|
||||
"description": "The first number.",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "b",
|
||||
"type": "integer",
|
||||
"description": "The second number.",
|
||||
"required": True,
|
||||
},
|
||||
]
|
||||
json_params = json.dumps(dict_params, ensure_ascii=False)
|
||||
expected_prompt = (
|
||||
f"sample: Call this tool to interact with the sample API. What is the "
|
||||
f"sample API useful for? Add two numbers. Parameters: {json_params}"
|
||||
)
|
||||
assert await ft.get_prompt() == expected_prompt
|
||||
assert await ft1.get_prompt() == expected_prompt
|
||||
assert ft.execute(1, 2) == 3
|
||||
with pytest.raises(ValueError):
|
||||
await ft.async_execute(1, 2)
|
||||
|
||||
|
||||
def test_function_tool_sync_with_complex_types() -> None:
|
||||
@tool
|
||||
def complex_func(
|
||||
a: int,
|
||||
b: Annotated[int, Doc("The second number.")],
|
||||
c: Annotated[str, Doc("The third string.")],
|
||||
d: List[int],
|
||||
e: Annotated[Dict[str, int], Doc("A dictionary of integers.")],
|
||||
f: Optional[float] = None,
|
||||
g: str | None = None,
|
||||
) -> int:
|
||||
"""A complex function."""
|
||||
return (
|
||||
a + b + len(c) + sum(d) + sum(e.values()) + (f or 0) + (len(g) if g else 0)
|
||||
)
|
||||
|
||||
ft: FunctionTool = complex_func._tool
|
||||
assert ft.description == "A complex function."
|
||||
assert ft.args.keys() == {"a", "b", "c", "d", "e", "f", "g"}
|
||||
assert ft.args["a"].type == "integer"
|
||||
assert ft.args["a"].description == "A"
|
||||
assert ft.args["b"].type == "integer"
|
||||
assert ft.args["b"].description == "The second number."
|
||||
assert ft.args["c"].type == "string"
|
||||
assert ft.args["c"].description == "The third string."
|
||||
assert ft.args["d"].type == "array"
|
||||
assert ft.args["d"].description == "D"
|
||||
assert ft.args["e"].type == "object"
|
||||
assert ft.args["e"].description == "A dictionary of integers."
|
||||
assert ft.args["f"].type == "float"
|
||||
assert ft.args["f"].description == "F"
|
||||
assert ft.args["g"].type == "string"
|
||||
assert ft.args["g"].description == "G"
|
||||
|
||||
|
||||
def test_function_tool_sync_with_args_schema() -> None:
|
||||
class ArgsSchema(BaseModel):
|
||||
a: int = Field(description="The first number.")
|
||||
b: int = Field(description="The second number.")
|
||||
c: Optional[str] = Field(None, description="The third string.")
|
||||
d: List[int] = Field(description="Numbers.")
|
||||
|
||||
@tool(args_schema=ArgsSchema)
|
||||
def complex_func(a: int, b: int, c: Optional[str] = None) -> int:
|
||||
"""A complex function."""
|
||||
return a + b + len(c) if c else 0
|
||||
|
||||
ft: FunctionTool = complex_func._tool
|
||||
assert ft.description == "A complex function."
|
||||
assert ft.args.keys() == {"a", "b", "c", "d"}
|
||||
assert ft.args["a"].type == "integer"
|
||||
assert ft.args["a"].description == "The first number."
|
||||
assert ft.args["b"].type == "integer"
|
||||
assert ft.args["b"].description == "The second number."
|
||||
assert ft.args["c"].type == "string"
|
||||
assert ft.args["c"].description == "The third string."
|
||||
assert ft.args["d"].type == "array"
|
||||
assert ft.args["d"].description == "Numbers."
|
||||
|
||||
|
||||
def test_tool_decorator() -> None:
|
||||
@tool(description="Add two numbers")
|
||||
def add(a: int, b: int) -> int:
|
||||
"""Add two numbers."""
|
||||
return a + b
|
||||
|
||||
assert add(1, 2) == 3
|
||||
assert add._tool.name == "add"
|
||||
assert add._tool.description == "Add two numbers"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_decorator_async() -> None:
|
||||
@tool
|
||||
async def async_add(a: int, b: int) -> int:
|
||||
"""Asynchronously add two numbers."""
|
||||
return a + b
|
||||
|
||||
assert await async_add(1, 2) == 3
|
Reference in New Issue
Block a user