feat(rag): Support RAG SDK (#1322)

This commit is contained in:
Fangyin Cheng
2024-03-22 15:36:57 +08:00
committed by GitHub
parent e65732d6e4
commit 8a17099dd2
69 changed files with 1332 additions and 558 deletions

View File

@@ -710,6 +710,11 @@ class DAG:
self.print_tree()
return _visualize_dag(self, view=view, **kwargs)
def show(self, mermaid: bool = False) -> Any:
"""Return the graph of current DAG."""
dot, mermaid_str = _get_graph(self)
return mermaid_str if mermaid else dot
def __enter__(self):
"""Enter a DAG context."""
DAGVar.enter_dag(self)
@@ -813,26 +818,12 @@ def _handle_dag_nodes(
_handle_dag_nodes(is_down_to_up, level, node, func)
def _visualize_dag(
dag: DAG, view: bool = True, generate_mermaid: bool = True, **kwargs
) -> Optional[str]:
"""Visualize the DAG.
Args:
dag (DAG): The DAG to visualize
view (bool, optional): Whether view the DAG graph. Defaults to True.
generate_mermaid (bool, optional): Whether to generate a Mermaid syntax file.
Defaults to True.
Returns:
Optional[str]: The filename of the DAG graph
"""
def _get_graph(dag: DAG):
try:
from graphviz import Digraph
except ImportError:
logger.warn("Can't import graphviz, skip visualize DAG")
return None
return None, None
dot = Digraph(name=dag.dag_id)
mermaid_str = "graph TD;\n" # Initialize Mermaid graph definition
# Record the added edges to avoid adding duplicate edges
@@ -851,6 +842,26 @@ def _visualize_dag(
for root in dag.root_nodes:
add_edges(root)
return dot, mermaid_str
def _visualize_dag(
dag: DAG, view: bool = True, generate_mermaid: bool = True, **kwargs
) -> Optional[str]:
"""Visualize the DAG.
Args:
dag (DAG): The DAG to visualize
view (bool, optional): Whether view the DAG graph. Defaults to True.
generate_mermaid (bool, optional): Whether to generate a Mermaid syntax file.
Defaults to True.
Returns:
Optional[str]: The filename of the DAG graph
"""
dot, mermaid_str = _get_graph(dag)
if not dot:
return None
filename = f"dag-vis-{dag.dag_id}.gv"
if "filename" in kwargs:
filename = kwargs["filename"]

View File

@@ -1,12 +1,13 @@
"""Common operators of AWEL."""
import asyncio
import logging
from typing import Awaitable, Callable, Dict, Generic, List, Optional, Union
from typing import Any, Awaitable, Callable, Dict, Generic, List, Optional, Union
from ..dag.base import DAGContext
from ..task.base import (
IN,
OUT,
SKIP_DATA,
InputContext,
InputSource,
JoinFunc,
@@ -276,6 +277,11 @@ class InputOperator(BaseOperator, Generic[OUT]):
curr_task_ctx.set_task_output(task_output)
return task_output
@classmethod
def dummy_input(cls, dummy_data: Any = SKIP_DATA, **kwargs) -> "InputOperator[OUT]":
"""Create a dummy InputOperator with a given input value."""
return cls(input_source=InputSource.from_data(dummy_data), **kwargs)
class TriggerOperator(InputOperator[OUT], Generic[OUT]):
"""Operator node that triggers the DAG to run."""