feat(core): More AWEL operators and new prompt manager API (#972)

Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
Fangyin Cheng
2023-12-25 20:03:22 +08:00
committed by GitHub
parent 048fb6c402
commit 69fb97e508
46 changed files with 2556 additions and 294 deletions

View File

@@ -1,7 +1,8 @@
import asyncio
from typing import List, Set, Optional, Dict
import uuid
import logging
from ..dag.base import DAG
from ..dag.base import DAG, DAGLifecycle
from ..operator.base import BaseOperator, CALL_DATA
@@ -18,18 +19,20 @@ class DAGInstance:
self._dag = dag
class JobManager:
class JobManager(DAGLifecycle):
def __init__(
self,
root_nodes: List[BaseOperator],
all_nodes: List[BaseOperator],
end_node: BaseOperator,
id2call_data: Dict[str, Dict],
node_name_to_ids: Dict[str, str],
) -> None:
self._root_nodes = root_nodes
self._all_nodes = all_nodes
self._end_node = end_node
self._id2node_data = id2call_data
self._node_name_to_ids = node_name_to_ids
@staticmethod
def build_from_end_node(
@@ -38,11 +41,31 @@ class JobManager:
nodes = _build_from_end_node(end_node)
root_nodes = _get_root_nodes(nodes)
id2call_data = _save_call_data(root_nodes, call_data)
return JobManager(root_nodes, nodes, end_node, id2call_data)
node_name_to_ids = {}
for node in nodes:
if node.node_name is not None:
node_name_to_ids[node.node_name] = node.node_id
return JobManager(root_nodes, nodes, end_node, id2call_data, node_name_to_ids)
def get_call_data_by_id(self, node_id: str) -> Optional[Dict]:
return self._id2node_data.get(node_id)
async def before_dag_run(self):
"""The callback before DAG run"""
tasks = []
for node in self._all_nodes:
tasks.append(node.before_dag_run())
await asyncio.gather(*tasks)
async def after_dag_end(self):
"""The callback after DAG end"""
tasks = []
for node in self._all_nodes:
tasks.append(node.after_dag_end())
await asyncio.gather(*tasks)
def _save_call_data(
root_nodes: List[BaseOperator], call_data: CALL_DATA
@@ -66,6 +89,7 @@ def _save_call_data(
def _build_from_end_node(end_node: BaseOperator) -> List[BaseOperator]:
"""Build all nodes from the end node."""
nodes = []
if isinstance(end_node, BaseOperator):
task_id = end_node.node_id

View File

@@ -1,7 +1,8 @@
from typing import Dict, Optional, Set, List
import logging
from ..dag.base import DAGContext
from dbgpt.component import SystemApp
from ..dag.base import DAGContext, DAGVar
from ..operator.base import WorkflowRunner, BaseOperator, CALL_DATA
from ..operator.common_operator import BranchOperator, JoinOperator, TriggerOperator
from ..task.base import TaskContext, TaskState
@@ -18,19 +19,29 @@ class DefaultWorkflowRunner(WorkflowRunner):
call_data: Optional[CALL_DATA] = None,
streaming_call: bool = False,
) -> DAGContext:
# Create DAG context
dag_ctx = DAGContext(streaming_call=streaming_call)
# Save node output
# dag = node.dag
node_outputs: Dict[str, TaskContext] = {}
job_manager = JobManager.build_from_end_node(node, call_data)
# Create DAG context
dag_ctx = DAGContext(
streaming_call=streaming_call,
node_to_outputs=node_outputs,
node_name_to_ids=job_manager._node_name_to_ids,
)
logger.info(
f"Begin run workflow from end operator, id: {node.node_id}, call_data: {call_data}"
)
dag = node.dag
# Save node output
node_outputs: Dict[str, TaskContext] = {}
skip_node_ids = set()
system_app: SystemApp = DAGVar.get_current_system_app()
await job_manager.before_dag_run()
await self._execute_node(
job_manager, node, dag_ctx, node_outputs, skip_node_ids
job_manager, node, dag_ctx, node_outputs, skip_node_ids, system_app
)
if not streaming_call and node.dag:
# streaming call not work for dag end
await node.dag._after_dag_end()
return dag_ctx
@@ -41,6 +52,7 @@ class DefaultWorkflowRunner(WorkflowRunner):
dag_ctx: DAGContext,
node_outputs: Dict[str, TaskContext],
skip_node_ids: Set[str],
system_app: SystemApp,
):
# Skip run node
if node.node_id in node_outputs:
@@ -50,7 +62,12 @@ class DefaultWorkflowRunner(WorkflowRunner):
for upstream_node in node.upstream:
if isinstance(upstream_node, BaseOperator):
await self._execute_node(
job_manager, upstream_node, dag_ctx, node_outputs, skip_node_ids
job_manager,
upstream_node,
dag_ctx,
node_outputs,
skip_node_ids,
system_app,
)
inputs = [
@@ -73,6 +90,9 @@ class DefaultWorkflowRunner(WorkflowRunner):
logger.debug(
f"Begin run operator, node id: {node.node_id}, node name: {node.node_name}, cls: {node}"
)
if system_app is not None and node.system_app is None:
node.set_system_app(system_app)
await node._run(dag_ctx)
node_outputs[node.node_id] = dag_ctx.current_task_context
task_ctx.set_current_state(TaskState.SUCCESS)