mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-01 01:04:43 +00:00
chore: Add pylint for DB-GPT core lib (#1076)
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
"""Agentic Workflow Expression Language (AWEL)
|
||||
"""Agentic Workflow Expression Language (AWEL).
|
||||
|
||||
Note:
|
||||
|
||||
AWEL is still an experimental feature and only opens the lowest level API.
|
||||
The stability of this API cannot be guaranteed at present.
|
||||
Agentic Workflow Expression Language(AWEL) is a set of intelligent agent workflow
|
||||
expression language specially designed for large model application development. It
|
||||
provides great functionality and flexibility. Through the AWEL API, you can focus on
|
||||
the development of business logic for LLMs applications without paying attention to
|
||||
cumbersome model and environment details.
|
||||
|
||||
"""
|
||||
|
||||
@@ -71,10 +72,12 @@ __all__ = [
|
||||
"TransformStreamAbsOperator",
|
||||
"HttpTrigger",
|
||||
"setup_dev_environment",
|
||||
"_is_async_iterator",
|
||||
]
|
||||
|
||||
|
||||
def initialize_awel(system_app: SystemApp, dag_dirs: List[str]):
|
||||
"""Initialize AWEL."""
|
||||
from .dag.base import DAGVar
|
||||
from .dag.dag_manager import DAGManager
|
||||
from .operator.base import initialize_runner
|
||||
@@ -92,13 +95,13 @@ def initialize_awel(system_app: SystemApp, dag_dirs: List[str]):
|
||||
|
||||
def setup_dev_environment(
|
||||
dags: List[DAG],
|
||||
host: Optional[str] = "127.0.0.1",
|
||||
port: Optional[int] = 5555,
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 5555,
|
||||
logging_level: Optional[str] = None,
|
||||
logger_filename: Optional[str] = None,
|
||||
show_dag_graph: Optional[bool] = True,
|
||||
) -> None:
|
||||
"""Setup a development environment for AWEL.
|
||||
"""Run AWEL in development environment.
|
||||
|
||||
Just using in development environment, not production environment.
|
||||
|
||||
@@ -107,9 +110,11 @@ def setup_dev_environment(
|
||||
host (Optional[str], optional): The host. Defaults to "127.0.0.1"
|
||||
port (Optional[int], optional): The port. Defaults to 5555.
|
||||
logging_level (Optional[str], optional): The logging level. Defaults to None.
|
||||
logger_filename (Optional[str], optional): The logger filename. Defaults to None.
|
||||
show_dag_graph (Optional[bool], optional): Whether show the DAG graph. Defaults to True.
|
||||
If True, the DAG graph will be saved to a file and open it automatically.
|
||||
logger_filename (Optional[str], optional): The logger filename.
|
||||
Defaults to None.
|
||||
show_dag_graph (Optional[bool], optional): Whether show the DAG graph.
|
||||
Defaults to True. If True, the DAG graph will be saved to a file and open
|
||||
it automatically.
|
||||
"""
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
@@ -138,7 +143,9 @@ def setup_dev_environment(
|
||||
logger.info(f"Visualize DAG {str(dag)} to {dag_graph_file}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Visualize DAG {str(dag)} failed: {e}, if your system has no graphviz, you can install it by `pip install graphviz` or `sudo apt install graphviz`"
|
||||
f"Visualize DAG {str(dag)} failed: {e}, if your system has no "
|
||||
f"graphviz, you can install it by `pip install graphviz` or "
|
||||
f"`sudo apt install graphviz`"
|
||||
)
|
||||
for trigger in dag.trigger_nodes:
|
||||
trigger_manager.register_trigger(trigger)
|
||||
|
@@ -1,7 +1,10 @@
|
||||
"""Base classes for AWEL."""
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class Trigger(ABC):
|
||||
"""Base class for trigger."""
|
||||
|
||||
@abstractmethod
|
||||
async def trigger(self) -> None:
|
||||
"""Trigger the workflow or a specific operation in the workflow."""
|
||||
|
@@ -0,0 +1 @@
|
||||
"""The module of DAGs."""
|
||||
|
@@ -1,3 +1,7 @@
|
||||
"""The base module of DAG.
|
||||
|
||||
DAG is the core component of AWEL, it is used to define the relationship between tasks.
|
||||
"""
|
||||
import asyncio
|
||||
import contextvars
|
||||
import logging
|
||||
@@ -6,7 +10,7 @@ import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from concurrent.futures import Executor
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Union, cast
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
|
||||
@@ -27,86 +31,108 @@ def _is_async_context():
|
||||
|
||||
|
||||
class DependencyMixin(ABC):
|
||||
"""The mixin class for DAGNode.
|
||||
|
||||
This class defines the interface for setting upstream and downstream nodes.
|
||||
|
||||
And it also implements the operator << and >> for setting upstream
|
||||
and downstream nodes.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_upstream(self, nodes: DependencyType) -> "DependencyMixin":
|
||||
def set_upstream(self, nodes: DependencyType) -> None:
|
||||
"""Set one or more upstream nodes for this node.
|
||||
|
||||
Args:
|
||||
nodes (DependencyType): Upstream nodes to be set to current node.
|
||||
|
||||
Returns:
|
||||
DependencyMixin: Returns self to allow method chaining.
|
||||
|
||||
Raises:
|
||||
ValueError: If no upstream nodes are provided or if an argument is not a DependencyMixin.
|
||||
ValueError: If no upstream nodes are provided or if an argument is
|
||||
not a DependencyMixin.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_downstream(self, nodes: DependencyType) -> "DependencyMixin":
|
||||
def set_downstream(self, nodes: DependencyType) -> None:
|
||||
"""Set one or more downstream nodes for this node.
|
||||
|
||||
Args:
|
||||
nodes (DependencyType): Downstream nodes to be set to current node.
|
||||
|
||||
Returns:
|
||||
DependencyMixin: Returns self to allow method chaining.
|
||||
|
||||
Raises:
|
||||
ValueError: If no downstream nodes are provided or if an argument is not a DependencyMixin.
|
||||
ValueError: If no downstream nodes are provided or if an argument is
|
||||
not a DependencyMixin.
|
||||
"""
|
||||
|
||||
def __lshift__(self, nodes: DependencyType) -> DependencyType:
|
||||
"""Implements self << nodes
|
||||
"""Set upstream nodes for current node.
|
||||
|
||||
Implements: self << nodes.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
.. code-block:: python
|
||||
# means node.set_upstream(input_node)
|
||||
node << input_node
|
||||
# means node2.set_upstream([input_node])
|
||||
node2 << [input_node]
|
||||
|
||||
# means node.set_upstream(input_node)
|
||||
node << input_node
|
||||
|
||||
# means node2.set_upstream([input_node])
|
||||
node2 << [input_node]
|
||||
"""
|
||||
self.set_upstream(nodes)
|
||||
return nodes
|
||||
|
||||
def __rshift__(self, nodes: DependencyType) -> DependencyType:
|
||||
"""Implements self >> nodes
|
||||
"""Set downstream nodes for current node.
|
||||
|
||||
Example:
|
||||
Implements: self >> nodes.
|
||||
|
||||
.. code-block:: python
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
# means node.set_downstream(next_node)
|
||||
node >> next_node
|
||||
# means node.set_downstream(next_node)
|
||||
node >> next_node
|
||||
|
||||
# means node2.set_downstream([next_node])
|
||||
node2 >> [next_node]
|
||||
# means node2.set_downstream([next_node])
|
||||
node2 >> [next_node]
|
||||
|
||||
"""
|
||||
self.set_downstream(nodes)
|
||||
return nodes
|
||||
|
||||
def __rrshift__(self, nodes: DependencyType) -> "DependencyMixin":
|
||||
"""Implements [node] >> self"""
|
||||
"""Set upstream nodes for current node.
|
||||
|
||||
Implements: [node] >> self
|
||||
"""
|
||||
self.__lshift__(nodes)
|
||||
return self
|
||||
|
||||
def __rlshift__(self, nodes: DependencyType) -> "DependencyMixin":
|
||||
"""Implements [node] << self"""
|
||||
"""Set downstream nodes for current node.
|
||||
|
||||
Implements: [node] << self
|
||||
"""
|
||||
self.__rshift__(nodes)
|
||||
return self
|
||||
|
||||
|
||||
class DAGVar:
|
||||
"""The DAGVar is used to store the current DAG context."""
|
||||
|
||||
_thread_local = threading.local()
|
||||
_async_local = contextvars.ContextVar("current_dag_stack", default=deque())
|
||||
_system_app: SystemApp = None
|
||||
_executor: Executor = None
|
||||
_async_local: contextvars.ContextVar = contextvars.ContextVar(
|
||||
"current_dag_stack", default=deque()
|
||||
)
|
||||
_system_app: Optional[SystemApp] = None
|
||||
# The executor for current DAG, this is used run some sync tasks in async DAG
|
||||
_executor: Optional[Executor] = None
|
||||
|
||||
@classmethod
|
||||
def enter_dag(cls, dag) -> None:
|
||||
"""Enter a DAG context.
|
||||
|
||||
Args:
|
||||
dag (DAG): The DAG to enter
|
||||
"""
|
||||
is_async = _is_async_context()
|
||||
if is_async:
|
||||
stack = cls._async_local.get()
|
||||
@@ -119,6 +145,7 @@ class DAGVar:
|
||||
|
||||
@classmethod
|
||||
def exit_dag(cls) -> None:
|
||||
"""Exit a DAG context."""
|
||||
is_async = _is_async_context()
|
||||
if is_async:
|
||||
stack = cls._async_local.get()
|
||||
@@ -134,6 +161,11 @@ class DAGVar:
|
||||
|
||||
@classmethod
|
||||
def get_current_dag(cls) -> Optional["DAG"]:
|
||||
"""Get the current DAG.
|
||||
|
||||
Returns:
|
||||
Optional[DAG]: The current DAG
|
||||
"""
|
||||
is_async = _is_async_context()
|
||||
if is_async:
|
||||
stack = cls._async_local.get()
|
||||
@@ -147,36 +179,56 @@ class DAGVar:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_current_system_app(cls) -> SystemApp:
|
||||
def get_current_system_app(cls) -> Optional[SystemApp]:
|
||||
"""Get the current system app.
|
||||
|
||||
Returns:
|
||||
Optional[SystemApp]: The current system app
|
||||
"""
|
||||
# if not cls._system_app:
|
||||
# raise RuntimeError("System APP not set for DAGVar")
|
||||
return cls._system_app
|
||||
|
||||
@classmethod
|
||||
def set_current_system_app(cls, system_app: SystemApp) -> None:
|
||||
"""Set the current system app.
|
||||
|
||||
Args:
|
||||
system_app (SystemApp): The system app to set
|
||||
"""
|
||||
if cls._system_app:
|
||||
logger.warn("System APP has already set, nothing to do")
|
||||
logger.warning("System APP has already set, nothing to do")
|
||||
else:
|
||||
cls._system_app = system_app
|
||||
|
||||
@classmethod
|
||||
def get_executor(cls) -> Executor:
|
||||
def get_executor(cls) -> Optional[Executor]:
|
||||
"""Get the current executor.
|
||||
|
||||
Returns:
|
||||
Optional[Executor]: The current executor
|
||||
"""
|
||||
return cls._executor
|
||||
|
||||
@classmethod
|
||||
def set_executor(cls, executor: Executor) -> None:
|
||||
"""Set the current executor.
|
||||
|
||||
Args:
|
||||
executor (Executor): The executor to set
|
||||
"""
|
||||
cls._executor = executor
|
||||
|
||||
|
||||
class DAGLifecycle:
|
||||
"""The lifecycle of DAG"""
|
||||
"""The lifecycle of DAG."""
|
||||
|
||||
async def before_dag_run(self):
|
||||
"""The callback before DAG run"""
|
||||
"""Execute before DAG run."""
|
||||
pass
|
||||
|
||||
async def after_dag_end(self):
|
||||
"""The callback after DAG end,
|
||||
"""Execute after DAG end.
|
||||
|
||||
This method may be called multiple times, please make sure it is idempotent.
|
||||
"""
|
||||
@@ -184,6 +236,8 @@ class DAGLifecycle:
|
||||
|
||||
|
||||
class DAGNode(DAGLifecycle, DependencyMixin, ABC):
|
||||
"""The base class of DAGNode."""
|
||||
|
||||
resource_group: Optional[ResourceGroup] = None
|
||||
"""The resource group of current DAGNode"""
|
||||
|
||||
@@ -196,6 +250,17 @@ class DAGNode(DAGLifecycle, DependencyMixin, ABC):
|
||||
executor: Optional[Executor] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Initialize a DAGNode.
|
||||
|
||||
Args:
|
||||
dag (Optional["DAG"], optional): The DAG to add this node to.
|
||||
Defaults to None.
|
||||
node_id (Optional[str], optional): The node id. Defaults to None.
|
||||
node_name (Optional[str], optional): The node name. Defaults to None.
|
||||
system_app (Optional[SystemApp], optional): The system app.
|
||||
Defaults to None.
|
||||
executor (Optional[Executor], optional): The executor. Defaults to None.
|
||||
"""
|
||||
super().__init__()
|
||||
self._upstream: List["DAGNode"] = []
|
||||
self._downstream: List["DAGNode"] = []
|
||||
@@ -206,24 +271,28 @@ class DAGNode(DAGLifecycle, DependencyMixin, ABC):
|
||||
self._executor: Optional[Executor] = executor or DAGVar.get_executor()
|
||||
if not node_id and self._dag:
|
||||
node_id = self._dag._new_node_id()
|
||||
self._node_id: str = node_id
|
||||
self._node_name: str = node_name
|
||||
self._node_id: Optional[str] = node_id
|
||||
self._node_name: Optional[str] = node_name
|
||||
|
||||
@property
|
||||
def node_id(self) -> str:
|
||||
"""Return the node id of current DAGNode."""
|
||||
if not self._node_id:
|
||||
raise ValueError("Node id not set for current DAGNode")
|
||||
return self._node_id
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def dev_mode(self) -> bool:
|
||||
"""Whether current DAGNode is in dev mode"""
|
||||
"""Whether current DAGNode is in dev mode."""
|
||||
|
||||
@property
|
||||
def system_app(self) -> SystemApp:
|
||||
def system_app(self) -> Optional[SystemApp]:
|
||||
"""Return the system app of current DAGNode."""
|
||||
return self._system_app
|
||||
|
||||
def set_system_app(self, system_app: SystemApp) -> None:
|
||||
"""Set system app for current DAGNode
|
||||
"""Set system app for current DAGNode.
|
||||
|
||||
Args:
|
||||
system_app (SystemApp): The system app
|
||||
@@ -231,50 +300,97 @@ class DAGNode(DAGLifecycle, DependencyMixin, ABC):
|
||||
self._system_app = system_app
|
||||
|
||||
def set_node_id(self, node_id: str) -> None:
|
||||
"""Set node id for current DAGNode.
|
||||
|
||||
Args:
|
||||
node_id (str): The node id
|
||||
"""
|
||||
self._node_id = node_id
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Return the hash value of current DAGNode.
|
||||
|
||||
If the node_id is not None, return the hash value of node_id.
|
||||
"""
|
||||
if self.node_id:
|
||||
return hash(self.node_id)
|
||||
else:
|
||||
return super().__hash__()
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Return whether the current DAGNode is equal to other DAGNode."""
|
||||
if not isinstance(other, DAGNode):
|
||||
return False
|
||||
return self.node_id == other.node_id
|
||||
|
||||
@property
|
||||
def node_name(self) -> str:
|
||||
def node_name(self) -> Optional[str]:
|
||||
"""Return the node name of current DAGNode.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The node name of current DAGNode
|
||||
"""
|
||||
return self._node_name
|
||||
|
||||
@property
|
||||
def dag(self) -> "DAG":
|
||||
def dag(self) -> Optional["DAG"]:
|
||||
"""Return the DAG of current DAGNode.
|
||||
|
||||
Returns:
|
||||
Optional["DAG"]: The DAG of current DAGNode
|
||||
"""
|
||||
return self._dag
|
||||
|
||||
def set_upstream(self, nodes: DependencyType) -> "DAGNode":
|
||||
def set_upstream(self, nodes: DependencyType) -> None:
|
||||
"""Set upstream nodes for current node.
|
||||
|
||||
Args:
|
||||
nodes (DependencyType): Upstream nodes to be set to current node.
|
||||
"""
|
||||
self.set_dependency(nodes)
|
||||
|
||||
def set_downstream(self, nodes: DependencyType) -> "DAGNode":
|
||||
def set_downstream(self, nodes: DependencyType) -> None:
|
||||
"""Set downstream nodes for current node.
|
||||
|
||||
Args:
|
||||
nodes (DependencyType): Downstream nodes to be set to current node.
|
||||
"""
|
||||
self.set_dependency(nodes, is_upstream=False)
|
||||
|
||||
@property
|
||||
def upstream(self) -> List["DAGNode"]:
|
||||
"""Return the upstream nodes of current DAGNode.
|
||||
|
||||
Returns:
|
||||
List["DAGNode"]: The upstream nodes of current DAGNode
|
||||
"""
|
||||
return self._upstream
|
||||
|
||||
@property
|
||||
def downstream(self) -> List["DAGNode"]:
|
||||
"""Return the downstream nodes of current DAGNode.
|
||||
|
||||
Returns:
|
||||
List["DAGNode"]: The downstream nodes of current DAGNode
|
||||
"""
|
||||
return self._downstream
|
||||
|
||||
def set_dependency(self, nodes: DependencyType, is_upstream: bool = True) -> None:
|
||||
"""Set dependency for current node.
|
||||
|
||||
Args:
|
||||
nodes (DependencyType): The nodes to set dependency to current node.
|
||||
is_upstream (bool, optional): Whether set upstream nodes. Defaults to True.
|
||||
"""
|
||||
if not isinstance(nodes, Sequence):
|
||||
nodes = [nodes]
|
||||
if not all(isinstance(node, DAGNode) for node in nodes):
|
||||
raise ValueError(
|
||||
"all nodes to set dependency to current node must be instance of 'DAGNode'"
|
||||
"all nodes to set dependency to current node must be instance "
|
||||
"of 'DAGNode'"
|
||||
)
|
||||
nodes: Sequence[DAGNode] = nodes
|
||||
dags = set([node.dag for node in nodes if node.dag])
|
||||
nodes = cast(Sequence[DAGNode], nodes)
|
||||
dags = set([node.dag for node in nodes if node.dag]) # noqa: C403
|
||||
if self.dag:
|
||||
dags.add(self.dag)
|
||||
if not dags:
|
||||
@@ -302,6 +418,7 @@ class DAGNode(DAGLifecycle, DependencyMixin, ABC):
|
||||
node._upstream.append(self)
|
||||
|
||||
def __repr__(self):
|
||||
"""Return the representation of current DAGNode."""
|
||||
cls_name = self.__class__.__name__
|
||||
if self.node_name and self.node_name:
|
||||
return f"{cls_name}(node_id={self.node_id}, node_name={self.node_name})"
|
||||
@@ -313,6 +430,7 @@ class DAGNode(DAGLifecycle, DependencyMixin, ABC):
|
||||
return f"{cls_name}"
|
||||
|
||||
def __str__(self):
|
||||
"""Return the string of current DAGNode."""
|
||||
return self.__repr__()
|
||||
|
||||
|
||||
@@ -321,7 +439,7 @@ def _build_task_key(task_name: str, key: str) -> str:
|
||||
|
||||
|
||||
class DAGContext:
|
||||
"""The context of current DAG, created when the DAG is running
|
||||
"""The context of current DAG, created when the DAG is running.
|
||||
|
||||
Every DAG has been triggered will create a new DAGContext.
|
||||
"""
|
||||
@@ -329,22 +447,32 @@ class DAGContext:
|
||||
def __init__(
|
||||
self,
|
||||
streaming_call: bool = False,
|
||||
node_to_outputs: Dict[str, TaskContext] = None,
|
||||
node_name_to_ids: Dict[str, str] = None,
|
||||
node_to_outputs: Optional[Dict[str, TaskContext]] = None,
|
||||
node_name_to_ids: Optional[Dict[str, str]] = None,
|
||||
) -> None:
|
||||
"""Initialize a DAGContext.
|
||||
|
||||
Args:
|
||||
streaming_call (bool, optional): Whether the current DAG is streaming call.
|
||||
Defaults to False.
|
||||
node_to_outputs (Optional[Dict[str, TaskContext]], optional):
|
||||
The task outputs of current DAG. Defaults to None.
|
||||
node_name_to_ids (Optional[Dict[str, str]], optional):
|
||||
The task name to task id mapping. Defaults to None.
|
||||
"""
|
||||
if not node_to_outputs:
|
||||
node_to_outputs = {}
|
||||
if not node_name_to_ids:
|
||||
node_name_to_ids = {}
|
||||
self._streaming_call = streaming_call
|
||||
self._curr_task_ctx = None
|
||||
self._curr_task_ctx: Optional[TaskContext] = None
|
||||
self._share_data: Dict[str, Any] = {}
|
||||
self._node_to_outputs = node_to_outputs
|
||||
self._node_name_to_ids = node_name_to_ids
|
||||
self._node_to_outputs: Dict[str, TaskContext] = node_to_outputs
|
||||
self._node_name_to_ids: Dict[str, str] = node_name_to_ids
|
||||
|
||||
@property
|
||||
def _task_outputs(self) -> Dict[str, TaskContext]:
|
||||
"""The task outputs of current DAG
|
||||
"""Return the task outputs of current DAG.
|
||||
|
||||
Just use for internal for now.
|
||||
Returns:
|
||||
@@ -354,18 +482,28 @@ class DAGContext:
|
||||
|
||||
@property
|
||||
def current_task_context(self) -> TaskContext:
|
||||
"""Return the current task context."""
|
||||
if not self._curr_task_ctx:
|
||||
raise RuntimeError("Current task context not set")
|
||||
return self._curr_task_ctx
|
||||
|
||||
@property
|
||||
def streaming_call(self) -> bool:
|
||||
"""Whether the current DAG is streaming call"""
|
||||
"""Whether the current DAG is streaming call."""
|
||||
return self._streaming_call
|
||||
|
||||
def set_current_task_context(self, _curr_task_ctx: TaskContext) -> None:
|
||||
"""Set the current task context.
|
||||
|
||||
When the task is running, the current task context
|
||||
will be set to the task context.
|
||||
|
||||
TODO: We should support parallel task running in the future.
|
||||
"""
|
||||
self._curr_task_ctx = _curr_task_ctx
|
||||
|
||||
def get_task_output(self, task_name: str) -> TaskOutput:
|
||||
"""Get the task output by task name
|
||||
"""Get the task output by task name.
|
||||
|
||||
Args:
|
||||
task_name (str): The task name
|
||||
@@ -376,22 +514,41 @@ class DAGContext:
|
||||
if task_name is None:
|
||||
raise ValueError("task_name can't be None")
|
||||
node_id = self._node_name_to_ids.get(task_name)
|
||||
if node_id:
|
||||
if not node_id:
|
||||
raise ValueError(f"Task name {task_name} not exists in DAG")
|
||||
return self._task_outputs.get(node_id).task_output
|
||||
task_output = self._task_outputs.get(node_id)
|
||||
if not task_output:
|
||||
raise ValueError(f"Task output for task {task_name} not exists")
|
||||
return task_output.task_output
|
||||
|
||||
async def get_from_share_data(self, key: str) -> Any:
|
||||
"""Get share data by key.
|
||||
|
||||
Args:
|
||||
key (str): The share data key
|
||||
|
||||
Returns:
|
||||
Any: The share data, you can cast it to the real type
|
||||
"""
|
||||
return self._share_data.get(key)
|
||||
|
||||
async def save_to_share_data(
|
||||
self, key: str, data: Any, overwrite: bool = False
|
||||
) -> None:
|
||||
"""Save share data by key.
|
||||
|
||||
Args:
|
||||
key (str): The share data key
|
||||
data (Any): The share data
|
||||
overwrite (bool): Whether overwrite the share data if the key
|
||||
already exists. Defaults to None.
|
||||
"""
|
||||
if key in self._share_data and not overwrite:
|
||||
raise ValueError(f"Share data key {key} already exists")
|
||||
self._share_data[key] = data
|
||||
|
||||
async def get_task_share_data(self, task_name: str, key: str) -> Any:
|
||||
"""Get share data by task name and key
|
||||
"""Get share data by task name and key.
|
||||
|
||||
Args:
|
||||
task_name (str): The task name
|
||||
@@ -409,14 +566,14 @@ class DAGContext:
|
||||
async def save_task_share_data(
|
||||
self, task_name: str, key: str, data: Any, overwrite: bool = False
|
||||
) -> None:
|
||||
"""Save share data by task name and key
|
||||
"""Save share data by task name and key.
|
||||
|
||||
Args:
|
||||
task_name (str): The task name
|
||||
key (str): The share data key
|
||||
data (Any): The share data
|
||||
overwrite (bool): Whether overwrite the share data if the key already exists.
|
||||
Defaults to None.
|
||||
overwrite (bool): Whether overwrite the share data if the key
|
||||
already exists. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: If the share data key already exists and overwrite is not True
|
||||
@@ -429,15 +586,22 @@ class DAGContext:
|
||||
|
||||
|
||||
class DAG:
|
||||
"""The DAG class.
|
||||
|
||||
Manage the DAG nodes and the relationship between them.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, dag_id: str, resource_group: Optional[ResourceGroup] = None
|
||||
) -> None:
|
||||
"""Initialize a DAG."""
|
||||
self._dag_id = dag_id
|
||||
self.node_map: Dict[str, DAGNode] = {}
|
||||
self.node_name_to_node: Dict[str, DAGNode] = {}
|
||||
self._root_nodes: List[DAGNode] = None
|
||||
self._leaf_nodes: List[DAGNode] = None
|
||||
self._trigger_nodes: List[DAGNode] = None
|
||||
self._root_nodes: List[DAGNode] = []
|
||||
self._leaf_nodes: List[DAGNode] = []
|
||||
self._trigger_nodes: List[DAGNode] = []
|
||||
self._resource_group: Optional[ResourceGroup] = resource_group
|
||||
|
||||
def _append_node(self, node: DAGNode) -> None:
|
||||
if node.node_id in self.node_map:
|
||||
@@ -448,22 +612,26 @@ class DAG:
|
||||
f"Node name {node.node_name} already exists in DAG {self.dag_id}"
|
||||
)
|
||||
self.node_name_to_node[node.node_name] = node
|
||||
self.node_map[node.node_id] = node
|
||||
node_id = node.node_id
|
||||
if not node_id:
|
||||
raise ValueError("Node id can't be None")
|
||||
self.node_map[node_id] = node
|
||||
# clear cached nodes
|
||||
self._root_nodes = None
|
||||
self._leaf_nodes = None
|
||||
self._root_nodes = []
|
||||
self._leaf_nodes = []
|
||||
|
||||
def _new_node_id(self) -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
@property
|
||||
def dag_id(self) -> str:
|
||||
"""Return the dag id of current DAG."""
|
||||
return self._dag_id
|
||||
|
||||
def _build(self) -> None:
|
||||
from ..operator.common_operator import TriggerOperator
|
||||
|
||||
nodes = set()
|
||||
nodes: Set[DAGNode] = set()
|
||||
for _, node in self.node_map.items():
|
||||
nodes = nodes.union(_get_nodes(node))
|
||||
self._root_nodes = list(set(filter(lambda x: not x.upstream, nodes)))
|
||||
@@ -474,7 +642,7 @@ class DAG:
|
||||
|
||||
@property
|
||||
def root_nodes(self) -> List[DAGNode]:
|
||||
"""The root nodes of current DAG
|
||||
"""Return the root nodes of current DAG.
|
||||
|
||||
Returns:
|
||||
List[DAGNode]: The root nodes of current DAG, no repeat
|
||||
@@ -485,7 +653,7 @@ class DAG:
|
||||
|
||||
@property
|
||||
def leaf_nodes(self) -> List[DAGNode]:
|
||||
"""The leaf nodes of current DAG
|
||||
"""Return the leaf nodes of current DAG.
|
||||
|
||||
Returns:
|
||||
List[DAGNode]: The leaf nodes of current DAG, no repeat
|
||||
@@ -496,7 +664,7 @@ class DAG:
|
||||
|
||||
@property
|
||||
def trigger_nodes(self) -> List[DAGNode]:
|
||||
"""The trigger nodes of current DAG
|
||||
"""Return the trigger nodes of current DAG.
|
||||
|
||||
Returns:
|
||||
List[DAGNode]: The trigger nodes of current DAG, no repeat
|
||||
@@ -506,34 +674,42 @@ class DAG:
|
||||
return self._trigger_nodes
|
||||
|
||||
async def _after_dag_end(self) -> None:
|
||||
"""The callback after DAG end"""
|
||||
"""Execute after DAG end."""
|
||||
tasks = []
|
||||
for node in self.node_map.values():
|
||||
tasks.append(node.after_dag_end())
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
def print_tree(self) -> None:
|
||||
"""Print the DAG tree"""
|
||||
"""Print the DAG tree""" # noqa: D400
|
||||
_print_format_dag_tree(self)
|
||||
|
||||
def visualize_dag(self, view: bool = True, **kwargs) -> Optional[str]:
|
||||
"""Create the DAG graph"""
|
||||
"""Visualize the DAG.
|
||||
|
||||
Args:
|
||||
view (bool, optional): Whether view the DAG graph. Defaults to True,
|
||||
if True, it will open the graph file with your default viewer.
|
||||
"""
|
||||
self.print_tree()
|
||||
return _visualize_dag(self, view=view, **kwargs)
|
||||
|
||||
def __enter__(self):
|
||||
"""Enter a DAG context."""
|
||||
DAGVar.enter_dag(self)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Exit a DAG context."""
|
||||
DAGVar.exit_dag()
|
||||
|
||||
def __repr__(self):
|
||||
"""Return the representation of current DAG."""
|
||||
return f"DAG(dag_id={self.dag_id})"
|
||||
|
||||
|
||||
def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> set[DAGNode]:
|
||||
nodes = set()
|
||||
def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> Set[DAGNode]:
|
||||
nodes: Set[DAGNode] = set()
|
||||
if not node:
|
||||
return nodes
|
||||
nodes.add(node)
|
||||
@@ -553,7 +729,7 @@ def _print_dag(
|
||||
level: int = 0,
|
||||
prefix: str = "",
|
||||
last: bool = True,
|
||||
level_dict: Dict[str, Any] = None,
|
||||
level_dict: Optional[Dict[int, Any]] = None,
|
||||
):
|
||||
if level_dict is None:
|
||||
level_dict = {}
|
||||
@@ -606,7 +782,7 @@ def _handle_dag_nodes(
|
||||
|
||||
|
||||
def _visualize_dag(dag: DAG, view: bool = True, **kwargs) -> Optional[str]:
|
||||
"""Visualize the DAG
|
||||
"""Visualize the DAG.
|
||||
|
||||
Args:
|
||||
dag (DAG): The DAG to visualize
|
||||
@@ -641,7 +817,7 @@ def _visualize_dag(dag: DAG, view: bool = True, **kwargs) -> Optional[str]:
|
||||
filename = kwargs["filename"]
|
||||
del kwargs["filename"]
|
||||
|
||||
if not "directory" in kwargs:
|
||||
if "directory" not in kwargs:
|
||||
from dbgpt.configs.model_config import LOGDIR
|
||||
|
||||
kwargs["directory"] = LOGDIR
|
||||
|
@@ -1,3 +1,8 @@
|
||||
"""DAGManager is a component of AWEL, it is used to manage DAGs.
|
||||
|
||||
DAGManager will load DAGs from dag_dirs, and register the trigger nodes
|
||||
to TriggerManager.
|
||||
"""
|
||||
import logging
|
||||
from typing import Dict, List
|
||||
|
||||
@@ -10,24 +15,35 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DAGManager(BaseComponent):
|
||||
"""The component of DAGManager."""
|
||||
|
||||
name = ComponentType.AWEL_DAG_MANAGER
|
||||
|
||||
def __init__(self, system_app: SystemApp, dag_dirs: List[str]):
|
||||
"""Initialize a DAGManager.
|
||||
|
||||
Args:
|
||||
system_app (SystemApp): The system app.
|
||||
dag_dirs (List[str]): The directories to load DAGs.
|
||||
"""
|
||||
super().__init__(system_app)
|
||||
self.dag_loader = LocalFileDAGLoader(dag_dirs)
|
||||
self.system_app = system_app
|
||||
self.dag_map: Dict[str, DAG] = {}
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
"""Initialize the DAGManager."""
|
||||
self.system_app = system_app
|
||||
|
||||
def load_dags(self):
|
||||
"""Load DAGs from dag_dirs."""
|
||||
dags = self.dag_loader.load_dags()
|
||||
triggers = []
|
||||
for dag in dags:
|
||||
dag_id = dag.dag_id
|
||||
if dag_id in self.dag_map:
|
||||
raise ValueError(f"Load DAG error, DAG ID {dag_id} has already exist")
|
||||
self.dag_map[dag_id] = dag
|
||||
triggers += dag.trigger_nodes
|
||||
from ..trigger.trigger_manager import DefaultTriggerManager
|
||||
|
||||
|
@@ -1,3 +1,8 @@
|
||||
"""DAG loader.
|
||||
|
||||
DAGLoader will load DAGs from dag_dirs or other sources.
|
||||
Now only support load DAGs from local files.
|
||||
"""
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
@@ -12,16 +17,26 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DAGLoader(ABC):
|
||||
"""Abstract base class representing a loader for loading DAGs."""
|
||||
|
||||
@abstractmethod
|
||||
def load_dags(self) -> List[DAG]:
|
||||
"""Load dags"""
|
||||
"""Load dags."""
|
||||
|
||||
|
||||
class LocalFileDAGLoader(DAGLoader):
|
||||
"""DAG loader for loading DAGs from local files."""
|
||||
|
||||
def __init__(self, dag_dirs: List[str]) -> None:
|
||||
"""Initialize a LocalFileDAGLoader.
|
||||
|
||||
Args:
|
||||
dag_dirs (List[str]): The directories to load DAGs.
|
||||
"""
|
||||
self._dag_dirs = dag_dirs
|
||||
|
||||
def load_dags(self) -> List[DAG]:
|
||||
"""Load dags from local files."""
|
||||
dags = []
|
||||
for filepath in self._dag_dirs:
|
||||
if not os.path.exists(filepath):
|
||||
@@ -70,7 +85,7 @@ def _load_modules_from_file(filepath: str):
|
||||
sys.modules[spec.name] = new_module
|
||||
loader.exec_module(new_module)
|
||||
return [new_module]
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
msg = traceback.format_exc()
|
||||
logger.error(f"Failed to import: {filepath}, error message: {msg}")
|
||||
# TODO save error message
|
||||
|
@@ -0,0 +1 @@
|
||||
"""The module of operator."""
|
||||
|
@@ -1,7 +1,7 @@
|
||||
"""Base classes for operators that can be executed within a workflow."""
|
||||
import asyncio
|
||||
import functools
|
||||
from abc import ABC, ABCMeta, abstractmethod
|
||||
from inspect import signature
|
||||
from types import FunctionType
|
||||
from typing import (
|
||||
Any,
|
||||
@@ -9,7 +9,6 @@ from typing import (
|
||||
Dict,
|
||||
Generic,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
@@ -21,7 +20,6 @@ from dbgpt.util.executor_utils import (
|
||||
AsyncToSyncIterator,
|
||||
BlockingFunction,
|
||||
DefaultExecutorFactory,
|
||||
ExecutorFactory,
|
||||
blocking_func_to_async,
|
||||
)
|
||||
|
||||
@@ -54,13 +52,15 @@ class WorkflowRunner(ABC, Generic[T]):
|
||||
node (RunnableDAGNode): The starting node of the workflow to be executed.
|
||||
call_data (CALL_DATA): The data pass to root operator node.
|
||||
streaming_call (bool): Whether the call is a streaming call.
|
||||
exist_dag_ctx (DAGContext): The context of the DAG when this node is run, Defaults to None.
|
||||
exist_dag_ctx (DAGContext): The context of the DAG when this node is run,
|
||||
Defaults to None.
|
||||
Returns:
|
||||
DAGContext: The context after executing the workflow, containing the final state and data.
|
||||
DAGContext: The context after executing the workflow, containing the final
|
||||
state and data.
|
||||
"""
|
||||
|
||||
|
||||
default_runner: WorkflowRunner = None
|
||||
default_runner: Optional[WorkflowRunner] = None
|
||||
|
||||
|
||||
class BaseOperatorMeta(ABCMeta):
|
||||
@@ -68,8 +68,7 @@ class BaseOperatorMeta(ABCMeta):
|
||||
|
||||
@classmethod
|
||||
def _apply_defaults(cls, func: F) -> F:
|
||||
sig_cache = signature(func)
|
||||
|
||||
# sig_cache = signature(func)
|
||||
@functools.wraps(func)
|
||||
def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any:
|
||||
dag: Optional[DAG] = kwargs.get("dag") or DAGVar.get_current_dag()
|
||||
@@ -81,7 +80,7 @@ class BaseOperatorMeta(ABCMeta):
|
||||
if not executor:
|
||||
if system_app:
|
||||
executor = system_app.get_component(
|
||||
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||
ComponentType.EXECUTOR_DEFAULT, DefaultExecutorFactory
|
||||
).create()
|
||||
else:
|
||||
executor = DefaultExecutorFactory().create()
|
||||
@@ -107,9 +106,10 @@ class BaseOperatorMeta(ABCMeta):
|
||||
real_obj = func(self, *args, **kwargs)
|
||||
return real_obj
|
||||
|
||||
return cast(T, apply_defaults)
|
||||
return cast(F, apply_defaults)
|
||||
|
||||
def __new__(cls, name, bases, namespace, **kwargs):
|
||||
"""Create a new BaseOperator class with default arguments."""
|
||||
new_cls = super().__new__(cls, name, bases, namespace, **kwargs)
|
||||
new_cls.__init__ = cls._apply_defaults(new_cls.__init__)
|
||||
return new_cls
|
||||
@@ -126,13 +126,14 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
task_id: Optional[str] = None,
|
||||
task_name: Optional[str] = None,
|
||||
dag: Optional[DAG] = None,
|
||||
runner: WorkflowRunner = None,
|
||||
runner: Optional[WorkflowRunner] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Initializes a BaseOperator with an optional workflow runner.
|
||||
"""Create a BaseOperator with an optional workflow runner.
|
||||
|
||||
Args:
|
||||
runner (WorkflowRunner, optional): The runner used to execute the workflow. Defaults to None.
|
||||
runner (WorkflowRunner, optional): The runner used to execute the workflow.
|
||||
Defaults to None.
|
||||
"""
|
||||
super().__init__(node_id=task_id, node_name=task_name, dag=dag, **kwargs)
|
||||
if not runner:
|
||||
@@ -141,19 +142,24 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
runner = DefaultWorkflowRunner()
|
||||
|
||||
self._runner: WorkflowRunner = runner
|
||||
self._dag_ctx: DAGContext = None
|
||||
self._dag_ctx: Optional[DAGContext] = None
|
||||
|
||||
@property
|
||||
def current_dag_context(self) -> DAGContext:
|
||||
"""Return the current DAG context."""
|
||||
if not self._dag_ctx:
|
||||
raise ValueError("DAGContext is not set")
|
||||
return self._dag_ctx
|
||||
|
||||
@property
|
||||
def dev_mode(self) -> bool:
|
||||
"""Whether the operator is in dev mode.
|
||||
|
||||
In production mode, the default runner is not None.
|
||||
|
||||
Returns:
|
||||
bool: Whether the operator is in dev mode. True if the default runner is None.
|
||||
bool: Whether the operator is in dev mode. True if the
|
||||
default runner is None.
|
||||
"""
|
||||
return default_runner is None
|
||||
|
||||
@@ -186,7 +192,8 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
|
||||
Args:
|
||||
call_data (CALL_DATA): The data pass to root operator node.
|
||||
dag_ctx (DAGContext): The context of the DAG when this node is run, Defaults to None.
|
||||
dag_ctx (DAGContext): The context of the DAG when this node is run,
|
||||
Defaults to None.
|
||||
Returns:
|
||||
OUT: The output of the node after execution.
|
||||
"""
|
||||
@@ -196,7 +203,9 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
return out_ctx.current_task_context.task_output.output
|
||||
|
||||
def _blocking_call(
|
||||
self, call_data: Optional[CALL_DATA] = None, loop: asyncio.BaseEventLoop = None
|
||||
self,
|
||||
call_data: Optional[CALL_DATA] = None,
|
||||
loop: Optional[asyncio.BaseEventLoop] = None,
|
||||
) -> OUT:
|
||||
"""Execute the node and return the output.
|
||||
|
||||
@@ -213,6 +222,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
|
||||
if not loop:
|
||||
loop = get_or_create_event_loop()
|
||||
loop = cast(asyncio.BaseEventLoop, loop)
|
||||
return loop.run_until_complete(self.call(call_data))
|
||||
|
||||
async def call_stream(
|
||||
@@ -226,7 +236,8 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
|
||||
Args:
|
||||
call_data (CALL_DATA): The data pass to root operator node.
|
||||
dag_ctx (DAGContext): The context of the DAG when this node is run, Defaults to None.
|
||||
dag_ctx (DAGContext): The context of the DAG when this node is run,
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
AsyncIterator[OUT]: An asynchronous iterator over the output stream.
|
||||
@@ -237,7 +248,9 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
return out_ctx.current_task_context.task_output.output_stream
|
||||
|
||||
def _blocking_call_stream(
|
||||
self, call_data: Optional[CALL_DATA] = None, loop: asyncio.BaseEventLoop = None
|
||||
self,
|
||||
call_data: Optional[CALL_DATA] = None,
|
||||
loop: Optional[asyncio.BaseEventLoop] = None,
|
||||
) -> Iterator[OUT]:
|
||||
"""Execute the node and return the output as a stream.
|
||||
|
||||
@@ -259,9 +272,22 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
async def blocking_func_to_async(
|
||||
self, func: BlockingFunction, *args, **kwargs
|
||||
) -> Any:
|
||||
"""Execute a blocking function asynchronously.
|
||||
|
||||
In AWEL, the operators are executed asynchronously. However,
|
||||
some functions are blocking, we run them in a separate thread.
|
||||
|
||||
Args:
|
||||
func (BlockingFunction): The blocking function to be executed.
|
||||
*args: Positional arguments for the function.
|
||||
**kwargs: Keyword arguments for the function.
|
||||
"""
|
||||
if not self._executor:
|
||||
raise ValueError("Executor is not set")
|
||||
return await blocking_func_to_async(self._executor, func, *args, **kwargs)
|
||||
|
||||
|
||||
def initialize_runner(runner: WorkflowRunner):
|
||||
"""Initialize the default runner."""
|
||||
global default_runner
|
||||
default_runner = runner
|
||||
|
@@ -1,7 +1,7 @@
|
||||
"""Common operators of AWEL."""
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
@@ -13,7 +13,17 @@ from typing import (
|
||||
)
|
||||
|
||||
from ..dag.base import DAGContext
|
||||
from ..task.base import IN, OUT, InputContext, InputSource, TaskContext, TaskOutput
|
||||
from ..task.base import (
|
||||
IN,
|
||||
OUT,
|
||||
InputContext,
|
||||
InputSource,
|
||||
JoinFunc,
|
||||
MapFunc,
|
||||
ReduceFunc,
|
||||
TaskContext,
|
||||
TaskOutput,
|
||||
)
|
||||
from .base import BaseOperator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -25,7 +35,12 @@ class JoinOperator(BaseOperator, Generic[OUT]):
|
||||
This node type is useful for combining the outputs of upstream nodes.
|
||||
"""
|
||||
|
||||
def __init__(self, combine_function, **kwargs):
|
||||
def __init__(self, combine_function: JoinFunc, **kwargs):
|
||||
"""Create a JoinDAGNode with a combine function.
|
||||
|
||||
Args:
|
||||
combine_function: A function that defines how to combine inputs.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
if not callable(combine_function):
|
||||
raise ValueError("combine_function must be callable")
|
||||
@@ -33,6 +48,7 @@ class JoinOperator(BaseOperator, Generic[OUT]):
|
||||
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
"""Run the join operation on the DAG context's inputs.
|
||||
|
||||
Args:
|
||||
dag_ctx (DAGContext): The current context of the DAG.
|
||||
|
||||
@@ -50,8 +66,10 @@ class JoinOperator(BaseOperator, Generic[OUT]):
|
||||
|
||||
|
||||
class ReduceStreamOperator(BaseOperator, Generic[IN, OUT]):
|
||||
def __init__(self, reduce_function=None, **kwargs):
|
||||
"""Initializes a ReduceStreamOperator with a combine function.
|
||||
"""Operator that reduces inputs using a custom reduce function."""
|
||||
|
||||
def __init__(self, reduce_function: Optional[ReduceFunc] = None, **kwargs):
|
||||
"""Create a ReduceStreamOperator with a combine function.
|
||||
|
||||
Args:
|
||||
combine_function: A function that defines how to combine inputs.
|
||||
@@ -89,6 +107,7 @@ class ReduceStreamOperator(BaseOperator, Generic[IN, OUT]):
|
||||
return reduce_output
|
||||
|
||||
async def reduce(self, input_value: AsyncIterator[IN]) -> OUT:
|
||||
"""Reduce the input stream to a single value."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -99,8 +118,8 @@ class MapOperator(BaseOperator, Generic[IN, OUT]):
|
||||
passes the transformed data downstream.
|
||||
"""
|
||||
|
||||
def __init__(self, map_function=None, **kwargs):
|
||||
"""Initializes a MapDAGNode with a mapping function.
|
||||
def __init__(self, map_function: Optional[MapFunc] = None, **kwargs):
|
||||
"""Create a MapDAGNode with a mapping function.
|
||||
|
||||
Args:
|
||||
map_function: A function that defines how to map the input data.
|
||||
@@ -133,13 +152,18 @@ class MapOperator(BaseOperator, Generic[IN, OUT]):
|
||||
if not call_data and not curr_task_ctx.task_input.check_single_parent():
|
||||
num_parents = len(curr_task_ctx.task_input.parent_outputs)
|
||||
raise ValueError(
|
||||
f"task {curr_task_ctx.task_id} MapDAGNode expects single parent, now number of parents: {num_parents}"
|
||||
f"task {curr_task_ctx.task_id} MapDAGNode expects single parent,"
|
||||
f"now number of parents: {num_parents}"
|
||||
)
|
||||
map_function = self.map_function or self.map
|
||||
|
||||
if call_data:
|
||||
call_data = await curr_task_ctx._call_data_to_output()
|
||||
output = await call_data.map(map_function)
|
||||
wrapped_call_data = await curr_task_ctx._call_data_to_output()
|
||||
if not wrapped_call_data:
|
||||
raise ValueError(
|
||||
f"task {curr_task_ctx.task_id} MapDAGNode expects wrapped_call_data"
|
||||
)
|
||||
output: TaskOutput[OUT] = await wrapped_call_data.map(map_function)
|
||||
curr_task_ctx.set_task_output(output)
|
||||
return output
|
||||
|
||||
@@ -150,6 +174,7 @@ class MapOperator(BaseOperator, Generic[IN, OUT]):
|
||||
return output
|
||||
|
||||
async def map(self, input_value: IN) -> OUT:
|
||||
"""Map the input data to a new value."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -161,6 +186,11 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
|
||||
|
||||
This node filters its input data using a branching function and
|
||||
allows for conditional paths in the workflow.
|
||||
|
||||
If a branch function returns True, the corresponding task will be executed.
|
||||
otherwise, the corresponding task will be skipped, and the output of
|
||||
this skip node will be set to `SKIP_DATA`
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -168,11 +198,11 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
|
||||
branches: Optional[Dict[BranchFunc[IN], Union[BaseOperator, str]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initializes a BranchDAGNode with a branching function.
|
||||
"""Create a BranchDAGNode with a branching function.
|
||||
|
||||
Args:
|
||||
branches (Dict[BranchFunc[IN], Union[BaseOperator, str]]): Dict of function that defines the branching condition.
|
||||
branches (Dict[BranchFunc[IN], Union[BaseOperator, str]]):
|
||||
Dict of function that defines the branching condition.
|
||||
|
||||
Raises:
|
||||
ValueError: If the branch_function is not callable.
|
||||
@@ -183,7 +213,9 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
|
||||
if not callable(branch_function):
|
||||
raise ValueError("branch_function must be callable")
|
||||
if isinstance(value, BaseOperator):
|
||||
branches[branch_function] = value.node_name or value.node_name
|
||||
if not value.node_name:
|
||||
raise ValueError("branch node name must be set")
|
||||
branches[branch_function] = value.node_name
|
||||
self._branches = branches
|
||||
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
@@ -210,7 +242,7 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
|
||||
branches = await self.branches()
|
||||
|
||||
branch_func_tasks = []
|
||||
branch_nodes: List[str] = []
|
||||
branch_nodes: List[Union[BaseOperator, str]] = []
|
||||
for func, node_name in branches.items():
|
||||
branch_nodes.append(node_name)
|
||||
branch_func_tasks.append(
|
||||
@@ -225,20 +257,25 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
|
||||
node_name = branch_nodes[i]
|
||||
branch_out = ctx.parent_outputs[0].task_output
|
||||
logger.info(
|
||||
f"branch_input_ctxs {i} result {branch_out.output}, is_empty: {branch_out.is_empty}"
|
||||
f"branch_input_ctxs {i} result {branch_out.output}, "
|
||||
f"is_empty: {branch_out.is_empty}"
|
||||
)
|
||||
if ctx.parent_outputs[0].task_output.is_empty:
|
||||
if ctx.parent_outputs[0].task_output.is_none:
|
||||
logger.info(f"Skip node name {node_name}")
|
||||
skip_node_names.append(node_name)
|
||||
curr_task_ctx.update_metadata("skip_node_names", skip_node_names)
|
||||
return parent_output
|
||||
|
||||
async def branches(self) -> Dict[BranchFunc[IN], Union[BaseOperator, str]]:
|
||||
"""Return branch logic based on input data."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class InputOperator(BaseOperator, Generic[OUT]):
|
||||
"""Operator node that reads data from an input source."""
|
||||
|
||||
def __init__(self, input_source: InputSource[OUT], **kwargs) -> None:
|
||||
"""Create an InputDAGNode with an input source."""
|
||||
super().__init__(**kwargs)
|
||||
self._input_source = input_source
|
||||
|
||||
@@ -250,7 +287,10 @@ class InputOperator(BaseOperator, Generic[OUT]):
|
||||
|
||||
|
||||
class TriggerOperator(InputOperator, Generic[OUT]):
|
||||
"""Operator node that triggers the DAG to run."""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
"""Create a TriggerDAGNode."""
|
||||
from ..task.task_impl import SimpleCallDataInputSource
|
||||
|
||||
super().__init__(input_source=SimpleCallDataInputSource(), **kwargs)
|
||||
|
@@ -1,3 +1,4 @@
|
||||
"""The module of stream operator."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import AsyncIterator, Generic
|
||||
|
||||
@@ -7,12 +8,18 @@ from .base import BaseOperator
|
||||
|
||||
|
||||
class StreamifyAbsOperator(BaseOperator[OUT], ABC, Generic[IN, OUT]):
|
||||
"""An abstract operator that converts a value of IN to an AsyncIterator[OUT]."""
|
||||
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||
call_data = curr_task_ctx.call_data
|
||||
if call_data:
|
||||
call_data = await curr_task_ctx._call_data_to_output()
|
||||
output = await call_data.streamify(self.streamify)
|
||||
wrapped_call_data = await curr_task_ctx._call_data_to_output()
|
||||
if not wrapped_call_data:
|
||||
raise ValueError(
|
||||
f"task {curr_task_ctx.task_id} MapDAGNode expects wrapped_call_data"
|
||||
)
|
||||
output = await wrapped_call_data.streamify(self.streamify)
|
||||
curr_task_ctx.set_task_output(output)
|
||||
return output
|
||||
output = await curr_task_ctx.task_input.parent_outputs[0].task_output.streamify(
|
||||
@@ -23,26 +30,28 @@ class StreamifyAbsOperator(BaseOperator[OUT], ABC, Generic[IN, OUT]):
|
||||
|
||||
@abstractmethod
|
||||
async def streamify(self, input_value: IN) -> AsyncIterator[OUT]:
|
||||
"""Convert a value of IN to an AsyncIterator[OUT]
|
||||
"""Convert a value of IN to an AsyncIterator[OUT].
|
||||
|
||||
Args:
|
||||
input_value (IN): The data of parent operator's output
|
||||
|
||||
Example:
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
.. code-block:: python
|
||||
class MyStreamOperator(StreamifyAbsOperator[int, int]):
|
||||
async def streamify(self, input_value: int) -> AsyncIterator[int]:
|
||||
for i in range(input_value):
|
||||
yield i
|
||||
|
||||
class MyStreamOperator(StreamifyAbsOperator[int, int]):
|
||||
async def streamify(self, input_value: int) -> AsyncIterator[int]:
|
||||
for i in range(input_value):
|
||||
yield i
|
||||
"""
|
||||
|
||||
|
||||
class UnstreamifyAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
|
||||
"""An abstract operator that converts a value of AsyncIterator[IN] to an OUT."""
|
||||
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||
output = await curr_task_ctx.task_input.parent_outputs[
|
||||
output: TaskOutput[OUT] = await curr_task_ctx.task_input.parent_outputs[
|
||||
0
|
||||
].task_output.unstreamify(self.unstreamify)
|
||||
curr_task_ctx.set_task_output(output)
|
||||
@@ -56,24 +65,30 @@ class UnstreamifyAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
|
||||
input_value (AsyncIterator[IN])): The data of parent operator's output
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class MyUnstreamOperator(UnstreamifyAbsOperator[int, int]):
|
||||
async def unstreamify(self, input_value: AsyncIterator[int]) -> int:
|
||||
value_cnt = 0
|
||||
async for v in input_value:
|
||||
value_cnt += 1
|
||||
return value_cnt
|
||||
class MyUnstreamOperator(UnstreamifyAbsOperator[int, int]):
|
||||
async def unstreamify(self, input_value: AsyncIterator[int]) -> int:
|
||||
value_cnt = 0
|
||||
async for v in input_value:
|
||||
value_cnt += 1
|
||||
return value_cnt
|
||||
"""
|
||||
|
||||
|
||||
class TransformStreamAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
|
||||
"""Streaming to other streaming data.
|
||||
|
||||
An abstract operator that transforms a value of
|
||||
AsyncIterator[IN] to another AsyncIterator[OUT].
|
||||
"""
|
||||
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||
output = await curr_task_ctx.task_input.parent_outputs[
|
||||
output: TaskOutput[OUT] = await curr_task_ctx.task_input.parent_outputs[
|
||||
0
|
||||
].task_output.transform_stream(self.transform_stream)
|
||||
|
||||
curr_task_ctx.set_task_output(output)
|
||||
return output
|
||||
|
||||
@@ -81,19 +96,18 @@ class TransformStreamAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
|
||||
async def transform_stream(
|
||||
self, input_value: AsyncIterator[IN]
|
||||
) -> AsyncIterator[OUT]:
|
||||
"""Transform an AsyncIterator[IN] to another AsyncIterator[OUT] using a given function.
|
||||
"""Transform an AsyncIterator[IN] to another AsyncIterator[OUT].
|
||||
|
||||
Args:
|
||||
input_value (AsyncIterator[IN])): The data of parent operator's output
|
||||
|
||||
Example:
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class MyTransformStreamOperator(TransformStreamAbsOperator[int, int]):
|
||||
async def unstreamify(
|
||||
self, input_value: AsyncIterator[int]
|
||||
) -> AsyncIterator[int]:
|
||||
async for v in input_value:
|
||||
yield v + 1
|
||||
class MyTransformStreamOperator(TransformStreamAbsOperator[int, int]):
|
||||
async def unstreamify(
|
||||
self, input_value: AsyncIterator[int]
|
||||
) -> AsyncIterator[int]:
|
||||
async for v in input_value:
|
||||
yield v + 1
|
||||
"""
|
||||
|
@@ -0,0 +1,4 @@
|
||||
"""The module of AWEL resource.
|
||||
|
||||
Not implemented yet.
|
||||
"""
|
||||
|
@@ -1,8 +1,15 @@
|
||||
"""Base class for resource group."""
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class ResourceGroup(ABC):
|
||||
"""Base class for resource group.
|
||||
|
||||
A resource group is a group of resources that are related to each other.
|
||||
It contains the all resources that are needed to run a workflow.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""The name of current resource group"""
|
||||
"""Return the name of current resource group."""
|
||||
|
@@ -0,0 +1,4 @@
|
||||
"""The module to run AWEL operators.
|
||||
|
||||
You can implement your own runner by inheriting the `WorkflowRunner` class.
|
||||
"""
|
||||
|
@@ -1,33 +1,38 @@
|
||||
"""Job manager for DAG."""
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Dict, List, Optional, Set
|
||||
from typing import Dict, List, Optional, cast
|
||||
|
||||
from ..dag.base import DAG, DAGLifecycle
|
||||
from ..dag.base import DAGLifecycle
|
||||
from ..operator.base import CALL_DATA, BaseOperator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DAGNodeInstance:
|
||||
def __init__(self, node_instance: DAG) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class DAGInstance:
|
||||
def __init__(self, dag: DAG) -> None:
|
||||
self._dag = dag
|
||||
|
||||
|
||||
class JobManager(DAGLifecycle):
|
||||
"""Job manager for DAG.
|
||||
|
||||
This class is used to manage the DAG lifecycle.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_nodes: List[BaseOperator],
|
||||
all_nodes: List[BaseOperator],
|
||||
end_node: BaseOperator,
|
||||
id2call_data: Dict[str, Dict],
|
||||
id2call_data: Dict[str, Optional[Dict]],
|
||||
node_name_to_ids: Dict[str, str],
|
||||
) -> None:
|
||||
"""Create a job manager.
|
||||
|
||||
Args:
|
||||
root_nodes (List[BaseOperator]): The root nodes of the DAG.
|
||||
all_nodes (List[BaseOperator]): All nodes of the DAG.
|
||||
end_node (BaseOperator): The end node of the DAG.
|
||||
id2call_data (Dict[str, Optional[Dict]]): The call data of each node.
|
||||
node_name_to_ids (Dict[str, str]): The node name to node id mapping.
|
||||
"""
|
||||
self._root_nodes = root_nodes
|
||||
self._all_nodes = all_nodes
|
||||
self._end_node = end_node
|
||||
@@ -38,6 +43,15 @@ class JobManager(DAGLifecycle):
|
||||
def build_from_end_node(
|
||||
end_node: BaseOperator, call_data: Optional[CALL_DATA] = None
|
||||
) -> "JobManager":
|
||||
"""Build a job manager from the end node.
|
||||
|
||||
This will get all upstream nodes from the end node, and build a job manager.
|
||||
|
||||
Args:
|
||||
end_node (BaseOperator): The end node of the DAG.
|
||||
call_data (Optional[CALL_DATA], optional): The call data of the end node.
|
||||
Defaults to None.
|
||||
"""
|
||||
nodes = _build_from_end_node(end_node)
|
||||
root_nodes = _get_root_nodes(nodes)
|
||||
id2call_data = _save_call_data(root_nodes, call_data)
|
||||
@@ -50,17 +64,22 @@ class JobManager(DAGLifecycle):
|
||||
return JobManager(root_nodes, nodes, end_node, id2call_data, node_name_to_ids)
|
||||
|
||||
def get_call_data_by_id(self, node_id: str) -> Optional[Dict]:
|
||||
"""Get the call data by node id.
|
||||
|
||||
Args:
|
||||
node_id (str): The node id.
|
||||
"""
|
||||
return self._id2node_data.get(node_id)
|
||||
|
||||
async def before_dag_run(self):
|
||||
"""The callback before DAG run"""
|
||||
"""Execute the callback before DAG run."""
|
||||
tasks = []
|
||||
for node in self._all_nodes:
|
||||
tasks.append(node.before_dag_run())
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def after_dag_end(self):
|
||||
"""The callback after DAG end"""
|
||||
"""Execute the callback after DAG end."""
|
||||
tasks = []
|
||||
for node in self._all_nodes:
|
||||
tasks.append(node.after_dag_end())
|
||||
@@ -68,9 +87,9 @@ class JobManager(DAGLifecycle):
|
||||
|
||||
|
||||
def _save_call_data(
|
||||
root_nodes: List[BaseOperator], call_data: CALL_DATA
|
||||
) -> Dict[str, Dict]:
|
||||
id2call_data = {}
|
||||
root_nodes: List[BaseOperator], call_data: Optional[CALL_DATA]
|
||||
) -> Dict[str, Optional[Dict]]:
|
||||
id2call_data: Dict[str, Optional[Dict]] = {}
|
||||
logger.debug(f"_save_call_data: {call_data}, root_nodes: {root_nodes}")
|
||||
if not call_data:
|
||||
return id2call_data
|
||||
@@ -82,7 +101,8 @@ def _save_call_data(
|
||||
for node in root_nodes:
|
||||
node_id = node.node_id
|
||||
logger.debug(
|
||||
f"Save call data to node {node.node_id}, call_data: {call_data.get(node_id)}"
|
||||
f"Save call data to node {node.node_id}, call_data: "
|
||||
f"{call_data.get(node_id)}"
|
||||
)
|
||||
id2call_data[node_id] = call_data.get(node_id)
|
||||
return id2call_data
|
||||
@@ -91,13 +111,11 @@ def _save_call_data(
|
||||
def _build_from_end_node(end_node: BaseOperator) -> List[BaseOperator]:
|
||||
"""Build all nodes from the end node."""
|
||||
nodes = []
|
||||
if isinstance(end_node, BaseOperator):
|
||||
task_id = end_node.node_id
|
||||
if not task_id:
|
||||
task_id = str(uuid.uuid4())
|
||||
end_node.set_node_id(task_id)
|
||||
if isinstance(end_node, BaseOperator) and not end_node._node_id:
|
||||
end_node.set_node_id(str(uuid.uuid4()))
|
||||
nodes.append(end_node)
|
||||
for node in end_node.upstream:
|
||||
node = cast(BaseOperator, node)
|
||||
nodes += _build_from_end_node(node)
|
||||
return nodes
|
||||
|
||||
|
@@ -1,12 +1,16 @@
|
||||
"""Local runner for workflow.
|
||||
|
||||
This runner will run the workflow in the current process.
|
||||
"""
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Set
|
||||
from typing import Dict, List, Optional, Set, cast
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
|
||||
from ..dag.base import DAGContext, DAGVar
|
||||
from ..operator.base import CALL_DATA, BaseOperator, WorkflowRunner
|
||||
from ..operator.common_operator import BranchOperator, JoinOperator, TriggerOperator
|
||||
from ..task.base import TaskContext, TaskState
|
||||
from ..operator.common_operator import BranchOperator, JoinOperator
|
||||
from ..task.base import SKIP_DATA, TaskContext, TaskState
|
||||
from ..task.task_impl import DefaultInputContext, DefaultTaskContext, SimpleTaskOutput
|
||||
from .job_manager import JobManager
|
||||
|
||||
@@ -14,6 +18,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DefaultWorkflowRunner(WorkflowRunner):
|
||||
"""The default workflow runner."""
|
||||
|
||||
async def execute_workflow(
|
||||
self,
|
||||
node: BaseOperator,
|
||||
@@ -21,6 +27,17 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
streaming_call: bool = False,
|
||||
exist_dag_ctx: Optional[DAGContext] = None,
|
||||
) -> DAGContext:
|
||||
"""Execute the workflow.
|
||||
|
||||
Args:
|
||||
node (BaseOperator): The end node of the workflow.
|
||||
call_data (Optional[CALL_DATA], optional): The call data of the end node.
|
||||
Defaults to None.
|
||||
streaming_call (bool, optional): Whether the call is streaming call.
|
||||
Defaults to False.
|
||||
exist_dag_ctx (Optional[DAGContext], optional): The exist DAG context.
|
||||
Defaults to None.
|
||||
"""
|
||||
# Save node output
|
||||
# dag = node.dag
|
||||
job_manager = JobManager.build_from_end_node(node, call_data)
|
||||
@@ -37,8 +54,8 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
)
|
||||
logger.info(f"Begin run workflow from end operator, id: {node.node_id}")
|
||||
logger.debug(f"Node id {node.node_id}, call_data: {call_data}")
|
||||
skip_node_ids = set()
|
||||
system_app: SystemApp = DAGVar.get_current_system_app()
|
||||
skip_node_ids: Set[str] = set()
|
||||
system_app: Optional[SystemApp] = DAGVar.get_current_system_app()
|
||||
|
||||
await job_manager.before_dag_run()
|
||||
await self._execute_node(
|
||||
@@ -57,7 +74,7 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
dag_ctx: DAGContext,
|
||||
node_outputs: Dict[str, TaskContext],
|
||||
skip_node_ids: Set[str],
|
||||
system_app: SystemApp,
|
||||
system_app: Optional[SystemApp],
|
||||
):
|
||||
# Skip run node
|
||||
if node.node_id in node_outputs:
|
||||
@@ -79,8 +96,12 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
node_outputs[upstream_node.node_id] for upstream_node in node.upstream
|
||||
]
|
||||
input_ctx = DefaultInputContext(inputs)
|
||||
task_ctx = DefaultTaskContext(node.node_id, TaskState.INIT, task_output=None)
|
||||
task_ctx.set_call_data(job_manager.get_call_data_by_id(node.node_id))
|
||||
task_ctx: DefaultTaskContext = DefaultTaskContext(
|
||||
node.node_id, TaskState.INIT, task_output=None
|
||||
)
|
||||
current_call_data = job_manager.get_call_data_by_id(node.node_id)
|
||||
if current_call_data:
|
||||
task_ctx.set_call_data(current_call_data)
|
||||
|
||||
task_ctx.set_task_input(input_ctx)
|
||||
dag_ctx.set_current_task_context(task_ctx)
|
||||
@@ -88,12 +109,13 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
|
||||
if node.node_id in skip_node_ids:
|
||||
task_ctx.set_current_state(TaskState.SKIP)
|
||||
task_ctx.set_task_output(SimpleTaskOutput(None))
|
||||
task_ctx.set_task_output(SimpleTaskOutput(SKIP_DATA))
|
||||
node_outputs[node.node_id] = task_ctx
|
||||
return
|
||||
try:
|
||||
logger.debug(
|
||||
f"Begin run operator, node id: {node.node_id}, node name: {node.node_name}, cls: {node}"
|
||||
f"Begin run operator, node id: {node.node_id}, node name: "
|
||||
f"{node.node_name}, cls: {node}"
|
||||
)
|
||||
if system_app is not None and node.system_app is None:
|
||||
node.set_system_app(system_app)
|
||||
@@ -120,6 +142,7 @@ def _skip_current_downstream_by_node_name(
|
||||
if not skip_nodes:
|
||||
return
|
||||
for child in branch_node.downstream:
|
||||
child = cast(BaseOperator, child)
|
||||
if child.node_name in skip_nodes:
|
||||
logger.info(f"Skip node name {child.node_name}, node id {child.node_id}")
|
||||
_skip_downstream_by_id(child, skip_node_ids)
|
||||
@@ -131,4 +154,5 @@ def _skip_downstream_by_id(node: BaseOperator, skip_node_ids: Set[str]):
|
||||
return
|
||||
skip_node_ids.add(node.node_id)
|
||||
for child in node.downstream:
|
||||
child = cast(BaseOperator, child)
|
||||
_skip_downstream_by_id(child, skip_node_ids)
|
||||
|
@@ -0,0 +1 @@
|
||||
"""The module of Task."""
|
||||
|
@@ -1,8 +1,10 @@
|
||||
"""Base classes for task-related objects."""
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
@@ -17,6 +19,24 @@ OUT = TypeVar("OUT")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class _EMPTY_DATA_TYPE:
|
||||
def __bool__(self):
|
||||
return False
|
||||
|
||||
|
||||
EMPTY_DATA = _EMPTY_DATA_TYPE()
|
||||
SKIP_DATA = _EMPTY_DATA_TYPE()
|
||||
PLACEHOLDER_DATA = _EMPTY_DATA_TYPE()
|
||||
|
||||
MapFunc = Union[Callable[[IN], OUT], Callable[[IN], Awaitable[OUT]]]
|
||||
ReduceFunc = Union[Callable[[IN], OUT], Callable[[IN], Awaitable[OUT]]]
|
||||
StreamFunc = Callable[[IN], Awaitable[AsyncIterator[OUT]]]
|
||||
UnStreamFunc = Callable[[AsyncIterator[IN]], OUT]
|
||||
TransformFunc = Callable[[AsyncIterator[IN]], Awaitable[AsyncIterator[OUT]]]
|
||||
PredicateFunc = Union[Callable[[IN], bool], Callable[[IN], Awaitable[bool]]]
|
||||
JoinFunc = Union[Callable[..., OUT], Callable[..., Awaitable[OUT]]]
|
||||
|
||||
|
||||
class TaskState(str, Enum):
|
||||
"""Enumeration representing the state of a task in the workflow.
|
||||
|
||||
@@ -33,8 +53,8 @@ class TaskState(str, Enum):
|
||||
class TaskOutput(ABC, Generic[T]):
|
||||
"""Abstract base class representing the output of a task.
|
||||
|
||||
This class encapsulates the output of a task and provides methods to access the output data.
|
||||
It can be subclassed to implement specific output behaviors.
|
||||
This class encapsulates the output of a task and provides methods to access the
|
||||
output data.It can be subclassed to implement specific output behaviors.
|
||||
"""
|
||||
|
||||
@property
|
||||
@@ -56,20 +76,30 @@ class TaskOutput(ABC, Generic[T]):
|
||||
return False
|
||||
|
||||
@property
|
||||
def output(self) -> Optional[T]:
|
||||
def is_none(self) -> bool:
|
||||
"""Check if the output is None.
|
||||
|
||||
Returns:
|
||||
bool: True if the output is None, False otherwise.
|
||||
"""
|
||||
return False
|
||||
|
||||
@property
|
||||
def output(self) -> T:
|
||||
"""Return the output of the task.
|
||||
|
||||
Returns:
|
||||
T: The output of the task. None if the output is empty.
|
||||
T: The output of the task.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def output_stream(self) -> Optional[AsyncIterator[T]]:
|
||||
def output_stream(self) -> AsyncIterator[T]:
|
||||
"""Return the output of the task as an asynchronous stream.
|
||||
|
||||
Returns:
|
||||
AsyncIterator[T]: An asynchronous iterator over the output. None if the output is empty.
|
||||
AsyncIterator[T]: An asynchronous iterator over the output. None if the
|
||||
output is empty.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -83,39 +113,38 @@ class TaskOutput(ABC, Generic[T]):
|
||||
|
||||
@abstractmethod
|
||||
def new_output(self) -> "TaskOutput[T]":
|
||||
"""Create new output object"""
|
||||
"""Create new output object."""
|
||||
|
||||
async def map(self, map_func) -> "TaskOutput[T]":
|
||||
async def map(self, map_func: MapFunc) -> "TaskOutput[OUT]":
|
||||
"""Apply a mapping function to the task's output.
|
||||
|
||||
Args:
|
||||
map_func: A function to apply to the task's output.
|
||||
map_func (MapFunc): A function to apply to the task's output.
|
||||
|
||||
Returns:
|
||||
TaskOutput[T]: The result of applying the mapping function.
|
||||
TaskOutput[OUT]: The result of applying the mapping function.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def reduce(self, reduce_func) -> "TaskOutput[T]":
|
||||
async def reduce(self, reduce_func: ReduceFunc) -> "TaskOutput[OUT]":
|
||||
"""Apply a reducing function to the task's output.
|
||||
|
||||
Stream TaskOutput to Nonstream TaskOutput.
|
||||
Stream TaskOutput to no stream TaskOutput.
|
||||
|
||||
Args:
|
||||
reduce_func: A reducing function to apply to the task's output.
|
||||
|
||||
Returns:
|
||||
TaskOutput[T]: The result of applying the reducing function.
|
||||
TaskOutput[OUT]: The result of applying the reducing function.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def streamify(
|
||||
self, transform_func: Callable[[T], AsyncIterator[T]]
|
||||
) -> "TaskOutput[T]":
|
||||
async def streamify(self, transform_func: StreamFunc) -> "TaskOutput[T]":
|
||||
"""Convert a value of type T to an AsyncIterator[T] using a transform function.
|
||||
|
||||
Args:
|
||||
transform_func (Callable[[T], AsyncIterator[T]]): Function to transform a T value into an AsyncIterator[T].
|
||||
transform_func (StreamFunc): Function to transform a T value into an
|
||||
AsyncIterator[OUT].
|
||||
|
||||
Returns:
|
||||
TaskOutput[T]: The result of applying the reducing function.
|
||||
@@ -123,38 +152,39 @@ class TaskOutput(ABC, Generic[T]):
|
||||
raise NotImplementedError
|
||||
|
||||
async def transform_stream(
|
||||
self, transform_func: Callable[[AsyncIterator[T]], AsyncIterator[T]]
|
||||
) -> "TaskOutput[T]":
|
||||
"""Transform an AsyncIterator[T] to another AsyncIterator[T] using a given function.
|
||||
self, transform_func: TransformFunc
|
||||
) -> "TaskOutput[OUT]":
|
||||
"""Transform an AsyncIterator[T] to another AsyncIterator[T].
|
||||
|
||||
Args:
|
||||
transform_func (Callable[[AsyncIterator[T]], AsyncIterator[T]]): Function to apply to the AsyncIterator[T].
|
||||
transform_func (Callable[[AsyncIterator[T]], AsyncIterator[T]]): Function to
|
||||
apply to the AsyncIterator[T].
|
||||
|
||||
Returns:
|
||||
TaskOutput[T]: The result of applying the reducing function.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def unstreamify(
|
||||
self, transform_func: Callable[[AsyncIterator[T]], T]
|
||||
) -> "TaskOutput[T]":
|
||||
async def unstreamify(self, transform_func: UnStreamFunc) -> "TaskOutput[OUT]":
|
||||
"""Convert an AsyncIterator[T] to a value of type T using a transform function.
|
||||
|
||||
Args:
|
||||
transform_func (Callable[[AsyncIterator[T]], T]): Function to transform an AsyncIterator[T] into a T value.
|
||||
transform_func (UnStreamFunc): Function to transform an AsyncIterator[T]
|
||||
into a T value.
|
||||
|
||||
Returns:
|
||||
TaskOutput[T]: The result of applying the reducing function.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def check_condition(self, condition_func) -> bool:
|
||||
async def check_condition(self, condition_func) -> "TaskOutput[OUT]":
|
||||
"""Check if current output meets a given condition.
|
||||
|
||||
Args:
|
||||
condition_func: A function to determine if the condition is met.
|
||||
Returns:
|
||||
bool: True if current output meet the condition, False otherwise.
|
||||
TaskOutput[T]: The result of applying the reducing function.
|
||||
If the condition is not met, return empty output.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -182,6 +212,9 @@ class TaskContext(ABC, Generic[T]):
|
||||
|
||||
Returns:
|
||||
InputContext: The InputContext of current task.
|
||||
|
||||
Raises:
|
||||
Exception: If the InputContext is not set.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@@ -216,7 +249,7 @@ class TaskContext(ABC, Generic[T]):
|
||||
|
||||
@abstractmethod
|
||||
def set_current_state(self, task_state: TaskState) -> None:
|
||||
"""Set current task state
|
||||
"""Set current task state.
|
||||
|
||||
Args:
|
||||
task_state (TaskState): The task state to be set.
|
||||
@@ -224,7 +257,7 @@ class TaskContext(ABC, Generic[T]):
|
||||
|
||||
@abstractmethod
|
||||
def new_ctx(self) -> "TaskContext":
|
||||
"""Create new task context
|
||||
"""Create new task context.
|
||||
|
||||
Returns:
|
||||
TaskContext: A new instance of a TaskContext.
|
||||
@@ -233,14 +266,14 @@ class TaskContext(ABC, Generic[T]):
|
||||
@property
|
||||
@abstractmethod
|
||||
def metadata(self) -> Dict[str, Any]:
|
||||
"""Get the metadata of current task
|
||||
"""Return the metadata of current task.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The metadata
|
||||
"""
|
||||
|
||||
def update_metadata(self, key: str, value: Any) -> None:
|
||||
"""Update metadata with key and value
|
||||
"""Update metadata with key and value.
|
||||
|
||||
Args:
|
||||
key (str): The key of metadata
|
||||
@@ -250,15 +283,15 @@ class TaskContext(ABC, Generic[T]):
|
||||
|
||||
@property
|
||||
def call_data(self) -> Optional[Dict]:
|
||||
"""Get the call data for current data"""
|
||||
"""Return the call data for current data."""
|
||||
return self.metadata.get("call_data")
|
||||
|
||||
@abstractmethod
|
||||
async def _call_data_to_output(self) -> Optional[TaskOutput[T]]:
|
||||
"""Get the call data for current data"""
|
||||
"""Get the call data for current data."""
|
||||
|
||||
def set_call_data(self, call_data: Dict) -> None:
|
||||
"""Set call data for current task"""
|
||||
"""Save the call data for current task."""
|
||||
self.update_metadata("call_data", call_data)
|
||||
|
||||
|
||||
@@ -315,7 +348,8 @@ class InputContext(ABC):
|
||||
"""Filter the inputs based on a provided function.
|
||||
|
||||
Args:
|
||||
filter_func (Callable[[Any], bool]): A function that returns True for inputs to keep.
|
||||
filter_func (Callable[[Any], bool]): A function that returns True for
|
||||
inputs to keep.
|
||||
|
||||
Returns:
|
||||
InputContext: A new InputContext instance with the filtered inputs.
|
||||
@@ -323,13 +357,15 @@ class InputContext(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def predicate_map(
|
||||
self, predicate_func: Callable[[Any], bool], failed_value: Any = None
|
||||
self, predicate_func: PredicateFunc, failed_value: Any = None
|
||||
) -> "InputContext":
|
||||
"""Predicate the inputs based on a provided function.
|
||||
|
||||
Args:
|
||||
predicate_func (Callable[[Any], bool]): A function that returns True for inputs is predicate True.
|
||||
failed_value (Any): The value to be set if the return value of predicate function is False
|
||||
predicate_func (Callable[[Any], bool]): A function that returns True for
|
||||
inputs is predicate True.
|
||||
failed_value (Any): The value to be set if the return value of predicate
|
||||
function is False
|
||||
Returns:
|
||||
InputContext: A new InputContext instance with the predicate inputs.
|
||||
"""
|
||||
|
@@ -1,3 +1,7 @@
|
||||
"""The default implementation of Task.
|
||||
|
||||
This implementation can run workflow in local machine.
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
@@ -8,15 +12,32 @@ from typing import (
|
||||
Coroutine,
|
||||
Dict,
|
||||
Generic,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from .base import InputContext, InputSource, T, TaskContext, TaskOutput, TaskState
|
||||
from .base import (
|
||||
_EMPTY_DATA_TYPE,
|
||||
EMPTY_DATA,
|
||||
OUT,
|
||||
PLACEHOLDER_DATA,
|
||||
SKIP_DATA,
|
||||
InputContext,
|
||||
InputSource,
|
||||
MapFunc,
|
||||
PredicateFunc,
|
||||
ReduceFunc,
|
||||
StreamFunc,
|
||||
T,
|
||||
TaskContext,
|
||||
TaskOutput,
|
||||
TaskState,
|
||||
TransformFunc,
|
||||
UnStreamFunc,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -37,101 +58,197 @@ async def _reduce_stream(stream: AsyncIterator, reduce_function) -> Any:
|
||||
|
||||
|
||||
class SimpleTaskOutput(TaskOutput[T], Generic[T]):
|
||||
def __init__(self, data: T) -> None:
|
||||
"""The default implementation of TaskOutput.
|
||||
|
||||
It wraps the no stream data and provide some basic data operations.
|
||||
"""
|
||||
|
||||
def __init__(self, data: Union[T, _EMPTY_DATA_TYPE] = EMPTY_DATA) -> None:
|
||||
"""Create a SimpleTaskOutput.
|
||||
|
||||
Args:
|
||||
data (Union[T, _EMPTY_DATA_TYPE], optional): The output data. Defaults to
|
||||
EMPTY_DATA.
|
||||
"""
|
||||
super().__init__()
|
||||
self._data = data
|
||||
|
||||
@property
|
||||
def output(self) -> T:
|
||||
return self._data
|
||||
"""Return the output data."""
|
||||
if self._data == EMPTY_DATA:
|
||||
raise ValueError("No output data for current task output")
|
||||
return cast(T, self._data)
|
||||
|
||||
def set_output(self, output_data: T | AsyncIterator[T]) -> None:
|
||||
self._data = output_data
|
||||
"""Save the output data to current object.
|
||||
|
||||
Args:
|
||||
output_data (T | AsyncIterator[T]): The output data.
|
||||
"""
|
||||
if _is_async_iterator(output_data):
|
||||
raise ValueError(
|
||||
f"Can not set stream data {output_data} to SimpleTaskOutput"
|
||||
)
|
||||
self._data = cast(T, output_data)
|
||||
|
||||
def new_output(self) -> TaskOutput[T]:
|
||||
return SimpleTaskOutput(None)
|
||||
"""Create new output object with empty data."""
|
||||
return SimpleTaskOutput()
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
"""Return True if the output data is empty."""
|
||||
return self._data == EMPTY_DATA or self._data == SKIP_DATA
|
||||
|
||||
@property
|
||||
def is_none(self) -> bool:
|
||||
"""Return True if the output data is None."""
|
||||
return self._data is None
|
||||
|
||||
async def _apply_func(self, func) -> Any:
|
||||
"""Apply the function to current output data."""
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
out = await func(self._data)
|
||||
else:
|
||||
out = func(self._data)
|
||||
return out
|
||||
|
||||
async def map(self, map_func) -> TaskOutput[T]:
|
||||
async def map(self, map_func: MapFunc) -> TaskOutput[OUT]:
|
||||
"""Apply a mapping function to the task's output.
|
||||
|
||||
Args:
|
||||
map_func (MapFunc): A function to apply to the task's output.
|
||||
|
||||
Returns:
|
||||
TaskOutput[OUT]: The result of applying the mapping function.
|
||||
"""
|
||||
out = await self._apply_func(map_func)
|
||||
return SimpleTaskOutput(out)
|
||||
|
||||
async def check_condition(self, condition_func) -> bool:
|
||||
return await self._apply_func(condition_func)
|
||||
async def check_condition(self, condition_func) -> TaskOutput[OUT]:
|
||||
"""Check the condition function."""
|
||||
out = await self._apply_func(condition_func)
|
||||
if out:
|
||||
return SimpleTaskOutput(PLACEHOLDER_DATA)
|
||||
return SimpleTaskOutput(EMPTY_DATA)
|
||||
|
||||
async def streamify(
|
||||
self, transform_func: Callable[[T], AsyncIterator[T]]
|
||||
) -> TaskOutput[T]:
|
||||
async def streamify(self, transform_func: StreamFunc) -> TaskOutput[OUT]:
|
||||
"""Transform the task's output to a stream output.
|
||||
|
||||
Args:
|
||||
transform_func (StreamFunc): A function to transform the task's output to a
|
||||
stream output.
|
||||
|
||||
Returns:
|
||||
TaskOutput[OUT]: The result of transforming the task's output to a stream
|
||||
output.
|
||||
"""
|
||||
out = await self._apply_func(transform_func)
|
||||
return SimpleStreamTaskOutput(out)
|
||||
|
||||
|
||||
class SimpleStreamTaskOutput(TaskOutput[T], Generic[T]):
|
||||
def __init__(self, data: AsyncIterator[T]) -> None:
|
||||
"""The default stream implementation of TaskOutput."""
|
||||
|
||||
def __init__(
|
||||
self, data: Union[AsyncIterator[T], _EMPTY_DATA_TYPE] = EMPTY_DATA
|
||||
) -> None:
|
||||
"""Create a SimpleStreamTaskOutput.
|
||||
|
||||
Args:
|
||||
data (Union[AsyncIterator[T], _EMPTY_DATA_TYPE], optional): The output data.
|
||||
Defaults to EMPTY_DATA.
|
||||
"""
|
||||
super().__init__()
|
||||
self._data = data
|
||||
|
||||
@property
|
||||
def is_stream(self) -> bool:
|
||||
"""Return True if the output data is a stream."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
return not self._data
|
||||
"""Return True if the output data is empty."""
|
||||
return self._data == EMPTY_DATA or self._data == SKIP_DATA
|
||||
|
||||
@property
|
||||
def is_none(self) -> bool:
|
||||
"""Return True if the output data is None."""
|
||||
return self._data is None
|
||||
|
||||
@property
|
||||
def output_stream(self) -> AsyncIterator[T]:
|
||||
return self._data
|
||||
"""Return the output data.
|
||||
|
||||
Returns:
|
||||
AsyncIterator[T]: The output data.
|
||||
|
||||
Raises:
|
||||
ValueError: If the output data is empty.
|
||||
"""
|
||||
if self._data == EMPTY_DATA:
|
||||
raise ValueError("No output data for current task output")
|
||||
return cast(AsyncIterator[T], self._data)
|
||||
|
||||
def set_output(self, output_data: T | AsyncIterator[T]) -> None:
|
||||
self._data = output_data
|
||||
"""Save the output data to current object.
|
||||
|
||||
Raises:
|
||||
ValueError: If the output data is not a stream.
|
||||
"""
|
||||
if not _is_async_iterator(output_data):
|
||||
raise ValueError(
|
||||
f"Can not set non-stream data {output_data} to SimpleStreamTaskOutput"
|
||||
)
|
||||
self._data = cast(AsyncIterator[T], output_data)
|
||||
|
||||
def new_output(self) -> TaskOutput[T]:
|
||||
return SimpleStreamTaskOutput(None)
|
||||
"""Create new output object with empty data."""
|
||||
return SimpleStreamTaskOutput()
|
||||
|
||||
async def map(self, map_func) -> TaskOutput[T]:
|
||||
async def map(self, map_func: MapFunc) -> TaskOutput[OUT]:
|
||||
"""Apply a mapping function to the task's output."""
|
||||
is_async = asyncio.iscoroutinefunction(map_func)
|
||||
|
||||
async def new_iter() -> AsyncIterator[T]:
|
||||
async for out in self._data:
|
||||
async def new_iter() -> AsyncIterator[OUT]:
|
||||
async for out in self.output_stream:
|
||||
if is_async:
|
||||
out = await map_func(out)
|
||||
new_out: OUT = await map_func(out)
|
||||
else:
|
||||
out = map_func(out)
|
||||
yield out
|
||||
new_out = cast(OUT, map_func(out))
|
||||
yield new_out
|
||||
|
||||
return SimpleStreamTaskOutput(new_iter())
|
||||
|
||||
async def reduce(self, reduce_func) -> TaskOutput[T]:
|
||||
out = await _reduce_stream(self._data, reduce_func)
|
||||
async def reduce(self, reduce_func: ReduceFunc) -> TaskOutput[OUT]:
|
||||
"""Apply a reduce function to the task's output."""
|
||||
out = await _reduce_stream(self.output_stream, reduce_func)
|
||||
return SimpleTaskOutput(out)
|
||||
|
||||
async def unstreamify(
|
||||
self, transform_func: Callable[[AsyncIterator[T]], T]
|
||||
) -> TaskOutput[T]:
|
||||
async def unstreamify(self, transform_func: UnStreamFunc) -> TaskOutput[OUT]:
|
||||
"""Transform the task's output to a non-stream output."""
|
||||
if asyncio.iscoroutinefunction(transform_func):
|
||||
out = await transform_func(self._data)
|
||||
out = await transform_func(self.output_stream)
|
||||
else:
|
||||
out = transform_func(self._data)
|
||||
out = transform_func(self.output_stream)
|
||||
return SimpleTaskOutput(out)
|
||||
|
||||
async def transform_stream(
|
||||
self, transform_func: Callable[[AsyncIterator[T]], AsyncIterator[T]]
|
||||
) -> TaskOutput[T]:
|
||||
async def transform_stream(self, transform_func: TransformFunc) -> TaskOutput[OUT]:
|
||||
"""Transform an AsyncIterator[T] to another AsyncIterator[T].
|
||||
|
||||
Args:
|
||||
transform_func (Callable[[AsyncIterator[T]], AsyncIterator[T]]): Function to
|
||||
apply to the AsyncIterator[T].
|
||||
|
||||
Returns:
|
||||
TaskOutput[T]: The result of applying the reducing function.
|
||||
"""
|
||||
if asyncio.iscoroutinefunction(transform_func):
|
||||
out = await transform_func(self._data)
|
||||
out: AsyncIterator[OUT] = await transform_func(self.output_stream)
|
||||
else:
|
||||
out = transform_func(self._data)
|
||||
out = cast(AsyncIterator[OUT], transform_func(self.output_stream))
|
||||
return SimpleStreamTaskOutput(out)
|
||||
|
||||
|
||||
@@ -145,20 +262,34 @@ def _is_async_iterator(obj):
|
||||
|
||||
|
||||
class BaseInputSource(InputSource, ABC):
|
||||
"""The base class of InputSource."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Create a BaseInputSource."""
|
||||
super().__init__()
|
||||
self._is_read = False
|
||||
|
||||
@abstractmethod
|
||||
def _read_data(self, task_ctx: TaskContext) -> Any:
|
||||
"""Read data with task context"""
|
||||
"""Return data with task context."""
|
||||
|
||||
async def read(self, task_ctx: TaskContext) -> TaskOutput:
|
||||
"""Read data with task context.
|
||||
|
||||
Args:
|
||||
task_ctx (TaskContext): The task context.
|
||||
|
||||
Returns:
|
||||
TaskOutput: The task output.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input source is a stream and has been read.
|
||||
"""
|
||||
data = self._read_data(task_ctx)
|
||||
if _is_async_iterator(data):
|
||||
if self._is_read:
|
||||
raise ValueError(f"Input iterator {data} has been read!")
|
||||
output = SimpleStreamTaskOutput(data)
|
||||
output: TaskOutput = SimpleStreamTaskOutput(data)
|
||||
else:
|
||||
output = SimpleTaskOutput(data)
|
||||
self._is_read = True
|
||||
@@ -166,7 +297,14 @@ class BaseInputSource(InputSource, ABC):
|
||||
|
||||
|
||||
class SimpleInputSource(BaseInputSource):
|
||||
"""The default implementation of InputSource."""
|
||||
|
||||
def __init__(self, data: Any) -> None:
|
||||
"""Create a SimpleInputSource.
|
||||
|
||||
Args:
|
||||
data (Any): The input data.
|
||||
"""
|
||||
super().__init__()
|
||||
self._data = data
|
||||
|
||||
@@ -175,63 +313,121 @@ class SimpleInputSource(BaseInputSource):
|
||||
|
||||
|
||||
class SimpleCallDataInputSource(BaseInputSource):
|
||||
"""The implementation of InputSource for call data."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Create a SimpleCallDataInputSource."""
|
||||
super().__init__()
|
||||
|
||||
def _read_data(self, task_ctx: TaskContext) -> Any:
|
||||
"""Read data from task context.
|
||||
|
||||
Returns:
|
||||
Any: The data.
|
||||
|
||||
Raises:
|
||||
ValueError: If the call data is empty.
|
||||
"""
|
||||
call_data = task_ctx.call_data
|
||||
data = call_data.get("data") if call_data else None
|
||||
if not (call_data and data):
|
||||
data = call_data.get("data", EMPTY_DATA) if call_data else EMPTY_DATA
|
||||
if data == EMPTY_DATA:
|
||||
raise ValueError("No call data for current SimpleCallDataInputSource")
|
||||
return data
|
||||
|
||||
|
||||
class DefaultTaskContext(TaskContext, Generic[T]):
|
||||
"""The default implementation of TaskContext."""
|
||||
|
||||
def __init__(
|
||||
self, task_id: str, task_state: TaskState, task_output: TaskOutput[T]
|
||||
self,
|
||||
task_id: str,
|
||||
task_state: TaskState,
|
||||
task_output: Optional[TaskOutput[T]] = None,
|
||||
) -> None:
|
||||
"""Create a DefaultTaskContext.
|
||||
|
||||
Args:
|
||||
task_id (str): The task id.
|
||||
task_state (TaskState): The task state.
|
||||
task_output (Optional[TaskOutput[T]], optional): The task output. Defaults
|
||||
to None.
|
||||
"""
|
||||
super().__init__()
|
||||
self._task_id = task_id
|
||||
self._task_state = task_state
|
||||
self._output = task_output
|
||||
self._task_input = None
|
||||
self._metadata = {}
|
||||
self._output: Optional[TaskOutput[T]] = task_output
|
||||
self._task_input: Optional[InputContext] = None
|
||||
self._metadata: Dict[str, Any] = {}
|
||||
|
||||
@property
|
||||
def task_id(self) -> str:
|
||||
"""Return the task id."""
|
||||
return self._task_id
|
||||
|
||||
@property
|
||||
def task_input(self) -> InputContext:
|
||||
"""Return the task input."""
|
||||
if not self._task_input:
|
||||
raise ValueError("No input for current task context")
|
||||
return self._task_input
|
||||
|
||||
def set_task_input(self, input_ctx: "InputContext") -> None:
|
||||
def set_task_input(self, input_ctx: InputContext) -> None:
|
||||
"""Save the task input to current task."""
|
||||
self._task_input = input_ctx
|
||||
|
||||
@property
|
||||
def task_output(self) -> TaskOutput:
|
||||
"""Return the task output.
|
||||
|
||||
Returns:
|
||||
TaskOutput: The task output.
|
||||
|
||||
Raises:
|
||||
ValueError: If the task output is empty.
|
||||
"""
|
||||
if not self._output:
|
||||
raise ValueError("No output for current task context")
|
||||
return self._output
|
||||
|
||||
def set_task_output(self, task_output: TaskOutput) -> None:
|
||||
"""Save the task output to current task.
|
||||
|
||||
Args:
|
||||
task_output (TaskOutput): The task output.
|
||||
"""
|
||||
self._output = task_output
|
||||
|
||||
@property
|
||||
def current_state(self) -> TaskState:
|
||||
"""Return the current task state."""
|
||||
return self._task_state
|
||||
|
||||
def set_current_state(self, task_state: TaskState) -> None:
|
||||
"""Save the current task state to current task."""
|
||||
self._task_state = task_state
|
||||
|
||||
def new_ctx(self) -> TaskContext:
|
||||
"""Create new task context with empty output."""
|
||||
if not self._output:
|
||||
raise ValueError("No output for current task context")
|
||||
new_output = self._output.new_output()
|
||||
return DefaultTaskContext(self._task_id, self._task_state, new_output)
|
||||
|
||||
@property
|
||||
def metadata(self) -> Dict[str, Any]:
|
||||
"""Return the metadata of current task.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The metadata.
|
||||
"""
|
||||
return self._metadata
|
||||
|
||||
async def _call_data_to_output(self) -> Optional[TaskOutput[T]]:
|
||||
"""Get the call data for current data"""
|
||||
"""Return the call data of current task.
|
||||
|
||||
Returns:
|
||||
Optional[TaskOutput[T]]: The call data.
|
||||
"""
|
||||
call_data = self.call_data
|
||||
if not call_data:
|
||||
return None
|
||||
@@ -240,24 +436,48 @@ class DefaultTaskContext(TaskContext, Generic[T]):
|
||||
|
||||
|
||||
class DefaultInputContext(InputContext):
|
||||
"""The default implementation of InputContext.
|
||||
|
||||
It wraps the all inputs from parent tasks and provide some basic data operations.
|
||||
"""
|
||||
|
||||
def __init__(self, outputs: List[TaskContext]) -> None:
|
||||
"""Create a DefaultInputContext.
|
||||
|
||||
Args:
|
||||
outputs (List[TaskContext]): The outputs from parent tasks.
|
||||
"""
|
||||
super().__init__()
|
||||
self._outputs = outputs
|
||||
|
||||
@property
|
||||
def parent_outputs(self) -> List[TaskContext]:
|
||||
"""Return the outputs from parent tasks.
|
||||
|
||||
Returns:
|
||||
List[TaskContext]: The outputs from parent tasks.
|
||||
"""
|
||||
return self._outputs
|
||||
|
||||
async def _apply_func(
|
||||
self, func: Callable[[Any], Any], apply_type: str = "map"
|
||||
) -> Tuple[List[TaskContext], List[TaskOutput]]:
|
||||
"""Apply the function to all parent outputs.
|
||||
|
||||
Args:
|
||||
func (Callable[[Any], Any]): The function to apply.
|
||||
apply_type (str, optional): The apply type. Defaults to "map".
|
||||
|
||||
Returns:
|
||||
Tuple[List[TaskContext], List[TaskOutput]]: The new parent outputs and the
|
||||
results of applying the function.
|
||||
"""
|
||||
new_outputs: List[TaskContext] = []
|
||||
map_tasks = []
|
||||
for out in self._outputs:
|
||||
new_outputs.append(out.new_ctx())
|
||||
result = None
|
||||
if apply_type == "map":
|
||||
result = out.task_output.map(func)
|
||||
result: Coroutine[Any, Any, TaskOutput[Any]] = out.task_output.map(func)
|
||||
elif apply_type == "reduce":
|
||||
result = out.task_output.reduce(func)
|
||||
elif apply_type == "check_condition":
|
||||
@@ -269,29 +489,40 @@ class DefaultInputContext(InputContext):
|
||||
return new_outputs, results
|
||||
|
||||
async def map(self, map_func: Callable[[Any], Any]) -> InputContext:
|
||||
"""Apply a mapping function to all parent outputs."""
|
||||
new_outputs, results = await self._apply_func(map_func)
|
||||
for i, task_ctx in enumerate(new_outputs):
|
||||
task_ctx: TaskContext = task_ctx
|
||||
task_ctx = cast(TaskContext, task_ctx)
|
||||
task_ctx.set_task_output(results[i])
|
||||
return DefaultInputContext(new_outputs)
|
||||
|
||||
async def map_all(self, map_func: Callable[..., Any]) -> InputContext:
|
||||
"""Apply a mapping function to all parent outputs.
|
||||
|
||||
The parent outputs will be unpacked and passed to the mapping function.
|
||||
|
||||
Args:
|
||||
map_func (Callable[..., Any]): The mapping function.
|
||||
|
||||
Returns:
|
||||
InputContext: The new input context.
|
||||
"""
|
||||
if not self._outputs:
|
||||
return DefaultInputContext([])
|
||||
# Some parent may be empty
|
||||
not_empty_idx = 0
|
||||
for i, p in enumerate(self._outputs):
|
||||
if p.task_output.is_empty:
|
||||
# Skip empty parent
|
||||
continue
|
||||
not_empty_idx = i
|
||||
break
|
||||
# All output is empty?
|
||||
is_steam = self._outputs[not_empty_idx].task_output.is_stream
|
||||
if is_steam:
|
||||
if not self.check_stream(skip_empty=True):
|
||||
raise ValueError(
|
||||
"The output in all tasks must has same output format to map_all"
|
||||
)
|
||||
if is_steam and not self.check_stream(skip_empty=True):
|
||||
raise ValueError(
|
||||
"The output in all tasks must has same output format to map_all"
|
||||
)
|
||||
outputs = []
|
||||
for out in self._outputs:
|
||||
if out.task_output.is_stream:
|
||||
@@ -305,22 +536,26 @@ class DefaultInputContext(InputContext):
|
||||
single_output: TaskContext = self._outputs[not_empty_idx].new_ctx()
|
||||
single_output.task_output.set_output(map_res)
|
||||
logger.debug(
|
||||
f"Current map_all map_res: {map_res}, is steam: {single_output.task_output.is_stream}"
|
||||
f"Current map_all map_res: {map_res}, is steam: "
|
||||
f"{single_output.task_output.is_stream}"
|
||||
)
|
||||
return DefaultInputContext([single_output])
|
||||
|
||||
async def reduce(self, reduce_func: Callable[[Any], Any]) -> InputContext:
|
||||
"""Apply a reduce function to all parent outputs."""
|
||||
if not self.check_stream():
|
||||
raise ValueError(
|
||||
"The output in all tasks must has same output format of stream to apply reduce function"
|
||||
"The output in all tasks must has same output format of stream to apply"
|
||||
" reduce function"
|
||||
)
|
||||
new_outputs, results = await self._apply_func(reduce_func, apply_type="reduce")
|
||||
for i, task_ctx in enumerate(new_outputs):
|
||||
task_ctx: TaskContext = task_ctx
|
||||
task_ctx = cast(TaskContext, task_ctx)
|
||||
task_ctx.set_task_output(results[i])
|
||||
return DefaultInputContext(new_outputs)
|
||||
|
||||
async def filter(self, filter_func: Callable[[Any], bool]) -> InputContext:
|
||||
"""Filter all parent outputs."""
|
||||
new_outputs, results = await self._apply_func(
|
||||
filter_func, apply_type="check_condition"
|
||||
)
|
||||
@@ -331,15 +566,16 @@ class DefaultInputContext(InputContext):
|
||||
return DefaultInputContext(result_outputs)
|
||||
|
||||
async def predicate_map(
|
||||
self, predicate_func: Callable[[Any], bool], failed_value: Any = None
|
||||
self, predicate_func: PredicateFunc, failed_value: Any = None
|
||||
) -> "InputContext":
|
||||
"""Apply a predicate function to all parent outputs."""
|
||||
new_outputs, results = await self._apply_func(
|
||||
predicate_func, apply_type="check_condition"
|
||||
)
|
||||
result_outputs = []
|
||||
for i, task_ctx in enumerate(new_outputs):
|
||||
task_ctx: TaskContext = task_ctx
|
||||
if results[i]:
|
||||
task_ctx = cast(TaskContext, task_ctx)
|
||||
if not results[i].is_empty:
|
||||
task_ctx.task_output.set_output(True)
|
||||
result_outputs.append(task_ctx)
|
||||
else:
|
||||
|
@@ -66,10 +66,10 @@ async def _create_input_node(**kwargs):
|
||||
else:
|
||||
outputs = kwargs.get("outputs", ["Hello."])
|
||||
nodes = []
|
||||
for output in outputs:
|
||||
for i, output in enumerate(outputs):
|
||||
print(f"output: {output}")
|
||||
input_source = SimpleInputSource(output)
|
||||
input_node = InputOperator(input_source)
|
||||
input_node = InputOperator(input_source, task_id="input_node_" + str(i))
|
||||
nodes.append(input_node)
|
||||
yield nodes
|
||||
|
||||
|
@@ -26,7 +26,7 @@ from .conftest import (
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_input_node(runner: WorkflowRunner):
|
||||
input_node = InputOperator(SimpleInputSource("hello"))
|
||||
input_node = InputOperator(SimpleInputSource("hello"), task_id="112232")
|
||||
res: DAGContext[str] = await runner.execute_workflow(input_node)
|
||||
assert res.current_task_context.current_state == TaskState.SUCCESS
|
||||
assert res.current_task_context.task_output.output == "hello"
|
||||
@@ -36,7 +36,9 @@ async def test_input_node(runner: WorkflowRunner):
|
||||
yield i
|
||||
|
||||
num_iter = 10
|
||||
steam_input_node = InputOperator(SimpleInputSource(new_steam_iter(num_iter)))
|
||||
steam_input_node = InputOperator(
|
||||
SimpleInputSource(new_steam_iter(num_iter)), task_id="112232"
|
||||
)
|
||||
res: DAGContext[str] = await runner.execute_workflow(steam_input_node)
|
||||
assert res.current_task_context.current_state == TaskState.SUCCESS
|
||||
output_steam = res.current_task_context.task_output.output_stream
|
||||
|
@@ -0,0 +1 @@
|
||||
"""The trigger module of AWEL."""
|
||||
|
@@ -1,3 +1,4 @@
|
||||
"""Base class for all trigger classes."""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
@@ -6,6 +7,11 @@ from ..operator.common_operator import TriggerOperator
|
||||
|
||||
|
||||
class Trigger(TriggerOperator, ABC):
|
||||
"""Base class for all trigger classes.
|
||||
|
||||
Now only support http trigger.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def trigger(self) -> None:
|
||||
"""Trigger the workflow or a specific operation in the workflow."""
|
||||
|
@@ -1,10 +1,11 @@
|
||||
"""Http trigger for AWEL."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union, cast
|
||||
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
|
||||
@@ -13,29 +14,35 @@ from ..operator.base import BaseOperator
|
||||
from .base import Trigger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from fastapi import APIRouter
|
||||
|
||||
RequestBody = Union[Type[Request], Type[BaseModel], str]
|
||||
RequestBody = Union[Type[Request], Type[BaseModel], Type[str]]
|
||||
StreamingPredictFunc = Callable[[Union[Request, BaseModel]], bool]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HttpTrigger(Trigger):
|
||||
"""Http trigger for AWEL.
|
||||
|
||||
Http trigger is used to trigger a DAG by http request.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str,
|
||||
methods: Optional[Union[str, List[str]]] = "GET",
|
||||
request_body: Optional[RequestBody] = None,
|
||||
streaming_response: Optional[bool] = False,
|
||||
streaming_response: bool = False,
|
||||
streaming_predict_func: Optional[StreamingPredictFunc] = None,
|
||||
response_model: Optional[Type] = None,
|
||||
response_headers: Optional[Dict[str, str]] = None,
|
||||
response_media_type: Optional[str] = None,
|
||||
status_code: Optional[int] = 200,
|
||||
router_tags: Optional[List[str]] = None,
|
||||
router_tags: Optional[List[str | Enum]] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Initialize a HttpTrigger."""
|
||||
super().__init__(**kwargs)
|
||||
if not endpoint.startswith("/"):
|
||||
endpoint = "/" + endpoint
|
||||
@@ -49,15 +56,21 @@ class HttpTrigger(Trigger):
|
||||
self._router_tags = router_tags
|
||||
self._response_headers = response_headers
|
||||
self._response_media_type = response_media_type
|
||||
self._end_node: BaseOperator = None
|
||||
self._end_node: Optional[BaseOperator] = None
|
||||
|
||||
async def trigger(self) -> None:
|
||||
"""Trigger the DAG. Not used in HttpTrigger."""
|
||||
pass
|
||||
|
||||
def mount_to_router(self, router: "APIRouter") -> None:
|
||||
"""Mount the trigger to a router.
|
||||
|
||||
Args:
|
||||
router (APIRouter): The router to mount the trigger.
|
||||
"""
|
||||
from fastapi import Depends
|
||||
|
||||
methods = self._methods if isinstance(self._methods, list) else [self._methods]
|
||||
methods = [self._methods] if isinstance(self._methods, str) else self._methods
|
||||
|
||||
def create_route_function(name, req_body_cls: Optional[Type[BaseModel]]):
|
||||
async def _request_body_dependency(request: Request):
|
||||
@@ -87,7 +100,8 @@ class HttpTrigger(Trigger):
|
||||
)
|
||||
dynamic_route_function = create_route_function(function_name, request_model)
|
||||
logger.info(
|
||||
f"mount router function {dynamic_route_function}({function_name}), endpoint: {self._endpoint}, methods: {methods}"
|
||||
f"mount router function {dynamic_route_function}({function_name}), "
|
||||
f"endpoint: {self._endpoint}, methods: {methods}"
|
||||
)
|
||||
|
||||
router.api_route(
|
||||
@@ -100,17 +114,27 @@ class HttpTrigger(Trigger):
|
||||
|
||||
|
||||
async def _parse_request_body(
|
||||
request: Request, request_body_cls: Optional[Type[BaseModel]]
|
||||
request: Request, request_body_cls: Optional[RequestBody]
|
||||
):
|
||||
if not request_body_cls:
|
||||
return None
|
||||
if request.method == "POST":
|
||||
json_data = await request.json()
|
||||
return request_body_cls(**json_data)
|
||||
elif request.method == "GET":
|
||||
return request_body_cls(**request.query_params)
|
||||
else:
|
||||
if request_body_cls == Request:
|
||||
return request
|
||||
if request.method == "POST":
|
||||
if request_body_cls == str:
|
||||
bytes_body = await request.body()
|
||||
str_body = bytes_body.decode("utf-8")
|
||||
return str_body
|
||||
elif issubclass(request_body_cls, BaseModel):
|
||||
json_data = await request.json()
|
||||
return request_body_cls(**json_data)
|
||||
else:
|
||||
raise ValueError(f"Invalid request body cls: {request_body_cls}")
|
||||
elif request.method == "GET":
|
||||
if issubclass(request_body_cls, BaseModel):
|
||||
return request_body_cls(**request.query_params)
|
||||
else:
|
||||
raise ValueError(f"Invalid request body cls: {request_body_cls}")
|
||||
|
||||
|
||||
async def _trigger_dag(
|
||||
@@ -123,10 +147,10 @@ async def _trigger_dag(
|
||||
from fastapi import BackgroundTasks
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
end_node = dag.leaf_nodes
|
||||
if len(end_node) != 1:
|
||||
leaf_nodes = dag.leaf_nodes
|
||||
if len(leaf_nodes) != 1:
|
||||
raise ValueError("HttpTrigger just support one leaf node in dag")
|
||||
end_node = end_node[0]
|
||||
end_node = cast(BaseOperator, leaf_nodes[0])
|
||||
if not streaming_response:
|
||||
return await end_node.call(call_data={"data": body})
|
||||
else:
|
||||
@@ -141,7 +165,7 @@ async def _trigger_dag(
|
||||
}
|
||||
generator = await end_node.call_stream(call_data={"data": body})
|
||||
background_tasks = BackgroundTasks()
|
||||
background_tasks.add_task(end_node.dag._after_dag_end)
|
||||
background_tasks.add_task(dag._after_dag_end)
|
||||
return StreamingResponse(
|
||||
generator,
|
||||
headers=headers,
|
||||
|
@@ -1,41 +1,63 @@
|
||||
"""Trigger manager for AWEL.
|
||||
|
||||
After DB-GPT started, the trigger manager will be initialized and register all triggers
|
||||
"""
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
|
||||
from .base import Trigger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import APIRouter
|
||||
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerManager(ABC):
|
||||
"""Base class for trigger manager."""
|
||||
|
||||
@abstractmethod
|
||||
def register_trigger(self, trigger: Any) -> None:
|
||||
""" "Register a trigger to current manager"""
|
||||
"""Register a trigger to current manager."""
|
||||
|
||||
|
||||
class HttpTriggerManager(TriggerManager):
|
||||
"""Http trigger manager.
|
||||
|
||||
Register all http triggers to a router.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
router: Optional["APIRouter"] = None,
|
||||
router_prefix: Optional[str] = "/api/v1/awel/trigger",
|
||||
router_prefix: str = "/api/v1/awel/trigger",
|
||||
) -> None:
|
||||
"""Initialize a HttpTriggerManager.
|
||||
|
||||
Args:
|
||||
router (Optional["APIRouter"], optional): The router. Defaults to None.
|
||||
If None, will create a new FastAPI router.
|
||||
router_prefix (str, optional): The router prefix. Defaults
|
||||
to "/api/v1/awel/trigger".
|
||||
"""
|
||||
if not router:
|
||||
from fastapi import APIRouter
|
||||
|
||||
router = APIRouter()
|
||||
self._router_prefix = router_prefix
|
||||
self._router = router
|
||||
self._trigger_map = {}
|
||||
self._trigger_map: Dict[str, Trigger] = {}
|
||||
|
||||
def register_trigger(self, trigger: Any) -> None:
|
||||
"""Register a trigger to current manager."""
|
||||
from .http_trigger import HttpTrigger
|
||||
|
||||
if not isinstance(trigger, HttpTrigger):
|
||||
raise ValueError(f"Current trigger {trigger} not an object of HttpTrigger")
|
||||
trigger: HttpTrigger = trigger
|
||||
trigger_id = trigger.node_id
|
||||
if trigger_id not in self._trigger_map:
|
||||
trigger.mount_to_router(self._router)
|
||||
@@ -45,23 +67,32 @@ class HttpTriggerManager(TriggerManager):
|
||||
logger.info(
|
||||
f"Include router {self._router} to prefix path {self._router_prefix}"
|
||||
)
|
||||
system_app.app.include_router(
|
||||
self._router, prefix=self._router_prefix, tags=["AWEL"]
|
||||
)
|
||||
app = system_app.app
|
||||
if not app:
|
||||
raise RuntimeError("System app not initialized")
|
||||
app.include_router(self._router, prefix=self._router_prefix, tags=["AWEL"])
|
||||
|
||||
|
||||
class DefaultTriggerManager(TriggerManager, BaseComponent):
|
||||
"""Default trigger manager for AWEL.
|
||||
|
||||
Manage all trigger managers. Just support http trigger now.
|
||||
"""
|
||||
|
||||
name = ComponentType.AWEL_TRIGGER_MANAGER
|
||||
|
||||
def __init__(self, system_app: SystemApp | None = None):
|
||||
"""Initialize a DefaultTriggerManager."""
|
||||
self.system_app = system_app
|
||||
self.http_trigger = HttpTriggerManager()
|
||||
super().__init__(None)
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
"""Initialize the trigger manager."""
|
||||
self.system_app = system_app
|
||||
|
||||
def register_trigger(self, trigger: Any) -> None:
|
||||
"""Register a trigger to current manager."""
|
||||
from .http_trigger import HttpTrigger
|
||||
|
||||
if isinstance(trigger, HttpTrigger):
|
||||
@@ -71,4 +102,6 @@ class DefaultTriggerManager(TriggerManager, BaseComponent):
|
||||
raise ValueError(f"Unsupport trigger: {trigger}")
|
||||
|
||||
def after_register(self) -> None:
|
||||
self.http_trigger._init_app(self.system_app)
|
||||
"""After register, init the trigger manager."""
|
||||
if self.system_app:
|
||||
self.http_trigger._init_app(self.system_app)
|
||||
|
Reference in New Issue
Block a user