mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-16 14:40:56 +00:00
feat(model): Support BranchOperator
This commit is contained in:
@@ -9,6 +9,7 @@ from .operator.common_operator import (
|
||||
MapOperator,
|
||||
BranchOperator,
|
||||
InputOperator,
|
||||
BranchFunc,
|
||||
)
|
||||
|
||||
from .operator.stream_operator import (
|
||||
@@ -25,6 +26,7 @@ from .task.task_impl import (
|
||||
DefaultInputContext,
|
||||
SimpleTaskOutput,
|
||||
SimpleStreamTaskOutput,
|
||||
_is_async_iterator,
|
||||
)
|
||||
from .runner.local_runner import DefaultWorkflowRunner
|
||||
|
||||
@@ -38,6 +40,7 @@ __all__ = [
|
||||
"MapOperator",
|
||||
"BranchOperator",
|
||||
"InputOperator",
|
||||
"BranchFunc",
|
||||
"WorkflowRunner",
|
||||
"TaskState",
|
||||
"TaskOutput",
|
||||
|
@@ -143,7 +143,9 @@ class DAGNode(DependencyMixin, ABC):
|
||||
resource_group: Optional[ResourceGroup] = None
|
||||
"""The resource group of current DAGNode"""
|
||||
|
||||
def __init__(self, dag: Optional["DAG"] = None, node_id: str = None) -> None:
|
||||
def __init__(
|
||||
self, dag: Optional["DAG"] = None, node_id: str = None, node_name: str = None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._upstream: List["DAGNode"] = []
|
||||
self._downstream: List["DAGNode"] = []
|
||||
@@ -151,6 +153,7 @@ class DAGNode(DependencyMixin, ABC):
|
||||
if not node_id and self._dag:
|
||||
node_id = self._dag._new_node_id()
|
||||
self._node_id: str = node_id
|
||||
self._node_name: str = node_name
|
||||
|
||||
@property
|
||||
def node_id(self) -> str:
|
||||
@@ -159,6 +162,21 @@ class DAGNode(DependencyMixin, ABC):
|
||||
def set_node_id(self, node_id: str) -> None:
|
||||
self._node_id = node_id
|
||||
|
||||
def __hash__(self) -> int:
|
||||
if self.node_id:
|
||||
return hash(self.node_id)
|
||||
else:
|
||||
return super().__hash__()
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if not isinstance(other, DAGNode):
|
||||
return False
|
||||
return self.node_id == other.node_id
|
||||
|
||||
@property
|
||||
def node_name(self) -> str:
|
||||
return self._node_name
|
||||
|
||||
@property
|
||||
def dag(self) -> "DAGNode":
|
||||
return self._dag
|
||||
|
@@ -100,6 +100,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
def __init__(
|
||||
self,
|
||||
task_id: Optional[str] = None,
|
||||
task_name: Optional[str] = None,
|
||||
dag: Optional[DAG] = None,
|
||||
runner: WorkflowRunner = None,
|
||||
**kwargs,
|
||||
@@ -109,7 +110,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
Args:
|
||||
runner (WorkflowRunner, optional): The runner used to execute the workflow. Defaults to None.
|
||||
"""
|
||||
super().__init__(node_id=task_id, dag=dag, **kwargs)
|
||||
super().__init__(node_id=task_id, node_name=task_name, dag=dag, **kwargs)
|
||||
if not runner:
|
||||
from pilot.awel import DefaultWorkflowRunner
|
||||
|
||||
|
@@ -1,5 +1,6 @@
|
||||
from typing import Generic, Dict, List, Union, Callable, Any, AsyncIterator, Awaitable
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from ..dag.base import DAGContext
|
||||
from ..task.base import (
|
||||
@@ -14,6 +15,9 @@ from ..task.base import (
|
||||
from .base import BaseOperator
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JoinOperator(BaseOperator, Generic[OUT]):
|
||||
"""Operator that joins inputs using a custom combine function.
|
||||
|
||||
@@ -141,31 +145,36 @@ class MapOperator(BaseOperator, Generic[IN, OUT]):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
BranchFunc = Union[Callable[[Any], bool], Callable[[Any], Awaitable[bool]]]
|
||||
BranchFunc = Union[Callable[[IN], bool], Callable[[IN], Awaitable[bool]]]
|
||||
|
||||
|
||||
class BranchOperator(BaseOperator, Generic[OUT]):
|
||||
class BranchOperator(BaseOperator, Generic[IN, OUT]):
|
||||
"""Operator node that branches the workflow based on a provided function.
|
||||
|
||||
This node filters its input data using a branching function and
|
||||
allows for conditional paths in the workflow.
|
||||
"""
|
||||
|
||||
def __init__(self, branches: Dict[BranchFunc, BaseOperator], **kwargs):
|
||||
def __init__(
|
||||
self, branches: Dict[BranchFunc[IN], Union[BaseOperator, str]], **kwargs
|
||||
):
|
||||
"""
|
||||
Initializes a BranchDAGNode with a branching function.
|
||||
|
||||
Args:
|
||||
branches (Dict[BranchFunc, RunnableDAGNode]): Dict of function that defines the branching condition.
|
||||
branches (Dict[BranchFunc[IN], Union[BaseOperator, str]]): Dict of function that defines the branching condition.
|
||||
|
||||
Raises:
|
||||
ValueError: If the branch_function is not callable.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
for branch_function in branches.keys():
|
||||
if not callable(branch_function):
|
||||
raise ValueError("branch_function must be callable")
|
||||
self.branches = branches
|
||||
if branches:
|
||||
for branch_function, value in branches.items():
|
||||
if not callable(branch_function):
|
||||
raise ValueError("branch_function must be callable")
|
||||
if isinstance(value, BaseOperator):
|
||||
branches[branch_function] = value.node_name or value.node_name
|
||||
self._branches = branches
|
||||
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
"""Run the branching operation on the DAG context's inputs.
|
||||
@@ -186,27 +195,36 @@ class BranchOperator(BaseOperator, Generic[OUT]):
|
||||
if not task_input.check_single_parent():
|
||||
raise ValueError("BranchDAGNode expects single parent")
|
||||
|
||||
branches = self._branches
|
||||
if not branches:
|
||||
branches = await self.branchs()
|
||||
|
||||
branch_func_tasks = []
|
||||
branch_nodes: List[BaseOperator] = []
|
||||
for func, node in self.branches.items():
|
||||
branch_nodes.append(node)
|
||||
branch_nodes: List[str] = []
|
||||
for func, node_name in branches.items():
|
||||
branch_nodes.append(node_name)
|
||||
branch_func_tasks.append(
|
||||
curr_task_ctx.task_input.predicate_map(func, failed_value=None)
|
||||
)
|
||||
|
||||
branch_input_ctxs: List[InputContext] = await asyncio.gather(*branch_func_tasks)
|
||||
parent_output = task_input.parent_outputs[0].task_output
|
||||
curr_task_ctx.set_task_output(parent_output)
|
||||
|
||||
skip_node_names = []
|
||||
for i, ctx in enumerate(branch_input_ctxs):
|
||||
node = branch_nodes[i]
|
||||
node_name = branch_nodes[i]
|
||||
branch_out = ctx.parent_outputs[0].task_output
|
||||
logger.info(
|
||||
f"branch_input_ctxs {i} result {branch_out.output}, is_empty: {branch_out.is_empty}"
|
||||
)
|
||||
if ctx.parent_outputs[0].task_output.is_empty:
|
||||
# Skip current node
|
||||
# node.current_task_context.set_current_state(TaskState.SKIP)
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
logger.info(f"Skip node name {node_name}")
|
||||
skip_node_names.append(node_name)
|
||||
curr_task_ctx.update_metadata("skip_node_names", skip_node_names)
|
||||
return parent_output
|
||||
|
||||
async def branchs(self) -> Dict[BranchFunc[IN], Union[BaseOperator, str]]:
|
||||
raise NotImplementedError
|
||||
return None
|
||||
|
||||
|
||||
class InputOperator(BaseOperator, Generic[OUT]):
|
||||
|
@@ -1,9 +1,12 @@
|
||||
from typing import List, Set, Optional, Dict
|
||||
import uuid
|
||||
import logging
|
||||
from ..dag.base import DAG
|
||||
|
||||
from ..operator.base import BaseOperator, CALL_DATA
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DAGNodeInstance:
|
||||
def __init__(self, node_instance: DAG) -> None:
|
||||
@@ -45,14 +48,19 @@ def _save_call_data(
|
||||
root_nodes: List[BaseOperator], call_data: CALL_DATA
|
||||
) -> Dict[str, Dict]:
|
||||
id2call_data = {}
|
||||
logger.debug(f"_save_call_data: {call_data}, root_nodes: {root_nodes}")
|
||||
if not call_data:
|
||||
return id2call_data
|
||||
if len(root_nodes) == 1:
|
||||
node = root_nodes[0]
|
||||
logger.info(f"Save call data to node {node.node_id}, call_data: {call_data}")
|
||||
id2call_data[node.node_id] = call_data
|
||||
else:
|
||||
for node in root_nodes:
|
||||
node_id = node.node_id
|
||||
logger.info(
|
||||
f"Save call data to node {node.node_id}, call_data: {call_data.get(node_id)}"
|
||||
)
|
||||
id2call_data[node_id] = call_data.get(node_id)
|
||||
return id2call_data
|
||||
|
||||
@@ -71,4 +79,4 @@ def _build_from_end_node(end_node: BaseOperator) -> List[BaseOperator]:
|
||||
|
||||
|
||||
def _get_root_nodes(nodes: List[BaseOperator]) -> List[BaseOperator]:
|
||||
return list(filter(lambda x: not x.upstream, nodes))
|
||||
return list(set(filter(lambda x: not x.upstream, nodes)))
|
||||
|
@@ -1,10 +1,15 @@
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, Optional, Set, List
|
||||
import logging
|
||||
|
||||
from ..dag.base import DAGContext
|
||||
from ..operator.base import WorkflowRunner, BaseOperator, CALL_DATA
|
||||
from ..operator.common_operator import BranchOperator, JoinOperator
|
||||
from ..task.base import TaskContext, TaskState
|
||||
from ..task.task_impl import DefaultInputContext, DefaultTaskContext
|
||||
from ..task.task_impl import DefaultInputContext, DefaultTaskContext, SimpleTaskOutput
|
||||
from .job_manager import JobManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DefaultWorkflowRunner(WorkflowRunner):
|
||||
async def execute_workflow(
|
||||
@@ -13,10 +18,16 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
# Create DAG context
|
||||
dag_ctx = DAGContext()
|
||||
job_manager = JobManager.build_from_end_node(node, call_data)
|
||||
logger.info(
|
||||
f"Begin run workflow from end operator, id: {node.node_id}, call_data: {call_data}"
|
||||
)
|
||||
dag = node.dag
|
||||
# Save node output
|
||||
node_outputs: Dict[str, TaskContext] = {}
|
||||
await self._execute_node(job_manager, node, dag_ctx, node_outputs)
|
||||
skip_node_ids = set()
|
||||
await self._execute_node(
|
||||
job_manager, node, dag_ctx, node_outputs, skip_node_ids
|
||||
)
|
||||
|
||||
return dag_ctx
|
||||
|
||||
@@ -26,6 +37,7 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
node: BaseOperator,
|
||||
dag_ctx: DAGContext,
|
||||
node_outputs: Dict[str, TaskContext],
|
||||
skip_node_ids: Set[str],
|
||||
):
|
||||
# Skip run node
|
||||
if node.node_id in node_outputs:
|
||||
@@ -35,18 +47,9 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
for upstream_node in node.upstream:
|
||||
if isinstance(upstream_node, BaseOperator):
|
||||
await self._execute_node(
|
||||
job_manager, upstream_node, dag_ctx, node_outputs
|
||||
job_manager, upstream_node, dag_ctx, node_outputs, skip_node_ids
|
||||
)
|
||||
# if node.current_task_context.current_state == TaskState.SKIP:
|
||||
# return
|
||||
|
||||
# for upstream_node in node.upstream:
|
||||
# if (
|
||||
# isinstance(upstream_node, BaseOperator)
|
||||
# and upstream_node.current_task_context.current_state == TaskState.SKIP
|
||||
# ):
|
||||
# return
|
||||
# Get the input from upstream node
|
||||
inputs = [
|
||||
node_outputs[upstream_node.node_id] for upstream_node in node.upstream
|
||||
]
|
||||
@@ -56,13 +59,48 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
|
||||
task_ctx.set_task_input(input_ctx)
|
||||
dag_ctx.set_current_task_context(task_ctx)
|
||||
|
||||
task_ctx.set_current_state(TaskState.RUNNING)
|
||||
|
||||
if node.node_id in skip_node_ids:
|
||||
task_ctx.set_current_state(TaskState.SKIP)
|
||||
task_ctx.set_task_output(SimpleTaskOutput(None))
|
||||
node_outputs[node.node_id] = task_ctx
|
||||
return
|
||||
try:
|
||||
# print(f"Begin run {node}")
|
||||
logger.info(
|
||||
f"Begin run operator, node id: {node.node_id}, node name: {node.node_name}, cls: {node}"
|
||||
)
|
||||
await node._run(dag_ctx)
|
||||
node_outputs[node.node_id] = dag_ctx.current_task_context
|
||||
task_ctx.set_current_state(TaskState.SUCCESS)
|
||||
|
||||
if isinstance(node, BranchOperator):
|
||||
skip_nodes = task_ctx.metadata.get("skip_node_names", [])
|
||||
logger.info(
|
||||
f"Current is branch operator, skip node names: {skip_nodes}"
|
||||
)
|
||||
_skip_current_downstream_by_node_name(node, skip_nodes, skip_node_ids)
|
||||
except Exception as e:
|
||||
logger.info(f"Run operator {node.node_id} error, error message: {str(e)}")
|
||||
task_ctx.set_current_state(TaskState.FAILED)
|
||||
raise e
|
||||
|
||||
|
||||
def _skip_current_downstream_by_node_name(
|
||||
branch_node: BranchOperator, skip_nodes: List[str], skip_node_ids: Set[str]
|
||||
):
|
||||
if not skip_nodes:
|
||||
return
|
||||
for child in branch_node.downstream:
|
||||
if child.node_name in skip_nodes:
|
||||
logger.info(f"Skip node name {child.node_name}, node id {child.node_id}")
|
||||
_skip_downstream_by_id(child, skip_node_ids)
|
||||
|
||||
|
||||
def _skip_downstream_by_id(node: BaseOperator, skip_node_ids: Set[str]):
|
||||
if isinstance(node, JoinOperator):
|
||||
# Not skip join node
|
||||
return
|
||||
skip_node_ids.add(node.node_id)
|
||||
for child in node.downstream:
|
||||
_skip_downstream_by_id(child, skip_node_ids)
|
||||
|
@@ -81,6 +81,10 @@ class TaskOutput(ABC, Generic[T]):
|
||||
output_data (Union[T, AsyncIterator[T]]): Output data.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def new_output(self) -> "TaskOutput[T]":
|
||||
"""Create new output object"""
|
||||
|
||||
async def map(self, map_func) -> "TaskOutput[T]":
|
||||
"""Apply a mapping function to the task's output.
|
||||
|
||||
@@ -334,13 +338,18 @@ class InputContext(ABC):
|
||||
"""
|
||||
return len(self.parent_outputs) == 1
|
||||
|
||||
def check_stream(self) -> bool:
|
||||
def check_stream(self, skip_empty: bool = False) -> bool:
|
||||
"""Check if all parent outputs are streams.
|
||||
|
||||
Args:
|
||||
skip_empty (bool): Skip empty output or not.
|
||||
|
||||
Returns:
|
||||
bool: True if all parent outputs are streams, False otherwise.
|
||||
"""
|
||||
for out in self.parent_outputs:
|
||||
if out.task_output.is_empty and skip_empty:
|
||||
continue
|
||||
if not (out.task_output and out.task_output.is_stream):
|
||||
return False
|
||||
return True
|
||||
|
@@ -13,10 +13,13 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
import asyncio
|
||||
|
||||
import logging
|
||||
from .base import TaskOutput, TaskContext, TaskState, InputContext, InputSource, T
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _reduce_stream(stream: AsyncIterator, reduce_function) -> Any:
|
||||
# Init accumulator
|
||||
try:
|
||||
@@ -44,6 +47,9 @@ class SimpleTaskOutput(TaskOutput[T], Generic[T]):
|
||||
def set_output(self, output_data: T | AsyncIterator[T]) -> None:
|
||||
self._data = output_data
|
||||
|
||||
def new_output(self) -> TaskOutput[T]:
|
||||
return SimpleTaskOutput(None)
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
return not self._data
|
||||
@@ -89,6 +95,9 @@ class SimpleStreamTaskOutput(TaskOutput[T], Generic[T]):
|
||||
def set_output(self, output_data: T | AsyncIterator[T]) -> None:
|
||||
self._data = output_data
|
||||
|
||||
def new_output(self) -> TaskOutput[T]:
|
||||
return SimpleStreamTaskOutput(None)
|
||||
|
||||
async def map(self, map_func) -> TaskOutput[T]:
|
||||
is_async = asyncio.iscoroutinefunction(map_func)
|
||||
|
||||
@@ -213,7 +222,8 @@ class DefaultTaskContext(TaskContext, Generic[T]):
|
||||
self._task_state = task_state
|
||||
|
||||
def new_ctx(self) -> TaskContext:
|
||||
return DefaultTaskContext(self._task_id, self._task_state, self._output)
|
||||
new_output = self._output.new_output()
|
||||
return DefaultTaskContext(self._task_id, self._task_state, new_output)
|
||||
|
||||
@property
|
||||
def metadata(self) -> Dict[str, Any]:
|
||||
@@ -259,9 +269,17 @@ class DefaultInputContext(InputContext):
|
||||
async def map_all(self, map_func: Callable[..., Any]) -> InputContext:
|
||||
if not self._outputs:
|
||||
return DefaultInputContext([])
|
||||
is_steam = self._outputs[0].task_output.is_stream
|
||||
# Some parent may be empty
|
||||
not_empty_idx = 0
|
||||
for i, p in enumerate(self._outputs):
|
||||
if p.task_output.is_empty:
|
||||
continue
|
||||
not_empty_idx = i
|
||||
break
|
||||
# All output is empty?
|
||||
is_steam = self._outputs[not_empty_idx].task_output.is_stream
|
||||
if is_steam:
|
||||
if not self.check_stream():
|
||||
if not self.check_stream(skip_empty=True):
|
||||
raise ValueError(
|
||||
"The output in all tasks must has same output format to map_all"
|
||||
)
|
||||
@@ -275,9 +293,11 @@ class DefaultInputContext(InputContext):
|
||||
map_res = await map_func(*outputs)
|
||||
else:
|
||||
map_res = map_func(*outputs)
|
||||
|
||||
single_output: TaskContext = self._outputs[0].new_ctx()
|
||||
single_output: TaskContext = self._outputs[not_empty_idx].new_ctx()
|
||||
single_output.task_output.set_output(map_res)
|
||||
logger.debug(
|
||||
f"Current map_all map_res: {map_res}, is steam: {single_output.task_output.is_stream}"
|
||||
)
|
||||
return DefaultInputContext([single_output])
|
||||
|
||||
async def reduce(self, reduce_func: Callable[[Any], Any]) -> InputContext:
|
||||
@@ -311,6 +331,7 @@ class DefaultInputContext(InputContext):
|
||||
for i, task_ctx in enumerate(new_outputs):
|
||||
task_ctx: TaskContext = task_ctx
|
||||
if results[i]:
|
||||
task_ctx.task_output.set_output(True)
|
||||
result_outputs.append(task_ctx)
|
||||
else:
|
||||
task_ctx.task_output.set_output(failed_value)
|
||||
|
@@ -108,21 +108,34 @@ async def test_join_node(runner: WorkflowRunner, input_nodes: List[InputOperator
|
||||
"input_node, is_odd",
|
||||
[
|
||||
({"outputs": [0]}, False),
|
||||
# ({"outputs": [1]}, False),
|
||||
({"outputs": [1]}, True),
|
||||
],
|
||||
indirect=["input_node"],
|
||||
)
|
||||
async def test_branch_node(
|
||||
runner: WorkflowRunner, input_node: InputOperator, is_odd: bool
|
||||
):
|
||||
def join_func(o1, o2) -> int:
|
||||
print(f"join func result, o1: {o1}, o2: {o2}")
|
||||
return o1 or o2
|
||||
|
||||
with DAG("test_join_node") as dag:
|
||||
odd_node = MapOperator(lambda x: 999, task_id="odd_node")
|
||||
even_node = MapOperator(lambda x: 888, task_id="even_node")
|
||||
odd_node = MapOperator(
|
||||
lambda x: 999, task_id="odd_node", task_name="odd_node_name"
|
||||
)
|
||||
even_node = MapOperator(
|
||||
lambda x: 888, task_id="even_node", task_name="even_node_name"
|
||||
)
|
||||
join_node = JoinOperator(join_func)
|
||||
branch_node = BranchOperator(
|
||||
{lambda x: x % 2 == 1: odd_node, lambda x: x % 2 == 0: even_node}
|
||||
)
|
||||
branch_node >> odd_node >> join_node
|
||||
branch_node >> even_node >> join_node
|
||||
|
||||
input_node >> branch_node
|
||||
|
||||
odd_res: DAGContext[int] = await runner.execute_workflow(odd_node)
|
||||
even_res: DAGContext[int] = await runner.execute_workflow(even_node)
|
||||
assert branch_node.current_task_context.current_state == TaskState.SUCCESS
|
||||
res: DAGContext[int] = await runner.execute_workflow(join_node)
|
||||
assert res.current_task_context.current_state == TaskState.SUCCESS
|
||||
expect_res = 999 if is_odd else 888
|
||||
assert res.current_task_context.task_output.output == expect_res
|
||||
|
@@ -1,10 +1,13 @@
|
||||
from typing import AsyncIterator, Dict
|
||||
from typing import AsyncIterator, Dict, Union
|
||||
import logging
|
||||
from pilot.awel import (
|
||||
BranchFunc,
|
||||
StreamifyAbsOperator,
|
||||
BranchOperator,
|
||||
MapOperator,
|
||||
TransformStreamAbsOperator,
|
||||
)
|
||||
from pilot.awel.operator.base import BaseOperator
|
||||
from pilot.model.base import ModelOutput
|
||||
from pilot.model.cluster import WorkerManager
|
||||
from pilot.cache import LLMCacheClient, CacheManager, LLMCacheKey, LLMCacheValue
|
||||
@@ -16,71 +19,189 @@ _LLM_MODEL_OUTPUT_CACHE_KEY = "llm_model_output_cache"
|
||||
|
||||
|
||||
class ModelStreamOperator(StreamifyAbsOperator[Dict, ModelOutput]):
|
||||
"""Operator for streaming processing of model outputs.
|
||||
|
||||
Args:
|
||||
worker_manager (WorkerManager): The manager that handles worker processes for model inference.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Methods:
|
||||
streamify: Asynchronously processes a stream of inputs, yielding model outputs.
|
||||
"""
|
||||
|
||||
def __init__(self, worker_manager: WorkerManager, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.worker_manager = worker_manager
|
||||
|
||||
async def streamify(self, input_value: Dict) -> AsyncIterator[ModelOutput]:
|
||||
llm_cache_value: LLMCacheValue = await self.current_dag_context.get_share_data(
|
||||
_LLM_MODEL_OUTPUT_CACHE_KEY
|
||||
)
|
||||
logger.info(f"llm_cache_value: {llm_cache_value}")
|
||||
if llm_cache_value:
|
||||
for out in llm_cache_value.get_value().output:
|
||||
yield out
|
||||
return
|
||||
"""Process inputs as a stream and yield model outputs.
|
||||
|
||||
Args:
|
||||
input_value (Dict): The input value for the model.
|
||||
|
||||
Returns:
|
||||
AsyncIterator[ModelOutput]: An asynchronous iterator of model outputs.
|
||||
"""
|
||||
async for out in self.worker_manager.generate_stream(input_value):
|
||||
yield out
|
||||
|
||||
|
||||
class ModelOperator(MapOperator[Dict, ModelOutput]):
|
||||
"""Operator for map-based processing of model outputs.
|
||||
|
||||
Args:
|
||||
worker_manager (WorkerManager): Manager for handling worker processes.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Methods:
|
||||
map: Asynchronously processes a single input and returns the model output.
|
||||
"""
|
||||
|
||||
def __init__(self, worker_manager: WorkerManager, **kwargs) -> None:
|
||||
self.worker_manager = worker_manager
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def map(self, input_value: Dict) -> ModelOutput:
|
||||
llm_cache_value: LLMCacheValue = await self.current_dag_context.get_share_data(
|
||||
_LLM_MODEL_OUTPUT_CACHE_KEY
|
||||
)
|
||||
logger.info(f"llm_cache_value: {llm_cache_value}")
|
||||
if llm_cache_value:
|
||||
return llm_cache_value.get_value().output
|
||||
"""Process a single input and return the model output.
|
||||
|
||||
Args:
|
||||
input_value (Dict): The input value for the model.
|
||||
|
||||
Returns:
|
||||
ModelOutput: The output from the model.
|
||||
"""
|
||||
return await self.worker_manager.generate(input_value)
|
||||
|
||||
|
||||
class ModelCachePreOperator(MapOperator[Dict, Dict]):
|
||||
def __init__(self, cache_manager: CacheManager, **kwargs):
|
||||
class CachedModelStreamOperator(StreamifyAbsOperator[Dict, ModelOutput]):
|
||||
"""Operator for streaming processing of model outputs with caching.
|
||||
|
||||
Args:
|
||||
cache_manager (CacheManager): The cache manager to handle caching operations.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Methods:
|
||||
streamify: Processes a stream of inputs with cache support, yielding model outputs.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_manager: CacheManager, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self._cache_manager = cache_manager
|
||||
self._client = LLMCacheClient(cache_manager)
|
||||
|
||||
async def streamify(self, input_value: Dict) -> AsyncIterator[ModelOutput]:
|
||||
"""Process inputs as a stream with cache support and yield model outputs.
|
||||
|
||||
Args:
|
||||
input_value (Dict): The input value for the model.
|
||||
|
||||
Returns:
|
||||
AsyncIterator[ModelOutput]: An asynchronous iterator of model outputs.
|
||||
"""
|
||||
cache_dict = _parse_cache_key_dict(input_value)
|
||||
llm_cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
|
||||
llm_cache_value: LLMCacheValue = await self._client.get(llm_cache_key)
|
||||
logger.info(f"llm_cache_value: {llm_cache_value}")
|
||||
for out in llm_cache_value.get_value().output:
|
||||
yield out
|
||||
|
||||
|
||||
class CachedModelOperator(MapOperator[Dict, ModelOutput]):
|
||||
"""Operator for map-based processing of model outputs with caching.
|
||||
|
||||
Args:
|
||||
cache_manager (CacheManager): Manager for caching operations.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Methods:
|
||||
map: Processes a single input with cache support and returns the model output.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_manager: CacheManager, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self._cache_manager = cache_manager
|
||||
self._client = LLMCacheClient(cache_manager)
|
||||
|
||||
async def map(self, input_value: Dict) -> Dict:
|
||||
cache_dict = {
|
||||
"prompt": input_value.get("prompt"),
|
||||
"model_name": input_value.get("model"),
|
||||
"temperature": input_value.get("temperature"),
|
||||
"max_new_tokens": input_value.get("max_new_tokens"),
|
||||
"top_p": input_value.get("top_p", "1.0"),
|
||||
# TODO pass model_type
|
||||
"model_type": input_value.get("model_type", "huggingface"),
|
||||
}
|
||||
cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
|
||||
cache_value = await self._client.get(cache_key)
|
||||
logger.debug(
|
||||
f"cache_key: {cache_key}, hash key: {hash(cache_key)}, cache_value: {cache_value}"
|
||||
)
|
||||
await self.current_dag_context.save_to_share_data(
|
||||
_LLM_MODEL_INPUT_VALUE_KEY, cache_key
|
||||
)
|
||||
if cache_value:
|
||||
logger.info(f"The model output has cached, cache_value: {cache_value}")
|
||||
await self.current_dag_context.save_to_share_data(
|
||||
_LLM_MODEL_OUTPUT_CACHE_KEY, cache_value
|
||||
async def map(self, input_value: Dict) -> ModelOutput:
|
||||
"""Process a single input with cache support and return the model output.
|
||||
|
||||
Args:
|
||||
input_value (Dict): The input value for the model.
|
||||
|
||||
Returns:
|
||||
ModelOutput: The output from the model.
|
||||
"""
|
||||
cache_dict = _parse_cache_key_dict(input_value)
|
||||
llm_cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
|
||||
llm_cache_value: LLMCacheValue = await self._client.get(llm_cache_key)
|
||||
logger.info(f"llm_cache_value: {llm_cache_value}")
|
||||
return llm_cache_value.get_value().output
|
||||
|
||||
|
||||
class ModelCacheBranchOperator(BranchOperator[Dict, Dict]):
|
||||
"""
|
||||
A branch operator that decides whether to use cached data or to process data using the model.
|
||||
|
||||
Args:
|
||||
cache_manager (CacheManager): The cache manager for managing cache operations.
|
||||
model_task_name (str): The name of the task to process data using the model.
|
||||
cache_task_name (str): The name of the task to process data using the cache.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_manager: CacheManager,
|
||||
model_task_name: str,
|
||||
cache_task_name: str,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(branches=None, **kwargs)
|
||||
self._cache_manager = cache_manager
|
||||
self._client = LLMCacheClient(cache_manager)
|
||||
self._model_task_name = model_task_name
|
||||
self._cache_task_name = cache_task_name
|
||||
|
||||
async def branchs(self) -> Dict[BranchFunc[Dict], Union[BaseOperator, str]]:
|
||||
"""Defines branch logic based on cache availability.
|
||||
|
||||
Returns:
|
||||
Dict[BranchFunc[Dict], Union[BaseOperator, str]]: A dictionary mapping branch functions to task names.
|
||||
"""
|
||||
|
||||
async def check_cache_true(input_value: Dict) -> bool:
|
||||
# Check if the cache contains the result for the given input
|
||||
cache_dict = _parse_cache_key_dict(input_value)
|
||||
cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
|
||||
cache_value = await self._client.get(cache_key)
|
||||
logger.debug(
|
||||
f"cache_key: {cache_key}, hash key: {hash(cache_key)}, cache_value: {cache_value}"
|
||||
)
|
||||
return input_value
|
||||
await self.current_dag_context.save_to_share_data(
|
||||
_LLM_MODEL_INPUT_VALUE_KEY, cache_key
|
||||
)
|
||||
return True if cache_value else False
|
||||
|
||||
async def check_cache_false(input_value: Dict):
|
||||
# Inverse of check_cache_true
|
||||
return not await check_cache_true(input_value)
|
||||
|
||||
return {
|
||||
check_cache_true: self._cache_task_name,
|
||||
check_cache_false: self._model_task_name,
|
||||
}
|
||||
|
||||
|
||||
class ModelStreamCacheOperator(TransformStreamAbsOperator[ModelOutput, ModelOutput]):
|
||||
class ModelStreamSaveCacheOperator(
|
||||
TransformStreamAbsOperator[ModelOutput, ModelOutput]
|
||||
):
|
||||
"""An operator to save the stream of model outputs to cache.
|
||||
|
||||
Args:
|
||||
cache_manager (CacheManager): The cache manager for handling cache operations.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_manager: CacheManager, **kwargs):
|
||||
self._cache_manager = cache_manager
|
||||
self._client = LLMCacheClient(cache_manager)
|
||||
@@ -89,6 +210,14 @@ class ModelStreamCacheOperator(TransformStreamAbsOperator[ModelOutput, ModelOutp
|
||||
async def transform_stream(
|
||||
self, input_value: AsyncIterator[ModelOutput]
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
"""Transforms the input stream by saving the outputs to cache.
|
||||
|
||||
Args:
|
||||
input_value (AsyncIterator[ModelOutput]): An asynchronous iterator of model outputs.
|
||||
|
||||
Returns:
|
||||
AsyncIterator[ModelOutput]: The same input iterator, but the outputs are saved to cache.
|
||||
"""
|
||||
llm_cache_key: LLMCacheKey = None
|
||||
outputs = []
|
||||
async for out in input_value:
|
||||
@@ -103,13 +232,28 @@ class ModelStreamCacheOperator(TransformStreamAbsOperator[ModelOutput, ModelOutp
|
||||
await self._client.set(llm_cache_key, llm_cache_value)
|
||||
|
||||
|
||||
class ModelCacheOperator(MapOperator[ModelOutput, ModelOutput]):
|
||||
class ModelSaveCacheOperator(MapOperator[ModelOutput, ModelOutput]):
|
||||
"""An operator to save a single model output to cache.
|
||||
|
||||
Args:
|
||||
cache_manager (CacheManager): The cache manager for handling cache operations.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_manager: CacheManager, **kwargs):
|
||||
self._cache_manager = cache_manager
|
||||
self._client = LLMCacheClient(cache_manager)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def map(self, input_value: ModelOutput) -> ModelOutput:
|
||||
"""Saves a single model output to cache and returns it.
|
||||
|
||||
Args:
|
||||
input_value (ModelOutput): The output from the model to be cached.
|
||||
|
||||
Returns:
|
||||
ModelOutput: The same input model output.
|
||||
"""
|
||||
llm_cache_key: LLMCacheKey = await self.current_dag_context.get_share_data(
|
||||
_LLM_MODEL_INPUT_VALUE_KEY
|
||||
)
|
||||
@@ -117,3 +261,26 @@ class ModelCacheOperator(MapOperator[ModelOutput, ModelOutput]):
|
||||
if llm_cache_key:
|
||||
await self._client.set(llm_cache_key, llm_cache_value)
|
||||
return input_value
|
||||
|
||||
|
||||
def _parse_cache_key_dict(input_value: Dict) -> Dict:
|
||||
"""Parses and extracts relevant fields from input to form a cache key dictionary.
|
||||
|
||||
Args:
|
||||
input_value (Dict): The input dictionary containing model and prompt parameters.
|
||||
|
||||
Returns:
|
||||
Dict: A dictionary used for generating cache keys.
|
||||
"""
|
||||
prompt: str = input_value.get("prompt")
|
||||
if prompt:
|
||||
prompt = prompt.strip()
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"model_name": input_value.get("model"),
|
||||
"temperature": input_value.get("temperature"),
|
||||
"max_new_tokens": input_value.get("max_new_tokens"),
|
||||
"top_p": input_value.get("top_p", "1.0"),
|
||||
# TODO pass model_type
|
||||
"model_type": input_value.get("model_type", "huggingface"),
|
||||
}
|
||||
|
@@ -566,29 +566,83 @@ class BaseChat(ABC):
|
||||
def _build_model_operator(
|
||||
is_stream: bool = False, dag_name: str = "llm_model_dag"
|
||||
) -> BaseOperator:
|
||||
"""Builds and returns a model processing workflow (DAG) operator.
|
||||
|
||||
This function constructs a Directed Acyclic Graph (DAG) for processing data using a model.
|
||||
It includes caching and branching logic to either fetch results from a cache or process
|
||||
data using the model. It supports both streaming and non-streaming modes.
|
||||
|
||||
.. code-block:: python
|
||||
input_node >> cache_check_branch_node
|
||||
cache_check_branch_node >> model_node >> save_cached_node >> join_node
|
||||
cache_check_branch_node >> cached_node >> join_node
|
||||
|
||||
equivalent to::
|
||||
|
||||
-> model_node -> save_cached_node ->
|
||||
/ \
|
||||
input_node -> cache_check_branch_node ---> join_node
|
||||
\ /
|
||||
-> cached_node ------------------- ->
|
||||
|
||||
Args:
|
||||
is_stream (bool): Flag to determine if the operator should process data in streaming mode.
|
||||
dag_name (str): Name of the DAG.
|
||||
|
||||
Returns:
|
||||
BaseOperator: The final operator in the constructed DAG, typically a join node.
|
||||
"""
|
||||
from pilot.model.cluster import WorkerManagerFactory
|
||||
from pilot.awel import JoinOperator
|
||||
from pilot.model.operator.model_operator import (
|
||||
ModelCacheOperator,
|
||||
ModelStreamCacheOperator,
|
||||
ModelCachePreOperator,
|
||||
ModelCacheBranchOperator,
|
||||
CachedModelStreamOperator,
|
||||
CachedModelOperator,
|
||||
ModelSaveCacheOperator,
|
||||
ModelStreamSaveCacheOperator,
|
||||
)
|
||||
from pilot.cache import CacheManager
|
||||
|
||||
# Fetch worker and cache managers from the system configuration
|
||||
worker_manager = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
cache_manager: CacheManager = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.MODEL_CACHE_MANAGER, CacheManager
|
||||
)
|
||||
# Define task names for the model and cache nodes
|
||||
model_task_name = "llm_model_node"
|
||||
cache_task_name = "llm_model_cache_node"
|
||||
|
||||
with DAG(dag_name):
|
||||
# Create an input node
|
||||
input_node = InputOperator(SimpleCallDataInputSource())
|
||||
cache_check_node = ModelCachePreOperator(cache_manager)
|
||||
# Determine if the workflow should operate in streaming mode
|
||||
if is_stream:
|
||||
model_node = ModelStreamOperator(worker_manager)
|
||||
cache_node = ModelStreamCacheOperator(cache_manager)
|
||||
model_node = ModelStreamOperator(worker_manager, task_name=model_task_name)
|
||||
cached_node = CachedModelStreamOperator(
|
||||
cache_manager, task_name=cache_task_name
|
||||
)
|
||||
save_cached_node = ModelStreamSaveCacheOperator(cache_manager)
|
||||
else:
|
||||
model_node = ModelOperator(worker_manager)
|
||||
cache_node = ModelCacheOperator(cache_manager)
|
||||
input_node >> cache_check_node >> model_node >> cache_node
|
||||
return cache_node
|
||||
model_node = ModelOperator(worker_manager, task_name=model_task_name)
|
||||
cached_node = CachedModelOperator(cache_manager, task_name=cache_task_name)
|
||||
save_cached_node = ModelSaveCacheOperator(cache_manager)
|
||||
|
||||
# Create a branch node to decide between fetching from cache or processing with the model
|
||||
cache_check_branch_node = ModelCacheBranchOperator(
|
||||
cache_manager,
|
||||
model_task_name="llm_model_node",
|
||||
cache_task_name="llm_model_cache_node",
|
||||
)
|
||||
# Create a join node to merge outputs from the model and cache nodes, just keep the fist not empty output
|
||||
join_node = JoinOperator(
|
||||
combine_function=lambda model_out, cache_out: cache_out or model_out
|
||||
)
|
||||
|
||||
# Define the workflow structure using the >> operator
|
||||
input_node >> cache_check_branch_node
|
||||
cache_check_branch_node >> model_node >> save_cached_node >> join_node
|
||||
cache_check_branch_node >> cached_node >> join_node
|
||||
|
||||
return join_node
|
||||
|
Reference in New Issue
Block a user