feat(core): Multiple ways to run dbgpts (#1734)

This commit is contained in:
Fangyin Cheng
2024-07-18 17:50:40 +08:00
committed by GitHub
parent d389fddc2f
commit f889fa3775
19 changed files with 1410 additions and 304 deletions

View File

@@ -2,6 +2,7 @@
DAG is the core component of AWEL, it is used to define the relationship between tasks.
"""
import asyncio
import contextvars
import logging
@@ -613,10 +614,14 @@ class DAG:
"""
def __init__(
self, dag_id: str, resource_group: Optional[ResourceGroup] = None
self,
dag_id: str,
resource_group: Optional[ResourceGroup] = None,
tags: Optional[Dict[str, str]] = None,
) -> None:
"""Initialize a DAG."""
self._dag_id = dag_id
self._tags: Dict[str, str] = tags or {}
self.node_map: Dict[str, DAGNode] = {}
self.node_name_to_node: Dict[str, DAGNode] = {}
self._root_nodes: List[DAGNode] = []
@@ -651,6 +656,22 @@ class DAG:
"""Return the dag id of current DAG."""
return self._dag_id
@property
def tags(self) -> Dict[str, str]:
"""Return the tags of current DAG."""
return self._tags
@property
def dev_mode(self) -> bool:
"""Whether the current DAG is in dev mode.
Returns:
bool: Whether the current DAG is in dev mode
"""
from ..operators.base import _dev_mode
return _dev_mode()
def _build(self) -> None:
from ..operators.common_operator import TriggerOperator

View File

@@ -3,18 +3,49 @@
DAGManager will load DAGs from dag_dirs, and register the trigger nodes
to TriggerManager.
"""
import logging
import threading
from typing import Dict, List, Optional
from collections import defaultdict
from typing import Dict, List, Optional, Set
from dbgpt._private.pydantic import BaseModel, Field, model_to_dict
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from .. import BaseOperator
from ..trigger.base import TriggerMetadata
from .base import DAG
from .loader import LocalFileDAGLoader
logger = logging.getLogger(__name__)
class DAGMetadata(BaseModel):
"""Metadata for the DAG."""
triggers: List[TriggerMetadata] = Field(
default_factory=list, description="The trigger metadata"
)
sse_output: bool = Field(
default=False, description="Whether the DAG is a server-sent event output"
)
streaming_output: bool = Field(
default=False, description="Whether the DAG is a streaming output"
)
tags: Optional[Dict[str, str]] = Field(
default=None, description="The tags of the DAG"
)
def to_dict(self):
"""Convert the metadata to dict."""
triggers_dict = []
for trigger in self.triggers:
triggers_dict.append(trigger.dict())
dict_value = model_to_dict(self, exclude={"triggers"})
dict_value["triggers"] = triggers_dict
return dict_value
class DAGManager(BaseComponent):
"""The component of DAGManager."""
@@ -35,6 +66,8 @@ class DAGManager(BaseComponent):
self.system_app = system_app
self.dag_map: Dict[str, DAG] = {}
self.dag_alias_map: Dict[str, str] = {}
self._dag_metadata_map: Dict[str, DAGMetadata] = {}
self._tags_to_dag_ids: Dict[str, Dict[str, Set[str]]] = {}
self._trigger_manager: Optional["DefaultTriggerManager"] = None
def init_app(self, system_app: SystemApp):
@@ -73,12 +106,26 @@ class DAGManager(BaseComponent):
if alias_name:
self.dag_alias_map[alias_name] = dag_id
trigger_metadata: List["TriggerMetadata"] = []
dag_metadata = _parse_metadata(dag)
if self._trigger_manager:
for trigger in dag.trigger_nodes:
self._trigger_manager.register_trigger(trigger, self.system_app)
tm = self._trigger_manager.register_trigger(
trigger, self.system_app
)
if tm:
trigger_metadata.append(tm)
self._trigger_manager.after_register()
else:
logger.warning("No trigger manager, not register dag trigger")
dag_metadata.triggers = trigger_metadata
self._dag_metadata_map[dag_id] = dag_metadata
tags = dag_metadata.tags
if tags:
for tag_key, tag_value in tags.items():
if tag_key not in self._tags_to_dag_ids:
self._tags_to_dag_ids[tag_key] = defaultdict(set)
self._tags_to_dag_ids[tag_key][tag_value].add(dag_id)
def unregister_dag(self, dag_id: str):
"""Unregister a DAG."""
@@ -104,7 +151,13 @@ class DAGManager(BaseComponent):
for trigger in dag.trigger_nodes:
self._trigger_manager.unregister_trigger(trigger, self.system_app)
# Finally remove the DAG from the map
metadata = self._dag_metadata_map[dag_id]
del self.dag_map[dag_id]
del self._dag_metadata_map[dag_id]
if metadata.tags:
for tag_key, tag_value in metadata.tags.items():
if tag_key in self._tags_to_dag_ids:
self._tags_to_dag_ids[tag_key][tag_value].remove(dag_id)
def get_dag(
self, dag_id: Optional[str] = None, alias_name: Optional[str] = None
@@ -116,3 +169,33 @@ class DAGManager(BaseComponent):
if alias_name in self.dag_alias_map:
return self.dag_map.get(self.dag_alias_map[alias_name])
return None
def get_dags_by_tag(self, tag_key: str, tag_value) -> List[DAG]:
"""Get all DAGs with the given tag."""
with self.lock:
dag_ids = self._tags_to_dag_ids.get(tag_key, {}).get(tag_value, set())
return [self.dag_map[dag_id] for dag_id in dag_ids]
def get_dag_metadata(
self, dag_id: Optional[str] = None, alias_name: Optional[str] = None
) -> Optional[DAGMetadata]:
"""Get a DAGMetadata by dag_id or alias_name."""
dag = self.get_dag(dag_id, alias_name)
if not dag:
return None
return self._dag_metadata_map.get(dag.dag_id)
def _parse_metadata(dag: DAG):
from ..util.chat_util import _is_sse_output
metadata = DAGMetadata()
metadata.tags = dag.tags
if not dag.leaf_nodes:
return metadata
end_node = dag.leaf_nodes[0]
if not isinstance(end_node, BaseOperator):
return metadata
metadata.sse_output = _is_sse_output(end_node)
metadata.streaming_output = end_node.streaming_operator
return metadata