feat(model): Support model cache and first version of Agentic Workflow Expression Language(AWEL)

This commit is contained in:
FangYin Cheng
2023-11-16 04:05:37 +08:00
parent 8eaf3693f0
commit 6db8c49d87
43 changed files with 3030 additions and 21 deletions

57
pilot/awel/__init__.py Normal file
View File

@@ -0,0 +1,57 @@
"""Agentic Workflow Expression Language (AWEL)"""
from .dag.base import DAGContext, DAG
from .operator.base import BaseOperator, WorkflowRunner, initialize_awel
from .operator.common_operator import (
JoinOperator,
ReduceStreamOperator,
MapOperator,
BranchOperator,
InputOperator,
)
from .operator.stream_operator import (
StreamifyAbsOperator,
UnstreamifyAbsOperator,
TransformStreamAbsOperator,
)
from .task.base import TaskState, TaskOutput, TaskContext, InputContext, InputSource
from .task.task_impl import (
SimpleInputSource,
SimpleCallDataInputSource,
DefaultTaskContext,
DefaultInputContext,
SimpleTaskOutput,
SimpleStreamTaskOutput,
)
from .runner.local_runner import DefaultWorkflowRunner
__all__ = [
"initialize_awel",
"DAGContext",
"DAG",
"BaseOperator",
"JoinOperator",
"ReduceStreamOperator",
"MapOperator",
"BranchOperator",
"InputOperator",
"WorkflowRunner",
"TaskState",
"TaskOutput",
"TaskContext",
"InputContext",
"InputSource",
"DefaultWorkflowRunner",
"SimpleInputSource",
"SimpleCallDataInputSource",
"DefaultTaskContext",
"DefaultInputContext",
"SimpleTaskOutput",
"SimpleStreamTaskOutput",
"StreamifyAbsOperator",
"UnstreamifyAbsOperator",
"TransformStreamAbsOperator",
]

View File

252
pilot/awel/dag/base.py Normal file
View File

@@ -0,0 +1,252 @@
from abc import ABC, abstractmethod
from typing import Optional, Dict, List, Sequence, Union, Any
import uuid
import contextvars
import threading
import asyncio
from collections import deque
from ..resource.base import ResourceGroup
from ..task.base import TaskContext
DependencyType = Union["DependencyMixin", Sequence["DependencyMixin"]]
def _is_async_context():
try:
loop = asyncio.get_running_loop()
return asyncio.current_task(loop=loop) is not None
except RuntimeError:
return False
class DependencyMixin(ABC):
@abstractmethod
def set_upstream(self, nodes: DependencyType) -> "DependencyMixin":
"""Set one or more upstream nodes for this node.
Args:
nodes (DependencyType): Upstream nodes to be set to current node.
Returns:
DependencyMixin: Returns self to allow method chaining.
Raises:
ValueError: If no upstream nodes are provided or if an argument is not a DependencyMixin.
"""
@abstractmethod
def set_downstream(self, nodes: DependencyType) -> "DependencyMixin":
"""Set one or more downstream nodes for this node.
Args:
nodes (DependencyType): Downstream nodes to be set to current node.
Returns:
DependencyMixin: Returns self to allow method chaining.
Raises:
ValueError: If no downstream nodes are provided or if an argument is not a DependencyMixin.
"""
def __lshift__(self, nodes: DependencyType) -> DependencyType:
"""Implements self << nodes
Example:
.. code-block:: python
# means node.set_upstream(input_node)
node << input_node
# means node2.set_upstream([input_node])
node2 << [input_node]
"""
self.set_upstream(nodes)
return nodes
def __rshift__(self, nodes: DependencyType) -> DependencyType:
"""Implements self >> nodes
Example:
.. code-block:: python
# means node.set_downstream(next_node)
node >> next_node
# means node2.set_downstream([next_node])
node2 >> [next_node]
"""
self.set_downstream(nodes)
return nodes
def __rrshift__(self, nodes: DependencyType) -> "DependencyMixin":
"""Implements [node] >> self"""
self.__lshift__(nodes)
return self
def __rlshift__(self, nodes: DependencyType) -> "DependencyMixin":
"""Implements [node] << self"""
self.__rshift__(nodes)
return self
class DAGVar:
_thread_local = threading.local()
_async_local = contextvars.ContextVar("current_dag_stack", default=deque())
@classmethod
def enter_dag(cls, dag) -> None:
is_async = _is_async_context()
if is_async:
stack = cls._async_local.get()
stack.append(dag)
cls._async_local.set(stack)
else:
if not hasattr(cls._thread_local, "current_dag_stack"):
cls._thread_local.current_dag_stack = deque()
cls._thread_local.current_dag_stack.append(dag)
@classmethod
def exit_dag(cls) -> None:
is_async = _is_async_context()
if is_async:
stack = cls._async_local.get()
if stack:
stack.pop()
cls._async_local.set(stack)
else:
if (
hasattr(cls._thread_local, "current_dag_stack")
and cls._thread_local.current_dag_stack
):
cls._thread_local.current_dag_stack.pop()
@classmethod
def get_current_dag(cls) -> Optional["DAG"]:
is_async = _is_async_context()
if is_async:
stack = cls._async_local.get()
return stack[-1] if stack else None
else:
if (
hasattr(cls._thread_local, "current_dag_stack")
and cls._thread_local.current_dag_stack
):
return cls._thread_local.current_dag_stack[-1]
return None
class DAGNode(DependencyMixin, ABC):
resource_group: Optional[ResourceGroup] = None
"""The resource group of current DAGNode"""
def __init__(self, dag: Optional["DAG"] = None, node_id: str = None) -> None:
super().__init__()
self._upstream: List["DAGNode"] = []
self._downstream: List["DAGNode"] = []
self._dag: Optional["DAG"] = dag or DAGVar.get_current_dag()
if not node_id and self._dag:
node_id = self._dag._new_node_id()
self._node_id: str = node_id
@property
def node_id(self) -> str:
return self._node_id
def set_node_id(self, node_id: str) -> None:
self._node_id = node_id
@property
def dag(self) -> "DAGNode":
return self._dag
def set_upstream(self, nodes: DependencyType) -> "DAGNode":
self.set_dependency(nodes)
def set_downstream(self, nodes: DependencyType) -> "DAGNode":
self.set_dependency(nodes, is_upstream=False)
@property
def upstream(self) -> List["DAGNode"]:
return self._upstream
@property
def downstream(self) -> List["DAGNode"]:
return self._downstream
def set_dependency(self, nodes: DependencyType, is_upstream: bool = True) -> None:
if not isinstance(nodes, Sequence):
nodes = [nodes]
if not all(isinstance(node, DAGNode) for node in nodes):
raise ValueError(
"all nodes to set dependency to current node must be instance of 'DAGNode'"
)
nodes: Sequence[DAGNode] = nodes
dags = set([node.dag for node in nodes if node.dag])
if self.dag:
dags.add(self.dag)
if not dags:
raise ValueError("set dependency to current node must in a DAG context")
if len(dags) != 1:
raise ValueError(
"set dependency to current node just support in one DAG context"
)
dag = dags.pop()
self._dag = dag
dag._append_node(self)
for node in nodes:
if is_upstream and node not in self.upstream:
node._dag = dag
dag._append_node(node)
self._upstream.append(node)
node._downstream.append(self)
elif node not in self._downstream:
node._dag = dag
dag._append_node(node)
self._downstream.append(node)
node._upstream.append(self)
class DAGContext:
def __init__(self) -> None:
self._curr_task_ctx = None
self._share_data: Dict[str, Any] = {}
@property
def current_task_context(self) -> TaskContext:
return self._curr_task_ctx
def set_current_task_context(self, _curr_task_ctx: TaskContext) -> None:
self._curr_task_ctx = _curr_task_ctx
async def get_share_data(self, key: str) -> Any:
return self._share_data.get(key)
async def save_to_share_data(self, key: str, data: Any) -> None:
self._share_data[key] = data
class DAG:
def __init__(
self, dag_id: str, resource_group: Optional[ResourceGroup] = None
) -> None:
self.node_map: Dict[str, DAGNode] = {}
def _append_node(self, node: DAGNode) -> None:
self.node_map[node.node_id] = node
def _new_node_id(self) -> str:
return str(uuid.uuid4())
def __enter__(self):
DAGVar.enter_dag(self)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
DAGVar.exit_dag()

View File

View File

@@ -0,0 +1,51 @@
import pytest
import threading
import asyncio
from ..dag import DAG, DAGContext
def test_dag_context_sync():
dag1 = DAG("dag1")
dag2 = DAG("dag2")
with dag1:
assert DAGContext.get_current_dag() == dag1
with dag2:
assert DAGContext.get_current_dag() == dag2
assert DAGContext.get_current_dag() == dag1
assert DAGContext.get_current_dag() is None
def test_dag_context_threading():
def thread_function(dag):
DAGContext.enter_dag(dag)
assert DAGContext.get_current_dag() == dag
DAGContext.exit_dag()
dag1 = DAG("dag1")
dag2 = DAG("dag2")
thread1 = threading.Thread(target=thread_function, args=(dag1,))
thread2 = threading.Thread(target=thread_function, args=(dag2,))
thread1.start()
thread2.start()
thread1.join()
thread2.join()
assert DAGContext.get_current_dag() is None
@pytest.mark.asyncio
async def test_dag_context_async():
async def async_function(dag):
DAGContext.enter_dag(dag)
assert DAGContext.get_current_dag() == dag
DAGContext.exit_dag()
dag1 = DAG("dag1")
dag2 = DAG("dag2")
await asyncio.gather(async_function(dag1), async_function(dag2))
assert DAGContext.get_current_dag() is None

View File

176
pilot/awel/operator/base.py Normal file
View File

@@ -0,0 +1,176 @@
from abc import ABC, abstractmethod, ABCMeta
from types import FunctionType
from typing import (
List,
Generic,
TypeVar,
AsyncIterator,
Union,
Any,
Dict,
Optional,
cast,
)
import functools
from inspect import signature
from ..dag.base import DAGNode, DAGContext, DAGVar, DAG
from ..task.base import (
TaskContext,
TaskOutput,
TaskState,
OUT,
T,
InputContext,
InputSource,
)
F = TypeVar("F", bound=FunctionType)
CALL_DATA = Union[Dict, Dict[str, Dict]]
class WorkflowRunner(ABC, Generic[T]):
"""Abstract base class representing a runner for executing workflows in a DAG.
This class defines the interface for executing workflows within the DAG,
handling the flow from one DAG node to another.
"""
@abstractmethod
async def execute_workflow(
self, node: "BaseOperator", call_data: Optional[CALL_DATA] = None
) -> DAGContext:
"""Execute the workflow starting from a given operator.
Args:
node (RunnableDAGNode): The starting node of the workflow to be executed.
call_data (CALL_DATA): The data pass to root operator node.
Returns:
DAGContext: The context after executing the workflow, containing the final state and data.
"""
default_runner: WorkflowRunner = None
class BaseOperatorMeta(ABCMeta):
"""Metaclass of BaseOperator."""
@classmethod
def _apply_defaults(cls, func: F) -> F:
sig_cache = signature(func)
@functools.wraps(func)
def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any:
dag: Optional[DAG] = kwargs.get("dag") or DAGVar.get_current_dag()
task_id: Optional[str] = kwargs.get("task_id")
if not task_id and dag:
task_id = dag._new_node_id()
runner: Optional[WorkflowRunner] = kwargs.get("runner") or default_runner
# print(f"self: {self}, kwargs dag: {kwargs.get('dag')}, kwargs: {kwargs}")
# for arg in sig_cache.parameters:
# if arg not in kwargs:
# kwargs[arg] = default_args[arg]
if not kwargs.get("dag"):
kwargs["dag"] = dag
if not kwargs.get("task_id"):
kwargs["task_id"] = task_id
if not kwargs.get("runner"):
kwargs["runner"] = runner
real_obj = func(self, *args, **kwargs)
return real_obj
return cast(T, apply_defaults)
def __new__(cls, name, bases, namespace, **kwargs):
new_cls = super().__new__(cls, name, bases, namespace, **kwargs)
new_cls.__init__ = cls._apply_defaults(new_cls.__init__)
return new_cls
class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
"""Abstract base class for operator nodes that can be executed within a workflow.
This class extends DAGNode by adding execution capabilities.
"""
def __init__(
self,
task_id: Optional[str] = None,
dag: Optional[DAG] = None,
runner: WorkflowRunner = None,
**kwargs,
) -> None:
"""Initializes a BaseOperator with an optional workflow runner.
Args:
runner (WorkflowRunner, optional): The runner used to execute the workflow. Defaults to None.
"""
super().__init__(node_id=task_id, dag=dag, **kwargs)
if not runner:
from pilot.awel import DefaultWorkflowRunner
runner = DefaultWorkflowRunner()
self._runner: WorkflowRunner = runner
self._dag_ctx: DAGContext = None
@property
def current_dag_context(self) -> DAGContext:
return self._dag_ctx
async def _run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
if not self.node_id:
raise ValueError(f"The DAG Node ID can't be empty, current node {self}")
self._dag_ctx = dag_ctx
return await self._do_run(dag_ctx)
@abstractmethod
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
"""
Abstract method to run the task within the DAG node.
Args:
dag_ctx (DAGContext): The context of the DAG when this node is run.
Returns:
TaskOutput[OUT]: The task output after this node has been run.
"""
async def call(self, call_data: Optional[CALL_DATA] = None) -> OUT:
"""Execute the node and return the output.
This method is a high-level wrapper for executing the node.
Args:
call_data (CALL_DATA): The data pass to root operator node.
Returns:
OUT: The output of the node after execution.
"""
out_ctx = await self._runner.execute_workflow(self, call_data)
return out_ctx.current_task_context.task_output.output
async def call_stream(
self, call_data: Optional[CALL_DATA] = None
) -> AsyncIterator[OUT]:
"""Execute the node and return the output as a stream.
This method is used for nodes where the output is a stream.
Args:
call_data (CALL_DATA): The data pass to root operator node.
Returns:
AsyncIterator[OUT]: An asynchronous iterator over the output stream.
"""
out_ctx = await self._runner.execute_workflow(self, call_data)
return out_ctx.current_task_context.task_output.output_stream
def initialize_awel(runner: WorkflowRunner):
global default_runner
default_runner = runner

View File

@@ -0,0 +1,221 @@
from typing import Generic, Dict, List, Union, Callable, Any, AsyncIterator, Awaitable
import asyncio
from ..dag.base import DAGContext
from ..task.base import (
TaskContext,
TaskOutput,
IN,
OUT,
InputContext,
InputSource,
)
from .base import BaseOperator
class JoinOperator(BaseOperator, Generic[OUT]):
"""Operator that joins inputs using a custom combine function.
This node type is useful for combining the outputs of upstream nodes.
"""
def __init__(self, combine_function, **kwargs):
super().__init__(**kwargs)
if not callable(combine_function):
raise ValueError("combine_function must be callable")
self.combine_function = combine_function
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
"""Run the join operation on the DAG context's inputs.
Args:
dag_ctx (DAGContext): The current context of the DAG.
Returns:
TaskOutput[OUT]: The task output after this node has been run.
"""
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
input_ctx: InputContext = await curr_task_ctx.task_input.map_all(
self.combine_function
)
# All join result store in the first parent output
join_output = input_ctx.parent_outputs[0].task_output
curr_task_ctx.set_task_output(join_output)
return join_output
class ReduceStreamOperator(BaseOperator, Generic[IN, OUT]):
def __init__(self, reduce_function=None, **kwargs):
"""Initializes a ReduceStreamOperator with a combine function.
Args:
combine_function: A function that defines how to combine inputs.
Raises:
ValueError: If the combine_function is not callable.
"""
super().__init__(**kwargs)
if reduce_function and not callable(reduce_function):
raise ValueError("reduce_function must be callable")
self.reduce_function = reduce_function
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
"""Run the join operation on the DAG context's inputs.
Args:
dag_ctx (DAGContext): The current context of the DAG.
Returns:
TaskOutput[OUT]: The task output after this node has been run.
"""
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
task_input = curr_task_ctx.task_input
if not task_input.check_stream():
raise ValueError("ReduceStreamOperator expects stream data")
if not task_input.check_single_parent():
raise ValueError("ReduceStreamOperator expects single parent")
reduce_function = self.reduce_function or self.reduce
input_ctx: InputContext = await task_input.reduce(reduce_function)
# All join result store in the first parent output
reduce_output = input_ctx.parent_outputs[0].task_output
curr_task_ctx.set_task_output(reduce_output)
return reduce_output
async def reduce(self, input_value: AsyncIterator[IN]) -> OUT:
raise NotImplementedError
class MapOperator(BaseOperator, Generic[IN, OUT]):
"""Map operator that applies a mapping function to its inputs.
This operator transforms its input data using a provided mapping function and
passes the transformed data downstream.
"""
def __init__(self, map_function=None, **kwargs):
"""Initializes a MapDAGNode with a mapping function.
Args:
map_function: A function that defines how to map the input data.
Raises:
ValueError: If the map_function is not callable.
"""
super().__init__(**kwargs)
if map_function and not callable(map_function):
raise ValueError("map_function must be callable")
self.map_function = map_function
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
"""Run the mapping operation on the DAG context's inputs.
This method applies the mapping function to the input context and updates
the DAG context with the new data.
Args:
dag_ctx (DAGContext[IN]): The current context of the DAG.
Returns:
TaskOutput[OUT]: The task output after this node has been run.
Raises:
ValueError: If not a single parent or the map_function is not callable
"""
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
if not curr_task_ctx.task_input.check_single_parent():
num_parents = len(curr_task_ctx.task_input.parent_outputs)
raise ValueError(
f"task {curr_task_ctx.task_id} MapDAGNode expects single parent, now number of parents: {num_parents}"
)
map_function = self.map_function or self.map
input_ctx: InputContext = await curr_task_ctx.task_input.map(map_function)
# All join result store in the first parent output
reduce_output = input_ctx.parent_outputs[0].task_output
curr_task_ctx.set_task_output(reduce_output)
return reduce_output
async def map(self, input_value: IN) -> OUT:
raise NotImplementedError
BranchFunc = Union[Callable[[Any], bool], Callable[[Any], Awaitable[bool]]]
class BranchOperator(BaseOperator, Generic[OUT]):
"""Operator node that branches the workflow based on a provided function.
This node filters its input data using a branching function and
allows for conditional paths in the workflow.
"""
def __init__(self, branches: Dict[BranchFunc, BaseOperator], **kwargs):
"""
Initializes a BranchDAGNode with a branching function.
Args:
branches (Dict[BranchFunc, RunnableDAGNode]): Dict of function that defines the branching condition.
Raises:
ValueError: If the branch_function is not callable.
"""
super().__init__(**kwargs)
for branch_function in branches.keys():
if not callable(branch_function):
raise ValueError("branch_function must be callable")
self.branches = branches
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
"""Run the branching operation on the DAG context's inputs.
This method applies the branching function to the input context to determine
the path of execution in the workflow.
Args:
dag_ctx (DAGContext[IN]): The current context of the DAG.
Returns:
TaskOutput[OUT]: The task output after this node has been run.
"""
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
task_input = curr_task_ctx.task_input
if task_input.check_stream():
raise ValueError("BranchDAGNode expects no stream data")
if not task_input.check_single_parent():
raise ValueError("BranchDAGNode expects single parent")
branch_func_tasks = []
branch_nodes: List[BaseOperator] = []
for func, node in self.branches.items():
branch_nodes.append(node)
branch_func_tasks.append(
curr_task_ctx.task_input.predicate_map(func, failed_value=None)
)
branch_input_ctxs: List[InputContext] = await asyncio.gather(*branch_func_tasks)
parent_output = task_input.parent_outputs[0].task_output
curr_task_ctx.set_task_output(parent_output)
for i, ctx in enumerate(branch_input_ctxs):
node = branch_nodes[i]
if ctx.parent_outputs[0].task_output.is_empty:
# Skip current node
# node.current_task_context.set_current_state(TaskState.SKIP)
pass
else:
pass
raise NotImplementedError
return None
class InputOperator(BaseOperator, Generic[OUT]):
def __init__(self, input_source: InputSource[OUT], **kwargs) -> None:
super().__init__(**kwargs)
self._input_source = input_source
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
task_output = await self._input_source.read(curr_task_ctx)
curr_task_ctx.set_task_output(task_output)
return task_output

View File

@@ -0,0 +1,90 @@
from abc import ABC, abstractmethod
from typing import Generic, AsyncIterator
from ..task.base import OUT, IN, TaskOutput, TaskContext
from ..dag.base import DAGContext
from .base import BaseOperator
class StreamifyAbsOperator(BaseOperator[OUT], ABC, Generic[IN, OUT]):
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
output = await curr_task_ctx.task_input.parent_outputs[0].task_output.streamify(
self.streamify
)
curr_task_ctx.set_task_output(output)
return output
@abstractmethod
async def streamify(self, input_value: IN) -> AsyncIterator[OUT]:
"""Convert a value of IN to an AsyncIterator[OUT]
Args:
input_value (IN): The data of parent operator's output
Example:
.. code-block:: python
class MyStreamOperator(StreamifyAbsOperator[int, int]):
async def streamify(self, input_value: int) -> AsyncIterator[int]
for i in range(input_value):
yield i
"""
class UnstreamifyAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
output = await curr_task_ctx.task_input.parent_outputs[
0
].task_output.unstreamify(self.unstreamify)
curr_task_ctx.set_task_output(output)
return output
@abstractmethod
async def unstreamify(self, input_value: AsyncIterator[IN]) -> OUT:
"""Convert a value of AsyncIterator[IN] to an OUT.
Args:
input_value (AsyncIterator[IN])): The data of parent operator's output
Example:
.. code-block:: python
class MyUnstreamOperator(UnstreamifyAbsOperator[int, int]):
async def unstreamify(self, input_value: AsyncIterator[int]) -> int
value_cnt = 0
async for v in input_value:
value_cnt += 1
return value_cnt
"""
class TransformStreamAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
output = await curr_task_ctx.task_input.parent_outputs[
0
].task_output.transform_stream(self.transform_stream)
curr_task_ctx.set_task_output(output)
return output
@abstractmethod
async def transform_stream(
self, input_value: AsyncIterator[IN]
) -> AsyncIterator[OUT]:
"""Transform an AsyncIterator[IN] to another AsyncIterator[OUT] using a given function.
Args:
input_value (AsyncIterator[IN])): The data of parent operator's output
Example:
.. code-block:: python
class MyTransformStreamOperator(TransformStreamAbsOperator[int, int]):
async def unstreamify(self, input_value: AsyncIterator[int]) -> AsyncIterator[int]
async for v in input_value:
yield v + 1
"""

View File

View File

@@ -0,0 +1,8 @@
from abc import ABC, abstractmethod
class ResourceGroup(ABC):
@property
@abstractmethod
def name(self) -> str:
"""The name of current resource group"""

View File

View File

@@ -0,0 +1,74 @@
from typing import List, Set, Optional, Dict
import uuid
from ..dag.base import DAG
from ..operator.base import BaseOperator, CALL_DATA
class DAGNodeInstance:
def __init__(self, node_instance: DAG) -> None:
pass
class DAGInstance:
def __init__(self, dag: DAG) -> None:
self._dag = dag
class JobManager:
def __init__(
self,
root_nodes: List[BaseOperator],
all_nodes: List[BaseOperator],
end_node: BaseOperator,
id2call_data: Dict[str, Dict],
) -> None:
self._root_nodes = root_nodes
self._all_nodes = all_nodes
self._end_node = end_node
self._id2node_data = id2call_data
@staticmethod
def build_from_end_node(
end_node: BaseOperator, call_data: Optional[CALL_DATA] = None
) -> "JobManager":
nodes = _build_from_end_node(end_node)
root_nodes = _get_root_nodes(nodes)
id2call_data = _save_call_data(root_nodes, call_data)
return JobManager(root_nodes, nodes, end_node, id2call_data)
def get_call_data_by_id(self, node_id: str) -> Optional[Dict]:
return self._id2node_data.get(node_id)
def _save_call_data(
root_nodes: List[BaseOperator], call_data: CALL_DATA
) -> Dict[str, Dict]:
id2call_data = {}
if not call_data:
return id2call_data
if len(root_nodes) == 1:
node = root_nodes[0]
id2call_data[node.node_id] = call_data
else:
for node in root_nodes:
node_id = node.node_id
id2call_data[node_id] = call_data.get(node_id)
return id2call_data
def _build_from_end_node(end_node: BaseOperator) -> List[BaseOperator]:
nodes = []
if isinstance(end_node, BaseOperator):
task_id = end_node.node_id
if not task_id:
task_id = str(uuid.uuid4())
end_node.set_node_id(task_id)
nodes.append(end_node)
for node in end_node.upstream:
nodes += _build_from_end_node(node)
return nodes
def _get_root_nodes(nodes: List[BaseOperator]) -> List[BaseOperator]:
return list(filter(lambda x: not x.upstream, nodes))

View File

@@ -0,0 +1,68 @@
from typing import Dict, Optional
from ..dag.base import DAGContext
from ..operator.base import WorkflowRunner, BaseOperator, CALL_DATA
from ..task.base import TaskContext, TaskState
from ..task.task_impl import DefaultInputContext, DefaultTaskContext
from .job_manager import JobManager
class DefaultWorkflowRunner(WorkflowRunner):
async def execute_workflow(
self, node: BaseOperator, call_data: Optional[CALL_DATA] = None
) -> DAGContext:
# Create DAG context
dag_ctx = DAGContext()
job_manager = JobManager.build_from_end_node(node, call_data)
dag = node.dag
# Save node output
node_outputs: Dict[str, TaskContext] = {}
await self._execute_node(job_manager, node, dag_ctx, node_outputs)
return dag_ctx
async def _execute_node(
self,
job_manager: JobManager,
node: BaseOperator,
dag_ctx: DAGContext,
node_outputs: Dict[str, TaskContext],
):
# Skip run node
if node.node_id in node_outputs:
return
# Run all upstream node
for upstream_node in node.upstream:
if isinstance(upstream_node, BaseOperator):
await self._execute_node(
job_manager, upstream_node, dag_ctx, node_outputs
)
# if node.current_task_context.current_state == TaskState.SKIP:
# return
# for upstream_node in node.upstream:
# if (
# isinstance(upstream_node, BaseOperator)
# and upstream_node.current_task_context.current_state == TaskState.SKIP
# ):
# return
# Get the input from upstream node
inputs = [
node_outputs[upstream_node.node_id] for upstream_node in node.upstream
]
input_ctx = DefaultInputContext(inputs)
task_ctx = DefaultTaskContext(node.node_id, TaskState.INIT, task_output=None)
task_ctx.set_call_data(job_manager.get_call_data_by_id(node.node_id))
task_ctx.set_task_input(input_ctx)
dag_ctx.set_current_task_context(task_ctx)
task_ctx.set_current_state(TaskState.RUNNING)
try:
# print(f"Begin run {node}")
await node._run(dag_ctx)
node_outputs[node.node_id] = dag_ctx.current_task_context
task_ctx.set_current_state(TaskState.SUCCESS)
except Exception as e:
task_ctx.set_current_state(TaskState.FAILED)
raise e

View File

358
pilot/awel/task/base.py Normal file
View File

@@ -0,0 +1,358 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import (
TypeVar,
Generic,
Optional,
AsyncIterator,
Union,
Callable,
Any,
Dict,
List,
)
IN = TypeVar("IN")
OUT = TypeVar("OUT")
T = TypeVar("T")
class TaskState(str, Enum):
"""Enumeration representing the state of a task in the workflow.
This Enum defines various states a task can be in during its lifecycle in the DAG.
"""
INIT = "init" # Initial state of the task, not yet started
SKIP = "skip" # State indicating the task was skipped
RUNNING = "running" # State indicating the task is currently running
SUCCESS = "success" # State indicating the task completed successfully
FAILED = "failed" # State indicating the task failed during execution
class TaskOutput(ABC, Generic[T]):
"""Abstract base class representing the output of a task.
This class encapsulates the output of a task and provides methods to access the output data.
It can be subclassed to implement specific output behaviors.
"""
@property
def is_stream(self) -> bool:
"""Check if the output is a stream.
Returns:
bool: True if the output is a stream, False otherwise.
"""
return False
@property
def is_empty(self) -> bool:
"""Check if the output is empty.
Returns:
bool: True if the output is empty, False otherwise.
"""
return False
@property
def output(self) -> Optional[T]:
"""Return the output of the task.
Returns:
T: The output of the task. None if the output is empty.
"""
raise NotImplementedError
@property
def output_stream(self) -> Optional[AsyncIterator[T]]:
"""Return the output of the task as an asynchronous stream.
Returns:
AsyncIterator[T]: An asynchronous iterator over the output. None if the output is empty.
"""
raise NotImplementedError
@abstractmethod
def set_output(self, output_data: Union[T, AsyncIterator[T]]) -> None:
"""Set the output data to current object.
Args:
output_data (Union[T, AsyncIterator[T]]): Output data.
"""
async def map(self, map_func) -> "TaskOutput[T]":
"""Apply a mapping function to the task's output.
Args:
map_func: A function to apply to the task's output.
Returns:
TaskOutput[T]: The result of applying the mapping function.
"""
raise NotImplementedError
async def reduce(self, reduce_func) -> "TaskOutput[T]":
"""Apply a reducing function to the task's output.
Stream TaskOutput to Nonstream TaskOutput.
Args:
reduce_func: A reducing function to apply to the task's output.
Returns:
TaskOutput[T]: The result of applying the reducing function.
"""
raise NotImplementedError
async def streamify(
self, transform_func: Callable[[T], AsyncIterator[T]]
) -> "TaskOutput[T]":
"""Convert a value of type T to an AsyncIterator[T] using a transform function.
Args:
transform_func (Callable[[T], AsyncIterator[T]]): Function to transform a T value into an AsyncIterator[T].
Returns:
TaskOutput[T]: The result of applying the reducing function.
"""
raise NotImplementedError
async def transform_stream(
self, transform_func: Callable[[AsyncIterator[T]], AsyncIterator[T]]
) -> "TaskOutput[T]":
"""Transform an AsyncIterator[T] to another AsyncIterator[T] using a given function.
Args:
transform_func (Callable[[AsyncIterator[T]], AsyncIterator[T]]): Function to apply to the AsyncIterator[T].
Returns:
TaskOutput[T]: The result of applying the reducing function.
"""
raise NotImplementedError
async def unstreamify(
self, transform_func: Callable[[AsyncIterator[T]], T]
) -> "TaskOutput[T]":
"""Convert an AsyncIterator[T] to a value of type T using a transform function.
Args:
transform_func (Callable[[AsyncIterator[T]], T]): Function to transform an AsyncIterator[T] into a T value.
Returns:
TaskOutput[T]: The result of applying the reducing function.
"""
raise NotImplementedError
async def check_condition(self, condition_func) -> bool:
"""Check if current output meets a given condition.
Args:
condition_func: A function to determine if the condition is met.
Returns:
bool: True if current output meet the condition, False otherwise.
"""
raise NotImplementedError
class TaskContext(ABC, Generic[T]):
"""Abstract base class representing the context of a task within a DAG.
This class provides the interface for accessing task-related information
and manipulating task output.
"""
@property
@abstractmethod
def task_id(self) -> str:
"""Return the unique identifier of the task.
Returns:
str: The unique identifier of the task.
"""
@property
@abstractmethod
def task_input(self) -> "InputContext":
"""Return the InputContext of current task.
Returns:
InputContext: The InputContext of current task.
"""
@abstractmethod
def set_task_input(self, input_ctx: "InputContext") -> None:
"""Set the InputContext object to current task.
Args:
input_ctx (InputContext): The InputContext of current task
"""
@property
@abstractmethod
def task_output(self) -> TaskOutput[T]:
"""Return the output object of the task.
Returns:
TaskOutput[T]: The output object of the task.
"""
@abstractmethod
def set_task_output(self, task_output: TaskOutput[T]) -> None:
"""Set the output object to current task."""
@property
@abstractmethod
def current_state(self) -> TaskState:
"""Get the current state of the task.
Returns:
TaskState: The current state of the task.
"""
@abstractmethod
def set_current_state(self, task_state: TaskState) -> None:
"""Set current task state
Args:
task_state (TaskState): The task state to be set.
"""
@abstractmethod
def new_ctx(self) -> "TaskContext":
"""Create new task context
Returns:
TaskContext: A new instance of a TaskContext.
"""
@property
@abstractmethod
def metadata(self) -> Dict[str, Any]:
"""Get the metadata of current task
Returns:
Dict[str, Any]: The metadata
"""
def update_metadata(self, key: str, value: Any) -> None:
"""Update metadata with key and value
Args:
key (str): The key of metadata
value (str): The value to be add to metadata
"""
self.metadata[key] = value
@property
def call_data(self) -> Optional[Dict]:
"""Get the call data for current data"""
return self.metadata.get("call_data")
def set_call_data(self, call_data: Dict) -> None:
"""Set call data for current task"""
self.update_metadata("call_data", call_data)
class InputContext(ABC):
"""Abstract base class representing the context of inputs to a operator node.
This class defines methods to manipulate and access the inputs for a operator node.
"""
@property
@abstractmethod
def parent_outputs(self) -> List[TaskContext]:
"""Get the outputs from the parent nodes.
Returns:
List[TaskContext]: A list of contexts of the parent nodes' outputs.
"""
@abstractmethod
async def map(self, map_func: Callable[[Any], Any]) -> "InputContext":
"""Apply a mapping function to the inputs.
Args:
map_func (Callable[[Any], Any]): A function to be applied to the inputs.
Returns:
InputContext: A new InputContext instance with the mapped inputs.
"""
@abstractmethod
async def map_all(self, map_func: Callable[..., Any]) -> "InputContext":
"""Apply a mapping function to all inputs.
Args:
map_func (Callable[..., Any]): A function to be applied to all inputs.
Returns:
InputContext: A new InputContext instance with the mapped inputs.
"""
@abstractmethod
async def reduce(self, reduce_func: Callable[[Any], Any]) -> "InputContext":
"""Apply a reducing function to the inputs.
Args:
reduce_func (Callable[[Any], Any]): A function that reduces the inputs.
Returns:
InputContext: A new InputContext instance with the reduced inputs.
"""
@abstractmethod
async def filter(self, filter_func: Callable[[Any], bool]) -> "InputContext":
"""Filter the inputs based on a provided function.
Args:
filter_func (Callable[[Any], bool]): A function that returns True for inputs to keep.
Returns:
InputContext: A new InputContext instance with the filtered inputs.
"""
@abstractmethod
async def predicate_map(
self, predicate_func: Callable[[Any], bool], failed_value: Any = None
) -> "InputContext":
"""Predicate the inputs based on a provided function.
Args:
predicate_func (Callable[[Any], bool]): A function that returns True for inputs is predicate True.
failed_value (Any): The value to be set if the return value of predicate function is False
Returns:
InputContext: A new InputContext instance with the predicate inputs.
"""
def check_single_parent(self) -> bool:
"""Check if there is only a single parent output.
Returns:
bool: True if there is only one parent output, False otherwise.
"""
return len(self.parent_outputs) == 1
def check_stream(self) -> bool:
"""Check if all parent outputs are streams.
Returns:
bool: True if all parent outputs are streams, False otherwise.
"""
for out in self.parent_outputs:
if not (out.task_output and out.task_output.is_stream):
return False
return True
class InputSource(ABC, Generic[T]):
"""Abstract base class representing the source of inputs to a DAG node."""
@abstractmethod
async def read(self, task_ctx: TaskContext) -> TaskOutput[T]:
"""Read the data from current input source.
Returns:
TaskOutput[T]: The output object read from current source
"""

View File

@@ -0,0 +1,318 @@
from abc import ABC, abstractmethod
from typing import (
Callable,
Coroutine,
Iterator,
AsyncIterator,
List,
Generic,
TypeVar,
Any,
Tuple,
Dict,
Union,
)
import asyncio
from .base import TaskOutput, TaskContext, TaskState, InputContext, InputSource, T
async def _reduce_stream(stream: AsyncIterator, reduce_function) -> Any:
# Init accumulator
try:
accumulator = await stream.__anext__()
except StopAsyncIteration:
raise ValueError("Stream is empty")
is_async = asyncio.iscoroutinefunction(reduce_function)
async for element in stream:
if is_async:
accumulator = await reduce_function(accumulator, element)
else:
accumulator = reduce_function(accumulator, element)
return accumulator
class SimpleTaskOutput(TaskOutput[T], Generic[T]):
def __init__(self, data: T) -> None:
super().__init__()
self._data = data
@property
def output(self) -> T:
return self._data
def set_output(self, output_data: T | AsyncIterator[T]) -> None:
self._data = output_data
@property
def is_empty(self) -> bool:
return not self._data
async def _apply_func(self, func) -> Any:
if asyncio.iscoroutinefunction(func):
out = await func(self._data)
else:
out = func(self._data)
return out
async def map(self, map_func) -> TaskOutput[T]:
out = await self._apply_func(map_func)
return SimpleTaskOutput(out)
async def check_condition(self, condition_func) -> bool:
return await self._apply_func(condition_func)
async def streamify(
self, transform_func: Callable[[T], AsyncIterator[T]]
) -> TaskOutput[T]:
out = await self._apply_func(transform_func)
return SimpleStreamTaskOutput(out)
class SimpleStreamTaskOutput(TaskOutput[T], Generic[T]):
def __init__(self, data: AsyncIterator[T]) -> None:
super().__init__()
self._data = data
@property
def is_stream(self) -> bool:
return True
@property
def is_empty(self) -> bool:
return not self._data
@property
def output_stream(self) -> AsyncIterator[T]:
return self._data
def set_output(self, output_data: T | AsyncIterator[T]) -> None:
self._data = output_data
async def map(self, map_func) -> TaskOutput[T]:
is_async = asyncio.iscoroutinefunction(map_func)
async def new_iter() -> AsyncIterator[T]:
async for out in self._data:
if is_async:
out = await map_func(out)
else:
out = map_func(out)
yield out
return SimpleStreamTaskOutput(new_iter())
async def reduce(self, reduce_func) -> TaskOutput[T]:
out = await _reduce_stream(self._data, reduce_func)
return SimpleTaskOutput(out)
async def unstreamify(
self, transform_func: Callable[[AsyncIterator[T]], T]
) -> TaskOutput[T]:
if asyncio.iscoroutinefunction(transform_func):
out = await transform_func(self._data)
else:
out = transform_func(self._data)
return SimpleTaskOutput(out)
async def transform_stream(
self, transform_func: Callable[[AsyncIterator[T]], AsyncIterator[T]]
) -> TaskOutput[T]:
if asyncio.iscoroutinefunction(transform_func):
out = await transform_func(self._data)
else:
out = transform_func(self._data)
return SimpleStreamTaskOutput(out)
def _is_async_iterator(obj):
return (
hasattr(obj, "__anext__")
and callable(getattr(obj, "__anext__", None))
and hasattr(obj, "__aiter__")
and callable(getattr(obj, "__aiter__", None))
)
class BaseInputSource(InputSource, ABC):
def __init__(self) -> None:
super().__init__()
self._is_read = False
@abstractmethod
def _read_data(self, task_ctx: TaskContext) -> Any:
"""Read data with task context"""
async def read(self, task_ctx: TaskContext) -> Coroutine[Any, Any, TaskOutput]:
data = self._read_data(task_ctx)
if _is_async_iterator(data):
if self._is_read:
raise ValueError(f"Input iterator {data} has been read!")
output = SimpleStreamTaskOutput(data)
else:
output = SimpleTaskOutput(data)
self._is_read = True
return output
class SimpleInputSource(BaseInputSource):
def __init__(self, data: Any) -> None:
super().__init__()
self._data = data
def _read_data(self, task_ctx: TaskContext) -> Any:
return self._data
class SimpleCallDataInputSource(BaseInputSource):
def __init__(self) -> None:
super().__init__()
def _read_data(self, task_ctx: TaskContext) -> Any:
call_data = task_ctx.call_data
data = call_data.get("data") if call_data else None
if not (call_data and data):
raise ValueError("No call data for current SimpleCallDataInputSource")
return data
class DefaultTaskContext(TaskContext, Generic[T]):
def __init__(
self, task_id: str, task_state: TaskState, task_output: TaskOutput[T]
) -> None:
super().__init__()
self._task_id = task_id
self._task_state = task_state
self._output = task_output
self._task_input = None
self._metadata = {}
@property
def task_id(self) -> str:
return self._task_id
@property
def task_input(self) -> InputContext:
return self._task_input
def set_task_input(self, input_ctx: "InputContext") -> None:
self._task_input = input_ctx
@property
def task_output(self) -> TaskOutput:
return self._output
def set_task_output(self, task_output: TaskOutput) -> None:
self._output = task_output
@property
def current_state(self) -> TaskState:
return self._task_state
def set_current_state(self, task_state: TaskState) -> None:
self._task_state = task_state
def new_ctx(self) -> TaskContext:
return DefaultTaskContext(self._task_id, self._task_state, self._output)
@property
def metadata(self) -> Dict[str, Any]:
return self._metadata
class DefaultInputContext(InputContext):
def __init__(self, outputs: List[TaskContext]) -> None:
super().__init__()
self._outputs = outputs
@property
def parent_outputs(self) -> List[TaskContext]:
return self._outputs
async def _apply_func(
self, func: Callable[[Any], Any], apply_type: str = "map"
) -> Tuple[List[TaskContext], List[TaskOutput]]:
new_outputs: List[TaskContext] = []
map_tasks = []
for out in self._outputs:
new_outputs.append(out.new_ctx())
result = None
if apply_type == "map":
result = out.task_output.map(func)
elif apply_type == "reduce":
result = out.task_output.reduce(func)
elif apply_type == "check_condition":
result = out.task_output.check_condition(func)
else:
raise ValueError(f"Unsupport apply type {apply_type}")
map_tasks.append(result)
results = await asyncio.gather(*map_tasks)
return new_outputs, results
async def map(self, map_func: Callable[[Any], Any]) -> InputContext:
new_outputs, results = await self._apply_func(map_func)
for i, task_ctx in enumerate(new_outputs):
task_ctx: TaskContext = task_ctx
task_ctx.set_task_output(results[i])
return DefaultInputContext(new_outputs)
async def map_all(self, map_func: Callable[..., Any]) -> InputContext:
if not self._outputs:
return DefaultInputContext([])
is_steam = self._outputs[0].task_output.is_stream
if is_steam:
if not self.check_stream():
raise ValueError(
"The output in all tasks must has same output format to map_all"
)
outputs = []
for out in self._outputs:
if out.task_output.is_stream:
outputs.append(out.task_output.output_stream)
else:
outputs.append(out.task_output.output)
if asyncio.iscoroutinefunction(map_func):
map_res = await map_func(*outputs)
else:
map_res = map_func(*outputs)
single_output: TaskContext = self._outputs[0].new_ctx()
single_output.task_output.set_output(map_res)
return DefaultInputContext([single_output])
async def reduce(self, reduce_func: Callable[[Any], Any]) -> InputContext:
if not self.check_stream():
raise ValueError(
"The output in all tasks must has same output format of stream to apply reduce function"
)
new_outputs, results = await self._apply_func(reduce_func, apply_type="reduce")
for i, task_ctx in enumerate(new_outputs):
task_ctx: TaskContext = task_ctx
task_ctx.set_task_output(results[i])
return DefaultInputContext(new_outputs)
async def filter(self, filter_func: Callable[[Any], bool]) -> InputContext:
new_outputs, results = await self._apply_func(
filter_func, apply_type="check_condition"
)
result_outputs = []
for i, task_ctx in enumerate(new_outputs):
if results[i]:
result_outputs.append(task_ctx)
return DefaultInputContext(result_outputs)
async def predicate_map(
self, predicate_func: Callable[[Any], bool], failed_value: Any = None
) -> "InputContext":
new_outputs, results = await self._apply_func(
predicate_func, apply_type="check_condition"
)
result_outputs = []
for i, task_ctx in enumerate(new_outputs):
task_ctx: TaskContext = task_ctx
if results[i]:
result_outputs.append(task_ctx)
else:
task_ctx.task_output.set_output(failed_value)
result_outputs.append(task_ctx)
return DefaultInputContext(result_outputs)

View File

View File

@@ -0,0 +1,102 @@
import pytest
import pytest_asyncio
from typing import AsyncIterator, List
from contextlib import contextmanager, asynccontextmanager
from .. import (
WorkflowRunner,
InputOperator,
DAGContext,
TaskState,
DefaultWorkflowRunner,
SimpleInputSource,
)
from ..task.task_impl import _is_async_iterator
@pytest.fixture
def runner():
return DefaultWorkflowRunner()
def _create_stream(num_nodes) -> List[AsyncIterator[int]]:
iters = []
for _ in range(num_nodes):
async def stream_iter():
for i in range(10):
yield i
stream_iter = stream_iter()
assert _is_async_iterator(stream_iter)
iters.append(stream_iter)
return iters
def _create_stream_from(output_streams: List[List[int]]) -> List[AsyncIterator[int]]:
iters = []
for single_stream in output_streams:
async def stream_iter():
for i in single_stream:
yield i
stream_iter = stream_iter()
assert _is_async_iterator(stream_iter)
iters.append(stream_iter)
return iters
@asynccontextmanager
async def _create_input_node(**kwargs):
num_nodes = kwargs.get("num_nodes")
is_stream = kwargs.get("is_stream", False)
if is_stream:
outputs = kwargs.get("output_streams")
if outputs:
if num_nodes and num_nodes != len(outputs):
raise ValueError(
f"num_nodes {num_nodes} != the length of output_streams {len(outputs)}"
)
outputs = _create_stream_from(outputs)
else:
num_nodes = num_nodes or 1
outputs = _create_stream(num_nodes)
else:
outputs = kwargs.get("outputs", ["Hello."])
nodes = []
for output in outputs:
print(f"output: {output}")
input_source = SimpleInputSource(output)
input_node = InputOperator(input_source)
nodes.append(input_node)
yield nodes
@pytest_asyncio.fixture
async def input_node(request):
param = getattr(request, "param", {})
async with _create_input_node(**param) as input_nodes:
yield input_nodes[0]
@pytest_asyncio.fixture
async def stream_input_node(request):
param = getattr(request, "param", {})
param["is_stream"] = True
async with _create_input_node(**param) as input_nodes:
yield input_nodes[0]
@pytest_asyncio.fixture
async def input_nodes(request):
param = getattr(request, "param", {})
async with _create_input_node(**param) as input_nodes:
yield input_nodes
@pytest_asyncio.fixture
async def stream_input_nodes(request):
param = getattr(request, "param", {})
param["is_stream"] = True
async with _create_input_node(**param) as input_nodes:
yield input_nodes

View File

@@ -0,0 +1,51 @@
import pytest
from typing import List
from .. import (
DAG,
WorkflowRunner,
DAGContext,
TaskState,
InputOperator,
MapOperator,
JoinOperator,
BranchOperator,
ReduceStreamOperator,
SimpleInputSource,
)
from .conftest import (
runner,
input_node,
input_nodes,
stream_input_node,
stream_input_nodes,
_is_async_iterator,
)
def _register_dag_to_fastapi_app(dag):
# TODO
pass
@pytest.mark.asyncio
async def test_http_operator(runner: WorkflowRunner, stream_input_node: InputOperator):
with DAG("test_map") as dag:
pass
# http_req_task = HttpRequestOperator(endpoint="/api/completions")
# db_task = DBQueryOperator(table_name="user_info")
# prompt_task = PromptTemplateOperator(
# system_prompt="You are an AI designed to solve the user's goals with given commands, please follow the constraints of the system's input for your answers."
# )
# llm_task = ChatGPTLLMOperator(model="chagpt-3.5")
# output_parser_task = CommonOutputParserOperator()
# http_res_task = HttpResponseOperator()
# (
# http_req_task
# >> db_task
# >> prompt_task
# >> llm_task
# >> output_parser_task
# >> http_res_task
# )
_register_dag_to_fastapi_app(dag)

View File

@@ -0,0 +1,128 @@
import pytest
from typing import List
from .. import (
DAG,
WorkflowRunner,
DAGContext,
TaskState,
InputOperator,
MapOperator,
JoinOperator,
BranchOperator,
ReduceStreamOperator,
SimpleInputSource,
)
from .conftest import (
runner,
input_node,
input_nodes,
stream_input_node,
stream_input_nodes,
_is_async_iterator,
)
@pytest.mark.asyncio
async def test_input_node(runner: WorkflowRunner):
input_node = InputOperator(SimpleInputSource("hello"))
res: DAGContext[str] = await runner.execute_workflow(input_node)
assert res.current_task_context.current_state == TaskState.SUCCESS
assert res.current_task_context.task_output.output == "hello"
async def new_steam_iter(n: int):
for i in range(n):
yield i
num_iter = 10
steam_input_node = InputOperator(SimpleInputSource(new_steam_iter(num_iter)))
res: DAGContext[str] = await runner.execute_workflow(steam_input_node)
assert res.current_task_context.current_state == TaskState.SUCCESS
output_steam = res.current_task_context.task_output.output_stream
assert output_steam
assert _is_async_iterator(output_steam)
i = 0
async for x in output_steam:
assert x == i
i += 1
@pytest.mark.asyncio
async def test_map_node(runner: WorkflowRunner, stream_input_node: InputOperator):
with DAG("test_map") as dag:
map_node = MapOperator(lambda x: x * 2)
stream_input_node >> map_node
res: DAGContext[int] = await runner.execute_workflow(map_node)
output_steam = res.current_task_context.task_output.output_stream
assert output_steam
i = 0
async for x in output_steam:
assert x == i * 2
i += 1
@pytest.mark.asyncio
@pytest.mark.parametrize(
"stream_input_node, expect_sum",
[
({"output_streams": [[0, 1, 2, 3]]}, 6),
({"output_streams": [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]}, 55),
],
indirect=["stream_input_node"],
)
async def test_reduce_node(
runner: WorkflowRunner, stream_input_node: InputOperator, expect_sum: int
):
with DAG("test_reduce_node") as dag:
reduce_node = ReduceStreamOperator(lambda x, y: x + y)
stream_input_node >> reduce_node
res: DAGContext[int] = await runner.execute_workflow(reduce_node)
assert res.current_task_context.current_state == TaskState.SUCCESS
assert not res.current_task_context.task_output.is_stream
assert res.current_task_context.task_output.output == expect_sum
@pytest.mark.asyncio
@pytest.mark.parametrize(
"input_nodes",
[
({"outputs": [0, 1, 2]}),
],
indirect=["input_nodes"],
)
async def test_join_node(runner: WorkflowRunner, input_nodes: List[InputOperator]):
def join_func(p1, p2, p3) -> int:
return p1 + p2 + p3
with DAG("test_join_node") as dag:
join_node = JoinOperator(join_func)
for input_node in input_nodes:
input_node >> join_node
res: DAGContext[int] = await runner.execute_workflow(join_node)
assert res.current_task_context.current_state == TaskState.SUCCESS
assert not res.current_task_context.task_output.is_stream
assert res.current_task_context.task_output.output == 3
@pytest.mark.asyncio
@pytest.mark.parametrize(
"input_node, is_odd",
[
({"outputs": [0]}, False),
# ({"outputs": [1]}, False),
],
indirect=["input_node"],
)
async def test_branch_node(
runner: WorkflowRunner, input_node: InputOperator, is_odd: bool
):
with DAG("test_join_node") as dag:
odd_node = MapOperator(lambda x: 999, task_id="odd_node")
even_node = MapOperator(lambda x: 888, task_id="even_node")
branch_node = BranchOperator(
{lambda x: x % 2 == 1: odd_node, lambda x: x % 2 == 0: even_node}
)
input_node >> branch_node
odd_res: DAGContext[int] = await runner.execute_workflow(odd_node)
even_res: DAGContext[int] = await runner.execute_workflow(even_node)
assert branch_node.current_task_context.current_state == TaskState.SUCCESS

10
pilot/cache/__init__.py vendored Normal file
View File

@@ -0,0 +1,10 @@
from pilot.cache.llm_cache import LLMCacheClient, LLMCacheKey, LLMCacheValue
from pilot.cache.manager import CacheManager, initialize_cache
__all__ = [
"LLMCacheKey",
"LLMCacheValue",
"LLMCacheClient",
"CacheManager",
"initialize_cache",
]

161
pilot/cache/base.py vendored Normal file
View File

@@ -0,0 +1,161 @@
from abc import ABC, abstractmethod, abstractclassmethod
from typing import Any, TypeVar, Generic, Optional, Type, Dict
from dataclasses import dataclass
from enum import Enum
T = TypeVar("T", bound="Serializable")
K = TypeVar("K")
V = TypeVar("V")
class Serializable(ABC):
@abstractmethod
def serialize(self) -> bytes:
"""Convert the object into bytes for storage or transmission.
Returns:
bytes: The byte array after serialization
"""
@abstractmethod
def to_dict(self) -> Dict:
"""Convert the object's state to a dictionary."""
# @staticmethod
# @abstractclassmethod
# def from_dict(cls: Type["Serializable"], obj_dict: Dict) -> "Serializable":
# """Deserialize a dictionary to an Serializable object.
# """
class RetrievalPolicy(str, Enum):
EXACT_MATCH = "exact_match"
SIMILARITY_MATCH = "similarity_match"
class CachePolicy(str, Enum):
LRU = "lru"
FIFO = "fifo"
@dataclass
class CacheConfig:
retrieval_policy: Optional[RetrievalPolicy] = RetrievalPolicy.EXACT_MATCH
cache_policy: Optional[CachePolicy] = CachePolicy.LRU
class CacheKey(Serializable, ABC, Generic[K]):
"""The key of the cache. Must be hashable and comparable.
Supported cache keys:
- The LLM cache key: Include user prompt and the parameters to LLM.
- The embedding model cache key: Include the texts to embedding and the parameters to embedding model.
"""
@abstractmethod
def __hash__(self) -> int:
"""Return the hash value of the key."""
@abstractmethod
def __eq__(self, other: Any) -> bool:
"""Check equality with another key."""
@abstractmethod
def get_hash_bytes(self) -> bytes:
"""Return the byte array of hash value."""
@abstractmethod
def get_value(self) -> K:
"""Get the underlying value of the cache key.
Returns:
K: The real object of current cache key
"""
class CacheValue(Serializable, ABC, Generic[V]):
"""Cache value abstract class."""
@abstractmethod
def get_value(self) -> V:
"""Get the underlying real value."""
class Serializer(ABC):
"""The serializer abstract class for serializing cache keys and values."""
@abstractmethod
def serialize(self, obj: Serializable) -> bytes:
"""Serialize a cache object.
Args:
obj (Serializable): The object to serialize
"""
@abstractmethod
def deserialize(self, data: bytes, cls: Type[Serializable]) -> Serializable:
"""Deserialize data back into a cache object of the specified type.
Args:
data (bytes): The byte array to deserialize
cls (Type[Serializable]): The type of current object
Returns:
Serializable: The serializable object
"""
class CacheClient(ABC, Generic[K, V]):
"""The cache client interface."""
@abstractmethod
async def get(
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
) -> Optional[CacheValue[V]]:
"""Retrieve a value from the cache using the provided key.
Args:
key (CacheKey[K]): The key to get cache
cache_config (Optional[CacheConfig]): Cache config
Returns:
Optional[CacheValue[V]]: The value retrieved according to key. If cache key not exist, return None.
"""
@abstractmethod
async def set(
self,
key: CacheKey[K],
value: CacheValue[V],
cache_config: Optional[CacheConfig] = None,
) -> None:
"""Set a value in the cache for the provided key.
Args:
key (CacheKey[K]): The key to set to cache
value (CacheValue[V]): The value to set to cache
cache_config (Optional[CacheConfig]): Cache config
"""
@abstractmethod
async def exists(
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
) -> bool:
"""Check if a key exists in the cache.
Args:
key (CacheKey[K]): The key to set to cache
cache_config (Optional[CacheConfig]): Cache config
Return:
bool: True if the key in the cache, otherwise is False
"""
@abstractmethod
def new_key(self, **kwargs) -> CacheKey[K]:
"""Create a cache key with params"""
@abstractmethod
def new_value(self, **kwargs) -> CacheValue[K]:
"""Create a cache key with params"""

0
pilot/cache/embedding_cache.py vendored Normal file
View File

148
pilot/cache/llm_cache.py vendored Normal file
View File

@@ -0,0 +1,148 @@
from typing import Optional, Dict, Any, Union, List
from dataclasses import dataclass, asdict
import json
import hashlib
from pilot.cache.base import CacheKey, CacheValue, Serializer, CacheClient, CacheConfig
from pilot.cache.manager import CacheManager
from pilot.cache.storage.base import CacheStorage
from pilot.model.base import ModelType, ModelOutput
@dataclass
class LLMCacheKeyData:
prompt: str
model_name: str
temperature: Optional[float] = 0.7
max_new_tokens: Optional[int] = None
top_p: Optional[float] = 1.0
model_type: Optional[str] = ModelType.HF
CacheOutputType = Union[ModelOutput, List[ModelOutput]]
@dataclass
class LLMCacheValueData:
output: CacheOutputType
user: Optional[str] = None
_is_list: Optional[bool] = False
@staticmethod
def from_dict(**kwargs) -> "LLMCacheValueData":
output = kwargs.get("output")
if not output:
raise ValueError("Can't new LLMCacheValueData object, output is None")
if isinstance(output, dict):
output = ModelOutput(**output)
elif isinstance(output, list):
kwargs["_is_list"] = True
output_list = []
for out in output:
if isinstance(out, dict):
out = ModelOutput(**out)
output_list.append(out)
output = output_list
kwargs["output"] = output
return LLMCacheValueData(**kwargs)
def to_dict(self) -> Dict:
output = self.output
is_list = False
if isinstance(output, list):
output_list = []
is_list = True
for out in output:
output_list.append(out.to_dict())
output = output_list
else:
output = output.to_dict()
return {"output": output, "_is_list": is_list, "user": self.user}
@property
def is_list(self) -> bool:
return self._is_list
def __str__(self) -> str:
if not isinstance(self.output, list):
return f"user: {self.user}, output: {self.output}"
else:
return f"user: {self.user}, output(last two item): {self.output[-2:]}"
class LLMCacheKey(CacheKey[LLMCacheKeyData]):
def __init__(self, serializer: Serializer = None, **kwargs) -> None:
super().__init__()
self._serializer = serializer
self.config = LLMCacheKeyData(**kwargs)
def __hash__(self) -> int:
serialize_bytes = self.serialize()
return int(hashlib.sha256(serialize_bytes).hexdigest(), 16)
def __eq__(self, other: Any) -> bool:
if not isinstance(other, LLMCacheKey):
return False
return self.config == other.config
def get_hash_bytes(self) -> bytes:
serialize_bytes = self.serialize()
return hashlib.sha256(serialize_bytes).digest()
def to_dict(self) -> Dict:
return asdict(self.config)
def serialize(self) -> bytes:
return self._serializer.serialize(self)
def get_value(self) -> LLMCacheKeyData:
return self.config
class LLMCacheValue(CacheValue[LLMCacheValueData]):
def __init__(self, serializer: Serializer = None, **kwargs) -> None:
super().__init__()
self._serializer = serializer
self.value = LLMCacheValueData.from_dict(**kwargs)
def to_dict(self) -> Dict:
return self.value.to_dict()
def serialize(self) -> bytes:
return self._serializer.serialize(self)
def get_value(self) -> LLMCacheValueData:
return self.value
def __str__(self) -> str:
return f"vaue: {str(self.value)}"
class LLMCacheClient(CacheClient[LLMCacheKeyData, LLMCacheValueData]):
def __init__(self, cache_manager: CacheManager) -> None:
super().__init__()
self._cache_manager: CacheManager = cache_manager
async def get(
self, key: LLMCacheKey, cache_config: Optional[CacheConfig] = None
) -> Optional[LLMCacheValue]:
return await self._cache_manager.get(key, LLMCacheValue, cache_config)
async def set(
self,
key: LLMCacheKey,
value: LLMCacheValue,
cache_config: Optional[CacheConfig] = None,
) -> None:
return await self._cache_manager.set(key, value, cache_config)
async def exists(
self, key: LLMCacheKey, cache_config: Optional[CacheConfig] = None
) -> bool:
return await self.get(key, cache_config) is not None
def new_key(self, **kwargs) -> LLMCacheKey:
return LLMCacheKey(serializer=self._cache_manager.serializer, **kwargs)
def new_value(self, **kwargs) -> LLMCacheValue:
return LLMCacheValue(serializer=self._cache_manager.serializer, **kwargs)

118
pilot/cache/manager.py vendored Normal file
View File

@@ -0,0 +1,118 @@
from abc import ABC, abstractmethod
from typing import Optional, Type
import logging
from concurrent.futures import Executor
from pilot.cache.storage.base import CacheStorage, StorageItem
from pilot.cache.base import (
K,
V,
CacheKey,
CacheValue,
CacheConfig,
Serializer,
Serializable,
)
from pilot.component import BaseComponent, ComponentType, SystemApp
from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async
logger = logging.getLogger(__name__)
class CacheManager(BaseComponent, ABC):
name = ComponentType.MODEL_CACHE_MANAGER
def __init__(self, system_app: SystemApp | None = None):
super().__init__(system_app)
def init_app(self, system_app: SystemApp):
self.system_app = system_app
@abstractmethod
async def set(
self,
key: CacheKey[K],
value: CacheValue[V],
cache_config: Optional[CacheConfig] = None,
):
"""Set cache"""
@abstractmethod
async def get(
self,
key: CacheKey[K],
cls: Type[Serializable],
cache_config: Optional[CacheConfig] = None,
) -> CacheValue[V]:
"""Get cache with key"""
@property
@abstractmethod
def serializer(self) -> Serializer:
"""Get cache serializer"""
class LocalCacheManager(CacheManager):
def __init__(
self, system_app: SystemApp, serializer: Serializer, storage: CacheStorage
) -> None:
super().__init__(system_app)
self._serializer = serializer
self._storage = storage
@property
def executor(self) -> Executor:
"""Return executor to submit task"""
self._executor = self.system_app.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
async def set(
self,
key: CacheKey[K],
value: CacheValue[V],
cache_config: Optional[CacheConfig] = None,
):
if self._storage.support_async():
await self._storage.aset(key, value, cache_config)
else:
await blocking_func_to_async(
self.executor, self._storage.set, key, value, cache_config
)
async def get(
self,
key: CacheKey[K],
cls: Type[Serializable],
cache_config: Optional[CacheConfig] = None,
) -> CacheValue[V]:
if self._storage.support_async():
item_bytes = await self._storage.aget(key, cache_config)
else:
item_bytes = await blocking_func_to_async(
self.executor, self._storage.get, key, cache_config
)
if not item_bytes:
return None
return self._serializer.deserialize(item_bytes.value_data, cls)
@property
def serializer(self) -> Serializer:
return self._serializer
def initialize_cache(system_app: SystemApp, persist_dir: str):
from pilot.cache.protocal.json_protocal import JsonSerializer
from pilot.cache.storage.base import MemoryCacheStorage
try:
from pilot.cache.storage.disk.disk_storage import DiskCacheStorage
cache_storage = DiskCacheStorage(persist_dir)
except ImportError as e:
logger.warn(
f"Can't import DiskCacheStorage, use MemoryCacheStorage, import error message: {str(e)}"
)
cache_storage = MemoryCacheStorage()
system_app.register(
LocalCacheManager, serializer=JsonSerializer(), storage=cache_storage
)

0
pilot/cache/protocal/__init__.py vendored Normal file
View File

44
pilot/cache/protocal/json_protocal.py vendored Normal file
View File

@@ -0,0 +1,44 @@
from abc import ABC, abstractmethod
from typing import Dict, Type
import json
from pilot.cache.base import Serializable, Serializer
JSON_ENCODING = "utf-8"
class JsonSerializable(Serializable, ABC):
@abstractmethod
def to_dict(self) -> Dict:
"""Return the dict of current serializable object"""
def serialize(self) -> bytes:
"""Convert the object into bytes for storage or transmission."""
return json.dumps(self.to_dict(), ensure_ascii=False).encode(JSON_ENCODING)
class JsonSerializer(Serializer):
"""The serializer abstract class for serializing cache keys and values."""
def serialize(self, obj: Serializable) -> bytes:
"""Serialize a cache object.
Args:
obj (Serializable): The object to serialize
"""
return json.dumps(obj.to_dict(), ensure_ascii=False).encode(JSON_ENCODING)
def deserialize(self, data: bytes, cls: Type[Serializable]) -> Serializable:
"""Deserialize data back into a cache object of the specified type.
Args:
data (bytes): The byte array to deserialize
cls (Type[Serializable]): The type of current object
Returns:
Serializable: The serializable object
"""
# Convert bytes back to JSON and then to the specified class
json_data = json.loads(data.decode(JSON_ENCODING))
# Assume that the cls has an __init__ that accepts a dictionary
return cls(**json_data)

0
pilot/cache/storage/__init__.py vendored Normal file
View File

252
pilot/cache/storage/base.py vendored Normal file
View File

@@ -0,0 +1,252 @@
from abc import ABC, abstractmethod
from typing import Optional
from dataclasses import dataclass
from collections import OrderedDict
import msgpack
import logging
from pilot.cache.base import (
K,
V,
CacheKey,
CacheValue,
CacheClient,
CacheConfig,
RetrievalPolicy,
CachePolicy,
)
from pilot.utils.memory_utils import _get_object_bytes
logger = logging.getLogger(__name__)
@dataclass
class StorageItem:
"""
A class representing a storage item.
This class encapsulates data related to a storage item, such as its length,
the hash of the key, and the data for both the key and value.
Parameters:
length (int): The bytes length of the storage item.
key_hash (bytes): The hash value of the storage item's key.
key_data (bytes): The data of the storage item's key, represented in bytes.
value_data (bytes): The data of the storage item's value, also in bytes.
"""
length: int # The bytes length of the storage item
key_hash: bytes # The hash value of the storage item's key
key_data: bytes # The data of the storage item's key
value_data: bytes # The data of the storage item's value
@staticmethod
def build_from(
key_hash: bytes, key_data: bytes, value_data: bytes
) -> "StorageItem":
length = (
32
+ _get_object_bytes(key_hash)
+ _get_object_bytes(key_data)
+ _get_object_bytes(value_data)
)
return StorageItem(
length=length, key_hash=key_hash, key_data=key_data, value_data=value_data
)
@staticmethod
def build_from_kv(key: CacheKey[K], value: CacheValue[V]) -> "StorageItem":
key_hash = key.get_hash_bytes()
key_data = key.serialize()
value_data = value.serialize()
return StorageItem.build_from(key_hash, key_data, value_data)
def serialize(self) -> bytes:
"""Serialize the StorageItem into a byte stream using MessagePack.
This method packs the object data into a dictionary, marking the
key_data and value_data fields as raw binary data to avoid re-serialization.
Returns:
bytes: The serialized bytes.
"""
obj = {
"length": self.length,
"key_hash": msgpack.ExtType(1, self.key_hash),
"key_data": msgpack.ExtType(2, self.key_data),
"value_data": msgpack.ExtType(3, self.value_data),
}
return msgpack.packb(obj)
@staticmethod
def deserialize(data: bytes) -> "StorageItem":
"""Deserialize bytes back into a StorageItem using MessagePack.
This extracts the fields from the MessagePack dict back into
a StorageItem object.
Args:
data (bytes): Serialized bytes
Returns:
StorageItem: Deserialized StorageItem object.
"""
obj = msgpack.unpackb(data)
key_hash = obj["key_hash"].data
key_data = obj["key_data"].data
value_data = obj["value_data"].data
return StorageItem(
length=obj["length"],
key_hash=key_hash,
key_data=key_data,
value_data=value_data,
)
class CacheStorage(ABC):
@abstractmethod
def check_config(
self,
cache_config: Optional[CacheConfig] = None,
raise_error: Optional[bool] = True,
) -> bool:
"""Check whether the CacheConfig is legal.
Args:
cache_config (Optional[CacheConfig]): Cache config.
raise_error (Optional[bool]): Whether raise error if illegal.
Returns:
ValueError: Error when raise_error is True and config is illegal.
"""
def support_async(self) -> bool:
return False
@abstractmethod
def get(
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
) -> Optional[StorageItem]:
"""Retrieve a storage item from the cache using the provided key.
Args:
key (CacheKey[K]): The key to get cache
cache_config (Optional[CacheConfig]): Cache config
Returns:
Optional[StorageItem]: The storage item retrieved according to key. If cache key not exist, return None.
"""
async def aget(
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
) -> Optional[StorageItem]:
"""Retrieve a storage item from the cache using the provided key asynchronously.
Args:
key (CacheKey[K]): The key to get cache
cache_config (Optional[CacheConfig]): Cache config
Returns:
Optional[StorageItem]: The storage item of bytes retrieved according to key. If cache key not exist, return None.
"""
raise NotImplementedError
@abstractmethod
def set(
self,
key: CacheKey[K],
value: CacheValue[V],
cache_config: Optional[CacheConfig] = None,
) -> None:
"""Set a value in the cache for the provided key asynchronously.
Args:
key (CacheKey[K]): The key to set to cache
value (CacheValue[V]): The value to set to cache
cache_config (Optional[CacheConfig]): Cache config
"""
async def aset(
self,
key: CacheKey[K],
value: CacheValue[V],
cache_config: Optional[CacheConfig] = None,
) -> None:
"""Set a value in the cache for the provided key asynchronously.
Args:
key (CacheKey[K]): The key to set to cache
value (CacheValue[V]): The value to set to cache
cache_config (Optional[CacheConfig]): Cache config
"""
raise NotImplementedError
class MemoryCacheStorage(CacheStorage):
def __init__(self, max_memory_mb: int = 1024):
self.cache = OrderedDict()
self.max_memory = max_memory_mb * 1024 * 1024
self.current_memory_usage = 0
def check_config(
self,
cache_config: Optional[CacheConfig] = None,
raise_error: Optional[bool] = True,
) -> bool:
if (
cache_config
and cache_config.retrieval_policy != RetrievalPolicy.EXACT_MATCH
):
if raise_error:
raise ValueError(
"MemoryCacheStorage only supports 'EXACT_MATCH' retrieval policy"
)
return False
return True
def get(
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
) -> Optional[StorageItem]:
self.check_config(cache_config, raise_error=True)
# Exact match retrieval
key_hash = hash(key)
item: StorageItem = self.cache.get(key_hash)
logger.debug(f"MemoryCacheStorage get key {key}, hash {key_hash}, item: {item}")
if not item:
return None
# Move the item to the end of the OrderedDict to signify recent use.
self.cache.move_to_end(key_hash)
return item
def set(
self,
key: CacheKey[K],
value: CacheValue[V],
cache_config: Optional[CacheConfig] = None,
) -> None:
key_hash = hash(key)
item = StorageItem.build_from_kv(key, value)
# Calculate memory size of the new entry
new_entry_size = _get_object_bytes(item)
# Evict entries if necessary
while self.current_memory_usage + new_entry_size > self.max_memory:
self._apply_cache_policy(cache_config)
# Store the item in the cache.
self.cache[key_hash] = item
self.current_memory_usage += new_entry_size
logger.debug(f"MemoryCacheStorage set key {key}, hash {key_hash}, item: {item}")
def exists(
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
) -> bool:
return self.get(key, cache_config) is not None
def _apply_cache_policy(self, cache_config: Optional[CacheConfig] = None):
# Remove the oldest/newest item based on the cache policy.
if cache_config and cache_config.cache_policy == CachePolicy.FIFO:
self.cache.popitem(last=False)
else: # Default is LRU
self.cache.popitem(last=True)

0
pilot/cache/storage/disk/__init__.py vendored Normal file
View File

View File

@@ -0,0 +1,93 @@
from typing import Optional
import logging
from pilot.cache.base import (
K,
V,
CacheKey,
CacheValue,
CacheConfig,
RetrievalPolicy,
CachePolicy,
)
from pilot.cache.storage.base import StorageItem, CacheStorage
from rocksdict import Rdict
from rocksdict import Rdict, Options, SliceTransform, PlainTableFactoryOptions
logger = logging.getLogger(__name__)
def db_options(
mem_table_buffer_mb: Optional[int] = 256, background_threads: Optional[int] = 2
):
opt = Options()
# create table
opt.create_if_missing(True)
# config to more jobs, default 2
opt.set_max_background_jobs(background_threads)
# configure mem-table to a large value
opt.set_write_buffer_size(mem_table_buffer_mb * 1024 * 1024)
# opt.set_write_buffer_size(1024)
# opt.set_level_zero_file_num_compaction_trigger(4)
# configure l0 and l1 size, let them have the same size (1 GB)
# opt.set_max_bytes_for_level_base(0x40000000)
# 256 MB file size
# opt.set_target_file_size_base(0x10000000)
# use a smaller compaction multiplier
# opt.set_max_bytes_for_level_multiplier(4.0)
# use 8-byte prefix (2 ^ 64 is far enough for transaction counts)
# opt.set_prefix_extractor(SliceTransform.create_max_len_prefix(8))
# set to plain-table
# opt.set_plain_table_factory(PlainTableFactoryOptions())
return opt
class DiskCacheStorage(CacheStorage):
def __init__(
self, persist_dir: str, mem_table_buffer_mb: Optional[int] = 256
) -> None:
super().__init__()
self.db: Rdict = Rdict(
persist_dir, db_options(mem_table_buffer_mb=mem_table_buffer_mb)
)
def check_config(
self,
cache_config: Optional[CacheConfig] = None,
raise_error: Optional[bool] = True,
) -> bool:
if (
cache_config
and cache_config.retrieval_policy != RetrievalPolicy.EXACT_MATCH
):
if raise_error:
raise ValueError(
"DiskCacheStorage only supports 'EXACT_MATCH' retrieval policy"
)
return False
return True
def get(
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
) -> Optional[StorageItem]:
self.check_config(cache_config, raise_error=True)
# Exact match retrieval
key_hash = key.get_hash_bytes()
item_bytes = self.db.get(key_hash)
if not item_bytes:
return None
item = StorageItem.deserialize(item_bytes)
logger.debug(f"Read file cache, key: {key}, storage item: {item}")
return item
def set(
self,
key: CacheKey[K],
value: CacheValue[V],
cache_config: Optional[CacheConfig] = None,
) -> None:
item = StorageItem.build_from_kv(key, value)
key_hash = item.key_hash
self.db[key_hash] = item.serialize()
logger.debug(f"Save file cache, key: {key}, value: {value}")

0
pilot/cache/storage/tests/__init__.py vendored Normal file
View File

View File

@@ -0,0 +1,53 @@
import pytest
from ..base import StorageItem
from pilot.utils.memory_utils import _get_object_bytes
def test_build_from():
key_hash = b"key_hash"
key_data = b"key_data"
value_data = b"value_data"
item = StorageItem.build_from(key_hash, key_data, value_data)
assert item.key_hash == key_hash
assert item.key_data == key_data
assert item.value_data == value_data
assert item.length == 32 + _get_object_bytes(key_hash) + _get_object_bytes(
key_data
) + _get_object_bytes(value_data)
def test_build_from_kv():
class MockCacheKey:
def get_hash_bytes(self):
return b"key_hash"
def serialize(self):
return b"key_data"
class MockCacheValue:
def serialize(self):
return b"value_data"
key = MockCacheKey()
value = MockCacheValue()
item = StorageItem.build_from_kv(key, value)
assert item.key_hash == key.get_hash_bytes()
assert item.key_data == key.serialize()
assert item.value_data == value.serialize()
def test_serialize_deserialize():
key_hash = b"key_hash"
key_data = b"key_data"
value_data = b"value_data"
item = StorageItem.build_from(key_hash, key_data, value_data)
serialized = item.serialize()
deserialized = StorageItem.deserialize(serialized)
assert deserialized.key_hash == item.key_hash
assert deserialized.key_data == item.key_data
assert deserialized.value_data == item.value_data
assert deserialized.length == item.length

View File

@@ -48,6 +48,7 @@ class ComponentType(str, Enum):
MODEL_CONTROLLER = "dbgpt_model_controller"
MODEL_REGISTRY = "dbgpt_model_registry"
MODEL_API_SERVER = "dbgpt_model_api_server"
MODEL_CACHE_MANAGER = "dbgpt_model_cache_manager"
AGENT_HUB = "dbgpt_agent_hub"
EXECUTOR_DEFAULT = "dbgpt_thread_pool_default"
TRACER = "dbgpt_tracer"

View File

@@ -253,6 +253,11 @@ class Config(metaclass=Singleton):
### Temporary configuration
self.USE_FASTCHAT: bool = os.getenv("USE_FASTCHAT", "True").lower() == "true"
self.MODEL_CACHE_STORAGE: str = os.getenv("MODEL_CACHE_STORAGE")
self.MODEL_CACHE_STORAGE_DIST_DIR: str = os.getenv(
"MODEL_CACHE_STORAGE_DIST_DIR"
)
def set_debug_mode(self, value: bool) -> None:
"""Set the debug mode value"""
self.debug_mode = value

View File

@@ -14,6 +14,7 @@ DATA_DIR = os.path.join(PILOT_PATH, "data")
# nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path
PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins")
FONT_DIR = os.path.join(PILOT_PATH, "fonts")
MODEL_DISK_CACHE_DIR = os.path.join(DATA_DIR, "model_cache")
current_directory = os.getcwd()

View File

View File

@@ -0,0 +1,119 @@
from typing import AsyncIterator, Dict
import logging
from pilot.awel import (
StreamifyAbsOperator,
MapOperator,
TransformStreamAbsOperator,
)
from pilot.model.base import ModelOutput
from pilot.model.cluster import WorkerManager
from pilot.cache import LLMCacheClient, CacheManager, LLMCacheKey, LLMCacheValue
logger = logging.getLogger(__name__)
_LLM_MODEL_INPUT_VALUE_KEY = "llm_model_input_value"
_LLM_MODEL_OUTPUT_CACHE_KEY = "llm_model_output_cache"
class ModelStreamOperator(StreamifyAbsOperator[Dict, ModelOutput]):
def __init__(self, worker_manager: WorkerManager, **kwargs) -> None:
super().__init__(**kwargs)
self.worker_manager = worker_manager
async def streamify(self, input_value: Dict) -> AsyncIterator[ModelOutput]:
llm_cache_value: LLMCacheValue = await self.current_dag_context.get_share_data(
_LLM_MODEL_OUTPUT_CACHE_KEY
)
logger.info(f"llm_cache_value: {llm_cache_value}")
if llm_cache_value:
for out in llm_cache_value.get_value().output:
yield out
return
async for out in self.worker_manager.generate_stream(input_value):
yield out
class ModelOperator(MapOperator[Dict, ModelOutput]):
def __init__(self, worker_manager: WorkerManager, **kwargs) -> None:
self.worker_manager = worker_manager
super().__init__(**kwargs)
async def map(self, input_value: Dict) -> ModelOutput:
llm_cache_value: LLMCacheValue = await self.current_dag_context.get_share_data(
_LLM_MODEL_OUTPUT_CACHE_KEY
)
logger.info(f"llm_cache_value: {llm_cache_value}")
if llm_cache_value:
return llm_cache_value.get_value().output
return await self.worker_manager.generate(input_value)
class ModelCachePreOperator(MapOperator[Dict, Dict]):
def __init__(self, cache_manager: CacheManager, **kwargs):
self._cache_manager = cache_manager
self._client = LLMCacheClient(cache_manager)
super().__init__(**kwargs)
async def map(self, input_value: Dict) -> Dict:
cache_dict = {
"prompt": input_value.get("prompt"),
"model_name": input_value.get("model"),
"temperature": input_value.get("temperature"),
"max_new_tokens": input_value.get("max_new_tokens"),
"top_p": input_value.get("top_p", "1.0"),
# TODO pass model_type
"model_type": input_value.get("model_type", "huggingface"),
}
cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
cache_value = await self._client.get(cache_key)
logger.debug(
f"cache_key: {cache_key}, hash key: {hash(cache_key)}, cache_value: {cache_value}"
)
await self.current_dag_context.save_to_share_data(
_LLM_MODEL_INPUT_VALUE_KEY, cache_key
)
if cache_value:
logger.info(f"The model output has cached, cache_value: {cache_value}")
await self.current_dag_context.save_to_share_data(
_LLM_MODEL_OUTPUT_CACHE_KEY, cache_value
)
return input_value
class ModelStreamCacheOperator(TransformStreamAbsOperator[ModelOutput, ModelOutput]):
def __init__(self, cache_manager: CacheManager, **kwargs):
self._cache_manager = cache_manager
self._client = LLMCacheClient(cache_manager)
super().__init__(**kwargs)
async def transform_stream(
self, input_value: AsyncIterator[ModelOutput]
) -> AsyncIterator[ModelOutput]:
llm_cache_key: LLMCacheKey = None
outputs = []
async for out in input_value:
if not llm_cache_key:
llm_cache_key = await self.current_dag_context.get_share_data(
_LLM_MODEL_INPUT_VALUE_KEY
)
outputs.append(out)
yield out
if llm_cache_key:
llm_cache_value: LLMCacheValue = self._client.new_value(output=outputs)
await self._client.set(llm_cache_key, llm_cache_value)
class ModelCacheOperator(MapOperator[ModelOutput, ModelOutput]):
def __init__(self, cache_manager: CacheManager, **kwargs):
self._cache_manager = cache_manager
self._client = LLMCacheClient(cache_manager)
super().__init__(**kwargs)
async def map(self, input_value: ModelOutput) -> ModelOutput:
llm_cache_key: LLMCacheKey = await self.current_dag_context.get_share_data(
_LLM_MODEL_INPUT_VALUE_KEY
)
llm_cache_value: LLMCacheValue = self._client.new_value(output=input_value)
if llm_cache_key:
await self._client.set(llm_cache_key, llm_cache_value)
return input_value

View File

@@ -16,6 +16,8 @@ from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async
from pilot.utils.tracer import root_tracer, trace
from pydantic import Extra
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
from pilot.awel import BaseOperator, SimpleCallDataInputSource, InputOperator, DAG
from pilot.model.operator.model_operator import ModelOperator, ModelStreamOperator
logger = logging.getLogger(__name__)
headers = {"User-Agent": "dbgpt Client"}
@@ -88,6 +90,11 @@ class BaseChat(ABC):
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
self._model_operator: BaseOperator = _build_model_operator()
self._model_stream_operator: BaseOperator = _build_model_operator(
is_stream=True, dag_name="llm_stream_model_dag"
)
class Config:
"""Configuration for this pydantic object."""
@@ -204,12 +211,9 @@ class BaseChat(ABC):
)
payload["span_id"] = span.span_id
try:
from pilot.model.cluster import WorkerManagerFactory
worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
async for output in worker_manager.generate_stream(payload):
async for output in await self._model_stream_operator.call_stream(
call_data={"data": payload}
):
### Plug-in research in result generation
msg = self.prompt_template.output_parser.parse_model_stream_resp_ex(
output, self.skip_echo_len
@@ -240,14 +244,10 @@ class BaseChat(ABC):
)
payload["span_id"] = span.span_id
try:
from pilot.model.cluster import WorkerManagerFactory
worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
with root_tracer.start_span("BaseChat.invoke_worker_manager.generate"):
model_output = await worker_manager.generate(payload)
model_output = await self._model_operator.call(
call_data={"data": payload}
)
### output parse
ai_response_text = (
@@ -307,14 +307,7 @@ class BaseChat(ABC):
logger.info(f"Request: \n{payload}")
ai_response_text = ""
try:
from pilot.model.cluster import WorkerManagerFactory
worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
model_output = await worker_manager.generate(payload)
model_output = await self._model_operator.call(call_data={"data": payload})
### output parse
ai_response_text = (
self.prompt_template.output_parser.parse_model_nostream_resp(
@@ -568,3 +561,34 @@ class BaseChat(ABC):
)
else:
return prompt_define_response
def _build_model_operator(
is_stream: bool = False, dag_name: str = "llm_model_dag"
) -> BaseOperator:
from pilot.model.cluster import WorkerManagerFactory
from pilot.model.operator.model_operator import (
ModelCacheOperator,
ModelStreamCacheOperator,
ModelCachePreOperator,
)
from pilot.cache import CacheManager
worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
cache_manager: CacheManager = CFG.SYSTEM_APP.get_component(
ComponentType.MODEL_CACHE_MANAGER, CacheManager
)
with DAG(dag_name):
input_node = InputOperator(SimpleCallDataInputSource())
cache_check_node = ModelCachePreOperator(cache_manager)
if is_stream:
model_node = ModelStreamOperator(worker_manager)
cache_node = ModelStreamCacheOperator(cache_manager)
else:
model_node = ModelOperator(worker_manager)
cache_node = ModelCacheOperator(cache_manager)
input_node >> cache_check_node >> model_node >> cache_node
return cache_node

View File

@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Type
import os
from pilot.component import ComponentType, SystemApp
from pilot.configs.model_config import MODEL_DISK_CACHE_DIR
from pilot.utils.executor_utils import DefaultExecutorFactory
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
from pilot.server.base import WebWerverParameters
@@ -41,6 +42,10 @@ def initialize_components(
param, system_app, embedding_model_name, embedding_model_path
)
from pilot.cache import initialize_cache
initialize_cache(system_app, MODEL_DISK_CACHE_DIR)
def _initialize_embedding_model(
param: WebWerverParameters,

View File

@@ -0,0 +1,11 @@
from typing import Any
from pympler import asizeof
def _get_object_bytes(obj: Any) -> int:
"""Get the bytes of a object in memory
Args:
obj (Any): The object to return the bytes
"""
return asizeof.asizeof(obj)

View File

@@ -319,6 +319,8 @@ def core_requires():
"alembic==1.12.0",
# for excel
"openpyxl",
# for cache, TODO pympler has not been updated for a long time and needs to find a new toolkit.
"pympler",
]
@@ -410,6 +412,13 @@ def vllm_requires():
setup_spec.extras["vllm"] = ["vllm"]
def cache_requires():
"""
pip install "db-gpt[cache]"
"""
setup_spec.extras["cache"] = ["rocksdict", "msgpack"]
# def chat_scene():
# setup_spec.extras["chat"] = [
# ""
@@ -460,6 +469,7 @@ all_datasource_requires()
openai_requires()
gpt4all_requires()
vllm_requires()
cache_requires()
# must be last
default_requires()