feat(core): More AWEL operators and new prompt manager API (#972)

Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
Fangyin Cheng
2023-12-25 20:03:22 +08:00
committed by GitHub
parent 048fb6c402
commit 69fb97e508
46 changed files with 2556 additions and 294 deletions

View File

@@ -11,7 +11,7 @@ from concurrent.futures import Executor
from dbgpt.component import SystemApp
from ..resource.base import ResourceGroup
from ..task.base import TaskContext
from ..task.base import TaskContext, TaskOutput
logger = logging.getLogger(__name__)
@@ -168,7 +168,19 @@ class DAGVar:
cls._executor = executor
class DAGNode(DependencyMixin, ABC):
class DAGLifecycle:
"""The lifecycle of DAG"""
async def before_dag_run(self):
"""The callback before DAG run"""
pass
async def after_dag_end(self):
"""The callback after DAG end"""
pass
class DAGNode(DAGLifecycle, DependencyMixin, ABC):
resource_group: Optional[ResourceGroup] = None
"""The resource group of current DAGNode"""
@@ -179,7 +191,7 @@ class DAGNode(DependencyMixin, ABC):
node_name: Optional[str] = None,
system_app: Optional[SystemApp] = None,
executor: Optional[Executor] = None,
**kwargs
**kwargs,
) -> None:
super().__init__()
self._upstream: List["DAGNode"] = []
@@ -198,10 +210,23 @@ class DAGNode(DependencyMixin, ABC):
def node_id(self) -> str:
return self._node_id
@property
@abstractmethod
def dev_mode(self) -> bool:
"""Whether current DAGNode is in dev mode"""
@property
def system_app(self) -> SystemApp:
return self._system_app
def set_system_app(self, system_app: SystemApp) -> None:
"""Set system app for current DAGNode
Args:
system_app (SystemApp): The system app
"""
self._system_app = system_app
def set_node_id(self, node_id: str) -> None:
self._node_id = node_id
@@ -274,11 +299,41 @@ class DAGNode(DependencyMixin, ABC):
node._upstream.append(self)
def _build_task_key(task_name: str, key: str) -> str:
return f"{task_name}___$$$$$$___{key}"
class DAGContext:
def __init__(self, streaming_call: bool = False) -> None:
"""The context of current DAG, created when the DAG is running
Every DAG has been triggered will create a new DAGContext.
"""
def __init__(
self,
streaming_call: bool = False,
node_to_outputs: Dict[str, TaskContext] = None,
node_name_to_ids: Dict[str, str] = None,
) -> None:
if not node_to_outputs:
node_to_outputs = {}
if not node_name_to_ids:
node_name_to_ids = {}
self._streaming_call = streaming_call
self._curr_task_ctx = None
self._share_data: Dict[str, Any] = {}
self._node_to_outputs = node_to_outputs
self._node_name_to_ids = node_name_to_ids
@property
def _task_outputs(self) -> Dict[str, TaskContext]:
"""The task outputs of current DAG
Just use for internal for now.
Returns:
Dict[str, TaskContext]: The task outputs of current DAG
"""
return self._node_to_outputs
@property
def current_task_context(self) -> TaskContext:
@@ -292,12 +347,69 @@ class DAGContext:
def set_current_task_context(self, _curr_task_ctx: TaskContext) -> None:
self._curr_task_ctx = _curr_task_ctx
async def get_share_data(self, key: str) -> Any:
def get_task_output(self, task_name: str) -> TaskOutput:
"""Get the task output by task name
Args:
task_name (str): The task name
Returns:
TaskOutput: The task output
"""
if task_name is None:
raise ValueError("task_name can't be None")
node_id = self._node_name_to_ids.get(task_name)
if node_id:
raise ValueError(f"Task name {task_name} not exists in DAG")
return self._task_outputs.get(node_id).task_output
async def get_from_share_data(self, key: str) -> Any:
return self._share_data.get(key)
async def save_to_share_data(self, key: str, data: Any) -> None:
async def save_to_share_data(
self, key: str, data: Any, overwrite: Optional[str] = None
) -> None:
if key in self._share_data and not overwrite:
raise ValueError(f"Share data key {key} already exists")
self._share_data[key] = data
async def get_task_share_data(self, task_name: str, key: str) -> Any:
"""Get share data by task name and key
Args:
task_name (str): The task name
key (str): The share data key
Returns:
Any: The share data
"""
if task_name is None:
raise ValueError("task_name can't be None")
if key is None:
raise ValueError("key can't be None")
return self.get_from_share_data(_build_task_key(task_name, key))
async def save_task_share_data(
self, task_name: str, key: str, data: Any, overwrite: Optional[str] = None
) -> None:
"""Save share data by task name and key
Args:
task_name (str): The task name
key (str): The share data key
data (Any): The share data
overwrite (Optional[str], optional): Whether overwrite the share data if the key already exists.
Defaults to None.
Raises:
ValueError: If the share data key already exists and overwrite is not True
"""
if task_name is None:
raise ValueError("task_name can't be None")
if key is None:
raise ValueError("key can't be None")
await self.save_to_share_data(_build_task_key(task_name, key), data, overwrite)
class DAG:
def __init__(
@@ -305,11 +417,20 @@ class DAG:
) -> None:
self._dag_id = dag_id
self.node_map: Dict[str, DAGNode] = {}
self._root_nodes: Set[DAGNode] = None
self._leaf_nodes: Set[DAGNode] = None
self._trigger_nodes: Set[DAGNode] = None
self.node_name_to_node: Dict[str, DAGNode] = {}
self._root_nodes: List[DAGNode] = None
self._leaf_nodes: List[DAGNode] = None
self._trigger_nodes: List[DAGNode] = None
def _append_node(self, node: DAGNode) -> None:
if node.node_id in self.node_map:
return
if node.node_name:
if node.node_name in self.node_name_to_node:
raise ValueError(
f"Node name {node.node_name} already exists in DAG {self.dag_id}"
)
self.node_name_to_node[node.node_name] = node
self.node_map[node.node_id] = node
# clear cached nodes
self._root_nodes = None
@@ -336,22 +457,44 @@ class DAG:
@property
def root_nodes(self) -> List[DAGNode]:
"""The root nodes of current DAG
Returns:
List[DAGNode]: The root nodes of current DAG, no repeat
"""
if not self._root_nodes:
self._build()
return self._root_nodes
@property
def leaf_nodes(self) -> List[DAGNode]:
"""The leaf nodes of current DAG
Returns:
List[DAGNode]: The leaf nodes of current DAG, no repeat
"""
if not self._leaf_nodes:
self._build()
return self._leaf_nodes
@property
def trigger_nodes(self):
def trigger_nodes(self) -> List[DAGNode]:
"""The trigger nodes of current DAG
Returns:
List[DAGNode]: The trigger nodes of current DAG, no repeat
"""
if not self._trigger_nodes:
self._build()
return self._trigger_nodes
async def _after_dag_end(self) -> None:
"""The callback after DAG end"""
tasks = []
for node in self.node_map.values():
tasks.append(node.after_dag_end())
await asyncio.gather(*tasks)
def __enter__(self):
DAGVar.enter_dag(self)
return self