mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-10-22 09:28:42 +00:00
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com> Co-authored-by: lcx01800250 <lcx01800250@alibaba-inc.com> Co-authored-by: licunxing <864255598@qq.com> Co-authored-by: Aralhi <xiaoping0501@gmail.com> Co-authored-by: xuyuan23 <643854343@qq.com> Co-authored-by: aries_ckt <916701291@qq.com> Co-authored-by: hzh97 <2976151305@qq.com>
209 lines
6.9 KiB
Python
209 lines
6.9 KiB
Python
import logging
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, cast
|
|
|
|
import tomlkit
|
|
|
|
from dbgpt._private.pydantic import BaseModel, Field, root_validator
|
|
from dbgpt.component import BaseComponent, SystemApp
|
|
from dbgpt.core.awel.flow.flow_factory import FlowPanel
|
|
from dbgpt.util.dbgpts.base import (
|
|
DBGPTS_METADATA_FILE,
|
|
INSTALL_DIR,
|
|
INSTALL_METADATA_FILE,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BasePackage(BaseModel):
|
|
class Config:
|
|
arbitrary_types_allowed = True
|
|
|
|
name: str = Field(..., description="The name of the package")
|
|
label: str = Field(..., description="The label of the package")
|
|
package_type: str = Field(..., description="The type of the package")
|
|
version: str = Field(..., description="The version of the package")
|
|
description: str = Field(..., description="The description of the package")
|
|
path: str = Field(..., description="The path of the package")
|
|
authors: List[str] = Field(
|
|
default_factory=list, description="The authors of the package"
|
|
)
|
|
definition_type: str = Field(
|
|
default="python", description="The type of the package"
|
|
)
|
|
definition_file: Optional[str] = Field(
|
|
default=None, description="The definition " "file of the package"
|
|
)
|
|
repo: str = Field(..., description="The repository of the package")
|
|
|
|
@classmethod
|
|
def build_from(cls, values: Dict[str, Any], ext_dict: Dict[str, Any]):
|
|
return cls(**values)
|
|
|
|
@root_validator(pre=True)
|
|
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Pre-fill the definition_file"""
|
|
import importlib.resources as pkg_resources
|
|
|
|
name = values.get("name")
|
|
if not name:
|
|
raise ValueError("The name is required")
|
|
with pkg_resources.path(name, "__init__.py") as path:
|
|
# Read the file
|
|
values["path"] = os.path.dirname(os.path.abspath(path))
|
|
return values
|
|
|
|
def abs_definition_file(self) -> str:
|
|
return str(Path(self.path) / self.definition_file)
|
|
|
|
|
|
class FlowPackage(BasePackage):
|
|
package_type = "flow"
|
|
|
|
@classmethod
|
|
def build_from(
|
|
cls, values: Dict[str, Any], ext_dict: Dict[str, Any]
|
|
) -> "FlowPackage":
|
|
if values["definition_type"] == "json":
|
|
return FlowJsonPackage.build_from(values, ext_dict)
|
|
return cls(**values)
|
|
|
|
|
|
class FlowJsonPackage(FlowPackage):
|
|
@classmethod
|
|
def build_from(cls, values: Dict[str, Any], ext_dict: Dict[str, Any]):
|
|
if "json_config" not in ext_dict:
|
|
raise ValueError("The json_config is required")
|
|
if "file_path" not in ext_dict["json_config"]:
|
|
raise ValueError("The file_path is required")
|
|
values["definition_file"] = ext_dict["json_config"]["file_path"]
|
|
return cls(**values)
|
|
|
|
def read_definition_json(self) -> Dict[str, Any]:
|
|
import json
|
|
|
|
with open(self.abs_definition_file(), "r", encoding="utf-8") as f:
|
|
return json.loads(f.read())
|
|
|
|
|
|
class OperatorPackage(BasePackage):
|
|
package_type = "operator"
|
|
|
|
|
|
class InstalledPackage(BaseModel):
|
|
name: str = Field(..., description="The name of the package")
|
|
repo: str = Field(..., description="The repository of the package")
|
|
root: str = Field(..., description="The root of the package")
|
|
|
|
|
|
def _parse_package_metadata(package: InstalledPackage) -> BasePackage:
|
|
with open(
|
|
Path(package.root) / DBGPTS_METADATA_FILE, mode="r+", encoding="utf-8"
|
|
) as f:
|
|
metadata = tomlkit.loads(f.read())
|
|
ext_metadata = {}
|
|
pkg_dict = {}
|
|
for key, value in metadata.items():
|
|
if key == "flow":
|
|
pkg_dict = value
|
|
pkg_dict["package_type"] = "flow"
|
|
else:
|
|
ext_metadata[key] = value
|
|
pkg_dict["repo"] = package.repo
|
|
if pkg_dict["package_type"] == "flow":
|
|
return FlowPackage.build_from(pkg_dict, ext_metadata)
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported package package_type: {pkg_dict['package_type']}"
|
|
)
|
|
|
|
|
|
def _load_installed_package(path: str) -> List[InstalledPackage]:
|
|
packages = []
|
|
for package in os.listdir(path):
|
|
full_path = Path(path) / package
|
|
install_metadata_file = full_path / INSTALL_METADATA_FILE
|
|
dbgpts_metadata_file = full_path / DBGPTS_METADATA_FILE
|
|
if (
|
|
full_path.is_dir()
|
|
and install_metadata_file.exists()
|
|
and dbgpts_metadata_file.exists()
|
|
):
|
|
with open(install_metadata_file) as f:
|
|
metadata = tomlkit.loads(f.read())
|
|
name = metadata["name"]
|
|
repo = metadata["repo"]
|
|
packages.append(
|
|
InstalledPackage(name=name, repo=repo, root=str(full_path))
|
|
)
|
|
return packages
|
|
|
|
|
|
def _load_package_from_path(path: str):
|
|
"""Load the package from the specified path"""
|
|
packages = _load_installed_package(path)
|
|
parsed_packages = []
|
|
for package in packages:
|
|
parsed_packages.append(_parse_package_metadata(package))
|
|
return parsed_packages
|
|
|
|
|
|
class DBGPTsLoader(BaseComponent):
|
|
"""The loader of the dbgpts packages"""
|
|
|
|
name = "dbgpt_dbgpts_loader"
|
|
|
|
def __init__(
|
|
self, system_app: Optional[SystemApp] = None, install_dir: Optional[str] = None
|
|
):
|
|
"""Initialize the DBGPTsLoader."""
|
|
self._system_app = None
|
|
self._install_dir = install_dir or INSTALL_DIR
|
|
self._packages: Dict[str, BasePackage] = {}
|
|
super().__init__(system_app)
|
|
|
|
def init_app(self, system_app: SystemApp):
|
|
"""Initialize the DBGPTsLoader."""
|
|
self._system_app = system_app
|
|
|
|
def before_start(self):
|
|
"""Execute after the application starts."""
|
|
self.load_package()
|
|
|
|
def load_package(self) -> None:
|
|
"""Load the package by name."""
|
|
try:
|
|
packages = _load_package_from_path(self._install_dir)
|
|
logger.info(
|
|
f"Found {len(packages)} dbgpts packages from {self._install_dir}"
|
|
)
|
|
for package in packages:
|
|
self._packages[package.name] = package
|
|
except Exception as e:
|
|
logger.warning(f"Load dbgpts package error: {e}")
|
|
|
|
def get_flows(self) -> List[FlowPanel]:
|
|
"""Get the flows.
|
|
|
|
Returns:
|
|
List[FlowPanel]: The list of the flows
|
|
"""
|
|
panels = []
|
|
for package in self._packages.values():
|
|
if package.package_type != "flow":
|
|
continue
|
|
package = cast(FlowJsonPackage, package)
|
|
dict_value = {
|
|
"name": package.name,
|
|
"label": package.label,
|
|
"version": package.version,
|
|
"editable": False,
|
|
"description": package.description,
|
|
"source": package.repo,
|
|
"flow_data": package.read_definition_json(),
|
|
}
|
|
panels.append(FlowPanel(**dict_value))
|
|
return panels
|