mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-16 14:40:56 +00:00
refactor: The first refactored version for sdk release (#907)
Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
0
dbgpt/core/awel/runner/__init__.py
Normal file
0
dbgpt/core/awel/runner/__init__.py
Normal file
82
dbgpt/core/awel/runner/job_manager.py
Normal file
82
dbgpt/core/awel/runner/job_manager.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from typing import List, Set, Optional, Dict
|
||||
import uuid
|
||||
import logging
|
||||
from ..dag.base import DAG
|
||||
|
||||
from ..operator.base import BaseOperator, CALL_DATA
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DAGNodeInstance:
|
||||
def __init__(self, node_instance: DAG) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class DAGInstance:
|
||||
def __init__(self, dag: DAG) -> None:
|
||||
self._dag = dag
|
||||
|
||||
|
||||
class JobManager:
|
||||
def __init__(
|
||||
self,
|
||||
root_nodes: List[BaseOperator],
|
||||
all_nodes: List[BaseOperator],
|
||||
end_node: BaseOperator,
|
||||
id2call_data: Dict[str, Dict],
|
||||
) -> None:
|
||||
self._root_nodes = root_nodes
|
||||
self._all_nodes = all_nodes
|
||||
self._end_node = end_node
|
||||
self._id2node_data = id2call_data
|
||||
|
||||
@staticmethod
|
||||
def build_from_end_node(
|
||||
end_node: BaseOperator, call_data: Optional[CALL_DATA] = None
|
||||
) -> "JobManager":
|
||||
nodes = _build_from_end_node(end_node)
|
||||
root_nodes = _get_root_nodes(nodes)
|
||||
id2call_data = _save_call_data(root_nodes, call_data)
|
||||
return JobManager(root_nodes, nodes, end_node, id2call_data)
|
||||
|
||||
def get_call_data_by_id(self, node_id: str) -> Optional[Dict]:
|
||||
return self._id2node_data.get(node_id)
|
||||
|
||||
|
||||
def _save_call_data(
|
||||
root_nodes: List[BaseOperator], call_data: CALL_DATA
|
||||
) -> Dict[str, Dict]:
|
||||
id2call_data = {}
|
||||
logger.debug(f"_save_call_data: {call_data}, root_nodes: {root_nodes}")
|
||||
if not call_data:
|
||||
return id2call_data
|
||||
if len(root_nodes) == 1:
|
||||
node = root_nodes[0]
|
||||
logger.info(f"Save call data to node {node.node_id}, call_data: {call_data}")
|
||||
id2call_data[node.node_id] = call_data
|
||||
else:
|
||||
for node in root_nodes:
|
||||
node_id = node.node_id
|
||||
logger.info(
|
||||
f"Save call data to node {node.node_id}, call_data: {call_data.get(node_id)}"
|
||||
)
|
||||
id2call_data[node_id] = call_data.get(node_id)
|
||||
return id2call_data
|
||||
|
||||
|
||||
def _build_from_end_node(end_node: BaseOperator) -> List[BaseOperator]:
|
||||
nodes = []
|
||||
if isinstance(end_node, BaseOperator):
|
||||
task_id = end_node.node_id
|
||||
if not task_id:
|
||||
task_id = str(uuid.uuid4())
|
||||
end_node.set_node_id(task_id)
|
||||
nodes.append(end_node)
|
||||
for node in end_node.upstream:
|
||||
nodes += _build_from_end_node(node)
|
||||
return nodes
|
||||
|
||||
|
||||
def _get_root_nodes(nodes: List[BaseOperator]) -> List[BaseOperator]:
|
||||
return list(set(filter(lambda x: not x.upstream, nodes)))
|
109
dbgpt/core/awel/runner/local_runner.py
Normal file
109
dbgpt/core/awel/runner/local_runner.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from typing import Dict, Optional, Set, List
|
||||
import logging
|
||||
|
||||
from ..dag.base import DAGContext
|
||||
from ..operator.base import WorkflowRunner, BaseOperator, CALL_DATA
|
||||
from ..operator.common_operator import BranchOperator, JoinOperator, TriggerOperator
|
||||
from ..task.base import TaskContext, TaskState
|
||||
from ..task.task_impl import DefaultInputContext, DefaultTaskContext, SimpleTaskOutput
|
||||
from .job_manager import JobManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DefaultWorkflowRunner(WorkflowRunner):
|
||||
async def execute_workflow(
|
||||
self,
|
||||
node: BaseOperator,
|
||||
call_data: Optional[CALL_DATA] = None,
|
||||
streaming_call: bool = False,
|
||||
) -> DAGContext:
|
||||
# Create DAG context
|
||||
dag_ctx = DAGContext(streaming_call=streaming_call)
|
||||
job_manager = JobManager.build_from_end_node(node, call_data)
|
||||
logger.info(
|
||||
f"Begin run workflow from end operator, id: {node.node_id}, call_data: {call_data}"
|
||||
)
|
||||
dag = node.dag
|
||||
# Save node output
|
||||
node_outputs: Dict[str, TaskContext] = {}
|
||||
skip_node_ids = set()
|
||||
await self._execute_node(
|
||||
job_manager, node, dag_ctx, node_outputs, skip_node_ids
|
||||
)
|
||||
|
||||
return dag_ctx
|
||||
|
||||
async def _execute_node(
|
||||
self,
|
||||
job_manager: JobManager,
|
||||
node: BaseOperator,
|
||||
dag_ctx: DAGContext,
|
||||
node_outputs: Dict[str, TaskContext],
|
||||
skip_node_ids: Set[str],
|
||||
):
|
||||
# Skip run node
|
||||
if node.node_id in node_outputs:
|
||||
return
|
||||
|
||||
# Run all upstream node
|
||||
for upstream_node in node.upstream:
|
||||
if isinstance(upstream_node, BaseOperator):
|
||||
await self._execute_node(
|
||||
job_manager, upstream_node, dag_ctx, node_outputs, skip_node_ids
|
||||
)
|
||||
|
||||
inputs = [
|
||||
node_outputs[upstream_node.node_id] for upstream_node in node.upstream
|
||||
]
|
||||
input_ctx = DefaultInputContext(inputs)
|
||||
task_ctx = DefaultTaskContext(node.node_id, TaskState.INIT, task_output=None)
|
||||
task_ctx.set_call_data(job_manager.get_call_data_by_id(node.node_id))
|
||||
|
||||
task_ctx.set_task_input(input_ctx)
|
||||
dag_ctx.set_current_task_context(task_ctx)
|
||||
task_ctx.set_current_state(TaskState.RUNNING)
|
||||
|
||||
if node.node_id in skip_node_ids:
|
||||
task_ctx.set_current_state(TaskState.SKIP)
|
||||
task_ctx.set_task_output(SimpleTaskOutput(None))
|
||||
node_outputs[node.node_id] = task_ctx
|
||||
return
|
||||
try:
|
||||
logger.debug(
|
||||
f"Begin run operator, node id: {node.node_id}, node name: {node.node_name}, cls: {node}"
|
||||
)
|
||||
await node._run(dag_ctx)
|
||||
node_outputs[node.node_id] = dag_ctx.current_task_context
|
||||
task_ctx.set_current_state(TaskState.SUCCESS)
|
||||
|
||||
if isinstance(node, BranchOperator):
|
||||
skip_nodes = task_ctx.metadata.get("skip_node_names", [])
|
||||
logger.debug(
|
||||
f"Current is branch operator, skip node names: {skip_nodes}"
|
||||
)
|
||||
_skip_current_downstream_by_node_name(node, skip_nodes, skip_node_ids)
|
||||
except Exception as e:
|
||||
logger.info(f"Run operator {node.node_id} error, error message: {str(e)}")
|
||||
task_ctx.set_current_state(TaskState.FAILED)
|
||||
raise e
|
||||
|
||||
|
||||
def _skip_current_downstream_by_node_name(
|
||||
branch_node: BranchOperator, skip_nodes: List[str], skip_node_ids: Set[str]
|
||||
):
|
||||
if not skip_nodes:
|
||||
return
|
||||
for child in branch_node.downstream:
|
||||
if child.node_name in skip_nodes:
|
||||
logger.info(f"Skip node name {child.node_name}, node id {child.node_id}")
|
||||
_skip_downstream_by_id(child, skip_node_ids)
|
||||
|
||||
|
||||
def _skip_downstream_by_id(node: BaseOperator, skip_node_ids: Set[str]):
|
||||
if isinstance(node, JoinOperator):
|
||||
# Not skip join node
|
||||
return
|
||||
skip_node_ids.add(node.node_id)
|
||||
for child in node.downstream:
|
||||
_skip_downstream_by_id(child, skip_node_ids)
|
Reference in New Issue
Block a user