mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-07 12:00:46 +00:00
feat(core): Dynamically loading dbgpts (#1211)
This commit is contained in:
@@ -1,8 +1,11 @@
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
import schedule
|
||||
import tomlkit
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field, root_validator
|
||||
@@ -36,6 +39,7 @@ class BasePackage(BaseModel):
|
||||
definition_file: Optional[str] = Field(
|
||||
default=None, description="The definition " "file of the package"
|
||||
)
|
||||
root: str = Field(..., description="The root of the package")
|
||||
repo: str = Field(..., description="The repository of the package")
|
||||
|
||||
@classmethod
|
||||
@@ -48,8 +52,13 @@ class BasePackage(BaseModel):
|
||||
import importlib.resources as pkg_resources
|
||||
|
||||
name = values.get("name")
|
||||
root = values.get("root")
|
||||
if not name:
|
||||
raise ValueError("The name is required")
|
||||
if not root:
|
||||
raise ValueError("The root is required")
|
||||
if root not in sys.path:
|
||||
sys.path.append(root)
|
||||
with pkg_resources.path(name, "__init__.py") as path:
|
||||
# Read the file
|
||||
values["path"] = os.path.dirname(os.path.abspath(path))
|
||||
@@ -91,6 +100,32 @@ class FlowJsonPackage(FlowPackage):
|
||||
class OperatorPackage(BasePackage):
|
||||
package_type = "operator"
|
||||
|
||||
operators: List[type] = Field(
|
||||
default_factory=list, description="The operators of the package"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build_from(cls, values: Dict[str, Any], ext_dict: Dict[str, Any]):
|
||||
import importlib.resources as pkg_resources
|
||||
|
||||
from dbgpt.core.awel import BaseOperator
|
||||
from dbgpt.core.awel.dag.loader import _load_modules_from_file
|
||||
|
||||
name = values.get("name")
|
||||
root = values.get("root")
|
||||
if root not in sys.path:
|
||||
sys.path.append(root)
|
||||
with pkg_resources.path(name, "__init__.py") as path:
|
||||
mods = _load_modules_from_file(str(path), name, show_log=False)
|
||||
all_cls = [_get_classes_from_module(m) for m in mods]
|
||||
operators = []
|
||||
for list_cls in all_cls:
|
||||
for c in list_cls:
|
||||
if issubclass(c, BaseOperator):
|
||||
operators.append(c)
|
||||
values["operators"] = operators
|
||||
return cls(**values)
|
||||
|
||||
|
||||
class InstalledPackage(BaseModel):
|
||||
name: str = Field(..., description="The name of the package")
|
||||
@@ -98,6 +133,15 @@ class InstalledPackage(BaseModel):
|
||||
root: str = Field(..., description="The root of the package")
|
||||
|
||||
|
||||
def _get_classes_from_module(module):
|
||||
classes = [
|
||||
obj
|
||||
for name, obj in inspect.getmembers(module, inspect.isclass)
|
||||
if obj.__module__ == module.__name__
|
||||
]
|
||||
return classes
|
||||
|
||||
|
||||
def _parse_package_metadata(package: InstalledPackage) -> BasePackage:
|
||||
with open(
|
||||
Path(package.root) / DBGPTS_METADATA_FILE, mode="r+", encoding="utf-8"
|
||||
@@ -109,11 +153,17 @@ def _parse_package_metadata(package: InstalledPackage) -> BasePackage:
|
||||
if key == "flow":
|
||||
pkg_dict = value
|
||||
pkg_dict["package_type"] = "flow"
|
||||
elif key == "operator":
|
||||
pkg_dict = {k: v for k, v in value.items()}
|
||||
pkg_dict["package_type"] = "operator"
|
||||
else:
|
||||
ext_metadata[key] = value
|
||||
pkg_dict["root"] = package.root
|
||||
pkg_dict["repo"] = package.repo
|
||||
if pkg_dict["package_type"] == "flow":
|
||||
return FlowPackage.build_from(pkg_dict, ext_metadata)
|
||||
elif pkg_dict["package_type"] == "operator":
|
||||
return OperatorPackage.build_from(pkg_dict, ext_metadata)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported package package_type: {pkg_dict['package_type']}"
|
||||
@@ -156,12 +206,16 @@ class DBGPTsLoader(BaseComponent):
|
||||
name = "dbgpt_dbgpts_loader"
|
||||
|
||||
def __init__(
|
||||
self, system_app: Optional[SystemApp] = None, install_dir: Optional[str] = None
|
||||
self,
|
||||
system_app: Optional[SystemApp] = None,
|
||||
install_dir: Optional[str] = None,
|
||||
load_dbgpts_interval: int = 10,
|
||||
):
|
||||
"""Initialize the DBGPTsLoader."""
|
||||
self._system_app = None
|
||||
self._install_dir = install_dir or INSTALL_DIR
|
||||
self._packages: Dict[str, BasePackage] = {}
|
||||
self._load_dbgpts_interval = load_dbgpts_interval
|
||||
super().__init__(system_app)
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
@@ -170,15 +224,18 @@ class DBGPTsLoader(BaseComponent):
|
||||
|
||||
def before_start(self):
|
||||
"""Execute after the application starts."""
|
||||
self.load_package()
|
||||
self.load_package(is_first=True)
|
||||
|
||||
def load_package(self) -> None:
|
||||
schedule.every(self._load_dbgpts_interval).seconds.do(self.load_package)
|
||||
|
||||
def load_package(self, is_first: bool = False) -> None:
|
||||
"""Load the package by name."""
|
||||
try:
|
||||
packages = _load_package_from_path(self._install_dir)
|
||||
logger.info(
|
||||
f"Found {len(packages)} dbgpts packages from {self._install_dir}"
|
||||
)
|
||||
if is_first:
|
||||
logger.info(
|
||||
f"Found {len(packages)} dbgpts packages from {self._install_dir}"
|
||||
)
|
||||
for package in packages:
|
||||
self._packages[package.name] = package
|
||||
except Exception as e:
|
||||
|
Reference in New Issue
Block a user