refactor: The first refactored version for sdk release (#907)

Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
FangYin Cheng
2023-12-08 14:45:59 +08:00
committed by GitHub
parent e7e4aff667
commit cd725db1fb
573 changed files with 2094 additions and 3571 deletions

View File

View File

@@ -0,0 +1,82 @@
from typing import List, Set, Optional, Dict
import uuid
import logging
from ..dag.base import DAG
from ..operator.base import BaseOperator, CALL_DATA
logger = logging.getLogger(__name__)
class DAGNodeInstance:
def __init__(self, node_instance: DAG) -> None:
pass
class DAGInstance:
def __init__(self, dag: DAG) -> None:
self._dag = dag
class JobManager:
def __init__(
self,
root_nodes: List[BaseOperator],
all_nodes: List[BaseOperator],
end_node: BaseOperator,
id2call_data: Dict[str, Dict],
) -> None:
self._root_nodes = root_nodes
self._all_nodes = all_nodes
self._end_node = end_node
self._id2node_data = id2call_data
@staticmethod
def build_from_end_node(
end_node: BaseOperator, call_data: Optional[CALL_DATA] = None
) -> "JobManager":
nodes = _build_from_end_node(end_node)
root_nodes = _get_root_nodes(nodes)
id2call_data = _save_call_data(root_nodes, call_data)
return JobManager(root_nodes, nodes, end_node, id2call_data)
def get_call_data_by_id(self, node_id: str) -> Optional[Dict]:
return self._id2node_data.get(node_id)
def _save_call_data(
root_nodes: List[BaseOperator], call_data: CALL_DATA
) -> Dict[str, Dict]:
id2call_data = {}
logger.debug(f"_save_call_data: {call_data}, root_nodes: {root_nodes}")
if not call_data:
return id2call_data
if len(root_nodes) == 1:
node = root_nodes[0]
logger.info(f"Save call data to node {node.node_id}, call_data: {call_data}")
id2call_data[node.node_id] = call_data
else:
for node in root_nodes:
node_id = node.node_id
logger.info(
f"Save call data to node {node.node_id}, call_data: {call_data.get(node_id)}"
)
id2call_data[node_id] = call_data.get(node_id)
return id2call_data
def _build_from_end_node(end_node: BaseOperator) -> List[BaseOperator]:
nodes = []
if isinstance(end_node, BaseOperator):
task_id = end_node.node_id
if not task_id:
task_id = str(uuid.uuid4())
end_node.set_node_id(task_id)
nodes.append(end_node)
for node in end_node.upstream:
nodes += _build_from_end_node(node)
return nodes
def _get_root_nodes(nodes: List[BaseOperator]) -> List[BaseOperator]:
return list(set(filter(lambda x: not x.upstream, nodes)))

View File

@@ -0,0 +1,109 @@
from typing import Dict, Optional, Set, List
import logging
from ..dag.base import DAGContext
from ..operator.base import WorkflowRunner, BaseOperator, CALL_DATA
from ..operator.common_operator import BranchOperator, JoinOperator, TriggerOperator
from ..task.base import TaskContext, TaskState
from ..task.task_impl import DefaultInputContext, DefaultTaskContext, SimpleTaskOutput
from .job_manager import JobManager
logger = logging.getLogger(__name__)
class DefaultWorkflowRunner(WorkflowRunner):
async def execute_workflow(
self,
node: BaseOperator,
call_data: Optional[CALL_DATA] = None,
streaming_call: bool = False,
) -> DAGContext:
# Create DAG context
dag_ctx = DAGContext(streaming_call=streaming_call)
job_manager = JobManager.build_from_end_node(node, call_data)
logger.info(
f"Begin run workflow from end operator, id: {node.node_id}, call_data: {call_data}"
)
dag = node.dag
# Save node output
node_outputs: Dict[str, TaskContext] = {}
skip_node_ids = set()
await self._execute_node(
job_manager, node, dag_ctx, node_outputs, skip_node_ids
)
return dag_ctx
async def _execute_node(
self,
job_manager: JobManager,
node: BaseOperator,
dag_ctx: DAGContext,
node_outputs: Dict[str, TaskContext],
skip_node_ids: Set[str],
):
# Skip run node
if node.node_id in node_outputs:
return
# Run all upstream node
for upstream_node in node.upstream:
if isinstance(upstream_node, BaseOperator):
await self._execute_node(
job_manager, upstream_node, dag_ctx, node_outputs, skip_node_ids
)
inputs = [
node_outputs[upstream_node.node_id] for upstream_node in node.upstream
]
input_ctx = DefaultInputContext(inputs)
task_ctx = DefaultTaskContext(node.node_id, TaskState.INIT, task_output=None)
task_ctx.set_call_data(job_manager.get_call_data_by_id(node.node_id))
task_ctx.set_task_input(input_ctx)
dag_ctx.set_current_task_context(task_ctx)
task_ctx.set_current_state(TaskState.RUNNING)
if node.node_id in skip_node_ids:
task_ctx.set_current_state(TaskState.SKIP)
task_ctx.set_task_output(SimpleTaskOutput(None))
node_outputs[node.node_id] = task_ctx
return
try:
logger.debug(
f"Begin run operator, node id: {node.node_id}, node name: {node.node_name}, cls: {node}"
)
await node._run(dag_ctx)
node_outputs[node.node_id] = dag_ctx.current_task_context
task_ctx.set_current_state(TaskState.SUCCESS)
if isinstance(node, BranchOperator):
skip_nodes = task_ctx.metadata.get("skip_node_names", [])
logger.debug(
f"Current is branch operator, skip node names: {skip_nodes}"
)
_skip_current_downstream_by_node_name(node, skip_nodes, skip_node_ids)
except Exception as e:
logger.info(f"Run operator {node.node_id} error, error message: {str(e)}")
task_ctx.set_current_state(TaskState.FAILED)
raise e
def _skip_current_downstream_by_node_name(
branch_node: BranchOperator, skip_nodes: List[str], skip_node_ids: Set[str]
):
if not skip_nodes:
return
for child in branch_node.downstream:
if child.node_name in skip_nodes:
logger.info(f"Skip node name {child.node_name}, node id {child.node_id}")
_skip_downstream_by_id(child, skip_node_ids)
def _skip_downstream_by_id(node: BaseOperator, skip_node_ids: Set[str]):
if isinstance(node, JoinOperator):
# Not skip join node
return
skip_node_ids.add(node.node_id)
for child in node.downstream:
_skip_downstream_by_id(child, skip_node_ids)