mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-03 10:05:13 +00:00
chore: Add pylint for DB-GPT core lib (#1076)
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""The module of Task."""
|
||||
|
@@ -1,8 +1,10 @@
|
||||
"""Base classes for task-related objects."""
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
@@ -17,6 +19,24 @@ OUT = TypeVar("OUT")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class _EMPTY_DATA_TYPE:
|
||||
def __bool__(self):
|
||||
return False
|
||||
|
||||
|
||||
EMPTY_DATA = _EMPTY_DATA_TYPE()
|
||||
SKIP_DATA = _EMPTY_DATA_TYPE()
|
||||
PLACEHOLDER_DATA = _EMPTY_DATA_TYPE()
|
||||
|
||||
MapFunc = Union[Callable[[IN], OUT], Callable[[IN], Awaitable[OUT]]]
|
||||
ReduceFunc = Union[Callable[[IN], OUT], Callable[[IN], Awaitable[OUT]]]
|
||||
StreamFunc = Callable[[IN], Awaitable[AsyncIterator[OUT]]]
|
||||
UnStreamFunc = Callable[[AsyncIterator[IN]], OUT]
|
||||
TransformFunc = Callable[[AsyncIterator[IN]], Awaitable[AsyncIterator[OUT]]]
|
||||
PredicateFunc = Union[Callable[[IN], bool], Callable[[IN], Awaitable[bool]]]
|
||||
JoinFunc = Union[Callable[..., OUT], Callable[..., Awaitable[OUT]]]
|
||||
|
||||
|
||||
class TaskState(str, Enum):
|
||||
"""Enumeration representing the state of a task in the workflow.
|
||||
|
||||
@@ -33,8 +53,8 @@ class TaskState(str, Enum):
|
||||
class TaskOutput(ABC, Generic[T]):
|
||||
"""Abstract base class representing the output of a task.
|
||||
|
||||
This class encapsulates the output of a task and provides methods to access the output data.
|
||||
It can be subclassed to implement specific output behaviors.
|
||||
This class encapsulates the output of a task and provides methods to access the
|
||||
output data.It can be subclassed to implement specific output behaviors.
|
||||
"""
|
||||
|
||||
@property
|
||||
@@ -56,20 +76,30 @@ class TaskOutput(ABC, Generic[T]):
|
||||
return False
|
||||
|
||||
@property
|
||||
def output(self) -> Optional[T]:
|
||||
def is_none(self) -> bool:
|
||||
"""Check if the output is None.
|
||||
|
||||
Returns:
|
||||
bool: True if the output is None, False otherwise.
|
||||
"""
|
||||
return False
|
||||
|
||||
@property
|
||||
def output(self) -> T:
|
||||
"""Return the output of the task.
|
||||
|
||||
Returns:
|
||||
T: The output of the task. None if the output is empty.
|
||||
T: The output of the task.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def output_stream(self) -> Optional[AsyncIterator[T]]:
|
||||
def output_stream(self) -> AsyncIterator[T]:
|
||||
"""Return the output of the task as an asynchronous stream.
|
||||
|
||||
Returns:
|
||||
AsyncIterator[T]: An asynchronous iterator over the output. None if the output is empty.
|
||||
AsyncIterator[T]: An asynchronous iterator over the output. None if the
|
||||
output is empty.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -83,39 +113,38 @@ class TaskOutput(ABC, Generic[T]):
|
||||
|
||||
@abstractmethod
|
||||
def new_output(self) -> "TaskOutput[T]":
|
||||
"""Create new output object"""
|
||||
"""Create new output object."""
|
||||
|
||||
async def map(self, map_func) -> "TaskOutput[T]":
|
||||
async def map(self, map_func: MapFunc) -> "TaskOutput[OUT]":
|
||||
"""Apply a mapping function to the task's output.
|
||||
|
||||
Args:
|
||||
map_func: A function to apply to the task's output.
|
||||
map_func (MapFunc): A function to apply to the task's output.
|
||||
|
||||
Returns:
|
||||
TaskOutput[T]: The result of applying the mapping function.
|
||||
TaskOutput[OUT]: The result of applying the mapping function.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def reduce(self, reduce_func) -> "TaskOutput[T]":
|
||||
async def reduce(self, reduce_func: ReduceFunc) -> "TaskOutput[OUT]":
|
||||
"""Apply a reducing function to the task's output.
|
||||
|
||||
Stream TaskOutput to Nonstream TaskOutput.
|
||||
Stream TaskOutput to no stream TaskOutput.
|
||||
|
||||
Args:
|
||||
reduce_func: A reducing function to apply to the task's output.
|
||||
|
||||
Returns:
|
||||
TaskOutput[T]: The result of applying the reducing function.
|
||||
TaskOutput[OUT]: The result of applying the reducing function.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def streamify(
|
||||
self, transform_func: Callable[[T], AsyncIterator[T]]
|
||||
) -> "TaskOutput[T]":
|
||||
async def streamify(self, transform_func: StreamFunc) -> "TaskOutput[T]":
|
||||
"""Convert a value of type T to an AsyncIterator[T] using a transform function.
|
||||
|
||||
Args:
|
||||
transform_func (Callable[[T], AsyncIterator[T]]): Function to transform a T value into an AsyncIterator[T].
|
||||
transform_func (StreamFunc): Function to transform a T value into an
|
||||
AsyncIterator[OUT].
|
||||
|
||||
Returns:
|
||||
TaskOutput[T]: The result of applying the reducing function.
|
||||
@@ -123,38 +152,39 @@ class TaskOutput(ABC, Generic[T]):
|
||||
raise NotImplementedError
|
||||
|
||||
async def transform_stream(
|
||||
self, transform_func: Callable[[AsyncIterator[T]], AsyncIterator[T]]
|
||||
) -> "TaskOutput[T]":
|
||||
"""Transform an AsyncIterator[T] to another AsyncIterator[T] using a given function.
|
||||
self, transform_func: TransformFunc
|
||||
) -> "TaskOutput[OUT]":
|
||||
"""Transform an AsyncIterator[T] to another AsyncIterator[T].
|
||||
|
||||
Args:
|
||||
transform_func (Callable[[AsyncIterator[T]], AsyncIterator[T]]): Function to apply to the AsyncIterator[T].
|
||||
transform_func (Callable[[AsyncIterator[T]], AsyncIterator[T]]): Function to
|
||||
apply to the AsyncIterator[T].
|
||||
|
||||
Returns:
|
||||
TaskOutput[T]: The result of applying the reducing function.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def unstreamify(
|
||||
self, transform_func: Callable[[AsyncIterator[T]], T]
|
||||
) -> "TaskOutput[T]":
|
||||
async def unstreamify(self, transform_func: UnStreamFunc) -> "TaskOutput[OUT]":
|
||||
"""Convert an AsyncIterator[T] to a value of type T using a transform function.
|
||||
|
||||
Args:
|
||||
transform_func (Callable[[AsyncIterator[T]], T]): Function to transform an AsyncIterator[T] into a T value.
|
||||
transform_func (UnStreamFunc): Function to transform an AsyncIterator[T]
|
||||
into a T value.
|
||||
|
||||
Returns:
|
||||
TaskOutput[T]: The result of applying the reducing function.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def check_condition(self, condition_func) -> bool:
|
||||
async def check_condition(self, condition_func) -> "TaskOutput[OUT]":
|
||||
"""Check if current output meets a given condition.
|
||||
|
||||
Args:
|
||||
condition_func: A function to determine if the condition is met.
|
||||
Returns:
|
||||
bool: True if current output meet the condition, False otherwise.
|
||||
TaskOutput[T]: The result of applying the reducing function.
|
||||
If the condition is not met, return empty output.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -182,6 +212,9 @@ class TaskContext(ABC, Generic[T]):
|
||||
|
||||
Returns:
|
||||
InputContext: The InputContext of current task.
|
||||
|
||||
Raises:
|
||||
Exception: If the InputContext is not set.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@@ -216,7 +249,7 @@ class TaskContext(ABC, Generic[T]):
|
||||
|
||||
@abstractmethod
|
||||
def set_current_state(self, task_state: TaskState) -> None:
|
||||
"""Set current task state
|
||||
"""Set current task state.
|
||||
|
||||
Args:
|
||||
task_state (TaskState): The task state to be set.
|
||||
@@ -224,7 +257,7 @@ class TaskContext(ABC, Generic[T]):
|
||||
|
||||
@abstractmethod
|
||||
def new_ctx(self) -> "TaskContext":
|
||||
"""Create new task context
|
||||
"""Create new task context.
|
||||
|
||||
Returns:
|
||||
TaskContext: A new instance of a TaskContext.
|
||||
@@ -233,14 +266,14 @@ class TaskContext(ABC, Generic[T]):
|
||||
@property
|
||||
@abstractmethod
|
||||
def metadata(self) -> Dict[str, Any]:
|
||||
"""Get the metadata of current task
|
||||
"""Return the metadata of current task.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The metadata
|
||||
"""
|
||||
|
||||
def update_metadata(self, key: str, value: Any) -> None:
|
||||
"""Update metadata with key and value
|
||||
"""Update metadata with key and value.
|
||||
|
||||
Args:
|
||||
key (str): The key of metadata
|
||||
@@ -250,15 +283,15 @@ class TaskContext(ABC, Generic[T]):
|
||||
|
||||
@property
|
||||
def call_data(self) -> Optional[Dict]:
|
||||
"""Get the call data for current data"""
|
||||
"""Return the call data for current data."""
|
||||
return self.metadata.get("call_data")
|
||||
|
||||
@abstractmethod
|
||||
async def _call_data_to_output(self) -> Optional[TaskOutput[T]]:
|
||||
"""Get the call data for current data"""
|
||||
"""Get the call data for current data."""
|
||||
|
||||
def set_call_data(self, call_data: Dict) -> None:
|
||||
"""Set call data for current task"""
|
||||
"""Save the call data for current task."""
|
||||
self.update_metadata("call_data", call_data)
|
||||
|
||||
|
||||
@@ -315,7 +348,8 @@ class InputContext(ABC):
|
||||
"""Filter the inputs based on a provided function.
|
||||
|
||||
Args:
|
||||
filter_func (Callable[[Any], bool]): A function that returns True for inputs to keep.
|
||||
filter_func (Callable[[Any], bool]): A function that returns True for
|
||||
inputs to keep.
|
||||
|
||||
Returns:
|
||||
InputContext: A new InputContext instance with the filtered inputs.
|
||||
@@ -323,13 +357,15 @@ class InputContext(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def predicate_map(
|
||||
self, predicate_func: Callable[[Any], bool], failed_value: Any = None
|
||||
self, predicate_func: PredicateFunc, failed_value: Any = None
|
||||
) -> "InputContext":
|
||||
"""Predicate the inputs based on a provided function.
|
||||
|
||||
Args:
|
||||
predicate_func (Callable[[Any], bool]): A function that returns True for inputs is predicate True.
|
||||
failed_value (Any): The value to be set if the return value of predicate function is False
|
||||
predicate_func (Callable[[Any], bool]): A function that returns True for
|
||||
inputs is predicate True.
|
||||
failed_value (Any): The value to be set if the return value of predicate
|
||||
function is False
|
||||
Returns:
|
||||
InputContext: A new InputContext instance with the predicate inputs.
|
||||
"""
|
||||
|
@@ -1,3 +1,7 @@
|
||||
"""The default implementation of Task.
|
||||
|
||||
This implementation can run workflow in local machine.
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
@@ -8,15 +12,32 @@ from typing import (
|
||||
Coroutine,
|
||||
Dict,
|
||||
Generic,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from .base import InputContext, InputSource, T, TaskContext, TaskOutput, TaskState
|
||||
from .base import (
|
||||
_EMPTY_DATA_TYPE,
|
||||
EMPTY_DATA,
|
||||
OUT,
|
||||
PLACEHOLDER_DATA,
|
||||
SKIP_DATA,
|
||||
InputContext,
|
||||
InputSource,
|
||||
MapFunc,
|
||||
PredicateFunc,
|
||||
ReduceFunc,
|
||||
StreamFunc,
|
||||
T,
|
||||
TaskContext,
|
||||
TaskOutput,
|
||||
TaskState,
|
||||
TransformFunc,
|
||||
UnStreamFunc,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -37,101 +58,197 @@ async def _reduce_stream(stream: AsyncIterator, reduce_function) -> Any:
|
||||
|
||||
|
||||
class SimpleTaskOutput(TaskOutput[T], Generic[T]):
|
||||
def __init__(self, data: T) -> None:
|
||||
"""The default implementation of TaskOutput.
|
||||
|
||||
It wraps the no stream data and provide some basic data operations.
|
||||
"""
|
||||
|
||||
def __init__(self, data: Union[T, _EMPTY_DATA_TYPE] = EMPTY_DATA) -> None:
|
||||
"""Create a SimpleTaskOutput.
|
||||
|
||||
Args:
|
||||
data (Union[T, _EMPTY_DATA_TYPE], optional): The output data. Defaults to
|
||||
EMPTY_DATA.
|
||||
"""
|
||||
super().__init__()
|
||||
self._data = data
|
||||
|
||||
@property
|
||||
def output(self) -> T:
|
||||
return self._data
|
||||
"""Return the output data."""
|
||||
if self._data == EMPTY_DATA:
|
||||
raise ValueError("No output data for current task output")
|
||||
return cast(T, self._data)
|
||||
|
||||
def set_output(self, output_data: T | AsyncIterator[T]) -> None:
|
||||
self._data = output_data
|
||||
"""Save the output data to current object.
|
||||
|
||||
Args:
|
||||
output_data (T | AsyncIterator[T]): The output data.
|
||||
"""
|
||||
if _is_async_iterator(output_data):
|
||||
raise ValueError(
|
||||
f"Can not set stream data {output_data} to SimpleTaskOutput"
|
||||
)
|
||||
self._data = cast(T, output_data)
|
||||
|
||||
def new_output(self) -> TaskOutput[T]:
|
||||
return SimpleTaskOutput(None)
|
||||
"""Create new output object with empty data."""
|
||||
return SimpleTaskOutput()
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
"""Return True if the output data is empty."""
|
||||
return self._data == EMPTY_DATA or self._data == SKIP_DATA
|
||||
|
||||
@property
|
||||
def is_none(self) -> bool:
|
||||
"""Return True if the output data is None."""
|
||||
return self._data is None
|
||||
|
||||
async def _apply_func(self, func) -> Any:
|
||||
"""Apply the function to current output data."""
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
out = await func(self._data)
|
||||
else:
|
||||
out = func(self._data)
|
||||
return out
|
||||
|
||||
async def map(self, map_func) -> TaskOutput[T]:
|
||||
async def map(self, map_func: MapFunc) -> TaskOutput[OUT]:
|
||||
"""Apply a mapping function to the task's output.
|
||||
|
||||
Args:
|
||||
map_func (MapFunc): A function to apply to the task's output.
|
||||
|
||||
Returns:
|
||||
TaskOutput[OUT]: The result of applying the mapping function.
|
||||
"""
|
||||
out = await self._apply_func(map_func)
|
||||
return SimpleTaskOutput(out)
|
||||
|
||||
async def check_condition(self, condition_func) -> bool:
|
||||
return await self._apply_func(condition_func)
|
||||
async def check_condition(self, condition_func) -> TaskOutput[OUT]:
|
||||
"""Check the condition function."""
|
||||
out = await self._apply_func(condition_func)
|
||||
if out:
|
||||
return SimpleTaskOutput(PLACEHOLDER_DATA)
|
||||
return SimpleTaskOutput(EMPTY_DATA)
|
||||
|
||||
async def streamify(
|
||||
self, transform_func: Callable[[T], AsyncIterator[T]]
|
||||
) -> TaskOutput[T]:
|
||||
async def streamify(self, transform_func: StreamFunc) -> TaskOutput[OUT]:
|
||||
"""Transform the task's output to a stream output.
|
||||
|
||||
Args:
|
||||
transform_func (StreamFunc): A function to transform the task's output to a
|
||||
stream output.
|
||||
|
||||
Returns:
|
||||
TaskOutput[OUT]: The result of transforming the task's output to a stream
|
||||
output.
|
||||
"""
|
||||
out = await self._apply_func(transform_func)
|
||||
return SimpleStreamTaskOutput(out)
|
||||
|
||||
|
||||
class SimpleStreamTaskOutput(TaskOutput[T], Generic[T]):
|
||||
def __init__(self, data: AsyncIterator[T]) -> None:
|
||||
"""The default stream implementation of TaskOutput."""
|
||||
|
||||
def __init__(
|
||||
self, data: Union[AsyncIterator[T], _EMPTY_DATA_TYPE] = EMPTY_DATA
|
||||
) -> None:
|
||||
"""Create a SimpleStreamTaskOutput.
|
||||
|
||||
Args:
|
||||
data (Union[AsyncIterator[T], _EMPTY_DATA_TYPE], optional): The output data.
|
||||
Defaults to EMPTY_DATA.
|
||||
"""
|
||||
super().__init__()
|
||||
self._data = data
|
||||
|
||||
@property
|
||||
def is_stream(self) -> bool:
|
||||
"""Return True if the output data is a stream."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
return not self._data
|
||||
"""Return True if the output data is empty."""
|
||||
return self._data == EMPTY_DATA or self._data == SKIP_DATA
|
||||
|
||||
@property
|
||||
def is_none(self) -> bool:
|
||||
"""Return True if the output data is None."""
|
||||
return self._data is None
|
||||
|
||||
@property
|
||||
def output_stream(self) -> AsyncIterator[T]:
|
||||
return self._data
|
||||
"""Return the output data.
|
||||
|
||||
Returns:
|
||||
AsyncIterator[T]: The output data.
|
||||
|
||||
Raises:
|
||||
ValueError: If the output data is empty.
|
||||
"""
|
||||
if self._data == EMPTY_DATA:
|
||||
raise ValueError("No output data for current task output")
|
||||
return cast(AsyncIterator[T], self._data)
|
||||
|
||||
def set_output(self, output_data: T | AsyncIterator[T]) -> None:
|
||||
self._data = output_data
|
||||
"""Save the output data to current object.
|
||||
|
||||
Raises:
|
||||
ValueError: If the output data is not a stream.
|
||||
"""
|
||||
if not _is_async_iterator(output_data):
|
||||
raise ValueError(
|
||||
f"Can not set non-stream data {output_data} to SimpleStreamTaskOutput"
|
||||
)
|
||||
self._data = cast(AsyncIterator[T], output_data)
|
||||
|
||||
def new_output(self) -> TaskOutput[T]:
|
||||
return SimpleStreamTaskOutput(None)
|
||||
"""Create new output object with empty data."""
|
||||
return SimpleStreamTaskOutput()
|
||||
|
||||
async def map(self, map_func) -> TaskOutput[T]:
|
||||
async def map(self, map_func: MapFunc) -> TaskOutput[OUT]:
|
||||
"""Apply a mapping function to the task's output."""
|
||||
is_async = asyncio.iscoroutinefunction(map_func)
|
||||
|
||||
async def new_iter() -> AsyncIterator[T]:
|
||||
async for out in self._data:
|
||||
async def new_iter() -> AsyncIterator[OUT]:
|
||||
async for out in self.output_stream:
|
||||
if is_async:
|
||||
out = await map_func(out)
|
||||
new_out: OUT = await map_func(out)
|
||||
else:
|
||||
out = map_func(out)
|
||||
yield out
|
||||
new_out = cast(OUT, map_func(out))
|
||||
yield new_out
|
||||
|
||||
return SimpleStreamTaskOutput(new_iter())
|
||||
|
||||
async def reduce(self, reduce_func) -> TaskOutput[T]:
|
||||
out = await _reduce_stream(self._data, reduce_func)
|
||||
async def reduce(self, reduce_func: ReduceFunc) -> TaskOutput[OUT]:
|
||||
"""Apply a reduce function to the task's output."""
|
||||
out = await _reduce_stream(self.output_stream, reduce_func)
|
||||
return SimpleTaskOutput(out)
|
||||
|
||||
async def unstreamify(
|
||||
self, transform_func: Callable[[AsyncIterator[T]], T]
|
||||
) -> TaskOutput[T]:
|
||||
async def unstreamify(self, transform_func: UnStreamFunc) -> TaskOutput[OUT]:
|
||||
"""Transform the task's output to a non-stream output."""
|
||||
if asyncio.iscoroutinefunction(transform_func):
|
||||
out = await transform_func(self._data)
|
||||
out = await transform_func(self.output_stream)
|
||||
else:
|
||||
out = transform_func(self._data)
|
||||
out = transform_func(self.output_stream)
|
||||
return SimpleTaskOutput(out)
|
||||
|
||||
async def transform_stream(
|
||||
self, transform_func: Callable[[AsyncIterator[T]], AsyncIterator[T]]
|
||||
) -> TaskOutput[T]:
|
||||
async def transform_stream(self, transform_func: TransformFunc) -> TaskOutput[OUT]:
|
||||
"""Transform an AsyncIterator[T] to another AsyncIterator[T].
|
||||
|
||||
Args:
|
||||
transform_func (Callable[[AsyncIterator[T]], AsyncIterator[T]]): Function to
|
||||
apply to the AsyncIterator[T].
|
||||
|
||||
Returns:
|
||||
TaskOutput[T]: The result of applying the reducing function.
|
||||
"""
|
||||
if asyncio.iscoroutinefunction(transform_func):
|
||||
out = await transform_func(self._data)
|
||||
out: AsyncIterator[OUT] = await transform_func(self.output_stream)
|
||||
else:
|
||||
out = transform_func(self._data)
|
||||
out = cast(AsyncIterator[OUT], transform_func(self.output_stream))
|
||||
return SimpleStreamTaskOutput(out)
|
||||
|
||||
|
||||
@@ -145,20 +262,34 @@ def _is_async_iterator(obj):
|
||||
|
||||
|
||||
class BaseInputSource(InputSource, ABC):
|
||||
"""The base class of InputSource."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Create a BaseInputSource."""
|
||||
super().__init__()
|
||||
self._is_read = False
|
||||
|
||||
@abstractmethod
|
||||
def _read_data(self, task_ctx: TaskContext) -> Any:
|
||||
"""Read data with task context"""
|
||||
"""Return data with task context."""
|
||||
|
||||
async def read(self, task_ctx: TaskContext) -> TaskOutput:
|
||||
"""Read data with task context.
|
||||
|
||||
Args:
|
||||
task_ctx (TaskContext): The task context.
|
||||
|
||||
Returns:
|
||||
TaskOutput: The task output.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input source is a stream and has been read.
|
||||
"""
|
||||
data = self._read_data(task_ctx)
|
||||
if _is_async_iterator(data):
|
||||
if self._is_read:
|
||||
raise ValueError(f"Input iterator {data} has been read!")
|
||||
output = SimpleStreamTaskOutput(data)
|
||||
output: TaskOutput = SimpleStreamTaskOutput(data)
|
||||
else:
|
||||
output = SimpleTaskOutput(data)
|
||||
self._is_read = True
|
||||
@@ -166,7 +297,14 @@ class BaseInputSource(InputSource, ABC):
|
||||
|
||||
|
||||
class SimpleInputSource(BaseInputSource):
|
||||
"""The default implementation of InputSource."""
|
||||
|
||||
def __init__(self, data: Any) -> None:
|
||||
"""Create a SimpleInputSource.
|
||||
|
||||
Args:
|
||||
data (Any): The input data.
|
||||
"""
|
||||
super().__init__()
|
||||
self._data = data
|
||||
|
||||
@@ -175,63 +313,121 @@ class SimpleInputSource(BaseInputSource):
|
||||
|
||||
|
||||
class SimpleCallDataInputSource(BaseInputSource):
|
||||
"""The implementation of InputSource for call data."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Create a SimpleCallDataInputSource."""
|
||||
super().__init__()
|
||||
|
||||
def _read_data(self, task_ctx: TaskContext) -> Any:
|
||||
"""Read data from task context.
|
||||
|
||||
Returns:
|
||||
Any: The data.
|
||||
|
||||
Raises:
|
||||
ValueError: If the call data is empty.
|
||||
"""
|
||||
call_data = task_ctx.call_data
|
||||
data = call_data.get("data") if call_data else None
|
||||
if not (call_data and data):
|
||||
data = call_data.get("data", EMPTY_DATA) if call_data else EMPTY_DATA
|
||||
if data == EMPTY_DATA:
|
||||
raise ValueError("No call data for current SimpleCallDataInputSource")
|
||||
return data
|
||||
|
||||
|
||||
class DefaultTaskContext(TaskContext, Generic[T]):
|
||||
"""The default implementation of TaskContext."""
|
||||
|
||||
def __init__(
|
||||
self, task_id: str, task_state: TaskState, task_output: TaskOutput[T]
|
||||
self,
|
||||
task_id: str,
|
||||
task_state: TaskState,
|
||||
task_output: Optional[TaskOutput[T]] = None,
|
||||
) -> None:
|
||||
"""Create a DefaultTaskContext.
|
||||
|
||||
Args:
|
||||
task_id (str): The task id.
|
||||
task_state (TaskState): The task state.
|
||||
task_output (Optional[TaskOutput[T]], optional): The task output. Defaults
|
||||
to None.
|
||||
"""
|
||||
super().__init__()
|
||||
self._task_id = task_id
|
||||
self._task_state = task_state
|
||||
self._output = task_output
|
||||
self._task_input = None
|
||||
self._metadata = {}
|
||||
self._output: Optional[TaskOutput[T]] = task_output
|
||||
self._task_input: Optional[InputContext] = None
|
||||
self._metadata: Dict[str, Any] = {}
|
||||
|
||||
@property
|
||||
def task_id(self) -> str:
|
||||
"""Return the task id."""
|
||||
return self._task_id
|
||||
|
||||
@property
|
||||
def task_input(self) -> InputContext:
|
||||
"""Return the task input."""
|
||||
if not self._task_input:
|
||||
raise ValueError("No input for current task context")
|
||||
return self._task_input
|
||||
|
||||
def set_task_input(self, input_ctx: "InputContext") -> None:
|
||||
def set_task_input(self, input_ctx: InputContext) -> None:
|
||||
"""Save the task input to current task."""
|
||||
self._task_input = input_ctx
|
||||
|
||||
@property
|
||||
def task_output(self) -> TaskOutput:
|
||||
"""Return the task output.
|
||||
|
||||
Returns:
|
||||
TaskOutput: The task output.
|
||||
|
||||
Raises:
|
||||
ValueError: If the task output is empty.
|
||||
"""
|
||||
if not self._output:
|
||||
raise ValueError("No output for current task context")
|
||||
return self._output
|
||||
|
||||
def set_task_output(self, task_output: TaskOutput) -> None:
|
||||
"""Save the task output to current task.
|
||||
|
||||
Args:
|
||||
task_output (TaskOutput): The task output.
|
||||
"""
|
||||
self._output = task_output
|
||||
|
||||
@property
|
||||
def current_state(self) -> TaskState:
|
||||
"""Return the current task state."""
|
||||
return self._task_state
|
||||
|
||||
def set_current_state(self, task_state: TaskState) -> None:
|
||||
"""Save the current task state to current task."""
|
||||
self._task_state = task_state
|
||||
|
||||
def new_ctx(self) -> TaskContext:
|
||||
"""Create new task context with empty output."""
|
||||
if not self._output:
|
||||
raise ValueError("No output for current task context")
|
||||
new_output = self._output.new_output()
|
||||
return DefaultTaskContext(self._task_id, self._task_state, new_output)
|
||||
|
||||
@property
|
||||
def metadata(self) -> Dict[str, Any]:
|
||||
"""Return the metadata of current task.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The metadata.
|
||||
"""
|
||||
return self._metadata
|
||||
|
||||
async def _call_data_to_output(self) -> Optional[TaskOutput[T]]:
|
||||
"""Get the call data for current data"""
|
||||
"""Return the call data of current task.
|
||||
|
||||
Returns:
|
||||
Optional[TaskOutput[T]]: The call data.
|
||||
"""
|
||||
call_data = self.call_data
|
||||
if not call_data:
|
||||
return None
|
||||
@@ -240,24 +436,48 @@ class DefaultTaskContext(TaskContext, Generic[T]):
|
||||
|
||||
|
||||
class DefaultInputContext(InputContext):
|
||||
"""The default implementation of InputContext.
|
||||
|
||||
It wraps the all inputs from parent tasks and provide some basic data operations.
|
||||
"""
|
||||
|
||||
def __init__(self, outputs: List[TaskContext]) -> None:
|
||||
"""Create a DefaultInputContext.
|
||||
|
||||
Args:
|
||||
outputs (List[TaskContext]): The outputs from parent tasks.
|
||||
"""
|
||||
super().__init__()
|
||||
self._outputs = outputs
|
||||
|
||||
@property
|
||||
def parent_outputs(self) -> List[TaskContext]:
|
||||
"""Return the outputs from parent tasks.
|
||||
|
||||
Returns:
|
||||
List[TaskContext]: The outputs from parent tasks.
|
||||
"""
|
||||
return self._outputs
|
||||
|
||||
async def _apply_func(
|
||||
self, func: Callable[[Any], Any], apply_type: str = "map"
|
||||
) -> Tuple[List[TaskContext], List[TaskOutput]]:
|
||||
"""Apply the function to all parent outputs.
|
||||
|
||||
Args:
|
||||
func (Callable[[Any], Any]): The function to apply.
|
||||
apply_type (str, optional): The apply type. Defaults to "map".
|
||||
|
||||
Returns:
|
||||
Tuple[List[TaskContext], List[TaskOutput]]: The new parent outputs and the
|
||||
results of applying the function.
|
||||
"""
|
||||
new_outputs: List[TaskContext] = []
|
||||
map_tasks = []
|
||||
for out in self._outputs:
|
||||
new_outputs.append(out.new_ctx())
|
||||
result = None
|
||||
if apply_type == "map":
|
||||
result = out.task_output.map(func)
|
||||
result: Coroutine[Any, Any, TaskOutput[Any]] = out.task_output.map(func)
|
||||
elif apply_type == "reduce":
|
||||
result = out.task_output.reduce(func)
|
||||
elif apply_type == "check_condition":
|
||||
@@ -269,29 +489,40 @@ class DefaultInputContext(InputContext):
|
||||
return new_outputs, results
|
||||
|
||||
async def map(self, map_func: Callable[[Any], Any]) -> InputContext:
|
||||
"""Apply a mapping function to all parent outputs."""
|
||||
new_outputs, results = await self._apply_func(map_func)
|
||||
for i, task_ctx in enumerate(new_outputs):
|
||||
task_ctx: TaskContext = task_ctx
|
||||
task_ctx = cast(TaskContext, task_ctx)
|
||||
task_ctx.set_task_output(results[i])
|
||||
return DefaultInputContext(new_outputs)
|
||||
|
||||
async def map_all(self, map_func: Callable[..., Any]) -> InputContext:
|
||||
"""Apply a mapping function to all parent outputs.
|
||||
|
||||
The parent outputs will be unpacked and passed to the mapping function.
|
||||
|
||||
Args:
|
||||
map_func (Callable[..., Any]): The mapping function.
|
||||
|
||||
Returns:
|
||||
InputContext: The new input context.
|
||||
"""
|
||||
if not self._outputs:
|
||||
return DefaultInputContext([])
|
||||
# Some parent may be empty
|
||||
not_empty_idx = 0
|
||||
for i, p in enumerate(self._outputs):
|
||||
if p.task_output.is_empty:
|
||||
# Skip empty parent
|
||||
continue
|
||||
not_empty_idx = i
|
||||
break
|
||||
# All output is empty?
|
||||
is_steam = self._outputs[not_empty_idx].task_output.is_stream
|
||||
if is_steam:
|
||||
if not self.check_stream(skip_empty=True):
|
||||
raise ValueError(
|
||||
"The output in all tasks must has same output format to map_all"
|
||||
)
|
||||
if is_steam and not self.check_stream(skip_empty=True):
|
||||
raise ValueError(
|
||||
"The output in all tasks must has same output format to map_all"
|
||||
)
|
||||
outputs = []
|
||||
for out in self._outputs:
|
||||
if out.task_output.is_stream:
|
||||
@@ -305,22 +536,26 @@ class DefaultInputContext(InputContext):
|
||||
single_output: TaskContext = self._outputs[not_empty_idx].new_ctx()
|
||||
single_output.task_output.set_output(map_res)
|
||||
logger.debug(
|
||||
f"Current map_all map_res: {map_res}, is steam: {single_output.task_output.is_stream}"
|
||||
f"Current map_all map_res: {map_res}, is steam: "
|
||||
f"{single_output.task_output.is_stream}"
|
||||
)
|
||||
return DefaultInputContext([single_output])
|
||||
|
||||
async def reduce(self, reduce_func: Callable[[Any], Any]) -> InputContext:
|
||||
"""Apply a reduce function to all parent outputs."""
|
||||
if not self.check_stream():
|
||||
raise ValueError(
|
||||
"The output in all tasks must has same output format of stream to apply reduce function"
|
||||
"The output in all tasks must has same output format of stream to apply"
|
||||
" reduce function"
|
||||
)
|
||||
new_outputs, results = await self._apply_func(reduce_func, apply_type="reduce")
|
||||
for i, task_ctx in enumerate(new_outputs):
|
||||
task_ctx: TaskContext = task_ctx
|
||||
task_ctx = cast(TaskContext, task_ctx)
|
||||
task_ctx.set_task_output(results[i])
|
||||
return DefaultInputContext(new_outputs)
|
||||
|
||||
async def filter(self, filter_func: Callable[[Any], bool]) -> InputContext:
|
||||
"""Filter all parent outputs."""
|
||||
new_outputs, results = await self._apply_func(
|
||||
filter_func, apply_type="check_condition"
|
||||
)
|
||||
@@ -331,15 +566,16 @@ class DefaultInputContext(InputContext):
|
||||
return DefaultInputContext(result_outputs)
|
||||
|
||||
async def predicate_map(
|
||||
self, predicate_func: Callable[[Any], bool], failed_value: Any = None
|
||||
self, predicate_func: PredicateFunc, failed_value: Any = None
|
||||
) -> "InputContext":
|
||||
"""Apply a predicate function to all parent outputs."""
|
||||
new_outputs, results = await self._apply_func(
|
||||
predicate_func, apply_type="check_condition"
|
||||
)
|
||||
result_outputs = []
|
||||
for i, task_ctx in enumerate(new_outputs):
|
||||
task_ctx: TaskContext = task_ctx
|
||||
if results[i]:
|
||||
task_ctx = cast(TaskContext, task_ctx)
|
||||
if not results[i].is_empty:
|
||||
task_ctx.task_output.set_output(True)
|
||||
result_outputs.append(task_ctx)
|
||||
else:
|
||||
|
Reference in New Issue
Block a user