refactor: The first refactored version for sdk release (#907)

Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
FangYin Cheng
2023-12-08 14:45:59 +08:00
committed by GitHub
parent e7e4aff667
commit cd725db1fb
573 changed files with 2094 additions and 3571 deletions

View File

View File

@@ -0,0 +1,371 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import (
TypeVar,
Generic,
Optional,
AsyncIterator,
Union,
Callable,
Any,
Dict,
List,
)
IN = TypeVar("IN")
OUT = TypeVar("OUT")
T = TypeVar("T")
class TaskState(str, Enum):
"""Enumeration representing the state of a task in the workflow.
This Enum defines various states a task can be in during its lifecycle in the DAG.
"""
INIT = "init" # Initial state of the task, not yet started
SKIP = "skip" # State indicating the task was skipped
RUNNING = "running" # State indicating the task is currently running
SUCCESS = "success" # State indicating the task completed successfully
FAILED = "failed" # State indicating the task failed during execution
class TaskOutput(ABC, Generic[T]):
"""Abstract base class representing the output of a task.
This class encapsulates the output of a task and provides methods to access the output data.
It can be subclassed to implement specific output behaviors.
"""
@property
def is_stream(self) -> bool:
"""Check if the output is a stream.
Returns:
bool: True if the output is a stream, False otherwise.
"""
return False
@property
def is_empty(self) -> bool:
"""Check if the output is empty.
Returns:
bool: True if the output is empty, False otherwise.
"""
return False
@property
def output(self) -> Optional[T]:
"""Return the output of the task.
Returns:
T: The output of the task. None if the output is empty.
"""
raise NotImplementedError
@property
def output_stream(self) -> Optional[AsyncIterator[T]]:
"""Return the output of the task as an asynchronous stream.
Returns:
AsyncIterator[T]: An asynchronous iterator over the output. None if the output is empty.
"""
raise NotImplementedError
@abstractmethod
def set_output(self, output_data: Union[T, AsyncIterator[T]]) -> None:
"""Set the output data to current object.
Args:
output_data (Union[T, AsyncIterator[T]]): Output data.
"""
@abstractmethod
def new_output(self) -> "TaskOutput[T]":
"""Create new output object"""
async def map(self, map_func) -> "TaskOutput[T]":
"""Apply a mapping function to the task's output.
Args:
map_func: A function to apply to the task's output.
Returns:
TaskOutput[T]: The result of applying the mapping function.
"""
raise NotImplementedError
async def reduce(self, reduce_func) -> "TaskOutput[T]":
"""Apply a reducing function to the task's output.
Stream TaskOutput to Nonstream TaskOutput.
Args:
reduce_func: A reducing function to apply to the task's output.
Returns:
TaskOutput[T]: The result of applying the reducing function.
"""
raise NotImplementedError
async def streamify(
self, transform_func: Callable[[T], AsyncIterator[T]]
) -> "TaskOutput[T]":
"""Convert a value of type T to an AsyncIterator[T] using a transform function.
Args:
transform_func (Callable[[T], AsyncIterator[T]]): Function to transform a T value into an AsyncIterator[T].
Returns:
TaskOutput[T]: The result of applying the reducing function.
"""
raise NotImplementedError
async def transform_stream(
self, transform_func: Callable[[AsyncIterator[T]], AsyncIterator[T]]
) -> "TaskOutput[T]":
"""Transform an AsyncIterator[T] to another AsyncIterator[T] using a given function.
Args:
transform_func (Callable[[AsyncIterator[T]], AsyncIterator[T]]): Function to apply to the AsyncIterator[T].
Returns:
TaskOutput[T]: The result of applying the reducing function.
"""
raise NotImplementedError
async def unstreamify(
self, transform_func: Callable[[AsyncIterator[T]], T]
) -> "TaskOutput[T]":
"""Convert an AsyncIterator[T] to a value of type T using a transform function.
Args:
transform_func (Callable[[AsyncIterator[T]], T]): Function to transform an AsyncIterator[T] into a T value.
Returns:
TaskOutput[T]: The result of applying the reducing function.
"""
raise NotImplementedError
async def check_condition(self, condition_func) -> bool:
"""Check if current output meets a given condition.
Args:
condition_func: A function to determine if the condition is met.
Returns:
bool: True if current output meet the condition, False otherwise.
"""
raise NotImplementedError
class TaskContext(ABC, Generic[T]):
"""Abstract base class representing the context of a task within a DAG.
This class provides the interface for accessing task-related information
and manipulating task output.
"""
@property
@abstractmethod
def task_id(self) -> str:
"""Return the unique identifier of the task.
Returns:
str: The unique identifier of the task.
"""
@property
@abstractmethod
def task_input(self) -> "InputContext":
"""Return the InputContext of current task.
Returns:
InputContext: The InputContext of current task.
"""
@abstractmethod
def set_task_input(self, input_ctx: "InputContext") -> None:
"""Set the InputContext object to current task.
Args:
input_ctx (InputContext): The InputContext of current task
"""
@property
@abstractmethod
def task_output(self) -> TaskOutput[T]:
"""Return the output object of the task.
Returns:
TaskOutput[T]: The output object of the task.
"""
@abstractmethod
def set_task_output(self, task_output: TaskOutput[T]) -> None:
"""Set the output object to current task."""
@property
@abstractmethod
def current_state(self) -> TaskState:
"""Get the current state of the task.
Returns:
TaskState: The current state of the task.
"""
@abstractmethod
def set_current_state(self, task_state: TaskState) -> None:
"""Set current task state
Args:
task_state (TaskState): The task state to be set.
"""
@abstractmethod
def new_ctx(self) -> "TaskContext":
"""Create new task context
Returns:
TaskContext: A new instance of a TaskContext.
"""
@property
@abstractmethod
def metadata(self) -> Dict[str, Any]:
"""Get the metadata of current task
Returns:
Dict[str, Any]: The metadata
"""
def update_metadata(self, key: str, value: Any) -> None:
"""Update metadata with key and value
Args:
key (str): The key of metadata
value (str): The value to be add to metadata
"""
self.metadata[key] = value
@property
def call_data(self) -> Optional[Dict]:
"""Get the call data for current data"""
return self.metadata.get("call_data")
@abstractmethod
async def _call_data_to_output(self) -> Optional[TaskOutput[T]]:
"""Get the call data for current data"""
def set_call_data(self, call_data: Dict) -> None:
"""Set call data for current task"""
self.update_metadata("call_data", call_data)
class InputContext(ABC):
"""Abstract base class representing the context of inputs to a operator node.
This class defines methods to manipulate and access the inputs for a operator node.
"""
@property
@abstractmethod
def parent_outputs(self) -> List[TaskContext]:
"""Get the outputs from the parent nodes.
Returns:
List[TaskContext]: A list of contexts of the parent nodes' outputs.
"""
@abstractmethod
async def map(self, map_func: Callable[[Any], Any]) -> "InputContext":
"""Apply a mapping function to the inputs.
Args:
map_func (Callable[[Any], Any]): A function to be applied to the inputs.
Returns:
InputContext: A new InputContext instance with the mapped inputs.
"""
@abstractmethod
async def map_all(self, map_func: Callable[..., Any]) -> "InputContext":
"""Apply a mapping function to all inputs.
Args:
map_func (Callable[..., Any]): A function to be applied to all inputs.
Returns:
InputContext: A new InputContext instance with the mapped inputs.
"""
@abstractmethod
async def reduce(self, reduce_func: Callable[[Any], Any]) -> "InputContext":
"""Apply a reducing function to the inputs.
Args:
reduce_func (Callable[[Any], Any]): A function that reduces the inputs.
Returns:
InputContext: A new InputContext instance with the reduced inputs.
"""
@abstractmethod
async def filter(self, filter_func: Callable[[Any], bool]) -> "InputContext":
"""Filter the inputs based on a provided function.
Args:
filter_func (Callable[[Any], bool]): A function that returns True for inputs to keep.
Returns:
InputContext: A new InputContext instance with the filtered inputs.
"""
@abstractmethod
async def predicate_map(
self, predicate_func: Callable[[Any], bool], failed_value: Any = None
) -> "InputContext":
"""Predicate the inputs based on a provided function.
Args:
predicate_func (Callable[[Any], bool]): A function that returns True for inputs is predicate True.
failed_value (Any): The value to be set if the return value of predicate function is False
Returns:
InputContext: A new InputContext instance with the predicate inputs.
"""
def check_single_parent(self) -> bool:
"""Check if there is only a single parent output.
Returns:
bool: True if there is only one parent output, False otherwise.
"""
return len(self.parent_outputs) == 1
def check_stream(self, skip_empty: bool = False) -> bool:
"""Check if all parent outputs are streams.
Args:
skip_empty (bool): Skip empty output or not.
Returns:
bool: True if all parent outputs are streams, False otherwise.
"""
for out in self.parent_outputs:
if out.task_output.is_empty and skip_empty:
continue
if not (out.task_output and out.task_output.is_stream):
return False
return True
class InputSource(ABC, Generic[T]):
"""Abstract base class representing the source of inputs to a DAG node."""
@abstractmethod
async def read(self, task_ctx: TaskContext) -> TaskOutput[T]:
"""Read the data from current input source.
Returns:
TaskOutput[T]: The output object read from current source
"""

View File

@@ -0,0 +1,348 @@
from abc import ABC, abstractmethod
from typing import (
Callable,
Coroutine,
Iterator,
AsyncIterator,
List,
Generic,
TypeVar,
Any,
Tuple,
Dict,
Union,
Optional,
)
import asyncio
import logging
from .base import TaskOutput, TaskContext, TaskState, InputContext, InputSource, T
logger = logging.getLogger(__name__)
async def _reduce_stream(stream: AsyncIterator, reduce_function) -> Any:
# Init accumulator
try:
accumulator = await stream.__anext__()
except StopAsyncIteration:
raise ValueError("Stream is empty")
is_async = asyncio.iscoroutinefunction(reduce_function)
async for element in stream:
if is_async:
accumulator = await reduce_function(accumulator, element)
else:
accumulator = reduce_function(accumulator, element)
return accumulator
class SimpleTaskOutput(TaskOutput[T], Generic[T]):
def __init__(self, data: T) -> None:
super().__init__()
self._data = data
@property
def output(self) -> T:
return self._data
def set_output(self, output_data: T | AsyncIterator[T]) -> None:
self._data = output_data
def new_output(self) -> TaskOutput[T]:
return SimpleTaskOutput(None)
@property
def is_empty(self) -> bool:
return not self._data
async def _apply_func(self, func) -> Any:
if asyncio.iscoroutinefunction(func):
out = await func(self._data)
else:
out = func(self._data)
return out
async def map(self, map_func) -> TaskOutput[T]:
out = await self._apply_func(map_func)
return SimpleTaskOutput(out)
async def check_condition(self, condition_func) -> bool:
return await self._apply_func(condition_func)
async def streamify(
self, transform_func: Callable[[T], AsyncIterator[T]]
) -> TaskOutput[T]:
out = await self._apply_func(transform_func)
return SimpleStreamTaskOutput(out)
class SimpleStreamTaskOutput(TaskOutput[T], Generic[T]):
def __init__(self, data: AsyncIterator[T]) -> None:
super().__init__()
self._data = data
@property
def is_stream(self) -> bool:
return True
@property
def is_empty(self) -> bool:
return not self._data
@property
def output_stream(self) -> AsyncIterator[T]:
return self._data
def set_output(self, output_data: T | AsyncIterator[T]) -> None:
self._data = output_data
def new_output(self) -> TaskOutput[T]:
return SimpleStreamTaskOutput(None)
async def map(self, map_func) -> TaskOutput[T]:
is_async = asyncio.iscoroutinefunction(map_func)
async def new_iter() -> AsyncIterator[T]:
async for out in self._data:
if is_async:
out = await map_func(out)
else:
out = map_func(out)
yield out
return SimpleStreamTaskOutput(new_iter())
async def reduce(self, reduce_func) -> TaskOutput[T]:
out = await _reduce_stream(self._data, reduce_func)
return SimpleTaskOutput(out)
async def unstreamify(
self, transform_func: Callable[[AsyncIterator[T]], T]
) -> TaskOutput[T]:
if asyncio.iscoroutinefunction(transform_func):
out = await transform_func(self._data)
else:
out = transform_func(self._data)
return SimpleTaskOutput(out)
async def transform_stream(
self, transform_func: Callable[[AsyncIterator[T]], AsyncIterator[T]]
) -> TaskOutput[T]:
if asyncio.iscoroutinefunction(transform_func):
out = await transform_func(self._data)
else:
out = transform_func(self._data)
return SimpleStreamTaskOutput(out)
def _is_async_iterator(obj):
return (
hasattr(obj, "__anext__")
and callable(getattr(obj, "__anext__", None))
and hasattr(obj, "__aiter__")
and callable(getattr(obj, "__aiter__", None))
)
class BaseInputSource(InputSource, ABC):
def __init__(self) -> None:
super().__init__()
self._is_read = False
@abstractmethod
def _read_data(self, task_ctx: TaskContext) -> Any:
"""Read data with task context"""
async def read(self, task_ctx: TaskContext) -> TaskOutput:
data = self._read_data(task_ctx)
if _is_async_iterator(data):
if self._is_read:
raise ValueError(f"Input iterator {data} has been read!")
output = SimpleStreamTaskOutput(data)
else:
output = SimpleTaskOutput(data)
self._is_read = True
return output
class SimpleInputSource(BaseInputSource):
def __init__(self, data: Any) -> None:
super().__init__()
self._data = data
def _read_data(self, task_ctx: TaskContext) -> Any:
return self._data
class SimpleCallDataInputSource(BaseInputSource):
def __init__(self) -> None:
super().__init__()
def _read_data(self, task_ctx: TaskContext) -> Any:
call_data = task_ctx.call_data
data = call_data.get("data") if call_data else None
if not (call_data and data):
raise ValueError("No call data for current SimpleCallDataInputSource")
return data
class DefaultTaskContext(TaskContext, Generic[T]):
def __init__(
self, task_id: str, task_state: TaskState, task_output: TaskOutput[T]
) -> None:
super().__init__()
self._task_id = task_id
self._task_state = task_state
self._output = task_output
self._task_input = None
self._metadata = {}
@property
def task_id(self) -> str:
return self._task_id
@property
def task_input(self) -> InputContext:
return self._task_input
def set_task_input(self, input_ctx: "InputContext") -> None:
self._task_input = input_ctx
@property
def task_output(self) -> TaskOutput:
return self._output
def set_task_output(self, task_output: TaskOutput) -> None:
self._output = task_output
@property
def current_state(self) -> TaskState:
return self._task_state
def set_current_state(self, task_state: TaskState) -> None:
self._task_state = task_state
def new_ctx(self) -> TaskContext:
new_output = self._output.new_output()
return DefaultTaskContext(self._task_id, self._task_state, new_output)
@property
def metadata(self) -> Dict[str, Any]:
return self._metadata
async def _call_data_to_output(self) -> Optional[TaskOutput[T]]:
"""Get the call data for current data"""
call_data = self.call_data
if not call_data:
return None
input_source = SimpleCallDataInputSource()
return await input_source.read(self)
class DefaultInputContext(InputContext):
def __init__(self, outputs: List[TaskContext]) -> None:
super().__init__()
self._outputs = outputs
@property
def parent_outputs(self) -> List[TaskContext]:
return self._outputs
async def _apply_func(
self, func: Callable[[Any], Any], apply_type: str = "map"
) -> Tuple[List[TaskContext], List[TaskOutput]]:
new_outputs: List[TaskContext] = []
map_tasks = []
for out in self._outputs:
new_outputs.append(out.new_ctx())
result = None
if apply_type == "map":
result = out.task_output.map(func)
elif apply_type == "reduce":
result = out.task_output.reduce(func)
elif apply_type == "check_condition":
result = out.task_output.check_condition(func)
else:
raise ValueError(f"Unsupport apply type {apply_type}")
map_tasks.append(result)
results = await asyncio.gather(*map_tasks)
return new_outputs, results
async def map(self, map_func: Callable[[Any], Any]) -> InputContext:
new_outputs, results = await self._apply_func(map_func)
for i, task_ctx in enumerate(new_outputs):
task_ctx: TaskContext = task_ctx
task_ctx.set_task_output(results[i])
return DefaultInputContext(new_outputs)
async def map_all(self, map_func: Callable[..., Any]) -> InputContext:
if not self._outputs:
return DefaultInputContext([])
# Some parent may be empty
not_empty_idx = 0
for i, p in enumerate(self._outputs):
if p.task_output.is_empty:
continue
not_empty_idx = i
break
# All output is empty?
is_steam = self._outputs[not_empty_idx].task_output.is_stream
if is_steam:
if not self.check_stream(skip_empty=True):
raise ValueError(
"The output in all tasks must has same output format to map_all"
)
outputs = []
for out in self._outputs:
if out.task_output.is_stream:
outputs.append(out.task_output.output_stream)
else:
outputs.append(out.task_output.output)
if asyncio.iscoroutinefunction(map_func):
map_res = await map_func(*outputs)
else:
map_res = map_func(*outputs)
single_output: TaskContext = self._outputs[not_empty_idx].new_ctx()
single_output.task_output.set_output(map_res)
logger.debug(
f"Current map_all map_res: {map_res}, is steam: {single_output.task_output.is_stream}"
)
return DefaultInputContext([single_output])
async def reduce(self, reduce_func: Callable[[Any], Any]) -> InputContext:
if not self.check_stream():
raise ValueError(
"The output in all tasks must has same output format of stream to apply reduce function"
)
new_outputs, results = await self._apply_func(reduce_func, apply_type="reduce")
for i, task_ctx in enumerate(new_outputs):
task_ctx: TaskContext = task_ctx
task_ctx.set_task_output(results[i])
return DefaultInputContext(new_outputs)
async def filter(self, filter_func: Callable[[Any], bool]) -> InputContext:
new_outputs, results = await self._apply_func(
filter_func, apply_type="check_condition"
)
result_outputs = []
for i, task_ctx in enumerate(new_outputs):
if results[i]:
result_outputs.append(task_ctx)
return DefaultInputContext(result_outputs)
async def predicate_map(
self, predicate_func: Callable[[Any], bool], failed_value: Any = None
) -> "InputContext":
new_outputs, results = await self._apply_func(
predicate_func, apply_type="check_condition"
)
result_outputs = []
for i, task_ctx in enumerate(new_outputs):
task_ctx: TaskContext = task_ctx
if results[i]:
task_ctx.task_output.set_output(True)
result_outputs.append(task_ctx)
else:
task_ctx.task_output.set_output(failed_value)
result_outputs.append(task_ctx)
return DefaultInputContext(result_outputs)