fix(core): Fix AWEL branch bug (#1640)

This commit is contained in:
Fangyin Cheng
2024-06-18 11:11:43 +08:00
committed by GitHub
parent 49b56b4576
commit ace169ac46
32 changed files with 870 additions and 481 deletions

View File

@@ -1,4 +1,5 @@
"""The mixin of DAGs."""
import abc
import dataclasses
import inspect
@@ -337,6 +338,9 @@ class Parameter(TypeMetadata, Serializable):
value: Optional[Any] = Field(
None, description="The value of the parameter(Saved in the dag file)"
)
alias: Optional[List[str]] = Field(
None, description="The alias of the parameter(Compatible with old version)"
)
@model_validator(mode="before")
@classmethod
@@ -398,6 +402,7 @@ class Parameter(TypeMetadata, Serializable):
description: Optional[str] = None,
options: Optional[Union[BaseDynamicOptions, List[OptionValue]]] = None,
resource_type: ResourceType = ResourceType.INSTANCE,
alias: Optional[List[str]] = None,
):
"""Build the parameter from the type."""
type_name = type.__qualname__
@@ -419,6 +424,7 @@ class Parameter(TypeMetadata, Serializable):
placeholder=placeholder,
description=description or label,
options=options,
alias=alias,
)
@classmethod
@@ -452,7 +458,7 @@ class Parameter(TypeMetadata, Serializable):
def to_dict(self) -> Dict:
"""Convert current metadata to json dict."""
dict_value = model_to_dict(self, exclude={"options"})
dict_value = model_to_dict(self, exclude={"options", "alias"})
if not self.options:
dict_value["options"] = None
elif isinstance(self.options, BaseDynamicOptions):
@@ -677,9 +683,18 @@ class BaseMetadata(BaseResource):
for parameter in self.parameters
if not parameter.optional
}
current_parameters = {
parameter.name: parameter for parameter in self.parameters
}
current_parameters = {}
current_aliases_parameters = {}
for parameter in self.parameters:
current_parameters[parameter.name] = parameter
if parameter.alias:
for alias in parameter.alias:
if alias in current_aliases_parameters:
raise FlowMetadataException(
f"Alias {alias} already exists in the metadata."
)
current_aliases_parameters[alias] = parameter
if len(view_required_parameters) < len(current_required_parameters):
# TODO, skip the optional parameters.
raise FlowParameterMetadataException(
@@ -691,12 +706,16 @@ class BaseMetadata(BaseResource):
)
for view_param in view_parameters:
view_param_key = view_param.name
if view_param_key not in current_parameters:
if view_param_key in current_parameters:
current_parameter = current_parameters[view_param_key]
elif view_param_key in current_aliases_parameters:
current_parameter = current_aliases_parameters[view_param_key]
else:
raise FlowParameterMetadataException(
f"Parameter {view_param_key} not in the metadata."
)
runnable_parameters.update(
current_parameters[view_param_key].to_runnable_parameter(
current_parameter.to_runnable_parameter(
view_param.get_typed_value(), resources, key_to_resource_instance
)
)

View File

@@ -1,8 +1,29 @@
"""Compatibility mapping for flow classes."""
from dataclasses import dataclass
from typing import Dict, Optional
_COMPAT_FLOW_MAPPING: Dict[str, str] = {}
@dataclass
class _RegisterItem:
"""Register item for compatibility mapping."""
old_module: str
new_module: str
old_name: str
new_name: Optional[str] = None
after: Optional[str] = None
def old_cls_key(self) -> str:
"""Get the old class key."""
return f"{self.old_module}.{self.old_name}"
def new_cls_key(self) -> str:
"""Get the new class key."""
return f"{self.new_module}.{self.new_name}"
_COMPAT_FLOW_MAPPING: Dict[str, _RegisterItem] = {}
_OLD_AGENT_RESOURCE_MODULE_1 = "dbgpt.serve.agent.team.layout.agent_operator_resource"
@@ -11,17 +32,24 @@ _NEW_AGENT_RESOURCE_MODULE = "dbgpt.agent.core.plan.awel.agent_operator_resource
def _register(
old_module: str, new_module: str, old_name: str, new_name: Optional[str] = None
old_module: str,
new_module: str,
old_name: str,
new_name: Optional[str] = None,
after_version: Optional[str] = None,
):
if not new_name:
new_name = old_name
_COMPAT_FLOW_MAPPING[f"{old_module}.{old_name}"] = f"{new_module}.{new_name}"
item = _RegisterItem(old_module, new_module, old_name, new_name, after_version)
_COMPAT_FLOW_MAPPING[item.old_cls_key()] = item
def get_new_class_name(old_class_name: str) -> Optional[str]:
"""Get the new class name for the old class name."""
new_cls_name = _COMPAT_FLOW_MAPPING.get(old_class_name, None)
return new_cls_name
if old_class_name not in _COMPAT_FLOW_MAPPING:
return None
item = _COMPAT_FLOW_MAPPING[old_class_name]
return item.new_cls_key()
_register(
@@ -54,3 +82,9 @@ _register(
_register(
_OLD_AGENT_RESOURCE_MODULE_2, _NEW_AGENT_RESOURCE_MODULE, "AWELAgent", "AWELAgent"
)
_register(
"dbgpt.storage.vector_store.connector",
"dbgpt.serve.rag.connector",
"VectorStoreConnector",
after_version="v0.5.8",
)

View File

@@ -555,14 +555,6 @@ class FlowFactory:
downstream = key_to_downstream.get(operator_key, [])
if not downstream:
raise ValueError("Branch operator should have downstream.")
if len(downstream) != len(view_metadata.parameters):
raise ValueError(
"Branch operator should have the same number of downstream as "
"parameters."
)
for i, param in enumerate(view_metadata.parameters):
downstream_key, _, _ = downstream[i]
param.value = key_to_operator_nodes[downstream_key].data.name
try:
runnable_params = metadata.get_runnable_parameters(