feat: Check serialization for AWEL operator function

This commit is contained in:
Fangyin Cheng
2024-09-04 22:08:55 +08:00
parent b1ca247419
commit f8ce7d4580
12 changed files with 236 additions and 4 deletions

View File

@@ -145,6 +145,9 @@ class DAGVar:
_executor: Optional[Executor] = None
_variables_provider: Optional["VariablesProvider"] = None
# Whether check serializable for AWEL, it will be set to True when running AWEL
# operator in remote environment
_check_serializable: Optional[bool] = None
@classmethod
def enter_dag(cls, dag) -> None:
@@ -257,6 +260,24 @@ class DAGVar:
"""
cls._variables_provider = variables_provider
@classmethod
def get_check_serializable(cls) -> Optional[bool]:
"""Get the check serializable flag.
Returns:
Optional[bool]: The check serializable flag
"""
return cls._check_serializable
@classmethod
def set_check_serializable(cls, check_serializable: bool) -> None:
"""Set the check serializable flag.
Args:
check_serializable (bool): The check serializable flag to set
"""
cls._check_serializable = check_serializable
class DAGLifecycle:
"""The lifecycle of DAG."""
@@ -286,6 +307,7 @@ class DAGNode(DAGLifecycle, DependencyMixin, ViewMixin, ABC):
node_name: Optional[str] = None,
system_app: Optional[SystemApp] = None,
executor: Optional[Executor] = None,
check_serializable: Optional[bool] = None,
**kwargs,
) -> None:
"""Initialize a DAGNode.
@@ -311,6 +333,7 @@ class DAGNode(DAGLifecycle, DependencyMixin, ViewMixin, ABC):
node_id = self._dag._new_node_id()
self._node_id: Optional[str] = node_id
self._node_name: Optional[str] = node_name
self._check_serializable = check_serializable
if self._dag:
self._dag._append_node(self)
@@ -486,6 +509,20 @@ class DAGNode(DAGLifecycle, DependencyMixin, ViewMixin, ABC):
"""Return the string of current DAGNode."""
return self.__repr__()
@classmethod
def _do_check_serializable(cls, obj: Any, obj_name: str = "Object"):
"""Check whether the current DAGNode is serializable."""
from dbgpt.util.serialization.check import check_serializable
check_serializable(obj, obj_name)
@property
def check_serializable(self) -> bool:
"""Whether check serializable for current DAGNode."""
if self._check_serializable is not None:
return self._check_serializable or False
return DAGVar.get_check_serializable() or False
def _build_task_key(task_name: str, key: str) -> str:
return f"{task_name}___$$$$$$___{key}"

View File

@@ -193,12 +193,29 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
self.incremental_output = bool(kwargs["incremental_output"])
if "output_format" in kwargs:
self.output_format = kwargs["output_format"]
self._runner: WorkflowRunner = runner
self._dag_ctx: Optional[DAGContext] = None
self._can_skip_in_branch = can_skip_in_branch
self._variables_provider = variables_provider
def __getstate__(self):
"""Customize the pickling process."""
state = self.__dict__.copy()
if "_runner" in state:
del state["_runner"]
if "_executor" in state:
del state["_executor"]
if "_system_app" in state:
del state["_system_app"]
return state
def __setstate__(self, state):
"""Customize the unpickling process."""
self.__dict__.update(state)
self._runner = default_runner
self._system_app = DAGVar.get_current_system_app()
self._executor = DAGVar.get_executor()
@property
def current_dag_context(self) -> DAGContext:
"""Return the current DAG context."""

View File

@@ -41,6 +41,12 @@ class JoinOperator(BaseOperator, Generic[OUT]):
super().__init__(can_skip_in_branch=can_skip_in_branch, **kwargs)
if not callable(combine_function):
raise ValueError("combine_function must be callable")
if self.check_serializable:
super()._do_check_serializable(
combine_function,
f"JoinOperator: {self}, combine_function: {combine_function}",
)
self.combine_function = combine_function
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
@@ -83,6 +89,11 @@ class ReduceStreamOperator(BaseOperator, Generic[IN, OUT]):
super().__init__(**kwargs)
if reduce_function and not callable(reduce_function):
raise ValueError("reduce_function must be callable")
if reduce_function and self.check_serializable:
super()._do_check_serializable(
reduce_function, f"Operator: {self}, reduce_function: {reduce_function}"
)
self.reduce_function = reduce_function
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
@@ -133,6 +144,12 @@ class MapOperator(BaseOperator, Generic[IN, OUT]):
super().__init__(**kwargs)
if map_function and not callable(map_function):
raise ValueError("map_function must be callable")
if map_function and self.check_serializable:
super()._do_check_serializable(
map_function, f"Operator: {self}, map_function: {map_function}"
)
self.map_function = map_function
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: