mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-09 12:59:43 +00:00
feat(agent): dbgpts support agent (#1417)
This commit is contained in:
@@ -3,7 +3,7 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
from typing import Any, Dict, List, Optional, Type, TypeVar, cast
|
||||
|
||||
import schedule
|
||||
import tomlkit
|
||||
@@ -18,6 +18,7 @@ from dbgpt.util.dbgpts.base import (
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class BasePackage(BaseModel):
|
||||
@@ -68,6 +69,32 @@ class BasePackage(BaseModel):
|
||||
def abs_definition_file(self) -> str:
|
||||
return str(Path(self.path) / self.definition_file)
|
||||
|
||||
@classmethod
|
||||
def load_module_class(
|
||||
cls, values: Dict[str, Any], expected_cls: Type[T]
|
||||
) -> List[Type[T]]:
|
||||
import importlib.resources as pkg_resources
|
||||
|
||||
from dbgpt.core.awel.dag.loader import _load_modules_from_file
|
||||
|
||||
name = values.get("name")
|
||||
root = values.get("root")
|
||||
if not name:
|
||||
raise ValueError("The name is required")
|
||||
if not root:
|
||||
raise ValueError("The root is required")
|
||||
if root not in sys.path:
|
||||
sys.path.append(root)
|
||||
with pkg_resources.path(name, "__init__.py") as path:
|
||||
mods = _load_modules_from_file(str(path), name, show_log=False)
|
||||
all_cls = [_get_classes_from_module(m) for m in mods]
|
||||
module_cls = []
|
||||
for list_cls in all_cls:
|
||||
for c in list_cls:
|
||||
if issubclass(c, expected_cls):
|
||||
module_cls.append(c)
|
||||
return module_cls
|
||||
|
||||
|
||||
class FlowPackage(BasePackage):
|
||||
package_type = "flow"
|
||||
@@ -107,24 +134,24 @@ class OperatorPackage(BasePackage):
|
||||
|
||||
@classmethod
|
||||
def build_from(cls, values: Dict[str, Any], ext_dict: Dict[str, Any]):
|
||||
import importlib.resources as pkg_resources
|
||||
|
||||
from dbgpt.core.awel import BaseOperator
|
||||
from dbgpt.core.awel.dag.loader import _load_modules_from_file
|
||||
|
||||
name = values.get("name")
|
||||
root = values.get("root")
|
||||
if root not in sys.path:
|
||||
sys.path.append(root)
|
||||
with pkg_resources.path(name, "__init__.py") as path:
|
||||
mods = _load_modules_from_file(str(path), name, show_log=False)
|
||||
all_cls = [_get_classes_from_module(m) for m in mods]
|
||||
operators = []
|
||||
for list_cls in all_cls:
|
||||
for c in list_cls:
|
||||
if issubclass(c, BaseOperator):
|
||||
operators.append(c)
|
||||
values["operators"] = operators
|
||||
values["operators"] = cls.load_module_class(values, BaseOperator)
|
||||
return cls(**values)
|
||||
|
||||
|
||||
class AgentPackage(BasePackage):
|
||||
package_type = "agent"
|
||||
|
||||
agents: List[type] = Field(
|
||||
default_factory=list, description="The agents of the package"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build_from(cls, values: Dict[str, Any], ext_dict: Dict[str, Any]):
|
||||
from dbgpt.agent import ConversableAgent
|
||||
|
||||
values["agents"] = cls.load_module_class(values, ConversableAgent)
|
||||
return cls(**values)
|
||||
|
||||
|
||||
@@ -153,11 +180,14 @@ def _parse_package_metadata(package: InstalledPackage) -> BasePackage:
|
||||
pkg_dict = {}
|
||||
for key, value in metadata.items():
|
||||
if key == "flow":
|
||||
pkg_dict = value
|
||||
pkg_dict = {k: v for k, v in value.items()}
|
||||
pkg_dict["package_type"] = "flow"
|
||||
elif key == "operator":
|
||||
pkg_dict = {k: v for k, v in value.items()}
|
||||
pkg_dict["package_type"] = "operator"
|
||||
elif key == "agent":
|
||||
pkg_dict = {k: v for k, v in value.items()}
|
||||
pkg_dict["package_type"] = "agent"
|
||||
else:
|
||||
ext_metadata[key] = value
|
||||
pkg_dict["root"] = package.root
|
||||
@@ -167,6 +197,8 @@ def _parse_package_metadata(package: InstalledPackage) -> BasePackage:
|
||||
return FlowPackage.build_from(pkg_dict, ext_metadata)
|
||||
elif pkg_dict["package_type"] == "operator":
|
||||
return OperatorPackage.build_from(pkg_dict, ext_metadata)
|
||||
elif pkg_dict["package_type"] == "agent":
|
||||
return AgentPackage.build_from(pkg_dict, ext_metadata)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported package package_type: {pkg_dict['package_type']}"
|
||||
@@ -243,6 +275,7 @@ class DBGPTsLoader(BaseComponent):
|
||||
)
|
||||
for package in packages:
|
||||
self._packages[package.name] = package
|
||||
self._register_packages(package)
|
||||
except Exception as e:
|
||||
logger.warning(f"Load dbgpts package error: {e}")
|
||||
|
||||
@@ -268,3 +301,16 @@ class DBGPTsLoader(BaseComponent):
|
||||
}
|
||||
panels.append(FlowPanel(**dict_value))
|
||||
return panels
|
||||
|
||||
def _register_packages(self, package: BasePackage):
|
||||
if package.package_type == "agent":
|
||||
from dbgpt.agent import ConversableAgent, get_agent_manager
|
||||
|
||||
agent_manager = get_agent_manager(self._system_app)
|
||||
pkg = cast(AgentPackage, package)
|
||||
for agent_cls in pkg.agents:
|
||||
if issubclass(agent_cls, ConversableAgent):
|
||||
try:
|
||||
agent_manager.register_agent(agent_cls, ignore_duplicate=True)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Register agent {agent_cls} error: {e}")
|
||||
|
@@ -73,8 +73,7 @@ def _print_repos():
|
||||
table.add_column(_("Repository"), justify="right", style="cyan", no_wrap=True)
|
||||
table.add_column(_("Path"), justify="right", style="green")
|
||||
for repo, full_path in repos:
|
||||
if full_path.startswith(str(Path.home())):
|
||||
full_path = full_path.replace(str(Path.home()), "~")
|
||||
full_path = _print_path(full_path)
|
||||
table.add_row(repo, full_path)
|
||||
cl.print(table)
|
||||
|
||||
@@ -245,7 +244,14 @@ def _copy_and_install(repo: str, name: str, package_path: Path):
|
||||
shutil.copytree(package_path, install_path)
|
||||
cl.info(f"Installing dbgpts '{name}' from {repo}...")
|
||||
os.chdir(install_path)
|
||||
subprocess.run(["poetry", "install"], check=True)
|
||||
subprocess.run(["poetry", "build"], check=True)
|
||||
wheel_files = list(install_path.glob("dist/*.whl"))
|
||||
if not wheel_files:
|
||||
cl.error("No wheel file found after building the package.", exit_code=1)
|
||||
# Install the wheel file using pip
|
||||
wheel_file = wheel_files[0]
|
||||
cl.info(f"Installing dbgpts '{name}' wheel file {_print_path(wheel_file)}...")
|
||||
subprocess.run(["pip", "install", str(wheel_file)], check=True)
|
||||
_write_install_metadata(name, repo, install_path)
|
||||
cl.success(f"Installed dbgpts at {_print_path(install_path)}.")
|
||||
cl.success(f"dbgpts '{name}' installed successfully.")
|
||||
@@ -357,7 +363,6 @@ def list_installed_apps():
|
||||
packages.sort(key=lambda x: (x.package, x.package_type, x.repo))
|
||||
for package in packages:
|
||||
str_path = package.root
|
||||
if str_path.startswith(str(Path.home())):
|
||||
str_path = str_path.replace(str(Path.home()), "~")
|
||||
str_path = _print_path(str_path)
|
||||
table.add_row(package.package, package.package_type, package.repo, str_path)
|
||||
cl.print(table)
|
||||
|
@@ -50,6 +50,15 @@ def create_template(
|
||||
definition_type,
|
||||
working_directory,
|
||||
)
|
||||
elif dbgpts_type == "agent":
|
||||
_create_agent_template(
|
||||
name,
|
||||
mod_name,
|
||||
dbgpts_type,
|
||||
base_metadata,
|
||||
definition_type,
|
||||
working_directory,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid dbgpts type: {dbgpts_type}")
|
||||
|
||||
@@ -111,6 +120,31 @@ def _create_operator_template(
|
||||
_write_manifest_file(working_directory, name, mod_name)
|
||||
|
||||
|
||||
def _create_agent_template(
|
||||
name: str,
|
||||
mod_name: str,
|
||||
dbgpts_type: str,
|
||||
base_metadata: dict,
|
||||
definition_type: str,
|
||||
working_directory: str,
|
||||
):
|
||||
json_dict = {
|
||||
"agent": base_metadata,
|
||||
"python_config": {},
|
||||
"json_config": {},
|
||||
}
|
||||
if definition_type != "python":
|
||||
raise click.ClickException(
|
||||
f"Unsupported definition type: {definition_type} for dbgpts type: "
|
||||
f"{dbgpts_type}"
|
||||
)
|
||||
|
||||
_create_poetry_project(working_directory, name)
|
||||
_write_dbgpts_toml(working_directory, name, json_dict)
|
||||
_write_agent_init_file(working_directory, name, mod_name)
|
||||
_write_manifest_file(working_directory, name, mod_name)
|
||||
|
||||
|
||||
def _create_poetry_project(working_directory: str, name: str):
|
||||
"""Create a new poetry project"""
|
||||
|
||||
@@ -207,3 +241,122 @@ class HelloWorldOperator(MapOperator[str, str]):
|
||||
"""
|
||||
with open(init_file, "w") as f:
|
||||
f.write(f'"""{name} operator package"""\n{content}')
|
||||
|
||||
|
||||
def _write_agent_init_file(working_directory: str, name: str, mod_name: str):
|
||||
"""Write the agent __init__.py file"""
|
||||
|
||||
init_file = Path(working_directory) / name / mod_name / "__init__.py"
|
||||
content = """
|
||||
import asyncio
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from dbgpt.agent import (
|
||||
AgentMessage,
|
||||
Action,
|
||||
ActionOutput,
|
||||
AgentResource,
|
||||
ConversableAgent,
|
||||
)
|
||||
from dbgpt.agent.util import cmp_string_equal
|
||||
|
||||
_HELLO_WORLD = "Hello world"
|
||||
|
||||
|
||||
class HelloWorldSpeakerAgent(ConversableAgent):
|
||||
name: str = "Hodor"
|
||||
profile: str = "HelloWorldSpeaker"
|
||||
goal: str = f"answer any question from user with '{_HELLO_WORLD}'"
|
||||
desc: str = f"You can answer any question from user with '{_HELLO_WORLD}'"
|
||||
constraints: list[str] = [
|
||||
"You can only answer with '{fix_message}'",
|
||||
"You can't use any other words",
|
||||
]
|
||||
examples: str = (
|
||||
f"user: What's your name?\\nassistant: {_HELLO_WORLD}\\n\\n",
|
||||
f"user: What's the weather today?\\nassistant: {_HELLO_WORLD}\\n\\n",
|
||||
f"user: Can you help me?\\nassistant: {_HELLO_WORLD}\\n\\n",
|
||||
f"user: Please tell me a joke.\\nassistant: {_HELLO_WORLD}\\n\\n",
|
||||
f"user: Please answer me without '{_HELLO_WORLD}'.\\nassistant: {_HELLO_WORLD}"
|
||||
"\\n\\n",
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._init_actions([HelloWorldAction])
|
||||
|
||||
def _init_reply_message(self, received_message: AgentMessage) -> AgentMessage:
|
||||
reply_message = super()._init_reply_message(received_message)
|
||||
# Fill in the dynamic parameters in the prompt template
|
||||
reply_message.context = {"fix_message": _HELLO_WORLD}
|
||||
return reply_message
|
||||
|
||||
async def correctness_check(
|
||||
self, message: AgentMessage
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
action_report = message.action_report
|
||||
task_result = ""
|
||||
if action_report:
|
||||
task_result = action_report.get("content", "")
|
||||
if not cmp_string_equal(
|
||||
task_result,
|
||||
_HELLO_WORLD,
|
||||
ignore_case=True,
|
||||
ignore_punctuation=True,
|
||||
ignore_whitespace=True,
|
||||
):
|
||||
return False, f"Please answer with {_HELLO_WORLD}, not '{task_result}'"
|
||||
return True, None
|
||||
|
||||
|
||||
class HelloWorldAction(Action[None]):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
async def run(
|
||||
self,
|
||||
ai_message: str,
|
||||
resource: Optional[AgentResource] = None,
|
||||
rely_action_out: Optional[ActionOutput] = None,
|
||||
need_vis_render: bool = True,
|
||||
**kwargs,
|
||||
) -> ActionOutput:
|
||||
return ActionOutput(is_exe_success=True, content=ai_message)
|
||||
|
||||
|
||||
async def _test_agent():
|
||||
\"\"\"Test the agent.
|
||||
|
||||
It will not run in the production environment.
|
||||
\"\"\"
|
||||
from dbgpt.model.proxy import OpenAILLMClient
|
||||
from dbgpt.agent import AgentContext, GptsMemory, UserProxyAgent, LLMConfig
|
||||
|
||||
llm_client = OpenAILLMClient(model_alias="gpt-3.5-turbo")
|
||||
context: AgentContext = AgentContext(conv_id="summarize")
|
||||
|
||||
default_memory: GptsMemory = GptsMemory()
|
||||
|
||||
speaker = (
|
||||
await HelloWorldSpeakerAgent()
|
||||
.bind(context)
|
||||
.bind(LLMConfig(llm_client=llm_client))
|
||||
.bind(default_memory)
|
||||
.build()
|
||||
)
|
||||
|
||||
user_proxy = await UserProxyAgent().bind(default_memory).bind(context).build()
|
||||
await user_proxy.initiate_chat(
|
||||
recipient=speaker,
|
||||
reviewer=user_proxy,
|
||||
message="What's your name?",
|
||||
)
|
||||
print(await default_memory.one_chat_completions("summarize"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(_test_agent())
|
||||
|
||||
"""
|
||||
with open(init_file, "w") as f:
|
||||
f.write(f'"""{name} agent package."""\n{content}')
|
||||
|
Reference in New Issue
Block a user