mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-04 10:34:30 +00:00
feat(awel): New MessageConverter and more AWEL operators (#1039)
This commit is contained in:
@@ -6,8 +6,7 @@ import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from concurrent.futures import Executor
|
||||
from functools import cache
|
||||
from typing import Any, Dict, List, Optional, Sequence, Set, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Union
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
|
||||
@@ -177,7 +176,10 @@ class DAGLifecycle:
|
||||
pass
|
||||
|
||||
async def after_dag_end(self):
|
||||
"""The callback after DAG end"""
|
||||
"""The callback after DAG end,
|
||||
|
||||
This method may be called multiple times, please make sure it is idempotent.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@@ -299,6 +301,20 @@ class DAGNode(DAGLifecycle, DependencyMixin, ABC):
|
||||
self._downstream.append(node)
|
||||
node._upstream.append(self)
|
||||
|
||||
def __repr__(self):
|
||||
cls_name = self.__class__.__name__
|
||||
if self.node_name and self.node_name:
|
||||
return f"{cls_name}(node_id={self.node_id}, node_name={self.node_name})"
|
||||
if self.node_id:
|
||||
return f"{cls_name}(node_id={self.node_id})"
|
||||
if self.node_name:
|
||||
return f"{cls_name}(node_name={self.node_name})"
|
||||
else:
|
||||
return f"{cls_name}"
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
|
||||
def _build_task_key(task_name: str, key: str) -> str:
|
||||
return f"{task_name}___$$$$$$___{key}"
|
||||
@@ -496,6 +512,15 @@ class DAG:
|
||||
tasks.append(node.after_dag_end())
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
def print_tree(self) -> None:
|
||||
"""Print the DAG tree"""
|
||||
_print_format_dag_tree(self)
|
||||
|
||||
def visualize_dag(self, view: bool = True, **kwargs) -> Optional[str]:
|
||||
"""Create the DAG graph"""
|
||||
self.print_tree()
|
||||
return _visualize_dag(self, view=view, **kwargs)
|
||||
|
||||
def __enter__(self):
|
||||
DAGVar.enter_dag(self)
|
||||
return self
|
||||
@@ -516,3 +541,109 @@ def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> set[DAGNode
|
||||
for node in stream_nodes:
|
||||
nodes = nodes.union(_get_nodes(node, is_upstream))
|
||||
return nodes
|
||||
|
||||
|
||||
def _print_format_dag_tree(dag: DAG) -> None:
|
||||
for node in dag.root_nodes:
|
||||
_print_dag(node)
|
||||
|
||||
|
||||
def _print_dag(
|
||||
node: DAGNode,
|
||||
level: int = 0,
|
||||
prefix: str = "",
|
||||
last: bool = True,
|
||||
level_dict: Dict[str, Any] = None,
|
||||
):
|
||||
if level_dict is None:
|
||||
level_dict = {}
|
||||
|
||||
connector = " -> " if level != 0 else ""
|
||||
new_prefix = prefix
|
||||
if last:
|
||||
if level != 0:
|
||||
new_prefix += " "
|
||||
print(prefix + connector + str(node))
|
||||
else:
|
||||
if level != 0:
|
||||
new_prefix += "| "
|
||||
print(prefix + connector + str(node))
|
||||
|
||||
level_dict[level] = level_dict.get(level, 0) + 1
|
||||
num_children = len(node.downstream)
|
||||
for i, child in enumerate(node.downstream):
|
||||
_print_dag(child, level + 1, new_prefix, i == num_children - 1, level_dict)
|
||||
|
||||
|
||||
def _print_dag_tree(root_nodes: List[DAGNode], level_sep: str = " ") -> None:
|
||||
def _print_node(node: DAGNode, level: int) -> None:
|
||||
print(f"{level_sep * level}{node}")
|
||||
|
||||
_apply_root_node(root_nodes, _print_node)
|
||||
|
||||
|
||||
def _apply_root_node(
|
||||
root_nodes: List[DAGNode],
|
||||
func: Callable[[DAGNode, int], None],
|
||||
) -> None:
|
||||
for dag_node in root_nodes:
|
||||
_handle_dag_nodes(False, 0, dag_node, func)
|
||||
|
||||
|
||||
def _handle_dag_nodes(
|
||||
is_down_to_up: bool,
|
||||
level: int,
|
||||
dag_node: DAGNode,
|
||||
func: Callable[[DAGNode, int], None],
|
||||
):
|
||||
if not dag_node:
|
||||
return
|
||||
func(dag_node, level)
|
||||
stream_nodes = dag_node.upstream if is_down_to_up else dag_node.downstream
|
||||
level += 1
|
||||
for node in stream_nodes:
|
||||
_handle_dag_nodes(is_down_to_up, level, node, func)
|
||||
|
||||
|
||||
def _visualize_dag(dag: DAG, view: bool = True, **kwargs) -> Optional[str]:
|
||||
"""Visualize the DAG
|
||||
|
||||
Args:
|
||||
dag (DAG): The DAG to visualize
|
||||
view (bool, optional): Whether view the DAG graph. Defaults to True.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The filename of the DAG graph
|
||||
"""
|
||||
try:
|
||||
from graphviz import Digraph
|
||||
except ImportError:
|
||||
logger.warn("Can't import graphviz, skip visualize DAG")
|
||||
return None
|
||||
|
||||
dot = Digraph(name=dag.dag_id)
|
||||
# Record the added edges to avoid adding duplicate edges
|
||||
added_edges = set()
|
||||
|
||||
def add_edges(node: DAGNode):
|
||||
if node.downstream:
|
||||
for downstream_node in node.downstream:
|
||||
# Check if the edge has been added
|
||||
if (str(node), str(downstream_node)) not in added_edges:
|
||||
dot.edge(str(node), str(downstream_node))
|
||||
added_edges.add((str(node), str(downstream_node)))
|
||||
add_edges(downstream_node)
|
||||
|
||||
for root in dag.root_nodes:
|
||||
add_edges(root)
|
||||
filename = f"dag-vis-{dag.dag_id}.gv"
|
||||
if "filename" in kwargs:
|
||||
filename = kwargs["filename"]
|
||||
del kwargs["filename"]
|
||||
|
||||
if not "directory" in kwargs:
|
||||
from dbgpt.configs.model_config import LOGDIR
|
||||
|
||||
kwargs["directory"] = LOGDIR
|
||||
|
||||
return dot.render(filename, view=view, **kwargs)
|
||||
|
Reference in New Issue
Block a user