mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-06 11:31:12 +00:00
feat(model): Support model cache and first version of Agentic Workflow Expression Language(AWEL)
This commit is contained in:
57
pilot/awel/__init__.py
Normal file
57
pilot/awel/__init__.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Agentic Workflow Expression Language (AWEL)"""
|
||||
|
||||
from .dag.base import DAGContext, DAG
|
||||
|
||||
from .operator.base import BaseOperator, WorkflowRunner, initialize_awel
|
||||
from .operator.common_operator import (
|
||||
JoinOperator,
|
||||
ReduceStreamOperator,
|
||||
MapOperator,
|
||||
BranchOperator,
|
||||
InputOperator,
|
||||
)
|
||||
|
||||
from .operator.stream_operator import (
|
||||
StreamifyAbsOperator,
|
||||
UnstreamifyAbsOperator,
|
||||
TransformStreamAbsOperator,
|
||||
)
|
||||
|
||||
from .task.base import TaskState, TaskOutput, TaskContext, InputContext, InputSource
|
||||
from .task.task_impl import (
|
||||
SimpleInputSource,
|
||||
SimpleCallDataInputSource,
|
||||
DefaultTaskContext,
|
||||
DefaultInputContext,
|
||||
SimpleTaskOutput,
|
||||
SimpleStreamTaskOutput,
|
||||
)
|
||||
from .runner.local_runner import DefaultWorkflowRunner
|
||||
|
||||
__all__ = [
|
||||
"initialize_awel",
|
||||
"DAGContext",
|
||||
"DAG",
|
||||
"BaseOperator",
|
||||
"JoinOperator",
|
||||
"ReduceStreamOperator",
|
||||
"MapOperator",
|
||||
"BranchOperator",
|
||||
"InputOperator",
|
||||
"WorkflowRunner",
|
||||
"TaskState",
|
||||
"TaskOutput",
|
||||
"TaskContext",
|
||||
"InputContext",
|
||||
"InputSource",
|
||||
"DefaultWorkflowRunner",
|
||||
"SimpleInputSource",
|
||||
"SimpleCallDataInputSource",
|
||||
"DefaultTaskContext",
|
||||
"DefaultInputContext",
|
||||
"SimpleTaskOutput",
|
||||
"SimpleStreamTaskOutput",
|
||||
"StreamifyAbsOperator",
|
||||
"UnstreamifyAbsOperator",
|
||||
"TransformStreamAbsOperator",
|
||||
]
|
0
pilot/awel/dag/__init__.py
Normal file
0
pilot/awel/dag/__init__.py
Normal file
252
pilot/awel/dag/base.py
Normal file
252
pilot/awel/dag/base.py
Normal file
@@ -0,0 +1,252 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Dict, List, Sequence, Union, Any
|
||||
import uuid
|
||||
import contextvars
|
||||
import threading
|
||||
import asyncio
|
||||
from collections import deque
|
||||
|
||||
from ..resource.base import ResourceGroup
|
||||
from ..task.base import TaskContext
|
||||
|
||||
DependencyType = Union["DependencyMixin", Sequence["DependencyMixin"]]
|
||||
|
||||
|
||||
def _is_async_context():
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
return asyncio.current_task(loop=loop) is not None
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
|
||||
class DependencyMixin(ABC):
|
||||
@abstractmethod
|
||||
def set_upstream(self, nodes: DependencyType) -> "DependencyMixin":
|
||||
"""Set one or more upstream nodes for this node.
|
||||
|
||||
Args:
|
||||
nodes (DependencyType): Upstream nodes to be set to current node.
|
||||
|
||||
Returns:
|
||||
DependencyMixin: Returns self to allow method chaining.
|
||||
|
||||
Raises:
|
||||
ValueError: If no upstream nodes are provided or if an argument is not a DependencyMixin.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_downstream(self, nodes: DependencyType) -> "DependencyMixin":
|
||||
"""Set one or more downstream nodes for this node.
|
||||
|
||||
Args:
|
||||
nodes (DependencyType): Downstream nodes to be set to current node.
|
||||
|
||||
Returns:
|
||||
DependencyMixin: Returns self to allow method chaining.
|
||||
|
||||
Raises:
|
||||
ValueError: If no downstream nodes are provided or if an argument is not a DependencyMixin.
|
||||
"""
|
||||
|
||||
def __lshift__(self, nodes: DependencyType) -> DependencyType:
|
||||
"""Implements self << nodes
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# means node.set_upstream(input_node)
|
||||
node << input_node
|
||||
|
||||
# means node2.set_upstream([input_node])
|
||||
node2 << [input_node]
|
||||
"""
|
||||
self.set_upstream(nodes)
|
||||
return nodes
|
||||
|
||||
def __rshift__(self, nodes: DependencyType) -> DependencyType:
|
||||
"""Implements self >> nodes
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# means node.set_downstream(next_node)
|
||||
node >> next_node
|
||||
|
||||
# means node2.set_downstream([next_node])
|
||||
node2 >> [next_node]
|
||||
|
||||
"""
|
||||
self.set_downstream(nodes)
|
||||
return nodes
|
||||
|
||||
def __rrshift__(self, nodes: DependencyType) -> "DependencyMixin":
|
||||
"""Implements [node] >> self"""
|
||||
self.__lshift__(nodes)
|
||||
return self
|
||||
|
||||
def __rlshift__(self, nodes: DependencyType) -> "DependencyMixin":
|
||||
"""Implements [node] << self"""
|
||||
self.__rshift__(nodes)
|
||||
return self
|
||||
|
||||
|
||||
class DAGVar:
|
||||
_thread_local = threading.local()
|
||||
_async_local = contextvars.ContextVar("current_dag_stack", default=deque())
|
||||
|
||||
@classmethod
|
||||
def enter_dag(cls, dag) -> None:
|
||||
is_async = _is_async_context()
|
||||
if is_async:
|
||||
stack = cls._async_local.get()
|
||||
stack.append(dag)
|
||||
cls._async_local.set(stack)
|
||||
else:
|
||||
if not hasattr(cls._thread_local, "current_dag_stack"):
|
||||
cls._thread_local.current_dag_stack = deque()
|
||||
cls._thread_local.current_dag_stack.append(dag)
|
||||
|
||||
@classmethod
|
||||
def exit_dag(cls) -> None:
|
||||
is_async = _is_async_context()
|
||||
if is_async:
|
||||
stack = cls._async_local.get()
|
||||
if stack:
|
||||
stack.pop()
|
||||
cls._async_local.set(stack)
|
||||
else:
|
||||
if (
|
||||
hasattr(cls._thread_local, "current_dag_stack")
|
||||
and cls._thread_local.current_dag_stack
|
||||
):
|
||||
cls._thread_local.current_dag_stack.pop()
|
||||
|
||||
@classmethod
|
||||
def get_current_dag(cls) -> Optional["DAG"]:
|
||||
is_async = _is_async_context()
|
||||
if is_async:
|
||||
stack = cls._async_local.get()
|
||||
return stack[-1] if stack else None
|
||||
else:
|
||||
if (
|
||||
hasattr(cls._thread_local, "current_dag_stack")
|
||||
and cls._thread_local.current_dag_stack
|
||||
):
|
||||
return cls._thread_local.current_dag_stack[-1]
|
||||
return None
|
||||
|
||||
|
||||
class DAGNode(DependencyMixin, ABC):
|
||||
resource_group: Optional[ResourceGroup] = None
|
||||
"""The resource group of current DAGNode"""
|
||||
|
||||
def __init__(self, dag: Optional["DAG"] = None, node_id: str = None) -> None:
|
||||
super().__init__()
|
||||
self._upstream: List["DAGNode"] = []
|
||||
self._downstream: List["DAGNode"] = []
|
||||
self._dag: Optional["DAG"] = dag or DAGVar.get_current_dag()
|
||||
if not node_id and self._dag:
|
||||
node_id = self._dag._new_node_id()
|
||||
self._node_id: str = node_id
|
||||
|
||||
@property
|
||||
def node_id(self) -> str:
|
||||
return self._node_id
|
||||
|
||||
def set_node_id(self, node_id: str) -> None:
|
||||
self._node_id = node_id
|
||||
|
||||
@property
|
||||
def dag(self) -> "DAGNode":
|
||||
return self._dag
|
||||
|
||||
def set_upstream(self, nodes: DependencyType) -> "DAGNode":
|
||||
self.set_dependency(nodes)
|
||||
|
||||
def set_downstream(self, nodes: DependencyType) -> "DAGNode":
|
||||
self.set_dependency(nodes, is_upstream=False)
|
||||
|
||||
@property
|
||||
def upstream(self) -> List["DAGNode"]:
|
||||
return self._upstream
|
||||
|
||||
@property
|
||||
def downstream(self) -> List["DAGNode"]:
|
||||
return self._downstream
|
||||
|
||||
def set_dependency(self, nodes: DependencyType, is_upstream: bool = True) -> None:
|
||||
if not isinstance(nodes, Sequence):
|
||||
nodes = [nodes]
|
||||
if not all(isinstance(node, DAGNode) for node in nodes):
|
||||
raise ValueError(
|
||||
"all nodes to set dependency to current node must be instance of 'DAGNode'"
|
||||
)
|
||||
nodes: Sequence[DAGNode] = nodes
|
||||
dags = set([node.dag for node in nodes if node.dag])
|
||||
if self.dag:
|
||||
dags.add(self.dag)
|
||||
if not dags:
|
||||
raise ValueError("set dependency to current node must in a DAG context")
|
||||
if len(dags) != 1:
|
||||
raise ValueError(
|
||||
"set dependency to current node just support in one DAG context"
|
||||
)
|
||||
dag = dags.pop()
|
||||
self._dag = dag
|
||||
|
||||
dag._append_node(self)
|
||||
for node in nodes:
|
||||
if is_upstream and node not in self.upstream:
|
||||
node._dag = dag
|
||||
dag._append_node(node)
|
||||
|
||||
self._upstream.append(node)
|
||||
node._downstream.append(self)
|
||||
elif node not in self._downstream:
|
||||
node._dag = dag
|
||||
dag._append_node(node)
|
||||
|
||||
self._downstream.append(node)
|
||||
node._upstream.append(self)
|
||||
|
||||
|
||||
class DAGContext:
|
||||
def __init__(self) -> None:
|
||||
self._curr_task_ctx = None
|
||||
self._share_data: Dict[str, Any] = {}
|
||||
|
||||
@property
|
||||
def current_task_context(self) -> TaskContext:
|
||||
return self._curr_task_ctx
|
||||
|
||||
def set_current_task_context(self, _curr_task_ctx: TaskContext) -> None:
|
||||
self._curr_task_ctx = _curr_task_ctx
|
||||
|
||||
async def get_share_data(self, key: str) -> Any:
|
||||
return self._share_data.get(key)
|
||||
|
||||
async def save_to_share_data(self, key: str, data: Any) -> None:
|
||||
self._share_data[key] = data
|
||||
|
||||
|
||||
class DAG:
|
||||
def __init__(
|
||||
self, dag_id: str, resource_group: Optional[ResourceGroup] = None
|
||||
) -> None:
|
||||
self.node_map: Dict[str, DAGNode] = {}
|
||||
|
||||
def _append_node(self, node: DAGNode) -> None:
|
||||
self.node_map[node.node_id] = node
|
||||
|
||||
def _new_node_id(self) -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
def __enter__(self):
|
||||
DAGVar.enter_dag(self)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
DAGVar.exit_dag()
|
0
pilot/awel/dag/tests/__init__.py
Normal file
0
pilot/awel/dag/tests/__init__.py
Normal file
51
pilot/awel/dag/tests/test_dag.py
Normal file
51
pilot/awel/dag/tests/test_dag.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import pytest
|
||||
import threading
|
||||
import asyncio
|
||||
from ..dag import DAG, DAGContext
|
||||
|
||||
|
||||
def test_dag_context_sync():
|
||||
dag1 = DAG("dag1")
|
||||
dag2 = DAG("dag2")
|
||||
|
||||
with dag1:
|
||||
assert DAGContext.get_current_dag() == dag1
|
||||
with dag2:
|
||||
assert DAGContext.get_current_dag() == dag2
|
||||
assert DAGContext.get_current_dag() == dag1
|
||||
assert DAGContext.get_current_dag() is None
|
||||
|
||||
|
||||
def test_dag_context_threading():
|
||||
def thread_function(dag):
|
||||
DAGContext.enter_dag(dag)
|
||||
assert DAGContext.get_current_dag() == dag
|
||||
DAGContext.exit_dag()
|
||||
|
||||
dag1 = DAG("dag1")
|
||||
dag2 = DAG("dag2")
|
||||
|
||||
thread1 = threading.Thread(target=thread_function, args=(dag1,))
|
||||
thread2 = threading.Thread(target=thread_function, args=(dag2,))
|
||||
|
||||
thread1.start()
|
||||
thread2.start()
|
||||
thread1.join()
|
||||
thread2.join()
|
||||
|
||||
assert DAGContext.get_current_dag() is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dag_context_async():
|
||||
async def async_function(dag):
|
||||
DAGContext.enter_dag(dag)
|
||||
assert DAGContext.get_current_dag() == dag
|
||||
DAGContext.exit_dag()
|
||||
|
||||
dag1 = DAG("dag1")
|
||||
dag2 = DAG("dag2")
|
||||
|
||||
await asyncio.gather(async_function(dag1), async_function(dag2))
|
||||
|
||||
assert DAGContext.get_current_dag() is None
|
0
pilot/awel/operator/__init__.py
Normal file
0
pilot/awel/operator/__init__.py
Normal file
176
pilot/awel/operator/base.py
Normal file
176
pilot/awel/operator/base.py
Normal file
@@ -0,0 +1,176 @@
|
||||
from abc import ABC, abstractmethod, ABCMeta
|
||||
|
||||
from types import FunctionType
|
||||
from typing import (
|
||||
List,
|
||||
Generic,
|
||||
TypeVar,
|
||||
AsyncIterator,
|
||||
Union,
|
||||
Any,
|
||||
Dict,
|
||||
Optional,
|
||||
cast,
|
||||
)
|
||||
import functools
|
||||
from inspect import signature
|
||||
|
||||
from ..dag.base import DAGNode, DAGContext, DAGVar, DAG
|
||||
from ..task.base import (
|
||||
TaskContext,
|
||||
TaskOutput,
|
||||
TaskState,
|
||||
OUT,
|
||||
T,
|
||||
InputContext,
|
||||
InputSource,
|
||||
)
|
||||
|
||||
F = TypeVar("F", bound=FunctionType)
|
||||
|
||||
CALL_DATA = Union[Dict, Dict[str, Dict]]
|
||||
|
||||
|
||||
class WorkflowRunner(ABC, Generic[T]):
|
||||
"""Abstract base class representing a runner for executing workflows in a DAG.
|
||||
|
||||
This class defines the interface for executing workflows within the DAG,
|
||||
handling the flow from one DAG node to another.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def execute_workflow(
|
||||
self, node: "BaseOperator", call_data: Optional[CALL_DATA] = None
|
||||
) -> DAGContext:
|
||||
"""Execute the workflow starting from a given operator.
|
||||
|
||||
Args:
|
||||
node (RunnableDAGNode): The starting node of the workflow to be executed.
|
||||
call_data (CALL_DATA): The data pass to root operator node.
|
||||
|
||||
Returns:
|
||||
DAGContext: The context after executing the workflow, containing the final state and data.
|
||||
"""
|
||||
|
||||
|
||||
default_runner: WorkflowRunner = None
|
||||
|
||||
|
||||
class BaseOperatorMeta(ABCMeta):
|
||||
"""Metaclass of BaseOperator."""
|
||||
|
||||
@classmethod
|
||||
def _apply_defaults(cls, func: F) -> F:
|
||||
sig_cache = signature(func)
|
||||
|
||||
@functools.wraps(func)
|
||||
def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any:
|
||||
dag: Optional[DAG] = kwargs.get("dag") or DAGVar.get_current_dag()
|
||||
task_id: Optional[str] = kwargs.get("task_id")
|
||||
if not task_id and dag:
|
||||
task_id = dag._new_node_id()
|
||||
runner: Optional[WorkflowRunner] = kwargs.get("runner") or default_runner
|
||||
# print(f"self: {self}, kwargs dag: {kwargs.get('dag')}, kwargs: {kwargs}")
|
||||
# for arg in sig_cache.parameters:
|
||||
# if arg not in kwargs:
|
||||
# kwargs[arg] = default_args[arg]
|
||||
if not kwargs.get("dag"):
|
||||
kwargs["dag"] = dag
|
||||
if not kwargs.get("task_id"):
|
||||
kwargs["task_id"] = task_id
|
||||
if not kwargs.get("runner"):
|
||||
kwargs["runner"] = runner
|
||||
real_obj = func(self, *args, **kwargs)
|
||||
return real_obj
|
||||
|
||||
return cast(T, apply_defaults)
|
||||
|
||||
def __new__(cls, name, bases, namespace, **kwargs):
|
||||
new_cls = super().__new__(cls, name, bases, namespace, **kwargs)
|
||||
new_cls.__init__ = cls._apply_defaults(new_cls.__init__)
|
||||
return new_cls
|
||||
|
||||
|
||||
class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
"""Abstract base class for operator nodes that can be executed within a workflow.
|
||||
|
||||
This class extends DAGNode by adding execution capabilities.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task_id: Optional[str] = None,
|
||||
dag: Optional[DAG] = None,
|
||||
runner: WorkflowRunner = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Initializes a BaseOperator with an optional workflow runner.
|
||||
|
||||
Args:
|
||||
runner (WorkflowRunner, optional): The runner used to execute the workflow. Defaults to None.
|
||||
"""
|
||||
super().__init__(node_id=task_id, dag=dag, **kwargs)
|
||||
if not runner:
|
||||
from pilot.awel import DefaultWorkflowRunner
|
||||
|
||||
runner = DefaultWorkflowRunner()
|
||||
|
||||
self._runner: WorkflowRunner = runner
|
||||
self._dag_ctx: DAGContext = None
|
||||
|
||||
@property
|
||||
def current_dag_context(self) -> DAGContext:
|
||||
return self._dag_ctx
|
||||
|
||||
async def _run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
if not self.node_id:
|
||||
raise ValueError(f"The DAG Node ID can't be empty, current node {self}")
|
||||
self._dag_ctx = dag_ctx
|
||||
return await self._do_run(dag_ctx)
|
||||
|
||||
@abstractmethod
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
"""
|
||||
Abstract method to run the task within the DAG node.
|
||||
|
||||
Args:
|
||||
dag_ctx (DAGContext): The context of the DAG when this node is run.
|
||||
|
||||
Returns:
|
||||
TaskOutput[OUT]: The task output after this node has been run.
|
||||
"""
|
||||
|
||||
async def call(self, call_data: Optional[CALL_DATA] = None) -> OUT:
|
||||
"""Execute the node and return the output.
|
||||
|
||||
This method is a high-level wrapper for executing the node.
|
||||
|
||||
Args:
|
||||
call_data (CALL_DATA): The data pass to root operator node.
|
||||
|
||||
Returns:
|
||||
OUT: The output of the node after execution.
|
||||
"""
|
||||
out_ctx = await self._runner.execute_workflow(self, call_data)
|
||||
return out_ctx.current_task_context.task_output.output
|
||||
|
||||
async def call_stream(
|
||||
self, call_data: Optional[CALL_DATA] = None
|
||||
) -> AsyncIterator[OUT]:
|
||||
"""Execute the node and return the output as a stream.
|
||||
|
||||
This method is used for nodes where the output is a stream.
|
||||
|
||||
Args:
|
||||
call_data (CALL_DATA): The data pass to root operator node.
|
||||
|
||||
Returns:
|
||||
AsyncIterator[OUT]: An asynchronous iterator over the output stream.
|
||||
"""
|
||||
out_ctx = await self._runner.execute_workflow(self, call_data)
|
||||
return out_ctx.current_task_context.task_output.output_stream
|
||||
|
||||
|
||||
def initialize_awel(runner: WorkflowRunner):
|
||||
global default_runner
|
||||
default_runner = runner
|
221
pilot/awel/operator/common_operator.py
Normal file
221
pilot/awel/operator/common_operator.py
Normal file
@@ -0,0 +1,221 @@
|
||||
from typing import Generic, Dict, List, Union, Callable, Any, AsyncIterator, Awaitable
|
||||
import asyncio
|
||||
|
||||
from ..dag.base import DAGContext
|
||||
from ..task.base import (
|
||||
TaskContext,
|
||||
TaskOutput,
|
||||
IN,
|
||||
OUT,
|
||||
InputContext,
|
||||
InputSource,
|
||||
)
|
||||
|
||||
from .base import BaseOperator
|
||||
|
||||
|
||||
class JoinOperator(BaseOperator, Generic[OUT]):
|
||||
"""Operator that joins inputs using a custom combine function.
|
||||
|
||||
This node type is useful for combining the outputs of upstream nodes.
|
||||
"""
|
||||
|
||||
def __init__(self, combine_function, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if not callable(combine_function):
|
||||
raise ValueError("combine_function must be callable")
|
||||
self.combine_function = combine_function
|
||||
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
"""Run the join operation on the DAG context's inputs.
|
||||
Args:
|
||||
dag_ctx (DAGContext): The current context of the DAG.
|
||||
|
||||
Returns:
|
||||
TaskOutput[OUT]: The task output after this node has been run.
|
||||
"""
|
||||
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||
input_ctx: InputContext = await curr_task_ctx.task_input.map_all(
|
||||
self.combine_function
|
||||
)
|
||||
# All join result store in the first parent output
|
||||
join_output = input_ctx.parent_outputs[0].task_output
|
||||
curr_task_ctx.set_task_output(join_output)
|
||||
return join_output
|
||||
|
||||
|
||||
class ReduceStreamOperator(BaseOperator, Generic[IN, OUT]):
|
||||
def __init__(self, reduce_function=None, **kwargs):
|
||||
"""Initializes a ReduceStreamOperator with a combine function.
|
||||
|
||||
Args:
|
||||
combine_function: A function that defines how to combine inputs.
|
||||
|
||||
Raises:
|
||||
ValueError: If the combine_function is not callable.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
if reduce_function and not callable(reduce_function):
|
||||
raise ValueError("reduce_function must be callable")
|
||||
self.reduce_function = reduce_function
|
||||
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
"""Run the join operation on the DAG context's inputs.
|
||||
|
||||
Args:
|
||||
dag_ctx (DAGContext): The current context of the DAG.
|
||||
|
||||
Returns:
|
||||
TaskOutput[OUT]: The task output after this node has been run.
|
||||
"""
|
||||
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||
task_input = curr_task_ctx.task_input
|
||||
if not task_input.check_stream():
|
||||
raise ValueError("ReduceStreamOperator expects stream data")
|
||||
if not task_input.check_single_parent():
|
||||
raise ValueError("ReduceStreamOperator expects single parent")
|
||||
|
||||
reduce_function = self.reduce_function or self.reduce
|
||||
|
||||
input_ctx: InputContext = await task_input.reduce(reduce_function)
|
||||
# All join result store in the first parent output
|
||||
reduce_output = input_ctx.parent_outputs[0].task_output
|
||||
curr_task_ctx.set_task_output(reduce_output)
|
||||
return reduce_output
|
||||
|
||||
async def reduce(self, input_value: AsyncIterator[IN]) -> OUT:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MapOperator(BaseOperator, Generic[IN, OUT]):
|
||||
"""Map operator that applies a mapping function to its inputs.
|
||||
|
||||
This operator transforms its input data using a provided mapping function and
|
||||
passes the transformed data downstream.
|
||||
"""
|
||||
|
||||
def __init__(self, map_function=None, **kwargs):
|
||||
"""Initializes a MapDAGNode with a mapping function.
|
||||
|
||||
Args:
|
||||
map_function: A function that defines how to map the input data.
|
||||
|
||||
Raises:
|
||||
ValueError: If the map_function is not callable.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
if map_function and not callable(map_function):
|
||||
raise ValueError("map_function must be callable")
|
||||
self.map_function = map_function
|
||||
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
"""Run the mapping operation on the DAG context's inputs.
|
||||
|
||||
This method applies the mapping function to the input context and updates
|
||||
the DAG context with the new data.
|
||||
|
||||
Args:
|
||||
dag_ctx (DAGContext[IN]): The current context of the DAG.
|
||||
|
||||
Returns:
|
||||
TaskOutput[OUT]: The task output after this node has been run.
|
||||
|
||||
Raises:
|
||||
ValueError: If not a single parent or the map_function is not callable
|
||||
"""
|
||||
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||
if not curr_task_ctx.task_input.check_single_parent():
|
||||
num_parents = len(curr_task_ctx.task_input.parent_outputs)
|
||||
raise ValueError(
|
||||
f"task {curr_task_ctx.task_id} MapDAGNode expects single parent, now number of parents: {num_parents}"
|
||||
)
|
||||
map_function = self.map_function or self.map
|
||||
|
||||
input_ctx: InputContext = await curr_task_ctx.task_input.map(map_function)
|
||||
# All join result store in the first parent output
|
||||
reduce_output = input_ctx.parent_outputs[0].task_output
|
||||
curr_task_ctx.set_task_output(reduce_output)
|
||||
return reduce_output
|
||||
|
||||
async def map(self, input_value: IN) -> OUT:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
BranchFunc = Union[Callable[[Any], bool], Callable[[Any], Awaitable[bool]]]
|
||||
|
||||
|
||||
class BranchOperator(BaseOperator, Generic[OUT]):
|
||||
"""Operator node that branches the workflow based on a provided function.
|
||||
|
||||
This node filters its input data using a branching function and
|
||||
allows for conditional paths in the workflow.
|
||||
"""
|
||||
|
||||
def __init__(self, branches: Dict[BranchFunc, BaseOperator], **kwargs):
|
||||
"""
|
||||
Initializes a BranchDAGNode with a branching function.
|
||||
|
||||
Args:
|
||||
branches (Dict[BranchFunc, RunnableDAGNode]): Dict of function that defines the branching condition.
|
||||
|
||||
Raises:
|
||||
ValueError: If the branch_function is not callable.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
for branch_function in branches.keys():
|
||||
if not callable(branch_function):
|
||||
raise ValueError("branch_function must be callable")
|
||||
self.branches = branches
|
||||
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
"""Run the branching operation on the DAG context's inputs.
|
||||
|
||||
This method applies the branching function to the input context to determine
|
||||
the path of execution in the workflow.
|
||||
|
||||
Args:
|
||||
dag_ctx (DAGContext[IN]): The current context of the DAG.
|
||||
|
||||
Returns:
|
||||
TaskOutput[OUT]: The task output after this node has been run.
|
||||
"""
|
||||
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||
task_input = curr_task_ctx.task_input
|
||||
if task_input.check_stream():
|
||||
raise ValueError("BranchDAGNode expects no stream data")
|
||||
if not task_input.check_single_parent():
|
||||
raise ValueError("BranchDAGNode expects single parent")
|
||||
|
||||
branch_func_tasks = []
|
||||
branch_nodes: List[BaseOperator] = []
|
||||
for func, node in self.branches.items():
|
||||
branch_nodes.append(node)
|
||||
branch_func_tasks.append(
|
||||
curr_task_ctx.task_input.predicate_map(func, failed_value=None)
|
||||
)
|
||||
branch_input_ctxs: List[InputContext] = await asyncio.gather(*branch_func_tasks)
|
||||
parent_output = task_input.parent_outputs[0].task_output
|
||||
curr_task_ctx.set_task_output(parent_output)
|
||||
|
||||
for i, ctx in enumerate(branch_input_ctxs):
|
||||
node = branch_nodes[i]
|
||||
if ctx.parent_outputs[0].task_output.is_empty:
|
||||
# Skip current node
|
||||
# node.current_task_context.set_current_state(TaskState.SKIP)
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
raise NotImplementedError
|
||||
return None
|
||||
|
||||
|
||||
class InputOperator(BaseOperator, Generic[OUT]):
|
||||
def __init__(self, input_source: InputSource[OUT], **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self._input_source = input_source
|
||||
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||
task_output = await self._input_source.read(curr_task_ctx)
|
||||
curr_task_ctx.set_task_output(task_output)
|
||||
return task_output
|
90
pilot/awel/operator/stream_operator.py
Normal file
90
pilot/awel/operator/stream_operator.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Generic, AsyncIterator
|
||||
from ..task.base import OUT, IN, TaskOutput, TaskContext
|
||||
from ..dag.base import DAGContext
|
||||
from .base import BaseOperator
|
||||
|
||||
|
||||
class StreamifyAbsOperator(BaseOperator[OUT], ABC, Generic[IN, OUT]):
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||
output = await curr_task_ctx.task_input.parent_outputs[0].task_output.streamify(
|
||||
self.streamify
|
||||
)
|
||||
curr_task_ctx.set_task_output(output)
|
||||
return output
|
||||
|
||||
@abstractmethod
|
||||
async def streamify(self, input_value: IN) -> AsyncIterator[OUT]:
|
||||
"""Convert a value of IN to an AsyncIterator[OUT]
|
||||
|
||||
Args:
|
||||
input_value (IN): The data of parent operator's output
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class MyStreamOperator(StreamifyAbsOperator[int, int]):
|
||||
async def streamify(self, input_value: int) -> AsyncIterator[int]
|
||||
for i in range(input_value):
|
||||
yield i
|
||||
"""
|
||||
|
||||
|
||||
class UnstreamifyAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||
output = await curr_task_ctx.task_input.parent_outputs[
|
||||
0
|
||||
].task_output.unstreamify(self.unstreamify)
|
||||
curr_task_ctx.set_task_output(output)
|
||||
return output
|
||||
|
||||
@abstractmethod
|
||||
async def unstreamify(self, input_value: AsyncIterator[IN]) -> OUT:
|
||||
"""Convert a value of AsyncIterator[IN] to an OUT.
|
||||
|
||||
Args:
|
||||
input_value (AsyncIterator[IN])): The data of parent operator's output
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class MyUnstreamOperator(UnstreamifyAbsOperator[int, int]):
|
||||
async def unstreamify(self, input_value: AsyncIterator[int]) -> int
|
||||
value_cnt = 0
|
||||
async for v in input_value:
|
||||
value_cnt += 1
|
||||
return value_cnt
|
||||
"""
|
||||
|
||||
|
||||
class TransformStreamAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||
output = await curr_task_ctx.task_input.parent_outputs[
|
||||
0
|
||||
].task_output.transform_stream(self.transform_stream)
|
||||
curr_task_ctx.set_task_output(output)
|
||||
return output
|
||||
|
||||
@abstractmethod
|
||||
async def transform_stream(
|
||||
self, input_value: AsyncIterator[IN]
|
||||
) -> AsyncIterator[OUT]:
|
||||
"""Transform an AsyncIterator[IN] to another AsyncIterator[OUT] using a given function.
|
||||
|
||||
Args:
|
||||
input_value (AsyncIterator[IN])): The data of parent operator's output
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class MyTransformStreamOperator(TransformStreamAbsOperator[int, int]):
|
||||
async def unstreamify(self, input_value: AsyncIterator[int]) -> AsyncIterator[int]
|
||||
async for v in input_value:
|
||||
yield v + 1
|
||||
"""
|
0
pilot/awel/resource/__init__.py
Normal file
0
pilot/awel/resource/__init__.py
Normal file
8
pilot/awel/resource/base.py
Normal file
8
pilot/awel/resource/base.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class ResourceGroup(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""The name of current resource group"""
|
0
pilot/awel/runner/__init__.py
Normal file
0
pilot/awel/runner/__init__.py
Normal file
74
pilot/awel/runner/job_manager.py
Normal file
74
pilot/awel/runner/job_manager.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from typing import List, Set, Optional, Dict
|
||||
import uuid
|
||||
from ..dag.base import DAG
|
||||
|
||||
from ..operator.base import BaseOperator, CALL_DATA
|
||||
|
||||
|
||||
class DAGNodeInstance:
|
||||
def __init__(self, node_instance: DAG) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class DAGInstance:
|
||||
def __init__(self, dag: DAG) -> None:
|
||||
self._dag = dag
|
||||
|
||||
|
||||
class JobManager:
|
||||
def __init__(
|
||||
self,
|
||||
root_nodes: List[BaseOperator],
|
||||
all_nodes: List[BaseOperator],
|
||||
end_node: BaseOperator,
|
||||
id2call_data: Dict[str, Dict],
|
||||
) -> None:
|
||||
self._root_nodes = root_nodes
|
||||
self._all_nodes = all_nodes
|
||||
self._end_node = end_node
|
||||
self._id2node_data = id2call_data
|
||||
|
||||
@staticmethod
|
||||
def build_from_end_node(
|
||||
end_node: BaseOperator, call_data: Optional[CALL_DATA] = None
|
||||
) -> "JobManager":
|
||||
nodes = _build_from_end_node(end_node)
|
||||
root_nodes = _get_root_nodes(nodes)
|
||||
id2call_data = _save_call_data(root_nodes, call_data)
|
||||
return JobManager(root_nodes, nodes, end_node, id2call_data)
|
||||
|
||||
def get_call_data_by_id(self, node_id: str) -> Optional[Dict]:
|
||||
return self._id2node_data.get(node_id)
|
||||
|
||||
|
||||
def _save_call_data(
|
||||
root_nodes: List[BaseOperator], call_data: CALL_DATA
|
||||
) -> Dict[str, Dict]:
|
||||
id2call_data = {}
|
||||
if not call_data:
|
||||
return id2call_data
|
||||
if len(root_nodes) == 1:
|
||||
node = root_nodes[0]
|
||||
id2call_data[node.node_id] = call_data
|
||||
else:
|
||||
for node in root_nodes:
|
||||
node_id = node.node_id
|
||||
id2call_data[node_id] = call_data.get(node_id)
|
||||
return id2call_data
|
||||
|
||||
|
||||
def _build_from_end_node(end_node: BaseOperator) -> List[BaseOperator]:
|
||||
nodes = []
|
||||
if isinstance(end_node, BaseOperator):
|
||||
task_id = end_node.node_id
|
||||
if not task_id:
|
||||
task_id = str(uuid.uuid4())
|
||||
end_node.set_node_id(task_id)
|
||||
nodes.append(end_node)
|
||||
for node in end_node.upstream:
|
||||
nodes += _build_from_end_node(node)
|
||||
return nodes
|
||||
|
||||
|
||||
def _get_root_nodes(nodes: List[BaseOperator]) -> List[BaseOperator]:
|
||||
return list(filter(lambda x: not x.upstream, nodes))
|
68
pilot/awel/runner/local_runner.py
Normal file
68
pilot/awel/runner/local_runner.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from typing import Dict, Optional
|
||||
from ..dag.base import DAGContext
|
||||
from ..operator.base import WorkflowRunner, BaseOperator, CALL_DATA
|
||||
from ..task.base import TaskContext, TaskState
|
||||
from ..task.task_impl import DefaultInputContext, DefaultTaskContext
|
||||
from .job_manager import JobManager
|
||||
|
||||
|
||||
class DefaultWorkflowRunner(WorkflowRunner):
|
||||
async def execute_workflow(
|
||||
self, node: BaseOperator, call_data: Optional[CALL_DATA] = None
|
||||
) -> DAGContext:
|
||||
# Create DAG context
|
||||
dag_ctx = DAGContext()
|
||||
job_manager = JobManager.build_from_end_node(node, call_data)
|
||||
dag = node.dag
|
||||
# Save node output
|
||||
node_outputs: Dict[str, TaskContext] = {}
|
||||
await self._execute_node(job_manager, node, dag_ctx, node_outputs)
|
||||
|
||||
return dag_ctx
|
||||
|
||||
async def _execute_node(
|
||||
self,
|
||||
job_manager: JobManager,
|
||||
node: BaseOperator,
|
||||
dag_ctx: DAGContext,
|
||||
node_outputs: Dict[str, TaskContext],
|
||||
):
|
||||
# Skip run node
|
||||
if node.node_id in node_outputs:
|
||||
return
|
||||
|
||||
# Run all upstream node
|
||||
for upstream_node in node.upstream:
|
||||
if isinstance(upstream_node, BaseOperator):
|
||||
await self._execute_node(
|
||||
job_manager, upstream_node, dag_ctx, node_outputs
|
||||
)
|
||||
# if node.current_task_context.current_state == TaskState.SKIP:
|
||||
# return
|
||||
|
||||
# for upstream_node in node.upstream:
|
||||
# if (
|
||||
# isinstance(upstream_node, BaseOperator)
|
||||
# and upstream_node.current_task_context.current_state == TaskState.SKIP
|
||||
# ):
|
||||
# return
|
||||
# Get the input from upstream node
|
||||
inputs = [
|
||||
node_outputs[upstream_node.node_id] for upstream_node in node.upstream
|
||||
]
|
||||
input_ctx = DefaultInputContext(inputs)
|
||||
task_ctx = DefaultTaskContext(node.node_id, TaskState.INIT, task_output=None)
|
||||
task_ctx.set_call_data(job_manager.get_call_data_by_id(node.node_id))
|
||||
|
||||
task_ctx.set_task_input(input_ctx)
|
||||
dag_ctx.set_current_task_context(task_ctx)
|
||||
|
||||
task_ctx.set_current_state(TaskState.RUNNING)
|
||||
try:
|
||||
# print(f"Begin run {node}")
|
||||
await node._run(dag_ctx)
|
||||
node_outputs[node.node_id] = dag_ctx.current_task_context
|
||||
task_ctx.set_current_state(TaskState.SUCCESS)
|
||||
except Exception as e:
|
||||
task_ctx.set_current_state(TaskState.FAILED)
|
||||
raise e
|
0
pilot/awel/task/__init__.py
Normal file
0
pilot/awel/task/__init__.py
Normal file
358
pilot/awel/task/base.py
Normal file
358
pilot/awel/task/base.py
Normal file
@@ -0,0 +1,358 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
TypeVar,
|
||||
Generic,
|
||||
Optional,
|
||||
AsyncIterator,
|
||||
Union,
|
||||
Callable,
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
)
|
||||
|
||||
IN = TypeVar("IN")
|
||||
OUT = TypeVar("OUT")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class TaskState(str, Enum):
|
||||
"""Enumeration representing the state of a task in the workflow.
|
||||
|
||||
This Enum defines various states a task can be in during its lifecycle in the DAG.
|
||||
"""
|
||||
|
||||
INIT = "init" # Initial state of the task, not yet started
|
||||
SKIP = "skip" # State indicating the task was skipped
|
||||
RUNNING = "running" # State indicating the task is currently running
|
||||
SUCCESS = "success" # State indicating the task completed successfully
|
||||
FAILED = "failed" # State indicating the task failed during execution
|
||||
|
||||
|
||||
class TaskOutput(ABC, Generic[T]):
|
||||
"""Abstract base class representing the output of a task.
|
||||
|
||||
This class encapsulates the output of a task and provides methods to access the output data.
|
||||
It can be subclassed to implement specific output behaviors.
|
||||
"""
|
||||
|
||||
@property
|
||||
def is_stream(self) -> bool:
|
||||
"""Check if the output is a stream.
|
||||
|
||||
Returns:
|
||||
bool: True if the output is a stream, False otherwise.
|
||||
"""
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
"""Check if the output is empty.
|
||||
|
||||
Returns:
|
||||
bool: True if the output is empty, False otherwise.
|
||||
"""
|
||||
return False
|
||||
|
||||
@property
|
||||
def output(self) -> Optional[T]:
|
||||
"""Return the output of the task.
|
||||
|
||||
Returns:
|
||||
T: The output of the task. None if the output is empty.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def output_stream(self) -> Optional[AsyncIterator[T]]:
|
||||
"""Return the output of the task as an asynchronous stream.
|
||||
|
||||
Returns:
|
||||
AsyncIterator[T]: An asynchronous iterator over the output. None if the output is empty.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def set_output(self, output_data: Union[T, AsyncIterator[T]]) -> None:
|
||||
"""Set the output data to current object.
|
||||
|
||||
Args:
|
||||
output_data (Union[T, AsyncIterator[T]]): Output data.
|
||||
"""
|
||||
|
||||
async def map(self, map_func) -> "TaskOutput[T]":
|
||||
"""Apply a mapping function to the task's output.
|
||||
|
||||
Args:
|
||||
map_func: A function to apply to the task's output.
|
||||
|
||||
Returns:
|
||||
TaskOutput[T]: The result of applying the mapping function.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def reduce(self, reduce_func) -> "TaskOutput[T]":
|
||||
"""Apply a reducing function to the task's output.
|
||||
|
||||
Stream TaskOutput to Nonstream TaskOutput.
|
||||
|
||||
Args:
|
||||
reduce_func: A reducing function to apply to the task's output.
|
||||
|
||||
Returns:
|
||||
TaskOutput[T]: The result of applying the reducing function.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def streamify(
|
||||
self, transform_func: Callable[[T], AsyncIterator[T]]
|
||||
) -> "TaskOutput[T]":
|
||||
"""Convert a value of type T to an AsyncIterator[T] using a transform function.
|
||||
|
||||
Args:
|
||||
transform_func (Callable[[T], AsyncIterator[T]]): Function to transform a T value into an AsyncIterator[T].
|
||||
|
||||
Returns:
|
||||
TaskOutput[T]: The result of applying the reducing function.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def transform_stream(
|
||||
self, transform_func: Callable[[AsyncIterator[T]], AsyncIterator[T]]
|
||||
) -> "TaskOutput[T]":
|
||||
"""Transform an AsyncIterator[T] to another AsyncIterator[T] using a given function.
|
||||
|
||||
Args:
|
||||
transform_func (Callable[[AsyncIterator[T]], AsyncIterator[T]]): Function to apply to the AsyncIterator[T].
|
||||
|
||||
Returns:
|
||||
TaskOutput[T]: The result of applying the reducing function.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def unstreamify(
|
||||
self, transform_func: Callable[[AsyncIterator[T]], T]
|
||||
) -> "TaskOutput[T]":
|
||||
"""Convert an AsyncIterator[T] to a value of type T using a transform function.
|
||||
|
||||
Args:
|
||||
transform_func (Callable[[AsyncIterator[T]], T]): Function to transform an AsyncIterator[T] into a T value.
|
||||
|
||||
Returns:
|
||||
TaskOutput[T]: The result of applying the reducing function.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def check_condition(self, condition_func) -> bool:
|
||||
"""Check if current output meets a given condition.
|
||||
|
||||
Args:
|
||||
condition_func: A function to determine if the condition is met.
|
||||
Returns:
|
||||
bool: True if current output meet the condition, False otherwise.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TaskContext(ABC, Generic[T]):
|
||||
"""Abstract base class representing the context of a task within a DAG.
|
||||
|
||||
This class provides the interface for accessing task-related information
|
||||
and manipulating task output.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def task_id(self) -> str:
|
||||
"""Return the unique identifier of the task.
|
||||
|
||||
Returns:
|
||||
str: The unique identifier of the task.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def task_input(self) -> "InputContext":
|
||||
"""Return the InputContext of current task.
|
||||
|
||||
Returns:
|
||||
InputContext: The InputContext of current task.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_task_input(self, input_ctx: "InputContext") -> None:
|
||||
"""Set the InputContext object to current task.
|
||||
|
||||
Args:
|
||||
input_ctx (InputContext): The InputContext of current task
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def task_output(self) -> TaskOutput[T]:
|
||||
"""Return the output object of the task.
|
||||
|
||||
Returns:
|
||||
TaskOutput[T]: The output object of the task.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_task_output(self, task_output: TaskOutput[T]) -> None:
|
||||
"""Set the output object to current task."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def current_state(self) -> TaskState:
|
||||
"""Get the current state of the task.
|
||||
|
||||
Returns:
|
||||
TaskState: The current state of the task.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_current_state(self, task_state: TaskState) -> None:
|
||||
"""Set current task state
|
||||
|
||||
Args:
|
||||
task_state (TaskState): The task state to be set.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def new_ctx(self) -> "TaskContext":
|
||||
"""Create new task context
|
||||
|
||||
Returns:
|
||||
TaskContext: A new instance of a TaskContext.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def metadata(self) -> Dict[str, Any]:
|
||||
"""Get the metadata of current task
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The metadata
|
||||
"""
|
||||
|
||||
def update_metadata(self, key: str, value: Any) -> None:
|
||||
"""Update metadata with key and value
|
||||
|
||||
Args:
|
||||
key (str): The key of metadata
|
||||
value (str): The value to be add to metadata
|
||||
"""
|
||||
self.metadata[key] = value
|
||||
|
||||
@property
|
||||
def call_data(self) -> Optional[Dict]:
|
||||
"""Get the call data for current data"""
|
||||
return self.metadata.get("call_data")
|
||||
|
||||
def set_call_data(self, call_data: Dict) -> None:
|
||||
"""Set call data for current task"""
|
||||
self.update_metadata("call_data", call_data)
|
||||
|
||||
|
||||
class InputContext(ABC):
|
||||
"""Abstract base class representing the context of inputs to a operator node.
|
||||
|
||||
This class defines methods to manipulate and access the inputs for a operator node.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def parent_outputs(self) -> List[TaskContext]:
|
||||
"""Get the outputs from the parent nodes.
|
||||
|
||||
Returns:
|
||||
List[TaskContext]: A list of contexts of the parent nodes' outputs.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def map(self, map_func: Callable[[Any], Any]) -> "InputContext":
|
||||
"""Apply a mapping function to the inputs.
|
||||
|
||||
Args:
|
||||
map_func (Callable[[Any], Any]): A function to be applied to the inputs.
|
||||
|
||||
Returns:
|
||||
InputContext: A new InputContext instance with the mapped inputs.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def map_all(self, map_func: Callable[..., Any]) -> "InputContext":
|
||||
"""Apply a mapping function to all inputs.
|
||||
|
||||
Args:
|
||||
map_func (Callable[..., Any]): A function to be applied to all inputs.
|
||||
|
||||
Returns:
|
||||
InputContext: A new InputContext instance with the mapped inputs.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def reduce(self, reduce_func: Callable[[Any], Any]) -> "InputContext":
|
||||
"""Apply a reducing function to the inputs.
|
||||
|
||||
Args:
|
||||
reduce_func (Callable[[Any], Any]): A function that reduces the inputs.
|
||||
|
||||
Returns:
|
||||
InputContext: A new InputContext instance with the reduced inputs.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def filter(self, filter_func: Callable[[Any], bool]) -> "InputContext":
|
||||
"""Filter the inputs based on a provided function.
|
||||
|
||||
Args:
|
||||
filter_func (Callable[[Any], bool]): A function that returns True for inputs to keep.
|
||||
|
||||
Returns:
|
||||
InputContext: A new InputContext instance with the filtered inputs.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def predicate_map(
|
||||
self, predicate_func: Callable[[Any], bool], failed_value: Any = None
|
||||
) -> "InputContext":
|
||||
"""Predicate the inputs based on a provided function.
|
||||
|
||||
Args:
|
||||
predicate_func (Callable[[Any], bool]): A function that returns True for inputs is predicate True.
|
||||
failed_value (Any): The value to be set if the return value of predicate function is False
|
||||
Returns:
|
||||
InputContext: A new InputContext instance with the predicate inputs.
|
||||
"""
|
||||
|
||||
def check_single_parent(self) -> bool:
|
||||
"""Check if there is only a single parent output.
|
||||
|
||||
Returns:
|
||||
bool: True if there is only one parent output, False otherwise.
|
||||
"""
|
||||
return len(self.parent_outputs) == 1
|
||||
|
||||
def check_stream(self) -> bool:
|
||||
"""Check if all parent outputs are streams.
|
||||
|
||||
Returns:
|
||||
bool: True if all parent outputs are streams, False otherwise.
|
||||
"""
|
||||
for out in self.parent_outputs:
|
||||
if not (out.task_output and out.task_output.is_stream):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class InputSource(ABC, Generic[T]):
|
||||
"""Abstract base class representing the source of inputs to a DAG node."""
|
||||
|
||||
@abstractmethod
|
||||
async def read(self, task_ctx: TaskContext) -> TaskOutput[T]:
|
||||
"""Read the data from current input source.
|
||||
|
||||
Returns:
|
||||
TaskOutput[T]: The output object read from current source
|
||||
"""
|
318
pilot/awel/task/task_impl.py
Normal file
318
pilot/awel/task/task_impl.py
Normal file
@@ -0,0 +1,318 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
Callable,
|
||||
Coroutine,
|
||||
Iterator,
|
||||
AsyncIterator,
|
||||
List,
|
||||
Generic,
|
||||
TypeVar,
|
||||
Any,
|
||||
Tuple,
|
||||
Dict,
|
||||
Union,
|
||||
)
|
||||
import asyncio
|
||||
|
||||
from .base import TaskOutput, TaskContext, TaskState, InputContext, InputSource, T
|
||||
|
||||
|
||||
async def _reduce_stream(stream: AsyncIterator, reduce_function) -> Any:
|
||||
# Init accumulator
|
||||
try:
|
||||
accumulator = await stream.__anext__()
|
||||
except StopAsyncIteration:
|
||||
raise ValueError("Stream is empty")
|
||||
is_async = asyncio.iscoroutinefunction(reduce_function)
|
||||
async for element in stream:
|
||||
if is_async:
|
||||
accumulator = await reduce_function(accumulator, element)
|
||||
else:
|
||||
accumulator = reduce_function(accumulator, element)
|
||||
return accumulator
|
||||
|
||||
|
||||
class SimpleTaskOutput(TaskOutput[T], Generic[T]):
|
||||
def __init__(self, data: T) -> None:
|
||||
super().__init__()
|
||||
self._data = data
|
||||
|
||||
@property
|
||||
def output(self) -> T:
|
||||
return self._data
|
||||
|
||||
def set_output(self, output_data: T | AsyncIterator[T]) -> None:
|
||||
self._data = output_data
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
return not self._data
|
||||
|
||||
async def _apply_func(self, func) -> Any:
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
out = await func(self._data)
|
||||
else:
|
||||
out = func(self._data)
|
||||
return out
|
||||
|
||||
async def map(self, map_func) -> TaskOutput[T]:
|
||||
out = await self._apply_func(map_func)
|
||||
return SimpleTaskOutput(out)
|
||||
|
||||
async def check_condition(self, condition_func) -> bool:
|
||||
return await self._apply_func(condition_func)
|
||||
|
||||
async def streamify(
|
||||
self, transform_func: Callable[[T], AsyncIterator[T]]
|
||||
) -> TaskOutput[T]:
|
||||
out = await self._apply_func(transform_func)
|
||||
return SimpleStreamTaskOutput(out)
|
||||
|
||||
|
||||
class SimpleStreamTaskOutput(TaskOutput[T], Generic[T]):
|
||||
def __init__(self, data: AsyncIterator[T]) -> None:
|
||||
super().__init__()
|
||||
self._data = data
|
||||
|
||||
@property
|
||||
def is_stream(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
return not self._data
|
||||
|
||||
@property
|
||||
def output_stream(self) -> AsyncIterator[T]:
|
||||
return self._data
|
||||
|
||||
def set_output(self, output_data: T | AsyncIterator[T]) -> None:
|
||||
self._data = output_data
|
||||
|
||||
async def map(self, map_func) -> TaskOutput[T]:
|
||||
is_async = asyncio.iscoroutinefunction(map_func)
|
||||
|
||||
async def new_iter() -> AsyncIterator[T]:
|
||||
async for out in self._data:
|
||||
if is_async:
|
||||
out = await map_func(out)
|
||||
else:
|
||||
out = map_func(out)
|
||||
yield out
|
||||
|
||||
return SimpleStreamTaskOutput(new_iter())
|
||||
|
||||
async def reduce(self, reduce_func) -> TaskOutput[T]:
|
||||
out = await _reduce_stream(self._data, reduce_func)
|
||||
return SimpleTaskOutput(out)
|
||||
|
||||
async def unstreamify(
|
||||
self, transform_func: Callable[[AsyncIterator[T]], T]
|
||||
) -> TaskOutput[T]:
|
||||
if asyncio.iscoroutinefunction(transform_func):
|
||||
out = await transform_func(self._data)
|
||||
else:
|
||||
out = transform_func(self._data)
|
||||
return SimpleTaskOutput(out)
|
||||
|
||||
async def transform_stream(
|
||||
self, transform_func: Callable[[AsyncIterator[T]], AsyncIterator[T]]
|
||||
) -> TaskOutput[T]:
|
||||
if asyncio.iscoroutinefunction(transform_func):
|
||||
out = await transform_func(self._data)
|
||||
else:
|
||||
out = transform_func(self._data)
|
||||
return SimpleStreamTaskOutput(out)
|
||||
|
||||
|
||||
def _is_async_iterator(obj):
|
||||
return (
|
||||
hasattr(obj, "__anext__")
|
||||
and callable(getattr(obj, "__anext__", None))
|
||||
and hasattr(obj, "__aiter__")
|
||||
and callable(getattr(obj, "__aiter__", None))
|
||||
)
|
||||
|
||||
|
||||
class BaseInputSource(InputSource, ABC):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._is_read = False
|
||||
|
||||
@abstractmethod
|
||||
def _read_data(self, task_ctx: TaskContext) -> Any:
|
||||
"""Read data with task context"""
|
||||
|
||||
async def read(self, task_ctx: TaskContext) -> Coroutine[Any, Any, TaskOutput]:
|
||||
data = self._read_data(task_ctx)
|
||||
if _is_async_iterator(data):
|
||||
if self._is_read:
|
||||
raise ValueError(f"Input iterator {data} has been read!")
|
||||
output = SimpleStreamTaskOutput(data)
|
||||
else:
|
||||
output = SimpleTaskOutput(data)
|
||||
self._is_read = True
|
||||
return output
|
||||
|
||||
|
||||
class SimpleInputSource(BaseInputSource):
|
||||
def __init__(self, data: Any) -> None:
|
||||
super().__init__()
|
||||
self._data = data
|
||||
|
||||
def _read_data(self, task_ctx: TaskContext) -> Any:
|
||||
return self._data
|
||||
|
||||
|
||||
class SimpleCallDataInputSource(BaseInputSource):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def _read_data(self, task_ctx: TaskContext) -> Any:
|
||||
call_data = task_ctx.call_data
|
||||
data = call_data.get("data") if call_data else None
|
||||
if not (call_data and data):
|
||||
raise ValueError("No call data for current SimpleCallDataInputSource")
|
||||
return data
|
||||
|
||||
|
||||
class DefaultTaskContext(TaskContext, Generic[T]):
|
||||
def __init__(
|
||||
self, task_id: str, task_state: TaskState, task_output: TaskOutput[T]
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._task_id = task_id
|
||||
self._task_state = task_state
|
||||
self._output = task_output
|
||||
self._task_input = None
|
||||
self._metadata = {}
|
||||
|
||||
@property
|
||||
def task_id(self) -> str:
|
||||
return self._task_id
|
||||
|
||||
@property
|
||||
def task_input(self) -> InputContext:
|
||||
return self._task_input
|
||||
|
||||
def set_task_input(self, input_ctx: "InputContext") -> None:
|
||||
self._task_input = input_ctx
|
||||
|
||||
@property
|
||||
def task_output(self) -> TaskOutput:
|
||||
return self._output
|
||||
|
||||
def set_task_output(self, task_output: TaskOutput) -> None:
|
||||
self._output = task_output
|
||||
|
||||
@property
|
||||
def current_state(self) -> TaskState:
|
||||
return self._task_state
|
||||
|
||||
def set_current_state(self, task_state: TaskState) -> None:
|
||||
self._task_state = task_state
|
||||
|
||||
def new_ctx(self) -> TaskContext:
|
||||
return DefaultTaskContext(self._task_id, self._task_state, self._output)
|
||||
|
||||
@property
|
||||
def metadata(self) -> Dict[str, Any]:
|
||||
return self._metadata
|
||||
|
||||
|
||||
class DefaultInputContext(InputContext):
|
||||
def __init__(self, outputs: List[TaskContext]) -> None:
|
||||
super().__init__()
|
||||
self._outputs = outputs
|
||||
|
||||
@property
|
||||
def parent_outputs(self) -> List[TaskContext]:
|
||||
return self._outputs
|
||||
|
||||
async def _apply_func(
|
||||
self, func: Callable[[Any], Any], apply_type: str = "map"
|
||||
) -> Tuple[List[TaskContext], List[TaskOutput]]:
|
||||
new_outputs: List[TaskContext] = []
|
||||
map_tasks = []
|
||||
for out in self._outputs:
|
||||
new_outputs.append(out.new_ctx())
|
||||
result = None
|
||||
if apply_type == "map":
|
||||
result = out.task_output.map(func)
|
||||
elif apply_type == "reduce":
|
||||
result = out.task_output.reduce(func)
|
||||
elif apply_type == "check_condition":
|
||||
result = out.task_output.check_condition(func)
|
||||
else:
|
||||
raise ValueError(f"Unsupport apply type {apply_type}")
|
||||
map_tasks.append(result)
|
||||
results = await asyncio.gather(*map_tasks)
|
||||
return new_outputs, results
|
||||
|
||||
async def map(self, map_func: Callable[[Any], Any]) -> InputContext:
|
||||
new_outputs, results = await self._apply_func(map_func)
|
||||
for i, task_ctx in enumerate(new_outputs):
|
||||
task_ctx: TaskContext = task_ctx
|
||||
task_ctx.set_task_output(results[i])
|
||||
return DefaultInputContext(new_outputs)
|
||||
|
||||
async def map_all(self, map_func: Callable[..., Any]) -> InputContext:
|
||||
if not self._outputs:
|
||||
return DefaultInputContext([])
|
||||
is_steam = self._outputs[0].task_output.is_stream
|
||||
if is_steam:
|
||||
if not self.check_stream():
|
||||
raise ValueError(
|
||||
"The output in all tasks must has same output format to map_all"
|
||||
)
|
||||
outputs = []
|
||||
for out in self._outputs:
|
||||
if out.task_output.is_stream:
|
||||
outputs.append(out.task_output.output_stream)
|
||||
else:
|
||||
outputs.append(out.task_output.output)
|
||||
if asyncio.iscoroutinefunction(map_func):
|
||||
map_res = await map_func(*outputs)
|
||||
else:
|
||||
map_res = map_func(*outputs)
|
||||
|
||||
single_output: TaskContext = self._outputs[0].new_ctx()
|
||||
single_output.task_output.set_output(map_res)
|
||||
return DefaultInputContext([single_output])
|
||||
|
||||
async def reduce(self, reduce_func: Callable[[Any], Any]) -> InputContext:
|
||||
if not self.check_stream():
|
||||
raise ValueError(
|
||||
"The output in all tasks must has same output format of stream to apply reduce function"
|
||||
)
|
||||
new_outputs, results = await self._apply_func(reduce_func, apply_type="reduce")
|
||||
for i, task_ctx in enumerate(new_outputs):
|
||||
task_ctx: TaskContext = task_ctx
|
||||
task_ctx.set_task_output(results[i])
|
||||
return DefaultInputContext(new_outputs)
|
||||
|
||||
async def filter(self, filter_func: Callable[[Any], bool]) -> InputContext:
|
||||
new_outputs, results = await self._apply_func(
|
||||
filter_func, apply_type="check_condition"
|
||||
)
|
||||
result_outputs = []
|
||||
for i, task_ctx in enumerate(new_outputs):
|
||||
if results[i]:
|
||||
result_outputs.append(task_ctx)
|
||||
return DefaultInputContext(result_outputs)
|
||||
|
||||
async def predicate_map(
|
||||
self, predicate_func: Callable[[Any], bool], failed_value: Any = None
|
||||
) -> "InputContext":
|
||||
new_outputs, results = await self._apply_func(
|
||||
predicate_func, apply_type="check_condition"
|
||||
)
|
||||
result_outputs = []
|
||||
for i, task_ctx in enumerate(new_outputs):
|
||||
task_ctx: TaskContext = task_ctx
|
||||
if results[i]:
|
||||
result_outputs.append(task_ctx)
|
||||
else:
|
||||
task_ctx.task_output.set_output(failed_value)
|
||||
result_outputs.append(task_ctx)
|
||||
return DefaultInputContext(result_outputs)
|
0
pilot/awel/tests/__init__.py
Normal file
0
pilot/awel/tests/__init__.py
Normal file
102
pilot/awel/tests/conftest.py
Normal file
102
pilot/awel/tests/conftest.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from typing import AsyncIterator, List
|
||||
from contextlib import contextmanager, asynccontextmanager
|
||||
from .. import (
|
||||
WorkflowRunner,
|
||||
InputOperator,
|
||||
DAGContext,
|
||||
TaskState,
|
||||
DefaultWorkflowRunner,
|
||||
SimpleInputSource,
|
||||
)
|
||||
from ..task.task_impl import _is_async_iterator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner():
|
||||
return DefaultWorkflowRunner()
|
||||
|
||||
|
||||
def _create_stream(num_nodes) -> List[AsyncIterator[int]]:
|
||||
iters = []
|
||||
for _ in range(num_nodes):
|
||||
|
||||
async def stream_iter():
|
||||
for i in range(10):
|
||||
yield i
|
||||
|
||||
stream_iter = stream_iter()
|
||||
assert _is_async_iterator(stream_iter)
|
||||
iters.append(stream_iter)
|
||||
return iters
|
||||
|
||||
|
||||
def _create_stream_from(output_streams: List[List[int]]) -> List[AsyncIterator[int]]:
|
||||
iters = []
|
||||
for single_stream in output_streams:
|
||||
|
||||
async def stream_iter():
|
||||
for i in single_stream:
|
||||
yield i
|
||||
|
||||
stream_iter = stream_iter()
|
||||
assert _is_async_iterator(stream_iter)
|
||||
iters.append(stream_iter)
|
||||
return iters
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _create_input_node(**kwargs):
|
||||
num_nodes = kwargs.get("num_nodes")
|
||||
is_stream = kwargs.get("is_stream", False)
|
||||
if is_stream:
|
||||
outputs = kwargs.get("output_streams")
|
||||
if outputs:
|
||||
if num_nodes and num_nodes != len(outputs):
|
||||
raise ValueError(
|
||||
f"num_nodes {num_nodes} != the length of output_streams {len(outputs)}"
|
||||
)
|
||||
outputs = _create_stream_from(outputs)
|
||||
else:
|
||||
num_nodes = num_nodes or 1
|
||||
outputs = _create_stream(num_nodes)
|
||||
else:
|
||||
outputs = kwargs.get("outputs", ["Hello."])
|
||||
nodes = []
|
||||
for output in outputs:
|
||||
print(f"output: {output}")
|
||||
input_source = SimpleInputSource(output)
|
||||
input_node = InputOperator(input_source)
|
||||
nodes.append(input_node)
|
||||
yield nodes
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def input_node(request):
|
||||
param = getattr(request, "param", {})
|
||||
async with _create_input_node(**param) as input_nodes:
|
||||
yield input_nodes[0]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def stream_input_node(request):
|
||||
param = getattr(request, "param", {})
|
||||
param["is_stream"] = True
|
||||
async with _create_input_node(**param) as input_nodes:
|
||||
yield input_nodes[0]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def input_nodes(request):
|
||||
param = getattr(request, "param", {})
|
||||
async with _create_input_node(**param) as input_nodes:
|
||||
yield input_nodes
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def stream_input_nodes(request):
|
||||
param = getattr(request, "param", {})
|
||||
param["is_stream"] = True
|
||||
async with _create_input_node(**param) as input_nodes:
|
||||
yield input_nodes
|
51
pilot/awel/tests/test_http_operator.py
Normal file
51
pilot/awel/tests/test_http_operator.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import pytest
|
||||
from typing import List
|
||||
from .. import (
|
||||
DAG,
|
||||
WorkflowRunner,
|
||||
DAGContext,
|
||||
TaskState,
|
||||
InputOperator,
|
||||
MapOperator,
|
||||
JoinOperator,
|
||||
BranchOperator,
|
||||
ReduceStreamOperator,
|
||||
SimpleInputSource,
|
||||
)
|
||||
from .conftest import (
|
||||
runner,
|
||||
input_node,
|
||||
input_nodes,
|
||||
stream_input_node,
|
||||
stream_input_nodes,
|
||||
_is_async_iterator,
|
||||
)
|
||||
|
||||
|
||||
def _register_dag_to_fastapi_app(dag):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_operator(runner: WorkflowRunner, stream_input_node: InputOperator):
|
||||
with DAG("test_map") as dag:
|
||||
pass
|
||||
# http_req_task = HttpRequestOperator(endpoint="/api/completions")
|
||||
# db_task = DBQueryOperator(table_name="user_info")
|
||||
# prompt_task = PromptTemplateOperator(
|
||||
# system_prompt="You are an AI designed to solve the user's goals with given commands, please follow the constraints of the system's input for your answers."
|
||||
# )
|
||||
# llm_task = ChatGPTLLMOperator(model="chagpt-3.5")
|
||||
# output_parser_task = CommonOutputParserOperator()
|
||||
# http_res_task = HttpResponseOperator()
|
||||
# (
|
||||
# http_req_task
|
||||
# >> db_task
|
||||
# >> prompt_task
|
||||
# >> llm_task
|
||||
# >> output_parser_task
|
||||
# >> http_res_task
|
||||
# )
|
||||
|
||||
_register_dag_to_fastapi_app(dag)
|
128
pilot/awel/tests/test_run_dag.py
Normal file
128
pilot/awel/tests/test_run_dag.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import pytest
|
||||
from typing import List
|
||||
from .. import (
|
||||
DAG,
|
||||
WorkflowRunner,
|
||||
DAGContext,
|
||||
TaskState,
|
||||
InputOperator,
|
||||
MapOperator,
|
||||
JoinOperator,
|
||||
BranchOperator,
|
||||
ReduceStreamOperator,
|
||||
SimpleInputSource,
|
||||
)
|
||||
from .conftest import (
|
||||
runner,
|
||||
input_node,
|
||||
input_nodes,
|
||||
stream_input_node,
|
||||
stream_input_nodes,
|
||||
_is_async_iterator,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_input_node(runner: WorkflowRunner):
|
||||
input_node = InputOperator(SimpleInputSource("hello"))
|
||||
res: DAGContext[str] = await runner.execute_workflow(input_node)
|
||||
assert res.current_task_context.current_state == TaskState.SUCCESS
|
||||
assert res.current_task_context.task_output.output == "hello"
|
||||
|
||||
async def new_steam_iter(n: int):
|
||||
for i in range(n):
|
||||
yield i
|
||||
|
||||
num_iter = 10
|
||||
steam_input_node = InputOperator(SimpleInputSource(new_steam_iter(num_iter)))
|
||||
res: DAGContext[str] = await runner.execute_workflow(steam_input_node)
|
||||
assert res.current_task_context.current_state == TaskState.SUCCESS
|
||||
output_steam = res.current_task_context.task_output.output_stream
|
||||
assert output_steam
|
||||
assert _is_async_iterator(output_steam)
|
||||
i = 0
|
||||
async for x in output_steam:
|
||||
assert x == i
|
||||
i += 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_node(runner: WorkflowRunner, stream_input_node: InputOperator):
|
||||
with DAG("test_map") as dag:
|
||||
map_node = MapOperator(lambda x: x * 2)
|
||||
stream_input_node >> map_node
|
||||
res: DAGContext[int] = await runner.execute_workflow(map_node)
|
||||
output_steam = res.current_task_context.task_output.output_stream
|
||||
assert output_steam
|
||||
i = 0
|
||||
async for x in output_steam:
|
||||
assert x == i * 2
|
||||
i += 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"stream_input_node, expect_sum",
|
||||
[
|
||||
({"output_streams": [[0, 1, 2, 3]]}, 6),
|
||||
({"output_streams": [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]}, 55),
|
||||
],
|
||||
indirect=["stream_input_node"],
|
||||
)
|
||||
async def test_reduce_node(
|
||||
runner: WorkflowRunner, stream_input_node: InputOperator, expect_sum: int
|
||||
):
|
||||
with DAG("test_reduce_node") as dag:
|
||||
reduce_node = ReduceStreamOperator(lambda x, y: x + y)
|
||||
stream_input_node >> reduce_node
|
||||
res: DAGContext[int] = await runner.execute_workflow(reduce_node)
|
||||
assert res.current_task_context.current_state == TaskState.SUCCESS
|
||||
assert not res.current_task_context.task_output.is_stream
|
||||
assert res.current_task_context.task_output.output == expect_sum
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"input_nodes",
|
||||
[
|
||||
({"outputs": [0, 1, 2]}),
|
||||
],
|
||||
indirect=["input_nodes"],
|
||||
)
|
||||
async def test_join_node(runner: WorkflowRunner, input_nodes: List[InputOperator]):
|
||||
def join_func(p1, p2, p3) -> int:
|
||||
return p1 + p2 + p3
|
||||
|
||||
with DAG("test_join_node") as dag:
|
||||
join_node = JoinOperator(join_func)
|
||||
for input_node in input_nodes:
|
||||
input_node >> join_node
|
||||
res: DAGContext[int] = await runner.execute_workflow(join_node)
|
||||
assert res.current_task_context.current_state == TaskState.SUCCESS
|
||||
assert not res.current_task_context.task_output.is_stream
|
||||
assert res.current_task_context.task_output.output == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"input_node, is_odd",
|
||||
[
|
||||
({"outputs": [0]}, False),
|
||||
# ({"outputs": [1]}, False),
|
||||
],
|
||||
indirect=["input_node"],
|
||||
)
|
||||
async def test_branch_node(
|
||||
runner: WorkflowRunner, input_node: InputOperator, is_odd: bool
|
||||
):
|
||||
with DAG("test_join_node") as dag:
|
||||
odd_node = MapOperator(lambda x: 999, task_id="odd_node")
|
||||
even_node = MapOperator(lambda x: 888, task_id="even_node")
|
||||
branch_node = BranchOperator(
|
||||
{lambda x: x % 2 == 1: odd_node, lambda x: x % 2 == 0: even_node}
|
||||
)
|
||||
input_node >> branch_node
|
||||
|
||||
odd_res: DAGContext[int] = await runner.execute_workflow(odd_node)
|
||||
even_res: DAGContext[int] = await runner.execute_workflow(even_node)
|
||||
assert branch_node.current_task_context.current_state == TaskState.SUCCESS
|
10
pilot/cache/__init__.py
vendored
Normal file
10
pilot/cache/__init__.py
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
from pilot.cache.llm_cache import LLMCacheClient, LLMCacheKey, LLMCacheValue
|
||||
from pilot.cache.manager import CacheManager, initialize_cache
|
||||
|
||||
__all__ = [
|
||||
"LLMCacheKey",
|
||||
"LLMCacheValue",
|
||||
"LLMCacheClient",
|
||||
"CacheManager",
|
||||
"initialize_cache",
|
||||
]
|
161
pilot/cache/base.py
vendored
Normal file
161
pilot/cache/base.py
vendored
Normal file
@@ -0,0 +1,161 @@
|
||||
from abc import ABC, abstractmethod, abstractclassmethod
|
||||
from typing import Any, TypeVar, Generic, Optional, Type, Dict
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
T = TypeVar("T", bound="Serializable")
|
||||
|
||||
K = TypeVar("K")
|
||||
V = TypeVar("V")
|
||||
|
||||
|
||||
class Serializable(ABC):
|
||||
@abstractmethod
|
||||
def serialize(self) -> bytes:
|
||||
"""Convert the object into bytes for storage or transmission.
|
||||
|
||||
Returns:
|
||||
bytes: The byte array after serialization
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert the object's state to a dictionary."""
|
||||
|
||||
# @staticmethod
|
||||
# @abstractclassmethod
|
||||
# def from_dict(cls: Type["Serializable"], obj_dict: Dict) -> "Serializable":
|
||||
# """Deserialize a dictionary to an Serializable object.
|
||||
# """
|
||||
|
||||
|
||||
class RetrievalPolicy(str, Enum):
|
||||
EXACT_MATCH = "exact_match"
|
||||
SIMILARITY_MATCH = "similarity_match"
|
||||
|
||||
|
||||
class CachePolicy(str, Enum):
|
||||
LRU = "lru"
|
||||
FIFO = "fifo"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheConfig:
|
||||
retrieval_policy: Optional[RetrievalPolicy] = RetrievalPolicy.EXACT_MATCH
|
||||
cache_policy: Optional[CachePolicy] = CachePolicy.LRU
|
||||
|
||||
|
||||
class CacheKey(Serializable, ABC, Generic[K]):
|
||||
"""The key of the cache. Must be hashable and comparable.
|
||||
|
||||
Supported cache keys:
|
||||
- The LLM cache key: Include user prompt and the parameters to LLM.
|
||||
- The embedding model cache key: Include the texts to embedding and the parameters to embedding model.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __hash__(self) -> int:
|
||||
"""Return the hash value of the key."""
|
||||
|
||||
@abstractmethod
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Check equality with another key."""
|
||||
|
||||
@abstractmethod
|
||||
def get_hash_bytes(self) -> bytes:
|
||||
"""Return the byte array of hash value."""
|
||||
|
||||
@abstractmethod
|
||||
def get_value(self) -> K:
|
||||
"""Get the underlying value of the cache key.
|
||||
|
||||
Returns:
|
||||
K: The real object of current cache key
|
||||
"""
|
||||
|
||||
|
||||
class CacheValue(Serializable, ABC, Generic[V]):
|
||||
"""Cache value abstract class."""
|
||||
|
||||
@abstractmethod
|
||||
def get_value(self) -> V:
|
||||
"""Get the underlying real value."""
|
||||
|
||||
|
||||
class Serializer(ABC):
|
||||
"""The serializer abstract class for serializing cache keys and values."""
|
||||
|
||||
@abstractmethod
|
||||
def serialize(self, obj: Serializable) -> bytes:
|
||||
"""Serialize a cache object.
|
||||
|
||||
Args:
|
||||
obj (Serializable): The object to serialize
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def deserialize(self, data: bytes, cls: Type[Serializable]) -> Serializable:
|
||||
"""Deserialize data back into a cache object of the specified type.
|
||||
|
||||
Args:
|
||||
data (bytes): The byte array to deserialize
|
||||
cls (Type[Serializable]): The type of current object
|
||||
|
||||
Returns:
|
||||
Serializable: The serializable object
|
||||
"""
|
||||
|
||||
|
||||
class CacheClient(ABC, Generic[K, V]):
|
||||
"""The cache client interface."""
|
||||
|
||||
@abstractmethod
|
||||
async def get(
|
||||
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
|
||||
) -> Optional[CacheValue[V]]:
|
||||
"""Retrieve a value from the cache using the provided key.
|
||||
|
||||
Args:
|
||||
key (CacheKey[K]): The key to get cache
|
||||
cache_config (Optional[CacheConfig]): Cache config
|
||||
|
||||
Returns:
|
||||
Optional[CacheValue[V]]: The value retrieved according to key. If cache key not exist, return None.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def set(
|
||||
self,
|
||||
key: CacheKey[K],
|
||||
value: CacheValue[V],
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> None:
|
||||
"""Set a value in the cache for the provided key.
|
||||
|
||||
Args:
|
||||
key (CacheKey[K]): The key to set to cache
|
||||
value (CacheValue[V]): The value to set to cache
|
||||
cache_config (Optional[CacheConfig]): Cache config
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def exists(
|
||||
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
|
||||
) -> bool:
|
||||
"""Check if a key exists in the cache.
|
||||
|
||||
Args:
|
||||
key (CacheKey[K]): The key to set to cache
|
||||
cache_config (Optional[CacheConfig]): Cache config
|
||||
|
||||
Return:
|
||||
bool: True if the key in the cache, otherwise is False
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def new_key(self, **kwargs) -> CacheKey[K]:
|
||||
"""Create a cache key with params"""
|
||||
|
||||
@abstractmethod
|
||||
def new_value(self, **kwargs) -> CacheValue[K]:
|
||||
"""Create a cache key with params"""
|
0
pilot/cache/embedding_cache.py
vendored
Normal file
0
pilot/cache/embedding_cache.py
vendored
Normal file
148
pilot/cache/llm_cache.py
vendored
Normal file
148
pilot/cache/llm_cache.py
vendored
Normal file
@@ -0,0 +1,148 @@
|
||||
from typing import Optional, Dict, Any, Union, List
|
||||
from dataclasses import dataclass, asdict
|
||||
import json
|
||||
import hashlib
|
||||
|
||||
from pilot.cache.base import CacheKey, CacheValue, Serializer, CacheClient, CacheConfig
|
||||
from pilot.cache.manager import CacheManager
|
||||
from pilot.cache.storage.base import CacheStorage
|
||||
from pilot.model.base import ModelType, ModelOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMCacheKeyData:
|
||||
prompt: str
|
||||
model_name: str
|
||||
temperature: Optional[float] = 0.7
|
||||
max_new_tokens: Optional[int] = None
|
||||
top_p: Optional[float] = 1.0
|
||||
model_type: Optional[str] = ModelType.HF
|
||||
|
||||
|
||||
CacheOutputType = Union[ModelOutput, List[ModelOutput]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMCacheValueData:
|
||||
output: CacheOutputType
|
||||
user: Optional[str] = None
|
||||
_is_list: Optional[bool] = False
|
||||
|
||||
@staticmethod
|
||||
def from_dict(**kwargs) -> "LLMCacheValueData":
|
||||
output = kwargs.get("output")
|
||||
if not output:
|
||||
raise ValueError("Can't new LLMCacheValueData object, output is None")
|
||||
if isinstance(output, dict):
|
||||
output = ModelOutput(**output)
|
||||
elif isinstance(output, list):
|
||||
kwargs["_is_list"] = True
|
||||
output_list = []
|
||||
for out in output:
|
||||
if isinstance(out, dict):
|
||||
out = ModelOutput(**out)
|
||||
output_list.append(out)
|
||||
output = output_list
|
||||
kwargs["output"] = output
|
||||
return LLMCacheValueData(**kwargs)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
output = self.output
|
||||
is_list = False
|
||||
if isinstance(output, list):
|
||||
output_list = []
|
||||
is_list = True
|
||||
for out in output:
|
||||
output_list.append(out.to_dict())
|
||||
output = output_list
|
||||
else:
|
||||
output = output.to_dict()
|
||||
return {"output": output, "_is_list": is_list, "user": self.user}
|
||||
|
||||
@property
|
||||
def is_list(self) -> bool:
|
||||
return self._is_list
|
||||
|
||||
def __str__(self) -> str:
|
||||
if not isinstance(self.output, list):
|
||||
return f"user: {self.user}, output: {self.output}"
|
||||
else:
|
||||
return f"user: {self.user}, output(last two item): {self.output[-2:]}"
|
||||
|
||||
|
||||
class LLMCacheKey(CacheKey[LLMCacheKeyData]):
|
||||
def __init__(self, serializer: Serializer = None, **kwargs) -> None:
|
||||
super().__init__()
|
||||
self._serializer = serializer
|
||||
self.config = LLMCacheKeyData(**kwargs)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
serialize_bytes = self.serialize()
|
||||
return int(hashlib.sha256(serialize_bytes).hexdigest(), 16)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if not isinstance(other, LLMCacheKey):
|
||||
return False
|
||||
return self.config == other.config
|
||||
|
||||
def get_hash_bytes(self) -> bytes:
|
||||
serialize_bytes = self.serialize()
|
||||
return hashlib.sha256(serialize_bytes).digest()
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return asdict(self.config)
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
return self._serializer.serialize(self)
|
||||
|
||||
def get_value(self) -> LLMCacheKeyData:
|
||||
return self.config
|
||||
|
||||
|
||||
class LLMCacheValue(CacheValue[LLMCacheValueData]):
|
||||
def __init__(self, serializer: Serializer = None, **kwargs) -> None:
|
||||
super().__init__()
|
||||
self._serializer = serializer
|
||||
self.value = LLMCacheValueData.from_dict(**kwargs)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return self.value.to_dict()
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
return self._serializer.serialize(self)
|
||||
|
||||
def get_value(self) -> LLMCacheValueData:
|
||||
return self.value
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"vaue: {str(self.value)}"
|
||||
|
||||
|
||||
class LLMCacheClient(CacheClient[LLMCacheKeyData, LLMCacheValueData]):
|
||||
def __init__(self, cache_manager: CacheManager) -> None:
|
||||
super().__init__()
|
||||
self._cache_manager: CacheManager = cache_manager
|
||||
|
||||
async def get(
|
||||
self, key: LLMCacheKey, cache_config: Optional[CacheConfig] = None
|
||||
) -> Optional[LLMCacheValue]:
|
||||
return await self._cache_manager.get(key, LLMCacheValue, cache_config)
|
||||
|
||||
async def set(
|
||||
self,
|
||||
key: LLMCacheKey,
|
||||
value: LLMCacheValue,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> None:
|
||||
return await self._cache_manager.set(key, value, cache_config)
|
||||
|
||||
async def exists(
|
||||
self, key: LLMCacheKey, cache_config: Optional[CacheConfig] = None
|
||||
) -> bool:
|
||||
return await self.get(key, cache_config) is not None
|
||||
|
||||
def new_key(self, **kwargs) -> LLMCacheKey:
|
||||
return LLMCacheKey(serializer=self._cache_manager.serializer, **kwargs)
|
||||
|
||||
def new_value(self, **kwargs) -> LLMCacheValue:
|
||||
return LLMCacheValue(serializer=self._cache_manager.serializer, **kwargs)
|
118
pilot/cache/manager.py
vendored
Normal file
118
pilot/cache/manager.py
vendored
Normal file
@@ -0,0 +1,118 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Type
|
||||
import logging
|
||||
from concurrent.futures import Executor
|
||||
from pilot.cache.storage.base import CacheStorage, StorageItem
|
||||
from pilot.cache.base import (
|
||||
K,
|
||||
V,
|
||||
CacheKey,
|
||||
CacheValue,
|
||||
CacheConfig,
|
||||
Serializer,
|
||||
Serializable,
|
||||
)
|
||||
from pilot.component import BaseComponent, ComponentType, SystemApp
|
||||
from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CacheManager(BaseComponent, ABC):
|
||||
name = ComponentType.MODEL_CACHE_MANAGER
|
||||
|
||||
def __init__(self, system_app: SystemApp | None = None):
|
||||
super().__init__(system_app)
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
self.system_app = system_app
|
||||
|
||||
@abstractmethod
|
||||
async def set(
|
||||
self,
|
||||
key: CacheKey[K],
|
||||
value: CacheValue[V],
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
):
|
||||
"""Set cache"""
|
||||
|
||||
@abstractmethod
|
||||
async def get(
|
||||
self,
|
||||
key: CacheKey[K],
|
||||
cls: Type[Serializable],
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> CacheValue[V]:
|
||||
"""Get cache with key"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def serializer(self) -> Serializer:
|
||||
"""Get cache serializer"""
|
||||
|
||||
|
||||
class LocalCacheManager(CacheManager):
|
||||
def __init__(
|
||||
self, system_app: SystemApp, serializer: Serializer, storage: CacheStorage
|
||||
) -> None:
|
||||
super().__init__(system_app)
|
||||
self._serializer = serializer
|
||||
self._storage = storage
|
||||
|
||||
@property
|
||||
def executor(self) -> Executor:
|
||||
"""Return executor to submit task"""
|
||||
self._executor = self.system_app.get_component(
|
||||
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||
).create()
|
||||
|
||||
async def set(
|
||||
self,
|
||||
key: CacheKey[K],
|
||||
value: CacheValue[V],
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
):
|
||||
if self._storage.support_async():
|
||||
await self._storage.aset(key, value, cache_config)
|
||||
else:
|
||||
await blocking_func_to_async(
|
||||
self.executor, self._storage.set, key, value, cache_config
|
||||
)
|
||||
|
||||
async def get(
|
||||
self,
|
||||
key: CacheKey[K],
|
||||
cls: Type[Serializable],
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> CacheValue[V]:
|
||||
if self._storage.support_async():
|
||||
item_bytes = await self._storage.aget(key, cache_config)
|
||||
else:
|
||||
item_bytes = await blocking_func_to_async(
|
||||
self.executor, self._storage.get, key, cache_config
|
||||
)
|
||||
if not item_bytes:
|
||||
return None
|
||||
return self._serializer.deserialize(item_bytes.value_data, cls)
|
||||
|
||||
@property
|
||||
def serializer(self) -> Serializer:
|
||||
return self._serializer
|
||||
|
||||
|
||||
def initialize_cache(system_app: SystemApp, persist_dir: str):
|
||||
from pilot.cache.protocal.json_protocal import JsonSerializer
|
||||
from pilot.cache.storage.base import MemoryCacheStorage
|
||||
|
||||
try:
|
||||
from pilot.cache.storage.disk.disk_storage import DiskCacheStorage
|
||||
|
||||
cache_storage = DiskCacheStorage(persist_dir)
|
||||
except ImportError as e:
|
||||
logger.warn(
|
||||
f"Can't import DiskCacheStorage, use MemoryCacheStorage, import error message: {str(e)}"
|
||||
)
|
||||
cache_storage = MemoryCacheStorage()
|
||||
system_app.register(
|
||||
LocalCacheManager, serializer=JsonSerializer(), storage=cache_storage
|
||||
)
|
0
pilot/cache/protocal/__init__.py
vendored
Normal file
0
pilot/cache/protocal/__init__.py
vendored
Normal file
44
pilot/cache/protocal/json_protocal.py
vendored
Normal file
44
pilot/cache/protocal/json_protocal.py
vendored
Normal file
@@ -0,0 +1,44 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Type
|
||||
import json
|
||||
|
||||
from pilot.cache.base import Serializable, Serializer
|
||||
|
||||
JSON_ENCODING = "utf-8"
|
||||
|
||||
|
||||
class JsonSerializable(Serializable, ABC):
|
||||
@abstractmethod
|
||||
def to_dict(self) -> Dict:
|
||||
"""Return the dict of current serializable object"""
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
"""Convert the object into bytes for storage or transmission."""
|
||||
return json.dumps(self.to_dict(), ensure_ascii=False).encode(JSON_ENCODING)
|
||||
|
||||
|
||||
class JsonSerializer(Serializer):
|
||||
"""The serializer abstract class for serializing cache keys and values."""
|
||||
|
||||
def serialize(self, obj: Serializable) -> bytes:
|
||||
"""Serialize a cache object.
|
||||
|
||||
Args:
|
||||
obj (Serializable): The object to serialize
|
||||
"""
|
||||
return json.dumps(obj.to_dict(), ensure_ascii=False).encode(JSON_ENCODING)
|
||||
|
||||
def deserialize(self, data: bytes, cls: Type[Serializable]) -> Serializable:
|
||||
"""Deserialize data back into a cache object of the specified type.
|
||||
|
||||
Args:
|
||||
data (bytes): The byte array to deserialize
|
||||
cls (Type[Serializable]): The type of current object
|
||||
|
||||
Returns:
|
||||
Serializable: The serializable object
|
||||
"""
|
||||
# Convert bytes back to JSON and then to the specified class
|
||||
json_data = json.loads(data.decode(JSON_ENCODING))
|
||||
# Assume that the cls has an __init__ that accepts a dictionary
|
||||
return cls(**json_data)
|
0
pilot/cache/storage/__init__.py
vendored
Normal file
0
pilot/cache/storage/__init__.py
vendored
Normal file
252
pilot/cache/storage/base.py
vendored
Normal file
252
pilot/cache/storage/base.py
vendored
Normal file
@@ -0,0 +1,252 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
from collections import OrderedDict
|
||||
import msgpack
|
||||
import logging
|
||||
|
||||
from pilot.cache.base import (
|
||||
K,
|
||||
V,
|
||||
CacheKey,
|
||||
CacheValue,
|
||||
CacheClient,
|
||||
CacheConfig,
|
||||
RetrievalPolicy,
|
||||
CachePolicy,
|
||||
)
|
||||
from pilot.utils.memory_utils import _get_object_bytes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StorageItem:
|
||||
"""
|
||||
A class representing a storage item.
|
||||
|
||||
This class encapsulates data related to a storage item, such as its length,
|
||||
the hash of the key, and the data for both the key and value.
|
||||
|
||||
Parameters:
|
||||
length (int): The bytes length of the storage item.
|
||||
key_hash (bytes): The hash value of the storage item's key.
|
||||
key_data (bytes): The data of the storage item's key, represented in bytes.
|
||||
value_data (bytes): The data of the storage item's value, also in bytes.
|
||||
"""
|
||||
|
||||
length: int # The bytes length of the storage item
|
||||
key_hash: bytes # The hash value of the storage item's key
|
||||
key_data: bytes # The data of the storage item's key
|
||||
value_data: bytes # The data of the storage item's value
|
||||
|
||||
@staticmethod
|
||||
def build_from(
|
||||
key_hash: bytes, key_data: bytes, value_data: bytes
|
||||
) -> "StorageItem":
|
||||
length = (
|
||||
32
|
||||
+ _get_object_bytes(key_hash)
|
||||
+ _get_object_bytes(key_data)
|
||||
+ _get_object_bytes(value_data)
|
||||
)
|
||||
return StorageItem(
|
||||
length=length, key_hash=key_hash, key_data=key_data, value_data=value_data
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_from_kv(key: CacheKey[K], value: CacheValue[V]) -> "StorageItem":
|
||||
key_hash = key.get_hash_bytes()
|
||||
key_data = key.serialize()
|
||||
value_data = value.serialize()
|
||||
return StorageItem.build_from(key_hash, key_data, value_data)
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
"""Serialize the StorageItem into a byte stream using MessagePack.
|
||||
|
||||
This method packs the object data into a dictionary, marking the
|
||||
key_data and value_data fields as raw binary data to avoid re-serialization.
|
||||
|
||||
Returns:
|
||||
bytes: The serialized bytes.
|
||||
"""
|
||||
obj = {
|
||||
"length": self.length,
|
||||
"key_hash": msgpack.ExtType(1, self.key_hash),
|
||||
"key_data": msgpack.ExtType(2, self.key_data),
|
||||
"value_data": msgpack.ExtType(3, self.value_data),
|
||||
}
|
||||
return msgpack.packb(obj)
|
||||
|
||||
@staticmethod
|
||||
def deserialize(data: bytes) -> "StorageItem":
|
||||
"""Deserialize bytes back into a StorageItem using MessagePack.
|
||||
|
||||
This extracts the fields from the MessagePack dict back into
|
||||
a StorageItem object.
|
||||
|
||||
Args:
|
||||
data (bytes): Serialized bytes
|
||||
|
||||
Returns:
|
||||
StorageItem: Deserialized StorageItem object.
|
||||
"""
|
||||
obj = msgpack.unpackb(data)
|
||||
key_hash = obj["key_hash"].data
|
||||
key_data = obj["key_data"].data
|
||||
value_data = obj["value_data"].data
|
||||
|
||||
return StorageItem(
|
||||
length=obj["length"],
|
||||
key_hash=key_hash,
|
||||
key_data=key_data,
|
||||
value_data=value_data,
|
||||
)
|
||||
|
||||
|
||||
class CacheStorage(ABC):
|
||||
@abstractmethod
|
||||
def check_config(
|
||||
self,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
raise_error: Optional[bool] = True,
|
||||
) -> bool:
|
||||
"""Check whether the CacheConfig is legal.
|
||||
|
||||
Args:
|
||||
cache_config (Optional[CacheConfig]): Cache config.
|
||||
raise_error (Optional[bool]): Whether raise error if illegal.
|
||||
|
||||
Returns:
|
||||
ValueError: Error when raise_error is True and config is illegal.
|
||||
"""
|
||||
|
||||
def support_async(self) -> bool:
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def get(
|
||||
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
|
||||
) -> Optional[StorageItem]:
|
||||
"""Retrieve a storage item from the cache using the provided key.
|
||||
|
||||
Args:
|
||||
key (CacheKey[K]): The key to get cache
|
||||
cache_config (Optional[CacheConfig]): Cache config
|
||||
|
||||
Returns:
|
||||
Optional[StorageItem]: The storage item retrieved according to key. If cache key not exist, return None.
|
||||
"""
|
||||
|
||||
async def aget(
|
||||
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
|
||||
) -> Optional[StorageItem]:
|
||||
"""Retrieve a storage item from the cache using the provided key asynchronously.
|
||||
|
||||
Args:
|
||||
key (CacheKey[K]): The key to get cache
|
||||
cache_config (Optional[CacheConfig]): Cache config
|
||||
|
||||
Returns:
|
||||
Optional[StorageItem]: The storage item of bytes retrieved according to key. If cache key not exist, return None.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def set(
|
||||
self,
|
||||
key: CacheKey[K],
|
||||
value: CacheValue[V],
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> None:
|
||||
"""Set a value in the cache for the provided key asynchronously.
|
||||
|
||||
Args:
|
||||
key (CacheKey[K]): The key to set to cache
|
||||
value (CacheValue[V]): The value to set to cache
|
||||
cache_config (Optional[CacheConfig]): Cache config
|
||||
"""
|
||||
|
||||
async def aset(
|
||||
self,
|
||||
key: CacheKey[K],
|
||||
value: CacheValue[V],
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> None:
|
||||
"""Set a value in the cache for the provided key asynchronously.
|
||||
|
||||
Args:
|
||||
key (CacheKey[K]): The key to set to cache
|
||||
value (CacheValue[V]): The value to set to cache
|
||||
cache_config (Optional[CacheConfig]): Cache config
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MemoryCacheStorage(CacheStorage):
|
||||
def __init__(self, max_memory_mb: int = 1024):
|
||||
self.cache = OrderedDict()
|
||||
self.max_memory = max_memory_mb * 1024 * 1024
|
||||
self.current_memory_usage = 0
|
||||
|
||||
def check_config(
|
||||
self,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
raise_error: Optional[bool] = True,
|
||||
) -> bool:
|
||||
if (
|
||||
cache_config
|
||||
and cache_config.retrieval_policy != RetrievalPolicy.EXACT_MATCH
|
||||
):
|
||||
if raise_error:
|
||||
raise ValueError(
|
||||
"MemoryCacheStorage only supports 'EXACT_MATCH' retrieval policy"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
def get(
|
||||
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
|
||||
) -> Optional[StorageItem]:
|
||||
self.check_config(cache_config, raise_error=True)
|
||||
# Exact match retrieval
|
||||
key_hash = hash(key)
|
||||
item: StorageItem = self.cache.get(key_hash)
|
||||
logger.debug(f"MemoryCacheStorage get key {key}, hash {key_hash}, item: {item}")
|
||||
|
||||
if not item:
|
||||
return None
|
||||
# Move the item to the end of the OrderedDict to signify recent use.
|
||||
self.cache.move_to_end(key_hash)
|
||||
return item
|
||||
|
||||
def set(
|
||||
self,
|
||||
key: CacheKey[K],
|
||||
value: CacheValue[V],
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> None:
|
||||
key_hash = hash(key)
|
||||
item = StorageItem.build_from_kv(key, value)
|
||||
# Calculate memory size of the new entry
|
||||
new_entry_size = _get_object_bytes(item)
|
||||
# Evict entries if necessary
|
||||
while self.current_memory_usage + new_entry_size > self.max_memory:
|
||||
self._apply_cache_policy(cache_config)
|
||||
|
||||
# Store the item in the cache.
|
||||
self.cache[key_hash] = item
|
||||
self.current_memory_usage += new_entry_size
|
||||
logger.debug(f"MemoryCacheStorage set key {key}, hash {key_hash}, item: {item}")
|
||||
|
||||
def exists(
|
||||
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
|
||||
) -> bool:
|
||||
return self.get(key, cache_config) is not None
|
||||
|
||||
def _apply_cache_policy(self, cache_config: Optional[CacheConfig] = None):
|
||||
# Remove the oldest/newest item based on the cache policy.
|
||||
if cache_config and cache_config.cache_policy == CachePolicy.FIFO:
|
||||
self.cache.popitem(last=False)
|
||||
else: # Default is LRU
|
||||
self.cache.popitem(last=True)
|
0
pilot/cache/storage/disk/__init__.py
vendored
Normal file
0
pilot/cache/storage/disk/__init__.py
vendored
Normal file
93
pilot/cache/storage/disk/disk_storage.py
vendored
Normal file
93
pilot/cache/storage/disk/disk_storage.py
vendored
Normal file
@@ -0,0 +1,93 @@
|
||||
from typing import Optional
|
||||
import logging
|
||||
from pilot.cache.base import (
|
||||
K,
|
||||
V,
|
||||
CacheKey,
|
||||
CacheValue,
|
||||
CacheConfig,
|
||||
RetrievalPolicy,
|
||||
CachePolicy,
|
||||
)
|
||||
from pilot.cache.storage.base import StorageItem, CacheStorage
|
||||
from rocksdict import Rdict
|
||||
from rocksdict import Rdict, Options, SliceTransform, PlainTableFactoryOptions
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def db_options(
|
||||
mem_table_buffer_mb: Optional[int] = 256, background_threads: Optional[int] = 2
|
||||
):
|
||||
opt = Options()
|
||||
# create table
|
||||
opt.create_if_missing(True)
|
||||
# config to more jobs, default 2
|
||||
opt.set_max_background_jobs(background_threads)
|
||||
# configure mem-table to a large value
|
||||
opt.set_write_buffer_size(mem_table_buffer_mb * 1024 * 1024)
|
||||
# opt.set_write_buffer_size(1024)
|
||||
# opt.set_level_zero_file_num_compaction_trigger(4)
|
||||
# configure l0 and l1 size, let them have the same size (1 GB)
|
||||
# opt.set_max_bytes_for_level_base(0x40000000)
|
||||
# 256 MB file size
|
||||
# opt.set_target_file_size_base(0x10000000)
|
||||
# use a smaller compaction multiplier
|
||||
# opt.set_max_bytes_for_level_multiplier(4.0)
|
||||
# use 8-byte prefix (2 ^ 64 is far enough for transaction counts)
|
||||
# opt.set_prefix_extractor(SliceTransform.create_max_len_prefix(8))
|
||||
# set to plain-table
|
||||
# opt.set_plain_table_factory(PlainTableFactoryOptions())
|
||||
return opt
|
||||
|
||||
|
||||
class DiskCacheStorage(CacheStorage):
|
||||
def __init__(
|
||||
self, persist_dir: str, mem_table_buffer_mb: Optional[int] = 256
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.db: Rdict = Rdict(
|
||||
persist_dir, db_options(mem_table_buffer_mb=mem_table_buffer_mb)
|
||||
)
|
||||
|
||||
def check_config(
|
||||
self,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
raise_error: Optional[bool] = True,
|
||||
) -> bool:
|
||||
if (
|
||||
cache_config
|
||||
and cache_config.retrieval_policy != RetrievalPolicy.EXACT_MATCH
|
||||
):
|
||||
if raise_error:
|
||||
raise ValueError(
|
||||
"DiskCacheStorage only supports 'EXACT_MATCH' retrieval policy"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
def get(
|
||||
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
|
||||
) -> Optional[StorageItem]:
|
||||
self.check_config(cache_config, raise_error=True)
|
||||
|
||||
# Exact match retrieval
|
||||
key_hash = key.get_hash_bytes()
|
||||
item_bytes = self.db.get(key_hash)
|
||||
if not item_bytes:
|
||||
return None
|
||||
item = StorageItem.deserialize(item_bytes)
|
||||
logger.debug(f"Read file cache, key: {key}, storage item: {item}")
|
||||
return item
|
||||
|
||||
def set(
|
||||
self,
|
||||
key: CacheKey[K],
|
||||
value: CacheValue[V],
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> None:
|
||||
item = StorageItem.build_from_kv(key, value)
|
||||
key_hash = item.key_hash
|
||||
self.db[key_hash] = item.serialize()
|
||||
logger.debug(f"Save file cache, key: {key}, value: {value}")
|
0
pilot/cache/storage/tests/__init__.py
vendored
Normal file
0
pilot/cache/storage/tests/__init__.py
vendored
Normal file
53
pilot/cache/storage/tests/test_storage.py
vendored
Normal file
53
pilot/cache/storage/tests/test_storage.py
vendored
Normal file
@@ -0,0 +1,53 @@
|
||||
import pytest
|
||||
from ..base import StorageItem
|
||||
from pilot.utils.memory_utils import _get_object_bytes
|
||||
|
||||
|
||||
def test_build_from():
|
||||
key_hash = b"key_hash"
|
||||
key_data = b"key_data"
|
||||
value_data = b"value_data"
|
||||
item = StorageItem.build_from(key_hash, key_data, value_data)
|
||||
|
||||
assert item.key_hash == key_hash
|
||||
assert item.key_data == key_data
|
||||
assert item.value_data == value_data
|
||||
assert item.length == 32 + _get_object_bytes(key_hash) + _get_object_bytes(
|
||||
key_data
|
||||
) + _get_object_bytes(value_data)
|
||||
|
||||
|
||||
def test_build_from_kv():
|
||||
class MockCacheKey:
|
||||
def get_hash_bytes(self):
|
||||
return b"key_hash"
|
||||
|
||||
def serialize(self):
|
||||
return b"key_data"
|
||||
|
||||
class MockCacheValue:
|
||||
def serialize(self):
|
||||
return b"value_data"
|
||||
|
||||
key = MockCacheKey()
|
||||
value = MockCacheValue()
|
||||
item = StorageItem.build_from_kv(key, value)
|
||||
|
||||
assert item.key_hash == key.get_hash_bytes()
|
||||
assert item.key_data == key.serialize()
|
||||
assert item.value_data == value.serialize()
|
||||
|
||||
|
||||
def test_serialize_deserialize():
|
||||
key_hash = b"key_hash"
|
||||
key_data = b"key_data"
|
||||
value_data = b"value_data"
|
||||
item = StorageItem.build_from(key_hash, key_data, value_data)
|
||||
|
||||
serialized = item.serialize()
|
||||
deserialized = StorageItem.deserialize(serialized)
|
||||
|
||||
assert deserialized.key_hash == item.key_hash
|
||||
assert deserialized.key_data == item.key_data
|
||||
assert deserialized.value_data == item.value_data
|
||||
assert deserialized.length == item.length
|
@@ -48,6 +48,7 @@ class ComponentType(str, Enum):
|
||||
MODEL_CONTROLLER = "dbgpt_model_controller"
|
||||
MODEL_REGISTRY = "dbgpt_model_registry"
|
||||
MODEL_API_SERVER = "dbgpt_model_api_server"
|
||||
MODEL_CACHE_MANAGER = "dbgpt_model_cache_manager"
|
||||
AGENT_HUB = "dbgpt_agent_hub"
|
||||
EXECUTOR_DEFAULT = "dbgpt_thread_pool_default"
|
||||
TRACER = "dbgpt_tracer"
|
||||
|
@@ -253,6 +253,11 @@ class Config(metaclass=Singleton):
|
||||
### Temporary configuration
|
||||
self.USE_FASTCHAT: bool = os.getenv("USE_FASTCHAT", "True").lower() == "true"
|
||||
|
||||
self.MODEL_CACHE_STORAGE: str = os.getenv("MODEL_CACHE_STORAGE")
|
||||
self.MODEL_CACHE_STORAGE_DIST_DIR: str = os.getenv(
|
||||
"MODEL_CACHE_STORAGE_DIST_DIR"
|
||||
)
|
||||
|
||||
def set_debug_mode(self, value: bool) -> None:
|
||||
"""Set the debug mode value"""
|
||||
self.debug_mode = value
|
||||
|
@@ -14,6 +14,7 @@ DATA_DIR = os.path.join(PILOT_PATH, "data")
|
||||
# nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path
|
||||
PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins")
|
||||
FONT_DIR = os.path.join(PILOT_PATH, "fonts")
|
||||
MODEL_DISK_CACHE_DIR = os.path.join(DATA_DIR, "model_cache")
|
||||
|
||||
current_directory = os.getcwd()
|
||||
|
||||
|
0
pilot/model/operator/__init__.py
Normal file
0
pilot/model/operator/__init__.py
Normal file
119
pilot/model/operator/model_operator.py
Normal file
119
pilot/model/operator/model_operator.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from typing import AsyncIterator, Dict
|
||||
import logging
|
||||
from pilot.awel import (
|
||||
StreamifyAbsOperator,
|
||||
MapOperator,
|
||||
TransformStreamAbsOperator,
|
||||
)
|
||||
from pilot.model.base import ModelOutput
|
||||
from pilot.model.cluster import WorkerManager
|
||||
from pilot.cache import LLMCacheClient, CacheManager, LLMCacheKey, LLMCacheValue
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_LLM_MODEL_INPUT_VALUE_KEY = "llm_model_input_value"
|
||||
_LLM_MODEL_OUTPUT_CACHE_KEY = "llm_model_output_cache"
|
||||
|
||||
|
||||
class ModelStreamOperator(StreamifyAbsOperator[Dict, ModelOutput]):
|
||||
def __init__(self, worker_manager: WorkerManager, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.worker_manager = worker_manager
|
||||
|
||||
async def streamify(self, input_value: Dict) -> AsyncIterator[ModelOutput]:
|
||||
llm_cache_value: LLMCacheValue = await self.current_dag_context.get_share_data(
|
||||
_LLM_MODEL_OUTPUT_CACHE_KEY
|
||||
)
|
||||
logger.info(f"llm_cache_value: {llm_cache_value}")
|
||||
if llm_cache_value:
|
||||
for out in llm_cache_value.get_value().output:
|
||||
yield out
|
||||
return
|
||||
async for out in self.worker_manager.generate_stream(input_value):
|
||||
yield out
|
||||
|
||||
|
||||
class ModelOperator(MapOperator[Dict, ModelOutput]):
|
||||
def __init__(self, worker_manager: WorkerManager, **kwargs) -> None:
|
||||
self.worker_manager = worker_manager
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def map(self, input_value: Dict) -> ModelOutput:
|
||||
llm_cache_value: LLMCacheValue = await self.current_dag_context.get_share_data(
|
||||
_LLM_MODEL_OUTPUT_CACHE_KEY
|
||||
)
|
||||
logger.info(f"llm_cache_value: {llm_cache_value}")
|
||||
if llm_cache_value:
|
||||
return llm_cache_value.get_value().output
|
||||
return await self.worker_manager.generate(input_value)
|
||||
|
||||
|
||||
class ModelCachePreOperator(MapOperator[Dict, Dict]):
|
||||
def __init__(self, cache_manager: CacheManager, **kwargs):
|
||||
self._cache_manager = cache_manager
|
||||
self._client = LLMCacheClient(cache_manager)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def map(self, input_value: Dict) -> Dict:
|
||||
cache_dict = {
|
||||
"prompt": input_value.get("prompt"),
|
||||
"model_name": input_value.get("model"),
|
||||
"temperature": input_value.get("temperature"),
|
||||
"max_new_tokens": input_value.get("max_new_tokens"),
|
||||
"top_p": input_value.get("top_p", "1.0"),
|
||||
# TODO pass model_type
|
||||
"model_type": input_value.get("model_type", "huggingface"),
|
||||
}
|
||||
cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
|
||||
cache_value = await self._client.get(cache_key)
|
||||
logger.debug(
|
||||
f"cache_key: {cache_key}, hash key: {hash(cache_key)}, cache_value: {cache_value}"
|
||||
)
|
||||
await self.current_dag_context.save_to_share_data(
|
||||
_LLM_MODEL_INPUT_VALUE_KEY, cache_key
|
||||
)
|
||||
if cache_value:
|
||||
logger.info(f"The model output has cached, cache_value: {cache_value}")
|
||||
await self.current_dag_context.save_to_share_data(
|
||||
_LLM_MODEL_OUTPUT_CACHE_KEY, cache_value
|
||||
)
|
||||
return input_value
|
||||
|
||||
|
||||
class ModelStreamCacheOperator(TransformStreamAbsOperator[ModelOutput, ModelOutput]):
|
||||
def __init__(self, cache_manager: CacheManager, **kwargs):
|
||||
self._cache_manager = cache_manager
|
||||
self._client = LLMCacheClient(cache_manager)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def transform_stream(
|
||||
self, input_value: AsyncIterator[ModelOutput]
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
llm_cache_key: LLMCacheKey = None
|
||||
outputs = []
|
||||
async for out in input_value:
|
||||
if not llm_cache_key:
|
||||
llm_cache_key = await self.current_dag_context.get_share_data(
|
||||
_LLM_MODEL_INPUT_VALUE_KEY
|
||||
)
|
||||
outputs.append(out)
|
||||
yield out
|
||||
if llm_cache_key:
|
||||
llm_cache_value: LLMCacheValue = self._client.new_value(output=outputs)
|
||||
await self._client.set(llm_cache_key, llm_cache_value)
|
||||
|
||||
|
||||
class ModelCacheOperator(MapOperator[ModelOutput, ModelOutput]):
|
||||
def __init__(self, cache_manager: CacheManager, **kwargs):
|
||||
self._cache_manager = cache_manager
|
||||
self._client = LLMCacheClient(cache_manager)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def map(self, input_value: ModelOutput) -> ModelOutput:
|
||||
llm_cache_key: LLMCacheKey = await self.current_dag_context.get_share_data(
|
||||
_LLM_MODEL_INPUT_VALUE_KEY
|
||||
)
|
||||
llm_cache_value: LLMCacheValue = self._client.new_value(output=input_value)
|
||||
if llm_cache_key:
|
||||
await self._client.set(llm_cache_key, llm_cache_value)
|
||||
return input_value
|
@@ -16,6 +16,8 @@ from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async
|
||||
from pilot.utils.tracer import root_tracer, trace
|
||||
from pydantic import Extra
|
||||
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
|
||||
from pilot.awel import BaseOperator, SimpleCallDataInputSource, InputOperator, DAG
|
||||
from pilot.model.operator.model_operator import ModelOperator, ModelStreamOperator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
headers = {"User-Agent": "dbgpt Client"}
|
||||
@@ -88,6 +90,11 @@ class BaseChat(ABC):
|
||||
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||
).create()
|
||||
|
||||
self._model_operator: BaseOperator = _build_model_operator()
|
||||
self._model_stream_operator: BaseOperator = _build_model_operator(
|
||||
is_stream=True, dag_name="llm_stream_model_dag"
|
||||
)
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
@@ -204,12 +211,9 @@ class BaseChat(ABC):
|
||||
)
|
||||
payload["span_id"] = span.span_id
|
||||
try:
|
||||
from pilot.model.cluster import WorkerManagerFactory
|
||||
|
||||
worker_manager = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
async for output in worker_manager.generate_stream(payload):
|
||||
async for output in await self._model_stream_operator.call_stream(
|
||||
call_data={"data": payload}
|
||||
):
|
||||
### Plug-in research in result generation
|
||||
msg = self.prompt_template.output_parser.parse_model_stream_resp_ex(
|
||||
output, self.skip_echo_len
|
||||
@@ -240,14 +244,10 @@ class BaseChat(ABC):
|
||||
)
|
||||
payload["span_id"] = span.span_id
|
||||
try:
|
||||
from pilot.model.cluster import WorkerManagerFactory
|
||||
|
||||
worker_manager = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
|
||||
with root_tracer.start_span("BaseChat.invoke_worker_manager.generate"):
|
||||
model_output = await worker_manager.generate(payload)
|
||||
model_output = await self._model_operator.call(
|
||||
call_data={"data": payload}
|
||||
)
|
||||
|
||||
### output parse
|
||||
ai_response_text = (
|
||||
@@ -307,14 +307,7 @@ class BaseChat(ABC):
|
||||
logger.info(f"Request: \n{payload}")
|
||||
ai_response_text = ""
|
||||
try:
|
||||
from pilot.model.cluster import WorkerManagerFactory
|
||||
|
||||
worker_manager = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
|
||||
model_output = await worker_manager.generate(payload)
|
||||
|
||||
model_output = await self._model_operator.call(call_data={"data": payload})
|
||||
### output parse
|
||||
ai_response_text = (
|
||||
self.prompt_template.output_parser.parse_model_nostream_resp(
|
||||
@@ -568,3 +561,34 @@ class BaseChat(ABC):
|
||||
)
|
||||
else:
|
||||
return prompt_define_response
|
||||
|
||||
|
||||
def _build_model_operator(
|
||||
is_stream: bool = False, dag_name: str = "llm_model_dag"
|
||||
) -> BaseOperator:
|
||||
from pilot.model.cluster import WorkerManagerFactory
|
||||
from pilot.model.operator.model_operator import (
|
||||
ModelCacheOperator,
|
||||
ModelStreamCacheOperator,
|
||||
ModelCachePreOperator,
|
||||
)
|
||||
from pilot.cache import CacheManager
|
||||
|
||||
worker_manager = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
cache_manager: CacheManager = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.MODEL_CACHE_MANAGER, CacheManager
|
||||
)
|
||||
|
||||
with DAG(dag_name):
|
||||
input_node = InputOperator(SimpleCallDataInputSource())
|
||||
cache_check_node = ModelCachePreOperator(cache_manager)
|
||||
if is_stream:
|
||||
model_node = ModelStreamOperator(worker_manager)
|
||||
cache_node = ModelStreamCacheOperator(cache_manager)
|
||||
else:
|
||||
model_node = ModelOperator(worker_manager)
|
||||
cache_node = ModelCacheOperator(cache_manager)
|
||||
input_node >> cache_check_node >> model_node >> cache_node
|
||||
return cache_node
|
||||
|
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Type
|
||||
import os
|
||||
|
||||
from pilot.component import ComponentType, SystemApp
|
||||
from pilot.configs.model_config import MODEL_DISK_CACHE_DIR
|
||||
from pilot.utils.executor_utils import DefaultExecutorFactory
|
||||
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
|
||||
from pilot.server.base import WebWerverParameters
|
||||
@@ -41,6 +42,10 @@ def initialize_components(
|
||||
param, system_app, embedding_model_name, embedding_model_path
|
||||
)
|
||||
|
||||
from pilot.cache import initialize_cache
|
||||
|
||||
initialize_cache(system_app, MODEL_DISK_CACHE_DIR)
|
||||
|
||||
|
||||
def _initialize_embedding_model(
|
||||
param: WebWerverParameters,
|
||||
|
11
pilot/utils/memory_utils.py
Normal file
11
pilot/utils/memory_utils.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from typing import Any
|
||||
from pympler import asizeof
|
||||
|
||||
|
||||
def _get_object_bytes(obj: Any) -> int:
|
||||
"""Get the bytes of a object in memory
|
||||
|
||||
Args:
|
||||
obj (Any): The object to return the bytes
|
||||
"""
|
||||
return asizeof.asizeof(obj)
|
10
setup.py
10
setup.py
@@ -319,6 +319,8 @@ def core_requires():
|
||||
"alembic==1.12.0",
|
||||
# for excel
|
||||
"openpyxl",
|
||||
# for cache, TODO pympler has not been updated for a long time and needs to find a new toolkit.
|
||||
"pympler",
|
||||
]
|
||||
|
||||
|
||||
@@ -410,6 +412,13 @@ def vllm_requires():
|
||||
setup_spec.extras["vllm"] = ["vllm"]
|
||||
|
||||
|
||||
def cache_requires():
|
||||
"""
|
||||
pip install "db-gpt[cache]"
|
||||
"""
|
||||
setup_spec.extras["cache"] = ["rocksdict", "msgpack"]
|
||||
|
||||
|
||||
# def chat_scene():
|
||||
# setup_spec.extras["chat"] = [
|
||||
# ""
|
||||
@@ -460,6 +469,7 @@ all_datasource_requires()
|
||||
openai_requires()
|
||||
gpt4all_requires()
|
||||
vllm_requires()
|
||||
cache_requires()
|
||||
|
||||
# must be last
|
||||
default_requires()
|
||||
|
Reference in New Issue
Block a user