mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-16 14:40:56 +00:00
feat(model): Support BranchOperator
This commit is contained in:
@@ -9,6 +9,7 @@ from .operator.common_operator import (
|
|||||||
MapOperator,
|
MapOperator,
|
||||||
BranchOperator,
|
BranchOperator,
|
||||||
InputOperator,
|
InputOperator,
|
||||||
|
BranchFunc,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .operator.stream_operator import (
|
from .operator.stream_operator import (
|
||||||
@@ -25,6 +26,7 @@ from .task.task_impl import (
|
|||||||
DefaultInputContext,
|
DefaultInputContext,
|
||||||
SimpleTaskOutput,
|
SimpleTaskOutput,
|
||||||
SimpleStreamTaskOutput,
|
SimpleStreamTaskOutput,
|
||||||
|
_is_async_iterator,
|
||||||
)
|
)
|
||||||
from .runner.local_runner import DefaultWorkflowRunner
|
from .runner.local_runner import DefaultWorkflowRunner
|
||||||
|
|
||||||
@@ -38,6 +40,7 @@ __all__ = [
|
|||||||
"MapOperator",
|
"MapOperator",
|
||||||
"BranchOperator",
|
"BranchOperator",
|
||||||
"InputOperator",
|
"InputOperator",
|
||||||
|
"BranchFunc",
|
||||||
"WorkflowRunner",
|
"WorkflowRunner",
|
||||||
"TaskState",
|
"TaskState",
|
||||||
"TaskOutput",
|
"TaskOutput",
|
||||||
|
@@ -143,7 +143,9 @@ class DAGNode(DependencyMixin, ABC):
|
|||||||
resource_group: Optional[ResourceGroup] = None
|
resource_group: Optional[ResourceGroup] = None
|
||||||
"""The resource group of current DAGNode"""
|
"""The resource group of current DAGNode"""
|
||||||
|
|
||||||
def __init__(self, dag: Optional["DAG"] = None, node_id: str = None) -> None:
|
def __init__(
|
||||||
|
self, dag: Optional["DAG"] = None, node_id: str = None, node_name: str = None
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._upstream: List["DAGNode"] = []
|
self._upstream: List["DAGNode"] = []
|
||||||
self._downstream: List["DAGNode"] = []
|
self._downstream: List["DAGNode"] = []
|
||||||
@@ -151,6 +153,7 @@ class DAGNode(DependencyMixin, ABC):
|
|||||||
if not node_id and self._dag:
|
if not node_id and self._dag:
|
||||||
node_id = self._dag._new_node_id()
|
node_id = self._dag._new_node_id()
|
||||||
self._node_id: str = node_id
|
self._node_id: str = node_id
|
||||||
|
self._node_name: str = node_name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def node_id(self) -> str:
|
def node_id(self) -> str:
|
||||||
@@ -159,6 +162,21 @@ class DAGNode(DependencyMixin, ABC):
|
|||||||
def set_node_id(self, node_id: str) -> None:
|
def set_node_id(self, node_id: str) -> None:
|
||||||
self._node_id = node_id
|
self._node_id = node_id
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
if self.node_id:
|
||||||
|
return hash(self.node_id)
|
||||||
|
else:
|
||||||
|
return super().__hash__()
|
||||||
|
|
||||||
|
def __eq__(self, other: Any) -> bool:
|
||||||
|
if not isinstance(other, DAGNode):
|
||||||
|
return False
|
||||||
|
return self.node_id == other.node_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def node_name(self) -> str:
|
||||||
|
return self._node_name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dag(self) -> "DAGNode":
|
def dag(self) -> "DAGNode":
|
||||||
return self._dag
|
return self._dag
|
||||||
|
@@ -100,6 +100,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
task_id: Optional[str] = None,
|
task_id: Optional[str] = None,
|
||||||
|
task_name: Optional[str] = None,
|
||||||
dag: Optional[DAG] = None,
|
dag: Optional[DAG] = None,
|
||||||
runner: WorkflowRunner = None,
|
runner: WorkflowRunner = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -109,7 +110,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
|||||||
Args:
|
Args:
|
||||||
runner (WorkflowRunner, optional): The runner used to execute the workflow. Defaults to None.
|
runner (WorkflowRunner, optional): The runner used to execute the workflow. Defaults to None.
|
||||||
"""
|
"""
|
||||||
super().__init__(node_id=task_id, dag=dag, **kwargs)
|
super().__init__(node_id=task_id, node_name=task_name, dag=dag, **kwargs)
|
||||||
if not runner:
|
if not runner:
|
||||||
from pilot.awel import DefaultWorkflowRunner
|
from pilot.awel import DefaultWorkflowRunner
|
||||||
|
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
from typing import Generic, Dict, List, Union, Callable, Any, AsyncIterator, Awaitable
|
from typing import Generic, Dict, List, Union, Callable, Any, AsyncIterator, Awaitable
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
from ..dag.base import DAGContext
|
from ..dag.base import DAGContext
|
||||||
from ..task.base import (
|
from ..task.base import (
|
||||||
@@ -14,6 +15,9 @@ from ..task.base import (
|
|||||||
from .base import BaseOperator
|
from .base import BaseOperator
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class JoinOperator(BaseOperator, Generic[OUT]):
|
class JoinOperator(BaseOperator, Generic[OUT]):
|
||||||
"""Operator that joins inputs using a custom combine function.
|
"""Operator that joins inputs using a custom combine function.
|
||||||
|
|
||||||
@@ -141,31 +145,36 @@ class MapOperator(BaseOperator, Generic[IN, OUT]):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
BranchFunc = Union[Callable[[Any], bool], Callable[[Any], Awaitable[bool]]]
|
BranchFunc = Union[Callable[[IN], bool], Callable[[IN], Awaitable[bool]]]
|
||||||
|
|
||||||
|
|
||||||
class BranchOperator(BaseOperator, Generic[OUT]):
|
class BranchOperator(BaseOperator, Generic[IN, OUT]):
|
||||||
"""Operator node that branches the workflow based on a provided function.
|
"""Operator node that branches the workflow based on a provided function.
|
||||||
|
|
||||||
This node filters its input data using a branching function and
|
This node filters its input data using a branching function and
|
||||||
allows for conditional paths in the workflow.
|
allows for conditional paths in the workflow.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, branches: Dict[BranchFunc, BaseOperator], **kwargs):
|
def __init__(
|
||||||
|
self, branches: Dict[BranchFunc[IN], Union[BaseOperator, str]], **kwargs
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initializes a BranchDAGNode with a branching function.
|
Initializes a BranchDAGNode with a branching function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
branches (Dict[BranchFunc, RunnableDAGNode]): Dict of function that defines the branching condition.
|
branches (Dict[BranchFunc[IN], Union[BaseOperator, str]]): Dict of function that defines the branching condition.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the branch_function is not callable.
|
ValueError: If the branch_function is not callable.
|
||||||
"""
|
"""
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
for branch_function in branches.keys():
|
if branches:
|
||||||
if not callable(branch_function):
|
for branch_function, value in branches.items():
|
||||||
raise ValueError("branch_function must be callable")
|
if not callable(branch_function):
|
||||||
self.branches = branches
|
raise ValueError("branch_function must be callable")
|
||||||
|
if isinstance(value, BaseOperator):
|
||||||
|
branches[branch_function] = value.node_name or value.node_name
|
||||||
|
self._branches = branches
|
||||||
|
|
||||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||||
"""Run the branching operation on the DAG context's inputs.
|
"""Run the branching operation on the DAG context's inputs.
|
||||||
@@ -186,27 +195,36 @@ class BranchOperator(BaseOperator, Generic[OUT]):
|
|||||||
if not task_input.check_single_parent():
|
if not task_input.check_single_parent():
|
||||||
raise ValueError("BranchDAGNode expects single parent")
|
raise ValueError("BranchDAGNode expects single parent")
|
||||||
|
|
||||||
|
branches = self._branches
|
||||||
|
if not branches:
|
||||||
|
branches = await self.branchs()
|
||||||
|
|
||||||
branch_func_tasks = []
|
branch_func_tasks = []
|
||||||
branch_nodes: List[BaseOperator] = []
|
branch_nodes: List[str] = []
|
||||||
for func, node in self.branches.items():
|
for func, node_name in branches.items():
|
||||||
branch_nodes.append(node)
|
branch_nodes.append(node_name)
|
||||||
branch_func_tasks.append(
|
branch_func_tasks.append(
|
||||||
curr_task_ctx.task_input.predicate_map(func, failed_value=None)
|
curr_task_ctx.task_input.predicate_map(func, failed_value=None)
|
||||||
)
|
)
|
||||||
|
|
||||||
branch_input_ctxs: List[InputContext] = await asyncio.gather(*branch_func_tasks)
|
branch_input_ctxs: List[InputContext] = await asyncio.gather(*branch_func_tasks)
|
||||||
parent_output = task_input.parent_outputs[0].task_output
|
parent_output = task_input.parent_outputs[0].task_output
|
||||||
curr_task_ctx.set_task_output(parent_output)
|
curr_task_ctx.set_task_output(parent_output)
|
||||||
|
skip_node_names = []
|
||||||
for i, ctx in enumerate(branch_input_ctxs):
|
for i, ctx in enumerate(branch_input_ctxs):
|
||||||
node = branch_nodes[i]
|
node_name = branch_nodes[i]
|
||||||
|
branch_out = ctx.parent_outputs[0].task_output
|
||||||
|
logger.info(
|
||||||
|
f"branch_input_ctxs {i} result {branch_out.output}, is_empty: {branch_out.is_empty}"
|
||||||
|
)
|
||||||
if ctx.parent_outputs[0].task_output.is_empty:
|
if ctx.parent_outputs[0].task_output.is_empty:
|
||||||
# Skip current node
|
logger.info(f"Skip node name {node_name}")
|
||||||
# node.current_task_context.set_current_state(TaskState.SKIP)
|
skip_node_names.append(node_name)
|
||||||
pass
|
curr_task_ctx.update_metadata("skip_node_names", skip_node_names)
|
||||||
else:
|
return parent_output
|
||||||
pass
|
|
||||||
|
async def branchs(self) -> Dict[BranchFunc[IN], Union[BaseOperator, str]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class InputOperator(BaseOperator, Generic[OUT]):
|
class InputOperator(BaseOperator, Generic[OUT]):
|
||||||
|
@@ -1,9 +1,12 @@
|
|||||||
from typing import List, Set, Optional, Dict
|
from typing import List, Set, Optional, Dict
|
||||||
import uuid
|
import uuid
|
||||||
|
import logging
|
||||||
from ..dag.base import DAG
|
from ..dag.base import DAG
|
||||||
|
|
||||||
from ..operator.base import BaseOperator, CALL_DATA
|
from ..operator.base import BaseOperator, CALL_DATA
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DAGNodeInstance:
|
class DAGNodeInstance:
|
||||||
def __init__(self, node_instance: DAG) -> None:
|
def __init__(self, node_instance: DAG) -> None:
|
||||||
@@ -45,14 +48,19 @@ def _save_call_data(
|
|||||||
root_nodes: List[BaseOperator], call_data: CALL_DATA
|
root_nodes: List[BaseOperator], call_data: CALL_DATA
|
||||||
) -> Dict[str, Dict]:
|
) -> Dict[str, Dict]:
|
||||||
id2call_data = {}
|
id2call_data = {}
|
||||||
|
logger.debug(f"_save_call_data: {call_data}, root_nodes: {root_nodes}")
|
||||||
if not call_data:
|
if not call_data:
|
||||||
return id2call_data
|
return id2call_data
|
||||||
if len(root_nodes) == 1:
|
if len(root_nodes) == 1:
|
||||||
node = root_nodes[0]
|
node = root_nodes[0]
|
||||||
|
logger.info(f"Save call data to node {node.node_id}, call_data: {call_data}")
|
||||||
id2call_data[node.node_id] = call_data
|
id2call_data[node.node_id] = call_data
|
||||||
else:
|
else:
|
||||||
for node in root_nodes:
|
for node in root_nodes:
|
||||||
node_id = node.node_id
|
node_id = node.node_id
|
||||||
|
logger.info(
|
||||||
|
f"Save call data to node {node.node_id}, call_data: {call_data.get(node_id)}"
|
||||||
|
)
|
||||||
id2call_data[node_id] = call_data.get(node_id)
|
id2call_data[node_id] = call_data.get(node_id)
|
||||||
return id2call_data
|
return id2call_data
|
||||||
|
|
||||||
@@ -71,4 +79,4 @@ def _build_from_end_node(end_node: BaseOperator) -> List[BaseOperator]:
|
|||||||
|
|
||||||
|
|
||||||
def _get_root_nodes(nodes: List[BaseOperator]) -> List[BaseOperator]:
|
def _get_root_nodes(nodes: List[BaseOperator]) -> List[BaseOperator]:
|
||||||
return list(filter(lambda x: not x.upstream, nodes))
|
return list(set(filter(lambda x: not x.upstream, nodes)))
|
||||||
|
@@ -1,10 +1,15 @@
|
|||||||
from typing import Dict, Optional
|
from typing import Dict, Optional, Set, List
|
||||||
|
import logging
|
||||||
|
|
||||||
from ..dag.base import DAGContext
|
from ..dag.base import DAGContext
|
||||||
from ..operator.base import WorkflowRunner, BaseOperator, CALL_DATA
|
from ..operator.base import WorkflowRunner, BaseOperator, CALL_DATA
|
||||||
|
from ..operator.common_operator import BranchOperator, JoinOperator
|
||||||
from ..task.base import TaskContext, TaskState
|
from ..task.base import TaskContext, TaskState
|
||||||
from ..task.task_impl import DefaultInputContext, DefaultTaskContext
|
from ..task.task_impl import DefaultInputContext, DefaultTaskContext, SimpleTaskOutput
|
||||||
from .job_manager import JobManager
|
from .job_manager import JobManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DefaultWorkflowRunner(WorkflowRunner):
|
class DefaultWorkflowRunner(WorkflowRunner):
|
||||||
async def execute_workflow(
|
async def execute_workflow(
|
||||||
@@ -13,10 +18,16 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
|||||||
# Create DAG context
|
# Create DAG context
|
||||||
dag_ctx = DAGContext()
|
dag_ctx = DAGContext()
|
||||||
job_manager = JobManager.build_from_end_node(node, call_data)
|
job_manager = JobManager.build_from_end_node(node, call_data)
|
||||||
|
logger.info(
|
||||||
|
f"Begin run workflow from end operator, id: {node.node_id}, call_data: {call_data}"
|
||||||
|
)
|
||||||
dag = node.dag
|
dag = node.dag
|
||||||
# Save node output
|
# Save node output
|
||||||
node_outputs: Dict[str, TaskContext] = {}
|
node_outputs: Dict[str, TaskContext] = {}
|
||||||
await self._execute_node(job_manager, node, dag_ctx, node_outputs)
|
skip_node_ids = set()
|
||||||
|
await self._execute_node(
|
||||||
|
job_manager, node, dag_ctx, node_outputs, skip_node_ids
|
||||||
|
)
|
||||||
|
|
||||||
return dag_ctx
|
return dag_ctx
|
||||||
|
|
||||||
@@ -26,6 +37,7 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
|||||||
node: BaseOperator,
|
node: BaseOperator,
|
||||||
dag_ctx: DAGContext,
|
dag_ctx: DAGContext,
|
||||||
node_outputs: Dict[str, TaskContext],
|
node_outputs: Dict[str, TaskContext],
|
||||||
|
skip_node_ids: Set[str],
|
||||||
):
|
):
|
||||||
# Skip run node
|
# Skip run node
|
||||||
if node.node_id in node_outputs:
|
if node.node_id in node_outputs:
|
||||||
@@ -35,18 +47,9 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
|||||||
for upstream_node in node.upstream:
|
for upstream_node in node.upstream:
|
||||||
if isinstance(upstream_node, BaseOperator):
|
if isinstance(upstream_node, BaseOperator):
|
||||||
await self._execute_node(
|
await self._execute_node(
|
||||||
job_manager, upstream_node, dag_ctx, node_outputs
|
job_manager, upstream_node, dag_ctx, node_outputs, skip_node_ids
|
||||||
)
|
)
|
||||||
# if node.current_task_context.current_state == TaskState.SKIP:
|
|
||||||
# return
|
|
||||||
|
|
||||||
# for upstream_node in node.upstream:
|
|
||||||
# if (
|
|
||||||
# isinstance(upstream_node, BaseOperator)
|
|
||||||
# and upstream_node.current_task_context.current_state == TaskState.SKIP
|
|
||||||
# ):
|
|
||||||
# return
|
|
||||||
# Get the input from upstream node
|
|
||||||
inputs = [
|
inputs = [
|
||||||
node_outputs[upstream_node.node_id] for upstream_node in node.upstream
|
node_outputs[upstream_node.node_id] for upstream_node in node.upstream
|
||||||
]
|
]
|
||||||
@@ -56,13 +59,48 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
|||||||
|
|
||||||
task_ctx.set_task_input(input_ctx)
|
task_ctx.set_task_input(input_ctx)
|
||||||
dag_ctx.set_current_task_context(task_ctx)
|
dag_ctx.set_current_task_context(task_ctx)
|
||||||
|
|
||||||
task_ctx.set_current_state(TaskState.RUNNING)
|
task_ctx.set_current_state(TaskState.RUNNING)
|
||||||
|
|
||||||
|
if node.node_id in skip_node_ids:
|
||||||
|
task_ctx.set_current_state(TaskState.SKIP)
|
||||||
|
task_ctx.set_task_output(SimpleTaskOutput(None))
|
||||||
|
node_outputs[node.node_id] = task_ctx
|
||||||
|
return
|
||||||
try:
|
try:
|
||||||
# print(f"Begin run {node}")
|
logger.info(
|
||||||
|
f"Begin run operator, node id: {node.node_id}, node name: {node.node_name}, cls: {node}"
|
||||||
|
)
|
||||||
await node._run(dag_ctx)
|
await node._run(dag_ctx)
|
||||||
node_outputs[node.node_id] = dag_ctx.current_task_context
|
node_outputs[node.node_id] = dag_ctx.current_task_context
|
||||||
task_ctx.set_current_state(TaskState.SUCCESS)
|
task_ctx.set_current_state(TaskState.SUCCESS)
|
||||||
|
|
||||||
|
if isinstance(node, BranchOperator):
|
||||||
|
skip_nodes = task_ctx.metadata.get("skip_node_names", [])
|
||||||
|
logger.info(
|
||||||
|
f"Current is branch operator, skip node names: {skip_nodes}"
|
||||||
|
)
|
||||||
|
_skip_current_downstream_by_node_name(node, skip_nodes, skip_node_ids)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.info(f"Run operator {node.node_id} error, error message: {str(e)}")
|
||||||
task_ctx.set_current_state(TaskState.FAILED)
|
task_ctx.set_current_state(TaskState.FAILED)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
def _skip_current_downstream_by_node_name(
|
||||||
|
branch_node: BranchOperator, skip_nodes: List[str], skip_node_ids: Set[str]
|
||||||
|
):
|
||||||
|
if not skip_nodes:
|
||||||
|
return
|
||||||
|
for child in branch_node.downstream:
|
||||||
|
if child.node_name in skip_nodes:
|
||||||
|
logger.info(f"Skip node name {child.node_name}, node id {child.node_id}")
|
||||||
|
_skip_downstream_by_id(child, skip_node_ids)
|
||||||
|
|
||||||
|
|
||||||
|
def _skip_downstream_by_id(node: BaseOperator, skip_node_ids: Set[str]):
|
||||||
|
if isinstance(node, JoinOperator):
|
||||||
|
# Not skip join node
|
||||||
|
return
|
||||||
|
skip_node_ids.add(node.node_id)
|
||||||
|
for child in node.downstream:
|
||||||
|
_skip_downstream_by_id(child, skip_node_ids)
|
||||||
|
@@ -81,6 +81,10 @@ class TaskOutput(ABC, Generic[T]):
|
|||||||
output_data (Union[T, AsyncIterator[T]]): Output data.
|
output_data (Union[T, AsyncIterator[T]]): Output data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def new_output(self) -> "TaskOutput[T]":
|
||||||
|
"""Create new output object"""
|
||||||
|
|
||||||
async def map(self, map_func) -> "TaskOutput[T]":
|
async def map(self, map_func) -> "TaskOutput[T]":
|
||||||
"""Apply a mapping function to the task's output.
|
"""Apply a mapping function to the task's output.
|
||||||
|
|
||||||
@@ -334,13 +338,18 @@ class InputContext(ABC):
|
|||||||
"""
|
"""
|
||||||
return len(self.parent_outputs) == 1
|
return len(self.parent_outputs) == 1
|
||||||
|
|
||||||
def check_stream(self) -> bool:
|
def check_stream(self, skip_empty: bool = False) -> bool:
|
||||||
"""Check if all parent outputs are streams.
|
"""Check if all parent outputs are streams.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skip_empty (bool): Skip empty output or not.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if all parent outputs are streams, False otherwise.
|
bool: True if all parent outputs are streams, False otherwise.
|
||||||
"""
|
"""
|
||||||
for out in self.parent_outputs:
|
for out in self.parent_outputs:
|
||||||
|
if out.task_output.is_empty and skip_empty:
|
||||||
|
continue
|
||||||
if not (out.task_output and out.task_output.is_stream):
|
if not (out.task_output and out.task_output.is_stream):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
@@ -13,10 +13,13 @@ from typing import (
|
|||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
from .base import TaskOutput, TaskContext, TaskState, InputContext, InputSource, T
|
from .base import TaskOutput, TaskContext, TaskState, InputContext, InputSource, T
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def _reduce_stream(stream: AsyncIterator, reduce_function) -> Any:
|
async def _reduce_stream(stream: AsyncIterator, reduce_function) -> Any:
|
||||||
# Init accumulator
|
# Init accumulator
|
||||||
try:
|
try:
|
||||||
@@ -44,6 +47,9 @@ class SimpleTaskOutput(TaskOutput[T], Generic[T]):
|
|||||||
def set_output(self, output_data: T | AsyncIterator[T]) -> None:
|
def set_output(self, output_data: T | AsyncIterator[T]) -> None:
|
||||||
self._data = output_data
|
self._data = output_data
|
||||||
|
|
||||||
|
def new_output(self) -> TaskOutput[T]:
|
||||||
|
return SimpleTaskOutput(None)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_empty(self) -> bool:
|
def is_empty(self) -> bool:
|
||||||
return not self._data
|
return not self._data
|
||||||
@@ -89,6 +95,9 @@ class SimpleStreamTaskOutput(TaskOutput[T], Generic[T]):
|
|||||||
def set_output(self, output_data: T | AsyncIterator[T]) -> None:
|
def set_output(self, output_data: T | AsyncIterator[T]) -> None:
|
||||||
self._data = output_data
|
self._data = output_data
|
||||||
|
|
||||||
|
def new_output(self) -> TaskOutput[T]:
|
||||||
|
return SimpleStreamTaskOutput(None)
|
||||||
|
|
||||||
async def map(self, map_func) -> TaskOutput[T]:
|
async def map(self, map_func) -> TaskOutput[T]:
|
||||||
is_async = asyncio.iscoroutinefunction(map_func)
|
is_async = asyncio.iscoroutinefunction(map_func)
|
||||||
|
|
||||||
@@ -213,7 +222,8 @@ class DefaultTaskContext(TaskContext, Generic[T]):
|
|||||||
self._task_state = task_state
|
self._task_state = task_state
|
||||||
|
|
||||||
def new_ctx(self) -> TaskContext:
|
def new_ctx(self) -> TaskContext:
|
||||||
return DefaultTaskContext(self._task_id, self._task_state, self._output)
|
new_output = self._output.new_output()
|
||||||
|
return DefaultTaskContext(self._task_id, self._task_state, new_output)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def metadata(self) -> Dict[str, Any]:
|
def metadata(self) -> Dict[str, Any]:
|
||||||
@@ -259,9 +269,17 @@ class DefaultInputContext(InputContext):
|
|||||||
async def map_all(self, map_func: Callable[..., Any]) -> InputContext:
|
async def map_all(self, map_func: Callable[..., Any]) -> InputContext:
|
||||||
if not self._outputs:
|
if not self._outputs:
|
||||||
return DefaultInputContext([])
|
return DefaultInputContext([])
|
||||||
is_steam = self._outputs[0].task_output.is_stream
|
# Some parent may be empty
|
||||||
|
not_empty_idx = 0
|
||||||
|
for i, p in enumerate(self._outputs):
|
||||||
|
if p.task_output.is_empty:
|
||||||
|
continue
|
||||||
|
not_empty_idx = i
|
||||||
|
break
|
||||||
|
# All output is empty?
|
||||||
|
is_steam = self._outputs[not_empty_idx].task_output.is_stream
|
||||||
if is_steam:
|
if is_steam:
|
||||||
if not self.check_stream():
|
if not self.check_stream(skip_empty=True):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The output in all tasks must has same output format to map_all"
|
"The output in all tasks must has same output format to map_all"
|
||||||
)
|
)
|
||||||
@@ -275,9 +293,11 @@ class DefaultInputContext(InputContext):
|
|||||||
map_res = await map_func(*outputs)
|
map_res = await map_func(*outputs)
|
||||||
else:
|
else:
|
||||||
map_res = map_func(*outputs)
|
map_res = map_func(*outputs)
|
||||||
|
single_output: TaskContext = self._outputs[not_empty_idx].new_ctx()
|
||||||
single_output: TaskContext = self._outputs[0].new_ctx()
|
|
||||||
single_output.task_output.set_output(map_res)
|
single_output.task_output.set_output(map_res)
|
||||||
|
logger.debug(
|
||||||
|
f"Current map_all map_res: {map_res}, is steam: {single_output.task_output.is_stream}"
|
||||||
|
)
|
||||||
return DefaultInputContext([single_output])
|
return DefaultInputContext([single_output])
|
||||||
|
|
||||||
async def reduce(self, reduce_func: Callable[[Any], Any]) -> InputContext:
|
async def reduce(self, reduce_func: Callable[[Any], Any]) -> InputContext:
|
||||||
@@ -311,6 +331,7 @@ class DefaultInputContext(InputContext):
|
|||||||
for i, task_ctx in enumerate(new_outputs):
|
for i, task_ctx in enumerate(new_outputs):
|
||||||
task_ctx: TaskContext = task_ctx
|
task_ctx: TaskContext = task_ctx
|
||||||
if results[i]:
|
if results[i]:
|
||||||
|
task_ctx.task_output.set_output(True)
|
||||||
result_outputs.append(task_ctx)
|
result_outputs.append(task_ctx)
|
||||||
else:
|
else:
|
||||||
task_ctx.task_output.set_output(failed_value)
|
task_ctx.task_output.set_output(failed_value)
|
||||||
|
@@ -108,21 +108,34 @@ async def test_join_node(runner: WorkflowRunner, input_nodes: List[InputOperator
|
|||||||
"input_node, is_odd",
|
"input_node, is_odd",
|
||||||
[
|
[
|
||||||
({"outputs": [0]}, False),
|
({"outputs": [0]}, False),
|
||||||
# ({"outputs": [1]}, False),
|
({"outputs": [1]}, True),
|
||||||
],
|
],
|
||||||
indirect=["input_node"],
|
indirect=["input_node"],
|
||||||
)
|
)
|
||||||
async def test_branch_node(
|
async def test_branch_node(
|
||||||
runner: WorkflowRunner, input_node: InputOperator, is_odd: bool
|
runner: WorkflowRunner, input_node: InputOperator, is_odd: bool
|
||||||
):
|
):
|
||||||
|
def join_func(o1, o2) -> int:
|
||||||
|
print(f"join func result, o1: {o1}, o2: {o2}")
|
||||||
|
return o1 or o2
|
||||||
|
|
||||||
with DAG("test_join_node") as dag:
|
with DAG("test_join_node") as dag:
|
||||||
odd_node = MapOperator(lambda x: 999, task_id="odd_node")
|
odd_node = MapOperator(
|
||||||
even_node = MapOperator(lambda x: 888, task_id="even_node")
|
lambda x: 999, task_id="odd_node", task_name="odd_node_name"
|
||||||
|
)
|
||||||
|
even_node = MapOperator(
|
||||||
|
lambda x: 888, task_id="even_node", task_name="even_node_name"
|
||||||
|
)
|
||||||
|
join_node = JoinOperator(join_func)
|
||||||
branch_node = BranchOperator(
|
branch_node = BranchOperator(
|
||||||
{lambda x: x % 2 == 1: odd_node, lambda x: x % 2 == 0: even_node}
|
{lambda x: x % 2 == 1: odd_node, lambda x: x % 2 == 0: even_node}
|
||||||
)
|
)
|
||||||
|
branch_node >> odd_node >> join_node
|
||||||
|
branch_node >> even_node >> join_node
|
||||||
|
|
||||||
input_node >> branch_node
|
input_node >> branch_node
|
||||||
|
|
||||||
odd_res: DAGContext[int] = await runner.execute_workflow(odd_node)
|
res: DAGContext[int] = await runner.execute_workflow(join_node)
|
||||||
even_res: DAGContext[int] = await runner.execute_workflow(even_node)
|
assert res.current_task_context.current_state == TaskState.SUCCESS
|
||||||
assert branch_node.current_task_context.current_state == TaskState.SUCCESS
|
expect_res = 999 if is_odd else 888
|
||||||
|
assert res.current_task_context.task_output.output == expect_res
|
||||||
|
@@ -1,10 +1,13 @@
|
|||||||
from typing import AsyncIterator, Dict
|
from typing import AsyncIterator, Dict, Union
|
||||||
import logging
|
import logging
|
||||||
from pilot.awel import (
|
from pilot.awel import (
|
||||||
|
BranchFunc,
|
||||||
StreamifyAbsOperator,
|
StreamifyAbsOperator,
|
||||||
|
BranchOperator,
|
||||||
MapOperator,
|
MapOperator,
|
||||||
TransformStreamAbsOperator,
|
TransformStreamAbsOperator,
|
||||||
)
|
)
|
||||||
|
from pilot.awel.operator.base import BaseOperator
|
||||||
from pilot.model.base import ModelOutput
|
from pilot.model.base import ModelOutput
|
||||||
from pilot.model.cluster import WorkerManager
|
from pilot.model.cluster import WorkerManager
|
||||||
from pilot.cache import LLMCacheClient, CacheManager, LLMCacheKey, LLMCacheValue
|
from pilot.cache import LLMCacheClient, CacheManager, LLMCacheKey, LLMCacheValue
|
||||||
@@ -16,71 +19,189 @@ _LLM_MODEL_OUTPUT_CACHE_KEY = "llm_model_output_cache"
|
|||||||
|
|
||||||
|
|
||||||
class ModelStreamOperator(StreamifyAbsOperator[Dict, ModelOutput]):
|
class ModelStreamOperator(StreamifyAbsOperator[Dict, ModelOutput]):
|
||||||
|
"""Operator for streaming processing of model outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
worker_manager (WorkerManager): The manager that handles worker processes for model inference.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
streamify: Asynchronously processes a stream of inputs, yielding model outputs.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, worker_manager: WorkerManager, **kwargs) -> None:
|
def __init__(self, worker_manager: WorkerManager, **kwargs) -> None:
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.worker_manager = worker_manager
|
self.worker_manager = worker_manager
|
||||||
|
|
||||||
async def streamify(self, input_value: Dict) -> AsyncIterator[ModelOutput]:
|
async def streamify(self, input_value: Dict) -> AsyncIterator[ModelOutput]:
|
||||||
llm_cache_value: LLMCacheValue = await self.current_dag_context.get_share_data(
|
"""Process inputs as a stream and yield model outputs.
|
||||||
_LLM_MODEL_OUTPUT_CACHE_KEY
|
|
||||||
)
|
Args:
|
||||||
logger.info(f"llm_cache_value: {llm_cache_value}")
|
input_value (Dict): The input value for the model.
|
||||||
if llm_cache_value:
|
|
||||||
for out in llm_cache_value.get_value().output:
|
Returns:
|
||||||
yield out
|
AsyncIterator[ModelOutput]: An asynchronous iterator of model outputs.
|
||||||
return
|
"""
|
||||||
async for out in self.worker_manager.generate_stream(input_value):
|
async for out in self.worker_manager.generate_stream(input_value):
|
||||||
yield out
|
yield out
|
||||||
|
|
||||||
|
|
||||||
class ModelOperator(MapOperator[Dict, ModelOutput]):
|
class ModelOperator(MapOperator[Dict, ModelOutput]):
|
||||||
|
"""Operator for map-based processing of model outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
worker_manager (WorkerManager): Manager for handling worker processes.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
map: Asynchronously processes a single input and returns the model output.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, worker_manager: WorkerManager, **kwargs) -> None:
|
def __init__(self, worker_manager: WorkerManager, **kwargs) -> None:
|
||||||
self.worker_manager = worker_manager
|
self.worker_manager = worker_manager
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
async def map(self, input_value: Dict) -> ModelOutput:
|
async def map(self, input_value: Dict) -> ModelOutput:
|
||||||
llm_cache_value: LLMCacheValue = await self.current_dag_context.get_share_data(
|
"""Process a single input and return the model output.
|
||||||
_LLM_MODEL_OUTPUT_CACHE_KEY
|
|
||||||
)
|
Args:
|
||||||
logger.info(f"llm_cache_value: {llm_cache_value}")
|
input_value (Dict): The input value for the model.
|
||||||
if llm_cache_value:
|
|
||||||
return llm_cache_value.get_value().output
|
Returns:
|
||||||
|
ModelOutput: The output from the model.
|
||||||
|
"""
|
||||||
return await self.worker_manager.generate(input_value)
|
return await self.worker_manager.generate(input_value)
|
||||||
|
|
||||||
|
|
||||||
class ModelCachePreOperator(MapOperator[Dict, Dict]):
|
class CachedModelStreamOperator(StreamifyAbsOperator[Dict, ModelOutput]):
|
||||||
def __init__(self, cache_manager: CacheManager, **kwargs):
|
"""Operator for streaming processing of model outputs with caching.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache_manager (CacheManager): The cache manager to handle caching operations.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
streamify: Processes a stream of inputs with cache support, yielding model outputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cache_manager: CacheManager, **kwargs) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
self._cache_manager = cache_manager
|
self._cache_manager = cache_manager
|
||||||
self._client = LLMCacheClient(cache_manager)
|
self._client = LLMCacheClient(cache_manager)
|
||||||
|
|
||||||
|
async def streamify(self, input_value: Dict) -> AsyncIterator[ModelOutput]:
|
||||||
|
"""Process inputs as a stream with cache support and yield model outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_value (Dict): The input value for the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AsyncIterator[ModelOutput]: An asynchronous iterator of model outputs.
|
||||||
|
"""
|
||||||
|
cache_dict = _parse_cache_key_dict(input_value)
|
||||||
|
llm_cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
|
||||||
|
llm_cache_value: LLMCacheValue = await self._client.get(llm_cache_key)
|
||||||
|
logger.info(f"llm_cache_value: {llm_cache_value}")
|
||||||
|
for out in llm_cache_value.get_value().output:
|
||||||
|
yield out
|
||||||
|
|
||||||
|
|
||||||
|
class CachedModelOperator(MapOperator[Dict, ModelOutput]):
|
||||||
|
"""Operator for map-based processing of model outputs with caching.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache_manager (CacheManager): Manager for caching operations.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
map: Processes a single input with cache support and returns the model output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cache_manager: CacheManager, **kwargs) -> None:
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
self._cache_manager = cache_manager
|
||||||
|
self._client = LLMCacheClient(cache_manager)
|
||||||
|
|
||||||
async def map(self, input_value: Dict) -> Dict:
|
async def map(self, input_value: Dict) -> ModelOutput:
|
||||||
cache_dict = {
|
"""Process a single input with cache support and return the model output.
|
||||||
"prompt": input_value.get("prompt"),
|
|
||||||
"model_name": input_value.get("model"),
|
Args:
|
||||||
"temperature": input_value.get("temperature"),
|
input_value (Dict): The input value for the model.
|
||||||
"max_new_tokens": input_value.get("max_new_tokens"),
|
|
||||||
"top_p": input_value.get("top_p", "1.0"),
|
Returns:
|
||||||
# TODO pass model_type
|
ModelOutput: The output from the model.
|
||||||
"model_type": input_value.get("model_type", "huggingface"),
|
"""
|
||||||
}
|
cache_dict = _parse_cache_key_dict(input_value)
|
||||||
cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
|
llm_cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
|
||||||
cache_value = await self._client.get(cache_key)
|
llm_cache_value: LLMCacheValue = await self._client.get(llm_cache_key)
|
||||||
logger.debug(
|
logger.info(f"llm_cache_value: {llm_cache_value}")
|
||||||
f"cache_key: {cache_key}, hash key: {hash(cache_key)}, cache_value: {cache_value}"
|
return llm_cache_value.get_value().output
|
||||||
)
|
|
||||||
await self.current_dag_context.save_to_share_data(
|
|
||||||
_LLM_MODEL_INPUT_VALUE_KEY, cache_key
|
class ModelCacheBranchOperator(BranchOperator[Dict, Dict]):
|
||||||
)
|
"""
|
||||||
if cache_value:
|
A branch operator that decides whether to use cached data or to process data using the model.
|
||||||
logger.info(f"The model output has cached, cache_value: {cache_value}")
|
|
||||||
await self.current_dag_context.save_to_share_data(
|
Args:
|
||||||
_LLM_MODEL_OUTPUT_CACHE_KEY, cache_value
|
cache_manager (CacheManager): The cache manager for managing cache operations.
|
||||||
|
model_task_name (str): The name of the task to process data using the model.
|
||||||
|
cache_task_name (str): The name of the task to process data using the cache.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cache_manager: CacheManager,
|
||||||
|
model_task_name: str,
|
||||||
|
cache_task_name: str,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(branches=None, **kwargs)
|
||||||
|
self._cache_manager = cache_manager
|
||||||
|
self._client = LLMCacheClient(cache_manager)
|
||||||
|
self._model_task_name = model_task_name
|
||||||
|
self._cache_task_name = cache_task_name
|
||||||
|
|
||||||
|
async def branchs(self) -> Dict[BranchFunc[Dict], Union[BaseOperator, str]]:
|
||||||
|
"""Defines branch logic based on cache availability.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[BranchFunc[Dict], Union[BaseOperator, str]]: A dictionary mapping branch functions to task names.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def check_cache_true(input_value: Dict) -> bool:
|
||||||
|
# Check if the cache contains the result for the given input
|
||||||
|
cache_dict = _parse_cache_key_dict(input_value)
|
||||||
|
cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
|
||||||
|
cache_value = await self._client.get(cache_key)
|
||||||
|
logger.debug(
|
||||||
|
f"cache_key: {cache_key}, hash key: {hash(cache_key)}, cache_value: {cache_value}"
|
||||||
)
|
)
|
||||||
return input_value
|
await self.current_dag_context.save_to_share_data(
|
||||||
|
_LLM_MODEL_INPUT_VALUE_KEY, cache_key
|
||||||
|
)
|
||||||
|
return True if cache_value else False
|
||||||
|
|
||||||
|
async def check_cache_false(input_value: Dict):
|
||||||
|
# Inverse of check_cache_true
|
||||||
|
return not await check_cache_true(input_value)
|
||||||
|
|
||||||
|
return {
|
||||||
|
check_cache_true: self._cache_task_name,
|
||||||
|
check_cache_false: self._model_task_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class ModelStreamCacheOperator(TransformStreamAbsOperator[ModelOutput, ModelOutput]):
|
class ModelStreamSaveCacheOperator(
|
||||||
|
TransformStreamAbsOperator[ModelOutput, ModelOutput]
|
||||||
|
):
|
||||||
|
"""An operator to save the stream of model outputs to cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache_manager (CacheManager): The cache manager for handling cache operations.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, cache_manager: CacheManager, **kwargs):
|
def __init__(self, cache_manager: CacheManager, **kwargs):
|
||||||
self._cache_manager = cache_manager
|
self._cache_manager = cache_manager
|
||||||
self._client = LLMCacheClient(cache_manager)
|
self._client = LLMCacheClient(cache_manager)
|
||||||
@@ -89,6 +210,14 @@ class ModelStreamCacheOperator(TransformStreamAbsOperator[ModelOutput, ModelOutp
|
|||||||
async def transform_stream(
|
async def transform_stream(
|
||||||
self, input_value: AsyncIterator[ModelOutput]
|
self, input_value: AsyncIterator[ModelOutput]
|
||||||
) -> AsyncIterator[ModelOutput]:
|
) -> AsyncIterator[ModelOutput]:
|
||||||
|
"""Transforms the input stream by saving the outputs to cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_value (AsyncIterator[ModelOutput]): An asynchronous iterator of model outputs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AsyncIterator[ModelOutput]: The same input iterator, but the outputs are saved to cache.
|
||||||
|
"""
|
||||||
llm_cache_key: LLMCacheKey = None
|
llm_cache_key: LLMCacheKey = None
|
||||||
outputs = []
|
outputs = []
|
||||||
async for out in input_value:
|
async for out in input_value:
|
||||||
@@ -103,13 +232,28 @@ class ModelStreamCacheOperator(TransformStreamAbsOperator[ModelOutput, ModelOutp
|
|||||||
await self._client.set(llm_cache_key, llm_cache_value)
|
await self._client.set(llm_cache_key, llm_cache_value)
|
||||||
|
|
||||||
|
|
||||||
class ModelCacheOperator(MapOperator[ModelOutput, ModelOutput]):
|
class ModelSaveCacheOperator(MapOperator[ModelOutput, ModelOutput]):
|
||||||
|
"""An operator to save a single model output to cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache_manager (CacheManager): The cache manager for handling cache operations.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, cache_manager: CacheManager, **kwargs):
|
def __init__(self, cache_manager: CacheManager, **kwargs):
|
||||||
self._cache_manager = cache_manager
|
self._cache_manager = cache_manager
|
||||||
self._client = LLMCacheClient(cache_manager)
|
self._client = LLMCacheClient(cache_manager)
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
async def map(self, input_value: ModelOutput) -> ModelOutput:
|
async def map(self, input_value: ModelOutput) -> ModelOutput:
|
||||||
|
"""Saves a single model output to cache and returns it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_value (ModelOutput): The output from the model to be cached.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModelOutput: The same input model output.
|
||||||
|
"""
|
||||||
llm_cache_key: LLMCacheKey = await self.current_dag_context.get_share_data(
|
llm_cache_key: LLMCacheKey = await self.current_dag_context.get_share_data(
|
||||||
_LLM_MODEL_INPUT_VALUE_KEY
|
_LLM_MODEL_INPUT_VALUE_KEY
|
||||||
)
|
)
|
||||||
@@ -117,3 +261,26 @@ class ModelCacheOperator(MapOperator[ModelOutput, ModelOutput]):
|
|||||||
if llm_cache_key:
|
if llm_cache_key:
|
||||||
await self._client.set(llm_cache_key, llm_cache_value)
|
await self._client.set(llm_cache_key, llm_cache_value)
|
||||||
return input_value
|
return input_value
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_cache_key_dict(input_value: Dict) -> Dict:
|
||||||
|
"""Parses and extracts relevant fields from input to form a cache key dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_value (Dict): The input dictionary containing model and prompt parameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: A dictionary used for generating cache keys.
|
||||||
|
"""
|
||||||
|
prompt: str = input_value.get("prompt")
|
||||||
|
if prompt:
|
||||||
|
prompt = prompt.strip()
|
||||||
|
return {
|
||||||
|
"prompt": prompt,
|
||||||
|
"model_name": input_value.get("model"),
|
||||||
|
"temperature": input_value.get("temperature"),
|
||||||
|
"max_new_tokens": input_value.get("max_new_tokens"),
|
||||||
|
"top_p": input_value.get("top_p", "1.0"),
|
||||||
|
# TODO pass model_type
|
||||||
|
"model_type": input_value.get("model_type", "huggingface"),
|
||||||
|
}
|
||||||
|
@@ -566,29 +566,83 @@ class BaseChat(ABC):
|
|||||||
def _build_model_operator(
|
def _build_model_operator(
|
||||||
is_stream: bool = False, dag_name: str = "llm_model_dag"
|
is_stream: bool = False, dag_name: str = "llm_model_dag"
|
||||||
) -> BaseOperator:
|
) -> BaseOperator:
|
||||||
|
"""Builds and returns a model processing workflow (DAG) operator.
|
||||||
|
|
||||||
|
This function constructs a Directed Acyclic Graph (DAG) for processing data using a model.
|
||||||
|
It includes caching and branching logic to either fetch results from a cache or process
|
||||||
|
data using the model. It supports both streaming and non-streaming modes.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
input_node >> cache_check_branch_node
|
||||||
|
cache_check_branch_node >> model_node >> save_cached_node >> join_node
|
||||||
|
cache_check_branch_node >> cached_node >> join_node
|
||||||
|
|
||||||
|
equivalent to::
|
||||||
|
|
||||||
|
-> model_node -> save_cached_node ->
|
||||||
|
/ \
|
||||||
|
input_node -> cache_check_branch_node ---> join_node
|
||||||
|
\ /
|
||||||
|
-> cached_node ------------------- ->
|
||||||
|
|
||||||
|
Args:
|
||||||
|
is_stream (bool): Flag to determine if the operator should process data in streaming mode.
|
||||||
|
dag_name (str): Name of the DAG.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BaseOperator: The final operator in the constructed DAG, typically a join node.
|
||||||
|
"""
|
||||||
from pilot.model.cluster import WorkerManagerFactory
|
from pilot.model.cluster import WorkerManagerFactory
|
||||||
|
from pilot.awel import JoinOperator
|
||||||
from pilot.model.operator.model_operator import (
|
from pilot.model.operator.model_operator import (
|
||||||
ModelCacheOperator,
|
ModelCacheBranchOperator,
|
||||||
ModelStreamCacheOperator,
|
CachedModelStreamOperator,
|
||||||
ModelCachePreOperator,
|
CachedModelOperator,
|
||||||
|
ModelSaveCacheOperator,
|
||||||
|
ModelStreamSaveCacheOperator,
|
||||||
)
|
)
|
||||||
from pilot.cache import CacheManager
|
from pilot.cache import CacheManager
|
||||||
|
|
||||||
|
# Fetch worker and cache managers from the system configuration
|
||||||
worker_manager = CFG.SYSTEM_APP.get_component(
|
worker_manager = CFG.SYSTEM_APP.get_component(
|
||||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||||
).create()
|
).create()
|
||||||
cache_manager: CacheManager = CFG.SYSTEM_APP.get_component(
|
cache_manager: CacheManager = CFG.SYSTEM_APP.get_component(
|
||||||
ComponentType.MODEL_CACHE_MANAGER, CacheManager
|
ComponentType.MODEL_CACHE_MANAGER, CacheManager
|
||||||
)
|
)
|
||||||
|
# Define task names for the model and cache nodes
|
||||||
|
model_task_name = "llm_model_node"
|
||||||
|
cache_task_name = "llm_model_cache_node"
|
||||||
|
|
||||||
with DAG(dag_name):
|
with DAG(dag_name):
|
||||||
|
# Create an input node
|
||||||
input_node = InputOperator(SimpleCallDataInputSource())
|
input_node = InputOperator(SimpleCallDataInputSource())
|
||||||
cache_check_node = ModelCachePreOperator(cache_manager)
|
# Determine if the workflow should operate in streaming mode
|
||||||
if is_stream:
|
if is_stream:
|
||||||
model_node = ModelStreamOperator(worker_manager)
|
model_node = ModelStreamOperator(worker_manager, task_name=model_task_name)
|
||||||
cache_node = ModelStreamCacheOperator(cache_manager)
|
cached_node = CachedModelStreamOperator(
|
||||||
|
cache_manager, task_name=cache_task_name
|
||||||
|
)
|
||||||
|
save_cached_node = ModelStreamSaveCacheOperator(cache_manager)
|
||||||
else:
|
else:
|
||||||
model_node = ModelOperator(worker_manager)
|
model_node = ModelOperator(worker_manager, task_name=model_task_name)
|
||||||
cache_node = ModelCacheOperator(cache_manager)
|
cached_node = CachedModelOperator(cache_manager, task_name=cache_task_name)
|
||||||
input_node >> cache_check_node >> model_node >> cache_node
|
save_cached_node = ModelSaveCacheOperator(cache_manager)
|
||||||
return cache_node
|
|
||||||
|
# Create a branch node to decide between fetching from cache or processing with the model
|
||||||
|
cache_check_branch_node = ModelCacheBranchOperator(
|
||||||
|
cache_manager,
|
||||||
|
model_task_name="llm_model_node",
|
||||||
|
cache_task_name="llm_model_cache_node",
|
||||||
|
)
|
||||||
|
# Create a join node to merge outputs from the model and cache nodes, just keep the fist not empty output
|
||||||
|
join_node = JoinOperator(
|
||||||
|
combine_function=lambda model_out, cache_out: cache_out or model_out
|
||||||
|
)
|
||||||
|
|
||||||
|
# Define the workflow structure using the >> operator
|
||||||
|
input_node >> cache_check_branch_node
|
||||||
|
cache_check_branch_node >> model_node >> save_cached_node >> join_node
|
||||||
|
cache_check_branch_node >> cached_node >> join_node
|
||||||
|
|
||||||
|
return join_node
|
||||||
|
Reference in New Issue
Block a user