mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-08 04:23:35 +00:00
feat(model): Support model cache and first version of Agentic Workflow Expression Language(AWEL)
This commit is contained in:
57
pilot/awel/__init__.py
Normal file
57
pilot/awel/__init__.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
"""Agentic Workflow Expression Language (AWEL)"""
|
||||||
|
|
||||||
|
from .dag.base import DAGContext, DAG
|
||||||
|
|
||||||
|
from .operator.base import BaseOperator, WorkflowRunner, initialize_awel
|
||||||
|
from .operator.common_operator import (
|
||||||
|
JoinOperator,
|
||||||
|
ReduceStreamOperator,
|
||||||
|
MapOperator,
|
||||||
|
BranchOperator,
|
||||||
|
InputOperator,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .operator.stream_operator import (
|
||||||
|
StreamifyAbsOperator,
|
||||||
|
UnstreamifyAbsOperator,
|
||||||
|
TransformStreamAbsOperator,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .task.base import TaskState, TaskOutput, TaskContext, InputContext, InputSource
|
||||||
|
from .task.task_impl import (
|
||||||
|
SimpleInputSource,
|
||||||
|
SimpleCallDataInputSource,
|
||||||
|
DefaultTaskContext,
|
||||||
|
DefaultInputContext,
|
||||||
|
SimpleTaskOutput,
|
||||||
|
SimpleStreamTaskOutput,
|
||||||
|
)
|
||||||
|
from .runner.local_runner import DefaultWorkflowRunner
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"initialize_awel",
|
||||||
|
"DAGContext",
|
||||||
|
"DAG",
|
||||||
|
"BaseOperator",
|
||||||
|
"JoinOperator",
|
||||||
|
"ReduceStreamOperator",
|
||||||
|
"MapOperator",
|
||||||
|
"BranchOperator",
|
||||||
|
"InputOperator",
|
||||||
|
"WorkflowRunner",
|
||||||
|
"TaskState",
|
||||||
|
"TaskOutput",
|
||||||
|
"TaskContext",
|
||||||
|
"InputContext",
|
||||||
|
"InputSource",
|
||||||
|
"DefaultWorkflowRunner",
|
||||||
|
"SimpleInputSource",
|
||||||
|
"SimpleCallDataInputSource",
|
||||||
|
"DefaultTaskContext",
|
||||||
|
"DefaultInputContext",
|
||||||
|
"SimpleTaskOutput",
|
||||||
|
"SimpleStreamTaskOutput",
|
||||||
|
"StreamifyAbsOperator",
|
||||||
|
"UnstreamifyAbsOperator",
|
||||||
|
"TransformStreamAbsOperator",
|
||||||
|
]
|
0
pilot/awel/dag/__init__.py
Normal file
0
pilot/awel/dag/__init__.py
Normal file
252
pilot/awel/dag/base.py
Normal file
252
pilot/awel/dag/base.py
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional, Dict, List, Sequence, Union, Any
|
||||||
|
import uuid
|
||||||
|
import contextvars
|
||||||
|
import threading
|
||||||
|
import asyncio
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
from ..resource.base import ResourceGroup
|
||||||
|
from ..task.base import TaskContext
|
||||||
|
|
||||||
|
DependencyType = Union["DependencyMixin", Sequence["DependencyMixin"]]
|
||||||
|
|
||||||
|
|
||||||
|
def _is_async_context():
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
return asyncio.current_task(loop=loop) is not None
|
||||||
|
except RuntimeError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class DependencyMixin(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def set_upstream(self, nodes: DependencyType) -> "DependencyMixin":
|
||||||
|
"""Set one or more upstream nodes for this node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
nodes (DependencyType): Upstream nodes to be set to current node.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DependencyMixin: Returns self to allow method chaining.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no upstream nodes are provided or if an argument is not a DependencyMixin.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set_downstream(self, nodes: DependencyType) -> "DependencyMixin":
|
||||||
|
"""Set one or more downstream nodes for this node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
nodes (DependencyType): Downstream nodes to be set to current node.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DependencyMixin: Returns self to allow method chaining.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no downstream nodes are provided or if an argument is not a DependencyMixin.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __lshift__(self, nodes: DependencyType) -> DependencyType:
|
||||||
|
"""Implements self << nodes
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
# means node.set_upstream(input_node)
|
||||||
|
node << input_node
|
||||||
|
|
||||||
|
# means node2.set_upstream([input_node])
|
||||||
|
node2 << [input_node]
|
||||||
|
"""
|
||||||
|
self.set_upstream(nodes)
|
||||||
|
return nodes
|
||||||
|
|
||||||
|
def __rshift__(self, nodes: DependencyType) -> DependencyType:
|
||||||
|
"""Implements self >> nodes
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
# means node.set_downstream(next_node)
|
||||||
|
node >> next_node
|
||||||
|
|
||||||
|
# means node2.set_downstream([next_node])
|
||||||
|
node2 >> [next_node]
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.set_downstream(nodes)
|
||||||
|
return nodes
|
||||||
|
|
||||||
|
def __rrshift__(self, nodes: DependencyType) -> "DependencyMixin":
|
||||||
|
"""Implements [node] >> self"""
|
||||||
|
self.__lshift__(nodes)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __rlshift__(self, nodes: DependencyType) -> "DependencyMixin":
|
||||||
|
"""Implements [node] << self"""
|
||||||
|
self.__rshift__(nodes)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class DAGVar:
|
||||||
|
_thread_local = threading.local()
|
||||||
|
_async_local = contextvars.ContextVar("current_dag_stack", default=deque())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def enter_dag(cls, dag) -> None:
|
||||||
|
is_async = _is_async_context()
|
||||||
|
if is_async:
|
||||||
|
stack = cls._async_local.get()
|
||||||
|
stack.append(dag)
|
||||||
|
cls._async_local.set(stack)
|
||||||
|
else:
|
||||||
|
if not hasattr(cls._thread_local, "current_dag_stack"):
|
||||||
|
cls._thread_local.current_dag_stack = deque()
|
||||||
|
cls._thread_local.current_dag_stack.append(dag)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def exit_dag(cls) -> None:
|
||||||
|
is_async = _is_async_context()
|
||||||
|
if is_async:
|
||||||
|
stack = cls._async_local.get()
|
||||||
|
if stack:
|
||||||
|
stack.pop()
|
||||||
|
cls._async_local.set(stack)
|
||||||
|
else:
|
||||||
|
if (
|
||||||
|
hasattr(cls._thread_local, "current_dag_stack")
|
||||||
|
and cls._thread_local.current_dag_stack
|
||||||
|
):
|
||||||
|
cls._thread_local.current_dag_stack.pop()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_current_dag(cls) -> Optional["DAG"]:
|
||||||
|
is_async = _is_async_context()
|
||||||
|
if is_async:
|
||||||
|
stack = cls._async_local.get()
|
||||||
|
return stack[-1] if stack else None
|
||||||
|
else:
|
||||||
|
if (
|
||||||
|
hasattr(cls._thread_local, "current_dag_stack")
|
||||||
|
and cls._thread_local.current_dag_stack
|
||||||
|
):
|
||||||
|
return cls._thread_local.current_dag_stack[-1]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class DAGNode(DependencyMixin, ABC):
|
||||||
|
resource_group: Optional[ResourceGroup] = None
|
||||||
|
"""The resource group of current DAGNode"""
|
||||||
|
|
||||||
|
def __init__(self, dag: Optional["DAG"] = None, node_id: str = None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._upstream: List["DAGNode"] = []
|
||||||
|
self._downstream: List["DAGNode"] = []
|
||||||
|
self._dag: Optional["DAG"] = dag or DAGVar.get_current_dag()
|
||||||
|
if not node_id and self._dag:
|
||||||
|
node_id = self._dag._new_node_id()
|
||||||
|
self._node_id: str = node_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def node_id(self) -> str:
|
||||||
|
return self._node_id
|
||||||
|
|
||||||
|
def set_node_id(self, node_id: str) -> None:
|
||||||
|
self._node_id = node_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dag(self) -> "DAGNode":
|
||||||
|
return self._dag
|
||||||
|
|
||||||
|
def set_upstream(self, nodes: DependencyType) -> "DAGNode":
|
||||||
|
self.set_dependency(nodes)
|
||||||
|
|
||||||
|
def set_downstream(self, nodes: DependencyType) -> "DAGNode":
|
||||||
|
self.set_dependency(nodes, is_upstream=False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def upstream(self) -> List["DAGNode"]:
|
||||||
|
return self._upstream
|
||||||
|
|
||||||
|
@property
|
||||||
|
def downstream(self) -> List["DAGNode"]:
|
||||||
|
return self._downstream
|
||||||
|
|
||||||
|
def set_dependency(self, nodes: DependencyType, is_upstream: bool = True) -> None:
|
||||||
|
if not isinstance(nodes, Sequence):
|
||||||
|
nodes = [nodes]
|
||||||
|
if not all(isinstance(node, DAGNode) for node in nodes):
|
||||||
|
raise ValueError(
|
||||||
|
"all nodes to set dependency to current node must be instance of 'DAGNode'"
|
||||||
|
)
|
||||||
|
nodes: Sequence[DAGNode] = nodes
|
||||||
|
dags = set([node.dag for node in nodes if node.dag])
|
||||||
|
if self.dag:
|
||||||
|
dags.add(self.dag)
|
||||||
|
if not dags:
|
||||||
|
raise ValueError("set dependency to current node must in a DAG context")
|
||||||
|
if len(dags) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"set dependency to current node just support in one DAG context"
|
||||||
|
)
|
||||||
|
dag = dags.pop()
|
||||||
|
self._dag = dag
|
||||||
|
|
||||||
|
dag._append_node(self)
|
||||||
|
for node in nodes:
|
||||||
|
if is_upstream and node not in self.upstream:
|
||||||
|
node._dag = dag
|
||||||
|
dag._append_node(node)
|
||||||
|
|
||||||
|
self._upstream.append(node)
|
||||||
|
node._downstream.append(self)
|
||||||
|
elif node not in self._downstream:
|
||||||
|
node._dag = dag
|
||||||
|
dag._append_node(node)
|
||||||
|
|
||||||
|
self._downstream.append(node)
|
||||||
|
node._upstream.append(self)
|
||||||
|
|
||||||
|
|
||||||
|
class DAGContext:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._curr_task_ctx = None
|
||||||
|
self._share_data: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_task_context(self) -> TaskContext:
|
||||||
|
return self._curr_task_ctx
|
||||||
|
|
||||||
|
def set_current_task_context(self, _curr_task_ctx: TaskContext) -> None:
|
||||||
|
self._curr_task_ctx = _curr_task_ctx
|
||||||
|
|
||||||
|
async def get_share_data(self, key: str) -> Any:
|
||||||
|
return self._share_data.get(key)
|
||||||
|
|
||||||
|
async def save_to_share_data(self, key: str, data: Any) -> None:
|
||||||
|
self._share_data[key] = data
|
||||||
|
|
||||||
|
|
||||||
|
class DAG:
|
||||||
|
def __init__(
|
||||||
|
self, dag_id: str, resource_group: Optional[ResourceGroup] = None
|
||||||
|
) -> None:
|
||||||
|
self.node_map: Dict[str, DAGNode] = {}
|
||||||
|
|
||||||
|
def _append_node(self, node: DAGNode) -> None:
|
||||||
|
self.node_map[node.node_id] = node
|
||||||
|
|
||||||
|
def _new_node_id(self) -> str:
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
DAGVar.enter_dag(self)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
DAGVar.exit_dag()
|
0
pilot/awel/dag/tests/__init__.py
Normal file
0
pilot/awel/dag/tests/__init__.py
Normal file
51
pilot/awel/dag/tests/test_dag.py
Normal file
51
pilot/awel/dag/tests/test_dag.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
import pytest
|
||||||
|
import threading
|
||||||
|
import asyncio
|
||||||
|
from ..dag import DAG, DAGContext
|
||||||
|
|
||||||
|
|
||||||
|
def test_dag_context_sync():
|
||||||
|
dag1 = DAG("dag1")
|
||||||
|
dag2 = DAG("dag2")
|
||||||
|
|
||||||
|
with dag1:
|
||||||
|
assert DAGContext.get_current_dag() == dag1
|
||||||
|
with dag2:
|
||||||
|
assert DAGContext.get_current_dag() == dag2
|
||||||
|
assert DAGContext.get_current_dag() == dag1
|
||||||
|
assert DAGContext.get_current_dag() is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_dag_context_threading():
|
||||||
|
def thread_function(dag):
|
||||||
|
DAGContext.enter_dag(dag)
|
||||||
|
assert DAGContext.get_current_dag() == dag
|
||||||
|
DAGContext.exit_dag()
|
||||||
|
|
||||||
|
dag1 = DAG("dag1")
|
||||||
|
dag2 = DAG("dag2")
|
||||||
|
|
||||||
|
thread1 = threading.Thread(target=thread_function, args=(dag1,))
|
||||||
|
thread2 = threading.Thread(target=thread_function, args=(dag2,))
|
||||||
|
|
||||||
|
thread1.start()
|
||||||
|
thread2.start()
|
||||||
|
thread1.join()
|
||||||
|
thread2.join()
|
||||||
|
|
||||||
|
assert DAGContext.get_current_dag() is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dag_context_async():
|
||||||
|
async def async_function(dag):
|
||||||
|
DAGContext.enter_dag(dag)
|
||||||
|
assert DAGContext.get_current_dag() == dag
|
||||||
|
DAGContext.exit_dag()
|
||||||
|
|
||||||
|
dag1 = DAG("dag1")
|
||||||
|
dag2 = DAG("dag2")
|
||||||
|
|
||||||
|
await asyncio.gather(async_function(dag1), async_function(dag2))
|
||||||
|
|
||||||
|
assert DAGContext.get_current_dag() is None
|
0
pilot/awel/operator/__init__.py
Normal file
0
pilot/awel/operator/__init__.py
Normal file
176
pilot/awel/operator/base.py
Normal file
176
pilot/awel/operator/base.py
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
from abc import ABC, abstractmethod, ABCMeta
|
||||||
|
|
||||||
|
from types import FunctionType
|
||||||
|
from typing import (
|
||||||
|
List,
|
||||||
|
Generic,
|
||||||
|
TypeVar,
|
||||||
|
AsyncIterator,
|
||||||
|
Union,
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
Optional,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
import functools
|
||||||
|
from inspect import signature
|
||||||
|
|
||||||
|
from ..dag.base import DAGNode, DAGContext, DAGVar, DAG
|
||||||
|
from ..task.base import (
|
||||||
|
TaskContext,
|
||||||
|
TaskOutput,
|
||||||
|
TaskState,
|
||||||
|
OUT,
|
||||||
|
T,
|
||||||
|
InputContext,
|
||||||
|
InputSource,
|
||||||
|
)
|
||||||
|
|
||||||
|
F = TypeVar("F", bound=FunctionType)
|
||||||
|
|
||||||
|
CALL_DATA = Union[Dict, Dict[str, Dict]]
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowRunner(ABC, Generic[T]):
|
||||||
|
"""Abstract base class representing a runner for executing workflows in a DAG.
|
||||||
|
|
||||||
|
This class defines the interface for executing workflows within the DAG,
|
||||||
|
handling the flow from one DAG node to another.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def execute_workflow(
|
||||||
|
self, node: "BaseOperator", call_data: Optional[CALL_DATA] = None
|
||||||
|
) -> DAGContext:
|
||||||
|
"""Execute the workflow starting from a given operator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (RunnableDAGNode): The starting node of the workflow to be executed.
|
||||||
|
call_data (CALL_DATA): The data pass to root operator node.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DAGContext: The context after executing the workflow, containing the final state and data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
default_runner: WorkflowRunner = None
|
||||||
|
|
||||||
|
|
||||||
|
class BaseOperatorMeta(ABCMeta):
|
||||||
|
"""Metaclass of BaseOperator."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _apply_defaults(cls, func: F) -> F:
|
||||||
|
sig_cache = signature(func)
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any:
|
||||||
|
dag: Optional[DAG] = kwargs.get("dag") or DAGVar.get_current_dag()
|
||||||
|
task_id: Optional[str] = kwargs.get("task_id")
|
||||||
|
if not task_id and dag:
|
||||||
|
task_id = dag._new_node_id()
|
||||||
|
runner: Optional[WorkflowRunner] = kwargs.get("runner") or default_runner
|
||||||
|
# print(f"self: {self}, kwargs dag: {kwargs.get('dag')}, kwargs: {kwargs}")
|
||||||
|
# for arg in sig_cache.parameters:
|
||||||
|
# if arg not in kwargs:
|
||||||
|
# kwargs[arg] = default_args[arg]
|
||||||
|
if not kwargs.get("dag"):
|
||||||
|
kwargs["dag"] = dag
|
||||||
|
if not kwargs.get("task_id"):
|
||||||
|
kwargs["task_id"] = task_id
|
||||||
|
if not kwargs.get("runner"):
|
||||||
|
kwargs["runner"] = runner
|
||||||
|
real_obj = func(self, *args, **kwargs)
|
||||||
|
return real_obj
|
||||||
|
|
||||||
|
return cast(T, apply_defaults)
|
||||||
|
|
||||||
|
def __new__(cls, name, bases, namespace, **kwargs):
|
||||||
|
new_cls = super().__new__(cls, name, bases, namespace, **kwargs)
|
||||||
|
new_cls.__init__ = cls._apply_defaults(new_cls.__init__)
|
||||||
|
return new_cls
|
||||||
|
|
||||||
|
|
||||||
|
class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||||
|
"""Abstract base class for operator nodes that can be executed within a workflow.
|
||||||
|
|
||||||
|
This class extends DAGNode by adding execution capabilities.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
task_id: Optional[str] = None,
|
||||||
|
dag: Optional[DAG] = None,
|
||||||
|
runner: WorkflowRunner = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
"""Initializes a BaseOperator with an optional workflow runner.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (WorkflowRunner, optional): The runner used to execute the workflow. Defaults to None.
|
||||||
|
"""
|
||||||
|
super().__init__(node_id=task_id, dag=dag, **kwargs)
|
||||||
|
if not runner:
|
||||||
|
from pilot.awel import DefaultWorkflowRunner
|
||||||
|
|
||||||
|
runner = DefaultWorkflowRunner()
|
||||||
|
|
||||||
|
self._runner: WorkflowRunner = runner
|
||||||
|
self._dag_ctx: DAGContext = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_dag_context(self) -> DAGContext:
|
||||||
|
return self._dag_ctx
|
||||||
|
|
||||||
|
async def _run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||||
|
if not self.node_id:
|
||||||
|
raise ValueError(f"The DAG Node ID can't be empty, current node {self}")
|
||||||
|
self._dag_ctx = dag_ctx
|
||||||
|
return await self._do_run(dag_ctx)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||||
|
"""
|
||||||
|
Abstract method to run the task within the DAG node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dag_ctx (DAGContext): The context of the DAG when this node is run.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TaskOutput[OUT]: The task output after this node has been run.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def call(self, call_data: Optional[CALL_DATA] = None) -> OUT:
|
||||||
|
"""Execute the node and return the output.
|
||||||
|
|
||||||
|
This method is a high-level wrapper for executing the node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
call_data (CALL_DATA): The data pass to root operator node.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OUT: The output of the node after execution.
|
||||||
|
"""
|
||||||
|
out_ctx = await self._runner.execute_workflow(self, call_data)
|
||||||
|
return out_ctx.current_task_context.task_output.output
|
||||||
|
|
||||||
|
async def call_stream(
|
||||||
|
self, call_data: Optional[CALL_DATA] = None
|
||||||
|
) -> AsyncIterator[OUT]:
|
||||||
|
"""Execute the node and return the output as a stream.
|
||||||
|
|
||||||
|
This method is used for nodes where the output is a stream.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
call_data (CALL_DATA): The data pass to root operator node.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AsyncIterator[OUT]: An asynchronous iterator over the output stream.
|
||||||
|
"""
|
||||||
|
out_ctx = await self._runner.execute_workflow(self, call_data)
|
||||||
|
return out_ctx.current_task_context.task_output.output_stream
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_awel(runner: WorkflowRunner):
|
||||||
|
global default_runner
|
||||||
|
default_runner = runner
|
221
pilot/awel/operator/common_operator.py
Normal file
221
pilot/awel/operator/common_operator.py
Normal file
@@ -0,0 +1,221 @@
|
|||||||
|
from typing import Generic, Dict, List, Union, Callable, Any, AsyncIterator, Awaitable
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from ..dag.base import DAGContext
|
||||||
|
from ..task.base import (
|
||||||
|
TaskContext,
|
||||||
|
TaskOutput,
|
||||||
|
IN,
|
||||||
|
OUT,
|
||||||
|
InputContext,
|
||||||
|
InputSource,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .base import BaseOperator
|
||||||
|
|
||||||
|
|
||||||
|
class JoinOperator(BaseOperator, Generic[OUT]):
|
||||||
|
"""Operator that joins inputs using a custom combine function.
|
||||||
|
|
||||||
|
This node type is useful for combining the outputs of upstream nodes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, combine_function, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
if not callable(combine_function):
|
||||||
|
raise ValueError("combine_function must be callable")
|
||||||
|
self.combine_function = combine_function
|
||||||
|
|
||||||
|
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||||
|
"""Run the join operation on the DAG context's inputs.
|
||||||
|
Args:
|
||||||
|
dag_ctx (DAGContext): The current context of the DAG.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TaskOutput[OUT]: The task output after this node has been run.
|
||||||
|
"""
|
||||||
|
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||||
|
input_ctx: InputContext = await curr_task_ctx.task_input.map_all(
|
||||||
|
self.combine_function
|
||||||
|
)
|
||||||
|
# All join result store in the first parent output
|
||||||
|
join_output = input_ctx.parent_outputs[0].task_output
|
||||||
|
curr_task_ctx.set_task_output(join_output)
|
||||||
|
return join_output
|
||||||
|
|
||||||
|
|
||||||
|
class ReduceStreamOperator(BaseOperator, Generic[IN, OUT]):
|
||||||
|
def __init__(self, reduce_function=None, **kwargs):
|
||||||
|
"""Initializes a ReduceStreamOperator with a combine function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
combine_function: A function that defines how to combine inputs.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the combine_function is not callable.
|
||||||
|
"""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
if reduce_function and not callable(reduce_function):
|
||||||
|
raise ValueError("reduce_function must be callable")
|
||||||
|
self.reduce_function = reduce_function
|
||||||
|
|
||||||
|
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||||
|
"""Run the join operation on the DAG context's inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dag_ctx (DAGContext): The current context of the DAG.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TaskOutput[OUT]: The task output after this node has been run.
|
||||||
|
"""
|
||||||
|
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||||
|
task_input = curr_task_ctx.task_input
|
||||||
|
if not task_input.check_stream():
|
||||||
|
raise ValueError("ReduceStreamOperator expects stream data")
|
||||||
|
if not task_input.check_single_parent():
|
||||||
|
raise ValueError("ReduceStreamOperator expects single parent")
|
||||||
|
|
||||||
|
reduce_function = self.reduce_function or self.reduce
|
||||||
|
|
||||||
|
input_ctx: InputContext = await task_input.reduce(reduce_function)
|
||||||
|
# All join result store in the first parent output
|
||||||
|
reduce_output = input_ctx.parent_outputs[0].task_output
|
||||||
|
curr_task_ctx.set_task_output(reduce_output)
|
||||||
|
return reduce_output
|
||||||
|
|
||||||
|
async def reduce(self, input_value: AsyncIterator[IN]) -> OUT:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class MapOperator(BaseOperator, Generic[IN, OUT]):
|
||||||
|
"""Map operator that applies a mapping function to its inputs.
|
||||||
|
|
||||||
|
This operator transforms its input data using a provided mapping function and
|
||||||
|
passes the transformed data downstream.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, map_function=None, **kwargs):
|
||||||
|
"""Initializes a MapDAGNode with a mapping function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
map_function: A function that defines how to map the input data.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the map_function is not callable.
|
||||||
|
"""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
if map_function and not callable(map_function):
|
||||||
|
raise ValueError("map_function must be callable")
|
||||||
|
self.map_function = map_function
|
||||||
|
|
||||||
|
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||||
|
"""Run the mapping operation on the DAG context's inputs.
|
||||||
|
|
||||||
|
This method applies the mapping function to the input context and updates
|
||||||
|
the DAG context with the new data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dag_ctx (DAGContext[IN]): The current context of the DAG.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TaskOutput[OUT]: The task output after this node has been run.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If not a single parent or the map_function is not callable
|
||||||
|
"""
|
||||||
|
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||||
|
if not curr_task_ctx.task_input.check_single_parent():
|
||||||
|
num_parents = len(curr_task_ctx.task_input.parent_outputs)
|
||||||
|
raise ValueError(
|
||||||
|
f"task {curr_task_ctx.task_id} MapDAGNode expects single parent, now number of parents: {num_parents}"
|
||||||
|
)
|
||||||
|
map_function = self.map_function or self.map
|
||||||
|
|
||||||
|
input_ctx: InputContext = await curr_task_ctx.task_input.map(map_function)
|
||||||
|
# All join result store in the first parent output
|
||||||
|
reduce_output = input_ctx.parent_outputs[0].task_output
|
||||||
|
curr_task_ctx.set_task_output(reduce_output)
|
||||||
|
return reduce_output
|
||||||
|
|
||||||
|
async def map(self, input_value: IN) -> OUT:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
BranchFunc = Union[Callable[[Any], bool], Callable[[Any], Awaitable[bool]]]
|
||||||
|
|
||||||
|
|
||||||
|
class BranchOperator(BaseOperator, Generic[OUT]):
|
||||||
|
"""Operator node that branches the workflow based on a provided function.
|
||||||
|
|
||||||
|
This node filters its input data using a branching function and
|
||||||
|
allows for conditional paths in the workflow.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, branches: Dict[BranchFunc, BaseOperator], **kwargs):
|
||||||
|
"""
|
||||||
|
Initializes a BranchDAGNode with a branching function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
branches (Dict[BranchFunc, RunnableDAGNode]): Dict of function that defines the branching condition.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the branch_function is not callable.
|
||||||
|
"""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
for branch_function in branches.keys():
|
||||||
|
if not callable(branch_function):
|
||||||
|
raise ValueError("branch_function must be callable")
|
||||||
|
self.branches = branches
|
||||||
|
|
||||||
|
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||||
|
"""Run the branching operation on the DAG context's inputs.
|
||||||
|
|
||||||
|
This method applies the branching function to the input context to determine
|
||||||
|
the path of execution in the workflow.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dag_ctx (DAGContext[IN]): The current context of the DAG.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TaskOutput[OUT]: The task output after this node has been run.
|
||||||
|
"""
|
||||||
|
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||||
|
task_input = curr_task_ctx.task_input
|
||||||
|
if task_input.check_stream():
|
||||||
|
raise ValueError("BranchDAGNode expects no stream data")
|
||||||
|
if not task_input.check_single_parent():
|
||||||
|
raise ValueError("BranchDAGNode expects single parent")
|
||||||
|
|
||||||
|
branch_func_tasks = []
|
||||||
|
branch_nodes: List[BaseOperator] = []
|
||||||
|
for func, node in self.branches.items():
|
||||||
|
branch_nodes.append(node)
|
||||||
|
branch_func_tasks.append(
|
||||||
|
curr_task_ctx.task_input.predicate_map(func, failed_value=None)
|
||||||
|
)
|
||||||
|
branch_input_ctxs: List[InputContext] = await asyncio.gather(*branch_func_tasks)
|
||||||
|
parent_output = task_input.parent_outputs[0].task_output
|
||||||
|
curr_task_ctx.set_task_output(parent_output)
|
||||||
|
|
||||||
|
for i, ctx in enumerate(branch_input_ctxs):
|
||||||
|
node = branch_nodes[i]
|
||||||
|
if ctx.parent_outputs[0].task_output.is_empty:
|
||||||
|
# Skip current node
|
||||||
|
# node.current_task_context.set_current_state(TaskState.SKIP)
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
raise NotImplementedError
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class InputOperator(BaseOperator, Generic[OUT]):
|
||||||
|
def __init__(self, input_source: InputSource[OUT], **kwargs) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._input_source = input_source
|
||||||
|
|
||||||
|
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||||
|
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||||
|
task_output = await self._input_source.read(curr_task_ctx)
|
||||||
|
curr_task_ctx.set_task_output(task_output)
|
||||||
|
return task_output
|
90
pilot/awel/operator/stream_operator.py
Normal file
90
pilot/awel/operator/stream_operator.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Generic, AsyncIterator
|
||||||
|
from ..task.base import OUT, IN, TaskOutput, TaskContext
|
||||||
|
from ..dag.base import DAGContext
|
||||||
|
from .base import BaseOperator
|
||||||
|
|
||||||
|
|
||||||
|
class StreamifyAbsOperator(BaseOperator[OUT], ABC, Generic[IN, OUT]):
|
||||||
|
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||||
|
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||||
|
output = await curr_task_ctx.task_input.parent_outputs[0].task_output.streamify(
|
||||||
|
self.streamify
|
||||||
|
)
|
||||||
|
curr_task_ctx.set_task_output(output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def streamify(self, input_value: IN) -> AsyncIterator[OUT]:
|
||||||
|
"""Convert a value of IN to an AsyncIterator[OUT]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_value (IN): The data of parent operator's output
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
class MyStreamOperator(StreamifyAbsOperator[int, int]):
|
||||||
|
async def streamify(self, input_value: int) -> AsyncIterator[int]
|
||||||
|
for i in range(input_value):
|
||||||
|
yield i
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class UnstreamifyAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
|
||||||
|
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||||
|
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||||
|
output = await curr_task_ctx.task_input.parent_outputs[
|
||||||
|
0
|
||||||
|
].task_output.unstreamify(self.unstreamify)
|
||||||
|
curr_task_ctx.set_task_output(output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def unstreamify(self, input_value: AsyncIterator[IN]) -> OUT:
|
||||||
|
"""Convert a value of AsyncIterator[IN] to an OUT.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_value (AsyncIterator[IN])): The data of parent operator's output
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
class MyUnstreamOperator(UnstreamifyAbsOperator[int, int]):
|
||||||
|
async def unstreamify(self, input_value: AsyncIterator[int]) -> int
|
||||||
|
value_cnt = 0
|
||||||
|
async for v in input_value:
|
||||||
|
value_cnt += 1
|
||||||
|
return value_cnt
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class TransformStreamAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
|
||||||
|
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||||
|
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
|
||||||
|
output = await curr_task_ctx.task_input.parent_outputs[
|
||||||
|
0
|
||||||
|
].task_output.transform_stream(self.transform_stream)
|
||||||
|
curr_task_ctx.set_task_output(output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def transform_stream(
|
||||||
|
self, input_value: AsyncIterator[IN]
|
||||||
|
) -> AsyncIterator[OUT]:
|
||||||
|
"""Transform an AsyncIterator[IN] to another AsyncIterator[OUT] using a given function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_value (AsyncIterator[IN])): The data of parent operator's output
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
class MyTransformStreamOperator(TransformStreamAbsOperator[int, int]):
|
||||||
|
async def unstreamify(self, input_value: AsyncIterator[int]) -> AsyncIterator[int]
|
||||||
|
async for v in input_value:
|
||||||
|
yield v + 1
|
||||||
|
"""
|
0
pilot/awel/resource/__init__.py
Normal file
0
pilot/awel/resource/__init__.py
Normal file
8
pilot/awel/resource/base.py
Normal file
8
pilot/awel/resource/base.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class ResourceGroup(ABC):
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def name(self) -> str:
|
||||||
|
"""The name of current resource group"""
|
0
pilot/awel/runner/__init__.py
Normal file
0
pilot/awel/runner/__init__.py
Normal file
74
pilot/awel/runner/job_manager.py
Normal file
74
pilot/awel/runner/job_manager.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
from typing import List, Set, Optional, Dict
|
||||||
|
import uuid
|
||||||
|
from ..dag.base import DAG
|
||||||
|
|
||||||
|
from ..operator.base import BaseOperator, CALL_DATA
|
||||||
|
|
||||||
|
|
||||||
|
class DAGNodeInstance:
|
||||||
|
def __init__(self, node_instance: DAG) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DAGInstance:
|
||||||
|
def __init__(self, dag: DAG) -> None:
|
||||||
|
self._dag = dag
|
||||||
|
|
||||||
|
|
||||||
|
class JobManager:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
root_nodes: List[BaseOperator],
|
||||||
|
all_nodes: List[BaseOperator],
|
||||||
|
end_node: BaseOperator,
|
||||||
|
id2call_data: Dict[str, Dict],
|
||||||
|
) -> None:
|
||||||
|
self._root_nodes = root_nodes
|
||||||
|
self._all_nodes = all_nodes
|
||||||
|
self._end_node = end_node
|
||||||
|
self._id2node_data = id2call_data
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_from_end_node(
|
||||||
|
end_node: BaseOperator, call_data: Optional[CALL_DATA] = None
|
||||||
|
) -> "JobManager":
|
||||||
|
nodes = _build_from_end_node(end_node)
|
||||||
|
root_nodes = _get_root_nodes(nodes)
|
||||||
|
id2call_data = _save_call_data(root_nodes, call_data)
|
||||||
|
return JobManager(root_nodes, nodes, end_node, id2call_data)
|
||||||
|
|
||||||
|
def get_call_data_by_id(self, node_id: str) -> Optional[Dict]:
|
||||||
|
return self._id2node_data.get(node_id)
|
||||||
|
|
||||||
|
|
||||||
|
def _save_call_data(
|
||||||
|
root_nodes: List[BaseOperator], call_data: CALL_DATA
|
||||||
|
) -> Dict[str, Dict]:
|
||||||
|
id2call_data = {}
|
||||||
|
if not call_data:
|
||||||
|
return id2call_data
|
||||||
|
if len(root_nodes) == 1:
|
||||||
|
node = root_nodes[0]
|
||||||
|
id2call_data[node.node_id] = call_data
|
||||||
|
else:
|
||||||
|
for node in root_nodes:
|
||||||
|
node_id = node.node_id
|
||||||
|
id2call_data[node_id] = call_data.get(node_id)
|
||||||
|
return id2call_data
|
||||||
|
|
||||||
|
|
||||||
|
def _build_from_end_node(end_node: BaseOperator) -> List[BaseOperator]:
|
||||||
|
nodes = []
|
||||||
|
if isinstance(end_node, BaseOperator):
|
||||||
|
task_id = end_node.node_id
|
||||||
|
if not task_id:
|
||||||
|
task_id = str(uuid.uuid4())
|
||||||
|
end_node.set_node_id(task_id)
|
||||||
|
nodes.append(end_node)
|
||||||
|
for node in end_node.upstream:
|
||||||
|
nodes += _build_from_end_node(node)
|
||||||
|
return nodes
|
||||||
|
|
||||||
|
|
||||||
|
def _get_root_nodes(nodes: List[BaseOperator]) -> List[BaseOperator]:
|
||||||
|
return list(filter(lambda x: not x.upstream, nodes))
|
68
pilot/awel/runner/local_runner.py
Normal file
68
pilot/awel/runner/local_runner.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
from typing import Dict, Optional
|
||||||
|
from ..dag.base import DAGContext
|
||||||
|
from ..operator.base import WorkflowRunner, BaseOperator, CALL_DATA
|
||||||
|
from ..task.base import TaskContext, TaskState
|
||||||
|
from ..task.task_impl import DefaultInputContext, DefaultTaskContext
|
||||||
|
from .job_manager import JobManager
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultWorkflowRunner(WorkflowRunner):
|
||||||
|
async def execute_workflow(
|
||||||
|
self, node: BaseOperator, call_data: Optional[CALL_DATA] = None
|
||||||
|
) -> DAGContext:
|
||||||
|
# Create DAG context
|
||||||
|
dag_ctx = DAGContext()
|
||||||
|
job_manager = JobManager.build_from_end_node(node, call_data)
|
||||||
|
dag = node.dag
|
||||||
|
# Save node output
|
||||||
|
node_outputs: Dict[str, TaskContext] = {}
|
||||||
|
await self._execute_node(job_manager, node, dag_ctx, node_outputs)
|
||||||
|
|
||||||
|
return dag_ctx
|
||||||
|
|
||||||
|
async def _execute_node(
|
||||||
|
self,
|
||||||
|
job_manager: JobManager,
|
||||||
|
node: BaseOperator,
|
||||||
|
dag_ctx: DAGContext,
|
||||||
|
node_outputs: Dict[str, TaskContext],
|
||||||
|
):
|
||||||
|
# Skip run node
|
||||||
|
if node.node_id in node_outputs:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Run all upstream node
|
||||||
|
for upstream_node in node.upstream:
|
||||||
|
if isinstance(upstream_node, BaseOperator):
|
||||||
|
await self._execute_node(
|
||||||
|
job_manager, upstream_node, dag_ctx, node_outputs
|
||||||
|
)
|
||||||
|
# if node.current_task_context.current_state == TaskState.SKIP:
|
||||||
|
# return
|
||||||
|
|
||||||
|
# for upstream_node in node.upstream:
|
||||||
|
# if (
|
||||||
|
# isinstance(upstream_node, BaseOperator)
|
||||||
|
# and upstream_node.current_task_context.current_state == TaskState.SKIP
|
||||||
|
# ):
|
||||||
|
# return
|
||||||
|
# Get the input from upstream node
|
||||||
|
inputs = [
|
||||||
|
node_outputs[upstream_node.node_id] for upstream_node in node.upstream
|
||||||
|
]
|
||||||
|
input_ctx = DefaultInputContext(inputs)
|
||||||
|
task_ctx = DefaultTaskContext(node.node_id, TaskState.INIT, task_output=None)
|
||||||
|
task_ctx.set_call_data(job_manager.get_call_data_by_id(node.node_id))
|
||||||
|
|
||||||
|
task_ctx.set_task_input(input_ctx)
|
||||||
|
dag_ctx.set_current_task_context(task_ctx)
|
||||||
|
|
||||||
|
task_ctx.set_current_state(TaskState.RUNNING)
|
||||||
|
try:
|
||||||
|
# print(f"Begin run {node}")
|
||||||
|
await node._run(dag_ctx)
|
||||||
|
node_outputs[node.node_id] = dag_ctx.current_task_context
|
||||||
|
task_ctx.set_current_state(TaskState.SUCCESS)
|
||||||
|
except Exception as e:
|
||||||
|
task_ctx.set_current_state(TaskState.FAILED)
|
||||||
|
raise e
|
0
pilot/awel/task/__init__.py
Normal file
0
pilot/awel/task/__init__.py
Normal file
358
pilot/awel/task/base.py
Normal file
358
pilot/awel/task/base.py
Normal file
@@ -0,0 +1,358 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from enum import Enum
|
||||||
|
from typing import (
|
||||||
|
TypeVar,
|
||||||
|
Generic,
|
||||||
|
Optional,
|
||||||
|
AsyncIterator,
|
||||||
|
Union,
|
||||||
|
Callable,
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
)
|
||||||
|
|
||||||
|
IN = TypeVar("IN")
|
||||||
|
OUT = TypeVar("OUT")
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class TaskState(str, Enum):
|
||||||
|
"""Enumeration representing the state of a task in the workflow.
|
||||||
|
|
||||||
|
This Enum defines various states a task can be in during its lifecycle in the DAG.
|
||||||
|
"""
|
||||||
|
|
||||||
|
INIT = "init" # Initial state of the task, not yet started
|
||||||
|
SKIP = "skip" # State indicating the task was skipped
|
||||||
|
RUNNING = "running" # State indicating the task is currently running
|
||||||
|
SUCCESS = "success" # State indicating the task completed successfully
|
||||||
|
FAILED = "failed" # State indicating the task failed during execution
|
||||||
|
|
||||||
|
|
||||||
|
class TaskOutput(ABC, Generic[T]):
|
||||||
|
"""Abstract base class representing the output of a task.
|
||||||
|
|
||||||
|
This class encapsulates the output of a task and provides methods to access the output data.
|
||||||
|
It can be subclassed to implement specific output behaviors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_stream(self) -> bool:
|
||||||
|
"""Check if the output is a stream.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the output is a stream, False otherwise.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_empty(self) -> bool:
|
||||||
|
"""Check if the output is empty.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the output is empty, False otherwise.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output(self) -> Optional[T]:
|
||||||
|
"""Return the output of the task.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
T: The output of the task. None if the output is empty.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_stream(self) -> Optional[AsyncIterator[T]]:
|
||||||
|
"""Return the output of the task as an asynchronous stream.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AsyncIterator[T]: An asynchronous iterator over the output. None if the output is empty.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set_output(self, output_data: Union[T, AsyncIterator[T]]) -> None:
|
||||||
|
"""Set the output data to current object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_data (Union[T, AsyncIterator[T]]): Output data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def map(self, map_func) -> "TaskOutput[T]":
|
||||||
|
"""Apply a mapping function to the task's output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
map_func: A function to apply to the task's output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TaskOutput[T]: The result of applying the mapping function.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def reduce(self, reduce_func) -> "TaskOutput[T]":
|
||||||
|
"""Apply a reducing function to the task's output.
|
||||||
|
|
||||||
|
Stream TaskOutput to Nonstream TaskOutput.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reduce_func: A reducing function to apply to the task's output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TaskOutput[T]: The result of applying the reducing function.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def streamify(
|
||||||
|
self, transform_func: Callable[[T], AsyncIterator[T]]
|
||||||
|
) -> "TaskOutput[T]":
|
||||||
|
"""Convert a value of type T to an AsyncIterator[T] using a transform function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
transform_func (Callable[[T], AsyncIterator[T]]): Function to transform a T value into an AsyncIterator[T].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TaskOutput[T]: The result of applying the reducing function.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def transform_stream(
|
||||||
|
self, transform_func: Callable[[AsyncIterator[T]], AsyncIterator[T]]
|
||||||
|
) -> "TaskOutput[T]":
|
||||||
|
"""Transform an AsyncIterator[T] to another AsyncIterator[T] using a given function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
transform_func (Callable[[AsyncIterator[T]], AsyncIterator[T]]): Function to apply to the AsyncIterator[T].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TaskOutput[T]: The result of applying the reducing function.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def unstreamify(
|
||||||
|
self, transform_func: Callable[[AsyncIterator[T]], T]
|
||||||
|
) -> "TaskOutput[T]":
|
||||||
|
"""Convert an AsyncIterator[T] to a value of type T using a transform function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
transform_func (Callable[[AsyncIterator[T]], T]): Function to transform an AsyncIterator[T] into a T value.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TaskOutput[T]: The result of applying the reducing function.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def check_condition(self, condition_func) -> bool:
|
||||||
|
"""Check if current output meets a given condition.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
condition_func: A function to determine if the condition is met.
|
||||||
|
Returns:
|
||||||
|
bool: True if current output meet the condition, False otherwise.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class TaskContext(ABC, Generic[T]):
|
||||||
|
"""Abstract base class representing the context of a task within a DAG.
|
||||||
|
|
||||||
|
This class provides the interface for accessing task-related information
|
||||||
|
and manipulating task output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def task_id(self) -> str:
|
||||||
|
"""Return the unique identifier of the task.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The unique identifier of the task.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def task_input(self) -> "InputContext":
|
||||||
|
"""Return the InputContext of current task.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
InputContext: The InputContext of current task.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set_task_input(self, input_ctx: "InputContext") -> None:
|
||||||
|
"""Set the InputContext object to current task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ctx (InputContext): The InputContext of current task
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def task_output(self) -> TaskOutput[T]:
|
||||||
|
"""Return the output object of the task.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TaskOutput[T]: The output object of the task.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set_task_output(self, task_output: TaskOutput[T]) -> None:
|
||||||
|
"""Set the output object to current task."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def current_state(self) -> TaskState:
|
||||||
|
"""Get the current state of the task.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TaskState: The current state of the task.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set_current_state(self, task_state: TaskState) -> None:
|
||||||
|
"""Set current task state
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_state (TaskState): The task state to be set.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def new_ctx(self) -> "TaskContext":
|
||||||
|
"""Create new task context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TaskContext: A new instance of a TaskContext.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def metadata(self) -> Dict[str, Any]:
|
||||||
|
"""Get the metadata of current task
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: The metadata
|
||||||
|
"""
|
||||||
|
|
||||||
|
def update_metadata(self, key: str, value: Any) -> None:
|
||||||
|
"""Update metadata with key and value
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key (str): The key of metadata
|
||||||
|
value (str): The value to be add to metadata
|
||||||
|
"""
|
||||||
|
self.metadata[key] = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def call_data(self) -> Optional[Dict]:
|
||||||
|
"""Get the call data for current data"""
|
||||||
|
return self.metadata.get("call_data")
|
||||||
|
|
||||||
|
def set_call_data(self, call_data: Dict) -> None:
|
||||||
|
"""Set call data for current task"""
|
||||||
|
self.update_metadata("call_data", call_data)
|
||||||
|
|
||||||
|
|
||||||
|
class InputContext(ABC):
|
||||||
|
"""Abstract base class representing the context of inputs to a operator node.
|
||||||
|
|
||||||
|
This class defines methods to manipulate and access the inputs for a operator node.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def parent_outputs(self) -> List[TaskContext]:
|
||||||
|
"""Get the outputs from the parent nodes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[TaskContext]: A list of contexts of the parent nodes' outputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def map(self, map_func: Callable[[Any], Any]) -> "InputContext":
|
||||||
|
"""Apply a mapping function to the inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
map_func (Callable[[Any], Any]): A function to be applied to the inputs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
InputContext: A new InputContext instance with the mapped inputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def map_all(self, map_func: Callable[..., Any]) -> "InputContext":
|
||||||
|
"""Apply a mapping function to all inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
map_func (Callable[..., Any]): A function to be applied to all inputs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
InputContext: A new InputContext instance with the mapped inputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def reduce(self, reduce_func: Callable[[Any], Any]) -> "InputContext":
|
||||||
|
"""Apply a reducing function to the inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reduce_func (Callable[[Any], Any]): A function that reduces the inputs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
InputContext: A new InputContext instance with the reduced inputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def filter(self, filter_func: Callable[[Any], bool]) -> "InputContext":
|
||||||
|
"""Filter the inputs based on a provided function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filter_func (Callable[[Any], bool]): A function that returns True for inputs to keep.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
InputContext: A new InputContext instance with the filtered inputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def predicate_map(
|
||||||
|
self, predicate_func: Callable[[Any], bool], failed_value: Any = None
|
||||||
|
) -> "InputContext":
|
||||||
|
"""Predicate the inputs based on a provided function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
predicate_func (Callable[[Any], bool]): A function that returns True for inputs is predicate True.
|
||||||
|
failed_value (Any): The value to be set if the return value of predicate function is False
|
||||||
|
Returns:
|
||||||
|
InputContext: A new InputContext instance with the predicate inputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def check_single_parent(self) -> bool:
|
||||||
|
"""Check if there is only a single parent output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if there is only one parent output, False otherwise.
|
||||||
|
"""
|
||||||
|
return len(self.parent_outputs) == 1
|
||||||
|
|
||||||
|
def check_stream(self) -> bool:
|
||||||
|
"""Check if all parent outputs are streams.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if all parent outputs are streams, False otherwise.
|
||||||
|
"""
|
||||||
|
for out in self.parent_outputs:
|
||||||
|
if not (out.task_output and out.task_output.is_stream):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class InputSource(ABC, Generic[T]):
|
||||||
|
"""Abstract base class representing the source of inputs to a DAG node."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def read(self, task_ctx: TaskContext) -> TaskOutput[T]:
|
||||||
|
"""Read the data from current input source.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TaskOutput[T]: The output object read from current source
|
||||||
|
"""
|
318
pilot/awel/task/task_impl.py
Normal file
318
pilot/awel/task/task_impl.py
Normal file
@@ -0,0 +1,318 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import (
|
||||||
|
Callable,
|
||||||
|
Coroutine,
|
||||||
|
Iterator,
|
||||||
|
AsyncIterator,
|
||||||
|
List,
|
||||||
|
Generic,
|
||||||
|
TypeVar,
|
||||||
|
Any,
|
||||||
|
Tuple,
|
||||||
|
Dict,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from .base import TaskOutput, TaskContext, TaskState, InputContext, InputSource, T
|
||||||
|
|
||||||
|
|
||||||
|
async def _reduce_stream(stream: AsyncIterator, reduce_function) -> Any:
|
||||||
|
# Init accumulator
|
||||||
|
try:
|
||||||
|
accumulator = await stream.__anext__()
|
||||||
|
except StopAsyncIteration:
|
||||||
|
raise ValueError("Stream is empty")
|
||||||
|
is_async = asyncio.iscoroutinefunction(reduce_function)
|
||||||
|
async for element in stream:
|
||||||
|
if is_async:
|
||||||
|
accumulator = await reduce_function(accumulator, element)
|
||||||
|
else:
|
||||||
|
accumulator = reduce_function(accumulator, element)
|
||||||
|
return accumulator
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleTaskOutput(TaskOutput[T], Generic[T]):
|
||||||
|
def __init__(self, data: T) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._data = data
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output(self) -> T:
|
||||||
|
return self._data
|
||||||
|
|
||||||
|
def set_output(self, output_data: T | AsyncIterator[T]) -> None:
|
||||||
|
self._data = output_data
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_empty(self) -> bool:
|
||||||
|
return not self._data
|
||||||
|
|
||||||
|
async def _apply_func(self, func) -> Any:
|
||||||
|
if asyncio.iscoroutinefunction(func):
|
||||||
|
out = await func(self._data)
|
||||||
|
else:
|
||||||
|
out = func(self._data)
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def map(self, map_func) -> TaskOutput[T]:
|
||||||
|
out = await self._apply_func(map_func)
|
||||||
|
return SimpleTaskOutput(out)
|
||||||
|
|
||||||
|
async def check_condition(self, condition_func) -> bool:
|
||||||
|
return await self._apply_func(condition_func)
|
||||||
|
|
||||||
|
async def streamify(
|
||||||
|
self, transform_func: Callable[[T], AsyncIterator[T]]
|
||||||
|
) -> TaskOutput[T]:
|
||||||
|
out = await self._apply_func(transform_func)
|
||||||
|
return SimpleStreamTaskOutput(out)
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleStreamTaskOutput(TaskOutput[T], Generic[T]):
|
||||||
|
def __init__(self, data: AsyncIterator[T]) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._data = data
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_stream(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_empty(self) -> bool:
|
||||||
|
return not self._data
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_stream(self) -> AsyncIterator[T]:
|
||||||
|
return self._data
|
||||||
|
|
||||||
|
def set_output(self, output_data: T | AsyncIterator[T]) -> None:
|
||||||
|
self._data = output_data
|
||||||
|
|
||||||
|
async def map(self, map_func) -> TaskOutput[T]:
|
||||||
|
is_async = asyncio.iscoroutinefunction(map_func)
|
||||||
|
|
||||||
|
async def new_iter() -> AsyncIterator[T]:
|
||||||
|
async for out in self._data:
|
||||||
|
if is_async:
|
||||||
|
out = await map_func(out)
|
||||||
|
else:
|
||||||
|
out = map_func(out)
|
||||||
|
yield out
|
||||||
|
|
||||||
|
return SimpleStreamTaskOutput(new_iter())
|
||||||
|
|
||||||
|
async def reduce(self, reduce_func) -> TaskOutput[T]:
|
||||||
|
out = await _reduce_stream(self._data, reduce_func)
|
||||||
|
return SimpleTaskOutput(out)
|
||||||
|
|
||||||
|
async def unstreamify(
|
||||||
|
self, transform_func: Callable[[AsyncIterator[T]], T]
|
||||||
|
) -> TaskOutput[T]:
|
||||||
|
if asyncio.iscoroutinefunction(transform_func):
|
||||||
|
out = await transform_func(self._data)
|
||||||
|
else:
|
||||||
|
out = transform_func(self._data)
|
||||||
|
return SimpleTaskOutput(out)
|
||||||
|
|
||||||
|
async def transform_stream(
|
||||||
|
self, transform_func: Callable[[AsyncIterator[T]], AsyncIterator[T]]
|
||||||
|
) -> TaskOutput[T]:
|
||||||
|
if asyncio.iscoroutinefunction(transform_func):
|
||||||
|
out = await transform_func(self._data)
|
||||||
|
else:
|
||||||
|
out = transform_func(self._data)
|
||||||
|
return SimpleStreamTaskOutput(out)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_async_iterator(obj):
|
||||||
|
return (
|
||||||
|
hasattr(obj, "__anext__")
|
||||||
|
and callable(getattr(obj, "__anext__", None))
|
||||||
|
and hasattr(obj, "__aiter__")
|
||||||
|
and callable(getattr(obj, "__aiter__", None))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseInputSource(InputSource, ABC):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._is_read = False
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _read_data(self, task_ctx: TaskContext) -> Any:
|
||||||
|
"""Read data with task context"""
|
||||||
|
|
||||||
|
async def read(self, task_ctx: TaskContext) -> Coroutine[Any, Any, TaskOutput]:
|
||||||
|
data = self._read_data(task_ctx)
|
||||||
|
if _is_async_iterator(data):
|
||||||
|
if self._is_read:
|
||||||
|
raise ValueError(f"Input iterator {data} has been read!")
|
||||||
|
output = SimpleStreamTaskOutput(data)
|
||||||
|
else:
|
||||||
|
output = SimpleTaskOutput(data)
|
||||||
|
self._is_read = True
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleInputSource(BaseInputSource):
|
||||||
|
def __init__(self, data: Any) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._data = data
|
||||||
|
|
||||||
|
def _read_data(self, task_ctx: TaskContext) -> Any:
|
||||||
|
return self._data
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleCallDataInputSource(BaseInputSource):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def _read_data(self, task_ctx: TaskContext) -> Any:
|
||||||
|
call_data = task_ctx.call_data
|
||||||
|
data = call_data.get("data") if call_data else None
|
||||||
|
if not (call_data and data):
|
||||||
|
raise ValueError("No call data for current SimpleCallDataInputSource")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultTaskContext(TaskContext, Generic[T]):
|
||||||
|
def __init__(
|
||||||
|
self, task_id: str, task_state: TaskState, task_output: TaskOutput[T]
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._task_id = task_id
|
||||||
|
self._task_state = task_state
|
||||||
|
self._output = task_output
|
||||||
|
self._task_input = None
|
||||||
|
self._metadata = {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def task_id(self) -> str:
|
||||||
|
return self._task_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def task_input(self) -> InputContext:
|
||||||
|
return self._task_input
|
||||||
|
|
||||||
|
def set_task_input(self, input_ctx: "InputContext") -> None:
|
||||||
|
self._task_input = input_ctx
|
||||||
|
|
||||||
|
@property
|
||||||
|
def task_output(self) -> TaskOutput:
|
||||||
|
return self._output
|
||||||
|
|
||||||
|
def set_task_output(self, task_output: TaskOutput) -> None:
|
||||||
|
self._output = task_output
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_state(self) -> TaskState:
|
||||||
|
return self._task_state
|
||||||
|
|
||||||
|
def set_current_state(self, task_state: TaskState) -> None:
|
||||||
|
self._task_state = task_state
|
||||||
|
|
||||||
|
def new_ctx(self) -> TaskContext:
|
||||||
|
return DefaultTaskContext(self._task_id, self._task_state, self._output)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def metadata(self) -> Dict[str, Any]:
|
||||||
|
return self._metadata
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultInputContext(InputContext):
|
||||||
|
def __init__(self, outputs: List[TaskContext]) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._outputs = outputs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parent_outputs(self) -> List[TaskContext]:
|
||||||
|
return self._outputs
|
||||||
|
|
||||||
|
async def _apply_func(
|
||||||
|
self, func: Callable[[Any], Any], apply_type: str = "map"
|
||||||
|
) -> Tuple[List[TaskContext], List[TaskOutput]]:
|
||||||
|
new_outputs: List[TaskContext] = []
|
||||||
|
map_tasks = []
|
||||||
|
for out in self._outputs:
|
||||||
|
new_outputs.append(out.new_ctx())
|
||||||
|
result = None
|
||||||
|
if apply_type == "map":
|
||||||
|
result = out.task_output.map(func)
|
||||||
|
elif apply_type == "reduce":
|
||||||
|
result = out.task_output.reduce(func)
|
||||||
|
elif apply_type == "check_condition":
|
||||||
|
result = out.task_output.check_condition(func)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupport apply type {apply_type}")
|
||||||
|
map_tasks.append(result)
|
||||||
|
results = await asyncio.gather(*map_tasks)
|
||||||
|
return new_outputs, results
|
||||||
|
|
||||||
|
async def map(self, map_func: Callable[[Any], Any]) -> InputContext:
|
||||||
|
new_outputs, results = await self._apply_func(map_func)
|
||||||
|
for i, task_ctx in enumerate(new_outputs):
|
||||||
|
task_ctx: TaskContext = task_ctx
|
||||||
|
task_ctx.set_task_output(results[i])
|
||||||
|
return DefaultInputContext(new_outputs)
|
||||||
|
|
||||||
|
async def map_all(self, map_func: Callable[..., Any]) -> InputContext:
|
||||||
|
if not self._outputs:
|
||||||
|
return DefaultInputContext([])
|
||||||
|
is_steam = self._outputs[0].task_output.is_stream
|
||||||
|
if is_steam:
|
||||||
|
if not self.check_stream():
|
||||||
|
raise ValueError(
|
||||||
|
"The output in all tasks must has same output format to map_all"
|
||||||
|
)
|
||||||
|
outputs = []
|
||||||
|
for out in self._outputs:
|
||||||
|
if out.task_output.is_stream:
|
||||||
|
outputs.append(out.task_output.output_stream)
|
||||||
|
else:
|
||||||
|
outputs.append(out.task_output.output)
|
||||||
|
if asyncio.iscoroutinefunction(map_func):
|
||||||
|
map_res = await map_func(*outputs)
|
||||||
|
else:
|
||||||
|
map_res = map_func(*outputs)
|
||||||
|
|
||||||
|
single_output: TaskContext = self._outputs[0].new_ctx()
|
||||||
|
single_output.task_output.set_output(map_res)
|
||||||
|
return DefaultInputContext([single_output])
|
||||||
|
|
||||||
|
async def reduce(self, reduce_func: Callable[[Any], Any]) -> InputContext:
|
||||||
|
if not self.check_stream():
|
||||||
|
raise ValueError(
|
||||||
|
"The output in all tasks must has same output format of stream to apply reduce function"
|
||||||
|
)
|
||||||
|
new_outputs, results = await self._apply_func(reduce_func, apply_type="reduce")
|
||||||
|
for i, task_ctx in enumerate(new_outputs):
|
||||||
|
task_ctx: TaskContext = task_ctx
|
||||||
|
task_ctx.set_task_output(results[i])
|
||||||
|
return DefaultInputContext(new_outputs)
|
||||||
|
|
||||||
|
async def filter(self, filter_func: Callable[[Any], bool]) -> InputContext:
|
||||||
|
new_outputs, results = await self._apply_func(
|
||||||
|
filter_func, apply_type="check_condition"
|
||||||
|
)
|
||||||
|
result_outputs = []
|
||||||
|
for i, task_ctx in enumerate(new_outputs):
|
||||||
|
if results[i]:
|
||||||
|
result_outputs.append(task_ctx)
|
||||||
|
return DefaultInputContext(result_outputs)
|
||||||
|
|
||||||
|
async def predicate_map(
|
||||||
|
self, predicate_func: Callable[[Any], bool], failed_value: Any = None
|
||||||
|
) -> "InputContext":
|
||||||
|
new_outputs, results = await self._apply_func(
|
||||||
|
predicate_func, apply_type="check_condition"
|
||||||
|
)
|
||||||
|
result_outputs = []
|
||||||
|
for i, task_ctx in enumerate(new_outputs):
|
||||||
|
task_ctx: TaskContext = task_ctx
|
||||||
|
if results[i]:
|
||||||
|
result_outputs.append(task_ctx)
|
||||||
|
else:
|
||||||
|
task_ctx.task_output.set_output(failed_value)
|
||||||
|
result_outputs.append(task_ctx)
|
||||||
|
return DefaultInputContext(result_outputs)
|
0
pilot/awel/tests/__init__.py
Normal file
0
pilot/awel/tests/__init__.py
Normal file
102
pilot/awel/tests/conftest.py
Normal file
102
pilot/awel/tests/conftest.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from typing import AsyncIterator, List
|
||||||
|
from contextlib import contextmanager, asynccontextmanager
|
||||||
|
from .. import (
|
||||||
|
WorkflowRunner,
|
||||||
|
InputOperator,
|
||||||
|
DAGContext,
|
||||||
|
TaskState,
|
||||||
|
DefaultWorkflowRunner,
|
||||||
|
SimpleInputSource,
|
||||||
|
)
|
||||||
|
from ..task.task_impl import _is_async_iterator
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def runner():
|
||||||
|
return DefaultWorkflowRunner()
|
||||||
|
|
||||||
|
|
||||||
|
def _create_stream(num_nodes) -> List[AsyncIterator[int]]:
|
||||||
|
iters = []
|
||||||
|
for _ in range(num_nodes):
|
||||||
|
|
||||||
|
async def stream_iter():
|
||||||
|
for i in range(10):
|
||||||
|
yield i
|
||||||
|
|
||||||
|
stream_iter = stream_iter()
|
||||||
|
assert _is_async_iterator(stream_iter)
|
||||||
|
iters.append(stream_iter)
|
||||||
|
return iters
|
||||||
|
|
||||||
|
|
||||||
|
def _create_stream_from(output_streams: List[List[int]]) -> List[AsyncIterator[int]]:
|
||||||
|
iters = []
|
||||||
|
for single_stream in output_streams:
|
||||||
|
|
||||||
|
async def stream_iter():
|
||||||
|
for i in single_stream:
|
||||||
|
yield i
|
||||||
|
|
||||||
|
stream_iter = stream_iter()
|
||||||
|
assert _is_async_iterator(stream_iter)
|
||||||
|
iters.append(stream_iter)
|
||||||
|
return iters
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _create_input_node(**kwargs):
|
||||||
|
num_nodes = kwargs.get("num_nodes")
|
||||||
|
is_stream = kwargs.get("is_stream", False)
|
||||||
|
if is_stream:
|
||||||
|
outputs = kwargs.get("output_streams")
|
||||||
|
if outputs:
|
||||||
|
if num_nodes and num_nodes != len(outputs):
|
||||||
|
raise ValueError(
|
||||||
|
f"num_nodes {num_nodes} != the length of output_streams {len(outputs)}"
|
||||||
|
)
|
||||||
|
outputs = _create_stream_from(outputs)
|
||||||
|
else:
|
||||||
|
num_nodes = num_nodes or 1
|
||||||
|
outputs = _create_stream(num_nodes)
|
||||||
|
else:
|
||||||
|
outputs = kwargs.get("outputs", ["Hello."])
|
||||||
|
nodes = []
|
||||||
|
for output in outputs:
|
||||||
|
print(f"output: {output}")
|
||||||
|
input_source = SimpleInputSource(output)
|
||||||
|
input_node = InputOperator(input_source)
|
||||||
|
nodes.append(input_node)
|
||||||
|
yield nodes
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def input_node(request):
|
||||||
|
param = getattr(request, "param", {})
|
||||||
|
async with _create_input_node(**param) as input_nodes:
|
||||||
|
yield input_nodes[0]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def stream_input_node(request):
|
||||||
|
param = getattr(request, "param", {})
|
||||||
|
param["is_stream"] = True
|
||||||
|
async with _create_input_node(**param) as input_nodes:
|
||||||
|
yield input_nodes[0]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def input_nodes(request):
|
||||||
|
param = getattr(request, "param", {})
|
||||||
|
async with _create_input_node(**param) as input_nodes:
|
||||||
|
yield input_nodes
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def stream_input_nodes(request):
|
||||||
|
param = getattr(request, "param", {})
|
||||||
|
param["is_stream"] = True
|
||||||
|
async with _create_input_node(**param) as input_nodes:
|
||||||
|
yield input_nodes
|
51
pilot/awel/tests/test_http_operator.py
Normal file
51
pilot/awel/tests/test_http_operator.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
import pytest
|
||||||
|
from typing import List
|
||||||
|
from .. import (
|
||||||
|
DAG,
|
||||||
|
WorkflowRunner,
|
||||||
|
DAGContext,
|
||||||
|
TaskState,
|
||||||
|
InputOperator,
|
||||||
|
MapOperator,
|
||||||
|
JoinOperator,
|
||||||
|
BranchOperator,
|
||||||
|
ReduceStreamOperator,
|
||||||
|
SimpleInputSource,
|
||||||
|
)
|
||||||
|
from .conftest import (
|
||||||
|
runner,
|
||||||
|
input_node,
|
||||||
|
input_nodes,
|
||||||
|
stream_input_node,
|
||||||
|
stream_input_nodes,
|
||||||
|
_is_async_iterator,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _register_dag_to_fastapi_app(dag):
|
||||||
|
# TODO
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_http_operator(runner: WorkflowRunner, stream_input_node: InputOperator):
|
||||||
|
with DAG("test_map") as dag:
|
||||||
|
pass
|
||||||
|
# http_req_task = HttpRequestOperator(endpoint="/api/completions")
|
||||||
|
# db_task = DBQueryOperator(table_name="user_info")
|
||||||
|
# prompt_task = PromptTemplateOperator(
|
||||||
|
# system_prompt="You are an AI designed to solve the user's goals with given commands, please follow the constraints of the system's input for your answers."
|
||||||
|
# )
|
||||||
|
# llm_task = ChatGPTLLMOperator(model="chagpt-3.5")
|
||||||
|
# output_parser_task = CommonOutputParserOperator()
|
||||||
|
# http_res_task = HttpResponseOperator()
|
||||||
|
# (
|
||||||
|
# http_req_task
|
||||||
|
# >> db_task
|
||||||
|
# >> prompt_task
|
||||||
|
# >> llm_task
|
||||||
|
# >> output_parser_task
|
||||||
|
# >> http_res_task
|
||||||
|
# )
|
||||||
|
|
||||||
|
_register_dag_to_fastapi_app(dag)
|
128
pilot/awel/tests/test_run_dag.py
Normal file
128
pilot/awel/tests/test_run_dag.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
import pytest
|
||||||
|
from typing import List
|
||||||
|
from .. import (
|
||||||
|
DAG,
|
||||||
|
WorkflowRunner,
|
||||||
|
DAGContext,
|
||||||
|
TaskState,
|
||||||
|
InputOperator,
|
||||||
|
MapOperator,
|
||||||
|
JoinOperator,
|
||||||
|
BranchOperator,
|
||||||
|
ReduceStreamOperator,
|
||||||
|
SimpleInputSource,
|
||||||
|
)
|
||||||
|
from .conftest import (
|
||||||
|
runner,
|
||||||
|
input_node,
|
||||||
|
input_nodes,
|
||||||
|
stream_input_node,
|
||||||
|
stream_input_nodes,
|
||||||
|
_is_async_iterator,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_input_node(runner: WorkflowRunner):
|
||||||
|
input_node = InputOperator(SimpleInputSource("hello"))
|
||||||
|
res: DAGContext[str] = await runner.execute_workflow(input_node)
|
||||||
|
assert res.current_task_context.current_state == TaskState.SUCCESS
|
||||||
|
assert res.current_task_context.task_output.output == "hello"
|
||||||
|
|
||||||
|
async def new_steam_iter(n: int):
|
||||||
|
for i in range(n):
|
||||||
|
yield i
|
||||||
|
|
||||||
|
num_iter = 10
|
||||||
|
steam_input_node = InputOperator(SimpleInputSource(new_steam_iter(num_iter)))
|
||||||
|
res: DAGContext[str] = await runner.execute_workflow(steam_input_node)
|
||||||
|
assert res.current_task_context.current_state == TaskState.SUCCESS
|
||||||
|
output_steam = res.current_task_context.task_output.output_stream
|
||||||
|
assert output_steam
|
||||||
|
assert _is_async_iterator(output_steam)
|
||||||
|
i = 0
|
||||||
|
async for x in output_steam:
|
||||||
|
assert x == i
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_map_node(runner: WorkflowRunner, stream_input_node: InputOperator):
|
||||||
|
with DAG("test_map") as dag:
|
||||||
|
map_node = MapOperator(lambda x: x * 2)
|
||||||
|
stream_input_node >> map_node
|
||||||
|
res: DAGContext[int] = await runner.execute_workflow(map_node)
|
||||||
|
output_steam = res.current_task_context.task_output.output_stream
|
||||||
|
assert output_steam
|
||||||
|
i = 0
|
||||||
|
async for x in output_steam:
|
||||||
|
assert x == i * 2
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"stream_input_node, expect_sum",
|
||||||
|
[
|
||||||
|
({"output_streams": [[0, 1, 2, 3]]}, 6),
|
||||||
|
({"output_streams": [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]}, 55),
|
||||||
|
],
|
||||||
|
indirect=["stream_input_node"],
|
||||||
|
)
|
||||||
|
async def test_reduce_node(
|
||||||
|
runner: WorkflowRunner, stream_input_node: InputOperator, expect_sum: int
|
||||||
|
):
|
||||||
|
with DAG("test_reduce_node") as dag:
|
||||||
|
reduce_node = ReduceStreamOperator(lambda x, y: x + y)
|
||||||
|
stream_input_node >> reduce_node
|
||||||
|
res: DAGContext[int] = await runner.execute_workflow(reduce_node)
|
||||||
|
assert res.current_task_context.current_state == TaskState.SUCCESS
|
||||||
|
assert not res.current_task_context.task_output.is_stream
|
||||||
|
assert res.current_task_context.task_output.output == expect_sum
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input_nodes",
|
||||||
|
[
|
||||||
|
({"outputs": [0, 1, 2]}),
|
||||||
|
],
|
||||||
|
indirect=["input_nodes"],
|
||||||
|
)
|
||||||
|
async def test_join_node(runner: WorkflowRunner, input_nodes: List[InputOperator]):
|
||||||
|
def join_func(p1, p2, p3) -> int:
|
||||||
|
return p1 + p2 + p3
|
||||||
|
|
||||||
|
with DAG("test_join_node") as dag:
|
||||||
|
join_node = JoinOperator(join_func)
|
||||||
|
for input_node in input_nodes:
|
||||||
|
input_node >> join_node
|
||||||
|
res: DAGContext[int] = await runner.execute_workflow(join_node)
|
||||||
|
assert res.current_task_context.current_state == TaskState.SUCCESS
|
||||||
|
assert not res.current_task_context.task_output.is_stream
|
||||||
|
assert res.current_task_context.task_output.output == 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input_node, is_odd",
|
||||||
|
[
|
||||||
|
({"outputs": [0]}, False),
|
||||||
|
# ({"outputs": [1]}, False),
|
||||||
|
],
|
||||||
|
indirect=["input_node"],
|
||||||
|
)
|
||||||
|
async def test_branch_node(
|
||||||
|
runner: WorkflowRunner, input_node: InputOperator, is_odd: bool
|
||||||
|
):
|
||||||
|
with DAG("test_join_node") as dag:
|
||||||
|
odd_node = MapOperator(lambda x: 999, task_id="odd_node")
|
||||||
|
even_node = MapOperator(lambda x: 888, task_id="even_node")
|
||||||
|
branch_node = BranchOperator(
|
||||||
|
{lambda x: x % 2 == 1: odd_node, lambda x: x % 2 == 0: even_node}
|
||||||
|
)
|
||||||
|
input_node >> branch_node
|
||||||
|
|
||||||
|
odd_res: DAGContext[int] = await runner.execute_workflow(odd_node)
|
||||||
|
even_res: DAGContext[int] = await runner.execute_workflow(even_node)
|
||||||
|
assert branch_node.current_task_context.current_state == TaskState.SUCCESS
|
10
pilot/cache/__init__.py
vendored
Normal file
10
pilot/cache/__init__.py
vendored
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
from pilot.cache.llm_cache import LLMCacheClient, LLMCacheKey, LLMCacheValue
|
||||||
|
from pilot.cache.manager import CacheManager, initialize_cache
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"LLMCacheKey",
|
||||||
|
"LLMCacheValue",
|
||||||
|
"LLMCacheClient",
|
||||||
|
"CacheManager",
|
||||||
|
"initialize_cache",
|
||||||
|
]
|
161
pilot/cache/base.py
vendored
Normal file
161
pilot/cache/base.py
vendored
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
from abc import ABC, abstractmethod, abstractclassmethod
|
||||||
|
from typing import Any, TypeVar, Generic, Optional, Type, Dict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
T = TypeVar("T", bound="Serializable")
|
||||||
|
|
||||||
|
K = TypeVar("K")
|
||||||
|
V = TypeVar("V")
|
||||||
|
|
||||||
|
|
||||||
|
class Serializable(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def serialize(self) -> bytes:
|
||||||
|
"""Convert the object into bytes for storage or transmission.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bytes: The byte array after serialization
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def to_dict(self) -> Dict:
|
||||||
|
"""Convert the object's state to a dictionary."""
|
||||||
|
|
||||||
|
# @staticmethod
|
||||||
|
# @abstractclassmethod
|
||||||
|
# def from_dict(cls: Type["Serializable"], obj_dict: Dict) -> "Serializable":
|
||||||
|
# """Deserialize a dictionary to an Serializable object.
|
||||||
|
# """
|
||||||
|
|
||||||
|
|
||||||
|
class RetrievalPolicy(str, Enum):
|
||||||
|
EXACT_MATCH = "exact_match"
|
||||||
|
SIMILARITY_MATCH = "similarity_match"
|
||||||
|
|
||||||
|
|
||||||
|
class CachePolicy(str, Enum):
|
||||||
|
LRU = "lru"
|
||||||
|
FIFO = "fifo"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CacheConfig:
|
||||||
|
retrieval_policy: Optional[RetrievalPolicy] = RetrievalPolicy.EXACT_MATCH
|
||||||
|
cache_policy: Optional[CachePolicy] = CachePolicy.LRU
|
||||||
|
|
||||||
|
|
||||||
|
class CacheKey(Serializable, ABC, Generic[K]):
|
||||||
|
"""The key of the cache. Must be hashable and comparable.
|
||||||
|
|
||||||
|
Supported cache keys:
|
||||||
|
- The LLM cache key: Include user prompt and the parameters to LLM.
|
||||||
|
- The embedding model cache key: Include the texts to embedding and the parameters to embedding model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
"""Return the hash value of the key."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __eq__(self, other: Any) -> bool:
|
||||||
|
"""Check equality with another key."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_hash_bytes(self) -> bytes:
|
||||||
|
"""Return the byte array of hash value."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_value(self) -> K:
|
||||||
|
"""Get the underlying value of the cache key.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
K: The real object of current cache key
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class CacheValue(Serializable, ABC, Generic[V]):
|
||||||
|
"""Cache value abstract class."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_value(self) -> V:
|
||||||
|
"""Get the underlying real value."""
|
||||||
|
|
||||||
|
|
||||||
|
class Serializer(ABC):
|
||||||
|
"""The serializer abstract class for serializing cache keys and values."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def serialize(self, obj: Serializable) -> bytes:
|
||||||
|
"""Serialize a cache object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj (Serializable): The object to serialize
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def deserialize(self, data: bytes, cls: Type[Serializable]) -> Serializable:
|
||||||
|
"""Deserialize data back into a cache object of the specified type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (bytes): The byte array to deserialize
|
||||||
|
cls (Type[Serializable]): The type of current object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Serializable: The serializable object
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class CacheClient(ABC, Generic[K, V]):
|
||||||
|
"""The cache client interface."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get(
|
||||||
|
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
|
||||||
|
) -> Optional[CacheValue[V]]:
|
||||||
|
"""Retrieve a value from the cache using the provided key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key (CacheKey[K]): The key to get cache
|
||||||
|
cache_config (Optional[CacheConfig]): Cache config
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[CacheValue[V]]: The value retrieved according to key. If cache key not exist, return None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def set(
|
||||||
|
self,
|
||||||
|
key: CacheKey[K],
|
||||||
|
value: CacheValue[V],
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Set a value in the cache for the provided key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key (CacheKey[K]): The key to set to cache
|
||||||
|
value (CacheValue[V]): The value to set to cache
|
||||||
|
cache_config (Optional[CacheConfig]): Cache config
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def exists(
|
||||||
|
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
|
||||||
|
) -> bool:
|
||||||
|
"""Check if a key exists in the cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key (CacheKey[K]): The key to set to cache
|
||||||
|
cache_config (Optional[CacheConfig]): Cache config
|
||||||
|
|
||||||
|
Return:
|
||||||
|
bool: True if the key in the cache, otherwise is False
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def new_key(self, **kwargs) -> CacheKey[K]:
|
||||||
|
"""Create a cache key with params"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def new_value(self, **kwargs) -> CacheValue[K]:
|
||||||
|
"""Create a cache key with params"""
|
0
pilot/cache/embedding_cache.py
vendored
Normal file
0
pilot/cache/embedding_cache.py
vendored
Normal file
148
pilot/cache/llm_cache.py
vendored
Normal file
148
pilot/cache/llm_cache.py
vendored
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
from typing import Optional, Dict, Any, Union, List
|
||||||
|
from dataclasses import dataclass, asdict
|
||||||
|
import json
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
from pilot.cache.base import CacheKey, CacheValue, Serializer, CacheClient, CacheConfig
|
||||||
|
from pilot.cache.manager import CacheManager
|
||||||
|
from pilot.cache.storage.base import CacheStorage
|
||||||
|
from pilot.model.base import ModelType, ModelOutput
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LLMCacheKeyData:
|
||||||
|
prompt: str
|
||||||
|
model_name: str
|
||||||
|
temperature: Optional[float] = 0.7
|
||||||
|
max_new_tokens: Optional[int] = None
|
||||||
|
top_p: Optional[float] = 1.0
|
||||||
|
model_type: Optional[str] = ModelType.HF
|
||||||
|
|
||||||
|
|
||||||
|
CacheOutputType = Union[ModelOutput, List[ModelOutput]]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LLMCacheValueData:
|
||||||
|
output: CacheOutputType
|
||||||
|
user: Optional[str] = None
|
||||||
|
_is_list: Optional[bool] = False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_dict(**kwargs) -> "LLMCacheValueData":
|
||||||
|
output = kwargs.get("output")
|
||||||
|
if not output:
|
||||||
|
raise ValueError("Can't new LLMCacheValueData object, output is None")
|
||||||
|
if isinstance(output, dict):
|
||||||
|
output = ModelOutput(**output)
|
||||||
|
elif isinstance(output, list):
|
||||||
|
kwargs["_is_list"] = True
|
||||||
|
output_list = []
|
||||||
|
for out in output:
|
||||||
|
if isinstance(out, dict):
|
||||||
|
out = ModelOutput(**out)
|
||||||
|
output_list.append(out)
|
||||||
|
output = output_list
|
||||||
|
kwargs["output"] = output
|
||||||
|
return LLMCacheValueData(**kwargs)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict:
|
||||||
|
output = self.output
|
||||||
|
is_list = False
|
||||||
|
if isinstance(output, list):
|
||||||
|
output_list = []
|
||||||
|
is_list = True
|
||||||
|
for out in output:
|
||||||
|
output_list.append(out.to_dict())
|
||||||
|
output = output_list
|
||||||
|
else:
|
||||||
|
output = output.to_dict()
|
||||||
|
return {"output": output, "_is_list": is_list, "user": self.user}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_list(self) -> bool:
|
||||||
|
return self._is_list
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
if not isinstance(self.output, list):
|
||||||
|
return f"user: {self.user}, output: {self.output}"
|
||||||
|
else:
|
||||||
|
return f"user: {self.user}, output(last two item): {self.output[-2:]}"
|
||||||
|
|
||||||
|
|
||||||
|
class LLMCacheKey(CacheKey[LLMCacheKeyData]):
|
||||||
|
def __init__(self, serializer: Serializer = None, **kwargs) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._serializer = serializer
|
||||||
|
self.config = LLMCacheKeyData(**kwargs)
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
serialize_bytes = self.serialize()
|
||||||
|
return int(hashlib.sha256(serialize_bytes).hexdigest(), 16)
|
||||||
|
|
||||||
|
def __eq__(self, other: Any) -> bool:
|
||||||
|
if not isinstance(other, LLMCacheKey):
|
||||||
|
return False
|
||||||
|
return self.config == other.config
|
||||||
|
|
||||||
|
def get_hash_bytes(self) -> bytes:
|
||||||
|
serialize_bytes = self.serialize()
|
||||||
|
return hashlib.sha256(serialize_bytes).digest()
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict:
|
||||||
|
return asdict(self.config)
|
||||||
|
|
||||||
|
def serialize(self) -> bytes:
|
||||||
|
return self._serializer.serialize(self)
|
||||||
|
|
||||||
|
def get_value(self) -> LLMCacheKeyData:
|
||||||
|
return self.config
|
||||||
|
|
||||||
|
|
||||||
|
class LLMCacheValue(CacheValue[LLMCacheValueData]):
|
||||||
|
def __init__(self, serializer: Serializer = None, **kwargs) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._serializer = serializer
|
||||||
|
self.value = LLMCacheValueData.from_dict(**kwargs)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict:
|
||||||
|
return self.value.to_dict()
|
||||||
|
|
||||||
|
def serialize(self) -> bytes:
|
||||||
|
return self._serializer.serialize(self)
|
||||||
|
|
||||||
|
def get_value(self) -> LLMCacheValueData:
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"vaue: {str(self.value)}"
|
||||||
|
|
||||||
|
|
||||||
|
class LLMCacheClient(CacheClient[LLMCacheKeyData, LLMCacheValueData]):
|
||||||
|
def __init__(self, cache_manager: CacheManager) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._cache_manager: CacheManager = cache_manager
|
||||||
|
|
||||||
|
async def get(
|
||||||
|
self, key: LLMCacheKey, cache_config: Optional[CacheConfig] = None
|
||||||
|
) -> Optional[LLMCacheValue]:
|
||||||
|
return await self._cache_manager.get(key, LLMCacheValue, cache_config)
|
||||||
|
|
||||||
|
async def set(
|
||||||
|
self,
|
||||||
|
key: LLMCacheKey,
|
||||||
|
value: LLMCacheValue,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
return await self._cache_manager.set(key, value, cache_config)
|
||||||
|
|
||||||
|
async def exists(
|
||||||
|
self, key: LLMCacheKey, cache_config: Optional[CacheConfig] = None
|
||||||
|
) -> bool:
|
||||||
|
return await self.get(key, cache_config) is not None
|
||||||
|
|
||||||
|
def new_key(self, **kwargs) -> LLMCacheKey:
|
||||||
|
return LLMCacheKey(serializer=self._cache_manager.serializer, **kwargs)
|
||||||
|
|
||||||
|
def new_value(self, **kwargs) -> LLMCacheValue:
|
||||||
|
return LLMCacheValue(serializer=self._cache_manager.serializer, **kwargs)
|
118
pilot/cache/manager.py
vendored
Normal file
118
pilot/cache/manager.py
vendored
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional, Type
|
||||||
|
import logging
|
||||||
|
from concurrent.futures import Executor
|
||||||
|
from pilot.cache.storage.base import CacheStorage, StorageItem
|
||||||
|
from pilot.cache.base import (
|
||||||
|
K,
|
||||||
|
V,
|
||||||
|
CacheKey,
|
||||||
|
CacheValue,
|
||||||
|
CacheConfig,
|
||||||
|
Serializer,
|
||||||
|
Serializable,
|
||||||
|
)
|
||||||
|
from pilot.component import BaseComponent, ComponentType, SystemApp
|
||||||
|
from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CacheManager(BaseComponent, ABC):
|
||||||
|
name = ComponentType.MODEL_CACHE_MANAGER
|
||||||
|
|
||||||
|
def __init__(self, system_app: SystemApp | None = None):
|
||||||
|
super().__init__(system_app)
|
||||||
|
|
||||||
|
def init_app(self, system_app: SystemApp):
|
||||||
|
self.system_app = system_app
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def set(
|
||||||
|
self,
|
||||||
|
key: CacheKey[K],
|
||||||
|
value: CacheValue[V],
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
):
|
||||||
|
"""Set cache"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get(
|
||||||
|
self,
|
||||||
|
key: CacheKey[K],
|
||||||
|
cls: Type[Serializable],
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
) -> CacheValue[V]:
|
||||||
|
"""Get cache with key"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def serializer(self) -> Serializer:
|
||||||
|
"""Get cache serializer"""
|
||||||
|
|
||||||
|
|
||||||
|
class LocalCacheManager(CacheManager):
|
||||||
|
def __init__(
|
||||||
|
self, system_app: SystemApp, serializer: Serializer, storage: CacheStorage
|
||||||
|
) -> None:
|
||||||
|
super().__init__(system_app)
|
||||||
|
self._serializer = serializer
|
||||||
|
self._storage = storage
|
||||||
|
|
||||||
|
@property
|
||||||
|
def executor(self) -> Executor:
|
||||||
|
"""Return executor to submit task"""
|
||||||
|
self._executor = self.system_app.get_component(
|
||||||
|
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||||
|
).create()
|
||||||
|
|
||||||
|
async def set(
|
||||||
|
self,
|
||||||
|
key: CacheKey[K],
|
||||||
|
value: CacheValue[V],
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
):
|
||||||
|
if self._storage.support_async():
|
||||||
|
await self._storage.aset(key, value, cache_config)
|
||||||
|
else:
|
||||||
|
await blocking_func_to_async(
|
||||||
|
self.executor, self._storage.set, key, value, cache_config
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get(
|
||||||
|
self,
|
||||||
|
key: CacheKey[K],
|
||||||
|
cls: Type[Serializable],
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
) -> CacheValue[V]:
|
||||||
|
if self._storage.support_async():
|
||||||
|
item_bytes = await self._storage.aget(key, cache_config)
|
||||||
|
else:
|
||||||
|
item_bytes = await blocking_func_to_async(
|
||||||
|
self.executor, self._storage.get, key, cache_config
|
||||||
|
)
|
||||||
|
if not item_bytes:
|
||||||
|
return None
|
||||||
|
return self._serializer.deserialize(item_bytes.value_data, cls)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def serializer(self) -> Serializer:
|
||||||
|
return self._serializer
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_cache(system_app: SystemApp, persist_dir: str):
|
||||||
|
from pilot.cache.protocal.json_protocal import JsonSerializer
|
||||||
|
from pilot.cache.storage.base import MemoryCacheStorage
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pilot.cache.storage.disk.disk_storage import DiskCacheStorage
|
||||||
|
|
||||||
|
cache_storage = DiskCacheStorage(persist_dir)
|
||||||
|
except ImportError as e:
|
||||||
|
logger.warn(
|
||||||
|
f"Can't import DiskCacheStorage, use MemoryCacheStorage, import error message: {str(e)}"
|
||||||
|
)
|
||||||
|
cache_storage = MemoryCacheStorage()
|
||||||
|
system_app.register(
|
||||||
|
LocalCacheManager, serializer=JsonSerializer(), storage=cache_storage
|
||||||
|
)
|
0
pilot/cache/protocal/__init__.py
vendored
Normal file
0
pilot/cache/protocal/__init__.py
vendored
Normal file
44
pilot/cache/protocal/json_protocal.py
vendored
Normal file
44
pilot/cache/protocal/json_protocal.py
vendored
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, Type
|
||||||
|
import json
|
||||||
|
|
||||||
|
from pilot.cache.base import Serializable, Serializer
|
||||||
|
|
||||||
|
JSON_ENCODING = "utf-8"
|
||||||
|
|
||||||
|
|
||||||
|
class JsonSerializable(Serializable, ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def to_dict(self) -> Dict:
|
||||||
|
"""Return the dict of current serializable object"""
|
||||||
|
|
||||||
|
def serialize(self) -> bytes:
|
||||||
|
"""Convert the object into bytes for storage or transmission."""
|
||||||
|
return json.dumps(self.to_dict(), ensure_ascii=False).encode(JSON_ENCODING)
|
||||||
|
|
||||||
|
|
||||||
|
class JsonSerializer(Serializer):
|
||||||
|
"""The serializer abstract class for serializing cache keys and values."""
|
||||||
|
|
||||||
|
def serialize(self, obj: Serializable) -> bytes:
|
||||||
|
"""Serialize a cache object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj (Serializable): The object to serialize
|
||||||
|
"""
|
||||||
|
return json.dumps(obj.to_dict(), ensure_ascii=False).encode(JSON_ENCODING)
|
||||||
|
|
||||||
|
def deserialize(self, data: bytes, cls: Type[Serializable]) -> Serializable:
|
||||||
|
"""Deserialize data back into a cache object of the specified type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (bytes): The byte array to deserialize
|
||||||
|
cls (Type[Serializable]): The type of current object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Serializable: The serializable object
|
||||||
|
"""
|
||||||
|
# Convert bytes back to JSON and then to the specified class
|
||||||
|
json_data = json.loads(data.decode(JSON_ENCODING))
|
||||||
|
# Assume that the cls has an __init__ that accepts a dictionary
|
||||||
|
return cls(**json_data)
|
0
pilot/cache/storage/__init__.py
vendored
Normal file
0
pilot/cache/storage/__init__.py
vendored
Normal file
252
pilot/cache/storage/base.py
vendored
Normal file
252
pilot/cache/storage/base.py
vendored
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from collections import OrderedDict
|
||||||
|
import msgpack
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from pilot.cache.base import (
|
||||||
|
K,
|
||||||
|
V,
|
||||||
|
CacheKey,
|
||||||
|
CacheValue,
|
||||||
|
CacheClient,
|
||||||
|
CacheConfig,
|
||||||
|
RetrievalPolicy,
|
||||||
|
CachePolicy,
|
||||||
|
)
|
||||||
|
from pilot.utils.memory_utils import _get_object_bytes
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StorageItem:
|
||||||
|
"""
|
||||||
|
A class representing a storage item.
|
||||||
|
|
||||||
|
This class encapsulates data related to a storage item, such as its length,
|
||||||
|
the hash of the key, and the data for both the key and value.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
length (int): The bytes length of the storage item.
|
||||||
|
key_hash (bytes): The hash value of the storage item's key.
|
||||||
|
key_data (bytes): The data of the storage item's key, represented in bytes.
|
||||||
|
value_data (bytes): The data of the storage item's value, also in bytes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
length: int # The bytes length of the storage item
|
||||||
|
key_hash: bytes # The hash value of the storage item's key
|
||||||
|
key_data: bytes # The data of the storage item's key
|
||||||
|
value_data: bytes # The data of the storage item's value
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_from(
|
||||||
|
key_hash: bytes, key_data: bytes, value_data: bytes
|
||||||
|
) -> "StorageItem":
|
||||||
|
length = (
|
||||||
|
32
|
||||||
|
+ _get_object_bytes(key_hash)
|
||||||
|
+ _get_object_bytes(key_data)
|
||||||
|
+ _get_object_bytes(value_data)
|
||||||
|
)
|
||||||
|
return StorageItem(
|
||||||
|
length=length, key_hash=key_hash, key_data=key_data, value_data=value_data
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_from_kv(key: CacheKey[K], value: CacheValue[V]) -> "StorageItem":
|
||||||
|
key_hash = key.get_hash_bytes()
|
||||||
|
key_data = key.serialize()
|
||||||
|
value_data = value.serialize()
|
||||||
|
return StorageItem.build_from(key_hash, key_data, value_data)
|
||||||
|
|
||||||
|
def serialize(self) -> bytes:
|
||||||
|
"""Serialize the StorageItem into a byte stream using MessagePack.
|
||||||
|
|
||||||
|
This method packs the object data into a dictionary, marking the
|
||||||
|
key_data and value_data fields as raw binary data to avoid re-serialization.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bytes: The serialized bytes.
|
||||||
|
"""
|
||||||
|
obj = {
|
||||||
|
"length": self.length,
|
||||||
|
"key_hash": msgpack.ExtType(1, self.key_hash),
|
||||||
|
"key_data": msgpack.ExtType(2, self.key_data),
|
||||||
|
"value_data": msgpack.ExtType(3, self.value_data),
|
||||||
|
}
|
||||||
|
return msgpack.packb(obj)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def deserialize(data: bytes) -> "StorageItem":
|
||||||
|
"""Deserialize bytes back into a StorageItem using MessagePack.
|
||||||
|
|
||||||
|
This extracts the fields from the MessagePack dict back into
|
||||||
|
a StorageItem object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (bytes): Serialized bytes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StorageItem: Deserialized StorageItem object.
|
||||||
|
"""
|
||||||
|
obj = msgpack.unpackb(data)
|
||||||
|
key_hash = obj["key_hash"].data
|
||||||
|
key_data = obj["key_data"].data
|
||||||
|
value_data = obj["value_data"].data
|
||||||
|
|
||||||
|
return StorageItem(
|
||||||
|
length=obj["length"],
|
||||||
|
key_hash=key_hash,
|
||||||
|
key_data=key_data,
|
||||||
|
value_data=value_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CacheStorage(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def check_config(
|
||||||
|
self,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
raise_error: Optional[bool] = True,
|
||||||
|
) -> bool:
|
||||||
|
"""Check whether the CacheConfig is legal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache_config (Optional[CacheConfig]): Cache config.
|
||||||
|
raise_error (Optional[bool]): Whether raise error if illegal.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ValueError: Error when raise_error is True and config is illegal.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def support_async(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(
|
||||||
|
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
|
||||||
|
) -> Optional[StorageItem]:
|
||||||
|
"""Retrieve a storage item from the cache using the provided key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key (CacheKey[K]): The key to get cache
|
||||||
|
cache_config (Optional[CacheConfig]): Cache config
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[StorageItem]: The storage item retrieved according to key. If cache key not exist, return None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def aget(
|
||||||
|
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
|
||||||
|
) -> Optional[StorageItem]:
|
||||||
|
"""Retrieve a storage item from the cache using the provided key asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key (CacheKey[K]): The key to get cache
|
||||||
|
cache_config (Optional[CacheConfig]): Cache config
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[StorageItem]: The storage item of bytes retrieved according to key. If cache key not exist, return None.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set(
|
||||||
|
self,
|
||||||
|
key: CacheKey[K],
|
||||||
|
value: CacheValue[V],
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Set a value in the cache for the provided key asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key (CacheKey[K]): The key to set to cache
|
||||||
|
value (CacheValue[V]): The value to set to cache
|
||||||
|
cache_config (Optional[CacheConfig]): Cache config
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def aset(
|
||||||
|
self,
|
||||||
|
key: CacheKey[K],
|
||||||
|
value: CacheValue[V],
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Set a value in the cache for the provided key asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key (CacheKey[K]): The key to set to cache
|
||||||
|
value (CacheValue[V]): The value to set to cache
|
||||||
|
cache_config (Optional[CacheConfig]): Cache config
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryCacheStorage(CacheStorage):
|
||||||
|
def __init__(self, max_memory_mb: int = 1024):
|
||||||
|
self.cache = OrderedDict()
|
||||||
|
self.max_memory = max_memory_mb * 1024 * 1024
|
||||||
|
self.current_memory_usage = 0
|
||||||
|
|
||||||
|
def check_config(
|
||||||
|
self,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
raise_error: Optional[bool] = True,
|
||||||
|
) -> bool:
|
||||||
|
if (
|
||||||
|
cache_config
|
||||||
|
and cache_config.retrieval_policy != RetrievalPolicy.EXACT_MATCH
|
||||||
|
):
|
||||||
|
if raise_error:
|
||||||
|
raise ValueError(
|
||||||
|
"MemoryCacheStorage only supports 'EXACT_MATCH' retrieval policy"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get(
|
||||||
|
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
|
||||||
|
) -> Optional[StorageItem]:
|
||||||
|
self.check_config(cache_config, raise_error=True)
|
||||||
|
# Exact match retrieval
|
||||||
|
key_hash = hash(key)
|
||||||
|
item: StorageItem = self.cache.get(key_hash)
|
||||||
|
logger.debug(f"MemoryCacheStorage get key {key}, hash {key_hash}, item: {item}")
|
||||||
|
|
||||||
|
if not item:
|
||||||
|
return None
|
||||||
|
# Move the item to the end of the OrderedDict to signify recent use.
|
||||||
|
self.cache.move_to_end(key_hash)
|
||||||
|
return item
|
||||||
|
|
||||||
|
def set(
|
||||||
|
self,
|
||||||
|
key: CacheKey[K],
|
||||||
|
value: CacheValue[V],
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
key_hash = hash(key)
|
||||||
|
item = StorageItem.build_from_kv(key, value)
|
||||||
|
# Calculate memory size of the new entry
|
||||||
|
new_entry_size = _get_object_bytes(item)
|
||||||
|
# Evict entries if necessary
|
||||||
|
while self.current_memory_usage + new_entry_size > self.max_memory:
|
||||||
|
self._apply_cache_policy(cache_config)
|
||||||
|
|
||||||
|
# Store the item in the cache.
|
||||||
|
self.cache[key_hash] = item
|
||||||
|
self.current_memory_usage += new_entry_size
|
||||||
|
logger.debug(f"MemoryCacheStorage set key {key}, hash {key_hash}, item: {item}")
|
||||||
|
|
||||||
|
def exists(
|
||||||
|
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
|
||||||
|
) -> bool:
|
||||||
|
return self.get(key, cache_config) is not None
|
||||||
|
|
||||||
|
def _apply_cache_policy(self, cache_config: Optional[CacheConfig] = None):
|
||||||
|
# Remove the oldest/newest item based on the cache policy.
|
||||||
|
if cache_config and cache_config.cache_policy == CachePolicy.FIFO:
|
||||||
|
self.cache.popitem(last=False)
|
||||||
|
else: # Default is LRU
|
||||||
|
self.cache.popitem(last=True)
|
0
pilot/cache/storage/disk/__init__.py
vendored
Normal file
0
pilot/cache/storage/disk/__init__.py
vendored
Normal file
93
pilot/cache/storage/disk/disk_storage.py
vendored
Normal file
93
pilot/cache/storage/disk/disk_storage.py
vendored
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
from typing import Optional
|
||||||
|
import logging
|
||||||
|
from pilot.cache.base import (
|
||||||
|
K,
|
||||||
|
V,
|
||||||
|
CacheKey,
|
||||||
|
CacheValue,
|
||||||
|
CacheConfig,
|
||||||
|
RetrievalPolicy,
|
||||||
|
CachePolicy,
|
||||||
|
)
|
||||||
|
from pilot.cache.storage.base import StorageItem, CacheStorage
|
||||||
|
from rocksdict import Rdict
|
||||||
|
from rocksdict import Rdict, Options, SliceTransform, PlainTableFactoryOptions
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def db_options(
|
||||||
|
mem_table_buffer_mb: Optional[int] = 256, background_threads: Optional[int] = 2
|
||||||
|
):
|
||||||
|
opt = Options()
|
||||||
|
# create table
|
||||||
|
opt.create_if_missing(True)
|
||||||
|
# config to more jobs, default 2
|
||||||
|
opt.set_max_background_jobs(background_threads)
|
||||||
|
# configure mem-table to a large value
|
||||||
|
opt.set_write_buffer_size(mem_table_buffer_mb * 1024 * 1024)
|
||||||
|
# opt.set_write_buffer_size(1024)
|
||||||
|
# opt.set_level_zero_file_num_compaction_trigger(4)
|
||||||
|
# configure l0 and l1 size, let them have the same size (1 GB)
|
||||||
|
# opt.set_max_bytes_for_level_base(0x40000000)
|
||||||
|
# 256 MB file size
|
||||||
|
# opt.set_target_file_size_base(0x10000000)
|
||||||
|
# use a smaller compaction multiplier
|
||||||
|
# opt.set_max_bytes_for_level_multiplier(4.0)
|
||||||
|
# use 8-byte prefix (2 ^ 64 is far enough for transaction counts)
|
||||||
|
# opt.set_prefix_extractor(SliceTransform.create_max_len_prefix(8))
|
||||||
|
# set to plain-table
|
||||||
|
# opt.set_plain_table_factory(PlainTableFactoryOptions())
|
||||||
|
return opt
|
||||||
|
|
||||||
|
|
||||||
|
class DiskCacheStorage(CacheStorage):
|
||||||
|
def __init__(
|
||||||
|
self, persist_dir: str, mem_table_buffer_mb: Optional[int] = 256
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.db: Rdict = Rdict(
|
||||||
|
persist_dir, db_options(mem_table_buffer_mb=mem_table_buffer_mb)
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_config(
|
||||||
|
self,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
raise_error: Optional[bool] = True,
|
||||||
|
) -> bool:
|
||||||
|
if (
|
||||||
|
cache_config
|
||||||
|
and cache_config.retrieval_policy != RetrievalPolicy.EXACT_MATCH
|
||||||
|
):
|
||||||
|
if raise_error:
|
||||||
|
raise ValueError(
|
||||||
|
"DiskCacheStorage only supports 'EXACT_MATCH' retrieval policy"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get(
|
||||||
|
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
|
||||||
|
) -> Optional[StorageItem]:
|
||||||
|
self.check_config(cache_config, raise_error=True)
|
||||||
|
|
||||||
|
# Exact match retrieval
|
||||||
|
key_hash = key.get_hash_bytes()
|
||||||
|
item_bytes = self.db.get(key_hash)
|
||||||
|
if not item_bytes:
|
||||||
|
return None
|
||||||
|
item = StorageItem.deserialize(item_bytes)
|
||||||
|
logger.debug(f"Read file cache, key: {key}, storage item: {item}")
|
||||||
|
return item
|
||||||
|
|
||||||
|
def set(
|
||||||
|
self,
|
||||||
|
key: CacheKey[K],
|
||||||
|
value: CacheValue[V],
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
item = StorageItem.build_from_kv(key, value)
|
||||||
|
key_hash = item.key_hash
|
||||||
|
self.db[key_hash] = item.serialize()
|
||||||
|
logger.debug(f"Save file cache, key: {key}, value: {value}")
|
0
pilot/cache/storage/tests/__init__.py
vendored
Normal file
0
pilot/cache/storage/tests/__init__.py
vendored
Normal file
53
pilot/cache/storage/tests/test_storage.py
vendored
Normal file
53
pilot/cache/storage/tests/test_storage.py
vendored
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
import pytest
|
||||||
|
from ..base import StorageItem
|
||||||
|
from pilot.utils.memory_utils import _get_object_bytes
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_from():
|
||||||
|
key_hash = b"key_hash"
|
||||||
|
key_data = b"key_data"
|
||||||
|
value_data = b"value_data"
|
||||||
|
item = StorageItem.build_from(key_hash, key_data, value_data)
|
||||||
|
|
||||||
|
assert item.key_hash == key_hash
|
||||||
|
assert item.key_data == key_data
|
||||||
|
assert item.value_data == value_data
|
||||||
|
assert item.length == 32 + _get_object_bytes(key_hash) + _get_object_bytes(
|
||||||
|
key_data
|
||||||
|
) + _get_object_bytes(value_data)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_from_kv():
|
||||||
|
class MockCacheKey:
|
||||||
|
def get_hash_bytes(self):
|
||||||
|
return b"key_hash"
|
||||||
|
|
||||||
|
def serialize(self):
|
||||||
|
return b"key_data"
|
||||||
|
|
||||||
|
class MockCacheValue:
|
||||||
|
def serialize(self):
|
||||||
|
return b"value_data"
|
||||||
|
|
||||||
|
key = MockCacheKey()
|
||||||
|
value = MockCacheValue()
|
||||||
|
item = StorageItem.build_from_kv(key, value)
|
||||||
|
|
||||||
|
assert item.key_hash == key.get_hash_bytes()
|
||||||
|
assert item.key_data == key.serialize()
|
||||||
|
assert item.value_data == value.serialize()
|
||||||
|
|
||||||
|
|
||||||
|
def test_serialize_deserialize():
|
||||||
|
key_hash = b"key_hash"
|
||||||
|
key_data = b"key_data"
|
||||||
|
value_data = b"value_data"
|
||||||
|
item = StorageItem.build_from(key_hash, key_data, value_data)
|
||||||
|
|
||||||
|
serialized = item.serialize()
|
||||||
|
deserialized = StorageItem.deserialize(serialized)
|
||||||
|
|
||||||
|
assert deserialized.key_hash == item.key_hash
|
||||||
|
assert deserialized.key_data == item.key_data
|
||||||
|
assert deserialized.value_data == item.value_data
|
||||||
|
assert deserialized.length == item.length
|
@@ -48,6 +48,7 @@ class ComponentType(str, Enum):
|
|||||||
MODEL_CONTROLLER = "dbgpt_model_controller"
|
MODEL_CONTROLLER = "dbgpt_model_controller"
|
||||||
MODEL_REGISTRY = "dbgpt_model_registry"
|
MODEL_REGISTRY = "dbgpt_model_registry"
|
||||||
MODEL_API_SERVER = "dbgpt_model_api_server"
|
MODEL_API_SERVER = "dbgpt_model_api_server"
|
||||||
|
MODEL_CACHE_MANAGER = "dbgpt_model_cache_manager"
|
||||||
AGENT_HUB = "dbgpt_agent_hub"
|
AGENT_HUB = "dbgpt_agent_hub"
|
||||||
EXECUTOR_DEFAULT = "dbgpt_thread_pool_default"
|
EXECUTOR_DEFAULT = "dbgpt_thread_pool_default"
|
||||||
TRACER = "dbgpt_tracer"
|
TRACER = "dbgpt_tracer"
|
||||||
|
@@ -253,6 +253,11 @@ class Config(metaclass=Singleton):
|
|||||||
### Temporary configuration
|
### Temporary configuration
|
||||||
self.USE_FASTCHAT: bool = os.getenv("USE_FASTCHAT", "True").lower() == "true"
|
self.USE_FASTCHAT: bool = os.getenv("USE_FASTCHAT", "True").lower() == "true"
|
||||||
|
|
||||||
|
self.MODEL_CACHE_STORAGE: str = os.getenv("MODEL_CACHE_STORAGE")
|
||||||
|
self.MODEL_CACHE_STORAGE_DIST_DIR: str = os.getenv(
|
||||||
|
"MODEL_CACHE_STORAGE_DIST_DIR"
|
||||||
|
)
|
||||||
|
|
||||||
def set_debug_mode(self, value: bool) -> None:
|
def set_debug_mode(self, value: bool) -> None:
|
||||||
"""Set the debug mode value"""
|
"""Set the debug mode value"""
|
||||||
self.debug_mode = value
|
self.debug_mode = value
|
||||||
|
@@ -14,6 +14,7 @@ DATA_DIR = os.path.join(PILOT_PATH, "data")
|
|||||||
# nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path
|
# nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path
|
||||||
PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins")
|
PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins")
|
||||||
FONT_DIR = os.path.join(PILOT_PATH, "fonts")
|
FONT_DIR = os.path.join(PILOT_PATH, "fonts")
|
||||||
|
MODEL_DISK_CACHE_DIR = os.path.join(DATA_DIR, "model_cache")
|
||||||
|
|
||||||
current_directory = os.getcwd()
|
current_directory = os.getcwd()
|
||||||
|
|
||||||
|
0
pilot/model/operator/__init__.py
Normal file
0
pilot/model/operator/__init__.py
Normal file
119
pilot/model/operator/model_operator.py
Normal file
119
pilot/model/operator/model_operator.py
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
from typing import AsyncIterator, Dict
|
||||||
|
import logging
|
||||||
|
from pilot.awel import (
|
||||||
|
StreamifyAbsOperator,
|
||||||
|
MapOperator,
|
||||||
|
TransformStreamAbsOperator,
|
||||||
|
)
|
||||||
|
from pilot.model.base import ModelOutput
|
||||||
|
from pilot.model.cluster import WorkerManager
|
||||||
|
from pilot.cache import LLMCacheClient, CacheManager, LLMCacheKey, LLMCacheValue
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_LLM_MODEL_INPUT_VALUE_KEY = "llm_model_input_value"
|
||||||
|
_LLM_MODEL_OUTPUT_CACHE_KEY = "llm_model_output_cache"
|
||||||
|
|
||||||
|
|
||||||
|
class ModelStreamOperator(StreamifyAbsOperator[Dict, ModelOutput]):
|
||||||
|
def __init__(self, worker_manager: WorkerManager, **kwargs) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.worker_manager = worker_manager
|
||||||
|
|
||||||
|
async def streamify(self, input_value: Dict) -> AsyncIterator[ModelOutput]:
|
||||||
|
llm_cache_value: LLMCacheValue = await self.current_dag_context.get_share_data(
|
||||||
|
_LLM_MODEL_OUTPUT_CACHE_KEY
|
||||||
|
)
|
||||||
|
logger.info(f"llm_cache_value: {llm_cache_value}")
|
||||||
|
if llm_cache_value:
|
||||||
|
for out in llm_cache_value.get_value().output:
|
||||||
|
yield out
|
||||||
|
return
|
||||||
|
async for out in self.worker_manager.generate_stream(input_value):
|
||||||
|
yield out
|
||||||
|
|
||||||
|
|
||||||
|
class ModelOperator(MapOperator[Dict, ModelOutput]):
|
||||||
|
def __init__(self, worker_manager: WorkerManager, **kwargs) -> None:
|
||||||
|
self.worker_manager = worker_manager
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
async def map(self, input_value: Dict) -> ModelOutput:
|
||||||
|
llm_cache_value: LLMCacheValue = await self.current_dag_context.get_share_data(
|
||||||
|
_LLM_MODEL_OUTPUT_CACHE_KEY
|
||||||
|
)
|
||||||
|
logger.info(f"llm_cache_value: {llm_cache_value}")
|
||||||
|
if llm_cache_value:
|
||||||
|
return llm_cache_value.get_value().output
|
||||||
|
return await self.worker_manager.generate(input_value)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelCachePreOperator(MapOperator[Dict, Dict]):
|
||||||
|
def __init__(self, cache_manager: CacheManager, **kwargs):
|
||||||
|
self._cache_manager = cache_manager
|
||||||
|
self._client = LLMCacheClient(cache_manager)
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
async def map(self, input_value: Dict) -> Dict:
|
||||||
|
cache_dict = {
|
||||||
|
"prompt": input_value.get("prompt"),
|
||||||
|
"model_name": input_value.get("model"),
|
||||||
|
"temperature": input_value.get("temperature"),
|
||||||
|
"max_new_tokens": input_value.get("max_new_tokens"),
|
||||||
|
"top_p": input_value.get("top_p", "1.0"),
|
||||||
|
# TODO pass model_type
|
||||||
|
"model_type": input_value.get("model_type", "huggingface"),
|
||||||
|
}
|
||||||
|
cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
|
||||||
|
cache_value = await self._client.get(cache_key)
|
||||||
|
logger.debug(
|
||||||
|
f"cache_key: {cache_key}, hash key: {hash(cache_key)}, cache_value: {cache_value}"
|
||||||
|
)
|
||||||
|
await self.current_dag_context.save_to_share_data(
|
||||||
|
_LLM_MODEL_INPUT_VALUE_KEY, cache_key
|
||||||
|
)
|
||||||
|
if cache_value:
|
||||||
|
logger.info(f"The model output has cached, cache_value: {cache_value}")
|
||||||
|
await self.current_dag_context.save_to_share_data(
|
||||||
|
_LLM_MODEL_OUTPUT_CACHE_KEY, cache_value
|
||||||
|
)
|
||||||
|
return input_value
|
||||||
|
|
||||||
|
|
||||||
|
class ModelStreamCacheOperator(TransformStreamAbsOperator[ModelOutput, ModelOutput]):
|
||||||
|
def __init__(self, cache_manager: CacheManager, **kwargs):
|
||||||
|
self._cache_manager = cache_manager
|
||||||
|
self._client = LLMCacheClient(cache_manager)
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
async def transform_stream(
|
||||||
|
self, input_value: AsyncIterator[ModelOutput]
|
||||||
|
) -> AsyncIterator[ModelOutput]:
|
||||||
|
llm_cache_key: LLMCacheKey = None
|
||||||
|
outputs = []
|
||||||
|
async for out in input_value:
|
||||||
|
if not llm_cache_key:
|
||||||
|
llm_cache_key = await self.current_dag_context.get_share_data(
|
||||||
|
_LLM_MODEL_INPUT_VALUE_KEY
|
||||||
|
)
|
||||||
|
outputs.append(out)
|
||||||
|
yield out
|
||||||
|
if llm_cache_key:
|
||||||
|
llm_cache_value: LLMCacheValue = self._client.new_value(output=outputs)
|
||||||
|
await self._client.set(llm_cache_key, llm_cache_value)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelCacheOperator(MapOperator[ModelOutput, ModelOutput]):
|
||||||
|
def __init__(self, cache_manager: CacheManager, **kwargs):
|
||||||
|
self._cache_manager = cache_manager
|
||||||
|
self._client = LLMCacheClient(cache_manager)
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
async def map(self, input_value: ModelOutput) -> ModelOutput:
|
||||||
|
llm_cache_key: LLMCacheKey = await self.current_dag_context.get_share_data(
|
||||||
|
_LLM_MODEL_INPUT_VALUE_KEY
|
||||||
|
)
|
||||||
|
llm_cache_value: LLMCacheValue = self._client.new_value(output=input_value)
|
||||||
|
if llm_cache_key:
|
||||||
|
await self._client.set(llm_cache_key, llm_cache_value)
|
||||||
|
return input_value
|
@@ -16,6 +16,8 @@ from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async
|
|||||||
from pilot.utils.tracer import root_tracer, trace
|
from pilot.utils.tracer import root_tracer, trace
|
||||||
from pydantic import Extra
|
from pydantic import Extra
|
||||||
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
|
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
|
||||||
|
from pilot.awel import BaseOperator, SimpleCallDataInputSource, InputOperator, DAG
|
||||||
|
from pilot.model.operator.model_operator import ModelOperator, ModelStreamOperator
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
headers = {"User-Agent": "dbgpt Client"}
|
headers = {"User-Agent": "dbgpt Client"}
|
||||||
@@ -88,6 +90,11 @@ class BaseChat(ABC):
|
|||||||
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||||
).create()
|
).create()
|
||||||
|
|
||||||
|
self._model_operator: BaseOperator = _build_model_operator()
|
||||||
|
self._model_stream_operator: BaseOperator = _build_model_operator(
|
||||||
|
is_stream=True, dag_name="llm_stream_model_dag"
|
||||||
|
)
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
@@ -204,12 +211,9 @@ class BaseChat(ABC):
|
|||||||
)
|
)
|
||||||
payload["span_id"] = span.span_id
|
payload["span_id"] = span.span_id
|
||||||
try:
|
try:
|
||||||
from pilot.model.cluster import WorkerManagerFactory
|
async for output in await self._model_stream_operator.call_stream(
|
||||||
|
call_data={"data": payload}
|
||||||
worker_manager = CFG.SYSTEM_APP.get_component(
|
):
|
||||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
|
||||||
).create()
|
|
||||||
async for output in worker_manager.generate_stream(payload):
|
|
||||||
### Plug-in research in result generation
|
### Plug-in research in result generation
|
||||||
msg = self.prompt_template.output_parser.parse_model_stream_resp_ex(
|
msg = self.prompt_template.output_parser.parse_model_stream_resp_ex(
|
||||||
output, self.skip_echo_len
|
output, self.skip_echo_len
|
||||||
@@ -240,14 +244,10 @@ class BaseChat(ABC):
|
|||||||
)
|
)
|
||||||
payload["span_id"] = span.span_id
|
payload["span_id"] = span.span_id
|
||||||
try:
|
try:
|
||||||
from pilot.model.cluster import WorkerManagerFactory
|
|
||||||
|
|
||||||
worker_manager = CFG.SYSTEM_APP.get_component(
|
|
||||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
|
||||||
).create()
|
|
||||||
|
|
||||||
with root_tracer.start_span("BaseChat.invoke_worker_manager.generate"):
|
with root_tracer.start_span("BaseChat.invoke_worker_manager.generate"):
|
||||||
model_output = await worker_manager.generate(payload)
|
model_output = await self._model_operator.call(
|
||||||
|
call_data={"data": payload}
|
||||||
|
)
|
||||||
|
|
||||||
### output parse
|
### output parse
|
||||||
ai_response_text = (
|
ai_response_text = (
|
||||||
@@ -307,14 +307,7 @@ class BaseChat(ABC):
|
|||||||
logger.info(f"Request: \n{payload}")
|
logger.info(f"Request: \n{payload}")
|
||||||
ai_response_text = ""
|
ai_response_text = ""
|
||||||
try:
|
try:
|
||||||
from pilot.model.cluster import WorkerManagerFactory
|
model_output = await self._model_operator.call(call_data={"data": payload})
|
||||||
|
|
||||||
worker_manager = CFG.SYSTEM_APP.get_component(
|
|
||||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
|
||||||
).create()
|
|
||||||
|
|
||||||
model_output = await worker_manager.generate(payload)
|
|
||||||
|
|
||||||
### output parse
|
### output parse
|
||||||
ai_response_text = (
|
ai_response_text = (
|
||||||
self.prompt_template.output_parser.parse_model_nostream_resp(
|
self.prompt_template.output_parser.parse_model_nostream_resp(
|
||||||
@@ -568,3 +561,34 @@ class BaseChat(ABC):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return prompt_define_response
|
return prompt_define_response
|
||||||
|
|
||||||
|
|
||||||
|
def _build_model_operator(
|
||||||
|
is_stream: bool = False, dag_name: str = "llm_model_dag"
|
||||||
|
) -> BaseOperator:
|
||||||
|
from pilot.model.cluster import WorkerManagerFactory
|
||||||
|
from pilot.model.operator.model_operator import (
|
||||||
|
ModelCacheOperator,
|
||||||
|
ModelStreamCacheOperator,
|
||||||
|
ModelCachePreOperator,
|
||||||
|
)
|
||||||
|
from pilot.cache import CacheManager
|
||||||
|
|
||||||
|
worker_manager = CFG.SYSTEM_APP.get_component(
|
||||||
|
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||||
|
).create()
|
||||||
|
cache_manager: CacheManager = CFG.SYSTEM_APP.get_component(
|
||||||
|
ComponentType.MODEL_CACHE_MANAGER, CacheManager
|
||||||
|
)
|
||||||
|
|
||||||
|
with DAG(dag_name):
|
||||||
|
input_node = InputOperator(SimpleCallDataInputSource())
|
||||||
|
cache_check_node = ModelCachePreOperator(cache_manager)
|
||||||
|
if is_stream:
|
||||||
|
model_node = ModelStreamOperator(worker_manager)
|
||||||
|
cache_node = ModelStreamCacheOperator(cache_manager)
|
||||||
|
else:
|
||||||
|
model_node = ModelOperator(worker_manager)
|
||||||
|
cache_node = ModelCacheOperator(cache_manager)
|
||||||
|
input_node >> cache_check_node >> model_node >> cache_node
|
||||||
|
return cache_node
|
||||||
|
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Type
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from pilot.component import ComponentType, SystemApp
|
from pilot.component import ComponentType, SystemApp
|
||||||
|
from pilot.configs.model_config import MODEL_DISK_CACHE_DIR
|
||||||
from pilot.utils.executor_utils import DefaultExecutorFactory
|
from pilot.utils.executor_utils import DefaultExecutorFactory
|
||||||
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
|
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
|
||||||
from pilot.server.base import WebWerverParameters
|
from pilot.server.base import WebWerverParameters
|
||||||
@@ -41,6 +42,10 @@ def initialize_components(
|
|||||||
param, system_app, embedding_model_name, embedding_model_path
|
param, system_app, embedding_model_name, embedding_model_path
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from pilot.cache import initialize_cache
|
||||||
|
|
||||||
|
initialize_cache(system_app, MODEL_DISK_CACHE_DIR)
|
||||||
|
|
||||||
|
|
||||||
def _initialize_embedding_model(
|
def _initialize_embedding_model(
|
||||||
param: WebWerverParameters,
|
param: WebWerverParameters,
|
||||||
|
11
pilot/utils/memory_utils.py
Normal file
11
pilot/utils/memory_utils.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
from typing import Any
|
||||||
|
from pympler import asizeof
|
||||||
|
|
||||||
|
|
||||||
|
def _get_object_bytes(obj: Any) -> int:
|
||||||
|
"""Get the bytes of a object in memory
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj (Any): The object to return the bytes
|
||||||
|
"""
|
||||||
|
return asizeof.asizeof(obj)
|
10
setup.py
10
setup.py
@@ -319,6 +319,8 @@ def core_requires():
|
|||||||
"alembic==1.12.0",
|
"alembic==1.12.0",
|
||||||
# for excel
|
# for excel
|
||||||
"openpyxl",
|
"openpyxl",
|
||||||
|
# for cache, TODO pympler has not been updated for a long time and needs to find a new toolkit.
|
||||||
|
"pympler",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -410,6 +412,13 @@ def vllm_requires():
|
|||||||
setup_spec.extras["vllm"] = ["vllm"]
|
setup_spec.extras["vllm"] = ["vllm"]
|
||||||
|
|
||||||
|
|
||||||
|
def cache_requires():
|
||||||
|
"""
|
||||||
|
pip install "db-gpt[cache]"
|
||||||
|
"""
|
||||||
|
setup_spec.extras["cache"] = ["rocksdict", "msgpack"]
|
||||||
|
|
||||||
|
|
||||||
# def chat_scene():
|
# def chat_scene():
|
||||||
# setup_spec.extras["chat"] = [
|
# setup_spec.extras["chat"] = [
|
||||||
# ""
|
# ""
|
||||||
@@ -460,6 +469,7 @@ all_datasource_requires()
|
|||||||
openai_requires()
|
openai_requires()
|
||||||
gpt4all_requires()
|
gpt4all_requires()
|
||||||
vllm_requires()
|
vllm_requires()
|
||||||
|
cache_requires()
|
||||||
|
|
||||||
# must be last
|
# must be last
|
||||||
default_requires()
|
default_requires()
|
||||||
|
Reference in New Issue
Block a user