DB-GPT/dbgpt/core/awel/dag/base.py
2024-08-29 19:39:42 +08:00

1086 lines
34 KiB
Python

"""The base module of DAG.
DAG is the core component of AWEL, it is used to define the relationship between tasks.
"""
import asyncio
import contextvars
import dataclasses
import logging
import threading
import uuid
from abc import ABC, abstractmethod
from collections import deque
from concurrent.futures import Executor
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Set,
Union,
cast,
)
from dbgpt.component import SystemApp
from ..flow.base import ViewMixin
from ..resource.base import ResourceGroup
from ..task.base import TaskContext, TaskOutput
logger = logging.getLogger(__name__)
DependencyType = Union["DependencyMixin", Sequence["DependencyMixin"]]
if TYPE_CHECKING:
from ...interface.variables import VariablesProvider
def _is_async_context():
try:
loop = asyncio.get_running_loop()
return asyncio.current_task(loop=loop) is not None
except RuntimeError:
return False
class DependencyMixin(ABC):
"""The mixin class for DAGNode.
This class defines the interface for setting upstream and downstream nodes.
And it also implements the operator << and >> for setting upstream
and downstream nodes.
"""
@abstractmethod
def set_upstream(self, nodes: DependencyType) -> None:
"""Set one or more upstream nodes for this node.
Args:
nodes (DependencyType): Upstream nodes to be set to current node.
Raises:
ValueError: If no upstream nodes are provided or if an argument is
not a DependencyMixin.
"""
@abstractmethod
def set_downstream(self, nodes: DependencyType) -> None:
"""Set one or more downstream nodes for this node.
Args:
nodes (DependencyType): Downstream nodes to be set to current node.
Raises:
ValueError: If no downstream nodes are provided or if an argument is
not a DependencyMixin.
"""
def __lshift__(self, nodes: DependencyType) -> DependencyType:
"""Set upstream nodes for current node.
Implements: self << nodes.
Example:
.. code-block:: python
# means node.set_upstream(input_node)
node << input_node
# means node2.set_upstream([input_node])
node2 << [input_node]
"""
self.set_upstream(nodes)
return nodes
def __rshift__(self, nodes: DependencyType) -> DependencyType:
"""Set downstream nodes for current node.
Implements: self >> nodes.
Examples:
.. code-block:: python
# means node.set_downstream(next_node)
node >> next_node
# means node2.set_downstream([next_node])
node2 >> [next_node]
"""
self.set_downstream(nodes)
return nodes
def __rrshift__(self, nodes: DependencyType) -> "DependencyMixin":
"""Set upstream nodes for current node.
Implements: [node] >> self
"""
self.__lshift__(nodes)
return self
def __rlshift__(self, nodes: DependencyType) -> "DependencyMixin":
"""Set downstream nodes for current node.
Implements: [node] << self
"""
self.__rshift__(nodes)
return self
class DAGVar:
"""The DAGVar is used to store the current DAG context."""
_thread_local = threading.local()
_async_local: contextvars.ContextVar = contextvars.ContextVar(
"current_dag_stack", default=deque()
)
_system_app: Optional[SystemApp] = None
# The executor for current DAG, this is used run some sync tasks in async DAG
_executor: Optional[Executor] = None
_variables_provider: Optional["VariablesProvider"] = None
@classmethod
def enter_dag(cls, dag) -> None:
"""Enter a DAG context.
Args:
dag (DAG): The DAG to enter
"""
is_async = _is_async_context()
if is_async:
stack = cls._async_local.get()
stack.append(dag)
cls._async_local.set(stack)
else:
if not hasattr(cls._thread_local, "current_dag_stack"):
cls._thread_local.current_dag_stack = deque()
cls._thread_local.current_dag_stack.append(dag)
@classmethod
def exit_dag(cls) -> None:
"""Exit a DAG context."""
is_async = _is_async_context()
if is_async:
stack = cls._async_local.get()
if stack:
stack.pop()
cls._async_local.set(stack)
else:
if (
hasattr(cls._thread_local, "current_dag_stack")
and cls._thread_local.current_dag_stack
):
cls._thread_local.current_dag_stack.pop()
@classmethod
def get_current_dag(cls) -> Optional["DAG"]:
"""Get the current DAG.
Returns:
Optional[DAG]: The current DAG
"""
is_async = _is_async_context()
if is_async:
stack = cls._async_local.get()
return stack[-1] if stack else None
else:
if (
hasattr(cls._thread_local, "current_dag_stack")
and cls._thread_local.current_dag_stack
):
return cls._thread_local.current_dag_stack[-1]
return None
@classmethod
def get_current_system_app(cls) -> Optional[SystemApp]:
"""Get the current system app.
Returns:
Optional[SystemApp]: The current system app
"""
# if not cls._system_app:
# raise RuntimeError("System APP not set for DAGVar")
return cls._system_app
@classmethod
def set_current_system_app(cls, system_app: SystemApp) -> None:
"""Set the current system app.
Args:
system_app (SystemApp): The system app to set
"""
if cls._system_app:
logger.warning("System APP has already set, nothing to do")
else:
cls._system_app = system_app
@classmethod
def get_executor(cls) -> Optional[Executor]:
"""Get the current executor.
Returns:
Optional[Executor]: The current executor
"""
return cls._executor
@classmethod
def set_executor(cls, executor: Executor) -> None:
"""Set the current executor.
Args:
executor (Executor): The executor to set
"""
cls._executor = executor
@classmethod
def get_variables_provider(cls) -> Optional["VariablesProvider"]:
"""Get the current variables provider.
Returns:
Optional[VariablesProvider]: The current variables provider
"""
return cls._variables_provider
@classmethod
def set_variables_provider(cls, variables_provider: "VariablesProvider") -> None:
"""Set the current variables provider.
Args:
variables_provider (VariablesProvider): The variables provider to set
"""
cls._variables_provider = variables_provider
class DAGLifecycle:
"""The lifecycle of DAG."""
async def before_dag_run(self):
"""Execute before DAG run."""
pass
async def after_dag_end(self, event_loop_task_id: int):
"""Execute after DAG end.
This method may be called multiple times, please make sure it is idempotent.
"""
pass
class DAGNode(DAGLifecycle, DependencyMixin, ViewMixin, ABC):
"""The base class of DAGNode."""
resource_group: Optional[ResourceGroup] = None
"""The resource group of current DAGNode"""
def __init__(
self,
dag: Optional["DAG"] = None,
node_id: Optional[str] = None,
node_name: Optional[str] = None,
system_app: Optional[SystemApp] = None,
executor: Optional[Executor] = None,
**kwargs,
) -> None:
"""Initialize a DAGNode.
Args:
dag (Optional["DAG"], optional): The DAG to add this node to.
Defaults to None.
node_id (Optional[str], optional): The node id. Defaults to None.
node_name (Optional[str], optional): The node name. Defaults to None.
system_app (Optional[SystemApp], optional): The system app.
Defaults to None.
executor (Optional[Executor], optional): The executor. Defaults to None.
"""
super().__init__()
self._upstream: List["DAGNode"] = []
self._downstream: List["DAGNode"] = []
self._dag: Optional["DAG"] = dag or DAGVar.get_current_dag()
self._system_app: Optional[SystemApp] = (
system_app or DAGVar.get_current_system_app()
)
self._executor: Optional[Executor] = executor or DAGVar.get_executor()
if not node_id and self._dag:
node_id = self._dag._new_node_id()
self._node_id: Optional[str] = node_id
self._node_name: Optional[str] = node_name
if self._dag:
self._dag._append_node(self)
@property
def node_id(self) -> str:
"""Return the node id of current DAGNode."""
if not self._node_id:
raise ValueError("Node id not set for current DAGNode")
return self._node_id
@property
@abstractmethod
def dev_mode(self) -> bool:
"""Whether current DAGNode is in dev mode."""
@property
def system_app(self) -> Optional[SystemApp]:
"""Return the system app of current DAGNode."""
return self._system_app
def set_system_app(self, system_app: SystemApp) -> None:
"""Set system app for current DAGNode.
Args:
system_app (SystemApp): The system app
"""
self._system_app = system_app
def set_node_id(self, node_id: str) -> None:
"""Set node id for current DAGNode.
Args:
node_id (str): The node id
"""
self._node_id = node_id
def __hash__(self) -> int:
"""Return the hash value of current DAGNode.
If the node_id is not None, return the hash value of node_id.
"""
if self.node_id:
return hash(self.node_id)
else:
return super().__hash__()
def __eq__(self, other: Any) -> bool:
"""Return whether the current DAGNode is equal to other DAGNode."""
if not isinstance(other, DAGNode):
return False
return self.node_id == other.node_id
@property
def node_name(self) -> Optional[str]:
"""Return the node name of current DAGNode.
Returns:
Optional[str]: The node name of current DAGNode
"""
return self._node_name
@property
def dag(self) -> Optional["DAG"]:
"""Return the DAG of current DAGNode.
Returns:
Optional["DAG"]: The DAG of current DAGNode
"""
return self._dag
def set_upstream(self, nodes: DependencyType) -> None:
"""Set upstream nodes for current node.
Args:
nodes (DependencyType): Upstream nodes to be set to current node.
"""
self.set_dependency(nodes)
def set_downstream(self, nodes: DependencyType) -> None:
"""Set downstream nodes for current node.
Args:
nodes (DependencyType): Downstream nodes to be set to current node.
"""
self.set_dependency(nodes, is_upstream=False)
@property
def upstream(self) -> List["DAGNode"]:
"""Return the upstream nodes of current DAGNode.
Returns:
List["DAGNode"]: The upstream nodes of current DAGNode
"""
return self._upstream
@property
def downstream(self) -> List["DAGNode"]:
"""Return the downstream nodes of current DAGNode.
Returns:
List["DAGNode"]: The downstream nodes of current DAGNode
"""
return self._downstream
def set_dependency(self, nodes: DependencyType, is_upstream: bool = True) -> None:
"""Set dependency for current node.
Args:
nodes (DependencyType): The nodes to set dependency to current node.
is_upstream (bool, optional): Whether set upstream nodes. Defaults to True.
"""
if not isinstance(nodes, Sequence):
nodes = [nodes]
if not all(isinstance(node, DAGNode) for node in nodes):
raise ValueError(
"all nodes to set dependency to current node must be instance "
"of 'DAGNode'"
)
nodes = cast(Sequence[DAGNode], nodes)
dags = set([node.dag for node in nodes if node.dag]) # noqa: C403
if self.dag:
dags.add(self.dag)
if not dags:
raise ValueError("set dependency to current node must in a DAG context")
if len(dags) != 1:
raise ValueError(
"set dependency to current node just support in one DAG context"
)
dag = dags.pop()
self._dag = dag
dag._append_node(self)
for node in nodes:
if is_upstream and node not in self.upstream:
node._dag = dag
dag._append_node(node)
self._upstream.append(node)
node._downstream.append(self)
elif node not in self._downstream:
node._dag = dag
dag._append_node(node)
self._downstream.append(node)
node._upstream.append(self)
def __repr__(self):
"""Return the representation of current DAGNode."""
cls_name = self.__class__.__name__
if self.node_id and self.node_name:
return f"{cls_name}(node_id={self.node_id}, node_name={self.node_name})"
if self.node_id:
return f"{cls_name}(node_id={self.node_id})"
if self.node_name:
return f"{cls_name}(node_name={self.node_name})"
else:
return f"{cls_name}"
@property
def graph_str(self):
"""Return the graph string of current DAGNode."""
cls_name = self.__class__.__name__
if self.node_id and self.node_name:
return f"{self.node_id}({cls_name},{self.node_name})"
if self.node_id:
return f"{self.node_id}({cls_name})"
if self.node_name:
return f"{self.node_name}_{cls_name}({cls_name})"
else:
return f"{cls_name}"
def __str__(self):
"""Return the string of current DAGNode."""
return self.__repr__()
def _build_task_key(task_name: str, key: str) -> str:
return f"{task_name}___$$$$$$___{key}"
@dataclasses.dataclass
class _DAGVariablesItem:
"""The DAG variables item.
It is a private class, just used for internal.
"""
key: str
name: str
label: str
value: Any
category: Literal["common", "secret"] = "common"
scope: str = "global"
value_type: Optional[str] = None
scope_key: Optional[str] = None
sys_code: Optional[str] = None
user_name: Optional[str] = None
description: Optional[str] = None
@dataclasses.dataclass
class DAGVariables:
"""The DAG variables."""
items: List[_DAGVariablesItem] = dataclasses.field(default_factory=list)
_cached_provider: Optional["VariablesProvider"] = None
_lock: threading.Lock = dataclasses.field(default_factory=threading.Lock)
def merge(self, dag_variables: "DAGVariables") -> "DAGVariables":
"""Merge the DAG variables.
Args:
dag_variables (DAGVariables): The DAG variables to merge
"""
def _build_key(item: _DAGVariablesItem):
key = "_".join([item.key, item.name, item.scope])
if item.scope_key:
key += f"_{item.scope_key}"
if item.sys_code:
key += f"_{item.sys_code}"
if item.user_name:
key += f"_{item.user_name}"
return key
new_items = []
exist_vars = set()
for item in self.items:
new_items.append(item)
exist_vars.add(_build_key(item))
for item in dag_variables.items:
key = _build_key(item)
if key not in exist_vars:
new_items.append(item)
return DAGVariables(
items=new_items,
_cached_provider=self._cached_provider or dag_variables._cached_provider,
)
def to_provider(self) -> "VariablesProvider":
"""Convert the DAG variables to variables provider.
Returns:
VariablesProvider: The variables provider
"""
if not self._cached_provider:
from ...interface.variables import (
StorageVariables,
StorageVariablesProvider,
)
with self._lock:
# Create a new provider safely
provider = StorageVariablesProvider()
for item in self.items:
storage_vars = StorageVariables(
key=item.key,
name=item.name,
label=item.label,
value=item.value,
category=item.category,
scope=item.scope,
value_type=item.value_type,
scope_key=item.scope_key,
sys_code=item.sys_code,
user_name=item.user_name,
description=item.description,
)
provider.save(storage_vars)
self._cached_provider = provider
return self._cached_provider
class DAGContext:
"""The context of current DAG, created when the DAG is running.
Every DAG has been triggered will create a new DAGContext.
"""
def __init__(
self,
node_to_outputs: Dict[str, TaskContext],
share_data: Dict[str, Any],
event_loop_task_id: int,
streaming_call: bool = False,
node_name_to_ids: Optional[Dict[str, str]] = None,
dag_variables: Optional[DAGVariables] = None,
) -> None:
"""Initialize a DAGContext.
Args:
node_to_outputs (Dict[str, TaskContext]): The task outputs of current DAG.
share_data (Dict[str, Any]): The share data of current DAG.
streaming_call (bool, optional): Whether the current DAG is streaming call.
Defaults to False.
node_name_to_ids (Optional[Dict[str, str]], optional): The node name to node
dag_variables (Optional[DAGVariables], optional): The DAG variables.
"""
if not node_name_to_ids:
node_name_to_ids = {}
self._streaming_call = streaming_call
self._curr_task_ctx: Optional[TaskContext] = None
self._share_data: Dict[str, Any] = share_data
self._node_to_outputs: Dict[str, TaskContext] = node_to_outputs
self._node_name_to_ids: Dict[str, str] = node_name_to_ids
self._event_loop_task_id = event_loop_task_id
self._dag_variables = dag_variables
self._share_data_lock = asyncio.Lock()
@property
def _task_outputs(self) -> Dict[str, TaskContext]:
"""Return the task outputs of current DAG.
Just use for internal for now.
Returns:
Dict[str, TaskContext]: The task outputs of current DAG
"""
return self._node_to_outputs
@property
def current_task_context(self) -> TaskContext:
"""Return the current task context."""
if not self._curr_task_ctx:
raise RuntimeError("Current task context not set")
return self._curr_task_ctx
@property
def streaming_call(self) -> bool:
"""Whether the current DAG is streaming call."""
return self._streaming_call
def set_current_task_context(self, _curr_task_ctx: TaskContext) -> None:
"""Set the current task context.
When the task is running, the current task context
will be set to the task context.
TODO: We should support parallel task running in the future.
"""
self._curr_task_ctx = _curr_task_ctx
def get_task_output(self, task_name: str) -> TaskOutput:
"""Get the task output by task name.
Args:
task_name (str): The task name
Returns:
TaskOutput: The task output
"""
if task_name is None:
raise ValueError("task_name can't be None")
node_id = self._node_name_to_ids.get(task_name)
if not node_id:
raise ValueError(f"Task name {task_name} not in DAG")
task_output = self._task_outputs.get(node_id)
if not task_output:
raise ValueError(f"Task output for task {task_name} not exists")
return task_output.task_output
async def get_from_share_data(self, key: str) -> Any:
"""Get share data by key.
Args:
key (str): The share data key
Returns:
Any: The share data, you can cast it to the real type
"""
async with self._share_data_lock:
logger.debug(f"Get share data by key {key} from {id(self._share_data)}")
return self._share_data.get(key)
async def save_to_share_data(
self, key: str, data: Any, overwrite: bool = False
) -> None:
"""Save share data by key.
Args:
key (str): The share data key
data (Any): The share data
overwrite (bool): Whether overwrite the share data if the key
already exists. Defaults to None.
"""
async with self._share_data_lock:
if key in self._share_data and not overwrite:
raise ValueError(f"Share data key {key} already exists")
logger.debug(f"Save share data by key {key} to {id(self._share_data)}")
self._share_data[key] = data
async def get_task_share_data(self, task_name: str, key: str) -> Any:
"""Get share data by task name and key.
Args:
task_name (str): The task name
key (str): The share data key
Returns:
Any: The share data
"""
if task_name is None:
raise ValueError("task_name can't be None")
if key is None:
raise ValueError("key can't be None")
return self.get_from_share_data(_build_task_key(task_name, key))
async def save_task_share_data(
self, task_name: str, key: str, data: Any, overwrite: bool = False
) -> None:
"""Save share data by task name and key.
Args:
task_name (str): The task name
key (str): The share data key
data (Any): The share data
overwrite (bool): Whether overwrite the share data if the key
already exists. Defaults to None.
Raises:
ValueError: If the share data key already exists and overwrite is not True
"""
if task_name is None:
raise ValueError("task_name can't be None")
if key is None:
raise ValueError("key can't be None")
await self.save_to_share_data(_build_task_key(task_name, key), data, overwrite)
async def _clean_all(self):
pass
class DAG:
"""The DAG class.
Manage the DAG nodes and the relationship between them.
"""
def __init__(
self,
dag_id: str,
resource_group: Optional[ResourceGroup] = None,
tags: Optional[Dict[str, str]] = None,
description: Optional[str] = None,
default_dag_variables: Optional[DAGVariables] = None,
) -> None:
"""Initialize a DAG."""
self._dag_id = dag_id
self._tags: Dict[str, str] = tags or {}
self._description = description
self.node_map: Dict[str, DAGNode] = {}
self.node_name_to_node: Dict[str, DAGNode] = {}
self._root_nodes: List[DAGNode] = []
self._leaf_nodes: List[DAGNode] = []
self._trigger_nodes: List[DAGNode] = []
self._resource_group: Optional[ResourceGroup] = resource_group
self._lock = asyncio.Lock()
self._event_loop_task_id_to_ctx: Dict[int, DAGContext] = {}
self._default_dag_variables = default_dag_variables
def _append_node(self, node: DAGNode) -> None:
if node.node_id in self.node_map:
return
if node.node_name:
if node.node_name in self.node_name_to_node:
raise ValueError(
f"Node name {node.node_name} already exists in DAG {self.dag_id}"
)
self.node_name_to_node[node.node_name] = node
node_id = node.node_id
if not node_id:
raise ValueError("Node id can't be None")
self.node_map[node_id] = node
# clear cached nodes
self._root_nodes = []
self._leaf_nodes = []
def _new_node_id(self) -> str:
return str(uuid.uuid4())
@property
def dag_id(self) -> str:
"""Return the dag id of current DAG."""
return self._dag_id
@property
def tags(self) -> Dict[str, str]:
"""Return the tags of current DAG."""
return self._tags
@property
def description(self) -> Optional[str]:
"""Return the description of current DAG."""
return self._description
@property
def dev_mode(self) -> bool:
"""Whether the current DAG is in dev mode.
Returns:
bool: Whether the current DAG is in dev mode
"""
from ..operators.base import _dev_mode
return _dev_mode()
def _build(self) -> None:
from ..operators.common_operator import TriggerOperator
nodes: Set[DAGNode] = set()
for _, node in self.node_map.items():
nodes = nodes.union(_get_nodes(node))
self._root_nodes = list(set(filter(lambda x: not x.upstream, nodes)))
self._leaf_nodes = list(set(filter(lambda x: not x.downstream, nodes)))
self._trigger_nodes = list(
set(filter(lambda x: isinstance(x, TriggerOperator), nodes))
)
@property
def root_nodes(self) -> List[DAGNode]:
"""Return the root nodes of current DAG.
Returns:
List[DAGNode]: The root nodes of current DAG, no repeat
"""
if not self._root_nodes:
self._build()
return self._root_nodes
@property
def leaf_nodes(self) -> List[DAGNode]:
"""Return the leaf nodes of current DAG.
Returns:
List[DAGNode]: The leaf nodes of current DAG, no repeat
"""
if not self._leaf_nodes:
self._build()
return self._leaf_nodes
@property
def trigger_nodes(self) -> List[DAGNode]:
"""Return the trigger nodes of current DAG.
Returns:
List[DAGNode]: The trigger nodes of current DAG, no repeat
"""
if not self._trigger_nodes:
self._build()
return self._trigger_nodes
async def _save_dag_ctx(self, dag_ctx: DAGContext) -> None:
async with self._lock:
event_loop_task_id = dag_ctx._event_loop_task_id
current_task = asyncio.current_task()
task_name = current_task.get_name() if current_task else None
self._event_loop_task_id_to_ctx[event_loop_task_id] = dag_ctx
logger.debug(
f"Save DAG context {dag_ctx} to event loop task {event_loop_task_id}, "
f"task_name: {task_name}"
)
async def _after_dag_end(self, event_loop_task_id: Optional[int] = None) -> None:
"""Execute after DAG end."""
tasks = []
event_loop_task_id = event_loop_task_id or id(asyncio.current_task())
for node in self.node_map.values():
tasks.append(node.after_dag_end(event_loop_task_id))
await asyncio.gather(*tasks)
# Clear the DAG context
async with self._lock:
current_task = asyncio.current_task()
task_name = current_task.get_name() if current_task else None
if event_loop_task_id not in self._event_loop_task_id_to_ctx:
raise RuntimeError(
f"DAG context not found with event loop task id "
f"{event_loop_task_id}, task_name: {task_name}"
)
logger.debug(
f"Clean DAG context with event loop task id {event_loop_task_id}, "
f"task_name: {task_name}"
)
dag_ctx = self._event_loop_task_id_to_ctx.pop(event_loop_task_id)
await dag_ctx._clean_all()
def print_tree(self) -> None:
"""Print the DAG tree""" # noqa: D400
_print_format_dag_tree(self)
def visualize_dag(self, view: bool = True, **kwargs) -> Optional[str]:
"""Visualize the DAG.
Args:
view (bool, optional): Whether view the DAG graph. Defaults to True,
if True, it will open the graph file with your default viewer.
"""
self.print_tree()
return _visualize_dag(self, view=view, **kwargs)
def show(self, mermaid: bool = False) -> Any:
"""Return the graph of current DAG."""
dot, mermaid_str = _get_graph(self)
return mermaid_str if mermaid else dot
def __enter__(self):
"""Enter a DAG context."""
DAGVar.enter_dag(self)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Exit a DAG context."""
DAGVar.exit_dag()
def __hash__(self) -> int:
"""Return the hash value of current DAG.
If the dag_id is not None, return the hash value of dag_id.
"""
if self.dag_id:
return hash(self.dag_id)
else:
return super().__hash__()
def __eq__(self, other):
"""Return whether the current DAG is equal to other DAG."""
if not isinstance(other, DAG):
return False
return self.dag_id == other.dag_id
def __repr__(self):
"""Return the representation of current DAG."""
return f"DAG(dag_id={self.dag_id})"
def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> Set[DAGNode]:
nodes: Set[DAGNode] = set()
if not node:
return nodes
nodes.add(node)
stream_nodes = node.upstream if is_upstream else node.downstream
for node in stream_nodes:
nodes = nodes.union(_get_nodes(node, is_upstream))
return nodes
def _print_format_dag_tree(dag: DAG) -> None:
for node in dag.root_nodes:
_print_dag(node)
def _print_dag(
node: DAGNode,
level: int = 0,
prefix: str = "",
last: bool = True,
level_dict: Optional[Dict[int, Any]] = None,
):
if level_dict is None:
level_dict = {}
connector = " -> " if level != 0 else ""
new_prefix = prefix
if last:
if level != 0:
new_prefix += " "
print(prefix + connector + str(node))
else:
if level != 0:
new_prefix += "| "
print(prefix + connector + str(node))
level_dict[level] = level_dict.get(level, 0) + 1
num_children = len(node.downstream)
for i, child in enumerate(node.downstream):
_print_dag(child, level + 1, new_prefix, i == num_children - 1, level_dict)
def _print_dag_tree(root_nodes: List[DAGNode], level_sep: str = " ") -> None:
def _print_node(node: DAGNode, level: int) -> None:
print(f"{level_sep * level}{node}")
_apply_root_node(root_nodes, _print_node)
def _apply_root_node(
root_nodes: List[DAGNode],
func: Callable[[DAGNode, int], None],
) -> None:
for dag_node in root_nodes:
_handle_dag_nodes(False, 0, dag_node, func)
def _handle_dag_nodes(
is_down_to_up: bool,
level: int,
dag_node: DAGNode,
func: Callable[[DAGNode, int], None],
):
if not dag_node:
return
func(dag_node, level)
stream_nodes = dag_node.upstream if is_down_to_up else dag_node.downstream
level += 1
for node in stream_nodes:
_handle_dag_nodes(is_down_to_up, level, node, func)
def _get_graph(dag: DAG):
try:
from graphviz import Digraph
except ImportError:
logger.warn("Can't import graphviz, skip visualize DAG")
return None, None
dot = Digraph(name=dag.dag_id)
mermaid_str = "graph TD;\n" # Initialize Mermaid graph definition
# Record the added edges to avoid adding duplicate edges
added_edges = set()
def add_edges(node: DAGNode):
nonlocal mermaid_str
if node.downstream:
for downstream_node in node.downstream:
# Check if the edge has been added
if (str(node), str(downstream_node)) not in added_edges:
dot.edge(str(node), str(downstream_node))
mermaid_str += f" {node.graph_str} --> {downstream_node.graph_str};\n" # noqa
added_edges.add((str(node), str(downstream_node)))
add_edges(downstream_node)
for root in dag.root_nodes:
add_edges(root)
return dot, mermaid_str
def _visualize_dag(
dag: DAG, view: bool = True, generate_mermaid: bool = True, **kwargs
) -> Optional[str]:
"""Visualize the DAG.
Args:
dag (DAG): The DAG to visualize
view (bool, optional): Whether view the DAG graph. Defaults to True.
generate_mermaid (bool, optional): Whether to generate a Mermaid syntax file.
Defaults to True.
Returns:
Optional[str]: The filename of the DAG graph
"""
dot, mermaid_str = _get_graph(dag)
if not dot:
return None
filename = f"dag-vis-{dag.dag_id}.gv"
if "filename" in kwargs:
filename = kwargs["filename"]
del kwargs["filename"]
if "directory" not in kwargs:
from dbgpt.configs.model_config import LOGDIR
kwargs["directory"] = LOGDIR
# Generate Mermaid syntax file if requested
if generate_mermaid:
mermaid_filename = filename.replace(".gv", ".md")
with open(
f"{kwargs.get('directory', '')}/{mermaid_filename}", "w"
) as mermaid_file:
logger.info(f"Writing Mermaid syntax to {mermaid_filename}")
mermaid_file.write(mermaid_str)
return dot.render(filename, view=view, **kwargs)