fix(core): Fix bug of sharing data across DAGs (#1102)

This commit is contained in:
Fangyin Cheng 2024-01-22 21:56:03 +08:00 committed by GitHub
parent 73c86ff083
commit 13527a8bd4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 96 additions and 55 deletions

View File

@ -446,27 +446,25 @@ class DAGContext:
def __init__( def __init__(
self, self,
node_to_outputs: Dict[str, TaskContext],
share_data: Dict[str, Any],
streaming_call: bool = False, streaming_call: bool = False,
node_to_outputs: Optional[Dict[str, TaskContext]] = None,
node_name_to_ids: Optional[Dict[str, str]] = None, node_name_to_ids: Optional[Dict[str, str]] = None,
) -> None: ) -> None:
"""Initialize a DAGContext. """Initialize a DAGContext.
Args: Args:
node_to_outputs (Dict[str, TaskContext]): The task outputs of current DAG.
share_data (Dict[str, Any]): The share data of current DAG.
streaming_call (bool, optional): Whether the current DAG is streaming call. streaming_call (bool, optional): Whether the current DAG is streaming call.
Defaults to False. Defaults to False.
node_to_outputs (Optional[Dict[str, TaskContext]], optional): node_name_to_ids (Optional[Dict[str, str]], optional): The node name to node
The task outputs of current DAG. Defaults to None.
node_name_to_ids (Optional[Dict[str, str]], optional):
The task name to task id mapping. Defaults to None.
""" """
if not node_to_outputs:
node_to_outputs = {}
if not node_name_to_ids: if not node_name_to_ids:
node_name_to_ids = {} node_name_to_ids = {}
self._streaming_call = streaming_call self._streaming_call = streaming_call
self._curr_task_ctx: Optional[TaskContext] = None self._curr_task_ctx: Optional[TaskContext] = None
self._share_data: Dict[str, Any] = {} self._share_data: Dict[str, Any] = share_data
self._node_to_outputs: Dict[str, TaskContext] = node_to_outputs self._node_to_outputs: Dict[str, TaskContext] = node_to_outputs
self._node_name_to_ids: Dict[str, str] = node_name_to_ids self._node_name_to_ids: Dict[str, str] = node_name_to_ids
@ -530,6 +528,7 @@ class DAGContext:
Returns: Returns:
Any: The share data, you can cast it to the real type Any: The share data, you can cast it to the real type
""" """
logger.debug(f"Get share data by key {key} from {id(self._share_data)}")
return self._share_data.get(key) return self._share_data.get(key)
async def save_to_share_data( async def save_to_share_data(
@ -545,6 +544,7 @@ class DAGContext:
""" """
if key in self._share_data and not overwrite: if key in self._share_data and not overwrite:
raise ValueError(f"Share data key {key} already exists") raise ValueError(f"Share data key {key} already exists")
logger.debug(f"Save share data by key {key} to {id(self._share_data)}")
self._share_data[key] = data self._share_data[key] = data
async def get_task_share_data(self, task_name: str, key: str) -> Any: async def get_task_share_data(self, task_name: str, key: str) -> Any:

View File

@ -3,7 +3,7 @@
This runner will run the workflow in the current process. This runner will run the workflow in the current process.
""" """
import logging import logging
from typing import Dict, List, Optional, Set, cast from typing import Any, Dict, List, Optional, Set, cast
from dbgpt.component import SystemApp from dbgpt.component import SystemApp
@ -20,6 +20,10 @@ logger = logging.getLogger(__name__)
class DefaultWorkflowRunner(WorkflowRunner): class DefaultWorkflowRunner(WorkflowRunner):
"""The default workflow runner.""" """The default workflow runner."""
def __init__(self):
"""Init the default workflow runner."""
self._running_dag_ctx: Dict[str, DAGContext] = {}
async def execute_workflow( async def execute_workflow(
self, self,
node: BaseOperator, node: BaseOperator,
@ -44,15 +48,22 @@ class DefaultWorkflowRunner(WorkflowRunner):
if not exist_dag_ctx: if not exist_dag_ctx:
# Create DAG context # Create DAG context
node_outputs: Dict[str, TaskContext] = {} node_outputs: Dict[str, TaskContext] = {}
share_data: Dict[str, Any] = {}
else: else:
# Share node output with exist dag context # Share node output with exist dag context
node_outputs = exist_dag_ctx._node_to_outputs node_outputs = exist_dag_ctx._node_to_outputs
share_data = exist_dag_ctx._share_data
dag_ctx = DAGContext( dag_ctx = DAGContext(
streaming_call=streaming_call,
node_to_outputs=node_outputs, node_to_outputs=node_outputs,
share_data=share_data,
streaming_call=streaming_call,
node_name_to_ids=job_manager._node_name_to_ids, node_name_to_ids=job_manager._node_name_to_ids,
) )
logger.info(f"Begin run workflow from end operator, id: {node.node_id}") if node.dag:
self._running_dag_ctx[node.dag.dag_id] = dag_ctx
logger.info(
f"Begin run workflow from end operator, id: {node.node_id}, runner: {self}"
)
logger.debug(f"Node id {node.node_id}, call_data: {call_data}") logger.debug(f"Node id {node.node_id}, call_data: {call_data}")
skip_node_ids: Set[str] = set() skip_node_ids: Set[str] = set()
system_app: Optional[SystemApp] = DAGVar.get_current_system_app() system_app: Optional[SystemApp] = DAGVar.get_current_system_app()
@ -64,7 +75,8 @@ class DefaultWorkflowRunner(WorkflowRunner):
if not streaming_call and node.dag: if not streaming_call and node.dag:
# streaming call not work for dag end # streaming call not work for dag end
await node.dag._after_dag_end() await node.dag._after_dag_end()
if node.dag:
del self._running_dag_ctx[node.dag.dag_id]
return dag_ctx return dag_ctx
async def _execute_node( async def _execute_node(

View File

@ -14,11 +14,21 @@ if TYPE_CHECKING:
from starlette.requests import Request from starlette.requests import Request
RequestBody = Union[Type[Request], Type[BaseModel], Type[str]] RequestBody = Union[Type[Request], Type[BaseModel], Type[str]]
StreamingPredictFunc = Callable[[Union[Request, BaseModel]], bool] StreamingPredictFunc = Callable[[Union[Request, BaseModel, str, None]], bool]
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AWELHttpError(RuntimeError):
"""AWEL Http Error."""
def __init__(self, msg: str, code: Optional[str] = None):
"""Init the AWELHttpError."""
super().__init__(msg)
self.msg = msg
self.code = code
class HttpTrigger(Trigger): class HttpTrigger(Trigger):
"""Http trigger for AWEL. """Http trigger for AWEL.
@ -65,29 +75,74 @@ class HttpTrigger(Trigger):
Args: Args:
router (APIRouter): The router to mount the trigger. router (APIRouter): The router to mount the trigger.
""" """
from fastapi import Depends from inspect import Parameter, Signature
from typing import get_type_hints
from starlette.requests import Request from starlette.requests import Request
methods = [self._methods] if isinstance(self._methods, str) else self._methods methods = [self._methods] if isinstance(self._methods, str) else self._methods
is_query_method = (
all(method in ["GET", "DELETE"] for method in methods) if methods else True
)
async def _trigger_dag_func(body: Union[Request, BaseModel, str, None]):
streaming_response = self._streaming_response
if self._streaming_predict_func:
streaming_response = self._streaming_predict_func(body)
dag = self.dag
if not dag:
raise AWELHttpError("DAG is not set")
return await _trigger_dag(
body,
dag,
streaming_response,
self._response_headers,
self._response_media_type,
)
def create_route_function(name, req_body_cls: Optional[Type[BaseModel]]): def create_route_function(name, req_body_cls: Optional[Type[BaseModel]]):
async def _request_body_dependency(request: Request): async def route_function_request(request: Request):
return await _parse_request_body(request, self._req_body) return await _trigger_dag_func(request)
async def route_function(body=Depends(_request_body_dependency)): async def route_function_none():
streaming_response = self._streaming_response return await _trigger_dag_func(None)
if self._streaming_predict_func:
streaming_response = self._streaming_predict_func(body)
return await _trigger_dag(
body,
self.dag,
streaming_response,
self._response_headers,
self._response_media_type,
)
route_function.__name__ = name route_function_request.__name__ = name
return route_function route_function_none.__name__ = name
if not req_body_cls:
return route_function_none
if req_body_cls == Request:
return route_function_request
if is_query_method:
if req_body_cls == str:
raise AWELHttpError(f"Query methods {methods} not support str type")
async def route_function_get(**kwargs):
body = req_body_cls(**kwargs)
return await _trigger_dag_func(body)
parameters = [
Parameter(
name=field_name,
kind=Parameter.KEYWORD_ONLY,
default=Parameter.empty,
annotation=field.outer_type_,
)
for field_name, field in req_body_cls.__fields__.items()
]
route_function_get.__signature__ = Signature(parameters) # type: ignore
route_function_get.__annotations__ = get_type_hints(req_body_cls)
route_function_get.__name__ = name
return route_function_get
else:
async def route_function(body: req_body_cls): # type: ignore
return await _trigger_dag_func(body)
route_function.__name__ = name
return route_function
function_name = f"AWEL_trigger_route_{self._endpoint.replace('/', '_')}" function_name = f"AWEL_trigger_route_{self._endpoint.replace('/', '_')}"
request_model = ( request_model = (
@ -111,32 +166,6 @@ class HttpTrigger(Trigger):
)(dynamic_route_function) )(dynamic_route_function)
async def _parse_request_body(
request: "Request", request_body_cls: Optional["RequestBody"]
):
from starlette.requests import Request
if not request_body_cls:
return None
if request_body_cls == Request:
return request
if request.method == "POST":
if request_body_cls == str:
bytes_body = await request.body()
str_body = bytes_body.decode("utf-8")
return str_body
elif issubclass(request_body_cls, BaseModel):
json_data = await request.json()
return request_body_cls(**json_data)
else:
raise ValueError(f"Invalid request body cls: {request_body_cls}")
elif request.method == "GET":
if issubclass(request_body_cls, BaseModel):
return request_body_cls(**request.query_params)
else:
raise ValueError(f"Invalid request body cls: {request_body_cls}")
async def _trigger_dag( async def _trigger_dag(
body: Any, body: Any,
dag: DAG, dag: DAG,