mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-05 18:33:52 +00:00
478 lines
14 KiB
Python
478 lines
14 KiB
Python
"""Base classes for task-related objects."""
|
|
from abc import ABC, abstractmethod
|
|
from enum import Enum
|
|
from typing import (
|
|
Any,
|
|
AsyncIterable,
|
|
AsyncIterator,
|
|
Awaitable,
|
|
Callable,
|
|
Dict,
|
|
Generic,
|
|
Iterable,
|
|
List,
|
|
Optional,
|
|
TypeVar,
|
|
Union,
|
|
)
|
|
|
|
IN = TypeVar("IN")
|
|
OUT = TypeVar("OUT")
|
|
T = TypeVar("T")
|
|
|
|
|
|
class _EMPTY_DATA_TYPE:
|
|
"""A special type to represent empty data."""
|
|
|
|
def __init__(self, name: str = "EMPTY_DATA"):
|
|
self.name = name
|
|
|
|
def __bool__(self):
|
|
return False
|
|
|
|
def __str__(self):
|
|
return f"EmptyData({self.name})"
|
|
|
|
def is_same(self, obj: Any) -> bool:
|
|
"""Check if the object is the same as the current object.
|
|
|
|
Args:
|
|
obj (Any): The object to compare with.
|
|
|
|
Returns:
|
|
bool: True if the object is the same as the current object, False otherwise.
|
|
"""
|
|
if not isinstance(obj, _EMPTY_DATA_TYPE):
|
|
return False
|
|
return self == obj
|
|
|
|
|
|
EMPTY_DATA = _EMPTY_DATA_TYPE("EMPTY_DATA")
|
|
SKIP_DATA = _EMPTY_DATA_TYPE("SKIP_DATA")
|
|
PLACEHOLDER_DATA = _EMPTY_DATA_TYPE("PLACEHOLDER_DATA")
|
|
|
|
|
|
def is_empty_data(data: Any):
|
|
"""Check if the data is empty."""
|
|
if isinstance(data, _EMPTY_DATA_TYPE):
|
|
return data in (EMPTY_DATA, SKIP_DATA)
|
|
elif hasattr(data, "empty"):
|
|
return getattr(data, "empty", False)
|
|
return False
|
|
|
|
|
|
MapFunc = Union[Callable[[IN], OUT], Callable[[IN], Awaitable[OUT]]]
|
|
ReduceFunc = Union[Callable[[IN, IN], OUT], Callable[[IN, IN], Awaitable[OUT]]]
|
|
StreamFunc = Callable[[IN], Awaitable[AsyncIterator[OUT]]]
|
|
UnStreamFunc = Callable[[AsyncIterator[IN]], OUT]
|
|
TransformFunc = Callable[[AsyncIterator[IN]], Awaitable[AsyncIterator[OUT]]]
|
|
PredicateFunc = Union[Callable[[IN], bool], Callable[[IN], Awaitable[bool]]]
|
|
JoinFunc = Union[Callable[..., OUT], Callable[..., Awaitable[OUT]]]
|
|
|
|
|
|
class TaskState(str, Enum):
|
|
"""Enumeration representing the state of a task in the workflow.
|
|
|
|
This Enum defines various states a task can be in during its lifecycle in the DAG.
|
|
"""
|
|
|
|
INIT = "init" # Initial state of the task, not yet started
|
|
SKIP = "skip" # State indicating the task was skipped
|
|
RUNNING = "running" # State indicating the task is currently running
|
|
SUCCESS = "success" # State indicating the task completed successfully
|
|
FAILED = "failed" # State indicating the task failed during execution
|
|
|
|
|
|
class TaskOutput(ABC, Generic[T]):
|
|
"""Abstract base class representing the output of a task.
|
|
|
|
This class encapsulates the output of a task and provides methods to access the
|
|
output data.It can be subclassed to implement specific output behaviors.
|
|
"""
|
|
|
|
@property
|
|
def is_stream(self) -> bool:
|
|
"""Check if the output is a stream.
|
|
|
|
Returns:
|
|
bool: True if the output is a stream, False otherwise.
|
|
"""
|
|
return False
|
|
|
|
@property
|
|
def is_empty(self) -> bool:
|
|
"""Check if the output is empty.
|
|
|
|
Returns:
|
|
bool: True if the output is empty, False otherwise.
|
|
"""
|
|
return False
|
|
|
|
@property
|
|
def is_none(self) -> bool:
|
|
"""Check if the output is None.
|
|
|
|
Returns:
|
|
bool: True if the output is None, False otherwise.
|
|
"""
|
|
return False
|
|
|
|
@property
|
|
def output(self) -> T:
|
|
"""Return the output of the task.
|
|
|
|
Returns:
|
|
T: The output of the task.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@property
|
|
def output_stream(self) -> AsyncIterator[T]:
|
|
"""Return the output of the task as an asynchronous stream.
|
|
|
|
Returns:
|
|
AsyncIterator[T]: An asynchronous iterator over the output. None if the
|
|
output is empty.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def set_output(self, output_data: Union[T, AsyncIterator[T]]) -> None:
|
|
"""Set the output data to current object.
|
|
|
|
Args:
|
|
output_data (Union[T, AsyncIterator[T]]): Output data.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def new_output(self) -> "TaskOutput[T]":
|
|
"""Create new output object."""
|
|
|
|
async def map(self, map_func: MapFunc) -> "TaskOutput[OUT]":
|
|
"""Apply a mapping function to the task's output.
|
|
|
|
Args:
|
|
map_func (MapFunc): A function to apply to the task's output.
|
|
|
|
Returns:
|
|
TaskOutput[OUT]: The result of applying the mapping function.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
async def reduce(self, reduce_func: ReduceFunc) -> "TaskOutput[OUT]":
|
|
"""Apply a reducing function to the task's output.
|
|
|
|
Stream TaskOutput to no stream TaskOutput.
|
|
|
|
Args:
|
|
reduce_func: A reducing function to apply to the task's output.
|
|
|
|
Returns:
|
|
TaskOutput[OUT]: The result of applying the reducing function.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
async def streamify(self, transform_func: StreamFunc) -> "TaskOutput[T]":
|
|
"""Convert a value of type T to an AsyncIterator[T] using a transform function.
|
|
|
|
Args:
|
|
transform_func (StreamFunc): Function to transform a T value into an
|
|
AsyncIterator[OUT].
|
|
|
|
Returns:
|
|
TaskOutput[T]: The result of applying the reducing function.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
async def transform_stream(
|
|
self, transform_func: TransformFunc
|
|
) -> "TaskOutput[OUT]":
|
|
"""Transform an AsyncIterator[T] to another AsyncIterator[T].
|
|
|
|
Args:
|
|
transform_func (Callable[[AsyncIterator[T]], AsyncIterator[T]]): Function to
|
|
apply to the AsyncIterator[T].
|
|
|
|
Returns:
|
|
TaskOutput[T]: The result of applying the reducing function.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
async def unstreamify(self, transform_func: UnStreamFunc) -> "TaskOutput[OUT]":
|
|
"""Convert an AsyncIterator[T] to a value of type T using a transform function.
|
|
|
|
Args:
|
|
transform_func (UnStreamFunc): Function to transform an AsyncIterator[T]
|
|
into a T value.
|
|
|
|
Returns:
|
|
TaskOutput[T]: The result of applying the reducing function.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
async def check_condition(self, condition_func) -> "TaskOutput[OUT]":
|
|
"""Check if current output meets a given condition.
|
|
|
|
Args:
|
|
condition_func: A function to determine if the condition is met.
|
|
Returns:
|
|
TaskOutput[T]: The result of applying the reducing function.
|
|
If the condition is not met, return empty output.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
class TaskContext(ABC, Generic[T]):
|
|
"""Abstract base class representing the context of a task within a DAG.
|
|
|
|
This class provides the interface for accessing task-related information
|
|
and manipulating task output.
|
|
"""
|
|
|
|
@property
|
|
@abstractmethod
|
|
def task_id(self) -> str:
|
|
"""Return the unique identifier of the task.
|
|
|
|
Returns:
|
|
str: The unique identifier of the task.
|
|
"""
|
|
|
|
@property
|
|
@abstractmethod
|
|
def task_input(self) -> "InputContext":
|
|
"""Return the InputContext of current task.
|
|
|
|
Returns:
|
|
InputContext: The InputContext of current task.
|
|
|
|
Raises:
|
|
Exception: If the InputContext is not set.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def set_task_input(self, input_ctx: "InputContext") -> None:
|
|
"""Set the InputContext object to current task.
|
|
|
|
Args:
|
|
input_ctx (InputContext): The InputContext of current task
|
|
"""
|
|
|
|
@property
|
|
@abstractmethod
|
|
def task_output(self) -> TaskOutput[T]:
|
|
"""Return the output object of the task.
|
|
|
|
Returns:
|
|
TaskOutput[T]: The output object of the task.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def set_task_output(self, task_output: TaskOutput[T]) -> None:
|
|
"""Set the output object to current task."""
|
|
|
|
@property
|
|
@abstractmethod
|
|
def current_state(self) -> TaskState:
|
|
"""Get the current state of the task.
|
|
|
|
Returns:
|
|
TaskState: The current state of the task.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def set_current_state(self, task_state: TaskState) -> None:
|
|
"""Set current task state.
|
|
|
|
Args:
|
|
task_state (TaskState): The task state to be set.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def new_ctx(self) -> "TaskContext":
|
|
"""Create new task context.
|
|
|
|
Returns:
|
|
TaskContext: A new instance of a TaskContext.
|
|
"""
|
|
|
|
@property
|
|
@abstractmethod
|
|
def metadata(self) -> Dict[str, Any]:
|
|
"""Return the metadata of current task.
|
|
|
|
Returns:
|
|
Dict[str, Any]: The metadata
|
|
"""
|
|
|
|
def update_metadata(self, key: str, value: Any) -> None:
|
|
"""Update metadata with key and value.
|
|
|
|
Args:
|
|
key (str): The key of metadata
|
|
value (str): The value to be add to metadata
|
|
"""
|
|
self.metadata[key] = value
|
|
|
|
@property
|
|
def call_data(self) -> Optional[Dict]:
|
|
"""Return the call data for current data."""
|
|
return self.metadata.get("call_data")
|
|
|
|
@abstractmethod
|
|
async def _call_data_to_output(self) -> Optional[TaskOutput[T]]:
|
|
"""Get the call data for current data."""
|
|
|
|
def set_call_data(self, call_data: Dict) -> None:
|
|
"""Save the call data for current task."""
|
|
self.update_metadata("call_data", call_data)
|
|
|
|
|
|
class InputContext(ABC):
|
|
"""Abstract base class representing the context of inputs to a operator node.
|
|
|
|
This class defines methods to manipulate and access the inputs for a operator node.
|
|
"""
|
|
|
|
@property
|
|
@abstractmethod
|
|
def parent_outputs(self) -> List[TaskContext]:
|
|
"""Get the outputs from the parent nodes.
|
|
|
|
Returns:
|
|
List[TaskContext]: A list of contexts of the parent nodes' outputs.
|
|
"""
|
|
|
|
@abstractmethod
|
|
async def map(self, map_func: Callable[[Any], Any]) -> "InputContext":
|
|
"""Apply a mapping function to the inputs.
|
|
|
|
Args:
|
|
map_func (Callable[[Any], Any]): A function to be applied to the inputs.
|
|
|
|
Returns:
|
|
InputContext: A new InputContext instance with the mapped inputs.
|
|
"""
|
|
|
|
@abstractmethod
|
|
async def map_all(self, map_func: Callable[..., Any]) -> "InputContext":
|
|
"""Apply a mapping function to all inputs.
|
|
|
|
Args:
|
|
map_func (Callable[..., Any]): A function to be applied to all inputs.
|
|
|
|
Returns:
|
|
InputContext: A new InputContext instance with the mapped inputs.
|
|
"""
|
|
|
|
@abstractmethod
|
|
async def reduce(self, reduce_func: ReduceFunc) -> "InputContext":
|
|
"""Apply a reducing function to the inputs.
|
|
|
|
Args:
|
|
reduce_func (Callable[[Any], Any]): A function that reduces the inputs.
|
|
|
|
Returns:
|
|
InputContext: A new InputContext instance with the reduced inputs.
|
|
"""
|
|
|
|
@abstractmethod
|
|
async def filter(self, filter_func: Callable[[Any], bool]) -> "InputContext":
|
|
"""Filter the inputs based on a provided function.
|
|
|
|
Args:
|
|
filter_func (Callable[[Any], bool]): A function that returns True for
|
|
inputs to keep.
|
|
|
|
Returns:
|
|
InputContext: A new InputContext instance with the filtered inputs.
|
|
"""
|
|
|
|
@abstractmethod
|
|
async def predicate_map(
|
|
self, predicate_func: PredicateFunc, failed_value: Any = None
|
|
) -> "InputContext":
|
|
"""Predicate the inputs based on a provided function.
|
|
|
|
Args:
|
|
predicate_func (Callable[[Any], bool]): A function that returns True for
|
|
inputs is predicate True.
|
|
failed_value (Any): The value to be set if the return value of predicate
|
|
function is False
|
|
Returns:
|
|
InputContext: A new InputContext instance with the predicate inputs.
|
|
"""
|
|
|
|
def check_single_parent(self) -> bool:
|
|
"""Check if there is only a single parent output.
|
|
|
|
Returns:
|
|
bool: True if there is only one parent output, False otherwise.
|
|
"""
|
|
return len(self.parent_outputs) == 1
|
|
|
|
def check_stream(self, skip_empty: bool = False) -> bool:
|
|
"""Check if all parent outputs are streams.
|
|
|
|
Args:
|
|
skip_empty (bool): Skip empty output or not.
|
|
|
|
Returns:
|
|
bool: True if all parent outputs are streams, False otherwise.
|
|
"""
|
|
for out in self.parent_outputs:
|
|
if out.task_output.is_empty and skip_empty:
|
|
continue
|
|
if not (out.task_output and out.task_output.is_stream):
|
|
return False
|
|
return True
|
|
|
|
|
|
class InputSource(ABC, Generic[T]):
|
|
"""Abstract base class representing the source of inputs to a DAG node."""
|
|
|
|
@abstractmethod
|
|
async def read(self, task_ctx: TaskContext) -> TaskOutput[T]:
|
|
"""Read the data from current input source.
|
|
|
|
Returns:
|
|
TaskOutput[T]: The output object read from current source
|
|
"""
|
|
|
|
@classmethod
|
|
def from_data(cls, data: T) -> "InputSource[T]":
|
|
"""Create an InputSource from data.
|
|
|
|
Args:
|
|
data (T): The data to create the InputSource from.
|
|
|
|
Returns:
|
|
InputSource[T]: The InputSource created from the data.
|
|
"""
|
|
from .task_impl import SimpleInputSource
|
|
|
|
return SimpleInputSource(data, streaming=False)
|
|
|
|
@classmethod
|
|
def from_iterable(
|
|
cls, iterable: Union[AsyncIterable[T], Iterable[T]]
|
|
) -> "InputSource[T]":
|
|
"""Create an InputSource from an iterable.
|
|
|
|
Args:
|
|
iterable (List[T]): The iterable to create the InputSource from.
|
|
|
|
Returns:
|
|
InputSource[T]: The InputSource created from the iterable.
|
|
"""
|
|
from .task_impl import SimpleInputSource
|
|
|
|
return SimpleInputSource(iterable, streaming=True)
|
|
|
|
@classmethod
|
|
def from_callable(cls) -> "InputSource[T]":
|
|
"""Create an InputSource from a callable."""
|
|
from .task_impl import SimpleCallDataInputSource
|
|
|
|
return SimpleCallDataInputSource()
|