feat(core): More AWEL operators and new prompt manager API (#972)

Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
Fangyin Cheng
2023-12-25 20:03:22 +08:00
committed by GitHub
parent 048fb6c402
commit 69fb97e508
46 changed files with 2556 additions and 294 deletions

View File

@@ -7,6 +7,7 @@ The stability of this API cannot be guaranteed at present.
"""
from typing import List, Optional
from dbgpt.component import SystemApp
from .dag.base import DAGContext, DAG
@@ -68,6 +69,7 @@ __all__ = [
"UnstreamifyAbsOperator",
"TransformStreamAbsOperator",
"HttpTrigger",
"setup_dev_environment",
]
@@ -85,3 +87,29 @@ def initialize_awel(system_app: SystemApp, dag_filepath: str):
initialize_runner(DefaultWorkflowRunner())
# Load all dags
dag_manager.load_dags()
def setup_dev_environment(
dags: List[DAG], host: Optional[str] = "0.0.0.0", port: Optional[int] = 5555
) -> None:
"""Setup a development environment for AWEL.
Just using in development environment, not production environment.
"""
import uvicorn
from fastapi import FastAPI
from dbgpt.component import SystemApp
from .trigger.trigger_manager import DefaultTriggerManager
from .dag.base import DAGVar
app = FastAPI()
system_app = SystemApp(app)
DAGVar.set_current_system_app(system_app)
trigger_manager = DefaultTriggerManager()
system_app.register_instance(trigger_manager)
for dag in dags:
for trigger in dag.trigger_nodes:
trigger_manager.register_trigger(trigger)
trigger_manager.after_register()
uvicorn.run(app, host=host, port=port)

View File

@@ -11,7 +11,7 @@ from concurrent.futures import Executor
from dbgpt.component import SystemApp
from ..resource.base import ResourceGroup
from ..task.base import TaskContext
from ..task.base import TaskContext, TaskOutput
logger = logging.getLogger(__name__)
@@ -168,7 +168,19 @@ class DAGVar:
cls._executor = executor
class DAGNode(DependencyMixin, ABC):
class DAGLifecycle:
"""The lifecycle of DAG"""
async def before_dag_run(self):
"""The callback before DAG run"""
pass
async def after_dag_end(self):
"""The callback after DAG end"""
pass
class DAGNode(DAGLifecycle, DependencyMixin, ABC):
resource_group: Optional[ResourceGroup] = None
"""The resource group of current DAGNode"""
@@ -179,7 +191,7 @@ class DAGNode(DependencyMixin, ABC):
node_name: Optional[str] = None,
system_app: Optional[SystemApp] = None,
executor: Optional[Executor] = None,
**kwargs
**kwargs,
) -> None:
super().__init__()
self._upstream: List["DAGNode"] = []
@@ -198,10 +210,23 @@ class DAGNode(DependencyMixin, ABC):
def node_id(self) -> str:
return self._node_id
@property
@abstractmethod
def dev_mode(self) -> bool:
"""Whether current DAGNode is in dev mode"""
@property
def system_app(self) -> SystemApp:
return self._system_app
def set_system_app(self, system_app: SystemApp) -> None:
"""Set system app for current DAGNode
Args:
system_app (SystemApp): The system app
"""
self._system_app = system_app
def set_node_id(self, node_id: str) -> None:
self._node_id = node_id
@@ -274,11 +299,41 @@ class DAGNode(DependencyMixin, ABC):
node._upstream.append(self)
def _build_task_key(task_name: str, key: str) -> str:
return f"{task_name}___$$$$$$___{key}"
class DAGContext:
def __init__(self, streaming_call: bool = False) -> None:
"""The context of current DAG, created when the DAG is running
Every DAG has been triggered will create a new DAGContext.
"""
def __init__(
self,
streaming_call: bool = False,
node_to_outputs: Dict[str, TaskContext] = None,
node_name_to_ids: Dict[str, str] = None,
) -> None:
if not node_to_outputs:
node_to_outputs = {}
if not node_name_to_ids:
node_name_to_ids = {}
self._streaming_call = streaming_call
self._curr_task_ctx = None
self._share_data: Dict[str, Any] = {}
self._node_to_outputs = node_to_outputs
self._node_name_to_ids = node_name_to_ids
@property
def _task_outputs(self) -> Dict[str, TaskContext]:
"""The task outputs of current DAG
Just use for internal for now.
Returns:
Dict[str, TaskContext]: The task outputs of current DAG
"""
return self._node_to_outputs
@property
def current_task_context(self) -> TaskContext:
@@ -292,12 +347,69 @@ class DAGContext:
def set_current_task_context(self, _curr_task_ctx: TaskContext) -> None:
self._curr_task_ctx = _curr_task_ctx
async def get_share_data(self, key: str) -> Any:
def get_task_output(self, task_name: str) -> TaskOutput:
"""Get the task output by task name
Args:
task_name (str): The task name
Returns:
TaskOutput: The task output
"""
if task_name is None:
raise ValueError("task_name can't be None")
node_id = self._node_name_to_ids.get(task_name)
if node_id:
raise ValueError(f"Task name {task_name} not exists in DAG")
return self._task_outputs.get(node_id).task_output
async def get_from_share_data(self, key: str) -> Any:
return self._share_data.get(key)
async def save_to_share_data(self, key: str, data: Any) -> None:
async def save_to_share_data(
self, key: str, data: Any, overwrite: Optional[str] = None
) -> None:
if key in self._share_data and not overwrite:
raise ValueError(f"Share data key {key} already exists")
self._share_data[key] = data
async def get_task_share_data(self, task_name: str, key: str) -> Any:
"""Get share data by task name and key
Args:
task_name (str): The task name
key (str): The share data key
Returns:
Any: The share data
"""
if task_name is None:
raise ValueError("task_name can't be None")
if key is None:
raise ValueError("key can't be None")
return self.get_from_share_data(_build_task_key(task_name, key))
async def save_task_share_data(
self, task_name: str, key: str, data: Any, overwrite: Optional[str] = None
) -> None:
"""Save share data by task name and key
Args:
task_name (str): The task name
key (str): The share data key
data (Any): The share data
overwrite (Optional[str], optional): Whether overwrite the share data if the key already exists.
Defaults to None.
Raises:
ValueError: If the share data key already exists and overwrite is not True
"""
if task_name is None:
raise ValueError("task_name can't be None")
if key is None:
raise ValueError("key can't be None")
await self.save_to_share_data(_build_task_key(task_name, key), data, overwrite)
class DAG:
def __init__(
@@ -305,11 +417,20 @@ class DAG:
) -> None:
self._dag_id = dag_id
self.node_map: Dict[str, DAGNode] = {}
self._root_nodes: Set[DAGNode] = None
self._leaf_nodes: Set[DAGNode] = None
self._trigger_nodes: Set[DAGNode] = None
self.node_name_to_node: Dict[str, DAGNode] = {}
self._root_nodes: List[DAGNode] = None
self._leaf_nodes: List[DAGNode] = None
self._trigger_nodes: List[DAGNode] = None
def _append_node(self, node: DAGNode) -> None:
if node.node_id in self.node_map:
return
if node.node_name:
if node.node_name in self.node_name_to_node:
raise ValueError(
f"Node name {node.node_name} already exists in DAG {self.dag_id}"
)
self.node_name_to_node[node.node_name] = node
self.node_map[node.node_id] = node
# clear cached nodes
self._root_nodes = None
@@ -336,22 +457,44 @@ class DAG:
@property
def root_nodes(self) -> List[DAGNode]:
"""The root nodes of current DAG
Returns:
List[DAGNode]: The root nodes of current DAG, no repeat
"""
if not self._root_nodes:
self._build()
return self._root_nodes
@property
def leaf_nodes(self) -> List[DAGNode]:
"""The leaf nodes of current DAG
Returns:
List[DAGNode]: The leaf nodes of current DAG, no repeat
"""
if not self._leaf_nodes:
self._build()
return self._leaf_nodes
@property
def trigger_nodes(self):
def trigger_nodes(self) -> List[DAGNode]:
"""The trigger nodes of current DAG
Returns:
List[DAGNode]: The trigger nodes of current DAG, no repeat
"""
if not self._trigger_nodes:
self._build()
return self._trigger_nodes
async def _after_dag_end(self) -> None:
"""The callback after DAG end"""
tasks = []
for node in self.node_map.values():
tasks.append(node.after_dag_end())
await asyncio.gather(*tasks)
def __enter__(self):
DAGVar.enter_dag(self)
return self

View File

@@ -146,6 +146,16 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
def current_dag_context(self) -> DAGContext:
return self._dag_ctx
@property
def dev_mode(self) -> bool:
"""Whether the operator is in dev mode.
In production mode, the default runner is not None.
Returns:
bool: Whether the operator is in dev mode. True if the default runner is None.
"""
return default_runner is None
async def _run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
if not self.node_id:
raise ValueError(f"The DAG Node ID can't be empty, current node {self}")

View File

@@ -1,4 +1,14 @@
from typing import Generic, Dict, List, Union, Callable, Any, AsyncIterator, Awaitable
from typing import (
Generic,
Dict,
List,
Union,
Callable,
Any,
AsyncIterator,
Awaitable,
Optional,
)
import asyncio
import logging
@@ -162,7 +172,9 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
"""
def __init__(
self, branches: Dict[BranchFunc[IN], Union[BaseOperator, str]], **kwargs
self,
branches: Optional[Dict[BranchFunc[IN], Union[BaseOperator, str]]] = None,
**kwargs,
):
"""
Initializes a BranchDAGNode with a branching function.
@@ -203,7 +215,7 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
branches = self._branches
if not branches:
branches = await self.branchs()
branches = await self.branches()
branch_func_tasks = []
branch_nodes: List[str] = []
@@ -229,7 +241,7 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
curr_task_ctx.update_metadata("skip_node_names", skip_node_names)
return parent_output
async def branchs(self) -> Dict[BranchFunc[IN], Union[BaseOperator, str]]:
async def branches(self) -> Dict[BranchFunc[IN], Union[BaseOperator, str]]:
raise NotImplementedError

View File

@@ -1,7 +1,8 @@
import asyncio
from typing import List, Set, Optional, Dict
import uuid
import logging
from ..dag.base import DAG
from ..dag.base import DAG, DAGLifecycle
from ..operator.base import BaseOperator, CALL_DATA
@@ -18,18 +19,20 @@ class DAGInstance:
self._dag = dag
class JobManager:
class JobManager(DAGLifecycle):
def __init__(
self,
root_nodes: List[BaseOperator],
all_nodes: List[BaseOperator],
end_node: BaseOperator,
id2call_data: Dict[str, Dict],
node_name_to_ids: Dict[str, str],
) -> None:
self._root_nodes = root_nodes
self._all_nodes = all_nodes
self._end_node = end_node
self._id2node_data = id2call_data
self._node_name_to_ids = node_name_to_ids
@staticmethod
def build_from_end_node(
@@ -38,11 +41,31 @@ class JobManager:
nodes = _build_from_end_node(end_node)
root_nodes = _get_root_nodes(nodes)
id2call_data = _save_call_data(root_nodes, call_data)
return JobManager(root_nodes, nodes, end_node, id2call_data)
node_name_to_ids = {}
for node in nodes:
if node.node_name is not None:
node_name_to_ids[node.node_name] = node.node_id
return JobManager(root_nodes, nodes, end_node, id2call_data, node_name_to_ids)
def get_call_data_by_id(self, node_id: str) -> Optional[Dict]:
return self._id2node_data.get(node_id)
async def before_dag_run(self):
"""The callback before DAG run"""
tasks = []
for node in self._all_nodes:
tasks.append(node.before_dag_run())
await asyncio.gather(*tasks)
async def after_dag_end(self):
"""The callback after DAG end"""
tasks = []
for node in self._all_nodes:
tasks.append(node.after_dag_end())
await asyncio.gather(*tasks)
def _save_call_data(
root_nodes: List[BaseOperator], call_data: CALL_DATA
@@ -66,6 +89,7 @@ def _save_call_data(
def _build_from_end_node(end_node: BaseOperator) -> List[BaseOperator]:
"""Build all nodes from the end node."""
nodes = []
if isinstance(end_node, BaseOperator):
task_id = end_node.node_id

View File

@@ -1,7 +1,8 @@
from typing import Dict, Optional, Set, List
import logging
from ..dag.base import DAGContext
from dbgpt.component import SystemApp
from ..dag.base import DAGContext, DAGVar
from ..operator.base import WorkflowRunner, BaseOperator, CALL_DATA
from ..operator.common_operator import BranchOperator, JoinOperator, TriggerOperator
from ..task.base import TaskContext, TaskState
@@ -18,19 +19,29 @@ class DefaultWorkflowRunner(WorkflowRunner):
call_data: Optional[CALL_DATA] = None,
streaming_call: bool = False,
) -> DAGContext:
# Create DAG context
dag_ctx = DAGContext(streaming_call=streaming_call)
# Save node output
# dag = node.dag
node_outputs: Dict[str, TaskContext] = {}
job_manager = JobManager.build_from_end_node(node, call_data)
# Create DAG context
dag_ctx = DAGContext(
streaming_call=streaming_call,
node_to_outputs=node_outputs,
node_name_to_ids=job_manager._node_name_to_ids,
)
logger.info(
f"Begin run workflow from end operator, id: {node.node_id}, call_data: {call_data}"
)
dag = node.dag
# Save node output
node_outputs: Dict[str, TaskContext] = {}
skip_node_ids = set()
system_app: SystemApp = DAGVar.get_current_system_app()
await job_manager.before_dag_run()
await self._execute_node(
job_manager, node, dag_ctx, node_outputs, skip_node_ids
job_manager, node, dag_ctx, node_outputs, skip_node_ids, system_app
)
if not streaming_call and node.dag:
# streaming call not work for dag end
await node.dag._after_dag_end()
return dag_ctx
@@ -41,6 +52,7 @@ class DefaultWorkflowRunner(WorkflowRunner):
dag_ctx: DAGContext,
node_outputs: Dict[str, TaskContext],
skip_node_ids: Set[str],
system_app: SystemApp,
):
# Skip run node
if node.node_id in node_outputs:
@@ -50,7 +62,12 @@ class DefaultWorkflowRunner(WorkflowRunner):
for upstream_node in node.upstream:
if isinstance(upstream_node, BaseOperator):
await self._execute_node(
job_manager, upstream_node, dag_ctx, node_outputs, skip_node_ids
job_manager,
upstream_node,
dag_ctx,
node_outputs,
skip_node_ids,
system_app,
)
inputs = [
@@ -73,6 +90,9 @@ class DefaultWorkflowRunner(WorkflowRunner):
logger.debug(
f"Begin run operator, node id: {node.node_id}, node name: {node.node_name}, cls: {node}"
)
if system_app is not None and node.system_app is None:
node.set_system_app(system_app)
await node._run(dag_ctx)
node_outputs[node.node_id] = dag_ctx.current_task_context
task_ctx.set_current_state(TaskState.SUCCESS)

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Union, Type, List, TYPE_CHECKING, Optional, Any, Dict
from typing import Union, Type, List, TYPE_CHECKING, Optional, Any, Dict, Callable
from starlette.requests import Request
from starlette.responses import Response
from dbgpt._private.pydantic import BaseModel
@@ -13,7 +13,8 @@ from ..operator.base import BaseOperator
if TYPE_CHECKING:
from fastapi import APIRouter, FastAPI
RequestBody = Union[Request, Type[BaseModel], str]
RequestBody = Union[Type[Request], Type[BaseModel], str]
StreamingPredictFunc = Callable[[Union[Request, BaseModel]], bool]
logger = logging.getLogger(__name__)
@@ -25,6 +26,7 @@ class HttpTrigger(Trigger):
methods: Optional[Union[str, List[str]]] = "GET",
request_body: Optional[RequestBody] = None,
streaming_response: Optional[bool] = False,
streaming_predict_func: Optional[StreamingPredictFunc] = None,
response_model: Optional[Type] = None,
response_headers: Optional[Dict[str, str]] = None,
response_media_type: Optional[str] = None,
@@ -39,6 +41,7 @@ class HttpTrigger(Trigger):
self._methods = methods
self._req_body = request_body
self._streaming_response = streaming_response
self._streaming_predict_func = streaming_predict_func
self._response_model = response_model
self._status_code = status_code
self._router_tags = router_tags
@@ -59,10 +62,13 @@ class HttpTrigger(Trigger):
return await _parse_request_body(request, self._req_body)
async def route_function(body=Depends(_request_body_dependency)):
streaming_response = self._streaming_response
if self._streaming_predict_func:
streaming_response = self._streaming_predict_func(body)
return await _trigger_dag(
body,
self.dag,
self._streaming_response,
streaming_response,
self._response_headers,
self._response_media_type,
)
@@ -112,6 +118,7 @@ async def _trigger_dag(
response_headers: Optional[Dict[str, str]] = None,
response_media_type: Optional[str] = None,
) -> Any:
from fastapi import BackgroundTasks
from fastapi.responses import StreamingResponse
end_node = dag.leaf_nodes
@@ -131,8 +138,11 @@ async def _trigger_dag(
"Transfer-Encoding": "chunked",
}
generator = await end_node.call_stream(call_data={"data": body})
background_tasks = BackgroundTasks()
background_tasks.add_task(end_node.dag._after_dag_end)
return StreamingResponse(
generator,
headers=headers,
media_type=media_type,
background=background_tasks,
)