mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-24 02:53:24 +00:00
fix(core): Fix bug of sharing data across DAGs (#1102)
This commit is contained in:
parent
73c86ff083
commit
13527a8bd4
@ -446,27 +446,25 @@ class DAGContext:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
node_to_outputs: Dict[str, TaskContext],
|
||||||
|
share_data: Dict[str, Any],
|
||||||
streaming_call: bool = False,
|
streaming_call: bool = False,
|
||||||
node_to_outputs: Optional[Dict[str, TaskContext]] = None,
|
|
||||||
node_name_to_ids: Optional[Dict[str, str]] = None,
|
node_name_to_ids: Optional[Dict[str, str]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize a DAGContext.
|
"""Initialize a DAGContext.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
node_to_outputs (Dict[str, TaskContext]): The task outputs of current DAG.
|
||||||
|
share_data (Dict[str, Any]): The share data of current DAG.
|
||||||
streaming_call (bool, optional): Whether the current DAG is streaming call.
|
streaming_call (bool, optional): Whether the current DAG is streaming call.
|
||||||
Defaults to False.
|
Defaults to False.
|
||||||
node_to_outputs (Optional[Dict[str, TaskContext]], optional):
|
node_name_to_ids (Optional[Dict[str, str]], optional): The node name to node
|
||||||
The task outputs of current DAG. Defaults to None.
|
|
||||||
node_name_to_ids (Optional[Dict[str, str]], optional):
|
|
||||||
The task name to task id mapping. Defaults to None.
|
|
||||||
"""
|
"""
|
||||||
if not node_to_outputs:
|
|
||||||
node_to_outputs = {}
|
|
||||||
if not node_name_to_ids:
|
if not node_name_to_ids:
|
||||||
node_name_to_ids = {}
|
node_name_to_ids = {}
|
||||||
self._streaming_call = streaming_call
|
self._streaming_call = streaming_call
|
||||||
self._curr_task_ctx: Optional[TaskContext] = None
|
self._curr_task_ctx: Optional[TaskContext] = None
|
||||||
self._share_data: Dict[str, Any] = {}
|
self._share_data: Dict[str, Any] = share_data
|
||||||
self._node_to_outputs: Dict[str, TaskContext] = node_to_outputs
|
self._node_to_outputs: Dict[str, TaskContext] = node_to_outputs
|
||||||
self._node_name_to_ids: Dict[str, str] = node_name_to_ids
|
self._node_name_to_ids: Dict[str, str] = node_name_to_ids
|
||||||
|
|
||||||
@ -530,6 +528,7 @@ class DAGContext:
|
|||||||
Returns:
|
Returns:
|
||||||
Any: The share data, you can cast it to the real type
|
Any: The share data, you can cast it to the real type
|
||||||
"""
|
"""
|
||||||
|
logger.debug(f"Get share data by key {key} from {id(self._share_data)}")
|
||||||
return self._share_data.get(key)
|
return self._share_data.get(key)
|
||||||
|
|
||||||
async def save_to_share_data(
|
async def save_to_share_data(
|
||||||
@ -545,6 +544,7 @@ class DAGContext:
|
|||||||
"""
|
"""
|
||||||
if key in self._share_data and not overwrite:
|
if key in self._share_data and not overwrite:
|
||||||
raise ValueError(f"Share data key {key} already exists")
|
raise ValueError(f"Share data key {key} already exists")
|
||||||
|
logger.debug(f"Save share data by key {key} to {id(self._share_data)}")
|
||||||
self._share_data[key] = data
|
self._share_data[key] = data
|
||||||
|
|
||||||
async def get_task_share_data(self, task_name: str, key: str) -> Any:
|
async def get_task_share_data(self, task_name: str, key: str) -> Any:
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
This runner will run the workflow in the current process.
|
This runner will run the workflow in the current process.
|
||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Optional, Set, cast
|
from typing import Any, Dict, List, Optional, Set, cast
|
||||||
|
|
||||||
from dbgpt.component import SystemApp
|
from dbgpt.component import SystemApp
|
||||||
|
|
||||||
@ -20,6 +20,10 @@ logger = logging.getLogger(__name__)
|
|||||||
class DefaultWorkflowRunner(WorkflowRunner):
|
class DefaultWorkflowRunner(WorkflowRunner):
|
||||||
"""The default workflow runner."""
|
"""The default workflow runner."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Init the default workflow runner."""
|
||||||
|
self._running_dag_ctx: Dict[str, DAGContext] = {}
|
||||||
|
|
||||||
async def execute_workflow(
|
async def execute_workflow(
|
||||||
self,
|
self,
|
||||||
node: BaseOperator,
|
node: BaseOperator,
|
||||||
@ -44,15 +48,22 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
|||||||
if not exist_dag_ctx:
|
if not exist_dag_ctx:
|
||||||
# Create DAG context
|
# Create DAG context
|
||||||
node_outputs: Dict[str, TaskContext] = {}
|
node_outputs: Dict[str, TaskContext] = {}
|
||||||
|
share_data: Dict[str, Any] = {}
|
||||||
else:
|
else:
|
||||||
# Share node output with exist dag context
|
# Share node output with exist dag context
|
||||||
node_outputs = exist_dag_ctx._node_to_outputs
|
node_outputs = exist_dag_ctx._node_to_outputs
|
||||||
|
share_data = exist_dag_ctx._share_data
|
||||||
dag_ctx = DAGContext(
|
dag_ctx = DAGContext(
|
||||||
streaming_call=streaming_call,
|
|
||||||
node_to_outputs=node_outputs,
|
node_to_outputs=node_outputs,
|
||||||
|
share_data=share_data,
|
||||||
|
streaming_call=streaming_call,
|
||||||
node_name_to_ids=job_manager._node_name_to_ids,
|
node_name_to_ids=job_manager._node_name_to_ids,
|
||||||
)
|
)
|
||||||
logger.info(f"Begin run workflow from end operator, id: {node.node_id}")
|
if node.dag:
|
||||||
|
self._running_dag_ctx[node.dag.dag_id] = dag_ctx
|
||||||
|
logger.info(
|
||||||
|
f"Begin run workflow from end operator, id: {node.node_id}, runner: {self}"
|
||||||
|
)
|
||||||
logger.debug(f"Node id {node.node_id}, call_data: {call_data}")
|
logger.debug(f"Node id {node.node_id}, call_data: {call_data}")
|
||||||
skip_node_ids: Set[str] = set()
|
skip_node_ids: Set[str] = set()
|
||||||
system_app: Optional[SystemApp] = DAGVar.get_current_system_app()
|
system_app: Optional[SystemApp] = DAGVar.get_current_system_app()
|
||||||
@ -64,7 +75,8 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
|||||||
if not streaming_call and node.dag:
|
if not streaming_call and node.dag:
|
||||||
# streaming call not work for dag end
|
# streaming call not work for dag end
|
||||||
await node.dag._after_dag_end()
|
await node.dag._after_dag_end()
|
||||||
|
if node.dag:
|
||||||
|
del self._running_dag_ctx[node.dag.dag_id]
|
||||||
return dag_ctx
|
return dag_ctx
|
||||||
|
|
||||||
async def _execute_node(
|
async def _execute_node(
|
||||||
|
@ -14,11 +14,21 @@ if TYPE_CHECKING:
|
|||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
|
|
||||||
RequestBody = Union[Type[Request], Type[BaseModel], Type[str]]
|
RequestBody = Union[Type[Request], Type[BaseModel], Type[str]]
|
||||||
StreamingPredictFunc = Callable[[Union[Request, BaseModel]], bool]
|
StreamingPredictFunc = Callable[[Union[Request, BaseModel, str, None]], bool]
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AWELHttpError(RuntimeError):
|
||||||
|
"""AWEL Http Error."""
|
||||||
|
|
||||||
|
def __init__(self, msg: str, code: Optional[str] = None):
|
||||||
|
"""Init the AWELHttpError."""
|
||||||
|
super().__init__(msg)
|
||||||
|
self.msg = msg
|
||||||
|
self.code = code
|
||||||
|
|
||||||
|
|
||||||
class HttpTrigger(Trigger):
|
class HttpTrigger(Trigger):
|
||||||
"""Http trigger for AWEL.
|
"""Http trigger for AWEL.
|
||||||
|
|
||||||
@ -65,29 +75,74 @@ class HttpTrigger(Trigger):
|
|||||||
Args:
|
Args:
|
||||||
router (APIRouter): The router to mount the trigger.
|
router (APIRouter): The router to mount the trigger.
|
||||||
"""
|
"""
|
||||||
from fastapi import Depends
|
from inspect import Parameter, Signature
|
||||||
|
from typing import get_type_hints
|
||||||
|
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
|
|
||||||
methods = [self._methods] if isinstance(self._methods, str) else self._methods
|
methods = [self._methods] if isinstance(self._methods, str) else self._methods
|
||||||
|
is_query_method = (
|
||||||
|
all(method in ["GET", "DELETE"] for method in methods) if methods else True
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _trigger_dag_func(body: Union[Request, BaseModel, str, None]):
|
||||||
|
streaming_response = self._streaming_response
|
||||||
|
if self._streaming_predict_func:
|
||||||
|
streaming_response = self._streaming_predict_func(body)
|
||||||
|
dag = self.dag
|
||||||
|
if not dag:
|
||||||
|
raise AWELHttpError("DAG is not set")
|
||||||
|
return await _trigger_dag(
|
||||||
|
body,
|
||||||
|
dag,
|
||||||
|
streaming_response,
|
||||||
|
self._response_headers,
|
||||||
|
self._response_media_type,
|
||||||
|
)
|
||||||
|
|
||||||
def create_route_function(name, req_body_cls: Optional[Type[BaseModel]]):
|
def create_route_function(name, req_body_cls: Optional[Type[BaseModel]]):
|
||||||
async def _request_body_dependency(request: Request):
|
async def route_function_request(request: Request):
|
||||||
return await _parse_request_body(request, self._req_body)
|
return await _trigger_dag_func(request)
|
||||||
|
|
||||||
async def route_function(body=Depends(_request_body_dependency)):
|
async def route_function_none():
|
||||||
streaming_response = self._streaming_response
|
return await _trigger_dag_func(None)
|
||||||
if self._streaming_predict_func:
|
|
||||||
streaming_response = self._streaming_predict_func(body)
|
|
||||||
return await _trigger_dag(
|
|
||||||
body,
|
|
||||||
self.dag,
|
|
||||||
streaming_response,
|
|
||||||
self._response_headers,
|
|
||||||
self._response_media_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
route_function.__name__ = name
|
route_function_request.__name__ = name
|
||||||
return route_function
|
route_function_none.__name__ = name
|
||||||
|
|
||||||
|
if not req_body_cls:
|
||||||
|
return route_function_none
|
||||||
|
if req_body_cls == Request:
|
||||||
|
return route_function_request
|
||||||
|
|
||||||
|
if is_query_method:
|
||||||
|
if req_body_cls == str:
|
||||||
|
raise AWELHttpError(f"Query methods {methods} not support str type")
|
||||||
|
|
||||||
|
async def route_function_get(**kwargs):
|
||||||
|
body = req_body_cls(**kwargs)
|
||||||
|
return await _trigger_dag_func(body)
|
||||||
|
|
||||||
|
parameters = [
|
||||||
|
Parameter(
|
||||||
|
name=field_name,
|
||||||
|
kind=Parameter.KEYWORD_ONLY,
|
||||||
|
default=Parameter.empty,
|
||||||
|
annotation=field.outer_type_,
|
||||||
|
)
|
||||||
|
for field_name, field in req_body_cls.__fields__.items()
|
||||||
|
]
|
||||||
|
route_function_get.__signature__ = Signature(parameters) # type: ignore
|
||||||
|
route_function_get.__annotations__ = get_type_hints(req_body_cls)
|
||||||
|
route_function_get.__name__ = name
|
||||||
|
return route_function_get
|
||||||
|
else:
|
||||||
|
|
||||||
|
async def route_function(body: req_body_cls): # type: ignore
|
||||||
|
return await _trigger_dag_func(body)
|
||||||
|
|
||||||
|
route_function.__name__ = name
|
||||||
|
return route_function
|
||||||
|
|
||||||
function_name = f"AWEL_trigger_route_{self._endpoint.replace('/', '_')}"
|
function_name = f"AWEL_trigger_route_{self._endpoint.replace('/', '_')}"
|
||||||
request_model = (
|
request_model = (
|
||||||
@ -111,32 +166,6 @@ class HttpTrigger(Trigger):
|
|||||||
)(dynamic_route_function)
|
)(dynamic_route_function)
|
||||||
|
|
||||||
|
|
||||||
async def _parse_request_body(
|
|
||||||
request: "Request", request_body_cls: Optional["RequestBody"]
|
|
||||||
):
|
|
||||||
from starlette.requests import Request
|
|
||||||
|
|
||||||
if not request_body_cls:
|
|
||||||
return None
|
|
||||||
if request_body_cls == Request:
|
|
||||||
return request
|
|
||||||
if request.method == "POST":
|
|
||||||
if request_body_cls == str:
|
|
||||||
bytes_body = await request.body()
|
|
||||||
str_body = bytes_body.decode("utf-8")
|
|
||||||
return str_body
|
|
||||||
elif issubclass(request_body_cls, BaseModel):
|
|
||||||
json_data = await request.json()
|
|
||||||
return request_body_cls(**json_data)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid request body cls: {request_body_cls}")
|
|
||||||
elif request.method == "GET":
|
|
||||||
if issubclass(request_body_cls, BaseModel):
|
|
||||||
return request_body_cls(**request.query_params)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid request body cls: {request_body_cls}")
|
|
||||||
|
|
||||||
|
|
||||||
async def _trigger_dag(
|
async def _trigger_dag(
|
||||||
body: Any,
|
body: Any,
|
||||||
dag: DAG,
|
dag: DAG,
|
||||||
|
Loading…
Reference in New Issue
Block a user