feat(awel): New MessageConverter and more AWEL operators (#1039)

This commit is contained in:
Fangyin Cheng
2024-01-08 09:40:05 +08:00
committed by GitHub
parent 765fb181f6
commit e8861bd8fa
48 changed files with 2333 additions and 719 deletions

View File

@@ -6,8 +6,7 @@ import uuid
from abc import ABC, abstractmethod
from collections import deque
from concurrent.futures import Executor
from functools import cache
from typing import Any, Dict, List, Optional, Sequence, Set, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Union
from dbgpt.component import SystemApp
@@ -177,7 +176,10 @@ class DAGLifecycle:
pass
async def after_dag_end(self):
"""The callback after DAG end"""
"""The callback after DAG end,
This method may be called multiple times, please make sure it is idempotent.
"""
pass
@@ -299,6 +301,20 @@ class DAGNode(DAGLifecycle, DependencyMixin, ABC):
self._downstream.append(node)
node._upstream.append(self)
def __repr__(self):
cls_name = self.__class__.__name__
if self.node_name and self.node_name:
return f"{cls_name}(node_id={self.node_id}, node_name={self.node_name})"
if self.node_id:
return f"{cls_name}(node_id={self.node_id})"
if self.node_name:
return f"{cls_name}(node_name={self.node_name})"
else:
return f"{cls_name}"
def __str__(self):
return self.__repr__()
def _build_task_key(task_name: str, key: str) -> str:
return f"{task_name}___$$$$$$___{key}"
@@ -496,6 +512,15 @@ class DAG:
tasks.append(node.after_dag_end())
await asyncio.gather(*tasks)
def print_tree(self) -> None:
"""Print the DAG tree"""
_print_format_dag_tree(self)
def visualize_dag(self, view: bool = True, **kwargs) -> Optional[str]:
"""Create the DAG graph"""
self.print_tree()
return _visualize_dag(self, view=view, **kwargs)
def __enter__(self):
DAGVar.enter_dag(self)
return self
@@ -516,3 +541,109 @@ def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> set[DAGNode
for node in stream_nodes:
nodes = nodes.union(_get_nodes(node, is_upstream))
return nodes
def _print_format_dag_tree(dag: DAG) -> None:
for node in dag.root_nodes:
_print_dag(node)
def _print_dag(
node: DAGNode,
level: int = 0,
prefix: str = "",
last: bool = True,
level_dict: Dict[str, Any] = None,
):
if level_dict is None:
level_dict = {}
connector = " -> " if level != 0 else ""
new_prefix = prefix
if last:
if level != 0:
new_prefix += " "
print(prefix + connector + str(node))
else:
if level != 0:
new_prefix += "| "
print(prefix + connector + str(node))
level_dict[level] = level_dict.get(level, 0) + 1
num_children = len(node.downstream)
for i, child in enumerate(node.downstream):
_print_dag(child, level + 1, new_prefix, i == num_children - 1, level_dict)
def _print_dag_tree(root_nodes: List[DAGNode], level_sep: str = " ") -> None:
def _print_node(node: DAGNode, level: int) -> None:
print(f"{level_sep * level}{node}")
_apply_root_node(root_nodes, _print_node)
def _apply_root_node(
root_nodes: List[DAGNode],
func: Callable[[DAGNode, int], None],
) -> None:
for dag_node in root_nodes:
_handle_dag_nodes(False, 0, dag_node, func)
def _handle_dag_nodes(
is_down_to_up: bool,
level: int,
dag_node: DAGNode,
func: Callable[[DAGNode, int], None],
):
if not dag_node:
return
func(dag_node, level)
stream_nodes = dag_node.upstream if is_down_to_up else dag_node.downstream
level += 1
for node in stream_nodes:
_handle_dag_nodes(is_down_to_up, level, node, func)
def _visualize_dag(dag: DAG, view: bool = True, **kwargs) -> Optional[str]:
"""Visualize the DAG
Args:
dag (DAG): The DAG to visualize
view (bool, optional): Whether view the DAG graph. Defaults to True.
Returns:
Optional[str]: The filename of the DAG graph
"""
try:
from graphviz import Digraph
except ImportError:
logger.warn("Can't import graphviz, skip visualize DAG")
return None
dot = Digraph(name=dag.dag_id)
# Record the added edges to avoid adding duplicate edges
added_edges = set()
def add_edges(node: DAGNode):
if node.downstream:
for downstream_node in node.downstream:
# Check if the edge has been added
if (str(node), str(downstream_node)) not in added_edges:
dot.edge(str(node), str(downstream_node))
added_edges.add((str(node), str(downstream_node)))
add_edges(downstream_node)
for root in dag.root_nodes:
add_edges(root)
filename = f"dag-vis-{dag.dag_id}.gv"
if "filename" in kwargs:
filename = kwargs["filename"]
del kwargs["filename"]
if not "directory" in kwargs:
from dbgpt.configs.model_config import LOGDIR
kwargs["directory"] = LOGDIR
return dot.render(filename, view=view, **kwargs)