chore: Add pylint for DB-GPT core lib (#1076)

This commit is contained in:
Fangyin Cheng
2024-01-16 17:36:26 +08:00
committed by GitHub
parent 3a54d1ef9a
commit 40c853575a
79 changed files with 2213 additions and 839 deletions

View File

@@ -1,9 +1,10 @@
"""Agentic Workflow Expression Language (AWEL)
"""Agentic Workflow Expression Language (AWEL).
Note:
AWEL is still an experimental feature and only opens the lowest level API.
The stability of this API cannot be guaranteed at present.
Agentic Workflow Expression Language(AWEL) is a set of intelligent agent workflow
expression language specially designed for large model application development. It
provides great functionality and flexibility. Through the AWEL API, you can focus on
the development of business logic for LLMs applications without paying attention to
cumbersome model and environment details.
"""
@@ -71,10 +72,12 @@ __all__ = [
"TransformStreamAbsOperator",
"HttpTrigger",
"setup_dev_environment",
"_is_async_iterator",
]
def initialize_awel(system_app: SystemApp, dag_dirs: List[str]):
"""Initialize AWEL."""
from .dag.base import DAGVar
from .dag.dag_manager import DAGManager
from .operator.base import initialize_runner
@@ -92,13 +95,13 @@ def initialize_awel(system_app: SystemApp, dag_dirs: List[str]):
def setup_dev_environment(
dags: List[DAG],
host: Optional[str] = "127.0.0.1",
port: Optional[int] = 5555,
host: str = "127.0.0.1",
port: int = 5555,
logging_level: Optional[str] = None,
logger_filename: Optional[str] = None,
show_dag_graph: Optional[bool] = True,
) -> None:
"""Setup a development environment for AWEL.
"""Run AWEL in development environment.
Just using in development environment, not production environment.
@@ -107,9 +110,11 @@ def setup_dev_environment(
host (Optional[str], optional): The host. Defaults to "127.0.0.1"
port (Optional[int], optional): The port. Defaults to 5555.
logging_level (Optional[str], optional): The logging level. Defaults to None.
logger_filename (Optional[str], optional): The logger filename. Defaults to None.
show_dag_graph (Optional[bool], optional): Whether show the DAG graph. Defaults to True.
If True, the DAG graph will be saved to a file and open it automatically.
logger_filename (Optional[str], optional): The logger filename.
Defaults to None.
show_dag_graph (Optional[bool], optional): Whether show the DAG graph.
Defaults to True. If True, the DAG graph will be saved to a file and open
it automatically.
"""
import uvicorn
from fastapi import FastAPI
@@ -138,7 +143,9 @@ def setup_dev_environment(
logger.info(f"Visualize DAG {str(dag)} to {dag_graph_file}")
except Exception as e:
logger.warning(
f"Visualize DAG {str(dag)} failed: {e}, if your system has no graphviz, you can install it by `pip install graphviz` or `sudo apt install graphviz`"
f"Visualize DAG {str(dag)} failed: {e}, if your system has no "
f"graphviz, you can install it by `pip install graphviz` or "
f"`sudo apt install graphviz`"
)
for trigger in dag.trigger_nodes:
trigger_manager.register_trigger(trigger)

View File

@@ -1,7 +1,10 @@
"""Base classes for AWEL."""
from abc import ABC, abstractmethod
class Trigger(ABC):
"""Base class for trigger."""
@abstractmethod
async def trigger(self) -> None:
"""Trigger the workflow or a specific operation in the workflow."""

View File

@@ -0,0 +1 @@
"""The module of DAGs."""

View File

@@ -1,3 +1,7 @@
"""The base module of DAG.
DAG is the core component of AWEL, it is used to define the relationship between tasks.
"""
import asyncio
import contextvars
import logging
@@ -6,7 +10,7 @@ import uuid
from abc import ABC, abstractmethod
from collections import deque
from concurrent.futures import Executor
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Union, cast
from dbgpt.component import SystemApp
@@ -27,86 +31,108 @@ def _is_async_context():
class DependencyMixin(ABC):
"""The mixin class for DAGNode.
This class defines the interface for setting upstream and downstream nodes.
And it also implements the operator << and >> for setting upstream
and downstream nodes.
"""
@abstractmethod
def set_upstream(self, nodes: DependencyType) -> "DependencyMixin":
def set_upstream(self, nodes: DependencyType) -> None:
"""Set one or more upstream nodes for this node.
Args:
nodes (DependencyType): Upstream nodes to be set to current node.
Returns:
DependencyMixin: Returns self to allow method chaining.
Raises:
ValueError: If no upstream nodes are provided or if an argument is not a DependencyMixin.
ValueError: If no upstream nodes are provided or if an argument is
not a DependencyMixin.
"""
@abstractmethod
def set_downstream(self, nodes: DependencyType) -> "DependencyMixin":
def set_downstream(self, nodes: DependencyType) -> None:
"""Set one or more downstream nodes for this node.
Args:
nodes (DependencyType): Downstream nodes to be set to current node.
Returns:
DependencyMixin: Returns self to allow method chaining.
Raises:
ValueError: If no downstream nodes are provided or if an argument is not a DependencyMixin.
ValueError: If no downstream nodes are provided or if an argument is
not a DependencyMixin.
"""
def __lshift__(self, nodes: DependencyType) -> DependencyType:
"""Implements self << nodes
"""Set upstream nodes for current node.
Implements: self << nodes.
Example:
.. code-block:: python
.. code-block:: python
# means node.set_upstream(input_node)
node << input_node
# means node2.set_upstream([input_node])
node2 << [input_node]
# means node.set_upstream(input_node)
node << input_node
# means node2.set_upstream([input_node])
node2 << [input_node]
"""
self.set_upstream(nodes)
return nodes
def __rshift__(self, nodes: DependencyType) -> DependencyType:
"""Implements self >> nodes
"""Set downstream nodes for current node.
Example:
Implements: self >> nodes.
.. code-block:: python
Examples:
.. code-block:: python
# means node.set_downstream(next_node)
node >> next_node
# means node.set_downstream(next_node)
node >> next_node
# means node2.set_downstream([next_node])
node2 >> [next_node]
# means node2.set_downstream([next_node])
node2 >> [next_node]
"""
self.set_downstream(nodes)
return nodes
def __rrshift__(self, nodes: DependencyType) -> "DependencyMixin":
"""Implements [node] >> self"""
"""Set upstream nodes for current node.
Implements: [node] >> self
"""
self.__lshift__(nodes)
return self
def __rlshift__(self, nodes: DependencyType) -> "DependencyMixin":
"""Implements [node] << self"""
"""Set downstream nodes for current node.
Implements: [node] << self
"""
self.__rshift__(nodes)
return self
class DAGVar:
"""The DAGVar is used to store the current DAG context."""
_thread_local = threading.local()
_async_local = contextvars.ContextVar("current_dag_stack", default=deque())
_system_app: SystemApp = None
_executor: Executor = None
_async_local: contextvars.ContextVar = contextvars.ContextVar(
"current_dag_stack", default=deque()
)
_system_app: Optional[SystemApp] = None
# The executor for current DAG, this is used run some sync tasks in async DAG
_executor: Optional[Executor] = None
@classmethod
def enter_dag(cls, dag) -> None:
"""Enter a DAG context.
Args:
dag (DAG): The DAG to enter
"""
is_async = _is_async_context()
if is_async:
stack = cls._async_local.get()
@@ -119,6 +145,7 @@ class DAGVar:
@classmethod
def exit_dag(cls) -> None:
"""Exit a DAG context."""
is_async = _is_async_context()
if is_async:
stack = cls._async_local.get()
@@ -134,6 +161,11 @@ class DAGVar:
@classmethod
def get_current_dag(cls) -> Optional["DAG"]:
"""Get the current DAG.
Returns:
Optional[DAG]: The current DAG
"""
is_async = _is_async_context()
if is_async:
stack = cls._async_local.get()
@@ -147,36 +179,56 @@ class DAGVar:
return None
@classmethod
def get_current_system_app(cls) -> SystemApp:
def get_current_system_app(cls) -> Optional[SystemApp]:
"""Get the current system app.
Returns:
Optional[SystemApp]: The current system app
"""
# if not cls._system_app:
# raise RuntimeError("System APP not set for DAGVar")
return cls._system_app
@classmethod
def set_current_system_app(cls, system_app: SystemApp) -> None:
"""Set the current system app.
Args:
system_app (SystemApp): The system app to set
"""
if cls._system_app:
logger.warn("System APP has already set, nothing to do")
logger.warning("System APP has already set, nothing to do")
else:
cls._system_app = system_app
@classmethod
def get_executor(cls) -> Executor:
def get_executor(cls) -> Optional[Executor]:
"""Get the current executor.
Returns:
Optional[Executor]: The current executor
"""
return cls._executor
@classmethod
def set_executor(cls, executor: Executor) -> None:
"""Set the current executor.
Args:
executor (Executor): The executor to set
"""
cls._executor = executor
class DAGLifecycle:
"""The lifecycle of DAG"""
"""The lifecycle of DAG."""
async def before_dag_run(self):
"""The callback before DAG run"""
"""Execute before DAG run."""
pass
async def after_dag_end(self):
"""The callback after DAG end,
"""Execute after DAG end.
This method may be called multiple times, please make sure it is idempotent.
"""
@@ -184,6 +236,8 @@ class DAGLifecycle:
class DAGNode(DAGLifecycle, DependencyMixin, ABC):
"""The base class of DAGNode."""
resource_group: Optional[ResourceGroup] = None
"""The resource group of current DAGNode"""
@@ -196,6 +250,17 @@ class DAGNode(DAGLifecycle, DependencyMixin, ABC):
executor: Optional[Executor] = None,
**kwargs,
) -> None:
"""Initialize a DAGNode.
Args:
dag (Optional["DAG"], optional): The DAG to add this node to.
Defaults to None.
node_id (Optional[str], optional): The node id. Defaults to None.
node_name (Optional[str], optional): The node name. Defaults to None.
system_app (Optional[SystemApp], optional): The system app.
Defaults to None.
executor (Optional[Executor], optional): The executor. Defaults to None.
"""
super().__init__()
self._upstream: List["DAGNode"] = []
self._downstream: List["DAGNode"] = []
@@ -206,24 +271,28 @@ class DAGNode(DAGLifecycle, DependencyMixin, ABC):
self._executor: Optional[Executor] = executor or DAGVar.get_executor()
if not node_id and self._dag:
node_id = self._dag._new_node_id()
self._node_id: str = node_id
self._node_name: str = node_name
self._node_id: Optional[str] = node_id
self._node_name: Optional[str] = node_name
@property
def node_id(self) -> str:
"""Return the node id of current DAGNode."""
if not self._node_id:
raise ValueError("Node id not set for current DAGNode")
return self._node_id
@property
@abstractmethod
def dev_mode(self) -> bool:
"""Whether current DAGNode is in dev mode"""
"""Whether current DAGNode is in dev mode."""
@property
def system_app(self) -> SystemApp:
def system_app(self) -> Optional[SystemApp]:
"""Return the system app of current DAGNode."""
return self._system_app
def set_system_app(self, system_app: SystemApp) -> None:
"""Set system app for current DAGNode
"""Set system app for current DAGNode.
Args:
system_app (SystemApp): The system app
@@ -231,50 +300,97 @@ class DAGNode(DAGLifecycle, DependencyMixin, ABC):
self._system_app = system_app
def set_node_id(self, node_id: str) -> None:
"""Set node id for current DAGNode.
Args:
node_id (str): The node id
"""
self._node_id = node_id
def __hash__(self) -> int:
"""Return the hash value of current DAGNode.
If the node_id is not None, return the hash value of node_id.
"""
if self.node_id:
return hash(self.node_id)
else:
return super().__hash__()
def __eq__(self, other: Any) -> bool:
"""Return whether the current DAGNode is equal to other DAGNode."""
if not isinstance(other, DAGNode):
return False
return self.node_id == other.node_id
@property
def node_name(self) -> str:
def node_name(self) -> Optional[str]:
"""Return the node name of current DAGNode.
Returns:
Optional[str]: The node name of current DAGNode
"""
return self._node_name
@property
def dag(self) -> "DAG":
def dag(self) -> Optional["DAG"]:
"""Return the DAG of current DAGNode.
Returns:
Optional["DAG"]: The DAG of current DAGNode
"""
return self._dag
def set_upstream(self, nodes: DependencyType) -> "DAGNode":
def set_upstream(self, nodes: DependencyType) -> None:
"""Set upstream nodes for current node.
Args:
nodes (DependencyType): Upstream nodes to be set to current node.
"""
self.set_dependency(nodes)
def set_downstream(self, nodes: DependencyType) -> "DAGNode":
def set_downstream(self, nodes: DependencyType) -> None:
"""Set downstream nodes for current node.
Args:
nodes (DependencyType): Downstream nodes to be set to current node.
"""
self.set_dependency(nodes, is_upstream=False)
@property
def upstream(self) -> List["DAGNode"]:
"""Return the upstream nodes of current DAGNode.
Returns:
List["DAGNode"]: The upstream nodes of current DAGNode
"""
return self._upstream
@property
def downstream(self) -> List["DAGNode"]:
"""Return the downstream nodes of current DAGNode.
Returns:
List["DAGNode"]: The downstream nodes of current DAGNode
"""
return self._downstream
def set_dependency(self, nodes: DependencyType, is_upstream: bool = True) -> None:
"""Set dependency for current node.
Args:
nodes (DependencyType): The nodes to set dependency to current node.
is_upstream (bool, optional): Whether set upstream nodes. Defaults to True.
"""
if not isinstance(nodes, Sequence):
nodes = [nodes]
if not all(isinstance(node, DAGNode) for node in nodes):
raise ValueError(
"all nodes to set dependency to current node must be instance of 'DAGNode'"
"all nodes to set dependency to current node must be instance "
"of 'DAGNode'"
)
nodes: Sequence[DAGNode] = nodes
dags = set([node.dag for node in nodes if node.dag])
nodes = cast(Sequence[DAGNode], nodes)
dags = set([node.dag for node in nodes if node.dag]) # noqa: C403
if self.dag:
dags.add(self.dag)
if not dags:
@@ -302,6 +418,7 @@ class DAGNode(DAGLifecycle, DependencyMixin, ABC):
node._upstream.append(self)
def __repr__(self):
"""Return the representation of current DAGNode."""
cls_name = self.__class__.__name__
if self.node_name and self.node_name:
return f"{cls_name}(node_id={self.node_id}, node_name={self.node_name})"
@@ -313,6 +430,7 @@ class DAGNode(DAGLifecycle, DependencyMixin, ABC):
return f"{cls_name}"
def __str__(self):
"""Return the string of current DAGNode."""
return self.__repr__()
@@ -321,7 +439,7 @@ def _build_task_key(task_name: str, key: str) -> str:
class DAGContext:
"""The context of current DAG, created when the DAG is running
"""The context of current DAG, created when the DAG is running.
Every DAG has been triggered will create a new DAGContext.
"""
@@ -329,22 +447,32 @@ class DAGContext:
def __init__(
self,
streaming_call: bool = False,
node_to_outputs: Dict[str, TaskContext] = None,
node_name_to_ids: Dict[str, str] = None,
node_to_outputs: Optional[Dict[str, TaskContext]] = None,
node_name_to_ids: Optional[Dict[str, str]] = None,
) -> None:
"""Initialize a DAGContext.
Args:
streaming_call (bool, optional): Whether the current DAG is streaming call.
Defaults to False.
node_to_outputs (Optional[Dict[str, TaskContext]], optional):
The task outputs of current DAG. Defaults to None.
node_name_to_ids (Optional[Dict[str, str]], optional):
The task name to task id mapping. Defaults to None.
"""
if not node_to_outputs:
node_to_outputs = {}
if not node_name_to_ids:
node_name_to_ids = {}
self._streaming_call = streaming_call
self._curr_task_ctx = None
self._curr_task_ctx: Optional[TaskContext] = None
self._share_data: Dict[str, Any] = {}
self._node_to_outputs = node_to_outputs
self._node_name_to_ids = node_name_to_ids
self._node_to_outputs: Dict[str, TaskContext] = node_to_outputs
self._node_name_to_ids: Dict[str, str] = node_name_to_ids
@property
def _task_outputs(self) -> Dict[str, TaskContext]:
"""The task outputs of current DAG
"""Return the task outputs of current DAG.
Just use for internal for now.
Returns:
@@ -354,18 +482,28 @@ class DAGContext:
@property
def current_task_context(self) -> TaskContext:
"""Return the current task context."""
if not self._curr_task_ctx:
raise RuntimeError("Current task context not set")
return self._curr_task_ctx
@property
def streaming_call(self) -> bool:
"""Whether the current DAG is streaming call"""
"""Whether the current DAG is streaming call."""
return self._streaming_call
def set_current_task_context(self, _curr_task_ctx: TaskContext) -> None:
"""Set the current task context.
When the task is running, the current task context
will be set to the task context.
TODO: We should support parallel task running in the future.
"""
self._curr_task_ctx = _curr_task_ctx
def get_task_output(self, task_name: str) -> TaskOutput:
"""Get the task output by task name
"""Get the task output by task name.
Args:
task_name (str): The task name
@@ -376,22 +514,41 @@ class DAGContext:
if task_name is None:
raise ValueError("task_name can't be None")
node_id = self._node_name_to_ids.get(task_name)
if node_id:
if not node_id:
raise ValueError(f"Task name {task_name} not exists in DAG")
return self._task_outputs.get(node_id).task_output
task_output = self._task_outputs.get(node_id)
if not task_output:
raise ValueError(f"Task output for task {task_name} not exists")
return task_output.task_output
async def get_from_share_data(self, key: str) -> Any:
"""Get share data by key.
Args:
key (str): The share data key
Returns:
Any: The share data, you can cast it to the real type
"""
return self._share_data.get(key)
async def save_to_share_data(
self, key: str, data: Any, overwrite: bool = False
) -> None:
"""Save share data by key.
Args:
key (str): The share data key
data (Any): The share data
overwrite (bool): Whether overwrite the share data if the key
already exists. Defaults to None.
"""
if key in self._share_data and not overwrite:
raise ValueError(f"Share data key {key} already exists")
self._share_data[key] = data
async def get_task_share_data(self, task_name: str, key: str) -> Any:
"""Get share data by task name and key
"""Get share data by task name and key.
Args:
task_name (str): The task name
@@ -409,14 +566,14 @@ class DAGContext:
async def save_task_share_data(
self, task_name: str, key: str, data: Any, overwrite: bool = False
) -> None:
"""Save share data by task name and key
"""Save share data by task name and key.
Args:
task_name (str): The task name
key (str): The share data key
data (Any): The share data
overwrite (bool): Whether overwrite the share data if the key already exists.
Defaults to None.
overwrite (bool): Whether overwrite the share data if the key
already exists. Defaults to None.
Raises:
ValueError: If the share data key already exists and overwrite is not True
@@ -429,15 +586,22 @@ class DAGContext:
class DAG:
"""The DAG class.
Manage the DAG nodes and the relationship between them.
"""
def __init__(
self, dag_id: str, resource_group: Optional[ResourceGroup] = None
) -> None:
"""Initialize a DAG."""
self._dag_id = dag_id
self.node_map: Dict[str, DAGNode] = {}
self.node_name_to_node: Dict[str, DAGNode] = {}
self._root_nodes: List[DAGNode] = None
self._leaf_nodes: List[DAGNode] = None
self._trigger_nodes: List[DAGNode] = None
self._root_nodes: List[DAGNode] = []
self._leaf_nodes: List[DAGNode] = []
self._trigger_nodes: List[DAGNode] = []
self._resource_group: Optional[ResourceGroup] = resource_group
def _append_node(self, node: DAGNode) -> None:
if node.node_id in self.node_map:
@@ -448,22 +612,26 @@ class DAG:
f"Node name {node.node_name} already exists in DAG {self.dag_id}"
)
self.node_name_to_node[node.node_name] = node
self.node_map[node.node_id] = node
node_id = node.node_id
if not node_id:
raise ValueError("Node id can't be None")
self.node_map[node_id] = node
# clear cached nodes
self._root_nodes = None
self._leaf_nodes = None
self._root_nodes = []
self._leaf_nodes = []
def _new_node_id(self) -> str:
return str(uuid.uuid4())
@property
def dag_id(self) -> str:
"""Return the dag id of current DAG."""
return self._dag_id
def _build(self) -> None:
from ..operator.common_operator import TriggerOperator
nodes = set()
nodes: Set[DAGNode] = set()
for _, node in self.node_map.items():
nodes = nodes.union(_get_nodes(node))
self._root_nodes = list(set(filter(lambda x: not x.upstream, nodes)))
@@ -474,7 +642,7 @@ class DAG:
@property
def root_nodes(self) -> List[DAGNode]:
"""The root nodes of current DAG
"""Return the root nodes of current DAG.
Returns:
List[DAGNode]: The root nodes of current DAG, no repeat
@@ -485,7 +653,7 @@ class DAG:
@property
def leaf_nodes(self) -> List[DAGNode]:
"""The leaf nodes of current DAG
"""Return the leaf nodes of current DAG.
Returns:
List[DAGNode]: The leaf nodes of current DAG, no repeat
@@ -496,7 +664,7 @@ class DAG:
@property
def trigger_nodes(self) -> List[DAGNode]:
"""The trigger nodes of current DAG
"""Return the trigger nodes of current DAG.
Returns:
List[DAGNode]: The trigger nodes of current DAG, no repeat
@@ -506,34 +674,42 @@ class DAG:
return self._trigger_nodes
async def _after_dag_end(self) -> None:
"""The callback after DAG end"""
"""Execute after DAG end."""
tasks = []
for node in self.node_map.values():
tasks.append(node.after_dag_end())
await asyncio.gather(*tasks)
def print_tree(self) -> None:
"""Print the DAG tree"""
"""Print the DAG tree""" # noqa: D400
_print_format_dag_tree(self)
def visualize_dag(self, view: bool = True, **kwargs) -> Optional[str]:
"""Create the DAG graph"""
"""Visualize the DAG.
Args:
view (bool, optional): Whether view the DAG graph. Defaults to True,
if True, it will open the graph file with your default viewer.
"""
self.print_tree()
return _visualize_dag(self, view=view, **kwargs)
def __enter__(self):
"""Enter a DAG context."""
DAGVar.enter_dag(self)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Exit a DAG context."""
DAGVar.exit_dag()
def __repr__(self):
"""Return the representation of current DAG."""
return f"DAG(dag_id={self.dag_id})"
def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> set[DAGNode]:
nodes = set()
def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> Set[DAGNode]:
nodes: Set[DAGNode] = set()
if not node:
return nodes
nodes.add(node)
@@ -553,7 +729,7 @@ def _print_dag(
level: int = 0,
prefix: str = "",
last: bool = True,
level_dict: Dict[str, Any] = None,
level_dict: Optional[Dict[int, Any]] = None,
):
if level_dict is None:
level_dict = {}
@@ -606,7 +782,7 @@ def _handle_dag_nodes(
def _visualize_dag(dag: DAG, view: bool = True, **kwargs) -> Optional[str]:
"""Visualize the DAG
"""Visualize the DAG.
Args:
dag (DAG): The DAG to visualize
@@ -641,7 +817,7 @@ def _visualize_dag(dag: DAG, view: bool = True, **kwargs) -> Optional[str]:
filename = kwargs["filename"]
del kwargs["filename"]
if not "directory" in kwargs:
if "directory" not in kwargs:
from dbgpt.configs.model_config import LOGDIR
kwargs["directory"] = LOGDIR

View File

@@ -1,3 +1,8 @@
"""DAGManager is a component of AWEL, it is used to manage DAGs.
DAGManager will load DAGs from dag_dirs, and register the trigger nodes
to TriggerManager.
"""
import logging
from typing import Dict, List
@@ -10,24 +15,35 @@ logger = logging.getLogger(__name__)
class DAGManager(BaseComponent):
"""The component of DAGManager."""
name = ComponentType.AWEL_DAG_MANAGER
def __init__(self, system_app: SystemApp, dag_dirs: List[str]):
"""Initialize a DAGManager.
Args:
system_app (SystemApp): The system app.
dag_dirs (List[str]): The directories to load DAGs.
"""
super().__init__(system_app)
self.dag_loader = LocalFileDAGLoader(dag_dirs)
self.system_app = system_app
self.dag_map: Dict[str, DAG] = {}
def init_app(self, system_app: SystemApp):
"""Initialize the DAGManager."""
self.system_app = system_app
def load_dags(self):
"""Load DAGs from dag_dirs."""
dags = self.dag_loader.load_dags()
triggers = []
for dag in dags:
dag_id = dag.dag_id
if dag_id in self.dag_map:
raise ValueError(f"Load DAG error, DAG ID {dag_id} has already exist")
self.dag_map[dag_id] = dag
triggers += dag.trigger_nodes
from ..trigger.trigger_manager import DefaultTriggerManager

View File

@@ -1,3 +1,8 @@
"""DAG loader.
DAGLoader will load DAGs from dag_dirs or other sources.
Now only support load DAGs from local files.
"""
import hashlib
import logging
import os
@@ -12,16 +17,26 @@ logger = logging.getLogger(__name__)
class DAGLoader(ABC):
"""Abstract base class representing a loader for loading DAGs."""
@abstractmethod
def load_dags(self) -> List[DAG]:
"""Load dags"""
"""Load dags."""
class LocalFileDAGLoader(DAGLoader):
"""DAG loader for loading DAGs from local files."""
def __init__(self, dag_dirs: List[str]) -> None:
"""Initialize a LocalFileDAGLoader.
Args:
dag_dirs (List[str]): The directories to load DAGs.
"""
self._dag_dirs = dag_dirs
def load_dags(self) -> List[DAG]:
"""Load dags from local files."""
dags = []
for filepath in self._dag_dirs:
if not os.path.exists(filepath):
@@ -70,7 +85,7 @@ def _load_modules_from_file(filepath: str):
sys.modules[spec.name] = new_module
loader.exec_module(new_module)
return [new_module]
except Exception as e:
except Exception:
msg = traceback.format_exc()
logger.error(f"Failed to import: {filepath}, error message: {msg}")
# TODO save error message

View File

@@ -0,0 +1 @@
"""The module of operator."""

View File

@@ -1,7 +1,7 @@
"""Base classes for operators that can be executed within a workflow."""
import asyncio
import functools
from abc import ABC, ABCMeta, abstractmethod
from inspect import signature
from types import FunctionType
from typing import (
Any,
@@ -9,7 +9,6 @@ from typing import (
Dict,
Generic,
Iterator,
List,
Optional,
TypeVar,
Union,
@@ -21,7 +20,6 @@ from dbgpt.util.executor_utils import (
AsyncToSyncIterator,
BlockingFunction,
DefaultExecutorFactory,
ExecutorFactory,
blocking_func_to_async,
)
@@ -54,13 +52,15 @@ class WorkflowRunner(ABC, Generic[T]):
node (RunnableDAGNode): The starting node of the workflow to be executed.
call_data (CALL_DATA): The data pass to root operator node.
streaming_call (bool): Whether the call is a streaming call.
exist_dag_ctx (DAGContext): The context of the DAG when this node is run, Defaults to None.
exist_dag_ctx (DAGContext): The context of the DAG when this node is run,
Defaults to None.
Returns:
DAGContext: The context after executing the workflow, containing the final state and data.
DAGContext: The context after executing the workflow, containing the final
state and data.
"""
default_runner: WorkflowRunner = None
default_runner: Optional[WorkflowRunner] = None
class BaseOperatorMeta(ABCMeta):
@@ -68,8 +68,7 @@ class BaseOperatorMeta(ABCMeta):
@classmethod
def _apply_defaults(cls, func: F) -> F:
sig_cache = signature(func)
# sig_cache = signature(func)
@functools.wraps(func)
def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any:
dag: Optional[DAG] = kwargs.get("dag") or DAGVar.get_current_dag()
@@ -81,7 +80,7 @@ class BaseOperatorMeta(ABCMeta):
if not executor:
if system_app:
executor = system_app.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
ComponentType.EXECUTOR_DEFAULT, DefaultExecutorFactory
).create()
else:
executor = DefaultExecutorFactory().create()
@@ -107,9 +106,10 @@ class BaseOperatorMeta(ABCMeta):
real_obj = func(self, *args, **kwargs)
return real_obj
return cast(T, apply_defaults)
return cast(F, apply_defaults)
def __new__(cls, name, bases, namespace, **kwargs):
"""Create a new BaseOperator class with default arguments."""
new_cls = super().__new__(cls, name, bases, namespace, **kwargs)
new_cls.__init__ = cls._apply_defaults(new_cls.__init__)
return new_cls
@@ -126,13 +126,14 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
task_id: Optional[str] = None,
task_name: Optional[str] = None,
dag: Optional[DAG] = None,
runner: WorkflowRunner = None,
runner: Optional[WorkflowRunner] = None,
**kwargs,
) -> None:
"""Initializes a BaseOperator with an optional workflow runner.
"""Create a BaseOperator with an optional workflow runner.
Args:
runner (WorkflowRunner, optional): The runner used to execute the workflow. Defaults to None.
runner (WorkflowRunner, optional): The runner used to execute the workflow.
Defaults to None.
"""
super().__init__(node_id=task_id, node_name=task_name, dag=dag, **kwargs)
if not runner:
@@ -141,19 +142,24 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
runner = DefaultWorkflowRunner()
self._runner: WorkflowRunner = runner
self._dag_ctx: DAGContext = None
self._dag_ctx: Optional[DAGContext] = None
@property
def current_dag_context(self) -> DAGContext:
"""Return the current DAG context."""
if not self._dag_ctx:
raise ValueError("DAGContext is not set")
return self._dag_ctx
@property
def dev_mode(self) -> bool:
"""Whether the operator is in dev mode.
In production mode, the default runner is not None.
Returns:
bool: Whether the operator is in dev mode. True if the default runner is None.
bool: Whether the operator is in dev mode. True if the
default runner is None.
"""
return default_runner is None
@@ -186,7 +192,8 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
Args:
call_data (CALL_DATA): The data pass to root operator node.
dag_ctx (DAGContext): The context of the DAG when this node is run, Defaults to None.
dag_ctx (DAGContext): The context of the DAG when this node is run,
Defaults to None.
Returns:
OUT: The output of the node after execution.
"""
@@ -196,7 +203,9 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
return out_ctx.current_task_context.task_output.output
def _blocking_call(
self, call_data: Optional[CALL_DATA] = None, loop: asyncio.BaseEventLoop = None
self,
call_data: Optional[CALL_DATA] = None,
loop: Optional[asyncio.BaseEventLoop] = None,
) -> OUT:
"""Execute the node and return the output.
@@ -213,6 +222,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
if not loop:
loop = get_or_create_event_loop()
loop = cast(asyncio.BaseEventLoop, loop)
return loop.run_until_complete(self.call(call_data))
async def call_stream(
@@ -226,7 +236,8 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
Args:
call_data (CALL_DATA): The data pass to root operator node.
dag_ctx (DAGContext): The context of the DAG when this node is run, Defaults to None.
dag_ctx (DAGContext): The context of the DAG when this node is run,
Defaults to None.
Returns:
AsyncIterator[OUT]: An asynchronous iterator over the output stream.
@@ -237,7 +248,9 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
return out_ctx.current_task_context.task_output.output_stream
def _blocking_call_stream(
self, call_data: Optional[CALL_DATA] = None, loop: asyncio.BaseEventLoop = None
self,
call_data: Optional[CALL_DATA] = None,
loop: Optional[asyncio.BaseEventLoop] = None,
) -> Iterator[OUT]:
"""Execute the node and return the output as a stream.
@@ -259,9 +272,22 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
async def blocking_func_to_async(
self, func: BlockingFunction, *args, **kwargs
) -> Any:
"""Execute a blocking function asynchronously.
In AWEL, the operators are executed asynchronously. However,
some functions are blocking, we run them in a separate thread.
Args:
func (BlockingFunction): The blocking function to be executed.
*args: Positional arguments for the function.
**kwargs: Keyword arguments for the function.
"""
if not self._executor:
raise ValueError("Executor is not set")
return await blocking_func_to_async(self._executor, func, *args, **kwargs)
def initialize_runner(runner: WorkflowRunner):
"""Initialize the default runner."""
global default_runner
default_runner = runner

View File

@@ -1,7 +1,7 @@
"""Common operators of AWEL."""
import asyncio
import logging
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
@@ -13,7 +13,17 @@ from typing import (
)
from ..dag.base import DAGContext
from ..task.base import IN, OUT, InputContext, InputSource, TaskContext, TaskOutput
from ..task.base import (
IN,
OUT,
InputContext,
InputSource,
JoinFunc,
MapFunc,
ReduceFunc,
TaskContext,
TaskOutput,
)
from .base import BaseOperator
logger = logging.getLogger(__name__)
@@ -25,7 +35,12 @@ class JoinOperator(BaseOperator, Generic[OUT]):
This node type is useful for combining the outputs of upstream nodes.
"""
def __init__(self, combine_function, **kwargs):
def __init__(self, combine_function: JoinFunc, **kwargs):
"""Create a JoinDAGNode with a combine function.
Args:
combine_function: A function that defines how to combine inputs.
"""
super().__init__(**kwargs)
if not callable(combine_function):
raise ValueError("combine_function must be callable")
@@ -33,6 +48,7 @@ class JoinOperator(BaseOperator, Generic[OUT]):
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
"""Run the join operation on the DAG context's inputs.
Args:
dag_ctx (DAGContext): The current context of the DAG.
@@ -50,8 +66,10 @@ class JoinOperator(BaseOperator, Generic[OUT]):
class ReduceStreamOperator(BaseOperator, Generic[IN, OUT]):
def __init__(self, reduce_function=None, **kwargs):
"""Initializes a ReduceStreamOperator with a combine function.
"""Operator that reduces inputs using a custom reduce function."""
def __init__(self, reduce_function: Optional[ReduceFunc] = None, **kwargs):
"""Create a ReduceStreamOperator with a combine function.
Args:
combine_function: A function that defines how to combine inputs.
@@ -89,6 +107,7 @@ class ReduceStreamOperator(BaseOperator, Generic[IN, OUT]):
return reduce_output
async def reduce(self, input_value: AsyncIterator[IN]) -> OUT:
"""Reduce the input stream to a single value."""
raise NotImplementedError
@@ -99,8 +118,8 @@ class MapOperator(BaseOperator, Generic[IN, OUT]):
passes the transformed data downstream.
"""
def __init__(self, map_function=None, **kwargs):
"""Initializes a MapDAGNode with a mapping function.
def __init__(self, map_function: Optional[MapFunc] = None, **kwargs):
"""Create a MapDAGNode with a mapping function.
Args:
map_function: A function that defines how to map the input data.
@@ -133,13 +152,18 @@ class MapOperator(BaseOperator, Generic[IN, OUT]):
if not call_data and not curr_task_ctx.task_input.check_single_parent():
num_parents = len(curr_task_ctx.task_input.parent_outputs)
raise ValueError(
f"task {curr_task_ctx.task_id} MapDAGNode expects single parent, now number of parents: {num_parents}"
f"task {curr_task_ctx.task_id} MapDAGNode expects single parent,"
f"now number of parents: {num_parents}"
)
map_function = self.map_function or self.map
if call_data:
call_data = await curr_task_ctx._call_data_to_output()
output = await call_data.map(map_function)
wrapped_call_data = await curr_task_ctx._call_data_to_output()
if not wrapped_call_data:
raise ValueError(
f"task {curr_task_ctx.task_id} MapDAGNode expects wrapped_call_data"
)
output: TaskOutput[OUT] = await wrapped_call_data.map(map_function)
curr_task_ctx.set_task_output(output)
return output
@@ -150,6 +174,7 @@ class MapOperator(BaseOperator, Generic[IN, OUT]):
return output
async def map(self, input_value: IN) -> OUT:
"""Map the input data to a new value."""
raise NotImplementedError
@@ -161,6 +186,11 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
This node filters its input data using a branching function and
allows for conditional paths in the workflow.
If a branch function returns True, the corresponding task will be executed.
otherwise, the corresponding task will be skipped, and the output of
this skip node will be set to `SKIP_DATA`
"""
def __init__(
@@ -168,11 +198,11 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
branches: Optional[Dict[BranchFunc[IN], Union[BaseOperator, str]]] = None,
**kwargs,
):
"""
Initializes a BranchDAGNode with a branching function.
"""Create a BranchDAGNode with a branching function.
Args:
branches (Dict[BranchFunc[IN], Union[BaseOperator, str]]): Dict of function that defines the branching condition.
branches (Dict[BranchFunc[IN], Union[BaseOperator, str]]):
Dict of function that defines the branching condition.
Raises:
ValueError: If the branch_function is not callable.
@@ -183,7 +213,9 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
if not callable(branch_function):
raise ValueError("branch_function must be callable")
if isinstance(value, BaseOperator):
branches[branch_function] = value.node_name or value.node_name
if not value.node_name:
raise ValueError("branch node name must be set")
branches[branch_function] = value.node_name
self._branches = branches
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
@@ -210,7 +242,7 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
branches = await self.branches()
branch_func_tasks = []
branch_nodes: List[str] = []
branch_nodes: List[Union[BaseOperator, str]] = []
for func, node_name in branches.items():
branch_nodes.append(node_name)
branch_func_tasks.append(
@@ -225,20 +257,25 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
node_name = branch_nodes[i]
branch_out = ctx.parent_outputs[0].task_output
logger.info(
f"branch_input_ctxs {i} result {branch_out.output}, is_empty: {branch_out.is_empty}"
f"branch_input_ctxs {i} result {branch_out.output}, "
f"is_empty: {branch_out.is_empty}"
)
if ctx.parent_outputs[0].task_output.is_empty:
if ctx.parent_outputs[0].task_output.is_none:
logger.info(f"Skip node name {node_name}")
skip_node_names.append(node_name)
curr_task_ctx.update_metadata("skip_node_names", skip_node_names)
return parent_output
async def branches(self) -> Dict[BranchFunc[IN], Union[BaseOperator, str]]:
"""Return branch logic based on input data."""
raise NotImplementedError
class InputOperator(BaseOperator, Generic[OUT]):
"""Operator node that reads data from an input source."""
def __init__(self, input_source: InputSource[OUT], **kwargs) -> None:
"""Create an InputDAGNode with an input source."""
super().__init__(**kwargs)
self._input_source = input_source
@@ -250,7 +287,10 @@ class InputOperator(BaseOperator, Generic[OUT]):
class TriggerOperator(InputOperator, Generic[OUT]):
"""Operator node that triggers the DAG to run."""
def __init__(self, **kwargs) -> None:
"""Create a TriggerDAGNode."""
from ..task.task_impl import SimpleCallDataInputSource
super().__init__(input_source=SimpleCallDataInputSource(), **kwargs)

View File

@@ -1,3 +1,4 @@
"""The module of stream operator."""
from abc import ABC, abstractmethod
from typing import AsyncIterator, Generic
@@ -7,12 +8,18 @@ from .base import BaseOperator
class StreamifyAbsOperator(BaseOperator[OUT], ABC, Generic[IN, OUT]):
"""An abstract operator that converts a value of IN to an AsyncIterator[OUT]."""
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
call_data = curr_task_ctx.call_data
if call_data:
call_data = await curr_task_ctx._call_data_to_output()
output = await call_data.streamify(self.streamify)
wrapped_call_data = await curr_task_ctx._call_data_to_output()
if not wrapped_call_data:
raise ValueError(
f"task {curr_task_ctx.task_id} MapDAGNode expects wrapped_call_data"
)
output = await wrapped_call_data.streamify(self.streamify)
curr_task_ctx.set_task_output(output)
return output
output = await curr_task_ctx.task_input.parent_outputs[0].task_output.streamify(
@@ -23,26 +30,28 @@ class StreamifyAbsOperator(BaseOperator[OUT], ABC, Generic[IN, OUT]):
@abstractmethod
async def streamify(self, input_value: IN) -> AsyncIterator[OUT]:
"""Convert a value of IN to an AsyncIterator[OUT]
"""Convert a value of IN to an AsyncIterator[OUT].
Args:
input_value (IN): The data of parent operator's output
Example:
Examples:
.. code-block:: python
.. code-block:: python
class MyStreamOperator(StreamifyAbsOperator[int, int]):
async def streamify(self, input_value: int) -> AsyncIterator[int]:
for i in range(input_value):
yield i
class MyStreamOperator(StreamifyAbsOperator[int, int]):
async def streamify(self, input_value: int) -> AsyncIterator[int]:
for i in range(input_value):
yield i
"""
class UnstreamifyAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
"""An abstract operator that converts a value of AsyncIterator[IN] to an OUT."""
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
output = await curr_task_ctx.task_input.parent_outputs[
output: TaskOutput[OUT] = await curr_task_ctx.task_input.parent_outputs[
0
].task_output.unstreamify(self.unstreamify)
curr_task_ctx.set_task_output(output)
@@ -56,24 +65,30 @@ class UnstreamifyAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
input_value (AsyncIterator[IN])): The data of parent operator's output
Example:
.. code-block:: python
.. code-block:: python
class MyUnstreamOperator(UnstreamifyAbsOperator[int, int]):
async def unstreamify(self, input_value: AsyncIterator[int]) -> int:
value_cnt = 0
async for v in input_value:
value_cnt += 1
return value_cnt
class MyUnstreamOperator(UnstreamifyAbsOperator[int, int]):
async def unstreamify(self, input_value: AsyncIterator[int]) -> int:
value_cnt = 0
async for v in input_value:
value_cnt += 1
return value_cnt
"""
class TransformStreamAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
"""Streaming to other streaming data.
An abstract operator that transforms a value of
AsyncIterator[IN] to another AsyncIterator[OUT].
"""
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
output = await curr_task_ctx.task_input.parent_outputs[
output: TaskOutput[OUT] = await curr_task_ctx.task_input.parent_outputs[
0
].task_output.transform_stream(self.transform_stream)
curr_task_ctx.set_task_output(output)
return output
@@ -81,19 +96,18 @@ class TransformStreamAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
async def transform_stream(
self, input_value: AsyncIterator[IN]
) -> AsyncIterator[OUT]:
"""Transform an AsyncIterator[IN] to another AsyncIterator[OUT] using a given function.
"""Transform an AsyncIterator[IN] to another AsyncIterator[OUT].
Args:
input_value (AsyncIterator[IN])): The data of parent operator's output
Example:
Examples:
.. code-block:: python
.. code-block:: python
class MyTransformStreamOperator(TransformStreamAbsOperator[int, int]):
async def unstreamify(
self, input_value: AsyncIterator[int]
) -> AsyncIterator[int]:
async for v in input_value:
yield v + 1
class MyTransformStreamOperator(TransformStreamAbsOperator[int, int]):
async def unstreamify(
self, input_value: AsyncIterator[int]
) -> AsyncIterator[int]:
async for v in input_value:
yield v + 1
"""

View File

@@ -0,0 +1,4 @@
"""The module of AWEL resource.
Not implemented yet.
"""

View File

@@ -1,8 +1,15 @@
"""Base class for resource group."""
from abc import ABC, abstractmethod
class ResourceGroup(ABC):
"""Base class for resource group.
A resource group is a group of resources that are related to each other.
It contains the all resources that are needed to run a workflow.
"""
@property
@abstractmethod
def name(self) -> str:
"""The name of current resource group"""
"""Return the name of current resource group."""

View File

@@ -0,0 +1,4 @@
"""The module to run AWEL operators.
You can implement your own runner by inheriting the `WorkflowRunner` class.
"""

View File

@@ -1,33 +1,38 @@
"""Job manager for DAG."""
import asyncio
import logging
import uuid
from typing import Dict, List, Optional, Set
from typing import Dict, List, Optional, cast
from ..dag.base import DAG, DAGLifecycle
from ..dag.base import DAGLifecycle
from ..operator.base import CALL_DATA, BaseOperator
logger = logging.getLogger(__name__)
class DAGNodeInstance:
def __init__(self, node_instance: DAG) -> None:
pass
class DAGInstance:
def __init__(self, dag: DAG) -> None:
self._dag = dag
class JobManager(DAGLifecycle):
"""Job manager for DAG.
This class is used to manage the DAG lifecycle.
"""
def __init__(
self,
root_nodes: List[BaseOperator],
all_nodes: List[BaseOperator],
end_node: BaseOperator,
id2call_data: Dict[str, Dict],
id2call_data: Dict[str, Optional[Dict]],
node_name_to_ids: Dict[str, str],
) -> None:
"""Create a job manager.
Args:
root_nodes (List[BaseOperator]): The root nodes of the DAG.
all_nodes (List[BaseOperator]): All nodes of the DAG.
end_node (BaseOperator): The end node of the DAG.
id2call_data (Dict[str, Optional[Dict]]): The call data of each node.
node_name_to_ids (Dict[str, str]): The node name to node id mapping.
"""
self._root_nodes = root_nodes
self._all_nodes = all_nodes
self._end_node = end_node
@@ -38,6 +43,15 @@ class JobManager(DAGLifecycle):
def build_from_end_node(
end_node: BaseOperator, call_data: Optional[CALL_DATA] = None
) -> "JobManager":
"""Build a job manager from the end node.
This will get all upstream nodes from the end node, and build a job manager.
Args:
end_node (BaseOperator): The end node of the DAG.
call_data (Optional[CALL_DATA], optional): The call data of the end node.
Defaults to None.
"""
nodes = _build_from_end_node(end_node)
root_nodes = _get_root_nodes(nodes)
id2call_data = _save_call_data(root_nodes, call_data)
@@ -50,17 +64,22 @@ class JobManager(DAGLifecycle):
return JobManager(root_nodes, nodes, end_node, id2call_data, node_name_to_ids)
def get_call_data_by_id(self, node_id: str) -> Optional[Dict]:
"""Get the call data by node id.
Args:
node_id (str): The node id.
"""
return self._id2node_data.get(node_id)
async def before_dag_run(self):
"""The callback before DAG run"""
"""Execute the callback before DAG run."""
tasks = []
for node in self._all_nodes:
tasks.append(node.before_dag_run())
await asyncio.gather(*tasks)
async def after_dag_end(self):
"""The callback after DAG end"""
"""Execute the callback after DAG end."""
tasks = []
for node in self._all_nodes:
tasks.append(node.after_dag_end())
@@ -68,9 +87,9 @@ class JobManager(DAGLifecycle):
def _save_call_data(
root_nodes: List[BaseOperator], call_data: CALL_DATA
) -> Dict[str, Dict]:
id2call_data = {}
root_nodes: List[BaseOperator], call_data: Optional[CALL_DATA]
) -> Dict[str, Optional[Dict]]:
id2call_data: Dict[str, Optional[Dict]] = {}
logger.debug(f"_save_call_data: {call_data}, root_nodes: {root_nodes}")
if not call_data:
return id2call_data
@@ -82,7 +101,8 @@ def _save_call_data(
for node in root_nodes:
node_id = node.node_id
logger.debug(
f"Save call data to node {node.node_id}, call_data: {call_data.get(node_id)}"
f"Save call data to node {node.node_id}, call_data: "
f"{call_data.get(node_id)}"
)
id2call_data[node_id] = call_data.get(node_id)
return id2call_data
@@ -91,13 +111,11 @@ def _save_call_data(
def _build_from_end_node(end_node: BaseOperator) -> List[BaseOperator]:
"""Build all nodes from the end node."""
nodes = []
if isinstance(end_node, BaseOperator):
task_id = end_node.node_id
if not task_id:
task_id = str(uuid.uuid4())
end_node.set_node_id(task_id)
if isinstance(end_node, BaseOperator) and not end_node._node_id:
end_node.set_node_id(str(uuid.uuid4()))
nodes.append(end_node)
for node in end_node.upstream:
node = cast(BaseOperator, node)
nodes += _build_from_end_node(node)
return nodes

View File

@@ -1,12 +1,16 @@
"""Local runner for workflow.
This runner will run the workflow in the current process.
"""
import logging
from typing import Dict, List, Optional, Set
from typing import Dict, List, Optional, Set, cast
from dbgpt.component import SystemApp
from ..dag.base import DAGContext, DAGVar
from ..operator.base import CALL_DATA, BaseOperator, WorkflowRunner
from ..operator.common_operator import BranchOperator, JoinOperator, TriggerOperator
from ..task.base import TaskContext, TaskState
from ..operator.common_operator import BranchOperator, JoinOperator
from ..task.base import SKIP_DATA, TaskContext, TaskState
from ..task.task_impl import DefaultInputContext, DefaultTaskContext, SimpleTaskOutput
from .job_manager import JobManager
@@ -14,6 +18,8 @@ logger = logging.getLogger(__name__)
class DefaultWorkflowRunner(WorkflowRunner):
"""The default workflow runner."""
async def execute_workflow(
self,
node: BaseOperator,
@@ -21,6 +27,17 @@ class DefaultWorkflowRunner(WorkflowRunner):
streaming_call: bool = False,
exist_dag_ctx: Optional[DAGContext] = None,
) -> DAGContext:
"""Execute the workflow.
Args:
node (BaseOperator): The end node of the workflow.
call_data (Optional[CALL_DATA], optional): The call data of the end node.
Defaults to None.
streaming_call (bool, optional): Whether the call is streaming call.
Defaults to False.
exist_dag_ctx (Optional[DAGContext], optional): The exist DAG context.
Defaults to None.
"""
# Save node output
# dag = node.dag
job_manager = JobManager.build_from_end_node(node, call_data)
@@ -37,8 +54,8 @@ class DefaultWorkflowRunner(WorkflowRunner):
)
logger.info(f"Begin run workflow from end operator, id: {node.node_id}")
logger.debug(f"Node id {node.node_id}, call_data: {call_data}")
skip_node_ids = set()
system_app: SystemApp = DAGVar.get_current_system_app()
skip_node_ids: Set[str] = set()
system_app: Optional[SystemApp] = DAGVar.get_current_system_app()
await job_manager.before_dag_run()
await self._execute_node(
@@ -57,7 +74,7 @@ class DefaultWorkflowRunner(WorkflowRunner):
dag_ctx: DAGContext,
node_outputs: Dict[str, TaskContext],
skip_node_ids: Set[str],
system_app: SystemApp,
system_app: Optional[SystemApp],
):
# Skip run node
if node.node_id in node_outputs:
@@ -79,8 +96,12 @@ class DefaultWorkflowRunner(WorkflowRunner):
node_outputs[upstream_node.node_id] for upstream_node in node.upstream
]
input_ctx = DefaultInputContext(inputs)
task_ctx = DefaultTaskContext(node.node_id, TaskState.INIT, task_output=None)
task_ctx.set_call_data(job_manager.get_call_data_by_id(node.node_id))
task_ctx: DefaultTaskContext = DefaultTaskContext(
node.node_id, TaskState.INIT, task_output=None
)
current_call_data = job_manager.get_call_data_by_id(node.node_id)
if current_call_data:
task_ctx.set_call_data(current_call_data)
task_ctx.set_task_input(input_ctx)
dag_ctx.set_current_task_context(task_ctx)
@@ -88,12 +109,13 @@ class DefaultWorkflowRunner(WorkflowRunner):
if node.node_id in skip_node_ids:
task_ctx.set_current_state(TaskState.SKIP)
task_ctx.set_task_output(SimpleTaskOutput(None))
task_ctx.set_task_output(SimpleTaskOutput(SKIP_DATA))
node_outputs[node.node_id] = task_ctx
return
try:
logger.debug(
f"Begin run operator, node id: {node.node_id}, node name: {node.node_name}, cls: {node}"
f"Begin run operator, node id: {node.node_id}, node name: "
f"{node.node_name}, cls: {node}"
)
if system_app is not None and node.system_app is None:
node.set_system_app(system_app)
@@ -120,6 +142,7 @@ def _skip_current_downstream_by_node_name(
if not skip_nodes:
return
for child in branch_node.downstream:
child = cast(BaseOperator, child)
if child.node_name in skip_nodes:
logger.info(f"Skip node name {child.node_name}, node id {child.node_id}")
_skip_downstream_by_id(child, skip_node_ids)
@@ -131,4 +154,5 @@ def _skip_downstream_by_id(node: BaseOperator, skip_node_ids: Set[str]):
return
skip_node_ids.add(node.node_id)
for child in node.downstream:
child = cast(BaseOperator, child)
_skip_downstream_by_id(child, skip_node_ids)

View File

@@ -0,0 +1 @@
"""The module of Task."""

View File

@@ -1,8 +1,10 @@
"""Base classes for task-related objects."""
from abc import ABC, abstractmethod
from enum import Enum
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Dict,
Generic,
@@ -17,6 +19,24 @@ OUT = TypeVar("OUT")
T = TypeVar("T")
class _EMPTY_DATA_TYPE:
def __bool__(self):
return False
EMPTY_DATA = _EMPTY_DATA_TYPE()
SKIP_DATA = _EMPTY_DATA_TYPE()
PLACEHOLDER_DATA = _EMPTY_DATA_TYPE()
MapFunc = Union[Callable[[IN], OUT], Callable[[IN], Awaitable[OUT]]]
ReduceFunc = Union[Callable[[IN], OUT], Callable[[IN], Awaitable[OUT]]]
StreamFunc = Callable[[IN], Awaitable[AsyncIterator[OUT]]]
UnStreamFunc = Callable[[AsyncIterator[IN]], OUT]
TransformFunc = Callable[[AsyncIterator[IN]], Awaitable[AsyncIterator[OUT]]]
PredicateFunc = Union[Callable[[IN], bool], Callable[[IN], Awaitable[bool]]]
JoinFunc = Union[Callable[..., OUT], Callable[..., Awaitable[OUT]]]
class TaskState(str, Enum):
"""Enumeration representing the state of a task in the workflow.
@@ -33,8 +53,8 @@ class TaskState(str, Enum):
class TaskOutput(ABC, Generic[T]):
"""Abstract base class representing the output of a task.
This class encapsulates the output of a task and provides methods to access the output data.
It can be subclassed to implement specific output behaviors.
This class encapsulates the output of a task and provides methods to access the
output data.It can be subclassed to implement specific output behaviors.
"""
@property
@@ -56,20 +76,30 @@ class TaskOutput(ABC, Generic[T]):
return False
@property
def output(self) -> Optional[T]:
def is_none(self) -> bool:
"""Check if the output is None.
Returns:
bool: True if the output is None, False otherwise.
"""
return False
@property
def output(self) -> T:
"""Return the output of the task.
Returns:
T: The output of the task. None if the output is empty.
T: The output of the task.
"""
raise NotImplementedError
@property
def output_stream(self) -> Optional[AsyncIterator[T]]:
def output_stream(self) -> AsyncIterator[T]:
"""Return the output of the task as an asynchronous stream.
Returns:
AsyncIterator[T]: An asynchronous iterator over the output. None if the output is empty.
AsyncIterator[T]: An asynchronous iterator over the output. None if the
output is empty.
"""
raise NotImplementedError
@@ -83,39 +113,38 @@ class TaskOutput(ABC, Generic[T]):
@abstractmethod
def new_output(self) -> "TaskOutput[T]":
"""Create new output object"""
"""Create new output object."""
async def map(self, map_func) -> "TaskOutput[T]":
async def map(self, map_func: MapFunc) -> "TaskOutput[OUT]":
"""Apply a mapping function to the task's output.
Args:
map_func: A function to apply to the task's output.
map_func (MapFunc): A function to apply to the task's output.
Returns:
TaskOutput[T]: The result of applying the mapping function.
TaskOutput[OUT]: The result of applying the mapping function.
"""
raise NotImplementedError
async def reduce(self, reduce_func) -> "TaskOutput[T]":
async def reduce(self, reduce_func: ReduceFunc) -> "TaskOutput[OUT]":
"""Apply a reducing function to the task's output.
Stream TaskOutput to Nonstream TaskOutput.
Stream TaskOutput to no stream TaskOutput.
Args:
reduce_func: A reducing function to apply to the task's output.
Returns:
TaskOutput[T]: The result of applying the reducing function.
TaskOutput[OUT]: The result of applying the reducing function.
"""
raise NotImplementedError
async def streamify(
self, transform_func: Callable[[T], AsyncIterator[T]]
) -> "TaskOutput[T]":
async def streamify(self, transform_func: StreamFunc) -> "TaskOutput[T]":
"""Convert a value of type T to an AsyncIterator[T] using a transform function.
Args:
transform_func (Callable[[T], AsyncIterator[T]]): Function to transform a T value into an AsyncIterator[T].
transform_func (StreamFunc): Function to transform a T value into an
AsyncIterator[OUT].
Returns:
TaskOutput[T]: The result of applying the reducing function.
@@ -123,38 +152,39 @@ class TaskOutput(ABC, Generic[T]):
raise NotImplementedError
async def transform_stream(
self, transform_func: Callable[[AsyncIterator[T]], AsyncIterator[T]]
) -> "TaskOutput[T]":
"""Transform an AsyncIterator[T] to another AsyncIterator[T] using a given function.
self, transform_func: TransformFunc
) -> "TaskOutput[OUT]":
"""Transform an AsyncIterator[T] to another AsyncIterator[T].
Args:
transform_func (Callable[[AsyncIterator[T]], AsyncIterator[T]]): Function to apply to the AsyncIterator[T].
transform_func (Callable[[AsyncIterator[T]], AsyncIterator[T]]): Function to
apply to the AsyncIterator[T].
Returns:
TaskOutput[T]: The result of applying the reducing function.
"""
raise NotImplementedError
async def unstreamify(
self, transform_func: Callable[[AsyncIterator[T]], T]
) -> "TaskOutput[T]":
async def unstreamify(self, transform_func: UnStreamFunc) -> "TaskOutput[OUT]":
"""Convert an AsyncIterator[T] to a value of type T using a transform function.
Args:
transform_func (Callable[[AsyncIterator[T]], T]): Function to transform an AsyncIterator[T] into a T value.
transform_func (UnStreamFunc): Function to transform an AsyncIterator[T]
into a T value.
Returns:
TaskOutput[T]: The result of applying the reducing function.
"""
raise NotImplementedError
async def check_condition(self, condition_func) -> bool:
async def check_condition(self, condition_func) -> "TaskOutput[OUT]":
"""Check if current output meets a given condition.
Args:
condition_func: A function to determine if the condition is met.
Returns:
bool: True if current output meet the condition, False otherwise.
TaskOutput[T]: The result of applying the reducing function.
If the condition is not met, return empty output.
"""
raise NotImplementedError
@@ -182,6 +212,9 @@ class TaskContext(ABC, Generic[T]):
Returns:
InputContext: The InputContext of current task.
Raises:
Exception: If the InputContext is not set.
"""
@abstractmethod
@@ -216,7 +249,7 @@ class TaskContext(ABC, Generic[T]):
@abstractmethod
def set_current_state(self, task_state: TaskState) -> None:
"""Set current task state
"""Set current task state.
Args:
task_state (TaskState): The task state to be set.
@@ -224,7 +257,7 @@ class TaskContext(ABC, Generic[T]):
@abstractmethod
def new_ctx(self) -> "TaskContext":
"""Create new task context
"""Create new task context.
Returns:
TaskContext: A new instance of a TaskContext.
@@ -233,14 +266,14 @@ class TaskContext(ABC, Generic[T]):
@property
@abstractmethod
def metadata(self) -> Dict[str, Any]:
"""Get the metadata of current task
"""Return the metadata of current task.
Returns:
Dict[str, Any]: The metadata
"""
def update_metadata(self, key: str, value: Any) -> None:
"""Update metadata with key and value
"""Update metadata with key and value.
Args:
key (str): The key of metadata
@@ -250,15 +283,15 @@ class TaskContext(ABC, Generic[T]):
@property
def call_data(self) -> Optional[Dict]:
"""Get the call data for current data"""
"""Return the call data for current data."""
return self.metadata.get("call_data")
@abstractmethod
async def _call_data_to_output(self) -> Optional[TaskOutput[T]]:
"""Get the call data for current data"""
"""Get the call data for current data."""
def set_call_data(self, call_data: Dict) -> None:
"""Set call data for current task"""
"""Save the call data for current task."""
self.update_metadata("call_data", call_data)
@@ -315,7 +348,8 @@ class InputContext(ABC):
"""Filter the inputs based on a provided function.
Args:
filter_func (Callable[[Any], bool]): A function that returns True for inputs to keep.
filter_func (Callable[[Any], bool]): A function that returns True for
inputs to keep.
Returns:
InputContext: A new InputContext instance with the filtered inputs.
@@ -323,13 +357,15 @@ class InputContext(ABC):
@abstractmethod
async def predicate_map(
self, predicate_func: Callable[[Any], bool], failed_value: Any = None
self, predicate_func: PredicateFunc, failed_value: Any = None
) -> "InputContext":
"""Predicate the inputs based on a provided function.
Args:
predicate_func (Callable[[Any], bool]): A function that returns True for inputs is predicate True.
failed_value (Any): The value to be set if the return value of predicate function is False
predicate_func (Callable[[Any], bool]): A function that returns True for
inputs is predicate True.
failed_value (Any): The value to be set if the return value of predicate
function is False
Returns:
InputContext: A new InputContext instance with the predicate inputs.
"""

View File

@@ -1,3 +1,7 @@
"""The default implementation of Task.
This implementation can run workflow in local machine.
"""
import asyncio
import logging
from abc import ABC, abstractmethod
@@ -8,15 +12,32 @@ from typing import (
Coroutine,
Dict,
Generic,
Iterator,
List,
Optional,
Tuple,
TypeVar,
Union,
cast,
)
from .base import InputContext, InputSource, T, TaskContext, TaskOutput, TaskState
from .base import (
_EMPTY_DATA_TYPE,
EMPTY_DATA,
OUT,
PLACEHOLDER_DATA,
SKIP_DATA,
InputContext,
InputSource,
MapFunc,
PredicateFunc,
ReduceFunc,
StreamFunc,
T,
TaskContext,
TaskOutput,
TaskState,
TransformFunc,
UnStreamFunc,
)
logger = logging.getLogger(__name__)
@@ -37,101 +58,197 @@ async def _reduce_stream(stream: AsyncIterator, reduce_function) -> Any:
class SimpleTaskOutput(TaskOutput[T], Generic[T]):
def __init__(self, data: T) -> None:
"""The default implementation of TaskOutput.
It wraps the no stream data and provide some basic data operations.
"""
def __init__(self, data: Union[T, _EMPTY_DATA_TYPE] = EMPTY_DATA) -> None:
"""Create a SimpleTaskOutput.
Args:
data (Union[T, _EMPTY_DATA_TYPE], optional): The output data. Defaults to
EMPTY_DATA.
"""
super().__init__()
self._data = data
@property
def output(self) -> T:
return self._data
"""Return the output data."""
if self._data == EMPTY_DATA:
raise ValueError("No output data for current task output")
return cast(T, self._data)
def set_output(self, output_data: T | AsyncIterator[T]) -> None:
self._data = output_data
"""Save the output data to current object.
Args:
output_data (T | AsyncIterator[T]): The output data.
"""
if _is_async_iterator(output_data):
raise ValueError(
f"Can not set stream data {output_data} to SimpleTaskOutput"
)
self._data = cast(T, output_data)
def new_output(self) -> TaskOutput[T]:
return SimpleTaskOutput(None)
"""Create new output object with empty data."""
return SimpleTaskOutput()
@property
def is_empty(self) -> bool:
"""Return True if the output data is empty."""
return self._data == EMPTY_DATA or self._data == SKIP_DATA
@property
def is_none(self) -> bool:
"""Return True if the output data is None."""
return self._data is None
async def _apply_func(self, func) -> Any:
"""Apply the function to current output data."""
if asyncio.iscoroutinefunction(func):
out = await func(self._data)
else:
out = func(self._data)
return out
async def map(self, map_func) -> TaskOutput[T]:
async def map(self, map_func: MapFunc) -> TaskOutput[OUT]:
"""Apply a mapping function to the task's output.
Args:
map_func (MapFunc): A function to apply to the task's output.
Returns:
TaskOutput[OUT]: The result of applying the mapping function.
"""
out = await self._apply_func(map_func)
return SimpleTaskOutput(out)
async def check_condition(self, condition_func) -> bool:
return await self._apply_func(condition_func)
async def check_condition(self, condition_func) -> TaskOutput[OUT]:
"""Check the condition function."""
out = await self._apply_func(condition_func)
if out:
return SimpleTaskOutput(PLACEHOLDER_DATA)
return SimpleTaskOutput(EMPTY_DATA)
async def streamify(
self, transform_func: Callable[[T], AsyncIterator[T]]
) -> TaskOutput[T]:
async def streamify(self, transform_func: StreamFunc) -> TaskOutput[OUT]:
"""Transform the task's output to a stream output.
Args:
transform_func (StreamFunc): A function to transform the task's output to a
stream output.
Returns:
TaskOutput[OUT]: The result of transforming the task's output to a stream
output.
"""
out = await self._apply_func(transform_func)
return SimpleStreamTaskOutput(out)
class SimpleStreamTaskOutput(TaskOutput[T], Generic[T]):
def __init__(self, data: AsyncIterator[T]) -> None:
"""The default stream implementation of TaskOutput."""
def __init__(
self, data: Union[AsyncIterator[T], _EMPTY_DATA_TYPE] = EMPTY_DATA
) -> None:
"""Create a SimpleStreamTaskOutput.
Args:
data (Union[AsyncIterator[T], _EMPTY_DATA_TYPE], optional): The output data.
Defaults to EMPTY_DATA.
"""
super().__init__()
self._data = data
@property
def is_stream(self) -> bool:
"""Return True if the output data is a stream."""
return True
@property
def is_empty(self) -> bool:
return not self._data
"""Return True if the output data is empty."""
return self._data == EMPTY_DATA or self._data == SKIP_DATA
@property
def is_none(self) -> bool:
"""Return True if the output data is None."""
return self._data is None
@property
def output_stream(self) -> AsyncIterator[T]:
return self._data
"""Return the output data.
Returns:
AsyncIterator[T]: The output data.
Raises:
ValueError: If the output data is empty.
"""
if self._data == EMPTY_DATA:
raise ValueError("No output data for current task output")
return cast(AsyncIterator[T], self._data)
def set_output(self, output_data: T | AsyncIterator[T]) -> None:
self._data = output_data
"""Save the output data to current object.
Raises:
ValueError: If the output data is not a stream.
"""
if not _is_async_iterator(output_data):
raise ValueError(
f"Can not set non-stream data {output_data} to SimpleStreamTaskOutput"
)
self._data = cast(AsyncIterator[T], output_data)
def new_output(self) -> TaskOutput[T]:
return SimpleStreamTaskOutput(None)
"""Create new output object with empty data."""
return SimpleStreamTaskOutput()
async def map(self, map_func) -> TaskOutput[T]:
async def map(self, map_func: MapFunc) -> TaskOutput[OUT]:
"""Apply a mapping function to the task's output."""
is_async = asyncio.iscoroutinefunction(map_func)
async def new_iter() -> AsyncIterator[T]:
async for out in self._data:
async def new_iter() -> AsyncIterator[OUT]:
async for out in self.output_stream:
if is_async:
out = await map_func(out)
new_out: OUT = await map_func(out)
else:
out = map_func(out)
yield out
new_out = cast(OUT, map_func(out))
yield new_out
return SimpleStreamTaskOutput(new_iter())
async def reduce(self, reduce_func) -> TaskOutput[T]:
out = await _reduce_stream(self._data, reduce_func)
async def reduce(self, reduce_func: ReduceFunc) -> TaskOutput[OUT]:
"""Apply a reduce function to the task's output."""
out = await _reduce_stream(self.output_stream, reduce_func)
return SimpleTaskOutput(out)
async def unstreamify(
self, transform_func: Callable[[AsyncIterator[T]], T]
) -> TaskOutput[T]:
async def unstreamify(self, transform_func: UnStreamFunc) -> TaskOutput[OUT]:
"""Transform the task's output to a non-stream output."""
if asyncio.iscoroutinefunction(transform_func):
out = await transform_func(self._data)
out = await transform_func(self.output_stream)
else:
out = transform_func(self._data)
out = transform_func(self.output_stream)
return SimpleTaskOutput(out)
async def transform_stream(
self, transform_func: Callable[[AsyncIterator[T]], AsyncIterator[T]]
) -> TaskOutput[T]:
async def transform_stream(self, transform_func: TransformFunc) -> TaskOutput[OUT]:
"""Transform an AsyncIterator[T] to another AsyncIterator[T].
Args:
transform_func (Callable[[AsyncIterator[T]], AsyncIterator[T]]): Function to
apply to the AsyncIterator[T].
Returns:
TaskOutput[T]: The result of applying the reducing function.
"""
if asyncio.iscoroutinefunction(transform_func):
out = await transform_func(self._data)
out: AsyncIterator[OUT] = await transform_func(self.output_stream)
else:
out = transform_func(self._data)
out = cast(AsyncIterator[OUT], transform_func(self.output_stream))
return SimpleStreamTaskOutput(out)
@@ -145,20 +262,34 @@ def _is_async_iterator(obj):
class BaseInputSource(InputSource, ABC):
"""The base class of InputSource."""
def __init__(self) -> None:
"""Create a BaseInputSource."""
super().__init__()
self._is_read = False
@abstractmethod
def _read_data(self, task_ctx: TaskContext) -> Any:
"""Read data with task context"""
"""Return data with task context."""
async def read(self, task_ctx: TaskContext) -> TaskOutput:
"""Read data with task context.
Args:
task_ctx (TaskContext): The task context.
Returns:
TaskOutput: The task output.
Raises:
ValueError: If the input source is a stream and has been read.
"""
data = self._read_data(task_ctx)
if _is_async_iterator(data):
if self._is_read:
raise ValueError(f"Input iterator {data} has been read!")
output = SimpleStreamTaskOutput(data)
output: TaskOutput = SimpleStreamTaskOutput(data)
else:
output = SimpleTaskOutput(data)
self._is_read = True
@@ -166,7 +297,14 @@ class BaseInputSource(InputSource, ABC):
class SimpleInputSource(BaseInputSource):
"""The default implementation of InputSource."""
def __init__(self, data: Any) -> None:
"""Create a SimpleInputSource.
Args:
data (Any): The input data.
"""
super().__init__()
self._data = data
@@ -175,63 +313,121 @@ class SimpleInputSource(BaseInputSource):
class SimpleCallDataInputSource(BaseInputSource):
"""The implementation of InputSource for call data."""
def __init__(self) -> None:
"""Create a SimpleCallDataInputSource."""
super().__init__()
def _read_data(self, task_ctx: TaskContext) -> Any:
"""Read data from task context.
Returns:
Any: The data.
Raises:
ValueError: If the call data is empty.
"""
call_data = task_ctx.call_data
data = call_data.get("data") if call_data else None
if not (call_data and data):
data = call_data.get("data", EMPTY_DATA) if call_data else EMPTY_DATA
if data == EMPTY_DATA:
raise ValueError("No call data for current SimpleCallDataInputSource")
return data
class DefaultTaskContext(TaskContext, Generic[T]):
"""The default implementation of TaskContext."""
def __init__(
self, task_id: str, task_state: TaskState, task_output: TaskOutput[T]
self,
task_id: str,
task_state: TaskState,
task_output: Optional[TaskOutput[T]] = None,
) -> None:
"""Create a DefaultTaskContext.
Args:
task_id (str): The task id.
task_state (TaskState): The task state.
task_output (Optional[TaskOutput[T]], optional): The task output. Defaults
to None.
"""
super().__init__()
self._task_id = task_id
self._task_state = task_state
self._output = task_output
self._task_input = None
self._metadata = {}
self._output: Optional[TaskOutput[T]] = task_output
self._task_input: Optional[InputContext] = None
self._metadata: Dict[str, Any] = {}
@property
def task_id(self) -> str:
"""Return the task id."""
return self._task_id
@property
def task_input(self) -> InputContext:
"""Return the task input."""
if not self._task_input:
raise ValueError("No input for current task context")
return self._task_input
def set_task_input(self, input_ctx: "InputContext") -> None:
def set_task_input(self, input_ctx: InputContext) -> None:
"""Save the task input to current task."""
self._task_input = input_ctx
@property
def task_output(self) -> TaskOutput:
"""Return the task output.
Returns:
TaskOutput: The task output.
Raises:
ValueError: If the task output is empty.
"""
if not self._output:
raise ValueError("No output for current task context")
return self._output
def set_task_output(self, task_output: TaskOutput) -> None:
"""Save the task output to current task.
Args:
task_output (TaskOutput): The task output.
"""
self._output = task_output
@property
def current_state(self) -> TaskState:
"""Return the current task state."""
return self._task_state
def set_current_state(self, task_state: TaskState) -> None:
"""Save the current task state to current task."""
self._task_state = task_state
def new_ctx(self) -> TaskContext:
"""Create new task context with empty output."""
if not self._output:
raise ValueError("No output for current task context")
new_output = self._output.new_output()
return DefaultTaskContext(self._task_id, self._task_state, new_output)
@property
def metadata(self) -> Dict[str, Any]:
"""Return the metadata of current task.
Returns:
Dict[str, Any]: The metadata.
"""
return self._metadata
async def _call_data_to_output(self) -> Optional[TaskOutput[T]]:
"""Get the call data for current data"""
"""Return the call data of current task.
Returns:
Optional[TaskOutput[T]]: The call data.
"""
call_data = self.call_data
if not call_data:
return None
@@ -240,24 +436,48 @@ class DefaultTaskContext(TaskContext, Generic[T]):
class DefaultInputContext(InputContext):
"""The default implementation of InputContext.
It wraps the all inputs from parent tasks and provide some basic data operations.
"""
def __init__(self, outputs: List[TaskContext]) -> None:
"""Create a DefaultInputContext.
Args:
outputs (List[TaskContext]): The outputs from parent tasks.
"""
super().__init__()
self._outputs = outputs
@property
def parent_outputs(self) -> List[TaskContext]:
"""Return the outputs from parent tasks.
Returns:
List[TaskContext]: The outputs from parent tasks.
"""
return self._outputs
async def _apply_func(
self, func: Callable[[Any], Any], apply_type: str = "map"
) -> Tuple[List[TaskContext], List[TaskOutput]]:
"""Apply the function to all parent outputs.
Args:
func (Callable[[Any], Any]): The function to apply.
apply_type (str, optional): The apply type. Defaults to "map".
Returns:
Tuple[List[TaskContext], List[TaskOutput]]: The new parent outputs and the
results of applying the function.
"""
new_outputs: List[TaskContext] = []
map_tasks = []
for out in self._outputs:
new_outputs.append(out.new_ctx())
result = None
if apply_type == "map":
result = out.task_output.map(func)
result: Coroutine[Any, Any, TaskOutput[Any]] = out.task_output.map(func)
elif apply_type == "reduce":
result = out.task_output.reduce(func)
elif apply_type == "check_condition":
@@ -269,29 +489,40 @@ class DefaultInputContext(InputContext):
return new_outputs, results
async def map(self, map_func: Callable[[Any], Any]) -> InputContext:
"""Apply a mapping function to all parent outputs."""
new_outputs, results = await self._apply_func(map_func)
for i, task_ctx in enumerate(new_outputs):
task_ctx: TaskContext = task_ctx
task_ctx = cast(TaskContext, task_ctx)
task_ctx.set_task_output(results[i])
return DefaultInputContext(new_outputs)
async def map_all(self, map_func: Callable[..., Any]) -> InputContext:
"""Apply a mapping function to all parent outputs.
The parent outputs will be unpacked and passed to the mapping function.
Args:
map_func (Callable[..., Any]): The mapping function.
Returns:
InputContext: The new input context.
"""
if not self._outputs:
return DefaultInputContext([])
# Some parent may be empty
not_empty_idx = 0
for i, p in enumerate(self._outputs):
if p.task_output.is_empty:
# Skip empty parent
continue
not_empty_idx = i
break
# All output is empty?
is_steam = self._outputs[not_empty_idx].task_output.is_stream
if is_steam:
if not self.check_stream(skip_empty=True):
raise ValueError(
"The output in all tasks must has same output format to map_all"
)
if is_steam and not self.check_stream(skip_empty=True):
raise ValueError(
"The output in all tasks must has same output format to map_all"
)
outputs = []
for out in self._outputs:
if out.task_output.is_stream:
@@ -305,22 +536,26 @@ class DefaultInputContext(InputContext):
single_output: TaskContext = self._outputs[not_empty_idx].new_ctx()
single_output.task_output.set_output(map_res)
logger.debug(
f"Current map_all map_res: {map_res}, is steam: {single_output.task_output.is_stream}"
f"Current map_all map_res: {map_res}, is steam: "
f"{single_output.task_output.is_stream}"
)
return DefaultInputContext([single_output])
async def reduce(self, reduce_func: Callable[[Any], Any]) -> InputContext:
"""Apply a reduce function to all parent outputs."""
if not self.check_stream():
raise ValueError(
"The output in all tasks must has same output format of stream to apply reduce function"
"The output in all tasks must has same output format of stream to apply"
" reduce function"
)
new_outputs, results = await self._apply_func(reduce_func, apply_type="reduce")
for i, task_ctx in enumerate(new_outputs):
task_ctx: TaskContext = task_ctx
task_ctx = cast(TaskContext, task_ctx)
task_ctx.set_task_output(results[i])
return DefaultInputContext(new_outputs)
async def filter(self, filter_func: Callable[[Any], bool]) -> InputContext:
"""Filter all parent outputs."""
new_outputs, results = await self._apply_func(
filter_func, apply_type="check_condition"
)
@@ -331,15 +566,16 @@ class DefaultInputContext(InputContext):
return DefaultInputContext(result_outputs)
async def predicate_map(
self, predicate_func: Callable[[Any], bool], failed_value: Any = None
self, predicate_func: PredicateFunc, failed_value: Any = None
) -> "InputContext":
"""Apply a predicate function to all parent outputs."""
new_outputs, results = await self._apply_func(
predicate_func, apply_type="check_condition"
)
result_outputs = []
for i, task_ctx in enumerate(new_outputs):
task_ctx: TaskContext = task_ctx
if results[i]:
task_ctx = cast(TaskContext, task_ctx)
if not results[i].is_empty:
task_ctx.task_output.set_output(True)
result_outputs.append(task_ctx)
else:

View File

@@ -66,10 +66,10 @@ async def _create_input_node(**kwargs):
else:
outputs = kwargs.get("outputs", ["Hello."])
nodes = []
for output in outputs:
for i, output in enumerate(outputs):
print(f"output: {output}")
input_source = SimpleInputSource(output)
input_node = InputOperator(input_source)
input_node = InputOperator(input_source, task_id="input_node_" + str(i))
nodes.append(input_node)
yield nodes

View File

@@ -26,7 +26,7 @@ from .conftest import (
@pytest.mark.asyncio
async def test_input_node(runner: WorkflowRunner):
input_node = InputOperator(SimpleInputSource("hello"))
input_node = InputOperator(SimpleInputSource("hello"), task_id="112232")
res: DAGContext[str] = await runner.execute_workflow(input_node)
assert res.current_task_context.current_state == TaskState.SUCCESS
assert res.current_task_context.task_output.output == "hello"
@@ -36,7 +36,9 @@ async def test_input_node(runner: WorkflowRunner):
yield i
num_iter = 10
steam_input_node = InputOperator(SimpleInputSource(new_steam_iter(num_iter)))
steam_input_node = InputOperator(
SimpleInputSource(new_steam_iter(num_iter)), task_id="112232"
)
res: DAGContext[str] = await runner.execute_workflow(steam_input_node)
assert res.current_task_context.current_state == TaskState.SUCCESS
output_steam = res.current_task_context.task_output.output_stream

View File

@@ -0,0 +1 @@
"""The trigger module of AWEL."""

View File

@@ -1,3 +1,4 @@
"""Base class for all trigger classes."""
from __future__ import annotations
from abc import ABC, abstractmethod
@@ -6,6 +7,11 @@ from ..operator.common_operator import TriggerOperator
class Trigger(TriggerOperator, ABC):
"""Base class for all trigger classes.
Now only support http trigger.
"""
@abstractmethod
async def trigger(self) -> None:
"""Trigger the workflow or a specific operation in the workflow."""

View File

@@ -1,10 +1,11 @@
"""Http trigger for AWEL."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union, cast
from starlette.requests import Request
from starlette.responses import Response
from dbgpt._private.pydantic import BaseModel
@@ -13,29 +14,35 @@ from ..operator.base import BaseOperator
from .base import Trigger
if TYPE_CHECKING:
from fastapi import APIRouter, FastAPI
from fastapi import APIRouter
RequestBody = Union[Type[Request], Type[BaseModel], str]
RequestBody = Union[Type[Request], Type[BaseModel], Type[str]]
StreamingPredictFunc = Callable[[Union[Request, BaseModel]], bool]
logger = logging.getLogger(__name__)
class HttpTrigger(Trigger):
"""Http trigger for AWEL.
Http trigger is used to trigger a DAG by http request.
"""
def __init__(
self,
endpoint: str,
methods: Optional[Union[str, List[str]]] = "GET",
request_body: Optional[RequestBody] = None,
streaming_response: Optional[bool] = False,
streaming_response: bool = False,
streaming_predict_func: Optional[StreamingPredictFunc] = None,
response_model: Optional[Type] = None,
response_headers: Optional[Dict[str, str]] = None,
response_media_type: Optional[str] = None,
status_code: Optional[int] = 200,
router_tags: Optional[List[str]] = None,
router_tags: Optional[List[str | Enum]] = None,
**kwargs,
) -> None:
"""Initialize a HttpTrigger."""
super().__init__(**kwargs)
if not endpoint.startswith("/"):
endpoint = "/" + endpoint
@@ -49,15 +56,21 @@ class HttpTrigger(Trigger):
self._router_tags = router_tags
self._response_headers = response_headers
self._response_media_type = response_media_type
self._end_node: BaseOperator = None
self._end_node: Optional[BaseOperator] = None
async def trigger(self) -> None:
"""Trigger the DAG. Not used in HttpTrigger."""
pass
def mount_to_router(self, router: "APIRouter") -> None:
"""Mount the trigger to a router.
Args:
router (APIRouter): The router to mount the trigger.
"""
from fastapi import Depends
methods = self._methods if isinstance(self._methods, list) else [self._methods]
methods = [self._methods] if isinstance(self._methods, str) else self._methods
def create_route_function(name, req_body_cls: Optional[Type[BaseModel]]):
async def _request_body_dependency(request: Request):
@@ -87,7 +100,8 @@ class HttpTrigger(Trigger):
)
dynamic_route_function = create_route_function(function_name, request_model)
logger.info(
f"mount router function {dynamic_route_function}({function_name}), endpoint: {self._endpoint}, methods: {methods}"
f"mount router function {dynamic_route_function}({function_name}), "
f"endpoint: {self._endpoint}, methods: {methods}"
)
router.api_route(
@@ -100,17 +114,27 @@ class HttpTrigger(Trigger):
async def _parse_request_body(
request: Request, request_body_cls: Optional[Type[BaseModel]]
request: Request, request_body_cls: Optional[RequestBody]
):
if not request_body_cls:
return None
if request.method == "POST":
json_data = await request.json()
return request_body_cls(**json_data)
elif request.method == "GET":
return request_body_cls(**request.query_params)
else:
if request_body_cls == Request:
return request
if request.method == "POST":
if request_body_cls == str:
bytes_body = await request.body()
str_body = bytes_body.decode("utf-8")
return str_body
elif issubclass(request_body_cls, BaseModel):
json_data = await request.json()
return request_body_cls(**json_data)
else:
raise ValueError(f"Invalid request body cls: {request_body_cls}")
elif request.method == "GET":
if issubclass(request_body_cls, BaseModel):
return request_body_cls(**request.query_params)
else:
raise ValueError(f"Invalid request body cls: {request_body_cls}")
async def _trigger_dag(
@@ -123,10 +147,10 @@ async def _trigger_dag(
from fastapi import BackgroundTasks
from fastapi.responses import StreamingResponse
end_node = dag.leaf_nodes
if len(end_node) != 1:
leaf_nodes = dag.leaf_nodes
if len(leaf_nodes) != 1:
raise ValueError("HttpTrigger just support one leaf node in dag")
end_node = end_node[0]
end_node = cast(BaseOperator, leaf_nodes[0])
if not streaming_response:
return await end_node.call(call_data={"data": body})
else:
@@ -141,7 +165,7 @@ async def _trigger_dag(
}
generator = await end_node.call_stream(call_data={"data": body})
background_tasks = BackgroundTasks()
background_tasks.add_task(end_node.dag._after_dag_end)
background_tasks.add_task(dag._after_dag_end)
return StreamingResponse(
generator,
headers=headers,

View File

@@ -1,41 +1,63 @@
"""Trigger manager for AWEL.
After DB-GPT started, the trigger manager will be initialized and register all triggers
"""
import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Dict, Optional
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from .base import Trigger
if TYPE_CHECKING:
from fastapi import APIRouter
from dbgpt.component import BaseComponent, ComponentType, SystemApp
logger = logging.getLogger(__name__)
class TriggerManager(ABC):
"""Base class for trigger manager."""
@abstractmethod
def register_trigger(self, trigger: Any) -> None:
""" "Register a trigger to current manager"""
"""Register a trigger to current manager."""
class HttpTriggerManager(TriggerManager):
"""Http trigger manager.
Register all http triggers to a router.
"""
def __init__(
self,
router: Optional["APIRouter"] = None,
router_prefix: Optional[str] = "/api/v1/awel/trigger",
router_prefix: str = "/api/v1/awel/trigger",
) -> None:
"""Initialize a HttpTriggerManager.
Args:
router (Optional["APIRouter"], optional): The router. Defaults to None.
If None, will create a new FastAPI router.
router_prefix (str, optional): The router prefix. Defaults
to "/api/v1/awel/trigger".
"""
if not router:
from fastapi import APIRouter
router = APIRouter()
self._router_prefix = router_prefix
self._router = router
self._trigger_map = {}
self._trigger_map: Dict[str, Trigger] = {}
def register_trigger(self, trigger: Any) -> None:
"""Register a trigger to current manager."""
from .http_trigger import HttpTrigger
if not isinstance(trigger, HttpTrigger):
raise ValueError(f"Current trigger {trigger} not an object of HttpTrigger")
trigger: HttpTrigger = trigger
trigger_id = trigger.node_id
if trigger_id not in self._trigger_map:
trigger.mount_to_router(self._router)
@@ -45,23 +67,32 @@ class HttpTriggerManager(TriggerManager):
logger.info(
f"Include router {self._router} to prefix path {self._router_prefix}"
)
system_app.app.include_router(
self._router, prefix=self._router_prefix, tags=["AWEL"]
)
app = system_app.app
if not app:
raise RuntimeError("System app not initialized")
app.include_router(self._router, prefix=self._router_prefix, tags=["AWEL"])
class DefaultTriggerManager(TriggerManager, BaseComponent):
"""Default trigger manager for AWEL.
Manage all trigger managers. Just support http trigger now.
"""
name = ComponentType.AWEL_TRIGGER_MANAGER
def __init__(self, system_app: SystemApp | None = None):
"""Initialize a DefaultTriggerManager."""
self.system_app = system_app
self.http_trigger = HttpTriggerManager()
super().__init__(None)
def init_app(self, system_app: SystemApp):
"""Initialize the trigger manager."""
self.system_app = system_app
def register_trigger(self, trigger: Any) -> None:
"""Register a trigger to current manager."""
from .http_trigger import HttpTrigger
if isinstance(trigger, HttpTrigger):
@@ -71,4 +102,6 @@ class DefaultTriggerManager(TriggerManager, BaseComponent):
raise ValueError(f"Unsupport trigger: {trigger}")
def after_register(self) -> None:
self.http_trigger._init_app(self.system_app)
"""After register, init the trigger manager."""
if self.system_app:
self.http_trigger._init_app(self.system_app)