mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-05 02:51:07 +00:00
fix(core): Fix bug of sharing data across DAGs (#1102)
This commit is contained in:
@@ -446,27 +446,25 @@ class DAGContext:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_to_outputs: Dict[str, TaskContext],
|
||||
share_data: Dict[str, Any],
|
||||
streaming_call: bool = False,
|
||||
node_to_outputs: Optional[Dict[str, TaskContext]] = None,
|
||||
node_name_to_ids: Optional[Dict[str, str]] = None,
|
||||
) -> None:
|
||||
"""Initialize a DAGContext.
|
||||
|
||||
Args:
|
||||
node_to_outputs (Dict[str, TaskContext]): The task outputs of current DAG.
|
||||
share_data (Dict[str, Any]): The share data of current DAG.
|
||||
streaming_call (bool, optional): Whether the current DAG is streaming call.
|
||||
Defaults to False.
|
||||
node_to_outputs (Optional[Dict[str, TaskContext]], optional):
|
||||
The task outputs of current DAG. Defaults to None.
|
||||
node_name_to_ids (Optional[Dict[str, str]], optional):
|
||||
The task name to task id mapping. Defaults to None.
|
||||
node_name_to_ids (Optional[Dict[str, str]], optional): The node name to node
|
||||
"""
|
||||
if not node_to_outputs:
|
||||
node_to_outputs = {}
|
||||
if not node_name_to_ids:
|
||||
node_name_to_ids = {}
|
||||
self._streaming_call = streaming_call
|
||||
self._curr_task_ctx: Optional[TaskContext] = None
|
||||
self._share_data: Dict[str, Any] = {}
|
||||
self._share_data: Dict[str, Any] = share_data
|
||||
self._node_to_outputs: Dict[str, TaskContext] = node_to_outputs
|
||||
self._node_name_to_ids: Dict[str, str] = node_name_to_ids
|
||||
|
||||
@@ -530,6 +528,7 @@ class DAGContext:
|
||||
Returns:
|
||||
Any: The share data, you can cast it to the real type
|
||||
"""
|
||||
logger.debug(f"Get share data by key {key} from {id(self._share_data)}")
|
||||
return self._share_data.get(key)
|
||||
|
||||
async def save_to_share_data(
|
||||
@@ -545,6 +544,7 @@ class DAGContext:
|
||||
"""
|
||||
if key in self._share_data and not overwrite:
|
||||
raise ValueError(f"Share data key {key} already exists")
|
||||
logger.debug(f"Save share data by key {key} to {id(self._share_data)}")
|
||||
self._share_data[key] = data
|
||||
|
||||
async def get_task_share_data(self, task_name: str, key: str) -> Any:
|
||||
|
Reference in New Issue
Block a user