mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-02 17:45:31 +00:00
feat(rag): Support RAG SDK (#1322)
This commit is contained in:
@@ -710,6 +710,11 @@ class DAG:
|
||||
self.print_tree()
|
||||
return _visualize_dag(self, view=view, **kwargs)
|
||||
|
||||
def show(self, mermaid: bool = False) -> Any:
|
||||
"""Return the graph of current DAG."""
|
||||
dot, mermaid_str = _get_graph(self)
|
||||
return mermaid_str if mermaid else dot
|
||||
|
||||
def __enter__(self):
|
||||
"""Enter a DAG context."""
|
||||
DAGVar.enter_dag(self)
|
||||
@@ -813,26 +818,12 @@ def _handle_dag_nodes(
|
||||
_handle_dag_nodes(is_down_to_up, level, node, func)
|
||||
|
||||
|
||||
def _visualize_dag(
|
||||
dag: DAG, view: bool = True, generate_mermaid: bool = True, **kwargs
|
||||
) -> Optional[str]:
|
||||
"""Visualize the DAG.
|
||||
|
||||
Args:
|
||||
dag (DAG): The DAG to visualize
|
||||
view (bool, optional): Whether view the DAG graph. Defaults to True.
|
||||
generate_mermaid (bool, optional): Whether to generate a Mermaid syntax file.
|
||||
Defaults to True.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The filename of the DAG graph
|
||||
"""
|
||||
def _get_graph(dag: DAG):
|
||||
try:
|
||||
from graphviz import Digraph
|
||||
except ImportError:
|
||||
logger.warn("Can't import graphviz, skip visualize DAG")
|
||||
return None
|
||||
|
||||
return None, None
|
||||
dot = Digraph(name=dag.dag_id)
|
||||
mermaid_str = "graph TD;\n" # Initialize Mermaid graph definition
|
||||
# Record the added edges to avoid adding duplicate edges
|
||||
@@ -851,6 +842,26 @@ def _visualize_dag(
|
||||
|
||||
for root in dag.root_nodes:
|
||||
add_edges(root)
|
||||
return dot, mermaid_str
|
||||
|
||||
|
||||
def _visualize_dag(
|
||||
dag: DAG, view: bool = True, generate_mermaid: bool = True, **kwargs
|
||||
) -> Optional[str]:
|
||||
"""Visualize the DAG.
|
||||
|
||||
Args:
|
||||
dag (DAG): The DAG to visualize
|
||||
view (bool, optional): Whether view the DAG graph. Defaults to True.
|
||||
generate_mermaid (bool, optional): Whether to generate a Mermaid syntax file.
|
||||
Defaults to True.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The filename of the DAG graph
|
||||
"""
|
||||
dot, mermaid_str = _get_graph(dag)
|
||||
if not dot:
|
||||
return None
|
||||
filename = f"dag-vis-{dag.dag_id}.gv"
|
||||
if "filename" in kwargs:
|
||||
filename = kwargs["filename"]
|
||||
|
@@ -1,12 +1,13 @@
|
||||
"""Common operators of AWEL."""
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Awaitable, Callable, Dict, Generic, List, Optional, Union
|
||||
from typing import Any, Awaitable, Callable, Dict, Generic, List, Optional, Union
|
||||
|
||||
from ..dag.base import DAGContext
|
||||
from ..task.base import (
|
||||
IN,
|
||||
OUT,
|
||||
SKIP_DATA,
|
||||
InputContext,
|
||||
InputSource,
|
||||
JoinFunc,
|
||||
@@ -276,6 +277,11 @@ class InputOperator(BaseOperator, Generic[OUT]):
|
||||
curr_task_ctx.set_task_output(task_output)
|
||||
return task_output
|
||||
|
||||
@classmethod
|
||||
def dummy_input(cls, dummy_data: Any = SKIP_DATA, **kwargs) -> "InputOperator[OUT]":
|
||||
"""Create a dummy InputOperator with a given input value."""
|
||||
return cls(input_source=InputSource.from_data(dummy_data), **kwargs)
|
||||
|
||||
|
||||
class TriggerOperator(InputOperator[OUT], Generic[OUT]):
|
||||
"""Operator node that triggers the DAG to run."""
|
||||
|
Reference in New Issue
Block a user