mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-09 04:49:26 +00:00
chore: Add pylint for DB-GPT core lib (#1076)
This commit is contained in:
@@ -0,0 +1,4 @@
|
||||
"""The module to run AWEL operators.
|
||||
|
||||
You can implement your own runner by inheriting the `WorkflowRunner` class.
|
||||
"""
|
||||
|
@@ -1,33 +1,38 @@
|
||||
"""Job manager for DAG."""
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Dict, List, Optional, Set
|
||||
from typing import Dict, List, Optional, cast
|
||||
|
||||
from ..dag.base import DAG, DAGLifecycle
|
||||
from ..dag.base import DAGLifecycle
|
||||
from ..operator.base import CALL_DATA, BaseOperator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DAGNodeInstance:
|
||||
def __init__(self, node_instance: DAG) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class DAGInstance:
|
||||
def __init__(self, dag: DAG) -> None:
|
||||
self._dag = dag
|
||||
|
||||
|
||||
class JobManager(DAGLifecycle):
|
||||
"""Job manager for DAG.
|
||||
|
||||
This class is used to manage the DAG lifecycle.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_nodes: List[BaseOperator],
|
||||
all_nodes: List[BaseOperator],
|
||||
end_node: BaseOperator,
|
||||
id2call_data: Dict[str, Dict],
|
||||
id2call_data: Dict[str, Optional[Dict]],
|
||||
node_name_to_ids: Dict[str, str],
|
||||
) -> None:
|
||||
"""Create a job manager.
|
||||
|
||||
Args:
|
||||
root_nodes (List[BaseOperator]): The root nodes of the DAG.
|
||||
all_nodes (List[BaseOperator]): All nodes of the DAG.
|
||||
end_node (BaseOperator): The end node of the DAG.
|
||||
id2call_data (Dict[str, Optional[Dict]]): The call data of each node.
|
||||
node_name_to_ids (Dict[str, str]): The node name to node id mapping.
|
||||
"""
|
||||
self._root_nodes = root_nodes
|
||||
self._all_nodes = all_nodes
|
||||
self._end_node = end_node
|
||||
@@ -38,6 +43,15 @@ class JobManager(DAGLifecycle):
|
||||
def build_from_end_node(
|
||||
end_node: BaseOperator, call_data: Optional[CALL_DATA] = None
|
||||
) -> "JobManager":
|
||||
"""Build a job manager from the end node.
|
||||
|
||||
This will get all upstream nodes from the end node, and build a job manager.
|
||||
|
||||
Args:
|
||||
end_node (BaseOperator): The end node of the DAG.
|
||||
call_data (Optional[CALL_DATA], optional): The call data of the end node.
|
||||
Defaults to None.
|
||||
"""
|
||||
nodes = _build_from_end_node(end_node)
|
||||
root_nodes = _get_root_nodes(nodes)
|
||||
id2call_data = _save_call_data(root_nodes, call_data)
|
||||
@@ -50,17 +64,22 @@ class JobManager(DAGLifecycle):
|
||||
return JobManager(root_nodes, nodes, end_node, id2call_data, node_name_to_ids)
|
||||
|
||||
def get_call_data_by_id(self, node_id: str) -> Optional[Dict]:
|
||||
"""Get the call data by node id.
|
||||
|
||||
Args:
|
||||
node_id (str): The node id.
|
||||
"""
|
||||
return self._id2node_data.get(node_id)
|
||||
|
||||
async def before_dag_run(self):
|
||||
"""The callback before DAG run"""
|
||||
"""Execute the callback before DAG run."""
|
||||
tasks = []
|
||||
for node in self._all_nodes:
|
||||
tasks.append(node.before_dag_run())
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def after_dag_end(self):
|
||||
"""The callback after DAG end"""
|
||||
"""Execute the callback after DAG end."""
|
||||
tasks = []
|
||||
for node in self._all_nodes:
|
||||
tasks.append(node.after_dag_end())
|
||||
@@ -68,9 +87,9 @@ class JobManager(DAGLifecycle):
|
||||
|
||||
|
||||
def _save_call_data(
|
||||
root_nodes: List[BaseOperator], call_data: CALL_DATA
|
||||
) -> Dict[str, Dict]:
|
||||
id2call_data = {}
|
||||
root_nodes: List[BaseOperator], call_data: Optional[CALL_DATA]
|
||||
) -> Dict[str, Optional[Dict]]:
|
||||
id2call_data: Dict[str, Optional[Dict]] = {}
|
||||
logger.debug(f"_save_call_data: {call_data}, root_nodes: {root_nodes}")
|
||||
if not call_data:
|
||||
return id2call_data
|
||||
@@ -82,7 +101,8 @@ def _save_call_data(
|
||||
for node in root_nodes:
|
||||
node_id = node.node_id
|
||||
logger.debug(
|
||||
f"Save call data to node {node.node_id}, call_data: {call_data.get(node_id)}"
|
||||
f"Save call data to node {node.node_id}, call_data: "
|
||||
f"{call_data.get(node_id)}"
|
||||
)
|
||||
id2call_data[node_id] = call_data.get(node_id)
|
||||
return id2call_data
|
||||
@@ -91,13 +111,11 @@ def _save_call_data(
|
||||
def _build_from_end_node(end_node: BaseOperator) -> List[BaseOperator]:
|
||||
"""Build all nodes from the end node."""
|
||||
nodes = []
|
||||
if isinstance(end_node, BaseOperator):
|
||||
task_id = end_node.node_id
|
||||
if not task_id:
|
||||
task_id = str(uuid.uuid4())
|
||||
end_node.set_node_id(task_id)
|
||||
if isinstance(end_node, BaseOperator) and not end_node._node_id:
|
||||
end_node.set_node_id(str(uuid.uuid4()))
|
||||
nodes.append(end_node)
|
||||
for node in end_node.upstream:
|
||||
node = cast(BaseOperator, node)
|
||||
nodes += _build_from_end_node(node)
|
||||
return nodes
|
||||
|
||||
|
@@ -1,12 +1,16 @@
|
||||
"""Local runner for workflow.
|
||||
|
||||
This runner will run the workflow in the current process.
|
||||
"""
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Set
|
||||
from typing import Dict, List, Optional, Set, cast
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
|
||||
from ..dag.base import DAGContext, DAGVar
|
||||
from ..operator.base import CALL_DATA, BaseOperator, WorkflowRunner
|
||||
from ..operator.common_operator import BranchOperator, JoinOperator, TriggerOperator
|
||||
from ..task.base import TaskContext, TaskState
|
||||
from ..operator.common_operator import BranchOperator, JoinOperator
|
||||
from ..task.base import SKIP_DATA, TaskContext, TaskState
|
||||
from ..task.task_impl import DefaultInputContext, DefaultTaskContext, SimpleTaskOutput
|
||||
from .job_manager import JobManager
|
||||
|
||||
@@ -14,6 +18,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DefaultWorkflowRunner(WorkflowRunner):
|
||||
"""The default workflow runner."""
|
||||
|
||||
async def execute_workflow(
|
||||
self,
|
||||
node: BaseOperator,
|
||||
@@ -21,6 +27,17 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
streaming_call: bool = False,
|
||||
exist_dag_ctx: Optional[DAGContext] = None,
|
||||
) -> DAGContext:
|
||||
"""Execute the workflow.
|
||||
|
||||
Args:
|
||||
node (BaseOperator): The end node of the workflow.
|
||||
call_data (Optional[CALL_DATA], optional): The call data of the end node.
|
||||
Defaults to None.
|
||||
streaming_call (bool, optional): Whether the call is streaming call.
|
||||
Defaults to False.
|
||||
exist_dag_ctx (Optional[DAGContext], optional): The exist DAG context.
|
||||
Defaults to None.
|
||||
"""
|
||||
# Save node output
|
||||
# dag = node.dag
|
||||
job_manager = JobManager.build_from_end_node(node, call_data)
|
||||
@@ -37,8 +54,8 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
)
|
||||
logger.info(f"Begin run workflow from end operator, id: {node.node_id}")
|
||||
logger.debug(f"Node id {node.node_id}, call_data: {call_data}")
|
||||
skip_node_ids = set()
|
||||
system_app: SystemApp = DAGVar.get_current_system_app()
|
||||
skip_node_ids: Set[str] = set()
|
||||
system_app: Optional[SystemApp] = DAGVar.get_current_system_app()
|
||||
|
||||
await job_manager.before_dag_run()
|
||||
await self._execute_node(
|
||||
@@ -57,7 +74,7 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
dag_ctx: DAGContext,
|
||||
node_outputs: Dict[str, TaskContext],
|
||||
skip_node_ids: Set[str],
|
||||
system_app: SystemApp,
|
||||
system_app: Optional[SystemApp],
|
||||
):
|
||||
# Skip run node
|
||||
if node.node_id in node_outputs:
|
||||
@@ -79,8 +96,12 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
node_outputs[upstream_node.node_id] for upstream_node in node.upstream
|
||||
]
|
||||
input_ctx = DefaultInputContext(inputs)
|
||||
task_ctx = DefaultTaskContext(node.node_id, TaskState.INIT, task_output=None)
|
||||
task_ctx.set_call_data(job_manager.get_call_data_by_id(node.node_id))
|
||||
task_ctx: DefaultTaskContext = DefaultTaskContext(
|
||||
node.node_id, TaskState.INIT, task_output=None
|
||||
)
|
||||
current_call_data = job_manager.get_call_data_by_id(node.node_id)
|
||||
if current_call_data:
|
||||
task_ctx.set_call_data(current_call_data)
|
||||
|
||||
task_ctx.set_task_input(input_ctx)
|
||||
dag_ctx.set_current_task_context(task_ctx)
|
||||
@@ -88,12 +109,13 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
|
||||
if node.node_id in skip_node_ids:
|
||||
task_ctx.set_current_state(TaskState.SKIP)
|
||||
task_ctx.set_task_output(SimpleTaskOutput(None))
|
||||
task_ctx.set_task_output(SimpleTaskOutput(SKIP_DATA))
|
||||
node_outputs[node.node_id] = task_ctx
|
||||
return
|
||||
try:
|
||||
logger.debug(
|
||||
f"Begin run operator, node id: {node.node_id}, node name: {node.node_name}, cls: {node}"
|
||||
f"Begin run operator, node id: {node.node_id}, node name: "
|
||||
f"{node.node_name}, cls: {node}"
|
||||
)
|
||||
if system_app is not None and node.system_app is None:
|
||||
node.set_system_app(system_app)
|
||||
@@ -120,6 +142,7 @@ def _skip_current_downstream_by_node_name(
|
||||
if not skip_nodes:
|
||||
return
|
||||
for child in branch_node.downstream:
|
||||
child = cast(BaseOperator, child)
|
||||
if child.node_name in skip_nodes:
|
||||
logger.info(f"Skip node name {child.node_name}, node id {child.node_id}")
|
||||
_skip_downstream_by_id(child, skip_node_ids)
|
||||
@@ -131,4 +154,5 @@ def _skip_downstream_by_id(node: BaseOperator, skip_node_ids: Set[str]):
|
||||
return
|
||||
skip_node_ids.add(node.node_id)
|
||||
for child in node.downstream:
|
||||
child = cast(BaseOperator, child)
|
||||
_skip_downstream_by_id(child, skip_node_ids)
|
||||
|
Reference in New Issue
Block a user