fix(core): Fix bug of sharing data across DAGs (#1102)

This commit is contained in:
Fangyin Cheng
2024-01-22 21:56:03 +08:00
committed by GitHub
parent 73c86ff083
commit 13527a8bd4
3 changed files with 96 additions and 55 deletions

View File

@@ -446,27 +446,25 @@ class DAGContext:
def __init__(
self,
node_to_outputs: Dict[str, TaskContext],
share_data: Dict[str, Any],
streaming_call: bool = False,
node_to_outputs: Optional[Dict[str, TaskContext]] = None,
node_name_to_ids: Optional[Dict[str, str]] = None,
) -> None:
"""Initialize a DAGContext.
Args:
node_to_outputs (Dict[str, TaskContext]): The task outputs of current DAG.
share_data (Dict[str, Any]): The share data of current DAG.
streaming_call (bool, optional): Whether the current DAG is streaming call.
Defaults to False.
node_to_outputs (Optional[Dict[str, TaskContext]], optional):
The task outputs of current DAG. Defaults to None.
node_name_to_ids (Optional[Dict[str, str]], optional):
The task name to task id mapping. Defaults to None.
node_name_to_ids (Optional[Dict[str, str]], optional): The node name to node
"""
if not node_to_outputs:
node_to_outputs = {}
if not node_name_to_ids:
node_name_to_ids = {}
self._streaming_call = streaming_call
self._curr_task_ctx: Optional[TaskContext] = None
self._share_data: Dict[str, Any] = {}
self._share_data: Dict[str, Any] = share_data
self._node_to_outputs: Dict[str, TaskContext] = node_to_outputs
self._node_name_to_ids: Dict[str, str] = node_name_to_ids
@@ -530,6 +528,7 @@ class DAGContext:
Returns:
Any: The share data, you can cast it to the real type
"""
logger.debug(f"Get share data by key {key} from {id(self._share_data)}")
return self._share_data.get(key)
async def save_to_share_data(
@@ -545,6 +544,7 @@ class DAGContext:
"""
if key in self._share_data and not overwrite:
raise ValueError(f"Share data key {key} already exists")
logger.debug(f"Save share data by key {key} to {id(self._share_data)}")
self._share_data[key] = data
async def get_task_share_data(self, task_name: str, key: str) -> Any: