feat(core): Dynamically loading dbgpts (#1211)

This commit is contained in:
Fangyin Cheng
2024-02-29 15:57:49 +08:00
committed by GitHub
parent 673ddaab5b
commit 1d90711952
15 changed files with 504 additions and 33 deletions

View File

@@ -21,6 +21,7 @@ def initialize_components(
): ):
# Lazy import to avoid high time cost # Lazy import to avoid high time cost
from dbgpt.app.initialization.embedding_component import _initialize_embedding_model from dbgpt.app.initialization.embedding_component import _initialize_embedding_model
from dbgpt.app.initialization.scheduler import DefaultScheduler
from dbgpt.app.initialization.serve_initialization import register_serve_apps from dbgpt.app.initialization.serve_initialization import register_serve_apps
from dbgpt.model.cluster.controller.controller import controller from dbgpt.model.cluster.controller.controller import controller
@@ -28,6 +29,7 @@ def initialize_components(
system_app.register( system_app.register(
DefaultExecutorFactory, max_workers=param.default_thread_pool_size DefaultExecutorFactory, max_workers=param.default_thread_pool_size
) )
system_app.register(DefaultScheduler)
system_app.register_instance(controller) system_app.register_instance(controller)
from dbgpt.serve.agent.hub.controller import module_plugin from dbgpt.serve.agent.hub.controller import module_plugin

View File

@@ -0,0 +1,43 @@
import logging
import threading
import time
import schedule
from dbgpt.component import BaseComponent, SystemApp
logger = logging.getLogger(__name__)
class DefaultScheduler(BaseComponent):
"""The default scheduler"""
name = "dbgpt_default_scheduler"
def __init__(
self,
system_app: SystemApp,
scheduler_delay_ms: int = 5000,
scheduler_interval_ms: int = 1000,
):
super().__init__(system_app)
self.system_app = system_app
self._scheduler_interval_ms = scheduler_interval_ms
self._scheduler_delay_ms = scheduler_delay_ms
def init_app(self, system_app: SystemApp):
self.system_app = system_app
def after_start(self):
thread = threading.Thread(target=self._scheduler)
thread.start()
def _scheduler(self):
time.sleep(self._scheduler_delay_ms / 1000)
while True:
try:
schedule.run_pending()
except Exception as e:
logger.debug(f"Scheduler error: {e}")
finally:
time.sleep(self._scheduler_interval_ms / 1000)

View File

@@ -163,7 +163,7 @@ try:
from dbgpt.util.dbgpts.cli import add_repo from dbgpt.util.dbgpts.cli import add_repo
from dbgpt.util.dbgpts.cli import install as app_install from dbgpt.util.dbgpts.cli import install as app_install
from dbgpt.util.dbgpts.cli import list_all_apps as app_list from dbgpt.util.dbgpts.cli import list_all_apps as app_list
from dbgpt.util.dbgpts.cli import list_repos, remove_repo from dbgpt.util.dbgpts.cli import list_repos, new_dbgpts, remove_repo
from dbgpt.util.dbgpts.cli import uninstall as app_uninstall from dbgpt.util.dbgpts.cli import uninstall as app_uninstall
from dbgpt.util.dbgpts.cli import update_repo from dbgpt.util.dbgpts.cli import update_repo
@@ -174,6 +174,7 @@ try:
add_command_alias(app_install, name="install", parent_group=app) add_command_alias(app_install, name="install", parent_group=app)
add_command_alias(app_uninstall, name="uninstall", parent_group=app) add_command_alias(app_uninstall, name="uninstall", parent_group=app)
add_command_alias(app_list, name="list-remote", parent_group=app) add_command_alias(app_list, name="list-remote", parent_group=app)
add_command_alias(new_dbgpts, name="app", parent_group=new)
except ImportError as e: except ImportError as e:
logging.warning(f"Integrating dbgpt dbgpts command line tool failed: {e}") logging.warning(f"Integrating dbgpt dbgpts command line tool failed: {e}")

View File

@@ -63,19 +63,23 @@ def _process_file(filepath) -> List[DAG]:
return results return results
def _load_modules_from_file(filepath: str): def _load_modules_from_file(
filepath: str, mod_name: str | None = None, show_log: bool = True
):
import importlib import importlib
import importlib.machinery import importlib.machinery
import importlib.util import importlib.util
logger.info(f"Importing {filepath}") if show_log:
logger.info(f"Importing {filepath}")
org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1]) org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1])
path_hash = hashlib.sha1(filepath.encode("utf-8")).hexdigest() path_hash = hashlib.sha1(filepath.encode("utf-8")).hexdigest()
mod_name = f"unusual_prefix_{path_hash}_{org_mod_name}" if mod_name is None:
mod_name = f"unusual_prefix_{path_hash}_{org_mod_name}"
if mod_name in sys.modules: if mod_name in sys.modules:
del sys.modules[mod_name] del sys.modules[mod_name]
def parse(mod_name, filepath): def parse(mod_name, filepath):
try: try:

View File

@@ -1046,7 +1046,7 @@ class RequestedParsedOperator(MapOperator[CommonLLMHttpRequestBody, str]):
"key", "key",
str, str,
optional=True, optional=True,
default="", default="messages",
description="The key of the dict, link 'user_input'", description="The key of the dict, link 'user_input'",
) )
], ],

View File

@@ -59,7 +59,7 @@ class WorkerApplyOutput:
return WorkerApplyOutput("Not outputs") return WorkerApplyOutput("Not outputs")
combined_success = all(out.success for out in outs) combined_success = all(out.success for out in outs)
max_timecost = max(out.timecost for out in outs) max_timecost = max(out.timecost for out in outs)
combined_message = ", ".join(out.message for out in outs) combined_message = "\n;".join(out.message for out in outs)
return WorkerApplyOutput(combined_message, combined_success, max_timecost) return WorkerApplyOutput(combined_message, combined_success, max_timecost)

View File

@@ -6,6 +6,7 @@ import os
import random import random
import sys import sys
import time import time
import traceback
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict from dataclasses import asdict
from typing import Awaitable, Callable, Iterator from typing import Awaitable, Callable, Iterator
@@ -490,6 +491,8 @@ class LocalWorkerManager(WorkerManager):
async def _start_all_worker( async def _start_all_worker(
self, apply_req: WorkerApplyRequest self, apply_req: WorkerApplyRequest
) -> WorkerApplyOutput: ) -> WorkerApplyOutput:
from httpx import TimeoutException, TransportError
# TODO avoid start twice # TODO avoid start twice
start_time = time.time() start_time = time.time()
logger.info(f"Begin start all worker, apply_req: {apply_req}") logger.info(f"Begin start all worker, apply_req: {apply_req}")
@@ -520,9 +523,24 @@ class LocalWorkerManager(WorkerManager):
) )
) )
out.message = f"{info} start successfully" out.message = f"{info} start successfully"
except Exception as e: except TimeoutException as e:
out.success = False out.success = False
out.message = f"{info} start failed, {str(e)}" out.message = (
f"{info} start failed for network timeout, please make "
f"sure your port is available, if you are using global network "
f"proxy, please close it"
)
except TransportError as e:
out.success = False
out.message = (
f"{info} start failed for network error, please make "
f"sure your port is available, if you are using global network "
"proxy, please close it"
)
except Exception:
err_msg = traceback.format_exc()
out.success = False
out.message = f"{info} start failed, {err_msg}"
finally: finally:
out.timecost = time.time() - _start_time out.timecost = time.time() - _start_time
return out return out
@@ -837,10 +855,13 @@ def _setup_fastapi(
try: try:
await worker_manager.start() await worker_manager.start()
except Exception as e: except Exception as e:
logger.error(f"Error starting worker manager: {e}") import signal
sys.exit(1)
# It cannot be blocked here because the startup of worker_manager depends on the fastapi app (registered to the controller) logger.error(f"Error starting worker manager: {str(e)}")
os.kill(os.getpid(), signal.SIGINT)
# It cannot be blocked here because the startup of worker_manager depends on
# the fastapi app (registered to the controller)
asyncio.create_task(start_worker_manager()) asyncio.create_task(start_worker_manager())
@app.on_event("shutdown") @app.on_event("shutdown")

View File

@@ -20,3 +20,6 @@ class ServeConfig(BaseServeConfig):
api_keys: Optional[str] = field( api_keys: Optional[str] = field(
default=None, metadata={"help": "API keys for the endpoint, if None, allow all"} default=None, metadata={"help": "API keys for the endpoint, if None, allow all"}
) )
load_dbgpts_interval: int = field(
default=5, metadata={"help": "Interval to load dbgpts from installed packages"}
)

View File

@@ -3,6 +3,7 @@ import logging
import traceback import traceback
from typing import Any, List, Optional, cast from typing import Any, List, Optional, cast
import schedule
from fastapi import HTTPException from fastapi import HTTPException
from dbgpt.component import SystemApp from dbgpt.component import SystemApp
@@ -56,7 +57,10 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
self._dao = self._dao or ServeDao(self._serve_config) self._dao = self._dao or ServeDao(self._serve_config)
self._system_app = system_app self._system_app = system_app
self._dbgpts_loader = system_app.get_component( self._dbgpts_loader = system_app.get_component(
DBGPTsLoader.name, DBGPTsLoader, or_register_component=DBGPTsLoader DBGPTsLoader.name,
DBGPTsLoader,
or_register_component=DBGPTsLoader,
load_dbgpts_interval=self._serve_config.load_dbgpts_interval,
) )
def before_start(self): def before_start(self):
@@ -68,7 +72,10 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
def after_start(self): def after_start(self):
"""Execute after the application starts""" """Execute after the application starts"""
self.load_dag_from_db() self.load_dag_from_db()
self.load_dag_from_dbgpts() self.load_dag_from_dbgpts(is_first_load=True)
schedule.every(self._serve_config.load_dbgpts_interval).seconds.do(
self.load_dag_from_dbgpts
)
@property @property
def dao(self) -> BaseDao[ServeEntity, ServeRequest, ServerResponse]: def dao(self) -> BaseDao[ServeEntity, ServeRequest, ServerResponse]:
@@ -126,6 +133,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
if save_failed_flow: if save_failed_flow:
request.state = State.LOAD_FAILED request.state = State.LOAD_FAILED
request.error_message = str(e) request.error_message = str(e)
request.dag_id = ""
return self.dao.create(request) return self.dao.create(request)
else: else:
raise e raise e
@@ -147,6 +155,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
if save_failed_flow: if save_failed_flow:
request.state = State.LOAD_FAILED request.state = State.LOAD_FAILED
request.error_message = f"Register DAG error: {str(e)}" request.error_message = f"Register DAG error: {str(e)}"
request.dag_id = ""
self.dao.update({"uid": request.uid}, request) self.dao.update({"uid": request.uid}, request)
else: else:
# Rollback # Rollback
@@ -198,7 +207,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
f"dbgpts error: {str(e)}" f"dbgpts error: {str(e)}"
) )
def load_dag_from_dbgpts(self): def load_dag_from_dbgpts(self, is_first_load: bool = False):
"""Load DAG from dbgpts""" """Load DAG from dbgpts"""
flows = self.dbgpts_loader.get_flows() flows = self.dbgpts_loader.get_flows()
for flow in flows: for flow in flows:
@@ -208,7 +217,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
exist_inst = self.get({"name": flow.name}) exist_inst = self.get({"name": flow.name})
if not exist_inst: if not exist_inst:
self.create_and_save_dag(flow, save_failed_flow=True) self.create_and_save_dag(flow, save_failed_flow=True)
else: elif is_first_load or exist_inst.state != State.RUNNING:
# TODO check version, must be greater than the exist one # TODO check version, must be greater than the exist one
flow.uid = exist_inst.uid flow.uid = exist_inst.uid
self.update_flow(flow, check_editable=False, save_failed_flow=True) self.update_flow(flow, check_editable=False, save_failed_flow=True)
@@ -242,6 +251,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
if save_failed_flow: if save_failed_flow:
request.state = State.LOAD_FAILED request.state = State.LOAD_FAILED
request.error_message = str(e) request.error_message = str(e)
request.dag_id = ""
return self.dao.update({"uid": request.uid}, request) return self.dao.update({"uid": request.uid}, request)
else: else:
raise e raise e
@@ -306,12 +316,13 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
inst = self.get(query_request) inst = self.get(query_request)
if inst is None: if inst is None:
raise HTTPException(status_code=404, detail=f"Flow {uid} not found") raise HTTPException(status_code=404, detail=f"Flow {uid} not found")
if not inst.dag_id: if inst.state == State.RUNNING and not inst.dag_id:
raise HTTPException( raise HTTPException(
status_code=404, detail=f"Flow {uid}'s dag id not found" status_code=404, detail=f"Running flow {uid}'s dag id not found"
) )
try: try:
self.dag_manager.unregister_dag(inst.dag_id) if inst.dag_id:
self.dag_manager.unregister_dag(inst.dag_id)
except Exception as e: except Exception as e:
logger.warning(f"Unregister DAG({inst.dag_id}) error: {str(e)}") logger.warning(f"Unregister DAG({inst.dag_id}) error: {str(e)}")
self.dao.delete(query_request) self.dao.delete(query_request)

View File

@@ -23,6 +23,13 @@ DEFAULT_PACKAGE_TYPES = ["agent", "app", "operator", "flow"]
INSTALL_METADATA_FILE = "install_metadata.toml" INSTALL_METADATA_FILE = "install_metadata.toml"
DBGPTS_METADATA_FILE = "dbgpts.toml" DBGPTS_METADATA_FILE = "dbgpts.toml"
TYPE_TO_PACKAGE = {
"agent": "agents",
"app": "apps",
"operator": "operators",
"flow": "workflow",
}
def _get_env_sig() -> str: def _get_env_sig() -> str:
"""Get a unique signature for the current Python environment.""" """Get a unique signature for the current Python environment."""

View File

@@ -1,7 +1,31 @@
import functools import functools
import subprocess
import sys
from pathlib import Path
import click import click
from .base import DEFAULT_PACKAGE_TYPES
def check_poetry_installed():
try:
# Check if poetry is installed
subprocess.run(
["poetry", "--version"],
check=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
except (subprocess.CalledProcessError, FileNotFoundError):
print("Poetry is not installed. Please install Poetry to proceed.")
print(
"Visit https://python-poetry.org/docs/#installation for installation "
"instructions."
)
# Exit with error
sys.exit(1)
def add_tap_options(func): def add_tap_options(func):
@click.option( @click.option(
@@ -26,6 +50,7 @@ def install(repo: str | None, name: str):
"""Install your dbgpts(operators,agents,workflows or apps)""" """Install your dbgpts(operators,agents,workflows or apps)"""
from .repo import install from .repo import install
check_poetry_installed()
install(name, repo) install(name, repo)
@@ -108,3 +133,84 @@ def update_repo(repo: str | None):
else: else:
print(f"Updating repo '{p}'...") print(f"Updating repo '{p}'...")
update_repo(p) update_repo(p)
@click.command(name="app")
@click.option(
"-n",
"--name",
type=str,
required=True,
help="The name you want to give to the dbgpt",
)
@click.option(
"-l",
"--label",
type=str,
default=None,
required=False,
help="The label of the dbgpt",
)
@click.option(
"-d",
"--description",
type=str,
default=None,
required=False,
help="The description of the dbgpt",
)
@click.option(
"-t",
"--type",
type=click.Choice(DEFAULT_PACKAGE_TYPES),
default="flow",
required=False,
help="The type of the dbgpt",
)
@click.option(
"--definition_type",
type=click.Choice(["json", "python"]),
default="json",
required=False,
help="The definition type of the dbgpt",
)
@click.option(
"-C",
"--directory",
type=str,
default=None,
required=False,
help="The working directory of the dbgpt(defaults to the current directory).",
)
def new_dbgpts(
name: str,
label: str | None,
description: str | None,
type: str,
definition_type: str,
directory: str | None,
):
"""New a dbgpts module structure"""
if not label:
# Set label to the name
default_label = name.replace("-", " ").replace("_", " ").title()
label = click.prompt(
"Please input the label of the dbgpt", default=default_label
)
if not description:
# Read with click
description = click.prompt(
"Please input the description of the dbgpt", default=""
)
if not directory:
# Set directory to the current directory(abs path)
directory = click.prompt(
"Please input the working directory of the dbgpt",
default=str(Path.cwd()),
type=click.Path(exists=True, file_okay=False, dir_okay=True),
)
check_poetry_installed()
from .template import create_template
create_template(name, label, description, type, definition_type, directory)

View File

@@ -1,8 +1,11 @@
import inspect
import logging import logging
import os import os
import sys
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, cast from typing import Any, Dict, List, Optional, cast
import schedule
import tomlkit import tomlkit
from dbgpt._private.pydantic import BaseModel, Field, root_validator from dbgpt._private.pydantic import BaseModel, Field, root_validator
@@ -36,6 +39,7 @@ class BasePackage(BaseModel):
definition_file: Optional[str] = Field( definition_file: Optional[str] = Field(
default=None, description="The definition " "file of the package" default=None, description="The definition " "file of the package"
) )
root: str = Field(..., description="The root of the package")
repo: str = Field(..., description="The repository of the package") repo: str = Field(..., description="The repository of the package")
@classmethod @classmethod
@@ -48,8 +52,13 @@ class BasePackage(BaseModel):
import importlib.resources as pkg_resources import importlib.resources as pkg_resources
name = values.get("name") name = values.get("name")
root = values.get("root")
if not name: if not name:
raise ValueError("The name is required") raise ValueError("The name is required")
if not root:
raise ValueError("The root is required")
if root not in sys.path:
sys.path.append(root)
with pkg_resources.path(name, "__init__.py") as path: with pkg_resources.path(name, "__init__.py") as path:
# Read the file # Read the file
values["path"] = os.path.dirname(os.path.abspath(path)) values["path"] = os.path.dirname(os.path.abspath(path))
@@ -91,6 +100,32 @@ class FlowJsonPackage(FlowPackage):
class OperatorPackage(BasePackage): class OperatorPackage(BasePackage):
package_type = "operator" package_type = "operator"
operators: List[type] = Field(
default_factory=list, description="The operators of the package"
)
@classmethod
def build_from(cls, values: Dict[str, Any], ext_dict: Dict[str, Any]):
import importlib.resources as pkg_resources
from dbgpt.core.awel import BaseOperator
from dbgpt.core.awel.dag.loader import _load_modules_from_file
name = values.get("name")
root = values.get("root")
if root not in sys.path:
sys.path.append(root)
with pkg_resources.path(name, "__init__.py") as path:
mods = _load_modules_from_file(str(path), name, show_log=False)
all_cls = [_get_classes_from_module(m) for m in mods]
operators = []
for list_cls in all_cls:
for c in list_cls:
if issubclass(c, BaseOperator):
operators.append(c)
values["operators"] = operators
return cls(**values)
class InstalledPackage(BaseModel): class InstalledPackage(BaseModel):
name: str = Field(..., description="The name of the package") name: str = Field(..., description="The name of the package")
@@ -98,6 +133,15 @@ class InstalledPackage(BaseModel):
root: str = Field(..., description="The root of the package") root: str = Field(..., description="The root of the package")
def _get_classes_from_module(module):
classes = [
obj
for name, obj in inspect.getmembers(module, inspect.isclass)
if obj.__module__ == module.__name__
]
return classes
def _parse_package_metadata(package: InstalledPackage) -> BasePackage: def _parse_package_metadata(package: InstalledPackage) -> BasePackage:
with open( with open(
Path(package.root) / DBGPTS_METADATA_FILE, mode="r+", encoding="utf-8" Path(package.root) / DBGPTS_METADATA_FILE, mode="r+", encoding="utf-8"
@@ -109,11 +153,17 @@ def _parse_package_metadata(package: InstalledPackage) -> BasePackage:
if key == "flow": if key == "flow":
pkg_dict = value pkg_dict = value
pkg_dict["package_type"] = "flow" pkg_dict["package_type"] = "flow"
elif key == "operator":
pkg_dict = {k: v for k, v in value.items()}
pkg_dict["package_type"] = "operator"
else: else:
ext_metadata[key] = value ext_metadata[key] = value
pkg_dict["root"] = package.root
pkg_dict["repo"] = package.repo pkg_dict["repo"] = package.repo
if pkg_dict["package_type"] == "flow": if pkg_dict["package_type"] == "flow":
return FlowPackage.build_from(pkg_dict, ext_metadata) return FlowPackage.build_from(pkg_dict, ext_metadata)
elif pkg_dict["package_type"] == "operator":
return OperatorPackage.build_from(pkg_dict, ext_metadata)
else: else:
raise ValueError( raise ValueError(
f"Unsupported package package_type: {pkg_dict['package_type']}" f"Unsupported package package_type: {pkg_dict['package_type']}"
@@ -156,12 +206,16 @@ class DBGPTsLoader(BaseComponent):
name = "dbgpt_dbgpts_loader" name = "dbgpt_dbgpts_loader"
def __init__( def __init__(
self, system_app: Optional[SystemApp] = None, install_dir: Optional[str] = None self,
system_app: Optional[SystemApp] = None,
install_dir: Optional[str] = None,
load_dbgpts_interval: int = 10,
): ):
"""Initialize the DBGPTsLoader.""" """Initialize the DBGPTsLoader."""
self._system_app = None self._system_app = None
self._install_dir = install_dir or INSTALL_DIR self._install_dir = install_dir or INSTALL_DIR
self._packages: Dict[str, BasePackage] = {} self._packages: Dict[str, BasePackage] = {}
self._load_dbgpts_interval = load_dbgpts_interval
super().__init__(system_app) super().__init__(system_app)
def init_app(self, system_app: SystemApp): def init_app(self, system_app: SystemApp):
@@ -170,15 +224,18 @@ class DBGPTsLoader(BaseComponent):
def before_start(self): def before_start(self):
"""Execute after the application starts.""" """Execute after the application starts."""
self.load_package() self.load_package(is_first=True)
def load_package(self) -> None: schedule.every(self._load_dbgpts_interval).seconds.do(self.load_package)
def load_package(self, is_first: bool = False) -> None:
"""Load the package by name.""" """Load the package by name."""
try: try:
packages = _load_package_from_path(self._install_dir) packages = _load_package_from_path(self._install_dir)
logger.info( if is_first:
f"Found {len(packages)} dbgpts packages from {self._install_dir}" logger.info(
) f"Found {len(packages)} dbgpts packages from {self._install_dir}"
)
for package in packages: for package in packages:
self._packages[package.name] = package self._packages[package.name] = package
except Exception as e: except Exception as e:

View File

@@ -214,13 +214,18 @@ def _copy_and_install(repo: str, name: str, package_path: Path):
err=True, err=True,
) )
return return
shutil.copytree(package_path, install_path) try:
logger.info(f"Installing dbgpts '{name}' from {repo}...") shutil.copytree(package_path, install_path)
os.chdir(install_path) logger.info(f"Installing dbgpts '{name}' from {repo}...")
subprocess.run(["poetry", "install"], check=True) os.chdir(install_path)
_write_install_metadata(name, repo, install_path) subprocess.run(["poetry", "install"], check=True)
click.echo(f"Installed dbgpts at {_print_path(install_path)}.") _write_install_metadata(name, repo, install_path)
click.echo(f"dbgpts '{name}' installed successfully.") click.echo(f"Installed dbgpts at {_print_path(install_path)}.")
click.echo(f"dbgpts '{name}' installed successfully.")
except Exception as e:
if install_path.exists():
shutil.rmtree(install_path)
raise e
def _write_install_metadata(name: str, repo: str, install_path: Path): def _write_install_metadata(name: str, repo: str, install_path: Path):

View File

@@ -0,0 +1,209 @@
import os
import subprocess
from pathlib import Path
import click
from .base import DBGPTS_METADATA_FILE, TYPE_TO_PACKAGE
def create_template(
name: str,
name_label: str,
description: str,
dbgpts_type: str,
definition_type: str,
working_directory: str,
):
"""Create a new flow dbgpt"""
if dbgpts_type != "flow":
definition_type = "python"
mod_name = name.replace("-", "_")
base_metadata = {
"label": name_label,
"name": mod_name,
"version": "0.1.0",
"description": description,
"authors": [],
"definition_type": definition_type,
}
working_directory = os.path.join(working_directory, TYPE_TO_PACKAGE[dbgpts_type])
package_dir = Path(working_directory) / name
if os.path.exists(package_dir):
raise click.ClickException(f"Package '{str(package_dir)}' already exists")
if dbgpts_type == "flow":
_create_flow_template(
name,
mod_name,
dbgpts_type,
base_metadata,
definition_type,
working_directory,
)
elif dbgpts_type == "operator":
_create_operator_template(
name,
mod_name,
dbgpts_type,
base_metadata,
definition_type,
working_directory,
)
else:
raise ValueError(f"Invalid dbgpts type: {dbgpts_type}")
def _create_flow_template(
name: str,
mod_name: str,
dbgpts_type: str,
base_metadata: dict,
definition_type: str,
working_directory: str,
):
"""Create a new flow dbgpt"""
json_dict = {
"flow": base_metadata,
"python_config": {},
"json_config": {},
}
if definition_type == "json":
json_dict["json_config"] = {"file_path": "definition/flow_definition.json"}
_create_poetry_project(working_directory, name)
_write_dbgpts_toml(working_directory, name, json_dict)
_write_manifest_file(working_directory, name, mod_name)
if definition_type == "json":
_write_flow_define_json_file(working_directory, name, mod_name)
else:
raise click.ClickException(
f"Unsupported definition type: {definition_type} for dbgpts type: {dbgpts_type}"
)
def _create_operator_template(
name: str,
mod_name: str,
dbgpts_type: str,
base_metadata: dict,
definition_type: str,
working_directory: str,
):
"""Create a new operator dbgpt"""
json_dict = {
"operator": base_metadata,
"python_config": {},
"json_config": {},
}
if definition_type != "python":
raise click.ClickException(
f"Unsupported definition type: {definition_type} for dbgpts type: "
f"{dbgpts_type}"
)
_create_poetry_project(working_directory, name)
_write_dbgpts_toml(working_directory, name, json_dict)
_write_operator_init_file(working_directory, name, mod_name)
_write_manifest_file(working_directory, name, mod_name)
def _create_poetry_project(working_directory: str, name: str):
"""Create a new poetry project"""
os.chdir(working_directory)
subprocess.run(["poetry", "new", name, "-n"], check=True)
def _write_dbgpts_toml(working_directory: str, name: str, json_data: dict):
"""Write the dbgpts.toml file"""
import tomlkit
with open(Path(working_directory) / name / DBGPTS_METADATA_FILE, "w") as f:
tomlkit.dump(json_data, f)
def _write_manifest_file(working_directory: str, name: str, mod_name: str):
"""Write the manifest file"""
manifest = f"""include dbgpts.toml
include {mod_name}/definition/*.json
"""
with open(Path(working_directory) / name / "MANIFEST.in", "w") as f:
f.write(manifest)
def _write_flow_define_json_file(working_directory: str, name: str, mod_name: str):
"""Write the flow define json file"""
def_file = (
Path(working_directory)
/ name
/ mod_name
/ "definition"
/ "flow_definition.json"
)
if not def_file.parent.exists():
def_file.parent.mkdir(parents=True)
with open(def_file, "w") as f:
f.write("")
print("Please write your flow json to the file: ", def_file)
def _write_operator_init_file(working_directory: str, name: str, mod_name: str):
"""Write the operator __init__.py file"""
init_file = Path(working_directory) / name / mod_name / "__init__.py"
content = """
from dbgpt.core.awel import MapOperator
from dbgpt.core.awel.flow import ViewMetadata, OperatorCategory, IOField, Parameter
class HelloWorldOperator(MapOperator[str, str]):
# The metadata for AWEL flow
metadata = ViewMetadata(
label="Hello World Operator",
name="hello_world_operator",
category=OperatorCategory.COMMON,
description="A example operator to say hello to someone.",
parameters=[
Parameter.build_from(
"Name",
"name",
str,
optional=True,
default="World",
description="The name to say hello",
)
],
inputs=[
IOField.build_from(
"Input value",
"value",
str,
description="The input value to say hello",
)
],
outputs=[
IOField.build_from(
"Output value",
"value",
str,
description="The output value after saying hello",
)
]
)
def __init__(self, name: str = "World", **kwargs):
super().__init__(**kwargs)
self.name = name
async def map(self, value: str) -> str:
return f"Hello, {self.name}! {value}"
"""
with open(init_file, "w") as f:
f.write(f'"""{name} operator package"""\n{content}')

View File

@@ -400,6 +400,8 @@ def core_requires():
"sqlparse==0.4.4", "sqlparse==0.4.4",
"duckdb==0.8.1", "duckdb==0.8.1",
"duckdb-engine", "duckdb-engine",
# lightweight python library for scheduling jobs
"schedule",
] ]
# TODO: remove fschat from simple_framework # TODO: remove fschat from simple_framework
if BUILD_FROM_SOURCE: if BUILD_FROM_SOURCE: