chore: Add pylint for DB-GPT core lib (#1076)

This commit is contained in:
Fangyin Cheng
2024-01-16 17:36:26 +08:00
committed by GitHub
parent 3a54d1ef9a
commit 40c853575a
79 changed files with 2213 additions and 839 deletions

View File

@@ -1,3 +1,7 @@
"""The base module of DAG.
DAG is the core component of AWEL, it is used to define the relationship between tasks.
"""
import asyncio
import contextvars
import logging
@@ -6,7 +10,7 @@ import uuid
from abc import ABC, abstractmethod
from collections import deque
from concurrent.futures import Executor
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Union, cast
from dbgpt.component import SystemApp
@@ -27,86 +31,108 @@ def _is_async_context():
class DependencyMixin(ABC):
"""The mixin class for DAGNode.
This class defines the interface for setting upstream and downstream nodes.
And it also implements the operator << and >> for setting upstream
and downstream nodes.
"""
@abstractmethod
def set_upstream(self, nodes: DependencyType) -> "DependencyMixin":
def set_upstream(self, nodes: DependencyType) -> None:
"""Set one or more upstream nodes for this node.
Args:
nodes (DependencyType): Upstream nodes to be set to current node.
Returns:
DependencyMixin: Returns self to allow method chaining.
Raises:
ValueError: If no upstream nodes are provided or if an argument is not a DependencyMixin.
ValueError: If no upstream nodes are provided or if an argument is
not a DependencyMixin.
"""
@abstractmethod
def set_downstream(self, nodes: DependencyType) -> "DependencyMixin":
def set_downstream(self, nodes: DependencyType) -> None:
"""Set one or more downstream nodes for this node.
Args:
nodes (DependencyType): Downstream nodes to be set to current node.
Returns:
DependencyMixin: Returns self to allow method chaining.
Raises:
ValueError: If no downstream nodes are provided or if an argument is not a DependencyMixin.
ValueError: If no downstream nodes are provided or if an argument is
not a DependencyMixin.
"""
def __lshift__(self, nodes: DependencyType) -> DependencyType:
"""Implements self << nodes
"""Set upstream nodes for current node.
Implements: self << nodes.
Example:
.. code-block:: python
.. code-block:: python
# means node.set_upstream(input_node)
node << input_node
# means node2.set_upstream([input_node])
node2 << [input_node]
# means node.set_upstream(input_node)
node << input_node
# means node2.set_upstream([input_node])
node2 << [input_node]
"""
self.set_upstream(nodes)
return nodes
def __rshift__(self, nodes: DependencyType) -> DependencyType:
"""Implements self >> nodes
"""Set downstream nodes for current node.
Example:
Implements: self >> nodes.
.. code-block:: python
Examples:
.. code-block:: python
# means node.set_downstream(next_node)
node >> next_node
# means node.set_downstream(next_node)
node >> next_node
# means node2.set_downstream([next_node])
node2 >> [next_node]
# means node2.set_downstream([next_node])
node2 >> [next_node]
"""
self.set_downstream(nodes)
return nodes
def __rrshift__(self, nodes: DependencyType) -> "DependencyMixin":
"""Implements [node] >> self"""
"""Set upstream nodes for current node.
Implements: [node] >> self
"""
self.__lshift__(nodes)
return self
def __rlshift__(self, nodes: DependencyType) -> "DependencyMixin":
"""Implements [node] << self"""
"""Set downstream nodes for current node.
Implements: [node] << self
"""
self.__rshift__(nodes)
return self
class DAGVar:
"""The DAGVar is used to store the current DAG context."""
_thread_local = threading.local()
_async_local = contextvars.ContextVar("current_dag_stack", default=deque())
_system_app: SystemApp = None
_executor: Executor = None
_async_local: contextvars.ContextVar = contextvars.ContextVar(
"current_dag_stack", default=deque()
)
_system_app: Optional[SystemApp] = None
# The executor for current DAG, this is used run some sync tasks in async DAG
_executor: Optional[Executor] = None
@classmethod
def enter_dag(cls, dag) -> None:
"""Enter a DAG context.
Args:
dag (DAG): The DAG to enter
"""
is_async = _is_async_context()
if is_async:
stack = cls._async_local.get()
@@ -119,6 +145,7 @@ class DAGVar:
@classmethod
def exit_dag(cls) -> None:
"""Exit a DAG context."""
is_async = _is_async_context()
if is_async:
stack = cls._async_local.get()
@@ -134,6 +161,11 @@ class DAGVar:
@classmethod
def get_current_dag(cls) -> Optional["DAG"]:
"""Get the current DAG.
Returns:
Optional[DAG]: The current DAG
"""
is_async = _is_async_context()
if is_async:
stack = cls._async_local.get()
@@ -147,36 +179,56 @@ class DAGVar:
return None
@classmethod
def get_current_system_app(cls) -> SystemApp:
def get_current_system_app(cls) -> Optional[SystemApp]:
"""Get the current system app.
Returns:
Optional[SystemApp]: The current system app
"""
# if not cls._system_app:
# raise RuntimeError("System APP not set for DAGVar")
return cls._system_app
@classmethod
def set_current_system_app(cls, system_app: SystemApp) -> None:
"""Set the current system app.
Args:
system_app (SystemApp): The system app to set
"""
if cls._system_app:
logger.warn("System APP has already set, nothing to do")
logger.warning("System APP has already set, nothing to do")
else:
cls._system_app = system_app
@classmethod
def get_executor(cls) -> Executor:
def get_executor(cls) -> Optional[Executor]:
"""Get the current executor.
Returns:
Optional[Executor]: The current executor
"""
return cls._executor
@classmethod
def set_executor(cls, executor: Executor) -> None:
"""Set the current executor.
Args:
executor (Executor): The executor to set
"""
cls._executor = executor
class DAGLifecycle:
"""The lifecycle of DAG"""
"""The lifecycle of DAG."""
async def before_dag_run(self):
"""The callback before DAG run"""
"""Execute before DAG run."""
pass
async def after_dag_end(self):
"""The callback after DAG end,
"""Execute after DAG end.
This method may be called multiple times, please make sure it is idempotent.
"""
@@ -184,6 +236,8 @@ class DAGLifecycle:
class DAGNode(DAGLifecycle, DependencyMixin, ABC):
"""The base class of DAGNode."""
resource_group: Optional[ResourceGroup] = None
"""The resource group of current DAGNode"""
@@ -196,6 +250,17 @@ class DAGNode(DAGLifecycle, DependencyMixin, ABC):
executor: Optional[Executor] = None,
**kwargs,
) -> None:
"""Initialize a DAGNode.
Args:
dag (Optional["DAG"], optional): The DAG to add this node to.
Defaults to None.
node_id (Optional[str], optional): The node id. Defaults to None.
node_name (Optional[str], optional): The node name. Defaults to None.
system_app (Optional[SystemApp], optional): The system app.
Defaults to None.
executor (Optional[Executor], optional): The executor. Defaults to None.
"""
super().__init__()
self._upstream: List["DAGNode"] = []
self._downstream: List["DAGNode"] = []
@@ -206,24 +271,28 @@ class DAGNode(DAGLifecycle, DependencyMixin, ABC):
self._executor: Optional[Executor] = executor or DAGVar.get_executor()
if not node_id and self._dag:
node_id = self._dag._new_node_id()
self._node_id: str = node_id
self._node_name: str = node_name
self._node_id: Optional[str] = node_id
self._node_name: Optional[str] = node_name
@property
def node_id(self) -> str:
"""Return the node id of current DAGNode."""
if not self._node_id:
raise ValueError("Node id not set for current DAGNode")
return self._node_id
@property
@abstractmethod
def dev_mode(self) -> bool:
"""Whether current DAGNode is in dev mode"""
"""Whether current DAGNode is in dev mode."""
@property
def system_app(self) -> SystemApp:
def system_app(self) -> Optional[SystemApp]:
"""Return the system app of current DAGNode."""
return self._system_app
def set_system_app(self, system_app: SystemApp) -> None:
"""Set system app for current DAGNode
"""Set system app for current DAGNode.
Args:
system_app (SystemApp): The system app
@@ -231,50 +300,97 @@ class DAGNode(DAGLifecycle, DependencyMixin, ABC):
self._system_app = system_app
def set_node_id(self, node_id: str) -> None:
"""Set node id for current DAGNode.
Args:
node_id (str): The node id
"""
self._node_id = node_id
def __hash__(self) -> int:
"""Return the hash value of current DAGNode.
If the node_id is not None, return the hash value of node_id.
"""
if self.node_id:
return hash(self.node_id)
else:
return super().__hash__()
def __eq__(self, other: Any) -> bool:
"""Return whether the current DAGNode is equal to other DAGNode."""
if not isinstance(other, DAGNode):
return False
return self.node_id == other.node_id
@property
def node_name(self) -> str:
def node_name(self) -> Optional[str]:
"""Return the node name of current DAGNode.
Returns:
Optional[str]: The node name of current DAGNode
"""
return self._node_name
@property
def dag(self) -> "DAG":
def dag(self) -> Optional["DAG"]:
"""Return the DAG of current DAGNode.
Returns:
Optional["DAG"]: The DAG of current DAGNode
"""
return self._dag
def set_upstream(self, nodes: DependencyType) -> "DAGNode":
def set_upstream(self, nodes: DependencyType) -> None:
"""Set upstream nodes for current node.
Args:
nodes (DependencyType): Upstream nodes to be set to current node.
"""
self.set_dependency(nodes)
def set_downstream(self, nodes: DependencyType) -> "DAGNode":
def set_downstream(self, nodes: DependencyType) -> None:
"""Set downstream nodes for current node.
Args:
nodes (DependencyType): Downstream nodes to be set to current node.
"""
self.set_dependency(nodes, is_upstream=False)
@property
def upstream(self) -> List["DAGNode"]:
"""Return the upstream nodes of current DAGNode.
Returns:
List["DAGNode"]: The upstream nodes of current DAGNode
"""
return self._upstream
@property
def downstream(self) -> List["DAGNode"]:
"""Return the downstream nodes of current DAGNode.
Returns:
List["DAGNode"]: The downstream nodes of current DAGNode
"""
return self._downstream
def set_dependency(self, nodes: DependencyType, is_upstream: bool = True) -> None:
"""Set dependency for current node.
Args:
nodes (DependencyType): The nodes to set dependency to current node.
is_upstream (bool, optional): Whether set upstream nodes. Defaults to True.
"""
if not isinstance(nodes, Sequence):
nodes = [nodes]
if not all(isinstance(node, DAGNode) for node in nodes):
raise ValueError(
"all nodes to set dependency to current node must be instance of 'DAGNode'"
"all nodes to set dependency to current node must be instance "
"of 'DAGNode'"
)
nodes: Sequence[DAGNode] = nodes
dags = set([node.dag for node in nodes if node.dag])
nodes = cast(Sequence[DAGNode], nodes)
dags = set([node.dag for node in nodes if node.dag]) # noqa: C403
if self.dag:
dags.add(self.dag)
if not dags:
@@ -302,6 +418,7 @@ class DAGNode(DAGLifecycle, DependencyMixin, ABC):
node._upstream.append(self)
def __repr__(self):
"""Return the representation of current DAGNode."""
cls_name = self.__class__.__name__
if self.node_name and self.node_name:
return f"{cls_name}(node_id={self.node_id}, node_name={self.node_name})"
@@ -313,6 +430,7 @@ class DAGNode(DAGLifecycle, DependencyMixin, ABC):
return f"{cls_name}"
def __str__(self):
"""Return the string of current DAGNode."""
return self.__repr__()
@@ -321,7 +439,7 @@ def _build_task_key(task_name: str, key: str) -> str:
class DAGContext:
"""The context of current DAG, created when the DAG is running
"""The context of current DAG, created when the DAG is running.
Every DAG has been triggered will create a new DAGContext.
"""
@@ -329,22 +447,32 @@ class DAGContext:
def __init__(
self,
streaming_call: bool = False,
node_to_outputs: Dict[str, TaskContext] = None,
node_name_to_ids: Dict[str, str] = None,
node_to_outputs: Optional[Dict[str, TaskContext]] = None,
node_name_to_ids: Optional[Dict[str, str]] = None,
) -> None:
"""Initialize a DAGContext.
Args:
streaming_call (bool, optional): Whether the current DAG is streaming call.
Defaults to False.
node_to_outputs (Optional[Dict[str, TaskContext]], optional):
The task outputs of current DAG. Defaults to None.
node_name_to_ids (Optional[Dict[str, str]], optional):
The task name to task id mapping. Defaults to None.
"""
if not node_to_outputs:
node_to_outputs = {}
if not node_name_to_ids:
node_name_to_ids = {}
self._streaming_call = streaming_call
self._curr_task_ctx = None
self._curr_task_ctx: Optional[TaskContext] = None
self._share_data: Dict[str, Any] = {}
self._node_to_outputs = node_to_outputs
self._node_name_to_ids = node_name_to_ids
self._node_to_outputs: Dict[str, TaskContext] = node_to_outputs
self._node_name_to_ids: Dict[str, str] = node_name_to_ids
@property
def _task_outputs(self) -> Dict[str, TaskContext]:
"""The task outputs of current DAG
"""Return the task outputs of current DAG.
Just use for internal for now.
Returns:
@@ -354,18 +482,28 @@ class DAGContext:
@property
def current_task_context(self) -> TaskContext:
"""Return the current task context."""
if not self._curr_task_ctx:
raise RuntimeError("Current task context not set")
return self._curr_task_ctx
@property
def streaming_call(self) -> bool:
"""Whether the current DAG is streaming call"""
"""Whether the current DAG is streaming call."""
return self._streaming_call
def set_current_task_context(self, _curr_task_ctx: TaskContext) -> None:
"""Set the current task context.
When the task is running, the current task context
will be set to the task context.
TODO: We should support parallel task running in the future.
"""
self._curr_task_ctx = _curr_task_ctx
def get_task_output(self, task_name: str) -> TaskOutput:
"""Get the task output by task name
"""Get the task output by task name.
Args:
task_name (str): The task name
@@ -376,22 +514,41 @@ class DAGContext:
if task_name is None:
raise ValueError("task_name can't be None")
node_id = self._node_name_to_ids.get(task_name)
if node_id:
if not node_id:
raise ValueError(f"Task name {task_name} not exists in DAG")
return self._task_outputs.get(node_id).task_output
task_output = self._task_outputs.get(node_id)
if not task_output:
raise ValueError(f"Task output for task {task_name} not exists")
return task_output.task_output
async def get_from_share_data(self, key: str) -> Any:
"""Get share data by key.
Args:
key (str): The share data key
Returns:
Any: The share data, you can cast it to the real type
"""
return self._share_data.get(key)
async def save_to_share_data(
self, key: str, data: Any, overwrite: bool = False
) -> None:
"""Save share data by key.
Args:
key (str): The share data key
data (Any): The share data
overwrite (bool): Whether overwrite the share data if the key
already exists. Defaults to None.
"""
if key in self._share_data and not overwrite:
raise ValueError(f"Share data key {key} already exists")
self._share_data[key] = data
async def get_task_share_data(self, task_name: str, key: str) -> Any:
"""Get share data by task name and key
"""Get share data by task name and key.
Args:
task_name (str): The task name
@@ -409,14 +566,14 @@ class DAGContext:
async def save_task_share_data(
self, task_name: str, key: str, data: Any, overwrite: bool = False
) -> None:
"""Save share data by task name and key
"""Save share data by task name and key.
Args:
task_name (str): The task name
key (str): The share data key
data (Any): The share data
overwrite (bool): Whether overwrite the share data if the key already exists.
Defaults to None.
overwrite (bool): Whether overwrite the share data if the key
already exists. Defaults to None.
Raises:
ValueError: If the share data key already exists and overwrite is not True
@@ -429,15 +586,22 @@ class DAGContext:
class DAG:
"""The DAG class.
Manage the DAG nodes and the relationship between them.
"""
def __init__(
self, dag_id: str, resource_group: Optional[ResourceGroup] = None
) -> None:
"""Initialize a DAG."""
self._dag_id = dag_id
self.node_map: Dict[str, DAGNode] = {}
self.node_name_to_node: Dict[str, DAGNode] = {}
self._root_nodes: List[DAGNode] = None
self._leaf_nodes: List[DAGNode] = None
self._trigger_nodes: List[DAGNode] = None
self._root_nodes: List[DAGNode] = []
self._leaf_nodes: List[DAGNode] = []
self._trigger_nodes: List[DAGNode] = []
self._resource_group: Optional[ResourceGroup] = resource_group
def _append_node(self, node: DAGNode) -> None:
if node.node_id in self.node_map:
@@ -448,22 +612,26 @@ class DAG:
f"Node name {node.node_name} already exists in DAG {self.dag_id}"
)
self.node_name_to_node[node.node_name] = node
self.node_map[node.node_id] = node
node_id = node.node_id
if not node_id:
raise ValueError("Node id can't be None")
self.node_map[node_id] = node
# clear cached nodes
self._root_nodes = None
self._leaf_nodes = None
self._root_nodes = []
self._leaf_nodes = []
def _new_node_id(self) -> str:
return str(uuid.uuid4())
@property
def dag_id(self) -> str:
"""Return the dag id of current DAG."""
return self._dag_id
def _build(self) -> None:
from ..operator.common_operator import TriggerOperator
nodes = set()
nodes: Set[DAGNode] = set()
for _, node in self.node_map.items():
nodes = nodes.union(_get_nodes(node))
self._root_nodes = list(set(filter(lambda x: not x.upstream, nodes)))
@@ -474,7 +642,7 @@ class DAG:
@property
def root_nodes(self) -> List[DAGNode]:
"""The root nodes of current DAG
"""Return the root nodes of current DAG.
Returns:
List[DAGNode]: The root nodes of current DAG, no repeat
@@ -485,7 +653,7 @@ class DAG:
@property
def leaf_nodes(self) -> List[DAGNode]:
"""The leaf nodes of current DAG
"""Return the leaf nodes of current DAG.
Returns:
List[DAGNode]: The leaf nodes of current DAG, no repeat
@@ -496,7 +664,7 @@ class DAG:
@property
def trigger_nodes(self) -> List[DAGNode]:
"""The trigger nodes of current DAG
"""Return the trigger nodes of current DAG.
Returns:
List[DAGNode]: The trigger nodes of current DAG, no repeat
@@ -506,34 +674,42 @@ class DAG:
return self._trigger_nodes
async def _after_dag_end(self) -> None:
"""The callback after DAG end"""
"""Execute after DAG end."""
tasks = []
for node in self.node_map.values():
tasks.append(node.after_dag_end())
await asyncio.gather(*tasks)
def print_tree(self) -> None:
"""Print the DAG tree"""
"""Print the DAG tree""" # noqa: D400
_print_format_dag_tree(self)
def visualize_dag(self, view: bool = True, **kwargs) -> Optional[str]:
"""Create the DAG graph"""
"""Visualize the DAG.
Args:
view (bool, optional): Whether view the DAG graph. Defaults to True,
if True, it will open the graph file with your default viewer.
"""
self.print_tree()
return _visualize_dag(self, view=view, **kwargs)
def __enter__(self):
"""Enter a DAG context."""
DAGVar.enter_dag(self)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Exit a DAG context."""
DAGVar.exit_dag()
def __repr__(self):
"""Return the representation of current DAG."""
return f"DAG(dag_id={self.dag_id})"
def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> set[DAGNode]:
nodes = set()
def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> Set[DAGNode]:
nodes: Set[DAGNode] = set()
if not node:
return nodes
nodes.add(node)
@@ -553,7 +729,7 @@ def _print_dag(
level: int = 0,
prefix: str = "",
last: bool = True,
level_dict: Dict[str, Any] = None,
level_dict: Optional[Dict[int, Any]] = None,
):
if level_dict is None:
level_dict = {}
@@ -606,7 +782,7 @@ def _handle_dag_nodes(
def _visualize_dag(dag: DAG, view: bool = True, **kwargs) -> Optional[str]:
"""Visualize the DAG
"""Visualize the DAG.
Args:
dag (DAG): The DAG to visualize
@@ -641,7 +817,7 @@ def _visualize_dag(dag: DAG, view: bool = True, **kwargs) -> Optional[str]:
filename = kwargs["filename"]
del kwargs["filename"]
if not "directory" in kwargs:
if "directory" not in kwargs:
from dbgpt.configs.model_config import LOGDIR
kwargs["directory"] = LOGDIR