chore: Add pylint for DB-GPT core lib (#1076)

This commit is contained in:
Fangyin Cheng
2024-01-16 17:36:26 +08:00
committed by GitHub
parent 3a54d1ef9a
commit 40c853575a
79 changed files with 2213 additions and 839 deletions

View File

@@ -0,0 +1 @@
"""The module of Task."""

View File

@@ -1,8 +1,10 @@
"""Base classes for task-related objects."""
from abc import ABC, abstractmethod
from enum import Enum
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Dict,
Generic,
@@ -17,6 +19,24 @@ OUT = TypeVar("OUT")
T = TypeVar("T")
class _EMPTY_DATA_TYPE:
def __bool__(self):
return False
EMPTY_DATA = _EMPTY_DATA_TYPE()
SKIP_DATA = _EMPTY_DATA_TYPE()
PLACEHOLDER_DATA = _EMPTY_DATA_TYPE()
MapFunc = Union[Callable[[IN], OUT], Callable[[IN], Awaitable[OUT]]]
ReduceFunc = Union[Callable[[IN], OUT], Callable[[IN], Awaitable[OUT]]]
StreamFunc = Callable[[IN], Awaitable[AsyncIterator[OUT]]]
UnStreamFunc = Callable[[AsyncIterator[IN]], OUT]
TransformFunc = Callable[[AsyncIterator[IN]], Awaitable[AsyncIterator[OUT]]]
PredicateFunc = Union[Callable[[IN], bool], Callable[[IN], Awaitable[bool]]]
JoinFunc = Union[Callable[..., OUT], Callable[..., Awaitable[OUT]]]
class TaskState(str, Enum):
"""Enumeration representing the state of a task in the workflow.
@@ -33,8 +53,8 @@ class TaskState(str, Enum):
class TaskOutput(ABC, Generic[T]):
"""Abstract base class representing the output of a task.
This class encapsulates the output of a task and provides methods to access the output data.
It can be subclassed to implement specific output behaviors.
This class encapsulates the output of a task and provides methods to access the
output data.It can be subclassed to implement specific output behaviors.
"""
@property
@@ -56,20 +76,30 @@ class TaskOutput(ABC, Generic[T]):
return False
@property
def output(self) -> Optional[T]:
def is_none(self) -> bool:
"""Check if the output is None.
Returns:
bool: True if the output is None, False otherwise.
"""
return False
@property
def output(self) -> T:
"""Return the output of the task.
Returns:
T: The output of the task. None if the output is empty.
T: The output of the task.
"""
raise NotImplementedError
@property
def output_stream(self) -> Optional[AsyncIterator[T]]:
def output_stream(self) -> AsyncIterator[T]:
"""Return the output of the task as an asynchronous stream.
Returns:
AsyncIterator[T]: An asynchronous iterator over the output. None if the output is empty.
AsyncIterator[T]: An asynchronous iterator over the output. None if the
output is empty.
"""
raise NotImplementedError
@@ -83,39 +113,38 @@ class TaskOutput(ABC, Generic[T]):
@abstractmethod
def new_output(self) -> "TaskOutput[T]":
"""Create new output object"""
"""Create new output object."""
async def map(self, map_func) -> "TaskOutput[T]":
async def map(self, map_func: MapFunc) -> "TaskOutput[OUT]":
"""Apply a mapping function to the task's output.
Args:
map_func: A function to apply to the task's output.
map_func (MapFunc): A function to apply to the task's output.
Returns:
TaskOutput[T]: The result of applying the mapping function.
TaskOutput[OUT]: The result of applying the mapping function.
"""
raise NotImplementedError
async def reduce(self, reduce_func) -> "TaskOutput[T]":
async def reduce(self, reduce_func: ReduceFunc) -> "TaskOutput[OUT]":
"""Apply a reducing function to the task's output.
Stream TaskOutput to Nonstream TaskOutput.
Stream TaskOutput to no stream TaskOutput.
Args:
reduce_func: A reducing function to apply to the task's output.
Returns:
TaskOutput[T]: The result of applying the reducing function.
TaskOutput[OUT]: The result of applying the reducing function.
"""
raise NotImplementedError
async def streamify(
self, transform_func: Callable[[T], AsyncIterator[T]]
) -> "TaskOutput[T]":
async def streamify(self, transform_func: StreamFunc) -> "TaskOutput[T]":
"""Convert a value of type T to an AsyncIterator[T] using a transform function.
Args:
transform_func (Callable[[T], AsyncIterator[T]]): Function to transform a T value into an AsyncIterator[T].
transform_func (StreamFunc): Function to transform a T value into an
AsyncIterator[OUT].
Returns:
TaskOutput[T]: The result of applying the reducing function.
@@ -123,38 +152,39 @@ class TaskOutput(ABC, Generic[T]):
raise NotImplementedError
async def transform_stream(
self, transform_func: Callable[[AsyncIterator[T]], AsyncIterator[T]]
) -> "TaskOutput[T]":
"""Transform an AsyncIterator[T] to another AsyncIterator[T] using a given function.
self, transform_func: TransformFunc
) -> "TaskOutput[OUT]":
"""Transform an AsyncIterator[T] to another AsyncIterator[T].
Args:
transform_func (Callable[[AsyncIterator[T]], AsyncIterator[T]]): Function to apply to the AsyncIterator[T].
transform_func (Callable[[AsyncIterator[T]], AsyncIterator[T]]): Function to
apply to the AsyncIterator[T].
Returns:
TaskOutput[T]: The result of applying the reducing function.
"""
raise NotImplementedError
async def unstreamify(
self, transform_func: Callable[[AsyncIterator[T]], T]
) -> "TaskOutput[T]":
async def unstreamify(self, transform_func: UnStreamFunc) -> "TaskOutput[OUT]":
"""Convert an AsyncIterator[T] to a value of type T using a transform function.
Args:
transform_func (Callable[[AsyncIterator[T]], T]): Function to transform an AsyncIterator[T] into a T value.
transform_func (UnStreamFunc): Function to transform an AsyncIterator[T]
into a T value.
Returns:
TaskOutput[T]: The result of applying the reducing function.
"""
raise NotImplementedError
async def check_condition(self, condition_func) -> bool:
async def check_condition(self, condition_func) -> "TaskOutput[OUT]":
"""Check if current output meets a given condition.
Args:
condition_func: A function to determine if the condition is met.
Returns:
bool: True if current output meet the condition, False otherwise.
TaskOutput[T]: The result of applying the reducing function.
If the condition is not met, return empty output.
"""
raise NotImplementedError
@@ -182,6 +212,9 @@ class TaskContext(ABC, Generic[T]):
Returns:
InputContext: The InputContext of current task.
Raises:
Exception: If the InputContext is not set.
"""
@abstractmethod
@@ -216,7 +249,7 @@ class TaskContext(ABC, Generic[T]):
@abstractmethod
def set_current_state(self, task_state: TaskState) -> None:
"""Set current task state
"""Set current task state.
Args:
task_state (TaskState): The task state to be set.
@@ -224,7 +257,7 @@ class TaskContext(ABC, Generic[T]):
@abstractmethod
def new_ctx(self) -> "TaskContext":
"""Create new task context
"""Create new task context.
Returns:
TaskContext: A new instance of a TaskContext.
@@ -233,14 +266,14 @@ class TaskContext(ABC, Generic[T]):
@property
@abstractmethod
def metadata(self) -> Dict[str, Any]:
"""Get the metadata of current task
"""Return the metadata of current task.
Returns:
Dict[str, Any]: The metadata
"""
def update_metadata(self, key: str, value: Any) -> None:
"""Update metadata with key and value
"""Update metadata with key and value.
Args:
key (str): The key of metadata
@@ -250,15 +283,15 @@ class TaskContext(ABC, Generic[T]):
@property
def call_data(self) -> Optional[Dict]:
"""Get the call data for current data"""
"""Return the call data for current data."""
return self.metadata.get("call_data")
@abstractmethod
async def _call_data_to_output(self) -> Optional[TaskOutput[T]]:
"""Get the call data for current data"""
"""Get the call data for current data."""
def set_call_data(self, call_data: Dict) -> None:
"""Set call data for current task"""
"""Save the call data for current task."""
self.update_metadata("call_data", call_data)
@@ -315,7 +348,8 @@ class InputContext(ABC):
"""Filter the inputs based on a provided function.
Args:
filter_func (Callable[[Any], bool]): A function that returns True for inputs to keep.
filter_func (Callable[[Any], bool]): A function that returns True for
inputs to keep.
Returns:
InputContext: A new InputContext instance with the filtered inputs.
@@ -323,13 +357,15 @@ class InputContext(ABC):
@abstractmethod
async def predicate_map(
self, predicate_func: Callable[[Any], bool], failed_value: Any = None
self, predicate_func: PredicateFunc, failed_value: Any = None
) -> "InputContext":
"""Predicate the inputs based on a provided function.
Args:
predicate_func (Callable[[Any], bool]): A function that returns True for inputs is predicate True.
failed_value (Any): The value to be set if the return value of predicate function is False
predicate_func (Callable[[Any], bool]): A function that returns True for
inputs is predicate True.
failed_value (Any): The value to be set if the return value of predicate
function is False
Returns:
InputContext: A new InputContext instance with the predicate inputs.
"""

View File

@@ -1,3 +1,7 @@
"""The default implementation of Task.
This implementation can run workflow in local machine.
"""
import asyncio
import logging
from abc import ABC, abstractmethod
@@ -8,15 +12,32 @@ from typing import (
Coroutine,
Dict,
Generic,
Iterator,
List,
Optional,
Tuple,
TypeVar,
Union,
cast,
)
from .base import InputContext, InputSource, T, TaskContext, TaskOutput, TaskState
from .base import (
_EMPTY_DATA_TYPE,
EMPTY_DATA,
OUT,
PLACEHOLDER_DATA,
SKIP_DATA,
InputContext,
InputSource,
MapFunc,
PredicateFunc,
ReduceFunc,
StreamFunc,
T,
TaskContext,
TaskOutput,
TaskState,
TransformFunc,
UnStreamFunc,
)
logger = logging.getLogger(__name__)
@@ -37,101 +58,197 @@ async def _reduce_stream(stream: AsyncIterator, reduce_function) -> Any:
class SimpleTaskOutput(TaskOutput[T], Generic[T]):
def __init__(self, data: T) -> None:
"""The default implementation of TaskOutput.
It wraps the no stream data and provide some basic data operations.
"""
def __init__(self, data: Union[T, _EMPTY_DATA_TYPE] = EMPTY_DATA) -> None:
"""Create a SimpleTaskOutput.
Args:
data (Union[T, _EMPTY_DATA_TYPE], optional): The output data. Defaults to
EMPTY_DATA.
"""
super().__init__()
self._data = data
@property
def output(self) -> T:
return self._data
"""Return the output data."""
if self._data == EMPTY_DATA:
raise ValueError("No output data for current task output")
return cast(T, self._data)
def set_output(self, output_data: T | AsyncIterator[T]) -> None:
self._data = output_data
"""Save the output data to current object.
Args:
output_data (T | AsyncIterator[T]): The output data.
"""
if _is_async_iterator(output_data):
raise ValueError(
f"Can not set stream data {output_data} to SimpleTaskOutput"
)
self._data = cast(T, output_data)
def new_output(self) -> TaskOutput[T]:
return SimpleTaskOutput(None)
"""Create new output object with empty data."""
return SimpleTaskOutput()
@property
def is_empty(self) -> bool:
"""Return True if the output data is empty."""
return self._data == EMPTY_DATA or self._data == SKIP_DATA
@property
def is_none(self) -> bool:
"""Return True if the output data is None."""
return self._data is None
async def _apply_func(self, func) -> Any:
"""Apply the function to current output data."""
if asyncio.iscoroutinefunction(func):
out = await func(self._data)
else:
out = func(self._data)
return out
async def map(self, map_func) -> TaskOutput[T]:
async def map(self, map_func: MapFunc) -> TaskOutput[OUT]:
"""Apply a mapping function to the task's output.
Args:
map_func (MapFunc): A function to apply to the task's output.
Returns:
TaskOutput[OUT]: The result of applying the mapping function.
"""
out = await self._apply_func(map_func)
return SimpleTaskOutput(out)
async def check_condition(self, condition_func) -> bool:
return await self._apply_func(condition_func)
async def check_condition(self, condition_func) -> TaskOutput[OUT]:
"""Check the condition function."""
out = await self._apply_func(condition_func)
if out:
return SimpleTaskOutput(PLACEHOLDER_DATA)
return SimpleTaskOutput(EMPTY_DATA)
async def streamify(
self, transform_func: Callable[[T], AsyncIterator[T]]
) -> TaskOutput[T]:
async def streamify(self, transform_func: StreamFunc) -> TaskOutput[OUT]:
"""Transform the task's output to a stream output.
Args:
transform_func (StreamFunc): A function to transform the task's output to a
stream output.
Returns:
TaskOutput[OUT]: The result of transforming the task's output to a stream
output.
"""
out = await self._apply_func(transform_func)
return SimpleStreamTaskOutput(out)
class SimpleStreamTaskOutput(TaskOutput[T], Generic[T]):
def __init__(self, data: AsyncIterator[T]) -> None:
"""The default stream implementation of TaskOutput."""
def __init__(
self, data: Union[AsyncIterator[T], _EMPTY_DATA_TYPE] = EMPTY_DATA
) -> None:
"""Create a SimpleStreamTaskOutput.
Args:
data (Union[AsyncIterator[T], _EMPTY_DATA_TYPE], optional): The output data.
Defaults to EMPTY_DATA.
"""
super().__init__()
self._data = data
@property
def is_stream(self) -> bool:
"""Return True if the output data is a stream."""
return True
@property
def is_empty(self) -> bool:
return not self._data
"""Return True if the output data is empty."""
return self._data == EMPTY_DATA or self._data == SKIP_DATA
@property
def is_none(self) -> bool:
"""Return True if the output data is None."""
return self._data is None
@property
def output_stream(self) -> AsyncIterator[T]:
return self._data
"""Return the output data.
Returns:
AsyncIterator[T]: The output data.
Raises:
ValueError: If the output data is empty.
"""
if self._data == EMPTY_DATA:
raise ValueError("No output data for current task output")
return cast(AsyncIterator[T], self._data)
def set_output(self, output_data: T | AsyncIterator[T]) -> None:
self._data = output_data
"""Save the output data to current object.
Raises:
ValueError: If the output data is not a stream.
"""
if not _is_async_iterator(output_data):
raise ValueError(
f"Can not set non-stream data {output_data} to SimpleStreamTaskOutput"
)
self._data = cast(AsyncIterator[T], output_data)
def new_output(self) -> TaskOutput[T]:
return SimpleStreamTaskOutput(None)
"""Create new output object with empty data."""
return SimpleStreamTaskOutput()
async def map(self, map_func) -> TaskOutput[T]:
async def map(self, map_func: MapFunc) -> TaskOutput[OUT]:
"""Apply a mapping function to the task's output."""
is_async = asyncio.iscoroutinefunction(map_func)
async def new_iter() -> AsyncIterator[T]:
async for out in self._data:
async def new_iter() -> AsyncIterator[OUT]:
async for out in self.output_stream:
if is_async:
out = await map_func(out)
new_out: OUT = await map_func(out)
else:
out = map_func(out)
yield out
new_out = cast(OUT, map_func(out))
yield new_out
return SimpleStreamTaskOutput(new_iter())
async def reduce(self, reduce_func) -> TaskOutput[T]:
out = await _reduce_stream(self._data, reduce_func)
async def reduce(self, reduce_func: ReduceFunc) -> TaskOutput[OUT]:
"""Apply a reduce function to the task's output."""
out = await _reduce_stream(self.output_stream, reduce_func)
return SimpleTaskOutput(out)
async def unstreamify(
self, transform_func: Callable[[AsyncIterator[T]], T]
) -> TaskOutput[T]:
async def unstreamify(self, transform_func: UnStreamFunc) -> TaskOutput[OUT]:
"""Transform the task's output to a non-stream output."""
if asyncio.iscoroutinefunction(transform_func):
out = await transform_func(self._data)
out = await transform_func(self.output_stream)
else:
out = transform_func(self._data)
out = transform_func(self.output_stream)
return SimpleTaskOutput(out)
async def transform_stream(
self, transform_func: Callable[[AsyncIterator[T]], AsyncIterator[T]]
) -> TaskOutput[T]:
async def transform_stream(self, transform_func: TransformFunc) -> TaskOutput[OUT]:
"""Transform an AsyncIterator[T] to another AsyncIterator[T].
Args:
transform_func (Callable[[AsyncIterator[T]], AsyncIterator[T]]): Function to
apply to the AsyncIterator[T].
Returns:
TaskOutput[T]: The result of applying the reducing function.
"""
if asyncio.iscoroutinefunction(transform_func):
out = await transform_func(self._data)
out: AsyncIterator[OUT] = await transform_func(self.output_stream)
else:
out = transform_func(self._data)
out = cast(AsyncIterator[OUT], transform_func(self.output_stream))
return SimpleStreamTaskOutput(out)
@@ -145,20 +262,34 @@ def _is_async_iterator(obj):
class BaseInputSource(InputSource, ABC):
"""The base class of InputSource."""
def __init__(self) -> None:
"""Create a BaseInputSource."""
super().__init__()
self._is_read = False
@abstractmethod
def _read_data(self, task_ctx: TaskContext) -> Any:
"""Read data with task context"""
"""Return data with task context."""
async def read(self, task_ctx: TaskContext) -> TaskOutput:
"""Read data with task context.
Args:
task_ctx (TaskContext): The task context.
Returns:
TaskOutput: The task output.
Raises:
ValueError: If the input source is a stream and has been read.
"""
data = self._read_data(task_ctx)
if _is_async_iterator(data):
if self._is_read:
raise ValueError(f"Input iterator {data} has been read!")
output = SimpleStreamTaskOutput(data)
output: TaskOutput = SimpleStreamTaskOutput(data)
else:
output = SimpleTaskOutput(data)
self._is_read = True
@@ -166,7 +297,14 @@ class BaseInputSource(InputSource, ABC):
class SimpleInputSource(BaseInputSource):
"""The default implementation of InputSource."""
def __init__(self, data: Any) -> None:
"""Create a SimpleInputSource.
Args:
data (Any): The input data.
"""
super().__init__()
self._data = data
@@ -175,63 +313,121 @@ class SimpleInputSource(BaseInputSource):
class SimpleCallDataInputSource(BaseInputSource):
"""The implementation of InputSource for call data."""
def __init__(self) -> None:
"""Create a SimpleCallDataInputSource."""
super().__init__()
def _read_data(self, task_ctx: TaskContext) -> Any:
"""Read data from task context.
Returns:
Any: The data.
Raises:
ValueError: If the call data is empty.
"""
call_data = task_ctx.call_data
data = call_data.get("data") if call_data else None
if not (call_data and data):
data = call_data.get("data", EMPTY_DATA) if call_data else EMPTY_DATA
if data == EMPTY_DATA:
raise ValueError("No call data for current SimpleCallDataInputSource")
return data
class DefaultTaskContext(TaskContext, Generic[T]):
"""The default implementation of TaskContext."""
def __init__(
self, task_id: str, task_state: TaskState, task_output: TaskOutput[T]
self,
task_id: str,
task_state: TaskState,
task_output: Optional[TaskOutput[T]] = None,
) -> None:
"""Create a DefaultTaskContext.
Args:
task_id (str): The task id.
task_state (TaskState): The task state.
task_output (Optional[TaskOutput[T]], optional): The task output. Defaults
to None.
"""
super().__init__()
self._task_id = task_id
self._task_state = task_state
self._output = task_output
self._task_input = None
self._metadata = {}
self._output: Optional[TaskOutput[T]] = task_output
self._task_input: Optional[InputContext] = None
self._metadata: Dict[str, Any] = {}
@property
def task_id(self) -> str:
"""Return the task id."""
return self._task_id
@property
def task_input(self) -> InputContext:
"""Return the task input."""
if not self._task_input:
raise ValueError("No input for current task context")
return self._task_input
def set_task_input(self, input_ctx: "InputContext") -> None:
def set_task_input(self, input_ctx: InputContext) -> None:
"""Save the task input to current task."""
self._task_input = input_ctx
@property
def task_output(self) -> TaskOutput:
"""Return the task output.
Returns:
TaskOutput: The task output.
Raises:
ValueError: If the task output is empty.
"""
if not self._output:
raise ValueError("No output for current task context")
return self._output
def set_task_output(self, task_output: TaskOutput) -> None:
"""Save the task output to current task.
Args:
task_output (TaskOutput): The task output.
"""
self._output = task_output
@property
def current_state(self) -> TaskState:
"""Return the current task state."""
return self._task_state
def set_current_state(self, task_state: TaskState) -> None:
"""Save the current task state to current task."""
self._task_state = task_state
def new_ctx(self) -> TaskContext:
"""Create new task context with empty output."""
if not self._output:
raise ValueError("No output for current task context")
new_output = self._output.new_output()
return DefaultTaskContext(self._task_id, self._task_state, new_output)
@property
def metadata(self) -> Dict[str, Any]:
"""Return the metadata of current task.
Returns:
Dict[str, Any]: The metadata.
"""
return self._metadata
async def _call_data_to_output(self) -> Optional[TaskOutput[T]]:
"""Get the call data for current data"""
"""Return the call data of current task.
Returns:
Optional[TaskOutput[T]]: The call data.
"""
call_data = self.call_data
if not call_data:
return None
@@ -240,24 +436,48 @@ class DefaultTaskContext(TaskContext, Generic[T]):
class DefaultInputContext(InputContext):
"""The default implementation of InputContext.
It wraps the all inputs from parent tasks and provide some basic data operations.
"""
def __init__(self, outputs: List[TaskContext]) -> None:
"""Create a DefaultInputContext.
Args:
outputs (List[TaskContext]): The outputs from parent tasks.
"""
super().__init__()
self._outputs = outputs
@property
def parent_outputs(self) -> List[TaskContext]:
"""Return the outputs from parent tasks.
Returns:
List[TaskContext]: The outputs from parent tasks.
"""
return self._outputs
async def _apply_func(
self, func: Callable[[Any], Any], apply_type: str = "map"
) -> Tuple[List[TaskContext], List[TaskOutput]]:
"""Apply the function to all parent outputs.
Args:
func (Callable[[Any], Any]): The function to apply.
apply_type (str, optional): The apply type. Defaults to "map".
Returns:
Tuple[List[TaskContext], List[TaskOutput]]: The new parent outputs and the
results of applying the function.
"""
new_outputs: List[TaskContext] = []
map_tasks = []
for out in self._outputs:
new_outputs.append(out.new_ctx())
result = None
if apply_type == "map":
result = out.task_output.map(func)
result: Coroutine[Any, Any, TaskOutput[Any]] = out.task_output.map(func)
elif apply_type == "reduce":
result = out.task_output.reduce(func)
elif apply_type == "check_condition":
@@ -269,29 +489,40 @@ class DefaultInputContext(InputContext):
return new_outputs, results
async def map(self, map_func: Callable[[Any], Any]) -> InputContext:
"""Apply a mapping function to all parent outputs."""
new_outputs, results = await self._apply_func(map_func)
for i, task_ctx in enumerate(new_outputs):
task_ctx: TaskContext = task_ctx
task_ctx = cast(TaskContext, task_ctx)
task_ctx.set_task_output(results[i])
return DefaultInputContext(new_outputs)
async def map_all(self, map_func: Callable[..., Any]) -> InputContext:
"""Apply a mapping function to all parent outputs.
The parent outputs will be unpacked and passed to the mapping function.
Args:
map_func (Callable[..., Any]): The mapping function.
Returns:
InputContext: The new input context.
"""
if not self._outputs:
return DefaultInputContext([])
# Some parent may be empty
not_empty_idx = 0
for i, p in enumerate(self._outputs):
if p.task_output.is_empty:
# Skip empty parent
continue
not_empty_idx = i
break
# All output is empty?
is_steam = self._outputs[not_empty_idx].task_output.is_stream
if is_steam:
if not self.check_stream(skip_empty=True):
raise ValueError(
"The output in all tasks must has same output format to map_all"
)
if is_steam and not self.check_stream(skip_empty=True):
raise ValueError(
"The output in all tasks must has same output format to map_all"
)
outputs = []
for out in self._outputs:
if out.task_output.is_stream:
@@ -305,22 +536,26 @@ class DefaultInputContext(InputContext):
single_output: TaskContext = self._outputs[not_empty_idx].new_ctx()
single_output.task_output.set_output(map_res)
logger.debug(
f"Current map_all map_res: {map_res}, is steam: {single_output.task_output.is_stream}"
f"Current map_all map_res: {map_res}, is steam: "
f"{single_output.task_output.is_stream}"
)
return DefaultInputContext([single_output])
async def reduce(self, reduce_func: Callable[[Any], Any]) -> InputContext:
"""Apply a reduce function to all parent outputs."""
if not self.check_stream():
raise ValueError(
"The output in all tasks must has same output format of stream to apply reduce function"
"The output in all tasks must has same output format of stream to apply"
" reduce function"
)
new_outputs, results = await self._apply_func(reduce_func, apply_type="reduce")
for i, task_ctx in enumerate(new_outputs):
task_ctx: TaskContext = task_ctx
task_ctx = cast(TaskContext, task_ctx)
task_ctx.set_task_output(results[i])
return DefaultInputContext(new_outputs)
async def filter(self, filter_func: Callable[[Any], bool]) -> InputContext:
"""Filter all parent outputs."""
new_outputs, results = await self._apply_func(
filter_func, apply_type="check_condition"
)
@@ -331,15 +566,16 @@ class DefaultInputContext(InputContext):
return DefaultInputContext(result_outputs)
async def predicate_map(
self, predicate_func: Callable[[Any], bool], failed_value: Any = None
self, predicate_func: PredicateFunc, failed_value: Any = None
) -> "InputContext":
"""Apply a predicate function to all parent outputs."""
new_outputs, results = await self._apply_func(
predicate_func, apply_type="check_condition"
)
result_outputs = []
for i, task_ctx in enumerate(new_outputs):
task_ctx: TaskContext = task_ctx
if results[i]:
task_ctx = cast(TaskContext, task_ctx)
if not results[i].is_empty:
task_ctx.task_output.set_output(True)
result_outputs.append(task_ctx)
else: