mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-11 22:09:44 +00:00
feat(core): More AWEL operators and new prompt manager API (#972)
Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
import asyncio
|
||||
from typing import List, Set, Optional, Dict
|
||||
import uuid
|
||||
import logging
|
||||
from ..dag.base import DAG
|
||||
from ..dag.base import DAG, DAGLifecycle
|
||||
|
||||
from ..operator.base import BaseOperator, CALL_DATA
|
||||
|
||||
@@ -18,18 +19,20 @@ class DAGInstance:
|
||||
self._dag = dag
|
||||
|
||||
|
||||
class JobManager:
|
||||
class JobManager(DAGLifecycle):
|
||||
def __init__(
|
||||
self,
|
||||
root_nodes: List[BaseOperator],
|
||||
all_nodes: List[BaseOperator],
|
||||
end_node: BaseOperator,
|
||||
id2call_data: Dict[str, Dict],
|
||||
node_name_to_ids: Dict[str, str],
|
||||
) -> None:
|
||||
self._root_nodes = root_nodes
|
||||
self._all_nodes = all_nodes
|
||||
self._end_node = end_node
|
||||
self._id2node_data = id2call_data
|
||||
self._node_name_to_ids = node_name_to_ids
|
||||
|
||||
@staticmethod
|
||||
def build_from_end_node(
|
||||
@@ -38,11 +41,31 @@ class JobManager:
|
||||
nodes = _build_from_end_node(end_node)
|
||||
root_nodes = _get_root_nodes(nodes)
|
||||
id2call_data = _save_call_data(root_nodes, call_data)
|
||||
return JobManager(root_nodes, nodes, end_node, id2call_data)
|
||||
|
||||
node_name_to_ids = {}
|
||||
for node in nodes:
|
||||
if node.node_name is not None:
|
||||
node_name_to_ids[node.node_name] = node.node_id
|
||||
|
||||
return JobManager(root_nodes, nodes, end_node, id2call_data, node_name_to_ids)
|
||||
|
||||
def get_call_data_by_id(self, node_id: str) -> Optional[Dict]:
|
||||
return self._id2node_data.get(node_id)
|
||||
|
||||
async def before_dag_run(self):
|
||||
"""The callback before DAG run"""
|
||||
tasks = []
|
||||
for node in self._all_nodes:
|
||||
tasks.append(node.before_dag_run())
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def after_dag_end(self):
|
||||
"""The callback after DAG end"""
|
||||
tasks = []
|
||||
for node in self._all_nodes:
|
||||
tasks.append(node.after_dag_end())
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
def _save_call_data(
|
||||
root_nodes: List[BaseOperator], call_data: CALL_DATA
|
||||
@@ -66,6 +89,7 @@ def _save_call_data(
|
||||
|
||||
|
||||
def _build_from_end_node(end_node: BaseOperator) -> List[BaseOperator]:
|
||||
"""Build all nodes from the end node."""
|
||||
nodes = []
|
||||
if isinstance(end_node, BaseOperator):
|
||||
task_id = end_node.node_id
|
||||
|
@@ -1,7 +1,8 @@
|
||||
from typing import Dict, Optional, Set, List
|
||||
import logging
|
||||
|
||||
from ..dag.base import DAGContext
|
||||
from dbgpt.component import SystemApp
|
||||
from ..dag.base import DAGContext, DAGVar
|
||||
from ..operator.base import WorkflowRunner, BaseOperator, CALL_DATA
|
||||
from ..operator.common_operator import BranchOperator, JoinOperator, TriggerOperator
|
||||
from ..task.base import TaskContext, TaskState
|
||||
@@ -18,19 +19,29 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
call_data: Optional[CALL_DATA] = None,
|
||||
streaming_call: bool = False,
|
||||
) -> DAGContext:
|
||||
# Create DAG context
|
||||
dag_ctx = DAGContext(streaming_call=streaming_call)
|
||||
# Save node output
|
||||
# dag = node.dag
|
||||
node_outputs: Dict[str, TaskContext] = {}
|
||||
job_manager = JobManager.build_from_end_node(node, call_data)
|
||||
# Create DAG context
|
||||
dag_ctx = DAGContext(
|
||||
streaming_call=streaming_call,
|
||||
node_to_outputs=node_outputs,
|
||||
node_name_to_ids=job_manager._node_name_to_ids,
|
||||
)
|
||||
logger.info(
|
||||
f"Begin run workflow from end operator, id: {node.node_id}, call_data: {call_data}"
|
||||
)
|
||||
dag = node.dag
|
||||
# Save node output
|
||||
node_outputs: Dict[str, TaskContext] = {}
|
||||
skip_node_ids = set()
|
||||
system_app: SystemApp = DAGVar.get_current_system_app()
|
||||
|
||||
await job_manager.before_dag_run()
|
||||
await self._execute_node(
|
||||
job_manager, node, dag_ctx, node_outputs, skip_node_ids
|
||||
job_manager, node, dag_ctx, node_outputs, skip_node_ids, system_app
|
||||
)
|
||||
if not streaming_call and node.dag:
|
||||
# streaming call not work for dag end
|
||||
await node.dag._after_dag_end()
|
||||
|
||||
return dag_ctx
|
||||
|
||||
@@ -41,6 +52,7 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
dag_ctx: DAGContext,
|
||||
node_outputs: Dict[str, TaskContext],
|
||||
skip_node_ids: Set[str],
|
||||
system_app: SystemApp,
|
||||
):
|
||||
# Skip run node
|
||||
if node.node_id in node_outputs:
|
||||
@@ -50,7 +62,12 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
for upstream_node in node.upstream:
|
||||
if isinstance(upstream_node, BaseOperator):
|
||||
await self._execute_node(
|
||||
job_manager, upstream_node, dag_ctx, node_outputs, skip_node_ids
|
||||
job_manager,
|
||||
upstream_node,
|
||||
dag_ctx,
|
||||
node_outputs,
|
||||
skip_node_ids,
|
||||
system_app,
|
||||
)
|
||||
|
||||
inputs = [
|
||||
@@ -73,6 +90,9 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
logger.debug(
|
||||
f"Begin run operator, node id: {node.node_id}, node name: {node.node_name}, cls: {node}"
|
||||
)
|
||||
if system_app is not None and node.system_app is None:
|
||||
node.set_system_app(system_app)
|
||||
|
||||
await node._run(dag_ctx)
|
||||
node_outputs[node.node_id] = dag_ctx.current_task_context
|
||||
task_ctx.set_current_state(TaskState.SUCCESS)
|
||||
|
Reference in New Issue
Block a user