Files
DB-GPT/dbgpt/core/awel/runner/local_runner.py
Fangyin Cheng 9502251c08 feat(core): AWEL flow 2.0 backend code (#1879)
Co-authored-by: yhjun1026 <460342015@qq.com>
2024-08-23 14:57:54 +08:00

236 lines
9.2 KiB
Python

"""Local runner for workflow.
This runner will run the workflow in the current process.
"""
import asyncio
import logging
import traceback
from typing import Any, Dict, List, Optional, Set, cast
from dbgpt.component import SystemApp
from dbgpt.util.tracer import root_tracer
from ..dag.base import DAGContext, DAGVar, DAGVariables
from ..operators.base import CALL_DATA, BaseOperator, WorkflowRunner
from ..operators.common_operator import BranchOperator
from ..task.base import SKIP_DATA, TaskContext, TaskState
from ..task.task_impl import DefaultInputContext, DefaultTaskContext, SimpleTaskOutput
from .job_manager import JobManager
logger = logging.getLogger(__name__)
class DefaultWorkflowRunner(WorkflowRunner):
"""The default workflow runner."""
def __init__(self):
"""Init the default workflow runner."""
self._running_dag_ctx: Dict[str, DAGContext] = {}
self._task_log_index_map: Dict[str, int] = {}
self._lock = asyncio.Lock()
async def _log_task(self, task_id: str) -> int:
async with self._lock:
if task_id not in self._task_log_index_map:
self._task_log_index_map[task_id] = 0
self._task_log_index_map[task_id] += 1
logger.debug(
f"Task {task_id} log index {self._task_log_index_map[task_id]}"
)
return self._task_log_index_map[task_id]
async def execute_workflow(
self,
node: BaseOperator,
call_data: Optional[CALL_DATA] = None,
streaming_call: bool = False,
exist_dag_ctx: Optional[DAGContext] = None,
dag_variables: Optional[DAGVariables] = None,
) -> DAGContext:
"""Execute the workflow.
Args:
node (BaseOperator): The end node of the workflow.
call_data (Optional[CALL_DATA], optional): The call data of the end node.
Defaults to None.
streaming_call (bool, optional): Whether the call is streaming call.
Defaults to False.
exist_dag_ctx (Optional[DAGContext], optional): The exist DAG context.
Defaults to None.
dag_variables (Optional[DAGVariables], optional): The DAG variables.
"""
# Save node output
# dag = node.dag
job_manager = JobManager.build_from_end_node(node, call_data)
if not exist_dag_ctx:
# Create DAG context
node_outputs: Dict[str, TaskContext] = {}
share_data: Dict[str, Any] = {}
event_loop_task_id = id(asyncio.current_task())
else:
# Share node output with exist dag context
node_outputs = exist_dag_ctx._node_to_outputs
share_data = exist_dag_ctx._share_data
event_loop_task_id = exist_dag_ctx._event_loop_task_id
if dag_variables and exist_dag_ctx._dag_variables:
# Merge dag variables, prefer the `dag_variables` in the parameter
dag_variables = dag_variables.merge(exist_dag_ctx._dag_variables)
if node.dag and not dag_variables and node.dag._default_dag_variables:
# Use default dag variables if not set
dag_variables = node.dag._default_dag_variables
dag_ctx = DAGContext(
event_loop_task_id=event_loop_task_id,
node_to_outputs=node_outputs,
share_data=share_data,
streaming_call=streaming_call,
node_name_to_ids=job_manager._node_name_to_ids,
dag_variables=dag_variables,
)
# if node.dag:
# self._running_dag_ctx[node.dag.dag_id] = dag_ctx
logger.info(
f"Begin run workflow from end operator, id: {node.node_id}, runner: {self}"
)
logger.debug(f"Node id {node.node_id}, call_data: {call_data}")
skip_node_ids: Set[str] = set()
system_app: Optional[SystemApp] = DAGVar.get_current_system_app()
if node.dag:
# Save dag context
await node.dag._save_dag_ctx(dag_ctx)
await job_manager.before_dag_run()
with root_tracer.start_span(
"dbgpt.awel.workflow.run_workflow",
metadata={
"exist_dag_ctx": exist_dag_ctx is not None,
"event_loop_task_id": event_loop_task_id,
"streaming_call": streaming_call,
"awel_node_id": node.node_id,
"awel_node_name": node.node_name,
},
):
await self._execute_node(
job_manager, node, dag_ctx, node_outputs, skip_node_ids, system_app
)
if not streaming_call and node.dag and exist_dag_ctx is None:
# streaming call not work for dag end
# if exist_dag_ctx is not None, it means current dag is a sub dag
await node.dag._after_dag_end(dag_ctx._event_loop_task_id)
# if node.dag:
# del self._running_dag_ctx[node.dag.dag_id]
return dag_ctx
async def _execute_node(
self,
job_manager: JobManager,
node: BaseOperator,
dag_ctx: DAGContext,
node_outputs: Dict[str, TaskContext],
skip_node_ids: Set[str],
system_app: Optional[SystemApp],
):
# Skip run node
if node.node_id in node_outputs:
return
# Run all upstream nodes
# TODO: run in parallel, there are some code to be changed:
# dag_ctx.set_current_task_context(task_ctx)
for upstream_node in node.upstream:
if isinstance(upstream_node, BaseOperator):
await self._execute_node(
job_manager,
upstream_node,
dag_ctx,
node_outputs,
skip_node_ids,
system_app,
)
inputs = [
node_outputs[upstream_node.node_id] for upstream_node in node.upstream
]
input_ctx = DefaultInputContext(inputs)
# Log task, get log index(plus 1 every time)
log_index = await self._log_task(node.node_id)
task_ctx: DefaultTaskContext = DefaultTaskContext(
node.node_id, TaskState.INIT, task_output=None, log_index=log_index
)
current_call_data = job_manager.get_call_data_by_id(node.node_id)
if current_call_data:
task_ctx.set_call_data(current_call_data)
task_ctx.set_task_input(input_ctx)
dag_ctx.set_current_task_context(task_ctx)
task_ctx.set_current_state(TaskState.RUNNING)
if node.node_id in skip_node_ids:
task_ctx.set_current_state(TaskState.SKIP)
task_ctx.set_task_output(SimpleTaskOutput(SKIP_DATA))
node_outputs[node.node_id] = task_ctx
return
try:
logger.debug(
f"Begin run operator, node id: {node.node_id}, node name: "
f"{node.node_name}, cls: {node}"
)
if system_app is not None and node.system_app is None:
node.set_system_app(system_app)
run_metadata = {
"awel_node_id": node.node_id,
"awel_node_name": node.node_name,
"awel_node_type": str(node),
"state": TaskState.RUNNING.value,
"task_log_id": task_ctx.log_id,
}
with root_tracer.start_span(
"dbgpt.awel.workflow.run_operator", metadata=run_metadata
) as span:
await node._run(dag_ctx, task_ctx.log_id)
node_outputs[node.node_id] = dag_ctx.current_task_context
task_ctx.set_current_state(TaskState.SUCCESS)
run_metadata["skip_node_ids"] = ",".join(skip_node_ids)
run_metadata["state"] = TaskState.SUCCESS.value
span.metadata = run_metadata
if isinstance(node, BranchOperator):
skip_nodes = task_ctx.metadata.get("skip_node_names", [])
logger.debug(
f"Current is branch operator, skip node names: {skip_nodes}"
)
_skip_current_downstream_by_node_name(node, skip_nodes, skip_node_ids)
except Exception as e:
msg = traceback.format_exc()
logger.info(
f"Run operator {type(node)}({node.node_id}) error, error message: "
f"{msg}"
)
task_ctx.set_current_state(TaskState.FAILED)
raise e
def _skip_current_downstream_by_node_name(
branch_node: BranchOperator, skip_nodes: List[str], skip_node_ids: Set[str]
):
if not skip_nodes:
return
for child in branch_node.downstream:
child = cast(BaseOperator, child)
if child.node_name in skip_nodes or child.node_id in skip_node_ids:
logger.info(f"Skip node name {child.node_name}, node id {child.node_id}")
_skip_downstream_by_id(child, skip_node_ids)
def _skip_downstream_by_id(node: BaseOperator, skip_node_ids: Set[str]):
if not node.can_skip_in_branch():
# Current node can not skip, so skip its downstream
return
skip_node_ids.add(node.node_id)
for child in node.downstream:
child = cast(BaseOperator, child)
_skip_downstream_by_id(child, skip_node_ids)