mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-04 02:25:08 +00:00
feat(core): More AWEL operators and new prompt manager API (#972)
Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
@@ -11,7 +11,7 @@ from concurrent.futures import Executor
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from ..resource.base import ResourceGroup
|
||||
from ..task.base import TaskContext
|
||||
from ..task.base import TaskContext, TaskOutput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -168,7 +168,19 @@ class DAGVar:
|
||||
cls._executor = executor
|
||||
|
||||
|
||||
class DAGNode(DependencyMixin, ABC):
|
||||
class DAGLifecycle:
|
||||
"""The lifecycle of DAG"""
|
||||
|
||||
async def before_dag_run(self):
|
||||
"""The callback before DAG run"""
|
||||
pass
|
||||
|
||||
async def after_dag_end(self):
|
||||
"""The callback after DAG end"""
|
||||
pass
|
||||
|
||||
|
||||
class DAGNode(DAGLifecycle, DependencyMixin, ABC):
|
||||
resource_group: Optional[ResourceGroup] = None
|
||||
"""The resource group of current DAGNode"""
|
||||
|
||||
@@ -179,7 +191,7 @@ class DAGNode(DependencyMixin, ABC):
|
||||
node_name: Optional[str] = None,
|
||||
system_app: Optional[SystemApp] = None,
|
||||
executor: Optional[Executor] = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._upstream: List["DAGNode"] = []
|
||||
@@ -198,10 +210,23 @@ class DAGNode(DependencyMixin, ABC):
|
||||
def node_id(self) -> str:
|
||||
return self._node_id
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def dev_mode(self) -> bool:
|
||||
"""Whether current DAGNode is in dev mode"""
|
||||
|
||||
@property
|
||||
def system_app(self) -> SystemApp:
|
||||
return self._system_app
|
||||
|
||||
def set_system_app(self, system_app: SystemApp) -> None:
|
||||
"""Set system app for current DAGNode
|
||||
|
||||
Args:
|
||||
system_app (SystemApp): The system app
|
||||
"""
|
||||
self._system_app = system_app
|
||||
|
||||
def set_node_id(self, node_id: str) -> None:
|
||||
self._node_id = node_id
|
||||
|
||||
@@ -274,11 +299,41 @@ class DAGNode(DependencyMixin, ABC):
|
||||
node._upstream.append(self)
|
||||
|
||||
|
||||
def _build_task_key(task_name: str, key: str) -> str:
|
||||
return f"{task_name}___$$$$$$___{key}"
|
||||
|
||||
|
||||
class DAGContext:
|
||||
def __init__(self, streaming_call: bool = False) -> None:
|
||||
"""The context of current DAG, created when the DAG is running
|
||||
|
||||
Every DAG has been triggered will create a new DAGContext.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
streaming_call: bool = False,
|
||||
node_to_outputs: Dict[str, TaskContext] = None,
|
||||
node_name_to_ids: Dict[str, str] = None,
|
||||
) -> None:
|
||||
if not node_to_outputs:
|
||||
node_to_outputs = {}
|
||||
if not node_name_to_ids:
|
||||
node_name_to_ids = {}
|
||||
self._streaming_call = streaming_call
|
||||
self._curr_task_ctx = None
|
||||
self._share_data: Dict[str, Any] = {}
|
||||
self._node_to_outputs = node_to_outputs
|
||||
self._node_name_to_ids = node_name_to_ids
|
||||
|
||||
@property
|
||||
def _task_outputs(self) -> Dict[str, TaskContext]:
|
||||
"""The task outputs of current DAG
|
||||
|
||||
Just use for internal for now.
|
||||
Returns:
|
||||
Dict[str, TaskContext]: The task outputs of current DAG
|
||||
"""
|
||||
return self._node_to_outputs
|
||||
|
||||
@property
|
||||
def current_task_context(self) -> TaskContext:
|
||||
@@ -292,12 +347,69 @@ class DAGContext:
|
||||
def set_current_task_context(self, _curr_task_ctx: TaskContext) -> None:
|
||||
self._curr_task_ctx = _curr_task_ctx
|
||||
|
||||
async def get_share_data(self, key: str) -> Any:
|
||||
def get_task_output(self, task_name: str) -> TaskOutput:
|
||||
"""Get the task output by task name
|
||||
|
||||
Args:
|
||||
task_name (str): The task name
|
||||
|
||||
Returns:
|
||||
TaskOutput: The task output
|
||||
"""
|
||||
if task_name is None:
|
||||
raise ValueError("task_name can't be None")
|
||||
node_id = self._node_name_to_ids.get(task_name)
|
||||
if node_id:
|
||||
raise ValueError(f"Task name {task_name} not exists in DAG")
|
||||
return self._task_outputs.get(node_id).task_output
|
||||
|
||||
async def get_from_share_data(self, key: str) -> Any:
|
||||
return self._share_data.get(key)
|
||||
|
||||
async def save_to_share_data(self, key: str, data: Any) -> None:
|
||||
async def save_to_share_data(
|
||||
self, key: str, data: Any, overwrite: Optional[str] = None
|
||||
) -> None:
|
||||
if key in self._share_data and not overwrite:
|
||||
raise ValueError(f"Share data key {key} already exists")
|
||||
self._share_data[key] = data
|
||||
|
||||
async def get_task_share_data(self, task_name: str, key: str) -> Any:
|
||||
"""Get share data by task name and key
|
||||
|
||||
Args:
|
||||
task_name (str): The task name
|
||||
key (str): The share data key
|
||||
|
||||
Returns:
|
||||
Any: The share data
|
||||
"""
|
||||
if task_name is None:
|
||||
raise ValueError("task_name can't be None")
|
||||
if key is None:
|
||||
raise ValueError("key can't be None")
|
||||
return self.get_from_share_data(_build_task_key(task_name, key))
|
||||
|
||||
async def save_task_share_data(
|
||||
self, task_name: str, key: str, data: Any, overwrite: Optional[str] = None
|
||||
) -> None:
|
||||
"""Save share data by task name and key
|
||||
|
||||
Args:
|
||||
task_name (str): The task name
|
||||
key (str): The share data key
|
||||
data (Any): The share data
|
||||
overwrite (Optional[str], optional): Whether overwrite the share data if the key already exists.
|
||||
Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: If the share data key already exists and overwrite is not True
|
||||
"""
|
||||
if task_name is None:
|
||||
raise ValueError("task_name can't be None")
|
||||
if key is None:
|
||||
raise ValueError("key can't be None")
|
||||
await self.save_to_share_data(_build_task_key(task_name, key), data, overwrite)
|
||||
|
||||
|
||||
class DAG:
|
||||
def __init__(
|
||||
@@ -305,11 +417,20 @@ class DAG:
|
||||
) -> None:
|
||||
self._dag_id = dag_id
|
||||
self.node_map: Dict[str, DAGNode] = {}
|
||||
self._root_nodes: Set[DAGNode] = None
|
||||
self._leaf_nodes: Set[DAGNode] = None
|
||||
self._trigger_nodes: Set[DAGNode] = None
|
||||
self.node_name_to_node: Dict[str, DAGNode] = {}
|
||||
self._root_nodes: List[DAGNode] = None
|
||||
self._leaf_nodes: List[DAGNode] = None
|
||||
self._trigger_nodes: List[DAGNode] = None
|
||||
|
||||
def _append_node(self, node: DAGNode) -> None:
|
||||
if node.node_id in self.node_map:
|
||||
return
|
||||
if node.node_name:
|
||||
if node.node_name in self.node_name_to_node:
|
||||
raise ValueError(
|
||||
f"Node name {node.node_name} already exists in DAG {self.dag_id}"
|
||||
)
|
||||
self.node_name_to_node[node.node_name] = node
|
||||
self.node_map[node.node_id] = node
|
||||
# clear cached nodes
|
||||
self._root_nodes = None
|
||||
@@ -336,22 +457,44 @@ class DAG:
|
||||
|
||||
@property
|
||||
def root_nodes(self) -> List[DAGNode]:
|
||||
"""The root nodes of current DAG
|
||||
|
||||
Returns:
|
||||
List[DAGNode]: The root nodes of current DAG, no repeat
|
||||
"""
|
||||
if not self._root_nodes:
|
||||
self._build()
|
||||
return self._root_nodes
|
||||
|
||||
@property
|
||||
def leaf_nodes(self) -> List[DAGNode]:
|
||||
"""The leaf nodes of current DAG
|
||||
|
||||
Returns:
|
||||
List[DAGNode]: The leaf nodes of current DAG, no repeat
|
||||
"""
|
||||
if not self._leaf_nodes:
|
||||
self._build()
|
||||
return self._leaf_nodes
|
||||
|
||||
@property
|
||||
def trigger_nodes(self):
|
||||
def trigger_nodes(self) -> List[DAGNode]:
|
||||
"""The trigger nodes of current DAG
|
||||
|
||||
Returns:
|
||||
List[DAGNode]: The trigger nodes of current DAG, no repeat
|
||||
"""
|
||||
if not self._trigger_nodes:
|
||||
self._build()
|
||||
return self._trigger_nodes
|
||||
|
||||
async def _after_dag_end(self) -> None:
|
||||
"""The callback after DAG end"""
|
||||
tasks = []
|
||||
for node in self.node_map.values():
|
||||
tasks.append(node.after_dag_end())
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
def __enter__(self):
|
||||
DAGVar.enter_dag(self)
|
||||
return self
|
||||
|
Reference in New Issue
Block a user