feat(awel): New MessageConverter and more AWEL operators (#1039)

This commit is contained in:
Fangyin Cheng
2024-01-08 09:40:05 +08:00
committed by GitHub
parent 765fb181f6
commit e8861bd8fa
48 changed files with 2333 additions and 719 deletions

View File

@@ -7,6 +7,7 @@ The stability of this API cannot be guaranteed at present.
"""
import logging
from typing import List, Optional
from dbgpt.component import SystemApp
@@ -39,6 +40,8 @@ from .task.task_impl import (
)
from .trigger.http_trigger import HttpTrigger
logger = logging.getLogger(__name__)
__all__ = [
"initialize_awel",
"DAGContext",
@@ -89,14 +92,24 @@ def initialize_awel(system_app: SystemApp, dag_dirs: List[str]):
def setup_dev_environment(
dags: List[DAG],
host: Optional[str] = "0.0.0.0",
host: Optional[str] = "127.0.0.1",
port: Optional[int] = 5555,
logging_level: Optional[str] = None,
logger_filename: Optional[str] = None,
show_dag_graph: Optional[bool] = True,
) -> None:
"""Setup a development environment for AWEL.
Just using in development environment, not production environment.
Args:
dags (List[DAG]): The DAGs.
host (Optional[str], optional): The host. Defaults to "127.0.0.1"
port (Optional[int], optional): The port. Defaults to 5555.
logging_level (Optional[str], optional): The logging level. Defaults to None.
logger_filename (Optional[str], optional): The logger filename. Defaults to None.
show_dag_graph (Optional[bool], optional): Whether show the DAG graph. Defaults to True.
If True, the DAG graph will be saved to a file and open it automatically.
"""
import uvicorn
from fastapi import FastAPI
@@ -118,6 +131,15 @@ def setup_dev_environment(
system_app.register_instance(trigger_manager)
for dag in dags:
if show_dag_graph:
try:
dag_graph_file = dag.visualize_dag()
if dag_graph_file:
logger.info(f"Visualize DAG {str(dag)} to {dag_graph_file}")
except Exception as e:
logger.warning(
f"Visualize DAG {str(dag)} failed: {e}, if your system has no graphviz, you can install it by `pip install graphviz` or `sudo apt install graphviz`"
)
for trigger in dag.trigger_nodes:
trigger_manager.register_trigger(trigger)
trigger_manager.after_register()

View File

@@ -6,8 +6,7 @@ import uuid
from abc import ABC, abstractmethod
from collections import deque
from concurrent.futures import Executor
from functools import cache
from typing import Any, Dict, List, Optional, Sequence, Set, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Union
from dbgpt.component import SystemApp
@@ -177,7 +176,10 @@ class DAGLifecycle:
pass
async def after_dag_end(self):
"""The callback after DAG end"""
"""The callback after DAG end,
This method may be called multiple times, please make sure it is idempotent.
"""
pass
@@ -299,6 +301,20 @@ class DAGNode(DAGLifecycle, DependencyMixin, ABC):
self._downstream.append(node)
node._upstream.append(self)
def __repr__(self):
cls_name = self.__class__.__name__
if self.node_name and self.node_name:
return f"{cls_name}(node_id={self.node_id}, node_name={self.node_name})"
if self.node_id:
return f"{cls_name}(node_id={self.node_id})"
if self.node_name:
return f"{cls_name}(node_name={self.node_name})"
else:
return f"{cls_name}"
def __str__(self):
return self.__repr__()
def _build_task_key(task_name: str, key: str) -> str:
return f"{task_name}___$$$$$$___{key}"
@@ -496,6 +512,15 @@ class DAG:
tasks.append(node.after_dag_end())
await asyncio.gather(*tasks)
def print_tree(self) -> None:
"""Print the DAG tree"""
_print_format_dag_tree(self)
def visualize_dag(self, view: bool = True, **kwargs) -> Optional[str]:
"""Create the DAG graph"""
self.print_tree()
return _visualize_dag(self, view=view, **kwargs)
def __enter__(self):
DAGVar.enter_dag(self)
return self
@@ -516,3 +541,109 @@ def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> set[DAGNode
for node in stream_nodes:
nodes = nodes.union(_get_nodes(node, is_upstream))
return nodes
def _print_format_dag_tree(dag: DAG) -> None:
for node in dag.root_nodes:
_print_dag(node)
def _print_dag(
node: DAGNode,
level: int = 0,
prefix: str = "",
last: bool = True,
level_dict: Dict[str, Any] = None,
):
if level_dict is None:
level_dict = {}
connector = " -> " if level != 0 else ""
new_prefix = prefix
if last:
if level != 0:
new_prefix += " "
print(prefix + connector + str(node))
else:
if level != 0:
new_prefix += "| "
print(prefix + connector + str(node))
level_dict[level] = level_dict.get(level, 0) + 1
num_children = len(node.downstream)
for i, child in enumerate(node.downstream):
_print_dag(child, level + 1, new_prefix, i == num_children - 1, level_dict)
def _print_dag_tree(root_nodes: List[DAGNode], level_sep: str = " ") -> None:
def _print_node(node: DAGNode, level: int) -> None:
print(f"{level_sep * level}{node}")
_apply_root_node(root_nodes, _print_node)
def _apply_root_node(
root_nodes: List[DAGNode],
func: Callable[[DAGNode, int], None],
) -> None:
for dag_node in root_nodes:
_handle_dag_nodes(False, 0, dag_node, func)
def _handle_dag_nodes(
is_down_to_up: bool,
level: int,
dag_node: DAGNode,
func: Callable[[DAGNode, int], None],
):
if not dag_node:
return
func(dag_node, level)
stream_nodes = dag_node.upstream if is_down_to_up else dag_node.downstream
level += 1
for node in stream_nodes:
_handle_dag_nodes(is_down_to_up, level, node, func)
def _visualize_dag(dag: DAG, view: bool = True, **kwargs) -> Optional[str]:
"""Visualize the DAG
Args:
dag (DAG): The DAG to visualize
view (bool, optional): Whether view the DAG graph. Defaults to True.
Returns:
Optional[str]: The filename of the DAG graph
"""
try:
from graphviz import Digraph
except ImportError:
logger.warn("Can't import graphviz, skip visualize DAG")
return None
dot = Digraph(name=dag.dag_id)
# Record the added edges to avoid adding duplicate edges
added_edges = set()
def add_edges(node: DAGNode):
if node.downstream:
for downstream_node in node.downstream:
# Check if the edge has been added
if (str(node), str(downstream_node)) not in added_edges:
dot.edge(str(node), str(downstream_node))
added_edges.add((str(node), str(downstream_node)))
add_edges(downstream_node)
for root in dag.root_nodes:
add_edges(root)
filename = f"dag-vis-{dag.dag_id}.gv"
if "filename" in kwargs:
filename = kwargs["filename"]
del kwargs["filename"]
if not "directory" in kwargs:
from dbgpt.configs.model_config import LOGDIR
kwargs["directory"] = LOGDIR
return dot.render(filename, view=view, **kwargs)

View File

@@ -46,6 +46,7 @@ class WorkflowRunner(ABC, Generic[T]):
node: "BaseOperator",
call_data: Optional[CALL_DATA] = None,
streaming_call: bool = False,
dag_ctx: Optional[DAGContext] = None,
) -> DAGContext:
"""Execute the workflow starting from a given operator.
@@ -53,7 +54,7 @@ class WorkflowRunner(ABC, Generic[T]):
node (RunnableDAGNode): The starting node of the workflow to be executed.
call_data (CALL_DATA): The data pass to root operator node.
streaming_call (bool): Whether the call is a streaming call.
dag_ctx (DAGContext): The context of the DAG when this node is run, Defaults to None.
Returns:
DAGContext: The context after executing the workflow, containing the final state and data.
"""
@@ -174,18 +175,22 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
TaskOutput[OUT]: The task output after this node has been run.
"""
async def call(self, call_data: Optional[CALL_DATA] = None) -> OUT:
async def call(
self,
call_data: Optional[CALL_DATA] = None,
dag_ctx: Optional[DAGContext] = None,
) -> OUT:
"""Execute the node and return the output.
This method is a high-level wrapper for executing the node.
Args:
call_data (CALL_DATA): The data pass to root operator node.
dag_ctx (DAGContext): The context of the DAG when this node is run, Defaults to None.
Returns:
OUT: The output of the node after execution.
"""
out_ctx = await self._runner.execute_workflow(self, call_data)
out_ctx = await self._runner.execute_workflow(self, call_data, dag_ctx=dag_ctx)
return out_ctx.current_task_context.task_output.output
def _blocking_call(
@@ -209,7 +214,9 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
return loop.run_until_complete(self.call(call_data))
async def call_stream(
self, call_data: Optional[CALL_DATA] = None
self,
call_data: Optional[CALL_DATA] = None,
dag_ctx: Optional[DAGContext] = None,
) -> AsyncIterator[OUT]:
"""Execute the node and return the output as a stream.
@@ -217,12 +224,13 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
Args:
call_data (CALL_DATA): The data pass to root operator node.
dag_ctx (DAGContext): The context of the DAG when this node is run, Defaults to None.
Returns:
AsyncIterator[OUT]: An asynchronous iterator over the output stream.
"""
out_ctx = await self._runner.execute_workflow(
self, call_data, streaming_call=True
self, call_data, streaming_call=True, dag_ctx=dag_ctx
)
return out_ctx.current_task_context.task_output.output_stream

View File

@@ -19,17 +19,21 @@ class DefaultWorkflowRunner(WorkflowRunner):
node: BaseOperator,
call_data: Optional[CALL_DATA] = None,
streaming_call: bool = False,
dag_ctx: Optional[DAGContext] = None,
) -> DAGContext:
# Save node output
# dag = node.dag
node_outputs: Dict[str, TaskContext] = {}
job_manager = JobManager.build_from_end_node(node, call_data)
# Create DAG context
dag_ctx = DAGContext(
streaming_call=streaming_call,
node_to_outputs=node_outputs,
node_name_to_ids=job_manager._node_name_to_ids,
)
if not dag_ctx:
# Create DAG context
node_outputs: Dict[str, TaskContext] = {}
dag_ctx = DAGContext(
streaming_call=streaming_call,
node_to_outputs=node_outputs,
node_name_to_ids=job_manager._node_name_to_ids,
)
else:
node_outputs = dag_ctx._node_to_outputs
logger.info(
f"Begin run workflow from end operator, id: {node.node_id}, call_data: {call_data}"
)