mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-05 11:01:09 +00:00
chore: Add pylint for DB-GPT core lib (#1076)
This commit is contained in:
@@ -1,3 +1,7 @@
|
||||
"""The base module of DAG.
|
||||
|
||||
DAG is the core component of AWEL, it is used to define the relationship between tasks.
|
||||
"""
|
||||
import asyncio
|
||||
import contextvars
|
||||
import logging
|
||||
@@ -6,7 +10,7 @@ import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from concurrent.futures import Executor
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Union, cast
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
|
||||
@@ -27,86 +31,108 @@ def _is_async_context():
|
||||
|
||||
|
||||
class DependencyMixin(ABC):
|
||||
"""The mixin class for DAGNode.
|
||||
|
||||
This class defines the interface for setting upstream and downstream nodes.
|
||||
|
||||
And it also implements the operator << and >> for setting upstream
|
||||
and downstream nodes.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_upstream(self, nodes: DependencyType) -> "DependencyMixin":
|
||||
def set_upstream(self, nodes: DependencyType) -> None:
|
||||
"""Set one or more upstream nodes for this node.
|
||||
|
||||
Args:
|
||||
nodes (DependencyType): Upstream nodes to be set to current node.
|
||||
|
||||
Returns:
|
||||
DependencyMixin: Returns self to allow method chaining.
|
||||
|
||||
Raises:
|
||||
ValueError: If no upstream nodes are provided or if an argument is not a DependencyMixin.
|
||||
ValueError: If no upstream nodes are provided or if an argument is
|
||||
not a DependencyMixin.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_downstream(self, nodes: DependencyType) -> "DependencyMixin":
|
||||
def set_downstream(self, nodes: DependencyType) -> None:
|
||||
"""Set one or more downstream nodes for this node.
|
||||
|
||||
Args:
|
||||
nodes (DependencyType): Downstream nodes to be set to current node.
|
||||
|
||||
Returns:
|
||||
DependencyMixin: Returns self to allow method chaining.
|
||||
|
||||
Raises:
|
||||
ValueError: If no downstream nodes are provided or if an argument is not a DependencyMixin.
|
||||
ValueError: If no downstream nodes are provided or if an argument is
|
||||
not a DependencyMixin.
|
||||
"""
|
||||
|
||||
def __lshift__(self, nodes: DependencyType) -> DependencyType:
|
||||
"""Implements self << nodes
|
||||
"""Set upstream nodes for current node.
|
||||
|
||||
Implements: self << nodes.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
.. code-block:: python
|
||||
# means node.set_upstream(input_node)
|
||||
node << input_node
|
||||
# means node2.set_upstream([input_node])
|
||||
node2 << [input_node]
|
||||
|
||||
# means node.set_upstream(input_node)
|
||||
node << input_node
|
||||
|
||||
# means node2.set_upstream([input_node])
|
||||
node2 << [input_node]
|
||||
"""
|
||||
self.set_upstream(nodes)
|
||||
return nodes
|
||||
|
||||
def __rshift__(self, nodes: DependencyType) -> DependencyType:
|
||||
"""Implements self >> nodes
|
||||
"""Set downstream nodes for current node.
|
||||
|
||||
Example:
|
||||
Implements: self >> nodes.
|
||||
|
||||
.. code-block:: python
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
# means node.set_downstream(next_node)
|
||||
node >> next_node
|
||||
# means node.set_downstream(next_node)
|
||||
node >> next_node
|
||||
|
||||
# means node2.set_downstream([next_node])
|
||||
node2 >> [next_node]
|
||||
# means node2.set_downstream([next_node])
|
||||
node2 >> [next_node]
|
||||
|
||||
"""
|
||||
self.set_downstream(nodes)
|
||||
return nodes
|
||||
|
||||
def __rrshift__(self, nodes: DependencyType) -> "DependencyMixin":
|
||||
"""Implements [node] >> self"""
|
||||
"""Set upstream nodes for current node.
|
||||
|
||||
Implements: [node] >> self
|
||||
"""
|
||||
self.__lshift__(nodes)
|
||||
return self
|
||||
|
||||
def __rlshift__(self, nodes: DependencyType) -> "DependencyMixin":
|
||||
"""Implements [node] << self"""
|
||||
"""Set downstream nodes for current node.
|
||||
|
||||
Implements: [node] << self
|
||||
"""
|
||||
self.__rshift__(nodes)
|
||||
return self
|
||||
|
||||
|
||||
class DAGVar:
|
||||
"""The DAGVar is used to store the current DAG context."""
|
||||
|
||||
_thread_local = threading.local()
|
||||
_async_local = contextvars.ContextVar("current_dag_stack", default=deque())
|
||||
_system_app: SystemApp = None
|
||||
_executor: Executor = None
|
||||
_async_local: contextvars.ContextVar = contextvars.ContextVar(
|
||||
"current_dag_stack", default=deque()
|
||||
)
|
||||
_system_app: Optional[SystemApp] = None
|
||||
# The executor for current DAG, this is used run some sync tasks in async DAG
|
||||
_executor: Optional[Executor] = None
|
||||
|
||||
@classmethod
|
||||
def enter_dag(cls, dag) -> None:
|
||||
"""Enter a DAG context.
|
||||
|
||||
Args:
|
||||
dag (DAG): The DAG to enter
|
||||
"""
|
||||
is_async = _is_async_context()
|
||||
if is_async:
|
||||
stack = cls._async_local.get()
|
||||
@@ -119,6 +145,7 @@ class DAGVar:
|
||||
|
||||
@classmethod
|
||||
def exit_dag(cls) -> None:
|
||||
"""Exit a DAG context."""
|
||||
is_async = _is_async_context()
|
||||
if is_async:
|
||||
stack = cls._async_local.get()
|
||||
@@ -134,6 +161,11 @@ class DAGVar:
|
||||
|
||||
@classmethod
|
||||
def get_current_dag(cls) -> Optional["DAG"]:
|
||||
"""Get the current DAG.
|
||||
|
||||
Returns:
|
||||
Optional[DAG]: The current DAG
|
||||
"""
|
||||
is_async = _is_async_context()
|
||||
if is_async:
|
||||
stack = cls._async_local.get()
|
||||
@@ -147,36 +179,56 @@ class DAGVar:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_current_system_app(cls) -> SystemApp:
|
||||
def get_current_system_app(cls) -> Optional[SystemApp]:
|
||||
"""Get the current system app.
|
||||
|
||||
Returns:
|
||||
Optional[SystemApp]: The current system app
|
||||
"""
|
||||
# if not cls._system_app:
|
||||
# raise RuntimeError("System APP not set for DAGVar")
|
||||
return cls._system_app
|
||||
|
||||
@classmethod
|
||||
def set_current_system_app(cls, system_app: SystemApp) -> None:
|
||||
"""Set the current system app.
|
||||
|
||||
Args:
|
||||
system_app (SystemApp): The system app to set
|
||||
"""
|
||||
if cls._system_app:
|
||||
logger.warn("System APP has already set, nothing to do")
|
||||
logger.warning("System APP has already set, nothing to do")
|
||||
else:
|
||||
cls._system_app = system_app
|
||||
|
||||
@classmethod
|
||||
def get_executor(cls) -> Executor:
|
||||
def get_executor(cls) -> Optional[Executor]:
|
||||
"""Get the current executor.
|
||||
|
||||
Returns:
|
||||
Optional[Executor]: The current executor
|
||||
"""
|
||||
return cls._executor
|
||||
|
||||
@classmethod
|
||||
def set_executor(cls, executor: Executor) -> None:
|
||||
"""Set the current executor.
|
||||
|
||||
Args:
|
||||
executor (Executor): The executor to set
|
||||
"""
|
||||
cls._executor = executor
|
||||
|
||||
|
||||
class DAGLifecycle:
|
||||
"""The lifecycle of DAG"""
|
||||
"""The lifecycle of DAG."""
|
||||
|
||||
async def before_dag_run(self):
|
||||
"""The callback before DAG run"""
|
||||
"""Execute before DAG run."""
|
||||
pass
|
||||
|
||||
async def after_dag_end(self):
|
||||
"""The callback after DAG end,
|
||||
"""Execute after DAG end.
|
||||
|
||||
This method may be called multiple times, please make sure it is idempotent.
|
||||
"""
|
||||
@@ -184,6 +236,8 @@ class DAGLifecycle:
|
||||
|
||||
|
||||
class DAGNode(DAGLifecycle, DependencyMixin, ABC):
|
||||
"""The base class of DAGNode."""
|
||||
|
||||
resource_group: Optional[ResourceGroup] = None
|
||||
"""The resource group of current DAGNode"""
|
||||
|
||||
@@ -196,6 +250,17 @@ class DAGNode(DAGLifecycle, DependencyMixin, ABC):
|
||||
executor: Optional[Executor] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Initialize a DAGNode.
|
||||
|
||||
Args:
|
||||
dag (Optional["DAG"], optional): The DAG to add this node to.
|
||||
Defaults to None.
|
||||
node_id (Optional[str], optional): The node id. Defaults to None.
|
||||
node_name (Optional[str], optional): The node name. Defaults to None.
|
||||
system_app (Optional[SystemApp], optional): The system app.
|
||||
Defaults to None.
|
||||
executor (Optional[Executor], optional): The executor. Defaults to None.
|
||||
"""
|
||||
super().__init__()
|
||||
self._upstream: List["DAGNode"] = []
|
||||
self._downstream: List["DAGNode"] = []
|
||||
@@ -206,24 +271,28 @@ class DAGNode(DAGLifecycle, DependencyMixin, ABC):
|
||||
self._executor: Optional[Executor] = executor or DAGVar.get_executor()
|
||||
if not node_id and self._dag:
|
||||
node_id = self._dag._new_node_id()
|
||||
self._node_id: str = node_id
|
||||
self._node_name: str = node_name
|
||||
self._node_id: Optional[str] = node_id
|
||||
self._node_name: Optional[str] = node_name
|
||||
|
||||
@property
|
||||
def node_id(self) -> str:
|
||||
"""Return the node id of current DAGNode."""
|
||||
if not self._node_id:
|
||||
raise ValueError("Node id not set for current DAGNode")
|
||||
return self._node_id
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def dev_mode(self) -> bool:
|
||||
"""Whether current DAGNode is in dev mode"""
|
||||
"""Whether current DAGNode is in dev mode."""
|
||||
|
||||
@property
|
||||
def system_app(self) -> SystemApp:
|
||||
def system_app(self) -> Optional[SystemApp]:
|
||||
"""Return the system app of current DAGNode."""
|
||||
return self._system_app
|
||||
|
||||
def set_system_app(self, system_app: SystemApp) -> None:
|
||||
"""Set system app for current DAGNode
|
||||
"""Set system app for current DAGNode.
|
||||
|
||||
Args:
|
||||
system_app (SystemApp): The system app
|
||||
@@ -231,50 +300,97 @@ class DAGNode(DAGLifecycle, DependencyMixin, ABC):
|
||||
self._system_app = system_app
|
||||
|
||||
def set_node_id(self, node_id: str) -> None:
|
||||
"""Set node id for current DAGNode.
|
||||
|
||||
Args:
|
||||
node_id (str): The node id
|
||||
"""
|
||||
self._node_id = node_id
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Return the hash value of current DAGNode.
|
||||
|
||||
If the node_id is not None, return the hash value of node_id.
|
||||
"""
|
||||
if self.node_id:
|
||||
return hash(self.node_id)
|
||||
else:
|
||||
return super().__hash__()
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Return whether the current DAGNode is equal to other DAGNode."""
|
||||
if not isinstance(other, DAGNode):
|
||||
return False
|
||||
return self.node_id == other.node_id
|
||||
|
||||
@property
|
||||
def node_name(self) -> str:
|
||||
def node_name(self) -> Optional[str]:
|
||||
"""Return the node name of current DAGNode.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The node name of current DAGNode
|
||||
"""
|
||||
return self._node_name
|
||||
|
||||
@property
|
||||
def dag(self) -> "DAG":
|
||||
def dag(self) -> Optional["DAG"]:
|
||||
"""Return the DAG of current DAGNode.
|
||||
|
||||
Returns:
|
||||
Optional["DAG"]: The DAG of current DAGNode
|
||||
"""
|
||||
return self._dag
|
||||
|
||||
def set_upstream(self, nodes: DependencyType) -> "DAGNode":
|
||||
def set_upstream(self, nodes: DependencyType) -> None:
|
||||
"""Set upstream nodes for current node.
|
||||
|
||||
Args:
|
||||
nodes (DependencyType): Upstream nodes to be set to current node.
|
||||
"""
|
||||
self.set_dependency(nodes)
|
||||
|
||||
def set_downstream(self, nodes: DependencyType) -> "DAGNode":
|
||||
def set_downstream(self, nodes: DependencyType) -> None:
|
||||
"""Set downstream nodes for current node.
|
||||
|
||||
Args:
|
||||
nodes (DependencyType): Downstream nodes to be set to current node.
|
||||
"""
|
||||
self.set_dependency(nodes, is_upstream=False)
|
||||
|
||||
@property
|
||||
def upstream(self) -> List["DAGNode"]:
|
||||
"""Return the upstream nodes of current DAGNode.
|
||||
|
||||
Returns:
|
||||
List["DAGNode"]: The upstream nodes of current DAGNode
|
||||
"""
|
||||
return self._upstream
|
||||
|
||||
@property
|
||||
def downstream(self) -> List["DAGNode"]:
|
||||
"""Return the downstream nodes of current DAGNode.
|
||||
|
||||
Returns:
|
||||
List["DAGNode"]: The downstream nodes of current DAGNode
|
||||
"""
|
||||
return self._downstream
|
||||
|
||||
def set_dependency(self, nodes: DependencyType, is_upstream: bool = True) -> None:
|
||||
"""Set dependency for current node.
|
||||
|
||||
Args:
|
||||
nodes (DependencyType): The nodes to set dependency to current node.
|
||||
is_upstream (bool, optional): Whether set upstream nodes. Defaults to True.
|
||||
"""
|
||||
if not isinstance(nodes, Sequence):
|
||||
nodes = [nodes]
|
||||
if not all(isinstance(node, DAGNode) for node in nodes):
|
||||
raise ValueError(
|
||||
"all nodes to set dependency to current node must be instance of 'DAGNode'"
|
||||
"all nodes to set dependency to current node must be instance "
|
||||
"of 'DAGNode'"
|
||||
)
|
||||
nodes: Sequence[DAGNode] = nodes
|
||||
dags = set([node.dag for node in nodes if node.dag])
|
||||
nodes = cast(Sequence[DAGNode], nodes)
|
||||
dags = set([node.dag for node in nodes if node.dag]) # noqa: C403
|
||||
if self.dag:
|
||||
dags.add(self.dag)
|
||||
if not dags:
|
||||
@@ -302,6 +418,7 @@ class DAGNode(DAGLifecycle, DependencyMixin, ABC):
|
||||
node._upstream.append(self)
|
||||
|
||||
def __repr__(self):
|
||||
"""Return the representation of current DAGNode."""
|
||||
cls_name = self.__class__.__name__
|
||||
if self.node_name and self.node_name:
|
||||
return f"{cls_name}(node_id={self.node_id}, node_name={self.node_name})"
|
||||
@@ -313,6 +430,7 @@ class DAGNode(DAGLifecycle, DependencyMixin, ABC):
|
||||
return f"{cls_name}"
|
||||
|
||||
def __str__(self):
|
||||
"""Return the string of current DAGNode."""
|
||||
return self.__repr__()
|
||||
|
||||
|
||||
@@ -321,7 +439,7 @@ def _build_task_key(task_name: str, key: str) -> str:
|
||||
|
||||
|
||||
class DAGContext:
|
||||
"""The context of current DAG, created when the DAG is running
|
||||
"""The context of current DAG, created when the DAG is running.
|
||||
|
||||
Every DAG has been triggered will create a new DAGContext.
|
||||
"""
|
||||
@@ -329,22 +447,32 @@ class DAGContext:
|
||||
def __init__(
|
||||
self,
|
||||
streaming_call: bool = False,
|
||||
node_to_outputs: Dict[str, TaskContext] = None,
|
||||
node_name_to_ids: Dict[str, str] = None,
|
||||
node_to_outputs: Optional[Dict[str, TaskContext]] = None,
|
||||
node_name_to_ids: Optional[Dict[str, str]] = None,
|
||||
) -> None:
|
||||
"""Initialize a DAGContext.
|
||||
|
||||
Args:
|
||||
streaming_call (bool, optional): Whether the current DAG is streaming call.
|
||||
Defaults to False.
|
||||
node_to_outputs (Optional[Dict[str, TaskContext]], optional):
|
||||
The task outputs of current DAG. Defaults to None.
|
||||
node_name_to_ids (Optional[Dict[str, str]], optional):
|
||||
The task name to task id mapping. Defaults to None.
|
||||
"""
|
||||
if not node_to_outputs:
|
||||
node_to_outputs = {}
|
||||
if not node_name_to_ids:
|
||||
node_name_to_ids = {}
|
||||
self._streaming_call = streaming_call
|
||||
self._curr_task_ctx = None
|
||||
self._curr_task_ctx: Optional[TaskContext] = None
|
||||
self._share_data: Dict[str, Any] = {}
|
||||
self._node_to_outputs = node_to_outputs
|
||||
self._node_name_to_ids = node_name_to_ids
|
||||
self._node_to_outputs: Dict[str, TaskContext] = node_to_outputs
|
||||
self._node_name_to_ids: Dict[str, str] = node_name_to_ids
|
||||
|
||||
@property
|
||||
def _task_outputs(self) -> Dict[str, TaskContext]:
|
||||
"""The task outputs of current DAG
|
||||
"""Return the task outputs of current DAG.
|
||||
|
||||
Just use for internal for now.
|
||||
Returns:
|
||||
@@ -354,18 +482,28 @@ class DAGContext:
|
||||
|
||||
@property
|
||||
def current_task_context(self) -> TaskContext:
|
||||
"""Return the current task context."""
|
||||
if not self._curr_task_ctx:
|
||||
raise RuntimeError("Current task context not set")
|
||||
return self._curr_task_ctx
|
||||
|
||||
@property
|
||||
def streaming_call(self) -> bool:
|
||||
"""Whether the current DAG is streaming call"""
|
||||
"""Whether the current DAG is streaming call."""
|
||||
return self._streaming_call
|
||||
|
||||
def set_current_task_context(self, _curr_task_ctx: TaskContext) -> None:
|
||||
"""Set the current task context.
|
||||
|
||||
When the task is running, the current task context
|
||||
will be set to the task context.
|
||||
|
||||
TODO: We should support parallel task running in the future.
|
||||
"""
|
||||
self._curr_task_ctx = _curr_task_ctx
|
||||
|
||||
def get_task_output(self, task_name: str) -> TaskOutput:
|
||||
"""Get the task output by task name
|
||||
"""Get the task output by task name.
|
||||
|
||||
Args:
|
||||
task_name (str): The task name
|
||||
@@ -376,22 +514,41 @@ class DAGContext:
|
||||
if task_name is None:
|
||||
raise ValueError("task_name can't be None")
|
||||
node_id = self._node_name_to_ids.get(task_name)
|
||||
if node_id:
|
||||
if not node_id:
|
||||
raise ValueError(f"Task name {task_name} not exists in DAG")
|
||||
return self._task_outputs.get(node_id).task_output
|
||||
task_output = self._task_outputs.get(node_id)
|
||||
if not task_output:
|
||||
raise ValueError(f"Task output for task {task_name} not exists")
|
||||
return task_output.task_output
|
||||
|
||||
async def get_from_share_data(self, key: str) -> Any:
|
||||
"""Get share data by key.
|
||||
|
||||
Args:
|
||||
key (str): The share data key
|
||||
|
||||
Returns:
|
||||
Any: The share data, you can cast it to the real type
|
||||
"""
|
||||
return self._share_data.get(key)
|
||||
|
||||
async def save_to_share_data(
|
||||
self, key: str, data: Any, overwrite: bool = False
|
||||
) -> None:
|
||||
"""Save share data by key.
|
||||
|
||||
Args:
|
||||
key (str): The share data key
|
||||
data (Any): The share data
|
||||
overwrite (bool): Whether overwrite the share data if the key
|
||||
already exists. Defaults to None.
|
||||
"""
|
||||
if key in self._share_data and not overwrite:
|
||||
raise ValueError(f"Share data key {key} already exists")
|
||||
self._share_data[key] = data
|
||||
|
||||
async def get_task_share_data(self, task_name: str, key: str) -> Any:
|
||||
"""Get share data by task name and key
|
||||
"""Get share data by task name and key.
|
||||
|
||||
Args:
|
||||
task_name (str): The task name
|
||||
@@ -409,14 +566,14 @@ class DAGContext:
|
||||
async def save_task_share_data(
|
||||
self, task_name: str, key: str, data: Any, overwrite: bool = False
|
||||
) -> None:
|
||||
"""Save share data by task name and key
|
||||
"""Save share data by task name and key.
|
||||
|
||||
Args:
|
||||
task_name (str): The task name
|
||||
key (str): The share data key
|
||||
data (Any): The share data
|
||||
overwrite (bool): Whether overwrite the share data if the key already exists.
|
||||
Defaults to None.
|
||||
overwrite (bool): Whether overwrite the share data if the key
|
||||
already exists. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: If the share data key already exists and overwrite is not True
|
||||
@@ -429,15 +586,22 @@ class DAGContext:
|
||||
|
||||
|
||||
class DAG:
|
||||
"""The DAG class.
|
||||
|
||||
Manage the DAG nodes and the relationship between them.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, dag_id: str, resource_group: Optional[ResourceGroup] = None
|
||||
) -> None:
|
||||
"""Initialize a DAG."""
|
||||
self._dag_id = dag_id
|
||||
self.node_map: Dict[str, DAGNode] = {}
|
||||
self.node_name_to_node: Dict[str, DAGNode] = {}
|
||||
self._root_nodes: List[DAGNode] = None
|
||||
self._leaf_nodes: List[DAGNode] = None
|
||||
self._trigger_nodes: List[DAGNode] = None
|
||||
self._root_nodes: List[DAGNode] = []
|
||||
self._leaf_nodes: List[DAGNode] = []
|
||||
self._trigger_nodes: List[DAGNode] = []
|
||||
self._resource_group: Optional[ResourceGroup] = resource_group
|
||||
|
||||
def _append_node(self, node: DAGNode) -> None:
|
||||
if node.node_id in self.node_map:
|
||||
@@ -448,22 +612,26 @@ class DAG:
|
||||
f"Node name {node.node_name} already exists in DAG {self.dag_id}"
|
||||
)
|
||||
self.node_name_to_node[node.node_name] = node
|
||||
self.node_map[node.node_id] = node
|
||||
node_id = node.node_id
|
||||
if not node_id:
|
||||
raise ValueError("Node id can't be None")
|
||||
self.node_map[node_id] = node
|
||||
# clear cached nodes
|
||||
self._root_nodes = None
|
||||
self._leaf_nodes = None
|
||||
self._root_nodes = []
|
||||
self._leaf_nodes = []
|
||||
|
||||
def _new_node_id(self) -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
@property
|
||||
def dag_id(self) -> str:
|
||||
"""Return the dag id of current DAG."""
|
||||
return self._dag_id
|
||||
|
||||
def _build(self) -> None:
|
||||
from ..operator.common_operator import TriggerOperator
|
||||
|
||||
nodes = set()
|
||||
nodes: Set[DAGNode] = set()
|
||||
for _, node in self.node_map.items():
|
||||
nodes = nodes.union(_get_nodes(node))
|
||||
self._root_nodes = list(set(filter(lambda x: not x.upstream, nodes)))
|
||||
@@ -474,7 +642,7 @@ class DAG:
|
||||
|
||||
@property
|
||||
def root_nodes(self) -> List[DAGNode]:
|
||||
"""The root nodes of current DAG
|
||||
"""Return the root nodes of current DAG.
|
||||
|
||||
Returns:
|
||||
List[DAGNode]: The root nodes of current DAG, no repeat
|
||||
@@ -485,7 +653,7 @@ class DAG:
|
||||
|
||||
@property
|
||||
def leaf_nodes(self) -> List[DAGNode]:
|
||||
"""The leaf nodes of current DAG
|
||||
"""Return the leaf nodes of current DAG.
|
||||
|
||||
Returns:
|
||||
List[DAGNode]: The leaf nodes of current DAG, no repeat
|
||||
@@ -496,7 +664,7 @@ class DAG:
|
||||
|
||||
@property
|
||||
def trigger_nodes(self) -> List[DAGNode]:
|
||||
"""The trigger nodes of current DAG
|
||||
"""Return the trigger nodes of current DAG.
|
||||
|
||||
Returns:
|
||||
List[DAGNode]: The trigger nodes of current DAG, no repeat
|
||||
@@ -506,34 +674,42 @@ class DAG:
|
||||
return self._trigger_nodes
|
||||
|
||||
async def _after_dag_end(self) -> None:
|
||||
"""The callback after DAG end"""
|
||||
"""Execute after DAG end."""
|
||||
tasks = []
|
||||
for node in self.node_map.values():
|
||||
tasks.append(node.after_dag_end())
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
def print_tree(self) -> None:
|
||||
"""Print the DAG tree"""
|
||||
"""Print the DAG tree""" # noqa: D400
|
||||
_print_format_dag_tree(self)
|
||||
|
||||
def visualize_dag(self, view: bool = True, **kwargs) -> Optional[str]:
|
||||
"""Create the DAG graph"""
|
||||
"""Visualize the DAG.
|
||||
|
||||
Args:
|
||||
view (bool, optional): Whether view the DAG graph. Defaults to True,
|
||||
if True, it will open the graph file with your default viewer.
|
||||
"""
|
||||
self.print_tree()
|
||||
return _visualize_dag(self, view=view, **kwargs)
|
||||
|
||||
def __enter__(self):
|
||||
"""Enter a DAG context."""
|
||||
DAGVar.enter_dag(self)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Exit a DAG context."""
|
||||
DAGVar.exit_dag()
|
||||
|
||||
def __repr__(self):
|
||||
"""Return the representation of current DAG."""
|
||||
return f"DAG(dag_id={self.dag_id})"
|
||||
|
||||
|
||||
def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> set[DAGNode]:
|
||||
nodes = set()
|
||||
def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> Set[DAGNode]:
|
||||
nodes: Set[DAGNode] = set()
|
||||
if not node:
|
||||
return nodes
|
||||
nodes.add(node)
|
||||
@@ -553,7 +729,7 @@ def _print_dag(
|
||||
level: int = 0,
|
||||
prefix: str = "",
|
||||
last: bool = True,
|
||||
level_dict: Dict[str, Any] = None,
|
||||
level_dict: Optional[Dict[int, Any]] = None,
|
||||
):
|
||||
if level_dict is None:
|
||||
level_dict = {}
|
||||
@@ -606,7 +782,7 @@ def _handle_dag_nodes(
|
||||
|
||||
|
||||
def _visualize_dag(dag: DAG, view: bool = True, **kwargs) -> Optional[str]:
|
||||
"""Visualize the DAG
|
||||
"""Visualize the DAG.
|
||||
|
||||
Args:
|
||||
dag (DAG): The DAG to visualize
|
||||
@@ -641,7 +817,7 @@ def _visualize_dag(dag: DAG, view: bool = True, **kwargs) -> Optional[str]:
|
||||
filename = kwargs["filename"]
|
||||
del kwargs["filename"]
|
||||
|
||||
if not "directory" in kwargs:
|
||||
if "directory" not in kwargs:
|
||||
from dbgpt.configs.model_config import LOGDIR
|
||||
|
||||
kwargs["directory"] = LOGDIR
|
||||
|
Reference in New Issue
Block a user