mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-07 20:10:08 +00:00
feat(agent): dbgpts support agent (#1417)
This commit is contained in:
@@ -3,7 +3,7 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
from typing import Any, Dict, List, Optional, Type, TypeVar, cast
|
||||
|
||||
import schedule
|
||||
import tomlkit
|
||||
@@ -18,6 +18,7 @@ from dbgpt.util.dbgpts.base import (
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class BasePackage(BaseModel):
|
||||
@@ -68,6 +69,32 @@ class BasePackage(BaseModel):
|
||||
def abs_definition_file(self) -> str:
|
||||
return str(Path(self.path) / self.definition_file)
|
||||
|
||||
@classmethod
|
||||
def load_module_class(
|
||||
cls, values: Dict[str, Any], expected_cls: Type[T]
|
||||
) -> List[Type[T]]:
|
||||
import importlib.resources as pkg_resources
|
||||
|
||||
from dbgpt.core.awel.dag.loader import _load_modules_from_file
|
||||
|
||||
name = values.get("name")
|
||||
root = values.get("root")
|
||||
if not name:
|
||||
raise ValueError("The name is required")
|
||||
if not root:
|
||||
raise ValueError("The root is required")
|
||||
if root not in sys.path:
|
||||
sys.path.append(root)
|
||||
with pkg_resources.path(name, "__init__.py") as path:
|
||||
mods = _load_modules_from_file(str(path), name, show_log=False)
|
||||
all_cls = [_get_classes_from_module(m) for m in mods]
|
||||
module_cls = []
|
||||
for list_cls in all_cls:
|
||||
for c in list_cls:
|
||||
if issubclass(c, expected_cls):
|
||||
module_cls.append(c)
|
||||
return module_cls
|
||||
|
||||
|
||||
class FlowPackage(BasePackage):
|
||||
package_type = "flow"
|
||||
@@ -107,24 +134,24 @@ class OperatorPackage(BasePackage):
|
||||
|
||||
@classmethod
|
||||
def build_from(cls, values: Dict[str, Any], ext_dict: Dict[str, Any]):
|
||||
import importlib.resources as pkg_resources
|
||||
|
||||
from dbgpt.core.awel import BaseOperator
|
||||
from dbgpt.core.awel.dag.loader import _load_modules_from_file
|
||||
|
||||
name = values.get("name")
|
||||
root = values.get("root")
|
||||
if root not in sys.path:
|
||||
sys.path.append(root)
|
||||
with pkg_resources.path(name, "__init__.py") as path:
|
||||
mods = _load_modules_from_file(str(path), name, show_log=False)
|
||||
all_cls = [_get_classes_from_module(m) for m in mods]
|
||||
operators = []
|
||||
for list_cls in all_cls:
|
||||
for c in list_cls:
|
||||
if issubclass(c, BaseOperator):
|
||||
operators.append(c)
|
||||
values["operators"] = operators
|
||||
values["operators"] = cls.load_module_class(values, BaseOperator)
|
||||
return cls(**values)
|
||||
|
||||
|
||||
class AgentPackage(BasePackage):
|
||||
package_type = "agent"
|
||||
|
||||
agents: List[type] = Field(
|
||||
default_factory=list, description="The agents of the package"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build_from(cls, values: Dict[str, Any], ext_dict: Dict[str, Any]):
|
||||
from dbgpt.agent import ConversableAgent
|
||||
|
||||
values["agents"] = cls.load_module_class(values, ConversableAgent)
|
||||
return cls(**values)
|
||||
|
||||
|
||||
@@ -153,11 +180,14 @@ def _parse_package_metadata(package: InstalledPackage) -> BasePackage:
|
||||
pkg_dict = {}
|
||||
for key, value in metadata.items():
|
||||
if key == "flow":
|
||||
pkg_dict = value
|
||||
pkg_dict = {k: v for k, v in value.items()}
|
||||
pkg_dict["package_type"] = "flow"
|
||||
elif key == "operator":
|
||||
pkg_dict = {k: v for k, v in value.items()}
|
||||
pkg_dict["package_type"] = "operator"
|
||||
elif key == "agent":
|
||||
pkg_dict = {k: v for k, v in value.items()}
|
||||
pkg_dict["package_type"] = "agent"
|
||||
else:
|
||||
ext_metadata[key] = value
|
||||
pkg_dict["root"] = package.root
|
||||
@@ -167,6 +197,8 @@ def _parse_package_metadata(package: InstalledPackage) -> BasePackage:
|
||||
return FlowPackage.build_from(pkg_dict, ext_metadata)
|
||||
elif pkg_dict["package_type"] == "operator":
|
||||
return OperatorPackage.build_from(pkg_dict, ext_metadata)
|
||||
elif pkg_dict["package_type"] == "agent":
|
||||
return AgentPackage.build_from(pkg_dict, ext_metadata)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported package package_type: {pkg_dict['package_type']}"
|
||||
@@ -243,6 +275,7 @@ class DBGPTsLoader(BaseComponent):
|
||||
)
|
||||
for package in packages:
|
||||
self._packages[package.name] = package
|
||||
self._register_packages(package)
|
||||
except Exception as e:
|
||||
logger.warning(f"Load dbgpts package error: {e}")
|
||||
|
||||
@@ -268,3 +301,16 @@ class DBGPTsLoader(BaseComponent):
|
||||
}
|
||||
panels.append(FlowPanel(**dict_value))
|
||||
return panels
|
||||
|
||||
def _register_packages(self, package: BasePackage):
|
||||
if package.package_type == "agent":
|
||||
from dbgpt.agent import ConversableAgent, get_agent_manager
|
||||
|
||||
agent_manager = get_agent_manager(self._system_app)
|
||||
pkg = cast(AgentPackage, package)
|
||||
for agent_cls in pkg.agents:
|
||||
if issubclass(agent_cls, ConversableAgent):
|
||||
try:
|
||||
agent_manager.register_agent(agent_cls, ignore_duplicate=True)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Register agent {agent_cls} error: {e}")
|
||||
|
Reference in New Issue
Block a user