mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-12 20:53:48 +00:00
feat(core): Multiple ways to run dbgpts (#1734)
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
DAG is the core component of AWEL, it is used to define the relationship between tasks.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import logging
|
||||
@@ -613,10 +614,14 @@ class DAG:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, dag_id: str, resource_group: Optional[ResourceGroup] = None
|
||||
self,
|
||||
dag_id: str,
|
||||
resource_group: Optional[ResourceGroup] = None,
|
||||
tags: Optional[Dict[str, str]] = None,
|
||||
) -> None:
|
||||
"""Initialize a DAG."""
|
||||
self._dag_id = dag_id
|
||||
self._tags: Dict[str, str] = tags or {}
|
||||
self.node_map: Dict[str, DAGNode] = {}
|
||||
self.node_name_to_node: Dict[str, DAGNode] = {}
|
||||
self._root_nodes: List[DAGNode] = []
|
||||
@@ -651,6 +656,22 @@ class DAG:
|
||||
"""Return the dag id of current DAG."""
|
||||
return self._dag_id
|
||||
|
||||
@property
|
||||
def tags(self) -> Dict[str, str]:
|
||||
"""Return the tags of current DAG."""
|
||||
return self._tags
|
||||
|
||||
@property
|
||||
def dev_mode(self) -> bool:
|
||||
"""Whether the current DAG is in dev mode.
|
||||
|
||||
Returns:
|
||||
bool: Whether the current DAG is in dev mode
|
||||
"""
|
||||
from ..operators.base import _dev_mode
|
||||
|
||||
return _dev_mode()
|
||||
|
||||
def _build(self) -> None:
|
||||
from ..operators.common_operator import TriggerOperator
|
||||
|
||||
|
@@ -3,18 +3,49 @@
|
||||
DAGManager will load DAGs from dag_dirs, and register the trigger nodes
|
||||
to TriggerManager.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Dict, List, Optional
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field, model_to_dict
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
|
||||
from .. import BaseOperator
|
||||
from ..trigger.base import TriggerMetadata
|
||||
from .base import DAG
|
||||
from .loader import LocalFileDAGLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DAGMetadata(BaseModel):
|
||||
"""Metadata for the DAG."""
|
||||
|
||||
triggers: List[TriggerMetadata] = Field(
|
||||
default_factory=list, description="The trigger metadata"
|
||||
)
|
||||
sse_output: bool = Field(
|
||||
default=False, description="Whether the DAG is a server-sent event output"
|
||||
)
|
||||
streaming_output: bool = Field(
|
||||
default=False, description="Whether the DAG is a streaming output"
|
||||
)
|
||||
tags: Optional[Dict[str, str]] = Field(
|
||||
default=None, description="The tags of the DAG"
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert the metadata to dict."""
|
||||
triggers_dict = []
|
||||
for trigger in self.triggers:
|
||||
triggers_dict.append(trigger.dict())
|
||||
dict_value = model_to_dict(self, exclude={"triggers"})
|
||||
dict_value["triggers"] = triggers_dict
|
||||
return dict_value
|
||||
|
||||
|
||||
class DAGManager(BaseComponent):
|
||||
"""The component of DAGManager."""
|
||||
|
||||
@@ -35,6 +66,8 @@ class DAGManager(BaseComponent):
|
||||
self.system_app = system_app
|
||||
self.dag_map: Dict[str, DAG] = {}
|
||||
self.dag_alias_map: Dict[str, str] = {}
|
||||
self._dag_metadata_map: Dict[str, DAGMetadata] = {}
|
||||
self._tags_to_dag_ids: Dict[str, Dict[str, Set[str]]] = {}
|
||||
self._trigger_manager: Optional["DefaultTriggerManager"] = None
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
@@ -73,12 +106,26 @@ class DAGManager(BaseComponent):
|
||||
if alias_name:
|
||||
self.dag_alias_map[alias_name] = dag_id
|
||||
|
||||
trigger_metadata: List["TriggerMetadata"] = []
|
||||
dag_metadata = _parse_metadata(dag)
|
||||
if self._trigger_manager:
|
||||
for trigger in dag.trigger_nodes:
|
||||
self._trigger_manager.register_trigger(trigger, self.system_app)
|
||||
tm = self._trigger_manager.register_trigger(
|
||||
trigger, self.system_app
|
||||
)
|
||||
if tm:
|
||||
trigger_metadata.append(tm)
|
||||
self._trigger_manager.after_register()
|
||||
else:
|
||||
logger.warning("No trigger manager, not register dag trigger")
|
||||
dag_metadata.triggers = trigger_metadata
|
||||
self._dag_metadata_map[dag_id] = dag_metadata
|
||||
tags = dag_metadata.tags
|
||||
if tags:
|
||||
for tag_key, tag_value in tags.items():
|
||||
if tag_key not in self._tags_to_dag_ids:
|
||||
self._tags_to_dag_ids[tag_key] = defaultdict(set)
|
||||
self._tags_to_dag_ids[tag_key][tag_value].add(dag_id)
|
||||
|
||||
def unregister_dag(self, dag_id: str):
|
||||
"""Unregister a DAG."""
|
||||
@@ -104,7 +151,13 @@ class DAGManager(BaseComponent):
|
||||
for trigger in dag.trigger_nodes:
|
||||
self._trigger_manager.unregister_trigger(trigger, self.system_app)
|
||||
# Finally remove the DAG from the map
|
||||
metadata = self._dag_metadata_map[dag_id]
|
||||
del self.dag_map[dag_id]
|
||||
del self._dag_metadata_map[dag_id]
|
||||
if metadata.tags:
|
||||
for tag_key, tag_value in metadata.tags.items():
|
||||
if tag_key in self._tags_to_dag_ids:
|
||||
self._tags_to_dag_ids[tag_key][tag_value].remove(dag_id)
|
||||
|
||||
def get_dag(
|
||||
self, dag_id: Optional[str] = None, alias_name: Optional[str] = None
|
||||
@@ -116,3 +169,33 @@ class DAGManager(BaseComponent):
|
||||
if alias_name in self.dag_alias_map:
|
||||
return self.dag_map.get(self.dag_alias_map[alias_name])
|
||||
return None
|
||||
|
||||
def get_dags_by_tag(self, tag_key: str, tag_value) -> List[DAG]:
|
||||
"""Get all DAGs with the given tag."""
|
||||
with self.lock:
|
||||
dag_ids = self._tags_to_dag_ids.get(tag_key, {}).get(tag_value, set())
|
||||
return [self.dag_map[dag_id] for dag_id in dag_ids]
|
||||
|
||||
def get_dag_metadata(
|
||||
self, dag_id: Optional[str] = None, alias_name: Optional[str] = None
|
||||
) -> Optional[DAGMetadata]:
|
||||
"""Get a DAGMetadata by dag_id or alias_name."""
|
||||
dag = self.get_dag(dag_id, alias_name)
|
||||
if not dag:
|
||||
return None
|
||||
return self._dag_metadata_map.get(dag.dag_id)
|
||||
|
||||
|
||||
def _parse_metadata(dag: DAG):
|
||||
from ..util.chat_util import _is_sse_output
|
||||
|
||||
metadata = DAGMetadata()
|
||||
metadata.tags = dag.tags
|
||||
if not dag.leaf_nodes:
|
||||
return metadata
|
||||
end_node = dag.leaf_nodes[0]
|
||||
if not isinstance(end_node, BaseOperator):
|
||||
return metadata
|
||||
metadata.sse_output = _is_sse_output(end_node)
|
||||
metadata.streaming_output = end_node.streaming_operator
|
||||
return metadata
|
||||
|
Reference in New Issue
Block a user