mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-01 01:04:43 +00:00
feat(core): More AWEL operators and new prompt manager API (#972)
Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
@@ -7,6 +7,7 @@ The stability of this API cannot be guaranteed at present.
|
||||
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from dbgpt.component import SystemApp
|
||||
|
||||
from .dag.base import DAGContext, DAG
|
||||
@@ -68,6 +69,7 @@ __all__ = [
|
||||
"UnstreamifyAbsOperator",
|
||||
"TransformStreamAbsOperator",
|
||||
"HttpTrigger",
|
||||
"setup_dev_environment",
|
||||
]
|
||||
|
||||
|
||||
@@ -85,3 +87,29 @@ def initialize_awel(system_app: SystemApp, dag_filepath: str):
|
||||
initialize_runner(DefaultWorkflowRunner())
|
||||
# Load all dags
|
||||
dag_manager.load_dags()
|
||||
|
||||
|
||||
def setup_dev_environment(
|
||||
dags: List[DAG], host: Optional[str] = "0.0.0.0", port: Optional[int] = 5555
|
||||
) -> None:
|
||||
"""Setup a development environment for AWEL.
|
||||
|
||||
Just using in development environment, not production environment.
|
||||
"""
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from dbgpt.component import SystemApp
|
||||
from .trigger.trigger_manager import DefaultTriggerManager
|
||||
from .dag.base import DAGVar
|
||||
|
||||
app = FastAPI()
|
||||
system_app = SystemApp(app)
|
||||
DAGVar.set_current_system_app(system_app)
|
||||
trigger_manager = DefaultTriggerManager()
|
||||
system_app.register_instance(trigger_manager)
|
||||
|
||||
for dag in dags:
|
||||
for trigger in dag.trigger_nodes:
|
||||
trigger_manager.register_trigger(trigger)
|
||||
trigger_manager.after_register()
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
@@ -11,7 +11,7 @@ from concurrent.futures import Executor
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from ..resource.base import ResourceGroup
|
||||
from ..task.base import TaskContext
|
||||
from ..task.base import TaskContext, TaskOutput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -168,7 +168,19 @@ class DAGVar:
|
||||
cls._executor = executor
|
||||
|
||||
|
||||
class DAGNode(DependencyMixin, ABC):
|
||||
class DAGLifecycle:
|
||||
"""The lifecycle of DAG"""
|
||||
|
||||
async def before_dag_run(self):
|
||||
"""The callback before DAG run"""
|
||||
pass
|
||||
|
||||
async def after_dag_end(self):
|
||||
"""The callback after DAG end"""
|
||||
pass
|
||||
|
||||
|
||||
class DAGNode(DAGLifecycle, DependencyMixin, ABC):
|
||||
resource_group: Optional[ResourceGroup] = None
|
||||
"""The resource group of current DAGNode"""
|
||||
|
||||
@@ -179,7 +191,7 @@ class DAGNode(DependencyMixin, ABC):
|
||||
node_name: Optional[str] = None,
|
||||
system_app: Optional[SystemApp] = None,
|
||||
executor: Optional[Executor] = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._upstream: List["DAGNode"] = []
|
||||
@@ -198,10 +210,23 @@ class DAGNode(DependencyMixin, ABC):
|
||||
def node_id(self) -> str:
|
||||
return self._node_id
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def dev_mode(self) -> bool:
|
||||
"""Whether current DAGNode is in dev mode"""
|
||||
|
||||
@property
|
||||
def system_app(self) -> SystemApp:
|
||||
return self._system_app
|
||||
|
||||
def set_system_app(self, system_app: SystemApp) -> None:
|
||||
"""Set system app for current DAGNode
|
||||
|
||||
Args:
|
||||
system_app (SystemApp): The system app
|
||||
"""
|
||||
self._system_app = system_app
|
||||
|
||||
def set_node_id(self, node_id: str) -> None:
|
||||
self._node_id = node_id
|
||||
|
||||
@@ -274,11 +299,41 @@ class DAGNode(DependencyMixin, ABC):
|
||||
node._upstream.append(self)
|
||||
|
||||
|
||||
def _build_task_key(task_name: str, key: str) -> str:
|
||||
return f"{task_name}___$$$$$$___{key}"
|
||||
|
||||
|
||||
class DAGContext:
|
||||
def __init__(self, streaming_call: bool = False) -> None:
|
||||
"""The context of current DAG, created when the DAG is running
|
||||
|
||||
Every DAG has been triggered will create a new DAGContext.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
streaming_call: bool = False,
|
||||
node_to_outputs: Dict[str, TaskContext] = None,
|
||||
node_name_to_ids: Dict[str, str] = None,
|
||||
) -> None:
|
||||
if not node_to_outputs:
|
||||
node_to_outputs = {}
|
||||
if not node_name_to_ids:
|
||||
node_name_to_ids = {}
|
||||
self._streaming_call = streaming_call
|
||||
self._curr_task_ctx = None
|
||||
self._share_data: Dict[str, Any] = {}
|
||||
self._node_to_outputs = node_to_outputs
|
||||
self._node_name_to_ids = node_name_to_ids
|
||||
|
||||
@property
|
||||
def _task_outputs(self) -> Dict[str, TaskContext]:
|
||||
"""The task outputs of current DAG
|
||||
|
||||
Just use for internal for now.
|
||||
Returns:
|
||||
Dict[str, TaskContext]: The task outputs of current DAG
|
||||
"""
|
||||
return self._node_to_outputs
|
||||
|
||||
@property
|
||||
def current_task_context(self) -> TaskContext:
|
||||
@@ -292,12 +347,69 @@ class DAGContext:
|
||||
def set_current_task_context(self, _curr_task_ctx: TaskContext) -> None:
|
||||
self._curr_task_ctx = _curr_task_ctx
|
||||
|
||||
async def get_share_data(self, key: str) -> Any:
|
||||
def get_task_output(self, task_name: str) -> TaskOutput:
|
||||
"""Get the task output by task name
|
||||
|
||||
Args:
|
||||
task_name (str): The task name
|
||||
|
||||
Returns:
|
||||
TaskOutput: The task output
|
||||
"""
|
||||
if task_name is None:
|
||||
raise ValueError("task_name can't be None")
|
||||
node_id = self._node_name_to_ids.get(task_name)
|
||||
if node_id:
|
||||
raise ValueError(f"Task name {task_name} not exists in DAG")
|
||||
return self._task_outputs.get(node_id).task_output
|
||||
|
||||
async def get_from_share_data(self, key: str) -> Any:
|
||||
return self._share_data.get(key)
|
||||
|
||||
async def save_to_share_data(self, key: str, data: Any) -> None:
|
||||
async def save_to_share_data(
|
||||
self, key: str, data: Any, overwrite: Optional[str] = None
|
||||
) -> None:
|
||||
if key in self._share_data and not overwrite:
|
||||
raise ValueError(f"Share data key {key} already exists")
|
||||
self._share_data[key] = data
|
||||
|
||||
async def get_task_share_data(self, task_name: str, key: str) -> Any:
|
||||
"""Get share data by task name and key
|
||||
|
||||
Args:
|
||||
task_name (str): The task name
|
||||
key (str): The share data key
|
||||
|
||||
Returns:
|
||||
Any: The share data
|
||||
"""
|
||||
if task_name is None:
|
||||
raise ValueError("task_name can't be None")
|
||||
if key is None:
|
||||
raise ValueError("key can't be None")
|
||||
return self.get_from_share_data(_build_task_key(task_name, key))
|
||||
|
||||
async def save_task_share_data(
|
||||
self, task_name: str, key: str, data: Any, overwrite: Optional[str] = None
|
||||
) -> None:
|
||||
"""Save share data by task name and key
|
||||
|
||||
Args:
|
||||
task_name (str): The task name
|
||||
key (str): The share data key
|
||||
data (Any): The share data
|
||||
overwrite (Optional[str], optional): Whether overwrite the share data if the key already exists.
|
||||
Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: If the share data key already exists and overwrite is not True
|
||||
"""
|
||||
if task_name is None:
|
||||
raise ValueError("task_name can't be None")
|
||||
if key is None:
|
||||
raise ValueError("key can't be None")
|
||||
await self.save_to_share_data(_build_task_key(task_name, key), data, overwrite)
|
||||
|
||||
|
||||
class DAG:
|
||||
def __init__(
|
||||
@@ -305,11 +417,20 @@ class DAG:
|
||||
) -> None:
|
||||
self._dag_id = dag_id
|
||||
self.node_map: Dict[str, DAGNode] = {}
|
||||
self._root_nodes: Set[DAGNode] = None
|
||||
self._leaf_nodes: Set[DAGNode] = None
|
||||
self._trigger_nodes: Set[DAGNode] = None
|
||||
self.node_name_to_node: Dict[str, DAGNode] = {}
|
||||
self._root_nodes: List[DAGNode] = None
|
||||
self._leaf_nodes: List[DAGNode] = None
|
||||
self._trigger_nodes: List[DAGNode] = None
|
||||
|
||||
def _append_node(self, node: DAGNode) -> None:
|
||||
if node.node_id in self.node_map:
|
||||
return
|
||||
if node.node_name:
|
||||
if node.node_name in self.node_name_to_node:
|
||||
raise ValueError(
|
||||
f"Node name {node.node_name} already exists in DAG {self.dag_id}"
|
||||
)
|
||||
self.node_name_to_node[node.node_name] = node
|
||||
self.node_map[node.node_id] = node
|
||||
# clear cached nodes
|
||||
self._root_nodes = None
|
||||
@@ -336,22 +457,44 @@ class DAG:
|
||||
|
||||
@property
|
||||
def root_nodes(self) -> List[DAGNode]:
|
||||
"""The root nodes of current DAG
|
||||
|
||||
Returns:
|
||||
List[DAGNode]: The root nodes of current DAG, no repeat
|
||||
"""
|
||||
if not self._root_nodes:
|
||||
self._build()
|
||||
return self._root_nodes
|
||||
|
||||
@property
|
||||
def leaf_nodes(self) -> List[DAGNode]:
|
||||
"""The leaf nodes of current DAG
|
||||
|
||||
Returns:
|
||||
List[DAGNode]: The leaf nodes of current DAG, no repeat
|
||||
"""
|
||||
if not self._leaf_nodes:
|
||||
self._build()
|
||||
return self._leaf_nodes
|
||||
|
||||
@property
|
||||
def trigger_nodes(self):
|
||||
def trigger_nodes(self) -> List[DAGNode]:
|
||||
"""The trigger nodes of current DAG
|
||||
|
||||
Returns:
|
||||
List[DAGNode]: The trigger nodes of current DAG, no repeat
|
||||
"""
|
||||
if not self._trigger_nodes:
|
||||
self._build()
|
||||
return self._trigger_nodes
|
||||
|
||||
async def _after_dag_end(self) -> None:
|
||||
"""The callback after DAG end"""
|
||||
tasks = []
|
||||
for node in self.node_map.values():
|
||||
tasks.append(node.after_dag_end())
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
def __enter__(self):
|
||||
DAGVar.enter_dag(self)
|
||||
return self
|
||||
|
@@ -146,6 +146,16 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
def current_dag_context(self) -> DAGContext:
|
||||
return self._dag_ctx
|
||||
|
||||
@property
|
||||
def dev_mode(self) -> bool:
|
||||
"""Whether the operator is in dev mode.
|
||||
In production mode, the default runner is not None.
|
||||
|
||||
Returns:
|
||||
bool: Whether the operator is in dev mode. True if the default runner is None.
|
||||
"""
|
||||
return default_runner is None
|
||||
|
||||
async def _run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
if not self.node_id:
|
||||
raise ValueError(f"The DAG Node ID can't be empty, current node {self}")
|
||||
|
@@ -1,4 +1,14 @@
|
||||
from typing import Generic, Dict, List, Union, Callable, Any, AsyncIterator, Awaitable
|
||||
from typing import (
|
||||
Generic,
|
||||
Dict,
|
||||
List,
|
||||
Union,
|
||||
Callable,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Optional,
|
||||
)
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
@@ -162,7 +172,9 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, branches: Dict[BranchFunc[IN], Union[BaseOperator, str]], **kwargs
|
||||
self,
|
||||
branches: Optional[Dict[BranchFunc[IN], Union[BaseOperator, str]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initializes a BranchDAGNode with a branching function.
|
||||
@@ -203,7 +215,7 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
|
||||
|
||||
branches = self._branches
|
||||
if not branches:
|
||||
branches = await self.branchs()
|
||||
branches = await self.branches()
|
||||
|
||||
branch_func_tasks = []
|
||||
branch_nodes: List[str] = []
|
||||
@@ -229,7 +241,7 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
|
||||
curr_task_ctx.update_metadata("skip_node_names", skip_node_names)
|
||||
return parent_output
|
||||
|
||||
async def branchs(self) -> Dict[BranchFunc[IN], Union[BaseOperator, str]]:
|
||||
async def branches(self) -> Dict[BranchFunc[IN], Union[BaseOperator, str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
@@ -1,7 +1,8 @@
|
||||
import asyncio
|
||||
from typing import List, Set, Optional, Dict
|
||||
import uuid
|
||||
import logging
|
||||
from ..dag.base import DAG
|
||||
from ..dag.base import DAG, DAGLifecycle
|
||||
|
||||
from ..operator.base import BaseOperator, CALL_DATA
|
||||
|
||||
@@ -18,18 +19,20 @@ class DAGInstance:
|
||||
self._dag = dag
|
||||
|
||||
|
||||
class JobManager:
|
||||
class JobManager(DAGLifecycle):
|
||||
def __init__(
|
||||
self,
|
||||
root_nodes: List[BaseOperator],
|
||||
all_nodes: List[BaseOperator],
|
||||
end_node: BaseOperator,
|
||||
id2call_data: Dict[str, Dict],
|
||||
node_name_to_ids: Dict[str, str],
|
||||
) -> None:
|
||||
self._root_nodes = root_nodes
|
||||
self._all_nodes = all_nodes
|
||||
self._end_node = end_node
|
||||
self._id2node_data = id2call_data
|
||||
self._node_name_to_ids = node_name_to_ids
|
||||
|
||||
@staticmethod
|
||||
def build_from_end_node(
|
||||
@@ -38,11 +41,31 @@ class JobManager:
|
||||
nodes = _build_from_end_node(end_node)
|
||||
root_nodes = _get_root_nodes(nodes)
|
||||
id2call_data = _save_call_data(root_nodes, call_data)
|
||||
return JobManager(root_nodes, nodes, end_node, id2call_data)
|
||||
|
||||
node_name_to_ids = {}
|
||||
for node in nodes:
|
||||
if node.node_name is not None:
|
||||
node_name_to_ids[node.node_name] = node.node_id
|
||||
|
||||
return JobManager(root_nodes, nodes, end_node, id2call_data, node_name_to_ids)
|
||||
|
||||
def get_call_data_by_id(self, node_id: str) -> Optional[Dict]:
|
||||
return self._id2node_data.get(node_id)
|
||||
|
||||
async def before_dag_run(self):
|
||||
"""The callback before DAG run"""
|
||||
tasks = []
|
||||
for node in self._all_nodes:
|
||||
tasks.append(node.before_dag_run())
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def after_dag_end(self):
|
||||
"""The callback after DAG end"""
|
||||
tasks = []
|
||||
for node in self._all_nodes:
|
||||
tasks.append(node.after_dag_end())
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
def _save_call_data(
|
||||
root_nodes: List[BaseOperator], call_data: CALL_DATA
|
||||
@@ -66,6 +89,7 @@ def _save_call_data(
|
||||
|
||||
|
||||
def _build_from_end_node(end_node: BaseOperator) -> List[BaseOperator]:
|
||||
"""Build all nodes from the end node."""
|
||||
nodes = []
|
||||
if isinstance(end_node, BaseOperator):
|
||||
task_id = end_node.node_id
|
||||
|
@@ -1,7 +1,8 @@
|
||||
from typing import Dict, Optional, Set, List
|
||||
import logging
|
||||
|
||||
from ..dag.base import DAGContext
|
||||
from dbgpt.component import SystemApp
|
||||
from ..dag.base import DAGContext, DAGVar
|
||||
from ..operator.base import WorkflowRunner, BaseOperator, CALL_DATA
|
||||
from ..operator.common_operator import BranchOperator, JoinOperator, TriggerOperator
|
||||
from ..task.base import TaskContext, TaskState
|
||||
@@ -18,19 +19,29 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
call_data: Optional[CALL_DATA] = None,
|
||||
streaming_call: bool = False,
|
||||
) -> DAGContext:
|
||||
# Create DAG context
|
||||
dag_ctx = DAGContext(streaming_call=streaming_call)
|
||||
# Save node output
|
||||
# dag = node.dag
|
||||
node_outputs: Dict[str, TaskContext] = {}
|
||||
job_manager = JobManager.build_from_end_node(node, call_data)
|
||||
# Create DAG context
|
||||
dag_ctx = DAGContext(
|
||||
streaming_call=streaming_call,
|
||||
node_to_outputs=node_outputs,
|
||||
node_name_to_ids=job_manager._node_name_to_ids,
|
||||
)
|
||||
logger.info(
|
||||
f"Begin run workflow from end operator, id: {node.node_id}, call_data: {call_data}"
|
||||
)
|
||||
dag = node.dag
|
||||
# Save node output
|
||||
node_outputs: Dict[str, TaskContext] = {}
|
||||
skip_node_ids = set()
|
||||
system_app: SystemApp = DAGVar.get_current_system_app()
|
||||
|
||||
await job_manager.before_dag_run()
|
||||
await self._execute_node(
|
||||
job_manager, node, dag_ctx, node_outputs, skip_node_ids
|
||||
job_manager, node, dag_ctx, node_outputs, skip_node_ids, system_app
|
||||
)
|
||||
if not streaming_call and node.dag:
|
||||
# streaming call not work for dag end
|
||||
await node.dag._after_dag_end()
|
||||
|
||||
return dag_ctx
|
||||
|
||||
@@ -41,6 +52,7 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
dag_ctx: DAGContext,
|
||||
node_outputs: Dict[str, TaskContext],
|
||||
skip_node_ids: Set[str],
|
||||
system_app: SystemApp,
|
||||
):
|
||||
# Skip run node
|
||||
if node.node_id in node_outputs:
|
||||
@@ -50,7 +62,12 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
for upstream_node in node.upstream:
|
||||
if isinstance(upstream_node, BaseOperator):
|
||||
await self._execute_node(
|
||||
job_manager, upstream_node, dag_ctx, node_outputs, skip_node_ids
|
||||
job_manager,
|
||||
upstream_node,
|
||||
dag_ctx,
|
||||
node_outputs,
|
||||
skip_node_ids,
|
||||
system_app,
|
||||
)
|
||||
|
||||
inputs = [
|
||||
@@ -73,6 +90,9 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
logger.debug(
|
||||
f"Begin run operator, node id: {node.node_id}, node name: {node.node_name}, cls: {node}"
|
||||
)
|
||||
if system_app is not None and node.system_app is None:
|
||||
node.set_system_app(system_app)
|
||||
|
||||
await node._run(dag_ctx)
|
||||
node_outputs[node.node_id] = dag_ctx.current_task_context
|
||||
task_ctx.set_current_state(TaskState.SUCCESS)
|
||||
|
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Union, Type, List, TYPE_CHECKING, Optional, Any, Dict
|
||||
from typing import Union, Type, List, TYPE_CHECKING, Optional, Any, Dict, Callable
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
@@ -13,7 +13,8 @@ from ..operator.base import BaseOperator
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import APIRouter, FastAPI
|
||||
|
||||
RequestBody = Union[Request, Type[BaseModel], str]
|
||||
RequestBody = Union[Type[Request], Type[BaseModel], str]
|
||||
StreamingPredictFunc = Callable[[Union[Request, BaseModel]], bool]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -25,6 +26,7 @@ class HttpTrigger(Trigger):
|
||||
methods: Optional[Union[str, List[str]]] = "GET",
|
||||
request_body: Optional[RequestBody] = None,
|
||||
streaming_response: Optional[bool] = False,
|
||||
streaming_predict_func: Optional[StreamingPredictFunc] = None,
|
||||
response_model: Optional[Type] = None,
|
||||
response_headers: Optional[Dict[str, str]] = None,
|
||||
response_media_type: Optional[str] = None,
|
||||
@@ -39,6 +41,7 @@ class HttpTrigger(Trigger):
|
||||
self._methods = methods
|
||||
self._req_body = request_body
|
||||
self._streaming_response = streaming_response
|
||||
self._streaming_predict_func = streaming_predict_func
|
||||
self._response_model = response_model
|
||||
self._status_code = status_code
|
||||
self._router_tags = router_tags
|
||||
@@ -59,10 +62,13 @@ class HttpTrigger(Trigger):
|
||||
return await _parse_request_body(request, self._req_body)
|
||||
|
||||
async def route_function(body=Depends(_request_body_dependency)):
|
||||
streaming_response = self._streaming_response
|
||||
if self._streaming_predict_func:
|
||||
streaming_response = self._streaming_predict_func(body)
|
||||
return await _trigger_dag(
|
||||
body,
|
||||
self.dag,
|
||||
self._streaming_response,
|
||||
streaming_response,
|
||||
self._response_headers,
|
||||
self._response_media_type,
|
||||
)
|
||||
@@ -112,6 +118,7 @@ async def _trigger_dag(
|
||||
response_headers: Optional[Dict[str, str]] = None,
|
||||
response_media_type: Optional[str] = None,
|
||||
) -> Any:
|
||||
from fastapi import BackgroundTasks
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
end_node = dag.leaf_nodes
|
||||
@@ -131,8 +138,11 @@ async def _trigger_dag(
|
||||
"Transfer-Encoding": "chunked",
|
||||
}
|
||||
generator = await end_node.call_stream(call_data={"data": body})
|
||||
background_tasks = BackgroundTasks()
|
||||
background_tasks.add_task(end_node.dag._after_dag_end)
|
||||
return StreamingResponse(
|
||||
generator,
|
||||
headers=headers,
|
||||
media_type=media_type,
|
||||
background=background_tasks,
|
||||
)
|
||||
|
Reference in New Issue
Block a user