refactor: The first refactored version for sdk release (#907)

Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
FangYin Cheng
2023-12-08 14:45:59 +08:00
committed by GitHub
parent e7e4aff667
commit cd725db1fb
573 changed files with 2094 additions and 3571 deletions

View File

371
dbgpt/core/awel/dag/base.py Normal file
View File

@@ -0,0 +1,371 @@
from abc import ABC, abstractmethod
from typing import Optional, Dict, List, Sequence, Union, Any, Set
import uuid
import contextvars
import threading
import asyncio
import logging
from collections import deque
from functools import cache
from concurrent.futures import Executor
from dbgpt.component import SystemApp
from ..resource.base import ResourceGroup
from ..task.base import TaskContext
logger = logging.getLogger(__name__)
DependencyType = Union["DependencyMixin", Sequence["DependencyMixin"]]
def _is_async_context():
try:
loop = asyncio.get_running_loop()
return asyncio.current_task(loop=loop) is not None
except RuntimeError:
return False
class DependencyMixin(ABC):
@abstractmethod
def set_upstream(self, nodes: DependencyType) -> "DependencyMixin":
"""Set one or more upstream nodes for this node.
Args:
nodes (DependencyType): Upstream nodes to be set to current node.
Returns:
DependencyMixin: Returns self to allow method chaining.
Raises:
ValueError: If no upstream nodes are provided or if an argument is not a DependencyMixin.
"""
@abstractmethod
def set_downstream(self, nodes: DependencyType) -> "DependencyMixin":
"""Set one or more downstream nodes for this node.
Args:
nodes (DependencyType): Downstream nodes to be set to current node.
Returns:
DependencyMixin: Returns self to allow method chaining.
Raises:
ValueError: If no downstream nodes are provided or if an argument is not a DependencyMixin.
"""
def __lshift__(self, nodes: DependencyType) -> DependencyType:
"""Implements self << nodes
Example:
.. code-block:: python
# means node.set_upstream(input_node)
node << input_node
# means node2.set_upstream([input_node])
node2 << [input_node]
"""
self.set_upstream(nodes)
return nodes
def __rshift__(self, nodes: DependencyType) -> DependencyType:
"""Implements self >> nodes
Example:
.. code-block:: python
# means node.set_downstream(next_node)
node >> next_node
# means node2.set_downstream([next_node])
node2 >> [next_node]
"""
self.set_downstream(nodes)
return nodes
def __rrshift__(self, nodes: DependencyType) -> "DependencyMixin":
"""Implements [node] >> self"""
self.__lshift__(nodes)
return self
def __rlshift__(self, nodes: DependencyType) -> "DependencyMixin":
"""Implements [node] << self"""
self.__rshift__(nodes)
return self
class DAGVar:
_thread_local = threading.local()
_async_local = contextvars.ContextVar("current_dag_stack", default=deque())
_system_app: SystemApp = None
_executor: Executor = None
@classmethod
def enter_dag(cls, dag) -> None:
is_async = _is_async_context()
if is_async:
stack = cls._async_local.get()
stack.append(dag)
cls._async_local.set(stack)
else:
if not hasattr(cls._thread_local, "current_dag_stack"):
cls._thread_local.current_dag_stack = deque()
cls._thread_local.current_dag_stack.append(dag)
@classmethod
def exit_dag(cls) -> None:
is_async = _is_async_context()
if is_async:
stack = cls._async_local.get()
if stack:
stack.pop()
cls._async_local.set(stack)
else:
if (
hasattr(cls._thread_local, "current_dag_stack")
and cls._thread_local.current_dag_stack
):
cls._thread_local.current_dag_stack.pop()
@classmethod
def get_current_dag(cls) -> Optional["DAG"]:
is_async = _is_async_context()
if is_async:
stack = cls._async_local.get()
return stack[-1] if stack else None
else:
if (
hasattr(cls._thread_local, "current_dag_stack")
and cls._thread_local.current_dag_stack
):
return cls._thread_local.current_dag_stack[-1]
return None
@classmethod
def get_current_system_app(cls) -> SystemApp:
# if not cls._system_app:
# raise RuntimeError("System APP not set for DAGVar")
return cls._system_app
@classmethod
def set_current_system_app(cls, system_app: SystemApp) -> None:
if cls._system_app:
logger.warn("System APP has already set, nothing to do")
else:
cls._system_app = system_app
@classmethod
def get_executor(cls) -> Executor:
return cls._executor
@classmethod
def set_executor(cls, executor: Executor) -> None:
cls._executor = executor
class DAGNode(DependencyMixin, ABC):
resource_group: Optional[ResourceGroup] = None
"""The resource group of current DAGNode"""
def __init__(
self,
dag: Optional["DAG"] = None,
node_id: Optional[str] = None,
node_name: Optional[str] = None,
system_app: Optional[SystemApp] = None,
executor: Optional[Executor] = None,
**kwargs
) -> None:
super().__init__()
self._upstream: List["DAGNode"] = []
self._downstream: List["DAGNode"] = []
self._dag: Optional["DAG"] = dag or DAGVar.get_current_dag()
self._system_app: Optional[SystemApp] = (
system_app or DAGVar.get_current_system_app()
)
self._executor: Optional[Executor] = executor or DAGVar.get_executor()
if not node_id and self._dag:
node_id = self._dag._new_node_id()
self._node_id: str = node_id
self._node_name: str = node_name
@property
def node_id(self) -> str:
return self._node_id
@property
def system_app(self) -> SystemApp:
return self._system_app
def set_node_id(self, node_id: str) -> None:
self._node_id = node_id
def __hash__(self) -> int:
if self.node_id:
return hash(self.node_id)
else:
return super().__hash__()
def __eq__(self, other: Any) -> bool:
if not isinstance(other, DAGNode):
return False
return self.node_id == other.node_id
@property
def node_name(self) -> str:
return self._node_name
@property
def dag(self) -> "DAG":
return self._dag
def set_upstream(self, nodes: DependencyType) -> "DAGNode":
self.set_dependency(nodes)
def set_downstream(self, nodes: DependencyType) -> "DAGNode":
self.set_dependency(nodes, is_upstream=False)
@property
def upstream(self) -> List["DAGNode"]:
return self._upstream
@property
def downstream(self) -> List["DAGNode"]:
return self._downstream
def set_dependency(self, nodes: DependencyType, is_upstream: bool = True) -> None:
if not isinstance(nodes, Sequence):
nodes = [nodes]
if not all(isinstance(node, DAGNode) for node in nodes):
raise ValueError(
"all nodes to set dependency to current node must be instance of 'DAGNode'"
)
nodes: Sequence[DAGNode] = nodes
dags = set([node.dag for node in nodes if node.dag])
if self.dag:
dags.add(self.dag)
if not dags:
raise ValueError("set dependency to current node must in a DAG context")
if len(dags) != 1:
raise ValueError(
"set dependency to current node just support in one DAG context"
)
dag = dags.pop()
self._dag = dag
dag._append_node(self)
for node in nodes:
if is_upstream and node not in self.upstream:
node._dag = dag
dag._append_node(node)
self._upstream.append(node)
node._downstream.append(self)
elif node not in self._downstream:
node._dag = dag
dag._append_node(node)
self._downstream.append(node)
node._upstream.append(self)
class DAGContext:
def __init__(self, streaming_call: bool = False) -> None:
self._streaming_call = streaming_call
self._curr_task_ctx = None
self._share_data: Dict[str, Any] = {}
@property
def current_task_context(self) -> TaskContext:
return self._curr_task_ctx
@property
def streaming_call(self) -> bool:
"""Whether the current DAG is streaming call"""
return self._streaming_call
def set_current_task_context(self, _curr_task_ctx: TaskContext) -> None:
self._curr_task_ctx = _curr_task_ctx
async def get_share_data(self, key: str) -> Any:
return self._share_data.get(key)
async def save_to_share_data(self, key: str, data: Any) -> None:
self._share_data[key] = data
class DAG:
def __init__(
self, dag_id: str, resource_group: Optional[ResourceGroup] = None
) -> None:
self._dag_id = dag_id
self.node_map: Dict[str, DAGNode] = {}
self._root_nodes: Set[DAGNode] = None
self._leaf_nodes: Set[DAGNode] = None
self._trigger_nodes: Set[DAGNode] = None
def _append_node(self, node: DAGNode) -> None:
self.node_map[node.node_id] = node
# clear cached nodes
self._root_nodes = None
self._leaf_nodes = None
def _new_node_id(self) -> str:
return str(uuid.uuid4())
@property
def dag_id(self) -> str:
return self._dag_id
def _build(self) -> None:
from ..operator.common_operator import TriggerOperator
nodes = set()
for _, node in self.node_map.items():
nodes = nodes.union(_get_nodes(node))
self._root_nodes = list(set(filter(lambda x: not x.upstream, nodes)))
self._leaf_nodes = list(set(filter(lambda x: not x.downstream, nodes)))
self._trigger_nodes = list(
set(filter(lambda x: isinstance(x, TriggerOperator), nodes))
)
@property
def root_nodes(self) -> List[DAGNode]:
if not self._root_nodes:
self._build()
return self._root_nodes
@property
def leaf_nodes(self) -> List[DAGNode]:
if not self._leaf_nodes:
self._build()
return self._leaf_nodes
@property
def trigger_nodes(self):
if not self._trigger_nodes:
self._build()
return self._trigger_nodes
def __enter__(self):
DAGVar.enter_dag(self)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
DAGVar.exit_dag()
def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> set[DAGNode]:
nodes = set()
if not node:
return nodes
nodes.add(node)
stream_nodes = node.upstream if is_upstream else node.downstream
for node in stream_nodes:
nodes = nodes.union(_get_nodes(node, is_upstream))
return nodes

View File

@@ -0,0 +1,42 @@
from typing import Dict, Optional
import logging
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from .loader import DAGLoader, LocalFileDAGLoader
from .base import DAG
logger = logging.getLogger(__name__)
class DAGManager(BaseComponent):
name = ComponentType.AWEL_DAG_MANAGER
def __init__(self, system_app: SystemApp, dag_filepath: str):
super().__init__(system_app)
self.dag_loader = LocalFileDAGLoader(dag_filepath)
self.system_app = system_app
self.dag_map: Dict[str, DAG] = {}
def init_app(self, system_app: SystemApp):
self.system_app = system_app
def load_dags(self):
dags = self.dag_loader.load_dags()
triggers = []
for dag in dags:
dag_id = dag.dag_id
if dag_id in self.dag_map:
raise ValueError(f"Load DAG error, DAG ID {dag_id} has already exist")
triggers += dag.trigger_nodes
from ..trigger.trigger_manager import DefaultTriggerManager
trigger_manager: DefaultTriggerManager = self.system_app.get_component(
ComponentType.AWEL_TRIGGER_MANAGER,
DefaultTriggerManager,
default_component=None,
)
if trigger_manager:
for trigger in triggers:
trigger_manager.register_trigger(trigger)
trigger_manager.after_register()
else:
logger.warn("No trigger manager, not register dag trigger")

View File

@@ -0,0 +1,93 @@
from abc import ABC, abstractmethod
from typing import List
import os
import hashlib
import sys
import logging
import traceback
from .base import DAG
logger = logging.getLogger(__name__)
class DAGLoader(ABC):
@abstractmethod
def load_dags(self) -> List[DAG]:
"""Load dags"""
class LocalFileDAGLoader(DAGLoader):
def __init__(self, filepath: str) -> None:
super().__init__()
self._filepath = filepath
def load_dags(self) -> List[DAG]:
if not os.path.exists(self._filepath):
return []
if os.path.isdir(self._filepath):
return _process_directory(self._filepath)
else:
return _process_file(self._filepath)
def _process_directory(directory: str) -> List[DAG]:
dags = []
for file in os.listdir(directory):
if file.endswith(".py"):
filepath = os.path.join(directory, file)
dags += _process_file(filepath)
return dags
def _process_file(filepath) -> List[DAG]:
mods = _load_modules_from_file(filepath)
results = _process_modules(mods)
return results
def _load_modules_from_file(filepath: str):
import importlib
import importlib.machinery
import importlib.util
logger.info(f"Importing {filepath}")
org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1])
path_hash = hashlib.sha1(filepath.encode("utf-8")).hexdigest()
mod_name = f"unusual_prefix_{path_hash}_{org_mod_name}"
if mod_name in sys.modules:
del sys.modules[mod_name]
def parse(mod_name, filepath):
try:
loader = importlib.machinery.SourceFileLoader(mod_name, filepath)
spec = importlib.util.spec_from_loader(mod_name, loader)
new_module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = new_module
loader.exec_module(new_module)
return [new_module]
except Exception as e:
msg = traceback.format_exc()
logger.error(f"Failed to import: {filepath}, error message: {msg}")
# TODO save error message
return []
return parse(mod_name, filepath)
def _process_modules(mods) -> List[DAG]:
top_level_dags = (
(o, m) for m in mods for o in m.__dict__.values() if isinstance(o, DAG)
)
found_dags = []
for dag, mod in top_level_dags:
try:
# TODO validate dag params
logger.info(f"Found dag {dag} from mod {mod} and model file {mod.__file__}")
found_dags.append(dag)
except Exception:
msg = traceback.format_exc()
logger.error(f"Failed to dag file, error message: {msg}")
return found_dags

View File

View File

@@ -0,0 +1,51 @@
import pytest
import threading
import asyncio
from ..dag import DAG, DAGContext
def test_dag_context_sync():
dag1 = DAG("dag1")
dag2 = DAG("dag2")
with dag1:
assert DAGContext.get_current_dag() == dag1
with dag2:
assert DAGContext.get_current_dag() == dag2
assert DAGContext.get_current_dag() == dag1
assert DAGContext.get_current_dag() is None
def test_dag_context_threading():
def thread_function(dag):
DAGContext.enter_dag(dag)
assert DAGContext.get_current_dag() == dag
DAGContext.exit_dag()
dag1 = DAG("dag1")
dag2 = DAG("dag2")
thread1 = threading.Thread(target=thread_function, args=(dag1,))
thread2 = threading.Thread(target=thread_function, args=(dag2,))
thread1.start()
thread2.start()
thread1.join()
thread2.join()
assert DAGContext.get_current_dag() is None
@pytest.mark.asyncio
async def test_dag_context_async():
async def async_function(dag):
DAGContext.enter_dag(dag)
assert DAGContext.get_current_dag() == dag
DAGContext.exit_dag()
dag1 = DAG("dag1")
dag2 = DAG("dag2")
await asyncio.gather(async_function(dag1), async_function(dag2))
assert DAGContext.get_current_dag() is None