chore: Merge latest code

This commit is contained in:
Fangyin Cheng
2024-08-30 15:02:53 +08:00
parent bf63a967b5
commit 0e71991f7e
8 changed files with 169 additions and 22 deletions

View File

@@ -619,6 +619,7 @@ class DAGContext:
self._node_name_to_ids: Dict[str, str] = node_name_to_ids
self._event_loop_task_id = event_loop_task_id
self._dag_variables = dag_variables
self._share_data_lock = asyncio.Lock()
@property
def _task_outputs(self) -> Dict[str, TaskContext]:
@@ -680,8 +681,9 @@ class DAGContext:
Returns:
Any: The share data, you can cast it to the real type
"""
logger.debug(f"Get share data by key {key} from {id(self._share_data)}")
return self._share_data.get(key)
async with self._share_data_lock:
logger.debug(f"Get share data by key {key} from {id(self._share_data)}")
return self._share_data.get(key)
async def save_to_share_data(
self, key: str, data: Any, overwrite: bool = False
@@ -694,10 +696,11 @@ class DAGContext:
overwrite (bool): Whether overwrite the share data if the key
already exists. Defaults to None.
"""
if key in self._share_data and not overwrite:
raise ValueError(f"Share data key {key} already exists")
logger.debug(f"Save share data by key {key} to {id(self._share_data)}")
self._share_data[key] = data
async with self._share_data_lock:
if key in self._share_data and not overwrite:
raise ValueError(f"Share data key {key} already exists")
logger.debug(f"Save share data by key {key} to {id(self._share_data)}")
self._share_data[key] = data
async def get_task_share_data(self, task_name: str, key: str) -> Any:
"""Get share data by task name and key.