mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-01 09:06:55 +00:00
feat(awel): New MessageConverter and more AWEL operators (#1039)
This commit is contained in:
@@ -7,6 +7,7 @@ The stability of this API cannot be guaranteed at present.
|
||||
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
@@ -39,6 +40,8 @@ from .task.task_impl import (
|
||||
)
|
||||
from .trigger.http_trigger import HttpTrigger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = [
|
||||
"initialize_awel",
|
||||
"DAGContext",
|
||||
@@ -89,14 +92,24 @@ def initialize_awel(system_app: SystemApp, dag_dirs: List[str]):
|
||||
|
||||
def setup_dev_environment(
|
||||
dags: List[DAG],
|
||||
host: Optional[str] = "0.0.0.0",
|
||||
host: Optional[str] = "127.0.0.1",
|
||||
port: Optional[int] = 5555,
|
||||
logging_level: Optional[str] = None,
|
||||
logger_filename: Optional[str] = None,
|
||||
show_dag_graph: Optional[bool] = True,
|
||||
) -> None:
|
||||
"""Setup a development environment for AWEL.
|
||||
|
||||
Just using in development environment, not production environment.
|
||||
|
||||
Args:
|
||||
dags (List[DAG]): The DAGs.
|
||||
host (Optional[str], optional): The host. Defaults to "127.0.0.1"
|
||||
port (Optional[int], optional): The port. Defaults to 5555.
|
||||
logging_level (Optional[str], optional): The logging level. Defaults to None.
|
||||
logger_filename (Optional[str], optional): The logger filename. Defaults to None.
|
||||
show_dag_graph (Optional[bool], optional): Whether show the DAG graph. Defaults to True.
|
||||
If True, the DAG graph will be saved to a file and open it automatically.
|
||||
"""
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
@@ -118,6 +131,15 @@ def setup_dev_environment(
|
||||
system_app.register_instance(trigger_manager)
|
||||
|
||||
for dag in dags:
|
||||
if show_dag_graph:
|
||||
try:
|
||||
dag_graph_file = dag.visualize_dag()
|
||||
if dag_graph_file:
|
||||
logger.info(f"Visualize DAG {str(dag)} to {dag_graph_file}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Visualize DAG {str(dag)} failed: {e}, if your system has no graphviz, you can install it by `pip install graphviz` or `sudo apt install graphviz`"
|
||||
)
|
||||
for trigger in dag.trigger_nodes:
|
||||
trigger_manager.register_trigger(trigger)
|
||||
trigger_manager.after_register()
|
||||
|
@@ -6,8 +6,7 @@ import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from concurrent.futures import Executor
|
||||
from functools import cache
|
||||
from typing import Any, Dict, List, Optional, Sequence, Set, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Union
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
|
||||
@@ -177,7 +176,10 @@ class DAGLifecycle:
|
||||
pass
|
||||
|
||||
async def after_dag_end(self):
|
||||
"""The callback after DAG end"""
|
||||
"""The callback after DAG end,
|
||||
|
||||
This method may be called multiple times, please make sure it is idempotent.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@@ -299,6 +301,20 @@ class DAGNode(DAGLifecycle, DependencyMixin, ABC):
|
||||
self._downstream.append(node)
|
||||
node._upstream.append(self)
|
||||
|
||||
def __repr__(self):
|
||||
cls_name = self.__class__.__name__
|
||||
if self.node_name and self.node_name:
|
||||
return f"{cls_name}(node_id={self.node_id}, node_name={self.node_name})"
|
||||
if self.node_id:
|
||||
return f"{cls_name}(node_id={self.node_id})"
|
||||
if self.node_name:
|
||||
return f"{cls_name}(node_name={self.node_name})"
|
||||
else:
|
||||
return f"{cls_name}"
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
|
||||
def _build_task_key(task_name: str, key: str) -> str:
|
||||
return f"{task_name}___$$$$$$___{key}"
|
||||
@@ -496,6 +512,15 @@ class DAG:
|
||||
tasks.append(node.after_dag_end())
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
def print_tree(self) -> None:
|
||||
"""Print the DAG tree"""
|
||||
_print_format_dag_tree(self)
|
||||
|
||||
def visualize_dag(self, view: bool = True, **kwargs) -> Optional[str]:
|
||||
"""Create the DAG graph"""
|
||||
self.print_tree()
|
||||
return _visualize_dag(self, view=view, **kwargs)
|
||||
|
||||
def __enter__(self):
|
||||
DAGVar.enter_dag(self)
|
||||
return self
|
||||
@@ -516,3 +541,109 @@ def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> set[DAGNode
|
||||
for node in stream_nodes:
|
||||
nodes = nodes.union(_get_nodes(node, is_upstream))
|
||||
return nodes
|
||||
|
||||
|
||||
def _print_format_dag_tree(dag: DAG) -> None:
|
||||
for node in dag.root_nodes:
|
||||
_print_dag(node)
|
||||
|
||||
|
||||
def _print_dag(
|
||||
node: DAGNode,
|
||||
level: int = 0,
|
||||
prefix: str = "",
|
||||
last: bool = True,
|
||||
level_dict: Dict[str, Any] = None,
|
||||
):
|
||||
if level_dict is None:
|
||||
level_dict = {}
|
||||
|
||||
connector = " -> " if level != 0 else ""
|
||||
new_prefix = prefix
|
||||
if last:
|
||||
if level != 0:
|
||||
new_prefix += " "
|
||||
print(prefix + connector + str(node))
|
||||
else:
|
||||
if level != 0:
|
||||
new_prefix += "| "
|
||||
print(prefix + connector + str(node))
|
||||
|
||||
level_dict[level] = level_dict.get(level, 0) + 1
|
||||
num_children = len(node.downstream)
|
||||
for i, child in enumerate(node.downstream):
|
||||
_print_dag(child, level + 1, new_prefix, i == num_children - 1, level_dict)
|
||||
|
||||
|
||||
def _print_dag_tree(root_nodes: List[DAGNode], level_sep: str = " ") -> None:
|
||||
def _print_node(node: DAGNode, level: int) -> None:
|
||||
print(f"{level_sep * level}{node}")
|
||||
|
||||
_apply_root_node(root_nodes, _print_node)
|
||||
|
||||
|
||||
def _apply_root_node(
|
||||
root_nodes: List[DAGNode],
|
||||
func: Callable[[DAGNode, int], None],
|
||||
) -> None:
|
||||
for dag_node in root_nodes:
|
||||
_handle_dag_nodes(False, 0, dag_node, func)
|
||||
|
||||
|
||||
def _handle_dag_nodes(
|
||||
is_down_to_up: bool,
|
||||
level: int,
|
||||
dag_node: DAGNode,
|
||||
func: Callable[[DAGNode, int], None],
|
||||
):
|
||||
if not dag_node:
|
||||
return
|
||||
func(dag_node, level)
|
||||
stream_nodes = dag_node.upstream if is_down_to_up else dag_node.downstream
|
||||
level += 1
|
||||
for node in stream_nodes:
|
||||
_handle_dag_nodes(is_down_to_up, level, node, func)
|
||||
|
||||
|
||||
def _visualize_dag(dag: DAG, view: bool = True, **kwargs) -> Optional[str]:
|
||||
"""Visualize the DAG
|
||||
|
||||
Args:
|
||||
dag (DAG): The DAG to visualize
|
||||
view (bool, optional): Whether view the DAG graph. Defaults to True.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The filename of the DAG graph
|
||||
"""
|
||||
try:
|
||||
from graphviz import Digraph
|
||||
except ImportError:
|
||||
logger.warn("Can't import graphviz, skip visualize DAG")
|
||||
return None
|
||||
|
||||
dot = Digraph(name=dag.dag_id)
|
||||
# Record the added edges to avoid adding duplicate edges
|
||||
added_edges = set()
|
||||
|
||||
def add_edges(node: DAGNode):
|
||||
if node.downstream:
|
||||
for downstream_node in node.downstream:
|
||||
# Check if the edge has been added
|
||||
if (str(node), str(downstream_node)) not in added_edges:
|
||||
dot.edge(str(node), str(downstream_node))
|
||||
added_edges.add((str(node), str(downstream_node)))
|
||||
add_edges(downstream_node)
|
||||
|
||||
for root in dag.root_nodes:
|
||||
add_edges(root)
|
||||
filename = f"dag-vis-{dag.dag_id}.gv"
|
||||
if "filename" in kwargs:
|
||||
filename = kwargs["filename"]
|
||||
del kwargs["filename"]
|
||||
|
||||
if not "directory" in kwargs:
|
||||
from dbgpt.configs.model_config import LOGDIR
|
||||
|
||||
kwargs["directory"] = LOGDIR
|
||||
|
||||
return dot.render(filename, view=view, **kwargs)
|
||||
|
@@ -46,6 +46,7 @@ class WorkflowRunner(ABC, Generic[T]):
|
||||
node: "BaseOperator",
|
||||
call_data: Optional[CALL_DATA] = None,
|
||||
streaming_call: bool = False,
|
||||
dag_ctx: Optional[DAGContext] = None,
|
||||
) -> DAGContext:
|
||||
"""Execute the workflow starting from a given operator.
|
||||
|
||||
@@ -53,7 +54,7 @@ class WorkflowRunner(ABC, Generic[T]):
|
||||
node (RunnableDAGNode): The starting node of the workflow to be executed.
|
||||
call_data (CALL_DATA): The data pass to root operator node.
|
||||
streaming_call (bool): Whether the call is a streaming call.
|
||||
|
||||
dag_ctx (DAGContext): The context of the DAG when this node is run, Defaults to None.
|
||||
Returns:
|
||||
DAGContext: The context after executing the workflow, containing the final state and data.
|
||||
"""
|
||||
@@ -174,18 +175,22 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
TaskOutput[OUT]: The task output after this node has been run.
|
||||
"""
|
||||
|
||||
async def call(self, call_data: Optional[CALL_DATA] = None) -> OUT:
|
||||
async def call(
|
||||
self,
|
||||
call_data: Optional[CALL_DATA] = None,
|
||||
dag_ctx: Optional[DAGContext] = None,
|
||||
) -> OUT:
|
||||
"""Execute the node and return the output.
|
||||
|
||||
This method is a high-level wrapper for executing the node.
|
||||
|
||||
Args:
|
||||
call_data (CALL_DATA): The data pass to root operator node.
|
||||
|
||||
dag_ctx (DAGContext): The context of the DAG when this node is run, Defaults to None.
|
||||
Returns:
|
||||
OUT: The output of the node after execution.
|
||||
"""
|
||||
out_ctx = await self._runner.execute_workflow(self, call_data)
|
||||
out_ctx = await self._runner.execute_workflow(self, call_data, dag_ctx=dag_ctx)
|
||||
return out_ctx.current_task_context.task_output.output
|
||||
|
||||
def _blocking_call(
|
||||
@@ -209,7 +214,9 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
return loop.run_until_complete(self.call(call_data))
|
||||
|
||||
async def call_stream(
|
||||
self, call_data: Optional[CALL_DATA] = None
|
||||
self,
|
||||
call_data: Optional[CALL_DATA] = None,
|
||||
dag_ctx: Optional[DAGContext] = None,
|
||||
) -> AsyncIterator[OUT]:
|
||||
"""Execute the node and return the output as a stream.
|
||||
|
||||
@@ -217,12 +224,13 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
|
||||
Args:
|
||||
call_data (CALL_DATA): The data pass to root operator node.
|
||||
dag_ctx (DAGContext): The context of the DAG when this node is run, Defaults to None.
|
||||
|
||||
Returns:
|
||||
AsyncIterator[OUT]: An asynchronous iterator over the output stream.
|
||||
"""
|
||||
out_ctx = await self._runner.execute_workflow(
|
||||
self, call_data, streaming_call=True
|
||||
self, call_data, streaming_call=True, dag_ctx=dag_ctx
|
||||
)
|
||||
return out_ctx.current_task_context.task_output.output_stream
|
||||
|
||||
|
@@ -19,17 +19,21 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
node: BaseOperator,
|
||||
call_data: Optional[CALL_DATA] = None,
|
||||
streaming_call: bool = False,
|
||||
dag_ctx: Optional[DAGContext] = None,
|
||||
) -> DAGContext:
|
||||
# Save node output
|
||||
# dag = node.dag
|
||||
node_outputs: Dict[str, TaskContext] = {}
|
||||
job_manager = JobManager.build_from_end_node(node, call_data)
|
||||
# Create DAG context
|
||||
dag_ctx = DAGContext(
|
||||
streaming_call=streaming_call,
|
||||
node_to_outputs=node_outputs,
|
||||
node_name_to_ids=job_manager._node_name_to_ids,
|
||||
)
|
||||
if not dag_ctx:
|
||||
# Create DAG context
|
||||
node_outputs: Dict[str, TaskContext] = {}
|
||||
dag_ctx = DAGContext(
|
||||
streaming_call=streaming_call,
|
||||
node_to_outputs=node_outputs,
|
||||
node_name_to_ids=job_manager._node_name_to_ids,
|
||||
)
|
||||
else:
|
||||
node_outputs = dag_ctx._node_to_outputs
|
||||
logger.info(
|
||||
f"Begin run workflow from end operator, id: {node.node_id}, call_data: {call_data}"
|
||||
)
|
||||
|
Reference in New Issue
Block a user