mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-13 21:21:08 +00:00
fix(core): Fix AWEL branch bug (#1640)
This commit is contained in:
@@ -17,6 +17,7 @@ from .dag.base import DAG, DAGContext, DAGVar
|
||||
from .operators.base import BaseOperator, WorkflowRunner
|
||||
from .operators.common_operator import (
|
||||
BranchFunc,
|
||||
BranchJoinOperator,
|
||||
BranchOperator,
|
||||
BranchTaskType,
|
||||
InputOperator,
|
||||
@@ -78,6 +79,7 @@ __all__ = [
|
||||
"ReduceStreamOperator",
|
||||
"TriggerOperator",
|
||||
"MapOperator",
|
||||
"BranchJoinOperator",
|
||||
"BranchOperator",
|
||||
"InputOperator",
|
||||
"BranchFunc",
|
||||
|
@@ -1,4 +1,5 @@
|
||||
"""The mixin of DAGs."""
|
||||
|
||||
import abc
|
||||
import dataclasses
|
||||
import inspect
|
||||
@@ -337,6 +338,9 @@ class Parameter(TypeMetadata, Serializable):
|
||||
value: Optional[Any] = Field(
|
||||
None, description="The value of the parameter(Saved in the dag file)"
|
||||
)
|
||||
alias: Optional[List[str]] = Field(
|
||||
None, description="The alias of the parameter(Compatible with old version)"
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -398,6 +402,7 @@ class Parameter(TypeMetadata, Serializable):
|
||||
description: Optional[str] = None,
|
||||
options: Optional[Union[BaseDynamicOptions, List[OptionValue]]] = None,
|
||||
resource_type: ResourceType = ResourceType.INSTANCE,
|
||||
alias: Optional[List[str]] = None,
|
||||
):
|
||||
"""Build the parameter from the type."""
|
||||
type_name = type.__qualname__
|
||||
@@ -419,6 +424,7 @@ class Parameter(TypeMetadata, Serializable):
|
||||
placeholder=placeholder,
|
||||
description=description or label,
|
||||
options=options,
|
||||
alias=alias,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -452,7 +458,7 @@ class Parameter(TypeMetadata, Serializable):
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert current metadata to json dict."""
|
||||
dict_value = model_to_dict(self, exclude={"options"})
|
||||
dict_value = model_to_dict(self, exclude={"options", "alias"})
|
||||
if not self.options:
|
||||
dict_value["options"] = None
|
||||
elif isinstance(self.options, BaseDynamicOptions):
|
||||
@@ -677,9 +683,18 @@ class BaseMetadata(BaseResource):
|
||||
for parameter in self.parameters
|
||||
if not parameter.optional
|
||||
}
|
||||
current_parameters = {
|
||||
parameter.name: parameter for parameter in self.parameters
|
||||
}
|
||||
current_parameters = {}
|
||||
current_aliases_parameters = {}
|
||||
for parameter in self.parameters:
|
||||
current_parameters[parameter.name] = parameter
|
||||
if parameter.alias:
|
||||
for alias in parameter.alias:
|
||||
if alias in current_aliases_parameters:
|
||||
raise FlowMetadataException(
|
||||
f"Alias {alias} already exists in the metadata."
|
||||
)
|
||||
current_aliases_parameters[alias] = parameter
|
||||
|
||||
if len(view_required_parameters) < len(current_required_parameters):
|
||||
# TODO, skip the optional parameters.
|
||||
raise FlowParameterMetadataException(
|
||||
@@ -691,12 +706,16 @@ class BaseMetadata(BaseResource):
|
||||
)
|
||||
for view_param in view_parameters:
|
||||
view_param_key = view_param.name
|
||||
if view_param_key not in current_parameters:
|
||||
if view_param_key in current_parameters:
|
||||
current_parameter = current_parameters[view_param_key]
|
||||
elif view_param_key in current_aliases_parameters:
|
||||
current_parameter = current_aliases_parameters[view_param_key]
|
||||
else:
|
||||
raise FlowParameterMetadataException(
|
||||
f"Parameter {view_param_key} not in the metadata."
|
||||
)
|
||||
runnable_parameters.update(
|
||||
current_parameters[view_param_key].to_runnable_parameter(
|
||||
current_parameter.to_runnable_parameter(
|
||||
view_param.get_typed_value(), resources, key_to_resource_instance
|
||||
)
|
||||
)
|
||||
|
@@ -1,8 +1,29 @@
|
||||
"""Compatibility mapping for flow classes."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional
|
||||
|
||||
_COMPAT_FLOW_MAPPING: Dict[str, str] = {}
|
||||
|
||||
@dataclass
|
||||
class _RegisterItem:
|
||||
"""Register item for compatibility mapping."""
|
||||
|
||||
old_module: str
|
||||
new_module: str
|
||||
old_name: str
|
||||
new_name: Optional[str] = None
|
||||
after: Optional[str] = None
|
||||
|
||||
def old_cls_key(self) -> str:
|
||||
"""Get the old class key."""
|
||||
return f"{self.old_module}.{self.old_name}"
|
||||
|
||||
def new_cls_key(self) -> str:
|
||||
"""Get the new class key."""
|
||||
return f"{self.new_module}.{self.new_name}"
|
||||
|
||||
|
||||
_COMPAT_FLOW_MAPPING: Dict[str, _RegisterItem] = {}
|
||||
|
||||
|
||||
_OLD_AGENT_RESOURCE_MODULE_1 = "dbgpt.serve.agent.team.layout.agent_operator_resource"
|
||||
@@ -11,17 +32,24 @@ _NEW_AGENT_RESOURCE_MODULE = "dbgpt.agent.core.plan.awel.agent_operator_resource
|
||||
|
||||
|
||||
def _register(
|
||||
old_module: str, new_module: str, old_name: str, new_name: Optional[str] = None
|
||||
old_module: str,
|
||||
new_module: str,
|
||||
old_name: str,
|
||||
new_name: Optional[str] = None,
|
||||
after_version: Optional[str] = None,
|
||||
):
|
||||
if not new_name:
|
||||
new_name = old_name
|
||||
_COMPAT_FLOW_MAPPING[f"{old_module}.{old_name}"] = f"{new_module}.{new_name}"
|
||||
item = _RegisterItem(old_module, new_module, old_name, new_name, after_version)
|
||||
_COMPAT_FLOW_MAPPING[item.old_cls_key()] = item
|
||||
|
||||
|
||||
def get_new_class_name(old_class_name: str) -> Optional[str]:
|
||||
"""Get the new class name for the old class name."""
|
||||
new_cls_name = _COMPAT_FLOW_MAPPING.get(old_class_name, None)
|
||||
return new_cls_name
|
||||
if old_class_name not in _COMPAT_FLOW_MAPPING:
|
||||
return None
|
||||
item = _COMPAT_FLOW_MAPPING[old_class_name]
|
||||
return item.new_cls_key()
|
||||
|
||||
|
||||
_register(
|
||||
@@ -54,3 +82,9 @@ _register(
|
||||
_register(
|
||||
_OLD_AGENT_RESOURCE_MODULE_2, _NEW_AGENT_RESOURCE_MODULE, "AWELAgent", "AWELAgent"
|
||||
)
|
||||
_register(
|
||||
"dbgpt.storage.vector_store.connector",
|
||||
"dbgpt.serve.rag.connector",
|
||||
"VectorStoreConnector",
|
||||
after_version="v0.5.8",
|
||||
)
|
||||
|
@@ -555,14 +555,6 @@ class FlowFactory:
|
||||
downstream = key_to_downstream.get(operator_key, [])
|
||||
if not downstream:
|
||||
raise ValueError("Branch operator should have downstream.")
|
||||
if len(downstream) != len(view_metadata.parameters):
|
||||
raise ValueError(
|
||||
"Branch operator should have the same number of downstream as "
|
||||
"parameters."
|
||||
)
|
||||
for i, param in enumerate(view_metadata.parameters):
|
||||
downstream_key, _, _ = downstream[i]
|
||||
param.value = key_to_operator_nodes[downstream_key].data.name
|
||||
|
||||
try:
|
||||
runnable_params = metadata.get_runnable_parameters(
|
||||
|
@@ -137,6 +137,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
task_name: Optional[str] = None,
|
||||
dag: Optional[DAG] = None,
|
||||
runner: Optional[WorkflowRunner] = None,
|
||||
can_skip_in_branch: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Create a BaseOperator with an optional workflow runner.
|
||||
@@ -157,6 +158,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
|
||||
self._runner: WorkflowRunner = runner
|
||||
self._dag_ctx: Optional[DAGContext] = None
|
||||
self._can_skip_in_branch = can_skip_in_branch
|
||||
|
||||
@property
|
||||
def current_dag_context(self) -> DAGContext:
|
||||
@@ -321,6 +323,10 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
"""Get the current event loop task id."""
|
||||
return id(asyncio.current_task())
|
||||
|
||||
def can_skip_in_branch(self) -> bool:
|
||||
"""Check if the operator can be skipped in the branch."""
|
||||
return self._can_skip_in_branch
|
||||
|
||||
|
||||
def initialize_runner(runner: WorkflowRunner):
|
||||
"""Initialize the default runner."""
|
||||
|
@@ -16,6 +16,7 @@ from ..task.base import (
|
||||
ReduceFunc,
|
||||
TaskContext,
|
||||
TaskOutput,
|
||||
is_empty_data,
|
||||
)
|
||||
from .base import BaseOperator
|
||||
|
||||
@@ -28,13 +29,16 @@ class JoinOperator(BaseOperator, Generic[OUT]):
|
||||
This node type is useful for combining the outputs of upstream nodes.
|
||||
"""
|
||||
|
||||
def __init__(self, combine_function: JoinFunc, **kwargs):
|
||||
def __init__(
|
||||
self, combine_function: JoinFunc, can_skip_in_branch: bool = True, **kwargs
|
||||
):
|
||||
"""Create a JoinDAGNode with a combine function.
|
||||
|
||||
Args:
|
||||
combine_function: A function that defines how to combine inputs.
|
||||
can_skip_in_branch(bool): Whether the node can be skipped in a branch.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(can_skip_in_branch=can_skip_in_branch, **kwargs)
|
||||
if not callable(combine_function):
|
||||
raise ValueError("combine_function must be callable")
|
||||
self.combine_function = combine_function
|
||||
@@ -57,6 +61,12 @@ class JoinOperator(BaseOperator, Generic[OUT]):
|
||||
curr_task_ctx.set_task_output(join_output)
|
||||
return join_output
|
||||
|
||||
async def _return_first_non_empty(self, *inputs):
|
||||
for data in inputs:
|
||||
if not is_empty_data(data):
|
||||
return data
|
||||
raise ValueError("All inputs are empty")
|
||||
|
||||
|
||||
class ReduceStreamOperator(BaseOperator, Generic[IN, OUT]):
|
||||
"""Operator that reduces inputs using a custom reduce function."""
|
||||
@@ -287,6 +297,32 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BranchJoinOperator(JoinOperator, Generic[OUT]):
|
||||
"""Operator that joins inputs using a custom combine function.
|
||||
|
||||
This node type is useful for combining the outputs of upstream nodes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
combine_function: Optional[JoinFunc] = None,
|
||||
can_skip_in_branch: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a JoinDAGNode with a combine function.
|
||||
|
||||
Args:
|
||||
combine_function: A function that defines how to combine inputs.
|
||||
can_skip_in_branch(bool): Whether the node can be skipped in a branch(
|
||||
default True).
|
||||
"""
|
||||
super().__init__(
|
||||
combine_function=combine_function or self._return_first_non_empty,
|
||||
can_skip_in_branch=can_skip_in_branch,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class InputOperator(BaseOperator, Generic[OUT]):
|
||||
"""Operator node that reads data from an input source."""
|
||||
|
||||
|
@@ -2,6 +2,7 @@
|
||||
|
||||
This runner will run the workflow in the current process.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import traceback
|
||||
@@ -11,7 +12,7 @@ from dbgpt.component import SystemApp
|
||||
|
||||
from ..dag.base import DAGContext, DAGVar
|
||||
from ..operators.base import CALL_DATA, BaseOperator, WorkflowRunner
|
||||
from ..operators.common_operator import BranchOperator, JoinOperator
|
||||
from ..operators.common_operator import BranchOperator
|
||||
from ..task.base import SKIP_DATA, TaskContext, TaskState
|
||||
from ..task.task_impl import DefaultInputContext, DefaultTaskContext, SimpleTaskOutput
|
||||
from .job_manager import JobManager
|
||||
@@ -184,14 +185,14 @@ def _skip_current_downstream_by_node_name(
|
||||
return
|
||||
for child in branch_node.downstream:
|
||||
child = cast(BaseOperator, child)
|
||||
if child.node_name in skip_nodes:
|
||||
if child.node_name in skip_nodes or child.node_id in skip_node_ids:
|
||||
logger.info(f"Skip node name {child.node_name}, node id {child.node_id}")
|
||||
_skip_downstream_by_id(child, skip_node_ids)
|
||||
|
||||
|
||||
def _skip_downstream_by_id(node: BaseOperator, skip_node_ids: Set[str]):
|
||||
if isinstance(node, JoinOperator):
|
||||
# Not skip join node
|
||||
if not node.can_skip_in_branch():
|
||||
# Current node can not skip, so skip its downstream
|
||||
return
|
||||
skip_node_ids.add(node.node_id)
|
||||
for child in node.downstream:
|
||||
|
@@ -130,7 +130,7 @@ async def test_branch_node(
|
||||
even_node = MapOperator(
|
||||
lambda x: 888, task_id="even_node", task_name="even_node_name"
|
||||
)
|
||||
join_node = JoinOperator(join_func)
|
||||
join_node = JoinOperator(join_func, can_skip_in_branch=False)
|
||||
branch_node = BranchOperator(
|
||||
{lambda x: x % 2 == 1: odd_node, lambda x: x % 2 == 0: even_node}
|
||||
)
|
||||
|
@@ -2,12 +2,13 @@
|
||||
|
||||
import dataclasses
|
||||
from abc import ABC
|
||||
from typing import Any, AsyncIterator, Dict, List, Optional, Union
|
||||
from typing import Any, AsyncIterator, Dict, List, Optional, Union, cast
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
from dbgpt.core.awel import (
|
||||
BaseOperator,
|
||||
BranchFunc,
|
||||
BranchJoinOperator,
|
||||
BranchOperator,
|
||||
CommonLLMHttpRequestBody,
|
||||
CommonLLMHttpResponseBody,
|
||||
@@ -340,24 +341,7 @@ class LLMBranchOperator(BranchOperator[ModelRequest, ModelRequest]):
|
||||
category=OperatorCategory.LLM,
|
||||
operator_type=OperatorType.BRANCH,
|
||||
description=_("Branch the workflow based on the stream flag of the request."),
|
||||
parameters=[
|
||||
Parameter.build_from(
|
||||
_("Streaming Task Name"),
|
||||
"stream_task_name",
|
||||
str,
|
||||
optional=True,
|
||||
default="streaming_llm_task",
|
||||
description=_("The name of the streaming task."),
|
||||
),
|
||||
Parameter.build_from(
|
||||
_("Non-Streaming Task Name"),
|
||||
"no_stream_task_name",
|
||||
str,
|
||||
optional=True,
|
||||
default="llm_task",
|
||||
description=_("The name of the non-streaming task."),
|
||||
),
|
||||
],
|
||||
parameters=[],
|
||||
inputs=[
|
||||
IOField.build_from(
|
||||
_("Model Request"),
|
||||
@@ -382,7 +366,12 @@ class LLMBranchOperator(BranchOperator[ModelRequest, ModelRequest]):
|
||||
],
|
||||
)
|
||||
|
||||
def __init__(self, stream_task_name: str, no_stream_task_name: str, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
stream_task_name: Optional[str] = None,
|
||||
no_stream_task_name: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a new LLM branch operator.
|
||||
|
||||
Args:
|
||||
@@ -390,18 +379,13 @@ class LLMBranchOperator(BranchOperator[ModelRequest, ModelRequest]):
|
||||
no_stream_task_name (str): The name of the non-streaming task.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
if not stream_task_name:
|
||||
raise ValueError("stream_task_name is not set")
|
||||
if not no_stream_task_name:
|
||||
raise ValueError("no_stream_task_name is not set")
|
||||
self._stream_task_name = stream_task_name
|
||||
self._no_stream_task_name = no_stream_task_name
|
||||
|
||||
async def branches(
|
||||
self,
|
||||
) -> Dict[BranchFunc[ModelRequest], Union[BaseOperator, str]]:
|
||||
"""
|
||||
Return a dict of branch function and task name.
|
||||
"""Return a dict of branch function and task name.
|
||||
|
||||
Returns:
|
||||
Dict[BranchFunc[ModelRequest], str]: A dict of branch function and task
|
||||
@@ -409,6 +393,18 @@ class LLMBranchOperator(BranchOperator[ModelRequest, ModelRequest]):
|
||||
If the predicate function returns True, we will run the corresponding
|
||||
task.
|
||||
"""
|
||||
if self._stream_task_name and self._no_stream_task_name:
|
||||
stream_task_name = self._stream_task_name
|
||||
no_stream_task_name = self._no_stream_task_name
|
||||
else:
|
||||
stream_task_name = ""
|
||||
no_stream_task_name = ""
|
||||
for node in self.downstream:
|
||||
task = cast(BaseOperator, node)
|
||||
if task.streaming_operator:
|
||||
stream_task_name = node.node_name
|
||||
else:
|
||||
no_stream_task_name = node.node_name
|
||||
|
||||
async def check_stream_true(r: ModelRequest) -> bool:
|
||||
# If stream is true, we will run the streaming task. otherwise, we will run
|
||||
@@ -416,8 +412,8 @@ class LLMBranchOperator(BranchOperator[ModelRequest, ModelRequest]):
|
||||
return r.stream
|
||||
|
||||
return {
|
||||
check_stream_true: self._stream_task_name,
|
||||
lambda x: not x.stream: self._no_stream_task_name,
|
||||
check_stream_true: stream_task_name,
|
||||
lambda x: not x.stream: no_stream_task_name,
|
||||
}
|
||||
|
||||
|
||||
@@ -553,3 +549,93 @@ class StringOutput2ModelOutputOperator(MapOperator[str, ModelOutput]):
|
||||
text=input_value,
|
||||
error_code=500,
|
||||
)
|
||||
|
||||
|
||||
class LLMBranchJoinOperator(BranchJoinOperator[ModelOutput]):
|
||||
"""The LLM Branch Join Operator.
|
||||
|
||||
Decide which output to keep(streaming or non-streaming).
|
||||
"""
|
||||
|
||||
streaming_operator = True
|
||||
metadata = ViewMetadata(
|
||||
label=_("LLM Branch Join Operator"),
|
||||
name="llm_branch_join_operator",
|
||||
category=OperatorCategory.LLM,
|
||||
operator_type=OperatorType.JOIN,
|
||||
description=_("Just keep the first non-empty output."),
|
||||
parameters=[],
|
||||
inputs=[
|
||||
IOField.build_from(
|
||||
_("Streaming Model Output"),
|
||||
"stream_output",
|
||||
ModelOutput,
|
||||
is_list=True,
|
||||
description=_("The streaming output."),
|
||||
),
|
||||
IOField.build_from(
|
||||
_("Non-Streaming Model Output"),
|
||||
"not_stream_output",
|
||||
ModelOutput,
|
||||
description=_("The non-streaming output."),
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IOField.build_from(
|
||||
_("Model Output"),
|
||||
"output_value",
|
||||
ModelOutput,
|
||||
is_list=True,
|
||||
description=_("The output value of the operator."),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a new LLM branch join operator."""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class StringBranchJoinOperator(BranchJoinOperator[str]):
|
||||
"""The String Branch Join Operator.
|
||||
|
||||
Decide which output to keep(streaming or non-streaming).
|
||||
"""
|
||||
|
||||
streaming_operator = True
|
||||
metadata = ViewMetadata(
|
||||
label=_("String Branch Join Operator"),
|
||||
name="string_branch_join_operator",
|
||||
category=OperatorCategory.COMMON,
|
||||
operator_type=OperatorType.JOIN,
|
||||
description=_("Just keep the first non-empty output."),
|
||||
parameters=[],
|
||||
inputs=[
|
||||
IOField.build_from(
|
||||
_("Streaming String Output"),
|
||||
"stream_output",
|
||||
str,
|
||||
is_list=True,
|
||||
description=_("The streaming output."),
|
||||
),
|
||||
IOField.build_from(
|
||||
_("Non-Streaming String Output"),
|
||||
"not_stream_output",
|
||||
str,
|
||||
description=_("The non-streaming output."),
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IOField.build_from(
|
||||
_("String Output"),
|
||||
"output_value",
|
||||
str,
|
||||
is_list=True,
|
||||
description=_("The output value of the operator."),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a new LLM branch join operator."""
|
||||
super().__init__(**kwargs)
|
||||
|
@@ -8,6 +8,7 @@ from dbgpt.core.interface.operators.llm_operator import ( # noqa: F401
|
||||
BaseLLM,
|
||||
BaseLLMOperator,
|
||||
BaseStreamingLLMOperator,
|
||||
LLMBranchJoinOperator,
|
||||
LLMBranchOperator,
|
||||
RequestBuilderOperator,
|
||||
)
|
||||
@@ -32,6 +33,7 @@ __ALL__ = [
|
||||
"BaseLLM",
|
||||
"LLMBranchOperator",
|
||||
"BaseLLMOperator",
|
||||
"LLMBranchJoinOperator",
|
||||
"RequestBuilderOperator",
|
||||
"BaseStreamingLLMOperator",
|
||||
"BaseConversationOperator",
|
||||
|
Reference in New Issue
Block a user