feat(model): Support BranchOperator

This commit is contained in:
FangYin Cheng
2023-11-16 18:13:16 +08:00
parent 6db8c49d87
commit 1150adbe6a
11 changed files with 452 additions and 102 deletions

View File

@@ -9,6 +9,7 @@ from .operator.common_operator import (
MapOperator, MapOperator,
BranchOperator, BranchOperator,
InputOperator, InputOperator,
BranchFunc,
) )
from .operator.stream_operator import ( from .operator.stream_operator import (
@@ -25,6 +26,7 @@ from .task.task_impl import (
DefaultInputContext, DefaultInputContext,
SimpleTaskOutput, SimpleTaskOutput,
SimpleStreamTaskOutput, SimpleStreamTaskOutput,
_is_async_iterator,
) )
from .runner.local_runner import DefaultWorkflowRunner from .runner.local_runner import DefaultWorkflowRunner
@@ -38,6 +40,7 @@ __all__ = [
"MapOperator", "MapOperator",
"BranchOperator", "BranchOperator",
"InputOperator", "InputOperator",
"BranchFunc",
"WorkflowRunner", "WorkflowRunner",
"TaskState", "TaskState",
"TaskOutput", "TaskOutput",

View File

@@ -143,7 +143,9 @@ class DAGNode(DependencyMixin, ABC):
resource_group: Optional[ResourceGroup] = None resource_group: Optional[ResourceGroup] = None
"""The resource group of current DAGNode""" """The resource group of current DAGNode"""
def __init__(self, dag: Optional["DAG"] = None, node_id: str = None) -> None: def __init__(
self, dag: Optional["DAG"] = None, node_id: str = None, node_name: str = None
) -> None:
super().__init__() super().__init__()
self._upstream: List["DAGNode"] = [] self._upstream: List["DAGNode"] = []
self._downstream: List["DAGNode"] = [] self._downstream: List["DAGNode"] = []
@@ -151,6 +153,7 @@ class DAGNode(DependencyMixin, ABC):
if not node_id and self._dag: if not node_id and self._dag:
node_id = self._dag._new_node_id() node_id = self._dag._new_node_id()
self._node_id: str = node_id self._node_id: str = node_id
self._node_name: str = node_name
@property @property
def node_id(self) -> str: def node_id(self) -> str:
@@ -159,6 +162,21 @@ class DAGNode(DependencyMixin, ABC):
def set_node_id(self, node_id: str) -> None: def set_node_id(self, node_id: str) -> None:
self._node_id = node_id self._node_id = node_id
def __hash__(self) -> int:
if self.node_id:
return hash(self.node_id)
else:
return super().__hash__()
def __eq__(self, other: Any) -> bool:
if not isinstance(other, DAGNode):
return False
return self.node_id == other.node_id
@property
def node_name(self) -> str:
return self._node_name
@property @property
def dag(self) -> "DAGNode": def dag(self) -> "DAGNode":
return self._dag return self._dag

View File

@@ -100,6 +100,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
def __init__( def __init__(
self, self,
task_id: Optional[str] = None, task_id: Optional[str] = None,
task_name: Optional[str] = None,
dag: Optional[DAG] = None, dag: Optional[DAG] = None,
runner: WorkflowRunner = None, runner: WorkflowRunner = None,
**kwargs, **kwargs,
@@ -109,7 +110,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
Args: Args:
runner (WorkflowRunner, optional): The runner used to execute the workflow. Defaults to None. runner (WorkflowRunner, optional): The runner used to execute the workflow. Defaults to None.
""" """
super().__init__(node_id=task_id, dag=dag, **kwargs) super().__init__(node_id=task_id, node_name=task_name, dag=dag, **kwargs)
if not runner: if not runner:
from pilot.awel import DefaultWorkflowRunner from pilot.awel import DefaultWorkflowRunner

View File

@@ -1,5 +1,6 @@
from typing import Generic, Dict, List, Union, Callable, Any, AsyncIterator, Awaitable from typing import Generic, Dict, List, Union, Callable, Any, AsyncIterator, Awaitable
import asyncio import asyncio
import logging
from ..dag.base import DAGContext from ..dag.base import DAGContext
from ..task.base import ( from ..task.base import (
@@ -14,6 +15,9 @@ from ..task.base import (
from .base import BaseOperator from .base import BaseOperator
logger = logging.getLogger(__name__)
class JoinOperator(BaseOperator, Generic[OUT]): class JoinOperator(BaseOperator, Generic[OUT]):
"""Operator that joins inputs using a custom combine function. """Operator that joins inputs using a custom combine function.
@@ -141,31 +145,36 @@ class MapOperator(BaseOperator, Generic[IN, OUT]):
raise NotImplementedError raise NotImplementedError
BranchFunc = Union[Callable[[Any], bool], Callable[[Any], Awaitable[bool]]] BranchFunc = Union[Callable[[IN], bool], Callable[[IN], Awaitable[bool]]]
class BranchOperator(BaseOperator, Generic[OUT]): class BranchOperator(BaseOperator, Generic[IN, OUT]):
"""Operator node that branches the workflow based on a provided function. """Operator node that branches the workflow based on a provided function.
This node filters its input data using a branching function and This node filters its input data using a branching function and
allows for conditional paths in the workflow. allows for conditional paths in the workflow.
""" """
def __init__(self, branches: Dict[BranchFunc, BaseOperator], **kwargs): def __init__(
self, branches: Dict[BranchFunc[IN], Union[BaseOperator, str]], **kwargs
):
""" """
Initializes a BranchDAGNode with a branching function. Initializes a BranchDAGNode with a branching function.
Args: Args:
branches (Dict[BranchFunc, RunnableDAGNode]): Dict of function that defines the branching condition. branches (Dict[BranchFunc[IN], Union[BaseOperator, str]]): Dict of function that defines the branching condition.
Raises: Raises:
ValueError: If the branch_function is not callable. ValueError: If the branch_function is not callable.
""" """
super().__init__(**kwargs) super().__init__(**kwargs)
for branch_function in branches.keys(): if branches:
if not callable(branch_function): for branch_function, value in branches.items():
raise ValueError("branch_function must be callable") if not callable(branch_function):
self.branches = branches raise ValueError("branch_function must be callable")
if isinstance(value, BaseOperator):
branches[branch_function] = value.node_name or value.node_name
self._branches = branches
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
"""Run the branching operation on the DAG context's inputs. """Run the branching operation on the DAG context's inputs.
@@ -186,27 +195,36 @@ class BranchOperator(BaseOperator, Generic[OUT]):
if not task_input.check_single_parent(): if not task_input.check_single_parent():
raise ValueError("BranchDAGNode expects single parent") raise ValueError("BranchDAGNode expects single parent")
branches = self._branches
if not branches:
branches = await self.branchs()
branch_func_tasks = [] branch_func_tasks = []
branch_nodes: List[BaseOperator] = [] branch_nodes: List[str] = []
for func, node in self.branches.items(): for func, node_name in branches.items():
branch_nodes.append(node) branch_nodes.append(node_name)
branch_func_tasks.append( branch_func_tasks.append(
curr_task_ctx.task_input.predicate_map(func, failed_value=None) curr_task_ctx.task_input.predicate_map(func, failed_value=None)
) )
branch_input_ctxs: List[InputContext] = await asyncio.gather(*branch_func_tasks) branch_input_ctxs: List[InputContext] = await asyncio.gather(*branch_func_tasks)
parent_output = task_input.parent_outputs[0].task_output parent_output = task_input.parent_outputs[0].task_output
curr_task_ctx.set_task_output(parent_output) curr_task_ctx.set_task_output(parent_output)
skip_node_names = []
for i, ctx in enumerate(branch_input_ctxs): for i, ctx in enumerate(branch_input_ctxs):
node = branch_nodes[i] node_name = branch_nodes[i]
branch_out = ctx.parent_outputs[0].task_output
logger.info(
f"branch_input_ctxs {i} result {branch_out.output}, is_empty: {branch_out.is_empty}"
)
if ctx.parent_outputs[0].task_output.is_empty: if ctx.parent_outputs[0].task_output.is_empty:
# Skip current node logger.info(f"Skip node name {node_name}")
# node.current_task_context.set_current_state(TaskState.SKIP) skip_node_names.append(node_name)
pass curr_task_ctx.update_metadata("skip_node_names", skip_node_names)
else: return parent_output
pass
async def branchs(self) -> Dict[BranchFunc[IN], Union[BaseOperator, str]]:
raise NotImplementedError raise NotImplementedError
return None
class InputOperator(BaseOperator, Generic[OUT]): class InputOperator(BaseOperator, Generic[OUT]):

View File

@@ -1,9 +1,12 @@
from typing import List, Set, Optional, Dict from typing import List, Set, Optional, Dict
import uuid import uuid
import logging
from ..dag.base import DAG from ..dag.base import DAG
from ..operator.base import BaseOperator, CALL_DATA from ..operator.base import BaseOperator, CALL_DATA
logger = logging.getLogger(__name__)
class DAGNodeInstance: class DAGNodeInstance:
def __init__(self, node_instance: DAG) -> None: def __init__(self, node_instance: DAG) -> None:
@@ -45,14 +48,19 @@ def _save_call_data(
root_nodes: List[BaseOperator], call_data: CALL_DATA root_nodes: List[BaseOperator], call_data: CALL_DATA
) -> Dict[str, Dict]: ) -> Dict[str, Dict]:
id2call_data = {} id2call_data = {}
logger.debug(f"_save_call_data: {call_data}, root_nodes: {root_nodes}")
if not call_data: if not call_data:
return id2call_data return id2call_data
if len(root_nodes) == 1: if len(root_nodes) == 1:
node = root_nodes[0] node = root_nodes[0]
logger.info(f"Save call data to node {node.node_id}, call_data: {call_data}")
id2call_data[node.node_id] = call_data id2call_data[node.node_id] = call_data
else: else:
for node in root_nodes: for node in root_nodes:
node_id = node.node_id node_id = node.node_id
logger.info(
f"Save call data to node {node.node_id}, call_data: {call_data.get(node_id)}"
)
id2call_data[node_id] = call_data.get(node_id) id2call_data[node_id] = call_data.get(node_id)
return id2call_data return id2call_data
@@ -71,4 +79,4 @@ def _build_from_end_node(end_node: BaseOperator) -> List[BaseOperator]:
def _get_root_nodes(nodes: List[BaseOperator]) -> List[BaseOperator]: def _get_root_nodes(nodes: List[BaseOperator]) -> List[BaseOperator]:
return list(filter(lambda x: not x.upstream, nodes)) return list(set(filter(lambda x: not x.upstream, nodes)))

View File

@@ -1,10 +1,15 @@
from typing import Dict, Optional from typing import Dict, Optional, Set, List
import logging
from ..dag.base import DAGContext from ..dag.base import DAGContext
from ..operator.base import WorkflowRunner, BaseOperator, CALL_DATA from ..operator.base import WorkflowRunner, BaseOperator, CALL_DATA
from ..operator.common_operator import BranchOperator, JoinOperator
from ..task.base import TaskContext, TaskState from ..task.base import TaskContext, TaskState
from ..task.task_impl import DefaultInputContext, DefaultTaskContext from ..task.task_impl import DefaultInputContext, DefaultTaskContext, SimpleTaskOutput
from .job_manager import JobManager from .job_manager import JobManager
logger = logging.getLogger(__name__)
class DefaultWorkflowRunner(WorkflowRunner): class DefaultWorkflowRunner(WorkflowRunner):
async def execute_workflow( async def execute_workflow(
@@ -13,10 +18,16 @@ class DefaultWorkflowRunner(WorkflowRunner):
# Create DAG context # Create DAG context
dag_ctx = DAGContext() dag_ctx = DAGContext()
job_manager = JobManager.build_from_end_node(node, call_data) job_manager = JobManager.build_from_end_node(node, call_data)
logger.info(
f"Begin run workflow from end operator, id: {node.node_id}, call_data: {call_data}"
)
dag = node.dag dag = node.dag
# Save node output # Save node output
node_outputs: Dict[str, TaskContext] = {} node_outputs: Dict[str, TaskContext] = {}
await self._execute_node(job_manager, node, dag_ctx, node_outputs) skip_node_ids = set()
await self._execute_node(
job_manager, node, dag_ctx, node_outputs, skip_node_ids
)
return dag_ctx return dag_ctx
@@ -26,6 +37,7 @@ class DefaultWorkflowRunner(WorkflowRunner):
node: BaseOperator, node: BaseOperator,
dag_ctx: DAGContext, dag_ctx: DAGContext,
node_outputs: Dict[str, TaskContext], node_outputs: Dict[str, TaskContext],
skip_node_ids: Set[str],
): ):
# Skip run node # Skip run node
if node.node_id in node_outputs: if node.node_id in node_outputs:
@@ -35,18 +47,9 @@ class DefaultWorkflowRunner(WorkflowRunner):
for upstream_node in node.upstream: for upstream_node in node.upstream:
if isinstance(upstream_node, BaseOperator): if isinstance(upstream_node, BaseOperator):
await self._execute_node( await self._execute_node(
job_manager, upstream_node, dag_ctx, node_outputs job_manager, upstream_node, dag_ctx, node_outputs, skip_node_ids
) )
# if node.current_task_context.current_state == TaskState.SKIP:
# return
# for upstream_node in node.upstream:
# if (
# isinstance(upstream_node, BaseOperator)
# and upstream_node.current_task_context.current_state == TaskState.SKIP
# ):
# return
# Get the input from upstream node
inputs = [ inputs = [
node_outputs[upstream_node.node_id] for upstream_node in node.upstream node_outputs[upstream_node.node_id] for upstream_node in node.upstream
] ]
@@ -56,13 +59,48 @@ class DefaultWorkflowRunner(WorkflowRunner):
task_ctx.set_task_input(input_ctx) task_ctx.set_task_input(input_ctx)
dag_ctx.set_current_task_context(task_ctx) dag_ctx.set_current_task_context(task_ctx)
task_ctx.set_current_state(TaskState.RUNNING) task_ctx.set_current_state(TaskState.RUNNING)
if node.node_id in skip_node_ids:
task_ctx.set_current_state(TaskState.SKIP)
task_ctx.set_task_output(SimpleTaskOutput(None))
node_outputs[node.node_id] = task_ctx
return
try: try:
# print(f"Begin run {node}") logger.info(
f"Begin run operator, node id: {node.node_id}, node name: {node.node_name}, cls: {node}"
)
await node._run(dag_ctx) await node._run(dag_ctx)
node_outputs[node.node_id] = dag_ctx.current_task_context node_outputs[node.node_id] = dag_ctx.current_task_context
task_ctx.set_current_state(TaskState.SUCCESS) task_ctx.set_current_state(TaskState.SUCCESS)
if isinstance(node, BranchOperator):
skip_nodes = task_ctx.metadata.get("skip_node_names", [])
logger.info(
f"Current is branch operator, skip node names: {skip_nodes}"
)
_skip_current_downstream_by_node_name(node, skip_nodes, skip_node_ids)
except Exception as e: except Exception as e:
logger.info(f"Run operator {node.node_id} error, error message: {str(e)}")
task_ctx.set_current_state(TaskState.FAILED) task_ctx.set_current_state(TaskState.FAILED)
raise e raise e
def _skip_current_downstream_by_node_name(
branch_node: BranchOperator, skip_nodes: List[str], skip_node_ids: Set[str]
):
if not skip_nodes:
return
for child in branch_node.downstream:
if child.node_name in skip_nodes:
logger.info(f"Skip node name {child.node_name}, node id {child.node_id}")
_skip_downstream_by_id(child, skip_node_ids)
def _skip_downstream_by_id(node: BaseOperator, skip_node_ids: Set[str]):
if isinstance(node, JoinOperator):
# Not skip join node
return
skip_node_ids.add(node.node_id)
for child in node.downstream:
_skip_downstream_by_id(child, skip_node_ids)

View File

@@ -81,6 +81,10 @@ class TaskOutput(ABC, Generic[T]):
output_data (Union[T, AsyncIterator[T]]): Output data. output_data (Union[T, AsyncIterator[T]]): Output data.
""" """
@abstractmethod
def new_output(self) -> "TaskOutput[T]":
"""Create new output object"""
async def map(self, map_func) -> "TaskOutput[T]": async def map(self, map_func) -> "TaskOutput[T]":
"""Apply a mapping function to the task's output. """Apply a mapping function to the task's output.
@@ -334,13 +338,18 @@ class InputContext(ABC):
""" """
return len(self.parent_outputs) == 1 return len(self.parent_outputs) == 1
def check_stream(self) -> bool: def check_stream(self, skip_empty: bool = False) -> bool:
"""Check if all parent outputs are streams. """Check if all parent outputs are streams.
Args:
skip_empty (bool): Skip empty output or not.
Returns: Returns:
bool: True if all parent outputs are streams, False otherwise. bool: True if all parent outputs are streams, False otherwise.
""" """
for out in self.parent_outputs: for out in self.parent_outputs:
if out.task_output.is_empty and skip_empty:
continue
if not (out.task_output and out.task_output.is_stream): if not (out.task_output and out.task_output.is_stream):
return False return False
return True return True

View File

@@ -13,10 +13,13 @@ from typing import (
Union, Union,
) )
import asyncio import asyncio
import logging
from .base import TaskOutput, TaskContext, TaskState, InputContext, InputSource, T from .base import TaskOutput, TaskContext, TaskState, InputContext, InputSource, T
logger = logging.getLogger(__name__)
async def _reduce_stream(stream: AsyncIterator, reduce_function) -> Any: async def _reduce_stream(stream: AsyncIterator, reduce_function) -> Any:
# Init accumulator # Init accumulator
try: try:
@@ -44,6 +47,9 @@ class SimpleTaskOutput(TaskOutput[T], Generic[T]):
def set_output(self, output_data: T | AsyncIterator[T]) -> None: def set_output(self, output_data: T | AsyncIterator[T]) -> None:
self._data = output_data self._data = output_data
def new_output(self) -> TaskOutput[T]:
return SimpleTaskOutput(None)
@property @property
def is_empty(self) -> bool: def is_empty(self) -> bool:
return not self._data return not self._data
@@ -89,6 +95,9 @@ class SimpleStreamTaskOutput(TaskOutput[T], Generic[T]):
def set_output(self, output_data: T | AsyncIterator[T]) -> None: def set_output(self, output_data: T | AsyncIterator[T]) -> None:
self._data = output_data self._data = output_data
def new_output(self) -> TaskOutput[T]:
return SimpleStreamTaskOutput(None)
async def map(self, map_func) -> TaskOutput[T]: async def map(self, map_func) -> TaskOutput[T]:
is_async = asyncio.iscoroutinefunction(map_func) is_async = asyncio.iscoroutinefunction(map_func)
@@ -213,7 +222,8 @@ class DefaultTaskContext(TaskContext, Generic[T]):
self._task_state = task_state self._task_state = task_state
def new_ctx(self) -> TaskContext: def new_ctx(self) -> TaskContext:
return DefaultTaskContext(self._task_id, self._task_state, self._output) new_output = self._output.new_output()
return DefaultTaskContext(self._task_id, self._task_state, new_output)
@property @property
def metadata(self) -> Dict[str, Any]: def metadata(self) -> Dict[str, Any]:
@@ -259,9 +269,17 @@ class DefaultInputContext(InputContext):
async def map_all(self, map_func: Callable[..., Any]) -> InputContext: async def map_all(self, map_func: Callable[..., Any]) -> InputContext:
if not self._outputs: if not self._outputs:
return DefaultInputContext([]) return DefaultInputContext([])
is_steam = self._outputs[0].task_output.is_stream # Some parent may be empty
not_empty_idx = 0
for i, p in enumerate(self._outputs):
if p.task_output.is_empty:
continue
not_empty_idx = i
break
# All output is empty?
is_steam = self._outputs[not_empty_idx].task_output.is_stream
if is_steam: if is_steam:
if not self.check_stream(): if not self.check_stream(skip_empty=True):
raise ValueError( raise ValueError(
"The output in all tasks must has same output format to map_all" "The output in all tasks must has same output format to map_all"
) )
@@ -275,9 +293,11 @@ class DefaultInputContext(InputContext):
map_res = await map_func(*outputs) map_res = await map_func(*outputs)
else: else:
map_res = map_func(*outputs) map_res = map_func(*outputs)
single_output: TaskContext = self._outputs[not_empty_idx].new_ctx()
single_output: TaskContext = self._outputs[0].new_ctx()
single_output.task_output.set_output(map_res) single_output.task_output.set_output(map_res)
logger.debug(
f"Current map_all map_res: {map_res}, is steam: {single_output.task_output.is_stream}"
)
return DefaultInputContext([single_output]) return DefaultInputContext([single_output])
async def reduce(self, reduce_func: Callable[[Any], Any]) -> InputContext: async def reduce(self, reduce_func: Callable[[Any], Any]) -> InputContext:
@@ -311,6 +331,7 @@ class DefaultInputContext(InputContext):
for i, task_ctx in enumerate(new_outputs): for i, task_ctx in enumerate(new_outputs):
task_ctx: TaskContext = task_ctx task_ctx: TaskContext = task_ctx
if results[i]: if results[i]:
task_ctx.task_output.set_output(True)
result_outputs.append(task_ctx) result_outputs.append(task_ctx)
else: else:
task_ctx.task_output.set_output(failed_value) task_ctx.task_output.set_output(failed_value)

View File

@@ -108,21 +108,34 @@ async def test_join_node(runner: WorkflowRunner, input_nodes: List[InputOperator
"input_node, is_odd", "input_node, is_odd",
[ [
({"outputs": [0]}, False), ({"outputs": [0]}, False),
# ({"outputs": [1]}, False), ({"outputs": [1]}, True),
], ],
indirect=["input_node"], indirect=["input_node"],
) )
async def test_branch_node( async def test_branch_node(
runner: WorkflowRunner, input_node: InputOperator, is_odd: bool runner: WorkflowRunner, input_node: InputOperator, is_odd: bool
): ):
def join_func(o1, o2) -> int:
print(f"join func result, o1: {o1}, o2: {o2}")
return o1 or o2
with DAG("test_join_node") as dag: with DAG("test_join_node") as dag:
odd_node = MapOperator(lambda x: 999, task_id="odd_node") odd_node = MapOperator(
even_node = MapOperator(lambda x: 888, task_id="even_node") lambda x: 999, task_id="odd_node", task_name="odd_node_name"
)
even_node = MapOperator(
lambda x: 888, task_id="even_node", task_name="even_node_name"
)
join_node = JoinOperator(join_func)
branch_node = BranchOperator( branch_node = BranchOperator(
{lambda x: x % 2 == 1: odd_node, lambda x: x % 2 == 0: even_node} {lambda x: x % 2 == 1: odd_node, lambda x: x % 2 == 0: even_node}
) )
branch_node >> odd_node >> join_node
branch_node >> even_node >> join_node
input_node >> branch_node input_node >> branch_node
odd_res: DAGContext[int] = await runner.execute_workflow(odd_node) res: DAGContext[int] = await runner.execute_workflow(join_node)
even_res: DAGContext[int] = await runner.execute_workflow(even_node) assert res.current_task_context.current_state == TaskState.SUCCESS
assert branch_node.current_task_context.current_state == TaskState.SUCCESS expect_res = 999 if is_odd else 888
assert res.current_task_context.task_output.output == expect_res

View File

@@ -1,10 +1,13 @@
from typing import AsyncIterator, Dict from typing import AsyncIterator, Dict, Union
import logging import logging
from pilot.awel import ( from pilot.awel import (
BranchFunc,
StreamifyAbsOperator, StreamifyAbsOperator,
BranchOperator,
MapOperator, MapOperator,
TransformStreamAbsOperator, TransformStreamAbsOperator,
) )
from pilot.awel.operator.base import BaseOperator
from pilot.model.base import ModelOutput from pilot.model.base import ModelOutput
from pilot.model.cluster import WorkerManager from pilot.model.cluster import WorkerManager
from pilot.cache import LLMCacheClient, CacheManager, LLMCacheKey, LLMCacheValue from pilot.cache import LLMCacheClient, CacheManager, LLMCacheKey, LLMCacheValue
@@ -16,71 +19,189 @@ _LLM_MODEL_OUTPUT_CACHE_KEY = "llm_model_output_cache"
class ModelStreamOperator(StreamifyAbsOperator[Dict, ModelOutput]): class ModelStreamOperator(StreamifyAbsOperator[Dict, ModelOutput]):
"""Operator for streaming processing of model outputs.
Args:
worker_manager (WorkerManager): The manager that handles worker processes for model inference.
**kwargs: Additional keyword arguments.
Methods:
streamify: Asynchronously processes a stream of inputs, yielding model outputs.
"""
def __init__(self, worker_manager: WorkerManager, **kwargs) -> None: def __init__(self, worker_manager: WorkerManager, **kwargs) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
self.worker_manager = worker_manager self.worker_manager = worker_manager
async def streamify(self, input_value: Dict) -> AsyncIterator[ModelOutput]: async def streamify(self, input_value: Dict) -> AsyncIterator[ModelOutput]:
llm_cache_value: LLMCacheValue = await self.current_dag_context.get_share_data( """Process inputs as a stream and yield model outputs.
_LLM_MODEL_OUTPUT_CACHE_KEY
) Args:
logger.info(f"llm_cache_value: {llm_cache_value}") input_value (Dict): The input value for the model.
if llm_cache_value:
for out in llm_cache_value.get_value().output: Returns:
yield out AsyncIterator[ModelOutput]: An asynchronous iterator of model outputs.
return """
async for out in self.worker_manager.generate_stream(input_value): async for out in self.worker_manager.generate_stream(input_value):
yield out yield out
class ModelOperator(MapOperator[Dict, ModelOutput]): class ModelOperator(MapOperator[Dict, ModelOutput]):
"""Operator for map-based processing of model outputs.
Args:
worker_manager (WorkerManager): Manager for handling worker processes.
**kwargs: Additional keyword arguments.
Methods:
map: Asynchronously processes a single input and returns the model output.
"""
def __init__(self, worker_manager: WorkerManager, **kwargs) -> None: def __init__(self, worker_manager: WorkerManager, **kwargs) -> None:
self.worker_manager = worker_manager self.worker_manager = worker_manager
super().__init__(**kwargs) super().__init__(**kwargs)
async def map(self, input_value: Dict) -> ModelOutput: async def map(self, input_value: Dict) -> ModelOutput:
llm_cache_value: LLMCacheValue = await self.current_dag_context.get_share_data( """Process a single input and return the model output.
_LLM_MODEL_OUTPUT_CACHE_KEY
) Args:
logger.info(f"llm_cache_value: {llm_cache_value}") input_value (Dict): The input value for the model.
if llm_cache_value:
return llm_cache_value.get_value().output Returns:
ModelOutput: The output from the model.
"""
return await self.worker_manager.generate(input_value) return await self.worker_manager.generate(input_value)
class ModelCachePreOperator(MapOperator[Dict, Dict]): class CachedModelStreamOperator(StreamifyAbsOperator[Dict, ModelOutput]):
def __init__(self, cache_manager: CacheManager, **kwargs): """Operator for streaming processing of model outputs with caching.
Args:
cache_manager (CacheManager): The cache manager to handle caching operations.
**kwargs: Additional keyword arguments.
Methods:
streamify: Processes a stream of inputs with cache support, yielding model outputs.
"""
def __init__(self, cache_manager: CacheManager, **kwargs) -> None:
super().__init__(**kwargs)
self._cache_manager = cache_manager self._cache_manager = cache_manager
self._client = LLMCacheClient(cache_manager) self._client = LLMCacheClient(cache_manager)
async def streamify(self, input_value: Dict) -> AsyncIterator[ModelOutput]:
"""Process inputs as a stream with cache support and yield model outputs.
Args:
input_value (Dict): The input value for the model.
Returns:
AsyncIterator[ModelOutput]: An asynchronous iterator of model outputs.
"""
cache_dict = _parse_cache_key_dict(input_value)
llm_cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
llm_cache_value: LLMCacheValue = await self._client.get(llm_cache_key)
logger.info(f"llm_cache_value: {llm_cache_value}")
for out in llm_cache_value.get_value().output:
yield out
class CachedModelOperator(MapOperator[Dict, ModelOutput]):
"""Operator for map-based processing of model outputs with caching.
Args:
cache_manager (CacheManager): Manager for caching operations.
**kwargs: Additional keyword arguments.
Methods:
map: Processes a single input with cache support and returns the model output.
"""
def __init__(self, cache_manager: CacheManager, **kwargs) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
self._cache_manager = cache_manager
self._client = LLMCacheClient(cache_manager)
async def map(self, input_value: Dict) -> Dict: async def map(self, input_value: Dict) -> ModelOutput:
cache_dict = { """Process a single input with cache support and return the model output.
"prompt": input_value.get("prompt"),
"model_name": input_value.get("model"), Args:
"temperature": input_value.get("temperature"), input_value (Dict): The input value for the model.
"max_new_tokens": input_value.get("max_new_tokens"),
"top_p": input_value.get("top_p", "1.0"), Returns:
# TODO pass model_type ModelOutput: The output from the model.
"model_type": input_value.get("model_type", "huggingface"), """
} cache_dict = _parse_cache_key_dict(input_value)
cache_key: LLMCacheKey = self._client.new_key(**cache_dict) llm_cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
cache_value = await self._client.get(cache_key) llm_cache_value: LLMCacheValue = await self._client.get(llm_cache_key)
logger.debug( logger.info(f"llm_cache_value: {llm_cache_value}")
f"cache_key: {cache_key}, hash key: {hash(cache_key)}, cache_value: {cache_value}" return llm_cache_value.get_value().output
)
await self.current_dag_context.save_to_share_data(
_LLM_MODEL_INPUT_VALUE_KEY, cache_key class ModelCacheBranchOperator(BranchOperator[Dict, Dict]):
) """
if cache_value: A branch operator that decides whether to use cached data or to process data using the model.
logger.info(f"The model output has cached, cache_value: {cache_value}")
await self.current_dag_context.save_to_share_data( Args:
_LLM_MODEL_OUTPUT_CACHE_KEY, cache_value cache_manager (CacheManager): The cache manager for managing cache operations.
model_task_name (str): The name of the task to process data using the model.
cache_task_name (str): The name of the task to process data using the cache.
**kwargs: Additional keyword arguments.
"""
def __init__(
self,
cache_manager: CacheManager,
model_task_name: str,
cache_task_name: str,
**kwargs,
):
super().__init__(branches=None, **kwargs)
self._cache_manager = cache_manager
self._client = LLMCacheClient(cache_manager)
self._model_task_name = model_task_name
self._cache_task_name = cache_task_name
async def branchs(self) -> Dict[BranchFunc[Dict], Union[BaseOperator, str]]:
"""Defines branch logic based on cache availability.
Returns:
Dict[BranchFunc[Dict], Union[BaseOperator, str]]: A dictionary mapping branch functions to task names.
"""
async def check_cache_true(input_value: Dict) -> bool:
# Check if the cache contains the result for the given input
cache_dict = _parse_cache_key_dict(input_value)
cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
cache_value = await self._client.get(cache_key)
logger.debug(
f"cache_key: {cache_key}, hash key: {hash(cache_key)}, cache_value: {cache_value}"
) )
return input_value await self.current_dag_context.save_to_share_data(
_LLM_MODEL_INPUT_VALUE_KEY, cache_key
)
return True if cache_value else False
async def check_cache_false(input_value: Dict):
# Inverse of check_cache_true
return not await check_cache_true(input_value)
return {
check_cache_true: self._cache_task_name,
check_cache_false: self._model_task_name,
}
class ModelStreamCacheOperator(TransformStreamAbsOperator[ModelOutput, ModelOutput]): class ModelStreamSaveCacheOperator(
TransformStreamAbsOperator[ModelOutput, ModelOutput]
):
"""An operator to save the stream of model outputs to cache.
Args:
cache_manager (CacheManager): The cache manager for handling cache operations.
**kwargs: Additional keyword arguments.
"""
def __init__(self, cache_manager: CacheManager, **kwargs): def __init__(self, cache_manager: CacheManager, **kwargs):
self._cache_manager = cache_manager self._cache_manager = cache_manager
self._client = LLMCacheClient(cache_manager) self._client = LLMCacheClient(cache_manager)
@@ -89,6 +210,14 @@ class ModelStreamCacheOperator(TransformStreamAbsOperator[ModelOutput, ModelOutp
async def transform_stream( async def transform_stream(
self, input_value: AsyncIterator[ModelOutput] self, input_value: AsyncIterator[ModelOutput]
) -> AsyncIterator[ModelOutput]: ) -> AsyncIterator[ModelOutput]:
"""Transforms the input stream by saving the outputs to cache.
Args:
input_value (AsyncIterator[ModelOutput]): An asynchronous iterator of model outputs.
Returns:
AsyncIterator[ModelOutput]: The same input iterator, but the outputs are saved to cache.
"""
llm_cache_key: LLMCacheKey = None llm_cache_key: LLMCacheKey = None
outputs = [] outputs = []
async for out in input_value: async for out in input_value:
@@ -103,13 +232,28 @@ class ModelStreamCacheOperator(TransformStreamAbsOperator[ModelOutput, ModelOutp
await self._client.set(llm_cache_key, llm_cache_value) await self._client.set(llm_cache_key, llm_cache_value)
class ModelCacheOperator(MapOperator[ModelOutput, ModelOutput]): class ModelSaveCacheOperator(MapOperator[ModelOutput, ModelOutput]):
"""An operator to save a single model output to cache.
Args:
cache_manager (CacheManager): The cache manager for handling cache operations.
**kwargs: Additional keyword arguments.
"""
def __init__(self, cache_manager: CacheManager, **kwargs): def __init__(self, cache_manager: CacheManager, **kwargs):
self._cache_manager = cache_manager self._cache_manager = cache_manager
self._client = LLMCacheClient(cache_manager) self._client = LLMCacheClient(cache_manager)
super().__init__(**kwargs) super().__init__(**kwargs)
async def map(self, input_value: ModelOutput) -> ModelOutput: async def map(self, input_value: ModelOutput) -> ModelOutput:
"""Saves a single model output to cache and returns it.
Args:
input_value (ModelOutput): The output from the model to be cached.
Returns:
ModelOutput: The same input model output.
"""
llm_cache_key: LLMCacheKey = await self.current_dag_context.get_share_data( llm_cache_key: LLMCacheKey = await self.current_dag_context.get_share_data(
_LLM_MODEL_INPUT_VALUE_KEY _LLM_MODEL_INPUT_VALUE_KEY
) )
@@ -117,3 +261,26 @@ class ModelCacheOperator(MapOperator[ModelOutput, ModelOutput]):
if llm_cache_key: if llm_cache_key:
await self._client.set(llm_cache_key, llm_cache_value) await self._client.set(llm_cache_key, llm_cache_value)
return input_value return input_value
def _parse_cache_key_dict(input_value: Dict) -> Dict:
"""Parses and extracts relevant fields from input to form a cache key dictionary.
Args:
input_value (Dict): The input dictionary containing model and prompt parameters.
Returns:
Dict: A dictionary used for generating cache keys.
"""
prompt: str = input_value.get("prompt")
if prompt:
prompt = prompt.strip()
return {
"prompt": prompt,
"model_name": input_value.get("model"),
"temperature": input_value.get("temperature"),
"max_new_tokens": input_value.get("max_new_tokens"),
"top_p": input_value.get("top_p", "1.0"),
# TODO pass model_type
"model_type": input_value.get("model_type", "huggingface"),
}

View File

@@ -566,29 +566,83 @@ class BaseChat(ABC):
def _build_model_operator( def _build_model_operator(
is_stream: bool = False, dag_name: str = "llm_model_dag" is_stream: bool = False, dag_name: str = "llm_model_dag"
) -> BaseOperator: ) -> BaseOperator:
"""Builds and returns a model processing workflow (DAG) operator.
This function constructs a Directed Acyclic Graph (DAG) for processing data using a model.
It includes caching and branching logic to either fetch results from a cache or process
data using the model. It supports both streaming and non-streaming modes.
.. code-block:: python
input_node >> cache_check_branch_node
cache_check_branch_node >> model_node >> save_cached_node >> join_node
cache_check_branch_node >> cached_node >> join_node
equivalent to::
-> model_node -> save_cached_node ->
/ \
input_node -> cache_check_branch_node ---> join_node
\ /
-> cached_node ------------------- ->
Args:
is_stream (bool): Flag to determine if the operator should process data in streaming mode.
dag_name (str): Name of the DAG.
Returns:
BaseOperator: The final operator in the constructed DAG, typically a join node.
"""
from pilot.model.cluster import WorkerManagerFactory from pilot.model.cluster import WorkerManagerFactory
from pilot.awel import JoinOperator
from pilot.model.operator.model_operator import ( from pilot.model.operator.model_operator import (
ModelCacheOperator, ModelCacheBranchOperator,
ModelStreamCacheOperator, CachedModelStreamOperator,
ModelCachePreOperator, CachedModelOperator,
ModelSaveCacheOperator,
ModelStreamSaveCacheOperator,
) )
from pilot.cache import CacheManager from pilot.cache import CacheManager
# Fetch worker and cache managers from the system configuration
worker_manager = CFG.SYSTEM_APP.get_component( worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create() ).create()
cache_manager: CacheManager = CFG.SYSTEM_APP.get_component( cache_manager: CacheManager = CFG.SYSTEM_APP.get_component(
ComponentType.MODEL_CACHE_MANAGER, CacheManager ComponentType.MODEL_CACHE_MANAGER, CacheManager
) )
# Define task names for the model and cache nodes
model_task_name = "llm_model_node"
cache_task_name = "llm_model_cache_node"
with DAG(dag_name): with DAG(dag_name):
# Create an input node
input_node = InputOperator(SimpleCallDataInputSource()) input_node = InputOperator(SimpleCallDataInputSource())
cache_check_node = ModelCachePreOperator(cache_manager) # Determine if the workflow should operate in streaming mode
if is_stream: if is_stream:
model_node = ModelStreamOperator(worker_manager) model_node = ModelStreamOperator(worker_manager, task_name=model_task_name)
cache_node = ModelStreamCacheOperator(cache_manager) cached_node = CachedModelStreamOperator(
cache_manager, task_name=cache_task_name
)
save_cached_node = ModelStreamSaveCacheOperator(cache_manager)
else: else:
model_node = ModelOperator(worker_manager) model_node = ModelOperator(worker_manager, task_name=model_task_name)
cache_node = ModelCacheOperator(cache_manager) cached_node = CachedModelOperator(cache_manager, task_name=cache_task_name)
input_node >> cache_check_node >> model_node >> cache_node save_cached_node = ModelSaveCacheOperator(cache_manager)
return cache_node
# Create a branch node to decide between fetching from cache or processing with the model
cache_check_branch_node = ModelCacheBranchOperator(
cache_manager,
model_task_name="llm_model_node",
cache_task_name="llm_model_cache_node",
)
# Create a join node to merge outputs from the model and cache nodes, just keep the fist not empty output
join_node = JoinOperator(
combine_function=lambda model_out, cache_out: cache_out or model_out
)
# Define the workflow structure using the >> operator
input_node >> cache_check_branch_node
cache_check_branch_node >> model_node >> save_cached_node >> join_node
cache_check_branch_node >> cached_node >> join_node
return join_node