feat(rag): Support RAG SDK (#1322)

This commit is contained in:
Fangyin Cheng
2024-03-22 15:36:57 +08:00
committed by GitHub
parent e65732d6e4
commit 8a17099dd2
69 changed files with 1332 additions and 558 deletions

View File

@@ -710,6 +710,11 @@ class DAG:
self.print_tree()
return _visualize_dag(self, view=view, **kwargs)
def show(self, mermaid: bool = False) -> Any:
"""Return the graph of current DAG."""
dot, mermaid_str = _get_graph(self)
return mermaid_str if mermaid else dot
def __enter__(self):
"""Enter a DAG context."""
DAGVar.enter_dag(self)
@@ -813,26 +818,12 @@ def _handle_dag_nodes(
_handle_dag_nodes(is_down_to_up, level, node, func)
def _visualize_dag(
dag: DAG, view: bool = True, generate_mermaid: bool = True, **kwargs
) -> Optional[str]:
"""Visualize the DAG.
Args:
dag (DAG): The DAG to visualize
view (bool, optional): Whether view the DAG graph. Defaults to True.
generate_mermaid (bool, optional): Whether to generate a Mermaid syntax file.
Defaults to True.
Returns:
Optional[str]: The filename of the DAG graph
"""
def _get_graph(dag: DAG):
try:
from graphviz import Digraph
except ImportError:
logger.warn("Can't import graphviz, skip visualize DAG")
return None
return None, None
dot = Digraph(name=dag.dag_id)
mermaid_str = "graph TD;\n" # Initialize Mermaid graph definition
# Record the added edges to avoid adding duplicate edges
@@ -851,6 +842,26 @@ def _visualize_dag(
for root in dag.root_nodes:
add_edges(root)
return dot, mermaid_str
def _visualize_dag(
dag: DAG, view: bool = True, generate_mermaid: bool = True, **kwargs
) -> Optional[str]:
"""Visualize the DAG.
Args:
dag (DAG): The DAG to visualize
view (bool, optional): Whether view the DAG graph. Defaults to True.
generate_mermaid (bool, optional): Whether to generate a Mermaid syntax file.
Defaults to True.
Returns:
Optional[str]: The filename of the DAG graph
"""
dot, mermaid_str = _get_graph(dag)
if not dot:
return None
filename = f"dag-vis-{dag.dag_id}.gv"
if "filename" in kwargs:
filename = kwargs["filename"]