chore: Add pylint for DB-GPT core lib (#1076)

This commit is contained in:
Fangyin Cheng
2024-01-16 17:36:26 +08:00
committed by GitHub
parent 3a54d1ef9a
commit 40c853575a
79 changed files with 2213 additions and 839 deletions

View File

@@ -0,0 +1,4 @@
"""The module to run AWEL operators.
You can implement your own runner by inheriting the `WorkflowRunner` class.
"""

View File

@@ -1,33 +1,38 @@
"""Job manager for DAG."""
import asyncio
import logging
import uuid
from typing import Dict, List, Optional, Set
from typing import Dict, List, Optional, cast
from ..dag.base import DAG, DAGLifecycle
from ..dag.base import DAGLifecycle
from ..operator.base import CALL_DATA, BaseOperator
logger = logging.getLogger(__name__)
class DAGNodeInstance:
def __init__(self, node_instance: DAG) -> None:
pass
class DAGInstance:
def __init__(self, dag: DAG) -> None:
self._dag = dag
class JobManager(DAGLifecycle):
"""Job manager for DAG.
This class is used to manage the DAG lifecycle.
"""
def __init__(
self,
root_nodes: List[BaseOperator],
all_nodes: List[BaseOperator],
end_node: BaseOperator,
id2call_data: Dict[str, Dict],
id2call_data: Dict[str, Optional[Dict]],
node_name_to_ids: Dict[str, str],
) -> None:
"""Create a job manager.
Args:
root_nodes (List[BaseOperator]): The root nodes of the DAG.
all_nodes (List[BaseOperator]): All nodes of the DAG.
end_node (BaseOperator): The end node of the DAG.
id2call_data (Dict[str, Optional[Dict]]): The call data of each node.
node_name_to_ids (Dict[str, str]): The node name to node id mapping.
"""
self._root_nodes = root_nodes
self._all_nodes = all_nodes
self._end_node = end_node
@@ -38,6 +43,15 @@ class JobManager(DAGLifecycle):
def build_from_end_node(
end_node: BaseOperator, call_data: Optional[CALL_DATA] = None
) -> "JobManager":
"""Build a job manager from the end node.
This will get all upstream nodes from the end node, and build a job manager.
Args:
end_node (BaseOperator): The end node of the DAG.
call_data (Optional[CALL_DATA], optional): The call data of the end node.
Defaults to None.
"""
nodes = _build_from_end_node(end_node)
root_nodes = _get_root_nodes(nodes)
id2call_data = _save_call_data(root_nodes, call_data)
@@ -50,17 +64,22 @@ class JobManager(DAGLifecycle):
return JobManager(root_nodes, nodes, end_node, id2call_data, node_name_to_ids)
def get_call_data_by_id(self, node_id: str) -> Optional[Dict]:
"""Get the call data by node id.
Args:
node_id (str): The node id.
"""
return self._id2node_data.get(node_id)
async def before_dag_run(self):
"""The callback before DAG run"""
"""Execute the callback before DAG run."""
tasks = []
for node in self._all_nodes:
tasks.append(node.before_dag_run())
await asyncio.gather(*tasks)
async def after_dag_end(self):
"""The callback after DAG end"""
"""Execute the callback after DAG end."""
tasks = []
for node in self._all_nodes:
tasks.append(node.after_dag_end())
@@ -68,9 +87,9 @@ class JobManager(DAGLifecycle):
def _save_call_data(
root_nodes: List[BaseOperator], call_data: CALL_DATA
) -> Dict[str, Dict]:
id2call_data = {}
root_nodes: List[BaseOperator], call_data: Optional[CALL_DATA]
) -> Dict[str, Optional[Dict]]:
id2call_data: Dict[str, Optional[Dict]] = {}
logger.debug(f"_save_call_data: {call_data}, root_nodes: {root_nodes}")
if not call_data:
return id2call_data
@@ -82,7 +101,8 @@ def _save_call_data(
for node in root_nodes:
node_id = node.node_id
logger.debug(
f"Save call data to node {node.node_id}, call_data: {call_data.get(node_id)}"
f"Save call data to node {node.node_id}, call_data: "
f"{call_data.get(node_id)}"
)
id2call_data[node_id] = call_data.get(node_id)
return id2call_data
@@ -91,13 +111,11 @@ def _save_call_data(
def _build_from_end_node(end_node: BaseOperator) -> List[BaseOperator]:
"""Build all nodes from the end node."""
nodes = []
if isinstance(end_node, BaseOperator):
task_id = end_node.node_id
if not task_id:
task_id = str(uuid.uuid4())
end_node.set_node_id(task_id)
if isinstance(end_node, BaseOperator) and not end_node._node_id:
end_node.set_node_id(str(uuid.uuid4()))
nodes.append(end_node)
for node in end_node.upstream:
node = cast(BaseOperator, node)
nodes += _build_from_end_node(node)
return nodes

View File

@@ -1,12 +1,16 @@
"""Local runner for workflow.
This runner will run the workflow in the current process.
"""
import logging
from typing import Dict, List, Optional, Set
from typing import Dict, List, Optional, Set, cast
from dbgpt.component import SystemApp
from ..dag.base import DAGContext, DAGVar
from ..operator.base import CALL_DATA, BaseOperator, WorkflowRunner
from ..operator.common_operator import BranchOperator, JoinOperator, TriggerOperator
from ..task.base import TaskContext, TaskState
from ..operator.common_operator import BranchOperator, JoinOperator
from ..task.base import SKIP_DATA, TaskContext, TaskState
from ..task.task_impl import DefaultInputContext, DefaultTaskContext, SimpleTaskOutput
from .job_manager import JobManager
@@ -14,6 +18,8 @@ logger = logging.getLogger(__name__)
class DefaultWorkflowRunner(WorkflowRunner):
"""The default workflow runner."""
async def execute_workflow(
self,
node: BaseOperator,
@@ -21,6 +27,17 @@ class DefaultWorkflowRunner(WorkflowRunner):
streaming_call: bool = False,
exist_dag_ctx: Optional[DAGContext] = None,
) -> DAGContext:
"""Execute the workflow.
Args:
node (BaseOperator): The end node of the workflow.
call_data (Optional[CALL_DATA], optional): The call data of the end node.
Defaults to None.
streaming_call (bool, optional): Whether the call is streaming call.
Defaults to False.
exist_dag_ctx (Optional[DAGContext], optional): The exist DAG context.
Defaults to None.
"""
# Save node output
# dag = node.dag
job_manager = JobManager.build_from_end_node(node, call_data)
@@ -37,8 +54,8 @@ class DefaultWorkflowRunner(WorkflowRunner):
)
logger.info(f"Begin run workflow from end operator, id: {node.node_id}")
logger.debug(f"Node id {node.node_id}, call_data: {call_data}")
skip_node_ids = set()
system_app: SystemApp = DAGVar.get_current_system_app()
skip_node_ids: Set[str] = set()
system_app: Optional[SystemApp] = DAGVar.get_current_system_app()
await job_manager.before_dag_run()
await self._execute_node(
@@ -57,7 +74,7 @@ class DefaultWorkflowRunner(WorkflowRunner):
dag_ctx: DAGContext,
node_outputs: Dict[str, TaskContext],
skip_node_ids: Set[str],
system_app: SystemApp,
system_app: Optional[SystemApp],
):
# Skip run node
if node.node_id in node_outputs:
@@ -79,8 +96,12 @@ class DefaultWorkflowRunner(WorkflowRunner):
node_outputs[upstream_node.node_id] for upstream_node in node.upstream
]
input_ctx = DefaultInputContext(inputs)
task_ctx = DefaultTaskContext(node.node_id, TaskState.INIT, task_output=None)
task_ctx.set_call_data(job_manager.get_call_data_by_id(node.node_id))
task_ctx: DefaultTaskContext = DefaultTaskContext(
node.node_id, TaskState.INIT, task_output=None
)
current_call_data = job_manager.get_call_data_by_id(node.node_id)
if current_call_data:
task_ctx.set_call_data(current_call_data)
task_ctx.set_task_input(input_ctx)
dag_ctx.set_current_task_context(task_ctx)
@@ -88,12 +109,13 @@ class DefaultWorkflowRunner(WorkflowRunner):
if node.node_id in skip_node_ids:
task_ctx.set_current_state(TaskState.SKIP)
task_ctx.set_task_output(SimpleTaskOutput(None))
task_ctx.set_task_output(SimpleTaskOutput(SKIP_DATA))
node_outputs[node.node_id] = task_ctx
return
try:
logger.debug(
f"Begin run operator, node id: {node.node_id}, node name: {node.node_name}, cls: {node}"
f"Begin run operator, node id: {node.node_id}, node name: "
f"{node.node_name}, cls: {node}"
)
if system_app is not None and node.system_app is None:
node.set_system_app(system_app)
@@ -120,6 +142,7 @@ def _skip_current_downstream_by_node_name(
if not skip_nodes:
return
for child in branch_node.downstream:
child = cast(BaseOperator, child)
if child.node_name in skip_nodes:
logger.info(f"Skip node name {child.node_name}, node id {child.node_id}")
_skip_downstream_by_id(child, skip_node_ids)
@@ -131,4 +154,5 @@ def _skip_downstream_by_id(node: BaseOperator, skip_node_ids: Set[str]):
return
skip_node_ids.add(node.node_id)
for child in node.downstream:
child = cast(BaseOperator, child)
_skip_downstream_by_id(child, skip_node_ids)