mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-24 19:08:15 +00:00
refactor: The first refactored version for sdk release (#907)
Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
0
dbgpt/core/awel/dag/__init__.py
Normal file
0
dbgpt/core/awel/dag/__init__.py
Normal file
371
dbgpt/core/awel/dag/base.py
Normal file
371
dbgpt/core/awel/dag/base.py
Normal file
@@ -0,0 +1,371 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Dict, List, Sequence, Union, Any, Set
|
||||
import uuid
|
||||
import contextvars
|
||||
import threading
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import deque
|
||||
from functools import cache
|
||||
from concurrent.futures import Executor
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from ..resource.base import ResourceGroup
|
||||
from ..task.base import TaskContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DependencyType = Union["DependencyMixin", Sequence["DependencyMixin"]]
|
||||
|
||||
|
||||
def _is_async_context():
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
return asyncio.current_task(loop=loop) is not None
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
|
||||
class DependencyMixin(ABC):
|
||||
@abstractmethod
|
||||
def set_upstream(self, nodes: DependencyType) -> "DependencyMixin":
|
||||
"""Set one or more upstream nodes for this node.
|
||||
|
||||
Args:
|
||||
nodes (DependencyType): Upstream nodes to be set to current node.
|
||||
|
||||
Returns:
|
||||
DependencyMixin: Returns self to allow method chaining.
|
||||
|
||||
Raises:
|
||||
ValueError: If no upstream nodes are provided or if an argument is not a DependencyMixin.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_downstream(self, nodes: DependencyType) -> "DependencyMixin":
|
||||
"""Set one or more downstream nodes for this node.
|
||||
|
||||
Args:
|
||||
nodes (DependencyType): Downstream nodes to be set to current node.
|
||||
|
||||
Returns:
|
||||
DependencyMixin: Returns self to allow method chaining.
|
||||
|
||||
Raises:
|
||||
ValueError: If no downstream nodes are provided or if an argument is not a DependencyMixin.
|
||||
"""
|
||||
|
||||
def __lshift__(self, nodes: DependencyType) -> DependencyType:
|
||||
"""Implements self << nodes
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# means node.set_upstream(input_node)
|
||||
node << input_node
|
||||
|
||||
# means node2.set_upstream([input_node])
|
||||
node2 << [input_node]
|
||||
"""
|
||||
self.set_upstream(nodes)
|
||||
return nodes
|
||||
|
||||
def __rshift__(self, nodes: DependencyType) -> DependencyType:
|
||||
"""Implements self >> nodes
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# means node.set_downstream(next_node)
|
||||
node >> next_node
|
||||
|
||||
# means node2.set_downstream([next_node])
|
||||
node2 >> [next_node]
|
||||
|
||||
"""
|
||||
self.set_downstream(nodes)
|
||||
return nodes
|
||||
|
||||
def __rrshift__(self, nodes: DependencyType) -> "DependencyMixin":
|
||||
"""Implements [node] >> self"""
|
||||
self.__lshift__(nodes)
|
||||
return self
|
||||
|
||||
def __rlshift__(self, nodes: DependencyType) -> "DependencyMixin":
|
||||
"""Implements [node] << self"""
|
||||
self.__rshift__(nodes)
|
||||
return self
|
||||
|
||||
|
||||
class DAGVar:
|
||||
_thread_local = threading.local()
|
||||
_async_local = contextvars.ContextVar("current_dag_stack", default=deque())
|
||||
_system_app: SystemApp = None
|
||||
_executor: Executor = None
|
||||
|
||||
@classmethod
|
||||
def enter_dag(cls, dag) -> None:
|
||||
is_async = _is_async_context()
|
||||
if is_async:
|
||||
stack = cls._async_local.get()
|
||||
stack.append(dag)
|
||||
cls._async_local.set(stack)
|
||||
else:
|
||||
if not hasattr(cls._thread_local, "current_dag_stack"):
|
||||
cls._thread_local.current_dag_stack = deque()
|
||||
cls._thread_local.current_dag_stack.append(dag)
|
||||
|
||||
@classmethod
|
||||
def exit_dag(cls) -> None:
|
||||
is_async = _is_async_context()
|
||||
if is_async:
|
||||
stack = cls._async_local.get()
|
||||
if stack:
|
||||
stack.pop()
|
||||
cls._async_local.set(stack)
|
||||
else:
|
||||
if (
|
||||
hasattr(cls._thread_local, "current_dag_stack")
|
||||
and cls._thread_local.current_dag_stack
|
||||
):
|
||||
cls._thread_local.current_dag_stack.pop()
|
||||
|
||||
@classmethod
|
||||
def get_current_dag(cls) -> Optional["DAG"]:
|
||||
is_async = _is_async_context()
|
||||
if is_async:
|
||||
stack = cls._async_local.get()
|
||||
return stack[-1] if stack else None
|
||||
else:
|
||||
if (
|
||||
hasattr(cls._thread_local, "current_dag_stack")
|
||||
and cls._thread_local.current_dag_stack
|
||||
):
|
||||
return cls._thread_local.current_dag_stack[-1]
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_current_system_app(cls) -> SystemApp:
|
||||
# if not cls._system_app:
|
||||
# raise RuntimeError("System APP not set for DAGVar")
|
||||
return cls._system_app
|
||||
|
||||
@classmethod
|
||||
def set_current_system_app(cls, system_app: SystemApp) -> None:
|
||||
if cls._system_app:
|
||||
logger.warn("System APP has already set, nothing to do")
|
||||
else:
|
||||
cls._system_app = system_app
|
||||
|
||||
@classmethod
|
||||
def get_executor(cls) -> Executor:
|
||||
return cls._executor
|
||||
|
||||
@classmethod
|
||||
def set_executor(cls, executor: Executor) -> None:
|
||||
cls._executor = executor
|
||||
|
||||
|
||||
class DAGNode(DependencyMixin, ABC):
|
||||
resource_group: Optional[ResourceGroup] = None
|
||||
"""The resource group of current DAGNode"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dag: Optional["DAG"] = None,
|
||||
node_id: Optional[str] = None,
|
||||
node_name: Optional[str] = None,
|
||||
system_app: Optional[SystemApp] = None,
|
||||
executor: Optional[Executor] = None,
|
||||
**kwargs
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._upstream: List["DAGNode"] = []
|
||||
self._downstream: List["DAGNode"] = []
|
||||
self._dag: Optional["DAG"] = dag or DAGVar.get_current_dag()
|
||||
self._system_app: Optional[SystemApp] = (
|
||||
system_app or DAGVar.get_current_system_app()
|
||||
)
|
||||
self._executor: Optional[Executor] = executor or DAGVar.get_executor()
|
||||
if not node_id and self._dag:
|
||||
node_id = self._dag._new_node_id()
|
||||
self._node_id: str = node_id
|
||||
self._node_name: str = node_name
|
||||
|
||||
@property
|
||||
def node_id(self) -> str:
|
||||
return self._node_id
|
||||
|
||||
@property
|
||||
def system_app(self) -> SystemApp:
|
||||
return self._system_app
|
||||
|
||||
def set_node_id(self, node_id: str) -> None:
|
||||
self._node_id = node_id
|
||||
|
||||
def __hash__(self) -> int:
|
||||
if self.node_id:
|
||||
return hash(self.node_id)
|
||||
else:
|
||||
return super().__hash__()
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if not isinstance(other, DAGNode):
|
||||
return False
|
||||
return self.node_id == other.node_id
|
||||
|
||||
@property
|
||||
def node_name(self) -> str:
|
||||
return self._node_name
|
||||
|
||||
@property
|
||||
def dag(self) -> "DAG":
|
||||
return self._dag
|
||||
|
||||
def set_upstream(self, nodes: DependencyType) -> "DAGNode":
|
||||
self.set_dependency(nodes)
|
||||
|
||||
def set_downstream(self, nodes: DependencyType) -> "DAGNode":
|
||||
self.set_dependency(nodes, is_upstream=False)
|
||||
|
||||
@property
|
||||
def upstream(self) -> List["DAGNode"]:
|
||||
return self._upstream
|
||||
|
||||
@property
|
||||
def downstream(self) -> List["DAGNode"]:
|
||||
return self._downstream
|
||||
|
||||
def set_dependency(self, nodes: DependencyType, is_upstream: bool = True) -> None:
|
||||
if not isinstance(nodes, Sequence):
|
||||
nodes = [nodes]
|
||||
if not all(isinstance(node, DAGNode) for node in nodes):
|
||||
raise ValueError(
|
||||
"all nodes to set dependency to current node must be instance of 'DAGNode'"
|
||||
)
|
||||
nodes: Sequence[DAGNode] = nodes
|
||||
dags = set([node.dag for node in nodes if node.dag])
|
||||
if self.dag:
|
||||
dags.add(self.dag)
|
||||
if not dags:
|
||||
raise ValueError("set dependency to current node must in a DAG context")
|
||||
if len(dags) != 1:
|
||||
raise ValueError(
|
||||
"set dependency to current node just support in one DAG context"
|
||||
)
|
||||
dag = dags.pop()
|
||||
self._dag = dag
|
||||
|
||||
dag._append_node(self)
|
||||
for node in nodes:
|
||||
if is_upstream and node not in self.upstream:
|
||||
node._dag = dag
|
||||
dag._append_node(node)
|
||||
|
||||
self._upstream.append(node)
|
||||
node._downstream.append(self)
|
||||
elif node not in self._downstream:
|
||||
node._dag = dag
|
||||
dag._append_node(node)
|
||||
|
||||
self._downstream.append(node)
|
||||
node._upstream.append(self)
|
||||
|
||||
|
||||
class DAGContext:
|
||||
def __init__(self, streaming_call: bool = False) -> None:
|
||||
self._streaming_call = streaming_call
|
||||
self._curr_task_ctx = None
|
||||
self._share_data: Dict[str, Any] = {}
|
||||
|
||||
@property
|
||||
def current_task_context(self) -> TaskContext:
|
||||
return self._curr_task_ctx
|
||||
|
||||
@property
|
||||
def streaming_call(self) -> bool:
|
||||
"""Whether the current DAG is streaming call"""
|
||||
return self._streaming_call
|
||||
|
||||
def set_current_task_context(self, _curr_task_ctx: TaskContext) -> None:
|
||||
self._curr_task_ctx = _curr_task_ctx
|
||||
|
||||
async def get_share_data(self, key: str) -> Any:
|
||||
return self._share_data.get(key)
|
||||
|
||||
async def save_to_share_data(self, key: str, data: Any) -> None:
|
||||
self._share_data[key] = data
|
||||
|
||||
|
||||
class DAG:
|
||||
def __init__(
|
||||
self, dag_id: str, resource_group: Optional[ResourceGroup] = None
|
||||
) -> None:
|
||||
self._dag_id = dag_id
|
||||
self.node_map: Dict[str, DAGNode] = {}
|
||||
self._root_nodes: Set[DAGNode] = None
|
||||
self._leaf_nodes: Set[DAGNode] = None
|
||||
self._trigger_nodes: Set[DAGNode] = None
|
||||
|
||||
def _append_node(self, node: DAGNode) -> None:
|
||||
self.node_map[node.node_id] = node
|
||||
# clear cached nodes
|
||||
self._root_nodes = None
|
||||
self._leaf_nodes = None
|
||||
|
||||
def _new_node_id(self) -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
@property
|
||||
def dag_id(self) -> str:
|
||||
return self._dag_id
|
||||
|
||||
def _build(self) -> None:
|
||||
from ..operator.common_operator import TriggerOperator
|
||||
|
||||
nodes = set()
|
||||
for _, node in self.node_map.items():
|
||||
nodes = nodes.union(_get_nodes(node))
|
||||
self._root_nodes = list(set(filter(lambda x: not x.upstream, nodes)))
|
||||
self._leaf_nodes = list(set(filter(lambda x: not x.downstream, nodes)))
|
||||
self._trigger_nodes = list(
|
||||
set(filter(lambda x: isinstance(x, TriggerOperator), nodes))
|
||||
)
|
||||
|
||||
@property
|
||||
def root_nodes(self) -> List[DAGNode]:
|
||||
if not self._root_nodes:
|
||||
self._build()
|
||||
return self._root_nodes
|
||||
|
||||
@property
|
||||
def leaf_nodes(self) -> List[DAGNode]:
|
||||
if not self._leaf_nodes:
|
||||
self._build()
|
||||
return self._leaf_nodes
|
||||
|
||||
@property
|
||||
def trigger_nodes(self):
|
||||
if not self._trigger_nodes:
|
||||
self._build()
|
||||
return self._trigger_nodes
|
||||
|
||||
def __enter__(self):
|
||||
DAGVar.enter_dag(self)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
DAGVar.exit_dag()
|
||||
|
||||
|
||||
def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> set[DAGNode]:
|
||||
nodes = set()
|
||||
if not node:
|
||||
return nodes
|
||||
nodes.add(node)
|
||||
stream_nodes = node.upstream if is_upstream else node.downstream
|
||||
for node in stream_nodes:
|
||||
nodes = nodes.union(_get_nodes(node, is_upstream))
|
||||
return nodes
|
42
dbgpt/core/awel/dag/dag_manager.py
Normal file
42
dbgpt/core/awel/dag/dag_manager.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from typing import Dict, Optional
|
||||
import logging
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
from .loader import DAGLoader, LocalFileDAGLoader
|
||||
from .base import DAG
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DAGManager(BaseComponent):
|
||||
name = ComponentType.AWEL_DAG_MANAGER
|
||||
|
||||
def __init__(self, system_app: SystemApp, dag_filepath: str):
|
||||
super().__init__(system_app)
|
||||
self.dag_loader = LocalFileDAGLoader(dag_filepath)
|
||||
self.system_app = system_app
|
||||
self.dag_map: Dict[str, DAG] = {}
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
self.system_app = system_app
|
||||
|
||||
def load_dags(self):
|
||||
dags = self.dag_loader.load_dags()
|
||||
triggers = []
|
||||
for dag in dags:
|
||||
dag_id = dag.dag_id
|
||||
if dag_id in self.dag_map:
|
||||
raise ValueError(f"Load DAG error, DAG ID {dag_id} has already exist")
|
||||
triggers += dag.trigger_nodes
|
||||
from ..trigger.trigger_manager import DefaultTriggerManager
|
||||
|
||||
trigger_manager: DefaultTriggerManager = self.system_app.get_component(
|
||||
ComponentType.AWEL_TRIGGER_MANAGER,
|
||||
DefaultTriggerManager,
|
||||
default_component=None,
|
||||
)
|
||||
if trigger_manager:
|
||||
for trigger in triggers:
|
||||
trigger_manager.register_trigger(trigger)
|
||||
trigger_manager.after_register()
|
||||
else:
|
||||
logger.warn("No trigger manager, not register dag trigger")
|
93
dbgpt/core/awel/dag/loader.py
Normal file
93
dbgpt/core/awel/dag/loader.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
import os
|
||||
import hashlib
|
||||
import sys
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from .base import DAG
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DAGLoader(ABC):
|
||||
@abstractmethod
|
||||
def load_dags(self) -> List[DAG]:
|
||||
"""Load dags"""
|
||||
|
||||
|
||||
class LocalFileDAGLoader(DAGLoader):
|
||||
def __init__(self, filepath: str) -> None:
|
||||
super().__init__()
|
||||
self._filepath = filepath
|
||||
|
||||
def load_dags(self) -> List[DAG]:
|
||||
if not os.path.exists(self._filepath):
|
||||
return []
|
||||
if os.path.isdir(self._filepath):
|
||||
return _process_directory(self._filepath)
|
||||
else:
|
||||
return _process_file(self._filepath)
|
||||
|
||||
|
||||
def _process_directory(directory: str) -> List[DAG]:
|
||||
dags = []
|
||||
for file in os.listdir(directory):
|
||||
if file.endswith(".py"):
|
||||
filepath = os.path.join(directory, file)
|
||||
dags += _process_file(filepath)
|
||||
return dags
|
||||
|
||||
|
||||
def _process_file(filepath) -> List[DAG]:
|
||||
mods = _load_modules_from_file(filepath)
|
||||
results = _process_modules(mods)
|
||||
return results
|
||||
|
||||
|
||||
def _load_modules_from_file(filepath: str):
|
||||
import importlib
|
||||
import importlib.machinery
|
||||
import importlib.util
|
||||
|
||||
logger.info(f"Importing {filepath}")
|
||||
|
||||
org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1])
|
||||
path_hash = hashlib.sha1(filepath.encode("utf-8")).hexdigest()
|
||||
mod_name = f"unusual_prefix_{path_hash}_{org_mod_name}"
|
||||
|
||||
if mod_name in sys.modules:
|
||||
del sys.modules[mod_name]
|
||||
|
||||
def parse(mod_name, filepath):
|
||||
try:
|
||||
loader = importlib.machinery.SourceFileLoader(mod_name, filepath)
|
||||
spec = importlib.util.spec_from_loader(mod_name, loader)
|
||||
new_module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[spec.name] = new_module
|
||||
loader.exec_module(new_module)
|
||||
return [new_module]
|
||||
except Exception as e:
|
||||
msg = traceback.format_exc()
|
||||
logger.error(f"Failed to import: {filepath}, error message: {msg}")
|
||||
# TODO save error message
|
||||
return []
|
||||
|
||||
return parse(mod_name, filepath)
|
||||
|
||||
|
||||
def _process_modules(mods) -> List[DAG]:
|
||||
top_level_dags = (
|
||||
(o, m) for m in mods for o in m.__dict__.values() if isinstance(o, DAG)
|
||||
)
|
||||
found_dags = []
|
||||
for dag, mod in top_level_dags:
|
||||
try:
|
||||
# TODO validate dag params
|
||||
logger.info(f"Found dag {dag} from mod {mod} and model file {mod.__file__}")
|
||||
found_dags.append(dag)
|
||||
except Exception:
|
||||
msg = traceback.format_exc()
|
||||
logger.error(f"Failed to dag file, error message: {msg}")
|
||||
return found_dags
|
0
dbgpt/core/awel/dag/tests/__init__.py
Normal file
0
dbgpt/core/awel/dag/tests/__init__.py
Normal file
51
dbgpt/core/awel/dag/tests/test_dag.py
Normal file
51
dbgpt/core/awel/dag/tests/test_dag.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import pytest
|
||||
import threading
|
||||
import asyncio
|
||||
from ..dag import DAG, DAGContext
|
||||
|
||||
|
||||
def test_dag_context_sync():
|
||||
dag1 = DAG("dag1")
|
||||
dag2 = DAG("dag2")
|
||||
|
||||
with dag1:
|
||||
assert DAGContext.get_current_dag() == dag1
|
||||
with dag2:
|
||||
assert DAGContext.get_current_dag() == dag2
|
||||
assert DAGContext.get_current_dag() == dag1
|
||||
assert DAGContext.get_current_dag() is None
|
||||
|
||||
|
||||
def test_dag_context_threading():
|
||||
def thread_function(dag):
|
||||
DAGContext.enter_dag(dag)
|
||||
assert DAGContext.get_current_dag() == dag
|
||||
DAGContext.exit_dag()
|
||||
|
||||
dag1 = DAG("dag1")
|
||||
dag2 = DAG("dag2")
|
||||
|
||||
thread1 = threading.Thread(target=thread_function, args=(dag1,))
|
||||
thread2 = threading.Thread(target=thread_function, args=(dag2,))
|
||||
|
||||
thread1.start()
|
||||
thread2.start()
|
||||
thread1.join()
|
||||
thread2.join()
|
||||
|
||||
assert DAGContext.get_current_dag() is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dag_context_async():
|
||||
async def async_function(dag):
|
||||
DAGContext.enter_dag(dag)
|
||||
assert DAGContext.get_current_dag() == dag
|
||||
DAGContext.exit_dag()
|
||||
|
||||
dag1 = DAG("dag1")
|
||||
dag2 = DAG("dag2")
|
||||
|
||||
await asyncio.gather(async_function(dag1), async_function(dag2))
|
||||
|
||||
assert DAGContext.get_current_dag() is None
|
Reference in New Issue
Block a user