fix(core): Fix bug of sharing data across DAGs (#1102)

This commit is contained in:
Fangyin Cheng 2024-01-22 21:56:03 +08:00 committed by GitHub
parent 73c86ff083
commit 13527a8bd4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 96 additions and 55 deletions

View File

@ -446,27 +446,25 @@ class DAGContext:
def __init__(
self,
node_to_outputs: Dict[str, TaskContext],
share_data: Dict[str, Any],
streaming_call: bool = False,
node_to_outputs: Optional[Dict[str, TaskContext]] = None,
node_name_to_ids: Optional[Dict[str, str]] = None,
) -> None:
"""Initialize a DAGContext.
Args:
node_to_outputs (Dict[str, TaskContext]): The task outputs of current DAG.
share_data (Dict[str, Any]): The share data of current DAG.
streaming_call (bool, optional): Whether the current DAG is streaming call.
Defaults to False.
node_to_outputs (Optional[Dict[str, TaskContext]], optional):
The task outputs of current DAG. Defaults to None.
node_name_to_ids (Optional[Dict[str, str]], optional):
The task name to task id mapping. Defaults to None.
node_name_to_ids (Optional[Dict[str, str]], optional): The node name to node
"""
if not node_to_outputs:
node_to_outputs = {}
if not node_name_to_ids:
node_name_to_ids = {}
self._streaming_call = streaming_call
self._curr_task_ctx: Optional[TaskContext] = None
self._share_data: Dict[str, Any] = {}
self._share_data: Dict[str, Any] = share_data
self._node_to_outputs: Dict[str, TaskContext] = node_to_outputs
self._node_name_to_ids: Dict[str, str] = node_name_to_ids
@ -530,6 +528,7 @@ class DAGContext:
Returns:
Any: The share data, you can cast it to the real type
"""
logger.debug(f"Get share data by key {key} from {id(self._share_data)}")
return self._share_data.get(key)
async def save_to_share_data(
@ -545,6 +544,7 @@ class DAGContext:
"""
if key in self._share_data and not overwrite:
raise ValueError(f"Share data key {key} already exists")
logger.debug(f"Save share data by key {key} to {id(self._share_data)}")
self._share_data[key] = data
async def get_task_share_data(self, task_name: str, key: str) -> Any:

View File

@ -3,7 +3,7 @@
This runner will run the workflow in the current process.
"""
import logging
from typing import Dict, List, Optional, Set, cast
from typing import Any, Dict, List, Optional, Set, cast
from dbgpt.component import SystemApp
@ -20,6 +20,10 @@ logger = logging.getLogger(__name__)
class DefaultWorkflowRunner(WorkflowRunner):
"""The default workflow runner."""
def __init__(self):
"""Init the default workflow runner."""
self._running_dag_ctx: Dict[str, DAGContext] = {}
async def execute_workflow(
self,
node: BaseOperator,
@ -44,15 +48,22 @@ class DefaultWorkflowRunner(WorkflowRunner):
if not exist_dag_ctx:
# Create DAG context
node_outputs: Dict[str, TaskContext] = {}
share_data: Dict[str, Any] = {}
else:
# Share node output with exist dag context
node_outputs = exist_dag_ctx._node_to_outputs
share_data = exist_dag_ctx._share_data
dag_ctx = DAGContext(
streaming_call=streaming_call,
node_to_outputs=node_outputs,
share_data=share_data,
streaming_call=streaming_call,
node_name_to_ids=job_manager._node_name_to_ids,
)
logger.info(f"Begin run workflow from end operator, id: {node.node_id}")
if node.dag:
self._running_dag_ctx[node.dag.dag_id] = dag_ctx
logger.info(
f"Begin run workflow from end operator, id: {node.node_id}, runner: {self}"
)
logger.debug(f"Node id {node.node_id}, call_data: {call_data}")
skip_node_ids: Set[str] = set()
system_app: Optional[SystemApp] = DAGVar.get_current_system_app()
@ -64,7 +75,8 @@ class DefaultWorkflowRunner(WorkflowRunner):
if not streaming_call and node.dag:
# streaming call not work for dag end
await node.dag._after_dag_end()
if node.dag:
del self._running_dag_ctx[node.dag.dag_id]
return dag_ctx
async def _execute_node(

View File

@ -14,11 +14,21 @@ if TYPE_CHECKING:
from starlette.requests import Request
RequestBody = Union[Type[Request], Type[BaseModel], Type[str]]
StreamingPredictFunc = Callable[[Union[Request, BaseModel]], bool]
StreamingPredictFunc = Callable[[Union[Request, BaseModel, str, None]], bool]
logger = logging.getLogger(__name__)
class AWELHttpError(RuntimeError):
"""AWEL Http Error."""
def __init__(self, msg: str, code: Optional[str] = None):
"""Init the AWELHttpError."""
super().__init__(msg)
self.msg = msg
self.code = code
class HttpTrigger(Trigger):
"""Http trigger for AWEL.
@ -65,27 +75,72 @@ class HttpTrigger(Trigger):
Args:
router (APIRouter): The router to mount the trigger.
"""
from fastapi import Depends
from inspect import Parameter, Signature
from typing import get_type_hints
from starlette.requests import Request
methods = [self._methods] if isinstance(self._methods, str) else self._methods
is_query_method = (
all(method in ["GET", "DELETE"] for method in methods) if methods else True
)
def create_route_function(name, req_body_cls: Optional[Type[BaseModel]]):
async def _request_body_dependency(request: Request):
return await _parse_request_body(request, self._req_body)
async def route_function(body=Depends(_request_body_dependency)):
async def _trigger_dag_func(body: Union[Request, BaseModel, str, None]):
streaming_response = self._streaming_response
if self._streaming_predict_func:
streaming_response = self._streaming_predict_func(body)
dag = self.dag
if not dag:
raise AWELHttpError("DAG is not set")
return await _trigger_dag(
body,
self.dag,
dag,
streaming_response,
self._response_headers,
self._response_media_type,
)
def create_route_function(name, req_body_cls: Optional[Type[BaseModel]]):
async def route_function_request(request: Request):
return await _trigger_dag_func(request)
async def route_function_none():
return await _trigger_dag_func(None)
route_function_request.__name__ = name
route_function_none.__name__ = name
if not req_body_cls:
return route_function_none
if req_body_cls == Request:
return route_function_request
if is_query_method:
if req_body_cls == str:
raise AWELHttpError(f"Query methods {methods} not support str type")
async def route_function_get(**kwargs):
body = req_body_cls(**kwargs)
return await _trigger_dag_func(body)
parameters = [
Parameter(
name=field_name,
kind=Parameter.KEYWORD_ONLY,
default=Parameter.empty,
annotation=field.outer_type_,
)
for field_name, field in req_body_cls.__fields__.items()
]
route_function_get.__signature__ = Signature(parameters) # type: ignore
route_function_get.__annotations__ = get_type_hints(req_body_cls)
route_function_get.__name__ = name
return route_function_get
else:
async def route_function(body: req_body_cls): # type: ignore
return await _trigger_dag_func(body)
route_function.__name__ = name
return route_function
@ -111,32 +166,6 @@ class HttpTrigger(Trigger):
)(dynamic_route_function)
async def _parse_request_body(
request: "Request", request_body_cls: Optional["RequestBody"]
):
from starlette.requests import Request
if not request_body_cls:
return None
if request_body_cls == Request:
return request
if request.method == "POST":
if request_body_cls == str:
bytes_body = await request.body()
str_body = bytes_body.decode("utf-8")
return str_body
elif issubclass(request_body_cls, BaseModel):
json_data = await request.json()
return request_body_cls(**json_data)
else:
raise ValueError(f"Invalid request body cls: {request_body_cls}")
elif request.method == "GET":
if issubclass(request_body_cls, BaseModel):
return request_body_cls(**request.query_params)
else:
raise ValueError(f"Invalid request body cls: {request_body_cls}")
async def _trigger_dag(
body: Any,
dag: DAG,