feat: Check serialization for AWEL operator function

This commit is contained in:
Fangyin Cheng
2024-09-04 22:08:55 +08:00
parent b1ca247419
commit f8ce7d4580
12 changed files with 236 additions and 4 deletions

View File

@@ -145,6 +145,9 @@ class DAGVar:
_executor: Optional[Executor] = None
_variables_provider: Optional["VariablesProvider"] = None
# Whether check serializable for AWEL, it will be set to True when running AWEL
# operator in remote environment
_check_serializable: Optional[bool] = None
@classmethod
def enter_dag(cls, dag) -> None:
@@ -257,6 +260,24 @@ class DAGVar:
"""
cls._variables_provider = variables_provider
@classmethod
def get_check_serializable(cls) -> Optional[bool]:
"""Get the check serializable flag.
Returns:
Optional[bool]: The check serializable flag
"""
return cls._check_serializable
@classmethod
def set_check_serializable(cls, check_serializable: bool) -> None:
"""Set the check serializable flag.
Args:
check_serializable (bool): The check serializable flag to set
"""
cls._check_serializable = check_serializable
class DAGLifecycle:
"""The lifecycle of DAG."""
@@ -286,6 +307,7 @@ class DAGNode(DAGLifecycle, DependencyMixin, ViewMixin, ABC):
node_name: Optional[str] = None,
system_app: Optional[SystemApp] = None,
executor: Optional[Executor] = None,
check_serializable: Optional[bool] = None,
**kwargs,
) -> None:
"""Initialize a DAGNode.
@@ -311,6 +333,7 @@ class DAGNode(DAGLifecycle, DependencyMixin, ViewMixin, ABC):
node_id = self._dag._new_node_id()
self._node_id: Optional[str] = node_id
self._node_name: Optional[str] = node_name
self._check_serializable = check_serializable
if self._dag:
self._dag._append_node(self)
@@ -486,6 +509,20 @@ class DAGNode(DAGLifecycle, DependencyMixin, ViewMixin, ABC):
"""Return the string of current DAGNode."""
return self.__repr__()
@classmethod
def _do_check_serializable(cls, obj: Any, obj_name: str = "Object"):
"""Check whether the current DAGNode is serializable."""
from dbgpt.util.serialization.check import check_serializable
check_serializable(obj, obj_name)
@property
def check_serializable(self) -> bool:
"""Whether check serializable for current DAGNode."""
if self._check_serializable is not None:
return self._check_serializable or False
return DAGVar.get_check_serializable() or False
def _build_task_key(task_name: str, key: str) -> str:
return f"{task_name}___$$$$$$___{key}"