fix(core): Fix AWEL branch bug (#1640)

This commit is contained in:
Fangyin Cheng
2024-06-18 11:11:43 +08:00
committed by GitHub
parent 49b56b4576
commit ace169ac46
32 changed files with 870 additions and 481 deletions

View File

@@ -17,6 +17,7 @@ from .dag.base import DAG, DAGContext, DAGVar
from .operators.base import BaseOperator, WorkflowRunner
from .operators.common_operator import (
BranchFunc,
BranchJoinOperator,
BranchOperator,
BranchTaskType,
InputOperator,
@@ -78,6 +79,7 @@ __all__ = [
"ReduceStreamOperator",
"TriggerOperator",
"MapOperator",
"BranchJoinOperator",
"BranchOperator",
"InputOperator",
"BranchFunc",

View File

@@ -1,4 +1,5 @@
"""The mixin of DAGs."""
import abc
import dataclasses
import inspect
@@ -337,6 +338,9 @@ class Parameter(TypeMetadata, Serializable):
value: Optional[Any] = Field(
None, description="The value of the parameter(Saved in the dag file)"
)
alias: Optional[List[str]] = Field(
None, description="The alias of the parameter(Compatible with old version)"
)
@model_validator(mode="before")
@classmethod
@@ -398,6 +402,7 @@ class Parameter(TypeMetadata, Serializable):
description: Optional[str] = None,
options: Optional[Union[BaseDynamicOptions, List[OptionValue]]] = None,
resource_type: ResourceType = ResourceType.INSTANCE,
alias: Optional[List[str]] = None,
):
"""Build the parameter from the type."""
type_name = type.__qualname__
@@ -419,6 +424,7 @@ class Parameter(TypeMetadata, Serializable):
placeholder=placeholder,
description=description or label,
options=options,
alias=alias,
)
@classmethod
@@ -452,7 +458,7 @@ class Parameter(TypeMetadata, Serializable):
def to_dict(self) -> Dict:
"""Convert current metadata to json dict."""
dict_value = model_to_dict(self, exclude={"options"})
dict_value = model_to_dict(self, exclude={"options", "alias"})
if not self.options:
dict_value["options"] = None
elif isinstance(self.options, BaseDynamicOptions):
@@ -677,9 +683,18 @@ class BaseMetadata(BaseResource):
for parameter in self.parameters
if not parameter.optional
}
current_parameters = {
parameter.name: parameter for parameter in self.parameters
}
current_parameters = {}
current_aliases_parameters = {}
for parameter in self.parameters:
current_parameters[parameter.name] = parameter
if parameter.alias:
for alias in parameter.alias:
if alias in current_aliases_parameters:
raise FlowMetadataException(
f"Alias {alias} already exists in the metadata."
)
current_aliases_parameters[alias] = parameter
if len(view_required_parameters) < len(current_required_parameters):
# TODO, skip the optional parameters.
raise FlowParameterMetadataException(
@@ -691,12 +706,16 @@ class BaseMetadata(BaseResource):
)
for view_param in view_parameters:
view_param_key = view_param.name
if view_param_key not in current_parameters:
if view_param_key in current_parameters:
current_parameter = current_parameters[view_param_key]
elif view_param_key in current_aliases_parameters:
current_parameter = current_aliases_parameters[view_param_key]
else:
raise FlowParameterMetadataException(
f"Parameter {view_param_key} not in the metadata."
)
runnable_parameters.update(
current_parameters[view_param_key].to_runnable_parameter(
current_parameter.to_runnable_parameter(
view_param.get_typed_value(), resources, key_to_resource_instance
)
)

View File

@@ -1,8 +1,29 @@
"""Compatibility mapping for flow classes."""
from dataclasses import dataclass
from typing import Dict, Optional
_COMPAT_FLOW_MAPPING: Dict[str, str] = {}
@dataclass
class _RegisterItem:
"""Register item for compatibility mapping."""
old_module: str
new_module: str
old_name: str
new_name: Optional[str] = None
after: Optional[str] = None
def old_cls_key(self) -> str:
"""Get the old class key."""
return f"{self.old_module}.{self.old_name}"
def new_cls_key(self) -> str:
"""Get the new class key."""
return f"{self.new_module}.{self.new_name}"
_COMPAT_FLOW_MAPPING: Dict[str, _RegisterItem] = {}
_OLD_AGENT_RESOURCE_MODULE_1 = "dbgpt.serve.agent.team.layout.agent_operator_resource"
@@ -11,17 +32,24 @@ _NEW_AGENT_RESOURCE_MODULE = "dbgpt.agent.core.plan.awel.agent_operator_resource
def _register(
old_module: str, new_module: str, old_name: str, new_name: Optional[str] = None
old_module: str,
new_module: str,
old_name: str,
new_name: Optional[str] = None,
after_version: Optional[str] = None,
):
if not new_name:
new_name = old_name
_COMPAT_FLOW_MAPPING[f"{old_module}.{old_name}"] = f"{new_module}.{new_name}"
item = _RegisterItem(old_module, new_module, old_name, new_name, after_version)
_COMPAT_FLOW_MAPPING[item.old_cls_key()] = item
def get_new_class_name(old_class_name: str) -> Optional[str]:
"""Get the new class name for the old class name."""
new_cls_name = _COMPAT_FLOW_MAPPING.get(old_class_name, None)
return new_cls_name
if old_class_name not in _COMPAT_FLOW_MAPPING:
return None
item = _COMPAT_FLOW_MAPPING[old_class_name]
return item.new_cls_key()
_register(
@@ -54,3 +82,9 @@ _register(
_register(
_OLD_AGENT_RESOURCE_MODULE_2, _NEW_AGENT_RESOURCE_MODULE, "AWELAgent", "AWELAgent"
)
_register(
"dbgpt.storage.vector_store.connector",
"dbgpt.serve.rag.connector",
"VectorStoreConnector",
after_version="v0.5.8",
)

View File

@@ -555,14 +555,6 @@ class FlowFactory:
downstream = key_to_downstream.get(operator_key, [])
if not downstream:
raise ValueError("Branch operator should have downstream.")
if len(downstream) != len(view_metadata.parameters):
raise ValueError(
"Branch operator should have the same number of downstream as "
"parameters."
)
for i, param in enumerate(view_metadata.parameters):
downstream_key, _, _ = downstream[i]
param.value = key_to_operator_nodes[downstream_key].data.name
try:
runnable_params = metadata.get_runnable_parameters(

View File

@@ -137,6 +137,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
task_name: Optional[str] = None,
dag: Optional[DAG] = None,
runner: Optional[WorkflowRunner] = None,
can_skip_in_branch: bool = True,
**kwargs,
) -> None:
"""Create a BaseOperator with an optional workflow runner.
@@ -157,6 +158,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
self._runner: WorkflowRunner = runner
self._dag_ctx: Optional[DAGContext] = None
self._can_skip_in_branch = can_skip_in_branch
@property
def current_dag_context(self) -> DAGContext:
@@ -321,6 +323,10 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
"""Get the current event loop task id."""
return id(asyncio.current_task())
def can_skip_in_branch(self) -> bool:
"""Check if the operator can be skipped in the branch."""
return self._can_skip_in_branch
def initialize_runner(runner: WorkflowRunner):
"""Initialize the default runner."""

View File

@@ -16,6 +16,7 @@ from ..task.base import (
ReduceFunc,
TaskContext,
TaskOutput,
is_empty_data,
)
from .base import BaseOperator
@@ -28,13 +29,16 @@ class JoinOperator(BaseOperator, Generic[OUT]):
This node type is useful for combining the outputs of upstream nodes.
"""
def __init__(self, combine_function: JoinFunc, **kwargs):
def __init__(
self, combine_function: JoinFunc, can_skip_in_branch: bool = True, **kwargs
):
"""Create a JoinDAGNode with a combine function.
Args:
combine_function: A function that defines how to combine inputs.
can_skip_in_branch(bool): Whether the node can be skipped in a branch.
"""
super().__init__(**kwargs)
super().__init__(can_skip_in_branch=can_skip_in_branch, **kwargs)
if not callable(combine_function):
raise ValueError("combine_function must be callable")
self.combine_function = combine_function
@@ -57,6 +61,12 @@ class JoinOperator(BaseOperator, Generic[OUT]):
curr_task_ctx.set_task_output(join_output)
return join_output
async def _return_first_non_empty(self, *inputs):
for data in inputs:
if not is_empty_data(data):
return data
raise ValueError("All inputs are empty")
class ReduceStreamOperator(BaseOperator, Generic[IN, OUT]):
"""Operator that reduces inputs using a custom reduce function."""
@@ -287,6 +297,32 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
raise NotImplementedError
class BranchJoinOperator(JoinOperator, Generic[OUT]):
"""Operator that joins inputs using a custom combine function.
This node type is useful for combining the outputs of upstream nodes.
"""
def __init__(
self,
combine_function: Optional[JoinFunc] = None,
can_skip_in_branch: bool = False,
**kwargs,
):
"""Create a JoinDAGNode with a combine function.
Args:
combine_function: A function that defines how to combine inputs.
can_skip_in_branch(bool): Whether the node can be skipped in a branch(
default True).
"""
super().__init__(
combine_function=combine_function or self._return_first_non_empty,
can_skip_in_branch=can_skip_in_branch,
**kwargs,
)
class InputOperator(BaseOperator, Generic[OUT]):
"""Operator node that reads data from an input source."""

View File

@@ -2,6 +2,7 @@
This runner will run the workflow in the current process.
"""
import asyncio
import logging
import traceback
@@ -11,7 +12,7 @@ from dbgpt.component import SystemApp
from ..dag.base import DAGContext, DAGVar
from ..operators.base import CALL_DATA, BaseOperator, WorkflowRunner
from ..operators.common_operator import BranchOperator, JoinOperator
from ..operators.common_operator import BranchOperator
from ..task.base import SKIP_DATA, TaskContext, TaskState
from ..task.task_impl import DefaultInputContext, DefaultTaskContext, SimpleTaskOutput
from .job_manager import JobManager
@@ -184,14 +185,14 @@ def _skip_current_downstream_by_node_name(
return
for child in branch_node.downstream:
child = cast(BaseOperator, child)
if child.node_name in skip_nodes:
if child.node_name in skip_nodes or child.node_id in skip_node_ids:
logger.info(f"Skip node name {child.node_name}, node id {child.node_id}")
_skip_downstream_by_id(child, skip_node_ids)
def _skip_downstream_by_id(node: BaseOperator, skip_node_ids: Set[str]):
if isinstance(node, JoinOperator):
# Not skip join node
if not node.can_skip_in_branch():
# Current node can not skip, so skip its downstream
return
skip_node_ids.add(node.node_id)
for child in node.downstream:

View File

@@ -130,7 +130,7 @@ async def test_branch_node(
even_node = MapOperator(
lambda x: 888, task_id="even_node", task_name="even_node_name"
)
join_node = JoinOperator(join_func)
join_node = JoinOperator(join_func, can_skip_in_branch=False)
branch_node = BranchOperator(
{lambda x: x % 2 == 1: odd_node, lambda x: x % 2 == 0: even_node}
)

View File

@@ -2,12 +2,13 @@
import dataclasses
from abc import ABC
from typing import Any, AsyncIterator, Dict, List, Optional, Union
from typing import Any, AsyncIterator, Dict, List, Optional, Union, cast
from dbgpt._private.pydantic import BaseModel
from dbgpt.core.awel import (
BaseOperator,
BranchFunc,
BranchJoinOperator,
BranchOperator,
CommonLLMHttpRequestBody,
CommonLLMHttpResponseBody,
@@ -340,24 +341,7 @@ class LLMBranchOperator(BranchOperator[ModelRequest, ModelRequest]):
category=OperatorCategory.LLM,
operator_type=OperatorType.BRANCH,
description=_("Branch the workflow based on the stream flag of the request."),
parameters=[
Parameter.build_from(
_("Streaming Task Name"),
"stream_task_name",
str,
optional=True,
default="streaming_llm_task",
description=_("The name of the streaming task."),
),
Parameter.build_from(
_("Non-Streaming Task Name"),
"no_stream_task_name",
str,
optional=True,
default="llm_task",
description=_("The name of the non-streaming task."),
),
],
parameters=[],
inputs=[
IOField.build_from(
_("Model Request"),
@@ -382,7 +366,12 @@ class LLMBranchOperator(BranchOperator[ModelRequest, ModelRequest]):
],
)
def __init__(self, stream_task_name: str, no_stream_task_name: str, **kwargs):
def __init__(
self,
stream_task_name: Optional[str] = None,
no_stream_task_name: Optional[str] = None,
**kwargs,
):
"""Create a new LLM branch operator.
Args:
@@ -390,18 +379,13 @@ class LLMBranchOperator(BranchOperator[ModelRequest, ModelRequest]):
no_stream_task_name (str): The name of the non-streaming task.
"""
super().__init__(**kwargs)
if not stream_task_name:
raise ValueError("stream_task_name is not set")
if not no_stream_task_name:
raise ValueError("no_stream_task_name is not set")
self._stream_task_name = stream_task_name
self._no_stream_task_name = no_stream_task_name
async def branches(
self,
) -> Dict[BranchFunc[ModelRequest], Union[BaseOperator, str]]:
"""
Return a dict of branch function and task name.
"""Return a dict of branch function and task name.
Returns:
Dict[BranchFunc[ModelRequest], str]: A dict of branch function and task
@@ -409,6 +393,18 @@ class LLMBranchOperator(BranchOperator[ModelRequest, ModelRequest]):
If the predicate function returns True, we will run the corresponding
task.
"""
if self._stream_task_name and self._no_stream_task_name:
stream_task_name = self._stream_task_name
no_stream_task_name = self._no_stream_task_name
else:
stream_task_name = ""
no_stream_task_name = ""
for node in self.downstream:
task = cast(BaseOperator, node)
if task.streaming_operator:
stream_task_name = node.node_name
else:
no_stream_task_name = node.node_name
async def check_stream_true(r: ModelRequest) -> bool:
# If stream is true, we will run the streaming task. otherwise, we will run
@@ -416,8 +412,8 @@ class LLMBranchOperator(BranchOperator[ModelRequest, ModelRequest]):
return r.stream
return {
check_stream_true: self._stream_task_name,
lambda x: not x.stream: self._no_stream_task_name,
check_stream_true: stream_task_name,
lambda x: not x.stream: no_stream_task_name,
}
@@ -553,3 +549,93 @@ class StringOutput2ModelOutputOperator(MapOperator[str, ModelOutput]):
text=input_value,
error_code=500,
)
class LLMBranchJoinOperator(BranchJoinOperator[ModelOutput]):
"""The LLM Branch Join Operator.
Decide which output to keep(streaming or non-streaming).
"""
streaming_operator = True
metadata = ViewMetadata(
label=_("LLM Branch Join Operator"),
name="llm_branch_join_operator",
category=OperatorCategory.LLM,
operator_type=OperatorType.JOIN,
description=_("Just keep the first non-empty output."),
parameters=[],
inputs=[
IOField.build_from(
_("Streaming Model Output"),
"stream_output",
ModelOutput,
is_list=True,
description=_("The streaming output."),
),
IOField.build_from(
_("Non-Streaming Model Output"),
"not_stream_output",
ModelOutput,
description=_("The non-streaming output."),
),
],
outputs=[
IOField.build_from(
_("Model Output"),
"output_value",
ModelOutput,
is_list=True,
description=_("The output value of the operator."),
),
],
)
def __init__(self, **kwargs):
"""Create a new LLM branch join operator."""
super().__init__(**kwargs)
class StringBranchJoinOperator(BranchJoinOperator[str]):
"""The String Branch Join Operator.
Decide which output to keep(streaming or non-streaming).
"""
streaming_operator = True
metadata = ViewMetadata(
label=_("String Branch Join Operator"),
name="string_branch_join_operator",
category=OperatorCategory.COMMON,
operator_type=OperatorType.JOIN,
description=_("Just keep the first non-empty output."),
parameters=[],
inputs=[
IOField.build_from(
_("Streaming String Output"),
"stream_output",
str,
is_list=True,
description=_("The streaming output."),
),
IOField.build_from(
_("Non-Streaming String Output"),
"not_stream_output",
str,
description=_("The non-streaming output."),
),
],
outputs=[
IOField.build_from(
_("String Output"),
"output_value",
str,
is_list=True,
description=_("The output value of the operator."),
),
],
)
def __init__(self, **kwargs):
"""Create a new LLM branch join operator."""
super().__init__(**kwargs)

View File

@@ -8,6 +8,7 @@ from dbgpt.core.interface.operators.llm_operator import ( # noqa: F401
BaseLLM,
BaseLLMOperator,
BaseStreamingLLMOperator,
LLMBranchJoinOperator,
LLMBranchOperator,
RequestBuilderOperator,
)
@@ -32,6 +33,7 @@ __ALL__ = [
"BaseLLM",
"LLMBranchOperator",
"BaseLLMOperator",
"LLMBranchJoinOperator",
"RequestBuilderOperator",
"BaseStreamingLLMOperator",
"BaseConversationOperator",