mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-16 22:51:24 +00:00
refactor: The first refactored version for sdk release (#907)
Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
87
dbgpt/core/awel/__init__.py
Normal file
87
dbgpt/core/awel/__init__.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""Agentic Workflow Expression Language (AWEL)
|
||||
|
||||
Note:
|
||||
|
||||
AWEL is still an experimental feature and only opens the lowest level API.
|
||||
The stability of this API cannot be guaranteed at present.
|
||||
|
||||
"""
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
|
||||
from .dag.base import DAGContext, DAG
|
||||
|
||||
from .operator.base import BaseOperator, WorkflowRunner
|
||||
from .operator.common_operator import (
|
||||
JoinOperator,
|
||||
ReduceStreamOperator,
|
||||
MapOperator,
|
||||
BranchOperator,
|
||||
InputOperator,
|
||||
BranchFunc,
|
||||
)
|
||||
|
||||
from .operator.stream_operator import (
|
||||
StreamifyAbsOperator,
|
||||
UnstreamifyAbsOperator,
|
||||
TransformStreamAbsOperator,
|
||||
)
|
||||
|
||||
from .task.base import TaskState, TaskOutput, TaskContext, InputContext, InputSource
|
||||
from .task.task_impl import (
|
||||
SimpleInputSource,
|
||||
SimpleCallDataInputSource,
|
||||
DefaultTaskContext,
|
||||
DefaultInputContext,
|
||||
SimpleTaskOutput,
|
||||
SimpleStreamTaskOutput,
|
||||
_is_async_iterator,
|
||||
)
|
||||
from .trigger.http_trigger import HttpTrigger
|
||||
from .runner.local_runner import DefaultWorkflowRunner
|
||||
|
||||
__all__ = [
|
||||
"initialize_awel",
|
||||
"DAGContext",
|
||||
"DAG",
|
||||
"BaseOperator",
|
||||
"JoinOperator",
|
||||
"ReduceStreamOperator",
|
||||
"MapOperator",
|
||||
"BranchOperator",
|
||||
"InputOperator",
|
||||
"BranchFunc",
|
||||
"WorkflowRunner",
|
||||
"TaskState",
|
||||
"TaskOutput",
|
||||
"TaskContext",
|
||||
"InputContext",
|
||||
"InputSource",
|
||||
"DefaultWorkflowRunner",
|
||||
"SimpleInputSource",
|
||||
"SimpleCallDataInputSource",
|
||||
"DefaultTaskContext",
|
||||
"DefaultInputContext",
|
||||
"SimpleTaskOutput",
|
||||
"SimpleStreamTaskOutput",
|
||||
"StreamifyAbsOperator",
|
||||
"UnstreamifyAbsOperator",
|
||||
"TransformStreamAbsOperator",
|
||||
"HttpTrigger",
|
||||
]
|
||||
|
||||
|
||||
def initialize_awel(system_app: SystemApp, dag_filepath: str):
|
||||
from .dag.dag_manager import DAGManager
|
||||
from .dag.base import DAGVar
|
||||
from .trigger.trigger_manager import DefaultTriggerManager
|
||||
from .operator.base import initialize_runner
|
||||
|
||||
DAGVar.set_current_system_app(system_app)
|
||||
|
||||
system_app.register(DefaultTriggerManager)
|
||||
dag_manager = DAGManager(system_app, dag_filepath)
|
||||
system_app.register_instance(dag_manager)
|
||||
initialize_runner(DefaultWorkflowRunner())
|
||||
# Load all dags
|
||||
dag_manager.load_dags()
|
7
dbgpt/core/awel/base.py
Normal file
7
dbgpt/core/awel/base.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class Trigger(ABC):
|
||||
@abstractmethod
|
||||
async def trigger(self) -> None:
|
||||
"""Trigger the workflow or a specific operation in the workflow."""
|
0
dbgpt/core/awel/dag/__init__.py
Normal file
0
dbgpt/core/awel/dag/__init__.py
Normal file
371
dbgpt/core/awel/dag/base.py
Normal file
371
dbgpt/core/awel/dag/base.py
Normal file
@@ -0,0 +1,371 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Dict, List, Sequence, Union, Any, Set
|
||||
import uuid
|
||||
import contextvars
|
||||
import threading
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import deque
|
||||
from functools import cache
|
||||
from concurrent.futures import Executor
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from ..resource.base import ResourceGroup
|
||||
from ..task.base import TaskContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DependencyType = Union["DependencyMixin", Sequence["DependencyMixin"]]
|
||||
|
||||
|
||||
def _is_async_context():
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
return asyncio.current_task(loop=loop) is not None
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
|
||||
class DependencyMixin(ABC):
|
||||
@abstractmethod
|
||||
def set_upstream(self, nodes: DependencyType) -> "DependencyMixin":
|
||||
"""Set one or more upstream nodes for this node.
|
||||
|
||||
Args:
|
||||
nodes (DependencyType): Upstream nodes to be set to current node.
|
||||
|
||||
Returns:
|
||||
DependencyMixin: Returns self to allow method chaining.
|
||||
|
||||
Raises:
|
||||
ValueError: If no upstream nodes are provided or if an argument is not a DependencyMixin.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_downstream(self, nodes: DependencyType) -> "DependencyMixin":
|
||||
"""Set one or more downstream nodes for this node.
|
||||
|
||||
Args:
|
||||
nodes (DependencyType): Downstream nodes to be set to current node.
|
||||
|
||||
Returns:
|
||||
DependencyMixin: Returns self to allow method chaining.
|
||||
|
||||
Raises:
|
||||
ValueError: If no downstream nodes are provided or if an argument is not a DependencyMixin.
|
||||
"""
|
||||
|
||||
def __lshift__(self, nodes: DependencyType) -> DependencyType:
|
||||
"""Implements self << nodes
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# means node.set_upstream(input_node)
|
||||
node << input_node
|
||||
|
||||
# means node2.set_upstream([input_node])
|
||||
node2 << [input_node]
|
||||
"""
|
||||
self.set_upstream(nodes)
|
||||
return nodes
|
||||
|
||||
def __rshift__(self, nodes: DependencyType) -> DependencyType:
|
||||
"""Implements self >> nodes
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# means node.set_downstream(next_node)
|
||||
node >> next_node
|
||||
|
||||
# means node2.set_downstream([next_node])
|
||||
node2 >> [next_node]
|
||||
|
||||
"""
|
||||
self.set_downstream(nodes)
|
||||
return nodes
|
||||
|
||||
def __rrshift__(self, nodes: DependencyType) -> "DependencyMixin":
|
||||
"""Implements [node] >> self"""
|
||||
self.__lshift__(nodes)
|
||||
return self
|
||||
|
||||
def __rlshift__(self, nodes: DependencyType) -> "DependencyMixin":
|
||||
"""Implements [node] << self"""
|
||||
self.__rshift__(nodes)
|
||||
return self
|
||||
|
||||
|
||||
class DAGVar:
|
||||
_thread_local = threading.local()
|
||||
_async_local = contextvars.ContextVar("current_dag_stack", default=deque())
|
||||
_system_app: SystemApp = None
|
||||
_executor: Executor = None
|
||||
|
||||
@classmethod
|
||||
def enter_dag(cls, dag) -> None:
|
||||
is_async = _is_async_context()
|
||||
if is_async:
|
||||
stack = cls._async_local.get()
|
||||
stack.append(dag)
|
||||
cls._async_local.set(stack)
|
||||
else:
|
||||
if not hasattr(cls._thread_local, "current_dag_stack"):
|
||||
cls._thread_local.current_dag_stack = deque()
|
||||
cls._thread_local.current_dag_stack.append(dag)
|
||||
|
||||
@classmethod
|
||||
def exit_dag(cls) -> None:
|
||||
is_async = _is_async_context()
|
||||
if is_async:
|
||||
stack = cls._async_local.get()
|
||||
if stack:
|
||||
stack.pop()
|
||||
cls._async_local.set(stack)
|
||||
else:
|
||||
if (
|
||||
hasattr(cls._thread_local, "current_dag_stack")
|
||||
and cls._thread_local.current_dag_stack
|
||||
):
|
||||
cls._thread_local.current_dag_stack.pop()
|
||||
|
||||
@classmethod
|
||||
def get_current_dag(cls) -> Optional["DAG"]:
|
||||
is_async = _is_async_context()
|
||||
if is_async:
|
||||
stack = cls._async_local.get()
|
||||
return stack[-1] if stack else None
|
||||
else:
|
||||
if (
|
||||
hasattr(cls._thread_local, "current_dag_stack")
|
||||
and cls._thread_local.current_dag_stack
|
||||
):
|
||||
return cls._thread_local.current_dag_stack[-1]
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_current_system_app(cls) -> SystemApp:
|
||||
# if not cls._system_app:
|
||||
# raise RuntimeError("System APP not set for DAGVar")
|
||||
return cls._system_app
|
||||
|
||||
@classmethod
|
||||
def set_current_system_app(cls, system_app: SystemApp) -> None:
|
||||
if cls._system_app:
|
||||
logger.warn("System APP has already set, nothing to do")
|
||||
else:
|
||||
cls._system_app = system_app
|
||||
|
||||
@classmethod
|
||||
def get_executor(cls) -> Executor:
|
||||
return cls._executor
|
||||
|
||||
@classmethod
|
||||
def set_executor(cls, executor: Executor) -> None:
|
||||
cls._executor = executor
|
||||
|
||||
|
||||
class DAGNode(DependencyMixin, ABC):
|
||||
resource_group: Optional[ResourceGroup] = None
|
||||
"""The resource group of current DAGNode"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dag: Optional["DAG"] = None,
|
||||
node_id: Optional[str] = None,
|
||||
node_name: Optional[str] = None,
|
||||
system_app: Optional[SystemApp] = None,
|
||||
executor: Optional[Executor] = None,
|
||||
**kwargs
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._upstream: List["DAGNode"] = []
|
||||
self._downstream: List["DAGNode"] = []
|
||||
self._dag: Optional["DAG"] = dag or DAGVar.get_current_dag()
|
||||
self._system_app: Optional[SystemApp] = (
|
||||
system_app or DAGVar.get_current_system_app()
|
||||
)
|
||||
self._executor: Optional[Executor] = executor or DAGVar.get_executor()
|
||||
if not node_id and self._dag:
|
||||
node_id = self._dag._new_node_id()
|
||||
self._node_id: str = node_id
|
||||
self._node_name: str = node_name
|
||||
|
||||
@property
|
||||
def node_id(self) -> str:
|
||||
return self._node_id
|
||||
|
||||
@property
|
||||
def system_app(self) -> SystemApp:
|
||||
return self._system_app
|
||||
|
||||
def set_node_id(self, node_id: str) -> None:
|
||||
self._node_id = node_id
|
||||
|
||||
def __hash__(self) -> int:
|
||||
if self.node_id:
|
||||
return hash(self.node_id)
|
||||
else:
|
||||
return super().__hash__()
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if not isinstance(other, DAGNode):
|
||||
return False
|
||||
return self.node_id == other.node_id
|
||||
|
||||
@property
|
||||
def node_name(self) -> str:
|
||||
return self._node_name
|
||||
|
||||
@property
|
||||
def dag(self) -> "DAG":
|
||||
return self._dag
|
||||
|
||||
def set_upstream(self, nodes: DependencyType) -> "DAGNode":
|
||||
self.set_dependency(nodes)
|
||||
|
||||
def set_downstream(self, nodes: DependencyType) -> "DAGNode":
|
||||
self.set_dependency(nodes, is_upstream=False)
|
||||
|
||||
@property
|
||||
def upstream(self) -> List["DAGNode"]:
|
||||
return self._upstream
|
||||
|
||||
@property
|
||||
def downstream(self) -> List["DAGNode"]:
|
||||
return self._downstream
|
||||
|
||||
def set_dependency(self, nodes: DependencyType, is_upstream: bool = True) -> None:
|
||||
if not isinstance(nodes, Sequence):
|
||||
nodes = [nodes]
|
||||
if not all(isinstance(node, DAGNode) for node in nodes):
|
||||
raise ValueError(
|
||||
"all nodes to set dependency to current node must be instance of 'DAGNode'"
|
||||
)
|
||||
nodes: Sequence[DAGNode] = nodes
|
||||
dags = set([node.dag for node in nodes if node.dag])
|
||||
if self.dag:
|
||||
dags.add(self.dag)
|
||||
if not dags:
|
||||
raise ValueError("set dependency to current node must in a DAG context")
|
||||
if len(dags) != 1:
|
||||
raise ValueError(
|
||||
"set dependency to current node just support in one DAG context"
|
||||
)
|
||||
dag = dags.pop()
|
||||
self._dag = dag
|
||||
|
||||
dag._append_node(self)
|
||||
for node in nodes:
|
||||
if is_upstream and node not in self.upstream:
|
||||
node._dag = dag
|
||||
dag._append_node(node)
|
||||
|
||||
self._upstream.append(node)
|
||||
node._downstream.append(self)
|
||||
elif node not in self._downstream:
|
||||
node._dag = dag
|
||||
dag._append_node(node)
|
||||
|
||||
self._downstream.append(node)
|
||||
node._upstream.append(self)
|
||||
|
||||
|
||||
class DAGContext:
|
||||
def __init__(self, streaming_call: bool = False) -> None:
|
||||
self._streaming_call = streaming_call
|
||||
self._curr_task_ctx = None
|
||||
self._share_data: Dict[str, Any] = {}
|
||||
|
||||
@property
|
||||
def current_task_context(self) -> TaskContext:
|
||||
return self._curr_task_ctx
|
||||
|
||||
@property
|
||||
def streaming_call(self) -> bool:
|
||||
"""Whether the current DAG is streaming call"""
|
||||
return self._streaming_call
|
||||
|
||||
def set_current_task_context(self, _curr_task_ctx: TaskContext) -> None:
|
||||
self._curr_task_ctx = _curr_task_ctx
|
||||
|
||||
async def get_share_data(self, key: str) -> Any:
|
||||
return self._share_data.get(key)
|
||||
|
||||
async def save_to_share_data(self, key: str, data: Any) -> None:
|
||||
self._share_data[key] = data
|
||||
|
||||
|
||||
class DAG:
|
||||
def __init__(
|
||||
self, dag_id: str, resource_group: Optional[ResourceGroup] = None
|
||||
) -> None:
|
||||
self._dag_id = dag_id
|
||||
self.node_map: Dict[str, DAGNode] = {}
|
||||
self._root_nodes: Set[DAGNode] = None
|
||||
self._leaf_nodes: Set[DAGNode] = None
|
||||
self._trigger_nodes: Set[DAGNode] = None
|
||||
|
||||
def _append_node(self, node: DAGNode) -> None:
|
||||
self.node_map[node.node_id] = node
|
||||
# clear cached nodes
|
||||
self._root_nodes = None
|
||||
self._leaf_nodes = None
|
||||
|
||||
def _new_node_id(self) -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
@property
|
||||
def dag_id(self) -> str:
|
||||
return self._dag_id
|
||||
|
||||
def _build(self) -> None:
|
||||
from ..operator.common_operator import TriggerOperator
|
||||
|
||||
nodes = set()
|
||||
for _, node in self.node_map.items():
|
||||
nodes = nodes.union(_get_nodes(node))
|
||||
self._root_nodes = list(set(filter(lambda x: not x.upstream, nodes)))
|
||||
self._leaf_nodes = list(set(filter(lambda x: not x.downstream, nodes)))
|
||||
self._trigger_nodes = list(
|
||||
set(filter(lambda x: isinstance(x, TriggerOperator), nodes))
|
||||
)
|
||||
|
||||
@property
|
||||
def root_nodes(self) -> List[DAGNode]:
|
||||
if not self._root_nodes:
|
||||
self._build()
|
||||
return self._root_nodes
|
||||
|
||||
@property
|
||||
def leaf_nodes(self) -> List[DAGNode]:
|
||||
if not self._leaf_nodes:
|
||||
self._build()
|
||||
return self._leaf_nodes
|
||||
|
||||
@property
|
||||
def trigger_nodes(self):
|
||||
if not self._trigger_nodes:
|
||||
self._build()
|
||||
return self._trigger_nodes
|
||||
|
||||
def __enter__(self):
|
||||
DAGVar.enter_dag(self)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
DAGVar.exit_dag()
|
||||
|
||||
|
||||
def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> set[DAGNode]:
|
||||
nodes = set()
|
||||
if not node:
|
||||
return nodes
|
||||
nodes.add(node)
|
||||
stream_nodes = node.upstream if is_upstream else node.downstream
|
||||
for node in stream_nodes:
|
||||
nodes = nodes.union(_get_nodes(node, is_upstream))
|
||||
return nodes
|
42
dbgpt/core/awel/dag/dag_manager.py
Normal file
42
dbgpt/core/awel/dag/dag_manager.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from typing import Dict, Optional
|
||||
import logging
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
from .loader import DAGLoader, LocalFileDAGLoader
|
||||
from .base import DAG
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DAGManager(BaseComponent):
|
||||
name = ComponentType.AWEL_DAG_MANAGER
|
||||
|
||||
def __init__(self, system_app: SystemApp, dag_filepath: str):
|
||||
super().__init__(system_app)
|
||||
self.dag_loader = LocalFileDAGLoader(dag_filepath)
|
||||
self.system_app = system_app
|
||||
self.dag_map: Dict[str, DAG] = {}
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
self.system_app = system_app
|
||||
|
||||
def load_dags(self):
|
||||
dags = self.dag_loader.load_dags()
|
||||
triggers = []
|
||||
for dag in dags:
|
||||
dag_id = dag.dag_id
|
||||
if dag_id in self.dag_map:
|
||||
raise ValueError(f"Load DAG error, DAG ID {dag_id} has already exist")
|
||||
triggers += dag.trigger_nodes
|
||||
from ..trigger.trigger_manager import DefaultTriggerManager
|
||||
|
||||
trigger_manager: DefaultTriggerManager = self.system_app.get_component(
|
||||
ComponentType.AWEL_TRIGGER_MANAGER,
|
||||
DefaultTriggerManager,
|
||||
default_component=None,
|
||||
)
|
||||
if trigger_manager:
|
||||
for trigger in triggers:
|
||||
trigger_manager.register_trigger(trigger)
|
||||
trigger_manager.after_register()
|
||||
else:
|
||||
logger.warn("No trigger manager, not register dag trigger")
|
93
dbgpt/core/awel/dag/loader.py
Normal file
93
dbgpt/core/awel/dag/loader.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
import os
|
||||
import hashlib
|
||||
import sys
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from .base import DAG
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DAGLoader(ABC):
|
||||
@abstractmethod
|
||||
def load_dags(self) -> List[DAG]:
|
||||
"""Load dags"""
|
||||
|
||||
|
||||
class LocalFileDAGLoader(DAGLoader):
|
||||
def __init__(self, filepath: str) -> None:
|
||||
super().__init__()
|
||||
self._filepath = filepath
|
||||
|
||||
def load_dags(self) -> List[DAG]:
|
||||
if not os.path.exists(self._filepath):
|
||||
return []
|
||||
if os.path.isdir(self._filepath):
|
||||
return _process_directory(self._filepath)
|
||||
else:
|
||||
return _process_file(self._filepath)
|
||||
|
||||
|
||||
def _process_directory(directory: str) -> List[DAG]:
|
||||
dags = []
|
||||
for file in os.listdir(directory):
|
||||
if file.endswith(".py"):
|
||||
filepath = os.path.join(directory, file)
|
||||
dags += _process_file(filepath)
|
||||
return dags
|
||||
|
||||
|
||||
def _process_file(filepath) -> List[DAG]:
|
||||
mods = _load_modules_from_file(filepath)
|
||||
results = _process_modules(mods)
|
||||
return results
|
||||
|
||||
|
||||
def _load_modules_from_file(filepath: str):
|
||||
import importlib
|
||||
import importlib.machinery
|
||||
import importlib.util
|
||||
|
||||
logger.info(f"Importing {filepath}")
|
||||
|
||||
org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1])
|
||||
path_hash = hashlib.sha1(filepath.encode("utf-8")).hexdigest()
|
||||
mod_name = f"unusual_prefix_{path_hash}_{org_mod_name}"
|
||||
|
||||
if mod_name in sys.modules:
|
||||
del sys.modules[mod_name]
|
||||
|
||||
def parse(mod_name, filepath):
|
||||
try:
|
||||
loader = importlib.machinery.SourceFileLoader(mod_name, filepath)
|
||||
spec = importlib.util.spec_from_loader(mod_name, loader)
|
||||
new_module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[spec.name] = new_module
|
||||
loader.exec_module(new_module)
|
||||
return [new_module]
|
||||
except Exception as e:
|
||||
msg = traceback.format_exc()
|
||||
logger.error(f"Failed to import: {filepath}, error message: {msg}")
|
||||
# TODO save error message
|
||||
return []
|
||||
|
||||
return parse(mod_name, filepath)
|
||||
|
||||
|
||||
def _process_modules(mods) -> List[DAG]:
|
||||
top_level_dags = (
|
||||
(o, m) for m in mods for o in m.__dict__.values() if isinstance(o, DAG)
|
||||
)
|
||||
found_dags = []
|
||||
for dag, mod in top_level_dags:
|
||||
try:
|
||||
# TODO validate dag params
|
||||
logger.info(f"Found dag {dag} from mod {mod} and model file {mod.__file__}")
|
||||
found_dags.append(dag)
|
||||
except Exception:
|
||||
msg = traceback.format_exc()
|
||||
logger.error(f"Failed to dag file, error message: {msg}")
|
||||
return found_dags
|
0
dbgpt/core/awel/dag/tests/__init__.py
Normal file
0
dbgpt/core/awel/dag/tests/__init__.py
Normal file
51
dbgpt/core/awel/dag/tests/test_dag.py
Normal file
51
dbgpt/core/awel/dag/tests/test_dag.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import pytest
|
||||
import threading
|
||||
import asyncio
|
||||
from ..dag import DAG, DAGContext
|
||||
|
||||
|
||||
def test_dag_context_sync():
|
||||
dag1 = DAG("dag1")
|
||||
dag2 = DAG("dag2")
|
||||
|
||||
with dag1:
|
||||
assert DAGContext.get_current_dag() == dag1
|
||||
with dag2:
|
||||
assert DAGContext.get_current_dag() == dag2
|
||||
assert DAGContext.get_current_dag() == dag1
|
||||
assert DAGContext.get_current_dag() is None
|
||||
|
||||
|
||||
def test_dag_context_threading():
|
||||
def thread_function(dag):
|
||||
DAGContext.enter_dag(dag)
|
||||
assert DAGContext.get_current_dag() == dag
|
||||
DAGContext.exit_dag()
|
||||
|
||||
dag1 = DAG("dag1")
|
||||
dag2 = DAG("dag2")
|
||||
|
||||
thread1 = threading.Thread(target=thread_function, args=(dag1,))
|
||||
thread2 = threading.Thread(target=thread_function, args=(dag2,))
|
||||
|
||||
thread1.start()
|
||||
thread2.start()
|
||||
thread1.join()
|
||||
thread2.join()
|
||||
|
||||
assert DAGContext.get_current_dag() is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dag_context_async():
|
||||
async def async_function(dag):
|
||||
DAGContext.enter_dag(dag)
|
||||
assert DAGContext.get_current_dag() == dag
|
||||
DAGContext.exit_dag()
|
||||
|
||||
dag1 = DAG("dag1")
|
||||
dag2 = DAG("dag2")
|
||||
|
||||
await asyncio.gather(async_function(dag1), async_function(dag2))
|
||||
|
||||
assert DAGContext.get_current_dag() is None
|
0
dbgpt/core/awel/operator/__init__.py
Normal file
0
dbgpt/core/awel/operator/__init__.py
Normal file
245
dbgpt/core/awel/operator/base.py
Normal file
245
dbgpt/core/awel/operator/base.py
Normal file
@@ -0,0 +1,245 @@
|
||||
from abc import ABC, abstractmethod, ABCMeta
|
||||
|
||||
from types import FunctionType
|
||||
from typing import (
|
||||
List,
|
||||
Generic,
|
||||
TypeVar,
|
||||
AsyncIterator,
|
||||
Iterator,
|
||||
Union,
|
||||
Any,
|
||||
Dict,
|
||||
Optional,
|
||||
cast,
|
||||
)
|
||||
import functools
|
||||
from inspect import signature
|
||||
import asyncio
|
||||
from dbgpt.component import SystemApp, ComponentType
|
||||
from dbgpt.util.executor_utils import (
|
||||
ExecutorFactory,
|
||||
DefaultExecutorFactory,
|
||||
blocking_func_to_async,
|
||||
BlockingFunction,
|
||||
AsyncToSyncIterator,
|
||||
)
|
||||
|
||||
from ..dag.base import DAGNode, DAGContext, DAGVar, DAG
|
||||
from ..task.base import TaskOutput, OUT, T
|
||||
|
||||
F = TypeVar("F", bound=FunctionType)
|
||||
|
||||
CALL_DATA = Union[Dict, Dict[str, Dict]]
|
||||
|
||||
|
||||
class WorkflowRunner(ABC, Generic[T]):
|
||||
"""Abstract base class representing a runner for executing workflows in a DAG.
|
||||
|
||||
This class defines the interface for executing workflows within the DAG,
|
||||
handling the flow from one DAG node to another.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def execute_workflow(
|
||||
self,
|
||||
node: "BaseOperator",
|
||||
call_data: Optional[CALL_DATA] = None,
|
||||
streaming_call: bool = False,
|
||||
) -> DAGContext:
|
||||
"""Execute the workflow starting from a given operator.
|
||||
|
||||
Args:
|
||||
node (RunnableDAGNode): The starting node of the workflow to be executed.
|
||||
call_data (CALL_DATA): The data pass to root operator node.
|
||||
streaming_call (bool): Whether the call is a streaming call.
|
||||
|
||||
Returns:
|
||||
DAGContext: The context after executing the workflow, containing the final state and data.
|
||||
"""
|
||||
|
||||
|
||||
default_runner: WorkflowRunner = None
|
||||
|
||||
|
||||
class BaseOperatorMeta(ABCMeta):
|
||||
"""Metaclass of BaseOperator."""
|
||||
|
||||
@classmethod
|
||||
def _apply_defaults(cls, func: F) -> F:
|
||||
sig_cache = signature(func)
|
||||
|
||||
@functools.wraps(func)
|
||||
def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any:
|
||||
dag: Optional[DAG] = kwargs.get("dag") or DAGVar.get_current_dag()
|
||||
task_id: Optional[str] = kwargs.get("task_id")
|
||||
system_app: Optional[SystemApp] = (
|
||||
kwargs.get("system_app") or DAGVar.get_current_system_app()
|
||||
)
|
||||
executor = kwargs.get("executor") or DAGVar.get_executor()
|
||||
if not executor:
|
||||
if system_app:
|
||||
executor = system_app.get_component(
|
||||
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||
).create()
|
||||
else:
|
||||
executor = DefaultExecutorFactory().create()
|
||||
DAGVar.set_executor(executor)
|
||||
|
||||
if not task_id and dag:
|
||||
task_id = dag._new_node_id()
|
||||
runner: Optional[WorkflowRunner] = kwargs.get("runner") or default_runner
|
||||
# print(f"self: {self}, kwargs dag: {kwargs.get('dag')}, kwargs: {kwargs}")
|
||||
# for arg in sig_cache.parameters:
|
||||
# if arg not in kwargs:
|
||||
# kwargs[arg] = default_args[arg]
|
||||
if not kwargs.get("dag"):
|
||||
kwargs["dag"] = dag
|
||||
if not kwargs.get("task_id"):
|
||||
kwargs["task_id"] = task_id
|
||||
if not kwargs.get("runner"):
|
||||
kwargs["runner"] = runner
|
||||
if not kwargs.get("system_app"):
|
||||
kwargs["system_app"] = system_app
|
||||
if not kwargs.get("executor"):
|
||||
kwargs["executor"] = executor
|
||||
real_obj = func(self, *args, **kwargs)
|
||||
return real_obj
|
||||
|
||||
return cast(T, apply_defaults)
|
||||
|
||||
def __new__(cls, name, bases, namespace, **kwargs):
|
||||
new_cls = super().__new__(cls, name, bases, namespace, **kwargs)
|
||||
new_cls.__init__ = cls._apply_defaults(new_cls.__init__)
|
||||
return new_cls
|
||||
|
||||
|
||||
class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
"""Abstract base class for operator nodes that can be executed within a workflow.
|
||||
|
||||
This class extends DAGNode by adding execution capabilities.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task_id: Optional[str] = None,
|
||||
task_name: Optional[str] = None,
|
||||
dag: Optional[DAG] = None,
|
||||
runner: WorkflowRunner = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Initializes a BaseOperator with an optional workflow runner.
|
||||
|
||||
Args:
|
||||
runner (WorkflowRunner, optional): The runner used to execute the workflow. Defaults to None.
|
||||
"""
|
||||
super().__init__(node_id=task_id, node_name=task_name, dag=dag, **kwargs)
|
||||
if not runner:
|
||||
from dbgpt.core.awel import DefaultWorkflowRunner
|
||||
|
||||
runner = DefaultWorkflowRunner()
|
||||
|
||||
self._runner: WorkflowRunner = runner
|
||||
self._dag_ctx: DAGContext = None
|
||||
|
||||
@property
|
||||
def current_dag_context(self) -> DAGContext:
|
||||
return self._dag_ctx
|
||||
|
||||
async def _run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
if not self.node_id:
|
||||
raise ValueError(f"The DAG Node ID can't be empty, current node {self}")
|
||||
self._dag_ctx = dag_ctx
|
||||
return await self._do_run(dag_ctx)
|
||||
|
||||
@abstractmethod
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
"""
|
||||
Abstract method to run the task within the DAG node.
|
||||
|
||||
Args:
|
||||
dag_ctx (DAGContext): The context of the DAG when this node is run.
|
||||
|
||||
Returns:
|
||||
TaskOutput[OUT]: The task output after this node has been run.
|
||||
"""
|
||||
|
||||
async def call(self, call_data: Optional[CALL_DATA] = None) -> OUT:
|
||||
"""Execute the node and return the output.
|
||||
|
||||
This method is a high-level wrapper for executing the node.
|
||||
|
||||
Args:
|
||||
call_data (CALL_DATA): The data pass to root operator node.
|
||||
|
||||
Returns:
|
||||
OUT: The output of the node after execution.
|
||||
"""
|
||||
out_ctx = await self._runner.execute_workflow(self, call_data)
|
||||
return out_ctx.current_task_context.task_output.output
|
||||
|
||||
def _blocking_call(
|
||||
self, call_data: Optional[CALL_DATA] = None, loop: asyncio.BaseEventLoop = None
|
||||
) -> OUT:
|
||||
"""Execute the node and return the output.
|
||||
|
||||
This method is a high-level wrapper for executing the node.
|
||||
This method just for debug. Please use `call` method instead.
|
||||
|
||||
Args:
|
||||
call_data (CALL_DATA): The data pass to root operator node.
|
||||
|
||||
Returns:
|
||||
OUT: The output of the node after execution.
|
||||
"""
|
||||
from dbgpt.util.utils import get_or_create_event_loop
|
||||
|
||||
if not loop:
|
||||
loop = get_or_create_event_loop()
|
||||
return loop.run_until_complete(self.call(call_data))
|
||||
|
||||
async def call_stream(
|
||||
self, call_data: Optional[CALL_DATA] = None
|
||||
) -> AsyncIterator[OUT]:
|
||||
"""Execute the node and return the output as a stream.
|
||||
|
||||
This method is used for nodes where the output is a stream.
|
||||
|
||||
Args:
|
||||
call_data (CALL_DATA): The data pass to root operator node.
|
||||
|
||||
Returns:
|
||||
AsyncIterator[OUT]: An asynchronous iterator over the output stream.
|
||||
"""
|
||||
out_ctx = await self._runner.execute_workflow(self, call_data)
|
||||
return out_ctx.current_task_context.task_output.output_stream
|
||||
|
||||
def _blocking_call_stream(
|
||||
self, call_data: Optional[CALL_DATA] = None, loop: asyncio.BaseEventLoop = None
|
||||
) -> Iterator[OUT]:
|
||||
"""Execute the node and return the output as a stream.
|
||||
|
||||
This method is used for nodes where the output is a stream.
|
||||
This method just for debug. Please use `call_stream` method instead.
|
||||
|
||||
Args:
|
||||
call_data (CALL_DATA): The data pass to root operator node.
|
||||
|
||||
Returns:
|
||||
Iterator[OUT]: An iterator over the output stream.
|
||||
"""
|
||||
from dbgpt.util.utils import get_or_create_event_loop
|
||||
|
||||
if not loop:
|
||||
loop = get_or_create_event_loop()
|
||||
return AsyncToSyncIterator(self.call_stream(call_data), loop)
|
||||
|
||||
async def blocking_func_to_async(
|
||||
self, func: BlockingFunction, *args, **kwargs
|
||||
) -> Any:
|
||||
return await blocking_func_to_async(self._executor, func, *args, **kwargs)
|
||||
|
||||
|
||||
def initialize_runner(runner: WorkflowRunner):
|
||||
global default_runner
|
||||
default_runner = runner
|
252
dbgpt/core/awel/operator/common_operator.py
Normal file
252
dbgpt/core/awel/operator/common_operator.py
Normal file
@@ -0,0 +1,252 @@
|
||||
from typing import Generic, Dict, List, Union, Callable, Any, AsyncIterator, Awaitable
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from ..dag.base import DAGContext
|
||||
from ..task.base import (
|
||||
TaskContext,
|
||||
TaskOutput,
|
||||
IN,
|
||||
OUT,
|
||||
InputContext,
|
||||
InputSource,
|
||||
)
|
||||
|
||||
from .base import BaseOperator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JoinOperator(BaseOperator, Generic[OUT]):
|
||||
"""Operator that joins inputs using a custom combine function.
|
||||
|
||||
This node type is useful for combining the outputs of upstream nodes.
|
||||
"""
|
||||
|
||||
def __init__(self, combine_function, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if not callable(combine_function):
|
||||
raise ValueError("combine_function must be callable")
|
||||
self.combine_function = combine_function
|
||||
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
"""Run the join operation on the DAG context's inputs.
|
||||
Args:
|
||||
dag_ctx (DAGContext): The current context of the DAG.
|
||||
|
||||
Returns:
|
||||
TaskOutput[OUT]: The task output after this node has been run.
|
||||
"""
|
||||
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||
input_ctx: InputContext = await curr_task_ctx.task_input.map_all(
|
||||
self.combine_function
|
||||
)
|
||||
# All join result store in the first parent output
|
||||
join_output = input_ctx.parent_outputs[0].task_output
|
||||
curr_task_ctx.set_task_output(join_output)
|
||||
return join_output
|
||||
|
||||
|
||||
class ReduceStreamOperator(BaseOperator, Generic[IN, OUT]):
|
||||
def __init__(self, reduce_function=None, **kwargs):
|
||||
"""Initializes a ReduceStreamOperator with a combine function.
|
||||
|
||||
Args:
|
||||
combine_function: A function that defines how to combine inputs.
|
||||
|
||||
Raises:
|
||||
ValueError: If the combine_function is not callable.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
if reduce_function and not callable(reduce_function):
|
||||
raise ValueError("reduce_function must be callable")
|
||||
self.reduce_function = reduce_function
|
||||
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
"""Run the join operation on the DAG context's inputs.
|
||||
|
||||
Args:
|
||||
dag_ctx (DAGContext): The current context of the DAG.
|
||||
|
||||
Returns:
|
||||
TaskOutput[OUT]: The task output after this node has been run.
|
||||
"""
|
||||
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||
task_input = curr_task_ctx.task_input
|
||||
if not task_input.check_stream():
|
||||
raise ValueError("ReduceStreamOperator expects stream data")
|
||||
if not task_input.check_single_parent():
|
||||
raise ValueError("ReduceStreamOperator expects single parent")
|
||||
|
||||
reduce_function = self.reduce_function or self.reduce
|
||||
|
||||
input_ctx: InputContext = await task_input.reduce(reduce_function)
|
||||
# All join result store in the first parent output
|
||||
reduce_output = input_ctx.parent_outputs[0].task_output
|
||||
curr_task_ctx.set_task_output(reduce_output)
|
||||
return reduce_output
|
||||
|
||||
async def reduce(self, input_value: AsyncIterator[IN]) -> OUT:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MapOperator(BaseOperator, Generic[IN, OUT]):
|
||||
"""Map operator that applies a mapping function to its inputs.
|
||||
|
||||
This operator transforms its input data using a provided mapping function and
|
||||
passes the transformed data downstream.
|
||||
"""
|
||||
|
||||
def __init__(self, map_function=None, **kwargs):
|
||||
"""Initializes a MapDAGNode with a mapping function.
|
||||
|
||||
Args:
|
||||
map_function: A function that defines how to map the input data.
|
||||
|
||||
Raises:
|
||||
ValueError: If the map_function is not callable.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
if map_function and not callable(map_function):
|
||||
raise ValueError("map_function must be callable")
|
||||
self.map_function = map_function
|
||||
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
"""Run the mapping operation on the DAG context's inputs.
|
||||
|
||||
This method applies the mapping function to the input context and updates
|
||||
the DAG context with the new data.
|
||||
|
||||
Args:
|
||||
dag_ctx (DAGContext[IN]): The current context of the DAG.
|
||||
|
||||
Returns:
|
||||
TaskOutput[OUT]: The task output after this node has been run.
|
||||
|
||||
Raises:
|
||||
ValueError: If not a single parent or the map_function is not callable
|
||||
"""
|
||||
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||
call_data = curr_task_ctx.call_data
|
||||
if not call_data and not curr_task_ctx.task_input.check_single_parent():
|
||||
num_parents = len(curr_task_ctx.task_input.parent_outputs)
|
||||
raise ValueError(
|
||||
f"task {curr_task_ctx.task_id} MapDAGNode expects single parent, now number of parents: {num_parents}"
|
||||
)
|
||||
map_function = self.map_function or self.map
|
||||
|
||||
if call_data:
|
||||
call_data = await curr_task_ctx._call_data_to_output()
|
||||
output = await call_data.map(map_function)
|
||||
curr_task_ctx.set_task_output(output)
|
||||
return output
|
||||
|
||||
input_ctx: InputContext = await curr_task_ctx.task_input.map(map_function)
|
||||
# All join result store in the first parent output
|
||||
output = input_ctx.parent_outputs[0].task_output
|
||||
curr_task_ctx.set_task_output(output)
|
||||
return output
|
||||
|
||||
async def map(self, input_value: IN) -> OUT:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
BranchFunc = Union[Callable[[IN], bool], Callable[[IN], Awaitable[bool]]]
|
||||
|
||||
|
||||
class BranchOperator(BaseOperator, Generic[IN, OUT]):
|
||||
"""Operator node that branches the workflow based on a provided function.
|
||||
|
||||
This node filters its input data using a branching function and
|
||||
allows for conditional paths in the workflow.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, branches: Dict[BranchFunc[IN], Union[BaseOperator, str]], **kwargs
|
||||
):
|
||||
"""
|
||||
Initializes a BranchDAGNode with a branching function.
|
||||
|
||||
Args:
|
||||
branches (Dict[BranchFunc[IN], Union[BaseOperator, str]]): Dict of function that defines the branching condition.
|
||||
|
||||
Raises:
|
||||
ValueError: If the branch_function is not callable.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
if branches:
|
||||
for branch_function, value in branches.items():
|
||||
if not callable(branch_function):
|
||||
raise ValueError("branch_function must be callable")
|
||||
if isinstance(value, BaseOperator):
|
||||
branches[branch_function] = value.node_name or value.node_name
|
||||
self._branches = branches
|
||||
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
"""Run the branching operation on the DAG context's inputs.
|
||||
|
||||
This method applies the branching function to the input context to determine
|
||||
the path of execution in the workflow.
|
||||
|
||||
Args:
|
||||
dag_ctx (DAGContext[IN]): The current context of the DAG.
|
||||
|
||||
Returns:
|
||||
TaskOutput[OUT]: The task output after this node has been run.
|
||||
"""
|
||||
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||
task_input = curr_task_ctx.task_input
|
||||
if task_input.check_stream():
|
||||
raise ValueError("BranchDAGNode expects no stream data")
|
||||
if not task_input.check_single_parent():
|
||||
raise ValueError("BranchDAGNode expects single parent")
|
||||
|
||||
branches = self._branches
|
||||
if not branches:
|
||||
branches = await self.branchs()
|
||||
|
||||
branch_func_tasks = []
|
||||
branch_nodes: List[str] = []
|
||||
for func, node_name in branches.items():
|
||||
branch_nodes.append(node_name)
|
||||
branch_func_tasks.append(
|
||||
curr_task_ctx.task_input.predicate_map(func, failed_value=None)
|
||||
)
|
||||
|
||||
branch_input_ctxs: List[InputContext] = await asyncio.gather(*branch_func_tasks)
|
||||
parent_output = task_input.parent_outputs[0].task_output
|
||||
curr_task_ctx.set_task_output(parent_output)
|
||||
skip_node_names = []
|
||||
for i, ctx in enumerate(branch_input_ctxs):
|
||||
node_name = branch_nodes[i]
|
||||
branch_out = ctx.parent_outputs[0].task_output
|
||||
logger.info(
|
||||
f"branch_input_ctxs {i} result {branch_out.output}, is_empty: {branch_out.is_empty}"
|
||||
)
|
||||
if ctx.parent_outputs[0].task_output.is_empty:
|
||||
logger.info(f"Skip node name {node_name}")
|
||||
skip_node_names.append(node_name)
|
||||
curr_task_ctx.update_metadata("skip_node_names", skip_node_names)
|
||||
return parent_output
|
||||
|
||||
async def branchs(self) -> Dict[BranchFunc[IN], Union[BaseOperator, str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class InputOperator(BaseOperator, Generic[OUT]):
|
||||
def __init__(self, input_source: InputSource[OUT], **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self._input_source = input_source
|
||||
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||
task_output = await self._input_source.read(curr_task_ctx)
|
||||
curr_task_ctx.set_task_output(task_output)
|
||||
return task_output
|
||||
|
||||
|
||||
class TriggerOperator(InputOperator, Generic[OUT]):
|
||||
def __init__(self, **kwargs) -> None:
|
||||
from ..task.task_impl import SimpleCallDataInputSource
|
||||
|
||||
super().__init__(input_source=SimpleCallDataInputSource(), **kwargs)
|
90
dbgpt/core/awel/operator/stream_operator.py
Normal file
90
dbgpt/core/awel/operator/stream_operator.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Generic, AsyncIterator
|
||||
from ..task.base import OUT, IN, TaskOutput, TaskContext
|
||||
from ..dag.base import DAGContext
|
||||
from .base import BaseOperator
|
||||
|
||||
|
||||
class StreamifyAbsOperator(BaseOperator[OUT], ABC, Generic[IN, OUT]):
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||
output = await curr_task_ctx.task_input.parent_outputs[0].task_output.streamify(
|
||||
self.streamify
|
||||
)
|
||||
curr_task_ctx.set_task_output(output)
|
||||
return output
|
||||
|
||||
@abstractmethod
|
||||
async def streamify(self, input_value: IN) -> AsyncIterator[OUT]:
|
||||
"""Convert a value of IN to an AsyncIterator[OUT]
|
||||
|
||||
Args:
|
||||
input_value (IN): The data of parent operator's output
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class MyStreamOperator(StreamifyAbsOperator[int, int]):
|
||||
async def streamify(self, input_value: int) -> AsyncIterator[int]
|
||||
for i in range(input_value):
|
||||
yield i
|
||||
"""
|
||||
|
||||
|
||||
class UnstreamifyAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||
output = await curr_task_ctx.task_input.parent_outputs[
|
||||
0
|
||||
].task_output.unstreamify(self.unstreamify)
|
||||
curr_task_ctx.set_task_output(output)
|
||||
return output
|
||||
|
||||
@abstractmethod
|
||||
async def unstreamify(self, input_value: AsyncIterator[IN]) -> OUT:
|
||||
"""Convert a value of AsyncIterator[IN] to an OUT.
|
||||
|
||||
Args:
|
||||
input_value (AsyncIterator[IN])): The data of parent operator's output
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class MyUnstreamOperator(UnstreamifyAbsOperator[int, int]):
|
||||
async def unstreamify(self, input_value: AsyncIterator[int]) -> int
|
||||
value_cnt = 0
|
||||
async for v in input_value:
|
||||
value_cnt += 1
|
||||
return value_cnt
|
||||
"""
|
||||
|
||||
|
||||
class TransformStreamAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||
output = await curr_task_ctx.task_input.parent_outputs[
|
||||
0
|
||||
].task_output.transform_stream(self.transform_stream)
|
||||
curr_task_ctx.set_task_output(output)
|
||||
return output
|
||||
|
||||
@abstractmethod
|
||||
async def transform_stream(
|
||||
self, input_value: AsyncIterator[IN]
|
||||
) -> AsyncIterator[OUT]:
|
||||
"""Transform an AsyncIterator[IN] to another AsyncIterator[OUT] using a given function.
|
||||
|
||||
Args:
|
||||
input_value (AsyncIterator[IN])): The data of parent operator's output
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class MyTransformStreamOperator(TransformStreamAbsOperator[int, int]):
|
||||
async def unstreamify(self, input_value: AsyncIterator[int]) -> AsyncIterator[int]
|
||||
async for v in input_value:
|
||||
yield v + 1
|
||||
"""
|
0
dbgpt/core/awel/resource/__init__.py
Normal file
0
dbgpt/core/awel/resource/__init__.py
Normal file
8
dbgpt/core/awel/resource/base.py
Normal file
8
dbgpt/core/awel/resource/base.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class ResourceGroup(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""The name of current resource group"""
|
0
dbgpt/core/awel/runner/__init__.py
Normal file
0
dbgpt/core/awel/runner/__init__.py
Normal file
82
dbgpt/core/awel/runner/job_manager.py
Normal file
82
dbgpt/core/awel/runner/job_manager.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from typing import List, Set, Optional, Dict
|
||||
import uuid
|
||||
import logging
|
||||
from ..dag.base import DAG
|
||||
|
||||
from ..operator.base import BaseOperator, CALL_DATA
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DAGNodeInstance:
|
||||
def __init__(self, node_instance: DAG) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class DAGInstance:
|
||||
def __init__(self, dag: DAG) -> None:
|
||||
self._dag = dag
|
||||
|
||||
|
||||
class JobManager:
|
||||
def __init__(
|
||||
self,
|
||||
root_nodes: List[BaseOperator],
|
||||
all_nodes: List[BaseOperator],
|
||||
end_node: BaseOperator,
|
||||
id2call_data: Dict[str, Dict],
|
||||
) -> None:
|
||||
self._root_nodes = root_nodes
|
||||
self._all_nodes = all_nodes
|
||||
self._end_node = end_node
|
||||
self._id2node_data = id2call_data
|
||||
|
||||
@staticmethod
|
||||
def build_from_end_node(
|
||||
end_node: BaseOperator, call_data: Optional[CALL_DATA] = None
|
||||
) -> "JobManager":
|
||||
nodes = _build_from_end_node(end_node)
|
||||
root_nodes = _get_root_nodes(nodes)
|
||||
id2call_data = _save_call_data(root_nodes, call_data)
|
||||
return JobManager(root_nodes, nodes, end_node, id2call_data)
|
||||
|
||||
def get_call_data_by_id(self, node_id: str) -> Optional[Dict]:
|
||||
return self._id2node_data.get(node_id)
|
||||
|
||||
|
||||
def _save_call_data(
|
||||
root_nodes: List[BaseOperator], call_data: CALL_DATA
|
||||
) -> Dict[str, Dict]:
|
||||
id2call_data = {}
|
||||
logger.debug(f"_save_call_data: {call_data}, root_nodes: {root_nodes}")
|
||||
if not call_data:
|
||||
return id2call_data
|
||||
if len(root_nodes) == 1:
|
||||
node = root_nodes[0]
|
||||
logger.info(f"Save call data to node {node.node_id}, call_data: {call_data}")
|
||||
id2call_data[node.node_id] = call_data
|
||||
else:
|
||||
for node in root_nodes:
|
||||
node_id = node.node_id
|
||||
logger.info(
|
||||
f"Save call data to node {node.node_id}, call_data: {call_data.get(node_id)}"
|
||||
)
|
||||
id2call_data[node_id] = call_data.get(node_id)
|
||||
return id2call_data
|
||||
|
||||
|
||||
def _build_from_end_node(end_node: BaseOperator) -> List[BaseOperator]:
|
||||
nodes = []
|
||||
if isinstance(end_node, BaseOperator):
|
||||
task_id = end_node.node_id
|
||||
if not task_id:
|
||||
task_id = str(uuid.uuid4())
|
||||
end_node.set_node_id(task_id)
|
||||
nodes.append(end_node)
|
||||
for node in end_node.upstream:
|
||||
nodes += _build_from_end_node(node)
|
||||
return nodes
|
||||
|
||||
|
||||
def _get_root_nodes(nodes: List[BaseOperator]) -> List[BaseOperator]:
|
||||
return list(set(filter(lambda x: not x.upstream, nodes)))
|
109
dbgpt/core/awel/runner/local_runner.py
Normal file
109
dbgpt/core/awel/runner/local_runner.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from typing import Dict, Optional, Set, List
|
||||
import logging
|
||||
|
||||
from ..dag.base import DAGContext
|
||||
from ..operator.base import WorkflowRunner, BaseOperator, CALL_DATA
|
||||
from ..operator.common_operator import BranchOperator, JoinOperator, TriggerOperator
|
||||
from ..task.base import TaskContext, TaskState
|
||||
from ..task.task_impl import DefaultInputContext, DefaultTaskContext, SimpleTaskOutput
|
||||
from .job_manager import JobManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DefaultWorkflowRunner(WorkflowRunner):
|
||||
async def execute_workflow(
|
||||
self,
|
||||
node: BaseOperator,
|
||||
call_data: Optional[CALL_DATA] = None,
|
||||
streaming_call: bool = False,
|
||||
) -> DAGContext:
|
||||
# Create DAG context
|
||||
dag_ctx = DAGContext(streaming_call=streaming_call)
|
||||
job_manager = JobManager.build_from_end_node(node, call_data)
|
||||
logger.info(
|
||||
f"Begin run workflow from end operator, id: {node.node_id}, call_data: {call_data}"
|
||||
)
|
||||
dag = node.dag
|
||||
# Save node output
|
||||
node_outputs: Dict[str, TaskContext] = {}
|
||||
skip_node_ids = set()
|
||||
await self._execute_node(
|
||||
job_manager, node, dag_ctx, node_outputs, skip_node_ids
|
||||
)
|
||||
|
||||
return dag_ctx
|
||||
|
||||
async def _execute_node(
|
||||
self,
|
||||
job_manager: JobManager,
|
||||
node: BaseOperator,
|
||||
dag_ctx: DAGContext,
|
||||
node_outputs: Dict[str, TaskContext],
|
||||
skip_node_ids: Set[str],
|
||||
):
|
||||
# Skip run node
|
||||
if node.node_id in node_outputs:
|
||||
return
|
||||
|
||||
# Run all upstream node
|
||||
for upstream_node in node.upstream:
|
||||
if isinstance(upstream_node, BaseOperator):
|
||||
await self._execute_node(
|
||||
job_manager, upstream_node, dag_ctx, node_outputs, skip_node_ids
|
||||
)
|
||||
|
||||
inputs = [
|
||||
node_outputs[upstream_node.node_id] for upstream_node in node.upstream
|
||||
]
|
||||
input_ctx = DefaultInputContext(inputs)
|
||||
task_ctx = DefaultTaskContext(node.node_id, TaskState.INIT, task_output=None)
|
||||
task_ctx.set_call_data(job_manager.get_call_data_by_id(node.node_id))
|
||||
|
||||
task_ctx.set_task_input(input_ctx)
|
||||
dag_ctx.set_current_task_context(task_ctx)
|
||||
task_ctx.set_current_state(TaskState.RUNNING)
|
||||
|
||||
if node.node_id in skip_node_ids:
|
||||
task_ctx.set_current_state(TaskState.SKIP)
|
||||
task_ctx.set_task_output(SimpleTaskOutput(None))
|
||||
node_outputs[node.node_id] = task_ctx
|
||||
return
|
||||
try:
|
||||
logger.debug(
|
||||
f"Begin run operator, node id: {node.node_id}, node name: {node.node_name}, cls: {node}"
|
||||
)
|
||||
await node._run(dag_ctx)
|
||||
node_outputs[node.node_id] = dag_ctx.current_task_context
|
||||
task_ctx.set_current_state(TaskState.SUCCESS)
|
||||
|
||||
if isinstance(node, BranchOperator):
|
||||
skip_nodes = task_ctx.metadata.get("skip_node_names", [])
|
||||
logger.debug(
|
||||
f"Current is branch operator, skip node names: {skip_nodes}"
|
||||
)
|
||||
_skip_current_downstream_by_node_name(node, skip_nodes, skip_node_ids)
|
||||
except Exception as e:
|
||||
logger.info(f"Run operator {node.node_id} error, error message: {str(e)}")
|
||||
task_ctx.set_current_state(TaskState.FAILED)
|
||||
raise e
|
||||
|
||||
|
||||
def _skip_current_downstream_by_node_name(
|
||||
branch_node: BranchOperator, skip_nodes: List[str], skip_node_ids: Set[str]
|
||||
):
|
||||
if not skip_nodes:
|
||||
return
|
||||
for child in branch_node.downstream:
|
||||
if child.node_name in skip_nodes:
|
||||
logger.info(f"Skip node name {child.node_name}, node id {child.node_id}")
|
||||
_skip_downstream_by_id(child, skip_node_ids)
|
||||
|
||||
|
||||
def _skip_downstream_by_id(node: BaseOperator, skip_node_ids: Set[str]):
|
||||
if isinstance(node, JoinOperator):
|
||||
# Not skip join node
|
||||
return
|
||||
skip_node_ids.add(node.node_id)
|
||||
for child in node.downstream:
|
||||
_skip_downstream_by_id(child, skip_node_ids)
|
0
dbgpt/core/awel/task/__init__.py
Normal file
0
dbgpt/core/awel/task/__init__.py
Normal file
371
dbgpt/core/awel/task/base.py
Normal file
371
dbgpt/core/awel/task/base.py
Normal file
@@ -0,0 +1,371 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
TypeVar,
|
||||
Generic,
|
||||
Optional,
|
||||
AsyncIterator,
|
||||
Union,
|
||||
Callable,
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
)
|
||||
|
||||
IN = TypeVar("IN")
|
||||
OUT = TypeVar("OUT")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class TaskState(str, Enum):
|
||||
"""Enumeration representing the state of a task in the workflow.
|
||||
|
||||
This Enum defines various states a task can be in during its lifecycle in the DAG.
|
||||
"""
|
||||
|
||||
INIT = "init" # Initial state of the task, not yet started
|
||||
SKIP = "skip" # State indicating the task was skipped
|
||||
RUNNING = "running" # State indicating the task is currently running
|
||||
SUCCESS = "success" # State indicating the task completed successfully
|
||||
FAILED = "failed" # State indicating the task failed during execution
|
||||
|
||||
|
||||
class TaskOutput(ABC, Generic[T]):
|
||||
"""Abstract base class representing the output of a task.
|
||||
|
||||
This class encapsulates the output of a task and provides methods to access the output data.
|
||||
It can be subclassed to implement specific output behaviors.
|
||||
"""
|
||||
|
||||
@property
|
||||
def is_stream(self) -> bool:
|
||||
"""Check if the output is a stream.
|
||||
|
||||
Returns:
|
||||
bool: True if the output is a stream, False otherwise.
|
||||
"""
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
"""Check if the output is empty.
|
||||
|
||||
Returns:
|
||||
bool: True if the output is empty, False otherwise.
|
||||
"""
|
||||
return False
|
||||
|
||||
@property
|
||||
def output(self) -> Optional[T]:
|
||||
"""Return the output of the task.
|
||||
|
||||
Returns:
|
||||
T: The output of the task. None if the output is empty.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def output_stream(self) -> Optional[AsyncIterator[T]]:
|
||||
"""Return the output of the task as an asynchronous stream.
|
||||
|
||||
Returns:
|
||||
AsyncIterator[T]: An asynchronous iterator over the output. None if the output is empty.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def set_output(self, output_data: Union[T, AsyncIterator[T]]) -> None:
|
||||
"""Set the output data to current object.
|
||||
|
||||
Args:
|
||||
output_data (Union[T, AsyncIterator[T]]): Output data.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def new_output(self) -> "TaskOutput[T]":
|
||||
"""Create new output object"""
|
||||
|
||||
async def map(self, map_func) -> "TaskOutput[T]":
|
||||
"""Apply a mapping function to the task's output.
|
||||
|
||||
Args:
|
||||
map_func: A function to apply to the task's output.
|
||||
|
||||
Returns:
|
||||
TaskOutput[T]: The result of applying the mapping function.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def reduce(self, reduce_func) -> "TaskOutput[T]":
|
||||
"""Apply a reducing function to the task's output.
|
||||
|
||||
Stream TaskOutput to Nonstream TaskOutput.
|
||||
|
||||
Args:
|
||||
reduce_func: A reducing function to apply to the task's output.
|
||||
|
||||
Returns:
|
||||
TaskOutput[T]: The result of applying the reducing function.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def streamify(
|
||||
self, transform_func: Callable[[T], AsyncIterator[T]]
|
||||
) -> "TaskOutput[T]":
|
||||
"""Convert a value of type T to an AsyncIterator[T] using a transform function.
|
||||
|
||||
Args:
|
||||
transform_func (Callable[[T], AsyncIterator[T]]): Function to transform a T value into an AsyncIterator[T].
|
||||
|
||||
Returns:
|
||||
TaskOutput[T]: The result of applying the reducing function.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def transform_stream(
|
||||
self, transform_func: Callable[[AsyncIterator[T]], AsyncIterator[T]]
|
||||
) -> "TaskOutput[T]":
|
||||
"""Transform an AsyncIterator[T] to another AsyncIterator[T] using a given function.
|
||||
|
||||
Args:
|
||||
transform_func (Callable[[AsyncIterator[T]], AsyncIterator[T]]): Function to apply to the AsyncIterator[T].
|
||||
|
||||
Returns:
|
||||
TaskOutput[T]: The result of applying the reducing function.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def unstreamify(
|
||||
self, transform_func: Callable[[AsyncIterator[T]], T]
|
||||
) -> "TaskOutput[T]":
|
||||
"""Convert an AsyncIterator[T] to a value of type T using a transform function.
|
||||
|
||||
Args:
|
||||
transform_func (Callable[[AsyncIterator[T]], T]): Function to transform an AsyncIterator[T] into a T value.
|
||||
|
||||
Returns:
|
||||
TaskOutput[T]: The result of applying the reducing function.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def check_condition(self, condition_func) -> bool:
|
||||
"""Check if current output meets a given condition.
|
||||
|
||||
Args:
|
||||
condition_func: A function to determine if the condition is met.
|
||||
Returns:
|
||||
bool: True if current output meet the condition, False otherwise.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TaskContext(ABC, Generic[T]):
|
||||
"""Abstract base class representing the context of a task within a DAG.
|
||||
|
||||
This class provides the interface for accessing task-related information
|
||||
and manipulating task output.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def task_id(self) -> str:
|
||||
"""Return the unique identifier of the task.
|
||||
|
||||
Returns:
|
||||
str: The unique identifier of the task.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def task_input(self) -> "InputContext":
|
||||
"""Return the InputContext of current task.
|
||||
|
||||
Returns:
|
||||
InputContext: The InputContext of current task.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_task_input(self, input_ctx: "InputContext") -> None:
|
||||
"""Set the InputContext object to current task.
|
||||
|
||||
Args:
|
||||
input_ctx (InputContext): The InputContext of current task
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def task_output(self) -> TaskOutput[T]:
|
||||
"""Return the output object of the task.
|
||||
|
||||
Returns:
|
||||
TaskOutput[T]: The output object of the task.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_task_output(self, task_output: TaskOutput[T]) -> None:
|
||||
"""Set the output object to current task."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def current_state(self) -> TaskState:
|
||||
"""Get the current state of the task.
|
||||
|
||||
Returns:
|
||||
TaskState: The current state of the task.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_current_state(self, task_state: TaskState) -> None:
|
||||
"""Set current task state
|
||||
|
||||
Args:
|
||||
task_state (TaskState): The task state to be set.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def new_ctx(self) -> "TaskContext":
|
||||
"""Create new task context
|
||||
|
||||
Returns:
|
||||
TaskContext: A new instance of a TaskContext.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def metadata(self) -> Dict[str, Any]:
|
||||
"""Get the metadata of current task
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The metadata
|
||||
"""
|
||||
|
||||
def update_metadata(self, key: str, value: Any) -> None:
|
||||
"""Update metadata with key and value
|
||||
|
||||
Args:
|
||||
key (str): The key of metadata
|
||||
value (str): The value to be add to metadata
|
||||
"""
|
||||
self.metadata[key] = value
|
||||
|
||||
@property
|
||||
def call_data(self) -> Optional[Dict]:
|
||||
"""Get the call data for current data"""
|
||||
return self.metadata.get("call_data")
|
||||
|
||||
@abstractmethod
|
||||
async def _call_data_to_output(self) -> Optional[TaskOutput[T]]:
|
||||
"""Get the call data for current data"""
|
||||
|
||||
def set_call_data(self, call_data: Dict) -> None:
|
||||
"""Set call data for current task"""
|
||||
self.update_metadata("call_data", call_data)
|
||||
|
||||
|
||||
class InputContext(ABC):
|
||||
"""Abstract base class representing the context of inputs to a operator node.
|
||||
|
||||
This class defines methods to manipulate and access the inputs for a operator node.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def parent_outputs(self) -> List[TaskContext]:
|
||||
"""Get the outputs from the parent nodes.
|
||||
|
||||
Returns:
|
||||
List[TaskContext]: A list of contexts of the parent nodes' outputs.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def map(self, map_func: Callable[[Any], Any]) -> "InputContext":
|
||||
"""Apply a mapping function to the inputs.
|
||||
|
||||
Args:
|
||||
map_func (Callable[[Any], Any]): A function to be applied to the inputs.
|
||||
|
||||
Returns:
|
||||
InputContext: A new InputContext instance with the mapped inputs.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def map_all(self, map_func: Callable[..., Any]) -> "InputContext":
|
||||
"""Apply a mapping function to all inputs.
|
||||
|
||||
Args:
|
||||
map_func (Callable[..., Any]): A function to be applied to all inputs.
|
||||
|
||||
Returns:
|
||||
InputContext: A new InputContext instance with the mapped inputs.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def reduce(self, reduce_func: Callable[[Any], Any]) -> "InputContext":
|
||||
"""Apply a reducing function to the inputs.
|
||||
|
||||
Args:
|
||||
reduce_func (Callable[[Any], Any]): A function that reduces the inputs.
|
||||
|
||||
Returns:
|
||||
InputContext: A new InputContext instance with the reduced inputs.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def filter(self, filter_func: Callable[[Any], bool]) -> "InputContext":
|
||||
"""Filter the inputs based on a provided function.
|
||||
|
||||
Args:
|
||||
filter_func (Callable[[Any], bool]): A function that returns True for inputs to keep.
|
||||
|
||||
Returns:
|
||||
InputContext: A new InputContext instance with the filtered inputs.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def predicate_map(
|
||||
self, predicate_func: Callable[[Any], bool], failed_value: Any = None
|
||||
) -> "InputContext":
|
||||
"""Predicate the inputs based on a provided function.
|
||||
|
||||
Args:
|
||||
predicate_func (Callable[[Any], bool]): A function that returns True for inputs is predicate True.
|
||||
failed_value (Any): The value to be set if the return value of predicate function is False
|
||||
Returns:
|
||||
InputContext: A new InputContext instance with the predicate inputs.
|
||||
"""
|
||||
|
||||
def check_single_parent(self) -> bool:
|
||||
"""Check if there is only a single parent output.
|
||||
|
||||
Returns:
|
||||
bool: True if there is only one parent output, False otherwise.
|
||||
"""
|
||||
return len(self.parent_outputs) == 1
|
||||
|
||||
def check_stream(self, skip_empty: bool = False) -> bool:
|
||||
"""Check if all parent outputs are streams.
|
||||
|
||||
Args:
|
||||
skip_empty (bool): Skip empty output or not.
|
||||
|
||||
Returns:
|
||||
bool: True if all parent outputs are streams, False otherwise.
|
||||
"""
|
||||
for out in self.parent_outputs:
|
||||
if out.task_output.is_empty and skip_empty:
|
||||
continue
|
||||
if not (out.task_output and out.task_output.is_stream):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class InputSource(ABC, Generic[T]):
|
||||
"""Abstract base class representing the source of inputs to a DAG node."""
|
||||
|
||||
@abstractmethod
|
||||
async def read(self, task_ctx: TaskContext) -> TaskOutput[T]:
|
||||
"""Read the data from current input source.
|
||||
|
||||
Returns:
|
||||
TaskOutput[T]: The output object read from current source
|
||||
"""
|
348
dbgpt/core/awel/task/task_impl.py
Normal file
348
dbgpt/core/awel/task/task_impl.py
Normal file
@@ -0,0 +1,348 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
Callable,
|
||||
Coroutine,
|
||||
Iterator,
|
||||
AsyncIterator,
|
||||
List,
|
||||
Generic,
|
||||
TypeVar,
|
||||
Any,
|
||||
Tuple,
|
||||
Dict,
|
||||
Union,
|
||||
Optional,
|
||||
)
|
||||
import asyncio
|
||||
import logging
|
||||
from .base import TaskOutput, TaskContext, TaskState, InputContext, InputSource, T
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _reduce_stream(stream: AsyncIterator, reduce_function) -> Any:
|
||||
# Init accumulator
|
||||
try:
|
||||
accumulator = await stream.__anext__()
|
||||
except StopAsyncIteration:
|
||||
raise ValueError("Stream is empty")
|
||||
is_async = asyncio.iscoroutinefunction(reduce_function)
|
||||
async for element in stream:
|
||||
if is_async:
|
||||
accumulator = await reduce_function(accumulator, element)
|
||||
else:
|
||||
accumulator = reduce_function(accumulator, element)
|
||||
return accumulator
|
||||
|
||||
|
||||
class SimpleTaskOutput(TaskOutput[T], Generic[T]):
|
||||
def __init__(self, data: T) -> None:
|
||||
super().__init__()
|
||||
self._data = data
|
||||
|
||||
@property
|
||||
def output(self) -> T:
|
||||
return self._data
|
||||
|
||||
def set_output(self, output_data: T | AsyncIterator[T]) -> None:
|
||||
self._data = output_data
|
||||
|
||||
def new_output(self) -> TaskOutput[T]:
|
||||
return SimpleTaskOutput(None)
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
return not self._data
|
||||
|
||||
async def _apply_func(self, func) -> Any:
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
out = await func(self._data)
|
||||
else:
|
||||
out = func(self._data)
|
||||
return out
|
||||
|
||||
async def map(self, map_func) -> TaskOutput[T]:
|
||||
out = await self._apply_func(map_func)
|
||||
return SimpleTaskOutput(out)
|
||||
|
||||
async def check_condition(self, condition_func) -> bool:
|
||||
return await self._apply_func(condition_func)
|
||||
|
||||
async def streamify(
|
||||
self, transform_func: Callable[[T], AsyncIterator[T]]
|
||||
) -> TaskOutput[T]:
|
||||
out = await self._apply_func(transform_func)
|
||||
return SimpleStreamTaskOutput(out)
|
||||
|
||||
|
||||
class SimpleStreamTaskOutput(TaskOutput[T], Generic[T]):
|
||||
def __init__(self, data: AsyncIterator[T]) -> None:
|
||||
super().__init__()
|
||||
self._data = data
|
||||
|
||||
@property
|
||||
def is_stream(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
return not self._data
|
||||
|
||||
@property
|
||||
def output_stream(self) -> AsyncIterator[T]:
|
||||
return self._data
|
||||
|
||||
def set_output(self, output_data: T | AsyncIterator[T]) -> None:
|
||||
self._data = output_data
|
||||
|
||||
def new_output(self) -> TaskOutput[T]:
|
||||
return SimpleStreamTaskOutput(None)
|
||||
|
||||
async def map(self, map_func) -> TaskOutput[T]:
|
||||
is_async = asyncio.iscoroutinefunction(map_func)
|
||||
|
||||
async def new_iter() -> AsyncIterator[T]:
|
||||
async for out in self._data:
|
||||
if is_async:
|
||||
out = await map_func(out)
|
||||
else:
|
||||
out = map_func(out)
|
||||
yield out
|
||||
|
||||
return SimpleStreamTaskOutput(new_iter())
|
||||
|
||||
async def reduce(self, reduce_func) -> TaskOutput[T]:
|
||||
out = await _reduce_stream(self._data, reduce_func)
|
||||
return SimpleTaskOutput(out)
|
||||
|
||||
async def unstreamify(
|
||||
self, transform_func: Callable[[AsyncIterator[T]], T]
|
||||
) -> TaskOutput[T]:
|
||||
if asyncio.iscoroutinefunction(transform_func):
|
||||
out = await transform_func(self._data)
|
||||
else:
|
||||
out = transform_func(self._data)
|
||||
return SimpleTaskOutput(out)
|
||||
|
||||
async def transform_stream(
|
||||
self, transform_func: Callable[[AsyncIterator[T]], AsyncIterator[T]]
|
||||
) -> TaskOutput[T]:
|
||||
if asyncio.iscoroutinefunction(transform_func):
|
||||
out = await transform_func(self._data)
|
||||
else:
|
||||
out = transform_func(self._data)
|
||||
return SimpleStreamTaskOutput(out)
|
||||
|
||||
|
||||
def _is_async_iterator(obj):
|
||||
return (
|
||||
hasattr(obj, "__anext__")
|
||||
and callable(getattr(obj, "__anext__", None))
|
||||
and hasattr(obj, "__aiter__")
|
||||
and callable(getattr(obj, "__aiter__", None))
|
||||
)
|
||||
|
||||
|
||||
class BaseInputSource(InputSource, ABC):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._is_read = False
|
||||
|
||||
@abstractmethod
|
||||
def _read_data(self, task_ctx: TaskContext) -> Any:
|
||||
"""Read data with task context"""
|
||||
|
||||
async def read(self, task_ctx: TaskContext) -> TaskOutput:
|
||||
data = self._read_data(task_ctx)
|
||||
if _is_async_iterator(data):
|
||||
if self._is_read:
|
||||
raise ValueError(f"Input iterator {data} has been read!")
|
||||
output = SimpleStreamTaskOutput(data)
|
||||
else:
|
||||
output = SimpleTaskOutput(data)
|
||||
self._is_read = True
|
||||
return output
|
||||
|
||||
|
||||
class SimpleInputSource(BaseInputSource):
|
||||
def __init__(self, data: Any) -> None:
|
||||
super().__init__()
|
||||
self._data = data
|
||||
|
||||
def _read_data(self, task_ctx: TaskContext) -> Any:
|
||||
return self._data
|
||||
|
||||
|
||||
class SimpleCallDataInputSource(BaseInputSource):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def _read_data(self, task_ctx: TaskContext) -> Any:
|
||||
call_data = task_ctx.call_data
|
||||
data = call_data.get("data") if call_data else None
|
||||
if not (call_data and data):
|
||||
raise ValueError("No call data for current SimpleCallDataInputSource")
|
||||
return data
|
||||
|
||||
|
||||
class DefaultTaskContext(TaskContext, Generic[T]):
|
||||
def __init__(
|
||||
self, task_id: str, task_state: TaskState, task_output: TaskOutput[T]
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._task_id = task_id
|
||||
self._task_state = task_state
|
||||
self._output = task_output
|
||||
self._task_input = None
|
||||
self._metadata = {}
|
||||
|
||||
@property
|
||||
def task_id(self) -> str:
|
||||
return self._task_id
|
||||
|
||||
@property
|
||||
def task_input(self) -> InputContext:
|
||||
return self._task_input
|
||||
|
||||
def set_task_input(self, input_ctx: "InputContext") -> None:
|
||||
self._task_input = input_ctx
|
||||
|
||||
@property
|
||||
def task_output(self) -> TaskOutput:
|
||||
return self._output
|
||||
|
||||
def set_task_output(self, task_output: TaskOutput) -> None:
|
||||
self._output = task_output
|
||||
|
||||
@property
|
||||
def current_state(self) -> TaskState:
|
||||
return self._task_state
|
||||
|
||||
def set_current_state(self, task_state: TaskState) -> None:
|
||||
self._task_state = task_state
|
||||
|
||||
def new_ctx(self) -> TaskContext:
|
||||
new_output = self._output.new_output()
|
||||
return DefaultTaskContext(self._task_id, self._task_state, new_output)
|
||||
|
||||
@property
|
||||
def metadata(self) -> Dict[str, Any]:
|
||||
return self._metadata
|
||||
|
||||
async def _call_data_to_output(self) -> Optional[TaskOutput[T]]:
|
||||
"""Get the call data for current data"""
|
||||
call_data = self.call_data
|
||||
if not call_data:
|
||||
return None
|
||||
input_source = SimpleCallDataInputSource()
|
||||
return await input_source.read(self)
|
||||
|
||||
|
||||
class DefaultInputContext(InputContext):
|
||||
def __init__(self, outputs: List[TaskContext]) -> None:
|
||||
super().__init__()
|
||||
self._outputs = outputs
|
||||
|
||||
@property
|
||||
def parent_outputs(self) -> List[TaskContext]:
|
||||
return self._outputs
|
||||
|
||||
async def _apply_func(
|
||||
self, func: Callable[[Any], Any], apply_type: str = "map"
|
||||
) -> Tuple[List[TaskContext], List[TaskOutput]]:
|
||||
new_outputs: List[TaskContext] = []
|
||||
map_tasks = []
|
||||
for out in self._outputs:
|
||||
new_outputs.append(out.new_ctx())
|
||||
result = None
|
||||
if apply_type == "map":
|
||||
result = out.task_output.map(func)
|
||||
elif apply_type == "reduce":
|
||||
result = out.task_output.reduce(func)
|
||||
elif apply_type == "check_condition":
|
||||
result = out.task_output.check_condition(func)
|
||||
else:
|
||||
raise ValueError(f"Unsupport apply type {apply_type}")
|
||||
map_tasks.append(result)
|
||||
results = await asyncio.gather(*map_tasks)
|
||||
return new_outputs, results
|
||||
|
||||
async def map(self, map_func: Callable[[Any], Any]) -> InputContext:
|
||||
new_outputs, results = await self._apply_func(map_func)
|
||||
for i, task_ctx in enumerate(new_outputs):
|
||||
task_ctx: TaskContext = task_ctx
|
||||
task_ctx.set_task_output(results[i])
|
||||
return DefaultInputContext(new_outputs)
|
||||
|
||||
async def map_all(self, map_func: Callable[..., Any]) -> InputContext:
|
||||
if not self._outputs:
|
||||
return DefaultInputContext([])
|
||||
# Some parent may be empty
|
||||
not_empty_idx = 0
|
||||
for i, p in enumerate(self._outputs):
|
||||
if p.task_output.is_empty:
|
||||
continue
|
||||
not_empty_idx = i
|
||||
break
|
||||
# All output is empty?
|
||||
is_steam = self._outputs[not_empty_idx].task_output.is_stream
|
||||
if is_steam:
|
||||
if not self.check_stream(skip_empty=True):
|
||||
raise ValueError(
|
||||
"The output in all tasks must has same output format to map_all"
|
||||
)
|
||||
outputs = []
|
||||
for out in self._outputs:
|
||||
if out.task_output.is_stream:
|
||||
outputs.append(out.task_output.output_stream)
|
||||
else:
|
||||
outputs.append(out.task_output.output)
|
||||
if asyncio.iscoroutinefunction(map_func):
|
||||
map_res = await map_func(*outputs)
|
||||
else:
|
||||
map_res = map_func(*outputs)
|
||||
single_output: TaskContext = self._outputs[not_empty_idx].new_ctx()
|
||||
single_output.task_output.set_output(map_res)
|
||||
logger.debug(
|
||||
f"Current map_all map_res: {map_res}, is steam: {single_output.task_output.is_stream}"
|
||||
)
|
||||
return DefaultInputContext([single_output])
|
||||
|
||||
async def reduce(self, reduce_func: Callable[[Any], Any]) -> InputContext:
|
||||
if not self.check_stream():
|
||||
raise ValueError(
|
||||
"The output in all tasks must has same output format of stream to apply reduce function"
|
||||
)
|
||||
new_outputs, results = await self._apply_func(reduce_func, apply_type="reduce")
|
||||
for i, task_ctx in enumerate(new_outputs):
|
||||
task_ctx: TaskContext = task_ctx
|
||||
task_ctx.set_task_output(results[i])
|
||||
return DefaultInputContext(new_outputs)
|
||||
|
||||
async def filter(self, filter_func: Callable[[Any], bool]) -> InputContext:
|
||||
new_outputs, results = await self._apply_func(
|
||||
filter_func, apply_type="check_condition"
|
||||
)
|
||||
result_outputs = []
|
||||
for i, task_ctx in enumerate(new_outputs):
|
||||
if results[i]:
|
||||
result_outputs.append(task_ctx)
|
||||
return DefaultInputContext(result_outputs)
|
||||
|
||||
async def predicate_map(
|
||||
self, predicate_func: Callable[[Any], bool], failed_value: Any = None
|
||||
) -> "InputContext":
|
||||
new_outputs, results = await self._apply_func(
|
||||
predicate_func, apply_type="check_condition"
|
||||
)
|
||||
result_outputs = []
|
||||
for i, task_ctx in enumerate(new_outputs):
|
||||
task_ctx: TaskContext = task_ctx
|
||||
if results[i]:
|
||||
task_ctx.task_output.set_output(True)
|
||||
result_outputs.append(task_ctx)
|
||||
else:
|
||||
task_ctx.task_output.set_output(failed_value)
|
||||
result_outputs.append(task_ctx)
|
||||
return DefaultInputContext(result_outputs)
|
0
dbgpt/core/awel/tests/__init__.py
Normal file
0
dbgpt/core/awel/tests/__init__.py
Normal file
102
dbgpt/core/awel/tests/conftest.py
Normal file
102
dbgpt/core/awel/tests/conftest.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from typing import AsyncIterator, List
|
||||
from contextlib import contextmanager, asynccontextmanager
|
||||
from .. import (
|
||||
WorkflowRunner,
|
||||
InputOperator,
|
||||
DAGContext,
|
||||
TaskState,
|
||||
DefaultWorkflowRunner,
|
||||
SimpleInputSource,
|
||||
)
|
||||
from ..task.task_impl import _is_async_iterator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner():
|
||||
return DefaultWorkflowRunner()
|
||||
|
||||
|
||||
def _create_stream(num_nodes) -> List[AsyncIterator[int]]:
|
||||
iters = []
|
||||
for _ in range(num_nodes):
|
||||
|
||||
async def stream_iter():
|
||||
for i in range(10):
|
||||
yield i
|
||||
|
||||
stream_iter = stream_iter()
|
||||
assert _is_async_iterator(stream_iter)
|
||||
iters.append(stream_iter)
|
||||
return iters
|
||||
|
||||
|
||||
def _create_stream_from(output_streams: List[List[int]]) -> List[AsyncIterator[int]]:
|
||||
iters = []
|
||||
for single_stream in output_streams:
|
||||
|
||||
async def stream_iter():
|
||||
for i in single_stream:
|
||||
yield i
|
||||
|
||||
stream_iter = stream_iter()
|
||||
assert _is_async_iterator(stream_iter)
|
||||
iters.append(stream_iter)
|
||||
return iters
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _create_input_node(**kwargs):
|
||||
num_nodes = kwargs.get("num_nodes")
|
||||
is_stream = kwargs.get("is_stream", False)
|
||||
if is_stream:
|
||||
outputs = kwargs.get("output_streams")
|
||||
if outputs:
|
||||
if num_nodes and num_nodes != len(outputs):
|
||||
raise ValueError(
|
||||
f"num_nodes {num_nodes} != the length of output_streams {len(outputs)}"
|
||||
)
|
||||
outputs = _create_stream_from(outputs)
|
||||
else:
|
||||
num_nodes = num_nodes or 1
|
||||
outputs = _create_stream(num_nodes)
|
||||
else:
|
||||
outputs = kwargs.get("outputs", ["Hello."])
|
||||
nodes = []
|
||||
for output in outputs:
|
||||
print(f"output: {output}")
|
||||
input_source = SimpleInputSource(output)
|
||||
input_node = InputOperator(input_source)
|
||||
nodes.append(input_node)
|
||||
yield nodes
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def input_node(request):
|
||||
param = getattr(request, "param", {})
|
||||
async with _create_input_node(**param) as input_nodes:
|
||||
yield input_nodes[0]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def stream_input_node(request):
|
||||
param = getattr(request, "param", {})
|
||||
param["is_stream"] = True
|
||||
async with _create_input_node(**param) as input_nodes:
|
||||
yield input_nodes[0]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def input_nodes(request):
|
||||
param = getattr(request, "param", {})
|
||||
async with _create_input_node(**param) as input_nodes:
|
||||
yield input_nodes
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def stream_input_nodes(request):
|
||||
param = getattr(request, "param", {})
|
||||
param["is_stream"] = True
|
||||
async with _create_input_node(**param) as input_nodes:
|
||||
yield input_nodes
|
51
dbgpt/core/awel/tests/test_http_operator.py
Normal file
51
dbgpt/core/awel/tests/test_http_operator.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import pytest
|
||||
from typing import List
|
||||
from .. import (
|
||||
DAG,
|
||||
WorkflowRunner,
|
||||
DAGContext,
|
||||
TaskState,
|
||||
InputOperator,
|
||||
MapOperator,
|
||||
JoinOperator,
|
||||
BranchOperator,
|
||||
ReduceStreamOperator,
|
||||
SimpleInputSource,
|
||||
)
|
||||
from .conftest import (
|
||||
runner,
|
||||
input_node,
|
||||
input_nodes,
|
||||
stream_input_node,
|
||||
stream_input_nodes,
|
||||
_is_async_iterator,
|
||||
)
|
||||
|
||||
|
||||
def _register_dag_to_fastapi_app(dag):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_operator(runner: WorkflowRunner, stream_input_node: InputOperator):
|
||||
with DAG("test_map") as dag:
|
||||
pass
|
||||
# http_req_task = HttpRequestOperator(endpoint="/api/completions")
|
||||
# db_task = DBQueryOperator(table_name="user_info")
|
||||
# prompt_task = PromptTemplateOperator(
|
||||
# system_prompt="You are an AI designed to solve the user's goals with given commands, please follow the constraints of the system's input for your answers."
|
||||
# )
|
||||
# llm_task = ChatGPTLLMOperator(model="chagpt-3.5")
|
||||
# output_parser_task = CommonOutputParserOperator()
|
||||
# http_res_task = HttpResponseOperator()
|
||||
# (
|
||||
# http_req_task
|
||||
# >> db_task
|
||||
# >> prompt_task
|
||||
# >> llm_task
|
||||
# >> output_parser_task
|
||||
# >> http_res_task
|
||||
# )
|
||||
|
||||
_register_dag_to_fastapi_app(dag)
|
141
dbgpt/core/awel/tests/test_run_dag.py
Normal file
141
dbgpt/core/awel/tests/test_run_dag.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import pytest
|
||||
from typing import List
|
||||
from .. import (
|
||||
DAG,
|
||||
WorkflowRunner,
|
||||
DAGContext,
|
||||
TaskState,
|
||||
InputOperator,
|
||||
MapOperator,
|
||||
JoinOperator,
|
||||
BranchOperator,
|
||||
ReduceStreamOperator,
|
||||
SimpleInputSource,
|
||||
)
|
||||
from .conftest import (
|
||||
runner,
|
||||
input_node,
|
||||
input_nodes,
|
||||
stream_input_node,
|
||||
stream_input_nodes,
|
||||
_is_async_iterator,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_input_node(runner: WorkflowRunner):
|
||||
input_node = InputOperator(SimpleInputSource("hello"))
|
||||
res: DAGContext[str] = await runner.execute_workflow(input_node)
|
||||
assert res.current_task_context.current_state == TaskState.SUCCESS
|
||||
assert res.current_task_context.task_output.output == "hello"
|
||||
|
||||
async def new_steam_iter(n: int):
|
||||
for i in range(n):
|
||||
yield i
|
||||
|
||||
num_iter = 10
|
||||
steam_input_node = InputOperator(SimpleInputSource(new_steam_iter(num_iter)))
|
||||
res: DAGContext[str] = await runner.execute_workflow(steam_input_node)
|
||||
assert res.current_task_context.current_state == TaskState.SUCCESS
|
||||
output_steam = res.current_task_context.task_output.output_stream
|
||||
assert output_steam
|
||||
assert _is_async_iterator(output_steam)
|
||||
i = 0
|
||||
async for x in output_steam:
|
||||
assert x == i
|
||||
i += 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_node(runner: WorkflowRunner, stream_input_node: InputOperator):
|
||||
with DAG("test_map") as dag:
|
||||
map_node = MapOperator(lambda x: x * 2)
|
||||
stream_input_node >> map_node
|
||||
res: DAGContext[int] = await runner.execute_workflow(map_node)
|
||||
output_steam = res.current_task_context.task_output.output_stream
|
||||
assert output_steam
|
||||
i = 0
|
||||
async for x in output_steam:
|
||||
assert x == i * 2
|
||||
i += 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"stream_input_node, expect_sum",
|
||||
[
|
||||
({"output_streams": [[0, 1, 2, 3]]}, 6),
|
||||
({"output_streams": [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]}, 55),
|
||||
],
|
||||
indirect=["stream_input_node"],
|
||||
)
|
||||
async def test_reduce_node(
|
||||
runner: WorkflowRunner, stream_input_node: InputOperator, expect_sum: int
|
||||
):
|
||||
with DAG("test_reduce_node") as dag:
|
||||
reduce_node = ReduceStreamOperator(lambda x, y: x + y)
|
||||
stream_input_node >> reduce_node
|
||||
res: DAGContext[int] = await runner.execute_workflow(reduce_node)
|
||||
assert res.current_task_context.current_state == TaskState.SUCCESS
|
||||
assert not res.current_task_context.task_output.is_stream
|
||||
assert res.current_task_context.task_output.output == expect_sum
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"input_nodes",
|
||||
[
|
||||
({"outputs": [0, 1, 2]}),
|
||||
],
|
||||
indirect=["input_nodes"],
|
||||
)
|
||||
async def test_join_node(runner: WorkflowRunner, input_nodes: List[InputOperator]):
|
||||
def join_func(p1, p2, p3) -> int:
|
||||
return p1 + p2 + p3
|
||||
|
||||
with DAG("test_join_node") as dag:
|
||||
join_node = JoinOperator(join_func)
|
||||
for input_node in input_nodes:
|
||||
input_node >> join_node
|
||||
res: DAGContext[int] = await runner.execute_workflow(join_node)
|
||||
assert res.current_task_context.current_state == TaskState.SUCCESS
|
||||
assert not res.current_task_context.task_output.is_stream
|
||||
assert res.current_task_context.task_output.output == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"input_node, is_odd",
|
||||
[
|
||||
({"outputs": [0]}, False),
|
||||
({"outputs": [1]}, True),
|
||||
],
|
||||
indirect=["input_node"],
|
||||
)
|
||||
async def test_branch_node(
|
||||
runner: WorkflowRunner, input_node: InputOperator, is_odd: bool
|
||||
):
|
||||
def join_func(o1, o2) -> int:
|
||||
print(f"join func result, o1: {o1}, o2: {o2}")
|
||||
return o1 or o2
|
||||
|
||||
with DAG("test_join_node") as dag:
|
||||
odd_node = MapOperator(
|
||||
lambda x: 999, task_id="odd_node", task_name="odd_node_name"
|
||||
)
|
||||
even_node = MapOperator(
|
||||
lambda x: 888, task_id="even_node", task_name="even_node_name"
|
||||
)
|
||||
join_node = JoinOperator(join_func)
|
||||
branch_node = BranchOperator(
|
||||
{lambda x: x % 2 == 1: odd_node, lambda x: x % 2 == 0: even_node}
|
||||
)
|
||||
branch_node >> odd_node >> join_node
|
||||
branch_node >> even_node >> join_node
|
||||
|
||||
input_node >> branch_node
|
||||
|
||||
res: DAGContext[int] = await runner.execute_workflow(join_node)
|
||||
assert res.current_task_context.current_state == TaskState.SUCCESS
|
||||
expect_res = 999 if is_odd else 888
|
||||
assert res.current_task_context.task_output.output == expect_res
|
0
dbgpt/core/awel/trigger/__init__.py
Normal file
0
dbgpt/core/awel/trigger/__init__.py
Normal file
11
dbgpt/core/awel/trigger/base.py
Normal file
11
dbgpt/core/awel/trigger/base.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from ..operator.common_operator import TriggerOperator
|
||||
|
||||
|
||||
class Trigger(TriggerOperator, ABC):
|
||||
@abstractmethod
|
||||
async def trigger(self) -> None:
|
||||
"""Trigger the workflow or a specific operation in the workflow."""
|
137
dbgpt/core/awel/trigger/http_trigger.py
Normal file
137
dbgpt/core/awel/trigger/http_trigger.py
Normal file
@@ -0,0 +1,137 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Union, Type, List, TYPE_CHECKING, Optional, Any, Dict
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
import logging
|
||||
|
||||
from .base import Trigger
|
||||
from ..dag.base import DAG
|
||||
from ..operator.base import BaseOperator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import APIRouter, FastAPI
|
||||
|
||||
RequestBody = Union[Request, Type[BaseModel], str]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HttpTrigger(Trigger):
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str,
|
||||
methods: Optional[Union[str, List[str]]] = "GET",
|
||||
request_body: Optional[RequestBody] = None,
|
||||
streaming_response: Optional[bool] = False,
|
||||
response_model: Optional[Type] = None,
|
||||
response_headers: Optional[Dict[str, str]] = None,
|
||||
response_media_type: Optional[str] = None,
|
||||
status_code: Optional[int] = 200,
|
||||
router_tags: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
if not endpoint.startswith("/"):
|
||||
endpoint = "/" + endpoint
|
||||
self._endpoint = endpoint
|
||||
self._methods = methods
|
||||
self._req_body = request_body
|
||||
self._streaming_response = streaming_response
|
||||
self._response_model = response_model
|
||||
self._status_code = status_code
|
||||
self._router_tags = router_tags
|
||||
self._response_headers = response_headers
|
||||
self._response_media_type = response_media_type
|
||||
self._end_node: BaseOperator = None
|
||||
|
||||
async def trigger(self) -> None:
|
||||
pass
|
||||
|
||||
def mount_to_router(self, router: "APIRouter") -> None:
|
||||
from fastapi import Depends
|
||||
|
||||
methods = self._methods if isinstance(self._methods, list) else [self._methods]
|
||||
|
||||
def create_route_function(name, req_body_cls: Optional[Type[BaseModel]]):
|
||||
async def _request_body_dependency(request: Request):
|
||||
return await _parse_request_body(request, self._req_body)
|
||||
|
||||
async def route_function(body=Depends(_request_body_dependency)):
|
||||
return await _trigger_dag(
|
||||
body,
|
||||
self.dag,
|
||||
self._streaming_response,
|
||||
self._response_headers,
|
||||
self._response_media_type,
|
||||
)
|
||||
|
||||
route_function.__name__ = name
|
||||
return route_function
|
||||
|
||||
function_name = f"AWEL_trigger_route_{self._endpoint.replace('/', '_')}"
|
||||
request_model = (
|
||||
self._req_body
|
||||
if isinstance(self._req_body, type)
|
||||
and issubclass(self._req_body, BaseModel)
|
||||
else None
|
||||
)
|
||||
dynamic_route_function = create_route_function(function_name, request_model)
|
||||
logger.info(
|
||||
f"mount router function {dynamic_route_function}({function_name}), endpoint: {self._endpoint}, methods: {methods}"
|
||||
)
|
||||
|
||||
router.api_route(
|
||||
self._endpoint,
|
||||
methods=methods,
|
||||
response_model=self._response_model,
|
||||
status_code=self._status_code,
|
||||
tags=self._router_tags,
|
||||
)(dynamic_route_function)
|
||||
|
||||
|
||||
async def _parse_request_body(
|
||||
request: Request, request_body_cls: Optional[Type[BaseModel]]
|
||||
):
|
||||
if not request_body_cls:
|
||||
return None
|
||||
if request.method == "POST":
|
||||
json_data = await request.json()
|
||||
return request_body_cls(**json_data)
|
||||
elif request.method == "GET":
|
||||
return request_body_cls(**request.query_params)
|
||||
else:
|
||||
return request
|
||||
|
||||
|
||||
async def _trigger_dag(
|
||||
body: Any,
|
||||
dag: DAG,
|
||||
streaming_response: Optional[bool] = False,
|
||||
response_headers: Optional[Dict[str, str]] = None,
|
||||
response_media_type: Optional[str] = None,
|
||||
) -> Any:
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
end_node = dag.leaf_nodes
|
||||
if len(end_node) != 1:
|
||||
raise ValueError("HttpTrigger just support one leaf node in dag")
|
||||
end_node = end_node[0]
|
||||
if not streaming_response:
|
||||
return await end_node.call(call_data={"data": body})
|
||||
else:
|
||||
headers = response_headers
|
||||
media_type = response_media_type if response_media_type else "text/event-stream"
|
||||
if not headers:
|
||||
headers = {
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Transfer-Encoding": "chunked",
|
||||
}
|
||||
return StreamingResponse(
|
||||
end_node.call_stream(call_data={"data": body}),
|
||||
headers=headers,
|
||||
media_type=media_type,
|
||||
)
|
74
dbgpt/core/awel/trigger/trigger_manager.py
Normal file
74
dbgpt/core/awel/trigger/trigger_manager.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, TYPE_CHECKING, Optional
|
||||
import logging
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import APIRouter
|
||||
|
||||
from dbgpt.component import SystemApp, BaseComponent, ComponentType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerManager(ABC):
|
||||
@abstractmethod
|
||||
def register_trigger(self, trigger: Any) -> None:
|
||||
""" "Register a trigger to current manager"""
|
||||
|
||||
|
||||
class HttpTriggerManager(TriggerManager):
|
||||
def __init__(
|
||||
self,
|
||||
router: Optional["APIRouter"] = None,
|
||||
router_prefix: Optional[str] = "/api/v1/awel/trigger",
|
||||
) -> None:
|
||||
if not router:
|
||||
from fastapi import APIRouter
|
||||
|
||||
router = APIRouter()
|
||||
self._router_prefix = router_prefix
|
||||
self._router = router
|
||||
self._trigger_map = {}
|
||||
|
||||
def register_trigger(self, trigger: Any) -> None:
|
||||
from .http_trigger import HttpTrigger
|
||||
|
||||
if not isinstance(trigger, HttpTrigger):
|
||||
raise ValueError(f"Current trigger {trigger} not an object of HttpTrigger")
|
||||
trigger: HttpTrigger = trigger
|
||||
trigger_id = trigger.node_id
|
||||
if trigger_id not in self._trigger_map:
|
||||
trigger.mount_to_router(self._router)
|
||||
self._trigger_map[trigger_id] = trigger
|
||||
|
||||
def _init_app(self, system_app: SystemApp):
|
||||
logger.info(
|
||||
f"Include router {self._router} to prefix path {self._router_prefix}"
|
||||
)
|
||||
system_app.app.include_router(
|
||||
self._router, prefix=self._router_prefix, tags=["AWEL"]
|
||||
)
|
||||
|
||||
|
||||
class DefaultTriggerManager(TriggerManager, BaseComponent):
|
||||
name = ComponentType.AWEL_TRIGGER_MANAGER
|
||||
|
||||
def __init__(self, system_app: SystemApp | None = None):
|
||||
self.system_app = system_app
|
||||
self.http_trigger = HttpTriggerManager()
|
||||
super().__init__(None)
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
self.system_app = system_app
|
||||
|
||||
def register_trigger(self, trigger: Any) -> None:
|
||||
from .http_trigger import HttpTrigger
|
||||
|
||||
if isinstance(trigger, HttpTrigger):
|
||||
logger.info(f"Register trigger {trigger}")
|
||||
self.http_trigger.register_trigger(trigger)
|
||||
else:
|
||||
raise ValueError(f"Unsupport trigger: {trigger}")
|
||||
|
||||
def after_register(self) -> None:
|
||||
self.http_trigger._init_app(self.system_app)
|
Reference in New Issue
Block a user