feat(model): Support BranchOperator

This commit is contained in:
FangYin Cheng
2023-11-16 18:13:16 +08:00
parent 6db8c49d87
commit 1150adbe6a
11 changed files with 452 additions and 102 deletions

View File

@@ -9,6 +9,7 @@ from .operator.common_operator import (
MapOperator,
BranchOperator,
InputOperator,
BranchFunc,
)
from .operator.stream_operator import (
@@ -25,6 +26,7 @@ from .task.task_impl import (
DefaultInputContext,
SimpleTaskOutput,
SimpleStreamTaskOutput,
_is_async_iterator,
)
from .runner.local_runner import DefaultWorkflowRunner
@@ -38,6 +40,7 @@ __all__ = [
"MapOperator",
"BranchOperator",
"InputOperator",
"BranchFunc",
"WorkflowRunner",
"TaskState",
"TaskOutput",

View File

@@ -143,7 +143,9 @@ class DAGNode(DependencyMixin, ABC):
resource_group: Optional[ResourceGroup] = None
"""The resource group of current DAGNode"""
def __init__(self, dag: Optional["DAG"] = None, node_id: str = None) -> None:
def __init__(
self, dag: Optional["DAG"] = None, node_id: str = None, node_name: str = None
) -> None:
super().__init__()
self._upstream: List["DAGNode"] = []
self._downstream: List["DAGNode"] = []
@@ -151,6 +153,7 @@ class DAGNode(DependencyMixin, ABC):
if not node_id and self._dag:
node_id = self._dag._new_node_id()
self._node_id: str = node_id
self._node_name: str = node_name
@property
def node_id(self) -> str:
@@ -159,6 +162,21 @@ class DAGNode(DependencyMixin, ABC):
def set_node_id(self, node_id: str) -> None:
self._node_id = node_id
def __hash__(self) -> int:
if self.node_id:
return hash(self.node_id)
else:
return super().__hash__()
def __eq__(self, other: Any) -> bool:
if not isinstance(other, DAGNode):
return False
return self.node_id == other.node_id
@property
def node_name(self) -> str:
return self._node_name
@property
def dag(self) -> "DAGNode":
return self._dag

View File

@@ -100,6 +100,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
def __init__(
self,
task_id: Optional[str] = None,
task_name: Optional[str] = None,
dag: Optional[DAG] = None,
runner: WorkflowRunner = None,
**kwargs,
@@ -109,7 +110,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
Args:
runner (WorkflowRunner, optional): The runner used to execute the workflow. Defaults to None.
"""
super().__init__(node_id=task_id, dag=dag, **kwargs)
super().__init__(node_id=task_id, node_name=task_name, dag=dag, **kwargs)
if not runner:
from pilot.awel import DefaultWorkflowRunner

View File

@@ -1,5 +1,6 @@
from typing import Generic, Dict, List, Union, Callable, Any, AsyncIterator, Awaitable
import asyncio
import logging
from ..dag.base import DAGContext
from ..task.base import (
@@ -14,6 +15,9 @@ from ..task.base import (
from .base import BaseOperator
logger = logging.getLogger(__name__)
class JoinOperator(BaseOperator, Generic[OUT]):
"""Operator that joins inputs using a custom combine function.
@@ -141,31 +145,36 @@ class MapOperator(BaseOperator, Generic[IN, OUT]):
raise NotImplementedError
BranchFunc = Union[Callable[[Any], bool], Callable[[Any], Awaitable[bool]]]
BranchFunc = Union[Callable[[IN], bool], Callable[[IN], Awaitable[bool]]]
class BranchOperator(BaseOperator, Generic[OUT]):
class BranchOperator(BaseOperator, Generic[IN, OUT]):
"""Operator node that branches the workflow based on a provided function.
This node filters its input data using a branching function and
allows for conditional paths in the workflow.
"""
def __init__(self, branches: Dict[BranchFunc, BaseOperator], **kwargs):
def __init__(
self, branches: Dict[BranchFunc[IN], Union[BaseOperator, str]], **kwargs
):
"""
Initializes a BranchDAGNode with a branching function.
Args:
branches (Dict[BranchFunc, RunnableDAGNode]): Dict of function that defines the branching condition.
branches (Dict[BranchFunc[IN], Union[BaseOperator, str]]): Dict of function that defines the branching condition.
Raises:
ValueError: If the branch_function is not callable.
"""
super().__init__(**kwargs)
for branch_function in branches.keys():
if branches:
for branch_function, value in branches.items():
if not callable(branch_function):
raise ValueError("branch_function must be callable")
self.branches = branches
if isinstance(value, BaseOperator):
branches[branch_function] = value.node_name or value.node_name
self._branches = branches
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
"""Run the branching operation on the DAG context's inputs.
@@ -186,27 +195,36 @@ class BranchOperator(BaseOperator, Generic[OUT]):
if not task_input.check_single_parent():
raise ValueError("BranchDAGNode expects single parent")
branches = self._branches
if not branches:
branches = await self.branchs()
branch_func_tasks = []
branch_nodes: List[BaseOperator] = []
for func, node in self.branches.items():
branch_nodes.append(node)
branch_nodes: List[str] = []
for func, node_name in branches.items():
branch_nodes.append(node_name)
branch_func_tasks.append(
curr_task_ctx.task_input.predicate_map(func, failed_value=None)
)
branch_input_ctxs: List[InputContext] = await asyncio.gather(*branch_func_tasks)
parent_output = task_input.parent_outputs[0].task_output
curr_task_ctx.set_task_output(parent_output)
skip_node_names = []
for i, ctx in enumerate(branch_input_ctxs):
node = branch_nodes[i]
node_name = branch_nodes[i]
branch_out = ctx.parent_outputs[0].task_output
logger.info(
f"branch_input_ctxs {i} result {branch_out.output}, is_empty: {branch_out.is_empty}"
)
if ctx.parent_outputs[0].task_output.is_empty:
# Skip current node
# node.current_task_context.set_current_state(TaskState.SKIP)
pass
else:
pass
logger.info(f"Skip node name {node_name}")
skip_node_names.append(node_name)
curr_task_ctx.update_metadata("skip_node_names", skip_node_names)
return parent_output
async def branchs(self) -> Dict[BranchFunc[IN], Union[BaseOperator, str]]:
raise NotImplementedError
return None
class InputOperator(BaseOperator, Generic[OUT]):

View File

@@ -1,9 +1,12 @@
from typing import List, Set, Optional, Dict
import uuid
import logging
from ..dag.base import DAG
from ..operator.base import BaseOperator, CALL_DATA
logger = logging.getLogger(__name__)
class DAGNodeInstance:
def __init__(self, node_instance: DAG) -> None:
@@ -45,14 +48,19 @@ def _save_call_data(
root_nodes: List[BaseOperator], call_data: CALL_DATA
) -> Dict[str, Dict]:
id2call_data = {}
logger.debug(f"_save_call_data: {call_data}, root_nodes: {root_nodes}")
if not call_data:
return id2call_data
if len(root_nodes) == 1:
node = root_nodes[0]
logger.info(f"Save call data to node {node.node_id}, call_data: {call_data}")
id2call_data[node.node_id] = call_data
else:
for node in root_nodes:
node_id = node.node_id
logger.info(
f"Save call data to node {node.node_id}, call_data: {call_data.get(node_id)}"
)
id2call_data[node_id] = call_data.get(node_id)
return id2call_data
@@ -71,4 +79,4 @@ def _build_from_end_node(end_node: BaseOperator) -> List[BaseOperator]:
def _get_root_nodes(nodes: List[BaseOperator]) -> List[BaseOperator]:
return list(filter(lambda x: not x.upstream, nodes))
return list(set(filter(lambda x: not x.upstream, nodes)))

View File

@@ -1,10 +1,15 @@
from typing import Dict, Optional
from typing import Dict, Optional, Set, List
import logging
from ..dag.base import DAGContext
from ..operator.base import WorkflowRunner, BaseOperator, CALL_DATA
from ..operator.common_operator import BranchOperator, JoinOperator
from ..task.base import TaskContext, TaskState
from ..task.task_impl import DefaultInputContext, DefaultTaskContext
from ..task.task_impl import DefaultInputContext, DefaultTaskContext, SimpleTaskOutput
from .job_manager import JobManager
logger = logging.getLogger(__name__)
class DefaultWorkflowRunner(WorkflowRunner):
async def execute_workflow(
@@ -13,10 +18,16 @@ class DefaultWorkflowRunner(WorkflowRunner):
# Create DAG context
dag_ctx = DAGContext()
job_manager = JobManager.build_from_end_node(node, call_data)
logger.info(
f"Begin run workflow from end operator, id: {node.node_id}, call_data: {call_data}"
)
dag = node.dag
# Save node output
node_outputs: Dict[str, TaskContext] = {}
await self._execute_node(job_manager, node, dag_ctx, node_outputs)
skip_node_ids = set()
await self._execute_node(
job_manager, node, dag_ctx, node_outputs, skip_node_ids
)
return dag_ctx
@@ -26,6 +37,7 @@ class DefaultWorkflowRunner(WorkflowRunner):
node: BaseOperator,
dag_ctx: DAGContext,
node_outputs: Dict[str, TaskContext],
skip_node_ids: Set[str],
):
# Skip run node
if node.node_id in node_outputs:
@@ -35,18 +47,9 @@ class DefaultWorkflowRunner(WorkflowRunner):
for upstream_node in node.upstream:
if isinstance(upstream_node, BaseOperator):
await self._execute_node(
job_manager, upstream_node, dag_ctx, node_outputs
job_manager, upstream_node, dag_ctx, node_outputs, skip_node_ids
)
# if node.current_task_context.current_state == TaskState.SKIP:
# return
# for upstream_node in node.upstream:
# if (
# isinstance(upstream_node, BaseOperator)
# and upstream_node.current_task_context.current_state == TaskState.SKIP
# ):
# return
# Get the input from upstream node
inputs = [
node_outputs[upstream_node.node_id] for upstream_node in node.upstream
]
@@ -56,13 +59,48 @@ class DefaultWorkflowRunner(WorkflowRunner):
task_ctx.set_task_input(input_ctx)
dag_ctx.set_current_task_context(task_ctx)
task_ctx.set_current_state(TaskState.RUNNING)
if node.node_id in skip_node_ids:
task_ctx.set_current_state(TaskState.SKIP)
task_ctx.set_task_output(SimpleTaskOutput(None))
node_outputs[node.node_id] = task_ctx
return
try:
# print(f"Begin run {node}")
logger.info(
f"Begin run operator, node id: {node.node_id}, node name: {node.node_name}, cls: {node}"
)
await node._run(dag_ctx)
node_outputs[node.node_id] = dag_ctx.current_task_context
task_ctx.set_current_state(TaskState.SUCCESS)
if isinstance(node, BranchOperator):
skip_nodes = task_ctx.metadata.get("skip_node_names", [])
logger.info(
f"Current is branch operator, skip node names: {skip_nodes}"
)
_skip_current_downstream_by_node_name(node, skip_nodes, skip_node_ids)
except Exception as e:
logger.info(f"Run operator {node.node_id} error, error message: {str(e)}")
task_ctx.set_current_state(TaskState.FAILED)
raise e
def _skip_current_downstream_by_node_name(
branch_node: BranchOperator, skip_nodes: List[str], skip_node_ids: Set[str]
):
if not skip_nodes:
return
for child in branch_node.downstream:
if child.node_name in skip_nodes:
logger.info(f"Skip node name {child.node_name}, node id {child.node_id}")
_skip_downstream_by_id(child, skip_node_ids)
def _skip_downstream_by_id(node: BaseOperator, skip_node_ids: Set[str]):
if isinstance(node, JoinOperator):
# Not skip join node
return
skip_node_ids.add(node.node_id)
for child in node.downstream:
_skip_downstream_by_id(child, skip_node_ids)

View File

@@ -81,6 +81,10 @@ class TaskOutput(ABC, Generic[T]):
output_data (Union[T, AsyncIterator[T]]): Output data.
"""
@abstractmethod
def new_output(self) -> "TaskOutput[T]":
"""Create new output object"""
async def map(self, map_func) -> "TaskOutput[T]":
"""Apply a mapping function to the task's output.
@@ -334,13 +338,18 @@ class InputContext(ABC):
"""
return len(self.parent_outputs) == 1
def check_stream(self) -> bool:
def check_stream(self, skip_empty: bool = False) -> bool:
"""Check if all parent outputs are streams.
Args:
skip_empty (bool): Skip empty output or not.
Returns:
bool: True if all parent outputs are streams, False otherwise.
"""
for out in self.parent_outputs:
if out.task_output.is_empty and skip_empty:
continue
if not (out.task_output and out.task_output.is_stream):
return False
return True

View File

@@ -13,10 +13,13 @@ from typing import (
Union,
)
import asyncio
import logging
from .base import TaskOutput, TaskContext, TaskState, InputContext, InputSource, T
logger = logging.getLogger(__name__)
async def _reduce_stream(stream: AsyncIterator, reduce_function) -> Any:
# Init accumulator
try:
@@ -44,6 +47,9 @@ class SimpleTaskOutput(TaskOutput[T], Generic[T]):
def set_output(self, output_data: T | AsyncIterator[T]) -> None:
self._data = output_data
def new_output(self) -> TaskOutput[T]:
return SimpleTaskOutput(None)
@property
def is_empty(self) -> bool:
return not self._data
@@ -89,6 +95,9 @@ class SimpleStreamTaskOutput(TaskOutput[T], Generic[T]):
def set_output(self, output_data: T | AsyncIterator[T]) -> None:
self._data = output_data
def new_output(self) -> TaskOutput[T]:
return SimpleStreamTaskOutput(None)
async def map(self, map_func) -> TaskOutput[T]:
is_async = asyncio.iscoroutinefunction(map_func)
@@ -213,7 +222,8 @@ class DefaultTaskContext(TaskContext, Generic[T]):
self._task_state = task_state
def new_ctx(self) -> TaskContext:
return DefaultTaskContext(self._task_id, self._task_state, self._output)
new_output = self._output.new_output()
return DefaultTaskContext(self._task_id, self._task_state, new_output)
@property
def metadata(self) -> Dict[str, Any]:
@@ -259,9 +269,17 @@ class DefaultInputContext(InputContext):
async def map_all(self, map_func: Callable[..., Any]) -> InputContext:
if not self._outputs:
return DefaultInputContext([])
is_steam = self._outputs[0].task_output.is_stream
# Some parent may be empty
not_empty_idx = 0
for i, p in enumerate(self._outputs):
if p.task_output.is_empty:
continue
not_empty_idx = i
break
# All output is empty?
is_steam = self._outputs[not_empty_idx].task_output.is_stream
if is_steam:
if not self.check_stream():
if not self.check_stream(skip_empty=True):
raise ValueError(
"The output in all tasks must has same output format to map_all"
)
@@ -275,9 +293,11 @@ class DefaultInputContext(InputContext):
map_res = await map_func(*outputs)
else:
map_res = map_func(*outputs)
single_output: TaskContext = self._outputs[0].new_ctx()
single_output: TaskContext = self._outputs[not_empty_idx].new_ctx()
single_output.task_output.set_output(map_res)
logger.debug(
f"Current map_all map_res: {map_res}, is steam: {single_output.task_output.is_stream}"
)
return DefaultInputContext([single_output])
async def reduce(self, reduce_func: Callable[[Any], Any]) -> InputContext:
@@ -311,6 +331,7 @@ class DefaultInputContext(InputContext):
for i, task_ctx in enumerate(new_outputs):
task_ctx: TaskContext = task_ctx
if results[i]:
task_ctx.task_output.set_output(True)
result_outputs.append(task_ctx)
else:
task_ctx.task_output.set_output(failed_value)

View File

@@ -108,21 +108,34 @@ async def test_join_node(runner: WorkflowRunner, input_nodes: List[InputOperator
"input_node, is_odd",
[
({"outputs": [0]}, False),
# ({"outputs": [1]}, False),
({"outputs": [1]}, True),
],
indirect=["input_node"],
)
async def test_branch_node(
runner: WorkflowRunner, input_node: InputOperator, is_odd: bool
):
def join_func(o1, o2) -> int:
print(f"join func result, o1: {o1}, o2: {o2}")
return o1 or o2
with DAG("test_join_node") as dag:
odd_node = MapOperator(lambda x: 999, task_id="odd_node")
even_node = MapOperator(lambda x: 888, task_id="even_node")
odd_node = MapOperator(
lambda x: 999, task_id="odd_node", task_name="odd_node_name"
)
even_node = MapOperator(
lambda x: 888, task_id="even_node", task_name="even_node_name"
)
join_node = JoinOperator(join_func)
branch_node = BranchOperator(
{lambda x: x % 2 == 1: odd_node, lambda x: x % 2 == 0: even_node}
)
branch_node >> odd_node >> join_node
branch_node >> even_node >> join_node
input_node >> branch_node
odd_res: DAGContext[int] = await runner.execute_workflow(odd_node)
even_res: DAGContext[int] = await runner.execute_workflow(even_node)
assert branch_node.current_task_context.current_state == TaskState.SUCCESS
res: DAGContext[int] = await runner.execute_workflow(join_node)
assert res.current_task_context.current_state == TaskState.SUCCESS
expect_res = 999 if is_odd else 888
assert res.current_task_context.task_output.output == expect_res

View File

@@ -1,10 +1,13 @@
from typing import AsyncIterator, Dict
from typing import AsyncIterator, Dict, Union
import logging
from pilot.awel import (
BranchFunc,
StreamifyAbsOperator,
BranchOperator,
MapOperator,
TransformStreamAbsOperator,
)
from pilot.awel.operator.base import BaseOperator
from pilot.model.base import ModelOutput
from pilot.model.cluster import WorkerManager
from pilot.cache import LLMCacheClient, CacheManager, LLMCacheKey, LLMCacheValue
@@ -16,54 +19,159 @@ _LLM_MODEL_OUTPUT_CACHE_KEY = "llm_model_output_cache"
class ModelStreamOperator(StreamifyAbsOperator[Dict, ModelOutput]):
"""Operator for streaming processing of model outputs.
Args:
worker_manager (WorkerManager): The manager that handles worker processes for model inference.
**kwargs: Additional keyword arguments.
Methods:
streamify: Asynchronously processes a stream of inputs, yielding model outputs.
"""
def __init__(self, worker_manager: WorkerManager, **kwargs) -> None:
super().__init__(**kwargs)
self.worker_manager = worker_manager
async def streamify(self, input_value: Dict) -> AsyncIterator[ModelOutput]:
llm_cache_value: LLMCacheValue = await self.current_dag_context.get_share_data(
_LLM_MODEL_OUTPUT_CACHE_KEY
)
logger.info(f"llm_cache_value: {llm_cache_value}")
if llm_cache_value:
for out in llm_cache_value.get_value().output:
yield out
return
"""Process inputs as a stream and yield model outputs.
Args:
input_value (Dict): The input value for the model.
Returns:
AsyncIterator[ModelOutput]: An asynchronous iterator of model outputs.
"""
async for out in self.worker_manager.generate_stream(input_value):
yield out
class ModelOperator(MapOperator[Dict, ModelOutput]):
"""Operator for map-based processing of model outputs.
Args:
worker_manager (WorkerManager): Manager for handling worker processes.
**kwargs: Additional keyword arguments.
Methods:
map: Asynchronously processes a single input and returns the model output.
"""
def __init__(self, worker_manager: WorkerManager, **kwargs) -> None:
self.worker_manager = worker_manager
super().__init__(**kwargs)
async def map(self, input_value: Dict) -> ModelOutput:
llm_cache_value: LLMCacheValue = await self.current_dag_context.get_share_data(
_LLM_MODEL_OUTPUT_CACHE_KEY
)
logger.info(f"llm_cache_value: {llm_cache_value}")
if llm_cache_value:
return llm_cache_value.get_value().output
"""Process a single input and return the model output.
Args:
input_value (Dict): The input value for the model.
Returns:
ModelOutput: The output from the model.
"""
return await self.worker_manager.generate(input_value)
class ModelCachePreOperator(MapOperator[Dict, Dict]):
def __init__(self, cache_manager: CacheManager, **kwargs):
class CachedModelStreamOperator(StreamifyAbsOperator[Dict, ModelOutput]):
"""Operator for streaming processing of model outputs with caching.
Args:
cache_manager (CacheManager): The cache manager to handle caching operations.
**kwargs: Additional keyword arguments.
Methods:
streamify: Processes a stream of inputs with cache support, yielding model outputs.
"""
def __init__(self, cache_manager: CacheManager, **kwargs) -> None:
super().__init__(**kwargs)
self._cache_manager = cache_manager
self._client = LLMCacheClient(cache_manager)
super().__init__(**kwargs)
async def map(self, input_value: Dict) -> Dict:
cache_dict = {
"prompt": input_value.get("prompt"),
"model_name": input_value.get("model"),
"temperature": input_value.get("temperature"),
"max_new_tokens": input_value.get("max_new_tokens"),
"top_p": input_value.get("top_p", "1.0"),
# TODO pass model_type
"model_type": input_value.get("model_type", "huggingface"),
}
async def streamify(self, input_value: Dict) -> AsyncIterator[ModelOutput]:
"""Process inputs as a stream with cache support and yield model outputs.
Args:
input_value (Dict): The input value for the model.
Returns:
AsyncIterator[ModelOutput]: An asynchronous iterator of model outputs.
"""
cache_dict = _parse_cache_key_dict(input_value)
llm_cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
llm_cache_value: LLMCacheValue = await self._client.get(llm_cache_key)
logger.info(f"llm_cache_value: {llm_cache_value}")
for out in llm_cache_value.get_value().output:
yield out
class CachedModelOperator(MapOperator[Dict, ModelOutput]):
"""Operator for map-based processing of model outputs with caching.
Args:
cache_manager (CacheManager): Manager for caching operations.
**kwargs: Additional keyword arguments.
Methods:
map: Processes a single input with cache support and returns the model output.
"""
def __init__(self, cache_manager: CacheManager, **kwargs) -> None:
super().__init__(**kwargs)
self._cache_manager = cache_manager
self._client = LLMCacheClient(cache_manager)
async def map(self, input_value: Dict) -> ModelOutput:
"""Process a single input with cache support and return the model output.
Args:
input_value (Dict): The input value for the model.
Returns:
ModelOutput: The output from the model.
"""
cache_dict = _parse_cache_key_dict(input_value)
llm_cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
llm_cache_value: LLMCacheValue = await self._client.get(llm_cache_key)
logger.info(f"llm_cache_value: {llm_cache_value}")
return llm_cache_value.get_value().output
class ModelCacheBranchOperator(BranchOperator[Dict, Dict]):
"""
A branch operator that decides whether to use cached data or to process data using the model.
Args:
cache_manager (CacheManager): The cache manager for managing cache operations.
model_task_name (str): The name of the task to process data using the model.
cache_task_name (str): The name of the task to process data using the cache.
**kwargs: Additional keyword arguments.
"""
def __init__(
self,
cache_manager: CacheManager,
model_task_name: str,
cache_task_name: str,
**kwargs,
):
super().__init__(branches=None, **kwargs)
self._cache_manager = cache_manager
self._client = LLMCacheClient(cache_manager)
self._model_task_name = model_task_name
self._cache_task_name = cache_task_name
async def branchs(self) -> Dict[BranchFunc[Dict], Union[BaseOperator, str]]:
"""Defines branch logic based on cache availability.
Returns:
Dict[BranchFunc[Dict], Union[BaseOperator, str]]: A dictionary mapping branch functions to task names.
"""
async def check_cache_true(input_value: Dict) -> bool:
# Check if the cache contains the result for the given input
cache_dict = _parse_cache_key_dict(input_value)
cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
cache_value = await self._client.get(cache_key)
logger.debug(
@@ -72,15 +180,28 @@ class ModelCachePreOperator(MapOperator[Dict, Dict]):
await self.current_dag_context.save_to_share_data(
_LLM_MODEL_INPUT_VALUE_KEY, cache_key
)
if cache_value:
logger.info(f"The model output has cached, cache_value: {cache_value}")
await self.current_dag_context.save_to_share_data(
_LLM_MODEL_OUTPUT_CACHE_KEY, cache_value
)
return input_value
return True if cache_value else False
async def check_cache_false(input_value: Dict):
# Inverse of check_cache_true
return not await check_cache_true(input_value)
return {
check_cache_true: self._cache_task_name,
check_cache_false: self._model_task_name,
}
class ModelStreamCacheOperator(TransformStreamAbsOperator[ModelOutput, ModelOutput]):
class ModelStreamSaveCacheOperator(
TransformStreamAbsOperator[ModelOutput, ModelOutput]
):
"""An operator to save the stream of model outputs to cache.
Args:
cache_manager (CacheManager): The cache manager for handling cache operations.
**kwargs: Additional keyword arguments.
"""
def __init__(self, cache_manager: CacheManager, **kwargs):
self._cache_manager = cache_manager
self._client = LLMCacheClient(cache_manager)
@@ -89,6 +210,14 @@ class ModelStreamCacheOperator(TransformStreamAbsOperator[ModelOutput, ModelOutp
async def transform_stream(
self, input_value: AsyncIterator[ModelOutput]
) -> AsyncIterator[ModelOutput]:
"""Transforms the input stream by saving the outputs to cache.
Args:
input_value (AsyncIterator[ModelOutput]): An asynchronous iterator of model outputs.
Returns:
AsyncIterator[ModelOutput]: The same input iterator, but the outputs are saved to cache.
"""
llm_cache_key: LLMCacheKey = None
outputs = []
async for out in input_value:
@@ -103,13 +232,28 @@ class ModelStreamCacheOperator(TransformStreamAbsOperator[ModelOutput, ModelOutp
await self._client.set(llm_cache_key, llm_cache_value)
class ModelCacheOperator(MapOperator[ModelOutput, ModelOutput]):
class ModelSaveCacheOperator(MapOperator[ModelOutput, ModelOutput]):
"""An operator to save a single model output to cache.
Args:
cache_manager (CacheManager): The cache manager for handling cache operations.
**kwargs: Additional keyword arguments.
"""
def __init__(self, cache_manager: CacheManager, **kwargs):
self._cache_manager = cache_manager
self._client = LLMCacheClient(cache_manager)
super().__init__(**kwargs)
async def map(self, input_value: ModelOutput) -> ModelOutput:
"""Saves a single model output to cache and returns it.
Args:
input_value (ModelOutput): The output from the model to be cached.
Returns:
ModelOutput: The same input model output.
"""
llm_cache_key: LLMCacheKey = await self.current_dag_context.get_share_data(
_LLM_MODEL_INPUT_VALUE_KEY
)
@@ -117,3 +261,26 @@ class ModelCacheOperator(MapOperator[ModelOutput, ModelOutput]):
if llm_cache_key:
await self._client.set(llm_cache_key, llm_cache_value)
return input_value
def _parse_cache_key_dict(input_value: Dict) -> Dict:
"""Parses and extracts relevant fields from input to form a cache key dictionary.
Args:
input_value (Dict): The input dictionary containing model and prompt parameters.
Returns:
Dict: A dictionary used for generating cache keys.
"""
prompt: str = input_value.get("prompt")
if prompt:
prompt = prompt.strip()
return {
"prompt": prompt,
"model_name": input_value.get("model"),
"temperature": input_value.get("temperature"),
"max_new_tokens": input_value.get("max_new_tokens"),
"top_p": input_value.get("top_p", "1.0"),
# TODO pass model_type
"model_type": input_value.get("model_type", "huggingface"),
}

View File

@@ -566,29 +566,83 @@ class BaseChat(ABC):
def _build_model_operator(
is_stream: bool = False, dag_name: str = "llm_model_dag"
) -> BaseOperator:
"""Builds and returns a model processing workflow (DAG) operator.
This function constructs a Directed Acyclic Graph (DAG) for processing data using a model.
It includes caching and branching logic to either fetch results from a cache or process
data using the model. It supports both streaming and non-streaming modes.
.. code-block:: python
input_node >> cache_check_branch_node
cache_check_branch_node >> model_node >> save_cached_node >> join_node
cache_check_branch_node >> cached_node >> join_node
equivalent to::
-> model_node -> save_cached_node ->
/ \
input_node -> cache_check_branch_node ---> join_node
\ /
-> cached_node ------------------- ->
Args:
is_stream (bool): Flag to determine if the operator should process data in streaming mode.
dag_name (str): Name of the DAG.
Returns:
BaseOperator: The final operator in the constructed DAG, typically a join node.
"""
from pilot.model.cluster import WorkerManagerFactory
from pilot.awel import JoinOperator
from pilot.model.operator.model_operator import (
ModelCacheOperator,
ModelStreamCacheOperator,
ModelCachePreOperator,
ModelCacheBranchOperator,
CachedModelStreamOperator,
CachedModelOperator,
ModelSaveCacheOperator,
ModelStreamSaveCacheOperator,
)
from pilot.cache import CacheManager
# Fetch worker and cache managers from the system configuration
worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
cache_manager: CacheManager = CFG.SYSTEM_APP.get_component(
ComponentType.MODEL_CACHE_MANAGER, CacheManager
)
# Define task names for the model and cache nodes
model_task_name = "llm_model_node"
cache_task_name = "llm_model_cache_node"
with DAG(dag_name):
# Create an input node
input_node = InputOperator(SimpleCallDataInputSource())
cache_check_node = ModelCachePreOperator(cache_manager)
# Determine if the workflow should operate in streaming mode
if is_stream:
model_node = ModelStreamOperator(worker_manager)
cache_node = ModelStreamCacheOperator(cache_manager)
model_node = ModelStreamOperator(worker_manager, task_name=model_task_name)
cached_node = CachedModelStreamOperator(
cache_manager, task_name=cache_task_name
)
save_cached_node = ModelStreamSaveCacheOperator(cache_manager)
else:
model_node = ModelOperator(worker_manager)
cache_node = ModelCacheOperator(cache_manager)
input_node >> cache_check_node >> model_node >> cache_node
return cache_node
model_node = ModelOperator(worker_manager, task_name=model_task_name)
cached_node = CachedModelOperator(cache_manager, task_name=cache_task_name)
save_cached_node = ModelSaveCacheOperator(cache_manager)
# Create a branch node to decide between fetching from cache or processing with the model
cache_check_branch_node = ModelCacheBranchOperator(
cache_manager,
model_task_name="llm_model_node",
cache_task_name="llm_model_cache_node",
)
# Create a join node to merge outputs from the model and cache nodes, just keep the fist not empty output
join_node = JoinOperator(
combine_function=lambda model_out, cache_out: cache_out or model_out
)
# Define the workflow structure using the >> operator
input_node >> cache_check_branch_node
cache_check_branch_node >> model_node >> save_cached_node >> join_node
cache_check_branch_node >> cached_node >> join_node
return join_node