refactor: The first refactored version for sdk release (#907)

Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
FangYin Cheng
2023-12-08 14:45:59 +08:00
committed by GitHub
parent e7e4aff667
commit cd725db1fb
573 changed files with 2094 additions and 3571 deletions

View File

@@ -0,0 +1,87 @@
"""Agentic Workflow Expression Language (AWEL)
Note:
AWEL is still an experimental feature and only opens the lowest level API.
The stability of this API cannot be guaranteed at present.
"""
from dbgpt.component import SystemApp
from .dag.base import DAGContext, DAG
from .operator.base import BaseOperator, WorkflowRunner
from .operator.common_operator import (
JoinOperator,
ReduceStreamOperator,
MapOperator,
BranchOperator,
InputOperator,
BranchFunc,
)
from .operator.stream_operator import (
StreamifyAbsOperator,
UnstreamifyAbsOperator,
TransformStreamAbsOperator,
)
from .task.base import TaskState, TaskOutput, TaskContext, InputContext, InputSource
from .task.task_impl import (
SimpleInputSource,
SimpleCallDataInputSource,
DefaultTaskContext,
DefaultInputContext,
SimpleTaskOutput,
SimpleStreamTaskOutput,
_is_async_iterator,
)
from .trigger.http_trigger import HttpTrigger
from .runner.local_runner import DefaultWorkflowRunner
__all__ = [
"initialize_awel",
"DAGContext",
"DAG",
"BaseOperator",
"JoinOperator",
"ReduceStreamOperator",
"MapOperator",
"BranchOperator",
"InputOperator",
"BranchFunc",
"WorkflowRunner",
"TaskState",
"TaskOutput",
"TaskContext",
"InputContext",
"InputSource",
"DefaultWorkflowRunner",
"SimpleInputSource",
"SimpleCallDataInputSource",
"DefaultTaskContext",
"DefaultInputContext",
"SimpleTaskOutput",
"SimpleStreamTaskOutput",
"StreamifyAbsOperator",
"UnstreamifyAbsOperator",
"TransformStreamAbsOperator",
"HttpTrigger",
]
def initialize_awel(system_app: SystemApp, dag_filepath: str):
from .dag.dag_manager import DAGManager
from .dag.base import DAGVar
from .trigger.trigger_manager import DefaultTriggerManager
from .operator.base import initialize_runner
DAGVar.set_current_system_app(system_app)
system_app.register(DefaultTriggerManager)
dag_manager = DAGManager(system_app, dag_filepath)
system_app.register_instance(dag_manager)
initialize_runner(DefaultWorkflowRunner())
# Load all dags
dag_manager.load_dags()

7
dbgpt/core/awel/base.py Normal file
View File

@@ -0,0 +1,7 @@
from abc import ABC, abstractmethod
class Trigger(ABC):
@abstractmethod
async def trigger(self) -> None:
"""Trigger the workflow or a specific operation in the workflow."""

View File

371
dbgpt/core/awel/dag/base.py Normal file
View File

@@ -0,0 +1,371 @@
from abc import ABC, abstractmethod
from typing import Optional, Dict, List, Sequence, Union, Any, Set
import uuid
import contextvars
import threading
import asyncio
import logging
from collections import deque
from functools import cache
from concurrent.futures import Executor
from dbgpt.component import SystemApp
from ..resource.base import ResourceGroup
from ..task.base import TaskContext
logger = logging.getLogger(__name__)
DependencyType = Union["DependencyMixin", Sequence["DependencyMixin"]]
def _is_async_context():
try:
loop = asyncio.get_running_loop()
return asyncio.current_task(loop=loop) is not None
except RuntimeError:
return False
class DependencyMixin(ABC):
@abstractmethod
def set_upstream(self, nodes: DependencyType) -> "DependencyMixin":
"""Set one or more upstream nodes for this node.
Args:
nodes (DependencyType): Upstream nodes to be set to current node.
Returns:
DependencyMixin: Returns self to allow method chaining.
Raises:
ValueError: If no upstream nodes are provided or if an argument is not a DependencyMixin.
"""
@abstractmethod
def set_downstream(self, nodes: DependencyType) -> "DependencyMixin":
"""Set one or more downstream nodes for this node.
Args:
nodes (DependencyType): Downstream nodes to be set to current node.
Returns:
DependencyMixin: Returns self to allow method chaining.
Raises:
ValueError: If no downstream nodes are provided or if an argument is not a DependencyMixin.
"""
def __lshift__(self, nodes: DependencyType) -> DependencyType:
"""Implements self << nodes
Example:
.. code-block:: python
# means node.set_upstream(input_node)
node << input_node
# means node2.set_upstream([input_node])
node2 << [input_node]
"""
self.set_upstream(nodes)
return nodes
def __rshift__(self, nodes: DependencyType) -> DependencyType:
"""Implements self >> nodes
Example:
.. code-block:: python
# means node.set_downstream(next_node)
node >> next_node
# means node2.set_downstream([next_node])
node2 >> [next_node]
"""
self.set_downstream(nodes)
return nodes
def __rrshift__(self, nodes: DependencyType) -> "DependencyMixin":
"""Implements [node] >> self"""
self.__lshift__(nodes)
return self
def __rlshift__(self, nodes: DependencyType) -> "DependencyMixin":
"""Implements [node] << self"""
self.__rshift__(nodes)
return self
class DAGVar:
_thread_local = threading.local()
_async_local = contextvars.ContextVar("current_dag_stack", default=deque())
_system_app: SystemApp = None
_executor: Executor = None
@classmethod
def enter_dag(cls, dag) -> None:
is_async = _is_async_context()
if is_async:
stack = cls._async_local.get()
stack.append(dag)
cls._async_local.set(stack)
else:
if not hasattr(cls._thread_local, "current_dag_stack"):
cls._thread_local.current_dag_stack = deque()
cls._thread_local.current_dag_stack.append(dag)
@classmethod
def exit_dag(cls) -> None:
is_async = _is_async_context()
if is_async:
stack = cls._async_local.get()
if stack:
stack.pop()
cls._async_local.set(stack)
else:
if (
hasattr(cls._thread_local, "current_dag_stack")
and cls._thread_local.current_dag_stack
):
cls._thread_local.current_dag_stack.pop()
@classmethod
def get_current_dag(cls) -> Optional["DAG"]:
is_async = _is_async_context()
if is_async:
stack = cls._async_local.get()
return stack[-1] if stack else None
else:
if (
hasattr(cls._thread_local, "current_dag_stack")
and cls._thread_local.current_dag_stack
):
return cls._thread_local.current_dag_stack[-1]
return None
@classmethod
def get_current_system_app(cls) -> SystemApp:
# if not cls._system_app:
# raise RuntimeError("System APP not set for DAGVar")
return cls._system_app
@classmethod
def set_current_system_app(cls, system_app: SystemApp) -> None:
if cls._system_app:
logger.warn("System APP has already set, nothing to do")
else:
cls._system_app = system_app
@classmethod
def get_executor(cls) -> Executor:
return cls._executor
@classmethod
def set_executor(cls, executor: Executor) -> None:
cls._executor = executor
class DAGNode(DependencyMixin, ABC):
resource_group: Optional[ResourceGroup] = None
"""The resource group of current DAGNode"""
def __init__(
self,
dag: Optional["DAG"] = None,
node_id: Optional[str] = None,
node_name: Optional[str] = None,
system_app: Optional[SystemApp] = None,
executor: Optional[Executor] = None,
**kwargs
) -> None:
super().__init__()
self._upstream: List["DAGNode"] = []
self._downstream: List["DAGNode"] = []
self._dag: Optional["DAG"] = dag or DAGVar.get_current_dag()
self._system_app: Optional[SystemApp] = (
system_app or DAGVar.get_current_system_app()
)
self._executor: Optional[Executor] = executor or DAGVar.get_executor()
if not node_id and self._dag:
node_id = self._dag._new_node_id()
self._node_id: str = node_id
self._node_name: str = node_name
@property
def node_id(self) -> str:
return self._node_id
@property
def system_app(self) -> SystemApp:
return self._system_app
def set_node_id(self, node_id: str) -> None:
self._node_id = node_id
def __hash__(self) -> int:
if self.node_id:
return hash(self.node_id)
else:
return super().__hash__()
def __eq__(self, other: Any) -> bool:
if not isinstance(other, DAGNode):
return False
return self.node_id == other.node_id
@property
def node_name(self) -> str:
return self._node_name
@property
def dag(self) -> "DAG":
return self._dag
def set_upstream(self, nodes: DependencyType) -> "DAGNode":
self.set_dependency(nodes)
def set_downstream(self, nodes: DependencyType) -> "DAGNode":
self.set_dependency(nodes, is_upstream=False)
@property
def upstream(self) -> List["DAGNode"]:
return self._upstream
@property
def downstream(self) -> List["DAGNode"]:
return self._downstream
def set_dependency(self, nodes: DependencyType, is_upstream: bool = True) -> None:
if not isinstance(nodes, Sequence):
nodes = [nodes]
if not all(isinstance(node, DAGNode) for node in nodes):
raise ValueError(
"all nodes to set dependency to current node must be instance of 'DAGNode'"
)
nodes: Sequence[DAGNode] = nodes
dags = set([node.dag for node in nodes if node.dag])
if self.dag:
dags.add(self.dag)
if not dags:
raise ValueError("set dependency to current node must in a DAG context")
if len(dags) != 1:
raise ValueError(
"set dependency to current node just support in one DAG context"
)
dag = dags.pop()
self._dag = dag
dag._append_node(self)
for node in nodes:
if is_upstream and node not in self.upstream:
node._dag = dag
dag._append_node(node)
self._upstream.append(node)
node._downstream.append(self)
elif node not in self._downstream:
node._dag = dag
dag._append_node(node)
self._downstream.append(node)
node._upstream.append(self)
class DAGContext:
def __init__(self, streaming_call: bool = False) -> None:
self._streaming_call = streaming_call
self._curr_task_ctx = None
self._share_data: Dict[str, Any] = {}
@property
def current_task_context(self) -> TaskContext:
return self._curr_task_ctx
@property
def streaming_call(self) -> bool:
"""Whether the current DAG is streaming call"""
return self._streaming_call
def set_current_task_context(self, _curr_task_ctx: TaskContext) -> None:
self._curr_task_ctx = _curr_task_ctx
async def get_share_data(self, key: str) -> Any:
return self._share_data.get(key)
async def save_to_share_data(self, key: str, data: Any) -> None:
self._share_data[key] = data
class DAG:
def __init__(
self, dag_id: str, resource_group: Optional[ResourceGroup] = None
) -> None:
self._dag_id = dag_id
self.node_map: Dict[str, DAGNode] = {}
self._root_nodes: Set[DAGNode] = None
self._leaf_nodes: Set[DAGNode] = None
self._trigger_nodes: Set[DAGNode] = None
def _append_node(self, node: DAGNode) -> None:
self.node_map[node.node_id] = node
# clear cached nodes
self._root_nodes = None
self._leaf_nodes = None
def _new_node_id(self) -> str:
return str(uuid.uuid4())
@property
def dag_id(self) -> str:
return self._dag_id
def _build(self) -> None:
from ..operator.common_operator import TriggerOperator
nodes = set()
for _, node in self.node_map.items():
nodes = nodes.union(_get_nodes(node))
self._root_nodes = list(set(filter(lambda x: not x.upstream, nodes)))
self._leaf_nodes = list(set(filter(lambda x: not x.downstream, nodes)))
self._trigger_nodes = list(
set(filter(lambda x: isinstance(x, TriggerOperator), nodes))
)
@property
def root_nodes(self) -> List[DAGNode]:
if not self._root_nodes:
self._build()
return self._root_nodes
@property
def leaf_nodes(self) -> List[DAGNode]:
if not self._leaf_nodes:
self._build()
return self._leaf_nodes
@property
def trigger_nodes(self):
if not self._trigger_nodes:
self._build()
return self._trigger_nodes
def __enter__(self):
DAGVar.enter_dag(self)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
DAGVar.exit_dag()
def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> set[DAGNode]:
nodes = set()
if not node:
return nodes
nodes.add(node)
stream_nodes = node.upstream if is_upstream else node.downstream
for node in stream_nodes:
nodes = nodes.union(_get_nodes(node, is_upstream))
return nodes

View File

@@ -0,0 +1,42 @@
from typing import Dict, Optional
import logging
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from .loader import DAGLoader, LocalFileDAGLoader
from .base import DAG
logger = logging.getLogger(__name__)
class DAGManager(BaseComponent):
name = ComponentType.AWEL_DAG_MANAGER
def __init__(self, system_app: SystemApp, dag_filepath: str):
super().__init__(system_app)
self.dag_loader = LocalFileDAGLoader(dag_filepath)
self.system_app = system_app
self.dag_map: Dict[str, DAG] = {}
def init_app(self, system_app: SystemApp):
self.system_app = system_app
def load_dags(self):
dags = self.dag_loader.load_dags()
triggers = []
for dag in dags:
dag_id = dag.dag_id
if dag_id in self.dag_map:
raise ValueError(f"Load DAG error, DAG ID {dag_id} has already exist")
triggers += dag.trigger_nodes
from ..trigger.trigger_manager import DefaultTriggerManager
trigger_manager: DefaultTriggerManager = self.system_app.get_component(
ComponentType.AWEL_TRIGGER_MANAGER,
DefaultTriggerManager,
default_component=None,
)
if trigger_manager:
for trigger in triggers:
trigger_manager.register_trigger(trigger)
trigger_manager.after_register()
else:
logger.warn("No trigger manager, not register dag trigger")

View File

@@ -0,0 +1,93 @@
from abc import ABC, abstractmethod
from typing import List
import os
import hashlib
import sys
import logging
import traceback
from .base import DAG
logger = logging.getLogger(__name__)
class DAGLoader(ABC):
@abstractmethod
def load_dags(self) -> List[DAG]:
"""Load dags"""
class LocalFileDAGLoader(DAGLoader):
def __init__(self, filepath: str) -> None:
super().__init__()
self._filepath = filepath
def load_dags(self) -> List[DAG]:
if not os.path.exists(self._filepath):
return []
if os.path.isdir(self._filepath):
return _process_directory(self._filepath)
else:
return _process_file(self._filepath)
def _process_directory(directory: str) -> List[DAG]:
dags = []
for file in os.listdir(directory):
if file.endswith(".py"):
filepath = os.path.join(directory, file)
dags += _process_file(filepath)
return dags
def _process_file(filepath) -> List[DAG]:
mods = _load_modules_from_file(filepath)
results = _process_modules(mods)
return results
def _load_modules_from_file(filepath: str):
import importlib
import importlib.machinery
import importlib.util
logger.info(f"Importing {filepath}")
org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1])
path_hash = hashlib.sha1(filepath.encode("utf-8")).hexdigest()
mod_name = f"unusual_prefix_{path_hash}_{org_mod_name}"
if mod_name in sys.modules:
del sys.modules[mod_name]
def parse(mod_name, filepath):
try:
loader = importlib.machinery.SourceFileLoader(mod_name, filepath)
spec = importlib.util.spec_from_loader(mod_name, loader)
new_module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = new_module
loader.exec_module(new_module)
return [new_module]
except Exception as e:
msg = traceback.format_exc()
logger.error(f"Failed to import: {filepath}, error message: {msg}")
# TODO save error message
return []
return parse(mod_name, filepath)
def _process_modules(mods) -> List[DAG]:
top_level_dags = (
(o, m) for m in mods for o in m.__dict__.values() if isinstance(o, DAG)
)
found_dags = []
for dag, mod in top_level_dags:
try:
# TODO validate dag params
logger.info(f"Found dag {dag} from mod {mod} and model file {mod.__file__}")
found_dags.append(dag)
except Exception:
msg = traceback.format_exc()
logger.error(f"Failed to dag file, error message: {msg}")
return found_dags

View File

View File

@@ -0,0 +1,51 @@
import pytest
import threading
import asyncio
from ..dag import DAG, DAGContext
def test_dag_context_sync():
dag1 = DAG("dag1")
dag2 = DAG("dag2")
with dag1:
assert DAGContext.get_current_dag() == dag1
with dag2:
assert DAGContext.get_current_dag() == dag2
assert DAGContext.get_current_dag() == dag1
assert DAGContext.get_current_dag() is None
def test_dag_context_threading():
def thread_function(dag):
DAGContext.enter_dag(dag)
assert DAGContext.get_current_dag() == dag
DAGContext.exit_dag()
dag1 = DAG("dag1")
dag2 = DAG("dag2")
thread1 = threading.Thread(target=thread_function, args=(dag1,))
thread2 = threading.Thread(target=thread_function, args=(dag2,))
thread1.start()
thread2.start()
thread1.join()
thread2.join()
assert DAGContext.get_current_dag() is None
@pytest.mark.asyncio
async def test_dag_context_async():
async def async_function(dag):
DAGContext.enter_dag(dag)
assert DAGContext.get_current_dag() == dag
DAGContext.exit_dag()
dag1 = DAG("dag1")
dag2 = DAG("dag2")
await asyncio.gather(async_function(dag1), async_function(dag2))
assert DAGContext.get_current_dag() is None

View File

View File

@@ -0,0 +1,245 @@
from abc import ABC, abstractmethod, ABCMeta
from types import FunctionType
from typing import (
List,
Generic,
TypeVar,
AsyncIterator,
Iterator,
Union,
Any,
Dict,
Optional,
cast,
)
import functools
from inspect import signature
import asyncio
from dbgpt.component import SystemApp, ComponentType
from dbgpt.util.executor_utils import (
ExecutorFactory,
DefaultExecutorFactory,
blocking_func_to_async,
BlockingFunction,
AsyncToSyncIterator,
)
from ..dag.base import DAGNode, DAGContext, DAGVar, DAG
from ..task.base import TaskOutput, OUT, T
F = TypeVar("F", bound=FunctionType)
CALL_DATA = Union[Dict, Dict[str, Dict]]
class WorkflowRunner(ABC, Generic[T]):
"""Abstract base class representing a runner for executing workflows in a DAG.
This class defines the interface for executing workflows within the DAG,
handling the flow from one DAG node to another.
"""
@abstractmethod
async def execute_workflow(
self,
node: "BaseOperator",
call_data: Optional[CALL_DATA] = None,
streaming_call: bool = False,
) -> DAGContext:
"""Execute the workflow starting from a given operator.
Args:
node (RunnableDAGNode): The starting node of the workflow to be executed.
call_data (CALL_DATA): The data pass to root operator node.
streaming_call (bool): Whether the call is a streaming call.
Returns:
DAGContext: The context after executing the workflow, containing the final state and data.
"""
default_runner: WorkflowRunner = None
class BaseOperatorMeta(ABCMeta):
"""Metaclass of BaseOperator."""
@classmethod
def _apply_defaults(cls, func: F) -> F:
sig_cache = signature(func)
@functools.wraps(func)
def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any:
dag: Optional[DAG] = kwargs.get("dag") or DAGVar.get_current_dag()
task_id: Optional[str] = kwargs.get("task_id")
system_app: Optional[SystemApp] = (
kwargs.get("system_app") or DAGVar.get_current_system_app()
)
executor = kwargs.get("executor") or DAGVar.get_executor()
if not executor:
if system_app:
executor = system_app.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
else:
executor = DefaultExecutorFactory().create()
DAGVar.set_executor(executor)
if not task_id and dag:
task_id = dag._new_node_id()
runner: Optional[WorkflowRunner] = kwargs.get("runner") or default_runner
# print(f"self: {self}, kwargs dag: {kwargs.get('dag')}, kwargs: {kwargs}")
# for arg in sig_cache.parameters:
# if arg not in kwargs:
# kwargs[arg] = default_args[arg]
if not kwargs.get("dag"):
kwargs["dag"] = dag
if not kwargs.get("task_id"):
kwargs["task_id"] = task_id
if not kwargs.get("runner"):
kwargs["runner"] = runner
if not kwargs.get("system_app"):
kwargs["system_app"] = system_app
if not kwargs.get("executor"):
kwargs["executor"] = executor
real_obj = func(self, *args, **kwargs)
return real_obj
return cast(T, apply_defaults)
def __new__(cls, name, bases, namespace, **kwargs):
new_cls = super().__new__(cls, name, bases, namespace, **kwargs)
new_cls.__init__ = cls._apply_defaults(new_cls.__init__)
return new_cls
class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
"""Abstract base class for operator nodes that can be executed within a workflow.
This class extends DAGNode by adding execution capabilities.
"""
def __init__(
self,
task_id: Optional[str] = None,
task_name: Optional[str] = None,
dag: Optional[DAG] = None,
runner: WorkflowRunner = None,
**kwargs,
) -> None:
"""Initializes a BaseOperator with an optional workflow runner.
Args:
runner (WorkflowRunner, optional): The runner used to execute the workflow. Defaults to None.
"""
super().__init__(node_id=task_id, node_name=task_name, dag=dag, **kwargs)
if not runner:
from dbgpt.core.awel import DefaultWorkflowRunner
runner = DefaultWorkflowRunner()
self._runner: WorkflowRunner = runner
self._dag_ctx: DAGContext = None
@property
def current_dag_context(self) -> DAGContext:
return self._dag_ctx
async def _run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
if not self.node_id:
raise ValueError(f"The DAG Node ID can't be empty, current node {self}")
self._dag_ctx = dag_ctx
return await self._do_run(dag_ctx)
@abstractmethod
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
"""
Abstract method to run the task within the DAG node.
Args:
dag_ctx (DAGContext): The context of the DAG when this node is run.
Returns:
TaskOutput[OUT]: The task output after this node has been run.
"""
async def call(self, call_data: Optional[CALL_DATA] = None) -> OUT:
"""Execute the node and return the output.
This method is a high-level wrapper for executing the node.
Args:
call_data (CALL_DATA): The data pass to root operator node.
Returns:
OUT: The output of the node after execution.
"""
out_ctx = await self._runner.execute_workflow(self, call_data)
return out_ctx.current_task_context.task_output.output
def _blocking_call(
self, call_data: Optional[CALL_DATA] = None, loop: asyncio.BaseEventLoop = None
) -> OUT:
"""Execute the node and return the output.
This method is a high-level wrapper for executing the node.
This method just for debug. Please use `call` method instead.
Args:
call_data (CALL_DATA): The data pass to root operator node.
Returns:
OUT: The output of the node after execution.
"""
from dbgpt.util.utils import get_or_create_event_loop
if not loop:
loop = get_or_create_event_loop()
return loop.run_until_complete(self.call(call_data))
async def call_stream(
self, call_data: Optional[CALL_DATA] = None
) -> AsyncIterator[OUT]:
"""Execute the node and return the output as a stream.
This method is used for nodes where the output is a stream.
Args:
call_data (CALL_DATA): The data pass to root operator node.
Returns:
AsyncIterator[OUT]: An asynchronous iterator over the output stream.
"""
out_ctx = await self._runner.execute_workflow(self, call_data)
return out_ctx.current_task_context.task_output.output_stream
def _blocking_call_stream(
self, call_data: Optional[CALL_DATA] = None, loop: asyncio.BaseEventLoop = None
) -> Iterator[OUT]:
"""Execute the node and return the output as a stream.
This method is used for nodes where the output is a stream.
This method just for debug. Please use `call_stream` method instead.
Args:
call_data (CALL_DATA): The data pass to root operator node.
Returns:
Iterator[OUT]: An iterator over the output stream.
"""
from dbgpt.util.utils import get_or_create_event_loop
if not loop:
loop = get_or_create_event_loop()
return AsyncToSyncIterator(self.call_stream(call_data), loop)
async def blocking_func_to_async(
self, func: BlockingFunction, *args, **kwargs
) -> Any:
return await blocking_func_to_async(self._executor, func, *args, **kwargs)
def initialize_runner(runner: WorkflowRunner):
global default_runner
default_runner = runner

View File

@@ -0,0 +1,252 @@
from typing import Generic, Dict, List, Union, Callable, Any, AsyncIterator, Awaitable
import asyncio
import logging
from ..dag.base import DAGContext
from ..task.base import (
TaskContext,
TaskOutput,
IN,
OUT,
InputContext,
InputSource,
)
from .base import BaseOperator
logger = logging.getLogger(__name__)
class JoinOperator(BaseOperator, Generic[OUT]):
"""Operator that joins inputs using a custom combine function.
This node type is useful for combining the outputs of upstream nodes.
"""
def __init__(self, combine_function, **kwargs):
super().__init__(**kwargs)
if not callable(combine_function):
raise ValueError("combine_function must be callable")
self.combine_function = combine_function
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
"""Run the join operation on the DAG context's inputs.
Args:
dag_ctx (DAGContext): The current context of the DAG.
Returns:
TaskOutput[OUT]: The task output after this node has been run.
"""
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
input_ctx: InputContext = await curr_task_ctx.task_input.map_all(
self.combine_function
)
# All join result store in the first parent output
join_output = input_ctx.parent_outputs[0].task_output
curr_task_ctx.set_task_output(join_output)
return join_output
class ReduceStreamOperator(BaseOperator, Generic[IN, OUT]):
def __init__(self, reduce_function=None, **kwargs):
"""Initializes a ReduceStreamOperator with a combine function.
Args:
combine_function: A function that defines how to combine inputs.
Raises:
ValueError: If the combine_function is not callable.
"""
super().__init__(**kwargs)
if reduce_function and not callable(reduce_function):
raise ValueError("reduce_function must be callable")
self.reduce_function = reduce_function
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
"""Run the join operation on the DAG context's inputs.
Args:
dag_ctx (DAGContext): The current context of the DAG.
Returns:
TaskOutput[OUT]: The task output after this node has been run.
"""
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
task_input = curr_task_ctx.task_input
if not task_input.check_stream():
raise ValueError("ReduceStreamOperator expects stream data")
if not task_input.check_single_parent():
raise ValueError("ReduceStreamOperator expects single parent")
reduce_function = self.reduce_function or self.reduce
input_ctx: InputContext = await task_input.reduce(reduce_function)
# All join result store in the first parent output
reduce_output = input_ctx.parent_outputs[0].task_output
curr_task_ctx.set_task_output(reduce_output)
return reduce_output
async def reduce(self, input_value: AsyncIterator[IN]) -> OUT:
raise NotImplementedError
class MapOperator(BaseOperator, Generic[IN, OUT]):
"""Map operator that applies a mapping function to its inputs.
This operator transforms its input data using a provided mapping function and
passes the transformed data downstream.
"""
def __init__(self, map_function=None, **kwargs):
"""Initializes a MapDAGNode with a mapping function.
Args:
map_function: A function that defines how to map the input data.
Raises:
ValueError: If the map_function is not callable.
"""
super().__init__(**kwargs)
if map_function and not callable(map_function):
raise ValueError("map_function must be callable")
self.map_function = map_function
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
"""Run the mapping operation on the DAG context's inputs.
This method applies the mapping function to the input context and updates
the DAG context with the new data.
Args:
dag_ctx (DAGContext[IN]): The current context of the DAG.
Returns:
TaskOutput[OUT]: The task output after this node has been run.
Raises:
ValueError: If not a single parent or the map_function is not callable
"""
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
call_data = curr_task_ctx.call_data
if not call_data and not curr_task_ctx.task_input.check_single_parent():
num_parents = len(curr_task_ctx.task_input.parent_outputs)
raise ValueError(
f"task {curr_task_ctx.task_id} MapDAGNode expects single parent, now number of parents: {num_parents}"
)
map_function = self.map_function or self.map
if call_data:
call_data = await curr_task_ctx._call_data_to_output()
output = await call_data.map(map_function)
curr_task_ctx.set_task_output(output)
return output
input_ctx: InputContext = await curr_task_ctx.task_input.map(map_function)
# All join result store in the first parent output
output = input_ctx.parent_outputs[0].task_output
curr_task_ctx.set_task_output(output)
return output
async def map(self, input_value: IN) -> OUT:
raise NotImplementedError
BranchFunc = Union[Callable[[IN], bool], Callable[[IN], Awaitable[bool]]]
class BranchOperator(BaseOperator, Generic[IN, OUT]):
"""Operator node that branches the workflow based on a provided function.
This node filters its input data using a branching function and
allows for conditional paths in the workflow.
"""
def __init__(
self, branches: Dict[BranchFunc[IN], Union[BaseOperator, str]], **kwargs
):
"""
Initializes a BranchDAGNode with a branching function.
Args:
branches (Dict[BranchFunc[IN], Union[BaseOperator, str]]): Dict of function that defines the branching condition.
Raises:
ValueError: If the branch_function is not callable.
"""
super().__init__(**kwargs)
if branches:
for branch_function, value in branches.items():
if not callable(branch_function):
raise ValueError("branch_function must be callable")
if isinstance(value, BaseOperator):
branches[branch_function] = value.node_name or value.node_name
self._branches = branches
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
"""Run the branching operation on the DAG context's inputs.
This method applies the branching function to the input context to determine
the path of execution in the workflow.
Args:
dag_ctx (DAGContext[IN]): The current context of the DAG.
Returns:
TaskOutput[OUT]: The task output after this node has been run.
"""
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
task_input = curr_task_ctx.task_input
if task_input.check_stream():
raise ValueError("BranchDAGNode expects no stream data")
if not task_input.check_single_parent():
raise ValueError("BranchDAGNode expects single parent")
branches = self._branches
if not branches:
branches = await self.branchs()
branch_func_tasks = []
branch_nodes: List[str] = []
for func, node_name in branches.items():
branch_nodes.append(node_name)
branch_func_tasks.append(
curr_task_ctx.task_input.predicate_map(func, failed_value=None)
)
branch_input_ctxs: List[InputContext] = await asyncio.gather(*branch_func_tasks)
parent_output = task_input.parent_outputs[0].task_output
curr_task_ctx.set_task_output(parent_output)
skip_node_names = []
for i, ctx in enumerate(branch_input_ctxs):
node_name = branch_nodes[i]
branch_out = ctx.parent_outputs[0].task_output
logger.info(
f"branch_input_ctxs {i} result {branch_out.output}, is_empty: {branch_out.is_empty}"
)
if ctx.parent_outputs[0].task_output.is_empty:
logger.info(f"Skip node name {node_name}")
skip_node_names.append(node_name)
curr_task_ctx.update_metadata("skip_node_names", skip_node_names)
return parent_output
async def branchs(self) -> Dict[BranchFunc[IN], Union[BaseOperator, str]]:
raise NotImplementedError
class InputOperator(BaseOperator, Generic[OUT]):
def __init__(self, input_source: InputSource[OUT], **kwargs) -> None:
super().__init__(**kwargs)
self._input_source = input_source
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
task_output = await self._input_source.read(curr_task_ctx)
curr_task_ctx.set_task_output(task_output)
return task_output
class TriggerOperator(InputOperator, Generic[OUT]):
def __init__(self, **kwargs) -> None:
from ..task.task_impl import SimpleCallDataInputSource
super().__init__(input_source=SimpleCallDataInputSource(), **kwargs)

View File

@@ -0,0 +1,90 @@
from abc import ABC, abstractmethod
from typing import Generic, AsyncIterator
from ..task.base import OUT, IN, TaskOutput, TaskContext
from ..dag.base import DAGContext
from .base import BaseOperator
class StreamifyAbsOperator(BaseOperator[OUT], ABC, Generic[IN, OUT]):
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
output = await curr_task_ctx.task_input.parent_outputs[0].task_output.streamify(
self.streamify
)
curr_task_ctx.set_task_output(output)
return output
@abstractmethod
async def streamify(self, input_value: IN) -> AsyncIterator[OUT]:
"""Convert a value of IN to an AsyncIterator[OUT]
Args:
input_value (IN): The data of parent operator's output
Example:
.. code-block:: python
class MyStreamOperator(StreamifyAbsOperator[int, int]):
async def streamify(self, input_value: int) -> AsyncIterator[int]
for i in range(input_value):
yield i
"""
class UnstreamifyAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
output = await curr_task_ctx.task_input.parent_outputs[
0
].task_output.unstreamify(self.unstreamify)
curr_task_ctx.set_task_output(output)
return output
@abstractmethod
async def unstreamify(self, input_value: AsyncIterator[IN]) -> OUT:
"""Convert a value of AsyncIterator[IN] to an OUT.
Args:
input_value (AsyncIterator[IN])): The data of parent operator's output
Example:
.. code-block:: python
class MyUnstreamOperator(UnstreamifyAbsOperator[int, int]):
async def unstreamify(self, input_value: AsyncIterator[int]) -> int
value_cnt = 0
async for v in input_value:
value_cnt += 1
return value_cnt
"""
class TransformStreamAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
output = await curr_task_ctx.task_input.parent_outputs[
0
].task_output.transform_stream(self.transform_stream)
curr_task_ctx.set_task_output(output)
return output
@abstractmethod
async def transform_stream(
self, input_value: AsyncIterator[IN]
) -> AsyncIterator[OUT]:
"""Transform an AsyncIterator[IN] to another AsyncIterator[OUT] using a given function.
Args:
input_value (AsyncIterator[IN])): The data of parent operator's output
Example:
.. code-block:: python
class MyTransformStreamOperator(TransformStreamAbsOperator[int, int]):
async def unstreamify(self, input_value: AsyncIterator[int]) -> AsyncIterator[int]
async for v in input_value:
yield v + 1
"""

View File

View File

@@ -0,0 +1,8 @@
from abc import ABC, abstractmethod
class ResourceGroup(ABC):
@property
@abstractmethod
def name(self) -> str:
"""The name of current resource group"""

View File

View File

@@ -0,0 +1,82 @@
from typing import List, Set, Optional, Dict
import uuid
import logging
from ..dag.base import DAG
from ..operator.base import BaseOperator, CALL_DATA
logger = logging.getLogger(__name__)
class DAGNodeInstance:
def __init__(self, node_instance: DAG) -> None:
pass
class DAGInstance:
def __init__(self, dag: DAG) -> None:
self._dag = dag
class JobManager:
def __init__(
self,
root_nodes: List[BaseOperator],
all_nodes: List[BaseOperator],
end_node: BaseOperator,
id2call_data: Dict[str, Dict],
) -> None:
self._root_nodes = root_nodes
self._all_nodes = all_nodes
self._end_node = end_node
self._id2node_data = id2call_data
@staticmethod
def build_from_end_node(
end_node: BaseOperator, call_data: Optional[CALL_DATA] = None
) -> "JobManager":
nodes = _build_from_end_node(end_node)
root_nodes = _get_root_nodes(nodes)
id2call_data = _save_call_data(root_nodes, call_data)
return JobManager(root_nodes, nodes, end_node, id2call_data)
def get_call_data_by_id(self, node_id: str) -> Optional[Dict]:
return self._id2node_data.get(node_id)
def _save_call_data(
root_nodes: List[BaseOperator], call_data: CALL_DATA
) -> Dict[str, Dict]:
id2call_data = {}
logger.debug(f"_save_call_data: {call_data}, root_nodes: {root_nodes}")
if not call_data:
return id2call_data
if len(root_nodes) == 1:
node = root_nodes[0]
logger.info(f"Save call data to node {node.node_id}, call_data: {call_data}")
id2call_data[node.node_id] = call_data
else:
for node in root_nodes:
node_id = node.node_id
logger.info(
f"Save call data to node {node.node_id}, call_data: {call_data.get(node_id)}"
)
id2call_data[node_id] = call_data.get(node_id)
return id2call_data
def _build_from_end_node(end_node: BaseOperator) -> List[BaseOperator]:
nodes = []
if isinstance(end_node, BaseOperator):
task_id = end_node.node_id
if not task_id:
task_id = str(uuid.uuid4())
end_node.set_node_id(task_id)
nodes.append(end_node)
for node in end_node.upstream:
nodes += _build_from_end_node(node)
return nodes
def _get_root_nodes(nodes: List[BaseOperator]) -> List[BaseOperator]:
return list(set(filter(lambda x: not x.upstream, nodes)))

View File

@@ -0,0 +1,109 @@
from typing import Dict, Optional, Set, List
import logging
from ..dag.base import DAGContext
from ..operator.base import WorkflowRunner, BaseOperator, CALL_DATA
from ..operator.common_operator import BranchOperator, JoinOperator, TriggerOperator
from ..task.base import TaskContext, TaskState
from ..task.task_impl import DefaultInputContext, DefaultTaskContext, SimpleTaskOutput
from .job_manager import JobManager
logger = logging.getLogger(__name__)
class DefaultWorkflowRunner(WorkflowRunner):
async def execute_workflow(
self,
node: BaseOperator,
call_data: Optional[CALL_DATA] = None,
streaming_call: bool = False,
) -> DAGContext:
# Create DAG context
dag_ctx = DAGContext(streaming_call=streaming_call)
job_manager = JobManager.build_from_end_node(node, call_data)
logger.info(
f"Begin run workflow from end operator, id: {node.node_id}, call_data: {call_data}"
)
dag = node.dag
# Save node output
node_outputs: Dict[str, TaskContext] = {}
skip_node_ids = set()
await self._execute_node(
job_manager, node, dag_ctx, node_outputs, skip_node_ids
)
return dag_ctx
async def _execute_node(
self,
job_manager: JobManager,
node: BaseOperator,
dag_ctx: DAGContext,
node_outputs: Dict[str, TaskContext],
skip_node_ids: Set[str],
):
# Skip run node
if node.node_id in node_outputs:
return
# Run all upstream node
for upstream_node in node.upstream:
if isinstance(upstream_node, BaseOperator):
await self._execute_node(
job_manager, upstream_node, dag_ctx, node_outputs, skip_node_ids
)
inputs = [
node_outputs[upstream_node.node_id] for upstream_node in node.upstream
]
input_ctx = DefaultInputContext(inputs)
task_ctx = DefaultTaskContext(node.node_id, TaskState.INIT, task_output=None)
task_ctx.set_call_data(job_manager.get_call_data_by_id(node.node_id))
task_ctx.set_task_input(input_ctx)
dag_ctx.set_current_task_context(task_ctx)
task_ctx.set_current_state(TaskState.RUNNING)
if node.node_id in skip_node_ids:
task_ctx.set_current_state(TaskState.SKIP)
task_ctx.set_task_output(SimpleTaskOutput(None))
node_outputs[node.node_id] = task_ctx
return
try:
logger.debug(
f"Begin run operator, node id: {node.node_id}, node name: {node.node_name}, cls: {node}"
)
await node._run(dag_ctx)
node_outputs[node.node_id] = dag_ctx.current_task_context
task_ctx.set_current_state(TaskState.SUCCESS)
if isinstance(node, BranchOperator):
skip_nodes = task_ctx.metadata.get("skip_node_names", [])
logger.debug(
f"Current is branch operator, skip node names: {skip_nodes}"
)
_skip_current_downstream_by_node_name(node, skip_nodes, skip_node_ids)
except Exception as e:
logger.info(f"Run operator {node.node_id} error, error message: {str(e)}")
task_ctx.set_current_state(TaskState.FAILED)
raise e
def _skip_current_downstream_by_node_name(
branch_node: BranchOperator, skip_nodes: List[str], skip_node_ids: Set[str]
):
if not skip_nodes:
return
for child in branch_node.downstream:
if child.node_name in skip_nodes:
logger.info(f"Skip node name {child.node_name}, node id {child.node_id}")
_skip_downstream_by_id(child, skip_node_ids)
def _skip_downstream_by_id(node: BaseOperator, skip_node_ids: Set[str]):
if isinstance(node, JoinOperator):
# Not skip join node
return
skip_node_ids.add(node.node_id)
for child in node.downstream:
_skip_downstream_by_id(child, skip_node_ids)

View File

View File

@@ -0,0 +1,371 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import (
TypeVar,
Generic,
Optional,
AsyncIterator,
Union,
Callable,
Any,
Dict,
List,
)
IN = TypeVar("IN")
OUT = TypeVar("OUT")
T = TypeVar("T")
class TaskState(str, Enum):
"""Enumeration representing the state of a task in the workflow.
This Enum defines various states a task can be in during its lifecycle in the DAG.
"""
INIT = "init" # Initial state of the task, not yet started
SKIP = "skip" # State indicating the task was skipped
RUNNING = "running" # State indicating the task is currently running
SUCCESS = "success" # State indicating the task completed successfully
FAILED = "failed" # State indicating the task failed during execution
class TaskOutput(ABC, Generic[T]):
"""Abstract base class representing the output of a task.
This class encapsulates the output of a task and provides methods to access the output data.
It can be subclassed to implement specific output behaviors.
"""
@property
def is_stream(self) -> bool:
"""Check if the output is a stream.
Returns:
bool: True if the output is a stream, False otherwise.
"""
return False
@property
def is_empty(self) -> bool:
"""Check if the output is empty.
Returns:
bool: True if the output is empty, False otherwise.
"""
return False
@property
def output(self) -> Optional[T]:
"""Return the output of the task.
Returns:
T: The output of the task. None if the output is empty.
"""
raise NotImplementedError
@property
def output_stream(self) -> Optional[AsyncIterator[T]]:
"""Return the output of the task as an asynchronous stream.
Returns:
AsyncIterator[T]: An asynchronous iterator over the output. None if the output is empty.
"""
raise NotImplementedError
@abstractmethod
def set_output(self, output_data: Union[T, AsyncIterator[T]]) -> None:
"""Set the output data to current object.
Args:
output_data (Union[T, AsyncIterator[T]]): Output data.
"""
@abstractmethod
def new_output(self) -> "TaskOutput[T]":
"""Create new output object"""
async def map(self, map_func) -> "TaskOutput[T]":
"""Apply a mapping function to the task's output.
Args:
map_func: A function to apply to the task's output.
Returns:
TaskOutput[T]: The result of applying the mapping function.
"""
raise NotImplementedError
async def reduce(self, reduce_func) -> "TaskOutput[T]":
"""Apply a reducing function to the task's output.
Stream TaskOutput to Nonstream TaskOutput.
Args:
reduce_func: A reducing function to apply to the task's output.
Returns:
TaskOutput[T]: The result of applying the reducing function.
"""
raise NotImplementedError
async def streamify(
self, transform_func: Callable[[T], AsyncIterator[T]]
) -> "TaskOutput[T]":
"""Convert a value of type T to an AsyncIterator[T] using a transform function.
Args:
transform_func (Callable[[T], AsyncIterator[T]]): Function to transform a T value into an AsyncIterator[T].
Returns:
TaskOutput[T]: The result of applying the reducing function.
"""
raise NotImplementedError
async def transform_stream(
self, transform_func: Callable[[AsyncIterator[T]], AsyncIterator[T]]
) -> "TaskOutput[T]":
"""Transform an AsyncIterator[T] to another AsyncIterator[T] using a given function.
Args:
transform_func (Callable[[AsyncIterator[T]], AsyncIterator[T]]): Function to apply to the AsyncIterator[T].
Returns:
TaskOutput[T]: The result of applying the reducing function.
"""
raise NotImplementedError
async def unstreamify(
self, transform_func: Callable[[AsyncIterator[T]], T]
) -> "TaskOutput[T]":
"""Convert an AsyncIterator[T] to a value of type T using a transform function.
Args:
transform_func (Callable[[AsyncIterator[T]], T]): Function to transform an AsyncIterator[T] into a T value.
Returns:
TaskOutput[T]: The result of applying the reducing function.
"""
raise NotImplementedError
async def check_condition(self, condition_func) -> bool:
"""Check if current output meets a given condition.
Args:
condition_func: A function to determine if the condition is met.
Returns:
bool: True if current output meet the condition, False otherwise.
"""
raise NotImplementedError
class TaskContext(ABC, Generic[T]):
"""Abstract base class representing the context of a task within a DAG.
This class provides the interface for accessing task-related information
and manipulating task output.
"""
@property
@abstractmethod
def task_id(self) -> str:
"""Return the unique identifier of the task.
Returns:
str: The unique identifier of the task.
"""
@property
@abstractmethod
def task_input(self) -> "InputContext":
"""Return the InputContext of current task.
Returns:
InputContext: The InputContext of current task.
"""
@abstractmethod
def set_task_input(self, input_ctx: "InputContext") -> None:
"""Set the InputContext object to current task.
Args:
input_ctx (InputContext): The InputContext of current task
"""
@property
@abstractmethod
def task_output(self) -> TaskOutput[T]:
"""Return the output object of the task.
Returns:
TaskOutput[T]: The output object of the task.
"""
@abstractmethod
def set_task_output(self, task_output: TaskOutput[T]) -> None:
"""Set the output object to current task."""
@property
@abstractmethod
def current_state(self) -> TaskState:
"""Get the current state of the task.
Returns:
TaskState: The current state of the task.
"""
@abstractmethod
def set_current_state(self, task_state: TaskState) -> None:
"""Set current task state
Args:
task_state (TaskState): The task state to be set.
"""
@abstractmethod
def new_ctx(self) -> "TaskContext":
"""Create new task context
Returns:
TaskContext: A new instance of a TaskContext.
"""
@property
@abstractmethod
def metadata(self) -> Dict[str, Any]:
"""Get the metadata of current task
Returns:
Dict[str, Any]: The metadata
"""
def update_metadata(self, key: str, value: Any) -> None:
"""Update metadata with key and value
Args:
key (str): The key of metadata
value (str): The value to be add to metadata
"""
self.metadata[key] = value
@property
def call_data(self) -> Optional[Dict]:
"""Get the call data for current data"""
return self.metadata.get("call_data")
@abstractmethod
async def _call_data_to_output(self) -> Optional[TaskOutput[T]]:
"""Get the call data for current data"""
def set_call_data(self, call_data: Dict) -> None:
"""Set call data for current task"""
self.update_metadata("call_data", call_data)
class InputContext(ABC):
"""Abstract base class representing the context of inputs to a operator node.
This class defines methods to manipulate and access the inputs for a operator node.
"""
@property
@abstractmethod
def parent_outputs(self) -> List[TaskContext]:
"""Get the outputs from the parent nodes.
Returns:
List[TaskContext]: A list of contexts of the parent nodes' outputs.
"""
@abstractmethod
async def map(self, map_func: Callable[[Any], Any]) -> "InputContext":
"""Apply a mapping function to the inputs.
Args:
map_func (Callable[[Any], Any]): A function to be applied to the inputs.
Returns:
InputContext: A new InputContext instance with the mapped inputs.
"""
@abstractmethod
async def map_all(self, map_func: Callable[..., Any]) -> "InputContext":
"""Apply a mapping function to all inputs.
Args:
map_func (Callable[..., Any]): A function to be applied to all inputs.
Returns:
InputContext: A new InputContext instance with the mapped inputs.
"""
@abstractmethod
async def reduce(self, reduce_func: Callable[[Any], Any]) -> "InputContext":
"""Apply a reducing function to the inputs.
Args:
reduce_func (Callable[[Any], Any]): A function that reduces the inputs.
Returns:
InputContext: A new InputContext instance with the reduced inputs.
"""
@abstractmethod
async def filter(self, filter_func: Callable[[Any], bool]) -> "InputContext":
"""Filter the inputs based on a provided function.
Args:
filter_func (Callable[[Any], bool]): A function that returns True for inputs to keep.
Returns:
InputContext: A new InputContext instance with the filtered inputs.
"""
@abstractmethod
async def predicate_map(
self, predicate_func: Callable[[Any], bool], failed_value: Any = None
) -> "InputContext":
"""Predicate the inputs based on a provided function.
Args:
predicate_func (Callable[[Any], bool]): A function that returns True for inputs is predicate True.
failed_value (Any): The value to be set if the return value of predicate function is False
Returns:
InputContext: A new InputContext instance with the predicate inputs.
"""
def check_single_parent(self) -> bool:
"""Check if there is only a single parent output.
Returns:
bool: True if there is only one parent output, False otherwise.
"""
return len(self.parent_outputs) == 1
def check_stream(self, skip_empty: bool = False) -> bool:
"""Check if all parent outputs are streams.
Args:
skip_empty (bool): Skip empty output or not.
Returns:
bool: True if all parent outputs are streams, False otherwise.
"""
for out in self.parent_outputs:
if out.task_output.is_empty and skip_empty:
continue
if not (out.task_output and out.task_output.is_stream):
return False
return True
class InputSource(ABC, Generic[T]):
"""Abstract base class representing the source of inputs to a DAG node."""
@abstractmethod
async def read(self, task_ctx: TaskContext) -> TaskOutput[T]:
"""Read the data from current input source.
Returns:
TaskOutput[T]: The output object read from current source
"""

View File

@@ -0,0 +1,348 @@
from abc import ABC, abstractmethod
from typing import (
Callable,
Coroutine,
Iterator,
AsyncIterator,
List,
Generic,
TypeVar,
Any,
Tuple,
Dict,
Union,
Optional,
)
import asyncio
import logging
from .base import TaskOutput, TaskContext, TaskState, InputContext, InputSource, T
logger = logging.getLogger(__name__)
async def _reduce_stream(stream: AsyncIterator, reduce_function) -> Any:
# Init accumulator
try:
accumulator = await stream.__anext__()
except StopAsyncIteration:
raise ValueError("Stream is empty")
is_async = asyncio.iscoroutinefunction(reduce_function)
async for element in stream:
if is_async:
accumulator = await reduce_function(accumulator, element)
else:
accumulator = reduce_function(accumulator, element)
return accumulator
class SimpleTaskOutput(TaskOutput[T], Generic[T]):
def __init__(self, data: T) -> None:
super().__init__()
self._data = data
@property
def output(self) -> T:
return self._data
def set_output(self, output_data: T | AsyncIterator[T]) -> None:
self._data = output_data
def new_output(self) -> TaskOutput[T]:
return SimpleTaskOutput(None)
@property
def is_empty(self) -> bool:
return not self._data
async def _apply_func(self, func) -> Any:
if asyncio.iscoroutinefunction(func):
out = await func(self._data)
else:
out = func(self._data)
return out
async def map(self, map_func) -> TaskOutput[T]:
out = await self._apply_func(map_func)
return SimpleTaskOutput(out)
async def check_condition(self, condition_func) -> bool:
return await self._apply_func(condition_func)
async def streamify(
self, transform_func: Callable[[T], AsyncIterator[T]]
) -> TaskOutput[T]:
out = await self._apply_func(transform_func)
return SimpleStreamTaskOutput(out)
class SimpleStreamTaskOutput(TaskOutput[T], Generic[T]):
def __init__(self, data: AsyncIterator[T]) -> None:
super().__init__()
self._data = data
@property
def is_stream(self) -> bool:
return True
@property
def is_empty(self) -> bool:
return not self._data
@property
def output_stream(self) -> AsyncIterator[T]:
return self._data
def set_output(self, output_data: T | AsyncIterator[T]) -> None:
self._data = output_data
def new_output(self) -> TaskOutput[T]:
return SimpleStreamTaskOutput(None)
async def map(self, map_func) -> TaskOutput[T]:
is_async = asyncio.iscoroutinefunction(map_func)
async def new_iter() -> AsyncIterator[T]:
async for out in self._data:
if is_async:
out = await map_func(out)
else:
out = map_func(out)
yield out
return SimpleStreamTaskOutput(new_iter())
async def reduce(self, reduce_func) -> TaskOutput[T]:
out = await _reduce_stream(self._data, reduce_func)
return SimpleTaskOutput(out)
async def unstreamify(
self, transform_func: Callable[[AsyncIterator[T]], T]
) -> TaskOutput[T]:
if asyncio.iscoroutinefunction(transform_func):
out = await transform_func(self._data)
else:
out = transform_func(self._data)
return SimpleTaskOutput(out)
async def transform_stream(
self, transform_func: Callable[[AsyncIterator[T]], AsyncIterator[T]]
) -> TaskOutput[T]:
if asyncio.iscoroutinefunction(transform_func):
out = await transform_func(self._data)
else:
out = transform_func(self._data)
return SimpleStreamTaskOutput(out)
def _is_async_iterator(obj):
return (
hasattr(obj, "__anext__")
and callable(getattr(obj, "__anext__", None))
and hasattr(obj, "__aiter__")
and callable(getattr(obj, "__aiter__", None))
)
class BaseInputSource(InputSource, ABC):
def __init__(self) -> None:
super().__init__()
self._is_read = False
@abstractmethod
def _read_data(self, task_ctx: TaskContext) -> Any:
"""Read data with task context"""
async def read(self, task_ctx: TaskContext) -> TaskOutput:
data = self._read_data(task_ctx)
if _is_async_iterator(data):
if self._is_read:
raise ValueError(f"Input iterator {data} has been read!")
output = SimpleStreamTaskOutput(data)
else:
output = SimpleTaskOutput(data)
self._is_read = True
return output
class SimpleInputSource(BaseInputSource):
def __init__(self, data: Any) -> None:
super().__init__()
self._data = data
def _read_data(self, task_ctx: TaskContext) -> Any:
return self._data
class SimpleCallDataInputSource(BaseInputSource):
def __init__(self) -> None:
super().__init__()
def _read_data(self, task_ctx: TaskContext) -> Any:
call_data = task_ctx.call_data
data = call_data.get("data") if call_data else None
if not (call_data and data):
raise ValueError("No call data for current SimpleCallDataInputSource")
return data
class DefaultTaskContext(TaskContext, Generic[T]):
def __init__(
self, task_id: str, task_state: TaskState, task_output: TaskOutput[T]
) -> None:
super().__init__()
self._task_id = task_id
self._task_state = task_state
self._output = task_output
self._task_input = None
self._metadata = {}
@property
def task_id(self) -> str:
return self._task_id
@property
def task_input(self) -> InputContext:
return self._task_input
def set_task_input(self, input_ctx: "InputContext") -> None:
self._task_input = input_ctx
@property
def task_output(self) -> TaskOutput:
return self._output
def set_task_output(self, task_output: TaskOutput) -> None:
self._output = task_output
@property
def current_state(self) -> TaskState:
return self._task_state
def set_current_state(self, task_state: TaskState) -> None:
self._task_state = task_state
def new_ctx(self) -> TaskContext:
new_output = self._output.new_output()
return DefaultTaskContext(self._task_id, self._task_state, new_output)
@property
def metadata(self) -> Dict[str, Any]:
return self._metadata
async def _call_data_to_output(self) -> Optional[TaskOutput[T]]:
"""Get the call data for current data"""
call_data = self.call_data
if not call_data:
return None
input_source = SimpleCallDataInputSource()
return await input_source.read(self)
class DefaultInputContext(InputContext):
def __init__(self, outputs: List[TaskContext]) -> None:
super().__init__()
self._outputs = outputs
@property
def parent_outputs(self) -> List[TaskContext]:
return self._outputs
async def _apply_func(
self, func: Callable[[Any], Any], apply_type: str = "map"
) -> Tuple[List[TaskContext], List[TaskOutput]]:
new_outputs: List[TaskContext] = []
map_tasks = []
for out in self._outputs:
new_outputs.append(out.new_ctx())
result = None
if apply_type == "map":
result = out.task_output.map(func)
elif apply_type == "reduce":
result = out.task_output.reduce(func)
elif apply_type == "check_condition":
result = out.task_output.check_condition(func)
else:
raise ValueError(f"Unsupport apply type {apply_type}")
map_tasks.append(result)
results = await asyncio.gather(*map_tasks)
return new_outputs, results
async def map(self, map_func: Callable[[Any], Any]) -> InputContext:
new_outputs, results = await self._apply_func(map_func)
for i, task_ctx in enumerate(new_outputs):
task_ctx: TaskContext = task_ctx
task_ctx.set_task_output(results[i])
return DefaultInputContext(new_outputs)
async def map_all(self, map_func: Callable[..., Any]) -> InputContext:
if not self._outputs:
return DefaultInputContext([])
# Some parent may be empty
not_empty_idx = 0
for i, p in enumerate(self._outputs):
if p.task_output.is_empty:
continue
not_empty_idx = i
break
# All output is empty?
is_steam = self._outputs[not_empty_idx].task_output.is_stream
if is_steam:
if not self.check_stream(skip_empty=True):
raise ValueError(
"The output in all tasks must has same output format to map_all"
)
outputs = []
for out in self._outputs:
if out.task_output.is_stream:
outputs.append(out.task_output.output_stream)
else:
outputs.append(out.task_output.output)
if asyncio.iscoroutinefunction(map_func):
map_res = await map_func(*outputs)
else:
map_res = map_func(*outputs)
single_output: TaskContext = self._outputs[not_empty_idx].new_ctx()
single_output.task_output.set_output(map_res)
logger.debug(
f"Current map_all map_res: {map_res}, is steam: {single_output.task_output.is_stream}"
)
return DefaultInputContext([single_output])
async def reduce(self, reduce_func: Callable[[Any], Any]) -> InputContext:
if not self.check_stream():
raise ValueError(
"The output in all tasks must has same output format of stream to apply reduce function"
)
new_outputs, results = await self._apply_func(reduce_func, apply_type="reduce")
for i, task_ctx in enumerate(new_outputs):
task_ctx: TaskContext = task_ctx
task_ctx.set_task_output(results[i])
return DefaultInputContext(new_outputs)
async def filter(self, filter_func: Callable[[Any], bool]) -> InputContext:
new_outputs, results = await self._apply_func(
filter_func, apply_type="check_condition"
)
result_outputs = []
for i, task_ctx in enumerate(new_outputs):
if results[i]:
result_outputs.append(task_ctx)
return DefaultInputContext(result_outputs)
async def predicate_map(
self, predicate_func: Callable[[Any], bool], failed_value: Any = None
) -> "InputContext":
new_outputs, results = await self._apply_func(
predicate_func, apply_type="check_condition"
)
result_outputs = []
for i, task_ctx in enumerate(new_outputs):
task_ctx: TaskContext = task_ctx
if results[i]:
task_ctx.task_output.set_output(True)
result_outputs.append(task_ctx)
else:
task_ctx.task_output.set_output(failed_value)
result_outputs.append(task_ctx)
return DefaultInputContext(result_outputs)

View File

View File

@@ -0,0 +1,102 @@
import pytest
import pytest_asyncio
from typing import AsyncIterator, List
from contextlib import contextmanager, asynccontextmanager
from .. import (
WorkflowRunner,
InputOperator,
DAGContext,
TaskState,
DefaultWorkflowRunner,
SimpleInputSource,
)
from ..task.task_impl import _is_async_iterator
@pytest.fixture
def runner():
return DefaultWorkflowRunner()
def _create_stream(num_nodes) -> List[AsyncIterator[int]]:
iters = []
for _ in range(num_nodes):
async def stream_iter():
for i in range(10):
yield i
stream_iter = stream_iter()
assert _is_async_iterator(stream_iter)
iters.append(stream_iter)
return iters
def _create_stream_from(output_streams: List[List[int]]) -> List[AsyncIterator[int]]:
iters = []
for single_stream in output_streams:
async def stream_iter():
for i in single_stream:
yield i
stream_iter = stream_iter()
assert _is_async_iterator(stream_iter)
iters.append(stream_iter)
return iters
@asynccontextmanager
async def _create_input_node(**kwargs):
num_nodes = kwargs.get("num_nodes")
is_stream = kwargs.get("is_stream", False)
if is_stream:
outputs = kwargs.get("output_streams")
if outputs:
if num_nodes and num_nodes != len(outputs):
raise ValueError(
f"num_nodes {num_nodes} != the length of output_streams {len(outputs)}"
)
outputs = _create_stream_from(outputs)
else:
num_nodes = num_nodes or 1
outputs = _create_stream(num_nodes)
else:
outputs = kwargs.get("outputs", ["Hello."])
nodes = []
for output in outputs:
print(f"output: {output}")
input_source = SimpleInputSource(output)
input_node = InputOperator(input_source)
nodes.append(input_node)
yield nodes
@pytest_asyncio.fixture
async def input_node(request):
param = getattr(request, "param", {})
async with _create_input_node(**param) as input_nodes:
yield input_nodes[0]
@pytest_asyncio.fixture
async def stream_input_node(request):
param = getattr(request, "param", {})
param["is_stream"] = True
async with _create_input_node(**param) as input_nodes:
yield input_nodes[0]
@pytest_asyncio.fixture
async def input_nodes(request):
param = getattr(request, "param", {})
async with _create_input_node(**param) as input_nodes:
yield input_nodes
@pytest_asyncio.fixture
async def stream_input_nodes(request):
param = getattr(request, "param", {})
param["is_stream"] = True
async with _create_input_node(**param) as input_nodes:
yield input_nodes

View File

@@ -0,0 +1,51 @@
import pytest
from typing import List
from .. import (
DAG,
WorkflowRunner,
DAGContext,
TaskState,
InputOperator,
MapOperator,
JoinOperator,
BranchOperator,
ReduceStreamOperator,
SimpleInputSource,
)
from .conftest import (
runner,
input_node,
input_nodes,
stream_input_node,
stream_input_nodes,
_is_async_iterator,
)
def _register_dag_to_fastapi_app(dag):
# TODO
pass
@pytest.mark.asyncio
async def test_http_operator(runner: WorkflowRunner, stream_input_node: InputOperator):
with DAG("test_map") as dag:
pass
# http_req_task = HttpRequestOperator(endpoint="/api/completions")
# db_task = DBQueryOperator(table_name="user_info")
# prompt_task = PromptTemplateOperator(
# system_prompt="You are an AI designed to solve the user's goals with given commands, please follow the constraints of the system's input for your answers."
# )
# llm_task = ChatGPTLLMOperator(model="chagpt-3.5")
# output_parser_task = CommonOutputParserOperator()
# http_res_task = HttpResponseOperator()
# (
# http_req_task
# >> db_task
# >> prompt_task
# >> llm_task
# >> output_parser_task
# >> http_res_task
# )
_register_dag_to_fastapi_app(dag)

View File

@@ -0,0 +1,141 @@
import pytest
from typing import List
from .. import (
DAG,
WorkflowRunner,
DAGContext,
TaskState,
InputOperator,
MapOperator,
JoinOperator,
BranchOperator,
ReduceStreamOperator,
SimpleInputSource,
)
from .conftest import (
runner,
input_node,
input_nodes,
stream_input_node,
stream_input_nodes,
_is_async_iterator,
)
@pytest.mark.asyncio
async def test_input_node(runner: WorkflowRunner):
input_node = InputOperator(SimpleInputSource("hello"))
res: DAGContext[str] = await runner.execute_workflow(input_node)
assert res.current_task_context.current_state == TaskState.SUCCESS
assert res.current_task_context.task_output.output == "hello"
async def new_steam_iter(n: int):
for i in range(n):
yield i
num_iter = 10
steam_input_node = InputOperator(SimpleInputSource(new_steam_iter(num_iter)))
res: DAGContext[str] = await runner.execute_workflow(steam_input_node)
assert res.current_task_context.current_state == TaskState.SUCCESS
output_steam = res.current_task_context.task_output.output_stream
assert output_steam
assert _is_async_iterator(output_steam)
i = 0
async for x in output_steam:
assert x == i
i += 1
@pytest.mark.asyncio
async def test_map_node(runner: WorkflowRunner, stream_input_node: InputOperator):
with DAG("test_map") as dag:
map_node = MapOperator(lambda x: x * 2)
stream_input_node >> map_node
res: DAGContext[int] = await runner.execute_workflow(map_node)
output_steam = res.current_task_context.task_output.output_stream
assert output_steam
i = 0
async for x in output_steam:
assert x == i * 2
i += 1
@pytest.mark.asyncio
@pytest.mark.parametrize(
"stream_input_node, expect_sum",
[
({"output_streams": [[0, 1, 2, 3]]}, 6),
({"output_streams": [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]}, 55),
],
indirect=["stream_input_node"],
)
async def test_reduce_node(
runner: WorkflowRunner, stream_input_node: InputOperator, expect_sum: int
):
with DAG("test_reduce_node") as dag:
reduce_node = ReduceStreamOperator(lambda x, y: x + y)
stream_input_node >> reduce_node
res: DAGContext[int] = await runner.execute_workflow(reduce_node)
assert res.current_task_context.current_state == TaskState.SUCCESS
assert not res.current_task_context.task_output.is_stream
assert res.current_task_context.task_output.output == expect_sum
@pytest.mark.asyncio
@pytest.mark.parametrize(
"input_nodes",
[
({"outputs": [0, 1, 2]}),
],
indirect=["input_nodes"],
)
async def test_join_node(runner: WorkflowRunner, input_nodes: List[InputOperator]):
def join_func(p1, p2, p3) -> int:
return p1 + p2 + p3
with DAG("test_join_node") as dag:
join_node = JoinOperator(join_func)
for input_node in input_nodes:
input_node >> join_node
res: DAGContext[int] = await runner.execute_workflow(join_node)
assert res.current_task_context.current_state == TaskState.SUCCESS
assert not res.current_task_context.task_output.is_stream
assert res.current_task_context.task_output.output == 3
@pytest.mark.asyncio
@pytest.mark.parametrize(
"input_node, is_odd",
[
({"outputs": [0]}, False),
({"outputs": [1]}, True),
],
indirect=["input_node"],
)
async def test_branch_node(
runner: WorkflowRunner, input_node: InputOperator, is_odd: bool
):
def join_func(o1, o2) -> int:
print(f"join func result, o1: {o1}, o2: {o2}")
return o1 or o2
with DAG("test_join_node") as dag:
odd_node = MapOperator(
lambda x: 999, task_id="odd_node", task_name="odd_node_name"
)
even_node = MapOperator(
lambda x: 888, task_id="even_node", task_name="even_node_name"
)
join_node = JoinOperator(join_func)
branch_node = BranchOperator(
{lambda x: x % 2 == 1: odd_node, lambda x: x % 2 == 0: even_node}
)
branch_node >> odd_node >> join_node
branch_node >> even_node >> join_node
input_node >> branch_node
res: DAGContext[int] = await runner.execute_workflow(join_node)
assert res.current_task_context.current_state == TaskState.SUCCESS
expect_res = 999 if is_odd else 888
assert res.current_task_context.task_output.output == expect_res

View File

View File

@@ -0,0 +1,11 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from ..operator.common_operator import TriggerOperator
class Trigger(TriggerOperator, ABC):
@abstractmethod
async def trigger(self) -> None:
"""Trigger the workflow or a specific operation in the workflow."""

View File

@@ -0,0 +1,137 @@
from __future__ import annotations
from typing import Union, Type, List, TYPE_CHECKING, Optional, Any, Dict
from starlette.requests import Request
from starlette.responses import Response
from dbgpt._private.pydantic import BaseModel
import logging
from .base import Trigger
from ..dag.base import DAG
from ..operator.base import BaseOperator
if TYPE_CHECKING:
from fastapi import APIRouter, FastAPI
RequestBody = Union[Request, Type[BaseModel], str]
logger = logging.getLogger(__name__)
class HttpTrigger(Trigger):
def __init__(
self,
endpoint: str,
methods: Optional[Union[str, List[str]]] = "GET",
request_body: Optional[RequestBody] = None,
streaming_response: Optional[bool] = False,
response_model: Optional[Type] = None,
response_headers: Optional[Dict[str, str]] = None,
response_media_type: Optional[str] = None,
status_code: Optional[int] = 200,
router_tags: Optional[List[str]] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
if not endpoint.startswith("/"):
endpoint = "/" + endpoint
self._endpoint = endpoint
self._methods = methods
self._req_body = request_body
self._streaming_response = streaming_response
self._response_model = response_model
self._status_code = status_code
self._router_tags = router_tags
self._response_headers = response_headers
self._response_media_type = response_media_type
self._end_node: BaseOperator = None
async def trigger(self) -> None:
pass
def mount_to_router(self, router: "APIRouter") -> None:
from fastapi import Depends
methods = self._methods if isinstance(self._methods, list) else [self._methods]
def create_route_function(name, req_body_cls: Optional[Type[BaseModel]]):
async def _request_body_dependency(request: Request):
return await _parse_request_body(request, self._req_body)
async def route_function(body=Depends(_request_body_dependency)):
return await _trigger_dag(
body,
self.dag,
self._streaming_response,
self._response_headers,
self._response_media_type,
)
route_function.__name__ = name
return route_function
function_name = f"AWEL_trigger_route_{self._endpoint.replace('/', '_')}"
request_model = (
self._req_body
if isinstance(self._req_body, type)
and issubclass(self._req_body, BaseModel)
else None
)
dynamic_route_function = create_route_function(function_name, request_model)
logger.info(
f"mount router function {dynamic_route_function}({function_name}), endpoint: {self._endpoint}, methods: {methods}"
)
router.api_route(
self._endpoint,
methods=methods,
response_model=self._response_model,
status_code=self._status_code,
tags=self._router_tags,
)(dynamic_route_function)
async def _parse_request_body(
request: Request, request_body_cls: Optional[Type[BaseModel]]
):
if not request_body_cls:
return None
if request.method == "POST":
json_data = await request.json()
return request_body_cls(**json_data)
elif request.method == "GET":
return request_body_cls(**request.query_params)
else:
return request
async def _trigger_dag(
body: Any,
dag: DAG,
streaming_response: Optional[bool] = False,
response_headers: Optional[Dict[str, str]] = None,
response_media_type: Optional[str] = None,
) -> Any:
from fastapi.responses import StreamingResponse
end_node = dag.leaf_nodes
if len(end_node) != 1:
raise ValueError("HttpTrigger just support one leaf node in dag")
end_node = end_node[0]
if not streaming_response:
return await end_node.call(call_data={"data": body})
else:
headers = response_headers
media_type = response_media_type if response_media_type else "text/event-stream"
if not headers:
headers = {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked",
}
return StreamingResponse(
end_node.call_stream(call_data={"data": body}),
headers=headers,
media_type=media_type,
)

View File

@@ -0,0 +1,74 @@
from abc import ABC, abstractmethod
from typing import Any, TYPE_CHECKING, Optional
import logging
if TYPE_CHECKING:
from fastapi import APIRouter
from dbgpt.component import SystemApp, BaseComponent, ComponentType
logger = logging.getLogger(__name__)
class TriggerManager(ABC):
@abstractmethod
def register_trigger(self, trigger: Any) -> None:
""" "Register a trigger to current manager"""
class HttpTriggerManager(TriggerManager):
def __init__(
self,
router: Optional["APIRouter"] = None,
router_prefix: Optional[str] = "/api/v1/awel/trigger",
) -> None:
if not router:
from fastapi import APIRouter
router = APIRouter()
self._router_prefix = router_prefix
self._router = router
self._trigger_map = {}
def register_trigger(self, trigger: Any) -> None:
from .http_trigger import HttpTrigger
if not isinstance(trigger, HttpTrigger):
raise ValueError(f"Current trigger {trigger} not an object of HttpTrigger")
trigger: HttpTrigger = trigger
trigger_id = trigger.node_id
if trigger_id not in self._trigger_map:
trigger.mount_to_router(self._router)
self._trigger_map[trigger_id] = trigger
def _init_app(self, system_app: SystemApp):
logger.info(
f"Include router {self._router} to prefix path {self._router_prefix}"
)
system_app.app.include_router(
self._router, prefix=self._router_prefix, tags=["AWEL"]
)
class DefaultTriggerManager(TriggerManager, BaseComponent):
name = ComponentType.AWEL_TRIGGER_MANAGER
def __init__(self, system_app: SystemApp | None = None):
self.system_app = system_app
self.http_trigger = HttpTriggerManager()
super().__init__(None)
def init_app(self, system_app: SystemApp):
self.system_app = system_app
def register_trigger(self, trigger: Any) -> None:
from .http_trigger import HttpTrigger
if isinstance(trigger, HttpTrigger):
logger.info(f"Register trigger {trigger}")
self.http_trigger.register_trigger(trigger)
else:
raise ValueError(f"Unsupport trigger: {trigger}")
def after_register(self) -> None:
self.http_trigger._init_app(self.system_app)