mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-08 04:23:35 +00:00
feat(core): Dynamically loading dbgpts (#1211)
This commit is contained in:
@@ -21,6 +21,7 @@ def initialize_components(
|
||||
):
|
||||
# Lazy import to avoid high time cost
|
||||
from dbgpt.app.initialization.embedding_component import _initialize_embedding_model
|
||||
from dbgpt.app.initialization.scheduler import DefaultScheduler
|
||||
from dbgpt.app.initialization.serve_initialization import register_serve_apps
|
||||
from dbgpt.model.cluster.controller.controller import controller
|
||||
|
||||
@@ -28,6 +29,7 @@ def initialize_components(
|
||||
system_app.register(
|
||||
DefaultExecutorFactory, max_workers=param.default_thread_pool_size
|
||||
)
|
||||
system_app.register(DefaultScheduler)
|
||||
system_app.register_instance(controller)
|
||||
|
||||
from dbgpt.serve.agent.hub.controller import module_plugin
|
||||
|
43
dbgpt/app/initialization/scheduler.py
Normal file
43
dbgpt/app/initialization/scheduler.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
|
||||
import schedule
|
||||
|
||||
from dbgpt.component import BaseComponent, SystemApp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DefaultScheduler(BaseComponent):
|
||||
"""The default scheduler"""
|
||||
|
||||
name = "dbgpt_default_scheduler"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
system_app: SystemApp,
|
||||
scheduler_delay_ms: int = 5000,
|
||||
scheduler_interval_ms: int = 1000,
|
||||
):
|
||||
super().__init__(system_app)
|
||||
self.system_app = system_app
|
||||
self._scheduler_interval_ms = scheduler_interval_ms
|
||||
self._scheduler_delay_ms = scheduler_delay_ms
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
self.system_app = system_app
|
||||
|
||||
def after_start(self):
|
||||
thread = threading.Thread(target=self._scheduler)
|
||||
thread.start()
|
||||
|
||||
def _scheduler(self):
|
||||
time.sleep(self._scheduler_delay_ms / 1000)
|
||||
while True:
|
||||
try:
|
||||
schedule.run_pending()
|
||||
except Exception as e:
|
||||
logger.debug(f"Scheduler error: {e}")
|
||||
finally:
|
||||
time.sleep(self._scheduler_interval_ms / 1000)
|
@@ -163,7 +163,7 @@ try:
|
||||
from dbgpt.util.dbgpts.cli import add_repo
|
||||
from dbgpt.util.dbgpts.cli import install as app_install
|
||||
from dbgpt.util.dbgpts.cli import list_all_apps as app_list
|
||||
from dbgpt.util.dbgpts.cli import list_repos, remove_repo
|
||||
from dbgpt.util.dbgpts.cli import list_repos, new_dbgpts, remove_repo
|
||||
from dbgpt.util.dbgpts.cli import uninstall as app_uninstall
|
||||
from dbgpt.util.dbgpts.cli import update_repo
|
||||
|
||||
@@ -174,6 +174,7 @@ try:
|
||||
add_command_alias(app_install, name="install", parent_group=app)
|
||||
add_command_alias(app_uninstall, name="uninstall", parent_group=app)
|
||||
add_command_alias(app_list, name="list-remote", parent_group=app)
|
||||
add_command_alias(new_dbgpts, name="app", parent_group=new)
|
||||
|
||||
except ImportError as e:
|
||||
logging.warning(f"Integrating dbgpt dbgpts command line tool failed: {e}")
|
||||
|
@@ -63,15 +63,19 @@ def _process_file(filepath) -> List[DAG]:
|
||||
return results
|
||||
|
||||
|
||||
def _load_modules_from_file(filepath: str):
|
||||
def _load_modules_from_file(
|
||||
filepath: str, mod_name: str | None = None, show_log: bool = True
|
||||
):
|
||||
import importlib
|
||||
import importlib.machinery
|
||||
import importlib.util
|
||||
|
||||
if show_log:
|
||||
logger.info(f"Importing {filepath}")
|
||||
|
||||
org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1])
|
||||
path_hash = hashlib.sha1(filepath.encode("utf-8")).hexdigest()
|
||||
if mod_name is None:
|
||||
mod_name = f"unusual_prefix_{path_hash}_{org_mod_name}"
|
||||
|
||||
if mod_name in sys.modules:
|
||||
|
@@ -1046,7 +1046,7 @@ class RequestedParsedOperator(MapOperator[CommonLLMHttpRequestBody, str]):
|
||||
"key",
|
||||
str,
|
||||
optional=True,
|
||||
default="",
|
||||
default="messages",
|
||||
description="The key of the dict, link 'user_input'",
|
||||
)
|
||||
],
|
||||
|
@@ -59,7 +59,7 @@ class WorkerApplyOutput:
|
||||
return WorkerApplyOutput("Not outputs")
|
||||
combined_success = all(out.success for out in outs)
|
||||
max_timecost = max(out.timecost for out in outs)
|
||||
combined_message = ", ".join(out.message for out in outs)
|
||||
combined_message = "\n;".join(out.message for out in outs)
|
||||
return WorkerApplyOutput(combined_message, combined_success, max_timecost)
|
||||
|
||||
|
||||
|
@@ -6,6 +6,7 @@ import os
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import asdict
|
||||
from typing import Awaitable, Callable, Iterator
|
||||
@@ -490,6 +491,8 @@ class LocalWorkerManager(WorkerManager):
|
||||
async def _start_all_worker(
|
||||
self, apply_req: WorkerApplyRequest
|
||||
) -> WorkerApplyOutput:
|
||||
from httpx import TimeoutException, TransportError
|
||||
|
||||
# TODO avoid start twice
|
||||
start_time = time.time()
|
||||
logger.info(f"Begin start all worker, apply_req: {apply_req}")
|
||||
@@ -520,9 +523,24 @@ class LocalWorkerManager(WorkerManager):
|
||||
)
|
||||
)
|
||||
out.message = f"{info} start successfully"
|
||||
except Exception as e:
|
||||
except TimeoutException as e:
|
||||
out.success = False
|
||||
out.message = f"{info} start failed, {str(e)}"
|
||||
out.message = (
|
||||
f"{info} start failed for network timeout, please make "
|
||||
f"sure your port is available, if you are using global network "
|
||||
f"proxy, please close it"
|
||||
)
|
||||
except TransportError as e:
|
||||
out.success = False
|
||||
out.message = (
|
||||
f"{info} start failed for network error, please make "
|
||||
f"sure your port is available, if you are using global network "
|
||||
"proxy, please close it"
|
||||
)
|
||||
except Exception:
|
||||
err_msg = traceback.format_exc()
|
||||
out.success = False
|
||||
out.message = f"{info} start failed, {err_msg}"
|
||||
finally:
|
||||
out.timecost = time.time() - _start_time
|
||||
return out
|
||||
@@ -837,10 +855,13 @@ def _setup_fastapi(
|
||||
try:
|
||||
await worker_manager.start()
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting worker manager: {e}")
|
||||
sys.exit(1)
|
||||
import signal
|
||||
|
||||
# It cannot be blocked here because the startup of worker_manager depends on the fastapi app (registered to the controller)
|
||||
logger.error(f"Error starting worker manager: {str(e)}")
|
||||
os.kill(os.getpid(), signal.SIGINT)
|
||||
|
||||
# It cannot be blocked here because the startup of worker_manager depends on
|
||||
# the fastapi app (registered to the controller)
|
||||
asyncio.create_task(start_worker_manager())
|
||||
|
||||
@app.on_event("shutdown")
|
||||
|
@@ -20,3 +20,6 @@ class ServeConfig(BaseServeConfig):
|
||||
api_keys: Optional[str] = field(
|
||||
default=None, metadata={"help": "API keys for the endpoint, if None, allow all"}
|
||||
)
|
||||
load_dbgpts_interval: int = field(
|
||||
default=5, metadata={"help": "Interval to load dbgpts from installed packages"}
|
||||
)
|
||||
|
@@ -3,6 +3,7 @@ import logging
|
||||
import traceback
|
||||
from typing import Any, List, Optional, cast
|
||||
|
||||
import schedule
|
||||
from fastapi import HTTPException
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
@@ -56,7 +57,10 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
self._dao = self._dao or ServeDao(self._serve_config)
|
||||
self._system_app = system_app
|
||||
self._dbgpts_loader = system_app.get_component(
|
||||
DBGPTsLoader.name, DBGPTsLoader, or_register_component=DBGPTsLoader
|
||||
DBGPTsLoader.name,
|
||||
DBGPTsLoader,
|
||||
or_register_component=DBGPTsLoader,
|
||||
load_dbgpts_interval=self._serve_config.load_dbgpts_interval,
|
||||
)
|
||||
|
||||
def before_start(self):
|
||||
@@ -68,7 +72,10 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
def after_start(self):
|
||||
"""Execute after the application starts"""
|
||||
self.load_dag_from_db()
|
||||
self.load_dag_from_dbgpts()
|
||||
self.load_dag_from_dbgpts(is_first_load=True)
|
||||
schedule.every(self._serve_config.load_dbgpts_interval).seconds.do(
|
||||
self.load_dag_from_dbgpts
|
||||
)
|
||||
|
||||
@property
|
||||
def dao(self) -> BaseDao[ServeEntity, ServeRequest, ServerResponse]:
|
||||
@@ -126,6 +133,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
if save_failed_flow:
|
||||
request.state = State.LOAD_FAILED
|
||||
request.error_message = str(e)
|
||||
request.dag_id = ""
|
||||
return self.dao.create(request)
|
||||
else:
|
||||
raise e
|
||||
@@ -147,6 +155,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
if save_failed_flow:
|
||||
request.state = State.LOAD_FAILED
|
||||
request.error_message = f"Register DAG error: {str(e)}"
|
||||
request.dag_id = ""
|
||||
self.dao.update({"uid": request.uid}, request)
|
||||
else:
|
||||
# Rollback
|
||||
@@ -198,7 +207,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
f"dbgpts error: {str(e)}"
|
||||
)
|
||||
|
||||
def load_dag_from_dbgpts(self):
|
||||
def load_dag_from_dbgpts(self, is_first_load: bool = False):
|
||||
"""Load DAG from dbgpts"""
|
||||
flows = self.dbgpts_loader.get_flows()
|
||||
for flow in flows:
|
||||
@@ -208,7 +217,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
exist_inst = self.get({"name": flow.name})
|
||||
if not exist_inst:
|
||||
self.create_and_save_dag(flow, save_failed_flow=True)
|
||||
else:
|
||||
elif is_first_load or exist_inst.state != State.RUNNING:
|
||||
# TODO check version, must be greater than the exist one
|
||||
flow.uid = exist_inst.uid
|
||||
self.update_flow(flow, check_editable=False, save_failed_flow=True)
|
||||
@@ -242,6 +251,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
if save_failed_flow:
|
||||
request.state = State.LOAD_FAILED
|
||||
request.error_message = str(e)
|
||||
request.dag_id = ""
|
||||
return self.dao.update({"uid": request.uid}, request)
|
||||
else:
|
||||
raise e
|
||||
@@ -306,11 +316,12 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
inst = self.get(query_request)
|
||||
if inst is None:
|
||||
raise HTTPException(status_code=404, detail=f"Flow {uid} not found")
|
||||
if not inst.dag_id:
|
||||
if inst.state == State.RUNNING and not inst.dag_id:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Flow {uid}'s dag id not found"
|
||||
status_code=404, detail=f"Running flow {uid}'s dag id not found"
|
||||
)
|
||||
try:
|
||||
if inst.dag_id:
|
||||
self.dag_manager.unregister_dag(inst.dag_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Unregister DAG({inst.dag_id}) error: {str(e)}")
|
||||
|
@@ -23,6 +23,13 @@ DEFAULT_PACKAGE_TYPES = ["agent", "app", "operator", "flow"]
|
||||
INSTALL_METADATA_FILE = "install_metadata.toml"
|
||||
DBGPTS_METADATA_FILE = "dbgpts.toml"
|
||||
|
||||
TYPE_TO_PACKAGE = {
|
||||
"agent": "agents",
|
||||
"app": "apps",
|
||||
"operator": "operators",
|
||||
"flow": "workflow",
|
||||
}
|
||||
|
||||
|
||||
def _get_env_sig() -> str:
|
||||
"""Get a unique signature for the current Python environment."""
|
||||
|
@@ -1,7 +1,31 @@
|
||||
import functools
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
|
||||
from .base import DEFAULT_PACKAGE_TYPES
|
||||
|
||||
|
||||
def check_poetry_installed():
|
||||
try:
|
||||
# Check if poetry is installed
|
||||
subprocess.run(
|
||||
["poetry", "--version"],
|
||||
check=True,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
print("Poetry is not installed. Please install Poetry to proceed.")
|
||||
print(
|
||||
"Visit https://python-poetry.org/docs/#installation for installation "
|
||||
"instructions."
|
||||
)
|
||||
# Exit with error
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def add_tap_options(func):
|
||||
@click.option(
|
||||
@@ -26,6 +50,7 @@ def install(repo: str | None, name: str):
|
||||
"""Install your dbgpts(operators,agents,workflows or apps)"""
|
||||
from .repo import install
|
||||
|
||||
check_poetry_installed()
|
||||
install(name, repo)
|
||||
|
||||
|
||||
@@ -108,3 +133,84 @@ def update_repo(repo: str | None):
|
||||
else:
|
||||
print(f"Updating repo '{p}'...")
|
||||
update_repo(p)
|
||||
|
||||
|
||||
@click.command(name="app")
|
||||
@click.option(
|
||||
"-n",
|
||||
"--name",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The name you want to give to the dbgpt",
|
||||
)
|
||||
@click.option(
|
||||
"-l",
|
||||
"--label",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="The label of the dbgpt",
|
||||
)
|
||||
@click.option(
|
||||
"-d",
|
||||
"--description",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="The description of the dbgpt",
|
||||
)
|
||||
@click.option(
|
||||
"-t",
|
||||
"--type",
|
||||
type=click.Choice(DEFAULT_PACKAGE_TYPES),
|
||||
default="flow",
|
||||
required=False,
|
||||
help="The type of the dbgpt",
|
||||
)
|
||||
@click.option(
|
||||
"--definition_type",
|
||||
type=click.Choice(["json", "python"]),
|
||||
default="json",
|
||||
required=False,
|
||||
help="The definition type of the dbgpt",
|
||||
)
|
||||
@click.option(
|
||||
"-C",
|
||||
"--directory",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="The working directory of the dbgpt(defaults to the current directory).",
|
||||
)
|
||||
def new_dbgpts(
|
||||
name: str,
|
||||
label: str | None,
|
||||
description: str | None,
|
||||
type: str,
|
||||
definition_type: str,
|
||||
directory: str | None,
|
||||
):
|
||||
"""New a dbgpts module structure"""
|
||||
if not label:
|
||||
# Set label to the name
|
||||
default_label = name.replace("-", " ").replace("_", " ").title()
|
||||
label = click.prompt(
|
||||
"Please input the label of the dbgpt", default=default_label
|
||||
)
|
||||
if not description:
|
||||
# Read with click
|
||||
description = click.prompt(
|
||||
"Please input the description of the dbgpt", default=""
|
||||
)
|
||||
if not directory:
|
||||
# Set directory to the current directory(abs path)
|
||||
directory = click.prompt(
|
||||
"Please input the working directory of the dbgpt",
|
||||
default=str(Path.cwd()),
|
||||
type=click.Path(exists=True, file_okay=False, dir_okay=True),
|
||||
)
|
||||
|
||||
check_poetry_installed()
|
||||
from .template import create_template
|
||||
|
||||
create_template(name, label, description, type, definition_type, directory)
|
||||
|
@@ -1,8 +1,11 @@
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
import schedule
|
||||
import tomlkit
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field, root_validator
|
||||
@@ -36,6 +39,7 @@ class BasePackage(BaseModel):
|
||||
definition_file: Optional[str] = Field(
|
||||
default=None, description="The definition " "file of the package"
|
||||
)
|
||||
root: str = Field(..., description="The root of the package")
|
||||
repo: str = Field(..., description="The repository of the package")
|
||||
|
||||
@classmethod
|
||||
@@ -48,8 +52,13 @@ class BasePackage(BaseModel):
|
||||
import importlib.resources as pkg_resources
|
||||
|
||||
name = values.get("name")
|
||||
root = values.get("root")
|
||||
if not name:
|
||||
raise ValueError("The name is required")
|
||||
if not root:
|
||||
raise ValueError("The root is required")
|
||||
if root not in sys.path:
|
||||
sys.path.append(root)
|
||||
with pkg_resources.path(name, "__init__.py") as path:
|
||||
# Read the file
|
||||
values["path"] = os.path.dirname(os.path.abspath(path))
|
||||
@@ -91,6 +100,32 @@ class FlowJsonPackage(FlowPackage):
|
||||
class OperatorPackage(BasePackage):
|
||||
package_type = "operator"
|
||||
|
||||
operators: List[type] = Field(
|
||||
default_factory=list, description="The operators of the package"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build_from(cls, values: Dict[str, Any], ext_dict: Dict[str, Any]):
|
||||
import importlib.resources as pkg_resources
|
||||
|
||||
from dbgpt.core.awel import BaseOperator
|
||||
from dbgpt.core.awel.dag.loader import _load_modules_from_file
|
||||
|
||||
name = values.get("name")
|
||||
root = values.get("root")
|
||||
if root not in sys.path:
|
||||
sys.path.append(root)
|
||||
with pkg_resources.path(name, "__init__.py") as path:
|
||||
mods = _load_modules_from_file(str(path), name, show_log=False)
|
||||
all_cls = [_get_classes_from_module(m) for m in mods]
|
||||
operators = []
|
||||
for list_cls in all_cls:
|
||||
for c in list_cls:
|
||||
if issubclass(c, BaseOperator):
|
||||
operators.append(c)
|
||||
values["operators"] = operators
|
||||
return cls(**values)
|
||||
|
||||
|
||||
class InstalledPackage(BaseModel):
|
||||
name: str = Field(..., description="The name of the package")
|
||||
@@ -98,6 +133,15 @@ class InstalledPackage(BaseModel):
|
||||
root: str = Field(..., description="The root of the package")
|
||||
|
||||
|
||||
def _get_classes_from_module(module):
|
||||
classes = [
|
||||
obj
|
||||
for name, obj in inspect.getmembers(module, inspect.isclass)
|
||||
if obj.__module__ == module.__name__
|
||||
]
|
||||
return classes
|
||||
|
||||
|
||||
def _parse_package_metadata(package: InstalledPackage) -> BasePackage:
|
||||
with open(
|
||||
Path(package.root) / DBGPTS_METADATA_FILE, mode="r+", encoding="utf-8"
|
||||
@@ -109,11 +153,17 @@ def _parse_package_metadata(package: InstalledPackage) -> BasePackage:
|
||||
if key == "flow":
|
||||
pkg_dict = value
|
||||
pkg_dict["package_type"] = "flow"
|
||||
elif key == "operator":
|
||||
pkg_dict = {k: v for k, v in value.items()}
|
||||
pkg_dict["package_type"] = "operator"
|
||||
else:
|
||||
ext_metadata[key] = value
|
||||
pkg_dict["root"] = package.root
|
||||
pkg_dict["repo"] = package.repo
|
||||
if pkg_dict["package_type"] == "flow":
|
||||
return FlowPackage.build_from(pkg_dict, ext_metadata)
|
||||
elif pkg_dict["package_type"] == "operator":
|
||||
return OperatorPackage.build_from(pkg_dict, ext_metadata)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported package package_type: {pkg_dict['package_type']}"
|
||||
@@ -156,12 +206,16 @@ class DBGPTsLoader(BaseComponent):
|
||||
name = "dbgpt_dbgpts_loader"
|
||||
|
||||
def __init__(
|
||||
self, system_app: Optional[SystemApp] = None, install_dir: Optional[str] = None
|
||||
self,
|
||||
system_app: Optional[SystemApp] = None,
|
||||
install_dir: Optional[str] = None,
|
||||
load_dbgpts_interval: int = 10,
|
||||
):
|
||||
"""Initialize the DBGPTsLoader."""
|
||||
self._system_app = None
|
||||
self._install_dir = install_dir or INSTALL_DIR
|
||||
self._packages: Dict[str, BasePackage] = {}
|
||||
self._load_dbgpts_interval = load_dbgpts_interval
|
||||
super().__init__(system_app)
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
@@ -170,12 +224,15 @@ class DBGPTsLoader(BaseComponent):
|
||||
|
||||
def before_start(self):
|
||||
"""Execute after the application starts."""
|
||||
self.load_package()
|
||||
self.load_package(is_first=True)
|
||||
|
||||
def load_package(self) -> None:
|
||||
schedule.every(self._load_dbgpts_interval).seconds.do(self.load_package)
|
||||
|
||||
def load_package(self, is_first: bool = False) -> None:
|
||||
"""Load the package by name."""
|
||||
try:
|
||||
packages = _load_package_from_path(self._install_dir)
|
||||
if is_first:
|
||||
logger.info(
|
||||
f"Found {len(packages)} dbgpts packages from {self._install_dir}"
|
||||
)
|
||||
|
@@ -214,6 +214,7 @@ def _copy_and_install(repo: str, name: str, package_path: Path):
|
||||
err=True,
|
||||
)
|
||||
return
|
||||
try:
|
||||
shutil.copytree(package_path, install_path)
|
||||
logger.info(f"Installing dbgpts '{name}' from {repo}...")
|
||||
os.chdir(install_path)
|
||||
@@ -221,6 +222,10 @@ def _copy_and_install(repo: str, name: str, package_path: Path):
|
||||
_write_install_metadata(name, repo, install_path)
|
||||
click.echo(f"Installed dbgpts at {_print_path(install_path)}.")
|
||||
click.echo(f"dbgpts '{name}' installed successfully.")
|
||||
except Exception as e:
|
||||
if install_path.exists():
|
||||
shutil.rmtree(install_path)
|
||||
raise e
|
||||
|
||||
|
||||
def _write_install_metadata(name: str, repo: str, install_path: Path):
|
||||
|
209
dbgpt/util/dbgpts/template.py
Normal file
209
dbgpt/util/dbgpts/template.py
Normal file
@@ -0,0 +1,209 @@
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
|
||||
from .base import DBGPTS_METADATA_FILE, TYPE_TO_PACKAGE
|
||||
|
||||
|
||||
def create_template(
|
||||
name: str,
|
||||
name_label: str,
|
||||
description: str,
|
||||
dbgpts_type: str,
|
||||
definition_type: str,
|
||||
working_directory: str,
|
||||
):
|
||||
"""Create a new flow dbgpt"""
|
||||
if dbgpts_type != "flow":
|
||||
definition_type = "python"
|
||||
mod_name = name.replace("-", "_")
|
||||
base_metadata = {
|
||||
"label": name_label,
|
||||
"name": mod_name,
|
||||
"version": "0.1.0",
|
||||
"description": description,
|
||||
"authors": [],
|
||||
"definition_type": definition_type,
|
||||
}
|
||||
working_directory = os.path.join(working_directory, TYPE_TO_PACKAGE[dbgpts_type])
|
||||
package_dir = Path(working_directory) / name
|
||||
if os.path.exists(package_dir):
|
||||
raise click.ClickException(f"Package '{str(package_dir)}' already exists")
|
||||
|
||||
if dbgpts_type == "flow":
|
||||
_create_flow_template(
|
||||
name,
|
||||
mod_name,
|
||||
dbgpts_type,
|
||||
base_metadata,
|
||||
definition_type,
|
||||
working_directory,
|
||||
)
|
||||
elif dbgpts_type == "operator":
|
||||
_create_operator_template(
|
||||
name,
|
||||
mod_name,
|
||||
dbgpts_type,
|
||||
base_metadata,
|
||||
definition_type,
|
||||
working_directory,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid dbgpts type: {dbgpts_type}")
|
||||
|
||||
|
||||
def _create_flow_template(
|
||||
name: str,
|
||||
mod_name: str,
|
||||
dbgpts_type: str,
|
||||
base_metadata: dict,
|
||||
definition_type: str,
|
||||
working_directory: str,
|
||||
):
|
||||
"""Create a new flow dbgpt"""
|
||||
|
||||
json_dict = {
|
||||
"flow": base_metadata,
|
||||
"python_config": {},
|
||||
"json_config": {},
|
||||
}
|
||||
if definition_type == "json":
|
||||
json_dict["json_config"] = {"file_path": "definition/flow_definition.json"}
|
||||
|
||||
_create_poetry_project(working_directory, name)
|
||||
_write_dbgpts_toml(working_directory, name, json_dict)
|
||||
_write_manifest_file(working_directory, name, mod_name)
|
||||
|
||||
if definition_type == "json":
|
||||
_write_flow_define_json_file(working_directory, name, mod_name)
|
||||
else:
|
||||
raise click.ClickException(
|
||||
f"Unsupported definition type: {definition_type} for dbgpts type: {dbgpts_type}"
|
||||
)
|
||||
|
||||
|
||||
def _create_operator_template(
|
||||
name: str,
|
||||
mod_name: str,
|
||||
dbgpts_type: str,
|
||||
base_metadata: dict,
|
||||
definition_type: str,
|
||||
working_directory: str,
|
||||
):
|
||||
"""Create a new operator dbgpt"""
|
||||
|
||||
json_dict = {
|
||||
"operator": base_metadata,
|
||||
"python_config": {},
|
||||
"json_config": {},
|
||||
}
|
||||
if definition_type != "python":
|
||||
raise click.ClickException(
|
||||
f"Unsupported definition type: {definition_type} for dbgpts type: "
|
||||
f"{dbgpts_type}"
|
||||
)
|
||||
|
||||
_create_poetry_project(working_directory, name)
|
||||
_write_dbgpts_toml(working_directory, name, json_dict)
|
||||
_write_operator_init_file(working_directory, name, mod_name)
|
||||
_write_manifest_file(working_directory, name, mod_name)
|
||||
|
||||
|
||||
def _create_poetry_project(working_directory: str, name: str):
|
||||
"""Create a new poetry project"""
|
||||
|
||||
os.chdir(working_directory)
|
||||
subprocess.run(["poetry", "new", name, "-n"], check=True)
|
||||
|
||||
|
||||
def _write_dbgpts_toml(working_directory: str, name: str, json_data: dict):
|
||||
"""Write the dbgpts.toml file"""
|
||||
|
||||
import tomlkit
|
||||
|
||||
with open(Path(working_directory) / name / DBGPTS_METADATA_FILE, "w") as f:
|
||||
tomlkit.dump(json_data, f)
|
||||
|
||||
|
||||
def _write_manifest_file(working_directory: str, name: str, mod_name: str):
|
||||
"""Write the manifest file"""
|
||||
|
||||
manifest = f"""include dbgpts.toml
|
||||
include {mod_name}/definition/*.json
|
||||
"""
|
||||
with open(Path(working_directory) / name / "MANIFEST.in", "w") as f:
|
||||
f.write(manifest)
|
||||
|
||||
|
||||
def _write_flow_define_json_file(working_directory: str, name: str, mod_name: str):
|
||||
"""Write the flow define json file"""
|
||||
|
||||
def_file = (
|
||||
Path(working_directory)
|
||||
/ name
|
||||
/ mod_name
|
||||
/ "definition"
|
||||
/ "flow_definition.json"
|
||||
)
|
||||
if not def_file.parent.exists():
|
||||
def_file.parent.mkdir(parents=True)
|
||||
with open(def_file, "w") as f:
|
||||
f.write("")
|
||||
print("Please write your flow json to the file: ", def_file)
|
||||
|
||||
|
||||
def _write_operator_init_file(working_directory: str, name: str, mod_name: str):
|
||||
"""Write the operator __init__.py file"""
|
||||
|
||||
init_file = Path(working_directory) / name / mod_name / "__init__.py"
|
||||
content = """
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core.awel.flow import ViewMetadata, OperatorCategory, IOField, Parameter
|
||||
|
||||
|
||||
class HelloWorldOperator(MapOperator[str, str]):
|
||||
# The metadata for AWEL flow
|
||||
metadata = ViewMetadata(
|
||||
label="Hello World Operator",
|
||||
name="hello_world_operator",
|
||||
category=OperatorCategory.COMMON,
|
||||
description="A example operator to say hello to someone.",
|
||||
parameters=[
|
||||
Parameter.build_from(
|
||||
"Name",
|
||||
"name",
|
||||
str,
|
||||
optional=True,
|
||||
default="World",
|
||||
description="The name to say hello",
|
||||
)
|
||||
],
|
||||
inputs=[
|
||||
IOField.build_from(
|
||||
"Input value",
|
||||
"value",
|
||||
str,
|
||||
description="The input value to say hello",
|
||||
)
|
||||
],
|
||||
outputs=[
|
||||
IOField.build_from(
|
||||
"Output value",
|
||||
"value",
|
||||
str,
|
||||
description="The output value after saying hello",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
def __init__(self, name: str = "World", **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.name = name
|
||||
|
||||
async def map(self, value: str) -> str:
|
||||
return f"Hello, {self.name}! {value}"
|
||||
"""
|
||||
with open(init_file, "w") as f:
|
||||
f.write(f'"""{name} operator package"""\n{content}')
|
Reference in New Issue
Block a user