mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-19 08:47:32 +00:00
fix(core): Fix bug of sharing data across DAGs (#1102)
This commit is contained in:
parent
73c86ff083
commit
13527a8bd4
@ -446,27 +446,25 @@ class DAGContext:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_to_outputs: Dict[str, TaskContext],
|
||||
share_data: Dict[str, Any],
|
||||
streaming_call: bool = False,
|
||||
node_to_outputs: Optional[Dict[str, TaskContext]] = None,
|
||||
node_name_to_ids: Optional[Dict[str, str]] = None,
|
||||
) -> None:
|
||||
"""Initialize a DAGContext.
|
||||
|
||||
Args:
|
||||
node_to_outputs (Dict[str, TaskContext]): The task outputs of current DAG.
|
||||
share_data (Dict[str, Any]): The share data of current DAG.
|
||||
streaming_call (bool, optional): Whether the current DAG is streaming call.
|
||||
Defaults to False.
|
||||
node_to_outputs (Optional[Dict[str, TaskContext]], optional):
|
||||
The task outputs of current DAG. Defaults to None.
|
||||
node_name_to_ids (Optional[Dict[str, str]], optional):
|
||||
The task name to task id mapping. Defaults to None.
|
||||
node_name_to_ids (Optional[Dict[str, str]], optional): The node name to node
|
||||
"""
|
||||
if not node_to_outputs:
|
||||
node_to_outputs = {}
|
||||
if not node_name_to_ids:
|
||||
node_name_to_ids = {}
|
||||
self._streaming_call = streaming_call
|
||||
self._curr_task_ctx: Optional[TaskContext] = None
|
||||
self._share_data: Dict[str, Any] = {}
|
||||
self._share_data: Dict[str, Any] = share_data
|
||||
self._node_to_outputs: Dict[str, TaskContext] = node_to_outputs
|
||||
self._node_name_to_ids: Dict[str, str] = node_name_to_ids
|
||||
|
||||
@ -530,6 +528,7 @@ class DAGContext:
|
||||
Returns:
|
||||
Any: The share data, you can cast it to the real type
|
||||
"""
|
||||
logger.debug(f"Get share data by key {key} from {id(self._share_data)}")
|
||||
return self._share_data.get(key)
|
||||
|
||||
async def save_to_share_data(
|
||||
@ -545,6 +544,7 @@ class DAGContext:
|
||||
"""
|
||||
if key in self._share_data and not overwrite:
|
||||
raise ValueError(f"Share data key {key} already exists")
|
||||
logger.debug(f"Save share data by key {key} to {id(self._share_data)}")
|
||||
self._share_data[key] = data
|
||||
|
||||
async def get_task_share_data(self, task_name: str, key: str) -> Any:
|
||||
|
@ -3,7 +3,7 @@
|
||||
This runner will run the workflow in the current process.
|
||||
"""
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Set, cast
|
||||
from typing import Any, Dict, List, Optional, Set, cast
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
|
||||
@ -20,6 +20,10 @@ logger = logging.getLogger(__name__)
|
||||
class DefaultWorkflowRunner(WorkflowRunner):
|
||||
"""The default workflow runner."""
|
||||
|
||||
def __init__(self):
|
||||
"""Init the default workflow runner."""
|
||||
self._running_dag_ctx: Dict[str, DAGContext] = {}
|
||||
|
||||
async def execute_workflow(
|
||||
self,
|
||||
node: BaseOperator,
|
||||
@ -44,15 +48,22 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
if not exist_dag_ctx:
|
||||
# Create DAG context
|
||||
node_outputs: Dict[str, TaskContext] = {}
|
||||
share_data: Dict[str, Any] = {}
|
||||
else:
|
||||
# Share node output with exist dag context
|
||||
node_outputs = exist_dag_ctx._node_to_outputs
|
||||
share_data = exist_dag_ctx._share_data
|
||||
dag_ctx = DAGContext(
|
||||
streaming_call=streaming_call,
|
||||
node_to_outputs=node_outputs,
|
||||
share_data=share_data,
|
||||
streaming_call=streaming_call,
|
||||
node_name_to_ids=job_manager._node_name_to_ids,
|
||||
)
|
||||
logger.info(f"Begin run workflow from end operator, id: {node.node_id}")
|
||||
if node.dag:
|
||||
self._running_dag_ctx[node.dag.dag_id] = dag_ctx
|
||||
logger.info(
|
||||
f"Begin run workflow from end operator, id: {node.node_id}, runner: {self}"
|
||||
)
|
||||
logger.debug(f"Node id {node.node_id}, call_data: {call_data}")
|
||||
skip_node_ids: Set[str] = set()
|
||||
system_app: Optional[SystemApp] = DAGVar.get_current_system_app()
|
||||
@ -64,7 +75,8 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
if not streaming_call and node.dag:
|
||||
# streaming call not work for dag end
|
||||
await node.dag._after_dag_end()
|
||||
|
||||
if node.dag:
|
||||
del self._running_dag_ctx[node.dag.dag_id]
|
||||
return dag_ctx
|
||||
|
||||
async def _execute_node(
|
||||
|
@ -14,11 +14,21 @@ if TYPE_CHECKING:
|
||||
from starlette.requests import Request
|
||||
|
||||
RequestBody = Union[Type[Request], Type[BaseModel], Type[str]]
|
||||
StreamingPredictFunc = Callable[[Union[Request, BaseModel]], bool]
|
||||
StreamingPredictFunc = Callable[[Union[Request, BaseModel, str, None]], bool]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AWELHttpError(RuntimeError):
|
||||
"""AWEL Http Error."""
|
||||
|
||||
def __init__(self, msg: str, code: Optional[str] = None):
|
||||
"""Init the AWELHttpError."""
|
||||
super().__init__(msg)
|
||||
self.msg = msg
|
||||
self.code = code
|
||||
|
||||
|
||||
class HttpTrigger(Trigger):
|
||||
"""Http trigger for AWEL.
|
||||
|
||||
@ -65,27 +75,72 @@ class HttpTrigger(Trigger):
|
||||
Args:
|
||||
router (APIRouter): The router to mount the trigger.
|
||||
"""
|
||||
from fastapi import Depends
|
||||
from inspect import Parameter, Signature
|
||||
from typing import get_type_hints
|
||||
|
||||
from starlette.requests import Request
|
||||
|
||||
methods = [self._methods] if isinstance(self._methods, str) else self._methods
|
||||
is_query_method = (
|
||||
all(method in ["GET", "DELETE"] for method in methods) if methods else True
|
||||
)
|
||||
|
||||
def create_route_function(name, req_body_cls: Optional[Type[BaseModel]]):
|
||||
async def _request_body_dependency(request: Request):
|
||||
return await _parse_request_body(request, self._req_body)
|
||||
|
||||
async def route_function(body=Depends(_request_body_dependency)):
|
||||
async def _trigger_dag_func(body: Union[Request, BaseModel, str, None]):
|
||||
streaming_response = self._streaming_response
|
||||
if self._streaming_predict_func:
|
||||
streaming_response = self._streaming_predict_func(body)
|
||||
dag = self.dag
|
||||
if not dag:
|
||||
raise AWELHttpError("DAG is not set")
|
||||
return await _trigger_dag(
|
||||
body,
|
||||
self.dag,
|
||||
dag,
|
||||
streaming_response,
|
||||
self._response_headers,
|
||||
self._response_media_type,
|
||||
)
|
||||
|
||||
def create_route_function(name, req_body_cls: Optional[Type[BaseModel]]):
|
||||
async def route_function_request(request: Request):
|
||||
return await _trigger_dag_func(request)
|
||||
|
||||
async def route_function_none():
|
||||
return await _trigger_dag_func(None)
|
||||
|
||||
route_function_request.__name__ = name
|
||||
route_function_none.__name__ = name
|
||||
|
||||
if not req_body_cls:
|
||||
return route_function_none
|
||||
if req_body_cls == Request:
|
||||
return route_function_request
|
||||
|
||||
if is_query_method:
|
||||
if req_body_cls == str:
|
||||
raise AWELHttpError(f"Query methods {methods} not support str type")
|
||||
|
||||
async def route_function_get(**kwargs):
|
||||
body = req_body_cls(**kwargs)
|
||||
return await _trigger_dag_func(body)
|
||||
|
||||
parameters = [
|
||||
Parameter(
|
||||
name=field_name,
|
||||
kind=Parameter.KEYWORD_ONLY,
|
||||
default=Parameter.empty,
|
||||
annotation=field.outer_type_,
|
||||
)
|
||||
for field_name, field in req_body_cls.__fields__.items()
|
||||
]
|
||||
route_function_get.__signature__ = Signature(parameters) # type: ignore
|
||||
route_function_get.__annotations__ = get_type_hints(req_body_cls)
|
||||
route_function_get.__name__ = name
|
||||
return route_function_get
|
||||
else:
|
||||
|
||||
async def route_function(body: req_body_cls): # type: ignore
|
||||
return await _trigger_dag_func(body)
|
||||
|
||||
route_function.__name__ = name
|
||||
return route_function
|
||||
|
||||
@ -111,32 +166,6 @@ class HttpTrigger(Trigger):
|
||||
)(dynamic_route_function)
|
||||
|
||||
|
||||
async def _parse_request_body(
|
||||
request: "Request", request_body_cls: Optional["RequestBody"]
|
||||
):
|
||||
from starlette.requests import Request
|
||||
|
||||
if not request_body_cls:
|
||||
return None
|
||||
if request_body_cls == Request:
|
||||
return request
|
||||
if request.method == "POST":
|
||||
if request_body_cls == str:
|
||||
bytes_body = await request.body()
|
||||
str_body = bytes_body.decode("utf-8")
|
||||
return str_body
|
||||
elif issubclass(request_body_cls, BaseModel):
|
||||
json_data = await request.json()
|
||||
return request_body_cls(**json_data)
|
||||
else:
|
||||
raise ValueError(f"Invalid request body cls: {request_body_cls}")
|
||||
elif request.method == "GET":
|
||||
if issubclass(request_body_cls, BaseModel):
|
||||
return request_body_cls(**request.query_params)
|
||||
else:
|
||||
raise ValueError(f"Invalid request body cls: {request_body_cls}")
|
||||
|
||||
|
||||
async def _trigger_dag(
|
||||
body: Any,
|
||||
dag: DAG,
|
||||
|
Loading…
Reference in New Issue
Block a user