mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-13 13:10:29 +00:00
fix(core): Fix AWEL branch bug (#1640)
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
"""The mixin of DAGs."""
|
||||
|
||||
import abc
|
||||
import dataclasses
|
||||
import inspect
|
||||
@@ -337,6 +338,9 @@ class Parameter(TypeMetadata, Serializable):
|
||||
value: Optional[Any] = Field(
|
||||
None, description="The value of the parameter(Saved in the dag file)"
|
||||
)
|
||||
alias: Optional[List[str]] = Field(
|
||||
None, description="The alias of the parameter(Compatible with old version)"
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -398,6 +402,7 @@ class Parameter(TypeMetadata, Serializable):
|
||||
description: Optional[str] = None,
|
||||
options: Optional[Union[BaseDynamicOptions, List[OptionValue]]] = None,
|
||||
resource_type: ResourceType = ResourceType.INSTANCE,
|
||||
alias: Optional[List[str]] = None,
|
||||
):
|
||||
"""Build the parameter from the type."""
|
||||
type_name = type.__qualname__
|
||||
@@ -419,6 +424,7 @@ class Parameter(TypeMetadata, Serializable):
|
||||
placeholder=placeholder,
|
||||
description=description or label,
|
||||
options=options,
|
||||
alias=alias,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -452,7 +458,7 @@ class Parameter(TypeMetadata, Serializable):
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert current metadata to json dict."""
|
||||
dict_value = model_to_dict(self, exclude={"options"})
|
||||
dict_value = model_to_dict(self, exclude={"options", "alias"})
|
||||
if not self.options:
|
||||
dict_value["options"] = None
|
||||
elif isinstance(self.options, BaseDynamicOptions):
|
||||
@@ -677,9 +683,18 @@ class BaseMetadata(BaseResource):
|
||||
for parameter in self.parameters
|
||||
if not parameter.optional
|
||||
}
|
||||
current_parameters = {
|
||||
parameter.name: parameter for parameter in self.parameters
|
||||
}
|
||||
current_parameters = {}
|
||||
current_aliases_parameters = {}
|
||||
for parameter in self.parameters:
|
||||
current_parameters[parameter.name] = parameter
|
||||
if parameter.alias:
|
||||
for alias in parameter.alias:
|
||||
if alias in current_aliases_parameters:
|
||||
raise FlowMetadataException(
|
||||
f"Alias {alias} already exists in the metadata."
|
||||
)
|
||||
current_aliases_parameters[alias] = parameter
|
||||
|
||||
if len(view_required_parameters) < len(current_required_parameters):
|
||||
# TODO, skip the optional parameters.
|
||||
raise FlowParameterMetadataException(
|
||||
@@ -691,12 +706,16 @@ class BaseMetadata(BaseResource):
|
||||
)
|
||||
for view_param in view_parameters:
|
||||
view_param_key = view_param.name
|
||||
if view_param_key not in current_parameters:
|
||||
if view_param_key in current_parameters:
|
||||
current_parameter = current_parameters[view_param_key]
|
||||
elif view_param_key in current_aliases_parameters:
|
||||
current_parameter = current_aliases_parameters[view_param_key]
|
||||
else:
|
||||
raise FlowParameterMetadataException(
|
||||
f"Parameter {view_param_key} not in the metadata."
|
||||
)
|
||||
runnable_parameters.update(
|
||||
current_parameters[view_param_key].to_runnable_parameter(
|
||||
current_parameter.to_runnable_parameter(
|
||||
view_param.get_typed_value(), resources, key_to_resource_instance
|
||||
)
|
||||
)
|
||||
|
@@ -1,8 +1,29 @@
|
||||
"""Compatibility mapping for flow classes."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional
|
||||
|
||||
_COMPAT_FLOW_MAPPING: Dict[str, str] = {}
|
||||
|
||||
@dataclass
|
||||
class _RegisterItem:
|
||||
"""Register item for compatibility mapping."""
|
||||
|
||||
old_module: str
|
||||
new_module: str
|
||||
old_name: str
|
||||
new_name: Optional[str] = None
|
||||
after: Optional[str] = None
|
||||
|
||||
def old_cls_key(self) -> str:
|
||||
"""Get the old class key."""
|
||||
return f"{self.old_module}.{self.old_name}"
|
||||
|
||||
def new_cls_key(self) -> str:
|
||||
"""Get the new class key."""
|
||||
return f"{self.new_module}.{self.new_name}"
|
||||
|
||||
|
||||
_COMPAT_FLOW_MAPPING: Dict[str, _RegisterItem] = {}
|
||||
|
||||
|
||||
_OLD_AGENT_RESOURCE_MODULE_1 = "dbgpt.serve.agent.team.layout.agent_operator_resource"
|
||||
@@ -11,17 +32,24 @@ _NEW_AGENT_RESOURCE_MODULE = "dbgpt.agent.core.plan.awel.agent_operator_resource
|
||||
|
||||
|
||||
def _register(
|
||||
old_module: str, new_module: str, old_name: str, new_name: Optional[str] = None
|
||||
old_module: str,
|
||||
new_module: str,
|
||||
old_name: str,
|
||||
new_name: Optional[str] = None,
|
||||
after_version: Optional[str] = None,
|
||||
):
|
||||
if not new_name:
|
||||
new_name = old_name
|
||||
_COMPAT_FLOW_MAPPING[f"{old_module}.{old_name}"] = f"{new_module}.{new_name}"
|
||||
item = _RegisterItem(old_module, new_module, old_name, new_name, after_version)
|
||||
_COMPAT_FLOW_MAPPING[item.old_cls_key()] = item
|
||||
|
||||
|
||||
def get_new_class_name(old_class_name: str) -> Optional[str]:
|
||||
"""Get the new class name for the old class name."""
|
||||
new_cls_name = _COMPAT_FLOW_MAPPING.get(old_class_name, None)
|
||||
return new_cls_name
|
||||
if old_class_name not in _COMPAT_FLOW_MAPPING:
|
||||
return None
|
||||
item = _COMPAT_FLOW_MAPPING[old_class_name]
|
||||
return item.new_cls_key()
|
||||
|
||||
|
||||
_register(
|
||||
@@ -54,3 +82,9 @@ _register(
|
||||
_register(
|
||||
_OLD_AGENT_RESOURCE_MODULE_2, _NEW_AGENT_RESOURCE_MODULE, "AWELAgent", "AWELAgent"
|
||||
)
|
||||
_register(
|
||||
"dbgpt.storage.vector_store.connector",
|
||||
"dbgpt.serve.rag.connector",
|
||||
"VectorStoreConnector",
|
||||
after_version="v0.5.8",
|
||||
)
|
||||
|
@@ -555,14 +555,6 @@ class FlowFactory:
|
||||
downstream = key_to_downstream.get(operator_key, [])
|
||||
if not downstream:
|
||||
raise ValueError("Branch operator should have downstream.")
|
||||
if len(downstream) != len(view_metadata.parameters):
|
||||
raise ValueError(
|
||||
"Branch operator should have the same number of downstream as "
|
||||
"parameters."
|
||||
)
|
||||
for i, param in enumerate(view_metadata.parameters):
|
||||
downstream_key, _, _ = downstream[i]
|
||||
param.value = key_to_operator_nodes[downstream_key].data.name
|
||||
|
||||
try:
|
||||
runnable_params = metadata.get_runnable_parameters(
|
||||
|
Reference in New Issue
Block a user